refactor(agent): restructure visibility masks, task queues, and memory service

Visibility mask refactoring:
- Replace dead UI_REALTIME bit with CONTEXT_ASSEMBLY (bit 1)
- Remove visibility_consumer_bit from SystemAgentLLMConfig and system_agents.yaml
- Simplify _resolve_user_message_visibility_mask: chat->UI_HISTORY|CONTEXT_ASSEMBLY, automation->0
- Simplify _resolve_stage_visibility_mask: memory->UI_HISTORY, router/worker->UI_HISTORY|CONTEXT_ASSEMBLY
- Remove stage_visibility_bit_map from store.py

Task queue renaming:
- Replace default_broker/bulk_broker/critical_broker with worker_agent_broker/worker_automation_broker
- Queue names: 'default'/'bulk'/'critical' -> 'agent'/'automation'
- Rename run_command_task -> run_command_task_agent/run_command_task_automation
- AgentService derives queue from runtime_mode: chat->agent, automation->automation

Architecture cleanup:
- Move context_service.py from runtime/ to agentscope/services/
- Add MemoryService in v1/memory/ following repository/service pattern
- Move consumer_registry.py and pipeline_spec.py from schemas/agent to agentscope/schemas/
- Delete dead code: registry_builder.py, VisibilityBitRef
- Delete superseded plan docs
This commit is contained in:
zl-q
2026-03-22 20:35:55 +08:00
parent 20b9e70e84
commit 80ad5141a6
37 changed files with 628 additions and 2428 deletions
+120
View File
@@ -0,0 +1,120 @@
from __future__ import annotations
import asyncio
from typing import Any
import dashscope
from dashscope.audio.asr import Recognition, RecognitionCallback
from core.config.settings import config
from core.logging import get_logger
logger = get_logger(__name__)
class AsrService:
def __init__(self) -> None:
self._api_key: str | None = None
def _get_api_key(self) -> str:
if self._api_key is None:
dashscope_key = config.llm.provider_keys.get("dashscope")
if not dashscope_key:
raise ValueError(
"DASHSCOPE_API_KEY not configured. Set SOCIAL_LLM__PROVIDER_KEYS__DASHSCOPE in environment."
)
self._api_key = dashscope_key
return self._api_key
async def transcribe_file(self, file_path: str, filename: str) -> str:
try:
dashscope.api_key = self._get_api_key()
loop = asyncio.get_event_loop()
class SyncCallback(RecognitionCallback):
error: str | None = None
def on_error(self, result: Any) -> None:
self.error = str(result)
callback = SyncCallback()
recognizer = Recognition(
model="fun-asr-realtime-2026-02-28",
callback=callback,
format="wav",
sample_rate=16000,
)
result: Any = await loop.run_in_executor(
None,
lambda: recognizer.call(file=file_path),
)
if callback.error:
raise RuntimeError(f"ASR error: {callback.error}")
status_code = self._extract_field(result, "status_code")
if status_code != 200:
message = self._extract_field(result, "message")
raise RuntimeError(f"ASR transcription failed: {message}")
sentence = self._extract_sentence_payload(result)
if sentence is None:
request_id = self._extract_field(result, "request_id")
logger.warning(
"ASR returned empty result", extra={"request_id": request_id}
)
return ""
if isinstance(sentence, dict):
transcription = sentence.get("text", "")
elif isinstance(sentence, list):
transcription = " ".join(
item.get("text", "") for item in sentence if isinstance(item, dict)
)
else:
transcription = str(sentence) if sentence else ""
logger.info(
"ASR transcription completed",
extra={"filename": filename, "transcript_length": len(transcription)},
)
return transcription
except asyncio.CancelledError:
raise
except RuntimeError:
raise
except Exception as exc:
logger.exception("ASR transcription error")
raise RuntimeError(f"ASR transcription failed: {exc}") from exc
def _extract_sentence_payload(self, result: Any) -> Any | None:
if isinstance(result, dict):
output = result.get("output")
if isinstance(output, dict):
return output.get("sentence")
if output is not None:
return getattr(output, "sentence", None)
return result.get("sentence")
get_sentence = getattr(result, "get_sentence", None)
if callable(get_sentence):
sentence = get_sentence()
if sentence is not None:
return sentence
output = getattr(result, "output", None)
if output is None:
return None
if isinstance(output, dict):
return output.get("sentence")
return getattr(output, "sentence", None)
def _extract_field(self, result: Any, field: str) -> Any | None:
if isinstance(result, dict):
return result.get(field)
return getattr(result, field, None)
asr_service = AsrService()
+6 -9
View File
@@ -44,17 +44,14 @@ class TaskiqQueueClient:
@staticmethod
def _select_queue_task(command: dict[str, object]) -> Any:
from core.agentscope.runtime.tasks import (
run_command_task,
run_command_task_bulk,
run_command_task_critical,
run_command_task_agent,
run_command_task_automation,
)
queue = str(command.get("queue", "default")).strip().lower()
if queue == "critical":
return run_command_task_critical
if queue == "bulk":
return run_command_task_bulk
return run_command_task
queue = str(command.get("queue", "agent")).strip().lower()
if queue == "automation":
return run_command_task_automation
return run_command_task_agent
async def enqueue(
self, *, command: dict[str, object], dedup_key: str | None
+39 -38
View File
@@ -6,7 +6,7 @@ import re
import tempfile
from collections.abc import AsyncIterator
from datetime import date
from typing import Annotated, Union
from typing import Annotated
from ag_ui.core import RunAgentInput
from core.agentscope.events import to_sse_event
@@ -28,7 +28,7 @@ from fastapi import (
UploadFile,
status,
)
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.responses import StreamingResponse
from services.base.redis import get_or_init_redis_client
from v1.agent.dependencies import get_agent_service
from v1.agent.schemas import (
@@ -39,7 +39,8 @@ from v1.agent.schemas import (
HistorySnapshotResponse,
TaskAcceptedResponse,
)
from v1.agent.service import AgentService, asr_service
from v1.agent.asr import asr_service
from v1.agent.service import AgentService
from v1.users.dependencies import get_current_user
router = APIRouter(prefix="/agent", tags=["agent"])
@@ -73,15 +74,13 @@ async def _acquire_sse_slot(*, user_id: str) -> bool:
count = await redis.incr(key)
if count == 1:
await redis.expire(key, _SSE_SLOT_TTL_SECONDS)
elif count > _MAX_SSE_CONNECTIONS_PER_USER:
await redis.decr(key)
return False
else:
ttl = await redis.ttl(key)
if int(ttl) < 0:
if ttl < 0:
await redis.expire(key, _SSE_SLOT_TTL_SECONDS)
if int(count) > _MAX_SSE_CONNECTIONS_PER_USER:
after_decr = await redis.decr(key)
if int(after_decr) <= 0:
await redis.delete(key)
return False
return True
except Exception as exc: # noqa: BLE001
logger.warning(
@@ -97,13 +96,18 @@ async def _release_sse_slot(*, user_id: str) -> None:
redis = await get_or_init_redis_client()
key = f"agent:sse-active:{user_id}"
count = await redis.decr(key)
if int(count) <= 0:
if count <= 0:
await redis.delete(key)
return None
ttl = await redis.ttl(key)
if int(ttl) < 0:
await redis.expire(key, _SSE_SLOT_TTL_SECONDS)
except Exception: # noqa: BLE001
else:
ttl = await redis.ttl(key)
if ttl < 0:
await redis.expire(key, _SSE_SLOT_TTL_SECONDS)
except Exception as exc: # noqa: BLE001
logger.warning(
"SSE slot release failed",
user_id=user_id,
reason=str(exc),
)
return None
@@ -176,6 +180,11 @@ async def stream_events(
last_event_id=cursor,
current_user=current_user,
)
except TimeoutError:
idle_polls += 1
yield ": keep-alive\n\n"
await asyncio.sleep(0.2)
continue
except Exception as exc: # noqa: BLE001
logger.warning(
"SSE stream read failed",
@@ -183,11 +192,6 @@ async def stream_events(
user_id=str(current_user.id),
reason=str(exc),
)
if "Timeout reading from" in str(exc):
idle_polls += 1
yield ": keep-alive\n\n"
await asyncio.sleep(0.2)
continue
break
if not rows:
@@ -291,12 +295,12 @@ async def create_attachment_signed_url(
async def transcribe(
audio: UploadFile,
request: Request,
_current_user: Annotated[CurrentUser, Depends(get_current_user)],
) -> Union[AsrTranscribeResponse, JSONResponse]:
current_user: Annotated[CurrentUser, Depends(get_current_user)],
) -> AsrTranscribeResponse:
temp_path: str | None = None
try:
if audio.content_type not in _ALLOWED_AUDIO_CONTENT_TYPES:
raise ValueError("Unsupported audio format")
raise HTTPException(status_code=400, detail="Unsupported audio format")
content_length = request.headers.get("content-length")
if content_length is not None:
@@ -309,7 +313,7 @@ async def transcribe(
and declared_length
> _MAX_TRANSCRIBE_AUDIO_BYTES + _MULTIPART_OVERHEAD_BYTES
):
raise ValueError("Audio file too large")
raise HTTPException(status_code=400, detail="Audio file too large")
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
temp_path = tmp_file.name
@@ -322,16 +326,16 @@ async def transcribe(
break
total_bytes += len(chunk)
if total_bytes > _MAX_TRANSCRIBE_AUDIO_BYTES:
raise ValueError("Audio file too large")
raise HTTPException(status_code=400, detail="Audio file too large")
if len(header) < _WAV_HEADER_MIN_BYTES:
required = _WAV_HEADER_MIN_BYTES - len(header)
header.extend(chunk[:required])
tmp_file.write(chunk)
if total_bytes == 0:
raise ValueError("Empty audio file")
raise HTTPException(status_code=400, detail="Empty audio file")
if not _looks_like_wav_header(bytes(header)):
raise ValueError("Unsupported audio format")
raise HTTPException(status_code=400, detail="Unsupported audio format")
transcript = await asr_service.transcribe_file(
temp_path, audio.filename or "unknown"
@@ -339,17 +343,14 @@ async def transcribe(
return AsrTranscribeResponse(transcript=transcript)
except ValueError as exc:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(exc)},
)
except HTTPException:
raise
except RuntimeError:
return JSONResponse(
status_code=status.HTTP_502_BAD_GATEWAY,
content={"detail": "ASR service unavailable"},
)
raise HTTPException(status_code=502, detail="ASR service unavailable")
finally:
await audio.close()
if temp_path and os.path.exists(temp_path):
os.unlink(temp_path)
if temp_path:
try:
os.unlink(temp_path)
except OSError:
pass
+84 -1
View File
@@ -1,12 +1,95 @@
from __future__ import annotations
from typing import Literal
from dataclasses import dataclass
from datetime import date
from typing import Any, Literal, Protocol
from pydantic import BaseModel, ConfigDict, Field
from schemas.agent.ui_schema import UiSchemaRenderer
class AgentRepositoryLike(Protocol):
async def get_session_owner(self, *, session_id: str) -> str: ...
async def create_session_for_user(
self, *, user_id: str, session_id: str | None = None
) -> str: ...
async def commit(self) -> None: ...
async def rollback(self) -> None: ...
async def get_history_day(
self,
*,
session_id: str,
before: date | None,
visibility_mask: int | None = None,
) -> dict[str, object] | None: ...
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None: ...
async def persist_user_message(
self,
*,
session_id: str,
content: str,
metadata: Any,
visibility_mask: int,
) -> None: ...
async def get_system_agent_config(
self, *, agent_type: str
) -> dict[str, object] | None: ...
class QueueClientLike(Protocol):
async def enqueue(
self, *, command: dict[str, object], dedup_key: str | None
) -> str: ...
class EventStreamLike(Protocol):
async def read(
self,
*,
session_id: str,
last_event_id: str | None,
) -> list[dict[str, object]]: ...
class AttachmentStorageLike(Protocol):
async def upload_bytes(
self,
*,
bucket: str,
path: str,
content: bytes,
content_type: str,
) -> str: ...
async def download_bytes(self, *, bucket: str, path: str) -> bytes: ...
async def create_signed_url(
self,
*,
bucket: str,
path: str,
expires_in_seconds: int,
) -> str: ...
def parse_signed_url(self, url: str) -> tuple[str, str]: ...
@dataclass(frozen=True)
class TaskAccepted:
task_id: str
thread_id: str
run_id: str
created: bool
class TaskAcceptedResponse(BaseModel):
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
+37 -288
View File
@@ -1,15 +1,11 @@
from __future__ import annotations
import asyncio
from dataclasses import dataclass
from datetime import date
import hashlib
from typing import Any, Protocol
from urllib.parse import urlparse
import dashscope
from ag_ui.core import RunAgentInput
from dashscope.audio.asr import Recognition, RecognitionCallback
from fastapi import HTTPException
from sqlalchemy.exc import IntegrityError
@@ -17,102 +13,32 @@ from core.auth.models import CurrentUser
from core.agentscope.schemas.agui_input import extract_latest_user_payload
from core.config.settings import config
from core.logging import get_logger
from schemas.agent.forwarded_props import parse_forwarded_props_agent_type
from schemas.agent.system_agent import SystemAgentLLMConfig
from schemas.agent.forwarded_props import (
parse_forwarded_props_runtime_mode,
RuntimeMode,
)
from schemas.agent.visibility import SystemVisibilityBit, bit_mask
from schemas.messages.chat_message import (
AgentChatMessageMetadata,
UserMessageAttachment,
extract_user_message_attachments,
)
from v1.agent.schemas import HistorySnapshotResponse
from v1.agent.schemas import (
AgentRepositoryLike,
AttachmentStorageLike,
EventStreamLike,
HistorySnapshotResponse,
QueueClientLike,
TaskAccepted,
)
from v1.agent.utils import (
MAX_ATTACHMENT_BYTES,
MAX_ATTACHMENTS_PER_MESSAGE,
is_safe_attachment_path,
mime_to_suffix,
)
logger = get_logger(__name__)
_ALLOWED_ATTACHMENT_MIME_TYPES = {"image/png", "image/jpeg", "image/webp"}
_MAX_ATTACHMENT_BYTES = 5 * 1024 * 1024
_MAX_TOTAL_ATTACHMENT_BYTES = 12 * 1024 * 1024
_MAX_ATTACHMENTS_PER_MESSAGE = 3
@dataclass(frozen=True)
class TaskAccepted:
task_id: str
thread_id: str
run_id: str
created: bool
class AgentRepositoryLike(Protocol):
async def get_session_owner(self, *, session_id: str) -> str: ...
async def create_session_for_user(
self, *, user_id: str, session_id: str | None = None
) -> str: ...
async def commit(self) -> None: ...
async def rollback(self) -> None: ...
async def get_history_day(
self,
*,
session_id: str,
before: date | None,
visibility_mask: int | None = None,
) -> dict[str, object] | None: ...
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None: ...
async def persist_user_message(
self,
*,
session_id: str,
content: str,
metadata: AgentChatMessageMetadata | None,
visibility_mask: int,
) -> None: ...
async def get_system_agent_config(
self, *, agent_type: str
) -> dict[str, object] | None: ...
class QueueClientLike(Protocol):
async def enqueue(
self, *, command: dict[str, object], dedup_key: str | None
) -> str: ...
class EventStreamLike(Protocol):
async def read(
self,
*,
session_id: str,
last_event_id: str | None,
) -> list[dict[str, object]]: ...
class AttachmentStorageLike(Protocol):
async def upload_bytes(
self,
*,
bucket: str,
path: str,
content: bytes,
content_type: str,
) -> str: ...
async def download_bytes(self, *, bucket: str, path: str) -> bytes: ...
async def create_signed_url(
self,
*,
bucket: str,
path: str,
expires_in_seconds: int,
) -> str: ...
def parse_signed_url(self, url: str) -> tuple[str, str]: ...
def ensure_session_owner(*, owner_id: str, current_user: CurrentUser) -> None:
@@ -152,14 +78,9 @@ class AgentService:
run_id = run_input.run_id
forwarded_props = getattr(run_input, "forwarded_props", None)
try:
agent_type = parse_forwarded_props_agent_type(forwarded_props)
runtime_mode = parse_forwarded_props_runtime_mode(forwarded_props)
except ValueError as exc:
raise HTTPException(status_code=422, detail=str(exc)) from exc
if agent_type == "memory":
raise HTTPException(
status_code=422,
detail="memory mode is automation-only",
)
try:
owner = await self._repository.get_session_owner(session_id=thread_id)
@@ -185,7 +106,7 @@ class AgentService:
current_user=current_user,
)
visibility_mask = await self._resolve_user_message_visibility_mask(
agent_type=agent_type
runtime_mode=runtime_mode
)
await self._repository.persist_user_message(
session_id=thread_id,
@@ -195,6 +116,7 @@ class AgentService:
)
await self._repository.commit()
queue = "automation" if runtime_mode == RuntimeMode.AUTOMATION else "agent"
task_id = await self._queue.enqueue(
command={
"command": "run",
@@ -202,6 +124,7 @@ class AgentService:
"run_input": run_input.model_dump(
mode="json", by_alias=True, exclude_none=True
),
"queue": queue,
},
dedup_key=None,
)
@@ -212,60 +135,14 @@ class AgentService:
created=created,
)
async def _resolve_user_message_visibility_mask(self, *, agent_type: str) -> int:
normalized_agent_type = agent_type.strip().lower()
history_bit_mask = bit_mask(bit=int(SystemVisibilityBit.UI_HISTORY))
if normalized_agent_type == "memory":
return bit_mask(bit=18)
agent_config = await self._repository.get_system_agent_config(
agent_type=normalized_agent_type
)
if agent_config is None:
raise HTTPException(
status_code=422, detail="invalid forwarded_props.agent_type"
async def _resolve_user_message_visibility_mask(
self, *, runtime_mode: RuntimeMode
) -> int:
if runtime_mode == RuntimeMode.CHAT:
return bit_mask(bit=int(SystemVisibilityBit.UI_HISTORY)) | bit_mask(
bit=int(SystemVisibilityBit.CONTEXT_ASSEMBLY)
)
llm_config = SystemAgentLLMConfig.model_validate(
(agent_config.get("config") if isinstance(agent_config, dict) else {}) or {}
)
agent_mask = bit_mask(bit=llm_config.visibility_consumer_bit)
if normalized_agent_type == "worker":
router_config = await self._repository.get_system_agent_config(
agent_type="router"
)
worker_config = await self._repository.get_system_agent_config(
agent_type="worker"
)
if router_config is None or worker_config is None:
raise HTTPException(
status_code=500,
detail="system agent visibility config missing",
)
router_mask = bit_mask(
bit=SystemAgentLLMConfig.model_validate(
(
router_config.get("config")
if isinstance(router_config, dict)
else {}
)
or {}
).visibility_consumer_bit
)
worker_mask = bit_mask(
bit=SystemAgentLLMConfig.model_validate(
(
worker_config.get("config")
if isinstance(worker_config, dict)
else {}
)
or {}
).visibility_consumer_bit
)
return history_bit_mask | router_mask | worker_mask
return history_bit_mask | agent_mask
return 0
async def _prepare_user_message(
self,
@@ -309,7 +186,7 @@ class AgentService:
mime_type=mime_type,
)
)
if len(user_attachments) > _MAX_ATTACHMENTS_PER_MESSAGE:
if len(user_attachments) > MAX_ATTACHMENTS_PER_MESSAGE:
raise HTTPException(status_code=422, detail="Too many attachments")
except HTTPException:
raise
@@ -360,14 +237,14 @@ class AgentService:
if not isinstance(content_type, str):
raise HTTPException(status_code=422, detail="Unsupported attachment type")
mime_type = content_type.lower()
if mime_type not in _ALLOWED_ATTACHMENT_MIME_TYPES:
if mime_type not in {"image/png", "image/jpeg", "image/webp"}:
raise HTTPException(status_code=422, detail="Unsupported attachment type")
if not payload:
raise HTTPException(status_code=422, detail="Empty attachment")
if len(payload) > _MAX_ATTACHMENT_BYTES:
if len(payload) > MAX_ATTACHMENT_BYTES:
raise HTTPException(status_code=413, detail="Attachment too large")
suffix = _mime_to_suffix(mime_type)
suffix = mime_to_suffix(mime_type)
checksum = hashlib.sha1(payload).hexdigest()[:16]
filename_seed = filename if isinstance(filename, str) and filename else "upload"
filename_hash = hashlib.sha1(filename_seed.encode("utf-8")).hexdigest()[:8]
@@ -424,7 +301,7 @@ class AgentService:
normalized_path = path.strip()
expected_prefix = f"agent-inputs/{current_user.id}/"
if not _is_safe_attachment_path(
if not is_safe_attachment_path(
normalized_path, expected_prefix=expected_prefix
):
raise HTTPException(status_code=422, detail="Invalid attachment path scope")
@@ -503,7 +380,7 @@ class AgentService:
f"agent-inputs/{current_user.id}/{thread_id}/uploads/"
)
for attachment in attachments:
if not _is_safe_attachment_path(
if not is_safe_attachment_path(
attachment.path,
expected_prefix=expected_prefix,
):
@@ -586,134 +463,6 @@ class AgentService:
raise HTTPException(status_code=422, detail="INVALID_BINARY_URL_BUCKET")
expected_prefix = f"agent-inputs/{current_user.id}/{thread_id}/uploads/"
if not _is_safe_attachment_path(path, expected_prefix=expected_prefix):
if not is_safe_attachment_path(path, expected_prefix=expected_prefix):
raise HTTPException(status_code=422, detail="INVALID_BINARY_URL_PATH_SCOPE")
return bucket, path
class AsrService:
def __init__(self) -> None:
self._api_key: str | None = None
def _get_api_key(self) -> str:
if self._api_key is None:
dashscope_key = config.llm.provider_keys.get("dashscope")
if not dashscope_key:
raise ValueError(
"DASHSCOPE_API_KEY not configured. Set SOCIAL_LLM__PROVIDER_KEYS__DASHSCOPE in environment."
)
self._api_key = dashscope_key
return self._api_key
async def transcribe_file(self, file_path: str, filename: str) -> str:
try:
dashscope.api_key = self._get_api_key()
loop = asyncio.get_event_loop()
class SyncCallback(RecognitionCallback):
error: str | None = None
def on_error(self, result: Any) -> None:
self.error = str(result)
callback = SyncCallback()
recognizer = Recognition(
model="fun-asr-realtime-2026-02-28",
callback=callback,
format="wav",
sample_rate=16000,
)
result: Any = await loop.run_in_executor(
None,
lambda: recognizer.call(file=file_path),
)
if callback.error:
raise RuntimeError(f"ASR error: {callback.error}")
status_code = self._extract_field(result, "status_code")
if status_code != 200:
message = self._extract_field(result, "message")
raise RuntimeError(f"ASR transcription failed: {message}")
sentence = self._extract_sentence_payload(result)
if sentence is None:
request_id = self._extract_field(result, "request_id")
logger.warning(
"ASR returned empty result", extra={"request_id": request_id}
)
return ""
if isinstance(sentence, dict):
transcription = sentence.get("text", "")
elif isinstance(sentence, list):
transcription = " ".join(
item.get("text", "") for item in sentence if isinstance(item, dict)
)
else:
transcription = str(sentence) if sentence else ""
logger.info(
"ASR transcription completed",
extra={"filename": filename, "transcript_length": len(transcription)},
)
return transcription
except asyncio.CancelledError:
raise
except RuntimeError:
raise
except Exception as exc:
logger.exception("ASR transcription error")
raise RuntimeError(f"ASR transcription failed: {exc}") from exc
def _extract_sentence_payload(self, result: Any) -> Any | None:
if isinstance(result, dict):
output = result.get("output")
if isinstance(output, dict):
return output.get("sentence")
if output is not None:
return getattr(output, "sentence", None)
return result.get("sentence")
get_sentence = getattr(result, "get_sentence", None)
if callable(get_sentence):
sentence = get_sentence()
if sentence is not None:
return sentence
output = getattr(result, "output", None)
if output is None:
return None
if isinstance(output, dict):
return output.get("sentence")
return getattr(output, "sentence", None)
def _extract_field(self, result: Any, field: str) -> Any | None:
if isinstance(result, dict):
return result.get(field)
return getattr(result, field, None)
asr_service = AsrService()
def _mime_to_suffix(mime_type: str) -> str:
mapping = {
"image/png": "png",
"image/jpeg": "jpg",
"image/webp": "webp",
}
return mapping.get(mime_type.lower(), "bin")
def _is_safe_attachment_path(path: str, *, expected_prefix: str) -> bool:
normalized = path.strip()
if not normalized:
return False
if normalized.startswith("/"):
return False
if ".." in normalized:
return False
return normalized.startswith(expected_prefix)
+25
View File
@@ -14,6 +14,11 @@ from schemas.messages.chat_message import (
extract_user_message_attachments,
)
ALLOWED_ATTACHMENT_MIME_TYPES = {"image/png", "image/jpeg", "image/webp"}
MAX_ATTACHMENT_BYTES = 5 * 1024 * 1024
MAX_TOTAL_ATTACHMENT_BYTES = 12 * 1024 * 1024
MAX_ATTACHMENTS_PER_MESSAGE = 3
def convert_message_to_history(
message: AgentChatMessage,
@@ -124,3 +129,23 @@ def _compile_worker_ui_hints(
return compiled
except Exception:
return None
def mime_to_suffix(mime_type: str) -> str:
mapping = {
"image/png": "png",
"image/jpeg": "jpg",
"image/webp": "webp",
}
return mapping.get(mime_type.lower(), "bin")
def is_safe_attachment_path(path: str, *, expected_prefix: str) -> bool:
normalized = path.strip()
if not normalized:
return False
if normalized.startswith("/"):
return False
if ".." in normalized:
return False
return normalized.startswith(expected_prefix)
+3
View File
@@ -0,0 +1,3 @@
from v1.memory.service import MemoryService
__all__ = ["MemoryService"]
+35
View File
@@ -0,0 +1,35 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Protocol
from uuid import UUID
from sqlalchemy import select
from core.db.base_repository import BaseRepository
from models.automation_jobs import AutomationJob
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
class MemoryRepositoryLike(Protocol):
async def get_job_by_id_and_owner(
self, *, job_id: UUID, owner_id: UUID
) -> AutomationJob | None: ...
class MemoryRepository(BaseRepository[AutomationJob]):
def __init__(self, session: AsyncSession) -> None:
super().__init__(session=session, model=AutomationJob)
async def get_job_by_id_and_owner(
self, *, job_id: UUID, owner_id: UUID
) -> AutomationJob | None:
stmt = (
select(AutomationJob)
.where(AutomationJob.id == job_id)
.where(AutomationJob.owner_id == owner_id)
.where(AutomationJob.deleted_at.is_(None))
)
result = await self._session.execute(stmt)
return result.scalar_one_or_none()
+25
View File
@@ -0,0 +1,25 @@
from __future__ import annotations
from uuid import UUID
from fastapi import HTTPException
from schemas.automation.config import AutomationJobConfig
from v1.memory.repository import MemoryRepositoryLike
class MemoryService:
_repository: MemoryRepositoryLike
def __init__(self, repository: MemoryRepositoryLike) -> None:
self._repository = repository
async def get_memory_job_config(
self, *, job_id: UUID, owner_id: UUID
) -> AutomationJobConfig:
job = await self._repository.get_job_by_id_and_owner(
job_id=job_id, owner_id=owner_id
)
if job is None:
raise HTTPException(status_code=404, detail="Automation job not found")
return AutomationJobConfig.model_validate(job.config or {})