feat(agentscope): add memory system and automation job support
- Add consumer_registry and pipeline_registry for runtime orchestration - Add Visibility schema for message filtering - Add PipelineSpec for agent pipeline configuration - Add automation job models and configuration - Remove memory_prompt.py (consolidated into memory system) - Update runtime components: context_loader, context_service, orchestrator, runner, tasks - Update toolkit: tool_config, tool_middleware, custom tools (calendar, user_lookup) - Add auth_helpers and calendar_domain utilities - Add system_agents.yaml configuration
This commit is contained in:
@@ -10,8 +10,6 @@ from ag_ui.core import (
|
||||
RunErrorEvent,
|
||||
StepStartedEvent,
|
||||
StepFinishedEvent,
|
||||
TextMessageEndEvent,
|
||||
ToolCallResultEvent,
|
||||
)
|
||||
from core.agentscope.runtime.ui_compiler import compile as compile_ui_hints
|
||||
from schemas.agent.ui_hints import UiHintsPayload
|
||||
@@ -94,9 +92,20 @@ def _build_run_finished(event: dict[str, Any]) -> RunFinishedEvent:
|
||||
|
||||
def _build_run_error(event: dict[str, Any]) -> RunErrorEvent:
|
||||
data = event.get("data", {})
|
||||
top_level_message = event.get("message")
|
||||
message = top_level_message if isinstance(top_level_message, str) else ""
|
||||
top_level_code = event.get("code")
|
||||
code = top_level_code if isinstance(top_level_code, str) else None
|
||||
if (not message or code is None) and isinstance(data, dict):
|
||||
data_message = data.get("message")
|
||||
if not message and isinstance(data_message, str):
|
||||
message = data_message
|
||||
data_code = data.get("code")
|
||||
if code is None and isinstance(data_code, str):
|
||||
code = data_code
|
||||
return RunErrorEvent(
|
||||
message=data.get("message", "Unknown error"),
|
||||
code=data.get("code"),
|
||||
message=message or "Unknown error",
|
||||
code=code,
|
||||
)
|
||||
|
||||
|
||||
@@ -120,34 +129,12 @@ def _build_step_finished(event: dict[str, Any]) -> StepFinishedEvent:
|
||||
)
|
||||
|
||||
|
||||
def _build_text_end(event: dict[str, Any]) -> TextMessageEndEvent:
|
||||
data = event.get("data", {})
|
||||
return TextMessageEndEvent(
|
||||
message_id=data.get("messageId", ""),
|
||||
)
|
||||
|
||||
|
||||
def _build_tool_result(event: dict[str, Any]) -> ToolCallResultEvent:
|
||||
data = event.get("data", {})
|
||||
content = data.get("result")
|
||||
if not isinstance(content, str):
|
||||
content = ""
|
||||
return ToolCallResultEvent(
|
||||
message_id=data.get("messageId", ""),
|
||||
tool_call_id=data.get("toolCallId", ""),
|
||||
content=content,
|
||||
role="tool",
|
||||
)
|
||||
|
||||
|
||||
_BUILDER_MAP: dict[str, Any] = {
|
||||
"run.started": _build_run_started,
|
||||
"run.finished": _build_run_finished,
|
||||
"run.error": _build_run_error,
|
||||
"step.start": _build_step_started,
|
||||
"step.finish": _build_step_finished,
|
||||
"text.end": _build_text_end,
|
||||
"tool.result": _build_tool_result,
|
||||
}
|
||||
|
||||
|
||||
@@ -208,6 +195,8 @@ def to_agui_wire_event(event: dict[str, Any] | BaseEvent) -> dict[str, Any]:
|
||||
payload["runId"] = run_id
|
||||
if isinstance(data, dict):
|
||||
reserved = {"type", "threadId", "runId"}
|
||||
if internal_type == "run.error":
|
||||
reserved = {*reserved, "message", "code"}
|
||||
payload.update({k: v for k, v in data.items() if k not in reserved})
|
||||
return payload
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ class MessageRepository:
|
||||
output_tokens: int = 0,
|
||||
cost: Decimal = Decimal("0"),
|
||||
latency_ms: int | None = None,
|
||||
visibility_mask: int = 0,
|
||||
) -> AgentChatMessage:
|
||||
message = AgentChatMessage(
|
||||
session_id=session_id,
|
||||
@@ -42,6 +43,7 @@ class MessageRepository:
|
||||
output_tokens=output_tokens,
|
||||
cost=cost,
|
||||
latency_ms=latency_ms,
|
||||
visibility_mask=max(int(visibility_mask), 0),
|
||||
)
|
||||
self._session.add(message)
|
||||
await self._session.flush()
|
||||
|
||||
@@ -8,9 +8,12 @@ from core.agentscope.events.persistence import MessageRepository, SessionReposit
|
||||
from core.logging import get_logger
|
||||
from models.agent_chat_message import AgentChatMessageRole
|
||||
from models.agent_chat_session import AgentChatSessionStatus
|
||||
from models.system_agents import SystemAgents
|
||||
from schemas.agent.system_agent import AgentType, SystemAgentLLMConfig
|
||||
from schemas.agent.runtime_models import AgentOutput, ToolAgentOutput
|
||||
from schemas.agent.system_agent import AgentType
|
||||
from schemas.agent.visibility import SystemVisibilityBit, bit_mask
|
||||
from schemas.messages.chat_message import AgentChatMessageMetadata
|
||||
from sqlalchemy import select
|
||||
|
||||
|
||||
class EventStore(Protocol):
|
||||
@@ -45,6 +48,9 @@ class SqlAlchemyEventStore:
|
||||
async with self._session_factory() as session:
|
||||
session_repo = SessionRepository(session)
|
||||
message_repo = MessageRepository(session)
|
||||
stage_visibility_bit_map = await self._load_stage_visibility_bit_map(
|
||||
session=session
|
||||
)
|
||||
chat_session = await session_repo.get_session(session_id=session_id)
|
||||
if chat_session is None:
|
||||
return
|
||||
@@ -77,6 +83,7 @@ class SqlAlchemyEventStore:
|
||||
chat_session=chat_session,
|
||||
session_repo=session_repo,
|
||||
message_repo=message_repo,
|
||||
stage_visibility_bit_map=stage_visibility_bit_map,
|
||||
)
|
||||
elif event_type == "TOOL_CALL_RESULT":
|
||||
await self._persist_tool_call_result(
|
||||
@@ -85,6 +92,7 @@ class SqlAlchemyEventStore:
|
||||
chat_session=chat_session,
|
||||
session_repo=session_repo,
|
||||
message_repo=message_repo,
|
||||
stage_visibility_bit_map=stage_visibility_bit_map,
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
@@ -97,6 +105,7 @@ class SqlAlchemyEventStore:
|
||||
chat_session: Any,
|
||||
session_repo: SessionRepository,
|
||||
message_repo: MessageRepository,
|
||||
stage_visibility_bit_map: dict[str, int],
|
||||
) -> None:
|
||||
message_id_raw = self._event_value(event, "messageId")
|
||||
message_id = message_id_raw if isinstance(message_id_raw, str) else ""
|
||||
@@ -188,6 +197,10 @@ class SqlAlchemyEventStore:
|
||||
output_tokens=output_tokens,
|
||||
cost=cost,
|
||||
latency_ms=latency_ms,
|
||||
visibility_mask=self._resolve_stage_visibility_mask(
|
||||
event=event,
|
||||
stage_visibility_bit_map=stage_visibility_bit_map,
|
||||
),
|
||||
)
|
||||
|
||||
current_status = getattr(chat_session, "status", AgentChatSessionStatus.RUNNING)
|
||||
@@ -213,6 +226,7 @@ class SqlAlchemyEventStore:
|
||||
chat_session: Any,
|
||||
session_repo: SessionRepository,
|
||||
message_repo: MessageRepository,
|
||||
stage_visibility_bit_map: dict[str, int],
|
||||
) -> None:
|
||||
run_id = self._event_value(event, "runId")
|
||||
run_id_value = run_id if isinstance(run_id, str) and run_id else None
|
||||
@@ -256,6 +270,10 @@ class SqlAlchemyEventStore:
|
||||
content=content,
|
||||
tool_name=tool_output.tool_name,
|
||||
metadata=metadata_model.model_dump(mode="json", exclude_none=True),
|
||||
visibility_mask=self._resolve_stage_visibility_mask(
|
||||
event=event,
|
||||
stage_visibility_bit_map=stage_visibility_bit_map,
|
||||
),
|
||||
)
|
||||
|
||||
current_status = getattr(chat_session, "status", AgentChatSessionStatus.RUNNING)
|
||||
@@ -279,6 +297,44 @@ class SqlAlchemyEventStore:
|
||||
return AgentChatMessageRole.TOOL
|
||||
return AgentChatMessageRole.ASSISTANT
|
||||
|
||||
def _resolve_stage_visibility_mask(
|
||||
self,
|
||||
*,
|
||||
event: dict[str, Any],
|
||||
stage_visibility_bit_map: dict[str, int],
|
||||
) -> int:
|
||||
base = bit_mask(bit=int(SystemVisibilityBit.UI_HISTORY))
|
||||
raw_stage = self._event_value(event, "stage")
|
||||
if not isinstance(raw_stage, str):
|
||||
return base
|
||||
normalized_stage = raw_stage.strip().lower()
|
||||
bit = stage_visibility_bit_map.get(normalized_stage)
|
||||
if bit is None and normalized_stage == AgentType.MEMORY.value:
|
||||
bit = 18
|
||||
if bit is None:
|
||||
return base
|
||||
return base | bit_mask(bit=bit)
|
||||
|
||||
async def _load_stage_visibility_bit_map(
|
||||
self,
|
||||
*,
|
||||
session: Any,
|
||||
) -> dict[str, int]:
|
||||
stmt = select(SystemAgents.agent_type, SystemAgents.config).where(
|
||||
SystemAgents.agent_type.in_(
|
||||
[AgentType.ROUTER.value, AgentType.WORKER.value, AgentType.MEMORY.value]
|
||||
)
|
||||
)
|
||||
rows = (await session.execute(stmt)).all()
|
||||
bit_map: dict[str, int] = {}
|
||||
for agent_type, raw_config in rows:
|
||||
if not isinstance(agent_type, str):
|
||||
continue
|
||||
config_payload = raw_config if isinstance(raw_config, dict) else {}
|
||||
llm_config = SystemAgentLLMConfig.model_validate(config_payload)
|
||||
bit_map[agent_type.strip().lower()] = llm_config.visibility_consumer_bit
|
||||
return bit_map
|
||||
|
||||
async def _update_session_state(
|
||||
self,
|
||||
*,
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from schemas.agent.runtime_models import ResultType, RouterAgentOutput, TaskType
|
||||
from schemas.agent.system_agent import AgentType, SystemAgentLLMConfig
|
||||
|
||||
|
||||
@@ -14,6 +17,24 @@ def _wrap_section(section: str, content: str) -> str:
|
||||
return f"{start}\n{body}\n{end}" if body else f"{start}\n{end}"
|
||||
|
||||
|
||||
def _enum_values(enum_cls: Any) -> str:
|
||||
return ", ".join(item.value for item in enum_cls)
|
||||
|
||||
|
||||
def _config_rules(llm_config: SystemAgentLLMConfig | None) -> list[str]:
|
||||
if llm_config is None:
|
||||
return []
|
||||
context_mode = llm_config.context_messages.mode.value
|
||||
context_count = llm_config.context_messages.count
|
||||
enabled_tools = [tool.value for tool in llm_config.enabled_tools]
|
||||
return [
|
||||
"[Runtime Config]",
|
||||
f"- context_messages.mode={context_mode}",
|
||||
f"- context_messages.count={context_count}",
|
||||
f"- enabled_tools={','.join(enabled_tools) if enabled_tools else 'default'}",
|
||||
]
|
||||
|
||||
|
||||
PromptRuleBuilder = Callable[[SystemAgentLLMConfig | None], list[str]]
|
||||
|
||||
|
||||
@@ -36,36 +57,41 @@ class AgentPromptRegistry:
|
||||
return builder(llm_config)
|
||||
|
||||
|
||||
def _config_rules(llm_config: SystemAgentLLMConfig | None) -> list[str]:
|
||||
if llm_config is None:
|
||||
return []
|
||||
context_mode = llm_config.context_messages.mode.value
|
||||
context_count = llm_config.context_messages.count
|
||||
tool_groups = [group.value for group in llm_config.enabled_tool_groups]
|
||||
def _router_rules(llm_config: SystemAgentLLMConfig | None) -> list[str]:
|
||||
return [
|
||||
"[Runtime Config]",
|
||||
f"- context_messages.mode={context_mode}",
|
||||
f"- context_messages.count={context_count}",
|
||||
f"- enabled_tool_groups={','.join(tool_groups) if tool_groups else 'default'}",
|
||||
"[Router Agent]",
|
||||
"- Read the latest user input and produce a routing contract for downstream execution.",
|
||||
"- Return exactly one RouterAgentOutput JSON object.",
|
||||
"[Responsibilities]",
|
||||
"- Router only: extract intent and route strategy; never answer user directly.",
|
||||
"- Preserve intent in normalized_task_input.user_text; keep wording concise and faithful.",
|
||||
"- Fill multimodal_summary only when image/attachment changes execution decisions.",
|
||||
"- Return key_entities and constraints that are execution-relevant; low confidence -> omit rather than guess.",
|
||||
"- Set execution_mode by complexity: onestep / tool_assisted / multistep.",
|
||||
"- Set result_typing.primary to the most suitable response shape; use clarification_request only when required info is missing.",
|
||||
"- Set ui.ui_mode and ui.ui_decision_reason based on whether structured UI improves actionability.",
|
||||
f"- task_typing.primary must use one TaskType enum: {_enum_values(TaskType)}.",
|
||||
f"- task_typing.secondary max 3 enums: {_enum_values(TaskType)}.",
|
||||
f"- result_typing.primary must use one ResultType enum: {_enum_values(ResultType)}.",
|
||||
f"- result_typing.secondary max 3 enums: {_enum_values(ResultType)}.",
|
||||
*_config_rules(llm_config),
|
||||
]
|
||||
|
||||
|
||||
def _worker_rules(llm_config: SystemAgentLLMConfig | None) -> list[str]:
|
||||
return [
|
||||
"[Worker Agent]",
|
||||
"- Process user context directly, identify intent, then execute or answer with available evidence.",
|
||||
"- Return exactly one agent output JSON object matching the runtime-injected schema.",
|
||||
"- Execute or answer against the routed objective and available evidence.",
|
||||
"- Return exactly one worker output JSON object matching the runtime-injected schema.",
|
||||
"[Responsibilities]",
|
||||
"- Handle user request directly from conversation context.",
|
||||
"- Decide intent first, then choose direct answer, tool call, clarification, or refusal.",
|
||||
"- Prefer a direct answer when no tool result is required.",
|
||||
"- Call tools only when tool results are necessary to produce a correct answer.",
|
||||
"- Infer deterministic required tool arguments from context, tool schema, and runtime signals.",
|
||||
"- Worker only: execute routed objective without changing router intent.",
|
||||
"- Treat router output as objective/constraints contract, not as a fully-materialized tool-args payload.",
|
||||
"- Infer deterministic required tool arguments from contract fields, tool schema, and runtime context.",
|
||||
"- Ask minimal clarification only when required arguments cannot be inferred safely.",
|
||||
"- If request is unsafe or disallowed, return safe refusal with actionable alternative.",
|
||||
"- Ground every claim in evidence and tool outputs; never fabricate execution state.",
|
||||
"- Ground every claim in available evidence and tool results; never fabricate execution state.",
|
||||
"- Keep status/result_type/answer/key_points/suggested_actions/error internally consistent.",
|
||||
"[Schema Guidance]",
|
||||
"- The output schema is injected at runtime; follow it exactly.",
|
||||
"- The worker output schema is injected at runtime; follow it exactly.",
|
||||
"- Do not add fields that are not present in the injected schema.",
|
||||
*_config_rules(llm_config),
|
||||
]
|
||||
@@ -88,7 +114,28 @@ def _memory_rules(llm_config: SystemAgentLLMConfig | None) -> list[str]:
|
||||
]
|
||||
|
||||
|
||||
def build_worker_contract_prompt(*, router_output: RouterAgentOutput) -> str:
|
||||
contract_json = json.dumps(
|
||||
router_output.model_dump(mode="json", exclude_none=True),
|
||||
ensure_ascii=True,
|
||||
separators=(",", ":"),
|
||||
)
|
||||
return "\n".join(
|
||||
[
|
||||
"[Worker Contract]",
|
||||
"- Keep routed objective unchanged.",
|
||||
"- Use normalized_task_input as objective text.",
|
||||
"- Use multimodal_summary/key_entities/constraints as execution evidence.",
|
||||
"- Infer deterministic missing required tool args from evidence + tool schema.",
|
||||
"- Ask clarification only when safe inference is impossible.",
|
||||
"[RouterAgentOutput]",
|
||||
contract_json,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
AGENT_PROMPT_REGISTRY = AgentPromptRegistry()
|
||||
AGENT_PROMPT_REGISTRY.register(agent_type=AgentType.ROUTER, builder=_router_rules)
|
||||
AGENT_PROMPT_REGISTRY.register(agent_type=AgentType.WORKER, builder=_worker_rules)
|
||||
AGENT_PROMPT_REGISTRY.register(agent_type=AgentType.MEMORY, builder=_memory_rules)
|
||||
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
pass
|
||||
@@ -0,0 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from schemas.agent.consumer_registry import AgentConsumerBinding, ConsumerRegistry
|
||||
|
||||
|
||||
def build_consumer_registry(
|
||||
*,
|
||||
system_agent_configs: dict[str, dict[str, object]],
|
||||
) -> ConsumerRegistry:
|
||||
bindings: list[AgentConsumerBinding] = []
|
||||
for agent_type, payload in system_agent_configs.items():
|
||||
config_obj = payload.get("config") if isinstance(payload, dict) else None
|
||||
if not isinstance(config_obj, dict):
|
||||
raise ValueError(f"invalid system agent config: {agent_type}")
|
||||
raw_bit = config_obj.get("visibility_consumer_bit")
|
||||
if not isinstance(raw_bit, int):
|
||||
raise ValueError(f"visibility_consumer_bit missing for agent: {agent_type}")
|
||||
bindings.append(AgentConsumerBinding(agent_type=agent_type, bit=raw_bit))
|
||||
return ConsumerRegistry(bindings=bindings)
|
||||
@@ -5,7 +5,7 @@ from typing import Any
|
||||
|
||||
from schemas.agent.system_agent import ContextBuildStrategy
|
||||
|
||||
ContextLoader = Callable[[Any, str, int], Awaitable[dict[str, object] | None]]
|
||||
ContextLoader = Callable[[Any, str, int, int], Awaitable[dict[str, object] | None]]
|
||||
|
||||
|
||||
class ContextLoaderRegistry:
|
||||
@@ -23,20 +23,28 @@ class ContextLoaderRegistry:
|
||||
|
||||
|
||||
async def _load_number(
|
||||
service: Any, thread_id: str, count: int
|
||||
service: Any,
|
||||
thread_id: str,
|
||||
count: int,
|
||||
visibility_mask: int,
|
||||
) -> dict[str, object] | None:
|
||||
return await service.load_by_user_message_window(
|
||||
thread_id=thread_id,
|
||||
user_message_limit=max(count, 1),
|
||||
visibility_mask=visibility_mask,
|
||||
)
|
||||
|
||||
|
||||
async def _load_day(
|
||||
service: Any, thread_id: str, count: int
|
||||
service: Any,
|
||||
thread_id: str,
|
||||
count: int,
|
||||
visibility_mask: int,
|
||||
) -> dict[str, object] | None:
|
||||
return await service.load_by_day_window(
|
||||
thread_id=thread_id,
|
||||
day_count=max(count, 1),
|
||||
visibility_mask=visibility_mask,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -5,17 +5,27 @@ from typing import Protocol
|
||||
|
||||
from core.agentscope.runtime.context_loader_registry import CONTEXT_LOADER_REGISTRY
|
||||
from schemas.agent.system_agent import SystemAgentLLMConfig
|
||||
from schemas.agent.visibility import bit_mask
|
||||
|
||||
_DEFAULT_CONTEXT_WINDOW_USER_MESSAGES = 20
|
||||
_DEFAULT_ROUTER_CONTEXT_DAY_COUNT = 20
|
||||
|
||||
|
||||
class ContextRepositoryLike(Protocol):
|
||||
async def get_history_day(
|
||||
self, *, session_id: str, before: date | None
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
before: date | None,
|
||||
visibility_mask: int | None = None,
|
||||
) -> dict[str, object] | None: ...
|
||||
|
||||
async def get_recent_messages_by_user_window(
|
||||
self, *, session_id: str, user_message_limit: int
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
user_message_limit: int,
|
||||
visibility_mask: int | None = None,
|
||||
) -> list[dict[str, object]]: ...
|
||||
|
||||
async def get_system_agent_config(
|
||||
@@ -41,20 +51,36 @@ class AgentContextService:
|
||||
if isinstance(raw_config, dict):
|
||||
raw_llm_config = raw_config
|
||||
|
||||
if mode == "router" and not raw_llm_config:
|
||||
raw_llm_config = {
|
||||
"context_messages": {
|
||||
"mode": "day",
|
||||
"count": _DEFAULT_ROUTER_CONTEXT_DAY_COUNT,
|
||||
}
|
||||
}
|
||||
|
||||
normalized_config = self._normalize_system_agent_config(raw_llm_config)
|
||||
context_config = normalized_config.context_messages
|
||||
visibility_mask = bit_mask(bit=normalized_config.visibility_consumer_bit)
|
||||
context_loader = CONTEXT_LOADER_REGISTRY.resolve(mode=context_config.mode)
|
||||
return await context_loader(self, thread_id, context_config.count)
|
||||
return await context_loader(
|
||||
self,
|
||||
thread_id,
|
||||
context_config.count,
|
||||
visibility_mask,
|
||||
)
|
||||
|
||||
async def load_by_user_message_window(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
user_message_limit: int,
|
||||
visibility_mask: int,
|
||||
) -> dict[str, object] | None:
|
||||
messages = await self._repository.get_recent_messages_by_user_window(
|
||||
session_id=thread_id,
|
||||
user_message_limit=max(int(user_message_limit), 1),
|
||||
visibility_mask=visibility_mask,
|
||||
)
|
||||
if not messages:
|
||||
return None
|
||||
@@ -65,6 +91,7 @@ class AgentContextService:
|
||||
*,
|
||||
thread_id: str,
|
||||
day_count: int,
|
||||
visibility_mask: int,
|
||||
) -> dict[str, object] | None:
|
||||
messages: list[dict[str, object]] = []
|
||||
before: date | None = None
|
||||
@@ -72,6 +99,7 @@ class AgentContextService:
|
||||
day_payload = await self._repository.get_history_day(
|
||||
session_id=thread_id,
|
||||
before=before,
|
||||
visibility_mask=visibility_mask,
|
||||
)
|
||||
if not day_payload:
|
||||
break
|
||||
@@ -95,7 +123,7 @@ class AgentContextService:
|
||||
"mode": "number",
|
||||
"count": _DEFAULT_CONTEXT_WINDOW_USER_MESSAGES,
|
||||
},
|
||||
"enabled_tool_groups": [],
|
||||
"enabled_tools": [],
|
||||
}
|
||||
if not raw_config:
|
||||
return SystemAgentLLMConfig.model_validate(default_payload)
|
||||
|
||||
@@ -6,6 +6,7 @@ from ag_ui.core.types import RunAgentInput
|
||||
from agentscope.message import Msg
|
||||
from core.agentscope.runtime.runner import AgentScopeRunner
|
||||
from core.logging import get_logger
|
||||
from schemas.automation.config import AutomationJobConfig
|
||||
from schemas.user import UserContext
|
||||
|
||||
logger = get_logger("core.agentscope.runtime.orchestrator")
|
||||
@@ -24,6 +25,7 @@ class RunnerLike(Protocol):
|
||||
pipeline: PipelineLike,
|
||||
run_input: RunAgentInput,
|
||||
system_agent_mode: str,
|
||||
memory_job_config: AutomationJobConfig | None,
|
||||
) -> dict[str, Any]: ...
|
||||
|
||||
|
||||
@@ -47,6 +49,7 @@ class AgentScopeRuntimeOrchestrator:
|
||||
context_messages: list[Msg],
|
||||
user_context: UserContext,
|
||||
system_agent_mode: str,
|
||||
memory_job_config: AutomationJobConfig | None = None,
|
||||
) -> dict[str, Any]:
|
||||
thread_id = run_input.thread_id
|
||||
run_id = run_input.run_id
|
||||
@@ -66,6 +69,7 @@ class AgentScopeRuntimeOrchestrator:
|
||||
pipeline=self._pipeline,
|
||||
run_input=run_input,
|
||||
system_agent_mode=system_agent_mode,
|
||||
memory_job_config=memory_job_config,
|
||||
)
|
||||
|
||||
await self._pipeline.emit(
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from schemas.agent.pipeline_spec import (
|
||||
ContextPolicy,
|
||||
ContextWindowMode,
|
||||
ExecutorKind,
|
||||
PipelineSpec,
|
||||
StageSpec,
|
||||
)
|
||||
|
||||
|
||||
def build_default_pipeline_spec(*, mode: str) -> PipelineSpec:
|
||||
normalized = mode.strip().lower()
|
||||
if normalized == "worker":
|
||||
return PipelineSpec(
|
||||
mode="worker",
|
||||
stages=[
|
||||
StageSpec(
|
||||
stage_name="router",
|
||||
executor_kind=ExecutorKind.SINGLE_SHOT,
|
||||
default_visibility_mask=0,
|
||||
context_policy=ContextPolicy(
|
||||
consumer_agent_type="router",
|
||||
window_mode=ContextWindowMode.DAY,
|
||||
count=20,
|
||||
),
|
||||
),
|
||||
StageSpec(
|
||||
stage_name="worker",
|
||||
executor_kind=ExecutorKind.REACT,
|
||||
default_visibility_mask=0,
|
||||
context_policy=ContextPolicy(
|
||||
consumer_agent_type="worker",
|
||||
window_mode=ContextWindowMode.NUMBER,
|
||||
count=20,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
if normalized == "memory":
|
||||
return PipelineSpec(
|
||||
mode="memory",
|
||||
stages=[
|
||||
StageSpec(
|
||||
stage_name="memory",
|
||||
executor_kind=ExecutorKind.REACT,
|
||||
default_visibility_mask=0,
|
||||
context_policy=ContextPolicy(
|
||||
consumer_agent_type="memory",
|
||||
window_mode=ContextWindowMode.DAY,
|
||||
count=20,
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
raise ValueError(f"unsupported pipeline mode: {normalized}")
|
||||
@@ -10,26 +10,40 @@ from agentscope.formatter import OpenAIChatFormatter
|
||||
from agentscope.memory import InMemoryMemory
|
||||
from agentscope.message import Msg
|
||||
from agentscope.model import OpenAIChatModel
|
||||
from core.agentscope.prompts.agent_prompt import build_worker_contract_prompt
|
||||
from core.agentscope.prompts.system_prompt import build_system_prompt
|
||||
from core.agentscope.runtime.pipeline_registry import build_default_pipeline_spec
|
||||
from core.agentscope.runtime.json_react_agent import JsonReActAgent
|
||||
from core.agentscope.runtime.model_tracking import TrackingChatModel
|
||||
from core.agentscope.runtime.stage_emitter import PipelineStageEmitter
|
||||
from core.agentscope.runtime.tool_selection_registry import TOOL_SELECTION_REGISTRY
|
||||
from core.agentscope.tools.toolkit import build_stage_toolkit
|
||||
from core.agentscope.utils import patch_agentscope_json_repair_compat
|
||||
from core.agentscope.utils import (
|
||||
finalize_json_response,
|
||||
patch_agentscope_json_repair_compat,
|
||||
)
|
||||
from core.config.settings import config
|
||||
from core.db.session import AsyncSessionLocal
|
||||
from models.llm import Llm
|
||||
from models.llm_factory import LlmFactory
|
||||
from models.system_agents import SystemAgents
|
||||
from schemas.agent.runtime_models import (
|
||||
AgentOutput,
|
||||
)
|
||||
from schemas.agent.forwarded_props import (
|
||||
ClientTimeContext,
|
||||
parse_forwarded_props_client_time,
|
||||
)
|
||||
from schemas.agent.system_agent import AgentType, SystemAgentLLMConfig
|
||||
from schemas.automation.config import AutomationJobConfig
|
||||
from schemas.agent.runtime_models import (
|
||||
AgentOutput,
|
||||
RouterAgentOutput,
|
||||
WorkerAgentOutputLite,
|
||||
resolve_worker_output_model,
|
||||
)
|
||||
from schemas.agent.system_agent import (
|
||||
AgentType,
|
||||
ContextMessagesConfig,
|
||||
ContextBuildStrategy,
|
||||
SystemAgentLLMConfig,
|
||||
)
|
||||
from schemas.user import UserContext
|
||||
from services.litellm.service import LiteLLMService
|
||||
from sqlalchemy import select
|
||||
@@ -46,6 +60,7 @@ class SystemAgentRuntimeConfig:
|
||||
api_base_url: str
|
||||
api_key: str
|
||||
llm_config: SystemAgentLLMConfig
|
||||
extra_context: str | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -68,33 +83,83 @@ class AgentScopeRunner:
|
||||
pipeline: PipelineLike,
|
||||
run_input: RunAgentInput,
|
||||
system_agent_mode: str,
|
||||
memory_job_config: AutomationJobConfig | None = None,
|
||||
) -> dict[str, Any]:
|
||||
owner_id = UUID(user_context.id)
|
||||
runtime_client_time = self._resolve_runtime_client_time(run_input=run_input)
|
||||
stage_agent_type = self._resolve_stage_agent_type(system_agent_mode)
|
||||
pipeline_spec = build_default_pipeline_spec(mode=system_agent_mode)
|
||||
stage_agent_types = [
|
||||
self._parse_agent_type(stage_name=stage.stage_name)
|
||||
for stage in pipeline_spec.stages
|
||||
]
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
if stage_agent_types == [AgentType.ROUTER, AgentType.WORKER]:
|
||||
router_config = await self._load_stage_config(
|
||||
session=session,
|
||||
agent_type=AgentType.ROUTER,
|
||||
)
|
||||
worker_config = await self._load_stage_config(
|
||||
session=session,
|
||||
agent_type=AgentType.WORKER,
|
||||
)
|
||||
worker_toolkit = self._build_stage_toolkit(
|
||||
session=session,
|
||||
owner_id=owner_id,
|
||||
stage_config=worker_config,
|
||||
)
|
||||
router_output = await self._execute_router_step(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
user_context=user_context,
|
||||
context_messages=context_messages,
|
||||
stage_config=router_config,
|
||||
runtime_client_time=runtime_client_time,
|
||||
)
|
||||
worker_output = await self._execute_worker_step(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
user_context=user_context,
|
||||
router_output=router_output,
|
||||
toolkit=worker_toolkit,
|
||||
stage_config=worker_config,
|
||||
runtime_client_time=runtime_client_time,
|
||||
)
|
||||
return {
|
||||
"router": router_output.model_dump(mode="json", exclude_none=True),
|
||||
"worker": worker_output.model_dump(mode="json", exclude_none=True),
|
||||
}
|
||||
|
||||
if stage_agent_types[0] == AgentType.MEMORY:
|
||||
if memory_job_config is None:
|
||||
raise RuntimeError("memory job config is required")
|
||||
stage_config = await self._build_memory_stage_config(
|
||||
session=session,
|
||||
memory_job_config=memory_job_config,
|
||||
)
|
||||
else:
|
||||
stage_config = await self._load_stage_config(
|
||||
session=session,
|
||||
agent_type=stage_agent_type,
|
||||
agent_type=stage_agent_types[0],
|
||||
)
|
||||
stage_toolkit = self._build_stage_toolkit(
|
||||
session=session,
|
||||
owner_id=owner_id,
|
||||
stage_config=stage_config,
|
||||
)
|
||||
worker_output = await self._execute_worker_step(
|
||||
stage_output = await self._execute_single_stage_step(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
user_context=user_context,
|
||||
context_messages=context_messages,
|
||||
input_messages=context_messages,
|
||||
toolkit=stage_toolkit,
|
||||
stage_config=stage_config,
|
||||
runtime_client_time=runtime_client_time,
|
||||
)
|
||||
|
||||
return {
|
||||
"worker": worker_output.model_dump(mode="json", exclude_none=True),
|
||||
stage_config.agent_type.value: stage_output.model_dump(
|
||||
mode="json", exclude_none=True
|
||||
),
|
||||
}
|
||||
|
||||
def _build_stage_toolkit(
|
||||
@@ -113,11 +178,15 @@ class AgentScopeRunner:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_stage_agent_type(system_agent_mode: str) -> AgentType:
|
||||
mode = system_agent_mode.strip().lower() if system_agent_mode else "worker"
|
||||
if mode == AgentType.MEMORY.value:
|
||||
return AgentType.MEMORY
|
||||
def _parse_agent_type(*, stage_name: str) -> AgentType:
|
||||
normalized = stage_name.strip().lower()
|
||||
if normalized == AgentType.ROUTER.value:
|
||||
return AgentType.ROUTER
|
||||
if normalized == AgentType.WORKER.value:
|
||||
return AgentType.WORKER
|
||||
if normalized == AgentType.MEMORY.value:
|
||||
return AgentType.MEMORY
|
||||
raise ValueError(f"unsupported stage name: {stage_name}")
|
||||
|
||||
async def _load_stage_config(
|
||||
self,
|
||||
@@ -130,28 +199,60 @@ class AgentScopeRunner:
|
||||
agent_type=agent_type,
|
||||
)
|
||||
|
||||
async def _execute_worker_step(
|
||||
async def _execute_router_step(
|
||||
self,
|
||||
*,
|
||||
pipeline: PipelineLike,
|
||||
run_input: RunAgentInput,
|
||||
user_context: UserContext,
|
||||
context_messages: list[Msg],
|
||||
toolkit: Any,
|
||||
stage_config: SystemAgentRuntimeConfig,
|
||||
runtime_client_time: ClientTimeContext | None,
|
||||
) -> AgentOutput:
|
||||
step_name = stage_config.agent_type.value
|
||||
worker_output_model = AgentOutput
|
||||
) -> RouterAgentOutput:
|
||||
await self._emit_step_event(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
step_name=step_name,
|
||||
step_name=AgentType.ROUTER.value,
|
||||
event_type="STEP_STARTED",
|
||||
)
|
||||
router_result = await self._run_router_stage(
|
||||
user_context=user_context,
|
||||
context_messages=context_messages,
|
||||
stage_config=stage_config,
|
||||
runtime_client_time=runtime_client_time,
|
||||
)
|
||||
router_output = RouterAgentOutput.model_validate(router_result.payload)
|
||||
await self._emit_step_event(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
step_name=AgentType.ROUTER.value,
|
||||
event_type="STEP_FINISHED",
|
||||
)
|
||||
return router_output
|
||||
|
||||
async def _execute_worker_step(
|
||||
self,
|
||||
*,
|
||||
pipeline: PipelineLike,
|
||||
run_input: RunAgentInput,
|
||||
user_context: UserContext,
|
||||
router_output: RouterAgentOutput,
|
||||
toolkit: Any,
|
||||
stage_config: SystemAgentRuntimeConfig,
|
||||
runtime_client_time: ClientTimeContext | None,
|
||||
) -> WorkerAgentOutputLite:
|
||||
worker_output_model = resolve_worker_output_model(router_output.ui.ui_mode)
|
||||
await self._emit_step_event(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
step_name=AgentType.WORKER.value,
|
||||
event_type="STEP_STARTED",
|
||||
)
|
||||
worker_result = await self._run_worker_stage(
|
||||
user_context=user_context,
|
||||
context_messages=context_messages,
|
||||
input_messages=self._build_worker_input_messages(
|
||||
router_output=router_output
|
||||
),
|
||||
toolkit=toolkit,
|
||||
run_input=run_input,
|
||||
stage_config=stage_config,
|
||||
@@ -163,11 +264,48 @@ class AgentScopeRunner:
|
||||
await self._emit_step_event(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
step_name=step_name,
|
||||
step_name=AgentType.WORKER.value,
|
||||
event_type="STEP_FINISHED",
|
||||
)
|
||||
return worker_output
|
||||
|
||||
async def _execute_single_stage_step(
|
||||
self,
|
||||
*,
|
||||
pipeline: PipelineLike,
|
||||
run_input: RunAgentInput,
|
||||
user_context: UserContext,
|
||||
input_messages: list[Msg],
|
||||
toolkit: Any,
|
||||
stage_config: SystemAgentRuntimeConfig,
|
||||
runtime_client_time: ClientTimeContext | None,
|
||||
) -> AgentOutput:
|
||||
step_name = stage_config.agent_type.value
|
||||
await self._emit_step_event(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
step_name=step_name,
|
||||
event_type="STEP_STARTED",
|
||||
)
|
||||
stage_result = await self._run_worker_stage(
|
||||
user_context=user_context,
|
||||
input_messages=input_messages,
|
||||
toolkit=toolkit,
|
||||
run_input=run_input,
|
||||
stage_config=stage_config,
|
||||
worker_output_model=AgentOutput,
|
||||
pipeline=pipeline,
|
||||
runtime_client_time=runtime_client_time,
|
||||
)
|
||||
stage_output = AgentOutput.model_validate(stage_result.payload)
|
||||
await self._emit_step_event(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
step_name=step_name,
|
||||
event_type="STEP_FINISHED",
|
||||
)
|
||||
return stage_output
|
||||
|
||||
async def _load_system_agent_config(
|
||||
self,
|
||||
*,
|
||||
@@ -193,6 +331,50 @@ class AgentScopeRunner:
|
||||
api_base_url=factory.request_url,
|
||||
api_key=self._resolve_provider_api_key(factory_name=factory.name),
|
||||
llm_config=SystemAgentLLMConfig.model_validate(system_agent.config or {}),
|
||||
extra_context=None,
|
||||
)
|
||||
|
||||
async def _build_memory_stage_config(
|
||||
self,
|
||||
*,
|
||||
session: AsyncSession,
|
||||
memory_job_config: AutomationJobConfig,
|
||||
) -> SystemAgentRuntimeConfig:
|
||||
stmt = (
|
||||
select(Llm, LlmFactory)
|
||||
.join(LlmFactory, Llm.factory_id == LlmFactory.id)
|
||||
.where(Llm.model_code == memory_job_config.model_code)
|
||||
)
|
||||
row = (await session.execute(stmt)).one_or_none()
|
||||
if row is None:
|
||||
raise RuntimeError(
|
||||
f"memory model not found: {memory_job_config.model_code}"
|
||||
)
|
||||
llm, factory = row
|
||||
llm_config = SystemAgentLLMConfig(
|
||||
temperature=0.7,
|
||||
max_tokens=None,
|
||||
timeout_seconds=30,
|
||||
visibility_consumer_bit=18,
|
||||
context_messages=ContextMessagesConfig(
|
||||
mode=(
|
||||
ContextBuildStrategy.DAY
|
||||
if memory_job_config.context.window_mode.value == "day"
|
||||
else ContextBuildStrategy.NUMBER
|
||||
),
|
||||
count=memory_job_config.context.window_count,
|
||||
),
|
||||
enabled_tools=memory_job_config.enabled_tools,
|
||||
)
|
||||
return SystemAgentRuntimeConfig(
|
||||
agent_type=AgentType.MEMORY,
|
||||
model_code=llm.model_code,
|
||||
api_base_url=factory.request_url,
|
||||
api_key=self._resolve_provider_api_key(factory_name=factory.name),
|
||||
llm_config=llm_config,
|
||||
extra_context=(
|
||||
f"[Memory Input Template]\n{memory_job_config.input_template.strip()}"
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -211,19 +393,63 @@ class AgentScopeRunner:
|
||||
raise RuntimeError(f"provider api key missing for factory: {factory_name}")
|
||||
return api_key
|
||||
|
||||
async def _run_worker_stage(
|
||||
async def _run_router_stage(
|
||||
self,
|
||||
*,
|
||||
user_context: UserContext,
|
||||
context_messages: list[Msg],
|
||||
stage_config: SystemAgentRuntimeConfig,
|
||||
runtime_client_time: ClientTimeContext | None,
|
||||
) -> StageExecutionResult:
|
||||
tracking_model = self._build_model(stage_config=stage_config)
|
||||
response, payload = await finalize_json_response(
|
||||
model=tracking_model,
|
||||
formatter=OpenAIChatFormatter(),
|
||||
base_messages=[
|
||||
Msg(
|
||||
"system",
|
||||
build_system_prompt(
|
||||
agent_type=AgentType.ROUTER,
|
||||
llm_config=stage_config.llm_config,
|
||||
user_context=user_context,
|
||||
now_utc=datetime.now(timezone.utc),
|
||||
runtime_client_time=runtime_client_time,
|
||||
tools=None,
|
||||
),
|
||||
"system",
|
||||
),
|
||||
*context_messages,
|
||||
],
|
||||
output_model=RouterAgentOutput,
|
||||
retries=0,
|
||||
)
|
||||
response_msg = Msg(
|
||||
name="router",
|
||||
role="assistant",
|
||||
content=list(getattr(response, "content", [])),
|
||||
metadata=payload,
|
||||
)
|
||||
return StageExecutionResult(
|
||||
message=response_msg,
|
||||
payload=payload,
|
||||
response_metadata=self._litellm_service.build_usage_metadata(
|
||||
model=stage_config.model_code,
|
||||
usage_summary=tracking_model.usage_summary(),
|
||||
),
|
||||
)
|
||||
|
||||
async def _run_worker_stage(
|
||||
self,
|
||||
*,
|
||||
user_context: UserContext,
|
||||
input_messages: list[Msg],
|
||||
toolkit: Any,
|
||||
run_input: RunAgentInput,
|
||||
stage_config: SystemAgentRuntimeConfig,
|
||||
worker_output_model: type[AgentOutput],
|
||||
worker_output_model: type[WorkerAgentOutputLite],
|
||||
pipeline: PipelineLike,
|
||||
runtime_client_time: ClientTimeContext | None,
|
||||
) -> StageExecutionResult:
|
||||
worker_input = list(context_messages)
|
||||
tracking_model = self._build_model(stage_config=stage_config)
|
||||
emitter = PipelineStageEmitter(
|
||||
pipeline=pipeline,
|
||||
@@ -241,6 +467,7 @@ class AgentScopeRunner:
|
||||
user_context=user_context,
|
||||
now_utc=datetime.now(timezone.utc),
|
||||
runtime_client_time=runtime_client_time,
|
||||
extra_context=stage_config.extra_context,
|
||||
tools=None,
|
||||
),
|
||||
toolkit=toolkit,
|
||||
@@ -248,7 +475,7 @@ class AgentScopeRunner:
|
||||
emitter=emitter,
|
||||
)
|
||||
response_msg = await agent.reply_json(
|
||||
worker_input, output_model=worker_output_model
|
||||
input_messages, output_model=worker_output_model
|
||||
)
|
||||
worker_payload = worker_output_model.model_validate(response_msg.metadata or {})
|
||||
response_metadata = self._litellm_service.build_usage_metadata(
|
||||
@@ -265,6 +492,19 @@ class AgentScopeRunner:
|
||||
response_metadata=response_metadata,
|
||||
)
|
||||
|
||||
def _build_worker_input_messages(
|
||||
self,
|
||||
*,
|
||||
router_output: RouterAgentOutput,
|
||||
) -> list[Msg]:
|
||||
return [
|
||||
Msg(
|
||||
name=AgentType.ROUTER.value,
|
||||
role="user",
|
||||
content=build_worker_contract_prompt(router_output=router_output),
|
||||
)
|
||||
]
|
||||
|
||||
def _build_model(
|
||||
self, *, stage_config: SystemAgentRuntimeConfig
|
||||
) -> TrackingChatModel:
|
||||
@@ -272,8 +512,8 @@ class AgentScopeRunner:
|
||||
"temperature": stage_config.llm_config.temperature,
|
||||
"max_tokens": stage_config.llm_config.max_tokens,
|
||||
"timeout": stage_config.llm_config.timeout_seconds,
|
||||
"extra_body": {"enable_thinking": False},
|
||||
}
|
||||
generate_kwargs["extra_body"] = {"enable_thinking": False}
|
||||
|
||||
model = OpenAIChatModel(
|
||||
model_name=stage_config.model_code,
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
from datetime import timezone
|
||||
from typing import Any, cast
|
||||
from uuid import UUID
|
||||
|
||||
@@ -14,26 +15,49 @@ from core.agentscope.events import (
|
||||
)
|
||||
from core.agentscope.runtime.context_service import AgentContextService
|
||||
from core.agentscope.runtime.orchestrator import AgentScopeRuntimeOrchestrator
|
||||
from core.agentscope.runtime.pipeline_registry import build_default_pipeline_spec
|
||||
from core.agentscope.schemas.agui_input import parse_run_input
|
||||
from core.automation.scheduler import (
|
||||
AutomationSchedulerService,
|
||||
SqlAlchemyAutomationSchedulerRepository,
|
||||
utc_now,
|
||||
)
|
||||
from core.auth.models import CurrentUser
|
||||
from core.config.settings import config
|
||||
from core.db.session import AsyncSessionLocal
|
||||
from core.logging import get_logger
|
||||
from core.taskiq.app import bulk_broker, critical_broker, default_broker
|
||||
from models.automation_jobs import AutomationJob
|
||||
from schemas.agent.visibility import SystemVisibilityBit, bit_mask
|
||||
from schemas.automation.config import AutomationJobConfig
|
||||
from schemas.messages.chat_message import (
|
||||
AgentChatMessageMetadata,
|
||||
extract_user_message_attachments,
|
||||
)
|
||||
from schemas.agent.forwarded_props import parse_forwarded_props_agent_type
|
||||
from schemas.user import UserContext
|
||||
from services.base.redis import get_or_init_redis_client
|
||||
from services.base.supabase import supabase_service
|
||||
from v1.agent.repository import AgentRepository
|
||||
from v1.users.dependencies import get_user_service
|
||||
from sqlalchemy import select
|
||||
|
||||
logger = get_logger("core.agentscope.runtime.tasks")
|
||||
_MAX_CONTEXT_ATTACHMENTS = 3
|
||||
|
||||
|
||||
class _BulkQueueAdapter:
|
||||
async def enqueue(
|
||||
self,
|
||||
*,
|
||||
command: dict[str, object],
|
||||
dedup_key: str | None,
|
||||
) -> str:
|
||||
del dedup_key
|
||||
result = await run_command_task_bulk.kiq(command)
|
||||
return str(result.task_id)
|
||||
|
||||
|
||||
def _serialize_tool_agent_output(
|
||||
*,
|
||||
metadata: AgentChatMessageMetadata | dict[str, object] | None,
|
||||
@@ -79,12 +103,28 @@ async def _build_recent_context_messages(
|
||||
*,
|
||||
session: Any,
|
||||
thread_id: str,
|
||||
system_agent_mode: str,
|
||||
context_mode: str,
|
||||
memory_job_config: AutomationJobConfig | None = None,
|
||||
) -> list[Msg]:
|
||||
context_service = AgentContextService(repository=AgentRepository(session))
|
||||
if memory_job_config is not None:
|
||||
visibility_mask = bit_mask(bit=int(SystemVisibilityBit.UI_HISTORY))
|
||||
if memory_job_config.context.window_mode.value == "day":
|
||||
result = await context_service.load_by_day_window(
|
||||
thread_id=thread_id,
|
||||
day_count=memory_job_config.context.window_count,
|
||||
visibility_mask=visibility_mask,
|
||||
)
|
||||
else:
|
||||
result = await context_service.load_by_user_message_window(
|
||||
thread_id=thread_id,
|
||||
user_message_limit=memory_job_config.context.window_count,
|
||||
visibility_mask=visibility_mask,
|
||||
)
|
||||
else:
|
||||
result = await context_service.load_context_messages(
|
||||
thread_id=thread_id,
|
||||
system_agent_mode=system_agent_mode,
|
||||
system_agent_mode=context_mode,
|
||||
)
|
||||
if not result:
|
||||
return []
|
||||
@@ -166,11 +206,33 @@ async def _build_recent_context_messages(
|
||||
return converted
|
||||
|
||||
|
||||
async def _load_memory_job_config(
|
||||
*,
|
||||
session: Any,
|
||||
owner_id: UUID,
|
||||
automation_job_id: str,
|
||||
) -> AutomationJobConfig:
|
||||
try:
|
||||
job_uuid = UUID(automation_job_id)
|
||||
except ValueError as exc:
|
||||
raise ValueError("automation_job_id is invalid") from exc
|
||||
|
||||
stmt = (
|
||||
select(AutomationJob)
|
||||
.where(AutomationJob.id == job_uuid)
|
||||
.where(AutomationJob.owner_id == owner_id)
|
||||
.where(AutomationJob.deleted_at.is_(None))
|
||||
)
|
||||
row = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if row is None:
|
||||
raise ValueError("automation job not found")
|
||||
return AutomationJobConfig.model_validate(row.config or {})
|
||||
|
||||
|
||||
async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
command_type = str(command.get("command", "run")).strip().lower()
|
||||
raw_owner_id = command.get("owner_id")
|
||||
run_input_raw = command.get("run_input")
|
||||
system_agent_mode = str(command.get("system_agent_mode", "worker")).strip().lower()
|
||||
|
||||
if not isinstance(raw_owner_id, str) or not raw_owner_id.strip():
|
||||
raise ValueError("owner_id is required")
|
||||
@@ -178,6 +240,15 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
raise ValueError("run_input is required")
|
||||
|
||||
run_input = parse_run_input(run_input_raw)
|
||||
system_agent_mode = parse_forwarded_props_agent_type(
|
||||
getattr(run_input, "forwarded_props", None)
|
||||
)
|
||||
raw_automation_job_id = command.get("automation_job_id")
|
||||
if system_agent_mode == "memory" and (
|
||||
not isinstance(raw_automation_job_id, str) or not raw_automation_job_id
|
||||
):
|
||||
raise ValueError("automation_job_id is required for memory mode")
|
||||
pipeline_spec = build_default_pipeline_spec(mode=system_agent_mode)
|
||||
thread_id = run_input.thread_id
|
||||
run_id = run_input.run_id
|
||||
owner_id = UUID(raw_owner_id)
|
||||
@@ -189,6 +260,14 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
user_context = await _build_user_context(owner_id=owner_id, session=session)
|
||||
memory_job_config: AutomationJobConfig | None = None
|
||||
if system_agent_mode == "memory":
|
||||
assert isinstance(raw_automation_job_id, str)
|
||||
memory_job_config = await _load_memory_job_config(
|
||||
session=session,
|
||||
owner_id=owner_id,
|
||||
automation_job_id=raw_automation_job_id,
|
||||
)
|
||||
|
||||
redis_client = await get_or_init_redis_client()
|
||||
bus = RedisStreamBus(
|
||||
@@ -211,7 +290,8 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
context_messages = await _build_recent_context_messages(
|
||||
session=session,
|
||||
thread_id=thread_id,
|
||||
system_agent_mode=system_agent_mode,
|
||||
context_mode=pipeline_spec.stages[0].context_policy.consumer_agent_type,
|
||||
memory_job_config=memory_job_config,
|
||||
)
|
||||
|
||||
await runtime.run(
|
||||
@@ -219,6 +299,7 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
context_messages=context_messages,
|
||||
user_context=user_context,
|
||||
system_agent_mode=system_agent_mode,
|
||||
memory_job_config=memory_job_config,
|
||||
)
|
||||
logger.info(
|
||||
"agentscope runtime task completed",
|
||||
@@ -233,6 +314,35 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
}
|
||||
|
||||
|
||||
async def run_automation_scheduler_scan(
|
||||
*,
|
||||
limit: int | None = None,
|
||||
) -> dict[str, int]:
|
||||
now = utc_now()
|
||||
safe_limit = (
|
||||
max(int(limit), 1)
|
||||
if isinstance(limit, int)
|
||||
else int(config.automation_scheduler.batch_limit)
|
||||
)
|
||||
async with AsyncSessionLocal() as session:
|
||||
repository = SqlAlchemyAutomationSchedulerRepository(session=session)
|
||||
service = AutomationSchedulerService(
|
||||
repository=repository,
|
||||
queue=_BulkQueueAdapter(),
|
||||
)
|
||||
result = await service.scan_and_dispatch(now_utc=now, limit=safe_limit)
|
||||
logger.info(
|
||||
"automation scheduler scan completed",
|
||||
scanned=result.scanned,
|
||||
dispatched=result.dispatched,
|
||||
now_utc=now.astimezone(timezone.utc).isoformat(),
|
||||
)
|
||||
return {
|
||||
"scanned": int(result.scanned),
|
||||
"dispatched": int(result.dispatched),
|
||||
}
|
||||
|
||||
|
||||
@default_broker.task(task_name="tasks.agentscope.run_command")
|
||||
async def run_command_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
return await run_agentscope_task(command)
|
||||
@@ -246,3 +356,8 @@ async def run_command_task_critical(command: dict[str, Any]) -> dict[str, object
|
||||
@bulk_broker.task(task_name="tasks.agentscope.run_command.bulk")
|
||||
async def run_command_task_bulk(command: dict[str, Any]) -> dict[str, object]:
|
||||
return await run_agentscope_task(command)
|
||||
|
||||
|
||||
@default_broker.task(task_name="tasks.automation.scan_due_jobs")
|
||||
async def scan_due_automation_jobs_task(limit: int | None = None) -> dict[str, int]:
|
||||
return await run_automation_scheduler_scan(limit=limit)
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from core.agentscope.tools.tool_config import resolve_tool_names_by_groups
|
||||
from core.agentscope.tools.tool_config import resolve_tool_function_names
|
||||
from schemas.agent.system_agent import AgentType
|
||||
|
||||
ToolNameResolver = Callable[[Any], set[str] | None]
|
||||
@@ -23,20 +23,19 @@ class ToolSelectionRegistry:
|
||||
return resolver(stage_config)
|
||||
|
||||
|
||||
def _default_group_resolver(stage_config: Any) -> set[str] | None:
|
||||
raw_groups = getattr(stage_config.llm_config, "enabled_tool_groups", [])
|
||||
groups = raw_groups if isinstance(raw_groups, list) else []
|
||||
if not groups:
|
||||
def _default_tool_resolver(stage_config: Any) -> set[str] | None:
|
||||
enabled_tools = getattr(stage_config.llm_config, "enabled_tools", [])
|
||||
if not enabled_tools:
|
||||
return None
|
||||
return resolve_tool_names_by_groups(set(groups))
|
||||
return resolve_tool_function_names(set(enabled_tools))
|
||||
|
||||
|
||||
TOOL_SELECTION_REGISTRY = ToolSelectionRegistry()
|
||||
TOOL_SELECTION_REGISTRY.register(
|
||||
agent_type=AgentType.WORKER,
|
||||
resolver=_default_group_resolver,
|
||||
resolver=_default_tool_resolver,
|
||||
)
|
||||
TOOL_SELECTION_REGISTRY.register(
|
||||
agent_type=AgentType.MEMORY,
|
||||
resolver=_default_group_resolver,
|
||||
resolver=_default_tool_resolver,
|
||||
)
|
||||
|
||||
@@ -11,7 +11,7 @@ from core.agentscope.tools.utils.calendar_domain import (
|
||||
map_calendar_exception,
|
||||
merge_schedule_metadata_for_update,
|
||||
parse_iso_datetime,
|
||||
resolve_share_target_email_map,
|
||||
resolve_share_target_phone_map,
|
||||
schedule_event_to_dict,
|
||||
)
|
||||
from core.agentscope.tools.utils.calendar_ui import (
|
||||
@@ -580,11 +580,11 @@ async def calendar_share(
|
||||
)
|
||||
target_uuid = UUID(event_id)
|
||||
|
||||
email_map = resolve_share_target_email_map(
|
||||
phone_map = resolve_share_target_phone_map(
|
||||
[invitee.user_id for invitee in invitees]
|
||||
)
|
||||
|
||||
if not email_map:
|
||||
if not phone_map:
|
||||
return calendar_error_output(
|
||||
tool_name=tool_name,
|
||||
tool_call_args=tool_call_args,
|
||||
@@ -599,8 +599,8 @@ async def calendar_share(
|
||||
normalized_user_id = str(UUID(invitee.user_id.strip()))
|
||||
except ValueError:
|
||||
continue
|
||||
email = email_map.get(normalized_user_id)
|
||||
if email is None:
|
||||
phone = phone_map.get(normalized_user_id)
|
||||
if phone is None:
|
||||
continue
|
||||
permission = {
|
||||
"permission_view": invitee.permission_view,
|
||||
@@ -608,15 +608,15 @@ async def calendar_share(
|
||||
"permission_invite": invitee.permission_invite,
|
||||
}
|
||||
await service.share(
|
||||
target_uuid, ScheduleItemShareRequest(email=email, **permission)
|
||||
target_uuid, ScheduleItemShareRequest(phone=phone, **permission)
|
||||
)
|
||||
invited.append(email)
|
||||
invited.append(phone)
|
||||
if not invited:
|
||||
return calendar_error_output(
|
||||
tool_name=tool_name,
|
||||
tool_call_args=tool_call_args,
|
||||
code="NOT_FOUND",
|
||||
message="邀请目标均无有效邮箱",
|
||||
message="邀请目标均无有效手机号",
|
||||
retryable=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from agentscope.tool import ToolResponse
|
||||
from core.agentscope.tools.tool_call_context import get_current_tool_call_id
|
||||
from core.agentscope.tools.utils import (
|
||||
find_auth_email_by_user_id,
|
||||
find_auth_phone_by_user_id,
|
||||
list_auth_users,
|
||||
)
|
||||
from core.agentscope.tools.utils.tool_response_builder import (
|
||||
@@ -46,22 +46,22 @@ def _lookup_error_output(
|
||||
async def _resolve_identity(
|
||||
*,
|
||||
session: AsyncSession,
|
||||
user_email: str | None,
|
||||
user_phone: str | None,
|
||||
user_name: str | None,
|
||||
) -> dict[str, Any]:
|
||||
"""Resolve user identity by email or username."""
|
||||
email = user_email.strip().lower() if isinstance(user_email, str) else ""
|
||||
"""Resolve user identity by phone or username."""
|
||||
phone = user_phone.strip() if isinstance(user_phone, str) else ""
|
||||
name = user_name.strip() if isinstance(user_name, str) else ""
|
||||
|
||||
if bool(email) == bool(name):
|
||||
if bool(phone) == bool(name):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="请提供 email 或 username 其中之一",
|
||||
detail="请提供 phone 或 username 其中之一",
|
||||
)
|
||||
|
||||
if email:
|
||||
if phone:
|
||||
auth_gateway = SupabaseAuthGateway()
|
||||
user = await auth_gateway.get_user_by_email(email)
|
||||
user = await auth_gateway.get_user_by_phone(phone)
|
||||
user_id = UUID(user.id)
|
||||
|
||||
stmt = (
|
||||
@@ -73,9 +73,9 @@ async def _resolve_identity(
|
||||
|
||||
return {
|
||||
"userId": str(user_id),
|
||||
"email": user.email,
|
||||
"phone": user.phone,
|
||||
"username": username,
|
||||
"matchedBy": "email",
|
||||
"matchedBy": "phone",
|
||||
}
|
||||
|
||||
stmt = (
|
||||
@@ -90,20 +90,20 @@ async def _resolve_identity(
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
|
||||
users = list_auth_users()
|
||||
email_value = find_auth_email_by_user_id(users=users, user_id=profile.id)
|
||||
phone_value = find_auth_phone_by_user_id(users=users, user_id=profile.id)
|
||||
|
||||
return {
|
||||
"userId": str(profile.id),
|
||||
"email": email_value,
|
||||
"phone": phone_value,
|
||||
"username": profile.username,
|
||||
"matchedBy": "username",
|
||||
}
|
||||
|
||||
|
||||
async def user_lookup(
|
||||
user_email: Annotated[
|
||||
user_phone: Annotated[
|
||||
str | None,
|
||||
Field(description="User email address to look up."),
|
||||
Field(description="User phone to look up."),
|
||||
] = None,
|
||||
user_name: Annotated[
|
||||
str | None,
|
||||
@@ -112,16 +112,16 @@ async def user_lookup(
|
||||
session: Any = None,
|
||||
owner_id: Any = None,
|
||||
) -> ToolResponse:
|
||||
"""Look up user identity by email or username.
|
||||
"""Look up user identity by phone or username.
|
||||
|
||||
Args:
|
||||
user_email: User email address for lookup.
|
||||
user_phone: User phone for lookup.
|
||||
user_name: Username for lookup.
|
||||
|
||||
Returns:
|
||||
ToolResponse with serialized ToolAgentOutput payload.
|
||||
"""
|
||||
tool_call_args = {"user_email": user_email, "user_name": user_name}
|
||||
tool_call_args = {"user_phone": user_phone, "user_name": user_name}
|
||||
|
||||
if session is None or owner_id is None:
|
||||
return _lookup_error_output(
|
||||
@@ -134,17 +134,17 @@ async def user_lookup(
|
||||
try:
|
||||
resolved = await _resolve_identity(
|
||||
session=cast(AsyncSession, session),
|
||||
user_email=user_email,
|
||||
user_phone=user_phone,
|
||||
user_name=user_name,
|
||||
)
|
||||
|
||||
username = str(resolved.get("username") or "")
|
||||
email = str(resolved.get("email") or "")
|
||||
phone = str(resolved.get("phone") or "")
|
||||
user_id = str(resolved.get("userId") or "")
|
||||
matched_by = str(resolved.get("matchedBy") or "")
|
||||
summary = (
|
||||
f"status=success matched_by={matched_by} user_id={user_id} "
|
||||
f"username={username} has_email={str(bool(email)).lower()}"
|
||||
f"username={username} has_phone={str(bool(phone)).lower()}"
|
||||
)
|
||||
return _dump_tool_output(
|
||||
ToolAgentOutput(
|
||||
|
||||
@@ -6,7 +6,15 @@ from enum import Enum
|
||||
|
||||
class ToolGroup(str, Enum):
|
||||
READ = "read"
|
||||
WRITE = "write"
|
||||
EXECUTE = "execute"
|
||||
MEMORY = "memory"
|
||||
|
||||
|
||||
class AgentTool(str, Enum):
|
||||
CALENDAR_READ = "calendar.read"
|
||||
CALENDAR_WRITE = "calendar.write"
|
||||
CALENDAR_SHARE = "calendar.share"
|
||||
USER_LOOKUP = "user.lookup"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -29,30 +37,51 @@ TOOL_CONFIGS: dict[str, ToolConfig] = {
|
||||
),
|
||||
"user_lookup": ToolConfig(
|
||||
name="user_lookup",
|
||||
group=ToolGroup.READ,
|
||||
group=ToolGroup.MEMORY,
|
||||
approval=ToolApprovalConfig(required=False),
|
||||
),
|
||||
"calendar_write": ToolConfig(
|
||||
name="calendar_write",
|
||||
group=ToolGroup.WRITE,
|
||||
group=ToolGroup.EXECUTE,
|
||||
approval=ToolApprovalConfig(required=False),
|
||||
),
|
||||
"calendar_share": ToolConfig(
|
||||
name="calendar_share",
|
||||
group=ToolGroup.WRITE,
|
||||
group=ToolGroup.EXECUTE,
|
||||
approval=ToolApprovalConfig(required=False),
|
||||
),
|
||||
}
|
||||
|
||||
AGENT_TOOL_TO_FUNCTION_NAME: dict[AgentTool, str] = {
|
||||
AgentTool.CALENDAR_READ: "calendar_read",
|
||||
AgentTool.CALENDAR_WRITE: "calendar_write",
|
||||
AgentTool.CALENDAR_SHARE: "calendar_share",
|
||||
AgentTool.USER_LOOKUP: "user_lookup",
|
||||
}
|
||||
|
||||
def get_tool_config(tool_name: str) -> ToolConfig:
|
||||
config = TOOL_CONFIGS.get(tool_name)
|
||||
if config is None:
|
||||
raise ValueError(f"unknown tool: {tool_name}")
|
||||
return config
|
||||
TOOL_NAME_ALIASES: dict[str, AgentTool] = {
|
||||
AgentTool.CALENDAR_READ.value: AgentTool.CALENDAR_READ,
|
||||
"calendar_read": AgentTool.CALENDAR_READ,
|
||||
AgentTool.CALENDAR_WRITE.value: AgentTool.CALENDAR_WRITE,
|
||||
"calendar_write": AgentTool.CALENDAR_WRITE,
|
||||
AgentTool.CALENDAR_SHARE.value: AgentTool.CALENDAR_SHARE,
|
||||
"calendar_share": AgentTool.CALENDAR_SHARE,
|
||||
AgentTool.USER_LOOKUP.value: AgentTool.USER_LOOKUP,
|
||||
"user_lookup": AgentTool.USER_LOOKUP,
|
||||
}
|
||||
|
||||
|
||||
def resolve_tool_names_by_groups(groups: set[ToolGroup]) -> set[str]:
|
||||
if not groups:
|
||||
return set()
|
||||
return {name for name, config in TOOL_CONFIGS.items() if config.group in groups}
|
||||
def parse_agent_tool(value: object) -> AgentTool:
|
||||
if isinstance(value, AgentTool):
|
||||
return value
|
||||
raw_value = str(value or "").strip().lower()
|
||||
if not raw_value:
|
||||
raise ValueError("enabled tool value cannot be empty")
|
||||
tool = TOOL_NAME_ALIASES.get(raw_value)
|
||||
if tool is None:
|
||||
raise ValueError(f"unknown enabled tool: {raw_value}")
|
||||
return tool
|
||||
|
||||
|
||||
def resolve_tool_function_names(tools: set[AgentTool]) -> set[str]:
|
||||
return {AGENT_TOOL_TO_FUNCTION_NAME[tool] for tool in tools}
|
||||
|
||||
@@ -7,10 +7,15 @@ from core.agentscope.tools.tool_call_context import (
|
||||
reset_current_tool_call_id,
|
||||
set_current_tool_call_id,
|
||||
)
|
||||
from core.agentscope.tools.tool_config import (
|
||||
AGENT_TOOL_TO_FUNCTION_NAME,
|
||||
TOOL_CONFIGS,
|
||||
ToolConfig,
|
||||
parse_agent_tool,
|
||||
)
|
||||
from core.agentscope.tools.utils.tool_response_builder import (
|
||||
build_error_response,
|
||||
)
|
||||
from core.agentscope.tools.tool_config import ToolConfig, TOOL_CONFIGS
|
||||
|
||||
|
||||
def register_tool_middlewares(
|
||||
@@ -59,6 +64,18 @@ def create_approval_middleware(
|
||||
approval_resolver: Callable[[str, dict[str, Any], ToolConfig], str | None]
|
||||
| None = None,
|
||||
) -> Callable[..., AsyncGenerator[Any, None]]:
|
||||
def _resolve_tool_config(*, tool_name: str) -> ToolConfig | None:
|
||||
config = config_by_name.get(tool_name)
|
||||
if config is not None:
|
||||
return config
|
||||
try:
|
||||
normalized_tool_name = AGENT_TOOL_TO_FUNCTION_NAME[
|
||||
parse_agent_tool(tool_name)
|
||||
]
|
||||
except ValueError:
|
||||
return None
|
||||
return config_by_name.get(normalized_tool_name)
|
||||
|
||||
def _resolve_tool_call_id(tool_call: dict[str, Any]) -> str:
|
||||
raw_tool_call_id = tool_call.get("id")
|
||||
if isinstance(raw_tool_call_id, str) and raw_tool_call_id.strip():
|
||||
@@ -81,7 +98,7 @@ def create_approval_middleware(
|
||||
yield response
|
||||
return
|
||||
|
||||
config = config_by_name.get(tool_name)
|
||||
config = _resolve_tool_config(tool_name=tool_name)
|
||||
if config is None or not config.approval.required:
|
||||
async for response in await next_handler(**kwargs):
|
||||
yield response
|
||||
@@ -134,15 +151,3 @@ def create_approval_middleware(
|
||||
yield pending_response
|
||||
|
||||
return approval_middleware
|
||||
|
||||
|
||||
def create_hitl_middleware(
|
||||
*,
|
||||
meta_by_name: dict[str, ToolConfig],
|
||||
approval_resolver: Callable[[str, dict[str, Any], ToolConfig], str | None]
|
||||
| None = None,
|
||||
) -> Callable[..., AsyncGenerator[Any, None]]:
|
||||
return create_approval_middleware(
|
||||
config_by_name=meta_by_name,
|
||||
approval_resolver=approval_resolver,
|
||||
)
|
||||
|
||||
@@ -13,8 +13,6 @@ from core.agentscope.tools.custom.calendar import (
|
||||
from core.agentscope.tools.custom.user_lookup import user_lookup
|
||||
from core.agentscope.tools.tool_config import (
|
||||
TOOL_CONFIGS,
|
||||
ToolGroup,
|
||||
resolve_tool_names_by_groups,
|
||||
)
|
||||
from core.agentscope.tools.tool_middleware import register_tool_middlewares
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
@@ -28,46 +26,36 @@ TOOL_FUNCTIONS: dict[str, Any] = {
|
||||
}
|
||||
|
||||
|
||||
AGENT_TYPE_TO_GROUPS: dict[AgentType, set[ToolGroup]] = {
|
||||
AgentType.WORKER: {ToolGroup.READ, ToolGroup.WRITE},
|
||||
AgentType.MEMORY: {ToolGroup.READ, ToolGroup.WRITE},
|
||||
AGENT_TYPE_TO_DEFAULT_TOOLS: dict[AgentType, set[str]] = {
|
||||
AgentType.WORKER: {
|
||||
"calendar_read",
|
||||
"calendar_write",
|
||||
"calendar_share",
|
||||
"user_lookup",
|
||||
},
|
||||
AgentType.MEMORY: {"calendar_read", "user_lookup"},
|
||||
}
|
||||
|
||||
|
||||
def _resolve_enabled_tools(
|
||||
*,
|
||||
groups: set[ToolGroup] | None,
|
||||
enabled_tool_names: set[str] | None,
|
||||
) -> set[str]:
|
||||
if enabled_tool_names is not None:
|
||||
def _validate_enabled_tool_names(enabled_tool_names: set[str]) -> set[str]:
|
||||
unknown = enabled_tool_names - set(TOOL_FUNCTIONS)
|
||||
if unknown:
|
||||
raise ValueError(f"unknown tools in enabled_tool_names: {sorted(unknown)}")
|
||||
return set(enabled_tool_names)
|
||||
|
||||
if groups is None:
|
||||
return set(TOOL_FUNCTIONS)
|
||||
|
||||
resolved = resolve_tool_names_by_groups(groups)
|
||||
unknown = resolved - set(TOOL_FUNCTIONS)
|
||||
if unknown:
|
||||
raise ValueError(f"tool config contains unknown tools: {sorted(unknown)}")
|
||||
return resolved
|
||||
return enabled_tool_names
|
||||
|
||||
|
||||
def build_toolkit(
|
||||
*,
|
||||
session: AsyncSession,
|
||||
owner_id: UUID,
|
||||
groups: set[ToolGroup] | None = None,
|
||||
enabled_tool_names: set[str] | None = None,
|
||||
enable_hitl: bool | None = None,
|
||||
):
|
||||
toolkit = Toolkit()
|
||||
enabled_names = _resolve_enabled_tools(
|
||||
groups=groups,
|
||||
enabled_tool_names=enabled_tool_names,
|
||||
)
|
||||
if enabled_tool_names is None:
|
||||
enabled_names = set(TOOL_FUNCTIONS)
|
||||
else:
|
||||
enabled_names = _validate_enabled_tool_names(set(enabled_tool_names))
|
||||
|
||||
preset_kwargs = cast(
|
||||
dict[str, JSONSerializableObject],
|
||||
@@ -100,15 +88,13 @@ def build_stage_toolkit(
|
||||
enabled_tool_names: set[str] | None = None,
|
||||
enable_hitl: bool | None = None,
|
||||
):
|
||||
groups = AGENT_TYPE_TO_GROUPS.get(agent_type)
|
||||
if groups is None:
|
||||
default_tools = AGENT_TYPE_TO_DEFAULT_TOOLS.get(agent_type)
|
||||
if default_tools is None:
|
||||
raise ValueError(f"unknown agent_type: {agent_type}")
|
||||
|
||||
stage_enabled_names = resolve_tool_names_by_groups(set(groups))
|
||||
selected_names = (
|
||||
stage_enabled_names
|
||||
set(default_tools)
|
||||
if enabled_tool_names is None
|
||||
else stage_enabled_names | set(enabled_tool_names)
|
||||
else _validate_enabled_tool_names(set(enabled_tool_names))
|
||||
)
|
||||
|
||||
return build_toolkit(
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from core.agentscope.tools.utils.auth_helpers import (
|
||||
find_auth_email_by_user_id,
|
||||
find_auth_phone_by_user_id,
|
||||
list_auth_users,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"list_auth_users",
|
||||
"find_auth_email_by_user_id",
|
||||
"find_auth_phone_by_user_id",
|
||||
]
|
||||
|
||||
@@ -25,12 +25,12 @@ def list_auth_users() -> list[Any]:
|
||||
return users
|
||||
|
||||
|
||||
def find_auth_email_by_user_id(*, users: list[Any], user_id: UUID) -> str | None:
|
||||
"""Find auth email by user id from fetched user list."""
|
||||
def find_auth_phone_by_user_id(*, users: list[Any], user_id: UUID) -> str | None:
|
||||
"""Find auth phone by user id from fetched user list."""
|
||||
target = str(user_id)
|
||||
for user in users:
|
||||
if str(getattr(user, "id", "")) == target:
|
||||
email = getattr(user, "email", None)
|
||||
if isinstance(email, str) and email.strip():
|
||||
return email.strip()
|
||||
phone = getattr(user, "phone", None)
|
||||
if isinstance(phone, str) and phone.strip():
|
||||
return phone.strip()
|
||||
return None
|
||||
|
||||
@@ -9,7 +9,7 @@ from fastapi import HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.agentscope.tools.utils.auth_helpers import (
|
||||
find_auth_email_by_user_id,
|
||||
find_auth_phone_by_user_id,
|
||||
list_auth_users,
|
||||
)
|
||||
from core.auth.models import CurrentUser
|
||||
@@ -125,7 +125,7 @@ def parse_iso_datetime(value: str | None) -> datetime | None:
|
||||
return parsed.astimezone(timezone.utc)
|
||||
|
||||
|
||||
def resolve_share_target_email_map(invitee_user_ids: list[str]) -> dict[str, str]:
|
||||
def resolve_share_target_phone_map(invitee_user_ids: list[str]) -> dict[str, str]:
|
||||
users = list_auth_users()
|
||||
resolved: dict[str, str] = {}
|
||||
for raw_user_id in invitee_user_ids:
|
||||
@@ -138,7 +138,7 @@ def resolve_share_target_email_map(invitee_user_ids: list[str]) -> dict[str, str
|
||||
user_uuid = UUID(normalized_user_id)
|
||||
except ValueError:
|
||||
continue
|
||||
email = find_auth_email_by_user_id(users=users, user_id=user_uuid)
|
||||
if email:
|
||||
resolved[str(user_uuid)] = email.lower()
|
||||
phone = find_auth_phone_by_user_id(users=users, user_id=user_uuid)
|
||||
if phone:
|
||||
resolved[str(user_uuid)] = phone
|
||||
return resolved
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
from core.automation.scheduler import (
|
||||
AutomationSchedulerService,
|
||||
DispatchResult,
|
||||
SqlAlchemyAutomationSchedulerRepository,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AutomationSchedulerService",
|
||||
"DispatchResult",
|
||||
"SqlAlchemyAutomationSchedulerRepository",
|
||||
]
|
||||
@@ -0,0 +1,247 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Protocol
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.logging import get_logger
|
||||
from models.agent_chat_session import AgentChatSession, SessionType
|
||||
from models.automation_jobs import AutomationJob, ScheduleType
|
||||
from schemas.automation.config import AutomationJobConfig
|
||||
from schemas.automation.scheduler import DueAutomationJob, SchedulerDispatchCommand
|
||||
|
||||
logger = get_logger("core.automation.scheduler")
|
||||
|
||||
|
||||
class QueueLike(Protocol):
|
||||
async def enqueue(
|
||||
self,
|
||||
*,
|
||||
command: dict[str, object],
|
||||
dedup_key: str | None,
|
||||
) -> str: ...
|
||||
|
||||
|
||||
class AutomationSchedulerRepositoryLike(Protocol):
|
||||
async def list_due_jobs(
|
||||
self,
|
||||
*,
|
||||
now_utc: datetime,
|
||||
limit: int,
|
||||
) -> list[DueAutomationJob]: ...
|
||||
|
||||
async def get_job_config(self, *, job_id: UUID) -> AutomationJobConfig: ...
|
||||
|
||||
async def ensure_latest_chat_session(self, *, owner_id: UUID) -> UUID: ...
|
||||
|
||||
async def mark_job_dispatched(
|
||||
self,
|
||||
*,
|
||||
job_id: UUID,
|
||||
next_run_at: datetime,
|
||||
last_run_at: datetime,
|
||||
) -> None: ...
|
||||
|
||||
async def commit(self) -> None: ...
|
||||
|
||||
async def rollback(self) -> None: ...
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class DispatchResult:
|
||||
scanned: int
|
||||
dispatched: int
|
||||
|
||||
|
||||
class AutomationSchedulerService:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
repository: AutomationSchedulerRepositoryLike,
|
||||
queue: QueueLike,
|
||||
) -> None:
|
||||
self._repository = repository
|
||||
self._queue = queue
|
||||
|
||||
async def scan_and_dispatch(
|
||||
self,
|
||||
*,
|
||||
now_utc: datetime,
|
||||
limit: int,
|
||||
) -> DispatchResult:
|
||||
safe_limit = max(int(limit), 1)
|
||||
due_jobs = await self._repository.list_due_jobs(
|
||||
now_utc=now_utc, limit=safe_limit
|
||||
)
|
||||
dispatched = 0
|
||||
for job in due_jobs:
|
||||
try:
|
||||
config = await self._repository.get_job_config(job_id=job.id)
|
||||
thread_id = await self._repository.ensure_latest_chat_session(
|
||||
owner_id=job.owner_id
|
||||
)
|
||||
command = self._build_dispatch_command(
|
||||
job=job,
|
||||
thread_id=thread_id,
|
||||
input_text=config.input_template,
|
||||
now_utc=now_utc,
|
||||
)
|
||||
await self._queue.enqueue(command=command, dedup_key=None)
|
||||
await self._repository.mark_job_dispatched(
|
||||
job_id=job.id,
|
||||
next_run_at=_compute_next_run_at(
|
||||
current_next_run_at=job.next_run_at,
|
||||
now_utc=now_utc,
|
||||
schedule_type=job.schedule_type,
|
||||
),
|
||||
last_run_at=now_utc,
|
||||
)
|
||||
await self._repository.commit()
|
||||
dispatched += 1
|
||||
except Exception as exc:
|
||||
await self._repository.rollback()
|
||||
logger.exception(
|
||||
"automation job dispatch failed",
|
||||
job_id=str(job.id),
|
||||
owner_id=str(job.owner_id),
|
||||
error=str(exc),
|
||||
)
|
||||
return DispatchResult(scanned=len(due_jobs), dispatched=dispatched)
|
||||
|
||||
def _build_dispatch_command(
|
||||
self,
|
||||
*,
|
||||
job: DueAutomationJob,
|
||||
thread_id: UUID,
|
||||
input_text: str,
|
||||
now_utc: datetime,
|
||||
) -> dict[str, object]:
|
||||
run_id = f"auto-{job.id}-{int(now_utc.timestamp())}"
|
||||
payload = SchedulerDispatchCommand(
|
||||
owner_id=job.owner_id,
|
||||
automation_job_id=job.id,
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
input_text=input_text.strip(),
|
||||
)
|
||||
return {
|
||||
"command": "run",
|
||||
"owner_id": str(payload.owner_id),
|
||||
"automation_job_id": str(payload.automation_job_id),
|
||||
"queue": "bulk",
|
||||
"run_input": {
|
||||
"threadId": str(payload.thread_id),
|
||||
"runId": payload.run_id,
|
||||
"state": {},
|
||||
"messages": [
|
||||
{
|
||||
"id": str(uuid4()),
|
||||
"role": "user",
|
||||
"content": payload.input_text,
|
||||
}
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {
|
||||
"agent_type": "memory",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class SqlAlchemyAutomationSchedulerRepository:
|
||||
def __init__(self, *, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
async def list_due_jobs(
|
||||
self,
|
||||
*,
|
||||
now_utc: datetime,
|
||||
limit: int,
|
||||
) -> list[DueAutomationJob]:
|
||||
stmt = (
|
||||
select(AutomationJob)
|
||||
.where(AutomationJob.deleted_at.is_(None))
|
||||
.where(AutomationJob.status == "active")
|
||||
.where(AutomationJob.next_run_at <= now_utc)
|
||||
.order_by(AutomationJob.next_run_at.asc())
|
||||
.limit(max(limit, 1))
|
||||
)
|
||||
rows = (await self._session.execute(stmt)).scalars().all()
|
||||
return [
|
||||
DueAutomationJob(
|
||||
id=row.id,
|
||||
owner_id=row.owner_id,
|
||||
schedule_type=row.schedule_type,
|
||||
timezone=row.timezone,
|
||||
next_run_at=row.next_run_at,
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
async def get_job_config(self, *, job_id: UUID) -> AutomationJobConfig:
|
||||
stmt = select(AutomationJob.config).where(AutomationJob.id == job_id)
|
||||
config_payload = (await self._session.execute(stmt)).scalar_one()
|
||||
return AutomationJobConfig.model_validate(config_payload or {})
|
||||
|
||||
async def ensure_latest_chat_session(self, *, owner_id: UUID) -> UUID:
|
||||
stmt = (
|
||||
select(AgentChatSession.id)
|
||||
.where(AgentChatSession.user_id == owner_id)
|
||||
.where(AgentChatSession.deleted_at.is_(None))
|
||||
.where(AgentChatSession.session_type == SessionType.CHAT)
|
||||
.order_by(AgentChatSession.last_activity_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
existing = (await self._session.execute(stmt)).scalar_one_or_none()
|
||||
if existing is not None:
|
||||
return existing
|
||||
|
||||
session = AgentChatSession(
|
||||
id=uuid4(),
|
||||
user_id=owner_id,
|
||||
session_type=SessionType.CHAT,
|
||||
)
|
||||
self._session.add(session)
|
||||
await self._session.flush()
|
||||
return session.id
|
||||
|
||||
async def mark_job_dispatched(
|
||||
self,
|
||||
*,
|
||||
job_id: UUID,
|
||||
next_run_at: datetime,
|
||||
last_run_at: datetime,
|
||||
) -> None:
|
||||
stmt = select(AutomationJob).where(AutomationJob.id == job_id)
|
||||
row = (await self._session.execute(stmt)).scalar_one()
|
||||
row.next_run_at = next_run_at
|
||||
row.last_run_at = last_run_at
|
||||
await self._session.flush()
|
||||
|
||||
async def commit(self) -> None:
|
||||
await self._session.commit()
|
||||
|
||||
async def rollback(self) -> None:
|
||||
await self._session.rollback()
|
||||
|
||||
|
||||
def _compute_next_run_at(
|
||||
*,
|
||||
current_next_run_at: datetime,
|
||||
now_utc: datetime,
|
||||
schedule_type: ScheduleType,
|
||||
) -> datetime:
|
||||
delta = timedelta(days=1 if schedule_type == ScheduleType.DAILY else 7)
|
||||
next_run_at = current_next_run_at
|
||||
while next_run_at <= now_utc:
|
||||
next_run_at = next_run_at + delta
|
||||
return next_run_at
|
||||
|
||||
|
||||
def utc_now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
@@ -1,27 +1,30 @@
|
||||
agents:
|
||||
- agent_type: worker
|
||||
llm_model_code: qwen3.5-35b-a3b
|
||||
status: active
|
||||
config:
|
||||
temperature: 0.7
|
||||
max_tokens: null
|
||||
timeout_seconds: 30
|
||||
context_messages:
|
||||
mode: number
|
||||
count: 20
|
||||
enabled_tool_groups:
|
||||
- read
|
||||
- write
|
||||
|
||||
- agent_type: memory
|
||||
- agent_type: router
|
||||
llm_model_code: qwen3.5-flash
|
||||
status: active
|
||||
config:
|
||||
temperature: 0.7
|
||||
max_tokens: null
|
||||
timeout_seconds: 30
|
||||
visibility_consumer_bit: 16
|
||||
context_messages:
|
||||
mode: day
|
||||
count: 2
|
||||
enabled_tool_groups:
|
||||
- read
|
||||
enabled_tools: []
|
||||
|
||||
- agent_type: worker
|
||||
llm_model_code: qwen3.5-flash
|
||||
status: active
|
||||
config:
|
||||
temperature: 0.7
|
||||
max_tokens: null
|
||||
timeout_seconds: 30
|
||||
visibility_consumer_bit: 17
|
||||
context_messages:
|
||||
mode: number
|
||||
count: 20
|
||||
enabled_tools:
|
||||
- calendar.read
|
||||
- calendar.write
|
||||
- calendar.share
|
||||
- user.lookup
|
||||
|
||||
@@ -5,6 +5,7 @@ import uuid
|
||||
from enum import Enum
|
||||
|
||||
from sqlalchemy import (
|
||||
BigInteger,
|
||||
JSON,
|
||||
Enum as SqlEnum,
|
||||
ForeignKey,
|
||||
@@ -59,6 +60,11 @@ class AgentChatMessage(TimestampMixin, SoftDeleteMixin, Base):
|
||||
output_tokens: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
cost: Mapped[Decimal] = mapped_column(Numeric(12, 6), nullable=False, default=0)
|
||||
latency_ms: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
visibility_mask: Mapped[int] = mapped_column(
|
||||
BigInteger,
|
||||
nullable=False,
|
||||
default=0,
|
||||
)
|
||||
metadata_json: Mapped[dict[str, object] | None] = mapped_column(
|
||||
"metadata", JSON().with_variant(JSONB, "postgresql"), nullable=True
|
||||
)
|
||||
|
||||
@@ -4,8 +4,8 @@ import uuid
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from sqlalchemy import DateTime, String, Text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy import DateTime, JSON, String
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from core.db.base import Base, SoftDeleteMixin, TimestampMixin
|
||||
@@ -36,9 +36,10 @@ class AutomationJob(TimestampMixin, SoftDeleteMixin, Base):
|
||||
String(255),
|
||||
nullable=False,
|
||||
)
|
||||
prompt: Mapped[str] = mapped_column(
|
||||
Text,
|
||||
config: Mapped[dict[str, object]] = mapped_column(
|
||||
JSON().with_variant(JSONB, "postgresql"),
|
||||
nullable=False,
|
||||
default=dict,
|
||||
)
|
||||
schedule_type: Mapped[ScheduleType] = mapped_column(
|
||||
String(20),
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
|
||||
|
||||
class AgentConsumerBinding(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
agent_type: str = Field(..., min_length=1, max_length=64)
|
||||
bit: int = Field(..., ge=16, le=63)
|
||||
|
||||
@field_validator("agent_type")
|
||||
@classmethod
|
||||
def _normalize_agent_type(cls, value: str) -> str:
|
||||
normalized = value.strip().lower()
|
||||
if not normalized:
|
||||
raise ValueError("agent_type must not be empty")
|
||||
return normalized
|
||||
|
||||
|
||||
class ConsumerRegistry(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
bindings: list[AgentConsumerBinding] = Field(default_factory=list)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_unique_bindings(self) -> "ConsumerRegistry":
|
||||
by_agent: set[str] = set()
|
||||
by_bit: set[int] = set()
|
||||
for item in self.bindings:
|
||||
if item.agent_type in by_agent:
|
||||
raise ValueError(f"duplicate agent_type binding: {item.agent_type}")
|
||||
if item.bit in by_bit:
|
||||
raise ValueError(f"duplicate visibility bit binding: {item.bit}")
|
||||
by_agent.add(item.agent_type)
|
||||
by_bit.add(item.bit)
|
||||
return self
|
||||
|
||||
def resolve_agent_bit(self, *, agent_type: str) -> int:
|
||||
target = agent_type.strip().lower()
|
||||
for item in self.bindings:
|
||||
if item.agent_type == target:
|
||||
return item.bit
|
||||
raise ValueError(f"agent visibility bit not configured: {target}")
|
||||
@@ -0,0 +1,63 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
|
||||
class ExecutorKind(str, Enum):
|
||||
SINGLE_SHOT = "single_shot"
|
||||
REACT = "react"
|
||||
|
||||
|
||||
class ContextWindowMode(str, Enum):
|
||||
DAY = "day"
|
||||
NUMBER = "number"
|
||||
|
||||
|
||||
class ContextPolicy(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
consumer_agent_type: str = Field(..., min_length=1, max_length=64)
|
||||
window_mode: ContextWindowMode = ContextWindowMode.NUMBER
|
||||
count: int = Field(default=20, ge=1, le=200)
|
||||
|
||||
@field_validator("consumer_agent_type")
|
||||
@classmethod
|
||||
def _normalize_consumer_agent_type(cls, value: str) -> str:
|
||||
normalized = value.strip().lower()
|
||||
if not normalized:
|
||||
raise ValueError("consumer_agent_type must not be empty")
|
||||
return normalized
|
||||
|
||||
|
||||
class StageSpec(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
stage_name: str = Field(..., min_length=1, max_length=64)
|
||||
executor_kind: ExecutorKind
|
||||
default_visibility_mask: int = Field(..., ge=0, le=(1 << 63) - 1)
|
||||
context_policy: ContextPolicy
|
||||
|
||||
@field_validator("stage_name")
|
||||
@classmethod
|
||||
def _normalize_stage_name(cls, value: str) -> str:
|
||||
normalized = value.strip().lower()
|
||||
if not normalized:
|
||||
raise ValueError("stage_name must not be empty")
|
||||
return normalized
|
||||
|
||||
|
||||
class PipelineSpec(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
mode: str = Field(..., min_length=1, max_length=64)
|
||||
stages: list[StageSpec] = Field(..., min_length=1)
|
||||
|
||||
@field_validator("mode")
|
||||
@classmethod
|
||||
def _normalize_mode(cls, value: str) -> str:
|
||||
normalized = value.strip().lower()
|
||||
if not normalized:
|
||||
raise ValueError("mode must not be empty")
|
||||
return normalized
|
||||
@@ -0,0 +1,50 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import IntEnum
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
|
||||
class SystemVisibilityBit(IntEnum):
|
||||
UI_HISTORY = 0
|
||||
UI_REALTIME = 1
|
||||
|
||||
|
||||
class VisibilityMask(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
value: int = Field(..., ge=0, le=(1 << 63) - 1)
|
||||
|
||||
@classmethod
|
||||
def from_bits(cls, *, bits: list[int]) -> "VisibilityMask":
|
||||
mask = 0
|
||||
for bit in bits:
|
||||
validate_visibility_bit(bit=bit)
|
||||
mask |= 1 << bit
|
||||
return cls(value=mask)
|
||||
|
||||
def contains(self, *, bit: int) -> bool:
|
||||
validate_visibility_bit(bit=bit)
|
||||
return bool(self.value & (1 << bit))
|
||||
|
||||
|
||||
class VisibilityBitRef(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
bit: int = Field(..., ge=0, le=63)
|
||||
|
||||
@field_validator("bit")
|
||||
@classmethod
|
||||
def _validate_bit(cls, value: int) -> int:
|
||||
validate_visibility_bit(bit=value)
|
||||
return value
|
||||
|
||||
|
||||
def validate_visibility_bit(*, bit: int) -> None:
|
||||
if bit < 0 or bit > 63:
|
||||
raise ValueError("visibility bit must be in range [0, 63]")
|
||||
|
||||
|
||||
def bit_mask(*, bit: int) -> int:
|
||||
validate_visibility_bit(bit=bit)
|
||||
return 1 << bit
|
||||
@@ -0,0 +1,20 @@
|
||||
from schemas.automation.config import (
|
||||
AutomationAgentType,
|
||||
AutomationContextSource,
|
||||
AutomationContextWindowMode,
|
||||
AutomationJobConfig,
|
||||
AutomationMemoryContextConfig,
|
||||
default_memory_job_config,
|
||||
)
|
||||
from schemas.automation.scheduler import DueAutomationJob, SchedulerDispatchCommand
|
||||
|
||||
__all__ = [
|
||||
"AutomationAgentType",
|
||||
"AutomationContextSource",
|
||||
"AutomationContextWindowMode",
|
||||
"AutomationJobConfig",
|
||||
"AutomationMemoryContextConfig",
|
||||
"default_memory_job_config",
|
||||
"DueAutomationJob",
|
||||
"SchedulerDispatchCommand",
|
||||
]
|
||||
@@ -0,0 +1,62 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from core.agentscope.tools.tool_config import AgentTool
|
||||
|
||||
|
||||
class AutomationAgentType(str, Enum):
|
||||
MEMORY = "memory"
|
||||
|
||||
|
||||
class AutomationContextSource(str, Enum):
|
||||
LATEST_CHAT = "latest_chat"
|
||||
|
||||
|
||||
class AutomationContextWindowMode(str, Enum):
|
||||
DAY = "day"
|
||||
NUMBER = "number"
|
||||
|
||||
|
||||
class AutomationMemoryContextConfig(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
source: AutomationContextSource = AutomationContextSource.LATEST_CHAT
|
||||
window_mode: AutomationContextWindowMode = AutomationContextWindowMode.DAY
|
||||
window_count: int = Field(default=2, ge=1, le=200)
|
||||
|
||||
|
||||
class AutomationJobConfig(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
agent_type: AutomationAgentType = AutomationAgentType.MEMORY
|
||||
model_code: str = Field(default="qwen3.5-flash", min_length=1, max_length=64)
|
||||
enabled_tools: list[AgentTool] = Field(default_factory=list, max_length=32)
|
||||
input_template: str = Field(..., min_length=1, max_length=4000)
|
||||
context: AutomationMemoryContextConfig = Field(
|
||||
default_factory=AutomationMemoryContextConfig
|
||||
)
|
||||
|
||||
@field_validator("model_code")
|
||||
@classmethod
|
||||
def _validate_model_code(cls, value: str) -> str:
|
||||
normalized = value.strip()
|
||||
if normalized != "qwen3.5-flash":
|
||||
raise ValueError("model_code must be qwen3.5-flash")
|
||||
return normalized
|
||||
|
||||
|
||||
def default_memory_job_config() -> AutomationJobConfig:
|
||||
return AutomationJobConfig(
|
||||
agent_type=AutomationAgentType.MEMORY,
|
||||
model_code="qwen3.5-flash",
|
||||
enabled_tools=[AgentTool.CALENDAR_READ, AgentTool.USER_LOOKUP],
|
||||
input_template="请基于最近聊天上下文生成一段可执行的记忆总结与建议。",
|
||||
context=AutomationMemoryContextConfig(
|
||||
source=AutomationContextSource.LATEST_CHAT,
|
||||
window_mode=AutomationContextWindowMode.DAY,
|
||||
window_count=2,
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from models.automation_jobs import ScheduleType
|
||||
|
||||
|
||||
class DueAutomationJob(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
id: UUID
|
||||
owner_id: UUID
|
||||
schedule_type: ScheduleType
|
||||
timezone: str = Field(..., min_length=1, max_length=50)
|
||||
next_run_at: datetime
|
||||
|
||||
|
||||
class SchedulerDispatchCommand(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
owner_id: UUID
|
||||
automation_job_id: UUID
|
||||
thread_id: UUID
|
||||
run_id: str = Field(..., min_length=1, max_length=128)
|
||||
input_text: str = Field(..., min_length=1, max_length=4000)
|
||||
@@ -189,3 +189,56 @@ def test_step_started_internal_event_keeps_step_name() -> None:
|
||||
|
||||
assert result["type"] == "STEP_STARTED"
|
||||
assert result["stepName"] == "worker"
|
||||
|
||||
|
||||
def test_run_error_prefers_top_level_message_and_code() -> None:
|
||||
internal = {
|
||||
"type": "run.error",
|
||||
"threadId": "thread-1",
|
||||
"runId": "run-1",
|
||||
"message": "runtime failed",
|
||||
"code": "RUNTIME_ERROR",
|
||||
"data": {
|
||||
"message": "nested message",
|
||||
"code": "NESTED_ERROR",
|
||||
},
|
||||
}
|
||||
|
||||
result = to_agui_wire_event(internal)
|
||||
|
||||
assert result["type"] == "RUN_ERROR"
|
||||
assert result["message"] == "runtime failed"
|
||||
assert result["code"] == "RUNTIME_ERROR"
|
||||
|
||||
|
||||
def test_run_error_falls_back_to_data_when_top_level_missing() -> None:
|
||||
internal = {
|
||||
"type": "run.error",
|
||||
"threadId": "thread-1",
|
||||
"runId": "run-1",
|
||||
"data": {
|
||||
"message": "nested message",
|
||||
"code": "NESTED_ERROR",
|
||||
},
|
||||
}
|
||||
|
||||
result = to_agui_wire_event(internal)
|
||||
|
||||
assert result["type"] == "RUN_ERROR"
|
||||
assert result["message"] == "nested message"
|
||||
assert result["code"] == "NESTED_ERROR"
|
||||
|
||||
|
||||
def test_run_error_uses_default_message_when_payload_invalid() -> None:
|
||||
internal = {
|
||||
"type": "run.error",
|
||||
"threadId": "thread-1",
|
||||
"runId": "run-1",
|
||||
"data": "invalid",
|
||||
}
|
||||
|
||||
result = to_agui_wire_event(internal)
|
||||
|
||||
assert result["type"] == "RUN_ERROR"
|
||||
assert result["message"] == "Unknown error"
|
||||
assert "code" not in result
|
||||
|
||||
@@ -59,6 +59,16 @@ def _patch_repositories(
|
||||
monkeypatch.setattr(store_module, "MessageRepository", _FakeMessageRepository)
|
||||
monkeypatch.setattr(store_module, "AgentChatSessionStatus", _SessionStatus)
|
||||
|
||||
async def _fake_stage_bit_map(self, *, session: object) -> dict[str, int]:
|
||||
del self, session
|
||||
return {"router": 16, "worker": 17, "memory": 18}
|
||||
|
||||
monkeypatch.setattr(
|
||||
store_module.SqlAlchemyEventStore,
|
||||
"_load_stage_visibility_bit_map",
|
||||
_fake_stage_bit_map,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_persists_worker_output_with_answer_as_content(
|
||||
@@ -103,6 +113,7 @@ async def test_store_persists_worker_output_with_answer_as_content(
|
||||
assert metadata["agent_output"]["answer"] == "worker-answer"
|
||||
assert metadata["agent_output"]["ui_hints"]["intent"] == "message"
|
||||
assert append_kwargs["cost"] == Decimal("0.123")
|
||||
assert append_kwargs["visibility_mask"] == ((1 << 0) | (1 << 17))
|
||||
assert captured["message_delta"] == 1
|
||||
assert captured["token_delta"] == 8
|
||||
|
||||
@@ -141,3 +152,4 @@ async def test_store_persists_tool_output_with_summary_as_content(
|
||||
metadata["tool_agent_output"]["result"]
|
||||
== "status=success batch=1 success=1 failed=0 ids=[event-1]"
|
||||
)
|
||||
assert append_kwargs["visibility_mask"] == (1 << 0)
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agentscope.persistence.user_context_cache import UserContextCache
|
||||
from core.agentscope.schemas.user_context import (
|
||||
UserAgentContext,
|
||||
from schemas.user.context import (
|
||||
UserContext,
|
||||
parse_profile_settings,
|
||||
)
|
||||
|
||||
@@ -45,15 +46,15 @@ class _FakeRedis:
|
||||
self.set_store.pop(key, None)
|
||||
return len(keys)
|
||||
|
||||
async def sadd(self, key: str, *values: str) -> int:
|
||||
bucket = self.set_store.setdefault(key, set())
|
||||
async def sadd(self, name: str, *values: str) -> int:
|
||||
bucket = self.set_store.setdefault(name, set())
|
||||
before = len(bucket)
|
||||
for value in values:
|
||||
bucket.add(value)
|
||||
return len(bucket) - before
|
||||
|
||||
async def smembers(self, key: str) -> set[str]:
|
||||
return set(self.set_store.get(key, set()))
|
||||
async def smembers(self, name: str) -> set[str]:
|
||||
return set(self.set_store.get(name, set()))
|
||||
|
||||
|
||||
class _BrokenRedis:
|
||||
@@ -77,18 +78,18 @@ class _BrokenRedis:
|
||||
del keys
|
||||
raise RuntimeError("redis down")
|
||||
|
||||
async def sadd(self, key: str, *values: str) -> int:
|
||||
del key, values
|
||||
async def sadd(self, name: str, *values: str) -> int:
|
||||
del name, values
|
||||
raise RuntimeError("redis down")
|
||||
|
||||
async def smembers(self, key: str) -> set[str]:
|
||||
del key
|
||||
async def smembers(self, name: str) -> set[str]:
|
||||
del name
|
||||
raise RuntimeError("redis down")
|
||||
|
||||
|
||||
def _build_context() -> UserAgentContext:
|
||||
return UserAgentContext(
|
||||
user_id=uuid4(),
|
||||
def _build_context() -> UserContext:
|
||||
return UserContext(
|
||||
id=str(uuid4()),
|
||||
username="demo-user",
|
||||
bio="demo bio",
|
||||
settings=parse_profile_settings({"preferences": {"ai_language": "en-US"}}),
|
||||
@@ -111,11 +112,11 @@ async def test_user_context_cache_set_and_get_hit() -> None:
|
||||
loaded = await cache.get(session_id=session_id)
|
||||
|
||||
assert loaded is not None
|
||||
assert loaded.user_id == context.user_id
|
||||
assert loaded.id == context.id
|
||||
assert loaded.username == "demo-user"
|
||||
assert redis.expire_calls == [
|
||||
(f"agent:user-context:{session_id}", 600),
|
||||
(f"agent:user-context:sessions:{context.user_id}", 600),
|
||||
(f"agent:user-context:sessions:{context.id}", 600),
|
||||
]
|
||||
assert redis.hincrby_calls == [
|
||||
(f"agent:user-context:{session_id}", "turns_used", 1)
|
||||
@@ -138,12 +139,14 @@ async def test_user_context_cache_invalidate_user_deletes_all_sessions() -> None
|
||||
await cache.set(session_id=s1, context=context)
|
||||
await cache.set(session_id=s2, context=context)
|
||||
|
||||
deleted = await cache.invalidate_user(user_id=context.user_id)
|
||||
deleted = await cache.invalidate_user(user_id=UUID(context.id))
|
||||
|
||||
assert deleted == 2
|
||||
assert f"agent:user-context:{s1}" in redis.delete_calls
|
||||
assert f"agent:user-context:{s2}" in redis.delete_calls
|
||||
assert f"agent:user-context:sessions:{context.user_id}" in redis.delete_calls
|
||||
assert any(
|
||||
key.startswith("agent:user-context:sessions:") for key in redis.delete_calls
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agentscope.runtime.consumer_registry import build_consumer_registry
|
||||
|
||||
|
||||
def test_build_consumer_registry_from_system_agent_configs() -> None:
|
||||
registry = build_consumer_registry(
|
||||
system_agent_configs={
|
||||
"router": {"config": {"visibility_consumer_bit": 16}},
|
||||
"worker": {"config": {"visibility_consumer_bit": 17}},
|
||||
"memory": {"config": {"visibility_consumer_bit": 18}},
|
||||
}
|
||||
)
|
||||
|
||||
assert registry.resolve_agent_bit(agent_type="router") == 16
|
||||
assert registry.resolve_agent_bit(agent_type="worker") == 17
|
||||
|
||||
|
||||
def test_build_consumer_registry_rejects_duplicate_bit() -> None:
|
||||
with pytest.raises(ValueError, match="duplicate visibility bit"):
|
||||
build_consumer_registry(
|
||||
system_agent_configs={
|
||||
"router": {"config": {"visibility_consumer_bit": 16}},
|
||||
"worker": {"config": {"visibility_consumer_bit": 16}},
|
||||
}
|
||||
)
|
||||
@@ -28,7 +28,7 @@ def _user_context() -> UserContext:
|
||||
return UserContext(
|
||||
id="00000000-0000-0000-0000-000000000001",
|
||||
username="alice",
|
||||
email="alice@example.com",
|
||||
phone="+8613900000000",
|
||||
settings=parse_profile_settings(None),
|
||||
)
|
||||
|
||||
@@ -42,7 +42,7 @@ def _run_input() -> RunAgentInput:
|
||||
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
"forwardedProps": {"agent_type": "worker"},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -58,6 +58,7 @@ async def test_orchestrator_emits_run_lifecycle_events() -> None:
|
||||
run_input=_run_input(),
|
||||
context_messages=[],
|
||||
user_context=_user_context(),
|
||||
system_agent_mode="worker",
|
||||
)
|
||||
|
||||
assert result["worker"]["answer"] == "done"
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agentscope.runtime.pipeline_registry import build_default_pipeline_spec
|
||||
|
||||
|
||||
def test_build_default_pipeline_spec_worker_has_two_stages() -> None:
|
||||
spec = build_default_pipeline_spec(mode="worker")
|
||||
|
||||
assert spec.mode == "worker"
|
||||
assert [item.stage_name for item in spec.stages] == ["router", "worker"]
|
||||
|
||||
|
||||
def test_build_default_pipeline_spec_memory_has_single_stage() -> None:
|
||||
spec = build_default_pipeline_spec(mode="memory")
|
||||
|
||||
assert spec.mode == "memory"
|
||||
assert [item.stage_name for item in spec.stages] == ["memory"]
|
||||
|
||||
|
||||
def test_build_default_pipeline_spec_rejects_unknown_mode() -> None:
|
||||
with pytest.raises(ValueError, match="unsupported pipeline mode"):
|
||||
build_default_pipeline_spec(mode="planner")
|
||||
@@ -3,8 +3,23 @@ from __future__ import annotations
|
||||
import pytest
|
||||
from ag_ui.core import RunAgentInput
|
||||
|
||||
import core.agentscope.runtime.runner as runner_module
|
||||
from core.agentscope.runtime.runner import AgentScopeRunner
|
||||
from schemas.automation.config import default_memory_job_config
|
||||
from schemas.agent.runtime_models import (
|
||||
ExecutionMode,
|
||||
NormalizedTaskInput,
|
||||
ResultType,
|
||||
ResultTyping,
|
||||
RouterAgentOutput,
|
||||
RouterUiDecision,
|
||||
TaskType,
|
||||
TaskTyping,
|
||||
UiMode,
|
||||
WorkerAgentOutputLite,
|
||||
)
|
||||
from schemas.agent.system_agent import AgentType
|
||||
from schemas.user import UserContext, parse_profile_settings
|
||||
|
||||
|
||||
def _run_input() -> RunAgentInput:
|
||||
@@ -16,19 +31,51 @@ def _run_input() -> RunAgentInput:
|
||||
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
"forwardedProps": {"agent_type": "worker"},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_resolve_stage_agent_type_defaults_to_worker() -> None:
|
||||
assert AgentScopeRunner._resolve_stage_agent_type("") == AgentType.WORKER
|
||||
assert AgentScopeRunner._resolve_stage_agent_type("worker") == AgentType.WORKER
|
||||
assert AgentScopeRunner._resolve_stage_agent_type("unknown") == AgentType.WORKER
|
||||
def _user_context() -> UserContext:
|
||||
return UserContext(
|
||||
id="00000000-0000-0000-0000-000000000001",
|
||||
username="alice",
|
||||
phone="+8613900000000",
|
||||
settings=parse_profile_settings(None),
|
||||
)
|
||||
|
||||
|
||||
def test_resolve_stage_agent_type_supports_memory() -> None:
|
||||
assert AgentScopeRunner._resolve_stage_agent_type("memory") == AgentType.MEMORY
|
||||
def test_parse_agent_type_supports_known_stages() -> None:
|
||||
assert AgentScopeRunner._parse_agent_type(stage_name="router") == AgentType.ROUTER
|
||||
assert AgentScopeRunner._parse_agent_type(stage_name="worker") == AgentType.WORKER
|
||||
assert AgentScopeRunner._parse_agent_type(stage_name="memory") == AgentType.MEMORY
|
||||
|
||||
|
||||
def test_parse_agent_type_rejects_unknown_stage() -> None:
|
||||
with pytest.raises(ValueError, match="unsupported stage name"):
|
||||
AgentScopeRunner._parse_agent_type(stage_name="planner")
|
||||
|
||||
|
||||
def test_build_worker_input_messages_only_contains_router_contract() -> None:
|
||||
runner = AgentScopeRunner()
|
||||
router_output = RouterAgentOutput(
|
||||
normalized_task_input=NormalizedTaskInput(user_text="安排明天会议"),
|
||||
key_entities=[],
|
||||
constraints=[],
|
||||
task_typing=TaskTyping(primary=TaskType.SCHEDULING),
|
||||
execution_mode=ExecutionMode.TOOL_ASSISTED,
|
||||
result_typing=ResultTyping(primary=ResultType.EXECUTION_REPORT),
|
||||
ui=RouterUiDecision(
|
||||
ui_mode=UiMode.NONE,
|
||||
ui_decision_reason="单一执行任务,文本输出足够",
|
||||
),
|
||||
)
|
||||
|
||||
input_messages = runner._build_worker_input_messages(router_output=router_output)
|
||||
|
||||
assert len(input_messages) == 1
|
||||
assert input_messages[0].role == "user"
|
||||
assert "[RouterAgentOutput]" in str(input_messages[0].content)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -43,11 +90,12 @@ async def test_resolve_runtime_client_time_from_forwarded_props() -> None:
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {
|
||||
"agent_type": "worker",
|
||||
"client_time": {
|
||||
"device_timezone": "America/Los_Angeles",
|
||||
"client_now_iso": "2026-03-16T09:12:33-07:00",
|
||||
"client_epoch_ms": 1773658353000,
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
@@ -55,3 +103,146 @@ async def test_resolve_runtime_client_time_from_forwarded_props() -> None:
|
||||
resolved = runner._resolve_runtime_client_time(run_input=run_input)
|
||||
assert resolved is not None
|
||||
assert resolved.device_timezone == "America/Los_Angeles"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_worker_mode_runs_router_then_worker(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
class _FakePipeline:
|
||||
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
|
||||
del session_id, event
|
||||
return "1-0"
|
||||
|
||||
class _FakeSessionCtx:
|
||||
async def __aenter__(self) -> object:
|
||||
return object()
|
||||
|
||||
async def __aexit__(self, exc_type: object, exc: object, tb: object) -> None:
|
||||
del exc_type, exc, tb
|
||||
|
||||
runner = AgentScopeRunner()
|
||||
load_calls: list[AgentType] = []
|
||||
|
||||
async def _fake_load_stage_config(*, session: object, agent_type: AgentType):
|
||||
del session
|
||||
load_calls.append(agent_type)
|
||||
return runner_module.SystemAgentRuntimeConfig(
|
||||
agent_type=agent_type,
|
||||
model_code="demo",
|
||||
api_base_url="https://example.com",
|
||||
api_key="test",
|
||||
llm_config=runner_module.SystemAgentLLMConfig(),
|
||||
)
|
||||
|
||||
async def _fake_execute_router_step(**kwargs: object) -> RouterAgentOutput:
|
||||
del kwargs
|
||||
return RouterAgentOutput(
|
||||
normalized_task_input=NormalizedTaskInput(user_text="安排会议"),
|
||||
key_entities=[],
|
||||
constraints=[],
|
||||
task_typing=TaskTyping(primary=TaskType.SCHEDULING),
|
||||
execution_mode=ExecutionMode.TOOL_ASSISTED,
|
||||
result_typing=ResultTyping(primary=ResultType.EXECUTION_REPORT),
|
||||
ui=RouterUiDecision(
|
||||
ui_mode=UiMode.NONE,
|
||||
ui_decision_reason="单任务",
|
||||
),
|
||||
)
|
||||
|
||||
async def _fake_execute_worker_step(**kwargs: object) -> WorkerAgentOutputLite:
|
||||
del kwargs
|
||||
return WorkerAgentOutputLite(answer="ok")
|
||||
|
||||
monkeypatch.setattr(runner_module, "AsyncSessionLocal", lambda: _FakeSessionCtx())
|
||||
monkeypatch.setattr(runner, "_load_stage_config", _fake_load_stage_config)
|
||||
monkeypatch.setattr(runner, "_build_stage_toolkit", lambda **kwargs: object())
|
||||
monkeypatch.setattr(runner, "_execute_router_step", _fake_execute_router_step)
|
||||
monkeypatch.setattr(runner, "_execute_worker_step", _fake_execute_worker_step)
|
||||
|
||||
result = await runner.execute(
|
||||
user_context=_user_context(),
|
||||
context_messages=[],
|
||||
pipeline=_FakePipeline(),
|
||||
run_input=_run_input(),
|
||||
system_agent_mode="worker",
|
||||
)
|
||||
|
||||
assert load_calls == [AgentType.ROUTER, AgentType.WORKER]
|
||||
assert result["router"]["normalized_task_input"]["user_text"] == "安排会议"
|
||||
assert result["worker"]["answer"] == "ok"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_memory_mode_requires_memory_job_config() -> None:
|
||||
runner = AgentScopeRunner()
|
||||
|
||||
class _FakePipeline:
|
||||
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
|
||||
del session_id, event
|
||||
return "1-0"
|
||||
|
||||
with pytest.raises(RuntimeError, match="memory job config is required"):
|
||||
await runner.execute(
|
||||
user_context=_user_context(),
|
||||
context_messages=[],
|
||||
pipeline=_FakePipeline(),
|
||||
run_input=_run_input(),
|
||||
system_agent_mode="memory",
|
||||
memory_job_config=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_memory_mode_uses_memory_job_config(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
class _FakePipeline:
|
||||
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
|
||||
del session_id, event
|
||||
return "1-0"
|
||||
|
||||
class _FakeSessionCtx:
|
||||
async def __aenter__(self) -> object:
|
||||
return object()
|
||||
|
||||
async def __aexit__(self, exc_type: object, exc: object, tb: object) -> None:
|
||||
del exc_type, exc, tb
|
||||
|
||||
runner = AgentScopeRunner()
|
||||
|
||||
async def _fake_build_memory_stage_config(**kwargs: object):
|
||||
del kwargs
|
||||
return runner_module.SystemAgentRuntimeConfig(
|
||||
agent_type=AgentType.MEMORY,
|
||||
model_code="qwen3.5-flash",
|
||||
api_base_url="https://example.com",
|
||||
api_key="test",
|
||||
llm_config=runner_module.SystemAgentLLMConfig(),
|
||||
)
|
||||
|
||||
async def _fake_execute_single_stage_step(**kwargs: object):
|
||||
del kwargs
|
||||
return runner_module.AgentOutput(answer="memory")
|
||||
|
||||
monkeypatch.setattr(runner_module, "AsyncSessionLocal", lambda: _FakeSessionCtx())
|
||||
monkeypatch.setattr(
|
||||
runner, "_build_memory_stage_config", _fake_build_memory_stage_config
|
||||
)
|
||||
monkeypatch.setattr(runner, "_build_stage_toolkit", lambda **kwargs: object())
|
||||
monkeypatch.setattr(
|
||||
runner,
|
||||
"_execute_single_stage_step",
|
||||
_fake_execute_single_stage_step,
|
||||
)
|
||||
|
||||
result = await runner.execute(
|
||||
user_context=_user_context(),
|
||||
context_messages=[],
|
||||
pipeline=_FakePipeline(),
|
||||
run_input=_run_input(),
|
||||
system_agent_mode="memory",
|
||||
memory_job_config=default_memory_job_config(),
|
||||
)
|
||||
|
||||
assert result["memory"]["answer"] == "memory"
|
||||
|
||||
@@ -18,7 +18,7 @@ def _run_input_payload() -> dict[str, Any]:
|
||||
"messages": [],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
"forwardedProps": {"agent_type": "worker"},
|
||||
}
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ async def _fake_user_context(**kwargs: object) -> UserContext:
|
||||
return UserContext(
|
||||
id=str(uuid4()),
|
||||
username="alice",
|
||||
email="alice@example.com",
|
||||
phone="+8613900000000",
|
||||
settings=parse_profile_settings(None),
|
||||
)
|
||||
|
||||
@@ -177,17 +177,53 @@ async def test_run_agentscope_task_rejects_invalid_command_type() -> None:
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agentscope_task_requires_forwarded_props_agent_type() -> None:
|
||||
payload = _run_input_payload()
|
||||
payload["forwardedProps"] = {}
|
||||
|
||||
with pytest.raises(ValueError, match="invalid RunAgentInput.forwardedProps"):
|
||||
await tasks_module.run_agentscope_task(
|
||||
{
|
||||
"command": "run",
|
||||
"owner_id": str(uuid4()),
|
||||
"run_input": payload,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agentscope_task_memory_mode_requires_automation_job_id() -> None:
|
||||
payload = _run_input_payload()
|
||||
payload["forwardedProps"] = {"agent_type": "memory"}
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="automation_job_id is required for memory mode"
|
||||
):
|
||||
await tasks_module.run_agentscope_task(
|
||||
{
|
||||
"command": "run",
|
||||
"owner_id": str(uuid4()),
|
||||
"run_input": payload,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_recent_context_messages_includes_all_user_attachments(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
class _FakeAgentService:
|
||||
async def load_agent_input_messages(
|
||||
class _FakeContextService:
|
||||
def __init__(self, *, repository: object) -> None:
|
||||
del repository
|
||||
|
||||
async def load_context_messages(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
system_agent_mode: str,
|
||||
) -> dict[str, object] | None:
|
||||
del thread_id
|
||||
del thread_id, system_agent_mode
|
||||
return {
|
||||
"messages": [
|
||||
{
|
||||
@@ -215,14 +251,13 @@ async def test_build_recent_context_messages_includes_all_user_attachments(
|
||||
async def download_bytes(self, *, bucket: str, path: str) -> bytes:
|
||||
return f"{bucket}:{path}".encode("utf-8")
|
||||
|
||||
monkeypatch.setattr(
|
||||
tasks_module, "get_agent_service", lambda session: _FakeAgentService()
|
||||
)
|
||||
monkeypatch.setattr(tasks_module, "AgentContextService", _FakeContextService)
|
||||
monkeypatch.setattr(tasks_module, "supabase_service", _FakeSupabase())
|
||||
|
||||
messages = await tasks_module._build_recent_context_messages(
|
||||
session=object(),
|
||||
thread_id=str(uuid4()),
|
||||
context_mode="worker",
|
||||
)
|
||||
|
||||
assert len(messages) == 1
|
||||
@@ -238,13 +273,17 @@ async def test_build_recent_context_messages_includes_all_user_attachments(
|
||||
async def test_build_recent_context_messages_uses_tool_metadata_output(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
class _FakeAgentService:
|
||||
async def load_agent_input_messages(
|
||||
class _FakeContextService:
|
||||
def __init__(self, *, repository: object) -> None:
|
||||
del repository
|
||||
|
||||
async def load_context_messages(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
system_agent_mode: str,
|
||||
) -> dict[str, object] | None:
|
||||
del thread_id
|
||||
del thread_id, system_agent_mode
|
||||
return {
|
||||
"messages": [
|
||||
{
|
||||
@@ -268,13 +307,12 @@ async def test_build_recent_context_messages_uses_tool_metadata_output(
|
||||
]
|
||||
}
|
||||
|
||||
monkeypatch.setattr(
|
||||
tasks_module, "get_agent_service", lambda session: _FakeAgentService()
|
||||
)
|
||||
monkeypatch.setattr(tasks_module, "AgentContextService", _FakeContextService)
|
||||
|
||||
messages = await tasks_module._build_recent_context_messages(
|
||||
session=object(),
|
||||
thread_id=str(uuid4()),
|
||||
context_mode="worker",
|
||||
)
|
||||
|
||||
assert len(messages) == 1
|
||||
@@ -290,13 +328,17 @@ async def test_build_recent_context_messages_uses_tool_metadata_output(
|
||||
async def test_build_recent_context_messages_skips_tool_without_metadata_output(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
class _FakeAgentService:
|
||||
async def load_agent_input_messages(
|
||||
class _FakeContextService:
|
||||
def __init__(self, *, repository: object) -> None:
|
||||
del repository
|
||||
|
||||
async def load_context_messages(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
system_agent_mode: str,
|
||||
) -> dict[str, object] | None:
|
||||
del thread_id
|
||||
del thread_id, system_agent_mode
|
||||
return {
|
||||
"messages": [
|
||||
{
|
||||
@@ -307,13 +349,44 @@ async def test_build_recent_context_messages_skips_tool_without_metadata_output(
|
||||
]
|
||||
}
|
||||
|
||||
monkeypatch.setattr(
|
||||
tasks_module, "get_agent_service", lambda session: _FakeAgentService()
|
||||
)
|
||||
monkeypatch.setattr(tasks_module, "AgentContextService", _FakeContextService)
|
||||
|
||||
messages = await tasks_module._build_recent_context_messages(
|
||||
session=object(),
|
||||
thread_id=str(uuid4()),
|
||||
context_mode="worker",
|
||||
)
|
||||
|
||||
assert messages == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_recent_context_messages_passes_context_mode_through(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
captured_mode: dict[str, str | None] = {"mode": None}
|
||||
|
||||
class _FakeContextService:
|
||||
def __init__(self, *, repository: object) -> None:
|
||||
del repository
|
||||
|
||||
async def load_context_messages(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
system_agent_mode: str,
|
||||
) -> dict[str, object] | None:
|
||||
del thread_id
|
||||
captured_mode["mode"] = system_agent_mode
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(tasks_module, "AgentContextService", _FakeContextService)
|
||||
|
||||
messages = await tasks_module._build_recent_context_messages(
|
||||
session=object(),
|
||||
thread_id=str(uuid4()),
|
||||
context_mode="worker",
|
||||
)
|
||||
|
||||
assert messages == []
|
||||
assert captured_mode["mode"] == "worker"
|
||||
|
||||
@@ -1,99 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.agentscope import schemas as exported_schemas
|
||||
from core.agentscope.schemas.agent_runtime import (
|
||||
AcceptedTaskResponse,
|
||||
AgUiWireEvent,
|
||||
HistorySnapshot,
|
||||
HistorySnapshotResponse,
|
||||
InternalRuntimeEvent,
|
||||
RunCommand,
|
||||
pytest.skip(
|
||||
"legacy agent_runtime schemas removed; covered by agui_input tests",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
|
||||
def test_run_command_alias_roundtrip() -> None:
|
||||
payload = {
|
||||
"threadId": "thread-001",
|
||||
"runId": "run-001",
|
||||
"state": {"cursor": 1},
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"tools": [{"name": "calendar.lookup"}],
|
||||
"context": {"locale": "zh-CN"},
|
||||
"forwardedProps": {"traceId": "trace-1"},
|
||||
}
|
||||
|
||||
command = RunCommand.model_validate(payload)
|
||||
|
||||
assert command.thread_id == "thread-001"
|
||||
assert command.run_id == "run-001"
|
||||
assert command.forwarded_props == {"traceId": "trace-1"}
|
||||
|
||||
dumped = command.model_dump(mode="json", by_alias=True)
|
||||
assert dumped["threadId"] == "thread-001"
|
||||
assert dumped["runId"] == "run-001"
|
||||
assert dumped["forwardedProps"] == {"traceId": "trace-1"}
|
||||
|
||||
|
||||
def test_history_snapshot_response_shape() -> None:
|
||||
response = HistorySnapshotResponse(
|
||||
threadId="thread-123",
|
||||
snapshot=HistorySnapshot(
|
||||
threadId="thread-123",
|
||||
day="2026-03-11",
|
||||
hasMore=False,
|
||||
messages=[{"id": "msg-1"}],
|
||||
),
|
||||
)
|
||||
|
||||
dumped = response.model_dump(mode="json", by_alias=True, exclude_none=True)
|
||||
assert dumped["type"] == "STATE_SNAPSHOT"
|
||||
assert dumped["threadId"] == "thread-123"
|
||||
assert dumped["snapshot"]["scope"] == "history_day"
|
||||
assert dumped["snapshot"]["hasMore"] is False
|
||||
assert dumped["snapshot"]["messages"] == [{"id": "msg-1"}]
|
||||
|
||||
|
||||
def test_runtime_event_validation_basics() -> None:
|
||||
internal = InternalRuntimeEvent(type="RUN_STARTED", data={"step": 1})
|
||||
assert internal.type == "RUN_STARTED"
|
||||
assert internal.model_dump(mode="json", by_alias=True)["data"] == {"step": 1}
|
||||
|
||||
wire = AgUiWireEvent(type="TEXT_MESSAGE_CONTENT", payload={"delta": "hello"})
|
||||
dumped = wire.model_dump(mode="json", by_alias=True, exclude_none=True)
|
||||
assert dumped["type"] == "TEXT_MESSAGE_CONTENT"
|
||||
assert dumped["payload"] == {"delta": "hello"}
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
InternalRuntimeEvent.model_validate({"threadId": "t-1", "data": {}})
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
AgUiWireEvent.model_validate({"payload": {"delta": "hello"}})
|
||||
|
||||
|
||||
def test_schemas_exports_include_task_and_history_models() -> None:
|
||||
assert exported_schemas.AcceptedTaskResponse is AcceptedTaskResponse
|
||||
assert exported_schemas.TaskAccepted is AcceptedTaskResponse
|
||||
assert exported_schemas.TaskAcceptedResponse is AcceptedTaskResponse
|
||||
assert exported_schemas.HistorySnapshotResponse is HistorySnapshotResponse
|
||||
|
||||
|
||||
def test_run_command_accepts_agui_context_list_and_parent_run_id() -> None:
|
||||
payload = {
|
||||
"threadId": "thread-xyz",
|
||||
"runId": "run-xyz",
|
||||
"state": {},
|
||||
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
"parentRunId": None,
|
||||
}
|
||||
|
||||
command = RunCommand.model_validate(payload)
|
||||
|
||||
dumped = command.model_dump(mode="json", by_alias=True)
|
||||
assert dumped["context"] == []
|
||||
assert "parentRunId" in dumped
|
||||
|
||||
@@ -20,7 +20,7 @@ def _base_payload() -> dict[str, object]:
|
||||
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
"forwardedProps": {"agent_type": "worker"},
|
||||
}
|
||||
|
||||
|
||||
@@ -149,7 +149,7 @@ def test_parse_run_input_accepts_snake_case_aliases() -> None:
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwarded_props": {},
|
||||
"forwarded_props": {"agent_type": "worker"},
|
||||
}
|
||||
|
||||
run_input = parse_run_input(payload)
|
||||
@@ -162,11 +162,12 @@ def test_parse_run_input_accepts_snake_case_aliases() -> None:
|
||||
def test_parse_run_input_accepts_client_time_forwarded_props() -> None:
|
||||
payload = _base_payload()
|
||||
payload["forwardedProps"] = {
|
||||
"agent_type": "worker",
|
||||
"client_time": {
|
||||
"device_timezone": "America/Los_Angeles",
|
||||
"client_now_iso": "2026-03-16T09:12:33-07:00",
|
||||
"client_epoch_ms": 1773658353000,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
run_input = parse_run_input(payload)
|
||||
@@ -177,11 +178,12 @@ def test_parse_run_input_accepts_client_time_forwarded_props() -> None:
|
||||
def test_parse_run_input_rejects_invalid_client_time_timezone() -> None:
|
||||
payload = _base_payload()
|
||||
payload["forwardedProps"] = {
|
||||
"agent_type": "worker",
|
||||
"client_time": {
|
||||
"device_timezone": "Mars/OlympusMons",
|
||||
"client_now_iso": "2026-03-16T09:12:33-07:00",
|
||||
"client_epoch_ms": 1773658353000,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="invalid RunAgentInput.forwardedProps"):
|
||||
@@ -191,11 +193,12 @@ def test_parse_run_input_rejects_invalid_client_time_timezone() -> None:
|
||||
def test_parse_run_input_rejects_invalid_client_time_now_iso() -> None:
|
||||
payload = _base_payload()
|
||||
payload["forwardedProps"] = {
|
||||
"agent_type": "worker",
|
||||
"client_time": {
|
||||
"device_timezone": "America/Los_Angeles",
|
||||
"client_now_iso": "2026-03-16 09:12:33",
|
||||
"client_epoch_ms": 1773658353000,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="invalid RunAgentInput.forwardedProps"):
|
||||
@@ -205,11 +208,12 @@ def test_parse_run_input_rejects_invalid_client_time_now_iso() -> None:
|
||||
def test_parse_run_input_rejects_invalid_client_time_epoch_type() -> None:
|
||||
payload = _base_payload()
|
||||
payload["forwardedProps"] = {
|
||||
"agent_type": "worker",
|
||||
"client_time": {
|
||||
"device_timezone": "America/Los_Angeles",
|
||||
"client_now_iso": "2026-03-16T09:12:33-07:00",
|
||||
"client_epoch_ms": "1773658353000",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="invalid RunAgentInput.forwardedProps"):
|
||||
@@ -219,6 +223,7 @@ def test_parse_run_input_rejects_invalid_client_time_epoch_type() -> None:
|
||||
def test_parse_run_input_rejects_unknown_forwarded_props_key() -> None:
|
||||
payload = _base_payload()
|
||||
payload["forwardedProps"] = {
|
||||
"agent_type": "worker",
|
||||
"client_time": {
|
||||
"device_timezone": "America/Los_Angeles",
|
||||
"client_now_iso": "2026-03-16T09:12:33-07:00",
|
||||
@@ -229,3 +234,17 @@ def test_parse_run_input_rejects_unknown_forwarded_props_key() -> None:
|
||||
|
||||
with pytest.raises(ValueError, match="invalid RunAgentInput.forwardedProps"):
|
||||
parse_run_input(payload)
|
||||
|
||||
|
||||
def test_parse_run_input_rejects_missing_forwarded_props_agent_type() -> None:
|
||||
payload = _base_payload()
|
||||
payload["forwardedProps"] = {
|
||||
"client_time": {
|
||||
"device_timezone": "America/Los_Angeles",
|
||||
"client_now_iso": "2026-03-16T09:12:33-07:00",
|
||||
"client_epoch_ms": 1773658353000,
|
||||
}
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="invalid RunAgentInput.forwardedProps"):
|
||||
parse_run_input(payload)
|
||||
|
||||
@@ -11,7 +11,7 @@ def test_build_agent_prompt_for_worker_contains_runtime_config() -> None:
|
||||
{
|
||||
"temperature": 0.2,
|
||||
"context_messages": {"mode": "number", "count": 20},
|
||||
"enabled_tool_groups": ["read", "write"],
|
||||
"enabled_tools": ["calendar.read", "calendar.write"],
|
||||
}
|
||||
),
|
||||
)
|
||||
@@ -20,7 +20,7 @@ def test_build_agent_prompt_for_worker_contains_runtime_config() -> None:
|
||||
assert "- type: worker" in prompt
|
||||
assert "context_messages.mode=number" in prompt
|
||||
assert "context_messages.count=20" in prompt
|
||||
assert "enabled_tool_groups=read,write" in prompt
|
||||
assert "enabled_tools=calendar.read,calendar.write" in prompt
|
||||
|
||||
|
||||
def test_build_agent_prompt_for_memory_uses_memory_rules() -> None:
|
||||
@@ -29,7 +29,7 @@ def test_build_agent_prompt_for_memory_uses_memory_rules() -> None:
|
||||
llm_config=SystemAgentLLMConfig.model_validate(
|
||||
{
|
||||
"context_messages": {"mode": "day", "count": 2},
|
||||
"enabled_tool_groups": ["read"],
|
||||
"enabled_tools": ["user.lookup"],
|
||||
}
|
||||
),
|
||||
)
|
||||
@@ -38,4 +38,4 @@ def test_build_agent_prompt_for_memory_uses_memory_rules() -> None:
|
||||
assert "[Memory Agent]" in prompt
|
||||
assert "context_messages.mode=day" in prompt
|
||||
assert "context_messages.count=2" in prompt
|
||||
assert "enabled_tool_groups=read" in prompt
|
||||
assert "enabled_tools=user.lookup" in prompt
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agentscope.tools.hitl_middleware import create_hitl_middleware
|
||||
from core.agentscope.tools.tool_meta import TOOL_META, ToolMeta
|
||||
from core.agentscope.tools.tool_config import ToolApprovalConfig, ToolConfig, ToolGroup
|
||||
from core.agentscope.tools.tool_middleware import create_approval_middleware
|
||||
|
||||
|
||||
async def _next_handler(**kwargs: Any) -> AsyncGenerator[dict[str, object], None]:
|
||||
@@ -15,9 +16,31 @@ async def _next_handler(**kwargs: Any) -> AsyncGenerator[dict[str, object], None
|
||||
return _generator()
|
||||
|
||||
|
||||
def _extract_error_payload(chunk: object) -> dict[str, Any]:
|
||||
content = getattr(chunk, "content", None)
|
||||
if not isinstance(content, list) or not content:
|
||||
return {}
|
||||
first_block = content[0]
|
||||
text = getattr(first_block, "text", None)
|
||||
if not isinstance(text, str) and isinstance(first_block, dict):
|
||||
raw_text = first_block.get("text")
|
||||
text = raw_text if isinstance(raw_text, str) else None
|
||||
if not isinstance(text, str):
|
||||
return {}
|
||||
return json.loads(text)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hitl_middleware_default_write_does_not_require_approval() -> None:
|
||||
middleware = create_hitl_middleware(meta_by_name=TOOL_META)
|
||||
middleware = create_approval_middleware(
|
||||
config_by_name={
|
||||
"calendar_write": ToolConfig(
|
||||
name="calendar_write",
|
||||
group=ToolGroup.EXECUTE,
|
||||
approval=ToolApprovalConfig(required=False),
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
responses = []
|
||||
async for chunk in middleware(
|
||||
@@ -30,36 +53,39 @@ async def test_hitl_middleware_default_write_does_not_require_approval() -> None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hitl_middleware_pending_when_tool_requires_approval(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
middleware = create_hitl_middleware(
|
||||
meta_by_name={
|
||||
"calendar.write": ToolMeta(name="calendar.write", requires_approval=True)
|
||||
}
|
||||
async def test_hitl_middleware_pending_when_tool_requires_approval() -> None:
|
||||
middleware = create_approval_middleware(
|
||||
config_by_name={
|
||||
"calendar_write": ToolConfig(
|
||||
name="calendar_write",
|
||||
group=ToolGroup.EXECUTE,
|
||||
approval=ToolApprovalConfig(required=True),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agentscope.tools.hitl_middleware.build_tool_response",
|
||||
lambda payload: payload,
|
||||
}
|
||||
)
|
||||
|
||||
responses = []
|
||||
async for chunk in middleware(
|
||||
{"tool_call": {"name": "calendar.write", "input": {"operation": "create"}}},
|
||||
{"tool_call": {"name": "calendar_write", "input": {"operation": "create"}}},
|
||||
_next_handler,
|
||||
):
|
||||
responses.append(chunk)
|
||||
|
||||
assert responses[0]["data"]["status"] == "pending"
|
||||
payload = _extract_error_payload(responses[0])
|
||||
assert payload["error"]["code"] == "TOOL_PENDING_APPROVAL"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hitl_middleware_passes_when_write_approved() -> None:
|
||||
middleware = create_hitl_middleware(
|
||||
meta_by_name={
|
||||
"calendar.write": ToolMeta(name="calendar.write", requires_approval=True)
|
||||
middleware = create_approval_middleware(
|
||||
config_by_name={
|
||||
"calendar_write": ToolConfig(
|
||||
name="calendar_write",
|
||||
group=ToolGroup.EXECUTE,
|
||||
approval=ToolApprovalConfig(required=True),
|
||||
)
|
||||
},
|
||||
approval_resolver=lambda _name, _args: "approved",
|
||||
approval_resolver=lambda _name, _args, _config: "approved",
|
||||
)
|
||||
|
||||
responses = []
|
||||
@@ -69,6 +95,7 @@ async def test_hitl_middleware_passes_when_write_approved() -> None:
|
||||
"name": "calendar.write",
|
||||
"input": {
|
||||
"operation": "create",
|
||||
"_hitl": {"approval": "required"},
|
||||
},
|
||||
}
|
||||
},
|
||||
@@ -82,25 +109,24 @@ async def test_hitl_middleware_passes_when_write_approved() -> None:
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hitl_middleware_rejected_short_circuits(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
middleware = create_hitl_middleware(
|
||||
meta_by_name={
|
||||
"calendar.write": ToolMeta(name="calendar.write", requires_approval=True)
|
||||
},
|
||||
approval_resolver=lambda _name, _args: "rejected",
|
||||
async def test_hitl_middleware_rejected_short_circuits() -> None:
|
||||
middleware = create_approval_middleware(
|
||||
config_by_name={
|
||||
"calendar_write": ToolConfig(
|
||||
name="calendar_write",
|
||||
group=ToolGroup.EXECUTE,
|
||||
approval=ToolApprovalConfig(required=True),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agentscope.tools.hitl_middleware.build_tool_response",
|
||||
lambda payload: payload,
|
||||
},
|
||||
approval_resolver=lambda _name, _args, _config: "rejected",
|
||||
)
|
||||
|
||||
responses = []
|
||||
async for chunk in middleware(
|
||||
{"tool_call": {"name": "calendar.write", "input": {"operation": "create"}}},
|
||||
{"tool_call": {"name": "calendar_write", "input": {"operation": "create"}}},
|
||||
_next_handler,
|
||||
):
|
||||
responses.append(chunk)
|
||||
|
||||
assert responses[0]["data"]["status"] == "rejected"
|
||||
payload = _extract_error_payload(responses[0])
|
||||
assert payload["error"]["code"] == "TOOL_REJECTED"
|
||||
|
||||
@@ -16,7 +16,7 @@ def _build_user_context(*, timezone_name: str = "Asia/Shanghai") -> UserContext:
|
||||
return UserContext(
|
||||
id=str(uuid4()),
|
||||
username="alice",
|
||||
email="alice@example.com",
|
||||
phone="+8613900000000",
|
||||
bio="focus on calendars",
|
||||
settings=parse_profile_settings(
|
||||
{
|
||||
|
||||
@@ -7,7 +7,9 @@ from core.agentscope.tools.toolkit import build_stage_toolkit
|
||||
from schemas.agent.system_agent import AgentType
|
||||
|
||||
|
||||
def test_build_stage_toolkit_filters_requested_tools_by_agent_type(monkeypatch) -> None:
|
||||
def test_build_stage_toolkit_uses_explicit_enabled_tools_as_final_set(
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
def _fake_build_toolkit(**kwargs):
|
||||
@@ -22,7 +24,29 @@ def test_build_stage_toolkit_filters_requested_tools_by_agent_type(monkeypatch)
|
||||
agent_type=AgentType.WORKER,
|
||||
session=cast(Any, object()),
|
||||
owner_id=uuid4(),
|
||||
enabled_tool_names={"calendar_read", "calendar_write", "user_lookup"},
|
||||
enabled_tool_names={"calendar_read", "user_lookup"},
|
||||
)
|
||||
|
||||
assert captured["enabled_tool_names"] == {"calendar_read", "user_lookup"}
|
||||
|
||||
|
||||
def test_build_stage_toolkit_uses_memory_defaults_without_explicit_tools(
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
def _fake_build_toolkit(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return object()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.agentscope.tools.toolkit.build_toolkit", _fake_build_toolkit
|
||||
)
|
||||
|
||||
build_stage_toolkit(
|
||||
agent_type=AgentType.MEMORY,
|
||||
session=cast(Any, object()),
|
||||
owner_id=uuid4(),
|
||||
)
|
||||
|
||||
assert captured["enabled_tool_names"] == {"calendar_read", "user_lookup"}
|
||||
|
||||
@@ -0,0 +1,134 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.automation.scheduler import (
|
||||
AutomationSchedulerService,
|
||||
_compute_next_run_at,
|
||||
)
|
||||
from models.automation_jobs import ScheduleType
|
||||
from schemas.automation.config import AutomationJobConfig
|
||||
from schemas.automation.scheduler import DueAutomationJob
|
||||
|
||||
|
||||
class _FakeRepository:
|
||||
def __init__(self, jobs: list[DueAutomationJob]) -> None:
|
||||
self.jobs = jobs
|
||||
self.marked: list[tuple[UUID, datetime, datetime]] = []
|
||||
self.commits = 0
|
||||
self.rollbacks = 0
|
||||
|
||||
async def list_due_jobs(
|
||||
self, *, now_utc: datetime, limit: int
|
||||
) -> list[DueAutomationJob]:
|
||||
del now_utc
|
||||
return self.jobs[:limit]
|
||||
|
||||
async def get_job_config(self, *, job_id: UUID) -> AutomationJobConfig:
|
||||
del job_id
|
||||
return AutomationJobConfig.model_validate(
|
||||
{
|
||||
"agent_type": "memory",
|
||||
"model_code": "qwen3.5-flash",
|
||||
"enabled_tools": ["calendar.read", "user.lookup"],
|
||||
"input_template": "auto input",
|
||||
"context": {
|
||||
"source": "latest_chat",
|
||||
"window_mode": "day",
|
||||
"window_count": 2,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
async def ensure_latest_chat_session(self, *, owner_id: UUID) -> UUID:
|
||||
return owner_id
|
||||
|
||||
async def mark_job_dispatched(
|
||||
self,
|
||||
*,
|
||||
job_id: UUID,
|
||||
next_run_at: datetime,
|
||||
last_run_at: datetime,
|
||||
) -> None:
|
||||
self.marked.append((job_id, next_run_at, last_run_at))
|
||||
|
||||
async def commit(self) -> None:
|
||||
self.commits += 1
|
||||
|
||||
async def rollback(self) -> None:
|
||||
self.rollbacks += 1
|
||||
|
||||
|
||||
class _FakeQueue:
|
||||
def __init__(self) -> None:
|
||||
self.commands: list[dict[str, object]] = []
|
||||
|
||||
async def enqueue(
|
||||
self,
|
||||
*,
|
||||
command: dict[str, object],
|
||||
dedup_key: str | None,
|
||||
) -> str:
|
||||
del dedup_key
|
||||
self.commands.append(command)
|
||||
return "task-1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_and_dispatch_enqueues_memory_run_command() -> None:
|
||||
now = datetime(2026, 3, 19, 12, 0, tzinfo=timezone.utc)
|
||||
owner_id = uuid4()
|
||||
job_id = uuid4()
|
||||
repo = _FakeRepository(
|
||||
jobs=[
|
||||
DueAutomationJob(
|
||||
id=job_id,
|
||||
owner_id=owner_id,
|
||||
schedule_type=ScheduleType.DAILY,
|
||||
timezone="UTC",
|
||||
next_run_at=now - timedelta(minutes=1),
|
||||
)
|
||||
]
|
||||
)
|
||||
queue = _FakeQueue()
|
||||
service = AutomationSchedulerService(repository=repo, queue=queue)
|
||||
|
||||
result = await service.scan_and_dispatch(now_utc=now, limit=10)
|
||||
|
||||
assert result.scanned == 1
|
||||
assert result.dispatched == 1
|
||||
assert len(queue.commands) == 1
|
||||
run_input = queue.commands[0]["run_input"]
|
||||
assert isinstance(run_input, dict)
|
||||
assert run_input["forwardedProps"] == {"agent_type": "memory"}
|
||||
assert queue.commands[0]["automation_job_id"] == str(job_id)
|
||||
assert repo.commits == 1
|
||||
|
||||
|
||||
def test_compute_next_run_at_daily() -> None:
|
||||
now = datetime(2026, 3, 19, 12, 0, tzinfo=timezone.utc)
|
||||
current = datetime(2026, 3, 19, 11, 0, tzinfo=timezone.utc)
|
||||
|
||||
computed = _compute_next_run_at(
|
||||
current_next_run_at=current,
|
||||
now_utc=now,
|
||||
schedule_type=ScheduleType.DAILY,
|
||||
)
|
||||
|
||||
assert computed == datetime(2026, 3, 20, 11, 0, tzinfo=timezone.utc)
|
||||
|
||||
|
||||
def test_compute_next_run_at_weekly() -> None:
|
||||
now = datetime(2026, 3, 19, 12, 0, tzinfo=timezone.utc)
|
||||
current = datetime(2026, 3, 10, 11, 0, tzinfo=timezone.utc)
|
||||
|
||||
computed = _compute_next_run_at(
|
||||
current_next_run_at=current,
|
||||
now_utc=now,
|
||||
schedule_type=ScheduleType.WEEKLY,
|
||||
)
|
||||
|
||||
assert computed == datetime(2026, 3, 24, 11, 0, tzinfo=timezone.utc)
|
||||
@@ -0,0 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def test_memory_automation_job_trigger_exists_in_0004_migration() -> None:
|
||||
migration = (
|
||||
Path(__file__).resolve().parents[3]
|
||||
/ "alembic"
|
||||
/ "versions"
|
||||
/ "20260319_0004_automation_job_config_for_memory.py"
|
||||
)
|
||||
content = migration.read_text(encoding="utf-8")
|
||||
|
||||
assert "INSERT INTO public.automation_jobs" in content
|
||||
assert "'agent_type', 'memory'" in content
|
||||
assert "ux_automation_jobs_owner_memory_active" in content
|
||||
assert "input_template" in content
|
||||
@@ -0,0 +1,37 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from schemas.agent.consumer_registry import AgentConsumerBinding, ConsumerRegistry
|
||||
from schemas.agent.pipeline_spec import (
|
||||
ContextPolicy,
|
||||
ExecutorKind,
|
||||
PipelineSpec,
|
||||
StageSpec,
|
||||
)
|
||||
|
||||
|
||||
def test_consumer_registry_rejects_duplicate_bits() -> None:
|
||||
with pytest.raises(ValueError, match="duplicate visibility bit"):
|
||||
ConsumerRegistry(
|
||||
bindings=[
|
||||
AgentConsumerBinding(agent_type="router", bit=16),
|
||||
AgentConsumerBinding(agent_type="worker", bit=16),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_pipeline_spec_requires_non_empty_stages() -> None:
|
||||
with pytest.raises(ValueError, match="at least 1 item"):
|
||||
PipelineSpec(mode="worker", stages=[])
|
||||
|
||||
|
||||
def test_stage_spec_normalizes_stage_name() -> None:
|
||||
spec = StageSpec(
|
||||
stage_name=" Worker ",
|
||||
executor_kind=ExecutorKind.REACT,
|
||||
default_visibility_mask=1,
|
||||
context_policy=ContextPolicy(consumer_agent_type="worker", count=20),
|
||||
)
|
||||
|
||||
assert spec.stage_name == "worker"
|
||||
@@ -0,0 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from schemas.agent.system_agent import SystemAgentLLMConfig
|
||||
|
||||
|
||||
def test_system_agent_llm_config_normalizes_enabled_tools_aliases() -> None:
|
||||
config = SystemAgentLLMConfig.model_validate(
|
||||
{
|
||||
"enabled_tools": [
|
||||
"calendar.write",
|
||||
"calendar_write",
|
||||
"user.lookup",
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
assert [tool.value for tool in config.enabled_tools] == [
|
||||
"calendar.write",
|
||||
"user.lookup",
|
||||
]
|
||||
|
||||
|
||||
def test_system_agent_llm_config_rejects_unknown_enabled_tool() -> None:
|
||||
with pytest.raises(ValueError, match="unknown enabled tool"):
|
||||
SystemAgentLLMConfig.model_validate(
|
||||
{
|
||||
"enabled_tools": ["calendar.remove"],
|
||||
}
|
||||
)
|
||||
@@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from schemas.agent.visibility import VisibilityMask, bit_mask
|
||||
|
||||
|
||||
def test_visibility_mask_from_bits_and_contains() -> None:
|
||||
mask = VisibilityMask.from_bits(bits=[0, 16, 18])
|
||||
|
||||
assert mask.contains(bit=0) is True
|
||||
assert mask.contains(bit=16) is True
|
||||
assert mask.contains(bit=17) is False
|
||||
|
||||
|
||||
def test_visibility_mask_rejects_out_of_range_bit() -> None:
|
||||
with pytest.raises(ValueError, match="range"):
|
||||
VisibilityMask.from_bits(bits=[64])
|
||||
|
||||
|
||||
def test_bit_mask_builds_single_bit_integer() -> None:
|
||||
assert bit_mask(bit=0) == 1
|
||||
assert bit_mask(bit=16) == (1 << 16)
|
||||
@@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from schemas.automation.config import AutomationJobConfig, default_memory_job_config
|
||||
|
||||
|
||||
def test_default_memory_job_config_has_expected_defaults() -> None:
|
||||
config = default_memory_job_config()
|
||||
|
||||
assert config.agent_type.value == "memory"
|
||||
assert config.model_code == "qwen3.5-flash"
|
||||
assert config.context.source.value == "latest_chat"
|
||||
|
||||
|
||||
def test_automation_job_config_rejects_non_flash_model() -> None:
|
||||
with pytest.raises(ValueError, match="model_code must be qwen3.5-flash"):
|
||||
AutomationJobConfig.model_validate(
|
||||
{
|
||||
"agent_type": "memory",
|
||||
"model_code": "qwen-plus",
|
||||
"enabled_tools": ["calendar.read"],
|
||||
"input_template": "x",
|
||||
"context": {
|
||||
"source": "latest_chat",
|
||||
"window_mode": "day",
|
||||
"window_count": 2,
|
||||
},
|
||||
}
|
||||
)
|
||||
@@ -1,306 +0,0 @@
|
||||
# 自动化记忆与 Agent 重构设计(v3)
|
||||
|
||||
## 1. 目标
|
||||
|
||||
本次重构目标是把现有“router + worker”双阶段执行,收敛为更直接、可控、低 token 成本的执行链路。
|
||||
|
||||
明确目标:
|
||||
|
||||
1. 去除 `router agent`,由 `worker` 直接处理用户消息;
|
||||
2. `worker` 模型从 `deepseek-chat` 切换为 `qwen3.5-30b-a3b`;
|
||||
3. 关闭 `worker` 思考模式(thinking/reasoning off);
|
||||
4. 上下文策略从“今天+昨天”改为“向前回溯直到累计 20 条用户消息”;
|
||||
5. 引入专用 `memory agent`,负责记忆提取,模型为 `qwen3.5-flash`;
|
||||
6. 保持 `/api/v1/agent`、SSE、history 兼容演进,不引入第二运行时。
|
||||
|
||||
## 2. 非目标
|
||||
|
||||
1. 不新增独立部署的执行系统;
|
||||
2. 不引入复杂 DSL 编排;
|
||||
3. 不在本阶段改造前端交互形态;
|
||||
4. 不做跨域业务逻辑重写(仅围绕 agent 执行链路和 memory 提取职责)。
|
||||
|
||||
## 3. 总体架构
|
||||
|
||||
### 3.1 重构后链路
|
||||
|
||||
```text
|
||||
API/Scheduler Trigger
|
||||
-> Context Window Resolver (last 20 user messages)
|
||||
-> Executor Dispatch
|
||||
-> Worker Executor (qwen3.5-30b-a3b, thinking off)
|
||||
-> Memory Executor (qwen3.5-flash, thinking off)
|
||||
-> Tool Authorization
|
||||
-> Persistence + Redis Stream
|
||||
-> SSE/History Consumers
|
||||
```
|
||||
|
||||
### 3.2 核心变化
|
||||
|
||||
1. 删除 router 阶段,缩短调用链路与阶段状态;
|
||||
2. 把“意图识别”内聚到 worker 提示词与执行策略;
|
||||
3. 把“记忆提取”从通用 worker 中剥离为独立 memory executor;
|
||||
4. 上下文装配改为固定规模策略,降低输入 token 波动。
|
||||
|
||||
## 4. 执行角色与职责边界
|
||||
|
||||
### 4.1 Worker Executor(主对话执行器)
|
||||
|
||||
- 模型:`qwen3.5-30b-a3b`
|
||||
- 配置:`thinking=off`
|
||||
- 职责:
|
||||
- 处理用户请求;
|
||||
- 在单次推理中完成意图判断、工具决策、结果回复;
|
||||
- 遵守工具白名单与安全约束。
|
||||
|
||||
### 4.2 Memory Executor(记忆提取执行器)
|
||||
|
||||
- 模型:`qwen3.5-flash`
|
||||
- 配置:`thinking=off`
|
||||
- 职责:
|
||||
- 从对话中提取稳定记忆候选;
|
||||
- 产出结构化记忆写入/遗忘建议;
|
||||
- 不承担通用对话和复杂工具编排。
|
||||
|
||||
### 4.3 工具权限边界
|
||||
|
||||
统一规则:
|
||||
|
||||
`effective_tools = declared_tools ∩ profile_allowlist ∩ system_allowlist`
|
||||
|
||||
- worker 默认可调用通用工具子集;
|
||||
- memory 默认仅允许 `memory_write`、`memory_forget`;
|
||||
- 拒绝调用必须落审计(原因、工具名、请求上下文摘要)。
|
||||
|
||||
## 5. 上下文窗口策略(替代 today_yesterday)
|
||||
|
||||
### 5.1 策略定义
|
||||
|
||||
窗口规则:从当前待处理消息向前回溯,直到累计到 20 条 `role=user` 消息。
|
||||
|
||||
细则:
|
||||
|
||||
1. 计数对象仅为 `role=user`;
|
||||
2. 为保持语义连续,窗口中保留相关 assistant/tool/system 消息;
|
||||
3. 若历史不足 20 条用户消息,返回全部可用历史;
|
||||
4. 当前用户消息默认计入 20 条统计。
|
||||
|
||||
### 5.2 伪代码
|
||||
|
||||
```text
|
||||
collect_context(messages, current_message_id, n=20):
|
||||
included = []
|
||||
user_count = 0
|
||||
|
||||
for msg in reverse(messages up to current_message_id):
|
||||
included.append(msg)
|
||||
if msg.role == "user":
|
||||
user_count += 1
|
||||
if user_count >= n:
|
||||
break
|
||||
|
||||
return reverse(included)
|
||||
```
|
||||
|
||||
### 5.3 预期收益
|
||||
|
||||
1. 输入 token 成本更稳定,不再受日期边界放大;
|
||||
2. router 去除后仍可控住 worker 输入规模;
|
||||
3. 在高频会话下显著降低上下文冗余。
|
||||
|
||||
## 6. 去除 Router 后的意图识别设计
|
||||
|
||||
去掉 router 后,必须在 worker 内完成“识别 + 决策 + 执行”。本节给出可落地方案。
|
||||
|
||||
### 6.1 方案 A:复用 RouterAgentOutput 语义到 Worker Prompt
|
||||
|
||||
做法:把原 router 的标签、判定规则、优先级放进 worker system prompt,让 worker先做内部意图归类,再进入执行。
|
||||
|
||||
优点:
|
||||
|
||||
1. 迁移风险低,行为连续性强;
|
||||
2. 便于快速下线 router;
|
||||
3. 对现有回归样本复用程度高。
|
||||
|
||||
不足:
|
||||
|
||||
1. 提示词偏长;
|
||||
2. 分类与执行耦合,调试颗粒度较粗。
|
||||
|
||||
### 6.2 方案 B:重写 Worker-Native 轻量意图识别提示词
|
||||
|
||||
做法:定义最小标签集(示例:`chat | tool_call | memory_write | memory_forget | reject`)和明确优先级规则,直接服务 worker 执行。
|
||||
|
||||
优点:
|
||||
|
||||
1. prompt 更短,token 更省;
|
||||
2. 更匹配“20 条用户消息窗口”策略;
|
||||
3. 长期维护成本更低。
|
||||
|
||||
不足:
|
||||
|
||||
1. 需要重建回归集与阈值;
|
||||
2. 初期存在行为漂移风险。
|
||||
|
||||
### 6.3 推荐路线:A -> B 两阶段
|
||||
|
||||
1. 第一阶段(稳定迁移):先落地方案 A,确保 router 去除后行为不突变;
|
||||
2. 第二阶段(成本优化):基于线上样本收敛到方案 B,压缩提示词和标签集。
|
||||
|
||||
该路线兼顾了“短期稳定”和“中期降本”。
|
||||
|
||||
## 7. 数据模型与配置约定
|
||||
|
||||
### 7.1 Execution Profile(精简版)
|
||||
|
||||
建议收敛为:
|
||||
|
||||
```json
|
||||
{
|
||||
"name": "chat_default",
|
||||
"executor": "worker",
|
||||
"model": {
|
||||
"name": "qwen3.5-30b-a3b",
|
||||
"thinking": "off"
|
||||
},
|
||||
"history_policy": {
|
||||
"mode": "last_n_user_messages",
|
||||
"n": 20,
|
||||
"include_current_user_message": true
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
memory profile 示例:
|
||||
|
||||
```json
|
||||
{
|
||||
"name": "automation_memory_default",
|
||||
"executor": "memory",
|
||||
"model": {
|
||||
"name": "qwen3.5-flash",
|
||||
"thinking": "off"
|
||||
},
|
||||
"history_policy": {
|
||||
"mode": "last_n_user_messages",
|
||||
"n": 20,
|
||||
"include_current_user_message": true
|
||||
},
|
||||
"tool_policy": {
|
||||
"mode": "intersection",
|
||||
"allowlist": ["memory_write", "memory_forget"]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 7.2 兼容字段策略
|
||||
|
||||
- 历史 `enable_router` 字段保留读取兼容,写入路径不再依赖;
|
||||
- 新任务默认不再产出 router 配置;
|
||||
- 执行路径仅依据 `executor` 与 profile 配置。
|
||||
|
||||
### 7.3 Metadata 扩展建议
|
||||
|
||||
建议标准字段:
|
||||
|
||||
- `origin: "chat" | "automation"`
|
||||
- `executor: "worker" | "memory"`
|
||||
- `execution_profile_name: string`
|
||||
- `hidden_from_user: boolean`
|
||||
|
||||
用于用户可见性隔离与全链路审计。
|
||||
|
||||
## 8. 协议影响与文档更新
|
||||
|
||||
按照“协议先行”原则,先更新 `docs/protocols/`:
|
||||
|
||||
1. `docs/protocols/agent/sse-events.md`
|
||||
- step 枚举从 `router|worker|memory` 收敛为 `worker|memory`;
|
||||
- 对旧客户端声明:`router` 事件可能不再出现。
|
||||
|
||||
2. `docs/protocols/agent/run-agent-input.md`
|
||||
- 增加 `history_policy.mode=last_n_user_messages` 语义与边界规则。
|
||||
|
||||
3. `docs/protocols/agent/api-endpoints.md`
|
||||
- 说明执行阶段由后端 profile 决定;
|
||||
- 不要求前端显式传入 router 或 executor 控制参数。
|
||||
|
||||
## 9. 迁移计划
|
||||
|
||||
### 阶段 1:协议与配置就绪
|
||||
|
||||
1. 完成协议文档更新;
|
||||
2. 增加 profile 新字段和兼容读取逻辑;
|
||||
3. 新建任务默认 profile 切到 worker/memory 双执行器模型。
|
||||
|
||||
### 阶段 2:执行链路切换
|
||||
|
||||
1. 下线 router 运行路径;
|
||||
2. worker 切换 `qwen3.5-30b-a3b` + thinking off;
|
||||
3. 上下文装配切为“20 条用户消息策略”。
|
||||
|
||||
### 阶段 3:memory agent 接管记忆提取
|
||||
|
||||
1. memory executor 切换 `qwen3.5-flash`;
|
||||
2. 自动记忆任务全量走 memory executor;
|
||||
3. 对比提取质量与成本,完成灰度放量。
|
||||
|
||||
### 阶段 4:优化与收敛
|
||||
|
||||
1. worker 意图识别从方案 A 迭代到方案 B;
|
||||
2. 清理 router 相关遗留代码和配置分支;
|
||||
3. 固化观测指标与报警阈值。
|
||||
|
||||
## 10. 测试与验收
|
||||
|
||||
### 10.1 单元测试
|
||||
|
||||
1. `last_n_user_messages` 窗口截取逻辑;
|
||||
2. 工具交集授权逻辑;
|
||||
3. profile 解析与兼容字段读取;
|
||||
4. memory 输出结构校验。
|
||||
|
||||
### 10.2 集成测试
|
||||
|
||||
1. 无 router 情况下 worker 正常执行;
|
||||
2. SSE/history 在 `worker|memory` 阶段下可稳定消费;
|
||||
3. 自动记忆任务完整链路可执行;
|
||||
4. hidden 消息对用户不可见但审计可见。
|
||||
|
||||
### 10.3 验收指标
|
||||
|
||||
P0:
|
||||
|
||||
1. router 下线后核心对话流程零阻断;
|
||||
2. 平均输入 token 相比 today_yesterday 明显下降;
|
||||
3. 工具调用越权率为 0。
|
||||
|
||||
P1:
|
||||
|
||||
1. memory 提取质量不低于现网基线;
|
||||
2. 延迟与成本达到预期区间;
|
||||
3. 协议兼容无前端回归。
|
||||
|
||||
## 11. 风险与回滚
|
||||
|
||||
主要风险:
|
||||
|
||||
1. worker 内聚识别导致误判率短期上升;
|
||||
2. 20 条用户消息窗口在极端长任务中可能信息不足;
|
||||
3. memory 轻量模型在复杂语义下提取质量波动。
|
||||
|
||||
回滚策略:
|
||||
|
||||
1. 保留 profile 级灰度开关,支持按租户/任务类型回切旧模型;
|
||||
2. 上下文策略支持临时扩容(20 -> 25)作为应急参数;
|
||||
3. memory agent 保留质量兜底阈值,低置信度结果不落库。
|
||||
|
||||
---
|
||||
|
||||
本版文档确立了清晰的一次性架构方向:
|
||||
|
||||
- router 从执行链路中移除;
|
||||
- worker 直连用户消息并承担意图识别;
|
||||
- memory 由专用 agent 执行;
|
||||
- 上下文按 20 条用户消息定长回溯,控制 token 成本;
|
||||
- 在兼容现有协议消费方式的前提下完成演进。
|
||||
@@ -1,449 +0,0 @@
|
||||
# Reminder Alert Archival Implementation Plan
|
||||
|
||||
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
|
||||
|
||||
**Goal:** Deliver alarm-style reminder popups with cancel/snooze actions, 30s timeout auto-snooze, overlap handling, and archived/gray lifecycle consistency across Android and iOS.
|
||||
|
||||
**Architecture:** Keep scheduling local on device (Flutter local notifications), persist user reminder actions with an app-side outbox for eventual backend sync, and use backend PATCH update for archive status as the source of truth. Add a backend safety net job to auto-archive expired active events so app-terminated scenarios still converge. Implement shared reminder payload and action handler with platform-specific notification configuration (Android full-screen intent, iOS category actions).
|
||||
|
||||
**Tech Stack:** Flutter (`flutter_local_notifications`, existing calendar service), FastAPI schedule-items API, SQLAlchemy service layer, uv/pytest, Dart tests.
|
||||
|
||||
---
|
||||
|
||||
### Task 1: Update protocol and state semantics first
|
||||
|
||||
**Files:**
|
||||
- Modify: `docs/protocols/` (create a new reminder interaction protocol doc or extend existing schedule protocol doc)
|
||||
- Modify: `docs/runtime/runtime-route.md`
|
||||
|
||||
**Step 1: Write failing doc checks (manual checklist as fail-first gate)**
|
||||
|
||||
```text
|
||||
Checklist fails until all are documented:
|
||||
1) cancel action semantics
|
||||
2) snooze +10 minutes semantics
|
||||
3) timeout(30s) = ignore -> snooze
|
||||
4) overlap aggregation semantics
|
||||
5) archive + gray render semantics
|
||||
6) iOS degraded behavior note
|
||||
```
|
||||
|
||||
**Step 2: Run verification of checklist**
|
||||
|
||||
Run: manual review (expect FAIL before edits)
|
||||
|
||||
**Step 3: Write minimal protocol spec**
|
||||
|
||||
Include exact payload keys and action enum:
|
||||
|
||||
```json
|
||||
{
|
||||
"eventId": "uuid",
|
||||
"title": "string",
|
||||
"startAt": "iso8601",
|
||||
"endAt": "iso8601|null",
|
||||
"timezone": "IANA",
|
||||
"location": "string|null",
|
||||
"notes": "string|null",
|
||||
"color": "#RRGGBB|null",
|
||||
"mode": "single|aggregate",
|
||||
"aggregateIds": ["uuid"]
|
||||
}
|
||||
```
|
||||
|
||||
Actions:
|
||||
- `cancel`: archive target events and stop reminders
|
||||
- `snooze_10m`: reschedule +10m, stop when `now >= endAt`
|
||||
- `timeout_30s`: same as `snooze_10m`
|
||||
|
||||
**Step 4: Verify checklist passes**
|
||||
|
||||
Run: manual review (expect PASS)
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add docs/protocols docs/runtime/runtime-route.md
|
||||
git commit -m "docs: define reminder interaction protocol and lifecycle semantics"
|
||||
```
|
||||
|
||||
### Task 2: Add frontend reminder action models and payload codec
|
||||
|
||||
**Files:**
|
||||
- Create: `apps/lib/features/calendar/reminders/models/reminder_payload.dart`
|
||||
- Create: `apps/lib/features/calendar/reminders/models/reminder_action.dart`
|
||||
- Test: `apps/test/features/calendar/reminders/models/reminder_payload_test.dart`
|
||||
|
||||
**Step 1: Write the failing test**
|
||||
|
||||
```dart
|
||||
test('round-trips payload with single and aggregate modes', () {
|
||||
final payload = ReminderPayload(...);
|
||||
expect(ReminderPayload.fromJson(payload.toJson()), payload);
|
||||
});
|
||||
```
|
||||
|
||||
**Step 2: Run test to verify it fails**
|
||||
|
||||
Run: `flutter test test/features/calendar/reminders/models/reminder_payload_test.dart`
|
||||
Expected: FAIL (type/file missing)
|
||||
|
||||
**Step 3: Write minimal implementation**
|
||||
|
||||
Implement immutable model + json codec + enum parser.
|
||||
|
||||
**Step 4: Run test to verify it passes**
|
||||
|
||||
Run: `flutter test test/features/calendar/reminders/models/reminder_payload_test.dart`
|
||||
Expected: PASS
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add apps/lib/features/calendar/reminders/models apps/test/features/calendar/reminders/models/reminder_payload_test.dart
|
||||
git commit -m "feat: add reminder payload and action models"
|
||||
```
|
||||
|
||||
### Task 3: Refactor local notification service for action-capable reminders (Android + iOS)
|
||||
|
||||
**Files:**
|
||||
- Modify: `apps/lib/core/notifications/local_notification_service.dart`
|
||||
- Modify: `apps/lib/main.dart`
|
||||
- Modify: `apps/ios/Runner/AppDelegate.swift`
|
||||
- Test: `apps/test/core/notifications/local_notification_service_test.dart`
|
||||
|
||||
**Step 1: Write failing tests**
|
||||
|
||||
```dart
|
||||
test('uses alarm-style Android details with actions and timeout', () async {});
|
||||
test('uses Darwin category actions for cancel/snooze', () async {});
|
||||
test('encodes payload in notification details', () async {});
|
||||
```
|
||||
|
||||
**Step 2: Run tests to verify failure**
|
||||
|
||||
Run: `flutter test test/core/notifications/local_notification_service_test.dart`
|
||||
Expected: FAIL
|
||||
|
||||
**Step 3: Write minimal implementation**
|
||||
|
||||
Implement:
|
||||
- Android notification actions: `cancel`, `snooze_10m`
|
||||
- `timeoutAfter: 30000`
|
||||
- payload serialization
|
||||
- iOS `DarwinNotificationCategory` + action identifiers
|
||||
- initialize callback registration for action responses
|
||||
|
||||
Also keep existing full-screen alarm setup and exact alarm fallback behavior.
|
||||
|
||||
**Step 4: Run tests to verify pass**
|
||||
|
||||
Run: `flutter test test/core/notifications/local_notification_service_test.dart`
|
||||
Expected: PASS
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add apps/lib/core/notifications/local_notification_service.dart apps/lib/main.dart apps/ios/Runner/AppDelegate.swift apps/test/core/notifications/local_notification_service_test.dart
|
||||
git commit -m "feat: support actionable reminder notifications on android and ios"
|
||||
```
|
||||
|
||||
### Task 4: Implement reminder action executor + local outbox for eventual consistency
|
||||
|
||||
**Files:**
|
||||
- Create: `apps/lib/features/calendar/reminders/reminder_action_executor.dart`
|
||||
- Create: `apps/lib/features/calendar/reminders/reminder_outbox_store.dart`
|
||||
- Modify: `apps/lib/features/calendar/data/services/calendar_service.dart`
|
||||
- Test: `apps/test/features/calendar/reminders/reminder_action_executor_test.dart`
|
||||
|
||||
**Step 1: Write failing tests**
|
||||
|
||||
```dart
|
||||
test('cancel archives remotely and cancels local reminders', () async {});
|
||||
test('network failure writes outbox item and keeps local state updated', () async {});
|
||||
test('snooze reschedules +10m and stops after endAt', () async {});
|
||||
```
|
||||
|
||||
**Step 2: Run tests (expect FAIL)**
|
||||
|
||||
Run: `flutter test test/features/calendar/reminders/reminder_action_executor_test.dart`
|
||||
|
||||
**Step 3: Write minimal implementation**
|
||||
|
||||
Rules:
|
||||
- Cancel: local cancel now, enqueue archive API job, best-effort immediate PATCH
|
||||
- Snooze: schedule at `now + 10m`, if `next >= endAt` then archive path
|
||||
- Timeout action uses same path as snooze
|
||||
|
||||
**Step 4: Re-run tests (expect PASS)**
|
||||
|
||||
Run: `flutter test test/features/calendar/reminders/reminder_action_executor_test.dart`
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add apps/lib/features/calendar/reminders apps/lib/features/calendar/data/services/calendar_service.dart apps/test/features/calendar/reminders/reminder_action_executor_test.dart
|
||||
git commit -m "feat: add reminder action executor with offline outbox"
|
||||
```
|
||||
|
||||
### Task 5: Add startup reconciliation (replay outbox + rebuild reminders)
|
||||
|
||||
**Files:**
|
||||
- Modify: `apps/lib/core/startup/auth_session_bootstrapper.dart`
|
||||
- Modify: `apps/lib/main.dart`
|
||||
- Test: `apps/test/core/startup/auth_session_bootstrapper_test.dart`
|
||||
|
||||
**Step 1: Write failing tests**
|
||||
|
||||
```dart
|
||||
test('replays pending reminder actions after login', () async {});
|
||||
test('rebuilds reminders after outbox replay', () async {});
|
||||
```
|
||||
|
||||
**Step 2: Run tests (expect FAIL)**
|
||||
|
||||
Run: `flutter test test/core/startup/auth_session_bootstrapper_test.dart`
|
||||
|
||||
**Step 3: Write minimal implementation**
|
||||
|
||||
In authenticated startup flow:
|
||||
1) replay outbox
|
||||
2) fetch events with overlap semantics (active and not ended)
|
||||
3) rebuild active reminders with compensation scheduling:
|
||||
- `now < remindAt`: schedule at remindAt
|
||||
- `remindAt <= now < endAt`: schedule immediate compensation reminder (e.g. +5s)
|
||||
- `now >= endAt`: archive path
|
||||
4) enforce reminder dedupe key to avoid duplicate reminders after reinstall/restart
|
||||
|
||||
**Step 4: Re-run tests (expect PASS)**
|
||||
|
||||
Run: `flutter test test/core/startup/auth_session_bootstrapper_test.dart`
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add apps/lib/core/startup/auth_session_bootstrapper.dart apps/lib/main.dart apps/test/core/startup/auth_session_bootstrapper_test.dart
|
||||
git commit -m "feat: replay reminder outbox on startup"
|
||||
```
|
||||
|
||||
### Task 6: Implement overlap strategy (aggregate popup)
|
||||
|
||||
**Files:**
|
||||
- Create: `apps/lib/features/calendar/reminders/reminder_overlap_policy.dart`
|
||||
- Modify: `apps/lib/core/notifications/local_notification_service.dart`
|
||||
- Test: `apps/test/features/calendar/reminders/reminder_overlap_policy_test.dart`
|
||||
|
||||
**Step 1: Write failing tests**
|
||||
|
||||
```dart
|
||||
test('groups reminders whose fire time falls into same minute bucket', () {});
|
||||
test('creates aggregate payload with top-3 preview and ids', () {});
|
||||
```
|
||||
|
||||
**Step 2: Run tests (expect FAIL)**
|
||||
|
||||
Run: `flutter test test/features/calendar/reminders/reminder_overlap_policy_test.dart`
|
||||
|
||||
**Step 3: Write minimal implementation**
|
||||
|
||||
Policy:
|
||||
- same minute bucket => one aggregate popup
|
||||
- actions apply to all members by default
|
||||
- payload includes aggregateIds
|
||||
|
||||
**Step 4: Re-run tests (expect PASS)**
|
||||
|
||||
Run: `flutter test test/features/calendar/reminders/reminder_overlap_policy_test.dart`
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add apps/lib/features/calendar/reminders/reminder_overlap_policy.dart apps/lib/core/notifications/local_notification_service.dart apps/test/features/calendar/reminders/reminder_overlap_policy_test.dart
|
||||
git commit -m "feat: add overlap aggregation policy for reminders"
|
||||
```
|
||||
|
||||
### Task 7: Render archived events as gray in calendar UI
|
||||
|
||||
**Files:**
|
||||
- Modify: `apps/lib/features/calendar/ui/**` (event color resolution points)
|
||||
- Test: `apps/test/features/calendar/ui/*archived*test.dart` (new if missing)
|
||||
|
||||
**Step 1: Write failing tests**
|
||||
|
||||
```dart
|
||||
testWidgets('archived events use gray token color', (tester) async {});
|
||||
```
|
||||
|
||||
**Step 2: Run tests (expect FAIL)**
|
||||
|
||||
Run: `flutter test test/features/calendar/ui`
|
||||
|
||||
**Step 3: Write minimal implementation**
|
||||
|
||||
Rule:
|
||||
- if `status == archived`, force token-based gray (do not mutate persisted `metadata.color`)
|
||||
|
||||
**Step 4: Re-run tests (expect PASS)**
|
||||
|
||||
Run: `flutter test test/features/calendar/ui`
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add apps/lib/features/calendar/ui apps/test/features/calendar/ui
|
||||
git commit -m "feat: render archived calendar events in gray"
|
||||
```
|
||||
|
||||
### Task 8: Backend reuse route + add expired-event auto-archive safety job
|
||||
|
||||
**Files:**
|
||||
- Modify: `backend/src/v1/schedule_items/service.py` (if needed for stricter status transition)
|
||||
- Modify: `backend/src/v1/schedule_items/repository.py` (range query to overlap query)
|
||||
- Create: `backend/src/jobs/schedule_item_archive_job.py` (or existing worker module path)
|
||||
- Modify: worker scheduler registration file under `backend/src/core/celery/` (actual existing path)
|
||||
- Test: `backend/tests/unit/v1/schedule_items/test_service.py`
|
||||
- Test: `backend/tests/unit/v1/schedule_items/test_repository.py`
|
||||
- Test: `backend/tests/unit/jobs/test_schedule_item_archive_job.py`
|
||||
|
||||
**Step 1: Write failing tests**
|
||||
|
||||
```python
|
||||
def test_patch_status_archived_allowed_for_owner() -> None: ...
|
||||
def test_list_by_overlap_includes_started_but_not_ended_items() -> None: ...
|
||||
def test_archive_job_marks_expired_active_items_archived() -> None: ...
|
||||
```
|
||||
|
||||
**Step 2: Run tests (expect FAIL)**
|
||||
|
||||
Run: `uv run pytest backend/tests/unit/v1/schedule_items/test_service.py backend/tests/unit/jobs/test_schedule_item_archive_job.py -q`
|
||||
|
||||
**Step 3: Write minimal implementation**
|
||||
|
||||
Implement/verify:
|
||||
- route reuse: PATCH status archived works as-is for authorized user
|
||||
- overlap query for bootstrap: `start_at <= window_end AND (end_at IS NULL OR end_at >= window_start)` and `status=active`
|
||||
- periodic archive job: `end_at < now and status=active -> archived`
|
||||
|
||||
**Step 4: Re-run tests (expect PASS)**
|
||||
|
||||
Run: `uv run pytest backend/tests/unit/v1/schedule_items/test_service.py backend/tests/unit/jobs/test_schedule_item_archive_job.py -q`
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add backend/src backend/tests
|
||||
git commit -m "feat: add expired schedule auto-archive safety job"
|
||||
```
|
||||
|
||||
### Task 9: End-to-end verification and release notes
|
||||
|
||||
**Files:**
|
||||
- Modify: `docs/runtime/runtime-runbook.md`
|
||||
- Modify: `docs/protocols/` reminder doc from Task 1
|
||||
|
||||
**Step 1: Run frontend verification**
|
||||
|
||||
Run:
|
||||
- `flutter analyze`
|
||||
- `flutter test`
|
||||
|
||||
Expected: PASS
|
||||
|
||||
**Step 2: Run backend verification**
|
||||
|
||||
Run:
|
||||
- `uv run pytest backend/tests/unit/v1/schedule_items -q`
|
||||
- `uv run pytest backend/tests/unit/jobs/test_schedule_item_archive_job.py -q`
|
||||
|
||||
Expected: PASS
|
||||
|
||||
**Step 3: Manual device matrix**
|
||||
|
||||
Android:
|
||||
- app foreground/background/terminated for cancel/snooze/timeout
|
||||
- overlap popup behavior
|
||||
- endAt stop reminder + archive
|
||||
|
||||
iOS:
|
||||
- action button behavior in foreground/background/terminated
|
||||
- timeout -> snooze behavior after relaunch sync
|
||||
- archive sync after offline period
|
||||
|
||||
**Step 4: Document operational caveats**
|
||||
|
||||
Include:
|
||||
- Android full-screen may degrade to heads-up by OEM policy
|
||||
- iOS does not guarantee Android-style full-screen alarm behavior
|
||||
- eventual consistency via outbox + startup replay + backend safety job
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add docs/runtime/runtime-runbook.md docs/protocols
|
||||
git commit -m "docs: add reminder action runbook and platform caveats"
|
||||
```
|
||||
|
||||
## Notes on iOS parity
|
||||
|
||||
- iOS supports actionable local notifications via Darwin categories; implement `cancel` and `snooze_10m` action identifiers with same payload model.
|
||||
- iOS cannot guarantee Android-like forced full-screen alarm takeover; use lock-screen alert + sound + action buttons as equivalent UX.
|
||||
- App-terminated network callback reliability is lower on iOS; therefore outbox + startup replay is mandatory for parity.
|
||||
|
||||
## Data contracts and constraints (added)
|
||||
|
||||
### Reminder payload contract
|
||||
|
||||
```json
|
||||
{
|
||||
"eventId": "uuid",
|
||||
"title": "string",
|
||||
"startAt": "iso8601-with-offset",
|
||||
"endAt": "iso8601-with-offset|null",
|
||||
"timezone": "IANA",
|
||||
"location": "string|null",
|
||||
"notes": "string|null",
|
||||
"color": "#RRGGBB|null",
|
||||
"mode": "single|aggregate",
|
||||
"aggregateIds": ["uuid"]
|
||||
}
|
||||
```
|
||||
|
||||
Constraints:
|
||||
- `eventId` required and valid UUID.
|
||||
- `startAt` must be timezone-aware datetime.
|
||||
- `mode=aggregate` requires `aggregateIds.length >= 2`.
|
||||
- Payload versioning should be explicit if schema evolves.
|
||||
|
||||
### Reminder outbox contract
|
||||
|
||||
```json
|
||||
{
|
||||
"opId": "uuid",
|
||||
"eventId": "uuid",
|
||||
"action": "cancel|snooze_10m|timeout_30s|auto_archive",
|
||||
"targetStatus": "archived|null",
|
||||
"occurredAt": "iso8601-with-offset",
|
||||
"retryCount": 0,
|
||||
"nextRetryAt": "iso8601-with-offset|null",
|
||||
"state": "pending|done|dead",
|
||||
"lastError": "string|null"
|
||||
}
|
||||
```
|
||||
|
||||
Constraints:
|
||||
- Idempotency key: `(eventId, action, occurredAtBucket)`.
|
||||
- Exponential backoff retries with capped max attempts.
|
||||
- `cancel` and `auto_archive` both map to backend `status=archived` PATCH.
|
||||
|
||||
### Uniqueness and dedupe rules
|
||||
|
||||
- Notification identity uses deterministic key per event+cycle: `hash(eventId + cycleStartEpochMinutes + mode)`.
|
||||
- Before scheduling any reminder, cancel existing pending reminders for same dedupe key.
|
||||
- On bootstrap/reinstall, dedupe against local pending requests and outbox state before creating new schedules.
|
||||
- Compensation reminder (`remindAt <= now < endAt`) must generate exactly one immediate reminder per cycle window.
|
||||
|
||||
## Rollback plan
|
||||
|
||||
1) Disable action handling by feature flag while keeping plain reminders.
|
||||
2) Keep backend PATCH status route unchanged (safe rollback path).
|
||||
3) Pause auto-archive job if unexpected archival spikes occur.
|
||||
@@ -1,61 +0,0 @@
|
||||
# UI Schema File Split Implementation Plan
|
||||
|
||||
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
|
||||
|
||||
**Goal:** 将 `apps/lib/core/schemas/ui_schema.dart` 从超大单文件拆分为同库的多个 `part` 文件,在不改变协议行为的前提下提升可维护性。
|
||||
|
||||
**Architecture:** 保留 `ui_schema.dart` 作为唯一对外入口和 single source of truth;通过 `part` 把 enums、common types、actions、nodes、document、builder 按职责拆分到 `core/schemas/ui_schema/` 子目录。所有类型名、JSON 字段、默认值与工厂方法逻辑保持完全一致,避免协议漂移。
|
||||
|
||||
**Tech Stack:** Dart library/part、Flutter analyze、Flutter test
|
||||
|
||||
---
|
||||
|
||||
### Task 1: 建立拆分骨架并保留外部接口
|
||||
|
||||
**Files:**
|
||||
- Modify: `apps/lib/core/schemas/ui_schema.dart`
|
||||
- Create: `apps/lib/core/schemas/ui_schema/enums.dart`
|
||||
- Create: `apps/lib/core/schemas/ui_schema/common_types.dart`
|
||||
- Create: `apps/lib/core/schemas/ui_schema/actions.dart`
|
||||
- Create: `apps/lib/core/schemas/ui_schema/nodes.dart`
|
||||
- Create: `apps/lib/core/schemas/ui_schema/document.dart`
|
||||
- Create: `apps/lib/core/schemas/ui_schema/builders.dart`
|
||||
|
||||
- [ ] **Step 1: 在主文件添加 `part` 声明并保留文件头注释**
|
||||
- [ ] **Step 1.1: 所有子文件统一使用 `part of '../ui_schema.dart';`,避免子目录相对路径错误**
|
||||
- [ ] **Step 2: 将 enum 定义迁移到 `enums.dart`,语义不变**
|
||||
- [ ] **Step 3: 将基础 DTO 迁移到 `common_types.dart`,语义不变**
|
||||
- [ ] **Step 4: 将 Action 协议与解析逻辑迁移到 `actions.dart`,语义不变**
|
||||
- [ ] **Step 5: 将 UiNode 与各 node 实现迁移到 `nodes.dart`,语义不变**
|
||||
- [ ] **Step 6: 将文档配置/文档模型迁移到 `document.dart`,语义不变**
|
||||
- [ ] **Step 7: 将 `buildSuccessDocument`/`buildErrorDocument` 迁移到 `builders.dart`**
|
||||
|
||||
### Task 2: 进行协议稳定性验证
|
||||
|
||||
**Files:**
|
||||
- Create: `apps/test/core/schemas/ui_schema_test.dart`
|
||||
- Test: `apps/test/features/chat/ui_schema_renderer_test.dart`
|
||||
- Test: `apps/test/features/chat/ui_schema_navigation_test.dart`
|
||||
- Test: `apps/test/features/chat/ag_ui_event_test.dart`
|
||||
|
||||
- [ ] **Step 1: 新增 `ui_schema.dart` 直连回归测试**
|
||||
- 覆盖 enum fallback、`actionSpecFromJson` 分支、`UiNode.fromJson` 分支、Document/builder 默认值、round-trip 稳定性。
|
||||
- [ ] **Step 2: 在 `apps/` 目录运行 analyze,确认 `part` 结构无编译错误**
|
||||
- Run (`apps/`): `flutter analyze`
|
||||
- [ ] **Step 3: 在 `apps/` 目录运行 UI Schema 渲染与导航相关测试**
|
||||
- Run: `flutter test test/features/chat/ui_schema_renderer_test.dart`
|
||||
- Run: `flutter test test/features/chat/ui_schema_navigation_test.dart`
|
||||
- [ ] **Step 4: 在 `apps/` 目录运行 AG-UI 事件模型回归测试**
|
||||
- Run: `flutter test test/features/chat/ag_ui_event_test.dart`
|
||||
- [ ] **Step 5: 在 `apps/` 目录运行新增 schema 回归测试**
|
||||
- Run: `flutter test test/core/schemas/ui_schema_test.dart`
|
||||
|
||||
### Task 3: 完成收尾与风险核对
|
||||
|
||||
**Files:**
|
||||
- Modify: `apps/lib/core/schemas/ui_schema.dart`
|
||||
- Modify: `apps/lib/core/schemas/ui_schema/*.dart`
|
||||
|
||||
- [ ] **Step 1: 检查 public API 未变化(类型名/函数名不变)**
|
||||
- [ ] **Step 2: 检查 JSON 键、默认值、fallback 分支未变化**
|
||||
- [ ] **Step 3: 确认 `ui_schema.dart` 行数显著下降并保留 single source 入口定位**
|
||||
Reference in New Issue
Block a user