feat: AG-UI 协议对齐与路由导航功能
- 前端: 添加 SSE 流式支持、stateSnapshot 事件、路由导航工具 - 前端: 实现工具调用审批流程,支持 pending 状态展示 - 后端: Agent 状态管理与会话持久化相关重构 - 文档: 新增 agent-agui-full-alignance 设计文档 - 测试: 补充相关单元测试和集成测试
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
from ag_ui.core import RunAgentInput
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
from core.agent.application.resume_service import ResumeService
|
||||
@@ -84,28 +86,76 @@ async def test_run_then_resume_persists_messages_and_session_state(
|
||||
|
||||
published: list[str] = []
|
||||
|
||||
def _publish(event_type: str, payload: dict[str, object]) -> None:
|
||||
del payload
|
||||
published.append(event_type)
|
||||
async def _publish(event: dict[str, object]) -> None:
|
||||
event_type = event.get("type")
|
||||
if isinstance(event_type, str):
|
||||
published.append(event_type)
|
||||
|
||||
try:
|
||||
run_result = run_agent_task(
|
||||
run_input_payload = {
|
||||
"threadId": str(session_uuid),
|
||||
"runId": "run-it-1",
|
||||
"state": {},
|
||||
"messages": [
|
||||
{"id": "u1", "role": "user", "content": "帮我打开日历"},
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"name": "navigate_to_route",
|
||||
"description": "navigate route",
|
||||
"parameters": {"type": "object"},
|
||||
}
|
||||
],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
}
|
||||
run_result = await run_agent_task(
|
||||
{
|
||||
"command": "run",
|
||||
"session_id": str(session_uuid),
|
||||
"user_input": "hello",
|
||||
"run_input": run_input_payload,
|
||||
},
|
||||
publish_event=_publish,
|
||||
run_service=RunService(),
|
||||
resume_service=ResumeService(),
|
||||
)
|
||||
pending_tool_call_id = str(run_result["pending_tool_call_id"])
|
||||
state_snapshot = run_result["state_snapshot"]
|
||||
assert isinstance(state_snapshot, dict)
|
||||
pending_tool_nonce = state_snapshot["pending_tool_nonce"]
|
||||
assert isinstance(pending_tool_nonce, str)
|
||||
|
||||
run_agent_task(
|
||||
await run_agent_task(
|
||||
{
|
||||
"command": "resume",
|
||||
"session_id": str(session_uuid),
|
||||
"tool_call_id": pending_tool_call_id,
|
||||
"run_input": {
|
||||
"threadId": str(session_uuid),
|
||||
"runId": "run-it-2",
|
||||
"state": {},
|
||||
"messages": [
|
||||
{
|
||||
"id": "tool-1",
|
||||
"role": "tool",
|
||||
"toolCallId": pending_tool_call_id,
|
||||
"content": json.dumps(
|
||||
{
|
||||
"toolName": "navigate_to_route",
|
||||
"toolArgs": {
|
||||
"target": "/calendar/dayweek",
|
||||
"replace": False,
|
||||
"__nonce": pending_tool_nonce,
|
||||
},
|
||||
"nonce": pending_tool_nonce,
|
||||
"result": {"ok": True},
|
||||
},
|
||||
ensure_ascii=True,
|
||||
separators=(",", ":"),
|
||||
),
|
||||
}
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
},
|
||||
},
|
||||
publish_event=_publish,
|
||||
run_service=RunService(),
|
||||
@@ -123,6 +173,9 @@ async def test_run_then_resume_persists_messages_and_session_state(
|
||||
assert db_session.state_snapshot == {
|
||||
"status": "completed",
|
||||
"pending_tool_call_id": None,
|
||||
"pending_tool_name": None,
|
||||
"pending_tool_args_sha256": None,
|
||||
"pending_tool_nonce": None,
|
||||
}
|
||||
|
||||
rows = await verify_session.execute(
|
||||
@@ -142,7 +195,7 @@ async def test_run_then_resume_persists_messages_and_session_state(
|
||||
assert messages[1].cost == Decimal("0.002500")
|
||||
|
||||
assert "RUN_STARTED" in published
|
||||
assert "RUN_RESUMED" in published
|
||||
assert "RUN_FINISHED" in published
|
||||
assert "TEXT_MESSAGE_CONTENT" in published
|
||||
finally:
|
||||
async with AsyncSessionLocal() as cleanup_session:
|
||||
@@ -219,7 +272,21 @@ async def test_run_service_embeds_profile_settings_in_runtime_system_prompt(
|
||||
seed_session.add(AgentChatSession(id=session_uuid, user_id=owner_id))
|
||||
await seed_session.commit()
|
||||
|
||||
result = await RunService().run(session_id=str(session_uuid), user_input="hello")
|
||||
result = await RunService().run(
|
||||
run_input=RunAgentInput.model_validate(
|
||||
{
|
||||
"threadId": str(session_uuid),
|
||||
"runId": "run-it-1",
|
||||
"state": {},
|
||||
"messages": [
|
||||
{"id": "u1", "role": "user", "content": "hello"},
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
assert result["persisted"] is True
|
||||
assert captured["user_input"] == "hello"
|
||||
|
||||
@@ -16,29 +16,38 @@ class _FakeStorage:
|
||||
return "etag-1"
|
||||
|
||||
|
||||
def test_closed_loop_run_flow_frontend_to_sse() -> None:
|
||||
session_id = "00000000-0000-0000-0000-000000000001"
|
||||
async def test_closed_loop_run_flow_frontend_to_sse() -> None:
|
||||
thread_id = "00000000-0000-0000-0000-000000000001"
|
||||
published: list[str] = []
|
||||
|
||||
class _FakeRunService:
|
||||
async def run(self, *, session_id: str, user_input: str) -> dict[str, object]:
|
||||
return {"session_id": session_id, "user_input": user_input}
|
||||
async def run(self, *, run_input: object) -> dict[str, object]:
|
||||
del run_input
|
||||
return {"threadId": thread_id, "runId": "run-1"}
|
||||
|
||||
def _publish(event_type: str, payload: dict[str, object]) -> None:
|
||||
del payload
|
||||
published.append(event_type)
|
||||
async def _publish(event: dict[str, object]) -> None:
|
||||
event_type = event.get("type")
|
||||
if isinstance(event_type, str):
|
||||
published.append(event_type)
|
||||
|
||||
result = run_agent_task(
|
||||
result = await run_agent_task(
|
||||
{
|
||||
"command": "run",
|
||||
"session_id": session_id,
|
||||
"user_input": "hello",
|
||||
"run_input": {
|
||||
"threadId": thread_id,
|
||||
"runId": "run-1",
|
||||
"state": {},
|
||||
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
},
|
||||
},
|
||||
publish_event=_publish,
|
||||
run_service=_FakeRunService(),
|
||||
)
|
||||
|
||||
assert result["session_id"] == session_id
|
||||
assert result["threadId"] == thread_id
|
||||
assert published[0] == "RUN_STARTED"
|
||||
assert published[-1] == "RUN_FINISHED"
|
||||
|
||||
|
||||
@@ -3,10 +3,12 @@ from __future__ import annotations
|
||||
from types import SimpleNamespace
|
||||
from uuid import uuid4
|
||||
|
||||
from ag_ui.core import RunAgentInput
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app import app
|
||||
from core.auth.models import CurrentUser
|
||||
from v1.agent import router as agent_router
|
||||
from v1.agent.dependencies import get_agent_service
|
||||
from v1.users.dependencies import get_current_user
|
||||
|
||||
@@ -16,52 +18,122 @@ class _FakeAgentService:
|
||||
self._stream_called = False
|
||||
|
||||
async def enqueue_run(
|
||||
self, *, session_id: str | None, prompt: str, current_user: CurrentUser
|
||||
self, *, run_input: RunAgentInput, current_user: CurrentUser
|
||||
):
|
||||
del prompt, current_user
|
||||
resolved_session = session_id or "auto-created-session"
|
||||
del current_user
|
||||
return SimpleNamespace(
|
||||
task_id="task-run-1",
|
||||
session_id=resolved_session,
|
||||
created=session_id is None,
|
||||
thread_id=run_input.thread_id,
|
||||
run_id=run_input.run_id,
|
||||
created=False,
|
||||
)
|
||||
|
||||
async def enqueue_resume(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
tool_call_id: str,
|
||||
thread_id: str,
|
||||
run_input: RunAgentInput,
|
||||
current_user: CurrentUser,
|
||||
):
|
||||
del tool_call_id, current_user
|
||||
del thread_id, current_user
|
||||
return SimpleNamespace(
|
||||
task_id="task-resume-1", session_id=session_id, created=False
|
||||
task_id="task-resume-1",
|
||||
thread_id=run_input.thread_id,
|
||||
run_id=run_input.run_id,
|
||||
created=False,
|
||||
)
|
||||
|
||||
async def stream_events(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
thread_id: str,
|
||||
last_event_id: str | None,
|
||||
current_user: CurrentUser,
|
||||
) -> list[dict[str, object]]:
|
||||
del session_id, current_user
|
||||
del thread_id, current_user
|
||||
if self._stream_called:
|
||||
return []
|
||||
self._stream_called = True
|
||||
return [
|
||||
{"id": "2-0", "event": {"type": "RUN_STARTED"}, "cursor": last_event_id}
|
||||
{
|
||||
"id": "2-0",
|
||||
"event": {
|
||||
"type": "RUN_STARTED",
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-1",
|
||||
},
|
||||
"cursor": last_event_id,
|
||||
}
|
||||
]
|
||||
|
||||
async def get_history_snapshot(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
before: str | None,
|
||||
current_user: CurrentUser,
|
||||
) -> dict[str, object]:
|
||||
del current_user
|
||||
return {
|
||||
"type": "STATE_SNAPSHOT",
|
||||
"threadId": thread_id,
|
||||
"snapshot": {
|
||||
"scope": "history_day",
|
||||
"day": before or "2026-03-07",
|
||||
"hasMore": False,
|
||||
"messages": [
|
||||
{
|
||||
"id": "msg-h1",
|
||||
"role": "assistant",
|
||||
"content": "history-message",
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
async def get_user_history_snapshot(
|
||||
self,
|
||||
*,
|
||||
current_user: CurrentUser,
|
||||
thread_id: str | None,
|
||||
before: str | None,
|
||||
) -> dict[str, object]:
|
||||
del current_user, before
|
||||
return {
|
||||
"type": "STATE_SNAPSHOT",
|
||||
"threadId": thread_id or "00000000-0000-0000-0000-000000000001",
|
||||
"snapshot": {
|
||||
"scope": "history_day",
|
||||
"day": "2026-03-07",
|
||||
"hasMore": False,
|
||||
"messages": [],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_run_requires_auth_and_returns_202_task_id() -> None:
|
||||
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
|
||||
client = TestClient(app)
|
||||
original_allow_run = agent_router._allow_run_request
|
||||
|
||||
async def _allow_run(*, user_id: str) -> bool:
|
||||
del user_id
|
||||
return True
|
||||
|
||||
agent_router._allow_run_request = _allow_run # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
unauthorized = client.post(
|
||||
"/api/v1/agent/runs",
|
||||
json={"session_id": "session-1", "prompt": "hello"},
|
||||
json={
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-1",
|
||||
"state": {},
|
||||
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
},
|
||||
)
|
||||
assert unauthorized.status_code == 401
|
||||
|
||||
@@ -70,20 +142,23 @@ def test_run_requires_auth_and_returns_202_task_id() -> None:
|
||||
)
|
||||
authorized = client.post(
|
||||
"/api/v1/agent/runs",
|
||||
json={"session_id": "session-1", "prompt": "hello"},
|
||||
json={
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-1",
|
||||
"state": {},
|
||||
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
},
|
||||
)
|
||||
assert authorized.status_code == 202
|
||||
assert authorized.json()["task_id"] == "task-run-1"
|
||||
assert authorized.json()["taskId"] == "task-run-1"
|
||||
assert authorized.json()["threadId"] == "00000000-0000-0000-0000-000000000001"
|
||||
assert authorized.json()["runId"] == "run-1"
|
||||
assert authorized.json()["created"] is False
|
||||
|
||||
first_chat = client.post(
|
||||
"/api/v1/agent/runs",
|
||||
json={"prompt": "hello"},
|
||||
)
|
||||
assert first_chat.status_code == 202
|
||||
assert first_chat.json()["session_id"] == "auto-created-session"
|
||||
assert first_chat.json()["created"] is True
|
||||
finally:
|
||||
agent_router._allow_run_request = original_allow_run # type: ignore[assignment]
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
@@ -93,15 +168,122 @@ def test_stream_reads_from_last_event_id() -> None:
|
||||
id=uuid4(), email="user@example.com"
|
||||
)
|
||||
client = TestClient(app)
|
||||
original_acquire = agent_router._acquire_sse_slot
|
||||
original_release = agent_router._release_sse_slot
|
||||
|
||||
async def _allow_slot(*, user_id: str) -> bool:
|
||||
del user_id
|
||||
return True
|
||||
|
||||
async def _noop_release(*, user_id: str) -> None:
|
||||
del user_id
|
||||
return None
|
||||
|
||||
agent_router._acquire_sse_slot = _allow_slot # type: ignore[assignment]
|
||||
agent_router._release_sse_slot = _noop_release # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
response = client.get(
|
||||
"/api/v1/agent/runs/session-1/events?idle_limit=1",
|
||||
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/events?idle_limit=1",
|
||||
headers={"Last-Event-ID": "1-0"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"].startswith("text/event-stream")
|
||||
assert "id: 2-0" in response.text
|
||||
assert "event: RUN_STARTED" in response.text
|
||||
assert '"threadId":"00000000-0000-0000-0000-000000000001"' in response.text
|
||||
finally:
|
||||
agent_router._acquire_sse_slot = original_acquire # type: ignore[assignment]
|
||||
agent_router._release_sse_slot = original_release # type: ignore[assignment]
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_stream_rejects_invalid_last_event_id() -> None:
|
||||
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=uuid4(), email="user@example.com"
|
||||
)
|
||||
client = TestClient(app)
|
||||
|
||||
try:
|
||||
response = client.get(
|
||||
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/events",
|
||||
headers={"Last-Event-ID": "bad-id"},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_history_returns_state_snapshot() -> None:
|
||||
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
|
||||
client = TestClient(app)
|
||||
|
||||
try:
|
||||
unauthorized = client.get(
|
||||
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/history"
|
||||
)
|
||||
assert unauthorized.status_code == 401
|
||||
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=uuid4(), email="user@example.com"
|
||||
)
|
||||
authorized = client.get(
|
||||
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/history",
|
||||
params={"before": "2026-03-07"},
|
||||
)
|
||||
assert authorized.status_code == 200
|
||||
payload = authorized.json()
|
||||
assert payload["type"] == "STATE_SNAPSHOT"
|
||||
assert payload["threadId"] == "00000000-0000-0000-0000-000000000001"
|
||||
assert payload["snapshot"]["scope"] == "history_day"
|
||||
assert payload["snapshot"]["day"] == "2026-03-07"
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_user_history_returns_latest_snapshot() -> None:
|
||||
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=uuid4(), email="user@example.com"
|
||||
)
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.get("/api/v1/agent/history")
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["type"] == "STATE_SNAPSHOT"
|
||||
assert body["threadId"] == "00000000-0000-0000-0000-000000000001"
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_run_rejects_oversized_user_text_payload() -> None:
|
||||
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=uuid4(), email="user@example.com"
|
||||
)
|
||||
client = TestClient(app)
|
||||
|
||||
try:
|
||||
response = client.post(
|
||||
"/api/v1/agent/runs",
|
||||
json={
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-oversize",
|
||||
"state": {},
|
||||
"messages": [
|
||||
{
|
||||
"id": "u1",
|
||||
"role": "user",
|
||||
"content": "x" * 11000,
|
||||
}
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from uuid import UUID
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import httpx
|
||||
import jwt
|
||||
@@ -56,15 +56,25 @@ async def test_agent_sse_closed_loop_live() -> None:
|
||||
run_resp = await client.post(
|
||||
f"{BASE_URL}/api/v1/agent/runs",
|
||||
headers=headers,
|
||||
json={"prompt": "请用一句话介绍你自己"},
|
||||
json={
|
||||
"threadId": str(uuid4()),
|
||||
"runId": "run-live-1",
|
||||
"state": {},
|
||||
"messages": [
|
||||
{"id": "u1", "role": "user", "content": "请用一句话介绍你自己"}
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
},
|
||||
)
|
||||
assert run_resp.status_code == 202
|
||||
|
||||
accepted = run_resp.json()
|
||||
session_id = str(accepted["session_id"])
|
||||
assert session_id
|
||||
thread_id = str(accepted["threadId"])
|
||||
assert thread_id
|
||||
|
||||
events_url = f"{BASE_URL}/api/v1/agent/runs/{session_id}/events"
|
||||
events_url = f"{BASE_URL}/api/v1/agent/runs/{thread_id}/events"
|
||||
event_names: list[str] = []
|
||||
async with client.stream("GET", events_url, headers=headers, timeout=20.0) as sse_resp:
|
||||
assert sse_resp.status_code == 200
|
||||
@@ -77,13 +87,13 @@ async def test_agent_sse_closed_loop_live() -> None:
|
||||
assert "RUN_FINISHED" in event_names or "RUN_ERROR" in event_names
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
session_row = await session.get(AgentChatSession, UUID(session_id))
|
||||
session_row = await session.get(AgentChatSession, UUID(thread_id))
|
||||
assert session_row is not None
|
||||
assert session_row.message_count >= 1
|
||||
assert session_row.total_tokens >= 0
|
||||
assert session_row.total_cost >= 0
|
||||
|
||||
rows = await session.execute(
|
||||
select(AgentChatMessage).where(AgentChatMessage.session_id == UUID(session_id))
|
||||
select(AgentChatMessage).where(AgentChatMessage.session_id == UUID(thread_id))
|
||||
)
|
||||
assert len(list(rows.scalars().all())) >= 1
|
||||
|
||||
@@ -132,7 +132,9 @@ def test_bridge_rejects_unknown_event_type() -> None:
|
||||
|
||||
def test_sse_format_includes_id_event_data() -> None:
|
||||
payload = to_sse_event(
|
||||
stream_id="1-0", event={"type": "RUN_STARTED", "data": {"a": 1}}
|
||||
stream_id="1-0",
|
||||
event={"type": "RUN_STARTED", "threadId": "t1", "runId": "r1"},
|
||||
)
|
||||
|
||||
assert payload.startswith("id: 1-0\nevent: RUN_STARTED\ndata: {")
|
||||
assert '"threadId":"t1"' in payload
|
||||
|
||||
@@ -56,12 +56,14 @@ def test_runtime_execute_uses_provider_prefixed_litellm_model(
|
||||
messages: list[dict[str, object]],
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
timeout: float | None = None,
|
||||
):
|
||||
captured["model"] = model
|
||||
captured["api_key"] = api_key
|
||||
captured["messages"] = messages
|
||||
captured["temperature"] = temperature
|
||||
captured["max_tokens"] = max_tokens
|
||||
captured["timeout"] = timeout
|
||||
return {
|
||||
"choices": [
|
||||
{
|
||||
@@ -113,6 +115,7 @@ def test_runtime_execute_uses_provider_prefixed_litellm_model(
|
||||
assert captured["api_key"] == "env-api-key"
|
||||
assert captured["temperature"] == 0.3
|
||||
assert captured["max_tokens"] == 256
|
||||
assert captured["timeout"] == 30.0
|
||||
assert result["assistant_text"] == "hello"
|
||||
|
||||
|
||||
@@ -128,6 +131,7 @@ def test_runtime_execute_injects_system_prompt_and_intent_template(
|
||||
messages: list[dict[str, object]],
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
timeout: float | None = None,
|
||||
):
|
||||
captured["messages"] = messages
|
||||
return {
|
||||
@@ -219,6 +223,7 @@ def test_runtime_execute_short_circuits_on_direct_execution(
|
||||
messages: list[dict[str, object]],
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
timeout: float | None = None,
|
||||
):
|
||||
del model, api_key, temperature, max_tokens
|
||||
calls.append(messages)
|
||||
@@ -331,6 +336,7 @@ def test_runtime_execute_runs_execution_and_organization_stages(
|
||||
messages: list[dict[str, object]],
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
timeout: float | None = None,
|
||||
):
|
||||
del model, api_key, temperature, max_tokens
|
||||
calls.append(messages)
|
||||
@@ -383,6 +389,7 @@ def test_runtime_execute_rejects_invalid_intent_json(
|
||||
messages: list[dict[str, object]],
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
timeout: float | None = None,
|
||||
):
|
||||
del model, api_key, messages, temperature, max_tokens
|
||||
return {
|
||||
@@ -506,6 +513,7 @@ def test_runtime_execute_minimizes_prompt_and_execution_payload(
|
||||
messages: list[dict[str, object]],
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
timeout: float | None = None,
|
||||
):
|
||||
del model, api_key, temperature, max_tokens
|
||||
calls.append(messages)
|
||||
|
||||
@@ -21,10 +21,12 @@ def test_run_completion_passes_optional_params_when_provided(monkeypatch) -> Non
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
temperature=0.6,
|
||||
max_tokens=120,
|
||||
timeout=12.5,
|
||||
)
|
||||
|
||||
assert captured["temperature"] == 0.6
|
||||
assert captured["max_tokens"] == 120
|
||||
assert captured["timeout"] == 12.5
|
||||
|
||||
|
||||
def test_run_completion_omits_optional_params_when_none(monkeypatch) -> None:
|
||||
@@ -45,7 +47,9 @@ def test_run_completion_omits_optional_params_when_none(monkeypatch) -> None:
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
temperature=None,
|
||||
max_tokens=None,
|
||||
timeout=None,
|
||||
)
|
||||
|
||||
assert "temperature" not in captured
|
||||
assert "max_tokens" not in captured
|
||||
assert "timeout" not in captured
|
||||
|
||||
@@ -2,64 +2,124 @@ from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from ag_ui.core import RunAgentInput
|
||||
from core.agent.infrastructure.queue.tasks import _build_redis_publisher, run_agent_task
|
||||
|
||||
|
||||
class _FakeRunService:
|
||||
async def run(self, *, session_id: str, user_input: str) -> dict[str, object]:
|
||||
return {"session_id": session_id, "user_input": user_input}
|
||||
async def run(self, *, run_input: RunAgentInput) -> dict[str, object]:
|
||||
return {
|
||||
"threadId": run_input.thread_id,
|
||||
"runId": run_input.run_id,
|
||||
}
|
||||
|
||||
|
||||
class _FakeResumeService:
|
||||
async def resume(self, *, session_id: str, tool_call_id: str) -> dict[str, object]:
|
||||
return {"session_id": session_id, "tool_call_id": tool_call_id}
|
||||
async def resume(
|
||||
self,
|
||||
*,
|
||||
run_input: RunAgentInput,
|
||||
) -> dict[str, object]:
|
||||
return {
|
||||
"threadId": run_input.thread_id,
|
||||
"runId": run_input.run_id,
|
||||
}
|
||||
|
||||
|
||||
def _build_run_input() -> dict[str, object]:
|
||||
return {
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-1",
|
||||
"state": {},
|
||||
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_task_emits_started_runtime_and_finished_events() -> None:
|
||||
session_id = "00000000-0000-0000-0000-000000000001"
|
||||
events: list[str] = []
|
||||
|
||||
async def _publish(event_type: str, payload: dict[str, object]) -> None:
|
||||
del payload
|
||||
events.append(event_type)
|
||||
async def _publish(event: dict[str, object]) -> None:
|
||||
event_type = event.get("type")
|
||||
if isinstance(event_type, str):
|
||||
events.append(event_type)
|
||||
|
||||
result = await run_agent_task(
|
||||
{
|
||||
"command": "run",
|
||||
"session_id": session_id,
|
||||
"user_input": "hello",
|
||||
"run_input": _build_run_input(),
|
||||
},
|
||||
publish_event=_publish,
|
||||
run_service=_FakeRunService(),
|
||||
resume_service=_FakeResumeService(),
|
||||
)
|
||||
|
||||
assert result["session_id"] == session_id
|
||||
assert events == ["RUN_STARTED", "RUNTIME_EVENT", "RUN_FINISHED"]
|
||||
assert result["threadId"] == "00000000-0000-0000-0000-000000000001"
|
||||
assert events == ["RUN_STARTED", "RUN_FINISHED"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_task_injects_context_and_redacts_sensitive_fields() -> None:
|
||||
published: list[dict[str, object]] = []
|
||||
|
||||
class _RunWithExtraEvents(_FakeRunService):
|
||||
async def run(self, *, run_input: RunAgentInput) -> dict[str, object]:
|
||||
return {
|
||||
"threadId": run_input.thread_id,
|
||||
"runId": run_input.run_id,
|
||||
"events": [
|
||||
{
|
||||
"type": "TEXT_MESSAGE_CONTENT",
|
||||
"messageId": "m1",
|
||||
"delta": "hi",
|
||||
"token": "secret-token",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
async def _publish(event: dict[str, object]) -> None:
|
||||
published.append(event)
|
||||
|
||||
await run_agent_task(
|
||||
{"command": "run", "run_input": _build_run_input()},
|
||||
publish_event=_publish,
|
||||
run_service=_RunWithExtraEvents(),
|
||||
resume_service=_FakeResumeService(),
|
||||
)
|
||||
|
||||
run_started = published[0]
|
||||
assert run_started["type"] == "RUN_STARTED"
|
||||
assert "input" not in run_started
|
||||
|
||||
text_event = published[1]
|
||||
assert text_event["type"] == "TEXT_MESSAGE_CONTENT"
|
||||
assert text_event["threadId"] == "00000000-0000-0000-0000-000000000001"
|
||||
assert text_event["runId"] == "run-1"
|
||||
assert text_event["token"] == "***REDACTED***"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_task_emits_error_event_on_exception() -> None:
|
||||
session_id = "00000000-0000-0000-0000-000000000001"
|
||||
|
||||
class _BrokenRunService(_FakeRunService):
|
||||
async def run(self, *, session_id: str, user_input: str) -> dict[str, object]:
|
||||
del session_id, user_input
|
||||
async def run(self, *, run_input: dict[str, object]) -> dict[str, object]:
|
||||
del run_input
|
||||
raise RuntimeError("boom")
|
||||
|
||||
events: list[str] = []
|
||||
|
||||
async def _publish(event_type: str, payload: dict[str, object]) -> None:
|
||||
del payload
|
||||
events.append(event_type)
|
||||
async def _publish(event: dict[str, object]) -> None:
|
||||
event_type = event.get("type")
|
||||
if isinstance(event_type, str):
|
||||
events.append(event_type)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
await run_agent_task(
|
||||
{
|
||||
"command": "run",
|
||||
"session_id": session_id,
|
||||
"user_input": "hello",
|
||||
"run_input": _build_run_input(),
|
||||
},
|
||||
publish_event=_publish,
|
||||
run_service=_BrokenRunService(),
|
||||
@@ -72,16 +132,44 @@ async def test_run_agent_task_emits_error_event_on_exception() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_task_rejects_invalid_command() -> None:
|
||||
with pytest.raises(ValueError, match="invalid command type"):
|
||||
await run_agent_task({"command": "invalid", "session_id": "00000000-0000-0000-0000-000000000001"})
|
||||
await run_agent_task({"command": "invalid", "run_input": _build_run_input()})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_task_resume_requires_tool_call_id() -> None:
|
||||
with pytest.raises(ValueError, match="tool_call_id is required"):
|
||||
async def test_run_agent_task_rejects_missing_run_input() -> None:
|
||||
with pytest.raises(ValueError, match="run_input is required"):
|
||||
await run_agent_task(
|
||||
{
|
||||
"command": "resume",
|
||||
"session_id": "00000000-0000-0000-0000-000000000001",
|
||||
"command": "run",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_task_resume_uses_run_input() -> None:
|
||||
async def _publish(event: dict[str, object]) -> None:
|
||||
del event
|
||||
|
||||
result = await run_agent_task(
|
||||
{
|
||||
"command": "resume",
|
||||
"run_input": _build_run_input(),
|
||||
},
|
||||
publish_event=_publish,
|
||||
run_service=_FakeRunService(),
|
||||
resume_service=_FakeResumeService(),
|
||||
)
|
||||
|
||||
assert result["runId"] == "run-1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_task_rejects_invalid_run_input() -> None:
|
||||
with pytest.raises(ValueError, match="invalid AG-UI RunAgentInput payload"):
|
||||
await run_agent_task(
|
||||
{
|
||||
"command": "run",
|
||||
"run_input": {"threadId": "x"},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -23,11 +23,34 @@ class _FakeRedisClient:
|
||||
) -> list[tuple[str, list[tuple[str, dict[str, str]]]]]:
|
||||
del count, block
|
||||
key, start_id = next(iter(streams.items()))
|
||||
if start_id == "$":
|
||||
if start_id == "0-0":
|
||||
return [(key, [("11-0", {"event": '{"type":"RUN_STARTED"}'})])]
|
||||
return [(key, [("12-0", {"event": '{"type":"RUN_FINISHED"}'})])]
|
||||
|
||||
|
||||
class _MalformedRedisClient:
|
||||
async def xread(
|
||||
self,
|
||||
streams: dict[str, str],
|
||||
count: int,
|
||||
block: int,
|
||||
) -> list[object]:
|
||||
del streams, count, block
|
||||
return ["bad-shape"]
|
||||
|
||||
|
||||
class _InvalidJsonRedisClient:
|
||||
async def xread(
|
||||
self,
|
||||
streams: dict[str, str],
|
||||
count: int,
|
||||
block: int,
|
||||
) -> list[tuple[str, list[tuple[str, dict[str, str]]]]]:
|
||||
del count, block
|
||||
key = next(iter(streams.keys()))
|
||||
return [(key, [("11-0", {"event": "not-json"})])]
|
||||
|
||||
|
||||
def test_append_event_writes_json_payload() -> None:
|
||||
client = _FakeRedisClient()
|
||||
session_id = uuid4()
|
||||
@@ -55,3 +78,26 @@ async def test_read_events_respects_last_event_id() -> None:
|
||||
|
||||
assert from_start[0]["id"] == "11-0"
|
||||
assert from_last[0]["id"] == "12-0"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_events_returns_empty_for_malformed_response() -> None:
|
||||
session_id = uuid4()
|
||||
store = RedisStreamEventStore(client=_MalformedRedisClient(), stream_prefix="agent:events")
|
||||
|
||||
rows = await store.read_events(session_id=session_id, last_event_id=None)
|
||||
|
||||
assert rows == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_events_skips_invalid_event_json() -> None:
|
||||
session_id = uuid4()
|
||||
store = RedisStreamEventStore(
|
||||
client=_InvalidJsonRedisClient(),
|
||||
stream_prefix="agent:events",
|
||||
)
|
||||
|
||||
rows = await store.read_events(session_id=session_id, last_event_id=None)
|
||||
|
||||
assert rows == []
|
||||
|
||||
@@ -5,11 +5,13 @@ from types import SimpleNamespace
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from ag_ui.core import RunAgentInput
|
||||
|
||||
from core.agent.application.resume_service import ResumeService
|
||||
from core.agent.application.run_service import RunService
|
||||
from core.agent.domain.system_agent_config import SystemAgentLLMConfig
|
||||
from core.agent.domain.user_context import UserAgentContext, parse_profile_settings
|
||||
from models.agent_chat_message import AgentChatMessageRole
|
||||
from models.agent_chat_session import AgentChatSessionStatus
|
||||
|
||||
|
||||
@@ -61,12 +63,69 @@ class _FakeUserContextCache:
|
||||
self.set_calls += 1
|
||||
|
||||
|
||||
def _build_run_input(
|
||||
*,
|
||||
thread_id: str,
|
||||
text: str = "hello",
|
||||
tools: list[dict[str, object]] | None = None,
|
||||
) -> RunAgentInput:
|
||||
return RunAgentInput.model_validate(
|
||||
{
|
||||
"threadId": thread_id,
|
||||
"runId": "run-1",
|
||||
"state": {},
|
||||
"messages": [{"id": "u1", "role": "user", "content": text}],
|
||||
"tools": tools or [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _build_resume_input(
|
||||
*,
|
||||
thread_id: str,
|
||||
tool_call_id: str,
|
||||
content: str | None = None,
|
||||
) -> RunAgentInput:
|
||||
payload = content
|
||||
if payload is None:
|
||||
payload = json.dumps(
|
||||
{
|
||||
"toolName": "navigate_to_route",
|
||||
"toolArgs": {"target": "/calendar/dayweek", "replace": False, "__nonce": "nonce-1"},
|
||||
"nonce": "nonce-1",
|
||||
"result": {"ok": True},
|
||||
},
|
||||
ensure_ascii=True,
|
||||
separators=(",", ":"),
|
||||
)
|
||||
return RunAgentInput.model_validate(
|
||||
{
|
||||
"threadId": thread_id,
|
||||
"runId": "run-2",
|
||||
"state": {},
|
||||
"messages": [
|
||||
{
|
||||
"id": "tool-1",
|
||||
"role": "tool",
|
||||
"toolCallId": tool_call_id,
|
||||
"content": payload,
|
||||
}
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_service_rejects_invalid_session_id() -> None:
|
||||
run_service = RunService()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await run_service.run(session_id="session-1", user_input="hello")
|
||||
await run_service.run(run_input=_build_run_input(thread_id="session-1"))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -74,7 +133,272 @@ async def test_resume_service_requires_pending_tool_call() -> None:
|
||||
resume_service = ResumeService()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await resume_service.resume(session_id="session-1", tool_call_id="call-1")
|
||||
await resume_service.resume(
|
||||
run_input=_build_resume_input(
|
||||
thread_id="session-1",
|
||||
tool_call_id="call-1",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_service_validates_pending_tool_guard_and_persists_payload(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
session_id = uuid4()
|
||||
user_id = uuid4()
|
||||
captured: list[dict[str, object]] = []
|
||||
|
||||
class _FakeDbSession:
|
||||
async def commit(self) -> None:
|
||||
return None
|
||||
|
||||
class _FakeSessionFactory:
|
||||
def __call__(self) -> "_FakeSessionFactory":
|
||||
return self
|
||||
|
||||
async def __aenter__(self) -> _FakeDbSession:
|
||||
return _FakeDbSession()
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> bool:
|
||||
del exc_type, exc, tb
|
||||
return False
|
||||
|
||||
class _FakeSessionRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def lock_session_for_update(self, *, session_id: object):
|
||||
return SimpleNamespace(
|
||||
id=session_id,
|
||||
user_id=user_id,
|
||||
status=AgentChatSessionStatus.RUNNING,
|
||||
message_count=0,
|
||||
total_tokens=0,
|
||||
total_cost=0,
|
||||
state_snapshot={
|
||||
"pending_tool_call_id": "call-1",
|
||||
"pending_tool_name": "navigate_to_route",
|
||||
"pending_tool_args_sha256": "c8e6573e6ce79d9a8052d5167e18602841fa4248f3e8a2efb448c4bfbd298a12",
|
||||
"pending_tool_nonce": "nonce-1",
|
||||
},
|
||||
)
|
||||
|
||||
async def next_message_seq(self, *, session_id: object) -> int:
|
||||
del session_id
|
||||
return 1
|
||||
|
||||
async def update_runtime_state(self, **kwargs) -> None:
|
||||
del kwargs
|
||||
|
||||
class _FakeMessageRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def append_message(self, **kwargs) -> None:
|
||||
captured.append(kwargs)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.resume_service.SessionRepository",
|
||||
_FakeSessionRepository,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.resume_service.MessageRepository",
|
||||
_FakeMessageRepository,
|
||||
)
|
||||
|
||||
service = ResumeService(session_factory=_FakeSessionFactory()) # type: ignore[arg-type]
|
||||
await service.resume(
|
||||
run_input=_build_resume_input(
|
||||
thread_id=str(session_id),
|
||||
tool_call_id="call-1",
|
||||
),
|
||||
)
|
||||
|
||||
assert captured[0]["role"] == AgentChatMessageRole.TOOL
|
||||
stored_payload = json.loads(captured[0]["content"])
|
||||
assert stored_payload["toolName"] == "navigate_to_route"
|
||||
assert stored_payload["result"]["ok"] is True
|
||||
assert stored_payload["result"]["applied"] is True
|
||||
assert "ui" not in stored_payload
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_service_rejects_mismatched_nonce(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
session_id = uuid4()
|
||||
user_id = uuid4()
|
||||
|
||||
class _FakeDbSession:
|
||||
async def commit(self) -> None:
|
||||
return None
|
||||
|
||||
class _FakeSessionFactory:
|
||||
def __call__(self) -> "_FakeSessionFactory":
|
||||
return self
|
||||
|
||||
async def __aenter__(self) -> _FakeDbSession:
|
||||
return _FakeDbSession()
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> bool:
|
||||
del exc_type, exc, tb
|
||||
return False
|
||||
|
||||
class _FakeSessionRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def lock_session_for_update(self, *, session_id: object):
|
||||
return SimpleNamespace(
|
||||
id=session_id,
|
||||
user_id=user_id,
|
||||
status=AgentChatSessionStatus.RUNNING,
|
||||
message_count=0,
|
||||
total_tokens=0,
|
||||
total_cost=0,
|
||||
state_snapshot={
|
||||
"pending_tool_call_id": "call-1",
|
||||
"pending_tool_name": "navigate_to_route",
|
||||
"pending_tool_args_sha256": "c8e6573e6ce79d9a8052d5167e18602841fa4248f3e8a2efb448c4bfbd298a12",
|
||||
"pending_tool_nonce": "nonce-1",
|
||||
},
|
||||
)
|
||||
|
||||
async def next_message_seq(self, *, session_id: object) -> int:
|
||||
del session_id
|
||||
return 1
|
||||
|
||||
async def update_runtime_state(self, **kwargs) -> None:
|
||||
del kwargs
|
||||
|
||||
class _FakeMessageRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def append_message(self, **kwargs) -> None:
|
||||
del kwargs
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.resume_service.SessionRepository",
|
||||
_FakeSessionRepository,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.resume_service.MessageRepository",
|
||||
_FakeMessageRepository,
|
||||
)
|
||||
|
||||
service = ResumeService(session_factory=_FakeSessionFactory()) # type: ignore[arg-type]
|
||||
with pytest.raises(ValueError, match="nonce"):
|
||||
await service.resume(
|
||||
run_input=_build_resume_input(
|
||||
thread_id=str(session_id),
|
||||
tool_call_id="call-1",
|
||||
content=json.dumps(
|
||||
{
|
||||
"toolName": "navigate_to_route",
|
||||
"toolArgs": {
|
||||
"target": "/calendar/dayweek",
|
||||
"replace": False,
|
||||
"__nonce": "nonce-1",
|
||||
},
|
||||
"nonce": "nonce-bad",
|
||||
"result": {"ok": True},
|
||||
},
|
||||
ensure_ascii=True,
|
||||
separators=(",", ":"),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_service_rejects_tool_result_when_not_ok(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
session_id = uuid4()
|
||||
user_id = uuid4()
|
||||
|
||||
class _FakeDbSession:
|
||||
async def commit(self) -> None:
|
||||
return None
|
||||
|
||||
class _FakeSessionFactory:
|
||||
def __call__(self) -> "_FakeSessionFactory":
|
||||
return self
|
||||
|
||||
async def __aenter__(self) -> _FakeDbSession:
|
||||
return _FakeDbSession()
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> bool:
|
||||
del exc_type, exc, tb
|
||||
return False
|
||||
|
||||
class _FakeSessionRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def lock_session_for_update(self, *, session_id: object):
|
||||
return SimpleNamespace(
|
||||
id=session_id,
|
||||
user_id=user_id,
|
||||
status=AgentChatSessionStatus.RUNNING,
|
||||
message_count=0,
|
||||
total_tokens=0,
|
||||
total_cost=0,
|
||||
state_snapshot={
|
||||
"pending_tool_call_id": "call-1",
|
||||
"pending_tool_name": "navigate_to_route",
|
||||
"pending_tool_args_sha256": "c8e6573e6ce79d9a8052d5167e18602841fa4248f3e8a2efb448c4bfbd298a12",
|
||||
"pending_tool_nonce": "nonce-1",
|
||||
},
|
||||
)
|
||||
|
||||
async def next_message_seq(self, *, session_id: object) -> int:
|
||||
del session_id
|
||||
return 1
|
||||
|
||||
async def update_runtime_state(self, **kwargs) -> None:
|
||||
del kwargs
|
||||
|
||||
class _FakeMessageRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def append_message(self, **kwargs) -> None:
|
||||
del kwargs
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.resume_service.SessionRepository",
|
||||
_FakeSessionRepository,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.resume_service.MessageRepository",
|
||||
_FakeMessageRepository,
|
||||
)
|
||||
|
||||
service = ResumeService(session_factory=_FakeSessionFactory()) # type: ignore[arg-type]
|
||||
with pytest.raises(ValueError, match="execution failed"):
|
||||
await service.resume(
|
||||
run_input=_build_resume_input(
|
||||
thread_id=str(session_id),
|
||||
tool_call_id="call-1",
|
||||
content=json.dumps(
|
||||
{
|
||||
"toolName": "navigate_to_route",
|
||||
"toolArgs": {
|
||||
"target": "/calendar/dayweek",
|
||||
"replace": False,
|
||||
"__nonce": "nonce-1",
|
||||
},
|
||||
"nonce": "nonce-1",
|
||||
"result": {"ok": False, "error": "navigator not bound"},
|
||||
},
|
||||
ensure_ascii=True,
|
||||
separators=(",", ":"),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -256,7 +580,9 @@ async def test_run_service_passes_user_context_system_prompt_to_runtime(
|
||||
session_uuid = session_id
|
||||
run_service = RunService(session_factory=_FakeSessionFactory()) # type: ignore[arg-type]
|
||||
|
||||
await run_service.run(session_id=str(session_id), user_input="hello")
|
||||
await run_service.run(
|
||||
run_input=_build_run_input(thread_id=str(session_id), text="hello")
|
||||
)
|
||||
|
||||
system_prompt = captured["system_prompt"]
|
||||
assert isinstance(system_prompt, str)
|
||||
@@ -267,6 +593,290 @@ async def test_run_service_passes_user_context_system_prompt_to_runtime(
|
||||
assert payload["ai_language"] == "en-US"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_service_emits_frontend_tool_pending_events(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
session_id = uuid4()
|
||||
user_id = uuid4()
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
class _FakeDbSession:
|
||||
async def commit(self) -> None:
|
||||
return None
|
||||
|
||||
class _FakeSessionFactory:
|
||||
def __call__(self) -> "_FakeSessionFactory":
|
||||
return self
|
||||
|
||||
async def __aenter__(self) -> _FakeDbSession:
|
||||
return _FakeDbSession()
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> bool:
|
||||
del exc_type, exc, tb
|
||||
return False
|
||||
|
||||
class _FakeSessionRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def lock_session_for_update(self, *, session_id: object):
|
||||
return SimpleNamespace(
|
||||
id=session_id,
|
||||
user_id=user_id,
|
||||
status=AgentChatSessionStatus.PENDING,
|
||||
message_count=0,
|
||||
total_tokens=0,
|
||||
total_cost=0,
|
||||
state_snapshot=None,
|
||||
)
|
||||
|
||||
async def next_message_seq(self, *, session_id: object):
|
||||
del session_id
|
||||
return 1
|
||||
|
||||
async def update_runtime_state(self, **kwargs) -> None:
|
||||
captured["update_runtime_state"] = kwargs
|
||||
|
||||
class _FakeMessageRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def append_message(self, **kwargs) -> None:
|
||||
captured.setdefault("messages", []).append(kwargs)
|
||||
|
||||
class _FakeRuntime:
|
||||
def execute(self, *, user_input: str, system_prompt: str | None = None):
|
||||
del user_input, system_prompt
|
||||
return {
|
||||
"assistant_text": "请确认是否跳转。",
|
||||
"prompt_tokens": 1,
|
||||
"completion_tokens": 1,
|
||||
"total_tokens": 2,
|
||||
"cost": "0.001",
|
||||
"agui_events": [],
|
||||
}
|
||||
|
||||
async def _fake_load_agent_model_selection(self, _session):
|
||||
del self
|
||||
return ("qwen3.5-flash", "dashscope", SystemAgentLLMConfig())
|
||||
|
||||
async def _fake_load_user_agent_context(self, session, session_id, user_id):
|
||||
del self, session, session_id
|
||||
return SimpleNamespace(
|
||||
user_id=user_id,
|
||||
username="demo-user",
|
||||
bio=None,
|
||||
settings=SimpleNamespace(
|
||||
preferences=SimpleNamespace(
|
||||
interface_language="zh-CN",
|
||||
ai_language="zh-CN",
|
||||
timezone="Asia/Shanghai",
|
||||
country="CN",
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.run_service.SessionRepository",
|
||||
_FakeSessionRepository,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.run_service.MessageRepository",
|
||||
_FakeMessageRepository,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.run_service.create_runtime",
|
||||
lambda **_kwargs: _FakeRuntime(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.run_service.RunService._load_agent_model_selection",
|
||||
_fake_load_agent_model_selection,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.run_service.RunService._load_user_agent_context",
|
||||
_fake_load_user_agent_context,
|
||||
)
|
||||
|
||||
service = RunService(session_factory=_FakeSessionFactory()) # type: ignore[arg-type]
|
||||
result = await service.run(
|
||||
run_input=_build_run_input(
|
||||
thread_id=str(session_id),
|
||||
text="帮我打开日历",
|
||||
tools=[
|
||||
{
|
||||
"name": "navigate_to_route",
|
||||
"description": "navigate",
|
||||
"parameters": {"type": "object"},
|
||||
}
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
assert result["pending_tool_call_id"] is not None
|
||||
tool_start = next(event for event in result["events"] if event["type"] == "TOOL_CALL_START")
|
||||
assert tool_start["toolCallName"] == "navigate_to_route"
|
||||
runtime_state = captured["update_runtime_state"]
|
||||
assert isinstance(runtime_state, dict)
|
||||
assert runtime_state["status"] == AgentChatSessionStatus.RUNNING
|
||||
snapshot = runtime_state["state_snapshot"]
|
||||
assert isinstance(snapshot, dict)
|
||||
assert snapshot["pending_tool_name"] == "navigate_to_route"
|
||||
assert isinstance(snapshot["pending_tool_args_sha256"], str)
|
||||
assert isinstance(snapshot["pending_tool_nonce"], str)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_service_executes_backend_calendar_tool_and_emits_result(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
session_id = uuid4()
|
||||
user_id = uuid4()
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
class _FakeDbSession:
|
||||
async def commit(self) -> None:
|
||||
return None
|
||||
|
||||
class _FakeSessionFactory:
|
||||
def __call__(self) -> "_FakeSessionFactory":
|
||||
return self
|
||||
|
||||
async def __aenter__(self) -> _FakeDbSession:
|
||||
return _FakeDbSession()
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> bool:
|
||||
del exc_type, exc, tb
|
||||
return False
|
||||
|
||||
class _FakeSessionRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def lock_session_for_update(self, *, session_id: object):
|
||||
return SimpleNamespace(
|
||||
id=session_id,
|
||||
user_id=user_id,
|
||||
status=AgentChatSessionStatus.PENDING,
|
||||
message_count=0,
|
||||
total_tokens=0,
|
||||
total_cost=0,
|
||||
state_snapshot=None,
|
||||
)
|
||||
|
||||
async def next_message_seq(self, *, session_id: object):
|
||||
del session_id
|
||||
return 1
|
||||
|
||||
async def update_runtime_state(self, **kwargs) -> None:
|
||||
captured["update_runtime_state"] = kwargs
|
||||
|
||||
class _FakeMessageRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def append_message(self, **kwargs) -> None:
|
||||
captured.setdefault("messages", []).append(kwargs)
|
||||
|
||||
class _FakeRuntime:
|
||||
def execute(self, *, user_input: str, system_prompt: str | None = None):
|
||||
del user_input, system_prompt
|
||||
return {
|
||||
"assistant_text": "日历事件已创建。",
|
||||
"prompt_tokens": 1,
|
||||
"completion_tokens": 1,
|
||||
"total_tokens": 2,
|
||||
"cost": "0.001",
|
||||
"agui_events": [],
|
||||
}
|
||||
|
||||
async def _fake_load_agent_model_selection(self, _session):
|
||||
del self
|
||||
return ("qwen3.5-flash", "dashscope", SystemAgentLLMConfig())
|
||||
|
||||
async def _fake_load_user_agent_context(self, session, session_id, user_id):
|
||||
del self, session, session_id
|
||||
return SimpleNamespace(
|
||||
user_id=user_id,
|
||||
username="demo-user",
|
||||
bio=None,
|
||||
settings=SimpleNamespace(
|
||||
preferences=SimpleNamespace(
|
||||
interface_language="zh-CN",
|
||||
ai_language="zh-CN",
|
||||
timezone="Asia/Shanghai",
|
||||
country="CN",
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
async def _fake_execute_backend_tool(
|
||||
self,
|
||||
*,
|
||||
session,
|
||||
owner_id,
|
||||
tool_name,
|
||||
tool_args,
|
||||
):
|
||||
del self, session, owner_id
|
||||
assert tool_name == "create_calendar_event"
|
||||
assert "title" in tool_args
|
||||
return {
|
||||
"result": {"eventId": "evt-1", "ok": True},
|
||||
"ui": {
|
||||
"type": "calendar_card.v1",
|
||||
"version": "v1",
|
||||
"data": {"id": "evt-1", "title": "会议"},
|
||||
},
|
||||
}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.run_service.SessionRepository",
|
||||
_FakeSessionRepository,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.run_service.MessageRepository",
|
||||
_FakeMessageRepository,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.run_service.create_runtime",
|
||||
lambda **_kwargs: _FakeRuntime(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.run_service.RunService._load_agent_model_selection",
|
||||
_fake_load_agent_model_selection,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.run_service.RunService._load_user_agent_context",
|
||||
_fake_load_user_agent_context,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.run_service.RunService._execute_backend_tool",
|
||||
_fake_execute_backend_tool,
|
||||
)
|
||||
|
||||
service = RunService(session_factory=_FakeSessionFactory()) # type: ignore[arg-type]
|
||||
result = await service.run(
|
||||
run_input=_build_run_input(
|
||||
thread_id=str(session_id),
|
||||
text='#tool:create_calendar_event {"title":"会议","startAt":"2026-03-07T08:00:00Z"}',
|
||||
tools=[
|
||||
{
|
||||
"name": "create_calendar_event",
|
||||
"description": "create calendar",
|
||||
"parameters": {"type": "object"},
|
||||
}
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
assert result["pending_tool_call_id"] is None
|
||||
assert any(event["type"] == "TOOL_CALL_RESULT" for event in result["events"])
|
||||
runtime_state = captured["update_runtime_state"]
|
||||
assert isinstance(runtime_state, dict)
|
||||
assert runtime_state["status"] == AgentChatSessionStatus.COMPLETED
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_user_agent_context_parses_profile_settings_v1() -> None:
|
||||
session_id = uuid4()
|
||||
@@ -519,7 +1129,9 @@ async def test_run_service_still_executes_when_profile_missing(
|
||||
session_uuid = session_id
|
||||
run_service = RunService(session_factory=_FakeSessionFactory()) # type: ignore[arg-type]
|
||||
|
||||
await run_service.run(session_id=str(session_id), user_input="hello")
|
||||
await run_service.run(
|
||||
run_input=_build_run_input(thread_id=str(session_id), text="hello")
|
||||
)
|
||||
|
||||
system_prompt = captured["system_prompt"]
|
||||
assert isinstance(system_prompt, str)
|
||||
|
||||
@@ -4,9 +4,18 @@ from core.agent.domain.state_snapshot import AgentStateSnapshot
|
||||
|
||||
|
||||
def test_state_snapshot_serialization_round_trip() -> None:
|
||||
snapshot = AgentStateSnapshot(status="running", pending_tool_call_id="call-1")
|
||||
snapshot = AgentStateSnapshot(
|
||||
status="running",
|
||||
pending_tool_call_id="call-1",
|
||||
pending_tool_name="navigate_to_route",
|
||||
pending_tool_args_sha256="abc",
|
||||
pending_tool_nonce="nonce-1",
|
||||
)
|
||||
|
||||
payload = snapshot.model_dump()
|
||||
|
||||
assert payload["status"] == "running"
|
||||
assert payload["pending_tool_call_id"] == "call-1"
|
||||
assert payload["pending_tool_name"] == "navigate_to_route"
|
||||
assert payload["pending_tool_args_sha256"] == "abc"
|
||||
assert payload["pending_tool_nonce"] == "nonce-1"
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from v1.agent import router as agent_router
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allow_run_request_fails_closed_when_redis_unavailable(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
async def _raise_redis_error():
|
||||
raise RuntimeError("redis unavailable")
|
||||
|
||||
monkeypatch.setattr(agent_router, "get_or_init_redis_client", _raise_redis_error)
|
||||
|
||||
allowed = await agent_router._allow_run_request(user_id="user-1")
|
||||
|
||||
assert allowed is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_sse_slot_fails_closed_when_redis_unavailable(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
async def _raise_redis_error():
|
||||
raise RuntimeError("redis unavailable")
|
||||
|
||||
monkeypatch.setattr(agent_router, "get_or_init_redis_client", _raise_redis_error)
|
||||
|
||||
allowed = await agent_router._acquire_sse_slot(user_id="user-1")
|
||||
|
||||
assert allowed is False
|
||||
@@ -1,7 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
from uuid import UUID
|
||||
|
||||
from ag_ui.core import RunAgentInput
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from core.auth.models import CurrentUser
|
||||
from v1.agent.service import AgentService
|
||||
|
||||
@@ -11,14 +15,19 @@ class _FakeRepository:
|
||||
self.committed = False
|
||||
self.rolled_back = False
|
||||
self.deleted_session_id: str | None = None
|
||||
self.created_with_session_id: str | None = None
|
||||
|
||||
async def get_session_owner(self, *, session_id: str) -> str:
|
||||
del session_id
|
||||
return "00000000-0000-0000-0000-000000000001"
|
||||
if session_id == "00000000-0000-0000-0000-000000000001":
|
||||
return "00000000-0000-0000-0000-000000000001"
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
async def create_session_for_user(self, *, user_id: str) -> str:
|
||||
async def create_session_for_user(
|
||||
self, *, user_id: str, session_id: str | None = None
|
||||
) -> str:
|
||||
del user_id
|
||||
return "00000000-0000-0000-0000-000000000999"
|
||||
self.created_with_session_id = session_id
|
||||
return session_id or "00000000-0000-0000-0000-000000000999"
|
||||
|
||||
async def commit(self) -> None:
|
||||
self.committed = True
|
||||
@@ -29,6 +38,22 @@ class _FakeRepository:
|
||||
async def delete_session(self, *, session_id: str) -> None:
|
||||
self.deleted_session_id = session_id
|
||||
|
||||
async def get_history_day(
|
||||
self, *, session_id: str, before: date | None
|
||||
) -> dict[str, object] | None:
|
||||
del session_id
|
||||
if before is not None and before <= date(2026, 3, 6):
|
||||
return None
|
||||
return {
|
||||
"day": "2026-03-06",
|
||||
"hasMore": False,
|
||||
"messages": [{"id": "m1", "role": "assistant", "content": "hello"}],
|
||||
}
|
||||
|
||||
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None:
|
||||
del user_id
|
||||
return "00000000-0000-0000-0000-000000000001"
|
||||
|
||||
|
||||
class _FakeQueue:
|
||||
async def enqueue(
|
||||
@@ -63,6 +88,20 @@ def _user() -> CurrentUser:
|
||||
)
|
||||
|
||||
|
||||
def _build_run_input(*, thread_id: str, run_id: str) -> RunAgentInput:
|
||||
return RunAgentInput.model_validate(
|
||||
{
|
||||
"threadId": thread_id,
|
||||
"runId": run_id,
|
||||
"state": {},
|
||||
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def test_resume_idempotency_uses_redis_lock_and_task_key() -> None:
|
||||
service = AgentService(
|
||||
repository=_FakeRepository(),
|
||||
@@ -70,37 +109,46 @@ async def test_resume_idempotency_uses_redis_lock_and_task_key() -> None:
|
||||
stream=_FakeStream(),
|
||||
)
|
||||
user = _user()
|
||||
run_input = _build_run_input(
|
||||
thread_id="00000000-0000-0000-0000-000000000001",
|
||||
run_id="run-1",
|
||||
)
|
||||
|
||||
first = await service.enqueue_resume(
|
||||
session_id="session-1",
|
||||
tool_call_id="call-1",
|
||||
thread_id="00000000-0000-0000-0000-000000000001",
|
||||
run_input=run_input,
|
||||
current_user=user,
|
||||
)
|
||||
second = await service.enqueue_resume(
|
||||
session_id="session-1",
|
||||
tool_call_id="call-1",
|
||||
thread_id="00000000-0000-0000-0000-000000000001",
|
||||
run_input=run_input,
|
||||
current_user=user,
|
||||
)
|
||||
|
||||
assert first.task_id == second.task_id
|
||||
|
||||
|
||||
async def test_enqueue_run_without_session_creates_new_session() -> None:
|
||||
async def test_enqueue_run_creates_missing_thread_session() -> None:
|
||||
repository = _FakeRepository()
|
||||
service = AgentService(
|
||||
repository=repository,
|
||||
queue=_FakeQueue(),
|
||||
stream=_FakeStream(),
|
||||
)
|
||||
run_input = _build_run_input(
|
||||
thread_id="00000000-0000-0000-0000-000000000999",
|
||||
run_id="run-1",
|
||||
)
|
||||
|
||||
accepted = await service.enqueue_run(
|
||||
session_id=None,
|
||||
prompt="hello",
|
||||
run_input=run_input,
|
||||
current_user=_user(),
|
||||
)
|
||||
|
||||
assert accepted.session_id == "00000000-0000-0000-0000-000000000999"
|
||||
assert accepted.thread_id == "00000000-0000-0000-0000-000000000999"
|
||||
assert accepted.run_id == "run-1"
|
||||
assert accepted.created is True
|
||||
assert repository.created_with_session_id == "00000000-0000-0000-0000-000000000999"
|
||||
assert repository.committed is True
|
||||
|
||||
|
||||
@@ -111,11 +159,14 @@ async def test_enqueue_run_keeps_created_session_when_enqueue_fails() -> None:
|
||||
queue=_FailingQueue(),
|
||||
stream=_FakeStream(),
|
||||
)
|
||||
run_input = _build_run_input(
|
||||
thread_id="00000000-0000-0000-0000-000000000999",
|
||||
run_id="run-1",
|
||||
)
|
||||
|
||||
try:
|
||||
await service.enqueue_run(
|
||||
session_id=None,
|
||||
prompt="hello",
|
||||
run_input=run_input,
|
||||
current_user=_user(),
|
||||
)
|
||||
raise AssertionError("expected RuntimeError")
|
||||
@@ -123,3 +174,78 @@ async def test_enqueue_run_keeps_created_session_when_enqueue_fails() -> None:
|
||||
assert str(exc) == "enqueue failed"
|
||||
|
||||
assert repository.deleted_session_id is None
|
||||
|
||||
|
||||
async def test_enqueue_run_handles_session_create_race() -> None:
|
||||
class _RaceRepository(_FakeRepository):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.create_calls = 0
|
||||
|
||||
async def get_session_owner(self, *, session_id: str) -> str:
|
||||
if self.create_calls == 0:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
return "00000000-0000-0000-0000-000000000001"
|
||||
|
||||
async def create_session_for_user(
|
||||
self, *, user_id: str, session_id: str | None = None
|
||||
) -> str:
|
||||
del user_id, session_id
|
||||
self.create_calls += 1
|
||||
raise IntegrityError("insert", {}, Exception("duplicate key"))
|
||||
|
||||
repository = _RaceRepository()
|
||||
service = AgentService(
|
||||
repository=repository,
|
||||
queue=_FakeQueue(),
|
||||
stream=_FakeStream(),
|
||||
)
|
||||
run_input = _build_run_input(
|
||||
thread_id="00000000-0000-0000-0000-000000000999",
|
||||
run_id="run-1",
|
||||
)
|
||||
|
||||
accepted = await service.enqueue_run(
|
||||
run_input=run_input,
|
||||
current_user=_user(),
|
||||
)
|
||||
|
||||
assert accepted.created is False
|
||||
assert repository.rolled_back is True
|
||||
|
||||
|
||||
async def test_get_history_snapshot_wraps_history_day_as_state_snapshot_event() -> None:
|
||||
service = AgentService(
|
||||
repository=_FakeRepository(),
|
||||
queue=_FakeQueue(),
|
||||
stream=_FakeStream(),
|
||||
)
|
||||
|
||||
event = await service.get_history_snapshot(
|
||||
thread_id="00000000-0000-0000-0000-000000000001",
|
||||
before=date(2026, 3, 7),
|
||||
current_user=_user(),
|
||||
)
|
||||
|
||||
assert event["type"] == "STATE_SNAPSHOT"
|
||||
assert event["threadId"] == "00000000-0000-0000-0000-000000000001"
|
||||
snapshot = event["snapshot"]
|
||||
assert isinstance(snapshot, dict)
|
||||
assert snapshot["scope"] == "history_day"
|
||||
assert snapshot["day"] == "2026-03-06"
|
||||
assert snapshot["messages"][0]["id"] == "m1"
|
||||
|
||||
|
||||
async def test_get_user_history_snapshot_uses_latest_thread_when_absent() -> None:
|
||||
service = AgentService(
|
||||
repository=_FakeRepository(),
|
||||
queue=_FakeQueue(),
|
||||
stream=_FakeStream(),
|
||||
)
|
||||
event = await service.get_user_history_snapshot(
|
||||
current_user=_user(),
|
||||
thread_id=None,
|
||||
before=None,
|
||||
)
|
||||
assert event["type"] == "STATE_SNAPSHOT"
|
||||
assert event["threadId"] == "00000000-0000-0000-0000-000000000001"
|
||||
|
||||
Reference in New Issue
Block a user