from __future__ import annotations from decimal import Decimal from uuid import uuid4 import pytest from sqlalchemy import Column, String, Table, select from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from fastapi import HTTPException from sqlalchemy.exc import SQLAlchemyError from core.auth.models import CurrentUser from core.agent.orchestrator import OrchestratorResult from core.db.base import Base from models.agent_chat_message import AgentChatMessage from models.agent_chat_session import AgentChatSession from v1.agent.schemas import AgentChatRunRequest from v1.agent.service import AgentChatService @pytest.fixture async def db_engine(): users_table = Table( "users", Base.metadata, Column("id", String, primary_key=True), schema="auth", extend_existing=True, ) engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False) async with engine.begin() as conn: await conn.exec_driver_sql("ATTACH DATABASE ':memory:' AS auth") await conn.run_sync(Base.metadata.create_all) yield engine Base.metadata.remove(users_table) await engine.dispose() @pytest.fixture async def db_session(db_engine): async_session = async_sessionmaker( bind=db_engine, class_=AsyncSession, expire_on_commit=False, ) async with async_session() as session: yield session await session.rollback() @pytest.mark.asyncio async def test_run_creates_session_and_persists_messages( db_session: AsyncSession, ) -> None: user = CurrentUser(id=uuid4()) service = AgentChatService(session=db_session, current_user=user) result = await service.run(AgentChatRunRequest(message="hello")) assert result.session_id is not None assert result.output == "hello" assert [event.type for event in result.events] == [ "run.started", "message.delta", "run.completed", ] session_obj = await db_session.get(AgentChatSession, result.session_id) assert session_obj is not None assert session_obj.message_count == 2 assert session_obj.status.value == "completed" rows = await db_session.execute( select(AgentChatMessage) .where(AgentChatMessage.session_id == result.session_id) .order_by(AgentChatMessage.seq.asc()) ) messages = rows.scalars().all() assert len(messages) == 2 assert messages[0].role.value == "user" assert messages[1].role.value == "assistant" @pytest.mark.asyncio async def test_run_appends_to_existing_session(db_session: AsyncSession) -> None: user = CurrentUser(id=uuid4()) service = AgentChatService(session=db_session, current_user=user) first = await service.run(AgentChatRunRequest(message="first")) second = await service.run( AgentChatRunRequest(message="second", session_id=first.session_id) ) assert second.session_id == first.session_id session_obj = await db_session.get(AgentChatSession, first.session_id) assert session_obj is not None assert session_obj.message_count == 4 @pytest.mark.asyncio async def test_run_raises_502_and_marks_session_failed_when_orchestrator_fails( db_session: AsyncSession, ) -> None: user = CurrentUser(id=uuid4()) service = AgentChatService(session=db_session, current_user=user) class _FailingOrchestrator: async def run(self, *, run_id: str, user_message: str) -> OrchestratorResult: return OrchestratorResult( output="", usage={ "input_tokens": 0, "output_tokens": 0, "total_tokens": 0, "cost": Decimal("0"), "currency": "USD", }, events=[], context={}, failed=True, error="stage failed", ) service._orchestrator = _FailingOrchestrator() # type: ignore[assignment] with pytest.raises(HTTPException) as exc_info: await service.run(AgentChatRunRequest(message="hello")) assert exc_info.value.status_code == 502 rows = await db_session.execute( select(AgentChatSession).where(AgentChatSession.user_id == user.id) ) stored_session = rows.scalars().one() assert stored_session.status.value == "failed" @pytest.mark.asyncio async def test_run_returns_422_when_message_is_blank(db_session: AsyncSession) -> None: user = CurrentUser(id=uuid4()) service = AgentChatService(session=db_session, current_user=user) with pytest.raises(HTTPException) as exc_info: await service.run(AgentChatRunRequest(message=" ")) assert exc_info.value.status_code == 422 @pytest.mark.asyncio async def test_run_returns_404_when_session_not_found(db_session: AsyncSession) -> None: user = CurrentUser(id=uuid4()) service = AgentChatService(session=db_session, current_user=user) with pytest.raises(HTTPException) as exc_info: await service.run(AgentChatRunRequest(message="hello", session_id=uuid4())) assert exc_info.value.status_code == 404 @pytest.mark.asyncio async def test_run_returns_503_when_commit_raises_sqlalchemy_error( db_session: AsyncSession, monkeypatch: pytest.MonkeyPatch, ) -> None: user = CurrentUser(id=uuid4()) service = AgentChatService(session=db_session, current_user=user) async def _fail_commit() -> None: raise SQLAlchemyError("db down") monkeypatch.setattr(db_session, "commit", _fail_commit) with pytest.raises(HTTPException) as exc_info: await service.run(AgentChatRunRequest(message="hello")) assert exc_info.value.status_code == 503 @pytest.mark.asyncio async def test_run_returns_502_for_unexpected_exception( db_session: AsyncSession, ) -> None: user = CurrentUser(id=uuid4()) service = AgentChatService(session=db_session, current_user=user) class _CrashingOrchestrator: async def run(self, *, run_id: str, user_message: str) -> OrchestratorResult: raise RuntimeError("unexpected") service._orchestrator = _CrashingOrchestrator() # type: ignore[assignment] with pytest.raises(HTTPException) as exc_info: await service.run(AgentChatRunRequest(message="hello")) assert exc_info.value.status_code == 502