feat(agent): complete task4-6 tool result persistence flow

This commit is contained in:
zl-q
2026-03-08 17:07:09 +08:00
parent 5ada60e834
commit daa1c86d02
15 changed files with 903 additions and 92 deletions
@@ -15,7 +15,9 @@ from core.agent.application.runtime_data_service import RuntimeDataService
from core.agent.application.runtime_loop_service import RuntimeLoopService from core.agent.application.runtime_loop_service import RuntimeLoopService
from core.agent.application.session_state_persistence import ( from core.agent.application.session_state_persistence import (
SessionStatePersistence, SessionStatePersistence,
ToolResultStorage,
compute_tool_args_sha256, compute_tool_args_sha256,
persist_tool_result_payload,
) )
from core.agent.domain.agui_input import extract_latest_tool_result from core.agent.domain.agui_input import extract_latest_tool_result
from core.agent.domain.user_context import build_global_system_prompt from core.agent.domain.user_context import build_global_system_prompt
@@ -57,10 +59,20 @@ class ResumeService:
self, self,
*, *,
session_factory: async_sessionmaker[AsyncSession] = AsyncSessionLocal, session_factory: async_sessionmaker[AsyncSession] = AsyncSessionLocal,
tool_result_storage: ToolResultStorage | None = None,
tool_result_offload_threshold_bytes: int = 4096,
tool_result_bucket: str = "private",
tool_result_prefix: str = "tool-results",
) -> None: ) -> None:
self._session_factory = session_factory self._session_factory = session_factory
self._state_persistence = SessionStatePersistence() self._state_persistence = SessionStatePersistence()
self._loop_service = RuntimeLoopService() self._loop_service = RuntimeLoopService()
self._tool_result_storage = tool_result_storage
self._tool_result_offload_threshold_bytes = max(
1, int(tool_result_offload_threshold_bytes)
)
self._tool_result_bucket = tool_result_bucket
self._tool_result_prefix = tool_result_prefix.strip("/") or "tool-results"
async def resume( async def resume(
self, self,
@@ -152,18 +164,59 @@ class ResumeService:
next_seq = await session_repository.next_message_seq( next_seq = await session_repository.next_message_seq(
session_id=session_uuid session_id=session_uuid
) )
payload_json = json.dumps(
sanitized_tool_payload, ensure_ascii=True, separators=(",", ":")
)
payload_bytes = len(payload_json.encode("utf-8"))
metadata_payload: dict[str, object] = MessageMetadataToolResult(
tool_call_id=tool_call_id,
run_id=run_input.run_id,
tool_name=tool_name,
).model_dump()
stored_content = payload_json
if (
self._tool_result_storage is not None
and payload_bytes >= self._tool_result_offload_threshold_bytes
):
storage_path = (
f"{self._tool_result_prefix}/{run_input.thread_id}/"
f"{run_input.run_id}/{tool_call_id}.json"
)
try:
metadata_payload = await persist_tool_result_payload(
storage=self._tool_result_storage,
run_id=run_input.run_id,
turn_id=str(next_seq),
tool_call_id=tool_call_id,
tool_name=tool_name,
payload=sanitized_tool_payload,
bucket=self._tool_result_bucket,
path=storage_path,
)
stored_content = json.dumps(
{
"toolName": tool_name,
"offloaded": True,
"storage": {
"bucket": metadata_payload.get("storage_bucket"),
"path": metadata_payload.get("storage_path"),
},
},
ensure_ascii=True,
separators=(",", ":"),
)
except Exception:
metadata_payload = MessageMetadataToolResult(
tool_call_id=tool_call_id,
run_id=run_input.run_id,
tool_name=tool_name,
).model_dump()
tool_message = await message_repository.append_message( tool_message = await message_repository.append_message(
session_id=session_uuid, session_id=session_uuid,
seq=next_seq, seq=next_seq,
role=AgentChatMessageRole.TOOL, role=AgentChatMessageRole.TOOL,
content=json.dumps( content=stored_content,
sanitized_tool_payload, ensure_ascii=True, separators=(",", ":") metadata=metadata_payload,
),
metadata=MessageMetadataToolResult(
tool_call_id=tool_call_id,
run_id=run_input.run_id,
tool_name=tool_name,
).model_dump(),
) )
snapshot = self._state_persistence.build_resuming_snapshot( snapshot = self._state_persistence.build_resuming_snapshot(
@@ -274,6 +327,11 @@ class ResumeService:
pending_tool_call_id: str | None = None pending_tool_call_id: str | None = None
events: list[dict[str, object]] = [] events: list[dict[str, object]] = []
runtime_events = runtime_result.get("agui_events")
if isinstance(runtime_events, list):
for event in runtime_events:
if isinstance(event, dict):
events.append(event)
message_delta = 1 message_delta = 1
snapshot = self._state_persistence.build_completed_snapshot() snapshot = self._state_persistence.build_completed_snapshot()
status = AgentChatSessionStatus.COMPLETED status = AgentChatSessionStatus.COMPLETED
@@ -162,6 +162,11 @@ class RunService:
) )
pending_tool_call_id: str | None = None pending_tool_call_id: str | None = None
events: list[dict[str, object]] = [] events: list[dict[str, object]] = []
runtime_events = runtime_result.get("agui_events")
if isinstance(runtime_events, list):
for event in runtime_events:
if isinstance(event, dict):
events.append(event)
message_delta = 2 message_delta = 2
session_status = AgentChatSessionStatus.COMPLETED session_status = AgentChatSessionStatus.COMPLETED
snapshot = self._state_persistence.build_completed_snapshot() snapshot = self._state_persistence.build_completed_snapshot()
@@ -416,6 +416,21 @@ class CrewAIRuntime:
completion_tokens = 0 completion_tokens = 0
total_tokens = 0 total_tokens = 0
total_cost = 0.0 total_cost = 0.0
internal_events: list[dict[str, Any]] = []
def _emit_step_event(
*,
event_type: str,
stage: str,
status: str | None = None,
reason: str | None = None,
) -> None:
data: dict[str, Any] = {"stage": stage}
if status is not None:
data["status"] = status
if reason is not None:
data["reason"] = reason
internal_events.append({"type": event_type, "data": data})
client_front_tools = self._normalize_client_front_tools(tools) client_front_tools = self._normalize_client_front_tools(tools)
intent_tools = self._resolve_stage_tools_payload( intent_tools = self._resolve_stage_tools_payload(
@@ -432,6 +447,18 @@ class CrewAIRuntime:
) )
if resume_from_stage in {"execution", "organization"}: if resume_from_stage in {"execution", "organization"}:
_emit_step_event(
event_type="stepStarted",
stage="intent",
status="skipped",
reason="resume_from_interrupted_stage",
)
_emit_step_event(
event_type="stepFinished",
stage="intent",
status="skipped",
reason="resume_from_interrupted_stage",
)
intent_result = IntentResult( intent_result = IntentResult(
route="NEEDS_EXECUTION", route="NEEDS_EXECUTION",
intent_summary="resume_from_interrupted_stage", intent_summary="resume_from_interrupted_stage",
@@ -439,6 +466,7 @@ class CrewAIRuntime:
safety_flags=[], safety_flags=[],
) )
else: else:
_emit_step_event(event_type="stepStarted", stage="intent")
intent_text, intent_usage, _, _ = self._run_stage_with_crewai( intent_text, intent_usage, _, _ = self._run_stage_with_crewai(
stage="intent", stage="intent",
user_content=user_input, user_content=user_input,
@@ -451,11 +479,15 @@ class CrewAIRuntime:
total_tokens += intent_usage.total_tokens total_tokens += intent_usage.total_tokens
total_cost += intent_usage.cost total_cost += intent_usage.cost
intent_result = _parse_intent_result(intent_text) intent_result = _parse_intent_result(intent_text)
_emit_step_event(
event_type="stepFinished", stage="intent", status="completed"
)
assistant_text = intent_result.assistant_text or "" assistant_text = intent_result.assistant_text or ""
pending_front_tool: dict[str, object] | None = None pending_front_tool: dict[str, object] | None = None
if intent_result.route == "NEEDS_EXECUTION": if intent_result.route == "NEEDS_EXECUTION":
_emit_step_event(event_type="stepStarted", stage="execution")
execution_input = json.dumps( execution_input = json.dumps(
{ {
"user_input": user_input, "user_input": user_input,
@@ -483,8 +515,14 @@ class CrewAIRuntime:
execution_tools=execution_tools, execution_tools=execution_tools,
pending_call=pending_call, pending_call=pending_call,
) )
_emit_step_event(
event_type="stepFinished",
stage="execution",
status="pending_approval" if pending_call is not None else "completed",
)
if pending_call is None and resume_from_stage != "execution": if pending_call is None and resume_from_stage != "execution":
_emit_step_event(event_type="stepStarted", stage="organization")
execution_result = _parse_execution_result(execution_text) execution_result = _parse_execution_result(execution_text)
organization_input = json.dumps( organization_input = json.dumps(
{ {
@@ -522,28 +560,68 @@ class CrewAIRuntime:
fallback_text=execution_result.report_brief, fallback_text=execution_result.report_brief,
) )
assistant_text = organization_result.assistant_text assistant_text = organization_result.assistant_text
_emit_step_event(
event_type="stepFinished",
stage="organization",
status="completed",
)
elif pending_call is not None: elif pending_call is not None:
assistant_text = ( assistant_text = (
intent_result.execution_brief or "Tool call pending approval" intent_result.execution_brief or "Tool call pending approval"
) )
_emit_step_event(
event_type="stepStarted",
stage="organization",
status="skipped",
reason="pending_tool_approval",
)
_emit_step_event(
event_type="stepFinished",
stage="organization",
status="skipped",
reason="pending_tool_approval",
)
else: else:
execution_result = _parse_execution_result(execution_text) execution_result = _parse_execution_result(execution_text)
assistant_text = execution_result.report_brief assistant_text = execution_result.report_brief
_emit_step_event(
event_type="stepStarted",
stage="organization",
status="skipped",
reason="resume_from_execution",
)
_emit_step_event(
event_type="stepFinished",
stage="organization",
status="skipped",
reason="resume_from_execution",
)
else:
_emit_step_event(
event_type="stepStarted",
stage="execution",
status="skipped",
reason="direct_execution_route",
)
_emit_step_event(
event_type="stepFinished",
stage="execution",
status="skipped",
reason="direct_execution_route",
)
_emit_step_event(
event_type="stepStarted",
stage="organization",
status="skipped",
reason="direct_execution_route",
)
_emit_step_event(
event_type="stepFinished",
stage="organization",
status="skipped",
reason="direct_execution_route",
)
internal_events = [
{"type": "llmStarted", "data": {"model": self._config.model_code}},
{"type": "llmChunk", "data": {"text": assistant_text}},
{
"type": "llmFinished",
"data": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
"cost": total_cost,
"provider": self._config.provider_name,
},
},
]
return { return {
"assistant_text": assistant_text, "assistant_text": assistant_text,
"prompt_tokens": prompt_tokens, "prompt_tokens": prompt_tokens,
@@ -57,19 +57,6 @@ async def _execute_create_calendar_event(
) )
event_id = str(created.id) event_id = str(created.id)
return { return {
"result": {
"eventId": event_id,
"ok": True,
"message": "日程已创建",
"title": created.title,
"description": created.description,
"startAt": created.start_at.isoformat(),
"endAt": created.end_at.isoformat() if created.end_at is not None else None,
"timezone": created.timezone,
"location": location_value,
"sourceType": "agent_generated",
},
"ui": {
"type": "calendar_card.v1", "type": "calendar_card.v1",
"version": "v1", "version": "v1",
"data": { "data": {
@@ -77,13 +64,13 @@ async def _execute_create_calendar_event(
"title": created.title, "title": created.title,
"description": created.description, "description": created.description,
"startAt": created.start_at.isoformat(), "startAt": created.start_at.isoformat(),
"endAt": ( "endAt": created.end_at.isoformat() if created.end_at is not None else None,
created.end_at.isoformat() if created.end_at is not None else None
),
"timezone": created.timezone, "timezone": created.timezone,
"location": location_value, "location": location_value,
"color": "#4F46E5", "color": "#4F46E5",
"sourceType": "agent_generated", "sourceType": "agent_generated",
"ok": True,
"message": "日程已创建",
}, },
"actions": [ "actions": [
{ {
@@ -92,7 +79,6 @@ async def _execute_create_calendar_event(
"target": f"/calendar/events/{event_id}", "target": f"/calendar/events/{event_id}",
} }
], ],
},
} }
@@ -9,6 +9,9 @@ from core.agent.domain.agui_input import parse_run_input
from core.agent.application.resume_service import ResumeService from core.agent.application.resume_service import ResumeService
from core.agent.application.run_service import RunService from core.agent.application.run_service import RunService
from core.agent.infrastructure.events.redis_stream import RedisStreamEventStore from core.agent.infrastructure.events.redis_stream import RedisStreamEventStore
from core.agent.infrastructure.storage.tool_result_storage import (
create_tool_result_storage,
)
from core.config.settings import config from core.config.settings import config
from core.logging import get_logger from core.logging import get_logger
from core.taskiq.app import bulk_broker, critical_broker, default_broker from core.taskiq.app import bulk_broker, critical_broker, default_broker
@@ -125,8 +128,13 @@ async def run_agent_task(
) -> dict[str, object]: ) -> dict[str, object]:
publisher = publish_event or await _build_redis_publisher() publisher = publish_event or await _build_redis_publisher()
enqueue = enqueue_command or _enqueue_followup_command enqueue = enqueue_command or _enqueue_followup_command
tool_result_storage = create_tool_result_storage()
service_run = run_service or RunService() service_run = run_service or RunService()
service_resume = resume_service or ResumeService() service_resume = resume_service or ResumeService(
tool_result_storage=tool_result_storage,
tool_result_bucket="private",
tool_result_prefix="tool-results",
)
command_type = str(command.get("command", "run")) command_type = str(command.get("command", "run"))
if command_type not in {"run", "resume", "resume_continue"}: if command_type not in {"run", "resume", "resume_continue"}:
@@ -0,0 +1,6 @@
from core.agent.infrastructure.storage.tool_result_storage import (
SupabaseToolResultStorage,
create_tool_result_storage,
)
__all__ = ["SupabaseToolResultStorage", "create_tool_result_storage"]
@@ -0,0 +1,78 @@
from __future__ import annotations
import asyncio
import json
from typing import Any
from services.base.supabase import supabase_service
class SupabaseToolResultStorage:
def _bucket_client(self, *, bucket: str) -> Any:
client = supabase_service.get_admin_client()
storage = getattr(client, "storage", None)
if storage is None:
raise RuntimeError("Supabase storage client unavailable")
from_bucket = getattr(storage, "from_", None)
if not callable(from_bucket):
raise RuntimeError("Supabase storage bucket accessor unavailable")
return from_bucket(bucket)
async def upload_json(
self,
*,
bucket: str,
path: str,
payload: dict[str, object],
) -> str:
data = json.dumps(payload, ensure_ascii=True, separators=(",", ":")).encode(
"utf-8"
)
def _upload() -> object:
bucket_client = self._bucket_client(bucket=bucket)
upload = getattr(bucket_client, "upload", None)
if not callable(upload):
raise RuntimeError("Supabase storage upload is unavailable")
return upload(
path,
data,
{
"content-type": "application/json",
"upsert": "true",
},
)
result = await asyncio.to_thread(_upload)
return str(result or "")
async def read_json(self, *, bucket: str, path: str) -> dict[str, object] | None:
def _download() -> object:
bucket_client = self._bucket_client(bucket=bucket)
download = getattr(bucket_client, "download", None)
if not callable(download):
raise RuntimeError("Supabase storage download is unavailable")
return download(path)
raw = await asyncio.to_thread(_download)
if isinstance(raw, bytes):
text = raw.decode("utf-8")
elif isinstance(raw, str):
text = raw
else:
return None
try:
payload = json.loads(text)
except ValueError:
return None
if not isinstance(payload, dict):
return None
return payload
def create_tool_result_storage() -> SupabaseToolResultStorage | None:
try:
supabase_service.get_admin_client()
except Exception:
return None
return SupabaseToolResultStorage()
+5 -1
View File
@@ -9,6 +9,9 @@ from redis.asyncio import Redis
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from core.agent.infrastructure.events.redis_stream import RedisStreamEventStore from core.agent.infrastructure.events.redis_stream import RedisStreamEventStore
from core.agent.infrastructure.storage.tool_result_storage import (
create_tool_result_storage,
)
from core.agent.infrastructure.queue.tasks import ( from core.agent.infrastructure.queue.tasks import (
run_command_task, run_command_task,
run_command_task_bulk, run_command_task_bulk,
@@ -109,8 +112,9 @@ class RedisEventStream:
def get_agent_service(session: AsyncSession = Depends(get_db)) -> AgentService: def get_agent_service(session: AsyncSession = Depends(get_db)) -> AgentService:
tool_result_storage = create_tool_result_storage()
return AgentService( return AgentService(
repository=AgentRepository(session), repository=AgentRepository(session, tool_result_storage=tool_result_storage),
queue=TaskiqQueueClient(), queue=TaskiqQueueClient(),
stream=RedisEventStream(), stream=RedisEventStream(),
) )
+49 -9
View File
@@ -2,6 +2,7 @@ from __future__ import annotations
from datetime import date, datetime, time, timedelta, timezone from datetime import date, datetime, time, timedelta, timezone
import json import json
from typing import Protocol
from uuid import UUID from uuid import UUID
from fastapi import HTTPException from fastapi import HTTPException
@@ -12,9 +13,21 @@ from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
from models.agent_chat_session import AgentChatSession from models.agent_chat_session import AgentChatSession
class ToolResultPayloadStorage(Protocol):
async def read_json(
self, *, bucket: str, path: str
) -> dict[str, object] | None: ...
class AgentRepository: class AgentRepository:
def __init__(self, session: AsyncSession) -> None: def __init__(
self,
session: AsyncSession,
*,
tool_result_storage: ToolResultPayloadStorage | None = None,
) -> None:
self._session = session self._session = session
self._tool_result_storage = tool_result_storage
async def get_session_owner(self, *, session_id: str) -> str: async def get_session_owner(self, *, session_id: str) -> str:
try: try:
@@ -42,7 +55,9 @@ class AgentRepository:
try: try:
session_uuid = UUID(session_id) session_uuid = UUID(session_id)
except ValueError as exc: except ValueError as exc:
raise HTTPException(status_code=422, detail="Invalid session_id") from exc raise HTTPException(
status_code=422, detail="Invalid session_id"
) from exc
session = AgentChatSession( session = AgentChatSession(
id=session_uuid, id=session_uuid,
@@ -118,10 +133,13 @@ class AgentRepository:
) )
messages = (await self._session.execute(message_stmt)).scalars().all() messages = (await self._session.execute(message_stmt)).scalars().all()
has_more = any(day < target_day for day in unique_days) has_more = any(day < target_day for day in unique_days)
snapshot_messages: list[dict[str, object]] = []
for message in messages:
snapshot_messages.append(await self._to_snapshot_message(message))
return { return {
"day": target_day.isoformat(), "day": target_day.isoformat(),
"hasMore": has_more, "hasMore": has_more,
"messages": [self._to_snapshot_message(msg) for msg in messages], "messages": snapshot_messages,
} }
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None: async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None:
@@ -141,8 +159,9 @@ class AgentRepository:
return None return None
return str(latest_id) return str(latest_id)
@staticmethod async def _to_snapshot_message(
def _to_snapshot_message(message: AgentChatMessage) -> dict[str, object]: self, message: AgentChatMessage
) -> dict[str, object]:
role = ( role = (
message.role.value message.role.value
if isinstance(message.role, AgentChatMessageRole) if isinstance(message.role, AgentChatMessageRole)
@@ -167,14 +186,35 @@ class AgentRepository:
parsed_content = decoded parsed_content = decoded
except (TypeError, ValueError): except (TypeError, ValueError):
parsed_content = None parsed_content = None
if parsed_content is not None:
ui = parsed_content.get("ui") hydrated_content: dict[str, object] | None = None
if self._tool_result_storage is not None:
storage_bucket = metadata.get("storage_bucket")
storage_path = metadata.get("storage_path")
if isinstance(storage_bucket, str) and isinstance(storage_path, str):
try:
hydrated_content = await self._tool_result_storage.read_json(
bucket=storage_bucket,
path=storage_path,
)
except Exception:
hydrated_content = None
resolved_content = hydrated_content or parsed_content
if resolved_content is not None:
result = resolved_content.get("result")
if isinstance(result, dict):
result_content = result.get("content")
if isinstance(result_content, str):
payload["content"] = result_content
ui = resolved_content.get("ui")
if isinstance(ui, dict): if isinstance(ui, dict):
payload["ui"] = ui payload["ui"] = ui
display_content = parsed_content.get("content") display_content = resolved_content.get("content")
if isinstance(display_content, str): if isinstance(display_content, str):
payload["content"] = display_content payload["content"] = display_content
else:
if "content" not in payload:
payload["content"] = message.content payload["content"] = message.content
else: else:
payload["content"] = message.content payload["content"] = message.content
@@ -12,6 +12,10 @@ from core.agent.application.resume_service import ResumeService
from core.agent.application.run_service import RunService from core.agent.application.run_service import RunService
from core.agent.infrastructure.persistence.session_repository import SessionRepository from core.agent.infrastructure.persistence.session_repository import SessionRepository
from core.agent.infrastructure.queue.tasks import run_agent_task from core.agent.infrastructure.queue.tasks import run_agent_task
from core.agent.infrastructure.storage.tool_result_storage import (
create_tool_result_storage,
)
from services.base.supabase import supabase_service
from core.db import AsyncSessionLocal, engine from core.db import AsyncSessionLocal, engine
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
@@ -242,6 +246,299 @@ async def test_run_then_resume_persists_messages_and_session_state(
await cleanup_session.commit() await cleanup_session.commit()
@pytest.mark.asyncio
async def test_resume_tool_result_offloads_to_supabase_storage_for_calendar_tool(
monkeypatch: pytest.MonkeyPatch,
) -> None:
call_count = {"n": 0}
def _fake_execute(
self,
*,
user_input: str,
system_prompt: str | None = None,
tools: list[dict[str, object]] | None = None,
) -> dict[str, object]:
del self, user_input, system_prompt, tools
call_count["n"] += 1
if call_count["n"] == 1:
return {
"assistant_text": "我来创建日历事件,请稍候确认。",
"prompt_tokens": 10,
"completion_tokens": 6,
"total_tokens": 16,
"cost": 0.002,
"pending_front_tool": {
"name": "front.create_calendar_event",
"args": {
"title": "测试日程",
"start": "2026-03-09T09:00:00+08:00",
"end": "2026-03-09T10:00:00+08:00",
},
"target": "frontend",
},
"agui_events": [],
}
return {
"assistant_text": "日历已创建。",
"prompt_tokens": 2,
"completion_tokens": 2,
"total_tokens": 4,
"cost": 0.001,
"pending_front_tool": None,
"agui_events": [],
}
monkeypatch.setattr(
"core.agent.infrastructure.crewai.runtime.CrewAIRuntime.execute",
_fake_execute,
)
factory_id = uuid.uuid4()
test_user_id: str | None = None
test_user_email = f"agent-it-{uuid.uuid4().hex[:8]}@example.com"
owner_id = uuid.uuid4()
initialized = await supabase_service.initialize()
if not initialized:
pytest.skip("Supabase service is unavailable")
admin_client = supabase_service.get_admin_client()
tool_result_storage = create_tool_result_storage()
assert tool_result_storage is not None
created_user = admin_client.auth.admin.create_user(
{
"email": test_user_email,
"password": "Passw0rd!123",
"email_confirm": True,
"user_metadata": {"source": "integration-test"},
}
)
test_user_id = str(created_user.user.id)
owner_id = uuid.UUID(test_user_id)
await engine.dispose()
async with AsyncSessionLocal() as lookup_session:
llm_row = await lookup_session.execute(select(Llm.id).limit(1))
llm_id = llm_row.scalar_one_or_none()
if llm_id is None:
async with AsyncSessionLocal() as seed_session:
factory_row = await seed_session.execute(
select(LlmFactory.id).where(LlmFactory.name == "dashscope").limit(1)
)
existing_factory_id = factory_row.scalar_one_or_none()
if existing_factory_id is None:
seed_session.add(
LlmFactory(
id=factory_id,
name="dashscope",
request_url="https://dashscope.example",
)
)
await seed_session.commit()
else:
factory_id = existing_factory_id
async with AsyncSessionLocal() as seed_session:
llm_id = uuid.uuid4()
seed_session.add(
Llm(
id=llm_id,
factory_id=factory_id,
model_code=f"qwen3.5-flash-test-{uuid.uuid4().hex[:6]}",
)
)
await seed_session.commit()
storage = admin_client.storage
try:
storage.get_bucket("private")
except Exception:
storage.create_bucket("private", "private", {"public": False})
session_uuid = uuid.uuid4()
agent_type = f"AAA_TEST_{uuid.uuid4().hex[:8]}"
uploaded_path: str | None = None
try:
probe_path = f"tool-results/probe/{uuid.uuid4().hex}.json"
try:
storage.from_("private").upload(probe_path, b"{}")
storage.from_("private").remove([probe_path])
except Exception:
pytest.skip(
"Supabase Storage upload API unavailable in current environment"
)
async with AsyncSessionLocal() as seed_session:
existing_profile = await seed_session.get(Profile, owner_id)
if existing_profile is None:
seed_session.add(
Profile(
id=owner_id,
username=f"it_{uuid.uuid4().hex[:8]}",
)
)
seed_session.add(
SystemAgents(agent_type=agent_type, llm_id=llm_id, status="active")
)
seed_session.add(AgentChatSession(id=session_uuid, user_id=owner_id))
await seed_session.commit()
run_result = await run_agent_task(
{
"command": "run",
"run_input": {
"threadId": str(session_uuid),
"runId": "run-storage-1",
"state": {},
"messages": [
{
"id": "u1",
"role": "user",
"content": "帮我创建明天9点到10点的日历",
}
],
"tools": [
{
"name": "front.create_calendar_event",
"description": "Create calendar event",
"parameters": {"type": "object"},
}
],
"context": [],
"forwardedProps": {},
},
},
run_service=RunService(),
resume_service=ResumeService(
tool_result_storage=tool_result_storage,
tool_result_bucket="private",
tool_result_prefix="tool-results",
),
)
pending_tool_call_id = str(run_result["pending_tool_call_id"])
snapshot = run_result["state_snapshot"]
assert isinstance(snapshot, dict)
pending_tool_nonce = snapshot.get("pending_tool_nonce")
assert isinstance(pending_tool_nonce, str)
await run_agent_task(
{
"command": "resume",
"run_input": {
"threadId": str(session_uuid),
"runId": "run-storage-2",
"state": {},
"messages": [
{
"id": "tool-1",
"role": "tool",
"toolCallId": pending_tool_call_id,
"content": json.dumps(
{
"toolName": "front.create_calendar_event",
"toolArgs": {
"title": "测试日程",
"start": "2026-03-09T09:00:00+08:00",
"end": "2026-03-09T10:00:00+08:00",
"__nonce": pending_tool_nonce,
},
"nonce": pending_tool_nonce,
"result": {
"ok": True,
"type": "calendar_card.v1",
"version": "v1",
"data": {
"id": "evt-test",
"title": "测试日程",
"description": "x" * 9000,
},
"actions": [
{
"type": "link",
"label": "查看详情",
"target": "/calendar/events/evt-test",
}
],
},
},
ensure_ascii=True,
separators=(",", ":"),
),
}
],
"tools": [],
"context": [],
"forwardedProps": {},
},
},
run_service=RunService(),
resume_service=ResumeService(
tool_result_storage=tool_result_storage,
tool_result_bucket="private",
tool_result_prefix="tool-results",
),
)
await engine.dispose()
async with AsyncSessionLocal() as verify_session:
rows = await verify_session.execute(
select(AgentChatMessage)
.where(AgentChatMessage.session_id == session_uuid)
.where(AgentChatMessage.role == AgentChatMessageRole.TOOL)
.order_by(AgentChatMessage.seq.desc())
)
tool_message = rows.scalars().first()
assert tool_message is not None
metadata = tool_message.metadata_json or {}
storage_bucket = metadata.get("storage_bucket")
storage_path = metadata.get("storage_path")
assert storage_bucket == "private"
assert isinstance(storage_path, str)
assert storage_path.startswith("tool-results/")
uploaded_path = storage_path
downloaded = storage.from_("private").download(uploaded_path)
if isinstance(downloaded, bytes):
downloaded_payload = json.loads(downloaded.decode("utf-8"))
else:
downloaded_payload = json.loads(str(downloaded))
assert downloaded_payload["toolName"] == "front.create_calendar_event"
result_payload = downloaded_payload["result"]
assert result_payload["type"] == "calendar_card.v1"
assert result_payload["data"]["id"] == "evt-test"
finally:
if uploaded_path:
try:
storage.from_("private").remove([uploaded_path])
except Exception:
pass
async with AsyncSessionLocal() as cleanup_session:
await cleanup_session.execute(
delete(AgentChatSession).where(AgentChatSession.id == session_uuid)
)
await cleanup_session.execute(
delete(SystemAgents).where(SystemAgents.agent_type == agent_type)
)
await cleanup_session.execute(delete(Profile).where(Profile.id == owner_id))
await cleanup_session.execute(
delete(Llm).where(Llm.factory_id == factory_id)
)
await cleanup_session.execute(
delete(LlmFactory).where(LlmFactory.id == factory_id)
)
await cleanup_session.commit()
if test_user_id is not None:
try:
admin_client.auth.admin.delete_user(test_user_id)
except Exception:
pass
await supabase_service.close()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_service_embeds_profile_settings_in_runtime_system_prompt( async def test_run_service_embeds_profile_settings_in_runtime_system_prompt(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
@@ -0,0 +1,65 @@
from __future__ import annotations
from datetime import datetime, timezone
from types import SimpleNamespace
from typing import cast
from uuid import uuid4
import pytest
from sqlalchemy.ext.asyncio import AsyncSession
from core.agent.infrastructure.crewai.tools.backend.create_calendar_event_tool import (
_execute_create_calendar_event,
)
@pytest.mark.asyncio
async def test_create_calendar_event_tool_returns_ui_schema_v1_top_level(
monkeypatch: pytest.MonkeyPatch,
) -> None:
event_id = uuid4()
created = SimpleNamespace(
id=event_id,
title="晨会",
description="同步计划",
start_at=datetime(2026, 3, 8, 1, 0, tzinfo=timezone.utc),
end_at=datetime(2026, 3, 8, 2, 0, tzinfo=timezone.utc),
timezone="Asia/Shanghai",
)
class _FakeService:
def __init__(self, **kwargs) -> None:
del kwargs
async def create_agent_generated(self, payload):
del payload
return created
class _FakeRepository:
def __init__(self, session) -> None:
del session
monkeypatch.setattr(
"core.agent.infrastructure.crewai.tools.backend.create_calendar_event_tool.ScheduleItemService",
_FakeService,
)
monkeypatch.setattr(
"core.agent.infrastructure.crewai.tools.backend.create_calendar_event_tool.SQLAlchemyScheduleItemRepository",
_FakeRepository,
)
result = cast(
dict[str, object],
await _execute_create_calendar_event(
session=cast(AsyncSession, SimpleNamespace()),
owner_id=uuid4(),
tool_args={"title": "晨会"},
),
)
assert result["type"] == "calendar_card.v1"
assert result["version"] == "v1"
data = cast(dict[str, object], result["data"])
actions = cast(list[dict[str, object]], result["actions"])
assert data["id"] == str(event_id)
assert actions
@@ -116,7 +116,7 @@ def test_runtime_needs_execution_and_collects_front_tool_call() -> None:
tools = item["tools"] tools = item["tools"]
assert isinstance(tools, list) assert isinstance(tools, list)
assert any(t.get("name") == "front.navigate_to_route" for t in tools) assert any(t.get("name") == "front.navigate_to_route" for t in tools)
execution_tools = calls[1]["tools"] execution_tools = cast(list[dict[str, object]], calls[1]["tools"])
assert any(t.get("name") == "back.create_calendar_event" for t in execution_tools) assert any(t.get("name") == "back.create_calendar_event" for t in execution_tools)
assert result["assistant_text"] == "do it" assert result["assistant_text"] == "do it"
assert result["pending_front_tool"] == { assert result["pending_front_tool"] == {
@@ -131,3 +131,51 @@ def test_runtime_backend_registry_check() -> None:
runtime = _build_runtime() runtime = _build_runtime()
assert runtime.is_registered_backend_tool("back.create_calendar_event") is True assert runtime.is_registered_backend_tool("back.create_calendar_event") is True
assert runtime.is_registered_backend_tool("back.unknown") is False assert runtime.is_registered_backend_tool("back.unknown") is False
def test_runtime_emits_step_started_finished_for_all_three_stages() -> None:
runtime = _build_runtime()
def _fake_run_stage(self, **kwargs):
stage = kwargs["stage"]
if stage == "intent":
return (
'{"route":"NEEDS_EXECUTION","intent_summary":"need tool","execution_brief":"do it","safety_flags":[]}',
UsageCost(1, 1, 2, 0.01),
[],
None,
)
if stage == "execution":
return (
'{"status":"SUCCESS","execution_summary":"done","execution_data":{},"report_brief":"ok"}',
UsageCost(2, 2, 4, 0.02),
[],
None,
)
return (
'{"assistant_text":"final answer","response_metadata":{"source":"organization"}}',
UsageCost(3, 3, 6, 0.03),
[],
None,
)
runtime._run_stage_with_crewai = MethodType(_fake_run_stage, runtime) # type: ignore[method-assign]
result = runtime.execute(user_input="go", tools=[])
agui_events = cast(list[dict[str, object]], result["agui_events"])
step_events = [
event
for event in agui_events
if event.get("type") in {"STEP_STARTED", "STEP_FINISHED"}
]
assert len(step_events) == 6
assert [
cast(dict[str, object], event["data"])["stage"] for event in step_events
] == [
"intent",
"intent",
"execution",
"execution",
"organization",
"organization",
]
@@ -406,6 +406,114 @@ async def test_resume_service_rejects_tool_result_when_not_ok(
) )
@pytest.mark.asyncio
async def test_resume_service_offloads_large_tool_result_payload_to_object_storage(
monkeypatch: pytest.MonkeyPatch,
) -> None:
session_id = uuid4()
user_id = uuid4()
captured: list[dict[str, object]] = []
class _FakeDbSession:
async def commit(self) -> None:
return None
class _FakeSessionFactory:
def __call__(self) -> "_FakeSessionFactory":
return self
async def __aenter__(self) -> _FakeDbSession:
return _FakeDbSession()
async def __aexit__(self, exc_type, exc, tb) -> bool:
del exc_type, exc, tb
return False
class _FakeSessionRepository:
def __init__(self, session: object) -> None:
del session
async def lock_session_for_update(self, *, session_id: object):
return SimpleNamespace(
id=session_id,
user_id=user_id,
status=AgentChatSessionStatus.RUNNING,
message_count=0,
total_tokens=0,
total_cost=0,
state_snapshot={
"pending_tool_call_id": "call-1",
"pending_tool_name": "front.navigate_to_route",
"pending_tool_args_sha256": "c8e6573e6ce79d9a8052d5167e18602841fa4248f3e8a2efb448c4bfbd298a12",
"pending_tool_nonce": "nonce-1",
},
)
async def next_message_seq(self, *, session_id: object) -> int:
del session_id
return 1
async def update_runtime_state(self, **kwargs) -> None:
del kwargs
class _FakeMessageRepository:
def __init__(self, session: object) -> None:
del session
async def append_message(self, **kwargs) -> None:
captured.append(kwargs)
class _FakeStorage:
async def upload_json(
self, *, bucket: str, path: str, payload: dict[str, object]
) -> str:
del bucket, path, payload
return "etag-1"
monkeypatch.setattr(
"core.agent.application.resume_service.SessionRepository",
_FakeSessionRepository,
)
monkeypatch.setattr(
"core.agent.application.resume_service.MessageRepository",
_FakeMessageRepository,
)
service = ResumeService( # type: ignore[call-arg]
session_factory=_FakeSessionFactory(), # type: ignore[arg-type]
tool_result_storage=_FakeStorage(),
tool_result_offload_threshold_bytes=1,
tool_result_bucket="private",
tool_result_prefix="tool-results",
)
await service.resume(
run_input=_build_resume_input(
thread_id=str(session_id),
tool_call_id="call-1",
content=json.dumps(
{
"toolName": "front.navigate_to_route",
"toolArgs": {
"target": "/calendar/dayweek",
"replace": False,
"__nonce": "nonce-1",
},
"nonce": "nonce-1",
"result": {"ok": True, "payload": "x" * 4096},
},
ensure_ascii=True,
separators=(",", ":"),
),
)
)
metadata = captured[0]["metadata"]
assert isinstance(metadata, dict)
assert metadata["storage_bucket"] == "private"
assert metadata["storage_path"].startswith("tool-results/")
assert isinstance(metadata["payload_sha256"], str)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_load_agent_model_selection_returns_validated_llm_config() -> None: async def test_load_agent_model_selection_returns_validated_llm_config() -> None:
run_service = RunService() run_service = RunService()
+58 -29
View File
@@ -1,43 +1,72 @@
from __future__ import annotations from __future__ import annotations
from fastapi import HTTPException from datetime import datetime, timezone
from types import SimpleNamespace
from uuid import uuid4
import pytest
from models.agent_chat_message import AgentChatMessageRole
from v1.agent.repository import AgentRepository from v1.agent.repository import AgentRepository
class _FakeSession: class _FakeToolResultStorage:
def __init__(self) -> None: def __init__(self, payload: dict[str, object] | None) -> None:
self.added: list[object] = [] self._payload = payload
def add(self, obj: object) -> None: async def read_json(self, *, bucket: str, path: str) -> dict[str, object] | None:
self.added.append(obj) del bucket, path
return self._payload
async def flush(self) -> None:
return None
async def refresh(self, _obj: object) -> None:
return None
async def test_create_session_for_user_creates_session_row() -> None: @pytest.mark.asyncio
session = _FakeSession() async def test_tool_message_hydrates_content_from_object_storage() -> None:
repository = AgentRepository(session=session) # type: ignore[arg-type] repository = AgentRepository(
session=SimpleNamespace(), # type: ignore[arg-type]
await repository.create_session_for_user( tool_result_storage=_FakeToolResultStorage(
user_id="00000000-0000-0000-0000-000000000001" {
"toolName": "front.navigate_to_route",
"result": {"ok": True, "applied": True, "content": "已跳转"},
}
),
)
message = SimpleNamespace(
id=uuid4(),
role=AgentChatMessageRole.TOOL,
created_at=datetime.now(timezone.utc),
content='{"offloaded":true}',
metadata_json={
"tool_call_id": "call-1",
"storage_bucket": "private",
"storage_path": "tool-results/run-1/call-1.json",
},
) )
session_row = session.added[0] payload = await repository._to_snapshot_message(message) # type: ignore[arg-type]
assert str(getattr(session_row, "user_id")) == "00000000-0000-0000-0000-000000000001"
assert payload["toolCallId"] == "call-1"
assert payload["content"] == "已跳转"
async def test_create_session_for_user_rejects_invalid_uuid() -> None: @pytest.mark.asyncio
session = _FakeSession() async def test_tool_message_keeps_inline_content_when_storage_payload_missing() -> None:
repository = AgentRepository(session=session) # type: ignore[arg-type] repository = AgentRepository(
session=SimpleNamespace(), # type: ignore[arg-type]
tool_result_storage=_FakeToolResultStorage(None),
)
message = SimpleNamespace(
id=uuid4(),
role=AgentChatMessageRole.TOOL,
created_at=datetime.now(timezone.utc),
content="inline-tool-content",
metadata_json={
"tool_call_id": "call-2",
"storage_bucket": "private",
"storage_path": "tool-results/run-1/call-2.json",
},
)
try: payload = await repository._to_snapshot_message(message) # type: ignore[arg-type]
await repository.create_session_for_user(user_id="invalid-uuid")
raise AssertionError("expected invalid user_id") assert payload["toolCallId"] == "call-2"
except HTTPException as exc: assert payload["content"] == "inline-tool-content"
assert exc.status_code == 422
assert exc.detail == "Invalid user_id"
+2 -1
View File
@@ -201,7 +201,7 @@ services:
image: supabase/storage-api:v1.33.0 image: supabase/storage-api:v1.33.0
restart: unless-stopped restart: unless-stopped
volumes: volumes:
- ./volumes/storage:/var/lib/storage:z - storage_data:/var/lib/storage
healthcheck: healthcheck:
test: ["CMD", "wget", "--no-verbose", "--tries=1", "--spider", "http://storage:5000/status"] test: ["CMD", "wget", "--no-verbose", "--tries=1", "--spider", "http://storage:5000/status"]
timeout: 5s timeout: 5s
@@ -443,3 +443,4 @@ services:
volumes: volumes:
redis_data: redis_data:
db-config: db-config:
storage_data: