Files
social-app/backend/src/v1/automation_jobs/repository.py
T

230 lines
8.0 KiB
Python

from __future__ import annotations
from datetime import datetime, time, timedelta, timezone
from typing import TYPE_CHECKING
from uuid import UUID
from sqlalchemy import func, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
from core.db.base_repository import BaseRepository
from models.agent_chat_session import AgentChatSession, SessionType
from models.automation_jobs import AutomationJob, AutomationJobStatus, ScheduleType
from schemas.automation import AutomationJobConfig, ScheduleConfig
if TYPE_CHECKING:
from v1.automation_jobs.schemas import (
AutomationJobCreateRequest,
AutomationJobUpdateRequest,
)
class AutomationJobsRepository(BaseRepository[AutomationJob]):
def __init__(self, session: AsyncSession) -> None:
super().__init__(session=session, model=AutomationJob)
async def list_due_jobs(
self,
*,
now_utc: datetime,
limit: int,
) -> list[AutomationJob]:
stmt = (
select(AutomationJob)
.where(AutomationJob.deleted_at.is_(None))
.where(AutomationJob.status == AutomationJobStatus.ACTIVE)
.where(AutomationJob.next_run_at <= now_utc)
.order_by(AutomationJob.next_run_at.asc())
.limit(max(limit, 1))
)
rows = (await self._session.execute(stmt)).scalars().all()
return list(rows)
async def update_job_schedule(
self,
*,
job_id: UUID,
next_run_at: datetime,
last_run_at: datetime,
) -> None:
stmt = (
update(AutomationJob)
.where(AutomationJob.id == job_id)
.where(AutomationJob.deleted_at.is_(None))
.values(next_run_at=next_run_at, last_run_at=last_run_at)
)
await self._session.execute(stmt)
await self._session.flush()
async def get_or_create_chat_session(self, *, owner_id: UUID) -> UUID:
stmt = (
select(AgentChatSession.id)
.where(AgentChatSession.user_id == owner_id)
.where(AgentChatSession.deleted_at.is_(None))
.where(AgentChatSession.session_type == SessionType.CHAT)
.order_by(AgentChatSession.last_activity_at.desc())
.limit(1)
)
existing = (await self._session.execute(stmt)).scalar_one_or_none()
if existing is not None:
return existing
from uuid import uuid4
new_session = AgentChatSession(
id=uuid4(),
user_id=owner_id,
session_type=SessionType.CHAT,
)
self._session.add(new_session)
await self._session.flush()
return new_session.id
async def list_by_owner(self, owner_id: UUID) -> list[AutomationJob]:
stmt = (
select(AutomationJob)
.where(AutomationJob.owner_id == owner_id)
.where(AutomationJob.deleted_at.is_(None))
.order_by(AutomationJob.created_at.desc())
)
rows = (await self._session.execute(stmt)).scalars().all()
return list(rows)
async def count_user_jobs(self, owner_id: UUID) -> int:
stmt = (
select(func.count())
.select_from(AutomationJob)
.where(AutomationJob.owner_id == owner_id)
.where(AutomationJob.deleted_at.is_(None))
.where(AutomationJob.bootstrap_key.is_(None))
)
result = (await self._session.execute(stmt)).scalar_one()
return int(result)
def _resolve_timezone(self, timezone_str: str) -> ZoneInfo:
try:
return ZoneInfo(timezone_str)
except ZoneInfoNotFoundError:
return ZoneInfo("UTC")
def _compute_next_run_at(
self,
*,
schedule: ScheduleConfig,
timezone_str: str,
now_utc: datetime,
) -> datetime:
tz = self._resolve_timezone(timezone_str)
local_now = now_utc.astimezone(tz)
run_clock = time(
hour=schedule.run_at.hour,
minute=schedule.run_at.minute,
tzinfo=tz,
)
if schedule.type == ScheduleType.DAILY:
candidate_local = datetime.combine(local_now.date(), run_clock)
if candidate_local <= local_now:
candidate_local = candidate_local + timedelta(days=1)
return candidate_local.astimezone(timezone.utc)
weekdays = schedule.weekdays or []
if not weekdays:
raise ValueError("weekly schedule requires weekdays")
normalized_weekdays = sorted(set(weekdays))
for day_offset in range(0, 8):
candidate_day = local_now.date() + timedelta(days=day_offset)
if candidate_day.isoweekday() not in normalized_weekdays:
continue
candidate_local = datetime.combine(candidate_day, run_clock)
if candidate_local > local_now:
return candidate_local.astimezone(timezone.utc)
fallback_day = local_now.date() + timedelta(days=7)
while fallback_day.isoweekday() not in normalized_weekdays:
fallback_day = fallback_day + timedelta(days=1)
fallback_local = datetime.combine(fallback_day, run_clock)
return fallback_local.astimezone(timezone.utc)
async def create(
self,
owner_id: UUID,
data: AutomationJobCreateRequest,
) -> AutomationJob:
now_utc = datetime.now(tz=timezone.utc)
schedule = data.config.schedule
if schedule is None:
raise ValueError("config.schedule is required")
next_run_at = self._compute_next_run_at(
schedule=schedule,
timezone_str=data.timezone,
now_utc=now_utc,
)
new_job = AutomationJob(
owner_id=owner_id,
created_by=owner_id,
bootstrap_key=None,
title=data.title,
timezone=data.timezone,
status=data.status,
config=data.config.model_dump(mode="json"),
next_run_at=next_run_at,
)
self._session.add(new_job)
await self._session.flush()
return new_job
async def update(
self,
job_id: UUID,
data: AutomationJobUpdateRequest,
) -> AutomationJob | None:
update_values: dict[str, object] = {}
existing_job = await self.get_by_id(job_id)
if existing_job is None:
return None
if data.title is not None:
update_values["title"] = data.title
if data.timezone is not None:
update_values["timezone"] = data.timezone
if data.status is not None:
update_values["status"] = data.status
merged_config_raw: dict[str, object] = dict(existing_job.config or {})
if data.config is not None:
merged_config_raw = {
**merged_config_raw,
**data.config.model_dump(mode="json", exclude_unset=True),
}
normalized_config = AutomationJobConfig.model_validate(merged_config_raw)
update_values["config"] = normalized_config.model_dump(mode="json")
else:
normalized_config = AutomationJobConfig.model_validate(merged_config_raw)
schedule_changed = data.config is not None and (
"schedule" in data.config.model_dump(mode="json", exclude_unset=True)
)
if data.timezone is not None or schedule_changed:
if normalized_config.schedule is None:
raise ValueError("config.schedule is required")
effective_timezone = data.timezone or existing_job.timezone
update_values["next_run_at"] = self._compute_next_run_at(
schedule=normalized_config.schedule,
timezone_str=effective_timezone,
now_utc=datetime.now(tz=timezone.utc),
)
if not update_values:
return existing_job
return await self.update_by_id(job_id, update_values)
async def soft_delete(self, job_id: UUID) -> None:
await self.soft_delete_by_id(job_id)