refactor(backend): 重构 agentscope 运行时模块
This commit is contained in:
@@ -0,0 +1,45 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from schemas.agent.system_agent import ContextBuildStrategy
|
||||||
|
|
||||||
|
ContextLoader = Callable[[Any, str, int], Awaitable[dict[str, object] | None]]
|
||||||
|
|
||||||
|
|
||||||
|
class ContextLoaderRegistry:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._loaders: dict[ContextBuildStrategy, ContextLoader] = {}
|
||||||
|
|
||||||
|
def register(self, *, mode: ContextBuildStrategy, loader: ContextLoader) -> None:
|
||||||
|
self._loaders[mode] = loader
|
||||||
|
|
||||||
|
def resolve(self, *, mode: ContextBuildStrategy) -> ContextLoader:
|
||||||
|
loader = self._loaders.get(mode)
|
||||||
|
if loader is None:
|
||||||
|
raise ValueError(f"unsupported context mode: {mode.value}")
|
||||||
|
return loader
|
||||||
|
|
||||||
|
|
||||||
|
async def _load_number(
|
||||||
|
service: Any, thread_id: str, count: int
|
||||||
|
) -> dict[str, object] | None:
|
||||||
|
return await service.load_by_user_message_window(
|
||||||
|
thread_id=thread_id,
|
||||||
|
user_message_limit=max(count, 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _load_day(
|
||||||
|
service: Any, thread_id: str, count: int
|
||||||
|
) -> dict[str, object] | None:
|
||||||
|
return await service.load_by_day_window(
|
||||||
|
thread_id=thread_id,
|
||||||
|
day_count=max(count, 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
CONTEXT_LOADER_REGISTRY = ContextLoaderRegistry()
|
||||||
|
CONTEXT_LOADER_REGISTRY.register(mode=ContextBuildStrategy.NUMBER, loader=_load_number)
|
||||||
|
CONTEXT_LOADER_REGISTRY.register(mode=ContextBuildStrategy.DAY, loader=_load_day)
|
||||||
@@ -0,0 +1,113 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import date
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
from core.agentscope.runtime.context_loader_registry import CONTEXT_LOADER_REGISTRY
|
||||||
|
from schemas.agent.system_agent import SystemAgentLLMConfig
|
||||||
|
|
||||||
|
_DEFAULT_CONTEXT_WINDOW_USER_MESSAGES = 20
|
||||||
|
|
||||||
|
|
||||||
|
class ContextRepositoryLike(Protocol):
|
||||||
|
async def get_history_day(
|
||||||
|
self, *, session_id: str, before: date | None
|
||||||
|
) -> dict[str, object] | None: ...
|
||||||
|
|
||||||
|
async def get_recent_messages_by_user_window(
|
||||||
|
self, *, session_id: str, user_message_limit: int
|
||||||
|
) -> list[dict[str, object]]: ...
|
||||||
|
|
||||||
|
async def get_system_agent_config(
|
||||||
|
self, *, agent_type: str
|
||||||
|
) -> dict[str, object] | None: ...
|
||||||
|
|
||||||
|
|
||||||
|
class AgentContextService:
|
||||||
|
def __init__(self, *, repository: ContextRepositoryLike) -> None:
|
||||||
|
self._repository = repository
|
||||||
|
|
||||||
|
async def load_context_messages(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
thread_id: str,
|
||||||
|
system_agent_mode: str,
|
||||||
|
) -> dict[str, object] | None:
|
||||||
|
mode = system_agent_mode.strip().lower() if system_agent_mode else "worker"
|
||||||
|
runtime_config = await self._repository.get_system_agent_config(agent_type=mode)
|
||||||
|
raw_llm_config: dict[str, object] = {}
|
||||||
|
if isinstance(runtime_config, dict):
|
||||||
|
raw_config = runtime_config.get("config")
|
||||||
|
if isinstance(raw_config, dict):
|
||||||
|
raw_llm_config = raw_config
|
||||||
|
|
||||||
|
normalized_config = self._normalize_system_agent_config(raw_llm_config)
|
||||||
|
context_config = normalized_config.context_messages
|
||||||
|
context_loader = CONTEXT_LOADER_REGISTRY.resolve(mode=context_config.mode)
|
||||||
|
return await context_loader(self, thread_id, context_config.count)
|
||||||
|
|
||||||
|
async def load_by_user_message_window(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
thread_id: str,
|
||||||
|
user_message_limit: int,
|
||||||
|
) -> dict[str, object] | None:
|
||||||
|
messages = await self._repository.get_recent_messages_by_user_window(
|
||||||
|
session_id=thread_id,
|
||||||
|
user_message_limit=max(int(user_message_limit), 1),
|
||||||
|
)
|
||||||
|
if not messages:
|
||||||
|
return None
|
||||||
|
return {"messages": messages}
|
||||||
|
|
||||||
|
async def load_by_day_window(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
thread_id: str,
|
||||||
|
day_count: int,
|
||||||
|
) -> dict[str, object] | None:
|
||||||
|
messages: list[dict[str, object]] = []
|
||||||
|
before: date | None = None
|
||||||
|
for _ in range(max(day_count, 1)):
|
||||||
|
day_payload = await self._repository.get_history_day(
|
||||||
|
session_id=thread_id,
|
||||||
|
before=before,
|
||||||
|
)
|
||||||
|
if not day_payload:
|
||||||
|
break
|
||||||
|
day_messages = day_payload.get("messages")
|
||||||
|
if isinstance(day_messages, list):
|
||||||
|
messages = [*day_messages, *messages]
|
||||||
|
before = self._parse_history_day(day_payload.get("day"))
|
||||||
|
if before is None:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not messages:
|
||||||
|
return None
|
||||||
|
return {"messages": messages}
|
||||||
|
|
||||||
|
def _normalize_system_agent_config(
|
||||||
|
self,
|
||||||
|
raw_config: dict[str, object],
|
||||||
|
) -> SystemAgentLLMConfig:
|
||||||
|
default_payload = {
|
||||||
|
"context_messages": {
|
||||||
|
"mode": "number",
|
||||||
|
"count": _DEFAULT_CONTEXT_WINDOW_USER_MESSAGES,
|
||||||
|
},
|
||||||
|
"enabled_tool_groups": [],
|
||||||
|
}
|
||||||
|
if not raw_config:
|
||||||
|
return SystemAgentLLMConfig.model_validate(default_payload)
|
||||||
|
merged = {**default_payload, **raw_config}
|
||||||
|
return SystemAgentLLMConfig.model_validate(merged)
|
||||||
|
|
||||||
|
def _parse_history_day(self, value: object) -> date | None:
|
||||||
|
if isinstance(value, date):
|
||||||
|
return value
|
||||||
|
if isinstance(value, str):
|
||||||
|
try:
|
||||||
|
return date.fromisoformat(value)
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
return None
|
||||||
@@ -23,6 +23,7 @@ class RunnerLike(Protocol):
|
|||||||
context_messages: list[Msg],
|
context_messages: list[Msg],
|
||||||
pipeline: PipelineLike,
|
pipeline: PipelineLike,
|
||||||
run_input: RunAgentInput,
|
run_input: RunAgentInput,
|
||||||
|
system_agent_mode: str,
|
||||||
) -> dict[str, Any]: ...
|
) -> dict[str, Any]: ...
|
||||||
|
|
||||||
|
|
||||||
@@ -45,6 +46,7 @@ class AgentScopeRuntimeOrchestrator:
|
|||||||
run_input: RunAgentInput,
|
run_input: RunAgentInput,
|
||||||
context_messages: list[Msg],
|
context_messages: list[Msg],
|
||||||
user_context: UserContext,
|
user_context: UserContext,
|
||||||
|
system_agent_mode: str,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
thread_id = run_input.thread_id
|
thread_id = run_input.thread_id
|
||||||
run_id = run_input.run_id
|
run_id = run_input.run_id
|
||||||
@@ -63,6 +65,7 @@ class AgentScopeRuntimeOrchestrator:
|
|||||||
context_messages=context_messages,
|
context_messages=context_messages,
|
||||||
pipeline=self._pipeline,
|
pipeline=self._pipeline,
|
||||||
run_input=run_input,
|
run_input=run_input,
|
||||||
|
system_agent_mode=system_agent_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
await self._pipeline.emit(
|
await self._pipeline.emit(
|
||||||
|
|||||||
@@ -1,96 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from decimal import Decimal
|
|
||||||
from uuid import UUID, uuid4
|
|
||||||
|
|
||||||
from core.agentscope.events.persistence import MessageRepository, SessionRepository
|
|
||||||
from models.agent_chat_message import AgentChatMessageRole
|
|
||||||
from models.agent_chat_session import AgentChatSessionStatus
|
|
||||||
from schemas.agent.runtime_models import RouterAgentOutput
|
|
||||||
from schemas.agent.system_agent import AgentType
|
|
||||||
from schemas.messages.chat_message import AgentChatMessage, AgentChatMessageMetadata
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
|
|
||||||
def _to_int(value: object) -> int:
|
|
||||||
if value is None:
|
|
||||||
return 0
|
|
||||||
if isinstance(value, bool):
|
|
||||||
return int(value)
|
|
||||||
if isinstance(value, int):
|
|
||||||
return value
|
|
||||||
if isinstance(value, Decimal):
|
|
||||||
return int(value)
|
|
||||||
if isinstance(value, float):
|
|
||||||
return int(value)
|
|
||||||
if isinstance(value, str):
|
|
||||||
text = value.strip()
|
|
||||||
if not text:
|
|
||||||
return 0
|
|
||||||
try:
|
|
||||||
return int(text)
|
|
||||||
except ValueError:
|
|
||||||
return int(float(text))
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
async def persist_router_message(
|
|
||||||
*,
|
|
||||||
session: AsyncSession,
|
|
||||||
thread_id: str,
|
|
||||||
run_id: str,
|
|
||||||
model_code: str,
|
|
||||||
router_output: RouterAgentOutput,
|
|
||||||
response_metadata: dict[str, object],
|
|
||||||
) -> None:
|
|
||||||
session_id = UUID(thread_id)
|
|
||||||
message_repo = MessageRepository(session)
|
|
||||||
session_repo = SessionRepository(session)
|
|
||||||
locked_session = await session_repo.lock_session_for_update(session_id=session_id)
|
|
||||||
if locked_session is None:
|
|
||||||
raise RuntimeError("chat session not found for router persistence")
|
|
||||||
|
|
||||||
seq = _to_int(getattr(locked_session, "message_count", 0)) + 1
|
|
||||||
metadata = AgentChatMessageMetadata(
|
|
||||||
run_id=run_id,
|
|
||||||
agent_type=AgentType.ROUTER,
|
|
||||||
router_agent_output=router_output,
|
|
||||||
)
|
|
||||||
message_payload = AgentChatMessage(
|
|
||||||
id=uuid4(),
|
|
||||||
seq=seq,
|
|
||||||
role=AgentChatMessageRole.ASSISTANT.value,
|
|
||||||
content="",
|
|
||||||
model_code=model_code,
|
|
||||||
tool_name=None,
|
|
||||||
input_tokens=_to_int(response_metadata.get("inputTokens", 0)),
|
|
||||||
output_tokens=_to_int(response_metadata.get("outputTokens", 0)),
|
|
||||||
cost=Decimal(str(response_metadata.get("cost", 0) or 0)),
|
|
||||||
latency_ms=_to_int(response_metadata.get("latencyMs", 0)),
|
|
||||||
metadata=metadata,
|
|
||||||
timestamp=datetime.now(timezone.utc),
|
|
||||||
)
|
|
||||||
|
|
||||||
await message_repo.append_message(
|
|
||||||
session_id=session_id,
|
|
||||||
seq=message_payload.seq,
|
|
||||||
role=AgentChatMessageRole.ASSISTANT,
|
|
||||||
content=message_payload.content,
|
|
||||||
model_code=message_payload.model_code,
|
|
||||||
tool_name=message_payload.tool_name,
|
|
||||||
metadata=metadata.model_dump(mode="json", exclude_none=True),
|
|
||||||
input_tokens=message_payload.input_tokens,
|
|
||||||
output_tokens=message_payload.output_tokens,
|
|
||||||
cost=message_payload.cost,
|
|
||||||
latency_ms=message_payload.latency_ms,
|
|
||||||
)
|
|
||||||
await session_repo.update_runtime_state(
|
|
||||||
chat_session=locked_session,
|
|
||||||
status=AgentChatSessionStatus.RUNNING,
|
|
||||||
state_snapshot=locked_session.state_snapshot or {},
|
|
||||||
message_delta=1,
|
|
||||||
token_delta=message_payload.input_tokens + message_payload.output_tokens,
|
|
||||||
cost_delta=message_payload.cost,
|
|
||||||
)
|
|
||||||
await session.flush()
|
|
||||||
@@ -10,27 +10,20 @@ from agentscope.formatter import OpenAIChatFormatter
|
|||||||
from agentscope.memory import InMemoryMemory
|
from agentscope.memory import InMemoryMemory
|
||||||
from agentscope.message import Msg
|
from agentscope.message import Msg
|
||||||
from agentscope.model import OpenAIChatModel
|
from agentscope.model import OpenAIChatModel
|
||||||
from core.agentscope.prompts.agent_prompt import build_worker_contract_prompt
|
|
||||||
from core.agentscope.prompts.system_prompt import build_system_prompt
|
from core.agentscope.prompts.system_prompt import build_system_prompt
|
||||||
from core.agentscope.runtime.json_react_agent import JsonReActAgent
|
from core.agentscope.runtime.json_react_agent import JsonReActAgent
|
||||||
from core.agentscope.runtime.model_tracking import TrackingChatModel
|
from core.agentscope.runtime.model_tracking import TrackingChatModel
|
||||||
from core.agentscope.runtime.router_persistence import persist_router_message
|
|
||||||
from core.agentscope.runtime.stage_emitter import PipelineStageEmitter
|
from core.agentscope.runtime.stage_emitter import PipelineStageEmitter
|
||||||
|
from core.agentscope.runtime.tool_selection_registry import TOOL_SELECTION_REGISTRY
|
||||||
from core.agentscope.tools.toolkit import build_stage_toolkit
|
from core.agentscope.tools.toolkit import build_stage_toolkit
|
||||||
from core.agentscope.utils import (
|
from core.agentscope.utils import patch_agentscope_json_repair_compat
|
||||||
finalize_json_response,
|
|
||||||
patch_agentscope_json_repair_compat,
|
|
||||||
)
|
|
||||||
from core.config.settings import config
|
from core.config.settings import config
|
||||||
from core.db.session import AsyncSessionLocal
|
from core.db.session import AsyncSessionLocal
|
||||||
from core.logging import get_logger
|
|
||||||
from models.llm import Llm
|
from models.llm import Llm
|
||||||
from models.llm_factory import LlmFactory
|
from models.llm_factory import LlmFactory
|
||||||
from models.system_agents import SystemAgents
|
from models.system_agents import SystemAgents
|
||||||
from schemas.agent.runtime_models import (
|
from schemas.agent.runtime_models import (
|
||||||
RouterAgentOutput,
|
AgentOutput,
|
||||||
WorkerAgentOutputLite,
|
|
||||||
resolve_worker_output_model,
|
|
||||||
)
|
)
|
||||||
from schemas.agent.forwarded_props import (
|
from schemas.agent.forwarded_props import (
|
||||||
ClientTimeContext,
|
ClientTimeContext,
|
||||||
@@ -45,8 +38,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from core.agentscope.runtime.orchestrator import PipelineLike
|
from core.agentscope.runtime.orchestrator import PipelineLike
|
||||||
|
|
||||||
logger = get_logger("core.agentscope.runtime.runner")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class SystemAgentRuntimeConfig:
|
class SystemAgentRuntimeConfig:
|
||||||
@@ -76,110 +67,68 @@ class AgentScopeRunner:
|
|||||||
context_messages: list[Msg],
|
context_messages: list[Msg],
|
||||||
pipeline: PipelineLike,
|
pipeline: PipelineLike,
|
||||||
run_input: RunAgentInput,
|
run_input: RunAgentInput,
|
||||||
|
system_agent_mode: str,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
owner_id = UUID(user_context.id)
|
owner_id = UUID(user_context.id)
|
||||||
runtime_client_time = self._resolve_runtime_client_time(run_input=run_input)
|
runtime_client_time = self._resolve_runtime_client_time(run_input=run_input)
|
||||||
|
stage_agent_type = self._resolve_stage_agent_type(system_agent_mode)
|
||||||
|
|
||||||
async with AsyncSessionLocal() as session:
|
async with AsyncSessionLocal() as session:
|
||||||
worker_toolkit = self._build_worker_toolkit(
|
stage_config = await self._load_stage_config(
|
||||||
session=session, owner_id=owner_id
|
|
||||||
)
|
|
||||||
router_config, worker_config = await self._load_stage_configs(
|
|
||||||
session=session
|
|
||||||
)
|
|
||||||
|
|
||||||
router_output = await self._execute_router_step(
|
|
||||||
session=session,
|
session=session,
|
||||||
pipeline=pipeline,
|
agent_type=stage_agent_type,
|
||||||
run_input=run_input,
|
)
|
||||||
user_context=user_context,
|
stage_toolkit = self._build_stage_toolkit(
|
||||||
context_messages=context_messages,
|
session=session,
|
||||||
stage_config=router_config,
|
owner_id=owner_id,
|
||||||
runtime_client_time=runtime_client_time,
|
stage_config=stage_config,
|
||||||
)
|
)
|
||||||
worker_output = await self._execute_worker_step(
|
worker_output = await self._execute_worker_step(
|
||||||
pipeline=pipeline,
|
pipeline=pipeline,
|
||||||
run_input=run_input,
|
run_input=run_input,
|
||||||
user_context=user_context,
|
user_context=user_context,
|
||||||
router_output=router_output,
|
context_messages=context_messages,
|
||||||
toolkit=worker_toolkit,
|
toolkit=stage_toolkit,
|
||||||
stage_config=worker_config,
|
stage_config=stage_config,
|
||||||
runtime_client_time=runtime_client_time,
|
runtime_client_time=runtime_client_time,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"router": router_output.model_dump(mode="json", exclude_none=True),
|
|
||||||
"worker": worker_output.model_dump(mode="json", exclude_none=True),
|
"worker": worker_output.model_dump(mode="json", exclude_none=True),
|
||||||
}
|
}
|
||||||
|
|
||||||
def _build_worker_toolkit(
|
def _build_stage_toolkit(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
owner_id: UUID,
|
owner_id: UUID,
|
||||||
|
stage_config: SystemAgentRuntimeConfig,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
|
enabled_tool_names = TOOL_SELECTION_REGISTRY.resolve(stage_config=stage_config)
|
||||||
return build_stage_toolkit(
|
return build_stage_toolkit(
|
||||||
agent_type=AgentType.WORKER,
|
agent_type=stage_config.agent_type,
|
||||||
session=session,
|
session=session,
|
||||||
owner_id=owner_id,
|
owner_id=owner_id,
|
||||||
|
enabled_tool_names=enabled_tool_names,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _load_stage_configs(
|
@staticmethod
|
||||||
|
def _resolve_stage_agent_type(system_agent_mode: str) -> AgentType:
|
||||||
|
mode = system_agent_mode.strip().lower() if system_agent_mode else "worker"
|
||||||
|
if mode == AgentType.MEMORY.value:
|
||||||
|
return AgentType.MEMORY
|
||||||
|
return AgentType.WORKER
|
||||||
|
|
||||||
|
async def _load_stage_config(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
) -> tuple[SystemAgentRuntimeConfig, SystemAgentRuntimeConfig]:
|
agent_type: AgentType,
|
||||||
router_config = await self._load_system_agent_config(
|
) -> SystemAgentRuntimeConfig:
|
||||||
|
return await self._load_system_agent_config(
|
||||||
session=session,
|
session=session,
|
||||||
agent_type=AgentType.ROUTER,
|
agent_type=agent_type,
|
||||||
)
|
)
|
||||||
worker_config = await self._load_system_agent_config(
|
|
||||||
session=session,
|
|
||||||
agent_type=AgentType.WORKER,
|
|
||||||
)
|
|
||||||
return router_config, worker_config
|
|
||||||
|
|
||||||
async def _execute_router_step(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
session: AsyncSession,
|
|
||||||
pipeline: PipelineLike,
|
|
||||||
run_input: RunAgentInput,
|
|
||||||
user_context: UserContext,
|
|
||||||
context_messages: list[Msg],
|
|
||||||
stage_config: SystemAgentRuntimeConfig,
|
|
||||||
runtime_client_time: ClientTimeContext | None,
|
|
||||||
) -> RouterAgentOutput:
|
|
||||||
await self._emit_step_event(
|
|
||||||
pipeline=pipeline,
|
|
||||||
run_input=run_input,
|
|
||||||
step_name="router",
|
|
||||||
event_type="STEP_STARTED",
|
|
||||||
)
|
|
||||||
router_result = await self._run_router_stage(
|
|
||||||
user_context=user_context,
|
|
||||||
context_messages=context_messages,
|
|
||||||
run_input=run_input,
|
|
||||||
stage_config=stage_config,
|
|
||||||
runtime_client_time=runtime_client_time,
|
|
||||||
)
|
|
||||||
router_output = RouterAgentOutput.model_validate(router_result.payload)
|
|
||||||
await persist_router_message(
|
|
||||||
session=session,
|
|
||||||
thread_id=run_input.thread_id,
|
|
||||||
run_id=run_input.run_id,
|
|
||||||
model_code=stage_config.model_code,
|
|
||||||
router_output=router_output,
|
|
||||||
response_metadata=router_result.response_metadata,
|
|
||||||
)
|
|
||||||
await session.commit()
|
|
||||||
await self._emit_step_event(
|
|
||||||
pipeline=pipeline,
|
|
||||||
run_input=run_input,
|
|
||||||
step_name="router",
|
|
||||||
event_type="STEP_FINISHED",
|
|
||||||
)
|
|
||||||
return router_output
|
|
||||||
|
|
||||||
async def _execute_worker_step(
|
async def _execute_worker_step(
|
||||||
self,
|
self,
|
||||||
@@ -187,21 +136,22 @@ class AgentScopeRunner:
|
|||||||
pipeline: PipelineLike,
|
pipeline: PipelineLike,
|
||||||
run_input: RunAgentInput,
|
run_input: RunAgentInput,
|
||||||
user_context: UserContext,
|
user_context: UserContext,
|
||||||
router_output: RouterAgentOutput,
|
context_messages: list[Msg],
|
||||||
toolkit: Any,
|
toolkit: Any,
|
||||||
stage_config: SystemAgentRuntimeConfig,
|
stage_config: SystemAgentRuntimeConfig,
|
||||||
runtime_client_time: ClientTimeContext | None,
|
runtime_client_time: ClientTimeContext | None,
|
||||||
) -> WorkerAgentOutputLite:
|
) -> AgentOutput:
|
||||||
worker_output_model = resolve_worker_output_model(router_output.ui.ui_mode)
|
step_name = stage_config.agent_type.value
|
||||||
|
worker_output_model = AgentOutput
|
||||||
await self._emit_step_event(
|
await self._emit_step_event(
|
||||||
pipeline=pipeline,
|
pipeline=pipeline,
|
||||||
run_input=run_input,
|
run_input=run_input,
|
||||||
step_name="worker",
|
step_name=step_name,
|
||||||
event_type="STEP_STARTED",
|
event_type="STEP_STARTED",
|
||||||
)
|
)
|
||||||
worker_result = await self._run_worker_stage(
|
worker_result = await self._run_worker_stage(
|
||||||
user_context=user_context,
|
user_context=user_context,
|
||||||
router_output=router_output,
|
context_messages=context_messages,
|
||||||
toolkit=toolkit,
|
toolkit=toolkit,
|
||||||
run_input=run_input,
|
run_input=run_input,
|
||||||
stage_config=stage_config,
|
stage_config=stage_config,
|
||||||
@@ -213,7 +163,7 @@ class AgentScopeRunner:
|
|||||||
await self._emit_step_event(
|
await self._emit_step_event(
|
||||||
pipeline=pipeline,
|
pipeline=pipeline,
|
||||||
run_input=run_input,
|
run_input=run_input,
|
||||||
step_name="worker",
|
step_name=step_name,
|
||||||
event_type="STEP_FINISHED",
|
event_type="STEP_FINISHED",
|
||||||
)
|
)
|
||||||
return worker_output
|
return worker_output
|
||||||
@@ -261,78 +211,33 @@ class AgentScopeRunner:
|
|||||||
raise RuntimeError(f"provider api key missing for factory: {factory_name}")
|
raise RuntimeError(f"provider api key missing for factory: {factory_name}")
|
||||||
return api_key
|
return api_key
|
||||||
|
|
||||||
async def _run_router_stage(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
user_context: UserContext,
|
|
||||||
context_messages: list[Msg],
|
|
||||||
run_input: RunAgentInput,
|
|
||||||
stage_config: SystemAgentRuntimeConfig,
|
|
||||||
runtime_client_time: ClientTimeContext | None,
|
|
||||||
) -> StageExecutionResult:
|
|
||||||
tracking_model = self._build_model(stage_config=stage_config)
|
|
||||||
system_prompt = build_system_prompt(
|
|
||||||
agent_type=AgentType.ROUTER,
|
|
||||||
user_context=user_context,
|
|
||||||
now_utc=datetime.now(timezone.utc),
|
|
||||||
runtime_client_time=runtime_client_time,
|
|
||||||
tools=None,
|
|
||||||
)
|
|
||||||
response, payload = await finalize_json_response(
|
|
||||||
model=tracking_model,
|
|
||||||
formatter=OpenAIChatFormatter(),
|
|
||||||
base_messages=[Msg("system", system_prompt, "system"), *context_messages],
|
|
||||||
output_model=RouterAgentOutput,
|
|
||||||
retries=0,
|
|
||||||
)
|
|
||||||
response_msg = Msg(
|
|
||||||
name="router",
|
|
||||||
role="assistant",
|
|
||||||
content=list(getattr(response, "content", [])),
|
|
||||||
metadata=payload,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"router_reply_received",
|
|
||||||
run_id=run_input.run_id,
|
|
||||||
thread_id=run_input.thread_id,
|
|
||||||
message_id=str(response_msg.id),
|
|
||||||
)
|
|
||||||
return StageExecutionResult(
|
|
||||||
message=response_msg,
|
|
||||||
payload=payload,
|
|
||||||
response_metadata=self._litellm_service.build_usage_metadata(
|
|
||||||
model=stage_config.model_code,
|
|
||||||
usage_summary=tracking_model.usage_summary(),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _run_worker_stage(
|
async def _run_worker_stage(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
user_context: UserContext,
|
user_context: UserContext,
|
||||||
router_output: RouterAgentOutput,
|
context_messages: list[Msg],
|
||||||
toolkit: Any,
|
toolkit: Any,
|
||||||
run_input: RunAgentInput,
|
run_input: RunAgentInput,
|
||||||
stage_config: SystemAgentRuntimeConfig,
|
stage_config: SystemAgentRuntimeConfig,
|
||||||
worker_output_model: type[WorkerAgentOutputLite],
|
worker_output_model: type[AgentOutput],
|
||||||
pipeline: PipelineLike,
|
pipeline: PipelineLike,
|
||||||
runtime_client_time: ClientTimeContext | None,
|
runtime_client_time: ClientTimeContext | None,
|
||||||
) -> StageExecutionResult:
|
) -> StageExecutionResult:
|
||||||
worker_input = self._build_worker_input_messages(router_output=router_output)
|
worker_input = list(context_messages)
|
||||||
tracking_model = self._build_model(stage_config=stage_config)
|
tracking_model = self._build_model(stage_config=stage_config)
|
||||||
emitter = PipelineStageEmitter(
|
emitter = PipelineStageEmitter(
|
||||||
pipeline=pipeline,
|
pipeline=pipeline,
|
||||||
session_id=run_input.thread_id,
|
session_id=run_input.thread_id,
|
||||||
run_id=run_input.run_id,
|
run_id=run_input.run_id,
|
||||||
stage="worker",
|
stage=stage_config.agent_type.value,
|
||||||
emit_text_events=True,
|
emit_text_events=True,
|
||||||
emit_tool_events=True,
|
emit_tool_events=True,
|
||||||
)
|
)
|
||||||
agent = self._build_agent(
|
agent = self._build_agent(
|
||||||
agent_name="worker",
|
agent_name=stage_config.agent_type.value,
|
||||||
system_prompt=build_system_prompt(
|
system_prompt=build_system_prompt(
|
||||||
agent_type=AgentType.WORKER,
|
agent_type=stage_config.agent_type,
|
||||||
|
llm_config=stage_config.llm_config,
|
||||||
user_context=user_context,
|
user_context=user_context,
|
||||||
now_utc=datetime.now(timezone.utc),
|
now_utc=datetime.now(timezone.utc),
|
||||||
runtime_client_time=runtime_client_time,
|
runtime_client_time=runtime_client_time,
|
||||||
@@ -360,19 +265,6 @@ class AgentScopeRunner:
|
|||||||
response_metadata=response_metadata,
|
response_metadata=response_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _build_worker_input_messages(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
router_output: RouterAgentOutput,
|
|
||||||
) -> list[Msg]:
|
|
||||||
return [
|
|
||||||
Msg(
|
|
||||||
name="router",
|
|
||||||
role="user",
|
|
||||||
content=build_worker_contract_prompt(router_output=router_output),
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
def _build_model(
|
def _build_model(
|
||||||
self, *, stage_config: SystemAgentRuntimeConfig
|
self, *, stage_config: SystemAgentRuntimeConfig
|
||||||
) -> TrackingChatModel:
|
) -> TrackingChatModel:
|
||||||
@@ -381,8 +273,7 @@ class AgentScopeRunner:
|
|||||||
"max_tokens": stage_config.llm_config.max_tokens,
|
"max_tokens": stage_config.llm_config.max_tokens,
|
||||||
"timeout": stage_config.llm_config.timeout_seconds,
|
"timeout": stage_config.llm_config.timeout_seconds,
|
||||||
}
|
}
|
||||||
if stage_config.agent_type == AgentType.ROUTER:
|
generate_kwargs["extra_body"] = {"enable_thinking": False}
|
||||||
generate_kwargs["extra_body"] = {"enable_thinking": False}
|
|
||||||
|
|
||||||
model = OpenAIChatModel(
|
model = OpenAIChatModel(
|
||||||
model_name=stage_config.model_code,
|
model_name=stage_config.model_code,
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from core.agentscope.events import (
|
|||||||
RedisStreamBus,
|
RedisStreamBus,
|
||||||
SqlAlchemyEventStore,
|
SqlAlchemyEventStore,
|
||||||
)
|
)
|
||||||
|
from core.agentscope.runtime.context_service import AgentContextService
|
||||||
from core.agentscope.runtime.orchestrator import AgentScopeRuntimeOrchestrator
|
from core.agentscope.runtime.orchestrator import AgentScopeRuntimeOrchestrator
|
||||||
from core.agentscope.schemas.agui_input import parse_run_input
|
from core.agentscope.schemas.agui_input import parse_run_input
|
||||||
from core.auth.models import CurrentUser
|
from core.auth.models import CurrentUser
|
||||||
@@ -26,7 +27,7 @@ from schemas.messages.chat_message import (
|
|||||||
from schemas.user import UserContext
|
from schemas.user import UserContext
|
||||||
from services.base.redis import get_or_init_redis_client
|
from services.base.redis import get_or_init_redis_client
|
||||||
from services.base.supabase import supabase_service
|
from services.base.supabase import supabase_service
|
||||||
from v1.agent.dependencies import get_agent_service
|
from v1.agent.repository import AgentRepository
|
||||||
from v1.users.dependencies import get_user_service
|
from v1.users.dependencies import get_user_service
|
||||||
|
|
||||||
logger = get_logger("core.agentscope.runtime.tasks")
|
logger = get_logger("core.agentscope.runtime.tasks")
|
||||||
@@ -78,9 +79,13 @@ async def _build_recent_context_messages(
|
|||||||
*,
|
*,
|
||||||
session: Any,
|
session: Any,
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
|
system_agent_mode: str,
|
||||||
) -> list[Msg]:
|
) -> list[Msg]:
|
||||||
agent_service = get_agent_service(session)
|
context_service = AgentContextService(repository=AgentRepository(session))
|
||||||
result = await agent_service.load_agent_input_messages(thread_id=thread_id)
|
result = await context_service.load_context_messages(
|
||||||
|
thread_id=thread_id,
|
||||||
|
system_agent_mode=system_agent_mode,
|
||||||
|
)
|
||||||
if not result:
|
if not result:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@@ -165,6 +170,7 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
|||||||
command_type = str(command.get("command", "run")).strip().lower()
|
command_type = str(command.get("command", "run")).strip().lower()
|
||||||
raw_owner_id = command.get("owner_id")
|
raw_owner_id = command.get("owner_id")
|
||||||
run_input_raw = command.get("run_input")
|
run_input_raw = command.get("run_input")
|
||||||
|
system_agent_mode = str(command.get("system_agent_mode", "worker")).strip().lower()
|
||||||
|
|
||||||
if not isinstance(raw_owner_id, str) or not raw_owner_id.strip():
|
if not isinstance(raw_owner_id, str) or not raw_owner_id.strip():
|
||||||
raise ValueError("owner_id is required")
|
raise ValueError("owner_id is required")
|
||||||
@@ -205,12 +211,14 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
|||||||
context_messages = await _build_recent_context_messages(
|
context_messages = await _build_recent_context_messages(
|
||||||
session=session,
|
session=session,
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
|
system_agent_mode=system_agent_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
await runtime.run(
|
await runtime.run(
|
||||||
run_input=run_input,
|
run_input=run_input,
|
||||||
context_messages=context_messages,
|
context_messages=context_messages,
|
||||||
user_context=user_context,
|
user_context=user_context,
|
||||||
|
system_agent_mode=system_agent_mode,
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
"agentscope runtime task completed",
|
"agentscope runtime task completed",
|
||||||
|
|||||||
@@ -0,0 +1,42 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from core.agentscope.tools.tool_config import resolve_tool_names_by_groups
|
||||||
|
from schemas.agent.system_agent import AgentType
|
||||||
|
|
||||||
|
ToolNameResolver = Callable[[Any], set[str] | None]
|
||||||
|
|
||||||
|
|
||||||
|
class ToolSelectionRegistry:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._resolvers: dict[AgentType, ToolNameResolver] = {}
|
||||||
|
|
||||||
|
def register(self, *, agent_type: AgentType, resolver: ToolNameResolver) -> None:
|
||||||
|
self._resolvers[agent_type] = resolver
|
||||||
|
|
||||||
|
def resolve(self, *, stage_config: Any) -> set[str] | None:
|
||||||
|
resolver = self._resolvers.get(stage_config.agent_type)
|
||||||
|
if resolver is None:
|
||||||
|
return None
|
||||||
|
return resolver(stage_config)
|
||||||
|
|
||||||
|
|
||||||
|
def _default_group_resolver(stage_config: Any) -> set[str] | None:
|
||||||
|
raw_groups = getattr(stage_config.llm_config, "enabled_tool_groups", [])
|
||||||
|
groups = raw_groups if isinstance(raw_groups, list) else []
|
||||||
|
if not groups:
|
||||||
|
return None
|
||||||
|
return resolve_tool_names_by_groups(set(groups))
|
||||||
|
|
||||||
|
|
||||||
|
TOOL_SELECTION_REGISTRY = ToolSelectionRegistry()
|
||||||
|
TOOL_SELECTION_REGISTRY.register(
|
||||||
|
agent_type=AgentType.WORKER,
|
||||||
|
resolver=_default_group_resolver,
|
||||||
|
)
|
||||||
|
TOOL_SELECTION_REGISTRY.register(
|
||||||
|
agent_type=AgentType.MEMORY,
|
||||||
|
resolver=_default_group_resolver,
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user