feat(agentscope): add memory system and automation job support
- Add consumer_registry and pipeline_registry for runtime orchestration - Add Visibility schema for message filtering - Add PipelineSpec for agent pipeline configuration - Add automation job models and configuration - Remove memory_prompt.py (consolidated into memory system) - Update runtime components: context_loader, context_service, orchestrator, runner, tasks - Update toolkit: tool_config, tool_middleware, custom tools (calendar, user_lookup) - Add auth_helpers and calendar_domain utilities - Add system_agents.yaml configuration
This commit is contained in:
@@ -10,8 +10,6 @@ from ag_ui.core import (
|
||||
RunErrorEvent,
|
||||
StepStartedEvent,
|
||||
StepFinishedEvent,
|
||||
TextMessageEndEvent,
|
||||
ToolCallResultEvent,
|
||||
)
|
||||
from core.agentscope.runtime.ui_compiler import compile as compile_ui_hints
|
||||
from schemas.agent.ui_hints import UiHintsPayload
|
||||
@@ -94,9 +92,20 @@ def _build_run_finished(event: dict[str, Any]) -> RunFinishedEvent:
|
||||
|
||||
def _build_run_error(event: dict[str, Any]) -> RunErrorEvent:
|
||||
data = event.get("data", {})
|
||||
top_level_message = event.get("message")
|
||||
message = top_level_message if isinstance(top_level_message, str) else ""
|
||||
top_level_code = event.get("code")
|
||||
code = top_level_code if isinstance(top_level_code, str) else None
|
||||
if (not message or code is None) and isinstance(data, dict):
|
||||
data_message = data.get("message")
|
||||
if not message and isinstance(data_message, str):
|
||||
message = data_message
|
||||
data_code = data.get("code")
|
||||
if code is None and isinstance(data_code, str):
|
||||
code = data_code
|
||||
return RunErrorEvent(
|
||||
message=data.get("message", "Unknown error"),
|
||||
code=data.get("code"),
|
||||
message=message or "Unknown error",
|
||||
code=code,
|
||||
)
|
||||
|
||||
|
||||
@@ -120,34 +129,12 @@ def _build_step_finished(event: dict[str, Any]) -> StepFinishedEvent:
|
||||
)
|
||||
|
||||
|
||||
def _build_text_end(event: dict[str, Any]) -> TextMessageEndEvent:
|
||||
data = event.get("data", {})
|
||||
return TextMessageEndEvent(
|
||||
message_id=data.get("messageId", ""),
|
||||
)
|
||||
|
||||
|
||||
def _build_tool_result(event: dict[str, Any]) -> ToolCallResultEvent:
|
||||
data = event.get("data", {})
|
||||
content = data.get("result")
|
||||
if not isinstance(content, str):
|
||||
content = ""
|
||||
return ToolCallResultEvent(
|
||||
message_id=data.get("messageId", ""),
|
||||
tool_call_id=data.get("toolCallId", ""),
|
||||
content=content,
|
||||
role="tool",
|
||||
)
|
||||
|
||||
|
||||
_BUILDER_MAP: dict[str, Any] = {
|
||||
"run.started": _build_run_started,
|
||||
"run.finished": _build_run_finished,
|
||||
"run.error": _build_run_error,
|
||||
"step.start": _build_step_started,
|
||||
"step.finish": _build_step_finished,
|
||||
"text.end": _build_text_end,
|
||||
"tool.result": _build_tool_result,
|
||||
}
|
||||
|
||||
|
||||
@@ -208,6 +195,8 @@ def to_agui_wire_event(event: dict[str, Any] | BaseEvent) -> dict[str, Any]:
|
||||
payload["runId"] = run_id
|
||||
if isinstance(data, dict):
|
||||
reserved = {"type", "threadId", "runId"}
|
||||
if internal_type == "run.error":
|
||||
reserved = {*reserved, "message", "code"}
|
||||
payload.update({k: v for k, v in data.items() if k not in reserved})
|
||||
return payload
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ class MessageRepository:
|
||||
output_tokens: int = 0,
|
||||
cost: Decimal = Decimal("0"),
|
||||
latency_ms: int | None = None,
|
||||
visibility_mask: int = 0,
|
||||
) -> AgentChatMessage:
|
||||
message = AgentChatMessage(
|
||||
session_id=session_id,
|
||||
@@ -42,6 +43,7 @@ class MessageRepository:
|
||||
output_tokens=output_tokens,
|
||||
cost=cost,
|
||||
latency_ms=latency_ms,
|
||||
visibility_mask=max(int(visibility_mask), 0),
|
||||
)
|
||||
self._session.add(message)
|
||||
await self._session.flush()
|
||||
|
||||
@@ -8,9 +8,12 @@ from core.agentscope.events.persistence import MessageRepository, SessionReposit
|
||||
from core.logging import get_logger
|
||||
from models.agent_chat_message import AgentChatMessageRole
|
||||
from models.agent_chat_session import AgentChatSessionStatus
|
||||
from models.system_agents import SystemAgents
|
||||
from schemas.agent.system_agent import AgentType, SystemAgentLLMConfig
|
||||
from schemas.agent.runtime_models import AgentOutput, ToolAgentOutput
|
||||
from schemas.agent.system_agent import AgentType
|
||||
from schemas.agent.visibility import SystemVisibilityBit, bit_mask
|
||||
from schemas.messages.chat_message import AgentChatMessageMetadata
|
||||
from sqlalchemy import select
|
||||
|
||||
|
||||
class EventStore(Protocol):
|
||||
@@ -45,6 +48,9 @@ class SqlAlchemyEventStore:
|
||||
async with self._session_factory() as session:
|
||||
session_repo = SessionRepository(session)
|
||||
message_repo = MessageRepository(session)
|
||||
stage_visibility_bit_map = await self._load_stage_visibility_bit_map(
|
||||
session=session
|
||||
)
|
||||
chat_session = await session_repo.get_session(session_id=session_id)
|
||||
if chat_session is None:
|
||||
return
|
||||
@@ -77,6 +83,7 @@ class SqlAlchemyEventStore:
|
||||
chat_session=chat_session,
|
||||
session_repo=session_repo,
|
||||
message_repo=message_repo,
|
||||
stage_visibility_bit_map=stage_visibility_bit_map,
|
||||
)
|
||||
elif event_type == "TOOL_CALL_RESULT":
|
||||
await self._persist_tool_call_result(
|
||||
@@ -85,6 +92,7 @@ class SqlAlchemyEventStore:
|
||||
chat_session=chat_session,
|
||||
session_repo=session_repo,
|
||||
message_repo=message_repo,
|
||||
stage_visibility_bit_map=stage_visibility_bit_map,
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
@@ -97,6 +105,7 @@ class SqlAlchemyEventStore:
|
||||
chat_session: Any,
|
||||
session_repo: SessionRepository,
|
||||
message_repo: MessageRepository,
|
||||
stage_visibility_bit_map: dict[str, int],
|
||||
) -> None:
|
||||
message_id_raw = self._event_value(event, "messageId")
|
||||
message_id = message_id_raw if isinstance(message_id_raw, str) else ""
|
||||
@@ -188,6 +197,10 @@ class SqlAlchemyEventStore:
|
||||
output_tokens=output_tokens,
|
||||
cost=cost,
|
||||
latency_ms=latency_ms,
|
||||
visibility_mask=self._resolve_stage_visibility_mask(
|
||||
event=event,
|
||||
stage_visibility_bit_map=stage_visibility_bit_map,
|
||||
),
|
||||
)
|
||||
|
||||
current_status = getattr(chat_session, "status", AgentChatSessionStatus.RUNNING)
|
||||
@@ -213,6 +226,7 @@ class SqlAlchemyEventStore:
|
||||
chat_session: Any,
|
||||
session_repo: SessionRepository,
|
||||
message_repo: MessageRepository,
|
||||
stage_visibility_bit_map: dict[str, int],
|
||||
) -> None:
|
||||
run_id = self._event_value(event, "runId")
|
||||
run_id_value = run_id if isinstance(run_id, str) and run_id else None
|
||||
@@ -256,6 +270,10 @@ class SqlAlchemyEventStore:
|
||||
content=content,
|
||||
tool_name=tool_output.tool_name,
|
||||
metadata=metadata_model.model_dump(mode="json", exclude_none=True),
|
||||
visibility_mask=self._resolve_stage_visibility_mask(
|
||||
event=event,
|
||||
stage_visibility_bit_map=stage_visibility_bit_map,
|
||||
),
|
||||
)
|
||||
|
||||
current_status = getattr(chat_session, "status", AgentChatSessionStatus.RUNNING)
|
||||
@@ -279,6 +297,44 @@ class SqlAlchemyEventStore:
|
||||
return AgentChatMessageRole.TOOL
|
||||
return AgentChatMessageRole.ASSISTANT
|
||||
|
||||
def _resolve_stage_visibility_mask(
|
||||
self,
|
||||
*,
|
||||
event: dict[str, Any],
|
||||
stage_visibility_bit_map: dict[str, int],
|
||||
) -> int:
|
||||
base = bit_mask(bit=int(SystemVisibilityBit.UI_HISTORY))
|
||||
raw_stage = self._event_value(event, "stage")
|
||||
if not isinstance(raw_stage, str):
|
||||
return base
|
||||
normalized_stage = raw_stage.strip().lower()
|
||||
bit = stage_visibility_bit_map.get(normalized_stage)
|
||||
if bit is None and normalized_stage == AgentType.MEMORY.value:
|
||||
bit = 18
|
||||
if bit is None:
|
||||
return base
|
||||
return base | bit_mask(bit=bit)
|
||||
|
||||
async def _load_stage_visibility_bit_map(
|
||||
self,
|
||||
*,
|
||||
session: Any,
|
||||
) -> dict[str, int]:
|
||||
stmt = select(SystemAgents.agent_type, SystemAgents.config).where(
|
||||
SystemAgents.agent_type.in_(
|
||||
[AgentType.ROUTER.value, AgentType.WORKER.value, AgentType.MEMORY.value]
|
||||
)
|
||||
)
|
||||
rows = (await session.execute(stmt)).all()
|
||||
bit_map: dict[str, int] = {}
|
||||
for agent_type, raw_config in rows:
|
||||
if not isinstance(agent_type, str):
|
||||
continue
|
||||
config_payload = raw_config if isinstance(raw_config, dict) else {}
|
||||
llm_config = SystemAgentLLMConfig.model_validate(config_payload)
|
||||
bit_map[agent_type.strip().lower()] = llm_config.visibility_consumer_bit
|
||||
return bit_map
|
||||
|
||||
async def _update_session_state(
|
||||
self,
|
||||
*,
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from schemas.agent.runtime_models import ResultType, RouterAgentOutput, TaskType
|
||||
from schemas.agent.system_agent import AgentType, SystemAgentLLMConfig
|
||||
|
||||
|
||||
@@ -14,6 +17,24 @@ def _wrap_section(section: str, content: str) -> str:
|
||||
return f"{start}\n{body}\n{end}" if body else f"{start}\n{end}"
|
||||
|
||||
|
||||
def _enum_values(enum_cls: Any) -> str:
|
||||
return ", ".join(item.value for item in enum_cls)
|
||||
|
||||
|
||||
def _config_rules(llm_config: SystemAgentLLMConfig | None) -> list[str]:
|
||||
if llm_config is None:
|
||||
return []
|
||||
context_mode = llm_config.context_messages.mode.value
|
||||
context_count = llm_config.context_messages.count
|
||||
enabled_tools = [tool.value for tool in llm_config.enabled_tools]
|
||||
return [
|
||||
"[Runtime Config]",
|
||||
f"- context_messages.mode={context_mode}",
|
||||
f"- context_messages.count={context_count}",
|
||||
f"- enabled_tools={','.join(enabled_tools) if enabled_tools else 'default'}",
|
||||
]
|
||||
|
||||
|
||||
PromptRuleBuilder = Callable[[SystemAgentLLMConfig | None], list[str]]
|
||||
|
||||
|
||||
@@ -36,36 +57,41 @@ class AgentPromptRegistry:
|
||||
return builder(llm_config)
|
||||
|
||||
|
||||
def _config_rules(llm_config: SystemAgentLLMConfig | None) -> list[str]:
|
||||
if llm_config is None:
|
||||
return []
|
||||
context_mode = llm_config.context_messages.mode.value
|
||||
context_count = llm_config.context_messages.count
|
||||
tool_groups = [group.value for group in llm_config.enabled_tool_groups]
|
||||
def _router_rules(llm_config: SystemAgentLLMConfig | None) -> list[str]:
|
||||
return [
|
||||
"[Runtime Config]",
|
||||
f"- context_messages.mode={context_mode}",
|
||||
f"- context_messages.count={context_count}",
|
||||
f"- enabled_tool_groups={','.join(tool_groups) if tool_groups else 'default'}",
|
||||
"[Router Agent]",
|
||||
"- Read the latest user input and produce a routing contract for downstream execution.",
|
||||
"- Return exactly one RouterAgentOutput JSON object.",
|
||||
"[Responsibilities]",
|
||||
"- Router only: extract intent and route strategy; never answer user directly.",
|
||||
"- Preserve intent in normalized_task_input.user_text; keep wording concise and faithful.",
|
||||
"- Fill multimodal_summary only when image/attachment changes execution decisions.",
|
||||
"- Return key_entities and constraints that are execution-relevant; low confidence -> omit rather than guess.",
|
||||
"- Set execution_mode by complexity: onestep / tool_assisted / multistep.",
|
||||
"- Set result_typing.primary to the most suitable response shape; use clarification_request only when required info is missing.",
|
||||
"- Set ui.ui_mode and ui.ui_decision_reason based on whether structured UI improves actionability.",
|
||||
f"- task_typing.primary must use one TaskType enum: {_enum_values(TaskType)}.",
|
||||
f"- task_typing.secondary max 3 enums: {_enum_values(TaskType)}.",
|
||||
f"- result_typing.primary must use one ResultType enum: {_enum_values(ResultType)}.",
|
||||
f"- result_typing.secondary max 3 enums: {_enum_values(ResultType)}.",
|
||||
*_config_rules(llm_config),
|
||||
]
|
||||
|
||||
|
||||
def _worker_rules(llm_config: SystemAgentLLMConfig | None) -> list[str]:
|
||||
return [
|
||||
"[Worker Agent]",
|
||||
"- Process user context directly, identify intent, then execute or answer with available evidence.",
|
||||
"- Return exactly one agent output JSON object matching the runtime-injected schema.",
|
||||
"- Execute or answer against the routed objective and available evidence.",
|
||||
"- Return exactly one worker output JSON object matching the runtime-injected schema.",
|
||||
"[Responsibilities]",
|
||||
"- Handle user request directly from conversation context.",
|
||||
"- Decide intent first, then choose direct answer, tool call, clarification, or refusal.",
|
||||
"- Prefer a direct answer when no tool result is required.",
|
||||
"- Call tools only when tool results are necessary to produce a correct answer.",
|
||||
"- Infer deterministic required tool arguments from context, tool schema, and runtime signals.",
|
||||
"- Worker only: execute routed objective without changing router intent.",
|
||||
"- Treat router output as objective/constraints contract, not as a fully-materialized tool-args payload.",
|
||||
"- Infer deterministic required tool arguments from contract fields, tool schema, and runtime context.",
|
||||
"- Ask minimal clarification only when required arguments cannot be inferred safely.",
|
||||
"- If request is unsafe or disallowed, return safe refusal with actionable alternative.",
|
||||
"- Ground every claim in evidence and tool outputs; never fabricate execution state.",
|
||||
"- Ground every claim in available evidence and tool results; never fabricate execution state.",
|
||||
"- Keep status/result_type/answer/key_points/suggested_actions/error internally consistent.",
|
||||
"[Schema Guidance]",
|
||||
"- The output schema is injected at runtime; follow it exactly.",
|
||||
"- The worker output schema is injected at runtime; follow it exactly.",
|
||||
"- Do not add fields that are not present in the injected schema.",
|
||||
*_config_rules(llm_config),
|
||||
]
|
||||
@@ -88,7 +114,28 @@ def _memory_rules(llm_config: SystemAgentLLMConfig | None) -> list[str]:
|
||||
]
|
||||
|
||||
|
||||
def build_worker_contract_prompt(*, router_output: RouterAgentOutput) -> str:
|
||||
contract_json = json.dumps(
|
||||
router_output.model_dump(mode="json", exclude_none=True),
|
||||
ensure_ascii=True,
|
||||
separators=(",", ":"),
|
||||
)
|
||||
return "\n".join(
|
||||
[
|
||||
"[Worker Contract]",
|
||||
"- Keep routed objective unchanged.",
|
||||
"- Use normalized_task_input as objective text.",
|
||||
"- Use multimodal_summary/key_entities/constraints as execution evidence.",
|
||||
"- Infer deterministic missing required tool args from evidence + tool schema.",
|
||||
"- Ask clarification only when safe inference is impossible.",
|
||||
"[RouterAgentOutput]",
|
||||
contract_json,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
AGENT_PROMPT_REGISTRY = AgentPromptRegistry()
|
||||
AGENT_PROMPT_REGISTRY.register(agent_type=AgentType.ROUTER, builder=_router_rules)
|
||||
AGENT_PROMPT_REGISTRY.register(agent_type=AgentType.WORKER, builder=_worker_rules)
|
||||
AGENT_PROMPT_REGISTRY.register(agent_type=AgentType.MEMORY, builder=_memory_rules)
|
||||
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
pass
|
||||
@@ -0,0 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from schemas.agent.consumer_registry import AgentConsumerBinding, ConsumerRegistry
|
||||
|
||||
|
||||
def build_consumer_registry(
|
||||
*,
|
||||
system_agent_configs: dict[str, dict[str, object]],
|
||||
) -> ConsumerRegistry:
|
||||
bindings: list[AgentConsumerBinding] = []
|
||||
for agent_type, payload in system_agent_configs.items():
|
||||
config_obj = payload.get("config") if isinstance(payload, dict) else None
|
||||
if not isinstance(config_obj, dict):
|
||||
raise ValueError(f"invalid system agent config: {agent_type}")
|
||||
raw_bit = config_obj.get("visibility_consumer_bit")
|
||||
if not isinstance(raw_bit, int):
|
||||
raise ValueError(f"visibility_consumer_bit missing for agent: {agent_type}")
|
||||
bindings.append(AgentConsumerBinding(agent_type=agent_type, bit=raw_bit))
|
||||
return ConsumerRegistry(bindings=bindings)
|
||||
@@ -5,7 +5,7 @@ from typing import Any
|
||||
|
||||
from schemas.agent.system_agent import ContextBuildStrategy
|
||||
|
||||
ContextLoader = Callable[[Any, str, int], Awaitable[dict[str, object] | None]]
|
||||
ContextLoader = Callable[[Any, str, int, int], Awaitable[dict[str, object] | None]]
|
||||
|
||||
|
||||
class ContextLoaderRegistry:
|
||||
@@ -23,20 +23,28 @@ class ContextLoaderRegistry:
|
||||
|
||||
|
||||
async def _load_number(
|
||||
service: Any, thread_id: str, count: int
|
||||
service: Any,
|
||||
thread_id: str,
|
||||
count: int,
|
||||
visibility_mask: int,
|
||||
) -> dict[str, object] | None:
|
||||
return await service.load_by_user_message_window(
|
||||
thread_id=thread_id,
|
||||
user_message_limit=max(count, 1),
|
||||
visibility_mask=visibility_mask,
|
||||
)
|
||||
|
||||
|
||||
async def _load_day(
|
||||
service: Any, thread_id: str, count: int
|
||||
service: Any,
|
||||
thread_id: str,
|
||||
count: int,
|
||||
visibility_mask: int,
|
||||
) -> dict[str, object] | None:
|
||||
return await service.load_by_day_window(
|
||||
thread_id=thread_id,
|
||||
day_count=max(count, 1),
|
||||
visibility_mask=visibility_mask,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -5,17 +5,27 @@ from typing import Protocol
|
||||
|
||||
from core.agentscope.runtime.context_loader_registry import CONTEXT_LOADER_REGISTRY
|
||||
from schemas.agent.system_agent import SystemAgentLLMConfig
|
||||
from schemas.agent.visibility import bit_mask
|
||||
|
||||
_DEFAULT_CONTEXT_WINDOW_USER_MESSAGES = 20
|
||||
_DEFAULT_ROUTER_CONTEXT_DAY_COUNT = 20
|
||||
|
||||
|
||||
class ContextRepositoryLike(Protocol):
|
||||
async def get_history_day(
|
||||
self, *, session_id: str, before: date | None
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
before: date | None,
|
||||
visibility_mask: int | None = None,
|
||||
) -> dict[str, object] | None: ...
|
||||
|
||||
async def get_recent_messages_by_user_window(
|
||||
self, *, session_id: str, user_message_limit: int
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
user_message_limit: int,
|
||||
visibility_mask: int | None = None,
|
||||
) -> list[dict[str, object]]: ...
|
||||
|
||||
async def get_system_agent_config(
|
||||
@@ -41,20 +51,36 @@ class AgentContextService:
|
||||
if isinstance(raw_config, dict):
|
||||
raw_llm_config = raw_config
|
||||
|
||||
if mode == "router" and not raw_llm_config:
|
||||
raw_llm_config = {
|
||||
"context_messages": {
|
||||
"mode": "day",
|
||||
"count": _DEFAULT_ROUTER_CONTEXT_DAY_COUNT,
|
||||
}
|
||||
}
|
||||
|
||||
normalized_config = self._normalize_system_agent_config(raw_llm_config)
|
||||
context_config = normalized_config.context_messages
|
||||
visibility_mask = bit_mask(bit=normalized_config.visibility_consumer_bit)
|
||||
context_loader = CONTEXT_LOADER_REGISTRY.resolve(mode=context_config.mode)
|
||||
return await context_loader(self, thread_id, context_config.count)
|
||||
return await context_loader(
|
||||
self,
|
||||
thread_id,
|
||||
context_config.count,
|
||||
visibility_mask,
|
||||
)
|
||||
|
||||
async def load_by_user_message_window(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
user_message_limit: int,
|
||||
visibility_mask: int,
|
||||
) -> dict[str, object] | None:
|
||||
messages = await self._repository.get_recent_messages_by_user_window(
|
||||
session_id=thread_id,
|
||||
user_message_limit=max(int(user_message_limit), 1),
|
||||
visibility_mask=visibility_mask,
|
||||
)
|
||||
if not messages:
|
||||
return None
|
||||
@@ -65,6 +91,7 @@ class AgentContextService:
|
||||
*,
|
||||
thread_id: str,
|
||||
day_count: int,
|
||||
visibility_mask: int,
|
||||
) -> dict[str, object] | None:
|
||||
messages: list[dict[str, object]] = []
|
||||
before: date | None = None
|
||||
@@ -72,6 +99,7 @@ class AgentContextService:
|
||||
day_payload = await self._repository.get_history_day(
|
||||
session_id=thread_id,
|
||||
before=before,
|
||||
visibility_mask=visibility_mask,
|
||||
)
|
||||
if not day_payload:
|
||||
break
|
||||
@@ -95,7 +123,7 @@ class AgentContextService:
|
||||
"mode": "number",
|
||||
"count": _DEFAULT_CONTEXT_WINDOW_USER_MESSAGES,
|
||||
},
|
||||
"enabled_tool_groups": [],
|
||||
"enabled_tools": [],
|
||||
}
|
||||
if not raw_config:
|
||||
return SystemAgentLLMConfig.model_validate(default_payload)
|
||||
|
||||
@@ -6,6 +6,7 @@ from ag_ui.core.types import RunAgentInput
|
||||
from agentscope.message import Msg
|
||||
from core.agentscope.runtime.runner import AgentScopeRunner
|
||||
from core.logging import get_logger
|
||||
from schemas.automation.config import AutomationJobConfig
|
||||
from schemas.user import UserContext
|
||||
|
||||
logger = get_logger("core.agentscope.runtime.orchestrator")
|
||||
@@ -24,6 +25,7 @@ class RunnerLike(Protocol):
|
||||
pipeline: PipelineLike,
|
||||
run_input: RunAgentInput,
|
||||
system_agent_mode: str,
|
||||
memory_job_config: AutomationJobConfig | None,
|
||||
) -> dict[str, Any]: ...
|
||||
|
||||
|
||||
@@ -47,6 +49,7 @@ class AgentScopeRuntimeOrchestrator:
|
||||
context_messages: list[Msg],
|
||||
user_context: UserContext,
|
||||
system_agent_mode: str,
|
||||
memory_job_config: AutomationJobConfig | None = None,
|
||||
) -> dict[str, Any]:
|
||||
thread_id = run_input.thread_id
|
||||
run_id = run_input.run_id
|
||||
@@ -66,6 +69,7 @@ class AgentScopeRuntimeOrchestrator:
|
||||
pipeline=self._pipeline,
|
||||
run_input=run_input,
|
||||
system_agent_mode=system_agent_mode,
|
||||
memory_job_config=memory_job_config,
|
||||
)
|
||||
|
||||
await self._pipeline.emit(
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from schemas.agent.pipeline_spec import (
|
||||
ContextPolicy,
|
||||
ContextWindowMode,
|
||||
ExecutorKind,
|
||||
PipelineSpec,
|
||||
StageSpec,
|
||||
)
|
||||
|
||||
|
||||
def build_default_pipeline_spec(*, mode: str) -> PipelineSpec:
|
||||
normalized = mode.strip().lower()
|
||||
if normalized == "worker":
|
||||
return PipelineSpec(
|
||||
mode="worker",
|
||||
stages=[
|
||||
StageSpec(
|
||||
stage_name="router",
|
||||
executor_kind=ExecutorKind.SINGLE_SHOT,
|
||||
default_visibility_mask=0,
|
||||
context_policy=ContextPolicy(
|
||||
consumer_agent_type="router",
|
||||
window_mode=ContextWindowMode.DAY,
|
||||
count=20,
|
||||
),
|
||||
),
|
||||
StageSpec(
|
||||
stage_name="worker",
|
||||
executor_kind=ExecutorKind.REACT,
|
||||
default_visibility_mask=0,
|
||||
context_policy=ContextPolicy(
|
||||
consumer_agent_type="worker",
|
||||
window_mode=ContextWindowMode.NUMBER,
|
||||
count=20,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
if normalized == "memory":
|
||||
return PipelineSpec(
|
||||
mode="memory",
|
||||
stages=[
|
||||
StageSpec(
|
||||
stage_name="memory",
|
||||
executor_kind=ExecutorKind.REACT,
|
||||
default_visibility_mask=0,
|
||||
context_policy=ContextPolicy(
|
||||
consumer_agent_type="memory",
|
||||
window_mode=ContextWindowMode.DAY,
|
||||
count=20,
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
raise ValueError(f"unsupported pipeline mode: {normalized}")
|
||||
@@ -10,26 +10,40 @@ from agentscope.formatter import OpenAIChatFormatter
|
||||
from agentscope.memory import InMemoryMemory
|
||||
from agentscope.message import Msg
|
||||
from agentscope.model import OpenAIChatModel
|
||||
from core.agentscope.prompts.agent_prompt import build_worker_contract_prompt
|
||||
from core.agentscope.prompts.system_prompt import build_system_prompt
|
||||
from core.agentscope.runtime.pipeline_registry import build_default_pipeline_spec
|
||||
from core.agentscope.runtime.json_react_agent import JsonReActAgent
|
||||
from core.agentscope.runtime.model_tracking import TrackingChatModel
|
||||
from core.agentscope.runtime.stage_emitter import PipelineStageEmitter
|
||||
from core.agentscope.runtime.tool_selection_registry import TOOL_SELECTION_REGISTRY
|
||||
from core.agentscope.tools.toolkit import build_stage_toolkit
|
||||
from core.agentscope.utils import patch_agentscope_json_repair_compat
|
||||
from core.agentscope.utils import (
|
||||
finalize_json_response,
|
||||
patch_agentscope_json_repair_compat,
|
||||
)
|
||||
from core.config.settings import config
|
||||
from core.db.session import AsyncSessionLocal
|
||||
from models.llm import Llm
|
||||
from models.llm_factory import LlmFactory
|
||||
from models.system_agents import SystemAgents
|
||||
from schemas.agent.runtime_models import (
|
||||
AgentOutput,
|
||||
)
|
||||
from schemas.agent.forwarded_props import (
|
||||
ClientTimeContext,
|
||||
parse_forwarded_props_client_time,
|
||||
)
|
||||
from schemas.agent.system_agent import AgentType, SystemAgentLLMConfig
|
||||
from schemas.automation.config import AutomationJobConfig
|
||||
from schemas.agent.runtime_models import (
|
||||
AgentOutput,
|
||||
RouterAgentOutput,
|
||||
WorkerAgentOutputLite,
|
||||
resolve_worker_output_model,
|
||||
)
|
||||
from schemas.agent.system_agent import (
|
||||
AgentType,
|
||||
ContextMessagesConfig,
|
||||
ContextBuildStrategy,
|
||||
SystemAgentLLMConfig,
|
||||
)
|
||||
from schemas.user import UserContext
|
||||
from services.litellm.service import LiteLLMService
|
||||
from sqlalchemy import select
|
||||
@@ -46,6 +60,7 @@ class SystemAgentRuntimeConfig:
|
||||
api_base_url: str
|
||||
api_key: str
|
||||
llm_config: SystemAgentLLMConfig
|
||||
extra_context: str | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -68,33 +83,83 @@ class AgentScopeRunner:
|
||||
pipeline: PipelineLike,
|
||||
run_input: RunAgentInput,
|
||||
system_agent_mode: str,
|
||||
memory_job_config: AutomationJobConfig | None = None,
|
||||
) -> dict[str, Any]:
|
||||
owner_id = UUID(user_context.id)
|
||||
runtime_client_time = self._resolve_runtime_client_time(run_input=run_input)
|
||||
stage_agent_type = self._resolve_stage_agent_type(system_agent_mode)
|
||||
pipeline_spec = build_default_pipeline_spec(mode=system_agent_mode)
|
||||
stage_agent_types = [
|
||||
self._parse_agent_type(stage_name=stage.stage_name)
|
||||
for stage in pipeline_spec.stages
|
||||
]
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
stage_config = await self._load_stage_config(
|
||||
session=session,
|
||||
agent_type=stage_agent_type,
|
||||
)
|
||||
if stage_agent_types == [AgentType.ROUTER, AgentType.WORKER]:
|
||||
router_config = await self._load_stage_config(
|
||||
session=session,
|
||||
agent_type=AgentType.ROUTER,
|
||||
)
|
||||
worker_config = await self._load_stage_config(
|
||||
session=session,
|
||||
agent_type=AgentType.WORKER,
|
||||
)
|
||||
worker_toolkit = self._build_stage_toolkit(
|
||||
session=session,
|
||||
owner_id=owner_id,
|
||||
stage_config=worker_config,
|
||||
)
|
||||
router_output = await self._execute_router_step(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
user_context=user_context,
|
||||
context_messages=context_messages,
|
||||
stage_config=router_config,
|
||||
runtime_client_time=runtime_client_time,
|
||||
)
|
||||
worker_output = await self._execute_worker_step(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
user_context=user_context,
|
||||
router_output=router_output,
|
||||
toolkit=worker_toolkit,
|
||||
stage_config=worker_config,
|
||||
runtime_client_time=runtime_client_time,
|
||||
)
|
||||
return {
|
||||
"router": router_output.model_dump(mode="json", exclude_none=True),
|
||||
"worker": worker_output.model_dump(mode="json", exclude_none=True),
|
||||
}
|
||||
|
||||
if stage_agent_types[0] == AgentType.MEMORY:
|
||||
if memory_job_config is None:
|
||||
raise RuntimeError("memory job config is required")
|
||||
stage_config = await self._build_memory_stage_config(
|
||||
session=session,
|
||||
memory_job_config=memory_job_config,
|
||||
)
|
||||
else:
|
||||
stage_config = await self._load_stage_config(
|
||||
session=session,
|
||||
agent_type=stage_agent_types[0],
|
||||
)
|
||||
stage_toolkit = self._build_stage_toolkit(
|
||||
session=session,
|
||||
owner_id=owner_id,
|
||||
stage_config=stage_config,
|
||||
)
|
||||
worker_output = await self._execute_worker_step(
|
||||
stage_output = await self._execute_single_stage_step(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
user_context=user_context,
|
||||
context_messages=context_messages,
|
||||
input_messages=context_messages,
|
||||
toolkit=stage_toolkit,
|
||||
stage_config=stage_config,
|
||||
runtime_client_time=runtime_client_time,
|
||||
)
|
||||
|
||||
return {
|
||||
"worker": worker_output.model_dump(mode="json", exclude_none=True),
|
||||
stage_config.agent_type.value: stage_output.model_dump(
|
||||
mode="json", exclude_none=True
|
||||
),
|
||||
}
|
||||
|
||||
def _build_stage_toolkit(
|
||||
@@ -113,11 +178,15 @@ class AgentScopeRunner:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_stage_agent_type(system_agent_mode: str) -> AgentType:
|
||||
mode = system_agent_mode.strip().lower() if system_agent_mode else "worker"
|
||||
if mode == AgentType.MEMORY.value:
|
||||
def _parse_agent_type(*, stage_name: str) -> AgentType:
|
||||
normalized = stage_name.strip().lower()
|
||||
if normalized == AgentType.ROUTER.value:
|
||||
return AgentType.ROUTER
|
||||
if normalized == AgentType.WORKER.value:
|
||||
return AgentType.WORKER
|
||||
if normalized == AgentType.MEMORY.value:
|
||||
return AgentType.MEMORY
|
||||
return AgentType.WORKER
|
||||
raise ValueError(f"unsupported stage name: {stage_name}")
|
||||
|
||||
async def _load_stage_config(
|
||||
self,
|
||||
@@ -130,28 +199,60 @@ class AgentScopeRunner:
|
||||
agent_type=agent_type,
|
||||
)
|
||||
|
||||
async def _execute_worker_step(
|
||||
async def _execute_router_step(
|
||||
self,
|
||||
*,
|
||||
pipeline: PipelineLike,
|
||||
run_input: RunAgentInput,
|
||||
user_context: UserContext,
|
||||
context_messages: list[Msg],
|
||||
toolkit: Any,
|
||||
stage_config: SystemAgentRuntimeConfig,
|
||||
runtime_client_time: ClientTimeContext | None,
|
||||
) -> AgentOutput:
|
||||
step_name = stage_config.agent_type.value
|
||||
worker_output_model = AgentOutput
|
||||
) -> RouterAgentOutput:
|
||||
await self._emit_step_event(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
step_name=step_name,
|
||||
step_name=AgentType.ROUTER.value,
|
||||
event_type="STEP_STARTED",
|
||||
)
|
||||
router_result = await self._run_router_stage(
|
||||
user_context=user_context,
|
||||
context_messages=context_messages,
|
||||
stage_config=stage_config,
|
||||
runtime_client_time=runtime_client_time,
|
||||
)
|
||||
router_output = RouterAgentOutput.model_validate(router_result.payload)
|
||||
await self._emit_step_event(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
step_name=AgentType.ROUTER.value,
|
||||
event_type="STEP_FINISHED",
|
||||
)
|
||||
return router_output
|
||||
|
||||
async def _execute_worker_step(
|
||||
self,
|
||||
*,
|
||||
pipeline: PipelineLike,
|
||||
run_input: RunAgentInput,
|
||||
user_context: UserContext,
|
||||
router_output: RouterAgentOutput,
|
||||
toolkit: Any,
|
||||
stage_config: SystemAgentRuntimeConfig,
|
||||
runtime_client_time: ClientTimeContext | None,
|
||||
) -> WorkerAgentOutputLite:
|
||||
worker_output_model = resolve_worker_output_model(router_output.ui.ui_mode)
|
||||
await self._emit_step_event(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
step_name=AgentType.WORKER.value,
|
||||
event_type="STEP_STARTED",
|
||||
)
|
||||
worker_result = await self._run_worker_stage(
|
||||
user_context=user_context,
|
||||
context_messages=context_messages,
|
||||
input_messages=self._build_worker_input_messages(
|
||||
router_output=router_output
|
||||
),
|
||||
toolkit=toolkit,
|
||||
run_input=run_input,
|
||||
stage_config=stage_config,
|
||||
@@ -163,11 +264,48 @@ class AgentScopeRunner:
|
||||
await self._emit_step_event(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
step_name=step_name,
|
||||
step_name=AgentType.WORKER.value,
|
||||
event_type="STEP_FINISHED",
|
||||
)
|
||||
return worker_output
|
||||
|
||||
async def _execute_single_stage_step(
|
||||
self,
|
||||
*,
|
||||
pipeline: PipelineLike,
|
||||
run_input: RunAgentInput,
|
||||
user_context: UserContext,
|
||||
input_messages: list[Msg],
|
||||
toolkit: Any,
|
||||
stage_config: SystemAgentRuntimeConfig,
|
||||
runtime_client_time: ClientTimeContext | None,
|
||||
) -> AgentOutput:
|
||||
step_name = stage_config.agent_type.value
|
||||
await self._emit_step_event(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
step_name=step_name,
|
||||
event_type="STEP_STARTED",
|
||||
)
|
||||
stage_result = await self._run_worker_stage(
|
||||
user_context=user_context,
|
||||
input_messages=input_messages,
|
||||
toolkit=toolkit,
|
||||
run_input=run_input,
|
||||
stage_config=stage_config,
|
||||
worker_output_model=AgentOutput,
|
||||
pipeline=pipeline,
|
||||
runtime_client_time=runtime_client_time,
|
||||
)
|
||||
stage_output = AgentOutput.model_validate(stage_result.payload)
|
||||
await self._emit_step_event(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
step_name=step_name,
|
||||
event_type="STEP_FINISHED",
|
||||
)
|
||||
return stage_output
|
||||
|
||||
async def _load_system_agent_config(
|
||||
self,
|
||||
*,
|
||||
@@ -193,6 +331,50 @@ class AgentScopeRunner:
|
||||
api_base_url=factory.request_url,
|
||||
api_key=self._resolve_provider_api_key(factory_name=factory.name),
|
||||
llm_config=SystemAgentLLMConfig.model_validate(system_agent.config or {}),
|
||||
extra_context=None,
|
||||
)
|
||||
|
||||
async def _build_memory_stage_config(
|
||||
self,
|
||||
*,
|
||||
session: AsyncSession,
|
||||
memory_job_config: AutomationJobConfig,
|
||||
) -> SystemAgentRuntimeConfig:
|
||||
stmt = (
|
||||
select(Llm, LlmFactory)
|
||||
.join(LlmFactory, Llm.factory_id == LlmFactory.id)
|
||||
.where(Llm.model_code == memory_job_config.model_code)
|
||||
)
|
||||
row = (await session.execute(stmt)).one_or_none()
|
||||
if row is None:
|
||||
raise RuntimeError(
|
||||
f"memory model not found: {memory_job_config.model_code}"
|
||||
)
|
||||
llm, factory = row
|
||||
llm_config = SystemAgentLLMConfig(
|
||||
temperature=0.7,
|
||||
max_tokens=None,
|
||||
timeout_seconds=30,
|
||||
visibility_consumer_bit=18,
|
||||
context_messages=ContextMessagesConfig(
|
||||
mode=(
|
||||
ContextBuildStrategy.DAY
|
||||
if memory_job_config.context.window_mode.value == "day"
|
||||
else ContextBuildStrategy.NUMBER
|
||||
),
|
||||
count=memory_job_config.context.window_count,
|
||||
),
|
||||
enabled_tools=memory_job_config.enabled_tools,
|
||||
)
|
||||
return SystemAgentRuntimeConfig(
|
||||
agent_type=AgentType.MEMORY,
|
||||
model_code=llm.model_code,
|
||||
api_base_url=factory.request_url,
|
||||
api_key=self._resolve_provider_api_key(factory_name=factory.name),
|
||||
llm_config=llm_config,
|
||||
extra_context=(
|
||||
f"[Memory Input Template]\n{memory_job_config.input_template.strip()}"
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -211,19 +393,63 @@ class AgentScopeRunner:
|
||||
raise RuntimeError(f"provider api key missing for factory: {factory_name}")
|
||||
return api_key
|
||||
|
||||
async def _run_worker_stage(
|
||||
async def _run_router_stage(
|
||||
self,
|
||||
*,
|
||||
user_context: UserContext,
|
||||
context_messages: list[Msg],
|
||||
stage_config: SystemAgentRuntimeConfig,
|
||||
runtime_client_time: ClientTimeContext | None,
|
||||
) -> StageExecutionResult:
|
||||
tracking_model = self._build_model(stage_config=stage_config)
|
||||
response, payload = await finalize_json_response(
|
||||
model=tracking_model,
|
||||
formatter=OpenAIChatFormatter(),
|
||||
base_messages=[
|
||||
Msg(
|
||||
"system",
|
||||
build_system_prompt(
|
||||
agent_type=AgentType.ROUTER,
|
||||
llm_config=stage_config.llm_config,
|
||||
user_context=user_context,
|
||||
now_utc=datetime.now(timezone.utc),
|
||||
runtime_client_time=runtime_client_time,
|
||||
tools=None,
|
||||
),
|
||||
"system",
|
||||
),
|
||||
*context_messages,
|
||||
],
|
||||
output_model=RouterAgentOutput,
|
||||
retries=0,
|
||||
)
|
||||
response_msg = Msg(
|
||||
name="router",
|
||||
role="assistant",
|
||||
content=list(getattr(response, "content", [])),
|
||||
metadata=payload,
|
||||
)
|
||||
return StageExecutionResult(
|
||||
message=response_msg,
|
||||
payload=payload,
|
||||
response_metadata=self._litellm_service.build_usage_metadata(
|
||||
model=stage_config.model_code,
|
||||
usage_summary=tracking_model.usage_summary(),
|
||||
),
|
||||
)
|
||||
|
||||
async def _run_worker_stage(
|
||||
self,
|
||||
*,
|
||||
user_context: UserContext,
|
||||
input_messages: list[Msg],
|
||||
toolkit: Any,
|
||||
run_input: RunAgentInput,
|
||||
stage_config: SystemAgentRuntimeConfig,
|
||||
worker_output_model: type[AgentOutput],
|
||||
worker_output_model: type[WorkerAgentOutputLite],
|
||||
pipeline: PipelineLike,
|
||||
runtime_client_time: ClientTimeContext | None,
|
||||
) -> StageExecutionResult:
|
||||
worker_input = list(context_messages)
|
||||
tracking_model = self._build_model(stage_config=stage_config)
|
||||
emitter = PipelineStageEmitter(
|
||||
pipeline=pipeline,
|
||||
@@ -241,6 +467,7 @@ class AgentScopeRunner:
|
||||
user_context=user_context,
|
||||
now_utc=datetime.now(timezone.utc),
|
||||
runtime_client_time=runtime_client_time,
|
||||
extra_context=stage_config.extra_context,
|
||||
tools=None,
|
||||
),
|
||||
toolkit=toolkit,
|
||||
@@ -248,7 +475,7 @@ class AgentScopeRunner:
|
||||
emitter=emitter,
|
||||
)
|
||||
response_msg = await agent.reply_json(
|
||||
worker_input, output_model=worker_output_model
|
||||
input_messages, output_model=worker_output_model
|
||||
)
|
||||
worker_payload = worker_output_model.model_validate(response_msg.metadata or {})
|
||||
response_metadata = self._litellm_service.build_usage_metadata(
|
||||
@@ -265,6 +492,19 @@ class AgentScopeRunner:
|
||||
response_metadata=response_metadata,
|
||||
)
|
||||
|
||||
def _build_worker_input_messages(
|
||||
self,
|
||||
*,
|
||||
router_output: RouterAgentOutput,
|
||||
) -> list[Msg]:
|
||||
return [
|
||||
Msg(
|
||||
name=AgentType.ROUTER.value,
|
||||
role="user",
|
||||
content=build_worker_contract_prompt(router_output=router_output),
|
||||
)
|
||||
]
|
||||
|
||||
def _build_model(
|
||||
self, *, stage_config: SystemAgentRuntimeConfig
|
||||
) -> TrackingChatModel:
|
||||
@@ -272,8 +512,8 @@ class AgentScopeRunner:
|
||||
"temperature": stage_config.llm_config.temperature,
|
||||
"max_tokens": stage_config.llm_config.max_tokens,
|
||||
"timeout": stage_config.llm_config.timeout_seconds,
|
||||
"extra_body": {"enable_thinking": False},
|
||||
}
|
||||
generate_kwargs["extra_body"] = {"enable_thinking": False}
|
||||
|
||||
model = OpenAIChatModel(
|
||||
model_name=stage_config.model_code,
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
from datetime import timezone
|
||||
from typing import Any, cast
|
||||
from uuid import UUID
|
||||
|
||||
@@ -14,26 +15,49 @@ from core.agentscope.events import (
|
||||
)
|
||||
from core.agentscope.runtime.context_service import AgentContextService
|
||||
from core.agentscope.runtime.orchestrator import AgentScopeRuntimeOrchestrator
|
||||
from core.agentscope.runtime.pipeline_registry import build_default_pipeline_spec
|
||||
from core.agentscope.schemas.agui_input import parse_run_input
|
||||
from core.automation.scheduler import (
|
||||
AutomationSchedulerService,
|
||||
SqlAlchemyAutomationSchedulerRepository,
|
||||
utc_now,
|
||||
)
|
||||
from core.auth.models import CurrentUser
|
||||
from core.config.settings import config
|
||||
from core.db.session import AsyncSessionLocal
|
||||
from core.logging import get_logger
|
||||
from core.taskiq.app import bulk_broker, critical_broker, default_broker
|
||||
from models.automation_jobs import AutomationJob
|
||||
from schemas.agent.visibility import SystemVisibilityBit, bit_mask
|
||||
from schemas.automation.config import AutomationJobConfig
|
||||
from schemas.messages.chat_message import (
|
||||
AgentChatMessageMetadata,
|
||||
extract_user_message_attachments,
|
||||
)
|
||||
from schemas.agent.forwarded_props import parse_forwarded_props_agent_type
|
||||
from schemas.user import UserContext
|
||||
from services.base.redis import get_or_init_redis_client
|
||||
from services.base.supabase import supabase_service
|
||||
from v1.agent.repository import AgentRepository
|
||||
from v1.users.dependencies import get_user_service
|
||||
from sqlalchemy import select
|
||||
|
||||
logger = get_logger("core.agentscope.runtime.tasks")
|
||||
_MAX_CONTEXT_ATTACHMENTS = 3
|
||||
|
||||
|
||||
class _BulkQueueAdapter:
|
||||
async def enqueue(
|
||||
self,
|
||||
*,
|
||||
command: dict[str, object],
|
||||
dedup_key: str | None,
|
||||
) -> str:
|
||||
del dedup_key
|
||||
result = await run_command_task_bulk.kiq(command)
|
||||
return str(result.task_id)
|
||||
|
||||
|
||||
def _serialize_tool_agent_output(
|
||||
*,
|
||||
metadata: AgentChatMessageMetadata | dict[str, object] | None,
|
||||
@@ -79,13 +103,29 @@ async def _build_recent_context_messages(
|
||||
*,
|
||||
session: Any,
|
||||
thread_id: str,
|
||||
system_agent_mode: str,
|
||||
context_mode: str,
|
||||
memory_job_config: AutomationJobConfig | None = None,
|
||||
) -> list[Msg]:
|
||||
context_service = AgentContextService(repository=AgentRepository(session))
|
||||
result = await context_service.load_context_messages(
|
||||
thread_id=thread_id,
|
||||
system_agent_mode=system_agent_mode,
|
||||
)
|
||||
if memory_job_config is not None:
|
||||
visibility_mask = bit_mask(bit=int(SystemVisibilityBit.UI_HISTORY))
|
||||
if memory_job_config.context.window_mode.value == "day":
|
||||
result = await context_service.load_by_day_window(
|
||||
thread_id=thread_id,
|
||||
day_count=memory_job_config.context.window_count,
|
||||
visibility_mask=visibility_mask,
|
||||
)
|
||||
else:
|
||||
result = await context_service.load_by_user_message_window(
|
||||
thread_id=thread_id,
|
||||
user_message_limit=memory_job_config.context.window_count,
|
||||
visibility_mask=visibility_mask,
|
||||
)
|
||||
else:
|
||||
result = await context_service.load_context_messages(
|
||||
thread_id=thread_id,
|
||||
system_agent_mode=context_mode,
|
||||
)
|
||||
if not result:
|
||||
return []
|
||||
|
||||
@@ -166,11 +206,33 @@ async def _build_recent_context_messages(
|
||||
return converted
|
||||
|
||||
|
||||
async def _load_memory_job_config(
|
||||
*,
|
||||
session: Any,
|
||||
owner_id: UUID,
|
||||
automation_job_id: str,
|
||||
) -> AutomationJobConfig:
|
||||
try:
|
||||
job_uuid = UUID(automation_job_id)
|
||||
except ValueError as exc:
|
||||
raise ValueError("automation_job_id is invalid") from exc
|
||||
|
||||
stmt = (
|
||||
select(AutomationJob)
|
||||
.where(AutomationJob.id == job_uuid)
|
||||
.where(AutomationJob.owner_id == owner_id)
|
||||
.where(AutomationJob.deleted_at.is_(None))
|
||||
)
|
||||
row = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if row is None:
|
||||
raise ValueError("automation job not found")
|
||||
return AutomationJobConfig.model_validate(row.config or {})
|
||||
|
||||
|
||||
async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
command_type = str(command.get("command", "run")).strip().lower()
|
||||
raw_owner_id = command.get("owner_id")
|
||||
run_input_raw = command.get("run_input")
|
||||
system_agent_mode = str(command.get("system_agent_mode", "worker")).strip().lower()
|
||||
|
||||
if not isinstance(raw_owner_id, str) or not raw_owner_id.strip():
|
||||
raise ValueError("owner_id is required")
|
||||
@@ -178,6 +240,15 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
raise ValueError("run_input is required")
|
||||
|
||||
run_input = parse_run_input(run_input_raw)
|
||||
system_agent_mode = parse_forwarded_props_agent_type(
|
||||
getattr(run_input, "forwarded_props", None)
|
||||
)
|
||||
raw_automation_job_id = command.get("automation_job_id")
|
||||
if system_agent_mode == "memory" and (
|
||||
not isinstance(raw_automation_job_id, str) or not raw_automation_job_id
|
||||
):
|
||||
raise ValueError("automation_job_id is required for memory mode")
|
||||
pipeline_spec = build_default_pipeline_spec(mode=system_agent_mode)
|
||||
thread_id = run_input.thread_id
|
||||
run_id = run_input.run_id
|
||||
owner_id = UUID(raw_owner_id)
|
||||
@@ -189,6 +260,14 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
user_context = await _build_user_context(owner_id=owner_id, session=session)
|
||||
memory_job_config: AutomationJobConfig | None = None
|
||||
if system_agent_mode == "memory":
|
||||
assert isinstance(raw_automation_job_id, str)
|
||||
memory_job_config = await _load_memory_job_config(
|
||||
session=session,
|
||||
owner_id=owner_id,
|
||||
automation_job_id=raw_automation_job_id,
|
||||
)
|
||||
|
||||
redis_client = await get_or_init_redis_client()
|
||||
bus = RedisStreamBus(
|
||||
@@ -211,7 +290,8 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
context_messages = await _build_recent_context_messages(
|
||||
session=session,
|
||||
thread_id=thread_id,
|
||||
system_agent_mode=system_agent_mode,
|
||||
context_mode=pipeline_spec.stages[0].context_policy.consumer_agent_type,
|
||||
memory_job_config=memory_job_config,
|
||||
)
|
||||
|
||||
await runtime.run(
|
||||
@@ -219,6 +299,7 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
context_messages=context_messages,
|
||||
user_context=user_context,
|
||||
system_agent_mode=system_agent_mode,
|
||||
memory_job_config=memory_job_config,
|
||||
)
|
||||
logger.info(
|
||||
"agentscope runtime task completed",
|
||||
@@ -233,6 +314,35 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
}
|
||||
|
||||
|
||||
async def run_automation_scheduler_scan(
|
||||
*,
|
||||
limit: int | None = None,
|
||||
) -> dict[str, int]:
|
||||
now = utc_now()
|
||||
safe_limit = (
|
||||
max(int(limit), 1)
|
||||
if isinstance(limit, int)
|
||||
else int(config.automation_scheduler.batch_limit)
|
||||
)
|
||||
async with AsyncSessionLocal() as session:
|
||||
repository = SqlAlchemyAutomationSchedulerRepository(session=session)
|
||||
service = AutomationSchedulerService(
|
||||
repository=repository,
|
||||
queue=_BulkQueueAdapter(),
|
||||
)
|
||||
result = await service.scan_and_dispatch(now_utc=now, limit=safe_limit)
|
||||
logger.info(
|
||||
"automation scheduler scan completed",
|
||||
scanned=result.scanned,
|
||||
dispatched=result.dispatched,
|
||||
now_utc=now.astimezone(timezone.utc).isoformat(),
|
||||
)
|
||||
return {
|
||||
"scanned": int(result.scanned),
|
||||
"dispatched": int(result.dispatched),
|
||||
}
|
||||
|
||||
|
||||
@default_broker.task(task_name="tasks.agentscope.run_command")
|
||||
async def run_command_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
return await run_agentscope_task(command)
|
||||
@@ -246,3 +356,8 @@ async def run_command_task_critical(command: dict[str, Any]) -> dict[str, object
|
||||
@bulk_broker.task(task_name="tasks.agentscope.run_command.bulk")
|
||||
async def run_command_task_bulk(command: dict[str, Any]) -> dict[str, object]:
|
||||
return await run_agentscope_task(command)
|
||||
|
||||
|
||||
@default_broker.task(task_name="tasks.automation.scan_due_jobs")
|
||||
async def scan_due_automation_jobs_task(limit: int | None = None) -> dict[str, int]:
|
||||
return await run_automation_scheduler_scan(limit=limit)
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from core.agentscope.tools.tool_config import resolve_tool_names_by_groups
|
||||
from core.agentscope.tools.tool_config import resolve_tool_function_names
|
||||
from schemas.agent.system_agent import AgentType
|
||||
|
||||
ToolNameResolver = Callable[[Any], set[str] | None]
|
||||
@@ -23,20 +23,19 @@ class ToolSelectionRegistry:
|
||||
return resolver(stage_config)
|
||||
|
||||
|
||||
def _default_group_resolver(stage_config: Any) -> set[str] | None:
|
||||
raw_groups = getattr(stage_config.llm_config, "enabled_tool_groups", [])
|
||||
groups = raw_groups if isinstance(raw_groups, list) else []
|
||||
if not groups:
|
||||
def _default_tool_resolver(stage_config: Any) -> set[str] | None:
|
||||
enabled_tools = getattr(stage_config.llm_config, "enabled_tools", [])
|
||||
if not enabled_tools:
|
||||
return None
|
||||
return resolve_tool_names_by_groups(set(groups))
|
||||
return resolve_tool_function_names(set(enabled_tools))
|
||||
|
||||
|
||||
TOOL_SELECTION_REGISTRY = ToolSelectionRegistry()
|
||||
TOOL_SELECTION_REGISTRY.register(
|
||||
agent_type=AgentType.WORKER,
|
||||
resolver=_default_group_resolver,
|
||||
resolver=_default_tool_resolver,
|
||||
)
|
||||
TOOL_SELECTION_REGISTRY.register(
|
||||
agent_type=AgentType.MEMORY,
|
||||
resolver=_default_group_resolver,
|
||||
resolver=_default_tool_resolver,
|
||||
)
|
||||
|
||||
@@ -11,7 +11,7 @@ from core.agentscope.tools.utils.calendar_domain import (
|
||||
map_calendar_exception,
|
||||
merge_schedule_metadata_for_update,
|
||||
parse_iso_datetime,
|
||||
resolve_share_target_email_map,
|
||||
resolve_share_target_phone_map,
|
||||
schedule_event_to_dict,
|
||||
)
|
||||
from core.agentscope.tools.utils.calendar_ui import (
|
||||
@@ -580,11 +580,11 @@ async def calendar_share(
|
||||
)
|
||||
target_uuid = UUID(event_id)
|
||||
|
||||
email_map = resolve_share_target_email_map(
|
||||
phone_map = resolve_share_target_phone_map(
|
||||
[invitee.user_id for invitee in invitees]
|
||||
)
|
||||
|
||||
if not email_map:
|
||||
if not phone_map:
|
||||
return calendar_error_output(
|
||||
tool_name=tool_name,
|
||||
tool_call_args=tool_call_args,
|
||||
@@ -599,8 +599,8 @@ async def calendar_share(
|
||||
normalized_user_id = str(UUID(invitee.user_id.strip()))
|
||||
except ValueError:
|
||||
continue
|
||||
email = email_map.get(normalized_user_id)
|
||||
if email is None:
|
||||
phone = phone_map.get(normalized_user_id)
|
||||
if phone is None:
|
||||
continue
|
||||
permission = {
|
||||
"permission_view": invitee.permission_view,
|
||||
@@ -608,15 +608,15 @@ async def calendar_share(
|
||||
"permission_invite": invitee.permission_invite,
|
||||
}
|
||||
await service.share(
|
||||
target_uuid, ScheduleItemShareRequest(email=email, **permission)
|
||||
target_uuid, ScheduleItemShareRequest(phone=phone, **permission)
|
||||
)
|
||||
invited.append(email)
|
||||
invited.append(phone)
|
||||
if not invited:
|
||||
return calendar_error_output(
|
||||
tool_name=tool_name,
|
||||
tool_call_args=tool_call_args,
|
||||
code="NOT_FOUND",
|
||||
message="邀请目标均无有效邮箱",
|
||||
message="邀请目标均无有效手机号",
|
||||
retryable=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from agentscope.tool import ToolResponse
|
||||
from core.agentscope.tools.tool_call_context import get_current_tool_call_id
|
||||
from core.agentscope.tools.utils import (
|
||||
find_auth_email_by_user_id,
|
||||
find_auth_phone_by_user_id,
|
||||
list_auth_users,
|
||||
)
|
||||
from core.agentscope.tools.utils.tool_response_builder import (
|
||||
@@ -46,22 +46,22 @@ def _lookup_error_output(
|
||||
async def _resolve_identity(
|
||||
*,
|
||||
session: AsyncSession,
|
||||
user_email: str | None,
|
||||
user_phone: str | None,
|
||||
user_name: str | None,
|
||||
) -> dict[str, Any]:
|
||||
"""Resolve user identity by email or username."""
|
||||
email = user_email.strip().lower() if isinstance(user_email, str) else ""
|
||||
"""Resolve user identity by phone or username."""
|
||||
phone = user_phone.strip() if isinstance(user_phone, str) else ""
|
||||
name = user_name.strip() if isinstance(user_name, str) else ""
|
||||
|
||||
if bool(email) == bool(name):
|
||||
if bool(phone) == bool(name):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="请提供 email 或 username 其中之一",
|
||||
detail="请提供 phone 或 username 其中之一",
|
||||
)
|
||||
|
||||
if email:
|
||||
if phone:
|
||||
auth_gateway = SupabaseAuthGateway()
|
||||
user = await auth_gateway.get_user_by_email(email)
|
||||
user = await auth_gateway.get_user_by_phone(phone)
|
||||
user_id = UUID(user.id)
|
||||
|
||||
stmt = (
|
||||
@@ -73,9 +73,9 @@ async def _resolve_identity(
|
||||
|
||||
return {
|
||||
"userId": str(user_id),
|
||||
"email": user.email,
|
||||
"phone": user.phone,
|
||||
"username": username,
|
||||
"matchedBy": "email",
|
||||
"matchedBy": "phone",
|
||||
}
|
||||
|
||||
stmt = (
|
||||
@@ -90,20 +90,20 @@ async def _resolve_identity(
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
|
||||
users = list_auth_users()
|
||||
email_value = find_auth_email_by_user_id(users=users, user_id=profile.id)
|
||||
phone_value = find_auth_phone_by_user_id(users=users, user_id=profile.id)
|
||||
|
||||
return {
|
||||
"userId": str(profile.id),
|
||||
"email": email_value,
|
||||
"phone": phone_value,
|
||||
"username": profile.username,
|
||||
"matchedBy": "username",
|
||||
}
|
||||
|
||||
|
||||
async def user_lookup(
|
||||
user_email: Annotated[
|
||||
user_phone: Annotated[
|
||||
str | None,
|
||||
Field(description="User email address to look up."),
|
||||
Field(description="User phone to look up."),
|
||||
] = None,
|
||||
user_name: Annotated[
|
||||
str | None,
|
||||
@@ -112,16 +112,16 @@ async def user_lookup(
|
||||
session: Any = None,
|
||||
owner_id: Any = None,
|
||||
) -> ToolResponse:
|
||||
"""Look up user identity by email or username.
|
||||
"""Look up user identity by phone or username.
|
||||
|
||||
Args:
|
||||
user_email: User email address for lookup.
|
||||
user_phone: User phone for lookup.
|
||||
user_name: Username for lookup.
|
||||
|
||||
Returns:
|
||||
ToolResponse with serialized ToolAgentOutput payload.
|
||||
"""
|
||||
tool_call_args = {"user_email": user_email, "user_name": user_name}
|
||||
tool_call_args = {"user_phone": user_phone, "user_name": user_name}
|
||||
|
||||
if session is None or owner_id is None:
|
||||
return _lookup_error_output(
|
||||
@@ -134,17 +134,17 @@ async def user_lookup(
|
||||
try:
|
||||
resolved = await _resolve_identity(
|
||||
session=cast(AsyncSession, session),
|
||||
user_email=user_email,
|
||||
user_phone=user_phone,
|
||||
user_name=user_name,
|
||||
)
|
||||
|
||||
username = str(resolved.get("username") or "")
|
||||
email = str(resolved.get("email") or "")
|
||||
phone = str(resolved.get("phone") or "")
|
||||
user_id = str(resolved.get("userId") or "")
|
||||
matched_by = str(resolved.get("matchedBy") or "")
|
||||
summary = (
|
||||
f"status=success matched_by={matched_by} user_id={user_id} "
|
||||
f"username={username} has_email={str(bool(email)).lower()}"
|
||||
f"username={username} has_phone={str(bool(phone)).lower()}"
|
||||
)
|
||||
return _dump_tool_output(
|
||||
ToolAgentOutput(
|
||||
|
||||
@@ -6,7 +6,15 @@ from enum import Enum
|
||||
|
||||
class ToolGroup(str, Enum):
|
||||
READ = "read"
|
||||
WRITE = "write"
|
||||
EXECUTE = "execute"
|
||||
MEMORY = "memory"
|
||||
|
||||
|
||||
class AgentTool(str, Enum):
|
||||
CALENDAR_READ = "calendar.read"
|
||||
CALENDAR_WRITE = "calendar.write"
|
||||
CALENDAR_SHARE = "calendar.share"
|
||||
USER_LOOKUP = "user.lookup"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -29,30 +37,51 @@ TOOL_CONFIGS: dict[str, ToolConfig] = {
|
||||
),
|
||||
"user_lookup": ToolConfig(
|
||||
name="user_lookup",
|
||||
group=ToolGroup.READ,
|
||||
group=ToolGroup.MEMORY,
|
||||
approval=ToolApprovalConfig(required=False),
|
||||
),
|
||||
"calendar_write": ToolConfig(
|
||||
name="calendar_write",
|
||||
group=ToolGroup.WRITE,
|
||||
group=ToolGroup.EXECUTE,
|
||||
approval=ToolApprovalConfig(required=False),
|
||||
),
|
||||
"calendar_share": ToolConfig(
|
||||
name="calendar_share",
|
||||
group=ToolGroup.WRITE,
|
||||
group=ToolGroup.EXECUTE,
|
||||
approval=ToolApprovalConfig(required=False),
|
||||
),
|
||||
}
|
||||
|
||||
AGENT_TOOL_TO_FUNCTION_NAME: dict[AgentTool, str] = {
|
||||
AgentTool.CALENDAR_READ: "calendar_read",
|
||||
AgentTool.CALENDAR_WRITE: "calendar_write",
|
||||
AgentTool.CALENDAR_SHARE: "calendar_share",
|
||||
AgentTool.USER_LOOKUP: "user_lookup",
|
||||
}
|
||||
|
||||
def get_tool_config(tool_name: str) -> ToolConfig:
|
||||
config = TOOL_CONFIGS.get(tool_name)
|
||||
if config is None:
|
||||
raise ValueError(f"unknown tool: {tool_name}")
|
||||
return config
|
||||
TOOL_NAME_ALIASES: dict[str, AgentTool] = {
|
||||
AgentTool.CALENDAR_READ.value: AgentTool.CALENDAR_READ,
|
||||
"calendar_read": AgentTool.CALENDAR_READ,
|
||||
AgentTool.CALENDAR_WRITE.value: AgentTool.CALENDAR_WRITE,
|
||||
"calendar_write": AgentTool.CALENDAR_WRITE,
|
||||
AgentTool.CALENDAR_SHARE.value: AgentTool.CALENDAR_SHARE,
|
||||
"calendar_share": AgentTool.CALENDAR_SHARE,
|
||||
AgentTool.USER_LOOKUP.value: AgentTool.USER_LOOKUP,
|
||||
"user_lookup": AgentTool.USER_LOOKUP,
|
||||
}
|
||||
|
||||
|
||||
def resolve_tool_names_by_groups(groups: set[ToolGroup]) -> set[str]:
|
||||
if not groups:
|
||||
return set()
|
||||
return {name for name, config in TOOL_CONFIGS.items() if config.group in groups}
|
||||
def parse_agent_tool(value: object) -> AgentTool:
|
||||
if isinstance(value, AgentTool):
|
||||
return value
|
||||
raw_value = str(value or "").strip().lower()
|
||||
if not raw_value:
|
||||
raise ValueError("enabled tool value cannot be empty")
|
||||
tool = TOOL_NAME_ALIASES.get(raw_value)
|
||||
if tool is None:
|
||||
raise ValueError(f"unknown enabled tool: {raw_value}")
|
||||
return tool
|
||||
|
||||
|
||||
def resolve_tool_function_names(tools: set[AgentTool]) -> set[str]:
|
||||
return {AGENT_TOOL_TO_FUNCTION_NAME[tool] for tool in tools}
|
||||
|
||||
@@ -7,10 +7,15 @@ from core.agentscope.tools.tool_call_context import (
|
||||
reset_current_tool_call_id,
|
||||
set_current_tool_call_id,
|
||||
)
|
||||
from core.agentscope.tools.tool_config import (
|
||||
AGENT_TOOL_TO_FUNCTION_NAME,
|
||||
TOOL_CONFIGS,
|
||||
ToolConfig,
|
||||
parse_agent_tool,
|
||||
)
|
||||
from core.agentscope.tools.utils.tool_response_builder import (
|
||||
build_error_response,
|
||||
)
|
||||
from core.agentscope.tools.tool_config import ToolConfig, TOOL_CONFIGS
|
||||
|
||||
|
||||
def register_tool_middlewares(
|
||||
@@ -59,6 +64,18 @@ def create_approval_middleware(
|
||||
approval_resolver: Callable[[str, dict[str, Any], ToolConfig], str | None]
|
||||
| None = None,
|
||||
) -> Callable[..., AsyncGenerator[Any, None]]:
|
||||
def _resolve_tool_config(*, tool_name: str) -> ToolConfig | None:
|
||||
config = config_by_name.get(tool_name)
|
||||
if config is not None:
|
||||
return config
|
||||
try:
|
||||
normalized_tool_name = AGENT_TOOL_TO_FUNCTION_NAME[
|
||||
parse_agent_tool(tool_name)
|
||||
]
|
||||
except ValueError:
|
||||
return None
|
||||
return config_by_name.get(normalized_tool_name)
|
||||
|
||||
def _resolve_tool_call_id(tool_call: dict[str, Any]) -> str:
|
||||
raw_tool_call_id = tool_call.get("id")
|
||||
if isinstance(raw_tool_call_id, str) and raw_tool_call_id.strip():
|
||||
@@ -81,7 +98,7 @@ def create_approval_middleware(
|
||||
yield response
|
||||
return
|
||||
|
||||
config = config_by_name.get(tool_name)
|
||||
config = _resolve_tool_config(tool_name=tool_name)
|
||||
if config is None or not config.approval.required:
|
||||
async for response in await next_handler(**kwargs):
|
||||
yield response
|
||||
@@ -134,15 +151,3 @@ def create_approval_middleware(
|
||||
yield pending_response
|
||||
|
||||
return approval_middleware
|
||||
|
||||
|
||||
def create_hitl_middleware(
|
||||
*,
|
||||
meta_by_name: dict[str, ToolConfig],
|
||||
approval_resolver: Callable[[str, dict[str, Any], ToolConfig], str | None]
|
||||
| None = None,
|
||||
) -> Callable[..., AsyncGenerator[Any, None]]:
|
||||
return create_approval_middleware(
|
||||
config_by_name=meta_by_name,
|
||||
approval_resolver=approval_resolver,
|
||||
)
|
||||
|
||||
@@ -13,8 +13,6 @@ from core.agentscope.tools.custom.calendar import (
|
||||
from core.agentscope.tools.custom.user_lookup import user_lookup
|
||||
from core.agentscope.tools.tool_config import (
|
||||
TOOL_CONFIGS,
|
||||
ToolGroup,
|
||||
resolve_tool_names_by_groups,
|
||||
)
|
||||
from core.agentscope.tools.tool_middleware import register_tool_middlewares
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
@@ -28,46 +26,36 @@ TOOL_FUNCTIONS: dict[str, Any] = {
|
||||
}
|
||||
|
||||
|
||||
AGENT_TYPE_TO_GROUPS: dict[AgentType, set[ToolGroup]] = {
|
||||
AgentType.WORKER: {ToolGroup.READ, ToolGroup.WRITE},
|
||||
AgentType.MEMORY: {ToolGroup.READ, ToolGroup.WRITE},
|
||||
AGENT_TYPE_TO_DEFAULT_TOOLS: dict[AgentType, set[str]] = {
|
||||
AgentType.WORKER: {
|
||||
"calendar_read",
|
||||
"calendar_write",
|
||||
"calendar_share",
|
||||
"user_lookup",
|
||||
},
|
||||
AgentType.MEMORY: {"calendar_read", "user_lookup"},
|
||||
}
|
||||
|
||||
|
||||
def _resolve_enabled_tools(
|
||||
*,
|
||||
groups: set[ToolGroup] | None,
|
||||
enabled_tool_names: set[str] | None,
|
||||
) -> set[str]:
|
||||
if enabled_tool_names is not None:
|
||||
unknown = enabled_tool_names - set(TOOL_FUNCTIONS)
|
||||
if unknown:
|
||||
raise ValueError(f"unknown tools in enabled_tool_names: {sorted(unknown)}")
|
||||
return set(enabled_tool_names)
|
||||
|
||||
if groups is None:
|
||||
return set(TOOL_FUNCTIONS)
|
||||
|
||||
resolved = resolve_tool_names_by_groups(groups)
|
||||
unknown = resolved - set(TOOL_FUNCTIONS)
|
||||
def _validate_enabled_tool_names(enabled_tool_names: set[str]) -> set[str]:
|
||||
unknown = enabled_tool_names - set(TOOL_FUNCTIONS)
|
||||
if unknown:
|
||||
raise ValueError(f"tool config contains unknown tools: {sorted(unknown)}")
|
||||
return resolved
|
||||
raise ValueError(f"unknown tools in enabled_tool_names: {sorted(unknown)}")
|
||||
return enabled_tool_names
|
||||
|
||||
|
||||
def build_toolkit(
|
||||
*,
|
||||
session: AsyncSession,
|
||||
owner_id: UUID,
|
||||
groups: set[ToolGroup] | None = None,
|
||||
enabled_tool_names: set[str] | None = None,
|
||||
enable_hitl: bool | None = None,
|
||||
):
|
||||
toolkit = Toolkit()
|
||||
enabled_names = _resolve_enabled_tools(
|
||||
groups=groups,
|
||||
enabled_tool_names=enabled_tool_names,
|
||||
)
|
||||
if enabled_tool_names is None:
|
||||
enabled_names = set(TOOL_FUNCTIONS)
|
||||
else:
|
||||
enabled_names = _validate_enabled_tool_names(set(enabled_tool_names))
|
||||
|
||||
preset_kwargs = cast(
|
||||
dict[str, JSONSerializableObject],
|
||||
@@ -100,15 +88,13 @@ def build_stage_toolkit(
|
||||
enabled_tool_names: set[str] | None = None,
|
||||
enable_hitl: bool | None = None,
|
||||
):
|
||||
groups = AGENT_TYPE_TO_GROUPS.get(agent_type)
|
||||
if groups is None:
|
||||
default_tools = AGENT_TYPE_TO_DEFAULT_TOOLS.get(agent_type)
|
||||
if default_tools is None:
|
||||
raise ValueError(f"unknown agent_type: {agent_type}")
|
||||
|
||||
stage_enabled_names = resolve_tool_names_by_groups(set(groups))
|
||||
selected_names = (
|
||||
stage_enabled_names
|
||||
set(default_tools)
|
||||
if enabled_tool_names is None
|
||||
else stage_enabled_names | set(enabled_tool_names)
|
||||
else _validate_enabled_tool_names(set(enabled_tool_names))
|
||||
)
|
||||
|
||||
return build_toolkit(
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from core.agentscope.tools.utils.auth_helpers import (
|
||||
find_auth_email_by_user_id,
|
||||
find_auth_phone_by_user_id,
|
||||
list_auth_users,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"list_auth_users",
|
||||
"find_auth_email_by_user_id",
|
||||
"find_auth_phone_by_user_id",
|
||||
]
|
||||
|
||||
@@ -25,12 +25,12 @@ def list_auth_users() -> list[Any]:
|
||||
return users
|
||||
|
||||
|
||||
def find_auth_email_by_user_id(*, users: list[Any], user_id: UUID) -> str | None:
|
||||
"""Find auth email by user id from fetched user list."""
|
||||
def find_auth_phone_by_user_id(*, users: list[Any], user_id: UUID) -> str | None:
|
||||
"""Find auth phone by user id from fetched user list."""
|
||||
target = str(user_id)
|
||||
for user in users:
|
||||
if str(getattr(user, "id", "")) == target:
|
||||
email = getattr(user, "email", None)
|
||||
if isinstance(email, str) and email.strip():
|
||||
return email.strip()
|
||||
phone = getattr(user, "phone", None)
|
||||
if isinstance(phone, str) and phone.strip():
|
||||
return phone.strip()
|
||||
return None
|
||||
|
||||
@@ -9,7 +9,7 @@ from fastapi import HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.agentscope.tools.utils.auth_helpers import (
|
||||
find_auth_email_by_user_id,
|
||||
find_auth_phone_by_user_id,
|
||||
list_auth_users,
|
||||
)
|
||||
from core.auth.models import CurrentUser
|
||||
@@ -125,7 +125,7 @@ def parse_iso_datetime(value: str | None) -> datetime | None:
|
||||
return parsed.astimezone(timezone.utc)
|
||||
|
||||
|
||||
def resolve_share_target_email_map(invitee_user_ids: list[str]) -> dict[str, str]:
|
||||
def resolve_share_target_phone_map(invitee_user_ids: list[str]) -> dict[str, str]:
|
||||
users = list_auth_users()
|
||||
resolved: dict[str, str] = {}
|
||||
for raw_user_id in invitee_user_ids:
|
||||
@@ -138,7 +138,7 @@ def resolve_share_target_email_map(invitee_user_ids: list[str]) -> dict[str, str
|
||||
user_uuid = UUID(normalized_user_id)
|
||||
except ValueError:
|
||||
continue
|
||||
email = find_auth_email_by_user_id(users=users, user_id=user_uuid)
|
||||
if email:
|
||||
resolved[str(user_uuid)] = email.lower()
|
||||
phone = find_auth_phone_by_user_id(users=users, user_id=user_uuid)
|
||||
if phone:
|
||||
resolved[str(user_uuid)] = phone
|
||||
return resolved
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
from core.automation.scheduler import (
|
||||
AutomationSchedulerService,
|
||||
DispatchResult,
|
||||
SqlAlchemyAutomationSchedulerRepository,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AutomationSchedulerService",
|
||||
"DispatchResult",
|
||||
"SqlAlchemyAutomationSchedulerRepository",
|
||||
]
|
||||
@@ -0,0 +1,247 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Protocol
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.logging import get_logger
|
||||
from models.agent_chat_session import AgentChatSession, SessionType
|
||||
from models.automation_jobs import AutomationJob, ScheduleType
|
||||
from schemas.automation.config import AutomationJobConfig
|
||||
from schemas.automation.scheduler import DueAutomationJob, SchedulerDispatchCommand
|
||||
|
||||
logger = get_logger("core.automation.scheduler")
|
||||
|
||||
|
||||
class QueueLike(Protocol):
|
||||
async def enqueue(
|
||||
self,
|
||||
*,
|
||||
command: dict[str, object],
|
||||
dedup_key: str | None,
|
||||
) -> str: ...
|
||||
|
||||
|
||||
class AutomationSchedulerRepositoryLike(Protocol):
|
||||
async def list_due_jobs(
|
||||
self,
|
||||
*,
|
||||
now_utc: datetime,
|
||||
limit: int,
|
||||
) -> list[DueAutomationJob]: ...
|
||||
|
||||
async def get_job_config(self, *, job_id: UUID) -> AutomationJobConfig: ...
|
||||
|
||||
async def ensure_latest_chat_session(self, *, owner_id: UUID) -> UUID: ...
|
||||
|
||||
async def mark_job_dispatched(
|
||||
self,
|
||||
*,
|
||||
job_id: UUID,
|
||||
next_run_at: datetime,
|
||||
last_run_at: datetime,
|
||||
) -> None: ...
|
||||
|
||||
async def commit(self) -> None: ...
|
||||
|
||||
async def rollback(self) -> None: ...
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class DispatchResult:
|
||||
scanned: int
|
||||
dispatched: int
|
||||
|
||||
|
||||
class AutomationSchedulerService:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
repository: AutomationSchedulerRepositoryLike,
|
||||
queue: QueueLike,
|
||||
) -> None:
|
||||
self._repository = repository
|
||||
self._queue = queue
|
||||
|
||||
async def scan_and_dispatch(
|
||||
self,
|
||||
*,
|
||||
now_utc: datetime,
|
||||
limit: int,
|
||||
) -> DispatchResult:
|
||||
safe_limit = max(int(limit), 1)
|
||||
due_jobs = await self._repository.list_due_jobs(
|
||||
now_utc=now_utc, limit=safe_limit
|
||||
)
|
||||
dispatched = 0
|
||||
for job in due_jobs:
|
||||
try:
|
||||
config = await self._repository.get_job_config(job_id=job.id)
|
||||
thread_id = await self._repository.ensure_latest_chat_session(
|
||||
owner_id=job.owner_id
|
||||
)
|
||||
command = self._build_dispatch_command(
|
||||
job=job,
|
||||
thread_id=thread_id,
|
||||
input_text=config.input_template,
|
||||
now_utc=now_utc,
|
||||
)
|
||||
await self._queue.enqueue(command=command, dedup_key=None)
|
||||
await self._repository.mark_job_dispatched(
|
||||
job_id=job.id,
|
||||
next_run_at=_compute_next_run_at(
|
||||
current_next_run_at=job.next_run_at,
|
||||
now_utc=now_utc,
|
||||
schedule_type=job.schedule_type,
|
||||
),
|
||||
last_run_at=now_utc,
|
||||
)
|
||||
await self._repository.commit()
|
||||
dispatched += 1
|
||||
except Exception as exc:
|
||||
await self._repository.rollback()
|
||||
logger.exception(
|
||||
"automation job dispatch failed",
|
||||
job_id=str(job.id),
|
||||
owner_id=str(job.owner_id),
|
||||
error=str(exc),
|
||||
)
|
||||
return DispatchResult(scanned=len(due_jobs), dispatched=dispatched)
|
||||
|
||||
def _build_dispatch_command(
|
||||
self,
|
||||
*,
|
||||
job: DueAutomationJob,
|
||||
thread_id: UUID,
|
||||
input_text: str,
|
||||
now_utc: datetime,
|
||||
) -> dict[str, object]:
|
||||
run_id = f"auto-{job.id}-{int(now_utc.timestamp())}"
|
||||
payload = SchedulerDispatchCommand(
|
||||
owner_id=job.owner_id,
|
||||
automation_job_id=job.id,
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
input_text=input_text.strip(),
|
||||
)
|
||||
return {
|
||||
"command": "run",
|
||||
"owner_id": str(payload.owner_id),
|
||||
"automation_job_id": str(payload.automation_job_id),
|
||||
"queue": "bulk",
|
||||
"run_input": {
|
||||
"threadId": str(payload.thread_id),
|
||||
"runId": payload.run_id,
|
||||
"state": {},
|
||||
"messages": [
|
||||
{
|
||||
"id": str(uuid4()),
|
||||
"role": "user",
|
||||
"content": payload.input_text,
|
||||
}
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {
|
||||
"agent_type": "memory",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class SqlAlchemyAutomationSchedulerRepository:
|
||||
def __init__(self, *, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
async def list_due_jobs(
|
||||
self,
|
||||
*,
|
||||
now_utc: datetime,
|
||||
limit: int,
|
||||
) -> list[DueAutomationJob]:
|
||||
stmt = (
|
||||
select(AutomationJob)
|
||||
.where(AutomationJob.deleted_at.is_(None))
|
||||
.where(AutomationJob.status == "active")
|
||||
.where(AutomationJob.next_run_at <= now_utc)
|
||||
.order_by(AutomationJob.next_run_at.asc())
|
||||
.limit(max(limit, 1))
|
||||
)
|
||||
rows = (await self._session.execute(stmt)).scalars().all()
|
||||
return [
|
||||
DueAutomationJob(
|
||||
id=row.id,
|
||||
owner_id=row.owner_id,
|
||||
schedule_type=row.schedule_type,
|
||||
timezone=row.timezone,
|
||||
next_run_at=row.next_run_at,
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
async def get_job_config(self, *, job_id: UUID) -> AutomationJobConfig:
|
||||
stmt = select(AutomationJob.config).where(AutomationJob.id == job_id)
|
||||
config_payload = (await self._session.execute(stmt)).scalar_one()
|
||||
return AutomationJobConfig.model_validate(config_payload or {})
|
||||
|
||||
async def ensure_latest_chat_session(self, *, owner_id: UUID) -> UUID:
|
||||
stmt = (
|
||||
select(AgentChatSession.id)
|
||||
.where(AgentChatSession.user_id == owner_id)
|
||||
.where(AgentChatSession.deleted_at.is_(None))
|
||||
.where(AgentChatSession.session_type == SessionType.CHAT)
|
||||
.order_by(AgentChatSession.last_activity_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
existing = (await self._session.execute(stmt)).scalar_one_or_none()
|
||||
if existing is not None:
|
||||
return existing
|
||||
|
||||
session = AgentChatSession(
|
||||
id=uuid4(),
|
||||
user_id=owner_id,
|
||||
session_type=SessionType.CHAT,
|
||||
)
|
||||
self._session.add(session)
|
||||
await self._session.flush()
|
||||
return session.id
|
||||
|
||||
async def mark_job_dispatched(
|
||||
self,
|
||||
*,
|
||||
job_id: UUID,
|
||||
next_run_at: datetime,
|
||||
last_run_at: datetime,
|
||||
) -> None:
|
||||
stmt = select(AutomationJob).where(AutomationJob.id == job_id)
|
||||
row = (await self._session.execute(stmt)).scalar_one()
|
||||
row.next_run_at = next_run_at
|
||||
row.last_run_at = last_run_at
|
||||
await self._session.flush()
|
||||
|
||||
async def commit(self) -> None:
|
||||
await self._session.commit()
|
||||
|
||||
async def rollback(self) -> None:
|
||||
await self._session.rollback()
|
||||
|
||||
|
||||
def _compute_next_run_at(
|
||||
*,
|
||||
current_next_run_at: datetime,
|
||||
now_utc: datetime,
|
||||
schedule_type: ScheduleType,
|
||||
) -> datetime:
|
||||
delta = timedelta(days=1 if schedule_type == ScheduleType.DAILY else 7)
|
||||
next_run_at = current_next_run_at
|
||||
while next_run_at <= now_utc:
|
||||
next_run_at = next_run_at + delta
|
||||
return next_run_at
|
||||
|
||||
|
||||
def utc_now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
@@ -1,27 +1,30 @@
|
||||
agents:
|
||||
- agent_type: worker
|
||||
llm_model_code: qwen3.5-35b-a3b
|
||||
status: active
|
||||
config:
|
||||
temperature: 0.7
|
||||
max_tokens: null
|
||||
timeout_seconds: 30
|
||||
context_messages:
|
||||
mode: number
|
||||
count: 20
|
||||
enabled_tool_groups:
|
||||
- read
|
||||
- write
|
||||
- agent_type: router
|
||||
llm_model_code: qwen3.5-flash
|
||||
status: active
|
||||
config:
|
||||
temperature: 0.7
|
||||
max_tokens: null
|
||||
timeout_seconds: 30
|
||||
visibility_consumer_bit: 16
|
||||
context_messages:
|
||||
mode: day
|
||||
count: 2
|
||||
enabled_tools: []
|
||||
|
||||
- agent_type: memory
|
||||
llm_model_code: qwen3.5-flash
|
||||
status: active
|
||||
config:
|
||||
temperature: 0.7
|
||||
max_tokens: null
|
||||
timeout_seconds: 30
|
||||
context_messages:
|
||||
mode: day
|
||||
count: 2
|
||||
enabled_tool_groups:
|
||||
- read
|
||||
- agent_type: worker
|
||||
llm_model_code: qwen3.5-flash
|
||||
status: active
|
||||
config:
|
||||
temperature: 0.7
|
||||
max_tokens: null
|
||||
timeout_seconds: 30
|
||||
visibility_consumer_bit: 17
|
||||
context_messages:
|
||||
mode: number
|
||||
count: 20
|
||||
enabled_tools:
|
||||
- calendar.read
|
||||
- calendar.write
|
||||
- calendar.share
|
||||
- user.lookup
|
||||
|
||||
@@ -5,6 +5,7 @@ import uuid
|
||||
from enum import Enum
|
||||
|
||||
from sqlalchemy import (
|
||||
BigInteger,
|
||||
JSON,
|
||||
Enum as SqlEnum,
|
||||
ForeignKey,
|
||||
@@ -59,6 +60,11 @@ class AgentChatMessage(TimestampMixin, SoftDeleteMixin, Base):
|
||||
output_tokens: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
cost: Mapped[Decimal] = mapped_column(Numeric(12, 6), nullable=False, default=0)
|
||||
latency_ms: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
visibility_mask: Mapped[int] = mapped_column(
|
||||
BigInteger,
|
||||
nullable=False,
|
||||
default=0,
|
||||
)
|
||||
metadata_json: Mapped[dict[str, object] | None] = mapped_column(
|
||||
"metadata", JSON().with_variant(JSONB, "postgresql"), nullable=True
|
||||
)
|
||||
|
||||
@@ -4,8 +4,8 @@ import uuid
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from sqlalchemy import DateTime, String, Text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy import DateTime, JSON, String
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from core.db.base import Base, SoftDeleteMixin, TimestampMixin
|
||||
@@ -36,9 +36,10 @@ class AutomationJob(TimestampMixin, SoftDeleteMixin, Base):
|
||||
String(255),
|
||||
nullable=False,
|
||||
)
|
||||
prompt: Mapped[str] = mapped_column(
|
||||
Text,
|
||||
config: Mapped[dict[str, object]] = mapped_column(
|
||||
JSON().with_variant(JSONB, "postgresql"),
|
||||
nullable=False,
|
||||
default=dict,
|
||||
)
|
||||
schedule_type: Mapped[ScheduleType] = mapped_column(
|
||||
String(20),
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
|
||||
|
||||
class AgentConsumerBinding(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
agent_type: str = Field(..., min_length=1, max_length=64)
|
||||
bit: int = Field(..., ge=16, le=63)
|
||||
|
||||
@field_validator("agent_type")
|
||||
@classmethod
|
||||
def _normalize_agent_type(cls, value: str) -> str:
|
||||
normalized = value.strip().lower()
|
||||
if not normalized:
|
||||
raise ValueError("agent_type must not be empty")
|
||||
return normalized
|
||||
|
||||
|
||||
class ConsumerRegistry(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
bindings: list[AgentConsumerBinding] = Field(default_factory=list)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_unique_bindings(self) -> "ConsumerRegistry":
|
||||
by_agent: set[str] = set()
|
||||
by_bit: set[int] = set()
|
||||
for item in self.bindings:
|
||||
if item.agent_type in by_agent:
|
||||
raise ValueError(f"duplicate agent_type binding: {item.agent_type}")
|
||||
if item.bit in by_bit:
|
||||
raise ValueError(f"duplicate visibility bit binding: {item.bit}")
|
||||
by_agent.add(item.agent_type)
|
||||
by_bit.add(item.bit)
|
||||
return self
|
||||
|
||||
def resolve_agent_bit(self, *, agent_type: str) -> int:
|
||||
target = agent_type.strip().lower()
|
||||
for item in self.bindings:
|
||||
if item.agent_type == target:
|
||||
return item.bit
|
||||
raise ValueError(f"agent visibility bit not configured: {target}")
|
||||
@@ -0,0 +1,63 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
|
||||
class ExecutorKind(str, Enum):
|
||||
SINGLE_SHOT = "single_shot"
|
||||
REACT = "react"
|
||||
|
||||
|
||||
class ContextWindowMode(str, Enum):
|
||||
DAY = "day"
|
||||
NUMBER = "number"
|
||||
|
||||
|
||||
class ContextPolicy(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
consumer_agent_type: str = Field(..., min_length=1, max_length=64)
|
||||
window_mode: ContextWindowMode = ContextWindowMode.NUMBER
|
||||
count: int = Field(default=20, ge=1, le=200)
|
||||
|
||||
@field_validator("consumer_agent_type")
|
||||
@classmethod
|
||||
def _normalize_consumer_agent_type(cls, value: str) -> str:
|
||||
normalized = value.strip().lower()
|
||||
if not normalized:
|
||||
raise ValueError("consumer_agent_type must not be empty")
|
||||
return normalized
|
||||
|
||||
|
||||
class StageSpec(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
stage_name: str = Field(..., min_length=1, max_length=64)
|
||||
executor_kind: ExecutorKind
|
||||
default_visibility_mask: int = Field(..., ge=0, le=(1 << 63) - 1)
|
||||
context_policy: ContextPolicy
|
||||
|
||||
@field_validator("stage_name")
|
||||
@classmethod
|
||||
def _normalize_stage_name(cls, value: str) -> str:
|
||||
normalized = value.strip().lower()
|
||||
if not normalized:
|
||||
raise ValueError("stage_name must not be empty")
|
||||
return normalized
|
||||
|
||||
|
||||
class PipelineSpec(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
mode: str = Field(..., min_length=1, max_length=64)
|
||||
stages: list[StageSpec] = Field(..., min_length=1)
|
||||
|
||||
@field_validator("mode")
|
||||
@classmethod
|
||||
def _normalize_mode(cls, value: str) -> str:
|
||||
normalized = value.strip().lower()
|
||||
if not normalized:
|
||||
raise ValueError("mode must not be empty")
|
||||
return normalized
|
||||
@@ -0,0 +1,50 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import IntEnum
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
|
||||
class SystemVisibilityBit(IntEnum):
|
||||
UI_HISTORY = 0
|
||||
UI_REALTIME = 1
|
||||
|
||||
|
||||
class VisibilityMask(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
value: int = Field(..., ge=0, le=(1 << 63) - 1)
|
||||
|
||||
@classmethod
|
||||
def from_bits(cls, *, bits: list[int]) -> "VisibilityMask":
|
||||
mask = 0
|
||||
for bit in bits:
|
||||
validate_visibility_bit(bit=bit)
|
||||
mask |= 1 << bit
|
||||
return cls(value=mask)
|
||||
|
||||
def contains(self, *, bit: int) -> bool:
|
||||
validate_visibility_bit(bit=bit)
|
||||
return bool(self.value & (1 << bit))
|
||||
|
||||
|
||||
class VisibilityBitRef(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
bit: int = Field(..., ge=0, le=63)
|
||||
|
||||
@field_validator("bit")
|
||||
@classmethod
|
||||
def _validate_bit(cls, value: int) -> int:
|
||||
validate_visibility_bit(bit=value)
|
||||
return value
|
||||
|
||||
|
||||
def validate_visibility_bit(*, bit: int) -> None:
|
||||
if bit < 0 or bit > 63:
|
||||
raise ValueError("visibility bit must be in range [0, 63]")
|
||||
|
||||
|
||||
def bit_mask(*, bit: int) -> int:
|
||||
validate_visibility_bit(bit=bit)
|
||||
return 1 << bit
|
||||
@@ -0,0 +1,20 @@
|
||||
from schemas.automation.config import (
|
||||
AutomationAgentType,
|
||||
AutomationContextSource,
|
||||
AutomationContextWindowMode,
|
||||
AutomationJobConfig,
|
||||
AutomationMemoryContextConfig,
|
||||
default_memory_job_config,
|
||||
)
|
||||
from schemas.automation.scheduler import DueAutomationJob, SchedulerDispatchCommand
|
||||
|
||||
__all__ = [
|
||||
"AutomationAgentType",
|
||||
"AutomationContextSource",
|
||||
"AutomationContextWindowMode",
|
||||
"AutomationJobConfig",
|
||||
"AutomationMemoryContextConfig",
|
||||
"default_memory_job_config",
|
||||
"DueAutomationJob",
|
||||
"SchedulerDispatchCommand",
|
||||
]
|
||||
@@ -0,0 +1,62 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from core.agentscope.tools.tool_config import AgentTool
|
||||
|
||||
|
||||
class AutomationAgentType(str, Enum):
|
||||
MEMORY = "memory"
|
||||
|
||||
|
||||
class AutomationContextSource(str, Enum):
|
||||
LATEST_CHAT = "latest_chat"
|
||||
|
||||
|
||||
class AutomationContextWindowMode(str, Enum):
|
||||
DAY = "day"
|
||||
NUMBER = "number"
|
||||
|
||||
|
||||
class AutomationMemoryContextConfig(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
source: AutomationContextSource = AutomationContextSource.LATEST_CHAT
|
||||
window_mode: AutomationContextWindowMode = AutomationContextWindowMode.DAY
|
||||
window_count: int = Field(default=2, ge=1, le=200)
|
||||
|
||||
|
||||
class AutomationJobConfig(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
agent_type: AutomationAgentType = AutomationAgentType.MEMORY
|
||||
model_code: str = Field(default="qwen3.5-flash", min_length=1, max_length=64)
|
||||
enabled_tools: list[AgentTool] = Field(default_factory=list, max_length=32)
|
||||
input_template: str = Field(..., min_length=1, max_length=4000)
|
||||
context: AutomationMemoryContextConfig = Field(
|
||||
default_factory=AutomationMemoryContextConfig
|
||||
)
|
||||
|
||||
@field_validator("model_code")
|
||||
@classmethod
|
||||
def _validate_model_code(cls, value: str) -> str:
|
||||
normalized = value.strip()
|
||||
if normalized != "qwen3.5-flash":
|
||||
raise ValueError("model_code must be qwen3.5-flash")
|
||||
return normalized
|
||||
|
||||
|
||||
def default_memory_job_config() -> AutomationJobConfig:
|
||||
return AutomationJobConfig(
|
||||
agent_type=AutomationAgentType.MEMORY,
|
||||
model_code="qwen3.5-flash",
|
||||
enabled_tools=[AgentTool.CALENDAR_READ, AgentTool.USER_LOOKUP],
|
||||
input_template="请基于最近聊天上下文生成一段可执行的记忆总结与建议。",
|
||||
context=AutomationMemoryContextConfig(
|
||||
source=AutomationContextSource.LATEST_CHAT,
|
||||
window_mode=AutomationContextWindowMode.DAY,
|
||||
window_count=2,
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from models.automation_jobs import ScheduleType
|
||||
|
||||
|
||||
class DueAutomationJob(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
id: UUID
|
||||
owner_id: UUID
|
||||
schedule_type: ScheduleType
|
||||
timezone: str = Field(..., min_length=1, max_length=50)
|
||||
next_run_at: datetime
|
||||
|
||||
|
||||
class SchedulerDispatchCommand(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
owner_id: UUID
|
||||
automation_job_id: UUID
|
||||
thread_id: UUID
|
||||
run_id: str = Field(..., min_length=1, max_length=128)
|
||||
input_text: str = Field(..., min_length=1, max_length=4000)
|
||||
Reference in New Issue
Block a user