feat(agentscope): add memory system and automation job support
- Add consumer_registry and pipeline_registry for runtime orchestration - Add Visibility schema for message filtering - Add PipelineSpec for agent pipeline configuration - Add automation job models and configuration - Remove memory_prompt.py (consolidated into memory system) - Update runtime components: context_loader, context_service, orchestrator, runner, tasks - Update toolkit: tool_config, tool_middleware, custom tools (calendar, user_lookup) - Add auth_helpers and calendar_domain utilities - Add system_agents.yaml configuration
This commit is contained in:
@@ -10,8 +10,6 @@ from ag_ui.core import (
|
||||
RunErrorEvent,
|
||||
StepStartedEvent,
|
||||
StepFinishedEvent,
|
||||
TextMessageEndEvent,
|
||||
ToolCallResultEvent,
|
||||
)
|
||||
from core.agentscope.runtime.ui_compiler import compile as compile_ui_hints
|
||||
from schemas.agent.ui_hints import UiHintsPayload
|
||||
@@ -94,9 +92,20 @@ def _build_run_finished(event: dict[str, Any]) -> RunFinishedEvent:
|
||||
|
||||
def _build_run_error(event: dict[str, Any]) -> RunErrorEvent:
|
||||
data = event.get("data", {})
|
||||
top_level_message = event.get("message")
|
||||
message = top_level_message if isinstance(top_level_message, str) else ""
|
||||
top_level_code = event.get("code")
|
||||
code = top_level_code if isinstance(top_level_code, str) else None
|
||||
if (not message or code is None) and isinstance(data, dict):
|
||||
data_message = data.get("message")
|
||||
if not message and isinstance(data_message, str):
|
||||
message = data_message
|
||||
data_code = data.get("code")
|
||||
if code is None and isinstance(data_code, str):
|
||||
code = data_code
|
||||
return RunErrorEvent(
|
||||
message=data.get("message", "Unknown error"),
|
||||
code=data.get("code"),
|
||||
message=message or "Unknown error",
|
||||
code=code,
|
||||
)
|
||||
|
||||
|
||||
@@ -120,34 +129,12 @@ def _build_step_finished(event: dict[str, Any]) -> StepFinishedEvent:
|
||||
)
|
||||
|
||||
|
||||
def _build_text_end(event: dict[str, Any]) -> TextMessageEndEvent:
|
||||
data = event.get("data", {})
|
||||
return TextMessageEndEvent(
|
||||
message_id=data.get("messageId", ""),
|
||||
)
|
||||
|
||||
|
||||
def _build_tool_result(event: dict[str, Any]) -> ToolCallResultEvent:
|
||||
data = event.get("data", {})
|
||||
content = data.get("result")
|
||||
if not isinstance(content, str):
|
||||
content = ""
|
||||
return ToolCallResultEvent(
|
||||
message_id=data.get("messageId", ""),
|
||||
tool_call_id=data.get("toolCallId", ""),
|
||||
content=content,
|
||||
role="tool",
|
||||
)
|
||||
|
||||
|
||||
_BUILDER_MAP: dict[str, Any] = {
|
||||
"run.started": _build_run_started,
|
||||
"run.finished": _build_run_finished,
|
||||
"run.error": _build_run_error,
|
||||
"step.start": _build_step_started,
|
||||
"step.finish": _build_step_finished,
|
||||
"text.end": _build_text_end,
|
||||
"tool.result": _build_tool_result,
|
||||
}
|
||||
|
||||
|
||||
@@ -208,6 +195,8 @@ def to_agui_wire_event(event: dict[str, Any] | BaseEvent) -> dict[str, Any]:
|
||||
payload["runId"] = run_id
|
||||
if isinstance(data, dict):
|
||||
reserved = {"type", "threadId", "runId"}
|
||||
if internal_type == "run.error":
|
||||
reserved = {*reserved, "message", "code"}
|
||||
payload.update({k: v for k, v in data.items() if k not in reserved})
|
||||
return payload
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ class MessageRepository:
|
||||
output_tokens: int = 0,
|
||||
cost: Decimal = Decimal("0"),
|
||||
latency_ms: int | None = None,
|
||||
visibility_mask: int = 0,
|
||||
) -> AgentChatMessage:
|
||||
message = AgentChatMessage(
|
||||
session_id=session_id,
|
||||
@@ -42,6 +43,7 @@ class MessageRepository:
|
||||
output_tokens=output_tokens,
|
||||
cost=cost,
|
||||
latency_ms=latency_ms,
|
||||
visibility_mask=max(int(visibility_mask), 0),
|
||||
)
|
||||
self._session.add(message)
|
||||
await self._session.flush()
|
||||
|
||||
@@ -8,9 +8,12 @@ from core.agentscope.events.persistence import MessageRepository, SessionReposit
|
||||
from core.logging import get_logger
|
||||
from models.agent_chat_message import AgentChatMessageRole
|
||||
from models.agent_chat_session import AgentChatSessionStatus
|
||||
from models.system_agents import SystemAgents
|
||||
from schemas.agent.system_agent import AgentType, SystemAgentLLMConfig
|
||||
from schemas.agent.runtime_models import AgentOutput, ToolAgentOutput
|
||||
from schemas.agent.system_agent import AgentType
|
||||
from schemas.agent.visibility import SystemVisibilityBit, bit_mask
|
||||
from schemas.messages.chat_message import AgentChatMessageMetadata
|
||||
from sqlalchemy import select
|
||||
|
||||
|
||||
class EventStore(Protocol):
|
||||
@@ -45,6 +48,9 @@ class SqlAlchemyEventStore:
|
||||
async with self._session_factory() as session:
|
||||
session_repo = SessionRepository(session)
|
||||
message_repo = MessageRepository(session)
|
||||
stage_visibility_bit_map = await self._load_stage_visibility_bit_map(
|
||||
session=session
|
||||
)
|
||||
chat_session = await session_repo.get_session(session_id=session_id)
|
||||
if chat_session is None:
|
||||
return
|
||||
@@ -77,6 +83,7 @@ class SqlAlchemyEventStore:
|
||||
chat_session=chat_session,
|
||||
session_repo=session_repo,
|
||||
message_repo=message_repo,
|
||||
stage_visibility_bit_map=stage_visibility_bit_map,
|
||||
)
|
||||
elif event_type == "TOOL_CALL_RESULT":
|
||||
await self._persist_tool_call_result(
|
||||
@@ -85,6 +92,7 @@ class SqlAlchemyEventStore:
|
||||
chat_session=chat_session,
|
||||
session_repo=session_repo,
|
||||
message_repo=message_repo,
|
||||
stage_visibility_bit_map=stage_visibility_bit_map,
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
@@ -97,6 +105,7 @@ class SqlAlchemyEventStore:
|
||||
chat_session: Any,
|
||||
session_repo: SessionRepository,
|
||||
message_repo: MessageRepository,
|
||||
stage_visibility_bit_map: dict[str, int],
|
||||
) -> None:
|
||||
message_id_raw = self._event_value(event, "messageId")
|
||||
message_id = message_id_raw if isinstance(message_id_raw, str) else ""
|
||||
@@ -188,6 +197,10 @@ class SqlAlchemyEventStore:
|
||||
output_tokens=output_tokens,
|
||||
cost=cost,
|
||||
latency_ms=latency_ms,
|
||||
visibility_mask=self._resolve_stage_visibility_mask(
|
||||
event=event,
|
||||
stage_visibility_bit_map=stage_visibility_bit_map,
|
||||
),
|
||||
)
|
||||
|
||||
current_status = getattr(chat_session, "status", AgentChatSessionStatus.RUNNING)
|
||||
@@ -213,6 +226,7 @@ class SqlAlchemyEventStore:
|
||||
chat_session: Any,
|
||||
session_repo: SessionRepository,
|
||||
message_repo: MessageRepository,
|
||||
stage_visibility_bit_map: dict[str, int],
|
||||
) -> None:
|
||||
run_id = self._event_value(event, "runId")
|
||||
run_id_value = run_id if isinstance(run_id, str) and run_id else None
|
||||
@@ -256,6 +270,10 @@ class SqlAlchemyEventStore:
|
||||
content=content,
|
||||
tool_name=tool_output.tool_name,
|
||||
metadata=metadata_model.model_dump(mode="json", exclude_none=True),
|
||||
visibility_mask=self._resolve_stage_visibility_mask(
|
||||
event=event,
|
||||
stage_visibility_bit_map=stage_visibility_bit_map,
|
||||
),
|
||||
)
|
||||
|
||||
current_status = getattr(chat_session, "status", AgentChatSessionStatus.RUNNING)
|
||||
@@ -279,6 +297,44 @@ class SqlAlchemyEventStore:
|
||||
return AgentChatMessageRole.TOOL
|
||||
return AgentChatMessageRole.ASSISTANT
|
||||
|
||||
def _resolve_stage_visibility_mask(
|
||||
self,
|
||||
*,
|
||||
event: dict[str, Any],
|
||||
stage_visibility_bit_map: dict[str, int],
|
||||
) -> int:
|
||||
base = bit_mask(bit=int(SystemVisibilityBit.UI_HISTORY))
|
||||
raw_stage = self._event_value(event, "stage")
|
||||
if not isinstance(raw_stage, str):
|
||||
return base
|
||||
normalized_stage = raw_stage.strip().lower()
|
||||
bit = stage_visibility_bit_map.get(normalized_stage)
|
||||
if bit is None and normalized_stage == AgentType.MEMORY.value:
|
||||
bit = 18
|
||||
if bit is None:
|
||||
return base
|
||||
return base | bit_mask(bit=bit)
|
||||
|
||||
async def _load_stage_visibility_bit_map(
|
||||
self,
|
||||
*,
|
||||
session: Any,
|
||||
) -> dict[str, int]:
|
||||
stmt = select(SystemAgents.agent_type, SystemAgents.config).where(
|
||||
SystemAgents.agent_type.in_(
|
||||
[AgentType.ROUTER.value, AgentType.WORKER.value, AgentType.MEMORY.value]
|
||||
)
|
||||
)
|
||||
rows = (await session.execute(stmt)).all()
|
||||
bit_map: dict[str, int] = {}
|
||||
for agent_type, raw_config in rows:
|
||||
if not isinstance(agent_type, str):
|
||||
continue
|
||||
config_payload = raw_config if isinstance(raw_config, dict) else {}
|
||||
llm_config = SystemAgentLLMConfig.model_validate(config_payload)
|
||||
bit_map[agent_type.strip().lower()] = llm_config.visibility_consumer_bit
|
||||
return bit_map
|
||||
|
||||
async def _update_session_state(
|
||||
self,
|
||||
*,
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from schemas.agent.runtime_models import ResultType, RouterAgentOutput, TaskType
|
||||
from schemas.agent.system_agent import AgentType, SystemAgentLLMConfig
|
||||
|
||||
|
||||
@@ -14,6 +17,24 @@ def _wrap_section(section: str, content: str) -> str:
|
||||
return f"{start}\n{body}\n{end}" if body else f"{start}\n{end}"
|
||||
|
||||
|
||||
def _enum_values(enum_cls: Any) -> str:
|
||||
return ", ".join(item.value for item in enum_cls)
|
||||
|
||||
|
||||
def _config_rules(llm_config: SystemAgentLLMConfig | None) -> list[str]:
|
||||
if llm_config is None:
|
||||
return []
|
||||
context_mode = llm_config.context_messages.mode.value
|
||||
context_count = llm_config.context_messages.count
|
||||
enabled_tools = [tool.value for tool in llm_config.enabled_tools]
|
||||
return [
|
||||
"[Runtime Config]",
|
||||
f"- context_messages.mode={context_mode}",
|
||||
f"- context_messages.count={context_count}",
|
||||
f"- enabled_tools={','.join(enabled_tools) if enabled_tools else 'default'}",
|
||||
]
|
||||
|
||||
|
||||
PromptRuleBuilder = Callable[[SystemAgentLLMConfig | None], list[str]]
|
||||
|
||||
|
||||
@@ -36,36 +57,41 @@ class AgentPromptRegistry:
|
||||
return builder(llm_config)
|
||||
|
||||
|
||||
def _config_rules(llm_config: SystemAgentLLMConfig | None) -> list[str]:
|
||||
if llm_config is None:
|
||||
return []
|
||||
context_mode = llm_config.context_messages.mode.value
|
||||
context_count = llm_config.context_messages.count
|
||||
tool_groups = [group.value for group in llm_config.enabled_tool_groups]
|
||||
def _router_rules(llm_config: SystemAgentLLMConfig | None) -> list[str]:
|
||||
return [
|
||||
"[Runtime Config]",
|
||||
f"- context_messages.mode={context_mode}",
|
||||
f"- context_messages.count={context_count}",
|
||||
f"- enabled_tool_groups={','.join(tool_groups) if tool_groups else 'default'}",
|
||||
"[Router Agent]",
|
||||
"- Read the latest user input and produce a routing contract for downstream execution.",
|
||||
"- Return exactly one RouterAgentOutput JSON object.",
|
||||
"[Responsibilities]",
|
||||
"- Router only: extract intent and route strategy; never answer user directly.",
|
||||
"- Preserve intent in normalized_task_input.user_text; keep wording concise and faithful.",
|
||||
"- Fill multimodal_summary only when image/attachment changes execution decisions.",
|
||||
"- Return key_entities and constraints that are execution-relevant; low confidence -> omit rather than guess.",
|
||||
"- Set execution_mode by complexity: onestep / tool_assisted / multistep.",
|
||||
"- Set result_typing.primary to the most suitable response shape; use clarification_request only when required info is missing.",
|
||||
"- Set ui.ui_mode and ui.ui_decision_reason based on whether structured UI improves actionability.",
|
||||
f"- task_typing.primary must use one TaskType enum: {_enum_values(TaskType)}.",
|
||||
f"- task_typing.secondary max 3 enums: {_enum_values(TaskType)}.",
|
||||
f"- result_typing.primary must use one ResultType enum: {_enum_values(ResultType)}.",
|
||||
f"- result_typing.secondary max 3 enums: {_enum_values(ResultType)}.",
|
||||
*_config_rules(llm_config),
|
||||
]
|
||||
|
||||
|
||||
def _worker_rules(llm_config: SystemAgentLLMConfig | None) -> list[str]:
|
||||
return [
|
||||
"[Worker Agent]",
|
||||
"- Process user context directly, identify intent, then execute or answer with available evidence.",
|
||||
"- Return exactly one agent output JSON object matching the runtime-injected schema.",
|
||||
"- Execute or answer against the routed objective and available evidence.",
|
||||
"- Return exactly one worker output JSON object matching the runtime-injected schema.",
|
||||
"[Responsibilities]",
|
||||
"- Handle user request directly from conversation context.",
|
||||
"- Decide intent first, then choose direct answer, tool call, clarification, or refusal.",
|
||||
"- Prefer a direct answer when no tool result is required.",
|
||||
"- Call tools only when tool results are necessary to produce a correct answer.",
|
||||
"- Infer deterministic required tool arguments from context, tool schema, and runtime signals.",
|
||||
"- Worker only: execute routed objective without changing router intent.",
|
||||
"- Treat router output as objective/constraints contract, not as a fully-materialized tool-args payload.",
|
||||
"- Infer deterministic required tool arguments from contract fields, tool schema, and runtime context.",
|
||||
"- Ask minimal clarification only when required arguments cannot be inferred safely.",
|
||||
"- If request is unsafe or disallowed, return safe refusal with actionable alternative.",
|
||||
"- Ground every claim in evidence and tool outputs; never fabricate execution state.",
|
||||
"- Ground every claim in available evidence and tool results; never fabricate execution state.",
|
||||
"- Keep status/result_type/answer/key_points/suggested_actions/error internally consistent.",
|
||||
"[Schema Guidance]",
|
||||
"- The output schema is injected at runtime; follow it exactly.",
|
||||
"- The worker output schema is injected at runtime; follow it exactly.",
|
||||
"- Do not add fields that are not present in the injected schema.",
|
||||
*_config_rules(llm_config),
|
||||
]
|
||||
@@ -88,7 +114,28 @@ def _memory_rules(llm_config: SystemAgentLLMConfig | None) -> list[str]:
|
||||
]
|
||||
|
||||
|
||||
def build_worker_contract_prompt(*, router_output: RouterAgentOutput) -> str:
|
||||
contract_json = json.dumps(
|
||||
router_output.model_dump(mode="json", exclude_none=True),
|
||||
ensure_ascii=True,
|
||||
separators=(",", ":"),
|
||||
)
|
||||
return "\n".join(
|
||||
[
|
||||
"[Worker Contract]",
|
||||
"- Keep routed objective unchanged.",
|
||||
"- Use normalized_task_input as objective text.",
|
||||
"- Use multimodal_summary/key_entities/constraints as execution evidence.",
|
||||
"- Infer deterministic missing required tool args from evidence + tool schema.",
|
||||
"- Ask clarification only when safe inference is impossible.",
|
||||
"[RouterAgentOutput]",
|
||||
contract_json,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
AGENT_PROMPT_REGISTRY = AgentPromptRegistry()
|
||||
AGENT_PROMPT_REGISTRY.register(agent_type=AgentType.ROUTER, builder=_router_rules)
|
||||
AGENT_PROMPT_REGISTRY.register(agent_type=AgentType.WORKER, builder=_worker_rules)
|
||||
AGENT_PROMPT_REGISTRY.register(agent_type=AgentType.MEMORY, builder=_memory_rules)
|
||||
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
pass
|
||||
@@ -0,0 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from schemas.agent.consumer_registry import AgentConsumerBinding, ConsumerRegistry
|
||||
|
||||
|
||||
def build_consumer_registry(
|
||||
*,
|
||||
system_agent_configs: dict[str, dict[str, object]],
|
||||
) -> ConsumerRegistry:
|
||||
bindings: list[AgentConsumerBinding] = []
|
||||
for agent_type, payload in system_agent_configs.items():
|
||||
config_obj = payload.get("config") if isinstance(payload, dict) else None
|
||||
if not isinstance(config_obj, dict):
|
||||
raise ValueError(f"invalid system agent config: {agent_type}")
|
||||
raw_bit = config_obj.get("visibility_consumer_bit")
|
||||
if not isinstance(raw_bit, int):
|
||||
raise ValueError(f"visibility_consumer_bit missing for agent: {agent_type}")
|
||||
bindings.append(AgentConsumerBinding(agent_type=agent_type, bit=raw_bit))
|
||||
return ConsumerRegistry(bindings=bindings)
|
||||
@@ -5,7 +5,7 @@ from typing import Any
|
||||
|
||||
from schemas.agent.system_agent import ContextBuildStrategy
|
||||
|
||||
ContextLoader = Callable[[Any, str, int], Awaitable[dict[str, object] | None]]
|
||||
ContextLoader = Callable[[Any, str, int, int], Awaitable[dict[str, object] | None]]
|
||||
|
||||
|
||||
class ContextLoaderRegistry:
|
||||
@@ -23,20 +23,28 @@ class ContextLoaderRegistry:
|
||||
|
||||
|
||||
async def _load_number(
|
||||
service: Any, thread_id: str, count: int
|
||||
service: Any,
|
||||
thread_id: str,
|
||||
count: int,
|
||||
visibility_mask: int,
|
||||
) -> dict[str, object] | None:
|
||||
return await service.load_by_user_message_window(
|
||||
thread_id=thread_id,
|
||||
user_message_limit=max(count, 1),
|
||||
visibility_mask=visibility_mask,
|
||||
)
|
||||
|
||||
|
||||
async def _load_day(
|
||||
service: Any, thread_id: str, count: int
|
||||
service: Any,
|
||||
thread_id: str,
|
||||
count: int,
|
||||
visibility_mask: int,
|
||||
) -> dict[str, object] | None:
|
||||
return await service.load_by_day_window(
|
||||
thread_id=thread_id,
|
||||
day_count=max(count, 1),
|
||||
visibility_mask=visibility_mask,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -5,17 +5,27 @@ from typing import Protocol
|
||||
|
||||
from core.agentscope.runtime.context_loader_registry import CONTEXT_LOADER_REGISTRY
|
||||
from schemas.agent.system_agent import SystemAgentLLMConfig
|
||||
from schemas.agent.visibility import bit_mask
|
||||
|
||||
_DEFAULT_CONTEXT_WINDOW_USER_MESSAGES = 20
|
||||
_DEFAULT_ROUTER_CONTEXT_DAY_COUNT = 20
|
||||
|
||||
|
||||
class ContextRepositoryLike(Protocol):
|
||||
async def get_history_day(
|
||||
self, *, session_id: str, before: date | None
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
before: date | None,
|
||||
visibility_mask: int | None = None,
|
||||
) -> dict[str, object] | None: ...
|
||||
|
||||
async def get_recent_messages_by_user_window(
|
||||
self, *, session_id: str, user_message_limit: int
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
user_message_limit: int,
|
||||
visibility_mask: int | None = None,
|
||||
) -> list[dict[str, object]]: ...
|
||||
|
||||
async def get_system_agent_config(
|
||||
@@ -41,20 +51,36 @@ class AgentContextService:
|
||||
if isinstance(raw_config, dict):
|
||||
raw_llm_config = raw_config
|
||||
|
||||
if mode == "router" and not raw_llm_config:
|
||||
raw_llm_config = {
|
||||
"context_messages": {
|
||||
"mode": "day",
|
||||
"count": _DEFAULT_ROUTER_CONTEXT_DAY_COUNT,
|
||||
}
|
||||
}
|
||||
|
||||
normalized_config = self._normalize_system_agent_config(raw_llm_config)
|
||||
context_config = normalized_config.context_messages
|
||||
visibility_mask = bit_mask(bit=normalized_config.visibility_consumer_bit)
|
||||
context_loader = CONTEXT_LOADER_REGISTRY.resolve(mode=context_config.mode)
|
||||
return await context_loader(self, thread_id, context_config.count)
|
||||
return await context_loader(
|
||||
self,
|
||||
thread_id,
|
||||
context_config.count,
|
||||
visibility_mask,
|
||||
)
|
||||
|
||||
async def load_by_user_message_window(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
user_message_limit: int,
|
||||
visibility_mask: int,
|
||||
) -> dict[str, object] | None:
|
||||
messages = await self._repository.get_recent_messages_by_user_window(
|
||||
session_id=thread_id,
|
||||
user_message_limit=max(int(user_message_limit), 1),
|
||||
visibility_mask=visibility_mask,
|
||||
)
|
||||
if not messages:
|
||||
return None
|
||||
@@ -65,6 +91,7 @@ class AgentContextService:
|
||||
*,
|
||||
thread_id: str,
|
||||
day_count: int,
|
||||
visibility_mask: int,
|
||||
) -> dict[str, object] | None:
|
||||
messages: list[dict[str, object]] = []
|
||||
before: date | None = None
|
||||
@@ -72,6 +99,7 @@ class AgentContextService:
|
||||
day_payload = await self._repository.get_history_day(
|
||||
session_id=thread_id,
|
||||
before=before,
|
||||
visibility_mask=visibility_mask,
|
||||
)
|
||||
if not day_payload:
|
||||
break
|
||||
@@ -95,7 +123,7 @@ class AgentContextService:
|
||||
"mode": "number",
|
||||
"count": _DEFAULT_CONTEXT_WINDOW_USER_MESSAGES,
|
||||
},
|
||||
"enabled_tool_groups": [],
|
||||
"enabled_tools": [],
|
||||
}
|
||||
if not raw_config:
|
||||
return SystemAgentLLMConfig.model_validate(default_payload)
|
||||
|
||||
@@ -6,6 +6,7 @@ from ag_ui.core.types import RunAgentInput
|
||||
from agentscope.message import Msg
|
||||
from core.agentscope.runtime.runner import AgentScopeRunner
|
||||
from core.logging import get_logger
|
||||
from schemas.automation.config import AutomationJobConfig
|
||||
from schemas.user import UserContext
|
||||
|
||||
logger = get_logger("core.agentscope.runtime.orchestrator")
|
||||
@@ -24,6 +25,7 @@ class RunnerLike(Protocol):
|
||||
pipeline: PipelineLike,
|
||||
run_input: RunAgentInput,
|
||||
system_agent_mode: str,
|
||||
memory_job_config: AutomationJobConfig | None,
|
||||
) -> dict[str, Any]: ...
|
||||
|
||||
|
||||
@@ -47,6 +49,7 @@ class AgentScopeRuntimeOrchestrator:
|
||||
context_messages: list[Msg],
|
||||
user_context: UserContext,
|
||||
system_agent_mode: str,
|
||||
memory_job_config: AutomationJobConfig | None = None,
|
||||
) -> dict[str, Any]:
|
||||
thread_id = run_input.thread_id
|
||||
run_id = run_input.run_id
|
||||
@@ -66,6 +69,7 @@ class AgentScopeRuntimeOrchestrator:
|
||||
pipeline=self._pipeline,
|
||||
run_input=run_input,
|
||||
system_agent_mode=system_agent_mode,
|
||||
memory_job_config=memory_job_config,
|
||||
)
|
||||
|
||||
await self._pipeline.emit(
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from schemas.agent.pipeline_spec import (
|
||||
ContextPolicy,
|
||||
ContextWindowMode,
|
||||
ExecutorKind,
|
||||
PipelineSpec,
|
||||
StageSpec,
|
||||
)
|
||||
|
||||
|
||||
def build_default_pipeline_spec(*, mode: str) -> PipelineSpec:
|
||||
normalized = mode.strip().lower()
|
||||
if normalized == "worker":
|
||||
return PipelineSpec(
|
||||
mode="worker",
|
||||
stages=[
|
||||
StageSpec(
|
||||
stage_name="router",
|
||||
executor_kind=ExecutorKind.SINGLE_SHOT,
|
||||
default_visibility_mask=0,
|
||||
context_policy=ContextPolicy(
|
||||
consumer_agent_type="router",
|
||||
window_mode=ContextWindowMode.DAY,
|
||||
count=20,
|
||||
),
|
||||
),
|
||||
StageSpec(
|
||||
stage_name="worker",
|
||||
executor_kind=ExecutorKind.REACT,
|
||||
default_visibility_mask=0,
|
||||
context_policy=ContextPolicy(
|
||||
consumer_agent_type="worker",
|
||||
window_mode=ContextWindowMode.NUMBER,
|
||||
count=20,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
if normalized == "memory":
|
||||
return PipelineSpec(
|
||||
mode="memory",
|
||||
stages=[
|
||||
StageSpec(
|
||||
stage_name="memory",
|
||||
executor_kind=ExecutorKind.REACT,
|
||||
default_visibility_mask=0,
|
||||
context_policy=ContextPolicy(
|
||||
consumer_agent_type="memory",
|
||||
window_mode=ContextWindowMode.DAY,
|
||||
count=20,
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
raise ValueError(f"unsupported pipeline mode: {normalized}")
|
||||
@@ -10,26 +10,40 @@ from agentscope.formatter import OpenAIChatFormatter
|
||||
from agentscope.memory import InMemoryMemory
|
||||
from agentscope.message import Msg
|
||||
from agentscope.model import OpenAIChatModel
|
||||
from core.agentscope.prompts.agent_prompt import build_worker_contract_prompt
|
||||
from core.agentscope.prompts.system_prompt import build_system_prompt
|
||||
from core.agentscope.runtime.pipeline_registry import build_default_pipeline_spec
|
||||
from core.agentscope.runtime.json_react_agent import JsonReActAgent
|
||||
from core.agentscope.runtime.model_tracking import TrackingChatModel
|
||||
from core.agentscope.runtime.stage_emitter import PipelineStageEmitter
|
||||
from core.agentscope.runtime.tool_selection_registry import TOOL_SELECTION_REGISTRY
|
||||
from core.agentscope.tools.toolkit import build_stage_toolkit
|
||||
from core.agentscope.utils import patch_agentscope_json_repair_compat
|
||||
from core.agentscope.utils import (
|
||||
finalize_json_response,
|
||||
patch_agentscope_json_repair_compat,
|
||||
)
|
||||
from core.config.settings import config
|
||||
from core.db.session import AsyncSessionLocal
|
||||
from models.llm import Llm
|
||||
from models.llm_factory import LlmFactory
|
||||
from models.system_agents import SystemAgents
|
||||
from schemas.agent.runtime_models import (
|
||||
AgentOutput,
|
||||
)
|
||||
from schemas.agent.forwarded_props import (
|
||||
ClientTimeContext,
|
||||
parse_forwarded_props_client_time,
|
||||
)
|
||||
from schemas.agent.system_agent import AgentType, SystemAgentLLMConfig
|
||||
from schemas.automation.config import AutomationJobConfig
|
||||
from schemas.agent.runtime_models import (
|
||||
AgentOutput,
|
||||
RouterAgentOutput,
|
||||
WorkerAgentOutputLite,
|
||||
resolve_worker_output_model,
|
||||
)
|
||||
from schemas.agent.system_agent import (
|
||||
AgentType,
|
||||
ContextMessagesConfig,
|
||||
ContextBuildStrategy,
|
||||
SystemAgentLLMConfig,
|
||||
)
|
||||
from schemas.user import UserContext
|
||||
from services.litellm.service import LiteLLMService
|
||||
from sqlalchemy import select
|
||||
@@ -46,6 +60,7 @@ class SystemAgentRuntimeConfig:
|
||||
api_base_url: str
|
||||
api_key: str
|
||||
llm_config: SystemAgentLLMConfig
|
||||
extra_context: str | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -68,33 +83,83 @@ class AgentScopeRunner:
|
||||
pipeline: PipelineLike,
|
||||
run_input: RunAgentInput,
|
||||
system_agent_mode: str,
|
||||
memory_job_config: AutomationJobConfig | None = None,
|
||||
) -> dict[str, Any]:
|
||||
owner_id = UUID(user_context.id)
|
||||
runtime_client_time = self._resolve_runtime_client_time(run_input=run_input)
|
||||
stage_agent_type = self._resolve_stage_agent_type(system_agent_mode)
|
||||
pipeline_spec = build_default_pipeline_spec(mode=system_agent_mode)
|
||||
stage_agent_types = [
|
||||
self._parse_agent_type(stage_name=stage.stage_name)
|
||||
for stage in pipeline_spec.stages
|
||||
]
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
stage_config = await self._load_stage_config(
|
||||
session=session,
|
||||
agent_type=stage_agent_type,
|
||||
)
|
||||
if stage_agent_types == [AgentType.ROUTER, AgentType.WORKER]:
|
||||
router_config = await self._load_stage_config(
|
||||
session=session,
|
||||
agent_type=AgentType.ROUTER,
|
||||
)
|
||||
worker_config = await self._load_stage_config(
|
||||
session=session,
|
||||
agent_type=AgentType.WORKER,
|
||||
)
|
||||
worker_toolkit = self._build_stage_toolkit(
|
||||
session=session,
|
||||
owner_id=owner_id,
|
||||
stage_config=worker_config,
|
||||
)
|
||||
router_output = await self._execute_router_step(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
user_context=user_context,
|
||||
context_messages=context_messages,
|
||||
stage_config=router_config,
|
||||
runtime_client_time=runtime_client_time,
|
||||
)
|
||||
worker_output = await self._execute_worker_step(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
user_context=user_context,
|
||||
router_output=router_output,
|
||||
toolkit=worker_toolkit,
|
||||
stage_config=worker_config,
|
||||
runtime_client_time=runtime_client_time,
|
||||
)
|
||||
return {
|
||||
"router": router_output.model_dump(mode="json", exclude_none=True),
|
||||
"worker": worker_output.model_dump(mode="json", exclude_none=True),
|
||||
}
|
||||
|
||||
if stage_agent_types[0] == AgentType.MEMORY:
|
||||
if memory_job_config is None:
|
||||
raise RuntimeError("memory job config is required")
|
||||
stage_config = await self._build_memory_stage_config(
|
||||
session=session,
|
||||
memory_job_config=memory_job_config,
|
||||
)
|
||||
else:
|
||||
stage_config = await self._load_stage_config(
|
||||
session=session,
|
||||
agent_type=stage_agent_types[0],
|
||||
)
|
||||
stage_toolkit = self._build_stage_toolkit(
|
||||
session=session,
|
||||
owner_id=owner_id,
|
||||
stage_config=stage_config,
|
||||
)
|
||||
worker_output = await self._execute_worker_step(
|
||||
stage_output = await self._execute_single_stage_step(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
user_context=user_context,
|
||||
context_messages=context_messages,
|
||||
input_messages=context_messages,
|
||||
toolkit=stage_toolkit,
|
||||
stage_config=stage_config,
|
||||
runtime_client_time=runtime_client_time,
|
||||
)
|
||||
|
||||
return {
|
||||
"worker": worker_output.model_dump(mode="json", exclude_none=True),
|
||||
stage_config.agent_type.value: stage_output.model_dump(
|
||||
mode="json", exclude_none=True
|
||||
),
|
||||
}
|
||||
|
||||
def _build_stage_toolkit(
|
||||
@@ -113,11 +178,15 @@ class AgentScopeRunner:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_stage_agent_type(system_agent_mode: str) -> AgentType:
|
||||
mode = system_agent_mode.strip().lower() if system_agent_mode else "worker"
|
||||
if mode == AgentType.MEMORY.value:
|
||||
def _parse_agent_type(*, stage_name: str) -> AgentType:
|
||||
normalized = stage_name.strip().lower()
|
||||
if normalized == AgentType.ROUTER.value:
|
||||
return AgentType.ROUTER
|
||||
if normalized == AgentType.WORKER.value:
|
||||
return AgentType.WORKER
|
||||
if normalized == AgentType.MEMORY.value:
|
||||
return AgentType.MEMORY
|
||||
return AgentType.WORKER
|
||||
raise ValueError(f"unsupported stage name: {stage_name}")
|
||||
|
||||
async def _load_stage_config(
|
||||
self,
|
||||
@@ -130,28 +199,60 @@ class AgentScopeRunner:
|
||||
agent_type=agent_type,
|
||||
)
|
||||
|
||||
async def _execute_worker_step(
|
||||
async def _execute_router_step(
|
||||
self,
|
||||
*,
|
||||
pipeline: PipelineLike,
|
||||
run_input: RunAgentInput,
|
||||
user_context: UserContext,
|
||||
context_messages: list[Msg],
|
||||
toolkit: Any,
|
||||
stage_config: SystemAgentRuntimeConfig,
|
||||
runtime_client_time: ClientTimeContext | None,
|
||||
) -> AgentOutput:
|
||||
step_name = stage_config.agent_type.value
|
||||
worker_output_model = AgentOutput
|
||||
) -> RouterAgentOutput:
|
||||
await self._emit_step_event(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
step_name=step_name,
|
||||
step_name=AgentType.ROUTER.value,
|
||||
event_type="STEP_STARTED",
|
||||
)
|
||||
router_result = await self._run_router_stage(
|
||||
user_context=user_context,
|
||||
context_messages=context_messages,
|
||||
stage_config=stage_config,
|
||||
runtime_client_time=runtime_client_time,
|
||||
)
|
||||
router_output = RouterAgentOutput.model_validate(router_result.payload)
|
||||
await self._emit_step_event(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
step_name=AgentType.ROUTER.value,
|
||||
event_type="STEP_FINISHED",
|
||||
)
|
||||
return router_output
|
||||
|
||||
async def _execute_worker_step(
|
||||
self,
|
||||
*,
|
||||
pipeline: PipelineLike,
|
||||
run_input: RunAgentInput,
|
||||
user_context: UserContext,
|
||||
router_output: RouterAgentOutput,
|
||||
toolkit: Any,
|
||||
stage_config: SystemAgentRuntimeConfig,
|
||||
runtime_client_time: ClientTimeContext | None,
|
||||
) -> WorkerAgentOutputLite:
|
||||
worker_output_model = resolve_worker_output_model(router_output.ui.ui_mode)
|
||||
await self._emit_step_event(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
step_name=AgentType.WORKER.value,
|
||||
event_type="STEP_STARTED",
|
||||
)
|
||||
worker_result = await self._run_worker_stage(
|
||||
user_context=user_context,
|
||||
context_messages=context_messages,
|
||||
input_messages=self._build_worker_input_messages(
|
||||
router_output=router_output
|
||||
),
|
||||
toolkit=toolkit,
|
||||
run_input=run_input,
|
||||
stage_config=stage_config,
|
||||
@@ -163,11 +264,48 @@ class AgentScopeRunner:
|
||||
await self._emit_step_event(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
step_name=step_name,
|
||||
step_name=AgentType.WORKER.value,
|
||||
event_type="STEP_FINISHED",
|
||||
)
|
||||
return worker_output
|
||||
|
||||
async def _execute_single_stage_step(
|
||||
self,
|
||||
*,
|
||||
pipeline: PipelineLike,
|
||||
run_input: RunAgentInput,
|
||||
user_context: UserContext,
|
||||
input_messages: list[Msg],
|
||||
toolkit: Any,
|
||||
stage_config: SystemAgentRuntimeConfig,
|
||||
runtime_client_time: ClientTimeContext | None,
|
||||
) -> AgentOutput:
|
||||
step_name = stage_config.agent_type.value
|
||||
await self._emit_step_event(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
step_name=step_name,
|
||||
event_type="STEP_STARTED",
|
||||
)
|
||||
stage_result = await self._run_worker_stage(
|
||||
user_context=user_context,
|
||||
input_messages=input_messages,
|
||||
toolkit=toolkit,
|
||||
run_input=run_input,
|
||||
stage_config=stage_config,
|
||||
worker_output_model=AgentOutput,
|
||||
pipeline=pipeline,
|
||||
runtime_client_time=runtime_client_time,
|
||||
)
|
||||
stage_output = AgentOutput.model_validate(stage_result.payload)
|
||||
await self._emit_step_event(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
step_name=step_name,
|
||||
event_type="STEP_FINISHED",
|
||||
)
|
||||
return stage_output
|
||||
|
||||
async def _load_system_agent_config(
|
||||
self,
|
||||
*,
|
||||
@@ -193,6 +331,50 @@ class AgentScopeRunner:
|
||||
api_base_url=factory.request_url,
|
||||
api_key=self._resolve_provider_api_key(factory_name=factory.name),
|
||||
llm_config=SystemAgentLLMConfig.model_validate(system_agent.config or {}),
|
||||
extra_context=None,
|
||||
)
|
||||
|
||||
async def _build_memory_stage_config(
|
||||
self,
|
||||
*,
|
||||
session: AsyncSession,
|
||||
memory_job_config: AutomationJobConfig,
|
||||
) -> SystemAgentRuntimeConfig:
|
||||
stmt = (
|
||||
select(Llm, LlmFactory)
|
||||
.join(LlmFactory, Llm.factory_id == LlmFactory.id)
|
||||
.where(Llm.model_code == memory_job_config.model_code)
|
||||
)
|
||||
row = (await session.execute(stmt)).one_or_none()
|
||||
if row is None:
|
||||
raise RuntimeError(
|
||||
f"memory model not found: {memory_job_config.model_code}"
|
||||
)
|
||||
llm, factory = row
|
||||
llm_config = SystemAgentLLMConfig(
|
||||
temperature=0.7,
|
||||
max_tokens=None,
|
||||
timeout_seconds=30,
|
||||
visibility_consumer_bit=18,
|
||||
context_messages=ContextMessagesConfig(
|
||||
mode=(
|
||||
ContextBuildStrategy.DAY
|
||||
if memory_job_config.context.window_mode.value == "day"
|
||||
else ContextBuildStrategy.NUMBER
|
||||
),
|
||||
count=memory_job_config.context.window_count,
|
||||
),
|
||||
enabled_tools=memory_job_config.enabled_tools,
|
||||
)
|
||||
return SystemAgentRuntimeConfig(
|
||||
agent_type=AgentType.MEMORY,
|
||||
model_code=llm.model_code,
|
||||
api_base_url=factory.request_url,
|
||||
api_key=self._resolve_provider_api_key(factory_name=factory.name),
|
||||
llm_config=llm_config,
|
||||
extra_context=(
|
||||
f"[Memory Input Template]\n{memory_job_config.input_template.strip()}"
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -211,19 +393,63 @@ class AgentScopeRunner:
|
||||
raise RuntimeError(f"provider api key missing for factory: {factory_name}")
|
||||
return api_key
|
||||
|
||||
async def _run_worker_stage(
|
||||
async def _run_router_stage(
|
||||
self,
|
||||
*,
|
||||
user_context: UserContext,
|
||||
context_messages: list[Msg],
|
||||
stage_config: SystemAgentRuntimeConfig,
|
||||
runtime_client_time: ClientTimeContext | None,
|
||||
) -> StageExecutionResult:
|
||||
tracking_model = self._build_model(stage_config=stage_config)
|
||||
response, payload = await finalize_json_response(
|
||||
model=tracking_model,
|
||||
formatter=OpenAIChatFormatter(),
|
||||
base_messages=[
|
||||
Msg(
|
||||
"system",
|
||||
build_system_prompt(
|
||||
agent_type=AgentType.ROUTER,
|
||||
llm_config=stage_config.llm_config,
|
||||
user_context=user_context,
|
||||
now_utc=datetime.now(timezone.utc),
|
||||
runtime_client_time=runtime_client_time,
|
||||
tools=None,
|
||||
),
|
||||
"system",
|
||||
),
|
||||
*context_messages,
|
||||
],
|
||||
output_model=RouterAgentOutput,
|
||||
retries=0,
|
||||
)
|
||||
response_msg = Msg(
|
||||
name="router",
|
||||
role="assistant",
|
||||
content=list(getattr(response, "content", [])),
|
||||
metadata=payload,
|
||||
)
|
||||
return StageExecutionResult(
|
||||
message=response_msg,
|
||||
payload=payload,
|
||||
response_metadata=self._litellm_service.build_usage_metadata(
|
||||
model=stage_config.model_code,
|
||||
usage_summary=tracking_model.usage_summary(),
|
||||
),
|
||||
)
|
||||
|
||||
async def _run_worker_stage(
|
||||
self,
|
||||
*,
|
||||
user_context: UserContext,
|
||||
input_messages: list[Msg],
|
||||
toolkit: Any,
|
||||
run_input: RunAgentInput,
|
||||
stage_config: SystemAgentRuntimeConfig,
|
||||
worker_output_model: type[AgentOutput],
|
||||
worker_output_model: type[WorkerAgentOutputLite],
|
||||
pipeline: PipelineLike,
|
||||
runtime_client_time: ClientTimeContext | None,
|
||||
) -> StageExecutionResult:
|
||||
worker_input = list(context_messages)
|
||||
tracking_model = self._build_model(stage_config=stage_config)
|
||||
emitter = PipelineStageEmitter(
|
||||
pipeline=pipeline,
|
||||
@@ -241,6 +467,7 @@ class AgentScopeRunner:
|
||||
user_context=user_context,
|
||||
now_utc=datetime.now(timezone.utc),
|
||||
runtime_client_time=runtime_client_time,
|
||||
extra_context=stage_config.extra_context,
|
||||
tools=None,
|
||||
),
|
||||
toolkit=toolkit,
|
||||
@@ -248,7 +475,7 @@ class AgentScopeRunner:
|
||||
emitter=emitter,
|
||||
)
|
||||
response_msg = await agent.reply_json(
|
||||
worker_input, output_model=worker_output_model
|
||||
input_messages, output_model=worker_output_model
|
||||
)
|
||||
worker_payload = worker_output_model.model_validate(response_msg.metadata or {})
|
||||
response_metadata = self._litellm_service.build_usage_metadata(
|
||||
@@ -265,6 +492,19 @@ class AgentScopeRunner:
|
||||
response_metadata=response_metadata,
|
||||
)
|
||||
|
||||
def _build_worker_input_messages(
|
||||
self,
|
||||
*,
|
||||
router_output: RouterAgentOutput,
|
||||
) -> list[Msg]:
|
||||
return [
|
||||
Msg(
|
||||
name=AgentType.ROUTER.value,
|
||||
role="user",
|
||||
content=build_worker_contract_prompt(router_output=router_output),
|
||||
)
|
||||
]
|
||||
|
||||
def _build_model(
|
||||
self, *, stage_config: SystemAgentRuntimeConfig
|
||||
) -> TrackingChatModel:
|
||||
@@ -272,8 +512,8 @@ class AgentScopeRunner:
|
||||
"temperature": stage_config.llm_config.temperature,
|
||||
"max_tokens": stage_config.llm_config.max_tokens,
|
||||
"timeout": stage_config.llm_config.timeout_seconds,
|
||||
"extra_body": {"enable_thinking": False},
|
||||
}
|
||||
generate_kwargs["extra_body"] = {"enable_thinking": False}
|
||||
|
||||
model = OpenAIChatModel(
|
||||
model_name=stage_config.model_code,
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
from datetime import timezone
|
||||
from typing import Any, cast
|
||||
from uuid import UUID
|
||||
|
||||
@@ -14,26 +15,49 @@ from core.agentscope.events import (
|
||||
)
|
||||
from core.agentscope.runtime.context_service import AgentContextService
|
||||
from core.agentscope.runtime.orchestrator import AgentScopeRuntimeOrchestrator
|
||||
from core.agentscope.runtime.pipeline_registry import build_default_pipeline_spec
|
||||
from core.agentscope.schemas.agui_input import parse_run_input
|
||||
from core.automation.scheduler import (
|
||||
AutomationSchedulerService,
|
||||
SqlAlchemyAutomationSchedulerRepository,
|
||||
utc_now,
|
||||
)
|
||||
from core.auth.models import CurrentUser
|
||||
from core.config.settings import config
|
||||
from core.db.session import AsyncSessionLocal
|
||||
from core.logging import get_logger
|
||||
from core.taskiq.app import bulk_broker, critical_broker, default_broker
|
||||
from models.automation_jobs import AutomationJob
|
||||
from schemas.agent.visibility import SystemVisibilityBit, bit_mask
|
||||
from schemas.automation.config import AutomationJobConfig
|
||||
from schemas.messages.chat_message import (
|
||||
AgentChatMessageMetadata,
|
||||
extract_user_message_attachments,
|
||||
)
|
||||
from schemas.agent.forwarded_props import parse_forwarded_props_agent_type
|
||||
from schemas.user import UserContext
|
||||
from services.base.redis import get_or_init_redis_client
|
||||
from services.base.supabase import supabase_service
|
||||
from v1.agent.repository import AgentRepository
|
||||
from v1.users.dependencies import get_user_service
|
||||
from sqlalchemy import select
|
||||
|
||||
logger = get_logger("core.agentscope.runtime.tasks")
|
||||
_MAX_CONTEXT_ATTACHMENTS = 3
|
||||
|
||||
|
||||
class _BulkQueueAdapter:
|
||||
async def enqueue(
|
||||
self,
|
||||
*,
|
||||
command: dict[str, object],
|
||||
dedup_key: str | None,
|
||||
) -> str:
|
||||
del dedup_key
|
||||
result = await run_command_task_bulk.kiq(command)
|
||||
return str(result.task_id)
|
||||
|
||||
|
||||
def _serialize_tool_agent_output(
|
||||
*,
|
||||
metadata: AgentChatMessageMetadata | dict[str, object] | None,
|
||||
@@ -79,13 +103,29 @@ async def _build_recent_context_messages(
|
||||
*,
|
||||
session: Any,
|
||||
thread_id: str,
|
||||
system_agent_mode: str,
|
||||
context_mode: str,
|
||||
memory_job_config: AutomationJobConfig | None = None,
|
||||
) -> list[Msg]:
|
||||
context_service = AgentContextService(repository=AgentRepository(session))
|
||||
result = await context_service.load_context_messages(
|
||||
thread_id=thread_id,
|
||||
system_agent_mode=system_agent_mode,
|
||||
)
|
||||
if memory_job_config is not None:
|
||||
visibility_mask = bit_mask(bit=int(SystemVisibilityBit.UI_HISTORY))
|
||||
if memory_job_config.context.window_mode.value == "day":
|
||||
result = await context_service.load_by_day_window(
|
||||
thread_id=thread_id,
|
||||
day_count=memory_job_config.context.window_count,
|
||||
visibility_mask=visibility_mask,
|
||||
)
|
||||
else:
|
||||
result = await context_service.load_by_user_message_window(
|
||||
thread_id=thread_id,
|
||||
user_message_limit=memory_job_config.context.window_count,
|
||||
visibility_mask=visibility_mask,
|
||||
)
|
||||
else:
|
||||
result = await context_service.load_context_messages(
|
||||
thread_id=thread_id,
|
||||
system_agent_mode=context_mode,
|
||||
)
|
||||
if not result:
|
||||
return []
|
||||
|
||||
@@ -166,11 +206,33 @@ async def _build_recent_context_messages(
|
||||
return converted
|
||||
|
||||
|
||||
async def _load_memory_job_config(
|
||||
*,
|
||||
session: Any,
|
||||
owner_id: UUID,
|
||||
automation_job_id: str,
|
||||
) -> AutomationJobConfig:
|
||||
try:
|
||||
job_uuid = UUID(automation_job_id)
|
||||
except ValueError as exc:
|
||||
raise ValueError("automation_job_id is invalid") from exc
|
||||
|
||||
stmt = (
|
||||
select(AutomationJob)
|
||||
.where(AutomationJob.id == job_uuid)
|
||||
.where(AutomationJob.owner_id == owner_id)
|
||||
.where(AutomationJob.deleted_at.is_(None))
|
||||
)
|
||||
row = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if row is None:
|
||||
raise ValueError("automation job not found")
|
||||
return AutomationJobConfig.model_validate(row.config or {})
|
||||
|
||||
|
||||
async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
command_type = str(command.get("command", "run")).strip().lower()
|
||||
raw_owner_id = command.get("owner_id")
|
||||
run_input_raw = command.get("run_input")
|
||||
system_agent_mode = str(command.get("system_agent_mode", "worker")).strip().lower()
|
||||
|
||||
if not isinstance(raw_owner_id, str) or not raw_owner_id.strip():
|
||||
raise ValueError("owner_id is required")
|
||||
@@ -178,6 +240,15 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
raise ValueError("run_input is required")
|
||||
|
||||
run_input = parse_run_input(run_input_raw)
|
||||
system_agent_mode = parse_forwarded_props_agent_type(
|
||||
getattr(run_input, "forwarded_props", None)
|
||||
)
|
||||
raw_automation_job_id = command.get("automation_job_id")
|
||||
if system_agent_mode == "memory" and (
|
||||
not isinstance(raw_automation_job_id, str) or not raw_automation_job_id
|
||||
):
|
||||
raise ValueError("automation_job_id is required for memory mode")
|
||||
pipeline_spec = build_default_pipeline_spec(mode=system_agent_mode)
|
||||
thread_id = run_input.thread_id
|
||||
run_id = run_input.run_id
|
||||
owner_id = UUID(raw_owner_id)
|
||||
@@ -189,6 +260,14 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
user_context = await _build_user_context(owner_id=owner_id, session=session)
|
||||
memory_job_config: AutomationJobConfig | None = None
|
||||
if system_agent_mode == "memory":
|
||||
assert isinstance(raw_automation_job_id, str)
|
||||
memory_job_config = await _load_memory_job_config(
|
||||
session=session,
|
||||
owner_id=owner_id,
|
||||
automation_job_id=raw_automation_job_id,
|
||||
)
|
||||
|
||||
redis_client = await get_or_init_redis_client()
|
||||
bus = RedisStreamBus(
|
||||
@@ -211,7 +290,8 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
context_messages = await _build_recent_context_messages(
|
||||
session=session,
|
||||
thread_id=thread_id,
|
||||
system_agent_mode=system_agent_mode,
|
||||
context_mode=pipeline_spec.stages[0].context_policy.consumer_agent_type,
|
||||
memory_job_config=memory_job_config,
|
||||
)
|
||||
|
||||
await runtime.run(
|
||||
@@ -219,6 +299,7 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
context_messages=context_messages,
|
||||
user_context=user_context,
|
||||
system_agent_mode=system_agent_mode,
|
||||
memory_job_config=memory_job_config,
|
||||
)
|
||||
logger.info(
|
||||
"agentscope runtime task completed",
|
||||
@@ -233,6 +314,35 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
}
|
||||
|
||||
|
||||
async def run_automation_scheduler_scan(
|
||||
*,
|
||||
limit: int | None = None,
|
||||
) -> dict[str, int]:
|
||||
now = utc_now()
|
||||
safe_limit = (
|
||||
max(int(limit), 1)
|
||||
if isinstance(limit, int)
|
||||
else int(config.automation_scheduler.batch_limit)
|
||||
)
|
||||
async with AsyncSessionLocal() as session:
|
||||
repository = SqlAlchemyAutomationSchedulerRepository(session=session)
|
||||
service = AutomationSchedulerService(
|
||||
repository=repository,
|
||||
queue=_BulkQueueAdapter(),
|
||||
)
|
||||
result = await service.scan_and_dispatch(now_utc=now, limit=safe_limit)
|
||||
logger.info(
|
||||
"automation scheduler scan completed",
|
||||
scanned=result.scanned,
|
||||
dispatched=result.dispatched,
|
||||
now_utc=now.astimezone(timezone.utc).isoformat(),
|
||||
)
|
||||
return {
|
||||
"scanned": int(result.scanned),
|
||||
"dispatched": int(result.dispatched),
|
||||
}
|
||||
|
||||
|
||||
@default_broker.task(task_name="tasks.agentscope.run_command")
|
||||
async def run_command_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
return await run_agentscope_task(command)
|
||||
@@ -246,3 +356,8 @@ async def run_command_task_critical(command: dict[str, Any]) -> dict[str, object
|
||||
@bulk_broker.task(task_name="tasks.agentscope.run_command.bulk")
|
||||
async def run_command_task_bulk(command: dict[str, Any]) -> dict[str, object]:
|
||||
return await run_agentscope_task(command)
|
||||
|
||||
|
||||
@default_broker.task(task_name="tasks.automation.scan_due_jobs")
|
||||
async def scan_due_automation_jobs_task(limit: int | None = None) -> dict[str, int]:
|
||||
return await run_automation_scheduler_scan(limit=limit)
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from core.agentscope.tools.tool_config import resolve_tool_names_by_groups
|
||||
from core.agentscope.tools.tool_config import resolve_tool_function_names
|
||||
from schemas.agent.system_agent import AgentType
|
||||
|
||||
ToolNameResolver = Callable[[Any], set[str] | None]
|
||||
@@ -23,20 +23,19 @@ class ToolSelectionRegistry:
|
||||
return resolver(stage_config)
|
||||
|
||||
|
||||
def _default_group_resolver(stage_config: Any) -> set[str] | None:
|
||||
raw_groups = getattr(stage_config.llm_config, "enabled_tool_groups", [])
|
||||
groups = raw_groups if isinstance(raw_groups, list) else []
|
||||
if not groups:
|
||||
def _default_tool_resolver(stage_config: Any) -> set[str] | None:
|
||||
enabled_tools = getattr(stage_config.llm_config, "enabled_tools", [])
|
||||
if not enabled_tools:
|
||||
return None
|
||||
return resolve_tool_names_by_groups(set(groups))
|
||||
return resolve_tool_function_names(set(enabled_tools))
|
||||
|
||||
|
||||
TOOL_SELECTION_REGISTRY = ToolSelectionRegistry()
|
||||
TOOL_SELECTION_REGISTRY.register(
|
||||
agent_type=AgentType.WORKER,
|
||||
resolver=_default_group_resolver,
|
||||
resolver=_default_tool_resolver,
|
||||
)
|
||||
TOOL_SELECTION_REGISTRY.register(
|
||||
agent_type=AgentType.MEMORY,
|
||||
resolver=_default_group_resolver,
|
||||
resolver=_default_tool_resolver,
|
||||
)
|
||||
|
||||
@@ -11,7 +11,7 @@ from core.agentscope.tools.utils.calendar_domain import (
|
||||
map_calendar_exception,
|
||||
merge_schedule_metadata_for_update,
|
||||
parse_iso_datetime,
|
||||
resolve_share_target_email_map,
|
||||
resolve_share_target_phone_map,
|
||||
schedule_event_to_dict,
|
||||
)
|
||||
from core.agentscope.tools.utils.calendar_ui import (
|
||||
@@ -580,11 +580,11 @@ async def calendar_share(
|
||||
)
|
||||
target_uuid = UUID(event_id)
|
||||
|
||||
email_map = resolve_share_target_email_map(
|
||||
phone_map = resolve_share_target_phone_map(
|
||||
[invitee.user_id for invitee in invitees]
|
||||
)
|
||||
|
||||
if not email_map:
|
||||
if not phone_map:
|
||||
return calendar_error_output(
|
||||
tool_name=tool_name,
|
||||
tool_call_args=tool_call_args,
|
||||
@@ -599,8 +599,8 @@ async def calendar_share(
|
||||
normalized_user_id = str(UUID(invitee.user_id.strip()))
|
||||
except ValueError:
|
||||
continue
|
||||
email = email_map.get(normalized_user_id)
|
||||
if email is None:
|
||||
phone = phone_map.get(normalized_user_id)
|
||||
if phone is None:
|
||||
continue
|
||||
permission = {
|
||||
"permission_view": invitee.permission_view,
|
||||
@@ -608,15 +608,15 @@ async def calendar_share(
|
||||
"permission_invite": invitee.permission_invite,
|
||||
}
|
||||
await service.share(
|
||||
target_uuid, ScheduleItemShareRequest(email=email, **permission)
|
||||
target_uuid, ScheduleItemShareRequest(phone=phone, **permission)
|
||||
)
|
||||
invited.append(email)
|
||||
invited.append(phone)
|
||||
if not invited:
|
||||
return calendar_error_output(
|
||||
tool_name=tool_name,
|
||||
tool_call_args=tool_call_args,
|
||||
code="NOT_FOUND",
|
||||
message="邀请目标均无有效邮箱",
|
||||
message="邀请目标均无有效手机号",
|
||||
retryable=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from agentscope.tool import ToolResponse
|
||||
from core.agentscope.tools.tool_call_context import get_current_tool_call_id
|
||||
from core.agentscope.tools.utils import (
|
||||
find_auth_email_by_user_id,
|
||||
find_auth_phone_by_user_id,
|
||||
list_auth_users,
|
||||
)
|
||||
from core.agentscope.tools.utils.tool_response_builder import (
|
||||
@@ -46,22 +46,22 @@ def _lookup_error_output(
|
||||
async def _resolve_identity(
|
||||
*,
|
||||
session: AsyncSession,
|
||||
user_email: str | None,
|
||||
user_phone: str | None,
|
||||
user_name: str | None,
|
||||
) -> dict[str, Any]:
|
||||
"""Resolve user identity by email or username."""
|
||||
email = user_email.strip().lower() if isinstance(user_email, str) else ""
|
||||
"""Resolve user identity by phone or username."""
|
||||
phone = user_phone.strip() if isinstance(user_phone, str) else ""
|
||||
name = user_name.strip() if isinstance(user_name, str) else ""
|
||||
|
||||
if bool(email) == bool(name):
|
||||
if bool(phone) == bool(name):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="请提供 email 或 username 其中之一",
|
||||
detail="请提供 phone 或 username 其中之一",
|
||||
)
|
||||
|
||||
if email:
|
||||
if phone:
|
||||
auth_gateway = SupabaseAuthGateway()
|
||||
user = await auth_gateway.get_user_by_email(email)
|
||||
user = await auth_gateway.get_user_by_phone(phone)
|
||||
user_id = UUID(user.id)
|
||||
|
||||
stmt = (
|
||||
@@ -73,9 +73,9 @@ async def _resolve_identity(
|
||||
|
||||
return {
|
||||
"userId": str(user_id),
|
||||
"email": user.email,
|
||||
"phone": user.phone,
|
||||
"username": username,
|
||||
"matchedBy": "email",
|
||||
"matchedBy": "phone",
|
||||
}
|
||||
|
||||
stmt = (
|
||||
@@ -90,20 +90,20 @@ async def _resolve_identity(
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
|
||||
users = list_auth_users()
|
||||
email_value = find_auth_email_by_user_id(users=users, user_id=profile.id)
|
||||
phone_value = find_auth_phone_by_user_id(users=users, user_id=profile.id)
|
||||
|
||||
return {
|
||||
"userId": str(profile.id),
|
||||
"email": email_value,
|
||||
"phone": phone_value,
|
||||
"username": profile.username,
|
||||
"matchedBy": "username",
|
||||
}
|
||||
|
||||
|
||||
async def user_lookup(
|
||||
user_email: Annotated[
|
||||
user_phone: Annotated[
|
||||
str | None,
|
||||
Field(description="User email address to look up."),
|
||||
Field(description="User phone to look up."),
|
||||
] = None,
|
||||
user_name: Annotated[
|
||||
str | None,
|
||||
@@ -112,16 +112,16 @@ async def user_lookup(
|
||||
session: Any = None,
|
||||
owner_id: Any = None,
|
||||
) -> ToolResponse:
|
||||
"""Look up user identity by email or username.
|
||||
"""Look up user identity by phone or username.
|
||||
|
||||
Args:
|
||||
user_email: User email address for lookup.
|
||||
user_phone: User phone for lookup.
|
||||
user_name: Username for lookup.
|
||||
|
||||
Returns:
|
||||
ToolResponse with serialized ToolAgentOutput payload.
|
||||
"""
|
||||
tool_call_args = {"user_email": user_email, "user_name": user_name}
|
||||
tool_call_args = {"user_phone": user_phone, "user_name": user_name}
|
||||
|
||||
if session is None or owner_id is None:
|
||||
return _lookup_error_output(
|
||||
@@ -134,17 +134,17 @@ async def user_lookup(
|
||||
try:
|
||||
resolved = await _resolve_identity(
|
||||
session=cast(AsyncSession, session),
|
||||
user_email=user_email,
|
||||
user_phone=user_phone,
|
||||
user_name=user_name,
|
||||
)
|
||||
|
||||
username = str(resolved.get("username") or "")
|
||||
email = str(resolved.get("email") or "")
|
||||
phone = str(resolved.get("phone") or "")
|
||||
user_id = str(resolved.get("userId") or "")
|
||||
matched_by = str(resolved.get("matchedBy") or "")
|
||||
summary = (
|
||||
f"status=success matched_by={matched_by} user_id={user_id} "
|
||||
f"username={username} has_email={str(bool(email)).lower()}"
|
||||
f"username={username} has_phone={str(bool(phone)).lower()}"
|
||||
)
|
||||
return _dump_tool_output(
|
||||
ToolAgentOutput(
|
||||
|
||||
@@ -6,7 +6,15 @@ from enum import Enum
|
||||
|
||||
class ToolGroup(str, Enum):
|
||||
READ = "read"
|
||||
WRITE = "write"
|
||||
EXECUTE = "execute"
|
||||
MEMORY = "memory"
|
||||
|
||||
|
||||
class AgentTool(str, Enum):
|
||||
CALENDAR_READ = "calendar.read"
|
||||
CALENDAR_WRITE = "calendar.write"
|
||||
CALENDAR_SHARE = "calendar.share"
|
||||
USER_LOOKUP = "user.lookup"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -29,30 +37,51 @@ TOOL_CONFIGS: dict[str, ToolConfig] = {
|
||||
),
|
||||
"user_lookup": ToolConfig(
|
||||
name="user_lookup",
|
||||
group=ToolGroup.READ,
|
||||
group=ToolGroup.MEMORY,
|
||||
approval=ToolApprovalConfig(required=False),
|
||||
),
|
||||
"calendar_write": ToolConfig(
|
||||
name="calendar_write",
|
||||
group=ToolGroup.WRITE,
|
||||
group=ToolGroup.EXECUTE,
|
||||
approval=ToolApprovalConfig(required=False),
|
||||
),
|
||||
"calendar_share": ToolConfig(
|
||||
name="calendar_share",
|
||||
group=ToolGroup.WRITE,
|
||||
group=ToolGroup.EXECUTE,
|
||||
approval=ToolApprovalConfig(required=False),
|
||||
),
|
||||
}
|
||||
|
||||
AGENT_TOOL_TO_FUNCTION_NAME: dict[AgentTool, str] = {
|
||||
AgentTool.CALENDAR_READ: "calendar_read",
|
||||
AgentTool.CALENDAR_WRITE: "calendar_write",
|
||||
AgentTool.CALENDAR_SHARE: "calendar_share",
|
||||
AgentTool.USER_LOOKUP: "user_lookup",
|
||||
}
|
||||
|
||||
def get_tool_config(tool_name: str) -> ToolConfig:
|
||||
config = TOOL_CONFIGS.get(tool_name)
|
||||
if config is None:
|
||||
raise ValueError(f"unknown tool: {tool_name}")
|
||||
return config
|
||||
TOOL_NAME_ALIASES: dict[str, AgentTool] = {
|
||||
AgentTool.CALENDAR_READ.value: AgentTool.CALENDAR_READ,
|
||||
"calendar_read": AgentTool.CALENDAR_READ,
|
||||
AgentTool.CALENDAR_WRITE.value: AgentTool.CALENDAR_WRITE,
|
||||
"calendar_write": AgentTool.CALENDAR_WRITE,
|
||||
AgentTool.CALENDAR_SHARE.value: AgentTool.CALENDAR_SHARE,
|
||||
"calendar_share": AgentTool.CALENDAR_SHARE,
|
||||
AgentTool.USER_LOOKUP.value: AgentTool.USER_LOOKUP,
|
||||
"user_lookup": AgentTool.USER_LOOKUP,
|
||||
}
|
||||
|
||||
|
||||
def resolve_tool_names_by_groups(groups: set[ToolGroup]) -> set[str]:
|
||||
if not groups:
|
||||
return set()
|
||||
return {name for name, config in TOOL_CONFIGS.items() if config.group in groups}
|
||||
def parse_agent_tool(value: object) -> AgentTool:
|
||||
if isinstance(value, AgentTool):
|
||||
return value
|
||||
raw_value = str(value or "").strip().lower()
|
||||
if not raw_value:
|
||||
raise ValueError("enabled tool value cannot be empty")
|
||||
tool = TOOL_NAME_ALIASES.get(raw_value)
|
||||
if tool is None:
|
||||
raise ValueError(f"unknown enabled tool: {raw_value}")
|
||||
return tool
|
||||
|
||||
|
||||
def resolve_tool_function_names(tools: set[AgentTool]) -> set[str]:
|
||||
return {AGENT_TOOL_TO_FUNCTION_NAME[tool] for tool in tools}
|
||||
|
||||
@@ -7,10 +7,15 @@ from core.agentscope.tools.tool_call_context import (
|
||||
reset_current_tool_call_id,
|
||||
set_current_tool_call_id,
|
||||
)
|
||||
from core.agentscope.tools.tool_config import (
|
||||
AGENT_TOOL_TO_FUNCTION_NAME,
|
||||
TOOL_CONFIGS,
|
||||
ToolConfig,
|
||||
parse_agent_tool,
|
||||
)
|
||||
from core.agentscope.tools.utils.tool_response_builder import (
|
||||
build_error_response,
|
||||
)
|
||||
from core.agentscope.tools.tool_config import ToolConfig, TOOL_CONFIGS
|
||||
|
||||
|
||||
def register_tool_middlewares(
|
||||
@@ -59,6 +64,18 @@ def create_approval_middleware(
|
||||
approval_resolver: Callable[[str, dict[str, Any], ToolConfig], str | None]
|
||||
| None = None,
|
||||
) -> Callable[..., AsyncGenerator[Any, None]]:
|
||||
def _resolve_tool_config(*, tool_name: str) -> ToolConfig | None:
|
||||
config = config_by_name.get(tool_name)
|
||||
if config is not None:
|
||||
return config
|
||||
try:
|
||||
normalized_tool_name = AGENT_TOOL_TO_FUNCTION_NAME[
|
||||
parse_agent_tool(tool_name)
|
||||
]
|
||||
except ValueError:
|
||||
return None
|
||||
return config_by_name.get(normalized_tool_name)
|
||||
|
||||
def _resolve_tool_call_id(tool_call: dict[str, Any]) -> str:
|
||||
raw_tool_call_id = tool_call.get("id")
|
||||
if isinstance(raw_tool_call_id, str) and raw_tool_call_id.strip():
|
||||
@@ -81,7 +98,7 @@ def create_approval_middleware(
|
||||
yield response
|
||||
return
|
||||
|
||||
config = config_by_name.get(tool_name)
|
||||
config = _resolve_tool_config(tool_name=tool_name)
|
||||
if config is None or not config.approval.required:
|
||||
async for response in await next_handler(**kwargs):
|
||||
yield response
|
||||
@@ -134,15 +151,3 @@ def create_approval_middleware(
|
||||
yield pending_response
|
||||
|
||||
return approval_middleware
|
||||
|
||||
|
||||
def create_hitl_middleware(
|
||||
*,
|
||||
meta_by_name: dict[str, ToolConfig],
|
||||
approval_resolver: Callable[[str, dict[str, Any], ToolConfig], str | None]
|
||||
| None = None,
|
||||
) -> Callable[..., AsyncGenerator[Any, None]]:
|
||||
return create_approval_middleware(
|
||||
config_by_name=meta_by_name,
|
||||
approval_resolver=approval_resolver,
|
||||
)
|
||||
|
||||
@@ -13,8 +13,6 @@ from core.agentscope.tools.custom.calendar import (
|
||||
from core.agentscope.tools.custom.user_lookup import user_lookup
|
||||
from core.agentscope.tools.tool_config import (
|
||||
TOOL_CONFIGS,
|
||||
ToolGroup,
|
||||
resolve_tool_names_by_groups,
|
||||
)
|
||||
from core.agentscope.tools.tool_middleware import register_tool_middlewares
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
@@ -28,46 +26,36 @@ TOOL_FUNCTIONS: dict[str, Any] = {
|
||||
}
|
||||
|
||||
|
||||
AGENT_TYPE_TO_GROUPS: dict[AgentType, set[ToolGroup]] = {
|
||||
AgentType.WORKER: {ToolGroup.READ, ToolGroup.WRITE},
|
||||
AgentType.MEMORY: {ToolGroup.READ, ToolGroup.WRITE},
|
||||
AGENT_TYPE_TO_DEFAULT_TOOLS: dict[AgentType, set[str]] = {
|
||||
AgentType.WORKER: {
|
||||
"calendar_read",
|
||||
"calendar_write",
|
||||
"calendar_share",
|
||||
"user_lookup",
|
||||
},
|
||||
AgentType.MEMORY: {"calendar_read", "user_lookup"},
|
||||
}
|
||||
|
||||
|
||||
def _resolve_enabled_tools(
|
||||
*,
|
||||
groups: set[ToolGroup] | None,
|
||||
enabled_tool_names: set[str] | None,
|
||||
) -> set[str]:
|
||||
if enabled_tool_names is not None:
|
||||
unknown = enabled_tool_names - set(TOOL_FUNCTIONS)
|
||||
if unknown:
|
||||
raise ValueError(f"unknown tools in enabled_tool_names: {sorted(unknown)}")
|
||||
return set(enabled_tool_names)
|
||||
|
||||
if groups is None:
|
||||
return set(TOOL_FUNCTIONS)
|
||||
|
||||
resolved = resolve_tool_names_by_groups(groups)
|
||||
unknown = resolved - set(TOOL_FUNCTIONS)
|
||||
def _validate_enabled_tool_names(enabled_tool_names: set[str]) -> set[str]:
|
||||
unknown = enabled_tool_names - set(TOOL_FUNCTIONS)
|
||||
if unknown:
|
||||
raise ValueError(f"tool config contains unknown tools: {sorted(unknown)}")
|
||||
return resolved
|
||||
raise ValueError(f"unknown tools in enabled_tool_names: {sorted(unknown)}")
|
||||
return enabled_tool_names
|
||||
|
||||
|
||||
def build_toolkit(
|
||||
*,
|
||||
session: AsyncSession,
|
||||
owner_id: UUID,
|
||||
groups: set[ToolGroup] | None = None,
|
||||
enabled_tool_names: set[str] | None = None,
|
||||
enable_hitl: bool | None = None,
|
||||
):
|
||||
toolkit = Toolkit()
|
||||
enabled_names = _resolve_enabled_tools(
|
||||
groups=groups,
|
||||
enabled_tool_names=enabled_tool_names,
|
||||
)
|
||||
if enabled_tool_names is None:
|
||||
enabled_names = set(TOOL_FUNCTIONS)
|
||||
else:
|
||||
enabled_names = _validate_enabled_tool_names(set(enabled_tool_names))
|
||||
|
||||
preset_kwargs = cast(
|
||||
dict[str, JSONSerializableObject],
|
||||
@@ -100,15 +88,13 @@ def build_stage_toolkit(
|
||||
enabled_tool_names: set[str] | None = None,
|
||||
enable_hitl: bool | None = None,
|
||||
):
|
||||
groups = AGENT_TYPE_TO_GROUPS.get(agent_type)
|
||||
if groups is None:
|
||||
default_tools = AGENT_TYPE_TO_DEFAULT_TOOLS.get(agent_type)
|
||||
if default_tools is None:
|
||||
raise ValueError(f"unknown agent_type: {agent_type}")
|
||||
|
||||
stage_enabled_names = resolve_tool_names_by_groups(set(groups))
|
||||
selected_names = (
|
||||
stage_enabled_names
|
||||
set(default_tools)
|
||||
if enabled_tool_names is None
|
||||
else stage_enabled_names | set(enabled_tool_names)
|
||||
else _validate_enabled_tool_names(set(enabled_tool_names))
|
||||
)
|
||||
|
||||
return build_toolkit(
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from core.agentscope.tools.utils.auth_helpers import (
|
||||
find_auth_email_by_user_id,
|
||||
find_auth_phone_by_user_id,
|
||||
list_auth_users,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"list_auth_users",
|
||||
"find_auth_email_by_user_id",
|
||||
"find_auth_phone_by_user_id",
|
||||
]
|
||||
|
||||
@@ -25,12 +25,12 @@ def list_auth_users() -> list[Any]:
|
||||
return users
|
||||
|
||||
|
||||
def find_auth_email_by_user_id(*, users: list[Any], user_id: UUID) -> str | None:
|
||||
"""Find auth email by user id from fetched user list."""
|
||||
def find_auth_phone_by_user_id(*, users: list[Any], user_id: UUID) -> str | None:
|
||||
"""Find auth phone by user id from fetched user list."""
|
||||
target = str(user_id)
|
||||
for user in users:
|
||||
if str(getattr(user, "id", "")) == target:
|
||||
email = getattr(user, "email", None)
|
||||
if isinstance(email, str) and email.strip():
|
||||
return email.strip()
|
||||
phone = getattr(user, "phone", None)
|
||||
if isinstance(phone, str) and phone.strip():
|
||||
return phone.strip()
|
||||
return None
|
||||
|
||||
@@ -9,7 +9,7 @@ from fastapi import HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.agentscope.tools.utils.auth_helpers import (
|
||||
find_auth_email_by_user_id,
|
||||
find_auth_phone_by_user_id,
|
||||
list_auth_users,
|
||||
)
|
||||
from core.auth.models import CurrentUser
|
||||
@@ -125,7 +125,7 @@ def parse_iso_datetime(value: str | None) -> datetime | None:
|
||||
return parsed.astimezone(timezone.utc)
|
||||
|
||||
|
||||
def resolve_share_target_email_map(invitee_user_ids: list[str]) -> dict[str, str]:
|
||||
def resolve_share_target_phone_map(invitee_user_ids: list[str]) -> dict[str, str]:
|
||||
users = list_auth_users()
|
||||
resolved: dict[str, str] = {}
|
||||
for raw_user_id in invitee_user_ids:
|
||||
@@ -138,7 +138,7 @@ def resolve_share_target_email_map(invitee_user_ids: list[str]) -> dict[str, str
|
||||
user_uuid = UUID(normalized_user_id)
|
||||
except ValueError:
|
||||
continue
|
||||
email = find_auth_email_by_user_id(users=users, user_id=user_uuid)
|
||||
if email:
|
||||
resolved[str(user_uuid)] = email.lower()
|
||||
phone = find_auth_phone_by_user_id(users=users, user_id=user_uuid)
|
||||
if phone:
|
||||
resolved[str(user_uuid)] = phone
|
||||
return resolved
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
from core.automation.scheduler import (
|
||||
AutomationSchedulerService,
|
||||
DispatchResult,
|
||||
SqlAlchemyAutomationSchedulerRepository,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AutomationSchedulerService",
|
||||
"DispatchResult",
|
||||
"SqlAlchemyAutomationSchedulerRepository",
|
||||
]
|
||||
@@ -0,0 +1,247 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Protocol
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.logging import get_logger
|
||||
from models.agent_chat_session import AgentChatSession, SessionType
|
||||
from models.automation_jobs import AutomationJob, ScheduleType
|
||||
from schemas.automation.config import AutomationJobConfig
|
||||
from schemas.automation.scheduler import DueAutomationJob, SchedulerDispatchCommand
|
||||
|
||||
logger = get_logger("core.automation.scheduler")
|
||||
|
||||
|
||||
class QueueLike(Protocol):
|
||||
async def enqueue(
|
||||
self,
|
||||
*,
|
||||
command: dict[str, object],
|
||||
dedup_key: str | None,
|
||||
) -> str: ...
|
||||
|
||||
|
||||
class AutomationSchedulerRepositoryLike(Protocol):
|
||||
async def list_due_jobs(
|
||||
self,
|
||||
*,
|
||||
now_utc: datetime,
|
||||
limit: int,
|
||||
) -> list[DueAutomationJob]: ...
|
||||
|
||||
async def get_job_config(self, *, job_id: UUID) -> AutomationJobConfig: ...
|
||||
|
||||
async def ensure_latest_chat_session(self, *, owner_id: UUID) -> UUID: ...
|
||||
|
||||
async def mark_job_dispatched(
|
||||
self,
|
||||
*,
|
||||
job_id: UUID,
|
||||
next_run_at: datetime,
|
||||
last_run_at: datetime,
|
||||
) -> None: ...
|
||||
|
||||
async def commit(self) -> None: ...
|
||||
|
||||
async def rollback(self) -> None: ...
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class DispatchResult:
|
||||
scanned: int
|
||||
dispatched: int
|
||||
|
||||
|
||||
class AutomationSchedulerService:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
repository: AutomationSchedulerRepositoryLike,
|
||||
queue: QueueLike,
|
||||
) -> None:
|
||||
self._repository = repository
|
||||
self._queue = queue
|
||||
|
||||
async def scan_and_dispatch(
|
||||
self,
|
||||
*,
|
||||
now_utc: datetime,
|
||||
limit: int,
|
||||
) -> DispatchResult:
|
||||
safe_limit = max(int(limit), 1)
|
||||
due_jobs = await self._repository.list_due_jobs(
|
||||
now_utc=now_utc, limit=safe_limit
|
||||
)
|
||||
dispatched = 0
|
||||
for job in due_jobs:
|
||||
try:
|
||||
config = await self._repository.get_job_config(job_id=job.id)
|
||||
thread_id = await self._repository.ensure_latest_chat_session(
|
||||
owner_id=job.owner_id
|
||||
)
|
||||
command = self._build_dispatch_command(
|
||||
job=job,
|
||||
thread_id=thread_id,
|
||||
input_text=config.input_template,
|
||||
now_utc=now_utc,
|
||||
)
|
||||
await self._queue.enqueue(command=command, dedup_key=None)
|
||||
await self._repository.mark_job_dispatched(
|
||||
job_id=job.id,
|
||||
next_run_at=_compute_next_run_at(
|
||||
current_next_run_at=job.next_run_at,
|
||||
now_utc=now_utc,
|
||||
schedule_type=job.schedule_type,
|
||||
),
|
||||
last_run_at=now_utc,
|
||||
)
|
||||
await self._repository.commit()
|
||||
dispatched += 1
|
||||
except Exception as exc:
|
||||
await self._repository.rollback()
|
||||
logger.exception(
|
||||
"automation job dispatch failed",
|
||||
job_id=str(job.id),
|
||||
owner_id=str(job.owner_id),
|
||||
error=str(exc),
|
||||
)
|
||||
return DispatchResult(scanned=len(due_jobs), dispatched=dispatched)
|
||||
|
||||
def _build_dispatch_command(
|
||||
self,
|
||||
*,
|
||||
job: DueAutomationJob,
|
||||
thread_id: UUID,
|
||||
input_text: str,
|
||||
now_utc: datetime,
|
||||
) -> dict[str, object]:
|
||||
run_id = f"auto-{job.id}-{int(now_utc.timestamp())}"
|
||||
payload = SchedulerDispatchCommand(
|
||||
owner_id=job.owner_id,
|
||||
automation_job_id=job.id,
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
input_text=input_text.strip(),
|
||||
)
|
||||
return {
|
||||
"command": "run",
|
||||
"owner_id": str(payload.owner_id),
|
||||
"automation_job_id": str(payload.automation_job_id),
|
||||
"queue": "bulk",
|
||||
"run_input": {
|
||||
"threadId": str(payload.thread_id),
|
||||
"runId": payload.run_id,
|
||||
"state": {},
|
||||
"messages": [
|
||||
{
|
||||
"id": str(uuid4()),
|
||||
"role": "user",
|
||||
"content": payload.input_text,
|
||||
}
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {
|
||||
"agent_type": "memory",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class SqlAlchemyAutomationSchedulerRepository:
|
||||
def __init__(self, *, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
async def list_due_jobs(
|
||||
self,
|
||||
*,
|
||||
now_utc: datetime,
|
||||
limit: int,
|
||||
) -> list[DueAutomationJob]:
|
||||
stmt = (
|
||||
select(AutomationJob)
|
||||
.where(AutomationJob.deleted_at.is_(None))
|
||||
.where(AutomationJob.status == "active")
|
||||
.where(AutomationJob.next_run_at <= now_utc)
|
||||
.order_by(AutomationJob.next_run_at.asc())
|
||||
.limit(max(limit, 1))
|
||||
)
|
||||
rows = (await self._session.execute(stmt)).scalars().all()
|
||||
return [
|
||||
DueAutomationJob(
|
||||
id=row.id,
|
||||
owner_id=row.owner_id,
|
||||
schedule_type=row.schedule_type,
|
||||
timezone=row.timezone,
|
||||
next_run_at=row.next_run_at,
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
async def get_job_config(self, *, job_id: UUID) -> AutomationJobConfig:
|
||||
stmt = select(AutomationJob.config).where(AutomationJob.id == job_id)
|
||||
config_payload = (await self._session.execute(stmt)).scalar_one()
|
||||
return AutomationJobConfig.model_validate(config_payload or {})
|
||||
|
||||
async def ensure_latest_chat_session(self, *, owner_id: UUID) -> UUID:
|
||||
stmt = (
|
||||
select(AgentChatSession.id)
|
||||
.where(AgentChatSession.user_id == owner_id)
|
||||
.where(AgentChatSession.deleted_at.is_(None))
|
||||
.where(AgentChatSession.session_type == SessionType.CHAT)
|
||||
.order_by(AgentChatSession.last_activity_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
existing = (await self._session.execute(stmt)).scalar_one_or_none()
|
||||
if existing is not None:
|
||||
return existing
|
||||
|
||||
session = AgentChatSession(
|
||||
id=uuid4(),
|
||||
user_id=owner_id,
|
||||
session_type=SessionType.CHAT,
|
||||
)
|
||||
self._session.add(session)
|
||||
await self._session.flush()
|
||||
return session.id
|
||||
|
||||
async def mark_job_dispatched(
|
||||
self,
|
||||
*,
|
||||
job_id: UUID,
|
||||
next_run_at: datetime,
|
||||
last_run_at: datetime,
|
||||
) -> None:
|
||||
stmt = select(AutomationJob).where(AutomationJob.id == job_id)
|
||||
row = (await self._session.execute(stmt)).scalar_one()
|
||||
row.next_run_at = next_run_at
|
||||
row.last_run_at = last_run_at
|
||||
await self._session.flush()
|
||||
|
||||
async def commit(self) -> None:
|
||||
await self._session.commit()
|
||||
|
||||
async def rollback(self) -> None:
|
||||
await self._session.rollback()
|
||||
|
||||
|
||||
def _compute_next_run_at(
|
||||
*,
|
||||
current_next_run_at: datetime,
|
||||
now_utc: datetime,
|
||||
schedule_type: ScheduleType,
|
||||
) -> datetime:
|
||||
delta = timedelta(days=1 if schedule_type == ScheduleType.DAILY else 7)
|
||||
next_run_at = current_next_run_at
|
||||
while next_run_at <= now_utc:
|
||||
next_run_at = next_run_at + delta
|
||||
return next_run_at
|
||||
|
||||
|
||||
def utc_now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
@@ -1,27 +1,30 @@
|
||||
agents:
|
||||
- agent_type: worker
|
||||
llm_model_code: qwen3.5-35b-a3b
|
||||
status: active
|
||||
config:
|
||||
temperature: 0.7
|
||||
max_tokens: null
|
||||
timeout_seconds: 30
|
||||
context_messages:
|
||||
mode: number
|
||||
count: 20
|
||||
enabled_tool_groups:
|
||||
- read
|
||||
- write
|
||||
- agent_type: router
|
||||
llm_model_code: qwen3.5-flash
|
||||
status: active
|
||||
config:
|
||||
temperature: 0.7
|
||||
max_tokens: null
|
||||
timeout_seconds: 30
|
||||
visibility_consumer_bit: 16
|
||||
context_messages:
|
||||
mode: day
|
||||
count: 2
|
||||
enabled_tools: []
|
||||
|
||||
- agent_type: memory
|
||||
llm_model_code: qwen3.5-flash
|
||||
status: active
|
||||
config:
|
||||
temperature: 0.7
|
||||
max_tokens: null
|
||||
timeout_seconds: 30
|
||||
context_messages:
|
||||
mode: day
|
||||
count: 2
|
||||
enabled_tool_groups:
|
||||
- read
|
||||
- agent_type: worker
|
||||
llm_model_code: qwen3.5-flash
|
||||
status: active
|
||||
config:
|
||||
temperature: 0.7
|
||||
max_tokens: null
|
||||
timeout_seconds: 30
|
||||
visibility_consumer_bit: 17
|
||||
context_messages:
|
||||
mode: number
|
||||
count: 20
|
||||
enabled_tools:
|
||||
- calendar.read
|
||||
- calendar.write
|
||||
- calendar.share
|
||||
- user.lookup
|
||||
|
||||
@@ -5,6 +5,7 @@ import uuid
|
||||
from enum import Enum
|
||||
|
||||
from sqlalchemy import (
|
||||
BigInteger,
|
||||
JSON,
|
||||
Enum as SqlEnum,
|
||||
ForeignKey,
|
||||
@@ -59,6 +60,11 @@ class AgentChatMessage(TimestampMixin, SoftDeleteMixin, Base):
|
||||
output_tokens: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
cost: Mapped[Decimal] = mapped_column(Numeric(12, 6), nullable=False, default=0)
|
||||
latency_ms: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
visibility_mask: Mapped[int] = mapped_column(
|
||||
BigInteger,
|
||||
nullable=False,
|
||||
default=0,
|
||||
)
|
||||
metadata_json: Mapped[dict[str, object] | None] = mapped_column(
|
||||
"metadata", JSON().with_variant(JSONB, "postgresql"), nullable=True
|
||||
)
|
||||
|
||||
@@ -4,8 +4,8 @@ import uuid
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from sqlalchemy import DateTime, String, Text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy import DateTime, JSON, String
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from core.db.base import Base, SoftDeleteMixin, TimestampMixin
|
||||
@@ -36,9 +36,10 @@ class AutomationJob(TimestampMixin, SoftDeleteMixin, Base):
|
||||
String(255),
|
||||
nullable=False,
|
||||
)
|
||||
prompt: Mapped[str] = mapped_column(
|
||||
Text,
|
||||
config: Mapped[dict[str, object]] = mapped_column(
|
||||
JSON().with_variant(JSONB, "postgresql"),
|
||||
nullable=False,
|
||||
default=dict,
|
||||
)
|
||||
schedule_type: Mapped[ScheduleType] = mapped_column(
|
||||
String(20),
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
|
||||
|
||||
class AgentConsumerBinding(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
agent_type: str = Field(..., min_length=1, max_length=64)
|
||||
bit: int = Field(..., ge=16, le=63)
|
||||
|
||||
@field_validator("agent_type")
|
||||
@classmethod
|
||||
def _normalize_agent_type(cls, value: str) -> str:
|
||||
normalized = value.strip().lower()
|
||||
if not normalized:
|
||||
raise ValueError("agent_type must not be empty")
|
||||
return normalized
|
||||
|
||||
|
||||
class ConsumerRegistry(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
bindings: list[AgentConsumerBinding] = Field(default_factory=list)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_unique_bindings(self) -> "ConsumerRegistry":
|
||||
by_agent: set[str] = set()
|
||||
by_bit: set[int] = set()
|
||||
for item in self.bindings:
|
||||
if item.agent_type in by_agent:
|
||||
raise ValueError(f"duplicate agent_type binding: {item.agent_type}")
|
||||
if item.bit in by_bit:
|
||||
raise ValueError(f"duplicate visibility bit binding: {item.bit}")
|
||||
by_agent.add(item.agent_type)
|
||||
by_bit.add(item.bit)
|
||||
return self
|
||||
|
||||
def resolve_agent_bit(self, *, agent_type: str) -> int:
|
||||
target = agent_type.strip().lower()
|
||||
for item in self.bindings:
|
||||
if item.agent_type == target:
|
||||
return item.bit
|
||||
raise ValueError(f"agent visibility bit not configured: {target}")
|
||||
@@ -0,0 +1,63 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
|
||||
class ExecutorKind(str, Enum):
|
||||
SINGLE_SHOT = "single_shot"
|
||||
REACT = "react"
|
||||
|
||||
|
||||
class ContextWindowMode(str, Enum):
|
||||
DAY = "day"
|
||||
NUMBER = "number"
|
||||
|
||||
|
||||
class ContextPolicy(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
consumer_agent_type: str = Field(..., min_length=1, max_length=64)
|
||||
window_mode: ContextWindowMode = ContextWindowMode.NUMBER
|
||||
count: int = Field(default=20, ge=1, le=200)
|
||||
|
||||
@field_validator("consumer_agent_type")
|
||||
@classmethod
|
||||
def _normalize_consumer_agent_type(cls, value: str) -> str:
|
||||
normalized = value.strip().lower()
|
||||
if not normalized:
|
||||
raise ValueError("consumer_agent_type must not be empty")
|
||||
return normalized
|
||||
|
||||
|
||||
class StageSpec(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
stage_name: str = Field(..., min_length=1, max_length=64)
|
||||
executor_kind: ExecutorKind
|
||||
default_visibility_mask: int = Field(..., ge=0, le=(1 << 63) - 1)
|
||||
context_policy: ContextPolicy
|
||||
|
||||
@field_validator("stage_name")
|
||||
@classmethod
|
||||
def _normalize_stage_name(cls, value: str) -> str:
|
||||
normalized = value.strip().lower()
|
||||
if not normalized:
|
||||
raise ValueError("stage_name must not be empty")
|
||||
return normalized
|
||||
|
||||
|
||||
class PipelineSpec(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
mode: str = Field(..., min_length=1, max_length=64)
|
||||
stages: list[StageSpec] = Field(..., min_length=1)
|
||||
|
||||
@field_validator("mode")
|
||||
@classmethod
|
||||
def _normalize_mode(cls, value: str) -> str:
|
||||
normalized = value.strip().lower()
|
||||
if not normalized:
|
||||
raise ValueError("mode must not be empty")
|
||||
return normalized
|
||||
@@ -0,0 +1,50 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import IntEnum
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
|
||||
class SystemVisibilityBit(IntEnum):
|
||||
UI_HISTORY = 0
|
||||
UI_REALTIME = 1
|
||||
|
||||
|
||||
class VisibilityMask(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
value: int = Field(..., ge=0, le=(1 << 63) - 1)
|
||||
|
||||
@classmethod
|
||||
def from_bits(cls, *, bits: list[int]) -> "VisibilityMask":
|
||||
mask = 0
|
||||
for bit in bits:
|
||||
validate_visibility_bit(bit=bit)
|
||||
mask |= 1 << bit
|
||||
return cls(value=mask)
|
||||
|
||||
def contains(self, *, bit: int) -> bool:
|
||||
validate_visibility_bit(bit=bit)
|
||||
return bool(self.value & (1 << bit))
|
||||
|
||||
|
||||
class VisibilityBitRef(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
bit: int = Field(..., ge=0, le=63)
|
||||
|
||||
@field_validator("bit")
|
||||
@classmethod
|
||||
def _validate_bit(cls, value: int) -> int:
|
||||
validate_visibility_bit(bit=value)
|
||||
return value
|
||||
|
||||
|
||||
def validate_visibility_bit(*, bit: int) -> None:
|
||||
if bit < 0 or bit > 63:
|
||||
raise ValueError("visibility bit must be in range [0, 63]")
|
||||
|
||||
|
||||
def bit_mask(*, bit: int) -> int:
|
||||
validate_visibility_bit(bit=bit)
|
||||
return 1 << bit
|
||||
@@ -0,0 +1,20 @@
|
||||
from schemas.automation.config import (
|
||||
AutomationAgentType,
|
||||
AutomationContextSource,
|
||||
AutomationContextWindowMode,
|
||||
AutomationJobConfig,
|
||||
AutomationMemoryContextConfig,
|
||||
default_memory_job_config,
|
||||
)
|
||||
from schemas.automation.scheduler import DueAutomationJob, SchedulerDispatchCommand
|
||||
|
||||
__all__ = [
|
||||
"AutomationAgentType",
|
||||
"AutomationContextSource",
|
||||
"AutomationContextWindowMode",
|
||||
"AutomationJobConfig",
|
||||
"AutomationMemoryContextConfig",
|
||||
"default_memory_job_config",
|
||||
"DueAutomationJob",
|
||||
"SchedulerDispatchCommand",
|
||||
]
|
||||
@@ -0,0 +1,62 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from core.agentscope.tools.tool_config import AgentTool
|
||||
|
||||
|
||||
class AutomationAgentType(str, Enum):
|
||||
MEMORY = "memory"
|
||||
|
||||
|
||||
class AutomationContextSource(str, Enum):
|
||||
LATEST_CHAT = "latest_chat"
|
||||
|
||||
|
||||
class AutomationContextWindowMode(str, Enum):
|
||||
DAY = "day"
|
||||
NUMBER = "number"
|
||||
|
||||
|
||||
class AutomationMemoryContextConfig(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
source: AutomationContextSource = AutomationContextSource.LATEST_CHAT
|
||||
window_mode: AutomationContextWindowMode = AutomationContextWindowMode.DAY
|
||||
window_count: int = Field(default=2, ge=1, le=200)
|
||||
|
||||
|
||||
class AutomationJobConfig(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
agent_type: AutomationAgentType = AutomationAgentType.MEMORY
|
||||
model_code: str = Field(default="qwen3.5-flash", min_length=1, max_length=64)
|
||||
enabled_tools: list[AgentTool] = Field(default_factory=list, max_length=32)
|
||||
input_template: str = Field(..., min_length=1, max_length=4000)
|
||||
context: AutomationMemoryContextConfig = Field(
|
||||
default_factory=AutomationMemoryContextConfig
|
||||
)
|
||||
|
||||
@field_validator("model_code")
|
||||
@classmethod
|
||||
def _validate_model_code(cls, value: str) -> str:
|
||||
normalized = value.strip()
|
||||
if normalized != "qwen3.5-flash":
|
||||
raise ValueError("model_code must be qwen3.5-flash")
|
||||
return normalized
|
||||
|
||||
|
||||
def default_memory_job_config() -> AutomationJobConfig:
|
||||
return AutomationJobConfig(
|
||||
agent_type=AutomationAgentType.MEMORY,
|
||||
model_code="qwen3.5-flash",
|
||||
enabled_tools=[AgentTool.CALENDAR_READ, AgentTool.USER_LOOKUP],
|
||||
input_template="请基于最近聊天上下文生成一段可执行的记忆总结与建议。",
|
||||
context=AutomationMemoryContextConfig(
|
||||
source=AutomationContextSource.LATEST_CHAT,
|
||||
window_mode=AutomationContextWindowMode.DAY,
|
||||
window_count=2,
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from models.automation_jobs import ScheduleType
|
||||
|
||||
|
||||
class DueAutomationJob(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
id: UUID
|
||||
owner_id: UUID
|
||||
schedule_type: ScheduleType
|
||||
timezone: str = Field(..., min_length=1, max_length=50)
|
||||
next_run_at: datetime
|
||||
|
||||
|
||||
class SchedulerDispatchCommand(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
owner_id: UUID
|
||||
automation_job_id: UUID
|
||||
thread_id: UUID
|
||||
run_id: str = Field(..., min_length=1, max_length=128)
|
||||
input_text: str = Field(..., min_length=1, max_length=4000)
|
||||
@@ -189,3 +189,56 @@ def test_step_started_internal_event_keeps_step_name() -> None:
|
||||
|
||||
assert result["type"] == "STEP_STARTED"
|
||||
assert result["stepName"] == "worker"
|
||||
|
||||
|
||||
def test_run_error_prefers_top_level_message_and_code() -> None:
|
||||
internal = {
|
||||
"type": "run.error",
|
||||
"threadId": "thread-1",
|
||||
"runId": "run-1",
|
||||
"message": "runtime failed",
|
||||
"code": "RUNTIME_ERROR",
|
||||
"data": {
|
||||
"message": "nested message",
|
||||
"code": "NESTED_ERROR",
|
||||
},
|
||||
}
|
||||
|
||||
result = to_agui_wire_event(internal)
|
||||
|
||||
assert result["type"] == "RUN_ERROR"
|
||||
assert result["message"] == "runtime failed"
|
||||
assert result["code"] == "RUNTIME_ERROR"
|
||||
|
||||
|
||||
def test_run_error_falls_back_to_data_when_top_level_missing() -> None:
|
||||
internal = {
|
||||
"type": "run.error",
|
||||
"threadId": "thread-1",
|
||||
"runId": "run-1",
|
||||
"data": {
|
||||
"message": "nested message",
|
||||
"code": "NESTED_ERROR",
|
||||
},
|
||||
}
|
||||
|
||||
result = to_agui_wire_event(internal)
|
||||
|
||||
assert result["type"] == "RUN_ERROR"
|
||||
assert result["message"] == "nested message"
|
||||
assert result["code"] == "NESTED_ERROR"
|
||||
|
||||
|
||||
def test_run_error_uses_default_message_when_payload_invalid() -> None:
|
||||
internal = {
|
||||
"type": "run.error",
|
||||
"threadId": "thread-1",
|
||||
"runId": "run-1",
|
||||
"data": "invalid",
|
||||
}
|
||||
|
||||
result = to_agui_wire_event(internal)
|
||||
|
||||
assert result["type"] == "RUN_ERROR"
|
||||
assert result["message"] == "Unknown error"
|
||||
assert "code" not in result
|
||||
|
||||
@@ -59,6 +59,16 @@ def _patch_repositories(
|
||||
monkeypatch.setattr(store_module, "MessageRepository", _FakeMessageRepository)
|
||||
monkeypatch.setattr(store_module, "AgentChatSessionStatus", _SessionStatus)
|
||||
|
||||
async def _fake_stage_bit_map(self, *, session: object) -> dict[str, int]:
|
||||
del self, session
|
||||
return {"router": 16, "worker": 17, "memory": 18}
|
||||
|
||||
monkeypatch.setattr(
|
||||
store_module.SqlAlchemyEventStore,
|
||||
"_load_stage_visibility_bit_map",
|
||||
_fake_stage_bit_map,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_persists_worker_output_with_answer_as_content(
|
||||
@@ -103,6 +113,7 @@ async def test_store_persists_worker_output_with_answer_as_content(
|
||||
assert metadata["agent_output"]["answer"] == "worker-answer"
|
||||
assert metadata["agent_output"]["ui_hints"]["intent"] == "message"
|
||||
assert append_kwargs["cost"] == Decimal("0.123")
|
||||
assert append_kwargs["visibility_mask"] == ((1 << 0) | (1 << 17))
|
||||
assert captured["message_delta"] == 1
|
||||
assert captured["token_delta"] == 8
|
||||
|
||||
@@ -141,3 +152,4 @@ async def test_store_persists_tool_output_with_summary_as_content(
|
||||
metadata["tool_agent_output"]["result"]
|
||||
== "status=success batch=1 success=1 failed=0 ids=[event-1]"
|
||||
)
|
||||
assert append_kwargs["visibility_mask"] == (1 << 0)
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agentscope.persistence.user_context_cache import UserContextCache
|
||||
from core.agentscope.schemas.user_context import (
|
||||
UserAgentContext,
|
||||
from schemas.user.context import (
|
||||
UserContext,
|
||||
parse_profile_settings,
|
||||
)
|
||||
|
||||
@@ -45,15 +46,15 @@ class _FakeRedis:
|
||||
self.set_store.pop(key, None)
|
||||
return len(keys)
|
||||
|
||||
async def sadd(self, key: str, *values: str) -> int:
|
||||
bucket = self.set_store.setdefault(key, set())
|
||||
async def sadd(self, name: str, *values: str) -> int:
|
||||
bucket = self.set_store.setdefault(name, set())
|
||||
before = len(bucket)
|
||||
for value in values:
|
||||
bucket.add(value)
|
||||
return len(bucket) - before
|
||||
|
||||
async def smembers(self, key: str) -> set[str]:
|
||||
return set(self.set_store.get(key, set()))
|
||||
async def smembers(self, name: str) -> set[str]:
|
||||
return set(self.set_store.get(name, set()))
|
||||
|
||||
|
||||
class _BrokenRedis:
|
||||
@@ -77,18 +78,18 @@ class _BrokenRedis:
|
||||
del keys
|
||||
raise RuntimeError("redis down")
|
||||
|
||||
async def sadd(self, key: str, *values: str) -> int:
|
||||
del key, values
|
||||
async def sadd(self, name: str, *values: str) -> int:
|
||||
del name, values
|
||||
raise RuntimeError("redis down")
|
||||
|
||||
async def smembers(self, key: str) -> set[str]:
|
||||
del key
|
||||
async def smembers(self, name: str) -> set[str]:
|
||||
del name
|
||||
raise RuntimeError("redis down")
|
||||
|
||||
|
||||
def _build_context() -> UserAgentContext:
|
||||
return UserAgentContext(
|
||||
user_id=uuid4(),
|
||||
def _build_context() -> UserContext:
|
||||
return UserContext(
|
||||
id=str(uuid4()),
|
||||
username="demo-user",
|
||||
bio="demo bio",
|
||||
settings=parse_profile_settings({"preferences": {"ai_language": "en-US"}}),
|
||||
@@ -111,11 +112,11 @@ async def test_user_context_cache_set_and_get_hit() -> None:
|
||||
loaded = await cache.get(session_id=session_id)
|
||||
|
||||
assert loaded is not None
|
||||
assert loaded.user_id == context.user_id
|
||||
assert loaded.id == context.id
|
||||
assert loaded.username == "demo-user"
|
||||
assert redis.expire_calls == [
|
||||
(f"agent:user-context:{session_id}", 600),
|
||||
(f"agent:user-context:sessions:{context.user_id}", 600),
|
||||
(f"agent:user-context:sessions:{context.id}", 600),
|
||||
]
|
||||
assert redis.hincrby_calls == [
|
||||
(f"agent:user-context:{session_id}", "turns_used", 1)
|
||||
@@ -138,12 +139,14 @@ async def test_user_context_cache_invalidate_user_deletes_all_sessions() -> None
|
||||
await cache.set(session_id=s1, context=context)
|
||||
await cache.set(session_id=s2, context=context)
|
||||
|
||||
deleted = await cache.invalidate_user(user_id=context.user_id)
|
||||
deleted = await cache.invalidate_user(user_id=UUID(context.id))
|
||||
|
||||
assert deleted == 2
|
||||
assert f"agent:user-context:{s1}" in redis.delete_calls
|
||||
assert f"agent:user-context:{s2}" in redis.delete_calls
|
||||
assert f"agent:user-context:sessions:{context.user_id}" in redis.delete_calls
|
||||
assert any(
|
||||
key.startswith("agent:user-context:sessions:") for key in redis.delete_calls
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agentscope.runtime.consumer_registry import build_consumer_registry
|
||||
|
||||
|
||||
def test_build_consumer_registry_from_system_agent_configs() -> None:
|
||||
registry = build_consumer_registry(
|
||||
system_agent_configs={
|
||||
"router": {"config": {"visibility_consumer_bit": 16}},
|
||||
"worker": {"config": {"visibility_consumer_bit": 17}},
|
||||
"memory": {"config": {"visibility_consumer_bit": 18}},
|
||||
}
|
||||
)
|
||||
|
||||
assert registry.resolve_agent_bit(agent_type="router") == 16
|
||||
assert registry.resolve_agent_bit(agent_type="worker") == 17
|
||||
|
||||
|
||||
def test_build_consumer_registry_rejects_duplicate_bit() -> None:
|
||||
with pytest.raises(ValueError, match="duplicate visibility bit"):
|
||||
build_consumer_registry(
|
||||
system_agent_configs={
|
||||
"router": {"config": {"visibility_consumer_bit": 16}},
|
||||
"worker": {"config": {"visibility_consumer_bit": 16}},
|
||||
}
|
||||
)
|
||||
@@ -28,7 +28,7 @@ def _user_context() -> UserContext:
|
||||
return UserContext(
|
||||
id="00000000-0000-0000-0000-000000000001",
|
||||
username="alice",
|
||||
email="alice@example.com",
|
||||
phone="+8613900000000",
|
||||
settings=parse_profile_settings(None),
|
||||
)
|
||||
|
||||
@@ -42,7 +42,7 @@ def _run_input() -> RunAgentInput:
|
||||
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
"forwardedProps": {"agent_type": "worker"},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -58,6 +58,7 @@ async def test_orchestrator_emits_run_lifecycle_events() -> None:
|
||||
run_input=_run_input(),
|
||||
context_messages=[],
|
||||
user_context=_user_context(),
|
||||
system_agent_mode="worker",
|
||||
)
|
||||
|
||||
assert result["worker"]["answer"] == "done"
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agentscope.runtime.pipeline_registry import build_default_pipeline_spec
|
||||
|
||||
|
||||
def test_build_default_pipeline_spec_worker_has_two_stages() -> None:
|
||||
spec = build_default_pipeline_spec(mode="worker")
|
||||
|
||||
assert spec.mode == "worker"
|
||||
assert [item.stage_name for item in spec.stages] == ["router", "worker"]
|
||||
|
||||
|
||||
def test_build_default_pipeline_spec_memory_has_single_stage() -> None:
|
||||
spec = build_default_pipeline_spec(mode="memory")
|
||||
|
||||
assert spec.mode == "memory"
|
||||
assert [item.stage_name for item in spec.stages] == ["memory"]
|
||||
|
||||
|
||||
def test_build_default_pipeline_spec_rejects_unknown_mode() -> None:
|
||||
with pytest.raises(ValueError, match="unsupported pipeline mode"):
|
||||
build_default_pipeline_spec(mode="planner")
|
||||
@@ -3,8 +3,23 @@ from __future__ import annotations
|
||||
import pytest
|
||||
from ag_ui.core import RunAgentInput
|
||||
|
||||
import core.agentscope.runtime.runner as runner_module
|
||||
from core.agentscope.runtime.runner import AgentScopeRunner
|
||||
from schemas.automation.config import default_memory_job_config
|
||||
from schemas.agent.runtime_models import (
|
||||
ExecutionMode,
|
||||
NormalizedTaskInput,
|
||||
ResultType,
|
||||
ResultTyping,
|
||||
RouterAgentOutput,
|
||||
RouterUiDecision,
|
||||
TaskType,
|
||||
TaskTyping,
|
||||
UiMode,
|
||||
WorkerAgentOutputLite,
|
||||
)
|
||||
from schemas.agent.system_agent import AgentType
|
||||
from schemas.user import UserContext, parse_profile_settings
|
||||
|
||||
|
||||
def _run_input() -> RunAgentInput:
|
||||
@@ -16,19 +31,51 @@ def _run_input() -> RunAgentInput:
|
||||
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
"forwardedProps": {"agent_type": "worker"},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_resolve_stage_agent_type_defaults_to_worker() -> None:
|
||||
assert AgentScopeRunner._resolve_stage_agent_type("") == AgentType.WORKER
|
||||
assert AgentScopeRunner._resolve_stage_agent_type("worker") == AgentType.WORKER
|
||||
assert AgentScopeRunner._resolve_stage_agent_type("unknown") == AgentType.WORKER
|
||||
def _user_context() -> UserContext:
|
||||
return UserContext(
|
||||
id="00000000-0000-0000-0000-000000000001",
|
||||
username="alice",
|
||||
phone="+8613900000000",
|
||||
settings=parse_profile_settings(None),
|
||||
)
|
||||
|
||||
|
||||
def test_resolve_stage_agent_type_supports_memory() -> None:
|
||||
assert AgentScopeRunner._resolve_stage_agent_type("memory") == AgentType.MEMORY
|
||||
def test_parse_agent_type_supports_known_stages() -> None:
|
||||
assert AgentScopeRunner._parse_agent_type(stage_name="router") == AgentType.ROUTER
|
||||
assert AgentScopeRunner._parse_agent_type(stage_name="worker") == AgentType.WORKER
|
||||
assert AgentScopeRunner._parse_agent_type(stage_name="memory") == AgentType.MEMORY
|
||||
|
||||
|
||||
def test_parse_agent_type_rejects_unknown_stage() -> None:
|
||||
with pytest.raises(ValueError, match="unsupported stage name"):
|
||||
AgentScopeRunner._parse_agent_type(stage_name="planner")
|
||||
|
||||
|
||||
def test_build_worker_input_messages_only_contains_router_contract() -> None:
|
||||
runner = AgentScopeRunner()
|
||||
router_output = RouterAgentOutput(
|
||||
normalized_task_input=NormalizedTaskInput(user_text="安排明天会议"),
|
||||
key_entities=[],
|
||||
constraints=[],
|
||||
task_typing=TaskTyping(primary=TaskType.SCHEDULING),
|
||||
execution_mode=ExecutionMode.TOOL_ASSISTED,
|
||||
result_typing=ResultTyping(primary=ResultType.EXECUTION_REPORT),
|
||||
ui=RouterUiDecision(
|
||||
ui_mode=UiMode.NONE,
|
||||
ui_decision_reason="单一执行任务,文本输出足够",
|
||||
),
|
||||
)
|
||||
|
||||
input_messages = runner._build_worker_input_messages(router_output=router_output)
|
||||
|
||||
assert len(input_messages) == 1
|
||||
assert input_messages[0].role == "user"
|
||||
assert "[RouterAgentOutput]" in str(input_messages[0].content)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -43,11 +90,12 @@ async def test_resolve_runtime_client_time_from_forwarded_props() -> None:
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {
|
||||
"agent_type": "worker",
|
||||
"client_time": {
|
||||
"device_timezone": "America/Los_Angeles",
|
||||
"client_now_iso": "2026-03-16T09:12:33-07:00",
|
||||
"client_epoch_ms": 1773658353000,
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
@@ -55,3 +103,146 @@ async def test_resolve_runtime_client_time_from_forwarded_props() -> None:
|
||||
resolved = runner._resolve_runtime_client_time(run_input=run_input)
|
||||
assert resolved is not None
|
||||
assert resolved.device_timezone == "America/Los_Angeles"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_worker_mode_runs_router_then_worker(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
class _FakePipeline:
|
||||
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
|
||||
del session_id, event
|
||||
return "1-0"
|
||||
|
||||
class _FakeSessionCtx:
|
||||
async def __aenter__(self) -> object:
|
||||
return object()
|
||||
|
||||
async def __aexit__(self, exc_type: object, exc: object, tb: object) -> None:
|
||||
del exc_type, exc, tb
|
||||
|
||||
runner = AgentScopeRunner()
|
||||
load_calls: list[AgentType] = []
|
||||
|
||||
async def _fake_load_stage_config(*, session: object, agent_type: AgentType):
|
||||
del session
|
||||
load_calls.append(agent_type)
|
||||
return runner_module.SystemAgentRuntimeConfig(
|
||||
agent_type=agent_type,
|
||||
model_code="demo",
|
||||
api_base_url="https://example.com",
|
||||
api_key="test",
|
||||
llm_config=runner_module.SystemAgentLLMConfig(),
|
||||
)
|
||||
|
||||
async def _fake_execute_router_step(**kwargs: object) -> RouterAgentOutput:
|
||||
del kwargs
|
||||
return RouterAgentOutput(
|
||||
normalized_task_input=NormalizedTaskInput(user_text="安排会议"),
|
||||
key_entities=[],
|
||||
constraints=[],
|
||||
task_typing=TaskTyping(primary=TaskType.SCHEDULING),
|
||||
execution_mode=ExecutionMode.TOOL_ASSISTED,
|
||||
result_typing=ResultTyping(primary=ResultType.EXECUTION_REPORT),
|
||||
ui=RouterUiDecision(
|
||||
ui_mode=UiMode.NONE,
|
||||
ui_decision_reason="单任务",
|
||||
),
|
||||
)
|
||||
|
||||
async def _fake_execute_worker_step(**kwargs: object) -> WorkerAgentOutputLite:
|
||||
del kwargs
|
||||
return WorkerAgentOutputLite(answer="ok")
|
||||
|
||||
monkeypatch.setattr(runner_module, "AsyncSessionLocal", lambda: _FakeSessionCtx())
|
||||
monkeypatch.setattr(runner, "_load_stage_config", _fake_load_stage_config)
|
||||
monkeypatch.setattr(runner, "_build_stage_toolkit", lambda **kwargs: object())
|
||||
monkeypatch.setattr(runner, "_execute_router_step", _fake_execute_router_step)
|
||||
monkeypatch.setattr(runner, "_execute_worker_step", _fake_execute_worker_step)
|
||||
|
||||
result = await runner.execute(
|
||||
user_context=_user_context(),
|
||||
context_messages=[],
|
||||
pipeline=_FakePipeline(),
|
||||
run_input=_run_input(),
|
||||
system_agent_mode="worker",
|
||||
)
|
||||
|
||||
assert load_calls == [AgentType.ROUTER, AgentType.WORKER]
|
||||
assert result["router"]["normalized_task_input"]["user_text"] == "安排会议"
|
||||
assert result["worker"]["answer"] == "ok"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_memory_mode_requires_memory_job_config() -> None:
|
||||
runner = AgentScopeRunner()
|
||||
|
||||
class _FakePipeline:
|
||||
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
|
||||
del session_id, event
|
||||
return "1-0"
|
||||
|
||||
with pytest.raises(RuntimeError, match="memory job config is required"):
|
||||
await runner.execute(
|
||||
user_context=_user_context(),
|
||||
context_messages=[],
|
||||
pipeline=_FakePipeline(),
|
||||
run_input=_run_input(),
|
||||
system_agent_mode="memory",
|
||||
memory_job_config=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_memory_mode_uses_memory_job_config(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
class _FakePipeline:
|
||||
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
|
||||
del session_id, event
|
||||
return "1-0"
|
||||
|
||||
class _FakeSessionCtx:
|
||||
async def __aenter__(self) -> object:
|
||||
return object()
|
||||
|
||||
async def __aexit__(self, exc_type: object, exc: object, tb: object) -> None:
|
||||
del exc_type, exc, tb
|
||||
|
||||
runner = AgentScopeRunner()
|
||||
|
||||
async def _fake_build_memory_stage_config(**kwargs: object):
|
||||
del kwargs
|
||||
return runner_module.SystemAgentRuntimeConfig(
|
||||
agent_type=AgentType.MEMORY,
|
||||
model_code="qwen3.5-flash",
|
||||
api_base_url="https://example.com",
|
||||
api_key="test",
|
||||
llm_config=runner_module.SystemAgentLLMConfig(),
|
||||
)
|
||||
|
||||
async def _fake_execute_single_stage_step(**kwargs: object):
|
||||
del kwargs
|
||||
return runner_module.AgentOutput(answer="memory")
|
||||
|
||||
monkeypatch.setattr(runner_module, "AsyncSessionLocal", lambda: _FakeSessionCtx())
|
||||
monkeypatch.setattr(
|
||||
runner, "_build_memory_stage_config", _fake_build_memory_stage_config
|
||||
)
|
||||
monkeypatch.setattr(runner, "_build_stage_toolkit", lambda **kwargs: object())
|
||||
monkeypatch.setattr(
|
||||
runner,
|
||||
"_execute_single_stage_step",
|
||||
_fake_execute_single_stage_step,
|
||||
)
|
||||
|
||||
result = await runner.execute(
|
||||
user_context=_user_context(),
|
||||
context_messages=[],
|
||||
pipeline=_FakePipeline(),
|
||||
run_input=_run_input(),
|
||||
system_agent_mode="memory",
|
||||
memory_job_config=default_memory_job_config(),
|
||||
)
|
||||
|
||||
assert result["memory"]["answer"] == "memory"
|
||||
|
||||
@@ -18,7 +18,7 @@ def _run_input_payload() -> dict[str, Any]:
|
||||
"messages": [],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
"forwardedProps": {"agent_type": "worker"},
|
||||
}
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ async def _fake_user_context(**kwargs: object) -> UserContext:
|
||||
return UserContext(
|
||||
id=str(uuid4()),
|
||||
username="alice",
|
||||
email="alice@example.com",
|
||||
phone="+8613900000000",
|
||||
settings=parse_profile_settings(None),
|
||||
)
|
||||
|
||||
@@ -177,17 +177,53 @@ async def test_run_agentscope_task_rejects_invalid_command_type() -> None:
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agentscope_task_requires_forwarded_props_agent_type() -> None:
|
||||
payload = _run_input_payload()
|
||||
payload["forwardedProps"] = {}
|
||||
|
||||
with pytest.raises(ValueError, match="invalid RunAgentInput.forwardedProps"):
|
||||
await tasks_module.run_agentscope_task(
|
||||
{
|
||||
"command": "run",
|
||||
"owner_id": str(uuid4()),
|
||||
"run_input": payload,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agentscope_task_memory_mode_requires_automation_job_id() -> None:
|
||||
payload = _run_input_payload()
|
||||
payload["forwardedProps"] = {"agent_type": "memory"}
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="automation_job_id is required for memory mode"
|
||||
):
|
||||
await tasks_module.run_agentscope_task(
|
||||
{
|
||||
"command": "run",
|
||||
"owner_id": str(uuid4()),
|
||||
"run_input": payload,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_recent_context_messages_includes_all_user_attachments(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
class _FakeAgentService:
|
||||
async def load_agent_input_messages(
|
||||
class _FakeContextService:
|
||||
def __init__(self, *, repository: object) -> None:
|
||||
del repository
|
||||
|
||||
async def load_context_messages(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
system_agent_mode: str,
|
||||
) -> dict[str, object] | None:
|
||||
del thread_id
|
||||
del thread_id, system_agent_mode
|
||||
return {
|
||||
"messages": [
|
||||
{
|
||||
@@ -215,14 +251,13 @@ async def test_build_recent_context_messages_includes_all_user_attachments(
|
||||
async def download_bytes(self, *, bucket: str, path: str) -> bytes:
|
||||
return f"{bucket}:{path}".encode("utf-8")
|
||||
|
||||
monkeypatch.setattr(
|
||||
tasks_module, "get_agent_service", lambda session: _FakeAgentService()
|
||||
)
|
||||
monkeypatch.setattr(tasks_module, "AgentContextService", _FakeContextService)
|
||||
monkeypatch.setattr(tasks_module, "supabase_service", _FakeSupabase())
|
||||
|
||||
messages = await tasks_module._build_recent_context_messages(
|
||||
session=object(),
|
||||
thread_id=str(uuid4()),
|
||||
context_mode="worker",
|
||||
)
|
||||
|
||||
assert len(messages) == 1
|
||||
@@ -238,13 +273,17 @@ async def test_build_recent_context_messages_includes_all_user_attachments(
|
||||
async def test_build_recent_context_messages_uses_tool_metadata_output(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
class _FakeAgentService:
|
||||
async def load_agent_input_messages(
|
||||
class _FakeContextService:
|
||||
def __init__(self, *, repository: object) -> None:
|
||||
del repository
|
||||
|
||||
async def load_context_messages(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
system_agent_mode: str,
|
||||
) -> dict[str, object] | None:
|
||||
del thread_id
|
||||
del thread_id, system_agent_mode
|
||||
return {
|
||||
"messages": [
|
||||
{
|
||||
@@ -268,13 +307,12 @@ async def test_build_recent_context_messages_uses_tool_metadata_output(
|
||||
]
|
||||
}
|
||||
|
||||
monkeypatch.setattr(
|
||||
tasks_module, "get_agent_service", lambda session: _FakeAgentService()
|
||||
)
|
||||
monkeypatch.setattr(tasks_module, "AgentContextService", _FakeContextService)
|
||||
|
||||
messages = await tasks_module._build_recent_context_messages(
|
||||
session=object(),
|
||||
thread_id=str(uuid4()),
|
||||
context_mode="worker",
|
||||
)
|
||||
|
||||
assert len(messages) == 1
|
||||
@@ -290,13 +328,17 @@ async def test_build_recent_context_messages_uses_tool_metadata_output(
|
||||
async def test_build_recent_context_messages_skips_tool_without_metadata_output(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
class _FakeAgentService:
|
||||
async def load_agent_input_messages(
|
||||
class _FakeContextService:
|
||||
def __init__(self, *, repository: object) -> None:
|
||||
del repository
|
||||
|
||||
async def load_context_messages(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
system_agent_mode: str,
|
||||
) -> dict[str, object] | None:
|
||||
del thread_id
|
||||
del thread_id, system_agent_mode
|
||||
return {
|
||||
"messages": [
|
||||
{
|
||||
@@ -307,13 +349,44 @@ async def test_build_recent_context_messages_skips_tool_without_metadata_output(
|
||||
]
|
||||
}
|
||||
|
||||
monkeypatch.setattr(
|
||||
tasks_module, "get_agent_service", lambda session: _FakeAgentService()
|
||||
)
|
||||
monkeypatch.setattr(tasks_module, "AgentContextService", _FakeContextService)
|
||||
|
||||
messages = await tasks_module._build_recent_context_messages(
|
||||
session=object(),
|
||||
thread_id=str(uuid4()),
|
||||
context_mode="worker",
|
||||
)
|
||||
|
||||
assert messages == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_recent_context_messages_passes_context_mode_through(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
captured_mode: dict[str, str | None] = {"mode": None}
|
||||
|
||||
class _FakeContextService:
|
||||
def __init__(self, *, repository: object) -> None:
|
||||
del repository
|
||||
|
||||
async def load_context_messages(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
system_agent_mode: str,
|
||||
) -> dict[str, object] | None:
|
||||
del thread_id
|
||||
captured_mode["mode"] = system_agent_mode
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(tasks_module, "AgentContextService", _FakeContextService)
|
||||
|
||||
messages = await tasks_module._build_recent_context_messages(
|
||||
session=object(),
|
||||
thread_id=str(uuid4()),
|
||||
context_mode="worker",
|
||||
)
|
||||
|
||||
assert messages == []
|
||||
assert captured_mode["mode"] == "worker"
|
||||
|
||||
@@ -1,99 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.agentscope import schemas as exported_schemas
|
||||
from core.agentscope.schemas.agent_runtime import (
|
||||
AcceptedTaskResponse,
|
||||
AgUiWireEvent,
|
||||
HistorySnapshot,
|
||||
HistorySnapshotResponse,
|
||||
InternalRuntimeEvent,
|
||||
RunCommand,
|
||||
pytest.skip(
|
||||
"legacy agent_runtime schemas removed; covered by agui_input tests",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
|
||||
def test_run_command_alias_roundtrip() -> None:
|
||||
payload = {
|
||||
"threadId": "thread-001",
|
||||
"runId": "run-001",
|
||||
"state": {"cursor": 1},
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"tools": [{"name": "calendar.lookup"}],
|
||||
"context": {"locale": "zh-CN"},
|
||||
"forwardedProps": {"traceId": "trace-1"},
|
||||
}
|
||||
|
||||
command = RunCommand.model_validate(payload)
|
||||
|
||||
assert command.thread_id == "thread-001"
|
||||
assert command.run_id == "run-001"
|
||||
assert command.forwarded_props == {"traceId": "trace-1"}
|
||||
|
||||
dumped = command.model_dump(mode="json", by_alias=True)
|
||||
assert dumped["threadId"] == "thread-001"
|
||||
assert dumped["runId"] == "run-001"
|
||||
assert dumped["forwardedProps"] == {"traceId": "trace-1"}
|
||||
|
||||
|
||||
def test_history_snapshot_response_shape() -> None:
|
||||
response = HistorySnapshotResponse(
|
||||
threadId="thread-123",
|
||||
snapshot=HistorySnapshot(
|
||||
threadId="thread-123",
|
||||
day="2026-03-11",
|
||||
hasMore=False,
|
||||
messages=[{"id": "msg-1"}],
|
||||
),
|
||||
)
|
||||
|
||||
dumped = response.model_dump(mode="json", by_alias=True, exclude_none=True)
|
||||
assert dumped["type"] == "STATE_SNAPSHOT"
|
||||
assert dumped["threadId"] == "thread-123"
|
||||
assert dumped["snapshot"]["scope"] == "history_day"
|
||||
assert dumped["snapshot"]["hasMore"] is False
|
||||
assert dumped["snapshot"]["messages"] == [{"id": "msg-1"}]
|
||||
|
||||
|
||||
def test_runtime_event_validation_basics() -> None:
|
||||
internal = InternalRuntimeEvent(type="RUN_STARTED", data={"step": 1})
|
||||
assert internal.type == "RUN_STARTED"
|
||||
assert internal.model_dump(mode="json", by_alias=True)["data"] == {"step": 1}
|
||||
|
||||
wire = AgUiWireEvent(type="TEXT_MESSAGE_CONTENT", payload={"delta": "hello"})
|
||||
dumped = wire.model_dump(mode="json", by_alias=True, exclude_none=True)
|
||||
assert dumped["type"] == "TEXT_MESSAGE_CONTENT"
|
||||
assert dumped["payload"] == {"delta": "hello"}
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
InternalRuntimeEvent.model_validate({"threadId": "t-1", "data": {}})
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
AgUiWireEvent.model_validate({"payload": {"delta": "hello"}})
|
||||
|
||||
|
||||
def test_schemas_exports_include_task_and_history_models() -> None:
|
||||
assert exported_schemas.AcceptedTaskResponse is AcceptedTaskResponse
|
||||
assert exported_schemas.TaskAccepted is AcceptedTaskResponse
|
||||
assert exported_schemas.TaskAcceptedResponse is AcceptedTaskResponse
|
||||
assert exported_schemas.HistorySnapshotResponse is HistorySnapshotResponse
|
||||
|
||||
|
||||
def test_run_command_accepts_agui_context_list_and_parent_run_id() -> None:
|
||||
payload = {
|
||||
"threadId": "thread-xyz",
|
||||
"runId": "run-xyz",
|
||||
"state": {},
|
||||
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
"parentRunId": None,
|
||||
}
|
||||
|
||||
command = RunCommand.model_validate(payload)
|
||||
|
||||
dumped = command.model_dump(mode="json", by_alias=True)
|
||||
assert dumped["context"] == []
|
||||
assert "parentRunId" in dumped
|
||||
|
||||
@@ -20,7 +20,7 @@ def _base_payload() -> dict[str, object]:
|
||||
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
"forwardedProps": {"agent_type": "worker"},
|
||||
}
|
||||
|
||||
|
||||
@@ -149,7 +149,7 @@ def test_parse_run_input_accepts_snake_case_aliases() -> None:
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwarded_props": {},
|
||||
"forwarded_props": {"agent_type": "worker"},
|
||||
}
|
||||
|
||||
run_input = parse_run_input(payload)
|
||||
@@ -162,11 +162,12 @@ def test_parse_run_input_accepts_snake_case_aliases() -> None:
|
||||
def test_parse_run_input_accepts_client_time_forwarded_props() -> None:
|
||||
payload = _base_payload()
|
||||
payload["forwardedProps"] = {
|
||||
"agent_type": "worker",
|
||||
"client_time": {
|
||||
"device_timezone": "America/Los_Angeles",
|
||||
"client_now_iso": "2026-03-16T09:12:33-07:00",
|
||||
"client_epoch_ms": 1773658353000,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
run_input = parse_run_input(payload)
|
||||
@@ -177,11 +178,12 @@ def test_parse_run_input_accepts_client_time_forwarded_props() -> None:
|
||||
def test_parse_run_input_rejects_invalid_client_time_timezone() -> None:
|
||||
payload = _base_payload()
|
||||
payload["forwardedProps"] = {
|
||||
"agent_type": "worker",
|
||||
"client_time": {
|
||||
"device_timezone": "Mars/OlympusMons",
|
||||
"client_now_iso": "2026-03-16T09:12:33-07:00",
|
||||
"client_epoch_ms": 1773658353000,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="invalid RunAgentInput.forwardedProps"):
|
||||
@@ -191,11 +193,12 @@ def test_parse_run_input_rejects_invalid_client_time_timezone() -> None:
|
||||
def test_parse_run_input_rejects_invalid_client_time_now_iso() -> None:
|
||||
payload = _base_payload()
|
||||
payload["forwardedProps"] = {
|
||||
"agent_type": "worker",
|
||||
"client_time": {
|
||||
"device_timezone": "America/Los_Angeles",
|
||||
"client_now_iso": "2026-03-16 09:12:33",
|
||||
"client_epoch_ms": 1773658353000,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="invalid RunAgentInput.forwardedProps"):
|
||||
@@ -205,11 +208,12 @@ def test_parse_run_input_rejects_invalid_client_time_now_iso() -> None:
|
||||
def test_parse_run_input_rejects_invalid_client_time_epoch_type() -> None:
|
||||
payload = _base_payload()
|
||||
payload["forwardedProps"] = {
|
||||
"agent_type": "worker",
|
||||
"client_time": {
|
||||
"device_timezone": "America/Los_Angeles",
|
||||
"client_now_iso": "2026-03-16T09:12:33-07:00",
|
||||
"client_epoch_ms": "1773658353000",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="invalid RunAgentInput.forwardedProps"):
|
||||
@@ -219,6 +223,7 @@ def test_parse_run_input_rejects_invalid_client_time_epoch_type() -> None:
|
||||
def test_parse_run_input_rejects_unknown_forwarded_props_key() -> None:
|
||||
payload = _base_payload()
|
||||
payload["forwardedProps"] = {
|
||||
"agent_type": "worker",
|
||||
"client_time": {
|
||||
"device_timezone": "America/Los_Angeles",
|
||||
"client_now_iso": "2026-03-16T09:12:33-07:00",
|
||||
@@ -229,3 +234,17 @@ def test_parse_run_input_rejects_unknown_forwarded_props_key() -> None:
|
||||
|
||||
with pytest.raises(ValueError, match="invalid RunAgentInput.forwardedProps"):
|
||||
parse_run_input(payload)
|
||||
|
||||
|
||||
def test_parse_run_input_rejects_missing_forwarded_props_agent_type() -> None:
|
||||
payload = _base_payload()
|
||||
payload["forwardedProps"] = {
|
||||
"client_time": {
|
||||
"device_timezone": "America/Los_Angeles",
|
||||
"client_now_iso": "2026-03-16T09:12:33-07:00",
|
||||
"client_epoch_ms": 1773658353000,
|
||||
}
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="invalid RunAgentInput.forwardedProps"):
|
||||
parse_run_input(payload)
|
||||
|
||||
@@ -11,7 +11,7 @@ def test_build_agent_prompt_for_worker_contains_runtime_config() -> None:
|
||||
{
|
||||
"temperature": 0.2,
|
||||
"context_messages": {"mode": "number", "count": 20},
|
||||
"enabled_tool_groups": ["read", "write"],
|
||||
"enabled_tools": ["calendar.read", "calendar.write"],
|
||||
}
|
||||
),
|
||||
)
|
||||
@@ -20,7 +20,7 @@ def test_build_agent_prompt_for_worker_contains_runtime_config() -> None:
|
||||
assert "- type: worker" in prompt
|
||||
assert "context_messages.mode=number" in prompt
|
||||
assert "context_messages.count=20" in prompt
|
||||
assert "enabled_tool_groups=read,write" in prompt
|
||||
assert "enabled_tools=calendar.read,calendar.write" in prompt
|
||||
|
||||
|
||||
def test_build_agent_prompt_for_memory_uses_memory_rules() -> None:
|
||||
@@ -29,7 +29,7 @@ def test_build_agent_prompt_for_memory_uses_memory_rules() -> None:
|
||||
llm_config=SystemAgentLLMConfig.model_validate(
|
||||
{
|
||||
"context_messages": {"mode": "day", "count": 2},
|
||||
"enabled_tool_groups": ["read"],
|
||||
"enabled_tools": ["user.lookup"],
|
||||
}
|
||||
),
|
||||
)
|
||||
@@ -38,4 +38,4 @@ def test_build_agent_prompt_for_memory_uses_memory_rules() -> None:
|
||||
assert "[Memory Agent]" in prompt
|
||||
assert "context_messages.mode=day" in prompt
|
||||
assert "context_messages.count=2" in prompt
|
||||
assert "enabled_tool_groups=read" in prompt
|
||||
assert "enabled_tools=user.lookup" in prompt
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agentscope.tools.hitl_middleware import create_hitl_middleware
|
||||
from core.agentscope.tools.tool_meta import TOOL_META, ToolMeta
|
||||
from core.agentscope.tools.tool_config import ToolApprovalConfig, ToolConfig, ToolGroup
|
||||
from core.agentscope.tools.tool_middleware import create_approval_middleware
|
||||
|
||||
|
||||
async def _next_handler(**kwargs: Any) -> AsyncGenerator[dict[str, object], None]:
|
||||
@@ -15,9 +16,31 @@ async def _next_handler(**kwargs: Any) -> AsyncGenerator[dict[str, object], None
|
||||
return _generator()
|
||||
|
||||
|
||||
def _extract_error_payload(chunk: object) -> dict[str, Any]:
|
||||
content = getattr(chunk, "content", None)
|
||||
if not isinstance(content, list) or not content:
|
||||
return {}
|
||||
first_block = content[0]
|
||||
text = getattr(first_block, "text", None)
|
||||
if not isinstance(text, str) and isinstance(first_block, dict):
|
||||
raw_text = first_block.get("text")
|
||||
text = raw_text if isinstance(raw_text, str) else None
|
||||
if not isinstance(text, str):
|
||||
return {}
|
||||
return json.loads(text)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hitl_middleware_default_write_does_not_require_approval() -> None:
|
||||
middleware = create_hitl_middleware(meta_by_name=TOOL_META)
|
||||
middleware = create_approval_middleware(
|
||||
config_by_name={
|
||||
"calendar_write": ToolConfig(
|
||||
name="calendar_write",
|
||||
group=ToolGroup.EXECUTE,
|
||||
approval=ToolApprovalConfig(required=False),
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
responses = []
|
||||
async for chunk in middleware(
|
||||
@@ -30,36 +53,39 @@ async def test_hitl_middleware_default_write_does_not_require_approval() -> None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hitl_middleware_pending_when_tool_requires_approval(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
middleware = create_hitl_middleware(
|
||||
meta_by_name={
|
||||
"calendar.write": ToolMeta(name="calendar.write", requires_approval=True)
|
||||
async def test_hitl_middleware_pending_when_tool_requires_approval() -> None:
|
||||
middleware = create_approval_middleware(
|
||||
config_by_name={
|
||||
"calendar_write": ToolConfig(
|
||||
name="calendar_write",
|
||||
group=ToolGroup.EXECUTE,
|
||||
approval=ToolApprovalConfig(required=True),
|
||||
)
|
||||
}
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agentscope.tools.hitl_middleware.build_tool_response",
|
||||
lambda payload: payload,
|
||||
)
|
||||
|
||||
responses = []
|
||||
async for chunk in middleware(
|
||||
{"tool_call": {"name": "calendar.write", "input": {"operation": "create"}}},
|
||||
{"tool_call": {"name": "calendar_write", "input": {"operation": "create"}}},
|
||||
_next_handler,
|
||||
):
|
||||
responses.append(chunk)
|
||||
|
||||
assert responses[0]["data"]["status"] == "pending"
|
||||
payload = _extract_error_payload(responses[0])
|
||||
assert payload["error"]["code"] == "TOOL_PENDING_APPROVAL"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hitl_middleware_passes_when_write_approved() -> None:
|
||||
middleware = create_hitl_middleware(
|
||||
meta_by_name={
|
||||
"calendar.write": ToolMeta(name="calendar.write", requires_approval=True)
|
||||
middleware = create_approval_middleware(
|
||||
config_by_name={
|
||||
"calendar_write": ToolConfig(
|
||||
name="calendar_write",
|
||||
group=ToolGroup.EXECUTE,
|
||||
approval=ToolApprovalConfig(required=True),
|
||||
)
|
||||
},
|
||||
approval_resolver=lambda _name, _args: "approved",
|
||||
approval_resolver=lambda _name, _args, _config: "approved",
|
||||
)
|
||||
|
||||
responses = []
|
||||
@@ -69,6 +95,7 @@ async def test_hitl_middleware_passes_when_write_approved() -> None:
|
||||
"name": "calendar.write",
|
||||
"input": {
|
||||
"operation": "create",
|
||||
"_hitl": {"approval": "required"},
|
||||
},
|
||||
}
|
||||
},
|
||||
@@ -82,25 +109,24 @@ async def test_hitl_middleware_passes_when_write_approved() -> None:
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hitl_middleware_rejected_short_circuits(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
middleware = create_hitl_middleware(
|
||||
meta_by_name={
|
||||
"calendar.write": ToolMeta(name="calendar.write", requires_approval=True)
|
||||
async def test_hitl_middleware_rejected_short_circuits() -> None:
|
||||
middleware = create_approval_middleware(
|
||||
config_by_name={
|
||||
"calendar_write": ToolConfig(
|
||||
name="calendar_write",
|
||||
group=ToolGroup.EXECUTE,
|
||||
approval=ToolApprovalConfig(required=True),
|
||||
)
|
||||
},
|
||||
approval_resolver=lambda _name, _args: "rejected",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agentscope.tools.hitl_middleware.build_tool_response",
|
||||
lambda payload: payload,
|
||||
approval_resolver=lambda _name, _args, _config: "rejected",
|
||||
)
|
||||
|
||||
responses = []
|
||||
async for chunk in middleware(
|
||||
{"tool_call": {"name": "calendar.write", "input": {"operation": "create"}}},
|
||||
{"tool_call": {"name": "calendar_write", "input": {"operation": "create"}}},
|
||||
_next_handler,
|
||||
):
|
||||
responses.append(chunk)
|
||||
|
||||
assert responses[0]["data"]["status"] == "rejected"
|
||||
payload = _extract_error_payload(responses[0])
|
||||
assert payload["error"]["code"] == "TOOL_REJECTED"
|
||||
|
||||
@@ -16,7 +16,7 @@ def _build_user_context(*, timezone_name: str = "Asia/Shanghai") -> UserContext:
|
||||
return UserContext(
|
||||
id=str(uuid4()),
|
||||
username="alice",
|
||||
email="alice@example.com",
|
||||
phone="+8613900000000",
|
||||
bio="focus on calendars",
|
||||
settings=parse_profile_settings(
|
||||
{
|
||||
|
||||
@@ -7,7 +7,9 @@ from core.agentscope.tools.toolkit import build_stage_toolkit
|
||||
from schemas.agent.system_agent import AgentType
|
||||
|
||||
|
||||
def test_build_stage_toolkit_filters_requested_tools_by_agent_type(monkeypatch) -> None:
|
||||
def test_build_stage_toolkit_uses_explicit_enabled_tools_as_final_set(
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
def _fake_build_toolkit(**kwargs):
|
||||
@@ -22,7 +24,29 @@ def test_build_stage_toolkit_filters_requested_tools_by_agent_type(monkeypatch)
|
||||
agent_type=AgentType.WORKER,
|
||||
session=cast(Any, object()),
|
||||
owner_id=uuid4(),
|
||||
enabled_tool_names={"calendar_read", "calendar_write", "user_lookup"},
|
||||
enabled_tool_names={"calendar_read", "user_lookup"},
|
||||
)
|
||||
|
||||
assert captured["enabled_tool_names"] == {"calendar_read", "user_lookup"}
|
||||
|
||||
|
||||
def test_build_stage_toolkit_uses_memory_defaults_without_explicit_tools(
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
def _fake_build_toolkit(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return object()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.agentscope.tools.toolkit.build_toolkit", _fake_build_toolkit
|
||||
)
|
||||
|
||||
build_stage_toolkit(
|
||||
agent_type=AgentType.MEMORY,
|
||||
session=cast(Any, object()),
|
||||
owner_id=uuid4(),
|
||||
)
|
||||
|
||||
assert captured["enabled_tool_names"] == {"calendar_read", "user_lookup"}
|
||||
|
||||
@@ -0,0 +1,134 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.automation.scheduler import (
|
||||
AutomationSchedulerService,
|
||||
_compute_next_run_at,
|
||||
)
|
||||
from models.automation_jobs import ScheduleType
|
||||
from schemas.automation.config import AutomationJobConfig
|
||||
from schemas.automation.scheduler import DueAutomationJob
|
||||
|
||||
|
||||
class _FakeRepository:
|
||||
def __init__(self, jobs: list[DueAutomationJob]) -> None:
|
||||
self.jobs = jobs
|
||||
self.marked: list[tuple[UUID, datetime, datetime]] = []
|
||||
self.commits = 0
|
||||
self.rollbacks = 0
|
||||
|
||||
async def list_due_jobs(
|
||||
self, *, now_utc: datetime, limit: int
|
||||
) -> list[DueAutomationJob]:
|
||||
del now_utc
|
||||
return self.jobs[:limit]
|
||||
|
||||
async def get_job_config(self, *, job_id: UUID) -> AutomationJobConfig:
|
||||
del job_id
|
||||
return AutomationJobConfig.model_validate(
|
||||
{
|
||||
"agent_type": "memory",
|
||||
"model_code": "qwen3.5-flash",
|
||||
"enabled_tools": ["calendar.read", "user.lookup"],
|
||||
"input_template": "auto input",
|
||||
"context": {
|
||||
"source": "latest_chat",
|
||||
"window_mode": "day",
|
||||
"window_count": 2,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
async def ensure_latest_chat_session(self, *, owner_id: UUID) -> UUID:
|
||||
return owner_id
|
||||
|
||||
async def mark_job_dispatched(
|
||||
self,
|
||||
*,
|
||||
job_id: UUID,
|
||||
next_run_at: datetime,
|
||||
last_run_at: datetime,
|
||||
) -> None:
|
||||
self.marked.append((job_id, next_run_at, last_run_at))
|
||||
|
||||
async def commit(self) -> None:
|
||||
self.commits += 1
|
||||
|
||||
async def rollback(self) -> None:
|
||||
self.rollbacks += 1
|
||||
|
||||
|
||||
class _FakeQueue:
|
||||
def __init__(self) -> None:
|
||||
self.commands: list[dict[str, object]] = []
|
||||
|
||||
async def enqueue(
|
||||
self,
|
||||
*,
|
||||
command: dict[str, object],
|
||||
dedup_key: str | None,
|
||||
) -> str:
|
||||
del dedup_key
|
||||
self.commands.append(command)
|
||||
return "task-1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_and_dispatch_enqueues_memory_run_command() -> None:
|
||||
now = datetime(2026, 3, 19, 12, 0, tzinfo=timezone.utc)
|
||||
owner_id = uuid4()
|
||||
job_id = uuid4()
|
||||
repo = _FakeRepository(
|
||||
jobs=[
|
||||
DueAutomationJob(
|
||||
id=job_id,
|
||||
owner_id=owner_id,
|
||||
schedule_type=ScheduleType.DAILY,
|
||||
timezone="UTC",
|
||||
next_run_at=now - timedelta(minutes=1),
|
||||
)
|
||||
]
|
||||
)
|
||||
queue = _FakeQueue()
|
||||
service = AutomationSchedulerService(repository=repo, queue=queue)
|
||||
|
||||
result = await service.scan_and_dispatch(now_utc=now, limit=10)
|
||||
|
||||
assert result.scanned == 1
|
||||
assert result.dispatched == 1
|
||||
assert len(queue.commands) == 1
|
||||
run_input = queue.commands[0]["run_input"]
|
||||
assert isinstance(run_input, dict)
|
||||
assert run_input["forwardedProps"] == {"agent_type": "memory"}
|
||||
assert queue.commands[0]["automation_job_id"] == str(job_id)
|
||||
assert repo.commits == 1
|
||||
|
||||
|
||||
def test_compute_next_run_at_daily() -> None:
|
||||
now = datetime(2026, 3, 19, 12, 0, tzinfo=timezone.utc)
|
||||
current = datetime(2026, 3, 19, 11, 0, tzinfo=timezone.utc)
|
||||
|
||||
computed = _compute_next_run_at(
|
||||
current_next_run_at=current,
|
||||
now_utc=now,
|
||||
schedule_type=ScheduleType.DAILY,
|
||||
)
|
||||
|
||||
assert computed == datetime(2026, 3, 20, 11, 0, tzinfo=timezone.utc)
|
||||
|
||||
|
||||
def test_compute_next_run_at_weekly() -> None:
|
||||
now = datetime(2026, 3, 19, 12, 0, tzinfo=timezone.utc)
|
||||
current = datetime(2026, 3, 10, 11, 0, tzinfo=timezone.utc)
|
||||
|
||||
computed = _compute_next_run_at(
|
||||
current_next_run_at=current,
|
||||
now_utc=now,
|
||||
schedule_type=ScheduleType.WEEKLY,
|
||||
)
|
||||
|
||||
assert computed == datetime(2026, 3, 24, 11, 0, tzinfo=timezone.utc)
|
||||
@@ -0,0 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def test_memory_automation_job_trigger_exists_in_0004_migration() -> None:
|
||||
migration = (
|
||||
Path(__file__).resolve().parents[3]
|
||||
/ "alembic"
|
||||
/ "versions"
|
||||
/ "20260319_0004_automation_job_config_for_memory.py"
|
||||
)
|
||||
content = migration.read_text(encoding="utf-8")
|
||||
|
||||
assert "INSERT INTO public.automation_jobs" in content
|
||||
assert "'agent_type', 'memory'" in content
|
||||
assert "ux_automation_jobs_owner_memory_active" in content
|
||||
assert "input_template" in content
|
||||
@@ -0,0 +1,37 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from schemas.agent.consumer_registry import AgentConsumerBinding, ConsumerRegistry
|
||||
from schemas.agent.pipeline_spec import (
|
||||
ContextPolicy,
|
||||
ExecutorKind,
|
||||
PipelineSpec,
|
||||
StageSpec,
|
||||
)
|
||||
|
||||
|
||||
def test_consumer_registry_rejects_duplicate_bits() -> None:
|
||||
with pytest.raises(ValueError, match="duplicate visibility bit"):
|
||||
ConsumerRegistry(
|
||||
bindings=[
|
||||
AgentConsumerBinding(agent_type="router", bit=16),
|
||||
AgentConsumerBinding(agent_type="worker", bit=16),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_pipeline_spec_requires_non_empty_stages() -> None:
|
||||
with pytest.raises(ValueError, match="at least 1 item"):
|
||||
PipelineSpec(mode="worker", stages=[])
|
||||
|
||||
|
||||
def test_stage_spec_normalizes_stage_name() -> None:
|
||||
spec = StageSpec(
|
||||
stage_name=" Worker ",
|
||||
executor_kind=ExecutorKind.REACT,
|
||||
default_visibility_mask=1,
|
||||
context_policy=ContextPolicy(consumer_agent_type="worker", count=20),
|
||||
)
|
||||
|
||||
assert spec.stage_name == "worker"
|
||||
@@ -0,0 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from schemas.agent.system_agent import SystemAgentLLMConfig
|
||||
|
||||
|
||||
def test_system_agent_llm_config_normalizes_enabled_tools_aliases() -> None:
|
||||
config = SystemAgentLLMConfig.model_validate(
|
||||
{
|
||||
"enabled_tools": [
|
||||
"calendar.write",
|
||||
"calendar_write",
|
||||
"user.lookup",
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
assert [tool.value for tool in config.enabled_tools] == [
|
||||
"calendar.write",
|
||||
"user.lookup",
|
||||
]
|
||||
|
||||
|
||||
def test_system_agent_llm_config_rejects_unknown_enabled_tool() -> None:
|
||||
with pytest.raises(ValueError, match="unknown enabled tool"):
|
||||
SystemAgentLLMConfig.model_validate(
|
||||
{
|
||||
"enabled_tools": ["calendar.remove"],
|
||||
}
|
||||
)
|
||||
@@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from schemas.agent.visibility import VisibilityMask, bit_mask
|
||||
|
||||
|
||||
def test_visibility_mask_from_bits_and_contains() -> None:
|
||||
mask = VisibilityMask.from_bits(bits=[0, 16, 18])
|
||||
|
||||
assert mask.contains(bit=0) is True
|
||||
assert mask.contains(bit=16) is True
|
||||
assert mask.contains(bit=17) is False
|
||||
|
||||
|
||||
def test_visibility_mask_rejects_out_of_range_bit() -> None:
|
||||
with pytest.raises(ValueError, match="range"):
|
||||
VisibilityMask.from_bits(bits=[64])
|
||||
|
||||
|
||||
def test_bit_mask_builds_single_bit_integer() -> None:
|
||||
assert bit_mask(bit=0) == 1
|
||||
assert bit_mask(bit=16) == (1 << 16)
|
||||
@@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from schemas.automation.config import AutomationJobConfig, default_memory_job_config
|
||||
|
||||
|
||||
def test_default_memory_job_config_has_expected_defaults() -> None:
|
||||
config = default_memory_job_config()
|
||||
|
||||
assert config.agent_type.value == "memory"
|
||||
assert config.model_code == "qwen3.5-flash"
|
||||
assert config.context.source.value == "latest_chat"
|
||||
|
||||
|
||||
def test_automation_job_config_rejects_non_flash_model() -> None:
|
||||
with pytest.raises(ValueError, match="model_code must be qwen3.5-flash"):
|
||||
AutomationJobConfig.model_validate(
|
||||
{
|
||||
"agent_type": "memory",
|
||||
"model_code": "qwen-plus",
|
||||
"enabled_tools": ["calendar.read"],
|
||||
"input_template": "x",
|
||||
"context": {
|
||||
"source": "latest_chat",
|
||||
"window_mode": "day",
|
||||
"window_count": 2,
|
||||
},
|
||||
}
|
||||
)
|
||||
Reference in New Issue
Block a user