feat: 添加 Agent 步骤事件与图片附件功能

- 新增 stepStarted/stepFinished 事件类型支持
- 前端实现图片附件上传和预览功能
- 后端增强工具结果存储和事件处理
- 完善相关单元测试和集成测试
This commit is contained in:
zl-q
2026-03-12 09:29:57 +08:00
parent 87215f9d41
commit 7b8865e256
45 changed files with 3869 additions and 308 deletions
@@ -37,6 +37,21 @@ def to_agui_wire_event(event: dict[str, Any]) -> dict[str, Any]:
data = event.get("data")
if isinstance(data, dict):
if event_type == "tool.result":
for key in (
"messageId",
"toolCallId",
"callId",
"toolName",
"stage",
"taskId",
"ui",
"content",
):
value = data.get(key)
if value is not None:
payload[key] = value
return payload
reserved = {"type", "threadId", "runId"}
data_map = cast(dict[str, Any], data)
payload.update({k: v for k, v in data_map.items() if k not in reserved})
+171 -15
View File
@@ -1,11 +1,14 @@
from __future__ import annotations
import json
import re
from decimal import Decimal, InvalidOperation
from typing import Any, Callable, Protocol
from uuid import UUID
from uuid import UUID, uuid4
from core.agentscope.events.tool_result_summary import build_tool_content_summary
from core.agentscope.events.persistence import MessageRepository, SessionRepository
from core.logging import get_logger
from models.agent_chat_message import AgentChatMessageRole
from models.agent_chat_session import AgentChatSessionStatus
@@ -14,6 +17,16 @@ class EventStore(Protocol):
async def persist(self, event: dict[str, Any]) -> None: ...
class ToolResultStorageLike(Protocol):
async def upload_json(
self,
*,
bucket: str,
path: str,
payload: dict[str, object],
) -> str: ...
class NullEventStore:
async def persist(self, event: dict[str, Any]) -> None:
del event
@@ -21,9 +34,20 @@ class NullEventStore:
class SqlAlchemyEventStore:
_session_factory: Callable[[], Any]
_tool_result_storage: ToolResultStorageLike | None
_tool_result_bucket: str | None
_logger = get_logger("core.agentscope.events.store")
def __init__(self, *, session_factory: Any) -> None:
def __init__(
self,
*,
session_factory: Any,
tool_result_storage: ToolResultStorageLike | None = None,
tool_result_bucket: str | None = None,
) -> None:
self._session_factory = session_factory
self._tool_result_storage = tool_result_storage
self._tool_result_bucket = tool_result_bucket
self._message_buffers: dict[tuple[str, str], str] = {}
self._message_contexts: dict[tuple[str, str], dict[str, object]] = {}
@@ -228,23 +252,89 @@ class SqlAlchemyEventStore:
if not isinstance(tool_name, str) or not tool_name:
return
payload = {
"args": event.get("args"),
"result": event.get("result"),
"error": event.get("error"),
"call_id": event.get("callId"),
}
content = json.dumps(payload, ensure_ascii=False, separators=(",", ":"))
metadata: dict[str, object] = {"tool_name": tool_name}
run_id = event.get("runId")
if isinstance(run_id, str) and run_id:
metadata["run_id"] = run_id
run_id_value = run_id if isinstance(run_id, str) and run_id else ""
task_id = event.get("taskId")
task_id_value = task_id if isinstance(task_id, str) and task_id else "task"
call_id_value = event.get("callId")
if not isinstance(call_id_value, str) or not call_id_value:
call_id_value = (
f"{run_id_value}-{task_id_value}-{uuid4().hex[:8]}"
if run_id_value
else f"{task_id_value}-{uuid4().hex[:8]}"
)
summary = build_tool_content_summary(
tool_name=tool_name,
args=event.get("args") if isinstance(event.get("args"), dict) else None,
result=event.get("result"),
error=event.get("error"),
)
raw_result_value = event.get("result")
raw_result: dict[str, object] = (
raw_result_value if isinstance(raw_result_value, dict) else {}
)
ui_candidate = raw_result.get("ui")
ui_schema = ui_candidate if isinstance(ui_candidate, dict) else None
result_type = raw_result.get("type")
result_data = raw_result.get("data")
if (
ui_schema is None
and isinstance(result_type, str)
and isinstance(result_data, dict)
):
ui_schema = raw_result
payload: dict[str, object] = {
"toolName": tool_name,
"ui_schema": ui_schema,
"result": _sanitize_result(raw_result),
"error": _sanitize_error(event.get("error")),
"callId": call_id_value,
"runId": run_id_value,
"taskId": task_id_value,
"content": summary,
}
metadata: dict[str, object] = {
"tool_name": tool_name,
"tool_call_id": call_id_value,
"summary_version": "v1",
}
if run_id_value:
metadata["run_id"] = run_id_value
stage = event.get("stage")
if isinstance(stage, str) and stage:
metadata["stage"] = stage
task_id = event.get("taskId")
if isinstance(task_id, str) and task_id:
metadata["task_id"] = task_id
if task_id_value:
metadata["task_id"] = task_id_value
if self._tool_result_storage is not None and self._tool_result_bucket:
safe_run = _sanitize_path_component(run_id_value or "run")
safe_call = _sanitize_path_component(call_id_value)
storage_path = f"tool-results/{session_id}/{safe_run}/{safe_call}.json"
try:
await self._tool_result_storage.upload_json(
bucket=self._tool_result_bucket,
path=storage_path,
payload=payload,
)
metadata["storage_bucket"] = self._tool_result_bucket
metadata["storage_path"] = storage_path
except Exception: # noqa: BLE001
metadata["storage_upload_failed"] = True
self._logger.warning(
"tool result storage upload failed",
session_id=str(session_id),
run_id=run_id_value,
call_id=call_id_value,
storage_path=storage_path,
)
content = summary or json.dumps(
payload, ensure_ascii=False, separators=(",", ":")
)
locked_session = await session_repo.lock_session_for_update(
session_id=session_id
@@ -333,3 +423,69 @@ class SqlAlchemyEventStore:
except (InvalidOperation, TypeError, ValueError):
return Decimal("0")
return parsed if parsed >= 0 else Decimal("0")
def _sanitize_path_component(value: str) -> str:
compact = re.sub(r"[^A-Za-z0-9._-]", "-", value.strip())
compact = compact.strip(".-")
return compact or "id"
def _sanitize_error(value: object) -> str | None:
if isinstance(value, str) and value.strip():
return " ".join(value.split())[:300]
if isinstance(value, dict):
for key in ("message", "error", "detail"):
item = value.get(key)
if isinstance(item, str) and item.strip():
return " ".join(item.split())[:300]
return None
def _sanitize_result(value: object) -> dict[str, object]:
if not isinstance(value, dict):
return {}
def _is_sensitive_key(key: str) -> bool:
normalized = key.strip().lower().replace("-", "_")
if not normalized:
return False
exact = {
"password",
"token",
"secret",
"api_key",
"apikey",
"credential",
"authorization",
"auth",
}
if normalized in exact:
return True
patterns = (
"password",
"token",
"secret",
"auth",
"credential",
"api_key",
"apikey",
"authorization",
)
return any(pattern in normalized for pattern in patterns)
def _sanitize_value(item: object) -> object:
if isinstance(item, dict):
return _sanitize_result(item)
if isinstance(item, list):
return [_sanitize_value(entry) for entry in item]
return item
sanitized: dict[str, object] = {}
for key, item in value.items():
key_text = str(key)
if _is_sensitive_key(key_text):
sanitized[str(key)] = "[REDACTED]"
continue
sanitized[str(key)] = _sanitize_value(item)
return sanitized
@@ -0,0 +1,124 @@
from __future__ import annotations
from typing import Any
def build_tool_content_summary(
*,
tool_name: str,
args: dict[str, Any] | None,
result: Any,
error: Any,
) -> str:
error_message = _extract_error_message(error)
if error_message is not None:
return _truncate(f"{tool_name} 执行失败:{error_message}")
normalized_args = args if isinstance(args, dict) else {}
normalized_result = result if isinstance(result, dict) else {}
business_failure = _extract_business_failure_message(normalized_result)
if business_failure is not None:
return _truncate(f"{tool_name} 执行失败:{business_failure}")
if tool_name == "calendar_write":
title = _pick_first_str(normalized_result, ("title",)) or _pick_first_str(
normalized_args, ("title",)
)
start_at = _pick_first_str(normalized_result, ("startAt", "start_at"))
if title and start_at:
return _truncate(f"已创建日程:{title}{start_at}")
if title:
return _truncate(f"已创建日程:{title}")
if tool_name == "calendar_read":
total = _extract_total(normalized_result)
query = _pick_first_str(normalized_args, ("query",)) or "全部"
if total is not None:
return _truncate(f"查询到 {total} 条日程({query}")
if tool_name == "calendar_delete":
target = _pick_first_str(normalized_result, ("title", "eventId", "event_id"))
if target:
return _truncate(f"已删除日程:{target}")
if tool_name == "calendar_share":
target = _pick_first_str(normalized_result, ("target", "user", "userName"))
if target:
return _truncate(f"已分享日程给 {target}")
if tool_name == "user_resolve":
target = _pick_first_str(normalized_result, ("name", "userName", "userId"))
if target:
return _truncate(f"已匹配用户:{target}")
result_content = _pick_first_str(normalized_result, ("content", "message"))
if result_content:
return _truncate(result_content)
return _truncate(f"{tool_name} 执行完成")
def _extract_error_message(error: Any) -> str | None:
if isinstance(error, str) and error.strip():
return error.strip()
if isinstance(error, dict):
for key in ("message", "error", "detail"):
value = error.get(key)
if isinstance(value, str) and value.strip():
return value.strip()
return None
def _pick_first_str(payload: dict[str, Any], keys: tuple[str, ...]) -> str | None:
for key in keys:
value = payload.get(key)
if isinstance(value, str):
normalized = " ".join(value.split())
if normalized:
return normalized
return None
def _extract_total(result: dict[str, Any]) -> int | None:
candidates: list[Any] = [result.get("total")]
data = result.get("data")
if isinstance(data, dict):
candidates.append(data.get("total"))
events = data.get("events")
if isinstance(events, list):
candidates.append(len(events))
for value in candidates:
if isinstance(value, bool):
continue
if isinstance(value, int) and value >= 0:
return value
if isinstance(value, str) and value.isdigit():
return int(value)
return None
def _extract_business_failure_message(result: dict[str, Any]) -> str | None:
top_ok = result.get("ok")
if top_ok is False:
top_message = _pick_first_str(result, ("message", "error", "detail"))
if top_message:
return top_message
data = result.get("data")
if isinstance(data, dict) and data.get("ok") is False:
data_message = _pick_first_str(data, ("message", "error", "detail"))
if data_message:
return data_message
code = _pick_first_str(data, ("code",))
if code:
return code
return None
def _truncate(text: str, limit: int = 80) -> str:
normalized = " ".join(text.split())
if len(normalized) <= limit:
return normalized
return normalized[: limit - 3] + "..."
@@ -42,11 +42,19 @@ def build_intent_user_prompt(
*, user_input: str | list[dict[str, Any]]
) -> str | list[dict[str, Any]]:
if isinstance(user_input, list):
context_messages = _conversation_context_messages(user_input)
context_hint = (
json.dumps(context_messages, ensure_ascii=True, separators=(",", ":"))
if context_messages
else "[]"
)
instruction_text = "\n\n".join(
[
INTENT_TASK_INSTRUCTION,
"[Output Schema]",
_schema_json(IntentOutput),
"[Conversation Context]",
context_hint,
"[User Input]",
"Use the following multimodal blocks as the latest user input.",
]
@@ -127,6 +135,56 @@ def _latest_user_content_blocks(
return []
def _conversation_context_messages(
user_input: list[dict[str, Any]],
) -> list[dict[str, str]]:
latest_user_index = -1
for index in range(len(user_input) - 1, -1, -1):
item = user_input[index]
if isinstance(item, dict) and item.get("role") == "user":
latest_user_index = index
break
if latest_user_index <= 0:
return []
context_items: list[dict[str, str]] = []
for item in user_input[:latest_user_index]:
if not isinstance(item, dict):
continue
role = item.get("role")
if role not in {"user", "assistant"}:
continue
content = item.get("content")
text = _content_to_text(content)
if text:
context_items.append({"role": str(role), "content": text})
if len(context_items) <= 12:
return context_items
return context_items[-12:]
def _content_to_text(content: Any) -> str:
if isinstance(content, str):
return " ".join(content.split())
if not isinstance(content, list):
return ""
parts: list[str] = []
for block in content:
if not isinstance(block, dict):
continue
block_type = block.get("type")
if block_type == "text":
text = block.get("text")
if isinstance(text, str) and text.strip():
parts.append(" ".join(text.split()))
elif block_type in {"binary", "image"}:
parts.append("[image]")
return " ".join(parts).strip()
def _binary_source_block(item: dict[str, Any]) -> dict[str, Any] | None:
mime_type = item.get("mimeType")
media_type = mime_type if isinstance(mime_type, str) and mime_type else "image/png"
@@ -1,9 +1,11 @@
from __future__ import annotations
import json
import re
from typing import Any, Protocol
from uuid import UUID
import structlog
from sqlalchemy.ext.asyncio import AsyncSession
from core.logging import get_logger
@@ -31,7 +33,9 @@ class PipelineLike(Protocol):
class AgentRouteRuntime:
_orchestrator: OrchestratorLike
_pipeline: PipelineLike
_logger = get_logger("core.agentscope.runtime.agent_route_runtime")
_logger: structlog.stdlib.BoundLogger = get_logger(
"core.agentscope.runtime.agent_route_runtime"
)
def __init__(
self, *, orchestrator: OrchestratorLike, pipeline: PipelineLike
@@ -144,15 +148,6 @@ class AgentRouteRuntime:
"data": {"stepName": "execution"},
},
)
await self._pipeline.emit(
session_id=command.thread_id,
event={
"type": "step.finish",
"threadId": command.thread_id,
"runId": command.run_id,
"data": {"stepName": "execution"},
},
)
await self._emit_stage_text(
thread_id=command.thread_id,
@@ -191,6 +186,15 @@ class AgentRouteRuntime:
task_id=task.task_id,
tool_calls=_task_tool_calls(task),
)
await self._pipeline.emit(
session_id=command.thread_id,
event={
"type": "step.finish",
"threadId": command.thread_id,
"runId": command.run_id,
"data": {"stepName": "execution"},
},
)
await self._pipeline.emit(
session_id=command.thread_id,
@@ -294,6 +298,13 @@ class AgentRouteRuntime:
tool_name = tool_call.get("tool_name")
if not isinstance(tool_name, str) or not tool_name:
continue
call_id = f"{run_id}-{task_id}-{index}"
result_payload = _build_tool_result_event_payload(
tool_name=tool_name,
call_id=call_id,
raw_result=tool_call.get("result"),
raw_error=tool_call.get("error"),
)
await self._pipeline.emit(
session_id=thread_id,
event={
@@ -301,18 +312,175 @@ class AgentRouteRuntime:
"threadId": thread_id,
"runId": run_id,
"data": {
"callId": f"{run_id}-{task_id}-{index}",
"messageId": result_payload["messageId"],
"toolCallId": call_id,
"callId": call_id,
"stage": "execution",
"taskId": task_id,
"toolName": tool_name,
"args": tool_call.get("args", {}),
"result": tool_call.get("result"),
"error": tool_call.get("error"),
"args": _sanitize_result(tool_call.get("args", {})),
"result": result_payload["result"],
"error": result_payload["error"],
"ui": result_payload["ui"],
"content": result_payload["content"],
},
},
)
def _build_tool_result_event_payload(
*,
tool_name: str,
call_id: str,
raw_result: Any,
raw_error: Any,
) -> dict[str, Any]:
result = _sanitize_result(_normalize_tool_result(raw_result))
error = _sanitize_error(raw_error)
ui: dict[str, Any] | None = None
direct_ui = result.get("ui")
if isinstance(direct_ui, dict):
ui = direct_ui
elif isinstance(result.get("type"), str) and isinstance(result.get("data"), dict):
ui = result
text_content = _extract_result_text_content(result)
if text_content is None and isinstance(error, str):
text_content = error
if text_content is None:
text_content = f"{tool_name} 执行完成"
return {
"messageId": f"tool-result-{call_id}",
"result": result,
"error": error,
"ui": ui,
"content": text_content,
}
def _normalize_tool_result(raw_result: Any) -> dict[str, Any]:
if isinstance(raw_result, dict):
content = raw_result.get("content")
if isinstance(content, str):
parsed = _try_parse_json_object(content)
if parsed is not None:
return parsed
return raw_result
if isinstance(raw_result, str):
parsed = _try_parse_json_object(raw_result)
if parsed is not None:
return parsed
text = raw_result.strip()
if text:
return {"content": text}
if raw_result is not None:
return {"value": raw_result}
return {}
def _try_parse_json_object(value: str) -> dict[str, Any] | None:
raw = value.strip()
if not raw:
return None
try:
parsed = json.loads(raw)
except ValueError:
return None
if not isinstance(parsed, dict):
return None
return parsed
def _extract_result_text_content(result: dict[str, Any]) -> str | None:
content = result.get("content")
if isinstance(content, str) and content.strip():
return content
data = result.get("data")
if isinstance(data, dict):
message = data.get("message")
if isinstance(message, str) and message.strip():
return message
return None
def _sanitize_error(value: Any) -> str | None:
if isinstance(value, str) and value.strip():
text = " ".join(value.split())
return _redact_sensitive_text(text)[:300]
if isinstance(value, dict):
for key in ("message", "error", "detail"):
item = value.get(key)
if isinstance(item, str) and item.strip():
text = " ".join(item.split())
return _redact_sensitive_text(text)[:300]
return None
def _sanitize_result(value: Any) -> dict[str, Any]:
if not isinstance(value, dict):
return {}
def _is_sensitive_key(key: str) -> bool:
normalized = key.strip().lower().replace("-", "_")
if not normalized:
return False
exact = {
"password",
"token",
"secret",
"api_key",
"apikey",
"credential",
"authorization",
"auth",
}
if normalized in exact:
return True
patterns = (
"password",
"token",
"secret",
"credential",
"api_key",
"apikey",
"authorization",
)
return any(pattern in normalized for pattern in patterns)
def _sanitize_value(item: Any) -> Any:
if isinstance(item, dict):
return _sanitize_result(item)
if isinstance(item, list):
return [_sanitize_value(entry) for entry in item]
if isinstance(item, str):
return _redact_sensitive_text(item)
return item
sanitized: dict[str, Any] = {}
for key, item in value.items():
key_text = str(key)
if _is_sensitive_key(key_text):
sanitized[key_text] = "[REDACTED]"
continue
sanitized[key_text] = _sanitize_value(item)
return sanitized
def _redact_sensitive_text(value: str) -> str:
redacted = value
key_value_patterns = (
r"(?i)(authorization)\s*[:=]\s*bearer\s+[^\s,;]+",
r"(?i)(password|token|secret|api[_-]?key|authorization|credential)\s*[:=]\s*[^\s,;]+",
r"(?i)(password|token|secret|api[_-]?key|authorization|credential)\s+[^\s,;]+",
)
for pattern in key_value_patterns:
redacted = re.sub(pattern, r"\1=[REDACTED]", redacted)
redacted = re.sub(r"(?i)bearer\s+[^\s,;]+", "Bearer [REDACTED]", redacted)
return redacted
def _text_end_telemetry_payload(metadata: dict[str, Any]) -> dict[str, Any]:
payload: dict[str, Any] = {}
model = _first_non_empty_str(metadata, keys=("model", "model_code"))
@@ -52,6 +52,24 @@ def _tools_payload_from_schema(
return payload
def _merge_tool_schemas(
*schema_sets: list[dict[str, object]],
) -> list[dict[str, object]]:
merged: list[dict[str, object]] = []
seen_names: set[str] = set()
for schemas in schema_sets:
for schema in schemas:
function = schema.get("function")
if not isinstance(function, dict):
continue
name = function.get("name")
if not isinstance(name, str) or not name or name in seen_names:
continue
seen_names.add(name)
merged.append(schema)
return merged
class AgentScopeRuntimeOrchestrator:
_runner: Any
_config_loader: Callable[[AsyncSession], Awaitable[dict[str, RuntimeStageConfig]]]
@@ -96,10 +114,20 @@ class AgentScopeRuntimeOrchestrator:
enable_hitl=False,
)
intent_tools_schema = intent_toolkit.get_json_schemas()
execution_toolkit = build_stage_toolkit(
stage="execution",
session=session,
owner_id=owner_id,
user_token=user_token,
enable_hitl=True,
)
execution_tools_schema = execution_toolkit.get_json_schemas()
intent_prompt = build_system_prompt(
stage="intent",
user_context=user_context,
tools=_tools_payload_from_schema(intent_tools_schema),
tools=_tools_payload_from_schema(
_merge_tool_schemas(intent_tools_schema, execution_tools_schema)
),
)
intent_payload = await self._runner.run_json_stage(
stage_config=stage_config["intent"],
@@ -125,14 +153,6 @@ class AgentScopeRuntimeOrchestrator:
execution_output: ExecutionBatchOutput | None = None
if intent_output.route == "TASK_EXECUTION":
execution_toolkit = build_stage_toolkit(
stage="execution",
session=session,
owner_id=owner_id,
user_token=user_token,
enable_hitl=True,
)
execution_tools_schema = execution_toolkit.get_json_schemas()
execution_prompt = build_system_prompt(
stage="execution",
user_context=user_context,
@@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio
import json
import math
from time import perf_counter
from typing import Any, cast
@@ -106,6 +107,9 @@ class AgentScopeReActRunner:
stage_config=stage_config,
response=response,
latency_ms=latency_ms,
system_prompt=system_prompt,
user_prompt=user_prompt,
assistant_text=text_content,
)
except json.JSONDecodeError as exc:
logger.exception(
@@ -234,6 +238,9 @@ def _merge_stage_response_metadata(
stage_config: RuntimeStageConfig,
response: Any,
latency_ms: int,
system_prompt: str,
user_prompt: str | list[dict[str, Any]],
assistant_text: str,
) -> dict[str, Any]:
result = dict(payload)
existing = result.get("response_metadata")
@@ -247,6 +254,15 @@ def _merge_stage_response_metadata(
completion_tokens = _to_non_negative_int(
_read_value(usage, "completion_tokens") or _read_value(usage, "output_tokens")
)
if prompt_tokens is None:
prompt_tokens = _estimate_token_count(
{
"system": system_prompt,
"user": user_prompt,
}
)
if completion_tokens is None:
completion_tokens = _estimate_token_count(assistant_text)
cost = _to_non_negative_float(
_read_value(usage, "cost")
or _read_value(_read_value(usage, "metadata"), "cost")
@@ -352,3 +368,16 @@ def _estimate_cost_by_pricing(
)
except ValueError:
return None
def _estimate_token_count(value: object) -> int:
try:
serialized = (
value if isinstance(value, str) else json.dumps(value, ensure_ascii=False)
)
except (TypeError, ValueError):
serialized = str(value)
normalized = serialized.strip()
if not normalized:
return 0
return max(1, math.ceil(len(normalized) / 4))
+7 -8
View File
@@ -21,6 +21,7 @@ from core.config.settings import config
from core.db.session import AsyncSessionLocal
from core.logging import get_logger
from core.taskiq.app import bulk_broker, critical_broker, default_broker
from core.agentscope.tools.tool_result_storage import create_tool_result_storage
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
from services.base.redis import get_or_init_redis_client
@@ -67,16 +68,10 @@ def _build_user_context(*, owner_id: UUID, run_input: RunCommand) -> UserAgentCo
def _extract_user_token(
*, command: dict[str, Any], run_input: RunCommand
) -> str | None:
del run_input
raw_token = command.get("user_token")
if isinstance(raw_token, str) and raw_token.strip():
return raw_token.strip()
forwarded = (
run_input.forwarded_props if isinstance(run_input.forwarded_props, dict) else {}
)
for key in ("accessToken", "userToken", "token"):
value = forwarded.get(key)
if isinstance(value, str) and value.strip():
return value.strip()
return None
@@ -162,7 +157,11 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
)
pipeline = AgentScopeEventPipeline(
codec=AgentScopeAgUiCodec(),
store=SqlAlchemyEventStore(session_factory=AsyncSessionLocal),
store=SqlAlchemyEventStore(
session_factory=AsyncSessionLocal,
tool_result_storage=create_tool_result_storage(),
tool_result_bucket=config.storage.bucket,
),
bus=bus,
)
runtime = route_runtime_type(
@@ -67,6 +67,7 @@ def validate_run_request_messages_contract(run_input: RunAgentInput) -> None:
message = run_input.messages[0]
if getattr(message, "role", None) != "user":
raise ValueError("RunAgentInput.messages[0].role must be user")
_validate_user_content_blocks(getattr(message, "content", None))
extract_latest_user_payload(run_input)
@@ -106,84 +107,76 @@ def extract_latest_user_payload(
text_parts.append(text)
blocks.append({"type": "text", "text": text})
continue
if item_type not in {"image", "binary"}:
if item_type != "binary":
continue
source_type: str | None = None
source_value: str | None = None
source_mime: str | None = None
if item_type == "binary":
source_mime = (
item.get("mimeType")
if isinstance(item, dict)
else getattr(item, "mime_type", None)
)
source_url = (
item.get("url")
if isinstance(item, dict)
else getattr(item, "url", None)
)
source_data = (
item.get("data")
if isinstance(item, dict)
else getattr(item, "data", None)
)
if isinstance(source_url, str) and source_url:
source_type = "url"
source_value = source_url
elif isinstance(source_data, str) and source_data:
source_type = "data"
source_value = source_data
else:
source = getattr(item, "source", None)
source_type = (
source.get("type")
if isinstance(source, dict)
else getattr(source, "type", None)
)
source_value = (
source.get("value")
if isinstance(source, dict)
else getattr(source, "value", None)
)
source_mime = (
source.get("mimeType")
if isinstance(source, dict)
else getattr(source, "mimeType", None)
)
if (
source_type == "url"
and isinstance(source_value, str)
and source_value
):
source_url = (
item.get("url")
if isinstance(item, dict)
else getattr(item, "url", None)
)
if isinstance(source_url, str) and source_url:
blocks.append(
{"type": "image_url", "image_url": {"url": source_value}}
)
elif (
source_type == "data"
and isinstance(source_value, str)
and source_value
):
mime_type = (
source_mime
if isinstance(source_mime, str) and source_mime
else "image/png"
)
blocks.append(
{
"type": "image_url",
"image_url": {
"url": f"data:{mime_type};base64,{source_value}"
},
}
{"type": "image_url", "image_url": {"url": source_url}}
)
combined = "".join(text_parts).strip()
if combined:
if combined or blocks:
return combined, blocks
raise ValueError(
"RunAgentInput.messages requires at least one non-empty user message"
)
def _validate_user_content_blocks(content: Any) -> None:
if isinstance(content, str):
if content.strip():
return
raise ValueError(
"RunAgentInput.messages requires at least one non-empty user message"
)
if not isinstance(content, list):
raise ValueError("RunAgentInput.messages[0].content must be string or list")
has_text = False
has_binary = False
for item in content:
item_type = getattr(item, "type", None)
if item_type == "text":
text = getattr(item, "text", None)
if isinstance(text, str) and text.strip():
has_text = True
continue
if item_type == "binary":
mime_type = (
item.get("mimeType")
if isinstance(item, dict)
else getattr(item, "mime_type", None)
)
url = (
item.get("url")
if isinstance(item, dict)
else getattr(item, "url", None)
)
data = (
item.get("data")
if isinstance(item, dict)
else getattr(item, "data", None)
)
if not isinstance(mime_type, str) or not mime_type.startswith("image/"):
raise ValueError("binary content requires image mimeType")
if not isinstance(url, str) or not url:
raise ValueError("binary content requires url")
if isinstance(data, str) and data:
raise ValueError("binary content data is not allowed")
has_binary = True
continue
raise ValueError("unsupported content block type")
if not has_text and not has_binary:
raise ValueError(
"RunAgentInput.messages requires at least one non-empty user message"
)
def extract_latest_tool_result(
run_input: RunAgentInput,
) -> tuple[str, dict[str, object]]: