feat: 添加自动化任务(automation_jobs)功能模块
This commit is contained in:
@@ -13,3 +13,8 @@ context:
|
||||
source: latest_chat
|
||||
window_mode: day
|
||||
window_count: 2
|
||||
schedule:
|
||||
type: daily
|
||||
run_at:
|
||||
hour: 8
|
||||
minute: 0
|
||||
|
||||
@@ -10,21 +10,6 @@ routes:
|
||||
description: Login entry for unauthenticated users.
|
||||
category: auth
|
||||
auth_required: false
|
||||
- route_id: auth.register
|
||||
path: /register
|
||||
description: Account registration page.
|
||||
category: auth
|
||||
auth_required: false
|
||||
- route_id: auth.register_verification
|
||||
path: /register/verification
|
||||
description: Verifies registration code after signup.
|
||||
category: auth
|
||||
auth_required: false
|
||||
- route_id: auth.reset_password
|
||||
path: /reset-password
|
||||
description: Resets password using verification flow.
|
||||
category: auth
|
||||
auth_required: false
|
||||
- route_id: home.main
|
||||
path: /home
|
||||
description: Main assistant home screen.
|
||||
@@ -126,22 +111,44 @@ routes:
|
||||
auth_required: true
|
||||
- route_id: settings.features
|
||||
path: /settings/features
|
||||
description: Cycle planning settings page.
|
||||
description: Automation job list page.
|
||||
category: settings
|
||||
auth_required: true
|
||||
- route_id: settings.job_new
|
||||
path: /settings/job/new
|
||||
description: Create page for one automation job.
|
||||
category: settings
|
||||
auth_required: true
|
||||
- route_id: settings.job_detail
|
||||
path: /settings/job/{id}
|
||||
description: Detail page for one automation job.
|
||||
category: settings
|
||||
auth_required: true
|
||||
path_params:
|
||||
- id
|
||||
- route_id: settings.memory
|
||||
path: /settings/memory
|
||||
description: Memory preferences and controls.
|
||||
category: settings
|
||||
auth_required: true
|
||||
- route_id: settings.account
|
||||
path: /settings/account
|
||||
description: Account profile and security entry points.
|
||||
- route_id: settings.memory_user
|
||||
path: /settings/memory/user
|
||||
description: User memory summary view.
|
||||
category: settings
|
||||
auth_required: true
|
||||
- route_id: settings.change_password
|
||||
path: /change-password
|
||||
description: Password change page.
|
||||
- route_id: settings.memory_work
|
||||
path: /settings/memory/work
|
||||
description: Work memory summary view.
|
||||
category: settings
|
||||
auth_required: true
|
||||
- route_id: settings.memory_user_edit
|
||||
path: /settings/memory/user/edit
|
||||
description: Edit user memory details.
|
||||
category: settings
|
||||
auth_required: true
|
||||
- route_id: settings.memory_work_edit
|
||||
path: /settings/memory/work/edit
|
||||
description: Edit work memory details.
|
||||
category: settings
|
||||
auth_required: true
|
||||
- route_id: settings.edit_profile
|
||||
|
||||
@@ -28,6 +28,20 @@ class MessageContextConfig(BaseModel):
|
||||
window_count: int = Field(default=2, ge=1, le=200)
|
||||
|
||||
|
||||
class ScheduleRunAt(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
hour: int = Field(default=8, ge=0, le=23)
|
||||
minute: int = Field(default=0, ge=0, le=59)
|
||||
|
||||
|
||||
class ScheduleConfig(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
type: ScheduleType
|
||||
run_at: ScheduleRunAt
|
||||
|
||||
|
||||
class RuntimeConfig(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
@@ -35,10 +49,13 @@ class RuntimeConfig(BaseModel):
|
||||
context: MessageContextConfig = Field(default_factory=MessageContextConfig)
|
||||
|
||||
|
||||
class AutomationJobConfig(RuntimeConfig):
|
||||
class AutomationJobConfig(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
input_template: str = Field(..., min_length=1, max_length=4000)
|
||||
enabled_tools: list[AgentTool] | None = Field(default=None, max_length=32)
|
||||
context: MessageContextConfig | None = None
|
||||
input_template: str | None = Field(default=None, min_length=1, max_length=4000)
|
||||
schedule: ScheduleConfig | None = None
|
||||
|
||||
|
||||
class AutomationJob(BaseModel):
|
||||
@@ -59,10 +76,6 @@ class AutomationJob(BaseModel):
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
@property
|
||||
def is_system(self) -> bool:
|
||||
return self.bootstrap_key is not None
|
||||
|
||||
@classmethod
|
||||
def from_orm(cls, obj: OrmAutomationJob) -> "AutomationJob":
|
||||
return cls(
|
||||
@@ -81,3 +94,7 @@ class AutomationJob(BaseModel):
|
||||
created_at=obj.created_at,
|
||||
updated_at=obj.updated_at,
|
||||
)
|
||||
|
||||
@property
|
||||
def is_system(self) -> bool:
|
||||
return self.bootstrap_key is not None
|
||||
|
||||
@@ -8,7 +8,11 @@ from typing import Any
|
||||
import yaml
|
||||
|
||||
from core.agentscope.tools.tool_config import AgentTool
|
||||
from schemas.automation import AutomationJobConfig, MessageContextConfig
|
||||
from models.automation_jobs import ScheduleType
|
||||
from schemas.automation import (
|
||||
AutomationJobConfig,
|
||||
MessageContextConfig,
|
||||
)
|
||||
|
||||
_CONFIG_NAME_PATTERN = re.compile(r"^[a-z0-9][a-z0-9_-]{0,63}$")
|
||||
|
||||
@@ -43,4 +47,8 @@ def load_static_automation_job_config(*, config_name: str) -> AutomationJobConfi
|
||||
raise ValueError(
|
||||
"memory_extraction context must be latest_chat/day with window_count=2"
|
||||
)
|
||||
if config.schedule is None:
|
||||
raise ValueError("memory_extraction schedule must be configured")
|
||||
if config.schedule.type != ScheduleType.DAILY:
|
||||
raise ValueError("memory_extraction schedule type must be daily")
|
||||
return config
|
||||
|
||||
@@ -22,9 +22,6 @@ from v1.memories.repository import SQLAlchemyMemoriesRepository
|
||||
|
||||
logger = get_logger("v1.auth.registration_bootstrap")
|
||||
|
||||
_LOCAL_RUN_HOUR = 8
|
||||
_LOCAL_RUN_MINUTE = 0
|
||||
|
||||
|
||||
class RegistrationBootstrapRepository:
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
@@ -49,6 +46,7 @@ class RegistrationBootstrapRepository:
|
||||
timezone_name: str,
|
||||
run_at: datetime,
|
||||
next_run_at: datetime,
|
||||
schedule_type: ScheduleType,
|
||||
) -> bool:
|
||||
stmt = (
|
||||
insert(AutomationJob)
|
||||
@@ -58,7 +56,7 @@ class RegistrationBootstrapRepository:
|
||||
bootstrap_key=bootstrap_key,
|
||||
title=title,
|
||||
config=config.model_dump(mode="json"),
|
||||
schedule_type=ScheduleType.DAILY,
|
||||
schedule_type=schedule_type,
|
||||
run_at=run_at,
|
||||
next_run_at=next_run_at,
|
||||
timezone=timezone_name,
|
||||
@@ -107,6 +105,7 @@ class RegistrationBootstrapRepositoryLike(Protocol):
|
||||
timezone_name: str,
|
||||
run_at: datetime,
|
||||
next_run_at: datetime,
|
||||
schedule_type: ScheduleType,
|
||||
) -> bool: ...
|
||||
|
||||
async def upsert_initial_memory(
|
||||
@@ -130,6 +129,7 @@ def compute_next_local_time_utc(
|
||||
timezone_name: str,
|
||||
local_hour: int,
|
||||
local_minute: int,
|
||||
schedule_type: ScheduleType,
|
||||
) -> tuple[datetime, datetime]:
|
||||
try:
|
||||
timezone_obj = ZoneInfo(timezone_name)
|
||||
@@ -147,7 +147,10 @@ def compute_next_local_time_utc(
|
||||
if local_now <= today_run_local
|
||||
else today_run_local + timedelta(days=1)
|
||||
)
|
||||
next_local = run_local + timedelta(days=1)
|
||||
if schedule_type == ScheduleType.WEEKLY:
|
||||
next_local = run_local + timedelta(weeks=1)
|
||||
else:
|
||||
next_local = run_local + timedelta(days=1)
|
||||
return run_local.astimezone(UTC), next_local.astimezone(UTC)
|
||||
|
||||
|
||||
@@ -170,9 +173,7 @@ class RegistrationAutomationBootstrapService:
|
||||
{
|
||||
"bootstrap_key": "memory_extraction",
|
||||
"config_name": "memory_extraction",
|
||||
"title": "Memory Agent",
|
||||
"local_hour": _LOCAL_RUN_HOUR,
|
||||
"local_minute": _LOCAL_RUN_MINUTE,
|
||||
"title": "记忆推送",
|
||||
}
|
||||
]
|
||||
|
||||
@@ -197,11 +198,17 @@ class RegistrationAutomationBootstrapService:
|
||||
job_config = load_static_automation_job_config(
|
||||
config_name=str(definition["config_name"])
|
||||
)
|
||||
schedule = job_config.schedule
|
||||
if schedule is None:
|
||||
raise ValueError(
|
||||
f"bootstrap job {bootstrap_key} has no schedule configured"
|
||||
)
|
||||
run_at, next_run_at = compute_next_local_time_utc(
|
||||
now_utc=datetime.now(UTC),
|
||||
timezone_name=timezone_name,
|
||||
local_hour=int(definition["local_hour"]),
|
||||
local_minute=int(definition["local_minute"]),
|
||||
local_hour=schedule.run_at.hour,
|
||||
local_minute=schedule.run_at.minute,
|
||||
schedule_type=schedule.type,
|
||||
)
|
||||
inserted = (
|
||||
await self._repository.insert_bootstrap_automation_job_if_absent(
|
||||
@@ -212,6 +219,7 @@ class RegistrationAutomationBootstrapService:
|
||||
timezone_name=timezone_name,
|
||||
run_at=run_at,
|
||||
next_run_at=next_run_at,
|
||||
schedule_type=schedule.type,
|
||||
)
|
||||
)
|
||||
inserted_any = inserted_any or inserted
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from core.db import get_db
|
||||
from v1.automation_jobs.repository import AutomationJobsRepository
|
||||
from v1.automation_jobs.service import AutomationJobsService
|
||||
from v1.users.dependencies import get_current_user
|
||||
|
||||
|
||||
async def get_automation_jobs_repository(
|
||||
session: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> AutomationJobsRepository:
|
||||
return AutomationJobsRepository(session=session)
|
||||
|
||||
|
||||
async def get_automation_jobs_service(
|
||||
repository: Annotated[
|
||||
AutomationJobsRepository, Depends(get_automation_jobs_repository)
|
||||
],
|
||||
session: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> AutomationJobsService:
|
||||
return AutomationJobsService(repository=repository, session=session)
|
||||
|
||||
|
||||
async def get_current_user_id(
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
) -> UUID:
|
||||
return current_user.id
|
||||
@@ -1,8 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from datetime import datetime, time, timedelta, timezone
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID, uuid4
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
@@ -10,7 +10,7 @@ from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||
|
||||
from core.db.base_repository import BaseRepository
|
||||
from models.agent_chat_session import AgentChatSession, SessionType
|
||||
from models.automation_jobs import AutomationJob
|
||||
from models.automation_jobs import AutomationJob, AutomationJobStatus, ScheduleType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from v1.automation_jobs.schemas import (
|
||||
@@ -19,144 +19,10 @@ if TYPE_CHECKING:
|
||||
)
|
||||
|
||||
|
||||
def _compute_next_local_time_utc(
|
||||
*,
|
||||
now_utc: datetime,
|
||||
timezone_name: str,
|
||||
local_hour: int,
|
||||
local_minute: int,
|
||||
) -> tuple[datetime, datetime]:
|
||||
try:
|
||||
timezone_obj = ZoneInfo(timezone_name)
|
||||
except ZoneInfoNotFoundError:
|
||||
timezone_obj = ZoneInfo("UTC")
|
||||
local_now = now_utc.astimezone(timezone_obj)
|
||||
today_run_local = local_now.replace(
|
||||
hour=local_hour,
|
||||
minute=local_minute,
|
||||
second=0,
|
||||
microsecond=0,
|
||||
)
|
||||
run_local = (
|
||||
today_run_local
|
||||
if local_now <= today_run_local
|
||||
else today_run_local + timedelta(days=1)
|
||||
)
|
||||
next_local = run_local + timedelta(days=1)
|
||||
return run_local.astimezone(timezone.utc), next_local.astimezone(timezone.utc)
|
||||
|
||||
|
||||
class AutomationJobsRepository(BaseRepository[AutomationJob]):
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
super().__init__(session=session, model=AutomationJob)
|
||||
|
||||
async def list_by_owner(self, owner_id: UUID) -> list[AutomationJob]:
|
||||
stmt = (
|
||||
select(AutomationJob)
|
||||
.where(AutomationJob.owner_id == owner_id)
|
||||
.where(AutomationJob.deleted_at.is_(None))
|
||||
.order_by(AutomationJob.created_at.desc())
|
||||
)
|
||||
rows = (await self._session.execute(stmt)).scalars().all()
|
||||
return list(rows)
|
||||
|
||||
async def get_by_id(self, job_id: UUID) -> AutomationJob | None: # type: ignore[override]
|
||||
stmt = (
|
||||
select(AutomationJob)
|
||||
.where(AutomationJob.id == job_id)
|
||||
.where(AutomationJob.deleted_at.is_(None))
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def count_user_jobs(self, owner_id: UUID) -> int:
|
||||
stmt = (
|
||||
select(func.count(AutomationJob.id))
|
||||
.where(AutomationJob.owner_id == owner_id)
|
||||
.where(AutomationJob.bootstrap_key.is_(None))
|
||||
.where(AutomationJob.deleted_at.is_(None))
|
||||
)
|
||||
result = (await self._session.execute(stmt)).scalar_one()
|
||||
return int(result)
|
||||
|
||||
async def create(
|
||||
self, owner_id: UUID, data: "AutomationJobCreateRequest"
|
||||
) -> AutomationJob:
|
||||
now_utc = datetime.now(timezone.utc)
|
||||
run_at_dt, next_run_at = _compute_next_local_time_utc(
|
||||
now_utc=now_utc,
|
||||
timezone_name=data.timezone,
|
||||
local_hour=data.run_at.hour,
|
||||
local_minute=data.run_at.minute,
|
||||
)
|
||||
new_job = AutomationJob(
|
||||
id=uuid4(),
|
||||
owner_id=owner_id,
|
||||
created_by=owner_id,
|
||||
bootstrap_key=None,
|
||||
title=data.title,
|
||||
config=data.config.model_dump(mode="json"),
|
||||
schedule_type=data.schedule_type,
|
||||
run_at=run_at_dt,
|
||||
next_run_at=next_run_at,
|
||||
timezone=data.timezone,
|
||||
status=data.status,
|
||||
)
|
||||
self._session.add(new_job)
|
||||
await self._session.flush()
|
||||
return new_job
|
||||
|
||||
async def update(
|
||||
self, job_id: UUID, data: "AutomationJobUpdateRequest"
|
||||
) -> AutomationJob | None:
|
||||
update_values: dict[str, object] = {}
|
||||
if data.title is not None:
|
||||
update_values["title"] = data.title
|
||||
if data.schedule_type is not None:
|
||||
update_values["schedule_type"] = data.schedule_type
|
||||
if data.run_at is not None:
|
||||
stmt = select(AutomationJob).where(AutomationJob.id == job_id)
|
||||
existing = (await self._session.execute(stmt)).scalar_one_or_none()
|
||||
if existing is None:
|
||||
return None
|
||||
run_at_dt, next_run_at = _compute_next_local_time_utc(
|
||||
now_utc=datetime.now(timezone.utc),
|
||||
timezone_name=data.timezone or existing.timezone,
|
||||
local_hour=data.run_at.hour,
|
||||
local_minute=data.run_at.minute,
|
||||
)
|
||||
update_values["run_at"] = run_at_dt
|
||||
update_values["next_run_at"] = next_run_at
|
||||
update_values["timezone"] = data.timezone or existing.timezone
|
||||
if data.status is not None:
|
||||
update_values["status"] = data.status
|
||||
if data.config is not None:
|
||||
update_values["config"] = data.config.model_dump(mode="json")
|
||||
|
||||
if not update_values:
|
||||
return await self.get_by_id(job_id)
|
||||
|
||||
stmt = (
|
||||
update(AutomationJob)
|
||||
.where(AutomationJob.id == job_id)
|
||||
.where(AutomationJob.deleted_at.is_(None))
|
||||
.values(**update_values)
|
||||
.returning(AutomationJob)
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
await self._session.flush()
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def soft_delete(self, job_id: UUID) -> None:
|
||||
stmt = (
|
||||
update(AutomationJob)
|
||||
.where(AutomationJob.id == job_id)
|
||||
.where(AutomationJob.deleted_at.is_(None))
|
||||
.values(deleted_at=datetime.now(timezone.utc))
|
||||
)
|
||||
await self._session.execute(stmt)
|
||||
await self._session.flush()
|
||||
|
||||
async def list_due_jobs(
|
||||
self,
|
||||
*,
|
||||
@@ -166,7 +32,7 @@ class AutomationJobsRepository(BaseRepository[AutomationJob]):
|
||||
stmt = (
|
||||
select(AutomationJob)
|
||||
.where(AutomationJob.deleted_at.is_(None))
|
||||
.where(AutomationJob.status == "active")
|
||||
.where(AutomationJob.status == AutomationJobStatus.ACTIVE)
|
||||
.where(AutomationJob.next_run_at <= now_utc)
|
||||
.order_by(AutomationJob.next_run_at.asc())
|
||||
.limit(max(limit, 1))
|
||||
@@ -213,3 +79,160 @@ class AutomationJobsRepository(BaseRepository[AutomationJob]):
|
||||
self._session.add(new_session)
|
||||
await self._session.flush()
|
||||
return new_session.id
|
||||
|
||||
async def list_by_owner(self, owner_id: UUID) -> list[AutomationJob]:
|
||||
stmt = (
|
||||
select(AutomationJob)
|
||||
.where(AutomationJob.owner_id == owner_id)
|
||||
.where(AutomationJob.deleted_at.is_(None))
|
||||
.order_by(AutomationJob.created_at.desc())
|
||||
)
|
||||
rows = (await self._session.execute(stmt)).scalars().all()
|
||||
return list(rows)
|
||||
|
||||
async def count_user_jobs(self, owner_id: UUID) -> int:
|
||||
stmt = (
|
||||
select(func.count())
|
||||
.select_from(AutomationJob)
|
||||
.where(AutomationJob.owner_id == owner_id)
|
||||
.where(AutomationJob.deleted_at.is_(None))
|
||||
.where(AutomationJob.bootstrap_key.is_(None))
|
||||
)
|
||||
result = (await self._session.execute(stmt)).scalar_one()
|
||||
return int(result)
|
||||
|
||||
def _resolve_timezone(self, timezone_str: str) -> ZoneInfo:
|
||||
try:
|
||||
return ZoneInfo(timezone_str)
|
||||
except ZoneInfoNotFoundError:
|
||||
return ZoneInfo("UTC")
|
||||
|
||||
def _compute_initial_next_run_at(
|
||||
self,
|
||||
*,
|
||||
run_at: time,
|
||||
timezone_str: str,
|
||||
now_utc: datetime,
|
||||
schedule_type: ScheduleType,
|
||||
) -> datetime:
|
||||
tz = self._resolve_timezone(timezone_str)
|
||||
local_now = now_utc.astimezone(tz)
|
||||
run_at_local = datetime.combine(local_now.date(), run_at, tz)
|
||||
if run_at_local.tzinfo is None:
|
||||
run_at_local = run_at_local.replace(tzinfo=tz)
|
||||
next_run_at = run_at_local
|
||||
if next_run_at <= local_now:
|
||||
if schedule_type == ScheduleType.DAILY:
|
||||
next_run_at = next_run_at + timedelta(days=1)
|
||||
else:
|
||||
next_run_at = next_run_at + timedelta(weeks=1)
|
||||
return next_run_at.astimezone(timezone.utc)
|
||||
|
||||
async def create(
|
||||
self,
|
||||
owner_id: UUID,
|
||||
data: AutomationJobCreateRequest,
|
||||
) -> AutomationJob:
|
||||
now_utc = datetime.now(tz=timezone.utc)
|
||||
timezone_obj = self._resolve_timezone(data.timezone)
|
||||
local_now = now_utc.astimezone(timezone_obj)
|
||||
date_ref = local_now.date()
|
||||
local_dt = datetime.combine(date_ref, data.run_at, timezone_obj)
|
||||
run_at_datetime = local_dt.astimezone(timezone.utc)
|
||||
next_run_at = self._compute_initial_next_run_at(
|
||||
run_at=data.run_at,
|
||||
timezone_str=data.timezone,
|
||||
now_utc=now_utc,
|
||||
schedule_type=data.schedule_type,
|
||||
)
|
||||
|
||||
new_job = AutomationJob(
|
||||
owner_id=owner_id,
|
||||
created_by=owner_id,
|
||||
bootstrap_key=None,
|
||||
title=data.title,
|
||||
schedule_type=data.schedule_type,
|
||||
run_at=run_at_datetime,
|
||||
timezone=data.timezone,
|
||||
status=data.status,
|
||||
config=data.config.model_dump(mode="json"),
|
||||
next_run_at=next_run_at,
|
||||
)
|
||||
self._session.add(new_job)
|
||||
await self._session.flush()
|
||||
return new_job
|
||||
|
||||
async def update(
|
||||
self,
|
||||
job_id: UUID,
|
||||
data: AutomationJobUpdateRequest,
|
||||
) -> AutomationJob | None:
|
||||
update_values: dict[str, object] = {}
|
||||
existing_job: AutomationJob | None = None
|
||||
|
||||
if data.title is not None:
|
||||
update_values["title"] = data.title
|
||||
if data.schedule_type is not None:
|
||||
update_values["schedule_type"] = data.schedule_type
|
||||
|
||||
should_recompute_schedule = (
|
||||
data.run_at is not None
|
||||
or data.schedule_type is not None
|
||||
or data.timezone is not None
|
||||
)
|
||||
if should_recompute_schedule:
|
||||
now_utc = datetime.now(tz=timezone.utc)
|
||||
if existing_job is None:
|
||||
existing_job = await self.get_by_id(job_id)
|
||||
if existing_job is None:
|
||||
return None
|
||||
|
||||
effective_timezone = data.timezone or existing_job.timezone
|
||||
effective_timezone_obj = self._resolve_timezone(effective_timezone)
|
||||
effective_schedule_type = data.schedule_type or existing_job.schedule_type
|
||||
|
||||
if data.run_at is not None:
|
||||
effective_run_at = data.run_at
|
||||
else:
|
||||
existing_timezone_obj = self._resolve_timezone(existing_job.timezone)
|
||||
effective_run_at = (
|
||||
existing_job.run_at.astimezone(existing_timezone_obj)
|
||||
.time()
|
||||
.replace(microsecond=0)
|
||||
)
|
||||
|
||||
local_now = now_utc.astimezone(effective_timezone_obj)
|
||||
local_dt = datetime.combine(
|
||||
local_now.date(),
|
||||
effective_run_at,
|
||||
effective_timezone_obj,
|
||||
)
|
||||
update_values["run_at"] = local_dt.astimezone(timezone.utc)
|
||||
update_values["next_run_at"] = self._compute_initial_next_run_at(
|
||||
run_at=effective_run_at,
|
||||
timezone_str=effective_timezone,
|
||||
now_utc=now_utc,
|
||||
schedule_type=effective_schedule_type,
|
||||
)
|
||||
if data.timezone is not None:
|
||||
update_values["timezone"] = data.timezone
|
||||
if data.status is not None:
|
||||
update_values["status"] = data.status
|
||||
if data.config is not None:
|
||||
if existing_job is None:
|
||||
existing_job = await self.get_by_id(job_id)
|
||||
if existing_job is None:
|
||||
return None
|
||||
merged_config = {
|
||||
**existing_job.config,
|
||||
**data.config.model_dump(mode="json", exclude_unset=True),
|
||||
}
|
||||
update_values["config"] = merged_config
|
||||
|
||||
if not update_values:
|
||||
return await self.get_by_id(job_id)
|
||||
|
||||
return await self.update_by_id(job_id, update_values)
|
||||
|
||||
async def soft_delete(self, job_id: UUID) -> None:
|
||||
await self.soft_delete_by_id(job_id)
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, status
|
||||
|
||||
from v1.automation_jobs.dependencies import (
|
||||
get_automation_jobs_service,
|
||||
get_current_user_id,
|
||||
)
|
||||
from v1.automation_jobs.schemas import (
|
||||
AutomationJobCreateRequest,
|
||||
AutomationJobListResponse,
|
||||
AutomationJobResponse,
|
||||
AutomationJobUpdateRequest,
|
||||
)
|
||||
from v1.automation_jobs.service import AutomationJobsService
|
||||
|
||||
|
||||
router = APIRouter(prefix="/automation-jobs", tags=["automation-jobs"])
|
||||
|
||||
|
||||
@router.get("", response_model=AutomationJobListResponse)
|
||||
async def list_automation_jobs(
|
||||
service: Annotated[AutomationJobsService, Depends(get_automation_jobs_service)],
|
||||
current_user_id: Annotated[UUID, Depends(get_current_user_id)],
|
||||
) -> AutomationJobListResponse:
|
||||
return await service.list_by_owner(owner_id=current_user_id)
|
||||
|
||||
|
||||
@router.post(
|
||||
"", response_model=AutomationJobResponse, status_code=status.HTTP_201_CREATED
|
||||
)
|
||||
async def create_automation_job(
|
||||
request: AutomationJobCreateRequest,
|
||||
service: Annotated[AutomationJobsService, Depends(get_automation_jobs_service)],
|
||||
current_user_id: Annotated[UUID, Depends(get_current_user_id)],
|
||||
) -> AutomationJobResponse:
|
||||
return await service.create(owner_id=current_user_id, data=request)
|
||||
|
||||
|
||||
@router.get("/{job_id}", response_model=AutomationJobResponse)
|
||||
async def get_automation_job(
|
||||
job_id: UUID,
|
||||
service: Annotated[AutomationJobsService, Depends(get_automation_jobs_service)],
|
||||
current_user_id: Annotated[UUID, Depends(get_current_user_id)],
|
||||
) -> AutomationJobResponse:
|
||||
return await service.get_by_id(job_id=job_id, owner_id=current_user_id)
|
||||
|
||||
|
||||
@router.patch("/{job_id}", response_model=AutomationJobResponse)
|
||||
async def update_automation_job(
|
||||
job_id: UUID,
|
||||
request: AutomationJobUpdateRequest,
|
||||
service: Annotated[AutomationJobsService, Depends(get_automation_jobs_service)],
|
||||
current_user_id: Annotated[UUID, Depends(get_current_user_id)],
|
||||
) -> AutomationJobResponse:
|
||||
return await service.update(job_id=job_id, owner_id=current_user_id, data=request)
|
||||
|
||||
|
||||
@router.delete("/{job_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_automation_job(
|
||||
job_id: UUID,
|
||||
service: Annotated[AutomationJobsService, Depends(get_automation_jobs_service)],
|
||||
current_user_id: Annotated[UUID, Depends(get_current_user_id)],
|
||||
) -> None:
|
||||
await service.delete(job_id=job_id, owner_id=current_user_id)
|
||||
@@ -3,14 +3,13 @@ from __future__ import annotations
|
||||
from datetime import datetime, time
|
||||
from typing import Self
|
||||
from uuid import UUID
|
||||
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from models.automation_jobs import AutomationJob as OrmAutomationJob
|
||||
from models.automation_jobs import AutomationJobStatus, ScheduleType
|
||||
from schemas.automation import (
|
||||
AutomationJobConfig,
|
||||
)
|
||||
from schemas.automation import AutomationJobConfig
|
||||
|
||||
|
||||
class AutomationJobResponse(BaseModel):
|
||||
@@ -61,6 +60,15 @@ class AutomationJobCreateRequest(BaseModel):
|
||||
status: AutomationJobStatus = Field(default=AutomationJobStatus.ACTIVE)
|
||||
config: AutomationJobConfig
|
||||
|
||||
@field_validator("timezone")
|
||||
@classmethod
|
||||
def validate_timezone(cls, value: str) -> str:
|
||||
try:
|
||||
ZoneInfo(value)
|
||||
except ZoneInfoNotFoundError as exc:
|
||||
raise ValueError("timezone must be a valid IANA timezone") from exc
|
||||
return value
|
||||
|
||||
|
||||
class AutomationJobUpdateRequest(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
@@ -72,6 +80,17 @@ class AutomationJobUpdateRequest(BaseModel):
|
||||
status: AutomationJobStatus | None = None
|
||||
config: AutomationJobConfig | None = None
|
||||
|
||||
@field_validator("timezone")
|
||||
@classmethod
|
||||
def validate_timezone(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
try:
|
||||
ZoneInfo(value)
|
||||
except ZoneInfoNotFoundError as exc:
|
||||
raise ValueError("timezone must be a valid IANA timezone") from exc
|
||||
return value
|
||||
|
||||
|
||||
class AutomationJobListResponse(BaseModel):
|
||||
items: list[AutomationJobResponse]
|
||||
|
||||
@@ -5,14 +5,54 @@ from datetime import datetime, timedelta
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
from models.automation_jobs import ScheduleType
|
||||
from schemas.automation import AutomationJob as AutomationJobSchema, RuntimeConfig
|
||||
from schemas.automation import (
|
||||
AutomationJob as AutomationJobSchema,
|
||||
MessageContextConfig,
|
||||
RuntimeConfig,
|
||||
)
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from core.logging import get_logger
|
||||
from v1.automation_jobs.schemas import (
|
||||
AutomationJobCreateRequest,
|
||||
AutomationJobListResponse,
|
||||
AutomationJobResponse,
|
||||
AutomationJobUpdateRequest,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from v1.automation_jobs.repository import AutomationJobsRepository
|
||||
|
||||
logger = get_logger("v1.automation_jobs.service")
|
||||
|
||||
|
||||
class AutomationJobLimitExceeded(HTTPException):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Maximum of 3 user jobs allowed",
|
||||
)
|
||||
|
||||
|
||||
class SystemJobModificationForbidden(HTTPException):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="System job cannot be modified",
|
||||
)
|
||||
|
||||
|
||||
class AutomationJobNotFound(HTTPException):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Automation job not found",
|
||||
)
|
||||
|
||||
|
||||
class DispatchFn(Protocol):
|
||||
async def __call__(
|
||||
@@ -46,6 +86,9 @@ class ScanResult:
|
||||
|
||||
|
||||
class AutomationJobsService:
|
||||
_repository: "AutomationJobsRepository"
|
||||
_session: "AsyncSession"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
repository: "AutomationJobsRepository",
|
||||
@@ -71,14 +114,15 @@ class AutomationJobsService:
|
||||
thread_id = await self.get_or_create_chat_session(owner_id=job.owner_id)
|
||||
run_id = f"auto-{job.id}-{int(now_utc.timestamp())}"
|
||||
|
||||
input_text = (job.config.input_template or "").strip()
|
||||
await dispatch_fn(
|
||||
owner_id=job.owner_id,
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
input_text=job.config.input_template.strip(),
|
||||
input_text=input_text,
|
||||
runtime_config=RuntimeConfig(
|
||||
enabled_tools=job.config.enabled_tools,
|
||||
context=job.config.context,
|
||||
enabled_tools=job.config.enabled_tools or [],
|
||||
context=job.config.context or MessageContextConfig(),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -102,3 +146,82 @@ class AutomationJobsService:
|
||||
|
||||
async def get_or_create_chat_session(self, *, owner_id: UUID) -> UUID:
|
||||
return await self._repository.get_or_create_chat_session(owner_id=owner_id)
|
||||
|
||||
async def list_by_owner(self, owner_id: UUID) -> AutomationJobListResponse:
|
||||
jobs = await self._repository.list_by_owner(owner_id)
|
||||
return AutomationJobListResponse(
|
||||
items=[AutomationJobResponse.from_orm(job) for job in jobs],
|
||||
)
|
||||
|
||||
async def get_by_id(self, job_id: UUID, owner_id: UUID) -> AutomationJobResponse:
|
||||
job = await self._repository.get_by_id(job_id)
|
||||
if job is None or job.owner_id != owner_id:
|
||||
raise AutomationJobNotFound()
|
||||
return AutomationJobResponse.from_orm(job)
|
||||
|
||||
async def create(
|
||||
self,
|
||||
owner_id: UUID,
|
||||
data: AutomationJobCreateRequest,
|
||||
) -> AutomationJobResponse:
|
||||
try:
|
||||
await self._session.execute(
|
||||
text("SELECT pg_advisory_xact_lock(abs(hashtext(:owner_id)))"),
|
||||
{"owner_id": str(owner_id)},
|
||||
)
|
||||
count = await self._repository.count_user_jobs(owner_id)
|
||||
if count >= 3:
|
||||
await self._session.rollback()
|
||||
raise AutomationJobLimitExceeded()
|
||||
job = await self._repository.create(owner_id, data)
|
||||
await self._session.commit()
|
||||
return AutomationJobResponse.from_orm(job)
|
||||
except SQLAlchemyError:
|
||||
await self._session.rollback()
|
||||
logger.exception("Failed to create automation job", owner_id=str(owner_id))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Automation job store unavailable",
|
||||
)
|
||||
|
||||
async def update(
|
||||
self,
|
||||
job_id: UUID,
|
||||
owner_id: UUID,
|
||||
data: AutomationJobUpdateRequest,
|
||||
) -> AutomationJobResponse:
|
||||
try:
|
||||
job = await self._repository.get_by_id(job_id)
|
||||
if job is None or job.owner_id != owner_id:
|
||||
raise AutomationJobNotFound()
|
||||
if job.bootstrap_key is not None:
|
||||
raise SystemJobModificationForbidden()
|
||||
updated_job = await self._repository.update(job_id, data)
|
||||
if updated_job is None:
|
||||
raise AutomationJobNotFound()
|
||||
await self._session.commit()
|
||||
return AutomationJobResponse.from_orm(updated_job)
|
||||
except SQLAlchemyError:
|
||||
await self._session.rollback()
|
||||
logger.exception("Failed to update automation job", job_id=str(job_id))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Automation job store unavailable",
|
||||
)
|
||||
|
||||
async def delete(self, job_id: UUID, owner_id: UUID) -> None:
|
||||
try:
|
||||
job = await self._repository.get_by_id(job_id)
|
||||
if job is None or job.owner_id != owner_id:
|
||||
raise AutomationJobNotFound()
|
||||
if job.bootstrap_key is not None:
|
||||
raise SystemJobModificationForbidden()
|
||||
await self._repository.soft_delete(job_id)
|
||||
await self._session.commit()
|
||||
except SQLAlchemyError:
|
||||
await self._session.rollback()
|
||||
logger.exception("Failed to delete automation job", job_id=str(job_id))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Automation job store unavailable",
|
||||
)
|
||||
|
||||
@@ -4,6 +4,7 @@ from fastapi import APIRouter
|
||||
|
||||
from v1.agent.router import router as agent_router
|
||||
from v1.app.router import router as app_router
|
||||
from v1.automation_jobs.router import router as automation_jobs_router
|
||||
from v1.auth.router import router as auth_router
|
||||
from v1.friendships.router import router as friendships_router
|
||||
from v1.inbox_messages.router import router as inbox_messages_router
|
||||
@@ -17,6 +18,7 @@ router = APIRouter(prefix="/api/v1")
|
||||
router.include_router(app_router)
|
||||
router.include_router(auth_router)
|
||||
router.include_router(agent_router)
|
||||
router.include_router(automation_jobs_router)
|
||||
router.include_router(friendships_router)
|
||||
router.include_router(memories_router)
|
||||
router.include_router(users_router)
|
||||
|
||||
@@ -0,0 +1,371 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, time, timezone
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app import app
|
||||
from core.auth.models import CurrentUser
|
||||
from v1.automation_jobs.dependencies import get_automation_jobs_service
|
||||
from v1.automation_jobs.service import (
|
||||
AutomationJobLimitExceeded,
|
||||
AutomationJobNotFound,
|
||||
)
|
||||
from v1.automation_jobs.schemas import (
|
||||
AutomationJobCreateRequest,
|
||||
AutomationJobListResponse,
|
||||
AutomationJobResponse,
|
||||
AutomationJobUpdateRequest,
|
||||
)
|
||||
from v1.users.dependencies import get_current_user
|
||||
|
||||
|
||||
def _make_job_response(
|
||||
job_id: UUID | None = None, owner_id: UUID | None = None, **overrides
|
||||
) -> AutomationJobResponse:
|
||||
now = datetime.now(timezone.utc)
|
||||
return AutomationJobResponse(
|
||||
id=job_id or uuid4(),
|
||||
owner_id=owner_id or uuid4(),
|
||||
title=overrides.get("title", "Test Job"),
|
||||
schedule_type=overrides.get("schedule_type", "daily"),
|
||||
run_at=overrides.get("run_at", time(9, 0, 0)),
|
||||
timezone=overrides.get("timezone", "Asia/Shanghai"),
|
||||
status=overrides.get("status", "active"),
|
||||
is_system=overrides.get("is_system", False),
|
||||
config=overrides.get(
|
||||
"config", {"input_template": "Hello", "enabled_tools": [], "context": {}}
|
||||
),
|
||||
next_run_at=overrides.get("next_run_at", now),
|
||||
created_at=overrides.get("created_at", now),
|
||||
updated_at=overrides.get("updated_at", now),
|
||||
)
|
||||
|
||||
|
||||
def test_list_automation_jobs_requires_auth() -> None:
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/automation-jobs")
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
def test_list_automation_jobs_returns_empty_when_no_jobs() -> None:
|
||||
class FakeService:
|
||||
async def list_by_owner(self, *, owner_id: UUID) -> AutomationJobListResponse:
|
||||
return AutomationJobListResponse(items=[])
|
||||
|
||||
app.dependency_overrides[get_automation_jobs_service] = lambda: FakeService()
|
||||
app.dependency_overrides.pop(get_current_user, None)
|
||||
client = TestClient(app)
|
||||
|
||||
try:
|
||||
response = client.get("/api/v1/automation-jobs")
|
||||
assert response.status_code == 401
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_list_automation_jobs_returns_jobs() -> None:
|
||||
user_id = uuid4()
|
||||
job = _make_job_response(owner_id=user_id)
|
||||
|
||||
class FakeService:
|
||||
async def list_by_owner(self, *, owner_id: UUID) -> AutomationJobListResponse:
|
||||
if owner_id == user_id:
|
||||
return AutomationJobListResponse(items=[job])
|
||||
return AutomationJobListResponse(items=[])
|
||||
|
||||
app.dependency_overrides[get_automation_jobs_service] = lambda: FakeService()
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=user_id, phone="+8613812345678"
|
||||
)
|
||||
client = TestClient(app)
|
||||
|
||||
try:
|
||||
response = client.get("/api/v1/automation-jobs")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["items"]) == 1
|
||||
assert data["items"][0]["title"] == "Test Job"
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_create_automation_job_requires_auth() -> None:
|
||||
class FakeService:
|
||||
pass
|
||||
|
||||
app.dependency_overrides[get_automation_jobs_service] = lambda: FakeService()
|
||||
app.dependency_overrides.pop(get_current_user, None)
|
||||
client = TestClient(app)
|
||||
|
||||
try:
|
||||
response = client.post(
|
||||
"/api/v1/automation-jobs",
|
||||
json={
|
||||
"title": "New Job",
|
||||
"schedule_type": "daily",
|
||||
"run_at": "09:00:00",
|
||||
"timezone": "Asia/Shanghai",
|
||||
"config": {
|
||||
"input_template": "Hello",
|
||||
"enabled_tools": [],
|
||||
"context": {},
|
||||
},
|
||||
},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_create_automation_job_succeeds() -> None:
|
||||
user_id = uuid4()
|
||||
new_job = _make_job_response(owner_id=user_id, title="New Job")
|
||||
|
||||
class FakeService:
|
||||
async def create(
|
||||
self, *, owner_id: UUID, data: AutomationJobCreateRequest
|
||||
) -> AutomationJobResponse:
|
||||
return new_job
|
||||
|
||||
app.dependency_overrides[get_automation_jobs_service] = lambda: FakeService()
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=user_id, phone="+8613812345678"
|
||||
)
|
||||
client = TestClient(app)
|
||||
|
||||
try:
|
||||
response = client.post(
|
||||
"/api/v1/automation-jobs",
|
||||
json={
|
||||
"title": "New Job",
|
||||
"schedule_type": "daily",
|
||||
"run_at": "09:00:00",
|
||||
"timezone": "Asia/Shanghai",
|
||||
"status": "active",
|
||||
"config": {
|
||||
"input_template": "Hello",
|
||||
"enabled_tools": [],
|
||||
"context": {},
|
||||
},
|
||||
},
|
||||
)
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert data["title"] == "New Job"
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_create_automation_job_respects_limit() -> None:
|
||||
user_id = uuid4()
|
||||
|
||||
class FakeService:
|
||||
async def create(
|
||||
self, *, owner_id: UUID, data: AutomationJobCreateRequest
|
||||
) -> AutomationJobResponse:
|
||||
raise AutomationJobLimitExceeded()
|
||||
|
||||
app.dependency_overrides[get_automation_jobs_service] = lambda: FakeService()
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=user_id, phone="+8613812345678"
|
||||
)
|
||||
client = TestClient(app)
|
||||
|
||||
try:
|
||||
response = client.post(
|
||||
"/api/v1/automation-jobs",
|
||||
json={
|
||||
"title": "New Job",
|
||||
"schedule_type": "daily",
|
||||
"run_at": "09:00:00",
|
||||
"timezone": "Asia/Shanghai",
|
||||
"status": "active",
|
||||
"config": {
|
||||
"input_template": "Hello",
|
||||
"enabled_tools": [],
|
||||
"context": {},
|
||||
},
|
||||
},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert "maximum" in response.json()["detail"].lower()
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_get_automation_job_requires_auth() -> None:
|
||||
client = TestClient(app)
|
||||
response = client.get(f"/api/v1/automation-jobs/{uuid4()}")
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
def test_get_automation_job_returns_job() -> None:
|
||||
user_id = uuid4()
|
||||
job_id = uuid4()
|
||||
job = _make_job_response(id=job_id, owner_id=user_id)
|
||||
|
||||
captured_job_id = job_id
|
||||
captured_owner_id = user_id
|
||||
|
||||
class FakeService:
|
||||
async def get_by_id(
|
||||
self, *, job_id: UUID, owner_id: UUID
|
||||
) -> AutomationJobResponse:
|
||||
if job_id == captured_job_id and owner_id == captured_owner_id:
|
||||
return job
|
||||
raise AutomationJobNotFound()
|
||||
|
||||
app.dependency_overrides[get_automation_jobs_service] = lambda: FakeService()
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=user_id, phone="+8613812345678"
|
||||
)
|
||||
client = TestClient(app)
|
||||
|
||||
try:
|
||||
response = client.get(f"/api/v1/automation-jobs/{job_id}")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["title"] == "Test Job"
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_get_automation_job_returns_404_when_not_found() -> None:
|
||||
user_id = uuid4()
|
||||
|
||||
class FakeService:
|
||||
async def get_by_id(
|
||||
self, *, job_id: UUID, owner_id: UUID
|
||||
) -> AutomationJobResponse:
|
||||
raise AutomationJobNotFound()
|
||||
|
||||
app.dependency_overrides[get_automation_jobs_service] = lambda: FakeService()
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=user_id, phone="+8613812345678"
|
||||
)
|
||||
client = TestClient(app)
|
||||
|
||||
try:
|
||||
response = client.get(f"/api/v1/automation-jobs/{uuid4()}")
|
||||
assert response.status_code == 404
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_update_automation_job_requires_auth() -> None:
|
||||
client = TestClient(app)
|
||||
response = client.patch(
|
||||
f"/api/v1/automation-jobs/{uuid4()}",
|
||||
json={"title": "Updated"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
def test_update_automation_job_succeeds() -> None:
|
||||
user_id = uuid4()
|
||||
job_id = uuid4()
|
||||
updated_job = _make_job_response(id=job_id, owner_id=user_id, title="Updated Title")
|
||||
|
||||
class FakeService:
|
||||
async def update(
|
||||
self,
|
||||
*,
|
||||
job_id: UUID,
|
||||
owner_id: UUID,
|
||||
data: AutomationJobUpdateRequest,
|
||||
) -> AutomationJobResponse:
|
||||
return updated_job
|
||||
|
||||
app.dependency_overrides[get_automation_jobs_service] = lambda: FakeService()
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=user_id, phone="+8613812345678"
|
||||
)
|
||||
client = TestClient(app)
|
||||
|
||||
try:
|
||||
response = client.patch(
|
||||
f"/api/v1/automation-jobs/{job_id}",
|
||||
json={"title": "Updated Title"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["title"] == "Updated Title"
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_update_automation_job_returns_404_when_not_found() -> None:
|
||||
user_id = uuid4()
|
||||
|
||||
class FakeService:
|
||||
async def update(
|
||||
self,
|
||||
*,
|
||||
job_id: UUID,
|
||||
owner_id: UUID,
|
||||
data: AutomationJobUpdateRequest,
|
||||
) -> AutomationJobResponse:
|
||||
raise AutomationJobNotFound()
|
||||
|
||||
app.dependency_overrides[get_automation_jobs_service] = lambda: FakeService()
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=user_id, phone="+8613812345678"
|
||||
)
|
||||
client = TestClient(app)
|
||||
|
||||
try:
|
||||
response = client.patch(
|
||||
f"/api/v1/automation-jobs/{uuid4()}", json={"title": "Updated"}
|
||||
)
|
||||
assert response.status_code == 404
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_delete_automation_job_requires_auth() -> None:
|
||||
client = TestClient(app)
|
||||
response = client.delete(f"/api/v1/automation-jobs/{uuid4()}")
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
def test_delete_automation_job_succeeds() -> None:
|
||||
user_id = uuid4()
|
||||
job_id = uuid4()
|
||||
|
||||
class FakeService:
|
||||
async def delete(self, *, job_id: UUID, owner_id: UUID) -> None:
|
||||
pass
|
||||
|
||||
app.dependency_overrides[get_automation_jobs_service] = lambda: FakeService()
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=user_id, phone="+8613812345678"
|
||||
)
|
||||
client = TestClient(app)
|
||||
|
||||
try:
|
||||
response = client.delete(f"/api/v1/automation-jobs/{job_id}")
|
||||
assert response.status_code == 204
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_delete_automation_job_returns_404_when_not_found() -> None:
|
||||
user_id = uuid4()
|
||||
|
||||
class FakeService:
|
||||
async def delete(self, *, job_id: UUID, owner_id: UUID) -> None:
|
||||
raise AutomationJobNotFound()
|
||||
|
||||
app.dependency_overrides[get_automation_jobs_service] = lambda: FakeService()
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=user_id, phone="+8613812345678"
|
||||
)
|
||||
client = TestClient(app)
|
||||
|
||||
try:
|
||||
response = client.delete(f"/api/v1/automation-jobs/{uuid4()}")
|
||||
assert response.status_code == 404
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
@@ -12,6 +12,10 @@ def test_memory_automation_static_config_contract() -> None:
|
||||
"memory.write",
|
||||
"memory.forget",
|
||||
]
|
||||
prompt = config.input_template
|
||||
assert "提取" in prompt
|
||||
assert "遗忘" in prompt
|
||||
assert config.input_template is not None
|
||||
assert "提取" in config.input_template
|
||||
assert "遗忘" in config.input_template
|
||||
assert config.schedule is not None
|
||||
assert config.schedule.type.value == "daily"
|
||||
assert config.schedule.run_at.hour == 8
|
||||
assert config.schedule.run_at.minute == 0
|
||||
|
||||
@@ -6,6 +6,7 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from models.automation_jobs import ScheduleType
|
||||
from v1.auth.registration_bootstrap import (
|
||||
compute_next_local_time_utc,
|
||||
)
|
||||
@@ -19,6 +20,7 @@ def test_compute_next_local_time_utc_from_asia_shanghai() -> None:
|
||||
timezone_name="Asia/Shanghai",
|
||||
local_hour=8,
|
||||
local_minute=0,
|
||||
schedule_type=ScheduleType.DAILY,
|
||||
)
|
||||
|
||||
assert run_at == datetime(2026, 3, 24, 0, 0, tzinfo=timezone.utc)
|
||||
@@ -33,6 +35,7 @@ def test_compute_next_local_time_utc_rolls_to_next_day_when_passed() -> None:
|
||||
timezone_name="Asia/Shanghai",
|
||||
local_hour=8,
|
||||
local_minute=0,
|
||||
schedule_type=ScheduleType.DAILY,
|
||||
)
|
||||
|
||||
assert run_at == datetime(2026, 3, 24, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
@@ -1,283 +1,287 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, time, timezone
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from models.automation_jobs import AutomationJobStatus, ScheduleType
|
||||
from v1.automation_jobs.repository import AutomationJobsRepository
|
||||
from v1.automation_jobs.schemas import (
|
||||
AutomationJobCreateRequest,
|
||||
AutomationJobUpdateRequest,
|
||||
)
|
||||
from schemas.automation import (
|
||||
AgentTool,
|
||||
AutomationJobConfig,
|
||||
ContextSource,
|
||||
ContextWindowMode,
|
||||
MessageContextConfig,
|
||||
)
|
||||
|
||||
|
||||
class _ExecuteResult:
|
||||
def __init__(self, value: object) -> None:
|
||||
self._value = value
|
||||
|
||||
def scalar_one_or_none(self) -> object:
|
||||
return self._value
|
||||
|
||||
def scalar_one(self) -> int:
|
||||
return self._value # type: ignore[return-value]
|
||||
def _make_config() -> AutomationJobConfig:
|
||||
return AutomationJobConfig(
|
||||
input_template="Hello",
|
||||
enabled_tools=[AgentTool.MEMORY_WRITE],
|
||||
context=MessageContextConfig(
|
||||
source=ContextSource.LATEST_CHAT,
|
||||
window_mode=ContextWindowMode.DAY,
|
||||
window_count=2,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class _ScalarRows:
|
||||
def __init__(self, rows: list[object]) -> None:
|
||||
self._rows = rows
|
||||
|
||||
def all(self) -> list[object]:
|
||||
return self._rows
|
||||
|
||||
|
||||
class _ExecuteRowsResult:
|
||||
def __init__(self, rows: list[object]) -> None:
|
||||
self._rows = rows
|
||||
|
||||
def scalars(self) -> _ScalarRows:
|
||||
return _ScalarRows(self._rows)
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self) -> None:
|
||||
self.added: list[object] = []
|
||||
self.flushed = False
|
||||
self._execute_result: object = None
|
||||
self._return_rows: bool = False
|
||||
|
||||
def set_execute_result(self, value: object) -> None:
|
||||
self._execute_result = value
|
||||
self._return_rows = isinstance(value, list)
|
||||
|
||||
async def execute(self, stmt): # noqa: ANN001
|
||||
del stmt
|
||||
if self._return_rows:
|
||||
return _ExecuteRowsResult(self._execute_result)
|
||||
return _ExecuteResult(self._execute_result)
|
||||
|
||||
def add(self, obj: object) -> None:
|
||||
self.added.append(obj)
|
||||
|
||||
async def flush(self) -> None:
|
||||
self.flushed = True
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_session() -> _FakeSession:
|
||||
return _FakeSession()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def repository(fake_session: _FakeSession) -> AutomationJobsRepository:
|
||||
return AutomationJobsRepository(session=fake_session) # type: ignore[arg-type]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_job() -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
id=uuid4(),
|
||||
owner_id=uuid4(),
|
||||
bootstrap_key=None,
|
||||
def _make_create_request() -> AutomationJobCreateRequest:
|
||||
return AutomationJobCreateRequest(
|
||||
title="Test Job",
|
||||
config={"input_template": "Hello {name}"},
|
||||
schedule_type=ScheduleType.DAILY,
|
||||
run_at=datetime(2026, 3, 23, 0, 0, tzinfo=timezone.utc),
|
||||
next_run_at=datetime(2026, 3, 24, 0, 0, tzinfo=timezone.utc),
|
||||
timezone="UTC",
|
||||
run_at=time(9, 0, 0),
|
||||
timezone="Asia/Shanghai",
|
||||
status=AutomationJobStatus.ACTIVE,
|
||||
created_by=uuid4(),
|
||||
deleted_at=None,
|
||||
config=_make_config(),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_by_owner_returns_jobs(
|
||||
repository: AutomationJobsRepository,
|
||||
fake_session: _FakeSession,
|
||||
sample_job: SimpleNamespace,
|
||||
) -> None:
|
||||
fake_session.set_execute_result([sample_job])
|
||||
|
||||
async def test_list_by_owner_returns_jobs() -> None:
|
||||
session = AsyncMock()
|
||||
repository = AutomationJobsRepository(session)
|
||||
owner_id = uuid4()
|
||||
jobs = await repository.list_by_owner(owner_id)
|
||||
job_one = MagicMock()
|
||||
job_two = MagicMock()
|
||||
execute_result = MagicMock()
|
||||
execute_result.scalars.return_value.all.return_value = [job_one, job_two]
|
||||
session.execute.return_value = execute_result
|
||||
|
||||
assert len(jobs) == 1
|
||||
assert jobs[0].title == "Test Job"
|
||||
result = await repository.list_by_owner(owner_id)
|
||||
|
||||
assert result == [job_one, job_two]
|
||||
session.execute.assert_awaited_once()
|
||||
call_args = session.execute.call_args
|
||||
stmt = call_args[0][0]
|
||||
assert "owner_id" in str(stmt)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_by_owner_returns_empty_list(
|
||||
repository: AutomationJobsRepository,
|
||||
fake_session: _FakeSession,
|
||||
) -> None:
|
||||
fake_session.set_execute_result([])
|
||||
|
||||
async def test_count_user_jobs_counts_non_bootstrap_jobs() -> None:
|
||||
session = AsyncMock()
|
||||
repository = AutomationJobsRepository(session)
|
||||
owner_id = uuid4()
|
||||
jobs = await repository.list_by_owner(owner_id)
|
||||
execute_result = MagicMock()
|
||||
execute_result.scalar_one.return_value = 3
|
||||
session.execute.return_value = execute_result
|
||||
|
||||
assert jobs == []
|
||||
result = await repository.count_user_jobs(owner_id)
|
||||
|
||||
assert result == 3
|
||||
session.execute.assert_awaited_once()
|
||||
call_args = session.execute.call_args
|
||||
stmt = call_args[0][0]
|
||||
stmt_str = str(stmt)
|
||||
assert "bootstrap_key" in stmt_str
|
||||
assert "IS NULL" in stmt_str or "is_(None)" in stmt_str.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_id_returns_job(
|
||||
repository: AutomationJobsRepository,
|
||||
fake_session: _FakeSession,
|
||||
sample_job: SimpleNamespace,
|
||||
) -> None:
|
||||
fake_session.set_execute_result(sample_job)
|
||||
async def test_create_sets_bootstrap_key_to_none() -> None:
|
||||
session = AsyncMock()
|
||||
session.add = MagicMock()
|
||||
repository = AutomationJobsRepository(session)
|
||||
owner_id = uuid4()
|
||||
data = _make_create_request()
|
||||
|
||||
await repository.create(owner_id, data)
|
||||
|
||||
session.add.assert_called_once()
|
||||
call_args = session.add.call_args[0][0]
|
||||
assert call_args.bootstrap_key is None
|
||||
session.flush.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_sets_correct_fields() -> None:
|
||||
session = AsyncMock()
|
||||
session.add = MagicMock()
|
||||
repository = AutomationJobsRepository(session)
|
||||
owner_id = uuid4()
|
||||
data = _make_create_request()
|
||||
|
||||
await repository.create(owner_id, data)
|
||||
|
||||
call_args = session.add.call_args[0][0]
|
||||
assert call_args.owner_id == owner_id
|
||||
assert call_args.title == data.title
|
||||
assert call_args.schedule_type == data.schedule_type
|
||||
assert call_args.timezone == data.timezone
|
||||
assert call_args.status == data.status
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_returns_updated_job() -> None:
|
||||
session = AsyncMock()
|
||||
repository = AutomationJobsRepository(session)
|
||||
job_id = uuid4()
|
||||
job = await repository.get_by_id(job_id)
|
||||
existing_job = MagicMock()
|
||||
existing_job.schedule_type = ScheduleType.DAILY
|
||||
existing_job.config = {"input_template": "Old"}
|
||||
updated_job = MagicMock()
|
||||
execute_result = MagicMock()
|
||||
execute_result.scalar_one_or_none.return_value = updated_job
|
||||
session.execute.return_value = execute_result
|
||||
|
||||
assert job is not None
|
||||
assert job.title == "Test Job"
|
||||
data = AutomationJobUpdateRequest(title="Updated Title")
|
||||
result = await repository.update(job_id, data)
|
||||
|
||||
assert result is updated_job
|
||||
session.flush.assert_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_id_returns_none_when_not_found(
|
||||
repository: AutomationJobsRepository,
|
||||
fake_session: _FakeSession,
|
||||
) -> None:
|
||||
fake_session.set_execute_result(None)
|
||||
|
||||
async def test_update_merges_config() -> None:
|
||||
session = AsyncMock()
|
||||
repository = AutomationJobsRepository(session)
|
||||
job_id = uuid4()
|
||||
job = await repository.get_by_id(job_id)
|
||||
existing_job = MagicMock()
|
||||
existing_job.schedule_type = ScheduleType.DAILY
|
||||
existing_job.config = {"input_template": "Old", "enabled_tools": []}
|
||||
execute_result = MagicMock()
|
||||
execute_result.scalar_one_or_none.return_value = existing_job
|
||||
session.execute.return_value = execute_result
|
||||
|
||||
assert job is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_user_jobs_returns_count(
|
||||
repository: AutomationJobsRepository,
|
||||
fake_session: _FakeSession,
|
||||
) -> None:
|
||||
fake_session.set_execute_result(5)
|
||||
|
||||
owner_id = uuid4()
|
||||
count = await repository.count_user_jobs(owner_id)
|
||||
|
||||
assert count == 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_user_jobs_returns_zero_when_none(
|
||||
repository: AutomationJobsRepository,
|
||||
fake_session: _FakeSession,
|
||||
) -> None:
|
||||
fake_session.set_execute_result(0)
|
||||
|
||||
owner_id = uuid4()
|
||||
count = await repository.count_user_jobs(owner_id)
|
||||
|
||||
assert count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_job(
|
||||
repository: AutomationJobsRepository,
|
||||
fake_session: _FakeSession,
|
||||
) -> None:
|
||||
from v1.automation_jobs.schemas import AutomationJobCreateRequest
|
||||
from schemas.automation import AutomationJobConfig
|
||||
|
||||
owner_id = uuid4()
|
||||
request = AutomationJobCreateRequest(
|
||||
title="New Job",
|
||||
schedule_type=ScheduleType.DAILY,
|
||||
run_at=time(0, 0),
|
||||
timezone="UTC",
|
||||
status=AutomationJobStatus.ACTIVE,
|
||||
config=AutomationJobConfig(input_template="Test"),
|
||||
data = AutomationJobUpdateRequest(
|
||||
config={"input_template": "New", "context": {"source": "latest_chat"}}
|
||||
)
|
||||
await repository.update(job_id, data)
|
||||
|
||||
job = await repository.create(owner_id, request)
|
||||
|
||||
assert job.title == "New Job"
|
||||
assert job.owner_id == owner_id
|
||||
assert job.created_by == owner_id
|
||||
assert job.bootstrap_key is None
|
||||
assert job.schedule_type == ScheduleType.DAILY
|
||||
assert fake_session.flushed is True
|
||||
assert len(fake_session.added) == 1
|
||||
session.flush.assert_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_soft_delete(
|
||||
repository: AutomationJobsRepository,
|
||||
fake_session: _FakeSession,
|
||||
) -> None:
|
||||
async def test_update_returns_none_when_job_not_found() -> None:
|
||||
session = AsyncMock()
|
||||
repository = AutomationJobsRepository(session)
|
||||
job_id = uuid4()
|
||||
execute_result = MagicMock()
|
||||
execute_result.scalar_one_or_none.return_value = None
|
||||
session.execute.return_value = execute_result
|
||||
|
||||
data = AutomationJobUpdateRequest(title="Updated Title")
|
||||
result = await repository.update(job_id, data)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_soft_delete_calls_soft_delete_by_id() -> None:
|
||||
session = AsyncMock()
|
||||
session.flush = AsyncMock()
|
||||
execute_result = MagicMock()
|
||||
execute_result.scalar_one_or_none.return_value = None
|
||||
session.execute.return_value = execute_result
|
||||
repository = AutomationJobsRepository(session)
|
||||
job_id = uuid4()
|
||||
fake_session.set_execute_result(None)
|
||||
|
||||
await repository.soft_delete(job_id)
|
||||
|
||||
assert fake_session.flushed is True
|
||||
session.flush.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_job_title(
|
||||
repository: AutomationJobsRepository,
|
||||
fake_session: _FakeSession,
|
||||
sample_job: SimpleNamespace,
|
||||
) -> None:
|
||||
from v1.automation_jobs.schemas import AutomationJobUpdateRequest
|
||||
async def test_list_due_jobs_filters_by_active_status() -> None:
|
||||
session = AsyncMock()
|
||||
repository = AutomationJobsRepository(session)
|
||||
execute_result = MagicMock()
|
||||
execute_result.scalars.return_value.all.return_value = []
|
||||
session.execute.return_value = execute_result
|
||||
|
||||
sample_job.title = "Updated Title"
|
||||
fake_session.set_execute_result(sample_job)
|
||||
await repository.list_due_jobs(now_utc=MagicMock(), limit=10)
|
||||
|
||||
request = AutomationJobUpdateRequest(title="Updated Title")
|
||||
job = await repository.update(sample_job.id, request)
|
||||
|
||||
assert job is not None
|
||||
assert job.title == "Updated Title"
|
||||
session.execute.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_job_run_at_recomputes_next_run_at(
|
||||
repository: AutomationJobsRepository,
|
||||
fake_session: _FakeSession,
|
||||
sample_job: SimpleNamespace,
|
||||
) -> None:
|
||||
from v1.automation_jobs.schemas import AutomationJobUpdateRequest
|
||||
async def test_create_stores_run_at_as_timezone_aware() -> None:
|
||||
session = AsyncMock()
|
||||
session.add = MagicMock()
|
||||
repository = AutomationJobsRepository(session)
|
||||
owner_id = uuid4()
|
||||
data = _make_create_request()
|
||||
|
||||
fake_session.set_execute_result(sample_job)
|
||||
await repository.create(owner_id, data)
|
||||
|
||||
request = AutomationJobUpdateRequest(
|
||||
run_at=time(12, 0),
|
||||
timezone="UTC",
|
||||
call_args = session.add.call_args[0][0]
|
||||
assert call_args.run_at.tzinfo is not None, "run_at should be timezone-aware"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_run_at_with_timezone_none_uses_existing_timezone() -> None:
|
||||
session = AsyncMock()
|
||||
repository = AutomationJobsRepository(session)
|
||||
job_id = uuid4()
|
||||
existing_job = MagicMock()
|
||||
existing_job.schedule_type = ScheduleType.DAILY
|
||||
existing_job.timezone = "America/New_York"
|
||||
existing_job.config = {}
|
||||
existing_job.run_at = None
|
||||
execute_result = MagicMock()
|
||||
execute_result.scalar_one_or_none.return_value = existing_job
|
||||
session.execute.return_value = execute_result
|
||||
|
||||
repository.update_by_id = AsyncMock(return_value=existing_job)
|
||||
|
||||
data = AutomationJobUpdateRequest(run_at=time(14, 30, 0))
|
||||
result = await repository.update(job_id, data)
|
||||
|
||||
assert result is not None
|
||||
update_values = repository.update_by_id.call_args[0][1]
|
||||
assert "run_at" in update_values
|
||||
assert "next_run_at" in update_values
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_schedule_type_recomputes_next_run_at() -> None:
|
||||
session = AsyncMock()
|
||||
repository = AutomationJobsRepository(session)
|
||||
job_id = uuid4()
|
||||
existing_job = MagicMock()
|
||||
existing_job.schedule_type = ScheduleType.DAILY
|
||||
existing_job.timezone = "UTC"
|
||||
existing_job.run_at = datetime(2026, 1, 1, 8, 0, 0, tzinfo=timezone.utc)
|
||||
existing_job.config = {}
|
||||
|
||||
repository.get_by_id = AsyncMock(return_value=existing_job)
|
||||
repository.update_by_id = AsyncMock(return_value=existing_job)
|
||||
|
||||
data = AutomationJobUpdateRequest(schedule_type=ScheduleType.WEEKLY)
|
||||
result = await repository.update(job_id, data)
|
||||
|
||||
assert result is not None
|
||||
update_values = repository.update_by_id.call_args[0][1]
|
||||
assert update_values["schedule_type"] == ScheduleType.WEEKLY
|
||||
assert "run_at" in update_values
|
||||
assert "next_run_at" in update_values
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_config_serializes_enum_values_to_json() -> None:
|
||||
session = AsyncMock()
|
||||
repository = AutomationJobsRepository(session)
|
||||
job_id = uuid4()
|
||||
existing_job = MagicMock()
|
||||
existing_job.schedule_type = ScheduleType.DAILY
|
||||
existing_job.timezone = "UTC"
|
||||
existing_job.run_at = datetime(2026, 1, 1, 8, 0, 0, tzinfo=timezone.utc)
|
||||
existing_job.config = {"input_template": "Old"}
|
||||
|
||||
repository.get_by_id = AsyncMock(return_value=existing_job)
|
||||
repository.update_by_id = AsyncMock(return_value=existing_job)
|
||||
|
||||
data = AutomationJobUpdateRequest(
|
||||
config={"enabled_tools": [AgentTool.MEMORY_WRITE]},
|
||||
)
|
||||
job = await repository.update(sample_job.id, request)
|
||||
result = await repository.update(job_id, data)
|
||||
|
||||
assert job is not None
|
||||
assert fake_session.flushed is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_returns_none_when_job_not_found(
|
||||
repository: AutomationJobsRepository,
|
||||
fake_session: _FakeSession,
|
||||
) -> None:
|
||||
from v1.automation_jobs.schemas import AutomationJobUpdateRequest
|
||||
|
||||
fake_session.set_execute_result(None)
|
||||
|
||||
request = AutomationJobUpdateRequest(title="New Title")
|
||||
job = await repository.update(uuid4(), request)
|
||||
|
||||
assert job is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_with_no_changes_returns_existing_job(
|
||||
repository: AutomationJobsRepository,
|
||||
fake_session: _FakeSession,
|
||||
sample_job: SimpleNamespace,
|
||||
) -> None:
|
||||
from v1.automation_jobs.schemas import AutomationJobUpdateRequest
|
||||
|
||||
fake_session.set_execute_result(sample_job)
|
||||
|
||||
request = AutomationJobUpdateRequest()
|
||||
job = await repository.update(sample_job.id, request)
|
||||
|
||||
assert job is not None
|
||||
assert job.title == "Test Job"
|
||||
assert result is not None
|
||||
update_values = repository.update_by_id.call_args[0][1]
|
||||
enabled_tools = update_values["config"]["enabled_tools"]
|
||||
assert isinstance(enabled_tools[0], str)
|
||||
|
||||
@@ -0,0 +1,246 @@
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from v1.automation_jobs.schemas import (
|
||||
AutomationJobCreateRequest,
|
||||
AutomationJobUpdateRequest,
|
||||
AutomationJobResponse,
|
||||
)
|
||||
from schemas.automation import AgentTool, AutomationJobConfig
|
||||
|
||||
|
||||
class TestIsSystemProperty:
|
||||
def test_is_system_true_when_bootstrap_key_present(self):
|
||||
mock_orm_job = MagicMock()
|
||||
mock_orm_job.id = uuid4()
|
||||
mock_orm_job.owner_id = uuid4()
|
||||
mock_orm_job.bootstrap_key = "memory_extraction"
|
||||
mock_orm_job.title = "Test Job"
|
||||
mock_orm_job.schedule_type = "daily"
|
||||
mock_orm_job.run_at = datetime.now()
|
||||
mock_orm_job.config = {
|
||||
"input_template": "Hello",
|
||||
"enabled_tools": [],
|
||||
"context": {},
|
||||
}
|
||||
mock_orm_job.schedule_type = "daily"
|
||||
mock_orm_job.status = "active"
|
||||
mock_orm_job.timezone = "Asia/Shanghai"
|
||||
mock_orm_job.next_run_at = datetime.now()
|
||||
mock_orm_job.last_run_at = None
|
||||
mock_orm_job.created_at = datetime.now()
|
||||
mock_orm_job.updated_at = datetime.now()
|
||||
mock_orm_job.deleted_at = None
|
||||
|
||||
resp = AutomationJobResponse.from_orm(mock_orm_job)
|
||||
assert resp.is_system is True
|
||||
|
||||
def test_is_system_false_when_bootstrap_key_none(self):
|
||||
mock_orm_job = MagicMock()
|
||||
mock_orm_job.id = uuid4()
|
||||
mock_orm_job.owner_id = uuid4()
|
||||
mock_orm_job.bootstrap_key = None
|
||||
mock_orm_job.title = "Test Job"
|
||||
mock_orm_job.schedule_type = "daily"
|
||||
mock_orm_job.run_at = datetime.now()
|
||||
mock_orm_job.config = {
|
||||
"input_template": "Hello",
|
||||
"enabled_tools": [],
|
||||
"context": {},
|
||||
}
|
||||
mock_orm_job.schedule_type = "daily"
|
||||
mock_orm_job.status = "active"
|
||||
mock_orm_job.timezone = "Asia/Shanghai"
|
||||
mock_orm_job.next_run_at = datetime.now()
|
||||
mock_orm_job.last_run_at = None
|
||||
mock_orm_job.created_at = datetime.now()
|
||||
mock_orm_job.updated_at = datetime.now()
|
||||
mock_orm_job.deleted_at = None
|
||||
|
||||
resp = AutomationJobResponse.from_orm(mock_orm_job)
|
||||
assert resp.is_system is False
|
||||
|
||||
|
||||
class TestFromOrm:
|
||||
def test_run_at_converted_from_datetime_to_time(self):
|
||||
run_at_datetime = datetime(2024, 6, 15, 14, 30, 0)
|
||||
mock_orm_job = MagicMock()
|
||||
mock_orm_job.id = uuid4()
|
||||
mock_orm_job.owner_id = uuid4()
|
||||
mock_orm_job.bootstrap_key = None
|
||||
mock_orm_job.title = "Test Job"
|
||||
mock_orm_job.schedule_type = "daily"
|
||||
mock_orm_job.run_at = run_at_datetime
|
||||
mock_orm_job.config = {
|
||||
"input_template": "Hello",
|
||||
"enabled_tools": [],
|
||||
"context": {},
|
||||
}
|
||||
mock_orm_job.schedule_type = "daily"
|
||||
mock_orm_job.status = "active"
|
||||
mock_orm_job.timezone = "Asia/Shanghai"
|
||||
mock_orm_job.next_run_at = datetime.now()
|
||||
mock_orm_job.last_run_at = None
|
||||
mock_orm_job.created_at = datetime.now()
|
||||
mock_orm_job.updated_at = datetime.now()
|
||||
mock_orm_job.deleted_at = None
|
||||
|
||||
resp = AutomationJobResponse.from_orm(mock_orm_job)
|
||||
assert resp.run_at == run_at_datetime.time()
|
||||
|
||||
def test_config_deserialized(self):
|
||||
config = {
|
||||
"input_template": "Test template",
|
||||
"enabled_tools": [AgentTool.MEMORY_WRITE],
|
||||
"context": {
|
||||
"source": "latest_chat",
|
||||
"window_mode": "day",
|
||||
"window_count": 5,
|
||||
},
|
||||
}
|
||||
mock_orm_job = MagicMock()
|
||||
mock_orm_job.id = uuid4()
|
||||
mock_orm_job.owner_id = uuid4()
|
||||
mock_orm_job.bootstrap_key = None
|
||||
mock_orm_job.title = "Test Job"
|
||||
mock_orm_job.schedule_type = "daily"
|
||||
mock_orm_job.run_at = datetime.now()
|
||||
mock_orm_job.config = config
|
||||
mock_orm_job.schedule_type = "daily"
|
||||
mock_orm_job.status = "active"
|
||||
mock_orm_job.timezone = "Asia/Shanghai"
|
||||
mock_orm_job.next_run_at = datetime.now()
|
||||
mock_orm_job.last_run_at = None
|
||||
mock_orm_job.created_at = datetime.now()
|
||||
mock_orm_job.updated_at = datetime.now()
|
||||
mock_orm_job.deleted_at = None
|
||||
|
||||
resp = AutomationJobResponse.from_orm(mock_orm_job)
|
||||
assert resp.config.input_template == "Test template"
|
||||
assert resp.config.enabled_tools == [AgentTool.MEMORY_WRITE]
|
||||
assert resp.config.context.window_count == 5
|
||||
|
||||
def test_is_system_derived_from_bootstrap_key(self):
|
||||
mock_orm_job = MagicMock()
|
||||
mock_orm_job.id = uuid4()
|
||||
mock_orm_job.owner_id = uuid4()
|
||||
mock_orm_job.bootstrap_key = "system_bootstrap"
|
||||
mock_orm_job.title = "Test Job"
|
||||
mock_orm_job.schedule_type = "daily"
|
||||
mock_orm_job.run_at = datetime.now()
|
||||
mock_orm_job.config = {
|
||||
"input_template": "Hello",
|
||||
"enabled_tools": [],
|
||||
"context": {},
|
||||
}
|
||||
mock_orm_job.schedule_type = "daily"
|
||||
mock_orm_job.status = "active"
|
||||
mock_orm_job.timezone = "UTC"
|
||||
mock_orm_job.next_run_at = datetime.now()
|
||||
mock_orm_job.last_run_at = None
|
||||
mock_orm_job.created_at = datetime.now()
|
||||
mock_orm_job.updated_at = datetime.now()
|
||||
mock_orm_job.deleted_at = None
|
||||
|
||||
resp = AutomationJobResponse.from_orm(mock_orm_job)
|
||||
assert resp.is_system is True
|
||||
assert resp.bootstrap_key == "system_bootstrap"
|
||||
|
||||
|
||||
class TestTimezoneValidation:
|
||||
def test_valid_timezone(self):
|
||||
request = AutomationJobCreateRequest.model_validate(
|
||||
{
|
||||
"title": "Test Job",
|
||||
"schedule_type": "daily",
|
||||
"run_at": "09:00:00",
|
||||
"timezone": "Asia/Shanghai",
|
||||
"config": {
|
||||
"input_template": "Hello",
|
||||
"enabled_tools": [],
|
||||
"context": {
|
||||
"source": "latest_chat",
|
||||
"window_mode": "day",
|
||||
"window_count": 2,
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
assert request.timezone == "Asia/Shanghai"
|
||||
|
||||
def test_invalid_timezone(self):
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
AutomationJobCreateRequest.model_validate(
|
||||
{
|
||||
"title": "Test Job",
|
||||
"schedule_type": "daily",
|
||||
"run_at": "09:00:00",
|
||||
"timezone": "Invalid/Timezone",
|
||||
"config": {
|
||||
"input_template": "Hello",
|
||||
"enabled_tools": [],
|
||||
"context": {
|
||||
"source": "latest_chat",
|
||||
"window_mode": "day",
|
||||
"window_count": 2,
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
assert "timezone must be a valid IANA timezone" in str(exc_info.value)
|
||||
|
||||
def test_update_valid_timezone(self):
|
||||
request = AutomationJobUpdateRequest.model_validate(
|
||||
{
|
||||
"timezone": "America/New_York",
|
||||
}
|
||||
)
|
||||
assert request.timezone == "America/New_York"
|
||||
|
||||
def test_update_invalid_timezone(self):
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
AutomationJobUpdateRequest.model_validate(
|
||||
{
|
||||
"timezone": "Invalid/Timezone",
|
||||
}
|
||||
)
|
||||
assert "timezone must be a valid IANA timezone" in str(exc_info.value)
|
||||
|
||||
def test_update_none_timezone_allowed(self):
|
||||
request = AutomationJobUpdateRequest.model_validate(
|
||||
{
|
||||
"timezone": None,
|
||||
}
|
||||
)
|
||||
assert request.timezone is None
|
||||
|
||||
|
||||
class TestAutomationJobConfigPatch:
|
||||
def test_all_fields_optional(self):
|
||||
patch = AutomationJobConfig.model_validate({})
|
||||
assert patch.input_template is None
|
||||
assert patch.enabled_tools is None
|
||||
assert patch.context is None
|
||||
|
||||
def test_partial_input_template(self):
|
||||
patch = AutomationJobConfig.model_validate(
|
||||
{
|
||||
"input_template": "Updated template",
|
||||
}
|
||||
)
|
||||
assert patch.input_template == "Updated template"
|
||||
assert patch.enabled_tools is None
|
||||
assert patch.context is None
|
||||
|
||||
def test_extra_fields_forbidden(self):
|
||||
with pytest.raises(ValidationError):
|
||||
AutomationJobConfig.model_validate(
|
||||
{
|
||||
"input_template": "Test",
|
||||
"unknown_field": "value",
|
||||
}
|
||||
)
|
||||
@@ -0,0 +1,371 @@
|
||||
from datetime import datetime, time, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from models.automation_jobs import AutomationJobStatus, ScheduleType
|
||||
from v1.automation_jobs.service import (
|
||||
AutomationJobLimitExceeded,
|
||||
AutomationJobNotFound,
|
||||
AutomationJobsService,
|
||||
SystemJobModificationForbidden,
|
||||
)
|
||||
from v1.automation_jobs.schemas import (
|
||||
AutomationJobCreateRequest,
|
||||
AutomationJobUpdateRequest,
|
||||
)
|
||||
from schemas.automation import (
|
||||
AgentTool,
|
||||
AutomationJobConfig,
|
||||
ContextSource,
|
||||
ContextWindowMode,
|
||||
MessageContextConfig,
|
||||
)
|
||||
|
||||
|
||||
def _make_config() -> AutomationJobConfig:
|
||||
return AutomationJobConfig(
|
||||
input_template="Hello",
|
||||
enabled_tools=[AgentTool.MEMORY_WRITE],
|
||||
context=MessageContextConfig(
|
||||
source=ContextSource.LATEST_CHAT,
|
||||
window_mode=ContextWindowMode.DAY,
|
||||
window_count=2,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _make_create_request() -> AutomationJobCreateRequest:
|
||||
return AutomationJobCreateRequest(
|
||||
title="Test Job",
|
||||
schedule_type=ScheduleType.DAILY,
|
||||
run_at=time(9, 0, 0),
|
||||
timezone="Asia/Shanghai",
|
||||
status=AutomationJobStatus.ACTIVE,
|
||||
config=_make_config(),
|
||||
)
|
||||
|
||||
|
||||
def _make_job(
|
||||
owner_id: MagicMock | None = None, bootstrap_key: str | None = None
|
||||
) -> MagicMock:
|
||||
job = MagicMock()
|
||||
job.id = uuid4()
|
||||
job.owner_id = owner_id or uuid4()
|
||||
job.bootstrap_key = bootstrap_key
|
||||
job.title = "Test Job"
|
||||
job.schedule_type = ScheduleType.DAILY
|
||||
job.run_at = datetime(2024, 1, 1, 9, 0, 0, tzinfo=timezone.utc)
|
||||
job.timezone = "Asia/Shanghai"
|
||||
job.status = AutomationJobStatus.ACTIVE
|
||||
job.config = {"input_template": "Hello"}
|
||||
job.next_run_at = datetime(2024, 1, 2, 9, 0, 0, tzinfo=timezone.utc)
|
||||
job.last_run_at = None
|
||||
job.created_at = datetime(2024, 1, 1, 9, 0, 0, tzinfo=timezone.utc)
|
||||
job.updated_at = datetime(2024, 1, 1, 9, 0, 0, tzinfo=timezone.utc)
|
||||
return job
|
||||
|
||||
|
||||
class TestListByOwner:
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_by_owner_returns_jobs(self) -> None:
|
||||
session = AsyncMock()
|
||||
repository = AsyncMock()
|
||||
service = AutomationJobsService(repository, session)
|
||||
owner_id = uuid4()
|
||||
job = _make_job(owner_id)
|
||||
repository.list_by_owner.return_value = [job]
|
||||
|
||||
result = await service.list_by_owner(owner_id)
|
||||
|
||||
assert len(result.items) == 1
|
||||
assert result.items[0].title == job.title
|
||||
repository.list_by_owner.assert_awaited_once_with(owner_id)
|
||||
|
||||
|
||||
class TestGetById:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_id_returns_job(self) -> None:
|
||||
session = AsyncMock()
|
||||
repository = AsyncMock()
|
||||
service = AutomationJobsService(repository, session)
|
||||
owner_id = uuid4()
|
||||
job = _make_job(owner_id)
|
||||
repository.get_by_id.return_value = job
|
||||
|
||||
result = await service.get_by_id(job.id, owner_id)
|
||||
|
||||
assert result.title == job.title
|
||||
repository.get_by_id.assert_awaited_once_with(job.id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_id_raises_not_found_when_job_none(self) -> None:
|
||||
session = AsyncMock()
|
||||
repository = AsyncMock()
|
||||
service = AutomationJobsService(repository, session)
|
||||
owner_id = uuid4()
|
||||
job_id = uuid4()
|
||||
repository.get_by_id.return_value = None
|
||||
|
||||
with pytest.raises(AutomationJobNotFound):
|
||||
await service.get_by_id(job_id, owner_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_id_raises_not_found_when_owner_mismatch(self) -> None:
|
||||
session = AsyncMock()
|
||||
repository = AsyncMock()
|
||||
service = AutomationJobsService(repository, session)
|
||||
owner_id = uuid4()
|
||||
different_owner_id = uuid4()
|
||||
job = _make_job(different_owner_id)
|
||||
repository.get_by_id.return_value = job
|
||||
|
||||
with pytest.raises(AutomationJobNotFound):
|
||||
await service.get_by_id(job.id, owner_id)
|
||||
|
||||
|
||||
class TestCreate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_raises_limit_exceeded(self) -> None:
|
||||
session = AsyncMock()
|
||||
repository = AsyncMock()
|
||||
service = AutomationJobsService(repository, session)
|
||||
owner_id = uuid4()
|
||||
data = _make_create_request()
|
||||
repository.count_user_jobs.return_value = 3
|
||||
|
||||
with pytest.raises(AutomationJobLimitExceeded):
|
||||
await service.create(owner_id, data)
|
||||
|
||||
session.execute.assert_awaited_once()
|
||||
session.rollback.assert_awaited_once()
|
||||
repository.count_user_jobs.assert_awaited_once_with(owner_id)
|
||||
repository.create.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_succeeds_when_under_limit(self) -> None:
|
||||
session = AsyncMock()
|
||||
repository = AsyncMock()
|
||||
service = AutomationJobsService(repository, session)
|
||||
owner_id = uuid4()
|
||||
data = _make_create_request()
|
||||
job = _make_job(owner_id)
|
||||
repository.count_user_jobs.return_value = 2
|
||||
repository.create.return_value = job
|
||||
|
||||
result = await service.create(owner_id, data)
|
||||
|
||||
assert result.title == job.title
|
||||
session.execute.assert_awaited_once()
|
||||
repository.create.assert_awaited_once_with(owner_id, data)
|
||||
session.commit.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_commits_session(self) -> None:
|
||||
session = AsyncMock()
|
||||
repository = AsyncMock()
|
||||
service = AutomationJobsService(repository, session)
|
||||
owner_id = uuid4()
|
||||
data = _make_create_request()
|
||||
job = _make_job(owner_id)
|
||||
repository.count_user_jobs.return_value = 0
|
||||
repository.create.return_value = job
|
||||
|
||||
await service.create(owner_id, data)
|
||||
|
||||
session.execute.assert_awaited_once()
|
||||
session.commit.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_rollbacks_on_sqlalchemy_error(self) -> None:
|
||||
session = AsyncMock()
|
||||
repository = AsyncMock()
|
||||
service = AutomationJobsService(repository, session)
|
||||
owner_id = uuid4()
|
||||
data = _make_create_request()
|
||||
repository.count_user_jobs.return_value = 0
|
||||
repository.create.side_effect = SQLAlchemyError("db down")
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await service.create(owner_id, data)
|
||||
|
||||
assert exc.value.status_code == 503
|
||||
session.execute.assert_awaited_once()
|
||||
session.rollback.assert_awaited_once()
|
||||
session.commit.assert_not_awaited()
|
||||
|
||||
|
||||
class TestUpdate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_raises_not_found_when_job_none(self) -> None:
|
||||
session = AsyncMock()
|
||||
repository = AsyncMock()
|
||||
service = AutomationJobsService(repository, session)
|
||||
owner_id = uuid4()
|
||||
job_id = uuid4()
|
||||
repository.get_by_id.return_value = None
|
||||
|
||||
with pytest.raises(AutomationJobNotFound):
|
||||
await service.update(
|
||||
job_id, owner_id, AutomationJobUpdateRequest(title="New")
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_raises_not_found_when_owner_mismatch(self) -> None:
|
||||
session = AsyncMock()
|
||||
repository = AsyncMock()
|
||||
service = AutomationJobsService(repository, session)
|
||||
owner_id = uuid4()
|
||||
different_owner_id = uuid4()
|
||||
job = _make_job(different_owner_id)
|
||||
repository.get_by_id.return_value = job
|
||||
|
||||
with pytest.raises(AutomationJobNotFound):
|
||||
await service.update(
|
||||
job.id, owner_id, AutomationJobUpdateRequest(title="New")
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_raises_system_job_forbidden(self) -> None:
|
||||
session = AsyncMock()
|
||||
repository = AsyncMock()
|
||||
service = AutomationJobsService(repository, session)
|
||||
owner_id = uuid4()
|
||||
job = _make_job(owner_id, bootstrap_key="system-key")
|
||||
repository.get_by_id.return_value = job
|
||||
|
||||
with pytest.raises(SystemJobModificationForbidden):
|
||||
await service.update(
|
||||
job.id, owner_id, AutomationJobUpdateRequest(title="New")
|
||||
)
|
||||
|
||||
repository.update.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_succeeds(self) -> None:
|
||||
session = AsyncMock()
|
||||
repository = AsyncMock()
|
||||
service = AutomationJobsService(repository, session)
|
||||
owner_id = uuid4()
|
||||
job = _make_job(owner_id)
|
||||
updated_job = _make_job(owner_id)
|
||||
updated_job.title = "Updated Title"
|
||||
repository.get_by_id.return_value = job
|
||||
repository.update.return_value = updated_job
|
||||
|
||||
result = await service.update(
|
||||
job.id, owner_id, AutomationJobUpdateRequest(title="Updated Title")
|
||||
)
|
||||
|
||||
assert result.title == "Updated Title"
|
||||
repository.update.assert_awaited_once_with(
|
||||
job.id, AutomationJobUpdateRequest(title="Updated Title")
|
||||
)
|
||||
session.commit.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_returns_not_found_when_update_returns_none(self) -> None:
|
||||
session = AsyncMock()
|
||||
repository = AsyncMock()
|
||||
service = AutomationJobsService(repository, session)
|
||||
owner_id = uuid4()
|
||||
job = _make_job(owner_id)
|
||||
repository.get_by_id.return_value = job
|
||||
repository.update.return_value = None
|
||||
|
||||
with pytest.raises(AutomationJobNotFound):
|
||||
await service.update(
|
||||
job.id, owner_id, AutomationJobUpdateRequest(title="New")
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_rollbacks_on_sqlalchemy_error(self) -> None:
|
||||
session = AsyncMock()
|
||||
repository = AsyncMock()
|
||||
service = AutomationJobsService(repository, session)
|
||||
owner_id = uuid4()
|
||||
job = _make_job(owner_id, bootstrap_key=None)
|
||||
repository.get_by_id.return_value = job
|
||||
repository.update.side_effect = SQLAlchemyError("db down")
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await service.update(
|
||||
job.id, owner_id, AutomationJobUpdateRequest(title="New")
|
||||
)
|
||||
|
||||
assert exc.value.status_code == 503
|
||||
session.rollback.assert_awaited_once()
|
||||
|
||||
|
||||
class TestDelete:
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_raises_not_found_when_job_none(self) -> None:
|
||||
session = AsyncMock()
|
||||
repository = AsyncMock()
|
||||
service = AutomationJobsService(repository, session)
|
||||
owner_id = uuid4()
|
||||
job_id = uuid4()
|
||||
repository.get_by_id.return_value = None
|
||||
|
||||
with pytest.raises(AutomationJobNotFound):
|
||||
await service.delete(job_id, owner_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_raises_not_found_when_owner_mismatch(self) -> None:
|
||||
session = AsyncMock()
|
||||
repository = AsyncMock()
|
||||
service = AutomationJobsService(repository, session)
|
||||
owner_id = uuid4()
|
||||
different_owner_id = uuid4()
|
||||
job = _make_job(different_owner_id)
|
||||
repository.get_by_id.return_value = job
|
||||
|
||||
with pytest.raises(AutomationJobNotFound):
|
||||
await service.delete(job.id, owner_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_raises_system_job_forbidden(self) -> None:
|
||||
session = AsyncMock()
|
||||
repository = AsyncMock()
|
||||
service = AutomationJobsService(repository, session)
|
||||
owner_id = uuid4()
|
||||
job = _make_job(owner_id, bootstrap_key="system-key")
|
||||
repository.get_by_id.return_value = job
|
||||
|
||||
with pytest.raises(SystemJobModificationForbidden):
|
||||
await service.delete(job.id, owner_id)
|
||||
|
||||
repository.soft_delete.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_succeeds(self) -> None:
|
||||
session = AsyncMock()
|
||||
repository = AsyncMock()
|
||||
service = AutomationJobsService(repository, session)
|
||||
owner_id = uuid4()
|
||||
job = _make_job(owner_id)
|
||||
repository.get_by_id.return_value = job
|
||||
|
||||
await service.delete(job.id, owner_id)
|
||||
|
||||
repository.soft_delete.assert_awaited_once_with(job.id)
|
||||
session.commit.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_rollbacks_on_sqlalchemy_error(self) -> None:
|
||||
session = AsyncMock()
|
||||
repository = AsyncMock()
|
||||
service = AutomationJobsService(repository, session)
|
||||
owner_id = uuid4()
|
||||
job = _make_job(owner_id, bootstrap_key=None)
|
||||
repository.get_by_id.return_value = job
|
||||
repository.soft_delete.side_effect = SQLAlchemyError("db down")
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await service.delete(job.id, owner_id)
|
||||
|
||||
assert exc.value.status_code == 503
|
||||
session.rollback.assert_awaited_once()
|
||||
Reference in New Issue
Block a user