Files
social-app/backend/src/core/agentscope/runtime/runner.py
T

333 lines
11 KiB
Python
Raw Normal View History

from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any
from uuid import UUID
from ag_ui.core.types import RunAgentInput
from agentscope.formatter import OpenAIChatFormatter
from agentscope.memory import InMemoryMemory
from agentscope.message import Msg
from agentscope.model import OpenAIChatModel
from core.agentscope.prompts.system_prompt import build_system_prompt
from core.agentscope.runtime.json_react_agent import JsonReActAgent
from core.agentscope.runtime.model_tracking import TrackingChatModel
from core.agentscope.runtime.stage_emitter import PipelineStageEmitter
from core.agentscope.runtime.tool_selection_registry import TOOL_SELECTION_REGISTRY
from core.agentscope.tools.toolkit import build_stage_toolkit
from core.agentscope.utils import patch_agentscope_json_repair_compat
from core.config.settings import config
from core.db.session import AsyncSessionLocal
from models.llm import Llm
from models.llm_factory import LlmFactory
from models.system_agents import SystemAgents
from schemas.agent.runtime_models import (
AgentOutput,
)
from schemas.agent.forwarded_props import (
ClientTimeContext,
parse_forwarded_props_client_time,
)
from schemas.agent.system_agent import AgentType, SystemAgentLLMConfig
from schemas.user import UserContext
from services.litellm.service import LiteLLMService
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
if TYPE_CHECKING:
from core.agentscope.runtime.orchestrator import PipelineLike
@dataclass(frozen=True)
class SystemAgentRuntimeConfig:
agent_type: AgentType
model_code: str
api_base_url: str
api_key: str
llm_config: SystemAgentLLMConfig
@dataclass(frozen=True)
class StageExecutionResult:
message: Msg
payload: dict[str, Any]
response_metadata: dict[str, Any]
class AgentScopeRunner:
def __init__(self, *, litellm_service: LiteLLMService | None = None) -> None:
patch_agentscope_json_repair_compat()
self._litellm_service: LiteLLMService = litellm_service or LiteLLMService()
async def execute(
self,
*,
user_context: UserContext,
context_messages: list[Msg],
pipeline: PipelineLike,
run_input: RunAgentInput,
system_agent_mode: str,
) -> dict[str, Any]:
owner_id = UUID(user_context.id)
runtime_client_time = self._resolve_runtime_client_time(run_input=run_input)
stage_agent_type = self._resolve_stage_agent_type(system_agent_mode)
async with AsyncSessionLocal() as session:
stage_config = await self._load_stage_config(
session=session,
agent_type=stage_agent_type,
)
stage_toolkit = self._build_stage_toolkit(
session=session,
owner_id=owner_id,
stage_config=stage_config,
)
worker_output = await self._execute_worker_step(
pipeline=pipeline,
run_input=run_input,
user_context=user_context,
context_messages=context_messages,
toolkit=stage_toolkit,
stage_config=stage_config,
runtime_client_time=runtime_client_time,
)
return {
"worker": worker_output.model_dump(mode="json", exclude_none=True),
}
def _build_stage_toolkit(
self,
*,
session: AsyncSession,
owner_id: UUID,
stage_config: SystemAgentRuntimeConfig,
) -> Any:
enabled_tool_names = TOOL_SELECTION_REGISTRY.resolve(stage_config=stage_config)
return build_stage_toolkit(
agent_type=stage_config.agent_type,
session=session,
owner_id=owner_id,
enabled_tool_names=enabled_tool_names,
)
@staticmethod
def _resolve_stage_agent_type(system_agent_mode: str) -> AgentType:
mode = system_agent_mode.strip().lower() if system_agent_mode else "worker"
if mode == AgentType.MEMORY.value:
return AgentType.MEMORY
return AgentType.WORKER
async def _load_stage_config(
self,
*,
session: AsyncSession,
agent_type: AgentType,
) -> SystemAgentRuntimeConfig:
return await self._load_system_agent_config(
session=session,
agent_type=agent_type,
)
async def _execute_worker_step(
self,
*,
pipeline: PipelineLike,
run_input: RunAgentInput,
user_context: UserContext,
context_messages: list[Msg],
toolkit: Any,
stage_config: SystemAgentRuntimeConfig,
runtime_client_time: ClientTimeContext | None,
) -> AgentOutput:
step_name = stage_config.agent_type.value
worker_output_model = AgentOutput
await self._emit_step_event(
pipeline=pipeline,
run_input=run_input,
step_name=step_name,
event_type="STEP_STARTED",
)
worker_result = await self._run_worker_stage(
user_context=user_context,
context_messages=context_messages,
toolkit=toolkit,
run_input=run_input,
stage_config=stage_config,
worker_output_model=worker_output_model,
pipeline=pipeline,
runtime_client_time=runtime_client_time,
)
worker_output = worker_output_model.model_validate(worker_result.payload)
await self._emit_step_event(
pipeline=pipeline,
run_input=run_input,
step_name=step_name,
event_type="STEP_FINISHED",
)
return worker_output
async def _load_system_agent_config(
self,
*,
session: AsyncSession,
agent_type: AgentType,
) -> SystemAgentRuntimeConfig:
stmt = (
select(SystemAgents, Llm, LlmFactory)
.join(Llm, SystemAgents.llm_id == Llm.id)
.join(LlmFactory, Llm.factory_id == LlmFactory.id)
.where(SystemAgents.agent_type == agent_type.value)
)
row = (await session.execute(stmt)).one_or_none()
if row is None:
raise RuntimeError(f"system agent config not found: {agent_type.value}")
system_agent, llm, factory = row
status = str(system_agent.status).strip().lower()
if status != "active":
raise RuntimeError(f"system agent is not active: {agent_type.value}")
return SystemAgentRuntimeConfig(
agent_type=agent_type,
model_code=llm.model_code,
api_base_url=factory.request_url,
api_key=self._resolve_provider_api_key(factory_name=factory.name),
llm_config=SystemAgentLLMConfig.model_validate(system_agent.config or {}),
)
@staticmethod
def _resolve_provider_api_key(*, factory_name: str) -> str:
normalized_factory_name = factory_name.strip().upper()
if normalized_factory_name == "VOLCENGINE":
normalized_factory_name = "ARK"
provider_keys = {
str(key).strip().upper(): str(value).strip()
for key, value in config.llm.provider_keys.items()
if str(value).strip()
}
api_key = provider_keys.get(normalized_factory_name, "")
if not api_key:
raise RuntimeError(f"provider api key missing for factory: {factory_name}")
return api_key
async def _run_worker_stage(
self,
*,
user_context: UserContext,
context_messages: list[Msg],
toolkit: Any,
run_input: RunAgentInput,
stage_config: SystemAgentRuntimeConfig,
worker_output_model: type[AgentOutput],
pipeline: PipelineLike,
runtime_client_time: ClientTimeContext | None,
) -> StageExecutionResult:
worker_input = list(context_messages)
tracking_model = self._build_model(stage_config=stage_config)
emitter = PipelineStageEmitter(
pipeline=pipeline,
session_id=run_input.thread_id,
run_id=run_input.run_id,
stage=stage_config.agent_type.value,
emit_text_events=True,
emit_tool_events=True,
)
agent = self._build_agent(
agent_name=stage_config.agent_type.value,
system_prompt=build_system_prompt(
agent_type=stage_config.agent_type,
llm_config=stage_config.llm_config,
user_context=user_context,
now_utc=datetime.now(timezone.utc),
runtime_client_time=runtime_client_time,
tools=None,
),
toolkit=toolkit,
model=tracking_model,
emitter=emitter,
)
response_msg = await agent.reply_json(
worker_input, output_model=worker_output_model
)
worker_payload = worker_output_model.model_validate(response_msg.metadata or {})
response_metadata = self._litellm_service.build_usage_metadata(
model=stage_config.model_code,
usage_summary=tracking_model.usage_summary(),
)
await emitter.emit_final_text_end(
worker_output=worker_payload.model_dump(mode="json", exclude_none=True),
response_metadata=response_metadata,
)
return StageExecutionResult(
message=response_msg,
payload=worker_payload.model_dump(mode="json", exclude_none=True),
response_metadata=response_metadata,
)
def _build_model(
self, *, stage_config: SystemAgentRuntimeConfig
) -> TrackingChatModel:
generate_kwargs: dict[str, Any] = {
"temperature": stage_config.llm_config.temperature,
"max_tokens": stage_config.llm_config.max_tokens,
"timeout": stage_config.llm_config.timeout_seconds,
}
generate_kwargs["extra_body"] = {"enable_thinking": False}
model = OpenAIChatModel(
model_name=stage_config.model_code,
api_key=stage_config.api_key,
stream=False,
client_kwargs={"base_url": stage_config.api_base_url},
generate_kwargs=generate_kwargs,
)
return TrackingChatModel(model)
def _build_agent(
self,
*,
agent_name: str,
system_prompt: str,
toolkit: Any,
model: TrackingChatModel,
emitter: PipelineStageEmitter | None = None,
) -> JsonReActAgent:
return JsonReActAgent(
name=agent_name,
sys_prompt=system_prompt,
model=model,
formatter=OpenAIChatFormatter(),
toolkit=toolkit,
memory=InMemoryMemory(),
emitter=emitter,
)
async def _emit_step_event(
self,
*,
pipeline: PipelineLike,
run_input: RunAgentInput,
step_name: str,
event_type: str,
) -> None:
await pipeline.emit(
session_id=run_input.thread_id,
event={
"type": event_type,
"threadId": run_input.thread_id,
"runId": run_input.run_id,
"stepName": step_name,
},
)
def _resolve_runtime_client_time(
self, *, run_input: RunAgentInput
) -> ClientTimeContext | None:
return parse_forwarded_props_client_time(
getattr(run_input, "forwarded_props", None)
)
AgentScopeReActRunner = AgentScopeRunner