from __future__ import annotations from dataclasses import dataclass from datetime import datetime, timezone from typing import TYPE_CHECKING, Any from uuid import UUID from ag_ui.core.types import RunAgentInput from agentscope.formatter import OpenAIChatFormatter from agentscope.memory import InMemoryMemory from agentscope.message import Msg from agentscope.model import OpenAIChatModel from core.agentscope.prompts.agent_prompt import build_worker_contract_prompt from core.agentscope.prompts.system_prompt import build_system_prompt from core.agentscope.schemas.agui_input import extract_latest_user_payload from core.agentscope.runtime.json_react_agent import JsonReActAgent from core.agentscope.runtime.model_tracking import TrackingChatModel from core.agentscope.runtime.stage_emitter import PipelineStageEmitter from core.agentscope.tools.tool_config import AgentTool from core.agentscope.tools.toolkit import build_toolkit from core.agentscope.utils import ( finalize_json_response, patch_agentscope_json_repair_compat, ) from core.config.settings import config from core.db.session import AsyncSessionLocal from models.llm import Llm from models.llm_factory import LlmFactory from models.system_agents import SystemAgents from schemas.agent.forwarded_props import ( ClientTimeContext, parse_forwarded_props_client_time, ) from schemas.agent.runtime_models import ( RouterAgentOutput, WorkerAgentOutputLite, resolve_worker_output_model, ) from schemas.agent.system_agent import ( AgentType, SystemAgentLLMConfig, ) from schemas.automation import RuntimeConfig from schemas.memories import MemoryListResponse from schemas.user import UserContext from services.litellm.service import LiteLLMService from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession if TYPE_CHECKING: from core.agentscope.runtime.orchestrator import PipelineLike @dataclass(frozen=True) class StageExecutionResult: message: Msg payload: dict[str, Any] response_metadata: dict[str, Any] class AgentScopeRunner: def __init__(self, *, litellm_service: LiteLLMService | None = None) -> None: patch_agentscope_json_repair_compat() self._litellm_service: LiteLLMService = litellm_service or LiteLLMService() async def execute( self, *, user_context: UserContext, context_messages: list[Msg], pipeline: PipelineLike, run_input: RunAgentInput, runtime_config: RuntimeConfig, memories: MemoryListResponse | None = None, ) -> dict[str, Any]: owner_id = UUID(user_context.id) runtime_client_time = self._resolve_runtime_client_time(run_input=run_input) async with AsyncSessionLocal() as session: router_config = await self._load_stage_config( session=session, agent_type=AgentType.ROUTER, ) worker_config = await self._load_stage_config( session=session, agent_type=AgentType.WORKER, ) worker_toolkit = self._build_toolkit( session=session, owner_id=owner_id, enabled_tools=runtime_config.enabled_tools, ) router_output = await self._execute_router_step( pipeline=pipeline, run_input=run_input, user_context=user_context, context_messages=context_messages, stage_config=router_config, runtime_client_time=runtime_client_time, memories=memories, ) worker_output = await self._execute_worker_step( pipeline=pipeline, run_input=run_input, user_context=user_context, router_output=router_output, toolkit=worker_toolkit, stage_config=worker_config, runtime_client_time=runtime_client_time, memories=memories, ) return { "router": router_output.model_dump(mode="json", exclude_none=True), "worker": worker_output.model_dump(mode="json", exclude_none=True), } def _build_toolkit( self, *, session: AsyncSession, owner_id: UUID, enabled_tools: list[AgentTool], ) -> Any: tool_names = [t.value for t in enabled_tools] if enabled_tools else [] return build_toolkit( session=session, owner_id=owner_id, enabled_tool_names=set(tool_names) if tool_names else None, ) async def _load_stage_config( self, *, session: AsyncSession, agent_type: AgentType, ) -> SystemAgentRuntimeConfig: stmt = ( select(SystemAgents, Llm, LlmFactory) .join(Llm, SystemAgents.llm_id == Llm.id) .join(LlmFactory, Llm.factory_id == LlmFactory.id) .where(SystemAgents.agent_type == agent_type.value) ) row = (await session.execute(stmt)).one_or_none() if row is None: raise RuntimeError(f"system agent config not found: {agent_type.value}") system_agent, llm, factory = row status = str(system_agent.status).strip().lower() if status != "active": raise RuntimeError(f"system agent is not active: {agent_type.value}") return SystemAgentRuntimeConfig( agent_type=agent_type, model_code=llm.model_code, api_base_url=factory.request_url, api_key=self._resolve_provider_api_key(factory_name=factory.name), llm_config=SystemAgentLLMConfig.model_validate(system_agent.config or {}), extra_context=None, ) async def _execute_router_step( self, *, pipeline: PipelineLike, run_input: RunAgentInput, user_context: UserContext, context_messages: list[Msg], stage_config: SystemAgentRuntimeConfig, runtime_client_time: ClientTimeContext | None, memories: MemoryListResponse | None, ) -> RouterAgentOutput: await self._emit_step_event( pipeline=pipeline, run_input=run_input, step_name=AgentType.ROUTER.value, event_type="STEP_STARTED", ) router_result = await self._run_router_stage( user_context=user_context, context_messages=context_messages, stage_config=stage_config, runtime_client_time=runtime_client_time, memories=memories, run_input=run_input, ) router_output = RouterAgentOutput.model_validate(router_result.payload) await self._emit_step_event( pipeline=pipeline, run_input=run_input, step_name=AgentType.ROUTER.value, event_type="STEP_FINISHED", ) return router_output async def _execute_worker_step( self, *, pipeline: PipelineLike, run_input: RunAgentInput, user_context: UserContext, router_output: RouterAgentOutput, toolkit: Any, stage_config: SystemAgentRuntimeConfig, runtime_client_time: ClientTimeContext | None, memories: MemoryListResponse | None, ) -> WorkerAgentOutputLite: worker_output_model = resolve_worker_output_model(router_output.ui.ui_mode) await self._emit_step_event( pipeline=pipeline, run_input=run_input, step_name=AgentType.WORKER.value, event_type="STEP_STARTED", ) worker_result = await self._run_worker_stage( user_context=user_context, input_messages=self._build_worker_input_messages( router_output=router_output ), toolkit=toolkit, run_input=run_input, stage_config=stage_config, worker_output_model=worker_output_model, pipeline=pipeline, runtime_client_time=runtime_client_time, memories=memories, ) worker_output = worker_output_model.model_validate(worker_result.payload) await self._emit_step_event( pipeline=pipeline, run_input=run_input, step_name=AgentType.WORKER.value, event_type="STEP_FINISHED", ) return worker_output async def _run_router_stage( self, *, user_context: UserContext, context_messages: list[Msg], stage_config: SystemAgentRuntimeConfig, runtime_client_time: ClientTimeContext | None, memories: MemoryListResponse | None, run_input: RunAgentInput, ) -> StageExecutionResult: messages_for_router = self._build_router_messages( context_messages=context_messages, run_input=run_input, ) tracking_model = self._build_model(stage_config=stage_config) response, payload = await finalize_json_response( model=tracking_model, formatter=OpenAIChatFormatter(), base_messages=[ Msg( "system", build_system_prompt( agent_type=AgentType.ROUTER, llm_config=stage_config.llm_config, user_context=user_context, now_utc=datetime.now(timezone.utc), runtime_client_time=runtime_client_time, tools=None, memories=memories, ), "system", ), *messages_for_router, ], output_model=RouterAgentOutput, retries=0, ) response_msg = Msg( name="router", role="assistant", content=list(getattr(response, "content", [])), metadata=payload, ) return StageExecutionResult( message=response_msg, payload=payload, response_metadata=self._litellm_service.build_usage_metadata( model=stage_config.model_code, usage_summary=tracking_model.usage_summary(), ), ) def _build_router_messages( self, *, context_messages: list[Msg], run_input: RunAgentInput, ) -> list[Msg]: if context_messages: last = context_messages[-1] if last.role == "user": return context_messages user_text, user_blocks = extract_latest_user_payload(run_input) if ( user_blocks and isinstance(user_blocks[0], dict) and user_blocks[0].get("type") == "text" ): content: Any = user_text else: content = user_blocks user_msg = Msg(name="user", role="user", content=content) return [user_msg, *context_messages] async def _run_worker_stage( self, *, user_context: UserContext, input_messages: list[Msg], toolkit: Any, run_input: RunAgentInput, stage_config: SystemAgentRuntimeConfig, worker_output_model: type[WorkerAgentOutputLite], pipeline: PipelineLike, runtime_client_time: ClientTimeContext | None, memories: MemoryListResponse | None, ) -> StageExecutionResult: tracking_model = self._build_model(stage_config=stage_config) emitter = PipelineStageEmitter( pipeline=pipeline, session_id=run_input.thread_id, run_id=run_input.run_id, stage=stage_config.agent_type.value, emit_text_events=True, emit_tool_events=True, ) agent = self._build_agent( agent_name=stage_config.agent_type.value, system_prompt=build_system_prompt( agent_type=stage_config.agent_type, llm_config=stage_config.llm_config, user_context=user_context, now_utc=datetime.now(timezone.utc), runtime_client_time=runtime_client_time, extra_context=stage_config.extra_context, tools=None, memories=memories, ), toolkit=toolkit, model=tracking_model, emitter=emitter, ) response_msg = await agent.reply_json( input_messages, output_model=worker_output_model ) worker_payload = worker_output_model.model_validate(response_msg.metadata or {}) response_metadata = self._litellm_service.build_usage_metadata( model=stage_config.model_code, usage_summary=tracking_model.usage_summary(), ) await emitter.emit_final_text_end( worker_output=worker_payload.model_dump(mode="json", exclude_none=True), response_metadata=response_metadata, ) return StageExecutionResult( message=response_msg, payload=worker_payload.model_dump(mode="json", exclude_none=True), response_metadata=response_metadata, ) def _build_worker_input_messages( self, *, router_output: RouterAgentOutput, ) -> list[Msg]: return [ Msg( name=AgentType.ROUTER.value, role="user", content=build_worker_contract_prompt(router_output=router_output), ) ] def _build_model( self, *, stage_config: SystemAgentRuntimeConfig ) -> TrackingChatModel: generate_kwargs: dict[str, Any] = { "temperature": stage_config.llm_config.temperature, "max_tokens": stage_config.llm_config.max_tokens, "timeout": stage_config.llm_config.timeout_seconds, "extra_body": {"enable_thinking": False}, } model = OpenAIChatModel( model_name=stage_config.model_code, api_key=stage_config.api_key, stream=False, client_kwargs={"base_url": stage_config.api_base_url}, generate_kwargs=generate_kwargs, ) return TrackingChatModel(model) def _build_agent( self, *, agent_name: str, system_prompt: str, toolkit: Any, model: TrackingChatModel, emitter: PipelineStageEmitter | None = None, ) -> JsonReActAgent: return JsonReActAgent( name=agent_name, sys_prompt=system_prompt, model=model, formatter=OpenAIChatFormatter(), toolkit=toolkit, memory=InMemoryMemory(), emitter=emitter, ) async def _emit_step_event( self, *, pipeline: PipelineLike, run_input: RunAgentInput, step_name: str, event_type: str, ) -> None: await pipeline.emit( session_id=run_input.thread_id, event={ "type": event_type, "threadId": run_input.thread_id, "runId": run_input.run_id, "stepName": step_name, }, ) def _resolve_runtime_client_time( self, *, run_input: RunAgentInput ) -> ClientTimeContext | None: return parse_forwarded_props_client_time( getattr(run_input, "forwarded_props", None) ) @staticmethod def _resolve_provider_api_key(*, factory_name: str) -> str: normalized_factory_name = factory_name.strip().upper() if normalized_factory_name == "VOLCENGINE": normalized_factory_name = "ARK" provider_keys = { str(key).strip().upper(): str(value).strip() for key, value in config.llm.provider_keys.items() if str(value).strip() } api_key = provider_keys.get(normalized_factory_name, "") if not api_key: raise RuntimeError(f"provider api key missing for factory: {factory_name}") return api_key @dataclass(frozen=True) class SystemAgentRuntimeConfig: agent_type: AgentType model_code: str api_base_url: str api_key: str llm_config: SystemAgentLLMConfig extra_context: str | None = None AgentScopeReActRunner = AgentScopeRunner