refactor(agent): restructure visibility masks, task queues, and memory service
Visibility mask refactoring: - Replace dead UI_REALTIME bit with CONTEXT_ASSEMBLY (bit 1) - Remove visibility_consumer_bit from SystemAgentLLMConfig and system_agents.yaml - Simplify _resolve_user_message_visibility_mask: chat->UI_HISTORY|CONTEXT_ASSEMBLY, automation->0 - Simplify _resolve_stage_visibility_mask: memory->UI_HISTORY, router/worker->UI_HISTORY|CONTEXT_ASSEMBLY - Remove stage_visibility_bit_map from store.py Task queue renaming: - Replace default_broker/bulk_broker/critical_broker with worker_agent_broker/worker_automation_broker - Queue names: 'default'/'bulk'/'critical' -> 'agent'/'automation' - Rename run_command_task -> run_command_task_agent/run_command_task_automation - AgentService derives queue from runtime_mode: chat->agent, automation->automation Architecture cleanup: - Move context_service.py from runtime/ to agentscope/services/ - Add MemoryService in v1/memory/ following repository/service pattern - Move consumer_registry.py and pipeline_spec.py from schemas/agent to agentscope/schemas/ - Delete dead code: registry_builder.py, VisibilityBitRef - Delete superseded plan docs
This commit is contained in:
@@ -8,12 +8,10 @@ from core.agentscope.events.persistence import MessageRepository, SessionReposit
|
||||
from core.logging import get_logger
|
||||
from models.agent_chat_message import AgentChatMessageRole
|
||||
from models.agent_chat_session import AgentChatSessionStatus
|
||||
from models.system_agents import SystemAgents
|
||||
from schemas.agent.system_agent import AgentType, SystemAgentLLMConfig
|
||||
from schemas.agent.system_agent import AgentType
|
||||
from schemas.agent.runtime_models import AgentOutput, ToolAgentOutput
|
||||
from schemas.agent.visibility import SystemVisibilityBit, bit_mask
|
||||
from schemas.messages.chat_message import AgentChatMessageMetadata
|
||||
from sqlalchemy import select
|
||||
|
||||
|
||||
class EventStore(Protocol):
|
||||
@@ -48,9 +46,6 @@ class SqlAlchemyEventStore:
|
||||
async with self._session_factory() as session:
|
||||
session_repo = SessionRepository(session)
|
||||
message_repo = MessageRepository(session)
|
||||
stage_visibility_bit_map = await self._load_stage_visibility_bit_map(
|
||||
session=session
|
||||
)
|
||||
chat_session = await session_repo.get_session(session_id=session_id)
|
||||
if chat_session is None:
|
||||
return
|
||||
@@ -83,7 +78,6 @@ class SqlAlchemyEventStore:
|
||||
chat_session=chat_session,
|
||||
session_repo=session_repo,
|
||||
message_repo=message_repo,
|
||||
stage_visibility_bit_map=stage_visibility_bit_map,
|
||||
)
|
||||
elif event_type == "TOOL_CALL_RESULT":
|
||||
await self._persist_tool_call_result(
|
||||
@@ -92,7 +86,6 @@ class SqlAlchemyEventStore:
|
||||
chat_session=chat_session,
|
||||
session_repo=session_repo,
|
||||
message_repo=message_repo,
|
||||
stage_visibility_bit_map=stage_visibility_bit_map,
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
@@ -105,7 +98,6 @@ class SqlAlchemyEventStore:
|
||||
chat_session: Any,
|
||||
session_repo: SessionRepository,
|
||||
message_repo: MessageRepository,
|
||||
stage_visibility_bit_map: dict[str, int],
|
||||
) -> None:
|
||||
message_id_raw = self._event_value(event, "messageId")
|
||||
message_id = message_id_raw if isinstance(message_id_raw, str) else ""
|
||||
@@ -146,17 +138,7 @@ class SqlAlchemyEventStore:
|
||||
|
||||
try:
|
||||
worker_output = AgentOutput.model_validate(worker_output_payload)
|
||||
raw_agent_type = self._event_value(event, "stage")
|
||||
normalized_agent_type = (
|
||||
str(raw_agent_type).strip().lower()
|
||||
if isinstance(raw_agent_type, str)
|
||||
else AgentType.WORKER.value
|
||||
)
|
||||
agent_type = (
|
||||
AgentType.MEMORY
|
||||
if normalized_agent_type == AgentType.MEMORY.value
|
||||
else AgentType.WORKER
|
||||
)
|
||||
agent_type = AgentType.WORKER
|
||||
metadata_model = AgentChatMessageMetadata(
|
||||
run_id=run_id_value,
|
||||
agent_type=agent_type,
|
||||
@@ -199,7 +181,6 @@ class SqlAlchemyEventStore:
|
||||
latency_ms=latency_ms,
|
||||
visibility_mask=self._resolve_stage_visibility_mask(
|
||||
event=event,
|
||||
stage_visibility_bit_map=stage_visibility_bit_map,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -226,7 +207,6 @@ class SqlAlchemyEventStore:
|
||||
chat_session: Any,
|
||||
session_repo: SessionRepository,
|
||||
message_repo: MessageRepository,
|
||||
stage_visibility_bit_map: dict[str, int],
|
||||
) -> None:
|
||||
run_id = self._event_value(event, "runId")
|
||||
run_id_value = run_id if isinstance(run_id, str) and run_id else None
|
||||
@@ -272,7 +252,6 @@ class SqlAlchemyEventStore:
|
||||
metadata=metadata_model.model_dump(mode="json", exclude_none=True),
|
||||
visibility_mask=self._resolve_stage_visibility_mask(
|
||||
event=event,
|
||||
stage_visibility_bit_map=stage_visibility_bit_map,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -301,39 +280,16 @@ class SqlAlchemyEventStore:
|
||||
self,
|
||||
*,
|
||||
event: dict[str, Any],
|
||||
stage_visibility_bit_map: dict[str, int],
|
||||
) -> int:
|
||||
base = bit_mask(bit=int(SystemVisibilityBit.UI_HISTORY))
|
||||
raw_stage = self._event_value(event, "stage")
|
||||
if not isinstance(raw_stage, str):
|
||||
return base
|
||||
return bit_mask(bit=int(SystemVisibilityBit.UI_HISTORY))
|
||||
normalized_stage = raw_stage.strip().lower()
|
||||
bit = stage_visibility_bit_map.get(normalized_stage)
|
||||
if bit is None and normalized_stage == AgentType.MEMORY.value:
|
||||
bit = 18
|
||||
if bit is None:
|
||||
return base
|
||||
return base | bit_mask(bit=bit)
|
||||
|
||||
async def _load_stage_visibility_bit_map(
|
||||
self,
|
||||
*,
|
||||
session: Any,
|
||||
) -> dict[str, int]:
|
||||
stmt = select(SystemAgents.agent_type, SystemAgents.config).where(
|
||||
SystemAgents.agent_type.in_(
|
||||
[AgentType.ROUTER.value, AgentType.WORKER.value, AgentType.MEMORY.value]
|
||||
)
|
||||
if normalized_stage == "memory":
|
||||
return bit_mask(bit=int(SystemVisibilityBit.UI_HISTORY))
|
||||
return bit_mask(bit=int(SystemVisibilityBit.UI_HISTORY)) | bit_mask(
|
||||
bit=int(SystemVisibilityBit.CONTEXT_ASSEMBLY)
|
||||
)
|
||||
rows = (await session.execute(stmt)).all()
|
||||
bit_map: dict[str, int] = {}
|
||||
for agent_type, raw_config in rows:
|
||||
if not isinstance(agent_type, str):
|
||||
continue
|
||||
config_payload = raw_config if isinstance(raw_config, dict) else {}
|
||||
llm_config = SystemAgentLLMConfig.model_validate(config_payload)
|
||||
bit_map[agent_type.strip().lower()] = llm_config.visibility_consumer_bit
|
||||
return bit_map
|
||||
|
||||
async def _update_session_state(
|
||||
self,
|
||||
|
||||
@@ -1,12 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from schemas.agent.pipeline_spec import (
|
||||
ContextPolicy,
|
||||
ContextWindowMode,
|
||||
ExecutorKind,
|
||||
PipelineSpec,
|
||||
StageSpec,
|
||||
)
|
||||
from core.agentscope.schemas.pipeline_spec import ExecutorKind, PipelineSpec, StageSpec
|
||||
from schemas.agent.system_agent import AgentType
|
||||
|
||||
|
||||
def build_default_pipeline_spec(*, mode: str) -> PipelineSpec:
|
||||
@@ -17,23 +12,13 @@ def build_default_pipeline_spec(*, mode: str) -> PipelineSpec:
|
||||
stages=[
|
||||
StageSpec(
|
||||
stage_name="router",
|
||||
agent_type=AgentType.ROUTER,
|
||||
executor_kind=ExecutorKind.SINGLE_SHOT,
|
||||
default_visibility_mask=0,
|
||||
context_policy=ContextPolicy(
|
||||
consumer_agent_type="router",
|
||||
window_mode=ContextWindowMode.DAY,
|
||||
count=20,
|
||||
),
|
||||
),
|
||||
StageSpec(
|
||||
stage_name="worker",
|
||||
agent_type=AgentType.WORKER,
|
||||
executor_kind=ExecutorKind.REACT,
|
||||
default_visibility_mask=0,
|
||||
context_policy=ContextPolicy(
|
||||
consumer_agent_type="worker",
|
||||
window_mode=ContextWindowMode.NUMBER,
|
||||
count=20,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
@@ -44,13 +29,8 @@ def build_default_pipeline_spec(*, mode: str) -> PipelineSpec:
|
||||
stages=[
|
||||
StageSpec(
|
||||
stage_name="memory",
|
||||
agent_type=AgentType.MEMORY,
|
||||
executor_kind=ExecutorKind.REACT,
|
||||
default_visibility_mask=0,
|
||||
context_policy=ContextPolicy(
|
||||
consumer_agent_type="memory",
|
||||
window_mode=ContextWindowMode.DAY,
|
||||
count=20,
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
+4
-1
@@ -1,6 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from schemas.agent.consumer_registry import AgentConsumerBinding, ConsumerRegistry
|
||||
from core.agentscope.schemas.consumer_registry import (
|
||||
AgentConsumerBinding,
|
||||
ConsumerRegistry,
|
||||
)
|
||||
|
||||
|
||||
def build_consumer_registry(
|
||||
@@ -88,10 +88,7 @@ class AgentScopeRunner:
|
||||
owner_id = UUID(user_context.id)
|
||||
runtime_client_time = self._resolve_runtime_client_time(run_input=run_input)
|
||||
pipeline_spec = build_default_pipeline_spec(mode=system_agent_mode)
|
||||
stage_agent_types = [
|
||||
self._parse_agent_type(stage_name=stage.stage_name)
|
||||
for stage in pipeline_spec.stages
|
||||
]
|
||||
stage_agent_types = [stage.agent_type for stage in pipeline_spec.stages]
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
if stage_agent_types == [AgentType.ROUTER, AgentType.WORKER]:
|
||||
@@ -177,17 +174,6 @@ class AgentScopeRunner:
|
||||
enabled_tool_names=enabled_tool_names,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_agent_type(*, stage_name: str) -> AgentType:
|
||||
normalized = stage_name.strip().lower()
|
||||
if normalized == AgentType.ROUTER.value:
|
||||
return AgentType.ROUTER
|
||||
if normalized == AgentType.WORKER.value:
|
||||
return AgentType.WORKER
|
||||
if normalized == AgentType.MEMORY.value:
|
||||
return AgentType.MEMORY
|
||||
raise ValueError(f"unsupported stage name: {stage_name}")
|
||||
|
||||
async def _load_stage_config(
|
||||
self,
|
||||
*,
|
||||
@@ -355,7 +341,6 @@ class AgentScopeRunner:
|
||||
temperature=0.7,
|
||||
max_tokens=None,
|
||||
timeout_seconds=30,
|
||||
visibility_consumer_bit=18,
|
||||
context_messages=ContextMessagesConfig(
|
||||
mode=(
|
||||
ContextBuildStrategy.DAY
|
||||
|
||||
@@ -12,7 +12,7 @@ from core.agentscope.events import (
|
||||
RedisStreamBus,
|
||||
SqlAlchemyEventStore,
|
||||
)
|
||||
from core.agentscope.runtime.context_service import AgentContextService
|
||||
from core.agentscope.services.context_service import AgentContextService
|
||||
from core.agentscope.runtime.orchestrator import AgentScopeRuntimeOrchestrator
|
||||
from core.agentscope.runtime.pipeline_registry import build_default_pipeline_spec
|
||||
from core.agentscope.schemas.agui_input import parse_run_input
|
||||
@@ -20,8 +20,7 @@ from core.auth.models import CurrentUser
|
||||
from core.config.settings import config
|
||||
from core.db.session import AsyncSessionLocal
|
||||
from core.logging import get_logger
|
||||
from core.taskiq.app import bulk_broker, critical_broker, default_broker
|
||||
from models.automation_jobs import AutomationJob
|
||||
from core.taskiq.app import worker_agent_broker, worker_automation_broker
|
||||
from schemas.agent.visibility import SystemVisibilityBit, bit_mask
|
||||
from schemas.automation.config import AutomationJobConfig
|
||||
from schemas.messages.chat_message import (
|
||||
@@ -33,8 +32,10 @@ from schemas.user import UserContext
|
||||
from services.base.redis import get_or_init_redis_client
|
||||
from services.base.supabase import supabase_service
|
||||
from v1.agent.repository import AgentRepository
|
||||
from v1.memory.repository import MemoryRepository
|
||||
from v1.memory.service import MemoryService
|
||||
from v1.users.dependencies import get_user_service
|
||||
from sqlalchemy import select
|
||||
|
||||
|
||||
logger = get_logger("core.agentscope.runtime.tasks")
|
||||
_MAX_CONTEXT_ATTACHMENTS = 3
|
||||
@@ -188,29 +189,6 @@ async def _build_recent_context_messages(
|
||||
return converted
|
||||
|
||||
|
||||
async def _load_memory_job_config(
|
||||
*,
|
||||
session: Any,
|
||||
owner_id: UUID,
|
||||
automation_job_id: str,
|
||||
) -> AutomationJobConfig:
|
||||
try:
|
||||
job_uuid = UUID(automation_job_id)
|
||||
except ValueError as exc:
|
||||
raise ValueError("automation_job_id is invalid") from exc
|
||||
|
||||
stmt = (
|
||||
select(AutomationJob)
|
||||
.where(AutomationJob.id == job_uuid)
|
||||
.where(AutomationJob.owner_id == owner_id)
|
||||
.where(AutomationJob.deleted_at.is_(None))
|
||||
)
|
||||
row = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if row is None:
|
||||
raise ValueError("automation job not found")
|
||||
return AutomationJobConfig.model_validate(row.config or {})
|
||||
|
||||
|
||||
async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
command_type = str(command.get("command", "run")).strip().lower()
|
||||
raw_owner_id = command.get("owner_id")
|
||||
@@ -245,10 +223,11 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
memory_job_config: AutomationJobConfig | None = None
|
||||
if system_agent_mode == "memory":
|
||||
assert isinstance(raw_automation_job_id, str)
|
||||
memory_job_config = await _load_memory_job_config(
|
||||
session=session,
|
||||
job_uuid = UUID(raw_automation_job_id)
|
||||
memory_service = MemoryService(MemoryRepository(session))
|
||||
memory_job_config = await memory_service.get_memory_job_config(
|
||||
job_id=job_uuid,
|
||||
owner_id=owner_id,
|
||||
automation_job_id=raw_automation_job_id,
|
||||
)
|
||||
|
||||
redis_client = await get_or_init_redis_client()
|
||||
@@ -272,7 +251,7 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
context_messages = await _build_recent_context_messages(
|
||||
session=session,
|
||||
thread_id=thread_id,
|
||||
context_mode=pipeline_spec.stages[0].context_policy.consumer_agent_type,
|
||||
context_mode=pipeline_spec.stages[0].agent_type.value,
|
||||
memory_job_config=memory_job_config,
|
||||
)
|
||||
|
||||
@@ -296,16 +275,11 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
}
|
||||
|
||||
|
||||
@default_broker.task(task_name="tasks.agentscope.run_command")
|
||||
async def run_command_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
@worker_agent_broker.task(task_name="tasks.agentscope.run_command.agent")
|
||||
async def run_command_task_agent(command: dict[str, object]) -> dict[str, object]:
|
||||
return await run_agentscope_task(command)
|
||||
|
||||
|
||||
@critical_broker.task(task_name="tasks.agentscope.run_command.critical")
|
||||
async def run_command_task_critical(command: dict[str, Any]) -> dict[str, object]:
|
||||
return await run_agentscope_task(command)
|
||||
|
||||
|
||||
@bulk_broker.task(task_name="tasks.agentscope.run_command.bulk")
|
||||
async def run_command_task_bulk(command: dict[str, Any]) -> dict[str, object]:
|
||||
@worker_automation_broker.task(task_name="tasks.agentscope.run_command.automation")
|
||||
async def run_command_task_automation(command: dict[str, object]) -> dict[str, object]:
|
||||
return await run_agentscope_task(command)
|
||||
|
||||
+3
-23
@@ -4,40 +4,20 @@ from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from schemas.agent.system_agent import AgentType
|
||||
|
||||
|
||||
class ExecutorKind(str, Enum):
|
||||
SINGLE_SHOT = "single_shot"
|
||||
REACT = "react"
|
||||
|
||||
|
||||
class ContextWindowMode(str, Enum):
|
||||
DAY = "day"
|
||||
NUMBER = "number"
|
||||
|
||||
|
||||
class ContextPolicy(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
consumer_agent_type: str = Field(..., min_length=1, max_length=64)
|
||||
window_mode: ContextWindowMode = ContextWindowMode.NUMBER
|
||||
count: int = Field(default=20, ge=1, le=200)
|
||||
|
||||
@field_validator("consumer_agent_type")
|
||||
@classmethod
|
||||
def _normalize_consumer_agent_type(cls, value: str) -> str:
|
||||
normalized = value.strip().lower()
|
||||
if not normalized:
|
||||
raise ValueError("consumer_agent_type must not be empty")
|
||||
return normalized
|
||||
|
||||
|
||||
class StageSpec(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
stage_name: str = Field(..., min_length=1, max_length=64)
|
||||
agent_type: AgentType
|
||||
executor_kind: ExecutorKind
|
||||
default_visibility_mask: int = Field(..., ge=0, le=(1 << 63) - 1)
|
||||
context_policy: ContextPolicy
|
||||
|
||||
@field_validator("stage_name")
|
||||
@classmethod
|
||||
+2
-2
@@ -5,7 +5,7 @@ from typing import Protocol
|
||||
|
||||
from core.agentscope.runtime.context_loader_registry import CONTEXT_LOADER_REGISTRY
|
||||
from schemas.agent.system_agent import SystemAgentLLMConfig
|
||||
from schemas.agent.visibility import bit_mask
|
||||
from schemas.agent.visibility import SystemVisibilityBit, bit_mask
|
||||
|
||||
_DEFAULT_CONTEXT_WINDOW_USER_MESSAGES = 20
|
||||
_DEFAULT_ROUTER_CONTEXT_DAY_COUNT = 20
|
||||
@@ -61,7 +61,7 @@ class AgentContextService:
|
||||
|
||||
normalized_config = self._normalize_system_agent_config(raw_llm_config)
|
||||
context_config = normalized_config.context_messages
|
||||
visibility_mask = bit_mask(bit=normalized_config.visibility_consumer_bit)
|
||||
visibility_mask = bit_mask(bit=int(SystemVisibilityBit.CONTEXT_ASSEMBLY))
|
||||
context_loader = CONTEXT_LOADER_REGISTRY.resolve(mode=context_config.mode)
|
||||
return await context_loader(
|
||||
self,
|
||||
@@ -2,12 +2,12 @@ from __future__ import annotations
|
||||
|
||||
|
||||
from core.logging import get_logger
|
||||
from core.taskiq.app import bulk_broker
|
||||
from core.taskiq.app import worker_automation_broker
|
||||
|
||||
logger = get_logger("core.automation.tasks")
|
||||
|
||||
|
||||
@bulk_broker.task(task_name="tasks.automation.scan_due_jobs")
|
||||
@worker_automation_broker.task(task_name="tasks.automation.scan_due_jobs")
|
||||
async def scan_due_automation_jobs_task(limit: int | None = None) -> dict[str, int]:
|
||||
from core.automation.scheduler import run_automation_scheduler_scan
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ agents:
|
||||
temperature: 0.7
|
||||
max_tokens: null
|
||||
timeout_seconds: 30
|
||||
visibility_consumer_bit: 16
|
||||
context_messages:
|
||||
mode: day
|
||||
count: 2
|
||||
@@ -19,7 +18,6 @@ agents:
|
||||
temperature: 0.7
|
||||
max_tokens: null
|
||||
timeout_seconds: 30
|
||||
visibility_consumer_bit: 17
|
||||
context_messages:
|
||||
mode: number
|
||||
count: 20
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
from core.taskiq.app import broker, bulk_broker, critical_broker, default_broker
|
||||
from core.taskiq.app import broker, worker_agent_broker, worker_automation_broker
|
||||
|
||||
__all__ = ["broker", "default_broker", "critical_broker", "bulk_broker"]
|
||||
__all__ = ["broker", "worker_agent_broker", "worker_automation_broker"]
|
||||
|
||||
@@ -21,11 +21,9 @@ def _build_broker(queue_name: str) -> ListQueueBroker:
|
||||
)
|
||||
|
||||
|
||||
default_broker = _build_broker("default")
|
||||
critical_broker = _build_broker("critical")
|
||||
bulk_broker = _build_broker("bulk")
|
||||
worker_agent_broker = _build_broker("agent")
|
||||
worker_automation_broker = _build_broker("automation")
|
||||
|
||||
# Backward-compatible export name for existing imports/tests.
|
||||
broker = default_broker
|
||||
broker = worker_agent_broker
|
||||
|
||||
__all__ = ["broker", "default_broker", "critical_broker", "bulk_broker"]
|
||||
__all__ = ["broker", "worker_agent_broker", "worker_automation_broker"]
|
||||
|
||||
@@ -1,17 +1,10 @@
|
||||
from schemas.agent.consumer_registry import AgentConsumerBinding, ConsumerRegistry
|
||||
from schemas.agent.forwarded_props import (
|
||||
ClientTimeContext,
|
||||
ForwardedPropsPayload,
|
||||
parse_forwarded_props_agent_type,
|
||||
parse_forwarded_props_client_time,
|
||||
parse_forwarded_props_runtime_mode,
|
||||
)
|
||||
from schemas.agent.pipeline_spec import (
|
||||
ContextPolicy,
|
||||
ContextWindowMode,
|
||||
ExecutorKind,
|
||||
PipelineSpec,
|
||||
StageSpec,
|
||||
)
|
||||
from schemas.agent.forwarded_props import RuntimeMode
|
||||
from schemas.agent.runtime_models import (
|
||||
AgentOutput,
|
||||
ConstraintItem,
|
||||
@@ -45,28 +38,22 @@ from schemas.agent.ui_hints import (
|
||||
__all__ = [
|
||||
"AgentType",
|
||||
"AgentOutput",
|
||||
"AgentConsumerBinding",
|
||||
"ConstraintItem",
|
||||
"ConsumerRegistry",
|
||||
"ContextPolicy",
|
||||
"ContextWindowMode",
|
||||
"ExecutionMode",
|
||||
"ExecutorKind",
|
||||
"ForwardedPropsPayload",
|
||||
"KeyEntity",
|
||||
"NormalizedTaskInput",
|
||||
"PipelineSpec",
|
||||
"ResultTyping",
|
||||
"ClientTimeContext",
|
||||
"ResultType",
|
||||
"RouterAgentOutput",
|
||||
"RouterUiDecision",
|
||||
"RunStatus",
|
||||
"RuntimeMode",
|
||||
"TaskType",
|
||||
"TaskTyping",
|
||||
"SystemAgentLLMConfig",
|
||||
"SystemVisibilityBit",
|
||||
"StageSpec",
|
||||
"ToolAgentOutput",
|
||||
"ToolStatus",
|
||||
"UiMode",
|
||||
@@ -79,7 +66,7 @@ __all__ = [
|
||||
"WorkerAgentOutputLite",
|
||||
"WorkerAgentOutputRich",
|
||||
"bit_mask",
|
||||
"parse_forwarded_props_agent_type",
|
||||
"parse_forwarded_props_client_time",
|
||||
"parse_forwarded_props_runtime_mode",
|
||||
"resolve_worker_output_model",
|
||||
]
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
import re
|
||||
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||
|
||||
@@ -59,20 +60,17 @@ class ClientTimeContext(BaseModel):
|
||||
return value
|
||||
|
||||
|
||||
class RuntimeMode(str, Enum):
|
||||
CHAT = "chat"
|
||||
AUTOMATION = "automation"
|
||||
|
||||
|
||||
class ForwardedPropsPayload(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
agent_type: str = Field(..., min_length=1, max_length=64)
|
||||
runtime_mode: RuntimeMode
|
||||
client_time: ClientTimeContext | None = None
|
||||
|
||||
@field_validator("agent_type")
|
||||
@classmethod
|
||||
def validate_agent_type(cls, value: str) -> str:
|
||||
normalized = value.strip().lower()
|
||||
if not normalized:
|
||||
raise ValueError("invalid forwarded_props.agent_type")
|
||||
return normalized
|
||||
|
||||
|
||||
def parse_forwarded_props(forwarded_props: object) -> ForwardedPropsPayload:
|
||||
if not isinstance(forwarded_props, dict):
|
||||
@@ -90,6 +88,6 @@ def parse_forwarded_props_client_time(
|
||||
return payload.client_time
|
||||
|
||||
|
||||
def parse_forwarded_props_agent_type(forwarded_props: object) -> str:
|
||||
def parse_forwarded_props_runtime_mode(forwarded_props: object) -> RuntimeMode:
|
||||
payload = parse_forwarded_props(forwarded_props)
|
||||
return payload.agent_type
|
||||
return payload.runtime_mode
|
||||
|
||||
@@ -2,15 +2,13 @@ from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.agentscope.tools.tool_config import AgentTool, parse_agent_tool
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class AgentType(str, Enum):
|
||||
ROUTER = "router"
|
||||
WORKER = "worker"
|
||||
MEMORY = "memory"
|
||||
|
||||
|
||||
class ContextBuildStrategy(str, Enum):
|
||||
@@ -30,7 +28,6 @@ class SystemAgentLLMConfig(BaseModel):
|
||||
context_messages: ContextMessagesConfig = Field(
|
||||
default_factory=ContextMessagesConfig
|
||||
)
|
||||
visibility_consumer_bit: int = Field(default=16, ge=16, le=63)
|
||||
enabled_tools: list[AgentTool] = Field(default_factory=list, max_length=32)
|
||||
|
||||
@field_validator("enabled_tools", mode="before")
|
||||
@@ -42,10 +39,13 @@ class SystemAgentLLMConfig(BaseModel):
|
||||
raise ValueError("enabled_tools must be a list")
|
||||
normalized: list[AgentTool] = []
|
||||
for item in value:
|
||||
raw_item = str(item or "").strip()
|
||||
if not raw_item:
|
||||
continue
|
||||
tool = parse_agent_tool(raw_item)
|
||||
if isinstance(item, AgentTool):
|
||||
tool = item
|
||||
else:
|
||||
raw_item = str(item or "").strip()
|
||||
if not raw_item:
|
||||
continue
|
||||
tool = parse_agent_tool(raw_item)
|
||||
if tool not in normalized:
|
||||
normalized.append(tool)
|
||||
return normalized
|
||||
|
||||
@@ -7,7 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
class SystemVisibilityBit(IntEnum):
|
||||
UI_HISTORY = 0
|
||||
UI_REALTIME = 1
|
||||
CONTEXT_ASSEMBLY = 1
|
||||
|
||||
|
||||
class VisibilityMask(BaseModel):
|
||||
|
||||
@@ -0,0 +1,120 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
import dashscope
|
||||
from dashscope.audio.asr import Recognition, RecognitionCallback
|
||||
|
||||
from core.config.settings import config
|
||||
from core.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AsrService:
|
||||
def __init__(self) -> None:
|
||||
self._api_key: str | None = None
|
||||
|
||||
def _get_api_key(self) -> str:
|
||||
if self._api_key is None:
|
||||
dashscope_key = config.llm.provider_keys.get("dashscope")
|
||||
if not dashscope_key:
|
||||
raise ValueError(
|
||||
"DASHSCOPE_API_KEY not configured. Set SOCIAL_LLM__PROVIDER_KEYS__DASHSCOPE in environment."
|
||||
)
|
||||
self._api_key = dashscope_key
|
||||
return self._api_key
|
||||
|
||||
async def transcribe_file(self, file_path: str, filename: str) -> str:
|
||||
try:
|
||||
dashscope.api_key = self._get_api_key()
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
class SyncCallback(RecognitionCallback):
|
||||
error: str | None = None
|
||||
|
||||
def on_error(self, result: Any) -> None:
|
||||
self.error = str(result)
|
||||
|
||||
callback = SyncCallback()
|
||||
recognizer = Recognition(
|
||||
model="fun-asr-realtime-2026-02-28",
|
||||
callback=callback,
|
||||
format="wav",
|
||||
sample_rate=16000,
|
||||
)
|
||||
|
||||
result: Any = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: recognizer.call(file=file_path),
|
||||
)
|
||||
|
||||
if callback.error:
|
||||
raise RuntimeError(f"ASR error: {callback.error}")
|
||||
status_code = self._extract_field(result, "status_code")
|
||||
if status_code != 200:
|
||||
message = self._extract_field(result, "message")
|
||||
raise RuntimeError(f"ASR transcription failed: {message}")
|
||||
|
||||
sentence = self._extract_sentence_payload(result)
|
||||
if sentence is None:
|
||||
request_id = self._extract_field(result, "request_id")
|
||||
logger.warning(
|
||||
"ASR returned empty result", extra={"request_id": request_id}
|
||||
)
|
||||
return ""
|
||||
|
||||
if isinstance(sentence, dict):
|
||||
transcription = sentence.get("text", "")
|
||||
elif isinstance(sentence, list):
|
||||
transcription = " ".join(
|
||||
item.get("text", "") for item in sentence if isinstance(item, dict)
|
||||
)
|
||||
else:
|
||||
transcription = str(sentence) if sentence else ""
|
||||
|
||||
logger.info(
|
||||
"ASR transcription completed",
|
||||
extra={"filename": filename, "transcript_length": len(transcription)},
|
||||
)
|
||||
return transcription
|
||||
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except RuntimeError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.exception("ASR transcription error")
|
||||
raise RuntimeError(f"ASR transcription failed: {exc}") from exc
|
||||
|
||||
def _extract_sentence_payload(self, result: Any) -> Any | None:
|
||||
if isinstance(result, dict):
|
||||
output = result.get("output")
|
||||
if isinstance(output, dict):
|
||||
return output.get("sentence")
|
||||
if output is not None:
|
||||
return getattr(output, "sentence", None)
|
||||
return result.get("sentence")
|
||||
|
||||
get_sentence = getattr(result, "get_sentence", None)
|
||||
if callable(get_sentence):
|
||||
sentence = get_sentence()
|
||||
if sentence is not None:
|
||||
return sentence
|
||||
|
||||
output = getattr(result, "output", None)
|
||||
if output is None:
|
||||
return None
|
||||
if isinstance(output, dict):
|
||||
return output.get("sentence")
|
||||
return getattr(output, "sentence", None)
|
||||
|
||||
def _extract_field(self, result: Any, field: str) -> Any | None:
|
||||
if isinstance(result, dict):
|
||||
return result.get(field)
|
||||
return getattr(result, field, None)
|
||||
|
||||
|
||||
asr_service = AsrService()
|
||||
@@ -44,17 +44,14 @@ class TaskiqQueueClient:
|
||||
@staticmethod
|
||||
def _select_queue_task(command: dict[str, object]) -> Any:
|
||||
from core.agentscope.runtime.tasks import (
|
||||
run_command_task,
|
||||
run_command_task_bulk,
|
||||
run_command_task_critical,
|
||||
run_command_task_agent,
|
||||
run_command_task_automation,
|
||||
)
|
||||
|
||||
queue = str(command.get("queue", "default")).strip().lower()
|
||||
if queue == "critical":
|
||||
return run_command_task_critical
|
||||
if queue == "bulk":
|
||||
return run_command_task_bulk
|
||||
return run_command_task
|
||||
queue = str(command.get("queue", "agent")).strip().lower()
|
||||
if queue == "automation":
|
||||
return run_command_task_automation
|
||||
return run_command_task_agent
|
||||
|
||||
async def enqueue(
|
||||
self, *, command: dict[str, object], dedup_key: str | None
|
||||
|
||||
@@ -6,7 +6,7 @@ import re
|
||||
import tempfile
|
||||
from collections.abc import AsyncIterator
|
||||
from datetime import date
|
||||
from typing import Annotated, Union
|
||||
from typing import Annotated
|
||||
|
||||
from ag_ui.core import RunAgentInput
|
||||
from core.agentscope.events import to_sse_event
|
||||
@@ -28,7 +28,7 @@ from fastapi import (
|
||||
UploadFile,
|
||||
status,
|
||||
)
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.responses import StreamingResponse
|
||||
from services.base.redis import get_or_init_redis_client
|
||||
from v1.agent.dependencies import get_agent_service
|
||||
from v1.agent.schemas import (
|
||||
@@ -39,7 +39,8 @@ from v1.agent.schemas import (
|
||||
HistorySnapshotResponse,
|
||||
TaskAcceptedResponse,
|
||||
)
|
||||
from v1.agent.service import AgentService, asr_service
|
||||
from v1.agent.asr import asr_service
|
||||
from v1.agent.service import AgentService
|
||||
from v1.users.dependencies import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/agent", tags=["agent"])
|
||||
@@ -73,15 +74,13 @@ async def _acquire_sse_slot(*, user_id: str) -> bool:
|
||||
count = await redis.incr(key)
|
||||
if count == 1:
|
||||
await redis.expire(key, _SSE_SLOT_TTL_SECONDS)
|
||||
elif count > _MAX_SSE_CONNECTIONS_PER_USER:
|
||||
await redis.decr(key)
|
||||
return False
|
||||
else:
|
||||
ttl = await redis.ttl(key)
|
||||
if int(ttl) < 0:
|
||||
if ttl < 0:
|
||||
await redis.expire(key, _SSE_SLOT_TTL_SECONDS)
|
||||
if int(count) > _MAX_SSE_CONNECTIONS_PER_USER:
|
||||
after_decr = await redis.decr(key)
|
||||
if int(after_decr) <= 0:
|
||||
await redis.delete(key)
|
||||
return False
|
||||
return True
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning(
|
||||
@@ -97,13 +96,18 @@ async def _release_sse_slot(*, user_id: str) -> None:
|
||||
redis = await get_or_init_redis_client()
|
||||
key = f"agent:sse-active:{user_id}"
|
||||
count = await redis.decr(key)
|
||||
if int(count) <= 0:
|
||||
if count <= 0:
|
||||
await redis.delete(key)
|
||||
return None
|
||||
ttl = await redis.ttl(key)
|
||||
if int(ttl) < 0:
|
||||
await redis.expire(key, _SSE_SLOT_TTL_SECONDS)
|
||||
except Exception: # noqa: BLE001
|
||||
else:
|
||||
ttl = await redis.ttl(key)
|
||||
if ttl < 0:
|
||||
await redis.expire(key, _SSE_SLOT_TTL_SECONDS)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning(
|
||||
"SSE slot release failed",
|
||||
user_id=user_id,
|
||||
reason=str(exc),
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@@ -176,6 +180,11 @@ async def stream_events(
|
||||
last_event_id=cursor,
|
||||
current_user=current_user,
|
||||
)
|
||||
except TimeoutError:
|
||||
idle_polls += 1
|
||||
yield ": keep-alive\n\n"
|
||||
await asyncio.sleep(0.2)
|
||||
continue
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning(
|
||||
"SSE stream read failed",
|
||||
@@ -183,11 +192,6 @@ async def stream_events(
|
||||
user_id=str(current_user.id),
|
||||
reason=str(exc),
|
||||
)
|
||||
if "Timeout reading from" in str(exc):
|
||||
idle_polls += 1
|
||||
yield ": keep-alive\n\n"
|
||||
await asyncio.sleep(0.2)
|
||||
continue
|
||||
break
|
||||
|
||||
if not rows:
|
||||
@@ -291,12 +295,12 @@ async def create_attachment_signed_url(
|
||||
async def transcribe(
|
||||
audio: UploadFile,
|
||||
request: Request,
|
||||
_current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
) -> Union[AsrTranscribeResponse, JSONResponse]:
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
) -> AsrTranscribeResponse:
|
||||
temp_path: str | None = None
|
||||
try:
|
||||
if audio.content_type not in _ALLOWED_AUDIO_CONTENT_TYPES:
|
||||
raise ValueError("Unsupported audio format")
|
||||
raise HTTPException(status_code=400, detail="Unsupported audio format")
|
||||
|
||||
content_length = request.headers.get("content-length")
|
||||
if content_length is not None:
|
||||
@@ -309,7 +313,7 @@ async def transcribe(
|
||||
and declared_length
|
||||
> _MAX_TRANSCRIBE_AUDIO_BYTES + _MULTIPART_OVERHEAD_BYTES
|
||||
):
|
||||
raise ValueError("Audio file too large")
|
||||
raise HTTPException(status_code=400, detail="Audio file too large")
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
|
||||
temp_path = tmp_file.name
|
||||
@@ -322,16 +326,16 @@ async def transcribe(
|
||||
break
|
||||
total_bytes += len(chunk)
|
||||
if total_bytes > _MAX_TRANSCRIBE_AUDIO_BYTES:
|
||||
raise ValueError("Audio file too large")
|
||||
raise HTTPException(status_code=400, detail="Audio file too large")
|
||||
if len(header) < _WAV_HEADER_MIN_BYTES:
|
||||
required = _WAV_HEADER_MIN_BYTES - len(header)
|
||||
header.extend(chunk[:required])
|
||||
tmp_file.write(chunk)
|
||||
|
||||
if total_bytes == 0:
|
||||
raise ValueError("Empty audio file")
|
||||
raise HTTPException(status_code=400, detail="Empty audio file")
|
||||
if not _looks_like_wav_header(bytes(header)):
|
||||
raise ValueError("Unsupported audio format")
|
||||
raise HTTPException(status_code=400, detail="Unsupported audio format")
|
||||
|
||||
transcript = await asr_service.transcribe_file(
|
||||
temp_path, audio.filename or "unknown"
|
||||
@@ -339,17 +343,14 @@ async def transcribe(
|
||||
|
||||
return AsrTranscribeResponse(transcript=transcript)
|
||||
|
||||
except ValueError as exc:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(exc)},
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except RuntimeError:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
content={"detail": "ASR service unavailable"},
|
||||
)
|
||||
raise HTTPException(status_code=502, detail="ASR service unavailable")
|
||||
finally:
|
||||
await audio.close()
|
||||
if temp_path and os.path.exists(temp_path):
|
||||
os.unlink(temp_path)
|
||||
if temp_path:
|
||||
try:
|
||||
os.unlink(temp_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
@@ -1,12 +1,95 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
from dataclasses import dataclass
|
||||
from datetime import date
|
||||
from typing import Any, Literal, Protocol
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from schemas.agent.ui_schema import UiSchemaRenderer
|
||||
|
||||
|
||||
class AgentRepositoryLike(Protocol):
|
||||
async def get_session_owner(self, *, session_id: str) -> str: ...
|
||||
|
||||
async def create_session_for_user(
|
||||
self, *, user_id: str, session_id: str | None = None
|
||||
) -> str: ...
|
||||
|
||||
async def commit(self) -> None: ...
|
||||
|
||||
async def rollback(self) -> None: ...
|
||||
|
||||
async def get_history_day(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
before: date | None,
|
||||
visibility_mask: int | None = None,
|
||||
) -> dict[str, object] | None: ...
|
||||
|
||||
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None: ...
|
||||
|
||||
async def persist_user_message(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
content: str,
|
||||
metadata: Any,
|
||||
visibility_mask: int,
|
||||
) -> None: ...
|
||||
|
||||
async def get_system_agent_config(
|
||||
self, *, agent_type: str
|
||||
) -> dict[str, object] | None: ...
|
||||
|
||||
|
||||
class QueueClientLike(Protocol):
|
||||
async def enqueue(
|
||||
self, *, command: dict[str, object], dedup_key: str | None
|
||||
) -> str: ...
|
||||
|
||||
|
||||
class EventStreamLike(Protocol):
|
||||
async def read(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
last_event_id: str | None,
|
||||
) -> list[dict[str, object]]: ...
|
||||
|
||||
|
||||
class AttachmentStorageLike(Protocol):
|
||||
async def upload_bytes(
|
||||
self,
|
||||
*,
|
||||
bucket: str,
|
||||
path: str,
|
||||
content: bytes,
|
||||
content_type: str,
|
||||
) -> str: ...
|
||||
|
||||
async def download_bytes(self, *, bucket: str, path: str) -> bytes: ...
|
||||
|
||||
async def create_signed_url(
|
||||
self,
|
||||
*,
|
||||
bucket: str,
|
||||
path: str,
|
||||
expires_in_seconds: int,
|
||||
) -> str: ...
|
||||
|
||||
def parse_signed_url(self, url: str) -> tuple[str, str]: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TaskAccepted:
|
||||
task_id: str
|
||||
thread_id: str
|
||||
run_id: str
|
||||
created: bool
|
||||
|
||||
|
||||
class TaskAcceptedResponse(BaseModel):
|
||||
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
|
||||
|
||||
|
||||
+37
-288
@@ -1,15 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from datetime import date
|
||||
import hashlib
|
||||
from typing import Any, Protocol
|
||||
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import dashscope
|
||||
from ag_ui.core import RunAgentInput
|
||||
from dashscope.audio.asr import Recognition, RecognitionCallback
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
@@ -17,102 +13,32 @@ from core.auth.models import CurrentUser
|
||||
from core.agentscope.schemas.agui_input import extract_latest_user_payload
|
||||
from core.config.settings import config
|
||||
from core.logging import get_logger
|
||||
from schemas.agent.forwarded_props import parse_forwarded_props_agent_type
|
||||
from schemas.agent.system_agent import SystemAgentLLMConfig
|
||||
from schemas.agent.forwarded_props import (
|
||||
parse_forwarded_props_runtime_mode,
|
||||
RuntimeMode,
|
||||
)
|
||||
from schemas.agent.visibility import SystemVisibilityBit, bit_mask
|
||||
from schemas.messages.chat_message import (
|
||||
AgentChatMessageMetadata,
|
||||
UserMessageAttachment,
|
||||
extract_user_message_attachments,
|
||||
)
|
||||
from v1.agent.schemas import HistorySnapshotResponse
|
||||
from v1.agent.schemas import (
|
||||
AgentRepositoryLike,
|
||||
AttachmentStorageLike,
|
||||
EventStreamLike,
|
||||
HistorySnapshotResponse,
|
||||
QueueClientLike,
|
||||
TaskAccepted,
|
||||
)
|
||||
from v1.agent.utils import (
|
||||
MAX_ATTACHMENT_BYTES,
|
||||
MAX_ATTACHMENTS_PER_MESSAGE,
|
||||
is_safe_attachment_path,
|
||||
mime_to_suffix,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
_ALLOWED_ATTACHMENT_MIME_TYPES = {"image/png", "image/jpeg", "image/webp"}
|
||||
_MAX_ATTACHMENT_BYTES = 5 * 1024 * 1024
|
||||
_MAX_TOTAL_ATTACHMENT_BYTES = 12 * 1024 * 1024
|
||||
_MAX_ATTACHMENTS_PER_MESSAGE = 3
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TaskAccepted:
|
||||
task_id: str
|
||||
thread_id: str
|
||||
run_id: str
|
||||
created: bool
|
||||
|
||||
|
||||
class AgentRepositoryLike(Protocol):
|
||||
async def get_session_owner(self, *, session_id: str) -> str: ...
|
||||
|
||||
async def create_session_for_user(
|
||||
self, *, user_id: str, session_id: str | None = None
|
||||
) -> str: ...
|
||||
|
||||
async def commit(self) -> None: ...
|
||||
|
||||
async def rollback(self) -> None: ...
|
||||
|
||||
async def get_history_day(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
before: date | None,
|
||||
visibility_mask: int | None = None,
|
||||
) -> dict[str, object] | None: ...
|
||||
|
||||
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None: ...
|
||||
|
||||
async def persist_user_message(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
content: str,
|
||||
metadata: AgentChatMessageMetadata | None,
|
||||
visibility_mask: int,
|
||||
) -> None: ...
|
||||
|
||||
async def get_system_agent_config(
|
||||
self, *, agent_type: str
|
||||
) -> dict[str, object] | None: ...
|
||||
|
||||
|
||||
class QueueClientLike(Protocol):
|
||||
async def enqueue(
|
||||
self, *, command: dict[str, object], dedup_key: str | None
|
||||
) -> str: ...
|
||||
|
||||
|
||||
class EventStreamLike(Protocol):
|
||||
async def read(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
last_event_id: str | None,
|
||||
) -> list[dict[str, object]]: ...
|
||||
|
||||
|
||||
class AttachmentStorageLike(Protocol):
|
||||
async def upload_bytes(
|
||||
self,
|
||||
*,
|
||||
bucket: str,
|
||||
path: str,
|
||||
content: bytes,
|
||||
content_type: str,
|
||||
) -> str: ...
|
||||
|
||||
async def download_bytes(self, *, bucket: str, path: str) -> bytes: ...
|
||||
|
||||
async def create_signed_url(
|
||||
self,
|
||||
*,
|
||||
bucket: str,
|
||||
path: str,
|
||||
expires_in_seconds: int,
|
||||
) -> str: ...
|
||||
|
||||
def parse_signed_url(self, url: str) -> tuple[str, str]: ...
|
||||
|
||||
|
||||
def ensure_session_owner(*, owner_id: str, current_user: CurrentUser) -> None:
|
||||
@@ -152,14 +78,9 @@ class AgentService:
|
||||
run_id = run_input.run_id
|
||||
forwarded_props = getattr(run_input, "forwarded_props", None)
|
||||
try:
|
||||
agent_type = parse_forwarded_props_agent_type(forwarded_props)
|
||||
runtime_mode = parse_forwarded_props_runtime_mode(forwarded_props)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
||||
if agent_type == "memory":
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail="memory mode is automation-only",
|
||||
)
|
||||
|
||||
try:
|
||||
owner = await self._repository.get_session_owner(session_id=thread_id)
|
||||
@@ -185,7 +106,7 @@ class AgentService:
|
||||
current_user=current_user,
|
||||
)
|
||||
visibility_mask = await self._resolve_user_message_visibility_mask(
|
||||
agent_type=agent_type
|
||||
runtime_mode=runtime_mode
|
||||
)
|
||||
await self._repository.persist_user_message(
|
||||
session_id=thread_id,
|
||||
@@ -195,6 +116,7 @@ class AgentService:
|
||||
)
|
||||
await self._repository.commit()
|
||||
|
||||
queue = "automation" if runtime_mode == RuntimeMode.AUTOMATION else "agent"
|
||||
task_id = await self._queue.enqueue(
|
||||
command={
|
||||
"command": "run",
|
||||
@@ -202,6 +124,7 @@ class AgentService:
|
||||
"run_input": run_input.model_dump(
|
||||
mode="json", by_alias=True, exclude_none=True
|
||||
),
|
||||
"queue": queue,
|
||||
},
|
||||
dedup_key=None,
|
||||
)
|
||||
@@ -212,60 +135,14 @@ class AgentService:
|
||||
created=created,
|
||||
)
|
||||
|
||||
async def _resolve_user_message_visibility_mask(self, *, agent_type: str) -> int:
|
||||
normalized_agent_type = agent_type.strip().lower()
|
||||
history_bit_mask = bit_mask(bit=int(SystemVisibilityBit.UI_HISTORY))
|
||||
|
||||
if normalized_agent_type == "memory":
|
||||
return bit_mask(bit=18)
|
||||
|
||||
agent_config = await self._repository.get_system_agent_config(
|
||||
agent_type=normalized_agent_type
|
||||
)
|
||||
if agent_config is None:
|
||||
raise HTTPException(
|
||||
status_code=422, detail="invalid forwarded_props.agent_type"
|
||||
async def _resolve_user_message_visibility_mask(
|
||||
self, *, runtime_mode: RuntimeMode
|
||||
) -> int:
|
||||
if runtime_mode == RuntimeMode.CHAT:
|
||||
return bit_mask(bit=int(SystemVisibilityBit.UI_HISTORY)) | bit_mask(
|
||||
bit=int(SystemVisibilityBit.CONTEXT_ASSEMBLY)
|
||||
)
|
||||
llm_config = SystemAgentLLMConfig.model_validate(
|
||||
(agent_config.get("config") if isinstance(agent_config, dict) else {}) or {}
|
||||
)
|
||||
agent_mask = bit_mask(bit=llm_config.visibility_consumer_bit)
|
||||
|
||||
if normalized_agent_type == "worker":
|
||||
router_config = await self._repository.get_system_agent_config(
|
||||
agent_type="router"
|
||||
)
|
||||
worker_config = await self._repository.get_system_agent_config(
|
||||
agent_type="worker"
|
||||
)
|
||||
if router_config is None or worker_config is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="system agent visibility config missing",
|
||||
)
|
||||
router_mask = bit_mask(
|
||||
bit=SystemAgentLLMConfig.model_validate(
|
||||
(
|
||||
router_config.get("config")
|
||||
if isinstance(router_config, dict)
|
||||
else {}
|
||||
)
|
||||
or {}
|
||||
).visibility_consumer_bit
|
||||
)
|
||||
worker_mask = bit_mask(
|
||||
bit=SystemAgentLLMConfig.model_validate(
|
||||
(
|
||||
worker_config.get("config")
|
||||
if isinstance(worker_config, dict)
|
||||
else {}
|
||||
)
|
||||
or {}
|
||||
).visibility_consumer_bit
|
||||
)
|
||||
return history_bit_mask | router_mask | worker_mask
|
||||
|
||||
return history_bit_mask | agent_mask
|
||||
return 0
|
||||
|
||||
async def _prepare_user_message(
|
||||
self,
|
||||
@@ -309,7 +186,7 @@ class AgentService:
|
||||
mime_type=mime_type,
|
||||
)
|
||||
)
|
||||
if len(user_attachments) > _MAX_ATTACHMENTS_PER_MESSAGE:
|
||||
if len(user_attachments) > MAX_ATTACHMENTS_PER_MESSAGE:
|
||||
raise HTTPException(status_code=422, detail="Too many attachments")
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -360,14 +237,14 @@ class AgentService:
|
||||
if not isinstance(content_type, str):
|
||||
raise HTTPException(status_code=422, detail="Unsupported attachment type")
|
||||
mime_type = content_type.lower()
|
||||
if mime_type not in _ALLOWED_ATTACHMENT_MIME_TYPES:
|
||||
if mime_type not in {"image/png", "image/jpeg", "image/webp"}:
|
||||
raise HTTPException(status_code=422, detail="Unsupported attachment type")
|
||||
if not payload:
|
||||
raise HTTPException(status_code=422, detail="Empty attachment")
|
||||
if len(payload) > _MAX_ATTACHMENT_BYTES:
|
||||
if len(payload) > MAX_ATTACHMENT_BYTES:
|
||||
raise HTTPException(status_code=413, detail="Attachment too large")
|
||||
|
||||
suffix = _mime_to_suffix(mime_type)
|
||||
suffix = mime_to_suffix(mime_type)
|
||||
checksum = hashlib.sha1(payload).hexdigest()[:16]
|
||||
filename_seed = filename if isinstance(filename, str) and filename else "upload"
|
||||
filename_hash = hashlib.sha1(filename_seed.encode("utf-8")).hexdigest()[:8]
|
||||
@@ -424,7 +301,7 @@ class AgentService:
|
||||
|
||||
normalized_path = path.strip()
|
||||
expected_prefix = f"agent-inputs/{current_user.id}/"
|
||||
if not _is_safe_attachment_path(
|
||||
if not is_safe_attachment_path(
|
||||
normalized_path, expected_prefix=expected_prefix
|
||||
):
|
||||
raise HTTPException(status_code=422, detail="Invalid attachment path scope")
|
||||
@@ -503,7 +380,7 @@ class AgentService:
|
||||
f"agent-inputs/{current_user.id}/{thread_id}/uploads/"
|
||||
)
|
||||
for attachment in attachments:
|
||||
if not _is_safe_attachment_path(
|
||||
if not is_safe_attachment_path(
|
||||
attachment.path,
|
||||
expected_prefix=expected_prefix,
|
||||
):
|
||||
@@ -586,134 +463,6 @@ class AgentService:
|
||||
raise HTTPException(status_code=422, detail="INVALID_BINARY_URL_BUCKET")
|
||||
|
||||
expected_prefix = f"agent-inputs/{current_user.id}/{thread_id}/uploads/"
|
||||
if not _is_safe_attachment_path(path, expected_prefix=expected_prefix):
|
||||
if not is_safe_attachment_path(path, expected_prefix=expected_prefix):
|
||||
raise HTTPException(status_code=422, detail="INVALID_BINARY_URL_PATH_SCOPE")
|
||||
return bucket, path
|
||||
|
||||
|
||||
class AsrService:
|
||||
def __init__(self) -> None:
|
||||
self._api_key: str | None = None
|
||||
|
||||
def _get_api_key(self) -> str:
|
||||
if self._api_key is None:
|
||||
dashscope_key = config.llm.provider_keys.get("dashscope")
|
||||
if not dashscope_key:
|
||||
raise ValueError(
|
||||
"DASHSCOPE_API_KEY not configured. Set SOCIAL_LLM__PROVIDER_KEYS__DASHSCOPE in environment."
|
||||
)
|
||||
self._api_key = dashscope_key
|
||||
return self._api_key
|
||||
|
||||
async def transcribe_file(self, file_path: str, filename: str) -> str:
|
||||
try:
|
||||
dashscope.api_key = self._get_api_key()
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
class SyncCallback(RecognitionCallback):
|
||||
error: str | None = None
|
||||
|
||||
def on_error(self, result: Any) -> None:
|
||||
self.error = str(result)
|
||||
|
||||
callback = SyncCallback()
|
||||
recognizer = Recognition(
|
||||
model="fun-asr-realtime-2026-02-28",
|
||||
callback=callback,
|
||||
format="wav",
|
||||
sample_rate=16000,
|
||||
)
|
||||
|
||||
result: Any = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: recognizer.call(file=file_path),
|
||||
)
|
||||
|
||||
if callback.error:
|
||||
raise RuntimeError(f"ASR error: {callback.error}")
|
||||
status_code = self._extract_field(result, "status_code")
|
||||
if status_code != 200:
|
||||
message = self._extract_field(result, "message")
|
||||
raise RuntimeError(f"ASR transcription failed: {message}")
|
||||
|
||||
sentence = self._extract_sentence_payload(result)
|
||||
if sentence is None:
|
||||
request_id = self._extract_field(result, "request_id")
|
||||
logger.warning(
|
||||
"ASR returned empty result", extra={"request_id": request_id}
|
||||
)
|
||||
return ""
|
||||
|
||||
if isinstance(sentence, dict):
|
||||
transcription = sentence.get("text", "")
|
||||
elif isinstance(sentence, list):
|
||||
transcription = " ".join(
|
||||
item.get("text", "") for item in sentence if isinstance(item, dict)
|
||||
)
|
||||
else:
|
||||
transcription = str(sentence) if sentence else ""
|
||||
|
||||
logger.info(
|
||||
"ASR transcription completed",
|
||||
extra={"filename": filename, "transcript_length": len(transcription)},
|
||||
)
|
||||
return transcription
|
||||
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except RuntimeError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.exception("ASR transcription error")
|
||||
raise RuntimeError(f"ASR transcription failed: {exc}") from exc
|
||||
|
||||
def _extract_sentence_payload(self, result: Any) -> Any | None:
|
||||
if isinstance(result, dict):
|
||||
output = result.get("output")
|
||||
if isinstance(output, dict):
|
||||
return output.get("sentence")
|
||||
if output is not None:
|
||||
return getattr(output, "sentence", None)
|
||||
return result.get("sentence")
|
||||
|
||||
get_sentence = getattr(result, "get_sentence", None)
|
||||
if callable(get_sentence):
|
||||
sentence = get_sentence()
|
||||
if sentence is not None:
|
||||
return sentence
|
||||
|
||||
output = getattr(result, "output", None)
|
||||
if output is None:
|
||||
return None
|
||||
if isinstance(output, dict):
|
||||
return output.get("sentence")
|
||||
return getattr(output, "sentence", None)
|
||||
|
||||
def _extract_field(self, result: Any, field: str) -> Any | None:
|
||||
if isinstance(result, dict):
|
||||
return result.get(field)
|
||||
return getattr(result, field, None)
|
||||
|
||||
|
||||
asr_service = AsrService()
|
||||
|
||||
|
||||
def _mime_to_suffix(mime_type: str) -> str:
|
||||
mapping = {
|
||||
"image/png": "png",
|
||||
"image/jpeg": "jpg",
|
||||
"image/webp": "webp",
|
||||
}
|
||||
return mapping.get(mime_type.lower(), "bin")
|
||||
|
||||
|
||||
def _is_safe_attachment_path(path: str, *, expected_prefix: str) -> bool:
|
||||
normalized = path.strip()
|
||||
if not normalized:
|
||||
return False
|
||||
if normalized.startswith("/"):
|
||||
return False
|
||||
if ".." in normalized:
|
||||
return False
|
||||
return normalized.startswith(expected_prefix)
|
||||
|
||||
@@ -14,6 +14,11 @@ from schemas.messages.chat_message import (
|
||||
extract_user_message_attachments,
|
||||
)
|
||||
|
||||
ALLOWED_ATTACHMENT_MIME_TYPES = {"image/png", "image/jpeg", "image/webp"}
|
||||
MAX_ATTACHMENT_BYTES = 5 * 1024 * 1024
|
||||
MAX_TOTAL_ATTACHMENT_BYTES = 12 * 1024 * 1024
|
||||
MAX_ATTACHMENTS_PER_MESSAGE = 3
|
||||
|
||||
|
||||
def convert_message_to_history(
|
||||
message: AgentChatMessage,
|
||||
@@ -124,3 +129,23 @@ def _compile_worker_ui_hints(
|
||||
return compiled
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def mime_to_suffix(mime_type: str) -> str:
|
||||
mapping = {
|
||||
"image/png": "png",
|
||||
"image/jpeg": "jpg",
|
||||
"image/webp": "webp",
|
||||
}
|
||||
return mapping.get(mime_type.lower(), "bin")
|
||||
|
||||
|
||||
def is_safe_attachment_path(path: str, *, expected_prefix: str) -> bool:
|
||||
normalized = path.strip()
|
||||
if not normalized:
|
||||
return False
|
||||
if normalized.startswith("/"):
|
||||
return False
|
||||
if ".." in normalized:
|
||||
return False
|
||||
return normalized.startswith(expected_prefix)
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
from v1.memory.service import MemoryService
|
||||
|
||||
__all__ = ["MemoryService"]
|
||||
@@ -0,0 +1,35 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.db.base_repository import BaseRepository
|
||||
from models.automation_jobs import AutomationJob
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
||||
class MemoryRepositoryLike(Protocol):
|
||||
async def get_job_by_id_and_owner(
|
||||
self, *, job_id: UUID, owner_id: UUID
|
||||
) -> AutomationJob | None: ...
|
||||
|
||||
|
||||
class MemoryRepository(BaseRepository[AutomationJob]):
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
super().__init__(session=session, model=AutomationJob)
|
||||
|
||||
async def get_job_by_id_and_owner(
|
||||
self, *, job_id: UUID, owner_id: UUID
|
||||
) -> AutomationJob | None:
|
||||
stmt = (
|
||||
select(AutomationJob)
|
||||
.where(AutomationJob.id == job_id)
|
||||
.where(AutomationJob.owner_id == owner_id)
|
||||
.where(AutomationJob.deleted_at.is_(None))
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
@@ -0,0 +1,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from schemas.automation.config import AutomationJobConfig
|
||||
from v1.memory.repository import MemoryRepositoryLike
|
||||
|
||||
|
||||
class MemoryService:
|
||||
_repository: MemoryRepositoryLike
|
||||
|
||||
def __init__(self, repository: MemoryRepositoryLike) -> None:
|
||||
self._repository = repository
|
||||
|
||||
async def get_memory_job_config(
|
||||
self, *, job_id: UUID, owner_id: UUID
|
||||
) -> AutomationJobConfig:
|
||||
job = await self._repository.get_job_by_id_and_owner(
|
||||
job_id=job_id, owner_id=owner_id
|
||||
)
|
||||
if job is None:
|
||||
raise HTTPException(status_code=404, detail="Automation job not found")
|
||||
return AutomationJobConfig.model_validate(job.config or {})
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agentscope.runtime.consumer_registry import build_consumer_registry
|
||||
from core.agentscope.runtime.registry_builder import build_consumer_registry
|
||||
|
||||
|
||||
def test_build_consumer_registry_from_system_agent_configs() -> None:
|
||||
|
||||
@@ -45,17 +45,6 @@ def _user_context() -> UserContext:
|
||||
)
|
||||
|
||||
|
||||
def test_parse_agent_type_supports_known_stages() -> None:
|
||||
assert AgentScopeRunner._parse_agent_type(stage_name="router") == AgentType.ROUTER
|
||||
assert AgentScopeRunner._parse_agent_type(stage_name="worker") == AgentType.WORKER
|
||||
assert AgentScopeRunner._parse_agent_type(stage_name="memory") == AgentType.MEMORY
|
||||
|
||||
|
||||
def test_parse_agent_type_rejects_unknown_stage() -> None:
|
||||
with pytest.raises(ValueError, match="unsupported stage name"):
|
||||
AgentScopeRunner._parse_agent_type(stage_name="planner")
|
||||
|
||||
|
||||
def test_build_worker_input_messages_only_contains_router_contract() -> None:
|
||||
runner = AgentScopeRunner()
|
||||
router_output = RouterAgentOutput(
|
||||
|
||||
@@ -5,14 +5,13 @@ import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from core.taskiq.app import broker, bulk_broker, critical_broker, default_broker
|
||||
from core.taskiq.app import broker, worker_agent_broker, worker_automation_broker
|
||||
|
||||
|
||||
def test_taskiq_broker_is_configured() -> None:
|
||||
assert broker is not None
|
||||
assert default_broker is broker
|
||||
assert critical_broker is not None
|
||||
assert bulk_broker is not None
|
||||
assert worker_agent_broker is broker
|
||||
assert worker_automation_broker is not None
|
||||
|
||||
|
||||
def test_taskiq_app_configures_logging_on_import(
|
||||
|
||||
@@ -2,13 +2,12 @@ from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from schemas.agent.consumer_registry import AgentConsumerBinding, ConsumerRegistry
|
||||
from schemas.agent.pipeline_spec import (
|
||||
ContextPolicy,
|
||||
ExecutorKind,
|
||||
PipelineSpec,
|
||||
StageSpec,
|
||||
from core.agentscope.schemas.consumer_registry import (
|
||||
AgentConsumerBinding,
|
||||
ConsumerRegistry,
|
||||
)
|
||||
from core.agentscope.schemas.pipeline_spec import ExecutorKind, PipelineSpec, StageSpec
|
||||
from schemas.agent.system_agent import AgentType
|
||||
|
||||
|
||||
def test_consumer_registry_rejects_duplicate_bits() -> None:
|
||||
@@ -29,9 +28,9 @@ def test_pipeline_spec_requires_non_empty_stages() -> None:
|
||||
def test_stage_spec_normalizes_stage_name() -> None:
|
||||
spec = StageSpec(
|
||||
stage_name=" Worker ",
|
||||
agent_type=AgentType.WORKER,
|
||||
executor_kind=ExecutorKind.REACT,
|
||||
default_visibility_mask=1,
|
||||
context_policy=ContextPolicy(consumer_agent_type="worker", count=20),
|
||||
)
|
||||
|
||||
assert spec.stage_name == "worker"
|
||||
assert spec.agent_type == AgentType.WORKER
|
||||
|
||||
@@ -159,7 +159,7 @@ def _user() -> CurrentUser:
|
||||
)
|
||||
|
||||
|
||||
def _build_run_input(*, urls: list[str], agent_type: str = "worker") -> RunAgentInput:
|
||||
def _build_run_input(*, urls: list[str], runtime_mode: str = "chat") -> RunAgentInput:
|
||||
content: list[dict[str, str]] = [{"type": "text", "text": "hello"}]
|
||||
for url in urls:
|
||||
content.append({"type": "binary", "mimeType": "image/png", "url": url})
|
||||
@@ -177,7 +177,7 @@ def _build_run_input(*, urls: list[str], agent_type: str = "worker") -> RunAgent
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {"agent_type": agent_type},
|
||||
"forwardedProps": {"runtime_mode": runtime_mode},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -275,7 +275,7 @@ async def test_enqueue_run_rejects_unknown_agent_type(monkeypatch) -> None:
|
||||
urls=[
|
||||
f"{base_url}/storage/v1/object/sign/agent-test-bucket/{safe_path}?token=1"
|
||||
],
|
||||
agent_type="planner",
|
||||
runtime_mode="planner",
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
@@ -285,7 +285,7 @@ async def test_enqueue_run_rejects_unknown_agent_type(monkeypatch) -> None:
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_run_rejects_memory_mode_for_api(monkeypatch) -> None:
|
||||
async def test_enqueue_run_rejects_invalid_runtime_mode(monkeypatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
agent_service_module.config.storage, "bucket", "agent-test-bucket"
|
||||
)
|
||||
@@ -296,24 +296,12 @@ async def test_enqueue_run_rejects_memory_mode_for_api(monkeypatch) -> None:
|
||||
stream=_FakeStream(),
|
||||
attachment_storage=_FakeAttachmentStorage(),
|
||||
)
|
||||
base_url = str(config.supabase.url).rstrip("/")
|
||||
safe_path = quote(
|
||||
"agent-inputs/00000000-0000-0000-0000-000000000001/"
|
||||
"00000000-0000-0000-0000-000000000001/uploads/a.png"
|
||||
)
|
||||
run_input = _build_run_input(
|
||||
urls=[
|
||||
f"{base_url}/storage/v1/object/sign/agent-test-bucket/{safe_path}?token=1"
|
||||
],
|
||||
agent_type="memory",
|
||||
)
|
||||
run_input = _build_run_input(urls=[], runtime_mode="planner")
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.enqueue_run(run_input=run_input, current_user=_user())
|
||||
|
||||
assert exc_info.value.status_code == 422
|
||||
assert exc_info.value.detail == "memory mode is automation-only"
|
||||
assert repository.created_session_calls == 0
|
||||
assert repository.persisted_user_messages == []
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user