feat: AG-UI 协议对齐与路由导航功能
- 前端: 添加 SSE 流式支持、stateSnapshot 事件、路由导航工具 - 前端: 实现工具调用审批流程,支持 pending 状态展示 - 后端: Agent 状态管理与会话持久化相关重构 - 文档: 新增 agent-agui-full-alignance 设计文档 - 测试: 补充相关单元测试和集成测试
This commit is contained in:
@@ -1,11 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date, datetime, time, timedelta, timezone
|
||||
import json
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
|
||||
from models.agent_chat_session import AgentChatSession
|
||||
|
||||
|
||||
@@ -27,13 +30,22 @@ class AgentRepository:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
return str(owner_id)
|
||||
|
||||
async def create_session_for_user(self, *, user_id: str) -> str:
|
||||
async def create_session_for_user(
|
||||
self, *, user_id: str, session_id: str | None = None
|
||||
) -> str:
|
||||
try:
|
||||
user_uuid = UUID(user_id)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail="Invalid user_id") from exc
|
||||
session_uuid = None
|
||||
if session_id is not None:
|
||||
try:
|
||||
session_uuid = UUID(session_id)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail="Invalid session_id") from exc
|
||||
|
||||
session = AgentChatSession(
|
||||
id=session_uuid,
|
||||
user_id=user_uuid,
|
||||
)
|
||||
self._session.add(session)
|
||||
@@ -56,3 +68,114 @@ class AgentRepository:
|
||||
if session is not None:
|
||||
await self._session.delete(session)
|
||||
await self._session.flush()
|
||||
|
||||
async def get_history_day(
|
||||
self, *, session_id: str, before: date | None
|
||||
) -> dict[str, object] | None:
|
||||
try:
|
||||
session_uuid = UUID(session_id)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail="Invalid session_id") from exc
|
||||
|
||||
timestamp_stmt = (
|
||||
select(AgentChatMessage.created_at)
|
||||
.where(AgentChatMessage.session_id == session_uuid)
|
||||
.where(AgentChatMessage.deleted_at.is_(None))
|
||||
.order_by(AgentChatMessage.created_at.desc())
|
||||
)
|
||||
rows = (await self._session.execute(timestamp_stmt)).scalars().all()
|
||||
unique_days: list[date] = []
|
||||
for created_at in rows:
|
||||
if created_at is None:
|
||||
continue
|
||||
day = created_at.astimezone(timezone.utc).date()
|
||||
if day not in unique_days:
|
||||
unique_days.append(day)
|
||||
|
||||
if not unique_days:
|
||||
return None
|
||||
|
||||
target_day: date | None = None
|
||||
if before is None:
|
||||
target_day = unique_days[0]
|
||||
else:
|
||||
for day in unique_days:
|
||||
if day < before:
|
||||
target_day = day
|
||||
break
|
||||
if target_day is None:
|
||||
return None
|
||||
|
||||
start = datetime.combine(target_day, time.min, tzinfo=timezone.utc)
|
||||
end = start + timedelta(days=1)
|
||||
message_stmt = (
|
||||
select(AgentChatMessage)
|
||||
.where(AgentChatMessage.session_id == session_uuid)
|
||||
.where(AgentChatMessage.deleted_at.is_(None))
|
||||
.where(AgentChatMessage.created_at >= start)
|
||||
.where(AgentChatMessage.created_at < end)
|
||||
.order_by(AgentChatMessage.seq.asc())
|
||||
)
|
||||
messages = (await self._session.execute(message_stmt)).scalars().all()
|
||||
has_more = any(day < target_day for day in unique_days)
|
||||
return {
|
||||
"day": target_day.isoformat(),
|
||||
"hasMore": has_more,
|
||||
"messages": [self._to_snapshot_message(msg) for msg in messages],
|
||||
}
|
||||
|
||||
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None:
|
||||
try:
|
||||
user_uuid = UUID(user_id)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail="Invalid user_id") from exc
|
||||
stmt = (
|
||||
select(AgentChatSession.id)
|
||||
.where(AgentChatSession.user_id == user_uuid)
|
||||
.where(AgentChatSession.deleted_at.is_(None))
|
||||
.order_by(AgentChatSession.last_activity_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
latest_id = (await self._session.execute(stmt)).scalar_one_or_none()
|
||||
if latest_id is None:
|
||||
return None
|
||||
return str(latest_id)
|
||||
|
||||
@staticmethod
|
||||
def _to_snapshot_message(message: AgentChatMessage) -> dict[str, object]:
|
||||
role = (
|
||||
message.role.value
|
||||
if isinstance(message.role, AgentChatMessageRole)
|
||||
else str(message.role)
|
||||
)
|
||||
payload: dict[str, object] = {
|
||||
"id": str(message.id),
|
||||
"role": role,
|
||||
"timestamp": message.created_at.astimezone(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
if role == AgentChatMessageRole.TOOL.value:
|
||||
metadata = message.metadata_json or {}
|
||||
tool_call_id = metadata.get("tool_call_id")
|
||||
if isinstance(tool_call_id, str) and tool_call_id:
|
||||
payload["toolCallId"] = tool_call_id
|
||||
|
||||
parsed_content: dict[str, object] | None = None
|
||||
try:
|
||||
decoded = json.loads(message.content)
|
||||
if isinstance(decoded, dict):
|
||||
parsed_content = decoded
|
||||
except (TypeError, ValueError):
|
||||
parsed_content = None
|
||||
if parsed_content is not None:
|
||||
ui = parsed_content.get("ui")
|
||||
if isinstance(ui, dict):
|
||||
payload["ui"] = ui
|
||||
display_content = parsed_content.get("content")
|
||||
if isinstance(display_content, str):
|
||||
payload["content"] = display_content
|
||||
else:
|
||||
payload["content"] = message.content
|
||||
else:
|
||||
payload["content"] = message.content
|
||||
return payload
|
||||
|
||||
Reference in New Issue
Block a user