feat(agent-chat): complete core workflow and strengthen auth rate limiting

This commit is contained in:
qzl
2026-02-25 16:51:12 +08:00
parent 53c72e48e6
commit cd40b2b4f4
62 changed files with 3441 additions and 3 deletions
@@ -0,0 +1,97 @@
from __future__ import annotations
from datetime import datetime
from decimal import Decimal
from types import MethodType
from uuid import UUID, uuid4
import pytest
from core.auth.models import CurrentUser
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
from v1.agent_chat.schemas import AgentChatRunRequest
from v1.agent_chat.service import AgentChatService
class _FakeAsyncSession:
def __init__(self) -> None:
self.added: list[object] = []
self.committed = False
self.rolled_back = False
def add(self, obj: object) -> None:
self.added.append(obj)
async def flush(self) -> None:
return None
async def commit(self) -> None:
self.committed = True
async def rollback(self) -> None:
self.rolled_back = True
async def refresh(self, obj: object) -> None:
if isinstance(obj, AgentChatSession) and obj.id is None:
obj.id = uuid4()
if isinstance(obj, AgentChatMessage) and obj.id is None:
obj.id = uuid4()
@pytest.mark.asyncio
async def test_run_persists_messages_and_emits_ordered_events() -> None:
fake_db = _FakeAsyncSession()
service = AgentChatService(
session=fake_db, # type: ignore[arg-type]
current_user=CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001")),
)
async def _resolve_session(
self: AgentChatService,
*,
session_id: object | None,
user_id: UUID,
first_message: str,
now: datetime,
) -> AgentChatSession:
assert session_id is None
assert first_message == "hello"
return AgentChatSession(
id=UUID("00000000-0000-0000-0000-000000000111"),
user_id=user_id,
title="hello",
status=AgentChatSessionStatus.RUNNING,
last_activity_at=now,
message_count=0,
total_tokens=0,
total_cost=Decimal("0"),
created_at=now,
updated_at=now,
deleted_at=None,
)
async def _next_seq_base(self: AgentChatService, session_id: object) -> int:
assert session_id == UUID("00000000-0000-0000-0000-000000000111")
return 2
service._resolve_session = MethodType(_resolve_session, service) # type: ignore[method-assign]
service._next_seq_base = MethodType(_next_seq_base, service) # type: ignore[method-assign]
response = await service.run(AgentChatRunRequest(message="hello"))
assert fake_db.committed is True
inserted_messages = [
item for item in fake_db.added if isinstance(item, AgentChatMessage)
]
assert len(inserted_messages) == 2
assert [msg.seq for msg in inserted_messages] == [3, 4]
assert [msg.role for msg in inserted_messages] == [
AgentChatMessageRole.USER,
AgentChatMessageRole.ASSISTANT,
]
assert [event.type for event in response.events] == [
"run.started",
"message.delta",
"run.completed",
]
@@ -0,0 +1,78 @@
from __future__ import annotations
from typing import Callable
from uuid import UUID
from fastapi.testclient import TestClient
from app import app
from v1.agent_chat.dependencies import get_agent_chat_service
from v1.agent_chat.schemas import (
AgentChatEvent,
AgentChatRunRequest,
AgentChatRunResponse,
)
from v1.agent_chat.service import AgentChatService
class FakeAgentChatService:
async def run(self, payload: AgentChatRunRequest) -> AgentChatRunResponse:
return AgentChatRunResponse(
session_id=UUID("00000000-0000-0000-0000-000000000001"),
output=payload.message,
events=[
AgentChatEvent(
type="run.started", run_id="00000000-0000-0000-0000-000000000001"
),
AgentChatEvent(
type="message.delta", message_id="m1", delta=payload.message
),
AgentChatEvent(
type="run.completed",
run_id="00000000-0000-0000-0000-000000000001",
output=payload.message,
),
],
)
def _override_agent_chat_service(
service: FakeAgentChatService,
) -> Callable[[], AgentChatService]:
def _get_service() -> AgentChatService:
return service # type: ignore[return-value]
return _get_service
def test_run_route_returns_response() -> None:
app.dependency_overrides[get_agent_chat_service] = _override_agent_chat_service(
FakeAgentChatService()
)
client = TestClient(app)
try:
response = client.post("/api/v1/agent-chat/run", json={"message": "hello"})
assert response.status_code == 200
body = response.json()
assert body["output"] == "hello"
assert [event["type"] for event in body["events"]] == [
"run.started",
"message.delta",
"run.completed",
]
finally:
app.dependency_overrides = {}
def test_run_route_validates_payload() -> None:
app.dependency_overrides[get_agent_chat_service] = _override_agent_chat_service(
FakeAgentChatService()
)
client = TestClient(app)
try:
response = client.post("/api/v1/agent-chat/run", json={"message": ""})
assert response.status_code == 422
finally:
app.dependency_overrides = {}
@@ -0,0 +1,20 @@
from __future__ import annotations
from decimal import Decimal
from v1.agent_chat.service import aggregate_session_cost
def test_aggregate_session_cost_sums_non_negative_values() -> None:
total = aggregate_session_cost([Decimal("0.010000"), Decimal("0.002500")])
assert total == Decimal("0.012500")
def test_aggregate_session_cost_rejects_negative_value() -> None:
try:
aggregate_session_cost([Decimal("-0.010000")])
raised = False
except ValueError:
raised = True
assert raised is True
@@ -0,0 +1,42 @@
from __future__ import annotations
from datetime import datetime, timezone
from decimal import Decimal
from uuid import UUID
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
from v1.agent_chat.service import select_recent_session
def test_select_recent_session_uses_last_activity_desc() -> None:
sessions = [
AgentChatSession(
id=UUID("00000000-0000-0000-0000-000000000001"),
user_id=UUID("00000000-0000-0000-0000-0000000000a1"),
title="older",
status=AgentChatSessionStatus.COMPLETED,
last_activity_at=datetime(2026, 2, 25, 9, 0, tzinfo=timezone.utc),
message_count=1,
total_tokens=1,
total_cost=Decimal("0"),
),
AgentChatSession(
id=UUID("00000000-0000-0000-0000-000000000002"),
user_id=UUID("00000000-0000-0000-0000-0000000000a1"),
title="newer",
status=AgentChatSessionStatus.RUNNING,
last_activity_at=datetime(2026, 2, 25, 10, 0, tzinfo=timezone.utc),
message_count=2,
total_tokens=2,
total_cost=Decimal("0"),
),
]
selected = select_recent_session(sessions)
assert selected is not None
assert selected.id == UUID("00000000-0000-0000-0000-000000000002")
def test_select_recent_session_returns_none_for_empty_collection() -> None:
assert select_recent_session([]) is None
@@ -416,6 +416,108 @@ def test_logout_returns_no_content() -> None:
app.dependency_overrides = {}
def test_login_rate_limited_after_too_many_attempts() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = AuthTokenResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
)
client = TestClient(app)
try:
for _ in range(10):
blocked = client.post(
"/api/v1/auth/login",
json={"email": "user@example.com", "password": "wrongpw"},
)
assert blocked.status_code == 401
blocked = client.post(
"/api/v1/auth/login",
json={"email": "user@example.com", "password": "wrongpw"},
)
assert blocked.status_code == 429
assert blocked.headers["content-type"].startswith("application/problem+json")
body = blocked.json()
assert body["detail"] == "Too many requests"
finally:
app.dependency_overrides = {}
def test_refresh_rate_limited_after_too_many_attempts() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = AuthTokenResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
)
client = TestClient(app)
try:
for _ in range(10):
blocked = client.post(
"/api/v1/auth/refresh",
json={"refresh_token": "invalid"},
)
assert blocked.status_code == 401
blocked = client.post(
"/api/v1/auth/refresh",
json={"refresh_token": "invalid"},
)
assert blocked.status_code == 429
assert blocked.headers["content-type"].startswith("application/problem+json")
body = blocked.json()
assert body["detail"] == "Too many requests"
finally:
app.dependency_overrides = {}
def test_logout_rate_limited_after_too_many_attempts() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = AuthTokenResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
)
client = TestClient(app)
try:
for _ in range(10):
ok = client.post(
"/api/v1/auth/logout",
json={"refresh_token": "refresh"},
)
assert ok.status_code == 204
blocked = client.post(
"/api/v1/auth/logout",
json={"refresh_token": "refresh"},
)
assert blocked.status_code == 429
assert blocked.headers["content-type"].startswith("application/problem+json")
body = blocked.json()
assert body["detail"] == "Too many requests"
finally:
app.dependency_overrides = {}
def test_signup_start_validation_error_returns_problem_details() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = AuthTokenResponse(