126 lines
3.2 KiB
Python
126 lines
3.2 KiB
Python
from __future__ import annotations
|
|
|
|
from uuid import UUID
|
|
|
|
from core.auth.models import CurrentUser
|
|
from v1.agent.service import AgentService
|
|
|
|
|
|
class _FakeRepository:
|
|
def __init__(self) -> None:
|
|
self.committed = False
|
|
self.rolled_back = False
|
|
self.deleted_session_id: str | None = None
|
|
|
|
async def get_session_owner(self, *, session_id: str) -> str:
|
|
del session_id
|
|
return "00000000-0000-0000-0000-000000000001"
|
|
|
|
async def create_session_for_user(self, *, user_id: str) -> str:
|
|
del user_id
|
|
return "00000000-0000-0000-0000-000000000999"
|
|
|
|
async def commit(self) -> None:
|
|
self.committed = True
|
|
|
|
async def rollback(self) -> None:
|
|
self.rolled_back = True
|
|
|
|
async def delete_session(self, *, session_id: str) -> None:
|
|
self.deleted_session_id = session_id
|
|
|
|
|
|
class _FakeQueue:
|
|
async def enqueue(
|
|
self, *, command: dict[str, object], dedup_key: str | None
|
|
) -> str:
|
|
del command, dedup_key
|
|
return "task-1"
|
|
|
|
|
|
class _FailingQueue:
|
|
async def enqueue(
|
|
self, *, command: dict[str, object], dedup_key: str | None
|
|
) -> str:
|
|
del command, dedup_key
|
|
raise RuntimeError("enqueue failed")
|
|
|
|
|
|
class _FakeStream:
|
|
async def read(
|
|
self, *, session_id: str, last_event_id: str | None
|
|
) -> list[dict[str, object]]:
|
|
del session_id
|
|
return [
|
|
{"id": "2-0", "event": {"type": "RUN_STARTED"}, "cursor": last_event_id}
|
|
]
|
|
|
|
|
|
def _user() -> CurrentUser:
|
|
return CurrentUser(
|
|
id=UUID("00000000-0000-0000-0000-000000000001"),
|
|
email="user@example.com",
|
|
)
|
|
|
|
|
|
async def test_resume_idempotency_uses_redis_lock_and_task_key() -> None:
|
|
service = AgentService(
|
|
repository=_FakeRepository(),
|
|
queue=_FakeQueue(),
|
|
stream=_FakeStream(),
|
|
)
|
|
user = _user()
|
|
|
|
first = await service.enqueue_resume(
|
|
session_id="session-1",
|
|
tool_call_id="call-1",
|
|
current_user=user,
|
|
)
|
|
second = await service.enqueue_resume(
|
|
session_id="session-1",
|
|
tool_call_id="call-1",
|
|
current_user=user,
|
|
)
|
|
|
|
assert first.task_id == second.task_id
|
|
|
|
|
|
async def test_enqueue_run_without_session_creates_new_session() -> None:
|
|
repository = _FakeRepository()
|
|
service = AgentService(
|
|
repository=repository,
|
|
queue=_FakeQueue(),
|
|
stream=_FakeStream(),
|
|
)
|
|
|
|
accepted = await service.enqueue_run(
|
|
session_id=None,
|
|
prompt="hello",
|
|
current_user=_user(),
|
|
)
|
|
|
|
assert accepted.session_id == "00000000-0000-0000-0000-000000000999"
|
|
assert accepted.created is True
|
|
assert repository.committed is True
|
|
|
|
|
|
async def test_enqueue_run_keeps_created_session_when_enqueue_fails() -> None:
|
|
repository = _FakeRepository()
|
|
service = AgentService(
|
|
repository=repository,
|
|
queue=_FailingQueue(),
|
|
stream=_FakeStream(),
|
|
)
|
|
|
|
try:
|
|
await service.enqueue_run(
|
|
session_id=None,
|
|
prompt="hello",
|
|
current_user=_user(),
|
|
)
|
|
raise AssertionError("expected RuntimeError")
|
|
except RuntimeError as exc:
|
|
assert str(exc) == "enqueue failed"
|
|
|
|
assert repository.deleted_session_id is None
|