399 lines
13 KiB
Python
399 lines
13 KiB
Python
from datetime import datetime, timezone
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
from uuid import UUID, uuid4
|
|
|
|
import pytest
|
|
from fastapi import HTTPException
|
|
from sqlalchemy.exc import SQLAlchemyError
|
|
|
|
from models.automation_jobs import AutomationJobStatus, ScheduleType
|
|
from v1.automation_jobs.service import (
|
|
AutomationJobLimitExceeded,
|
|
AutomationJobNotFound,
|
|
AutomationJobsService,
|
|
SystemJobModificationForbidden,
|
|
)
|
|
from v1.automation_jobs.schemas import (
|
|
AutomationJobCreateRequest,
|
|
AutomationJobUpdateRequest,
|
|
)
|
|
from schemas.automation import (
|
|
AgentTool,
|
|
AutomationJobConfig,
|
|
ContextSource,
|
|
ContextWindowMode,
|
|
MessageContextConfig,
|
|
ScheduleConfig,
|
|
ScheduleRunAt,
|
|
)
|
|
|
|
|
|
def _make_config() -> AutomationJobConfig:
|
|
return AutomationJobConfig(
|
|
input_template="Hello",
|
|
enabled_tools=[AgentTool.MEMORY_WRITE],
|
|
context=MessageContextConfig(
|
|
source=ContextSource.LATEST_CHAT,
|
|
window_mode=ContextWindowMode.DAY,
|
|
window_count=2,
|
|
),
|
|
schedule=ScheduleConfig(
|
|
type=ScheduleType.DAILY,
|
|
run_at=ScheduleRunAt(hour=9, minute=0),
|
|
),
|
|
)
|
|
|
|
|
|
def _make_create_request() -> AutomationJobCreateRequest:
|
|
return AutomationJobCreateRequest(
|
|
title="Test Job",
|
|
timezone="Asia/Shanghai",
|
|
status=AutomationJobStatus.ACTIVE,
|
|
config=_make_config(),
|
|
)
|
|
|
|
|
|
def _make_job(
|
|
owner_id: UUID | None = None, bootstrap_key: str | None = None
|
|
) -> MagicMock:
|
|
job = MagicMock()
|
|
job.id = uuid4()
|
|
job.owner_id = owner_id or uuid4()
|
|
job.bootstrap_key = bootstrap_key
|
|
job.title = "Test Job"
|
|
job.timezone = "Asia/Shanghai"
|
|
job.status = AutomationJobStatus.ACTIVE
|
|
job.config = {
|
|
"input_template": "Hello",
|
|
"enabled_tools": ["memory.write"],
|
|
"context": {
|
|
"source": "latest_chat",
|
|
"window_mode": "day",
|
|
"window_count": 2,
|
|
},
|
|
"schedule": {
|
|
"type": "daily",
|
|
"run_at": {"hour": 9, "minute": 0},
|
|
},
|
|
}
|
|
job.next_run_at = datetime(2024, 1, 2, 9, 0, 0, tzinfo=timezone.utc)
|
|
job.last_run_at = None
|
|
job.created_at = datetime(2024, 1, 1, 9, 0, 0, tzinfo=timezone.utc)
|
|
job.updated_at = datetime(2024, 1, 1, 9, 0, 0, tzinfo=timezone.utc)
|
|
return job
|
|
|
|
|
|
class TestListByOwner:
|
|
@pytest.mark.asyncio
|
|
async def test_list_by_owner_returns_jobs(self) -> None:
|
|
session = AsyncMock()
|
|
repository = AsyncMock()
|
|
service = AutomationJobsService(repository, session)
|
|
owner_id = uuid4()
|
|
job = _make_job(owner_id)
|
|
repository.list_by_owner.return_value = [job]
|
|
|
|
result = await service.list_by_owner(owner_id)
|
|
|
|
assert len(result.items) == 1
|
|
assert result.items[0].title == job.title
|
|
repository.list_by_owner.assert_awaited_once_with(owner_id)
|
|
|
|
|
|
class TestGetById:
|
|
@pytest.mark.asyncio
|
|
async def test_get_by_id_returns_job(self) -> None:
|
|
session = AsyncMock()
|
|
repository = AsyncMock()
|
|
service = AutomationJobsService(repository, session)
|
|
owner_id = uuid4()
|
|
job = _make_job(owner_id)
|
|
repository.get_by_id.return_value = job
|
|
|
|
result = await service.get_by_id(job.id, owner_id)
|
|
|
|
assert result.title == job.title
|
|
repository.get_by_id.assert_awaited_once_with(job.id)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_by_id_raises_not_found_when_job_none(self) -> None:
|
|
session = AsyncMock()
|
|
repository = AsyncMock()
|
|
service = AutomationJobsService(repository, session)
|
|
owner_id = uuid4()
|
|
job_id = uuid4()
|
|
repository.get_by_id.return_value = None
|
|
|
|
with pytest.raises(AutomationJobNotFound):
|
|
await service.get_by_id(job_id, owner_id)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_by_id_raises_not_found_when_owner_mismatch(self) -> None:
|
|
session = AsyncMock()
|
|
repository = AsyncMock()
|
|
service = AutomationJobsService(repository, session)
|
|
owner_id = uuid4()
|
|
different_owner_id = uuid4()
|
|
job = _make_job(different_owner_id)
|
|
repository.get_by_id.return_value = job
|
|
|
|
with pytest.raises(AutomationJobNotFound):
|
|
await service.get_by_id(job.id, owner_id)
|
|
|
|
|
|
class TestCreate:
|
|
@pytest.mark.asyncio
|
|
async def test_create_raises_limit_exceeded(self) -> None:
|
|
session = AsyncMock()
|
|
repository = AsyncMock()
|
|
service = AutomationJobsService(repository, session)
|
|
owner_id = uuid4()
|
|
data = _make_create_request()
|
|
repository.count_user_jobs.return_value = 3
|
|
|
|
with pytest.raises(AutomationJobLimitExceeded):
|
|
await service.create(owner_id, data)
|
|
|
|
session.execute.assert_awaited_once()
|
|
session.rollback.assert_awaited_once()
|
|
repository.count_user_jobs.assert_awaited_once_with(owner_id)
|
|
repository.create.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_succeeds_when_under_limit(self) -> None:
|
|
session = AsyncMock()
|
|
repository = AsyncMock()
|
|
service = AutomationJobsService(repository, session)
|
|
owner_id = uuid4()
|
|
data = _make_create_request()
|
|
job = _make_job(owner_id)
|
|
repository.count_user_jobs.return_value = 2
|
|
repository.create.return_value = job
|
|
|
|
result = await service.create(owner_id, data)
|
|
|
|
assert result.title == job.title
|
|
session.execute.assert_awaited_once()
|
|
repository.create.assert_awaited_once_with(owner_id, data)
|
|
session.commit.assert_awaited_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_commits_session(self) -> None:
|
|
session = AsyncMock()
|
|
repository = AsyncMock()
|
|
service = AutomationJobsService(repository, session)
|
|
owner_id = uuid4()
|
|
data = _make_create_request()
|
|
job = _make_job(owner_id)
|
|
repository.count_user_jobs.return_value = 0
|
|
repository.create.return_value = job
|
|
|
|
await service.create(owner_id, data)
|
|
|
|
session.execute.assert_awaited_once()
|
|
session.commit.assert_awaited_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_rollbacks_on_sqlalchemy_error(self) -> None:
|
|
session = AsyncMock()
|
|
repository = AsyncMock()
|
|
service = AutomationJobsService(repository, session)
|
|
owner_id = uuid4()
|
|
data = _make_create_request()
|
|
repository.count_user_jobs.return_value = 0
|
|
repository.create.side_effect = SQLAlchemyError("db down")
|
|
|
|
with pytest.raises(HTTPException) as exc:
|
|
await service.create(owner_id, data)
|
|
|
|
assert exc.value.status_code == 503
|
|
session.execute.assert_awaited_once()
|
|
session.rollback.assert_awaited_once()
|
|
session.commit.assert_not_awaited()
|
|
|
|
|
|
class TestUpdate:
|
|
@pytest.mark.asyncio
|
|
async def test_update_raises_not_found_when_job_none(self) -> None:
|
|
session = AsyncMock()
|
|
repository = AsyncMock()
|
|
service = AutomationJobsService(repository, session)
|
|
owner_id = uuid4()
|
|
job_id = uuid4()
|
|
repository.get_by_id.return_value = None
|
|
|
|
with pytest.raises(AutomationJobNotFound):
|
|
await service.update(
|
|
job_id,
|
|
owner_id,
|
|
AutomationJobUpdateRequest(title="New", timezone="UTC"),
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_raises_not_found_when_owner_mismatch(self) -> None:
|
|
session = AsyncMock()
|
|
repository = AsyncMock()
|
|
service = AutomationJobsService(repository, session)
|
|
owner_id = uuid4()
|
|
different_owner_id = uuid4()
|
|
job = _make_job(different_owner_id)
|
|
repository.get_by_id.return_value = job
|
|
|
|
with pytest.raises(AutomationJobNotFound):
|
|
await service.update(
|
|
job.id,
|
|
owner_id,
|
|
AutomationJobUpdateRequest(title="New", timezone="UTC"),
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_raises_system_job_forbidden(self) -> None:
|
|
session = AsyncMock()
|
|
repository = AsyncMock()
|
|
service = AutomationJobsService(repository, session)
|
|
owner_id = uuid4()
|
|
job = _make_job(owner_id, bootstrap_key="system-key")
|
|
repository.get_by_id.return_value = job
|
|
|
|
with pytest.raises(SystemJobModificationForbidden):
|
|
await service.update(
|
|
job.id,
|
|
owner_id,
|
|
AutomationJobUpdateRequest(title="New", timezone="UTC"),
|
|
)
|
|
|
|
repository.update.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_succeeds(self) -> None:
|
|
session = AsyncMock()
|
|
repository = AsyncMock()
|
|
service = AutomationJobsService(repository, session)
|
|
owner_id = uuid4()
|
|
job = _make_job(owner_id)
|
|
updated_job = _make_job(owner_id)
|
|
updated_job.title = "Updated Title"
|
|
repository.get_by_id.return_value = job
|
|
repository.update.return_value = updated_job
|
|
|
|
result = await service.update(
|
|
job.id,
|
|
owner_id,
|
|
AutomationJobUpdateRequest(title="Updated Title", timezone="UTC"),
|
|
)
|
|
|
|
assert result.title == "Updated Title"
|
|
repository.update.assert_awaited_once_with(
|
|
job.id,
|
|
AutomationJobUpdateRequest(title="Updated Title", timezone="UTC"),
|
|
)
|
|
session.commit.assert_awaited_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_returns_not_found_when_update_returns_none(self) -> None:
|
|
session = AsyncMock()
|
|
repository = AsyncMock()
|
|
service = AutomationJobsService(repository, session)
|
|
owner_id = uuid4()
|
|
job = _make_job(owner_id)
|
|
repository.get_by_id.return_value = job
|
|
repository.update.return_value = None
|
|
|
|
with pytest.raises(AutomationJobNotFound):
|
|
await service.update(
|
|
job.id,
|
|
owner_id,
|
|
AutomationJobUpdateRequest(title="New", timezone="UTC"),
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_rollbacks_on_sqlalchemy_error(self) -> None:
|
|
session = AsyncMock()
|
|
repository = AsyncMock()
|
|
service = AutomationJobsService(repository, session)
|
|
owner_id = uuid4()
|
|
job = _make_job(owner_id, bootstrap_key=None)
|
|
repository.get_by_id.return_value = job
|
|
repository.update.side_effect = SQLAlchemyError("db down")
|
|
|
|
with pytest.raises(HTTPException) as exc:
|
|
await service.update(
|
|
job.id,
|
|
owner_id,
|
|
AutomationJobUpdateRequest(title="New", timezone="UTC"),
|
|
)
|
|
|
|
assert exc.value.status_code == 503
|
|
session.rollback.assert_awaited_once()
|
|
|
|
|
|
class TestDelete:
|
|
@pytest.mark.asyncio
|
|
async def test_delete_raises_not_found_when_job_none(self) -> None:
|
|
session = AsyncMock()
|
|
repository = AsyncMock()
|
|
service = AutomationJobsService(repository, session)
|
|
owner_id = uuid4()
|
|
job_id = uuid4()
|
|
repository.get_by_id.return_value = None
|
|
|
|
with pytest.raises(AutomationJobNotFound):
|
|
await service.delete(job_id, owner_id)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_raises_not_found_when_owner_mismatch(self) -> None:
|
|
session = AsyncMock()
|
|
repository = AsyncMock()
|
|
service = AutomationJobsService(repository, session)
|
|
owner_id = uuid4()
|
|
different_owner_id = uuid4()
|
|
job = _make_job(different_owner_id)
|
|
repository.get_by_id.return_value = job
|
|
|
|
with pytest.raises(AutomationJobNotFound):
|
|
await service.delete(job.id, owner_id)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_raises_system_job_forbidden(self) -> None:
|
|
session = AsyncMock()
|
|
repository = AsyncMock()
|
|
service = AutomationJobsService(repository, session)
|
|
owner_id = uuid4()
|
|
job = _make_job(owner_id, bootstrap_key="system-key")
|
|
repository.get_by_id.return_value = job
|
|
|
|
with pytest.raises(SystemJobModificationForbidden):
|
|
await service.delete(job.id, owner_id)
|
|
|
|
repository.soft_delete.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_succeeds(self) -> None:
|
|
session = AsyncMock()
|
|
repository = AsyncMock()
|
|
service = AutomationJobsService(repository, session)
|
|
owner_id = uuid4()
|
|
job = _make_job(owner_id)
|
|
repository.get_by_id.return_value = job
|
|
|
|
await service.delete(job.id, owner_id)
|
|
|
|
repository.soft_delete.assert_awaited_once_with(job.id)
|
|
session.commit.assert_awaited_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_rollbacks_on_sqlalchemy_error(self) -> None:
|
|
session = AsyncMock()
|
|
repository = AsyncMock()
|
|
service = AutomationJobsService(repository, session)
|
|
owner_id = uuid4()
|
|
job = _make_job(owner_id, bootstrap_key=None)
|
|
repository.get_by_id.return_value = job
|
|
repository.soft_delete.side_effect = SQLAlchemyError("db down")
|
|
|
|
with pytest.raises(HTTPException) as exc:
|
|
await service.delete(job.id, owner_id)
|
|
|
|
assert exc.value.status_code == 503
|
|
session.rollback.assert_awaited_once()
|