from __future__ import annotations from datetime import datetime, time, timedelta, timezone from typing import TYPE_CHECKING from uuid import UUID from sqlalchemy import func, select, update from sqlalchemy.ext.asyncio import AsyncSession from zoneinfo import ZoneInfo, ZoneInfoNotFoundError from core.db.base_repository import BaseRepository from models.agent_chat_session import AgentChatSession, SessionType from models.automation_jobs import AutomationJob, AutomationJobStatus, ScheduleType from schemas.automation import AutomationJobConfig, ScheduleConfig if TYPE_CHECKING: from v1.automation_jobs.schemas import ( AutomationJobCreateRequest, AutomationJobUpdateRequest, ) class AutomationJobsRepository(BaseRepository[AutomationJob]): def __init__(self, session: AsyncSession) -> None: super().__init__(session=session, model=AutomationJob) async def list_due_jobs( self, *, now_utc: datetime, limit: int, ) -> list[AutomationJob]: stmt = ( select(AutomationJob) .where(AutomationJob.deleted_at.is_(None)) .where(AutomationJob.status == AutomationJobStatus.ACTIVE) .where(AutomationJob.next_run_at <= now_utc) .order_by(AutomationJob.next_run_at.asc()) .limit(max(limit, 1)) ) rows = (await self._session.execute(stmt)).scalars().all() return list(rows) async def update_job_schedule( self, *, job_id: UUID, next_run_at: datetime, last_run_at: datetime, ) -> None: stmt = ( update(AutomationJob) .where(AutomationJob.id == job_id) .where(AutomationJob.deleted_at.is_(None)) .values(next_run_at=next_run_at, last_run_at=last_run_at) ) await self._session.execute(stmt) await self._session.flush() async def get_or_create_chat_session(self, *, owner_id: UUID) -> UUID: stmt = ( select(AgentChatSession.id) .where(AgentChatSession.user_id == owner_id) .where(AgentChatSession.deleted_at.is_(None)) .where(AgentChatSession.session_type == SessionType.CHAT) .order_by(AgentChatSession.last_activity_at.desc()) .limit(1) ) existing = (await self._session.execute(stmt)).scalar_one_or_none() if existing is not None: return existing from uuid import uuid4 new_session = AgentChatSession( id=uuid4(), user_id=owner_id, session_type=SessionType.CHAT, ) self._session.add(new_session) await self._session.flush() return new_session.id async def list_by_owner(self, owner_id: UUID) -> list[AutomationJob]: stmt = ( select(AutomationJob) .where(AutomationJob.owner_id == owner_id) .where(AutomationJob.deleted_at.is_(None)) .order_by(AutomationJob.created_at.desc()) ) rows = (await self._session.execute(stmt)).scalars().all() return list(rows) async def count_user_jobs(self, owner_id: UUID) -> int: stmt = ( select(func.count()) .select_from(AutomationJob) .where(AutomationJob.owner_id == owner_id) .where(AutomationJob.deleted_at.is_(None)) .where(AutomationJob.bootstrap_key.is_(None)) ) result = (await self._session.execute(stmt)).scalar_one() return int(result) def _resolve_timezone(self, timezone_str: str) -> ZoneInfo: try: return ZoneInfo(timezone_str) except ZoneInfoNotFoundError: return ZoneInfo("UTC") def _compute_next_run_at( self, *, schedule: ScheduleConfig, timezone_str: str, now_utc: datetime, ) -> datetime: tz = self._resolve_timezone(timezone_str) local_now = now_utc.astimezone(tz) run_clock = time( hour=schedule.run_at.hour, minute=schedule.run_at.minute, tzinfo=tz, ) if schedule.type == ScheduleType.DAILY: candidate_local = datetime.combine(local_now.date(), run_clock) if candidate_local <= local_now: candidate_local = candidate_local + timedelta(days=1) return candidate_local.astimezone(timezone.utc) weekdays = schedule.weekdays or [] if not weekdays: raise ValueError("weekly schedule requires weekdays") normalized_weekdays = sorted(set(weekdays)) for day_offset in range(0, 8): candidate_day = local_now.date() + timedelta(days=day_offset) if candidate_day.isoweekday() not in normalized_weekdays: continue candidate_local = datetime.combine(candidate_day, run_clock) if candidate_local > local_now: return candidate_local.astimezone(timezone.utc) fallback_day = local_now.date() + timedelta(days=7) while fallback_day.isoweekday() not in normalized_weekdays: fallback_day = fallback_day + timedelta(days=1) fallback_local = datetime.combine(fallback_day, run_clock) return fallback_local.astimezone(timezone.utc) async def create( self, owner_id: UUID, data: AutomationJobCreateRequest, ) -> AutomationJob: now_utc = datetime.now(tz=timezone.utc) schedule = data.config.schedule if schedule is None: raise ValueError("config.schedule is required") next_run_at = self._compute_next_run_at( schedule=schedule, timezone_str=data.timezone, now_utc=now_utc, ) new_job = AutomationJob( owner_id=owner_id, created_by=owner_id, bootstrap_key=None, title=data.title, timezone=data.timezone, status=data.status, config=data.config.model_dump(mode="json"), next_run_at=next_run_at, ) self._session.add(new_job) await self._session.flush() return new_job async def update( self, job_id: UUID, data: AutomationJobUpdateRequest, ) -> AutomationJob | None: update_values: dict[str, object] = {} existing_job = await self.get_by_id(job_id) if existing_job is None: return None if data.title is not None: update_values["title"] = data.title if data.timezone is not None: update_values["timezone"] = data.timezone if data.status is not None: update_values["status"] = data.status merged_config_raw: dict[str, object] = dict(existing_job.config or {}) if data.config is not None: merged_config_raw = { **merged_config_raw, **data.config.model_dump(mode="json", exclude_unset=True), } normalized_config = AutomationJobConfig.model_validate(merged_config_raw) update_values["config"] = normalized_config.model_dump(mode="json") else: normalized_config = AutomationJobConfig.model_validate(merged_config_raw) schedule_changed = data.config is not None and ( "schedule" in data.config.model_dump(mode="json", exclude_unset=True) ) if data.timezone is not None or schedule_changed: if normalized_config.schedule is None: raise ValueError("config.schedule is required") effective_timezone = data.timezone or existing_job.timezone update_values["next_run_at"] = self._compute_next_run_at( schedule=normalized_config.schedule, timezone_str=effective_timezone, now_utc=datetime.now(tz=timezone.utc), ) if not update_values: return existing_job return await self.update_by_id(job_id, update_values) async def soft_delete(self, job_id: UUID) -> None: await self.soft_delete_by_id(job_id)