98 lines
3.0 KiB
Python
98 lines
3.0 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
from datetime import datetime
|
||
|
|
from decimal import Decimal
|
||
|
|
from types import MethodType
|
||
|
|
from uuid import UUID, uuid4
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
|
||
|
|
from core.auth.models import CurrentUser
|
||
|
|
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
|
||
|
|
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
|
||
|
|
from v1.agent_chat.schemas import AgentChatRunRequest
|
||
|
|
from v1.agent_chat.service import AgentChatService
|
||
|
|
|
||
|
|
|
||
|
|
class _FakeAsyncSession:
|
||
|
|
def __init__(self) -> None:
|
||
|
|
self.added: list[object] = []
|
||
|
|
self.committed = False
|
||
|
|
self.rolled_back = False
|
||
|
|
|
||
|
|
def add(self, obj: object) -> None:
|
||
|
|
self.added.append(obj)
|
||
|
|
|
||
|
|
async def flush(self) -> None:
|
||
|
|
return None
|
||
|
|
|
||
|
|
async def commit(self) -> None:
|
||
|
|
self.committed = True
|
||
|
|
|
||
|
|
async def rollback(self) -> None:
|
||
|
|
self.rolled_back = True
|
||
|
|
|
||
|
|
async def refresh(self, obj: object) -> None:
|
||
|
|
if isinstance(obj, AgentChatSession) and obj.id is None:
|
||
|
|
obj.id = uuid4()
|
||
|
|
if isinstance(obj, AgentChatMessage) and obj.id is None:
|
||
|
|
obj.id = uuid4()
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_run_persists_messages_and_emits_ordered_events() -> None:
|
||
|
|
fake_db = _FakeAsyncSession()
|
||
|
|
service = AgentChatService(
|
||
|
|
session=fake_db, # type: ignore[arg-type]
|
||
|
|
current_user=CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001")),
|
||
|
|
)
|
||
|
|
|
||
|
|
async def _resolve_session(
|
||
|
|
self: AgentChatService,
|
||
|
|
*,
|
||
|
|
session_id: object | None,
|
||
|
|
user_id: UUID,
|
||
|
|
first_message: str,
|
||
|
|
now: datetime,
|
||
|
|
) -> AgentChatSession:
|
||
|
|
assert session_id is None
|
||
|
|
assert first_message == "hello"
|
||
|
|
return AgentChatSession(
|
||
|
|
id=UUID("00000000-0000-0000-0000-000000000111"),
|
||
|
|
user_id=user_id,
|
||
|
|
title="hello",
|
||
|
|
status=AgentChatSessionStatus.RUNNING,
|
||
|
|
last_activity_at=now,
|
||
|
|
message_count=0,
|
||
|
|
total_tokens=0,
|
||
|
|
total_cost=Decimal("0"),
|
||
|
|
created_at=now,
|
||
|
|
updated_at=now,
|
||
|
|
deleted_at=None,
|
||
|
|
)
|
||
|
|
|
||
|
|
async def _next_seq_base(self: AgentChatService, session_id: object) -> int:
|
||
|
|
assert session_id == UUID("00000000-0000-0000-0000-000000000111")
|
||
|
|
return 2
|
||
|
|
|
||
|
|
service._resolve_session = MethodType(_resolve_session, service) # type: ignore[method-assign]
|
||
|
|
service._next_seq_base = MethodType(_next_seq_base, service) # type: ignore[method-assign]
|
||
|
|
|
||
|
|
response = await service.run(AgentChatRunRequest(message="hello"))
|
||
|
|
|
||
|
|
assert fake_db.committed is True
|
||
|
|
inserted_messages = [
|
||
|
|
item for item in fake_db.added if isinstance(item, AgentChatMessage)
|
||
|
|
]
|
||
|
|
assert len(inserted_messages) == 2
|
||
|
|
assert [msg.seq for msg in inserted_messages] == [3, 4]
|
||
|
|
assert [msg.role for msg in inserted_messages] == [
|
||
|
|
AgentChatMessageRole.USER,
|
||
|
|
AgentChatMessageRole.ASSISTANT,
|
||
|
|
]
|
||
|
|
assert [event.type for event in response.events] == [
|
||
|
|
"run.started",
|
||
|
|
"message.delta",
|
||
|
|
"run.completed",
|
||
|
|
]
|