refactor(backend): 重构 agentscope 运行时模块

This commit is contained in:
zl-q
2026-03-19 00:52:05 +08:00
parent 9219e8d047
commit f709023b6d
7 changed files with 261 additions and 255 deletions
@@ -0,0 +1,45 @@
from __future__ import annotations
from collections.abc import Awaitable, Callable
from typing import Any
from schemas.agent.system_agent import ContextBuildStrategy
ContextLoader = Callable[[Any, str, int], Awaitable[dict[str, object] | None]]
class ContextLoaderRegistry:
def __init__(self) -> None:
self._loaders: dict[ContextBuildStrategy, ContextLoader] = {}
def register(self, *, mode: ContextBuildStrategy, loader: ContextLoader) -> None:
self._loaders[mode] = loader
def resolve(self, *, mode: ContextBuildStrategy) -> ContextLoader:
loader = self._loaders.get(mode)
if loader is None:
raise ValueError(f"unsupported context mode: {mode.value}")
return loader
async def _load_number(
service: Any, thread_id: str, count: int
) -> dict[str, object] | None:
return await service.load_by_user_message_window(
thread_id=thread_id,
user_message_limit=max(count, 1),
)
async def _load_day(
service: Any, thread_id: str, count: int
) -> dict[str, object] | None:
return await service.load_by_day_window(
thread_id=thread_id,
day_count=max(count, 1),
)
CONTEXT_LOADER_REGISTRY = ContextLoaderRegistry()
CONTEXT_LOADER_REGISTRY.register(mode=ContextBuildStrategy.NUMBER, loader=_load_number)
CONTEXT_LOADER_REGISTRY.register(mode=ContextBuildStrategy.DAY, loader=_load_day)
@@ -0,0 +1,113 @@
from __future__ import annotations
from datetime import date
from typing import Protocol
from core.agentscope.runtime.context_loader_registry import CONTEXT_LOADER_REGISTRY
from schemas.agent.system_agent import SystemAgentLLMConfig
_DEFAULT_CONTEXT_WINDOW_USER_MESSAGES = 20
class ContextRepositoryLike(Protocol):
async def get_history_day(
self, *, session_id: str, before: date | None
) -> dict[str, object] | None: ...
async def get_recent_messages_by_user_window(
self, *, session_id: str, user_message_limit: int
) -> list[dict[str, object]]: ...
async def get_system_agent_config(
self, *, agent_type: str
) -> dict[str, object] | None: ...
class AgentContextService:
def __init__(self, *, repository: ContextRepositoryLike) -> None:
self._repository = repository
async def load_context_messages(
self,
*,
thread_id: str,
system_agent_mode: str,
) -> dict[str, object] | None:
mode = system_agent_mode.strip().lower() if system_agent_mode else "worker"
runtime_config = await self._repository.get_system_agent_config(agent_type=mode)
raw_llm_config: dict[str, object] = {}
if isinstance(runtime_config, dict):
raw_config = runtime_config.get("config")
if isinstance(raw_config, dict):
raw_llm_config = raw_config
normalized_config = self._normalize_system_agent_config(raw_llm_config)
context_config = normalized_config.context_messages
context_loader = CONTEXT_LOADER_REGISTRY.resolve(mode=context_config.mode)
return await context_loader(self, thread_id, context_config.count)
async def load_by_user_message_window(
self,
*,
thread_id: str,
user_message_limit: int,
) -> dict[str, object] | None:
messages = await self._repository.get_recent_messages_by_user_window(
session_id=thread_id,
user_message_limit=max(int(user_message_limit), 1),
)
if not messages:
return None
return {"messages": messages}
async def load_by_day_window(
self,
*,
thread_id: str,
day_count: int,
) -> dict[str, object] | None:
messages: list[dict[str, object]] = []
before: date | None = None
for _ in range(max(day_count, 1)):
day_payload = await self._repository.get_history_day(
session_id=thread_id,
before=before,
)
if not day_payload:
break
day_messages = day_payload.get("messages")
if isinstance(day_messages, list):
messages = [*day_messages, *messages]
before = self._parse_history_day(day_payload.get("day"))
if before is None:
break
if not messages:
return None
return {"messages": messages}
def _normalize_system_agent_config(
self,
raw_config: dict[str, object],
) -> SystemAgentLLMConfig:
default_payload = {
"context_messages": {
"mode": "number",
"count": _DEFAULT_CONTEXT_WINDOW_USER_MESSAGES,
},
"enabled_tool_groups": [],
}
if not raw_config:
return SystemAgentLLMConfig.model_validate(default_payload)
merged = {**default_payload, **raw_config}
return SystemAgentLLMConfig.model_validate(merged)
def _parse_history_day(self, value: object) -> date | None:
if isinstance(value, date):
return value
if isinstance(value, str):
try:
return date.fromisoformat(value)
except ValueError:
return None
return None
@@ -23,6 +23,7 @@ class RunnerLike(Protocol):
context_messages: list[Msg], context_messages: list[Msg],
pipeline: PipelineLike, pipeline: PipelineLike,
run_input: RunAgentInput, run_input: RunAgentInput,
system_agent_mode: str,
) -> dict[str, Any]: ... ) -> dict[str, Any]: ...
@@ -45,6 +46,7 @@ class AgentScopeRuntimeOrchestrator:
run_input: RunAgentInput, run_input: RunAgentInput,
context_messages: list[Msg], context_messages: list[Msg],
user_context: UserContext, user_context: UserContext,
system_agent_mode: str,
) -> dict[str, Any]: ) -> dict[str, Any]:
thread_id = run_input.thread_id thread_id = run_input.thread_id
run_id = run_input.run_id run_id = run_input.run_id
@@ -63,6 +65,7 @@ class AgentScopeRuntimeOrchestrator:
context_messages=context_messages, context_messages=context_messages,
pipeline=self._pipeline, pipeline=self._pipeline,
run_input=run_input, run_input=run_input,
system_agent_mode=system_agent_mode,
) )
await self._pipeline.emit( await self._pipeline.emit(
@@ -1,96 +0,0 @@
from __future__ import annotations
from datetime import datetime, timezone
from decimal import Decimal
from uuid import UUID, uuid4
from core.agentscope.events.persistence import MessageRepository, SessionRepository
from models.agent_chat_message import AgentChatMessageRole
from models.agent_chat_session import AgentChatSessionStatus
from schemas.agent.runtime_models import RouterAgentOutput
from schemas.agent.system_agent import AgentType
from schemas.messages.chat_message import AgentChatMessage, AgentChatMessageMetadata
from sqlalchemy.ext.asyncio import AsyncSession
def _to_int(value: object) -> int:
if value is None:
return 0
if isinstance(value, bool):
return int(value)
if isinstance(value, int):
return value
if isinstance(value, Decimal):
return int(value)
if isinstance(value, float):
return int(value)
if isinstance(value, str):
text = value.strip()
if not text:
return 0
try:
return int(text)
except ValueError:
return int(float(text))
return 0
async def persist_router_message(
*,
session: AsyncSession,
thread_id: str,
run_id: str,
model_code: str,
router_output: RouterAgentOutput,
response_metadata: dict[str, object],
) -> None:
session_id = UUID(thread_id)
message_repo = MessageRepository(session)
session_repo = SessionRepository(session)
locked_session = await session_repo.lock_session_for_update(session_id=session_id)
if locked_session is None:
raise RuntimeError("chat session not found for router persistence")
seq = _to_int(getattr(locked_session, "message_count", 0)) + 1
metadata = AgentChatMessageMetadata(
run_id=run_id,
agent_type=AgentType.ROUTER,
router_agent_output=router_output,
)
message_payload = AgentChatMessage(
id=uuid4(),
seq=seq,
role=AgentChatMessageRole.ASSISTANT.value,
content="",
model_code=model_code,
tool_name=None,
input_tokens=_to_int(response_metadata.get("inputTokens", 0)),
output_tokens=_to_int(response_metadata.get("outputTokens", 0)),
cost=Decimal(str(response_metadata.get("cost", 0) or 0)),
latency_ms=_to_int(response_metadata.get("latencyMs", 0)),
metadata=metadata,
timestamp=datetime.now(timezone.utc),
)
await message_repo.append_message(
session_id=session_id,
seq=message_payload.seq,
role=AgentChatMessageRole.ASSISTANT,
content=message_payload.content,
model_code=message_payload.model_code,
tool_name=message_payload.tool_name,
metadata=metadata.model_dump(mode="json", exclude_none=True),
input_tokens=message_payload.input_tokens,
output_tokens=message_payload.output_tokens,
cost=message_payload.cost,
latency_ms=message_payload.latency_ms,
)
await session_repo.update_runtime_state(
chat_session=locked_session,
status=AgentChatSessionStatus.RUNNING,
state_snapshot=locked_session.state_snapshot or {},
message_delta=1,
token_delta=message_payload.input_tokens + message_payload.output_tokens,
cost_delta=message_payload.cost,
)
await session.flush()
+47 -156
View File
@@ -10,27 +10,20 @@ from agentscope.formatter import OpenAIChatFormatter
from agentscope.memory import InMemoryMemory from agentscope.memory import InMemoryMemory
from agentscope.message import Msg from agentscope.message import Msg
from agentscope.model import OpenAIChatModel from agentscope.model import OpenAIChatModel
from core.agentscope.prompts.agent_prompt import build_worker_contract_prompt
from core.agentscope.prompts.system_prompt import build_system_prompt from core.agentscope.prompts.system_prompt import build_system_prompt
from core.agentscope.runtime.json_react_agent import JsonReActAgent from core.agentscope.runtime.json_react_agent import JsonReActAgent
from core.agentscope.runtime.model_tracking import TrackingChatModel from core.agentscope.runtime.model_tracking import TrackingChatModel
from core.agentscope.runtime.router_persistence import persist_router_message
from core.agentscope.runtime.stage_emitter import PipelineStageEmitter from core.agentscope.runtime.stage_emitter import PipelineStageEmitter
from core.agentscope.runtime.tool_selection_registry import TOOL_SELECTION_REGISTRY
from core.agentscope.tools.toolkit import build_stage_toolkit from core.agentscope.tools.toolkit import build_stage_toolkit
from core.agentscope.utils import ( from core.agentscope.utils import patch_agentscope_json_repair_compat
finalize_json_response,
patch_agentscope_json_repair_compat,
)
from core.config.settings import config from core.config.settings import config
from core.db.session import AsyncSessionLocal from core.db.session import AsyncSessionLocal
from core.logging import get_logger
from models.llm import Llm from models.llm import Llm
from models.llm_factory import LlmFactory from models.llm_factory import LlmFactory
from models.system_agents import SystemAgents from models.system_agents import SystemAgents
from schemas.agent.runtime_models import ( from schemas.agent.runtime_models import (
RouterAgentOutput, AgentOutput,
WorkerAgentOutputLite,
resolve_worker_output_model,
) )
from schemas.agent.forwarded_props import ( from schemas.agent.forwarded_props import (
ClientTimeContext, ClientTimeContext,
@@ -45,8 +38,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
if TYPE_CHECKING: if TYPE_CHECKING:
from core.agentscope.runtime.orchestrator import PipelineLike from core.agentscope.runtime.orchestrator import PipelineLike
logger = get_logger("core.agentscope.runtime.runner")
@dataclass(frozen=True) @dataclass(frozen=True)
class SystemAgentRuntimeConfig: class SystemAgentRuntimeConfig:
@@ -76,110 +67,68 @@ class AgentScopeRunner:
context_messages: list[Msg], context_messages: list[Msg],
pipeline: PipelineLike, pipeline: PipelineLike,
run_input: RunAgentInput, run_input: RunAgentInput,
system_agent_mode: str,
) -> dict[str, Any]: ) -> dict[str, Any]:
owner_id = UUID(user_context.id) owner_id = UUID(user_context.id)
runtime_client_time = self._resolve_runtime_client_time(run_input=run_input) runtime_client_time = self._resolve_runtime_client_time(run_input=run_input)
stage_agent_type = self._resolve_stage_agent_type(system_agent_mode)
async with AsyncSessionLocal() as session: async with AsyncSessionLocal() as session:
worker_toolkit = self._build_worker_toolkit( stage_config = await self._load_stage_config(
session=session, owner_id=owner_id
)
router_config, worker_config = await self._load_stage_configs(
session=session
)
router_output = await self._execute_router_step(
session=session, session=session,
pipeline=pipeline, agent_type=stage_agent_type,
run_input=run_input, )
user_context=user_context, stage_toolkit = self._build_stage_toolkit(
context_messages=context_messages, session=session,
stage_config=router_config, owner_id=owner_id,
runtime_client_time=runtime_client_time, stage_config=stage_config,
) )
worker_output = await self._execute_worker_step( worker_output = await self._execute_worker_step(
pipeline=pipeline, pipeline=pipeline,
run_input=run_input, run_input=run_input,
user_context=user_context, user_context=user_context,
router_output=router_output, context_messages=context_messages,
toolkit=worker_toolkit, toolkit=stage_toolkit,
stage_config=worker_config, stage_config=stage_config,
runtime_client_time=runtime_client_time, runtime_client_time=runtime_client_time,
) )
return { return {
"router": router_output.model_dump(mode="json", exclude_none=True),
"worker": worker_output.model_dump(mode="json", exclude_none=True), "worker": worker_output.model_dump(mode="json", exclude_none=True),
} }
def _build_worker_toolkit( def _build_stage_toolkit(
self, self,
*, *,
session: AsyncSession, session: AsyncSession,
owner_id: UUID, owner_id: UUID,
stage_config: SystemAgentRuntimeConfig,
) -> Any: ) -> Any:
enabled_tool_names = TOOL_SELECTION_REGISTRY.resolve(stage_config=stage_config)
return build_stage_toolkit( return build_stage_toolkit(
agent_type=AgentType.WORKER, agent_type=stage_config.agent_type,
session=session, session=session,
owner_id=owner_id, owner_id=owner_id,
enabled_tool_names=enabled_tool_names,
) )
async def _load_stage_configs( @staticmethod
def _resolve_stage_agent_type(system_agent_mode: str) -> AgentType:
mode = system_agent_mode.strip().lower() if system_agent_mode else "worker"
if mode == AgentType.MEMORY.value:
return AgentType.MEMORY
return AgentType.WORKER
async def _load_stage_config(
self, self,
*, *,
session: AsyncSession, session: AsyncSession,
) -> tuple[SystemAgentRuntimeConfig, SystemAgentRuntimeConfig]: agent_type: AgentType,
router_config = await self._load_system_agent_config( ) -> SystemAgentRuntimeConfig:
return await self._load_system_agent_config(
session=session, session=session,
agent_type=AgentType.ROUTER, agent_type=agent_type,
) )
worker_config = await self._load_system_agent_config(
session=session,
agent_type=AgentType.WORKER,
)
return router_config, worker_config
async def _execute_router_step(
self,
*,
session: AsyncSession,
pipeline: PipelineLike,
run_input: RunAgentInput,
user_context: UserContext,
context_messages: list[Msg],
stage_config: SystemAgentRuntimeConfig,
runtime_client_time: ClientTimeContext | None,
) -> RouterAgentOutput:
await self._emit_step_event(
pipeline=pipeline,
run_input=run_input,
step_name="router",
event_type="STEP_STARTED",
)
router_result = await self._run_router_stage(
user_context=user_context,
context_messages=context_messages,
run_input=run_input,
stage_config=stage_config,
runtime_client_time=runtime_client_time,
)
router_output = RouterAgentOutput.model_validate(router_result.payload)
await persist_router_message(
session=session,
thread_id=run_input.thread_id,
run_id=run_input.run_id,
model_code=stage_config.model_code,
router_output=router_output,
response_metadata=router_result.response_metadata,
)
await session.commit()
await self._emit_step_event(
pipeline=pipeline,
run_input=run_input,
step_name="router",
event_type="STEP_FINISHED",
)
return router_output
async def _execute_worker_step( async def _execute_worker_step(
self, self,
@@ -187,21 +136,22 @@ class AgentScopeRunner:
pipeline: PipelineLike, pipeline: PipelineLike,
run_input: RunAgentInput, run_input: RunAgentInput,
user_context: UserContext, user_context: UserContext,
router_output: RouterAgentOutput, context_messages: list[Msg],
toolkit: Any, toolkit: Any,
stage_config: SystemAgentRuntimeConfig, stage_config: SystemAgentRuntimeConfig,
runtime_client_time: ClientTimeContext | None, runtime_client_time: ClientTimeContext | None,
) -> WorkerAgentOutputLite: ) -> AgentOutput:
worker_output_model = resolve_worker_output_model(router_output.ui.ui_mode) step_name = stage_config.agent_type.value
worker_output_model = AgentOutput
await self._emit_step_event( await self._emit_step_event(
pipeline=pipeline, pipeline=pipeline,
run_input=run_input, run_input=run_input,
step_name="worker", step_name=step_name,
event_type="STEP_STARTED", event_type="STEP_STARTED",
) )
worker_result = await self._run_worker_stage( worker_result = await self._run_worker_stage(
user_context=user_context, user_context=user_context,
router_output=router_output, context_messages=context_messages,
toolkit=toolkit, toolkit=toolkit,
run_input=run_input, run_input=run_input,
stage_config=stage_config, stage_config=stage_config,
@@ -213,7 +163,7 @@ class AgentScopeRunner:
await self._emit_step_event( await self._emit_step_event(
pipeline=pipeline, pipeline=pipeline,
run_input=run_input, run_input=run_input,
step_name="worker", step_name=step_name,
event_type="STEP_FINISHED", event_type="STEP_FINISHED",
) )
return worker_output return worker_output
@@ -261,78 +211,33 @@ class AgentScopeRunner:
raise RuntimeError(f"provider api key missing for factory: {factory_name}") raise RuntimeError(f"provider api key missing for factory: {factory_name}")
return api_key return api_key
async def _run_router_stage(
self,
*,
user_context: UserContext,
context_messages: list[Msg],
run_input: RunAgentInput,
stage_config: SystemAgentRuntimeConfig,
runtime_client_time: ClientTimeContext | None,
) -> StageExecutionResult:
tracking_model = self._build_model(stage_config=stage_config)
system_prompt = build_system_prompt(
agent_type=AgentType.ROUTER,
user_context=user_context,
now_utc=datetime.now(timezone.utc),
runtime_client_time=runtime_client_time,
tools=None,
)
response, payload = await finalize_json_response(
model=tracking_model,
formatter=OpenAIChatFormatter(),
base_messages=[Msg("system", system_prompt, "system"), *context_messages],
output_model=RouterAgentOutput,
retries=0,
)
response_msg = Msg(
name="router",
role="assistant",
content=list(getattr(response, "content", [])),
metadata=payload,
)
logger.info(
"router_reply_received",
run_id=run_input.run_id,
thread_id=run_input.thread_id,
message_id=str(response_msg.id),
)
return StageExecutionResult(
message=response_msg,
payload=payload,
response_metadata=self._litellm_service.build_usage_metadata(
model=stage_config.model_code,
usage_summary=tracking_model.usage_summary(),
),
)
async def _run_worker_stage( async def _run_worker_stage(
self, self,
*, *,
user_context: UserContext, user_context: UserContext,
router_output: RouterAgentOutput, context_messages: list[Msg],
toolkit: Any, toolkit: Any,
run_input: RunAgentInput, run_input: RunAgentInput,
stage_config: SystemAgentRuntimeConfig, stage_config: SystemAgentRuntimeConfig,
worker_output_model: type[WorkerAgentOutputLite], worker_output_model: type[AgentOutput],
pipeline: PipelineLike, pipeline: PipelineLike,
runtime_client_time: ClientTimeContext | None, runtime_client_time: ClientTimeContext | None,
) -> StageExecutionResult: ) -> StageExecutionResult:
worker_input = self._build_worker_input_messages(router_output=router_output) worker_input = list(context_messages)
tracking_model = self._build_model(stage_config=stage_config) tracking_model = self._build_model(stage_config=stage_config)
emitter = PipelineStageEmitter( emitter = PipelineStageEmitter(
pipeline=pipeline, pipeline=pipeline,
session_id=run_input.thread_id, session_id=run_input.thread_id,
run_id=run_input.run_id, run_id=run_input.run_id,
stage="worker", stage=stage_config.agent_type.value,
emit_text_events=True, emit_text_events=True,
emit_tool_events=True, emit_tool_events=True,
) )
agent = self._build_agent( agent = self._build_agent(
agent_name="worker", agent_name=stage_config.agent_type.value,
system_prompt=build_system_prompt( system_prompt=build_system_prompt(
agent_type=AgentType.WORKER, agent_type=stage_config.agent_type,
llm_config=stage_config.llm_config,
user_context=user_context, user_context=user_context,
now_utc=datetime.now(timezone.utc), now_utc=datetime.now(timezone.utc),
runtime_client_time=runtime_client_time, runtime_client_time=runtime_client_time,
@@ -360,19 +265,6 @@ class AgentScopeRunner:
response_metadata=response_metadata, response_metadata=response_metadata,
) )
def _build_worker_input_messages(
self,
*,
router_output: RouterAgentOutput,
) -> list[Msg]:
return [
Msg(
name="router",
role="user",
content=build_worker_contract_prompt(router_output=router_output),
)
]
def _build_model( def _build_model(
self, *, stage_config: SystemAgentRuntimeConfig self, *, stage_config: SystemAgentRuntimeConfig
) -> TrackingChatModel: ) -> TrackingChatModel:
@@ -381,8 +273,7 @@ class AgentScopeRunner:
"max_tokens": stage_config.llm_config.max_tokens, "max_tokens": stage_config.llm_config.max_tokens,
"timeout": stage_config.llm_config.timeout_seconds, "timeout": stage_config.llm_config.timeout_seconds,
} }
if stage_config.agent_type == AgentType.ROUTER: generate_kwargs["extra_body"] = {"enable_thinking": False}
generate_kwargs["extra_body"] = {"enable_thinking": False}
model = OpenAIChatModel( model = OpenAIChatModel(
model_name=stage_config.model_code, model_name=stage_config.model_code,
+11 -3
View File
@@ -12,6 +12,7 @@ from core.agentscope.events import (
RedisStreamBus, RedisStreamBus,
SqlAlchemyEventStore, SqlAlchemyEventStore,
) )
from core.agentscope.runtime.context_service import AgentContextService
from core.agentscope.runtime.orchestrator import AgentScopeRuntimeOrchestrator from core.agentscope.runtime.orchestrator import AgentScopeRuntimeOrchestrator
from core.agentscope.schemas.agui_input import parse_run_input from core.agentscope.schemas.agui_input import parse_run_input
from core.auth.models import CurrentUser from core.auth.models import CurrentUser
@@ -26,7 +27,7 @@ from schemas.messages.chat_message import (
from schemas.user import UserContext from schemas.user import UserContext
from services.base.redis import get_or_init_redis_client from services.base.redis import get_or_init_redis_client
from services.base.supabase import supabase_service from services.base.supabase import supabase_service
from v1.agent.dependencies import get_agent_service from v1.agent.repository import AgentRepository
from v1.users.dependencies import get_user_service from v1.users.dependencies import get_user_service
logger = get_logger("core.agentscope.runtime.tasks") logger = get_logger("core.agentscope.runtime.tasks")
@@ -78,9 +79,13 @@ async def _build_recent_context_messages(
*, *,
session: Any, session: Any,
thread_id: str, thread_id: str,
system_agent_mode: str,
) -> list[Msg]: ) -> list[Msg]:
agent_service = get_agent_service(session) context_service = AgentContextService(repository=AgentRepository(session))
result = await agent_service.load_agent_input_messages(thread_id=thread_id) result = await context_service.load_context_messages(
thread_id=thread_id,
system_agent_mode=system_agent_mode,
)
if not result: if not result:
return [] return []
@@ -165,6 +170,7 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
command_type = str(command.get("command", "run")).strip().lower() command_type = str(command.get("command", "run")).strip().lower()
raw_owner_id = command.get("owner_id") raw_owner_id = command.get("owner_id")
run_input_raw = command.get("run_input") run_input_raw = command.get("run_input")
system_agent_mode = str(command.get("system_agent_mode", "worker")).strip().lower()
if not isinstance(raw_owner_id, str) or not raw_owner_id.strip(): if not isinstance(raw_owner_id, str) or not raw_owner_id.strip():
raise ValueError("owner_id is required") raise ValueError("owner_id is required")
@@ -205,12 +211,14 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
context_messages = await _build_recent_context_messages( context_messages = await _build_recent_context_messages(
session=session, session=session,
thread_id=thread_id, thread_id=thread_id,
system_agent_mode=system_agent_mode,
) )
await runtime.run( await runtime.run(
run_input=run_input, run_input=run_input,
context_messages=context_messages, context_messages=context_messages,
user_context=user_context, user_context=user_context,
system_agent_mode=system_agent_mode,
) )
logger.info( logger.info(
"agentscope runtime task completed", "agentscope runtime task completed",
@@ -0,0 +1,42 @@
from __future__ import annotations
from collections.abc import Callable
from typing import Any
from core.agentscope.tools.tool_config import resolve_tool_names_by_groups
from schemas.agent.system_agent import AgentType
ToolNameResolver = Callable[[Any], set[str] | None]
class ToolSelectionRegistry:
def __init__(self) -> None:
self._resolvers: dict[AgentType, ToolNameResolver] = {}
def register(self, *, agent_type: AgentType, resolver: ToolNameResolver) -> None:
self._resolvers[agent_type] = resolver
def resolve(self, *, stage_config: Any) -> set[str] | None:
resolver = self._resolvers.get(stage_config.agent_type)
if resolver is None:
return None
return resolver(stage_config)
def _default_group_resolver(stage_config: Any) -> set[str] | None:
raw_groups = getattr(stage_config.llm_config, "enabled_tool_groups", [])
groups = raw_groups if isinstance(raw_groups, list) else []
if not groups:
return None
return resolve_tool_names_by_groups(set(groups))
TOOL_SELECTION_REGISTRY = ToolSelectionRegistry()
TOOL_SELECTION_REGISTRY.register(
agent_type=AgentType.WORKER,
resolver=_default_group_resolver,
)
TOOL_SELECTION_REGISTRY.register(
agent_type=AgentType.MEMORY,
resolver=_default_group_resolver,
)