from __future__ import annotations import asyncio import contextlib from dataclasses import dataclass from datetime import datetime, timezone from typing import Any, Awaitable, Callable from ag_ui.core.types import RunAgentInput from agentscope.formatter import OpenAIChatFormatter from agentscope.memory import InMemoryMemory from agentscope.message import Msg from agentscope.tool import Toolkit from agentscope.model import OpenAIChatModel from core.agentscope.prompts.system_prompt import build_system_prompt from core.agentscope.prompts.user_prompt import ( build_divination_user_prompt, build_follow_up_user_prompt, ) from core.agentscope.schemas.agui_input import extract_latest_user_payload from core.divination import derive_divination from core.agentscope.runtime.json_react_agent import JsonReActAgent from core.agentscope.runtime.model_tracking import TrackingChatModel from core.agentscope.runtime.stage_emitter import PipelineStageEmitter from core.agentscope.utils.compat import patch_agentscope_json_repair_compat from core.agentscope.utils.json_finalize import finalize_json_response from core.config.settings import config from core.db.session import AsyncSessionLocal from models.llm import Llm from models.llm_factory import LlmFactory from models.system_agents import SystemAgents from schemas.agent.forwarded_props import ( RuntimeMode, parse_forwarded_props_divination_payload, parse_forwarded_props_runtime_mode, ) from schemas.domain.divination import DerivedDivinationData from schemas.agent.runtime_models import ( FollowUpOutput, WorkerAgentOutputLite, resolve_worker_output_model, ) from schemas.agent.system_agent import ( AgentType, SystemAgentLLMConfig, ) from schemas.shared.user import UserContext from services.llm_pricing.service import LlmPricingService from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from core.agentscope.runtime.protocols import PipelineLike @dataclass(frozen=True) class StageExecutionResult: message: Msg validated_output: WorkerAgentOutputLite | FollowUpOutput response_metadata: dict[str, Any] class AgentScopeRunner: def __init__(self, *, llm_pricing_service: LlmPricingService | None = None) -> None: patch_agentscope_json_repair_compat() self._llm_pricing_service: LlmPricingService = ( llm_pricing_service or LlmPricingService() ) self._active_agent: JsonReActAgent | None = None self._active_agent_lock = asyncio.Lock() async def execute( self, *, user_context: UserContext, context_messages: list[Msg], pipeline: PipelineLike, run_input: RunAgentInput, runtime_config: Any, cancel_checker: Callable[[], Awaitable[bool]] | None = None, ) -> dict[str, Any]: _ = runtime_config runtime_mode = self._resolve_runtime_mode(run_input=run_input) stop_cancel_watch = asyncio.Event() cancel_watch_task: asyncio.Task[None] | None = None run_task = asyncio.current_task() if cancel_checker is not None and run_task is not None: cancel_watch_task = asyncio.create_task( self._watch_cancel_signal( cancel_checker=cancel_checker, stop_signal=stop_cancel_watch, run_task=run_task, ) ) try: async with AsyncSessionLocal() as session: worker_config = await self._load_stage_config( session=session, agent_type=AgentType.WORKER, ) worker_toolkit = self._build_toolkit() if cancel_checker is not None and await cancel_checker(): raise asyncio.CancelledError("run canceled by user") derived_divination: DerivedDivinationData | None = None if runtime_mode == RuntimeMode.CHAT: derived_divination = self._resolve_derived_divination( run_input=run_input ) await self._emit_step_event( pipeline=pipeline, run_input=run_input, step_name="divination", event_type="DIVINATION_DERIVED", runtime_mode=runtime_mode, extra_event={ "divination": derived_divination.model_dump( mode="json", by_alias=True, exclude_none=True ) }, ) worker_output = await self._execute_worker_step( pipeline=pipeline, run_input=run_input, user_context=user_context, context_messages=context_messages, toolkit=worker_toolkit, stage_config=worker_config, runtime_mode=runtime_mode, derived_divination=derived_divination, ) return { "worker": worker_output.model_dump(mode="json", exclude_none=True), } finally: stop_cancel_watch.set() if cancel_watch_task is not None: cancel_watch_task.cancel() with contextlib.suppress(asyncio.CancelledError): await cancel_watch_task async def _watch_cancel_signal( self, *, cancel_checker: Callable[[], Awaitable[bool]], stop_signal: asyncio.Event, run_task: asyncio.Task[object], ) -> None: while not stop_signal.is_set(): should_cancel = False try: should_cancel = await cancel_checker() except Exception: should_cancel = False if should_cancel: async with self._active_agent_lock: active_agent = self._active_agent if active_agent is not None: with contextlib.suppress(Exception): await active_agent.interrupt() if not run_task.done(): run_task.cancel("run canceled by user") return await asyncio.sleep(0.2) def _build_toolkit( self, ) -> Toolkit: return Toolkit() async def _load_stage_config( self, *, session: AsyncSession, agent_type: AgentType, ) -> SystemAgentRuntimeConfig: stmt = ( select(SystemAgents, Llm, LlmFactory) .join(Llm, SystemAgents.llm_id == Llm.id) .join(LlmFactory, Llm.factory_id == LlmFactory.id) .where(SystemAgents.agent_type == agent_type.value) ) row = (await session.execute(stmt)).one_or_none() if row is None: raise RuntimeError(f"system agent config not found: {agent_type.value}") system_agent, llm, factory = row status = str(system_agent.status).strip().lower() if status != "active": raise RuntimeError(f"system agent is not active: {agent_type.value}") return SystemAgentRuntimeConfig( agent_type=agent_type, model_code=llm.model_code, api_base_url=factory.request_url, api_key=self._resolve_provider_api_key(factory_name=factory.name), llm_config=SystemAgentLLMConfig.model_validate(system_agent.config or {}), ) async def _execute_worker_step( self, *, pipeline: PipelineLike, run_input: RunAgentInput, user_context: UserContext, context_messages: list[Msg], toolkit: Any, stage_config: SystemAgentRuntimeConfig, runtime_mode: RuntimeMode, derived_divination: DerivedDivinationData | None, ) -> WorkerAgentOutputLite | FollowUpOutput: worker_output_model = resolve_worker_output_model( runtime_mode=runtime_mode.value ) language = "zh-CN" if user_context.settings is not None: prefs = getattr(user_context.settings, "preferences", None) if prefs is not None: language = getattr(prefs, "language", "zh-CN") or "zh-CN" await self._emit_step_event( pipeline=pipeline, run_input=run_input, step_name=AgentType.WORKER.value, event_type="STEP_STARTED", runtime_mode=runtime_mode, ) worker_result = await self._run_worker_stage( user_context=user_context, input_messages=self._build_worker_input_messages( context_messages=context_messages, run_input=run_input, derived_divination=derived_divination, language=language, ), toolkit=toolkit, run_input=run_input, stage_config=stage_config, worker_output_model=worker_output_model, pipeline=pipeline, runtime_mode=runtime_mode, derived_divination=derived_divination, language=language, ) worker_output = worker_result.validated_output await self._emit_step_event( pipeline=pipeline, run_input=run_input, step_name=AgentType.WORKER.value, event_type="STEP_FINISHED", runtime_mode=runtime_mode, ) return worker_output async def _run_worker_stage( self, *, user_context: UserContext, input_messages: list[Msg], toolkit: Any, run_input: RunAgentInput, stage_config: SystemAgentRuntimeConfig, worker_output_model: type[WorkerAgentOutputLite | FollowUpOutput], pipeline: PipelineLike, runtime_mode: RuntimeMode, derived_divination: DerivedDivinationData | None, language: str, ) -> StageExecutionResult: tracking_model = self._build_model(stage_config=stage_config) formatter = OpenAIChatFormatter() emitter = PipelineStageEmitter( pipeline=pipeline, session_id=run_input.thread_id, run_id=run_input.run_id, stage=stage_config.agent_type.value, runtime_mode=runtime_mode.value, emit_text_events=True, emit_tool_events=False, ) system_prompt = build_system_prompt( agent_type=stage_config.agent_type, language=language, llm_config=stage_config.llm_config, tools=None, now_utc=datetime.now(timezone.utc), runtime_mode=runtime_mode.value, ) _, worker_payload = await finalize_json_response( model=tracking_model, formatter=formatter, base_messages=[Msg("system", system_prompt, "system"), *input_messages], output_model=worker_output_model, retries=2, language=language, ) response_metadata = self._llm_pricing_service.build_usage_metadata( model=stage_config.model_code, usage_summary=tracking_model.usage_summary(), ) await emitter.emit_final_text_end( worker_output=self._build_final_worker_output( worker_payload=worker_payload, runtime_mode=runtime_mode, derived_divination=derived_divination, ), response_metadata=response_metadata, ) return StageExecutionResult( message=Msg( name=stage_config.agent_type.value, role="assistant", content=worker_payload.answer, ), validated_output=worker_payload, response_metadata=response_metadata, ) def _build_worker_input_messages( self, *, context_messages: list[Msg], run_input: RunAgentInput, derived_divination: DerivedDivinationData | None, language: str = "zh-CN", ) -> list[Msg]: if derived_divination is not None: user_text = build_divination_user_prompt( derived=derived_divination, language=language ) else: raw_user_text, _ = extract_latest_user_payload(run_input) user_text = build_follow_up_user_prompt( question=raw_user_text, language=language ) if context_messages: last = context_messages[-1] if last.role == "user": context_messages[-1] = Msg( name=last.name, role=last.role, content=user_text, ) return context_messages user_msg = Msg(name="user", role="user", content=user_text) return [*context_messages, user_msg] @staticmethod def _resolve_derived_divination( *, run_input: RunAgentInput ) -> DerivedDivinationData: payload = parse_forwarded_props_divination_payload( getattr(run_input, "forwarded_props", None) ) if payload is None: raise ValueError("forwardedProps.divinationPayload is required") return derive_divination(payload) def _build_model( self, *, stage_config: SystemAgentRuntimeConfig ) -> TrackingChatModel: generate_kwargs: dict[str, Any] = { "timeout": stage_config.llm_config.timeout_seconds, "extra_body": {"thinking": {"type": "disabled"}}, } if stage_config.llm_config.temperature is not None: generate_kwargs["temperature"] = stage_config.llm_config.temperature if stage_config.llm_config.max_tokens is not None: generate_kwargs["max_tokens"] = stage_config.llm_config.max_tokens model = OpenAIChatModel( model_name=stage_config.model_code, api_key=stage_config.api_key, stream=False, client_kwargs={"base_url": stage_config.api_base_url}, generate_kwargs=generate_kwargs, ) return TrackingChatModel(model) def _build_agent( self, *, agent_name: str, system_prompt: str, toolkit: Any, model: TrackingChatModel, emitter: PipelineStageEmitter | None = None, ) -> JsonReActAgent: return JsonReActAgent( name=agent_name, sys_prompt=system_prompt, model=model, formatter=OpenAIChatFormatter(), toolkit=toolkit, memory=InMemoryMemory(), emitter=emitter, ) async def _emit_step_event( self, *, pipeline: PipelineLike, run_input: RunAgentInput, step_name: str, event_type: str, runtime_mode: RuntimeMode, extra_event: dict[str, Any] | None = None, ) -> None: payload: dict[str, Any] = { "type": event_type, "threadId": run_input.thread_id, "runId": run_input.run_id, "runtime_mode": runtime_mode.value, "stepName": step_name, } if extra_event: payload.update(extra_event) await pipeline.emit( session_id=run_input.thread_id, event=payload, ) @staticmethod def _resolve_runtime_mode(*, run_input: RunAgentInput) -> RuntimeMode: return parse_forwarded_props_runtime_mode( getattr(run_input, "forwarded_props", None) ) @staticmethod def _build_final_worker_output( *, worker_payload: WorkerAgentOutputLite | FollowUpOutput, runtime_mode: RuntimeMode, derived_divination: DerivedDivinationData | None, ) -> dict[str, Any]: payload = worker_payload.model_dump(mode="json", exclude_none=True) if runtime_mode == RuntimeMode.CHAT and derived_divination is not None: payload["divination_derived"] = derived_divination.model_dump( mode="json", by_alias=True, exclude_none=True ) return payload @staticmethod def _resolve_provider_api_key(*, factory_name: str) -> str: normalized_factory_name = factory_name.strip().upper() if normalized_factory_name == "VOLCENGINE": normalized_factory_name = "ARK" provider_keys = { str(key).strip().upper(): str(value).strip() for key, value in config.llm.provider_keys.items() if str(value).strip() } api_key = provider_keys.get(normalized_factory_name, "") if not api_key: raise RuntimeError(f"provider api key missing for factory: {factory_name}") return api_key @dataclass(frozen=True) class SystemAgentRuntimeConfig: agent_type: AgentType model_code: str api_base_url: str api_key: str llm_config: SystemAgentLLMConfig AgentScopeReActRunner = AgentScopeRunner