refactor(agent): restructure visibility masks, task queues, and memory service

Visibility mask refactoring:
- Replace dead UI_REALTIME bit with CONTEXT_ASSEMBLY (bit 1)
- Remove visibility_consumer_bit from SystemAgentLLMConfig and system_agents.yaml
- Simplify _resolve_user_message_visibility_mask: chat->UI_HISTORY|CONTEXT_ASSEMBLY, automation->0
- Simplify _resolve_stage_visibility_mask: memory->UI_HISTORY, router/worker->UI_HISTORY|CONTEXT_ASSEMBLY
- Remove stage_visibility_bit_map from store.py

Task queue renaming:
- Replace default_broker/bulk_broker/critical_broker with worker_agent_broker/worker_automation_broker
- Queue names: 'default'/'bulk'/'critical' -> 'agent'/'automation'
- Rename run_command_task -> run_command_task_agent/run_command_task_automation
- AgentService derives queue from runtime_mode: chat->agent, automation->automation

Architecture cleanup:
- Move context_service.py from runtime/ to agentscope/services/
- Add MemoryService in v1/memory/ following repository/service pattern
- Move consumer_registry.py and pipeline_spec.py from schemas/agent to agentscope/schemas/
- Delete dead code: registry_builder.py, VisibilityBitRef
- Delete superseded plan docs
This commit is contained in:
zl-q
2026-03-22 20:35:55 +08:00
parent 20b9e70e84
commit 80ad5141a6
37 changed files with 628 additions and 2428 deletions
+7 -51
View File
@@ -8,12 +8,10 @@ from core.agentscope.events.persistence import MessageRepository, SessionReposit
from core.logging import get_logger
from models.agent_chat_message import AgentChatMessageRole
from models.agent_chat_session import AgentChatSessionStatus
from models.system_agents import SystemAgents
from schemas.agent.system_agent import AgentType, SystemAgentLLMConfig
from schemas.agent.system_agent import AgentType
from schemas.agent.runtime_models import AgentOutput, ToolAgentOutput
from schemas.agent.visibility import SystemVisibilityBit, bit_mask
from schemas.messages.chat_message import AgentChatMessageMetadata
from sqlalchemy import select
class EventStore(Protocol):
@@ -48,9 +46,6 @@ class SqlAlchemyEventStore:
async with self._session_factory() as session:
session_repo = SessionRepository(session)
message_repo = MessageRepository(session)
stage_visibility_bit_map = await self._load_stage_visibility_bit_map(
session=session
)
chat_session = await session_repo.get_session(session_id=session_id)
if chat_session is None:
return
@@ -83,7 +78,6 @@ class SqlAlchemyEventStore:
chat_session=chat_session,
session_repo=session_repo,
message_repo=message_repo,
stage_visibility_bit_map=stage_visibility_bit_map,
)
elif event_type == "TOOL_CALL_RESULT":
await self._persist_tool_call_result(
@@ -92,7 +86,6 @@ class SqlAlchemyEventStore:
chat_session=chat_session,
session_repo=session_repo,
message_repo=message_repo,
stage_visibility_bit_map=stage_visibility_bit_map,
)
await session.commit()
@@ -105,7 +98,6 @@ class SqlAlchemyEventStore:
chat_session: Any,
session_repo: SessionRepository,
message_repo: MessageRepository,
stage_visibility_bit_map: dict[str, int],
) -> None:
message_id_raw = self._event_value(event, "messageId")
message_id = message_id_raw if isinstance(message_id_raw, str) else ""
@@ -146,17 +138,7 @@ class SqlAlchemyEventStore:
try:
worker_output = AgentOutput.model_validate(worker_output_payload)
raw_agent_type = self._event_value(event, "stage")
normalized_agent_type = (
str(raw_agent_type).strip().lower()
if isinstance(raw_agent_type, str)
else AgentType.WORKER.value
)
agent_type = (
AgentType.MEMORY
if normalized_agent_type == AgentType.MEMORY.value
else AgentType.WORKER
)
agent_type = AgentType.WORKER
metadata_model = AgentChatMessageMetadata(
run_id=run_id_value,
agent_type=agent_type,
@@ -199,7 +181,6 @@ class SqlAlchemyEventStore:
latency_ms=latency_ms,
visibility_mask=self._resolve_stage_visibility_mask(
event=event,
stage_visibility_bit_map=stage_visibility_bit_map,
),
)
@@ -226,7 +207,6 @@ class SqlAlchemyEventStore:
chat_session: Any,
session_repo: SessionRepository,
message_repo: MessageRepository,
stage_visibility_bit_map: dict[str, int],
) -> None:
run_id = self._event_value(event, "runId")
run_id_value = run_id if isinstance(run_id, str) and run_id else None
@@ -272,7 +252,6 @@ class SqlAlchemyEventStore:
metadata=metadata_model.model_dump(mode="json", exclude_none=True),
visibility_mask=self._resolve_stage_visibility_mask(
event=event,
stage_visibility_bit_map=stage_visibility_bit_map,
),
)
@@ -301,39 +280,16 @@ class SqlAlchemyEventStore:
self,
*,
event: dict[str, Any],
stage_visibility_bit_map: dict[str, int],
) -> int:
base = bit_mask(bit=int(SystemVisibilityBit.UI_HISTORY))
raw_stage = self._event_value(event, "stage")
if not isinstance(raw_stage, str):
return base
return bit_mask(bit=int(SystemVisibilityBit.UI_HISTORY))
normalized_stage = raw_stage.strip().lower()
bit = stage_visibility_bit_map.get(normalized_stage)
if bit is None and normalized_stage == AgentType.MEMORY.value:
bit = 18
if bit is None:
return base
return base | bit_mask(bit=bit)
async def _load_stage_visibility_bit_map(
self,
*,
session: Any,
) -> dict[str, int]:
stmt = select(SystemAgents.agent_type, SystemAgents.config).where(
SystemAgents.agent_type.in_(
[AgentType.ROUTER.value, AgentType.WORKER.value, AgentType.MEMORY.value]
)
if normalized_stage == "memory":
return bit_mask(bit=int(SystemVisibilityBit.UI_HISTORY))
return bit_mask(bit=int(SystemVisibilityBit.UI_HISTORY)) | bit_mask(
bit=int(SystemVisibilityBit.CONTEXT_ASSEMBLY)
)
rows = (await session.execute(stmt)).all()
bit_map: dict[str, int] = {}
for agent_type, raw_config in rows:
if not isinstance(agent_type, str):
continue
config_payload = raw_config if isinstance(raw_config, dict) else {}
llm_config = SystemAgentLLMConfig.model_validate(config_payload)
bit_map[agent_type.strip().lower()] = llm_config.visibility_consumer_bit
return bit_map
async def _update_session_state(
self,
@@ -1,12 +1,7 @@
from __future__ import annotations
from schemas.agent.pipeline_spec import (
ContextPolicy,
ContextWindowMode,
ExecutorKind,
PipelineSpec,
StageSpec,
)
from core.agentscope.schemas.pipeline_spec import ExecutorKind, PipelineSpec, StageSpec
from schemas.agent.system_agent import AgentType
def build_default_pipeline_spec(*, mode: str) -> PipelineSpec:
@@ -17,23 +12,13 @@ def build_default_pipeline_spec(*, mode: str) -> PipelineSpec:
stages=[
StageSpec(
stage_name="router",
agent_type=AgentType.ROUTER,
executor_kind=ExecutorKind.SINGLE_SHOT,
default_visibility_mask=0,
context_policy=ContextPolicy(
consumer_agent_type="router",
window_mode=ContextWindowMode.DAY,
count=20,
),
),
StageSpec(
stage_name="worker",
agent_type=AgentType.WORKER,
executor_kind=ExecutorKind.REACT,
default_visibility_mask=0,
context_policy=ContextPolicy(
consumer_agent_type="worker",
window_mode=ContextWindowMode.NUMBER,
count=20,
),
),
],
)
@@ -44,13 +29,8 @@ def build_default_pipeline_spec(*, mode: str) -> PipelineSpec:
stages=[
StageSpec(
stage_name="memory",
agent_type=AgentType.MEMORY,
executor_kind=ExecutorKind.REACT,
default_visibility_mask=0,
context_policy=ContextPolicy(
consumer_agent_type="memory",
window_mode=ContextWindowMode.DAY,
count=20,
),
)
],
)
@@ -1,6 +1,9 @@
from __future__ import annotations
from schemas.agent.consumer_registry import AgentConsumerBinding, ConsumerRegistry
from core.agentscope.schemas.consumer_registry import (
AgentConsumerBinding,
ConsumerRegistry,
)
def build_consumer_registry(
+1 -16
View File
@@ -88,10 +88,7 @@ class AgentScopeRunner:
owner_id = UUID(user_context.id)
runtime_client_time = self._resolve_runtime_client_time(run_input=run_input)
pipeline_spec = build_default_pipeline_spec(mode=system_agent_mode)
stage_agent_types = [
self._parse_agent_type(stage_name=stage.stage_name)
for stage in pipeline_spec.stages
]
stage_agent_types = [stage.agent_type for stage in pipeline_spec.stages]
async with AsyncSessionLocal() as session:
if stage_agent_types == [AgentType.ROUTER, AgentType.WORKER]:
@@ -177,17 +174,6 @@ class AgentScopeRunner:
enabled_tool_names=enabled_tool_names,
)
@staticmethod
def _parse_agent_type(*, stage_name: str) -> AgentType:
normalized = stage_name.strip().lower()
if normalized == AgentType.ROUTER.value:
return AgentType.ROUTER
if normalized == AgentType.WORKER.value:
return AgentType.WORKER
if normalized == AgentType.MEMORY.value:
return AgentType.MEMORY
raise ValueError(f"unsupported stage name: {stage_name}")
async def _load_stage_config(
self,
*,
@@ -355,7 +341,6 @@ class AgentScopeRunner:
temperature=0.7,
max_tokens=None,
timeout_seconds=30,
visibility_consumer_bit=18,
context_messages=ContextMessagesConfig(
mode=(
ContextBuildStrategy.DAY
+14 -40
View File
@@ -12,7 +12,7 @@ from core.agentscope.events import (
RedisStreamBus,
SqlAlchemyEventStore,
)
from core.agentscope.runtime.context_service import AgentContextService
from core.agentscope.services.context_service import AgentContextService
from core.agentscope.runtime.orchestrator import AgentScopeRuntimeOrchestrator
from core.agentscope.runtime.pipeline_registry import build_default_pipeline_spec
from core.agentscope.schemas.agui_input import parse_run_input
@@ -20,8 +20,7 @@ from core.auth.models import CurrentUser
from core.config.settings import config
from core.db.session import AsyncSessionLocal
from core.logging import get_logger
from core.taskiq.app import bulk_broker, critical_broker, default_broker
from models.automation_jobs import AutomationJob
from core.taskiq.app import worker_agent_broker, worker_automation_broker
from schemas.agent.visibility import SystemVisibilityBit, bit_mask
from schemas.automation.config import AutomationJobConfig
from schemas.messages.chat_message import (
@@ -33,8 +32,10 @@ from schemas.user import UserContext
from services.base.redis import get_or_init_redis_client
from services.base.supabase import supabase_service
from v1.agent.repository import AgentRepository
from v1.memory.repository import MemoryRepository
from v1.memory.service import MemoryService
from v1.users.dependencies import get_user_service
from sqlalchemy import select
logger = get_logger("core.agentscope.runtime.tasks")
_MAX_CONTEXT_ATTACHMENTS = 3
@@ -188,29 +189,6 @@ async def _build_recent_context_messages(
return converted
async def _load_memory_job_config(
*,
session: Any,
owner_id: UUID,
automation_job_id: str,
) -> AutomationJobConfig:
try:
job_uuid = UUID(automation_job_id)
except ValueError as exc:
raise ValueError("automation_job_id is invalid") from exc
stmt = (
select(AutomationJob)
.where(AutomationJob.id == job_uuid)
.where(AutomationJob.owner_id == owner_id)
.where(AutomationJob.deleted_at.is_(None))
)
row = (await session.execute(stmt)).scalar_one_or_none()
if row is None:
raise ValueError("automation job not found")
return AutomationJobConfig.model_validate(row.config or {})
async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
command_type = str(command.get("command", "run")).strip().lower()
raw_owner_id = command.get("owner_id")
@@ -245,10 +223,11 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
memory_job_config: AutomationJobConfig | None = None
if system_agent_mode == "memory":
assert isinstance(raw_automation_job_id, str)
memory_job_config = await _load_memory_job_config(
session=session,
job_uuid = UUID(raw_automation_job_id)
memory_service = MemoryService(MemoryRepository(session))
memory_job_config = await memory_service.get_memory_job_config(
job_id=job_uuid,
owner_id=owner_id,
automation_job_id=raw_automation_job_id,
)
redis_client = await get_or_init_redis_client()
@@ -272,7 +251,7 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
context_messages = await _build_recent_context_messages(
session=session,
thread_id=thread_id,
context_mode=pipeline_spec.stages[0].context_policy.consumer_agent_type,
context_mode=pipeline_spec.stages[0].agent_type.value,
memory_job_config=memory_job_config,
)
@@ -296,16 +275,11 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
}
@default_broker.task(task_name="tasks.agentscope.run_command")
async def run_command_task(command: dict[str, Any]) -> dict[str, object]:
@worker_agent_broker.task(task_name="tasks.agentscope.run_command.agent")
async def run_command_task_agent(command: dict[str, object]) -> dict[str, object]:
return await run_agentscope_task(command)
@critical_broker.task(task_name="tasks.agentscope.run_command.critical")
async def run_command_task_critical(command: dict[str, Any]) -> dict[str, object]:
return await run_agentscope_task(command)
@bulk_broker.task(task_name="tasks.agentscope.run_command.bulk")
async def run_command_task_bulk(command: dict[str, Any]) -> dict[str, object]:
@worker_automation_broker.task(task_name="tasks.agentscope.run_command.automation")
async def run_command_task_automation(command: dict[str, object]) -> dict[str, object]:
return await run_agentscope_task(command)
@@ -4,40 +4,20 @@ from enum import Enum
from pydantic import BaseModel, ConfigDict, Field, field_validator
from schemas.agent.system_agent import AgentType
class ExecutorKind(str, Enum):
SINGLE_SHOT = "single_shot"
REACT = "react"
class ContextWindowMode(str, Enum):
DAY = "day"
NUMBER = "number"
class ContextPolicy(BaseModel):
model_config = ConfigDict(extra="forbid")
consumer_agent_type: str = Field(..., min_length=1, max_length=64)
window_mode: ContextWindowMode = ContextWindowMode.NUMBER
count: int = Field(default=20, ge=1, le=200)
@field_validator("consumer_agent_type")
@classmethod
def _normalize_consumer_agent_type(cls, value: str) -> str:
normalized = value.strip().lower()
if not normalized:
raise ValueError("consumer_agent_type must not be empty")
return normalized
class StageSpec(BaseModel):
model_config = ConfigDict(extra="forbid")
stage_name: str = Field(..., min_length=1, max_length=64)
agent_type: AgentType
executor_kind: ExecutorKind
default_visibility_mask: int = Field(..., ge=0, le=(1 << 63) - 1)
context_policy: ContextPolicy
@field_validator("stage_name")
@classmethod
@@ -5,7 +5,7 @@ from typing import Protocol
from core.agentscope.runtime.context_loader_registry import CONTEXT_LOADER_REGISTRY
from schemas.agent.system_agent import SystemAgentLLMConfig
from schemas.agent.visibility import bit_mask
from schemas.agent.visibility import SystemVisibilityBit, bit_mask
_DEFAULT_CONTEXT_WINDOW_USER_MESSAGES = 20
_DEFAULT_ROUTER_CONTEXT_DAY_COUNT = 20
@@ -61,7 +61,7 @@ class AgentContextService:
normalized_config = self._normalize_system_agent_config(raw_llm_config)
context_config = normalized_config.context_messages
visibility_mask = bit_mask(bit=normalized_config.visibility_consumer_bit)
visibility_mask = bit_mask(bit=int(SystemVisibilityBit.CONTEXT_ASSEMBLY))
context_loader = CONTEXT_LOADER_REGISTRY.resolve(mode=context_config.mode)
return await context_loader(
self,
+2 -2
View File
@@ -2,12 +2,12 @@ from __future__ import annotations
from core.logging import get_logger
from core.taskiq.app import bulk_broker
from core.taskiq.app import worker_automation_broker
logger = get_logger("core.automation.tasks")
@bulk_broker.task(task_name="tasks.automation.scan_due_jobs")
@worker_automation_broker.task(task_name="tasks.automation.scan_due_jobs")
async def scan_due_automation_jobs_task(limit: int | None = None) -> dict[str, int]:
from core.automation.scheduler import run_automation_scheduler_scan
@@ -6,7 +6,6 @@ agents:
temperature: 0.7
max_tokens: null
timeout_seconds: 30
visibility_consumer_bit: 16
context_messages:
mode: day
count: 2
@@ -19,7 +18,6 @@ agents:
temperature: 0.7
max_tokens: null
timeout_seconds: 30
visibility_consumer_bit: 17
context_messages:
mode: number
count: 20
+2 -2
View File
@@ -1,3 +1,3 @@
from core.taskiq.app import broker, bulk_broker, critical_broker, default_broker
from core.taskiq.app import broker, worker_agent_broker, worker_automation_broker
__all__ = ["broker", "default_broker", "critical_broker", "bulk_broker"]
__all__ = ["broker", "worker_agent_broker", "worker_automation_broker"]
+4 -6
View File
@@ -21,11 +21,9 @@ def _build_broker(queue_name: str) -> ListQueueBroker:
)
default_broker = _build_broker("default")
critical_broker = _build_broker("critical")
bulk_broker = _build_broker("bulk")
worker_agent_broker = _build_broker("agent")
worker_automation_broker = _build_broker("automation")
# Backward-compatible export name for existing imports/tests.
broker = default_broker
broker = worker_agent_broker
__all__ = ["broker", "default_broker", "critical_broker", "bulk_broker"]
__all__ = ["broker", "worker_agent_broker", "worker_automation_broker"]
+4 -17
View File
@@ -1,17 +1,10 @@
from schemas.agent.consumer_registry import AgentConsumerBinding, ConsumerRegistry
from schemas.agent.forwarded_props import (
ClientTimeContext,
ForwardedPropsPayload,
parse_forwarded_props_agent_type,
parse_forwarded_props_client_time,
parse_forwarded_props_runtime_mode,
)
from schemas.agent.pipeline_spec import (
ContextPolicy,
ContextWindowMode,
ExecutorKind,
PipelineSpec,
StageSpec,
)
from schemas.agent.forwarded_props import RuntimeMode
from schemas.agent.runtime_models import (
AgentOutput,
ConstraintItem,
@@ -45,28 +38,22 @@ from schemas.agent.ui_hints import (
__all__ = [
"AgentType",
"AgentOutput",
"AgentConsumerBinding",
"ConstraintItem",
"ConsumerRegistry",
"ContextPolicy",
"ContextWindowMode",
"ExecutionMode",
"ExecutorKind",
"ForwardedPropsPayload",
"KeyEntity",
"NormalizedTaskInput",
"PipelineSpec",
"ResultTyping",
"ClientTimeContext",
"ResultType",
"RouterAgentOutput",
"RouterUiDecision",
"RunStatus",
"RuntimeMode",
"TaskType",
"TaskTyping",
"SystemAgentLLMConfig",
"SystemVisibilityBit",
"StageSpec",
"ToolAgentOutput",
"ToolStatus",
"UiMode",
@@ -79,7 +66,7 @@ __all__ = [
"WorkerAgentOutputLite",
"WorkerAgentOutputRich",
"bit_mask",
"parse_forwarded_props_agent_type",
"parse_forwarded_props_client_time",
"parse_forwarded_props_runtime_mode",
"resolve_worker_output_model",
]
+9 -11
View File
@@ -1,6 +1,7 @@
from __future__ import annotations
from datetime import datetime
from enum import Enum
import re
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
@@ -59,20 +60,17 @@ class ClientTimeContext(BaseModel):
return value
class RuntimeMode(str, Enum):
CHAT = "chat"
AUTOMATION = "automation"
class ForwardedPropsPayload(BaseModel):
model_config = ConfigDict(extra="forbid")
agent_type: str = Field(..., min_length=1, max_length=64)
runtime_mode: RuntimeMode
client_time: ClientTimeContext | None = None
@field_validator("agent_type")
@classmethod
def validate_agent_type(cls, value: str) -> str:
normalized = value.strip().lower()
if not normalized:
raise ValueError("invalid forwarded_props.agent_type")
return normalized
def parse_forwarded_props(forwarded_props: object) -> ForwardedPropsPayload:
if not isinstance(forwarded_props, dict):
@@ -90,6 +88,6 @@ def parse_forwarded_props_client_time(
return payload.client_time
def parse_forwarded_props_agent_type(forwarded_props: object) -> str:
def parse_forwarded_props_runtime_mode(forwarded_props: object) -> RuntimeMode:
payload = parse_forwarded_props(forwarded_props)
return payload.agent_type
return payload.runtime_mode
+8 -8
View File
@@ -2,15 +2,13 @@ from __future__ import annotations
from enum import Enum
from pydantic import BaseModel, Field, field_validator
from core.agentscope.tools.tool_config import AgentTool, parse_agent_tool
from pydantic import BaseModel, Field, field_validator
class AgentType(str, Enum):
ROUTER = "router"
WORKER = "worker"
MEMORY = "memory"
class ContextBuildStrategy(str, Enum):
@@ -30,7 +28,6 @@ class SystemAgentLLMConfig(BaseModel):
context_messages: ContextMessagesConfig = Field(
default_factory=ContextMessagesConfig
)
visibility_consumer_bit: int = Field(default=16, ge=16, le=63)
enabled_tools: list[AgentTool] = Field(default_factory=list, max_length=32)
@field_validator("enabled_tools", mode="before")
@@ -42,10 +39,13 @@ class SystemAgentLLMConfig(BaseModel):
raise ValueError("enabled_tools must be a list")
normalized: list[AgentTool] = []
for item in value:
raw_item = str(item or "").strip()
if not raw_item:
continue
tool = parse_agent_tool(raw_item)
if isinstance(item, AgentTool):
tool = item
else:
raw_item = str(item or "").strip()
if not raw_item:
continue
tool = parse_agent_tool(raw_item)
if tool not in normalized:
normalized.append(tool)
return normalized
+1 -1
View File
@@ -7,7 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator
class SystemVisibilityBit(IntEnum):
UI_HISTORY = 0
UI_REALTIME = 1
CONTEXT_ASSEMBLY = 1
class VisibilityMask(BaseModel):
+120
View File
@@ -0,0 +1,120 @@
from __future__ import annotations
import asyncio
from typing import Any
import dashscope
from dashscope.audio.asr import Recognition, RecognitionCallback
from core.config.settings import config
from core.logging import get_logger
logger = get_logger(__name__)
class AsrService:
def __init__(self) -> None:
self._api_key: str | None = None
def _get_api_key(self) -> str:
if self._api_key is None:
dashscope_key = config.llm.provider_keys.get("dashscope")
if not dashscope_key:
raise ValueError(
"DASHSCOPE_API_KEY not configured. Set SOCIAL_LLM__PROVIDER_KEYS__DASHSCOPE in environment."
)
self._api_key = dashscope_key
return self._api_key
async def transcribe_file(self, file_path: str, filename: str) -> str:
try:
dashscope.api_key = self._get_api_key()
loop = asyncio.get_event_loop()
class SyncCallback(RecognitionCallback):
error: str | None = None
def on_error(self, result: Any) -> None:
self.error = str(result)
callback = SyncCallback()
recognizer = Recognition(
model="fun-asr-realtime-2026-02-28",
callback=callback,
format="wav",
sample_rate=16000,
)
result: Any = await loop.run_in_executor(
None,
lambda: recognizer.call(file=file_path),
)
if callback.error:
raise RuntimeError(f"ASR error: {callback.error}")
status_code = self._extract_field(result, "status_code")
if status_code != 200:
message = self._extract_field(result, "message")
raise RuntimeError(f"ASR transcription failed: {message}")
sentence = self._extract_sentence_payload(result)
if sentence is None:
request_id = self._extract_field(result, "request_id")
logger.warning(
"ASR returned empty result", extra={"request_id": request_id}
)
return ""
if isinstance(sentence, dict):
transcription = sentence.get("text", "")
elif isinstance(sentence, list):
transcription = " ".join(
item.get("text", "") for item in sentence if isinstance(item, dict)
)
else:
transcription = str(sentence) if sentence else ""
logger.info(
"ASR transcription completed",
extra={"filename": filename, "transcript_length": len(transcription)},
)
return transcription
except asyncio.CancelledError:
raise
except RuntimeError:
raise
except Exception as exc:
logger.exception("ASR transcription error")
raise RuntimeError(f"ASR transcription failed: {exc}") from exc
def _extract_sentence_payload(self, result: Any) -> Any | None:
if isinstance(result, dict):
output = result.get("output")
if isinstance(output, dict):
return output.get("sentence")
if output is not None:
return getattr(output, "sentence", None)
return result.get("sentence")
get_sentence = getattr(result, "get_sentence", None)
if callable(get_sentence):
sentence = get_sentence()
if sentence is not None:
return sentence
output = getattr(result, "output", None)
if output is None:
return None
if isinstance(output, dict):
return output.get("sentence")
return getattr(output, "sentence", None)
def _extract_field(self, result: Any, field: str) -> Any | None:
if isinstance(result, dict):
return result.get(field)
return getattr(result, field, None)
asr_service = AsrService()
+6 -9
View File
@@ -44,17 +44,14 @@ class TaskiqQueueClient:
@staticmethod
def _select_queue_task(command: dict[str, object]) -> Any:
from core.agentscope.runtime.tasks import (
run_command_task,
run_command_task_bulk,
run_command_task_critical,
run_command_task_agent,
run_command_task_automation,
)
queue = str(command.get("queue", "default")).strip().lower()
if queue == "critical":
return run_command_task_critical
if queue == "bulk":
return run_command_task_bulk
return run_command_task
queue = str(command.get("queue", "agent")).strip().lower()
if queue == "automation":
return run_command_task_automation
return run_command_task_agent
async def enqueue(
self, *, command: dict[str, object], dedup_key: str | None
+39 -38
View File
@@ -6,7 +6,7 @@ import re
import tempfile
from collections.abc import AsyncIterator
from datetime import date
from typing import Annotated, Union
from typing import Annotated
from ag_ui.core import RunAgentInput
from core.agentscope.events import to_sse_event
@@ -28,7 +28,7 @@ from fastapi import (
UploadFile,
status,
)
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.responses import StreamingResponse
from services.base.redis import get_or_init_redis_client
from v1.agent.dependencies import get_agent_service
from v1.agent.schemas import (
@@ -39,7 +39,8 @@ from v1.agent.schemas import (
HistorySnapshotResponse,
TaskAcceptedResponse,
)
from v1.agent.service import AgentService, asr_service
from v1.agent.asr import asr_service
from v1.agent.service import AgentService
from v1.users.dependencies import get_current_user
router = APIRouter(prefix="/agent", tags=["agent"])
@@ -73,15 +74,13 @@ async def _acquire_sse_slot(*, user_id: str) -> bool:
count = await redis.incr(key)
if count == 1:
await redis.expire(key, _SSE_SLOT_TTL_SECONDS)
elif count > _MAX_SSE_CONNECTIONS_PER_USER:
await redis.decr(key)
return False
else:
ttl = await redis.ttl(key)
if int(ttl) < 0:
if ttl < 0:
await redis.expire(key, _SSE_SLOT_TTL_SECONDS)
if int(count) > _MAX_SSE_CONNECTIONS_PER_USER:
after_decr = await redis.decr(key)
if int(after_decr) <= 0:
await redis.delete(key)
return False
return True
except Exception as exc: # noqa: BLE001
logger.warning(
@@ -97,13 +96,18 @@ async def _release_sse_slot(*, user_id: str) -> None:
redis = await get_or_init_redis_client()
key = f"agent:sse-active:{user_id}"
count = await redis.decr(key)
if int(count) <= 0:
if count <= 0:
await redis.delete(key)
return None
ttl = await redis.ttl(key)
if int(ttl) < 0:
await redis.expire(key, _SSE_SLOT_TTL_SECONDS)
except Exception: # noqa: BLE001
else:
ttl = await redis.ttl(key)
if ttl < 0:
await redis.expire(key, _SSE_SLOT_TTL_SECONDS)
except Exception as exc: # noqa: BLE001
logger.warning(
"SSE slot release failed",
user_id=user_id,
reason=str(exc),
)
return None
@@ -176,6 +180,11 @@ async def stream_events(
last_event_id=cursor,
current_user=current_user,
)
except TimeoutError:
idle_polls += 1
yield ": keep-alive\n\n"
await asyncio.sleep(0.2)
continue
except Exception as exc: # noqa: BLE001
logger.warning(
"SSE stream read failed",
@@ -183,11 +192,6 @@ async def stream_events(
user_id=str(current_user.id),
reason=str(exc),
)
if "Timeout reading from" in str(exc):
idle_polls += 1
yield ": keep-alive\n\n"
await asyncio.sleep(0.2)
continue
break
if not rows:
@@ -291,12 +295,12 @@ async def create_attachment_signed_url(
async def transcribe(
audio: UploadFile,
request: Request,
_current_user: Annotated[CurrentUser, Depends(get_current_user)],
) -> Union[AsrTranscribeResponse, JSONResponse]:
current_user: Annotated[CurrentUser, Depends(get_current_user)],
) -> AsrTranscribeResponse:
temp_path: str | None = None
try:
if audio.content_type not in _ALLOWED_AUDIO_CONTENT_TYPES:
raise ValueError("Unsupported audio format")
raise HTTPException(status_code=400, detail="Unsupported audio format")
content_length = request.headers.get("content-length")
if content_length is not None:
@@ -309,7 +313,7 @@ async def transcribe(
and declared_length
> _MAX_TRANSCRIBE_AUDIO_BYTES + _MULTIPART_OVERHEAD_BYTES
):
raise ValueError("Audio file too large")
raise HTTPException(status_code=400, detail="Audio file too large")
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
temp_path = tmp_file.name
@@ -322,16 +326,16 @@ async def transcribe(
break
total_bytes += len(chunk)
if total_bytes > _MAX_TRANSCRIBE_AUDIO_BYTES:
raise ValueError("Audio file too large")
raise HTTPException(status_code=400, detail="Audio file too large")
if len(header) < _WAV_HEADER_MIN_BYTES:
required = _WAV_HEADER_MIN_BYTES - len(header)
header.extend(chunk[:required])
tmp_file.write(chunk)
if total_bytes == 0:
raise ValueError("Empty audio file")
raise HTTPException(status_code=400, detail="Empty audio file")
if not _looks_like_wav_header(bytes(header)):
raise ValueError("Unsupported audio format")
raise HTTPException(status_code=400, detail="Unsupported audio format")
transcript = await asr_service.transcribe_file(
temp_path, audio.filename or "unknown"
@@ -339,17 +343,14 @@ async def transcribe(
return AsrTranscribeResponse(transcript=transcript)
except ValueError as exc:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(exc)},
)
except HTTPException:
raise
except RuntimeError:
return JSONResponse(
status_code=status.HTTP_502_BAD_GATEWAY,
content={"detail": "ASR service unavailable"},
)
raise HTTPException(status_code=502, detail="ASR service unavailable")
finally:
await audio.close()
if temp_path and os.path.exists(temp_path):
os.unlink(temp_path)
if temp_path:
try:
os.unlink(temp_path)
except OSError:
pass
+84 -1
View File
@@ -1,12 +1,95 @@
from __future__ import annotations
from typing import Literal
from dataclasses import dataclass
from datetime import date
from typing import Any, Literal, Protocol
from pydantic import BaseModel, ConfigDict, Field
from schemas.agent.ui_schema import UiSchemaRenderer
class AgentRepositoryLike(Protocol):
async def get_session_owner(self, *, session_id: str) -> str: ...
async def create_session_for_user(
self, *, user_id: str, session_id: str | None = None
) -> str: ...
async def commit(self) -> None: ...
async def rollback(self) -> None: ...
async def get_history_day(
self,
*,
session_id: str,
before: date | None,
visibility_mask: int | None = None,
) -> dict[str, object] | None: ...
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None: ...
async def persist_user_message(
self,
*,
session_id: str,
content: str,
metadata: Any,
visibility_mask: int,
) -> None: ...
async def get_system_agent_config(
self, *, agent_type: str
) -> dict[str, object] | None: ...
class QueueClientLike(Protocol):
async def enqueue(
self, *, command: dict[str, object], dedup_key: str | None
) -> str: ...
class EventStreamLike(Protocol):
async def read(
self,
*,
session_id: str,
last_event_id: str | None,
) -> list[dict[str, object]]: ...
class AttachmentStorageLike(Protocol):
async def upload_bytes(
self,
*,
bucket: str,
path: str,
content: bytes,
content_type: str,
) -> str: ...
async def download_bytes(self, *, bucket: str, path: str) -> bytes: ...
async def create_signed_url(
self,
*,
bucket: str,
path: str,
expires_in_seconds: int,
) -> str: ...
def parse_signed_url(self, url: str) -> tuple[str, str]: ...
@dataclass(frozen=True)
class TaskAccepted:
task_id: str
thread_id: str
run_id: str
created: bool
class TaskAcceptedResponse(BaseModel):
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
+37 -288
View File
@@ -1,15 +1,11 @@
from __future__ import annotations
import asyncio
from dataclasses import dataclass
from datetime import date
import hashlib
from typing import Any, Protocol
from urllib.parse import urlparse
import dashscope
from ag_ui.core import RunAgentInput
from dashscope.audio.asr import Recognition, RecognitionCallback
from fastapi import HTTPException
from sqlalchemy.exc import IntegrityError
@@ -17,102 +13,32 @@ from core.auth.models import CurrentUser
from core.agentscope.schemas.agui_input import extract_latest_user_payload
from core.config.settings import config
from core.logging import get_logger
from schemas.agent.forwarded_props import parse_forwarded_props_agent_type
from schemas.agent.system_agent import SystemAgentLLMConfig
from schemas.agent.forwarded_props import (
parse_forwarded_props_runtime_mode,
RuntimeMode,
)
from schemas.agent.visibility import SystemVisibilityBit, bit_mask
from schemas.messages.chat_message import (
AgentChatMessageMetadata,
UserMessageAttachment,
extract_user_message_attachments,
)
from v1.agent.schemas import HistorySnapshotResponse
from v1.agent.schemas import (
AgentRepositoryLike,
AttachmentStorageLike,
EventStreamLike,
HistorySnapshotResponse,
QueueClientLike,
TaskAccepted,
)
from v1.agent.utils import (
MAX_ATTACHMENT_BYTES,
MAX_ATTACHMENTS_PER_MESSAGE,
is_safe_attachment_path,
mime_to_suffix,
)
logger = get_logger(__name__)
_ALLOWED_ATTACHMENT_MIME_TYPES = {"image/png", "image/jpeg", "image/webp"}
_MAX_ATTACHMENT_BYTES = 5 * 1024 * 1024
_MAX_TOTAL_ATTACHMENT_BYTES = 12 * 1024 * 1024
_MAX_ATTACHMENTS_PER_MESSAGE = 3
@dataclass(frozen=True)
class TaskAccepted:
task_id: str
thread_id: str
run_id: str
created: bool
class AgentRepositoryLike(Protocol):
async def get_session_owner(self, *, session_id: str) -> str: ...
async def create_session_for_user(
self, *, user_id: str, session_id: str | None = None
) -> str: ...
async def commit(self) -> None: ...
async def rollback(self) -> None: ...
async def get_history_day(
self,
*,
session_id: str,
before: date | None,
visibility_mask: int | None = None,
) -> dict[str, object] | None: ...
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None: ...
async def persist_user_message(
self,
*,
session_id: str,
content: str,
metadata: AgentChatMessageMetadata | None,
visibility_mask: int,
) -> None: ...
async def get_system_agent_config(
self, *, agent_type: str
) -> dict[str, object] | None: ...
class QueueClientLike(Protocol):
async def enqueue(
self, *, command: dict[str, object], dedup_key: str | None
) -> str: ...
class EventStreamLike(Protocol):
async def read(
self,
*,
session_id: str,
last_event_id: str | None,
) -> list[dict[str, object]]: ...
class AttachmentStorageLike(Protocol):
async def upload_bytes(
self,
*,
bucket: str,
path: str,
content: bytes,
content_type: str,
) -> str: ...
async def download_bytes(self, *, bucket: str, path: str) -> bytes: ...
async def create_signed_url(
self,
*,
bucket: str,
path: str,
expires_in_seconds: int,
) -> str: ...
def parse_signed_url(self, url: str) -> tuple[str, str]: ...
def ensure_session_owner(*, owner_id: str, current_user: CurrentUser) -> None:
@@ -152,14 +78,9 @@ class AgentService:
run_id = run_input.run_id
forwarded_props = getattr(run_input, "forwarded_props", None)
try:
agent_type = parse_forwarded_props_agent_type(forwarded_props)
runtime_mode = parse_forwarded_props_runtime_mode(forwarded_props)
except ValueError as exc:
raise HTTPException(status_code=422, detail=str(exc)) from exc
if agent_type == "memory":
raise HTTPException(
status_code=422,
detail="memory mode is automation-only",
)
try:
owner = await self._repository.get_session_owner(session_id=thread_id)
@@ -185,7 +106,7 @@ class AgentService:
current_user=current_user,
)
visibility_mask = await self._resolve_user_message_visibility_mask(
agent_type=agent_type
runtime_mode=runtime_mode
)
await self._repository.persist_user_message(
session_id=thread_id,
@@ -195,6 +116,7 @@ class AgentService:
)
await self._repository.commit()
queue = "automation" if runtime_mode == RuntimeMode.AUTOMATION else "agent"
task_id = await self._queue.enqueue(
command={
"command": "run",
@@ -202,6 +124,7 @@ class AgentService:
"run_input": run_input.model_dump(
mode="json", by_alias=True, exclude_none=True
),
"queue": queue,
},
dedup_key=None,
)
@@ -212,60 +135,14 @@ class AgentService:
created=created,
)
async def _resolve_user_message_visibility_mask(self, *, agent_type: str) -> int:
normalized_agent_type = agent_type.strip().lower()
history_bit_mask = bit_mask(bit=int(SystemVisibilityBit.UI_HISTORY))
if normalized_agent_type == "memory":
return bit_mask(bit=18)
agent_config = await self._repository.get_system_agent_config(
agent_type=normalized_agent_type
)
if agent_config is None:
raise HTTPException(
status_code=422, detail="invalid forwarded_props.agent_type"
async def _resolve_user_message_visibility_mask(
self, *, runtime_mode: RuntimeMode
) -> int:
if runtime_mode == RuntimeMode.CHAT:
return bit_mask(bit=int(SystemVisibilityBit.UI_HISTORY)) | bit_mask(
bit=int(SystemVisibilityBit.CONTEXT_ASSEMBLY)
)
llm_config = SystemAgentLLMConfig.model_validate(
(agent_config.get("config") if isinstance(agent_config, dict) else {}) or {}
)
agent_mask = bit_mask(bit=llm_config.visibility_consumer_bit)
if normalized_agent_type == "worker":
router_config = await self._repository.get_system_agent_config(
agent_type="router"
)
worker_config = await self._repository.get_system_agent_config(
agent_type="worker"
)
if router_config is None or worker_config is None:
raise HTTPException(
status_code=500,
detail="system agent visibility config missing",
)
router_mask = bit_mask(
bit=SystemAgentLLMConfig.model_validate(
(
router_config.get("config")
if isinstance(router_config, dict)
else {}
)
or {}
).visibility_consumer_bit
)
worker_mask = bit_mask(
bit=SystemAgentLLMConfig.model_validate(
(
worker_config.get("config")
if isinstance(worker_config, dict)
else {}
)
or {}
).visibility_consumer_bit
)
return history_bit_mask | router_mask | worker_mask
return history_bit_mask | agent_mask
return 0
async def _prepare_user_message(
self,
@@ -309,7 +186,7 @@ class AgentService:
mime_type=mime_type,
)
)
if len(user_attachments) > _MAX_ATTACHMENTS_PER_MESSAGE:
if len(user_attachments) > MAX_ATTACHMENTS_PER_MESSAGE:
raise HTTPException(status_code=422, detail="Too many attachments")
except HTTPException:
raise
@@ -360,14 +237,14 @@ class AgentService:
if not isinstance(content_type, str):
raise HTTPException(status_code=422, detail="Unsupported attachment type")
mime_type = content_type.lower()
if mime_type not in _ALLOWED_ATTACHMENT_MIME_TYPES:
if mime_type not in {"image/png", "image/jpeg", "image/webp"}:
raise HTTPException(status_code=422, detail="Unsupported attachment type")
if not payload:
raise HTTPException(status_code=422, detail="Empty attachment")
if len(payload) > _MAX_ATTACHMENT_BYTES:
if len(payload) > MAX_ATTACHMENT_BYTES:
raise HTTPException(status_code=413, detail="Attachment too large")
suffix = _mime_to_suffix(mime_type)
suffix = mime_to_suffix(mime_type)
checksum = hashlib.sha1(payload).hexdigest()[:16]
filename_seed = filename if isinstance(filename, str) and filename else "upload"
filename_hash = hashlib.sha1(filename_seed.encode("utf-8")).hexdigest()[:8]
@@ -424,7 +301,7 @@ class AgentService:
normalized_path = path.strip()
expected_prefix = f"agent-inputs/{current_user.id}/"
if not _is_safe_attachment_path(
if not is_safe_attachment_path(
normalized_path, expected_prefix=expected_prefix
):
raise HTTPException(status_code=422, detail="Invalid attachment path scope")
@@ -503,7 +380,7 @@ class AgentService:
f"agent-inputs/{current_user.id}/{thread_id}/uploads/"
)
for attachment in attachments:
if not _is_safe_attachment_path(
if not is_safe_attachment_path(
attachment.path,
expected_prefix=expected_prefix,
):
@@ -586,134 +463,6 @@ class AgentService:
raise HTTPException(status_code=422, detail="INVALID_BINARY_URL_BUCKET")
expected_prefix = f"agent-inputs/{current_user.id}/{thread_id}/uploads/"
if not _is_safe_attachment_path(path, expected_prefix=expected_prefix):
if not is_safe_attachment_path(path, expected_prefix=expected_prefix):
raise HTTPException(status_code=422, detail="INVALID_BINARY_URL_PATH_SCOPE")
return bucket, path
class AsrService:
def __init__(self) -> None:
self._api_key: str | None = None
def _get_api_key(self) -> str:
if self._api_key is None:
dashscope_key = config.llm.provider_keys.get("dashscope")
if not dashscope_key:
raise ValueError(
"DASHSCOPE_API_KEY not configured. Set SOCIAL_LLM__PROVIDER_KEYS__DASHSCOPE in environment."
)
self._api_key = dashscope_key
return self._api_key
async def transcribe_file(self, file_path: str, filename: str) -> str:
try:
dashscope.api_key = self._get_api_key()
loop = asyncio.get_event_loop()
class SyncCallback(RecognitionCallback):
error: str | None = None
def on_error(self, result: Any) -> None:
self.error = str(result)
callback = SyncCallback()
recognizer = Recognition(
model="fun-asr-realtime-2026-02-28",
callback=callback,
format="wav",
sample_rate=16000,
)
result: Any = await loop.run_in_executor(
None,
lambda: recognizer.call(file=file_path),
)
if callback.error:
raise RuntimeError(f"ASR error: {callback.error}")
status_code = self._extract_field(result, "status_code")
if status_code != 200:
message = self._extract_field(result, "message")
raise RuntimeError(f"ASR transcription failed: {message}")
sentence = self._extract_sentence_payload(result)
if sentence is None:
request_id = self._extract_field(result, "request_id")
logger.warning(
"ASR returned empty result", extra={"request_id": request_id}
)
return ""
if isinstance(sentence, dict):
transcription = sentence.get("text", "")
elif isinstance(sentence, list):
transcription = " ".join(
item.get("text", "") for item in sentence if isinstance(item, dict)
)
else:
transcription = str(sentence) if sentence else ""
logger.info(
"ASR transcription completed",
extra={"filename": filename, "transcript_length": len(transcription)},
)
return transcription
except asyncio.CancelledError:
raise
except RuntimeError:
raise
except Exception as exc:
logger.exception("ASR transcription error")
raise RuntimeError(f"ASR transcription failed: {exc}") from exc
def _extract_sentence_payload(self, result: Any) -> Any | None:
if isinstance(result, dict):
output = result.get("output")
if isinstance(output, dict):
return output.get("sentence")
if output is not None:
return getattr(output, "sentence", None)
return result.get("sentence")
get_sentence = getattr(result, "get_sentence", None)
if callable(get_sentence):
sentence = get_sentence()
if sentence is not None:
return sentence
output = getattr(result, "output", None)
if output is None:
return None
if isinstance(output, dict):
return output.get("sentence")
return getattr(output, "sentence", None)
def _extract_field(self, result: Any, field: str) -> Any | None:
if isinstance(result, dict):
return result.get(field)
return getattr(result, field, None)
asr_service = AsrService()
def _mime_to_suffix(mime_type: str) -> str:
mapping = {
"image/png": "png",
"image/jpeg": "jpg",
"image/webp": "webp",
}
return mapping.get(mime_type.lower(), "bin")
def _is_safe_attachment_path(path: str, *, expected_prefix: str) -> bool:
normalized = path.strip()
if not normalized:
return False
if normalized.startswith("/"):
return False
if ".." in normalized:
return False
return normalized.startswith(expected_prefix)
+25
View File
@@ -14,6 +14,11 @@ from schemas.messages.chat_message import (
extract_user_message_attachments,
)
ALLOWED_ATTACHMENT_MIME_TYPES = {"image/png", "image/jpeg", "image/webp"}
MAX_ATTACHMENT_BYTES = 5 * 1024 * 1024
MAX_TOTAL_ATTACHMENT_BYTES = 12 * 1024 * 1024
MAX_ATTACHMENTS_PER_MESSAGE = 3
def convert_message_to_history(
message: AgentChatMessage,
@@ -124,3 +129,23 @@ def _compile_worker_ui_hints(
return compiled
except Exception:
return None
def mime_to_suffix(mime_type: str) -> str:
mapping = {
"image/png": "png",
"image/jpeg": "jpg",
"image/webp": "webp",
}
return mapping.get(mime_type.lower(), "bin")
def is_safe_attachment_path(path: str, *, expected_prefix: str) -> bool:
normalized = path.strip()
if not normalized:
return False
if normalized.startswith("/"):
return False
if ".." in normalized:
return False
return normalized.startswith(expected_prefix)
+3
View File
@@ -0,0 +1,3 @@
from v1.memory.service import MemoryService
__all__ = ["MemoryService"]
+35
View File
@@ -0,0 +1,35 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Protocol
from uuid import UUID
from sqlalchemy import select
from core.db.base_repository import BaseRepository
from models.automation_jobs import AutomationJob
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
class MemoryRepositoryLike(Protocol):
async def get_job_by_id_and_owner(
self, *, job_id: UUID, owner_id: UUID
) -> AutomationJob | None: ...
class MemoryRepository(BaseRepository[AutomationJob]):
def __init__(self, session: AsyncSession) -> None:
super().__init__(session=session, model=AutomationJob)
async def get_job_by_id_and_owner(
self, *, job_id: UUID, owner_id: UUID
) -> AutomationJob | None:
stmt = (
select(AutomationJob)
.where(AutomationJob.id == job_id)
.where(AutomationJob.owner_id == owner_id)
.where(AutomationJob.deleted_at.is_(None))
)
result = await self._session.execute(stmt)
return result.scalar_one_or_none()
+25
View File
@@ -0,0 +1,25 @@
from __future__ import annotations
from uuid import UUID
from fastapi import HTTPException
from schemas.automation.config import AutomationJobConfig
from v1.memory.repository import MemoryRepositoryLike
class MemoryService:
_repository: MemoryRepositoryLike
def __init__(self, repository: MemoryRepositoryLike) -> None:
self._repository = repository
async def get_memory_job_config(
self, *, job_id: UUID, owner_id: UUID
) -> AutomationJobConfig:
job = await self._repository.get_job_by_id_and_owner(
job_id=job_id, owner_id=owner_id
)
if job is None:
raise HTTPException(status_code=404, detail="Automation job not found")
return AutomationJobConfig.model_validate(job.config or {})
@@ -2,7 +2,7 @@ from __future__ import annotations
import pytest
from core.agentscope.runtime.consumer_registry import build_consumer_registry
from core.agentscope.runtime.registry_builder import build_consumer_registry
def test_build_consumer_registry_from_system_agent_configs() -> None:
@@ -45,17 +45,6 @@ def _user_context() -> UserContext:
)
def test_parse_agent_type_supports_known_stages() -> None:
assert AgentScopeRunner._parse_agent_type(stage_name="router") == AgentType.ROUTER
assert AgentScopeRunner._parse_agent_type(stage_name="worker") == AgentType.WORKER
assert AgentScopeRunner._parse_agent_type(stage_name="memory") == AgentType.MEMORY
def test_parse_agent_type_rejects_unknown_stage() -> None:
with pytest.raises(ValueError, match="unsupported stage name"):
AgentScopeRunner._parse_agent_type(stage_name="planner")
def test_build_worker_input_messages_only_contains_router_contract() -> None:
runner = AgentScopeRunner()
router_output = RouterAgentOutput(
+3 -4
View File
@@ -5,14 +5,13 @@ import sys
import pytest
from core.taskiq.app import broker, bulk_broker, critical_broker, default_broker
from core.taskiq.app import broker, worker_agent_broker, worker_automation_broker
def test_taskiq_broker_is_configured() -> None:
assert broker is not None
assert default_broker is broker
assert critical_broker is not None
assert bulk_broker is not None
assert worker_agent_broker is broker
assert worker_automation_broker is not None
def test_taskiq_app_configures_logging_on_import(
@@ -2,13 +2,12 @@ from __future__ import annotations
import pytest
from schemas.agent.consumer_registry import AgentConsumerBinding, ConsumerRegistry
from schemas.agent.pipeline_spec import (
ContextPolicy,
ExecutorKind,
PipelineSpec,
StageSpec,
from core.agentscope.schemas.consumer_registry import (
AgentConsumerBinding,
ConsumerRegistry,
)
from core.agentscope.schemas.pipeline_spec import ExecutorKind, PipelineSpec, StageSpec
from schemas.agent.system_agent import AgentType
def test_consumer_registry_rejects_duplicate_bits() -> None:
@@ -29,9 +28,9 @@ def test_pipeline_spec_requires_non_empty_stages() -> None:
def test_stage_spec_normalizes_stage_name() -> None:
spec = StageSpec(
stage_name=" Worker ",
agent_type=AgentType.WORKER,
executor_kind=ExecutorKind.REACT,
default_visibility_mask=1,
context_policy=ContextPolicy(consumer_agent_type="worker", count=20),
)
assert spec.stage_name == "worker"
assert spec.agent_type == AgentType.WORKER
+5 -17
View File
@@ -159,7 +159,7 @@ def _user() -> CurrentUser:
)
def _build_run_input(*, urls: list[str], agent_type: str = "worker") -> RunAgentInput:
def _build_run_input(*, urls: list[str], runtime_mode: str = "chat") -> RunAgentInput:
content: list[dict[str, str]] = [{"type": "text", "text": "hello"}]
for url in urls:
content.append({"type": "binary", "mimeType": "image/png", "url": url})
@@ -177,7 +177,7 @@ def _build_run_input(*, urls: list[str], agent_type: str = "worker") -> RunAgent
],
"tools": [],
"context": [],
"forwardedProps": {"agent_type": agent_type},
"forwardedProps": {"runtime_mode": runtime_mode},
}
)
@@ -275,7 +275,7 @@ async def test_enqueue_run_rejects_unknown_agent_type(monkeypatch) -> None:
urls=[
f"{base_url}/storage/v1/object/sign/agent-test-bucket/{safe_path}?token=1"
],
agent_type="planner",
runtime_mode="planner",
)
with pytest.raises(HTTPException) as exc_info:
@@ -285,7 +285,7 @@ async def test_enqueue_run_rejects_unknown_agent_type(monkeypatch) -> None:
@pytest.mark.asyncio
async def test_enqueue_run_rejects_memory_mode_for_api(monkeypatch) -> None:
async def test_enqueue_run_rejects_invalid_runtime_mode(monkeypatch) -> None:
monkeypatch.setattr(
agent_service_module.config.storage, "bucket", "agent-test-bucket"
)
@@ -296,24 +296,12 @@ async def test_enqueue_run_rejects_memory_mode_for_api(monkeypatch) -> None:
stream=_FakeStream(),
attachment_storage=_FakeAttachmentStorage(),
)
base_url = str(config.supabase.url).rstrip("/")
safe_path = quote(
"agent-inputs/00000000-0000-0000-0000-000000000001/"
"00000000-0000-0000-0000-000000000001/uploads/a.png"
)
run_input = _build_run_input(
urls=[
f"{base_url}/storage/v1/object/sign/agent-test-bucket/{safe_path}?token=1"
],
agent_type="memory",
)
run_input = _build_run_input(urls=[], runtime_mode="planner")
with pytest.raises(HTTPException) as exc_info:
await service.enqueue_run(run_input=run_input, current_user=_user())
assert exc_info.value.status_code == 422
assert exc_info.value.detail == "memory mode is automation-only"
assert repository.created_session_calls == 0
assert repository.persisted_user_messages == []