feat: AG-UI 协议对齐与路由导航功能

- 前端: 添加 SSE 流式支持、stateSnapshot 事件、路由导航工具
- 前端: 实现工具调用审批流程,支持 pending 状态展示
- 后端: Agent 状态管理与会话持久化相关重构
- 文档: 新增 agent-agui-full-alignance 设计文档
- 测试: 补充相关单元测试和集成测试
This commit is contained in:
zl-q
2026-03-07 17:30:20 +08:00
parent ec33bb0cee
commit 120df903d2
52 changed files with 4305 additions and 1672 deletions
@@ -1,10 +1,22 @@
from __future__ import annotations
import json
from uuid import UUID
from ag_ui.core import (
RunAgentInput,
TextMessageContentEvent,
TextMessageEndEvent,
TextMessageStartEvent,
ToolCallResultEvent,
)
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from core.agent.application.session_state_persistence import SessionStatePersistence
from core.agent.application.session_state_persistence import (
SessionStatePersistence,
compute_tool_args_sha256,
)
from core.agent.domain.agui_input import extract_latest_tool_result
from core.agent.domain.message_metadata import (
MessageMetadataAssistantOutput,
MessageMetadataToolResult,
@@ -25,8 +37,13 @@ class ResumeService:
self._session_factory = session_factory
self._state_persistence = SessionStatePersistence()
async def resume(self, *, session_id: str, tool_call_id: str) -> dict[str, object]:
session_uuid = UUID(session_id)
async def resume(
self,
*,
run_input: RunAgentInput,
) -> dict[str, object]:
session_uuid = UUID(run_input.thread_id)
tool_call_id, tool_payload = extract_latest_tool_result(run_input)
async with self._session_factory() as db_session:
session_repository = SessionRepository(db_session)
@@ -41,6 +58,41 @@ class ResumeService:
pending_tool_call = state_snapshot.get("pending_tool_call_id")
if pending_tool_call != tool_call_id:
raise ValueError("pending tool call does not match")
pending_tool_name = state_snapshot.get("pending_tool_name")
pending_tool_args_sha256 = state_snapshot.get("pending_tool_args_sha256")
pending_tool_nonce = state_snapshot.get("pending_tool_nonce")
if (
not isinstance(pending_tool_name, str)
or not pending_tool_name
or not isinstance(pending_tool_args_sha256, str)
or not pending_tool_args_sha256
or not isinstance(pending_tool_nonce, str)
or not pending_tool_nonce
):
raise ValueError("pending tool guard is incomplete")
tool_name = tool_payload.get("toolName")
tool_args = tool_payload.get("toolArgs")
nonce = tool_payload.get("nonce")
if not isinstance(tool_name, str) or not tool_name:
raise ValueError("resume payload missing toolName")
if not isinstance(tool_args, dict):
raise ValueError("resume payload missing toolArgs")
if not isinstance(nonce, str) or not nonce:
raise ValueError("resume payload missing nonce")
if tool_name != pending_tool_name:
raise ValueError("resume toolName does not match pending tool")
if nonce != pending_tool_nonce:
raise ValueError("resume nonce does not match pending tool")
computed_args_sha256 = compute_tool_args_sha256(tool_args)
if computed_args_sha256 != pending_tool_args_sha256:
raise ValueError("resume toolArgs does not match pending tool")
sanitized_tool_payload = self._sanitize_tool_payload(
tool_name=tool_name,
tool_args=tool_args,
nonce=nonce,
tool_payload=tool_payload,
)
next_seq = await session_repository.next_message_seq(
session_id=session_uuid
@@ -49,9 +101,13 @@ class ResumeService:
session_id=session_uuid,
seq=next_seq,
role=AgentChatMessageRole.TOOL,
content='{"status":"ok"}',
content=json.dumps(
sanitized_tool_payload, ensure_ascii=True, separators=(",", ":")
),
metadata=MessageMetadataToolResult(
tool_call_id=tool_call_id,
run_id=run_input.run_id,
tool_name=tool_name,
).model_dump(),
)
await message_repository.append_message(
@@ -71,4 +127,61 @@ class ResumeService:
)
await db_session.commit()
return {"session_id": session_id, "resumed": True, "state_snapshot": snapshot}
tool_message_id = f"msg-tool-{next_seq}"
assistant_message_id = f"msg-assistant-{next_seq + 1}"
events = [
ToolCallResultEvent(
message_id=tool_message_id,
tool_call_id=tool_call_id,
content=json.dumps(
sanitized_tool_payload, ensure_ascii=True, separators=(",", ":")
),
).model_dump(mode="json", by_alias=True, exclude_none=True),
TextMessageStartEvent(
message_id=assistant_message_id,
role="assistant",
).model_dump(mode="json", by_alias=True, exclude_none=True),
TextMessageContentEvent(
message_id=assistant_message_id,
delta="Tool result received",
).model_dump(mode="json", by_alias=True, exclude_none=True),
TextMessageEndEvent(
message_id=assistant_message_id
).model_dump(mode="json", by_alias=True, exclude_none=True),
]
return {
"threadId": run_input.thread_id,
"runId": run_input.run_id,
"resumed": True,
"state_snapshot": snapshot,
"events": events,
}
@staticmethod
def _sanitize_tool_payload(
*,
tool_name: str,
tool_args: dict[str, object],
nonce: str,
tool_payload: dict[str, object],
) -> dict[str, object]:
if tool_name != "navigate_to_route":
raise ValueError("unsupported frontend tool in resume payload")
target = tool_args.get("target")
if not isinstance(target, str) or not target:
raise ValueError("resume toolArgs missing target")
raw_result = tool_payload.get("result")
if not isinstance(raw_result, dict) or raw_result.get("ok") is not True:
raise ValueError("frontend tool execution failed")
sanitized_result = {
"ok": True,
"target": target,
"replace": tool_args.get("replace") is True,
"applied": True,
}
return {
"toolName": tool_name,
"toolArgs": tool_args,
"nonce": nonce,
"result": sanitized_result,
}
+375 -27
View File
@@ -1,14 +1,34 @@
from __future__ import annotations
import asyncio
from datetime import datetime, timedelta, timezone
from decimal import Decimal
import json
import re
from uuid import UUID, uuid4
from ag_ui.core import (
TextMessageContentEvent,
TextMessageEndEvent,
TextMessageStartEvent,
ToolCallArgsEvent,
ToolCallEndEvent,
ToolCallResultEvent,
ToolCallStartEvent,
RunAgentInput,
)
from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from core.agent.application.session_state_persistence import SessionStatePersistence
from core.agent.domain.agui_input import extract_latest_user_text
from core.agent.application.session_state_persistence import (
SessionStatePersistence,
compute_tool_args_sha256,
)
from core.agent.domain.message_metadata import (
MessageMetadataAssistantOutput,
MessageMetadataToolResult,
MessageMetadataToolCall,
MessageMetadataUserInput,
)
@@ -27,6 +47,7 @@ from core.agent.infrastructure.persistence.user_context_loader import (
from core.db import AsyncSessionLocal
from models.agent_chat_message import AgentChatMessageRole
from models.agent_chat_session import AgentChatSessionStatus
from models.schedule_items import ScheduleItem, ScheduleItemSourceType, ScheduleItemStatus
from models.llm import Llm
from models.llm_factory import LlmFactory
from models.system_agents import SystemAgents
@@ -60,9 +81,14 @@ class RunService:
self._state_persistence = SessionStatePersistence()
self._user_context_cache = user_context_cache or create_user_context_cache()
async def run(self, *, session_id: str, user_input: str) -> dict[str, object]:
session_uuid = UUID(session_id)
pending_tool_call_id = f"tool-{uuid4()}"
async def run(
self,
*,
run_input: RunAgentInput,
) -> dict[str, object]:
session_uuid = UUID(run_input.thread_id)
user_input = extract_latest_user_text(run_input)
assistant_message_id = f"msg-{uuid4()}"
async with self._session_factory() as db_session:
session_repository = SessionRepository(db_session)
@@ -87,8 +113,12 @@ class RunService:
user_context = await self._load_user_agent_context(
db_session, session_uuid, chat_session.user_id
)
system_prompt = build_global_system_prompt(user_context)
runtime_result = runtime.execute(
system_prompt = self._build_system_prompt_with_tools(
base_prompt=build_global_system_prompt(user_context),
run_input=run_input,
)
runtime_result = await asyncio.to_thread(
runtime.execute,
user_input=user_input,
system_prompt=system_prompt,
)
@@ -97,7 +127,10 @@ class RunService:
completion_tokens = _to_int(runtime_result.get("completion_tokens", 0))
total_tokens = _to_int(runtime_result.get("total_tokens", 0))
cost = _to_decimal(runtime_result.get("cost", 0))
agui_events = runtime_result.get("agui_events", [])
planned_tool = self._select_tool_plan(
user_input=user_input,
available_tools={tool.name for tool in run_input.tools},
)
next_seq = await session_repository.next_message_seq(
session_id=session_uuid
@@ -110,39 +143,354 @@ class RunService:
model_code=model_code,
metadata=MessageMetadataUserInput().model_dump(),
)
await message_repository.append_message(
session_id=session_uuid,
seq=next_seq + 1,
role=AgentChatMessageRole.ASSISTANT,
content=assistant_text or "Tool call pending approval",
model_code=model_code,
metadata=MessageMetadataToolCall(
tool_call_id=pending_tool_call_id,
).model_dump(),
input_tokens=prompt_tokens,
output_tokens=completion_tokens,
cost=cost,
)
pending_tool_call_id: str | None = None
events: list[dict[str, object]] = []
message_delta = 2
session_status = AgentChatSessionStatus.COMPLETED
snapshot = self._state_persistence.build_completed_snapshot()
if planned_tool is None:
await message_repository.append_message(
session_id=session_uuid,
seq=next_seq + 1,
role=AgentChatMessageRole.ASSISTANT,
content=assistant_text or "已完成处理。",
model_code=model_code,
metadata=MessageMetadataAssistantOutput().model_dump(),
input_tokens=prompt_tokens,
output_tokens=completion_tokens,
cost=cost,
)
events.extend(
self._build_text_message_events(
message_id=assistant_message_id,
text=assistant_text or "已完成处理。",
)
)
elif planned_tool["target"] == "backend":
tool_call_id = f"tool-{uuid4()}"
tool_name = str(planned_tool["name"])
tool_args = planned_tool["args"]
if not isinstance(tool_args, dict):
tool_args = {}
tool_payload = await self._execute_backend_tool(
session=db_session,
owner_id=chat_session.user_id,
tool_name=tool_name,
tool_args=tool_args,
)
await message_repository.append_message(
session_id=session_uuid,
seq=next_seq + 1,
role=AgentChatMessageRole.TOOL,
content=json.dumps(
tool_payload,
ensure_ascii=True,
separators=(",", ":"),
),
metadata=MessageMetadataToolResult(
tool_call_id=tool_call_id,
run_id=run_input.run_id,
tool_name=tool_name,
).model_dump(),
)
await message_repository.append_message(
session_id=session_uuid,
seq=next_seq + 2,
role=AgentChatMessageRole.ASSISTANT,
content=assistant_text or "后端工具执行完成。",
model_code=model_code,
metadata=MessageMetadataAssistantOutput().model_dump(),
input_tokens=prompt_tokens,
output_tokens=completion_tokens,
cost=cost,
)
message_delta = 3
events.extend(
self._build_tool_call_events(
tool_call_id=tool_call_id,
tool_name=tool_name,
tool_args=tool_args,
)
)
events.append(
ToolCallResultEvent(
message_id=f"msg-tool-{uuid4()}",
tool_call_id=tool_call_id,
content=json.dumps(
tool_payload,
ensure_ascii=True,
separators=(",", ":"),
),
).model_dump(mode="json", by_alias=True, exclude_none=True)
)
events.extend(
self._build_text_message_events(
message_id=assistant_message_id,
text=assistant_text or "后端工具执行完成。",
)
)
else:
pending_tool_call_id = f"tool-{uuid4()}"
tool_name = str(planned_tool["name"])
tool_args = planned_tool["args"]
if not isinstance(tool_args, dict):
tool_args = {}
pending_tool_nonce = uuid4().hex
guarded_tool_args = {
**tool_args,
"__nonce": pending_tool_nonce,
}
pending_tool_args_sha256 = compute_tool_args_sha256(guarded_tool_args)
await message_repository.append_message(
session_id=session_uuid,
seq=next_seq + 1,
role=AgentChatMessageRole.ASSISTANT,
content=assistant_text or "Tool call pending approval",
model_code=model_code,
metadata=MessageMetadataToolCall(
tool_call_id=pending_tool_call_id,
).model_dump(),
input_tokens=prompt_tokens,
output_tokens=completion_tokens,
cost=cost,
)
snapshot = self._state_persistence.build_running_snapshot(
pending_tool_call_id=pending_tool_call_id,
pending_tool_name=tool_name,
pending_tool_args_sha256=pending_tool_args_sha256,
pending_tool_nonce=pending_tool_nonce,
)
session_status = AgentChatSessionStatus.RUNNING
events.extend(
self._build_tool_call_events(
tool_call_id=pending_tool_call_id,
tool_name=tool_name,
tool_args=guarded_tool_args,
)
)
events.extend(
self._build_text_message_events(
message_id=assistant_message_id,
text=assistant_text or "请确认是否执行前端工具。",
)
)
snapshot = self._state_persistence.build_running_snapshot(
pending_tool_call_id=pending_tool_call_id
)
await session_repository.update_runtime_state(
chat_session=chat_session,
status=AgentChatSessionStatus.RUNNING,
status=session_status,
state_snapshot=snapshot,
message_delta=2,
message_delta=message_delta,
token_delta=total_tokens,
cost_delta=cost,
)
await db_session.commit()
return {
"session_id": session_id,
"threadId": run_input.thread_id,
"runId": run_input.run_id,
"persisted": True,
"pending_tool_call_id": pending_tool_call_id,
"state_snapshot": snapshot,
"events": agui_events,
"events": events,
}
def _build_system_prompt_with_tools(
self, *, base_prompt: str, run_input: RunAgentInput
) -> str:
if not run_input.tools:
return base_prompt
tool_lines = [
f"- {tool.name}: {tool.description}" for tool in run_input.tools
]
tools_block = "\n".join(tool_lines)
return f"# AVAILABLE_TOOLS\n{tools_block}\n\n{base_prompt}"
def _select_tool_plan(
self,
*,
user_input: str,
available_tools: set[str],
) -> dict[str, object] | None:
forced = re.search(r"#tool:(\w+)\s*(\{.*\})?", user_input)
if forced is not None:
forced_name = forced.group(1)
if forced_name not in available_tools:
return None
raw_args = forced.group(2)
args: dict[str, object] = {}
if raw_args:
try:
parsed = json.loads(raw_args)
if isinstance(parsed, dict):
args = parsed
except (TypeError, ValueError):
args = {}
target = (
"frontend" if forced_name == "navigate_to_route" else "backend"
)
return {"name": forced_name, "args": args, "target": target}
normalized = user_input.lower()
wants_navigation = any(
keyword in normalized for keyword in ("打开", "跳转", "进入", "navigate", "open")
)
if wants_navigation and "navigate_to_route" in available_tools:
target_route = "/calendar/dayweek"
if "设置" in user_input:
target_route = "/settings"
elif "待办" in user_input:
target_route = "/todo"
return {
"name": "navigate_to_route",
"args": {"target": target_route, "replace": False},
"target": "frontend",
}
return None
def _infer_calendar_args(self, user_input: str) -> dict[str, object]:
start_at = datetime.now(timezone.utc) + timedelta(hours=1)
title = user_input.strip()[:80] or "新的日程"
return {
"title": title,
"description": user_input.strip(),
"startAt": start_at.isoformat(),
"timezone": "Asia/Shanghai",
}
def _build_text_message_events(
self, *, message_id: str, text: str
) -> list[dict[str, object]]:
events: list[dict[str, object]] = [
TextMessageStartEvent(
message_id=message_id,
role="assistant",
).model_dump(mode="json", by_alias=True, exclude_none=True),
]
if text:
events.append(
TextMessageContentEvent(
message_id=message_id,
delta=text,
).model_dump(mode="json", by_alias=True, exclude_none=True)
)
events.append(
TextMessageEndEvent(
message_id=message_id
).model_dump(mode="json", by_alias=True, exclude_none=True)
)
return events
def _build_tool_call_events(
self,
*,
tool_call_id: str,
tool_name: str,
tool_args: dict[str, object],
) -> list[dict[str, object]]:
return [
ToolCallStartEvent(
tool_call_id=tool_call_id,
tool_call_name=tool_name,
).model_dump(mode="json", by_alias=True, exclude_none=True),
ToolCallArgsEvent(
tool_call_id=tool_call_id,
delta=json.dumps(tool_args, ensure_ascii=True, separators=(",", ":")),
).model_dump(mode="json", by_alias=True, exclude_none=True),
ToolCallEndEvent(
tool_call_id=tool_call_id
).model_dump(mode="json", by_alias=True, exclude_none=True),
]
async def _execute_backend_tool(
self,
*,
session: AsyncSession,
owner_id: UUID,
tool_name: str,
tool_args: dict[str, object],
) -> dict[str, object]:
if tool_name != "create_calendar_event":
raise ValueError(f"unsupported backend tool: {tool_name}")
title = str(tool_args.get("title", "新的日程")).strip() or "新的日程"
description = str(tool_args.get("description", "")).strip() or None
start_raw = tool_args.get("startAt")
start_at = datetime.now(timezone.utc) + timedelta(hours=1)
if isinstance(start_raw, str) and start_raw:
try:
parsed = datetime.fromisoformat(start_raw.replace("Z", "+00:00"))
if parsed.tzinfo is None:
parsed = parsed.replace(tzinfo=timezone.utc)
start_at = parsed.astimezone(timezone.utc)
except ValueError:
start_at = datetime.now(timezone.utc) + timedelta(hours=1)
end_raw = tool_args.get("endAt")
end_at: datetime | None = None
if isinstance(end_raw, str) and end_raw:
try:
parsed_end = datetime.fromisoformat(end_raw.replace("Z", "+00:00"))
if parsed_end.tzinfo is None:
parsed_end = parsed_end.replace(tzinfo=timezone.utc)
end_at = parsed_end.astimezone(timezone.utc)
except ValueError:
end_at = None
timezone_value = str(tool_args.get("timezone", "Asia/Shanghai"))
location = tool_args.get("location")
location_value = str(location) if isinstance(location, str) else None
schedule_item = ScheduleItem(
owner_id=owner_id,
title=title,
description=description,
start_at=start_at,
end_at=end_at,
timezone=timezone_value,
extra_metadata={"location": location_value} if location_value else {},
source_type=ScheduleItemSourceType.AGENT_GENERATED,
status=ScheduleItemStatus.ACTIVE,
created_by=owner_id,
)
session.add(schedule_item)
await session.flush()
event_id = str(schedule_item.id)
ui_card = {
"type": "calendar_card.v1",
"version": "v1",
"data": {
"id": event_id,
"title": title,
"description": description,
"startAt": start_at.isoformat(),
"endAt": end_at.isoformat() if end_at is not None else None,
"timezone": timezone_value,
"location": location_value,
"color": "#4F46E5",
"sourceType": "agent_generated",
},
"actions": [
{
"type": "link",
"label": "查看详情",
"target": f"/calendar/events/{event_id}",
}
],
}
return {
"result": {
"eventId": event_id,
"ok": True,
"message": "日程已创建",
"title": title,
"description": description,
"startAt": start_at.isoformat(),
"endAt": end_at.isoformat() if end_at is not None else None,
"timezone": timezone_value,
"location": location_value,
"sourceType": "agent_generated",
},
"ui": ui_card,
}
async def _load_user_agent_context(
@@ -10,17 +10,35 @@ from core.agent.domain.state_snapshot import AgentStateSnapshot
class SessionStatePersistence:
def build_running_snapshot(
self, *, pending_tool_call_id: str | None
self,
*,
pending_tool_call_id: str | None,
pending_tool_name: str | None = None,
pending_tool_args_sha256: str | None = None,
pending_tool_nonce: str | None = None,
) -> dict[str, object]:
return AgentStateSnapshot(
status="running",
pending_tool_call_id=pending_tool_call_id,
pending_tool_name=pending_tool_name,
pending_tool_args_sha256=pending_tool_args_sha256,
pending_tool_nonce=pending_tool_nonce,
).model_dump()
def build_completed_snapshot(self) -> dict[str, object]:
return AgentStateSnapshot(status="completed").model_dump()
def compute_tool_args_sha256(tool_args: dict[str, object]) -> str:
encoded = json.dumps(
tool_args,
ensure_ascii=True,
sort_keys=True,
separators=(",", ":"),
).encode("utf-8")
return hashlib.sha256(encoded).hexdigest()
class ToolResultStorage(Protocol):
async def upload_json(
self,
+109
View File
@@ -0,0 +1,109 @@
from __future__ import annotations
import json
from typing import Any
from uuid import UUID
from ag_ui.core import RunAgentInput
from pydantic import ValidationError
MAX_RUN_INPUT_BYTES = 256_000
MAX_RUN_ID_LENGTH = 128
MAX_MESSAGES = 200
MAX_TEXT_CHARS = 10_000
def _safe_len(value: str | None) -> int:
if value is None:
return 0
return len(value)
def _user_text_chars(run_input: RunAgentInput) -> int:
total = 0
for message in run_input.messages:
if getattr(message, "role", None) != "user":
continue
content = getattr(message, "content", None)
if isinstance(content, str):
total += len(content)
continue
if isinstance(content, list):
for item in content:
if getattr(item, "type", None) != "text":
continue
text = getattr(item, "text", None)
if isinstance(text, str):
total += len(text)
return total
def parse_run_input(payload: dict[str, Any]) -> RunAgentInput:
payload_bytes = len(
json.dumps(payload, ensure_ascii=True, separators=(",", ":")).encode("utf-8")
)
if payload_bytes > MAX_RUN_INPUT_BYTES:
raise ValueError("RunAgentInput payload exceeds size limit")
try:
run_input = RunAgentInput.model_validate(payload)
except ValidationError as exc:
raise ValueError("invalid AG-UI RunAgentInput payload") from exc
try:
UUID(run_input.thread_id)
except ValueError as exc:
raise ValueError("threadId must be a valid UUID") from exc
if _safe_len(run_input.run_id) > MAX_RUN_ID_LENGTH:
raise ValueError("runId exceeds length limit")
if len(run_input.messages) > MAX_MESSAGES:
raise ValueError("RunAgentInput.messages exceeds limit")
if _user_text_chars(run_input) > MAX_TEXT_CHARS:
raise ValueError("RunAgentInput user message text exceeds limit")
return run_input
def extract_latest_user_text(run_input: RunAgentInput) -> str:
for message in reversed(run_input.messages):
role = getattr(message, "role", None)
if role != "user":
continue
content = getattr(message, "content", None)
if isinstance(content, str):
text = content.strip()
if text:
return text
continue
if isinstance(content, list):
text_parts: list[str] = []
for item in content:
if getattr(item, "type", None) != "text":
continue
text = getattr(item, "text", None)
if isinstance(text, str):
text_parts.append(text)
combined = "".join(text_parts).strip()
if combined:
return combined
raise ValueError("RunAgentInput.messages requires at least one non-empty user message")
def extract_latest_tool_result(run_input: RunAgentInput) -> tuple[str, dict[str, object]]:
for message in reversed(run_input.messages):
role = getattr(message, "role", None)
if role != "tool":
continue
tool_call_id = getattr(message, "tool_call_id", None)
content = getattr(message, "content", None)
if not isinstance(tool_call_id, str) or not tool_call_id:
continue
if not isinstance(content, str):
break
try:
parsed = json.loads(content)
except (TypeError, ValueError):
return tool_call_id, {"content": content}
if isinstance(parsed, dict):
return tool_call_id, parsed
return tool_call_id, {"content": content}
raise ValueError(
"RunAgentInput.messages requires a tool message with toolCallId for resume"
)
@@ -8,3 +8,6 @@ from pydantic import BaseModel
class AgentStateSnapshot(BaseModel):
status: Literal["pending", "running", "completed", "failed"]
pending_tool_call_id: str | None = None
pending_tool_name: str | None = None
pending_tool_args_sha256: str | None = None
pending_tool_nonce: str | None = None
@@ -6,3 +6,4 @@ from pydantic import BaseModel, Field
class SystemAgentLLMConfig(BaseModel):
temperature: float | None = Field(default=None, ge=0.0, le=2.0)
max_tokens: int | None = Field(default=None, ge=1)
timeout_seconds: float | None = Field(default=30.0, gt=0.0, le=300.0)
@@ -1,10 +1,16 @@
from __future__ import annotations
import json
import re
from typing import Any
_EVENT_TYPE_RE = re.compile(r"^[A-Z0-9_]+$")
def to_sse_event(stream_id: str, event: dict[str, Any]) -> str:
event_type = str(event.get("type", "MESSAGE"))
payload = json.dumps(event.get("data", {}), ensure_ascii=True)
raw_event_type = str(event.get("type", "MESSAGE")).replace("\r", "").replace(
"\n", ""
)
event_type = raw_event_type if _EVENT_TYPE_RE.fullmatch(raw_event_type) else "MESSAGE"
payload = json.dumps(event, ensure_ascii=True, separators=(",", ":"))
return f"id: {stream_id}\nevent: {event_type}\ndata: {payload}\n\n"
@@ -129,6 +129,7 @@ def _run_stage(
messages=messages,
temperature=llm_config.temperature,
max_tokens=llm_config.max_tokens,
timeout=llm_config.timeout_seconds,
)
if not isinstance(response, dict):
raise ValueError("llm response must be a dict")
@@ -46,7 +46,7 @@ class RedisStreamEventStore:
last_event_id: str | None,
) -> list[dict[str, Any]]:
stream = self._stream_name(session_id)
start_id = "$" if last_event_id is None else last_event_id
start_id = "0-0" if last_event_id is None else last_event_id
raw_response = self._client.xread(
{stream: start_id},
count=self._read_count,
@@ -59,13 +59,37 @@ class RedisStreamEventStore:
if not response:
return []
_, entries = response[0]
first = response[0]
if (
not isinstance(first, tuple)
or len(first) != 2
or not isinstance(first[1], list)
):
return []
_, entries = first
result: list[dict[str, Any]] = []
for stream_id, payload in entries:
for entry in entries:
if (
not isinstance(entry, tuple)
or len(entry) != 2
or not isinstance(entry[0], str)
or not isinstance(entry[1], dict)
):
continue
stream_id, payload = entry
event_payload = payload.get("event")
if not isinstance(event_payload, str):
continue
try:
parsed_event = json.loads(event_payload)
except (TypeError, ValueError):
continue
if not isinstance(parsed_event, dict):
continue
result.append(
{
"id": stream_id,
"event": json.loads(payload["event"]),
"event": parsed_event,
}
)
return result
@@ -12,6 +12,7 @@ def run_completion(
messages: list[dict[str, Any]],
temperature: float | None = None,
max_tokens: int | None = None,
timeout: float | None = None,
) -> Any:
kwargs: dict[str, Any] = {
"model": model,
@@ -23,6 +24,8 @@ def run_completion(
kwargs["temperature"] = temperature
if max_tokens is not None:
kwargs["max_tokens"] = max_tokens
if timeout is not None:
kwargs["timeout"] = timeout
response = completion(**kwargs)
model_dump = getattr(response, "model_dump", None)
@@ -9,6 +9,9 @@ import redis.asyncio as redis
from core.agent.domain.user_context import UserAgentContext, parse_profile_settings
from core.config.settings import config
from core.logging import get_logger
logger = get_logger("core.agent.infrastructure.persistence.user_context_cache")
class RedisHashClient(Protocol):
@@ -47,7 +50,12 @@ class UserContextCache:
key = self._key(session_id)
try:
raw = await _maybe_await(self._client.hgetall(key))
except Exception:
except Exception as exc:
logger.warning(
"Failed to read user context cache",
session_id=str(session_id),
error=str(exc),
)
return None
if not isinstance(raw, dict) or not raw:
@@ -92,7 +100,12 @@ class UserContextCache:
)
)
await _maybe_await(self._client.expire(key, self._ttl_seconds))
except Exception:
except Exception as exc:
logger.warning(
"Failed to write user context cache",
session_id=str(session_id),
error=str(exc),
)
return None
def _key(self, session_id: UUID) -> str:
@@ -136,13 +149,21 @@ class UserContextCache:
async def _safe_delete(self, key: str) -> None:
try:
await _maybe_await(self._client.delete(key))
except Exception:
except Exception as exc:
logger.warning("Failed to delete user context cache key", key=key, error=str(exc))
return None
async def _safe_hincrby(self, key: str, field: str, amount: int) -> None:
try:
await _maybe_await(self._client.hincrby(key, field, amount))
except Exception:
except Exception as exc:
logger.warning(
"Failed to update user context cache usage",
key=key,
field=field,
amount=amount,
error=str(exc),
)
return None
@@ -2,7 +2,10 @@ from __future__ import annotations
from typing import Any, Protocol
from uuid import UUID
import re
from ag_ui.core import RunAgentInput, RunErrorEvent, RunFinishedEvent, RunStartedEvent
from core.agent.domain.agui_input import parse_run_input
from core.agent.application.resume_service import ResumeService
from core.agent.application.run_service import RunService
from core.agent.infrastructure.events.redis_stream import RedisStreamEventStore
@@ -13,19 +16,65 @@ from services.base.redis import get_or_init_redis_client
logger = get_logger("core.agent.infrastructure.queue.tasks")
_NON_ALNUM_RE = re.compile(r"[^A-Za-z0-9]+")
_SENSITIVE_KEYS = {
"apikey",
"authorization",
"token",
"accesstoken",
"refreshtoken",
"secret",
"password",
"cookie",
}
class PublishEvent(Protocol):
async def __call__(self, event_type: str, payload: dict[str, object]) -> None: ...
async def __call__(self, event: dict[str, object]) -> None: ...
class RunServiceLike(Protocol):
async def run(self, *, session_id: str, user_input: str) -> dict[str, object]: ...
async def run(self, *, run_input: RunAgentInput) -> dict[str, object]: ...
class ResumeServiceLike(Protocol):
async def resume(
self, *, session_id: str, tool_call_id: str
) -> dict[str, object]: ...
async def resume(self, *, run_input: RunAgentInput) -> dict[str, object]: ...
def _is_sensitive_key(key: str) -> bool:
normalized = _NON_ALNUM_RE.sub("", key.lower())
if normalized in _SENSITIVE_KEYS:
return True
if "token" in normalized:
return True
if "api" in normalized and "key" in normalized:
return True
return False
def _redact_sensitive(value: Any) -> Any:
if isinstance(value, dict):
return {
k: "***REDACTED***" if _is_sensitive_key(str(k)) else _redact_sensitive(v)
for k, v in value.items()
}
if isinstance(value, list):
return [_redact_sensitive(item) for item in value]
return value
def _normalize_stream_event(
*,
event: dict[str, object],
thread_id: str,
run_id: str,
) -> dict[str, object]:
normalized = dict(event)
normalized["threadId"] = thread_id
normalized["runId"] = run_id
if normalized.get("type") == "RUN_STARTED":
normalized.pop("input", None)
return _redact_sensitive(normalized)
async def _build_redis_publisher() -> PublishEvent:
@@ -37,13 +86,13 @@ async def _build_redis_publisher() -> PublishEvent:
block_ms=config.agent_runtime.redis_stream_block_ms,
)
async def _publish(event_type: str, payload: dict[str, object]) -> None:
session_id = str(payload.get("session_id", "")).strip()
if not session_id:
raise ValueError("session_id is required in event payload")
async def _publish(event: dict[str, object]) -> None:
thread_id = str(event.get("threadId", "")).strip()
if not thread_id:
raise ValueError("threadId is required in event payload")
await event_store.append_event(
session_id=UUID(session_id),
event={"type": event_type, "data": payload},
session_id=UUID(thread_id),
event=event,
)
return _publish
@@ -61,69 +110,69 @@ async def run_agent_task(
service_resume = resume_service or ResumeService()
command_type = str(command.get("command", "run"))
session_id = str(command.get("session_id", ""))
if command_type not in {"run", "resume"}:
raise ValueError("invalid command type")
if not session_id:
raise ValueError("session_id is required")
UUID(session_id)
raw_run_input = command.get("run_input")
if not isinstance(raw_run_input, dict):
raise ValueError("run_input is required")
run_input = parse_run_input(raw_run_input)
UUID(run_input.thread_id)
tool_call_id = ""
user_input = ""
if command_type == "resume":
tool_call_id = str(command.get("tool_call_id", ""))
if not tool_call_id:
raise ValueError("tool_call_id is required")
else:
user_input = str(command.get("user_input", ""))
if not user_input:
raise ValueError("user_input is required")
start_event = "RUN_RESUMED" if command_type == "resume" else "RUN_STARTED"
await publisher(start_event, {"session_id": session_id})
await publisher(
RunStartedEvent(
thread_id=run_input.thread_id,
run_id=run_input.run_id,
parent_run_id=run_input.parent_run_id,
).model_dump(mode="json", by_alias=True, exclude_none=True)
)
try:
if command_type == "resume":
result = await service_resume.resume(
session_id=session_id,
tool_call_id=tool_call_id,
)
result = await service_resume.resume(run_input=run_input)
else:
result = await service_run.run(
session_id=session_id,
user_input=user_input,
)
result = await service_run.run(run_input=run_input)
await publisher("RUNTIME_EVENT", {"session_id": session_id, "result": result})
extra_events = result.get("events") if isinstance(result, dict) else None
if isinstance(extra_events, list):
for event in extra_events:
if not isinstance(event, dict):
continue
event_type = event.get("type")
event_data = event.get("data")
if not isinstance(event_type, str) or not isinstance(event_data, dict):
if not isinstance(event_type, str):
continue
payload = {"session_id": session_id, **event_data}
await publisher(event_type, payload)
await publisher("RUN_FINISHED", {"session_id": session_id})
await publisher(
_normalize_stream_event(
event=event,
thread_id=run_input.thread_id,
run_id=run_input.run_id,
)
)
await publisher(
RunFinishedEvent(
thread_id=run_input.thread_id,
run_id=run_input.run_id,
).model_dump(mode="json", by_alias=True, exclude_none=True)
)
return result
except Exception: # noqa: BLE001
error_id = "agent_runtime_failed"
logger.exception(
"Agent task failed",
session_id=session_id,
thread_id=run_input.thread_id,
error_id=error_id,
)
try:
await publisher(
"RUN_ERROR", {"session_id": session_id, "error_id": error_id}
)
error_event = RunErrorEvent(
message="Agent task failed",
code=error_id,
).model_dump(mode="json", by_alias=True, exclude_none=True)
error_event["threadId"] = run_input.thread_id
error_event["runId"] = run_input.run_id
await publisher(error_event)
except Exception as publish_exc: # noqa: BLE001
logger.warning(
"Failed to publish RUN_ERROR event",
session_id=session_id,
thread_id=run_input.thread_id,
error=str(publish_exc),
)
raise
+124 -1
View File
@@ -1,11 +1,14 @@
from __future__ import annotations
from datetime import date, datetime, time, timedelta, timezone
import json
from uuid import UUID
from fastapi import HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
from models.agent_chat_session import AgentChatSession
@@ -27,13 +30,22 @@ class AgentRepository:
raise HTTPException(status_code=404, detail="Session not found")
return str(owner_id)
async def create_session_for_user(self, *, user_id: str) -> str:
async def create_session_for_user(
self, *, user_id: str, session_id: str | None = None
) -> str:
try:
user_uuid = UUID(user_id)
except ValueError as exc:
raise HTTPException(status_code=422, detail="Invalid user_id") from exc
session_uuid = None
if session_id is not None:
try:
session_uuid = UUID(session_id)
except ValueError as exc:
raise HTTPException(status_code=422, detail="Invalid session_id") from exc
session = AgentChatSession(
id=session_uuid,
user_id=user_uuid,
)
self._session.add(session)
@@ -56,3 +68,114 @@ class AgentRepository:
if session is not None:
await self._session.delete(session)
await self._session.flush()
async def get_history_day(
self, *, session_id: str, before: date | None
) -> dict[str, object] | None:
try:
session_uuid = UUID(session_id)
except ValueError as exc:
raise HTTPException(status_code=422, detail="Invalid session_id") from exc
timestamp_stmt = (
select(AgentChatMessage.created_at)
.where(AgentChatMessage.session_id == session_uuid)
.where(AgentChatMessage.deleted_at.is_(None))
.order_by(AgentChatMessage.created_at.desc())
)
rows = (await self._session.execute(timestamp_stmt)).scalars().all()
unique_days: list[date] = []
for created_at in rows:
if created_at is None:
continue
day = created_at.astimezone(timezone.utc).date()
if day not in unique_days:
unique_days.append(day)
if not unique_days:
return None
target_day: date | None = None
if before is None:
target_day = unique_days[0]
else:
for day in unique_days:
if day < before:
target_day = day
break
if target_day is None:
return None
start = datetime.combine(target_day, time.min, tzinfo=timezone.utc)
end = start + timedelta(days=1)
message_stmt = (
select(AgentChatMessage)
.where(AgentChatMessage.session_id == session_uuid)
.where(AgentChatMessage.deleted_at.is_(None))
.where(AgentChatMessage.created_at >= start)
.where(AgentChatMessage.created_at < end)
.order_by(AgentChatMessage.seq.asc())
)
messages = (await self._session.execute(message_stmt)).scalars().all()
has_more = any(day < target_day for day in unique_days)
return {
"day": target_day.isoformat(),
"hasMore": has_more,
"messages": [self._to_snapshot_message(msg) for msg in messages],
}
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None:
try:
user_uuid = UUID(user_id)
except ValueError as exc:
raise HTTPException(status_code=422, detail="Invalid user_id") from exc
stmt = (
select(AgentChatSession.id)
.where(AgentChatSession.user_id == user_uuid)
.where(AgentChatSession.deleted_at.is_(None))
.order_by(AgentChatSession.last_activity_at.desc())
.limit(1)
)
latest_id = (await self._session.execute(stmt)).scalar_one_or_none()
if latest_id is None:
return None
return str(latest_id)
@staticmethod
def _to_snapshot_message(message: AgentChatMessage) -> dict[str, object]:
role = (
message.role.value
if isinstance(message.role, AgentChatMessageRole)
else str(message.role)
)
payload: dict[str, object] = {
"id": str(message.id),
"role": role,
"timestamp": message.created_at.astimezone(timezone.utc).isoformat(),
}
if role == AgentChatMessageRole.TOOL.value:
metadata = message.metadata_json or {}
tool_call_id = metadata.get("tool_call_id")
if isinstance(tool_call_id, str) and tool_call_id:
payload["toolCallId"] = tool_call_id
parsed_content: dict[str, object] | None = None
try:
decoded = json.loads(message.content)
if isinstance(decoded, dict):
parsed_content = decoded
except (TypeError, ValueError):
parsed_content = None
if parsed_content is not None:
ui = parsed_content.get("ui")
if isinstance(ui, dict):
payload["ui"] = ui
display_content = parsed_content.get("content")
if isinstance(display_content, str):
payload["content"] = display_content
else:
payload["content"] = message.content
else:
payload["content"] = message.content
return payload
+141 -32
View File
@@ -2,96 +2,177 @@ from __future__ import annotations
from collections.abc import AsyncIterator
import asyncio
from datetime import date
import re
import time
from typing import Annotated
from ag_ui.core import RunAgentInput
from fastapi import APIRouter, Depends, Header, Query, Request, status
from fastapi import HTTPException
from fastapi.responses import StreamingResponse
from core.agent.infrastructure.agui.stream import to_sse_event
from core.agent.domain.agui_input import parse_run_input
from core.auth.models import CurrentUser
from services.base.redis import get_or_init_redis_client
from v1.agent.dependencies import get_agent_service
from v1.agent.schemas import ResumeRequest, RunRequest, TaskAcceptedResponse
from v1.agent.schemas import TaskAcceptedResponse
from v1.agent.service import AgentService
from v1.users.dependencies import get_current_user
router = APIRouter(prefix="/agent", tags=["agent"])
_LAST_EVENT_ID_RE = re.compile(r"^\d+-\d+$")
_RUNS_PER_MINUTE = 30
_MAX_SSE_CONNECTIONS_PER_USER = 3
_SSE_SLOT_TTL_SECONDS = 15 * 60
async def _allow_run_request(*, user_id: str) -> bool:
try:
redis = await get_or_init_redis_client()
minute_bucket = int(time.time() // 60)
key = f"agent:run-rate:{user_id}:{minute_bucket}"
count = await redis.incr(key)
if count == 1:
await redis.expire(key, 70)
return int(count) <= _RUNS_PER_MINUTE
except Exception: # noqa: BLE001
return False
async def _acquire_sse_slot(*, user_id: str) -> bool:
try:
redis = await get_or_init_redis_client()
key = f"agent:sse-active:{user_id}"
count = await redis.incr(key)
if count == 1:
await redis.expire(key, _SSE_SLOT_TTL_SECONDS)
if int(count) > _MAX_SSE_CONNECTIONS_PER_USER:
await redis.decr(key)
return False
return True
except Exception: # noqa: BLE001
return False
async def _release_sse_slot(*, user_id: str) -> None:
try:
redis = await get_or_init_redis_client()
key = f"agent:sse-active:{user_id}"
count = await redis.decr(key)
if int(count) <= 0:
await redis.delete(key)
except Exception: # noqa: BLE001
return None
@router.post(
"/runs", response_model=TaskAcceptedResponse, status_code=status.HTTP_202_ACCEPTED
)
async def enqueue_run(
request: RunRequest,
request: RunAgentInput,
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
) -> TaskAcceptedResponse:
try:
parse_run_input(request.model_dump(mode="json", by_alias=True))
except ValueError as exc:
raise HTTPException(status_code=422, detail=str(exc)) from exc
allowed = await _allow_run_request(user_id=str(current_user.id))
if not allowed:
raise HTTPException(status_code=429, detail="Too many run requests")
task = await service.enqueue_run(
session_id=request.session_id,
prompt=request.prompt,
run_input=request,
current_user=current_user,
)
return TaskAcceptedResponse(
task_id=task.task_id,
session_id=task.session_id,
thread_id=task.thread_id,
run_id=task.run_id,
created=task.created,
)
@router.post(
"/runs/{session_id}/resume",
"/runs/{thread_id}/resume",
response_model=TaskAcceptedResponse,
status_code=status.HTTP_202_ACCEPTED,
)
async def enqueue_resume(
session_id: str,
request: ResumeRequest,
thread_id: str,
request: RunAgentInput,
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
) -> TaskAcceptedResponse:
if request.thread_id != thread_id:
raise HTTPException(status_code=422, detail="thread_id path/body mismatch")
try:
parse_run_input(request.model_dump(mode="json", by_alias=True))
except ValueError as exc:
raise HTTPException(status_code=422, detail=str(exc)) from exc
task = await service.enqueue_resume(
session_id=session_id,
tool_call_id=request.tool_call_id,
thread_id=thread_id,
run_input=request,
current_user=current_user,
)
return TaskAcceptedResponse(
task_id=task.task_id,
session_id=task.session_id,
thread_id=task.thread_id,
run_id=task.run_id,
created=task.created,
)
@router.get("/runs/{session_id}/events")
@router.get("/runs/{thread_id}/events")
async def stream_events(
request: Request,
session_id: str,
thread_id: str,
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
last_event_id: str | None = Header(default=None, alias="Last-Event-ID"),
idle_limit: int = Query(default=300, ge=1, le=3600),
) -> StreamingResponse:
if (
last_event_id is not None
and (
len(last_event_id) > 32
or _LAST_EVENT_ID_RE.fullmatch(last_event_id) is None
)
):
raise HTTPException(status_code=422, detail="Invalid Last-Event-ID")
sse_slot_acquired = await _acquire_sse_slot(user_id=str(current_user.id))
if not sse_slot_acquired:
raise HTTPException(status_code=429, detail="Too many SSE connections")
async def _event_iter() -> AsyncIterator[str]:
cursor = last_event_id
idle_polls = 0
while not await request.is_disconnected() and idle_polls < idle_limit:
rows = await service.stream_events(
session_id=session_id,
last_event_id=cursor,
current_user=current_user,
)
if not rows:
idle_polls += 1
yield ": keep-alive\n\n"
await asyncio.sleep(0.2)
continue
idle_polls = 0
for row in rows:
row_id = str(row.get("id", ""))
event = row.get("event")
if not row_id or not isinstance(event, dict):
try:
while not await request.is_disconnected() and idle_polls < idle_limit:
rows = await service.stream_events(
thread_id=thread_id,
last_event_id=cursor,
current_user=current_user,
)
if not rows:
idle_polls += 1
yield ": keep-alive\n\n"
await asyncio.sleep(0.2)
continue
cursor = row_id
yield to_sse_event(row_id, event)
idle_polls = 0
for row in rows:
row_id = str(row.get("id", ""))
event = row.get("event")
if not row_id or not isinstance(event, dict):
continue
cursor = row_id
yield to_sse_event(row_id, event)
finally:
await _release_sse_slot(user_id=str(current_user.id))
return StreamingResponse(
_event_iter(),
@@ -102,3 +183,31 @@ async def stream_events(
"X-Accel-Buffering": "no",
},
)
@router.get("/runs/{thread_id}/history")
async def get_history_snapshot(
thread_id: str,
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
before: date | None = Query(default=None),
) -> dict[str, object]:
return await service.get_history_snapshot(
thread_id=thread_id,
before=before,
current_user=current_user,
)
@router.get("/history")
async def get_user_history_snapshot(
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
thread_id: str | None = Query(default=None, alias="threadId"),
before: date | None = Query(default=None),
) -> dict[str, object]:
return await service.get_user_history_snapshot(
current_user=current_user,
thread_id=thread_id,
before=before,
)
+6 -12
View File
@@ -1,18 +1,12 @@
from __future__ import annotations
from pydantic import BaseModel, Field
class RunRequest(BaseModel):
session_id: str | None = Field(default=None, min_length=1, max_length=100)
prompt: str = Field(min_length=1, max_length=5000)
class ResumeRequest(BaseModel):
tool_call_id: str = Field(min_length=1, max_length=200)
from pydantic import BaseModel, ConfigDict, Field
class TaskAcceptedResponse(BaseModel):
task_id: str
session_id: str
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
task_id: str = Field(alias="taskId")
thread_id: str = Field(alias="threadId")
run_id: str = Field(alias="runId")
created: bool
+109 -29
View File
@@ -1,9 +1,13 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import date
from typing import Protocol
from ag_ui.core import StateSnapshotEvent
from ag_ui.core import RunAgentInput
from fastapi import HTTPException
from sqlalchemy.exc import IntegrityError
from core.auth.models import CurrentUser
@@ -11,19 +15,28 @@ from core.auth.models import CurrentUser
@dataclass(frozen=True)
class TaskAccepted:
task_id: str
session_id: str
thread_id: str
run_id: str
created: bool
class AgentRepositoryLike(Protocol):
async def get_session_owner(self, *, session_id: str) -> str: ...
async def create_session_for_user(self, *, user_id: str) -> str: ...
async def create_session_for_user(
self, *, user_id: str, session_id: str | None = None
) -> str: ...
async def commit(self) -> None: ...
async def rollback(self) -> None: ...
async def get_history_day(
self, *, session_id: str, before: date | None
) -> dict[str, object] | None: ...
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None: ...
class QueueClientLike(Protocol):
async def enqueue(
@@ -60,73 +73,140 @@ class AgentService:
async def enqueue_run(
self,
*,
session_id: str | None,
prompt: str,
run_input: RunAgentInput,
current_user: CurrentUser,
) -> TaskAccepted:
created = False
target_session_id = session_id
if target_session_id is None:
target_session_id = await self._repository.create_session_for_user(
user_id=str(current_user.id)
)
created = True
thread_id = run_input.thread_id
run_id = run_input.run_id
try:
owner = await self._repository.get_session_owner(session_id=thread_id)
except HTTPException as exc:
if exc.status_code != 404:
raise
try:
await self._repository.create_session_for_user(
user_id=str(current_user.id),
session_id=thread_id,
)
await self._repository.commit()
created = True
except IntegrityError:
await self._repository.rollback()
owner = await self._repository.get_session_owner(session_id=thread_id)
ensure_session_owner(owner_id=owner, current_user=current_user)
else:
owner = await self._repository.get_session_owner(
session_id=target_session_id
)
ensure_session_owner(owner_id=owner, current_user=current_user)
if created:
await self._repository.commit()
try:
task_id = await self._queue.enqueue(
command={
"command": "run",
"session_id": target_session_id,
"user_input": prompt,
"run_input": run_input.model_dump(mode="json", by_alias=True),
},
dedup_key=None,
)
except Exception: # noqa: BLE001
raise
return TaskAccepted(
task_id=task_id, session_id=target_session_id, created=created
task_id=task_id,
thread_id=thread_id,
run_id=run_id,
created=created,
)
async def enqueue_resume(
self,
*,
session_id: str,
tool_call_id: str,
thread_id: str,
run_input: RunAgentInput,
current_user: CurrentUser,
) -> TaskAccepted:
owner = await self._repository.get_session_owner(session_id=session_id)
owner = await self._repository.get_session_owner(session_id=thread_id)
ensure_session_owner(owner_id=owner, current_user=current_user)
dedup_key = f"resume:{session_id}:{tool_call_id}"
dedup_key = f"resume:{thread_id}:{run_input.run_id}"
task_id = await self._queue.enqueue(
command={
"command": "resume",
"session_id": session_id,
"tool_call_id": tool_call_id,
"run_input": run_input.model_dump(mode="json", by_alias=True),
},
dedup_key=dedup_key,
)
return TaskAccepted(task_id=task_id, session_id=session_id, created=False)
return TaskAccepted(
task_id=task_id,
thread_id=thread_id,
run_id=run_input.run_id,
created=False,
)
async def stream_events(
self,
*,
session_id: str,
thread_id: str,
last_event_id: str | None,
current_user: CurrentUser,
) -> list[dict[str, object]]:
owner = await self._repository.get_session_owner(session_id=session_id)
owner = await self._repository.get_session_owner(session_id=thread_id)
ensure_session_owner(owner_id=owner, current_user=current_user)
return await self._stream.read(
session_id=session_id,
session_id=thread_id,
last_event_id=last_event_id,
)
async def get_history_snapshot(
self,
*,
thread_id: str,
before: date | None,
current_user: CurrentUser,
) -> dict[str, object]:
owner = await self._repository.get_session_owner(session_id=thread_id)
ensure_session_owner(owner_id=owner, current_user=current_user)
day_payload = await self._repository.get_history_day(
session_id=thread_id,
before=before,
)
snapshot = {
"scope": "history_day",
"threadId": thread_id,
"day": day_payload["day"] if day_payload else None,
"hasMore": day_payload["hasMore"] if day_payload else False,
"messages": day_payload["messages"] if day_payload else [],
}
event = StateSnapshotEvent(snapshot=snapshot).model_dump(
mode="json",
by_alias=True,
exclude_none=True,
)
event["threadId"] = thread_id
return event
async def get_user_history_snapshot(
self,
*,
current_user: CurrentUser,
thread_id: str | None,
before: date | None,
) -> dict[str, object]:
target_thread_id = thread_id
if target_thread_id is None:
target_thread_id = await self._repository.get_latest_session_id_for_user(
user_id=str(current_user.id)
)
if target_thread_id is None:
return StateSnapshotEvent(
snapshot={
"scope": "history_day",
"threadId": None,
"day": None,
"hasMore": False,
"messages": [],
}
).model_dump(mode="json", by_alias=True, exclude_none=True)
return await self.get_history_snapshot(
thread_id=target_thread_id,
before=before,
current_user=current_user,
)
@@ -1,9 +1,11 @@
from __future__ import annotations
import json
import uuid
from decimal import Decimal
import pytest
from ag_ui.core import RunAgentInput
from sqlalchemy import delete, select
from core.agent.application.resume_service import ResumeService
@@ -84,28 +86,76 @@ async def test_run_then_resume_persists_messages_and_session_state(
published: list[str] = []
def _publish(event_type: str, payload: dict[str, object]) -> None:
del payload
published.append(event_type)
async def _publish(event: dict[str, object]) -> None:
event_type = event.get("type")
if isinstance(event_type, str):
published.append(event_type)
try:
run_result = run_agent_task(
run_input_payload = {
"threadId": str(session_uuid),
"runId": "run-it-1",
"state": {},
"messages": [
{"id": "u1", "role": "user", "content": "帮我打开日历"},
],
"tools": [
{
"name": "navigate_to_route",
"description": "navigate route",
"parameters": {"type": "object"},
}
],
"context": [],
"forwardedProps": {},
}
run_result = await run_agent_task(
{
"command": "run",
"session_id": str(session_uuid),
"user_input": "hello",
"run_input": run_input_payload,
},
publish_event=_publish,
run_service=RunService(),
resume_service=ResumeService(),
)
pending_tool_call_id = str(run_result["pending_tool_call_id"])
state_snapshot = run_result["state_snapshot"]
assert isinstance(state_snapshot, dict)
pending_tool_nonce = state_snapshot["pending_tool_nonce"]
assert isinstance(pending_tool_nonce, str)
run_agent_task(
await run_agent_task(
{
"command": "resume",
"session_id": str(session_uuid),
"tool_call_id": pending_tool_call_id,
"run_input": {
"threadId": str(session_uuid),
"runId": "run-it-2",
"state": {},
"messages": [
{
"id": "tool-1",
"role": "tool",
"toolCallId": pending_tool_call_id,
"content": json.dumps(
{
"toolName": "navigate_to_route",
"toolArgs": {
"target": "/calendar/dayweek",
"replace": False,
"__nonce": pending_tool_nonce,
},
"nonce": pending_tool_nonce,
"result": {"ok": True},
},
ensure_ascii=True,
separators=(",", ":"),
),
}
],
"tools": [],
"context": [],
"forwardedProps": {},
},
},
publish_event=_publish,
run_service=RunService(),
@@ -123,6 +173,9 @@ async def test_run_then_resume_persists_messages_and_session_state(
assert db_session.state_snapshot == {
"status": "completed",
"pending_tool_call_id": None,
"pending_tool_name": None,
"pending_tool_args_sha256": None,
"pending_tool_nonce": None,
}
rows = await verify_session.execute(
@@ -142,7 +195,7 @@ async def test_run_then_resume_persists_messages_and_session_state(
assert messages[1].cost == Decimal("0.002500")
assert "RUN_STARTED" in published
assert "RUN_RESUMED" in published
assert "RUN_FINISHED" in published
assert "TEXT_MESSAGE_CONTENT" in published
finally:
async with AsyncSessionLocal() as cleanup_session:
@@ -219,7 +272,21 @@ async def test_run_service_embeds_profile_settings_in_runtime_system_prompt(
seed_session.add(AgentChatSession(id=session_uuid, user_id=owner_id))
await seed_session.commit()
result = await RunService().run(session_id=str(session_uuid), user_input="hello")
result = await RunService().run(
run_input=RunAgentInput.model_validate(
{
"threadId": str(session_uuid),
"runId": "run-it-1",
"state": {},
"messages": [
{"id": "u1", "role": "user", "content": "hello"},
],
"tools": [],
"context": [],
"forwardedProps": {},
}
)
)
assert result["persisted"] is True
assert captured["user_input"] == "hello"
@@ -16,29 +16,38 @@ class _FakeStorage:
return "etag-1"
def test_closed_loop_run_flow_frontend_to_sse() -> None:
session_id = "00000000-0000-0000-0000-000000000001"
async def test_closed_loop_run_flow_frontend_to_sse() -> None:
thread_id = "00000000-0000-0000-0000-000000000001"
published: list[str] = []
class _FakeRunService:
async def run(self, *, session_id: str, user_input: str) -> dict[str, object]:
return {"session_id": session_id, "user_input": user_input}
async def run(self, *, run_input: object) -> dict[str, object]:
del run_input
return {"threadId": thread_id, "runId": "run-1"}
def _publish(event_type: str, payload: dict[str, object]) -> None:
del payload
published.append(event_type)
async def _publish(event: dict[str, object]) -> None:
event_type = event.get("type")
if isinstance(event_type, str):
published.append(event_type)
result = run_agent_task(
result = await run_agent_task(
{
"command": "run",
"session_id": session_id,
"user_input": "hello",
"run_input": {
"threadId": thread_id,
"runId": "run-1",
"state": {},
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
"tools": [],
"context": [],
"forwardedProps": {},
},
},
publish_event=_publish,
run_service=_FakeRunService(),
)
assert result["session_id"] == session_id
assert result["threadId"] == thread_id
assert published[0] == "RUN_STARTED"
assert published[-1] == "RUN_FINISHED"
+206 -24
View File
@@ -3,10 +3,12 @@ from __future__ import annotations
from types import SimpleNamespace
from uuid import uuid4
from ag_ui.core import RunAgentInput
from fastapi.testclient import TestClient
from app import app
from core.auth.models import CurrentUser
from v1.agent import router as agent_router
from v1.agent.dependencies import get_agent_service
from v1.users.dependencies import get_current_user
@@ -16,52 +18,122 @@ class _FakeAgentService:
self._stream_called = False
async def enqueue_run(
self, *, session_id: str | None, prompt: str, current_user: CurrentUser
self, *, run_input: RunAgentInput, current_user: CurrentUser
):
del prompt, current_user
resolved_session = session_id or "auto-created-session"
del current_user
return SimpleNamespace(
task_id="task-run-1",
session_id=resolved_session,
created=session_id is None,
thread_id=run_input.thread_id,
run_id=run_input.run_id,
created=False,
)
async def enqueue_resume(
self,
*,
session_id: str,
tool_call_id: str,
thread_id: str,
run_input: RunAgentInput,
current_user: CurrentUser,
):
del tool_call_id, current_user
del thread_id, current_user
return SimpleNamespace(
task_id="task-resume-1", session_id=session_id, created=False
task_id="task-resume-1",
thread_id=run_input.thread_id,
run_id=run_input.run_id,
created=False,
)
async def stream_events(
self,
*,
session_id: str,
thread_id: str,
last_event_id: str | None,
current_user: CurrentUser,
) -> list[dict[str, object]]:
del session_id, current_user
del thread_id, current_user
if self._stream_called:
return []
self._stream_called = True
return [
{"id": "2-0", "event": {"type": "RUN_STARTED"}, "cursor": last_event_id}
{
"id": "2-0",
"event": {
"type": "RUN_STARTED",
"threadId": "00000000-0000-0000-0000-000000000001",
"runId": "run-1",
},
"cursor": last_event_id,
}
]
async def get_history_snapshot(
self,
*,
thread_id: str,
before: str | None,
current_user: CurrentUser,
) -> dict[str, object]:
del current_user
return {
"type": "STATE_SNAPSHOT",
"threadId": thread_id,
"snapshot": {
"scope": "history_day",
"day": before or "2026-03-07",
"hasMore": False,
"messages": [
{
"id": "msg-h1",
"role": "assistant",
"content": "history-message",
}
],
},
}
async def get_user_history_snapshot(
self,
*,
current_user: CurrentUser,
thread_id: str | None,
before: str | None,
) -> dict[str, object]:
del current_user, before
return {
"type": "STATE_SNAPSHOT",
"threadId": thread_id or "00000000-0000-0000-0000-000000000001",
"snapshot": {
"scope": "history_day",
"day": "2026-03-07",
"hasMore": False,
"messages": [],
},
}
def test_run_requires_auth_and_returns_202_task_id() -> None:
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
client = TestClient(app)
original_allow_run = agent_router._allow_run_request
async def _allow_run(*, user_id: str) -> bool:
del user_id
return True
agent_router._allow_run_request = _allow_run # type: ignore[assignment]
try:
unauthorized = client.post(
"/api/v1/agent/runs",
json={"session_id": "session-1", "prompt": "hello"},
json={
"threadId": "00000000-0000-0000-0000-000000000001",
"runId": "run-1",
"state": {},
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
"tools": [],
"context": [],
"forwardedProps": {},
},
)
assert unauthorized.status_code == 401
@@ -70,20 +142,23 @@ def test_run_requires_auth_and_returns_202_task_id() -> None:
)
authorized = client.post(
"/api/v1/agent/runs",
json={"session_id": "session-1", "prompt": "hello"},
json={
"threadId": "00000000-0000-0000-0000-000000000001",
"runId": "run-1",
"state": {},
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
"tools": [],
"context": [],
"forwardedProps": {},
},
)
assert authorized.status_code == 202
assert authorized.json()["task_id"] == "task-run-1"
assert authorized.json()["taskId"] == "task-run-1"
assert authorized.json()["threadId"] == "00000000-0000-0000-0000-000000000001"
assert authorized.json()["runId"] == "run-1"
assert authorized.json()["created"] is False
first_chat = client.post(
"/api/v1/agent/runs",
json={"prompt": "hello"},
)
assert first_chat.status_code == 202
assert first_chat.json()["session_id"] == "auto-created-session"
assert first_chat.json()["created"] is True
finally:
agent_router._allow_run_request = original_allow_run # type: ignore[assignment]
app.dependency_overrides = {}
@@ -93,15 +168,122 @@ def test_stream_reads_from_last_event_id() -> None:
id=uuid4(), email="user@example.com"
)
client = TestClient(app)
original_acquire = agent_router._acquire_sse_slot
original_release = agent_router._release_sse_slot
async def _allow_slot(*, user_id: str) -> bool:
del user_id
return True
async def _noop_release(*, user_id: str) -> None:
del user_id
return None
agent_router._acquire_sse_slot = _allow_slot # type: ignore[assignment]
agent_router._release_sse_slot = _noop_release # type: ignore[assignment]
try:
response = client.get(
"/api/v1/agent/runs/session-1/events?idle_limit=1",
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/events?idle_limit=1",
headers={"Last-Event-ID": "1-0"},
)
assert response.status_code == 200
assert response.headers["content-type"].startswith("text/event-stream")
assert "id: 2-0" in response.text
assert "event: RUN_STARTED" in response.text
assert '"threadId":"00000000-0000-0000-0000-000000000001"' in response.text
finally:
agent_router._acquire_sse_slot = original_acquire # type: ignore[assignment]
agent_router._release_sse_slot = original_release # type: ignore[assignment]
app.dependency_overrides = {}
def test_stream_rejects_invalid_last_event_id() -> None:
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
id=uuid4(), email="user@example.com"
)
client = TestClient(app)
try:
response = client.get(
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/events",
headers={"Last-Event-ID": "bad-id"},
)
assert response.status_code == 422
finally:
app.dependency_overrides = {}
def test_history_returns_state_snapshot() -> None:
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
client = TestClient(app)
try:
unauthorized = client.get(
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/history"
)
assert unauthorized.status_code == 401
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
id=uuid4(), email="user@example.com"
)
authorized = client.get(
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/history",
params={"before": "2026-03-07"},
)
assert authorized.status_code == 200
payload = authorized.json()
assert payload["type"] == "STATE_SNAPSHOT"
assert payload["threadId"] == "00000000-0000-0000-0000-000000000001"
assert payload["snapshot"]["scope"] == "history_day"
assert payload["snapshot"]["day"] == "2026-03-07"
finally:
app.dependency_overrides = {}
def test_user_history_returns_latest_snapshot() -> None:
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
id=uuid4(), email="user@example.com"
)
client = TestClient(app)
try:
response = client.get("/api/v1/agent/history")
assert response.status_code == 200
body = response.json()
assert body["type"] == "STATE_SNAPSHOT"
assert body["threadId"] == "00000000-0000-0000-0000-000000000001"
finally:
app.dependency_overrides = {}
def test_run_rejects_oversized_user_text_payload() -> None:
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
id=uuid4(), email="user@example.com"
)
client = TestClient(app)
try:
response = client.post(
"/api/v1/agent/runs",
json={
"threadId": "00000000-0000-0000-0000-000000000001",
"runId": "run-oversize",
"state": {},
"messages": [
{
"id": "u1",
"role": "user",
"content": "x" * 11000,
}
],
"tools": [],
"context": [],
"forwardedProps": {},
},
)
assert response.status_code == 422
finally:
app.dependency_overrides = {}
@@ -2,7 +2,7 @@ from __future__ import annotations
import os
from datetime import datetime, timedelta, timezone
from uuid import UUID
from uuid import UUID, uuid4
import httpx
import jwt
@@ -56,15 +56,25 @@ async def test_agent_sse_closed_loop_live() -> None:
run_resp = await client.post(
f"{BASE_URL}/api/v1/agent/runs",
headers=headers,
json={"prompt": "请用一句话介绍你自己"},
json={
"threadId": str(uuid4()),
"runId": "run-live-1",
"state": {},
"messages": [
{"id": "u1", "role": "user", "content": "请用一句话介绍你自己"}
],
"tools": [],
"context": [],
"forwardedProps": {},
},
)
assert run_resp.status_code == 202
accepted = run_resp.json()
session_id = str(accepted["session_id"])
assert session_id
thread_id = str(accepted["threadId"])
assert thread_id
events_url = f"{BASE_URL}/api/v1/agent/runs/{session_id}/events"
events_url = f"{BASE_URL}/api/v1/agent/runs/{thread_id}/events"
event_names: list[str] = []
async with client.stream("GET", events_url, headers=headers, timeout=20.0) as sse_resp:
assert sse_resp.status_code == 200
@@ -77,13 +87,13 @@ async def test_agent_sse_closed_loop_live() -> None:
assert "RUN_FINISHED" in event_names or "RUN_ERROR" in event_names
async with AsyncSessionLocal() as session:
session_row = await session.get(AgentChatSession, UUID(session_id))
session_row = await session.get(AgentChatSession, UUID(thread_id))
assert session_row is not None
assert session_row.message_count >= 1
assert session_row.total_tokens >= 0
assert session_row.total_cost >= 0
rows = await session.execute(
select(AgentChatMessage).where(AgentChatMessage.session_id == UUID(session_id))
select(AgentChatMessage).where(AgentChatMessage.session_id == UUID(thread_id))
)
assert len(list(rows.scalars().all())) >= 1
@@ -132,7 +132,9 @@ def test_bridge_rejects_unknown_event_type() -> None:
def test_sse_format_includes_id_event_data() -> None:
payload = to_sse_event(
stream_id="1-0", event={"type": "RUN_STARTED", "data": {"a": 1}}
stream_id="1-0",
event={"type": "RUN_STARTED", "threadId": "t1", "runId": "r1"},
)
assert payload.startswith("id: 1-0\nevent: RUN_STARTED\ndata: {")
assert '"threadId":"t1"' in payload
@@ -56,12 +56,14 @@ def test_runtime_execute_uses_provider_prefixed_litellm_model(
messages: list[dict[str, object]],
temperature: float | None = None,
max_tokens: int | None = None,
timeout: float | None = None,
):
captured["model"] = model
captured["api_key"] = api_key
captured["messages"] = messages
captured["temperature"] = temperature
captured["max_tokens"] = max_tokens
captured["timeout"] = timeout
return {
"choices": [
{
@@ -113,6 +115,7 @@ def test_runtime_execute_uses_provider_prefixed_litellm_model(
assert captured["api_key"] == "env-api-key"
assert captured["temperature"] == 0.3
assert captured["max_tokens"] == 256
assert captured["timeout"] == 30.0
assert result["assistant_text"] == "hello"
@@ -128,6 +131,7 @@ def test_runtime_execute_injects_system_prompt_and_intent_template(
messages: list[dict[str, object]],
temperature: float | None = None,
max_tokens: int | None = None,
timeout: float | None = None,
):
captured["messages"] = messages
return {
@@ -219,6 +223,7 @@ def test_runtime_execute_short_circuits_on_direct_execution(
messages: list[dict[str, object]],
temperature: float | None = None,
max_tokens: int | None = None,
timeout: float | None = None,
):
del model, api_key, temperature, max_tokens
calls.append(messages)
@@ -331,6 +336,7 @@ def test_runtime_execute_runs_execution_and_organization_stages(
messages: list[dict[str, object]],
temperature: float | None = None,
max_tokens: int | None = None,
timeout: float | None = None,
):
del model, api_key, temperature, max_tokens
calls.append(messages)
@@ -383,6 +389,7 @@ def test_runtime_execute_rejects_invalid_intent_json(
messages: list[dict[str, object]],
temperature: float | None = None,
max_tokens: int | None = None,
timeout: float | None = None,
):
del model, api_key, messages, temperature, max_tokens
return {
@@ -506,6 +513,7 @@ def test_runtime_execute_minimizes_prompt_and_execution_payload(
messages: list[dict[str, object]],
temperature: float | None = None,
max_tokens: int | None = None,
timeout: float | None = None,
):
del model, api_key, temperature, max_tokens
calls.append(messages)
@@ -21,10 +21,12 @@ def test_run_completion_passes_optional_params_when_provided(monkeypatch) -> Non
messages=[{"role": "user", "content": "hi"}],
temperature=0.6,
max_tokens=120,
timeout=12.5,
)
assert captured["temperature"] == 0.6
assert captured["max_tokens"] == 120
assert captured["timeout"] == 12.5
def test_run_completion_omits_optional_params_when_none(monkeypatch) -> None:
@@ -45,7 +47,9 @@ def test_run_completion_omits_optional_params_when_none(monkeypatch) -> None:
messages=[{"role": "user", "content": "hi"}],
temperature=None,
max_tokens=None,
timeout=None,
)
assert "temperature" not in captured
assert "max_tokens" not in captured
assert "timeout" not in captured
+114 -26
View File
@@ -2,64 +2,124 @@ from __future__ import annotations
import pytest
from ag_ui.core import RunAgentInput
from core.agent.infrastructure.queue.tasks import _build_redis_publisher, run_agent_task
class _FakeRunService:
async def run(self, *, session_id: str, user_input: str) -> dict[str, object]:
return {"session_id": session_id, "user_input": user_input}
async def run(self, *, run_input: RunAgentInput) -> dict[str, object]:
return {
"threadId": run_input.thread_id,
"runId": run_input.run_id,
}
class _FakeResumeService:
async def resume(self, *, session_id: str, tool_call_id: str) -> dict[str, object]:
return {"session_id": session_id, "tool_call_id": tool_call_id}
async def resume(
self,
*,
run_input: RunAgentInput,
) -> dict[str, object]:
return {
"threadId": run_input.thread_id,
"runId": run_input.run_id,
}
def _build_run_input() -> dict[str, object]:
return {
"threadId": "00000000-0000-0000-0000-000000000001",
"runId": "run-1",
"state": {},
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
"tools": [],
"context": [],
"forwardedProps": {},
}
@pytest.mark.asyncio
async def test_run_agent_task_emits_started_runtime_and_finished_events() -> None:
session_id = "00000000-0000-0000-0000-000000000001"
events: list[str] = []
async def _publish(event_type: str, payload: dict[str, object]) -> None:
del payload
events.append(event_type)
async def _publish(event: dict[str, object]) -> None:
event_type = event.get("type")
if isinstance(event_type, str):
events.append(event_type)
result = await run_agent_task(
{
"command": "run",
"session_id": session_id,
"user_input": "hello",
"run_input": _build_run_input(),
},
publish_event=_publish,
run_service=_FakeRunService(),
resume_service=_FakeResumeService(),
)
assert result["session_id"] == session_id
assert events == ["RUN_STARTED", "RUNTIME_EVENT", "RUN_FINISHED"]
assert result["threadId"] == "00000000-0000-0000-0000-000000000001"
assert events == ["RUN_STARTED", "RUN_FINISHED"]
@pytest.mark.asyncio
async def test_run_agent_task_injects_context_and_redacts_sensitive_fields() -> None:
published: list[dict[str, object]] = []
class _RunWithExtraEvents(_FakeRunService):
async def run(self, *, run_input: RunAgentInput) -> dict[str, object]:
return {
"threadId": run_input.thread_id,
"runId": run_input.run_id,
"events": [
{
"type": "TEXT_MESSAGE_CONTENT",
"messageId": "m1",
"delta": "hi",
"token": "secret-token",
}
],
}
async def _publish(event: dict[str, object]) -> None:
published.append(event)
await run_agent_task(
{"command": "run", "run_input": _build_run_input()},
publish_event=_publish,
run_service=_RunWithExtraEvents(),
resume_service=_FakeResumeService(),
)
run_started = published[0]
assert run_started["type"] == "RUN_STARTED"
assert "input" not in run_started
text_event = published[1]
assert text_event["type"] == "TEXT_MESSAGE_CONTENT"
assert text_event["threadId"] == "00000000-0000-0000-0000-000000000001"
assert text_event["runId"] == "run-1"
assert text_event["token"] == "***REDACTED***"
@pytest.mark.asyncio
async def test_run_agent_task_emits_error_event_on_exception() -> None:
session_id = "00000000-0000-0000-0000-000000000001"
class _BrokenRunService(_FakeRunService):
async def run(self, *, session_id: str, user_input: str) -> dict[str, object]:
del session_id, user_input
async def run(self, *, run_input: dict[str, object]) -> dict[str, object]:
del run_input
raise RuntimeError("boom")
events: list[str] = []
async def _publish(event_type: str, payload: dict[str, object]) -> None:
del payload
events.append(event_type)
async def _publish(event: dict[str, object]) -> None:
event_type = event.get("type")
if isinstance(event_type, str):
events.append(event_type)
with pytest.raises(RuntimeError):
await run_agent_task(
{
"command": "run",
"session_id": session_id,
"user_input": "hello",
"run_input": _build_run_input(),
},
publish_event=_publish,
run_service=_BrokenRunService(),
@@ -72,16 +132,44 @@ async def test_run_agent_task_emits_error_event_on_exception() -> None:
@pytest.mark.asyncio
async def test_run_agent_task_rejects_invalid_command() -> None:
with pytest.raises(ValueError, match="invalid command type"):
await run_agent_task({"command": "invalid", "session_id": "00000000-0000-0000-0000-000000000001"})
await run_agent_task({"command": "invalid", "run_input": _build_run_input()})
@pytest.mark.asyncio
async def test_run_agent_task_resume_requires_tool_call_id() -> None:
with pytest.raises(ValueError, match="tool_call_id is required"):
async def test_run_agent_task_rejects_missing_run_input() -> None:
with pytest.raises(ValueError, match="run_input is required"):
await run_agent_task(
{
"command": "resume",
"session_id": "00000000-0000-0000-0000-000000000001",
"command": "run",
}
)
@pytest.mark.asyncio
async def test_run_agent_task_resume_uses_run_input() -> None:
async def _publish(event: dict[str, object]) -> None:
del event
result = await run_agent_task(
{
"command": "resume",
"run_input": _build_run_input(),
},
publish_event=_publish,
run_service=_FakeRunService(),
resume_service=_FakeResumeService(),
)
assert result["runId"] == "run-1"
@pytest.mark.asyncio
async def test_run_agent_task_rejects_invalid_run_input() -> None:
with pytest.raises(ValueError, match="invalid AG-UI RunAgentInput payload"):
await run_agent_task(
{
"command": "run",
"run_input": {"threadId": "x"},
}
)
@@ -23,11 +23,34 @@ class _FakeRedisClient:
) -> list[tuple[str, list[tuple[str, dict[str, str]]]]]:
del count, block
key, start_id = next(iter(streams.items()))
if start_id == "$":
if start_id == "0-0":
return [(key, [("11-0", {"event": '{"type":"RUN_STARTED"}'})])]
return [(key, [("12-0", {"event": '{"type":"RUN_FINISHED"}'})])]
class _MalformedRedisClient:
async def xread(
self,
streams: dict[str, str],
count: int,
block: int,
) -> list[object]:
del streams, count, block
return ["bad-shape"]
class _InvalidJsonRedisClient:
async def xread(
self,
streams: dict[str, str],
count: int,
block: int,
) -> list[tuple[str, list[tuple[str, dict[str, str]]]]]:
del count, block
key = next(iter(streams.keys()))
return [(key, [("11-0", {"event": "not-json"})])]
def test_append_event_writes_json_payload() -> None:
client = _FakeRedisClient()
session_id = uuid4()
@@ -55,3 +78,26 @@ async def test_read_events_respects_last_event_id() -> None:
assert from_start[0]["id"] == "11-0"
assert from_last[0]["id"] == "12-0"
@pytest.mark.asyncio
async def test_read_events_returns_empty_for_malformed_response() -> None:
session_id = uuid4()
store = RedisStreamEventStore(client=_MalformedRedisClient(), stream_prefix="agent:events")
rows = await store.read_events(session_id=session_id, last_event_id=None)
assert rows == []
@pytest.mark.asyncio
async def test_read_events_skips_invalid_event_json() -> None:
session_id = uuid4()
store = RedisStreamEventStore(
client=_InvalidJsonRedisClient(),
stream_prefix="agent:events",
)
rows = await store.read_events(session_id=session_id, last_event_id=None)
assert rows == []
@@ -5,11 +5,13 @@ from types import SimpleNamespace
from uuid import uuid4
import pytest
from ag_ui.core import RunAgentInput
from core.agent.application.resume_service import ResumeService
from core.agent.application.run_service import RunService
from core.agent.domain.system_agent_config import SystemAgentLLMConfig
from core.agent.domain.user_context import UserAgentContext, parse_profile_settings
from models.agent_chat_message import AgentChatMessageRole
from models.agent_chat_session import AgentChatSessionStatus
@@ -61,12 +63,69 @@ class _FakeUserContextCache:
self.set_calls += 1
def _build_run_input(
*,
thread_id: str,
text: str = "hello",
tools: list[dict[str, object]] | None = None,
) -> RunAgentInput:
return RunAgentInput.model_validate(
{
"threadId": thread_id,
"runId": "run-1",
"state": {},
"messages": [{"id": "u1", "role": "user", "content": text}],
"tools": tools or [],
"context": [],
"forwardedProps": {},
}
)
def _build_resume_input(
*,
thread_id: str,
tool_call_id: str,
content: str | None = None,
) -> RunAgentInput:
payload = content
if payload is None:
payload = json.dumps(
{
"toolName": "navigate_to_route",
"toolArgs": {"target": "/calendar/dayweek", "replace": False, "__nonce": "nonce-1"},
"nonce": "nonce-1",
"result": {"ok": True},
},
ensure_ascii=True,
separators=(",", ":"),
)
return RunAgentInput.model_validate(
{
"threadId": thread_id,
"runId": "run-2",
"state": {},
"messages": [
{
"id": "tool-1",
"role": "tool",
"toolCallId": tool_call_id,
"content": payload,
}
],
"tools": [],
"context": [],
"forwardedProps": {},
}
)
@pytest.mark.asyncio
async def test_run_service_rejects_invalid_session_id() -> None:
run_service = RunService()
with pytest.raises(ValueError):
await run_service.run(session_id="session-1", user_input="hello")
await run_service.run(run_input=_build_run_input(thread_id="session-1"))
@pytest.mark.asyncio
@@ -74,7 +133,272 @@ async def test_resume_service_requires_pending_tool_call() -> None:
resume_service = ResumeService()
with pytest.raises(ValueError):
await resume_service.resume(session_id="session-1", tool_call_id="call-1")
await resume_service.resume(
run_input=_build_resume_input(
thread_id="session-1",
tool_call_id="call-1",
)
)
@pytest.mark.asyncio
async def test_resume_service_validates_pending_tool_guard_and_persists_payload(
monkeypatch: pytest.MonkeyPatch,
) -> None:
session_id = uuid4()
user_id = uuid4()
captured: list[dict[str, object]] = []
class _FakeDbSession:
async def commit(self) -> None:
return None
class _FakeSessionFactory:
def __call__(self) -> "_FakeSessionFactory":
return self
async def __aenter__(self) -> _FakeDbSession:
return _FakeDbSession()
async def __aexit__(self, exc_type, exc, tb) -> bool:
del exc_type, exc, tb
return False
class _FakeSessionRepository:
def __init__(self, session: object) -> None:
del session
async def lock_session_for_update(self, *, session_id: object):
return SimpleNamespace(
id=session_id,
user_id=user_id,
status=AgentChatSessionStatus.RUNNING,
message_count=0,
total_tokens=0,
total_cost=0,
state_snapshot={
"pending_tool_call_id": "call-1",
"pending_tool_name": "navigate_to_route",
"pending_tool_args_sha256": "c8e6573e6ce79d9a8052d5167e18602841fa4248f3e8a2efb448c4bfbd298a12",
"pending_tool_nonce": "nonce-1",
},
)
async def next_message_seq(self, *, session_id: object) -> int:
del session_id
return 1
async def update_runtime_state(self, **kwargs) -> None:
del kwargs
class _FakeMessageRepository:
def __init__(self, session: object) -> None:
del session
async def append_message(self, **kwargs) -> None:
captured.append(kwargs)
monkeypatch.setattr(
"core.agent.application.resume_service.SessionRepository",
_FakeSessionRepository,
)
monkeypatch.setattr(
"core.agent.application.resume_service.MessageRepository",
_FakeMessageRepository,
)
service = ResumeService(session_factory=_FakeSessionFactory()) # type: ignore[arg-type]
await service.resume(
run_input=_build_resume_input(
thread_id=str(session_id),
tool_call_id="call-1",
),
)
assert captured[0]["role"] == AgentChatMessageRole.TOOL
stored_payload = json.loads(captured[0]["content"])
assert stored_payload["toolName"] == "navigate_to_route"
assert stored_payload["result"]["ok"] is True
assert stored_payload["result"]["applied"] is True
assert "ui" not in stored_payload
@pytest.mark.asyncio
async def test_resume_service_rejects_mismatched_nonce(
monkeypatch: pytest.MonkeyPatch,
) -> None:
session_id = uuid4()
user_id = uuid4()
class _FakeDbSession:
async def commit(self) -> None:
return None
class _FakeSessionFactory:
def __call__(self) -> "_FakeSessionFactory":
return self
async def __aenter__(self) -> _FakeDbSession:
return _FakeDbSession()
async def __aexit__(self, exc_type, exc, tb) -> bool:
del exc_type, exc, tb
return False
class _FakeSessionRepository:
def __init__(self, session: object) -> None:
del session
async def lock_session_for_update(self, *, session_id: object):
return SimpleNamespace(
id=session_id,
user_id=user_id,
status=AgentChatSessionStatus.RUNNING,
message_count=0,
total_tokens=0,
total_cost=0,
state_snapshot={
"pending_tool_call_id": "call-1",
"pending_tool_name": "navigate_to_route",
"pending_tool_args_sha256": "c8e6573e6ce79d9a8052d5167e18602841fa4248f3e8a2efb448c4bfbd298a12",
"pending_tool_nonce": "nonce-1",
},
)
async def next_message_seq(self, *, session_id: object) -> int:
del session_id
return 1
async def update_runtime_state(self, **kwargs) -> None:
del kwargs
class _FakeMessageRepository:
def __init__(self, session: object) -> None:
del session
async def append_message(self, **kwargs) -> None:
del kwargs
monkeypatch.setattr(
"core.agent.application.resume_service.SessionRepository",
_FakeSessionRepository,
)
monkeypatch.setattr(
"core.agent.application.resume_service.MessageRepository",
_FakeMessageRepository,
)
service = ResumeService(session_factory=_FakeSessionFactory()) # type: ignore[arg-type]
with pytest.raises(ValueError, match="nonce"):
await service.resume(
run_input=_build_resume_input(
thread_id=str(session_id),
tool_call_id="call-1",
content=json.dumps(
{
"toolName": "navigate_to_route",
"toolArgs": {
"target": "/calendar/dayweek",
"replace": False,
"__nonce": "nonce-1",
},
"nonce": "nonce-bad",
"result": {"ok": True},
},
ensure_ascii=True,
separators=(",", ":"),
),
)
)
@pytest.mark.asyncio
async def test_resume_service_rejects_tool_result_when_not_ok(
monkeypatch: pytest.MonkeyPatch,
) -> None:
session_id = uuid4()
user_id = uuid4()
class _FakeDbSession:
async def commit(self) -> None:
return None
class _FakeSessionFactory:
def __call__(self) -> "_FakeSessionFactory":
return self
async def __aenter__(self) -> _FakeDbSession:
return _FakeDbSession()
async def __aexit__(self, exc_type, exc, tb) -> bool:
del exc_type, exc, tb
return False
class _FakeSessionRepository:
def __init__(self, session: object) -> None:
del session
async def lock_session_for_update(self, *, session_id: object):
return SimpleNamespace(
id=session_id,
user_id=user_id,
status=AgentChatSessionStatus.RUNNING,
message_count=0,
total_tokens=0,
total_cost=0,
state_snapshot={
"pending_tool_call_id": "call-1",
"pending_tool_name": "navigate_to_route",
"pending_tool_args_sha256": "c8e6573e6ce79d9a8052d5167e18602841fa4248f3e8a2efb448c4bfbd298a12",
"pending_tool_nonce": "nonce-1",
},
)
async def next_message_seq(self, *, session_id: object) -> int:
del session_id
return 1
async def update_runtime_state(self, **kwargs) -> None:
del kwargs
class _FakeMessageRepository:
def __init__(self, session: object) -> None:
del session
async def append_message(self, **kwargs) -> None:
del kwargs
monkeypatch.setattr(
"core.agent.application.resume_service.SessionRepository",
_FakeSessionRepository,
)
monkeypatch.setattr(
"core.agent.application.resume_service.MessageRepository",
_FakeMessageRepository,
)
service = ResumeService(session_factory=_FakeSessionFactory()) # type: ignore[arg-type]
with pytest.raises(ValueError, match="execution failed"):
await service.resume(
run_input=_build_resume_input(
thread_id=str(session_id),
tool_call_id="call-1",
content=json.dumps(
{
"toolName": "navigate_to_route",
"toolArgs": {
"target": "/calendar/dayweek",
"replace": False,
"__nonce": "nonce-1",
},
"nonce": "nonce-1",
"result": {"ok": False, "error": "navigator not bound"},
},
ensure_ascii=True,
separators=(",", ":"),
),
)
)
@pytest.mark.asyncio
@@ -256,7 +580,9 @@ async def test_run_service_passes_user_context_system_prompt_to_runtime(
session_uuid = session_id
run_service = RunService(session_factory=_FakeSessionFactory()) # type: ignore[arg-type]
await run_service.run(session_id=str(session_id), user_input="hello")
await run_service.run(
run_input=_build_run_input(thread_id=str(session_id), text="hello")
)
system_prompt = captured["system_prompt"]
assert isinstance(system_prompt, str)
@@ -267,6 +593,290 @@ async def test_run_service_passes_user_context_system_prompt_to_runtime(
assert payload["ai_language"] == "en-US"
@pytest.mark.asyncio
async def test_run_service_emits_frontend_tool_pending_events(
monkeypatch: pytest.MonkeyPatch,
) -> None:
session_id = uuid4()
user_id = uuid4()
captured: dict[str, object] = {}
class _FakeDbSession:
async def commit(self) -> None:
return None
class _FakeSessionFactory:
def __call__(self) -> "_FakeSessionFactory":
return self
async def __aenter__(self) -> _FakeDbSession:
return _FakeDbSession()
async def __aexit__(self, exc_type, exc, tb) -> bool:
del exc_type, exc, tb
return False
class _FakeSessionRepository:
def __init__(self, session: object) -> None:
del session
async def lock_session_for_update(self, *, session_id: object):
return SimpleNamespace(
id=session_id,
user_id=user_id,
status=AgentChatSessionStatus.PENDING,
message_count=0,
total_tokens=0,
total_cost=0,
state_snapshot=None,
)
async def next_message_seq(self, *, session_id: object):
del session_id
return 1
async def update_runtime_state(self, **kwargs) -> None:
captured["update_runtime_state"] = kwargs
class _FakeMessageRepository:
def __init__(self, session: object) -> None:
del session
async def append_message(self, **kwargs) -> None:
captured.setdefault("messages", []).append(kwargs)
class _FakeRuntime:
def execute(self, *, user_input: str, system_prompt: str | None = None):
del user_input, system_prompt
return {
"assistant_text": "请确认是否跳转。",
"prompt_tokens": 1,
"completion_tokens": 1,
"total_tokens": 2,
"cost": "0.001",
"agui_events": [],
}
async def _fake_load_agent_model_selection(self, _session):
del self
return ("qwen3.5-flash", "dashscope", SystemAgentLLMConfig())
async def _fake_load_user_agent_context(self, session, session_id, user_id):
del self, session, session_id
return SimpleNamespace(
user_id=user_id,
username="demo-user",
bio=None,
settings=SimpleNamespace(
preferences=SimpleNamespace(
interface_language="zh-CN",
ai_language="zh-CN",
timezone="Asia/Shanghai",
country="CN",
)
),
)
monkeypatch.setattr(
"core.agent.application.run_service.SessionRepository",
_FakeSessionRepository,
)
monkeypatch.setattr(
"core.agent.application.run_service.MessageRepository",
_FakeMessageRepository,
)
monkeypatch.setattr(
"core.agent.application.run_service.create_runtime",
lambda **_kwargs: _FakeRuntime(),
)
monkeypatch.setattr(
"core.agent.application.run_service.RunService._load_agent_model_selection",
_fake_load_agent_model_selection,
)
monkeypatch.setattr(
"core.agent.application.run_service.RunService._load_user_agent_context",
_fake_load_user_agent_context,
)
service = RunService(session_factory=_FakeSessionFactory()) # type: ignore[arg-type]
result = await service.run(
run_input=_build_run_input(
thread_id=str(session_id),
text="帮我打开日历",
tools=[
{
"name": "navigate_to_route",
"description": "navigate",
"parameters": {"type": "object"},
}
],
)
)
assert result["pending_tool_call_id"] is not None
tool_start = next(event for event in result["events"] if event["type"] == "TOOL_CALL_START")
assert tool_start["toolCallName"] == "navigate_to_route"
runtime_state = captured["update_runtime_state"]
assert isinstance(runtime_state, dict)
assert runtime_state["status"] == AgentChatSessionStatus.RUNNING
snapshot = runtime_state["state_snapshot"]
assert isinstance(snapshot, dict)
assert snapshot["pending_tool_name"] == "navigate_to_route"
assert isinstance(snapshot["pending_tool_args_sha256"], str)
assert isinstance(snapshot["pending_tool_nonce"], str)
@pytest.mark.asyncio
async def test_run_service_executes_backend_calendar_tool_and_emits_result(
monkeypatch: pytest.MonkeyPatch,
) -> None:
session_id = uuid4()
user_id = uuid4()
captured: dict[str, object] = {}
class _FakeDbSession:
async def commit(self) -> None:
return None
class _FakeSessionFactory:
def __call__(self) -> "_FakeSessionFactory":
return self
async def __aenter__(self) -> _FakeDbSession:
return _FakeDbSession()
async def __aexit__(self, exc_type, exc, tb) -> bool:
del exc_type, exc, tb
return False
class _FakeSessionRepository:
def __init__(self, session: object) -> None:
del session
async def lock_session_for_update(self, *, session_id: object):
return SimpleNamespace(
id=session_id,
user_id=user_id,
status=AgentChatSessionStatus.PENDING,
message_count=0,
total_tokens=0,
total_cost=0,
state_snapshot=None,
)
async def next_message_seq(self, *, session_id: object):
del session_id
return 1
async def update_runtime_state(self, **kwargs) -> None:
captured["update_runtime_state"] = kwargs
class _FakeMessageRepository:
def __init__(self, session: object) -> None:
del session
async def append_message(self, **kwargs) -> None:
captured.setdefault("messages", []).append(kwargs)
class _FakeRuntime:
def execute(self, *, user_input: str, system_prompt: str | None = None):
del user_input, system_prompt
return {
"assistant_text": "日历事件已创建。",
"prompt_tokens": 1,
"completion_tokens": 1,
"total_tokens": 2,
"cost": "0.001",
"agui_events": [],
}
async def _fake_load_agent_model_selection(self, _session):
del self
return ("qwen3.5-flash", "dashscope", SystemAgentLLMConfig())
async def _fake_load_user_agent_context(self, session, session_id, user_id):
del self, session, session_id
return SimpleNamespace(
user_id=user_id,
username="demo-user",
bio=None,
settings=SimpleNamespace(
preferences=SimpleNamespace(
interface_language="zh-CN",
ai_language="zh-CN",
timezone="Asia/Shanghai",
country="CN",
)
),
)
async def _fake_execute_backend_tool(
self,
*,
session,
owner_id,
tool_name,
tool_args,
):
del self, session, owner_id
assert tool_name == "create_calendar_event"
assert "title" in tool_args
return {
"result": {"eventId": "evt-1", "ok": True},
"ui": {
"type": "calendar_card.v1",
"version": "v1",
"data": {"id": "evt-1", "title": "会议"},
},
}
monkeypatch.setattr(
"core.agent.application.run_service.SessionRepository",
_FakeSessionRepository,
)
monkeypatch.setattr(
"core.agent.application.run_service.MessageRepository",
_FakeMessageRepository,
)
monkeypatch.setattr(
"core.agent.application.run_service.create_runtime",
lambda **_kwargs: _FakeRuntime(),
)
monkeypatch.setattr(
"core.agent.application.run_service.RunService._load_agent_model_selection",
_fake_load_agent_model_selection,
)
monkeypatch.setattr(
"core.agent.application.run_service.RunService._load_user_agent_context",
_fake_load_user_agent_context,
)
monkeypatch.setattr(
"core.agent.application.run_service.RunService._execute_backend_tool",
_fake_execute_backend_tool,
)
service = RunService(session_factory=_FakeSessionFactory()) # type: ignore[arg-type]
result = await service.run(
run_input=_build_run_input(
thread_id=str(session_id),
text='#tool:create_calendar_event {"title":"会议","startAt":"2026-03-07T08:00:00Z"}',
tools=[
{
"name": "create_calendar_event",
"description": "create calendar",
"parameters": {"type": "object"},
}
],
)
)
assert result["pending_tool_call_id"] is None
assert any(event["type"] == "TOOL_CALL_RESULT" for event in result["events"])
runtime_state = captured["update_runtime_state"]
assert isinstance(runtime_state, dict)
assert runtime_state["status"] == AgentChatSessionStatus.COMPLETED
@pytest.mark.asyncio
async def test_load_user_agent_context_parses_profile_settings_v1() -> None:
session_id = uuid4()
@@ -519,7 +1129,9 @@ async def test_run_service_still_executes_when_profile_missing(
session_uuid = session_id
run_service = RunService(session_factory=_FakeSessionFactory()) # type: ignore[arg-type]
await run_service.run(session_id=str(session_id), user_input="hello")
await run_service.run(
run_input=_build_run_input(thread_id=str(session_id), text="hello")
)
system_prompt = captured["system_prompt"]
assert isinstance(system_prompt, str)
@@ -4,9 +4,18 @@ from core.agent.domain.state_snapshot import AgentStateSnapshot
def test_state_snapshot_serialization_round_trip() -> None:
snapshot = AgentStateSnapshot(status="running", pending_tool_call_id="call-1")
snapshot = AgentStateSnapshot(
status="running",
pending_tool_call_id="call-1",
pending_tool_name="navigate_to_route",
pending_tool_args_sha256="abc",
pending_tool_nonce="nonce-1",
)
payload = snapshot.model_dump()
assert payload["status"] == "running"
assert payload["pending_tool_call_id"] == "call-1"
assert payload["pending_tool_name"] == "navigate_to_route"
assert payload["pending_tool_args_sha256"] == "abc"
assert payload["pending_tool_nonce"] == "nonce-1"
@@ -0,0 +1,33 @@
from __future__ import annotations
import pytest
from v1.agent import router as agent_router
@pytest.mark.asyncio
async def test_allow_run_request_fails_closed_when_redis_unavailable(
monkeypatch: pytest.MonkeyPatch,
) -> None:
async def _raise_redis_error():
raise RuntimeError("redis unavailable")
monkeypatch.setattr(agent_router, "get_or_init_redis_client", _raise_redis_error)
allowed = await agent_router._allow_run_request(user_id="user-1")
assert allowed is False
@pytest.mark.asyncio
async def test_acquire_sse_slot_fails_closed_when_redis_unavailable(
monkeypatch: pytest.MonkeyPatch,
) -> None:
async def _raise_redis_error():
raise RuntimeError("redis unavailable")
monkeypatch.setattr(agent_router, "get_or_init_redis_client", _raise_redis_error)
allowed = await agent_router._acquire_sse_slot(user_id="user-1")
assert allowed is False
+140 -14
View File
@@ -1,7 +1,11 @@
from __future__ import annotations
from datetime import date
from uuid import UUID
from ag_ui.core import RunAgentInput
from fastapi import HTTPException
from sqlalchemy.exc import IntegrityError
from core.auth.models import CurrentUser
from v1.agent.service import AgentService
@@ -11,14 +15,19 @@ class _FakeRepository:
self.committed = False
self.rolled_back = False
self.deleted_session_id: str | None = None
self.created_with_session_id: str | None = None
async def get_session_owner(self, *, session_id: str) -> str:
del session_id
return "00000000-0000-0000-0000-000000000001"
if session_id == "00000000-0000-0000-0000-000000000001":
return "00000000-0000-0000-0000-000000000001"
raise HTTPException(status_code=404, detail="Session not found")
async def create_session_for_user(self, *, user_id: str) -> str:
async def create_session_for_user(
self, *, user_id: str, session_id: str | None = None
) -> str:
del user_id
return "00000000-0000-0000-0000-000000000999"
self.created_with_session_id = session_id
return session_id or "00000000-0000-0000-0000-000000000999"
async def commit(self) -> None:
self.committed = True
@@ -29,6 +38,22 @@ class _FakeRepository:
async def delete_session(self, *, session_id: str) -> None:
self.deleted_session_id = session_id
async def get_history_day(
self, *, session_id: str, before: date | None
) -> dict[str, object] | None:
del session_id
if before is not None and before <= date(2026, 3, 6):
return None
return {
"day": "2026-03-06",
"hasMore": False,
"messages": [{"id": "m1", "role": "assistant", "content": "hello"}],
}
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None:
del user_id
return "00000000-0000-0000-0000-000000000001"
class _FakeQueue:
async def enqueue(
@@ -63,6 +88,20 @@ def _user() -> CurrentUser:
)
def _build_run_input(*, thread_id: str, run_id: str) -> RunAgentInput:
return RunAgentInput.model_validate(
{
"threadId": thread_id,
"runId": run_id,
"state": {},
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
"tools": [],
"context": [],
"forwardedProps": {},
}
)
async def test_resume_idempotency_uses_redis_lock_and_task_key() -> None:
service = AgentService(
repository=_FakeRepository(),
@@ -70,37 +109,46 @@ async def test_resume_idempotency_uses_redis_lock_and_task_key() -> None:
stream=_FakeStream(),
)
user = _user()
run_input = _build_run_input(
thread_id="00000000-0000-0000-0000-000000000001",
run_id="run-1",
)
first = await service.enqueue_resume(
session_id="session-1",
tool_call_id="call-1",
thread_id="00000000-0000-0000-0000-000000000001",
run_input=run_input,
current_user=user,
)
second = await service.enqueue_resume(
session_id="session-1",
tool_call_id="call-1",
thread_id="00000000-0000-0000-0000-000000000001",
run_input=run_input,
current_user=user,
)
assert first.task_id == second.task_id
async def test_enqueue_run_without_session_creates_new_session() -> None:
async def test_enqueue_run_creates_missing_thread_session() -> None:
repository = _FakeRepository()
service = AgentService(
repository=repository,
queue=_FakeQueue(),
stream=_FakeStream(),
)
run_input = _build_run_input(
thread_id="00000000-0000-0000-0000-000000000999",
run_id="run-1",
)
accepted = await service.enqueue_run(
session_id=None,
prompt="hello",
run_input=run_input,
current_user=_user(),
)
assert accepted.session_id == "00000000-0000-0000-0000-000000000999"
assert accepted.thread_id == "00000000-0000-0000-0000-000000000999"
assert accepted.run_id == "run-1"
assert accepted.created is True
assert repository.created_with_session_id == "00000000-0000-0000-0000-000000000999"
assert repository.committed is True
@@ -111,11 +159,14 @@ async def test_enqueue_run_keeps_created_session_when_enqueue_fails() -> None:
queue=_FailingQueue(),
stream=_FakeStream(),
)
run_input = _build_run_input(
thread_id="00000000-0000-0000-0000-000000000999",
run_id="run-1",
)
try:
await service.enqueue_run(
session_id=None,
prompt="hello",
run_input=run_input,
current_user=_user(),
)
raise AssertionError("expected RuntimeError")
@@ -123,3 +174,78 @@ async def test_enqueue_run_keeps_created_session_when_enqueue_fails() -> None:
assert str(exc) == "enqueue failed"
assert repository.deleted_session_id is None
async def test_enqueue_run_handles_session_create_race() -> None:
class _RaceRepository(_FakeRepository):
def __init__(self) -> None:
super().__init__()
self.create_calls = 0
async def get_session_owner(self, *, session_id: str) -> str:
if self.create_calls == 0:
raise HTTPException(status_code=404, detail="Session not found")
return "00000000-0000-0000-0000-000000000001"
async def create_session_for_user(
self, *, user_id: str, session_id: str | None = None
) -> str:
del user_id, session_id
self.create_calls += 1
raise IntegrityError("insert", {}, Exception("duplicate key"))
repository = _RaceRepository()
service = AgentService(
repository=repository,
queue=_FakeQueue(),
stream=_FakeStream(),
)
run_input = _build_run_input(
thread_id="00000000-0000-0000-0000-000000000999",
run_id="run-1",
)
accepted = await service.enqueue_run(
run_input=run_input,
current_user=_user(),
)
assert accepted.created is False
assert repository.rolled_back is True
async def test_get_history_snapshot_wraps_history_day_as_state_snapshot_event() -> None:
service = AgentService(
repository=_FakeRepository(),
queue=_FakeQueue(),
stream=_FakeStream(),
)
event = await service.get_history_snapshot(
thread_id="00000000-0000-0000-0000-000000000001",
before=date(2026, 3, 7),
current_user=_user(),
)
assert event["type"] == "STATE_SNAPSHOT"
assert event["threadId"] == "00000000-0000-0000-0000-000000000001"
snapshot = event["snapshot"]
assert isinstance(snapshot, dict)
assert snapshot["scope"] == "history_day"
assert snapshot["day"] == "2026-03-06"
assert snapshot["messages"][0]["id"] == "m1"
async def test_get_user_history_snapshot_uses_latest_thread_when_absent() -> None:
service = AgentService(
repository=_FakeRepository(),
queue=_FakeQueue(),
stream=_FakeStream(),
)
event = await service.get_user_history_snapshot(
current_user=_user(),
thread_id=None,
before=None,
)
assert event["type"] == "STATE_SNAPSHOT"
assert event["threadId"] == "00000000-0000-0000-0000-000000000001"