feat: 添加 Agent 步骤事件与图片附件功能
- 新增 stepStarted/stepFinished 事件类型支持 - 前端实现图片附件上传和预览功能 - 后端增强工具结果存储和事件处理 - 完善相关单元测试和集成测试
This commit is contained in:
@@ -37,6 +37,21 @@ def to_agui_wire_event(event: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
data = event.get("data")
|
||||
if isinstance(data, dict):
|
||||
if event_type == "tool.result":
|
||||
for key in (
|
||||
"messageId",
|
||||
"toolCallId",
|
||||
"callId",
|
||||
"toolName",
|
||||
"stage",
|
||||
"taskId",
|
||||
"ui",
|
||||
"content",
|
||||
):
|
||||
value = data.get(key)
|
||||
if value is not None:
|
||||
payload[key] = value
|
||||
return payload
|
||||
reserved = {"type", "threadId", "runId"}
|
||||
data_map = cast(dict[str, Any], data)
|
||||
payload.update({k: v for k, v in data_map.items() if k not in reserved})
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from decimal import Decimal, InvalidOperation
|
||||
from typing import Any, Callable, Protocol
|
||||
from uuid import UUID
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from core.agentscope.events.tool_result_summary import build_tool_content_summary
|
||||
from core.agentscope.events.persistence import MessageRepository, SessionRepository
|
||||
from core.logging import get_logger
|
||||
from models.agent_chat_message import AgentChatMessageRole
|
||||
from models.agent_chat_session import AgentChatSessionStatus
|
||||
|
||||
@@ -14,6 +17,16 @@ class EventStore(Protocol):
|
||||
async def persist(self, event: dict[str, Any]) -> None: ...
|
||||
|
||||
|
||||
class ToolResultStorageLike(Protocol):
|
||||
async def upload_json(
|
||||
self,
|
||||
*,
|
||||
bucket: str,
|
||||
path: str,
|
||||
payload: dict[str, object],
|
||||
) -> str: ...
|
||||
|
||||
|
||||
class NullEventStore:
|
||||
async def persist(self, event: dict[str, Any]) -> None:
|
||||
del event
|
||||
@@ -21,9 +34,20 @@ class NullEventStore:
|
||||
|
||||
class SqlAlchemyEventStore:
|
||||
_session_factory: Callable[[], Any]
|
||||
_tool_result_storage: ToolResultStorageLike | None
|
||||
_tool_result_bucket: str | None
|
||||
_logger = get_logger("core.agentscope.events.store")
|
||||
|
||||
def __init__(self, *, session_factory: Any) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
session_factory: Any,
|
||||
tool_result_storage: ToolResultStorageLike | None = None,
|
||||
tool_result_bucket: str | None = None,
|
||||
) -> None:
|
||||
self._session_factory = session_factory
|
||||
self._tool_result_storage = tool_result_storage
|
||||
self._tool_result_bucket = tool_result_bucket
|
||||
self._message_buffers: dict[tuple[str, str], str] = {}
|
||||
self._message_contexts: dict[tuple[str, str], dict[str, object]] = {}
|
||||
|
||||
@@ -228,23 +252,89 @@ class SqlAlchemyEventStore:
|
||||
if not isinstance(tool_name, str) or not tool_name:
|
||||
return
|
||||
|
||||
payload = {
|
||||
"args": event.get("args"),
|
||||
"result": event.get("result"),
|
||||
"error": event.get("error"),
|
||||
"call_id": event.get("callId"),
|
||||
}
|
||||
content = json.dumps(payload, ensure_ascii=False, separators=(",", ":"))
|
||||
metadata: dict[str, object] = {"tool_name": tool_name}
|
||||
run_id = event.get("runId")
|
||||
if isinstance(run_id, str) and run_id:
|
||||
metadata["run_id"] = run_id
|
||||
run_id_value = run_id if isinstance(run_id, str) and run_id else ""
|
||||
task_id = event.get("taskId")
|
||||
task_id_value = task_id if isinstance(task_id, str) and task_id else "task"
|
||||
call_id_value = event.get("callId")
|
||||
if not isinstance(call_id_value, str) or not call_id_value:
|
||||
call_id_value = (
|
||||
f"{run_id_value}-{task_id_value}-{uuid4().hex[:8]}"
|
||||
if run_id_value
|
||||
else f"{task_id_value}-{uuid4().hex[:8]}"
|
||||
)
|
||||
|
||||
summary = build_tool_content_summary(
|
||||
tool_name=tool_name,
|
||||
args=event.get("args") if isinstance(event.get("args"), dict) else None,
|
||||
result=event.get("result"),
|
||||
error=event.get("error"),
|
||||
)
|
||||
|
||||
raw_result_value = event.get("result")
|
||||
raw_result: dict[str, object] = (
|
||||
raw_result_value if isinstance(raw_result_value, dict) else {}
|
||||
)
|
||||
ui_candidate = raw_result.get("ui")
|
||||
ui_schema = ui_candidate if isinstance(ui_candidate, dict) else None
|
||||
result_type = raw_result.get("type")
|
||||
result_data = raw_result.get("data")
|
||||
if (
|
||||
ui_schema is None
|
||||
and isinstance(result_type, str)
|
||||
and isinstance(result_data, dict)
|
||||
):
|
||||
ui_schema = raw_result
|
||||
|
||||
payload: dict[str, object] = {
|
||||
"toolName": tool_name,
|
||||
"ui_schema": ui_schema,
|
||||
"result": _sanitize_result(raw_result),
|
||||
"error": _sanitize_error(event.get("error")),
|
||||
"callId": call_id_value,
|
||||
"runId": run_id_value,
|
||||
"taskId": task_id_value,
|
||||
"content": summary,
|
||||
}
|
||||
|
||||
metadata: dict[str, object] = {
|
||||
"tool_name": tool_name,
|
||||
"tool_call_id": call_id_value,
|
||||
"summary_version": "v1",
|
||||
}
|
||||
if run_id_value:
|
||||
metadata["run_id"] = run_id_value
|
||||
stage = event.get("stage")
|
||||
if isinstance(stage, str) and stage:
|
||||
metadata["stage"] = stage
|
||||
task_id = event.get("taskId")
|
||||
if isinstance(task_id, str) and task_id:
|
||||
metadata["task_id"] = task_id
|
||||
if task_id_value:
|
||||
metadata["task_id"] = task_id_value
|
||||
|
||||
if self._tool_result_storage is not None and self._tool_result_bucket:
|
||||
safe_run = _sanitize_path_component(run_id_value or "run")
|
||||
safe_call = _sanitize_path_component(call_id_value)
|
||||
storage_path = f"tool-results/{session_id}/{safe_run}/{safe_call}.json"
|
||||
try:
|
||||
await self._tool_result_storage.upload_json(
|
||||
bucket=self._tool_result_bucket,
|
||||
path=storage_path,
|
||||
payload=payload,
|
||||
)
|
||||
metadata["storage_bucket"] = self._tool_result_bucket
|
||||
metadata["storage_path"] = storage_path
|
||||
except Exception: # noqa: BLE001
|
||||
metadata["storage_upload_failed"] = True
|
||||
self._logger.warning(
|
||||
"tool result storage upload failed",
|
||||
session_id=str(session_id),
|
||||
run_id=run_id_value,
|
||||
call_id=call_id_value,
|
||||
storage_path=storage_path,
|
||||
)
|
||||
|
||||
content = summary or json.dumps(
|
||||
payload, ensure_ascii=False, separators=(",", ":")
|
||||
)
|
||||
|
||||
locked_session = await session_repo.lock_session_for_update(
|
||||
session_id=session_id
|
||||
@@ -333,3 +423,69 @@ class SqlAlchemyEventStore:
|
||||
except (InvalidOperation, TypeError, ValueError):
|
||||
return Decimal("0")
|
||||
return parsed if parsed >= 0 else Decimal("0")
|
||||
|
||||
|
||||
def _sanitize_path_component(value: str) -> str:
|
||||
compact = re.sub(r"[^A-Za-z0-9._-]", "-", value.strip())
|
||||
compact = compact.strip(".-")
|
||||
return compact or "id"
|
||||
|
||||
|
||||
def _sanitize_error(value: object) -> str | None:
|
||||
if isinstance(value, str) and value.strip():
|
||||
return " ".join(value.split())[:300]
|
||||
if isinstance(value, dict):
|
||||
for key in ("message", "error", "detail"):
|
||||
item = value.get(key)
|
||||
if isinstance(item, str) and item.strip():
|
||||
return " ".join(item.split())[:300]
|
||||
return None
|
||||
|
||||
|
||||
def _sanitize_result(value: object) -> dict[str, object]:
|
||||
if not isinstance(value, dict):
|
||||
return {}
|
||||
|
||||
def _is_sensitive_key(key: str) -> bool:
|
||||
normalized = key.strip().lower().replace("-", "_")
|
||||
if not normalized:
|
||||
return False
|
||||
exact = {
|
||||
"password",
|
||||
"token",
|
||||
"secret",
|
||||
"api_key",
|
||||
"apikey",
|
||||
"credential",
|
||||
"authorization",
|
||||
"auth",
|
||||
}
|
||||
if normalized in exact:
|
||||
return True
|
||||
patterns = (
|
||||
"password",
|
||||
"token",
|
||||
"secret",
|
||||
"auth",
|
||||
"credential",
|
||||
"api_key",
|
||||
"apikey",
|
||||
"authorization",
|
||||
)
|
||||
return any(pattern in normalized for pattern in patterns)
|
||||
|
||||
def _sanitize_value(item: object) -> object:
|
||||
if isinstance(item, dict):
|
||||
return _sanitize_result(item)
|
||||
if isinstance(item, list):
|
||||
return [_sanitize_value(entry) for entry in item]
|
||||
return item
|
||||
|
||||
sanitized: dict[str, object] = {}
|
||||
for key, item in value.items():
|
||||
key_text = str(key)
|
||||
if _is_sensitive_key(key_text):
|
||||
sanitized[str(key)] = "[REDACTED]"
|
||||
continue
|
||||
sanitized[str(key)] = _sanitize_value(item)
|
||||
return sanitized
|
||||
|
||||
@@ -0,0 +1,124 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def build_tool_content_summary(
|
||||
*,
|
||||
tool_name: str,
|
||||
args: dict[str, Any] | None,
|
||||
result: Any,
|
||||
error: Any,
|
||||
) -> str:
|
||||
error_message = _extract_error_message(error)
|
||||
if error_message is not None:
|
||||
return _truncate(f"{tool_name} 执行失败:{error_message}")
|
||||
|
||||
normalized_args = args if isinstance(args, dict) else {}
|
||||
normalized_result = result if isinstance(result, dict) else {}
|
||||
|
||||
business_failure = _extract_business_failure_message(normalized_result)
|
||||
if business_failure is not None:
|
||||
return _truncate(f"{tool_name} 执行失败:{business_failure}")
|
||||
|
||||
if tool_name == "calendar_write":
|
||||
title = _pick_first_str(normalized_result, ("title",)) or _pick_first_str(
|
||||
normalized_args, ("title",)
|
||||
)
|
||||
start_at = _pick_first_str(normalized_result, ("startAt", "start_at"))
|
||||
if title and start_at:
|
||||
return _truncate(f"已创建日程:{title}({start_at})")
|
||||
if title:
|
||||
return _truncate(f"已创建日程:{title}")
|
||||
|
||||
if tool_name == "calendar_read":
|
||||
total = _extract_total(normalized_result)
|
||||
query = _pick_first_str(normalized_args, ("query",)) or "全部"
|
||||
if total is not None:
|
||||
return _truncate(f"查询到 {total} 条日程({query})")
|
||||
|
||||
if tool_name == "calendar_delete":
|
||||
target = _pick_first_str(normalized_result, ("title", "eventId", "event_id"))
|
||||
if target:
|
||||
return _truncate(f"已删除日程:{target}")
|
||||
|
||||
if tool_name == "calendar_share":
|
||||
target = _pick_first_str(normalized_result, ("target", "user", "userName"))
|
||||
if target:
|
||||
return _truncate(f"已分享日程给 {target}")
|
||||
|
||||
if tool_name == "user_resolve":
|
||||
target = _pick_first_str(normalized_result, ("name", "userName", "userId"))
|
||||
if target:
|
||||
return _truncate(f"已匹配用户:{target}")
|
||||
|
||||
result_content = _pick_first_str(normalized_result, ("content", "message"))
|
||||
if result_content:
|
||||
return _truncate(result_content)
|
||||
|
||||
return _truncate(f"{tool_name} 执行完成")
|
||||
|
||||
|
||||
def _extract_error_message(error: Any) -> str | None:
|
||||
if isinstance(error, str) and error.strip():
|
||||
return error.strip()
|
||||
if isinstance(error, dict):
|
||||
for key in ("message", "error", "detail"):
|
||||
value = error.get(key)
|
||||
if isinstance(value, str) and value.strip():
|
||||
return value.strip()
|
||||
return None
|
||||
|
||||
|
||||
def _pick_first_str(payload: dict[str, Any], keys: tuple[str, ...]) -> str | None:
|
||||
for key in keys:
|
||||
value = payload.get(key)
|
||||
if isinstance(value, str):
|
||||
normalized = " ".join(value.split())
|
||||
if normalized:
|
||||
return normalized
|
||||
return None
|
||||
|
||||
|
||||
def _extract_total(result: dict[str, Any]) -> int | None:
|
||||
candidates: list[Any] = [result.get("total")]
|
||||
data = result.get("data")
|
||||
if isinstance(data, dict):
|
||||
candidates.append(data.get("total"))
|
||||
events = data.get("events")
|
||||
if isinstance(events, list):
|
||||
candidates.append(len(events))
|
||||
for value in candidates:
|
||||
if isinstance(value, bool):
|
||||
continue
|
||||
if isinstance(value, int) and value >= 0:
|
||||
return value
|
||||
if isinstance(value, str) and value.isdigit():
|
||||
return int(value)
|
||||
return None
|
||||
|
||||
|
||||
def _extract_business_failure_message(result: dict[str, Any]) -> str | None:
|
||||
top_ok = result.get("ok")
|
||||
if top_ok is False:
|
||||
top_message = _pick_first_str(result, ("message", "error", "detail"))
|
||||
if top_message:
|
||||
return top_message
|
||||
|
||||
data = result.get("data")
|
||||
if isinstance(data, dict) and data.get("ok") is False:
|
||||
data_message = _pick_first_str(data, ("message", "error", "detail"))
|
||||
if data_message:
|
||||
return data_message
|
||||
code = _pick_first_str(data, ("code",))
|
||||
if code:
|
||||
return code
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _truncate(text: str, limit: int = 80) -> str:
|
||||
normalized = " ".join(text.split())
|
||||
if len(normalized) <= limit:
|
||||
return normalized
|
||||
return normalized[: limit - 3] + "..."
|
||||
Reference in New Issue
Block a user