feat(agent): complete task4-6 tool result persistence flow

This commit is contained in:
zl-q
2026-03-08 17:07:09 +08:00
parent 5ada60e834
commit daa1c86d02
15 changed files with 903 additions and 92 deletions
@@ -15,7 +15,9 @@ from core.agent.application.runtime_data_service import RuntimeDataService
from core.agent.application.runtime_loop_service import RuntimeLoopService
from core.agent.application.session_state_persistence import (
SessionStatePersistence,
ToolResultStorage,
compute_tool_args_sha256,
persist_tool_result_payload,
)
from core.agent.domain.agui_input import extract_latest_tool_result
from core.agent.domain.user_context import build_global_system_prompt
@@ -57,10 +59,20 @@ class ResumeService:
self,
*,
session_factory: async_sessionmaker[AsyncSession] = AsyncSessionLocal,
tool_result_storage: ToolResultStorage | None = None,
tool_result_offload_threshold_bytes: int = 4096,
tool_result_bucket: str = "private",
tool_result_prefix: str = "tool-results",
) -> None:
self._session_factory = session_factory
self._state_persistence = SessionStatePersistence()
self._loop_service = RuntimeLoopService()
self._tool_result_storage = tool_result_storage
self._tool_result_offload_threshold_bytes = max(
1, int(tool_result_offload_threshold_bytes)
)
self._tool_result_bucket = tool_result_bucket
self._tool_result_prefix = tool_result_prefix.strip("/") or "tool-results"
async def resume(
self,
@@ -152,18 +164,59 @@ class ResumeService:
next_seq = await session_repository.next_message_seq(
session_id=session_uuid
)
payload_json = json.dumps(
sanitized_tool_payload, ensure_ascii=True, separators=(",", ":")
)
payload_bytes = len(payload_json.encode("utf-8"))
metadata_payload: dict[str, object] = MessageMetadataToolResult(
tool_call_id=tool_call_id,
run_id=run_input.run_id,
tool_name=tool_name,
).model_dump()
stored_content = payload_json
if (
self._tool_result_storage is not None
and payload_bytes >= self._tool_result_offload_threshold_bytes
):
storage_path = (
f"{self._tool_result_prefix}/{run_input.thread_id}/"
f"{run_input.run_id}/{tool_call_id}.json"
)
try:
metadata_payload = await persist_tool_result_payload(
storage=self._tool_result_storage,
run_id=run_input.run_id,
turn_id=str(next_seq),
tool_call_id=tool_call_id,
tool_name=tool_name,
payload=sanitized_tool_payload,
bucket=self._tool_result_bucket,
path=storage_path,
)
stored_content = json.dumps(
{
"toolName": tool_name,
"offloaded": True,
"storage": {
"bucket": metadata_payload.get("storage_bucket"),
"path": metadata_payload.get("storage_path"),
},
},
ensure_ascii=True,
separators=(",", ":"),
)
except Exception:
metadata_payload = MessageMetadataToolResult(
tool_call_id=tool_call_id,
run_id=run_input.run_id,
tool_name=tool_name,
).model_dump()
tool_message = await message_repository.append_message(
session_id=session_uuid,
seq=next_seq,
role=AgentChatMessageRole.TOOL,
content=json.dumps(
sanitized_tool_payload, ensure_ascii=True, separators=(",", ":")
),
metadata=MessageMetadataToolResult(
tool_call_id=tool_call_id,
run_id=run_input.run_id,
tool_name=tool_name,
).model_dump(),
content=stored_content,
metadata=metadata_payload,
)
snapshot = self._state_persistence.build_resuming_snapshot(
@@ -274,6 +327,11 @@ class ResumeService:
pending_tool_call_id: str | None = None
events: list[dict[str, object]] = []
runtime_events = runtime_result.get("agui_events")
if isinstance(runtime_events, list):
for event in runtime_events:
if isinstance(event, dict):
events.append(event)
message_delta = 1
snapshot = self._state_persistence.build_completed_snapshot()
status = AgentChatSessionStatus.COMPLETED
@@ -162,6 +162,11 @@ class RunService:
)
pending_tool_call_id: str | None = None
events: list[dict[str, object]] = []
runtime_events = runtime_result.get("agui_events")
if isinstance(runtime_events, list):
for event in runtime_events:
if isinstance(event, dict):
events.append(event)
message_delta = 2
session_status = AgentChatSessionStatus.COMPLETED
snapshot = self._state_persistence.build_completed_snapshot()
@@ -416,6 +416,21 @@ class CrewAIRuntime:
completion_tokens = 0
total_tokens = 0
total_cost = 0.0
internal_events: list[dict[str, Any]] = []
def _emit_step_event(
*,
event_type: str,
stage: str,
status: str | None = None,
reason: str | None = None,
) -> None:
data: dict[str, Any] = {"stage": stage}
if status is not None:
data["status"] = status
if reason is not None:
data["reason"] = reason
internal_events.append({"type": event_type, "data": data})
client_front_tools = self._normalize_client_front_tools(tools)
intent_tools = self._resolve_stage_tools_payload(
@@ -432,6 +447,18 @@ class CrewAIRuntime:
)
if resume_from_stage in {"execution", "organization"}:
_emit_step_event(
event_type="stepStarted",
stage="intent",
status="skipped",
reason="resume_from_interrupted_stage",
)
_emit_step_event(
event_type="stepFinished",
stage="intent",
status="skipped",
reason="resume_from_interrupted_stage",
)
intent_result = IntentResult(
route="NEEDS_EXECUTION",
intent_summary="resume_from_interrupted_stage",
@@ -439,6 +466,7 @@ class CrewAIRuntime:
safety_flags=[],
)
else:
_emit_step_event(event_type="stepStarted", stage="intent")
intent_text, intent_usage, _, _ = self._run_stage_with_crewai(
stage="intent",
user_content=user_input,
@@ -451,11 +479,15 @@ class CrewAIRuntime:
total_tokens += intent_usage.total_tokens
total_cost += intent_usage.cost
intent_result = _parse_intent_result(intent_text)
_emit_step_event(
event_type="stepFinished", stage="intent", status="completed"
)
assistant_text = intent_result.assistant_text or ""
pending_front_tool: dict[str, object] | None = None
if intent_result.route == "NEEDS_EXECUTION":
_emit_step_event(event_type="stepStarted", stage="execution")
execution_input = json.dumps(
{
"user_input": user_input,
@@ -483,8 +515,14 @@ class CrewAIRuntime:
execution_tools=execution_tools,
pending_call=pending_call,
)
_emit_step_event(
event_type="stepFinished",
stage="execution",
status="pending_approval" if pending_call is not None else "completed",
)
if pending_call is None and resume_from_stage != "execution":
_emit_step_event(event_type="stepStarted", stage="organization")
execution_result = _parse_execution_result(execution_text)
organization_input = json.dumps(
{
@@ -522,28 +560,68 @@ class CrewAIRuntime:
fallback_text=execution_result.report_brief,
)
assistant_text = organization_result.assistant_text
_emit_step_event(
event_type="stepFinished",
stage="organization",
status="completed",
)
elif pending_call is not None:
assistant_text = (
intent_result.execution_brief or "Tool call pending approval"
)
_emit_step_event(
event_type="stepStarted",
stage="organization",
status="skipped",
reason="pending_tool_approval",
)
_emit_step_event(
event_type="stepFinished",
stage="organization",
status="skipped",
reason="pending_tool_approval",
)
else:
execution_result = _parse_execution_result(execution_text)
assistant_text = execution_result.report_brief
_emit_step_event(
event_type="stepStarted",
stage="organization",
status="skipped",
reason="resume_from_execution",
)
_emit_step_event(
event_type="stepFinished",
stage="organization",
status="skipped",
reason="resume_from_execution",
)
else:
_emit_step_event(
event_type="stepStarted",
stage="execution",
status="skipped",
reason="direct_execution_route",
)
_emit_step_event(
event_type="stepFinished",
stage="execution",
status="skipped",
reason="direct_execution_route",
)
_emit_step_event(
event_type="stepStarted",
stage="organization",
status="skipped",
reason="direct_execution_route",
)
_emit_step_event(
event_type="stepFinished",
stage="organization",
status="skipped",
reason="direct_execution_route",
)
internal_events = [
{"type": "llmStarted", "data": {"model": self._config.model_code}},
{"type": "llmChunk", "data": {"text": assistant_text}},
{
"type": "llmFinished",
"data": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
"cost": total_cost,
"provider": self._config.provider_name,
},
},
]
return {
"assistant_text": assistant_text,
"prompt_tokens": prompt_tokens,
@@ -57,42 +57,28 @@ async def _execute_create_calendar_event(
)
event_id = str(created.id)
return {
"result": {
"eventId": event_id,
"ok": True,
"message": "日程已创建",
"type": "calendar_card.v1",
"version": "v1",
"data": {
"id": event_id,
"title": created.title,
"description": created.description,
"startAt": created.start_at.isoformat(),
"endAt": created.end_at.isoformat() if created.end_at is not None else None,
"timezone": created.timezone,
"location": location_value,
"color": "#4F46E5",
"sourceType": "agent_generated",
"ok": True,
"message": "日程已创建",
},
"ui": {
"type": "calendar_card.v1",
"version": "v1",
"data": {
"id": event_id,
"title": created.title,
"description": created.description,
"startAt": created.start_at.isoformat(),
"endAt": (
created.end_at.isoformat() if created.end_at is not None else None
),
"timezone": created.timezone,
"location": location_value,
"color": "#4F46E5",
"sourceType": "agent_generated",
},
"actions": [
{
"type": "link",
"label": "查看详情",
"target": f"/calendar/events/{event_id}",
}
],
},
"actions": [
{
"type": "link",
"label": "查看详情",
"target": f"/calendar/events/{event_id}",
}
],
}
@@ -9,6 +9,9 @@ from core.agent.domain.agui_input import parse_run_input
from core.agent.application.resume_service import ResumeService
from core.agent.application.run_service import RunService
from core.agent.infrastructure.events.redis_stream import RedisStreamEventStore
from core.agent.infrastructure.storage.tool_result_storage import (
create_tool_result_storage,
)
from core.config.settings import config
from core.logging import get_logger
from core.taskiq.app import bulk_broker, critical_broker, default_broker
@@ -125,8 +128,13 @@ async def run_agent_task(
) -> dict[str, object]:
publisher = publish_event or await _build_redis_publisher()
enqueue = enqueue_command or _enqueue_followup_command
tool_result_storage = create_tool_result_storage()
service_run = run_service or RunService()
service_resume = resume_service or ResumeService()
service_resume = resume_service or ResumeService(
tool_result_storage=tool_result_storage,
tool_result_bucket="private",
tool_result_prefix="tool-results",
)
command_type = str(command.get("command", "run"))
if command_type not in {"run", "resume", "resume_continue"}:
@@ -0,0 +1,6 @@
from core.agent.infrastructure.storage.tool_result_storage import (
SupabaseToolResultStorage,
create_tool_result_storage,
)
__all__ = ["SupabaseToolResultStorage", "create_tool_result_storage"]
@@ -0,0 +1,78 @@
from __future__ import annotations
import asyncio
import json
from typing import Any
from services.base.supabase import supabase_service
class SupabaseToolResultStorage:
def _bucket_client(self, *, bucket: str) -> Any:
client = supabase_service.get_admin_client()
storage = getattr(client, "storage", None)
if storage is None:
raise RuntimeError("Supabase storage client unavailable")
from_bucket = getattr(storage, "from_", None)
if not callable(from_bucket):
raise RuntimeError("Supabase storage bucket accessor unavailable")
return from_bucket(bucket)
async def upload_json(
self,
*,
bucket: str,
path: str,
payload: dict[str, object],
) -> str:
data = json.dumps(payload, ensure_ascii=True, separators=(",", ":")).encode(
"utf-8"
)
def _upload() -> object:
bucket_client = self._bucket_client(bucket=bucket)
upload = getattr(bucket_client, "upload", None)
if not callable(upload):
raise RuntimeError("Supabase storage upload is unavailable")
return upload(
path,
data,
{
"content-type": "application/json",
"upsert": "true",
},
)
result = await asyncio.to_thread(_upload)
return str(result or "")
async def read_json(self, *, bucket: str, path: str) -> dict[str, object] | None:
def _download() -> object:
bucket_client = self._bucket_client(bucket=bucket)
download = getattr(bucket_client, "download", None)
if not callable(download):
raise RuntimeError("Supabase storage download is unavailable")
return download(path)
raw = await asyncio.to_thread(_download)
if isinstance(raw, bytes):
text = raw.decode("utf-8")
elif isinstance(raw, str):
text = raw
else:
return None
try:
payload = json.loads(text)
except ValueError:
return None
if not isinstance(payload, dict):
return None
return payload
def create_tool_result_storage() -> SupabaseToolResultStorage | None:
try:
supabase_service.get_admin_client()
except Exception:
return None
return SupabaseToolResultStorage()
+5 -1
View File
@@ -9,6 +9,9 @@ from redis.asyncio import Redis
from sqlalchemy.ext.asyncio import AsyncSession
from core.agent.infrastructure.events.redis_stream import RedisStreamEventStore
from core.agent.infrastructure.storage.tool_result_storage import (
create_tool_result_storage,
)
from core.agent.infrastructure.queue.tasks import (
run_command_task,
run_command_task_bulk,
@@ -109,8 +112,9 @@ class RedisEventStream:
def get_agent_service(session: AsyncSession = Depends(get_db)) -> AgentService:
tool_result_storage = create_tool_result_storage()
return AgentService(
repository=AgentRepository(session),
repository=AgentRepository(session, tool_result_storage=tool_result_storage),
queue=TaskiqQueueClient(),
stream=RedisEventStream(),
)
+49 -9
View File
@@ -2,6 +2,7 @@ from __future__ import annotations
from datetime import date, datetime, time, timedelta, timezone
import json
from typing import Protocol
from uuid import UUID
from fastapi import HTTPException
@@ -12,9 +13,21 @@ from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
from models.agent_chat_session import AgentChatSession
class ToolResultPayloadStorage(Protocol):
async def read_json(
self, *, bucket: str, path: str
) -> dict[str, object] | None: ...
class AgentRepository:
def __init__(self, session: AsyncSession) -> None:
def __init__(
self,
session: AsyncSession,
*,
tool_result_storage: ToolResultPayloadStorage | None = None,
) -> None:
self._session = session
self._tool_result_storage = tool_result_storage
async def get_session_owner(self, *, session_id: str) -> str:
try:
@@ -42,7 +55,9 @@ class AgentRepository:
try:
session_uuid = UUID(session_id)
except ValueError as exc:
raise HTTPException(status_code=422, detail="Invalid session_id") from exc
raise HTTPException(
status_code=422, detail="Invalid session_id"
) from exc
session = AgentChatSession(
id=session_uuid,
@@ -118,10 +133,13 @@ class AgentRepository:
)
messages = (await self._session.execute(message_stmt)).scalars().all()
has_more = any(day < target_day for day in unique_days)
snapshot_messages: list[dict[str, object]] = []
for message in messages:
snapshot_messages.append(await self._to_snapshot_message(message))
return {
"day": target_day.isoformat(),
"hasMore": has_more,
"messages": [self._to_snapshot_message(msg) for msg in messages],
"messages": snapshot_messages,
}
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None:
@@ -141,8 +159,9 @@ class AgentRepository:
return None
return str(latest_id)
@staticmethod
def _to_snapshot_message(message: AgentChatMessage) -> dict[str, object]:
async def _to_snapshot_message(
self, message: AgentChatMessage
) -> dict[str, object]:
role = (
message.role.value
if isinstance(message.role, AgentChatMessageRole)
@@ -167,14 +186,35 @@ class AgentRepository:
parsed_content = decoded
except (TypeError, ValueError):
parsed_content = None
if parsed_content is not None:
ui = parsed_content.get("ui")
hydrated_content: dict[str, object] | None = None
if self._tool_result_storage is not None:
storage_bucket = metadata.get("storage_bucket")
storage_path = metadata.get("storage_path")
if isinstance(storage_bucket, str) and isinstance(storage_path, str):
try:
hydrated_content = await self._tool_result_storage.read_json(
bucket=storage_bucket,
path=storage_path,
)
except Exception:
hydrated_content = None
resolved_content = hydrated_content or parsed_content
if resolved_content is not None:
result = resolved_content.get("result")
if isinstance(result, dict):
result_content = result.get("content")
if isinstance(result_content, str):
payload["content"] = result_content
ui = resolved_content.get("ui")
if isinstance(ui, dict):
payload["ui"] = ui
display_content = parsed_content.get("content")
display_content = resolved_content.get("content")
if isinstance(display_content, str):
payload["content"] = display_content
else:
if "content" not in payload:
payload["content"] = message.content
else:
payload["content"] = message.content