feat(automation_jobs): add CRUD repository methods
This commit is contained in:
@@ -1,24 +1,162 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy import func, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||
|
||||
from core.db.base_repository import BaseRepository
|
||||
from models.agent_chat_session import AgentChatSession, SessionType
|
||||
from models.automation_jobs import AutomationJob
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
from v1.automation_jobs.schemas import (
|
||||
AutomationJobCreateRequest,
|
||||
AutomationJobUpdateRequest,
|
||||
)
|
||||
|
||||
|
||||
def _compute_next_local_time_utc(
|
||||
*,
|
||||
now_utc: datetime,
|
||||
timezone_name: str,
|
||||
local_hour: int,
|
||||
local_minute: int,
|
||||
) -> tuple[datetime, datetime]:
|
||||
try:
|
||||
timezone_obj = ZoneInfo(timezone_name)
|
||||
except ZoneInfoNotFoundError:
|
||||
timezone_obj = ZoneInfo("UTC")
|
||||
local_now = now_utc.astimezone(timezone_obj)
|
||||
today_run_local = local_now.replace(
|
||||
hour=local_hour,
|
||||
minute=local_minute,
|
||||
second=0,
|
||||
microsecond=0,
|
||||
)
|
||||
run_local = (
|
||||
today_run_local
|
||||
if local_now <= today_run_local
|
||||
else today_run_local + timedelta(days=1)
|
||||
)
|
||||
next_local = run_local + timedelta(days=1)
|
||||
return run_local.astimezone(timezone.utc), next_local.astimezone(timezone.utc)
|
||||
|
||||
|
||||
class AutomationJobsRepository(BaseRepository[AutomationJob]):
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
super().__init__(session=session, model=AutomationJob)
|
||||
|
||||
async def list_by_owner(self, owner_id: UUID) -> list[AutomationJob]:
|
||||
stmt = (
|
||||
select(AutomationJob)
|
||||
.where(AutomationJob.owner_id == owner_id)
|
||||
.where(AutomationJob.deleted_at.is_(None))
|
||||
.order_by(AutomationJob.created_at.desc())
|
||||
)
|
||||
rows = (await self._session.execute(stmt)).scalars().all()
|
||||
return list(rows)
|
||||
|
||||
async def get_by_id(self, job_id: UUID) -> AutomationJob | None: # type: ignore[override]
|
||||
stmt = (
|
||||
select(AutomationJob)
|
||||
.where(AutomationJob.id == job_id)
|
||||
.where(AutomationJob.deleted_at.is_(None))
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def count_user_jobs(self, owner_id: UUID) -> int:
|
||||
stmt = (
|
||||
select(func.count(AutomationJob.id))
|
||||
.where(AutomationJob.owner_id == owner_id)
|
||||
.where(AutomationJob.bootstrap_key.is_(None))
|
||||
.where(AutomationJob.deleted_at.is_(None))
|
||||
)
|
||||
result = (await self._session.execute(stmt)).scalar_one()
|
||||
return int(result)
|
||||
|
||||
async def create(
|
||||
self, owner_id: UUID, data: "AutomationJobCreateRequest"
|
||||
) -> AutomationJob:
|
||||
now_utc = datetime.now(timezone.utc)
|
||||
run_at_dt, next_run_at = _compute_next_local_time_utc(
|
||||
now_utc=now_utc,
|
||||
timezone_name=data.timezone,
|
||||
local_hour=data.run_at.hour,
|
||||
local_minute=data.run_at.minute,
|
||||
)
|
||||
new_job = AutomationJob(
|
||||
id=uuid4(),
|
||||
owner_id=owner_id,
|
||||
created_by=owner_id,
|
||||
bootstrap_key=None,
|
||||
title=data.title,
|
||||
config=data.config.model_dump(mode="json"),
|
||||
schedule_type=data.schedule_type,
|
||||
run_at=run_at_dt,
|
||||
next_run_at=next_run_at,
|
||||
timezone=data.timezone,
|
||||
status=data.status,
|
||||
)
|
||||
self._session.add(new_job)
|
||||
await self._session.flush()
|
||||
return new_job
|
||||
|
||||
async def update(
|
||||
self, job_id: UUID, data: "AutomationJobUpdateRequest"
|
||||
) -> AutomationJob | None:
|
||||
update_values: dict[str, object] = {}
|
||||
if data.title is not None:
|
||||
update_values["title"] = data.title
|
||||
if data.schedule_type is not None:
|
||||
update_values["schedule_type"] = data.schedule_type
|
||||
if data.run_at is not None:
|
||||
stmt = select(AutomationJob).where(AutomationJob.id == job_id)
|
||||
existing = (await self._session.execute(stmt)).scalar_one_or_none()
|
||||
if existing is None:
|
||||
return None
|
||||
run_at_dt, next_run_at = _compute_next_local_time_utc(
|
||||
now_utc=datetime.now(timezone.utc),
|
||||
timezone_name=data.timezone or existing.timezone,
|
||||
local_hour=data.run_at.hour,
|
||||
local_minute=data.run_at.minute,
|
||||
)
|
||||
update_values["run_at"] = run_at_dt
|
||||
update_values["next_run_at"] = next_run_at
|
||||
update_values["timezone"] = data.timezone or existing.timezone
|
||||
if data.status is not None:
|
||||
update_values["status"] = data.status
|
||||
if data.config is not None:
|
||||
update_values["config"] = data.config.model_dump(mode="json")
|
||||
|
||||
if not update_values:
|
||||
return await self.get_by_id(job_id)
|
||||
|
||||
stmt = (
|
||||
update(AutomationJob)
|
||||
.where(AutomationJob.id == job_id)
|
||||
.where(AutomationJob.deleted_at.is_(None))
|
||||
.values(**update_values)
|
||||
.returning(AutomationJob)
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
await self._session.flush()
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def soft_delete(self, job_id: UUID) -> None:
|
||||
stmt = (
|
||||
update(AutomationJob)
|
||||
.where(AutomationJob.id == job_id)
|
||||
.where(AutomationJob.deleted_at.is_(None))
|
||||
.values(deleted_at=datetime.now(timezone.utc))
|
||||
)
|
||||
await self._session.execute(stmt)
|
||||
await self._session.flush()
|
||||
|
||||
async def list_due_jobs(
|
||||
self,
|
||||
*,
|
||||
|
||||
Reference in New Issue
Block a user