feat: AG-UI 协议对齐与路由导航功能

- 前端: 添加 SSE 流式支持、stateSnapshot 事件、路由导航工具
- 前端: 实现工具调用审批流程,支持 pending 状态展示
- 后端: Agent 状态管理与会话持久化相关重构
- 文档: 新增 agent-agui-full-alignance 设计文档
- 测试: 补充相关单元测试和集成测试
This commit is contained in:
zl-q
2026-03-07 17:30:20 +08:00
parent ec33bb0cee
commit 120df903d2
52 changed files with 4305 additions and 1672 deletions
+124 -1
View File
@@ -1,11 +1,14 @@
from __future__ import annotations
from datetime import date, datetime, time, timedelta, timezone
import json
from uuid import UUID
from fastapi import HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
from models.agent_chat_session import AgentChatSession
@@ -27,13 +30,22 @@ class AgentRepository:
raise HTTPException(status_code=404, detail="Session not found")
return str(owner_id)
async def create_session_for_user(self, *, user_id: str) -> str:
async def create_session_for_user(
self, *, user_id: str, session_id: str | None = None
) -> str:
try:
user_uuid = UUID(user_id)
except ValueError as exc:
raise HTTPException(status_code=422, detail="Invalid user_id") from exc
session_uuid = None
if session_id is not None:
try:
session_uuid = UUID(session_id)
except ValueError as exc:
raise HTTPException(status_code=422, detail="Invalid session_id") from exc
session = AgentChatSession(
id=session_uuid,
user_id=user_uuid,
)
self._session.add(session)
@@ -56,3 +68,114 @@ class AgentRepository:
if session is not None:
await self._session.delete(session)
await self._session.flush()
async def get_history_day(
self, *, session_id: str, before: date | None
) -> dict[str, object] | None:
try:
session_uuid = UUID(session_id)
except ValueError as exc:
raise HTTPException(status_code=422, detail="Invalid session_id") from exc
timestamp_stmt = (
select(AgentChatMessage.created_at)
.where(AgentChatMessage.session_id == session_uuid)
.where(AgentChatMessage.deleted_at.is_(None))
.order_by(AgentChatMessage.created_at.desc())
)
rows = (await self._session.execute(timestamp_stmt)).scalars().all()
unique_days: list[date] = []
for created_at in rows:
if created_at is None:
continue
day = created_at.astimezone(timezone.utc).date()
if day not in unique_days:
unique_days.append(day)
if not unique_days:
return None
target_day: date | None = None
if before is None:
target_day = unique_days[0]
else:
for day in unique_days:
if day < before:
target_day = day
break
if target_day is None:
return None
start = datetime.combine(target_day, time.min, tzinfo=timezone.utc)
end = start + timedelta(days=1)
message_stmt = (
select(AgentChatMessage)
.where(AgentChatMessage.session_id == session_uuid)
.where(AgentChatMessage.deleted_at.is_(None))
.where(AgentChatMessage.created_at >= start)
.where(AgentChatMessage.created_at < end)
.order_by(AgentChatMessage.seq.asc())
)
messages = (await self._session.execute(message_stmt)).scalars().all()
has_more = any(day < target_day for day in unique_days)
return {
"day": target_day.isoformat(),
"hasMore": has_more,
"messages": [self._to_snapshot_message(msg) for msg in messages],
}
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None:
try:
user_uuid = UUID(user_id)
except ValueError as exc:
raise HTTPException(status_code=422, detail="Invalid user_id") from exc
stmt = (
select(AgentChatSession.id)
.where(AgentChatSession.user_id == user_uuid)
.where(AgentChatSession.deleted_at.is_(None))
.order_by(AgentChatSession.last_activity_at.desc())
.limit(1)
)
latest_id = (await self._session.execute(stmt)).scalar_one_or_none()
if latest_id is None:
return None
return str(latest_id)
@staticmethod
def _to_snapshot_message(message: AgentChatMessage) -> dict[str, object]:
role = (
message.role.value
if isinstance(message.role, AgentChatMessageRole)
else str(message.role)
)
payload: dict[str, object] = {
"id": str(message.id),
"role": role,
"timestamp": message.created_at.astimezone(timezone.utc).isoformat(),
}
if role == AgentChatMessageRole.TOOL.value:
metadata = message.metadata_json or {}
tool_call_id = metadata.get("tool_call_id")
if isinstance(tool_call_id, str) and tool_call_id:
payload["toolCallId"] = tool_call_id
parsed_content: dict[str, object] | None = None
try:
decoded = json.loads(message.content)
if isinstance(decoded, dict):
parsed_content = decoded
except (TypeError, ValueError):
parsed_content = None
if parsed_content is not None:
ui = parsed_content.get("ui")
if isinstance(ui, dict):
payload["ui"] = ui
display_content = parsed_content.get("content")
if isinstance(display_content, str):
payload["content"] = display_content
else:
payload["content"] = message.content
else:
payload["content"] = message.content
return payload