feat: AG-UI 协议对齐与路由导航功能

- 前端: 添加 SSE 流式支持、stateSnapshot 事件、路由导航工具
- 前端: 实现工具调用审批流程,支持 pending 状态展示
- 后端: Agent 状态管理与会话持久化相关重构
- 文档: 新增 agent-agui-full-alignance 设计文档
- 测试: 补充相关单元测试和集成测试
This commit is contained in:
zl-q
2026-03-07 17:30:20 +08:00
parent ec33bb0cee
commit 120df903d2
52 changed files with 4305 additions and 1672 deletions
+124 -1
View File
@@ -1,11 +1,14 @@
from __future__ import annotations
from datetime import date, datetime, time, timedelta, timezone
import json
from uuid import UUID
from fastapi import HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
from models.agent_chat_session import AgentChatSession
@@ -27,13 +30,22 @@ class AgentRepository:
raise HTTPException(status_code=404, detail="Session not found")
return str(owner_id)
async def create_session_for_user(self, *, user_id: str) -> str:
async def create_session_for_user(
self, *, user_id: str, session_id: str | None = None
) -> str:
try:
user_uuid = UUID(user_id)
except ValueError as exc:
raise HTTPException(status_code=422, detail="Invalid user_id") from exc
session_uuid = None
if session_id is not None:
try:
session_uuid = UUID(session_id)
except ValueError as exc:
raise HTTPException(status_code=422, detail="Invalid session_id") from exc
session = AgentChatSession(
id=session_uuid,
user_id=user_uuid,
)
self._session.add(session)
@@ -56,3 +68,114 @@ class AgentRepository:
if session is not None:
await self._session.delete(session)
await self._session.flush()
async def get_history_day(
self, *, session_id: str, before: date | None
) -> dict[str, object] | None:
try:
session_uuid = UUID(session_id)
except ValueError as exc:
raise HTTPException(status_code=422, detail="Invalid session_id") from exc
timestamp_stmt = (
select(AgentChatMessage.created_at)
.where(AgentChatMessage.session_id == session_uuid)
.where(AgentChatMessage.deleted_at.is_(None))
.order_by(AgentChatMessage.created_at.desc())
)
rows = (await self._session.execute(timestamp_stmt)).scalars().all()
unique_days: list[date] = []
for created_at in rows:
if created_at is None:
continue
day = created_at.astimezone(timezone.utc).date()
if day not in unique_days:
unique_days.append(day)
if not unique_days:
return None
target_day: date | None = None
if before is None:
target_day = unique_days[0]
else:
for day in unique_days:
if day < before:
target_day = day
break
if target_day is None:
return None
start = datetime.combine(target_day, time.min, tzinfo=timezone.utc)
end = start + timedelta(days=1)
message_stmt = (
select(AgentChatMessage)
.where(AgentChatMessage.session_id == session_uuid)
.where(AgentChatMessage.deleted_at.is_(None))
.where(AgentChatMessage.created_at >= start)
.where(AgentChatMessage.created_at < end)
.order_by(AgentChatMessage.seq.asc())
)
messages = (await self._session.execute(message_stmt)).scalars().all()
has_more = any(day < target_day for day in unique_days)
return {
"day": target_day.isoformat(),
"hasMore": has_more,
"messages": [self._to_snapshot_message(msg) for msg in messages],
}
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None:
try:
user_uuid = UUID(user_id)
except ValueError as exc:
raise HTTPException(status_code=422, detail="Invalid user_id") from exc
stmt = (
select(AgentChatSession.id)
.where(AgentChatSession.user_id == user_uuid)
.where(AgentChatSession.deleted_at.is_(None))
.order_by(AgentChatSession.last_activity_at.desc())
.limit(1)
)
latest_id = (await self._session.execute(stmt)).scalar_one_or_none()
if latest_id is None:
return None
return str(latest_id)
@staticmethod
def _to_snapshot_message(message: AgentChatMessage) -> dict[str, object]:
role = (
message.role.value
if isinstance(message.role, AgentChatMessageRole)
else str(message.role)
)
payload: dict[str, object] = {
"id": str(message.id),
"role": role,
"timestamp": message.created_at.astimezone(timezone.utc).isoformat(),
}
if role == AgentChatMessageRole.TOOL.value:
metadata = message.metadata_json or {}
tool_call_id = metadata.get("tool_call_id")
if isinstance(tool_call_id, str) and tool_call_id:
payload["toolCallId"] = tool_call_id
parsed_content: dict[str, object] | None = None
try:
decoded = json.loads(message.content)
if isinstance(decoded, dict):
parsed_content = decoded
except (TypeError, ValueError):
parsed_content = None
if parsed_content is not None:
ui = parsed_content.get("ui")
if isinstance(ui, dict):
payload["ui"] = ui
display_content = parsed_content.get("content")
if isinstance(display_content, str):
payload["content"] = display_content
else:
payload["content"] = message.content
else:
payload["content"] = message.content
return payload
+141 -32
View File
@@ -2,96 +2,177 @@ from __future__ import annotations
from collections.abc import AsyncIterator
import asyncio
from datetime import date
import re
import time
from typing import Annotated
from ag_ui.core import RunAgentInput
from fastapi import APIRouter, Depends, Header, Query, Request, status
from fastapi import HTTPException
from fastapi.responses import StreamingResponse
from core.agent.infrastructure.agui.stream import to_sse_event
from core.agent.domain.agui_input import parse_run_input
from core.auth.models import CurrentUser
from services.base.redis import get_or_init_redis_client
from v1.agent.dependencies import get_agent_service
from v1.agent.schemas import ResumeRequest, RunRequest, TaskAcceptedResponse
from v1.agent.schemas import TaskAcceptedResponse
from v1.agent.service import AgentService
from v1.users.dependencies import get_current_user
router = APIRouter(prefix="/agent", tags=["agent"])
_LAST_EVENT_ID_RE = re.compile(r"^\d+-\d+$")
_RUNS_PER_MINUTE = 30
_MAX_SSE_CONNECTIONS_PER_USER = 3
_SSE_SLOT_TTL_SECONDS = 15 * 60
async def _allow_run_request(*, user_id: str) -> bool:
try:
redis = await get_or_init_redis_client()
minute_bucket = int(time.time() // 60)
key = f"agent:run-rate:{user_id}:{minute_bucket}"
count = await redis.incr(key)
if count == 1:
await redis.expire(key, 70)
return int(count) <= _RUNS_PER_MINUTE
except Exception: # noqa: BLE001
return False
async def _acquire_sse_slot(*, user_id: str) -> bool:
try:
redis = await get_or_init_redis_client()
key = f"agent:sse-active:{user_id}"
count = await redis.incr(key)
if count == 1:
await redis.expire(key, _SSE_SLOT_TTL_SECONDS)
if int(count) > _MAX_SSE_CONNECTIONS_PER_USER:
await redis.decr(key)
return False
return True
except Exception: # noqa: BLE001
return False
async def _release_sse_slot(*, user_id: str) -> None:
try:
redis = await get_or_init_redis_client()
key = f"agent:sse-active:{user_id}"
count = await redis.decr(key)
if int(count) <= 0:
await redis.delete(key)
except Exception: # noqa: BLE001
return None
@router.post(
"/runs", response_model=TaskAcceptedResponse, status_code=status.HTTP_202_ACCEPTED
)
async def enqueue_run(
request: RunRequest,
request: RunAgentInput,
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
) -> TaskAcceptedResponse:
try:
parse_run_input(request.model_dump(mode="json", by_alias=True))
except ValueError as exc:
raise HTTPException(status_code=422, detail=str(exc)) from exc
allowed = await _allow_run_request(user_id=str(current_user.id))
if not allowed:
raise HTTPException(status_code=429, detail="Too many run requests")
task = await service.enqueue_run(
session_id=request.session_id,
prompt=request.prompt,
run_input=request,
current_user=current_user,
)
return TaskAcceptedResponse(
task_id=task.task_id,
session_id=task.session_id,
thread_id=task.thread_id,
run_id=task.run_id,
created=task.created,
)
@router.post(
"/runs/{session_id}/resume",
"/runs/{thread_id}/resume",
response_model=TaskAcceptedResponse,
status_code=status.HTTP_202_ACCEPTED,
)
async def enqueue_resume(
session_id: str,
request: ResumeRequest,
thread_id: str,
request: RunAgentInput,
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
) -> TaskAcceptedResponse:
if request.thread_id != thread_id:
raise HTTPException(status_code=422, detail="thread_id path/body mismatch")
try:
parse_run_input(request.model_dump(mode="json", by_alias=True))
except ValueError as exc:
raise HTTPException(status_code=422, detail=str(exc)) from exc
task = await service.enqueue_resume(
session_id=session_id,
tool_call_id=request.tool_call_id,
thread_id=thread_id,
run_input=request,
current_user=current_user,
)
return TaskAcceptedResponse(
task_id=task.task_id,
session_id=task.session_id,
thread_id=task.thread_id,
run_id=task.run_id,
created=task.created,
)
@router.get("/runs/{session_id}/events")
@router.get("/runs/{thread_id}/events")
async def stream_events(
request: Request,
session_id: str,
thread_id: str,
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
last_event_id: str | None = Header(default=None, alias="Last-Event-ID"),
idle_limit: int = Query(default=300, ge=1, le=3600),
) -> StreamingResponse:
if (
last_event_id is not None
and (
len(last_event_id) > 32
or _LAST_EVENT_ID_RE.fullmatch(last_event_id) is None
)
):
raise HTTPException(status_code=422, detail="Invalid Last-Event-ID")
sse_slot_acquired = await _acquire_sse_slot(user_id=str(current_user.id))
if not sse_slot_acquired:
raise HTTPException(status_code=429, detail="Too many SSE connections")
async def _event_iter() -> AsyncIterator[str]:
cursor = last_event_id
idle_polls = 0
while not await request.is_disconnected() and idle_polls < idle_limit:
rows = await service.stream_events(
session_id=session_id,
last_event_id=cursor,
current_user=current_user,
)
if not rows:
idle_polls += 1
yield ": keep-alive\n\n"
await asyncio.sleep(0.2)
continue
idle_polls = 0
for row in rows:
row_id = str(row.get("id", ""))
event = row.get("event")
if not row_id or not isinstance(event, dict):
try:
while not await request.is_disconnected() and idle_polls < idle_limit:
rows = await service.stream_events(
thread_id=thread_id,
last_event_id=cursor,
current_user=current_user,
)
if not rows:
idle_polls += 1
yield ": keep-alive\n\n"
await asyncio.sleep(0.2)
continue
cursor = row_id
yield to_sse_event(row_id, event)
idle_polls = 0
for row in rows:
row_id = str(row.get("id", ""))
event = row.get("event")
if not row_id or not isinstance(event, dict):
continue
cursor = row_id
yield to_sse_event(row_id, event)
finally:
await _release_sse_slot(user_id=str(current_user.id))
return StreamingResponse(
_event_iter(),
@@ -102,3 +183,31 @@ async def stream_events(
"X-Accel-Buffering": "no",
},
)
@router.get("/runs/{thread_id}/history")
async def get_history_snapshot(
thread_id: str,
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
before: date | None = Query(default=None),
) -> dict[str, object]:
return await service.get_history_snapshot(
thread_id=thread_id,
before=before,
current_user=current_user,
)
@router.get("/history")
async def get_user_history_snapshot(
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
thread_id: str | None = Query(default=None, alias="threadId"),
before: date | None = Query(default=None),
) -> dict[str, object]:
return await service.get_user_history_snapshot(
current_user=current_user,
thread_id=thread_id,
before=before,
)
+6 -12
View File
@@ -1,18 +1,12 @@
from __future__ import annotations
from pydantic import BaseModel, Field
class RunRequest(BaseModel):
session_id: str | None = Field(default=None, min_length=1, max_length=100)
prompt: str = Field(min_length=1, max_length=5000)
class ResumeRequest(BaseModel):
tool_call_id: str = Field(min_length=1, max_length=200)
from pydantic import BaseModel, ConfigDict, Field
class TaskAcceptedResponse(BaseModel):
task_id: str
session_id: str
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
task_id: str = Field(alias="taskId")
thread_id: str = Field(alias="threadId")
run_id: str = Field(alias="runId")
created: bool
+109 -29
View File
@@ -1,9 +1,13 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import date
from typing import Protocol
from ag_ui.core import StateSnapshotEvent
from ag_ui.core import RunAgentInput
from fastapi import HTTPException
from sqlalchemy.exc import IntegrityError
from core.auth.models import CurrentUser
@@ -11,19 +15,28 @@ from core.auth.models import CurrentUser
@dataclass(frozen=True)
class TaskAccepted:
task_id: str
session_id: str
thread_id: str
run_id: str
created: bool
class AgentRepositoryLike(Protocol):
async def get_session_owner(self, *, session_id: str) -> str: ...
async def create_session_for_user(self, *, user_id: str) -> str: ...
async def create_session_for_user(
self, *, user_id: str, session_id: str | None = None
) -> str: ...
async def commit(self) -> None: ...
async def rollback(self) -> None: ...
async def get_history_day(
self, *, session_id: str, before: date | None
) -> dict[str, object] | None: ...
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None: ...
class QueueClientLike(Protocol):
async def enqueue(
@@ -60,73 +73,140 @@ class AgentService:
async def enqueue_run(
self,
*,
session_id: str | None,
prompt: str,
run_input: RunAgentInput,
current_user: CurrentUser,
) -> TaskAccepted:
created = False
target_session_id = session_id
if target_session_id is None:
target_session_id = await self._repository.create_session_for_user(
user_id=str(current_user.id)
)
created = True
thread_id = run_input.thread_id
run_id = run_input.run_id
try:
owner = await self._repository.get_session_owner(session_id=thread_id)
except HTTPException as exc:
if exc.status_code != 404:
raise
try:
await self._repository.create_session_for_user(
user_id=str(current_user.id),
session_id=thread_id,
)
await self._repository.commit()
created = True
except IntegrityError:
await self._repository.rollback()
owner = await self._repository.get_session_owner(session_id=thread_id)
ensure_session_owner(owner_id=owner, current_user=current_user)
else:
owner = await self._repository.get_session_owner(
session_id=target_session_id
)
ensure_session_owner(owner_id=owner, current_user=current_user)
if created:
await self._repository.commit()
try:
task_id = await self._queue.enqueue(
command={
"command": "run",
"session_id": target_session_id,
"user_input": prompt,
"run_input": run_input.model_dump(mode="json", by_alias=True),
},
dedup_key=None,
)
except Exception: # noqa: BLE001
raise
return TaskAccepted(
task_id=task_id, session_id=target_session_id, created=created
task_id=task_id,
thread_id=thread_id,
run_id=run_id,
created=created,
)
async def enqueue_resume(
self,
*,
session_id: str,
tool_call_id: str,
thread_id: str,
run_input: RunAgentInput,
current_user: CurrentUser,
) -> TaskAccepted:
owner = await self._repository.get_session_owner(session_id=session_id)
owner = await self._repository.get_session_owner(session_id=thread_id)
ensure_session_owner(owner_id=owner, current_user=current_user)
dedup_key = f"resume:{session_id}:{tool_call_id}"
dedup_key = f"resume:{thread_id}:{run_input.run_id}"
task_id = await self._queue.enqueue(
command={
"command": "resume",
"session_id": session_id,
"tool_call_id": tool_call_id,
"run_input": run_input.model_dump(mode="json", by_alias=True),
},
dedup_key=dedup_key,
)
return TaskAccepted(task_id=task_id, session_id=session_id, created=False)
return TaskAccepted(
task_id=task_id,
thread_id=thread_id,
run_id=run_input.run_id,
created=False,
)
async def stream_events(
self,
*,
session_id: str,
thread_id: str,
last_event_id: str | None,
current_user: CurrentUser,
) -> list[dict[str, object]]:
owner = await self._repository.get_session_owner(session_id=session_id)
owner = await self._repository.get_session_owner(session_id=thread_id)
ensure_session_owner(owner_id=owner, current_user=current_user)
return await self._stream.read(
session_id=session_id,
session_id=thread_id,
last_event_id=last_event_id,
)
async def get_history_snapshot(
self,
*,
thread_id: str,
before: date | None,
current_user: CurrentUser,
) -> dict[str, object]:
owner = await self._repository.get_session_owner(session_id=thread_id)
ensure_session_owner(owner_id=owner, current_user=current_user)
day_payload = await self._repository.get_history_day(
session_id=thread_id,
before=before,
)
snapshot = {
"scope": "history_day",
"threadId": thread_id,
"day": day_payload["day"] if day_payload else None,
"hasMore": day_payload["hasMore"] if day_payload else False,
"messages": day_payload["messages"] if day_payload else [],
}
event = StateSnapshotEvent(snapshot=snapshot).model_dump(
mode="json",
by_alias=True,
exclude_none=True,
)
event["threadId"] = thread_id
return event
async def get_user_history_snapshot(
self,
*,
current_user: CurrentUser,
thread_id: str | None,
before: date | None,
) -> dict[str, object]:
target_thread_id = thread_id
if target_thread_id is None:
target_thread_id = await self._repository.get_latest_session_id_for_user(
user_id=str(current_user.id)
)
if target_thread_id is None:
return StateSnapshotEvent(
snapshot={
"scope": "history_day",
"threadId": None,
"day": None,
"hasMore": False,
"messages": [],
}
).model_dump(mode="json", by_alias=True, exclude_none=True)
return await self.get_history_snapshot(
thread_id=target_thread_id,
before=before,
current_user=current_user,
)