refactor: 移除 crewai agent 架构相关代码并更新 LLM 配置

This commit is contained in:
qzl
2026-03-04 11:37:09 +08:00
parent 87399f74c8
commit b02a322bf3
71 changed files with 1045 additions and 7499 deletions
-98
View File
@@ -1,98 +0,0 @@
from __future__ import annotations
import json
import socket
import threading
import time
from uuid import UUID
from playwright.sync_api import sync_playwright
import uvicorn
from app import app
from v1.agent.dependencies import get_agent_service
from v1.agent.schemas import (
AgentChatEvent,
AgentChatRunRequest,
AgentChatRunResponse,
)
from v1.agent.service import AgentChatService
class FakeE2EAgentChatService(AgentChatService):
def __init__(self) -> None:
return None
async def run(self, payload: AgentChatRunRequest) -> AgentChatRunResponse:
session_id = payload.session_id or UUID("00000000-0000-0000-0000-000000000001")
return AgentChatRunResponse(
session_id=session_id,
output=payload.message,
events=[
AgentChatEvent(type="run.started", run_id=str(session_id)),
AgentChatEvent(
type="message.delta", message_id="m1", delta=payload.message
),
AgentChatEvent(
type="run.completed", run_id=str(session_id), output=payload.message
),
],
)
def _find_free_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(("127.0.0.1", 0))
return sock.getsockname()[1]
def _wait_for_port(host: str, port: int, timeout: float = 5.0) -> None:
deadline = time.time() + timeout
while time.time() < deadline:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
if sock.connect_ex((host, port)) == 0:
return
time.sleep(0.05)
raise RuntimeError("Server did not start in time")
def _start_server(host: str, port: int):
config = uvicorn.Config(app, host=host, port=port, log_level="info")
server = uvicorn.Server(config)
thread = threading.Thread(target=server.run, daemon=True)
thread.start()
_wait_for_port(host, port)
return server, thread
def test_agent_chat_flow_e2e() -> None:
app.dependency_overrides[get_agent_service] = lambda: FakeE2EAgentChatService()
host = "127.0.0.1"
port = _find_free_port()
server, thread = _start_server(host, port)
try:
with sync_playwright() as playwright:
request_context = playwright.request.new_context(
base_url=f"http://{host}:{port}"
)
try:
response = request_context.post(
"/api/v1/agent-chat",
data=json.dumps({"message": "hello"}),
headers={"Content-Type": "application/json"},
)
assert response.status == 200
body = response.json()
assert body["output"] == "hello"
assert [event["type"] for event in body["events"]] == [
"run.started",
"message.delta",
"run.completed",
]
finally:
request_context.dispose()
finally:
app.dependency_overrides = {}
server.should_exit = True
thread.join(timeout=5)
@@ -1,38 +0,0 @@
from __future__ import annotations
from datetime import datetime, timezone
from decimal import Decimal
from uuid import UUID
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
from v1.agent.service import select_recent_session
def test_recent_session_home_default_selection() -> None:
sessions = [
AgentChatSession(
id=UUID("00000000-0000-0000-0000-0000000000a1"),
user_id=UUID("00000000-0000-0000-0000-0000000000c1"),
title="older",
status=AgentChatSessionStatus.COMPLETED,
last_activity_at=datetime(2026, 2, 25, 8, 0, tzinfo=timezone.utc),
message_count=2,
total_tokens=100,
total_cost=Decimal("0.010000"),
),
AgentChatSession(
id=UUID("00000000-0000-0000-0000-0000000000a2"),
user_id=UUID("00000000-0000-0000-0000-0000000000c1"),
title="newer",
status=AgentChatSessionStatus.RUNNING,
last_activity_at=datetime(2026, 2, 25, 9, 0, tzinfo=timezone.utc),
message_count=3,
total_tokens=120,
total_cost=Decimal("0.020000"),
),
]
selected = select_recent_session(sessions)
assert selected is not None
assert selected.id == UUID("00000000-0000-0000-0000-0000000000a2")
@@ -1,97 +0,0 @@
from __future__ import annotations
from datetime import datetime
from decimal import Decimal
from types import MethodType
from uuid import UUID, uuid4
import pytest
from core.auth.models import CurrentUser
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
from v1.agent.schemas import AgentChatRunRequest
from v1.agent.service import AgentChatService
class _FakeAsyncSession:
def __init__(self) -> None:
self.added: list[object] = []
self.committed = False
self.rolled_back = False
def add(self, obj: object) -> None:
self.added.append(obj)
async def flush(self) -> None:
return None
async def commit(self) -> None:
self.committed = True
async def rollback(self) -> None:
self.rolled_back = True
async def refresh(self, obj: object) -> None:
if isinstance(obj, AgentChatSession) and obj.id is None:
obj.id = uuid4()
if isinstance(obj, AgentChatMessage) and obj.id is None:
obj.id = uuid4()
@pytest.mark.asyncio
async def test_run_persists_messages_and_emits_ordered_events() -> None:
fake_db = _FakeAsyncSession()
service = AgentChatService(
session=fake_db, # type: ignore[arg-type]
current_user=CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001")),
)
async def _resolve_session(
self: AgentChatService,
*,
session_id: object | None,
user_id: UUID,
first_message: str,
now: datetime,
) -> AgentChatSession:
assert session_id is None
assert first_message == "hello"
return AgentChatSession(
id=UUID("00000000-0000-0000-0000-000000000111"),
user_id=user_id,
title="hello",
status=AgentChatSessionStatus.RUNNING,
last_activity_at=now,
message_count=0,
total_tokens=0,
total_cost=Decimal("0"),
created_at=now,
updated_at=now,
deleted_at=None,
)
async def _next_seq_base(self: AgentChatService, session_id: object) -> int:
assert session_id == UUID("00000000-0000-0000-0000-000000000111")
return 2
service._resolve_session = MethodType(_resolve_session, service) # type: ignore[method-assign]
service._next_seq_base = MethodType(_next_seq_base, service) # type: ignore[method-assign]
response = await service.run(AgentChatRunRequest(message="hello"))
assert fake_db.committed is True
inserted_messages = [
item for item in fake_db.added if isinstance(item, AgentChatMessage)
]
assert len(inserted_messages) == 2
assert [msg.seq for msg in inserted_messages] == [3, 4]
assert [msg.role for msg in inserted_messages] == [
AgentChatMessageRole.USER,
AgentChatMessageRole.ASSISTANT,
]
assert [event.type for event in response.events] == [
"run.started",
"message.delta",
"run.completed",
]
@@ -1,78 +0,0 @@
from __future__ import annotations
from typing import Callable
from uuid import UUID
from fastapi.testclient import TestClient
from app import app
from v1.agent.dependencies import get_agent_service
from v1.agent.schemas import (
AgentChatEvent,
AgentChatRunRequest,
AgentChatRunResponse,
)
from v1.agent.service import AgentChatService
class FakeAgentChatService:
async def run(self, payload: AgentChatRunRequest) -> AgentChatRunResponse:
return AgentChatRunResponse(
session_id=UUID("00000000-0000-0000-0000-000000000001"),
output=payload.message,
events=[
AgentChatEvent(
type="run.started", run_id="00000000-0000-0000-0000-000000000001"
),
AgentChatEvent(
type="message.delta", message_id="m1", delta=payload.message
),
AgentChatEvent(
type="run.completed",
run_id="00000000-0000-0000-0000-000000000001",
output=payload.message,
),
],
)
def _override_agent_chat_service(
service: FakeAgentChatService,
) -> Callable[[], AgentChatService]:
def _get_service() -> AgentChatService:
return service # type: ignore[return-value]
return _get_service
def test_run_route_returns_response() -> None:
app.dependency_overrides[get_agent_service] = _override_agent_chat_service(
FakeAgentChatService()
)
client = TestClient(app)
try:
response = client.post("/api/v1/agent-chat", json={"message": "hello"})
assert response.status_code == 200
body = response.json()
assert body["output"] == "hello"
assert [event["type"] for event in body["events"]] == [
"run.started",
"message.delta",
"run.completed",
]
finally:
app.dependency_overrides = {}
def test_run_route_validates_payload() -> None:
app.dependency_overrides[get_agent_service] = _override_agent_chat_service(
FakeAgentChatService()
)
client = TestClient(app)
try:
response = client.post("/api/v1/agent-chat", json={"message": ""})
assert response.status_code == 422
finally:
app.dependency_overrides = {}
@@ -1,20 +0,0 @@
from __future__ import annotations
from decimal import Decimal
from v1.agent.service import aggregate_session_cost
def test_aggregate_session_cost_sums_non_negative_values() -> None:
total = aggregate_session_cost([Decimal("0.010000"), Decimal("0.002500")])
assert total == Decimal("0.012500")
def test_aggregate_session_cost_rejects_negative_value() -> None:
try:
aggregate_session_cost([Decimal("-0.010000")])
raised = False
except ValueError:
raised = True
assert raised is True
@@ -1,42 +0,0 @@
from __future__ import annotations
from datetime import datetime, timezone
from decimal import Decimal
from uuid import UUID
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
from v1.agent.service import select_recent_session
def test_select_recent_session_uses_last_activity_desc() -> None:
sessions = [
AgentChatSession(
id=UUID("00000000-0000-0000-0000-000000000001"),
user_id=UUID("00000000-0000-0000-0000-0000000000a1"),
title="older",
status=AgentChatSessionStatus.COMPLETED,
last_activity_at=datetime(2026, 2, 25, 9, 0, tzinfo=timezone.utc),
message_count=1,
total_tokens=1,
total_cost=Decimal("0"),
),
AgentChatSession(
id=UUID("00000000-0000-0000-0000-000000000002"),
user_id=UUID("00000000-0000-0000-0000-0000000000a1"),
title="newer",
status=AgentChatSessionStatus.RUNNING,
last_activity_at=datetime(2026, 2, 25, 10, 0, tzinfo=timezone.utc),
message_count=2,
total_tokens=2,
total_cost=Decimal("0"),
),
]
selected = select_recent_session(sessions)
assert selected is not None
assert selected.id == UUID("00000000-0000-0000-0000-000000000002")
def test_select_recent_session_returns_none_for_empty_collection() -> None:
assert select_recent_session([]) is None
@@ -1,82 +0,0 @@
from __future__ import annotations
from uuid import UUID
import pytest
from fastapi.testclient import TestClient
from app import app
from core.auth.models import CurrentUser
from v1.agent.dependencies import get_agent_service
from v1.agent.schemas import RunAgentInput
from v1.users.dependencies import get_current_user
class FakeAgentService:
async def prepare_resume(self, run_id: str, input_data: RunAgentInput):
return None
async def stream_run(self, input_data: RunAgentInput):
yield 'data: {"type": "RUN_STARTED", "runId": "r1"}\n\n'
yield 'data: {"type": "TEXT_MESSAGE_START", "messageId": "m1"}\n\n'
yield 'data: {"type": "TEXT_MESSAGE_CONTENT", "delta": "Hello"}\n\n'
yield 'data: {"type": "TEXT_MESSAGE_END", "messageId": "m1"}\n\n'
yield 'data: {"type": "RUN_FINISHED", "runId": "r1"}\n\n'
async def stream_resume(self, run_id: str, input_data: RunAgentInput):
yield 'data: {"type": "RUN_STARTED", "runId": "r1"}\n\n'
yield 'data: {"type": "TEXT_MESSAGE_START", "messageId": "m2"}\n\n'
yield 'data: {"type": "TEXT_MESSAGE_CONTENT", "delta": "Resumed"}\n\n'
yield 'data: {"type": "TEXT_MESSAGE_END", "messageId": "m2"}\n\n'
yield 'data: {"type": "RUN_FINISHED", "runId": "r1"}\n\n'
def _get_test_user() -> CurrentUser:
return CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001"))
@pytest.fixture
def client() -> TestClient:
app.dependency_overrides[get_current_user] = _get_test_user
app.dependency_overrides[get_agent_service] = lambda: FakeAgentService()
yield TestClient(app)
app.dependency_overrides.clear()
class TestChatRoutes:
def test_run_route_streams_sse_events(self, client: TestClient):
payload = {
"threadId": "t1",
"runId": "r1",
"state": {},
"messages": [],
"tools": [],
"context": [],
"forwardedProps": {},
}
response = client.post("/api/v1/agent/runs", json=payload)
assert response.status_code == 200
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
events = response.text.split("\n\n")
assert 'data: {"type": "RUN_STARTED"' in events[0]
assert 'data: {"type": "TEXT_MESSAGE_START"' in events[1]
def test_resume_route_streams_sse_events(self, client: TestClient):
payload = {
"threadId": "t1",
"runId": "r1",
"state": {},
"messages": [],
"tools": [],
"context": [],
"forwardedProps": {},
"resume": {"interruptId": "int-1", "payload": {"decision": "approved"}},
}
response = client.post("/api/v1/agent/runs/r1/resume", json=payload)
assert response.status_code == 200
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
events = response.text.split("\n\n")
assert 'data: {"type": "RUN_STARTED"' in events[0]
assert 'data: {"type": "TEXT_MESSAGE_CONTENT", "delta": "Resumed"' in events[2]
@@ -1,144 +0,0 @@
from __future__ import annotations
import json
from uuid import UUID
import pytest
from fastapi.testclient import TestClient
from app import app
from core.auth.models import CurrentUser
from v1.agent.dependencies import get_agent_service
from v1.agent.schemas import RunAgentInput
from v1.users.dependencies import get_current_user
class FakeAgentServiceWithInterrupt:
async def prepare_resume(self, run_id: str, input_data: RunAgentInput):
return None
async def stream_run(self, input_data: RunAgentInput):
yield 'data: {"type": "RUN_STARTED", "runId": "' + input_data.runId + '"}\n\n'
yield 'data: {"type": "TEXT_MESSAGE_START", "messageId": "m1"}\n\n'
yield 'data: {"type": "TEXT_MESSAGE_CONTENT", "delta": "Let me navigate"}\n\n'
yield 'data: {"type": "TOOL_CALL", "toolName": "ui.navigate_to", "args": {"path": "/home"}}\n\n'
yield (
'data: {"type": "RUN_FINISHED", "runId": "'
+ input_data.runId
+ '", "outcome": "interrupt", "interrupt": {"id": "int-1", "reason": "frontend_tool", "payload": {"toolName": "ui.navigate_to", "args": {"path": "/home"}}}}\n\n'
)
async def stream_resume(self, run_id: str, input_data: RunAgentInput):
if input_data.resume and input_data.resume.get("interruptId") == "int-1":
payload = input_data.resume.get("payload", {})
yield 'data: {"type": "RUN_STARTED", "runId": "' + run_id + '"}\n\n'
yield (
'data: {"type": "TOOL_RESULT", "toolName": "ui.navigate_to", "result": '
+ json.dumps(payload.get("result", {}))
+ "}\n\n"
)
yield 'data: {"type": "TEXT_MESSAGE_START", "messageId": "m2"}\n\n'
yield 'data: {"type": "TEXT_MESSAGE_CONTENT", "delta": "Navigation completed"}\n\n'
yield 'data: {"type": "RUN_FINISHED", "runId": "' + run_id + '"}\n\n'
else:
yield (
'data: {"type": "RUN_FINISHED", "runId": "'
+ run_id
+ '", "outcome": "error"}\n\n'
)
def _get_test_user() -> CurrentUser:
return CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001"))
@pytest.fixture
def client() -> TestClient:
app.dependency_overrides[get_current_user] = _get_test_user
app.dependency_overrides[get_agent_service] = (
lambda: FakeAgentServiceWithInterrupt()
)
yield TestClient(app)
app.dependency_overrides.clear()
class TestInterruptResumeFlow:
def test_frontend_tool_interrupt_then_resume_with_result(self, client: TestClient):
payload = {
"threadId": "t1",
"runId": "r1",
"state": {},
"messages": [{"role": "user", "content": "Navigate to home"}],
"tools": [{"name": "ui.navigate_to", "execution_target": "frontend"}],
"context": [],
"forwardedProps": {},
}
response = client.post("/api/v1/agent/runs", json=payload)
assert response.status_code == 200
events = response.text.split("\n\n")
interrupt_event = [e for e in events if '"outcome": "interrupt"' in e][0]
assert '"id": "int-1"' in interrupt_event
assert '"reason": "frontend_tool"' in interrupt_event
resume_payload = {
"threadId": "t1",
"runId": "r1",
"state": {},
"messages": [],
"tools": [],
"context": [],
"forwardedProps": {},
"resume": {
"interruptId": "int-1",
"payload": {"result": {"success": True}},
},
}
resume_response = client.post(
"/api/v1/agent/runs/r1/resume", json=resume_payload
)
assert resume_response.status_code == 200
resume_events = resume_response.text.split("\n\n")
tool_result_event = [e for e in resume_events if '"type": "TOOL_RESULT"' in e][
0
]
assert '"toolName": "ui.navigate_to"' in tool_result_event
assert '"success": true' in tool_result_event.lower()
def test_backend_tool_approval_rejected(self, client: TestClient):
payload = {
"threadId": "t2",
"runId": "r2",
"state": {},
"messages": [{"role": "user", "content": "Transfer funds"}],
"tools": [
{
"name": "srv.transfer_funds",
"execution_target": "backend",
"requires_approval": True,
}
],
"context": [],
"forwardedProps": {},
}
response = client.post("/api/v1/agent/runs", json=payload)
assert response.status_code == 200
resume_payload = {
"threadId": "t2",
"runId": "r2",
"state": {},
"messages": [],
"tools": [],
"context": [],
"forwardedProps": {},
"resume": {
"interruptId": "int-1",
"payload": {"decision": "rejected", "reason": "User denied"},
},
}
resume_response = client.post(
"/api/v1/agent/runs/r2/resume", json=resume_payload
)
assert resume_response.status_code == 200
@@ -1,40 +0,0 @@
from __future__ import annotations
import pytest
from core.agent.agui_adapter import AguiAdapter
def test_to_command_maps_payload_fields() -> None:
adapter = AguiAdapter()
command = adapter.to_command(
{
"message": "hello",
"session_id": "00000000-0000-0000-0000-000000000001",
}
)
assert command["message"] == "hello"
assert command["session_id"] == "00000000-0000-0000-0000-000000000001"
def test_to_protocol_event_maps_internal_event() -> None:
adapter = AguiAdapter()
mapped = adapter.to_protocol_event(
{
"kind": "run_completed",
"session_id": "run-1",
"output": "done",
}
)
assert mapped == {"type": "run.completed", "run_id": "run-1", "output": "done"}
def test_to_protocol_event_raises_for_invalid_event() -> None:
adapter = AguiAdapter()
with pytest.raises(ValueError):
adapter.to_protocol_event({"kind": "unknown"})
@@ -1,30 +0,0 @@
from __future__ import annotations
import pytest
from core.agent.tools.asr_fun_asr import FunASRTool
def test_transcribe_uses_injected_dashscope_callable() -> None:
def fake_transcribe(*, audio_bytes: bytes, filename: str) -> dict[str, str]:
assert filename == "voice.wav"
assert audio_bytes == b"audio"
return {"text": "你好", "request_id": "req-1"}
tool = FunASRTool(transcribe_callable=fake_transcribe)
result = tool.transcribe(audio_bytes=b"audio", filename="voice.wav")
assert result["text"] == "你好"
assert result["request_id"] == "req-1"
assert result["model"] == "fun-asr-realtime-2025-11-07"
def test_transcribe_raises_runtime_error_when_provider_fails() -> None:
def fake_transcribe(*, audio_bytes: bytes, filename: str) -> dict[str, str]:
raise RuntimeError("upstream timeout")
tool = FunASRTool(transcribe_callable=fake_transcribe)
with pytest.raises(RuntimeError):
tool.transcribe(audio_bytes=b"audio", filename="voice.wav")
@@ -1,48 +0,0 @@
from __future__ import annotations
from decimal import Decimal
from core.agent.litellm_client import get_model_cost
def test_get_model_cost_returns_decimal() -> None:
usage = {
"prompt_tokens": 7,
"completion_tokens": 5,
"total_tokens": 12,
"cost": "0.002500",
}
cost = get_model_cost(usage)
assert cost == Decimal("0.002500")
def test_get_model_cost_with_no_cost() -> None:
usage = {
"prompt_tokens": 7,
"completion_tokens": 5,
"total_tokens": 12,
}
cost = get_model_cost(usage)
assert cost == Decimal("0")
def test_get_model_cost_with_zero_cost() -> None:
usage = {
"prompt_tokens": 7,
"completion_tokens": 5,
"total_tokens": 12,
"cost": "0",
}
cost = get_model_cost(usage)
assert cost == Decimal("0")
def test_get_model_cost_with_numeric_cost() -> None:
usage = {
"prompt_tokens": 7,
"completion_tokens": 5,
"total_tokens": 12,
"cost": 0.0025,
}
cost = get_model_cost(usage)
assert cost == Decimal("0.0025")
@@ -1,61 +0,0 @@
from __future__ import annotations
import pytest
from core.agent.event_bridge import map_internal_event
def test_map_run_started_event() -> None:
event = {"kind": "run_started", "session_id": "s1"}
mapped = map_internal_event(event)
assert mapped == {"type": "run.started", "run_id": "s1"}
def test_map_message_delta_event() -> None:
event = {"kind": "message_delta", "message_id": "m1", "delta": "hello"}
mapped = map_internal_event(event)
assert mapped == {"type": "message.delta", "message_id": "m1", "delta": "hello"}
def test_map_tool_events() -> None:
started = {
"kind": "tool_started",
"message_id": "m2",
"tool_name": "asr_fun_asr",
}
completed = {
"kind": "tool_completed",
"message_id": "m2",
"tool_name": "asr_fun_asr",
"result": "ok",
}
mapped_started = map_internal_event(started)
mapped_completed = map_internal_event(completed)
assert mapped_started["type"] == "tool.started"
assert mapped_started["tool_name"] == "asr_fun_asr"
assert mapped_completed["type"] == "tool.completed"
assert mapped_completed["result"] == "ok"
def test_map_run_completed_event() -> None:
event = {"kind": "run_completed", "session_id": "s1", "output": "done"}
mapped = map_internal_event(event)
assert mapped == {"type": "run.completed", "run_id": "s1", "output": "done"}
def test_map_unknown_event_raises() -> None:
with pytest.raises(ValueError):
map_internal_event({"kind": "unknown"})
def test_map_event_missing_required_field_raises_value_error() -> None:
with pytest.raises(ValueError):
map_internal_event({"kind": "message_delta", "message_id": "m1"})
@@ -1,104 +0,0 @@
from __future__ import annotations
from core.agent.orchestrator import AgentChatOrchestrator
async def _intent_stage(
*, message: str, context: dict[str, object]
) -> dict[str, object]:
sequence = context.setdefault("sequence", [])
if isinstance(sequence, list):
sequence.append("intent")
return {
"content": f"intent:{message}",
"usage": {"input_tokens": 2, "output_tokens": 1, "cost": "0.001000"},
}
async def _execution_stage(
*, message: str, context: dict[str, object]
) -> dict[str, object]:
sequence = context.setdefault("sequence", [])
if isinstance(sequence, list):
sequence.append("execution")
return {
"content": f"execution:{message}",
"usage": {"input_tokens": 3, "output_tokens": 2, "cost": "0.002000"},
}
async def _organization_stage(
*, message: str, context: dict[str, object]
) -> dict[str, object]:
sequence = context.setdefault("sequence", [])
if isinstance(sequence, list):
sequence.append("organization")
return {
"content": "final answer",
"usage": {"input_tokens": 4, "output_tokens": 1, "cost": "0.001500"},
}
def test_orchestrator_runs_three_stages_in_order() -> None:
orchestrator = AgentChatOrchestrator(
intent_stage=_intent_stage,
execution_stage=_execution_stage,
organization_stage=_organization_stage,
)
result = orchestrator.run_sync(run_id="run-1", user_message="hello")
assert result.context["sequence"] == ["intent", "execution", "organization"]
assert result.output == "final answer"
assert result.usage["total_tokens"] == 13
assert result.events[0]["type"] == "run.started"
assert result.events[-1]["type"] == "run.completed"
async def _failing_execution_stage(
*, message: str, context: dict[str, object]
) -> dict[str, object]:
sequence = context.setdefault("sequence", [])
if isinstance(sequence, list):
sequence.append("execution")
raise RuntimeError("boom")
def test_orchestrator_stops_and_marks_failed_when_middle_stage_raises() -> None:
orchestrator = AgentChatOrchestrator(
intent_stage=_intent_stage,
execution_stage=_failing_execution_stage,
organization_stage=_organization_stage,
)
result = orchestrator.run_sync(run_id="run-2", user_message="hello")
assert result.context["sequence"] == ["intent", "execution"]
assert result.events[-1]["type"] == "run.failed"
assert result.events[-1]["run_id"] == "run-2"
assert "boom" in (result.events[-1].get("error") or "")
assert result.failed is True
assert "boom" in (result.error or "")
def test_orchestrator_emits_stage_event_payload_shape() -> None:
orchestrator = AgentChatOrchestrator(
intent_stage=_intent_stage,
execution_stage=_execution_stage,
organization_stage=_organization_stage,
)
result = orchestrator.run_sync(run_id="run-3", user_message="hello")
for event in result.events:
assert "type" in event
assert event.get("run_id") == "run-3"
stage_events = [
event for event in result.events if event["type"] == "stage.completed"
]
assert [event["stage"] for event in stage_events] == [
"intent",
"execution",
"organization",
]
@@ -1,23 +0,0 @@
from __future__ import annotations
from datetime import datetime
from v1.agent.service import build_session_title
def test_build_session_title_truncates_first_message() -> None:
now = datetime(2026, 2, 25, 10, 30)
title = build_session_title(
"这是一个非常长的标题会被截断到二十四个可见字符用于会话摘要", now=now
)
assert len(title) == 24
def test_build_session_title_falls_back_when_message_empty() -> None:
now = datetime(2026, 2, 25, 10, 30)
title = build_session_title("\n ", now=now)
assert title == "新对话 2026-02-25 10:30"
@@ -1,40 +0,0 @@
from __future__ import annotations
import pytest
from core.agent.agui_adapter import AguiAdapter
def test_to_command_maps_payload_fields() -> None:
adapter = AguiAdapter()
command = adapter.to_command(
{
"message": "hello",
"session_id": "00000000-0000-0000-0000-000000000001",
}
)
assert command["message"] == "hello"
assert command["session_id"] == "00000000-0000-0000-0000-000000000001"
def test_to_protocol_event_maps_internal_event() -> None:
adapter = AguiAdapter()
mapped = adapter.to_protocol_event(
{
"kind": "run_completed",
"session_id": "run-1",
"output": "done",
}
)
assert mapped == {"type": "run.completed", "run_id": "run-1", "output": "done"}
def test_to_protocol_event_raises_for_invalid_event() -> None:
adapter = AguiAdapter()
with pytest.raises(ValueError):
adapter.to_protocol_event({"kind": "unknown"})
@@ -1,30 +0,0 @@
from __future__ import annotations
import pytest
from core.agent.tools.asr_fun_asr import FunASRTool
def test_transcribe_uses_injected_dashscope_callable() -> None:
def fake_transcribe(*, audio_bytes: bytes, filename: str) -> dict[str, str]:
assert filename == "voice.wav"
assert audio_bytes == b"audio"
return {"text": "你好", "request_id": "req-1"}
tool = FunASRTool(transcribe_callable=fake_transcribe)
result = tool.transcribe(audio_bytes=b"audio", filename="voice.wav")
assert result["text"] == "你好"
assert result["request_id"] == "req-1"
assert result["model"] == "fun-asr-realtime-2025-11-07"
def test_transcribe_raises_runtime_error_when_provider_fails() -> None:
def fake_transcribe(*, audio_bytes: bytes, filename: str) -> dict[str, str]:
raise RuntimeError("upstream timeout")
tool = FunASRTool(transcribe_callable=fake_transcribe)
with pytest.raises(RuntimeError):
tool.transcribe(audio_bytes=b"audio", filename="voice.wav")
@@ -1,82 +0,0 @@
from __future__ import annotations
from decimal import Decimal
import pytest
from core.agent.cost_tracker import CostTracker
def test_normalize_usage_and_cost_aggregation() -> None:
tracker = CostTracker()
tracker.add_usage(
{
"prompt_tokens": 7,
"completion_tokens": 5,
"cost": "0.002500",
}
)
tracker.add_usage(
{
"input_tokens": 5,
"output_tokens": 3,
"cost": "0.003000",
"currency": "USD",
}
)
snapshot = tracker.snapshot()
assert snapshot["input_tokens"] == 12
assert snapshot["output_tokens"] == 8
assert snapshot["total_tokens"] == 20
assert snapshot["cost"] == Decimal("0.005500")
assert snapshot["currency"] == "USD"
def test_add_usage_rejects_negative_values() -> None:
tracker = CostTracker()
with pytest.raises(ValueError):
tracker.add_usage({"input_tokens": -1})
with pytest.raises(ValueError):
tracker.add_usage({"cost": "-0.010000"})
def test_snapshot_is_zero_before_any_usage() -> None:
tracker = CostTracker()
snapshot = tracker.snapshot()
assert snapshot["input_tokens"] == 0
assert snapshot["output_tokens"] == 0
assert snapshot["total_tokens"] == 0
assert snapshot["cost"] == Decimal("0")
assert snapshot["currency"] == "USD"
def test_add_usage_rejects_currency_mismatch() -> None:
tracker = CostTracker(currency="USD")
tracker.add_usage({"input_tokens": 1, "output_tokens": 1, "cost": "0.001000"})
with pytest.raises(ValueError):
tracker.add_usage(
{
"input_tokens": 1,
"output_tokens": 1,
"cost": "0.001000",
"currency": "CNY",
}
)
def test_add_usage_rejects_non_integral_token_values() -> None:
tracker = CostTracker()
with pytest.raises(ValueError):
tracker.add_usage({"input_tokens": 1.5})
with pytest.raises(ValueError):
tracker.add_usage({"output_tokens": True})
@@ -1,61 +0,0 @@
from __future__ import annotations
import pytest
from core.agent.event_bridge import map_internal_event
def test_map_run_started_event() -> None:
event = {"kind": "run_started", "session_id": "s1"}
mapped = map_internal_event(event)
assert mapped == {"type": "run.started", "run_id": "s1"}
def test_map_message_delta_event() -> None:
event = {"kind": "message_delta", "message_id": "m1", "delta": "hello"}
mapped = map_internal_event(event)
assert mapped == {"type": "message.delta", "message_id": "m1", "delta": "hello"}
def test_map_tool_events() -> None:
started = {
"kind": "tool_started",
"message_id": "m2",
"tool_name": "asr_fun_asr",
}
completed = {
"kind": "tool_completed",
"message_id": "m2",
"tool_name": "asr_fun_asr",
"result": "ok",
}
mapped_started = map_internal_event(started)
mapped_completed = map_internal_event(completed)
assert mapped_started["type"] == "tool.started"
assert mapped_started["tool_name"] == "asr_fun_asr"
assert mapped_completed["type"] == "tool.completed"
assert mapped_completed["result"] == "ok"
def test_map_run_completed_event() -> None:
event = {"kind": "run_completed", "session_id": "s1", "output": "done"}
mapped = map_internal_event(event)
assert mapped == {"type": "run.completed", "run_id": "s1", "output": "done"}
def test_map_unknown_event_raises() -> None:
with pytest.raises(ValueError):
map_internal_event({"kind": "unknown"})
def test_map_event_missing_required_field_raises_value_error() -> None:
with pytest.raises(ValueError):
map_internal_event({"kind": "message_delta", "message_id": "m1"})
@@ -1,104 +0,0 @@
from __future__ import annotations
from core.agent.orchestrator import AgentChatOrchestrator
async def _intent_stage(
*, message: str, context: dict[str, object]
) -> dict[str, object]:
sequence = context.setdefault("sequence", [])
if isinstance(sequence, list):
sequence.append("intent")
return {
"content": f"intent:{message}",
"usage": {"input_tokens": 2, "output_tokens": 1, "cost": "0.001000"},
}
async def _execution_stage(
*, message: str, context: dict[str, object]
) -> dict[str, object]:
sequence = context.setdefault("sequence", [])
if isinstance(sequence, list):
sequence.append("execution")
return {
"content": f"execution:{message}",
"usage": {"input_tokens": 3, "output_tokens": 2, "cost": "0.002000"},
}
async def _organization_stage(
*, message: str, context: dict[str, object]
) -> dict[str, object]:
sequence = context.setdefault("sequence", [])
if isinstance(sequence, list):
sequence.append("organization")
return {
"content": "final answer",
"usage": {"input_tokens": 4, "output_tokens": 1, "cost": "0.001500"},
}
def test_orchestrator_runs_three_stages_in_order() -> None:
orchestrator = AgentChatOrchestrator(
intent_stage=_intent_stage,
execution_stage=_execution_stage,
organization_stage=_organization_stage,
)
result = orchestrator.run_sync(run_id="run-1", user_message="hello")
assert result.context["sequence"] == ["intent", "execution", "organization"]
assert result.output == "final answer"
assert result.usage["total_tokens"] == 13
assert result.events[0]["type"] == "run.started"
assert result.events[-1]["type"] == "run.completed"
async def _failing_execution_stage(
*, message: str, context: dict[str, object]
) -> dict[str, object]:
sequence = context.setdefault("sequence", [])
if isinstance(sequence, list):
sequence.append("execution")
raise RuntimeError("boom")
def test_orchestrator_stops_and_marks_failed_when_middle_stage_raises() -> None:
orchestrator = AgentChatOrchestrator(
intent_stage=_intent_stage,
execution_stage=_failing_execution_stage,
organization_stage=_organization_stage,
)
result = orchestrator.run_sync(run_id="run-2", user_message="hello")
assert result.context["sequence"] == ["intent", "execution"]
assert result.events[-1]["type"] == "run.failed"
assert result.events[-1]["run_id"] == "run-2"
assert "boom" in (result.events[-1].get("error") or "")
assert result.failed is True
assert "boom" in (result.error or "")
def test_orchestrator_emits_stage_event_payload_shape() -> None:
orchestrator = AgentChatOrchestrator(
intent_stage=_intent_stage,
execution_stage=_execution_stage,
organization_stage=_organization_stage,
)
result = orchestrator.run_sync(run_id="run-3", user_message="hello")
for event in result.events:
assert "type" in event
assert event.get("run_id") == "run-3"
stage_events = [
event for event in result.events if event["type"] == "stage.completed"
]
assert [event["stage"] for event in stage_events] == [
"intent",
"execution",
"organization",
]
@@ -1,23 +0,0 @@
from __future__ import annotations
from datetime import datetime
from v1.agent.service import build_session_title
def test_build_session_title_truncates_first_message() -> None:
now = datetime(2026, 2, 25, 10, 30)
title = build_session_title(
"这是一个非常长的标题会被截断到二十四个可见字符用于会话摘要", now=now
)
assert len(title) == 24
def test_build_session_title_falls_back_when_message_empty() -> None:
now = datetime(2026, 2, 25, 10, 30)
title = build_session_title("\n ", now=now)
assert title == "新对话 2026-02-25 10:30"
@@ -1,132 +0,0 @@
from __future__ import annotations
from pathlib import Path
import pytest
from core.agent.crewai.template_loader import (
load_crewai_template,
load_tools_whitelist,
validate_workflow_stages,
)
def _write(path: Path, content: str) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(content, encoding="utf-8")
def _prepare_static_root(root: Path) -> Path:
_write(
root / "agents.yaml",
"""
intent:
role: Intent Agent
execution:
role: Execution Agent
organization:
role: Organization Agent
""".strip(),
)
_write(
root / "tasks.yaml",
"""
intent:
description: classify
execution:
description: run task
organization:
description: summarize
""".strip(),
)
_write(
root / "workflow.yaml",
"""
stages:
- intent
- execution
- organization
""".strip(),
)
_write(
root / "tools.yaml",
"""
tools:
- asr_fun_asr
- doc_extract
""".strip(),
)
return root
def test_load_crewai_template_success_when_all_files_valid(tmp_path: Path) -> None:
static_root = _prepare_static_root(tmp_path)
template = load_crewai_template(static_root)
assert set(template.agents.keys()) == {"intent", "execution", "organization"}
assert set(template.tasks.keys()) == {"intent", "execution", "organization"}
assert template.workflow["stages"] == ["intent", "execution", "organization"]
assert template.tools_whitelist == {"asr_fun_asr", "doc_extract"}
def test_load_crewai_template_raises_file_not_found_when_required_file_missing(
tmp_path: Path,
) -> None:
static_root = _prepare_static_root(tmp_path)
(static_root / "tasks.yaml").unlink()
with pytest.raises(FileNotFoundError):
load_crewai_template(static_root)
def test_load_crewai_template_raises_value_error_when_workflow_stages_invalid(
tmp_path: Path,
) -> None:
static_root = _prepare_static_root(tmp_path)
_write(
static_root / "workflow.yaml",
"""
stages:
- execution
- intent
- organization
""".strip(),
)
with pytest.raises(ValueError):
load_crewai_template(static_root)
def test_load_tools_whitelist_from_tools_yaml(tmp_path: Path) -> None:
static_root = _prepare_static_root(tmp_path)
whitelist = load_tools_whitelist(static_root)
assert whitelist == {"asr_fun_asr", "doc_extract"}
def test_validate_workflow_stages_accepts_exact_intent_execution_organization() -> None:
validate_workflow_stages(["intent", "execution", "organization"])
def test_validate_workflow_stages_rejects_extra_or_missing_stage() -> None:
with pytest.raises(ValueError):
validate_workflow_stages(["intent", "execution"])
with pytest.raises(ValueError):
validate_workflow_stages(["intent", "execution", "organization", "extra"])
def test_load_tools_whitelist_rejects_non_string_item(tmp_path: Path) -> None:
static_root = _prepare_static_root(tmp_path)
_write(
static_root / "tools.yaml",
"""
tools:
- asr_fun_asr
- 123
""".strip(),
)
with pytest.raises(ValueError):
load_tools_whitelist(static_root)
@@ -1,188 +0,0 @@
from __future__ import annotations
from pathlib import Path
import pytest
from sqlalchemy import Column, String, Table, func, select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from core.db.base import Base
from core.config.initial import init_data
from models.llm import Llm
from models.llm_factory import LlmFactory
def test_llm_catalog_file_exists_and_has_required_fields() -> None:
catalog_path = (
Path(__file__).resolve().parents[3]
/ "src"
/ "core"
/ "config"
/ "static"
/ "database"
/ "llm_catalog.yaml"
)
catalog = init_data.load_llm_catalog(catalog_path)
assert len(catalog["factories"]) == 6
assert len(catalog["llms"]) == 2
assert set(catalog["factories"][0].keys()) == {"name", "request_url", "avatar"}
assert set(catalog["llms"][0].keys()) == {"model_code", "factory_name"}
def test_load_llm_catalog_raises_on_invalid_structure(tmp_path: Path) -> None:
catalog_path = tmp_path / "llm_catalog.yaml"
catalog_path.write_text(
"""
factories:
- name: qwen
llms:
- model_code: qwen3.5-flash
""".strip(),
encoding="utf-8",
)
with pytest.raises(ValueError):
init_data.load_llm_catalog(catalog_path)
@pytest.mark.asyncio
async def test_initialize_data_is_idempotent(monkeypatch: pytest.MonkeyPatch) -> None:
users_table = Table(
"users",
Base.metadata,
Column("id", String, primary_key=True),
schema="auth",
extend_existing=True,
)
engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
session_maker = async_sessionmaker(
bind=engine, class_=AsyncSession, expire_on_commit=False
)
async with engine.begin() as conn:
await conn.exec_driver_sql("ATTACH DATABASE ':memory:' AS auth")
await conn.run_sync(Base.metadata.create_all)
monkeypatch.setattr(init_data, "AsyncSessionLocal", session_maker)
first = await init_data.initialize_data()
second = await init_data.initialize_data()
assert first is True
assert second is True
async with session_maker() as session:
factory_count = await session.scalar(
select(func.count()).select_from(LlmFactory)
)
llm_count = await session.scalar(select(func.count()).select_from(Llm))
assert factory_count == 6
assert llm_count == 2
Base.metadata.remove(users_table)
await engine.dispose()
@pytest.mark.asyncio
async def test_initialize_data_rolls_back_on_invalid_factory_mapping(
monkeypatch: pytest.MonkeyPatch,
) -> None:
users_table = Table(
"users",
Base.metadata,
Column("id", String, primary_key=True),
schema="auth",
extend_existing=True,
)
engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
session_maker = async_sessionmaker(
bind=engine, class_=AsyncSession, expire_on_commit=False
)
async with engine.begin() as conn:
await conn.exec_driver_sql("ATTACH DATABASE ':memory:' AS auth")
await conn.run_sync(Base.metadata.create_all)
monkeypatch.setattr(init_data, "AsyncSessionLocal", session_maker)
monkeypatch.setattr(
init_data,
"load_llm_catalog",
lambda *_: {
"factories": [
{
"name": "qwen",
"request_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
"avatar": "https://cdn.example.com/qwen.png",
}
],
"llms": [
{
"model_code": "qwen3.5-flash",
"factory_id": "missing_factory",
}
],
},
)
with pytest.raises(RuntimeError):
await init_data.initialize_data()
async with session_maker() as session:
factory_count = await session.scalar(
select(func.count()).select_from(LlmFactory)
)
llm_count = await session.scalar(select(func.count()).select_from(Llm))
assert factory_count == 0
assert llm_count == 0
Base.metadata.remove(users_table)
await engine.dispose()
def test_user_agent_catalog_file_exists_and_has_required_fields() -> None:
catalog_path = (
Path(__file__).resolve().parents[3]
/ "src"
/ "core"
/ "config"
/ "static"
/ "database"
/ "user_agent_catalog.yaml"
)
assert catalog_path.exists(), f"Catalog file not found: {catalog_path}"
catalog = init_data.load_user_agent_catalog(catalog_path)
assert "agents" in catalog
assert isinstance(catalog["agents"], list)
assert len(catalog["agents"]) == 3
for agent in catalog["agents"]:
assert "agent_type" in agent
assert "llm_model_code" in agent
assert "status" in agent
assert "config" in agent
assert isinstance(agent["config"], dict)
def test_load_user_agent_catalog_raises_on_invalid_structure(
tmp_path: Path,
) -> None:
catalog_path = tmp_path / "user_agent_catalog.yaml"
catalog_path.write_text(
"""
agents:
- agent_type: TEST
llm_model_code: test-model
status: ACTIVE
""".strip(),
encoding="utf-8",
)
with pytest.raises(ValueError, match="Invalid user agent catalog"):
init_data.load_user_agent_catalog(catalog_path)
@@ -1,27 +0,0 @@
from __future__ import annotations
from pathlib import Path
def test_initial_migration_exists_and_creates_expected_tables() -> None:
versions_dir = Path(__file__).resolve().parents[3] / "alembic" / "versions"
migration_files = sorted(versions_dir.glob("20260226_*.py"))
assert len(migration_files) == 5, "split initial migrations should exist"
content = "\n".join(m.read_text(encoding="utf-8") for m in migration_files)
# New tables from social data model redesign
assert "create_table(" in content and "automation_jobs" in content
assert "create_table(" in content and "user_agents" in content
assert "create_table(" in content and "memories" in content
assert "create_table(" in content and "friendships" in content
assert "create_table(" in content and "groups" in content
assert "create_table(" in content and "group_members" in content
assert "create_table(" in content and "schedule_items" in content
assert "create_table(" in content and "schedule_subscriptions" in content
assert "create_table(" in content and "inbox_messages" in content
assert "create_table(" in content and "todos" in content
assert "create_table(" in content and "todo_sources" in content
assert "create_table(" in content and "profiles" in content
assert "create_table(" in content and "sessions" in content
assert "create_table(" in content and "messages" in content
@@ -1,119 +0,0 @@
from __future__ import annotations
from uuid import uuid4
import pytest
from sqlalchemy import Column, String, Table, select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from core.db.base import Base
from models.agent_chat_message import AgentChatMessage
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
from models.llm import Llm
from models.llm_factory import LlmFactory
@pytest.fixture
async def db_engine():
users_table = Table(
"users",
Base.metadata,
Column("id", String, primary_key=True),
schema="auth",
extend_existing=True,
)
engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
async with engine.begin() as conn:
await conn.exec_driver_sql("ATTACH DATABASE ':memory:' AS auth")
await conn.run_sync(Base.metadata.create_all)
yield engine
Base.metadata.remove(users_table)
await engine.dispose()
@pytest.fixture
async def db_session(db_engine):
async_session = async_sessionmaker(
bind=db_engine,
class_=AsyncSession,
expire_on_commit=False,
)
async with async_session() as session:
yield session
await session.rollback()
@pytest.mark.asyncio
async def test_llm_factory_and_llm_relationship(db_session: AsyncSession) -> None:
factory = LlmFactory(
name="qwen",
request_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
avatar="https://cdn.example.com/qwen.png",
)
db_session.add(factory)
await db_session.flush()
llm = Llm(
factory_id=factory.id,
model_code="qwen3.5-flash",
)
db_session.add(llm)
await db_session.commit()
found_llm = await db_session.get(Llm, llm.id)
assert found_llm is not None
assert found_llm.factory_id == factory.id
@pytest.mark.asyncio
async def test_session_status_supports_required_values(
db_session: AsyncSession,
) -> None:
user_id = uuid4()
session = AgentChatSession(
user_id=user_id,
title="test",
status="pending",
)
db_session.add(session)
await db_session.commit()
statuses = [
AgentChatSessionStatus.PENDING,
AgentChatSessionStatus.RUNNING,
AgentChatSessionStatus.COMPLETED,
AgentChatSessionStatus.FAILED,
]
for status in statuses:
session.status = status
await db_session.commit()
await db_session.refresh(session)
assert session.status == status
@pytest.mark.asyncio
async def test_messages_role_supports_tool(db_session: AsyncSession) -> None:
user_id = uuid4()
session = AgentChatSession(
user_id=user_id,
title="tool test",
status="pending",
)
db_session.add(session)
await db_session.flush()
message = AgentChatMessage(
session_id=session.id,
seq=1,
role="tool",
content="tool output",
cost=0,
)
db_session.add(message)
await db_session.commit()
result = await db_session.execute(
select(AgentChatMessage).where(AgentChatMessage.session_id == session.id)
)
found = result.scalar_one()
assert found.role == "tool"
@@ -1,80 +0,0 @@
from __future__ import annotations
from datetime import datetime, timedelta, timezone
from uuid import UUID, uuid4
import pytest
from core.auth.models import CurrentUser
from models.agent_chat_session import (
AgentChatSession,
AgentChatSessionStatus,
SessionType,
)
from v1.agent.service import AgentChatService
from v1.agent.tool_registry import validate_tool_spec
class TestAgentSecurityRules:
def test_tool_name_must_be_allowlisted(self):
validate_tool_spec({"name": "ui.navigate_to", "execution_target": "frontend"})
validate_tool_spec({"name": "srv.search_docs", "execution_target": "backend"})
def test_tool_name_rejected_if_not_in_namespace(self):
try:
validate_tool_spec(
{"name": "malicious.tool", "execution_target": "frontend"}
)
except ValueError:
pass
else:
raise AssertionError("Should have raised ValueError for unknown namespace")
@pytest.mark.asyncio
async def test_frontend_result_fails_when_interrupt_mismatch(self):
session = AgentChatSession(
id=uuid4(),
user_id=UUID("00000000-0000-0000-0000-000000000001"),
session_type=SessionType.CHAT,
status=AgentChatSessionStatus.RUNNING,
)
class FakeAsyncSession:
def __init__(self, session_obj: AgentChatSession) -> None:
self._session_obj = session_obj
async def execute(self, stmt: object):
class _Result:
def __init__(self, session_obj: AgentChatSession | None) -> None:
self._session_obj = session_obj
def scalar_one_or_none(self) -> AgentChatSession | None:
return self._session_obj
return _Result(self._session_obj)
async def scalar(self, stmt: object) -> AgentChatSession | None:
return self._session_obj
service = AgentChatService(
session=FakeAsyncSession(session), # type: ignore[arg-type]
current_user=CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001")),
)
await service.set_pending_tool_call(
session_id=session.id,
interrupt_id="int-1",
tool_name="srv.transfer_funds",
tool_args={"to": "u2", "amount": 100},
expires_at=datetime.now(timezone.utc) + timedelta(minutes=5),
thread_id="t1",
run_id="r1",
)
result = await service.apply_resume_decision(
session_id=session.id,
interrupt_id="int-other",
decision={"decision": "approved"},
)
assert result.applied is False
@@ -1,25 +0,0 @@
from __future__ import annotations
import pytest
from v1.agent.crewai_flow import AgentFlow
class TestCrewAIFlow:
@pytest.mark.asyncio
async def test_flow_stages_run_in_order(self):
flow = AgentFlow()
await flow.run()
assert flow.state.stage_trace == ["intent", "execution", "reporting"]
@pytest.mark.asyncio
async def test_flow_state_initialized(self):
flow = AgentFlow()
assert flow.state.stage_trace == []
assert flow.state.current_stage is None
@pytest.mark.asyncio
async def test_flow_updates_current_stage(self):
flow = AgentFlow()
await flow.run()
assert flow.state.current_stage == "reporting"
@@ -1,187 +0,0 @@
from __future__ import annotations
from datetime import datetime, timedelta, timezone
from uuid import UUID, uuid4
import pytest
from models.agent_chat_session import (
AgentChatSession,
AgentChatSessionStatus,
SessionType,
)
from v1.agent.service import AgentChatService
class FakeAsyncSession:
def __init__(self) -> None:
self.added: list[object] = []
self._sessions: dict[UUID, AgentChatSession] = {}
self.last_fetch_with_lock = False
def add(self, obj: object) -> None:
self.added.append(obj)
if isinstance(obj, AgentChatSession):
self._sessions[obj.id] = obj
async def flush(self) -> None:
return None
async def commit(self) -> None:
pass
async def rollback(self) -> None:
pass
async def refresh(self, obj: object) -> None:
pass
async def execute(self, stmt: object):
self.last_fetch_with_lock = "FOR UPDATE" in str(stmt)
class _Result:
def __init__(self, session_obj: AgentChatSession | None) -> None:
self._session_obj = session_obj
def scalar_one_or_none(self) -> AgentChatSession | None:
return self._session_obj
return _Result(next(iter(self._sessions.values()), None))
async def scalar(self, stmt: object) -> AgentChatSession | None:
for session in self._sessions.values():
return session
return None
@pytest.fixture
def fake_db() -> FakeAsyncSession:
return FakeAsyncSession()
@pytest.fixture
def session(fake_db: FakeAsyncSession) -> AgentChatSession:
sess = AgentChatSession(
id=uuid4(),
user_id=uuid4(),
session_type=SessionType.CHAT,
status=AgentChatSessionStatus.RUNNING,
)
fake_db.add(sess)
return sess
@pytest.fixture
def service(fake_db: FakeAsyncSession) -> AgentChatService:
return AgentChatService(fake_db, current_user=None) # type: ignore[arg-type]
class TestResumeIdempotency:
@pytest.mark.asyncio
async def test_resume_is_idempotent(
self,
service: AgentChatService,
session: AgentChatSession,
fake_db: FakeAsyncSession,
):
expires_at = datetime.now(timezone.utc) + timedelta(hours=1)
await service.set_pending_tool_call(
session_id=session.id,
interrupt_id="int-1",
tool_name="srv.transfer_funds",
tool_args={"to": "u2", "amount": 100},
expires_at=expires_at,
thread_id="t1",
run_id="r1",
)
first = await service.apply_resume_decision(
session_id=session.id,
interrupt_id="int-1",
decision={"decision": "approved"},
)
second = await service.apply_resume_decision(
session_id=session.id,
interrupt_id="int-1",
decision={"decision": "approved"},
)
assert first.applied is True
assert second.applied is False
assert fake_db.last_fetch_with_lock is True
@pytest.mark.asyncio
async def test_resume_updates_status_to_approved(
self, service: AgentChatService, session: AgentChatSession
):
expires_at = datetime.now(timezone.utc) + timedelta(hours=1)
await service.set_pending_tool_call(
session_id=session.id,
interrupt_id="int-2",
tool_name="srv.delete_file",
tool_args={"file_id": "f1"},
expires_at=expires_at,
thread_id="t1",
run_id="r1",
)
result = await service.apply_resume_decision(
session_id=session.id,
interrupt_id="int-2",
decision={"decision": "approved"},
)
assert result.applied is True
snapshot = await service.get_state_snapshot(session.id)
assert snapshot["pending_tool_call"]["status"] == "APPROVED_EXECUTING"
assert snapshot["pending_tool_call"]["decision"] == {"decision": "approved"}
@pytest.mark.asyncio
async def test_resume_updates_status_to_rejected(
self, service: AgentChatService, session: AgentChatSession
):
expires_at = datetime.now(timezone.utc) + timedelta(hours=1)
await service.set_pending_tool_call(
session_id=session.id,
interrupt_id="int-3",
tool_name="srv.transfer_funds",
tool_args={"to": "u2", "amount": 100},
expires_at=expires_at,
thread_id="t1",
run_id="r1",
)
result = await service.apply_resume_decision(
session_id=session.id,
interrupt_id="int-3",
decision={"decision": "rejected"},
)
assert result.applied is True
snapshot = await service.get_state_snapshot(session.id)
assert snapshot["pending_tool_call"]["status"] == "REJECTED"
@pytest.mark.asyncio
async def test_resume_expired_pending_marks_expired_and_not_applied(
self, service: AgentChatService, session: AgentChatSession
):
expires_at = datetime.now(timezone.utc) - timedelta(seconds=1)
await service.set_pending_tool_call(
session_id=session.id,
interrupt_id="int-expired",
tool_name="srv.transfer_funds",
tool_args={"to": "u2", "amount": 100},
expires_at=expires_at,
thread_id="t1",
run_id="r1",
)
result = await service.apply_resume_decision(
session_id=session.id,
interrupt_id="int-expired",
decision={"decision": "approved"},
)
assert result.applied is False
snapshot = await service.get_state_snapshot(session.id)
assert snapshot["pending_tool_call"]["status"] == "EXPIRED"
-127
View File
@@ -1,127 +0,0 @@
from datetime import datetime, timezone
import pytest
from v1.agent.schemas import AgentSessionSnapshot, RunAgentInput
class TestRunAgentInput:
def test_requires_full_fields(self):
payload = {
"threadId": "t1",
"runId": "r1",
"state": {},
"messages": [],
"tools": [],
"context": [],
"forwardedProps": {},
}
model = RunAgentInput.model_validate(payload)
assert model.threadId == "t1"
assert model.runId == "r1"
assert model.parentRunId is None
assert model.state == {}
assert model.messages == []
assert model.tools == []
assert model.context == []
assert model.forwardedProps == {}
assert model.resume is None
def test_resume_optional(self):
payload = {
"threadId": "t1",
"runId": "r2",
"state": {},
"messages": [],
"tools": [],
"context": [],
"forwardedProps": {},
"resume": {"interruptId": "int-1", "payload": {"decision": "approved"}},
}
model = RunAgentInput.model_validate(payload)
assert model.resume is not None
assert model.resume["interruptId"] == "int-1"
assert model.resume["payload"]["decision"] == "approved"
def test_parent_run_id_optional(self):
payload = {
"threadId": "t1",
"runId": "r3",
"parentRunId": "p1",
"state": {"key": "value"},
"messages": [{"role": "user", "content": "hello"}],
"tools": [{"name": "ui.navigate_to"}],
"context": [{"type": "user", "id": "u1"}],
"forwardedProps": {"theme": "dark"},
}
model = RunAgentInput.model_validate(payload)
assert model.parentRunId == "p1"
assert model.state == {"key": "value"}
assert len(model.messages) == 1
assert model.messages[0]["role"] == "user"
class TestAgentSessionSnapshot:
def test_state_snapshot_v2_model_accepts_valid_payload(self):
payload = {
"version": 2,
"pending_tool_call": {
"interrupt_id": "int-1",
"tool_name": "srv.transfer_funds",
"tool_args": {"to": "u2", "amount": 100},
"status": "PENDING_APPROVAL",
"expires_at": "2026-03-03T12:00:00Z",
"decision": None,
"result": None,
"updated_at": "2026-03-03T11:59:00Z",
},
"run_context": {"thread_id": "t1", "run_id": "r1"},
}
model = AgentSessionSnapshot.model_validate(payload)
assert model.version == 2
assert model.pending_tool_call is not None
assert model.pending_tool_call.interrupt_id == "int-1"
assert model.pending_tool_call.updated_at == datetime(
2026, 3, 3, 11, 59, tzinfo=timezone.utc
)
def test_state_snapshot_v2_rejects_wrong_version(self):
payload = {
"version": 1,
"pending_tool_call": None,
"run_context": {"thread_id": "t1", "run_id": "r1"},
}
with pytest.raises(ValueError):
AgentSessionSnapshot.model_validate(payload)
def test_state_snapshot_v2_requires_pending_tool_call_key(self):
payload = {
"version": 2,
"run_context": {"thread_id": "t1", "run_id": "r1"},
}
with pytest.raises(ValueError):
AgentSessionSnapshot.model_validate(payload)
def test_state_snapshot_v2_rejects_extra_fields(self):
payload = {
"version": 2,
"pending_tool_call": {
"interrupt_id": "int-1",
"tool_name": "srv.transfer_funds",
"tool_args": {"to": "u2", "amount": 100},
"status": "PENDING_APPROVAL",
"expires_at": "2026-03-03T12:00:00Z",
"decision": None,
"result": None,
"updated_at": "2026-03-03T11:59:00Z",
"unexpected": True,
},
"run_context": {"thread_id": "t1", "run_id": "r1", "foo": "bar"},
}
with pytest.raises(ValueError):
AgentSessionSnapshot.model_validate(payload)
@@ -1,168 +0,0 @@
from __future__ import annotations
from datetime import datetime, timedelta, timezone
from uuid import UUID, uuid4
import pytest
from models.agent_chat_session import (
AgentChatSession,
AgentChatSessionStatus,
SessionType,
)
from v1.agent.service import AgentChatService
class FakeAsyncSession:
def __init__(self) -> None:
self.added: list[object] = []
self._sessions: dict[UUID, AgentChatSession] = {}
def add(self, obj: object) -> None:
self.added.append(obj)
if isinstance(obj, AgentChatSession):
self._sessions[obj.id] = obj
async def flush(self) -> None:
return None
async def commit(self) -> None:
pass
async def rollback(self) -> None:
pass
async def refresh(self, obj: object) -> None:
pass
async def execute(self, stmt: object):
class _Result:
def __init__(self, session_obj: AgentChatSession | None) -> None:
self._session_obj = session_obj
def scalar_one_or_none(self) -> AgentChatSession | None:
return self._session_obj
return _Result(next(iter(self._sessions.values()), None))
async def scalar(self, stmt: object) -> AgentChatSession | None:
for session in self._sessions.values():
return session
return None
@pytest.fixture
def fake_db() -> FakeAsyncSession:
return FakeAsyncSession()
@pytest.fixture
def session(fake_db: FakeAsyncSession) -> AgentChatSession:
sess = AgentChatSession(
id=uuid4(),
user_id=uuid4(),
session_type=SessionType.CHAT,
status=AgentChatSessionStatus.RUNNING,
)
fake_db.add(sess)
return sess
@pytest.fixture
def service(fake_db: FakeAsyncSession) -> AgentChatService:
return AgentChatService(fake_db, current_user=None) # type: ignore[arg-type]
class TestPendingToolCall:
@pytest.mark.asyncio
async def test_save_pending_tool_call_to_state_snapshot(
self, service: AgentChatService, session: AgentChatSession
):
expires_at = datetime.now(timezone.utc) + timedelta(hours=1)
await service.set_pending_tool_call(
session_id=session.id,
interrupt_id="int-1",
tool_name="srv.transfer_funds",
tool_args={"to": "u2", "amount": 100},
expires_at=expires_at,
thread_id="t1",
run_id="r1",
)
snapshot = await service.get_state_snapshot(session.id)
assert snapshot is not None
assert snapshot["version"] == 2
assert snapshot["run_context"]["thread_id"] == "t1"
assert snapshot["run_context"]["run_id"] == "r1"
assert snapshot["pending_tool_call"]["status"] == "PENDING_APPROVAL"
assert snapshot["pending_tool_call"]["interrupt_id"] == "int-1"
assert snapshot["pending_tool_call"]["tool_name"] == "srv.transfer_funds"
@pytest.mark.asyncio
async def test_get_state_snapshot_returns_none_when_empty(
self, service: AgentChatService, session: AgentChatSession
):
snapshot = await service.get_state_snapshot(session.id)
assert snapshot is None
@pytest.mark.asyncio
async def test_update_pending_tool_call_status(
self, service: AgentChatService, session: AgentChatSession
):
expires_at = datetime.now(timezone.utc) + timedelta(hours=1)
await service.set_pending_tool_call(
session_id=session.id,
interrupt_id="int-2",
tool_name="srv.delete_file",
tool_args={"file_id": "f1"},
expires_at=expires_at,
thread_id="t1",
run_id="r1",
)
await service.update_pending_tool_call_status(
session_id=session.id,
interrupt_id="int-2",
status="APPROVED_EXECUTING",
)
snapshot = await service.get_state_snapshot(session.id)
assert snapshot["pending_tool_call"]["status"] == "APPROVED_EXECUTING"
@pytest.mark.asyncio
async def test_invalid_legacy_snapshot_is_rejected(
self, service: AgentChatService, session: AgentChatSession
):
session.state_snapshot = {"pending_tool_call": {"status": "PENDING_APPROVAL"}}
with pytest.raises(ValueError):
await service.apply_resume_decision(
session_id=session.id,
interrupt_id="int-legacy",
decision={"decision": "approved"},
)
@pytest.mark.asyncio
async def test_snapshot_rejects_naive_datetime(
self, service: AgentChatService, session: AgentChatSession
):
session.state_snapshot = {
"version": 2,
"pending_tool_call": {
"interrupt_id": "int-naive",
"tool_name": "srv.transfer_funds",
"tool_args": {"to": "u2", "amount": 100},
"status": "PENDING_APPROVAL",
"expires_at": "2026-03-03T12:00:00",
"decision": None,
"result": None,
"updated_at": "2026-03-03T11:59:00",
},
"run_context": {"thread_id": "t1", "run_id": "r1"},
}
with pytest.raises(ValueError):
await service.apply_resume_decision(
session_id=session.id,
interrupt_id="int-naive",
decision={"decision": "approved"},
)
@@ -1,126 +0,0 @@
from __future__ import annotations
from datetime import datetime, timezone
from uuid import UUID, uuid4
import pytest
from fastapi import HTTPException
from core.auth.models import CurrentUser
from models.agent_chat_session import (
AgentChatSession,
AgentChatSessionStatus,
SessionType,
)
from v1.agent.schemas import RunAgentInput
from v1.agent.service import AgentChatService
class FakeAsyncSession:
def __init__(self, sessions: list[AgentChatSession]) -> None:
self._sessions = {session.id: session for session in sessions}
self.commit_called = False
async def execute(self, stmt: object):
class _Result:
def __init__(self, session_obj: AgentChatSession | None) -> None:
self._session_obj = session_obj
def scalar_one_or_none(self) -> AgentChatSession | None:
return self._session_obj
for session in self._sessions.values():
return _Result(session)
return _Result(None)
async def scalar(self, stmt: object) -> AgentChatSession | None:
for session in self._sessions.values():
return session
return None
async def commit(self) -> None:
self.commit_called = True
def _build_input(run_id: str) -> RunAgentInput:
return RunAgentInput.model_validate(
{
"threadId": "t1",
"runId": run_id,
"state": {},
"messages": [],
"tools": [],
"context": [],
"forwardedProps": {},
"resume": {"interruptId": "int-1", "payload": {"decision": "approved"}},
}
)
@pytest.mark.asyncio
async def test_stream_resume_rejects_non_owner_session() -> None:
session = AgentChatSession(
id=uuid4(),
user_id=uuid4(),
session_type=SessionType.CHAT,
status=AgentChatSessionStatus.RUNNING,
state_snapshot={
"version": 2,
"pending_tool_call": {
"interrupt_id": "int-1",
"tool_name": "srv.transfer_funds",
"tool_args": {"to": "u2", "amount": 100},
"status": "PENDING_APPROVAL",
"expires_at": datetime.now(timezone.utc).isoformat(),
"decision": None,
"result": None,
"updated_at": datetime.now(timezone.utc).isoformat(),
},
"run_context": {"thread_id": "t1", "run_id": str(uuid4())},
},
)
service = AgentChatService(
session=FakeAsyncSession([session]), # type: ignore[arg-type]
current_user=CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001")),
)
with pytest.raises(HTTPException) as exc_info:
await service.prepare_resume(str(session.id), _build_input(str(session.id)))
assert exc_info.value.status_code == 404
@pytest.mark.asyncio
async def test_prepare_resume_commits_expired_state_before_410() -> None:
owner_id = UUID("00000000-0000-0000-0000-000000000001")
session = AgentChatSession(
id=uuid4(),
user_id=owner_id,
session_type=SessionType.CHAT,
status=AgentChatSessionStatus.RUNNING,
state_snapshot={
"version": 2,
"pending_tool_call": {
"interrupt_id": "int-1",
"tool_name": "srv.transfer_funds",
"tool_args": {"to": "u2", "amount": 100},
"status": "PENDING_APPROVAL",
"expires_at": "2000-01-01T00:00:00+00:00",
"decision": None,
"result": None,
"updated_at": datetime.now(timezone.utc).isoformat(),
},
"run_context": {"thread_id": "t1", "run_id": str(uuid4())},
},
)
fake_db = FakeAsyncSession([session])
service = AgentChatService(
session=fake_db, # type: ignore[arg-type]
current_user=CurrentUser(id=owner_id),
)
with pytest.raises(HTTPException) as exc_info:
await service.prepare_resume(str(session.id), _build_input(str(session.id)))
assert exc_info.value.status_code == 410
assert fake_db.commit_called is True
@@ -1,65 +0,0 @@
from __future__ import annotations
import pytest
from v1.agent.tool_dispatcher import (
BackendExecutionResult,
InterruptResult,
ToolDispatcher,
dispatch_tool_call,
)
class TestToolDispatcher:
def test_frontend_tool_returns_interrupt(self):
tool = {
"name": "ui.navigate_to",
"execution_target": "frontend",
"args": {"path": "/home"},
}
result = dispatch_tool_call(tool)
assert isinstance(result, InterruptResult)
assert result.interrupt_type == "tool_execution"
assert result.tool_name == "ui.navigate_to"
def test_backend_tool_executes_directly(self):
tool = {
"name": "srv.get_user_info",
"execution_target": "backend",
"args": {"user_id": "u1"},
"requires_approval": False,
}
result = dispatch_tool_call(tool)
assert isinstance(result, BackendExecutionResult)
assert result.tool_name == "srv.get_user_info"
def test_backend_tool_with_approval_returns_interrupt(self):
tool = {
"name": "srv.transfer_funds",
"execution_target": "backend",
"args": {"to": "u2", "amount": 100},
"requires_approval": True,
}
result = dispatch_tool_call(tool)
assert isinstance(result, InterruptResult)
assert result.interrupt_type == "approval_required"
assert result.tool_name == "srv.transfer_funds"
def test_dispatcher_class_can_dispatch(self):
dispatcher = ToolDispatcher()
tool = {
"name": "ui.navigate_to",
"execution_target": "frontend",
"args": {"message": "Hello"},
}
result = dispatcher.dispatch(tool)
assert isinstance(result, InterruptResult)
def test_unknown_frontend_tool_is_rejected(self):
tool = {
"name": "ui.unknown_action",
"execution_target": "frontend",
"args": {},
}
with pytest.raises(ValueError, match="not in allowlist"):
dispatch_tool_call(tool)
@@ -1,27 +0,0 @@
import pytest
from v1.agent.tool_registry import validate_tool_spec
class TestValidateToolSpec:
def test_ui_namespace_must_be_frontend(self):
with pytest.raises(ValueError, match="ui.* must use frontend target"):
validate_tool_spec(
{"name": "ui.navigate_to", "execution_target": "backend"}
)
def test_srv_namespace_must_be_backend(self):
with pytest.raises(ValueError, match="srv.* must use backend target"):
validate_tool_spec(
{"name": "srv.search_docs", "execution_target": "frontend"}
)
def test_ui_namespace_with_frontend_is_valid(self):
validate_tool_spec({"name": "ui.navigate_to", "execution_target": "frontend"})
def test_srv_namespace_with_backend_is_valid(self):
validate_tool_spec({"name": "srv.search_docs", "execution_target": "backend"})
def test_other_namespace_is_rejected(self):
with pytest.raises(ValueError, match="must be in ui.* or srv.* namespace"):
validate_tool_spec({"name": "other.tool", "execution_target": "frontend"})
@@ -1,196 +0,0 @@
from __future__ import annotations
from decimal import Decimal
from uuid import uuid4
import pytest
from sqlalchemy import Column, String, Table, select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from fastapi import HTTPException
from sqlalchemy.exc import SQLAlchemyError
from core.auth.models import CurrentUser
from core.agent.orchestrator import OrchestratorResult
from core.db.base import Base
from models.agent_chat_message import AgentChatMessage
from models.agent_chat_session import AgentChatSession
from v1.agent.schemas import AgentChatRunRequest
from v1.agent.service import AgentChatService
@pytest.fixture
async def db_engine():
users_table = Table(
"users",
Base.metadata,
Column("id", String, primary_key=True),
schema="auth",
extend_existing=True,
)
engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
async with engine.begin() as conn:
await conn.exec_driver_sql("ATTACH DATABASE ':memory:' AS auth")
await conn.run_sync(Base.metadata.create_all)
yield engine
Base.metadata.remove(users_table)
await engine.dispose()
@pytest.fixture
async def db_session(db_engine):
async_session = async_sessionmaker(
bind=db_engine,
class_=AsyncSession,
expire_on_commit=False,
)
async with async_session() as session:
yield session
await session.rollback()
@pytest.mark.asyncio
async def test_run_creates_session_and_persists_messages(
db_session: AsyncSession,
) -> None:
user = CurrentUser(id=uuid4())
service = AgentChatService(session=db_session, current_user=user)
result = await service.run(AgentChatRunRequest(message="hello"))
assert result.session_id is not None
assert result.output == "hello"
assert [event.type for event in result.events] == [
"run.started",
"message.delta",
"run.completed",
]
session_obj = await db_session.get(AgentChatSession, result.session_id)
assert session_obj is not None
assert session_obj.message_count == 2
assert session_obj.status.value == "completed"
rows = await db_session.execute(
select(AgentChatMessage)
.where(AgentChatMessage.session_id == result.session_id)
.order_by(AgentChatMessage.seq.asc())
)
messages = rows.scalars().all()
assert len(messages) == 2
assert messages[0].role.value == "user"
assert messages[1].role.value == "assistant"
@pytest.mark.asyncio
async def test_run_appends_to_existing_session(db_session: AsyncSession) -> None:
user = CurrentUser(id=uuid4())
service = AgentChatService(session=db_session, current_user=user)
first = await service.run(AgentChatRunRequest(message="first"))
second = await service.run(
AgentChatRunRequest(message="second", session_id=first.session_id)
)
assert second.session_id == first.session_id
session_obj = await db_session.get(AgentChatSession, first.session_id)
assert session_obj is not None
assert session_obj.message_count == 4
@pytest.mark.asyncio
async def test_run_raises_502_and_marks_session_failed_when_orchestrator_fails(
db_session: AsyncSession,
) -> None:
user = CurrentUser(id=uuid4())
service = AgentChatService(session=db_session, current_user=user)
class _FailingOrchestrator:
async def run(self, *, run_id: str, user_message: str) -> OrchestratorResult:
return OrchestratorResult(
output="",
usage={
"input_tokens": 0,
"output_tokens": 0,
"total_tokens": 0,
"cost": Decimal("0"),
"currency": "USD",
},
events=[],
context={},
failed=True,
error="stage failed",
)
service._orchestrator = _FailingOrchestrator() # type: ignore[assignment]
with pytest.raises(HTTPException) as exc_info:
await service.run(AgentChatRunRequest(message="hello"))
assert exc_info.value.status_code == 502
rows = await db_session.execute(
select(AgentChatSession).where(AgentChatSession.user_id == user.id)
)
stored_session = rows.scalars().one()
assert stored_session.status.value == "failed"
@pytest.mark.asyncio
async def test_run_returns_422_when_message_is_blank(db_session: AsyncSession) -> None:
user = CurrentUser(id=uuid4())
service = AgentChatService(session=db_session, current_user=user)
with pytest.raises(HTTPException) as exc_info:
await service.run(AgentChatRunRequest(message=" "))
assert exc_info.value.status_code == 422
@pytest.mark.asyncio
async def test_run_returns_404_when_session_not_found(db_session: AsyncSession) -> None:
user = CurrentUser(id=uuid4())
service = AgentChatService(session=db_session, current_user=user)
with pytest.raises(HTTPException) as exc_info:
await service.run(AgentChatRunRequest(message="hello", session_id=uuid4()))
assert exc_info.value.status_code == 404
@pytest.mark.asyncio
async def test_run_returns_503_when_commit_raises_sqlalchemy_error(
db_session: AsyncSession,
monkeypatch: pytest.MonkeyPatch,
) -> None:
user = CurrentUser(id=uuid4())
service = AgentChatService(session=db_session, current_user=user)
async def _fail_commit() -> None:
raise SQLAlchemyError("db down")
monkeypatch.setattr(db_session, "commit", _fail_commit)
with pytest.raises(HTTPException) as exc_info:
await service.run(AgentChatRunRequest(message="hello"))
assert exc_info.value.status_code == 503
@pytest.mark.asyncio
async def test_run_returns_502_for_unexpected_exception(
db_session: AsyncSession,
) -> None:
user = CurrentUser(id=uuid4())
service = AgentChatService(session=db_session, current_user=user)
class _CrashingOrchestrator:
async def run(self, *, run_id: str, user_message: str) -> OrchestratorResult:
raise RuntimeError("unexpected")
service._orchestrator = _CrashingOrchestrator() # type: ignore[assignment]
with pytest.raises(HTTPException) as exc_info:
await service.run(AgentChatRunRequest(message="hello"))
assert exc_info.value.status_code == 502