197 lines
6.2 KiB
Python
197 lines
6.2 KiB
Python
from __future__ import annotations
|
|
|
|
from decimal import Decimal
|
|
from uuid import uuid4
|
|
|
|
import pytest
|
|
from sqlalchemy import Column, String, Table, select
|
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
|
from fastapi import HTTPException
|
|
from sqlalchemy.exc import SQLAlchemyError
|
|
|
|
from core.auth.models import CurrentUser
|
|
from core.agent.orchestrator import OrchestratorResult
|
|
from core.db.base import Base
|
|
from models.agent_chat_message import AgentChatMessage
|
|
from models.agent_chat_session import AgentChatSession
|
|
from v1.agent.schemas import AgentChatRunRequest
|
|
from v1.agent.service import AgentChatService
|
|
|
|
|
|
@pytest.fixture
|
|
async def db_engine():
|
|
users_table = Table(
|
|
"users",
|
|
Base.metadata,
|
|
Column("id", String, primary_key=True),
|
|
schema="auth",
|
|
extend_existing=True,
|
|
)
|
|
engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
|
|
async with engine.begin() as conn:
|
|
await conn.exec_driver_sql("ATTACH DATABASE ':memory:' AS auth")
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
yield engine
|
|
Base.metadata.remove(users_table)
|
|
await engine.dispose()
|
|
|
|
|
|
@pytest.fixture
|
|
async def db_session(db_engine):
|
|
async_session = async_sessionmaker(
|
|
bind=db_engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False,
|
|
)
|
|
async with async_session() as session:
|
|
yield session
|
|
await session.rollback()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_creates_session_and_persists_messages(
|
|
db_session: AsyncSession,
|
|
) -> None:
|
|
user = CurrentUser(id=uuid4())
|
|
service = AgentChatService(session=db_session, current_user=user)
|
|
|
|
result = await service.run(AgentChatRunRequest(message="hello"))
|
|
|
|
assert result.session_id is not None
|
|
assert result.output == "hello"
|
|
assert [event.type for event in result.events] == [
|
|
"run.started",
|
|
"message.delta",
|
|
"run.completed",
|
|
]
|
|
|
|
session_obj = await db_session.get(AgentChatSession, result.session_id)
|
|
assert session_obj is not None
|
|
assert session_obj.message_count == 2
|
|
assert session_obj.status.value == "completed"
|
|
|
|
rows = await db_session.execute(
|
|
select(AgentChatMessage)
|
|
.where(AgentChatMessage.session_id == result.session_id)
|
|
.order_by(AgentChatMessage.seq.asc())
|
|
)
|
|
messages = rows.scalars().all()
|
|
assert len(messages) == 2
|
|
assert messages[0].role.value == "user"
|
|
assert messages[1].role.value == "assistant"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_appends_to_existing_session(db_session: AsyncSession) -> None:
|
|
user = CurrentUser(id=uuid4())
|
|
service = AgentChatService(session=db_session, current_user=user)
|
|
|
|
first = await service.run(AgentChatRunRequest(message="first"))
|
|
second = await service.run(
|
|
AgentChatRunRequest(message="second", session_id=first.session_id)
|
|
)
|
|
|
|
assert second.session_id == first.session_id
|
|
|
|
session_obj = await db_session.get(AgentChatSession, first.session_id)
|
|
assert session_obj is not None
|
|
assert session_obj.message_count == 4
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_raises_502_and_marks_session_failed_when_orchestrator_fails(
|
|
db_session: AsyncSession,
|
|
) -> None:
|
|
user = CurrentUser(id=uuid4())
|
|
service = AgentChatService(session=db_session, current_user=user)
|
|
|
|
class _FailingOrchestrator:
|
|
async def run(self, *, run_id: str, user_message: str) -> OrchestratorResult:
|
|
return OrchestratorResult(
|
|
output="",
|
|
usage={
|
|
"input_tokens": 0,
|
|
"output_tokens": 0,
|
|
"total_tokens": 0,
|
|
"cost": Decimal("0"),
|
|
"currency": "USD",
|
|
},
|
|
events=[],
|
|
context={},
|
|
failed=True,
|
|
error="stage failed",
|
|
)
|
|
|
|
service._orchestrator = _FailingOrchestrator() # type: ignore[assignment]
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
await service.run(AgentChatRunRequest(message="hello"))
|
|
|
|
assert exc_info.value.status_code == 502
|
|
|
|
rows = await db_session.execute(
|
|
select(AgentChatSession).where(AgentChatSession.user_id == user.id)
|
|
)
|
|
stored_session = rows.scalars().one()
|
|
assert stored_session.status.value == "failed"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_returns_422_when_message_is_blank(db_session: AsyncSession) -> None:
|
|
user = CurrentUser(id=uuid4())
|
|
service = AgentChatService(session=db_session, current_user=user)
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
await service.run(AgentChatRunRequest(message=" "))
|
|
|
|
assert exc_info.value.status_code == 422
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_returns_404_when_session_not_found(db_session: AsyncSession) -> None:
|
|
user = CurrentUser(id=uuid4())
|
|
service = AgentChatService(session=db_session, current_user=user)
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
await service.run(AgentChatRunRequest(message="hello", session_id=uuid4()))
|
|
|
|
assert exc_info.value.status_code == 404
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_returns_503_when_commit_raises_sqlalchemy_error(
|
|
db_session: AsyncSession,
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
user = CurrentUser(id=uuid4())
|
|
service = AgentChatService(session=db_session, current_user=user)
|
|
|
|
async def _fail_commit() -> None:
|
|
raise SQLAlchemyError("db down")
|
|
|
|
monkeypatch.setattr(db_session, "commit", _fail_commit)
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
await service.run(AgentChatRunRequest(message="hello"))
|
|
|
|
assert exc_info.value.status_code == 503
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_returns_502_for_unexpected_exception(
|
|
db_session: AsyncSession,
|
|
) -> None:
|
|
user = CurrentUser(id=uuid4())
|
|
service = AgentChatService(session=db_session, current_user=user)
|
|
|
|
class _CrashingOrchestrator:
|
|
async def run(self, *, run_id: str, user_message: str) -> OrchestratorResult:
|
|
raise RuntimeError("unexpected")
|
|
|
|
service._orchestrator = _CrashingOrchestrator() # type: ignore[assignment]
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
await service.run(AgentChatRunRequest(message="hello"))
|
|
|
|
assert exc_info.value.status_code == 502
|