chore: 清理废弃的生成文件和重命名文件
This commit is contained in:
@@ -1,201 +0,0 @@
|
|||||||
// GENERATED CODE - DO NOT MODIFY BY HAND
|
|
||||||
|
|
||||||
part of 'ag_ui_event.dart';
|
|
||||||
|
|
||||||
// **************************************************************************
|
|
||||||
// JsonSerializableGenerator
|
|
||||||
// **************************************************************************
|
|
||||||
|
|
||||||
Map<String, dynamic> _$AgUiEventToJson(AgUiEvent instance) => <String, dynamic>{
|
|
||||||
'type': _$AgUiEventTypeEnumMap[instance.type]!,
|
|
||||||
};
|
|
||||||
|
|
||||||
const _$AgUiEventTypeEnumMap = {
|
|
||||||
AgUiEventType.runStarted: 'runStarted',
|
|
||||||
AgUiEventType.runFinished: 'runFinished',
|
|
||||||
AgUiEventType.runError: 'runError',
|
|
||||||
AgUiEventType.stepStarted: 'stepStarted',
|
|
||||||
AgUiEventType.stepFinished: 'stepFinished',
|
|
||||||
AgUiEventType.textMessageStart: 'textMessageStart',
|
|
||||||
AgUiEventType.textMessageContent: 'textMessageContent',
|
|
||||||
AgUiEventType.textMessageEnd: 'textMessageEnd',
|
|
||||||
AgUiEventType.toolCallStart: 'toolCallStart',
|
|
||||||
AgUiEventType.toolCallArgs: 'toolCallArgs',
|
|
||||||
AgUiEventType.toolCallEnd: 'toolCallEnd',
|
|
||||||
AgUiEventType.toolCallResult: 'toolCallResult',
|
|
||||||
AgUiEventType.toolCallError: 'toolCallError',
|
|
||||||
AgUiEventType.messagesSnapshot: 'messagesSnapshot',
|
|
||||||
AgUiEventType.unknown: 'unknown',
|
|
||||||
};
|
|
||||||
|
|
||||||
RunStartedEvent _$RunStartedEventFromJson(Map<String, dynamic> json) =>
|
|
||||||
RunStartedEvent(
|
|
||||||
threadId: json['threadId'] as String,
|
|
||||||
runId: json['runId'] as String,
|
|
||||||
);
|
|
||||||
|
|
||||||
Map<String, dynamic> _$RunStartedEventToJson(RunStartedEvent instance) =>
|
|
||||||
<String, dynamic>{'threadId': instance.threadId, 'runId': instance.runId};
|
|
||||||
|
|
||||||
RunFinishedEvent _$RunFinishedEventFromJson(Map<String, dynamic> json) =>
|
|
||||||
RunFinishedEvent(
|
|
||||||
threadId: json['threadId'] as String,
|
|
||||||
runId: json['runId'] as String,
|
|
||||||
);
|
|
||||||
|
|
||||||
Map<String, dynamic> _$RunFinishedEventToJson(RunFinishedEvent instance) =>
|
|
||||||
<String, dynamic>{'threadId': instance.threadId, 'runId': instance.runId};
|
|
||||||
|
|
||||||
RunErrorEvent _$RunErrorEventFromJson(Map<String, dynamic> json) =>
|
|
||||||
RunErrorEvent(
|
|
||||||
message: json['message'] as String,
|
|
||||||
code: json['code'] as String?,
|
|
||||||
);
|
|
||||||
|
|
||||||
Map<String, dynamic> _$RunErrorEventToJson(RunErrorEvent instance) =>
|
|
||||||
<String, dynamic>{'message': instance.message, 'code': instance.code};
|
|
||||||
|
|
||||||
StepStartedEvent _$StepStartedEventFromJson(Map<String, dynamic> json) =>
|
|
||||||
StepStartedEvent(stepName: json['stepName'] as String);
|
|
||||||
|
|
||||||
Map<String, dynamic> _$StepStartedEventToJson(StepStartedEvent instance) =>
|
|
||||||
<String, dynamic>{'stepName': instance.stepName};
|
|
||||||
|
|
||||||
StepFinishedEvent _$StepFinishedEventFromJson(Map<String, dynamic> json) =>
|
|
||||||
StepFinishedEvent(stepName: json['stepName'] as String);
|
|
||||||
|
|
||||||
Map<String, dynamic> _$StepFinishedEventToJson(StepFinishedEvent instance) =>
|
|
||||||
<String, dynamic>{'stepName': instance.stepName};
|
|
||||||
|
|
||||||
TextMessageStartEvent _$TextMessageStartEventFromJson(
|
|
||||||
Map<String, dynamic> json,
|
|
||||||
) => TextMessageStartEvent(
|
|
||||||
messageId: json['messageId'] as String,
|
|
||||||
role: json['role'] as String,
|
|
||||||
);
|
|
||||||
|
|
||||||
Map<String, dynamic> _$TextMessageStartEventToJson(
|
|
||||||
TextMessageStartEvent instance,
|
|
||||||
) => <String, dynamic>{'messageId': instance.messageId, 'role': instance.role};
|
|
||||||
|
|
||||||
TextMessageContentEvent _$TextMessageContentEventFromJson(
|
|
||||||
Map<String, dynamic> json,
|
|
||||||
) => TextMessageContentEvent(
|
|
||||||
messageId: json['messageId'] as String,
|
|
||||||
delta: json['delta'] as String,
|
|
||||||
);
|
|
||||||
|
|
||||||
Map<String, dynamic> _$TextMessageContentEventToJson(
|
|
||||||
TextMessageContentEvent instance,
|
|
||||||
) => <String, dynamic>{
|
|
||||||
'messageId': instance.messageId,
|
|
||||||
'delta': instance.delta,
|
|
||||||
};
|
|
||||||
|
|
||||||
TextMessageEndEvent _$TextMessageEndEventFromJson(Map<String, dynamic> json) =>
|
|
||||||
TextMessageEndEvent(messageId: json['messageId'] as String);
|
|
||||||
|
|
||||||
Map<String, dynamic> _$TextMessageEndEventToJson(
|
|
||||||
TextMessageEndEvent instance,
|
|
||||||
) => <String, dynamic>{'messageId': instance.messageId};
|
|
||||||
|
|
||||||
ToolCallStartEvent _$ToolCallStartEventFromJson(Map<String, dynamic> json) =>
|
|
||||||
ToolCallStartEvent(
|
|
||||||
toolCallId: json['toolCallId'] as String,
|
|
||||||
toolCallName: json['toolCallName'] as String,
|
|
||||||
parentMessageId: json['parentMessageId'] as String?,
|
|
||||||
);
|
|
||||||
|
|
||||||
Map<String, dynamic> _$ToolCallStartEventToJson(ToolCallStartEvent instance) =>
|
|
||||||
<String, dynamic>{
|
|
||||||
'toolCallId': instance.toolCallId,
|
|
||||||
'toolCallName': instance.toolCallName,
|
|
||||||
'parentMessageId': instance.parentMessageId,
|
|
||||||
};
|
|
||||||
|
|
||||||
ToolCallArgsEvent _$ToolCallArgsEventFromJson(Map<String, dynamic> json) =>
|
|
||||||
ToolCallArgsEvent(
|
|
||||||
toolCallId: json['toolCallId'] as String,
|
|
||||||
delta: json['delta'] as String,
|
|
||||||
);
|
|
||||||
|
|
||||||
Map<String, dynamic> _$ToolCallArgsEventToJson(ToolCallArgsEvent instance) =>
|
|
||||||
<String, dynamic>{
|
|
||||||
'toolCallId': instance.toolCallId,
|
|
||||||
'delta': instance.delta,
|
|
||||||
};
|
|
||||||
|
|
||||||
ToolCallEndEvent _$ToolCallEndEventFromJson(Map<String, dynamic> json) =>
|
|
||||||
ToolCallEndEvent(toolCallId: json['toolCallId'] as String);
|
|
||||||
|
|
||||||
Map<String, dynamic> _$ToolCallEndEventToJson(ToolCallEndEvent instance) =>
|
|
||||||
<String, dynamic>{'toolCallId': instance.toolCallId};
|
|
||||||
|
|
||||||
ToolCallResultEvent _$ToolCallResultEventFromJson(Map<String, dynamic> json) =>
|
|
||||||
ToolCallResultEvent(
|
|
||||||
messageId: json['messageId'] as String,
|
|
||||||
toolCallId: json['toolCallId'] as String,
|
|
||||||
content: json['content'] as String,
|
|
||||||
);
|
|
||||||
|
|
||||||
Map<String, dynamic> _$ToolCallResultEventToJson(
|
|
||||||
ToolCallResultEvent instance,
|
|
||||||
) => <String, dynamic>{
|
|
||||||
'messageId': instance.messageId,
|
|
||||||
'toolCallId': instance.toolCallId,
|
|
||||||
'content': instance.content,
|
|
||||||
};
|
|
||||||
|
|
||||||
ToolCallErrorEvent _$ToolCallErrorEventFromJson(Map<String, dynamic> json) =>
|
|
||||||
ToolCallErrorEvent(
|
|
||||||
toolCallId: json['toolCallId'] as String,
|
|
||||||
error: json['error'] as String,
|
|
||||||
code: json['code'] as String?,
|
|
||||||
);
|
|
||||||
|
|
||||||
Map<String, dynamic> _$ToolCallErrorEventToJson(ToolCallErrorEvent instance) =>
|
|
||||||
<String, dynamic>{
|
|
||||||
'toolCallId': instance.toolCallId,
|
|
||||||
'error': instance.error,
|
|
||||||
'code': instance.code,
|
|
||||||
};
|
|
||||||
|
|
||||||
MessagesSnapshotEvent _$MessagesSnapshotEventFromJson(
|
|
||||||
Map<String, dynamic> json,
|
|
||||||
) => MessagesSnapshotEvent(
|
|
||||||
messages: (json['messages'] as List<dynamic>)
|
|
||||||
.map((e) => SnapshotMessage.fromJson(e as Map<String, dynamic>))
|
|
||||||
.toList(),
|
|
||||||
);
|
|
||||||
|
|
||||||
Map<String, dynamic> _$MessagesSnapshotEventToJson(
|
|
||||||
MessagesSnapshotEvent instance,
|
|
||||||
) => <String, dynamic>{'messages': instance.messages};
|
|
||||||
|
|
||||||
SnapshotMessage _$SnapshotMessageFromJson(Map<String, dynamic> json) =>
|
|
||||||
SnapshotMessage(
|
|
||||||
id: json['id'] as String,
|
|
||||||
role: json['role'] as String,
|
|
||||||
content: json['content'] as String?,
|
|
||||||
toolCallId: json['toolCallId'] as String?,
|
|
||||||
ui: json['ui'] == null
|
|
||||||
? null
|
|
||||||
: UiCard.fromJson(json['ui'] as Map<String, dynamic>),
|
|
||||||
timestamp: json['timestamp'] == null
|
|
||||||
? null
|
|
||||||
: DateTime.parse(json['timestamp'] as String),
|
|
||||||
attachments: (json['attachments'] as List<dynamic>?)
|
|
||||||
?.whereType<Map<String, dynamic>>()
|
|
||||||
.toList(),
|
|
||||||
);
|
|
||||||
|
|
||||||
Map<String, dynamic> _$SnapshotMessageToJson(SnapshotMessage instance) =>
|
|
||||||
<String, dynamic>{
|
|
||||||
'id': instance.id,
|
|
||||||
'role': instance.role,
|
|
||||||
'content': instance.content,
|
|
||||||
'toolCallId': instance.toolCallId,
|
|
||||||
'ui': instance.ui,
|
|
||||||
'timestamp': instance.timestamp?.toIso8601String(),
|
|
||||||
'attachments': instance.attachments,
|
|
||||||
};
|
|
||||||
@@ -1,691 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
from collections.abc import AsyncGenerator, Sequence
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from decimal import Decimal
|
|
||||||
from typing import TYPE_CHECKING, Any
|
|
||||||
from uuid import UUID, uuid4
|
|
||||||
|
|
||||||
from ag_ui.core.types import RunAgentInput
|
|
||||||
from agentscope.agent import ReActAgent
|
|
||||||
from agentscope.formatter import OpenAIChatFormatter
|
|
||||||
from agentscope.memory import InMemoryMemory
|
|
||||||
from agentscope.message import Msg
|
|
||||||
from agentscope.model import OpenAIChatModel
|
|
||||||
from core.agentscope.events.persistence import MessageRepository, SessionRepository
|
|
||||||
from core.agentscope.prompts.system_prompt import build_system_prompt
|
|
||||||
from core.agentscope.tools.toolkit import build_stage_toolkit
|
|
||||||
from core.db.session import AsyncSessionLocal
|
|
||||||
from core.logging import get_logger
|
|
||||||
from models.agent_chat_message import AgentChatMessageRole
|
|
||||||
from models.agent_chat_session import AgentChatSessionStatus
|
|
||||||
from models.llm import Llm
|
|
||||||
from models.system_agents import SystemAgents
|
|
||||||
from schemas.agent.runtime_models import (
|
|
||||||
RouterAgentOutput,
|
|
||||||
ToolAgentOutput,
|
|
||||||
WorkerAgentOutputLite,
|
|
||||||
resolve_worker_output_model,
|
|
||||||
)
|
|
||||||
from schemas.agent.system_agent import AgentType, SystemAgentLLMConfig
|
|
||||||
from schemas.messages.chat_message import AgentChatMessage, AgentChatMessageMetadata
|
|
||||||
from schemas.user import UserContext
|
|
||||||
from services.litellm.service import LiteLLMService
|
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from core.agentscope.runtime.orchestrator import PipelineLike
|
|
||||||
|
|
||||||
logger = get_logger("core.agentscope.runtime.react_runner")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class SystemAgentRuntimeConfig:
|
|
||||||
agent_type: AgentType
|
|
||||||
model_code: str
|
|
||||||
llm_config: SystemAgentLLMConfig
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class StageExecutionResult:
|
|
||||||
message: Msg
|
|
||||||
payload: dict[str, Any]
|
|
||||||
response_metadata: dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
class _TrackingChatModel:
|
|
||||||
def __init__(self, inner: OpenAIChatModel) -> None:
|
|
||||||
self._inner = inner
|
|
||||||
self._total_input_tokens = 0
|
|
||||||
self._total_output_tokens = 0
|
|
||||||
self._total_latency_ms = 0
|
|
||||||
self._cached_prompt_tokens = 0
|
|
||||||
|
|
||||||
@property
|
|
||||||
def stream(self) -> bool:
|
|
||||||
return self._inner.stream
|
|
||||||
|
|
||||||
@stream.setter
|
|
||||||
def stream(self, value: bool) -> None:
|
|
||||||
self._inner.stream = value
|
|
||||||
|
|
||||||
def __getattr__(self, name: str) -> Any:
|
|
||||||
return getattr(self._inner, name)
|
|
||||||
|
|
||||||
async def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
|
||||||
response = await self._inner(*args, **kwargs)
|
|
||||||
if isinstance(response, AsyncGenerator):
|
|
||||||
return self._track_stream(response)
|
|
||||||
self._record_usage(getattr(response, "usage", None))
|
|
||||||
return response
|
|
||||||
|
|
||||||
async def _track_stream(
|
|
||||||
self, response: AsyncGenerator[Any, None]
|
|
||||||
) -> AsyncGenerator[Any, None]:
|
|
||||||
latest_usage = None
|
|
||||||
async for chunk in response:
|
|
||||||
usage = getattr(chunk, "usage", None)
|
|
||||||
if usage is not None:
|
|
||||||
latest_usage = usage
|
|
||||||
yield chunk
|
|
||||||
self._record_usage(latest_usage)
|
|
||||||
|
|
||||||
def _record_usage(self, usage: Any) -> None:
|
|
||||||
if usage is None:
|
|
||||||
return
|
|
||||||
self._total_input_tokens += max(int(getattr(usage, "input_tokens", 0) or 0), 0)
|
|
||||||
self._total_output_tokens += max(
|
|
||||||
int(getattr(usage, "output_tokens", 0) or 0), 0
|
|
||||||
)
|
|
||||||
self._total_latency_ms += max(
|
|
||||||
int(round(float(getattr(usage, "time", 0) or 0) * 1000)), 0
|
|
||||||
)
|
|
||||||
metadata = getattr(usage, "metadata", None)
|
|
||||||
if metadata is not None:
|
|
||||||
cached_tokens = 0
|
|
||||||
if isinstance(metadata, dict):
|
|
||||||
prompt_details = metadata.get("prompt_tokens_details")
|
|
||||||
if isinstance(prompt_details, dict):
|
|
||||||
cached_tokens = int(prompt_details.get("cached_tokens", 0) or 0)
|
|
||||||
else:
|
|
||||||
prompt_details = getattr(metadata, "prompt_tokens_details", None)
|
|
||||||
cached_tokens = int(getattr(prompt_details, "cached_tokens", 0) or 0)
|
|
||||||
self._cached_prompt_tokens += max(cached_tokens, 0)
|
|
||||||
|
|
||||||
def usage_summary(self) -> dict[str, int]:
|
|
||||||
return {
|
|
||||||
"input_tokens": self._total_input_tokens,
|
|
||||||
"output_tokens": self._total_output_tokens,
|
|
||||||
"latency_ms": self._total_latency_ms,
|
|
||||||
"cached_prompt_tokens": self._cached_prompt_tokens,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class _PipelineStageEmitter:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
pipeline: PipelineLike,
|
|
||||||
session_id: str,
|
|
||||||
run_id: str,
|
|
||||||
stage: str,
|
|
||||||
emit_text_events: bool,
|
|
||||||
emit_tool_events: bool,
|
|
||||||
) -> None:
|
|
||||||
self._pipeline = pipeline
|
|
||||||
self._session_id = session_id
|
|
||||||
self._run_id = run_id
|
|
||||||
self._stage = stage
|
|
||||||
self._emit_text_events = emit_text_events
|
|
||||||
self._emit_tool_events = emit_tool_events
|
|
||||||
self._text_by_message_id: dict[str, str] = {}
|
|
||||||
self._emitted_tool_calls: set[str] = set()
|
|
||||||
self._emitted_tool_results: set[str] = set()
|
|
||||||
self.latest_text_message_id: str | None = None
|
|
||||||
self.latest_text: str = ""
|
|
||||||
|
|
||||||
async def handle_print(self, *, msg: Msg, last: bool) -> None:
|
|
||||||
del last
|
|
||||||
if self._emit_tool_events:
|
|
||||||
await self._emit_tool_events_from_msg(msg)
|
|
||||||
if self._emit_text_events:
|
|
||||||
await self._emit_text_events_from_msg(msg)
|
|
||||||
|
|
||||||
async def _emit_text_events_from_msg(self, msg: Msg) -> None:
|
|
||||||
text = msg.get_text_content(separator="") or ""
|
|
||||||
if not text:
|
|
||||||
return
|
|
||||||
message_id = str(msg.id)
|
|
||||||
previous = self._text_by_message_id.get(message_id, "")
|
|
||||||
if message_id not in self._text_by_message_id:
|
|
||||||
await self._emit(
|
|
||||||
"text.start",
|
|
||||||
{
|
|
||||||
"messageId": message_id,
|
|
||||||
"role": "assistant",
|
|
||||||
"stage": self._stage,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
delta = text[len(previous) :] if text.startswith(previous) else text
|
|
||||||
if delta:
|
|
||||||
await self._emit(
|
|
||||||
"text.delta",
|
|
||||||
{
|
|
||||||
"messageId": message_id,
|
|
||||||
"delta": delta,
|
|
||||||
"stage": self._stage,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
self._text_by_message_id[message_id] = text
|
|
||||||
self.latest_text_message_id = message_id
|
|
||||||
self.latest_text = text
|
|
||||||
|
|
||||||
async def _emit_tool_events_from_msg(self, msg: Msg) -> None:
|
|
||||||
for block in msg.get_content_blocks("tool_use"):
|
|
||||||
tool_call_id = str(block.get("id", "")).strip()
|
|
||||||
tool_name = str(block.get("name", "")).strip()
|
|
||||||
if (
|
|
||||||
not tool_call_id
|
|
||||||
or not tool_name
|
|
||||||
or tool_call_id in self._emitted_tool_calls
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
payload = {
|
|
||||||
"messageId": str(msg.id),
|
|
||||||
"toolCallId": tool_call_id,
|
|
||||||
"toolName": tool_name,
|
|
||||||
"stage": self._stage,
|
|
||||||
}
|
|
||||||
await self._emit("tool.start", payload)
|
|
||||||
await self._emit(
|
|
||||||
"tool.args",
|
|
||||||
{
|
|
||||||
**payload,
|
|
||||||
"args": block.get("input", {}),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
await self._emit("tool.end", payload)
|
|
||||||
self._emitted_tool_calls.add(tool_call_id)
|
|
||||||
|
|
||||||
for block in msg.get_content_blocks("tool_result"):
|
|
||||||
tool_call_id = str(block.get("id", "")).strip()
|
|
||||||
if not tool_call_id or tool_call_id in self._emitted_tool_results:
|
|
||||||
continue
|
|
||||||
tool_output = _parse_tool_agent_output(block.get("output"))
|
|
||||||
if tool_output is None:
|
|
||||||
continue
|
|
||||||
await self._emit(
|
|
||||||
"tool.result",
|
|
||||||
{
|
|
||||||
"messageId": str(msg.id),
|
|
||||||
"toolCallId": tool_call_id,
|
|
||||||
"toolName": tool_output.tool_name,
|
|
||||||
"stage": self._stage,
|
|
||||||
"toolAgentOutput": tool_output.model_dump(
|
|
||||||
mode="json", exclude_none=True
|
|
||||||
),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
self._emitted_tool_results.add(tool_call_id)
|
|
||||||
|
|
||||||
async def emit_final_text_end(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
worker_output: dict[str, Any],
|
|
||||||
response_metadata: dict[str, Any],
|
|
||||||
) -> None:
|
|
||||||
message_id = (
|
|
||||||
self.latest_text_message_id or f"worker-{self._run_id}-{uuid4().hex[:8]}"
|
|
||||||
)
|
|
||||||
if self.latest_text_message_id is None and worker_output.get("answer"):
|
|
||||||
await self._emit(
|
|
||||||
"text.start",
|
|
||||||
{
|
|
||||||
"messageId": message_id,
|
|
||||||
"role": "assistant",
|
|
||||||
"stage": self._stage,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
await self._emit(
|
|
||||||
"text.delta",
|
|
||||||
{
|
|
||||||
"messageId": message_id,
|
|
||||||
"delta": worker_output.get("answer", ""),
|
|
||||||
"stage": self._stage,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
await self._emit(
|
|
||||||
"text.end",
|
|
||||||
{
|
|
||||||
"messageId": message_id,
|
|
||||||
"role": "assistant",
|
|
||||||
"stage": self._stage,
|
|
||||||
"workerAgentOutput": worker_output,
|
|
||||||
**response_metadata,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _emit(self, event_type: str, data: dict[str, Any]) -> None:
|
|
||||||
await self._pipeline.emit(
|
|
||||||
session_id=self._session_id,
|
|
||||||
event={
|
|
||||||
"type": event_type,
|
|
||||||
"threadId": self._session_id,
|
|
||||||
"runId": self._run_id,
|
|
||||||
"data": data,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class _PipelineReActAgent(ReActAgent):
|
|
||||||
def __init__(
|
|
||||||
self, *, emitter: _PipelineStageEmitter | None = None, **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self._pipeline_emitter = emitter
|
|
||||||
self.disable_console_output()
|
|
||||||
|
|
||||||
async def print(self, msg: Msg, last: bool = True, speech: Any = None) -> None:
|
|
||||||
del speech
|
|
||||||
if self._pipeline_emitter is not None:
|
|
||||||
await self._pipeline_emitter.handle_print(msg=msg, last=last)
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_tool_agent_output(output: Any) -> ToolAgentOutput | None:
|
|
||||||
blocks = output if isinstance(output, Sequence) else []
|
|
||||||
for block in blocks:
|
|
||||||
if not isinstance(block, dict) or block.get("type") != "text":
|
|
||||||
continue
|
|
||||||
text = block.get("text")
|
|
||||||
if not isinstance(text, str) or not text.strip():
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
return ToolAgentOutput.model_validate(json.loads(text))
|
|
||||||
except Exception:
|
|
||||||
return None
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize_tool_name(value: str) -> str:
|
|
||||||
return value.strip().replace(".", "_").replace("-", "_")
|
|
||||||
|
|
||||||
|
|
||||||
class AgentScopeReActRunner:
|
|
||||||
def __init__(self, *, litellm_service: LiteLLMService | None = None) -> None:
|
|
||||||
self._litellm_service = litellm_service or LiteLLMService()
|
|
||||||
|
|
||||||
async def execute(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
user_context: UserContext,
|
|
||||||
context_messages: list[Msg],
|
|
||||||
pipeline: PipelineLike,
|
|
||||||
run_input: RunAgentInput,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
owner_id = UUID(user_context.id)
|
|
||||||
enabled_tool_names = self._extract_tool_names(run_input)
|
|
||||||
|
|
||||||
async with AsyncSessionLocal() as session:
|
|
||||||
router_toolkit, worker_toolkit = self._build_toolkits(
|
|
||||||
session=session,
|
|
||||||
owner_id=owner_id,
|
|
||||||
enabled_tool_names=enabled_tool_names,
|
|
||||||
)
|
|
||||||
|
|
||||||
router_config = await self._load_system_agent_config(
|
|
||||||
session=session,
|
|
||||||
agent_type=AgentType.ROUTER,
|
|
||||||
)
|
|
||||||
worker_config = await self._load_system_agent_config(
|
|
||||||
session=session,
|
|
||||||
agent_type=AgentType.WORKER,
|
|
||||||
)
|
|
||||||
|
|
||||||
await self._emit_step_event(
|
|
||||||
pipeline=pipeline,
|
|
||||||
run_input=run_input,
|
|
||||||
step_name="router",
|
|
||||||
event_type="step.start",
|
|
||||||
)
|
|
||||||
router_result = await self._run_router_stage(
|
|
||||||
user_context=user_context,
|
|
||||||
context_messages=context_messages,
|
|
||||||
toolkit=router_toolkit,
|
|
||||||
run_input=run_input,
|
|
||||||
stage_config=router_config,
|
|
||||||
)
|
|
||||||
router_output = RouterAgentOutput.model_validate(router_result.payload)
|
|
||||||
await self._persist_router_message(
|
|
||||||
session=session,
|
|
||||||
thread_id=run_input.thread_id,
|
|
||||||
run_id=run_input.run_id,
|
|
||||||
model_code=router_config.model_code,
|
|
||||||
router_output=router_output,
|
|
||||||
response_metadata=router_result.response_metadata,
|
|
||||||
)
|
|
||||||
await session.commit()
|
|
||||||
await self._emit_step_event(
|
|
||||||
pipeline=pipeline,
|
|
||||||
run_input=run_input,
|
|
||||||
step_name="router",
|
|
||||||
event_type="step.finish",
|
|
||||||
)
|
|
||||||
|
|
||||||
worker_output_model = resolve_worker_output_model(router_output.ui.ui_mode)
|
|
||||||
await self._emit_step_event(
|
|
||||||
pipeline=pipeline,
|
|
||||||
run_input=run_input,
|
|
||||||
step_name="worker",
|
|
||||||
event_type="step.start",
|
|
||||||
)
|
|
||||||
worker_result = await self._run_worker_stage(
|
|
||||||
user_context=user_context,
|
|
||||||
router_output=router_output,
|
|
||||||
toolkit=worker_toolkit,
|
|
||||||
run_input=run_input,
|
|
||||||
stage_config=worker_config,
|
|
||||||
worker_output_model=worker_output_model,
|
|
||||||
pipeline=pipeline,
|
|
||||||
)
|
|
||||||
worker_output = worker_output_model.model_validate(worker_result.payload)
|
|
||||||
await self._emit_step_event(
|
|
||||||
pipeline=pipeline,
|
|
||||||
run_input=run_input,
|
|
||||||
step_name="worker",
|
|
||||||
event_type="step.finish",
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"router": router_output.model_dump(mode="json", exclude_none=True),
|
|
||||||
"worker": worker_output.model_dump(mode="json", exclude_none=True),
|
|
||||||
}
|
|
||||||
|
|
||||||
def _build_toolkits(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
session: AsyncSession,
|
|
||||||
owner_id: UUID,
|
|
||||||
enabled_tool_names: set[str] | None,
|
|
||||||
) -> tuple[Any, Any]:
|
|
||||||
return (
|
|
||||||
build_stage_toolkit(
|
|
||||||
agent_type=AgentType.ROUTER,
|
|
||||||
session=session,
|
|
||||||
owner_id=owner_id,
|
|
||||||
enabled_tool_names=enabled_tool_names,
|
|
||||||
),
|
|
||||||
build_stage_toolkit(
|
|
||||||
agent_type=AgentType.WORKER,
|
|
||||||
session=session,
|
|
||||||
owner_id=owner_id,
|
|
||||||
enabled_tool_names=enabled_tool_names,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def _extract_tool_names(self, run_input: RunAgentInput) -> set[str] | None:
|
|
||||||
raw_tools = getattr(run_input, "tools", None)
|
|
||||||
if not isinstance(raw_tools, list):
|
|
||||||
return None
|
|
||||||
selected: set[str] = set()
|
|
||||||
for item in raw_tools:
|
|
||||||
if isinstance(item, dict):
|
|
||||||
name = item.get("name")
|
|
||||||
else:
|
|
||||||
name = getattr(item, "name", None)
|
|
||||||
if isinstance(name, str) and name.strip():
|
|
||||||
selected.add(_normalize_tool_name(name))
|
|
||||||
return selected
|
|
||||||
|
|
||||||
async def _load_system_agent_config(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
session: AsyncSession,
|
|
||||||
agent_type: AgentType,
|
|
||||||
) -> SystemAgentRuntimeConfig:
|
|
||||||
stmt = (
|
|
||||||
select(SystemAgents, Llm)
|
|
||||||
.join(Llm, SystemAgents.llm_id == Llm.id)
|
|
||||||
.where(SystemAgents.agent_type == agent_type.value)
|
|
||||||
)
|
|
||||||
row = (await session.execute(stmt)).one_or_none()
|
|
||||||
if row is None:
|
|
||||||
raise RuntimeError(f"system agent config not found: {agent_type.value}")
|
|
||||||
system_agent, llm = row
|
|
||||||
status = str(system_agent.status).strip().lower()
|
|
||||||
if status != "active":
|
|
||||||
raise RuntimeError(f"system agent is not active: {agent_type.value}")
|
|
||||||
return SystemAgentRuntimeConfig(
|
|
||||||
agent_type=agent_type,
|
|
||||||
model_code=llm.model_code,
|
|
||||||
llm_config=SystemAgentLLMConfig.model_validate(system_agent.config or {}),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _run_router_stage(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
user_context: UserContext,
|
|
||||||
context_messages: list[Msg],
|
|
||||||
toolkit: Any,
|
|
||||||
run_input: RunAgentInput,
|
|
||||||
stage_config: SystemAgentRuntimeConfig,
|
|
||||||
) -> StageExecutionResult:
|
|
||||||
tracking_model = self._build_model(stage_config=stage_config)
|
|
||||||
agent = self._build_agent(
|
|
||||||
agent_name="router",
|
|
||||||
system_prompt=build_system_prompt(
|
|
||||||
agent_type=AgentType.ROUTER,
|
|
||||||
user_context=user_context,
|
|
||||||
now_utc=datetime.now(timezone.utc),
|
|
||||||
tools=run_input.tools,
|
|
||||||
),
|
|
||||||
toolkit=toolkit,
|
|
||||||
model=tracking_model,
|
|
||||||
)
|
|
||||||
response_msg = await agent.reply(
|
|
||||||
context_messages, structured_model=RouterAgentOutput
|
|
||||||
)
|
|
||||||
payload = RouterAgentOutput.model_validate(
|
|
||||||
response_msg.metadata or {}
|
|
||||||
).model_dump(
|
|
||||||
mode="json",
|
|
||||||
exclude_none=True,
|
|
||||||
)
|
|
||||||
return StageExecutionResult(
|
|
||||||
message=response_msg,
|
|
||||||
payload=payload,
|
|
||||||
response_metadata=self._litellm_service.build_usage_metadata(
|
|
||||||
model=stage_config.model_code,
|
|
||||||
usage_summary=tracking_model.usage_summary(),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _run_worker_stage(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
user_context: UserContext,
|
|
||||||
router_output: RouterAgentOutput,
|
|
||||||
toolkit: Any,
|
|
||||||
run_input: RunAgentInput,
|
|
||||||
stage_config: SystemAgentRuntimeConfig,
|
|
||||||
worker_output_model: type[WorkerAgentOutputLite],
|
|
||||||
pipeline: PipelineLike,
|
|
||||||
) -> StageExecutionResult:
|
|
||||||
worker_input = self._build_worker_input_messages(
|
|
||||||
router_output=router_output,
|
|
||||||
)
|
|
||||||
tracking_model = self._build_model(stage_config=stage_config)
|
|
||||||
emitter = _PipelineStageEmitter(
|
|
||||||
pipeline=pipeline,
|
|
||||||
session_id=run_input.thread_id,
|
|
||||||
run_id=run_input.run_id,
|
|
||||||
stage="worker",
|
|
||||||
emit_text_events=True,
|
|
||||||
emit_tool_events=True,
|
|
||||||
)
|
|
||||||
agent = self._build_agent(
|
|
||||||
agent_name="worker",
|
|
||||||
system_prompt=build_system_prompt(
|
|
||||||
agent_type=AgentType.WORKER,
|
|
||||||
user_context=user_context,
|
|
||||||
now_utc=datetime.now(timezone.utc),
|
|
||||||
tools=run_input.tools,
|
|
||||||
),
|
|
||||||
toolkit=toolkit,
|
|
||||||
model=tracking_model,
|
|
||||||
emitter=emitter,
|
|
||||||
)
|
|
||||||
response_msg = await agent.reply(
|
|
||||||
worker_input,
|
|
||||||
structured_model=worker_output_model,
|
|
||||||
)
|
|
||||||
worker_payload = worker_output_model.model_validate(response_msg.metadata or {})
|
|
||||||
response_metadata = self._litellm_service.build_usage_metadata(
|
|
||||||
model=stage_config.model_code,
|
|
||||||
usage_summary=tracking_model.usage_summary(),
|
|
||||||
)
|
|
||||||
await emitter.emit_final_text_end(
|
|
||||||
worker_output=worker_payload.model_dump(mode="json", exclude_none=True),
|
|
||||||
response_metadata=response_metadata,
|
|
||||||
)
|
|
||||||
return StageExecutionResult(
|
|
||||||
message=response_msg,
|
|
||||||
payload=worker_payload.model_dump(mode="json", exclude_none=True),
|
|
||||||
response_metadata=response_metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _build_worker_input_messages(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
router_output: RouterAgentOutput,
|
|
||||||
) -> list[Msg]:
|
|
||||||
routing_contract = json.dumps(
|
|
||||||
router_output.model_dump(mode="json", exclude_none=True),
|
|
||||||
ensure_ascii=False,
|
|
||||||
separators=(",", ":"),
|
|
||||||
)
|
|
||||||
routing_msg = Msg(
|
|
||||||
name="router",
|
|
||||||
role="user",
|
|
||||||
content=(
|
|
||||||
"Use the following routing contract as the execution source of truth. "
|
|
||||||
f"Do not change the routed objective:\n{routing_contract}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
return [routing_msg]
|
|
||||||
|
|
||||||
def _build_model(
|
|
||||||
self, *, stage_config: SystemAgentRuntimeConfig
|
|
||||||
) -> _TrackingChatModel:
|
|
||||||
model = OpenAIChatModel(
|
|
||||||
model_name=stage_config.model_code,
|
|
||||||
api_key=self._litellm_service.proxy_api_key,
|
|
||||||
stream=True,
|
|
||||||
client_kwargs={"base_url": self._litellm_service.proxy_base_url},
|
|
||||||
generate_kwargs={
|
|
||||||
"temperature": stage_config.llm_config.temperature,
|
|
||||||
"max_tokens": stage_config.llm_config.max_tokens,
|
|
||||||
"timeout": stage_config.llm_config.timeout_seconds,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
return _TrackingChatModel(model)
|
|
||||||
|
|
||||||
def _build_agent(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
agent_name: str,
|
|
||||||
system_prompt: str,
|
|
||||||
toolkit: Any,
|
|
||||||
model: _TrackingChatModel,
|
|
||||||
emitter: _PipelineStageEmitter | None = None,
|
|
||||||
) -> _PipelineReActAgent:
|
|
||||||
return _PipelineReActAgent(
|
|
||||||
name=agent_name,
|
|
||||||
sys_prompt=system_prompt,
|
|
||||||
model=model,
|
|
||||||
formatter=OpenAIChatFormatter(),
|
|
||||||
toolkit=toolkit,
|
|
||||||
memory=InMemoryMemory(),
|
|
||||||
emitter=emitter,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _emit_step_event(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
pipeline: PipelineLike,
|
|
||||||
run_input: RunAgentInput,
|
|
||||||
step_name: str,
|
|
||||||
event_type: str,
|
|
||||||
) -> None:
|
|
||||||
await pipeline.emit(
|
|
||||||
session_id=run_input.thread_id,
|
|
||||||
event={
|
|
||||||
"type": event_type,
|
|
||||||
"threadId": run_input.thread_id,
|
|
||||||
"runId": run_input.run_id,
|
|
||||||
"data": {"stepName": step_name},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _persist_router_message(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
session: AsyncSession,
|
|
||||||
thread_id: str,
|
|
||||||
run_id: str,
|
|
||||||
model_code: str,
|
|
||||||
router_output: RouterAgentOutput,
|
|
||||||
response_metadata: dict[str, Any],
|
|
||||||
) -> None:
|
|
||||||
session_id = UUID(thread_id)
|
|
||||||
message_repo = MessageRepository(session)
|
|
||||||
session_repo = SessionRepository(session)
|
|
||||||
locked_session = await session_repo.lock_session_for_update(
|
|
||||||
session_id=session_id
|
|
||||||
)
|
|
||||||
if locked_session is None:
|
|
||||||
raise RuntimeError("chat session not found for router persistence")
|
|
||||||
seq = int(getattr(locked_session, "message_count", 0) or 0) + 1
|
|
||||||
metadata = AgentChatMessageMetadata(
|
|
||||||
run_id=run_id,
|
|
||||||
agent_type=AgentType.ROUTER,
|
|
||||||
router_agent_output=router_output,
|
|
||||||
)
|
|
||||||
message_payload = AgentChatMessage(
|
|
||||||
id=uuid4(),
|
|
||||||
seq=seq,
|
|
||||||
role=AgentChatMessageRole.ASSISTANT.value,
|
|
||||||
content="",
|
|
||||||
model_code=model_code,
|
|
||||||
tool_name=None,
|
|
||||||
input_tokens=int(response_metadata.get("inputTokens", 0) or 0),
|
|
||||||
output_tokens=int(response_metadata.get("outputTokens", 0) or 0),
|
|
||||||
cost=Decimal(str(response_metadata.get("cost", 0) or 0)),
|
|
||||||
latency_ms=int(response_metadata.get("latencyMs", 0) or 0),
|
|
||||||
metadata=metadata,
|
|
||||||
timestamp=datetime.now(timezone.utc),
|
|
||||||
)
|
|
||||||
await message_repo.append_message(
|
|
||||||
session_id=session_id,
|
|
||||||
seq=message_payload.seq,
|
|
||||||
role=AgentChatMessageRole.ASSISTANT,
|
|
||||||
content=message_payload.content,
|
|
||||||
model_code=message_payload.model_code,
|
|
||||||
tool_name=message_payload.tool_name,
|
|
||||||
metadata=metadata.model_dump(mode="json", exclude_none=True),
|
|
||||||
input_tokens=message_payload.input_tokens,
|
|
||||||
output_tokens=message_payload.output_tokens,
|
|
||||||
cost=message_payload.cost,
|
|
||||||
latency_ms=message_payload.latency_ms,
|
|
||||||
)
|
|
||||||
await session_repo.update_runtime_state(
|
|
||||||
chat_session=locked_session,
|
|
||||||
status=AgentChatSessionStatus.RUNNING,
|
|
||||||
state_snapshot=locked_session.state_snapshot or {},
|
|
||||||
message_delta=1,
|
|
||||||
token_delta=message_payload.input_tokens + message_payload.output_tokens,
|
|
||||||
cost_delta=message_payload.cost,
|
|
||||||
)
|
|
||||||
await session.flush()
|
|
||||||
@@ -1,201 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from ag_ui.core import RunAgentInput
|
|
||||||
from agentscope.message import Msg
|
|
||||||
|
|
||||||
from core.agentscope.runtime.react_runner import (
|
|
||||||
AgentScopeReActRunner,
|
|
||||||
StageExecutionResult,
|
|
||||||
SystemAgentRuntimeConfig,
|
|
||||||
)
|
|
||||||
from schemas.agent.runtime_models import (
|
|
||||||
RouterAgentOutput,
|
|
||||||
UiMode,
|
|
||||||
WorkerAgentOutputRich,
|
|
||||||
)
|
|
||||||
from schemas.agent.system_agent import AgentType, SystemAgentLLMConfig
|
|
||||||
from schemas.user.context import UserContext, parse_profile_settings
|
|
||||||
|
|
||||||
|
|
||||||
class _FakePipeline:
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.events: list[dict[str, object]] = []
|
|
||||||
|
|
||||||
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
|
|
||||||
self.events.append({"session_id": session_id, "event": event})
|
|
||||||
return "1-0"
|
|
||||||
|
|
||||||
|
|
||||||
class _FakeSessionCtx:
|
|
||||||
def __init__(self, session: object) -> None:
|
|
||||||
self._session = session
|
|
||||||
|
|
||||||
async def __aenter__(self) -> object:
|
|
||||||
return self._session
|
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc, tb) -> None:
|
|
||||||
del exc_type, exc, tb
|
|
||||||
|
|
||||||
|
|
||||||
def _user_context() -> UserContext:
|
|
||||||
return UserContext(
|
|
||||||
id="00000000-0000-0000-0000-000000000001",
|
|
||||||
username="alice",
|
|
||||||
email="alice@example.com",
|
|
||||||
settings=parse_profile_settings(None),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _run_input() -> RunAgentInput:
|
|
||||||
return RunAgentInput.model_validate(
|
|
||||||
{
|
|
||||||
"threadId": "00000000-0000-0000-0000-000000000010",
|
|
||||||
"runId": "run-1",
|
|
||||||
"state": {},
|
|
||||||
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
|
|
||||||
"tools": [
|
|
||||||
{
|
|
||||||
"name": "calendar.read",
|
|
||||||
"description": "read",
|
|
||||||
"parameters": {"type": "object"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "calendar-write",
|
|
||||||
"description": "write",
|
|
||||||
"parameters": {"type": "object"},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"context": [],
|
|
||||||
"forwardedProps": {},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _router_output(*, ui_mode: UiMode) -> RouterAgentOutput:
|
|
||||||
return RouterAgentOutput.model_validate(
|
|
||||||
{
|
|
||||||
"normalized_task_input": {
|
|
||||||
"user_text": "hello",
|
|
||||||
"multimodal_summary": [],
|
|
||||||
},
|
|
||||||
"key_entities": [],
|
|
||||||
"constraints": [],
|
|
||||||
"task_typing": {"primary": "knowledge", "secondary": []},
|
|
||||||
"execution_mode": "onestep",
|
|
||||||
"result_typing": {"primary": "direct_answer", "secondary": []},
|
|
||||||
"ui": {
|
|
||||||
"ui_mode": ui_mode.value,
|
|
||||||
"ui_decision_reason": "need structure"
|
|
||||||
if ui_mode == UiMode.RICH
|
|
||||||
else "plain text",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_execute_uses_router_ui_mode_to_select_worker_output_model(
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
) -> None:
|
|
||||||
runner = AgentScopeReActRunner()
|
|
||||||
pipeline = _FakePipeline()
|
|
||||||
worker_model_holder: dict[str, type[object]] = {}
|
|
||||||
|
|
||||||
class _CommitSession:
|
|
||||||
async def commit(self) -> None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
monkeypatch.setattr(
|
|
||||||
"core.agentscope.runtime.react_runner.AsyncSessionLocal",
|
|
||||||
lambda: _FakeSessionCtx(_CommitSession()),
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
runner,
|
|
||||||
"_build_toolkits",
|
|
||||||
lambda **kwargs: ("router-toolkit", "worker-toolkit"),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _load_system_agent_config(**kwargs):
|
|
||||||
return SystemAgentRuntimeConfig(
|
|
||||||
agent_type=kwargs["agent_type"],
|
|
||||||
model_code="qwen3.5-flash"
|
|
||||||
if kwargs["agent_type"] == AgentType.ROUTER
|
|
||||||
else "deepseek-chat",
|
|
||||||
llm_config=SystemAgentLLMConfig(
|
|
||||||
temperature=0.1, max_tokens=256, timeout_seconds=30
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
monkeypatch.setattr(runner, "_load_system_agent_config", _load_system_agent_config)
|
|
||||||
|
|
||||||
async def _run_router_stage(**kwargs):
|
|
||||||
return StageExecutionResult(
|
|
||||||
message=Msg(name="router", content="", role="assistant"),
|
|
||||||
payload=_router_output(ui_mode=UiMode.RICH).model_dump(mode="json"),
|
|
||||||
response_metadata={
|
|
||||||
"model": "qwen3.5-flash",
|
|
||||||
"inputTokens": 12,
|
|
||||||
"outputTokens": 6,
|
|
||||||
"cost": 0.001,
|
|
||||||
"latencyMs": 50,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
monkeypatch.setattr(runner, "_run_router_stage", _run_router_stage)
|
|
||||||
|
|
||||||
async def _persist_router_message(**kwargs) -> None:
|
|
||||||
assert kwargs["model_code"] == "qwen3.5-flash"
|
|
||||||
|
|
||||||
monkeypatch.setattr(runner, "_persist_router_message", _persist_router_message)
|
|
||||||
|
|
||||||
async def _run_worker_stage(**kwargs):
|
|
||||||
worker_model_holder["model"] = kwargs["worker_output_model"]
|
|
||||||
return StageExecutionResult(
|
|
||||||
message=Msg(name="worker", content="done", role="assistant"),
|
|
||||||
payload=WorkerAgentOutputRich.model_validate(
|
|
||||||
{
|
|
||||||
"status": "success",
|
|
||||||
"answer": "done",
|
|
||||||
"key_points": [],
|
|
||||||
"result_type": "direct_answer",
|
|
||||||
"suggested_actions": [],
|
|
||||||
"error": None,
|
|
||||||
"ui_hints": None,
|
|
||||||
}
|
|
||||||
).model_dump(mode="json", exclude_none=True),
|
|
||||||
response_metadata={
|
|
||||||
"model": "deepseek-chat",
|
|
||||||
"inputTokens": 8,
|
|
||||||
"outputTokens": 4,
|
|
||||||
"cost": 0.002,
|
|
||||||
"latencyMs": 40,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
monkeypatch.setattr(runner, "_run_worker_stage", _run_worker_stage)
|
|
||||||
|
|
||||||
result = await runner.execute(
|
|
||||||
user_context=_user_context(),
|
|
||||||
context_messages=[],
|
|
||||||
pipeline=pipeline,
|
|
||||||
run_input=_run_input(),
|
|
||||||
)
|
|
||||||
|
|
||||||
assert worker_model_holder["model"].__name__ == "WorkerAgentOutputRich"
|
|
||||||
event_types = []
|
|
||||||
for item in pipeline.events:
|
|
||||||
event = item.get("event")
|
|
||||||
if isinstance(event, dict):
|
|
||||||
event_types.append(event.get("type"))
|
|
||||||
assert event_types == ["step.start", "step.finish", "step.start", "step.finish"]
|
|
||||||
assert result["router"]["ui"]["ui_mode"] == "rich"
|
|
||||||
assert result["worker"]["answer"] == "done"
|
|
||||||
|
|
||||||
|
|
||||||
def test_extract_tool_names_normalizes_client_tool_names() -> None:
|
|
||||||
runner = AgentScopeReActRunner()
|
|
||||||
|
|
||||||
names = runner._extract_tool_names(_run_input())
|
|
||||||
|
|
||||||
assert names == {"calendar_read", "calendar_write"}
|
|
||||||
Reference in New Issue
Block a user