Files
social-app/backend/src/core/automation/scheduler.py
T

294 lines
8.8 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import Protocol
from uuid import UUID, uuid4
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from core.config.settings import config
from core.logging import get_logger
from models.agent_chat_session import AgentChatSession, SessionType
from models.automation_jobs import AutomationJob, ScheduleType
from schemas.automation.config import AutomationJobConfig
from schemas.automation.scheduler import DueAutomationJob, SchedulerDispatchCommand
logger = get_logger("core.automation.scheduler")
class _BulkQueueAdapter:
async def enqueue(
self,
*,
command: dict[str, object],
dedup_key: str | None,
) -> str:
del dedup_key
from core.agentscope.runtime.tasks import run_command_task_bulk
result = await run_command_task_bulk.kiq(command)
return str(result.task_id)
class QueueLike(Protocol):
async def enqueue(
self,
*,
command: dict[str, object],
dedup_key: str | None,
) -> str: ...
class AutomationSchedulerRepositoryLike(Protocol):
async def list_due_jobs(
self,
*,
now_utc: datetime,
limit: int,
) -> list[DueAutomationJob]: ...
async def get_job_config(self, *, job_id: UUID) -> AutomationJobConfig: ...
async def ensure_latest_chat_session(self, *, owner_id: UUID) -> UUID: ...
async def mark_job_dispatched(
self,
*,
job_id: UUID,
next_run_at: datetime,
last_run_at: datetime,
) -> None: ...
async def commit(self) -> None: ...
async def rollback(self) -> None: ...
@dataclass(slots=True)
class DispatchResult:
scanned: int
dispatched: int
class AutomationSchedulerService:
def __init__(
self,
*,
repository: AutomationSchedulerRepositoryLike,
queue: QueueLike,
) -> None:
self._repository = repository
self._queue = queue
async def scan_and_dispatch(
self,
*,
now_utc: datetime,
limit: int,
) -> DispatchResult:
safe_limit = max(int(limit), 1)
due_jobs = await self._repository.list_due_jobs(
now_utc=now_utc, limit=safe_limit
)
dispatched = 0
for job in due_jobs:
try:
config = await self._repository.get_job_config(job_id=job.id)
thread_id = await self._repository.ensure_latest_chat_session(
owner_id=job.owner_id
)
command = self._build_dispatch_command(
job=job,
thread_id=thread_id,
input_text=config.input_template,
now_utc=now_utc,
)
await self._queue.enqueue(command=command, dedup_key=None)
await self._repository.mark_job_dispatched(
job_id=job.id,
next_run_at=_compute_next_run_at(
current_next_run_at=job.next_run_at,
now_utc=now_utc,
schedule_type=job.schedule_type,
),
last_run_at=now_utc,
)
await self._repository.commit()
dispatched += 1
except Exception as exc:
await self._repository.rollback()
logger.exception(
"automation job dispatch failed",
job_id=str(job.id),
owner_id=str(job.owner_id),
error=str(exc),
)
return DispatchResult(scanned=len(due_jobs), dispatched=dispatched)
def _build_dispatch_command(
self,
*,
job: DueAutomationJob,
thread_id: UUID,
input_text: str,
now_utc: datetime,
) -> dict[str, object]:
run_id = f"auto-{job.id}-{int(now_utc.timestamp())}"
payload = SchedulerDispatchCommand(
owner_id=job.owner_id,
automation_job_id=job.id,
thread_id=thread_id,
run_id=run_id,
input_text=input_text.strip(),
)
return {
"command": "run",
"owner_id": str(payload.owner_id),
"automation_job_id": str(payload.automation_job_id),
"queue": "bulk",
"run_input": {
"threadId": str(payload.thread_id),
"runId": payload.run_id,
"state": {},
"messages": [
{
"id": str(uuid4()),
"role": "user",
"content": payload.input_text,
}
],
"tools": [],
"context": [],
"forwardedProps": {
"agent_type": "memory",
},
},
}
class SqlAlchemyAutomationSchedulerRepository:
def __init__(self, *, session: AsyncSession) -> None:
self._session = session
async def list_due_jobs(
self,
*,
now_utc: datetime,
limit: int,
) -> list[DueAutomationJob]:
stmt = (
select(AutomationJob)
.where(AutomationJob.deleted_at.is_(None))
.where(AutomationJob.status == "active")
.where(AutomationJob.next_run_at <= now_utc)
.order_by(AutomationJob.next_run_at.asc())
.limit(max(limit, 1))
)
rows = (await self._session.execute(stmt)).scalars().all()
return [
DueAutomationJob(
id=row.id,
owner_id=row.owner_id,
schedule_type=row.schedule_type,
timezone=row.timezone,
next_run_at=row.next_run_at,
)
for row in rows
]
async def get_job_config(self, *, job_id: UUID) -> AutomationJobConfig:
stmt = select(AutomationJob.config).where(AutomationJob.id == job_id)
config_payload = (await self._session.execute(stmt)).scalar_one()
return AutomationJobConfig.model_validate(config_payload or {})
async def ensure_latest_chat_session(self, *, owner_id: UUID) -> UUID:
stmt = (
select(AgentChatSession.id)
.where(AgentChatSession.user_id == owner_id)
.where(AgentChatSession.deleted_at.is_(None))
.where(AgentChatSession.session_type == SessionType.CHAT)
.order_by(AgentChatSession.last_activity_at.desc())
.limit(1)
)
existing = (await self._session.execute(stmt)).scalar_one_or_none()
if existing is not None:
return existing
session = AgentChatSession(
id=uuid4(),
user_id=owner_id,
session_type=SessionType.CHAT,
)
self._session.add(session)
await self._session.flush()
return session.id
async def mark_job_dispatched(
self,
*,
job_id: UUID,
next_run_at: datetime,
last_run_at: datetime,
) -> None:
stmt = select(AutomationJob).where(AutomationJob.id == job_id)
row = (await self._session.execute(stmt)).scalar_one()
row.next_run_at = next_run_at
row.last_run_at = last_run_at
await self._session.flush()
async def commit(self) -> None:
await self._session.commit()
async def rollback(self) -> None:
await self._session.rollback()
def _compute_next_run_at(
*,
current_next_run_at: datetime,
now_utc: datetime,
schedule_type: ScheduleType,
) -> datetime:
delta = timedelta(days=1 if schedule_type == ScheduleType.DAILY else 7)
next_run_at = current_next_run_at
while next_run_at <= now_utc:
next_run_at = next_run_at + delta
return next_run_at
def utc_now() -> datetime:
return datetime.now(timezone.utc)
async def run_automation_scheduler_scan(
*,
limit: int | None = None,
) -> dict[str, int]:
now = utc_now()
safe_limit = (
max(int(limit), 1)
if isinstance(limit, int)
else int(config.automation_scheduler.batch_limit)
)
from core.db.session import AsyncSessionLocal
async with AsyncSessionLocal() as session:
repository = SqlAlchemyAutomationSchedulerRepository(session=session)
service = AutomationSchedulerService(
repository=repository,
queue=_BulkQueueAdapter(),
)
result = await service.scan_and_dispatch(now_utc=now, limit=safe_limit)
logger.info(
"automation scheduler scan completed",
scanned=result.scanned,
dispatched=result.dispatched,
now_utc=now.astimezone(timezone.utc).isoformat(),
)
return {
"scanned": int(result.scanned),
"dispatched": int(result.dispatched),
}