690 lines
24 KiB
Python
690 lines
24 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import json
|
||
|
|
from collections.abc import AsyncGenerator
|
||
|
|
from dataclasses import dataclass
|
||
|
|
from datetime import datetime, timezone
|
||
|
|
from decimal import Decimal
|
||
|
|
from typing import TYPE_CHECKING, Any
|
||
|
|
from uuid import UUID, uuid4
|
||
|
|
|
||
|
|
from ag_ui.core.types import RunAgentInput
|
||
|
|
from agentscope.formatter import OpenAIChatFormatter
|
||
|
|
from agentscope.memory import InMemoryMemory
|
||
|
|
from agentscope.message import Msg
|
||
|
|
from agentscope.model import OpenAIChatModel
|
||
|
|
from core.agentscope.events.persistence import MessageRepository, SessionRepository
|
||
|
|
from core.agentscope.runtime.json_react_agent import JsonReActAgent
|
||
|
|
from core.agentscope.prompts.system_prompt import build_system_prompt
|
||
|
|
from core.agentscope.tools.toolkit import build_stage_toolkit, build_toolkit
|
||
|
|
from core.agentscope.runtime.utils import (
|
||
|
|
normalize_tool_name,
|
||
|
|
parse_tool_agent_output,
|
||
|
|
)
|
||
|
|
from core.db.session import AsyncSessionLocal
|
||
|
|
from core.logging import get_logger
|
||
|
|
from models.agent_chat_message import AgentChatMessageRole
|
||
|
|
from models.agent_chat_session import AgentChatSessionStatus
|
||
|
|
from models.llm import Llm
|
||
|
|
from models.system_agents import SystemAgents
|
||
|
|
from schemas.agent.runtime_models import (
|
||
|
|
RouterAgentOutput,
|
||
|
|
WorkerAgentOutputLite,
|
||
|
|
resolve_worker_output_model,
|
||
|
|
)
|
||
|
|
from schemas.agent.system_agent import AgentType, SystemAgentLLMConfig
|
||
|
|
from schemas.messages.chat_message import AgentChatMessage, AgentChatMessageMetadata
|
||
|
|
from schemas.user import UserContext
|
||
|
|
from services.litellm.service import LiteLLMService
|
||
|
|
from sqlalchemy import select
|
||
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
|
|
||
|
|
if TYPE_CHECKING:
|
||
|
|
from core.agentscope.runtime.orchestrator import PipelineLike
|
||
|
|
|
||
|
|
logger = get_logger("core.agentscope.runtime.runner")
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass(frozen=True)
|
||
|
|
class SystemAgentRuntimeConfig:
|
||
|
|
agent_type: AgentType
|
||
|
|
model_code: str
|
||
|
|
llm_config: SystemAgentLLMConfig
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass(frozen=True)
|
||
|
|
class StageExecutionResult:
|
||
|
|
message: Msg
|
||
|
|
payload: dict[str, Any]
|
||
|
|
response_metadata: dict[str, Any]
|
||
|
|
|
||
|
|
|
||
|
|
class _TrackingChatModel:
|
||
|
|
def __init__(self, inner: OpenAIChatModel) -> None:
|
||
|
|
self._inner = inner
|
||
|
|
self._total_input_tokens = 0
|
||
|
|
self._total_output_tokens = 0
|
||
|
|
self._total_latency_ms = 0
|
||
|
|
self._cached_prompt_tokens = 0
|
||
|
|
|
||
|
|
@property
|
||
|
|
def stream(self) -> bool:
|
||
|
|
return self._inner.stream
|
||
|
|
|
||
|
|
@stream.setter
|
||
|
|
def stream(self, value: bool) -> None:
|
||
|
|
self._inner.stream = value
|
||
|
|
|
||
|
|
def __getattr__(self, name: str) -> Any:
|
||
|
|
return getattr(self._inner, name)
|
||
|
|
|
||
|
|
async def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||
|
|
tools = kwargs.get("tools")
|
||
|
|
tool_names: list[str] = []
|
||
|
|
generate_response_schema: dict[str, Any] | None = None
|
||
|
|
if isinstance(tools, list):
|
||
|
|
for tool in tools:
|
||
|
|
if not isinstance(tool, dict):
|
||
|
|
continue
|
||
|
|
function = tool.get("function")
|
||
|
|
if isinstance(function, dict):
|
||
|
|
name = function.get("name")
|
||
|
|
if isinstance(name, str):
|
||
|
|
tool_names.append(name)
|
||
|
|
if name == "generate_response":
|
||
|
|
parameters = function.get("parameters")
|
||
|
|
if isinstance(parameters, dict):
|
||
|
|
generate_response_schema = {
|
||
|
|
"required": parameters.get("required"),
|
||
|
|
"properties": list(
|
||
|
|
(
|
||
|
|
parameters.get("properties", {})
|
||
|
|
if isinstance(
|
||
|
|
parameters.get("properties", {}), dict
|
||
|
|
)
|
||
|
|
else {}
|
||
|
|
).keys()
|
||
|
|
),
|
||
|
|
}
|
||
|
|
logger.info(
|
||
|
|
"model_call_debug",
|
||
|
|
tool_choice=kwargs.get("tool_choice"),
|
||
|
|
tool_count=len(tool_names),
|
||
|
|
tool_names=tool_names,
|
||
|
|
generate_response_schema=generate_response_schema,
|
||
|
|
)
|
||
|
|
response = await self._inner(*args, **kwargs)
|
||
|
|
if isinstance(response, AsyncGenerator):
|
||
|
|
return self._track_stream(response)
|
||
|
|
self._record_usage(getattr(response, "usage", None))
|
||
|
|
return response
|
||
|
|
|
||
|
|
async def _track_stream(
|
||
|
|
self, response: AsyncGenerator[Any, None]
|
||
|
|
) -> AsyncGenerator[Any, None]:
|
||
|
|
latest_usage = None
|
||
|
|
async for chunk in response:
|
||
|
|
usage = getattr(chunk, "usage", None)
|
||
|
|
if usage is not None:
|
||
|
|
latest_usage = usage
|
||
|
|
yield chunk
|
||
|
|
self._record_usage(latest_usage)
|
||
|
|
|
||
|
|
def _record_usage(self, usage: Any) -> None:
|
||
|
|
if usage is None:
|
||
|
|
return
|
||
|
|
self._total_input_tokens += max(int(getattr(usage, "input_tokens", 0) or 0), 0)
|
||
|
|
self._total_output_tokens += max(
|
||
|
|
int(getattr(usage, "output_tokens", 0) or 0), 0
|
||
|
|
)
|
||
|
|
self._total_latency_ms += max(
|
||
|
|
int(round(float(getattr(usage, "time", 0) or 0) * 1000)), 0
|
||
|
|
)
|
||
|
|
metadata = getattr(usage, "metadata", None)
|
||
|
|
if metadata is not None:
|
||
|
|
cached_tokens = 0
|
||
|
|
if isinstance(metadata, dict):
|
||
|
|
prompt_details = metadata.get("prompt_tokens_details")
|
||
|
|
if isinstance(prompt_details, dict):
|
||
|
|
cached_tokens = int(prompt_details.get("cached_tokens", 0) or 0)
|
||
|
|
else:
|
||
|
|
prompt_details = getattr(metadata, "prompt_tokens_details", None)
|
||
|
|
cached_tokens = int(getattr(prompt_details, "cached_tokens", 0) or 0)
|
||
|
|
self._cached_prompt_tokens += max(cached_tokens, 0)
|
||
|
|
|
||
|
|
def usage_summary(self) -> dict[str, int]:
|
||
|
|
return {
|
||
|
|
"input_tokens": self._total_input_tokens,
|
||
|
|
"output_tokens": self._total_output_tokens,
|
||
|
|
"latency_ms": self._total_latency_ms,
|
||
|
|
"cached_prompt_tokens": self._cached_prompt_tokens,
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
class _PipelineStageEmitter:
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
*,
|
||
|
|
pipeline: PipelineLike,
|
||
|
|
session_id: str,
|
||
|
|
run_id: str,
|
||
|
|
stage: str,
|
||
|
|
emit_text_events: bool,
|
||
|
|
emit_tool_events: bool,
|
||
|
|
) -> None:
|
||
|
|
self._pipeline = pipeline
|
||
|
|
self._session_id = session_id
|
||
|
|
self._run_id = run_id
|
||
|
|
self._stage = stage
|
||
|
|
self._emit_text_events = emit_text_events
|
||
|
|
self._emit_tool_events = emit_tool_events
|
||
|
|
self._text_by_message_id: dict[str, str] = {}
|
||
|
|
self._emitted_tool_calls: set[str] = set()
|
||
|
|
self._emitted_tool_results: set[str] = set()
|
||
|
|
self.latest_text_message_id: str | None = None
|
||
|
|
self.latest_text: str = ""
|
||
|
|
|
||
|
|
async def handle_print(self, *, msg: Msg, last: bool) -> None:
|
||
|
|
del last
|
||
|
|
if self._emit_tool_events:
|
||
|
|
await self._emit_tool_events_from_msg(msg)
|
||
|
|
if self._emit_text_events:
|
||
|
|
await self._emit_text_events_from_msg(msg)
|
||
|
|
|
||
|
|
async def _emit_text_events_from_msg(self, msg: Msg) -> None:
|
||
|
|
text = msg.get_text_content(separator="") or ""
|
||
|
|
if not text:
|
||
|
|
return
|
||
|
|
message_id = str(msg.id)
|
||
|
|
self._text_by_message_id[message_id] = text
|
||
|
|
self.latest_text_message_id = message_id
|
||
|
|
self.latest_text = text
|
||
|
|
|
||
|
|
async def _emit_tool_events_from_msg(self, msg: Msg) -> None:
|
||
|
|
for block in msg.get_content_blocks("tool_use"):
|
||
|
|
tool_call_id = str(block.get("id", "")).strip()
|
||
|
|
tool_name = str(block.get("name", "")).strip()
|
||
|
|
if (
|
||
|
|
not tool_call_id
|
||
|
|
or not tool_name
|
||
|
|
or tool_call_id in self._emitted_tool_calls
|
||
|
|
):
|
||
|
|
continue
|
||
|
|
payload = {
|
||
|
|
"messageId": str(msg.id),
|
||
|
|
"toolCallId": tool_call_id,
|
||
|
|
"toolCallName": tool_name,
|
||
|
|
"stage": self._stage,
|
||
|
|
}
|
||
|
|
await self._emit("TOOL_CALL_START", payload)
|
||
|
|
await self._emit(
|
||
|
|
"TOOL_CALL_ARGS",
|
||
|
|
{
|
||
|
|
**payload,
|
||
|
|
"args": block.get("input", {}),
|
||
|
|
},
|
||
|
|
)
|
||
|
|
await self._emit("TOOL_CALL_END", payload)
|
||
|
|
self._emitted_tool_calls.add(tool_call_id)
|
||
|
|
|
||
|
|
for block in msg.get_content_blocks("tool_result"):
|
||
|
|
tool_call_id = str(block.get("id", "")).strip()
|
||
|
|
if not tool_call_id or tool_call_id in self._emitted_tool_results:
|
||
|
|
continue
|
||
|
|
tool_output = parse_tool_agent_output(block.get("output"))
|
||
|
|
if tool_output is None:
|
||
|
|
continue
|
||
|
|
|
||
|
|
tool_output_dict = tool_output.model_dump(mode="json", exclude_none=True)
|
||
|
|
|
||
|
|
result_data = {
|
||
|
|
"messageId": str(msg.id),
|
||
|
|
"role": "tool",
|
||
|
|
"stage": self._stage,
|
||
|
|
"tool_name": tool_output.tool_name,
|
||
|
|
"tool_call_id": tool_output.tool_call_id,
|
||
|
|
"tool_call_args": tool_output.tool_call_args,
|
||
|
|
"status": tool_output.status.value,
|
||
|
|
"result_summary": tool_output.result_summary,
|
||
|
|
}
|
||
|
|
ui_hints = tool_output_dict.get("ui_hints")
|
||
|
|
if ui_hints is not None:
|
||
|
|
result_data["ui_hints"] = ui_hints
|
||
|
|
if tool_output.error:
|
||
|
|
result_data["error"] = tool_output.error.model_dump(mode="json")
|
||
|
|
|
||
|
|
await self._emit("TOOL_CALL_RESULT", result_data)
|
||
|
|
self._emitted_tool_results.add(tool_call_id)
|
||
|
|
|
||
|
|
async def emit_final_text_end(
|
||
|
|
self,
|
||
|
|
*,
|
||
|
|
worker_output: dict[str, Any],
|
||
|
|
response_metadata: dict[str, Any],
|
||
|
|
) -> None:
|
||
|
|
message_id = (
|
||
|
|
self.latest_text_message_id or f"worker-{self._run_id}-{uuid4().hex[:8]}"
|
||
|
|
)
|
||
|
|
|
||
|
|
output_data = {
|
||
|
|
"messageId": message_id,
|
||
|
|
"role": "assistant",
|
||
|
|
"stage": self._stage,
|
||
|
|
"status": worker_output.get("status"),
|
||
|
|
"answer": worker_output.get("answer", ""),
|
||
|
|
"key_points": worker_output.get("key_points", []),
|
||
|
|
"result_type": worker_output.get("result_type"),
|
||
|
|
"suggested_actions": worker_output.get("suggested_actions", []),
|
||
|
|
"error": worker_output.get("error"),
|
||
|
|
}
|
||
|
|
ui_hints = worker_output.get("ui_hints")
|
||
|
|
if ui_hints is not None:
|
||
|
|
output_data["ui_hints"] = ui_hints
|
||
|
|
|
||
|
|
output_data.update(response_metadata)
|
||
|
|
|
||
|
|
await self._emit("TEXT_MESSAGE_END", output_data)
|
||
|
|
|
||
|
|
async def _emit(self, event_type: str, payload: dict[str, Any]) -> None:
|
||
|
|
await self._pipeline.emit(
|
||
|
|
session_id=self._session_id,
|
||
|
|
event={
|
||
|
|
"type": event_type,
|
||
|
|
"threadId": self._session_id,
|
||
|
|
"runId": self._run_id,
|
||
|
|
**payload,
|
||
|
|
},
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
class AgentScopeRunner:
|
||
|
|
def __init__(self, *, litellm_service: LiteLLMService | None = None) -> None:
|
||
|
|
self._litellm_service = litellm_service or LiteLLMService()
|
||
|
|
|
||
|
|
async def execute(
|
||
|
|
self,
|
||
|
|
*,
|
||
|
|
user_context: UserContext,
|
||
|
|
context_messages: list[Msg],
|
||
|
|
pipeline: PipelineLike,
|
||
|
|
run_input: RunAgentInput,
|
||
|
|
) -> dict[str, Any]:
|
||
|
|
owner_id = UUID(user_context.id)
|
||
|
|
enabled_tool_names = self._extract_tool_names(run_input)
|
||
|
|
|
||
|
|
async with AsyncSessionLocal() as session:
|
||
|
|
router_toolkit, worker_toolkit = self._build_toolkits(
|
||
|
|
session=session,
|
||
|
|
owner_id=owner_id,
|
||
|
|
enabled_tool_names=enabled_tool_names,
|
||
|
|
)
|
||
|
|
|
||
|
|
router_config = await self._load_system_agent_config(
|
||
|
|
session=session,
|
||
|
|
agent_type=AgentType.ROUTER,
|
||
|
|
)
|
||
|
|
worker_config = await self._load_system_agent_config(
|
||
|
|
session=session,
|
||
|
|
agent_type=AgentType.WORKER,
|
||
|
|
)
|
||
|
|
|
||
|
|
await self._emit_step_event(
|
||
|
|
pipeline=pipeline,
|
||
|
|
run_input=run_input,
|
||
|
|
step_name="router",
|
||
|
|
event_type="STEP_STARTED",
|
||
|
|
)
|
||
|
|
router_result = await self._run_router_stage(
|
||
|
|
user_context=user_context,
|
||
|
|
context_messages=context_messages,
|
||
|
|
toolkit=router_toolkit,
|
||
|
|
run_input=run_input,
|
||
|
|
stage_config=router_config,
|
||
|
|
)
|
||
|
|
router_output = RouterAgentOutput.model_validate(router_result.payload)
|
||
|
|
await self._persist_router_message(
|
||
|
|
session=session,
|
||
|
|
thread_id=run_input.thread_id,
|
||
|
|
run_id=run_input.run_id,
|
||
|
|
model_code=router_config.model_code,
|
||
|
|
router_output=router_output,
|
||
|
|
response_metadata=router_result.response_metadata,
|
||
|
|
)
|
||
|
|
await session.commit()
|
||
|
|
await self._emit_step_event(
|
||
|
|
pipeline=pipeline,
|
||
|
|
run_input=run_input,
|
||
|
|
step_name="router",
|
||
|
|
event_type="STEP_FINISHED",
|
||
|
|
)
|
||
|
|
|
||
|
|
worker_output_model = resolve_worker_output_model(router_output.ui.ui_mode)
|
||
|
|
await self._emit_step_event(
|
||
|
|
pipeline=pipeline,
|
||
|
|
run_input=run_input,
|
||
|
|
step_name="worker",
|
||
|
|
event_type="STEP_STARTED",
|
||
|
|
)
|
||
|
|
worker_result = await self._run_worker_stage(
|
||
|
|
user_context=user_context,
|
||
|
|
router_output=router_output,
|
||
|
|
toolkit=worker_toolkit,
|
||
|
|
run_input=run_input,
|
||
|
|
stage_config=worker_config,
|
||
|
|
worker_output_model=worker_output_model,
|
||
|
|
pipeline=pipeline,
|
||
|
|
)
|
||
|
|
worker_output = worker_output_model.model_validate(worker_result.payload)
|
||
|
|
await self._emit_step_event(
|
||
|
|
pipeline=pipeline,
|
||
|
|
run_input=run_input,
|
||
|
|
step_name="worker",
|
||
|
|
event_type="STEP_FINISHED",
|
||
|
|
)
|
||
|
|
|
||
|
|
return {
|
||
|
|
"router": router_output.model_dump(mode="json", exclude_none=True),
|
||
|
|
"worker": worker_output.model_dump(mode="json", exclude_none=True),
|
||
|
|
}
|
||
|
|
|
||
|
|
def _build_toolkits(
|
||
|
|
self,
|
||
|
|
*,
|
||
|
|
session: AsyncSession,
|
||
|
|
owner_id: UUID,
|
||
|
|
enabled_tool_names: set[str] | None,
|
||
|
|
) -> tuple[Any, Any]:
|
||
|
|
return (
|
||
|
|
build_toolkit(
|
||
|
|
session=session,
|
||
|
|
owner_id=owner_id,
|
||
|
|
enabled_tool_names=set(),
|
||
|
|
),
|
||
|
|
build_stage_toolkit(
|
||
|
|
agent_type=AgentType.WORKER,
|
||
|
|
session=session,
|
||
|
|
owner_id=owner_id,
|
||
|
|
enabled_tool_names=enabled_tool_names,
|
||
|
|
),
|
||
|
|
)
|
||
|
|
|
||
|
|
def _extract_tool_names(self, run_input: RunAgentInput) -> set[str] | None:
|
||
|
|
raw_tools = getattr(run_input, "tools", None)
|
||
|
|
if not isinstance(raw_tools, list):
|
||
|
|
return None
|
||
|
|
selected: set[str] = set()
|
||
|
|
for item in raw_tools:
|
||
|
|
if isinstance(item, dict):
|
||
|
|
name = item.get("name")
|
||
|
|
else:
|
||
|
|
name = getattr(item, "name", None)
|
||
|
|
if isinstance(name, str) and name.strip():
|
||
|
|
selected.add(normalize_tool_name(name))
|
||
|
|
return selected
|
||
|
|
|
||
|
|
async def _load_system_agent_config(
|
||
|
|
self,
|
||
|
|
*,
|
||
|
|
session: AsyncSession,
|
||
|
|
agent_type: AgentType,
|
||
|
|
) -> SystemAgentRuntimeConfig:
|
||
|
|
stmt = (
|
||
|
|
select(SystemAgents, Llm)
|
||
|
|
.join(Llm, SystemAgents.llm_id == Llm.id)
|
||
|
|
.where(SystemAgents.agent_type == agent_type.value)
|
||
|
|
)
|
||
|
|
row = (await session.execute(stmt)).one_or_none()
|
||
|
|
if row is None:
|
||
|
|
raise RuntimeError(f"system agent config not found: {agent_type.value}")
|
||
|
|
system_agent, llm = row
|
||
|
|
status = str(system_agent.status).strip().lower()
|
||
|
|
if status != "active":
|
||
|
|
raise RuntimeError(f"system agent is not active: {agent_type.value}")
|
||
|
|
return SystemAgentRuntimeConfig(
|
||
|
|
agent_type=agent_type,
|
||
|
|
model_code=llm.model_code,
|
||
|
|
llm_config=SystemAgentLLMConfig.model_validate(system_agent.config or {}),
|
||
|
|
)
|
||
|
|
|
||
|
|
async def _run_router_stage(
|
||
|
|
self,
|
||
|
|
*,
|
||
|
|
user_context: UserContext,
|
||
|
|
context_messages: list[Msg],
|
||
|
|
toolkit: Any,
|
||
|
|
run_input: RunAgentInput,
|
||
|
|
stage_config: SystemAgentRuntimeConfig,
|
||
|
|
) -> StageExecutionResult:
|
||
|
|
tracking_model = self._build_model(stage_config=stage_config)
|
||
|
|
system_prompt = build_system_prompt(
|
||
|
|
agent_type=AgentType.ROUTER,
|
||
|
|
user_context=user_context,
|
||
|
|
now_utc=datetime.now(timezone.utc),
|
||
|
|
tools=None,
|
||
|
|
)
|
||
|
|
agent = self._build_agent(
|
||
|
|
agent_name="router",
|
||
|
|
system_prompt=system_prompt,
|
||
|
|
toolkit=toolkit,
|
||
|
|
model=tracking_model,
|
||
|
|
)
|
||
|
|
response_msg = await agent.reply_json(
|
||
|
|
context_messages,
|
||
|
|
output_model=RouterAgentOutput,
|
||
|
|
)
|
||
|
|
logger.info(
|
||
|
|
"router_reply_received",
|
||
|
|
run_id=run_input.run_id,
|
||
|
|
thread_id=run_input.thread_id,
|
||
|
|
message_id=str(response_msg.id),
|
||
|
|
)
|
||
|
|
payload = RouterAgentOutput.model_validate(
|
||
|
|
response_msg.metadata or {}
|
||
|
|
).model_dump(
|
||
|
|
mode="json",
|
||
|
|
exclude_none=True,
|
||
|
|
)
|
||
|
|
return StageExecutionResult(
|
||
|
|
message=response_msg,
|
||
|
|
payload=payload,
|
||
|
|
response_metadata=self._litellm_service.build_usage_metadata(
|
||
|
|
model=stage_config.model_code,
|
||
|
|
usage_summary=tracking_model.usage_summary(),
|
||
|
|
),
|
||
|
|
)
|
||
|
|
|
||
|
|
async def _run_worker_stage(
|
||
|
|
self,
|
||
|
|
*,
|
||
|
|
user_context: UserContext,
|
||
|
|
router_output: RouterAgentOutput,
|
||
|
|
toolkit: Any,
|
||
|
|
run_input: RunAgentInput,
|
||
|
|
stage_config: SystemAgentRuntimeConfig,
|
||
|
|
worker_output_model: type[WorkerAgentOutputLite],
|
||
|
|
pipeline: PipelineLike,
|
||
|
|
) -> StageExecutionResult:
|
||
|
|
worker_input = self._build_worker_input_messages(
|
||
|
|
router_output=router_output,
|
||
|
|
)
|
||
|
|
tracking_model = self._build_model(stage_config=stage_config)
|
||
|
|
emitter = _PipelineStageEmitter(
|
||
|
|
pipeline=pipeline,
|
||
|
|
session_id=run_input.thread_id,
|
||
|
|
run_id=run_input.run_id,
|
||
|
|
stage="worker",
|
||
|
|
emit_text_events=True,
|
||
|
|
emit_tool_events=True,
|
||
|
|
)
|
||
|
|
agent = self._build_agent(
|
||
|
|
agent_name="worker",
|
||
|
|
system_prompt=build_system_prompt(
|
||
|
|
agent_type=AgentType.WORKER,
|
||
|
|
user_context=user_context,
|
||
|
|
now_utc=datetime.now(timezone.utc),
|
||
|
|
tools=run_input.tools,
|
||
|
|
),
|
||
|
|
toolkit=toolkit,
|
||
|
|
model=tracking_model,
|
||
|
|
emitter=emitter,
|
||
|
|
)
|
||
|
|
response_msg = await agent.reply_json(
|
||
|
|
worker_input,
|
||
|
|
output_model=worker_output_model,
|
||
|
|
)
|
||
|
|
worker_payload = worker_output_model.model_validate(response_msg.metadata or {})
|
||
|
|
response_metadata = self._litellm_service.build_usage_metadata(
|
||
|
|
model=stage_config.model_code,
|
||
|
|
usage_summary=tracking_model.usage_summary(),
|
||
|
|
)
|
||
|
|
await emitter.emit_final_text_end(
|
||
|
|
worker_output=worker_payload.model_dump(mode="json", exclude_none=True),
|
||
|
|
response_metadata=response_metadata,
|
||
|
|
)
|
||
|
|
return StageExecutionResult(
|
||
|
|
message=response_msg,
|
||
|
|
payload=worker_payload.model_dump(mode="json", exclude_none=True),
|
||
|
|
response_metadata=response_metadata,
|
||
|
|
)
|
||
|
|
|
||
|
|
def _build_worker_input_messages(
|
||
|
|
self,
|
||
|
|
*,
|
||
|
|
router_output: RouterAgentOutput,
|
||
|
|
) -> list[Msg]:
|
||
|
|
routing_contract = json.dumps(
|
||
|
|
router_output.model_dump(mode="json", exclude_none=True),
|
||
|
|
ensure_ascii=False,
|
||
|
|
separators=(",", ":"),
|
||
|
|
)
|
||
|
|
routing_msg = Msg(
|
||
|
|
name="router",
|
||
|
|
role="user",
|
||
|
|
content=(
|
||
|
|
"Use the following routing contract as the execution source of truth. "
|
||
|
|
f"Do not change the routed objective:\n{routing_contract}"
|
||
|
|
),
|
||
|
|
)
|
||
|
|
return [routing_msg]
|
||
|
|
|
||
|
|
def _build_model(
|
||
|
|
self, *, stage_config: SystemAgentRuntimeConfig
|
||
|
|
) -> _TrackingChatModel:
|
||
|
|
generate_kwargs: dict[str, Any] = {
|
||
|
|
"temperature": stage_config.llm_config.temperature,
|
||
|
|
"max_tokens": stage_config.llm_config.max_tokens,
|
||
|
|
"timeout": stage_config.llm_config.timeout_seconds,
|
||
|
|
}
|
||
|
|
if stage_config.agent_type == AgentType.ROUTER:
|
||
|
|
generate_kwargs["extra_body"] = {"enable_thinking": False}
|
||
|
|
|
||
|
|
model = OpenAIChatModel(
|
||
|
|
model_name=stage_config.model_code,
|
||
|
|
api_key=self._litellm_service.proxy_api_key,
|
||
|
|
stream=False,
|
||
|
|
client_kwargs={"base_url": self._litellm_service.proxy_base_url},
|
||
|
|
generate_kwargs=generate_kwargs,
|
||
|
|
)
|
||
|
|
return _TrackingChatModel(model)
|
||
|
|
|
||
|
|
def _build_agent(
|
||
|
|
self,
|
||
|
|
*,
|
||
|
|
agent_name: str,
|
||
|
|
system_prompt: str,
|
||
|
|
toolkit: Any,
|
||
|
|
model: _TrackingChatModel,
|
||
|
|
emitter: _PipelineStageEmitter | None = None,
|
||
|
|
) -> JsonReActAgent:
|
||
|
|
return JsonReActAgent(
|
||
|
|
name=agent_name,
|
||
|
|
sys_prompt=system_prompt,
|
||
|
|
model=model,
|
||
|
|
formatter=OpenAIChatFormatter(),
|
||
|
|
toolkit=toolkit,
|
||
|
|
memory=InMemoryMemory(),
|
||
|
|
emitter=emitter,
|
||
|
|
)
|
||
|
|
|
||
|
|
async def _emit_step_event(
|
||
|
|
self,
|
||
|
|
*,
|
||
|
|
pipeline: PipelineLike,
|
||
|
|
run_input: RunAgentInput,
|
||
|
|
step_name: str,
|
||
|
|
event_type: str,
|
||
|
|
) -> None:
|
||
|
|
await pipeline.emit(
|
||
|
|
session_id=run_input.thread_id,
|
||
|
|
event={
|
||
|
|
"type": event_type,
|
||
|
|
"threadId": run_input.thread_id,
|
||
|
|
"runId": run_input.run_id,
|
||
|
|
"stepName": step_name,
|
||
|
|
},
|
||
|
|
)
|
||
|
|
|
||
|
|
async def _persist_router_message(
|
||
|
|
self,
|
||
|
|
*,
|
||
|
|
session: AsyncSession,
|
||
|
|
thread_id: str,
|
||
|
|
run_id: str,
|
||
|
|
model_code: str,
|
||
|
|
router_output: RouterAgentOutput,
|
||
|
|
response_metadata: dict[str, Any],
|
||
|
|
) -> None:
|
||
|
|
session_id = UUID(thread_id)
|
||
|
|
message_repo = MessageRepository(session)
|
||
|
|
session_repo = SessionRepository(session)
|
||
|
|
locked_session = await session_repo.lock_session_for_update(
|
||
|
|
session_id=session_id
|
||
|
|
)
|
||
|
|
if locked_session is None:
|
||
|
|
raise RuntimeError("chat session not found for router persistence")
|
||
|
|
seq = int(getattr(locked_session, "message_count", 0) or 0) + 1
|
||
|
|
metadata = AgentChatMessageMetadata(
|
||
|
|
run_id=run_id,
|
||
|
|
agent_type=AgentType.ROUTER,
|
||
|
|
router_agent_output=router_output,
|
||
|
|
)
|
||
|
|
message_payload = AgentChatMessage(
|
||
|
|
id=uuid4(),
|
||
|
|
seq=seq,
|
||
|
|
role=AgentChatMessageRole.ASSISTANT.value,
|
||
|
|
content="",
|
||
|
|
model_code=model_code,
|
||
|
|
tool_name=None,
|
||
|
|
input_tokens=int(response_metadata.get("inputTokens", 0) or 0),
|
||
|
|
output_tokens=int(response_metadata.get("outputTokens", 0) or 0),
|
||
|
|
cost=Decimal(str(response_metadata.get("cost", 0) or 0)),
|
||
|
|
latency_ms=int(response_metadata.get("latencyMs", 0) or 0),
|
||
|
|
metadata=metadata,
|
||
|
|
timestamp=datetime.now(timezone.utc),
|
||
|
|
)
|
||
|
|
await message_repo.append_message(
|
||
|
|
session_id=session_id,
|
||
|
|
seq=message_payload.seq,
|
||
|
|
role=AgentChatMessageRole.ASSISTANT,
|
||
|
|
content=message_payload.content,
|
||
|
|
model_code=message_payload.model_code,
|
||
|
|
tool_name=message_payload.tool_name,
|
||
|
|
metadata=metadata.model_dump(mode="json", exclude_none=True),
|
||
|
|
input_tokens=message_payload.input_tokens,
|
||
|
|
output_tokens=message_payload.output_tokens,
|
||
|
|
cost=message_payload.cost,
|
||
|
|
latency_ms=message_payload.latency_ms,
|
||
|
|
)
|
||
|
|
await session_repo.update_runtime_state(
|
||
|
|
chat_session=locked_session,
|
||
|
|
status=AgentChatSessionStatus.RUNNING,
|
||
|
|
state_snapshot=locked_session.state_snapshot or {},
|
||
|
|
message_delta=1,
|
||
|
|
token_delta=message_payload.input_tokens + message_payload.output_tokens,
|
||
|
|
cost_delta=message_payload.cost,
|
||
|
|
)
|
||
|
|
await session.flush()
|
||
|
|
|
||
|
|
|
||
|
|
AgentScopeReActRunner = AgentScopeRunner
|