diff --git a/apps/lib/features/chat/data/models/ag_ui_event.g.dart b/apps/lib/features/chat/data/models/ag_ui_event.g.dart deleted file mode 100644 index 79b0616..0000000 --- a/apps/lib/features/chat/data/models/ag_ui_event.g.dart +++ /dev/null @@ -1,201 +0,0 @@ -// GENERATED CODE - DO NOT MODIFY BY HAND - -part of 'ag_ui_event.dart'; - -// ************************************************************************** -// JsonSerializableGenerator -// ************************************************************************** - -Map _$AgUiEventToJson(AgUiEvent instance) => { - 'type': _$AgUiEventTypeEnumMap[instance.type]!, -}; - -const _$AgUiEventTypeEnumMap = { - AgUiEventType.runStarted: 'runStarted', - AgUiEventType.runFinished: 'runFinished', - AgUiEventType.runError: 'runError', - AgUiEventType.stepStarted: 'stepStarted', - AgUiEventType.stepFinished: 'stepFinished', - AgUiEventType.textMessageStart: 'textMessageStart', - AgUiEventType.textMessageContent: 'textMessageContent', - AgUiEventType.textMessageEnd: 'textMessageEnd', - AgUiEventType.toolCallStart: 'toolCallStart', - AgUiEventType.toolCallArgs: 'toolCallArgs', - AgUiEventType.toolCallEnd: 'toolCallEnd', - AgUiEventType.toolCallResult: 'toolCallResult', - AgUiEventType.toolCallError: 'toolCallError', - AgUiEventType.messagesSnapshot: 'messagesSnapshot', - AgUiEventType.unknown: 'unknown', -}; - -RunStartedEvent _$RunStartedEventFromJson(Map json) => - RunStartedEvent( - threadId: json['threadId'] as String, - runId: json['runId'] as String, - ); - -Map _$RunStartedEventToJson(RunStartedEvent instance) => - {'threadId': instance.threadId, 'runId': instance.runId}; - -RunFinishedEvent _$RunFinishedEventFromJson(Map json) => - RunFinishedEvent( - threadId: json['threadId'] as String, - runId: json['runId'] as String, - ); - -Map _$RunFinishedEventToJson(RunFinishedEvent instance) => - {'threadId': instance.threadId, 'runId': instance.runId}; - -RunErrorEvent _$RunErrorEventFromJson(Map json) => - RunErrorEvent( - message: json['message'] as String, - code: json['code'] as String?, - ); - -Map _$RunErrorEventToJson(RunErrorEvent instance) => - {'message': instance.message, 'code': instance.code}; - -StepStartedEvent _$StepStartedEventFromJson(Map json) => - StepStartedEvent(stepName: json['stepName'] as String); - -Map _$StepStartedEventToJson(StepStartedEvent instance) => - {'stepName': instance.stepName}; - -StepFinishedEvent _$StepFinishedEventFromJson(Map json) => - StepFinishedEvent(stepName: json['stepName'] as String); - -Map _$StepFinishedEventToJson(StepFinishedEvent instance) => - {'stepName': instance.stepName}; - -TextMessageStartEvent _$TextMessageStartEventFromJson( - Map json, -) => TextMessageStartEvent( - messageId: json['messageId'] as String, - role: json['role'] as String, -); - -Map _$TextMessageStartEventToJson( - TextMessageStartEvent instance, -) => {'messageId': instance.messageId, 'role': instance.role}; - -TextMessageContentEvent _$TextMessageContentEventFromJson( - Map json, -) => TextMessageContentEvent( - messageId: json['messageId'] as String, - delta: json['delta'] as String, -); - -Map _$TextMessageContentEventToJson( - TextMessageContentEvent instance, -) => { - 'messageId': instance.messageId, - 'delta': instance.delta, -}; - -TextMessageEndEvent _$TextMessageEndEventFromJson(Map json) => - TextMessageEndEvent(messageId: json['messageId'] as String); - -Map _$TextMessageEndEventToJson( - TextMessageEndEvent instance, -) => {'messageId': instance.messageId}; - -ToolCallStartEvent _$ToolCallStartEventFromJson(Map json) => - ToolCallStartEvent( - toolCallId: json['toolCallId'] as String, - toolCallName: json['toolCallName'] as String, - parentMessageId: json['parentMessageId'] as String?, - ); - -Map _$ToolCallStartEventToJson(ToolCallStartEvent instance) => - { - 'toolCallId': instance.toolCallId, - 'toolCallName': instance.toolCallName, - 'parentMessageId': instance.parentMessageId, - }; - -ToolCallArgsEvent _$ToolCallArgsEventFromJson(Map json) => - ToolCallArgsEvent( - toolCallId: json['toolCallId'] as String, - delta: json['delta'] as String, - ); - -Map _$ToolCallArgsEventToJson(ToolCallArgsEvent instance) => - { - 'toolCallId': instance.toolCallId, - 'delta': instance.delta, - }; - -ToolCallEndEvent _$ToolCallEndEventFromJson(Map json) => - ToolCallEndEvent(toolCallId: json['toolCallId'] as String); - -Map _$ToolCallEndEventToJson(ToolCallEndEvent instance) => - {'toolCallId': instance.toolCallId}; - -ToolCallResultEvent _$ToolCallResultEventFromJson(Map json) => - ToolCallResultEvent( - messageId: json['messageId'] as String, - toolCallId: json['toolCallId'] as String, - content: json['content'] as String, - ); - -Map _$ToolCallResultEventToJson( - ToolCallResultEvent instance, -) => { - 'messageId': instance.messageId, - 'toolCallId': instance.toolCallId, - 'content': instance.content, -}; - -ToolCallErrorEvent _$ToolCallErrorEventFromJson(Map json) => - ToolCallErrorEvent( - toolCallId: json['toolCallId'] as String, - error: json['error'] as String, - code: json['code'] as String?, - ); - -Map _$ToolCallErrorEventToJson(ToolCallErrorEvent instance) => - { - 'toolCallId': instance.toolCallId, - 'error': instance.error, - 'code': instance.code, - }; - -MessagesSnapshotEvent _$MessagesSnapshotEventFromJson( - Map json, -) => MessagesSnapshotEvent( - messages: (json['messages'] as List) - .map((e) => SnapshotMessage.fromJson(e as Map)) - .toList(), -); - -Map _$MessagesSnapshotEventToJson( - MessagesSnapshotEvent instance, -) => {'messages': instance.messages}; - -SnapshotMessage _$SnapshotMessageFromJson(Map json) => - SnapshotMessage( - id: json['id'] as String, - role: json['role'] as String, - content: json['content'] as String?, - toolCallId: json['toolCallId'] as String?, - ui: json['ui'] == null - ? null - : UiCard.fromJson(json['ui'] as Map), - timestamp: json['timestamp'] == null - ? null - : DateTime.parse(json['timestamp'] as String), - attachments: (json['attachments'] as List?) - ?.whereType>() - .toList(), - ); - -Map _$SnapshotMessageToJson(SnapshotMessage instance) => - { - 'id': instance.id, - 'role': instance.role, - 'content': instance.content, - 'toolCallId': instance.toolCallId, - 'ui': instance.ui, - 'timestamp': instance.timestamp?.toIso8601String(), - 'attachments': instance.attachments, - }; diff --git a/backend/src/core/agentscope/runtime/react_runner.py b/backend/src/core/agentscope/runtime/react_runner.py deleted file mode 100644 index 2a70f4e..0000000 --- a/backend/src/core/agentscope/runtime/react_runner.py +++ /dev/null @@ -1,691 +0,0 @@ -from __future__ import annotations - -import json -from collections.abc import AsyncGenerator, Sequence -from dataclasses import dataclass -from datetime import datetime, timezone -from decimal import Decimal -from typing import TYPE_CHECKING, Any -from uuid import UUID, uuid4 - -from ag_ui.core.types import RunAgentInput -from agentscope.agent import ReActAgent -from agentscope.formatter import OpenAIChatFormatter -from agentscope.memory import InMemoryMemory -from agentscope.message import Msg -from agentscope.model import OpenAIChatModel -from core.agentscope.events.persistence import MessageRepository, SessionRepository -from core.agentscope.prompts.system_prompt import build_system_prompt -from core.agentscope.tools.toolkit import build_stage_toolkit -from core.db.session import AsyncSessionLocal -from core.logging import get_logger -from models.agent_chat_message import AgentChatMessageRole -from models.agent_chat_session import AgentChatSessionStatus -from models.llm import Llm -from models.system_agents import SystemAgents -from schemas.agent.runtime_models import ( - RouterAgentOutput, - ToolAgentOutput, - WorkerAgentOutputLite, - resolve_worker_output_model, -) -from schemas.agent.system_agent import AgentType, SystemAgentLLMConfig -from schemas.messages.chat_message import AgentChatMessage, AgentChatMessageMetadata -from schemas.user import UserContext -from services.litellm.service import LiteLLMService -from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession - -if TYPE_CHECKING: - from core.agentscope.runtime.orchestrator import PipelineLike - -logger = get_logger("core.agentscope.runtime.react_runner") - - -@dataclass(frozen=True) -class SystemAgentRuntimeConfig: - agent_type: AgentType - model_code: str - llm_config: SystemAgentLLMConfig - - -@dataclass(frozen=True) -class StageExecutionResult: - message: Msg - payload: dict[str, Any] - response_metadata: dict[str, Any] - - -class _TrackingChatModel: - def __init__(self, inner: OpenAIChatModel) -> None: - self._inner = inner - self._total_input_tokens = 0 - self._total_output_tokens = 0 - self._total_latency_ms = 0 - self._cached_prompt_tokens = 0 - - @property - def stream(self) -> bool: - return self._inner.stream - - @stream.setter - def stream(self, value: bool) -> None: - self._inner.stream = value - - def __getattr__(self, name: str) -> Any: - return getattr(self._inner, name) - - async def __call__(self, *args: Any, **kwargs: Any) -> Any: - response = await self._inner(*args, **kwargs) - if isinstance(response, AsyncGenerator): - return self._track_stream(response) - self._record_usage(getattr(response, "usage", None)) - return response - - async def _track_stream( - self, response: AsyncGenerator[Any, None] - ) -> AsyncGenerator[Any, None]: - latest_usage = None - async for chunk in response: - usage = getattr(chunk, "usage", None) - if usage is not None: - latest_usage = usage - yield chunk - self._record_usage(latest_usage) - - def _record_usage(self, usage: Any) -> None: - if usage is None: - return - self._total_input_tokens += max(int(getattr(usage, "input_tokens", 0) or 0), 0) - self._total_output_tokens += max( - int(getattr(usage, "output_tokens", 0) or 0), 0 - ) - self._total_latency_ms += max( - int(round(float(getattr(usage, "time", 0) or 0) * 1000)), 0 - ) - metadata = getattr(usage, "metadata", None) - if metadata is not None: - cached_tokens = 0 - if isinstance(metadata, dict): - prompt_details = metadata.get("prompt_tokens_details") - if isinstance(prompt_details, dict): - cached_tokens = int(prompt_details.get("cached_tokens", 0) or 0) - else: - prompt_details = getattr(metadata, "prompt_tokens_details", None) - cached_tokens = int(getattr(prompt_details, "cached_tokens", 0) or 0) - self._cached_prompt_tokens += max(cached_tokens, 0) - - def usage_summary(self) -> dict[str, int]: - return { - "input_tokens": self._total_input_tokens, - "output_tokens": self._total_output_tokens, - "latency_ms": self._total_latency_ms, - "cached_prompt_tokens": self._cached_prompt_tokens, - } - - -class _PipelineStageEmitter: - def __init__( - self, - *, - pipeline: PipelineLike, - session_id: str, - run_id: str, - stage: str, - emit_text_events: bool, - emit_tool_events: bool, - ) -> None: - self._pipeline = pipeline - self._session_id = session_id - self._run_id = run_id - self._stage = stage - self._emit_text_events = emit_text_events - self._emit_tool_events = emit_tool_events - self._text_by_message_id: dict[str, str] = {} - self._emitted_tool_calls: set[str] = set() - self._emitted_tool_results: set[str] = set() - self.latest_text_message_id: str | None = None - self.latest_text: str = "" - - async def handle_print(self, *, msg: Msg, last: bool) -> None: - del last - if self._emit_tool_events: - await self._emit_tool_events_from_msg(msg) - if self._emit_text_events: - await self._emit_text_events_from_msg(msg) - - async def _emit_text_events_from_msg(self, msg: Msg) -> None: - text = msg.get_text_content(separator="") or "" - if not text: - return - message_id = str(msg.id) - previous = self._text_by_message_id.get(message_id, "") - if message_id not in self._text_by_message_id: - await self._emit( - "text.start", - { - "messageId": message_id, - "role": "assistant", - "stage": self._stage, - }, - ) - delta = text[len(previous) :] if text.startswith(previous) else text - if delta: - await self._emit( - "text.delta", - { - "messageId": message_id, - "delta": delta, - "stage": self._stage, - }, - ) - self._text_by_message_id[message_id] = text - self.latest_text_message_id = message_id - self.latest_text = text - - async def _emit_tool_events_from_msg(self, msg: Msg) -> None: - for block in msg.get_content_blocks("tool_use"): - tool_call_id = str(block.get("id", "")).strip() - tool_name = str(block.get("name", "")).strip() - if ( - not tool_call_id - or not tool_name - or tool_call_id in self._emitted_tool_calls - ): - continue - payload = { - "messageId": str(msg.id), - "toolCallId": tool_call_id, - "toolName": tool_name, - "stage": self._stage, - } - await self._emit("tool.start", payload) - await self._emit( - "tool.args", - { - **payload, - "args": block.get("input", {}), - }, - ) - await self._emit("tool.end", payload) - self._emitted_tool_calls.add(tool_call_id) - - for block in msg.get_content_blocks("tool_result"): - tool_call_id = str(block.get("id", "")).strip() - if not tool_call_id or tool_call_id in self._emitted_tool_results: - continue - tool_output = _parse_tool_agent_output(block.get("output")) - if tool_output is None: - continue - await self._emit( - "tool.result", - { - "messageId": str(msg.id), - "toolCallId": tool_call_id, - "toolName": tool_output.tool_name, - "stage": self._stage, - "toolAgentOutput": tool_output.model_dump( - mode="json", exclude_none=True - ), - }, - ) - self._emitted_tool_results.add(tool_call_id) - - async def emit_final_text_end( - self, - *, - worker_output: dict[str, Any], - response_metadata: dict[str, Any], - ) -> None: - message_id = ( - self.latest_text_message_id or f"worker-{self._run_id}-{uuid4().hex[:8]}" - ) - if self.latest_text_message_id is None and worker_output.get("answer"): - await self._emit( - "text.start", - { - "messageId": message_id, - "role": "assistant", - "stage": self._stage, - }, - ) - await self._emit( - "text.delta", - { - "messageId": message_id, - "delta": worker_output.get("answer", ""), - "stage": self._stage, - }, - ) - await self._emit( - "text.end", - { - "messageId": message_id, - "role": "assistant", - "stage": self._stage, - "workerAgentOutput": worker_output, - **response_metadata, - }, - ) - - async def _emit(self, event_type: str, data: dict[str, Any]) -> None: - await self._pipeline.emit( - session_id=self._session_id, - event={ - "type": event_type, - "threadId": self._session_id, - "runId": self._run_id, - "data": data, - }, - ) - - -class _PipelineReActAgent(ReActAgent): - def __init__( - self, *, emitter: _PipelineStageEmitter | None = None, **kwargs: Any - ) -> None: - super().__init__(**kwargs) - self._pipeline_emitter = emitter - self.disable_console_output() - - async def print(self, msg: Msg, last: bool = True, speech: Any = None) -> None: - del speech - if self._pipeline_emitter is not None: - await self._pipeline_emitter.handle_print(msg=msg, last=last) - - -def _parse_tool_agent_output(output: Any) -> ToolAgentOutput | None: - blocks = output if isinstance(output, Sequence) else [] - for block in blocks: - if not isinstance(block, dict) or block.get("type") != "text": - continue - text = block.get("text") - if not isinstance(text, str) or not text.strip(): - continue - try: - return ToolAgentOutput.model_validate(json.loads(text)) - except Exception: - return None - return None - - -def _normalize_tool_name(value: str) -> str: - return value.strip().replace(".", "_").replace("-", "_") - - -class AgentScopeReActRunner: - def __init__(self, *, litellm_service: LiteLLMService | None = None) -> None: - self._litellm_service = litellm_service or LiteLLMService() - - async def execute( - self, - *, - user_context: UserContext, - context_messages: list[Msg], - pipeline: PipelineLike, - run_input: RunAgentInput, - ) -> dict[str, Any]: - owner_id = UUID(user_context.id) - enabled_tool_names = self._extract_tool_names(run_input) - - async with AsyncSessionLocal() as session: - router_toolkit, worker_toolkit = self._build_toolkits( - session=session, - owner_id=owner_id, - enabled_tool_names=enabled_tool_names, - ) - - router_config = await self._load_system_agent_config( - session=session, - agent_type=AgentType.ROUTER, - ) - worker_config = await self._load_system_agent_config( - session=session, - agent_type=AgentType.WORKER, - ) - - await self._emit_step_event( - pipeline=pipeline, - run_input=run_input, - step_name="router", - event_type="step.start", - ) - router_result = await self._run_router_stage( - user_context=user_context, - context_messages=context_messages, - toolkit=router_toolkit, - run_input=run_input, - stage_config=router_config, - ) - router_output = RouterAgentOutput.model_validate(router_result.payload) - await self._persist_router_message( - session=session, - thread_id=run_input.thread_id, - run_id=run_input.run_id, - model_code=router_config.model_code, - router_output=router_output, - response_metadata=router_result.response_metadata, - ) - await session.commit() - await self._emit_step_event( - pipeline=pipeline, - run_input=run_input, - step_name="router", - event_type="step.finish", - ) - - worker_output_model = resolve_worker_output_model(router_output.ui.ui_mode) - await self._emit_step_event( - pipeline=pipeline, - run_input=run_input, - step_name="worker", - event_type="step.start", - ) - worker_result = await self._run_worker_stage( - user_context=user_context, - router_output=router_output, - toolkit=worker_toolkit, - run_input=run_input, - stage_config=worker_config, - worker_output_model=worker_output_model, - pipeline=pipeline, - ) - worker_output = worker_output_model.model_validate(worker_result.payload) - await self._emit_step_event( - pipeline=pipeline, - run_input=run_input, - step_name="worker", - event_type="step.finish", - ) - - return { - "router": router_output.model_dump(mode="json", exclude_none=True), - "worker": worker_output.model_dump(mode="json", exclude_none=True), - } - - def _build_toolkits( - self, - *, - session: AsyncSession, - owner_id: UUID, - enabled_tool_names: set[str] | None, - ) -> tuple[Any, Any]: - return ( - build_stage_toolkit( - agent_type=AgentType.ROUTER, - session=session, - owner_id=owner_id, - enabled_tool_names=enabled_tool_names, - ), - build_stage_toolkit( - agent_type=AgentType.WORKER, - session=session, - owner_id=owner_id, - enabled_tool_names=enabled_tool_names, - ), - ) - - def _extract_tool_names(self, run_input: RunAgentInput) -> set[str] | None: - raw_tools = getattr(run_input, "tools", None) - if not isinstance(raw_tools, list): - return None - selected: set[str] = set() - for item in raw_tools: - if isinstance(item, dict): - name = item.get("name") - else: - name = getattr(item, "name", None) - if isinstance(name, str) and name.strip(): - selected.add(_normalize_tool_name(name)) - return selected - - async def _load_system_agent_config( - self, - *, - session: AsyncSession, - agent_type: AgentType, - ) -> SystemAgentRuntimeConfig: - stmt = ( - select(SystemAgents, Llm) - .join(Llm, SystemAgents.llm_id == Llm.id) - .where(SystemAgents.agent_type == agent_type.value) - ) - row = (await session.execute(stmt)).one_or_none() - if row is None: - raise RuntimeError(f"system agent config not found: {agent_type.value}") - system_agent, llm = row - status = str(system_agent.status).strip().lower() - if status != "active": - raise RuntimeError(f"system agent is not active: {agent_type.value}") - return SystemAgentRuntimeConfig( - agent_type=agent_type, - model_code=llm.model_code, - llm_config=SystemAgentLLMConfig.model_validate(system_agent.config or {}), - ) - - async def _run_router_stage( - self, - *, - user_context: UserContext, - context_messages: list[Msg], - toolkit: Any, - run_input: RunAgentInput, - stage_config: SystemAgentRuntimeConfig, - ) -> StageExecutionResult: - tracking_model = self._build_model(stage_config=stage_config) - agent = self._build_agent( - agent_name="router", - system_prompt=build_system_prompt( - agent_type=AgentType.ROUTER, - user_context=user_context, - now_utc=datetime.now(timezone.utc), - tools=run_input.tools, - ), - toolkit=toolkit, - model=tracking_model, - ) - response_msg = await agent.reply( - context_messages, structured_model=RouterAgentOutput - ) - payload = RouterAgentOutput.model_validate( - response_msg.metadata or {} - ).model_dump( - mode="json", - exclude_none=True, - ) - return StageExecutionResult( - message=response_msg, - payload=payload, - response_metadata=self._litellm_service.build_usage_metadata( - model=stage_config.model_code, - usage_summary=tracking_model.usage_summary(), - ), - ) - - async def _run_worker_stage( - self, - *, - user_context: UserContext, - router_output: RouterAgentOutput, - toolkit: Any, - run_input: RunAgentInput, - stage_config: SystemAgentRuntimeConfig, - worker_output_model: type[WorkerAgentOutputLite], - pipeline: PipelineLike, - ) -> StageExecutionResult: - worker_input = self._build_worker_input_messages( - router_output=router_output, - ) - tracking_model = self._build_model(stage_config=stage_config) - emitter = _PipelineStageEmitter( - pipeline=pipeline, - session_id=run_input.thread_id, - run_id=run_input.run_id, - stage="worker", - emit_text_events=True, - emit_tool_events=True, - ) - agent = self._build_agent( - agent_name="worker", - system_prompt=build_system_prompt( - agent_type=AgentType.WORKER, - user_context=user_context, - now_utc=datetime.now(timezone.utc), - tools=run_input.tools, - ), - toolkit=toolkit, - model=tracking_model, - emitter=emitter, - ) - response_msg = await agent.reply( - worker_input, - structured_model=worker_output_model, - ) - worker_payload = worker_output_model.model_validate(response_msg.metadata or {}) - response_metadata = self._litellm_service.build_usage_metadata( - model=stage_config.model_code, - usage_summary=tracking_model.usage_summary(), - ) - await emitter.emit_final_text_end( - worker_output=worker_payload.model_dump(mode="json", exclude_none=True), - response_metadata=response_metadata, - ) - return StageExecutionResult( - message=response_msg, - payload=worker_payload.model_dump(mode="json", exclude_none=True), - response_metadata=response_metadata, - ) - - def _build_worker_input_messages( - self, - *, - router_output: RouterAgentOutput, - ) -> list[Msg]: - routing_contract = json.dumps( - router_output.model_dump(mode="json", exclude_none=True), - ensure_ascii=False, - separators=(",", ":"), - ) - routing_msg = Msg( - name="router", - role="user", - content=( - "Use the following routing contract as the execution source of truth. " - f"Do not change the routed objective:\n{routing_contract}" - ), - ) - return [routing_msg] - - def _build_model( - self, *, stage_config: SystemAgentRuntimeConfig - ) -> _TrackingChatModel: - model = OpenAIChatModel( - model_name=stage_config.model_code, - api_key=self._litellm_service.proxy_api_key, - stream=True, - client_kwargs={"base_url": self._litellm_service.proxy_base_url}, - generate_kwargs={ - "temperature": stage_config.llm_config.temperature, - "max_tokens": stage_config.llm_config.max_tokens, - "timeout": stage_config.llm_config.timeout_seconds, - }, - ) - return _TrackingChatModel(model) - - def _build_agent( - self, - *, - agent_name: str, - system_prompt: str, - toolkit: Any, - model: _TrackingChatModel, - emitter: _PipelineStageEmitter | None = None, - ) -> _PipelineReActAgent: - return _PipelineReActAgent( - name=agent_name, - sys_prompt=system_prompt, - model=model, - formatter=OpenAIChatFormatter(), - toolkit=toolkit, - memory=InMemoryMemory(), - emitter=emitter, - ) - - async def _emit_step_event( - self, - *, - pipeline: PipelineLike, - run_input: RunAgentInput, - step_name: str, - event_type: str, - ) -> None: - await pipeline.emit( - session_id=run_input.thread_id, - event={ - "type": event_type, - "threadId": run_input.thread_id, - "runId": run_input.run_id, - "data": {"stepName": step_name}, - }, - ) - - async def _persist_router_message( - self, - *, - session: AsyncSession, - thread_id: str, - run_id: str, - model_code: str, - router_output: RouterAgentOutput, - response_metadata: dict[str, Any], - ) -> None: - session_id = UUID(thread_id) - message_repo = MessageRepository(session) - session_repo = SessionRepository(session) - locked_session = await session_repo.lock_session_for_update( - session_id=session_id - ) - if locked_session is None: - raise RuntimeError("chat session not found for router persistence") - seq = int(getattr(locked_session, "message_count", 0) or 0) + 1 - metadata = AgentChatMessageMetadata( - run_id=run_id, - agent_type=AgentType.ROUTER, - router_agent_output=router_output, - ) - message_payload = AgentChatMessage( - id=uuid4(), - seq=seq, - role=AgentChatMessageRole.ASSISTANT.value, - content="", - model_code=model_code, - tool_name=None, - input_tokens=int(response_metadata.get("inputTokens", 0) or 0), - output_tokens=int(response_metadata.get("outputTokens", 0) or 0), - cost=Decimal(str(response_metadata.get("cost", 0) or 0)), - latency_ms=int(response_metadata.get("latencyMs", 0) or 0), - metadata=metadata, - timestamp=datetime.now(timezone.utc), - ) - await message_repo.append_message( - session_id=session_id, - seq=message_payload.seq, - role=AgentChatMessageRole.ASSISTANT, - content=message_payload.content, - model_code=message_payload.model_code, - tool_name=message_payload.tool_name, - metadata=metadata.model_dump(mode="json", exclude_none=True), - input_tokens=message_payload.input_tokens, - output_tokens=message_payload.output_tokens, - cost=message_payload.cost, - latency_ms=message_payload.latency_ms, - ) - await session_repo.update_runtime_state( - chat_session=locked_session, - status=AgentChatSessionStatus.RUNNING, - state_snapshot=locked_session.state_snapshot or {}, - message_delta=1, - token_delta=message_payload.input_tokens + message_payload.output_tokens, - cost_delta=message_payload.cost, - ) - await session.flush() diff --git a/backend/tests/unit/core/agentscope/runtime/test_react_runner.py b/backend/tests/unit/core/agentscope/runtime/test_react_runner.py deleted file mode 100644 index f677ebe..0000000 --- a/backend/tests/unit/core/agentscope/runtime/test_react_runner.py +++ /dev/null @@ -1,201 +0,0 @@ -from __future__ import annotations - -import pytest -from ag_ui.core import RunAgentInput -from agentscope.message import Msg - -from core.agentscope.runtime.react_runner import ( - AgentScopeReActRunner, - StageExecutionResult, - SystemAgentRuntimeConfig, -) -from schemas.agent.runtime_models import ( - RouterAgentOutput, - UiMode, - WorkerAgentOutputRich, -) -from schemas.agent.system_agent import AgentType, SystemAgentLLMConfig -from schemas.user.context import UserContext, parse_profile_settings - - -class _FakePipeline: - def __init__(self) -> None: - self.events: list[dict[str, object]] = [] - - async def emit(self, *, session_id: str, event: dict[str, object]) -> str: - self.events.append({"session_id": session_id, "event": event}) - return "1-0" - - -class _FakeSessionCtx: - def __init__(self, session: object) -> None: - self._session = session - - async def __aenter__(self) -> object: - return self._session - - async def __aexit__(self, exc_type, exc, tb) -> None: - del exc_type, exc, tb - - -def _user_context() -> UserContext: - return UserContext( - id="00000000-0000-0000-0000-000000000001", - username="alice", - email="alice@example.com", - settings=parse_profile_settings(None), - ) - - -def _run_input() -> RunAgentInput: - return RunAgentInput.model_validate( - { - "threadId": "00000000-0000-0000-0000-000000000010", - "runId": "run-1", - "state": {}, - "messages": [{"id": "u1", "role": "user", "content": "hello"}], - "tools": [ - { - "name": "calendar.read", - "description": "read", - "parameters": {"type": "object"}, - }, - { - "name": "calendar-write", - "description": "write", - "parameters": {"type": "object"}, - }, - ], - "context": [], - "forwardedProps": {}, - } - ) - - -def _router_output(*, ui_mode: UiMode) -> RouterAgentOutput: - return RouterAgentOutput.model_validate( - { - "normalized_task_input": { - "user_text": "hello", - "multimodal_summary": [], - }, - "key_entities": [], - "constraints": [], - "task_typing": {"primary": "knowledge", "secondary": []}, - "execution_mode": "onestep", - "result_typing": {"primary": "direct_answer", "secondary": []}, - "ui": { - "ui_mode": ui_mode.value, - "ui_decision_reason": "need structure" - if ui_mode == UiMode.RICH - else "plain text", - }, - } - ) - - -@pytest.mark.asyncio -async def test_execute_uses_router_ui_mode_to_select_worker_output_model( - monkeypatch: pytest.MonkeyPatch, -) -> None: - runner = AgentScopeReActRunner() - pipeline = _FakePipeline() - worker_model_holder: dict[str, type[object]] = {} - - class _CommitSession: - async def commit(self) -> None: - return None - - monkeypatch.setattr( - "core.agentscope.runtime.react_runner.AsyncSessionLocal", - lambda: _FakeSessionCtx(_CommitSession()), - ) - monkeypatch.setattr( - runner, - "_build_toolkits", - lambda **kwargs: ("router-toolkit", "worker-toolkit"), - ) - - async def _load_system_agent_config(**kwargs): - return SystemAgentRuntimeConfig( - agent_type=kwargs["agent_type"], - model_code="qwen3.5-flash" - if kwargs["agent_type"] == AgentType.ROUTER - else "deepseek-chat", - llm_config=SystemAgentLLMConfig( - temperature=0.1, max_tokens=256, timeout_seconds=30 - ), - ) - - monkeypatch.setattr(runner, "_load_system_agent_config", _load_system_agent_config) - - async def _run_router_stage(**kwargs): - return StageExecutionResult( - message=Msg(name="router", content="", role="assistant"), - payload=_router_output(ui_mode=UiMode.RICH).model_dump(mode="json"), - response_metadata={ - "model": "qwen3.5-flash", - "inputTokens": 12, - "outputTokens": 6, - "cost": 0.001, - "latencyMs": 50, - }, - ) - - monkeypatch.setattr(runner, "_run_router_stage", _run_router_stage) - - async def _persist_router_message(**kwargs) -> None: - assert kwargs["model_code"] == "qwen3.5-flash" - - monkeypatch.setattr(runner, "_persist_router_message", _persist_router_message) - - async def _run_worker_stage(**kwargs): - worker_model_holder["model"] = kwargs["worker_output_model"] - return StageExecutionResult( - message=Msg(name="worker", content="done", role="assistant"), - payload=WorkerAgentOutputRich.model_validate( - { - "status": "success", - "answer": "done", - "key_points": [], - "result_type": "direct_answer", - "suggested_actions": [], - "error": None, - "ui_hints": None, - } - ).model_dump(mode="json", exclude_none=True), - response_metadata={ - "model": "deepseek-chat", - "inputTokens": 8, - "outputTokens": 4, - "cost": 0.002, - "latencyMs": 40, - }, - ) - - monkeypatch.setattr(runner, "_run_worker_stage", _run_worker_stage) - - result = await runner.execute( - user_context=_user_context(), - context_messages=[], - pipeline=pipeline, - run_input=_run_input(), - ) - - assert worker_model_holder["model"].__name__ == "WorkerAgentOutputRich" - event_types = [] - for item in pipeline.events: - event = item.get("event") - if isinstance(event, dict): - event_types.append(event.get("type")) - assert event_types == ["step.start", "step.finish", "step.start", "step.finish"] - assert result["router"]["ui"]["ui_mode"] == "rich" - assert result["worker"]["answer"] == "done" - - -def test_extract_tool_names_normalizes_client_tool_names() -> None: - runner = AgentScopeReActRunner() - - names = runner._extract_tool_names(_run_input()) - - assert names == {"calendar_read", "calendar_write"}