feat: AG-UI 协议对齐与路由导航功能
- 前端: 添加 SSE 流式支持、stateSnapshot 事件、路由导航工具 - 前端: 实现工具调用审批流程,支持 pending 状态展示 - 后端: Agent 状态管理与会话持久化相关重构 - 文档: 新增 agent-agui-full-alignance 设计文档 - 测试: 补充相关单元测试和集成测试
This commit is contained in:
+109
-29
@@ -1,9 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import date
|
||||
from typing import Protocol
|
||||
|
||||
from ag_ui.core import StateSnapshotEvent
|
||||
from ag_ui.core import RunAgentInput
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
|
||||
@@ -11,19 +15,28 @@ from core.auth.models import CurrentUser
|
||||
@dataclass(frozen=True)
|
||||
class TaskAccepted:
|
||||
task_id: str
|
||||
session_id: str
|
||||
thread_id: str
|
||||
run_id: str
|
||||
created: bool
|
||||
|
||||
|
||||
class AgentRepositoryLike(Protocol):
|
||||
async def get_session_owner(self, *, session_id: str) -> str: ...
|
||||
|
||||
async def create_session_for_user(self, *, user_id: str) -> str: ...
|
||||
async def create_session_for_user(
|
||||
self, *, user_id: str, session_id: str | None = None
|
||||
) -> str: ...
|
||||
|
||||
async def commit(self) -> None: ...
|
||||
|
||||
async def rollback(self) -> None: ...
|
||||
|
||||
async def get_history_day(
|
||||
self, *, session_id: str, before: date | None
|
||||
) -> dict[str, object] | None: ...
|
||||
|
||||
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None: ...
|
||||
|
||||
|
||||
class QueueClientLike(Protocol):
|
||||
async def enqueue(
|
||||
@@ -60,73 +73,140 @@ class AgentService:
|
||||
async def enqueue_run(
|
||||
self,
|
||||
*,
|
||||
session_id: str | None,
|
||||
prompt: str,
|
||||
run_input: RunAgentInput,
|
||||
current_user: CurrentUser,
|
||||
) -> TaskAccepted:
|
||||
created = False
|
||||
target_session_id = session_id
|
||||
if target_session_id is None:
|
||||
target_session_id = await self._repository.create_session_for_user(
|
||||
user_id=str(current_user.id)
|
||||
)
|
||||
created = True
|
||||
thread_id = run_input.thread_id
|
||||
run_id = run_input.run_id
|
||||
try:
|
||||
owner = await self._repository.get_session_owner(session_id=thread_id)
|
||||
except HTTPException as exc:
|
||||
if exc.status_code != 404:
|
||||
raise
|
||||
try:
|
||||
await self._repository.create_session_for_user(
|
||||
user_id=str(current_user.id),
|
||||
session_id=thread_id,
|
||||
)
|
||||
await self._repository.commit()
|
||||
created = True
|
||||
except IntegrityError:
|
||||
await self._repository.rollback()
|
||||
owner = await self._repository.get_session_owner(session_id=thread_id)
|
||||
ensure_session_owner(owner_id=owner, current_user=current_user)
|
||||
else:
|
||||
owner = await self._repository.get_session_owner(
|
||||
session_id=target_session_id
|
||||
)
|
||||
ensure_session_owner(owner_id=owner, current_user=current_user)
|
||||
|
||||
if created:
|
||||
await self._repository.commit()
|
||||
|
||||
try:
|
||||
task_id = await self._queue.enqueue(
|
||||
command={
|
||||
"command": "run",
|
||||
"session_id": target_session_id,
|
||||
"user_input": prompt,
|
||||
"run_input": run_input.model_dump(mode="json", by_alias=True),
|
||||
},
|
||||
dedup_key=None,
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
raise
|
||||
return TaskAccepted(
|
||||
task_id=task_id, session_id=target_session_id, created=created
|
||||
task_id=task_id,
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
created=created,
|
||||
)
|
||||
|
||||
async def enqueue_resume(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
tool_call_id: str,
|
||||
thread_id: str,
|
||||
run_input: RunAgentInput,
|
||||
current_user: CurrentUser,
|
||||
) -> TaskAccepted:
|
||||
owner = await self._repository.get_session_owner(session_id=session_id)
|
||||
owner = await self._repository.get_session_owner(session_id=thread_id)
|
||||
ensure_session_owner(owner_id=owner, current_user=current_user)
|
||||
|
||||
dedup_key = f"resume:{session_id}:{tool_call_id}"
|
||||
dedup_key = f"resume:{thread_id}:{run_input.run_id}"
|
||||
task_id = await self._queue.enqueue(
|
||||
command={
|
||||
"command": "resume",
|
||||
"session_id": session_id,
|
||||
"tool_call_id": tool_call_id,
|
||||
"run_input": run_input.model_dump(mode="json", by_alias=True),
|
||||
},
|
||||
dedup_key=dedup_key,
|
||||
)
|
||||
|
||||
return TaskAccepted(task_id=task_id, session_id=session_id, created=False)
|
||||
return TaskAccepted(
|
||||
task_id=task_id,
|
||||
thread_id=thread_id,
|
||||
run_id=run_input.run_id,
|
||||
created=False,
|
||||
)
|
||||
|
||||
async def stream_events(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
thread_id: str,
|
||||
last_event_id: str | None,
|
||||
current_user: CurrentUser,
|
||||
) -> list[dict[str, object]]:
|
||||
owner = await self._repository.get_session_owner(session_id=session_id)
|
||||
owner = await self._repository.get_session_owner(session_id=thread_id)
|
||||
ensure_session_owner(owner_id=owner, current_user=current_user)
|
||||
return await self._stream.read(
|
||||
session_id=session_id,
|
||||
session_id=thread_id,
|
||||
last_event_id=last_event_id,
|
||||
)
|
||||
|
||||
async def get_history_snapshot(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
before: date | None,
|
||||
current_user: CurrentUser,
|
||||
) -> dict[str, object]:
|
||||
owner = await self._repository.get_session_owner(session_id=thread_id)
|
||||
ensure_session_owner(owner_id=owner, current_user=current_user)
|
||||
day_payload = await self._repository.get_history_day(
|
||||
session_id=thread_id,
|
||||
before=before,
|
||||
)
|
||||
snapshot = {
|
||||
"scope": "history_day",
|
||||
"threadId": thread_id,
|
||||
"day": day_payload["day"] if day_payload else None,
|
||||
"hasMore": day_payload["hasMore"] if day_payload else False,
|
||||
"messages": day_payload["messages"] if day_payload else [],
|
||||
}
|
||||
event = StateSnapshotEvent(snapshot=snapshot).model_dump(
|
||||
mode="json",
|
||||
by_alias=True,
|
||||
exclude_none=True,
|
||||
)
|
||||
event["threadId"] = thread_id
|
||||
return event
|
||||
|
||||
async def get_user_history_snapshot(
|
||||
self,
|
||||
*,
|
||||
current_user: CurrentUser,
|
||||
thread_id: str | None,
|
||||
before: date | None,
|
||||
) -> dict[str, object]:
|
||||
target_thread_id = thread_id
|
||||
if target_thread_id is None:
|
||||
target_thread_id = await self._repository.get_latest_session_id_for_user(
|
||||
user_id=str(current_user.id)
|
||||
)
|
||||
if target_thread_id is None:
|
||||
return StateSnapshotEvent(
|
||||
snapshot={
|
||||
"scope": "history_day",
|
||||
"threadId": None,
|
||||
"day": None,
|
||||
"hasMore": False,
|
||||
"messages": [],
|
||||
}
|
||||
).model_dump(mode="json", by_alias=True, exclude_none=True)
|
||||
return await self.get_history_snapshot(
|
||||
thread_id=target_thread_id,
|
||||
before=before,
|
||||
current_user=current_user,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user