feat: 实现 AgentScope ReAct Runner 两阶段执行并重构事件处理
This commit is contained in:
@@ -38,9 +38,10 @@ _INTERNAL_TO_AGUI: dict[str, EventType] = {
|
||||
|
||||
|
||||
def _convert_to_agui_type(internal_type: str) -> EventType:
|
||||
return _INTERNAL_TO_AGUI.get(
|
||||
internal_type, EventType(internal_type.upper().replace(".", "_"))
|
||||
)
|
||||
mapped = _INTERNAL_TO_AGUI.get(internal_type)
|
||||
if mapped is not None:
|
||||
return mapped
|
||||
return EventType(internal_type.upper().replace(".", "_"))
|
||||
|
||||
|
||||
def _is_agui_event(event: dict[str, Any]) -> bool:
|
||||
@@ -142,32 +143,64 @@ def to_agui_wire_event(event: dict[str, Any] | BaseEvent) -> dict[str, Any]:
|
||||
return event
|
||||
|
||||
internal_type = str(event.get("type", "")).strip()
|
||||
|
||||
thread_id = event.get("threadId")
|
||||
run_id = event.get("runId")
|
||||
data = event.get("data")
|
||||
|
||||
if internal_type == "text.end" and isinstance(data, dict):
|
||||
text_end_payload: dict[str, Any] = {
|
||||
"type": _convert_to_agui_type(internal_type).value,
|
||||
}
|
||||
if isinstance(thread_id, str) and thread_id:
|
||||
text_end_payload["threadId"] = thread_id
|
||||
if isinstance(run_id, str) and run_id:
|
||||
text_end_payload["runId"] = run_id
|
||||
for key in ("messageId", "workerAgentOutput"):
|
||||
value = data.get(key)
|
||||
if value is not None:
|
||||
text_end_payload[key] = value
|
||||
return text_end_payload
|
||||
|
||||
if internal_type == "tool.result" and isinstance(data, dict):
|
||||
tool_result_payload = {
|
||||
"type": _convert_to_agui_type(internal_type).value,
|
||||
}
|
||||
if isinstance(thread_id, str) and thread_id:
|
||||
tool_result_payload["threadId"] = thread_id
|
||||
if isinstance(run_id, str) and run_id:
|
||||
tool_result_payload["runId"] = run_id
|
||||
for key in ("messageId", "toolCallId", "toolAgentOutput"):
|
||||
value = data.get(key)
|
||||
if value is not None:
|
||||
tool_result_payload[key] = value
|
||||
return tool_result_payload
|
||||
|
||||
builder = _BUILDER_MAP.get(internal_type)
|
||||
|
||||
if builder:
|
||||
agui_event = builder(event)
|
||||
return agui_event.model_dump(by_alias=True, exclude_none=True)
|
||||
payload = agui_event.model_dump(by_alias=True, exclude_none=True)
|
||||
if isinstance(thread_id, str) and thread_id:
|
||||
payload["threadId"] = thread_id
|
||||
if isinstance(run_id, str) and run_id:
|
||||
payload["runId"] = run_id
|
||||
if isinstance(data, dict):
|
||||
reserved = {"type", "threadId", "runId"}
|
||||
payload.update({k: v for k, v in data.items() if k not in reserved})
|
||||
return payload
|
||||
|
||||
wire_type = _convert_to_agui_type(internal_type)
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"type": wire_type.value,
|
||||
}
|
||||
thread_id = event.get("threadId")
|
||||
run_id = event.get("runId")
|
||||
if isinstance(thread_id, str) and thread_id:
|
||||
payload["threadId"] = thread_id
|
||||
if isinstance(run_id, str) and run_id:
|
||||
payload["runId"] = run_id
|
||||
|
||||
data = event.get("data")
|
||||
if isinstance(data, dict):
|
||||
if internal_type == "text.end":
|
||||
for key in ("messageId", "workerAgentOutput"):
|
||||
value = data.get(key)
|
||||
if value is not None:
|
||||
payload[key] = value
|
||||
return payload
|
||||
reserved = {"type", "threadId", "runId"}
|
||||
data_map = cast(dict[str, Any], data)
|
||||
payload.update({k: v for k, v in data_map.items() if k not in reserved})
|
||||
|
||||
@@ -50,5 +50,5 @@ class AgentScopeEventPipeline:
|
||||
) -> str:
|
||||
event_dict = to_dict(event)
|
||||
wire_event = self._codec.to_wire(event_dict)
|
||||
await self._store.persist(wire_event)
|
||||
await self._store.persist(event_dict)
|
||||
return await self._bus.publish(session_id=session_id, event=wire_event)
|
||||
|
||||
@@ -55,8 +55,8 @@ class SqlAlchemyEventStore:
|
||||
self._message_contexts: dict[tuple[str, str], dict[str, object]] = {}
|
||||
|
||||
async def persist(self, event: dict[str, Any]) -> None:
|
||||
event_type = str(event.get("type", "")).strip().upper()
|
||||
thread_id = event.get("threadId")
|
||||
event_type = str(event.get("type", "")).strip().upper().replace(".", "_")
|
||||
thread_id = self._event_value(event, "threadId")
|
||||
if not isinstance(thread_id, str) or not thread_id:
|
||||
return
|
||||
try:
|
||||
@@ -124,8 +124,8 @@ class SqlAlchemyEventStore:
|
||||
await session.commit()
|
||||
|
||||
def _buffer_text_delta(self, *, session_key: str, event: dict[str, Any]) -> None:
|
||||
message_id = event.get("messageId")
|
||||
delta = event.get("delta")
|
||||
message_id = self._event_value(event, "messageId")
|
||||
delta = self._event_value(event, "delta")
|
||||
if not isinstance(message_id, str) or not message_id:
|
||||
return
|
||||
if not isinstance(delta, str) or not delta:
|
||||
@@ -143,13 +143,13 @@ class SqlAlchemyEventStore:
|
||||
self._message_contexts.pop(key, None)
|
||||
|
||||
def _buffer_text_context(self, *, session_key: str, event: dict[str, Any]) -> None:
|
||||
message_id = event.get("messageId")
|
||||
message_id = self._event_value(event, "messageId")
|
||||
if not isinstance(message_id, str) or not message_id:
|
||||
return
|
||||
key = (session_key, message_id)
|
||||
role = event.get("role")
|
||||
stage = event.get("stage")
|
||||
tool_name = event.get("toolName")
|
||||
role = self._event_value(event, "role")
|
||||
stage = self._event_value(event, "stage")
|
||||
tool_name = self._event_value(event, "toolName")
|
||||
context: dict[str, object] = {}
|
||||
if isinstance(role, str) and role:
|
||||
context["role"] = role
|
||||
@@ -168,7 +168,7 @@ class SqlAlchemyEventStore:
|
||||
session_repo: SessionRepository,
|
||||
message_repo: MessageRepository,
|
||||
) -> None:
|
||||
message_id_raw = event.get("messageId")
|
||||
message_id_raw = self._event_value(event, "messageId")
|
||||
message_id = message_id_raw if isinstance(message_id_raw, str) else ""
|
||||
key = (str(session_id), message_id)
|
||||
content = self._message_buffers.get(key, "")
|
||||
@@ -177,26 +177,26 @@ class SqlAlchemyEventStore:
|
||||
|
||||
context = self._message_contexts.get(key, {})
|
||||
|
||||
input_tokens = self._to_int(event.get("inputTokens"))
|
||||
output_tokens = self._to_int(event.get("outputTokens"))
|
||||
input_tokens = self._to_int(self._event_value(event, "inputTokens"))
|
||||
output_tokens = self._to_int(self._event_value(event, "outputTokens"))
|
||||
token_delta = input_tokens + output_tokens
|
||||
cost = self._to_decimal(event.get("cost"))
|
||||
latency_ms = self._to_int_or_none(event.get("latencyMs"))
|
||||
run_id = event.get("runId")
|
||||
model_code = event.get("model")
|
||||
cost = self._to_decimal(self._event_value(event, "cost"))
|
||||
latency_ms = self._to_int_or_none(self._event_value(event, "latencyMs"))
|
||||
run_id = self._event_value(event, "runId")
|
||||
model_code = self._event_value(event, "model")
|
||||
|
||||
metadata: dict[str, object] = {"message_id": message_id}
|
||||
if isinstance(run_id, str) and run_id:
|
||||
metadata["run_id"] = run_id
|
||||
if latency_ms is not None:
|
||||
metadata["latency_ms"] = latency_ms
|
||||
stage = event.get("stage")
|
||||
stage = self._event_value(event, "stage")
|
||||
if not isinstance(stage, str):
|
||||
stage = context.get("stage")
|
||||
if isinstance(stage, str) and stage:
|
||||
metadata["stage"] = stage
|
||||
|
||||
worker_payload = event.get("workerAgentOutput")
|
||||
worker_payload = self._event_value(event, "workerAgentOutput")
|
||||
if isinstance(worker_payload, dict):
|
||||
try:
|
||||
if "ui_hints" in worker_payload:
|
||||
@@ -264,11 +264,11 @@ class SqlAlchemyEventStore:
|
||||
session_repo: SessionRepository,
|
||||
message_repo: MessageRepository,
|
||||
) -> None:
|
||||
tool_name = event.get("toolName")
|
||||
tool_name = self._event_value(event, "toolName")
|
||||
if not isinstance(tool_name, str) or not tool_name:
|
||||
return
|
||||
|
||||
raw_output = event.get("toolAgentOutput")
|
||||
raw_output = self._event_value(event, "toolAgentOutput")
|
||||
if not isinstance(raw_output, dict):
|
||||
return
|
||||
try:
|
||||
@@ -276,11 +276,11 @@ class SqlAlchemyEventStore:
|
||||
except Exception:
|
||||
return
|
||||
|
||||
run_id = event.get("runId")
|
||||
run_id = self._event_value(event, "runId")
|
||||
run_id_value = run_id if isinstance(run_id, str) and run_id else ""
|
||||
task_id = event.get("taskId")
|
||||
task_id = self._event_value(event, "taskId")
|
||||
task_id_value = task_id if isinstance(task_id, str) and task_id else "task"
|
||||
call_id_value = event.get("callId")
|
||||
call_id_value = self._event_value(event, "callId")
|
||||
if not isinstance(call_id_value, str) or not call_id_value:
|
||||
call_id_value = (
|
||||
f"{run_id_value}-{task_id_value}-{uuid4().hex[:8]}"
|
||||
@@ -303,7 +303,7 @@ class SqlAlchemyEventStore:
|
||||
}
|
||||
if run_id_value:
|
||||
metadata["run_id"] = run_id_value
|
||||
stage = event.get("stage")
|
||||
stage = self._event_value(event, "stage")
|
||||
if isinstance(stage, str) and stage:
|
||||
metadata["stage"] = stage
|
||||
if task_id_value:
|
||||
@@ -421,6 +421,19 @@ class SqlAlchemyEventStore:
|
||||
return Decimal("0")
|
||||
return parsed if parsed >= 0 else Decimal("0")
|
||||
|
||||
def _event_value(
|
||||
self,
|
||||
event: dict[str, Any],
|
||||
key: str,
|
||||
default: object | None = None,
|
||||
) -> object | None:
|
||||
if key in event:
|
||||
return event.get(key)
|
||||
data = event.get("data")
|
||||
if isinstance(data, dict):
|
||||
return data.get(key, default)
|
||||
return default
|
||||
|
||||
|
||||
def _sanitize_path_component(value: str) -> str:
|
||||
compact = re.sub(r"[^A-Za-z0-9._-]", "-", value.strip())
|
||||
|
||||
@@ -2,9 +2,10 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
from typing import Any, Sequence
|
||||
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||
|
||||
from ag_ui.core.types import Tool
|
||||
from core.agentscope.prompts.agent_prompt import (
|
||||
build_agent_prompt,
|
||||
)
|
||||
@@ -193,7 +194,7 @@ def build_system_prompt(
|
||||
user_context: UserContext,
|
||||
now_utc: datetime,
|
||||
extra_context: str | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
tools: Sequence[Tool] | None = None,
|
||||
) -> str:
|
||||
sections = [
|
||||
_build_identity_section(),
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, Iterable
|
||||
from typing import Iterable
|
||||
|
||||
from ag_ui.core.types import Tool
|
||||
|
||||
|
||||
def _wrap_section(section: str, content: str) -> str:
|
||||
@@ -15,18 +17,16 @@ def _wrap_section(section: str, content: str) -> str:
|
||||
|
||||
def build_tools_prompt(
|
||||
*,
|
||||
tools: Iterable[dict[str, Any]],
|
||||
tools: Iterable[Tool],
|
||||
) -> str:
|
||||
lines: list[str] = []
|
||||
lines.append("[Available Tools]")
|
||||
|
||||
for item in tools:
|
||||
name = item.get("name")
|
||||
description = item.get("description") or ""
|
||||
parameters = item.get("parameters") or {}
|
||||
if not isinstance(name, str) or not name:
|
||||
continue
|
||||
lines.append(f"- {name}: {description}".strip())
|
||||
name = item.name
|
||||
description = item.description or ""
|
||||
parameters = item.parameters or {}
|
||||
lines.append(f"- {name}: {description}")
|
||||
lines.append(
|
||||
" - args_schema: "
|
||||
+ json.dumps(parameters, ensure_ascii=True, separators=(",", ":"))
|
||||
|
||||
@@ -1,16 +1,322 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncGenerator, Sequence
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from decimal import Decimal
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from ag_ui.core.types import RunAgentInput
|
||||
from agentscope.agent import ReActAgent
|
||||
from agentscope.formatter import OpenAIChatFormatter
|
||||
from agentscope.memory import InMemoryMemory
|
||||
from agentscope.message import Msg
|
||||
from agentscope.model import OpenAIChatModel
|
||||
from core.agentscope.events.persistence import MessageRepository, SessionRepository
|
||||
from core.agentscope.prompts.system_prompt import build_system_prompt
|
||||
from core.agentscope.tools.toolkit import build_stage_toolkit
|
||||
from core.db.session import AsyncSessionLocal
|
||||
from core.logging import get_logger
|
||||
from models.agent_chat_message import AgentChatMessageRole
|
||||
from models.agent_chat_session import AgentChatSessionStatus
|
||||
from models.llm import Llm
|
||||
from models.system_agents import SystemAgents
|
||||
from schemas.agent.runtime_models import (
|
||||
RouterAgentOutput,
|
||||
ToolAgentOutput,
|
||||
WorkerAgentOutputLite,
|
||||
resolve_worker_output_model,
|
||||
)
|
||||
from schemas.agent.system_agent import AgentType, SystemAgentLLMConfig
|
||||
from schemas.messages.chat_message import AgentChatMessage, AgentChatMessageMetadata
|
||||
from schemas.user import UserContext
|
||||
from services.litellm.service import LiteLLMService
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.agentscope.runtime.orchestrator import PipelineLike
|
||||
|
||||
logger = get_logger("core.agentscope.runtime.react_runner")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SystemAgentRuntimeConfig:
|
||||
agent_type: AgentType
|
||||
model_code: str
|
||||
llm_config: SystemAgentLLMConfig
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class StageExecutionResult:
|
||||
message: Msg
|
||||
payload: dict[str, Any]
|
||||
response_metadata: dict[str, Any]
|
||||
|
||||
|
||||
class _TrackingChatModel:
|
||||
def __init__(self, inner: OpenAIChatModel) -> None:
|
||||
self._inner = inner
|
||||
self._total_input_tokens = 0
|
||||
self._total_output_tokens = 0
|
||||
self._total_latency_ms = 0
|
||||
self._cached_prompt_tokens = 0
|
||||
|
||||
@property
|
||||
def stream(self) -> bool:
|
||||
return self._inner.stream
|
||||
|
||||
@stream.setter
|
||||
def stream(self, value: bool) -> None:
|
||||
self._inner.stream = value
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
return getattr(self._inner, name)
|
||||
|
||||
async def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
response = await self._inner(*args, **kwargs)
|
||||
if isinstance(response, AsyncGenerator):
|
||||
return self._track_stream(response)
|
||||
self._record_usage(getattr(response, "usage", None))
|
||||
return response
|
||||
|
||||
async def _track_stream(
|
||||
self, response: AsyncGenerator[Any, None]
|
||||
) -> AsyncGenerator[Any, None]:
|
||||
latest_usage = None
|
||||
async for chunk in response:
|
||||
usage = getattr(chunk, "usage", None)
|
||||
if usage is not None:
|
||||
latest_usage = usage
|
||||
yield chunk
|
||||
self._record_usage(latest_usage)
|
||||
|
||||
def _record_usage(self, usage: Any) -> None:
|
||||
if usage is None:
|
||||
return
|
||||
self._total_input_tokens += max(int(getattr(usage, "input_tokens", 0) or 0), 0)
|
||||
self._total_output_tokens += max(
|
||||
int(getattr(usage, "output_tokens", 0) or 0), 0
|
||||
)
|
||||
self._total_latency_ms += max(
|
||||
int(round(float(getattr(usage, "time", 0) or 0) * 1000)), 0
|
||||
)
|
||||
metadata = getattr(usage, "metadata", None)
|
||||
if metadata is not None:
|
||||
cached_tokens = 0
|
||||
if isinstance(metadata, dict):
|
||||
prompt_details = metadata.get("prompt_tokens_details")
|
||||
if isinstance(prompt_details, dict):
|
||||
cached_tokens = int(prompt_details.get("cached_tokens", 0) or 0)
|
||||
else:
|
||||
prompt_details = getattr(metadata, "prompt_tokens_details", None)
|
||||
cached_tokens = int(getattr(prompt_details, "cached_tokens", 0) or 0)
|
||||
self._cached_prompt_tokens += max(cached_tokens, 0)
|
||||
|
||||
def usage_summary(self) -> dict[str, int]:
|
||||
return {
|
||||
"input_tokens": self._total_input_tokens,
|
||||
"output_tokens": self._total_output_tokens,
|
||||
"latency_ms": self._total_latency_ms,
|
||||
"cached_prompt_tokens": self._cached_prompt_tokens,
|
||||
}
|
||||
|
||||
|
||||
class _PipelineStageEmitter:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
pipeline: PipelineLike,
|
||||
session_id: str,
|
||||
run_id: str,
|
||||
stage: str,
|
||||
emit_text_events: bool,
|
||||
emit_tool_events: bool,
|
||||
) -> None:
|
||||
self._pipeline = pipeline
|
||||
self._session_id = session_id
|
||||
self._run_id = run_id
|
||||
self._stage = stage
|
||||
self._emit_text_events = emit_text_events
|
||||
self._emit_tool_events = emit_tool_events
|
||||
self._text_by_message_id: dict[str, str] = {}
|
||||
self._emitted_tool_calls: set[str] = set()
|
||||
self._emitted_tool_results: set[str] = set()
|
||||
self.latest_text_message_id: str | None = None
|
||||
self.latest_text: str = ""
|
||||
|
||||
async def handle_print(self, *, msg: Msg, last: bool) -> None:
|
||||
del last
|
||||
if self._emit_tool_events:
|
||||
await self._emit_tool_events_from_msg(msg)
|
||||
if self._emit_text_events:
|
||||
await self._emit_text_events_from_msg(msg)
|
||||
|
||||
async def _emit_text_events_from_msg(self, msg: Msg) -> None:
|
||||
text = msg.get_text_content(separator="") or ""
|
||||
if not text:
|
||||
return
|
||||
message_id = str(msg.id)
|
||||
previous = self._text_by_message_id.get(message_id, "")
|
||||
if message_id not in self._text_by_message_id:
|
||||
await self._emit(
|
||||
"text.start",
|
||||
{
|
||||
"messageId": message_id,
|
||||
"role": "assistant",
|
||||
"stage": self._stage,
|
||||
},
|
||||
)
|
||||
delta = text[len(previous) :] if text.startswith(previous) else text
|
||||
if delta:
|
||||
await self._emit(
|
||||
"text.delta",
|
||||
{
|
||||
"messageId": message_id,
|
||||
"delta": delta,
|
||||
"stage": self._stage,
|
||||
},
|
||||
)
|
||||
self._text_by_message_id[message_id] = text
|
||||
self.latest_text_message_id = message_id
|
||||
self.latest_text = text
|
||||
|
||||
async def _emit_tool_events_from_msg(self, msg: Msg) -> None:
|
||||
for block in msg.get_content_blocks("tool_use"):
|
||||
tool_call_id = str(block.get("id", "")).strip()
|
||||
tool_name = str(block.get("name", "")).strip()
|
||||
if (
|
||||
not tool_call_id
|
||||
or not tool_name
|
||||
or tool_call_id in self._emitted_tool_calls
|
||||
):
|
||||
continue
|
||||
payload = {
|
||||
"messageId": str(msg.id),
|
||||
"toolCallId": tool_call_id,
|
||||
"toolName": tool_name,
|
||||
"stage": self._stage,
|
||||
}
|
||||
await self._emit("tool.start", payload)
|
||||
await self._emit(
|
||||
"tool.args",
|
||||
{
|
||||
**payload,
|
||||
"args": block.get("input", {}),
|
||||
},
|
||||
)
|
||||
await self._emit("tool.end", payload)
|
||||
self._emitted_tool_calls.add(tool_call_id)
|
||||
|
||||
for block in msg.get_content_blocks("tool_result"):
|
||||
tool_call_id = str(block.get("id", "")).strip()
|
||||
if not tool_call_id or tool_call_id in self._emitted_tool_results:
|
||||
continue
|
||||
tool_output = _parse_tool_agent_output(block.get("output"))
|
||||
if tool_output is None:
|
||||
continue
|
||||
await self._emit(
|
||||
"tool.result",
|
||||
{
|
||||
"messageId": str(msg.id),
|
||||
"toolCallId": tool_call_id,
|
||||
"toolName": tool_output.tool_name,
|
||||
"stage": self._stage,
|
||||
"toolAgentOutput": tool_output.model_dump(
|
||||
mode="json", exclude_none=True
|
||||
),
|
||||
},
|
||||
)
|
||||
self._emitted_tool_results.add(tool_call_id)
|
||||
|
||||
async def emit_final_text_end(
|
||||
self,
|
||||
*,
|
||||
worker_output: dict[str, Any],
|
||||
response_metadata: dict[str, Any],
|
||||
) -> None:
|
||||
message_id = (
|
||||
self.latest_text_message_id or f"worker-{self._run_id}-{uuid4().hex[:8]}"
|
||||
)
|
||||
if self.latest_text_message_id is None and worker_output.get("answer"):
|
||||
await self._emit(
|
||||
"text.start",
|
||||
{
|
||||
"messageId": message_id,
|
||||
"role": "assistant",
|
||||
"stage": self._stage,
|
||||
},
|
||||
)
|
||||
await self._emit(
|
||||
"text.delta",
|
||||
{
|
||||
"messageId": message_id,
|
||||
"delta": worker_output.get("answer", ""),
|
||||
"stage": self._stage,
|
||||
},
|
||||
)
|
||||
await self._emit(
|
||||
"text.end",
|
||||
{
|
||||
"messageId": message_id,
|
||||
"role": "assistant",
|
||||
"stage": self._stage,
|
||||
"workerAgentOutput": worker_output,
|
||||
**response_metadata,
|
||||
},
|
||||
)
|
||||
|
||||
async def _emit(self, event_type: str, data: dict[str, Any]) -> None:
|
||||
await self._pipeline.emit(
|
||||
session_id=self._session_id,
|
||||
event={
|
||||
"type": event_type,
|
||||
"threadId": self._session_id,
|
||||
"runId": self._run_id,
|
||||
"data": data,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class _PipelineReActAgent(ReActAgent):
|
||||
def __init__(
|
||||
self, *, emitter: _PipelineStageEmitter | None = None, **kwargs: Any
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self._pipeline_emitter = emitter
|
||||
self.disable_console_output()
|
||||
|
||||
async def print(self, msg: Msg, last: bool = True, speech: Any = None) -> None:
|
||||
del speech
|
||||
if self._pipeline_emitter is not None:
|
||||
await self._pipeline_emitter.handle_print(msg=msg, last=last)
|
||||
|
||||
|
||||
def _parse_tool_agent_output(output: Any) -> ToolAgentOutput | None:
|
||||
blocks = output if isinstance(output, Sequence) else []
|
||||
for block in blocks:
|
||||
if not isinstance(block, dict) or block.get("type") != "text":
|
||||
continue
|
||||
text = block.get("text")
|
||||
if not isinstance(text, str) or not text.strip():
|
||||
continue
|
||||
try:
|
||||
return ToolAgentOutput.model_validate(json.loads(text))
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_tool_name(value: str) -> str:
|
||||
return value.strip().replace(".", "_").replace("-", "_")
|
||||
|
||||
|
||||
class AgentScopeReActRunner:
|
||||
def __init__(self, *, litellm_service: LiteLLMService | None = None) -> None:
|
||||
self._litellm_service = litellm_service or LiteLLMService()
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
*,
|
||||
@@ -19,4 +325,367 @@ class AgentScopeReActRunner:
|
||||
pipeline: PipelineLike,
|
||||
run_input: RunAgentInput,
|
||||
) -> dict[str, Any]:
|
||||
raise NotImplementedError("execute method not implemented")
|
||||
owner_id = UUID(user_context.id)
|
||||
enabled_tool_names = self._extract_tool_names(run_input)
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
router_toolkit, worker_toolkit = self._build_toolkits(
|
||||
session=session,
|
||||
owner_id=owner_id,
|
||||
enabled_tool_names=enabled_tool_names,
|
||||
)
|
||||
|
||||
router_config = await self._load_system_agent_config(
|
||||
session=session,
|
||||
agent_type=AgentType.ROUTER,
|
||||
)
|
||||
worker_config = await self._load_system_agent_config(
|
||||
session=session,
|
||||
agent_type=AgentType.WORKER,
|
||||
)
|
||||
|
||||
await self._emit_step_event(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
step_name="router",
|
||||
event_type="step.start",
|
||||
)
|
||||
router_result = await self._run_router_stage(
|
||||
user_context=user_context,
|
||||
context_messages=context_messages,
|
||||
toolkit=router_toolkit,
|
||||
run_input=run_input,
|
||||
stage_config=router_config,
|
||||
)
|
||||
router_output = RouterAgentOutput.model_validate(router_result.payload)
|
||||
await self._persist_router_message(
|
||||
session=session,
|
||||
thread_id=run_input.thread_id,
|
||||
run_id=run_input.run_id,
|
||||
model_code=router_config.model_code,
|
||||
router_output=router_output,
|
||||
response_metadata=router_result.response_metadata,
|
||||
)
|
||||
await session.commit()
|
||||
await self._emit_step_event(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
step_name="router",
|
||||
event_type="step.finish",
|
||||
)
|
||||
|
||||
worker_output_model = resolve_worker_output_model(router_output.ui.ui_mode)
|
||||
await self._emit_step_event(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
step_name="worker",
|
||||
event_type="step.start",
|
||||
)
|
||||
worker_result = await self._run_worker_stage(
|
||||
user_context=user_context,
|
||||
router_output=router_output,
|
||||
toolkit=worker_toolkit,
|
||||
run_input=run_input,
|
||||
stage_config=worker_config,
|
||||
worker_output_model=worker_output_model,
|
||||
pipeline=pipeline,
|
||||
)
|
||||
worker_output = worker_output_model.model_validate(worker_result.payload)
|
||||
await self._emit_step_event(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
step_name="worker",
|
||||
event_type="step.finish",
|
||||
)
|
||||
|
||||
return {
|
||||
"router": router_output.model_dump(mode="json", exclude_none=True),
|
||||
"worker": worker_output.model_dump(mode="json", exclude_none=True),
|
||||
}
|
||||
|
||||
def _build_toolkits(
|
||||
self,
|
||||
*,
|
||||
session: AsyncSession,
|
||||
owner_id: UUID,
|
||||
enabled_tool_names: set[str] | None,
|
||||
) -> tuple[Any, Any]:
|
||||
return (
|
||||
build_stage_toolkit(
|
||||
agent_type=AgentType.ROUTER,
|
||||
session=session,
|
||||
owner_id=owner_id,
|
||||
enabled_tool_names=enabled_tool_names,
|
||||
),
|
||||
build_stage_toolkit(
|
||||
agent_type=AgentType.WORKER,
|
||||
session=session,
|
||||
owner_id=owner_id,
|
||||
enabled_tool_names=enabled_tool_names,
|
||||
),
|
||||
)
|
||||
|
||||
def _extract_tool_names(self, run_input: RunAgentInput) -> set[str] | None:
|
||||
raw_tools = getattr(run_input, "tools", None)
|
||||
if not isinstance(raw_tools, list):
|
||||
return None
|
||||
selected: set[str] = set()
|
||||
for item in raw_tools:
|
||||
if isinstance(item, dict):
|
||||
name = item.get("name")
|
||||
else:
|
||||
name = getattr(item, "name", None)
|
||||
if isinstance(name, str) and name.strip():
|
||||
selected.add(_normalize_tool_name(name))
|
||||
return selected
|
||||
|
||||
async def _load_system_agent_config(
|
||||
self,
|
||||
*,
|
||||
session: AsyncSession,
|
||||
agent_type: AgentType,
|
||||
) -> SystemAgentRuntimeConfig:
|
||||
stmt = (
|
||||
select(SystemAgents, Llm)
|
||||
.join(Llm, SystemAgents.llm_id == Llm.id)
|
||||
.where(SystemAgents.agent_type == agent_type.value)
|
||||
)
|
||||
row = (await session.execute(stmt)).one_or_none()
|
||||
if row is None:
|
||||
raise RuntimeError(f"system agent config not found: {agent_type.value}")
|
||||
system_agent, llm = row
|
||||
status = str(system_agent.status).strip().lower()
|
||||
if status != "active":
|
||||
raise RuntimeError(f"system agent is not active: {agent_type.value}")
|
||||
return SystemAgentRuntimeConfig(
|
||||
agent_type=agent_type,
|
||||
model_code=llm.model_code,
|
||||
llm_config=SystemAgentLLMConfig.model_validate(system_agent.config or {}),
|
||||
)
|
||||
|
||||
async def _run_router_stage(
|
||||
self,
|
||||
*,
|
||||
user_context: UserContext,
|
||||
context_messages: list[Msg],
|
||||
toolkit: Any,
|
||||
run_input: RunAgentInput,
|
||||
stage_config: SystemAgentRuntimeConfig,
|
||||
) -> StageExecutionResult:
|
||||
tracking_model = self._build_model(stage_config=stage_config)
|
||||
agent = self._build_agent(
|
||||
agent_name="router",
|
||||
system_prompt=build_system_prompt(
|
||||
agent_type=AgentType.ROUTER,
|
||||
user_context=user_context,
|
||||
now_utc=datetime.now(timezone.utc),
|
||||
tools=run_input.tools,
|
||||
),
|
||||
toolkit=toolkit,
|
||||
model=tracking_model,
|
||||
)
|
||||
response_msg = await agent.reply(
|
||||
context_messages, structured_model=RouterAgentOutput
|
||||
)
|
||||
payload = RouterAgentOutput.model_validate(
|
||||
response_msg.metadata or {}
|
||||
).model_dump(
|
||||
mode="json",
|
||||
exclude_none=True,
|
||||
)
|
||||
return StageExecutionResult(
|
||||
message=response_msg,
|
||||
payload=payload,
|
||||
response_metadata=self._litellm_service.build_usage_metadata(
|
||||
model=stage_config.model_code,
|
||||
usage_summary=tracking_model.usage_summary(),
|
||||
),
|
||||
)
|
||||
|
||||
async def _run_worker_stage(
|
||||
self,
|
||||
*,
|
||||
user_context: UserContext,
|
||||
router_output: RouterAgentOutput,
|
||||
toolkit: Any,
|
||||
run_input: RunAgentInput,
|
||||
stage_config: SystemAgentRuntimeConfig,
|
||||
worker_output_model: type[WorkerAgentOutputLite],
|
||||
pipeline: PipelineLike,
|
||||
) -> StageExecutionResult:
|
||||
worker_input = self._build_worker_input_messages(
|
||||
router_output=router_output,
|
||||
)
|
||||
tracking_model = self._build_model(stage_config=stage_config)
|
||||
emitter = _PipelineStageEmitter(
|
||||
pipeline=pipeline,
|
||||
session_id=run_input.thread_id,
|
||||
run_id=run_input.run_id,
|
||||
stage="worker",
|
||||
emit_text_events=True,
|
||||
emit_tool_events=True,
|
||||
)
|
||||
agent = self._build_agent(
|
||||
agent_name="worker",
|
||||
system_prompt=build_system_prompt(
|
||||
agent_type=AgentType.WORKER,
|
||||
user_context=user_context,
|
||||
now_utc=datetime.now(timezone.utc),
|
||||
tools=run_input.tools,
|
||||
),
|
||||
toolkit=toolkit,
|
||||
model=tracking_model,
|
||||
emitter=emitter,
|
||||
)
|
||||
response_msg = await agent.reply(
|
||||
worker_input,
|
||||
structured_model=worker_output_model,
|
||||
)
|
||||
worker_payload = worker_output_model.model_validate(response_msg.metadata or {})
|
||||
response_metadata = self._litellm_service.build_usage_metadata(
|
||||
model=stage_config.model_code,
|
||||
usage_summary=tracking_model.usage_summary(),
|
||||
)
|
||||
await emitter.emit_final_text_end(
|
||||
worker_output=worker_payload.model_dump(mode="json", exclude_none=True),
|
||||
response_metadata=response_metadata,
|
||||
)
|
||||
return StageExecutionResult(
|
||||
message=response_msg,
|
||||
payload=worker_payload.model_dump(mode="json", exclude_none=True),
|
||||
response_metadata=response_metadata,
|
||||
)
|
||||
|
||||
def _build_worker_input_messages(
|
||||
self,
|
||||
*,
|
||||
router_output: RouterAgentOutput,
|
||||
) -> list[Msg]:
|
||||
routing_contract = json.dumps(
|
||||
router_output.model_dump(mode="json", exclude_none=True),
|
||||
ensure_ascii=False,
|
||||
separators=(",", ":"),
|
||||
)
|
||||
routing_msg = Msg(
|
||||
name="router",
|
||||
role="user",
|
||||
content=(
|
||||
"Use the following routing contract as the execution source of truth. "
|
||||
f"Do not change the routed objective:\n{routing_contract}"
|
||||
),
|
||||
)
|
||||
return [routing_msg]
|
||||
|
||||
def _build_model(
|
||||
self, *, stage_config: SystemAgentRuntimeConfig
|
||||
) -> _TrackingChatModel:
|
||||
model = OpenAIChatModel(
|
||||
model_name=stage_config.model_code,
|
||||
api_key=self._litellm_service.proxy_api_key,
|
||||
stream=True,
|
||||
client_kwargs={"base_url": self._litellm_service.proxy_base_url},
|
||||
generate_kwargs={
|
||||
"temperature": stage_config.llm_config.temperature,
|
||||
"max_tokens": stage_config.llm_config.max_tokens,
|
||||
"timeout": stage_config.llm_config.timeout_seconds,
|
||||
},
|
||||
)
|
||||
return _TrackingChatModel(model)
|
||||
|
||||
def _build_agent(
|
||||
self,
|
||||
*,
|
||||
agent_name: str,
|
||||
system_prompt: str,
|
||||
toolkit: Any,
|
||||
model: _TrackingChatModel,
|
||||
emitter: _PipelineStageEmitter | None = None,
|
||||
) -> _PipelineReActAgent:
|
||||
return _PipelineReActAgent(
|
||||
name=agent_name,
|
||||
sys_prompt=system_prompt,
|
||||
model=model,
|
||||
formatter=OpenAIChatFormatter(),
|
||||
toolkit=toolkit,
|
||||
memory=InMemoryMemory(),
|
||||
emitter=emitter,
|
||||
)
|
||||
|
||||
async def _emit_step_event(
|
||||
self,
|
||||
*,
|
||||
pipeline: PipelineLike,
|
||||
run_input: RunAgentInput,
|
||||
step_name: str,
|
||||
event_type: str,
|
||||
) -> None:
|
||||
await pipeline.emit(
|
||||
session_id=run_input.thread_id,
|
||||
event={
|
||||
"type": event_type,
|
||||
"threadId": run_input.thread_id,
|
||||
"runId": run_input.run_id,
|
||||
"data": {"stepName": step_name},
|
||||
},
|
||||
)
|
||||
|
||||
async def _persist_router_message(
|
||||
self,
|
||||
*,
|
||||
session: AsyncSession,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
model_code: str,
|
||||
router_output: RouterAgentOutput,
|
||||
response_metadata: dict[str, Any],
|
||||
) -> None:
|
||||
session_id = UUID(thread_id)
|
||||
message_repo = MessageRepository(session)
|
||||
session_repo = SessionRepository(session)
|
||||
locked_session = await session_repo.lock_session_for_update(
|
||||
session_id=session_id
|
||||
)
|
||||
if locked_session is None:
|
||||
raise RuntimeError("chat session not found for router persistence")
|
||||
seq = int(getattr(locked_session, "message_count", 0) or 0) + 1
|
||||
metadata = AgentChatMessageMetadata(
|
||||
run_id=run_id,
|
||||
agent_type=AgentType.ROUTER,
|
||||
router_agent_output=router_output,
|
||||
)
|
||||
message_payload = AgentChatMessage(
|
||||
id=uuid4(),
|
||||
seq=seq,
|
||||
role=AgentChatMessageRole.ASSISTANT.value,
|
||||
content="",
|
||||
model_code=model_code,
|
||||
tool_name=None,
|
||||
input_tokens=int(response_metadata.get("inputTokens", 0) or 0),
|
||||
output_tokens=int(response_metadata.get("outputTokens", 0) or 0),
|
||||
cost=Decimal(str(response_metadata.get("cost", 0) or 0)),
|
||||
latency_ms=int(response_metadata.get("latencyMs", 0) or 0),
|
||||
metadata=metadata,
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
await message_repo.append_message(
|
||||
session_id=session_id,
|
||||
seq=message_payload.seq,
|
||||
role=AgentChatMessageRole.ASSISTANT,
|
||||
content=message_payload.content,
|
||||
model_code=message_payload.model_code,
|
||||
tool_name=message_payload.tool_name,
|
||||
metadata=metadata.model_dump(mode="json", exclude_none=True),
|
||||
input_tokens=message_payload.input_tokens,
|
||||
output_tokens=message_payload.output_tokens,
|
||||
cost=message_payload.cost,
|
||||
latency_ms=message_payload.latency_ms,
|
||||
)
|
||||
await session_repo.update_runtime_state(
|
||||
chat_session=locked_session,
|
||||
status=AgentChatSessionStatus.RUNNING,
|
||||
state_snapshot=locked_session.state_snapshot or {},
|
||||
message_delta=1,
|
||||
token_delta=message_payload.input_tokens + message_payload.output_tokens,
|
||||
cost_delta=message_payload.cost,
|
||||
)
|
||||
await session.flush()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -23,9 +23,8 @@ from schemas.agent.ui_hints import (
|
||||
UiHintAction,
|
||||
UiHintActionCopy,
|
||||
UiHintActionStyle,
|
||||
UiHintErrorBlock,
|
||||
UiHintKeyValuePair,
|
||||
UiHintKvBlock,
|
||||
UiHintIntent,
|
||||
UiHintKvItem,
|
||||
UiHintStatus,
|
||||
UiHintsPayload,
|
||||
)
|
||||
@@ -54,18 +53,10 @@ def _lookup_error_output(
|
||||
update={
|
||||
"tool_call_args": tool_call_args,
|
||||
"ui_hints": UiHintsPayload(
|
||||
intent=UiHintIntent.STATUS,
|
||||
status=UiHintStatus.ERROR,
|
||||
title="用户查找失败",
|
||||
description=message,
|
||||
blocks=[
|
||||
UiHintErrorBlock(
|
||||
kind="error",
|
||||
title="查找失败",
|
||||
errorCode=code,
|
||||
message=message,
|
||||
retryable=retryable,
|
||||
)
|
||||
],
|
||||
body=message,
|
||||
),
|
||||
}
|
||||
)
|
||||
@@ -78,28 +69,15 @@ def _lookup_success_hints(resolved: dict[str, Any]) -> UiHintsPayload:
|
||||
username = str(resolved.get("username") or "")
|
||||
matched_by = str(resolved.get("matchedBy") or "")
|
||||
return UiHintsPayload(
|
||||
intent=UiHintIntent.DATA,
|
||||
status=UiHintStatus.SUCCESS,
|
||||
title="用户信息",
|
||||
description=f"匹配方式: {matched_by}",
|
||||
blocks=[
|
||||
UiHintKvBlock(
|
||||
kind="kv",
|
||||
title="查找结果",
|
||||
pairs=[
|
||||
UiHintKeyValuePair(
|
||||
key="user_id", label="用户ID", value=user_id, copyable=True
|
||||
),
|
||||
UiHintKeyValuePair(
|
||||
key="email", label="邮箱", value=email, copyable=True
|
||||
),
|
||||
UiHintKeyValuePair(
|
||||
key="username", label="用户名", value=username or "-"
|
||||
),
|
||||
UiHintKeyValuePair(
|
||||
key="matched_by", label="匹配方式", value=matched_by
|
||||
),
|
||||
],
|
||||
)
|
||||
items=[
|
||||
UiHintKvItem(key="user_id", label="用户ID", value=user_id, copyable=True),
|
||||
UiHintKvItem(key="email", label="邮箱", value=email, copyable=True),
|
||||
UiHintKvItem(key="username", label="用户名", value=username or "-"),
|
||||
UiHintKvItem(key="matched_by", label="匹配方式", value=matched_by),
|
||||
],
|
||||
actions=[
|
||||
UiHintAction(
|
||||
|
||||
@@ -18,6 +18,7 @@ from core.agentscope.tools.tool_config import (
|
||||
)
|
||||
from core.agentscope.tools.tool_middleware import register_tool_middlewares
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from schemas.agent.system_agent import AgentType
|
||||
|
||||
TOOL_FUNCTIONS: dict[str, Any] = {
|
||||
"calendar_read": calendar_read,
|
||||
@@ -27,9 +28,9 @@ TOOL_FUNCTIONS: dict[str, Any] = {
|
||||
}
|
||||
|
||||
|
||||
STAGE_TO_GROUPS: dict[str, set[ToolGroup]] = {
|
||||
"router": {ToolGroup.READ},
|
||||
"worker": {ToolGroup.READ, ToolGroup.WRITE},
|
||||
AGENT_TYPE_TO_GROUPS: dict[AgentType, set[ToolGroup]] = {
|
||||
AgentType.ROUTER: {ToolGroup.READ},
|
||||
AgentType.WORKER: {ToolGroup.READ, ToolGroup.WRITE},
|
||||
}
|
||||
|
||||
|
||||
@@ -61,7 +62,6 @@ def build_toolkit(
|
||||
groups: set[ToolGroup] | None = None,
|
||||
enabled_tool_names: set[str] | None = None,
|
||||
enable_hitl: bool | None = None,
|
||||
enable_approval_layer: bool = True,
|
||||
):
|
||||
toolkit = Toolkit()
|
||||
enabled_names = _resolve_enabled_tools(
|
||||
@@ -85,7 +85,7 @@ def build_toolkit(
|
||||
preset_kwargs=preset_kwargs,
|
||||
)
|
||||
|
||||
approval_enabled = enable_approval_layer if enable_hitl is None else enable_hitl
|
||||
approval_enabled = enable_hitl if enable_hitl is not None else True
|
||||
if approval_enabled:
|
||||
register_tool_middlewares(toolkit=toolkit, config_by_name=TOOL_CONFIGS)
|
||||
|
||||
@@ -94,20 +94,26 @@ def build_toolkit(
|
||||
|
||||
def build_stage_toolkit(
|
||||
*,
|
||||
stage: str,
|
||||
agent_type: AgentType,
|
||||
session: AsyncSession,
|
||||
owner_id: UUID,
|
||||
enabled_tool_names: set[str] | None = None,
|
||||
enable_hitl: bool | None = None,
|
||||
enable_approval_layer: bool = True,
|
||||
):
|
||||
groups = STAGE_TO_GROUPS.get(stage)
|
||||
groups = AGENT_TYPE_TO_GROUPS.get(agent_type)
|
||||
if groups is None:
|
||||
raise ValueError(f"unknown stage: {stage}")
|
||||
raise ValueError(f"unknown agent_type: {agent_type}")
|
||||
|
||||
stage_enabled_names = resolve_tool_names_by_groups(set(groups))
|
||||
selected_names = (
|
||||
stage_enabled_names
|
||||
if enabled_tool_names is None
|
||||
else stage_enabled_names | set(enabled_tool_names)
|
||||
)
|
||||
|
||||
return build_toolkit(
|
||||
session=session,
|
||||
owner_id=owner_id,
|
||||
groups=set(groups),
|
||||
enabled_tool_names=selected_names,
|
||||
enable_hitl=enable_hitl,
|
||||
enable_approval_layer=enable_approval_layer,
|
||||
)
|
||||
|
||||
@@ -12,17 +12,10 @@ from schemas.agent.ui_hints import (
|
||||
UiHintAction,
|
||||
UiHintActionNavigation,
|
||||
UiHintActionStyle,
|
||||
UiHintErrorBlock,
|
||||
UiHintKeyValuePair,
|
||||
UiHintKvBlock,
|
||||
UiHintListBlock,
|
||||
UiHintIntent,
|
||||
UiHintKvItem,
|
||||
UiHintListItem,
|
||||
UiHintOperationBlock,
|
||||
UiHintOperationResult,
|
||||
UiHintOperationType,
|
||||
UiHintStatus,
|
||||
UiHintTextBlock,
|
||||
UiHintTextFormat,
|
||||
UiHintsPayload,
|
||||
)
|
||||
|
||||
@@ -40,18 +33,10 @@ def calendar_error_output(
|
||||
retryable: bool,
|
||||
) -> ToolResponse:
|
||||
ui_hints = UiHintsPayload(
|
||||
intent=UiHintIntent.STATUS,
|
||||
status=UiHintStatus.ERROR,
|
||||
title="日历操作失败",
|
||||
description=message,
|
||||
blocks=[
|
||||
UiHintErrorBlock(
|
||||
kind="error",
|
||||
title="操作失败",
|
||||
errorCode=code,
|
||||
message=message,
|
||||
retryable=retryable,
|
||||
)
|
||||
],
|
||||
body=message,
|
||||
)
|
||||
output = build_error_output(
|
||||
tool_name=tool_name,
|
||||
@@ -84,29 +69,17 @@ def calendar_read_hints(
|
||||
for event in events
|
||||
]
|
||||
return UiHintsPayload(
|
||||
intent=UiHintIntent.LIST,
|
||||
status=UiHintStatus.SUCCESS,
|
||||
title="日程列表",
|
||||
description=f"共 {total} 个日程",
|
||||
blocks=[
|
||||
UiHintKvBlock(
|
||||
kind="kv",
|
||||
title="分页信息",
|
||||
pairs=[
|
||||
UiHintKeyValuePair(key="total", label="总数", value=total),
|
||||
UiHintKeyValuePair(key="page", label="当前页", value=page),
|
||||
UiHintKeyValuePair(key="page_size", label="每页", value=page_size),
|
||||
UiHintKeyValuePair(
|
||||
key="total_pages", label="总页数", value=total_pages
|
||||
),
|
||||
],
|
||||
),
|
||||
UiHintListBlock(
|
||||
kind="list",
|
||||
title="日程项",
|
||||
items=event_items,
|
||||
emptyText="当前没有日程",
|
||||
),
|
||||
items=[
|
||||
UiHintKvItem(key="total", label="总数", value=total),
|
||||
UiHintKvItem(key="page", label="当前页", value=page),
|
||||
UiHintKvItem(key="page_size", label="每页", value=page_size),
|
||||
UiHintKvItem(key="total_pages", label="总页数", value=total_pages),
|
||||
],
|
||||
list_items=event_items,
|
||||
actions=[
|
||||
UiHintAction(
|
||||
label="打开日历",
|
||||
@@ -125,65 +98,38 @@ def calendar_write_hints(
|
||||
event: dict[str, Any] | None,
|
||||
event_id: str | None,
|
||||
) -> UiHintsPayload:
|
||||
operation_type = UiHintOperationType.EXECUTE
|
||||
if operation == "create":
|
||||
operation_type = UiHintOperationType.CREATE
|
||||
elif operation == "update":
|
||||
operation_type = UiHintOperationType.UPDATE
|
||||
elif operation == "delete":
|
||||
operation_type = UiHintOperationType.DELETE
|
||||
kv_items: list[UiHintKvItem] = []
|
||||
|
||||
blocks: list[Any] = [
|
||||
UiHintOperationBlock(
|
||||
kind="operation",
|
||||
title="日历写入结果",
|
||||
operation=operation_type,
|
||||
result=UiHintOperationResult.SUCCESS,
|
||||
message=message,
|
||||
affectedCount=1,
|
||||
)
|
||||
]
|
||||
if event:
|
||||
blocks.append(
|
||||
UiHintKvBlock(
|
||||
kind="kv",
|
||||
title="日程详情",
|
||||
pairs=[
|
||||
UiHintKeyValuePair(
|
||||
key="event_id",
|
||||
label="日程ID",
|
||||
value=str(event.get("id") or ""),
|
||||
copyable=True,
|
||||
),
|
||||
UiHintKeyValuePair(
|
||||
key="title",
|
||||
label="标题",
|
||||
value=str(event.get("title") or ""),
|
||||
copyable=True,
|
||||
),
|
||||
UiHintKeyValuePair(
|
||||
key="start_at",
|
||||
label="开始时间",
|
||||
value=str(event.get("startAt") or ""),
|
||||
copyable=True,
|
||||
),
|
||||
],
|
||||
)
|
||||
)
|
||||
kv_items = [
|
||||
UiHintKvItem(
|
||||
key="event_id",
|
||||
label="日程ID",
|
||||
value=str(event.get("id") or ""),
|
||||
copyable=True,
|
||||
),
|
||||
UiHintKvItem(
|
||||
key="title",
|
||||
label="标题",
|
||||
value=str(event.get("title") or ""),
|
||||
copyable=True,
|
||||
),
|
||||
UiHintKvItem(
|
||||
key="start_at",
|
||||
label="开始时间",
|
||||
value=str(event.get("startAt") or ""),
|
||||
copyable=True,
|
||||
),
|
||||
]
|
||||
elif event_id:
|
||||
blocks.append(
|
||||
UiHintTextBlock(
|
||||
kind="text",
|
||||
content=f"目标日程 ID: {event_id}",
|
||||
format=UiHintTextFormat.PLAIN,
|
||||
)
|
||||
)
|
||||
message = f"目标日程 ID: {event_id}\n{message}"
|
||||
|
||||
return UiHintsPayload(
|
||||
intent=UiHintIntent.STATUS,
|
||||
status=UiHintStatus.SUCCESS,
|
||||
title="日历操作完成",
|
||||
description=message,
|
||||
blocks=blocks,
|
||||
body=message,
|
||||
items=kv_items if kv_items else None,
|
||||
actions=[
|
||||
UiHintAction(
|
||||
label="查看日历",
|
||||
@@ -203,36 +149,17 @@ def calendar_share_hints(
|
||||
permission_text = (
|
||||
", ".join([k for k, v in permission.items() if v is True]) or "按邀请人单独设置"
|
||||
)
|
||||
|
||||
return UiHintsPayload(
|
||||
intent=UiHintIntent.STATUS,
|
||||
status=UiHintStatus.SUCCESS,
|
||||
title="日程已分享",
|
||||
description=f"已邀请 {len(invited)} 人",
|
||||
blocks=[
|
||||
UiHintOperationBlock(
|
||||
kind="operation",
|
||||
title="分享结果",
|
||||
operation=UiHintOperationType.EXECUTE,
|
||||
result=UiHintOperationResult.SUCCESS,
|
||||
message=f"已邀请 {len(invited)} 人",
|
||||
affectedCount=len(invited),
|
||||
),
|
||||
UiHintKvBlock(
|
||||
kind="kv",
|
||||
title="分享信息",
|
||||
pairs=[
|
||||
UiHintKeyValuePair(
|
||||
key="event_id", label="日程ID", value=event_id, copyable=True
|
||||
),
|
||||
UiHintKeyValuePair(
|
||||
key="permission", label="权限", value=permission_text
|
||||
),
|
||||
],
|
||||
),
|
||||
UiHintListBlock(
|
||||
kind="list",
|
||||
title="被邀请人",
|
||||
items=[UiHintListItem(title=email) for email in invited],
|
||||
emptyText="暂无被邀请人",
|
||||
),
|
||||
items=[
|
||||
UiHintKvItem(key="event_id", label="日程ID", value=event_id, copyable=True),
|
||||
UiHintKvItem(key="permission", label="权限", value=permission_text),
|
||||
],
|
||||
list_items=[UiHintListItem(title=email) for email in invited]
|
||||
if invited
|
||||
else [],
|
||||
)
|
||||
|
||||
@@ -14,7 +14,9 @@ from schemas.agent.runtime_models import (
|
||||
from schemas.agent.system_agent import AgentType, SystemAgentLLMConfig
|
||||
from schemas.agent.ui_hints import (
|
||||
UiHintAction,
|
||||
UiHintBlock,
|
||||
UiHintIntent,
|
||||
UiHintSection,
|
||||
UiHintStatus,
|
||||
UiHintsPayload,
|
||||
)
|
||||
|
||||
@@ -29,7 +31,9 @@ __all__ = [
|
||||
"ToolStatus",
|
||||
"UiMode",
|
||||
"UiHintAction",
|
||||
"UiHintBlock",
|
||||
"UiHintIntent",
|
||||
"UiHintSection",
|
||||
"UiHintStatus",
|
||||
"UiHintsPayload",
|
||||
"WorkerAgentOutputLite",
|
||||
"WorkerAgentOutputRich",
|
||||
|
||||
@@ -1,10 +1,26 @@
|
||||
"""
|
||||
UiHints - 描述性 UI 提示
|
||||
|
||||
设计原则:
|
||||
- 描述性而非渲染性: 告诉编译器“要展示什么”,而不是“如何渲染”
|
||||
- 最小化 token: 保持字段简洁
|
||||
- 可编译: 可机械转换为 UiSchemaRenderer
|
||||
- 尽量无损: hints 中的主要内容字段应尽量被保留到 renderer 中
|
||||
|
||||
Version: 2.1
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import Annotated, Any, Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
# ============================================================
|
||||
# Enums
|
||||
# ============================================================
|
||||
|
||||
|
||||
class UiHintStatus(str, Enum):
|
||||
INFO = "info"
|
||||
@@ -14,6 +30,17 @@ class UiHintStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
|
||||
|
||||
class UiHintIntent(str, Enum):
|
||||
"""主要展示意图(弱提示,不应决定字段生死)"""
|
||||
|
||||
MESSAGE = "message" # 普通消息/说明
|
||||
DATA = "data" # 数据/结果摘要
|
||||
LIST = "list" # 列表为主
|
||||
STATUS = "status" # 状态结果为主
|
||||
FORM = "form" # 结构化内容(当前不表示真实输入表单)
|
||||
MIXED = "mixed" # 混合内容
|
||||
|
||||
|
||||
class UiHintActionStyle(str, Enum):
|
||||
PRIMARY = "primary"
|
||||
SECONDARY = "secondary"
|
||||
@@ -26,520 +53,238 @@ class UiHintTextFormat(str, Enum):
|
||||
MARKDOWN = "markdown"
|
||||
|
||||
|
||||
class UiHintContainerDirection(str, Enum):
|
||||
VERTICAL = "vertical"
|
||||
HORIZONTAL = "horizontal"
|
||||
class UiHintActionType(str, Enum):
|
||||
NAVIGATION = "navigation"
|
||||
URL = "url"
|
||||
EVENT = "event"
|
||||
TOOL = "tool"
|
||||
COPY = "copy"
|
||||
PAYLOAD = "payload"
|
||||
|
||||
|
||||
class UiHintKvLayout(str, Enum):
|
||||
VERTICAL = "vertical"
|
||||
HORIZONTAL = "horizontal"
|
||||
GRID = "grid"
|
||||
class UiHintIconSource(str, Enum):
|
||||
ICON = "icon"
|
||||
EMOJI = "emoji"
|
||||
URL = "url"
|
||||
|
||||
|
||||
class UiHintOperationType(str, Enum):
|
||||
CREATE = "create"
|
||||
UPDATE = "update"
|
||||
DELETE = "delete"
|
||||
EXECUTE = "execute"
|
||||
# ============================================================
|
||||
# Base Config
|
||||
# ============================================================
|
||||
|
||||
|
||||
class UiHintOperationResult(str, Enum):
|
||||
SUCCESS = "success"
|
||||
FAILURE = "failure"
|
||||
PARTIAL = "partial"
|
||||
|
||||
|
||||
class UiHintConfirm(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
title: str | None = Field(
|
||||
default=None,
|
||||
description="Optional confirmation dialog title.",
|
||||
)
|
||||
message: str | None = Field(
|
||||
default=None,
|
||||
description="Optional confirmation message shown before action execution.",
|
||||
)
|
||||
confirm_label: str | None = Field(
|
||||
default=None,
|
||||
alias="confirmLabel",
|
||||
description="Optional confirm button label, e.g. 'Delete'.",
|
||||
)
|
||||
cancel_label: str | None = Field(
|
||||
default=None,
|
||||
alias="cancelLabel",
|
||||
description="Optional cancel button label, e.g. 'Cancel'.",
|
||||
class UiHintBaseModel(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
|
||||
class UiHintActionNavigation(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
# ============================================================
|
||||
# Action Targets
|
||||
# ============================================================
|
||||
|
||||
|
||||
class UiHintActionNavigation(UiHintBaseModel):
|
||||
type: Literal["navigation"]
|
||||
path: str = Field(
|
||||
...,
|
||||
description="Internal route path to navigate to.",
|
||||
)
|
||||
params: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="Optional route params for internal navigation.",
|
||||
)
|
||||
path: str = Field(..., description="Internal route path.")
|
||||
params: dict[str, Any] | None = Field(default=None, description="Route params.")
|
||||
|
||||
|
||||
class UiHintActionUrl(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
class UiHintActionUrl(UiHintBaseModel):
|
||||
type: Literal["url"]
|
||||
url: str = Field(..., description="External URL to open.")
|
||||
target: Literal["_self", "_blank"] | None = Field(
|
||||
default=None,
|
||||
description="Optional browser target for URL action.",
|
||||
)
|
||||
url: str = Field(..., description="External URL.")
|
||||
target: Literal["_self", "_blank"] | None = Field(default=None)
|
||||
|
||||
|
||||
class UiHintActionEvent(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
class UiHintActionEvent(UiHintBaseModel):
|
||||
type: Literal["event"]
|
||||
event: str = Field(
|
||||
...,
|
||||
description="Frontend domain event name, e.g. 'chat.retry'.",
|
||||
)
|
||||
payload: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="Optional event payload for frontend event handling.",
|
||||
)
|
||||
event: str = Field(..., description="Frontend event name.")
|
||||
payload: dict[str, Any] | None = Field(default=None)
|
||||
|
||||
|
||||
class UiHintActionTool(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
class UiHintActionTool(UiHintBaseModel):
|
||||
type: Literal["tool"]
|
||||
tool_id: str = Field(
|
||||
alias="toolId",
|
||||
description="Tool identifier used to trigger another tool execution.",
|
||||
)
|
||||
params: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="Optional parameters for tool re-execution.",
|
||||
)
|
||||
tool_id: str = Field(alias="toolId", description="Tool identifier.")
|
||||
params: dict[str, Any] | None = Field(default=None)
|
||||
|
||||
|
||||
class UiHintActionCopy(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
class UiHintActionCopy(UiHintBaseModel):
|
||||
type: Literal["copy"]
|
||||
content: str = Field(..., description="Text content to copy to clipboard.")
|
||||
success_message: str | None = Field(
|
||||
default=None,
|
||||
alias="successMessage",
|
||||
description="Optional user-facing success message after copy.",
|
||||
)
|
||||
content: str = Field(..., description="Content to copy.")
|
||||
success_message: str | None = Field(alias="successMessage", default=None)
|
||||
|
||||
|
||||
class UiHintActionPayload(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
class UiHintActionPayload(UiHintBaseModel):
|
||||
type: Literal["payload"]
|
||||
payload: dict[str, Any] = Field(
|
||||
...,
|
||||
description="Structured payload to submit to frontend or gateway.",
|
||||
)
|
||||
submit_to: str | None = Field(
|
||||
default=None,
|
||||
alias="submitTo",
|
||||
description="Optional submit target path or endpoint key.",
|
||||
)
|
||||
payload: dict[str, Any] = Field(..., description="Structured payload.")
|
||||
submit_to: str | None = Field(alias="submitTo", default=None)
|
||||
|
||||
|
||||
UiHintActionTarget = Annotated[
|
||||
(
|
||||
UiHintActionNavigation
|
||||
| UiHintActionUrl
|
||||
| UiHintActionEvent
|
||||
| UiHintActionTool
|
||||
| UiHintActionCopy
|
||||
| UiHintActionPayload
|
||||
),
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
UiHintActionTarget = (
|
||||
UiHintActionNavigation
|
||||
| UiHintActionUrl
|
||||
| UiHintActionEvent
|
||||
| UiHintActionTool
|
||||
| UiHintActionCopy
|
||||
| UiHintActionPayload
|
||||
)
|
||||
|
||||
|
||||
class UiHintAction(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
json_schema_extra={
|
||||
"examples": [
|
||||
{
|
||||
"id": "action-open-calendar",
|
||||
"label": "Open calendar",
|
||||
"style": "primary",
|
||||
"action": {"type": "navigation", "path": "/calendar"},
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
id: str | None = Field(
|
||||
default=None,
|
||||
description="Optional stable action id for tracking and targeting.",
|
||||
)
|
||||
label: str = Field(
|
||||
...,
|
||||
description="User-facing action label shown on button/link.",
|
||||
)
|
||||
style: UiHintActionStyle | None = Field(
|
||||
default=None,
|
||||
description="Optional semantic button style.",
|
||||
)
|
||||
disabled: bool = Field(
|
||||
default=False,
|
||||
description="Whether this action should be rendered as disabled.",
|
||||
)
|
||||
action: UiHintActionTarget = Field(
|
||||
...,
|
||||
description="Executable action target definition.",
|
||||
)
|
||||
confirm: UiHintConfirm | None = Field(
|
||||
default=None,
|
||||
description="Optional confirmation requirement before execution.",
|
||||
)
|
||||
class UiHintAction(UiHintBaseModel):
|
||||
label: str = Field(..., description="Button label.")
|
||||
style: UiHintActionStyle | None = Field(default=None, description="Button style.")
|
||||
disabled: bool = Field(default=False, description="Disabled state.")
|
||||
action: UiHintActionTarget = Field(..., description="Action to execute.")
|
||||
|
||||
|
||||
class UiHintIcon(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
source: Literal["icon", "emoji", "url"] = Field(
|
||||
...,
|
||||
description="Icon source type.",
|
||||
)
|
||||
value: str = Field(
|
||||
...,
|
||||
description="Icon identifier, emoji text, or image URL based on source.",
|
||||
)
|
||||
color: str | None = Field(
|
||||
default=None,
|
||||
description="Optional semantic color hint. Do not encode pixel-level style rules.",
|
||||
)
|
||||
size: int | None = Field(
|
||||
default=None,
|
||||
description="Optional icon size hint in abstract UI units.",
|
||||
)
|
||||
# ============================================================
|
||||
# Small Descriptive Models
|
||||
# ============================================================
|
||||
|
||||
|
||||
class UiHintBadge(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
label: str = Field(..., description="Badge text label.")
|
||||
variant: Literal["default", "success", "warning", "error", "info"] = Field(
|
||||
default="default",
|
||||
description="Semantic badge variant.",
|
||||
)
|
||||
class UiHintIcon(UiHintBaseModel):
|
||||
source: UiHintIconSource = Field(default=UiHintIconSource.ICON)
|
||||
value: str = Field(..., description="Icon identifier / emoji / url.")
|
||||
color: str | None = Field(default=None)
|
||||
size: int | None = Field(default=None)
|
||||
|
||||
|
||||
class UiHintKeyValuePair(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
key: str = Field(..., description="Stable key identifier for this pair.")
|
||||
label: str | None = Field(
|
||||
default=None,
|
||||
description="Optional user-facing label. Fallback to key when missing.",
|
||||
)
|
||||
value: str | int | bool | None = Field(
|
||||
default=None,
|
||||
description="Scalar value for this key-value pair.",
|
||||
)
|
||||
copyable: bool = Field(
|
||||
default=False,
|
||||
description="Whether frontend may offer copy interaction for this value.",
|
||||
)
|
||||
class UiHintKvItem(UiHintBaseModel):
|
||||
key: str = Field(..., description="Key identifier.")
|
||||
label: str | None = Field(default=None, description="Display label.")
|
||||
value: Any = Field(default=None, description="Value.")
|
||||
copyable: bool = Field(default=False, description="Allow copy.")
|
||||
|
||||
|
||||
class UiHintListItem(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
id: str | None = Field(
|
||||
default=None,
|
||||
description="Optional stable list item id.",
|
||||
)
|
||||
title: str = Field(..., description="Primary list item title.")
|
||||
subtitle: str | None = Field(
|
||||
default=None,
|
||||
description="Optional short secondary text.",
|
||||
)
|
||||
description: str | None = Field(
|
||||
default=None,
|
||||
description="Optional detailed description for this item.",
|
||||
)
|
||||
icon: UiHintIcon | None = Field(
|
||||
default=None,
|
||||
description="Optional semantic icon metadata.",
|
||||
)
|
||||
badge: UiHintBadge | None = Field(
|
||||
default=None,
|
||||
description="Optional semantic badge metadata.",
|
||||
)
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Optional non-visual metadata for analytics or interactions.",
|
||||
)
|
||||
actions: list[UiHintAction] = Field(
|
||||
default_factory=list,
|
||||
description="Optional per-item actions, recommended up to 3.",
|
||||
)
|
||||
class UiHintListItem(UiHintBaseModel):
|
||||
id: str | None = Field(default=None)
|
||||
title: str = Field(..., description="Item title.")
|
||||
subtitle: str | None = Field(default=None)
|
||||
description: str | None = Field(default=None)
|
||||
icon: UiHintIcon | None = Field(default=None)
|
||||
status: UiHintStatus | None = Field(default=None)
|
||||
actions: list[UiHintAction] = Field(default_factory=list)
|
||||
|
||||
|
||||
class UiHintPagination(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
class UiHintSection(UiHintBaseModel):
|
||||
title: str | None = Field(default=None, description="Section title.")
|
||||
description: str | None = Field(default=None, description="Section description.")
|
||||
icon: UiHintIcon | None = Field(default=None, description="Section icon.")
|
||||
|
||||
page: int = Field(..., description="Current page number starting from 1.")
|
||||
page_size: int = Field(
|
||||
alias="pageSize",
|
||||
description="Page size used for this list page.",
|
||||
)
|
||||
total: int = Field(..., description="Total number of records.")
|
||||
has_more: bool = Field(
|
||||
alias="hasMore",
|
||||
description="Whether there are more pages after current page.",
|
||||
)
|
||||
|
||||
|
||||
class UiHintBaseBlock(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
id: str | None = Field(
|
||||
default=None,
|
||||
description="Optional stable block id.",
|
||||
)
|
||||
title: str | None = Field(
|
||||
default=None,
|
||||
description="Optional block title.",
|
||||
)
|
||||
description: str | None = Field(
|
||||
default=None,
|
||||
description="Optional block description.",
|
||||
)
|
||||
status: UiHintStatus | None = Field(
|
||||
default=None,
|
||||
description="Optional semantic status for this block.",
|
||||
)
|
||||
actions: list[UiHintAction] = Field(
|
||||
default_factory=list,
|
||||
description="Optional block-level actions, recommended up to 3.",
|
||||
)
|
||||
|
||||
|
||||
class UiHintTextBlock(UiHintBaseBlock):
|
||||
kind: Literal["text"]
|
||||
content: str = Field(
|
||||
...,
|
||||
description="Main text content to present.",
|
||||
)
|
||||
format: UiHintTextFormat = Field(
|
||||
content: str | None = Field(default=None, description="Main text content.")
|
||||
content_format: UiHintTextFormat = Field(
|
||||
default=UiHintTextFormat.PLAIN,
|
||||
description="Text format: plain or markdown.",
|
||||
alias="contentFormat",
|
||||
description="Section content text format.",
|
||||
)
|
||||
|
||||
|
||||
class UiHintCardBlock(UiHintBaseBlock):
|
||||
kind: Literal["card"]
|
||||
children: list["UiHintBlock"] = Field(
|
||||
items: list[UiHintKvItem] = Field(default_factory=list, description="KV items.")
|
||||
list_items: list[UiHintListItem] = Field(
|
||||
default_factory=list,
|
||||
description="Nested child blocks grouped under this card.",
|
||||
alias="listItems",
|
||||
description="List items.",
|
||||
)
|
||||
actions: list[UiHintAction] = Field(default_factory=list, description="Actions.")
|
||||
|
||||
|
||||
class UiHintKvBlock(UiHintBaseBlock):
|
||||
kind: Literal["kv"]
|
||||
pairs: list[UiHintKeyValuePair] = Field(
|
||||
default_factory=list,
|
||||
description="Key-value pairs to display.",
|
||||
)
|
||||
layout: UiHintKvLayout = Field(
|
||||
default=UiHintKvLayout.VERTICAL,
|
||||
description="Preferred semantic layout for key-value content.",
|
||||
)
|
||||
# ============================================================
|
||||
# Root Payload
|
||||
# ============================================================
|
||||
|
||||
|
||||
class UiHintListBlock(UiHintBaseBlock):
|
||||
kind: Literal["list"]
|
||||
items: list[UiHintListItem] = Field(
|
||||
default_factory=list,
|
||||
description="List items to present.",
|
||||
)
|
||||
pagination: UiHintPagination | None = Field(
|
||||
default=None,
|
||||
description="Optional pagination metadata.",
|
||||
)
|
||||
empty_text: str | None = Field(
|
||||
default=None,
|
||||
alias="emptyText",
|
||||
description="Optional message shown when list items are empty.",
|
||||
)
|
||||
class UiHintsPayload(UiHintBaseModel):
|
||||
"""
|
||||
描述性 UI 提示
|
||||
|
||||
设计目标:
|
||||
- agent 输出尽可能短
|
||||
- 不表达布局细节
|
||||
- 编译器负责转换为完整 UiSchemaRenderer
|
||||
"""
|
||||
|
||||
class UiHintOperationBlock(UiHintBaseBlock):
|
||||
kind: Literal["operation"]
|
||||
operation: UiHintOperationType = Field(
|
||||
...,
|
||||
description="Operation category: create/update/delete/execute.",
|
||||
)
|
||||
result: UiHintOperationResult = Field(
|
||||
...,
|
||||
description="Operation result: success/failure/partial.",
|
||||
)
|
||||
message: str | None = Field(
|
||||
default=None,
|
||||
description="Optional operation summary message.",
|
||||
)
|
||||
affected_count: int | None = Field(
|
||||
default=None,
|
||||
alias="affectedCount",
|
||||
description="Optional affected record count.",
|
||||
)
|
||||
details: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="Optional machine-readable operation details.",
|
||||
)
|
||||
|
||||
|
||||
class UiHintErrorBlock(UiHintBaseBlock):
|
||||
kind: Literal["error"]
|
||||
error_code: str = Field(
|
||||
alias="errorCode",
|
||||
description="Stable error code for categorization.",
|
||||
)
|
||||
message: str = Field(
|
||||
...,
|
||||
description="Human-readable error message.",
|
||||
)
|
||||
retryable: bool = Field(
|
||||
default=False,
|
||||
description="Whether retry is likely to succeed.",
|
||||
)
|
||||
details: str | None = Field(
|
||||
default=None,
|
||||
description="Optional plain-text diagnostic details.",
|
||||
)
|
||||
suggestions: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Optional actionable suggestions, recommended up to 3.",
|
||||
)
|
||||
|
||||
|
||||
class UiHintContainerBlock(UiHintBaseBlock):
|
||||
kind: Literal["container"]
|
||||
direction: UiHintContainerDirection = Field(
|
||||
default=UiHintContainerDirection.VERTICAL,
|
||||
description="Child block layout direction.",
|
||||
)
|
||||
gap: int | None = Field(
|
||||
default=None,
|
||||
description="Optional semantic spacing hint between children.",
|
||||
)
|
||||
children: list["UiHintBlock"] = Field(
|
||||
default_factory=list,
|
||||
description="Nested child blocks in this container.",
|
||||
)
|
||||
|
||||
|
||||
class UiHintCustomBlock(UiHintBaseBlock):
|
||||
kind: Literal["custom"]
|
||||
renderer_key: str = Field(
|
||||
alias="rendererKey",
|
||||
description=(
|
||||
"Custom semantic renderer key. Use only when standard block kinds "
|
||||
"cannot represent the intent."
|
||||
),
|
||||
)
|
||||
payload: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Structured custom payload consumed by the renderer.",
|
||||
)
|
||||
|
||||
|
||||
UiHintBlock = Annotated[
|
||||
(
|
||||
UiHintTextBlock
|
||||
| UiHintCardBlock
|
||||
| UiHintKvBlock
|
||||
| UiHintListBlock
|
||||
| UiHintOperationBlock
|
||||
| UiHintErrorBlock
|
||||
| UiHintContainerBlock
|
||||
| UiHintCustomBlock
|
||||
),
|
||||
Field(discriminator="kind"),
|
||||
]
|
||||
|
||||
|
||||
class UiHintsPayload(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
populate_by_name=True,
|
||||
json_schema_extra={
|
||||
"examples": [
|
||||
{
|
||||
"version": "1.0",
|
||||
"status": "info",
|
||||
"title": "Schedule update",
|
||||
"blocks": [
|
||||
{
|
||||
"kind": "text",
|
||||
"content": "Your meeting is moved to 3:00 PM.",
|
||||
"format": "plain",
|
||||
},
|
||||
{
|
||||
"kind": "list",
|
||||
"title": "Next steps",
|
||||
"items": [
|
||||
{"title": "Open calendar"},
|
||||
{"title": "Notify attendees"},
|
||||
],
|
||||
},
|
||||
"intent": "status",
|
||||
"status": "success",
|
||||
"title": "日程已创建",
|
||||
"body": "本次创建已成功完成。",
|
||||
"items": [
|
||||
{"key": "title", "label": "主题", "value": "Q1 规划会议"},
|
||||
{"key": "time", "label": "时间", "value": "2026-03-15 14:00"},
|
||||
],
|
||||
"actions": [
|
||||
{
|
||||
"label": "Open calendar",
|
||||
"label": "查看详情",
|
||||
"style": "primary",
|
||||
"action": {"type": "navigation", "path": "/calendar"},
|
||||
}
|
||||
"action": {
|
||||
"type": "navigation",
|
||||
"path": "/calendar/evt_123",
|
||||
},
|
||||
},
|
||||
{
|
||||
"label": "删除",
|
||||
"style": "danger",
|
||||
"action": {
|
||||
"type": "tool",
|
||||
"toolId": "calendar.delete",
|
||||
"params": {"eventId": "evt_123"},
|
||||
},
|
||||
},
|
||||
],
|
||||
"meta": {"source": "worker"},
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
version: str = Field(
|
||||
default="1.0",
|
||||
description="Ui hints payload version.",
|
||||
version: str = Field(default="2.1")
|
||||
|
||||
intent: UiHintIntent = Field(
|
||||
default=UiHintIntent.MESSAGE,
|
||||
description="Primary display intent.",
|
||||
)
|
||||
status: UiHintStatus = Field(
|
||||
default=UiHintStatus.INFO,
|
||||
description="Overall semantic status for the full ui_hints payload.",
|
||||
description="Overall status.",
|
||||
)
|
||||
title: str | None = Field(
|
||||
default=None,
|
||||
description="Optional top-level semantic title.",
|
||||
|
||||
title: str | None = Field(default=None, description="Top-level title.")
|
||||
description: str | None = Field(default=None, description="Top-level description.")
|
||||
|
||||
body: str | None = Field(default=None, description="Top-level main body text.")
|
||||
body_format: UiHintTextFormat = Field(
|
||||
default=UiHintTextFormat.PLAIN,
|
||||
alias="bodyFormat",
|
||||
description="Body text format.",
|
||||
)
|
||||
description: str | None = Field(
|
||||
default=None,
|
||||
description="Optional top-level semantic description.",
|
||||
)
|
||||
blocks: list[UiHintBlock] = Field(
|
||||
|
||||
items: list[UiHintKvItem] = Field(
|
||||
default_factory=list,
|
||||
description="Main semantic content blocks.",
|
||||
description="Top-level key-value items.",
|
||||
)
|
||||
list_items: list[UiHintListItem] = Field(
|
||||
default_factory=list,
|
||||
alias="listItems",
|
||||
description="Top-level list items.",
|
||||
)
|
||||
sections: list[UiHintSection] = Field(
|
||||
default_factory=list,
|
||||
description="Grouped sections.",
|
||||
)
|
||||
actions: list[UiHintAction] = Field(
|
||||
default_factory=list,
|
||||
description="Optional top-level actions, recommended up to 3.",
|
||||
description="Top-level actions.",
|
||||
)
|
||||
|
||||
icon: UiHintIcon | None = Field(
|
||||
default=None,
|
||||
description="Top-level icon.",
|
||||
)
|
||||
meta: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Optional non-visual metadata for tracing and integration.",
|
||||
description="Extra meta, e.g. requestId/toolId/traceId/userId.",
|
||||
)
|
||||
|
||||
|
||||
UiHintCardBlock.model_rebuild()
|
||||
UiHintContainerBlock.model_rebuild()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,12 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from typing import ClassVar
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from schemas.agent.runtime_models import RouterAgentOutput, WorkerAgentOutputRich
|
||||
|
||||
from ..agent import AgentType, ToolAgentOutput, WorkerAgentOutput
|
||||
from ..agent import AgentType, ToolAgentOutput
|
||||
|
||||
|
||||
class UserMessageAttachments(BaseModel):
|
||||
@@ -22,8 +24,9 @@ class AgentChatMessageMetadata(BaseModel):
|
||||
run_id: str
|
||||
agent_type: AgentType | None = None
|
||||
user_message_attachments: UserMessageAttachments | None = None
|
||||
router_agent_output: RouterAgentOutput | None = None
|
||||
tool_agent_output: ToolAgentOutput | None = None
|
||||
worker_agent_output: WorkerAgentOutput | None = None
|
||||
worker_agent_output: WorkerAgentOutputRich | None = None
|
||||
|
||||
|
||||
class AgentChatMessage(BaseModel):
|
||||
@@ -35,5 +38,11 @@ class AgentChatMessage(BaseModel):
|
||||
seq: int
|
||||
role: str
|
||||
content: str
|
||||
model_code: str | None = None
|
||||
tool_name: str | None = None
|
||||
input_tokens: int = Field(default=0, ge=0)
|
||||
output_tokens: int = Field(default=0, ge=0)
|
||||
cost: Decimal = Field(default=Decimal("0"))
|
||||
latency_ms: int | None = Field(default=None, ge=0)
|
||||
metadata: AgentChatMessageMetadata | dict[str, object] | None = None
|
||||
timestamp: datetime
|
||||
|
||||
@@ -118,6 +118,31 @@ class LiteLLMService:
|
||||
+ normalized_completion_tokens * selected_tier.output_cost_per_token
|
||||
)
|
||||
|
||||
def build_usage_metadata(
|
||||
self,
|
||||
*,
|
||||
model: str,
|
||||
usage_summary: dict[str, int] | None,
|
||||
) -> dict[str, Any]:
|
||||
summary = usage_summary or {}
|
||||
input_tokens = max(int(summary.get("input_tokens", 0) or 0), 0)
|
||||
output_tokens = max(int(summary.get("output_tokens", 0) or 0), 0)
|
||||
latency_ms = max(int(summary.get("latency_ms", 0) or 0), 0)
|
||||
cached_prompt_tokens = max(int(summary.get("cached_prompt_tokens", 0) or 0), 0)
|
||||
cost = self.calculate_cost(
|
||||
model=model,
|
||||
prompt_tokens=input_tokens,
|
||||
completion_tokens=output_tokens,
|
||||
cached_prompt_tokens=cached_prompt_tokens,
|
||||
)
|
||||
return {
|
||||
"model": model,
|
||||
"inputTokens": input_tokens,
|
||||
"outputTokens": output_tokens,
|
||||
"cost": cost,
|
||||
"latencyMs": latency_ms,
|
||||
}
|
||||
|
||||
def run_completion_with_cost(
|
||||
self,
|
||||
*,
|
||||
|
||||
@@ -8,11 +8,6 @@ from redis.asyncio import Redis
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.agentscope.events import RedisStreamBus
|
||||
from core.agentscope.runtime.tasks import (
|
||||
run_command_task,
|
||||
run_command_task_bulk,
|
||||
run_command_task_critical,
|
||||
)
|
||||
from core.agentscope.tools.tool_result_storage import (
|
||||
create_tool_result_storage,
|
||||
)
|
||||
@@ -48,6 +43,12 @@ class TaskiqQueueClient:
|
||||
|
||||
@staticmethod
|
||||
def _select_queue_task(command: dict[str, object]) -> Any:
|
||||
from core.agentscope.runtime.tasks import (
|
||||
run_command_task,
|
||||
run_command_task_bulk,
|
||||
run_command_task_critical,
|
||||
)
|
||||
|
||||
queue = str(command.get("queue", "default")).strip().lower()
|
||||
if queue == "critical":
|
||||
return run_command_task_critical
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date, datetime, time, timedelta, timezone
|
||||
from decimal import Decimal
|
||||
from typing import Protocol
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
@@ -218,6 +219,12 @@ class AgentRepository:
|
||||
"seq": int(message.seq),
|
||||
"role": role,
|
||||
"content": message.content,
|
||||
"model_code": message.model_code,
|
||||
"tool_name": message.tool_name,
|
||||
"input_tokens": int(message.input_tokens or 0),
|
||||
"output_tokens": int(message.output_tokens or 0),
|
||||
"cost": str(message.cost if message.cost is not None else Decimal("0")),
|
||||
"latency_ms": message.latency_ms,
|
||||
"metadata": message.metadata_json,
|
||||
"timestamp": message.created_at.astimezone(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
@@ -37,6 +37,7 @@ from v1.agent.schemas import (
|
||||
AttachmentReference,
|
||||
AttachmentSignedUrlResponse,
|
||||
AttachmentUploadResponse,
|
||||
HistorySnapshotResponse,
|
||||
TaskAcceptedResponse,
|
||||
)
|
||||
from v1.agent.service import AgentService, asr_service
|
||||
@@ -219,13 +220,13 @@ async def stream_events(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/history")
|
||||
@router.get("/history", response_model=HistorySnapshotResponse)
|
||||
async def get_user_history_snapshot(
|
||||
service: Annotated[AgentService, Depends(get_agent_service)],
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
thread_id: str | None = Query(default=None, alias="threadId"),
|
||||
before: date | None = Query(default=None),
|
||||
) -> dict[str, object]:
|
||||
) -> HistorySnapshotResponse:
|
||||
return await service.get_user_history_snapshot(
|
||||
current_user=current_user,
|
||||
thread_id=thread_id,
|
||||
|
||||
@@ -2,6 +2,8 @@ from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from schemas.agent.ui_schema import UiSchemaRenderer
|
||||
|
||||
|
||||
class TaskAcceptedResponse(BaseModel):
|
||||
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
|
||||
@@ -33,3 +35,35 @@ class AttachmentSignedUrlResponse(BaseModel):
|
||||
bucket: str
|
||||
path: str
|
||||
url: str
|
||||
|
||||
|
||||
class HistoryMessage(BaseModel):
|
||||
"""History message schema for /history endpoint response."""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
|
||||
|
||||
id: str = Field(description="Message UUID")
|
||||
seq: int = Field(description="Message sequence number")
|
||||
role: str = Field(description="Message role: user | assistant | tool")
|
||||
content: str = Field(description="Message text content")
|
||||
url: str | None = Field(
|
||||
default=None,
|
||||
description="Temporary signed URL for user-attached images",
|
||||
)
|
||||
ui_schema: UiSchemaRenderer | None = Field(
|
||||
default=None,
|
||||
description="Compiled UI schema from worker/tool ui_hints for frontend rendering",
|
||||
)
|
||||
timestamp: str = Field(description="Message creation timestamp in ISO-8601 format")
|
||||
|
||||
|
||||
class HistorySnapshotResponse(BaseModel):
|
||||
"""Response schema for GET /api/v1/agent/history"""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
|
||||
|
||||
scope: str = Field(default="history_day")
|
||||
thread_id: str | None = Field(default=None, alias="threadId")
|
||||
day: str | None = None
|
||||
has_more: bool = Field(default=False, alias="hasMore")
|
||||
messages: list[HistoryMessage] = Field(default_factory=list)
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import Any, Protocol
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import dashscope
|
||||
from ag_ui.core import RunAgentInput, StateSnapshotEvent
|
||||
from ag_ui.core import RunAgentInput
|
||||
from dashscope.audio.asr import Recognition, RecognitionCallback
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
@@ -21,6 +21,7 @@ from schemas.messages.chat_message import (
|
||||
AgentChatMessageMetadata,
|
||||
UserMessageAttachments,
|
||||
)
|
||||
from v1.agent.schemas import HistorySnapshotResponse
|
||||
|
||||
logger = get_logger(__name__)
|
||||
_ALLOWED_ATTACHMENT_MIME_TYPES = {"image/png", "image/jpeg", "image/webp"}
|
||||
@@ -416,27 +417,48 @@ class AgentService:
|
||||
thread_id: str,
|
||||
before: date | None,
|
||||
current_user: CurrentUser,
|
||||
) -> dict[str, object]:
|
||||
) -> HistorySnapshotResponse:
|
||||
from schemas.messages.chat_message import AgentChatMessage
|
||||
from v1.agent.utils import convert_message_to_history
|
||||
from v1.agent.schemas import HistoryMessage
|
||||
|
||||
owner = await self._repository.get_session_owner(session_id=thread_id)
|
||||
ensure_session_owner(owner_id=owner, current_user=current_user)
|
||||
day_payload = await self._repository.get_history_day(
|
||||
session_id=thread_id,
|
||||
before=before,
|
||||
)
|
||||
snapshot = {
|
||||
"scope": "history_day",
|
||||
"threadId": thread_id,
|
||||
"day": day_payload["day"] if day_payload else None,
|
||||
"hasMore": day_payload["hasMore"] if day_payload else False,
|
||||
"messages": day_payload["messages"] if day_payload else [],
|
||||
}
|
||||
event = StateSnapshotEvent(snapshot=snapshot).model_dump(
|
||||
mode="json",
|
||||
by_alias=True,
|
||||
exclude_none=True,
|
||||
|
||||
messages: list[HistoryMessage] = []
|
||||
if day_payload:
|
||||
raw_messages = day_payload.get("messages") or []
|
||||
for msg_dict in raw_messages:
|
||||
msg = AgentChatMessage.model_validate(msg_dict)
|
||||
|
||||
signed_url: str | None = None
|
||||
if self._attachment_storage and msg.metadata:
|
||||
att = msg.metadata.user_message_attachments
|
||||
if att:
|
||||
signed_url = await self._attachment_storage.create_signed_url(
|
||||
bucket=att.bucket,
|
||||
path=att.path,
|
||||
expires_in_seconds=self._SIGNED_URL_EXPIRES_IN_SECONDS,
|
||||
)
|
||||
|
||||
converted = convert_message_to_history(msg, None)
|
||||
if signed_url:
|
||||
converted["url"] = signed_url
|
||||
messages.append(HistoryMessage.model_validate(converted))
|
||||
|
||||
return HistorySnapshotResponse(
|
||||
scope="history_day",
|
||||
threadId=thread_id,
|
||||
day=str(day_payload.get("day"))
|
||||
if day_payload and day_payload.get("day")
|
||||
else None,
|
||||
hasMore=bool(day_payload.get("hasMore")) if day_payload else False,
|
||||
messages=messages,
|
||||
)
|
||||
event["threadId"] = thread_id
|
||||
return event
|
||||
|
||||
async def get_user_history_snapshot(
|
||||
self,
|
||||
@@ -444,22 +466,20 @@ class AgentService:
|
||||
current_user: CurrentUser,
|
||||
thread_id: str | None,
|
||||
before: date | None,
|
||||
) -> dict[str, object]:
|
||||
) -> HistorySnapshotResponse:
|
||||
target_thread_id = thread_id
|
||||
if target_thread_id is None:
|
||||
target_thread_id = await self._repository.get_latest_session_id_for_user(
|
||||
user_id=str(current_user.id)
|
||||
)
|
||||
if target_thread_id is None:
|
||||
return StateSnapshotEvent(
|
||||
snapshot={
|
||||
"scope": "history_day",
|
||||
"threadId": None,
|
||||
"day": None,
|
||||
"hasMore": False,
|
||||
"messages": [],
|
||||
}
|
||||
).model_dump(mode="json", by_alias=True, exclude_none=True)
|
||||
return HistorySnapshotResponse(
|
||||
scope="history_day",
|
||||
threadId=None,
|
||||
day=None,
|
||||
hasMore=False,
|
||||
messages=[],
|
||||
)
|
||||
return await self.get_history_snapshot(
|
||||
thread_id=target_thread_id,
|
||||
before=before,
|
||||
|
||||
@@ -0,0 +1,149 @@
|
||||
"""
|
||||
历史消息转换工具函数
|
||||
|
||||
将数据库中的原始消息转换为 API 响应的数据结构
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from core.agentscope.runtime.ui_compiler import compile as compile_ui_hints
|
||||
from schemas.messages.chat_message import (
|
||||
AgentChatMessage,
|
||||
AgentChatMessageMetadata,
|
||||
UserMessageAttachments,
|
||||
)
|
||||
|
||||
|
||||
def convert_message_to_history(
|
||||
message: AgentChatMessage,
|
||||
get_signed_url_fn: Callable[[str, str], str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
将 AgentChatMessage 转换为 HistoryMessage 格式
|
||||
|
||||
转换规则:
|
||||
- role=user: 读取 metadata.user_message_attachments,将 bucket 转临时访问 url
|
||||
- role=tool: 读取 content 和 metadata.tool_agent_output.ui_hints,编译成 ui_schema
|
||||
- role=assistant: 读取 metadata.worker_agent_output.ui_hints,编译成 ui_schema
|
||||
"""
|
||||
role = message.role
|
||||
content = message.content
|
||||
metadata = message.metadata
|
||||
|
||||
url: str | None = None
|
||||
ui_schema: dict[str, Any] | None = None
|
||||
|
||||
if role == "user":
|
||||
url = _convert_user_attachments(metadata, get_signed_url_fn)
|
||||
|
||||
elif role == "tool":
|
||||
ui_schema = _compile_tool_ui_hints(metadata)
|
||||
|
||||
elif role == "assistant":
|
||||
ui_schema = _compile_worker_ui_hints(metadata)
|
||||
|
||||
result: dict[str, Any] = {
|
||||
"id": str(message.id),
|
||||
"seq": message.seq,
|
||||
"role": role,
|
||||
"content": content,
|
||||
"timestamp": message.timestamp.isoformat(),
|
||||
}
|
||||
|
||||
if url:
|
||||
result["url"] = url
|
||||
|
||||
if ui_schema:
|
||||
result["uiSchema"] = ui_schema
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _convert_user_attachments(
|
||||
metadata: AgentChatMessageMetadata | dict[str, Any] | None,
|
||||
get_signed_url_fn: Callable[[str, str], str] | None,
|
||||
) -> str | None:
|
||||
"""转换用户附件为临时访问 URL"""
|
||||
if not metadata:
|
||||
return None
|
||||
|
||||
if isinstance(metadata, AgentChatMessageMetadata):
|
||||
attachments = metadata.user_message_attachments
|
||||
else:
|
||||
attachments_data = metadata.get("user_message_attachments")
|
||||
if not attachments_data:
|
||||
return None
|
||||
attachments = UserMessageAttachments.model_validate(attachments_data)
|
||||
|
||||
if not attachments or not get_signed_url_fn:
|
||||
return None
|
||||
|
||||
try:
|
||||
return get_signed_url_fn(
|
||||
{"bucket": attachments.bucket, "path": attachments.path}
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _compile_tool_ui_hints(
|
||||
metadata: AgentChatMessageMetadata | dict[str, Any] | None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""编译 tool 消息的 ui_hints"""
|
||||
if not metadata:
|
||||
return None
|
||||
|
||||
if isinstance(metadata, AgentChatMessageMetadata):
|
||||
tool_output = metadata.tool_agent_output
|
||||
else:
|
||||
tool_output_data = metadata.get("tool_agent_output")
|
||||
if not tool_output_data:
|
||||
return None
|
||||
from schemas.agent.runtime_models import ToolAgentOutput
|
||||
|
||||
tool_output = ToolAgentOutput.model_validate(tool_output_data)
|
||||
|
||||
if not tool_output:
|
||||
return None
|
||||
|
||||
ui_hints = tool_output.ui_hints
|
||||
if not ui_hints:
|
||||
return None
|
||||
|
||||
try:
|
||||
compiled = compile_ui_hints(ui_hints)
|
||||
return compiled
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _compile_worker_ui_hints(
|
||||
metadata: AgentChatMessageMetadata | dict[str, Any] | None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""编译 assistant 消息的 worker ui_hints"""
|
||||
if not metadata:
|
||||
return None
|
||||
|
||||
if isinstance(metadata, AgentChatMessageMetadata):
|
||||
worker_output = metadata.worker_agent_output
|
||||
else:
|
||||
worker_output_data = metadata.get("worker_agent_output")
|
||||
if not worker_output_data:
|
||||
return None
|
||||
from schemas.agent.runtime_models import WorkerAgentOutputRich
|
||||
|
||||
worker_output = WorkerAgentOutputRich.model_validate(worker_output_data)
|
||||
|
||||
if not worker_output:
|
||||
return None
|
||||
|
||||
ui_hints = worker_output.ui_hints
|
||||
if not ui_hints:
|
||||
return None
|
||||
|
||||
try:
|
||||
compiled = compile_ui_hints(ui_hints)
|
||||
return compiled
|
||||
except Exception:
|
||||
return None
|
||||
Reference in New Issue
Block a user