feat(agentscope): add memory system and automation job support

- Add consumer_registry and pipeline_registry for runtime orchestration
- Add Visibility schema for message filtering
- Add PipelineSpec for agent pipeline configuration
- Add automation job models and configuration
- Remove memory_prompt.py (consolidated into memory system)
- Update runtime components: context_loader, context_service, orchestrator, runner, tasks
- Update toolkit: tool_config, tool_middleware, custom tools (calendar, user_lookup)
- Add auth_helpers and calendar_domain utilities
- Add system_agents.yaml configuration
This commit is contained in:
qzl
2026-03-19 18:42:35 +08:00
parent 0661016827
commit 0abf51e837
55 changed files with 2172 additions and 1233 deletions
@@ -10,8 +10,6 @@ from ag_ui.core import (
RunErrorEvent,
StepStartedEvent,
StepFinishedEvent,
TextMessageEndEvent,
ToolCallResultEvent,
)
from core.agentscope.runtime.ui_compiler import compile as compile_ui_hints
from schemas.agent.ui_hints import UiHintsPayload
@@ -94,9 +92,20 @@ def _build_run_finished(event: dict[str, Any]) -> RunFinishedEvent:
def _build_run_error(event: dict[str, Any]) -> RunErrorEvent:
data = event.get("data", {})
top_level_message = event.get("message")
message = top_level_message if isinstance(top_level_message, str) else ""
top_level_code = event.get("code")
code = top_level_code if isinstance(top_level_code, str) else None
if (not message or code is None) and isinstance(data, dict):
data_message = data.get("message")
if not message and isinstance(data_message, str):
message = data_message
data_code = data.get("code")
if code is None and isinstance(data_code, str):
code = data_code
return RunErrorEvent(
message=data.get("message", "Unknown error"),
code=data.get("code"),
message=message or "Unknown error",
code=code,
)
@@ -120,34 +129,12 @@ def _build_step_finished(event: dict[str, Any]) -> StepFinishedEvent:
)
def _build_text_end(event: dict[str, Any]) -> TextMessageEndEvent:
data = event.get("data", {})
return TextMessageEndEvent(
message_id=data.get("messageId", ""),
)
def _build_tool_result(event: dict[str, Any]) -> ToolCallResultEvent:
data = event.get("data", {})
content = data.get("result")
if not isinstance(content, str):
content = ""
return ToolCallResultEvent(
message_id=data.get("messageId", ""),
tool_call_id=data.get("toolCallId", ""),
content=content,
role="tool",
)
_BUILDER_MAP: dict[str, Any] = {
"run.started": _build_run_started,
"run.finished": _build_run_finished,
"run.error": _build_run_error,
"step.start": _build_step_started,
"step.finish": _build_step_finished,
"text.end": _build_text_end,
"tool.result": _build_tool_result,
}
@@ -208,6 +195,8 @@ def to_agui_wire_event(event: dict[str, Any] | BaseEvent) -> dict[str, Any]:
payload["runId"] = run_id
if isinstance(data, dict):
reserved = {"type", "threadId", "runId"}
if internal_type == "run.error":
reserved = {*reserved, "message", "code"}
payload.update({k: v for k, v in data.items() if k not in reserved})
return payload
@@ -29,6 +29,7 @@ class MessageRepository:
output_tokens: int = 0,
cost: Decimal = Decimal("0"),
latency_ms: int | None = None,
visibility_mask: int = 0,
) -> AgentChatMessage:
message = AgentChatMessage(
session_id=session_id,
@@ -42,6 +43,7 @@ class MessageRepository:
output_tokens=output_tokens,
cost=cost,
latency_ms=latency_ms,
visibility_mask=max(int(visibility_mask), 0),
)
self._session.add(message)
await self._session.flush()
+57 -1
View File
@@ -8,9 +8,12 @@ from core.agentscope.events.persistence import MessageRepository, SessionReposit
from core.logging import get_logger
from models.agent_chat_message import AgentChatMessageRole
from models.agent_chat_session import AgentChatSessionStatus
from models.system_agents import SystemAgents
from schemas.agent.system_agent import AgentType, SystemAgentLLMConfig
from schemas.agent.runtime_models import AgentOutput, ToolAgentOutput
from schemas.agent.system_agent import AgentType
from schemas.agent.visibility import SystemVisibilityBit, bit_mask
from schemas.messages.chat_message import AgentChatMessageMetadata
from sqlalchemy import select
class EventStore(Protocol):
@@ -45,6 +48,9 @@ class SqlAlchemyEventStore:
async with self._session_factory() as session:
session_repo = SessionRepository(session)
message_repo = MessageRepository(session)
stage_visibility_bit_map = await self._load_stage_visibility_bit_map(
session=session
)
chat_session = await session_repo.get_session(session_id=session_id)
if chat_session is None:
return
@@ -77,6 +83,7 @@ class SqlAlchemyEventStore:
chat_session=chat_session,
session_repo=session_repo,
message_repo=message_repo,
stage_visibility_bit_map=stage_visibility_bit_map,
)
elif event_type == "TOOL_CALL_RESULT":
await self._persist_tool_call_result(
@@ -85,6 +92,7 @@ class SqlAlchemyEventStore:
chat_session=chat_session,
session_repo=session_repo,
message_repo=message_repo,
stage_visibility_bit_map=stage_visibility_bit_map,
)
await session.commit()
@@ -97,6 +105,7 @@ class SqlAlchemyEventStore:
chat_session: Any,
session_repo: SessionRepository,
message_repo: MessageRepository,
stage_visibility_bit_map: dict[str, int],
) -> None:
message_id_raw = self._event_value(event, "messageId")
message_id = message_id_raw if isinstance(message_id_raw, str) else ""
@@ -188,6 +197,10 @@ class SqlAlchemyEventStore:
output_tokens=output_tokens,
cost=cost,
latency_ms=latency_ms,
visibility_mask=self._resolve_stage_visibility_mask(
event=event,
stage_visibility_bit_map=stage_visibility_bit_map,
),
)
current_status = getattr(chat_session, "status", AgentChatSessionStatus.RUNNING)
@@ -213,6 +226,7 @@ class SqlAlchemyEventStore:
chat_session: Any,
session_repo: SessionRepository,
message_repo: MessageRepository,
stage_visibility_bit_map: dict[str, int],
) -> None:
run_id = self._event_value(event, "runId")
run_id_value = run_id if isinstance(run_id, str) and run_id else None
@@ -256,6 +270,10 @@ class SqlAlchemyEventStore:
content=content,
tool_name=tool_output.tool_name,
metadata=metadata_model.model_dump(mode="json", exclude_none=True),
visibility_mask=self._resolve_stage_visibility_mask(
event=event,
stage_visibility_bit_map=stage_visibility_bit_map,
),
)
current_status = getattr(chat_session, "status", AgentChatSessionStatus.RUNNING)
@@ -279,6 +297,44 @@ class SqlAlchemyEventStore:
return AgentChatMessageRole.TOOL
return AgentChatMessageRole.ASSISTANT
def _resolve_stage_visibility_mask(
self,
*,
event: dict[str, Any],
stage_visibility_bit_map: dict[str, int],
) -> int:
base = bit_mask(bit=int(SystemVisibilityBit.UI_HISTORY))
raw_stage = self._event_value(event, "stage")
if not isinstance(raw_stage, str):
return base
normalized_stage = raw_stage.strip().lower()
bit = stage_visibility_bit_map.get(normalized_stage)
if bit is None and normalized_stage == AgentType.MEMORY.value:
bit = 18
if bit is None:
return base
return base | bit_mask(bit=bit)
async def _load_stage_visibility_bit_map(
self,
*,
session: Any,
) -> dict[str, int]:
stmt = select(SystemAgents.agent_type, SystemAgents.config).where(
SystemAgents.agent_type.in_(
[AgentType.ROUTER.value, AgentType.WORKER.value, AgentType.MEMORY.value]
)
)
rows = (await session.execute(stmt)).all()
bit_map: dict[str, int] = {}
for agent_type, raw_config in rows:
if not isinstance(agent_type, str):
continue
config_payload = raw_config if isinstance(raw_config, dict) else {}
llm_config = SystemAgentLLMConfig.model_validate(config_payload)
bit_map[agent_type.strip().lower()] = llm_config.visibility_consumer_bit
return bit_map
async def _update_session_state(
self,
*,
@@ -1,7 +1,10 @@
from __future__ import annotations
import json
from collections.abc import Callable
from typing import Any
from schemas.agent.runtime_models import ResultType, RouterAgentOutput, TaskType
from schemas.agent.system_agent import AgentType, SystemAgentLLMConfig
@@ -14,6 +17,24 @@ def _wrap_section(section: str, content: str) -> str:
return f"{start}\n{body}\n{end}" if body else f"{start}\n{end}"
def _enum_values(enum_cls: Any) -> str:
return ", ".join(item.value for item in enum_cls)
def _config_rules(llm_config: SystemAgentLLMConfig | None) -> list[str]:
if llm_config is None:
return []
context_mode = llm_config.context_messages.mode.value
context_count = llm_config.context_messages.count
enabled_tools = [tool.value for tool in llm_config.enabled_tools]
return [
"[Runtime Config]",
f"- context_messages.mode={context_mode}",
f"- context_messages.count={context_count}",
f"- enabled_tools={','.join(enabled_tools) if enabled_tools else 'default'}",
]
PromptRuleBuilder = Callable[[SystemAgentLLMConfig | None], list[str]]
@@ -36,36 +57,41 @@ class AgentPromptRegistry:
return builder(llm_config)
def _config_rules(llm_config: SystemAgentLLMConfig | None) -> list[str]:
if llm_config is None:
return []
context_mode = llm_config.context_messages.mode.value
context_count = llm_config.context_messages.count
tool_groups = [group.value for group in llm_config.enabled_tool_groups]
def _router_rules(llm_config: SystemAgentLLMConfig | None) -> list[str]:
return [
"[Runtime Config]",
f"- context_messages.mode={context_mode}",
f"- context_messages.count={context_count}",
f"- enabled_tool_groups={','.join(tool_groups) if tool_groups else 'default'}",
"[Router Agent]",
"- Read the latest user input and produce a routing contract for downstream execution.",
"- Return exactly one RouterAgentOutput JSON object.",
"[Responsibilities]",
"- Router only: extract intent and route strategy; never answer user directly.",
"- Preserve intent in normalized_task_input.user_text; keep wording concise and faithful.",
"- Fill multimodal_summary only when image/attachment changes execution decisions.",
"- Return key_entities and constraints that are execution-relevant; low confidence -> omit rather than guess.",
"- Set execution_mode by complexity: onestep / tool_assisted / multistep.",
"- Set result_typing.primary to the most suitable response shape; use clarification_request only when required info is missing.",
"- Set ui.ui_mode and ui.ui_decision_reason based on whether structured UI improves actionability.",
f"- task_typing.primary must use one TaskType enum: {_enum_values(TaskType)}.",
f"- task_typing.secondary max 3 enums: {_enum_values(TaskType)}.",
f"- result_typing.primary must use one ResultType enum: {_enum_values(ResultType)}.",
f"- result_typing.secondary max 3 enums: {_enum_values(ResultType)}.",
*_config_rules(llm_config),
]
def _worker_rules(llm_config: SystemAgentLLMConfig | None) -> list[str]:
return [
"[Worker Agent]",
"- Process user context directly, identify intent, then execute or answer with available evidence.",
"- Return exactly one agent output JSON object matching the runtime-injected schema.",
"- Execute or answer against the routed objective and available evidence.",
"- Return exactly one worker output JSON object matching the runtime-injected schema.",
"[Responsibilities]",
"- Handle user request directly from conversation context.",
"- Decide intent first, then choose direct answer, tool call, clarification, or refusal.",
"- Prefer a direct answer when no tool result is required.",
"- Call tools only when tool results are necessary to produce a correct answer.",
"- Infer deterministic required tool arguments from context, tool schema, and runtime signals.",
"- Worker only: execute routed objective without changing router intent.",
"- Treat router output as objective/constraints contract, not as a fully-materialized tool-args payload.",
"- Infer deterministic required tool arguments from contract fields, tool schema, and runtime context.",
"- Ask minimal clarification only when required arguments cannot be inferred safely.",
"- If request is unsafe or disallowed, return safe refusal with actionable alternative.",
"- Ground every claim in evidence and tool outputs; never fabricate execution state.",
"- Ground every claim in available evidence and tool results; never fabricate execution state.",
"- Keep status/result_type/answer/key_points/suggested_actions/error internally consistent.",
"[Schema Guidance]",
"- The output schema is injected at runtime; follow it exactly.",
"- The worker output schema is injected at runtime; follow it exactly.",
"- Do not add fields that are not present in the injected schema.",
*_config_rules(llm_config),
]
@@ -88,7 +114,28 @@ def _memory_rules(llm_config: SystemAgentLLMConfig | None) -> list[str]:
]
def build_worker_contract_prompt(*, router_output: RouterAgentOutput) -> str:
contract_json = json.dumps(
router_output.model_dump(mode="json", exclude_none=True),
ensure_ascii=True,
separators=(",", ":"),
)
return "\n".join(
[
"[Worker Contract]",
"- Keep routed objective unchanged.",
"- Use normalized_task_input as objective text.",
"- Use multimodal_summary/key_entities/constraints as execution evidence.",
"- Infer deterministic missing required tool args from evidence + tool schema.",
"- Ask clarification only when safe inference is impossible.",
"[RouterAgentOutput]",
contract_json,
]
)
AGENT_PROMPT_REGISTRY = AgentPromptRegistry()
AGENT_PROMPT_REGISTRY.register(agent_type=AgentType.ROUTER, builder=_router_rules)
AGENT_PROMPT_REGISTRY.register(agent_type=AgentType.WORKER, builder=_worker_rules)
AGENT_PROMPT_REGISTRY.register(agent_type=AgentType.MEMORY, builder=_memory_rules)
@@ -1 +0,0 @@
pass
@@ -0,0 +1,19 @@
from __future__ import annotations
from schemas.agent.consumer_registry import AgentConsumerBinding, ConsumerRegistry
def build_consumer_registry(
*,
system_agent_configs: dict[str, dict[str, object]],
) -> ConsumerRegistry:
bindings: list[AgentConsumerBinding] = []
for agent_type, payload in system_agent_configs.items():
config_obj = payload.get("config") if isinstance(payload, dict) else None
if not isinstance(config_obj, dict):
raise ValueError(f"invalid system agent config: {agent_type}")
raw_bit = config_obj.get("visibility_consumer_bit")
if not isinstance(raw_bit, int):
raise ValueError(f"visibility_consumer_bit missing for agent: {agent_type}")
bindings.append(AgentConsumerBinding(agent_type=agent_type, bit=raw_bit))
return ConsumerRegistry(bindings=bindings)
@@ -5,7 +5,7 @@ from typing import Any
from schemas.agent.system_agent import ContextBuildStrategy
ContextLoader = Callable[[Any, str, int], Awaitable[dict[str, object] | None]]
ContextLoader = Callable[[Any, str, int, int], Awaitable[dict[str, object] | None]]
class ContextLoaderRegistry:
@@ -23,20 +23,28 @@ class ContextLoaderRegistry:
async def _load_number(
service: Any, thread_id: str, count: int
service: Any,
thread_id: str,
count: int,
visibility_mask: int,
) -> dict[str, object] | None:
return await service.load_by_user_message_window(
thread_id=thread_id,
user_message_limit=max(count, 1),
visibility_mask=visibility_mask,
)
async def _load_day(
service: Any, thread_id: str, count: int
service: Any,
thread_id: str,
count: int,
visibility_mask: int,
) -> dict[str, object] | None:
return await service.load_by_day_window(
thread_id=thread_id,
day_count=max(count, 1),
visibility_mask=visibility_mask,
)
@@ -5,17 +5,27 @@ from typing import Protocol
from core.agentscope.runtime.context_loader_registry import CONTEXT_LOADER_REGISTRY
from schemas.agent.system_agent import SystemAgentLLMConfig
from schemas.agent.visibility import bit_mask
_DEFAULT_CONTEXT_WINDOW_USER_MESSAGES = 20
_DEFAULT_ROUTER_CONTEXT_DAY_COUNT = 20
class ContextRepositoryLike(Protocol):
async def get_history_day(
self, *, session_id: str, before: date | None
self,
*,
session_id: str,
before: date | None,
visibility_mask: int | None = None,
) -> dict[str, object] | None: ...
async def get_recent_messages_by_user_window(
self, *, session_id: str, user_message_limit: int
self,
*,
session_id: str,
user_message_limit: int,
visibility_mask: int | None = None,
) -> list[dict[str, object]]: ...
async def get_system_agent_config(
@@ -41,20 +51,36 @@ class AgentContextService:
if isinstance(raw_config, dict):
raw_llm_config = raw_config
if mode == "router" and not raw_llm_config:
raw_llm_config = {
"context_messages": {
"mode": "day",
"count": _DEFAULT_ROUTER_CONTEXT_DAY_COUNT,
}
}
normalized_config = self._normalize_system_agent_config(raw_llm_config)
context_config = normalized_config.context_messages
visibility_mask = bit_mask(bit=normalized_config.visibility_consumer_bit)
context_loader = CONTEXT_LOADER_REGISTRY.resolve(mode=context_config.mode)
return await context_loader(self, thread_id, context_config.count)
return await context_loader(
self,
thread_id,
context_config.count,
visibility_mask,
)
async def load_by_user_message_window(
self,
*,
thread_id: str,
user_message_limit: int,
visibility_mask: int,
) -> dict[str, object] | None:
messages = await self._repository.get_recent_messages_by_user_window(
session_id=thread_id,
user_message_limit=max(int(user_message_limit), 1),
visibility_mask=visibility_mask,
)
if not messages:
return None
@@ -65,6 +91,7 @@ class AgentContextService:
*,
thread_id: str,
day_count: int,
visibility_mask: int,
) -> dict[str, object] | None:
messages: list[dict[str, object]] = []
before: date | None = None
@@ -72,6 +99,7 @@ class AgentContextService:
day_payload = await self._repository.get_history_day(
session_id=thread_id,
before=before,
visibility_mask=visibility_mask,
)
if not day_payload:
break
@@ -95,7 +123,7 @@ class AgentContextService:
"mode": "number",
"count": _DEFAULT_CONTEXT_WINDOW_USER_MESSAGES,
},
"enabled_tool_groups": [],
"enabled_tools": [],
}
if not raw_config:
return SystemAgentLLMConfig.model_validate(default_payload)
@@ -6,6 +6,7 @@ from ag_ui.core.types import RunAgentInput
from agentscope.message import Msg
from core.agentscope.runtime.runner import AgentScopeRunner
from core.logging import get_logger
from schemas.automation.config import AutomationJobConfig
from schemas.user import UserContext
logger = get_logger("core.agentscope.runtime.orchestrator")
@@ -24,6 +25,7 @@ class RunnerLike(Protocol):
pipeline: PipelineLike,
run_input: RunAgentInput,
system_agent_mode: str,
memory_job_config: AutomationJobConfig | None,
) -> dict[str, Any]: ...
@@ -47,6 +49,7 @@ class AgentScopeRuntimeOrchestrator:
context_messages: list[Msg],
user_context: UserContext,
system_agent_mode: str,
memory_job_config: AutomationJobConfig | None = None,
) -> dict[str, Any]:
thread_id = run_input.thread_id
run_id = run_input.run_id
@@ -66,6 +69,7 @@ class AgentScopeRuntimeOrchestrator:
pipeline=self._pipeline,
run_input=run_input,
system_agent_mode=system_agent_mode,
memory_job_config=memory_job_config,
)
await self._pipeline.emit(
@@ -0,0 +1,58 @@
from __future__ import annotations
from schemas.agent.pipeline_spec import (
ContextPolicy,
ContextWindowMode,
ExecutorKind,
PipelineSpec,
StageSpec,
)
def build_default_pipeline_spec(*, mode: str) -> PipelineSpec:
normalized = mode.strip().lower()
if normalized == "worker":
return PipelineSpec(
mode="worker",
stages=[
StageSpec(
stage_name="router",
executor_kind=ExecutorKind.SINGLE_SHOT,
default_visibility_mask=0,
context_policy=ContextPolicy(
consumer_agent_type="router",
window_mode=ContextWindowMode.DAY,
count=20,
),
),
StageSpec(
stage_name="worker",
executor_kind=ExecutorKind.REACT,
default_visibility_mask=0,
context_policy=ContextPolicy(
consumer_agent_type="worker",
window_mode=ContextWindowMode.NUMBER,
count=20,
),
),
],
)
if normalized == "memory":
return PipelineSpec(
mode="memory",
stages=[
StageSpec(
stage_name="memory",
executor_kind=ExecutorKind.REACT,
default_visibility_mask=0,
context_policy=ContextPolicy(
consumer_agent_type="memory",
window_mode=ContextWindowMode.DAY,
count=20,
),
)
],
)
raise ValueError(f"unsupported pipeline mode: {normalized}")
+271 -31
View File
@@ -10,26 +10,40 @@ from agentscope.formatter import OpenAIChatFormatter
from agentscope.memory import InMemoryMemory
from agentscope.message import Msg
from agentscope.model import OpenAIChatModel
from core.agentscope.prompts.agent_prompt import build_worker_contract_prompt
from core.agentscope.prompts.system_prompt import build_system_prompt
from core.agentscope.runtime.pipeline_registry import build_default_pipeline_spec
from core.agentscope.runtime.json_react_agent import JsonReActAgent
from core.agentscope.runtime.model_tracking import TrackingChatModel
from core.agentscope.runtime.stage_emitter import PipelineStageEmitter
from core.agentscope.runtime.tool_selection_registry import TOOL_SELECTION_REGISTRY
from core.agentscope.tools.toolkit import build_stage_toolkit
from core.agentscope.utils import patch_agentscope_json_repair_compat
from core.agentscope.utils import (
finalize_json_response,
patch_agentscope_json_repair_compat,
)
from core.config.settings import config
from core.db.session import AsyncSessionLocal
from models.llm import Llm
from models.llm_factory import LlmFactory
from models.system_agents import SystemAgents
from schemas.agent.runtime_models import (
AgentOutput,
)
from schemas.agent.forwarded_props import (
ClientTimeContext,
parse_forwarded_props_client_time,
)
from schemas.agent.system_agent import AgentType, SystemAgentLLMConfig
from schemas.automation.config import AutomationJobConfig
from schemas.agent.runtime_models import (
AgentOutput,
RouterAgentOutput,
WorkerAgentOutputLite,
resolve_worker_output_model,
)
from schemas.agent.system_agent import (
AgentType,
ContextMessagesConfig,
ContextBuildStrategy,
SystemAgentLLMConfig,
)
from schemas.user import UserContext
from services.litellm.service import LiteLLMService
from sqlalchemy import select
@@ -46,6 +60,7 @@ class SystemAgentRuntimeConfig:
api_base_url: str
api_key: str
llm_config: SystemAgentLLMConfig
extra_context: str | None = None
@dataclass(frozen=True)
@@ -68,33 +83,83 @@ class AgentScopeRunner:
pipeline: PipelineLike,
run_input: RunAgentInput,
system_agent_mode: str,
memory_job_config: AutomationJobConfig | None = None,
) -> dict[str, Any]:
owner_id = UUID(user_context.id)
runtime_client_time = self._resolve_runtime_client_time(run_input=run_input)
stage_agent_type = self._resolve_stage_agent_type(system_agent_mode)
pipeline_spec = build_default_pipeline_spec(mode=system_agent_mode)
stage_agent_types = [
self._parse_agent_type(stage_name=stage.stage_name)
for stage in pipeline_spec.stages
]
async with AsyncSessionLocal() as session:
stage_config = await self._load_stage_config(
session=session,
agent_type=stage_agent_type,
)
if stage_agent_types == [AgentType.ROUTER, AgentType.WORKER]:
router_config = await self._load_stage_config(
session=session,
agent_type=AgentType.ROUTER,
)
worker_config = await self._load_stage_config(
session=session,
agent_type=AgentType.WORKER,
)
worker_toolkit = self._build_stage_toolkit(
session=session,
owner_id=owner_id,
stage_config=worker_config,
)
router_output = await self._execute_router_step(
pipeline=pipeline,
run_input=run_input,
user_context=user_context,
context_messages=context_messages,
stage_config=router_config,
runtime_client_time=runtime_client_time,
)
worker_output = await self._execute_worker_step(
pipeline=pipeline,
run_input=run_input,
user_context=user_context,
router_output=router_output,
toolkit=worker_toolkit,
stage_config=worker_config,
runtime_client_time=runtime_client_time,
)
return {
"router": router_output.model_dump(mode="json", exclude_none=True),
"worker": worker_output.model_dump(mode="json", exclude_none=True),
}
if stage_agent_types[0] == AgentType.MEMORY:
if memory_job_config is None:
raise RuntimeError("memory job config is required")
stage_config = await self._build_memory_stage_config(
session=session,
memory_job_config=memory_job_config,
)
else:
stage_config = await self._load_stage_config(
session=session,
agent_type=stage_agent_types[0],
)
stage_toolkit = self._build_stage_toolkit(
session=session,
owner_id=owner_id,
stage_config=stage_config,
)
worker_output = await self._execute_worker_step(
stage_output = await self._execute_single_stage_step(
pipeline=pipeline,
run_input=run_input,
user_context=user_context,
context_messages=context_messages,
input_messages=context_messages,
toolkit=stage_toolkit,
stage_config=stage_config,
runtime_client_time=runtime_client_time,
)
return {
"worker": worker_output.model_dump(mode="json", exclude_none=True),
stage_config.agent_type.value: stage_output.model_dump(
mode="json", exclude_none=True
),
}
def _build_stage_toolkit(
@@ -113,11 +178,15 @@ class AgentScopeRunner:
)
@staticmethod
def _resolve_stage_agent_type(system_agent_mode: str) -> AgentType:
mode = system_agent_mode.strip().lower() if system_agent_mode else "worker"
if mode == AgentType.MEMORY.value:
def _parse_agent_type(*, stage_name: str) -> AgentType:
normalized = stage_name.strip().lower()
if normalized == AgentType.ROUTER.value:
return AgentType.ROUTER
if normalized == AgentType.WORKER.value:
return AgentType.WORKER
if normalized == AgentType.MEMORY.value:
return AgentType.MEMORY
return AgentType.WORKER
raise ValueError(f"unsupported stage name: {stage_name}")
async def _load_stage_config(
self,
@@ -130,28 +199,60 @@ class AgentScopeRunner:
agent_type=agent_type,
)
async def _execute_worker_step(
async def _execute_router_step(
self,
*,
pipeline: PipelineLike,
run_input: RunAgentInput,
user_context: UserContext,
context_messages: list[Msg],
toolkit: Any,
stage_config: SystemAgentRuntimeConfig,
runtime_client_time: ClientTimeContext | None,
) -> AgentOutput:
step_name = stage_config.agent_type.value
worker_output_model = AgentOutput
) -> RouterAgentOutput:
await self._emit_step_event(
pipeline=pipeline,
run_input=run_input,
step_name=step_name,
step_name=AgentType.ROUTER.value,
event_type="STEP_STARTED",
)
router_result = await self._run_router_stage(
user_context=user_context,
context_messages=context_messages,
stage_config=stage_config,
runtime_client_time=runtime_client_time,
)
router_output = RouterAgentOutput.model_validate(router_result.payload)
await self._emit_step_event(
pipeline=pipeline,
run_input=run_input,
step_name=AgentType.ROUTER.value,
event_type="STEP_FINISHED",
)
return router_output
async def _execute_worker_step(
self,
*,
pipeline: PipelineLike,
run_input: RunAgentInput,
user_context: UserContext,
router_output: RouterAgentOutput,
toolkit: Any,
stage_config: SystemAgentRuntimeConfig,
runtime_client_time: ClientTimeContext | None,
) -> WorkerAgentOutputLite:
worker_output_model = resolve_worker_output_model(router_output.ui.ui_mode)
await self._emit_step_event(
pipeline=pipeline,
run_input=run_input,
step_name=AgentType.WORKER.value,
event_type="STEP_STARTED",
)
worker_result = await self._run_worker_stage(
user_context=user_context,
context_messages=context_messages,
input_messages=self._build_worker_input_messages(
router_output=router_output
),
toolkit=toolkit,
run_input=run_input,
stage_config=stage_config,
@@ -163,11 +264,48 @@ class AgentScopeRunner:
await self._emit_step_event(
pipeline=pipeline,
run_input=run_input,
step_name=step_name,
step_name=AgentType.WORKER.value,
event_type="STEP_FINISHED",
)
return worker_output
async def _execute_single_stage_step(
self,
*,
pipeline: PipelineLike,
run_input: RunAgentInput,
user_context: UserContext,
input_messages: list[Msg],
toolkit: Any,
stage_config: SystemAgentRuntimeConfig,
runtime_client_time: ClientTimeContext | None,
) -> AgentOutput:
step_name = stage_config.agent_type.value
await self._emit_step_event(
pipeline=pipeline,
run_input=run_input,
step_name=step_name,
event_type="STEP_STARTED",
)
stage_result = await self._run_worker_stage(
user_context=user_context,
input_messages=input_messages,
toolkit=toolkit,
run_input=run_input,
stage_config=stage_config,
worker_output_model=AgentOutput,
pipeline=pipeline,
runtime_client_time=runtime_client_time,
)
stage_output = AgentOutput.model_validate(stage_result.payload)
await self._emit_step_event(
pipeline=pipeline,
run_input=run_input,
step_name=step_name,
event_type="STEP_FINISHED",
)
return stage_output
async def _load_system_agent_config(
self,
*,
@@ -193,6 +331,50 @@ class AgentScopeRunner:
api_base_url=factory.request_url,
api_key=self._resolve_provider_api_key(factory_name=factory.name),
llm_config=SystemAgentLLMConfig.model_validate(system_agent.config or {}),
extra_context=None,
)
async def _build_memory_stage_config(
self,
*,
session: AsyncSession,
memory_job_config: AutomationJobConfig,
) -> SystemAgentRuntimeConfig:
stmt = (
select(Llm, LlmFactory)
.join(LlmFactory, Llm.factory_id == LlmFactory.id)
.where(Llm.model_code == memory_job_config.model_code)
)
row = (await session.execute(stmt)).one_or_none()
if row is None:
raise RuntimeError(
f"memory model not found: {memory_job_config.model_code}"
)
llm, factory = row
llm_config = SystemAgentLLMConfig(
temperature=0.7,
max_tokens=None,
timeout_seconds=30,
visibility_consumer_bit=18,
context_messages=ContextMessagesConfig(
mode=(
ContextBuildStrategy.DAY
if memory_job_config.context.window_mode.value == "day"
else ContextBuildStrategy.NUMBER
),
count=memory_job_config.context.window_count,
),
enabled_tools=memory_job_config.enabled_tools,
)
return SystemAgentRuntimeConfig(
agent_type=AgentType.MEMORY,
model_code=llm.model_code,
api_base_url=factory.request_url,
api_key=self._resolve_provider_api_key(factory_name=factory.name),
llm_config=llm_config,
extra_context=(
f"[Memory Input Template]\n{memory_job_config.input_template.strip()}"
),
)
@staticmethod
@@ -211,19 +393,63 @@ class AgentScopeRunner:
raise RuntimeError(f"provider api key missing for factory: {factory_name}")
return api_key
async def _run_worker_stage(
async def _run_router_stage(
self,
*,
user_context: UserContext,
context_messages: list[Msg],
stage_config: SystemAgentRuntimeConfig,
runtime_client_time: ClientTimeContext | None,
) -> StageExecutionResult:
tracking_model = self._build_model(stage_config=stage_config)
response, payload = await finalize_json_response(
model=tracking_model,
formatter=OpenAIChatFormatter(),
base_messages=[
Msg(
"system",
build_system_prompt(
agent_type=AgentType.ROUTER,
llm_config=stage_config.llm_config,
user_context=user_context,
now_utc=datetime.now(timezone.utc),
runtime_client_time=runtime_client_time,
tools=None,
),
"system",
),
*context_messages,
],
output_model=RouterAgentOutput,
retries=0,
)
response_msg = Msg(
name="router",
role="assistant",
content=list(getattr(response, "content", [])),
metadata=payload,
)
return StageExecutionResult(
message=response_msg,
payload=payload,
response_metadata=self._litellm_service.build_usage_metadata(
model=stage_config.model_code,
usage_summary=tracking_model.usage_summary(),
),
)
async def _run_worker_stage(
self,
*,
user_context: UserContext,
input_messages: list[Msg],
toolkit: Any,
run_input: RunAgentInput,
stage_config: SystemAgentRuntimeConfig,
worker_output_model: type[AgentOutput],
worker_output_model: type[WorkerAgentOutputLite],
pipeline: PipelineLike,
runtime_client_time: ClientTimeContext | None,
) -> StageExecutionResult:
worker_input = list(context_messages)
tracking_model = self._build_model(stage_config=stage_config)
emitter = PipelineStageEmitter(
pipeline=pipeline,
@@ -241,6 +467,7 @@ class AgentScopeRunner:
user_context=user_context,
now_utc=datetime.now(timezone.utc),
runtime_client_time=runtime_client_time,
extra_context=stage_config.extra_context,
tools=None,
),
toolkit=toolkit,
@@ -248,7 +475,7 @@ class AgentScopeRunner:
emitter=emitter,
)
response_msg = await agent.reply_json(
worker_input, output_model=worker_output_model
input_messages, output_model=worker_output_model
)
worker_payload = worker_output_model.model_validate(response_msg.metadata or {})
response_metadata = self._litellm_service.build_usage_metadata(
@@ -265,6 +492,19 @@ class AgentScopeRunner:
response_metadata=response_metadata,
)
def _build_worker_input_messages(
self,
*,
router_output: RouterAgentOutput,
) -> list[Msg]:
return [
Msg(
name=AgentType.ROUTER.value,
role="user",
content=build_worker_contract_prompt(router_output=router_output),
)
]
def _build_model(
self, *, stage_config: SystemAgentRuntimeConfig
) -> TrackingChatModel:
@@ -272,8 +512,8 @@ class AgentScopeRunner:
"temperature": stage_config.llm_config.temperature,
"max_tokens": stage_config.llm_config.max_tokens,
"timeout": stage_config.llm_config.timeout_seconds,
"extra_body": {"enable_thinking": False},
}
generate_kwargs["extra_body"] = {"enable_thinking": False}
model = OpenAIChatModel(
model_name=stage_config.model_code,
+122 -7
View File
@@ -2,6 +2,7 @@ from __future__ import annotations
import base64
import json
from datetime import timezone
from typing import Any, cast
from uuid import UUID
@@ -14,26 +15,49 @@ from core.agentscope.events import (
)
from core.agentscope.runtime.context_service import AgentContextService
from core.agentscope.runtime.orchestrator import AgentScopeRuntimeOrchestrator
from core.agentscope.runtime.pipeline_registry import build_default_pipeline_spec
from core.agentscope.schemas.agui_input import parse_run_input
from core.automation.scheduler import (
AutomationSchedulerService,
SqlAlchemyAutomationSchedulerRepository,
utc_now,
)
from core.auth.models import CurrentUser
from core.config.settings import config
from core.db.session import AsyncSessionLocal
from core.logging import get_logger
from core.taskiq.app import bulk_broker, critical_broker, default_broker
from models.automation_jobs import AutomationJob
from schemas.agent.visibility import SystemVisibilityBit, bit_mask
from schemas.automation.config import AutomationJobConfig
from schemas.messages.chat_message import (
AgentChatMessageMetadata,
extract_user_message_attachments,
)
from schemas.agent.forwarded_props import parse_forwarded_props_agent_type
from schemas.user import UserContext
from services.base.redis import get_or_init_redis_client
from services.base.supabase import supabase_service
from v1.agent.repository import AgentRepository
from v1.users.dependencies import get_user_service
from sqlalchemy import select
logger = get_logger("core.agentscope.runtime.tasks")
_MAX_CONTEXT_ATTACHMENTS = 3
class _BulkQueueAdapter:
async def enqueue(
self,
*,
command: dict[str, object],
dedup_key: str | None,
) -> str:
del dedup_key
result = await run_command_task_bulk.kiq(command)
return str(result.task_id)
def _serialize_tool_agent_output(
*,
metadata: AgentChatMessageMetadata | dict[str, object] | None,
@@ -79,13 +103,29 @@ async def _build_recent_context_messages(
*,
session: Any,
thread_id: str,
system_agent_mode: str,
context_mode: str,
memory_job_config: AutomationJobConfig | None = None,
) -> list[Msg]:
context_service = AgentContextService(repository=AgentRepository(session))
result = await context_service.load_context_messages(
thread_id=thread_id,
system_agent_mode=system_agent_mode,
)
if memory_job_config is not None:
visibility_mask = bit_mask(bit=int(SystemVisibilityBit.UI_HISTORY))
if memory_job_config.context.window_mode.value == "day":
result = await context_service.load_by_day_window(
thread_id=thread_id,
day_count=memory_job_config.context.window_count,
visibility_mask=visibility_mask,
)
else:
result = await context_service.load_by_user_message_window(
thread_id=thread_id,
user_message_limit=memory_job_config.context.window_count,
visibility_mask=visibility_mask,
)
else:
result = await context_service.load_context_messages(
thread_id=thread_id,
system_agent_mode=context_mode,
)
if not result:
return []
@@ -166,11 +206,33 @@ async def _build_recent_context_messages(
return converted
async def _load_memory_job_config(
*,
session: Any,
owner_id: UUID,
automation_job_id: str,
) -> AutomationJobConfig:
try:
job_uuid = UUID(automation_job_id)
except ValueError as exc:
raise ValueError("automation_job_id is invalid") from exc
stmt = (
select(AutomationJob)
.where(AutomationJob.id == job_uuid)
.where(AutomationJob.owner_id == owner_id)
.where(AutomationJob.deleted_at.is_(None))
)
row = (await session.execute(stmt)).scalar_one_or_none()
if row is None:
raise ValueError("automation job not found")
return AutomationJobConfig.model_validate(row.config or {})
async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
command_type = str(command.get("command", "run")).strip().lower()
raw_owner_id = command.get("owner_id")
run_input_raw = command.get("run_input")
system_agent_mode = str(command.get("system_agent_mode", "worker")).strip().lower()
if not isinstance(raw_owner_id, str) or not raw_owner_id.strip():
raise ValueError("owner_id is required")
@@ -178,6 +240,15 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
raise ValueError("run_input is required")
run_input = parse_run_input(run_input_raw)
system_agent_mode = parse_forwarded_props_agent_type(
getattr(run_input, "forwarded_props", None)
)
raw_automation_job_id = command.get("automation_job_id")
if system_agent_mode == "memory" and (
not isinstance(raw_automation_job_id, str) or not raw_automation_job_id
):
raise ValueError("automation_job_id is required for memory mode")
pipeline_spec = build_default_pipeline_spec(mode=system_agent_mode)
thread_id = run_input.thread_id
run_id = run_input.run_id
owner_id = UUID(raw_owner_id)
@@ -189,6 +260,14 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
async with AsyncSessionLocal() as session:
user_context = await _build_user_context(owner_id=owner_id, session=session)
memory_job_config: AutomationJobConfig | None = None
if system_agent_mode == "memory":
assert isinstance(raw_automation_job_id, str)
memory_job_config = await _load_memory_job_config(
session=session,
owner_id=owner_id,
automation_job_id=raw_automation_job_id,
)
redis_client = await get_or_init_redis_client()
bus = RedisStreamBus(
@@ -211,7 +290,8 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
context_messages = await _build_recent_context_messages(
session=session,
thread_id=thread_id,
system_agent_mode=system_agent_mode,
context_mode=pipeline_spec.stages[0].context_policy.consumer_agent_type,
memory_job_config=memory_job_config,
)
await runtime.run(
@@ -219,6 +299,7 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
context_messages=context_messages,
user_context=user_context,
system_agent_mode=system_agent_mode,
memory_job_config=memory_job_config,
)
logger.info(
"agentscope runtime task completed",
@@ -233,6 +314,35 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
}
async def run_automation_scheduler_scan(
*,
limit: int | None = None,
) -> dict[str, int]:
now = utc_now()
safe_limit = (
max(int(limit), 1)
if isinstance(limit, int)
else int(config.automation_scheduler.batch_limit)
)
async with AsyncSessionLocal() as session:
repository = SqlAlchemyAutomationSchedulerRepository(session=session)
service = AutomationSchedulerService(
repository=repository,
queue=_BulkQueueAdapter(),
)
result = await service.scan_and_dispatch(now_utc=now, limit=safe_limit)
logger.info(
"automation scheduler scan completed",
scanned=result.scanned,
dispatched=result.dispatched,
now_utc=now.astimezone(timezone.utc).isoformat(),
)
return {
"scanned": int(result.scanned),
"dispatched": int(result.dispatched),
}
@default_broker.task(task_name="tasks.agentscope.run_command")
async def run_command_task(command: dict[str, Any]) -> dict[str, object]:
return await run_agentscope_task(command)
@@ -246,3 +356,8 @@ async def run_command_task_critical(command: dict[str, Any]) -> dict[str, object
@bulk_broker.task(task_name="tasks.agentscope.run_command.bulk")
async def run_command_task_bulk(command: dict[str, Any]) -> dict[str, object]:
return await run_agentscope_task(command)
@default_broker.task(task_name="tasks.automation.scan_due_jobs")
async def scan_due_automation_jobs_task(limit: int | None = None) -> dict[str, int]:
return await run_automation_scheduler_scan(limit=limit)
@@ -3,7 +3,7 @@ from __future__ import annotations
from collections.abc import Callable
from typing import Any
from core.agentscope.tools.tool_config import resolve_tool_names_by_groups
from core.agentscope.tools.tool_config import resolve_tool_function_names
from schemas.agent.system_agent import AgentType
ToolNameResolver = Callable[[Any], set[str] | None]
@@ -23,20 +23,19 @@ class ToolSelectionRegistry:
return resolver(stage_config)
def _default_group_resolver(stage_config: Any) -> set[str] | None:
raw_groups = getattr(stage_config.llm_config, "enabled_tool_groups", [])
groups = raw_groups if isinstance(raw_groups, list) else []
if not groups:
def _default_tool_resolver(stage_config: Any) -> set[str] | None:
enabled_tools = getattr(stage_config.llm_config, "enabled_tools", [])
if not enabled_tools:
return None
return resolve_tool_names_by_groups(set(groups))
return resolve_tool_function_names(set(enabled_tools))
TOOL_SELECTION_REGISTRY = ToolSelectionRegistry()
TOOL_SELECTION_REGISTRY.register(
agent_type=AgentType.WORKER,
resolver=_default_group_resolver,
resolver=_default_tool_resolver,
)
TOOL_SELECTION_REGISTRY.register(
agent_type=AgentType.MEMORY,
resolver=_default_group_resolver,
resolver=_default_tool_resolver,
)
@@ -11,7 +11,7 @@ from core.agentscope.tools.utils.calendar_domain import (
map_calendar_exception,
merge_schedule_metadata_for_update,
parse_iso_datetime,
resolve_share_target_email_map,
resolve_share_target_phone_map,
schedule_event_to_dict,
)
from core.agentscope.tools.utils.calendar_ui import (
@@ -580,11 +580,11 @@ async def calendar_share(
)
target_uuid = UUID(event_id)
email_map = resolve_share_target_email_map(
phone_map = resolve_share_target_phone_map(
[invitee.user_id for invitee in invitees]
)
if not email_map:
if not phone_map:
return calendar_error_output(
tool_name=tool_name,
tool_call_args=tool_call_args,
@@ -599,8 +599,8 @@ async def calendar_share(
normalized_user_id = str(UUID(invitee.user_id.strip()))
except ValueError:
continue
email = email_map.get(normalized_user_id)
if email is None:
phone = phone_map.get(normalized_user_id)
if phone is None:
continue
permission = {
"permission_view": invitee.permission_view,
@@ -608,15 +608,15 @@ async def calendar_share(
"permission_invite": invitee.permission_invite,
}
await service.share(
target_uuid, ScheduleItemShareRequest(email=email, **permission)
target_uuid, ScheduleItemShareRequest(phone=phone, **permission)
)
invited.append(email)
invited.append(phone)
if not invited:
return calendar_error_output(
tool_name=tool_name,
tool_call_args=tool_call_args,
code="NOT_FOUND",
message="邀请目标均无有效邮箱",
message="邀请目标均无有效手机号",
retryable=False,
)
@@ -9,7 +9,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from agentscope.tool import ToolResponse
from core.agentscope.tools.tool_call_context import get_current_tool_call_id
from core.agentscope.tools.utils import (
find_auth_email_by_user_id,
find_auth_phone_by_user_id,
list_auth_users,
)
from core.agentscope.tools.utils.tool_response_builder import (
@@ -46,22 +46,22 @@ def _lookup_error_output(
async def _resolve_identity(
*,
session: AsyncSession,
user_email: str | None,
user_phone: str | None,
user_name: str | None,
) -> dict[str, Any]:
"""Resolve user identity by email or username."""
email = user_email.strip().lower() if isinstance(user_email, str) else ""
"""Resolve user identity by phone or username."""
phone = user_phone.strip() if isinstance(user_phone, str) else ""
name = user_name.strip() if isinstance(user_name, str) else ""
if bool(email) == bool(name):
if bool(phone) == bool(name):
raise HTTPException(
status_code=400,
detail="请提供 email 或 username 其中之一",
detail="请提供 phone 或 username 其中之一",
)
if email:
if phone:
auth_gateway = SupabaseAuthGateway()
user = await auth_gateway.get_user_by_email(email)
user = await auth_gateway.get_user_by_phone(phone)
user_id = UUID(user.id)
stmt = (
@@ -73,9 +73,9 @@ async def _resolve_identity(
return {
"userId": str(user_id),
"email": user.email,
"phone": user.phone,
"username": username,
"matchedBy": "email",
"matchedBy": "phone",
}
stmt = (
@@ -90,20 +90,20 @@ async def _resolve_identity(
raise HTTPException(status_code=404, detail="用户不存在")
users = list_auth_users()
email_value = find_auth_email_by_user_id(users=users, user_id=profile.id)
phone_value = find_auth_phone_by_user_id(users=users, user_id=profile.id)
return {
"userId": str(profile.id),
"email": email_value,
"phone": phone_value,
"username": profile.username,
"matchedBy": "username",
}
async def user_lookup(
user_email: Annotated[
user_phone: Annotated[
str | None,
Field(description="User email address to look up."),
Field(description="User phone to look up."),
] = None,
user_name: Annotated[
str | None,
@@ -112,16 +112,16 @@ async def user_lookup(
session: Any = None,
owner_id: Any = None,
) -> ToolResponse:
"""Look up user identity by email or username.
"""Look up user identity by phone or username.
Args:
user_email: User email address for lookup.
user_phone: User phone for lookup.
user_name: Username for lookup.
Returns:
ToolResponse with serialized ToolAgentOutput payload.
"""
tool_call_args = {"user_email": user_email, "user_name": user_name}
tool_call_args = {"user_phone": user_phone, "user_name": user_name}
if session is None or owner_id is None:
return _lookup_error_output(
@@ -134,17 +134,17 @@ async def user_lookup(
try:
resolved = await _resolve_identity(
session=cast(AsyncSession, session),
user_email=user_email,
user_phone=user_phone,
user_name=user_name,
)
username = str(resolved.get("username") or "")
email = str(resolved.get("email") or "")
phone = str(resolved.get("phone") or "")
user_id = str(resolved.get("userId") or "")
matched_by = str(resolved.get("matchedBy") or "")
summary = (
f"status=success matched_by={matched_by} user_id={user_id} "
f"username={username} has_email={str(bool(email)).lower()}"
f"username={username} has_phone={str(bool(phone)).lower()}"
)
return _dump_tool_output(
ToolAgentOutput(
@@ -6,7 +6,15 @@ from enum import Enum
class ToolGroup(str, Enum):
READ = "read"
WRITE = "write"
EXECUTE = "execute"
MEMORY = "memory"
class AgentTool(str, Enum):
CALENDAR_READ = "calendar.read"
CALENDAR_WRITE = "calendar.write"
CALENDAR_SHARE = "calendar.share"
USER_LOOKUP = "user.lookup"
@dataclass(frozen=True)
@@ -29,30 +37,51 @@ TOOL_CONFIGS: dict[str, ToolConfig] = {
),
"user_lookup": ToolConfig(
name="user_lookup",
group=ToolGroup.READ,
group=ToolGroup.MEMORY,
approval=ToolApprovalConfig(required=False),
),
"calendar_write": ToolConfig(
name="calendar_write",
group=ToolGroup.WRITE,
group=ToolGroup.EXECUTE,
approval=ToolApprovalConfig(required=False),
),
"calendar_share": ToolConfig(
name="calendar_share",
group=ToolGroup.WRITE,
group=ToolGroup.EXECUTE,
approval=ToolApprovalConfig(required=False),
),
}
AGENT_TOOL_TO_FUNCTION_NAME: dict[AgentTool, str] = {
AgentTool.CALENDAR_READ: "calendar_read",
AgentTool.CALENDAR_WRITE: "calendar_write",
AgentTool.CALENDAR_SHARE: "calendar_share",
AgentTool.USER_LOOKUP: "user_lookup",
}
def get_tool_config(tool_name: str) -> ToolConfig:
config = TOOL_CONFIGS.get(tool_name)
if config is None:
raise ValueError(f"unknown tool: {tool_name}")
return config
TOOL_NAME_ALIASES: dict[str, AgentTool] = {
AgentTool.CALENDAR_READ.value: AgentTool.CALENDAR_READ,
"calendar_read": AgentTool.CALENDAR_READ,
AgentTool.CALENDAR_WRITE.value: AgentTool.CALENDAR_WRITE,
"calendar_write": AgentTool.CALENDAR_WRITE,
AgentTool.CALENDAR_SHARE.value: AgentTool.CALENDAR_SHARE,
"calendar_share": AgentTool.CALENDAR_SHARE,
AgentTool.USER_LOOKUP.value: AgentTool.USER_LOOKUP,
"user_lookup": AgentTool.USER_LOOKUP,
}
def resolve_tool_names_by_groups(groups: set[ToolGroup]) -> set[str]:
if not groups:
return set()
return {name for name, config in TOOL_CONFIGS.items() if config.group in groups}
def parse_agent_tool(value: object) -> AgentTool:
if isinstance(value, AgentTool):
return value
raw_value = str(value or "").strip().lower()
if not raw_value:
raise ValueError("enabled tool value cannot be empty")
tool = TOOL_NAME_ALIASES.get(raw_value)
if tool is None:
raise ValueError(f"unknown enabled tool: {raw_value}")
return tool
def resolve_tool_function_names(tools: set[AgentTool]) -> set[str]:
return {AGENT_TOOL_TO_FUNCTION_NAME[tool] for tool in tools}
@@ -7,10 +7,15 @@ from core.agentscope.tools.tool_call_context import (
reset_current_tool_call_id,
set_current_tool_call_id,
)
from core.agentscope.tools.tool_config import (
AGENT_TOOL_TO_FUNCTION_NAME,
TOOL_CONFIGS,
ToolConfig,
parse_agent_tool,
)
from core.agentscope.tools.utils.tool_response_builder import (
build_error_response,
)
from core.agentscope.tools.tool_config import ToolConfig, TOOL_CONFIGS
def register_tool_middlewares(
@@ -59,6 +64,18 @@ def create_approval_middleware(
approval_resolver: Callable[[str, dict[str, Any], ToolConfig], str | None]
| None = None,
) -> Callable[..., AsyncGenerator[Any, None]]:
def _resolve_tool_config(*, tool_name: str) -> ToolConfig | None:
config = config_by_name.get(tool_name)
if config is not None:
return config
try:
normalized_tool_name = AGENT_TOOL_TO_FUNCTION_NAME[
parse_agent_tool(tool_name)
]
except ValueError:
return None
return config_by_name.get(normalized_tool_name)
def _resolve_tool_call_id(tool_call: dict[str, Any]) -> str:
raw_tool_call_id = tool_call.get("id")
if isinstance(raw_tool_call_id, str) and raw_tool_call_id.strip():
@@ -81,7 +98,7 @@ def create_approval_middleware(
yield response
return
config = config_by_name.get(tool_name)
config = _resolve_tool_config(tool_name=tool_name)
if config is None or not config.approval.required:
async for response in await next_handler(**kwargs):
yield response
@@ -134,15 +151,3 @@ def create_approval_middleware(
yield pending_response
return approval_middleware
def create_hitl_middleware(
*,
meta_by_name: dict[str, ToolConfig],
approval_resolver: Callable[[str, dict[str, Any], ToolConfig], str | None]
| None = None,
) -> Callable[..., AsyncGenerator[Any, None]]:
return create_approval_middleware(
config_by_name=meta_by_name,
approval_resolver=approval_resolver,
)
+20 -34
View File
@@ -13,8 +13,6 @@ from core.agentscope.tools.custom.calendar import (
from core.agentscope.tools.custom.user_lookup import user_lookup
from core.agentscope.tools.tool_config import (
TOOL_CONFIGS,
ToolGroup,
resolve_tool_names_by_groups,
)
from core.agentscope.tools.tool_middleware import register_tool_middlewares
from sqlalchemy.ext.asyncio import AsyncSession
@@ -28,46 +26,36 @@ TOOL_FUNCTIONS: dict[str, Any] = {
}
AGENT_TYPE_TO_GROUPS: dict[AgentType, set[ToolGroup]] = {
AgentType.WORKER: {ToolGroup.READ, ToolGroup.WRITE},
AgentType.MEMORY: {ToolGroup.READ, ToolGroup.WRITE},
AGENT_TYPE_TO_DEFAULT_TOOLS: dict[AgentType, set[str]] = {
AgentType.WORKER: {
"calendar_read",
"calendar_write",
"calendar_share",
"user_lookup",
},
AgentType.MEMORY: {"calendar_read", "user_lookup"},
}
def _resolve_enabled_tools(
*,
groups: set[ToolGroup] | None,
enabled_tool_names: set[str] | None,
) -> set[str]:
if enabled_tool_names is not None:
unknown = enabled_tool_names - set(TOOL_FUNCTIONS)
if unknown:
raise ValueError(f"unknown tools in enabled_tool_names: {sorted(unknown)}")
return set(enabled_tool_names)
if groups is None:
return set(TOOL_FUNCTIONS)
resolved = resolve_tool_names_by_groups(groups)
unknown = resolved - set(TOOL_FUNCTIONS)
def _validate_enabled_tool_names(enabled_tool_names: set[str]) -> set[str]:
unknown = enabled_tool_names - set(TOOL_FUNCTIONS)
if unknown:
raise ValueError(f"tool config contains unknown tools: {sorted(unknown)}")
return resolved
raise ValueError(f"unknown tools in enabled_tool_names: {sorted(unknown)}")
return enabled_tool_names
def build_toolkit(
*,
session: AsyncSession,
owner_id: UUID,
groups: set[ToolGroup] | None = None,
enabled_tool_names: set[str] | None = None,
enable_hitl: bool | None = None,
):
toolkit = Toolkit()
enabled_names = _resolve_enabled_tools(
groups=groups,
enabled_tool_names=enabled_tool_names,
)
if enabled_tool_names is None:
enabled_names = set(TOOL_FUNCTIONS)
else:
enabled_names = _validate_enabled_tool_names(set(enabled_tool_names))
preset_kwargs = cast(
dict[str, JSONSerializableObject],
@@ -100,15 +88,13 @@ def build_stage_toolkit(
enabled_tool_names: set[str] | None = None,
enable_hitl: bool | None = None,
):
groups = AGENT_TYPE_TO_GROUPS.get(agent_type)
if groups is None:
default_tools = AGENT_TYPE_TO_DEFAULT_TOOLS.get(agent_type)
if default_tools is None:
raise ValueError(f"unknown agent_type: {agent_type}")
stage_enabled_names = resolve_tool_names_by_groups(set(groups))
selected_names = (
stage_enabled_names
set(default_tools)
if enabled_tool_names is None
else stage_enabled_names | set(enabled_tool_names)
else _validate_enabled_tool_names(set(enabled_tool_names))
)
return build_toolkit(
@@ -1,9 +1,9 @@
from core.agentscope.tools.utils.auth_helpers import (
find_auth_email_by_user_id,
find_auth_phone_by_user_id,
list_auth_users,
)
__all__ = [
"list_auth_users",
"find_auth_email_by_user_id",
"find_auth_phone_by_user_id",
]
@@ -25,12 +25,12 @@ def list_auth_users() -> list[Any]:
return users
def find_auth_email_by_user_id(*, users: list[Any], user_id: UUID) -> str | None:
"""Find auth email by user id from fetched user list."""
def find_auth_phone_by_user_id(*, users: list[Any], user_id: UUID) -> str | None:
"""Find auth phone by user id from fetched user list."""
target = str(user_id)
for user in users:
if str(getattr(user, "id", "")) == target:
email = getattr(user, "email", None)
if isinstance(email, str) and email.strip():
return email.strip()
phone = getattr(user, "phone", None)
if isinstance(phone, str) and phone.strip():
return phone.strip()
return None
@@ -9,7 +9,7 @@ from fastapi import HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from core.agentscope.tools.utils.auth_helpers import (
find_auth_email_by_user_id,
find_auth_phone_by_user_id,
list_auth_users,
)
from core.auth.models import CurrentUser
@@ -125,7 +125,7 @@ def parse_iso_datetime(value: str | None) -> datetime | None:
return parsed.astimezone(timezone.utc)
def resolve_share_target_email_map(invitee_user_ids: list[str]) -> dict[str, str]:
def resolve_share_target_phone_map(invitee_user_ids: list[str]) -> dict[str, str]:
users = list_auth_users()
resolved: dict[str, str] = {}
for raw_user_id in invitee_user_ids:
@@ -138,7 +138,7 @@ def resolve_share_target_email_map(invitee_user_ids: list[str]) -> dict[str, str
user_uuid = UUID(normalized_user_id)
except ValueError:
continue
email = find_auth_email_by_user_id(users=users, user_id=user_uuid)
if email:
resolved[str(user_uuid)] = email.lower()
phone = find_auth_phone_by_user_id(users=users, user_id=user_uuid)
if phone:
resolved[str(user_uuid)] = phone
return resolved
+11
View File
@@ -0,0 +1,11 @@
from core.automation.scheduler import (
AutomationSchedulerService,
DispatchResult,
SqlAlchemyAutomationSchedulerRepository,
)
__all__ = [
"AutomationSchedulerService",
"DispatchResult",
"SqlAlchemyAutomationSchedulerRepository",
]
+247
View File
@@ -0,0 +1,247 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import Protocol
from uuid import UUID, uuid4
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from core.logging import get_logger
from models.agent_chat_session import AgentChatSession, SessionType
from models.automation_jobs import AutomationJob, ScheduleType
from schemas.automation.config import AutomationJobConfig
from schemas.automation.scheduler import DueAutomationJob, SchedulerDispatchCommand
logger = get_logger("core.automation.scheduler")
class QueueLike(Protocol):
async def enqueue(
self,
*,
command: dict[str, object],
dedup_key: str | None,
) -> str: ...
class AutomationSchedulerRepositoryLike(Protocol):
async def list_due_jobs(
self,
*,
now_utc: datetime,
limit: int,
) -> list[DueAutomationJob]: ...
async def get_job_config(self, *, job_id: UUID) -> AutomationJobConfig: ...
async def ensure_latest_chat_session(self, *, owner_id: UUID) -> UUID: ...
async def mark_job_dispatched(
self,
*,
job_id: UUID,
next_run_at: datetime,
last_run_at: datetime,
) -> None: ...
async def commit(self) -> None: ...
async def rollback(self) -> None: ...
@dataclass(slots=True)
class DispatchResult:
scanned: int
dispatched: int
class AutomationSchedulerService:
def __init__(
self,
*,
repository: AutomationSchedulerRepositoryLike,
queue: QueueLike,
) -> None:
self._repository = repository
self._queue = queue
async def scan_and_dispatch(
self,
*,
now_utc: datetime,
limit: int,
) -> DispatchResult:
safe_limit = max(int(limit), 1)
due_jobs = await self._repository.list_due_jobs(
now_utc=now_utc, limit=safe_limit
)
dispatched = 0
for job in due_jobs:
try:
config = await self._repository.get_job_config(job_id=job.id)
thread_id = await self._repository.ensure_latest_chat_session(
owner_id=job.owner_id
)
command = self._build_dispatch_command(
job=job,
thread_id=thread_id,
input_text=config.input_template,
now_utc=now_utc,
)
await self._queue.enqueue(command=command, dedup_key=None)
await self._repository.mark_job_dispatched(
job_id=job.id,
next_run_at=_compute_next_run_at(
current_next_run_at=job.next_run_at,
now_utc=now_utc,
schedule_type=job.schedule_type,
),
last_run_at=now_utc,
)
await self._repository.commit()
dispatched += 1
except Exception as exc:
await self._repository.rollback()
logger.exception(
"automation job dispatch failed",
job_id=str(job.id),
owner_id=str(job.owner_id),
error=str(exc),
)
return DispatchResult(scanned=len(due_jobs), dispatched=dispatched)
def _build_dispatch_command(
self,
*,
job: DueAutomationJob,
thread_id: UUID,
input_text: str,
now_utc: datetime,
) -> dict[str, object]:
run_id = f"auto-{job.id}-{int(now_utc.timestamp())}"
payload = SchedulerDispatchCommand(
owner_id=job.owner_id,
automation_job_id=job.id,
thread_id=thread_id,
run_id=run_id,
input_text=input_text.strip(),
)
return {
"command": "run",
"owner_id": str(payload.owner_id),
"automation_job_id": str(payload.automation_job_id),
"queue": "bulk",
"run_input": {
"threadId": str(payload.thread_id),
"runId": payload.run_id,
"state": {},
"messages": [
{
"id": str(uuid4()),
"role": "user",
"content": payload.input_text,
}
],
"tools": [],
"context": [],
"forwardedProps": {
"agent_type": "memory",
},
},
}
class SqlAlchemyAutomationSchedulerRepository:
def __init__(self, *, session: AsyncSession) -> None:
self._session = session
async def list_due_jobs(
self,
*,
now_utc: datetime,
limit: int,
) -> list[DueAutomationJob]:
stmt = (
select(AutomationJob)
.where(AutomationJob.deleted_at.is_(None))
.where(AutomationJob.status == "active")
.where(AutomationJob.next_run_at <= now_utc)
.order_by(AutomationJob.next_run_at.asc())
.limit(max(limit, 1))
)
rows = (await self._session.execute(stmt)).scalars().all()
return [
DueAutomationJob(
id=row.id,
owner_id=row.owner_id,
schedule_type=row.schedule_type,
timezone=row.timezone,
next_run_at=row.next_run_at,
)
for row in rows
]
async def get_job_config(self, *, job_id: UUID) -> AutomationJobConfig:
stmt = select(AutomationJob.config).where(AutomationJob.id == job_id)
config_payload = (await self._session.execute(stmt)).scalar_one()
return AutomationJobConfig.model_validate(config_payload or {})
async def ensure_latest_chat_session(self, *, owner_id: UUID) -> UUID:
stmt = (
select(AgentChatSession.id)
.where(AgentChatSession.user_id == owner_id)
.where(AgentChatSession.deleted_at.is_(None))
.where(AgentChatSession.session_type == SessionType.CHAT)
.order_by(AgentChatSession.last_activity_at.desc())
.limit(1)
)
existing = (await self._session.execute(stmt)).scalar_one_or_none()
if existing is not None:
return existing
session = AgentChatSession(
id=uuid4(),
user_id=owner_id,
session_type=SessionType.CHAT,
)
self._session.add(session)
await self._session.flush()
return session.id
async def mark_job_dispatched(
self,
*,
job_id: UUID,
next_run_at: datetime,
last_run_at: datetime,
) -> None:
stmt = select(AutomationJob).where(AutomationJob.id == job_id)
row = (await self._session.execute(stmt)).scalar_one()
row.next_run_at = next_run_at
row.last_run_at = last_run_at
await self._session.flush()
async def commit(self) -> None:
await self._session.commit()
async def rollback(self) -> None:
await self._session.rollback()
def _compute_next_run_at(
*,
current_next_run_at: datetime,
now_utc: datetime,
schedule_type: ScheduleType,
) -> datetime:
delta = timedelta(days=1 if schedule_type == ScheduleType.DAILY else 7)
next_run_at = current_next_run_at
while next_run_at <= now_utc:
next_run_at = next_run_at + delta
return next_run_at
def utc_now() -> datetime:
return datetime.now(timezone.utc)
@@ -1,27 +1,30 @@
agents:
- agent_type: worker
llm_model_code: qwen3.5-35b-a3b
status: active
config:
temperature: 0.7
max_tokens: null
timeout_seconds: 30
context_messages:
mode: number
count: 20
enabled_tool_groups:
- read
- write
- agent_type: router
llm_model_code: qwen3.5-flash
status: active
config:
temperature: 0.7
max_tokens: null
timeout_seconds: 30
visibility_consumer_bit: 16
context_messages:
mode: day
count: 2
enabled_tools: []
- agent_type: memory
llm_model_code: qwen3.5-flash
status: active
config:
temperature: 0.7
max_tokens: null
timeout_seconds: 30
context_messages:
mode: day
count: 2
enabled_tool_groups:
- read
- agent_type: worker
llm_model_code: qwen3.5-flash
status: active
config:
temperature: 0.7
max_tokens: null
timeout_seconds: 30
visibility_consumer_bit: 17
context_messages:
mode: number
count: 20
enabled_tools:
- calendar.read
- calendar.write
- calendar.share
- user.lookup
+6
View File
@@ -5,6 +5,7 @@ import uuid
from enum import Enum
from sqlalchemy import (
BigInteger,
JSON,
Enum as SqlEnum,
ForeignKey,
@@ -59,6 +60,11 @@ class AgentChatMessage(TimestampMixin, SoftDeleteMixin, Base):
output_tokens: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
cost: Mapped[Decimal] = mapped_column(Numeric(12, 6), nullable=False, default=0)
latency_ms: Mapped[int | None] = mapped_column(Integer, nullable=True)
visibility_mask: Mapped[int] = mapped_column(
BigInteger,
nullable=False,
default=0,
)
metadata_json: Mapped[dict[str, object] | None] = mapped_column(
"metadata", JSON().with_variant(JSONB, "postgresql"), nullable=True
)
+5 -4
View File
@@ -4,8 +4,8 @@ import uuid
from datetime import datetime
from enum import Enum
from sqlalchemy import DateTime, String, Text
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy import DateTime, JSON, String
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.orm import Mapped, mapped_column
from core.db.base import Base, SoftDeleteMixin, TimestampMixin
@@ -36,9 +36,10 @@ class AutomationJob(TimestampMixin, SoftDeleteMixin, Base):
String(255),
nullable=False,
)
prompt: Mapped[str] = mapped_column(
Text,
config: Mapped[dict[str, object]] = mapped_column(
JSON().with_variant(JSONB, "postgresql"),
nullable=False,
default=dict,
)
schedule_type: Mapped[ScheduleType] = mapped_column(
String(20),
@@ -0,0 +1,44 @@
from __future__ import annotations
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
class AgentConsumerBinding(BaseModel):
model_config = ConfigDict(extra="forbid")
agent_type: str = Field(..., min_length=1, max_length=64)
bit: int = Field(..., ge=16, le=63)
@field_validator("agent_type")
@classmethod
def _normalize_agent_type(cls, value: str) -> str:
normalized = value.strip().lower()
if not normalized:
raise ValueError("agent_type must not be empty")
return normalized
class ConsumerRegistry(BaseModel):
model_config = ConfigDict(extra="forbid")
bindings: list[AgentConsumerBinding] = Field(default_factory=list)
@model_validator(mode="after")
def _validate_unique_bindings(self) -> "ConsumerRegistry":
by_agent: set[str] = set()
by_bit: set[int] = set()
for item in self.bindings:
if item.agent_type in by_agent:
raise ValueError(f"duplicate agent_type binding: {item.agent_type}")
if item.bit in by_bit:
raise ValueError(f"duplicate visibility bit binding: {item.bit}")
by_agent.add(item.agent_type)
by_bit.add(item.bit)
return self
def resolve_agent_bit(self, *, agent_type: str) -> int:
target = agent_type.strip().lower()
for item in self.bindings:
if item.agent_type == target:
return item.bit
raise ValueError(f"agent visibility bit not configured: {target}")
@@ -0,0 +1,63 @@
from __future__ import annotations
from enum import Enum
from pydantic import BaseModel, ConfigDict, Field, field_validator
class ExecutorKind(str, Enum):
SINGLE_SHOT = "single_shot"
REACT = "react"
class ContextWindowMode(str, Enum):
DAY = "day"
NUMBER = "number"
class ContextPolicy(BaseModel):
model_config = ConfigDict(extra="forbid")
consumer_agent_type: str = Field(..., min_length=1, max_length=64)
window_mode: ContextWindowMode = ContextWindowMode.NUMBER
count: int = Field(default=20, ge=1, le=200)
@field_validator("consumer_agent_type")
@classmethod
def _normalize_consumer_agent_type(cls, value: str) -> str:
normalized = value.strip().lower()
if not normalized:
raise ValueError("consumer_agent_type must not be empty")
return normalized
class StageSpec(BaseModel):
model_config = ConfigDict(extra="forbid")
stage_name: str = Field(..., min_length=1, max_length=64)
executor_kind: ExecutorKind
default_visibility_mask: int = Field(..., ge=0, le=(1 << 63) - 1)
context_policy: ContextPolicy
@field_validator("stage_name")
@classmethod
def _normalize_stage_name(cls, value: str) -> str:
normalized = value.strip().lower()
if not normalized:
raise ValueError("stage_name must not be empty")
return normalized
class PipelineSpec(BaseModel):
model_config = ConfigDict(extra="forbid")
mode: str = Field(..., min_length=1, max_length=64)
stages: list[StageSpec] = Field(..., min_length=1)
@field_validator("mode")
@classmethod
def _normalize_mode(cls, value: str) -> str:
normalized = value.strip().lower()
if not normalized:
raise ValueError("mode must not be empty")
return normalized
+50
View File
@@ -0,0 +1,50 @@
from __future__ import annotations
from enum import IntEnum
from pydantic import BaseModel, ConfigDict, Field, field_validator
class SystemVisibilityBit(IntEnum):
UI_HISTORY = 0
UI_REALTIME = 1
class VisibilityMask(BaseModel):
model_config = ConfigDict(extra="forbid")
value: int = Field(..., ge=0, le=(1 << 63) - 1)
@classmethod
def from_bits(cls, *, bits: list[int]) -> "VisibilityMask":
mask = 0
for bit in bits:
validate_visibility_bit(bit=bit)
mask |= 1 << bit
return cls(value=mask)
def contains(self, *, bit: int) -> bool:
validate_visibility_bit(bit=bit)
return bool(self.value & (1 << bit))
class VisibilityBitRef(BaseModel):
model_config = ConfigDict(extra="forbid")
bit: int = Field(..., ge=0, le=63)
@field_validator("bit")
@classmethod
def _validate_bit(cls, value: int) -> int:
validate_visibility_bit(bit=value)
return value
def validate_visibility_bit(*, bit: int) -> None:
if bit < 0 or bit > 63:
raise ValueError("visibility bit must be in range [0, 63]")
def bit_mask(*, bit: int) -> int:
validate_visibility_bit(bit=bit)
return 1 << bit
@@ -0,0 +1,20 @@
from schemas.automation.config import (
AutomationAgentType,
AutomationContextSource,
AutomationContextWindowMode,
AutomationJobConfig,
AutomationMemoryContextConfig,
default_memory_job_config,
)
from schemas.automation.scheduler import DueAutomationJob, SchedulerDispatchCommand
__all__ = [
"AutomationAgentType",
"AutomationContextSource",
"AutomationContextWindowMode",
"AutomationJobConfig",
"AutomationMemoryContextConfig",
"default_memory_job_config",
"DueAutomationJob",
"SchedulerDispatchCommand",
]
+62
View File
@@ -0,0 +1,62 @@
from __future__ import annotations
from enum import Enum
from pydantic import BaseModel, ConfigDict, Field, field_validator
from core.agentscope.tools.tool_config import AgentTool
class AutomationAgentType(str, Enum):
MEMORY = "memory"
class AutomationContextSource(str, Enum):
LATEST_CHAT = "latest_chat"
class AutomationContextWindowMode(str, Enum):
DAY = "day"
NUMBER = "number"
class AutomationMemoryContextConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
source: AutomationContextSource = AutomationContextSource.LATEST_CHAT
window_mode: AutomationContextWindowMode = AutomationContextWindowMode.DAY
window_count: int = Field(default=2, ge=1, le=200)
class AutomationJobConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
agent_type: AutomationAgentType = AutomationAgentType.MEMORY
model_code: str = Field(default="qwen3.5-flash", min_length=1, max_length=64)
enabled_tools: list[AgentTool] = Field(default_factory=list, max_length=32)
input_template: str = Field(..., min_length=1, max_length=4000)
context: AutomationMemoryContextConfig = Field(
default_factory=AutomationMemoryContextConfig
)
@field_validator("model_code")
@classmethod
def _validate_model_code(cls, value: str) -> str:
normalized = value.strip()
if normalized != "qwen3.5-flash":
raise ValueError("model_code must be qwen3.5-flash")
return normalized
def default_memory_job_config() -> AutomationJobConfig:
return AutomationJobConfig(
agent_type=AutomationAgentType.MEMORY,
model_code="qwen3.5-flash",
enabled_tools=[AgentTool.CALENDAR_READ, AgentTool.USER_LOOKUP],
input_template="请基于最近聊天上下文生成一段可执行的记忆总结与建议。",
context=AutomationMemoryContextConfig(
source=AutomationContextSource.LATEST_CHAT,
window_mode=AutomationContextWindowMode.DAY,
window_count=2,
),
)
@@ -0,0 +1,28 @@
from __future__ import annotations
from datetime import datetime
from uuid import UUID
from pydantic import BaseModel, ConfigDict, Field
from models.automation_jobs import ScheduleType
class DueAutomationJob(BaseModel):
model_config = ConfigDict(extra="forbid")
id: UUID
owner_id: UUID
schedule_type: ScheduleType
timezone: str = Field(..., min_length=1, max_length=50)
next_run_at: datetime
class SchedulerDispatchCommand(BaseModel):
model_config = ConfigDict(extra="forbid")
owner_id: UUID
automation_job_id: UUID
thread_id: UUID
run_id: str = Field(..., min_length=1, max_length=128)
input_text: str = Field(..., min_length=1, max_length=4000)
@@ -189,3 +189,56 @@ def test_step_started_internal_event_keeps_step_name() -> None:
assert result["type"] == "STEP_STARTED"
assert result["stepName"] == "worker"
def test_run_error_prefers_top_level_message_and_code() -> None:
internal = {
"type": "run.error",
"threadId": "thread-1",
"runId": "run-1",
"message": "runtime failed",
"code": "RUNTIME_ERROR",
"data": {
"message": "nested message",
"code": "NESTED_ERROR",
},
}
result = to_agui_wire_event(internal)
assert result["type"] == "RUN_ERROR"
assert result["message"] == "runtime failed"
assert result["code"] == "RUNTIME_ERROR"
def test_run_error_falls_back_to_data_when_top_level_missing() -> None:
internal = {
"type": "run.error",
"threadId": "thread-1",
"runId": "run-1",
"data": {
"message": "nested message",
"code": "NESTED_ERROR",
},
}
result = to_agui_wire_event(internal)
assert result["type"] == "RUN_ERROR"
assert result["message"] == "nested message"
assert result["code"] == "NESTED_ERROR"
def test_run_error_uses_default_message_when_payload_invalid() -> None:
internal = {
"type": "run.error",
"threadId": "thread-1",
"runId": "run-1",
"data": "invalid",
}
result = to_agui_wire_event(internal)
assert result["type"] == "RUN_ERROR"
assert result["message"] == "Unknown error"
assert "code" not in result
@@ -59,6 +59,16 @@ def _patch_repositories(
monkeypatch.setattr(store_module, "MessageRepository", _FakeMessageRepository)
monkeypatch.setattr(store_module, "AgentChatSessionStatus", _SessionStatus)
async def _fake_stage_bit_map(self, *, session: object) -> dict[str, int]:
del self, session
return {"router": 16, "worker": 17, "memory": 18}
monkeypatch.setattr(
store_module.SqlAlchemyEventStore,
"_load_stage_visibility_bit_map",
_fake_stage_bit_map,
)
@pytest.mark.asyncio
async def test_store_persists_worker_output_with_answer_as_content(
@@ -103,6 +113,7 @@ async def test_store_persists_worker_output_with_answer_as_content(
assert metadata["agent_output"]["answer"] == "worker-answer"
assert metadata["agent_output"]["ui_hints"]["intent"] == "message"
assert append_kwargs["cost"] == Decimal("0.123")
assert append_kwargs["visibility_mask"] == ((1 << 0) | (1 << 17))
assert captured["message_delta"] == 1
assert captured["token_delta"] == 8
@@ -141,3 +152,4 @@ async def test_store_persists_tool_output_with_summary_as_content(
metadata["tool_agent_output"]["result"]
== "status=success batch=1 success=1 failed=0 ids=[event-1]"
)
assert append_kwargs["visibility_mask"] == (1 << 0)
@@ -1,12 +1,13 @@
from __future__ import annotations
from uuid import UUID
from uuid import uuid4
import pytest
from core.agentscope.persistence.user_context_cache import UserContextCache
from core.agentscope.schemas.user_context import (
UserAgentContext,
from schemas.user.context import (
UserContext,
parse_profile_settings,
)
@@ -45,15 +46,15 @@ class _FakeRedis:
self.set_store.pop(key, None)
return len(keys)
async def sadd(self, key: str, *values: str) -> int:
bucket = self.set_store.setdefault(key, set())
async def sadd(self, name: str, *values: str) -> int:
bucket = self.set_store.setdefault(name, set())
before = len(bucket)
for value in values:
bucket.add(value)
return len(bucket) - before
async def smembers(self, key: str) -> set[str]:
return set(self.set_store.get(key, set()))
async def smembers(self, name: str) -> set[str]:
return set(self.set_store.get(name, set()))
class _BrokenRedis:
@@ -77,18 +78,18 @@ class _BrokenRedis:
del keys
raise RuntimeError("redis down")
async def sadd(self, key: str, *values: str) -> int:
del key, values
async def sadd(self, name: str, *values: str) -> int:
del name, values
raise RuntimeError("redis down")
async def smembers(self, key: str) -> set[str]:
del key
async def smembers(self, name: str) -> set[str]:
del name
raise RuntimeError("redis down")
def _build_context() -> UserAgentContext:
return UserAgentContext(
user_id=uuid4(),
def _build_context() -> UserContext:
return UserContext(
id=str(uuid4()),
username="demo-user",
bio="demo bio",
settings=parse_profile_settings({"preferences": {"ai_language": "en-US"}}),
@@ -111,11 +112,11 @@ async def test_user_context_cache_set_and_get_hit() -> None:
loaded = await cache.get(session_id=session_id)
assert loaded is not None
assert loaded.user_id == context.user_id
assert loaded.id == context.id
assert loaded.username == "demo-user"
assert redis.expire_calls == [
(f"agent:user-context:{session_id}", 600),
(f"agent:user-context:sessions:{context.user_id}", 600),
(f"agent:user-context:sessions:{context.id}", 600),
]
assert redis.hincrby_calls == [
(f"agent:user-context:{session_id}", "turns_used", 1)
@@ -138,12 +139,14 @@ async def test_user_context_cache_invalidate_user_deletes_all_sessions() -> None
await cache.set(session_id=s1, context=context)
await cache.set(session_id=s2, context=context)
deleted = await cache.invalidate_user(user_id=context.user_id)
deleted = await cache.invalidate_user(user_id=UUID(context.id))
assert deleted == 2
assert f"agent:user-context:{s1}" in redis.delete_calls
assert f"agent:user-context:{s2}" in redis.delete_calls
assert f"agent:user-context:sessions:{context.user_id}" in redis.delete_calls
assert any(
key.startswith("agent:user-context:sessions:") for key in redis.delete_calls
)
@pytest.mark.asyncio
@@ -0,0 +1,28 @@
from __future__ import annotations
import pytest
from core.agentscope.runtime.consumer_registry import build_consumer_registry
def test_build_consumer_registry_from_system_agent_configs() -> None:
registry = build_consumer_registry(
system_agent_configs={
"router": {"config": {"visibility_consumer_bit": 16}},
"worker": {"config": {"visibility_consumer_bit": 17}},
"memory": {"config": {"visibility_consumer_bit": 18}},
}
)
assert registry.resolve_agent_bit(agent_type="router") == 16
assert registry.resolve_agent_bit(agent_type="worker") == 17
def test_build_consumer_registry_rejects_duplicate_bit() -> None:
with pytest.raises(ValueError, match="duplicate visibility bit"):
build_consumer_registry(
system_agent_configs={
"router": {"config": {"visibility_consumer_bit": 16}},
"worker": {"config": {"visibility_consumer_bit": 16}},
}
)
@@ -28,7 +28,7 @@ def _user_context() -> UserContext:
return UserContext(
id="00000000-0000-0000-0000-000000000001",
username="alice",
email="alice@example.com",
phone="+8613900000000",
settings=parse_profile_settings(None),
)
@@ -42,7 +42,7 @@ def _run_input() -> RunAgentInput:
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
"tools": [],
"context": [],
"forwardedProps": {},
"forwardedProps": {"agent_type": "worker"},
}
)
@@ -58,6 +58,7 @@ async def test_orchestrator_emits_run_lifecycle_events() -> None:
run_input=_run_input(),
context_messages=[],
user_context=_user_context(),
system_agent_mode="worker",
)
assert result["worker"]["answer"] == "done"
@@ -0,0 +1,24 @@
from __future__ import annotations
import pytest
from core.agentscope.runtime.pipeline_registry import build_default_pipeline_spec
def test_build_default_pipeline_spec_worker_has_two_stages() -> None:
spec = build_default_pipeline_spec(mode="worker")
assert spec.mode == "worker"
assert [item.stage_name for item in spec.stages] == ["router", "worker"]
def test_build_default_pipeline_spec_memory_has_single_stage() -> None:
spec = build_default_pipeline_spec(mode="memory")
assert spec.mode == "memory"
assert [item.stage_name for item in spec.stages] == ["memory"]
def test_build_default_pipeline_spec_rejects_unknown_mode() -> None:
with pytest.raises(ValueError, match="unsupported pipeline mode"):
build_default_pipeline_spec(mode="planner")
@@ -3,8 +3,23 @@ from __future__ import annotations
import pytest
from ag_ui.core import RunAgentInput
import core.agentscope.runtime.runner as runner_module
from core.agentscope.runtime.runner import AgentScopeRunner
from schemas.automation.config import default_memory_job_config
from schemas.agent.runtime_models import (
ExecutionMode,
NormalizedTaskInput,
ResultType,
ResultTyping,
RouterAgentOutput,
RouterUiDecision,
TaskType,
TaskTyping,
UiMode,
WorkerAgentOutputLite,
)
from schemas.agent.system_agent import AgentType
from schemas.user import UserContext, parse_profile_settings
def _run_input() -> RunAgentInput:
@@ -16,19 +31,51 @@ def _run_input() -> RunAgentInput:
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
"tools": [],
"context": [],
"forwardedProps": {},
"forwardedProps": {"agent_type": "worker"},
}
)
def test_resolve_stage_agent_type_defaults_to_worker() -> None:
assert AgentScopeRunner._resolve_stage_agent_type("") == AgentType.WORKER
assert AgentScopeRunner._resolve_stage_agent_type("worker") == AgentType.WORKER
assert AgentScopeRunner._resolve_stage_agent_type("unknown") == AgentType.WORKER
def _user_context() -> UserContext:
return UserContext(
id="00000000-0000-0000-0000-000000000001",
username="alice",
phone="+8613900000000",
settings=parse_profile_settings(None),
)
def test_resolve_stage_agent_type_supports_memory() -> None:
assert AgentScopeRunner._resolve_stage_agent_type("memory") == AgentType.MEMORY
def test_parse_agent_type_supports_known_stages() -> None:
assert AgentScopeRunner._parse_agent_type(stage_name="router") == AgentType.ROUTER
assert AgentScopeRunner._parse_agent_type(stage_name="worker") == AgentType.WORKER
assert AgentScopeRunner._parse_agent_type(stage_name="memory") == AgentType.MEMORY
def test_parse_agent_type_rejects_unknown_stage() -> None:
with pytest.raises(ValueError, match="unsupported stage name"):
AgentScopeRunner._parse_agent_type(stage_name="planner")
def test_build_worker_input_messages_only_contains_router_contract() -> None:
runner = AgentScopeRunner()
router_output = RouterAgentOutput(
normalized_task_input=NormalizedTaskInput(user_text="安排明天会议"),
key_entities=[],
constraints=[],
task_typing=TaskTyping(primary=TaskType.SCHEDULING),
execution_mode=ExecutionMode.TOOL_ASSISTED,
result_typing=ResultTyping(primary=ResultType.EXECUTION_REPORT),
ui=RouterUiDecision(
ui_mode=UiMode.NONE,
ui_decision_reason="单一执行任务,文本输出足够",
),
)
input_messages = runner._build_worker_input_messages(router_output=router_output)
assert len(input_messages) == 1
assert input_messages[0].role == "user"
assert "[RouterAgentOutput]" in str(input_messages[0].content)
@pytest.mark.asyncio
@@ -43,11 +90,12 @@ async def test_resolve_runtime_client_time_from_forwarded_props() -> None:
"tools": [],
"context": [],
"forwardedProps": {
"agent_type": "worker",
"client_time": {
"device_timezone": "America/Los_Angeles",
"client_now_iso": "2026-03-16T09:12:33-07:00",
"client_epoch_ms": 1773658353000,
}
},
},
}
)
@@ -55,3 +103,146 @@ async def test_resolve_runtime_client_time_from_forwarded_props() -> None:
resolved = runner._resolve_runtime_client_time(run_input=run_input)
assert resolved is not None
assert resolved.device_timezone == "America/Los_Angeles"
@pytest.mark.asyncio
async def test_execute_worker_mode_runs_router_then_worker(
monkeypatch: pytest.MonkeyPatch,
) -> None:
class _FakePipeline:
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
del session_id, event
return "1-0"
class _FakeSessionCtx:
async def __aenter__(self) -> object:
return object()
async def __aexit__(self, exc_type: object, exc: object, tb: object) -> None:
del exc_type, exc, tb
runner = AgentScopeRunner()
load_calls: list[AgentType] = []
async def _fake_load_stage_config(*, session: object, agent_type: AgentType):
del session
load_calls.append(agent_type)
return runner_module.SystemAgentRuntimeConfig(
agent_type=agent_type,
model_code="demo",
api_base_url="https://example.com",
api_key="test",
llm_config=runner_module.SystemAgentLLMConfig(),
)
async def _fake_execute_router_step(**kwargs: object) -> RouterAgentOutput:
del kwargs
return RouterAgentOutput(
normalized_task_input=NormalizedTaskInput(user_text="安排会议"),
key_entities=[],
constraints=[],
task_typing=TaskTyping(primary=TaskType.SCHEDULING),
execution_mode=ExecutionMode.TOOL_ASSISTED,
result_typing=ResultTyping(primary=ResultType.EXECUTION_REPORT),
ui=RouterUiDecision(
ui_mode=UiMode.NONE,
ui_decision_reason="单任务",
),
)
async def _fake_execute_worker_step(**kwargs: object) -> WorkerAgentOutputLite:
del kwargs
return WorkerAgentOutputLite(answer="ok")
monkeypatch.setattr(runner_module, "AsyncSessionLocal", lambda: _FakeSessionCtx())
monkeypatch.setattr(runner, "_load_stage_config", _fake_load_stage_config)
monkeypatch.setattr(runner, "_build_stage_toolkit", lambda **kwargs: object())
monkeypatch.setattr(runner, "_execute_router_step", _fake_execute_router_step)
monkeypatch.setattr(runner, "_execute_worker_step", _fake_execute_worker_step)
result = await runner.execute(
user_context=_user_context(),
context_messages=[],
pipeline=_FakePipeline(),
run_input=_run_input(),
system_agent_mode="worker",
)
assert load_calls == [AgentType.ROUTER, AgentType.WORKER]
assert result["router"]["normalized_task_input"]["user_text"] == "安排会议"
assert result["worker"]["answer"] == "ok"
@pytest.mark.asyncio
async def test_execute_memory_mode_requires_memory_job_config() -> None:
runner = AgentScopeRunner()
class _FakePipeline:
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
del session_id, event
return "1-0"
with pytest.raises(RuntimeError, match="memory job config is required"):
await runner.execute(
user_context=_user_context(),
context_messages=[],
pipeline=_FakePipeline(),
run_input=_run_input(),
system_agent_mode="memory",
memory_job_config=None,
)
@pytest.mark.asyncio
async def test_execute_memory_mode_uses_memory_job_config(
monkeypatch: pytest.MonkeyPatch,
) -> None:
class _FakePipeline:
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
del session_id, event
return "1-0"
class _FakeSessionCtx:
async def __aenter__(self) -> object:
return object()
async def __aexit__(self, exc_type: object, exc: object, tb: object) -> None:
del exc_type, exc, tb
runner = AgentScopeRunner()
async def _fake_build_memory_stage_config(**kwargs: object):
del kwargs
return runner_module.SystemAgentRuntimeConfig(
agent_type=AgentType.MEMORY,
model_code="qwen3.5-flash",
api_base_url="https://example.com",
api_key="test",
llm_config=runner_module.SystemAgentLLMConfig(),
)
async def _fake_execute_single_stage_step(**kwargs: object):
del kwargs
return runner_module.AgentOutput(answer="memory")
monkeypatch.setattr(runner_module, "AsyncSessionLocal", lambda: _FakeSessionCtx())
monkeypatch.setattr(
runner, "_build_memory_stage_config", _fake_build_memory_stage_config
)
monkeypatch.setattr(runner, "_build_stage_toolkit", lambda **kwargs: object())
monkeypatch.setattr(
runner,
"_execute_single_stage_step",
_fake_execute_single_stage_step,
)
result = await runner.execute(
user_context=_user_context(),
context_messages=[],
pipeline=_FakePipeline(),
run_input=_run_input(),
system_agent_mode="memory",
memory_job_config=default_memory_job_config(),
)
assert result["memory"]["answer"] == "memory"
@@ -18,7 +18,7 @@ def _run_input_payload() -> dict[str, Any]:
"messages": [],
"tools": [],
"context": [],
"forwardedProps": {},
"forwardedProps": {"agent_type": "worker"},
}
@@ -35,7 +35,7 @@ async def _fake_user_context(**kwargs: object) -> UserContext:
return UserContext(
id=str(uuid4()),
username="alice",
email="alice@example.com",
phone="+8613900000000",
settings=parse_profile_settings(None),
)
@@ -177,17 +177,53 @@ async def test_run_agentscope_task_rejects_invalid_command_type() -> None:
)
@pytest.mark.asyncio
async def test_run_agentscope_task_requires_forwarded_props_agent_type() -> None:
payload = _run_input_payload()
payload["forwardedProps"] = {}
with pytest.raises(ValueError, match="invalid RunAgentInput.forwardedProps"):
await tasks_module.run_agentscope_task(
{
"command": "run",
"owner_id": str(uuid4()),
"run_input": payload,
}
)
@pytest.mark.asyncio
async def test_run_agentscope_task_memory_mode_requires_automation_job_id() -> None:
payload = _run_input_payload()
payload["forwardedProps"] = {"agent_type": "memory"}
with pytest.raises(
ValueError, match="automation_job_id is required for memory mode"
):
await tasks_module.run_agentscope_task(
{
"command": "run",
"owner_id": str(uuid4()),
"run_input": payload,
}
)
@pytest.mark.asyncio
async def test_build_recent_context_messages_includes_all_user_attachments(
monkeypatch: pytest.MonkeyPatch,
) -> None:
class _FakeAgentService:
async def load_agent_input_messages(
class _FakeContextService:
def __init__(self, *, repository: object) -> None:
del repository
async def load_context_messages(
self,
*,
thread_id: str,
system_agent_mode: str,
) -> dict[str, object] | None:
del thread_id
del thread_id, system_agent_mode
return {
"messages": [
{
@@ -215,14 +251,13 @@ async def test_build_recent_context_messages_includes_all_user_attachments(
async def download_bytes(self, *, bucket: str, path: str) -> bytes:
return f"{bucket}:{path}".encode("utf-8")
monkeypatch.setattr(
tasks_module, "get_agent_service", lambda session: _FakeAgentService()
)
monkeypatch.setattr(tasks_module, "AgentContextService", _FakeContextService)
monkeypatch.setattr(tasks_module, "supabase_service", _FakeSupabase())
messages = await tasks_module._build_recent_context_messages(
session=object(),
thread_id=str(uuid4()),
context_mode="worker",
)
assert len(messages) == 1
@@ -238,13 +273,17 @@ async def test_build_recent_context_messages_includes_all_user_attachments(
async def test_build_recent_context_messages_uses_tool_metadata_output(
monkeypatch: pytest.MonkeyPatch,
) -> None:
class _FakeAgentService:
async def load_agent_input_messages(
class _FakeContextService:
def __init__(self, *, repository: object) -> None:
del repository
async def load_context_messages(
self,
*,
thread_id: str,
system_agent_mode: str,
) -> dict[str, object] | None:
del thread_id
del thread_id, system_agent_mode
return {
"messages": [
{
@@ -268,13 +307,12 @@ async def test_build_recent_context_messages_uses_tool_metadata_output(
]
}
monkeypatch.setattr(
tasks_module, "get_agent_service", lambda session: _FakeAgentService()
)
monkeypatch.setattr(tasks_module, "AgentContextService", _FakeContextService)
messages = await tasks_module._build_recent_context_messages(
session=object(),
thread_id=str(uuid4()),
context_mode="worker",
)
assert len(messages) == 1
@@ -290,13 +328,17 @@ async def test_build_recent_context_messages_uses_tool_metadata_output(
async def test_build_recent_context_messages_skips_tool_without_metadata_output(
monkeypatch: pytest.MonkeyPatch,
) -> None:
class _FakeAgentService:
async def load_agent_input_messages(
class _FakeContextService:
def __init__(self, *, repository: object) -> None:
del repository
async def load_context_messages(
self,
*,
thread_id: str,
system_agent_mode: str,
) -> dict[str, object] | None:
del thread_id
del thread_id, system_agent_mode
return {
"messages": [
{
@@ -307,13 +349,44 @@ async def test_build_recent_context_messages_skips_tool_without_metadata_output(
]
}
monkeypatch.setattr(
tasks_module, "get_agent_service", lambda session: _FakeAgentService()
)
monkeypatch.setattr(tasks_module, "AgentContextService", _FakeContextService)
messages = await tasks_module._build_recent_context_messages(
session=object(),
thread_id=str(uuid4()),
context_mode="worker",
)
assert messages == []
@pytest.mark.asyncio
async def test_build_recent_context_messages_passes_context_mode_through(
monkeypatch: pytest.MonkeyPatch,
) -> None:
captured_mode: dict[str, str | None] = {"mode": None}
class _FakeContextService:
def __init__(self, *, repository: object) -> None:
del repository
async def load_context_messages(
self,
*,
thread_id: str,
system_agent_mode: str,
) -> dict[str, object] | None:
del thread_id
captured_mode["mode"] = system_agent_mode
return None
monkeypatch.setattr(tasks_module, "AgentContextService", _FakeContextService)
messages = await tasks_module._build_recent_context_messages(
session=object(),
thread_id=str(uuid4()),
context_mode="worker",
)
assert messages == []
assert captured_mode["mode"] == "worker"
@@ -1,99 +1,8 @@
from __future__ import annotations
import pytest
from pydantic import ValidationError
from core.agentscope import schemas as exported_schemas
from core.agentscope.schemas.agent_runtime import (
AcceptedTaskResponse,
AgUiWireEvent,
HistorySnapshot,
HistorySnapshotResponse,
InternalRuntimeEvent,
RunCommand,
pytest.skip(
"legacy agent_runtime schemas removed; covered by agui_input tests",
allow_module_level=True,
)
def test_run_command_alias_roundtrip() -> None:
payload = {
"threadId": "thread-001",
"runId": "run-001",
"state": {"cursor": 1},
"messages": [{"role": "user", "content": "hi"}],
"tools": [{"name": "calendar.lookup"}],
"context": {"locale": "zh-CN"},
"forwardedProps": {"traceId": "trace-1"},
}
command = RunCommand.model_validate(payload)
assert command.thread_id == "thread-001"
assert command.run_id == "run-001"
assert command.forwarded_props == {"traceId": "trace-1"}
dumped = command.model_dump(mode="json", by_alias=True)
assert dumped["threadId"] == "thread-001"
assert dumped["runId"] == "run-001"
assert dumped["forwardedProps"] == {"traceId": "trace-1"}
def test_history_snapshot_response_shape() -> None:
response = HistorySnapshotResponse(
threadId="thread-123",
snapshot=HistorySnapshot(
threadId="thread-123",
day="2026-03-11",
hasMore=False,
messages=[{"id": "msg-1"}],
),
)
dumped = response.model_dump(mode="json", by_alias=True, exclude_none=True)
assert dumped["type"] == "STATE_SNAPSHOT"
assert dumped["threadId"] == "thread-123"
assert dumped["snapshot"]["scope"] == "history_day"
assert dumped["snapshot"]["hasMore"] is False
assert dumped["snapshot"]["messages"] == [{"id": "msg-1"}]
def test_runtime_event_validation_basics() -> None:
internal = InternalRuntimeEvent(type="RUN_STARTED", data={"step": 1})
assert internal.type == "RUN_STARTED"
assert internal.model_dump(mode="json", by_alias=True)["data"] == {"step": 1}
wire = AgUiWireEvent(type="TEXT_MESSAGE_CONTENT", payload={"delta": "hello"})
dumped = wire.model_dump(mode="json", by_alias=True, exclude_none=True)
assert dumped["type"] == "TEXT_MESSAGE_CONTENT"
assert dumped["payload"] == {"delta": "hello"}
with pytest.raises(ValidationError):
InternalRuntimeEvent.model_validate({"threadId": "t-1", "data": {}})
with pytest.raises(ValidationError):
AgUiWireEvent.model_validate({"payload": {"delta": "hello"}})
def test_schemas_exports_include_task_and_history_models() -> None:
assert exported_schemas.AcceptedTaskResponse is AcceptedTaskResponse
assert exported_schemas.TaskAccepted is AcceptedTaskResponse
assert exported_schemas.TaskAcceptedResponse is AcceptedTaskResponse
assert exported_schemas.HistorySnapshotResponse is HistorySnapshotResponse
def test_run_command_accepts_agui_context_list_and_parent_run_id() -> None:
payload = {
"threadId": "thread-xyz",
"runId": "run-xyz",
"state": {},
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
"tools": [],
"context": [],
"forwardedProps": {},
"parentRunId": None,
}
command = RunCommand.model_validate(payload)
dumped = command.model_dump(mode="json", by_alias=True)
assert dumped["context"] == []
assert "parentRunId" in dumped
@@ -20,7 +20,7 @@ def _base_payload() -> dict[str, object]:
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
"tools": [],
"context": [],
"forwardedProps": {},
"forwardedProps": {"agent_type": "worker"},
}
@@ -149,7 +149,7 @@ def test_parse_run_input_accepts_snake_case_aliases() -> None:
],
"tools": [],
"context": [],
"forwarded_props": {},
"forwarded_props": {"agent_type": "worker"},
}
run_input = parse_run_input(payload)
@@ -162,11 +162,12 @@ def test_parse_run_input_accepts_snake_case_aliases() -> None:
def test_parse_run_input_accepts_client_time_forwarded_props() -> None:
payload = _base_payload()
payload["forwardedProps"] = {
"agent_type": "worker",
"client_time": {
"device_timezone": "America/Los_Angeles",
"client_now_iso": "2026-03-16T09:12:33-07:00",
"client_epoch_ms": 1773658353000,
}
},
}
run_input = parse_run_input(payload)
@@ -177,11 +178,12 @@ def test_parse_run_input_accepts_client_time_forwarded_props() -> None:
def test_parse_run_input_rejects_invalid_client_time_timezone() -> None:
payload = _base_payload()
payload["forwardedProps"] = {
"agent_type": "worker",
"client_time": {
"device_timezone": "Mars/OlympusMons",
"client_now_iso": "2026-03-16T09:12:33-07:00",
"client_epoch_ms": 1773658353000,
}
},
}
with pytest.raises(ValueError, match="invalid RunAgentInput.forwardedProps"):
@@ -191,11 +193,12 @@ def test_parse_run_input_rejects_invalid_client_time_timezone() -> None:
def test_parse_run_input_rejects_invalid_client_time_now_iso() -> None:
payload = _base_payload()
payload["forwardedProps"] = {
"agent_type": "worker",
"client_time": {
"device_timezone": "America/Los_Angeles",
"client_now_iso": "2026-03-16 09:12:33",
"client_epoch_ms": 1773658353000,
}
},
}
with pytest.raises(ValueError, match="invalid RunAgentInput.forwardedProps"):
@@ -205,11 +208,12 @@ def test_parse_run_input_rejects_invalid_client_time_now_iso() -> None:
def test_parse_run_input_rejects_invalid_client_time_epoch_type() -> None:
payload = _base_payload()
payload["forwardedProps"] = {
"agent_type": "worker",
"client_time": {
"device_timezone": "America/Los_Angeles",
"client_now_iso": "2026-03-16T09:12:33-07:00",
"client_epoch_ms": "1773658353000",
}
},
}
with pytest.raises(ValueError, match="invalid RunAgentInput.forwardedProps"):
@@ -219,6 +223,7 @@ def test_parse_run_input_rejects_invalid_client_time_epoch_type() -> None:
def test_parse_run_input_rejects_unknown_forwarded_props_key() -> None:
payload = _base_payload()
payload["forwardedProps"] = {
"agent_type": "worker",
"client_time": {
"device_timezone": "America/Los_Angeles",
"client_now_iso": "2026-03-16T09:12:33-07:00",
@@ -229,3 +234,17 @@ def test_parse_run_input_rejects_unknown_forwarded_props_key() -> None:
with pytest.raises(ValueError, match="invalid RunAgentInput.forwardedProps"):
parse_run_input(payload)
def test_parse_run_input_rejects_missing_forwarded_props_agent_type() -> None:
payload = _base_payload()
payload["forwardedProps"] = {
"client_time": {
"device_timezone": "America/Los_Angeles",
"client_now_iso": "2026-03-16T09:12:33-07:00",
"client_epoch_ms": 1773658353000,
}
}
with pytest.raises(ValueError, match="invalid RunAgentInput.forwardedProps"):
parse_run_input(payload)
@@ -11,7 +11,7 @@ def test_build_agent_prompt_for_worker_contains_runtime_config() -> None:
{
"temperature": 0.2,
"context_messages": {"mode": "number", "count": 20},
"enabled_tool_groups": ["read", "write"],
"enabled_tools": ["calendar.read", "calendar.write"],
}
),
)
@@ -20,7 +20,7 @@ def test_build_agent_prompt_for_worker_contains_runtime_config() -> None:
assert "- type: worker" in prompt
assert "context_messages.mode=number" in prompt
assert "context_messages.count=20" in prompt
assert "enabled_tool_groups=read,write" in prompt
assert "enabled_tools=calendar.read,calendar.write" in prompt
def test_build_agent_prompt_for_memory_uses_memory_rules() -> None:
@@ -29,7 +29,7 @@ def test_build_agent_prompt_for_memory_uses_memory_rules() -> None:
llm_config=SystemAgentLLMConfig.model_validate(
{
"context_messages": {"mode": "day", "count": 2},
"enabled_tool_groups": ["read"],
"enabled_tools": ["user.lookup"],
}
),
)
@@ -38,4 +38,4 @@ def test_build_agent_prompt_for_memory_uses_memory_rules() -> None:
assert "[Memory Agent]" in prompt
assert "context_messages.mode=day" in prompt
assert "context_messages.count=2" in prompt
assert "enabled_tool_groups=read" in prompt
assert "enabled_tools=user.lookup" in prompt
@@ -1,11 +1,12 @@
from __future__ import annotations
import json
from typing import Any, AsyncGenerator
import pytest
from core.agentscope.tools.hitl_middleware import create_hitl_middleware
from core.agentscope.tools.tool_meta import TOOL_META, ToolMeta
from core.agentscope.tools.tool_config import ToolApprovalConfig, ToolConfig, ToolGroup
from core.agentscope.tools.tool_middleware import create_approval_middleware
async def _next_handler(**kwargs: Any) -> AsyncGenerator[dict[str, object], None]:
@@ -15,9 +16,31 @@ async def _next_handler(**kwargs: Any) -> AsyncGenerator[dict[str, object], None
return _generator()
def _extract_error_payload(chunk: object) -> dict[str, Any]:
content = getattr(chunk, "content", None)
if not isinstance(content, list) or not content:
return {}
first_block = content[0]
text = getattr(first_block, "text", None)
if not isinstance(text, str) and isinstance(first_block, dict):
raw_text = first_block.get("text")
text = raw_text if isinstance(raw_text, str) else None
if not isinstance(text, str):
return {}
return json.loads(text)
@pytest.mark.asyncio
async def test_hitl_middleware_default_write_does_not_require_approval() -> None:
middleware = create_hitl_middleware(meta_by_name=TOOL_META)
middleware = create_approval_middleware(
config_by_name={
"calendar_write": ToolConfig(
name="calendar_write",
group=ToolGroup.EXECUTE,
approval=ToolApprovalConfig(required=False),
)
}
)
responses = []
async for chunk in middleware(
@@ -30,36 +53,39 @@ async def test_hitl_middleware_default_write_does_not_require_approval() -> None
@pytest.mark.asyncio
async def test_hitl_middleware_pending_when_tool_requires_approval(
monkeypatch: pytest.MonkeyPatch,
) -> None:
middleware = create_hitl_middleware(
meta_by_name={
"calendar.write": ToolMeta(name="calendar.write", requires_approval=True)
async def test_hitl_middleware_pending_when_tool_requires_approval() -> None:
middleware = create_approval_middleware(
config_by_name={
"calendar_write": ToolConfig(
name="calendar_write",
group=ToolGroup.EXECUTE,
approval=ToolApprovalConfig(required=True),
)
}
)
monkeypatch.setattr(
"core.agentscope.tools.hitl_middleware.build_tool_response",
lambda payload: payload,
)
responses = []
async for chunk in middleware(
{"tool_call": {"name": "calendar.write", "input": {"operation": "create"}}},
{"tool_call": {"name": "calendar_write", "input": {"operation": "create"}}},
_next_handler,
):
responses.append(chunk)
assert responses[0]["data"]["status"] == "pending"
payload = _extract_error_payload(responses[0])
assert payload["error"]["code"] == "TOOL_PENDING_APPROVAL"
@pytest.mark.asyncio
async def test_hitl_middleware_passes_when_write_approved() -> None:
middleware = create_hitl_middleware(
meta_by_name={
"calendar.write": ToolMeta(name="calendar.write", requires_approval=True)
middleware = create_approval_middleware(
config_by_name={
"calendar_write": ToolConfig(
name="calendar_write",
group=ToolGroup.EXECUTE,
approval=ToolApprovalConfig(required=True),
)
},
approval_resolver=lambda _name, _args: "approved",
approval_resolver=lambda _name, _args, _config: "approved",
)
responses = []
@@ -69,6 +95,7 @@ async def test_hitl_middleware_passes_when_write_approved() -> None:
"name": "calendar.write",
"input": {
"operation": "create",
"_hitl": {"approval": "required"},
},
}
},
@@ -82,25 +109,24 @@ async def test_hitl_middleware_passes_when_write_approved() -> None:
@pytest.mark.asyncio
async def test_hitl_middleware_rejected_short_circuits(
monkeypatch: pytest.MonkeyPatch,
) -> None:
middleware = create_hitl_middleware(
meta_by_name={
"calendar.write": ToolMeta(name="calendar.write", requires_approval=True)
async def test_hitl_middleware_rejected_short_circuits() -> None:
middleware = create_approval_middleware(
config_by_name={
"calendar_write": ToolConfig(
name="calendar_write",
group=ToolGroup.EXECUTE,
approval=ToolApprovalConfig(required=True),
)
},
approval_resolver=lambda _name, _args: "rejected",
)
monkeypatch.setattr(
"core.agentscope.tools.hitl_middleware.build_tool_response",
lambda payload: payload,
approval_resolver=lambda _name, _args, _config: "rejected",
)
responses = []
async for chunk in middleware(
{"tool_call": {"name": "calendar.write", "input": {"operation": "create"}}},
{"tool_call": {"name": "calendar_write", "input": {"operation": "create"}}},
_next_handler,
):
responses.append(chunk)
assert responses[0]["data"]["status"] == "rejected"
payload = _extract_error_payload(responses[0])
assert payload["error"]["code"] == "TOOL_REJECTED"
@@ -16,7 +16,7 @@ def _build_user_context(*, timezone_name: str = "Asia/Shanghai") -> UserContext:
return UserContext(
id=str(uuid4()),
username="alice",
email="alice@example.com",
phone="+8613900000000",
bio="focus on calendars",
settings=parse_profile_settings(
{
@@ -7,7 +7,9 @@ from core.agentscope.tools.toolkit import build_stage_toolkit
from schemas.agent.system_agent import AgentType
def test_build_stage_toolkit_filters_requested_tools_by_agent_type(monkeypatch) -> None:
def test_build_stage_toolkit_uses_explicit_enabled_tools_as_final_set(
monkeypatch,
) -> None:
captured: dict[str, object] = {}
def _fake_build_toolkit(**kwargs):
@@ -22,7 +24,29 @@ def test_build_stage_toolkit_filters_requested_tools_by_agent_type(monkeypatch)
agent_type=AgentType.WORKER,
session=cast(Any, object()),
owner_id=uuid4(),
enabled_tool_names={"calendar_read", "calendar_write", "user_lookup"},
enabled_tool_names={"calendar_read", "user_lookup"},
)
assert captured["enabled_tool_names"] == {"calendar_read", "user_lookup"}
def test_build_stage_toolkit_uses_memory_defaults_without_explicit_tools(
monkeypatch,
) -> None:
captured: dict[str, object] = {}
def _fake_build_toolkit(**kwargs):
captured.update(kwargs)
return object()
monkeypatch.setattr(
"core.agentscope.tools.toolkit.build_toolkit", _fake_build_toolkit
)
build_stage_toolkit(
agent_type=AgentType.MEMORY,
session=cast(Any, object()),
owner_id=uuid4(),
)
assert captured["enabled_tool_names"] == {"calendar_read", "user_lookup"}
@@ -0,0 +1,134 @@
from __future__ import annotations
from datetime import datetime, timedelta, timezone
from uuid import UUID, uuid4
import pytest
from core.automation.scheduler import (
AutomationSchedulerService,
_compute_next_run_at,
)
from models.automation_jobs import ScheduleType
from schemas.automation.config import AutomationJobConfig
from schemas.automation.scheduler import DueAutomationJob
class _FakeRepository:
def __init__(self, jobs: list[DueAutomationJob]) -> None:
self.jobs = jobs
self.marked: list[tuple[UUID, datetime, datetime]] = []
self.commits = 0
self.rollbacks = 0
async def list_due_jobs(
self, *, now_utc: datetime, limit: int
) -> list[DueAutomationJob]:
del now_utc
return self.jobs[:limit]
async def get_job_config(self, *, job_id: UUID) -> AutomationJobConfig:
del job_id
return AutomationJobConfig.model_validate(
{
"agent_type": "memory",
"model_code": "qwen3.5-flash",
"enabled_tools": ["calendar.read", "user.lookup"],
"input_template": "auto input",
"context": {
"source": "latest_chat",
"window_mode": "day",
"window_count": 2,
},
}
)
async def ensure_latest_chat_session(self, *, owner_id: UUID) -> UUID:
return owner_id
async def mark_job_dispatched(
self,
*,
job_id: UUID,
next_run_at: datetime,
last_run_at: datetime,
) -> None:
self.marked.append((job_id, next_run_at, last_run_at))
async def commit(self) -> None:
self.commits += 1
async def rollback(self) -> None:
self.rollbacks += 1
class _FakeQueue:
def __init__(self) -> None:
self.commands: list[dict[str, object]] = []
async def enqueue(
self,
*,
command: dict[str, object],
dedup_key: str | None,
) -> str:
del dedup_key
self.commands.append(command)
return "task-1"
@pytest.mark.asyncio
async def test_scan_and_dispatch_enqueues_memory_run_command() -> None:
now = datetime(2026, 3, 19, 12, 0, tzinfo=timezone.utc)
owner_id = uuid4()
job_id = uuid4()
repo = _FakeRepository(
jobs=[
DueAutomationJob(
id=job_id,
owner_id=owner_id,
schedule_type=ScheduleType.DAILY,
timezone="UTC",
next_run_at=now - timedelta(minutes=1),
)
]
)
queue = _FakeQueue()
service = AutomationSchedulerService(repository=repo, queue=queue)
result = await service.scan_and_dispatch(now_utc=now, limit=10)
assert result.scanned == 1
assert result.dispatched == 1
assert len(queue.commands) == 1
run_input = queue.commands[0]["run_input"]
assert isinstance(run_input, dict)
assert run_input["forwardedProps"] == {"agent_type": "memory"}
assert queue.commands[0]["automation_job_id"] == str(job_id)
assert repo.commits == 1
def test_compute_next_run_at_daily() -> None:
now = datetime(2026, 3, 19, 12, 0, tzinfo=timezone.utc)
current = datetime(2026, 3, 19, 11, 0, tzinfo=timezone.utc)
computed = _compute_next_run_at(
current_next_run_at=current,
now_utc=now,
schedule_type=ScheduleType.DAILY,
)
assert computed == datetime(2026, 3, 20, 11, 0, tzinfo=timezone.utc)
def test_compute_next_run_at_weekly() -> None:
now = datetime(2026, 3, 19, 12, 0, tzinfo=timezone.utc)
current = datetime(2026, 3, 10, 11, 0, tzinfo=timezone.utc)
computed = _compute_next_run_at(
current_next_run_at=current,
now_utc=now,
schedule_type=ScheduleType.WEEKLY,
)
assert computed == datetime(2026, 3, 24, 11, 0, tzinfo=timezone.utc)
@@ -0,0 +1,18 @@
from __future__ import annotations
from pathlib import Path
def test_memory_automation_job_trigger_exists_in_0004_migration() -> None:
migration = (
Path(__file__).resolve().parents[3]
/ "alembic"
/ "versions"
/ "20260319_0004_automation_job_config_for_memory.py"
)
content = migration.read_text(encoding="utf-8")
assert "INSERT INTO public.automation_jobs" in content
assert "'agent_type', 'memory'" in content
assert "ux_automation_jobs_owner_memory_active" in content
assert "input_template" in content
@@ -0,0 +1,37 @@
from __future__ import annotations
import pytest
from schemas.agent.consumer_registry import AgentConsumerBinding, ConsumerRegistry
from schemas.agent.pipeline_spec import (
ContextPolicy,
ExecutorKind,
PipelineSpec,
StageSpec,
)
def test_consumer_registry_rejects_duplicate_bits() -> None:
with pytest.raises(ValueError, match="duplicate visibility bit"):
ConsumerRegistry(
bindings=[
AgentConsumerBinding(agent_type="router", bit=16),
AgentConsumerBinding(agent_type="worker", bit=16),
]
)
def test_pipeline_spec_requires_non_empty_stages() -> None:
with pytest.raises(ValueError, match="at least 1 item"):
PipelineSpec(mode="worker", stages=[])
def test_stage_spec_normalizes_stage_name() -> None:
spec = StageSpec(
stage_name=" Worker ",
executor_kind=ExecutorKind.REACT,
default_visibility_mask=1,
context_policy=ContextPolicy(consumer_agent_type="worker", count=20),
)
assert spec.stage_name == "worker"
@@ -0,0 +1,31 @@
from __future__ import annotations
import pytest
from schemas.agent.system_agent import SystemAgentLLMConfig
def test_system_agent_llm_config_normalizes_enabled_tools_aliases() -> None:
config = SystemAgentLLMConfig.model_validate(
{
"enabled_tools": [
"calendar.write",
"calendar_write",
"user.lookup",
]
}
)
assert [tool.value for tool in config.enabled_tools] == [
"calendar.write",
"user.lookup",
]
def test_system_agent_llm_config_rejects_unknown_enabled_tool() -> None:
with pytest.raises(ValueError, match="unknown enabled tool"):
SystemAgentLLMConfig.model_validate(
{
"enabled_tools": ["calendar.remove"],
}
)
@@ -0,0 +1,23 @@
from __future__ import annotations
import pytest
from schemas.agent.visibility import VisibilityMask, bit_mask
def test_visibility_mask_from_bits_and_contains() -> None:
mask = VisibilityMask.from_bits(bits=[0, 16, 18])
assert mask.contains(bit=0) is True
assert mask.contains(bit=16) is True
assert mask.contains(bit=17) is False
def test_visibility_mask_rejects_out_of_range_bit() -> None:
with pytest.raises(ValueError, match="range"):
VisibilityMask.from_bits(bits=[64])
def test_bit_mask_builds_single_bit_integer() -> None:
assert bit_mask(bit=0) == 1
assert bit_mask(bit=16) == (1 << 16)
@@ -0,0 +1,30 @@
from __future__ import annotations
import pytest
from schemas.automation.config import AutomationJobConfig, default_memory_job_config
def test_default_memory_job_config_has_expected_defaults() -> None:
config = default_memory_job_config()
assert config.agent_type.value == "memory"
assert config.model_code == "qwen3.5-flash"
assert config.context.source.value == "latest_chat"
def test_automation_job_config_rejects_non_flash_model() -> None:
with pytest.raises(ValueError, match="model_code must be qwen3.5-flash"):
AutomationJobConfig.model_validate(
{
"agent_type": "memory",
"model_code": "qwen-plus",
"enabled_tools": ["calendar.read"],
"input_template": "x",
"context": {
"source": "latest_chat",
"window_mode": "day",
"window_count": 2,
},
}
)