from __future__ import annotations import asyncio import contextlib from dataclasses import dataclass from datetime import datetime, timezone from typing import TYPE_CHECKING, Any, Awaitable, Callable from ag_ui.core.types import RunAgentInput from agentscope.formatter import OpenAIChatFormatter from agentscope.memory import InMemoryMemory from agentscope.message import Msg from agentscope.model import OpenAIChatModel from core.agentscope.prompts.agent_prompt import build_worker_contract_prompt from core.agentscope.prompts.system_prompt import build_system_prompt from core.agentscope.schemas.agui_input import extract_latest_user_payload from core.agentscope.runtime.json_react_agent import JsonReActAgent from core.agentscope.runtime.model_tracking import TrackingChatModel from core.agentscope.runtime.stage_emitter import PipelineStageEmitter from core.agentscope.tools.toolkit import build_toolkit from core.agentscope.utils import ( finalize_json_response, patch_agentscope_json_repair_compat, ) from core.auth.credential_issuer import create_credential_issuer from core.auth.tool_credential_context import ( set_tool_credential, reset_tool_credential, ) from core.config.settings import config from core.db.session import AsyncSessionLocal from models.llm import Llm from models.llm_factory import LlmFactory from models.system_agents import SystemAgents from schemas.agent.forwarded_props import ( ClientTimeContext, RuntimeMode, parse_forwarded_props_client_time, parse_forwarded_props_runtime_mode, ) from schemas.agent.runtime_models import ( RouterAgentOutput, WorkerAgentOutputLite, ) from schemas.agent.skill_config import ProjectCliCommand, SkillName from schemas.agent.system_agent import ( AgentType, SystemAgentLLMConfig, ) from schemas.domain.automation import RuntimeConfig from schemas.domain.memory_content import UserMemoryContent, WorkProfileContent from schemas.shared.user import UserContext from services.llm_pricing.service import LlmPricingService from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession if TYPE_CHECKING: from core.agentscope.runtime.orchestrator import PipelineLike @dataclass(frozen=True) class StageExecutionResult: message: Msg payload: dict[str, Any] response_metadata: dict[str, Any] class AgentScopeRunner: def __init__(self, *, llm_pricing_service: LlmPricingService | None = None) -> None: patch_agentscope_json_repair_compat() self._llm_pricing_service: LlmPricingService = ( llm_pricing_service or LlmPricingService() ) self._active_agent: JsonReActAgent | None = None self._active_agent_lock = asyncio.Lock() async def execute( self, *, user_context: UserContext, context_messages: list[Msg], pipeline: PipelineLike, run_input: RunAgentInput, runtime_config: RuntimeConfig, user_memory: UserMemoryContent | None = None, work_memory: WorkProfileContent | None = None, cancel_checker: Callable[[], Awaitable[bool]] | None = None, ) -> dict[str, Any]: runtime_client_time = self._resolve_runtime_client_time(run_input=run_input) runtime_mode = self._resolve_runtime_mode(run_input=run_input) stop_cancel_watch = asyncio.Event() cancel_watch_task: asyncio.Task[None] | None = None run_task = asyncio.current_task() if cancel_checker is not None and run_task is not None: cancel_watch_task = asyncio.create_task( self._watch_cancel_signal( cancel_checker=cancel_checker, stop_signal=stop_cancel_watch, run_task=run_task, ) ) try: async with AsyncSessionLocal() as session: router_config = await self._load_stage_config( session=session, agent_type=AgentType.ROUTER, ) worker_config = await self._load_stage_config( session=session, agent_type=AgentType.WORKER, ) worker_toolkit = self._build_toolkit( enabled_skills=runtime_config.enabled_skills, allowed_commands=runtime_config.allowed_commands, ) router_output = await self._execute_router_step( pipeline=pipeline, run_input=run_input, user_context=user_context, context_messages=context_messages, stage_config=router_config, runtime_client_time=runtime_client_time, runtime_mode=runtime_mode, user_memory=user_memory, ) if cancel_checker is not None and await cancel_checker(): raise asyncio.CancelledError("run canceled by user") worker_output = await self._execute_worker_step( pipeline=pipeline, run_input=run_input, user_context=user_context, router_output=router_output, toolkit=worker_toolkit, stage_config=worker_config, runtime_client_time=runtime_client_time, runtime_mode=runtime_mode, work_memory=work_memory, ) return { "router": router_output.model_dump(mode="json", exclude_none=True), "worker": worker_output.model_dump(mode="json", exclude_none=True), } finally: stop_cancel_watch.set() if cancel_watch_task is not None: cancel_watch_task.cancel() with contextlib.suppress(asyncio.CancelledError): await cancel_watch_task async def _watch_cancel_signal( self, *, cancel_checker: Callable[[], Awaitable[bool]], stop_signal: asyncio.Event, run_task: asyncio.Task[object], ) -> None: while not stop_signal.is_set(): should_cancel = False try: should_cancel = await cancel_checker() except Exception: should_cancel = False if should_cancel: async with self._active_agent_lock: active_agent = self._active_agent if active_agent is not None: with contextlib.suppress(Exception): await active_agent.interrupt() if not run_task.done(): run_task.cancel("run canceled by user") return await asyncio.sleep(0.2) def _build_toolkit( self, *, enabled_skills: list[SkillName], allowed_commands: list[ProjectCliCommand], ) -> Any: enabled_skill_names = {str(skill.value) for skill in enabled_skills} allowed_command_names = {str(command.value) for command in allowed_commands} return build_toolkit( enabled_skill_names=enabled_skill_names if enabled_skill_names else None, allowed_commands=allowed_command_names if allowed_command_names else None, ) async def _load_stage_config( self, *, session: AsyncSession, agent_type: AgentType, ) -> SystemAgentRuntimeConfig: stmt = ( select(SystemAgents, Llm, LlmFactory) .join(Llm, SystemAgents.llm_id == Llm.id) .join(LlmFactory, Llm.factory_id == LlmFactory.id) .where(SystemAgents.agent_type == agent_type.value) ) row = (await session.execute(stmt)).one_or_none() if row is None: raise RuntimeError(f"system agent config not found: {agent_type.value}") system_agent, llm, factory = row status = str(system_agent.status).strip().lower() if status != "active": raise RuntimeError(f"system agent is not active: {agent_type.value}") return SystemAgentRuntimeConfig( agent_type=agent_type, model_code=llm.model_code, api_base_url=factory.request_url, api_key=self._resolve_provider_api_key(factory_name=factory.name), llm_config=SystemAgentLLMConfig.model_validate(system_agent.config or {}), extra_context=None, ) async def _execute_router_step( self, *, pipeline: PipelineLike, run_input: RunAgentInput, user_context: UserContext, context_messages: list[Msg], stage_config: SystemAgentRuntimeConfig, runtime_client_time: ClientTimeContext | None, runtime_mode: RuntimeMode, user_memory: UserMemoryContent | None, ) -> RouterAgentOutput: await self._emit_step_event( pipeline=pipeline, run_input=run_input, step_name=AgentType.ROUTER.value, event_type="STEP_STARTED", runtime_mode=runtime_mode, ) router_result = await self._run_router_stage( user_context=user_context, context_messages=context_messages, stage_config=stage_config, runtime_client_time=runtime_client_time, user_memory=user_memory, run_input=run_input, ) router_output = RouterAgentOutput.model_validate(router_result.payload) await self._emit_step_event( pipeline=pipeline, run_input=run_input, step_name=AgentType.ROUTER.value, event_type="STEP_FINISHED", runtime_mode=runtime_mode, extra_event={ "_router_persist": { "router_output": router_output.model_dump( mode="json", exclude_none=True ), "response_metadata": router_result.response_metadata, } }, ) return router_output async def _execute_worker_step( self, *, pipeline: PipelineLike, run_input: RunAgentInput, user_context: UserContext, router_output: RouterAgentOutput, toolkit: Any, stage_config: SystemAgentRuntimeConfig, runtime_client_time: ClientTimeContext | None, runtime_mode: RuntimeMode, work_memory: WorkProfileContent | None, ) -> WorkerAgentOutputLite: await self._emit_step_event( pipeline=pipeline, run_input=run_input, step_name=AgentType.WORKER.value, event_type="STEP_STARTED", runtime_mode=runtime_mode, ) worker_result = await self._run_worker_stage( user_context=user_context, input_messages=self._build_worker_input_messages( router_output=router_output ), toolkit=toolkit, run_input=run_input, stage_config=stage_config, worker_output_model=WorkerAgentOutputLite, pipeline=pipeline, runtime_client_time=runtime_client_time, runtime_mode=runtime_mode, work_memory=work_memory, requires_tool_evidence=router_output.requires_tool_evidence, ) worker_output = WorkerAgentOutputLite.model_validate(worker_result.payload) await self._emit_step_event( pipeline=pipeline, run_input=run_input, step_name=AgentType.WORKER.value, event_type="STEP_FINISHED", runtime_mode=runtime_mode, ) return worker_output async def _run_router_stage( self, *, user_context: UserContext, context_messages: list[Msg], stage_config: SystemAgentRuntimeConfig, runtime_client_time: ClientTimeContext | None, user_memory: UserMemoryContent | None, run_input: RunAgentInput, ) -> StageExecutionResult: messages_for_router = self._build_router_messages( context_messages=context_messages, run_input=run_input, ) tracking_model = self._build_model(stage_config=stage_config) response, payload = await finalize_json_response( model=tracking_model, formatter=OpenAIChatFormatter(), base_messages=[ Msg( "system", build_system_prompt( agent_type=AgentType.ROUTER, llm_config=stage_config.llm_config, user_context=user_context, now_utc=datetime.now(timezone.utc), runtime_client_time=runtime_client_time, user_memory=user_memory, ), "system", ), *messages_for_router, ], output_model=RouterAgentOutput, retries=3, ) response_msg = Msg( name="router", role="assistant", content=list(getattr(response, "content", [])), metadata=payload, ) return StageExecutionResult( message=response_msg, payload=payload, response_metadata=self._llm_pricing_service.build_usage_metadata( model=stage_config.model_code, usage_summary=tracking_model.usage_summary(), ), ) def _build_router_messages( self, *, context_messages: list[Msg], run_input: RunAgentInput, ) -> list[Msg]: if context_messages: last = context_messages[-1] if last.role == "user": return context_messages user_text, user_blocks = extract_latest_user_payload(run_input) if ( user_blocks and isinstance(user_blocks[0], dict) and user_blocks[0].get("type") == "text" ): content: Any = user_text else: content = user_blocks user_msg = Msg(name="user", role="user", content=content) return [*context_messages, user_msg] async def _run_worker_stage( self, *, user_context: UserContext, input_messages: list[Msg], toolkit: Any, run_input: RunAgentInput, stage_config: SystemAgentRuntimeConfig, worker_output_model: type[WorkerAgentOutputLite], pipeline: PipelineLike, runtime_client_time: ClientTimeContext | None, runtime_mode: RuntimeMode, work_memory: WorkProfileContent | None, requires_tool_evidence: bool = False, ) -> StageExecutionResult: issuer = create_credential_issuer() credential = issuer.issue( owner_id=str(user_context.id), mode=runtime_mode.value, ) credential_token = set_tool_credential(credential) try: tracking_model = self._build_model(stage_config=stage_config) emitter = PipelineStageEmitter( pipeline=pipeline, session_id=run_input.thread_id, run_id=run_input.run_id, stage=stage_config.agent_type.value, runtime_mode=runtime_mode.value, emit_text_events=True, emit_tool_events=True, ) agent = self._build_agent( agent_name=stage_config.agent_type.value, system_prompt=build_system_prompt( agent_type=stage_config.agent_type, llm_config=stage_config.llm_config, user_context=user_context, now_utc=datetime.now(timezone.utc), runtime_client_time=runtime_client_time, extra_context=stage_config.extra_context, work_memory=work_memory, ), toolkit=toolkit, model=tracking_model, emitter=emitter, force_tool_on_first_reasoning=requires_tool_evidence, ) async with self._active_agent_lock: self._active_agent = agent try: response_msg = await agent.reply_json( input_messages, output_model=worker_output_model ) finally: async with self._active_agent_lock: if self._active_agent is agent: self._active_agent = None worker_payload = worker_output_model.model_validate(response_msg.metadata or {}) response_metadata = self._llm_pricing_service.build_usage_metadata( model=stage_config.model_code, usage_summary=tracking_model.usage_summary(), ) await emitter.emit_final_text_end( worker_output=worker_payload.model_dump(mode="json", exclude_none=True), response_metadata=response_metadata, ) return StageExecutionResult( message=response_msg, payload=worker_payload.model_dump(mode="json", exclude_none=True), response_metadata=response_metadata, ) finally: reset_tool_credential(credential_token) def _build_worker_input_messages( self, *, router_output: RouterAgentOutput, ) -> list[Msg]: return [ Msg( name=AgentType.ROUTER.value, role="user", content=build_worker_contract_prompt(router_output=router_output), ) ] def _build_model( self, *, stage_config: SystemAgentRuntimeConfig ) -> TrackingChatModel: generate_kwargs: dict[str, Any] = { "timeout": stage_config.llm_config.timeout_seconds, "extra_body": {"enable_thinking": False}, } if stage_config.llm_config.temperature is not None: generate_kwargs["temperature"] = stage_config.llm_config.temperature if stage_config.llm_config.max_tokens is not None: generate_kwargs["max_tokens"] = stage_config.llm_config.max_tokens model = OpenAIChatModel( model_name=stage_config.model_code, api_key=stage_config.api_key, stream=False, client_kwargs={"base_url": stage_config.api_base_url}, generate_kwargs=generate_kwargs, ) return TrackingChatModel(model) def _build_agent( self, *, agent_name: str, system_prompt: str, toolkit: Any, model: TrackingChatModel, emitter: PipelineStageEmitter | None = None, force_tool_on_first_reasoning: bool = False, ) -> JsonReActAgent: return JsonReActAgent( name=agent_name, sys_prompt=system_prompt, model=model, formatter=OpenAIChatFormatter(), toolkit=toolkit, memory=InMemoryMemory(), emitter=emitter, force_tool_on_first_reasoning=force_tool_on_first_reasoning, ) async def _emit_step_event( self, *, pipeline: PipelineLike, run_input: RunAgentInput, step_name: str, event_type: str, runtime_mode: RuntimeMode, extra_event: dict[str, Any] | None = None, ) -> None: payload: dict[str, Any] = { "type": event_type, "threadId": run_input.thread_id, "runId": run_input.run_id, "runtime_mode": runtime_mode.value, "stepName": step_name, } if extra_event: payload.update(extra_event) await pipeline.emit( session_id=run_input.thread_id, event=payload, ) def _resolve_runtime_client_time( self, *, run_input: RunAgentInput ) -> ClientTimeContext | None: return parse_forwarded_props_client_time( getattr(run_input, "forwarded_props", None) ) @staticmethod def _resolve_runtime_mode(*, run_input: RunAgentInput) -> RuntimeMode: try: return parse_forwarded_props_runtime_mode( getattr(run_input, "forwarded_props", None) ) except ValueError: return RuntimeMode.CHAT @staticmethod def _resolve_provider_api_key(*, factory_name: str) -> str: normalized_factory_name = factory_name.strip().upper() if normalized_factory_name == "VOLCENGINE": normalized_factory_name = "ARK" provider_keys = { str(key).strip().upper(): str(value).strip() for key, value in config.llm.provider_keys.items() if str(value).strip() } api_key = provider_keys.get(normalized_factory_name, "") if not api_key: raise RuntimeError(f"provider api key missing for factory: {factory_name}") return api_key @dataclass(frozen=True) class SystemAgentRuntimeConfig: agent_type: AgentType model_code: str api_base_url: str api_key: str llm_config: SystemAgentLLMConfig extra_context: str | None = None AgentScopeReActRunner = AgentScopeRunner