from __future__ import annotations from dataclasses import dataclass from datetime import datetime, timedelta, timezone from typing import Protocol from uuid import UUID, uuid4 from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from core.config.settings import config from core.logging import get_logger from models.agent_chat_session import AgentChatSession, SessionType from models.automation_jobs import AutomationJob, ScheduleType from schemas.automation.config import AutomationJobConfig from schemas.automation.scheduler import DueAutomationJob, SchedulerDispatchCommand logger = get_logger("core.automation.scheduler") class _BulkQueueAdapter: async def enqueue( self, *, command: dict[str, object], dedup_key: str | None, ) -> str: del dedup_key from core.agentscope.runtime.tasks import run_command_task_bulk result = await run_command_task_bulk.kiq(command) return str(result.task_id) class QueueLike(Protocol): async def enqueue( self, *, command: dict[str, object], dedup_key: str | None, ) -> str: ... class AutomationSchedulerRepositoryLike(Protocol): async def list_due_jobs( self, *, now_utc: datetime, limit: int, ) -> list[DueAutomationJob]: ... async def get_job_config(self, *, job_id: UUID) -> AutomationJobConfig: ... async def ensure_latest_chat_session(self, *, owner_id: UUID) -> UUID: ... async def mark_job_dispatched( self, *, job_id: UUID, next_run_at: datetime, last_run_at: datetime, ) -> None: ... async def commit(self) -> None: ... async def rollback(self) -> None: ... @dataclass(slots=True) class DispatchResult: scanned: int dispatched: int class AutomationSchedulerService: def __init__( self, *, repository: AutomationSchedulerRepositoryLike, queue: QueueLike, ) -> None: self._repository = repository self._queue = queue async def scan_and_dispatch( self, *, now_utc: datetime, limit: int, ) -> DispatchResult: safe_limit = max(int(limit), 1) due_jobs = await self._repository.list_due_jobs( now_utc=now_utc, limit=safe_limit ) dispatched = 0 for job in due_jobs: try: config = await self._repository.get_job_config(job_id=job.id) thread_id = await self._repository.ensure_latest_chat_session( owner_id=job.owner_id ) command = self._build_dispatch_command( job=job, thread_id=thread_id, input_text=config.input_template, now_utc=now_utc, ) await self._queue.enqueue(command=command, dedup_key=None) await self._repository.mark_job_dispatched( job_id=job.id, next_run_at=_compute_next_run_at( current_next_run_at=job.next_run_at, now_utc=now_utc, schedule_type=job.schedule_type, ), last_run_at=now_utc, ) await self._repository.commit() dispatched += 1 except Exception as exc: await self._repository.rollback() logger.exception( "automation job dispatch failed", job_id=str(job.id), owner_id=str(job.owner_id), error=str(exc), ) return DispatchResult(scanned=len(due_jobs), dispatched=dispatched) def _build_dispatch_command( self, *, job: DueAutomationJob, thread_id: UUID, input_text: str, now_utc: datetime, ) -> dict[str, object]: run_id = f"auto-{job.id}-{int(now_utc.timestamp())}" payload = SchedulerDispatchCommand( owner_id=job.owner_id, automation_job_id=job.id, thread_id=thread_id, run_id=run_id, input_text=input_text.strip(), ) return { "command": "run", "owner_id": str(payload.owner_id), "automation_job_id": str(payload.automation_job_id), "queue": "bulk", "run_input": { "threadId": str(payload.thread_id), "runId": payload.run_id, "state": {}, "messages": [ { "id": str(uuid4()), "role": "user", "content": payload.input_text, } ], "tools": [], "context": [], "forwardedProps": { "agent_type": "memory", }, }, } class SqlAlchemyAutomationSchedulerRepository: def __init__(self, *, session: AsyncSession) -> None: self._session = session async def list_due_jobs( self, *, now_utc: datetime, limit: int, ) -> list[DueAutomationJob]: stmt = ( select(AutomationJob) .where(AutomationJob.deleted_at.is_(None)) .where(AutomationJob.status == "active") .where(AutomationJob.next_run_at <= now_utc) .order_by(AutomationJob.next_run_at.asc()) .limit(max(limit, 1)) ) rows = (await self._session.execute(stmt)).scalars().all() return [ DueAutomationJob( id=row.id, owner_id=row.owner_id, schedule_type=row.schedule_type, timezone=row.timezone, next_run_at=row.next_run_at, ) for row in rows ] async def get_job_config(self, *, job_id: UUID) -> AutomationJobConfig: stmt = select(AutomationJob.config).where(AutomationJob.id == job_id) config_payload = (await self._session.execute(stmt)).scalar_one() return AutomationJobConfig.model_validate(config_payload or {}) async def ensure_latest_chat_session(self, *, owner_id: UUID) -> UUID: stmt = ( select(AgentChatSession.id) .where(AgentChatSession.user_id == owner_id) .where(AgentChatSession.deleted_at.is_(None)) .where(AgentChatSession.session_type == SessionType.CHAT) .order_by(AgentChatSession.last_activity_at.desc()) .limit(1) ) existing = (await self._session.execute(stmt)).scalar_one_or_none() if existing is not None: return existing session = AgentChatSession( id=uuid4(), user_id=owner_id, session_type=SessionType.CHAT, ) self._session.add(session) await self._session.flush() return session.id async def mark_job_dispatched( self, *, job_id: UUID, next_run_at: datetime, last_run_at: datetime, ) -> None: stmt = select(AutomationJob).where(AutomationJob.id == job_id) row = (await self._session.execute(stmt)).scalar_one() row.next_run_at = next_run_at row.last_run_at = last_run_at await self._session.flush() async def commit(self) -> None: await self._session.commit() async def rollback(self) -> None: await self._session.rollback() def _compute_next_run_at( *, current_next_run_at: datetime, now_utc: datetime, schedule_type: ScheduleType, ) -> datetime: delta = timedelta(days=1 if schedule_type == ScheduleType.DAILY else 7) next_run_at = current_next_run_at while next_run_at <= now_utc: next_run_at = next_run_at + delta return next_run_at def utc_now() -> datetime: return datetime.now(timezone.utc) async def run_automation_scheduler_scan( *, limit: int | None = None, ) -> dict[str, int]: now = utc_now() safe_limit = ( max(int(limit), 1) if isinstance(limit, int) else int(config.automation_scheduler.batch_limit) ) from core.db.session import AsyncSessionLocal async with AsyncSessionLocal() as session: repository = SqlAlchemyAutomationSchedulerRepository(session=session) service = AutomationSchedulerService( repository=repository, queue=_BulkQueueAdapter(), ) result = await service.scan_and_dispatch(now_utc=now, limit=safe_limit) logger.info( "automation scheduler scan completed", scanned=result.scanned, dispatched=result.dispatched, now_utc=now.astimezone(timezone.utc).isoformat(), ) return { "scanned": int(result.scanned), "dispatched": int(result.dispatched), }