feat: 添加自动化任务(automation_jobs)功能模块
This commit is contained in:
@@ -8,7 +8,11 @@ from typing import Any
|
||||
import yaml
|
||||
|
||||
from core.agentscope.tools.tool_config import AgentTool
|
||||
from schemas.automation import AutomationJobConfig, MessageContextConfig
|
||||
from models.automation_jobs import ScheduleType
|
||||
from schemas.automation import (
|
||||
AutomationJobConfig,
|
||||
MessageContextConfig,
|
||||
)
|
||||
|
||||
_CONFIG_NAME_PATTERN = re.compile(r"^[a-z0-9][a-z0-9_-]{0,63}$")
|
||||
|
||||
@@ -43,4 +47,8 @@ def load_static_automation_job_config(*, config_name: str) -> AutomationJobConfi
|
||||
raise ValueError(
|
||||
"memory_extraction context must be latest_chat/day with window_count=2"
|
||||
)
|
||||
if config.schedule is None:
|
||||
raise ValueError("memory_extraction schedule must be configured")
|
||||
if config.schedule.type != ScheduleType.DAILY:
|
||||
raise ValueError("memory_extraction schedule type must be daily")
|
||||
return config
|
||||
|
||||
@@ -22,9 +22,6 @@ from v1.memories.repository import SQLAlchemyMemoriesRepository
|
||||
|
||||
logger = get_logger("v1.auth.registration_bootstrap")
|
||||
|
||||
_LOCAL_RUN_HOUR = 8
|
||||
_LOCAL_RUN_MINUTE = 0
|
||||
|
||||
|
||||
class RegistrationBootstrapRepository:
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
@@ -49,6 +46,7 @@ class RegistrationBootstrapRepository:
|
||||
timezone_name: str,
|
||||
run_at: datetime,
|
||||
next_run_at: datetime,
|
||||
schedule_type: ScheduleType,
|
||||
) -> bool:
|
||||
stmt = (
|
||||
insert(AutomationJob)
|
||||
@@ -58,7 +56,7 @@ class RegistrationBootstrapRepository:
|
||||
bootstrap_key=bootstrap_key,
|
||||
title=title,
|
||||
config=config.model_dump(mode="json"),
|
||||
schedule_type=ScheduleType.DAILY,
|
||||
schedule_type=schedule_type,
|
||||
run_at=run_at,
|
||||
next_run_at=next_run_at,
|
||||
timezone=timezone_name,
|
||||
@@ -107,6 +105,7 @@ class RegistrationBootstrapRepositoryLike(Protocol):
|
||||
timezone_name: str,
|
||||
run_at: datetime,
|
||||
next_run_at: datetime,
|
||||
schedule_type: ScheduleType,
|
||||
) -> bool: ...
|
||||
|
||||
async def upsert_initial_memory(
|
||||
@@ -130,6 +129,7 @@ def compute_next_local_time_utc(
|
||||
timezone_name: str,
|
||||
local_hour: int,
|
||||
local_minute: int,
|
||||
schedule_type: ScheduleType,
|
||||
) -> tuple[datetime, datetime]:
|
||||
try:
|
||||
timezone_obj = ZoneInfo(timezone_name)
|
||||
@@ -147,7 +147,10 @@ def compute_next_local_time_utc(
|
||||
if local_now <= today_run_local
|
||||
else today_run_local + timedelta(days=1)
|
||||
)
|
||||
next_local = run_local + timedelta(days=1)
|
||||
if schedule_type == ScheduleType.WEEKLY:
|
||||
next_local = run_local + timedelta(weeks=1)
|
||||
else:
|
||||
next_local = run_local + timedelta(days=1)
|
||||
return run_local.astimezone(UTC), next_local.astimezone(UTC)
|
||||
|
||||
|
||||
@@ -170,9 +173,7 @@ class RegistrationAutomationBootstrapService:
|
||||
{
|
||||
"bootstrap_key": "memory_extraction",
|
||||
"config_name": "memory_extraction",
|
||||
"title": "Memory Agent",
|
||||
"local_hour": _LOCAL_RUN_HOUR,
|
||||
"local_minute": _LOCAL_RUN_MINUTE,
|
||||
"title": "记忆推送",
|
||||
}
|
||||
]
|
||||
|
||||
@@ -197,11 +198,17 @@ class RegistrationAutomationBootstrapService:
|
||||
job_config = load_static_automation_job_config(
|
||||
config_name=str(definition["config_name"])
|
||||
)
|
||||
schedule = job_config.schedule
|
||||
if schedule is None:
|
||||
raise ValueError(
|
||||
f"bootstrap job {bootstrap_key} has no schedule configured"
|
||||
)
|
||||
run_at, next_run_at = compute_next_local_time_utc(
|
||||
now_utc=datetime.now(UTC),
|
||||
timezone_name=timezone_name,
|
||||
local_hour=int(definition["local_hour"]),
|
||||
local_minute=int(definition["local_minute"]),
|
||||
local_hour=schedule.run_at.hour,
|
||||
local_minute=schedule.run_at.minute,
|
||||
schedule_type=schedule.type,
|
||||
)
|
||||
inserted = (
|
||||
await self._repository.insert_bootstrap_automation_job_if_absent(
|
||||
@@ -212,6 +219,7 @@ class RegistrationAutomationBootstrapService:
|
||||
timezone_name=timezone_name,
|
||||
run_at=run_at,
|
||||
next_run_at=next_run_at,
|
||||
schedule_type=schedule.type,
|
||||
)
|
||||
)
|
||||
inserted_any = inserted_any or inserted
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from core.db import get_db
|
||||
from v1.automation_jobs.repository import AutomationJobsRepository
|
||||
from v1.automation_jobs.service import AutomationJobsService
|
||||
from v1.users.dependencies import get_current_user
|
||||
|
||||
|
||||
async def get_automation_jobs_repository(
|
||||
session: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> AutomationJobsRepository:
|
||||
return AutomationJobsRepository(session=session)
|
||||
|
||||
|
||||
async def get_automation_jobs_service(
|
||||
repository: Annotated[
|
||||
AutomationJobsRepository, Depends(get_automation_jobs_repository)
|
||||
],
|
||||
session: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> AutomationJobsService:
|
||||
return AutomationJobsService(repository=repository, session=session)
|
||||
|
||||
|
||||
async def get_current_user_id(
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
) -> UUID:
|
||||
return current_user.id
|
||||
@@ -1,8 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from datetime import datetime, time, timedelta, timezone
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID, uuid4
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
@@ -10,7 +10,7 @@ from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||
|
||||
from core.db.base_repository import BaseRepository
|
||||
from models.agent_chat_session import AgentChatSession, SessionType
|
||||
from models.automation_jobs import AutomationJob
|
||||
from models.automation_jobs import AutomationJob, AutomationJobStatus, ScheduleType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from v1.automation_jobs.schemas import (
|
||||
@@ -19,144 +19,10 @@ if TYPE_CHECKING:
|
||||
)
|
||||
|
||||
|
||||
def _compute_next_local_time_utc(
|
||||
*,
|
||||
now_utc: datetime,
|
||||
timezone_name: str,
|
||||
local_hour: int,
|
||||
local_minute: int,
|
||||
) -> tuple[datetime, datetime]:
|
||||
try:
|
||||
timezone_obj = ZoneInfo(timezone_name)
|
||||
except ZoneInfoNotFoundError:
|
||||
timezone_obj = ZoneInfo("UTC")
|
||||
local_now = now_utc.astimezone(timezone_obj)
|
||||
today_run_local = local_now.replace(
|
||||
hour=local_hour,
|
||||
minute=local_minute,
|
||||
second=0,
|
||||
microsecond=0,
|
||||
)
|
||||
run_local = (
|
||||
today_run_local
|
||||
if local_now <= today_run_local
|
||||
else today_run_local + timedelta(days=1)
|
||||
)
|
||||
next_local = run_local + timedelta(days=1)
|
||||
return run_local.astimezone(timezone.utc), next_local.astimezone(timezone.utc)
|
||||
|
||||
|
||||
class AutomationJobsRepository(BaseRepository[AutomationJob]):
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
super().__init__(session=session, model=AutomationJob)
|
||||
|
||||
async def list_by_owner(self, owner_id: UUID) -> list[AutomationJob]:
|
||||
stmt = (
|
||||
select(AutomationJob)
|
||||
.where(AutomationJob.owner_id == owner_id)
|
||||
.where(AutomationJob.deleted_at.is_(None))
|
||||
.order_by(AutomationJob.created_at.desc())
|
||||
)
|
||||
rows = (await self._session.execute(stmt)).scalars().all()
|
||||
return list(rows)
|
||||
|
||||
async def get_by_id(self, job_id: UUID) -> AutomationJob | None: # type: ignore[override]
|
||||
stmt = (
|
||||
select(AutomationJob)
|
||||
.where(AutomationJob.id == job_id)
|
||||
.where(AutomationJob.deleted_at.is_(None))
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def count_user_jobs(self, owner_id: UUID) -> int:
|
||||
stmt = (
|
||||
select(func.count(AutomationJob.id))
|
||||
.where(AutomationJob.owner_id == owner_id)
|
||||
.where(AutomationJob.bootstrap_key.is_(None))
|
||||
.where(AutomationJob.deleted_at.is_(None))
|
||||
)
|
||||
result = (await self._session.execute(stmt)).scalar_one()
|
||||
return int(result)
|
||||
|
||||
async def create(
|
||||
self, owner_id: UUID, data: "AutomationJobCreateRequest"
|
||||
) -> AutomationJob:
|
||||
now_utc = datetime.now(timezone.utc)
|
||||
run_at_dt, next_run_at = _compute_next_local_time_utc(
|
||||
now_utc=now_utc,
|
||||
timezone_name=data.timezone,
|
||||
local_hour=data.run_at.hour,
|
||||
local_minute=data.run_at.minute,
|
||||
)
|
||||
new_job = AutomationJob(
|
||||
id=uuid4(),
|
||||
owner_id=owner_id,
|
||||
created_by=owner_id,
|
||||
bootstrap_key=None,
|
||||
title=data.title,
|
||||
config=data.config.model_dump(mode="json"),
|
||||
schedule_type=data.schedule_type,
|
||||
run_at=run_at_dt,
|
||||
next_run_at=next_run_at,
|
||||
timezone=data.timezone,
|
||||
status=data.status,
|
||||
)
|
||||
self._session.add(new_job)
|
||||
await self._session.flush()
|
||||
return new_job
|
||||
|
||||
async def update(
|
||||
self, job_id: UUID, data: "AutomationJobUpdateRequest"
|
||||
) -> AutomationJob | None:
|
||||
update_values: dict[str, object] = {}
|
||||
if data.title is not None:
|
||||
update_values["title"] = data.title
|
||||
if data.schedule_type is not None:
|
||||
update_values["schedule_type"] = data.schedule_type
|
||||
if data.run_at is not None:
|
||||
stmt = select(AutomationJob).where(AutomationJob.id == job_id)
|
||||
existing = (await self._session.execute(stmt)).scalar_one_or_none()
|
||||
if existing is None:
|
||||
return None
|
||||
run_at_dt, next_run_at = _compute_next_local_time_utc(
|
||||
now_utc=datetime.now(timezone.utc),
|
||||
timezone_name=data.timezone or existing.timezone,
|
||||
local_hour=data.run_at.hour,
|
||||
local_minute=data.run_at.minute,
|
||||
)
|
||||
update_values["run_at"] = run_at_dt
|
||||
update_values["next_run_at"] = next_run_at
|
||||
update_values["timezone"] = data.timezone or existing.timezone
|
||||
if data.status is not None:
|
||||
update_values["status"] = data.status
|
||||
if data.config is not None:
|
||||
update_values["config"] = data.config.model_dump(mode="json")
|
||||
|
||||
if not update_values:
|
||||
return await self.get_by_id(job_id)
|
||||
|
||||
stmt = (
|
||||
update(AutomationJob)
|
||||
.where(AutomationJob.id == job_id)
|
||||
.where(AutomationJob.deleted_at.is_(None))
|
||||
.values(**update_values)
|
||||
.returning(AutomationJob)
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
await self._session.flush()
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def soft_delete(self, job_id: UUID) -> None:
|
||||
stmt = (
|
||||
update(AutomationJob)
|
||||
.where(AutomationJob.id == job_id)
|
||||
.where(AutomationJob.deleted_at.is_(None))
|
||||
.values(deleted_at=datetime.now(timezone.utc))
|
||||
)
|
||||
await self._session.execute(stmt)
|
||||
await self._session.flush()
|
||||
|
||||
async def list_due_jobs(
|
||||
self,
|
||||
*,
|
||||
@@ -166,7 +32,7 @@ class AutomationJobsRepository(BaseRepository[AutomationJob]):
|
||||
stmt = (
|
||||
select(AutomationJob)
|
||||
.where(AutomationJob.deleted_at.is_(None))
|
||||
.where(AutomationJob.status == "active")
|
||||
.where(AutomationJob.status == AutomationJobStatus.ACTIVE)
|
||||
.where(AutomationJob.next_run_at <= now_utc)
|
||||
.order_by(AutomationJob.next_run_at.asc())
|
||||
.limit(max(limit, 1))
|
||||
@@ -213,3 +79,160 @@ class AutomationJobsRepository(BaseRepository[AutomationJob]):
|
||||
self._session.add(new_session)
|
||||
await self._session.flush()
|
||||
return new_session.id
|
||||
|
||||
async def list_by_owner(self, owner_id: UUID) -> list[AutomationJob]:
|
||||
stmt = (
|
||||
select(AutomationJob)
|
||||
.where(AutomationJob.owner_id == owner_id)
|
||||
.where(AutomationJob.deleted_at.is_(None))
|
||||
.order_by(AutomationJob.created_at.desc())
|
||||
)
|
||||
rows = (await self._session.execute(stmt)).scalars().all()
|
||||
return list(rows)
|
||||
|
||||
async def count_user_jobs(self, owner_id: UUID) -> int:
|
||||
stmt = (
|
||||
select(func.count())
|
||||
.select_from(AutomationJob)
|
||||
.where(AutomationJob.owner_id == owner_id)
|
||||
.where(AutomationJob.deleted_at.is_(None))
|
||||
.where(AutomationJob.bootstrap_key.is_(None))
|
||||
)
|
||||
result = (await self._session.execute(stmt)).scalar_one()
|
||||
return int(result)
|
||||
|
||||
def _resolve_timezone(self, timezone_str: str) -> ZoneInfo:
|
||||
try:
|
||||
return ZoneInfo(timezone_str)
|
||||
except ZoneInfoNotFoundError:
|
||||
return ZoneInfo("UTC")
|
||||
|
||||
def _compute_initial_next_run_at(
|
||||
self,
|
||||
*,
|
||||
run_at: time,
|
||||
timezone_str: str,
|
||||
now_utc: datetime,
|
||||
schedule_type: ScheduleType,
|
||||
) -> datetime:
|
||||
tz = self._resolve_timezone(timezone_str)
|
||||
local_now = now_utc.astimezone(tz)
|
||||
run_at_local = datetime.combine(local_now.date(), run_at, tz)
|
||||
if run_at_local.tzinfo is None:
|
||||
run_at_local = run_at_local.replace(tzinfo=tz)
|
||||
next_run_at = run_at_local
|
||||
if next_run_at <= local_now:
|
||||
if schedule_type == ScheduleType.DAILY:
|
||||
next_run_at = next_run_at + timedelta(days=1)
|
||||
else:
|
||||
next_run_at = next_run_at + timedelta(weeks=1)
|
||||
return next_run_at.astimezone(timezone.utc)
|
||||
|
||||
async def create(
|
||||
self,
|
||||
owner_id: UUID,
|
||||
data: AutomationJobCreateRequest,
|
||||
) -> AutomationJob:
|
||||
now_utc = datetime.now(tz=timezone.utc)
|
||||
timezone_obj = self._resolve_timezone(data.timezone)
|
||||
local_now = now_utc.astimezone(timezone_obj)
|
||||
date_ref = local_now.date()
|
||||
local_dt = datetime.combine(date_ref, data.run_at, timezone_obj)
|
||||
run_at_datetime = local_dt.astimezone(timezone.utc)
|
||||
next_run_at = self._compute_initial_next_run_at(
|
||||
run_at=data.run_at,
|
||||
timezone_str=data.timezone,
|
||||
now_utc=now_utc,
|
||||
schedule_type=data.schedule_type,
|
||||
)
|
||||
|
||||
new_job = AutomationJob(
|
||||
owner_id=owner_id,
|
||||
created_by=owner_id,
|
||||
bootstrap_key=None,
|
||||
title=data.title,
|
||||
schedule_type=data.schedule_type,
|
||||
run_at=run_at_datetime,
|
||||
timezone=data.timezone,
|
||||
status=data.status,
|
||||
config=data.config.model_dump(mode="json"),
|
||||
next_run_at=next_run_at,
|
||||
)
|
||||
self._session.add(new_job)
|
||||
await self._session.flush()
|
||||
return new_job
|
||||
|
||||
async def update(
|
||||
self,
|
||||
job_id: UUID,
|
||||
data: AutomationJobUpdateRequest,
|
||||
) -> AutomationJob | None:
|
||||
update_values: dict[str, object] = {}
|
||||
existing_job: AutomationJob | None = None
|
||||
|
||||
if data.title is not None:
|
||||
update_values["title"] = data.title
|
||||
if data.schedule_type is not None:
|
||||
update_values["schedule_type"] = data.schedule_type
|
||||
|
||||
should_recompute_schedule = (
|
||||
data.run_at is not None
|
||||
or data.schedule_type is not None
|
||||
or data.timezone is not None
|
||||
)
|
||||
if should_recompute_schedule:
|
||||
now_utc = datetime.now(tz=timezone.utc)
|
||||
if existing_job is None:
|
||||
existing_job = await self.get_by_id(job_id)
|
||||
if existing_job is None:
|
||||
return None
|
||||
|
||||
effective_timezone = data.timezone or existing_job.timezone
|
||||
effective_timezone_obj = self._resolve_timezone(effective_timezone)
|
||||
effective_schedule_type = data.schedule_type or existing_job.schedule_type
|
||||
|
||||
if data.run_at is not None:
|
||||
effective_run_at = data.run_at
|
||||
else:
|
||||
existing_timezone_obj = self._resolve_timezone(existing_job.timezone)
|
||||
effective_run_at = (
|
||||
existing_job.run_at.astimezone(existing_timezone_obj)
|
||||
.time()
|
||||
.replace(microsecond=0)
|
||||
)
|
||||
|
||||
local_now = now_utc.astimezone(effective_timezone_obj)
|
||||
local_dt = datetime.combine(
|
||||
local_now.date(),
|
||||
effective_run_at,
|
||||
effective_timezone_obj,
|
||||
)
|
||||
update_values["run_at"] = local_dt.astimezone(timezone.utc)
|
||||
update_values["next_run_at"] = self._compute_initial_next_run_at(
|
||||
run_at=effective_run_at,
|
||||
timezone_str=effective_timezone,
|
||||
now_utc=now_utc,
|
||||
schedule_type=effective_schedule_type,
|
||||
)
|
||||
if data.timezone is not None:
|
||||
update_values["timezone"] = data.timezone
|
||||
if data.status is not None:
|
||||
update_values["status"] = data.status
|
||||
if data.config is not None:
|
||||
if existing_job is None:
|
||||
existing_job = await self.get_by_id(job_id)
|
||||
if existing_job is None:
|
||||
return None
|
||||
merged_config = {
|
||||
**existing_job.config,
|
||||
**data.config.model_dump(mode="json", exclude_unset=True),
|
||||
}
|
||||
update_values["config"] = merged_config
|
||||
|
||||
if not update_values:
|
||||
return await self.get_by_id(job_id)
|
||||
|
||||
return await self.update_by_id(job_id, update_values)
|
||||
|
||||
async def soft_delete(self, job_id: UUID) -> None:
|
||||
await self.soft_delete_by_id(job_id)
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, status
|
||||
|
||||
from v1.automation_jobs.dependencies import (
|
||||
get_automation_jobs_service,
|
||||
get_current_user_id,
|
||||
)
|
||||
from v1.automation_jobs.schemas import (
|
||||
AutomationJobCreateRequest,
|
||||
AutomationJobListResponse,
|
||||
AutomationJobResponse,
|
||||
AutomationJobUpdateRequest,
|
||||
)
|
||||
from v1.automation_jobs.service import AutomationJobsService
|
||||
|
||||
|
||||
router = APIRouter(prefix="/automation-jobs", tags=["automation-jobs"])
|
||||
|
||||
|
||||
@router.get("", response_model=AutomationJobListResponse)
|
||||
async def list_automation_jobs(
|
||||
service: Annotated[AutomationJobsService, Depends(get_automation_jobs_service)],
|
||||
current_user_id: Annotated[UUID, Depends(get_current_user_id)],
|
||||
) -> AutomationJobListResponse:
|
||||
return await service.list_by_owner(owner_id=current_user_id)
|
||||
|
||||
|
||||
@router.post(
|
||||
"", response_model=AutomationJobResponse, status_code=status.HTTP_201_CREATED
|
||||
)
|
||||
async def create_automation_job(
|
||||
request: AutomationJobCreateRequest,
|
||||
service: Annotated[AutomationJobsService, Depends(get_automation_jobs_service)],
|
||||
current_user_id: Annotated[UUID, Depends(get_current_user_id)],
|
||||
) -> AutomationJobResponse:
|
||||
return await service.create(owner_id=current_user_id, data=request)
|
||||
|
||||
|
||||
@router.get("/{job_id}", response_model=AutomationJobResponse)
|
||||
async def get_automation_job(
|
||||
job_id: UUID,
|
||||
service: Annotated[AutomationJobsService, Depends(get_automation_jobs_service)],
|
||||
current_user_id: Annotated[UUID, Depends(get_current_user_id)],
|
||||
) -> AutomationJobResponse:
|
||||
return await service.get_by_id(job_id=job_id, owner_id=current_user_id)
|
||||
|
||||
|
||||
@router.patch("/{job_id}", response_model=AutomationJobResponse)
|
||||
async def update_automation_job(
|
||||
job_id: UUID,
|
||||
request: AutomationJobUpdateRequest,
|
||||
service: Annotated[AutomationJobsService, Depends(get_automation_jobs_service)],
|
||||
current_user_id: Annotated[UUID, Depends(get_current_user_id)],
|
||||
) -> AutomationJobResponse:
|
||||
return await service.update(job_id=job_id, owner_id=current_user_id, data=request)
|
||||
|
||||
|
||||
@router.delete("/{job_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_automation_job(
|
||||
job_id: UUID,
|
||||
service: Annotated[AutomationJobsService, Depends(get_automation_jobs_service)],
|
||||
current_user_id: Annotated[UUID, Depends(get_current_user_id)],
|
||||
) -> None:
|
||||
await service.delete(job_id=job_id, owner_id=current_user_id)
|
||||
@@ -3,14 +3,13 @@ from __future__ import annotations
|
||||
from datetime import datetime, time
|
||||
from typing import Self
|
||||
from uuid import UUID
|
||||
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from models.automation_jobs import AutomationJob as OrmAutomationJob
|
||||
from models.automation_jobs import AutomationJobStatus, ScheduleType
|
||||
from schemas.automation import (
|
||||
AutomationJobConfig,
|
||||
)
|
||||
from schemas.automation import AutomationJobConfig
|
||||
|
||||
|
||||
class AutomationJobResponse(BaseModel):
|
||||
@@ -61,6 +60,15 @@ class AutomationJobCreateRequest(BaseModel):
|
||||
status: AutomationJobStatus = Field(default=AutomationJobStatus.ACTIVE)
|
||||
config: AutomationJobConfig
|
||||
|
||||
@field_validator("timezone")
|
||||
@classmethod
|
||||
def validate_timezone(cls, value: str) -> str:
|
||||
try:
|
||||
ZoneInfo(value)
|
||||
except ZoneInfoNotFoundError as exc:
|
||||
raise ValueError("timezone must be a valid IANA timezone") from exc
|
||||
return value
|
||||
|
||||
|
||||
class AutomationJobUpdateRequest(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
@@ -72,6 +80,17 @@ class AutomationJobUpdateRequest(BaseModel):
|
||||
status: AutomationJobStatus | None = None
|
||||
config: AutomationJobConfig | None = None
|
||||
|
||||
@field_validator("timezone")
|
||||
@classmethod
|
||||
def validate_timezone(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
try:
|
||||
ZoneInfo(value)
|
||||
except ZoneInfoNotFoundError as exc:
|
||||
raise ValueError("timezone must be a valid IANA timezone") from exc
|
||||
return value
|
||||
|
||||
|
||||
class AutomationJobListResponse(BaseModel):
|
||||
items: list[AutomationJobResponse]
|
||||
|
||||
@@ -5,14 +5,54 @@ from datetime import datetime, timedelta
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
from models.automation_jobs import ScheduleType
|
||||
from schemas.automation import AutomationJob as AutomationJobSchema, RuntimeConfig
|
||||
from schemas.automation import (
|
||||
AutomationJob as AutomationJobSchema,
|
||||
MessageContextConfig,
|
||||
RuntimeConfig,
|
||||
)
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from core.logging import get_logger
|
||||
from v1.automation_jobs.schemas import (
|
||||
AutomationJobCreateRequest,
|
||||
AutomationJobListResponse,
|
||||
AutomationJobResponse,
|
||||
AutomationJobUpdateRequest,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from v1.automation_jobs.repository import AutomationJobsRepository
|
||||
|
||||
logger = get_logger("v1.automation_jobs.service")
|
||||
|
||||
|
||||
class AutomationJobLimitExceeded(HTTPException):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Maximum of 3 user jobs allowed",
|
||||
)
|
||||
|
||||
|
||||
class SystemJobModificationForbidden(HTTPException):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="System job cannot be modified",
|
||||
)
|
||||
|
||||
|
||||
class AutomationJobNotFound(HTTPException):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Automation job not found",
|
||||
)
|
||||
|
||||
|
||||
class DispatchFn(Protocol):
|
||||
async def __call__(
|
||||
@@ -46,6 +86,9 @@ class ScanResult:
|
||||
|
||||
|
||||
class AutomationJobsService:
|
||||
_repository: "AutomationJobsRepository"
|
||||
_session: "AsyncSession"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
repository: "AutomationJobsRepository",
|
||||
@@ -71,14 +114,15 @@ class AutomationJobsService:
|
||||
thread_id = await self.get_or_create_chat_session(owner_id=job.owner_id)
|
||||
run_id = f"auto-{job.id}-{int(now_utc.timestamp())}"
|
||||
|
||||
input_text = (job.config.input_template or "").strip()
|
||||
await dispatch_fn(
|
||||
owner_id=job.owner_id,
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
input_text=job.config.input_template.strip(),
|
||||
input_text=input_text,
|
||||
runtime_config=RuntimeConfig(
|
||||
enabled_tools=job.config.enabled_tools,
|
||||
context=job.config.context,
|
||||
enabled_tools=job.config.enabled_tools or [],
|
||||
context=job.config.context or MessageContextConfig(),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -102,3 +146,82 @@ class AutomationJobsService:
|
||||
|
||||
async def get_or_create_chat_session(self, *, owner_id: UUID) -> UUID:
|
||||
return await self._repository.get_or_create_chat_session(owner_id=owner_id)
|
||||
|
||||
async def list_by_owner(self, owner_id: UUID) -> AutomationJobListResponse:
|
||||
jobs = await self._repository.list_by_owner(owner_id)
|
||||
return AutomationJobListResponse(
|
||||
items=[AutomationJobResponse.from_orm(job) for job in jobs],
|
||||
)
|
||||
|
||||
async def get_by_id(self, job_id: UUID, owner_id: UUID) -> AutomationJobResponse:
|
||||
job = await self._repository.get_by_id(job_id)
|
||||
if job is None or job.owner_id != owner_id:
|
||||
raise AutomationJobNotFound()
|
||||
return AutomationJobResponse.from_orm(job)
|
||||
|
||||
async def create(
|
||||
self,
|
||||
owner_id: UUID,
|
||||
data: AutomationJobCreateRequest,
|
||||
) -> AutomationJobResponse:
|
||||
try:
|
||||
await self._session.execute(
|
||||
text("SELECT pg_advisory_xact_lock(abs(hashtext(:owner_id)))"),
|
||||
{"owner_id": str(owner_id)},
|
||||
)
|
||||
count = await self._repository.count_user_jobs(owner_id)
|
||||
if count >= 3:
|
||||
await self._session.rollback()
|
||||
raise AutomationJobLimitExceeded()
|
||||
job = await self._repository.create(owner_id, data)
|
||||
await self._session.commit()
|
||||
return AutomationJobResponse.from_orm(job)
|
||||
except SQLAlchemyError:
|
||||
await self._session.rollback()
|
||||
logger.exception("Failed to create automation job", owner_id=str(owner_id))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Automation job store unavailable",
|
||||
)
|
||||
|
||||
async def update(
|
||||
self,
|
||||
job_id: UUID,
|
||||
owner_id: UUID,
|
||||
data: AutomationJobUpdateRequest,
|
||||
) -> AutomationJobResponse:
|
||||
try:
|
||||
job = await self._repository.get_by_id(job_id)
|
||||
if job is None or job.owner_id != owner_id:
|
||||
raise AutomationJobNotFound()
|
||||
if job.bootstrap_key is not None:
|
||||
raise SystemJobModificationForbidden()
|
||||
updated_job = await self._repository.update(job_id, data)
|
||||
if updated_job is None:
|
||||
raise AutomationJobNotFound()
|
||||
await self._session.commit()
|
||||
return AutomationJobResponse.from_orm(updated_job)
|
||||
except SQLAlchemyError:
|
||||
await self._session.rollback()
|
||||
logger.exception("Failed to update automation job", job_id=str(job_id))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Automation job store unavailable",
|
||||
)
|
||||
|
||||
async def delete(self, job_id: UUID, owner_id: UUID) -> None:
|
||||
try:
|
||||
job = await self._repository.get_by_id(job_id)
|
||||
if job is None or job.owner_id != owner_id:
|
||||
raise AutomationJobNotFound()
|
||||
if job.bootstrap_key is not None:
|
||||
raise SystemJobModificationForbidden()
|
||||
await self._repository.soft_delete(job_id)
|
||||
await self._session.commit()
|
||||
except SQLAlchemyError:
|
||||
await self._session.rollback()
|
||||
logger.exception("Failed to delete automation job", job_id=str(job_id))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Automation job store unavailable",
|
||||
)
|
||||
|
||||
@@ -4,6 +4,7 @@ from fastapi import APIRouter
|
||||
|
||||
from v1.agent.router import router as agent_router
|
||||
from v1.app.router import router as app_router
|
||||
from v1.automation_jobs.router import router as automation_jobs_router
|
||||
from v1.auth.router import router as auth_router
|
||||
from v1.friendships.router import router as friendships_router
|
||||
from v1.inbox_messages.router import router as inbox_messages_router
|
||||
@@ -17,6 +18,7 @@ router = APIRouter(prefix="/api/v1")
|
||||
router.include_router(app_router)
|
||||
router.include_router(auth_router)
|
||||
router.include_router(agent_router)
|
||||
router.include_router(automation_jobs_router)
|
||||
router.include_router(friendships_router)
|
||||
router.include_router(memories_router)
|
||||
router.include_router(users_router)
|
||||
|
||||
Reference in New Issue
Block a user