feat: 实现 AgentScope ReAct Runner 两阶段执行并重构事件处理

This commit is contained in:
zl-q
2026-03-16 09:01:01 +08:00
parent 072c09d99d
commit dcceb48d84
51 changed files with 5015 additions and 5663 deletions
@@ -38,9 +38,10 @@ _INTERNAL_TO_AGUI: dict[str, EventType] = {
def _convert_to_agui_type(internal_type: str) -> EventType:
return _INTERNAL_TO_AGUI.get(
internal_type, EventType(internal_type.upper().replace(".", "_"))
)
mapped = _INTERNAL_TO_AGUI.get(internal_type)
if mapped is not None:
return mapped
return EventType(internal_type.upper().replace(".", "_"))
def _is_agui_event(event: dict[str, Any]) -> bool:
@@ -142,32 +143,64 @@ def to_agui_wire_event(event: dict[str, Any] | BaseEvent) -> dict[str, Any]:
return event
internal_type = str(event.get("type", "")).strip()
thread_id = event.get("threadId")
run_id = event.get("runId")
data = event.get("data")
if internal_type == "text.end" and isinstance(data, dict):
text_end_payload: dict[str, Any] = {
"type": _convert_to_agui_type(internal_type).value,
}
if isinstance(thread_id, str) and thread_id:
text_end_payload["threadId"] = thread_id
if isinstance(run_id, str) and run_id:
text_end_payload["runId"] = run_id
for key in ("messageId", "workerAgentOutput"):
value = data.get(key)
if value is not None:
text_end_payload[key] = value
return text_end_payload
if internal_type == "tool.result" and isinstance(data, dict):
tool_result_payload = {
"type": _convert_to_agui_type(internal_type).value,
}
if isinstance(thread_id, str) and thread_id:
tool_result_payload["threadId"] = thread_id
if isinstance(run_id, str) and run_id:
tool_result_payload["runId"] = run_id
for key in ("messageId", "toolCallId", "toolAgentOutput"):
value = data.get(key)
if value is not None:
tool_result_payload[key] = value
return tool_result_payload
builder = _BUILDER_MAP.get(internal_type)
if builder:
agui_event = builder(event)
return agui_event.model_dump(by_alias=True, exclude_none=True)
payload = agui_event.model_dump(by_alias=True, exclude_none=True)
if isinstance(thread_id, str) and thread_id:
payload["threadId"] = thread_id
if isinstance(run_id, str) and run_id:
payload["runId"] = run_id
if isinstance(data, dict):
reserved = {"type", "threadId", "runId"}
payload.update({k: v for k, v in data.items() if k not in reserved})
return payload
wire_type = _convert_to_agui_type(internal_type)
payload: dict[str, Any] = {
"type": wire_type.value,
}
thread_id = event.get("threadId")
run_id = event.get("runId")
if isinstance(thread_id, str) and thread_id:
payload["threadId"] = thread_id
if isinstance(run_id, str) and run_id:
payload["runId"] = run_id
data = event.get("data")
if isinstance(data, dict):
if internal_type == "text.end":
for key in ("messageId", "workerAgentOutput"):
value = data.get(key)
if value is not None:
payload[key] = value
return payload
reserved = {"type", "threadId", "runId"}
data_map = cast(dict[str, Any], data)
payload.update({k: v for k, v in data_map.items() if k not in reserved})
@@ -50,5 +50,5 @@ class AgentScopeEventPipeline:
) -> str:
event_dict = to_dict(event)
wire_event = self._codec.to_wire(event_dict)
await self._store.persist(wire_event)
await self._store.persist(event_dict)
return await self._bus.publish(session_id=session_id, event=wire_event)
+36 -23
View File
@@ -55,8 +55,8 @@ class SqlAlchemyEventStore:
self._message_contexts: dict[tuple[str, str], dict[str, object]] = {}
async def persist(self, event: dict[str, Any]) -> None:
event_type = str(event.get("type", "")).strip().upper()
thread_id = event.get("threadId")
event_type = str(event.get("type", "")).strip().upper().replace(".", "_")
thread_id = self._event_value(event, "threadId")
if not isinstance(thread_id, str) or not thread_id:
return
try:
@@ -124,8 +124,8 @@ class SqlAlchemyEventStore:
await session.commit()
def _buffer_text_delta(self, *, session_key: str, event: dict[str, Any]) -> None:
message_id = event.get("messageId")
delta = event.get("delta")
message_id = self._event_value(event, "messageId")
delta = self._event_value(event, "delta")
if not isinstance(message_id, str) or not message_id:
return
if not isinstance(delta, str) or not delta:
@@ -143,13 +143,13 @@ class SqlAlchemyEventStore:
self._message_contexts.pop(key, None)
def _buffer_text_context(self, *, session_key: str, event: dict[str, Any]) -> None:
message_id = event.get("messageId")
message_id = self._event_value(event, "messageId")
if not isinstance(message_id, str) or not message_id:
return
key = (session_key, message_id)
role = event.get("role")
stage = event.get("stage")
tool_name = event.get("toolName")
role = self._event_value(event, "role")
stage = self._event_value(event, "stage")
tool_name = self._event_value(event, "toolName")
context: dict[str, object] = {}
if isinstance(role, str) and role:
context["role"] = role
@@ -168,7 +168,7 @@ class SqlAlchemyEventStore:
session_repo: SessionRepository,
message_repo: MessageRepository,
) -> None:
message_id_raw = event.get("messageId")
message_id_raw = self._event_value(event, "messageId")
message_id = message_id_raw if isinstance(message_id_raw, str) else ""
key = (str(session_id), message_id)
content = self._message_buffers.get(key, "")
@@ -177,26 +177,26 @@ class SqlAlchemyEventStore:
context = self._message_contexts.get(key, {})
input_tokens = self._to_int(event.get("inputTokens"))
output_tokens = self._to_int(event.get("outputTokens"))
input_tokens = self._to_int(self._event_value(event, "inputTokens"))
output_tokens = self._to_int(self._event_value(event, "outputTokens"))
token_delta = input_tokens + output_tokens
cost = self._to_decimal(event.get("cost"))
latency_ms = self._to_int_or_none(event.get("latencyMs"))
run_id = event.get("runId")
model_code = event.get("model")
cost = self._to_decimal(self._event_value(event, "cost"))
latency_ms = self._to_int_or_none(self._event_value(event, "latencyMs"))
run_id = self._event_value(event, "runId")
model_code = self._event_value(event, "model")
metadata: dict[str, object] = {"message_id": message_id}
if isinstance(run_id, str) and run_id:
metadata["run_id"] = run_id
if latency_ms is not None:
metadata["latency_ms"] = latency_ms
stage = event.get("stage")
stage = self._event_value(event, "stage")
if not isinstance(stage, str):
stage = context.get("stage")
if isinstance(stage, str) and stage:
metadata["stage"] = stage
worker_payload = event.get("workerAgentOutput")
worker_payload = self._event_value(event, "workerAgentOutput")
if isinstance(worker_payload, dict):
try:
if "ui_hints" in worker_payload:
@@ -264,11 +264,11 @@ class SqlAlchemyEventStore:
session_repo: SessionRepository,
message_repo: MessageRepository,
) -> None:
tool_name = event.get("toolName")
tool_name = self._event_value(event, "toolName")
if not isinstance(tool_name, str) or not tool_name:
return
raw_output = event.get("toolAgentOutput")
raw_output = self._event_value(event, "toolAgentOutput")
if not isinstance(raw_output, dict):
return
try:
@@ -276,11 +276,11 @@ class SqlAlchemyEventStore:
except Exception:
return
run_id = event.get("runId")
run_id = self._event_value(event, "runId")
run_id_value = run_id if isinstance(run_id, str) and run_id else ""
task_id = event.get("taskId")
task_id = self._event_value(event, "taskId")
task_id_value = task_id if isinstance(task_id, str) and task_id else "task"
call_id_value = event.get("callId")
call_id_value = self._event_value(event, "callId")
if not isinstance(call_id_value, str) or not call_id_value:
call_id_value = (
f"{run_id_value}-{task_id_value}-{uuid4().hex[:8]}"
@@ -303,7 +303,7 @@ class SqlAlchemyEventStore:
}
if run_id_value:
metadata["run_id"] = run_id_value
stage = event.get("stage")
stage = self._event_value(event, "stage")
if isinstance(stage, str) and stage:
metadata["stage"] = stage
if task_id_value:
@@ -421,6 +421,19 @@ class SqlAlchemyEventStore:
return Decimal("0")
return parsed if parsed >= 0 else Decimal("0")
def _event_value(
self,
event: dict[str, Any],
key: str,
default: object | None = None,
) -> object | None:
if key in event:
return event.get(key)
data = event.get("data")
if isinstance(data, dict):
return data.get(key, default)
return default
def _sanitize_path_component(value: str) -> str:
compact = re.sub(r"[^A-Za-z0-9._-]", "-", value.strip())
@@ -2,9 +2,10 @@ from __future__ import annotations
import json
from datetime import datetime, timezone
from typing import Any
from typing import Any, Sequence
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
from ag_ui.core.types import Tool
from core.agentscope.prompts.agent_prompt import (
build_agent_prompt,
)
@@ -193,7 +194,7 @@ def build_system_prompt(
user_context: UserContext,
now_utc: datetime,
extra_context: str | None = None,
tools: list[dict[str, Any]] | None = None,
tools: Sequence[Tool] | None = None,
) -> str:
sections = [
_build_identity_section(),
@@ -1,7 +1,9 @@
from __future__ import annotations
import json
from typing import Any, Iterable
from typing import Iterable
from ag_ui.core.types import Tool
def _wrap_section(section: str, content: str) -> str:
@@ -15,18 +17,16 @@ def _wrap_section(section: str, content: str) -> str:
def build_tools_prompt(
*,
tools: Iterable[dict[str, Any]],
tools: Iterable[Tool],
) -> str:
lines: list[str] = []
lines.append("[Available Tools]")
for item in tools:
name = item.get("name")
description = item.get("description") or ""
parameters = item.get("parameters") or {}
if not isinstance(name, str) or not name:
continue
lines.append(f"- {name}: {description}".strip())
name = item.name
description = item.description or ""
parameters = item.parameters or {}
lines.append(f"- {name}: {description}")
lines.append(
" - args_schema: "
+ json.dumps(parameters, ensure_ascii=True, separators=(",", ":"))
@@ -1,16 +1,322 @@
from __future__ import annotations
import json
from collections.abc import AsyncGenerator, Sequence
from dataclasses import dataclass
from datetime import datetime, timezone
from decimal import Decimal
from typing import TYPE_CHECKING, Any
from uuid import UUID, uuid4
from ag_ui.core.types import RunAgentInput
from agentscope.agent import ReActAgent
from agentscope.formatter import OpenAIChatFormatter
from agentscope.memory import InMemoryMemory
from agentscope.message import Msg
from agentscope.model import OpenAIChatModel
from core.agentscope.events.persistence import MessageRepository, SessionRepository
from core.agentscope.prompts.system_prompt import build_system_prompt
from core.agentscope.tools.toolkit import build_stage_toolkit
from core.db.session import AsyncSessionLocal
from core.logging import get_logger
from models.agent_chat_message import AgentChatMessageRole
from models.agent_chat_session import AgentChatSessionStatus
from models.llm import Llm
from models.system_agents import SystemAgents
from schemas.agent.runtime_models import (
RouterAgentOutput,
ToolAgentOutput,
WorkerAgentOutputLite,
resolve_worker_output_model,
)
from schemas.agent.system_agent import AgentType, SystemAgentLLMConfig
from schemas.messages.chat_message import AgentChatMessage, AgentChatMessageMetadata
from schemas.user import UserContext
from services.litellm.service import LiteLLMService
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
if TYPE_CHECKING:
from core.agentscope.runtime.orchestrator import PipelineLike
logger = get_logger("core.agentscope.runtime.react_runner")
@dataclass(frozen=True)
class SystemAgentRuntimeConfig:
agent_type: AgentType
model_code: str
llm_config: SystemAgentLLMConfig
@dataclass(frozen=True)
class StageExecutionResult:
message: Msg
payload: dict[str, Any]
response_metadata: dict[str, Any]
class _TrackingChatModel:
def __init__(self, inner: OpenAIChatModel) -> None:
self._inner = inner
self._total_input_tokens = 0
self._total_output_tokens = 0
self._total_latency_ms = 0
self._cached_prompt_tokens = 0
@property
def stream(self) -> bool:
return self._inner.stream
@stream.setter
def stream(self, value: bool) -> None:
self._inner.stream = value
def __getattr__(self, name: str) -> Any:
return getattr(self._inner, name)
async def __call__(self, *args: Any, **kwargs: Any) -> Any:
response = await self._inner(*args, **kwargs)
if isinstance(response, AsyncGenerator):
return self._track_stream(response)
self._record_usage(getattr(response, "usage", None))
return response
async def _track_stream(
self, response: AsyncGenerator[Any, None]
) -> AsyncGenerator[Any, None]:
latest_usage = None
async for chunk in response:
usage = getattr(chunk, "usage", None)
if usage is not None:
latest_usage = usage
yield chunk
self._record_usage(latest_usage)
def _record_usage(self, usage: Any) -> None:
if usage is None:
return
self._total_input_tokens += max(int(getattr(usage, "input_tokens", 0) or 0), 0)
self._total_output_tokens += max(
int(getattr(usage, "output_tokens", 0) or 0), 0
)
self._total_latency_ms += max(
int(round(float(getattr(usage, "time", 0) or 0) * 1000)), 0
)
metadata = getattr(usage, "metadata", None)
if metadata is not None:
cached_tokens = 0
if isinstance(metadata, dict):
prompt_details = metadata.get("prompt_tokens_details")
if isinstance(prompt_details, dict):
cached_tokens = int(prompt_details.get("cached_tokens", 0) or 0)
else:
prompt_details = getattr(metadata, "prompt_tokens_details", None)
cached_tokens = int(getattr(prompt_details, "cached_tokens", 0) or 0)
self._cached_prompt_tokens += max(cached_tokens, 0)
def usage_summary(self) -> dict[str, int]:
return {
"input_tokens": self._total_input_tokens,
"output_tokens": self._total_output_tokens,
"latency_ms": self._total_latency_ms,
"cached_prompt_tokens": self._cached_prompt_tokens,
}
class _PipelineStageEmitter:
def __init__(
self,
*,
pipeline: PipelineLike,
session_id: str,
run_id: str,
stage: str,
emit_text_events: bool,
emit_tool_events: bool,
) -> None:
self._pipeline = pipeline
self._session_id = session_id
self._run_id = run_id
self._stage = stage
self._emit_text_events = emit_text_events
self._emit_tool_events = emit_tool_events
self._text_by_message_id: dict[str, str] = {}
self._emitted_tool_calls: set[str] = set()
self._emitted_tool_results: set[str] = set()
self.latest_text_message_id: str | None = None
self.latest_text: str = ""
async def handle_print(self, *, msg: Msg, last: bool) -> None:
del last
if self._emit_tool_events:
await self._emit_tool_events_from_msg(msg)
if self._emit_text_events:
await self._emit_text_events_from_msg(msg)
async def _emit_text_events_from_msg(self, msg: Msg) -> None:
text = msg.get_text_content(separator="") or ""
if not text:
return
message_id = str(msg.id)
previous = self._text_by_message_id.get(message_id, "")
if message_id not in self._text_by_message_id:
await self._emit(
"text.start",
{
"messageId": message_id,
"role": "assistant",
"stage": self._stage,
},
)
delta = text[len(previous) :] if text.startswith(previous) else text
if delta:
await self._emit(
"text.delta",
{
"messageId": message_id,
"delta": delta,
"stage": self._stage,
},
)
self._text_by_message_id[message_id] = text
self.latest_text_message_id = message_id
self.latest_text = text
async def _emit_tool_events_from_msg(self, msg: Msg) -> None:
for block in msg.get_content_blocks("tool_use"):
tool_call_id = str(block.get("id", "")).strip()
tool_name = str(block.get("name", "")).strip()
if (
not tool_call_id
or not tool_name
or tool_call_id in self._emitted_tool_calls
):
continue
payload = {
"messageId": str(msg.id),
"toolCallId": tool_call_id,
"toolName": tool_name,
"stage": self._stage,
}
await self._emit("tool.start", payload)
await self._emit(
"tool.args",
{
**payload,
"args": block.get("input", {}),
},
)
await self._emit("tool.end", payload)
self._emitted_tool_calls.add(tool_call_id)
for block in msg.get_content_blocks("tool_result"):
tool_call_id = str(block.get("id", "")).strip()
if not tool_call_id or tool_call_id in self._emitted_tool_results:
continue
tool_output = _parse_tool_agent_output(block.get("output"))
if tool_output is None:
continue
await self._emit(
"tool.result",
{
"messageId": str(msg.id),
"toolCallId": tool_call_id,
"toolName": tool_output.tool_name,
"stage": self._stage,
"toolAgentOutput": tool_output.model_dump(
mode="json", exclude_none=True
),
},
)
self._emitted_tool_results.add(tool_call_id)
async def emit_final_text_end(
self,
*,
worker_output: dict[str, Any],
response_metadata: dict[str, Any],
) -> None:
message_id = (
self.latest_text_message_id or f"worker-{self._run_id}-{uuid4().hex[:8]}"
)
if self.latest_text_message_id is None and worker_output.get("answer"):
await self._emit(
"text.start",
{
"messageId": message_id,
"role": "assistant",
"stage": self._stage,
},
)
await self._emit(
"text.delta",
{
"messageId": message_id,
"delta": worker_output.get("answer", ""),
"stage": self._stage,
},
)
await self._emit(
"text.end",
{
"messageId": message_id,
"role": "assistant",
"stage": self._stage,
"workerAgentOutput": worker_output,
**response_metadata,
},
)
async def _emit(self, event_type: str, data: dict[str, Any]) -> None:
await self._pipeline.emit(
session_id=self._session_id,
event={
"type": event_type,
"threadId": self._session_id,
"runId": self._run_id,
"data": data,
},
)
class _PipelineReActAgent(ReActAgent):
def __init__(
self, *, emitter: _PipelineStageEmitter | None = None, **kwargs: Any
) -> None:
super().__init__(**kwargs)
self._pipeline_emitter = emitter
self.disable_console_output()
async def print(self, msg: Msg, last: bool = True, speech: Any = None) -> None:
del speech
if self._pipeline_emitter is not None:
await self._pipeline_emitter.handle_print(msg=msg, last=last)
def _parse_tool_agent_output(output: Any) -> ToolAgentOutput | None:
blocks = output if isinstance(output, Sequence) else []
for block in blocks:
if not isinstance(block, dict) or block.get("type") != "text":
continue
text = block.get("text")
if not isinstance(text, str) or not text.strip():
continue
try:
return ToolAgentOutput.model_validate(json.loads(text))
except Exception:
return None
return None
def _normalize_tool_name(value: str) -> str:
return value.strip().replace(".", "_").replace("-", "_")
class AgentScopeReActRunner:
def __init__(self, *, litellm_service: LiteLLMService | None = None) -> None:
self._litellm_service = litellm_service or LiteLLMService()
async def execute(
self,
*,
@@ -19,4 +325,367 @@ class AgentScopeReActRunner:
pipeline: PipelineLike,
run_input: RunAgentInput,
) -> dict[str, Any]:
raise NotImplementedError("execute method not implemented")
owner_id = UUID(user_context.id)
enabled_tool_names = self._extract_tool_names(run_input)
async with AsyncSessionLocal() as session:
router_toolkit, worker_toolkit = self._build_toolkits(
session=session,
owner_id=owner_id,
enabled_tool_names=enabled_tool_names,
)
router_config = await self._load_system_agent_config(
session=session,
agent_type=AgentType.ROUTER,
)
worker_config = await self._load_system_agent_config(
session=session,
agent_type=AgentType.WORKER,
)
await self._emit_step_event(
pipeline=pipeline,
run_input=run_input,
step_name="router",
event_type="step.start",
)
router_result = await self._run_router_stage(
user_context=user_context,
context_messages=context_messages,
toolkit=router_toolkit,
run_input=run_input,
stage_config=router_config,
)
router_output = RouterAgentOutput.model_validate(router_result.payload)
await self._persist_router_message(
session=session,
thread_id=run_input.thread_id,
run_id=run_input.run_id,
model_code=router_config.model_code,
router_output=router_output,
response_metadata=router_result.response_metadata,
)
await session.commit()
await self._emit_step_event(
pipeline=pipeline,
run_input=run_input,
step_name="router",
event_type="step.finish",
)
worker_output_model = resolve_worker_output_model(router_output.ui.ui_mode)
await self._emit_step_event(
pipeline=pipeline,
run_input=run_input,
step_name="worker",
event_type="step.start",
)
worker_result = await self._run_worker_stage(
user_context=user_context,
router_output=router_output,
toolkit=worker_toolkit,
run_input=run_input,
stage_config=worker_config,
worker_output_model=worker_output_model,
pipeline=pipeline,
)
worker_output = worker_output_model.model_validate(worker_result.payload)
await self._emit_step_event(
pipeline=pipeline,
run_input=run_input,
step_name="worker",
event_type="step.finish",
)
return {
"router": router_output.model_dump(mode="json", exclude_none=True),
"worker": worker_output.model_dump(mode="json", exclude_none=True),
}
def _build_toolkits(
self,
*,
session: AsyncSession,
owner_id: UUID,
enabled_tool_names: set[str] | None,
) -> tuple[Any, Any]:
return (
build_stage_toolkit(
agent_type=AgentType.ROUTER,
session=session,
owner_id=owner_id,
enabled_tool_names=enabled_tool_names,
),
build_stage_toolkit(
agent_type=AgentType.WORKER,
session=session,
owner_id=owner_id,
enabled_tool_names=enabled_tool_names,
),
)
def _extract_tool_names(self, run_input: RunAgentInput) -> set[str] | None:
raw_tools = getattr(run_input, "tools", None)
if not isinstance(raw_tools, list):
return None
selected: set[str] = set()
for item in raw_tools:
if isinstance(item, dict):
name = item.get("name")
else:
name = getattr(item, "name", None)
if isinstance(name, str) and name.strip():
selected.add(_normalize_tool_name(name))
return selected
async def _load_system_agent_config(
self,
*,
session: AsyncSession,
agent_type: AgentType,
) -> SystemAgentRuntimeConfig:
stmt = (
select(SystemAgents, Llm)
.join(Llm, SystemAgents.llm_id == Llm.id)
.where(SystemAgents.agent_type == agent_type.value)
)
row = (await session.execute(stmt)).one_or_none()
if row is None:
raise RuntimeError(f"system agent config not found: {agent_type.value}")
system_agent, llm = row
status = str(system_agent.status).strip().lower()
if status != "active":
raise RuntimeError(f"system agent is not active: {agent_type.value}")
return SystemAgentRuntimeConfig(
agent_type=agent_type,
model_code=llm.model_code,
llm_config=SystemAgentLLMConfig.model_validate(system_agent.config or {}),
)
async def _run_router_stage(
self,
*,
user_context: UserContext,
context_messages: list[Msg],
toolkit: Any,
run_input: RunAgentInput,
stage_config: SystemAgentRuntimeConfig,
) -> StageExecutionResult:
tracking_model = self._build_model(stage_config=stage_config)
agent = self._build_agent(
agent_name="router",
system_prompt=build_system_prompt(
agent_type=AgentType.ROUTER,
user_context=user_context,
now_utc=datetime.now(timezone.utc),
tools=run_input.tools,
),
toolkit=toolkit,
model=tracking_model,
)
response_msg = await agent.reply(
context_messages, structured_model=RouterAgentOutput
)
payload = RouterAgentOutput.model_validate(
response_msg.metadata or {}
).model_dump(
mode="json",
exclude_none=True,
)
return StageExecutionResult(
message=response_msg,
payload=payload,
response_metadata=self._litellm_service.build_usage_metadata(
model=stage_config.model_code,
usage_summary=tracking_model.usage_summary(),
),
)
async def _run_worker_stage(
self,
*,
user_context: UserContext,
router_output: RouterAgentOutput,
toolkit: Any,
run_input: RunAgentInput,
stage_config: SystemAgentRuntimeConfig,
worker_output_model: type[WorkerAgentOutputLite],
pipeline: PipelineLike,
) -> StageExecutionResult:
worker_input = self._build_worker_input_messages(
router_output=router_output,
)
tracking_model = self._build_model(stage_config=stage_config)
emitter = _PipelineStageEmitter(
pipeline=pipeline,
session_id=run_input.thread_id,
run_id=run_input.run_id,
stage="worker",
emit_text_events=True,
emit_tool_events=True,
)
agent = self._build_agent(
agent_name="worker",
system_prompt=build_system_prompt(
agent_type=AgentType.WORKER,
user_context=user_context,
now_utc=datetime.now(timezone.utc),
tools=run_input.tools,
),
toolkit=toolkit,
model=tracking_model,
emitter=emitter,
)
response_msg = await agent.reply(
worker_input,
structured_model=worker_output_model,
)
worker_payload = worker_output_model.model_validate(response_msg.metadata or {})
response_metadata = self._litellm_service.build_usage_metadata(
model=stage_config.model_code,
usage_summary=tracking_model.usage_summary(),
)
await emitter.emit_final_text_end(
worker_output=worker_payload.model_dump(mode="json", exclude_none=True),
response_metadata=response_metadata,
)
return StageExecutionResult(
message=response_msg,
payload=worker_payload.model_dump(mode="json", exclude_none=True),
response_metadata=response_metadata,
)
def _build_worker_input_messages(
self,
*,
router_output: RouterAgentOutput,
) -> list[Msg]:
routing_contract = json.dumps(
router_output.model_dump(mode="json", exclude_none=True),
ensure_ascii=False,
separators=(",", ":"),
)
routing_msg = Msg(
name="router",
role="user",
content=(
"Use the following routing contract as the execution source of truth. "
f"Do not change the routed objective:\n{routing_contract}"
),
)
return [routing_msg]
def _build_model(
self, *, stage_config: SystemAgentRuntimeConfig
) -> _TrackingChatModel:
model = OpenAIChatModel(
model_name=stage_config.model_code,
api_key=self._litellm_service.proxy_api_key,
stream=True,
client_kwargs={"base_url": self._litellm_service.proxy_base_url},
generate_kwargs={
"temperature": stage_config.llm_config.temperature,
"max_tokens": stage_config.llm_config.max_tokens,
"timeout": stage_config.llm_config.timeout_seconds,
},
)
return _TrackingChatModel(model)
def _build_agent(
self,
*,
agent_name: str,
system_prompt: str,
toolkit: Any,
model: _TrackingChatModel,
emitter: _PipelineStageEmitter | None = None,
) -> _PipelineReActAgent:
return _PipelineReActAgent(
name=agent_name,
sys_prompt=system_prompt,
model=model,
formatter=OpenAIChatFormatter(),
toolkit=toolkit,
memory=InMemoryMemory(),
emitter=emitter,
)
async def _emit_step_event(
self,
*,
pipeline: PipelineLike,
run_input: RunAgentInput,
step_name: str,
event_type: str,
) -> None:
await pipeline.emit(
session_id=run_input.thread_id,
event={
"type": event_type,
"threadId": run_input.thread_id,
"runId": run_input.run_id,
"data": {"stepName": step_name},
},
)
async def _persist_router_message(
self,
*,
session: AsyncSession,
thread_id: str,
run_id: str,
model_code: str,
router_output: RouterAgentOutput,
response_metadata: dict[str, Any],
) -> None:
session_id = UUID(thread_id)
message_repo = MessageRepository(session)
session_repo = SessionRepository(session)
locked_session = await session_repo.lock_session_for_update(
session_id=session_id
)
if locked_session is None:
raise RuntimeError("chat session not found for router persistence")
seq = int(getattr(locked_session, "message_count", 0) or 0) + 1
metadata = AgentChatMessageMetadata(
run_id=run_id,
agent_type=AgentType.ROUTER,
router_agent_output=router_output,
)
message_payload = AgentChatMessage(
id=uuid4(),
seq=seq,
role=AgentChatMessageRole.ASSISTANT.value,
content="",
model_code=model_code,
tool_name=None,
input_tokens=int(response_metadata.get("inputTokens", 0) or 0),
output_tokens=int(response_metadata.get("outputTokens", 0) or 0),
cost=Decimal(str(response_metadata.get("cost", 0) or 0)),
latency_ms=int(response_metadata.get("latencyMs", 0) or 0),
metadata=metadata,
timestamp=datetime.now(timezone.utc),
)
await message_repo.append_message(
session_id=session_id,
seq=message_payload.seq,
role=AgentChatMessageRole.ASSISTANT,
content=message_payload.content,
model_code=message_payload.model_code,
tool_name=message_payload.tool_name,
metadata=metadata.model_dump(mode="json", exclude_none=True),
input_tokens=message_payload.input_tokens,
output_tokens=message_payload.output_tokens,
cost=message_payload.cost,
latency_ms=message_payload.latency_ms,
)
await session_repo.update_runtime_state(
chat_session=locked_session,
status=AgentChatSessionStatus.RUNNING,
state_snapshot=locked_session.state_snapshot or {},
message_delta=1,
token_delta=message_payload.input_tokens + message_payload.output_tokens,
cost_delta=message_payload.cost,
)
await session.flush()
File diff suppressed because it is too large Load Diff
@@ -23,9 +23,8 @@ from schemas.agent.ui_hints import (
UiHintAction,
UiHintActionCopy,
UiHintActionStyle,
UiHintErrorBlock,
UiHintKeyValuePair,
UiHintKvBlock,
UiHintIntent,
UiHintKvItem,
UiHintStatus,
UiHintsPayload,
)
@@ -54,18 +53,10 @@ def _lookup_error_output(
update={
"tool_call_args": tool_call_args,
"ui_hints": UiHintsPayload(
intent=UiHintIntent.STATUS,
status=UiHintStatus.ERROR,
title="用户查找失败",
description=message,
blocks=[
UiHintErrorBlock(
kind="error",
title="查找失败",
errorCode=code,
message=message,
retryable=retryable,
)
],
body=message,
),
}
)
@@ -78,28 +69,15 @@ def _lookup_success_hints(resolved: dict[str, Any]) -> UiHintsPayload:
username = str(resolved.get("username") or "")
matched_by = str(resolved.get("matchedBy") or "")
return UiHintsPayload(
intent=UiHintIntent.DATA,
status=UiHintStatus.SUCCESS,
title="用户信息",
description=f"匹配方式: {matched_by}",
blocks=[
UiHintKvBlock(
kind="kv",
title="查找结果",
pairs=[
UiHintKeyValuePair(
key="user_id", label="用户ID", value=user_id, copyable=True
),
UiHintKeyValuePair(
key="email", label="邮箱", value=email, copyable=True
),
UiHintKeyValuePair(
key="username", label="用户名", value=username or "-"
),
UiHintKeyValuePair(
key="matched_by", label="匹配方式", value=matched_by
),
],
)
items=[
UiHintKvItem(key="user_id", label="用户ID", value=user_id, copyable=True),
UiHintKvItem(key="email", label="邮箱", value=email, copyable=True),
UiHintKvItem(key="username", label="用户名", value=username or "-"),
UiHintKvItem(key="matched_by", label="匹配方式", value=matched_by),
],
actions=[
UiHintAction(
+17 -11
View File
@@ -18,6 +18,7 @@ from core.agentscope.tools.tool_config import (
)
from core.agentscope.tools.tool_middleware import register_tool_middlewares
from sqlalchemy.ext.asyncio import AsyncSession
from schemas.agent.system_agent import AgentType
TOOL_FUNCTIONS: dict[str, Any] = {
"calendar_read": calendar_read,
@@ -27,9 +28,9 @@ TOOL_FUNCTIONS: dict[str, Any] = {
}
STAGE_TO_GROUPS: dict[str, set[ToolGroup]] = {
"router": {ToolGroup.READ},
"worker": {ToolGroup.READ, ToolGroup.WRITE},
AGENT_TYPE_TO_GROUPS: dict[AgentType, set[ToolGroup]] = {
AgentType.ROUTER: {ToolGroup.READ},
AgentType.WORKER: {ToolGroup.READ, ToolGroup.WRITE},
}
@@ -61,7 +62,6 @@ def build_toolkit(
groups: set[ToolGroup] | None = None,
enabled_tool_names: set[str] | None = None,
enable_hitl: bool | None = None,
enable_approval_layer: bool = True,
):
toolkit = Toolkit()
enabled_names = _resolve_enabled_tools(
@@ -85,7 +85,7 @@ def build_toolkit(
preset_kwargs=preset_kwargs,
)
approval_enabled = enable_approval_layer if enable_hitl is None else enable_hitl
approval_enabled = enable_hitl if enable_hitl is not None else True
if approval_enabled:
register_tool_middlewares(toolkit=toolkit, config_by_name=TOOL_CONFIGS)
@@ -94,20 +94,26 @@ def build_toolkit(
def build_stage_toolkit(
*,
stage: str,
agent_type: AgentType,
session: AsyncSession,
owner_id: UUID,
enabled_tool_names: set[str] | None = None,
enable_hitl: bool | None = None,
enable_approval_layer: bool = True,
):
groups = STAGE_TO_GROUPS.get(stage)
groups = AGENT_TYPE_TO_GROUPS.get(agent_type)
if groups is None:
raise ValueError(f"unknown stage: {stage}")
raise ValueError(f"unknown agent_type: {agent_type}")
stage_enabled_names = resolve_tool_names_by_groups(set(groups))
selected_names = (
stage_enabled_names
if enabled_tool_names is None
else stage_enabled_names | set(enabled_tool_names)
)
return build_toolkit(
session=session,
owner_id=owner_id,
groups=set(groups),
enabled_tool_names=selected_names,
enable_hitl=enable_hitl,
enable_approval_layer=enable_approval_layer,
)
@@ -12,17 +12,10 @@ from schemas.agent.ui_hints import (
UiHintAction,
UiHintActionNavigation,
UiHintActionStyle,
UiHintErrorBlock,
UiHintKeyValuePair,
UiHintKvBlock,
UiHintListBlock,
UiHintIntent,
UiHintKvItem,
UiHintListItem,
UiHintOperationBlock,
UiHintOperationResult,
UiHintOperationType,
UiHintStatus,
UiHintTextBlock,
UiHintTextFormat,
UiHintsPayload,
)
@@ -40,18 +33,10 @@ def calendar_error_output(
retryable: bool,
) -> ToolResponse:
ui_hints = UiHintsPayload(
intent=UiHintIntent.STATUS,
status=UiHintStatus.ERROR,
title="日历操作失败",
description=message,
blocks=[
UiHintErrorBlock(
kind="error",
title="操作失败",
errorCode=code,
message=message,
retryable=retryable,
)
],
body=message,
)
output = build_error_output(
tool_name=tool_name,
@@ -84,29 +69,17 @@ def calendar_read_hints(
for event in events
]
return UiHintsPayload(
intent=UiHintIntent.LIST,
status=UiHintStatus.SUCCESS,
title="日程列表",
description=f"{total} 个日程",
blocks=[
UiHintKvBlock(
kind="kv",
title="分页信息",
pairs=[
UiHintKeyValuePair(key="total", label="总数", value=total),
UiHintKeyValuePair(key="page", label="当前页", value=page),
UiHintKeyValuePair(key="page_size", label="每页", value=page_size),
UiHintKeyValuePair(
key="total_pages", label="总页数", value=total_pages
),
],
),
UiHintListBlock(
kind="list",
title="日程项",
items=event_items,
emptyText="当前没有日程",
),
items=[
UiHintKvItem(key="total", label="总数", value=total),
UiHintKvItem(key="page", label="当前页", value=page),
UiHintKvItem(key="page_size", label="每页", value=page_size),
UiHintKvItem(key="total_pages", label="总页数", value=total_pages),
],
list_items=event_items,
actions=[
UiHintAction(
label="打开日历",
@@ -125,65 +98,38 @@ def calendar_write_hints(
event: dict[str, Any] | None,
event_id: str | None,
) -> UiHintsPayload:
operation_type = UiHintOperationType.EXECUTE
if operation == "create":
operation_type = UiHintOperationType.CREATE
elif operation == "update":
operation_type = UiHintOperationType.UPDATE
elif operation == "delete":
operation_type = UiHintOperationType.DELETE
kv_items: list[UiHintKvItem] = []
blocks: list[Any] = [
UiHintOperationBlock(
kind="operation",
title="日历写入结果",
operation=operation_type,
result=UiHintOperationResult.SUCCESS,
message=message,
affectedCount=1,
)
]
if event:
blocks.append(
UiHintKvBlock(
kind="kv",
title="日程详情",
pairs=[
UiHintKeyValuePair(
key="event_id",
label="日程ID",
value=str(event.get("id") or ""),
copyable=True,
),
UiHintKeyValuePair(
key="title",
label="标题",
value=str(event.get("title") or ""),
copyable=True,
),
UiHintKeyValuePair(
key="start_at",
label="开始时间",
value=str(event.get("startAt") or ""),
copyable=True,
),
],
)
)
kv_items = [
UiHintKvItem(
key="event_id",
label="日程ID",
value=str(event.get("id") or ""),
copyable=True,
),
UiHintKvItem(
key="title",
label="标题",
value=str(event.get("title") or ""),
copyable=True,
),
UiHintKvItem(
key="start_at",
label="开始时间",
value=str(event.get("startAt") or ""),
copyable=True,
),
]
elif event_id:
blocks.append(
UiHintTextBlock(
kind="text",
content=f"目标日程 ID: {event_id}",
format=UiHintTextFormat.PLAIN,
)
)
message = f"目标日程 ID: {event_id}\n{message}"
return UiHintsPayload(
intent=UiHintIntent.STATUS,
status=UiHintStatus.SUCCESS,
title="日历操作完成",
description=message,
blocks=blocks,
body=message,
items=kv_items if kv_items else None,
actions=[
UiHintAction(
label="查看日历",
@@ -203,36 +149,17 @@ def calendar_share_hints(
permission_text = (
", ".join([k for k, v in permission.items() if v is True]) or "按邀请人单独设置"
)
return UiHintsPayload(
intent=UiHintIntent.STATUS,
status=UiHintStatus.SUCCESS,
title="日程已分享",
description=f"已邀请 {len(invited)}",
blocks=[
UiHintOperationBlock(
kind="operation",
title="分享结果",
operation=UiHintOperationType.EXECUTE,
result=UiHintOperationResult.SUCCESS,
message=f"已邀请 {len(invited)}",
affectedCount=len(invited),
),
UiHintKvBlock(
kind="kv",
title="分享信息",
pairs=[
UiHintKeyValuePair(
key="event_id", label="日程ID", value=event_id, copyable=True
),
UiHintKeyValuePair(
key="permission", label="权限", value=permission_text
),
],
),
UiHintListBlock(
kind="list",
title="被邀请人",
items=[UiHintListItem(title=email) for email in invited],
emptyText="暂无被邀请人",
),
items=[
UiHintKvItem(key="event_id", label="日程ID", value=event_id, copyable=True),
UiHintKvItem(key="permission", label="权限", value=permission_text),
],
list_items=[UiHintListItem(title=email) for email in invited]
if invited
else [],
)