feat: AG-UI 协议对齐与路由导航功能

- 前端: 添加 SSE 流式支持、stateSnapshot 事件、路由导航工具
- 前端: 实现工具调用审批流程,支持 pending 状态展示
- 后端: Agent 状态管理与会话持久化相关重构
- 文档: 新增 agent-agui-full-alignance 设计文档
- 测试: 补充相关单元测试和集成测试
This commit is contained in:
zl-q
2026-03-07 17:30:20 +08:00
parent ec33bb0cee
commit 120df903d2
52 changed files with 4305 additions and 1672 deletions
@@ -132,7 +132,9 @@ def test_bridge_rejects_unknown_event_type() -> None:
def test_sse_format_includes_id_event_data() -> None:
payload = to_sse_event(
stream_id="1-0", event={"type": "RUN_STARTED", "data": {"a": 1}}
stream_id="1-0",
event={"type": "RUN_STARTED", "threadId": "t1", "runId": "r1"},
)
assert payload.startswith("id: 1-0\nevent: RUN_STARTED\ndata: {")
assert '"threadId":"t1"' in payload
@@ -56,12 +56,14 @@ def test_runtime_execute_uses_provider_prefixed_litellm_model(
messages: list[dict[str, object]],
temperature: float | None = None,
max_tokens: int | None = None,
timeout: float | None = None,
):
captured["model"] = model
captured["api_key"] = api_key
captured["messages"] = messages
captured["temperature"] = temperature
captured["max_tokens"] = max_tokens
captured["timeout"] = timeout
return {
"choices": [
{
@@ -113,6 +115,7 @@ def test_runtime_execute_uses_provider_prefixed_litellm_model(
assert captured["api_key"] == "env-api-key"
assert captured["temperature"] == 0.3
assert captured["max_tokens"] == 256
assert captured["timeout"] == 30.0
assert result["assistant_text"] == "hello"
@@ -128,6 +131,7 @@ def test_runtime_execute_injects_system_prompt_and_intent_template(
messages: list[dict[str, object]],
temperature: float | None = None,
max_tokens: int | None = None,
timeout: float | None = None,
):
captured["messages"] = messages
return {
@@ -219,6 +223,7 @@ def test_runtime_execute_short_circuits_on_direct_execution(
messages: list[dict[str, object]],
temperature: float | None = None,
max_tokens: int | None = None,
timeout: float | None = None,
):
del model, api_key, temperature, max_tokens
calls.append(messages)
@@ -331,6 +336,7 @@ def test_runtime_execute_runs_execution_and_organization_stages(
messages: list[dict[str, object]],
temperature: float | None = None,
max_tokens: int | None = None,
timeout: float | None = None,
):
del model, api_key, temperature, max_tokens
calls.append(messages)
@@ -383,6 +389,7 @@ def test_runtime_execute_rejects_invalid_intent_json(
messages: list[dict[str, object]],
temperature: float | None = None,
max_tokens: int | None = None,
timeout: float | None = None,
):
del model, api_key, messages, temperature, max_tokens
return {
@@ -506,6 +513,7 @@ def test_runtime_execute_minimizes_prompt_and_execution_payload(
messages: list[dict[str, object]],
temperature: float | None = None,
max_tokens: int | None = None,
timeout: float | None = None,
):
del model, api_key, temperature, max_tokens
calls.append(messages)
@@ -21,10 +21,12 @@ def test_run_completion_passes_optional_params_when_provided(monkeypatch) -> Non
messages=[{"role": "user", "content": "hi"}],
temperature=0.6,
max_tokens=120,
timeout=12.5,
)
assert captured["temperature"] == 0.6
assert captured["max_tokens"] == 120
assert captured["timeout"] == 12.5
def test_run_completion_omits_optional_params_when_none(monkeypatch) -> None:
@@ -45,7 +47,9 @@ def test_run_completion_omits_optional_params_when_none(monkeypatch) -> None:
messages=[{"role": "user", "content": "hi"}],
temperature=None,
max_tokens=None,
timeout=None,
)
assert "temperature" not in captured
assert "max_tokens" not in captured
assert "timeout" not in captured
+114 -26
View File
@@ -2,64 +2,124 @@ from __future__ import annotations
import pytest
from ag_ui.core import RunAgentInput
from core.agent.infrastructure.queue.tasks import _build_redis_publisher, run_agent_task
class _FakeRunService:
async def run(self, *, session_id: str, user_input: str) -> dict[str, object]:
return {"session_id": session_id, "user_input": user_input}
async def run(self, *, run_input: RunAgentInput) -> dict[str, object]:
return {
"threadId": run_input.thread_id,
"runId": run_input.run_id,
}
class _FakeResumeService:
async def resume(self, *, session_id: str, tool_call_id: str) -> dict[str, object]:
return {"session_id": session_id, "tool_call_id": tool_call_id}
async def resume(
self,
*,
run_input: RunAgentInput,
) -> dict[str, object]:
return {
"threadId": run_input.thread_id,
"runId": run_input.run_id,
}
def _build_run_input() -> dict[str, object]:
return {
"threadId": "00000000-0000-0000-0000-000000000001",
"runId": "run-1",
"state": {},
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
"tools": [],
"context": [],
"forwardedProps": {},
}
@pytest.mark.asyncio
async def test_run_agent_task_emits_started_runtime_and_finished_events() -> None:
session_id = "00000000-0000-0000-0000-000000000001"
events: list[str] = []
async def _publish(event_type: str, payload: dict[str, object]) -> None:
del payload
events.append(event_type)
async def _publish(event: dict[str, object]) -> None:
event_type = event.get("type")
if isinstance(event_type, str):
events.append(event_type)
result = await run_agent_task(
{
"command": "run",
"session_id": session_id,
"user_input": "hello",
"run_input": _build_run_input(),
},
publish_event=_publish,
run_service=_FakeRunService(),
resume_service=_FakeResumeService(),
)
assert result["session_id"] == session_id
assert events == ["RUN_STARTED", "RUNTIME_EVENT", "RUN_FINISHED"]
assert result["threadId"] == "00000000-0000-0000-0000-000000000001"
assert events == ["RUN_STARTED", "RUN_FINISHED"]
@pytest.mark.asyncio
async def test_run_agent_task_injects_context_and_redacts_sensitive_fields() -> None:
published: list[dict[str, object]] = []
class _RunWithExtraEvents(_FakeRunService):
async def run(self, *, run_input: RunAgentInput) -> dict[str, object]:
return {
"threadId": run_input.thread_id,
"runId": run_input.run_id,
"events": [
{
"type": "TEXT_MESSAGE_CONTENT",
"messageId": "m1",
"delta": "hi",
"token": "secret-token",
}
],
}
async def _publish(event: dict[str, object]) -> None:
published.append(event)
await run_agent_task(
{"command": "run", "run_input": _build_run_input()},
publish_event=_publish,
run_service=_RunWithExtraEvents(),
resume_service=_FakeResumeService(),
)
run_started = published[0]
assert run_started["type"] == "RUN_STARTED"
assert "input" not in run_started
text_event = published[1]
assert text_event["type"] == "TEXT_MESSAGE_CONTENT"
assert text_event["threadId"] == "00000000-0000-0000-0000-000000000001"
assert text_event["runId"] == "run-1"
assert text_event["token"] == "***REDACTED***"
@pytest.mark.asyncio
async def test_run_agent_task_emits_error_event_on_exception() -> None:
session_id = "00000000-0000-0000-0000-000000000001"
class _BrokenRunService(_FakeRunService):
async def run(self, *, session_id: str, user_input: str) -> dict[str, object]:
del session_id, user_input
async def run(self, *, run_input: dict[str, object]) -> dict[str, object]:
del run_input
raise RuntimeError("boom")
events: list[str] = []
async def _publish(event_type: str, payload: dict[str, object]) -> None:
del payload
events.append(event_type)
async def _publish(event: dict[str, object]) -> None:
event_type = event.get("type")
if isinstance(event_type, str):
events.append(event_type)
with pytest.raises(RuntimeError):
await run_agent_task(
{
"command": "run",
"session_id": session_id,
"user_input": "hello",
"run_input": _build_run_input(),
},
publish_event=_publish,
run_service=_BrokenRunService(),
@@ -72,16 +132,44 @@ async def test_run_agent_task_emits_error_event_on_exception() -> None:
@pytest.mark.asyncio
async def test_run_agent_task_rejects_invalid_command() -> None:
with pytest.raises(ValueError, match="invalid command type"):
await run_agent_task({"command": "invalid", "session_id": "00000000-0000-0000-0000-000000000001"})
await run_agent_task({"command": "invalid", "run_input": _build_run_input()})
@pytest.mark.asyncio
async def test_run_agent_task_resume_requires_tool_call_id() -> None:
with pytest.raises(ValueError, match="tool_call_id is required"):
async def test_run_agent_task_rejects_missing_run_input() -> None:
with pytest.raises(ValueError, match="run_input is required"):
await run_agent_task(
{
"command": "resume",
"session_id": "00000000-0000-0000-0000-000000000001",
"command": "run",
}
)
@pytest.mark.asyncio
async def test_run_agent_task_resume_uses_run_input() -> None:
async def _publish(event: dict[str, object]) -> None:
del event
result = await run_agent_task(
{
"command": "resume",
"run_input": _build_run_input(),
},
publish_event=_publish,
run_service=_FakeRunService(),
resume_service=_FakeResumeService(),
)
assert result["runId"] == "run-1"
@pytest.mark.asyncio
async def test_run_agent_task_rejects_invalid_run_input() -> None:
with pytest.raises(ValueError, match="invalid AG-UI RunAgentInput payload"):
await run_agent_task(
{
"command": "run",
"run_input": {"threadId": "x"},
}
)
@@ -23,11 +23,34 @@ class _FakeRedisClient:
) -> list[tuple[str, list[tuple[str, dict[str, str]]]]]:
del count, block
key, start_id = next(iter(streams.items()))
if start_id == "$":
if start_id == "0-0":
return [(key, [("11-0", {"event": '{"type":"RUN_STARTED"}'})])]
return [(key, [("12-0", {"event": '{"type":"RUN_FINISHED"}'})])]
class _MalformedRedisClient:
async def xread(
self,
streams: dict[str, str],
count: int,
block: int,
) -> list[object]:
del streams, count, block
return ["bad-shape"]
class _InvalidJsonRedisClient:
async def xread(
self,
streams: dict[str, str],
count: int,
block: int,
) -> list[tuple[str, list[tuple[str, dict[str, str]]]]]:
del count, block
key = next(iter(streams.keys()))
return [(key, [("11-0", {"event": "not-json"})])]
def test_append_event_writes_json_payload() -> None:
client = _FakeRedisClient()
session_id = uuid4()
@@ -55,3 +78,26 @@ async def test_read_events_respects_last_event_id() -> None:
assert from_start[0]["id"] == "11-0"
assert from_last[0]["id"] == "12-0"
@pytest.mark.asyncio
async def test_read_events_returns_empty_for_malformed_response() -> None:
session_id = uuid4()
store = RedisStreamEventStore(client=_MalformedRedisClient(), stream_prefix="agent:events")
rows = await store.read_events(session_id=session_id, last_event_id=None)
assert rows == []
@pytest.mark.asyncio
async def test_read_events_skips_invalid_event_json() -> None:
session_id = uuid4()
store = RedisStreamEventStore(
client=_InvalidJsonRedisClient(),
stream_prefix="agent:events",
)
rows = await store.read_events(session_id=session_id, last_event_id=None)
assert rows == []
@@ -5,11 +5,13 @@ from types import SimpleNamespace
from uuid import uuid4
import pytest
from ag_ui.core import RunAgentInput
from core.agent.application.resume_service import ResumeService
from core.agent.application.run_service import RunService
from core.agent.domain.system_agent_config import SystemAgentLLMConfig
from core.agent.domain.user_context import UserAgentContext, parse_profile_settings
from models.agent_chat_message import AgentChatMessageRole
from models.agent_chat_session import AgentChatSessionStatus
@@ -61,12 +63,69 @@ class _FakeUserContextCache:
self.set_calls += 1
def _build_run_input(
*,
thread_id: str,
text: str = "hello",
tools: list[dict[str, object]] | None = None,
) -> RunAgentInput:
return RunAgentInput.model_validate(
{
"threadId": thread_id,
"runId": "run-1",
"state": {},
"messages": [{"id": "u1", "role": "user", "content": text}],
"tools": tools or [],
"context": [],
"forwardedProps": {},
}
)
def _build_resume_input(
*,
thread_id: str,
tool_call_id: str,
content: str | None = None,
) -> RunAgentInput:
payload = content
if payload is None:
payload = json.dumps(
{
"toolName": "navigate_to_route",
"toolArgs": {"target": "/calendar/dayweek", "replace": False, "__nonce": "nonce-1"},
"nonce": "nonce-1",
"result": {"ok": True},
},
ensure_ascii=True,
separators=(",", ":"),
)
return RunAgentInput.model_validate(
{
"threadId": thread_id,
"runId": "run-2",
"state": {},
"messages": [
{
"id": "tool-1",
"role": "tool",
"toolCallId": tool_call_id,
"content": payload,
}
],
"tools": [],
"context": [],
"forwardedProps": {},
}
)
@pytest.mark.asyncio
async def test_run_service_rejects_invalid_session_id() -> None:
run_service = RunService()
with pytest.raises(ValueError):
await run_service.run(session_id="session-1", user_input="hello")
await run_service.run(run_input=_build_run_input(thread_id="session-1"))
@pytest.mark.asyncio
@@ -74,7 +133,272 @@ async def test_resume_service_requires_pending_tool_call() -> None:
resume_service = ResumeService()
with pytest.raises(ValueError):
await resume_service.resume(session_id="session-1", tool_call_id="call-1")
await resume_service.resume(
run_input=_build_resume_input(
thread_id="session-1",
tool_call_id="call-1",
)
)
@pytest.mark.asyncio
async def test_resume_service_validates_pending_tool_guard_and_persists_payload(
monkeypatch: pytest.MonkeyPatch,
) -> None:
session_id = uuid4()
user_id = uuid4()
captured: list[dict[str, object]] = []
class _FakeDbSession:
async def commit(self) -> None:
return None
class _FakeSessionFactory:
def __call__(self) -> "_FakeSessionFactory":
return self
async def __aenter__(self) -> _FakeDbSession:
return _FakeDbSession()
async def __aexit__(self, exc_type, exc, tb) -> bool:
del exc_type, exc, tb
return False
class _FakeSessionRepository:
def __init__(self, session: object) -> None:
del session
async def lock_session_for_update(self, *, session_id: object):
return SimpleNamespace(
id=session_id,
user_id=user_id,
status=AgentChatSessionStatus.RUNNING,
message_count=0,
total_tokens=0,
total_cost=0,
state_snapshot={
"pending_tool_call_id": "call-1",
"pending_tool_name": "navigate_to_route",
"pending_tool_args_sha256": "c8e6573e6ce79d9a8052d5167e18602841fa4248f3e8a2efb448c4bfbd298a12",
"pending_tool_nonce": "nonce-1",
},
)
async def next_message_seq(self, *, session_id: object) -> int:
del session_id
return 1
async def update_runtime_state(self, **kwargs) -> None:
del kwargs
class _FakeMessageRepository:
def __init__(self, session: object) -> None:
del session
async def append_message(self, **kwargs) -> None:
captured.append(kwargs)
monkeypatch.setattr(
"core.agent.application.resume_service.SessionRepository",
_FakeSessionRepository,
)
monkeypatch.setattr(
"core.agent.application.resume_service.MessageRepository",
_FakeMessageRepository,
)
service = ResumeService(session_factory=_FakeSessionFactory()) # type: ignore[arg-type]
await service.resume(
run_input=_build_resume_input(
thread_id=str(session_id),
tool_call_id="call-1",
),
)
assert captured[0]["role"] == AgentChatMessageRole.TOOL
stored_payload = json.loads(captured[0]["content"])
assert stored_payload["toolName"] == "navigate_to_route"
assert stored_payload["result"]["ok"] is True
assert stored_payload["result"]["applied"] is True
assert "ui" not in stored_payload
@pytest.mark.asyncio
async def test_resume_service_rejects_mismatched_nonce(
monkeypatch: pytest.MonkeyPatch,
) -> None:
session_id = uuid4()
user_id = uuid4()
class _FakeDbSession:
async def commit(self) -> None:
return None
class _FakeSessionFactory:
def __call__(self) -> "_FakeSessionFactory":
return self
async def __aenter__(self) -> _FakeDbSession:
return _FakeDbSession()
async def __aexit__(self, exc_type, exc, tb) -> bool:
del exc_type, exc, tb
return False
class _FakeSessionRepository:
def __init__(self, session: object) -> None:
del session
async def lock_session_for_update(self, *, session_id: object):
return SimpleNamespace(
id=session_id,
user_id=user_id,
status=AgentChatSessionStatus.RUNNING,
message_count=0,
total_tokens=0,
total_cost=0,
state_snapshot={
"pending_tool_call_id": "call-1",
"pending_tool_name": "navigate_to_route",
"pending_tool_args_sha256": "c8e6573e6ce79d9a8052d5167e18602841fa4248f3e8a2efb448c4bfbd298a12",
"pending_tool_nonce": "nonce-1",
},
)
async def next_message_seq(self, *, session_id: object) -> int:
del session_id
return 1
async def update_runtime_state(self, **kwargs) -> None:
del kwargs
class _FakeMessageRepository:
def __init__(self, session: object) -> None:
del session
async def append_message(self, **kwargs) -> None:
del kwargs
monkeypatch.setattr(
"core.agent.application.resume_service.SessionRepository",
_FakeSessionRepository,
)
monkeypatch.setattr(
"core.agent.application.resume_service.MessageRepository",
_FakeMessageRepository,
)
service = ResumeService(session_factory=_FakeSessionFactory()) # type: ignore[arg-type]
with pytest.raises(ValueError, match="nonce"):
await service.resume(
run_input=_build_resume_input(
thread_id=str(session_id),
tool_call_id="call-1",
content=json.dumps(
{
"toolName": "navigate_to_route",
"toolArgs": {
"target": "/calendar/dayweek",
"replace": False,
"__nonce": "nonce-1",
},
"nonce": "nonce-bad",
"result": {"ok": True},
},
ensure_ascii=True,
separators=(",", ":"),
),
)
)
@pytest.mark.asyncio
async def test_resume_service_rejects_tool_result_when_not_ok(
monkeypatch: pytest.MonkeyPatch,
) -> None:
session_id = uuid4()
user_id = uuid4()
class _FakeDbSession:
async def commit(self) -> None:
return None
class _FakeSessionFactory:
def __call__(self) -> "_FakeSessionFactory":
return self
async def __aenter__(self) -> _FakeDbSession:
return _FakeDbSession()
async def __aexit__(self, exc_type, exc, tb) -> bool:
del exc_type, exc, tb
return False
class _FakeSessionRepository:
def __init__(self, session: object) -> None:
del session
async def lock_session_for_update(self, *, session_id: object):
return SimpleNamespace(
id=session_id,
user_id=user_id,
status=AgentChatSessionStatus.RUNNING,
message_count=0,
total_tokens=0,
total_cost=0,
state_snapshot={
"pending_tool_call_id": "call-1",
"pending_tool_name": "navigate_to_route",
"pending_tool_args_sha256": "c8e6573e6ce79d9a8052d5167e18602841fa4248f3e8a2efb448c4bfbd298a12",
"pending_tool_nonce": "nonce-1",
},
)
async def next_message_seq(self, *, session_id: object) -> int:
del session_id
return 1
async def update_runtime_state(self, **kwargs) -> None:
del kwargs
class _FakeMessageRepository:
def __init__(self, session: object) -> None:
del session
async def append_message(self, **kwargs) -> None:
del kwargs
monkeypatch.setattr(
"core.agent.application.resume_service.SessionRepository",
_FakeSessionRepository,
)
monkeypatch.setattr(
"core.agent.application.resume_service.MessageRepository",
_FakeMessageRepository,
)
service = ResumeService(session_factory=_FakeSessionFactory()) # type: ignore[arg-type]
with pytest.raises(ValueError, match="execution failed"):
await service.resume(
run_input=_build_resume_input(
thread_id=str(session_id),
tool_call_id="call-1",
content=json.dumps(
{
"toolName": "navigate_to_route",
"toolArgs": {
"target": "/calendar/dayweek",
"replace": False,
"__nonce": "nonce-1",
},
"nonce": "nonce-1",
"result": {"ok": False, "error": "navigator not bound"},
},
ensure_ascii=True,
separators=(",", ":"),
),
)
)
@pytest.mark.asyncio
@@ -256,7 +580,9 @@ async def test_run_service_passes_user_context_system_prompt_to_runtime(
session_uuid = session_id
run_service = RunService(session_factory=_FakeSessionFactory()) # type: ignore[arg-type]
await run_service.run(session_id=str(session_id), user_input="hello")
await run_service.run(
run_input=_build_run_input(thread_id=str(session_id), text="hello")
)
system_prompt = captured["system_prompt"]
assert isinstance(system_prompt, str)
@@ -267,6 +593,290 @@ async def test_run_service_passes_user_context_system_prompt_to_runtime(
assert payload["ai_language"] == "en-US"
@pytest.mark.asyncio
async def test_run_service_emits_frontend_tool_pending_events(
monkeypatch: pytest.MonkeyPatch,
) -> None:
session_id = uuid4()
user_id = uuid4()
captured: dict[str, object] = {}
class _FakeDbSession:
async def commit(self) -> None:
return None
class _FakeSessionFactory:
def __call__(self) -> "_FakeSessionFactory":
return self
async def __aenter__(self) -> _FakeDbSession:
return _FakeDbSession()
async def __aexit__(self, exc_type, exc, tb) -> bool:
del exc_type, exc, tb
return False
class _FakeSessionRepository:
def __init__(self, session: object) -> None:
del session
async def lock_session_for_update(self, *, session_id: object):
return SimpleNamespace(
id=session_id,
user_id=user_id,
status=AgentChatSessionStatus.PENDING,
message_count=0,
total_tokens=0,
total_cost=0,
state_snapshot=None,
)
async def next_message_seq(self, *, session_id: object):
del session_id
return 1
async def update_runtime_state(self, **kwargs) -> None:
captured["update_runtime_state"] = kwargs
class _FakeMessageRepository:
def __init__(self, session: object) -> None:
del session
async def append_message(self, **kwargs) -> None:
captured.setdefault("messages", []).append(kwargs)
class _FakeRuntime:
def execute(self, *, user_input: str, system_prompt: str | None = None):
del user_input, system_prompt
return {
"assistant_text": "请确认是否跳转。",
"prompt_tokens": 1,
"completion_tokens": 1,
"total_tokens": 2,
"cost": "0.001",
"agui_events": [],
}
async def _fake_load_agent_model_selection(self, _session):
del self
return ("qwen3.5-flash", "dashscope", SystemAgentLLMConfig())
async def _fake_load_user_agent_context(self, session, session_id, user_id):
del self, session, session_id
return SimpleNamespace(
user_id=user_id,
username="demo-user",
bio=None,
settings=SimpleNamespace(
preferences=SimpleNamespace(
interface_language="zh-CN",
ai_language="zh-CN",
timezone="Asia/Shanghai",
country="CN",
)
),
)
monkeypatch.setattr(
"core.agent.application.run_service.SessionRepository",
_FakeSessionRepository,
)
monkeypatch.setattr(
"core.agent.application.run_service.MessageRepository",
_FakeMessageRepository,
)
monkeypatch.setattr(
"core.agent.application.run_service.create_runtime",
lambda **_kwargs: _FakeRuntime(),
)
monkeypatch.setattr(
"core.agent.application.run_service.RunService._load_agent_model_selection",
_fake_load_agent_model_selection,
)
monkeypatch.setattr(
"core.agent.application.run_service.RunService._load_user_agent_context",
_fake_load_user_agent_context,
)
service = RunService(session_factory=_FakeSessionFactory()) # type: ignore[arg-type]
result = await service.run(
run_input=_build_run_input(
thread_id=str(session_id),
text="帮我打开日历",
tools=[
{
"name": "navigate_to_route",
"description": "navigate",
"parameters": {"type": "object"},
}
],
)
)
assert result["pending_tool_call_id"] is not None
tool_start = next(event for event in result["events"] if event["type"] == "TOOL_CALL_START")
assert tool_start["toolCallName"] == "navigate_to_route"
runtime_state = captured["update_runtime_state"]
assert isinstance(runtime_state, dict)
assert runtime_state["status"] == AgentChatSessionStatus.RUNNING
snapshot = runtime_state["state_snapshot"]
assert isinstance(snapshot, dict)
assert snapshot["pending_tool_name"] == "navigate_to_route"
assert isinstance(snapshot["pending_tool_args_sha256"], str)
assert isinstance(snapshot["pending_tool_nonce"], str)
@pytest.mark.asyncio
async def test_run_service_executes_backend_calendar_tool_and_emits_result(
monkeypatch: pytest.MonkeyPatch,
) -> None:
session_id = uuid4()
user_id = uuid4()
captured: dict[str, object] = {}
class _FakeDbSession:
async def commit(self) -> None:
return None
class _FakeSessionFactory:
def __call__(self) -> "_FakeSessionFactory":
return self
async def __aenter__(self) -> _FakeDbSession:
return _FakeDbSession()
async def __aexit__(self, exc_type, exc, tb) -> bool:
del exc_type, exc, tb
return False
class _FakeSessionRepository:
def __init__(self, session: object) -> None:
del session
async def lock_session_for_update(self, *, session_id: object):
return SimpleNamespace(
id=session_id,
user_id=user_id,
status=AgentChatSessionStatus.PENDING,
message_count=0,
total_tokens=0,
total_cost=0,
state_snapshot=None,
)
async def next_message_seq(self, *, session_id: object):
del session_id
return 1
async def update_runtime_state(self, **kwargs) -> None:
captured["update_runtime_state"] = kwargs
class _FakeMessageRepository:
def __init__(self, session: object) -> None:
del session
async def append_message(self, **kwargs) -> None:
captured.setdefault("messages", []).append(kwargs)
class _FakeRuntime:
def execute(self, *, user_input: str, system_prompt: str | None = None):
del user_input, system_prompt
return {
"assistant_text": "日历事件已创建。",
"prompt_tokens": 1,
"completion_tokens": 1,
"total_tokens": 2,
"cost": "0.001",
"agui_events": [],
}
async def _fake_load_agent_model_selection(self, _session):
del self
return ("qwen3.5-flash", "dashscope", SystemAgentLLMConfig())
async def _fake_load_user_agent_context(self, session, session_id, user_id):
del self, session, session_id
return SimpleNamespace(
user_id=user_id,
username="demo-user",
bio=None,
settings=SimpleNamespace(
preferences=SimpleNamespace(
interface_language="zh-CN",
ai_language="zh-CN",
timezone="Asia/Shanghai",
country="CN",
)
),
)
async def _fake_execute_backend_tool(
self,
*,
session,
owner_id,
tool_name,
tool_args,
):
del self, session, owner_id
assert tool_name == "create_calendar_event"
assert "title" in tool_args
return {
"result": {"eventId": "evt-1", "ok": True},
"ui": {
"type": "calendar_card.v1",
"version": "v1",
"data": {"id": "evt-1", "title": "会议"},
},
}
monkeypatch.setattr(
"core.agent.application.run_service.SessionRepository",
_FakeSessionRepository,
)
monkeypatch.setattr(
"core.agent.application.run_service.MessageRepository",
_FakeMessageRepository,
)
monkeypatch.setattr(
"core.agent.application.run_service.create_runtime",
lambda **_kwargs: _FakeRuntime(),
)
monkeypatch.setattr(
"core.agent.application.run_service.RunService._load_agent_model_selection",
_fake_load_agent_model_selection,
)
monkeypatch.setattr(
"core.agent.application.run_service.RunService._load_user_agent_context",
_fake_load_user_agent_context,
)
monkeypatch.setattr(
"core.agent.application.run_service.RunService._execute_backend_tool",
_fake_execute_backend_tool,
)
service = RunService(session_factory=_FakeSessionFactory()) # type: ignore[arg-type]
result = await service.run(
run_input=_build_run_input(
thread_id=str(session_id),
text='#tool:create_calendar_event {"title":"会议","startAt":"2026-03-07T08:00:00Z"}',
tools=[
{
"name": "create_calendar_event",
"description": "create calendar",
"parameters": {"type": "object"},
}
],
)
)
assert result["pending_tool_call_id"] is None
assert any(event["type"] == "TOOL_CALL_RESULT" for event in result["events"])
runtime_state = captured["update_runtime_state"]
assert isinstance(runtime_state, dict)
assert runtime_state["status"] == AgentChatSessionStatus.COMPLETED
@pytest.mark.asyncio
async def test_load_user_agent_context_parses_profile_settings_v1() -> None:
session_id = uuid4()
@@ -519,7 +1129,9 @@ async def test_run_service_still_executes_when_profile_missing(
session_uuid = session_id
run_service = RunService(session_factory=_FakeSessionFactory()) # type: ignore[arg-type]
await run_service.run(session_id=str(session_id), user_input="hello")
await run_service.run(
run_input=_build_run_input(thread_id=str(session_id), text="hello")
)
system_prompt = captured["system_prompt"]
assert isinstance(system_prompt, str)
@@ -4,9 +4,18 @@ from core.agent.domain.state_snapshot import AgentStateSnapshot
def test_state_snapshot_serialization_round_trip() -> None:
snapshot = AgentStateSnapshot(status="running", pending_tool_call_id="call-1")
snapshot = AgentStateSnapshot(
status="running",
pending_tool_call_id="call-1",
pending_tool_name="navigate_to_route",
pending_tool_args_sha256="abc",
pending_tool_nonce="nonce-1",
)
payload = snapshot.model_dump()
assert payload["status"] == "running"
assert payload["pending_tool_call_id"] == "call-1"
assert payload["pending_tool_name"] == "navigate_to_route"
assert payload["pending_tool_args_sha256"] == "abc"
assert payload["pending_tool_nonce"] == "nonce-1"