Files
eryao/backend/src/core/agentscope/runtime/runner.py
T

473 lines
17 KiB
Python

from __future__ import annotations
import asyncio
import contextlib
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Any, Awaitable, Callable
from ag_ui.core.types import RunAgentInput
from agentscope.formatter import OpenAIChatFormatter
from agentscope.memory import InMemoryMemory
from agentscope.message import Msg
from agentscope.tool import Toolkit
from agentscope.model import OpenAIChatModel
from core.agentscope.prompts.system_prompt import build_system_prompt
from core.agentscope.prompts.user_prompt import (
build_divination_user_prompt,
build_follow_up_user_prompt,
)
from core.agentscope.schemas.agui_input import extract_latest_user_payload
from core.divination import derive_divination
from core.agentscope.runtime.json_react_agent import JsonReActAgent
from core.agentscope.runtime.model_tracking import TrackingChatModel
from core.agentscope.runtime.stage_emitter import PipelineStageEmitter
from core.agentscope.utils.compat import patch_agentscope_json_repair_compat
from core.agentscope.utils.json_finalize import finalize_json_response
from core.config.settings import config
from core.db.session import AsyncSessionLocal
from models.llm import Llm
from models.llm_factory import LlmFactory
from models.system_agents import SystemAgents
from schemas.agent.forwarded_props import (
RuntimeMode,
parse_forwarded_props_divination_payload,
parse_forwarded_props_runtime_mode,
)
from schemas.domain.divination import DerivedDivinationData
from schemas.agent.runtime_models import (
FollowUpOutput,
WorkerAgentOutputLite,
resolve_worker_output_model,
)
from schemas.agent.system_agent import (
AgentType,
SystemAgentLLMConfig,
)
from schemas.shared.user import UserContext
from services.llm_pricing.service import LlmPricingService
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from core.agentscope.runtime.protocols import PipelineLike
@dataclass(frozen=True)
class StageExecutionResult:
message: Msg
validated_output: WorkerAgentOutputLite | FollowUpOutput
response_metadata: dict[str, Any]
class AgentScopeRunner:
def __init__(self, *, llm_pricing_service: LlmPricingService | None = None) -> None:
patch_agentscope_json_repair_compat()
self._llm_pricing_service: LlmPricingService = (
llm_pricing_service or LlmPricingService()
)
self._active_agent: JsonReActAgent | None = None
self._active_agent_lock = asyncio.Lock()
async def execute(
self,
*,
user_context: UserContext,
context_messages: list[Msg],
pipeline: PipelineLike,
run_input: RunAgentInput,
runtime_config: Any,
cancel_checker: Callable[[], Awaitable[bool]] | None = None,
) -> dict[str, Any]:
_ = runtime_config
runtime_mode = self._resolve_runtime_mode(run_input=run_input)
stop_cancel_watch = asyncio.Event()
cancel_watch_task: asyncio.Task[None] | None = None
run_task = asyncio.current_task()
if cancel_checker is not None and run_task is not None:
cancel_watch_task = asyncio.create_task(
self._watch_cancel_signal(
cancel_checker=cancel_checker,
stop_signal=stop_cancel_watch,
run_task=run_task,
)
)
try:
async with AsyncSessionLocal() as session:
worker_config = await self._load_stage_config(
session=session,
agent_type=AgentType.WORKER,
)
worker_toolkit = self._build_toolkit()
if cancel_checker is not None and await cancel_checker():
raise asyncio.CancelledError("run canceled by user")
derived_divination: DerivedDivinationData | None = None
if runtime_mode == RuntimeMode.CHAT:
derived_divination = self._resolve_derived_divination(
run_input=run_input
)
await self._emit_step_event(
pipeline=pipeline,
run_input=run_input,
step_name="divination",
event_type="DIVINATION_DERIVED",
runtime_mode=runtime_mode,
extra_event={
"divination": derived_divination.model_dump(
mode="json", by_alias=True, exclude_none=True
)
},
)
worker_output = await self._execute_worker_step(
pipeline=pipeline,
run_input=run_input,
user_context=user_context,
context_messages=context_messages,
toolkit=worker_toolkit,
stage_config=worker_config,
runtime_mode=runtime_mode,
derived_divination=derived_divination,
)
return {
"worker": worker_output.model_dump(mode="json", exclude_none=True),
}
finally:
stop_cancel_watch.set()
if cancel_watch_task is not None:
cancel_watch_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await cancel_watch_task
async def _watch_cancel_signal(
self,
*,
cancel_checker: Callable[[], Awaitable[bool]],
stop_signal: asyncio.Event,
run_task: asyncio.Task[object],
) -> None:
while not stop_signal.is_set():
should_cancel = False
try:
should_cancel = await cancel_checker()
except Exception:
should_cancel = False
if should_cancel:
async with self._active_agent_lock:
active_agent = self._active_agent
if active_agent is not None:
with contextlib.suppress(Exception):
await active_agent.interrupt()
if not run_task.done():
run_task.cancel("run canceled by user")
return
await asyncio.sleep(0.2)
def _build_toolkit(
self,
) -> Toolkit:
return Toolkit()
async def _load_stage_config(
self,
*,
session: AsyncSession,
agent_type: AgentType,
) -> SystemAgentRuntimeConfig:
stmt = (
select(SystemAgents, Llm, LlmFactory)
.join(Llm, SystemAgents.llm_id == Llm.id)
.join(LlmFactory, Llm.factory_id == LlmFactory.id)
.where(SystemAgents.agent_type == agent_type.value)
)
row = (await session.execute(stmt)).one_or_none()
if row is None:
raise RuntimeError(f"system agent config not found: {agent_type.value}")
system_agent, llm, factory = row
status = str(system_agent.status).strip().lower()
if status != "active":
raise RuntimeError(f"system agent is not active: {agent_type.value}")
return SystemAgentRuntimeConfig(
agent_type=agent_type,
model_code=llm.model_code,
api_base_url=factory.request_url,
api_key=self._resolve_provider_api_key(factory_name=factory.name),
llm_config=SystemAgentLLMConfig.model_validate(system_agent.config or {}),
)
async def _execute_worker_step(
self,
*,
pipeline: PipelineLike,
run_input: RunAgentInput,
user_context: UserContext,
context_messages: list[Msg],
toolkit: Any,
stage_config: SystemAgentRuntimeConfig,
runtime_mode: RuntimeMode,
derived_divination: DerivedDivinationData | None,
) -> WorkerAgentOutputLite | FollowUpOutput:
worker_output_model = resolve_worker_output_model(
runtime_mode=runtime_mode.value
)
language = "zh-CN"
if user_context.settings is not None:
prefs = getattr(user_context.settings, "preferences", None)
if prefs is not None:
language = getattr(prefs, "language", "zh-CN") or "zh-CN"
await self._emit_step_event(
pipeline=pipeline,
run_input=run_input,
step_name=AgentType.WORKER.value,
event_type="STEP_STARTED",
runtime_mode=runtime_mode,
)
worker_result = await self._run_worker_stage(
user_context=user_context,
input_messages=self._build_worker_input_messages(
context_messages=context_messages,
run_input=run_input,
derived_divination=derived_divination,
language=language,
),
toolkit=toolkit,
run_input=run_input,
stage_config=stage_config,
worker_output_model=worker_output_model,
pipeline=pipeline,
runtime_mode=runtime_mode,
derived_divination=derived_divination,
language=language,
)
worker_output = worker_result.validated_output
await self._emit_step_event(
pipeline=pipeline,
run_input=run_input,
step_name=AgentType.WORKER.value,
event_type="STEP_FINISHED",
runtime_mode=runtime_mode,
)
return worker_output
async def _run_worker_stage(
self,
*,
user_context: UserContext,
input_messages: list[Msg],
toolkit: Any,
run_input: RunAgentInput,
stage_config: SystemAgentRuntimeConfig,
worker_output_model: type[WorkerAgentOutputLite | FollowUpOutput],
pipeline: PipelineLike,
runtime_mode: RuntimeMode,
derived_divination: DerivedDivinationData | None,
language: str,
) -> StageExecutionResult:
tracking_model = self._build_model(stage_config=stage_config)
formatter = OpenAIChatFormatter()
emitter = PipelineStageEmitter(
pipeline=pipeline,
session_id=run_input.thread_id,
run_id=run_input.run_id,
stage=stage_config.agent_type.value,
runtime_mode=runtime_mode.value,
emit_text_events=True,
emit_tool_events=False,
)
system_prompt = build_system_prompt(
agent_type=stage_config.agent_type,
language=language,
llm_config=stage_config.llm_config,
tools=None,
now_utc=datetime.now(timezone.utc),
runtime_mode=runtime_mode.value,
)
_, worker_payload = await finalize_json_response(
model=tracking_model,
formatter=formatter,
base_messages=[Msg("system", system_prompt, "system"), *input_messages],
output_model=worker_output_model,
retries=2,
language=language,
)
response_metadata = self._llm_pricing_service.build_usage_metadata(
model=stage_config.model_code,
usage_summary=tracking_model.usage_summary(),
)
await emitter.emit_final_text_end(
worker_output=self._build_final_worker_output(
worker_payload=worker_payload,
runtime_mode=runtime_mode,
derived_divination=derived_divination,
),
response_metadata=response_metadata,
)
return StageExecutionResult(
message=Msg(
name=stage_config.agent_type.value,
role="assistant",
content=worker_payload.answer,
),
validated_output=worker_payload,
response_metadata=response_metadata,
)
def _build_worker_input_messages(
self,
*,
context_messages: list[Msg],
run_input: RunAgentInput,
derived_divination: DerivedDivinationData | None,
language: str = "zh-CN",
) -> list[Msg]:
if derived_divination is not None:
user_text = build_divination_user_prompt(
derived=derived_divination, language=language
)
else:
raw_user_text, _ = extract_latest_user_payload(run_input)
user_text = build_follow_up_user_prompt(
question=raw_user_text, language=language
)
if context_messages:
last = context_messages[-1]
if last.role == "user":
context_messages[-1] = Msg(
name=last.name,
role=last.role,
content=user_text,
)
return context_messages
user_msg = Msg(name="user", role="user", content=user_text)
return [*context_messages, user_msg]
@staticmethod
def _resolve_derived_divination(
*, run_input: RunAgentInput
) -> DerivedDivinationData:
payload = parse_forwarded_props_divination_payload(
getattr(run_input, "forwarded_props", None)
)
if payload is None:
raise ValueError("forwardedProps.divinationPayload is required")
return derive_divination(payload)
def _build_model(
self, *, stage_config: SystemAgentRuntimeConfig
) -> TrackingChatModel:
generate_kwargs: dict[str, Any] = {
"timeout": stage_config.llm_config.timeout_seconds,
"extra_body": {"thinking": {"type": "disabled"}},
}
if stage_config.llm_config.temperature is not None:
generate_kwargs["temperature"] = stage_config.llm_config.temperature
if stage_config.llm_config.max_tokens is not None:
generate_kwargs["max_tokens"] = stage_config.llm_config.max_tokens
model = OpenAIChatModel(
model_name=stage_config.model_code,
api_key=stage_config.api_key,
stream=False,
client_kwargs={"base_url": stage_config.api_base_url},
generate_kwargs=generate_kwargs,
)
return TrackingChatModel(model)
def _build_agent(
self,
*,
agent_name: str,
system_prompt: str,
toolkit: Any,
model: TrackingChatModel,
emitter: PipelineStageEmitter | None = None,
) -> JsonReActAgent:
return JsonReActAgent(
name=agent_name,
sys_prompt=system_prompt,
model=model,
formatter=OpenAIChatFormatter(),
toolkit=toolkit,
memory=InMemoryMemory(),
emitter=emitter,
)
async def _emit_step_event(
self,
*,
pipeline: PipelineLike,
run_input: RunAgentInput,
step_name: str,
event_type: str,
runtime_mode: RuntimeMode,
extra_event: dict[str, Any] | None = None,
) -> None:
payload: dict[str, Any] = {
"type": event_type,
"threadId": run_input.thread_id,
"runId": run_input.run_id,
"runtime_mode": runtime_mode.value,
"stepName": step_name,
}
if extra_event:
payload.update(extra_event)
await pipeline.emit(
session_id=run_input.thread_id,
event=payload,
)
@staticmethod
def _resolve_runtime_mode(*, run_input: RunAgentInput) -> RuntimeMode:
return parse_forwarded_props_runtime_mode(
getattr(run_input, "forwarded_props", None)
)
@staticmethod
def _build_final_worker_output(
*,
worker_payload: WorkerAgentOutputLite | FollowUpOutput,
runtime_mode: RuntimeMode,
derived_divination: DerivedDivinationData | None,
) -> dict[str, Any]:
payload = worker_payload.model_dump(mode="json", exclude_none=True)
if runtime_mode == RuntimeMode.CHAT and derived_divination is not None:
payload["divination_derived"] = derived_divination.model_dump(
mode="json", by_alias=True, exclude_none=True
)
return payload
@staticmethod
def _resolve_provider_api_key(*, factory_name: str) -> str:
normalized_factory_name = factory_name.strip().upper()
if normalized_factory_name == "VOLCENGINE":
normalized_factory_name = "ARK"
provider_keys = {
str(key).strip().upper(): str(value).strip()
for key, value in config.llm.provider_keys.items()
if str(value).strip()
}
api_key = provider_keys.get(normalized_factory_name, "")
if not api_key:
raise RuntimeError(f"provider api key missing for factory: {factory_name}")
return api_key
@dataclass(frozen=True)
class SystemAgentRuntimeConfig:
agent_type: AgentType
model_code: str
api_base_url: str
api_key: str
llm_config: SystemAgentLLMConfig
AgentScopeReActRunner = AgentScopeRunner