feat(agent-chat): complete core workflow and strengthen auth rate limiting
This commit is contained in:
@@ -0,0 +1,196 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from decimal import Decimal
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Column, String, Table, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from core.agent_chat.orchestrator import OrchestratorResult
|
||||
from core.db.base import Base
|
||||
from models.agent_chat_message import AgentChatMessage
|
||||
from models.agent_chat_session import AgentChatSession
|
||||
from v1.agent_chat.schemas import AgentChatRunRequest
|
||||
from v1.agent_chat.service import AgentChatService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_engine():
|
||||
users_table = Table(
|
||||
"users",
|
||||
Base.metadata,
|
||||
Column("id", String, primary_key=True),
|
||||
schema="auth",
|
||||
extend_existing=True,
|
||||
)
|
||||
engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
|
||||
async with engine.begin() as conn:
|
||||
await conn.exec_driver_sql("ATTACH DATABASE ':memory:' AS auth")
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield engine
|
||||
Base.metadata.remove(users_table)
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session(db_engine):
|
||||
async_session = async_sessionmaker(
|
||||
bind=db_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
async with async_session() as session:
|
||||
yield session
|
||||
await session.rollback()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_creates_session_and_persists_messages(
|
||||
db_session: AsyncSession,
|
||||
) -> None:
|
||||
user = CurrentUser(id=uuid4())
|
||||
service = AgentChatService(session=db_session, current_user=user)
|
||||
|
||||
result = await service.run(AgentChatRunRequest(message="hello"))
|
||||
|
||||
assert result.session_id is not None
|
||||
assert result.output == "hello"
|
||||
assert [event.type for event in result.events] == [
|
||||
"run.started",
|
||||
"message.delta",
|
||||
"run.completed",
|
||||
]
|
||||
|
||||
session_obj = await db_session.get(AgentChatSession, result.session_id)
|
||||
assert session_obj is not None
|
||||
assert session_obj.message_count == 2
|
||||
assert session_obj.status.value == "completed"
|
||||
|
||||
rows = await db_session.execute(
|
||||
select(AgentChatMessage)
|
||||
.where(AgentChatMessage.session_id == result.session_id)
|
||||
.order_by(AgentChatMessage.seq.asc())
|
||||
)
|
||||
messages = rows.scalars().all()
|
||||
assert len(messages) == 2
|
||||
assert messages[0].role.value == "user"
|
||||
assert messages[1].role.value == "assistant"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_appends_to_existing_session(db_session: AsyncSession) -> None:
|
||||
user = CurrentUser(id=uuid4())
|
||||
service = AgentChatService(session=db_session, current_user=user)
|
||||
|
||||
first = await service.run(AgentChatRunRequest(message="first"))
|
||||
second = await service.run(
|
||||
AgentChatRunRequest(message="second", session_id=first.session_id)
|
||||
)
|
||||
|
||||
assert second.session_id == first.session_id
|
||||
|
||||
session_obj = await db_session.get(AgentChatSession, first.session_id)
|
||||
assert session_obj is not None
|
||||
assert session_obj.message_count == 4
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_raises_502_and_marks_session_failed_when_orchestrator_fails(
|
||||
db_session: AsyncSession,
|
||||
) -> None:
|
||||
user = CurrentUser(id=uuid4())
|
||||
service = AgentChatService(session=db_session, current_user=user)
|
||||
|
||||
class _FailingOrchestrator:
|
||||
async def run(self, *, run_id: str, user_message: str) -> OrchestratorResult:
|
||||
return OrchestratorResult(
|
||||
output="",
|
||||
usage={
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
"cost": Decimal("0"),
|
||||
"currency": "USD",
|
||||
},
|
||||
events=[],
|
||||
context={},
|
||||
failed=True,
|
||||
error="stage failed",
|
||||
)
|
||||
|
||||
service._orchestrator = _FailingOrchestrator() # type: ignore[assignment]
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.run(AgentChatRunRequest(message="hello"))
|
||||
|
||||
assert exc_info.value.status_code == 502
|
||||
|
||||
rows = await db_session.execute(
|
||||
select(AgentChatSession).where(AgentChatSession.user_id == user.id)
|
||||
)
|
||||
stored_session = rows.scalars().one()
|
||||
assert stored_session.status.value == "failed"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_returns_422_when_message_is_blank(db_session: AsyncSession) -> None:
|
||||
user = CurrentUser(id=uuid4())
|
||||
service = AgentChatService(session=db_session, current_user=user)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.run(AgentChatRunRequest(message=" "))
|
||||
|
||||
assert exc_info.value.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_returns_404_when_session_not_found(db_session: AsyncSession) -> None:
|
||||
user = CurrentUser(id=uuid4())
|
||||
service = AgentChatService(session=db_session, current_user=user)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.run(AgentChatRunRequest(message="hello", session_id=uuid4()))
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_returns_503_when_commit_raises_sqlalchemy_error(
|
||||
db_session: AsyncSession,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
user = CurrentUser(id=uuid4())
|
||||
service = AgentChatService(session=db_session, current_user=user)
|
||||
|
||||
async def _fail_commit() -> None:
|
||||
raise SQLAlchemyError("db down")
|
||||
|
||||
monkeypatch.setattr(db_session, "commit", _fail_commit)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.run(AgentChatRunRequest(message="hello"))
|
||||
|
||||
assert exc_info.value.status_code == 503
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_returns_502_for_unexpected_exception(
|
||||
db_session: AsyncSession,
|
||||
) -> None:
|
||||
user = CurrentUser(id=uuid4())
|
||||
service = AgentChatService(session=db_session, current_user=user)
|
||||
|
||||
class _CrashingOrchestrator:
|
||||
async def run(self, *, run_id: str, user_message: str) -> OrchestratorResult:
|
||||
raise RuntimeError("unexpected")
|
||||
|
||||
service._orchestrator = _CrashingOrchestrator() # type: ignore[assignment]
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.run(AgentChatRunRequest(message="hello"))
|
||||
|
||||
assert exc_info.value.status_code == 502
|
||||
Reference in New Issue
Block a user