feat: 添加自动化任务(automation_jobs)功能模块
This commit is contained in:
@@ -0,0 +1,371 @@
|
||||
from datetime import datetime, time, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from models.automation_jobs import AutomationJobStatus, ScheduleType
|
||||
from v1.automation_jobs.service import (
|
||||
AutomationJobLimitExceeded,
|
||||
AutomationJobNotFound,
|
||||
AutomationJobsService,
|
||||
SystemJobModificationForbidden,
|
||||
)
|
||||
from v1.automation_jobs.schemas import (
|
||||
AutomationJobCreateRequest,
|
||||
AutomationJobUpdateRequest,
|
||||
)
|
||||
from schemas.automation import (
|
||||
AgentTool,
|
||||
AutomationJobConfig,
|
||||
ContextSource,
|
||||
ContextWindowMode,
|
||||
MessageContextConfig,
|
||||
)
|
||||
|
||||
|
||||
def _make_config() -> AutomationJobConfig:
|
||||
return AutomationJobConfig(
|
||||
input_template="Hello",
|
||||
enabled_tools=[AgentTool.MEMORY_WRITE],
|
||||
context=MessageContextConfig(
|
||||
source=ContextSource.LATEST_CHAT,
|
||||
window_mode=ContextWindowMode.DAY,
|
||||
window_count=2,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _make_create_request() -> AutomationJobCreateRequest:
|
||||
return AutomationJobCreateRequest(
|
||||
title="Test Job",
|
||||
schedule_type=ScheduleType.DAILY,
|
||||
run_at=time(9, 0, 0),
|
||||
timezone="Asia/Shanghai",
|
||||
status=AutomationJobStatus.ACTIVE,
|
||||
config=_make_config(),
|
||||
)
|
||||
|
||||
|
||||
def _make_job(
|
||||
owner_id: MagicMock | None = None, bootstrap_key: str | None = None
|
||||
) -> MagicMock:
|
||||
job = MagicMock()
|
||||
job.id = uuid4()
|
||||
job.owner_id = owner_id or uuid4()
|
||||
job.bootstrap_key = bootstrap_key
|
||||
job.title = "Test Job"
|
||||
job.schedule_type = ScheduleType.DAILY
|
||||
job.run_at = datetime(2024, 1, 1, 9, 0, 0, tzinfo=timezone.utc)
|
||||
job.timezone = "Asia/Shanghai"
|
||||
job.status = AutomationJobStatus.ACTIVE
|
||||
job.config = {"input_template": "Hello"}
|
||||
job.next_run_at = datetime(2024, 1, 2, 9, 0, 0, tzinfo=timezone.utc)
|
||||
job.last_run_at = None
|
||||
job.created_at = datetime(2024, 1, 1, 9, 0, 0, tzinfo=timezone.utc)
|
||||
job.updated_at = datetime(2024, 1, 1, 9, 0, 0, tzinfo=timezone.utc)
|
||||
return job
|
||||
|
||||
|
||||
class TestListByOwner:
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_by_owner_returns_jobs(self) -> None:
|
||||
session = AsyncMock()
|
||||
repository = AsyncMock()
|
||||
service = AutomationJobsService(repository, session)
|
||||
owner_id = uuid4()
|
||||
job = _make_job(owner_id)
|
||||
repository.list_by_owner.return_value = [job]
|
||||
|
||||
result = await service.list_by_owner(owner_id)
|
||||
|
||||
assert len(result.items) == 1
|
||||
assert result.items[0].title == job.title
|
||||
repository.list_by_owner.assert_awaited_once_with(owner_id)
|
||||
|
||||
|
||||
class TestGetById:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_id_returns_job(self) -> None:
|
||||
session = AsyncMock()
|
||||
repository = AsyncMock()
|
||||
service = AutomationJobsService(repository, session)
|
||||
owner_id = uuid4()
|
||||
job = _make_job(owner_id)
|
||||
repository.get_by_id.return_value = job
|
||||
|
||||
result = await service.get_by_id(job.id, owner_id)
|
||||
|
||||
assert result.title == job.title
|
||||
repository.get_by_id.assert_awaited_once_with(job.id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_id_raises_not_found_when_job_none(self) -> None:
|
||||
session = AsyncMock()
|
||||
repository = AsyncMock()
|
||||
service = AutomationJobsService(repository, session)
|
||||
owner_id = uuid4()
|
||||
job_id = uuid4()
|
||||
repository.get_by_id.return_value = None
|
||||
|
||||
with pytest.raises(AutomationJobNotFound):
|
||||
await service.get_by_id(job_id, owner_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_id_raises_not_found_when_owner_mismatch(self) -> None:
|
||||
session = AsyncMock()
|
||||
repository = AsyncMock()
|
||||
service = AutomationJobsService(repository, session)
|
||||
owner_id = uuid4()
|
||||
different_owner_id = uuid4()
|
||||
job = _make_job(different_owner_id)
|
||||
repository.get_by_id.return_value = job
|
||||
|
||||
with pytest.raises(AutomationJobNotFound):
|
||||
await service.get_by_id(job.id, owner_id)
|
||||
|
||||
|
||||
class TestCreate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_raises_limit_exceeded(self) -> None:
|
||||
session = AsyncMock()
|
||||
repository = AsyncMock()
|
||||
service = AutomationJobsService(repository, session)
|
||||
owner_id = uuid4()
|
||||
data = _make_create_request()
|
||||
repository.count_user_jobs.return_value = 3
|
||||
|
||||
with pytest.raises(AutomationJobLimitExceeded):
|
||||
await service.create(owner_id, data)
|
||||
|
||||
session.execute.assert_awaited_once()
|
||||
session.rollback.assert_awaited_once()
|
||||
repository.count_user_jobs.assert_awaited_once_with(owner_id)
|
||||
repository.create.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_succeeds_when_under_limit(self) -> None:
|
||||
session = AsyncMock()
|
||||
repository = AsyncMock()
|
||||
service = AutomationJobsService(repository, session)
|
||||
owner_id = uuid4()
|
||||
data = _make_create_request()
|
||||
job = _make_job(owner_id)
|
||||
repository.count_user_jobs.return_value = 2
|
||||
repository.create.return_value = job
|
||||
|
||||
result = await service.create(owner_id, data)
|
||||
|
||||
assert result.title == job.title
|
||||
session.execute.assert_awaited_once()
|
||||
repository.create.assert_awaited_once_with(owner_id, data)
|
||||
session.commit.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_commits_session(self) -> None:
|
||||
session = AsyncMock()
|
||||
repository = AsyncMock()
|
||||
service = AutomationJobsService(repository, session)
|
||||
owner_id = uuid4()
|
||||
data = _make_create_request()
|
||||
job = _make_job(owner_id)
|
||||
repository.count_user_jobs.return_value = 0
|
||||
repository.create.return_value = job
|
||||
|
||||
await service.create(owner_id, data)
|
||||
|
||||
session.execute.assert_awaited_once()
|
||||
session.commit.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_rollbacks_on_sqlalchemy_error(self) -> None:
|
||||
session = AsyncMock()
|
||||
repository = AsyncMock()
|
||||
service = AutomationJobsService(repository, session)
|
||||
owner_id = uuid4()
|
||||
data = _make_create_request()
|
||||
repository.count_user_jobs.return_value = 0
|
||||
repository.create.side_effect = SQLAlchemyError("db down")
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await service.create(owner_id, data)
|
||||
|
||||
assert exc.value.status_code == 503
|
||||
session.execute.assert_awaited_once()
|
||||
session.rollback.assert_awaited_once()
|
||||
session.commit.assert_not_awaited()
|
||||
|
||||
|
||||
class TestUpdate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_raises_not_found_when_job_none(self) -> None:
|
||||
session = AsyncMock()
|
||||
repository = AsyncMock()
|
||||
service = AutomationJobsService(repository, session)
|
||||
owner_id = uuid4()
|
||||
job_id = uuid4()
|
||||
repository.get_by_id.return_value = None
|
||||
|
||||
with pytest.raises(AutomationJobNotFound):
|
||||
await service.update(
|
||||
job_id, owner_id, AutomationJobUpdateRequest(title="New")
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_raises_not_found_when_owner_mismatch(self) -> None:
|
||||
session = AsyncMock()
|
||||
repository = AsyncMock()
|
||||
service = AutomationJobsService(repository, session)
|
||||
owner_id = uuid4()
|
||||
different_owner_id = uuid4()
|
||||
job = _make_job(different_owner_id)
|
||||
repository.get_by_id.return_value = job
|
||||
|
||||
with pytest.raises(AutomationJobNotFound):
|
||||
await service.update(
|
||||
job.id, owner_id, AutomationJobUpdateRequest(title="New")
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_raises_system_job_forbidden(self) -> None:
|
||||
session = AsyncMock()
|
||||
repository = AsyncMock()
|
||||
service = AutomationJobsService(repository, session)
|
||||
owner_id = uuid4()
|
||||
job = _make_job(owner_id, bootstrap_key="system-key")
|
||||
repository.get_by_id.return_value = job
|
||||
|
||||
with pytest.raises(SystemJobModificationForbidden):
|
||||
await service.update(
|
||||
job.id, owner_id, AutomationJobUpdateRequest(title="New")
|
||||
)
|
||||
|
||||
repository.update.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_succeeds(self) -> None:
|
||||
session = AsyncMock()
|
||||
repository = AsyncMock()
|
||||
service = AutomationJobsService(repository, session)
|
||||
owner_id = uuid4()
|
||||
job = _make_job(owner_id)
|
||||
updated_job = _make_job(owner_id)
|
||||
updated_job.title = "Updated Title"
|
||||
repository.get_by_id.return_value = job
|
||||
repository.update.return_value = updated_job
|
||||
|
||||
result = await service.update(
|
||||
job.id, owner_id, AutomationJobUpdateRequest(title="Updated Title")
|
||||
)
|
||||
|
||||
assert result.title == "Updated Title"
|
||||
repository.update.assert_awaited_once_with(
|
||||
job.id, AutomationJobUpdateRequest(title="Updated Title")
|
||||
)
|
||||
session.commit.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_returns_not_found_when_update_returns_none(self) -> None:
|
||||
session = AsyncMock()
|
||||
repository = AsyncMock()
|
||||
service = AutomationJobsService(repository, session)
|
||||
owner_id = uuid4()
|
||||
job = _make_job(owner_id)
|
||||
repository.get_by_id.return_value = job
|
||||
repository.update.return_value = None
|
||||
|
||||
with pytest.raises(AutomationJobNotFound):
|
||||
await service.update(
|
||||
job.id, owner_id, AutomationJobUpdateRequest(title="New")
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_rollbacks_on_sqlalchemy_error(self) -> None:
|
||||
session = AsyncMock()
|
||||
repository = AsyncMock()
|
||||
service = AutomationJobsService(repository, session)
|
||||
owner_id = uuid4()
|
||||
job = _make_job(owner_id, bootstrap_key=None)
|
||||
repository.get_by_id.return_value = job
|
||||
repository.update.side_effect = SQLAlchemyError("db down")
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await service.update(
|
||||
job.id, owner_id, AutomationJobUpdateRequest(title="New")
|
||||
)
|
||||
|
||||
assert exc.value.status_code == 503
|
||||
session.rollback.assert_awaited_once()
|
||||
|
||||
|
||||
class TestDelete:
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_raises_not_found_when_job_none(self) -> None:
|
||||
session = AsyncMock()
|
||||
repository = AsyncMock()
|
||||
service = AutomationJobsService(repository, session)
|
||||
owner_id = uuid4()
|
||||
job_id = uuid4()
|
||||
repository.get_by_id.return_value = None
|
||||
|
||||
with pytest.raises(AutomationJobNotFound):
|
||||
await service.delete(job_id, owner_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_raises_not_found_when_owner_mismatch(self) -> None:
|
||||
session = AsyncMock()
|
||||
repository = AsyncMock()
|
||||
service = AutomationJobsService(repository, session)
|
||||
owner_id = uuid4()
|
||||
different_owner_id = uuid4()
|
||||
job = _make_job(different_owner_id)
|
||||
repository.get_by_id.return_value = job
|
||||
|
||||
with pytest.raises(AutomationJobNotFound):
|
||||
await service.delete(job.id, owner_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_raises_system_job_forbidden(self) -> None:
|
||||
session = AsyncMock()
|
||||
repository = AsyncMock()
|
||||
service = AutomationJobsService(repository, session)
|
||||
owner_id = uuid4()
|
||||
job = _make_job(owner_id, bootstrap_key="system-key")
|
||||
repository.get_by_id.return_value = job
|
||||
|
||||
with pytest.raises(SystemJobModificationForbidden):
|
||||
await service.delete(job.id, owner_id)
|
||||
|
||||
repository.soft_delete.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_succeeds(self) -> None:
|
||||
session = AsyncMock()
|
||||
repository = AsyncMock()
|
||||
service = AutomationJobsService(repository, session)
|
||||
owner_id = uuid4()
|
||||
job = _make_job(owner_id)
|
||||
repository.get_by_id.return_value = job
|
||||
|
||||
await service.delete(job.id, owner_id)
|
||||
|
||||
repository.soft_delete.assert_awaited_once_with(job.id)
|
||||
session.commit.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_rollbacks_on_sqlalchemy_error(self) -> None:
|
||||
session = AsyncMock()
|
||||
repository = AsyncMock()
|
||||
service = AutomationJobsService(repository, session)
|
||||
owner_id = uuid4()
|
||||
job = _make_job(owner_id, bootstrap_key=None)
|
||||
repository.get_by_id.return_value = job
|
||||
repository.soft_delete.side_effect = SQLAlchemyError("db down")
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await service.delete(job.id, owner_id)
|
||||
|
||||
assert exc.value.status_code == 503
|
||||
session.rollback.assert_awaited_once()
|
||||
Reference in New Issue
Block a user