feat: AG-UI 协议对齐与路由导航功能

- 前端: 添加 SSE 流式支持、stateSnapshot 事件、路由导航工具
- 前端: 实现工具调用审批流程,支持 pending 状态展示
- 后端: Agent 状态管理与会话持久化相关重构
- 文档: 新增 agent-agui-full-alignance 设计文档
- 测试: 补充相关单元测试和集成测试
This commit is contained in:
zl-q
2026-03-07 17:30:20 +08:00
parent ec33bb0cee
commit 120df903d2
52 changed files with 4305 additions and 1672 deletions
@@ -1,9 +1,11 @@
from __future__ import annotations
import json
import uuid
from decimal import Decimal
import pytest
from ag_ui.core import RunAgentInput
from sqlalchemy import delete, select
from core.agent.application.resume_service import ResumeService
@@ -84,28 +86,76 @@ async def test_run_then_resume_persists_messages_and_session_state(
published: list[str] = []
def _publish(event_type: str, payload: dict[str, object]) -> None:
del payload
published.append(event_type)
async def _publish(event: dict[str, object]) -> None:
event_type = event.get("type")
if isinstance(event_type, str):
published.append(event_type)
try:
run_result = run_agent_task(
run_input_payload = {
"threadId": str(session_uuid),
"runId": "run-it-1",
"state": {},
"messages": [
{"id": "u1", "role": "user", "content": "帮我打开日历"},
],
"tools": [
{
"name": "navigate_to_route",
"description": "navigate route",
"parameters": {"type": "object"},
}
],
"context": [],
"forwardedProps": {},
}
run_result = await run_agent_task(
{
"command": "run",
"session_id": str(session_uuid),
"user_input": "hello",
"run_input": run_input_payload,
},
publish_event=_publish,
run_service=RunService(),
resume_service=ResumeService(),
)
pending_tool_call_id = str(run_result["pending_tool_call_id"])
state_snapshot = run_result["state_snapshot"]
assert isinstance(state_snapshot, dict)
pending_tool_nonce = state_snapshot["pending_tool_nonce"]
assert isinstance(pending_tool_nonce, str)
run_agent_task(
await run_agent_task(
{
"command": "resume",
"session_id": str(session_uuid),
"tool_call_id": pending_tool_call_id,
"run_input": {
"threadId": str(session_uuid),
"runId": "run-it-2",
"state": {},
"messages": [
{
"id": "tool-1",
"role": "tool",
"toolCallId": pending_tool_call_id,
"content": json.dumps(
{
"toolName": "navigate_to_route",
"toolArgs": {
"target": "/calendar/dayweek",
"replace": False,
"__nonce": pending_tool_nonce,
},
"nonce": pending_tool_nonce,
"result": {"ok": True},
},
ensure_ascii=True,
separators=(",", ":"),
),
}
],
"tools": [],
"context": [],
"forwardedProps": {},
},
},
publish_event=_publish,
run_service=RunService(),
@@ -123,6 +173,9 @@ async def test_run_then_resume_persists_messages_and_session_state(
assert db_session.state_snapshot == {
"status": "completed",
"pending_tool_call_id": None,
"pending_tool_name": None,
"pending_tool_args_sha256": None,
"pending_tool_nonce": None,
}
rows = await verify_session.execute(
@@ -142,7 +195,7 @@ async def test_run_then_resume_persists_messages_and_session_state(
assert messages[1].cost == Decimal("0.002500")
assert "RUN_STARTED" in published
assert "RUN_RESUMED" in published
assert "RUN_FINISHED" in published
assert "TEXT_MESSAGE_CONTENT" in published
finally:
async with AsyncSessionLocal() as cleanup_session:
@@ -219,7 +272,21 @@ async def test_run_service_embeds_profile_settings_in_runtime_system_prompt(
seed_session.add(AgentChatSession(id=session_uuid, user_id=owner_id))
await seed_session.commit()
result = await RunService().run(session_id=str(session_uuid), user_input="hello")
result = await RunService().run(
run_input=RunAgentInput.model_validate(
{
"threadId": str(session_uuid),
"runId": "run-it-1",
"state": {},
"messages": [
{"id": "u1", "role": "user", "content": "hello"},
],
"tools": [],
"context": [],
"forwardedProps": {},
}
)
)
assert result["persisted"] is True
assert captured["user_input"] == "hello"
@@ -16,29 +16,38 @@ class _FakeStorage:
return "etag-1"
def test_closed_loop_run_flow_frontend_to_sse() -> None:
session_id = "00000000-0000-0000-0000-000000000001"
async def test_closed_loop_run_flow_frontend_to_sse() -> None:
thread_id = "00000000-0000-0000-0000-000000000001"
published: list[str] = []
class _FakeRunService:
async def run(self, *, session_id: str, user_input: str) -> dict[str, object]:
return {"session_id": session_id, "user_input": user_input}
async def run(self, *, run_input: object) -> dict[str, object]:
del run_input
return {"threadId": thread_id, "runId": "run-1"}
def _publish(event_type: str, payload: dict[str, object]) -> None:
del payload
published.append(event_type)
async def _publish(event: dict[str, object]) -> None:
event_type = event.get("type")
if isinstance(event_type, str):
published.append(event_type)
result = run_agent_task(
result = await run_agent_task(
{
"command": "run",
"session_id": session_id,
"user_input": "hello",
"run_input": {
"threadId": thread_id,
"runId": "run-1",
"state": {},
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
"tools": [],
"context": [],
"forwardedProps": {},
},
},
publish_event=_publish,
run_service=_FakeRunService(),
)
assert result["session_id"] == session_id
assert result["threadId"] == thread_id
assert published[0] == "RUN_STARTED"
assert published[-1] == "RUN_FINISHED"
+206 -24
View File
@@ -3,10 +3,12 @@ from __future__ import annotations
from types import SimpleNamespace
from uuid import uuid4
from ag_ui.core import RunAgentInput
from fastapi.testclient import TestClient
from app import app
from core.auth.models import CurrentUser
from v1.agent import router as agent_router
from v1.agent.dependencies import get_agent_service
from v1.users.dependencies import get_current_user
@@ -16,52 +18,122 @@ class _FakeAgentService:
self._stream_called = False
async def enqueue_run(
self, *, session_id: str | None, prompt: str, current_user: CurrentUser
self, *, run_input: RunAgentInput, current_user: CurrentUser
):
del prompt, current_user
resolved_session = session_id or "auto-created-session"
del current_user
return SimpleNamespace(
task_id="task-run-1",
session_id=resolved_session,
created=session_id is None,
thread_id=run_input.thread_id,
run_id=run_input.run_id,
created=False,
)
async def enqueue_resume(
self,
*,
session_id: str,
tool_call_id: str,
thread_id: str,
run_input: RunAgentInput,
current_user: CurrentUser,
):
del tool_call_id, current_user
del thread_id, current_user
return SimpleNamespace(
task_id="task-resume-1", session_id=session_id, created=False
task_id="task-resume-1",
thread_id=run_input.thread_id,
run_id=run_input.run_id,
created=False,
)
async def stream_events(
self,
*,
session_id: str,
thread_id: str,
last_event_id: str | None,
current_user: CurrentUser,
) -> list[dict[str, object]]:
del session_id, current_user
del thread_id, current_user
if self._stream_called:
return []
self._stream_called = True
return [
{"id": "2-0", "event": {"type": "RUN_STARTED"}, "cursor": last_event_id}
{
"id": "2-0",
"event": {
"type": "RUN_STARTED",
"threadId": "00000000-0000-0000-0000-000000000001",
"runId": "run-1",
},
"cursor": last_event_id,
}
]
async def get_history_snapshot(
self,
*,
thread_id: str,
before: str | None,
current_user: CurrentUser,
) -> dict[str, object]:
del current_user
return {
"type": "STATE_SNAPSHOT",
"threadId": thread_id,
"snapshot": {
"scope": "history_day",
"day": before or "2026-03-07",
"hasMore": False,
"messages": [
{
"id": "msg-h1",
"role": "assistant",
"content": "history-message",
}
],
},
}
async def get_user_history_snapshot(
self,
*,
current_user: CurrentUser,
thread_id: str | None,
before: str | None,
) -> dict[str, object]:
del current_user, before
return {
"type": "STATE_SNAPSHOT",
"threadId": thread_id or "00000000-0000-0000-0000-000000000001",
"snapshot": {
"scope": "history_day",
"day": "2026-03-07",
"hasMore": False,
"messages": [],
},
}
def test_run_requires_auth_and_returns_202_task_id() -> None:
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
client = TestClient(app)
original_allow_run = agent_router._allow_run_request
async def _allow_run(*, user_id: str) -> bool:
del user_id
return True
agent_router._allow_run_request = _allow_run # type: ignore[assignment]
try:
unauthorized = client.post(
"/api/v1/agent/runs",
json={"session_id": "session-1", "prompt": "hello"},
json={
"threadId": "00000000-0000-0000-0000-000000000001",
"runId": "run-1",
"state": {},
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
"tools": [],
"context": [],
"forwardedProps": {},
},
)
assert unauthorized.status_code == 401
@@ -70,20 +142,23 @@ def test_run_requires_auth_and_returns_202_task_id() -> None:
)
authorized = client.post(
"/api/v1/agent/runs",
json={"session_id": "session-1", "prompt": "hello"},
json={
"threadId": "00000000-0000-0000-0000-000000000001",
"runId": "run-1",
"state": {},
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
"tools": [],
"context": [],
"forwardedProps": {},
},
)
assert authorized.status_code == 202
assert authorized.json()["task_id"] == "task-run-1"
assert authorized.json()["taskId"] == "task-run-1"
assert authorized.json()["threadId"] == "00000000-0000-0000-0000-000000000001"
assert authorized.json()["runId"] == "run-1"
assert authorized.json()["created"] is False
first_chat = client.post(
"/api/v1/agent/runs",
json={"prompt": "hello"},
)
assert first_chat.status_code == 202
assert first_chat.json()["session_id"] == "auto-created-session"
assert first_chat.json()["created"] is True
finally:
agent_router._allow_run_request = original_allow_run # type: ignore[assignment]
app.dependency_overrides = {}
@@ -93,15 +168,122 @@ def test_stream_reads_from_last_event_id() -> None:
id=uuid4(), email="user@example.com"
)
client = TestClient(app)
original_acquire = agent_router._acquire_sse_slot
original_release = agent_router._release_sse_slot
async def _allow_slot(*, user_id: str) -> bool:
del user_id
return True
async def _noop_release(*, user_id: str) -> None:
del user_id
return None
agent_router._acquire_sse_slot = _allow_slot # type: ignore[assignment]
agent_router._release_sse_slot = _noop_release # type: ignore[assignment]
try:
response = client.get(
"/api/v1/agent/runs/session-1/events?idle_limit=1",
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/events?idle_limit=1",
headers={"Last-Event-ID": "1-0"},
)
assert response.status_code == 200
assert response.headers["content-type"].startswith("text/event-stream")
assert "id: 2-0" in response.text
assert "event: RUN_STARTED" in response.text
assert '"threadId":"00000000-0000-0000-0000-000000000001"' in response.text
finally:
agent_router._acquire_sse_slot = original_acquire # type: ignore[assignment]
agent_router._release_sse_slot = original_release # type: ignore[assignment]
app.dependency_overrides = {}
def test_stream_rejects_invalid_last_event_id() -> None:
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
id=uuid4(), email="user@example.com"
)
client = TestClient(app)
try:
response = client.get(
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/events",
headers={"Last-Event-ID": "bad-id"},
)
assert response.status_code == 422
finally:
app.dependency_overrides = {}
def test_history_returns_state_snapshot() -> None:
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
client = TestClient(app)
try:
unauthorized = client.get(
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/history"
)
assert unauthorized.status_code == 401
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
id=uuid4(), email="user@example.com"
)
authorized = client.get(
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/history",
params={"before": "2026-03-07"},
)
assert authorized.status_code == 200
payload = authorized.json()
assert payload["type"] == "STATE_SNAPSHOT"
assert payload["threadId"] == "00000000-0000-0000-0000-000000000001"
assert payload["snapshot"]["scope"] == "history_day"
assert payload["snapshot"]["day"] == "2026-03-07"
finally:
app.dependency_overrides = {}
def test_user_history_returns_latest_snapshot() -> None:
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
id=uuid4(), email="user@example.com"
)
client = TestClient(app)
try:
response = client.get("/api/v1/agent/history")
assert response.status_code == 200
body = response.json()
assert body["type"] == "STATE_SNAPSHOT"
assert body["threadId"] == "00000000-0000-0000-0000-000000000001"
finally:
app.dependency_overrides = {}
def test_run_rejects_oversized_user_text_payload() -> None:
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
id=uuid4(), email="user@example.com"
)
client = TestClient(app)
try:
response = client.post(
"/api/v1/agent/runs",
json={
"threadId": "00000000-0000-0000-0000-000000000001",
"runId": "run-oversize",
"state": {},
"messages": [
{
"id": "u1",
"role": "user",
"content": "x" * 11000,
}
],
"tools": [],
"context": [],
"forwardedProps": {},
},
)
assert response.status_code == 422
finally:
app.dependency_overrides = {}
@@ -2,7 +2,7 @@ from __future__ import annotations
import os
from datetime import datetime, timedelta, timezone
from uuid import UUID
from uuid import UUID, uuid4
import httpx
import jwt
@@ -56,15 +56,25 @@ async def test_agent_sse_closed_loop_live() -> None:
run_resp = await client.post(
f"{BASE_URL}/api/v1/agent/runs",
headers=headers,
json={"prompt": "请用一句话介绍你自己"},
json={
"threadId": str(uuid4()),
"runId": "run-live-1",
"state": {},
"messages": [
{"id": "u1", "role": "user", "content": "请用一句话介绍你自己"}
],
"tools": [],
"context": [],
"forwardedProps": {},
},
)
assert run_resp.status_code == 202
accepted = run_resp.json()
session_id = str(accepted["session_id"])
assert session_id
thread_id = str(accepted["threadId"])
assert thread_id
events_url = f"{BASE_URL}/api/v1/agent/runs/{session_id}/events"
events_url = f"{BASE_URL}/api/v1/agent/runs/{thread_id}/events"
event_names: list[str] = []
async with client.stream("GET", events_url, headers=headers, timeout=20.0) as sse_resp:
assert sse_resp.status_code == 200
@@ -77,13 +87,13 @@ async def test_agent_sse_closed_loop_live() -> None:
assert "RUN_FINISHED" in event_names or "RUN_ERROR" in event_names
async with AsyncSessionLocal() as session:
session_row = await session.get(AgentChatSession, UUID(session_id))
session_row = await session.get(AgentChatSession, UUID(thread_id))
assert session_row is not None
assert session_row.message_count >= 1
assert session_row.total_tokens >= 0
assert session_row.total_cost >= 0
rows = await session.execute(
select(AgentChatMessage).where(AgentChatMessage.session_id == UUID(session_id))
select(AgentChatMessage).where(AgentChatMessage.session_id == UUID(thread_id))
)
assert len(list(rows.scalars().all())) >= 1