feat(agent-chat): complete core workflow and strengthen auth rate limiting
This commit is contained in:
@@ -0,0 +1,98 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
from uuid import UUID
|
||||
|
||||
from playwright.sync_api import sync_playwright
|
||||
import uvicorn
|
||||
|
||||
from app import app
|
||||
from v1.agent_chat.dependencies import get_agent_chat_service
|
||||
from v1.agent_chat.schemas import (
|
||||
AgentChatEvent,
|
||||
AgentChatRunRequest,
|
||||
AgentChatRunResponse,
|
||||
)
|
||||
from v1.agent_chat.service import AgentChatService
|
||||
|
||||
|
||||
class FakeE2EAgentChatService(AgentChatService):
|
||||
def __init__(self) -> None:
|
||||
return None
|
||||
|
||||
async def run(self, payload: AgentChatRunRequest) -> AgentChatRunResponse:
|
||||
session_id = payload.session_id or UUID("00000000-0000-0000-0000-000000000001")
|
||||
return AgentChatRunResponse(
|
||||
session_id=session_id,
|
||||
output=payload.message,
|
||||
events=[
|
||||
AgentChatEvent(type="run.started", run_id=str(session_id)),
|
||||
AgentChatEvent(
|
||||
type="message.delta", message_id="m1", delta=payload.message
|
||||
),
|
||||
AgentChatEvent(
|
||||
type="run.completed", run_id=str(session_id), output=payload.message
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _find_free_port() -> int:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
sock.bind(("127.0.0.1", 0))
|
||||
return sock.getsockname()[1]
|
||||
|
||||
|
||||
def _wait_for_port(host: str, port: int, timeout: float = 5.0) -> None:
|
||||
deadline = time.time() + timeout
|
||||
while time.time() < deadline:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
if sock.connect_ex((host, port)) == 0:
|
||||
return
|
||||
time.sleep(0.05)
|
||||
raise RuntimeError("Server did not start in time")
|
||||
|
||||
|
||||
def _start_server(host: str, port: int):
|
||||
config = uvicorn.Config(app, host=host, port=port, log_level="info")
|
||||
server = uvicorn.Server(config)
|
||||
thread = threading.Thread(target=server.run, daemon=True)
|
||||
thread.start()
|
||||
_wait_for_port(host, port)
|
||||
return server, thread
|
||||
|
||||
|
||||
def test_agent_chat_flow_e2e() -> None:
|
||||
app.dependency_overrides[get_agent_chat_service] = lambda: FakeE2EAgentChatService()
|
||||
host = "127.0.0.1"
|
||||
port = _find_free_port()
|
||||
server, thread = _start_server(host, port)
|
||||
|
||||
try:
|
||||
with sync_playwright() as playwright:
|
||||
request_context = playwright.request.new_context(
|
||||
base_url=f"http://{host}:{port}"
|
||||
)
|
||||
try:
|
||||
response = request_context.post(
|
||||
"/api/v1/agent-chat/run",
|
||||
data=json.dumps({"message": "hello"}),
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
assert response.status == 200
|
||||
body = response.json()
|
||||
assert body["output"] == "hello"
|
||||
assert [event["type"] for event in body["events"]] == [
|
||||
"run.started",
|
||||
"message.delta",
|
||||
"run.completed",
|
||||
]
|
||||
finally:
|
||||
request_context.dispose()
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
server.should_exit = True
|
||||
thread.join(timeout=5)
|
||||
@@ -0,0 +1,38 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from decimal import Decimal
|
||||
from uuid import UUID
|
||||
|
||||
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
|
||||
from v1.agent_chat.service import select_recent_session
|
||||
|
||||
|
||||
def test_recent_session_home_default_selection() -> None:
|
||||
sessions = [
|
||||
AgentChatSession(
|
||||
id=UUID("00000000-0000-0000-0000-0000000000a1"),
|
||||
user_id=UUID("00000000-0000-0000-0000-0000000000c1"),
|
||||
title="older",
|
||||
status=AgentChatSessionStatus.COMPLETED,
|
||||
last_activity_at=datetime(2026, 2, 25, 8, 0, tzinfo=timezone.utc),
|
||||
message_count=2,
|
||||
total_tokens=100,
|
||||
total_cost=Decimal("0.010000"),
|
||||
),
|
||||
AgentChatSession(
|
||||
id=UUID("00000000-0000-0000-0000-0000000000a2"),
|
||||
user_id=UUID("00000000-0000-0000-0000-0000000000c1"),
|
||||
title="newer",
|
||||
status=AgentChatSessionStatus.RUNNING,
|
||||
last_activity_at=datetime(2026, 2, 25, 9, 0, tzinfo=timezone.utc),
|
||||
message_count=3,
|
||||
total_tokens=120,
|
||||
total_cost=Decimal("0.020000"),
|
||||
),
|
||||
]
|
||||
|
||||
selected = select_recent_session(sessions)
|
||||
|
||||
assert selected is not None
|
||||
assert selected.id == UUID("00000000-0000-0000-0000-0000000000a2")
|
||||
@@ -0,0 +1,97 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from types import MethodType
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
|
||||
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
|
||||
from v1.agent_chat.schemas import AgentChatRunRequest
|
||||
from v1.agent_chat.service import AgentChatService
|
||||
|
||||
|
||||
class _FakeAsyncSession:
|
||||
def __init__(self) -> None:
|
||||
self.added: list[object] = []
|
||||
self.committed = False
|
||||
self.rolled_back = False
|
||||
|
||||
def add(self, obj: object) -> None:
|
||||
self.added.append(obj)
|
||||
|
||||
async def flush(self) -> None:
|
||||
return None
|
||||
|
||||
async def commit(self) -> None:
|
||||
self.committed = True
|
||||
|
||||
async def rollback(self) -> None:
|
||||
self.rolled_back = True
|
||||
|
||||
async def refresh(self, obj: object) -> None:
|
||||
if isinstance(obj, AgentChatSession) and obj.id is None:
|
||||
obj.id = uuid4()
|
||||
if isinstance(obj, AgentChatMessage) and obj.id is None:
|
||||
obj.id = uuid4()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_persists_messages_and_emits_ordered_events() -> None:
|
||||
fake_db = _FakeAsyncSession()
|
||||
service = AgentChatService(
|
||||
session=fake_db, # type: ignore[arg-type]
|
||||
current_user=CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001")),
|
||||
)
|
||||
|
||||
async def _resolve_session(
|
||||
self: AgentChatService,
|
||||
*,
|
||||
session_id: object | None,
|
||||
user_id: UUID,
|
||||
first_message: str,
|
||||
now: datetime,
|
||||
) -> AgentChatSession:
|
||||
assert session_id is None
|
||||
assert first_message == "hello"
|
||||
return AgentChatSession(
|
||||
id=UUID("00000000-0000-0000-0000-000000000111"),
|
||||
user_id=user_id,
|
||||
title="hello",
|
||||
status=AgentChatSessionStatus.RUNNING,
|
||||
last_activity_at=now,
|
||||
message_count=0,
|
||||
total_tokens=0,
|
||||
total_cost=Decimal("0"),
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
deleted_at=None,
|
||||
)
|
||||
|
||||
async def _next_seq_base(self: AgentChatService, session_id: object) -> int:
|
||||
assert session_id == UUID("00000000-0000-0000-0000-000000000111")
|
||||
return 2
|
||||
|
||||
service._resolve_session = MethodType(_resolve_session, service) # type: ignore[method-assign]
|
||||
service._next_seq_base = MethodType(_next_seq_base, service) # type: ignore[method-assign]
|
||||
|
||||
response = await service.run(AgentChatRunRequest(message="hello"))
|
||||
|
||||
assert fake_db.committed is True
|
||||
inserted_messages = [
|
||||
item for item in fake_db.added if isinstance(item, AgentChatMessage)
|
||||
]
|
||||
assert len(inserted_messages) == 2
|
||||
assert [msg.seq for msg in inserted_messages] == [3, 4]
|
||||
assert [msg.role for msg in inserted_messages] == [
|
||||
AgentChatMessageRole.USER,
|
||||
AgentChatMessageRole.ASSISTANT,
|
||||
]
|
||||
assert [event.type for event in response.events] == [
|
||||
"run.started",
|
||||
"message.delta",
|
||||
"run.completed",
|
||||
]
|
||||
@@ -0,0 +1,78 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app import app
|
||||
from v1.agent_chat.dependencies import get_agent_chat_service
|
||||
from v1.agent_chat.schemas import (
|
||||
AgentChatEvent,
|
||||
AgentChatRunRequest,
|
||||
AgentChatRunResponse,
|
||||
)
|
||||
from v1.agent_chat.service import AgentChatService
|
||||
|
||||
|
||||
class FakeAgentChatService:
|
||||
async def run(self, payload: AgentChatRunRequest) -> AgentChatRunResponse:
|
||||
return AgentChatRunResponse(
|
||||
session_id=UUID("00000000-0000-0000-0000-000000000001"),
|
||||
output=payload.message,
|
||||
events=[
|
||||
AgentChatEvent(
|
||||
type="run.started", run_id="00000000-0000-0000-0000-000000000001"
|
||||
),
|
||||
AgentChatEvent(
|
||||
type="message.delta", message_id="m1", delta=payload.message
|
||||
),
|
||||
AgentChatEvent(
|
||||
type="run.completed",
|
||||
run_id="00000000-0000-0000-0000-000000000001",
|
||||
output=payload.message,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _override_agent_chat_service(
|
||||
service: FakeAgentChatService,
|
||||
) -> Callable[[], AgentChatService]:
|
||||
def _get_service() -> AgentChatService:
|
||||
return service # type: ignore[return-value]
|
||||
|
||||
return _get_service
|
||||
|
||||
|
||||
def test_run_route_returns_response() -> None:
|
||||
app.dependency_overrides[get_agent_chat_service] = _override_agent_chat_service(
|
||||
FakeAgentChatService()
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.post("/api/v1/agent-chat/run", json={"message": "hello"})
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["output"] == "hello"
|
||||
assert [event["type"] for event in body["events"]] == [
|
||||
"run.started",
|
||||
"message.delta",
|
||||
"run.completed",
|
||||
]
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_run_route_validates_payload() -> None:
|
||||
app.dependency_overrides[get_agent_chat_service] = _override_agent_chat_service(
|
||||
FakeAgentChatService()
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.post("/api/v1/agent-chat/run", json={"message": ""})
|
||||
assert response.status_code == 422
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
@@ -0,0 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from decimal import Decimal
|
||||
|
||||
from v1.agent_chat.service import aggregate_session_cost
|
||||
|
||||
|
||||
def test_aggregate_session_cost_sums_non_negative_values() -> None:
|
||||
total = aggregate_session_cost([Decimal("0.010000"), Decimal("0.002500")])
|
||||
assert total == Decimal("0.012500")
|
||||
|
||||
|
||||
def test_aggregate_session_cost_rejects_negative_value() -> None:
|
||||
try:
|
||||
aggregate_session_cost([Decimal("-0.010000")])
|
||||
raised = False
|
||||
except ValueError:
|
||||
raised = True
|
||||
|
||||
assert raised is True
|
||||
@@ -0,0 +1,42 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from decimal import Decimal
|
||||
from uuid import UUID
|
||||
|
||||
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
|
||||
from v1.agent_chat.service import select_recent_session
|
||||
|
||||
|
||||
def test_select_recent_session_uses_last_activity_desc() -> None:
|
||||
sessions = [
|
||||
AgentChatSession(
|
||||
id=UUID("00000000-0000-0000-0000-000000000001"),
|
||||
user_id=UUID("00000000-0000-0000-0000-0000000000a1"),
|
||||
title="older",
|
||||
status=AgentChatSessionStatus.COMPLETED,
|
||||
last_activity_at=datetime(2026, 2, 25, 9, 0, tzinfo=timezone.utc),
|
||||
message_count=1,
|
||||
total_tokens=1,
|
||||
total_cost=Decimal("0"),
|
||||
),
|
||||
AgentChatSession(
|
||||
id=UUID("00000000-0000-0000-0000-000000000002"),
|
||||
user_id=UUID("00000000-0000-0000-0000-0000000000a1"),
|
||||
title="newer",
|
||||
status=AgentChatSessionStatus.RUNNING,
|
||||
last_activity_at=datetime(2026, 2, 25, 10, 0, tzinfo=timezone.utc),
|
||||
message_count=2,
|
||||
total_tokens=2,
|
||||
total_cost=Decimal("0"),
|
||||
),
|
||||
]
|
||||
|
||||
selected = select_recent_session(sessions)
|
||||
|
||||
assert selected is not None
|
||||
assert selected.id == UUID("00000000-0000-0000-0000-000000000002")
|
||||
|
||||
|
||||
def test_select_recent_session_returns_none_for_empty_collection() -> None:
|
||||
assert select_recent_session([]) is None
|
||||
@@ -416,6 +416,108 @@ def test_logout_returns_no_content() -> None:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_login_rate_limited_after_too_many_attempts() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = AuthTokenResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
app.dependency_overrides[get_auth_service] = _override_auth_service(
|
||||
FakeAuthService(token_response)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
for _ in range(10):
|
||||
blocked = client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": "user@example.com", "password": "wrongpw"},
|
||||
)
|
||||
assert blocked.status_code == 401
|
||||
|
||||
blocked = client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": "user@example.com", "password": "wrongpw"},
|
||||
)
|
||||
assert blocked.status_code == 429
|
||||
assert blocked.headers["content-type"].startswith("application/problem+json")
|
||||
body = blocked.json()
|
||||
assert body["detail"] == "Too many requests"
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_refresh_rate_limited_after_too_many_attempts() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = AuthTokenResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
app.dependency_overrides[get_auth_service] = _override_auth_service(
|
||||
FakeAuthService(token_response)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
for _ in range(10):
|
||||
blocked = client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": "invalid"},
|
||||
)
|
||||
assert blocked.status_code == 401
|
||||
|
||||
blocked = client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": "invalid"},
|
||||
)
|
||||
assert blocked.status_code == 429
|
||||
assert blocked.headers["content-type"].startswith("application/problem+json")
|
||||
body = blocked.json()
|
||||
assert body["detail"] == "Too many requests"
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_logout_rate_limited_after_too_many_attempts() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = AuthTokenResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
app.dependency_overrides[get_auth_service] = _override_auth_service(
|
||||
FakeAuthService(token_response)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
for _ in range(10):
|
||||
ok = client.post(
|
||||
"/api/v1/auth/logout",
|
||||
json={"refresh_token": "refresh"},
|
||||
)
|
||||
assert ok.status_code == 204
|
||||
|
||||
blocked = client.post(
|
||||
"/api/v1/auth/logout",
|
||||
json={"refresh_token": "refresh"},
|
||||
)
|
||||
assert blocked.status_code == 429
|
||||
assert blocked.headers["content-type"].startswith("application/problem+json")
|
||||
body = blocked.json()
|
||||
assert body["detail"] == "Too many requests"
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_signup_start_validation_error_returns_problem_details() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = AuthTokenResponse(
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent_chat.agui_adapter import AguiAdapter
|
||||
|
||||
|
||||
def test_to_command_maps_payload_fields() -> None:
|
||||
adapter = AguiAdapter()
|
||||
|
||||
command = adapter.to_command(
|
||||
{
|
||||
"message": "hello",
|
||||
"session_id": "00000000-0000-0000-0000-000000000001",
|
||||
}
|
||||
)
|
||||
|
||||
assert command["message"] == "hello"
|
||||
assert command["session_id"] == "00000000-0000-0000-0000-000000000001"
|
||||
|
||||
|
||||
def test_to_protocol_event_maps_internal_event() -> None:
|
||||
adapter = AguiAdapter()
|
||||
|
||||
mapped = adapter.to_protocol_event(
|
||||
{
|
||||
"kind": "run_completed",
|
||||
"session_id": "run-1",
|
||||
"output": "done",
|
||||
}
|
||||
)
|
||||
|
||||
assert mapped == {"type": "run.completed", "run_id": "run-1", "output": "done"}
|
||||
|
||||
|
||||
def test_to_protocol_event_raises_for_invalid_event() -> None:
|
||||
adapter = AguiAdapter()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
adapter.to_protocol_event({"kind": "unknown"})
|
||||
@@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent_chat.tools.asr_fun_asr import FunASRTool
|
||||
|
||||
|
||||
def test_transcribe_uses_injected_dashscope_callable() -> None:
|
||||
def fake_transcribe(*, audio_bytes: bytes, filename: str) -> dict[str, str]:
|
||||
assert filename == "voice.wav"
|
||||
assert audio_bytes == b"audio"
|
||||
return {"text": "你好", "request_id": "req-1"}
|
||||
|
||||
tool = FunASRTool(transcribe_callable=fake_transcribe)
|
||||
|
||||
result = tool.transcribe(audio_bytes=b"audio", filename="voice.wav")
|
||||
|
||||
assert result["text"] == "你好"
|
||||
assert result["request_id"] == "req-1"
|
||||
assert result["model"] == "fun-asr-realtime-2025-11-07"
|
||||
|
||||
|
||||
def test_transcribe_raises_runtime_error_when_provider_fails() -> None:
|
||||
def fake_transcribe(*, audio_bytes: bytes, filename: str) -> dict[str, str]:
|
||||
raise RuntimeError("upstream timeout")
|
||||
|
||||
tool = FunASRTool(transcribe_callable=fake_transcribe)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
tool.transcribe(audio_bytes=b"audio", filename="voice.wav")
|
||||
@@ -0,0 +1,82 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent_chat.cost_tracker import CostTracker
|
||||
|
||||
|
||||
def test_normalize_usage_and_cost_aggregation() -> None:
|
||||
tracker = CostTracker()
|
||||
|
||||
tracker.add_usage(
|
||||
{
|
||||
"prompt_tokens": 7,
|
||||
"completion_tokens": 5,
|
||||
"cost": "0.002500",
|
||||
}
|
||||
)
|
||||
tracker.add_usage(
|
||||
{
|
||||
"input_tokens": 5,
|
||||
"output_tokens": 3,
|
||||
"cost": "0.003000",
|
||||
"currency": "USD",
|
||||
}
|
||||
)
|
||||
|
||||
snapshot = tracker.snapshot()
|
||||
|
||||
assert snapshot["input_tokens"] == 12
|
||||
assert snapshot["output_tokens"] == 8
|
||||
assert snapshot["total_tokens"] == 20
|
||||
assert snapshot["cost"] == Decimal("0.005500")
|
||||
assert snapshot["currency"] == "USD"
|
||||
|
||||
|
||||
def test_add_usage_rejects_negative_values() -> None:
|
||||
tracker = CostTracker()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
tracker.add_usage({"input_tokens": -1})
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
tracker.add_usage({"cost": "-0.010000"})
|
||||
|
||||
|
||||
def test_snapshot_is_zero_before_any_usage() -> None:
|
||||
tracker = CostTracker()
|
||||
|
||||
snapshot = tracker.snapshot()
|
||||
|
||||
assert snapshot["input_tokens"] == 0
|
||||
assert snapshot["output_tokens"] == 0
|
||||
assert snapshot["total_tokens"] == 0
|
||||
assert snapshot["cost"] == Decimal("0")
|
||||
assert snapshot["currency"] == "USD"
|
||||
|
||||
|
||||
def test_add_usage_rejects_currency_mismatch() -> None:
|
||||
tracker = CostTracker(currency="USD")
|
||||
tracker.add_usage({"input_tokens": 1, "output_tokens": 1, "cost": "0.001000"})
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
tracker.add_usage(
|
||||
{
|
||||
"input_tokens": 1,
|
||||
"output_tokens": 1,
|
||||
"cost": "0.001000",
|
||||
"currency": "CNY",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_add_usage_rejects_non_integral_token_values() -> None:
|
||||
tracker = CostTracker()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
tracker.add_usage({"input_tokens": 1.5})
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
tracker.add_usage({"output_tokens": True})
|
||||
@@ -0,0 +1,61 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent_chat.event_bridge import map_internal_event
|
||||
|
||||
|
||||
def test_map_run_started_event() -> None:
|
||||
event = {"kind": "run_started", "session_id": "s1"}
|
||||
|
||||
mapped = map_internal_event(event)
|
||||
|
||||
assert mapped == {"type": "run.started", "run_id": "s1"}
|
||||
|
||||
|
||||
def test_map_message_delta_event() -> None:
|
||||
event = {"kind": "message_delta", "message_id": "m1", "delta": "hello"}
|
||||
|
||||
mapped = map_internal_event(event)
|
||||
|
||||
assert mapped == {"type": "message.delta", "message_id": "m1", "delta": "hello"}
|
||||
|
||||
|
||||
def test_map_tool_events() -> None:
|
||||
started = {
|
||||
"kind": "tool_started",
|
||||
"message_id": "m2",
|
||||
"tool_name": "asr_fun_asr",
|
||||
}
|
||||
completed = {
|
||||
"kind": "tool_completed",
|
||||
"message_id": "m2",
|
||||
"tool_name": "asr_fun_asr",
|
||||
"result": "ok",
|
||||
}
|
||||
|
||||
mapped_started = map_internal_event(started)
|
||||
mapped_completed = map_internal_event(completed)
|
||||
|
||||
assert mapped_started["type"] == "tool.started"
|
||||
assert mapped_started["tool_name"] == "asr_fun_asr"
|
||||
assert mapped_completed["type"] == "tool.completed"
|
||||
assert mapped_completed["result"] == "ok"
|
||||
|
||||
|
||||
def test_map_run_completed_event() -> None:
|
||||
event = {"kind": "run_completed", "session_id": "s1", "output": "done"}
|
||||
|
||||
mapped = map_internal_event(event)
|
||||
|
||||
assert mapped == {"type": "run.completed", "run_id": "s1", "output": "done"}
|
||||
|
||||
|
||||
def test_map_unknown_event_raises() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
map_internal_event({"kind": "unknown"})
|
||||
|
||||
|
||||
def test_map_event_missing_required_field_raises_value_error() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
map_internal_event({"kind": "message_delta", "message_id": "m1"})
|
||||
@@ -0,0 +1,89 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent_chat.multimodal import AttachmentInput, MultimodalProcessor
|
||||
from core.agent_chat.storage_adapter import StorageAdapter
|
||||
from core.agent_chat.tools.asr_fun_asr import FunASRTool
|
||||
|
||||
|
||||
def test_multimodal_processes_audio_and_builds_attachment_context() -> None:
|
||||
storage = StorageAdapter(bucket="agent-chat-attachments")
|
||||
|
||||
def fake_transcribe(*, audio_bytes: bytes, filename: str) -> dict[str, str]:
|
||||
assert audio_bytes == b"audio"
|
||||
assert filename == "voice.wav"
|
||||
return {"text": "hello world", "request_id": "req-1"}
|
||||
|
||||
processor = MultimodalProcessor(
|
||||
storage=storage,
|
||||
asr_tool=FunASRTool(transcribe_callable=fake_transcribe),
|
||||
max_file_size_mb=1,
|
||||
)
|
||||
|
||||
result = processor.process(
|
||||
user_id="u1",
|
||||
session_id="s1",
|
||||
message_seq=4,
|
||||
attachments=[
|
||||
AttachmentInput(
|
||||
filename="voice.wav",
|
||||
mime_type="audio/wav",
|
||||
content=b"audio",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
assert len(result.attachments) == 1
|
||||
metadata = result.attachments[0]
|
||||
assert (
|
||||
metadata["object_path"]
|
||||
== "agent-chat/u1/s1/4/6ed8919ce20490a5e3ad8630a4fab69475297abd07db73918dd5f36fcfaeb11b.wav"
|
||||
)
|
||||
assert metadata["mime_type"] == "audio/wav"
|
||||
assert result.preview_texts == ["hello world"]
|
||||
|
||||
|
||||
def test_multimodal_rejects_unsupported_mime_type() -> None:
|
||||
storage = StorageAdapter(bucket="agent-chat-attachments")
|
||||
processor = MultimodalProcessor(
|
||||
storage=storage, asr_tool=FunASRTool(lambda **_: {})
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
processor.process(
|
||||
user_id="u1",
|
||||
session_id="s1",
|
||||
message_seq=1,
|
||||
attachments=[
|
||||
AttachmentInput(
|
||||
filename="malware.exe",
|
||||
mime_type="application/octet-stream",
|
||||
content=b"bad",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_multimodal_rejects_attachment_over_max_size() -> None:
|
||||
storage = StorageAdapter(bucket="agent-chat-attachments")
|
||||
processor = MultimodalProcessor(
|
||||
storage=storage,
|
||||
asr_tool=FunASRTool(lambda **_: {}),
|
||||
max_file_size_mb=1,
|
||||
)
|
||||
|
||||
oversized = b"x" * (1024 * 1024 + 1)
|
||||
with pytest.raises(ValueError):
|
||||
processor.process(
|
||||
user_id="u1",
|
||||
session_id="s1",
|
||||
message_seq=1,
|
||||
attachments=[
|
||||
AttachmentInput(
|
||||
filename="big.wav",
|
||||
mime_type="audio/wav",
|
||||
content=oversized,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,104 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.agent_chat.orchestrator import AgentChatOrchestrator
|
||||
|
||||
|
||||
async def _intent_stage(
|
||||
*, message: str, context: dict[str, object]
|
||||
) -> dict[str, object]:
|
||||
sequence = context.setdefault("sequence", [])
|
||||
if isinstance(sequence, list):
|
||||
sequence.append("intent")
|
||||
return {
|
||||
"content": f"intent:{message}",
|
||||
"usage": {"input_tokens": 2, "output_tokens": 1, "cost": "0.001000"},
|
||||
}
|
||||
|
||||
|
||||
async def _execution_stage(
|
||||
*, message: str, context: dict[str, object]
|
||||
) -> dict[str, object]:
|
||||
sequence = context.setdefault("sequence", [])
|
||||
if isinstance(sequence, list):
|
||||
sequence.append("execution")
|
||||
return {
|
||||
"content": f"execution:{message}",
|
||||
"usage": {"input_tokens": 3, "output_tokens": 2, "cost": "0.002000"},
|
||||
}
|
||||
|
||||
|
||||
async def _organization_stage(
|
||||
*, message: str, context: dict[str, object]
|
||||
) -> dict[str, object]:
|
||||
sequence = context.setdefault("sequence", [])
|
||||
if isinstance(sequence, list):
|
||||
sequence.append("organization")
|
||||
return {
|
||||
"content": "final answer",
|
||||
"usage": {"input_tokens": 4, "output_tokens": 1, "cost": "0.001500"},
|
||||
}
|
||||
|
||||
|
||||
def test_orchestrator_runs_three_stages_in_order() -> None:
|
||||
orchestrator = AgentChatOrchestrator(
|
||||
intent_stage=_intent_stage,
|
||||
execution_stage=_execution_stage,
|
||||
organization_stage=_organization_stage,
|
||||
)
|
||||
|
||||
result = orchestrator.run_sync(run_id="run-1", user_message="hello")
|
||||
|
||||
assert result.context["sequence"] == ["intent", "execution", "organization"]
|
||||
assert result.output == "final answer"
|
||||
assert result.usage["total_tokens"] == 13
|
||||
assert result.events[0]["type"] == "run.started"
|
||||
assert result.events[-1]["type"] == "run.completed"
|
||||
|
||||
|
||||
async def _failing_execution_stage(
|
||||
*, message: str, context: dict[str, object]
|
||||
) -> dict[str, object]:
|
||||
sequence = context.setdefault("sequence", [])
|
||||
if isinstance(sequence, list):
|
||||
sequence.append("execution")
|
||||
raise RuntimeError("boom")
|
||||
|
||||
|
||||
def test_orchestrator_stops_and_marks_failed_when_middle_stage_raises() -> None:
|
||||
orchestrator = AgentChatOrchestrator(
|
||||
intent_stage=_intent_stage,
|
||||
execution_stage=_failing_execution_stage,
|
||||
organization_stage=_organization_stage,
|
||||
)
|
||||
|
||||
result = orchestrator.run_sync(run_id="run-2", user_message="hello")
|
||||
|
||||
assert result.context["sequence"] == ["intent", "execution"]
|
||||
assert result.events[-1]["type"] == "run.failed"
|
||||
assert result.events[-1]["run_id"] == "run-2"
|
||||
assert "boom" in (result.events[-1].get("error") or "")
|
||||
assert result.failed is True
|
||||
assert "boom" in (result.error or "")
|
||||
|
||||
|
||||
def test_orchestrator_emits_stage_event_payload_shape() -> None:
|
||||
orchestrator = AgentChatOrchestrator(
|
||||
intent_stage=_intent_stage,
|
||||
execution_stage=_execution_stage,
|
||||
organization_stage=_organization_stage,
|
||||
)
|
||||
|
||||
result = orchestrator.run_sync(run_id="run-3", user_message="hello")
|
||||
|
||||
for event in result.events:
|
||||
assert "type" in event
|
||||
assert event.get("run_id") == "run-3"
|
||||
|
||||
stage_events = [
|
||||
event for event in result.events if event["type"] == "stage.completed"
|
||||
]
|
||||
assert [event["stage"] for event in stage_events] == [
|
||||
"intent",
|
||||
"execution",
|
||||
"organization",
|
||||
]
|
||||
@@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from v1.agent_chat.service import build_session_title
|
||||
|
||||
|
||||
def test_build_session_title_truncates_first_message() -> None:
|
||||
now = datetime(2026, 2, 25, 10, 30)
|
||||
|
||||
title = build_session_title(
|
||||
"这是一个非常长的标题会被截断到二十四个可见字符用于会话摘要", now=now
|
||||
)
|
||||
|
||||
assert len(title) == 24
|
||||
|
||||
|
||||
def test_build_session_title_falls_back_when_message_empty() -> None:
|
||||
now = datetime(2026, 2, 25, 10, 30)
|
||||
|
||||
title = build_session_title("\n ", now=now)
|
||||
|
||||
assert title == "新对话 2026-02-25 10:30"
|
||||
@@ -0,0 +1,37 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.agent_chat.storage_adapter import StorageAdapter
|
||||
|
||||
|
||||
def test_build_object_path_uses_expected_pattern() -> None:
|
||||
adapter = StorageAdapter(bucket="agent-chat-attachments")
|
||||
|
||||
path = adapter.build_object_path(
|
||||
user_id="u1",
|
||||
session_id="s1",
|
||||
message_seq=3,
|
||||
checksum_sha256="abc123",
|
||||
extension="wav",
|
||||
)
|
||||
|
||||
assert path == "agent-chat/u1/s1/3/abc123.wav"
|
||||
|
||||
|
||||
def test_build_attachment_metadata_contains_required_fields() -> None:
|
||||
adapter = StorageAdapter(bucket="agent-chat-attachments")
|
||||
|
||||
metadata = adapter.build_attachment_metadata(
|
||||
object_path="agent-chat/u1/s1/3/abc123.wav",
|
||||
mime_type="audio/wav",
|
||||
size=1024,
|
||||
checksum_sha256="abc123",
|
||||
origin="user_upload",
|
||||
preview_text="hello",
|
||||
)
|
||||
|
||||
assert metadata["object_path"] == "agent-chat/u1/s1/3/abc123.wav"
|
||||
assert metadata["mime_type"] == "audio/wav"
|
||||
assert metadata["size"] == 1024
|
||||
assert metadata["checksum_sha256"] == "abc123"
|
||||
assert metadata["origin"] == "user_upload"
|
||||
assert metadata["preview_text"] == "hello"
|
||||
@@ -0,0 +1,138 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent_chat.crewai.template_loader import (
|
||||
load_crewai_template,
|
||||
load_tools_whitelist,
|
||||
validate_workflow_stages,
|
||||
)
|
||||
|
||||
|
||||
def _write(path: Path, content: str) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(content, encoding="utf-8")
|
||||
|
||||
|
||||
def _prepare_static_root(root: Path) -> Path:
|
||||
_write(
|
||||
root / "crewai" / "agents.yaml",
|
||||
"""
|
||||
intent:
|
||||
role: Intent Agent
|
||||
execution:
|
||||
role: Execution Agent
|
||||
organization:
|
||||
role: Organization Agent
|
||||
""".strip(),
|
||||
)
|
||||
_write(
|
||||
root / "crewai" / "tasks.yaml",
|
||||
"""
|
||||
intent:
|
||||
description: classify
|
||||
execution:
|
||||
description: run task
|
||||
organization:
|
||||
description: summarize
|
||||
""".strip(),
|
||||
)
|
||||
_write(
|
||||
root / "crewai" / "workflow.yaml",
|
||||
"""
|
||||
stages:
|
||||
- intent
|
||||
- execution
|
||||
- organization
|
||||
""".strip(),
|
||||
)
|
||||
_write(root / "crewai" / "prompts" / "intent.md", "intent prompt")
|
||||
_write(root / "crewai" / "prompts" / "execution.md", "execution prompt")
|
||||
_write(root / "crewai" / "prompts" / "organization.md", "organization prompt")
|
||||
_write(
|
||||
root / "tools.yaml",
|
||||
"""
|
||||
tools:
|
||||
- asr_fun_asr
|
||||
- doc_extract
|
||||
""".strip(),
|
||||
)
|
||||
return root
|
||||
|
||||
|
||||
def test_load_crewai_template_success_when_all_files_valid(tmp_path: Path) -> None:
|
||||
static_root = _prepare_static_root(tmp_path / "agent_chat")
|
||||
|
||||
template = load_crewai_template(static_root)
|
||||
|
||||
assert set(template.agents.keys()) == {"intent", "execution", "organization"}
|
||||
assert set(template.tasks.keys()) == {"intent", "execution", "organization"}
|
||||
assert template.workflow["stages"] == ["intent", "execution", "organization"]
|
||||
assert template.prompts["intent"] == "intent prompt"
|
||||
assert template.prompts["execution"] == "execution prompt"
|
||||
assert template.prompts["organization"] == "organization prompt"
|
||||
assert template.tools_whitelist == {"asr_fun_asr", "doc_extract"}
|
||||
|
||||
|
||||
def test_load_crewai_template_raises_file_not_found_when_required_file_missing(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
static_root = _prepare_static_root(tmp_path / "agent_chat")
|
||||
(static_root / "crewai" / "tasks.yaml").unlink()
|
||||
|
||||
with pytest.raises(FileNotFoundError):
|
||||
load_crewai_template(static_root)
|
||||
|
||||
|
||||
def test_load_crewai_template_raises_value_error_when_workflow_stages_invalid(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
static_root = _prepare_static_root(tmp_path / "agent_chat")
|
||||
_write(
|
||||
static_root / "crewai" / "workflow.yaml",
|
||||
"""
|
||||
stages:
|
||||
- execution
|
||||
- intent
|
||||
- organization
|
||||
""".strip(),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
load_crewai_template(static_root)
|
||||
|
||||
|
||||
def test_load_tools_whitelist_from_tools_yaml(tmp_path: Path) -> None:
|
||||
static_root = _prepare_static_root(tmp_path / "agent_chat")
|
||||
|
||||
whitelist = load_tools_whitelist(static_root)
|
||||
|
||||
assert whitelist == {"asr_fun_asr", "doc_extract"}
|
||||
|
||||
|
||||
def test_validate_workflow_stages_accepts_exact_intent_execution_organization() -> None:
|
||||
validate_workflow_stages(["intent", "execution", "organization"])
|
||||
|
||||
|
||||
def test_validate_workflow_stages_rejects_extra_or_missing_stage() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
validate_workflow_stages(["intent", "execution"])
|
||||
with pytest.raises(ValueError):
|
||||
validate_workflow_stages(["intent", "execution", "organization", "extra"])
|
||||
|
||||
|
||||
def test_load_tools_whitelist_rejects_non_string_item(tmp_path: Path) -> None:
|
||||
static_root = _prepare_static_root(tmp_path / "agent_chat")
|
||||
_write(
|
||||
static_root / "tools.yaml",
|
||||
"""
|
||||
tools:
|
||||
- asr_fun_asr
|
||||
- 123
|
||||
""".strip(),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
load_tools_whitelist(static_root)
|
||||
@@ -0,0 +1,143 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Column, String, Table, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from core.db.base import Base
|
||||
from core.initialization import init_data
|
||||
from models.llm import Llm
|
||||
from models.llm_factory import LlmFactory
|
||||
|
||||
|
||||
def test_llm_catalog_file_exists_and_has_required_fields() -> None:
|
||||
catalog_path = (
|
||||
Path(__file__).resolve().parents[3]
|
||||
/ "src"
|
||||
/ "core"
|
||||
/ "config"
|
||||
/ "static"
|
||||
/ "agent_chat"
|
||||
/ "llm_catalog.yaml"
|
||||
)
|
||||
|
||||
catalog = init_data.load_llm_catalog(catalog_path)
|
||||
|
||||
assert len(catalog["factories"]) == 6
|
||||
assert len(catalog["llms"]) == 2
|
||||
assert set(catalog["factories"][0].keys()) == {"name", "request_url", "avatar"}
|
||||
assert set(catalog["llms"][0].keys()) == {"model_code", "factory_id"}
|
||||
|
||||
|
||||
def test_load_llm_catalog_raises_on_invalid_structure(tmp_path: Path) -> None:
|
||||
catalog_path = tmp_path / "llm_catalog.yaml"
|
||||
catalog_path.write_text(
|
||||
"""
|
||||
factories:
|
||||
- name: qwen
|
||||
llms:
|
||||
- model_code: qwen3.5-flash
|
||||
""".strip(),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
init_data.load_llm_catalog(catalog_path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_data_is_idempotent(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
users_table = Table(
|
||||
"users",
|
||||
Base.metadata,
|
||||
Column("id", String, primary_key=True),
|
||||
schema="auth",
|
||||
extend_existing=True,
|
||||
)
|
||||
engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
|
||||
session_maker = async_sessionmaker(
|
||||
bind=engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.exec_driver_sql("ATTACH DATABASE ':memory:' AS auth")
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
monkeypatch.setattr(init_data, "AsyncSessionLocal", session_maker)
|
||||
|
||||
first = await init_data.initialize_data()
|
||||
second = await init_data.initialize_data()
|
||||
|
||||
assert first is True
|
||||
assert second is True
|
||||
|
||||
async with session_maker() as session:
|
||||
factory_count = await session.scalar(
|
||||
select(func.count()).select_from(LlmFactory)
|
||||
)
|
||||
llm_count = await session.scalar(select(func.count()).select_from(Llm))
|
||||
|
||||
assert factory_count == 6
|
||||
assert llm_count == 2
|
||||
|
||||
Base.metadata.remove(users_table)
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_data_rolls_back_on_invalid_factory_mapping(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
users_table = Table(
|
||||
"users",
|
||||
Base.metadata,
|
||||
Column("id", String, primary_key=True),
|
||||
schema="auth",
|
||||
extend_existing=True,
|
||||
)
|
||||
engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
|
||||
session_maker = async_sessionmaker(
|
||||
bind=engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.exec_driver_sql("ATTACH DATABASE ':memory:' AS auth")
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
monkeypatch.setattr(init_data, "AsyncSessionLocal", session_maker)
|
||||
monkeypatch.setattr(
|
||||
init_data,
|
||||
"load_llm_catalog",
|
||||
lambda *_: {
|
||||
"factories": [
|
||||
{
|
||||
"name": "qwen",
|
||||
"request_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
"avatar": "https://cdn.example.com/qwen.png",
|
||||
}
|
||||
],
|
||||
"llms": [
|
||||
{
|
||||
"model_code": "qwen3.5-flash",
|
||||
"factory_id": "missing_factory",
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
await init_data.initialize_data()
|
||||
|
||||
async with session_maker() as session:
|
||||
factory_count = await session.scalar(
|
||||
select(func.count()).select_from(LlmFactory)
|
||||
)
|
||||
llm_count = await session.scalar(select(func.count()).select_from(Llm))
|
||||
|
||||
assert factory_count == 0
|
||||
assert llm_count == 0
|
||||
|
||||
Base.metadata.remove(users_table)
|
||||
await engine.dispose()
|
||||
@@ -0,0 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def test_agent_chat_migration_exists_and_creates_expected_tables() -> None:
|
||||
versions_dir = Path(__file__).resolve().parents[3] / "alembic" / "versions"
|
||||
migration = versions_dir / "20260226_create_agent_chat_core_tables.py"
|
||||
|
||||
assert migration.exists()
|
||||
|
||||
content = migration.read_text(encoding="utf-8")
|
||||
assert 'create_table(\n "llm_factory"' in content
|
||||
assert 'create_table(\n "llms"' in content
|
||||
assert 'create_table(\n "sessions"' in content
|
||||
assert 'create_table(\n "messages"' in content
|
||||
assert "tool_calls" not in content
|
||||
@@ -0,0 +1,119 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Column, String, Table, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from core.db.base import Base
|
||||
from models.agent_chat_message import AgentChatMessage
|
||||
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
|
||||
from models.llm import Llm
|
||||
from models.llm_factory import LlmFactory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_engine():
|
||||
users_table = Table(
|
||||
"users",
|
||||
Base.metadata,
|
||||
Column("id", String, primary_key=True),
|
||||
schema="auth",
|
||||
extend_existing=True,
|
||||
)
|
||||
engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
|
||||
async with engine.begin() as conn:
|
||||
await conn.exec_driver_sql("ATTACH DATABASE ':memory:' AS auth")
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield engine
|
||||
Base.metadata.remove(users_table)
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session(db_engine):
|
||||
async_session = async_sessionmaker(
|
||||
bind=db_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
async with async_session() as session:
|
||||
yield session
|
||||
await session.rollback()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_factory_and_llm_relationship(db_session: AsyncSession) -> None:
|
||||
factory = LlmFactory(
|
||||
name="qwen",
|
||||
request_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
avatar="https://cdn.example.com/qwen.png",
|
||||
)
|
||||
db_session.add(factory)
|
||||
await db_session.flush()
|
||||
|
||||
llm = Llm(
|
||||
factory_id=factory.id,
|
||||
model_code="qwen3.5-flash",
|
||||
)
|
||||
db_session.add(llm)
|
||||
await db_session.commit()
|
||||
|
||||
found_llm = await db_session.get(Llm, llm.id)
|
||||
assert found_llm is not None
|
||||
assert found_llm.factory_id == factory.id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_status_supports_required_values(
|
||||
db_session: AsyncSession,
|
||||
) -> None:
|
||||
user_id = uuid4()
|
||||
session = AgentChatSession(
|
||||
user_id=user_id,
|
||||
title="test",
|
||||
status="pending",
|
||||
)
|
||||
db_session.add(session)
|
||||
await db_session.commit()
|
||||
|
||||
statuses = [
|
||||
AgentChatSessionStatus.PENDING,
|
||||
AgentChatSessionStatus.RUNNING,
|
||||
AgentChatSessionStatus.COMPLETED,
|
||||
AgentChatSessionStatus.FAILED,
|
||||
]
|
||||
for status in statuses:
|
||||
session.status = status
|
||||
await db_session.commit()
|
||||
await db_session.refresh(session)
|
||||
assert session.status == status
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_messages_role_supports_tool(db_session: AsyncSession) -> None:
|
||||
user_id = uuid4()
|
||||
session = AgentChatSession(
|
||||
user_id=user_id,
|
||||
title="tool test",
|
||||
status="pending",
|
||||
)
|
||||
db_session.add(session)
|
||||
await db_session.flush()
|
||||
|
||||
message = AgentChatMessage(
|
||||
session_id=session.id,
|
||||
seq=1,
|
||||
role="tool",
|
||||
content="tool output",
|
||||
cost=0,
|
||||
)
|
||||
db_session.add(message)
|
||||
await db_session.commit()
|
||||
|
||||
result = await db_session.execute(
|
||||
select(AgentChatMessage).where(AgentChatMessage.session_id == session.id)
|
||||
)
|
||||
found = result.scalar_one()
|
||||
assert found.role == "tool"
|
||||
@@ -0,0 +1,34 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import ValidationError
|
||||
import pytest
|
||||
from pytest import MonkeyPatch
|
||||
|
||||
from core.config.settings import Settings
|
||||
|
||||
|
||||
def test_social_prefixed_storage_env_populates_settings(
|
||||
monkeypatch: MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setenv("SOCIAL_STORAGE__PROVIDER", "supabase")
|
||||
monkeypatch.setenv("SOCIAL_STORAGE__BUCKET", "agent-chat-attachments")
|
||||
monkeypatch.setenv("SOCIAL_STORAGE__SIGNED_URL_TTL_SECONDS", "900")
|
||||
monkeypatch.setenv("SOCIAL_STORAGE__MAX_FILE_SIZE_MB", "25")
|
||||
monkeypatch.setenv("SOCIAL_STORAGE__RETENTION_DAYS", "45")
|
||||
|
||||
settings = Settings()
|
||||
|
||||
assert settings.storage.provider == "supabase"
|
||||
assert settings.storage.bucket == "agent-chat-attachments"
|
||||
assert settings.storage.signed_url_ttl_seconds == 900
|
||||
assert settings.storage.max_file_size_mb == 25
|
||||
assert settings.storage.retention_days == 45
|
||||
|
||||
|
||||
def test_storage_settings_validation_rejects_invalid_provider(
|
||||
monkeypatch: MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setenv("SOCIAL_STORAGE__PROVIDER", "s3")
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
Settings()
|
||||
@@ -0,0 +1,196 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from decimal import Decimal
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Column, String, Table, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from core.agent_chat.orchestrator import OrchestratorResult
|
||||
from core.db.base import Base
|
||||
from models.agent_chat_message import AgentChatMessage
|
||||
from models.agent_chat_session import AgentChatSession
|
||||
from v1.agent_chat.schemas import AgentChatRunRequest
|
||||
from v1.agent_chat.service import AgentChatService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_engine():
|
||||
users_table = Table(
|
||||
"users",
|
||||
Base.metadata,
|
||||
Column("id", String, primary_key=True),
|
||||
schema="auth",
|
||||
extend_existing=True,
|
||||
)
|
||||
engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
|
||||
async with engine.begin() as conn:
|
||||
await conn.exec_driver_sql("ATTACH DATABASE ':memory:' AS auth")
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield engine
|
||||
Base.metadata.remove(users_table)
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session(db_engine):
|
||||
async_session = async_sessionmaker(
|
||||
bind=db_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
async with async_session() as session:
|
||||
yield session
|
||||
await session.rollback()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_creates_session_and_persists_messages(
|
||||
db_session: AsyncSession,
|
||||
) -> None:
|
||||
user = CurrentUser(id=uuid4())
|
||||
service = AgentChatService(session=db_session, current_user=user)
|
||||
|
||||
result = await service.run(AgentChatRunRequest(message="hello"))
|
||||
|
||||
assert result.session_id is not None
|
||||
assert result.output == "hello"
|
||||
assert [event.type for event in result.events] == [
|
||||
"run.started",
|
||||
"message.delta",
|
||||
"run.completed",
|
||||
]
|
||||
|
||||
session_obj = await db_session.get(AgentChatSession, result.session_id)
|
||||
assert session_obj is not None
|
||||
assert session_obj.message_count == 2
|
||||
assert session_obj.status.value == "completed"
|
||||
|
||||
rows = await db_session.execute(
|
||||
select(AgentChatMessage)
|
||||
.where(AgentChatMessage.session_id == result.session_id)
|
||||
.order_by(AgentChatMessage.seq.asc())
|
||||
)
|
||||
messages = rows.scalars().all()
|
||||
assert len(messages) == 2
|
||||
assert messages[0].role.value == "user"
|
||||
assert messages[1].role.value == "assistant"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_appends_to_existing_session(db_session: AsyncSession) -> None:
|
||||
user = CurrentUser(id=uuid4())
|
||||
service = AgentChatService(session=db_session, current_user=user)
|
||||
|
||||
first = await service.run(AgentChatRunRequest(message="first"))
|
||||
second = await service.run(
|
||||
AgentChatRunRequest(message="second", session_id=first.session_id)
|
||||
)
|
||||
|
||||
assert second.session_id == first.session_id
|
||||
|
||||
session_obj = await db_session.get(AgentChatSession, first.session_id)
|
||||
assert session_obj is not None
|
||||
assert session_obj.message_count == 4
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_raises_502_and_marks_session_failed_when_orchestrator_fails(
|
||||
db_session: AsyncSession,
|
||||
) -> None:
|
||||
user = CurrentUser(id=uuid4())
|
||||
service = AgentChatService(session=db_session, current_user=user)
|
||||
|
||||
class _FailingOrchestrator:
|
||||
async def run(self, *, run_id: str, user_message: str) -> OrchestratorResult:
|
||||
return OrchestratorResult(
|
||||
output="",
|
||||
usage={
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
"cost": Decimal("0"),
|
||||
"currency": "USD",
|
||||
},
|
||||
events=[],
|
||||
context={},
|
||||
failed=True,
|
||||
error="stage failed",
|
||||
)
|
||||
|
||||
service._orchestrator = _FailingOrchestrator() # type: ignore[assignment]
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.run(AgentChatRunRequest(message="hello"))
|
||||
|
||||
assert exc_info.value.status_code == 502
|
||||
|
||||
rows = await db_session.execute(
|
||||
select(AgentChatSession).where(AgentChatSession.user_id == user.id)
|
||||
)
|
||||
stored_session = rows.scalars().one()
|
||||
assert stored_session.status.value == "failed"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_returns_422_when_message_is_blank(db_session: AsyncSession) -> None:
|
||||
user = CurrentUser(id=uuid4())
|
||||
service = AgentChatService(session=db_session, current_user=user)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.run(AgentChatRunRequest(message=" "))
|
||||
|
||||
assert exc_info.value.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_returns_404_when_session_not_found(db_session: AsyncSession) -> None:
|
||||
user = CurrentUser(id=uuid4())
|
||||
service = AgentChatService(session=db_session, current_user=user)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.run(AgentChatRunRequest(message="hello", session_id=uuid4()))
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_returns_503_when_commit_raises_sqlalchemy_error(
|
||||
db_session: AsyncSession,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
user = CurrentUser(id=uuid4())
|
||||
service = AgentChatService(session=db_session, current_user=user)
|
||||
|
||||
async def _fail_commit() -> None:
|
||||
raise SQLAlchemyError("db down")
|
||||
|
||||
monkeypatch.setattr(db_session, "commit", _fail_commit)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.run(AgentChatRunRequest(message="hello"))
|
||||
|
||||
assert exc_info.value.status_code == 503
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_returns_502_for_unexpected_exception(
|
||||
db_session: AsyncSession,
|
||||
) -> None:
|
||||
user = CurrentUser(id=uuid4())
|
||||
service = AgentChatService(session=db_session, current_user=user)
|
||||
|
||||
class _CrashingOrchestrator:
|
||||
async def run(self, *, run_id: str, user_message: str) -> OrchestratorResult:
|
||||
raise RuntimeError("unexpected")
|
||||
|
||||
service._orchestrator = _CrashingOrchestrator() # type: ignore[assignment]
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.run(AgentChatRunRequest(message="hello"))
|
||||
|
||||
assert exc_info.value.status_code == 502
|
||||
Reference in New Issue
Block a user