feat: 优化 Agent 运行时与聊天设置体验

This commit is contained in:
qzl
2026-03-16 18:32:09 +08:00
parent 3f79cf0df7
commit 5a34616287
41 changed files with 2603 additions and 1263 deletions
+139 -432
View File
@@ -1,30 +1,28 @@
from __future__ import annotations
import json
from collections.abc import AsyncGenerator
from dataclasses import dataclass
from datetime import datetime, timezone
from decimal import Decimal
from typing import TYPE_CHECKING, Any
from uuid import UUID, uuid4
from uuid import UUID
from ag_ui.core.types import RunAgentInput
from agentscope.formatter import OpenAIChatFormatter
from agentscope.memory import InMemoryMemory
from agentscope.message import Msg
from agentscope.model import OpenAIChatModel
from core.agentscope.events.persistence import MessageRepository, SessionRepository
from core.agentscope.runtime.json_react_agent import JsonReActAgent
from core.agentscope.prompts.agent_prompt import build_worker_contract_prompt
from core.agentscope.prompts.system_prompt import build_system_prompt
from core.agentscope.tools.toolkit import build_stage_toolkit, build_toolkit
from core.agentscope.runtime.utils import (
normalize_tool_name,
parse_tool_agent_output,
from core.agentscope.runtime.json_react_agent import JsonReActAgent
from core.agentscope.runtime.model_tracking import TrackingChatModel
from core.agentscope.runtime.router_persistence import persist_router_message
from core.agentscope.runtime.stage_emitter import PipelineStageEmitter
from core.agentscope.tools.toolkit import build_stage_toolkit
from core.agentscope.utils import (
finalize_json_response,
patch_agentscope_json_repair_compat,
)
from core.db.session import AsyncSessionLocal
from core.logging import get_logger
from models.agent_chat_message import AgentChatMessageRole
from models.agent_chat_session import AgentChatSessionStatus
from models.llm import Llm
from models.system_agents import SystemAgents
from schemas.agent.runtime_models import (
@@ -33,7 +31,6 @@ from schemas.agent.runtime_models import (
resolve_worker_output_model,
)
from schemas.agent.system_agent import AgentType, SystemAgentLLMConfig
from schemas.messages.chat_message import AgentChatMessage, AgentChatMessageMetadata
from schemas.user import UserContext
from services.litellm.service import LiteLLMService
from sqlalchemy import select
@@ -59,246 +56,9 @@ class StageExecutionResult:
response_metadata: dict[str, Any]
class _TrackingChatModel:
def __init__(self, inner: OpenAIChatModel) -> None:
self._inner = inner
self._total_input_tokens = 0
self._total_output_tokens = 0
self._total_latency_ms = 0
self._cached_prompt_tokens = 0
@property
def stream(self) -> bool:
return self._inner.stream
@stream.setter
def stream(self, value: bool) -> None:
self._inner.stream = value
def __getattr__(self, name: str) -> Any:
return getattr(self._inner, name)
async def __call__(self, *args: Any, **kwargs: Any) -> Any:
tools = kwargs.get("tools")
tool_names: list[str] = []
generate_response_schema: dict[str, Any] | None = None
if isinstance(tools, list):
for tool in tools:
if not isinstance(tool, dict):
continue
function = tool.get("function")
if isinstance(function, dict):
name = function.get("name")
if isinstance(name, str):
tool_names.append(name)
if name == "generate_response":
parameters = function.get("parameters")
if isinstance(parameters, dict):
generate_response_schema = {
"required": parameters.get("required"),
"properties": list(
(
parameters.get("properties", {})
if isinstance(
parameters.get("properties", {}), dict
)
else {}
).keys()
),
}
logger.info(
"model_call_debug",
tool_choice=kwargs.get("tool_choice"),
tool_count=len(tool_names),
tool_names=tool_names,
generate_response_schema=generate_response_schema,
)
response = await self._inner(*args, **kwargs)
if isinstance(response, AsyncGenerator):
return self._track_stream(response)
self._record_usage(getattr(response, "usage", None))
return response
async def _track_stream(
self, response: AsyncGenerator[Any, None]
) -> AsyncGenerator[Any, None]:
latest_usage = None
async for chunk in response:
usage = getattr(chunk, "usage", None)
if usage is not None:
latest_usage = usage
yield chunk
self._record_usage(latest_usage)
def _record_usage(self, usage: Any) -> None:
if usage is None:
return
self._total_input_tokens += max(int(getattr(usage, "input_tokens", 0) or 0), 0)
self._total_output_tokens += max(
int(getattr(usage, "output_tokens", 0) or 0), 0
)
self._total_latency_ms += max(
int(round(float(getattr(usage, "time", 0) or 0) * 1000)), 0
)
metadata = getattr(usage, "metadata", None)
if metadata is not None:
cached_tokens = 0
if isinstance(metadata, dict):
prompt_details = metadata.get("prompt_tokens_details")
if isinstance(prompt_details, dict):
cached_tokens = int(prompt_details.get("cached_tokens", 0) or 0)
else:
prompt_details = getattr(metadata, "prompt_tokens_details", None)
cached_tokens = int(getattr(prompt_details, "cached_tokens", 0) or 0)
self._cached_prompt_tokens += max(cached_tokens, 0)
def usage_summary(self) -> dict[str, int]:
return {
"input_tokens": self._total_input_tokens,
"output_tokens": self._total_output_tokens,
"latency_ms": self._total_latency_ms,
"cached_prompt_tokens": self._cached_prompt_tokens,
}
class _PipelineStageEmitter:
def __init__(
self,
*,
pipeline: PipelineLike,
session_id: str,
run_id: str,
stage: str,
emit_text_events: bool,
emit_tool_events: bool,
) -> None:
self._pipeline = pipeline
self._session_id = session_id
self._run_id = run_id
self._stage = stage
self._emit_text_events = emit_text_events
self._emit_tool_events = emit_tool_events
self._text_by_message_id: dict[str, str] = {}
self._emitted_tool_calls: set[str] = set()
self._emitted_tool_results: set[str] = set()
self.latest_text_message_id: str | None = None
self.latest_text: str = ""
async def handle_print(self, *, msg: Msg, last: bool) -> None:
del last
if self._emit_tool_events:
await self._emit_tool_events_from_msg(msg)
if self._emit_text_events:
await self._emit_text_events_from_msg(msg)
async def _emit_text_events_from_msg(self, msg: Msg) -> None:
text = msg.get_text_content(separator="") or ""
if not text:
return
message_id = str(msg.id)
self._text_by_message_id[message_id] = text
self.latest_text_message_id = message_id
self.latest_text = text
async def _emit_tool_events_from_msg(self, msg: Msg) -> None:
for block in msg.get_content_blocks("tool_use"):
tool_call_id = str(block.get("id", "")).strip()
tool_name = str(block.get("name", "")).strip()
if (
not tool_call_id
or not tool_name
or tool_call_id in self._emitted_tool_calls
):
continue
payload = {
"messageId": str(msg.id),
"toolCallId": tool_call_id,
"toolCallName": tool_name,
"stage": self._stage,
}
await self._emit("TOOL_CALL_START", payload)
await self._emit(
"TOOL_CALL_ARGS",
{
**payload,
"args": block.get("input", {}),
},
)
await self._emit("TOOL_CALL_END", payload)
self._emitted_tool_calls.add(tool_call_id)
for block in msg.get_content_blocks("tool_result"):
tool_call_id = str(block.get("id", "")).strip()
if not tool_call_id or tool_call_id in self._emitted_tool_results:
continue
tool_output = parse_tool_agent_output(block.get("output"))
if tool_output is None:
continue
tool_output_dict = tool_output.model_dump(mode="json", exclude_none=True)
result_data = {
"messageId": str(msg.id),
"role": "tool",
"stage": self._stage,
"tool_name": tool_output.tool_name,
"tool_call_id": tool_output.tool_call_id,
"tool_call_args": tool_output.tool_call_args,
"status": tool_output.status.value,
"result_summary": tool_output.result_summary,
}
ui_hints = tool_output_dict.get("ui_hints")
if ui_hints is not None:
result_data["ui_hints"] = ui_hints
if tool_output.error:
result_data["error"] = tool_output.error.model_dump(mode="json")
await self._emit("TOOL_CALL_RESULT", result_data)
self._emitted_tool_results.add(tool_call_id)
async def emit_final_text_end(
self,
*,
worker_output: dict[str, Any],
response_metadata: dict[str, Any],
) -> None:
message_id = (
self.latest_text_message_id or f"worker-{self._run_id}-{uuid4().hex[:8]}"
)
output_data = {
"messageId": message_id,
"role": "assistant",
"stage": self._stage,
"status": worker_output.get("status"),
"answer": worker_output.get("answer", ""),
"key_points": worker_output.get("key_points", []),
"result_type": worker_output.get("result_type"),
"suggested_actions": worker_output.get("suggested_actions", []),
"error": worker_output.get("error"),
}
ui_hints = worker_output.get("ui_hints")
if ui_hints is not None:
output_data["ui_hints"] = ui_hints
output_data.update(response_metadata)
await self._emit("TEXT_MESSAGE_END", output_data)
async def _emit(self, event_type: str, payload: dict[str, Any]) -> None:
await self._pipeline.emit(
session_id=self._session_id,
event={
"type": event_type,
"threadId": self._session_id,
"runId": self._run_id,
**payload,
},
)
class AgentScopeRunner:
def __init__(self, *, litellm_service: LiteLLMService | None = None) -> None:
patch_agentscope_json_repair_compat()
self._litellm_service = litellm_service or LiteLLMService()
async def execute(
@@ -310,76 +70,30 @@ class AgentScopeRunner:
run_input: RunAgentInput,
) -> dict[str, Any]:
owner_id = UUID(user_context.id)
enabled_tool_names = self._extract_tool_names(run_input)
async with AsyncSessionLocal() as session:
router_toolkit, worker_toolkit = self._build_toolkits(
session=session,
owner_id=owner_id,
enabled_tool_names=enabled_tool_names,
worker_toolkit = self._build_worker_toolkit(
session=session, owner_id=owner_id
)
router_config, worker_config = await self._load_stage_configs(
session=session
)
router_config = await self._load_system_agent_config(
router_output = await self._execute_router_step(
session=session,
agent_type=AgentType.ROUTER,
)
worker_config = await self._load_system_agent_config(
session=session,
agent_type=AgentType.WORKER,
)
await self._emit_step_event(
pipeline=pipeline,
run_input=run_input,
step_name="router",
event_type="STEP_STARTED",
)
router_result = await self._run_router_stage(
user_context=user_context,
context_messages=context_messages,
toolkit=router_toolkit,
run_input=run_input,
stage_config=router_config,
)
router_output = RouterAgentOutput.model_validate(router_result.payload)
await self._persist_router_message(
session=session,
thread_id=run_input.thread_id,
run_id=run_input.run_id,
model_code=router_config.model_code,
router_output=router_output,
response_metadata=router_result.response_metadata,
)
await session.commit()
await self._emit_step_event(
worker_output = await self._execute_worker_step(
pipeline=pipeline,
run_input=run_input,
step_name="router",
event_type="STEP_FINISHED",
)
worker_output_model = resolve_worker_output_model(router_output.ui.ui_mode)
await self._emit_step_event(
pipeline=pipeline,
run_input=run_input,
step_name="worker",
event_type="STEP_STARTED",
)
worker_result = await self._run_worker_stage(
user_context=user_context,
router_output=router_output,
toolkit=worker_toolkit,
run_input=run_input,
stage_config=worker_config,
worker_output_model=worker_output_model,
pipeline=pipeline,
)
worker_output = worker_output_model.model_validate(worker_result.payload)
await self._emit_step_event(
pipeline=pipeline,
run_input=run_input,
step_name="worker",
event_type="STEP_FINISHED",
)
return {
@@ -387,40 +101,107 @@ class AgentScopeRunner:
"worker": worker_output.model_dump(mode="json", exclude_none=True),
}
def _build_toolkits(
def _build_worker_toolkit(
self,
*,
session: AsyncSession,
owner_id: UUID,
enabled_tool_names: set[str] | None,
) -> tuple[Any, Any]:
return (
build_toolkit(
session=session,
owner_id=owner_id,
enabled_tool_names=set(),
),
build_stage_toolkit(
agent_type=AgentType.WORKER,
session=session,
owner_id=owner_id,
enabled_tool_names=enabled_tool_names,
),
) -> Any:
return build_stage_toolkit(
agent_type=AgentType.WORKER,
session=session,
owner_id=owner_id,
)
def _extract_tool_names(self, run_input: RunAgentInput) -> set[str] | None:
raw_tools = getattr(run_input, "tools", None)
if not isinstance(raw_tools, list):
return None
selected: set[str] = set()
for item in raw_tools:
if isinstance(item, dict):
name = item.get("name")
else:
name = getattr(item, "name", None)
if isinstance(name, str) and name.strip():
selected.add(normalize_tool_name(name))
return selected
async def _load_stage_configs(
self,
*,
session: AsyncSession,
) -> tuple[SystemAgentRuntimeConfig, SystemAgentRuntimeConfig]:
router_config = await self._load_system_agent_config(
session=session,
agent_type=AgentType.ROUTER,
)
worker_config = await self._load_system_agent_config(
session=session,
agent_type=AgentType.WORKER,
)
return router_config, worker_config
async def _execute_router_step(
self,
*,
session: AsyncSession,
pipeline: PipelineLike,
run_input: RunAgentInput,
user_context: UserContext,
context_messages: list[Msg],
stage_config: SystemAgentRuntimeConfig,
) -> RouterAgentOutput:
await self._emit_step_event(
pipeline=pipeline,
run_input=run_input,
step_name="router",
event_type="STEP_STARTED",
)
router_result = await self._run_router_stage(
user_context=user_context,
context_messages=context_messages,
run_input=run_input,
stage_config=stage_config,
)
router_output = RouterAgentOutput.model_validate(router_result.payload)
await persist_router_message(
session=session,
thread_id=run_input.thread_id,
run_id=run_input.run_id,
model_code=stage_config.model_code,
router_output=router_output,
response_metadata=router_result.response_metadata,
)
await session.commit()
await self._emit_step_event(
pipeline=pipeline,
run_input=run_input,
step_name="router",
event_type="STEP_FINISHED",
)
return router_output
async def _execute_worker_step(
self,
*,
pipeline: PipelineLike,
run_input: RunAgentInput,
user_context: UserContext,
router_output: RouterAgentOutput,
toolkit: Any,
stage_config: SystemAgentRuntimeConfig,
) -> WorkerAgentOutputLite:
worker_output_model = resolve_worker_output_model(router_output.ui.ui_mode)
await self._emit_step_event(
pipeline=pipeline,
run_input=run_input,
step_name="worker",
event_type="STEP_STARTED",
)
worker_result = await self._run_worker_stage(
user_context=user_context,
router_output=router_output,
toolkit=toolkit,
run_input=run_input,
stage_config=stage_config,
worker_output_model=worker_output_model,
pipeline=pipeline,
)
worker_output = worker_output_model.model_validate(worker_result.payload)
await self._emit_step_event(
pipeline=pipeline,
run_input=run_input,
step_name="worker",
event_type="STEP_FINISHED",
)
return worker_output
async def _load_system_agent_config(
self,
@@ -451,7 +232,6 @@ class AgentScopeRunner:
*,
user_context: UserContext,
context_messages: list[Msg],
toolkit: Any,
run_input: RunAgentInput,
stage_config: SystemAgentRuntimeConfig,
) -> StageExecutionResult:
@@ -462,28 +242,26 @@ class AgentScopeRunner:
now_utc=datetime.now(timezone.utc),
tools=None,
)
agent = self._build_agent(
agent_name="router",
system_prompt=system_prompt,
toolkit=toolkit,
response, payload = await finalize_json_response(
model=tracking_model,
)
response_msg = await agent.reply_json(
context_messages,
formatter=OpenAIChatFormatter(),
base_messages=[Msg("system", system_prompt, "system"), *context_messages],
output_model=RouterAgentOutput,
retries=0,
)
response_msg = Msg(
name="router",
role="assistant",
content=list(getattr(response, "content", [])),
metadata=payload,
)
logger.info(
"router_reply_received",
run_id=run_input.run_id,
thread_id=run_input.thread_id,
message_id=str(response_msg.id),
)
payload = RouterAgentOutput.model_validate(
response_msg.metadata or {}
).model_dump(
mode="json",
exclude_none=True,
)
return StageExecutionResult(
message=response_msg,
payload=payload,
@@ -504,11 +282,9 @@ class AgentScopeRunner:
worker_output_model: type[WorkerAgentOutputLite],
pipeline: PipelineLike,
) -> StageExecutionResult:
worker_input = self._build_worker_input_messages(
router_output=router_output,
)
worker_input = self._build_worker_input_messages(router_output=router_output)
tracking_model = self._build_model(stage_config=stage_config)
emitter = _PipelineStageEmitter(
emitter = PipelineStageEmitter(
pipeline=pipeline,
session_id=run_input.thread_id,
run_id=run_input.run_id,
@@ -522,15 +298,14 @@ class AgentScopeRunner:
agent_type=AgentType.WORKER,
user_context=user_context,
now_utc=datetime.now(timezone.utc),
tools=run_input.tools,
tools=None,
),
toolkit=toolkit,
model=tracking_model,
emitter=emitter,
)
response_msg = await agent.reply_json(
worker_input,
output_model=worker_output_model,
worker_input, output_model=worker_output_model
)
worker_payload = worker_output_model.model_validate(response_msg.metadata or {})
response_metadata = self._litellm_service.build_usage_metadata(
@@ -552,24 +327,17 @@ class AgentScopeRunner:
*,
router_output: RouterAgentOutput,
) -> list[Msg]:
routing_contract = json.dumps(
router_output.model_dump(mode="json", exclude_none=True),
ensure_ascii=False,
separators=(",", ":"),
)
routing_msg = Msg(
name="router",
role="user",
content=(
"Use the following routing contract as the execution source of truth. "
f"Do not change the routed objective:\n{routing_contract}"
),
)
return [routing_msg]
return [
Msg(
name="router",
role="user",
content=build_worker_contract_prompt(router_output=router_output),
)
]
def _build_model(
self, *, stage_config: SystemAgentRuntimeConfig
) -> _TrackingChatModel:
) -> TrackingChatModel:
generate_kwargs: dict[str, Any] = {
"temperature": stage_config.llm_config.temperature,
"max_tokens": stage_config.llm_config.max_tokens,
@@ -585,7 +353,7 @@ class AgentScopeRunner:
client_kwargs={"base_url": self._litellm_service.proxy_base_url},
generate_kwargs=generate_kwargs,
)
return _TrackingChatModel(model)
return TrackingChatModel(model)
def _build_agent(
self,
@@ -593,8 +361,8 @@ class AgentScopeRunner:
agent_name: str,
system_prompt: str,
toolkit: Any,
model: _TrackingChatModel,
emitter: _PipelineStageEmitter | None = None,
model: TrackingChatModel,
emitter: PipelineStageEmitter | None = None,
) -> JsonReActAgent:
return JsonReActAgent(
name=agent_name,
@@ -624,66 +392,5 @@ class AgentScopeRunner:
},
)
async def _persist_router_message(
self,
*,
session: AsyncSession,
thread_id: str,
run_id: str,
model_code: str,
router_output: RouterAgentOutput,
response_metadata: dict[str, Any],
) -> None:
session_id = UUID(thread_id)
message_repo = MessageRepository(session)
session_repo = SessionRepository(session)
locked_session = await session_repo.lock_session_for_update(
session_id=session_id
)
if locked_session is None:
raise RuntimeError("chat session not found for router persistence")
seq = int(getattr(locked_session, "message_count", 0) or 0) + 1
metadata = AgentChatMessageMetadata(
run_id=run_id,
agent_type=AgentType.ROUTER,
router_agent_output=router_output,
)
message_payload = AgentChatMessage(
id=uuid4(),
seq=seq,
role=AgentChatMessageRole.ASSISTANT.value,
content="",
model_code=model_code,
tool_name=None,
input_tokens=int(response_metadata.get("inputTokens", 0) or 0),
output_tokens=int(response_metadata.get("outputTokens", 0) or 0),
cost=Decimal(str(response_metadata.get("cost", 0) or 0)),
latency_ms=int(response_metadata.get("latencyMs", 0) or 0),
metadata=metadata,
timestamp=datetime.now(timezone.utc),
)
await message_repo.append_message(
session_id=session_id,
seq=message_payload.seq,
role=AgentChatMessageRole.ASSISTANT,
content=message_payload.content,
model_code=message_payload.model_code,
tool_name=message_payload.tool_name,
metadata=metadata.model_dump(mode="json", exclude_none=True),
input_tokens=message_payload.input_tokens,
output_tokens=message_payload.output_tokens,
cost=message_payload.cost,
latency_ms=message_payload.latency_ms,
)
await session_repo.update_runtime_state(
chat_session=locked_session,
status=AgentChatSessionStatus.RUNNING,
state_snapshot=locked_session.state_snapshot or {},
message_delta=1,
token_delta=message_payload.input_tokens + message_payload.output_tokens,
cost_delta=message_payload.cost,
)
await session.flush()
AgentScopeReActRunner = AgentScopeRunner