feat: AG-UI 协议对齐与路由导航功能
- 前端: 添加 SSE 流式支持、stateSnapshot 事件、路由导航工具 - 前端: 实现工具调用审批流程,支持 pending 状态展示 - 后端: Agent 状态管理与会话持久化相关重构 - 文档: 新增 agent-agui-full-alignance 设计文档 - 测试: 补充相关单元测试和集成测试
This commit is contained in:
@@ -3,10 +3,12 @@ from __future__ import annotations
|
||||
from types import SimpleNamespace
|
||||
from uuid import uuid4
|
||||
|
||||
from ag_ui.core import RunAgentInput
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app import app
|
||||
from core.auth.models import CurrentUser
|
||||
from v1.agent import router as agent_router
|
||||
from v1.agent.dependencies import get_agent_service
|
||||
from v1.users.dependencies import get_current_user
|
||||
|
||||
@@ -16,52 +18,122 @@ class _FakeAgentService:
|
||||
self._stream_called = False
|
||||
|
||||
async def enqueue_run(
|
||||
self, *, session_id: str | None, prompt: str, current_user: CurrentUser
|
||||
self, *, run_input: RunAgentInput, current_user: CurrentUser
|
||||
):
|
||||
del prompt, current_user
|
||||
resolved_session = session_id or "auto-created-session"
|
||||
del current_user
|
||||
return SimpleNamespace(
|
||||
task_id="task-run-1",
|
||||
session_id=resolved_session,
|
||||
created=session_id is None,
|
||||
thread_id=run_input.thread_id,
|
||||
run_id=run_input.run_id,
|
||||
created=False,
|
||||
)
|
||||
|
||||
async def enqueue_resume(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
tool_call_id: str,
|
||||
thread_id: str,
|
||||
run_input: RunAgentInput,
|
||||
current_user: CurrentUser,
|
||||
):
|
||||
del tool_call_id, current_user
|
||||
del thread_id, current_user
|
||||
return SimpleNamespace(
|
||||
task_id="task-resume-1", session_id=session_id, created=False
|
||||
task_id="task-resume-1",
|
||||
thread_id=run_input.thread_id,
|
||||
run_id=run_input.run_id,
|
||||
created=False,
|
||||
)
|
||||
|
||||
async def stream_events(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
thread_id: str,
|
||||
last_event_id: str | None,
|
||||
current_user: CurrentUser,
|
||||
) -> list[dict[str, object]]:
|
||||
del session_id, current_user
|
||||
del thread_id, current_user
|
||||
if self._stream_called:
|
||||
return []
|
||||
self._stream_called = True
|
||||
return [
|
||||
{"id": "2-0", "event": {"type": "RUN_STARTED"}, "cursor": last_event_id}
|
||||
{
|
||||
"id": "2-0",
|
||||
"event": {
|
||||
"type": "RUN_STARTED",
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-1",
|
||||
},
|
||||
"cursor": last_event_id,
|
||||
}
|
||||
]
|
||||
|
||||
async def get_history_snapshot(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
before: str | None,
|
||||
current_user: CurrentUser,
|
||||
) -> dict[str, object]:
|
||||
del current_user
|
||||
return {
|
||||
"type": "STATE_SNAPSHOT",
|
||||
"threadId": thread_id,
|
||||
"snapshot": {
|
||||
"scope": "history_day",
|
||||
"day": before or "2026-03-07",
|
||||
"hasMore": False,
|
||||
"messages": [
|
||||
{
|
||||
"id": "msg-h1",
|
||||
"role": "assistant",
|
||||
"content": "history-message",
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
async def get_user_history_snapshot(
|
||||
self,
|
||||
*,
|
||||
current_user: CurrentUser,
|
||||
thread_id: str | None,
|
||||
before: str | None,
|
||||
) -> dict[str, object]:
|
||||
del current_user, before
|
||||
return {
|
||||
"type": "STATE_SNAPSHOT",
|
||||
"threadId": thread_id or "00000000-0000-0000-0000-000000000001",
|
||||
"snapshot": {
|
||||
"scope": "history_day",
|
||||
"day": "2026-03-07",
|
||||
"hasMore": False,
|
||||
"messages": [],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_run_requires_auth_and_returns_202_task_id() -> None:
|
||||
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
|
||||
client = TestClient(app)
|
||||
original_allow_run = agent_router._allow_run_request
|
||||
|
||||
async def _allow_run(*, user_id: str) -> bool:
|
||||
del user_id
|
||||
return True
|
||||
|
||||
agent_router._allow_run_request = _allow_run # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
unauthorized = client.post(
|
||||
"/api/v1/agent/runs",
|
||||
json={"session_id": "session-1", "prompt": "hello"},
|
||||
json={
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-1",
|
||||
"state": {},
|
||||
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
},
|
||||
)
|
||||
assert unauthorized.status_code == 401
|
||||
|
||||
@@ -70,20 +142,23 @@ def test_run_requires_auth_and_returns_202_task_id() -> None:
|
||||
)
|
||||
authorized = client.post(
|
||||
"/api/v1/agent/runs",
|
||||
json={"session_id": "session-1", "prompt": "hello"},
|
||||
json={
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-1",
|
||||
"state": {},
|
||||
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
},
|
||||
)
|
||||
assert authorized.status_code == 202
|
||||
assert authorized.json()["task_id"] == "task-run-1"
|
||||
assert authorized.json()["taskId"] == "task-run-1"
|
||||
assert authorized.json()["threadId"] == "00000000-0000-0000-0000-000000000001"
|
||||
assert authorized.json()["runId"] == "run-1"
|
||||
assert authorized.json()["created"] is False
|
||||
|
||||
first_chat = client.post(
|
||||
"/api/v1/agent/runs",
|
||||
json={"prompt": "hello"},
|
||||
)
|
||||
assert first_chat.status_code == 202
|
||||
assert first_chat.json()["session_id"] == "auto-created-session"
|
||||
assert first_chat.json()["created"] is True
|
||||
finally:
|
||||
agent_router._allow_run_request = original_allow_run # type: ignore[assignment]
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
@@ -93,15 +168,122 @@ def test_stream_reads_from_last_event_id() -> None:
|
||||
id=uuid4(), email="user@example.com"
|
||||
)
|
||||
client = TestClient(app)
|
||||
original_acquire = agent_router._acquire_sse_slot
|
||||
original_release = agent_router._release_sse_slot
|
||||
|
||||
async def _allow_slot(*, user_id: str) -> bool:
|
||||
del user_id
|
||||
return True
|
||||
|
||||
async def _noop_release(*, user_id: str) -> None:
|
||||
del user_id
|
||||
return None
|
||||
|
||||
agent_router._acquire_sse_slot = _allow_slot # type: ignore[assignment]
|
||||
agent_router._release_sse_slot = _noop_release # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
response = client.get(
|
||||
"/api/v1/agent/runs/session-1/events?idle_limit=1",
|
||||
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/events?idle_limit=1",
|
||||
headers={"Last-Event-ID": "1-0"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"].startswith("text/event-stream")
|
||||
assert "id: 2-0" in response.text
|
||||
assert "event: RUN_STARTED" in response.text
|
||||
assert '"threadId":"00000000-0000-0000-0000-000000000001"' in response.text
|
||||
finally:
|
||||
agent_router._acquire_sse_slot = original_acquire # type: ignore[assignment]
|
||||
agent_router._release_sse_slot = original_release # type: ignore[assignment]
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_stream_rejects_invalid_last_event_id() -> None:
|
||||
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=uuid4(), email="user@example.com"
|
||||
)
|
||||
client = TestClient(app)
|
||||
|
||||
try:
|
||||
response = client.get(
|
||||
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/events",
|
||||
headers={"Last-Event-ID": "bad-id"},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_history_returns_state_snapshot() -> None:
|
||||
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
|
||||
client = TestClient(app)
|
||||
|
||||
try:
|
||||
unauthorized = client.get(
|
||||
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/history"
|
||||
)
|
||||
assert unauthorized.status_code == 401
|
||||
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=uuid4(), email="user@example.com"
|
||||
)
|
||||
authorized = client.get(
|
||||
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/history",
|
||||
params={"before": "2026-03-07"},
|
||||
)
|
||||
assert authorized.status_code == 200
|
||||
payload = authorized.json()
|
||||
assert payload["type"] == "STATE_SNAPSHOT"
|
||||
assert payload["threadId"] == "00000000-0000-0000-0000-000000000001"
|
||||
assert payload["snapshot"]["scope"] == "history_day"
|
||||
assert payload["snapshot"]["day"] == "2026-03-07"
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_user_history_returns_latest_snapshot() -> None:
|
||||
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=uuid4(), email="user@example.com"
|
||||
)
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.get("/api/v1/agent/history")
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["type"] == "STATE_SNAPSHOT"
|
||||
assert body["threadId"] == "00000000-0000-0000-0000-000000000001"
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_run_rejects_oversized_user_text_payload() -> None:
|
||||
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=uuid4(), email="user@example.com"
|
||||
)
|
||||
client = TestClient(app)
|
||||
|
||||
try:
|
||||
response = client.post(
|
||||
"/api/v1/agent/runs",
|
||||
json={
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-oversize",
|
||||
"state": {},
|
||||
"messages": [
|
||||
{
|
||||
"id": "u1",
|
||||
"role": "user",
|
||||
"content": "x" * 11000,
|
||||
}
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from uuid import UUID
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import httpx
|
||||
import jwt
|
||||
@@ -56,15 +56,25 @@ async def test_agent_sse_closed_loop_live() -> None:
|
||||
run_resp = await client.post(
|
||||
f"{BASE_URL}/api/v1/agent/runs",
|
||||
headers=headers,
|
||||
json={"prompt": "请用一句话介绍你自己"},
|
||||
json={
|
||||
"threadId": str(uuid4()),
|
||||
"runId": "run-live-1",
|
||||
"state": {},
|
||||
"messages": [
|
||||
{"id": "u1", "role": "user", "content": "请用一句话介绍你自己"}
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
},
|
||||
)
|
||||
assert run_resp.status_code == 202
|
||||
|
||||
accepted = run_resp.json()
|
||||
session_id = str(accepted["session_id"])
|
||||
assert session_id
|
||||
thread_id = str(accepted["threadId"])
|
||||
assert thread_id
|
||||
|
||||
events_url = f"{BASE_URL}/api/v1/agent/runs/{session_id}/events"
|
||||
events_url = f"{BASE_URL}/api/v1/agent/runs/{thread_id}/events"
|
||||
event_names: list[str] = []
|
||||
async with client.stream("GET", events_url, headers=headers, timeout=20.0) as sse_resp:
|
||||
assert sse_resp.status_code == 200
|
||||
@@ -77,13 +87,13 @@ async def test_agent_sse_closed_loop_live() -> None:
|
||||
assert "RUN_FINISHED" in event_names or "RUN_ERROR" in event_names
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
session_row = await session.get(AgentChatSession, UUID(session_id))
|
||||
session_row = await session.get(AgentChatSession, UUID(thread_id))
|
||||
assert session_row is not None
|
||||
assert session_row.message_count >= 1
|
||||
assert session_row.total_tokens >= 0
|
||||
assert session_row.total_cost >= 0
|
||||
|
||||
rows = await session.execute(
|
||||
select(AgentChatMessage).where(AgentChatMessage.session_id == UUID(session_id))
|
||||
select(AgentChatMessage).where(AgentChatMessage.session_id == UUID(thread_id))
|
||||
)
|
||||
assert len(list(rows.scalars().all())) >= 1
|
||||
|
||||
Reference in New Issue
Block a user