feat: AG-UI 协议对齐与路由导航功能
- 前端: 添加 SSE 流式支持、stateSnapshot 事件、路由导航工具 - 前端: 实现工具调用审批流程,支持 pending 状态展示 - 后端: Agent 状态管理与会话持久化相关重构 - 文档: 新增 agent-agui-full-alignance 设计文档 - 测试: 补充相关单元测试和集成测试
This commit is contained in:
@@ -1,10 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from uuid import UUID
|
||||
|
||||
from ag_ui.core import (
|
||||
RunAgentInput,
|
||||
TextMessageContentEvent,
|
||||
TextMessageEndEvent,
|
||||
TextMessageStartEvent,
|
||||
ToolCallResultEvent,
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from core.agent.application.session_state_persistence import SessionStatePersistence
|
||||
from core.agent.application.session_state_persistence import (
|
||||
SessionStatePersistence,
|
||||
compute_tool_args_sha256,
|
||||
)
|
||||
from core.agent.domain.agui_input import extract_latest_tool_result
|
||||
from core.agent.domain.message_metadata import (
|
||||
MessageMetadataAssistantOutput,
|
||||
MessageMetadataToolResult,
|
||||
@@ -25,8 +37,13 @@ class ResumeService:
|
||||
self._session_factory = session_factory
|
||||
self._state_persistence = SessionStatePersistence()
|
||||
|
||||
async def resume(self, *, session_id: str, tool_call_id: str) -> dict[str, object]:
|
||||
session_uuid = UUID(session_id)
|
||||
async def resume(
|
||||
self,
|
||||
*,
|
||||
run_input: RunAgentInput,
|
||||
) -> dict[str, object]:
|
||||
session_uuid = UUID(run_input.thread_id)
|
||||
tool_call_id, tool_payload = extract_latest_tool_result(run_input)
|
||||
|
||||
async with self._session_factory() as db_session:
|
||||
session_repository = SessionRepository(db_session)
|
||||
@@ -41,6 +58,41 @@ class ResumeService:
|
||||
pending_tool_call = state_snapshot.get("pending_tool_call_id")
|
||||
if pending_tool_call != tool_call_id:
|
||||
raise ValueError("pending tool call does not match")
|
||||
pending_tool_name = state_snapshot.get("pending_tool_name")
|
||||
pending_tool_args_sha256 = state_snapshot.get("pending_tool_args_sha256")
|
||||
pending_tool_nonce = state_snapshot.get("pending_tool_nonce")
|
||||
if (
|
||||
not isinstance(pending_tool_name, str)
|
||||
or not pending_tool_name
|
||||
or not isinstance(pending_tool_args_sha256, str)
|
||||
or not pending_tool_args_sha256
|
||||
or not isinstance(pending_tool_nonce, str)
|
||||
or not pending_tool_nonce
|
||||
):
|
||||
raise ValueError("pending tool guard is incomplete")
|
||||
|
||||
tool_name = tool_payload.get("toolName")
|
||||
tool_args = tool_payload.get("toolArgs")
|
||||
nonce = tool_payload.get("nonce")
|
||||
if not isinstance(tool_name, str) or not tool_name:
|
||||
raise ValueError("resume payload missing toolName")
|
||||
if not isinstance(tool_args, dict):
|
||||
raise ValueError("resume payload missing toolArgs")
|
||||
if not isinstance(nonce, str) or not nonce:
|
||||
raise ValueError("resume payload missing nonce")
|
||||
if tool_name != pending_tool_name:
|
||||
raise ValueError("resume toolName does not match pending tool")
|
||||
if nonce != pending_tool_nonce:
|
||||
raise ValueError("resume nonce does not match pending tool")
|
||||
computed_args_sha256 = compute_tool_args_sha256(tool_args)
|
||||
if computed_args_sha256 != pending_tool_args_sha256:
|
||||
raise ValueError("resume toolArgs does not match pending tool")
|
||||
sanitized_tool_payload = self._sanitize_tool_payload(
|
||||
tool_name=tool_name,
|
||||
tool_args=tool_args,
|
||||
nonce=nonce,
|
||||
tool_payload=tool_payload,
|
||||
)
|
||||
|
||||
next_seq = await session_repository.next_message_seq(
|
||||
session_id=session_uuid
|
||||
@@ -49,9 +101,13 @@ class ResumeService:
|
||||
session_id=session_uuid,
|
||||
seq=next_seq,
|
||||
role=AgentChatMessageRole.TOOL,
|
||||
content='{"status":"ok"}',
|
||||
content=json.dumps(
|
||||
sanitized_tool_payload, ensure_ascii=True, separators=(",", ":")
|
||||
),
|
||||
metadata=MessageMetadataToolResult(
|
||||
tool_call_id=tool_call_id,
|
||||
run_id=run_input.run_id,
|
||||
tool_name=tool_name,
|
||||
).model_dump(),
|
||||
)
|
||||
await message_repository.append_message(
|
||||
@@ -71,4 +127,61 @@ class ResumeService:
|
||||
)
|
||||
await db_session.commit()
|
||||
|
||||
return {"session_id": session_id, "resumed": True, "state_snapshot": snapshot}
|
||||
tool_message_id = f"msg-tool-{next_seq}"
|
||||
assistant_message_id = f"msg-assistant-{next_seq + 1}"
|
||||
events = [
|
||||
ToolCallResultEvent(
|
||||
message_id=tool_message_id,
|
||||
tool_call_id=tool_call_id,
|
||||
content=json.dumps(
|
||||
sanitized_tool_payload, ensure_ascii=True, separators=(",", ":")
|
||||
),
|
||||
).model_dump(mode="json", by_alias=True, exclude_none=True),
|
||||
TextMessageStartEvent(
|
||||
message_id=assistant_message_id,
|
||||
role="assistant",
|
||||
).model_dump(mode="json", by_alias=True, exclude_none=True),
|
||||
TextMessageContentEvent(
|
||||
message_id=assistant_message_id,
|
||||
delta="Tool result received",
|
||||
).model_dump(mode="json", by_alias=True, exclude_none=True),
|
||||
TextMessageEndEvent(
|
||||
message_id=assistant_message_id
|
||||
).model_dump(mode="json", by_alias=True, exclude_none=True),
|
||||
]
|
||||
return {
|
||||
"threadId": run_input.thread_id,
|
||||
"runId": run_input.run_id,
|
||||
"resumed": True,
|
||||
"state_snapshot": snapshot,
|
||||
"events": events,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_tool_payload(
|
||||
*,
|
||||
tool_name: str,
|
||||
tool_args: dict[str, object],
|
||||
nonce: str,
|
||||
tool_payload: dict[str, object],
|
||||
) -> dict[str, object]:
|
||||
if tool_name != "navigate_to_route":
|
||||
raise ValueError("unsupported frontend tool in resume payload")
|
||||
target = tool_args.get("target")
|
||||
if not isinstance(target, str) or not target:
|
||||
raise ValueError("resume toolArgs missing target")
|
||||
raw_result = tool_payload.get("result")
|
||||
if not isinstance(raw_result, dict) or raw_result.get("ok") is not True:
|
||||
raise ValueError("frontend tool execution failed")
|
||||
sanitized_result = {
|
||||
"ok": True,
|
||||
"target": target,
|
||||
"replace": tool_args.get("replace") is True,
|
||||
"applied": True,
|
||||
}
|
||||
return {
|
||||
"toolName": tool_name,
|
||||
"toolArgs": tool_args,
|
||||
"nonce": nonce,
|
||||
"result": sanitized_result,
|
||||
}
|
||||
|
||||
@@ -1,14 +1,34 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from decimal import Decimal
|
||||
import json
|
||||
import re
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from ag_ui.core import (
|
||||
TextMessageContentEvent,
|
||||
TextMessageEndEvent,
|
||||
TextMessageStartEvent,
|
||||
ToolCallArgsEvent,
|
||||
ToolCallEndEvent,
|
||||
ToolCallResultEvent,
|
||||
ToolCallStartEvent,
|
||||
RunAgentInput,
|
||||
)
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from core.agent.application.session_state_persistence import SessionStatePersistence
|
||||
from core.agent.domain.agui_input import extract_latest_user_text
|
||||
from core.agent.application.session_state_persistence import (
|
||||
SessionStatePersistence,
|
||||
compute_tool_args_sha256,
|
||||
)
|
||||
from core.agent.domain.message_metadata import (
|
||||
MessageMetadataAssistantOutput,
|
||||
MessageMetadataToolResult,
|
||||
MessageMetadataToolCall,
|
||||
MessageMetadataUserInput,
|
||||
)
|
||||
@@ -27,6 +47,7 @@ from core.agent.infrastructure.persistence.user_context_loader import (
|
||||
from core.db import AsyncSessionLocal
|
||||
from models.agent_chat_message import AgentChatMessageRole
|
||||
from models.agent_chat_session import AgentChatSessionStatus
|
||||
from models.schedule_items import ScheduleItem, ScheduleItemSourceType, ScheduleItemStatus
|
||||
from models.llm import Llm
|
||||
from models.llm_factory import LlmFactory
|
||||
from models.system_agents import SystemAgents
|
||||
@@ -60,9 +81,14 @@ class RunService:
|
||||
self._state_persistence = SessionStatePersistence()
|
||||
self._user_context_cache = user_context_cache or create_user_context_cache()
|
||||
|
||||
async def run(self, *, session_id: str, user_input: str) -> dict[str, object]:
|
||||
session_uuid = UUID(session_id)
|
||||
pending_tool_call_id = f"tool-{uuid4()}"
|
||||
async def run(
|
||||
self,
|
||||
*,
|
||||
run_input: RunAgentInput,
|
||||
) -> dict[str, object]:
|
||||
session_uuid = UUID(run_input.thread_id)
|
||||
user_input = extract_latest_user_text(run_input)
|
||||
assistant_message_id = f"msg-{uuid4()}"
|
||||
|
||||
async with self._session_factory() as db_session:
|
||||
session_repository = SessionRepository(db_session)
|
||||
@@ -87,8 +113,12 @@ class RunService:
|
||||
user_context = await self._load_user_agent_context(
|
||||
db_session, session_uuid, chat_session.user_id
|
||||
)
|
||||
system_prompt = build_global_system_prompt(user_context)
|
||||
runtime_result = runtime.execute(
|
||||
system_prompt = self._build_system_prompt_with_tools(
|
||||
base_prompt=build_global_system_prompt(user_context),
|
||||
run_input=run_input,
|
||||
)
|
||||
runtime_result = await asyncio.to_thread(
|
||||
runtime.execute,
|
||||
user_input=user_input,
|
||||
system_prompt=system_prompt,
|
||||
)
|
||||
@@ -97,7 +127,10 @@ class RunService:
|
||||
completion_tokens = _to_int(runtime_result.get("completion_tokens", 0))
|
||||
total_tokens = _to_int(runtime_result.get("total_tokens", 0))
|
||||
cost = _to_decimal(runtime_result.get("cost", 0))
|
||||
agui_events = runtime_result.get("agui_events", [])
|
||||
planned_tool = self._select_tool_plan(
|
||||
user_input=user_input,
|
||||
available_tools={tool.name for tool in run_input.tools},
|
||||
)
|
||||
|
||||
next_seq = await session_repository.next_message_seq(
|
||||
session_id=session_uuid
|
||||
@@ -110,39 +143,354 @@ class RunService:
|
||||
model_code=model_code,
|
||||
metadata=MessageMetadataUserInput().model_dump(),
|
||||
)
|
||||
await message_repository.append_message(
|
||||
session_id=session_uuid,
|
||||
seq=next_seq + 1,
|
||||
role=AgentChatMessageRole.ASSISTANT,
|
||||
content=assistant_text or "Tool call pending approval",
|
||||
model_code=model_code,
|
||||
metadata=MessageMetadataToolCall(
|
||||
tool_call_id=pending_tool_call_id,
|
||||
).model_dump(),
|
||||
input_tokens=prompt_tokens,
|
||||
output_tokens=completion_tokens,
|
||||
cost=cost,
|
||||
)
|
||||
pending_tool_call_id: str | None = None
|
||||
events: list[dict[str, object]] = []
|
||||
message_delta = 2
|
||||
session_status = AgentChatSessionStatus.COMPLETED
|
||||
snapshot = self._state_persistence.build_completed_snapshot()
|
||||
|
||||
if planned_tool is None:
|
||||
await message_repository.append_message(
|
||||
session_id=session_uuid,
|
||||
seq=next_seq + 1,
|
||||
role=AgentChatMessageRole.ASSISTANT,
|
||||
content=assistant_text or "已完成处理。",
|
||||
model_code=model_code,
|
||||
metadata=MessageMetadataAssistantOutput().model_dump(),
|
||||
input_tokens=prompt_tokens,
|
||||
output_tokens=completion_tokens,
|
||||
cost=cost,
|
||||
)
|
||||
events.extend(
|
||||
self._build_text_message_events(
|
||||
message_id=assistant_message_id,
|
||||
text=assistant_text or "已完成处理。",
|
||||
)
|
||||
)
|
||||
elif planned_tool["target"] == "backend":
|
||||
tool_call_id = f"tool-{uuid4()}"
|
||||
tool_name = str(planned_tool["name"])
|
||||
tool_args = planned_tool["args"]
|
||||
if not isinstance(tool_args, dict):
|
||||
tool_args = {}
|
||||
tool_payload = await self._execute_backend_tool(
|
||||
session=db_session,
|
||||
owner_id=chat_session.user_id,
|
||||
tool_name=tool_name,
|
||||
tool_args=tool_args,
|
||||
)
|
||||
await message_repository.append_message(
|
||||
session_id=session_uuid,
|
||||
seq=next_seq + 1,
|
||||
role=AgentChatMessageRole.TOOL,
|
||||
content=json.dumps(
|
||||
tool_payload,
|
||||
ensure_ascii=True,
|
||||
separators=(",", ":"),
|
||||
),
|
||||
metadata=MessageMetadataToolResult(
|
||||
tool_call_id=tool_call_id,
|
||||
run_id=run_input.run_id,
|
||||
tool_name=tool_name,
|
||||
).model_dump(),
|
||||
)
|
||||
await message_repository.append_message(
|
||||
session_id=session_uuid,
|
||||
seq=next_seq + 2,
|
||||
role=AgentChatMessageRole.ASSISTANT,
|
||||
content=assistant_text or "后端工具执行完成。",
|
||||
model_code=model_code,
|
||||
metadata=MessageMetadataAssistantOutput().model_dump(),
|
||||
input_tokens=prompt_tokens,
|
||||
output_tokens=completion_tokens,
|
||||
cost=cost,
|
||||
)
|
||||
message_delta = 3
|
||||
events.extend(
|
||||
self._build_tool_call_events(
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
tool_args=tool_args,
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
ToolCallResultEvent(
|
||||
message_id=f"msg-tool-{uuid4()}",
|
||||
tool_call_id=tool_call_id,
|
||||
content=json.dumps(
|
||||
tool_payload,
|
||||
ensure_ascii=True,
|
||||
separators=(",", ":"),
|
||||
),
|
||||
).model_dump(mode="json", by_alias=True, exclude_none=True)
|
||||
)
|
||||
events.extend(
|
||||
self._build_text_message_events(
|
||||
message_id=assistant_message_id,
|
||||
text=assistant_text or "后端工具执行完成。",
|
||||
)
|
||||
)
|
||||
else:
|
||||
pending_tool_call_id = f"tool-{uuid4()}"
|
||||
tool_name = str(planned_tool["name"])
|
||||
tool_args = planned_tool["args"]
|
||||
if not isinstance(tool_args, dict):
|
||||
tool_args = {}
|
||||
pending_tool_nonce = uuid4().hex
|
||||
guarded_tool_args = {
|
||||
**tool_args,
|
||||
"__nonce": pending_tool_nonce,
|
||||
}
|
||||
pending_tool_args_sha256 = compute_tool_args_sha256(guarded_tool_args)
|
||||
await message_repository.append_message(
|
||||
session_id=session_uuid,
|
||||
seq=next_seq + 1,
|
||||
role=AgentChatMessageRole.ASSISTANT,
|
||||
content=assistant_text or "Tool call pending approval",
|
||||
model_code=model_code,
|
||||
metadata=MessageMetadataToolCall(
|
||||
tool_call_id=pending_tool_call_id,
|
||||
).model_dump(),
|
||||
input_tokens=prompt_tokens,
|
||||
output_tokens=completion_tokens,
|
||||
cost=cost,
|
||||
)
|
||||
snapshot = self._state_persistence.build_running_snapshot(
|
||||
pending_tool_call_id=pending_tool_call_id,
|
||||
pending_tool_name=tool_name,
|
||||
pending_tool_args_sha256=pending_tool_args_sha256,
|
||||
pending_tool_nonce=pending_tool_nonce,
|
||||
)
|
||||
session_status = AgentChatSessionStatus.RUNNING
|
||||
events.extend(
|
||||
self._build_tool_call_events(
|
||||
tool_call_id=pending_tool_call_id,
|
||||
tool_name=tool_name,
|
||||
tool_args=guarded_tool_args,
|
||||
)
|
||||
)
|
||||
events.extend(
|
||||
self._build_text_message_events(
|
||||
message_id=assistant_message_id,
|
||||
text=assistant_text or "请确认是否执行前端工具。",
|
||||
)
|
||||
)
|
||||
|
||||
snapshot = self._state_persistence.build_running_snapshot(
|
||||
pending_tool_call_id=pending_tool_call_id
|
||||
)
|
||||
await session_repository.update_runtime_state(
|
||||
chat_session=chat_session,
|
||||
status=AgentChatSessionStatus.RUNNING,
|
||||
status=session_status,
|
||||
state_snapshot=snapshot,
|
||||
message_delta=2,
|
||||
message_delta=message_delta,
|
||||
token_delta=total_tokens,
|
||||
cost_delta=cost,
|
||||
)
|
||||
await db_session.commit()
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"threadId": run_input.thread_id,
|
||||
"runId": run_input.run_id,
|
||||
"persisted": True,
|
||||
"pending_tool_call_id": pending_tool_call_id,
|
||||
"state_snapshot": snapshot,
|
||||
"events": agui_events,
|
||||
"events": events,
|
||||
}
|
||||
|
||||
def _build_system_prompt_with_tools(
|
||||
self, *, base_prompt: str, run_input: RunAgentInput
|
||||
) -> str:
|
||||
if not run_input.tools:
|
||||
return base_prompt
|
||||
tool_lines = [
|
||||
f"- {tool.name}: {tool.description}" for tool in run_input.tools
|
||||
]
|
||||
tools_block = "\n".join(tool_lines)
|
||||
return f"# AVAILABLE_TOOLS\n{tools_block}\n\n{base_prompt}"
|
||||
|
||||
def _select_tool_plan(
|
||||
self,
|
||||
*,
|
||||
user_input: str,
|
||||
available_tools: set[str],
|
||||
) -> dict[str, object] | None:
|
||||
forced = re.search(r"#tool:(\w+)\s*(\{.*\})?", user_input)
|
||||
if forced is not None:
|
||||
forced_name = forced.group(1)
|
||||
if forced_name not in available_tools:
|
||||
return None
|
||||
raw_args = forced.group(2)
|
||||
args: dict[str, object] = {}
|
||||
if raw_args:
|
||||
try:
|
||||
parsed = json.loads(raw_args)
|
||||
if isinstance(parsed, dict):
|
||||
args = parsed
|
||||
except (TypeError, ValueError):
|
||||
args = {}
|
||||
target = (
|
||||
"frontend" if forced_name == "navigate_to_route" else "backend"
|
||||
)
|
||||
return {"name": forced_name, "args": args, "target": target}
|
||||
|
||||
normalized = user_input.lower()
|
||||
wants_navigation = any(
|
||||
keyword in normalized for keyword in ("打开", "跳转", "进入", "navigate", "open")
|
||||
)
|
||||
if wants_navigation and "navigate_to_route" in available_tools:
|
||||
target_route = "/calendar/dayweek"
|
||||
if "设置" in user_input:
|
||||
target_route = "/settings"
|
||||
elif "待办" in user_input:
|
||||
target_route = "/todo"
|
||||
return {
|
||||
"name": "navigate_to_route",
|
||||
"args": {"target": target_route, "replace": False},
|
||||
"target": "frontend",
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
def _infer_calendar_args(self, user_input: str) -> dict[str, object]:
|
||||
start_at = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
title = user_input.strip()[:80] or "新的日程"
|
||||
return {
|
||||
"title": title,
|
||||
"description": user_input.strip(),
|
||||
"startAt": start_at.isoformat(),
|
||||
"timezone": "Asia/Shanghai",
|
||||
}
|
||||
|
||||
def _build_text_message_events(
|
||||
self, *, message_id: str, text: str
|
||||
) -> list[dict[str, object]]:
|
||||
events: list[dict[str, object]] = [
|
||||
TextMessageStartEvent(
|
||||
message_id=message_id,
|
||||
role="assistant",
|
||||
).model_dump(mode="json", by_alias=True, exclude_none=True),
|
||||
]
|
||||
if text:
|
||||
events.append(
|
||||
TextMessageContentEvent(
|
||||
message_id=message_id,
|
||||
delta=text,
|
||||
).model_dump(mode="json", by_alias=True, exclude_none=True)
|
||||
)
|
||||
events.append(
|
||||
TextMessageEndEvent(
|
||||
message_id=message_id
|
||||
).model_dump(mode="json", by_alias=True, exclude_none=True)
|
||||
)
|
||||
return events
|
||||
|
||||
def _build_tool_call_events(
|
||||
self,
|
||||
*,
|
||||
tool_call_id: str,
|
||||
tool_name: str,
|
||||
tool_args: dict[str, object],
|
||||
) -> list[dict[str, object]]:
|
||||
return [
|
||||
ToolCallStartEvent(
|
||||
tool_call_id=tool_call_id,
|
||||
tool_call_name=tool_name,
|
||||
).model_dump(mode="json", by_alias=True, exclude_none=True),
|
||||
ToolCallArgsEvent(
|
||||
tool_call_id=tool_call_id,
|
||||
delta=json.dumps(tool_args, ensure_ascii=True, separators=(",", ":")),
|
||||
).model_dump(mode="json", by_alias=True, exclude_none=True),
|
||||
ToolCallEndEvent(
|
||||
tool_call_id=tool_call_id
|
||||
).model_dump(mode="json", by_alias=True, exclude_none=True),
|
||||
]
|
||||
|
||||
async def _execute_backend_tool(
|
||||
self,
|
||||
*,
|
||||
session: AsyncSession,
|
||||
owner_id: UUID,
|
||||
tool_name: str,
|
||||
tool_args: dict[str, object],
|
||||
) -> dict[str, object]:
|
||||
if tool_name != "create_calendar_event":
|
||||
raise ValueError(f"unsupported backend tool: {tool_name}")
|
||||
title = str(tool_args.get("title", "新的日程")).strip() or "新的日程"
|
||||
description = str(tool_args.get("description", "")).strip() or None
|
||||
start_raw = tool_args.get("startAt")
|
||||
start_at = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
if isinstance(start_raw, str) and start_raw:
|
||||
try:
|
||||
parsed = datetime.fromisoformat(start_raw.replace("Z", "+00:00"))
|
||||
if parsed.tzinfo is None:
|
||||
parsed = parsed.replace(tzinfo=timezone.utc)
|
||||
start_at = parsed.astimezone(timezone.utc)
|
||||
except ValueError:
|
||||
start_at = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
end_raw = tool_args.get("endAt")
|
||||
end_at: datetime | None = None
|
||||
if isinstance(end_raw, str) and end_raw:
|
||||
try:
|
||||
parsed_end = datetime.fromisoformat(end_raw.replace("Z", "+00:00"))
|
||||
if parsed_end.tzinfo is None:
|
||||
parsed_end = parsed_end.replace(tzinfo=timezone.utc)
|
||||
end_at = parsed_end.astimezone(timezone.utc)
|
||||
except ValueError:
|
||||
end_at = None
|
||||
timezone_value = str(tool_args.get("timezone", "Asia/Shanghai"))
|
||||
location = tool_args.get("location")
|
||||
location_value = str(location) if isinstance(location, str) else None
|
||||
|
||||
schedule_item = ScheduleItem(
|
||||
owner_id=owner_id,
|
||||
title=title,
|
||||
description=description,
|
||||
start_at=start_at,
|
||||
end_at=end_at,
|
||||
timezone=timezone_value,
|
||||
extra_metadata={"location": location_value} if location_value else {},
|
||||
source_type=ScheduleItemSourceType.AGENT_GENERATED,
|
||||
status=ScheduleItemStatus.ACTIVE,
|
||||
created_by=owner_id,
|
||||
)
|
||||
session.add(schedule_item)
|
||||
await session.flush()
|
||||
|
||||
event_id = str(schedule_item.id)
|
||||
ui_card = {
|
||||
"type": "calendar_card.v1",
|
||||
"version": "v1",
|
||||
"data": {
|
||||
"id": event_id,
|
||||
"title": title,
|
||||
"description": description,
|
||||
"startAt": start_at.isoformat(),
|
||||
"endAt": end_at.isoformat() if end_at is not None else None,
|
||||
"timezone": timezone_value,
|
||||
"location": location_value,
|
||||
"color": "#4F46E5",
|
||||
"sourceType": "agent_generated",
|
||||
},
|
||||
"actions": [
|
||||
{
|
||||
"type": "link",
|
||||
"label": "查看详情",
|
||||
"target": f"/calendar/events/{event_id}",
|
||||
}
|
||||
],
|
||||
}
|
||||
return {
|
||||
"result": {
|
||||
"eventId": event_id,
|
||||
"ok": True,
|
||||
"message": "日程已创建",
|
||||
"title": title,
|
||||
"description": description,
|
||||
"startAt": start_at.isoformat(),
|
||||
"endAt": end_at.isoformat() if end_at is not None else None,
|
||||
"timezone": timezone_value,
|
||||
"location": location_value,
|
||||
"sourceType": "agent_generated",
|
||||
},
|
||||
"ui": ui_card,
|
||||
}
|
||||
|
||||
async def _load_user_agent_context(
|
||||
|
||||
@@ -10,17 +10,35 @@ from core.agent.domain.state_snapshot import AgentStateSnapshot
|
||||
|
||||
class SessionStatePersistence:
|
||||
def build_running_snapshot(
|
||||
self, *, pending_tool_call_id: str | None
|
||||
self,
|
||||
*,
|
||||
pending_tool_call_id: str | None,
|
||||
pending_tool_name: str | None = None,
|
||||
pending_tool_args_sha256: str | None = None,
|
||||
pending_tool_nonce: str | None = None,
|
||||
) -> dict[str, object]:
|
||||
return AgentStateSnapshot(
|
||||
status="running",
|
||||
pending_tool_call_id=pending_tool_call_id,
|
||||
pending_tool_name=pending_tool_name,
|
||||
pending_tool_args_sha256=pending_tool_args_sha256,
|
||||
pending_tool_nonce=pending_tool_nonce,
|
||||
).model_dump()
|
||||
|
||||
def build_completed_snapshot(self) -> dict[str, object]:
|
||||
return AgentStateSnapshot(status="completed").model_dump()
|
||||
|
||||
|
||||
def compute_tool_args_sha256(tool_args: dict[str, object]) -> str:
|
||||
encoded = json.dumps(
|
||||
tool_args,
|
||||
ensure_ascii=True,
|
||||
sort_keys=True,
|
||||
separators=(",", ":"),
|
||||
).encode("utf-8")
|
||||
return hashlib.sha256(encoded).hexdigest()
|
||||
|
||||
|
||||
class ToolResultStorage(Protocol):
|
||||
async def upload_json(
|
||||
self,
|
||||
|
||||
@@ -0,0 +1,109 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from ag_ui.core import RunAgentInput
|
||||
from pydantic import ValidationError
|
||||
|
||||
MAX_RUN_INPUT_BYTES = 256_000
|
||||
MAX_RUN_ID_LENGTH = 128
|
||||
MAX_MESSAGES = 200
|
||||
MAX_TEXT_CHARS = 10_000
|
||||
|
||||
|
||||
def _safe_len(value: str | None) -> int:
|
||||
if value is None:
|
||||
return 0
|
||||
return len(value)
|
||||
|
||||
|
||||
def _user_text_chars(run_input: RunAgentInput) -> int:
|
||||
total = 0
|
||||
for message in run_input.messages:
|
||||
if getattr(message, "role", None) != "user":
|
||||
continue
|
||||
content = getattr(message, "content", None)
|
||||
if isinstance(content, str):
|
||||
total += len(content)
|
||||
continue
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if getattr(item, "type", None) != "text":
|
||||
continue
|
||||
text = getattr(item, "text", None)
|
||||
if isinstance(text, str):
|
||||
total += len(text)
|
||||
return total
|
||||
|
||||
|
||||
def parse_run_input(payload: dict[str, Any]) -> RunAgentInput:
|
||||
payload_bytes = len(
|
||||
json.dumps(payload, ensure_ascii=True, separators=(",", ":")).encode("utf-8")
|
||||
)
|
||||
if payload_bytes > MAX_RUN_INPUT_BYTES:
|
||||
raise ValueError("RunAgentInput payload exceeds size limit")
|
||||
try:
|
||||
run_input = RunAgentInput.model_validate(payload)
|
||||
except ValidationError as exc:
|
||||
raise ValueError("invalid AG-UI RunAgentInput payload") from exc
|
||||
try:
|
||||
UUID(run_input.thread_id)
|
||||
except ValueError as exc:
|
||||
raise ValueError("threadId must be a valid UUID") from exc
|
||||
if _safe_len(run_input.run_id) > MAX_RUN_ID_LENGTH:
|
||||
raise ValueError("runId exceeds length limit")
|
||||
if len(run_input.messages) > MAX_MESSAGES:
|
||||
raise ValueError("RunAgentInput.messages exceeds limit")
|
||||
if _user_text_chars(run_input) > MAX_TEXT_CHARS:
|
||||
raise ValueError("RunAgentInput user message text exceeds limit")
|
||||
return run_input
|
||||
|
||||
|
||||
def extract_latest_user_text(run_input: RunAgentInput) -> str:
|
||||
for message in reversed(run_input.messages):
|
||||
role = getattr(message, "role", None)
|
||||
if role != "user":
|
||||
continue
|
||||
content = getattr(message, "content", None)
|
||||
if isinstance(content, str):
|
||||
text = content.strip()
|
||||
if text:
|
||||
return text
|
||||
continue
|
||||
if isinstance(content, list):
|
||||
text_parts: list[str] = []
|
||||
for item in content:
|
||||
if getattr(item, "type", None) != "text":
|
||||
continue
|
||||
text = getattr(item, "text", None)
|
||||
if isinstance(text, str):
|
||||
text_parts.append(text)
|
||||
combined = "".join(text_parts).strip()
|
||||
if combined:
|
||||
return combined
|
||||
raise ValueError("RunAgentInput.messages requires at least one non-empty user message")
|
||||
|
||||
|
||||
def extract_latest_tool_result(run_input: RunAgentInput) -> tuple[str, dict[str, object]]:
|
||||
for message in reversed(run_input.messages):
|
||||
role = getattr(message, "role", None)
|
||||
if role != "tool":
|
||||
continue
|
||||
tool_call_id = getattr(message, "tool_call_id", None)
|
||||
content = getattr(message, "content", None)
|
||||
if not isinstance(tool_call_id, str) or not tool_call_id:
|
||||
continue
|
||||
if not isinstance(content, str):
|
||||
break
|
||||
try:
|
||||
parsed = json.loads(content)
|
||||
except (TypeError, ValueError):
|
||||
return tool_call_id, {"content": content}
|
||||
if isinstance(parsed, dict):
|
||||
return tool_call_id, parsed
|
||||
return tool_call_id, {"content": content}
|
||||
raise ValueError(
|
||||
"RunAgentInput.messages requires a tool message with toolCallId for resume"
|
||||
)
|
||||
@@ -8,3 +8,6 @@ from pydantic import BaseModel
|
||||
class AgentStateSnapshot(BaseModel):
|
||||
status: Literal["pending", "running", "completed", "failed"]
|
||||
pending_tool_call_id: str | None = None
|
||||
pending_tool_name: str | None = None
|
||||
pending_tool_args_sha256: str | None = None
|
||||
pending_tool_nonce: str | None = None
|
||||
|
||||
@@ -6,3 +6,4 @@ from pydantic import BaseModel, Field
|
||||
class SystemAgentLLMConfig(BaseModel):
|
||||
temperature: float | None = Field(default=None, ge=0.0, le=2.0)
|
||||
max_tokens: int | None = Field(default=None, ge=1)
|
||||
timeout_seconds: float | None = Field(default=30.0, gt=0.0, le=300.0)
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
_EVENT_TYPE_RE = re.compile(r"^[A-Z0-9_]+$")
|
||||
|
||||
|
||||
def to_sse_event(stream_id: str, event: dict[str, Any]) -> str:
|
||||
event_type = str(event.get("type", "MESSAGE"))
|
||||
payload = json.dumps(event.get("data", {}), ensure_ascii=True)
|
||||
raw_event_type = str(event.get("type", "MESSAGE")).replace("\r", "").replace(
|
||||
"\n", ""
|
||||
)
|
||||
event_type = raw_event_type if _EVENT_TYPE_RE.fullmatch(raw_event_type) else "MESSAGE"
|
||||
payload = json.dumps(event, ensure_ascii=True, separators=(",", ":"))
|
||||
return f"id: {stream_id}\nevent: {event_type}\ndata: {payload}\n\n"
|
||||
|
||||
@@ -129,6 +129,7 @@ def _run_stage(
|
||||
messages=messages,
|
||||
temperature=llm_config.temperature,
|
||||
max_tokens=llm_config.max_tokens,
|
||||
timeout=llm_config.timeout_seconds,
|
||||
)
|
||||
if not isinstance(response, dict):
|
||||
raise ValueError("llm response must be a dict")
|
||||
|
||||
@@ -46,7 +46,7 @@ class RedisStreamEventStore:
|
||||
last_event_id: str | None,
|
||||
) -> list[dict[str, Any]]:
|
||||
stream = self._stream_name(session_id)
|
||||
start_id = "$" if last_event_id is None else last_event_id
|
||||
start_id = "0-0" if last_event_id is None else last_event_id
|
||||
raw_response = self._client.xread(
|
||||
{stream: start_id},
|
||||
count=self._read_count,
|
||||
@@ -59,13 +59,37 @@ class RedisStreamEventStore:
|
||||
if not response:
|
||||
return []
|
||||
|
||||
_, entries = response[0]
|
||||
first = response[0]
|
||||
if (
|
||||
not isinstance(first, tuple)
|
||||
or len(first) != 2
|
||||
or not isinstance(first[1], list)
|
||||
):
|
||||
return []
|
||||
_, entries = first
|
||||
result: list[dict[str, Any]] = []
|
||||
for stream_id, payload in entries:
|
||||
for entry in entries:
|
||||
if (
|
||||
not isinstance(entry, tuple)
|
||||
or len(entry) != 2
|
||||
or not isinstance(entry[0], str)
|
||||
or not isinstance(entry[1], dict)
|
||||
):
|
||||
continue
|
||||
stream_id, payload = entry
|
||||
event_payload = payload.get("event")
|
||||
if not isinstance(event_payload, str):
|
||||
continue
|
||||
try:
|
||||
parsed_event = json.loads(event_payload)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
if not isinstance(parsed_event, dict):
|
||||
continue
|
||||
result.append(
|
||||
{
|
||||
"id": stream_id,
|
||||
"event": json.loads(payload["event"]),
|
||||
"event": parsed_event,
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
@@ -12,6 +12,7 @@ def run_completion(
|
||||
messages: list[dict[str, Any]],
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> Any:
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": model,
|
||||
@@ -23,6 +24,8 @@ def run_completion(
|
||||
kwargs["temperature"] = temperature
|
||||
if max_tokens is not None:
|
||||
kwargs["max_tokens"] = max_tokens
|
||||
if timeout is not None:
|
||||
kwargs["timeout"] = timeout
|
||||
|
||||
response = completion(**kwargs)
|
||||
model_dump = getattr(response, "model_dump", None)
|
||||
|
||||
@@ -9,6 +9,9 @@ import redis.asyncio as redis
|
||||
|
||||
from core.agent.domain.user_context import UserAgentContext, parse_profile_settings
|
||||
from core.config.settings import config
|
||||
from core.logging import get_logger
|
||||
|
||||
logger = get_logger("core.agent.infrastructure.persistence.user_context_cache")
|
||||
|
||||
|
||||
class RedisHashClient(Protocol):
|
||||
@@ -47,7 +50,12 @@ class UserContextCache:
|
||||
key = self._key(session_id)
|
||||
try:
|
||||
raw = await _maybe_await(self._client.hgetall(key))
|
||||
except Exception:
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to read user context cache",
|
||||
session_id=str(session_id),
|
||||
error=str(exc),
|
||||
)
|
||||
return None
|
||||
|
||||
if not isinstance(raw, dict) or not raw:
|
||||
@@ -92,7 +100,12 @@ class UserContextCache:
|
||||
)
|
||||
)
|
||||
await _maybe_await(self._client.expire(key, self._ttl_seconds))
|
||||
except Exception:
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to write user context cache",
|
||||
session_id=str(session_id),
|
||||
error=str(exc),
|
||||
)
|
||||
return None
|
||||
|
||||
def _key(self, session_id: UUID) -> str:
|
||||
@@ -136,13 +149,21 @@ class UserContextCache:
|
||||
async def _safe_delete(self, key: str) -> None:
|
||||
try:
|
||||
await _maybe_await(self._client.delete(key))
|
||||
except Exception:
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to delete user context cache key", key=key, error=str(exc))
|
||||
return None
|
||||
|
||||
async def _safe_hincrby(self, key: str, field: str, amount: int) -> None:
|
||||
try:
|
||||
await _maybe_await(self._client.hincrby(key, field, amount))
|
||||
except Exception:
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to update user context cache usage",
|
||||
key=key,
|
||||
field=field,
|
||||
amount=amount,
|
||||
error=str(exc),
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,10 @@ from __future__ import annotations
|
||||
|
||||
from typing import Any, Protocol
|
||||
from uuid import UUID
|
||||
import re
|
||||
|
||||
from ag_ui.core import RunAgentInput, RunErrorEvent, RunFinishedEvent, RunStartedEvent
|
||||
from core.agent.domain.agui_input import parse_run_input
|
||||
from core.agent.application.resume_service import ResumeService
|
||||
from core.agent.application.run_service import RunService
|
||||
from core.agent.infrastructure.events.redis_stream import RedisStreamEventStore
|
||||
@@ -13,19 +16,65 @@ from services.base.redis import get_or_init_redis_client
|
||||
|
||||
logger = get_logger("core.agent.infrastructure.queue.tasks")
|
||||
|
||||
_NON_ALNUM_RE = re.compile(r"[^A-Za-z0-9]+")
|
||||
_SENSITIVE_KEYS = {
|
||||
"apikey",
|
||||
"authorization",
|
||||
"token",
|
||||
"accesstoken",
|
||||
"refreshtoken",
|
||||
"secret",
|
||||
"password",
|
||||
"cookie",
|
||||
}
|
||||
|
||||
|
||||
class PublishEvent(Protocol):
|
||||
async def __call__(self, event_type: str, payload: dict[str, object]) -> None: ...
|
||||
async def __call__(self, event: dict[str, object]) -> None: ...
|
||||
|
||||
|
||||
class RunServiceLike(Protocol):
|
||||
async def run(self, *, session_id: str, user_input: str) -> dict[str, object]: ...
|
||||
async def run(self, *, run_input: RunAgentInput) -> dict[str, object]: ...
|
||||
|
||||
|
||||
class ResumeServiceLike(Protocol):
|
||||
async def resume(
|
||||
self, *, session_id: str, tool_call_id: str
|
||||
) -> dict[str, object]: ...
|
||||
async def resume(self, *, run_input: RunAgentInput) -> dict[str, object]: ...
|
||||
|
||||
|
||||
def _is_sensitive_key(key: str) -> bool:
|
||||
normalized = _NON_ALNUM_RE.sub("", key.lower())
|
||||
if normalized in _SENSITIVE_KEYS:
|
||||
return True
|
||||
if "token" in normalized:
|
||||
return True
|
||||
if "api" in normalized and "key" in normalized:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _redact_sensitive(value: Any) -> Any:
|
||||
if isinstance(value, dict):
|
||||
return {
|
||||
k: "***REDACTED***" if _is_sensitive_key(str(k)) else _redact_sensitive(v)
|
||||
for k, v in value.items()
|
||||
}
|
||||
if isinstance(value, list):
|
||||
return [_redact_sensitive(item) for item in value]
|
||||
return value
|
||||
|
||||
|
||||
def _normalize_stream_event(
|
||||
*,
|
||||
event: dict[str, object],
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
) -> dict[str, object]:
|
||||
normalized = dict(event)
|
||||
normalized["threadId"] = thread_id
|
||||
normalized["runId"] = run_id
|
||||
if normalized.get("type") == "RUN_STARTED":
|
||||
normalized.pop("input", None)
|
||||
return _redact_sensitive(normalized)
|
||||
|
||||
|
||||
async def _build_redis_publisher() -> PublishEvent:
|
||||
@@ -37,13 +86,13 @@ async def _build_redis_publisher() -> PublishEvent:
|
||||
block_ms=config.agent_runtime.redis_stream_block_ms,
|
||||
)
|
||||
|
||||
async def _publish(event_type: str, payload: dict[str, object]) -> None:
|
||||
session_id = str(payload.get("session_id", "")).strip()
|
||||
if not session_id:
|
||||
raise ValueError("session_id is required in event payload")
|
||||
async def _publish(event: dict[str, object]) -> None:
|
||||
thread_id = str(event.get("threadId", "")).strip()
|
||||
if not thread_id:
|
||||
raise ValueError("threadId is required in event payload")
|
||||
await event_store.append_event(
|
||||
session_id=UUID(session_id),
|
||||
event={"type": event_type, "data": payload},
|
||||
session_id=UUID(thread_id),
|
||||
event=event,
|
||||
)
|
||||
|
||||
return _publish
|
||||
@@ -61,69 +110,69 @@ async def run_agent_task(
|
||||
service_resume = resume_service or ResumeService()
|
||||
|
||||
command_type = str(command.get("command", "run"))
|
||||
session_id = str(command.get("session_id", ""))
|
||||
|
||||
if command_type not in {"run", "resume"}:
|
||||
raise ValueError("invalid command type")
|
||||
if not session_id:
|
||||
raise ValueError("session_id is required")
|
||||
UUID(session_id)
|
||||
raw_run_input = command.get("run_input")
|
||||
if not isinstance(raw_run_input, dict):
|
||||
raise ValueError("run_input is required")
|
||||
run_input = parse_run_input(raw_run_input)
|
||||
UUID(run_input.thread_id)
|
||||
|
||||
tool_call_id = ""
|
||||
user_input = ""
|
||||
if command_type == "resume":
|
||||
tool_call_id = str(command.get("tool_call_id", ""))
|
||||
if not tool_call_id:
|
||||
raise ValueError("tool_call_id is required")
|
||||
else:
|
||||
user_input = str(command.get("user_input", ""))
|
||||
if not user_input:
|
||||
raise ValueError("user_input is required")
|
||||
|
||||
start_event = "RUN_RESUMED" if command_type == "resume" else "RUN_STARTED"
|
||||
await publisher(start_event, {"session_id": session_id})
|
||||
await publisher(
|
||||
RunStartedEvent(
|
||||
thread_id=run_input.thread_id,
|
||||
run_id=run_input.run_id,
|
||||
parent_run_id=run_input.parent_run_id,
|
||||
).model_dump(mode="json", by_alias=True, exclude_none=True)
|
||||
)
|
||||
|
||||
try:
|
||||
if command_type == "resume":
|
||||
result = await service_resume.resume(
|
||||
session_id=session_id,
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
result = await service_resume.resume(run_input=run_input)
|
||||
else:
|
||||
result = await service_run.run(
|
||||
session_id=session_id,
|
||||
user_input=user_input,
|
||||
)
|
||||
result = await service_run.run(run_input=run_input)
|
||||
|
||||
await publisher("RUNTIME_EVENT", {"session_id": session_id, "result": result})
|
||||
extra_events = result.get("events") if isinstance(result, dict) else None
|
||||
if isinstance(extra_events, list):
|
||||
for event in extra_events:
|
||||
if not isinstance(event, dict):
|
||||
continue
|
||||
event_type = event.get("type")
|
||||
event_data = event.get("data")
|
||||
if not isinstance(event_type, str) or not isinstance(event_data, dict):
|
||||
if not isinstance(event_type, str):
|
||||
continue
|
||||
payload = {"session_id": session_id, **event_data}
|
||||
await publisher(event_type, payload)
|
||||
await publisher("RUN_FINISHED", {"session_id": session_id})
|
||||
await publisher(
|
||||
_normalize_stream_event(
|
||||
event=event,
|
||||
thread_id=run_input.thread_id,
|
||||
run_id=run_input.run_id,
|
||||
)
|
||||
)
|
||||
await publisher(
|
||||
RunFinishedEvent(
|
||||
thread_id=run_input.thread_id,
|
||||
run_id=run_input.run_id,
|
||||
).model_dump(mode="json", by_alias=True, exclude_none=True)
|
||||
)
|
||||
return result
|
||||
except Exception: # noqa: BLE001
|
||||
error_id = "agent_runtime_failed"
|
||||
logger.exception(
|
||||
"Agent task failed",
|
||||
session_id=session_id,
|
||||
thread_id=run_input.thread_id,
|
||||
error_id=error_id,
|
||||
)
|
||||
try:
|
||||
await publisher(
|
||||
"RUN_ERROR", {"session_id": session_id, "error_id": error_id}
|
||||
)
|
||||
error_event = RunErrorEvent(
|
||||
message="Agent task failed",
|
||||
code=error_id,
|
||||
).model_dump(mode="json", by_alias=True, exclude_none=True)
|
||||
error_event["threadId"] = run_input.thread_id
|
||||
error_event["runId"] = run_input.run_id
|
||||
await publisher(error_event)
|
||||
except Exception as publish_exc: # noqa: BLE001
|
||||
logger.warning(
|
||||
"Failed to publish RUN_ERROR event",
|
||||
session_id=session_id,
|
||||
thread_id=run_input.thread_id,
|
||||
error=str(publish_exc),
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date, datetime, time, timedelta, timezone
|
||||
import json
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
|
||||
from models.agent_chat_session import AgentChatSession
|
||||
|
||||
|
||||
@@ -27,13 +30,22 @@ class AgentRepository:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
return str(owner_id)
|
||||
|
||||
async def create_session_for_user(self, *, user_id: str) -> str:
|
||||
async def create_session_for_user(
|
||||
self, *, user_id: str, session_id: str | None = None
|
||||
) -> str:
|
||||
try:
|
||||
user_uuid = UUID(user_id)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail="Invalid user_id") from exc
|
||||
session_uuid = None
|
||||
if session_id is not None:
|
||||
try:
|
||||
session_uuid = UUID(session_id)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail="Invalid session_id") from exc
|
||||
|
||||
session = AgentChatSession(
|
||||
id=session_uuid,
|
||||
user_id=user_uuid,
|
||||
)
|
||||
self._session.add(session)
|
||||
@@ -56,3 +68,114 @@ class AgentRepository:
|
||||
if session is not None:
|
||||
await self._session.delete(session)
|
||||
await self._session.flush()
|
||||
|
||||
async def get_history_day(
|
||||
self, *, session_id: str, before: date | None
|
||||
) -> dict[str, object] | None:
|
||||
try:
|
||||
session_uuid = UUID(session_id)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail="Invalid session_id") from exc
|
||||
|
||||
timestamp_stmt = (
|
||||
select(AgentChatMessage.created_at)
|
||||
.where(AgentChatMessage.session_id == session_uuid)
|
||||
.where(AgentChatMessage.deleted_at.is_(None))
|
||||
.order_by(AgentChatMessage.created_at.desc())
|
||||
)
|
||||
rows = (await self._session.execute(timestamp_stmt)).scalars().all()
|
||||
unique_days: list[date] = []
|
||||
for created_at in rows:
|
||||
if created_at is None:
|
||||
continue
|
||||
day = created_at.astimezone(timezone.utc).date()
|
||||
if day not in unique_days:
|
||||
unique_days.append(day)
|
||||
|
||||
if not unique_days:
|
||||
return None
|
||||
|
||||
target_day: date | None = None
|
||||
if before is None:
|
||||
target_day = unique_days[0]
|
||||
else:
|
||||
for day in unique_days:
|
||||
if day < before:
|
||||
target_day = day
|
||||
break
|
||||
if target_day is None:
|
||||
return None
|
||||
|
||||
start = datetime.combine(target_day, time.min, tzinfo=timezone.utc)
|
||||
end = start + timedelta(days=1)
|
||||
message_stmt = (
|
||||
select(AgentChatMessage)
|
||||
.where(AgentChatMessage.session_id == session_uuid)
|
||||
.where(AgentChatMessage.deleted_at.is_(None))
|
||||
.where(AgentChatMessage.created_at >= start)
|
||||
.where(AgentChatMessage.created_at < end)
|
||||
.order_by(AgentChatMessage.seq.asc())
|
||||
)
|
||||
messages = (await self._session.execute(message_stmt)).scalars().all()
|
||||
has_more = any(day < target_day for day in unique_days)
|
||||
return {
|
||||
"day": target_day.isoformat(),
|
||||
"hasMore": has_more,
|
||||
"messages": [self._to_snapshot_message(msg) for msg in messages],
|
||||
}
|
||||
|
||||
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None:
|
||||
try:
|
||||
user_uuid = UUID(user_id)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail="Invalid user_id") from exc
|
||||
stmt = (
|
||||
select(AgentChatSession.id)
|
||||
.where(AgentChatSession.user_id == user_uuid)
|
||||
.where(AgentChatSession.deleted_at.is_(None))
|
||||
.order_by(AgentChatSession.last_activity_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
latest_id = (await self._session.execute(stmt)).scalar_one_or_none()
|
||||
if latest_id is None:
|
||||
return None
|
||||
return str(latest_id)
|
||||
|
||||
@staticmethod
|
||||
def _to_snapshot_message(message: AgentChatMessage) -> dict[str, object]:
|
||||
role = (
|
||||
message.role.value
|
||||
if isinstance(message.role, AgentChatMessageRole)
|
||||
else str(message.role)
|
||||
)
|
||||
payload: dict[str, object] = {
|
||||
"id": str(message.id),
|
||||
"role": role,
|
||||
"timestamp": message.created_at.astimezone(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
if role == AgentChatMessageRole.TOOL.value:
|
||||
metadata = message.metadata_json or {}
|
||||
tool_call_id = metadata.get("tool_call_id")
|
||||
if isinstance(tool_call_id, str) and tool_call_id:
|
||||
payload["toolCallId"] = tool_call_id
|
||||
|
||||
parsed_content: dict[str, object] | None = None
|
||||
try:
|
||||
decoded = json.loads(message.content)
|
||||
if isinstance(decoded, dict):
|
||||
parsed_content = decoded
|
||||
except (TypeError, ValueError):
|
||||
parsed_content = None
|
||||
if parsed_content is not None:
|
||||
ui = parsed_content.get("ui")
|
||||
if isinstance(ui, dict):
|
||||
payload["ui"] = ui
|
||||
display_content = parsed_content.get("content")
|
||||
if isinstance(display_content, str):
|
||||
payload["content"] = display_content
|
||||
else:
|
||||
payload["content"] = message.content
|
||||
else:
|
||||
payload["content"] = message.content
|
||||
return payload
|
||||
|
||||
+141
-32
@@ -2,96 +2,177 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
import asyncio
|
||||
from datetime import date
|
||||
import re
|
||||
import time
|
||||
from typing import Annotated
|
||||
|
||||
from ag_ui.core import RunAgentInput
|
||||
from fastapi import APIRouter, Depends, Header, Query, Request, status
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from core.agent.infrastructure.agui.stream import to_sse_event
|
||||
from core.agent.domain.agui_input import parse_run_input
|
||||
from core.auth.models import CurrentUser
|
||||
from services.base.redis import get_or_init_redis_client
|
||||
from v1.agent.dependencies import get_agent_service
|
||||
from v1.agent.schemas import ResumeRequest, RunRequest, TaskAcceptedResponse
|
||||
from v1.agent.schemas import TaskAcceptedResponse
|
||||
from v1.agent.service import AgentService
|
||||
from v1.users.dependencies import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/agent", tags=["agent"])
|
||||
_LAST_EVENT_ID_RE = re.compile(r"^\d+-\d+$")
|
||||
_RUNS_PER_MINUTE = 30
|
||||
_MAX_SSE_CONNECTIONS_PER_USER = 3
|
||||
_SSE_SLOT_TTL_SECONDS = 15 * 60
|
||||
|
||||
|
||||
async def _allow_run_request(*, user_id: str) -> bool:
|
||||
try:
|
||||
redis = await get_or_init_redis_client()
|
||||
minute_bucket = int(time.time() // 60)
|
||||
key = f"agent:run-rate:{user_id}:{minute_bucket}"
|
||||
count = await redis.incr(key)
|
||||
if count == 1:
|
||||
await redis.expire(key, 70)
|
||||
return int(count) <= _RUNS_PER_MINUTE
|
||||
except Exception: # noqa: BLE001
|
||||
return False
|
||||
|
||||
|
||||
async def _acquire_sse_slot(*, user_id: str) -> bool:
|
||||
try:
|
||||
redis = await get_or_init_redis_client()
|
||||
key = f"agent:sse-active:{user_id}"
|
||||
count = await redis.incr(key)
|
||||
if count == 1:
|
||||
await redis.expire(key, _SSE_SLOT_TTL_SECONDS)
|
||||
if int(count) > _MAX_SSE_CONNECTIONS_PER_USER:
|
||||
await redis.decr(key)
|
||||
return False
|
||||
return True
|
||||
except Exception: # noqa: BLE001
|
||||
return False
|
||||
|
||||
|
||||
async def _release_sse_slot(*, user_id: str) -> None:
|
||||
try:
|
||||
redis = await get_or_init_redis_client()
|
||||
key = f"agent:sse-active:{user_id}"
|
||||
count = await redis.decr(key)
|
||||
if int(count) <= 0:
|
||||
await redis.delete(key)
|
||||
except Exception: # noqa: BLE001
|
||||
return None
|
||||
|
||||
|
||||
@router.post(
|
||||
"/runs", response_model=TaskAcceptedResponse, status_code=status.HTTP_202_ACCEPTED
|
||||
)
|
||||
async def enqueue_run(
|
||||
request: RunRequest,
|
||||
request: RunAgentInput,
|
||||
service: Annotated[AgentService, Depends(get_agent_service)],
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
) -> TaskAcceptedResponse:
|
||||
try:
|
||||
parse_run_input(request.model_dump(mode="json", by_alias=True))
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
||||
allowed = await _allow_run_request(user_id=str(current_user.id))
|
||||
if not allowed:
|
||||
raise HTTPException(status_code=429, detail="Too many run requests")
|
||||
|
||||
task = await service.enqueue_run(
|
||||
session_id=request.session_id,
|
||||
prompt=request.prompt,
|
||||
run_input=request,
|
||||
current_user=current_user,
|
||||
)
|
||||
return TaskAcceptedResponse(
|
||||
task_id=task.task_id,
|
||||
session_id=task.session_id,
|
||||
thread_id=task.thread_id,
|
||||
run_id=task.run_id,
|
||||
created=task.created,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/runs/{session_id}/resume",
|
||||
"/runs/{thread_id}/resume",
|
||||
response_model=TaskAcceptedResponse,
|
||||
status_code=status.HTTP_202_ACCEPTED,
|
||||
)
|
||||
async def enqueue_resume(
|
||||
session_id: str,
|
||||
request: ResumeRequest,
|
||||
thread_id: str,
|
||||
request: RunAgentInput,
|
||||
service: Annotated[AgentService, Depends(get_agent_service)],
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
) -> TaskAcceptedResponse:
|
||||
if request.thread_id != thread_id:
|
||||
raise HTTPException(status_code=422, detail="thread_id path/body mismatch")
|
||||
try:
|
||||
parse_run_input(request.model_dump(mode="json", by_alias=True))
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
||||
task = await service.enqueue_resume(
|
||||
session_id=session_id,
|
||||
tool_call_id=request.tool_call_id,
|
||||
thread_id=thread_id,
|
||||
run_input=request,
|
||||
current_user=current_user,
|
||||
)
|
||||
return TaskAcceptedResponse(
|
||||
task_id=task.task_id,
|
||||
session_id=task.session_id,
|
||||
thread_id=task.thread_id,
|
||||
run_id=task.run_id,
|
||||
created=task.created,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/runs/{session_id}/events")
|
||||
@router.get("/runs/{thread_id}/events")
|
||||
async def stream_events(
|
||||
request: Request,
|
||||
session_id: str,
|
||||
thread_id: str,
|
||||
service: Annotated[AgentService, Depends(get_agent_service)],
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
last_event_id: str | None = Header(default=None, alias="Last-Event-ID"),
|
||||
idle_limit: int = Query(default=300, ge=1, le=3600),
|
||||
) -> StreamingResponse:
|
||||
if (
|
||||
last_event_id is not None
|
||||
and (
|
||||
len(last_event_id) > 32
|
||||
or _LAST_EVENT_ID_RE.fullmatch(last_event_id) is None
|
||||
)
|
||||
):
|
||||
raise HTTPException(status_code=422, detail="Invalid Last-Event-ID")
|
||||
|
||||
sse_slot_acquired = await _acquire_sse_slot(user_id=str(current_user.id))
|
||||
if not sse_slot_acquired:
|
||||
raise HTTPException(status_code=429, detail="Too many SSE connections")
|
||||
|
||||
async def _event_iter() -> AsyncIterator[str]:
|
||||
cursor = last_event_id
|
||||
idle_polls = 0
|
||||
while not await request.is_disconnected() and idle_polls < idle_limit:
|
||||
rows = await service.stream_events(
|
||||
session_id=session_id,
|
||||
last_event_id=cursor,
|
||||
current_user=current_user,
|
||||
)
|
||||
if not rows:
|
||||
idle_polls += 1
|
||||
yield ": keep-alive\n\n"
|
||||
await asyncio.sleep(0.2)
|
||||
continue
|
||||
|
||||
idle_polls = 0
|
||||
for row in rows:
|
||||
row_id = str(row.get("id", ""))
|
||||
event = row.get("event")
|
||||
if not row_id or not isinstance(event, dict):
|
||||
try:
|
||||
while not await request.is_disconnected() and idle_polls < idle_limit:
|
||||
rows = await service.stream_events(
|
||||
thread_id=thread_id,
|
||||
last_event_id=cursor,
|
||||
current_user=current_user,
|
||||
)
|
||||
if not rows:
|
||||
idle_polls += 1
|
||||
yield ": keep-alive\n\n"
|
||||
await asyncio.sleep(0.2)
|
||||
continue
|
||||
cursor = row_id
|
||||
yield to_sse_event(row_id, event)
|
||||
|
||||
idle_polls = 0
|
||||
for row in rows:
|
||||
row_id = str(row.get("id", ""))
|
||||
event = row.get("event")
|
||||
if not row_id or not isinstance(event, dict):
|
||||
continue
|
||||
cursor = row_id
|
||||
yield to_sse_event(row_id, event)
|
||||
finally:
|
||||
await _release_sse_slot(user_id=str(current_user.id))
|
||||
|
||||
return StreamingResponse(
|
||||
_event_iter(),
|
||||
@@ -102,3 +183,31 @@ async def stream_events(
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/runs/{thread_id}/history")
|
||||
async def get_history_snapshot(
|
||||
thread_id: str,
|
||||
service: Annotated[AgentService, Depends(get_agent_service)],
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
before: date | None = Query(default=None),
|
||||
) -> dict[str, object]:
|
||||
return await service.get_history_snapshot(
|
||||
thread_id=thread_id,
|
||||
before=before,
|
||||
current_user=current_user,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/history")
|
||||
async def get_user_history_snapshot(
|
||||
service: Annotated[AgentService, Depends(get_agent_service)],
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
thread_id: str | None = Query(default=None, alias="threadId"),
|
||||
before: date | None = Query(default=None),
|
||||
) -> dict[str, object]:
|
||||
return await service.get_user_history_snapshot(
|
||||
current_user=current_user,
|
||||
thread_id=thread_id,
|
||||
before=before,
|
||||
)
|
||||
|
||||
@@ -1,18 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class RunRequest(BaseModel):
|
||||
session_id: str | None = Field(default=None, min_length=1, max_length=100)
|
||||
prompt: str = Field(min_length=1, max_length=5000)
|
||||
|
||||
|
||||
class ResumeRequest(BaseModel):
|
||||
tool_call_id: str = Field(min_length=1, max_length=200)
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class TaskAcceptedResponse(BaseModel):
|
||||
task_id: str
|
||||
session_id: str
|
||||
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
|
||||
|
||||
task_id: str = Field(alias="taskId")
|
||||
thread_id: str = Field(alias="threadId")
|
||||
run_id: str = Field(alias="runId")
|
||||
created: bool
|
||||
|
||||
+109
-29
@@ -1,9 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import date
|
||||
from typing import Protocol
|
||||
|
||||
from ag_ui.core import StateSnapshotEvent
|
||||
from ag_ui.core import RunAgentInput
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
|
||||
@@ -11,19 +15,28 @@ from core.auth.models import CurrentUser
|
||||
@dataclass(frozen=True)
|
||||
class TaskAccepted:
|
||||
task_id: str
|
||||
session_id: str
|
||||
thread_id: str
|
||||
run_id: str
|
||||
created: bool
|
||||
|
||||
|
||||
class AgentRepositoryLike(Protocol):
|
||||
async def get_session_owner(self, *, session_id: str) -> str: ...
|
||||
|
||||
async def create_session_for_user(self, *, user_id: str) -> str: ...
|
||||
async def create_session_for_user(
|
||||
self, *, user_id: str, session_id: str | None = None
|
||||
) -> str: ...
|
||||
|
||||
async def commit(self) -> None: ...
|
||||
|
||||
async def rollback(self) -> None: ...
|
||||
|
||||
async def get_history_day(
|
||||
self, *, session_id: str, before: date | None
|
||||
) -> dict[str, object] | None: ...
|
||||
|
||||
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None: ...
|
||||
|
||||
|
||||
class QueueClientLike(Protocol):
|
||||
async def enqueue(
|
||||
@@ -60,73 +73,140 @@ class AgentService:
|
||||
async def enqueue_run(
|
||||
self,
|
||||
*,
|
||||
session_id: str | None,
|
||||
prompt: str,
|
||||
run_input: RunAgentInput,
|
||||
current_user: CurrentUser,
|
||||
) -> TaskAccepted:
|
||||
created = False
|
||||
target_session_id = session_id
|
||||
if target_session_id is None:
|
||||
target_session_id = await self._repository.create_session_for_user(
|
||||
user_id=str(current_user.id)
|
||||
)
|
||||
created = True
|
||||
thread_id = run_input.thread_id
|
||||
run_id = run_input.run_id
|
||||
try:
|
||||
owner = await self._repository.get_session_owner(session_id=thread_id)
|
||||
except HTTPException as exc:
|
||||
if exc.status_code != 404:
|
||||
raise
|
||||
try:
|
||||
await self._repository.create_session_for_user(
|
||||
user_id=str(current_user.id),
|
||||
session_id=thread_id,
|
||||
)
|
||||
await self._repository.commit()
|
||||
created = True
|
||||
except IntegrityError:
|
||||
await self._repository.rollback()
|
||||
owner = await self._repository.get_session_owner(session_id=thread_id)
|
||||
ensure_session_owner(owner_id=owner, current_user=current_user)
|
||||
else:
|
||||
owner = await self._repository.get_session_owner(
|
||||
session_id=target_session_id
|
||||
)
|
||||
ensure_session_owner(owner_id=owner, current_user=current_user)
|
||||
|
||||
if created:
|
||||
await self._repository.commit()
|
||||
|
||||
try:
|
||||
task_id = await self._queue.enqueue(
|
||||
command={
|
||||
"command": "run",
|
||||
"session_id": target_session_id,
|
||||
"user_input": prompt,
|
||||
"run_input": run_input.model_dump(mode="json", by_alias=True),
|
||||
},
|
||||
dedup_key=None,
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
raise
|
||||
return TaskAccepted(
|
||||
task_id=task_id, session_id=target_session_id, created=created
|
||||
task_id=task_id,
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
created=created,
|
||||
)
|
||||
|
||||
async def enqueue_resume(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
tool_call_id: str,
|
||||
thread_id: str,
|
||||
run_input: RunAgentInput,
|
||||
current_user: CurrentUser,
|
||||
) -> TaskAccepted:
|
||||
owner = await self._repository.get_session_owner(session_id=session_id)
|
||||
owner = await self._repository.get_session_owner(session_id=thread_id)
|
||||
ensure_session_owner(owner_id=owner, current_user=current_user)
|
||||
|
||||
dedup_key = f"resume:{session_id}:{tool_call_id}"
|
||||
dedup_key = f"resume:{thread_id}:{run_input.run_id}"
|
||||
task_id = await self._queue.enqueue(
|
||||
command={
|
||||
"command": "resume",
|
||||
"session_id": session_id,
|
||||
"tool_call_id": tool_call_id,
|
||||
"run_input": run_input.model_dump(mode="json", by_alias=True),
|
||||
},
|
||||
dedup_key=dedup_key,
|
||||
)
|
||||
|
||||
return TaskAccepted(task_id=task_id, session_id=session_id, created=False)
|
||||
return TaskAccepted(
|
||||
task_id=task_id,
|
||||
thread_id=thread_id,
|
||||
run_id=run_input.run_id,
|
||||
created=False,
|
||||
)
|
||||
|
||||
async def stream_events(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
thread_id: str,
|
||||
last_event_id: str | None,
|
||||
current_user: CurrentUser,
|
||||
) -> list[dict[str, object]]:
|
||||
owner = await self._repository.get_session_owner(session_id=session_id)
|
||||
owner = await self._repository.get_session_owner(session_id=thread_id)
|
||||
ensure_session_owner(owner_id=owner, current_user=current_user)
|
||||
return await self._stream.read(
|
||||
session_id=session_id,
|
||||
session_id=thread_id,
|
||||
last_event_id=last_event_id,
|
||||
)
|
||||
|
||||
async def get_history_snapshot(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
before: date | None,
|
||||
current_user: CurrentUser,
|
||||
) -> dict[str, object]:
|
||||
owner = await self._repository.get_session_owner(session_id=thread_id)
|
||||
ensure_session_owner(owner_id=owner, current_user=current_user)
|
||||
day_payload = await self._repository.get_history_day(
|
||||
session_id=thread_id,
|
||||
before=before,
|
||||
)
|
||||
snapshot = {
|
||||
"scope": "history_day",
|
||||
"threadId": thread_id,
|
||||
"day": day_payload["day"] if day_payload else None,
|
||||
"hasMore": day_payload["hasMore"] if day_payload else False,
|
||||
"messages": day_payload["messages"] if day_payload else [],
|
||||
}
|
||||
event = StateSnapshotEvent(snapshot=snapshot).model_dump(
|
||||
mode="json",
|
||||
by_alias=True,
|
||||
exclude_none=True,
|
||||
)
|
||||
event["threadId"] = thread_id
|
||||
return event
|
||||
|
||||
async def get_user_history_snapshot(
|
||||
self,
|
||||
*,
|
||||
current_user: CurrentUser,
|
||||
thread_id: str | None,
|
||||
before: date | None,
|
||||
) -> dict[str, object]:
|
||||
target_thread_id = thread_id
|
||||
if target_thread_id is None:
|
||||
target_thread_id = await self._repository.get_latest_session_id_for_user(
|
||||
user_id=str(current_user.id)
|
||||
)
|
||||
if target_thread_id is None:
|
||||
return StateSnapshotEvent(
|
||||
snapshot={
|
||||
"scope": "history_day",
|
||||
"threadId": None,
|
||||
"day": None,
|
||||
"hasMore": False,
|
||||
"messages": [],
|
||||
}
|
||||
).model_dump(mode="json", by_alias=True, exclude_none=True)
|
||||
return await self.get_history_snapshot(
|
||||
thread_id=target_thread_id,
|
||||
before=before,
|
||||
current_user=current_user,
|
||||
)
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
from ag_ui.core import RunAgentInput
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
from core.agent.application.resume_service import ResumeService
|
||||
@@ -84,28 +86,76 @@ async def test_run_then_resume_persists_messages_and_session_state(
|
||||
|
||||
published: list[str] = []
|
||||
|
||||
def _publish(event_type: str, payload: dict[str, object]) -> None:
|
||||
del payload
|
||||
published.append(event_type)
|
||||
async def _publish(event: dict[str, object]) -> None:
|
||||
event_type = event.get("type")
|
||||
if isinstance(event_type, str):
|
||||
published.append(event_type)
|
||||
|
||||
try:
|
||||
run_result = run_agent_task(
|
||||
run_input_payload = {
|
||||
"threadId": str(session_uuid),
|
||||
"runId": "run-it-1",
|
||||
"state": {},
|
||||
"messages": [
|
||||
{"id": "u1", "role": "user", "content": "帮我打开日历"},
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"name": "navigate_to_route",
|
||||
"description": "navigate route",
|
||||
"parameters": {"type": "object"},
|
||||
}
|
||||
],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
}
|
||||
run_result = await run_agent_task(
|
||||
{
|
||||
"command": "run",
|
||||
"session_id": str(session_uuid),
|
||||
"user_input": "hello",
|
||||
"run_input": run_input_payload,
|
||||
},
|
||||
publish_event=_publish,
|
||||
run_service=RunService(),
|
||||
resume_service=ResumeService(),
|
||||
)
|
||||
pending_tool_call_id = str(run_result["pending_tool_call_id"])
|
||||
state_snapshot = run_result["state_snapshot"]
|
||||
assert isinstance(state_snapshot, dict)
|
||||
pending_tool_nonce = state_snapshot["pending_tool_nonce"]
|
||||
assert isinstance(pending_tool_nonce, str)
|
||||
|
||||
run_agent_task(
|
||||
await run_agent_task(
|
||||
{
|
||||
"command": "resume",
|
||||
"session_id": str(session_uuid),
|
||||
"tool_call_id": pending_tool_call_id,
|
||||
"run_input": {
|
||||
"threadId": str(session_uuid),
|
||||
"runId": "run-it-2",
|
||||
"state": {},
|
||||
"messages": [
|
||||
{
|
||||
"id": "tool-1",
|
||||
"role": "tool",
|
||||
"toolCallId": pending_tool_call_id,
|
||||
"content": json.dumps(
|
||||
{
|
||||
"toolName": "navigate_to_route",
|
||||
"toolArgs": {
|
||||
"target": "/calendar/dayweek",
|
||||
"replace": False,
|
||||
"__nonce": pending_tool_nonce,
|
||||
},
|
||||
"nonce": pending_tool_nonce,
|
||||
"result": {"ok": True},
|
||||
},
|
||||
ensure_ascii=True,
|
||||
separators=(",", ":"),
|
||||
),
|
||||
}
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
},
|
||||
},
|
||||
publish_event=_publish,
|
||||
run_service=RunService(),
|
||||
@@ -123,6 +173,9 @@ async def test_run_then_resume_persists_messages_and_session_state(
|
||||
assert db_session.state_snapshot == {
|
||||
"status": "completed",
|
||||
"pending_tool_call_id": None,
|
||||
"pending_tool_name": None,
|
||||
"pending_tool_args_sha256": None,
|
||||
"pending_tool_nonce": None,
|
||||
}
|
||||
|
||||
rows = await verify_session.execute(
|
||||
@@ -142,7 +195,7 @@ async def test_run_then_resume_persists_messages_and_session_state(
|
||||
assert messages[1].cost == Decimal("0.002500")
|
||||
|
||||
assert "RUN_STARTED" in published
|
||||
assert "RUN_RESUMED" in published
|
||||
assert "RUN_FINISHED" in published
|
||||
assert "TEXT_MESSAGE_CONTENT" in published
|
||||
finally:
|
||||
async with AsyncSessionLocal() as cleanup_session:
|
||||
@@ -219,7 +272,21 @@ async def test_run_service_embeds_profile_settings_in_runtime_system_prompt(
|
||||
seed_session.add(AgentChatSession(id=session_uuid, user_id=owner_id))
|
||||
await seed_session.commit()
|
||||
|
||||
result = await RunService().run(session_id=str(session_uuid), user_input="hello")
|
||||
result = await RunService().run(
|
||||
run_input=RunAgentInput.model_validate(
|
||||
{
|
||||
"threadId": str(session_uuid),
|
||||
"runId": "run-it-1",
|
||||
"state": {},
|
||||
"messages": [
|
||||
{"id": "u1", "role": "user", "content": "hello"},
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
assert result["persisted"] is True
|
||||
assert captured["user_input"] == "hello"
|
||||
|
||||
@@ -16,29 +16,38 @@ class _FakeStorage:
|
||||
return "etag-1"
|
||||
|
||||
|
||||
def test_closed_loop_run_flow_frontend_to_sse() -> None:
|
||||
session_id = "00000000-0000-0000-0000-000000000001"
|
||||
async def test_closed_loop_run_flow_frontend_to_sse() -> None:
|
||||
thread_id = "00000000-0000-0000-0000-000000000001"
|
||||
published: list[str] = []
|
||||
|
||||
class _FakeRunService:
|
||||
async def run(self, *, session_id: str, user_input: str) -> dict[str, object]:
|
||||
return {"session_id": session_id, "user_input": user_input}
|
||||
async def run(self, *, run_input: object) -> dict[str, object]:
|
||||
del run_input
|
||||
return {"threadId": thread_id, "runId": "run-1"}
|
||||
|
||||
def _publish(event_type: str, payload: dict[str, object]) -> None:
|
||||
del payload
|
||||
published.append(event_type)
|
||||
async def _publish(event: dict[str, object]) -> None:
|
||||
event_type = event.get("type")
|
||||
if isinstance(event_type, str):
|
||||
published.append(event_type)
|
||||
|
||||
result = run_agent_task(
|
||||
result = await run_agent_task(
|
||||
{
|
||||
"command": "run",
|
||||
"session_id": session_id,
|
||||
"user_input": "hello",
|
||||
"run_input": {
|
||||
"threadId": thread_id,
|
||||
"runId": "run-1",
|
||||
"state": {},
|
||||
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
},
|
||||
},
|
||||
publish_event=_publish,
|
||||
run_service=_FakeRunService(),
|
||||
)
|
||||
|
||||
assert result["session_id"] == session_id
|
||||
assert result["threadId"] == thread_id
|
||||
assert published[0] == "RUN_STARTED"
|
||||
assert published[-1] == "RUN_FINISHED"
|
||||
|
||||
|
||||
@@ -3,10 +3,12 @@ from __future__ import annotations
|
||||
from types import SimpleNamespace
|
||||
from uuid import uuid4
|
||||
|
||||
from ag_ui.core import RunAgentInput
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app import app
|
||||
from core.auth.models import CurrentUser
|
||||
from v1.agent import router as agent_router
|
||||
from v1.agent.dependencies import get_agent_service
|
||||
from v1.users.dependencies import get_current_user
|
||||
|
||||
@@ -16,52 +18,122 @@ class _FakeAgentService:
|
||||
self._stream_called = False
|
||||
|
||||
async def enqueue_run(
|
||||
self, *, session_id: str | None, prompt: str, current_user: CurrentUser
|
||||
self, *, run_input: RunAgentInput, current_user: CurrentUser
|
||||
):
|
||||
del prompt, current_user
|
||||
resolved_session = session_id or "auto-created-session"
|
||||
del current_user
|
||||
return SimpleNamespace(
|
||||
task_id="task-run-1",
|
||||
session_id=resolved_session,
|
||||
created=session_id is None,
|
||||
thread_id=run_input.thread_id,
|
||||
run_id=run_input.run_id,
|
||||
created=False,
|
||||
)
|
||||
|
||||
async def enqueue_resume(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
tool_call_id: str,
|
||||
thread_id: str,
|
||||
run_input: RunAgentInput,
|
||||
current_user: CurrentUser,
|
||||
):
|
||||
del tool_call_id, current_user
|
||||
del thread_id, current_user
|
||||
return SimpleNamespace(
|
||||
task_id="task-resume-1", session_id=session_id, created=False
|
||||
task_id="task-resume-1",
|
||||
thread_id=run_input.thread_id,
|
||||
run_id=run_input.run_id,
|
||||
created=False,
|
||||
)
|
||||
|
||||
async def stream_events(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
thread_id: str,
|
||||
last_event_id: str | None,
|
||||
current_user: CurrentUser,
|
||||
) -> list[dict[str, object]]:
|
||||
del session_id, current_user
|
||||
del thread_id, current_user
|
||||
if self._stream_called:
|
||||
return []
|
||||
self._stream_called = True
|
||||
return [
|
||||
{"id": "2-0", "event": {"type": "RUN_STARTED"}, "cursor": last_event_id}
|
||||
{
|
||||
"id": "2-0",
|
||||
"event": {
|
||||
"type": "RUN_STARTED",
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-1",
|
||||
},
|
||||
"cursor": last_event_id,
|
||||
}
|
||||
]
|
||||
|
||||
async def get_history_snapshot(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
before: str | None,
|
||||
current_user: CurrentUser,
|
||||
) -> dict[str, object]:
|
||||
del current_user
|
||||
return {
|
||||
"type": "STATE_SNAPSHOT",
|
||||
"threadId": thread_id,
|
||||
"snapshot": {
|
||||
"scope": "history_day",
|
||||
"day": before or "2026-03-07",
|
||||
"hasMore": False,
|
||||
"messages": [
|
||||
{
|
||||
"id": "msg-h1",
|
||||
"role": "assistant",
|
||||
"content": "history-message",
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
async def get_user_history_snapshot(
|
||||
self,
|
||||
*,
|
||||
current_user: CurrentUser,
|
||||
thread_id: str | None,
|
||||
before: str | None,
|
||||
) -> dict[str, object]:
|
||||
del current_user, before
|
||||
return {
|
||||
"type": "STATE_SNAPSHOT",
|
||||
"threadId": thread_id or "00000000-0000-0000-0000-000000000001",
|
||||
"snapshot": {
|
||||
"scope": "history_day",
|
||||
"day": "2026-03-07",
|
||||
"hasMore": False,
|
||||
"messages": [],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_run_requires_auth_and_returns_202_task_id() -> None:
|
||||
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
|
||||
client = TestClient(app)
|
||||
original_allow_run = agent_router._allow_run_request
|
||||
|
||||
async def _allow_run(*, user_id: str) -> bool:
|
||||
del user_id
|
||||
return True
|
||||
|
||||
agent_router._allow_run_request = _allow_run # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
unauthorized = client.post(
|
||||
"/api/v1/agent/runs",
|
||||
json={"session_id": "session-1", "prompt": "hello"},
|
||||
json={
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-1",
|
||||
"state": {},
|
||||
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
},
|
||||
)
|
||||
assert unauthorized.status_code == 401
|
||||
|
||||
@@ -70,20 +142,23 @@ def test_run_requires_auth_and_returns_202_task_id() -> None:
|
||||
)
|
||||
authorized = client.post(
|
||||
"/api/v1/agent/runs",
|
||||
json={"session_id": "session-1", "prompt": "hello"},
|
||||
json={
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-1",
|
||||
"state": {},
|
||||
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
},
|
||||
)
|
||||
assert authorized.status_code == 202
|
||||
assert authorized.json()["task_id"] == "task-run-1"
|
||||
assert authorized.json()["taskId"] == "task-run-1"
|
||||
assert authorized.json()["threadId"] == "00000000-0000-0000-0000-000000000001"
|
||||
assert authorized.json()["runId"] == "run-1"
|
||||
assert authorized.json()["created"] is False
|
||||
|
||||
first_chat = client.post(
|
||||
"/api/v1/agent/runs",
|
||||
json={"prompt": "hello"},
|
||||
)
|
||||
assert first_chat.status_code == 202
|
||||
assert first_chat.json()["session_id"] == "auto-created-session"
|
||||
assert first_chat.json()["created"] is True
|
||||
finally:
|
||||
agent_router._allow_run_request = original_allow_run # type: ignore[assignment]
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
@@ -93,15 +168,122 @@ def test_stream_reads_from_last_event_id() -> None:
|
||||
id=uuid4(), email="user@example.com"
|
||||
)
|
||||
client = TestClient(app)
|
||||
original_acquire = agent_router._acquire_sse_slot
|
||||
original_release = agent_router._release_sse_slot
|
||||
|
||||
async def _allow_slot(*, user_id: str) -> bool:
|
||||
del user_id
|
||||
return True
|
||||
|
||||
async def _noop_release(*, user_id: str) -> None:
|
||||
del user_id
|
||||
return None
|
||||
|
||||
agent_router._acquire_sse_slot = _allow_slot # type: ignore[assignment]
|
||||
agent_router._release_sse_slot = _noop_release # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
response = client.get(
|
||||
"/api/v1/agent/runs/session-1/events?idle_limit=1",
|
||||
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/events?idle_limit=1",
|
||||
headers={"Last-Event-ID": "1-0"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"].startswith("text/event-stream")
|
||||
assert "id: 2-0" in response.text
|
||||
assert "event: RUN_STARTED" in response.text
|
||||
assert '"threadId":"00000000-0000-0000-0000-000000000001"' in response.text
|
||||
finally:
|
||||
agent_router._acquire_sse_slot = original_acquire # type: ignore[assignment]
|
||||
agent_router._release_sse_slot = original_release # type: ignore[assignment]
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_stream_rejects_invalid_last_event_id() -> None:
|
||||
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=uuid4(), email="user@example.com"
|
||||
)
|
||||
client = TestClient(app)
|
||||
|
||||
try:
|
||||
response = client.get(
|
||||
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/events",
|
||||
headers={"Last-Event-ID": "bad-id"},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_history_returns_state_snapshot() -> None:
|
||||
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
|
||||
client = TestClient(app)
|
||||
|
||||
try:
|
||||
unauthorized = client.get(
|
||||
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/history"
|
||||
)
|
||||
assert unauthorized.status_code == 401
|
||||
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=uuid4(), email="user@example.com"
|
||||
)
|
||||
authorized = client.get(
|
||||
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/history",
|
||||
params={"before": "2026-03-07"},
|
||||
)
|
||||
assert authorized.status_code == 200
|
||||
payload = authorized.json()
|
||||
assert payload["type"] == "STATE_SNAPSHOT"
|
||||
assert payload["threadId"] == "00000000-0000-0000-0000-000000000001"
|
||||
assert payload["snapshot"]["scope"] == "history_day"
|
||||
assert payload["snapshot"]["day"] == "2026-03-07"
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_user_history_returns_latest_snapshot() -> None:
|
||||
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=uuid4(), email="user@example.com"
|
||||
)
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.get("/api/v1/agent/history")
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["type"] == "STATE_SNAPSHOT"
|
||||
assert body["threadId"] == "00000000-0000-0000-0000-000000000001"
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_run_rejects_oversized_user_text_payload() -> None:
|
||||
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=uuid4(), email="user@example.com"
|
||||
)
|
||||
client = TestClient(app)
|
||||
|
||||
try:
|
||||
response = client.post(
|
||||
"/api/v1/agent/runs",
|
||||
json={
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-oversize",
|
||||
"state": {},
|
||||
"messages": [
|
||||
{
|
||||
"id": "u1",
|
||||
"role": "user",
|
||||
"content": "x" * 11000,
|
||||
}
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from uuid import UUID
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import httpx
|
||||
import jwt
|
||||
@@ -56,15 +56,25 @@ async def test_agent_sse_closed_loop_live() -> None:
|
||||
run_resp = await client.post(
|
||||
f"{BASE_URL}/api/v1/agent/runs",
|
||||
headers=headers,
|
||||
json={"prompt": "请用一句话介绍你自己"},
|
||||
json={
|
||||
"threadId": str(uuid4()),
|
||||
"runId": "run-live-1",
|
||||
"state": {},
|
||||
"messages": [
|
||||
{"id": "u1", "role": "user", "content": "请用一句话介绍你自己"}
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
},
|
||||
)
|
||||
assert run_resp.status_code == 202
|
||||
|
||||
accepted = run_resp.json()
|
||||
session_id = str(accepted["session_id"])
|
||||
assert session_id
|
||||
thread_id = str(accepted["threadId"])
|
||||
assert thread_id
|
||||
|
||||
events_url = f"{BASE_URL}/api/v1/agent/runs/{session_id}/events"
|
||||
events_url = f"{BASE_URL}/api/v1/agent/runs/{thread_id}/events"
|
||||
event_names: list[str] = []
|
||||
async with client.stream("GET", events_url, headers=headers, timeout=20.0) as sse_resp:
|
||||
assert sse_resp.status_code == 200
|
||||
@@ -77,13 +87,13 @@ async def test_agent_sse_closed_loop_live() -> None:
|
||||
assert "RUN_FINISHED" in event_names or "RUN_ERROR" in event_names
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
session_row = await session.get(AgentChatSession, UUID(session_id))
|
||||
session_row = await session.get(AgentChatSession, UUID(thread_id))
|
||||
assert session_row is not None
|
||||
assert session_row.message_count >= 1
|
||||
assert session_row.total_tokens >= 0
|
||||
assert session_row.total_cost >= 0
|
||||
|
||||
rows = await session.execute(
|
||||
select(AgentChatMessage).where(AgentChatMessage.session_id == UUID(session_id))
|
||||
select(AgentChatMessage).where(AgentChatMessage.session_id == UUID(thread_id))
|
||||
)
|
||||
assert len(list(rows.scalars().all())) >= 1
|
||||
|
||||
@@ -132,7 +132,9 @@ def test_bridge_rejects_unknown_event_type() -> None:
|
||||
|
||||
def test_sse_format_includes_id_event_data() -> None:
|
||||
payload = to_sse_event(
|
||||
stream_id="1-0", event={"type": "RUN_STARTED", "data": {"a": 1}}
|
||||
stream_id="1-0",
|
||||
event={"type": "RUN_STARTED", "threadId": "t1", "runId": "r1"},
|
||||
)
|
||||
|
||||
assert payload.startswith("id: 1-0\nevent: RUN_STARTED\ndata: {")
|
||||
assert '"threadId":"t1"' in payload
|
||||
|
||||
@@ -56,12 +56,14 @@ def test_runtime_execute_uses_provider_prefixed_litellm_model(
|
||||
messages: list[dict[str, object]],
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
timeout: float | None = None,
|
||||
):
|
||||
captured["model"] = model
|
||||
captured["api_key"] = api_key
|
||||
captured["messages"] = messages
|
||||
captured["temperature"] = temperature
|
||||
captured["max_tokens"] = max_tokens
|
||||
captured["timeout"] = timeout
|
||||
return {
|
||||
"choices": [
|
||||
{
|
||||
@@ -113,6 +115,7 @@ def test_runtime_execute_uses_provider_prefixed_litellm_model(
|
||||
assert captured["api_key"] == "env-api-key"
|
||||
assert captured["temperature"] == 0.3
|
||||
assert captured["max_tokens"] == 256
|
||||
assert captured["timeout"] == 30.0
|
||||
assert result["assistant_text"] == "hello"
|
||||
|
||||
|
||||
@@ -128,6 +131,7 @@ def test_runtime_execute_injects_system_prompt_and_intent_template(
|
||||
messages: list[dict[str, object]],
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
timeout: float | None = None,
|
||||
):
|
||||
captured["messages"] = messages
|
||||
return {
|
||||
@@ -219,6 +223,7 @@ def test_runtime_execute_short_circuits_on_direct_execution(
|
||||
messages: list[dict[str, object]],
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
timeout: float | None = None,
|
||||
):
|
||||
del model, api_key, temperature, max_tokens
|
||||
calls.append(messages)
|
||||
@@ -331,6 +336,7 @@ def test_runtime_execute_runs_execution_and_organization_stages(
|
||||
messages: list[dict[str, object]],
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
timeout: float | None = None,
|
||||
):
|
||||
del model, api_key, temperature, max_tokens
|
||||
calls.append(messages)
|
||||
@@ -383,6 +389,7 @@ def test_runtime_execute_rejects_invalid_intent_json(
|
||||
messages: list[dict[str, object]],
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
timeout: float | None = None,
|
||||
):
|
||||
del model, api_key, messages, temperature, max_tokens
|
||||
return {
|
||||
@@ -506,6 +513,7 @@ def test_runtime_execute_minimizes_prompt_and_execution_payload(
|
||||
messages: list[dict[str, object]],
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
timeout: float | None = None,
|
||||
):
|
||||
del model, api_key, temperature, max_tokens
|
||||
calls.append(messages)
|
||||
|
||||
@@ -21,10 +21,12 @@ def test_run_completion_passes_optional_params_when_provided(monkeypatch) -> Non
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
temperature=0.6,
|
||||
max_tokens=120,
|
||||
timeout=12.5,
|
||||
)
|
||||
|
||||
assert captured["temperature"] == 0.6
|
||||
assert captured["max_tokens"] == 120
|
||||
assert captured["timeout"] == 12.5
|
||||
|
||||
|
||||
def test_run_completion_omits_optional_params_when_none(monkeypatch) -> None:
|
||||
@@ -45,7 +47,9 @@ def test_run_completion_omits_optional_params_when_none(monkeypatch) -> None:
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
temperature=None,
|
||||
max_tokens=None,
|
||||
timeout=None,
|
||||
)
|
||||
|
||||
assert "temperature" not in captured
|
||||
assert "max_tokens" not in captured
|
||||
assert "timeout" not in captured
|
||||
|
||||
@@ -2,64 +2,124 @@ from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from ag_ui.core import RunAgentInput
|
||||
from core.agent.infrastructure.queue.tasks import _build_redis_publisher, run_agent_task
|
||||
|
||||
|
||||
class _FakeRunService:
|
||||
async def run(self, *, session_id: str, user_input: str) -> dict[str, object]:
|
||||
return {"session_id": session_id, "user_input": user_input}
|
||||
async def run(self, *, run_input: RunAgentInput) -> dict[str, object]:
|
||||
return {
|
||||
"threadId": run_input.thread_id,
|
||||
"runId": run_input.run_id,
|
||||
}
|
||||
|
||||
|
||||
class _FakeResumeService:
|
||||
async def resume(self, *, session_id: str, tool_call_id: str) -> dict[str, object]:
|
||||
return {"session_id": session_id, "tool_call_id": tool_call_id}
|
||||
async def resume(
|
||||
self,
|
||||
*,
|
||||
run_input: RunAgentInput,
|
||||
) -> dict[str, object]:
|
||||
return {
|
||||
"threadId": run_input.thread_id,
|
||||
"runId": run_input.run_id,
|
||||
}
|
||||
|
||||
|
||||
def _build_run_input() -> dict[str, object]:
|
||||
return {
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-1",
|
||||
"state": {},
|
||||
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_task_emits_started_runtime_and_finished_events() -> None:
|
||||
session_id = "00000000-0000-0000-0000-000000000001"
|
||||
events: list[str] = []
|
||||
|
||||
async def _publish(event_type: str, payload: dict[str, object]) -> None:
|
||||
del payload
|
||||
events.append(event_type)
|
||||
async def _publish(event: dict[str, object]) -> None:
|
||||
event_type = event.get("type")
|
||||
if isinstance(event_type, str):
|
||||
events.append(event_type)
|
||||
|
||||
result = await run_agent_task(
|
||||
{
|
||||
"command": "run",
|
||||
"session_id": session_id,
|
||||
"user_input": "hello",
|
||||
"run_input": _build_run_input(),
|
||||
},
|
||||
publish_event=_publish,
|
||||
run_service=_FakeRunService(),
|
||||
resume_service=_FakeResumeService(),
|
||||
)
|
||||
|
||||
assert result["session_id"] == session_id
|
||||
assert events == ["RUN_STARTED", "RUNTIME_EVENT", "RUN_FINISHED"]
|
||||
assert result["threadId"] == "00000000-0000-0000-0000-000000000001"
|
||||
assert events == ["RUN_STARTED", "RUN_FINISHED"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_task_injects_context_and_redacts_sensitive_fields() -> None:
|
||||
published: list[dict[str, object]] = []
|
||||
|
||||
class _RunWithExtraEvents(_FakeRunService):
|
||||
async def run(self, *, run_input: RunAgentInput) -> dict[str, object]:
|
||||
return {
|
||||
"threadId": run_input.thread_id,
|
||||
"runId": run_input.run_id,
|
||||
"events": [
|
||||
{
|
||||
"type": "TEXT_MESSAGE_CONTENT",
|
||||
"messageId": "m1",
|
||||
"delta": "hi",
|
||||
"token": "secret-token",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
async def _publish(event: dict[str, object]) -> None:
|
||||
published.append(event)
|
||||
|
||||
await run_agent_task(
|
||||
{"command": "run", "run_input": _build_run_input()},
|
||||
publish_event=_publish,
|
||||
run_service=_RunWithExtraEvents(),
|
||||
resume_service=_FakeResumeService(),
|
||||
)
|
||||
|
||||
run_started = published[0]
|
||||
assert run_started["type"] == "RUN_STARTED"
|
||||
assert "input" not in run_started
|
||||
|
||||
text_event = published[1]
|
||||
assert text_event["type"] == "TEXT_MESSAGE_CONTENT"
|
||||
assert text_event["threadId"] == "00000000-0000-0000-0000-000000000001"
|
||||
assert text_event["runId"] == "run-1"
|
||||
assert text_event["token"] == "***REDACTED***"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_task_emits_error_event_on_exception() -> None:
|
||||
session_id = "00000000-0000-0000-0000-000000000001"
|
||||
|
||||
class _BrokenRunService(_FakeRunService):
|
||||
async def run(self, *, session_id: str, user_input: str) -> dict[str, object]:
|
||||
del session_id, user_input
|
||||
async def run(self, *, run_input: dict[str, object]) -> dict[str, object]:
|
||||
del run_input
|
||||
raise RuntimeError("boom")
|
||||
|
||||
events: list[str] = []
|
||||
|
||||
async def _publish(event_type: str, payload: dict[str, object]) -> None:
|
||||
del payload
|
||||
events.append(event_type)
|
||||
async def _publish(event: dict[str, object]) -> None:
|
||||
event_type = event.get("type")
|
||||
if isinstance(event_type, str):
|
||||
events.append(event_type)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
await run_agent_task(
|
||||
{
|
||||
"command": "run",
|
||||
"session_id": session_id,
|
||||
"user_input": "hello",
|
||||
"run_input": _build_run_input(),
|
||||
},
|
||||
publish_event=_publish,
|
||||
run_service=_BrokenRunService(),
|
||||
@@ -72,16 +132,44 @@ async def test_run_agent_task_emits_error_event_on_exception() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_task_rejects_invalid_command() -> None:
|
||||
with pytest.raises(ValueError, match="invalid command type"):
|
||||
await run_agent_task({"command": "invalid", "session_id": "00000000-0000-0000-0000-000000000001"})
|
||||
await run_agent_task({"command": "invalid", "run_input": _build_run_input()})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_task_resume_requires_tool_call_id() -> None:
|
||||
with pytest.raises(ValueError, match="tool_call_id is required"):
|
||||
async def test_run_agent_task_rejects_missing_run_input() -> None:
|
||||
with pytest.raises(ValueError, match="run_input is required"):
|
||||
await run_agent_task(
|
||||
{
|
||||
"command": "resume",
|
||||
"session_id": "00000000-0000-0000-0000-000000000001",
|
||||
"command": "run",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_task_resume_uses_run_input() -> None:
|
||||
async def _publish(event: dict[str, object]) -> None:
|
||||
del event
|
||||
|
||||
result = await run_agent_task(
|
||||
{
|
||||
"command": "resume",
|
||||
"run_input": _build_run_input(),
|
||||
},
|
||||
publish_event=_publish,
|
||||
run_service=_FakeRunService(),
|
||||
resume_service=_FakeResumeService(),
|
||||
)
|
||||
|
||||
assert result["runId"] == "run-1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_task_rejects_invalid_run_input() -> None:
|
||||
with pytest.raises(ValueError, match="invalid AG-UI RunAgentInput payload"):
|
||||
await run_agent_task(
|
||||
{
|
||||
"command": "run",
|
||||
"run_input": {"threadId": "x"},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -23,11 +23,34 @@ class _FakeRedisClient:
|
||||
) -> list[tuple[str, list[tuple[str, dict[str, str]]]]]:
|
||||
del count, block
|
||||
key, start_id = next(iter(streams.items()))
|
||||
if start_id == "$":
|
||||
if start_id == "0-0":
|
||||
return [(key, [("11-0", {"event": '{"type":"RUN_STARTED"}'})])]
|
||||
return [(key, [("12-0", {"event": '{"type":"RUN_FINISHED"}'})])]
|
||||
|
||||
|
||||
class _MalformedRedisClient:
|
||||
async def xread(
|
||||
self,
|
||||
streams: dict[str, str],
|
||||
count: int,
|
||||
block: int,
|
||||
) -> list[object]:
|
||||
del streams, count, block
|
||||
return ["bad-shape"]
|
||||
|
||||
|
||||
class _InvalidJsonRedisClient:
|
||||
async def xread(
|
||||
self,
|
||||
streams: dict[str, str],
|
||||
count: int,
|
||||
block: int,
|
||||
) -> list[tuple[str, list[tuple[str, dict[str, str]]]]]:
|
||||
del count, block
|
||||
key = next(iter(streams.keys()))
|
||||
return [(key, [("11-0", {"event": "not-json"})])]
|
||||
|
||||
|
||||
def test_append_event_writes_json_payload() -> None:
|
||||
client = _FakeRedisClient()
|
||||
session_id = uuid4()
|
||||
@@ -55,3 +78,26 @@ async def test_read_events_respects_last_event_id() -> None:
|
||||
|
||||
assert from_start[0]["id"] == "11-0"
|
||||
assert from_last[0]["id"] == "12-0"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_events_returns_empty_for_malformed_response() -> None:
|
||||
session_id = uuid4()
|
||||
store = RedisStreamEventStore(client=_MalformedRedisClient(), stream_prefix="agent:events")
|
||||
|
||||
rows = await store.read_events(session_id=session_id, last_event_id=None)
|
||||
|
||||
assert rows == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_events_skips_invalid_event_json() -> None:
|
||||
session_id = uuid4()
|
||||
store = RedisStreamEventStore(
|
||||
client=_InvalidJsonRedisClient(),
|
||||
stream_prefix="agent:events",
|
||||
)
|
||||
|
||||
rows = await store.read_events(session_id=session_id, last_event_id=None)
|
||||
|
||||
assert rows == []
|
||||
|
||||
@@ -5,11 +5,13 @@ from types import SimpleNamespace
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from ag_ui.core import RunAgentInput
|
||||
|
||||
from core.agent.application.resume_service import ResumeService
|
||||
from core.agent.application.run_service import RunService
|
||||
from core.agent.domain.system_agent_config import SystemAgentLLMConfig
|
||||
from core.agent.domain.user_context import UserAgentContext, parse_profile_settings
|
||||
from models.agent_chat_message import AgentChatMessageRole
|
||||
from models.agent_chat_session import AgentChatSessionStatus
|
||||
|
||||
|
||||
@@ -61,12 +63,69 @@ class _FakeUserContextCache:
|
||||
self.set_calls += 1
|
||||
|
||||
|
||||
def _build_run_input(
|
||||
*,
|
||||
thread_id: str,
|
||||
text: str = "hello",
|
||||
tools: list[dict[str, object]] | None = None,
|
||||
) -> RunAgentInput:
|
||||
return RunAgentInput.model_validate(
|
||||
{
|
||||
"threadId": thread_id,
|
||||
"runId": "run-1",
|
||||
"state": {},
|
||||
"messages": [{"id": "u1", "role": "user", "content": text}],
|
||||
"tools": tools or [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _build_resume_input(
|
||||
*,
|
||||
thread_id: str,
|
||||
tool_call_id: str,
|
||||
content: str | None = None,
|
||||
) -> RunAgentInput:
|
||||
payload = content
|
||||
if payload is None:
|
||||
payload = json.dumps(
|
||||
{
|
||||
"toolName": "navigate_to_route",
|
||||
"toolArgs": {"target": "/calendar/dayweek", "replace": False, "__nonce": "nonce-1"},
|
||||
"nonce": "nonce-1",
|
||||
"result": {"ok": True},
|
||||
},
|
||||
ensure_ascii=True,
|
||||
separators=(",", ":"),
|
||||
)
|
||||
return RunAgentInput.model_validate(
|
||||
{
|
||||
"threadId": thread_id,
|
||||
"runId": "run-2",
|
||||
"state": {},
|
||||
"messages": [
|
||||
{
|
||||
"id": "tool-1",
|
||||
"role": "tool",
|
||||
"toolCallId": tool_call_id,
|
||||
"content": payload,
|
||||
}
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_service_rejects_invalid_session_id() -> None:
|
||||
run_service = RunService()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await run_service.run(session_id="session-1", user_input="hello")
|
||||
await run_service.run(run_input=_build_run_input(thread_id="session-1"))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -74,7 +133,272 @@ async def test_resume_service_requires_pending_tool_call() -> None:
|
||||
resume_service = ResumeService()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await resume_service.resume(session_id="session-1", tool_call_id="call-1")
|
||||
await resume_service.resume(
|
||||
run_input=_build_resume_input(
|
||||
thread_id="session-1",
|
||||
tool_call_id="call-1",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_service_validates_pending_tool_guard_and_persists_payload(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
session_id = uuid4()
|
||||
user_id = uuid4()
|
||||
captured: list[dict[str, object]] = []
|
||||
|
||||
class _FakeDbSession:
|
||||
async def commit(self) -> None:
|
||||
return None
|
||||
|
||||
class _FakeSessionFactory:
|
||||
def __call__(self) -> "_FakeSessionFactory":
|
||||
return self
|
||||
|
||||
async def __aenter__(self) -> _FakeDbSession:
|
||||
return _FakeDbSession()
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> bool:
|
||||
del exc_type, exc, tb
|
||||
return False
|
||||
|
||||
class _FakeSessionRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def lock_session_for_update(self, *, session_id: object):
|
||||
return SimpleNamespace(
|
||||
id=session_id,
|
||||
user_id=user_id,
|
||||
status=AgentChatSessionStatus.RUNNING,
|
||||
message_count=0,
|
||||
total_tokens=0,
|
||||
total_cost=0,
|
||||
state_snapshot={
|
||||
"pending_tool_call_id": "call-1",
|
||||
"pending_tool_name": "navigate_to_route",
|
||||
"pending_tool_args_sha256": "c8e6573e6ce79d9a8052d5167e18602841fa4248f3e8a2efb448c4bfbd298a12",
|
||||
"pending_tool_nonce": "nonce-1",
|
||||
},
|
||||
)
|
||||
|
||||
async def next_message_seq(self, *, session_id: object) -> int:
|
||||
del session_id
|
||||
return 1
|
||||
|
||||
async def update_runtime_state(self, **kwargs) -> None:
|
||||
del kwargs
|
||||
|
||||
class _FakeMessageRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def append_message(self, **kwargs) -> None:
|
||||
captured.append(kwargs)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.resume_service.SessionRepository",
|
||||
_FakeSessionRepository,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.resume_service.MessageRepository",
|
||||
_FakeMessageRepository,
|
||||
)
|
||||
|
||||
service = ResumeService(session_factory=_FakeSessionFactory()) # type: ignore[arg-type]
|
||||
await service.resume(
|
||||
run_input=_build_resume_input(
|
||||
thread_id=str(session_id),
|
||||
tool_call_id="call-1",
|
||||
),
|
||||
)
|
||||
|
||||
assert captured[0]["role"] == AgentChatMessageRole.TOOL
|
||||
stored_payload = json.loads(captured[0]["content"])
|
||||
assert stored_payload["toolName"] == "navigate_to_route"
|
||||
assert stored_payload["result"]["ok"] is True
|
||||
assert stored_payload["result"]["applied"] is True
|
||||
assert "ui" not in stored_payload
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_service_rejects_mismatched_nonce(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
session_id = uuid4()
|
||||
user_id = uuid4()
|
||||
|
||||
class _FakeDbSession:
|
||||
async def commit(self) -> None:
|
||||
return None
|
||||
|
||||
class _FakeSessionFactory:
|
||||
def __call__(self) -> "_FakeSessionFactory":
|
||||
return self
|
||||
|
||||
async def __aenter__(self) -> _FakeDbSession:
|
||||
return _FakeDbSession()
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> bool:
|
||||
del exc_type, exc, tb
|
||||
return False
|
||||
|
||||
class _FakeSessionRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def lock_session_for_update(self, *, session_id: object):
|
||||
return SimpleNamespace(
|
||||
id=session_id,
|
||||
user_id=user_id,
|
||||
status=AgentChatSessionStatus.RUNNING,
|
||||
message_count=0,
|
||||
total_tokens=0,
|
||||
total_cost=0,
|
||||
state_snapshot={
|
||||
"pending_tool_call_id": "call-1",
|
||||
"pending_tool_name": "navigate_to_route",
|
||||
"pending_tool_args_sha256": "c8e6573e6ce79d9a8052d5167e18602841fa4248f3e8a2efb448c4bfbd298a12",
|
||||
"pending_tool_nonce": "nonce-1",
|
||||
},
|
||||
)
|
||||
|
||||
async def next_message_seq(self, *, session_id: object) -> int:
|
||||
del session_id
|
||||
return 1
|
||||
|
||||
async def update_runtime_state(self, **kwargs) -> None:
|
||||
del kwargs
|
||||
|
||||
class _FakeMessageRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def append_message(self, **kwargs) -> None:
|
||||
del kwargs
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.resume_service.SessionRepository",
|
||||
_FakeSessionRepository,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.resume_service.MessageRepository",
|
||||
_FakeMessageRepository,
|
||||
)
|
||||
|
||||
service = ResumeService(session_factory=_FakeSessionFactory()) # type: ignore[arg-type]
|
||||
with pytest.raises(ValueError, match="nonce"):
|
||||
await service.resume(
|
||||
run_input=_build_resume_input(
|
||||
thread_id=str(session_id),
|
||||
tool_call_id="call-1",
|
||||
content=json.dumps(
|
||||
{
|
||||
"toolName": "navigate_to_route",
|
||||
"toolArgs": {
|
||||
"target": "/calendar/dayweek",
|
||||
"replace": False,
|
||||
"__nonce": "nonce-1",
|
||||
},
|
||||
"nonce": "nonce-bad",
|
||||
"result": {"ok": True},
|
||||
},
|
||||
ensure_ascii=True,
|
||||
separators=(",", ":"),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_service_rejects_tool_result_when_not_ok(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
session_id = uuid4()
|
||||
user_id = uuid4()
|
||||
|
||||
class _FakeDbSession:
|
||||
async def commit(self) -> None:
|
||||
return None
|
||||
|
||||
class _FakeSessionFactory:
|
||||
def __call__(self) -> "_FakeSessionFactory":
|
||||
return self
|
||||
|
||||
async def __aenter__(self) -> _FakeDbSession:
|
||||
return _FakeDbSession()
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> bool:
|
||||
del exc_type, exc, tb
|
||||
return False
|
||||
|
||||
class _FakeSessionRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def lock_session_for_update(self, *, session_id: object):
|
||||
return SimpleNamespace(
|
||||
id=session_id,
|
||||
user_id=user_id,
|
||||
status=AgentChatSessionStatus.RUNNING,
|
||||
message_count=0,
|
||||
total_tokens=0,
|
||||
total_cost=0,
|
||||
state_snapshot={
|
||||
"pending_tool_call_id": "call-1",
|
||||
"pending_tool_name": "navigate_to_route",
|
||||
"pending_tool_args_sha256": "c8e6573e6ce79d9a8052d5167e18602841fa4248f3e8a2efb448c4bfbd298a12",
|
||||
"pending_tool_nonce": "nonce-1",
|
||||
},
|
||||
)
|
||||
|
||||
async def next_message_seq(self, *, session_id: object) -> int:
|
||||
del session_id
|
||||
return 1
|
||||
|
||||
async def update_runtime_state(self, **kwargs) -> None:
|
||||
del kwargs
|
||||
|
||||
class _FakeMessageRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def append_message(self, **kwargs) -> None:
|
||||
del kwargs
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.resume_service.SessionRepository",
|
||||
_FakeSessionRepository,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.resume_service.MessageRepository",
|
||||
_FakeMessageRepository,
|
||||
)
|
||||
|
||||
service = ResumeService(session_factory=_FakeSessionFactory()) # type: ignore[arg-type]
|
||||
with pytest.raises(ValueError, match="execution failed"):
|
||||
await service.resume(
|
||||
run_input=_build_resume_input(
|
||||
thread_id=str(session_id),
|
||||
tool_call_id="call-1",
|
||||
content=json.dumps(
|
||||
{
|
||||
"toolName": "navigate_to_route",
|
||||
"toolArgs": {
|
||||
"target": "/calendar/dayweek",
|
||||
"replace": False,
|
||||
"__nonce": "nonce-1",
|
||||
},
|
||||
"nonce": "nonce-1",
|
||||
"result": {"ok": False, "error": "navigator not bound"},
|
||||
},
|
||||
ensure_ascii=True,
|
||||
separators=(",", ":"),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -256,7 +580,9 @@ async def test_run_service_passes_user_context_system_prompt_to_runtime(
|
||||
session_uuid = session_id
|
||||
run_service = RunService(session_factory=_FakeSessionFactory()) # type: ignore[arg-type]
|
||||
|
||||
await run_service.run(session_id=str(session_id), user_input="hello")
|
||||
await run_service.run(
|
||||
run_input=_build_run_input(thread_id=str(session_id), text="hello")
|
||||
)
|
||||
|
||||
system_prompt = captured["system_prompt"]
|
||||
assert isinstance(system_prompt, str)
|
||||
@@ -267,6 +593,290 @@ async def test_run_service_passes_user_context_system_prompt_to_runtime(
|
||||
assert payload["ai_language"] == "en-US"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_service_emits_frontend_tool_pending_events(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
session_id = uuid4()
|
||||
user_id = uuid4()
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
class _FakeDbSession:
|
||||
async def commit(self) -> None:
|
||||
return None
|
||||
|
||||
class _FakeSessionFactory:
|
||||
def __call__(self) -> "_FakeSessionFactory":
|
||||
return self
|
||||
|
||||
async def __aenter__(self) -> _FakeDbSession:
|
||||
return _FakeDbSession()
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> bool:
|
||||
del exc_type, exc, tb
|
||||
return False
|
||||
|
||||
class _FakeSessionRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def lock_session_for_update(self, *, session_id: object):
|
||||
return SimpleNamespace(
|
||||
id=session_id,
|
||||
user_id=user_id,
|
||||
status=AgentChatSessionStatus.PENDING,
|
||||
message_count=0,
|
||||
total_tokens=0,
|
||||
total_cost=0,
|
||||
state_snapshot=None,
|
||||
)
|
||||
|
||||
async def next_message_seq(self, *, session_id: object):
|
||||
del session_id
|
||||
return 1
|
||||
|
||||
async def update_runtime_state(self, **kwargs) -> None:
|
||||
captured["update_runtime_state"] = kwargs
|
||||
|
||||
class _FakeMessageRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def append_message(self, **kwargs) -> None:
|
||||
captured.setdefault("messages", []).append(kwargs)
|
||||
|
||||
class _FakeRuntime:
|
||||
def execute(self, *, user_input: str, system_prompt: str | None = None):
|
||||
del user_input, system_prompt
|
||||
return {
|
||||
"assistant_text": "请确认是否跳转。",
|
||||
"prompt_tokens": 1,
|
||||
"completion_tokens": 1,
|
||||
"total_tokens": 2,
|
||||
"cost": "0.001",
|
||||
"agui_events": [],
|
||||
}
|
||||
|
||||
async def _fake_load_agent_model_selection(self, _session):
|
||||
del self
|
||||
return ("qwen3.5-flash", "dashscope", SystemAgentLLMConfig())
|
||||
|
||||
async def _fake_load_user_agent_context(self, session, session_id, user_id):
|
||||
del self, session, session_id
|
||||
return SimpleNamespace(
|
||||
user_id=user_id,
|
||||
username="demo-user",
|
||||
bio=None,
|
||||
settings=SimpleNamespace(
|
||||
preferences=SimpleNamespace(
|
||||
interface_language="zh-CN",
|
||||
ai_language="zh-CN",
|
||||
timezone="Asia/Shanghai",
|
||||
country="CN",
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.run_service.SessionRepository",
|
||||
_FakeSessionRepository,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.run_service.MessageRepository",
|
||||
_FakeMessageRepository,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.run_service.create_runtime",
|
||||
lambda **_kwargs: _FakeRuntime(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.run_service.RunService._load_agent_model_selection",
|
||||
_fake_load_agent_model_selection,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.run_service.RunService._load_user_agent_context",
|
||||
_fake_load_user_agent_context,
|
||||
)
|
||||
|
||||
service = RunService(session_factory=_FakeSessionFactory()) # type: ignore[arg-type]
|
||||
result = await service.run(
|
||||
run_input=_build_run_input(
|
||||
thread_id=str(session_id),
|
||||
text="帮我打开日历",
|
||||
tools=[
|
||||
{
|
||||
"name": "navigate_to_route",
|
||||
"description": "navigate",
|
||||
"parameters": {"type": "object"},
|
||||
}
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
assert result["pending_tool_call_id"] is not None
|
||||
tool_start = next(event for event in result["events"] if event["type"] == "TOOL_CALL_START")
|
||||
assert tool_start["toolCallName"] == "navigate_to_route"
|
||||
runtime_state = captured["update_runtime_state"]
|
||||
assert isinstance(runtime_state, dict)
|
||||
assert runtime_state["status"] == AgentChatSessionStatus.RUNNING
|
||||
snapshot = runtime_state["state_snapshot"]
|
||||
assert isinstance(snapshot, dict)
|
||||
assert snapshot["pending_tool_name"] == "navigate_to_route"
|
||||
assert isinstance(snapshot["pending_tool_args_sha256"], str)
|
||||
assert isinstance(snapshot["pending_tool_nonce"], str)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_service_executes_backend_calendar_tool_and_emits_result(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
session_id = uuid4()
|
||||
user_id = uuid4()
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
class _FakeDbSession:
|
||||
async def commit(self) -> None:
|
||||
return None
|
||||
|
||||
class _FakeSessionFactory:
|
||||
def __call__(self) -> "_FakeSessionFactory":
|
||||
return self
|
||||
|
||||
async def __aenter__(self) -> _FakeDbSession:
|
||||
return _FakeDbSession()
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> bool:
|
||||
del exc_type, exc, tb
|
||||
return False
|
||||
|
||||
class _FakeSessionRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def lock_session_for_update(self, *, session_id: object):
|
||||
return SimpleNamespace(
|
||||
id=session_id,
|
||||
user_id=user_id,
|
||||
status=AgentChatSessionStatus.PENDING,
|
||||
message_count=0,
|
||||
total_tokens=0,
|
||||
total_cost=0,
|
||||
state_snapshot=None,
|
||||
)
|
||||
|
||||
async def next_message_seq(self, *, session_id: object):
|
||||
del session_id
|
||||
return 1
|
||||
|
||||
async def update_runtime_state(self, **kwargs) -> None:
|
||||
captured["update_runtime_state"] = kwargs
|
||||
|
||||
class _FakeMessageRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def append_message(self, **kwargs) -> None:
|
||||
captured.setdefault("messages", []).append(kwargs)
|
||||
|
||||
class _FakeRuntime:
|
||||
def execute(self, *, user_input: str, system_prompt: str | None = None):
|
||||
del user_input, system_prompt
|
||||
return {
|
||||
"assistant_text": "日历事件已创建。",
|
||||
"prompt_tokens": 1,
|
||||
"completion_tokens": 1,
|
||||
"total_tokens": 2,
|
||||
"cost": "0.001",
|
||||
"agui_events": [],
|
||||
}
|
||||
|
||||
async def _fake_load_agent_model_selection(self, _session):
|
||||
del self
|
||||
return ("qwen3.5-flash", "dashscope", SystemAgentLLMConfig())
|
||||
|
||||
async def _fake_load_user_agent_context(self, session, session_id, user_id):
|
||||
del self, session, session_id
|
||||
return SimpleNamespace(
|
||||
user_id=user_id,
|
||||
username="demo-user",
|
||||
bio=None,
|
||||
settings=SimpleNamespace(
|
||||
preferences=SimpleNamespace(
|
||||
interface_language="zh-CN",
|
||||
ai_language="zh-CN",
|
||||
timezone="Asia/Shanghai",
|
||||
country="CN",
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
async def _fake_execute_backend_tool(
|
||||
self,
|
||||
*,
|
||||
session,
|
||||
owner_id,
|
||||
tool_name,
|
||||
tool_args,
|
||||
):
|
||||
del self, session, owner_id
|
||||
assert tool_name == "create_calendar_event"
|
||||
assert "title" in tool_args
|
||||
return {
|
||||
"result": {"eventId": "evt-1", "ok": True},
|
||||
"ui": {
|
||||
"type": "calendar_card.v1",
|
||||
"version": "v1",
|
||||
"data": {"id": "evt-1", "title": "会议"},
|
||||
},
|
||||
}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.run_service.SessionRepository",
|
||||
_FakeSessionRepository,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.run_service.MessageRepository",
|
||||
_FakeMessageRepository,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.run_service.create_runtime",
|
||||
lambda **_kwargs: _FakeRuntime(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.run_service.RunService._load_agent_model_selection",
|
||||
_fake_load_agent_model_selection,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.run_service.RunService._load_user_agent_context",
|
||||
_fake_load_user_agent_context,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.run_service.RunService._execute_backend_tool",
|
||||
_fake_execute_backend_tool,
|
||||
)
|
||||
|
||||
service = RunService(session_factory=_FakeSessionFactory()) # type: ignore[arg-type]
|
||||
result = await service.run(
|
||||
run_input=_build_run_input(
|
||||
thread_id=str(session_id),
|
||||
text='#tool:create_calendar_event {"title":"会议","startAt":"2026-03-07T08:00:00Z"}',
|
||||
tools=[
|
||||
{
|
||||
"name": "create_calendar_event",
|
||||
"description": "create calendar",
|
||||
"parameters": {"type": "object"},
|
||||
}
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
assert result["pending_tool_call_id"] is None
|
||||
assert any(event["type"] == "TOOL_CALL_RESULT" for event in result["events"])
|
||||
runtime_state = captured["update_runtime_state"]
|
||||
assert isinstance(runtime_state, dict)
|
||||
assert runtime_state["status"] == AgentChatSessionStatus.COMPLETED
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_user_agent_context_parses_profile_settings_v1() -> None:
|
||||
session_id = uuid4()
|
||||
@@ -519,7 +1129,9 @@ async def test_run_service_still_executes_when_profile_missing(
|
||||
session_uuid = session_id
|
||||
run_service = RunService(session_factory=_FakeSessionFactory()) # type: ignore[arg-type]
|
||||
|
||||
await run_service.run(session_id=str(session_id), user_input="hello")
|
||||
await run_service.run(
|
||||
run_input=_build_run_input(thread_id=str(session_id), text="hello")
|
||||
)
|
||||
|
||||
system_prompt = captured["system_prompt"]
|
||||
assert isinstance(system_prompt, str)
|
||||
|
||||
@@ -4,9 +4,18 @@ from core.agent.domain.state_snapshot import AgentStateSnapshot
|
||||
|
||||
|
||||
def test_state_snapshot_serialization_round_trip() -> None:
|
||||
snapshot = AgentStateSnapshot(status="running", pending_tool_call_id="call-1")
|
||||
snapshot = AgentStateSnapshot(
|
||||
status="running",
|
||||
pending_tool_call_id="call-1",
|
||||
pending_tool_name="navigate_to_route",
|
||||
pending_tool_args_sha256="abc",
|
||||
pending_tool_nonce="nonce-1",
|
||||
)
|
||||
|
||||
payload = snapshot.model_dump()
|
||||
|
||||
assert payload["status"] == "running"
|
||||
assert payload["pending_tool_call_id"] == "call-1"
|
||||
assert payload["pending_tool_name"] == "navigate_to_route"
|
||||
assert payload["pending_tool_args_sha256"] == "abc"
|
||||
assert payload["pending_tool_nonce"] == "nonce-1"
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from v1.agent import router as agent_router
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allow_run_request_fails_closed_when_redis_unavailable(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
async def _raise_redis_error():
|
||||
raise RuntimeError("redis unavailable")
|
||||
|
||||
monkeypatch.setattr(agent_router, "get_or_init_redis_client", _raise_redis_error)
|
||||
|
||||
allowed = await agent_router._allow_run_request(user_id="user-1")
|
||||
|
||||
assert allowed is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_sse_slot_fails_closed_when_redis_unavailable(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
async def _raise_redis_error():
|
||||
raise RuntimeError("redis unavailable")
|
||||
|
||||
monkeypatch.setattr(agent_router, "get_or_init_redis_client", _raise_redis_error)
|
||||
|
||||
allowed = await agent_router._acquire_sse_slot(user_id="user-1")
|
||||
|
||||
assert allowed is False
|
||||
@@ -1,7 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
from uuid import UUID
|
||||
|
||||
from ag_ui.core import RunAgentInput
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from core.auth.models import CurrentUser
|
||||
from v1.agent.service import AgentService
|
||||
|
||||
@@ -11,14 +15,19 @@ class _FakeRepository:
|
||||
self.committed = False
|
||||
self.rolled_back = False
|
||||
self.deleted_session_id: str | None = None
|
||||
self.created_with_session_id: str | None = None
|
||||
|
||||
async def get_session_owner(self, *, session_id: str) -> str:
|
||||
del session_id
|
||||
return "00000000-0000-0000-0000-000000000001"
|
||||
if session_id == "00000000-0000-0000-0000-000000000001":
|
||||
return "00000000-0000-0000-0000-000000000001"
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
async def create_session_for_user(self, *, user_id: str) -> str:
|
||||
async def create_session_for_user(
|
||||
self, *, user_id: str, session_id: str | None = None
|
||||
) -> str:
|
||||
del user_id
|
||||
return "00000000-0000-0000-0000-000000000999"
|
||||
self.created_with_session_id = session_id
|
||||
return session_id or "00000000-0000-0000-0000-000000000999"
|
||||
|
||||
async def commit(self) -> None:
|
||||
self.committed = True
|
||||
@@ -29,6 +38,22 @@ class _FakeRepository:
|
||||
async def delete_session(self, *, session_id: str) -> None:
|
||||
self.deleted_session_id = session_id
|
||||
|
||||
async def get_history_day(
|
||||
self, *, session_id: str, before: date | None
|
||||
) -> dict[str, object] | None:
|
||||
del session_id
|
||||
if before is not None and before <= date(2026, 3, 6):
|
||||
return None
|
||||
return {
|
||||
"day": "2026-03-06",
|
||||
"hasMore": False,
|
||||
"messages": [{"id": "m1", "role": "assistant", "content": "hello"}],
|
||||
}
|
||||
|
||||
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None:
|
||||
del user_id
|
||||
return "00000000-0000-0000-0000-000000000001"
|
||||
|
||||
|
||||
class _FakeQueue:
|
||||
async def enqueue(
|
||||
@@ -63,6 +88,20 @@ def _user() -> CurrentUser:
|
||||
)
|
||||
|
||||
|
||||
def _build_run_input(*, thread_id: str, run_id: str) -> RunAgentInput:
|
||||
return RunAgentInput.model_validate(
|
||||
{
|
||||
"threadId": thread_id,
|
||||
"runId": run_id,
|
||||
"state": {},
|
||||
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def test_resume_idempotency_uses_redis_lock_and_task_key() -> None:
|
||||
service = AgentService(
|
||||
repository=_FakeRepository(),
|
||||
@@ -70,37 +109,46 @@ async def test_resume_idempotency_uses_redis_lock_and_task_key() -> None:
|
||||
stream=_FakeStream(),
|
||||
)
|
||||
user = _user()
|
||||
run_input = _build_run_input(
|
||||
thread_id="00000000-0000-0000-0000-000000000001",
|
||||
run_id="run-1",
|
||||
)
|
||||
|
||||
first = await service.enqueue_resume(
|
||||
session_id="session-1",
|
||||
tool_call_id="call-1",
|
||||
thread_id="00000000-0000-0000-0000-000000000001",
|
||||
run_input=run_input,
|
||||
current_user=user,
|
||||
)
|
||||
second = await service.enqueue_resume(
|
||||
session_id="session-1",
|
||||
tool_call_id="call-1",
|
||||
thread_id="00000000-0000-0000-0000-000000000001",
|
||||
run_input=run_input,
|
||||
current_user=user,
|
||||
)
|
||||
|
||||
assert first.task_id == second.task_id
|
||||
|
||||
|
||||
async def test_enqueue_run_without_session_creates_new_session() -> None:
|
||||
async def test_enqueue_run_creates_missing_thread_session() -> None:
|
||||
repository = _FakeRepository()
|
||||
service = AgentService(
|
||||
repository=repository,
|
||||
queue=_FakeQueue(),
|
||||
stream=_FakeStream(),
|
||||
)
|
||||
run_input = _build_run_input(
|
||||
thread_id="00000000-0000-0000-0000-000000000999",
|
||||
run_id="run-1",
|
||||
)
|
||||
|
||||
accepted = await service.enqueue_run(
|
||||
session_id=None,
|
||||
prompt="hello",
|
||||
run_input=run_input,
|
||||
current_user=_user(),
|
||||
)
|
||||
|
||||
assert accepted.session_id == "00000000-0000-0000-0000-000000000999"
|
||||
assert accepted.thread_id == "00000000-0000-0000-0000-000000000999"
|
||||
assert accepted.run_id == "run-1"
|
||||
assert accepted.created is True
|
||||
assert repository.created_with_session_id == "00000000-0000-0000-0000-000000000999"
|
||||
assert repository.committed is True
|
||||
|
||||
|
||||
@@ -111,11 +159,14 @@ async def test_enqueue_run_keeps_created_session_when_enqueue_fails() -> None:
|
||||
queue=_FailingQueue(),
|
||||
stream=_FakeStream(),
|
||||
)
|
||||
run_input = _build_run_input(
|
||||
thread_id="00000000-0000-0000-0000-000000000999",
|
||||
run_id="run-1",
|
||||
)
|
||||
|
||||
try:
|
||||
await service.enqueue_run(
|
||||
session_id=None,
|
||||
prompt="hello",
|
||||
run_input=run_input,
|
||||
current_user=_user(),
|
||||
)
|
||||
raise AssertionError("expected RuntimeError")
|
||||
@@ -123,3 +174,78 @@ async def test_enqueue_run_keeps_created_session_when_enqueue_fails() -> None:
|
||||
assert str(exc) == "enqueue failed"
|
||||
|
||||
assert repository.deleted_session_id is None
|
||||
|
||||
|
||||
async def test_enqueue_run_handles_session_create_race() -> None:
|
||||
class _RaceRepository(_FakeRepository):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.create_calls = 0
|
||||
|
||||
async def get_session_owner(self, *, session_id: str) -> str:
|
||||
if self.create_calls == 0:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
return "00000000-0000-0000-0000-000000000001"
|
||||
|
||||
async def create_session_for_user(
|
||||
self, *, user_id: str, session_id: str | None = None
|
||||
) -> str:
|
||||
del user_id, session_id
|
||||
self.create_calls += 1
|
||||
raise IntegrityError("insert", {}, Exception("duplicate key"))
|
||||
|
||||
repository = _RaceRepository()
|
||||
service = AgentService(
|
||||
repository=repository,
|
||||
queue=_FakeQueue(),
|
||||
stream=_FakeStream(),
|
||||
)
|
||||
run_input = _build_run_input(
|
||||
thread_id="00000000-0000-0000-0000-000000000999",
|
||||
run_id="run-1",
|
||||
)
|
||||
|
||||
accepted = await service.enqueue_run(
|
||||
run_input=run_input,
|
||||
current_user=_user(),
|
||||
)
|
||||
|
||||
assert accepted.created is False
|
||||
assert repository.rolled_back is True
|
||||
|
||||
|
||||
async def test_get_history_snapshot_wraps_history_day_as_state_snapshot_event() -> None:
|
||||
service = AgentService(
|
||||
repository=_FakeRepository(),
|
||||
queue=_FakeQueue(),
|
||||
stream=_FakeStream(),
|
||||
)
|
||||
|
||||
event = await service.get_history_snapshot(
|
||||
thread_id="00000000-0000-0000-0000-000000000001",
|
||||
before=date(2026, 3, 7),
|
||||
current_user=_user(),
|
||||
)
|
||||
|
||||
assert event["type"] == "STATE_SNAPSHOT"
|
||||
assert event["threadId"] == "00000000-0000-0000-0000-000000000001"
|
||||
snapshot = event["snapshot"]
|
||||
assert isinstance(snapshot, dict)
|
||||
assert snapshot["scope"] == "history_day"
|
||||
assert snapshot["day"] == "2026-03-06"
|
||||
assert snapshot["messages"][0]["id"] == "m1"
|
||||
|
||||
|
||||
async def test_get_user_history_snapshot_uses_latest_thread_when_absent() -> None:
|
||||
service = AgentService(
|
||||
repository=_FakeRepository(),
|
||||
queue=_FakeQueue(),
|
||||
stream=_FakeStream(),
|
||||
)
|
||||
event = await service.get_user_history_snapshot(
|
||||
current_user=_user(),
|
||||
thread_id=None,
|
||||
before=None,
|
||||
)
|
||||
assert event["type"] == "STATE_SNAPSHOT"
|
||||
assert event["threadId"] == "00000000-0000-0000-0000-000000000001"
|
||||
|
||||
Reference in New Issue
Block a user