294 lines
8.8 KiB
Python
294 lines
8.8 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Protocol
|
|
from uuid import UUID, uuid4
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from core.config.settings import config
|
|
from core.logging import get_logger
|
|
from models.agent_chat_session import AgentChatSession, SessionType
|
|
from models.automation_jobs import AutomationJob, ScheduleType
|
|
from schemas.automation.config import AutomationJobConfig
|
|
from schemas.automation.scheduler import DueAutomationJob, SchedulerDispatchCommand
|
|
|
|
logger = get_logger("core.automation.scheduler")
|
|
|
|
|
|
class _BulkQueueAdapter:
|
|
async def enqueue(
|
|
self,
|
|
*,
|
|
command: dict[str, object],
|
|
dedup_key: str | None,
|
|
) -> str:
|
|
del dedup_key
|
|
from core.agentscope.runtime.tasks import run_command_task_bulk
|
|
|
|
result = await run_command_task_bulk.kiq(command)
|
|
return str(result.task_id)
|
|
|
|
|
|
class QueueLike(Protocol):
|
|
async def enqueue(
|
|
self,
|
|
*,
|
|
command: dict[str, object],
|
|
dedup_key: str | None,
|
|
) -> str: ...
|
|
|
|
|
|
class AutomationSchedulerRepositoryLike(Protocol):
|
|
async def list_due_jobs(
|
|
self,
|
|
*,
|
|
now_utc: datetime,
|
|
limit: int,
|
|
) -> list[DueAutomationJob]: ...
|
|
|
|
async def get_job_config(self, *, job_id: UUID) -> AutomationJobConfig: ...
|
|
|
|
async def ensure_latest_chat_session(self, *, owner_id: UUID) -> UUID: ...
|
|
|
|
async def mark_job_dispatched(
|
|
self,
|
|
*,
|
|
job_id: UUID,
|
|
next_run_at: datetime,
|
|
last_run_at: datetime,
|
|
) -> None: ...
|
|
|
|
async def commit(self) -> None: ...
|
|
|
|
async def rollback(self) -> None: ...
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class DispatchResult:
|
|
scanned: int
|
|
dispatched: int
|
|
|
|
|
|
class AutomationSchedulerService:
|
|
def __init__(
|
|
self,
|
|
*,
|
|
repository: AutomationSchedulerRepositoryLike,
|
|
queue: QueueLike,
|
|
) -> None:
|
|
self._repository = repository
|
|
self._queue = queue
|
|
|
|
async def scan_and_dispatch(
|
|
self,
|
|
*,
|
|
now_utc: datetime,
|
|
limit: int,
|
|
) -> DispatchResult:
|
|
safe_limit = max(int(limit), 1)
|
|
due_jobs = await self._repository.list_due_jobs(
|
|
now_utc=now_utc, limit=safe_limit
|
|
)
|
|
dispatched = 0
|
|
for job in due_jobs:
|
|
try:
|
|
config = await self._repository.get_job_config(job_id=job.id)
|
|
thread_id = await self._repository.ensure_latest_chat_session(
|
|
owner_id=job.owner_id
|
|
)
|
|
command = self._build_dispatch_command(
|
|
job=job,
|
|
thread_id=thread_id,
|
|
input_text=config.input_template,
|
|
now_utc=now_utc,
|
|
)
|
|
await self._queue.enqueue(command=command, dedup_key=None)
|
|
await self._repository.mark_job_dispatched(
|
|
job_id=job.id,
|
|
next_run_at=_compute_next_run_at(
|
|
current_next_run_at=job.next_run_at,
|
|
now_utc=now_utc,
|
|
schedule_type=job.schedule_type,
|
|
),
|
|
last_run_at=now_utc,
|
|
)
|
|
await self._repository.commit()
|
|
dispatched += 1
|
|
except Exception as exc:
|
|
await self._repository.rollback()
|
|
logger.exception(
|
|
"automation job dispatch failed",
|
|
job_id=str(job.id),
|
|
owner_id=str(job.owner_id),
|
|
error=str(exc),
|
|
)
|
|
return DispatchResult(scanned=len(due_jobs), dispatched=dispatched)
|
|
|
|
def _build_dispatch_command(
|
|
self,
|
|
*,
|
|
job: DueAutomationJob,
|
|
thread_id: UUID,
|
|
input_text: str,
|
|
now_utc: datetime,
|
|
) -> dict[str, object]:
|
|
run_id = f"auto-{job.id}-{int(now_utc.timestamp())}"
|
|
payload = SchedulerDispatchCommand(
|
|
owner_id=job.owner_id,
|
|
automation_job_id=job.id,
|
|
thread_id=thread_id,
|
|
run_id=run_id,
|
|
input_text=input_text.strip(),
|
|
)
|
|
return {
|
|
"command": "run",
|
|
"owner_id": str(payload.owner_id),
|
|
"automation_job_id": str(payload.automation_job_id),
|
|
"queue": "bulk",
|
|
"run_input": {
|
|
"threadId": str(payload.thread_id),
|
|
"runId": payload.run_id,
|
|
"state": {},
|
|
"messages": [
|
|
{
|
|
"id": str(uuid4()),
|
|
"role": "user",
|
|
"content": payload.input_text,
|
|
}
|
|
],
|
|
"tools": [],
|
|
"context": [],
|
|
"forwardedProps": {
|
|
"agent_type": "memory",
|
|
},
|
|
},
|
|
}
|
|
|
|
|
|
class SqlAlchemyAutomationSchedulerRepository:
|
|
def __init__(self, *, session: AsyncSession) -> None:
|
|
self._session = session
|
|
|
|
async def list_due_jobs(
|
|
self,
|
|
*,
|
|
now_utc: datetime,
|
|
limit: int,
|
|
) -> list[DueAutomationJob]:
|
|
stmt = (
|
|
select(AutomationJob)
|
|
.where(AutomationJob.deleted_at.is_(None))
|
|
.where(AutomationJob.status == "active")
|
|
.where(AutomationJob.next_run_at <= now_utc)
|
|
.order_by(AutomationJob.next_run_at.asc())
|
|
.limit(max(limit, 1))
|
|
)
|
|
rows = (await self._session.execute(stmt)).scalars().all()
|
|
return [
|
|
DueAutomationJob(
|
|
id=row.id,
|
|
owner_id=row.owner_id,
|
|
schedule_type=row.schedule_type,
|
|
timezone=row.timezone,
|
|
next_run_at=row.next_run_at,
|
|
)
|
|
for row in rows
|
|
]
|
|
|
|
async def get_job_config(self, *, job_id: UUID) -> AutomationJobConfig:
|
|
stmt = select(AutomationJob.config).where(AutomationJob.id == job_id)
|
|
config_payload = (await self._session.execute(stmt)).scalar_one()
|
|
return AutomationJobConfig.model_validate(config_payload or {})
|
|
|
|
async def ensure_latest_chat_session(self, *, owner_id: UUID) -> UUID:
|
|
stmt = (
|
|
select(AgentChatSession.id)
|
|
.where(AgentChatSession.user_id == owner_id)
|
|
.where(AgentChatSession.deleted_at.is_(None))
|
|
.where(AgentChatSession.session_type == SessionType.CHAT)
|
|
.order_by(AgentChatSession.last_activity_at.desc())
|
|
.limit(1)
|
|
)
|
|
existing = (await self._session.execute(stmt)).scalar_one_or_none()
|
|
if existing is not None:
|
|
return existing
|
|
|
|
session = AgentChatSession(
|
|
id=uuid4(),
|
|
user_id=owner_id,
|
|
session_type=SessionType.CHAT,
|
|
)
|
|
self._session.add(session)
|
|
await self._session.flush()
|
|
return session.id
|
|
|
|
async def mark_job_dispatched(
|
|
self,
|
|
*,
|
|
job_id: UUID,
|
|
next_run_at: datetime,
|
|
last_run_at: datetime,
|
|
) -> None:
|
|
stmt = select(AutomationJob).where(AutomationJob.id == job_id)
|
|
row = (await self._session.execute(stmt)).scalar_one()
|
|
row.next_run_at = next_run_at
|
|
row.last_run_at = last_run_at
|
|
await self._session.flush()
|
|
|
|
async def commit(self) -> None:
|
|
await self._session.commit()
|
|
|
|
async def rollback(self) -> None:
|
|
await self._session.rollback()
|
|
|
|
|
|
def _compute_next_run_at(
|
|
*,
|
|
current_next_run_at: datetime,
|
|
now_utc: datetime,
|
|
schedule_type: ScheduleType,
|
|
) -> datetime:
|
|
delta = timedelta(days=1 if schedule_type == ScheduleType.DAILY else 7)
|
|
next_run_at = current_next_run_at
|
|
while next_run_at <= now_utc:
|
|
next_run_at = next_run_at + delta
|
|
return next_run_at
|
|
|
|
|
|
def utc_now() -> datetime:
|
|
return datetime.now(timezone.utc)
|
|
|
|
|
|
async def run_automation_scheduler_scan(
|
|
*,
|
|
limit: int | None = None,
|
|
) -> dict[str, int]:
|
|
now = utc_now()
|
|
safe_limit = (
|
|
max(int(limit), 1)
|
|
if isinstance(limit, int)
|
|
else int(config.automation_scheduler.batch_limit)
|
|
)
|
|
from core.db.session import AsyncSessionLocal
|
|
|
|
async with AsyncSessionLocal() as session:
|
|
repository = SqlAlchemyAutomationSchedulerRepository(session=session)
|
|
service = AutomationSchedulerService(
|
|
repository=repository,
|
|
queue=_BulkQueueAdapter(),
|
|
)
|
|
result = await service.scan_and_dispatch(now_utc=now, limit=safe_limit)
|
|
logger.info(
|
|
"automation scheduler scan completed",
|
|
scanned=result.scanned,
|
|
dispatched=result.dispatched,
|
|
now_utc=now.astimezone(timezone.utc).isoformat(),
|
|
)
|
|
return {
|
|
"scanned": int(result.scanned),
|
|
"dispatched": int(result.dispatched),
|
|
}
|