feat: 实现 AgentScope ReAct Runner 两阶段执行并重构事件处理
This commit is contained in:
@@ -38,9 +38,10 @@ _INTERNAL_TO_AGUI: dict[str, EventType] = {
|
||||
|
||||
|
||||
def _convert_to_agui_type(internal_type: str) -> EventType:
|
||||
return _INTERNAL_TO_AGUI.get(
|
||||
internal_type, EventType(internal_type.upper().replace(".", "_"))
|
||||
)
|
||||
mapped = _INTERNAL_TO_AGUI.get(internal_type)
|
||||
if mapped is not None:
|
||||
return mapped
|
||||
return EventType(internal_type.upper().replace(".", "_"))
|
||||
|
||||
|
||||
def _is_agui_event(event: dict[str, Any]) -> bool:
|
||||
@@ -142,32 +143,64 @@ def to_agui_wire_event(event: dict[str, Any] | BaseEvent) -> dict[str, Any]:
|
||||
return event
|
||||
|
||||
internal_type = str(event.get("type", "")).strip()
|
||||
|
||||
thread_id = event.get("threadId")
|
||||
run_id = event.get("runId")
|
||||
data = event.get("data")
|
||||
|
||||
if internal_type == "text.end" and isinstance(data, dict):
|
||||
text_end_payload: dict[str, Any] = {
|
||||
"type": _convert_to_agui_type(internal_type).value,
|
||||
}
|
||||
if isinstance(thread_id, str) and thread_id:
|
||||
text_end_payload["threadId"] = thread_id
|
||||
if isinstance(run_id, str) and run_id:
|
||||
text_end_payload["runId"] = run_id
|
||||
for key in ("messageId", "workerAgentOutput"):
|
||||
value = data.get(key)
|
||||
if value is not None:
|
||||
text_end_payload[key] = value
|
||||
return text_end_payload
|
||||
|
||||
if internal_type == "tool.result" and isinstance(data, dict):
|
||||
tool_result_payload = {
|
||||
"type": _convert_to_agui_type(internal_type).value,
|
||||
}
|
||||
if isinstance(thread_id, str) and thread_id:
|
||||
tool_result_payload["threadId"] = thread_id
|
||||
if isinstance(run_id, str) and run_id:
|
||||
tool_result_payload["runId"] = run_id
|
||||
for key in ("messageId", "toolCallId", "toolAgentOutput"):
|
||||
value = data.get(key)
|
||||
if value is not None:
|
||||
tool_result_payload[key] = value
|
||||
return tool_result_payload
|
||||
|
||||
builder = _BUILDER_MAP.get(internal_type)
|
||||
|
||||
if builder:
|
||||
agui_event = builder(event)
|
||||
return agui_event.model_dump(by_alias=True, exclude_none=True)
|
||||
payload = agui_event.model_dump(by_alias=True, exclude_none=True)
|
||||
if isinstance(thread_id, str) and thread_id:
|
||||
payload["threadId"] = thread_id
|
||||
if isinstance(run_id, str) and run_id:
|
||||
payload["runId"] = run_id
|
||||
if isinstance(data, dict):
|
||||
reserved = {"type", "threadId", "runId"}
|
||||
payload.update({k: v for k, v in data.items() if k not in reserved})
|
||||
return payload
|
||||
|
||||
wire_type = _convert_to_agui_type(internal_type)
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"type": wire_type.value,
|
||||
}
|
||||
thread_id = event.get("threadId")
|
||||
run_id = event.get("runId")
|
||||
if isinstance(thread_id, str) and thread_id:
|
||||
payload["threadId"] = thread_id
|
||||
if isinstance(run_id, str) and run_id:
|
||||
payload["runId"] = run_id
|
||||
|
||||
data = event.get("data")
|
||||
if isinstance(data, dict):
|
||||
if internal_type == "text.end":
|
||||
for key in ("messageId", "workerAgentOutput"):
|
||||
value = data.get(key)
|
||||
if value is not None:
|
||||
payload[key] = value
|
||||
return payload
|
||||
reserved = {"type", "threadId", "runId"}
|
||||
data_map = cast(dict[str, Any], data)
|
||||
payload.update({k: v for k, v in data_map.items() if k not in reserved})
|
||||
|
||||
@@ -50,5 +50,5 @@ class AgentScopeEventPipeline:
|
||||
) -> str:
|
||||
event_dict = to_dict(event)
|
||||
wire_event = self._codec.to_wire(event_dict)
|
||||
await self._store.persist(wire_event)
|
||||
await self._store.persist(event_dict)
|
||||
return await self._bus.publish(session_id=session_id, event=wire_event)
|
||||
|
||||
@@ -55,8 +55,8 @@ class SqlAlchemyEventStore:
|
||||
self._message_contexts: dict[tuple[str, str], dict[str, object]] = {}
|
||||
|
||||
async def persist(self, event: dict[str, Any]) -> None:
|
||||
event_type = str(event.get("type", "")).strip().upper()
|
||||
thread_id = event.get("threadId")
|
||||
event_type = str(event.get("type", "")).strip().upper().replace(".", "_")
|
||||
thread_id = self._event_value(event, "threadId")
|
||||
if not isinstance(thread_id, str) or not thread_id:
|
||||
return
|
||||
try:
|
||||
@@ -124,8 +124,8 @@ class SqlAlchemyEventStore:
|
||||
await session.commit()
|
||||
|
||||
def _buffer_text_delta(self, *, session_key: str, event: dict[str, Any]) -> None:
|
||||
message_id = event.get("messageId")
|
||||
delta = event.get("delta")
|
||||
message_id = self._event_value(event, "messageId")
|
||||
delta = self._event_value(event, "delta")
|
||||
if not isinstance(message_id, str) or not message_id:
|
||||
return
|
||||
if not isinstance(delta, str) or not delta:
|
||||
@@ -143,13 +143,13 @@ class SqlAlchemyEventStore:
|
||||
self._message_contexts.pop(key, None)
|
||||
|
||||
def _buffer_text_context(self, *, session_key: str, event: dict[str, Any]) -> None:
|
||||
message_id = event.get("messageId")
|
||||
message_id = self._event_value(event, "messageId")
|
||||
if not isinstance(message_id, str) or not message_id:
|
||||
return
|
||||
key = (session_key, message_id)
|
||||
role = event.get("role")
|
||||
stage = event.get("stage")
|
||||
tool_name = event.get("toolName")
|
||||
role = self._event_value(event, "role")
|
||||
stage = self._event_value(event, "stage")
|
||||
tool_name = self._event_value(event, "toolName")
|
||||
context: dict[str, object] = {}
|
||||
if isinstance(role, str) and role:
|
||||
context["role"] = role
|
||||
@@ -168,7 +168,7 @@ class SqlAlchemyEventStore:
|
||||
session_repo: SessionRepository,
|
||||
message_repo: MessageRepository,
|
||||
) -> None:
|
||||
message_id_raw = event.get("messageId")
|
||||
message_id_raw = self._event_value(event, "messageId")
|
||||
message_id = message_id_raw if isinstance(message_id_raw, str) else ""
|
||||
key = (str(session_id), message_id)
|
||||
content = self._message_buffers.get(key, "")
|
||||
@@ -177,26 +177,26 @@ class SqlAlchemyEventStore:
|
||||
|
||||
context = self._message_contexts.get(key, {})
|
||||
|
||||
input_tokens = self._to_int(event.get("inputTokens"))
|
||||
output_tokens = self._to_int(event.get("outputTokens"))
|
||||
input_tokens = self._to_int(self._event_value(event, "inputTokens"))
|
||||
output_tokens = self._to_int(self._event_value(event, "outputTokens"))
|
||||
token_delta = input_tokens + output_tokens
|
||||
cost = self._to_decimal(event.get("cost"))
|
||||
latency_ms = self._to_int_or_none(event.get("latencyMs"))
|
||||
run_id = event.get("runId")
|
||||
model_code = event.get("model")
|
||||
cost = self._to_decimal(self._event_value(event, "cost"))
|
||||
latency_ms = self._to_int_or_none(self._event_value(event, "latencyMs"))
|
||||
run_id = self._event_value(event, "runId")
|
||||
model_code = self._event_value(event, "model")
|
||||
|
||||
metadata: dict[str, object] = {"message_id": message_id}
|
||||
if isinstance(run_id, str) and run_id:
|
||||
metadata["run_id"] = run_id
|
||||
if latency_ms is not None:
|
||||
metadata["latency_ms"] = latency_ms
|
||||
stage = event.get("stage")
|
||||
stage = self._event_value(event, "stage")
|
||||
if not isinstance(stage, str):
|
||||
stage = context.get("stage")
|
||||
if isinstance(stage, str) and stage:
|
||||
metadata["stage"] = stage
|
||||
|
||||
worker_payload = event.get("workerAgentOutput")
|
||||
worker_payload = self._event_value(event, "workerAgentOutput")
|
||||
if isinstance(worker_payload, dict):
|
||||
try:
|
||||
if "ui_hints" in worker_payload:
|
||||
@@ -264,11 +264,11 @@ class SqlAlchemyEventStore:
|
||||
session_repo: SessionRepository,
|
||||
message_repo: MessageRepository,
|
||||
) -> None:
|
||||
tool_name = event.get("toolName")
|
||||
tool_name = self._event_value(event, "toolName")
|
||||
if not isinstance(tool_name, str) or not tool_name:
|
||||
return
|
||||
|
||||
raw_output = event.get("toolAgentOutput")
|
||||
raw_output = self._event_value(event, "toolAgentOutput")
|
||||
if not isinstance(raw_output, dict):
|
||||
return
|
||||
try:
|
||||
@@ -276,11 +276,11 @@ class SqlAlchemyEventStore:
|
||||
except Exception:
|
||||
return
|
||||
|
||||
run_id = event.get("runId")
|
||||
run_id = self._event_value(event, "runId")
|
||||
run_id_value = run_id if isinstance(run_id, str) and run_id else ""
|
||||
task_id = event.get("taskId")
|
||||
task_id = self._event_value(event, "taskId")
|
||||
task_id_value = task_id if isinstance(task_id, str) and task_id else "task"
|
||||
call_id_value = event.get("callId")
|
||||
call_id_value = self._event_value(event, "callId")
|
||||
if not isinstance(call_id_value, str) or not call_id_value:
|
||||
call_id_value = (
|
||||
f"{run_id_value}-{task_id_value}-{uuid4().hex[:8]}"
|
||||
@@ -303,7 +303,7 @@ class SqlAlchemyEventStore:
|
||||
}
|
||||
if run_id_value:
|
||||
metadata["run_id"] = run_id_value
|
||||
stage = event.get("stage")
|
||||
stage = self._event_value(event, "stage")
|
||||
if isinstance(stage, str) and stage:
|
||||
metadata["stage"] = stage
|
||||
if task_id_value:
|
||||
@@ -421,6 +421,19 @@ class SqlAlchemyEventStore:
|
||||
return Decimal("0")
|
||||
return parsed if parsed >= 0 else Decimal("0")
|
||||
|
||||
def _event_value(
|
||||
self,
|
||||
event: dict[str, Any],
|
||||
key: str,
|
||||
default: object | None = None,
|
||||
) -> object | None:
|
||||
if key in event:
|
||||
return event.get(key)
|
||||
data = event.get("data")
|
||||
if isinstance(data, dict):
|
||||
return data.get(key, default)
|
||||
return default
|
||||
|
||||
|
||||
def _sanitize_path_component(value: str) -> str:
|
||||
compact = re.sub(r"[^A-Za-z0-9._-]", "-", value.strip())
|
||||
|
||||
@@ -2,9 +2,10 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
from typing import Any, Sequence
|
||||
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||
|
||||
from ag_ui.core.types import Tool
|
||||
from core.agentscope.prompts.agent_prompt import (
|
||||
build_agent_prompt,
|
||||
)
|
||||
@@ -193,7 +194,7 @@ def build_system_prompt(
|
||||
user_context: UserContext,
|
||||
now_utc: datetime,
|
||||
extra_context: str | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
tools: Sequence[Tool] | None = None,
|
||||
) -> str:
|
||||
sections = [
|
||||
_build_identity_section(),
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, Iterable
|
||||
from typing import Iterable
|
||||
|
||||
from ag_ui.core.types import Tool
|
||||
|
||||
|
||||
def _wrap_section(section: str, content: str) -> str:
|
||||
@@ -15,18 +17,16 @@ def _wrap_section(section: str, content: str) -> str:
|
||||
|
||||
def build_tools_prompt(
|
||||
*,
|
||||
tools: Iterable[dict[str, Any]],
|
||||
tools: Iterable[Tool],
|
||||
) -> str:
|
||||
lines: list[str] = []
|
||||
lines.append("[Available Tools]")
|
||||
|
||||
for item in tools:
|
||||
name = item.get("name")
|
||||
description = item.get("description") or ""
|
||||
parameters = item.get("parameters") or {}
|
||||
if not isinstance(name, str) or not name:
|
||||
continue
|
||||
lines.append(f"- {name}: {description}".strip())
|
||||
name = item.name
|
||||
description = item.description or ""
|
||||
parameters = item.parameters or {}
|
||||
lines.append(f"- {name}: {description}")
|
||||
lines.append(
|
||||
" - args_schema: "
|
||||
+ json.dumps(parameters, ensure_ascii=True, separators=(",", ":"))
|
||||
|
||||
@@ -1,16 +1,322 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncGenerator, Sequence
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from decimal import Decimal
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from ag_ui.core.types import RunAgentInput
|
||||
from agentscope.agent import ReActAgent
|
||||
from agentscope.formatter import OpenAIChatFormatter
|
||||
from agentscope.memory import InMemoryMemory
|
||||
from agentscope.message import Msg
|
||||
from agentscope.model import OpenAIChatModel
|
||||
from core.agentscope.events.persistence import MessageRepository, SessionRepository
|
||||
from core.agentscope.prompts.system_prompt import build_system_prompt
|
||||
from core.agentscope.tools.toolkit import build_stage_toolkit
|
||||
from core.db.session import AsyncSessionLocal
|
||||
from core.logging import get_logger
|
||||
from models.agent_chat_message import AgentChatMessageRole
|
||||
from models.agent_chat_session import AgentChatSessionStatus
|
||||
from models.llm import Llm
|
||||
from models.system_agents import SystemAgents
|
||||
from schemas.agent.runtime_models import (
|
||||
RouterAgentOutput,
|
||||
ToolAgentOutput,
|
||||
WorkerAgentOutputLite,
|
||||
resolve_worker_output_model,
|
||||
)
|
||||
from schemas.agent.system_agent import AgentType, SystemAgentLLMConfig
|
||||
from schemas.messages.chat_message import AgentChatMessage, AgentChatMessageMetadata
|
||||
from schemas.user import UserContext
|
||||
from services.litellm.service import LiteLLMService
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.agentscope.runtime.orchestrator import PipelineLike
|
||||
|
||||
logger = get_logger("core.agentscope.runtime.react_runner")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SystemAgentRuntimeConfig:
|
||||
agent_type: AgentType
|
||||
model_code: str
|
||||
llm_config: SystemAgentLLMConfig
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class StageExecutionResult:
|
||||
message: Msg
|
||||
payload: dict[str, Any]
|
||||
response_metadata: dict[str, Any]
|
||||
|
||||
|
||||
class _TrackingChatModel:
|
||||
def __init__(self, inner: OpenAIChatModel) -> None:
|
||||
self._inner = inner
|
||||
self._total_input_tokens = 0
|
||||
self._total_output_tokens = 0
|
||||
self._total_latency_ms = 0
|
||||
self._cached_prompt_tokens = 0
|
||||
|
||||
@property
|
||||
def stream(self) -> bool:
|
||||
return self._inner.stream
|
||||
|
||||
@stream.setter
|
||||
def stream(self, value: bool) -> None:
|
||||
self._inner.stream = value
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
return getattr(self._inner, name)
|
||||
|
||||
async def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
response = await self._inner(*args, **kwargs)
|
||||
if isinstance(response, AsyncGenerator):
|
||||
return self._track_stream(response)
|
||||
self._record_usage(getattr(response, "usage", None))
|
||||
return response
|
||||
|
||||
async def _track_stream(
|
||||
self, response: AsyncGenerator[Any, None]
|
||||
) -> AsyncGenerator[Any, None]:
|
||||
latest_usage = None
|
||||
async for chunk in response:
|
||||
usage = getattr(chunk, "usage", None)
|
||||
if usage is not None:
|
||||
latest_usage = usage
|
||||
yield chunk
|
||||
self._record_usage(latest_usage)
|
||||
|
||||
def _record_usage(self, usage: Any) -> None:
|
||||
if usage is None:
|
||||
return
|
||||
self._total_input_tokens += max(int(getattr(usage, "input_tokens", 0) or 0), 0)
|
||||
self._total_output_tokens += max(
|
||||
int(getattr(usage, "output_tokens", 0) or 0), 0
|
||||
)
|
||||
self._total_latency_ms += max(
|
||||
int(round(float(getattr(usage, "time", 0) or 0) * 1000)), 0
|
||||
)
|
||||
metadata = getattr(usage, "metadata", None)
|
||||
if metadata is not None:
|
||||
cached_tokens = 0
|
||||
if isinstance(metadata, dict):
|
||||
prompt_details = metadata.get("prompt_tokens_details")
|
||||
if isinstance(prompt_details, dict):
|
||||
cached_tokens = int(prompt_details.get("cached_tokens", 0) or 0)
|
||||
else:
|
||||
prompt_details = getattr(metadata, "prompt_tokens_details", None)
|
||||
cached_tokens = int(getattr(prompt_details, "cached_tokens", 0) or 0)
|
||||
self._cached_prompt_tokens += max(cached_tokens, 0)
|
||||
|
||||
def usage_summary(self) -> dict[str, int]:
|
||||
return {
|
||||
"input_tokens": self._total_input_tokens,
|
||||
"output_tokens": self._total_output_tokens,
|
||||
"latency_ms": self._total_latency_ms,
|
||||
"cached_prompt_tokens": self._cached_prompt_tokens,
|
||||
}
|
||||
|
||||
|
||||
class _PipelineStageEmitter:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
pipeline: PipelineLike,
|
||||
session_id: str,
|
||||
run_id: str,
|
||||
stage: str,
|
||||
emit_text_events: bool,
|
||||
emit_tool_events: bool,
|
||||
) -> None:
|
||||
self._pipeline = pipeline
|
||||
self._session_id = session_id
|
||||
self._run_id = run_id
|
||||
self._stage = stage
|
||||
self._emit_text_events = emit_text_events
|
||||
self._emit_tool_events = emit_tool_events
|
||||
self._text_by_message_id: dict[str, str] = {}
|
||||
self._emitted_tool_calls: set[str] = set()
|
||||
self._emitted_tool_results: set[str] = set()
|
||||
self.latest_text_message_id: str | None = None
|
||||
self.latest_text: str = ""
|
||||
|
||||
async def handle_print(self, *, msg: Msg, last: bool) -> None:
|
||||
del last
|
||||
if self._emit_tool_events:
|
||||
await self._emit_tool_events_from_msg(msg)
|
||||
if self._emit_text_events:
|
||||
await self._emit_text_events_from_msg(msg)
|
||||
|
||||
async def _emit_text_events_from_msg(self, msg: Msg) -> None:
|
||||
text = msg.get_text_content(separator="") or ""
|
||||
if not text:
|
||||
return
|
||||
message_id = str(msg.id)
|
||||
previous = self._text_by_message_id.get(message_id, "")
|
||||
if message_id not in self._text_by_message_id:
|
||||
await self._emit(
|
||||
"text.start",
|
||||
{
|
||||
"messageId": message_id,
|
||||
"role": "assistant",
|
||||
"stage": self._stage,
|
||||
},
|
||||
)
|
||||
delta = text[len(previous) :] if text.startswith(previous) else text
|
||||
if delta:
|
||||
await self._emit(
|
||||
"text.delta",
|
||||
{
|
||||
"messageId": message_id,
|
||||
"delta": delta,
|
||||
"stage": self._stage,
|
||||
},
|
||||
)
|
||||
self._text_by_message_id[message_id] = text
|
||||
self.latest_text_message_id = message_id
|
||||
self.latest_text = text
|
||||
|
||||
async def _emit_tool_events_from_msg(self, msg: Msg) -> None:
|
||||
for block in msg.get_content_blocks("tool_use"):
|
||||
tool_call_id = str(block.get("id", "")).strip()
|
||||
tool_name = str(block.get("name", "")).strip()
|
||||
if (
|
||||
not tool_call_id
|
||||
or not tool_name
|
||||
or tool_call_id in self._emitted_tool_calls
|
||||
):
|
||||
continue
|
||||
payload = {
|
||||
"messageId": str(msg.id),
|
||||
"toolCallId": tool_call_id,
|
||||
"toolName": tool_name,
|
||||
"stage": self._stage,
|
||||
}
|
||||
await self._emit("tool.start", payload)
|
||||
await self._emit(
|
||||
"tool.args",
|
||||
{
|
||||
**payload,
|
||||
"args": block.get("input", {}),
|
||||
},
|
||||
)
|
||||
await self._emit("tool.end", payload)
|
||||
self._emitted_tool_calls.add(tool_call_id)
|
||||
|
||||
for block in msg.get_content_blocks("tool_result"):
|
||||
tool_call_id = str(block.get("id", "")).strip()
|
||||
if not tool_call_id or tool_call_id in self._emitted_tool_results:
|
||||
continue
|
||||
tool_output = _parse_tool_agent_output(block.get("output"))
|
||||
if tool_output is None:
|
||||
continue
|
||||
await self._emit(
|
||||
"tool.result",
|
||||
{
|
||||
"messageId": str(msg.id),
|
||||
"toolCallId": tool_call_id,
|
||||
"toolName": tool_output.tool_name,
|
||||
"stage": self._stage,
|
||||
"toolAgentOutput": tool_output.model_dump(
|
||||
mode="json", exclude_none=True
|
||||
),
|
||||
},
|
||||
)
|
||||
self._emitted_tool_results.add(tool_call_id)
|
||||
|
||||
async def emit_final_text_end(
|
||||
self,
|
||||
*,
|
||||
worker_output: dict[str, Any],
|
||||
response_metadata: dict[str, Any],
|
||||
) -> None:
|
||||
message_id = (
|
||||
self.latest_text_message_id or f"worker-{self._run_id}-{uuid4().hex[:8]}"
|
||||
)
|
||||
if self.latest_text_message_id is None and worker_output.get("answer"):
|
||||
await self._emit(
|
||||
"text.start",
|
||||
{
|
||||
"messageId": message_id,
|
||||
"role": "assistant",
|
||||
"stage": self._stage,
|
||||
},
|
||||
)
|
||||
await self._emit(
|
||||
"text.delta",
|
||||
{
|
||||
"messageId": message_id,
|
||||
"delta": worker_output.get("answer", ""),
|
||||
"stage": self._stage,
|
||||
},
|
||||
)
|
||||
await self._emit(
|
||||
"text.end",
|
||||
{
|
||||
"messageId": message_id,
|
||||
"role": "assistant",
|
||||
"stage": self._stage,
|
||||
"workerAgentOutput": worker_output,
|
||||
**response_metadata,
|
||||
},
|
||||
)
|
||||
|
||||
async def _emit(self, event_type: str, data: dict[str, Any]) -> None:
|
||||
await self._pipeline.emit(
|
||||
session_id=self._session_id,
|
||||
event={
|
||||
"type": event_type,
|
||||
"threadId": self._session_id,
|
||||
"runId": self._run_id,
|
||||
"data": data,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class _PipelineReActAgent(ReActAgent):
|
||||
def __init__(
|
||||
self, *, emitter: _PipelineStageEmitter | None = None, **kwargs: Any
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self._pipeline_emitter = emitter
|
||||
self.disable_console_output()
|
||||
|
||||
async def print(self, msg: Msg, last: bool = True, speech: Any = None) -> None:
|
||||
del speech
|
||||
if self._pipeline_emitter is not None:
|
||||
await self._pipeline_emitter.handle_print(msg=msg, last=last)
|
||||
|
||||
|
||||
def _parse_tool_agent_output(output: Any) -> ToolAgentOutput | None:
|
||||
blocks = output if isinstance(output, Sequence) else []
|
||||
for block in blocks:
|
||||
if not isinstance(block, dict) or block.get("type") != "text":
|
||||
continue
|
||||
text = block.get("text")
|
||||
if not isinstance(text, str) or not text.strip():
|
||||
continue
|
||||
try:
|
||||
return ToolAgentOutput.model_validate(json.loads(text))
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_tool_name(value: str) -> str:
|
||||
return value.strip().replace(".", "_").replace("-", "_")
|
||||
|
||||
|
||||
class AgentScopeReActRunner:
|
||||
def __init__(self, *, litellm_service: LiteLLMService | None = None) -> None:
|
||||
self._litellm_service = litellm_service or LiteLLMService()
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
*,
|
||||
@@ -19,4 +325,367 @@ class AgentScopeReActRunner:
|
||||
pipeline: PipelineLike,
|
||||
run_input: RunAgentInput,
|
||||
) -> dict[str, Any]:
|
||||
raise NotImplementedError("execute method not implemented")
|
||||
owner_id = UUID(user_context.id)
|
||||
enabled_tool_names = self._extract_tool_names(run_input)
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
router_toolkit, worker_toolkit = self._build_toolkits(
|
||||
session=session,
|
||||
owner_id=owner_id,
|
||||
enabled_tool_names=enabled_tool_names,
|
||||
)
|
||||
|
||||
router_config = await self._load_system_agent_config(
|
||||
session=session,
|
||||
agent_type=AgentType.ROUTER,
|
||||
)
|
||||
worker_config = await self._load_system_agent_config(
|
||||
session=session,
|
||||
agent_type=AgentType.WORKER,
|
||||
)
|
||||
|
||||
await self._emit_step_event(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
step_name="router",
|
||||
event_type="step.start",
|
||||
)
|
||||
router_result = await self._run_router_stage(
|
||||
user_context=user_context,
|
||||
context_messages=context_messages,
|
||||
toolkit=router_toolkit,
|
||||
run_input=run_input,
|
||||
stage_config=router_config,
|
||||
)
|
||||
router_output = RouterAgentOutput.model_validate(router_result.payload)
|
||||
await self._persist_router_message(
|
||||
session=session,
|
||||
thread_id=run_input.thread_id,
|
||||
run_id=run_input.run_id,
|
||||
model_code=router_config.model_code,
|
||||
router_output=router_output,
|
||||
response_metadata=router_result.response_metadata,
|
||||
)
|
||||
await session.commit()
|
||||
await self._emit_step_event(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
step_name="router",
|
||||
event_type="step.finish",
|
||||
)
|
||||
|
||||
worker_output_model = resolve_worker_output_model(router_output.ui.ui_mode)
|
||||
await self._emit_step_event(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
step_name="worker",
|
||||
event_type="step.start",
|
||||
)
|
||||
worker_result = await self._run_worker_stage(
|
||||
user_context=user_context,
|
||||
router_output=router_output,
|
||||
toolkit=worker_toolkit,
|
||||
run_input=run_input,
|
||||
stage_config=worker_config,
|
||||
worker_output_model=worker_output_model,
|
||||
pipeline=pipeline,
|
||||
)
|
||||
worker_output = worker_output_model.model_validate(worker_result.payload)
|
||||
await self._emit_step_event(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
step_name="worker",
|
||||
event_type="step.finish",
|
||||
)
|
||||
|
||||
return {
|
||||
"router": router_output.model_dump(mode="json", exclude_none=True),
|
||||
"worker": worker_output.model_dump(mode="json", exclude_none=True),
|
||||
}
|
||||
|
||||
def _build_toolkits(
|
||||
self,
|
||||
*,
|
||||
session: AsyncSession,
|
||||
owner_id: UUID,
|
||||
enabled_tool_names: set[str] | None,
|
||||
) -> tuple[Any, Any]:
|
||||
return (
|
||||
build_stage_toolkit(
|
||||
agent_type=AgentType.ROUTER,
|
||||
session=session,
|
||||
owner_id=owner_id,
|
||||
enabled_tool_names=enabled_tool_names,
|
||||
),
|
||||
build_stage_toolkit(
|
||||
agent_type=AgentType.WORKER,
|
||||
session=session,
|
||||
owner_id=owner_id,
|
||||
enabled_tool_names=enabled_tool_names,
|
||||
),
|
||||
)
|
||||
|
||||
def _extract_tool_names(self, run_input: RunAgentInput) -> set[str] | None:
|
||||
raw_tools = getattr(run_input, "tools", None)
|
||||
if not isinstance(raw_tools, list):
|
||||
return None
|
||||
selected: set[str] = set()
|
||||
for item in raw_tools:
|
||||
if isinstance(item, dict):
|
||||
name = item.get("name")
|
||||
else:
|
||||
name = getattr(item, "name", None)
|
||||
if isinstance(name, str) and name.strip():
|
||||
selected.add(_normalize_tool_name(name))
|
||||
return selected
|
||||
|
||||
async def _load_system_agent_config(
|
||||
self,
|
||||
*,
|
||||
session: AsyncSession,
|
||||
agent_type: AgentType,
|
||||
) -> SystemAgentRuntimeConfig:
|
||||
stmt = (
|
||||
select(SystemAgents, Llm)
|
||||
.join(Llm, SystemAgents.llm_id == Llm.id)
|
||||
.where(SystemAgents.agent_type == agent_type.value)
|
||||
)
|
||||
row = (await session.execute(stmt)).one_or_none()
|
||||
if row is None:
|
||||
raise RuntimeError(f"system agent config not found: {agent_type.value}")
|
||||
system_agent, llm = row
|
||||
status = str(system_agent.status).strip().lower()
|
||||
if status != "active":
|
||||
raise RuntimeError(f"system agent is not active: {agent_type.value}")
|
||||
return SystemAgentRuntimeConfig(
|
||||
agent_type=agent_type,
|
||||
model_code=llm.model_code,
|
||||
llm_config=SystemAgentLLMConfig.model_validate(system_agent.config or {}),
|
||||
)
|
||||
|
||||
async def _run_router_stage(
|
||||
self,
|
||||
*,
|
||||
user_context: UserContext,
|
||||
context_messages: list[Msg],
|
||||
toolkit: Any,
|
||||
run_input: RunAgentInput,
|
||||
stage_config: SystemAgentRuntimeConfig,
|
||||
) -> StageExecutionResult:
|
||||
tracking_model = self._build_model(stage_config=stage_config)
|
||||
agent = self._build_agent(
|
||||
agent_name="router",
|
||||
system_prompt=build_system_prompt(
|
||||
agent_type=AgentType.ROUTER,
|
||||
user_context=user_context,
|
||||
now_utc=datetime.now(timezone.utc),
|
||||
tools=run_input.tools,
|
||||
),
|
||||
toolkit=toolkit,
|
||||
model=tracking_model,
|
||||
)
|
||||
response_msg = await agent.reply(
|
||||
context_messages, structured_model=RouterAgentOutput
|
||||
)
|
||||
payload = RouterAgentOutput.model_validate(
|
||||
response_msg.metadata or {}
|
||||
).model_dump(
|
||||
mode="json",
|
||||
exclude_none=True,
|
||||
)
|
||||
return StageExecutionResult(
|
||||
message=response_msg,
|
||||
payload=payload,
|
||||
response_metadata=self._litellm_service.build_usage_metadata(
|
||||
model=stage_config.model_code,
|
||||
usage_summary=tracking_model.usage_summary(),
|
||||
),
|
||||
)
|
||||
|
||||
async def _run_worker_stage(
|
||||
self,
|
||||
*,
|
||||
user_context: UserContext,
|
||||
router_output: RouterAgentOutput,
|
||||
toolkit: Any,
|
||||
run_input: RunAgentInput,
|
||||
stage_config: SystemAgentRuntimeConfig,
|
||||
worker_output_model: type[WorkerAgentOutputLite],
|
||||
pipeline: PipelineLike,
|
||||
) -> StageExecutionResult:
|
||||
worker_input = self._build_worker_input_messages(
|
||||
router_output=router_output,
|
||||
)
|
||||
tracking_model = self._build_model(stage_config=stage_config)
|
||||
emitter = _PipelineStageEmitter(
|
||||
pipeline=pipeline,
|
||||
session_id=run_input.thread_id,
|
||||
run_id=run_input.run_id,
|
||||
stage="worker",
|
||||
emit_text_events=True,
|
||||
emit_tool_events=True,
|
||||
)
|
||||
agent = self._build_agent(
|
||||
agent_name="worker",
|
||||
system_prompt=build_system_prompt(
|
||||
agent_type=AgentType.WORKER,
|
||||
user_context=user_context,
|
||||
now_utc=datetime.now(timezone.utc),
|
||||
tools=run_input.tools,
|
||||
),
|
||||
toolkit=toolkit,
|
||||
model=tracking_model,
|
||||
emitter=emitter,
|
||||
)
|
||||
response_msg = await agent.reply(
|
||||
worker_input,
|
||||
structured_model=worker_output_model,
|
||||
)
|
||||
worker_payload = worker_output_model.model_validate(response_msg.metadata or {})
|
||||
response_metadata = self._litellm_service.build_usage_metadata(
|
||||
model=stage_config.model_code,
|
||||
usage_summary=tracking_model.usage_summary(),
|
||||
)
|
||||
await emitter.emit_final_text_end(
|
||||
worker_output=worker_payload.model_dump(mode="json", exclude_none=True),
|
||||
response_metadata=response_metadata,
|
||||
)
|
||||
return StageExecutionResult(
|
||||
message=response_msg,
|
||||
payload=worker_payload.model_dump(mode="json", exclude_none=True),
|
||||
response_metadata=response_metadata,
|
||||
)
|
||||
|
||||
def _build_worker_input_messages(
|
||||
self,
|
||||
*,
|
||||
router_output: RouterAgentOutput,
|
||||
) -> list[Msg]:
|
||||
routing_contract = json.dumps(
|
||||
router_output.model_dump(mode="json", exclude_none=True),
|
||||
ensure_ascii=False,
|
||||
separators=(",", ":"),
|
||||
)
|
||||
routing_msg = Msg(
|
||||
name="router",
|
||||
role="user",
|
||||
content=(
|
||||
"Use the following routing contract as the execution source of truth. "
|
||||
f"Do not change the routed objective:\n{routing_contract}"
|
||||
),
|
||||
)
|
||||
return [routing_msg]
|
||||
|
||||
def _build_model(
|
||||
self, *, stage_config: SystemAgentRuntimeConfig
|
||||
) -> _TrackingChatModel:
|
||||
model = OpenAIChatModel(
|
||||
model_name=stage_config.model_code,
|
||||
api_key=self._litellm_service.proxy_api_key,
|
||||
stream=True,
|
||||
client_kwargs={"base_url": self._litellm_service.proxy_base_url},
|
||||
generate_kwargs={
|
||||
"temperature": stage_config.llm_config.temperature,
|
||||
"max_tokens": stage_config.llm_config.max_tokens,
|
||||
"timeout": stage_config.llm_config.timeout_seconds,
|
||||
},
|
||||
)
|
||||
return _TrackingChatModel(model)
|
||||
|
||||
def _build_agent(
|
||||
self,
|
||||
*,
|
||||
agent_name: str,
|
||||
system_prompt: str,
|
||||
toolkit: Any,
|
||||
model: _TrackingChatModel,
|
||||
emitter: _PipelineStageEmitter | None = None,
|
||||
) -> _PipelineReActAgent:
|
||||
return _PipelineReActAgent(
|
||||
name=agent_name,
|
||||
sys_prompt=system_prompt,
|
||||
model=model,
|
||||
formatter=OpenAIChatFormatter(),
|
||||
toolkit=toolkit,
|
||||
memory=InMemoryMemory(),
|
||||
emitter=emitter,
|
||||
)
|
||||
|
||||
async def _emit_step_event(
|
||||
self,
|
||||
*,
|
||||
pipeline: PipelineLike,
|
||||
run_input: RunAgentInput,
|
||||
step_name: str,
|
||||
event_type: str,
|
||||
) -> None:
|
||||
await pipeline.emit(
|
||||
session_id=run_input.thread_id,
|
||||
event={
|
||||
"type": event_type,
|
||||
"threadId": run_input.thread_id,
|
||||
"runId": run_input.run_id,
|
||||
"data": {"stepName": step_name},
|
||||
},
|
||||
)
|
||||
|
||||
async def _persist_router_message(
|
||||
self,
|
||||
*,
|
||||
session: AsyncSession,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
model_code: str,
|
||||
router_output: RouterAgentOutput,
|
||||
response_metadata: dict[str, Any],
|
||||
) -> None:
|
||||
session_id = UUID(thread_id)
|
||||
message_repo = MessageRepository(session)
|
||||
session_repo = SessionRepository(session)
|
||||
locked_session = await session_repo.lock_session_for_update(
|
||||
session_id=session_id
|
||||
)
|
||||
if locked_session is None:
|
||||
raise RuntimeError("chat session not found for router persistence")
|
||||
seq = int(getattr(locked_session, "message_count", 0) or 0) + 1
|
||||
metadata = AgentChatMessageMetadata(
|
||||
run_id=run_id,
|
||||
agent_type=AgentType.ROUTER,
|
||||
router_agent_output=router_output,
|
||||
)
|
||||
message_payload = AgentChatMessage(
|
||||
id=uuid4(),
|
||||
seq=seq,
|
||||
role=AgentChatMessageRole.ASSISTANT.value,
|
||||
content="",
|
||||
model_code=model_code,
|
||||
tool_name=None,
|
||||
input_tokens=int(response_metadata.get("inputTokens", 0) or 0),
|
||||
output_tokens=int(response_metadata.get("outputTokens", 0) or 0),
|
||||
cost=Decimal(str(response_metadata.get("cost", 0) or 0)),
|
||||
latency_ms=int(response_metadata.get("latencyMs", 0) or 0),
|
||||
metadata=metadata,
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
await message_repo.append_message(
|
||||
session_id=session_id,
|
||||
seq=message_payload.seq,
|
||||
role=AgentChatMessageRole.ASSISTANT,
|
||||
content=message_payload.content,
|
||||
model_code=message_payload.model_code,
|
||||
tool_name=message_payload.tool_name,
|
||||
metadata=metadata.model_dump(mode="json", exclude_none=True),
|
||||
input_tokens=message_payload.input_tokens,
|
||||
output_tokens=message_payload.output_tokens,
|
||||
cost=message_payload.cost,
|
||||
latency_ms=message_payload.latency_ms,
|
||||
)
|
||||
await session_repo.update_runtime_state(
|
||||
chat_session=locked_session,
|
||||
status=AgentChatSessionStatus.RUNNING,
|
||||
state_snapshot=locked_session.state_snapshot or {},
|
||||
message_delta=1,
|
||||
token_delta=message_payload.input_tokens + message_payload.output_tokens,
|
||||
cost_delta=message_payload.cost,
|
||||
)
|
||||
await session.flush()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -23,9 +23,8 @@ from schemas.agent.ui_hints import (
|
||||
UiHintAction,
|
||||
UiHintActionCopy,
|
||||
UiHintActionStyle,
|
||||
UiHintErrorBlock,
|
||||
UiHintKeyValuePair,
|
||||
UiHintKvBlock,
|
||||
UiHintIntent,
|
||||
UiHintKvItem,
|
||||
UiHintStatus,
|
||||
UiHintsPayload,
|
||||
)
|
||||
@@ -54,18 +53,10 @@ def _lookup_error_output(
|
||||
update={
|
||||
"tool_call_args": tool_call_args,
|
||||
"ui_hints": UiHintsPayload(
|
||||
intent=UiHintIntent.STATUS,
|
||||
status=UiHintStatus.ERROR,
|
||||
title="用户查找失败",
|
||||
description=message,
|
||||
blocks=[
|
||||
UiHintErrorBlock(
|
||||
kind="error",
|
||||
title="查找失败",
|
||||
errorCode=code,
|
||||
message=message,
|
||||
retryable=retryable,
|
||||
)
|
||||
],
|
||||
body=message,
|
||||
),
|
||||
}
|
||||
)
|
||||
@@ -78,28 +69,15 @@ def _lookup_success_hints(resolved: dict[str, Any]) -> UiHintsPayload:
|
||||
username = str(resolved.get("username") or "")
|
||||
matched_by = str(resolved.get("matchedBy") or "")
|
||||
return UiHintsPayload(
|
||||
intent=UiHintIntent.DATA,
|
||||
status=UiHintStatus.SUCCESS,
|
||||
title="用户信息",
|
||||
description=f"匹配方式: {matched_by}",
|
||||
blocks=[
|
||||
UiHintKvBlock(
|
||||
kind="kv",
|
||||
title="查找结果",
|
||||
pairs=[
|
||||
UiHintKeyValuePair(
|
||||
key="user_id", label="用户ID", value=user_id, copyable=True
|
||||
),
|
||||
UiHintKeyValuePair(
|
||||
key="email", label="邮箱", value=email, copyable=True
|
||||
),
|
||||
UiHintKeyValuePair(
|
||||
key="username", label="用户名", value=username or "-"
|
||||
),
|
||||
UiHintKeyValuePair(
|
||||
key="matched_by", label="匹配方式", value=matched_by
|
||||
),
|
||||
],
|
||||
)
|
||||
items=[
|
||||
UiHintKvItem(key="user_id", label="用户ID", value=user_id, copyable=True),
|
||||
UiHintKvItem(key="email", label="邮箱", value=email, copyable=True),
|
||||
UiHintKvItem(key="username", label="用户名", value=username or "-"),
|
||||
UiHintKvItem(key="matched_by", label="匹配方式", value=matched_by),
|
||||
],
|
||||
actions=[
|
||||
UiHintAction(
|
||||
|
||||
@@ -18,6 +18,7 @@ from core.agentscope.tools.tool_config import (
|
||||
)
|
||||
from core.agentscope.tools.tool_middleware import register_tool_middlewares
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from schemas.agent.system_agent import AgentType
|
||||
|
||||
TOOL_FUNCTIONS: dict[str, Any] = {
|
||||
"calendar_read": calendar_read,
|
||||
@@ -27,9 +28,9 @@ TOOL_FUNCTIONS: dict[str, Any] = {
|
||||
}
|
||||
|
||||
|
||||
STAGE_TO_GROUPS: dict[str, set[ToolGroup]] = {
|
||||
"router": {ToolGroup.READ},
|
||||
"worker": {ToolGroup.READ, ToolGroup.WRITE},
|
||||
AGENT_TYPE_TO_GROUPS: dict[AgentType, set[ToolGroup]] = {
|
||||
AgentType.ROUTER: {ToolGroup.READ},
|
||||
AgentType.WORKER: {ToolGroup.READ, ToolGroup.WRITE},
|
||||
}
|
||||
|
||||
|
||||
@@ -61,7 +62,6 @@ def build_toolkit(
|
||||
groups: set[ToolGroup] | None = None,
|
||||
enabled_tool_names: set[str] | None = None,
|
||||
enable_hitl: bool | None = None,
|
||||
enable_approval_layer: bool = True,
|
||||
):
|
||||
toolkit = Toolkit()
|
||||
enabled_names = _resolve_enabled_tools(
|
||||
@@ -85,7 +85,7 @@ def build_toolkit(
|
||||
preset_kwargs=preset_kwargs,
|
||||
)
|
||||
|
||||
approval_enabled = enable_approval_layer if enable_hitl is None else enable_hitl
|
||||
approval_enabled = enable_hitl if enable_hitl is not None else True
|
||||
if approval_enabled:
|
||||
register_tool_middlewares(toolkit=toolkit, config_by_name=TOOL_CONFIGS)
|
||||
|
||||
@@ -94,20 +94,26 @@ def build_toolkit(
|
||||
|
||||
def build_stage_toolkit(
|
||||
*,
|
||||
stage: str,
|
||||
agent_type: AgentType,
|
||||
session: AsyncSession,
|
||||
owner_id: UUID,
|
||||
enabled_tool_names: set[str] | None = None,
|
||||
enable_hitl: bool | None = None,
|
||||
enable_approval_layer: bool = True,
|
||||
):
|
||||
groups = STAGE_TO_GROUPS.get(stage)
|
||||
groups = AGENT_TYPE_TO_GROUPS.get(agent_type)
|
||||
if groups is None:
|
||||
raise ValueError(f"unknown stage: {stage}")
|
||||
raise ValueError(f"unknown agent_type: {agent_type}")
|
||||
|
||||
stage_enabled_names = resolve_tool_names_by_groups(set(groups))
|
||||
selected_names = (
|
||||
stage_enabled_names
|
||||
if enabled_tool_names is None
|
||||
else stage_enabled_names | set(enabled_tool_names)
|
||||
)
|
||||
|
||||
return build_toolkit(
|
||||
session=session,
|
||||
owner_id=owner_id,
|
||||
groups=set(groups),
|
||||
enabled_tool_names=selected_names,
|
||||
enable_hitl=enable_hitl,
|
||||
enable_approval_layer=enable_approval_layer,
|
||||
)
|
||||
|
||||
@@ -12,17 +12,10 @@ from schemas.agent.ui_hints import (
|
||||
UiHintAction,
|
||||
UiHintActionNavigation,
|
||||
UiHintActionStyle,
|
||||
UiHintErrorBlock,
|
||||
UiHintKeyValuePair,
|
||||
UiHintKvBlock,
|
||||
UiHintListBlock,
|
||||
UiHintIntent,
|
||||
UiHintKvItem,
|
||||
UiHintListItem,
|
||||
UiHintOperationBlock,
|
||||
UiHintOperationResult,
|
||||
UiHintOperationType,
|
||||
UiHintStatus,
|
||||
UiHintTextBlock,
|
||||
UiHintTextFormat,
|
||||
UiHintsPayload,
|
||||
)
|
||||
|
||||
@@ -40,18 +33,10 @@ def calendar_error_output(
|
||||
retryable: bool,
|
||||
) -> ToolResponse:
|
||||
ui_hints = UiHintsPayload(
|
||||
intent=UiHintIntent.STATUS,
|
||||
status=UiHintStatus.ERROR,
|
||||
title="日历操作失败",
|
||||
description=message,
|
||||
blocks=[
|
||||
UiHintErrorBlock(
|
||||
kind="error",
|
||||
title="操作失败",
|
||||
errorCode=code,
|
||||
message=message,
|
||||
retryable=retryable,
|
||||
)
|
||||
],
|
||||
body=message,
|
||||
)
|
||||
output = build_error_output(
|
||||
tool_name=tool_name,
|
||||
@@ -84,29 +69,17 @@ def calendar_read_hints(
|
||||
for event in events
|
||||
]
|
||||
return UiHintsPayload(
|
||||
intent=UiHintIntent.LIST,
|
||||
status=UiHintStatus.SUCCESS,
|
||||
title="日程列表",
|
||||
description=f"共 {total} 个日程",
|
||||
blocks=[
|
||||
UiHintKvBlock(
|
||||
kind="kv",
|
||||
title="分页信息",
|
||||
pairs=[
|
||||
UiHintKeyValuePair(key="total", label="总数", value=total),
|
||||
UiHintKeyValuePair(key="page", label="当前页", value=page),
|
||||
UiHintKeyValuePair(key="page_size", label="每页", value=page_size),
|
||||
UiHintKeyValuePair(
|
||||
key="total_pages", label="总页数", value=total_pages
|
||||
),
|
||||
],
|
||||
),
|
||||
UiHintListBlock(
|
||||
kind="list",
|
||||
title="日程项",
|
||||
items=event_items,
|
||||
emptyText="当前没有日程",
|
||||
),
|
||||
items=[
|
||||
UiHintKvItem(key="total", label="总数", value=total),
|
||||
UiHintKvItem(key="page", label="当前页", value=page),
|
||||
UiHintKvItem(key="page_size", label="每页", value=page_size),
|
||||
UiHintKvItem(key="total_pages", label="总页数", value=total_pages),
|
||||
],
|
||||
list_items=event_items,
|
||||
actions=[
|
||||
UiHintAction(
|
||||
label="打开日历",
|
||||
@@ -125,65 +98,38 @@ def calendar_write_hints(
|
||||
event: dict[str, Any] | None,
|
||||
event_id: str | None,
|
||||
) -> UiHintsPayload:
|
||||
operation_type = UiHintOperationType.EXECUTE
|
||||
if operation == "create":
|
||||
operation_type = UiHintOperationType.CREATE
|
||||
elif operation == "update":
|
||||
operation_type = UiHintOperationType.UPDATE
|
||||
elif operation == "delete":
|
||||
operation_type = UiHintOperationType.DELETE
|
||||
kv_items: list[UiHintKvItem] = []
|
||||
|
||||
blocks: list[Any] = [
|
||||
UiHintOperationBlock(
|
||||
kind="operation",
|
||||
title="日历写入结果",
|
||||
operation=operation_type,
|
||||
result=UiHintOperationResult.SUCCESS,
|
||||
message=message,
|
||||
affectedCount=1,
|
||||
)
|
||||
]
|
||||
if event:
|
||||
blocks.append(
|
||||
UiHintKvBlock(
|
||||
kind="kv",
|
||||
title="日程详情",
|
||||
pairs=[
|
||||
UiHintKeyValuePair(
|
||||
key="event_id",
|
||||
label="日程ID",
|
||||
value=str(event.get("id") or ""),
|
||||
copyable=True,
|
||||
),
|
||||
UiHintKeyValuePair(
|
||||
key="title",
|
||||
label="标题",
|
||||
value=str(event.get("title") or ""),
|
||||
copyable=True,
|
||||
),
|
||||
UiHintKeyValuePair(
|
||||
key="start_at",
|
||||
label="开始时间",
|
||||
value=str(event.get("startAt") or ""),
|
||||
copyable=True,
|
||||
),
|
||||
],
|
||||
)
|
||||
)
|
||||
kv_items = [
|
||||
UiHintKvItem(
|
||||
key="event_id",
|
||||
label="日程ID",
|
||||
value=str(event.get("id") or ""),
|
||||
copyable=True,
|
||||
),
|
||||
UiHintKvItem(
|
||||
key="title",
|
||||
label="标题",
|
||||
value=str(event.get("title") or ""),
|
||||
copyable=True,
|
||||
),
|
||||
UiHintKvItem(
|
||||
key="start_at",
|
||||
label="开始时间",
|
||||
value=str(event.get("startAt") or ""),
|
||||
copyable=True,
|
||||
),
|
||||
]
|
||||
elif event_id:
|
||||
blocks.append(
|
||||
UiHintTextBlock(
|
||||
kind="text",
|
||||
content=f"目标日程 ID: {event_id}",
|
||||
format=UiHintTextFormat.PLAIN,
|
||||
)
|
||||
)
|
||||
message = f"目标日程 ID: {event_id}\n{message}"
|
||||
|
||||
return UiHintsPayload(
|
||||
intent=UiHintIntent.STATUS,
|
||||
status=UiHintStatus.SUCCESS,
|
||||
title="日历操作完成",
|
||||
description=message,
|
||||
blocks=blocks,
|
||||
body=message,
|
||||
items=kv_items if kv_items else None,
|
||||
actions=[
|
||||
UiHintAction(
|
||||
label="查看日历",
|
||||
@@ -203,36 +149,17 @@ def calendar_share_hints(
|
||||
permission_text = (
|
||||
", ".join([k for k, v in permission.items() if v is True]) or "按邀请人单独设置"
|
||||
)
|
||||
|
||||
return UiHintsPayload(
|
||||
intent=UiHintIntent.STATUS,
|
||||
status=UiHintStatus.SUCCESS,
|
||||
title="日程已分享",
|
||||
description=f"已邀请 {len(invited)} 人",
|
||||
blocks=[
|
||||
UiHintOperationBlock(
|
||||
kind="operation",
|
||||
title="分享结果",
|
||||
operation=UiHintOperationType.EXECUTE,
|
||||
result=UiHintOperationResult.SUCCESS,
|
||||
message=f"已邀请 {len(invited)} 人",
|
||||
affectedCount=len(invited),
|
||||
),
|
||||
UiHintKvBlock(
|
||||
kind="kv",
|
||||
title="分享信息",
|
||||
pairs=[
|
||||
UiHintKeyValuePair(
|
||||
key="event_id", label="日程ID", value=event_id, copyable=True
|
||||
),
|
||||
UiHintKeyValuePair(
|
||||
key="permission", label="权限", value=permission_text
|
||||
),
|
||||
],
|
||||
),
|
||||
UiHintListBlock(
|
||||
kind="list",
|
||||
title="被邀请人",
|
||||
items=[UiHintListItem(title=email) for email in invited],
|
||||
emptyText="暂无被邀请人",
|
||||
),
|
||||
items=[
|
||||
UiHintKvItem(key="event_id", label="日程ID", value=event_id, copyable=True),
|
||||
UiHintKvItem(key="permission", label="权限", value=permission_text),
|
||||
],
|
||||
list_items=[UiHintListItem(title=email) for email in invited]
|
||||
if invited
|
||||
else [],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user