feat: AG-UI 协议对齐与路由导航功能

- 前端: 添加 SSE 流式支持、stateSnapshot 事件、路由导航工具
- 前端: 实现工具调用审批流程,支持 pending 状态展示
- 后端: Agent 状态管理与会话持久化相关重构
- 文档: 新增 agent-agui-full-alignance 设计文档
- 测试: 补充相关单元测试和集成测试
This commit is contained in:
zl-q
2026-03-07 17:30:20 +08:00
parent ec33bb0cee
commit 120df903d2
52 changed files with 4305 additions and 1672 deletions
@@ -1,10 +1,22 @@
from __future__ import annotations
import json
from uuid import UUID
from ag_ui.core import (
RunAgentInput,
TextMessageContentEvent,
TextMessageEndEvent,
TextMessageStartEvent,
ToolCallResultEvent,
)
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from core.agent.application.session_state_persistence import SessionStatePersistence
from core.agent.application.session_state_persistence import (
SessionStatePersistence,
compute_tool_args_sha256,
)
from core.agent.domain.agui_input import extract_latest_tool_result
from core.agent.domain.message_metadata import (
MessageMetadataAssistantOutput,
MessageMetadataToolResult,
@@ -25,8 +37,13 @@ class ResumeService:
self._session_factory = session_factory
self._state_persistence = SessionStatePersistence()
async def resume(self, *, session_id: str, tool_call_id: str) -> dict[str, object]:
session_uuid = UUID(session_id)
async def resume(
self,
*,
run_input: RunAgentInput,
) -> dict[str, object]:
session_uuid = UUID(run_input.thread_id)
tool_call_id, tool_payload = extract_latest_tool_result(run_input)
async with self._session_factory() as db_session:
session_repository = SessionRepository(db_session)
@@ -41,6 +58,41 @@ class ResumeService:
pending_tool_call = state_snapshot.get("pending_tool_call_id")
if pending_tool_call != tool_call_id:
raise ValueError("pending tool call does not match")
pending_tool_name = state_snapshot.get("pending_tool_name")
pending_tool_args_sha256 = state_snapshot.get("pending_tool_args_sha256")
pending_tool_nonce = state_snapshot.get("pending_tool_nonce")
if (
not isinstance(pending_tool_name, str)
or not pending_tool_name
or not isinstance(pending_tool_args_sha256, str)
or not pending_tool_args_sha256
or not isinstance(pending_tool_nonce, str)
or not pending_tool_nonce
):
raise ValueError("pending tool guard is incomplete")
tool_name = tool_payload.get("toolName")
tool_args = tool_payload.get("toolArgs")
nonce = tool_payload.get("nonce")
if not isinstance(tool_name, str) or not tool_name:
raise ValueError("resume payload missing toolName")
if not isinstance(tool_args, dict):
raise ValueError("resume payload missing toolArgs")
if not isinstance(nonce, str) or not nonce:
raise ValueError("resume payload missing nonce")
if tool_name != pending_tool_name:
raise ValueError("resume toolName does not match pending tool")
if nonce != pending_tool_nonce:
raise ValueError("resume nonce does not match pending tool")
computed_args_sha256 = compute_tool_args_sha256(tool_args)
if computed_args_sha256 != pending_tool_args_sha256:
raise ValueError("resume toolArgs does not match pending tool")
sanitized_tool_payload = self._sanitize_tool_payload(
tool_name=tool_name,
tool_args=tool_args,
nonce=nonce,
tool_payload=tool_payload,
)
next_seq = await session_repository.next_message_seq(
session_id=session_uuid
@@ -49,9 +101,13 @@ class ResumeService:
session_id=session_uuid,
seq=next_seq,
role=AgentChatMessageRole.TOOL,
content='{"status":"ok"}',
content=json.dumps(
sanitized_tool_payload, ensure_ascii=True, separators=(",", ":")
),
metadata=MessageMetadataToolResult(
tool_call_id=tool_call_id,
run_id=run_input.run_id,
tool_name=tool_name,
).model_dump(),
)
await message_repository.append_message(
@@ -71,4 +127,61 @@ class ResumeService:
)
await db_session.commit()
return {"session_id": session_id, "resumed": True, "state_snapshot": snapshot}
tool_message_id = f"msg-tool-{next_seq}"
assistant_message_id = f"msg-assistant-{next_seq + 1}"
events = [
ToolCallResultEvent(
message_id=tool_message_id,
tool_call_id=tool_call_id,
content=json.dumps(
sanitized_tool_payload, ensure_ascii=True, separators=(",", ":")
),
).model_dump(mode="json", by_alias=True, exclude_none=True),
TextMessageStartEvent(
message_id=assistant_message_id,
role="assistant",
).model_dump(mode="json", by_alias=True, exclude_none=True),
TextMessageContentEvent(
message_id=assistant_message_id,
delta="Tool result received",
).model_dump(mode="json", by_alias=True, exclude_none=True),
TextMessageEndEvent(
message_id=assistant_message_id
).model_dump(mode="json", by_alias=True, exclude_none=True),
]
return {
"threadId": run_input.thread_id,
"runId": run_input.run_id,
"resumed": True,
"state_snapshot": snapshot,
"events": events,
}
@staticmethod
def _sanitize_tool_payload(
*,
tool_name: str,
tool_args: dict[str, object],
nonce: str,
tool_payload: dict[str, object],
) -> dict[str, object]:
if tool_name != "navigate_to_route":
raise ValueError("unsupported frontend tool in resume payload")
target = tool_args.get("target")
if not isinstance(target, str) or not target:
raise ValueError("resume toolArgs missing target")
raw_result = tool_payload.get("result")
if not isinstance(raw_result, dict) or raw_result.get("ok") is not True:
raise ValueError("frontend tool execution failed")
sanitized_result = {
"ok": True,
"target": target,
"replace": tool_args.get("replace") is True,
"applied": True,
}
return {
"toolName": tool_name,
"toolArgs": tool_args,
"nonce": nonce,
"result": sanitized_result,
}
+375 -27
View File
@@ -1,14 +1,34 @@
from __future__ import annotations
import asyncio
from datetime import datetime, timedelta, timezone
from decimal import Decimal
import json
import re
from uuid import UUID, uuid4
from ag_ui.core import (
TextMessageContentEvent,
TextMessageEndEvent,
TextMessageStartEvent,
ToolCallArgsEvent,
ToolCallEndEvent,
ToolCallResultEvent,
ToolCallStartEvent,
RunAgentInput,
)
from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from core.agent.application.session_state_persistence import SessionStatePersistence
from core.agent.domain.agui_input import extract_latest_user_text
from core.agent.application.session_state_persistence import (
SessionStatePersistence,
compute_tool_args_sha256,
)
from core.agent.domain.message_metadata import (
MessageMetadataAssistantOutput,
MessageMetadataToolResult,
MessageMetadataToolCall,
MessageMetadataUserInput,
)
@@ -27,6 +47,7 @@ from core.agent.infrastructure.persistence.user_context_loader import (
from core.db import AsyncSessionLocal
from models.agent_chat_message import AgentChatMessageRole
from models.agent_chat_session import AgentChatSessionStatus
from models.schedule_items import ScheduleItem, ScheduleItemSourceType, ScheduleItemStatus
from models.llm import Llm
from models.llm_factory import LlmFactory
from models.system_agents import SystemAgents
@@ -60,9 +81,14 @@ class RunService:
self._state_persistence = SessionStatePersistence()
self._user_context_cache = user_context_cache or create_user_context_cache()
async def run(self, *, session_id: str, user_input: str) -> dict[str, object]:
session_uuid = UUID(session_id)
pending_tool_call_id = f"tool-{uuid4()}"
async def run(
self,
*,
run_input: RunAgentInput,
) -> dict[str, object]:
session_uuid = UUID(run_input.thread_id)
user_input = extract_latest_user_text(run_input)
assistant_message_id = f"msg-{uuid4()}"
async with self._session_factory() as db_session:
session_repository = SessionRepository(db_session)
@@ -87,8 +113,12 @@ class RunService:
user_context = await self._load_user_agent_context(
db_session, session_uuid, chat_session.user_id
)
system_prompt = build_global_system_prompt(user_context)
runtime_result = runtime.execute(
system_prompt = self._build_system_prompt_with_tools(
base_prompt=build_global_system_prompt(user_context),
run_input=run_input,
)
runtime_result = await asyncio.to_thread(
runtime.execute,
user_input=user_input,
system_prompt=system_prompt,
)
@@ -97,7 +127,10 @@ class RunService:
completion_tokens = _to_int(runtime_result.get("completion_tokens", 0))
total_tokens = _to_int(runtime_result.get("total_tokens", 0))
cost = _to_decimal(runtime_result.get("cost", 0))
agui_events = runtime_result.get("agui_events", [])
planned_tool = self._select_tool_plan(
user_input=user_input,
available_tools={tool.name for tool in run_input.tools},
)
next_seq = await session_repository.next_message_seq(
session_id=session_uuid
@@ -110,39 +143,354 @@ class RunService:
model_code=model_code,
metadata=MessageMetadataUserInput().model_dump(),
)
await message_repository.append_message(
session_id=session_uuid,
seq=next_seq + 1,
role=AgentChatMessageRole.ASSISTANT,
content=assistant_text or "Tool call pending approval",
model_code=model_code,
metadata=MessageMetadataToolCall(
tool_call_id=pending_tool_call_id,
).model_dump(),
input_tokens=prompt_tokens,
output_tokens=completion_tokens,
cost=cost,
)
pending_tool_call_id: str | None = None
events: list[dict[str, object]] = []
message_delta = 2
session_status = AgentChatSessionStatus.COMPLETED
snapshot = self._state_persistence.build_completed_snapshot()
if planned_tool is None:
await message_repository.append_message(
session_id=session_uuid,
seq=next_seq + 1,
role=AgentChatMessageRole.ASSISTANT,
content=assistant_text or "已完成处理。",
model_code=model_code,
metadata=MessageMetadataAssistantOutput().model_dump(),
input_tokens=prompt_tokens,
output_tokens=completion_tokens,
cost=cost,
)
events.extend(
self._build_text_message_events(
message_id=assistant_message_id,
text=assistant_text or "已完成处理。",
)
)
elif planned_tool["target"] == "backend":
tool_call_id = f"tool-{uuid4()}"
tool_name = str(planned_tool["name"])
tool_args = planned_tool["args"]
if not isinstance(tool_args, dict):
tool_args = {}
tool_payload = await self._execute_backend_tool(
session=db_session,
owner_id=chat_session.user_id,
tool_name=tool_name,
tool_args=tool_args,
)
await message_repository.append_message(
session_id=session_uuid,
seq=next_seq + 1,
role=AgentChatMessageRole.TOOL,
content=json.dumps(
tool_payload,
ensure_ascii=True,
separators=(",", ":"),
),
metadata=MessageMetadataToolResult(
tool_call_id=tool_call_id,
run_id=run_input.run_id,
tool_name=tool_name,
).model_dump(),
)
await message_repository.append_message(
session_id=session_uuid,
seq=next_seq + 2,
role=AgentChatMessageRole.ASSISTANT,
content=assistant_text or "后端工具执行完成。",
model_code=model_code,
metadata=MessageMetadataAssistantOutput().model_dump(),
input_tokens=prompt_tokens,
output_tokens=completion_tokens,
cost=cost,
)
message_delta = 3
events.extend(
self._build_tool_call_events(
tool_call_id=tool_call_id,
tool_name=tool_name,
tool_args=tool_args,
)
)
events.append(
ToolCallResultEvent(
message_id=f"msg-tool-{uuid4()}",
tool_call_id=tool_call_id,
content=json.dumps(
tool_payload,
ensure_ascii=True,
separators=(",", ":"),
),
).model_dump(mode="json", by_alias=True, exclude_none=True)
)
events.extend(
self._build_text_message_events(
message_id=assistant_message_id,
text=assistant_text or "后端工具执行完成。",
)
)
else:
pending_tool_call_id = f"tool-{uuid4()}"
tool_name = str(planned_tool["name"])
tool_args = planned_tool["args"]
if not isinstance(tool_args, dict):
tool_args = {}
pending_tool_nonce = uuid4().hex
guarded_tool_args = {
**tool_args,
"__nonce": pending_tool_nonce,
}
pending_tool_args_sha256 = compute_tool_args_sha256(guarded_tool_args)
await message_repository.append_message(
session_id=session_uuid,
seq=next_seq + 1,
role=AgentChatMessageRole.ASSISTANT,
content=assistant_text or "Tool call pending approval",
model_code=model_code,
metadata=MessageMetadataToolCall(
tool_call_id=pending_tool_call_id,
).model_dump(),
input_tokens=prompt_tokens,
output_tokens=completion_tokens,
cost=cost,
)
snapshot = self._state_persistence.build_running_snapshot(
pending_tool_call_id=pending_tool_call_id,
pending_tool_name=tool_name,
pending_tool_args_sha256=pending_tool_args_sha256,
pending_tool_nonce=pending_tool_nonce,
)
session_status = AgentChatSessionStatus.RUNNING
events.extend(
self._build_tool_call_events(
tool_call_id=pending_tool_call_id,
tool_name=tool_name,
tool_args=guarded_tool_args,
)
)
events.extend(
self._build_text_message_events(
message_id=assistant_message_id,
text=assistant_text or "请确认是否执行前端工具。",
)
)
snapshot = self._state_persistence.build_running_snapshot(
pending_tool_call_id=pending_tool_call_id
)
await session_repository.update_runtime_state(
chat_session=chat_session,
status=AgentChatSessionStatus.RUNNING,
status=session_status,
state_snapshot=snapshot,
message_delta=2,
message_delta=message_delta,
token_delta=total_tokens,
cost_delta=cost,
)
await db_session.commit()
return {
"session_id": session_id,
"threadId": run_input.thread_id,
"runId": run_input.run_id,
"persisted": True,
"pending_tool_call_id": pending_tool_call_id,
"state_snapshot": snapshot,
"events": agui_events,
"events": events,
}
def _build_system_prompt_with_tools(
self, *, base_prompt: str, run_input: RunAgentInput
) -> str:
if not run_input.tools:
return base_prompt
tool_lines = [
f"- {tool.name}: {tool.description}" for tool in run_input.tools
]
tools_block = "\n".join(tool_lines)
return f"# AVAILABLE_TOOLS\n{tools_block}\n\n{base_prompt}"
def _select_tool_plan(
self,
*,
user_input: str,
available_tools: set[str],
) -> dict[str, object] | None:
forced = re.search(r"#tool:(\w+)\s*(\{.*\})?", user_input)
if forced is not None:
forced_name = forced.group(1)
if forced_name not in available_tools:
return None
raw_args = forced.group(2)
args: dict[str, object] = {}
if raw_args:
try:
parsed = json.loads(raw_args)
if isinstance(parsed, dict):
args = parsed
except (TypeError, ValueError):
args = {}
target = (
"frontend" if forced_name == "navigate_to_route" else "backend"
)
return {"name": forced_name, "args": args, "target": target}
normalized = user_input.lower()
wants_navigation = any(
keyword in normalized for keyword in ("打开", "跳转", "进入", "navigate", "open")
)
if wants_navigation and "navigate_to_route" in available_tools:
target_route = "/calendar/dayweek"
if "设置" in user_input:
target_route = "/settings"
elif "待办" in user_input:
target_route = "/todo"
return {
"name": "navigate_to_route",
"args": {"target": target_route, "replace": False},
"target": "frontend",
}
return None
def _infer_calendar_args(self, user_input: str) -> dict[str, object]:
start_at = datetime.now(timezone.utc) + timedelta(hours=1)
title = user_input.strip()[:80] or "新的日程"
return {
"title": title,
"description": user_input.strip(),
"startAt": start_at.isoformat(),
"timezone": "Asia/Shanghai",
}
def _build_text_message_events(
self, *, message_id: str, text: str
) -> list[dict[str, object]]:
events: list[dict[str, object]] = [
TextMessageStartEvent(
message_id=message_id,
role="assistant",
).model_dump(mode="json", by_alias=True, exclude_none=True),
]
if text:
events.append(
TextMessageContentEvent(
message_id=message_id,
delta=text,
).model_dump(mode="json", by_alias=True, exclude_none=True)
)
events.append(
TextMessageEndEvent(
message_id=message_id
).model_dump(mode="json", by_alias=True, exclude_none=True)
)
return events
def _build_tool_call_events(
self,
*,
tool_call_id: str,
tool_name: str,
tool_args: dict[str, object],
) -> list[dict[str, object]]:
return [
ToolCallStartEvent(
tool_call_id=tool_call_id,
tool_call_name=tool_name,
).model_dump(mode="json", by_alias=True, exclude_none=True),
ToolCallArgsEvent(
tool_call_id=tool_call_id,
delta=json.dumps(tool_args, ensure_ascii=True, separators=(",", ":")),
).model_dump(mode="json", by_alias=True, exclude_none=True),
ToolCallEndEvent(
tool_call_id=tool_call_id
).model_dump(mode="json", by_alias=True, exclude_none=True),
]
async def _execute_backend_tool(
self,
*,
session: AsyncSession,
owner_id: UUID,
tool_name: str,
tool_args: dict[str, object],
) -> dict[str, object]:
if tool_name != "create_calendar_event":
raise ValueError(f"unsupported backend tool: {tool_name}")
title = str(tool_args.get("title", "新的日程")).strip() or "新的日程"
description = str(tool_args.get("description", "")).strip() or None
start_raw = tool_args.get("startAt")
start_at = datetime.now(timezone.utc) + timedelta(hours=1)
if isinstance(start_raw, str) and start_raw:
try:
parsed = datetime.fromisoformat(start_raw.replace("Z", "+00:00"))
if parsed.tzinfo is None:
parsed = parsed.replace(tzinfo=timezone.utc)
start_at = parsed.astimezone(timezone.utc)
except ValueError:
start_at = datetime.now(timezone.utc) + timedelta(hours=1)
end_raw = tool_args.get("endAt")
end_at: datetime | None = None
if isinstance(end_raw, str) and end_raw:
try:
parsed_end = datetime.fromisoformat(end_raw.replace("Z", "+00:00"))
if parsed_end.tzinfo is None:
parsed_end = parsed_end.replace(tzinfo=timezone.utc)
end_at = parsed_end.astimezone(timezone.utc)
except ValueError:
end_at = None
timezone_value = str(tool_args.get("timezone", "Asia/Shanghai"))
location = tool_args.get("location")
location_value = str(location) if isinstance(location, str) else None
schedule_item = ScheduleItem(
owner_id=owner_id,
title=title,
description=description,
start_at=start_at,
end_at=end_at,
timezone=timezone_value,
extra_metadata={"location": location_value} if location_value else {},
source_type=ScheduleItemSourceType.AGENT_GENERATED,
status=ScheduleItemStatus.ACTIVE,
created_by=owner_id,
)
session.add(schedule_item)
await session.flush()
event_id = str(schedule_item.id)
ui_card = {
"type": "calendar_card.v1",
"version": "v1",
"data": {
"id": event_id,
"title": title,
"description": description,
"startAt": start_at.isoformat(),
"endAt": end_at.isoformat() if end_at is not None else None,
"timezone": timezone_value,
"location": location_value,
"color": "#4F46E5",
"sourceType": "agent_generated",
},
"actions": [
{
"type": "link",
"label": "查看详情",
"target": f"/calendar/events/{event_id}",
}
],
}
return {
"result": {
"eventId": event_id,
"ok": True,
"message": "日程已创建",
"title": title,
"description": description,
"startAt": start_at.isoformat(),
"endAt": end_at.isoformat() if end_at is not None else None,
"timezone": timezone_value,
"location": location_value,
"sourceType": "agent_generated",
},
"ui": ui_card,
}
async def _load_user_agent_context(
@@ -10,17 +10,35 @@ from core.agent.domain.state_snapshot import AgentStateSnapshot
class SessionStatePersistence:
def build_running_snapshot(
self, *, pending_tool_call_id: str | None
self,
*,
pending_tool_call_id: str | None,
pending_tool_name: str | None = None,
pending_tool_args_sha256: str | None = None,
pending_tool_nonce: str | None = None,
) -> dict[str, object]:
return AgentStateSnapshot(
status="running",
pending_tool_call_id=pending_tool_call_id,
pending_tool_name=pending_tool_name,
pending_tool_args_sha256=pending_tool_args_sha256,
pending_tool_nonce=pending_tool_nonce,
).model_dump()
def build_completed_snapshot(self) -> dict[str, object]:
return AgentStateSnapshot(status="completed").model_dump()
def compute_tool_args_sha256(tool_args: dict[str, object]) -> str:
encoded = json.dumps(
tool_args,
ensure_ascii=True,
sort_keys=True,
separators=(",", ":"),
).encode("utf-8")
return hashlib.sha256(encoded).hexdigest()
class ToolResultStorage(Protocol):
async def upload_json(
self,
+109
View File
@@ -0,0 +1,109 @@
from __future__ import annotations
import json
from typing import Any
from uuid import UUID
from ag_ui.core import RunAgentInput
from pydantic import ValidationError
MAX_RUN_INPUT_BYTES = 256_000
MAX_RUN_ID_LENGTH = 128
MAX_MESSAGES = 200
MAX_TEXT_CHARS = 10_000
def _safe_len(value: str | None) -> int:
if value is None:
return 0
return len(value)
def _user_text_chars(run_input: RunAgentInput) -> int:
total = 0
for message in run_input.messages:
if getattr(message, "role", None) != "user":
continue
content = getattr(message, "content", None)
if isinstance(content, str):
total += len(content)
continue
if isinstance(content, list):
for item in content:
if getattr(item, "type", None) != "text":
continue
text = getattr(item, "text", None)
if isinstance(text, str):
total += len(text)
return total
def parse_run_input(payload: dict[str, Any]) -> RunAgentInput:
payload_bytes = len(
json.dumps(payload, ensure_ascii=True, separators=(",", ":")).encode("utf-8")
)
if payload_bytes > MAX_RUN_INPUT_BYTES:
raise ValueError("RunAgentInput payload exceeds size limit")
try:
run_input = RunAgentInput.model_validate(payload)
except ValidationError as exc:
raise ValueError("invalid AG-UI RunAgentInput payload") from exc
try:
UUID(run_input.thread_id)
except ValueError as exc:
raise ValueError("threadId must be a valid UUID") from exc
if _safe_len(run_input.run_id) > MAX_RUN_ID_LENGTH:
raise ValueError("runId exceeds length limit")
if len(run_input.messages) > MAX_MESSAGES:
raise ValueError("RunAgentInput.messages exceeds limit")
if _user_text_chars(run_input) > MAX_TEXT_CHARS:
raise ValueError("RunAgentInput user message text exceeds limit")
return run_input
def extract_latest_user_text(run_input: RunAgentInput) -> str:
for message in reversed(run_input.messages):
role = getattr(message, "role", None)
if role != "user":
continue
content = getattr(message, "content", None)
if isinstance(content, str):
text = content.strip()
if text:
return text
continue
if isinstance(content, list):
text_parts: list[str] = []
for item in content:
if getattr(item, "type", None) != "text":
continue
text = getattr(item, "text", None)
if isinstance(text, str):
text_parts.append(text)
combined = "".join(text_parts).strip()
if combined:
return combined
raise ValueError("RunAgentInput.messages requires at least one non-empty user message")
def extract_latest_tool_result(run_input: RunAgentInput) -> tuple[str, dict[str, object]]:
for message in reversed(run_input.messages):
role = getattr(message, "role", None)
if role != "tool":
continue
tool_call_id = getattr(message, "tool_call_id", None)
content = getattr(message, "content", None)
if not isinstance(tool_call_id, str) or not tool_call_id:
continue
if not isinstance(content, str):
break
try:
parsed = json.loads(content)
except (TypeError, ValueError):
return tool_call_id, {"content": content}
if isinstance(parsed, dict):
return tool_call_id, parsed
return tool_call_id, {"content": content}
raise ValueError(
"RunAgentInput.messages requires a tool message with toolCallId for resume"
)
@@ -8,3 +8,6 @@ from pydantic import BaseModel
class AgentStateSnapshot(BaseModel):
status: Literal["pending", "running", "completed", "failed"]
pending_tool_call_id: str | None = None
pending_tool_name: str | None = None
pending_tool_args_sha256: str | None = None
pending_tool_nonce: str | None = None
@@ -6,3 +6,4 @@ from pydantic import BaseModel, Field
class SystemAgentLLMConfig(BaseModel):
temperature: float | None = Field(default=None, ge=0.0, le=2.0)
max_tokens: int | None = Field(default=None, ge=1)
timeout_seconds: float | None = Field(default=30.0, gt=0.0, le=300.0)
@@ -1,10 +1,16 @@
from __future__ import annotations
import json
import re
from typing import Any
_EVENT_TYPE_RE = re.compile(r"^[A-Z0-9_]+$")
def to_sse_event(stream_id: str, event: dict[str, Any]) -> str:
event_type = str(event.get("type", "MESSAGE"))
payload = json.dumps(event.get("data", {}), ensure_ascii=True)
raw_event_type = str(event.get("type", "MESSAGE")).replace("\r", "").replace(
"\n", ""
)
event_type = raw_event_type if _EVENT_TYPE_RE.fullmatch(raw_event_type) else "MESSAGE"
payload = json.dumps(event, ensure_ascii=True, separators=(",", ":"))
return f"id: {stream_id}\nevent: {event_type}\ndata: {payload}\n\n"
@@ -129,6 +129,7 @@ def _run_stage(
messages=messages,
temperature=llm_config.temperature,
max_tokens=llm_config.max_tokens,
timeout=llm_config.timeout_seconds,
)
if not isinstance(response, dict):
raise ValueError("llm response must be a dict")
@@ -46,7 +46,7 @@ class RedisStreamEventStore:
last_event_id: str | None,
) -> list[dict[str, Any]]:
stream = self._stream_name(session_id)
start_id = "$" if last_event_id is None else last_event_id
start_id = "0-0" if last_event_id is None else last_event_id
raw_response = self._client.xread(
{stream: start_id},
count=self._read_count,
@@ -59,13 +59,37 @@ class RedisStreamEventStore:
if not response:
return []
_, entries = response[0]
first = response[0]
if (
not isinstance(first, tuple)
or len(first) != 2
or not isinstance(first[1], list)
):
return []
_, entries = first
result: list[dict[str, Any]] = []
for stream_id, payload in entries:
for entry in entries:
if (
not isinstance(entry, tuple)
or len(entry) != 2
or not isinstance(entry[0], str)
or not isinstance(entry[1], dict)
):
continue
stream_id, payload = entry
event_payload = payload.get("event")
if not isinstance(event_payload, str):
continue
try:
parsed_event = json.loads(event_payload)
except (TypeError, ValueError):
continue
if not isinstance(parsed_event, dict):
continue
result.append(
{
"id": stream_id,
"event": json.loads(payload["event"]),
"event": parsed_event,
}
)
return result
@@ -12,6 +12,7 @@ def run_completion(
messages: list[dict[str, Any]],
temperature: float | None = None,
max_tokens: int | None = None,
timeout: float | None = None,
) -> Any:
kwargs: dict[str, Any] = {
"model": model,
@@ -23,6 +24,8 @@ def run_completion(
kwargs["temperature"] = temperature
if max_tokens is not None:
kwargs["max_tokens"] = max_tokens
if timeout is not None:
kwargs["timeout"] = timeout
response = completion(**kwargs)
model_dump = getattr(response, "model_dump", None)
@@ -9,6 +9,9 @@ import redis.asyncio as redis
from core.agent.domain.user_context import UserAgentContext, parse_profile_settings
from core.config.settings import config
from core.logging import get_logger
logger = get_logger("core.agent.infrastructure.persistence.user_context_cache")
class RedisHashClient(Protocol):
@@ -47,7 +50,12 @@ class UserContextCache:
key = self._key(session_id)
try:
raw = await _maybe_await(self._client.hgetall(key))
except Exception:
except Exception as exc:
logger.warning(
"Failed to read user context cache",
session_id=str(session_id),
error=str(exc),
)
return None
if not isinstance(raw, dict) or not raw:
@@ -92,7 +100,12 @@ class UserContextCache:
)
)
await _maybe_await(self._client.expire(key, self._ttl_seconds))
except Exception:
except Exception as exc:
logger.warning(
"Failed to write user context cache",
session_id=str(session_id),
error=str(exc),
)
return None
def _key(self, session_id: UUID) -> str:
@@ -136,13 +149,21 @@ class UserContextCache:
async def _safe_delete(self, key: str) -> None:
try:
await _maybe_await(self._client.delete(key))
except Exception:
except Exception as exc:
logger.warning("Failed to delete user context cache key", key=key, error=str(exc))
return None
async def _safe_hincrby(self, key: str, field: str, amount: int) -> None:
try:
await _maybe_await(self._client.hincrby(key, field, amount))
except Exception:
except Exception as exc:
logger.warning(
"Failed to update user context cache usage",
key=key,
field=field,
amount=amount,
error=str(exc),
)
return None
@@ -2,7 +2,10 @@ from __future__ import annotations
from typing import Any, Protocol
from uuid import UUID
import re
from ag_ui.core import RunAgentInput, RunErrorEvent, RunFinishedEvent, RunStartedEvent
from core.agent.domain.agui_input import parse_run_input
from core.agent.application.resume_service import ResumeService
from core.agent.application.run_service import RunService
from core.agent.infrastructure.events.redis_stream import RedisStreamEventStore
@@ -13,19 +16,65 @@ from services.base.redis import get_or_init_redis_client
logger = get_logger("core.agent.infrastructure.queue.tasks")
_NON_ALNUM_RE = re.compile(r"[^A-Za-z0-9]+")
_SENSITIVE_KEYS = {
"apikey",
"authorization",
"token",
"accesstoken",
"refreshtoken",
"secret",
"password",
"cookie",
}
class PublishEvent(Protocol):
async def __call__(self, event_type: str, payload: dict[str, object]) -> None: ...
async def __call__(self, event: dict[str, object]) -> None: ...
class RunServiceLike(Protocol):
async def run(self, *, session_id: str, user_input: str) -> dict[str, object]: ...
async def run(self, *, run_input: RunAgentInput) -> dict[str, object]: ...
class ResumeServiceLike(Protocol):
async def resume(
self, *, session_id: str, tool_call_id: str
) -> dict[str, object]: ...
async def resume(self, *, run_input: RunAgentInput) -> dict[str, object]: ...
def _is_sensitive_key(key: str) -> bool:
normalized = _NON_ALNUM_RE.sub("", key.lower())
if normalized in _SENSITIVE_KEYS:
return True
if "token" in normalized:
return True
if "api" in normalized and "key" in normalized:
return True
return False
def _redact_sensitive(value: Any) -> Any:
if isinstance(value, dict):
return {
k: "***REDACTED***" if _is_sensitive_key(str(k)) else _redact_sensitive(v)
for k, v in value.items()
}
if isinstance(value, list):
return [_redact_sensitive(item) for item in value]
return value
def _normalize_stream_event(
*,
event: dict[str, object],
thread_id: str,
run_id: str,
) -> dict[str, object]:
normalized = dict(event)
normalized["threadId"] = thread_id
normalized["runId"] = run_id
if normalized.get("type") == "RUN_STARTED":
normalized.pop("input", None)
return _redact_sensitive(normalized)
async def _build_redis_publisher() -> PublishEvent:
@@ -37,13 +86,13 @@ async def _build_redis_publisher() -> PublishEvent:
block_ms=config.agent_runtime.redis_stream_block_ms,
)
async def _publish(event_type: str, payload: dict[str, object]) -> None:
session_id = str(payload.get("session_id", "")).strip()
if not session_id:
raise ValueError("session_id is required in event payload")
async def _publish(event: dict[str, object]) -> None:
thread_id = str(event.get("threadId", "")).strip()
if not thread_id:
raise ValueError("threadId is required in event payload")
await event_store.append_event(
session_id=UUID(session_id),
event={"type": event_type, "data": payload},
session_id=UUID(thread_id),
event=event,
)
return _publish
@@ -61,69 +110,69 @@ async def run_agent_task(
service_resume = resume_service or ResumeService()
command_type = str(command.get("command", "run"))
session_id = str(command.get("session_id", ""))
if command_type not in {"run", "resume"}:
raise ValueError("invalid command type")
if not session_id:
raise ValueError("session_id is required")
UUID(session_id)
raw_run_input = command.get("run_input")
if not isinstance(raw_run_input, dict):
raise ValueError("run_input is required")
run_input = parse_run_input(raw_run_input)
UUID(run_input.thread_id)
tool_call_id = ""
user_input = ""
if command_type == "resume":
tool_call_id = str(command.get("tool_call_id", ""))
if not tool_call_id:
raise ValueError("tool_call_id is required")
else:
user_input = str(command.get("user_input", ""))
if not user_input:
raise ValueError("user_input is required")
start_event = "RUN_RESUMED" if command_type == "resume" else "RUN_STARTED"
await publisher(start_event, {"session_id": session_id})
await publisher(
RunStartedEvent(
thread_id=run_input.thread_id,
run_id=run_input.run_id,
parent_run_id=run_input.parent_run_id,
).model_dump(mode="json", by_alias=True, exclude_none=True)
)
try:
if command_type == "resume":
result = await service_resume.resume(
session_id=session_id,
tool_call_id=tool_call_id,
)
result = await service_resume.resume(run_input=run_input)
else:
result = await service_run.run(
session_id=session_id,
user_input=user_input,
)
result = await service_run.run(run_input=run_input)
await publisher("RUNTIME_EVENT", {"session_id": session_id, "result": result})
extra_events = result.get("events") if isinstance(result, dict) else None
if isinstance(extra_events, list):
for event in extra_events:
if not isinstance(event, dict):
continue
event_type = event.get("type")
event_data = event.get("data")
if not isinstance(event_type, str) or not isinstance(event_data, dict):
if not isinstance(event_type, str):
continue
payload = {"session_id": session_id, **event_data}
await publisher(event_type, payload)
await publisher("RUN_FINISHED", {"session_id": session_id})
await publisher(
_normalize_stream_event(
event=event,
thread_id=run_input.thread_id,
run_id=run_input.run_id,
)
)
await publisher(
RunFinishedEvent(
thread_id=run_input.thread_id,
run_id=run_input.run_id,
).model_dump(mode="json", by_alias=True, exclude_none=True)
)
return result
except Exception: # noqa: BLE001
error_id = "agent_runtime_failed"
logger.exception(
"Agent task failed",
session_id=session_id,
thread_id=run_input.thread_id,
error_id=error_id,
)
try:
await publisher(
"RUN_ERROR", {"session_id": session_id, "error_id": error_id}
)
error_event = RunErrorEvent(
message="Agent task failed",
code=error_id,
).model_dump(mode="json", by_alias=True, exclude_none=True)
error_event["threadId"] = run_input.thread_id
error_event["runId"] = run_input.run_id
await publisher(error_event)
except Exception as publish_exc: # noqa: BLE001
logger.warning(
"Failed to publish RUN_ERROR event",
session_id=session_id,
thread_id=run_input.thread_id,
error=str(publish_exc),
)
raise
+124 -1
View File
@@ -1,11 +1,14 @@
from __future__ import annotations
from datetime import date, datetime, time, timedelta, timezone
import json
from uuid import UUID
from fastapi import HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
from models.agent_chat_session import AgentChatSession
@@ -27,13 +30,22 @@ class AgentRepository:
raise HTTPException(status_code=404, detail="Session not found")
return str(owner_id)
async def create_session_for_user(self, *, user_id: str) -> str:
async def create_session_for_user(
self, *, user_id: str, session_id: str | None = None
) -> str:
try:
user_uuid = UUID(user_id)
except ValueError as exc:
raise HTTPException(status_code=422, detail="Invalid user_id") from exc
session_uuid = None
if session_id is not None:
try:
session_uuid = UUID(session_id)
except ValueError as exc:
raise HTTPException(status_code=422, detail="Invalid session_id") from exc
session = AgentChatSession(
id=session_uuid,
user_id=user_uuid,
)
self._session.add(session)
@@ -56,3 +68,114 @@ class AgentRepository:
if session is not None:
await self._session.delete(session)
await self._session.flush()
async def get_history_day(
self, *, session_id: str, before: date | None
) -> dict[str, object] | None:
try:
session_uuid = UUID(session_id)
except ValueError as exc:
raise HTTPException(status_code=422, detail="Invalid session_id") from exc
timestamp_stmt = (
select(AgentChatMessage.created_at)
.where(AgentChatMessage.session_id == session_uuid)
.where(AgentChatMessage.deleted_at.is_(None))
.order_by(AgentChatMessage.created_at.desc())
)
rows = (await self._session.execute(timestamp_stmt)).scalars().all()
unique_days: list[date] = []
for created_at in rows:
if created_at is None:
continue
day = created_at.astimezone(timezone.utc).date()
if day not in unique_days:
unique_days.append(day)
if not unique_days:
return None
target_day: date | None = None
if before is None:
target_day = unique_days[0]
else:
for day in unique_days:
if day < before:
target_day = day
break
if target_day is None:
return None
start = datetime.combine(target_day, time.min, tzinfo=timezone.utc)
end = start + timedelta(days=1)
message_stmt = (
select(AgentChatMessage)
.where(AgentChatMessage.session_id == session_uuid)
.where(AgentChatMessage.deleted_at.is_(None))
.where(AgentChatMessage.created_at >= start)
.where(AgentChatMessage.created_at < end)
.order_by(AgentChatMessage.seq.asc())
)
messages = (await self._session.execute(message_stmt)).scalars().all()
has_more = any(day < target_day for day in unique_days)
return {
"day": target_day.isoformat(),
"hasMore": has_more,
"messages": [self._to_snapshot_message(msg) for msg in messages],
}
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None:
try:
user_uuid = UUID(user_id)
except ValueError as exc:
raise HTTPException(status_code=422, detail="Invalid user_id") from exc
stmt = (
select(AgentChatSession.id)
.where(AgentChatSession.user_id == user_uuid)
.where(AgentChatSession.deleted_at.is_(None))
.order_by(AgentChatSession.last_activity_at.desc())
.limit(1)
)
latest_id = (await self._session.execute(stmt)).scalar_one_or_none()
if latest_id is None:
return None
return str(latest_id)
@staticmethod
def _to_snapshot_message(message: AgentChatMessage) -> dict[str, object]:
role = (
message.role.value
if isinstance(message.role, AgentChatMessageRole)
else str(message.role)
)
payload: dict[str, object] = {
"id": str(message.id),
"role": role,
"timestamp": message.created_at.astimezone(timezone.utc).isoformat(),
}
if role == AgentChatMessageRole.TOOL.value:
metadata = message.metadata_json or {}
tool_call_id = metadata.get("tool_call_id")
if isinstance(tool_call_id, str) and tool_call_id:
payload["toolCallId"] = tool_call_id
parsed_content: dict[str, object] | None = None
try:
decoded = json.loads(message.content)
if isinstance(decoded, dict):
parsed_content = decoded
except (TypeError, ValueError):
parsed_content = None
if parsed_content is not None:
ui = parsed_content.get("ui")
if isinstance(ui, dict):
payload["ui"] = ui
display_content = parsed_content.get("content")
if isinstance(display_content, str):
payload["content"] = display_content
else:
payload["content"] = message.content
else:
payload["content"] = message.content
return payload
+141 -32
View File
@@ -2,96 +2,177 @@ from __future__ import annotations
from collections.abc import AsyncIterator
import asyncio
from datetime import date
import re
import time
from typing import Annotated
from ag_ui.core import RunAgentInput
from fastapi import APIRouter, Depends, Header, Query, Request, status
from fastapi import HTTPException
from fastapi.responses import StreamingResponse
from core.agent.infrastructure.agui.stream import to_sse_event
from core.agent.domain.agui_input import parse_run_input
from core.auth.models import CurrentUser
from services.base.redis import get_or_init_redis_client
from v1.agent.dependencies import get_agent_service
from v1.agent.schemas import ResumeRequest, RunRequest, TaskAcceptedResponse
from v1.agent.schemas import TaskAcceptedResponse
from v1.agent.service import AgentService
from v1.users.dependencies import get_current_user
router = APIRouter(prefix="/agent", tags=["agent"])
_LAST_EVENT_ID_RE = re.compile(r"^\d+-\d+$")
_RUNS_PER_MINUTE = 30
_MAX_SSE_CONNECTIONS_PER_USER = 3
_SSE_SLOT_TTL_SECONDS = 15 * 60
async def _allow_run_request(*, user_id: str) -> bool:
try:
redis = await get_or_init_redis_client()
minute_bucket = int(time.time() // 60)
key = f"agent:run-rate:{user_id}:{minute_bucket}"
count = await redis.incr(key)
if count == 1:
await redis.expire(key, 70)
return int(count) <= _RUNS_PER_MINUTE
except Exception: # noqa: BLE001
return False
async def _acquire_sse_slot(*, user_id: str) -> bool:
try:
redis = await get_or_init_redis_client()
key = f"agent:sse-active:{user_id}"
count = await redis.incr(key)
if count == 1:
await redis.expire(key, _SSE_SLOT_TTL_SECONDS)
if int(count) > _MAX_SSE_CONNECTIONS_PER_USER:
await redis.decr(key)
return False
return True
except Exception: # noqa: BLE001
return False
async def _release_sse_slot(*, user_id: str) -> None:
try:
redis = await get_or_init_redis_client()
key = f"agent:sse-active:{user_id}"
count = await redis.decr(key)
if int(count) <= 0:
await redis.delete(key)
except Exception: # noqa: BLE001
return None
@router.post(
"/runs", response_model=TaskAcceptedResponse, status_code=status.HTTP_202_ACCEPTED
)
async def enqueue_run(
request: RunRequest,
request: RunAgentInput,
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
) -> TaskAcceptedResponse:
try:
parse_run_input(request.model_dump(mode="json", by_alias=True))
except ValueError as exc:
raise HTTPException(status_code=422, detail=str(exc)) from exc
allowed = await _allow_run_request(user_id=str(current_user.id))
if not allowed:
raise HTTPException(status_code=429, detail="Too many run requests")
task = await service.enqueue_run(
session_id=request.session_id,
prompt=request.prompt,
run_input=request,
current_user=current_user,
)
return TaskAcceptedResponse(
task_id=task.task_id,
session_id=task.session_id,
thread_id=task.thread_id,
run_id=task.run_id,
created=task.created,
)
@router.post(
"/runs/{session_id}/resume",
"/runs/{thread_id}/resume",
response_model=TaskAcceptedResponse,
status_code=status.HTTP_202_ACCEPTED,
)
async def enqueue_resume(
session_id: str,
request: ResumeRequest,
thread_id: str,
request: RunAgentInput,
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
) -> TaskAcceptedResponse:
if request.thread_id != thread_id:
raise HTTPException(status_code=422, detail="thread_id path/body mismatch")
try:
parse_run_input(request.model_dump(mode="json", by_alias=True))
except ValueError as exc:
raise HTTPException(status_code=422, detail=str(exc)) from exc
task = await service.enqueue_resume(
session_id=session_id,
tool_call_id=request.tool_call_id,
thread_id=thread_id,
run_input=request,
current_user=current_user,
)
return TaskAcceptedResponse(
task_id=task.task_id,
session_id=task.session_id,
thread_id=task.thread_id,
run_id=task.run_id,
created=task.created,
)
@router.get("/runs/{session_id}/events")
@router.get("/runs/{thread_id}/events")
async def stream_events(
request: Request,
session_id: str,
thread_id: str,
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
last_event_id: str | None = Header(default=None, alias="Last-Event-ID"),
idle_limit: int = Query(default=300, ge=1, le=3600),
) -> StreamingResponse:
if (
last_event_id is not None
and (
len(last_event_id) > 32
or _LAST_EVENT_ID_RE.fullmatch(last_event_id) is None
)
):
raise HTTPException(status_code=422, detail="Invalid Last-Event-ID")
sse_slot_acquired = await _acquire_sse_slot(user_id=str(current_user.id))
if not sse_slot_acquired:
raise HTTPException(status_code=429, detail="Too many SSE connections")
async def _event_iter() -> AsyncIterator[str]:
cursor = last_event_id
idle_polls = 0
while not await request.is_disconnected() and idle_polls < idle_limit:
rows = await service.stream_events(
session_id=session_id,
last_event_id=cursor,
current_user=current_user,
)
if not rows:
idle_polls += 1
yield ": keep-alive\n\n"
await asyncio.sleep(0.2)
continue
idle_polls = 0
for row in rows:
row_id = str(row.get("id", ""))
event = row.get("event")
if not row_id or not isinstance(event, dict):
try:
while not await request.is_disconnected() and idle_polls < idle_limit:
rows = await service.stream_events(
thread_id=thread_id,
last_event_id=cursor,
current_user=current_user,
)
if not rows:
idle_polls += 1
yield ": keep-alive\n\n"
await asyncio.sleep(0.2)
continue
cursor = row_id
yield to_sse_event(row_id, event)
idle_polls = 0
for row in rows:
row_id = str(row.get("id", ""))
event = row.get("event")
if not row_id or not isinstance(event, dict):
continue
cursor = row_id
yield to_sse_event(row_id, event)
finally:
await _release_sse_slot(user_id=str(current_user.id))
return StreamingResponse(
_event_iter(),
@@ -102,3 +183,31 @@ async def stream_events(
"X-Accel-Buffering": "no",
},
)
@router.get("/runs/{thread_id}/history")
async def get_history_snapshot(
thread_id: str,
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
before: date | None = Query(default=None),
) -> dict[str, object]:
return await service.get_history_snapshot(
thread_id=thread_id,
before=before,
current_user=current_user,
)
@router.get("/history")
async def get_user_history_snapshot(
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
thread_id: str | None = Query(default=None, alias="threadId"),
before: date | None = Query(default=None),
) -> dict[str, object]:
return await service.get_user_history_snapshot(
current_user=current_user,
thread_id=thread_id,
before=before,
)
+6 -12
View File
@@ -1,18 +1,12 @@
from __future__ import annotations
from pydantic import BaseModel, Field
class RunRequest(BaseModel):
session_id: str | None = Field(default=None, min_length=1, max_length=100)
prompt: str = Field(min_length=1, max_length=5000)
class ResumeRequest(BaseModel):
tool_call_id: str = Field(min_length=1, max_length=200)
from pydantic import BaseModel, ConfigDict, Field
class TaskAcceptedResponse(BaseModel):
task_id: str
session_id: str
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
task_id: str = Field(alias="taskId")
thread_id: str = Field(alias="threadId")
run_id: str = Field(alias="runId")
created: bool
+109 -29
View File
@@ -1,9 +1,13 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import date
from typing import Protocol
from ag_ui.core import StateSnapshotEvent
from ag_ui.core import RunAgentInput
from fastapi import HTTPException
from sqlalchemy.exc import IntegrityError
from core.auth.models import CurrentUser
@@ -11,19 +15,28 @@ from core.auth.models import CurrentUser
@dataclass(frozen=True)
class TaskAccepted:
task_id: str
session_id: str
thread_id: str
run_id: str
created: bool
class AgentRepositoryLike(Protocol):
async def get_session_owner(self, *, session_id: str) -> str: ...
async def create_session_for_user(self, *, user_id: str) -> str: ...
async def create_session_for_user(
self, *, user_id: str, session_id: str | None = None
) -> str: ...
async def commit(self) -> None: ...
async def rollback(self) -> None: ...
async def get_history_day(
self, *, session_id: str, before: date | None
) -> dict[str, object] | None: ...
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None: ...
class QueueClientLike(Protocol):
async def enqueue(
@@ -60,73 +73,140 @@ class AgentService:
async def enqueue_run(
self,
*,
session_id: str | None,
prompt: str,
run_input: RunAgentInput,
current_user: CurrentUser,
) -> TaskAccepted:
created = False
target_session_id = session_id
if target_session_id is None:
target_session_id = await self._repository.create_session_for_user(
user_id=str(current_user.id)
)
created = True
thread_id = run_input.thread_id
run_id = run_input.run_id
try:
owner = await self._repository.get_session_owner(session_id=thread_id)
except HTTPException as exc:
if exc.status_code != 404:
raise
try:
await self._repository.create_session_for_user(
user_id=str(current_user.id),
session_id=thread_id,
)
await self._repository.commit()
created = True
except IntegrityError:
await self._repository.rollback()
owner = await self._repository.get_session_owner(session_id=thread_id)
ensure_session_owner(owner_id=owner, current_user=current_user)
else:
owner = await self._repository.get_session_owner(
session_id=target_session_id
)
ensure_session_owner(owner_id=owner, current_user=current_user)
if created:
await self._repository.commit()
try:
task_id = await self._queue.enqueue(
command={
"command": "run",
"session_id": target_session_id,
"user_input": prompt,
"run_input": run_input.model_dump(mode="json", by_alias=True),
},
dedup_key=None,
)
except Exception: # noqa: BLE001
raise
return TaskAccepted(
task_id=task_id, session_id=target_session_id, created=created
task_id=task_id,
thread_id=thread_id,
run_id=run_id,
created=created,
)
async def enqueue_resume(
self,
*,
session_id: str,
tool_call_id: str,
thread_id: str,
run_input: RunAgentInput,
current_user: CurrentUser,
) -> TaskAccepted:
owner = await self._repository.get_session_owner(session_id=session_id)
owner = await self._repository.get_session_owner(session_id=thread_id)
ensure_session_owner(owner_id=owner, current_user=current_user)
dedup_key = f"resume:{session_id}:{tool_call_id}"
dedup_key = f"resume:{thread_id}:{run_input.run_id}"
task_id = await self._queue.enqueue(
command={
"command": "resume",
"session_id": session_id,
"tool_call_id": tool_call_id,
"run_input": run_input.model_dump(mode="json", by_alias=True),
},
dedup_key=dedup_key,
)
return TaskAccepted(task_id=task_id, session_id=session_id, created=False)
return TaskAccepted(
task_id=task_id,
thread_id=thread_id,
run_id=run_input.run_id,
created=False,
)
async def stream_events(
self,
*,
session_id: str,
thread_id: str,
last_event_id: str | None,
current_user: CurrentUser,
) -> list[dict[str, object]]:
owner = await self._repository.get_session_owner(session_id=session_id)
owner = await self._repository.get_session_owner(session_id=thread_id)
ensure_session_owner(owner_id=owner, current_user=current_user)
return await self._stream.read(
session_id=session_id,
session_id=thread_id,
last_event_id=last_event_id,
)
async def get_history_snapshot(
self,
*,
thread_id: str,
before: date | None,
current_user: CurrentUser,
) -> dict[str, object]:
owner = await self._repository.get_session_owner(session_id=thread_id)
ensure_session_owner(owner_id=owner, current_user=current_user)
day_payload = await self._repository.get_history_day(
session_id=thread_id,
before=before,
)
snapshot = {
"scope": "history_day",
"threadId": thread_id,
"day": day_payload["day"] if day_payload else None,
"hasMore": day_payload["hasMore"] if day_payload else False,
"messages": day_payload["messages"] if day_payload else [],
}
event = StateSnapshotEvent(snapshot=snapshot).model_dump(
mode="json",
by_alias=True,
exclude_none=True,
)
event["threadId"] = thread_id
return event
async def get_user_history_snapshot(
self,
*,
current_user: CurrentUser,
thread_id: str | None,
before: date | None,
) -> dict[str, object]:
target_thread_id = thread_id
if target_thread_id is None:
target_thread_id = await self._repository.get_latest_session_id_for_user(
user_id=str(current_user.id)
)
if target_thread_id is None:
return StateSnapshotEvent(
snapshot={
"scope": "history_day",
"threadId": None,
"day": None,
"hasMore": False,
"messages": [],
}
).model_dump(mode="json", by_alias=True, exclude_none=True)
return await self.get_history_snapshot(
thread_id=target_thread_id,
before=before,
current_user=current_user,
)