feat: AG-UI 协议对齐与路由导航功能
- 前端: 添加 SSE 流式支持、stateSnapshot 事件、路由导航工具 - 前端: 实现工具调用审批流程,支持 pending 状态展示 - 后端: Agent 状态管理与会话持久化相关重构 - 文档: 新增 agent-agui-full-alignance 设计文档 - 测试: 补充相关单元测试和集成测试
This commit is contained in:
@@ -1,7 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
from uuid import UUID
|
||||
|
||||
from ag_ui.core import RunAgentInput
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from core.auth.models import CurrentUser
|
||||
from v1.agent.service import AgentService
|
||||
|
||||
@@ -11,14 +15,19 @@ class _FakeRepository:
|
||||
self.committed = False
|
||||
self.rolled_back = False
|
||||
self.deleted_session_id: str | None = None
|
||||
self.created_with_session_id: str | None = None
|
||||
|
||||
async def get_session_owner(self, *, session_id: str) -> str:
|
||||
del session_id
|
||||
return "00000000-0000-0000-0000-000000000001"
|
||||
if session_id == "00000000-0000-0000-0000-000000000001":
|
||||
return "00000000-0000-0000-0000-000000000001"
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
async def create_session_for_user(self, *, user_id: str) -> str:
|
||||
async def create_session_for_user(
|
||||
self, *, user_id: str, session_id: str | None = None
|
||||
) -> str:
|
||||
del user_id
|
||||
return "00000000-0000-0000-0000-000000000999"
|
||||
self.created_with_session_id = session_id
|
||||
return session_id or "00000000-0000-0000-0000-000000000999"
|
||||
|
||||
async def commit(self) -> None:
|
||||
self.committed = True
|
||||
@@ -29,6 +38,22 @@ class _FakeRepository:
|
||||
async def delete_session(self, *, session_id: str) -> None:
|
||||
self.deleted_session_id = session_id
|
||||
|
||||
async def get_history_day(
|
||||
self, *, session_id: str, before: date | None
|
||||
) -> dict[str, object] | None:
|
||||
del session_id
|
||||
if before is not None and before <= date(2026, 3, 6):
|
||||
return None
|
||||
return {
|
||||
"day": "2026-03-06",
|
||||
"hasMore": False,
|
||||
"messages": [{"id": "m1", "role": "assistant", "content": "hello"}],
|
||||
}
|
||||
|
||||
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None:
|
||||
del user_id
|
||||
return "00000000-0000-0000-0000-000000000001"
|
||||
|
||||
|
||||
class _FakeQueue:
|
||||
async def enqueue(
|
||||
@@ -63,6 +88,20 @@ def _user() -> CurrentUser:
|
||||
)
|
||||
|
||||
|
||||
def _build_run_input(*, thread_id: str, run_id: str) -> RunAgentInput:
|
||||
return RunAgentInput.model_validate(
|
||||
{
|
||||
"threadId": thread_id,
|
||||
"runId": run_id,
|
||||
"state": {},
|
||||
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def test_resume_idempotency_uses_redis_lock_and_task_key() -> None:
|
||||
service = AgentService(
|
||||
repository=_FakeRepository(),
|
||||
@@ -70,37 +109,46 @@ async def test_resume_idempotency_uses_redis_lock_and_task_key() -> None:
|
||||
stream=_FakeStream(),
|
||||
)
|
||||
user = _user()
|
||||
run_input = _build_run_input(
|
||||
thread_id="00000000-0000-0000-0000-000000000001",
|
||||
run_id="run-1",
|
||||
)
|
||||
|
||||
first = await service.enqueue_resume(
|
||||
session_id="session-1",
|
||||
tool_call_id="call-1",
|
||||
thread_id="00000000-0000-0000-0000-000000000001",
|
||||
run_input=run_input,
|
||||
current_user=user,
|
||||
)
|
||||
second = await service.enqueue_resume(
|
||||
session_id="session-1",
|
||||
tool_call_id="call-1",
|
||||
thread_id="00000000-0000-0000-0000-000000000001",
|
||||
run_input=run_input,
|
||||
current_user=user,
|
||||
)
|
||||
|
||||
assert first.task_id == second.task_id
|
||||
|
||||
|
||||
async def test_enqueue_run_without_session_creates_new_session() -> None:
|
||||
async def test_enqueue_run_creates_missing_thread_session() -> None:
|
||||
repository = _FakeRepository()
|
||||
service = AgentService(
|
||||
repository=repository,
|
||||
queue=_FakeQueue(),
|
||||
stream=_FakeStream(),
|
||||
)
|
||||
run_input = _build_run_input(
|
||||
thread_id="00000000-0000-0000-0000-000000000999",
|
||||
run_id="run-1",
|
||||
)
|
||||
|
||||
accepted = await service.enqueue_run(
|
||||
session_id=None,
|
||||
prompt="hello",
|
||||
run_input=run_input,
|
||||
current_user=_user(),
|
||||
)
|
||||
|
||||
assert accepted.session_id == "00000000-0000-0000-0000-000000000999"
|
||||
assert accepted.thread_id == "00000000-0000-0000-0000-000000000999"
|
||||
assert accepted.run_id == "run-1"
|
||||
assert accepted.created is True
|
||||
assert repository.created_with_session_id == "00000000-0000-0000-0000-000000000999"
|
||||
assert repository.committed is True
|
||||
|
||||
|
||||
@@ -111,11 +159,14 @@ async def test_enqueue_run_keeps_created_session_when_enqueue_fails() -> None:
|
||||
queue=_FailingQueue(),
|
||||
stream=_FakeStream(),
|
||||
)
|
||||
run_input = _build_run_input(
|
||||
thread_id="00000000-0000-0000-0000-000000000999",
|
||||
run_id="run-1",
|
||||
)
|
||||
|
||||
try:
|
||||
await service.enqueue_run(
|
||||
session_id=None,
|
||||
prompt="hello",
|
||||
run_input=run_input,
|
||||
current_user=_user(),
|
||||
)
|
||||
raise AssertionError("expected RuntimeError")
|
||||
@@ -123,3 +174,78 @@ async def test_enqueue_run_keeps_created_session_when_enqueue_fails() -> None:
|
||||
assert str(exc) == "enqueue failed"
|
||||
|
||||
assert repository.deleted_session_id is None
|
||||
|
||||
|
||||
async def test_enqueue_run_handles_session_create_race() -> None:
|
||||
class _RaceRepository(_FakeRepository):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.create_calls = 0
|
||||
|
||||
async def get_session_owner(self, *, session_id: str) -> str:
|
||||
if self.create_calls == 0:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
return "00000000-0000-0000-0000-000000000001"
|
||||
|
||||
async def create_session_for_user(
|
||||
self, *, user_id: str, session_id: str | None = None
|
||||
) -> str:
|
||||
del user_id, session_id
|
||||
self.create_calls += 1
|
||||
raise IntegrityError("insert", {}, Exception("duplicate key"))
|
||||
|
||||
repository = _RaceRepository()
|
||||
service = AgentService(
|
||||
repository=repository,
|
||||
queue=_FakeQueue(),
|
||||
stream=_FakeStream(),
|
||||
)
|
||||
run_input = _build_run_input(
|
||||
thread_id="00000000-0000-0000-0000-000000000999",
|
||||
run_id="run-1",
|
||||
)
|
||||
|
||||
accepted = await service.enqueue_run(
|
||||
run_input=run_input,
|
||||
current_user=_user(),
|
||||
)
|
||||
|
||||
assert accepted.created is False
|
||||
assert repository.rolled_back is True
|
||||
|
||||
|
||||
async def test_get_history_snapshot_wraps_history_day_as_state_snapshot_event() -> None:
|
||||
service = AgentService(
|
||||
repository=_FakeRepository(),
|
||||
queue=_FakeQueue(),
|
||||
stream=_FakeStream(),
|
||||
)
|
||||
|
||||
event = await service.get_history_snapshot(
|
||||
thread_id="00000000-0000-0000-0000-000000000001",
|
||||
before=date(2026, 3, 7),
|
||||
current_user=_user(),
|
||||
)
|
||||
|
||||
assert event["type"] == "STATE_SNAPSHOT"
|
||||
assert event["threadId"] == "00000000-0000-0000-0000-000000000001"
|
||||
snapshot = event["snapshot"]
|
||||
assert isinstance(snapshot, dict)
|
||||
assert snapshot["scope"] == "history_day"
|
||||
assert snapshot["day"] == "2026-03-06"
|
||||
assert snapshot["messages"][0]["id"] == "m1"
|
||||
|
||||
|
||||
async def test_get_user_history_snapshot_uses_latest_thread_when_absent() -> None:
|
||||
service = AgentService(
|
||||
repository=_FakeRepository(),
|
||||
queue=_FakeQueue(),
|
||||
stream=_FakeStream(),
|
||||
)
|
||||
event = await service.get_user_history_snapshot(
|
||||
current_user=_user(),
|
||||
thread_id=None,
|
||||
before=None,
|
||||
)
|
||||
assert event["type"] == "STATE_SNAPSHOT"
|
||||
assert event["threadId"] == "00000000-0000-0000-0000-000000000001"
|
||||
|
||||
Reference in New Issue
Block a user