refactor(agent): remove memory agent, simplify runtime config system
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
from core.agentscope.prompts.agent_prompt import build_agent_prompt
|
||||
from core.agentscope.prompts.memory_prompt import build_memory_prompt
|
||||
from core.agentscope.prompts.system_prompt import build_system_prompt
|
||||
from core.agentscope.prompts.tool_prompt import build_tools_prompt
|
||||
|
||||
__all__ = [
|
||||
"build_agent_prompt",
|
||||
"build_memory_prompt",
|
||||
"build_system_prompt",
|
||||
"build_tools_prompt",
|
||||
]
|
||||
|
||||
@@ -66,6 +66,7 @@ def _router_rules(llm_config: SystemAgentLLMConfig | None) -> list[str]:
|
||||
"- Router only: extract intent and route strategy; never answer user directly.",
|
||||
"- Preserve intent in normalized_task_input.user_text; keep wording concise and faithful.",
|
||||
"- Fill multimodal_summary only when image/attachment changes execution decisions.",
|
||||
"- Fill normalized_task_input.context_summary with a brief description of what the provided context messages contain; this is critical for worker to understand the conversational background.",
|
||||
"- Return key_entities and constraints that are execution-relevant; low confidence -> omit rather than guess.",
|
||||
"- Set execution_mode by complexity: onestep / tool_assisted / multistep.",
|
||||
"- Set result_typing.primary to the most suitable response shape; use clarification_request only when required info is missing.",
|
||||
@@ -97,23 +98,6 @@ def _worker_rules(llm_config: SystemAgentLLMConfig | None) -> list[str]:
|
||||
]
|
||||
|
||||
|
||||
def _memory_rules(llm_config: SystemAgentLLMConfig | None) -> list[str]:
|
||||
return [
|
||||
"[Memory Agent]",
|
||||
"- Analyze conversation context and output structured memory-safe conclusions.",
|
||||
"- Return exactly one agent output JSON object matching the runtime-injected schema.",
|
||||
"[Responsibilities]",
|
||||
"- Focus on extracting durable user facts and preferences from context.",
|
||||
"- Keep outputs concise, deterministic, and evidence-backed.",
|
||||
"- Do not invent facts or hidden user intent.",
|
||||
"- Use tool calls only when required by explicit workflow and allowed tool groups.",
|
||||
"[Schema Guidance]",
|
||||
"- The output schema is injected at runtime; follow it exactly.",
|
||||
"- Do not add fields that are not present in the injected schema.",
|
||||
*_config_rules(llm_config),
|
||||
]
|
||||
|
||||
|
||||
def build_worker_contract_prompt(*, router_output: RouterAgentOutput) -> str:
|
||||
contract_json = json.dumps(
|
||||
router_output.model_dump(mode="json", exclude_none=True),
|
||||
@@ -125,6 +109,7 @@ def build_worker_contract_prompt(*, router_output: RouterAgentOutput) -> str:
|
||||
"[Worker Contract]",
|
||||
"- Keep routed objective unchanged.",
|
||||
"- Use normalized_task_input as objective text.",
|
||||
"- Use context_summary to understand conversational background from chat history.",
|
||||
"- Use multimodal_summary/key_entities/constraints as execution evidence.",
|
||||
"- Infer deterministic missing required tool args from evidence + tool schema.",
|
||||
"- Ask clarification only when safe inference is impossible.",
|
||||
@@ -137,7 +122,6 @@ def build_worker_contract_prompt(*, router_output: RouterAgentOutput) -> str:
|
||||
AGENT_PROMPT_REGISTRY = AgentPromptRegistry()
|
||||
AGENT_PROMPT_REGISTRY.register(agent_type=AgentType.ROUTER, builder=_router_rules)
|
||||
AGENT_PROMPT_REGISTRY.register(agent_type=AgentType.WORKER, builder=_worker_rules)
|
||||
AGENT_PROMPT_REGISTRY.register(agent_type=AgentType.MEMORY, builder=_memory_rules)
|
||||
|
||||
|
||||
def build_agent_prompt(
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from schemas.memories import MemoryContext, MemoryListResponse
|
||||
|
||||
|
||||
def _wrap_section(section: str, content: str) -> str:
|
||||
marker_map = {
|
||||
"memory": ("<!-- MEMORY_START -->", "<!-- MEMORY_END -->"),
|
||||
}
|
||||
start, end = marker_map[section]
|
||||
body = content.strip()
|
||||
return f"{start}\n{body}\n{end}" if body else f"{start}\n{end}"
|
||||
|
||||
|
||||
def _format_memory_content(content: dict[str, Any]) -> str:
|
||||
if isinstance(content, dict):
|
||||
return json.dumps(content, ensure_ascii=True, separators=(",", ":"))
|
||||
return str(content)
|
||||
|
||||
|
||||
def _format_memory(ctx: MemoryContext) -> str:
|
||||
parts = [
|
||||
f"[{ctx.memory_type.value.upper()}] {ctx.title or 'Untitled'}",
|
||||
f" source: {ctx.source.value}",
|
||||
f" content: {_format_memory_content(ctx.content)}",
|
||||
]
|
||||
if ctx.created_at:
|
||||
parts.append(f" created_at: {ctx.created_at.isoformat()}")
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def build_memory_prompt(
|
||||
*,
|
||||
memories: MemoryListResponse,
|
||||
) -> str | None:
|
||||
if not memories.memories:
|
||||
return None
|
||||
|
||||
lines: list[str] = [
|
||||
"[User Memories]",
|
||||
"- Memories are persistent context from previous sessions.",
|
||||
"- Use them to ground responses in known user facts and preferences.",
|
||||
"- Do not invent facts not present in memories.",
|
||||
]
|
||||
|
||||
for ctx in memories.memories:
|
||||
lines.append(_format_memory(ctx))
|
||||
|
||||
return _wrap_section("memory", "\n".join(lines))
|
||||
@@ -9,10 +9,12 @@ from ag_ui.core.types import Tool
|
||||
from core.agentscope.prompts.agent_prompt import (
|
||||
build_agent_prompt,
|
||||
)
|
||||
from core.agentscope.prompts.memory_prompt import build_memory_prompt
|
||||
from core.agentscope.prompts.route_prompt import build_frontend_route_prompt
|
||||
from core.agentscope.prompts.tool_prompt import build_tools_prompt
|
||||
from schemas.agent.system_agent import AgentType, SystemAgentLLMConfig
|
||||
from schemas.agent.forwarded_props import ClientTimeContext
|
||||
from schemas.memories import MemoryListResponse
|
||||
from schemas.user.context import UserContext
|
||||
|
||||
|
||||
@@ -202,12 +204,13 @@ def _build_route_section() -> str:
|
||||
def build_system_prompt(
|
||||
*,
|
||||
agent_type: AgentType,
|
||||
llm_config: SystemAgentLLMConfig | None,
|
||||
llm_config: SystemAgentLLMConfig | None = None,
|
||||
user_context: UserContext,
|
||||
now_utc: datetime,
|
||||
runtime_client_time: ClientTimeContext | None = None,
|
||||
extra_context: str | None = None,
|
||||
tools: Sequence[Tool | dict[str, Any]] | None = None,
|
||||
memories: MemoryListResponse | None = None,
|
||||
) -> str:
|
||||
include_route_section = agent_type == AgentType.WORKER
|
||||
sections: list[str | None] = [
|
||||
@@ -225,6 +228,7 @@ def build_system_prompt(
|
||||
llm_config=llm_config,
|
||||
),
|
||||
build_tools_prompt(tools=tools) if tools else None,
|
||||
build_memory_prompt(memories=memories) if memories else None,
|
||||
_build_output_rules(),
|
||||
]
|
||||
return "\n\n".join(item for item in sections if item).strip()
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
from schemas.agent.system_agent import ContextBuildStrategy
|
||||
|
||||
ContextLoader = Callable[[Any, str, int, int], Awaitable[dict[str, object] | None]]
|
||||
|
||||
|
||||
class ContextLoaderRegistry:
|
||||
def __init__(self) -> None:
|
||||
self._loaders: dict[ContextBuildStrategy, ContextLoader] = {}
|
||||
|
||||
def register(self, *, mode: ContextBuildStrategy, loader: ContextLoader) -> None:
|
||||
self._loaders[mode] = loader
|
||||
|
||||
def resolve(self, *, mode: ContextBuildStrategy) -> ContextLoader:
|
||||
loader = self._loaders.get(mode)
|
||||
if loader is None:
|
||||
raise ValueError(f"unsupported context mode: {mode.value}")
|
||||
return loader
|
||||
|
||||
|
||||
async def _load_number(
|
||||
service: Any,
|
||||
thread_id: str,
|
||||
count: int,
|
||||
visibility_mask: int,
|
||||
) -> dict[str, object] | None:
|
||||
return await service.load_by_user_message_window(
|
||||
thread_id=thread_id,
|
||||
user_message_limit=max(count, 1),
|
||||
visibility_mask=visibility_mask,
|
||||
)
|
||||
|
||||
|
||||
async def _load_day(
|
||||
service: Any,
|
||||
thread_id: str,
|
||||
count: int,
|
||||
visibility_mask: int,
|
||||
) -> dict[str, object] | None:
|
||||
return await service.load_by_day_window(
|
||||
thread_id=thread_id,
|
||||
day_count=max(count, 1),
|
||||
visibility_mask=visibility_mask,
|
||||
)
|
||||
|
||||
|
||||
CONTEXT_LOADER_REGISTRY = ContextLoaderRegistry()
|
||||
CONTEXT_LOADER_REGISTRY.register(mode=ContextBuildStrategy.NUMBER, loader=_load_number)
|
||||
CONTEXT_LOADER_REGISTRY.register(mode=ContextBuildStrategy.DAY, loader=_load_day)
|
||||
@@ -6,7 +6,8 @@ from ag_ui.core.types import RunAgentInput
|
||||
from agentscope.message import Msg
|
||||
from core.agentscope.runtime.runner import AgentScopeRunner
|
||||
from core.logging import get_logger
|
||||
from schemas.automation.config import AutomationJobConfig
|
||||
from schemas.automation import RuntimeConfig
|
||||
from schemas.memories import MemoryListResponse
|
||||
from schemas.user import UserContext
|
||||
|
||||
logger = get_logger("core.agentscope.runtime.orchestrator")
|
||||
@@ -24,8 +25,8 @@ class RunnerLike(Protocol):
|
||||
context_messages: list[Msg],
|
||||
pipeline: PipelineLike,
|
||||
run_input: RunAgentInput,
|
||||
system_agent_mode: str,
|
||||
memory_job_config: AutomationJobConfig | None,
|
||||
runtime_config: RuntimeConfig,
|
||||
memories: MemoryListResponse | None,
|
||||
) -> dict[str, Any]: ...
|
||||
|
||||
|
||||
@@ -48,8 +49,8 @@ class AgentScopeRuntimeOrchestrator:
|
||||
run_input: RunAgentInput,
|
||||
context_messages: list[Msg],
|
||||
user_context: UserContext,
|
||||
system_agent_mode: str,
|
||||
memory_job_config: AutomationJobConfig | None = None,
|
||||
runtime_config: RuntimeConfig,
|
||||
memories: MemoryListResponse | None = None,
|
||||
) -> dict[str, Any]:
|
||||
thread_id = run_input.thread_id
|
||||
run_id = run_input.run_id
|
||||
@@ -68,8 +69,8 @@ class AgentScopeRuntimeOrchestrator:
|
||||
context_messages=context_messages,
|
||||
pipeline=self._pipeline,
|
||||
run_input=run_input,
|
||||
system_agent_mode=system_agent_mode,
|
||||
memory_job_config=memory_job_config,
|
||||
runtime_config=runtime_config,
|
||||
memories=memories,
|
||||
)
|
||||
|
||||
await self._pipeline.emit(
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.agentscope.schemas.pipeline_spec import ExecutorKind, PipelineSpec, StageSpec
|
||||
from schemas.agent.system_agent import AgentType
|
||||
|
||||
|
||||
def build_default_pipeline_spec(*, mode: str) -> PipelineSpec:
|
||||
normalized = mode.strip().lower()
|
||||
if normalized == "worker":
|
||||
return PipelineSpec(
|
||||
mode="worker",
|
||||
stages=[
|
||||
StageSpec(
|
||||
stage_name="router",
|
||||
agent_type=AgentType.ROUTER,
|
||||
executor_kind=ExecutorKind.SINGLE_SHOT,
|
||||
),
|
||||
StageSpec(
|
||||
stage_name="worker",
|
||||
agent_type=AgentType.WORKER,
|
||||
executor_kind=ExecutorKind.REACT,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
if normalized == "memory":
|
||||
return PipelineSpec(
|
||||
mode="memory",
|
||||
stages=[
|
||||
StageSpec(
|
||||
stage_name="memory",
|
||||
agent_type=AgentType.MEMORY,
|
||||
executor_kind=ExecutorKind.REACT,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
raise ValueError(f"unsupported pipeline mode: {normalized}")
|
||||
@@ -1,22 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.agentscope.schemas.consumer_registry import (
|
||||
AgentConsumerBinding,
|
||||
ConsumerRegistry,
|
||||
)
|
||||
|
||||
|
||||
def build_consumer_registry(
|
||||
*,
|
||||
system_agent_configs: dict[str, dict[str, object]],
|
||||
) -> ConsumerRegistry:
|
||||
bindings: list[AgentConsumerBinding] = []
|
||||
for agent_type, payload in system_agent_configs.items():
|
||||
config_obj = payload.get("config") if isinstance(payload, dict) else None
|
||||
if not isinstance(config_obj, dict):
|
||||
raise ValueError(f"invalid system agent config: {agent_type}")
|
||||
raw_bit = config_obj.get("visibility_consumer_bit")
|
||||
if not isinstance(raw_bit, int):
|
||||
raise ValueError(f"visibility_consumer_bit missing for agent: {agent_type}")
|
||||
bindings.append(AgentConsumerBinding(agent_type=agent_type, bit=raw_bit))
|
||||
return ConsumerRegistry(bindings=bindings)
|
||||
@@ -12,12 +12,12 @@ from agentscope.message import Msg
|
||||
from agentscope.model import OpenAIChatModel
|
||||
from core.agentscope.prompts.agent_prompt import build_worker_contract_prompt
|
||||
from core.agentscope.prompts.system_prompt import build_system_prompt
|
||||
from core.agentscope.runtime.pipeline_registry import build_default_pipeline_spec
|
||||
from core.agentscope.schemas.agui_input import extract_latest_user_payload
|
||||
from core.agentscope.runtime.json_react_agent import JsonReActAgent
|
||||
from core.agentscope.runtime.model_tracking import TrackingChatModel
|
||||
from core.agentscope.runtime.stage_emitter import PipelineStageEmitter
|
||||
from core.agentscope.runtime.tool_selection_registry import TOOL_SELECTION_REGISTRY
|
||||
from core.agentscope.tools.toolkit import build_stage_toolkit
|
||||
from core.agentscope.tools.tool_config import AgentTool
|
||||
from core.agentscope.tools.toolkit import build_toolkit
|
||||
from core.agentscope.utils import (
|
||||
finalize_json_response,
|
||||
patch_agentscope_json_repair_compat,
|
||||
@@ -31,19 +31,17 @@ from schemas.agent.forwarded_props import (
|
||||
ClientTimeContext,
|
||||
parse_forwarded_props_client_time,
|
||||
)
|
||||
from schemas.automation.config import AutomationJobConfig
|
||||
from schemas.agent.runtime_models import (
|
||||
AgentOutput,
|
||||
RouterAgentOutput,
|
||||
WorkerAgentOutputLite,
|
||||
resolve_worker_output_model,
|
||||
)
|
||||
from schemas.agent.system_agent import (
|
||||
AgentType,
|
||||
ContextMessagesConfig,
|
||||
ContextBuildStrategy,
|
||||
SystemAgentLLMConfig,
|
||||
)
|
||||
from schemas.automation import RuntimeConfig
|
||||
from schemas.memories import MemoryListResponse
|
||||
from schemas.user import UserContext
|
||||
from services.litellm.service import LiteLLMService
|
||||
from sqlalchemy import select
|
||||
@@ -53,16 +51,6 @@ if TYPE_CHECKING:
|
||||
from core.agentscope.runtime.orchestrator import PipelineLike
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SystemAgentRuntimeConfig:
|
||||
agent_type: AgentType
|
||||
model_code: str
|
||||
api_base_url: str
|
||||
api_key: str
|
||||
llm_config: SystemAgentLLMConfig
|
||||
extra_context: str | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class StageExecutionResult:
|
||||
message: Msg
|
||||
@@ -82,96 +70,63 @@ class AgentScopeRunner:
|
||||
context_messages: list[Msg],
|
||||
pipeline: PipelineLike,
|
||||
run_input: RunAgentInput,
|
||||
system_agent_mode: str,
|
||||
memory_job_config: AutomationJobConfig | None = None,
|
||||
runtime_config: RuntimeConfig,
|
||||
memories: MemoryListResponse | None = None,
|
||||
) -> dict[str, Any]:
|
||||
owner_id = UUID(user_context.id)
|
||||
runtime_client_time = self._resolve_runtime_client_time(run_input=run_input)
|
||||
pipeline_spec = build_default_pipeline_spec(mode=system_agent_mode)
|
||||
stage_agent_types = [stage.agent_type for stage in pipeline_spec.stages]
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
if stage_agent_types == [AgentType.ROUTER, AgentType.WORKER]:
|
||||
router_config = await self._load_stage_config(
|
||||
session=session,
|
||||
agent_type=AgentType.ROUTER,
|
||||
)
|
||||
worker_config = await self._load_stage_config(
|
||||
session=session,
|
||||
agent_type=AgentType.WORKER,
|
||||
)
|
||||
worker_toolkit = self._build_stage_toolkit(
|
||||
session=session,
|
||||
owner_id=owner_id,
|
||||
stage_config=worker_config,
|
||||
)
|
||||
router_output = await self._execute_router_step(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
user_context=user_context,
|
||||
context_messages=context_messages,
|
||||
stage_config=router_config,
|
||||
runtime_client_time=runtime_client_time,
|
||||
)
|
||||
worker_output = await self._execute_worker_step(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
user_context=user_context,
|
||||
router_output=router_output,
|
||||
toolkit=worker_toolkit,
|
||||
stage_config=worker_config,
|
||||
runtime_client_time=runtime_client_time,
|
||||
)
|
||||
return {
|
||||
"router": router_output.model_dump(mode="json", exclude_none=True),
|
||||
"worker": worker_output.model_dump(mode="json", exclude_none=True),
|
||||
}
|
||||
|
||||
if stage_agent_types[0] == AgentType.MEMORY:
|
||||
if memory_job_config is None:
|
||||
raise RuntimeError("memory job config is required")
|
||||
stage_config = await self._build_memory_stage_config(
|
||||
session=session,
|
||||
memory_job_config=memory_job_config,
|
||||
)
|
||||
else:
|
||||
stage_config = await self._load_stage_config(
|
||||
session=session,
|
||||
agent_type=stage_agent_types[0],
|
||||
)
|
||||
stage_toolkit = self._build_stage_toolkit(
|
||||
router_config = await self._load_stage_config(
|
||||
session=session,
|
||||
agent_type=AgentType.ROUTER,
|
||||
)
|
||||
worker_config = await self._load_stage_config(
|
||||
session=session,
|
||||
agent_type=AgentType.WORKER,
|
||||
)
|
||||
worker_toolkit = self._build_toolkit(
|
||||
session=session,
|
||||
owner_id=owner_id,
|
||||
stage_config=stage_config,
|
||||
enabled_tools=runtime_config.enabled_tools,
|
||||
)
|
||||
stage_output = await self._execute_single_stage_step(
|
||||
|
||||
router_output = await self._execute_router_step(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
user_context=user_context,
|
||||
input_messages=context_messages,
|
||||
toolkit=stage_toolkit,
|
||||
stage_config=stage_config,
|
||||
context_messages=context_messages,
|
||||
stage_config=router_config,
|
||||
runtime_client_time=runtime_client_time,
|
||||
memories=memories,
|
||||
)
|
||||
worker_output = await self._execute_worker_step(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
user_context=user_context,
|
||||
router_output=router_output,
|
||||
toolkit=worker_toolkit,
|
||||
stage_config=worker_config,
|
||||
runtime_client_time=runtime_client_time,
|
||||
memories=memories,
|
||||
)
|
||||
return {
|
||||
stage_config.agent_type.value: stage_output.model_dump(
|
||||
mode="json", exclude_none=True
|
||||
),
|
||||
"router": router_output.model_dump(mode="json", exclude_none=True),
|
||||
"worker": worker_output.model_dump(mode="json", exclude_none=True),
|
||||
}
|
||||
|
||||
def _build_stage_toolkit(
|
||||
def _build_toolkit(
|
||||
self,
|
||||
*,
|
||||
session: AsyncSession,
|
||||
owner_id: UUID,
|
||||
stage_config: SystemAgentRuntimeConfig,
|
||||
enabled_tools: list[AgentTool],
|
||||
) -> Any:
|
||||
enabled_tool_names = TOOL_SELECTION_REGISTRY.resolve(stage_config=stage_config)
|
||||
return build_stage_toolkit(
|
||||
agent_type=stage_config.agent_type,
|
||||
tool_names = [t.value for t in enabled_tools] if enabled_tools else []
|
||||
return build_toolkit(
|
||||
session=session,
|
||||
owner_id=owner_id,
|
||||
enabled_tool_names=enabled_tool_names,
|
||||
enabled_tool_names=set(tool_names) if tool_names else None,
|
||||
)
|
||||
|
||||
async def _load_stage_config(
|
||||
@@ -179,124 +134,6 @@ class AgentScopeRunner:
|
||||
*,
|
||||
session: AsyncSession,
|
||||
agent_type: AgentType,
|
||||
) -> SystemAgentRuntimeConfig:
|
||||
return await self._load_system_agent_config(
|
||||
session=session,
|
||||
agent_type=agent_type,
|
||||
)
|
||||
|
||||
async def _execute_router_step(
|
||||
self,
|
||||
*,
|
||||
pipeline: PipelineLike,
|
||||
run_input: RunAgentInput,
|
||||
user_context: UserContext,
|
||||
context_messages: list[Msg],
|
||||
stage_config: SystemAgentRuntimeConfig,
|
||||
runtime_client_time: ClientTimeContext | None,
|
||||
) -> RouterAgentOutput:
|
||||
await self._emit_step_event(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
step_name=AgentType.ROUTER.value,
|
||||
event_type="STEP_STARTED",
|
||||
)
|
||||
router_result = await self._run_router_stage(
|
||||
user_context=user_context,
|
||||
context_messages=context_messages,
|
||||
stage_config=stage_config,
|
||||
runtime_client_time=runtime_client_time,
|
||||
)
|
||||
router_output = RouterAgentOutput.model_validate(router_result.payload)
|
||||
await self._emit_step_event(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
step_name=AgentType.ROUTER.value,
|
||||
event_type="STEP_FINISHED",
|
||||
)
|
||||
return router_output
|
||||
|
||||
async def _execute_worker_step(
|
||||
self,
|
||||
*,
|
||||
pipeline: PipelineLike,
|
||||
run_input: RunAgentInput,
|
||||
user_context: UserContext,
|
||||
router_output: RouterAgentOutput,
|
||||
toolkit: Any,
|
||||
stage_config: SystemAgentRuntimeConfig,
|
||||
runtime_client_time: ClientTimeContext | None,
|
||||
) -> WorkerAgentOutputLite:
|
||||
worker_output_model = resolve_worker_output_model(router_output.ui.ui_mode)
|
||||
await self._emit_step_event(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
step_name=AgentType.WORKER.value,
|
||||
event_type="STEP_STARTED",
|
||||
)
|
||||
worker_result = await self._run_worker_stage(
|
||||
user_context=user_context,
|
||||
input_messages=self._build_worker_input_messages(
|
||||
router_output=router_output
|
||||
),
|
||||
toolkit=toolkit,
|
||||
run_input=run_input,
|
||||
stage_config=stage_config,
|
||||
worker_output_model=worker_output_model,
|
||||
pipeline=pipeline,
|
||||
runtime_client_time=runtime_client_time,
|
||||
)
|
||||
worker_output = worker_output_model.model_validate(worker_result.payload)
|
||||
await self._emit_step_event(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
step_name=AgentType.WORKER.value,
|
||||
event_type="STEP_FINISHED",
|
||||
)
|
||||
return worker_output
|
||||
|
||||
async def _execute_single_stage_step(
|
||||
self,
|
||||
*,
|
||||
pipeline: PipelineLike,
|
||||
run_input: RunAgentInput,
|
||||
user_context: UserContext,
|
||||
input_messages: list[Msg],
|
||||
toolkit: Any,
|
||||
stage_config: SystemAgentRuntimeConfig,
|
||||
runtime_client_time: ClientTimeContext | None,
|
||||
) -> AgentOutput:
|
||||
step_name = stage_config.agent_type.value
|
||||
await self._emit_step_event(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
step_name=step_name,
|
||||
event_type="STEP_STARTED",
|
||||
)
|
||||
stage_result = await self._run_worker_stage(
|
||||
user_context=user_context,
|
||||
input_messages=input_messages,
|
||||
toolkit=toolkit,
|
||||
run_input=run_input,
|
||||
stage_config=stage_config,
|
||||
worker_output_model=AgentOutput,
|
||||
pipeline=pipeline,
|
||||
runtime_client_time=runtime_client_time,
|
||||
)
|
||||
stage_output = AgentOutput.model_validate(stage_result.payload)
|
||||
await self._emit_step_event(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
step_name=step_name,
|
||||
event_type="STEP_FINISHED",
|
||||
)
|
||||
return stage_output
|
||||
|
||||
async def _load_system_agent_config(
|
||||
self,
|
||||
*,
|
||||
session: AsyncSession,
|
||||
agent_type: AgentType,
|
||||
) -> SystemAgentRuntimeConfig:
|
||||
stmt = (
|
||||
select(SystemAgents, Llm, LlmFactory)
|
||||
@@ -320,63 +157,80 @@ class AgentScopeRunner:
|
||||
extra_context=None,
|
||||
)
|
||||
|
||||
async def _build_memory_stage_config(
|
||||
async def _execute_router_step(
|
||||
self,
|
||||
*,
|
||||
session: AsyncSession,
|
||||
memory_job_config: AutomationJobConfig,
|
||||
) -> SystemAgentRuntimeConfig:
|
||||
stmt = (
|
||||
select(Llm, LlmFactory)
|
||||
.join(LlmFactory, Llm.factory_id == LlmFactory.id)
|
||||
.where(Llm.model_code == memory_job_config.model_code)
|
||||
pipeline: PipelineLike,
|
||||
run_input: RunAgentInput,
|
||||
user_context: UserContext,
|
||||
context_messages: list[Msg],
|
||||
stage_config: SystemAgentRuntimeConfig,
|
||||
runtime_client_time: ClientTimeContext | None,
|
||||
memories: MemoryListResponse | None,
|
||||
) -> RouterAgentOutput:
|
||||
await self._emit_step_event(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
step_name=AgentType.ROUTER.value,
|
||||
event_type="STEP_STARTED",
|
||||
)
|
||||
row = (await session.execute(stmt)).one_or_none()
|
||||
if row is None:
|
||||
raise RuntimeError(
|
||||
f"memory model not found: {memory_job_config.model_code}"
|
||||
)
|
||||
llm, factory = row
|
||||
llm_config = SystemAgentLLMConfig(
|
||||
temperature=0.7,
|
||||
max_tokens=None,
|
||||
timeout_seconds=30,
|
||||
context_messages=ContextMessagesConfig(
|
||||
mode=(
|
||||
ContextBuildStrategy.DAY
|
||||
if memory_job_config.context.window_mode.value == "day"
|
||||
else ContextBuildStrategy.NUMBER
|
||||
),
|
||||
count=memory_job_config.context.window_count,
|
||||
),
|
||||
enabled_tools=memory_job_config.enabled_tools,
|
||||
router_result = await self._run_router_stage(
|
||||
user_context=user_context,
|
||||
context_messages=context_messages,
|
||||
stage_config=stage_config,
|
||||
runtime_client_time=runtime_client_time,
|
||||
memories=memories,
|
||||
run_input=run_input,
|
||||
)
|
||||
return SystemAgentRuntimeConfig(
|
||||
agent_type=AgentType.MEMORY,
|
||||
model_code=llm.model_code,
|
||||
api_base_url=factory.request_url,
|
||||
api_key=self._resolve_provider_api_key(factory_name=factory.name),
|
||||
llm_config=llm_config,
|
||||
extra_context=(
|
||||
f"[Memory Input Template]\n{memory_job_config.input_template.strip()}"
|
||||
),
|
||||
router_output = RouterAgentOutput.model_validate(router_result.payload)
|
||||
await self._emit_step_event(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
step_name=AgentType.ROUTER.value,
|
||||
event_type="STEP_FINISHED",
|
||||
)
|
||||
return router_output
|
||||
|
||||
@staticmethod
|
||||
def _resolve_provider_api_key(*, factory_name: str) -> str:
|
||||
normalized_factory_name = factory_name.strip().upper()
|
||||
if normalized_factory_name == "VOLCENGINE":
|
||||
normalized_factory_name = "ARK"
|
||||
|
||||
provider_keys = {
|
||||
str(key).strip().upper(): str(value).strip()
|
||||
for key, value in config.llm.provider_keys.items()
|
||||
if str(value).strip()
|
||||
}
|
||||
api_key = provider_keys.get(normalized_factory_name, "")
|
||||
if not api_key:
|
||||
raise RuntimeError(f"provider api key missing for factory: {factory_name}")
|
||||
return api_key
|
||||
async def _execute_worker_step(
|
||||
self,
|
||||
*,
|
||||
pipeline: PipelineLike,
|
||||
run_input: RunAgentInput,
|
||||
user_context: UserContext,
|
||||
router_output: RouterAgentOutput,
|
||||
toolkit: Any,
|
||||
stage_config: SystemAgentRuntimeConfig,
|
||||
runtime_client_time: ClientTimeContext | None,
|
||||
memories: MemoryListResponse | None,
|
||||
) -> WorkerAgentOutputLite:
|
||||
worker_output_model = resolve_worker_output_model(router_output.ui.ui_mode)
|
||||
await self._emit_step_event(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
step_name=AgentType.WORKER.value,
|
||||
event_type="STEP_STARTED",
|
||||
)
|
||||
worker_result = await self._run_worker_stage(
|
||||
user_context=user_context,
|
||||
input_messages=self._build_worker_input_messages(
|
||||
router_output=router_output
|
||||
),
|
||||
toolkit=toolkit,
|
||||
run_input=run_input,
|
||||
stage_config=stage_config,
|
||||
worker_output_model=worker_output_model,
|
||||
pipeline=pipeline,
|
||||
runtime_client_time=runtime_client_time,
|
||||
memories=memories,
|
||||
)
|
||||
worker_output = worker_output_model.model_validate(worker_result.payload)
|
||||
await self._emit_step_event(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
step_name=AgentType.WORKER.value,
|
||||
event_type="STEP_FINISHED",
|
||||
)
|
||||
return worker_output
|
||||
|
||||
async def _run_router_stage(
|
||||
self,
|
||||
@@ -385,7 +239,13 @@ class AgentScopeRunner:
|
||||
context_messages: list[Msg],
|
||||
stage_config: SystemAgentRuntimeConfig,
|
||||
runtime_client_time: ClientTimeContext | None,
|
||||
memories: MemoryListResponse | None,
|
||||
run_input: RunAgentInput,
|
||||
) -> StageExecutionResult:
|
||||
messages_for_router = self._build_router_messages(
|
||||
context_messages=context_messages,
|
||||
run_input=run_input,
|
||||
)
|
||||
tracking_model = self._build_model(stage_config=stage_config)
|
||||
response, payload = await finalize_json_response(
|
||||
model=tracking_model,
|
||||
@@ -400,10 +260,11 @@ class AgentScopeRunner:
|
||||
now_utc=datetime.now(timezone.utc),
|
||||
runtime_client_time=runtime_client_time,
|
||||
tools=None,
|
||||
memories=memories,
|
||||
),
|
||||
"system",
|
||||
),
|
||||
*context_messages,
|
||||
*messages_for_router,
|
||||
],
|
||||
output_model=RouterAgentOutput,
|
||||
retries=0,
|
||||
@@ -423,6 +284,30 @@ class AgentScopeRunner:
|
||||
),
|
||||
)
|
||||
|
||||
def _build_router_messages(
|
||||
self,
|
||||
*,
|
||||
context_messages: list[Msg],
|
||||
run_input: RunAgentInput,
|
||||
) -> list[Msg]:
|
||||
if context_messages:
|
||||
last = context_messages[-1]
|
||||
if last.role == "user":
|
||||
return context_messages
|
||||
|
||||
user_text, user_blocks = extract_latest_user_payload(run_input)
|
||||
if (
|
||||
user_blocks
|
||||
and isinstance(user_blocks[0], dict)
|
||||
and user_blocks[0].get("type") == "text"
|
||||
):
|
||||
content: Any = user_text
|
||||
else:
|
||||
content = user_blocks
|
||||
|
||||
user_msg = Msg(name="user", role="user", content=content)
|
||||
return [user_msg, *context_messages]
|
||||
|
||||
async def _run_worker_stage(
|
||||
self,
|
||||
*,
|
||||
@@ -434,6 +319,7 @@ class AgentScopeRunner:
|
||||
worker_output_model: type[WorkerAgentOutputLite],
|
||||
pipeline: PipelineLike,
|
||||
runtime_client_time: ClientTimeContext | None,
|
||||
memories: MemoryListResponse | None,
|
||||
) -> StageExecutionResult:
|
||||
tracking_model = self._build_model(stage_config=stage_config)
|
||||
emitter = PipelineStageEmitter(
|
||||
@@ -454,6 +340,7 @@ class AgentScopeRunner:
|
||||
runtime_client_time=runtime_client_time,
|
||||
extra_context=stage_config.extra_context,
|
||||
tools=None,
|
||||
memories=memories,
|
||||
),
|
||||
toolkit=toolkit,
|
||||
model=tracking_model,
|
||||
@@ -553,5 +440,31 @@ class AgentScopeRunner:
|
||||
getattr(run_input, "forwarded_props", None)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_provider_api_key(*, factory_name: str) -> str:
|
||||
normalized_factory_name = factory_name.strip().upper()
|
||||
if normalized_factory_name == "VOLCENGINE":
|
||||
normalized_factory_name = "ARK"
|
||||
|
||||
provider_keys = {
|
||||
str(key).strip().upper(): str(value).strip()
|
||||
for key, value in config.llm.provider_keys.items()
|
||||
if str(value).strip()
|
||||
}
|
||||
api_key = provider_keys.get(normalized_factory_name, "")
|
||||
if not api_key:
|
||||
raise RuntimeError(f"provider api key missing for factory: {factory_name}")
|
||||
return api_key
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SystemAgentRuntimeConfig:
|
||||
agent_type: AgentType
|
||||
model_code: str
|
||||
api_base_url: str
|
||||
api_key: str
|
||||
llm_config: SystemAgentLLMConfig
|
||||
extra_context: str | None = None
|
||||
|
||||
|
||||
AgentScopeReActRunner = AgentScopeRunner
|
||||
|
||||
@@ -12,31 +12,28 @@ from core.agentscope.events import (
|
||||
RedisStreamBus,
|
||||
SqlAlchemyEventStore,
|
||||
)
|
||||
from core.agentscope.services.context_service import AgentContextService
|
||||
from core.agentscope.runtime.orchestrator import AgentScopeRuntimeOrchestrator
|
||||
from core.agentscope.runtime.pipeline_registry import build_default_pipeline_spec
|
||||
from core.agentscope.schemas.agui_input import parse_run_input
|
||||
from core.agentscope.services.context_service import AgentContextService
|
||||
from core.auth.models import CurrentUser
|
||||
from core.config.settings import config
|
||||
from core.db.session import AsyncSessionLocal
|
||||
from core.logging import get_logger
|
||||
from core.taskiq.app import worker_agent_broker, worker_automation_broker
|
||||
from schemas.agent.visibility import SystemVisibilityBit, bit_mask
|
||||
from schemas.automation.config import AutomationJobConfig
|
||||
from schemas.automation import MemoryContextConfig, RuntimeConfig
|
||||
from schemas.memories import MemoryListResponse
|
||||
from schemas.messages.chat_message import (
|
||||
AgentChatMessageMetadata,
|
||||
extract_user_message_attachments,
|
||||
)
|
||||
from schemas.agent.forwarded_props import parse_forwarded_props_agent_type
|
||||
from schemas.user import UserContext
|
||||
from services.base.redis import get_or_init_redis_client
|
||||
from services.base.supabase import supabase_service
|
||||
from v1.agent.repository import AgentRepository
|
||||
from v1.memory.repository import MemoryRepository
|
||||
from v1.memory.service import MemoryService
|
||||
from v1.memories.repository import MemoriesRepository
|
||||
from v1.memories.service import MemoriesService
|
||||
from v1.users.dependencies import get_user_service
|
||||
|
||||
|
||||
logger = get_logger("core.agentscope.runtime.tasks")
|
||||
_MAX_CONTEXT_ATTACHMENTS = 3
|
||||
|
||||
@@ -86,29 +83,14 @@ async def _build_recent_context_messages(
|
||||
*,
|
||||
session: Any,
|
||||
thread_id: str,
|
||||
context_mode: str,
|
||||
memory_job_config: AutomationJobConfig | None = None,
|
||||
context_config: "MemoryContextConfig",
|
||||
) -> list[Msg]:
|
||||
context_service = AgentContextService(repository=AgentRepository(session))
|
||||
if memory_job_config is not None:
|
||||
visibility_mask = bit_mask(bit=int(SystemVisibilityBit.UI_HISTORY))
|
||||
if memory_job_config.context.window_mode.value == "day":
|
||||
result = await context_service.load_by_day_window(
|
||||
thread_id=thread_id,
|
||||
day_count=memory_job_config.context.window_count,
|
||||
visibility_mask=visibility_mask,
|
||||
)
|
||||
else:
|
||||
result = await context_service.load_by_user_message_window(
|
||||
thread_id=thread_id,
|
||||
user_message_limit=memory_job_config.context.window_count,
|
||||
visibility_mask=visibility_mask,
|
||||
)
|
||||
else:
|
||||
result = await context_service.load_context_messages(
|
||||
thread_id=thread_id,
|
||||
system_agent_mode=context_mode,
|
||||
)
|
||||
result = await context_service.load_context_messages(
|
||||
thread_id=thread_id,
|
||||
context_config=context_config,
|
||||
)
|
||||
|
||||
if not result:
|
||||
return []
|
||||
|
||||
@@ -193,6 +175,7 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
command_type = str(command.get("command", "run")).strip().lower()
|
||||
raw_owner_id = command.get("owner_id")
|
||||
run_input_raw = command.get("run_input")
|
||||
runtime_config_raw = command.get("runtime_config")
|
||||
|
||||
if not isinstance(raw_owner_id, str) or not raw_owner_id.strip():
|
||||
raise ValueError("owner_id is required")
|
||||
@@ -200,15 +183,7 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
raise ValueError("run_input is required")
|
||||
|
||||
run_input = parse_run_input(run_input_raw)
|
||||
system_agent_mode = parse_forwarded_props_agent_type(
|
||||
getattr(run_input, "forwarded_props", None)
|
||||
)
|
||||
raw_automation_job_id = command.get("automation_job_id")
|
||||
if system_agent_mode == "memory" and (
|
||||
not isinstance(raw_automation_job_id, str) or not raw_automation_job_id
|
||||
):
|
||||
raise ValueError("automation_job_id is required for memory mode")
|
||||
pipeline_spec = build_default_pipeline_spec(mode=system_agent_mode)
|
||||
runtime_config = RuntimeConfig.model_validate(runtime_config_raw or {})
|
||||
thread_id = run_input.thread_id
|
||||
run_id = run_input.run_id
|
||||
owner_id = UUID(raw_owner_id)
|
||||
@@ -220,15 +195,10 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
user_context = await _build_user_context(owner_id=owner_id, session=session)
|
||||
memory_job_config: AutomationJobConfig | None = None
|
||||
if system_agent_mode == "memory":
|
||||
assert isinstance(raw_automation_job_id, str)
|
||||
job_uuid = UUID(raw_automation_job_id)
|
||||
memory_service = MemoryService(MemoryRepository(session))
|
||||
memory_job_config = await memory_service.get_memory_job_config(
|
||||
job_id=job_uuid,
|
||||
owner_id=owner_id,
|
||||
)
|
||||
memories_service = MemoriesService(MemoriesRepository(session))
|
||||
memories: MemoryListResponse = await memories_service.get_all_memories(
|
||||
owner_id=owner_id
|
||||
)
|
||||
|
||||
redis_client = await get_or_init_redis_client()
|
||||
bus = RedisStreamBus(
|
||||
@@ -251,16 +221,15 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
context_messages = await _build_recent_context_messages(
|
||||
session=session,
|
||||
thread_id=thread_id,
|
||||
context_mode=pipeline_spec.stages[0].agent_type.value,
|
||||
memory_job_config=memory_job_config,
|
||||
context_config=runtime_config.context,
|
||||
)
|
||||
|
||||
await runtime.run(
|
||||
run_input=run_input,
|
||||
context_messages=context_messages,
|
||||
user_context=user_context,
|
||||
system_agent_mode=system_agent_mode,
|
||||
memory_job_config=memory_job_config,
|
||||
runtime_config=runtime_config,
|
||||
memories=memories,
|
||||
)
|
||||
logger.info(
|
||||
"agentscope runtime task completed",
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from core.agentscope.tools.tool_config import resolve_tool_function_names
|
||||
from schemas.agent.system_agent import AgentType
|
||||
|
||||
ToolNameResolver = Callable[[Any], set[str] | None]
|
||||
|
||||
|
||||
class ToolSelectionRegistry:
|
||||
def __init__(self) -> None:
|
||||
self._resolvers: dict[AgentType, ToolNameResolver] = {}
|
||||
|
||||
def register(self, *, agent_type: AgentType, resolver: ToolNameResolver) -> None:
|
||||
self._resolvers[agent_type] = resolver
|
||||
|
||||
def resolve(self, *, stage_config: Any) -> set[str] | None:
|
||||
resolver = self._resolvers.get(stage_config.agent_type)
|
||||
if resolver is None:
|
||||
return None
|
||||
return resolver(stage_config)
|
||||
|
||||
|
||||
def _default_tool_resolver(stage_config: Any) -> set[str] | None:
|
||||
enabled_tools = getattr(stage_config.llm_config, "enabled_tools", [])
|
||||
if not enabled_tools:
|
||||
return None
|
||||
return resolve_tool_function_names(set(enabled_tools))
|
||||
|
||||
|
||||
TOOL_SELECTION_REGISTRY = ToolSelectionRegistry()
|
||||
TOOL_SELECTION_REGISTRY.register(
|
||||
agent_type=AgentType.WORKER,
|
||||
resolver=_default_tool_resolver,
|
||||
)
|
||||
TOOL_SELECTION_REGISTRY.register(
|
||||
agent_type=AgentType.MEMORY,
|
||||
resolver=_default_tool_resolver,
|
||||
)
|
||||
@@ -1,44 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
|
||||
|
||||
class AgentConsumerBinding(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
agent_type: str = Field(..., min_length=1, max_length=64)
|
||||
bit: int = Field(..., ge=16, le=63)
|
||||
|
||||
@field_validator("agent_type")
|
||||
@classmethod
|
||||
def _normalize_agent_type(cls, value: str) -> str:
|
||||
normalized = value.strip().lower()
|
||||
if not normalized:
|
||||
raise ValueError("agent_type must not be empty")
|
||||
return normalized
|
||||
|
||||
|
||||
class ConsumerRegistry(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
bindings: list[AgentConsumerBinding] = Field(default_factory=list)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_unique_bindings(self) -> "ConsumerRegistry":
|
||||
by_agent: set[str] = set()
|
||||
by_bit: set[int] = set()
|
||||
for item in self.bindings:
|
||||
if item.agent_type in by_agent:
|
||||
raise ValueError(f"duplicate agent_type binding: {item.agent_type}")
|
||||
if item.bit in by_bit:
|
||||
raise ValueError(f"duplicate visibility bit binding: {item.bit}")
|
||||
by_agent.add(item.agent_type)
|
||||
by_bit.add(item.bit)
|
||||
return self
|
||||
|
||||
def resolve_agent_bit(self, *, agent_type: str) -> int:
|
||||
target = agent_type.strip().lower()
|
||||
for item in self.bindings:
|
||||
if item.agent_type == target:
|
||||
return item.bit
|
||||
raise ValueError(f"agent visibility bit not configured: {target}")
|
||||
@@ -1,43 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from schemas.agent.system_agent import AgentType
|
||||
|
||||
|
||||
class ExecutorKind(str, Enum):
|
||||
SINGLE_SHOT = "single_shot"
|
||||
REACT = "react"
|
||||
|
||||
|
||||
class StageSpec(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
stage_name: str = Field(..., min_length=1, max_length=64)
|
||||
agent_type: AgentType
|
||||
executor_kind: ExecutorKind
|
||||
|
||||
@field_validator("stage_name")
|
||||
@classmethod
|
||||
def _normalize_stage_name(cls, value: str) -> str:
|
||||
normalized = value.strip().lower()
|
||||
if not normalized:
|
||||
raise ValueError("stage_name must not be empty")
|
||||
return normalized
|
||||
|
||||
|
||||
class PipelineSpec(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
mode: str = Field(..., min_length=1, max_length=64)
|
||||
stages: list[StageSpec] = Field(..., min_length=1)
|
||||
|
||||
@field_validator("mode")
|
||||
@classmethod
|
||||
def _normalize_mode(cls, value: str) -> str:
|
||||
normalized = value.strip().lower()
|
||||
if not normalized:
|
||||
raise ValueError("mode must not be empty")
|
||||
return normalized
|
||||
@@ -1,14 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import date
|
||||
from typing import Protocol
|
||||
from typing import Any, Protocol
|
||||
|
||||
from core.agentscope.runtime.context_loader_registry import CONTEXT_LOADER_REGISTRY
|
||||
from schemas.agent.system_agent import SystemAgentLLMConfig
|
||||
from schemas.agent.visibility import SystemVisibilityBit, bit_mask
|
||||
|
||||
from schemas.automation import ContextWindowMode, MemoryContextConfig
|
||||
|
||||
|
||||
_DEFAULT_CONTEXT_WINDOW_USER_MESSAGES = 20
|
||||
_DEFAULT_ROUTER_CONTEXT_DAY_COUNT = 20
|
||||
|
||||
|
||||
class ContextRepositoryLike(Protocol):
|
||||
@@ -28,9 +29,53 @@ class ContextRepositoryLike(Protocol):
|
||||
visibility_mask: int | None = None,
|
||||
) -> list[dict[str, object]]: ...
|
||||
|
||||
async def get_system_agent_config(
|
||||
self, *, agent_type: str
|
||||
) -> dict[str, object] | None: ...
|
||||
|
||||
ContextLoader = Callable[[Any, str, int, int], Awaitable[dict[str, object] | None]]
|
||||
|
||||
|
||||
class ContextLoaderRegistry:
|
||||
def __init__(self) -> None:
|
||||
self._loaders: dict[ContextWindowMode, ContextLoader] = {}
|
||||
|
||||
def register(self, *, mode: ContextWindowMode, loader: ContextLoader) -> None:
|
||||
self._loaders[mode] = loader
|
||||
|
||||
def resolve(self, *, mode: ContextWindowMode) -> ContextLoader:
|
||||
loader = self._loaders.get(mode)
|
||||
if loader is None:
|
||||
raise ValueError(f"unsupported context mode: {mode.value}")
|
||||
return loader
|
||||
|
||||
|
||||
async def _load_number(
|
||||
service: Any,
|
||||
thread_id: str,
|
||||
count: int,
|
||||
visibility_mask: int,
|
||||
) -> dict[str, object] | None:
|
||||
return await service.load_by_user_message_window(
|
||||
thread_id=thread_id,
|
||||
user_message_limit=max(count, 1),
|
||||
visibility_mask=visibility_mask,
|
||||
)
|
||||
|
||||
|
||||
async def _load_day(
|
||||
service: Any,
|
||||
thread_id: str,
|
||||
count: int,
|
||||
visibility_mask: int,
|
||||
) -> dict[str, object] | None:
|
||||
return await service.load_by_day_window(
|
||||
thread_id=thread_id,
|
||||
day_count=max(count, 1),
|
||||
visibility_mask=visibility_mask,
|
||||
)
|
||||
|
||||
|
||||
CONTEXT_LOADER_REGISTRY = ContextLoaderRegistry()
|
||||
CONTEXT_LOADER_REGISTRY.register(mode=ContextWindowMode.NUMBER, loader=_load_number)
|
||||
CONTEXT_LOADER_REGISTRY.register(mode=ContextWindowMode.DAY, loader=_load_day)
|
||||
|
||||
|
||||
class AgentContextService:
|
||||
@@ -41,32 +86,16 @@ class AgentContextService:
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
system_agent_mode: str,
|
||||
context_config: MemoryContextConfig,
|
||||
) -> dict[str, object] | None:
|
||||
mode = system_agent_mode.strip().lower() if system_agent_mode else "worker"
|
||||
runtime_config = await self._repository.get_system_agent_config(agent_type=mode)
|
||||
raw_llm_config: dict[str, object] = {}
|
||||
if isinstance(runtime_config, dict):
|
||||
raw_config = runtime_config.get("config")
|
||||
if isinstance(raw_config, dict):
|
||||
raw_llm_config = raw_config
|
||||
|
||||
if mode == "router" and not raw_llm_config:
|
||||
raw_llm_config = {
|
||||
"context_messages": {
|
||||
"mode": "day",
|
||||
"count": _DEFAULT_ROUTER_CONTEXT_DAY_COUNT,
|
||||
}
|
||||
}
|
||||
|
||||
normalized_config = self._normalize_system_agent_config(raw_llm_config)
|
||||
context_config = normalized_config.context_messages
|
||||
visibility_mask = bit_mask(bit=int(SystemVisibilityBit.CONTEXT_ASSEMBLY))
|
||||
context_loader = CONTEXT_LOADER_REGISTRY.resolve(mode=context_config.mode)
|
||||
context_loader = CONTEXT_LOADER_REGISTRY.resolve(
|
||||
mode=context_config.window_mode
|
||||
)
|
||||
return await context_loader(
|
||||
self,
|
||||
thread_id,
|
||||
context_config.count,
|
||||
context_config.window_count,
|
||||
visibility_mask,
|
||||
)
|
||||
|
||||
@@ -114,22 +143,6 @@ class AgentContextService:
|
||||
return None
|
||||
return {"messages": messages}
|
||||
|
||||
def _normalize_system_agent_config(
|
||||
self,
|
||||
raw_config: dict[str, object],
|
||||
) -> SystemAgentLLMConfig:
|
||||
default_payload = {
|
||||
"context_messages": {
|
||||
"mode": "number",
|
||||
"count": _DEFAULT_CONTEXT_WINDOW_USER_MESSAGES,
|
||||
},
|
||||
"enabled_tools": [],
|
||||
}
|
||||
if not raw_config:
|
||||
return SystemAgentLLMConfig.model_validate(default_payload)
|
||||
merged = {**default_payload, **raw_config}
|
||||
return SystemAgentLLMConfig.model_validate(merged)
|
||||
|
||||
def _parse_history_day(self, value: object) -> date | None:
|
||||
if isinstance(value, date):
|
||||
return value
|
||||
|
||||
@@ -33,7 +33,6 @@ AGENT_TYPE_TO_DEFAULT_TOOLS: dict[AgentType, set[str]] = {
|
||||
"calendar_share",
|
||||
"user_lookup",
|
||||
},
|
||||
AgentType.MEMORY: {"calendar_read", "user_lookup"},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,263 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Protocol
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from datetime import datetime, timezone
|
||||
from uuid import UUID
|
||||
|
||||
from core.config.settings import config
|
||||
from core.logging import get_logger
|
||||
from models.agent_chat_session import AgentChatSession, SessionType
|
||||
from models.automation_jobs import AutomationJob, ScheduleType
|
||||
from schemas.automation.config import AutomationJobConfig
|
||||
from schemas.automation.scheduler import DueAutomationJob, SchedulerDispatchCommand
|
||||
from schemas.automation import RuntimeConfig
|
||||
|
||||
logger = get_logger("core.automation.scheduler")
|
||||
|
||||
|
||||
class _BulkQueueAdapter:
|
||||
async def enqueue(
|
||||
self,
|
||||
*,
|
||||
command: dict[str, object],
|
||||
dedup_key: str | None,
|
||||
) -> str:
|
||||
del dedup_key
|
||||
from core.agentscope.runtime.tasks import run_command_task_bulk
|
||||
|
||||
result = await run_command_task_bulk.kiq(command)
|
||||
return str(result.task_id)
|
||||
|
||||
|
||||
class QueueLike(Protocol):
|
||||
async def enqueue(
|
||||
self,
|
||||
*,
|
||||
command: dict[str, object],
|
||||
dedup_key: str | None,
|
||||
) -> str: ...
|
||||
|
||||
|
||||
class AutomationSchedulerRepositoryLike(Protocol):
|
||||
async def list_due_jobs(
|
||||
self,
|
||||
*,
|
||||
now_utc: datetime,
|
||||
limit: int,
|
||||
) -> list[DueAutomationJob]: ...
|
||||
|
||||
async def get_job_config(self, *, job_id: UUID) -> AutomationJobConfig: ...
|
||||
|
||||
async def ensure_latest_chat_session(self, *, owner_id: UUID) -> UUID: ...
|
||||
|
||||
async def mark_job_dispatched(
|
||||
self,
|
||||
*,
|
||||
job_id: UUID,
|
||||
next_run_at: datetime,
|
||||
last_run_at: datetime,
|
||||
) -> None: ...
|
||||
|
||||
async def commit(self) -> None: ...
|
||||
|
||||
async def rollback(self) -> None: ...
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class DispatchResult:
|
||||
scanned: int
|
||||
dispatched: int
|
||||
|
||||
|
||||
class AutomationSchedulerService:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
repository: AutomationSchedulerRepositoryLike,
|
||||
queue: QueueLike,
|
||||
) -> None:
|
||||
self._repository = repository
|
||||
self._queue = queue
|
||||
|
||||
async def scan_and_dispatch(
|
||||
self,
|
||||
*,
|
||||
now_utc: datetime,
|
||||
limit: int,
|
||||
) -> DispatchResult:
|
||||
safe_limit = max(int(limit), 1)
|
||||
due_jobs = await self._repository.list_due_jobs(
|
||||
now_utc=now_utc, limit=safe_limit
|
||||
)
|
||||
dispatched = 0
|
||||
for job in due_jobs:
|
||||
try:
|
||||
config = await self._repository.get_job_config(job_id=job.id)
|
||||
thread_id = await self._repository.ensure_latest_chat_session(
|
||||
owner_id=job.owner_id
|
||||
)
|
||||
command = self._build_dispatch_command(
|
||||
job=job,
|
||||
thread_id=thread_id,
|
||||
input_text=config.input_template,
|
||||
now_utc=now_utc,
|
||||
)
|
||||
await self._queue.enqueue(command=command, dedup_key=None)
|
||||
await self._repository.mark_job_dispatched(
|
||||
job_id=job.id,
|
||||
next_run_at=_compute_next_run_at(
|
||||
current_next_run_at=job.next_run_at,
|
||||
now_utc=now_utc,
|
||||
schedule_type=job.schedule_type,
|
||||
),
|
||||
last_run_at=now_utc,
|
||||
)
|
||||
await self._repository.commit()
|
||||
dispatched += 1
|
||||
except Exception as exc:
|
||||
await self._repository.rollback()
|
||||
logger.exception(
|
||||
"automation job dispatch failed",
|
||||
job_id=str(job.id),
|
||||
owner_id=str(job.owner_id),
|
||||
error=str(exc),
|
||||
)
|
||||
return DispatchResult(scanned=len(due_jobs), dispatched=dispatched)
|
||||
|
||||
def _build_dispatch_command(
|
||||
self,
|
||||
*,
|
||||
job: DueAutomationJob,
|
||||
thread_id: UUID,
|
||||
input_text: str,
|
||||
now_utc: datetime,
|
||||
) -> dict[str, object]:
|
||||
run_id = f"auto-{job.id}-{int(now_utc.timestamp())}"
|
||||
payload = SchedulerDispatchCommand(
|
||||
owner_id=job.owner_id,
|
||||
automation_job_id=job.id,
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
input_text=input_text.strip(),
|
||||
)
|
||||
return {
|
||||
"command": "run",
|
||||
"owner_id": str(payload.owner_id),
|
||||
"automation_job_id": str(payload.automation_job_id),
|
||||
"queue": "bulk",
|
||||
"run_input": {
|
||||
"threadId": str(payload.thread_id),
|
||||
"runId": payload.run_id,
|
||||
"state": {},
|
||||
"messages": [
|
||||
{
|
||||
"id": str(uuid4()),
|
||||
"role": "user",
|
||||
"content": payload.input_text,
|
||||
}
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {
|
||||
"agent_type": "memory",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class SqlAlchemyAutomationSchedulerRepository:
|
||||
def __init__(self, *, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
async def list_due_jobs(
|
||||
self,
|
||||
*,
|
||||
now_utc: datetime,
|
||||
limit: int,
|
||||
) -> list[DueAutomationJob]:
|
||||
stmt = (
|
||||
select(AutomationJob)
|
||||
.where(AutomationJob.deleted_at.is_(None))
|
||||
.where(AutomationJob.status == "active")
|
||||
.where(AutomationJob.next_run_at <= now_utc)
|
||||
.order_by(AutomationJob.next_run_at.asc())
|
||||
.limit(max(limit, 1))
|
||||
)
|
||||
rows = (await self._session.execute(stmt)).scalars().all()
|
||||
return [
|
||||
DueAutomationJob(
|
||||
id=row.id,
|
||||
owner_id=row.owner_id,
|
||||
schedule_type=row.schedule_type,
|
||||
timezone=row.timezone,
|
||||
next_run_at=row.next_run_at,
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
async def get_job_config(self, *, job_id: UUID) -> AutomationJobConfig:
|
||||
stmt = select(AutomationJob.config).where(AutomationJob.id == job_id)
|
||||
config_payload = (await self._session.execute(stmt)).scalar_one()
|
||||
return AutomationJobConfig.model_validate(config_payload or {})
|
||||
|
||||
async def ensure_latest_chat_session(self, *, owner_id: UUID) -> UUID:
|
||||
stmt = (
|
||||
select(AgentChatSession.id)
|
||||
.where(AgentChatSession.user_id == owner_id)
|
||||
.where(AgentChatSession.deleted_at.is_(None))
|
||||
.where(AgentChatSession.session_type == SessionType.CHAT)
|
||||
.order_by(AgentChatSession.last_activity_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
existing = (await self._session.execute(stmt)).scalar_one_or_none()
|
||||
if existing is not None:
|
||||
return existing
|
||||
|
||||
session = AgentChatSession(
|
||||
id=uuid4(),
|
||||
user_id=owner_id,
|
||||
session_type=SessionType.CHAT,
|
||||
)
|
||||
self._session.add(session)
|
||||
await self._session.flush()
|
||||
return session.id
|
||||
|
||||
async def mark_job_dispatched(
|
||||
self,
|
||||
*,
|
||||
job_id: UUID,
|
||||
next_run_at: datetime,
|
||||
last_run_at: datetime,
|
||||
) -> None:
|
||||
stmt = select(AutomationJob).where(AutomationJob.id == job_id)
|
||||
row = (await self._session.execute(stmt)).scalar_one()
|
||||
row.next_run_at = next_run_at
|
||||
row.last_run_at = last_run_at
|
||||
await self._session.flush()
|
||||
|
||||
async def commit(self) -> None:
|
||||
await self._session.commit()
|
||||
|
||||
async def rollback(self) -> None:
|
||||
await self._session.rollback()
|
||||
|
||||
|
||||
def _compute_next_run_at(
|
||||
*,
|
||||
current_next_run_at: datetime,
|
||||
now_utc: datetime,
|
||||
schedule_type: ScheduleType,
|
||||
) -> datetime:
|
||||
delta = timedelta(days=1 if schedule_type == ScheduleType.DAILY else 7)
|
||||
next_run_at = current_next_run_at
|
||||
while next_run_at <= now_utc:
|
||||
next_run_at = next_run_at + delta
|
||||
return next_run_at
|
||||
|
||||
|
||||
def utc_now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
@@ -272,22 +24,85 @@ async def run_automation_scheduler_scan(
|
||||
if isinstance(limit, int)
|
||||
else int(config.automation_scheduler.batch_limit)
|
||||
)
|
||||
|
||||
from core.db.session import AsyncSessionLocal
|
||||
from v1.automation_jobs.repository import AutomationJobsRepository
|
||||
from v1.automation_jobs.service import AutomationJobsService
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
repository = SqlAlchemyAutomationSchedulerRepository(session=session)
|
||||
service = AutomationSchedulerService(
|
||||
repository=repository,
|
||||
queue=_BulkQueueAdapter(),
|
||||
repository = AutomationJobsRepository(session=session)
|
||||
service = AutomationJobsService(repository=repository, session=session)
|
||||
|
||||
result = await service.scan_and_dispatch(
|
||||
now_utc=now,
|
||||
limit=safe_limit,
|
||||
dispatch_fn=_dispatch_automation_run,
|
||||
)
|
||||
result = await service.scan_and_dispatch(now_utc=now, limit=safe_limit)
|
||||
|
||||
logger.info(
|
||||
"automation scheduler scan completed",
|
||||
scanned=result.scanned,
|
||||
dispatched=result.dispatched,
|
||||
now_utc=now.astimezone(timezone.utc).isoformat(),
|
||||
now_utc=now.isoformat(),
|
||||
)
|
||||
return {
|
||||
"scanned": int(result.scanned),
|
||||
"dispatched": int(result.dispatched),
|
||||
"scanned": result.scanned,
|
||||
"dispatched": result.dispatched,
|
||||
}
|
||||
|
||||
|
||||
async def _dispatch_automation_run(
|
||||
*,
|
||||
owner_id: UUID,
|
||||
thread_id: UUID,
|
||||
run_id: str,
|
||||
input_text: str,
|
||||
runtime_config: RuntimeConfig,
|
||||
) -> None:
|
||||
from uuid import uuid4
|
||||
|
||||
from ag_ui.core import RunAgentInput
|
||||
from core.auth.models import CurrentUser
|
||||
from core.agentscope.tools.tool_result_storage import create_tool_result_storage
|
||||
from schemas.agent.forwarded_props import RuntimeMode
|
||||
from v1.agent.dependencies import TaskiqQueueClient, RedisEventStream
|
||||
from v1.agent.repository import AgentRepository
|
||||
from v1.agent.service import AgentService
|
||||
|
||||
current_user = CurrentUser(id=owner_id)
|
||||
tool_result_storage = create_tool_result_storage()
|
||||
|
||||
run_input = {
|
||||
"threadId": str(thread_id),
|
||||
"runId": run_id,
|
||||
"state": {},
|
||||
"messages": [
|
||||
{
|
||||
"id": str(uuid4()),
|
||||
"role": "user",
|
||||
"content": input_text,
|
||||
}
|
||||
],
|
||||
"forwardedProps": {
|
||||
"runtimeMode": RuntimeMode.AUTOMATION.value,
|
||||
},
|
||||
}
|
||||
|
||||
parsed_run_input = RunAgentInput.model_validate(run_input)
|
||||
|
||||
from core.db.session import AsyncSessionLocal
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
repository = AgentRepository(
|
||||
session=session, tool_result_storage=tool_result_storage
|
||||
)
|
||||
service = AgentService(
|
||||
repository=repository,
|
||||
queue=TaskiqQueueClient(),
|
||||
stream=RedisEventStream(),
|
||||
)
|
||||
await service.enqueue_run(
|
||||
run_input=parsed_run_input,
|
||||
current_user=current_user,
|
||||
runtime_config=runtime_config,
|
||||
)
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
from core.logging import get_logger
|
||||
from core.taskiq.app import worker_automation_broker
|
||||
|
||||
logger = get_logger("core.automation.tasks")
|
||||
|
||||
|
||||
@worker_automation_broker.task(task_name="tasks.automation.scan_due_jobs")
|
||||
async def scan_due_automation_jobs_task(limit: int | None = None) -> dict[str, int]:
|
||||
from core.automation.scheduler import run_automation_scheduler_scan
|
||||
|
||||
return await run_automation_scheduler_scan(limit=limit)
|
||||
@@ -18,9 +18,7 @@ agents:
|
||||
temperature: 0.7
|
||||
max_tokens: null
|
||||
timeout_seconds: 30
|
||||
context_messages:
|
||||
mode: number
|
||||
count: 20
|
||||
context_messages: null
|
||||
enabled_tools:
|
||||
- calendar.read
|
||||
- calendar.write
|
||||
|
||||
@@ -12,7 +12,7 @@ from schemas.inbox.messages import (
|
||||
parse_calendar_content,
|
||||
)
|
||||
from schemas.invite_codes import InviteCodeRewardConfig
|
||||
from schemas.memories import MemoryContent
|
||||
from schemas.memories import MemoryContext
|
||||
from schemas.messages import AgentChatMessageMetadata
|
||||
from schemas.schedule.items import (
|
||||
AttachmentType,
|
||||
@@ -36,7 +36,7 @@ __all__ = [
|
||||
"InboxMessageStatus",
|
||||
"InboxMessageType",
|
||||
"InviteCodeRewardConfig",
|
||||
"MemoryContent",
|
||||
"MemoryContext",
|
||||
"ScheduleItemMetadata",
|
||||
"ScheduleItemMetadataAttachment",
|
||||
"ScheduleItemSourceType",
|
||||
|
||||
@@ -100,6 +100,7 @@ class NormalizedTaskInput(BaseModel):
|
||||
|
||||
user_text: str
|
||||
multimodal_summary: list[str] = Field(default_factory=list)
|
||||
context_summary: str = Field(default="", max_length=2000)
|
||||
|
||||
|
||||
class RouterUiDecision(BaseModel):
|
||||
|
||||
@@ -1,20 +1,77 @@
|
||||
from schemas.automation.config import (
|
||||
AutomationAgentType,
|
||||
AutomationContextSource,
|
||||
AutomationContextWindowMode,
|
||||
AutomationJobConfig,
|
||||
AutomationMemoryContextConfig,
|
||||
default_memory_job_config,
|
||||
)
|
||||
from schemas.automation.scheduler import DueAutomationJob, SchedulerDispatchCommand
|
||||
from __future__ import annotations
|
||||
|
||||
__all__ = [
|
||||
"AutomationAgentType",
|
||||
"AutomationContextSource",
|
||||
"AutomationContextWindowMode",
|
||||
"AutomationJobConfig",
|
||||
"AutomationMemoryContextConfig",
|
||||
"default_memory_job_config",
|
||||
"DueAutomationJob",
|
||||
"SchedulerDispatchCommand",
|
||||
]
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from uuid import UUID
|
||||
|
||||
from core.agentscope.tools.tool_config import AgentTool
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from models.automation_jobs import AutomationJob as OrmAutomationJob
|
||||
from models.automation_jobs import AutomationJobStatus, ScheduleType
|
||||
|
||||
|
||||
class ContextSource(str, Enum):
|
||||
LATEST_CHAT = "latest_chat"
|
||||
|
||||
|
||||
class ContextWindowMode(str, Enum):
|
||||
DAY = "day"
|
||||
NUMBER = "number"
|
||||
|
||||
|
||||
class MemoryContextConfig(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
source: ContextSource = ContextSource.LATEST_CHAT
|
||||
window_mode: ContextWindowMode = ContextWindowMode.DAY
|
||||
window_count: int = Field(default=2, ge=1, le=200)
|
||||
|
||||
|
||||
class RuntimeConfig(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
enabled_tools: list[AgentTool] = Field(default_factory=list, max_length=32)
|
||||
context: MemoryContextConfig = Field(default_factory=MemoryContextConfig)
|
||||
|
||||
|
||||
class AutomationJobConfig(RuntimeConfig):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
input_template: str = Field(..., min_length=1, max_length=4000)
|
||||
|
||||
|
||||
class AutomationJob(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
id: UUID
|
||||
owner_id: UUID
|
||||
title: str = Field(..., min_length=1, max_length=255)
|
||||
config: AutomationJobConfig
|
||||
schedule_type: ScheduleType
|
||||
run_at: datetime
|
||||
next_run_at: datetime
|
||||
timezone: str = Field(default="UTC", min_length=1, max_length=50)
|
||||
last_run_at: datetime | None = None
|
||||
status: AutomationJobStatus
|
||||
created_by: UUID | None = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
@classmethod
|
||||
def from_orm(cls, obj: OrmAutomationJob) -> "AutomationJob":
|
||||
return cls(
|
||||
id=obj.id,
|
||||
owner_id=obj.owner_id,
|
||||
title=obj.title,
|
||||
config=AutomationJobConfig.model_validate(obj.config or {}),
|
||||
schedule_type=obj.schedule_type,
|
||||
run_at=obj.run_at,
|
||||
next_run_at=obj.next_run_at,
|
||||
timezone=obj.timezone,
|
||||
last_run_at=obj.last_run_at,
|
||||
status=obj.status,
|
||||
created_by=obj.created_by,
|
||||
created_at=obj.created_at,
|
||||
updated_at=obj.updated_at,
|
||||
)
|
||||
|
||||
@@ -1,62 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from core.agentscope.tools.tool_config import AgentTool
|
||||
|
||||
|
||||
class AutomationAgentType(str, Enum):
|
||||
MEMORY = "memory"
|
||||
|
||||
|
||||
class AutomationContextSource(str, Enum):
|
||||
LATEST_CHAT = "latest_chat"
|
||||
|
||||
|
||||
class AutomationContextWindowMode(str, Enum):
|
||||
DAY = "day"
|
||||
NUMBER = "number"
|
||||
|
||||
|
||||
class AutomationMemoryContextConfig(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
source: AutomationContextSource = AutomationContextSource.LATEST_CHAT
|
||||
window_mode: AutomationContextWindowMode = AutomationContextWindowMode.DAY
|
||||
window_count: int = Field(default=2, ge=1, le=200)
|
||||
|
||||
|
||||
class AutomationJobConfig(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
agent_type: AutomationAgentType = AutomationAgentType.MEMORY
|
||||
model_code: str = Field(default="qwen3.5-flash", min_length=1, max_length=64)
|
||||
enabled_tools: list[AgentTool] = Field(default_factory=list, max_length=32)
|
||||
input_template: str = Field(..., min_length=1, max_length=4000)
|
||||
context: AutomationMemoryContextConfig = Field(
|
||||
default_factory=AutomationMemoryContextConfig
|
||||
)
|
||||
|
||||
@field_validator("model_code")
|
||||
@classmethod
|
||||
def _validate_model_code(cls, value: str) -> str:
|
||||
normalized = value.strip()
|
||||
if normalized != "qwen3.5-flash":
|
||||
raise ValueError("model_code must be qwen3.5-flash")
|
||||
return normalized
|
||||
|
||||
|
||||
def default_memory_job_config() -> AutomationJobConfig:
|
||||
return AutomationJobConfig(
|
||||
agent_type=AutomationAgentType.MEMORY,
|
||||
model_code="qwen3.5-flash",
|
||||
enabled_tools=[AgentTool.CALENDAR_READ, AgentTool.USER_LOOKUP],
|
||||
input_template="请基于最近聊天上下文生成一段可执行的记忆总结与建议。",
|
||||
context=AutomationMemoryContextConfig(
|
||||
source=AutomationContextSource.LATEST_CHAT,
|
||||
window_mode=AutomationContextWindowMode.DAY,
|
||||
window_count=2,
|
||||
),
|
||||
)
|
||||
@@ -1,28 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from models.automation_jobs import ScheduleType
|
||||
|
||||
|
||||
class DueAutomationJob(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
id: UUID
|
||||
owner_id: UUID
|
||||
schedule_type: ScheduleType
|
||||
timezone: str = Field(..., min_length=1, max_length=50)
|
||||
next_run_at: datetime
|
||||
|
||||
|
||||
class SchedulerDispatchCommand(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
owner_id: UUID
|
||||
automation_job_id: UUID
|
||||
thread_id: UUID
|
||||
run_id: str = Field(..., min_length=1, max_length=128)
|
||||
input_text: str = Field(..., min_length=1, max_length=4000)
|
||||
@@ -1,11 +1,70 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import ClassVar
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, ClassVar, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class MemoryContent(BaseModel):
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="allow")
|
||||
class MemoryType(str, Enum):
|
||||
USER = "user"
|
||||
WORK = "work"
|
||||
|
||||
pass
|
||||
|
||||
class MemorySource(str, Enum):
|
||||
MANUAL = "manual"
|
||||
AGENT = "agent"
|
||||
IMPORTED = "imported"
|
||||
|
||||
|
||||
class MemoryStatus(str, Enum):
|
||||
ACTIVE = "active"
|
||||
DISABLED = "disabled"
|
||||
|
||||
|
||||
class MemoryModel(BaseModel):
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(
|
||||
extra="forbid", from_attributes=True
|
||||
)
|
||||
|
||||
id: UUID
|
||||
owner_id: UUID
|
||||
agent_id: UUID | None = None
|
||||
memory_type: Literal["user", "work"]
|
||||
title: str | None = None
|
||||
content: dict[str, Any]
|
||||
source: MemorySource
|
||||
status: MemoryStatus
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class MemoryContext(BaseModel):
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
memory_type: MemoryType
|
||||
source: MemorySource
|
||||
title: str | None = None
|
||||
content: dict[str, Any]
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class MemoryListResponse(BaseModel):
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
owner_id: UUID
|
||||
memories: list[MemoryContext] = Field(default_factory=list)
|
||||
total: int
|
||||
|
||||
|
||||
__all__ = [
|
||||
"MemoryContext",
|
||||
"MemoryListResponse",
|
||||
"MemoryModel",
|
||||
"MemorySource",
|
||||
"MemoryStatus",
|
||||
"MemoryType",
|
||||
]
|
||||
|
||||
@@ -18,6 +18,7 @@ from schemas.agent.forwarded_props import (
|
||||
RuntimeMode,
|
||||
)
|
||||
from schemas.agent.visibility import SystemVisibilityBit, bit_mask
|
||||
from schemas.automation import RuntimeConfig
|
||||
from schemas.messages.chat_message import (
|
||||
AgentChatMessageMetadata,
|
||||
UserMessageAttachment,
|
||||
@@ -72,6 +73,7 @@ class AgentService:
|
||||
*,
|
||||
run_input: RunAgentInput,
|
||||
current_user: CurrentUser,
|
||||
runtime_config: RuntimeConfig | None = None,
|
||||
) -> TaskAccepted:
|
||||
created = False
|
||||
thread_id = run_input.thread_id
|
||||
@@ -82,6 +84,13 @@ class AgentService:
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
||||
|
||||
if runtime_config is None:
|
||||
from v1.agent.system_agents_config import (
|
||||
build_runtime_config_from_system_agents,
|
||||
)
|
||||
|
||||
runtime_config = build_runtime_config_from_system_agents()
|
||||
|
||||
try:
|
||||
owner = await self._repository.get_session_owner(session_id=thread_id)
|
||||
except HTTPException as exc:
|
||||
@@ -124,6 +133,9 @@ class AgentService:
|
||||
"run_input": run_input.model_dump(
|
||||
mode="json", by_alias=True, exclude_none=True
|
||||
),
|
||||
"runtime_config": runtime_config.model_dump(
|
||||
mode="json", by_alias=True, exclude_none=True
|
||||
),
|
||||
"queue": queue,
|
||||
},
|
||||
dedup_key=None,
|
||||
|
||||
@@ -0,0 +1,103 @@
|
||||
"""
|
||||
System agents 配置加载工具
|
||||
|
||||
从 system_agents.yaml 加载配置并构建 RuntimeConfig
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
from pydantic import ValidationError
|
||||
|
||||
from schemas.agent.system_agent import SystemAgentLLMConfig
|
||||
from schemas.automation import (
|
||||
ContextSource,
|
||||
ContextWindowMode,
|
||||
MemoryContextConfig,
|
||||
RuntimeConfig,
|
||||
)
|
||||
|
||||
|
||||
def _default_system_agents_path() -> Path:
|
||||
return (
|
||||
Path(__file__).resolve().parents[2]
|
||||
/ "core"
|
||||
/ "config"
|
||||
/ "static"
|
||||
/ "database"
|
||||
/ "system_agents.yaml"
|
||||
)
|
||||
|
||||
|
||||
def _load_system_agents_yaml(path: Path | None = None) -> dict:
|
||||
target_path = path or _default_system_agents_path()
|
||||
with target_path.open("r", encoding="utf-8") as f:
|
||||
loaded = yaml.safe_load(f) or {}
|
||||
if not isinstance(loaded, dict):
|
||||
raise ValueError(f"Invalid system agents format: {target_path}")
|
||||
return loaded
|
||||
|
||||
|
||||
def _parse_context_messages_config(yaml_config: dict | None) -> MemoryContextConfig:
|
||||
if not yaml_config:
|
||||
return MemoryContextConfig()
|
||||
mode_str = yaml_config.get("mode", "day")
|
||||
count = yaml_config.get("count", 2)
|
||||
try:
|
||||
source = ContextSource.LATEST_CHAT
|
||||
except ValueError:
|
||||
source = ContextSource.LATEST_CHAT
|
||||
try:
|
||||
window_mode = ContextWindowMode(mode_str)
|
||||
except ValueError:
|
||||
window_mode = ContextWindowMode.DAY
|
||||
return MemoryContextConfig(
|
||||
source=source,
|
||||
window_mode=window_mode,
|
||||
window_count=count,
|
||||
)
|
||||
|
||||
|
||||
def build_runtime_config_from_system_agents(
|
||||
yaml_path: Path | None = None,
|
||||
) -> RuntimeConfig:
|
||||
"""
|
||||
从 system_agents.yaml 构建 RuntimeConfig
|
||||
|
||||
chat 模式使用:
|
||||
- router.context_messages 配置 context
|
||||
- worker.enabled_tools 配置 tools
|
||||
"""
|
||||
raw = _load_system_agents_yaml(yaml_path)
|
||||
agents_list = raw.get("agents", [])
|
||||
|
||||
router_config: SystemAgentLLMConfig | None = None
|
||||
worker_config: SystemAgentLLMConfig | None = None
|
||||
|
||||
for agent in agents_list:
|
||||
agent_type = str(agent.get("agent_type", "")).strip().lower()
|
||||
if agent_type == "router":
|
||||
config_dict = agent.get("config") or {}
|
||||
try:
|
||||
router_config = SystemAgentLLMConfig.model_validate(config_dict)
|
||||
except ValidationError:
|
||||
router_config = SystemAgentLLMConfig()
|
||||
elif agent_type == "worker":
|
||||
config_dict = agent.get("config") or {}
|
||||
try:
|
||||
worker_config = SystemAgentLLMConfig.model_validate(config_dict)
|
||||
except ValidationError:
|
||||
worker_config = SystemAgentLLMConfig()
|
||||
|
||||
context_cfg = _parse_context_messages_config(
|
||||
router_config.context_messages.model_dump() if router_config else None
|
||||
)
|
||||
|
||||
enabled_tools: list[str] = []
|
||||
if worker_config and worker_config.enabled_tools:
|
||||
enabled_tools = [str(t) for t in worker_config.enabled_tools]
|
||||
|
||||
return RuntimeConfig(
|
||||
enabled_tools=enabled_tools,
|
||||
context=context_cfg,
|
||||
)
|
||||
@@ -0,0 +1,3 @@
|
||||
from v1.automation_jobs.service import AutomationJobsService
|
||||
|
||||
__all__ = ["AutomationJobsService"]
|
||||
@@ -0,0 +1,77 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.db.base_repository import BaseRepository
|
||||
from models.agent_chat_session import AgentChatSession, SessionType
|
||||
from models.automation_jobs import AutomationJob
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class AutomationJobsRepository(BaseRepository[AutomationJob]):
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
super().__init__(session=session, model=AutomationJob)
|
||||
|
||||
async def list_due_jobs(
|
||||
self,
|
||||
*,
|
||||
now_utc: datetime,
|
||||
limit: int,
|
||||
) -> list[AutomationJob]:
|
||||
stmt = (
|
||||
select(AutomationJob)
|
||||
.where(AutomationJob.deleted_at.is_(None))
|
||||
.where(AutomationJob.status == "active")
|
||||
.where(AutomationJob.next_run_at <= now_utc)
|
||||
.order_by(AutomationJob.next_run_at.asc())
|
||||
.limit(max(limit, 1))
|
||||
)
|
||||
rows = (await self._session.execute(stmt)).scalars().all()
|
||||
return list(rows)
|
||||
|
||||
async def update_job_schedule(
|
||||
self,
|
||||
*,
|
||||
job_id: UUID,
|
||||
next_run_at: datetime,
|
||||
last_run_at: datetime,
|
||||
) -> None:
|
||||
stmt = (
|
||||
update(AutomationJob)
|
||||
.where(AutomationJob.id == job_id)
|
||||
.where(AutomationJob.deleted_at.is_(None))
|
||||
.values(next_run_at=next_run_at, last_run_at=last_run_at)
|
||||
)
|
||||
await self._session.execute(stmt)
|
||||
await self._session.flush()
|
||||
|
||||
async def get_or_create_chat_session(self, *, owner_id: UUID) -> UUID:
|
||||
stmt = (
|
||||
select(AgentChatSession.id)
|
||||
.where(AgentChatSession.user_id == owner_id)
|
||||
.where(AgentChatSession.deleted_at.is_(None))
|
||||
.where(AgentChatSession.session_type == SessionType.CHAT)
|
||||
.order_by(AgentChatSession.last_activity_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
existing = (await self._session.execute(stmt)).scalar_one_or_none()
|
||||
if existing is not None:
|
||||
return existing
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
new_session = AgentChatSession(
|
||||
id=uuid4(),
|
||||
user_id=owner_id,
|
||||
session_type=SessionType.CHAT,
|
||||
)
|
||||
self._session.add(new_session)
|
||||
await self._session.flush()
|
||||
return new_session.id
|
||||
@@ -0,0 +1,104 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
from uuid import UUID
|
||||
|
||||
from models.automation_jobs import ScheduleType
|
||||
from schemas.automation import AutomationJob as AutomationJobSchema, RuntimeConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from v1.automation_jobs.repository import AutomationJobsRepository
|
||||
|
||||
|
||||
class DispatchFn(Protocol):
|
||||
async def __call__(
|
||||
self,
|
||||
*,
|
||||
owner_id: UUID,
|
||||
thread_id: UUID,
|
||||
run_id: str,
|
||||
input_text: str,
|
||||
runtime_config: RuntimeConfig,
|
||||
) -> None: ...
|
||||
|
||||
|
||||
def _compute_next_run_at(
|
||||
*,
|
||||
current_next_run_at: datetime,
|
||||
now_utc: datetime,
|
||||
schedule_type: ScheduleType,
|
||||
) -> datetime:
|
||||
delta = timedelta(days=1 if schedule_type == ScheduleType.DAILY else 7)
|
||||
next_run_at = current_next_run_at
|
||||
while next_run_at <= now_utc:
|
||||
next_run_at = next_run_at + delta
|
||||
return next_run_at
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ScanResult:
|
||||
scanned: int
|
||||
dispatched: int
|
||||
|
||||
|
||||
class AutomationJobsService:
|
||||
def __init__(
|
||||
self,
|
||||
repository: "AutomationJobsRepository",
|
||||
session: "AsyncSession",
|
||||
) -> None:
|
||||
self._repository = repository
|
||||
self._session = session
|
||||
|
||||
async def scan_and_dispatch(
|
||||
self,
|
||||
*,
|
||||
now_utc: datetime,
|
||||
limit: int,
|
||||
dispatch_fn: DispatchFn,
|
||||
) -> ScanResult:
|
||||
rows = await self._repository.list_due_jobs(now_utc=now_utc, limit=limit)
|
||||
scanned = len(rows)
|
||||
dispatched = 0
|
||||
|
||||
for row in rows:
|
||||
try:
|
||||
job = AutomationJobSchema.from_orm(row)
|
||||
thread_id = await self.get_or_create_chat_session(owner_id=job.owner_id)
|
||||
run_id = f"auto-{job.id}-{int(now_utc.timestamp())}"
|
||||
|
||||
await dispatch_fn(
|
||||
owner_id=job.owner_id,
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
input_text=job.config.input_template.strip(),
|
||||
runtime_config=RuntimeConfig(
|
||||
enabled_tools=job.config.enabled_tools,
|
||||
context=job.config.context,
|
||||
),
|
||||
)
|
||||
|
||||
await self._repository.update_job_schedule(
|
||||
job_id=job.id,
|
||||
next_run_at=_compute_next_run_at(
|
||||
current_next_run_at=job.next_run_at,
|
||||
now_utc=now_utc,
|
||||
schedule_type=job.schedule_type,
|
||||
),
|
||||
last_run_at=now_utc,
|
||||
)
|
||||
await self._session.commit()
|
||||
dispatched += 1
|
||||
|
||||
except Exception:
|
||||
await self._session.rollback()
|
||||
raise
|
||||
|
||||
return ScanResult(scanned=scanned, dispatched=dispatched)
|
||||
|
||||
async def get_or_create_chat_session(self, *, owner_id: UUID) -> UUID:
|
||||
return await self._repository.get_or_create_chat_session(owner_id=owner_id)
|
||||
@@ -0,0 +1,3 @@
|
||||
from v1.memories.service import MemoriesService
|
||||
|
||||
__all__ = ["MemoriesService"]
|
||||
@@ -0,0 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.db.base_repository import BaseRepository
|
||||
from models.memories import Memory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
||||
class MemoriesRepositoryLike(Protocol):
|
||||
async def get_active_memories(self, *, owner_id: UUID) -> list[Memory]: ...
|
||||
|
||||
|
||||
class MemoriesRepository(BaseRepository[Memory]):
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
super().__init__(session=session, model=Memory)
|
||||
|
||||
async def get_active_memories(self, *, owner_id: UUID) -> list[Memory]:
|
||||
stmt = (
|
||||
select(Memory)
|
||||
.where(Memory.owner_id == owner_id)
|
||||
.where(Memory.status == "active")
|
||||
.order_by(Memory.created_at.desc())
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
@@ -0,0 +1,53 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from models.memories import Memory
|
||||
from schemas.memories import MemoryContext, MemoryListResponse, MemorySource, MemoryType
|
||||
from v1.memories.repository import MemoriesRepositoryLike
|
||||
|
||||
|
||||
class MemoriesService:
|
||||
_repository: MemoriesRepositoryLike
|
||||
|
||||
def __init__(self, repository: MemoriesRepositoryLike) -> None:
|
||||
self._repository = repository
|
||||
|
||||
def _to_context(self, memory: Memory) -> MemoryContext:
|
||||
return MemoryContext(
|
||||
memory_type=MemoryType(memory.memory_type.value),
|
||||
source=MemorySource(memory.source.value),
|
||||
title=memory.title,
|
||||
content=memory.content,
|
||||
created_at=memory.created_at,
|
||||
updated_at=memory.updated_at,
|
||||
)
|
||||
|
||||
async def get_user_memories(self, *, owner_id: UUID) -> MemoryListResponse:
|
||||
memories = await self._repository.get_active_memories(owner_id=owner_id)
|
||||
user_memories = [
|
||||
self._to_context(memory)
|
||||
for memory in memories
|
||||
if memory.memory_type.value == "user"
|
||||
]
|
||||
return MemoryListResponse(
|
||||
owner_id=owner_id, memories=user_memories, total=len(user_memories)
|
||||
)
|
||||
|
||||
async def get_agent_memories(self, *, owner_id: UUID) -> MemoryListResponse:
|
||||
memories = await self._repository.get_active_memories(owner_id=owner_id)
|
||||
agent_memories = [
|
||||
self._to_context(memory)
|
||||
for memory in memories
|
||||
if memory.memory_type.value == "work"
|
||||
]
|
||||
return MemoryListResponse(
|
||||
owner_id=owner_id, memories=agent_memories, total=len(agent_memories)
|
||||
)
|
||||
|
||||
async def get_all_memories(self, *, owner_id: UUID) -> MemoryListResponse:
|
||||
memories = await self._repository.get_active_memories(owner_id=owner_id)
|
||||
memory_contexts = [self._to_context(memory) for memory in memories]
|
||||
return MemoryListResponse(
|
||||
owner_id=owner_id, memories=memory_contexts, total=len(memory_contexts)
|
||||
)
|
||||
@@ -1,3 +0,0 @@
|
||||
from v1.memory.service import MemoryService
|
||||
|
||||
__all__ = ["MemoryService"]
|
||||
@@ -1,35 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.db.base_repository import BaseRepository
|
||||
from models.automation_jobs import AutomationJob
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
||||
class MemoryRepositoryLike(Protocol):
|
||||
async def get_job_by_id_and_owner(
|
||||
self, *, job_id: UUID, owner_id: UUID
|
||||
) -> AutomationJob | None: ...
|
||||
|
||||
|
||||
class MemoryRepository(BaseRepository[AutomationJob]):
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
super().__init__(session=session, model=AutomationJob)
|
||||
|
||||
async def get_job_by_id_and_owner(
|
||||
self, *, job_id: UUID, owner_id: UUID
|
||||
) -> AutomationJob | None:
|
||||
stmt = (
|
||||
select(AutomationJob)
|
||||
.where(AutomationJob.id == job_id)
|
||||
.where(AutomationJob.owner_id == owner_id)
|
||||
.where(AutomationJob.deleted_at.is_(None))
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
@@ -1,25 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from schemas.automation.config import AutomationJobConfig
|
||||
from v1.memory.repository import MemoryRepositoryLike
|
||||
|
||||
|
||||
class MemoryService:
|
||||
_repository: MemoryRepositoryLike
|
||||
|
||||
def __init__(self, repository: MemoryRepositoryLike) -> None:
|
||||
self._repository = repository
|
||||
|
||||
async def get_memory_job_config(
|
||||
self, *, job_id: UUID, owner_id: UUID
|
||||
) -> AutomationJobConfig:
|
||||
job = await self._repository.get_job_by_id_and_owner(
|
||||
job_id=job_id, owner_id=owner_id
|
||||
)
|
||||
if job is None:
|
||||
raise HTTPException(status_code=404, detail="Automation job not found")
|
||||
return AutomationJobConfig.model_validate(job.config or {})
|
||||
@@ -6,6 +6,7 @@ import pytest
|
||||
from ag_ui.core import RunAgentInput
|
||||
|
||||
from core.agentscope.runtime.orchestrator import AgentScopeRuntimeOrchestrator
|
||||
from schemas.automation import MemoryContextConfig, RuntimeConfig
|
||||
from schemas.user import UserContext, parse_profile_settings
|
||||
|
||||
|
||||
@@ -42,11 +43,18 @@ def _run_input() -> RunAgentInput:
|
||||
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {"agent_type": "worker"},
|
||||
"forwardedProps": {"runtime_mode": "automation"},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _runtime_config() -> RuntimeConfig:
|
||||
return RuntimeConfig(
|
||||
enabled_tools=[],
|
||||
context=MemoryContextConfig(),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_orchestrator_emits_run_lifecycle_events() -> None:
|
||||
pipeline = _FakePipeline()
|
||||
@@ -58,7 +66,7 @@ async def test_orchestrator_emits_run_lifecycle_events() -> None:
|
||||
run_input=_run_input(),
|
||||
context_messages=[],
|
||||
user_context=_user_context(),
|
||||
system_agent_mode="worker",
|
||||
runtime_config=_runtime_config(),
|
||||
)
|
||||
|
||||
assert result["worker"]["answer"] == "done"
|
||||
|
||||
@@ -5,7 +5,6 @@ from ag_ui.core import RunAgentInput
|
||||
|
||||
import core.agentscope.runtime.runner as runner_module
|
||||
from core.agentscope.runtime.runner import AgentScopeRunner
|
||||
from schemas.automation.config import default_memory_job_config
|
||||
from schemas.agent.runtime_models import (
|
||||
ExecutionMode,
|
||||
NormalizedTaskInput,
|
||||
@@ -19,6 +18,7 @@ from schemas.agent.runtime_models import (
|
||||
WorkerAgentOutputLite,
|
||||
)
|
||||
from schemas.agent.system_agent import AgentType
|
||||
from schemas.automation import MemoryContextConfig, RuntimeConfig
|
||||
from schemas.user import UserContext, parse_profile_settings
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ def _run_input() -> RunAgentInput:
|
||||
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {"agent_type": "worker"},
|
||||
"forwardedProps": {"runtime_mode": "automation"},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -45,10 +45,20 @@ def _user_context() -> UserContext:
|
||||
)
|
||||
|
||||
|
||||
def _runtime_config() -> RuntimeConfig:
|
||||
return RuntimeConfig(
|
||||
enabled_tools=[],
|
||||
context=MemoryContextConfig(),
|
||||
)
|
||||
|
||||
|
||||
def test_build_worker_input_messages_only_contains_router_contract() -> None:
|
||||
runner = AgentScopeRunner()
|
||||
router_output = RouterAgentOutput(
|
||||
normalized_task_input=NormalizedTaskInput(user_text="安排明天会议"),
|
||||
normalized_task_input=NormalizedTaskInput(
|
||||
user_text="安排明天会议",
|
||||
context_summary="用户询问天气",
|
||||
),
|
||||
key_entities=[],
|
||||
constraints=[],
|
||||
task_typing=TaskTyping(primary=TaskType.SCHEDULING),
|
||||
@@ -67,6 +77,43 @@ def test_build_worker_input_messages_only_contains_router_contract() -> None:
|
||||
assert "[RouterAgentOutput]" in str(input_messages[0].content)
|
||||
|
||||
|
||||
def test_build_router_messages_injects_user_input_when_context_last_not_user() -> None:
|
||||
runner = AgentScopeRunner()
|
||||
run_input = _run_input()
|
||||
|
||||
messages = runner._build_router_messages(
|
||||
context_messages=[],
|
||||
run_input=run_input,
|
||||
)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].role == "user"
|
||||
assert messages[0].content == "hello"
|
||||
|
||||
|
||||
def test_build_router_messages_skips_injection_when_context_last_is_user() -> None:
|
||||
runner = AgentScopeRunner()
|
||||
run_input = _run_input()
|
||||
|
||||
from agentscope.message import Msg
|
||||
|
||||
existing_context = [
|
||||
Msg(name="user", role="user", content="之前的问题"),
|
||||
Msg(name="assistant", role="assistant", content="回答"),
|
||||
Msg(name="user", role="user", content="最新用户消息"),
|
||||
]
|
||||
|
||||
messages = runner._build_router_messages(
|
||||
context_messages=existing_context,
|
||||
run_input=run_input,
|
||||
)
|
||||
|
||||
assert len(messages) == len(existing_context)
|
||||
for i, msg in enumerate(messages):
|
||||
assert msg.role == existing_context[i].role
|
||||
assert msg.content == existing_context[i].content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_runtime_client_time_from_forwarded_props() -> None:
|
||||
runner = AgentScopeRunner()
|
||||
@@ -79,7 +126,7 @@ async def test_resolve_runtime_client_time_from_forwarded_props() -> None:
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {
|
||||
"agent_type": "worker",
|
||||
"runtime_mode": "automation",
|
||||
"client_time": {
|
||||
"device_timezone": "America/Los_Angeles",
|
||||
"client_now_iso": "2026-03-16T09:12:33-07:00",
|
||||
@@ -95,7 +142,7 @@ async def test_resolve_runtime_client_time_from_forwarded_props() -> None:
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_worker_mode_runs_router_then_worker(
|
||||
async def test_execute_runs_router_then_worker(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
class _FakePipeline:
|
||||
@@ -127,7 +174,10 @@ async def test_execute_worker_mode_runs_router_then_worker(
|
||||
async def _fake_execute_router_step(**kwargs: object) -> RouterAgentOutput:
|
||||
del kwargs
|
||||
return RouterAgentOutput(
|
||||
normalized_task_input=NormalizedTaskInput(user_text="安排会议"),
|
||||
normalized_task_input=NormalizedTaskInput(
|
||||
user_text="安排会议",
|
||||
context_summary="用户询问天气",
|
||||
),
|
||||
key_entities=[],
|
||||
constraints=[],
|
||||
task_typing=TaskTyping(primary=TaskType.SCHEDULING),
|
||||
@@ -145,7 +195,7 @@ async def test_execute_worker_mode_runs_router_then_worker(
|
||||
|
||||
monkeypatch.setattr(runner_module, "AsyncSessionLocal", lambda: _FakeSessionCtx())
|
||||
monkeypatch.setattr(runner, "_load_stage_config", _fake_load_stage_config)
|
||||
monkeypatch.setattr(runner, "_build_stage_toolkit", lambda **kwargs: object())
|
||||
monkeypatch.setattr(runner, "_build_toolkit", lambda **kwargs: object())
|
||||
monkeypatch.setattr(runner, "_execute_router_step", _fake_execute_router_step)
|
||||
monkeypatch.setattr(runner, "_execute_worker_step", _fake_execute_worker_step)
|
||||
|
||||
@@ -154,84 +204,9 @@ async def test_execute_worker_mode_runs_router_then_worker(
|
||||
context_messages=[],
|
||||
pipeline=_FakePipeline(),
|
||||
run_input=_run_input(),
|
||||
system_agent_mode="worker",
|
||||
runtime_config=_runtime_config(),
|
||||
)
|
||||
|
||||
assert load_calls == [AgentType.ROUTER, AgentType.WORKER]
|
||||
assert result["router"]["normalized_task_input"]["user_text"] == "安排会议"
|
||||
assert result["worker"]["answer"] == "ok"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_memory_mode_requires_memory_job_config() -> None:
|
||||
runner = AgentScopeRunner()
|
||||
|
||||
class _FakePipeline:
|
||||
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
|
||||
del session_id, event
|
||||
return "1-0"
|
||||
|
||||
with pytest.raises(RuntimeError, match="memory job config is required"):
|
||||
await runner.execute(
|
||||
user_context=_user_context(),
|
||||
context_messages=[],
|
||||
pipeline=_FakePipeline(),
|
||||
run_input=_run_input(),
|
||||
system_agent_mode="memory",
|
||||
memory_job_config=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_memory_mode_uses_memory_job_config(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
class _FakePipeline:
|
||||
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
|
||||
del session_id, event
|
||||
return "1-0"
|
||||
|
||||
class _FakeSessionCtx:
|
||||
async def __aenter__(self) -> object:
|
||||
return object()
|
||||
|
||||
async def __aexit__(self, exc_type: object, exc: object, tb: object) -> None:
|
||||
del exc_type, exc, tb
|
||||
|
||||
runner = AgentScopeRunner()
|
||||
|
||||
async def _fake_build_memory_stage_config(**kwargs: object):
|
||||
del kwargs
|
||||
return runner_module.SystemAgentRuntimeConfig(
|
||||
agent_type=AgentType.MEMORY,
|
||||
model_code="qwen3.5-flash",
|
||||
api_base_url="https://example.com",
|
||||
api_key="test",
|
||||
llm_config=runner_module.SystemAgentLLMConfig(),
|
||||
)
|
||||
|
||||
async def _fake_execute_single_stage_step(**kwargs: object):
|
||||
del kwargs
|
||||
return runner_module.AgentOutput(answer="memory")
|
||||
|
||||
monkeypatch.setattr(runner_module, "AsyncSessionLocal", lambda: _FakeSessionCtx())
|
||||
monkeypatch.setattr(
|
||||
runner, "_build_memory_stage_config", _fake_build_memory_stage_config
|
||||
)
|
||||
monkeypatch.setattr(runner, "_build_stage_toolkit", lambda **kwargs: object())
|
||||
monkeypatch.setattr(
|
||||
runner,
|
||||
"_execute_single_stage_step",
|
||||
_fake_execute_single_stage_step,
|
||||
)
|
||||
|
||||
result = await runner.execute(
|
||||
user_context=_user_context(),
|
||||
context_messages=[],
|
||||
pipeline=_FakePipeline(),
|
||||
run_input=_run_input(),
|
||||
system_agent_mode="memory",
|
||||
memory_job_config=default_memory_job_config(),
|
||||
)
|
||||
|
||||
assert result["memory"]["answer"] == "memory"
|
||||
|
||||
@@ -7,6 +7,7 @@ import pytest
|
||||
|
||||
import core.agentscope.runtime.tasks as tasks_module
|
||||
from schemas.agent import ToolStatus
|
||||
from schemas.automation import ContextWindowMode, MemoryContextConfig
|
||||
from schemas.user import UserContext, parse_profile_settings
|
||||
|
||||
|
||||
@@ -15,21 +16,36 @@ def _run_input_payload() -> dict[str, Any]:
|
||||
"threadId": str(uuid4()),
|
||||
"runId": "run-1",
|
||||
"state": {},
|
||||
"messages": [],
|
||||
"messages": [{"id": "u1", "role": "user", "content": "现在几点"}],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {"agent_type": "worker"},
|
||||
"forwardedProps": {"runtime_mode": "automation"},
|
||||
}
|
||||
|
||||
|
||||
class _FakeSessionCtx:
|
||||
async def __aenter__(self) -> object:
|
||||
return object()
|
||||
async def __aenter__(self) -> "_FakeSession":
|
||||
return _FakeSession()
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
|
||||
del exc_type, exc, tb
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
async def execute(self, stmt: object) -> object:
|
||||
del stmt
|
||||
|
||||
class FakeResult:
|
||||
def scalars(self) -> object:
|
||||
class FakeScalars:
|
||||
def all(self) -> list[object]:
|
||||
return []
|
||||
|
||||
return FakeScalars()
|
||||
|
||||
return FakeResult()
|
||||
|
||||
|
||||
async def _fake_user_context(**kwargs: object) -> UserContext:
|
||||
del kwargs
|
||||
return UserContext(
|
||||
@@ -81,6 +97,10 @@ async def test_run_agentscope_task_calls_runtime_run(
|
||||
"command": "run",
|
||||
"owner_id": str(uuid4()),
|
||||
"run_input": _run_input_payload(),
|
||||
"runtime_config": {
|
||||
"enabled_tools": [],
|
||||
"context": {"window_mode": "day", "window_count": 2},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -89,34 +109,28 @@ async def test_run_agentscope_task_calls_runtime_run(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agentscope_task_includes_recent_context_messages(
|
||||
async def test_run_agentscope_task_injects_runtime_config(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
captured_messages: list[dict[str, Any]] = []
|
||||
captured_config: dict[str, Any] = {}
|
||||
|
||||
class _FakeRuntime:
|
||||
def __init__(self, **kwargs: object) -> None:
|
||||
del kwargs
|
||||
|
||||
async def run(self, **kwargs: object) -> object:
|
||||
raw_context_messages = kwargs.get("context_messages")
|
||||
raw_run_input = kwargs.get("run_input")
|
||||
if isinstance(raw_context_messages, list):
|
||||
captured_messages.extend(raw_context_messages)
|
||||
if raw_run_input is not None:
|
||||
raw_messages = getattr(raw_run_input, "messages", [])
|
||||
if isinstance(raw_messages, list):
|
||||
captured_messages.extend(raw_messages)
|
||||
captured_config.update(
|
||||
{
|
||||
"runtime_config": kwargs.get("runtime_config"),
|
||||
"context_messages": kwargs.get("context_messages"),
|
||||
}
|
||||
)
|
||||
return object()
|
||||
|
||||
async def _fake_get_redis_client() -> object:
|
||||
return object()
|
||||
|
||||
async def _empty_context(**kwargs: object) -> list[dict[str, Any]]:
|
||||
del kwargs
|
||||
return []
|
||||
|
||||
async def _fake_context(**kwargs: object) -> list[dict[str, Any]]:
|
||||
del kwargs
|
||||
return [{"id": "ctx-1", "role": "assistant", "content": "历史上下文"}]
|
||||
|
||||
@@ -133,25 +147,23 @@ async def test_run_agentscope_task_includes_recent_context_messages(
|
||||
"_build_recent_context_messages",
|
||||
_empty_context,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
tasks_module,
|
||||
"_build_recent_context_messages",
|
||||
_fake_context,
|
||||
)
|
||||
|
||||
run_input = _run_input_payload()
|
||||
run_input["messages"] = [{"id": "u1", "role": "user", "content": "现在几点"}]
|
||||
await tasks_module.run_agentscope_task(
|
||||
{
|
||||
"command": "run",
|
||||
"owner_id": str(uuid4()),
|
||||
"run_input": run_input,
|
||||
"run_input": _run_input_payload(),
|
||||
"runtime_config": {
|
||||
"enabled_tools": [],
|
||||
"context": {"window_mode": "day", "window_count": 2},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
assert len(captured_messages) == 2
|
||||
assert captured_messages[0]["id"] == "ctx-1"
|
||||
assert getattr(captured_messages[1], "id", None) == "u1"
|
||||
assert captured_config["context_messages"] == [
|
||||
{"id": "ctx-1", "role": "assistant", "content": "历史上下文"}
|
||||
]
|
||||
assert captured_config["runtime_config"] is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -177,38 +189,6 @@ async def test_run_agentscope_task_rejects_invalid_command_type() -> None:
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agentscope_task_requires_forwarded_props_agent_type() -> None:
|
||||
payload = _run_input_payload()
|
||||
payload["forwardedProps"] = {}
|
||||
|
||||
with pytest.raises(ValueError, match="invalid RunAgentInput.forwardedProps"):
|
||||
await tasks_module.run_agentscope_task(
|
||||
{
|
||||
"command": "run",
|
||||
"owner_id": str(uuid4()),
|
||||
"run_input": payload,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agentscope_task_memory_mode_requires_automation_job_id() -> None:
|
||||
payload = _run_input_payload()
|
||||
payload["forwardedProps"] = {"agent_type": "memory"}
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="automation_job_id is required for memory mode"
|
||||
):
|
||||
await tasks_module.run_agentscope_task(
|
||||
{
|
||||
"command": "run",
|
||||
"owner_id": str(uuid4()),
|
||||
"run_input": payload,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_recent_context_messages_includes_all_user_attachments(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
@@ -221,9 +201,9 @@ async def test_build_recent_context_messages_includes_all_user_attachments(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
system_agent_mode: str,
|
||||
context_config: MemoryContextConfig,
|
||||
) -> dict[str, object] | None:
|
||||
del thread_id, system_agent_mode
|
||||
del thread_id, context_config
|
||||
return {
|
||||
"messages": [
|
||||
{
|
||||
@@ -257,7 +237,10 @@ async def test_build_recent_context_messages_includes_all_user_attachments(
|
||||
messages = await tasks_module._build_recent_context_messages(
|
||||
session=object(),
|
||||
thread_id=str(uuid4()),
|
||||
context_mode="worker",
|
||||
context_config=MemoryContextConfig(
|
||||
window_mode=ContextWindowMode.DAY,
|
||||
window_count=2,
|
||||
),
|
||||
)
|
||||
|
||||
assert len(messages) == 1
|
||||
@@ -281,9 +264,9 @@ async def test_build_recent_context_messages_uses_tool_metadata_output(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
system_agent_mode: str,
|
||||
context_config: MemoryContextConfig,
|
||||
) -> dict[str, object] | None:
|
||||
del thread_id, system_agent_mode
|
||||
del thread_id, context_config
|
||||
return {
|
||||
"messages": [
|
||||
{
|
||||
@@ -312,7 +295,7 @@ async def test_build_recent_context_messages_uses_tool_metadata_output(
|
||||
messages = await tasks_module._build_recent_context_messages(
|
||||
session=object(),
|
||||
thread_id=str(uuid4()),
|
||||
context_mode="worker",
|
||||
context_config=MemoryContextConfig(),
|
||||
)
|
||||
|
||||
assert len(messages) == 1
|
||||
@@ -336,9 +319,9 @@ async def test_build_recent_context_messages_skips_tool_without_metadata_output(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
system_agent_mode: str,
|
||||
context_config: MemoryContextConfig,
|
||||
) -> dict[str, object] | None:
|
||||
del thread_id, system_agent_mode
|
||||
del thread_id, context_config
|
||||
return {
|
||||
"messages": [
|
||||
{
|
||||
@@ -354,17 +337,17 @@ async def test_build_recent_context_messages_skips_tool_without_metadata_output(
|
||||
messages = await tasks_module._build_recent_context_messages(
|
||||
session=object(),
|
||||
thread_id=str(uuid4()),
|
||||
context_mode="worker",
|
||||
context_config=MemoryContextConfig(),
|
||||
)
|
||||
|
||||
assert messages == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_recent_context_messages_passes_context_mode_through(
|
||||
async def test_build_recent_context_messages_passes_context_config(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
captured_mode: dict[str, str | None] = {"mode": None}
|
||||
captured_config: dict[str, Any] = {"config": None}
|
||||
|
||||
class _FakeContextService:
|
||||
def __init__(self, *, repository: object) -> None:
|
||||
@@ -374,19 +357,21 @@ async def test_build_recent_context_messages_passes_context_mode_through(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
system_agent_mode: str,
|
||||
context_config: MemoryContextConfig,
|
||||
) -> dict[str, object] | None:
|
||||
del thread_id
|
||||
captured_mode["mode"] = system_agent_mode
|
||||
captured_config["config"] = context_config
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(tasks_module, "AgentContextService", _FakeContextService)
|
||||
|
||||
cfg = MemoryContextConfig(window_mode=ContextWindowMode.NUMBER, window_count=10)
|
||||
messages = await tasks_module._build_recent_context_messages(
|
||||
session=object(),
|
||||
thread_id=str(uuid4()),
|
||||
context_mode="worker",
|
||||
context_config=cfg,
|
||||
)
|
||||
|
||||
assert messages == []
|
||||
assert captured_mode["mode"] == "worker"
|
||||
assert captured_config["config"].window_mode == ContextWindowMode.NUMBER
|
||||
assert captured_config["config"].window_count == 10
|
||||
|
||||
@@ -23,19 +23,17 @@ def test_build_agent_prompt_for_worker_contains_runtime_config() -> None:
|
||||
assert "enabled_tools=calendar.read,calendar.write" in prompt
|
||||
|
||||
|
||||
def test_build_agent_prompt_for_memory_uses_memory_rules() -> None:
|
||||
def test_build_agent_prompt_for_router_contains_task_typing_rules() -> None:
|
||||
prompt = build_agent_prompt(
|
||||
agent_type=AgentType.MEMORY,
|
||||
agent_type=AgentType.ROUTER,
|
||||
llm_config=SystemAgentLLMConfig.model_validate(
|
||||
{
|
||||
"context_messages": {"mode": "day", "count": 2},
|
||||
"enabled_tools": ["user.lookup"],
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
assert "- type: memory" in prompt
|
||||
assert "[Memory Agent]" in prompt
|
||||
assert "- type: router" in prompt
|
||||
assert "[Router Agent]" in prompt
|
||||
assert "context_messages.mode=day" in prompt
|
||||
assert "context_messages.count=2" in prompt
|
||||
assert "enabled_tools=user.lookup" in prompt
|
||||
|
||||
@@ -156,3 +156,48 @@ def test_build_system_prompt_keeps_sections_focused_without_language_duplication
|
||||
assert "[Answer Style]" in prompt
|
||||
assert "Default reply language:" not in prompt
|
||||
assert "Follow agent contracts strictly" not in prompt
|
||||
|
||||
|
||||
def test_build_system_prompt_includes_memory_section_when_memories_provided() -> None:
|
||||
from schemas.memories import (
|
||||
MemoryContext,
|
||||
MemoryListResponse,
|
||||
MemorySource,
|
||||
MemoryType,
|
||||
)
|
||||
|
||||
memories = MemoryListResponse(
|
||||
owner_id=uuid4(),
|
||||
memories=[
|
||||
MemoryContext(
|
||||
memory_type=MemoryType.USER,
|
||||
source=MemorySource.MANUAL,
|
||||
title="User prefers morning meetings",
|
||||
content={"text": "User likes meetings before 10am"},
|
||||
created_at=datetime(2026, 3, 1, tzinfo=timezone.utc),
|
||||
updated_at=datetime(2026, 3, 1, tzinfo=timezone.utc),
|
||||
),
|
||||
],
|
||||
total=1,
|
||||
)
|
||||
|
||||
prompt = build_system_prompt(
|
||||
agent_type=AgentType.WORKER,
|
||||
user_context=_build_user_context(),
|
||||
now_utc=datetime(2026, 3, 11, 0, 0, tzinfo=timezone.utc),
|
||||
memories=memories,
|
||||
)
|
||||
|
||||
assert "<!-- MEMORY_START -->" in prompt
|
||||
assert "[User Memories]" in prompt
|
||||
assert "User prefers morning meetings" in prompt
|
||||
|
||||
|
||||
def test_build_system_prompt_omits_memory_section_when_no_memories() -> None:
|
||||
prompt = build_system_prompt(
|
||||
agent_type=AgentType.WORKER,
|
||||
user_context=_build_user_context(),
|
||||
now_utc=datetime(2026, 3, 11, 0, 0, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
assert "<!-- MEMORY_START -->" not in prompt
|
||||
|
||||
@@ -5,17 +5,23 @@ from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.automation.scheduler import (
|
||||
AutomationSchedulerService,
|
||||
_compute_next_run_at,
|
||||
from models.automation_jobs import AutomationJob as OrmAutomationJob, ScheduleType
|
||||
from schemas.automation import (
|
||||
RuntimeConfig,
|
||||
)
|
||||
from models.automation_jobs import ScheduleType
|
||||
from schemas.automation.config import AutomationJobConfig
|
||||
from schemas.automation.scheduler import DueAutomationJob
|
||||
from v1.automation_jobs.service import AutomationJobsService, _compute_next_run_at
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
async def commit(self) -> None:
|
||||
pass
|
||||
|
||||
async def rollback(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class _FakeRepository:
|
||||
def __init__(self, jobs: list[DueAutomationJob]) -> None:
|
||||
def __init__(self, jobs: list[OrmAutomationJob]) -> None:
|
||||
self.jobs = jobs
|
||||
self.marked: list[tuple[UUID, datetime, datetime]] = []
|
||||
self.commits = 0
|
||||
@@ -23,30 +29,14 @@ class _FakeRepository:
|
||||
|
||||
async def list_due_jobs(
|
||||
self, *, now_utc: datetime, limit: int
|
||||
) -> list[DueAutomationJob]:
|
||||
) -> list[OrmAutomationJob]:
|
||||
del now_utc
|
||||
return self.jobs[:limit]
|
||||
|
||||
async def get_job_config(self, *, job_id: UUID) -> AutomationJobConfig:
|
||||
del job_id
|
||||
return AutomationJobConfig.model_validate(
|
||||
{
|
||||
"agent_type": "memory",
|
||||
"model_code": "qwen3.5-flash",
|
||||
"enabled_tools": ["calendar.read", "user.lookup"],
|
||||
"input_template": "auto input",
|
||||
"context": {
|
||||
"source": "latest_chat",
|
||||
"window_mode": "day",
|
||||
"window_count": 2,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
async def ensure_latest_chat_session(self, *, owner_id: UUID) -> UUID:
|
||||
async def get_or_create_chat_session(self, *, owner_id: UUID) -> UUID:
|
||||
return owner_id
|
||||
|
||||
async def mark_job_dispatched(
|
||||
async def update_job_schedule(
|
||||
self,
|
||||
*,
|
||||
job_id: UUID,
|
||||
@@ -55,57 +45,65 @@ class _FakeRepository:
|
||||
) -> None:
|
||||
self.marked.append((job_id, next_run_at, last_run_at))
|
||||
|
||||
async def commit(self) -> None:
|
||||
self.commits += 1
|
||||
|
||||
async def rollback(self) -> None:
|
||||
self.rollbacks += 1
|
||||
|
||||
|
||||
class _FakeQueue:
|
||||
def __init__(self) -> None:
|
||||
self.commands: list[dict[str, object]] = []
|
||||
|
||||
async def enqueue(
|
||||
self,
|
||||
*,
|
||||
command: dict[str, object],
|
||||
dedup_key: str | None,
|
||||
) -> str:
|
||||
del dedup_key
|
||||
self.commands.append(command)
|
||||
return "task-1"
|
||||
def _make_orm_job(
|
||||
*,
|
||||
job_id: UUID | None = None,
|
||||
owner_id: UUID | None = None,
|
||||
schedule_type: ScheduleType = ScheduleType.DAILY,
|
||||
next_run_at: datetime | None = None,
|
||||
) -> OrmAutomationJob:
|
||||
now = datetime(2026, 3, 19, 12, 0, tzinfo=timezone.utc)
|
||||
return OrmAutomationJob(
|
||||
id=job_id or uuid4(),
|
||||
owner_id=owner_id or uuid4(),
|
||||
title="Test Job",
|
||||
config={
|
||||
"enabled_tools": ["calendar.read", "user.lookup"],
|
||||
"context": {
|
||||
"source": "latest_chat",
|
||||
"window_mode": "day",
|
||||
"window_count": 2,
|
||||
},
|
||||
"input_template": "auto input: {date}",
|
||||
},
|
||||
schedule_type=schedule_type,
|
||||
run_at=now - timedelta(hours=1),
|
||||
next_run_at=next_run_at or now - timedelta(minutes=1),
|
||||
timezone="UTC",
|
||||
last_run_at=None,
|
||||
status="active",
|
||||
created_by=None,
|
||||
created_at=now - timedelta(days=1),
|
||||
updated_at=now - timedelta(hours=1),
|
||||
deleted_at=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_and_dispatch_enqueues_memory_run_command() -> None:
|
||||
async def test_scan_and_dispatch_calls_dispatch_fn_with_runtime_config() -> None:
|
||||
now = datetime(2026, 3, 19, 12, 0, tzinfo=timezone.utc)
|
||||
owner_id = uuid4()
|
||||
job_id = uuid4()
|
||||
repo = _FakeRepository(
|
||||
jobs=[
|
||||
DueAutomationJob(
|
||||
id=job_id,
|
||||
owner_id=owner_id,
|
||||
schedule_type=ScheduleType.DAILY,
|
||||
timezone="UTC",
|
||||
next_run_at=now - timedelta(minutes=1),
|
||||
)
|
||||
]
|
||||
)
|
||||
queue = _FakeQueue()
|
||||
service = AutomationSchedulerService(repository=repo, queue=queue)
|
||||
repo = _FakeRepository(jobs=[_make_orm_job(job_id=job_id, owner_id=owner_id)])
|
||||
dispatched_calls: list[dict] = []
|
||||
|
||||
result = await service.scan_and_dispatch(now_utc=now, limit=10)
|
||||
async def dispatch_fn(**kwargs: object) -> None:
|
||||
dispatched_calls.append(kwargs)
|
||||
|
||||
service = AutomationJobsService(repository=repo, session=_FakeSession())
|
||||
|
||||
result = await service.scan_and_dispatch(
|
||||
now_utc=now, limit=10, dispatch_fn=dispatch_fn
|
||||
)
|
||||
|
||||
assert result.scanned == 1
|
||||
assert result.dispatched == 1
|
||||
assert len(queue.commands) == 1
|
||||
run_input = queue.commands[0]["run_input"]
|
||||
assert isinstance(run_input, dict)
|
||||
assert run_input["forwardedProps"] == {"agent_type": "memory"}
|
||||
assert queue.commands[0]["automation_job_id"] == str(job_id)
|
||||
assert repo.commits == 1
|
||||
assert len(dispatched_calls) == 1
|
||||
assert dispatched_calls[0]["owner_id"] == owner_id
|
||||
assert dispatched_calls[0]["runtime_config"] is not None
|
||||
cfg: RuntimeConfig = dispatched_calls[0]["runtime_config"]
|
||||
assert len(cfg.enabled_tools) == 2
|
||||
|
||||
|
||||
def test_compute_next_run_at_daily() -> None:
|
||||
|
||||
@@ -202,10 +202,12 @@ interface ForwardedProps {
|
||||
|
||||
### 运行模式说明
|
||||
|
||||
| runtime_mode | 说明 | 后端 Pipeline |
|
||||
|--------------|------|---------------|
|
||||
| `chat` | 标准对话模式 | `router` -> `worker` |
|
||||
| `automation` | 自动化任务模式 | 由后端业务逻辑决定具体 Agent 类型 |
|
||||
| runtime_mode | 说明 | Pipeline | 差异 |
|
||||
|--------------|------|----------|------|
|
||||
| `chat` | 标准对话模式 | `router` -> `worker` | `enabled_tools` 和 `context` 来自 `system_agents.yaml` |
|
||||
| `automation` | 自动化任务模式 | `router` -> `worker` | `enabled_tools` 和 `context` 来自 `AutomationJob.config`(通过 `runtime_config` 注入)|
|
||||
|
||||
> `runtime_mode` 仅影响 `RuntimeConfig`(工具列表与上下文配置),不改变执行阶段。两模式均使用固定两阶段 pipeline。
|
||||
|
||||
### 时间来源优先级(固定)
|
||||
|
||||
|
||||
@@ -326,6 +326,48 @@ cost = uncached_prompt_tokens * input_cost_per_token
|
||||
|
||||
## 8) 可见性与上下文装载说明
|
||||
|
||||
- 持久化消息使用单字段 `visibility_mask`(位掩码)控制 consumer 可见性。
|
||||
- `/history` 仅投影 `ui.history` 可见消息。
|
||||
- 运行时上下文按当前 stage 对应 consumer 位过滤装载,不依赖前端展示可见性。
|
||||
### visibility_mask 位掩码系统
|
||||
|
||||
持久化消息使用单字段 `visibility_mask`(位掩码)控制不同 consumer 的可见性:
|
||||
|
||||
| Bit | 常量名 | 说明 |
|
||||
|-----|--------|------|
|
||||
| 0 | `UI_HISTORY` | `/history` API 投影可见的消息 |
|
||||
| 1 | `CONTEXT_ASSEMBLY` | 运行时上下文装配(context assembly)可见 |
|
||||
|
||||
> 新消息入库时,`chat` 模式设置 `mask = UI_HISTORY | CONTEXT_ASSEMBLY`(值为 3),`automation` 模式设置 `mask = 0`。
|
||||
|
||||
### /history API
|
||||
|
||||
`GET /api/v1/agent/history` 仅投影包含 `UI_HISTORY` 位的消息:
|
||||
|
||||
```sql
|
||||
WHERE (visibility_mask & 1) != 0
|
||||
```
|
||||
|
||||
### 运行时上下文装配
|
||||
|
||||
`load_context_messages` 查询上下文时使用 `CONTEXT_ASSEMBLY` 位过滤:
|
||||
|
||||
```sql
|
||||
WHERE (visibility_mask & 2) != 0
|
||||
```
|
||||
|
||||
**影响**:
|
||||
- `chat` 模式用户输入:mask=3 → 进入 `/history` ✅,进入 context assembly ✅
|
||||
- `automation` 模式用户输入:mask=0 → 进入 `/history` ❌,进入 context assembly ❌
|
||||
|
||||
### Automation 模式上下文注入
|
||||
|
||||
由于 automation 用户输入 `mask=0` 不进入 context assembly,router 调用前会从 `RunAgentInput.messages` 注入最新用户消息到 context 头部(条件:context 为空 或 最后一条非 user)。
|
||||
|
||||
### runtime_mode 差异总结
|
||||
|
||||
| 维度 | `chat` | `automation` |
|
||||
|------|--------|--------------|
|
||||
| Pipeline | `router` -> `worker` | `router` -> `worker` |
|
||||
| 用户输入 visibility_mask | `UI_HISTORY \| CONTEXT_ASSEMBLY` | `0` |
|
||||
| 进入 /history | ✅ | ❌ |
|
||||
| 进入 context assembly | ✅(自动) | ❌(通过 run_input 注入) |
|
||||
| enabled_tools 来源 | `system_agents.yaml` worker 配置 | `AutomationJob.config.enabled_tools` |
|
||||
| context 配置来源 | `system_agents.yaml` router context_messages | `AutomationJob.config.context` |
|
||||
|
||||
Reference in New Issue
Block a user