feat: 添加自动化任务(automation_jobs)功能模块
This commit is contained in:
@@ -1,283 +1,287 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, time, timezone
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from models.automation_jobs import AutomationJobStatus, ScheduleType
|
||||
from v1.automation_jobs.repository import AutomationJobsRepository
|
||||
from v1.automation_jobs.schemas import (
|
||||
AutomationJobCreateRequest,
|
||||
AutomationJobUpdateRequest,
|
||||
)
|
||||
from schemas.automation import (
|
||||
AgentTool,
|
||||
AutomationJobConfig,
|
||||
ContextSource,
|
||||
ContextWindowMode,
|
||||
MessageContextConfig,
|
||||
)
|
||||
|
||||
|
||||
class _ExecuteResult:
|
||||
def __init__(self, value: object) -> None:
|
||||
self._value = value
|
||||
|
||||
def scalar_one_or_none(self) -> object:
|
||||
return self._value
|
||||
|
||||
def scalar_one(self) -> int:
|
||||
return self._value # type: ignore[return-value]
|
||||
def _make_config() -> AutomationJobConfig:
|
||||
return AutomationJobConfig(
|
||||
input_template="Hello",
|
||||
enabled_tools=[AgentTool.MEMORY_WRITE],
|
||||
context=MessageContextConfig(
|
||||
source=ContextSource.LATEST_CHAT,
|
||||
window_mode=ContextWindowMode.DAY,
|
||||
window_count=2,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class _ScalarRows:
|
||||
def __init__(self, rows: list[object]) -> None:
|
||||
self._rows = rows
|
||||
|
||||
def all(self) -> list[object]:
|
||||
return self._rows
|
||||
|
||||
|
||||
class _ExecuteRowsResult:
|
||||
def __init__(self, rows: list[object]) -> None:
|
||||
self._rows = rows
|
||||
|
||||
def scalars(self) -> _ScalarRows:
|
||||
return _ScalarRows(self._rows)
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self) -> None:
|
||||
self.added: list[object] = []
|
||||
self.flushed = False
|
||||
self._execute_result: object = None
|
||||
self._return_rows: bool = False
|
||||
|
||||
def set_execute_result(self, value: object) -> None:
|
||||
self._execute_result = value
|
||||
self._return_rows = isinstance(value, list)
|
||||
|
||||
async def execute(self, stmt): # noqa: ANN001
|
||||
del stmt
|
||||
if self._return_rows:
|
||||
return _ExecuteRowsResult(self._execute_result)
|
||||
return _ExecuteResult(self._execute_result)
|
||||
|
||||
def add(self, obj: object) -> None:
|
||||
self.added.append(obj)
|
||||
|
||||
async def flush(self) -> None:
|
||||
self.flushed = True
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_session() -> _FakeSession:
|
||||
return _FakeSession()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def repository(fake_session: _FakeSession) -> AutomationJobsRepository:
|
||||
return AutomationJobsRepository(session=fake_session) # type: ignore[arg-type]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_job() -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
id=uuid4(),
|
||||
owner_id=uuid4(),
|
||||
bootstrap_key=None,
|
||||
def _make_create_request() -> AutomationJobCreateRequest:
|
||||
return AutomationJobCreateRequest(
|
||||
title="Test Job",
|
||||
config={"input_template": "Hello {name}"},
|
||||
schedule_type=ScheduleType.DAILY,
|
||||
run_at=datetime(2026, 3, 23, 0, 0, tzinfo=timezone.utc),
|
||||
next_run_at=datetime(2026, 3, 24, 0, 0, tzinfo=timezone.utc),
|
||||
timezone="UTC",
|
||||
run_at=time(9, 0, 0),
|
||||
timezone="Asia/Shanghai",
|
||||
status=AutomationJobStatus.ACTIVE,
|
||||
created_by=uuid4(),
|
||||
deleted_at=None,
|
||||
config=_make_config(),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_by_owner_returns_jobs(
|
||||
repository: AutomationJobsRepository,
|
||||
fake_session: _FakeSession,
|
||||
sample_job: SimpleNamespace,
|
||||
) -> None:
|
||||
fake_session.set_execute_result([sample_job])
|
||||
|
||||
async def test_list_by_owner_returns_jobs() -> None:
|
||||
session = AsyncMock()
|
||||
repository = AutomationJobsRepository(session)
|
||||
owner_id = uuid4()
|
||||
jobs = await repository.list_by_owner(owner_id)
|
||||
job_one = MagicMock()
|
||||
job_two = MagicMock()
|
||||
execute_result = MagicMock()
|
||||
execute_result.scalars.return_value.all.return_value = [job_one, job_two]
|
||||
session.execute.return_value = execute_result
|
||||
|
||||
assert len(jobs) == 1
|
||||
assert jobs[0].title == "Test Job"
|
||||
result = await repository.list_by_owner(owner_id)
|
||||
|
||||
assert result == [job_one, job_two]
|
||||
session.execute.assert_awaited_once()
|
||||
call_args = session.execute.call_args
|
||||
stmt = call_args[0][0]
|
||||
assert "owner_id" in str(stmt)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_by_owner_returns_empty_list(
|
||||
repository: AutomationJobsRepository,
|
||||
fake_session: _FakeSession,
|
||||
) -> None:
|
||||
fake_session.set_execute_result([])
|
||||
|
||||
async def test_count_user_jobs_counts_non_bootstrap_jobs() -> None:
|
||||
session = AsyncMock()
|
||||
repository = AutomationJobsRepository(session)
|
||||
owner_id = uuid4()
|
||||
jobs = await repository.list_by_owner(owner_id)
|
||||
execute_result = MagicMock()
|
||||
execute_result.scalar_one.return_value = 3
|
||||
session.execute.return_value = execute_result
|
||||
|
||||
assert jobs == []
|
||||
result = await repository.count_user_jobs(owner_id)
|
||||
|
||||
assert result == 3
|
||||
session.execute.assert_awaited_once()
|
||||
call_args = session.execute.call_args
|
||||
stmt = call_args[0][0]
|
||||
stmt_str = str(stmt)
|
||||
assert "bootstrap_key" in stmt_str
|
||||
assert "IS NULL" in stmt_str or "is_(None)" in stmt_str.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_id_returns_job(
|
||||
repository: AutomationJobsRepository,
|
||||
fake_session: _FakeSession,
|
||||
sample_job: SimpleNamespace,
|
||||
) -> None:
|
||||
fake_session.set_execute_result(sample_job)
|
||||
async def test_create_sets_bootstrap_key_to_none() -> None:
|
||||
session = AsyncMock()
|
||||
session.add = MagicMock()
|
||||
repository = AutomationJobsRepository(session)
|
||||
owner_id = uuid4()
|
||||
data = _make_create_request()
|
||||
|
||||
await repository.create(owner_id, data)
|
||||
|
||||
session.add.assert_called_once()
|
||||
call_args = session.add.call_args[0][0]
|
||||
assert call_args.bootstrap_key is None
|
||||
session.flush.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_sets_correct_fields() -> None:
|
||||
session = AsyncMock()
|
||||
session.add = MagicMock()
|
||||
repository = AutomationJobsRepository(session)
|
||||
owner_id = uuid4()
|
||||
data = _make_create_request()
|
||||
|
||||
await repository.create(owner_id, data)
|
||||
|
||||
call_args = session.add.call_args[0][0]
|
||||
assert call_args.owner_id == owner_id
|
||||
assert call_args.title == data.title
|
||||
assert call_args.schedule_type == data.schedule_type
|
||||
assert call_args.timezone == data.timezone
|
||||
assert call_args.status == data.status
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_returns_updated_job() -> None:
|
||||
session = AsyncMock()
|
||||
repository = AutomationJobsRepository(session)
|
||||
job_id = uuid4()
|
||||
job = await repository.get_by_id(job_id)
|
||||
existing_job = MagicMock()
|
||||
existing_job.schedule_type = ScheduleType.DAILY
|
||||
existing_job.config = {"input_template": "Old"}
|
||||
updated_job = MagicMock()
|
||||
execute_result = MagicMock()
|
||||
execute_result.scalar_one_or_none.return_value = updated_job
|
||||
session.execute.return_value = execute_result
|
||||
|
||||
assert job is not None
|
||||
assert job.title == "Test Job"
|
||||
data = AutomationJobUpdateRequest(title="Updated Title")
|
||||
result = await repository.update(job_id, data)
|
||||
|
||||
assert result is updated_job
|
||||
session.flush.assert_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_id_returns_none_when_not_found(
|
||||
repository: AutomationJobsRepository,
|
||||
fake_session: _FakeSession,
|
||||
) -> None:
|
||||
fake_session.set_execute_result(None)
|
||||
|
||||
async def test_update_merges_config() -> None:
|
||||
session = AsyncMock()
|
||||
repository = AutomationJobsRepository(session)
|
||||
job_id = uuid4()
|
||||
job = await repository.get_by_id(job_id)
|
||||
existing_job = MagicMock()
|
||||
existing_job.schedule_type = ScheduleType.DAILY
|
||||
existing_job.config = {"input_template": "Old", "enabled_tools": []}
|
||||
execute_result = MagicMock()
|
||||
execute_result.scalar_one_or_none.return_value = existing_job
|
||||
session.execute.return_value = execute_result
|
||||
|
||||
assert job is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_user_jobs_returns_count(
|
||||
repository: AutomationJobsRepository,
|
||||
fake_session: _FakeSession,
|
||||
) -> None:
|
||||
fake_session.set_execute_result(5)
|
||||
|
||||
owner_id = uuid4()
|
||||
count = await repository.count_user_jobs(owner_id)
|
||||
|
||||
assert count == 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_user_jobs_returns_zero_when_none(
|
||||
repository: AutomationJobsRepository,
|
||||
fake_session: _FakeSession,
|
||||
) -> None:
|
||||
fake_session.set_execute_result(0)
|
||||
|
||||
owner_id = uuid4()
|
||||
count = await repository.count_user_jobs(owner_id)
|
||||
|
||||
assert count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_job(
|
||||
repository: AutomationJobsRepository,
|
||||
fake_session: _FakeSession,
|
||||
) -> None:
|
||||
from v1.automation_jobs.schemas import AutomationJobCreateRequest
|
||||
from schemas.automation import AutomationJobConfig
|
||||
|
||||
owner_id = uuid4()
|
||||
request = AutomationJobCreateRequest(
|
||||
title="New Job",
|
||||
schedule_type=ScheduleType.DAILY,
|
||||
run_at=time(0, 0),
|
||||
timezone="UTC",
|
||||
status=AutomationJobStatus.ACTIVE,
|
||||
config=AutomationJobConfig(input_template="Test"),
|
||||
data = AutomationJobUpdateRequest(
|
||||
config={"input_template": "New", "context": {"source": "latest_chat"}}
|
||||
)
|
||||
await repository.update(job_id, data)
|
||||
|
||||
job = await repository.create(owner_id, request)
|
||||
|
||||
assert job.title == "New Job"
|
||||
assert job.owner_id == owner_id
|
||||
assert job.created_by == owner_id
|
||||
assert job.bootstrap_key is None
|
||||
assert job.schedule_type == ScheduleType.DAILY
|
||||
assert fake_session.flushed is True
|
||||
assert len(fake_session.added) == 1
|
||||
session.flush.assert_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_soft_delete(
|
||||
repository: AutomationJobsRepository,
|
||||
fake_session: _FakeSession,
|
||||
) -> None:
|
||||
async def test_update_returns_none_when_job_not_found() -> None:
|
||||
session = AsyncMock()
|
||||
repository = AutomationJobsRepository(session)
|
||||
job_id = uuid4()
|
||||
execute_result = MagicMock()
|
||||
execute_result.scalar_one_or_none.return_value = None
|
||||
session.execute.return_value = execute_result
|
||||
|
||||
data = AutomationJobUpdateRequest(title="Updated Title")
|
||||
result = await repository.update(job_id, data)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_soft_delete_calls_soft_delete_by_id() -> None:
|
||||
session = AsyncMock()
|
||||
session.flush = AsyncMock()
|
||||
execute_result = MagicMock()
|
||||
execute_result.scalar_one_or_none.return_value = None
|
||||
session.execute.return_value = execute_result
|
||||
repository = AutomationJobsRepository(session)
|
||||
job_id = uuid4()
|
||||
fake_session.set_execute_result(None)
|
||||
|
||||
await repository.soft_delete(job_id)
|
||||
|
||||
assert fake_session.flushed is True
|
||||
session.flush.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_job_title(
|
||||
repository: AutomationJobsRepository,
|
||||
fake_session: _FakeSession,
|
||||
sample_job: SimpleNamespace,
|
||||
) -> None:
|
||||
from v1.automation_jobs.schemas import AutomationJobUpdateRequest
|
||||
async def test_list_due_jobs_filters_by_active_status() -> None:
|
||||
session = AsyncMock()
|
||||
repository = AutomationJobsRepository(session)
|
||||
execute_result = MagicMock()
|
||||
execute_result.scalars.return_value.all.return_value = []
|
||||
session.execute.return_value = execute_result
|
||||
|
||||
sample_job.title = "Updated Title"
|
||||
fake_session.set_execute_result(sample_job)
|
||||
await repository.list_due_jobs(now_utc=MagicMock(), limit=10)
|
||||
|
||||
request = AutomationJobUpdateRequest(title="Updated Title")
|
||||
job = await repository.update(sample_job.id, request)
|
||||
|
||||
assert job is not None
|
||||
assert job.title == "Updated Title"
|
||||
session.execute.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_job_run_at_recomputes_next_run_at(
|
||||
repository: AutomationJobsRepository,
|
||||
fake_session: _FakeSession,
|
||||
sample_job: SimpleNamespace,
|
||||
) -> None:
|
||||
from v1.automation_jobs.schemas import AutomationJobUpdateRequest
|
||||
async def test_create_stores_run_at_as_timezone_aware() -> None:
|
||||
session = AsyncMock()
|
||||
session.add = MagicMock()
|
||||
repository = AutomationJobsRepository(session)
|
||||
owner_id = uuid4()
|
||||
data = _make_create_request()
|
||||
|
||||
fake_session.set_execute_result(sample_job)
|
||||
await repository.create(owner_id, data)
|
||||
|
||||
request = AutomationJobUpdateRequest(
|
||||
run_at=time(12, 0),
|
||||
timezone="UTC",
|
||||
call_args = session.add.call_args[0][0]
|
||||
assert call_args.run_at.tzinfo is not None, "run_at should be timezone-aware"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_run_at_with_timezone_none_uses_existing_timezone() -> None:
|
||||
session = AsyncMock()
|
||||
repository = AutomationJobsRepository(session)
|
||||
job_id = uuid4()
|
||||
existing_job = MagicMock()
|
||||
existing_job.schedule_type = ScheduleType.DAILY
|
||||
existing_job.timezone = "America/New_York"
|
||||
existing_job.config = {}
|
||||
existing_job.run_at = None
|
||||
execute_result = MagicMock()
|
||||
execute_result.scalar_one_or_none.return_value = existing_job
|
||||
session.execute.return_value = execute_result
|
||||
|
||||
repository.update_by_id = AsyncMock(return_value=existing_job)
|
||||
|
||||
data = AutomationJobUpdateRequest(run_at=time(14, 30, 0))
|
||||
result = await repository.update(job_id, data)
|
||||
|
||||
assert result is not None
|
||||
update_values = repository.update_by_id.call_args[0][1]
|
||||
assert "run_at" in update_values
|
||||
assert "next_run_at" in update_values
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_schedule_type_recomputes_next_run_at() -> None:
|
||||
session = AsyncMock()
|
||||
repository = AutomationJobsRepository(session)
|
||||
job_id = uuid4()
|
||||
existing_job = MagicMock()
|
||||
existing_job.schedule_type = ScheduleType.DAILY
|
||||
existing_job.timezone = "UTC"
|
||||
existing_job.run_at = datetime(2026, 1, 1, 8, 0, 0, tzinfo=timezone.utc)
|
||||
existing_job.config = {}
|
||||
|
||||
repository.get_by_id = AsyncMock(return_value=existing_job)
|
||||
repository.update_by_id = AsyncMock(return_value=existing_job)
|
||||
|
||||
data = AutomationJobUpdateRequest(schedule_type=ScheduleType.WEEKLY)
|
||||
result = await repository.update(job_id, data)
|
||||
|
||||
assert result is not None
|
||||
update_values = repository.update_by_id.call_args[0][1]
|
||||
assert update_values["schedule_type"] == ScheduleType.WEEKLY
|
||||
assert "run_at" in update_values
|
||||
assert "next_run_at" in update_values
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_config_serializes_enum_values_to_json() -> None:
|
||||
session = AsyncMock()
|
||||
repository = AutomationJobsRepository(session)
|
||||
job_id = uuid4()
|
||||
existing_job = MagicMock()
|
||||
existing_job.schedule_type = ScheduleType.DAILY
|
||||
existing_job.timezone = "UTC"
|
||||
existing_job.run_at = datetime(2026, 1, 1, 8, 0, 0, tzinfo=timezone.utc)
|
||||
existing_job.config = {"input_template": "Old"}
|
||||
|
||||
repository.get_by_id = AsyncMock(return_value=existing_job)
|
||||
repository.update_by_id = AsyncMock(return_value=existing_job)
|
||||
|
||||
data = AutomationJobUpdateRequest(
|
||||
config={"enabled_tools": [AgentTool.MEMORY_WRITE]},
|
||||
)
|
||||
job = await repository.update(sample_job.id, request)
|
||||
result = await repository.update(job_id, data)
|
||||
|
||||
assert job is not None
|
||||
assert fake_session.flushed is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_returns_none_when_job_not_found(
|
||||
repository: AutomationJobsRepository,
|
||||
fake_session: _FakeSession,
|
||||
) -> None:
|
||||
from v1.automation_jobs.schemas import AutomationJobUpdateRequest
|
||||
|
||||
fake_session.set_execute_result(None)
|
||||
|
||||
request = AutomationJobUpdateRequest(title="New Title")
|
||||
job = await repository.update(uuid4(), request)
|
||||
|
||||
assert job is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_with_no_changes_returns_existing_job(
|
||||
repository: AutomationJobsRepository,
|
||||
fake_session: _FakeSession,
|
||||
sample_job: SimpleNamespace,
|
||||
) -> None:
|
||||
from v1.automation_jobs.schemas import AutomationJobUpdateRequest
|
||||
|
||||
fake_session.set_execute_result(sample_job)
|
||||
|
||||
request = AutomationJobUpdateRequest()
|
||||
job = await repository.update(sample_job.id, request)
|
||||
|
||||
assert job is not None
|
||||
assert job.title == "Test Job"
|
||||
assert result is not None
|
||||
update_values = repository.update_by_id.call_args[0][1]
|
||||
enabled_tools = update_values["config"]["enabled_tools"]
|
||||
assert isinstance(enabled_tools[0], str)
|
||||
|
||||
Reference in New Issue
Block a user