Files
social-app/backend/tests/unit/v1/agent_chat/test_service.py
T

197 lines
6.2 KiB
Python
Raw Normal View History

from __future__ import annotations
from decimal import Decimal
from uuid import uuid4
import pytest
from sqlalchemy import Column, String, Table, select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from fastapi import HTTPException
from sqlalchemy.exc import SQLAlchemyError
from core.auth.models import CurrentUser
from core.agent.orchestrator import OrchestratorResult
from core.db.base import Base
from models.agent_chat_message import AgentChatMessage
from models.agent_chat_session import AgentChatSession
from v1.agent.schemas import AgentChatRunRequest
from v1.agent.service import AgentChatService
@pytest.fixture
async def db_engine():
users_table = Table(
"users",
Base.metadata,
Column("id", String, primary_key=True),
schema="auth",
extend_existing=True,
)
engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
async with engine.begin() as conn:
await conn.exec_driver_sql("ATTACH DATABASE ':memory:' AS auth")
await conn.run_sync(Base.metadata.create_all)
yield engine
Base.metadata.remove(users_table)
await engine.dispose()
@pytest.fixture
async def db_session(db_engine):
async_session = async_sessionmaker(
bind=db_engine,
class_=AsyncSession,
expire_on_commit=False,
)
async with async_session() as session:
yield session
await session.rollback()
@pytest.mark.asyncio
async def test_run_creates_session_and_persists_messages(
db_session: AsyncSession,
) -> None:
user = CurrentUser(id=uuid4())
service = AgentChatService(session=db_session, current_user=user)
result = await service.run(AgentChatRunRequest(message="hello"))
assert result.session_id is not None
assert result.output == "hello"
assert [event.type for event in result.events] == [
"run.started",
"message.delta",
"run.completed",
]
session_obj = await db_session.get(AgentChatSession, result.session_id)
assert session_obj is not None
assert session_obj.message_count == 2
assert session_obj.status.value == "completed"
rows = await db_session.execute(
select(AgentChatMessage)
.where(AgentChatMessage.session_id == result.session_id)
.order_by(AgentChatMessage.seq.asc())
)
messages = rows.scalars().all()
assert len(messages) == 2
assert messages[0].role.value == "user"
assert messages[1].role.value == "assistant"
@pytest.mark.asyncio
async def test_run_appends_to_existing_session(db_session: AsyncSession) -> None:
user = CurrentUser(id=uuid4())
service = AgentChatService(session=db_session, current_user=user)
first = await service.run(AgentChatRunRequest(message="first"))
second = await service.run(
AgentChatRunRequest(message="second", session_id=first.session_id)
)
assert second.session_id == first.session_id
session_obj = await db_session.get(AgentChatSession, first.session_id)
assert session_obj is not None
assert session_obj.message_count == 4
@pytest.mark.asyncio
async def test_run_raises_502_and_marks_session_failed_when_orchestrator_fails(
db_session: AsyncSession,
) -> None:
user = CurrentUser(id=uuid4())
service = AgentChatService(session=db_session, current_user=user)
class _FailingOrchestrator:
async def run(self, *, run_id: str, user_message: str) -> OrchestratorResult:
return OrchestratorResult(
output="",
usage={
"input_tokens": 0,
"output_tokens": 0,
"total_tokens": 0,
"cost": Decimal("0"),
"currency": "USD",
},
events=[],
context={},
failed=True,
error="stage failed",
)
service._orchestrator = _FailingOrchestrator() # type: ignore[assignment]
with pytest.raises(HTTPException) as exc_info:
await service.run(AgentChatRunRequest(message="hello"))
assert exc_info.value.status_code == 502
rows = await db_session.execute(
select(AgentChatSession).where(AgentChatSession.user_id == user.id)
)
stored_session = rows.scalars().one()
assert stored_session.status.value == "failed"
@pytest.mark.asyncio
async def test_run_returns_422_when_message_is_blank(db_session: AsyncSession) -> None:
user = CurrentUser(id=uuid4())
service = AgentChatService(session=db_session, current_user=user)
with pytest.raises(HTTPException) as exc_info:
await service.run(AgentChatRunRequest(message=" "))
assert exc_info.value.status_code == 422
@pytest.mark.asyncio
async def test_run_returns_404_when_session_not_found(db_session: AsyncSession) -> None:
user = CurrentUser(id=uuid4())
service = AgentChatService(session=db_session, current_user=user)
with pytest.raises(HTTPException) as exc_info:
await service.run(AgentChatRunRequest(message="hello", session_id=uuid4()))
assert exc_info.value.status_code == 404
@pytest.mark.asyncio
async def test_run_returns_503_when_commit_raises_sqlalchemy_error(
db_session: AsyncSession,
monkeypatch: pytest.MonkeyPatch,
) -> None:
user = CurrentUser(id=uuid4())
service = AgentChatService(session=db_session, current_user=user)
async def _fail_commit() -> None:
raise SQLAlchemyError("db down")
monkeypatch.setattr(db_session, "commit", _fail_commit)
with pytest.raises(HTTPException) as exc_info:
await service.run(AgentChatRunRequest(message="hello"))
assert exc_info.value.status_code == 503
@pytest.mark.asyncio
async def test_run_returns_502_for_unexpected_exception(
db_session: AsyncSession,
) -> None:
user = CurrentUser(id=uuid4())
service = AgentChatService(session=db_session, current_user=user)
class _CrashingOrchestrator:
async def run(self, *, run_id: str, user_message: str) -> OrchestratorResult:
raise RuntimeError("unexpected")
service._orchestrator = _CrashingOrchestrator() # type: ignore[assignment]
with pytest.raises(HTTPException) as exc_info:
await service.run(AgentChatRunRequest(message="hello"))
assert exc_info.value.status_code == 502