feat: AG-UI 协议对齐与路由导航功能
- 前端: 添加 SSE 流式支持、stateSnapshot 事件、路由导航工具 - 前端: 实现工具调用审批流程,支持 pending 状态展示 - 后端: Agent 状态管理与会话持久化相关重构 - 文档: 新增 agent-agui-full-alignance 设计文档 - 测试: 补充相关单元测试和集成测试
This commit is contained in:
@@ -132,7 +132,9 @@ def test_bridge_rejects_unknown_event_type() -> None:
|
||||
|
||||
def test_sse_format_includes_id_event_data() -> None:
|
||||
payload = to_sse_event(
|
||||
stream_id="1-0", event={"type": "RUN_STARTED", "data": {"a": 1}}
|
||||
stream_id="1-0",
|
||||
event={"type": "RUN_STARTED", "threadId": "t1", "runId": "r1"},
|
||||
)
|
||||
|
||||
assert payload.startswith("id: 1-0\nevent: RUN_STARTED\ndata: {")
|
||||
assert '"threadId":"t1"' in payload
|
||||
|
||||
@@ -56,12 +56,14 @@ def test_runtime_execute_uses_provider_prefixed_litellm_model(
|
||||
messages: list[dict[str, object]],
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
timeout: float | None = None,
|
||||
):
|
||||
captured["model"] = model
|
||||
captured["api_key"] = api_key
|
||||
captured["messages"] = messages
|
||||
captured["temperature"] = temperature
|
||||
captured["max_tokens"] = max_tokens
|
||||
captured["timeout"] = timeout
|
||||
return {
|
||||
"choices": [
|
||||
{
|
||||
@@ -113,6 +115,7 @@ def test_runtime_execute_uses_provider_prefixed_litellm_model(
|
||||
assert captured["api_key"] == "env-api-key"
|
||||
assert captured["temperature"] == 0.3
|
||||
assert captured["max_tokens"] == 256
|
||||
assert captured["timeout"] == 30.0
|
||||
assert result["assistant_text"] == "hello"
|
||||
|
||||
|
||||
@@ -128,6 +131,7 @@ def test_runtime_execute_injects_system_prompt_and_intent_template(
|
||||
messages: list[dict[str, object]],
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
timeout: float | None = None,
|
||||
):
|
||||
captured["messages"] = messages
|
||||
return {
|
||||
@@ -219,6 +223,7 @@ def test_runtime_execute_short_circuits_on_direct_execution(
|
||||
messages: list[dict[str, object]],
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
timeout: float | None = None,
|
||||
):
|
||||
del model, api_key, temperature, max_tokens
|
||||
calls.append(messages)
|
||||
@@ -331,6 +336,7 @@ def test_runtime_execute_runs_execution_and_organization_stages(
|
||||
messages: list[dict[str, object]],
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
timeout: float | None = None,
|
||||
):
|
||||
del model, api_key, temperature, max_tokens
|
||||
calls.append(messages)
|
||||
@@ -383,6 +389,7 @@ def test_runtime_execute_rejects_invalid_intent_json(
|
||||
messages: list[dict[str, object]],
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
timeout: float | None = None,
|
||||
):
|
||||
del model, api_key, messages, temperature, max_tokens
|
||||
return {
|
||||
@@ -506,6 +513,7 @@ def test_runtime_execute_minimizes_prompt_and_execution_payload(
|
||||
messages: list[dict[str, object]],
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
timeout: float | None = None,
|
||||
):
|
||||
del model, api_key, temperature, max_tokens
|
||||
calls.append(messages)
|
||||
|
||||
@@ -21,10 +21,12 @@ def test_run_completion_passes_optional_params_when_provided(monkeypatch) -> Non
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
temperature=0.6,
|
||||
max_tokens=120,
|
||||
timeout=12.5,
|
||||
)
|
||||
|
||||
assert captured["temperature"] == 0.6
|
||||
assert captured["max_tokens"] == 120
|
||||
assert captured["timeout"] == 12.5
|
||||
|
||||
|
||||
def test_run_completion_omits_optional_params_when_none(monkeypatch) -> None:
|
||||
@@ -45,7 +47,9 @@ def test_run_completion_omits_optional_params_when_none(monkeypatch) -> None:
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
temperature=None,
|
||||
max_tokens=None,
|
||||
timeout=None,
|
||||
)
|
||||
|
||||
assert "temperature" not in captured
|
||||
assert "max_tokens" not in captured
|
||||
assert "timeout" not in captured
|
||||
|
||||
@@ -2,64 +2,124 @@ from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from ag_ui.core import RunAgentInput
|
||||
from core.agent.infrastructure.queue.tasks import _build_redis_publisher, run_agent_task
|
||||
|
||||
|
||||
class _FakeRunService:
|
||||
async def run(self, *, session_id: str, user_input: str) -> dict[str, object]:
|
||||
return {"session_id": session_id, "user_input": user_input}
|
||||
async def run(self, *, run_input: RunAgentInput) -> dict[str, object]:
|
||||
return {
|
||||
"threadId": run_input.thread_id,
|
||||
"runId": run_input.run_id,
|
||||
}
|
||||
|
||||
|
||||
class _FakeResumeService:
|
||||
async def resume(self, *, session_id: str, tool_call_id: str) -> dict[str, object]:
|
||||
return {"session_id": session_id, "tool_call_id": tool_call_id}
|
||||
async def resume(
|
||||
self,
|
||||
*,
|
||||
run_input: RunAgentInput,
|
||||
) -> dict[str, object]:
|
||||
return {
|
||||
"threadId": run_input.thread_id,
|
||||
"runId": run_input.run_id,
|
||||
}
|
||||
|
||||
|
||||
def _build_run_input() -> dict[str, object]:
|
||||
return {
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-1",
|
||||
"state": {},
|
||||
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_task_emits_started_runtime_and_finished_events() -> None:
|
||||
session_id = "00000000-0000-0000-0000-000000000001"
|
||||
events: list[str] = []
|
||||
|
||||
async def _publish(event_type: str, payload: dict[str, object]) -> None:
|
||||
del payload
|
||||
events.append(event_type)
|
||||
async def _publish(event: dict[str, object]) -> None:
|
||||
event_type = event.get("type")
|
||||
if isinstance(event_type, str):
|
||||
events.append(event_type)
|
||||
|
||||
result = await run_agent_task(
|
||||
{
|
||||
"command": "run",
|
||||
"session_id": session_id,
|
||||
"user_input": "hello",
|
||||
"run_input": _build_run_input(),
|
||||
},
|
||||
publish_event=_publish,
|
||||
run_service=_FakeRunService(),
|
||||
resume_service=_FakeResumeService(),
|
||||
)
|
||||
|
||||
assert result["session_id"] == session_id
|
||||
assert events == ["RUN_STARTED", "RUNTIME_EVENT", "RUN_FINISHED"]
|
||||
assert result["threadId"] == "00000000-0000-0000-0000-000000000001"
|
||||
assert events == ["RUN_STARTED", "RUN_FINISHED"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_task_injects_context_and_redacts_sensitive_fields() -> None:
|
||||
published: list[dict[str, object]] = []
|
||||
|
||||
class _RunWithExtraEvents(_FakeRunService):
|
||||
async def run(self, *, run_input: RunAgentInput) -> dict[str, object]:
|
||||
return {
|
||||
"threadId": run_input.thread_id,
|
||||
"runId": run_input.run_id,
|
||||
"events": [
|
||||
{
|
||||
"type": "TEXT_MESSAGE_CONTENT",
|
||||
"messageId": "m1",
|
||||
"delta": "hi",
|
||||
"token": "secret-token",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
async def _publish(event: dict[str, object]) -> None:
|
||||
published.append(event)
|
||||
|
||||
await run_agent_task(
|
||||
{"command": "run", "run_input": _build_run_input()},
|
||||
publish_event=_publish,
|
||||
run_service=_RunWithExtraEvents(),
|
||||
resume_service=_FakeResumeService(),
|
||||
)
|
||||
|
||||
run_started = published[0]
|
||||
assert run_started["type"] == "RUN_STARTED"
|
||||
assert "input" not in run_started
|
||||
|
||||
text_event = published[1]
|
||||
assert text_event["type"] == "TEXT_MESSAGE_CONTENT"
|
||||
assert text_event["threadId"] == "00000000-0000-0000-0000-000000000001"
|
||||
assert text_event["runId"] == "run-1"
|
||||
assert text_event["token"] == "***REDACTED***"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_task_emits_error_event_on_exception() -> None:
|
||||
session_id = "00000000-0000-0000-0000-000000000001"
|
||||
|
||||
class _BrokenRunService(_FakeRunService):
|
||||
async def run(self, *, session_id: str, user_input: str) -> dict[str, object]:
|
||||
del session_id, user_input
|
||||
async def run(self, *, run_input: dict[str, object]) -> dict[str, object]:
|
||||
del run_input
|
||||
raise RuntimeError("boom")
|
||||
|
||||
events: list[str] = []
|
||||
|
||||
async def _publish(event_type: str, payload: dict[str, object]) -> None:
|
||||
del payload
|
||||
events.append(event_type)
|
||||
async def _publish(event: dict[str, object]) -> None:
|
||||
event_type = event.get("type")
|
||||
if isinstance(event_type, str):
|
||||
events.append(event_type)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
await run_agent_task(
|
||||
{
|
||||
"command": "run",
|
||||
"session_id": session_id,
|
||||
"user_input": "hello",
|
||||
"run_input": _build_run_input(),
|
||||
},
|
||||
publish_event=_publish,
|
||||
run_service=_BrokenRunService(),
|
||||
@@ -72,16 +132,44 @@ async def test_run_agent_task_emits_error_event_on_exception() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_task_rejects_invalid_command() -> None:
|
||||
with pytest.raises(ValueError, match="invalid command type"):
|
||||
await run_agent_task({"command": "invalid", "session_id": "00000000-0000-0000-0000-000000000001"})
|
||||
await run_agent_task({"command": "invalid", "run_input": _build_run_input()})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_task_resume_requires_tool_call_id() -> None:
|
||||
with pytest.raises(ValueError, match="tool_call_id is required"):
|
||||
async def test_run_agent_task_rejects_missing_run_input() -> None:
|
||||
with pytest.raises(ValueError, match="run_input is required"):
|
||||
await run_agent_task(
|
||||
{
|
||||
"command": "resume",
|
||||
"session_id": "00000000-0000-0000-0000-000000000001",
|
||||
"command": "run",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_task_resume_uses_run_input() -> None:
|
||||
async def _publish(event: dict[str, object]) -> None:
|
||||
del event
|
||||
|
||||
result = await run_agent_task(
|
||||
{
|
||||
"command": "resume",
|
||||
"run_input": _build_run_input(),
|
||||
},
|
||||
publish_event=_publish,
|
||||
run_service=_FakeRunService(),
|
||||
resume_service=_FakeResumeService(),
|
||||
)
|
||||
|
||||
assert result["runId"] == "run-1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_task_rejects_invalid_run_input() -> None:
|
||||
with pytest.raises(ValueError, match="invalid AG-UI RunAgentInput payload"):
|
||||
await run_agent_task(
|
||||
{
|
||||
"command": "run",
|
||||
"run_input": {"threadId": "x"},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -23,11 +23,34 @@ class _FakeRedisClient:
|
||||
) -> list[tuple[str, list[tuple[str, dict[str, str]]]]]:
|
||||
del count, block
|
||||
key, start_id = next(iter(streams.items()))
|
||||
if start_id == "$":
|
||||
if start_id == "0-0":
|
||||
return [(key, [("11-0", {"event": '{"type":"RUN_STARTED"}'})])]
|
||||
return [(key, [("12-0", {"event": '{"type":"RUN_FINISHED"}'})])]
|
||||
|
||||
|
||||
class _MalformedRedisClient:
|
||||
async def xread(
|
||||
self,
|
||||
streams: dict[str, str],
|
||||
count: int,
|
||||
block: int,
|
||||
) -> list[object]:
|
||||
del streams, count, block
|
||||
return ["bad-shape"]
|
||||
|
||||
|
||||
class _InvalidJsonRedisClient:
|
||||
async def xread(
|
||||
self,
|
||||
streams: dict[str, str],
|
||||
count: int,
|
||||
block: int,
|
||||
) -> list[tuple[str, list[tuple[str, dict[str, str]]]]]:
|
||||
del count, block
|
||||
key = next(iter(streams.keys()))
|
||||
return [(key, [("11-0", {"event": "not-json"})])]
|
||||
|
||||
|
||||
def test_append_event_writes_json_payload() -> None:
|
||||
client = _FakeRedisClient()
|
||||
session_id = uuid4()
|
||||
@@ -55,3 +78,26 @@ async def test_read_events_respects_last_event_id() -> None:
|
||||
|
||||
assert from_start[0]["id"] == "11-0"
|
||||
assert from_last[0]["id"] == "12-0"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_events_returns_empty_for_malformed_response() -> None:
|
||||
session_id = uuid4()
|
||||
store = RedisStreamEventStore(client=_MalformedRedisClient(), stream_prefix="agent:events")
|
||||
|
||||
rows = await store.read_events(session_id=session_id, last_event_id=None)
|
||||
|
||||
assert rows == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_events_skips_invalid_event_json() -> None:
|
||||
session_id = uuid4()
|
||||
store = RedisStreamEventStore(
|
||||
client=_InvalidJsonRedisClient(),
|
||||
stream_prefix="agent:events",
|
||||
)
|
||||
|
||||
rows = await store.read_events(session_id=session_id, last_event_id=None)
|
||||
|
||||
assert rows == []
|
||||
|
||||
@@ -5,11 +5,13 @@ from types import SimpleNamespace
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from ag_ui.core import RunAgentInput
|
||||
|
||||
from core.agent.application.resume_service import ResumeService
|
||||
from core.agent.application.run_service import RunService
|
||||
from core.agent.domain.system_agent_config import SystemAgentLLMConfig
|
||||
from core.agent.domain.user_context import UserAgentContext, parse_profile_settings
|
||||
from models.agent_chat_message import AgentChatMessageRole
|
||||
from models.agent_chat_session import AgentChatSessionStatus
|
||||
|
||||
|
||||
@@ -61,12 +63,69 @@ class _FakeUserContextCache:
|
||||
self.set_calls += 1
|
||||
|
||||
|
||||
def _build_run_input(
|
||||
*,
|
||||
thread_id: str,
|
||||
text: str = "hello",
|
||||
tools: list[dict[str, object]] | None = None,
|
||||
) -> RunAgentInput:
|
||||
return RunAgentInput.model_validate(
|
||||
{
|
||||
"threadId": thread_id,
|
||||
"runId": "run-1",
|
||||
"state": {},
|
||||
"messages": [{"id": "u1", "role": "user", "content": text}],
|
||||
"tools": tools or [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _build_resume_input(
|
||||
*,
|
||||
thread_id: str,
|
||||
tool_call_id: str,
|
||||
content: str | None = None,
|
||||
) -> RunAgentInput:
|
||||
payload = content
|
||||
if payload is None:
|
||||
payload = json.dumps(
|
||||
{
|
||||
"toolName": "navigate_to_route",
|
||||
"toolArgs": {"target": "/calendar/dayweek", "replace": False, "__nonce": "nonce-1"},
|
||||
"nonce": "nonce-1",
|
||||
"result": {"ok": True},
|
||||
},
|
||||
ensure_ascii=True,
|
||||
separators=(",", ":"),
|
||||
)
|
||||
return RunAgentInput.model_validate(
|
||||
{
|
||||
"threadId": thread_id,
|
||||
"runId": "run-2",
|
||||
"state": {},
|
||||
"messages": [
|
||||
{
|
||||
"id": "tool-1",
|
||||
"role": "tool",
|
||||
"toolCallId": tool_call_id,
|
||||
"content": payload,
|
||||
}
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_service_rejects_invalid_session_id() -> None:
|
||||
run_service = RunService()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await run_service.run(session_id="session-1", user_input="hello")
|
||||
await run_service.run(run_input=_build_run_input(thread_id="session-1"))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -74,7 +133,272 @@ async def test_resume_service_requires_pending_tool_call() -> None:
|
||||
resume_service = ResumeService()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await resume_service.resume(session_id="session-1", tool_call_id="call-1")
|
||||
await resume_service.resume(
|
||||
run_input=_build_resume_input(
|
||||
thread_id="session-1",
|
||||
tool_call_id="call-1",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_service_validates_pending_tool_guard_and_persists_payload(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
session_id = uuid4()
|
||||
user_id = uuid4()
|
||||
captured: list[dict[str, object]] = []
|
||||
|
||||
class _FakeDbSession:
|
||||
async def commit(self) -> None:
|
||||
return None
|
||||
|
||||
class _FakeSessionFactory:
|
||||
def __call__(self) -> "_FakeSessionFactory":
|
||||
return self
|
||||
|
||||
async def __aenter__(self) -> _FakeDbSession:
|
||||
return _FakeDbSession()
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> bool:
|
||||
del exc_type, exc, tb
|
||||
return False
|
||||
|
||||
class _FakeSessionRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def lock_session_for_update(self, *, session_id: object):
|
||||
return SimpleNamespace(
|
||||
id=session_id,
|
||||
user_id=user_id,
|
||||
status=AgentChatSessionStatus.RUNNING,
|
||||
message_count=0,
|
||||
total_tokens=0,
|
||||
total_cost=0,
|
||||
state_snapshot={
|
||||
"pending_tool_call_id": "call-1",
|
||||
"pending_tool_name": "navigate_to_route",
|
||||
"pending_tool_args_sha256": "c8e6573e6ce79d9a8052d5167e18602841fa4248f3e8a2efb448c4bfbd298a12",
|
||||
"pending_tool_nonce": "nonce-1",
|
||||
},
|
||||
)
|
||||
|
||||
async def next_message_seq(self, *, session_id: object) -> int:
|
||||
del session_id
|
||||
return 1
|
||||
|
||||
async def update_runtime_state(self, **kwargs) -> None:
|
||||
del kwargs
|
||||
|
||||
class _FakeMessageRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def append_message(self, **kwargs) -> None:
|
||||
captured.append(kwargs)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.resume_service.SessionRepository",
|
||||
_FakeSessionRepository,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.resume_service.MessageRepository",
|
||||
_FakeMessageRepository,
|
||||
)
|
||||
|
||||
service = ResumeService(session_factory=_FakeSessionFactory()) # type: ignore[arg-type]
|
||||
await service.resume(
|
||||
run_input=_build_resume_input(
|
||||
thread_id=str(session_id),
|
||||
tool_call_id="call-1",
|
||||
),
|
||||
)
|
||||
|
||||
assert captured[0]["role"] == AgentChatMessageRole.TOOL
|
||||
stored_payload = json.loads(captured[0]["content"])
|
||||
assert stored_payload["toolName"] == "navigate_to_route"
|
||||
assert stored_payload["result"]["ok"] is True
|
||||
assert stored_payload["result"]["applied"] is True
|
||||
assert "ui" not in stored_payload
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_service_rejects_mismatched_nonce(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
session_id = uuid4()
|
||||
user_id = uuid4()
|
||||
|
||||
class _FakeDbSession:
|
||||
async def commit(self) -> None:
|
||||
return None
|
||||
|
||||
class _FakeSessionFactory:
|
||||
def __call__(self) -> "_FakeSessionFactory":
|
||||
return self
|
||||
|
||||
async def __aenter__(self) -> _FakeDbSession:
|
||||
return _FakeDbSession()
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> bool:
|
||||
del exc_type, exc, tb
|
||||
return False
|
||||
|
||||
class _FakeSessionRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def lock_session_for_update(self, *, session_id: object):
|
||||
return SimpleNamespace(
|
||||
id=session_id,
|
||||
user_id=user_id,
|
||||
status=AgentChatSessionStatus.RUNNING,
|
||||
message_count=0,
|
||||
total_tokens=0,
|
||||
total_cost=0,
|
||||
state_snapshot={
|
||||
"pending_tool_call_id": "call-1",
|
||||
"pending_tool_name": "navigate_to_route",
|
||||
"pending_tool_args_sha256": "c8e6573e6ce79d9a8052d5167e18602841fa4248f3e8a2efb448c4bfbd298a12",
|
||||
"pending_tool_nonce": "nonce-1",
|
||||
},
|
||||
)
|
||||
|
||||
async def next_message_seq(self, *, session_id: object) -> int:
|
||||
del session_id
|
||||
return 1
|
||||
|
||||
async def update_runtime_state(self, **kwargs) -> None:
|
||||
del kwargs
|
||||
|
||||
class _FakeMessageRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def append_message(self, **kwargs) -> None:
|
||||
del kwargs
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.resume_service.SessionRepository",
|
||||
_FakeSessionRepository,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.resume_service.MessageRepository",
|
||||
_FakeMessageRepository,
|
||||
)
|
||||
|
||||
service = ResumeService(session_factory=_FakeSessionFactory()) # type: ignore[arg-type]
|
||||
with pytest.raises(ValueError, match="nonce"):
|
||||
await service.resume(
|
||||
run_input=_build_resume_input(
|
||||
thread_id=str(session_id),
|
||||
tool_call_id="call-1",
|
||||
content=json.dumps(
|
||||
{
|
||||
"toolName": "navigate_to_route",
|
||||
"toolArgs": {
|
||||
"target": "/calendar/dayweek",
|
||||
"replace": False,
|
||||
"__nonce": "nonce-1",
|
||||
},
|
||||
"nonce": "nonce-bad",
|
||||
"result": {"ok": True},
|
||||
},
|
||||
ensure_ascii=True,
|
||||
separators=(",", ":"),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_service_rejects_tool_result_when_not_ok(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
session_id = uuid4()
|
||||
user_id = uuid4()
|
||||
|
||||
class _FakeDbSession:
|
||||
async def commit(self) -> None:
|
||||
return None
|
||||
|
||||
class _FakeSessionFactory:
|
||||
def __call__(self) -> "_FakeSessionFactory":
|
||||
return self
|
||||
|
||||
async def __aenter__(self) -> _FakeDbSession:
|
||||
return _FakeDbSession()
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> bool:
|
||||
del exc_type, exc, tb
|
||||
return False
|
||||
|
||||
class _FakeSessionRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def lock_session_for_update(self, *, session_id: object):
|
||||
return SimpleNamespace(
|
||||
id=session_id,
|
||||
user_id=user_id,
|
||||
status=AgentChatSessionStatus.RUNNING,
|
||||
message_count=0,
|
||||
total_tokens=0,
|
||||
total_cost=0,
|
||||
state_snapshot={
|
||||
"pending_tool_call_id": "call-1",
|
||||
"pending_tool_name": "navigate_to_route",
|
||||
"pending_tool_args_sha256": "c8e6573e6ce79d9a8052d5167e18602841fa4248f3e8a2efb448c4bfbd298a12",
|
||||
"pending_tool_nonce": "nonce-1",
|
||||
},
|
||||
)
|
||||
|
||||
async def next_message_seq(self, *, session_id: object) -> int:
|
||||
del session_id
|
||||
return 1
|
||||
|
||||
async def update_runtime_state(self, **kwargs) -> None:
|
||||
del kwargs
|
||||
|
||||
class _FakeMessageRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def append_message(self, **kwargs) -> None:
|
||||
del kwargs
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.resume_service.SessionRepository",
|
||||
_FakeSessionRepository,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.resume_service.MessageRepository",
|
||||
_FakeMessageRepository,
|
||||
)
|
||||
|
||||
service = ResumeService(session_factory=_FakeSessionFactory()) # type: ignore[arg-type]
|
||||
with pytest.raises(ValueError, match="execution failed"):
|
||||
await service.resume(
|
||||
run_input=_build_resume_input(
|
||||
thread_id=str(session_id),
|
||||
tool_call_id="call-1",
|
||||
content=json.dumps(
|
||||
{
|
||||
"toolName": "navigate_to_route",
|
||||
"toolArgs": {
|
||||
"target": "/calendar/dayweek",
|
||||
"replace": False,
|
||||
"__nonce": "nonce-1",
|
||||
},
|
||||
"nonce": "nonce-1",
|
||||
"result": {"ok": False, "error": "navigator not bound"},
|
||||
},
|
||||
ensure_ascii=True,
|
||||
separators=(",", ":"),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -256,7 +580,9 @@ async def test_run_service_passes_user_context_system_prompt_to_runtime(
|
||||
session_uuid = session_id
|
||||
run_service = RunService(session_factory=_FakeSessionFactory()) # type: ignore[arg-type]
|
||||
|
||||
await run_service.run(session_id=str(session_id), user_input="hello")
|
||||
await run_service.run(
|
||||
run_input=_build_run_input(thread_id=str(session_id), text="hello")
|
||||
)
|
||||
|
||||
system_prompt = captured["system_prompt"]
|
||||
assert isinstance(system_prompt, str)
|
||||
@@ -267,6 +593,290 @@ async def test_run_service_passes_user_context_system_prompt_to_runtime(
|
||||
assert payload["ai_language"] == "en-US"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_service_emits_frontend_tool_pending_events(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
session_id = uuid4()
|
||||
user_id = uuid4()
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
class _FakeDbSession:
|
||||
async def commit(self) -> None:
|
||||
return None
|
||||
|
||||
class _FakeSessionFactory:
|
||||
def __call__(self) -> "_FakeSessionFactory":
|
||||
return self
|
||||
|
||||
async def __aenter__(self) -> _FakeDbSession:
|
||||
return _FakeDbSession()
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> bool:
|
||||
del exc_type, exc, tb
|
||||
return False
|
||||
|
||||
class _FakeSessionRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def lock_session_for_update(self, *, session_id: object):
|
||||
return SimpleNamespace(
|
||||
id=session_id,
|
||||
user_id=user_id,
|
||||
status=AgentChatSessionStatus.PENDING,
|
||||
message_count=0,
|
||||
total_tokens=0,
|
||||
total_cost=0,
|
||||
state_snapshot=None,
|
||||
)
|
||||
|
||||
async def next_message_seq(self, *, session_id: object):
|
||||
del session_id
|
||||
return 1
|
||||
|
||||
async def update_runtime_state(self, **kwargs) -> None:
|
||||
captured["update_runtime_state"] = kwargs
|
||||
|
||||
class _FakeMessageRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def append_message(self, **kwargs) -> None:
|
||||
captured.setdefault("messages", []).append(kwargs)
|
||||
|
||||
class _FakeRuntime:
|
||||
def execute(self, *, user_input: str, system_prompt: str | None = None):
|
||||
del user_input, system_prompt
|
||||
return {
|
||||
"assistant_text": "请确认是否跳转。",
|
||||
"prompt_tokens": 1,
|
||||
"completion_tokens": 1,
|
||||
"total_tokens": 2,
|
||||
"cost": "0.001",
|
||||
"agui_events": [],
|
||||
}
|
||||
|
||||
async def _fake_load_agent_model_selection(self, _session):
|
||||
del self
|
||||
return ("qwen3.5-flash", "dashscope", SystemAgentLLMConfig())
|
||||
|
||||
async def _fake_load_user_agent_context(self, session, session_id, user_id):
|
||||
del self, session, session_id
|
||||
return SimpleNamespace(
|
||||
user_id=user_id,
|
||||
username="demo-user",
|
||||
bio=None,
|
||||
settings=SimpleNamespace(
|
||||
preferences=SimpleNamespace(
|
||||
interface_language="zh-CN",
|
||||
ai_language="zh-CN",
|
||||
timezone="Asia/Shanghai",
|
||||
country="CN",
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.run_service.SessionRepository",
|
||||
_FakeSessionRepository,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.run_service.MessageRepository",
|
||||
_FakeMessageRepository,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.run_service.create_runtime",
|
||||
lambda **_kwargs: _FakeRuntime(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.run_service.RunService._load_agent_model_selection",
|
||||
_fake_load_agent_model_selection,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.run_service.RunService._load_user_agent_context",
|
||||
_fake_load_user_agent_context,
|
||||
)
|
||||
|
||||
service = RunService(session_factory=_FakeSessionFactory()) # type: ignore[arg-type]
|
||||
result = await service.run(
|
||||
run_input=_build_run_input(
|
||||
thread_id=str(session_id),
|
||||
text="帮我打开日历",
|
||||
tools=[
|
||||
{
|
||||
"name": "navigate_to_route",
|
||||
"description": "navigate",
|
||||
"parameters": {"type": "object"},
|
||||
}
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
assert result["pending_tool_call_id"] is not None
|
||||
tool_start = next(event for event in result["events"] if event["type"] == "TOOL_CALL_START")
|
||||
assert tool_start["toolCallName"] == "navigate_to_route"
|
||||
runtime_state = captured["update_runtime_state"]
|
||||
assert isinstance(runtime_state, dict)
|
||||
assert runtime_state["status"] == AgentChatSessionStatus.RUNNING
|
||||
snapshot = runtime_state["state_snapshot"]
|
||||
assert isinstance(snapshot, dict)
|
||||
assert snapshot["pending_tool_name"] == "navigate_to_route"
|
||||
assert isinstance(snapshot["pending_tool_args_sha256"], str)
|
||||
assert isinstance(snapshot["pending_tool_nonce"], str)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_service_executes_backend_calendar_tool_and_emits_result(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
session_id = uuid4()
|
||||
user_id = uuid4()
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
class _FakeDbSession:
|
||||
async def commit(self) -> None:
|
||||
return None
|
||||
|
||||
class _FakeSessionFactory:
|
||||
def __call__(self) -> "_FakeSessionFactory":
|
||||
return self
|
||||
|
||||
async def __aenter__(self) -> _FakeDbSession:
|
||||
return _FakeDbSession()
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> bool:
|
||||
del exc_type, exc, tb
|
||||
return False
|
||||
|
||||
class _FakeSessionRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def lock_session_for_update(self, *, session_id: object):
|
||||
return SimpleNamespace(
|
||||
id=session_id,
|
||||
user_id=user_id,
|
||||
status=AgentChatSessionStatus.PENDING,
|
||||
message_count=0,
|
||||
total_tokens=0,
|
||||
total_cost=0,
|
||||
state_snapshot=None,
|
||||
)
|
||||
|
||||
async def next_message_seq(self, *, session_id: object):
|
||||
del session_id
|
||||
return 1
|
||||
|
||||
async def update_runtime_state(self, **kwargs) -> None:
|
||||
captured["update_runtime_state"] = kwargs
|
||||
|
||||
class _FakeMessageRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def append_message(self, **kwargs) -> None:
|
||||
captured.setdefault("messages", []).append(kwargs)
|
||||
|
||||
class _FakeRuntime:
|
||||
def execute(self, *, user_input: str, system_prompt: str | None = None):
|
||||
del user_input, system_prompt
|
||||
return {
|
||||
"assistant_text": "日历事件已创建。",
|
||||
"prompt_tokens": 1,
|
||||
"completion_tokens": 1,
|
||||
"total_tokens": 2,
|
||||
"cost": "0.001",
|
||||
"agui_events": [],
|
||||
}
|
||||
|
||||
async def _fake_load_agent_model_selection(self, _session):
|
||||
del self
|
||||
return ("qwen3.5-flash", "dashscope", SystemAgentLLMConfig())
|
||||
|
||||
async def _fake_load_user_agent_context(self, session, session_id, user_id):
|
||||
del self, session, session_id
|
||||
return SimpleNamespace(
|
||||
user_id=user_id,
|
||||
username="demo-user",
|
||||
bio=None,
|
||||
settings=SimpleNamespace(
|
||||
preferences=SimpleNamespace(
|
||||
interface_language="zh-CN",
|
||||
ai_language="zh-CN",
|
||||
timezone="Asia/Shanghai",
|
||||
country="CN",
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
async def _fake_execute_backend_tool(
|
||||
self,
|
||||
*,
|
||||
session,
|
||||
owner_id,
|
||||
tool_name,
|
||||
tool_args,
|
||||
):
|
||||
del self, session, owner_id
|
||||
assert tool_name == "create_calendar_event"
|
||||
assert "title" in tool_args
|
||||
return {
|
||||
"result": {"eventId": "evt-1", "ok": True},
|
||||
"ui": {
|
||||
"type": "calendar_card.v1",
|
||||
"version": "v1",
|
||||
"data": {"id": "evt-1", "title": "会议"},
|
||||
},
|
||||
}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.run_service.SessionRepository",
|
||||
_FakeSessionRepository,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.run_service.MessageRepository",
|
||||
_FakeMessageRepository,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.run_service.create_runtime",
|
||||
lambda **_kwargs: _FakeRuntime(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.run_service.RunService._load_agent_model_selection",
|
||||
_fake_load_agent_model_selection,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.run_service.RunService._load_user_agent_context",
|
||||
_fake_load_user_agent_context,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.run_service.RunService._execute_backend_tool",
|
||||
_fake_execute_backend_tool,
|
||||
)
|
||||
|
||||
service = RunService(session_factory=_FakeSessionFactory()) # type: ignore[arg-type]
|
||||
result = await service.run(
|
||||
run_input=_build_run_input(
|
||||
thread_id=str(session_id),
|
||||
text='#tool:create_calendar_event {"title":"会议","startAt":"2026-03-07T08:00:00Z"}',
|
||||
tools=[
|
||||
{
|
||||
"name": "create_calendar_event",
|
||||
"description": "create calendar",
|
||||
"parameters": {"type": "object"},
|
||||
}
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
assert result["pending_tool_call_id"] is None
|
||||
assert any(event["type"] == "TOOL_CALL_RESULT" for event in result["events"])
|
||||
runtime_state = captured["update_runtime_state"]
|
||||
assert isinstance(runtime_state, dict)
|
||||
assert runtime_state["status"] == AgentChatSessionStatus.COMPLETED
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_user_agent_context_parses_profile_settings_v1() -> None:
|
||||
session_id = uuid4()
|
||||
@@ -519,7 +1129,9 @@ async def test_run_service_still_executes_when_profile_missing(
|
||||
session_uuid = session_id
|
||||
run_service = RunService(session_factory=_FakeSessionFactory()) # type: ignore[arg-type]
|
||||
|
||||
await run_service.run(session_id=str(session_id), user_input="hello")
|
||||
await run_service.run(
|
||||
run_input=_build_run_input(thread_id=str(session_id), text="hello")
|
||||
)
|
||||
|
||||
system_prompt = captured["system_prompt"]
|
||||
assert isinstance(system_prompt, str)
|
||||
|
||||
@@ -4,9 +4,18 @@ from core.agent.domain.state_snapshot import AgentStateSnapshot
|
||||
|
||||
|
||||
def test_state_snapshot_serialization_round_trip() -> None:
|
||||
snapshot = AgentStateSnapshot(status="running", pending_tool_call_id="call-1")
|
||||
snapshot = AgentStateSnapshot(
|
||||
status="running",
|
||||
pending_tool_call_id="call-1",
|
||||
pending_tool_name="navigate_to_route",
|
||||
pending_tool_args_sha256="abc",
|
||||
pending_tool_nonce="nonce-1",
|
||||
)
|
||||
|
||||
payload = snapshot.model_dump()
|
||||
|
||||
assert payload["status"] == "running"
|
||||
assert payload["pending_tool_call_id"] == "call-1"
|
||||
assert payload["pending_tool_name"] == "navigate_to_route"
|
||||
assert payload["pending_tool_args_sha256"] == "abc"
|
||||
assert payload["pending_tool_nonce"] == "nonce-1"
|
||||
|
||||
Reference in New Issue
Block a user