feat: 添加 Agent 步骤事件与图片附件功能

- 新增 stepStarted/stepFinished 事件类型支持
- 前端实现图片附件上传和预览功能
- 后端增强工具结果存储和事件处理
- 完善相关单元测试和集成测试
This commit is contained in:
zl-q
2026-03-12 09:29:57 +08:00
parent 87215f9d41
commit 7b8865e256
45 changed files with 3869 additions and 308 deletions
+91 -1
View File
@@ -3,11 +3,20 @@ from __future__ import annotations
import asyncio
from typing import Any
from storage3.exceptions import StorageApiError
from core.config.settings import config
from services.base.supabase import supabase_service
class AgentAttachmentStorage:
def _validate_bucket(self, *, bucket: str) -> None:
expected = config.storage.bucket
if bucket != expected:
raise RuntimeError("Invalid attachment bucket")
def _bucket_client(self, *, bucket: str) -> Any:
self._validate_bucket(bucket=bucket)
client = supabase_service.get_admin_client()
storage = getattr(client, "storage", None)
if storage is None:
@@ -39,9 +48,82 @@ class AgentAttachmentStorage:
},
)
await asyncio.to_thread(_upload)
try:
await asyncio.to_thread(_upload)
except Exception as exc: # noqa: BLE001
if not _is_bucket_not_found_error(exc):
raise
await self._ensure_bucket_exists(bucket=bucket)
await asyncio.to_thread(_upload)
return path
async def _ensure_bucket_exists(self, *, bucket: str) -> None:
def _ensure() -> None:
client = supabase_service.get_admin_client()
storage = getattr(client, "storage", None)
if storage is None:
raise RuntimeError("Supabase storage client unavailable")
get_bucket = getattr(storage, "get_bucket", None)
if callable(get_bucket):
try:
get_bucket(bucket)
return
except Exception: # noqa: BLE001
pass
create_bucket = getattr(storage, "create_bucket", None)
if not callable(create_bucket):
raise RuntimeError("Supabase storage create_bucket is unavailable")
try:
create_bucket(bucket, options={"public": False})
except Exception as exc: # noqa: BLE001
message = str(exc).lower()
if "already exists" in message or "duplicate" in message:
return
raise
await asyncio.to_thread(_ensure)
async def download_bytes(self, *, bucket: str, path: str) -> bytes:
def _download() -> object:
bucket_client = self._bucket_client(bucket=bucket)
download = getattr(bucket_client, "download", None)
if not callable(download):
raise RuntimeError("Supabase storage download is unavailable")
return download(path)
raw = await asyncio.to_thread(_download)
if isinstance(raw, bytes):
return raw
if isinstance(raw, bytearray):
return bytes(raw)
if isinstance(raw, memoryview):
return raw.tobytes()
raise RuntimeError("Invalid attachment payload")
async def create_signed_url(
self,
*,
bucket: str,
path: str,
expires_in_seconds: int,
) -> str:
def _create_signed_url() -> object:
bucket_client = self._bucket_client(bucket=bucket)
signer = getattr(bucket_client, "create_signed_url", None)
if not callable(signer):
raise RuntimeError("Supabase storage signed url is unavailable")
return signer(path, expires_in_seconds)
raw = await asyncio.to_thread(_create_signed_url)
if isinstance(raw, str):
return raw
if isinstance(raw, dict):
signed_url = raw.get("signedURL") or raw.get("signedUrl") or raw.get("url")
if isinstance(signed_url, str) and signed_url:
return signed_url
raise RuntimeError("Invalid signed url payload")
def create_attachment_storage() -> AgentAttachmentStorage | None:
try:
@@ -49,3 +131,11 @@ def create_attachment_storage() -> AgentAttachmentStorage | None:
except Exception:
return None
return AgentAttachmentStorage()
def _is_bucket_not_found_error(exc: Exception) -> bool:
if isinstance(exc, StorageApiError):
message = str(exc).lower()
return "bucket" in message and "not found" in message
message = str(exc).lower()
return "bucket" in message and "not found" in message
+138 -16
View File
@@ -9,6 +9,7 @@ from fastapi import HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from core.config.settings import config
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
from models.agent_chat_session import AgentChatSession
@@ -200,6 +201,61 @@ class AgentRepository:
return None
return str(latest_id)
async def get_message_attachment_reference(
self,
*,
session_id: str,
message_id: str,
attachment_index: int,
) -> dict[str, str] | None:
try:
session_uuid = UUID(session_id)
message_uuid = UUID(message_id)
except ValueError as exc:
raise HTTPException(
status_code=422, detail="Invalid message/session id"
) from exc
stmt = (
select(AgentChatMessage)
.where(AgentChatMessage.id == message_uuid)
.where(AgentChatMessage.session_id == session_uuid)
.where(AgentChatMessage.deleted_at.is_(None))
)
message = (await self._session.execute(stmt)).scalar_one_or_none()
if message is None:
return None
metadata = (
message.metadata_json if isinstance(message.metadata_json, dict) else {}
)
attachments_raw = metadata.get("attachments")
if not isinstance(attachments_raw, list):
return None
if attachment_index < 0 or attachment_index >= len(attachments_raw):
return None
attachment = attachments_raw[attachment_index]
if not isinstance(attachment, dict):
return None
bucket = attachment.get("bucket")
path = attachment.get("path")
mime_type = attachment.get("mimeType")
if (
not isinstance(bucket, str)
or not bucket
or not isinstance(path, str)
or not path
or not isinstance(mime_type, str)
or not mime_type
):
return None
return {
"bucket": bucket,
"path": path,
"mimeType": mime_type,
}
async def _to_snapshot_message(
self, message: AgentChatMessage
) -> dict[str, object]:
@@ -233,30 +289,65 @@ class AgentRepository:
storage_bucket = metadata.get("storage_bucket")
storage_path = metadata.get("storage_path")
if isinstance(storage_bucket, str) and isinstance(storage_path, str):
try:
hydrated_content = await self._tool_result_storage.read_json(
bucket=storage_bucket,
path=storage_path,
expected_bucket = config.storage.bucket
message_session_id = getattr(message, "session_id", None)
expected_prefix = (
f"tool-results/{message_session_id}/"
if message_session_id is not None
else None
)
tool_call_id = metadata.get("tool_call_id")
is_legacy_path = isinstance(
tool_call_id, str
) and storage_path.endswith(f"/{tool_call_id}.json")
if (
storage_bucket == expected_bucket
and _is_safe_storage_path(storage_path)
and (
(
expected_prefix is not None
and storage_path.startswith(expected_prefix)
)
or (
storage_path.startswith("tool-results/")
and is_legacy_path
)
)
except Exception:
hydrated_content = None
):
try:
hydrated_content = (
await self._tool_result_storage.read_json(
bucket=storage_bucket,
path=storage_path,
)
)
except Exception:
hydrated_content = None
resolved_content = hydrated_content or parsed_content
payload["content"] = message.content
if resolved_content is not None:
result = resolved_content.get("result")
if isinstance(result, dict):
result_content = result.get("content")
if isinstance(result_content, str):
payload["content"] = result_content
ui = resolved_content.get("ui")
if not isinstance(ui, dict):
ui = resolved_content.get("ui_schema")
if isinstance(ui, dict):
payload["ui"] = ui
display_content = resolved_content.get("content")
if isinstance(display_content, str):
if not isinstance(display_content, str):
nested_result = resolved_content.get("result")
if isinstance(nested_result, dict):
nested_content = nested_result.get("content")
if isinstance(nested_content, str):
display_content = nested_content
if (
isinstance(display_content, str)
and display_content.strip()
and (
not payload["content"]
or _looks_like_offloaded_placeholder(str(payload["content"]))
)
):
payload["content"] = display_content
if "content" not in payload:
payload["content"] = message.content
else:
payload["content"] = message.content
metadata = message.metadata_json or {}
@@ -264,7 +355,22 @@ class AgentRepository:
metadata.get("attachments") if isinstance(metadata, dict) else None
)
if isinstance(attachments, list):
rendered = [item for item in attachments if isinstance(item, dict)]
rendered: list[dict[str, object]] = []
for index, item in enumerate(attachments):
if not isinstance(item, dict):
continue
mime_type = item.get("mimeType")
if not isinstance(mime_type, str) or not mime_type:
continue
rendered.append(
{
"mimeType": mime_type,
"previewPath": (
f"/api/v1/agent/runs/{message.session_id}/attachments/"
f"{message.id}/{index}"
),
}
)
if rendered:
payload["attachments"] = rendered
return payload
@@ -279,3 +385,19 @@ def _derive_session_title(content_text: str) -> str | None:
if not normalized:
return None
return normalized[:80]
def _is_safe_storage_path(path: str) -> bool:
normalized = path.strip()
if not normalized:
return False
if normalized.startswith("/"):
return False
if ".." in normalized:
return False
return True
def _looks_like_offloaded_placeholder(content: str) -> bool:
normalized = content.strip().lower()
return normalized in {'{"offloaded":true}', '{"offloaded": true}'}
+121 -2
View File
@@ -10,7 +10,17 @@ import time
from typing import Annotated, Union
from ag_ui.core import RunAgentInput
from fastapi import APIRouter, Depends, Header, Query, Request, status, UploadFile
from fastapi import (
APIRouter,
Depends,
File,
Form,
Header,
Query,
Request,
status,
UploadFile,
)
from fastapi import HTTPException
from fastapi.responses import JSONResponse, StreamingResponse
@@ -20,11 +30,18 @@ from core.agentscope.schemas.agui_input import (
parse_run_input,
validate_run_request_messages_contract,
)
from core.auth.jwt_verifier import JwtVerifier, TokenValidationError
from core.auth.models import CurrentUser
from core.config.settings import config
from core.logging import get_logger
from services.base.redis import get_or_init_redis_client
from v1.agent.dependencies import get_agent_service
from v1.agent.schemas import AsrTranscribeResponse, TaskAcceptedResponse
from v1.agent.schemas import (
AsrTranscribeResponse,
AttachmentReference,
AttachmentUploadResponse,
TaskAcceptedResponse,
)
from v1.agent.service import AgentService, asr_service
from v1.users.dependencies import get_current_user
@@ -38,6 +55,7 @@ _SSE_SLOT_TTL_SECONDS = 15 * 60
_MAX_TRANSCRIBE_AUDIO_BYTES = 10 * 1024 * 1024
_TRANSCRIBE_READ_CHUNK_BYTES = 1024 * 1024
_MULTIPART_OVERHEAD_BYTES = 64 * 1024
_MAX_ATTACHMENT_UPLOAD_BYTES = 5 * 1024 * 1024
_WAV_HEADER_MIN_BYTES = 12
_ALLOWED_AUDIO_CONTENT_TYPES = {
"audio/wav",
@@ -46,6 +64,42 @@ _ALLOWED_AUDIO_CONTENT_TYPES = {
}
def _verified_access_token_for_user(
*,
authorization: str | None,
current_user: CurrentUser,
) -> str | None:
if not isinstance(authorization, str):
return None
normalized = authorization.strip()
if not normalized:
return None
if not normalized.lower().startswith("bearer "):
raise HTTPException(status_code=401, detail="Unauthorized")
token = normalized[7:].strip()
if not token:
raise HTTPException(status_code=401, detail="Unauthorized")
jwt_secret = config.supabase.jwt_secret
if jwt_secret is None:
raise HTTPException(status_code=503, detail="Auth verifier unavailable")
verifier = JwtVerifier(
issuer=str(config.supabase.jwt_issuer),
jwt_secret=jwt_secret.get_secret_value(),
jwt_algorithm=config.supabase.jwt_algorithm,
)
try:
payload = verifier.verify(token)
except TokenValidationError as exc:
raise HTTPException(status_code=401, detail="Unauthorized") from exc
subject = payload.get("sub")
if not isinstance(subject, str) or subject != str(current_user.id):
raise HTTPException(status_code=403, detail="Forbidden")
return token
def _looks_like_wav_header(header: bytes) -> bool:
if len(header) < _WAV_HEADER_MIN_BYTES:
return False
@@ -111,6 +165,7 @@ async def enqueue_run(
request: RunAgentInput,
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
authorization: str | None = Header(default=None, alias="Authorization"),
) -> TaskAcceptedResponse:
try:
normalized = parse_run_input(request.model_dump(mode="json", by_alias=True))
@@ -120,10 +175,15 @@ async def enqueue_run(
allowed = await _allow_run_request(user_id=str(current_user.id))
if not allowed:
raise HTTPException(status_code=429, detail="Too many run requests")
user_token = _verified_access_token_for_user(
authorization=authorization,
current_user=current_user,
)
task = await service.enqueue_run(
run_input=request,
current_user=current_user,
user_token=user_token,
)
return TaskAcceptedResponse(
taskId=task.task_id,
@@ -143,6 +203,7 @@ async def enqueue_resume(
request: RunAgentInput,
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
authorization: str | None = Header(default=None, alias="Authorization"),
) -> TaskAcceptedResponse:
if request.thread_id != thread_id:
raise HTTPException(status_code=422, detail="thread_id path/body mismatch")
@@ -154,10 +215,15 @@ async def enqueue_resume(
allowed = await _allow_run_request(user_id=str(current_user.id))
if not allowed:
raise HTTPException(status_code=429, detail="Too many run requests")
user_token = _verified_access_token_for_user(
authorization=authorization,
current_user=current_user,
)
task = await service.enqueue_resume(
thread_id=thread_id,
run_input=request,
current_user=current_user,
user_token=user_token,
)
return TaskAcceptedResponse(
taskId=task.task_id,
@@ -253,6 +319,31 @@ async def get_history_snapshot(
)
@router.get("/runs/{thread_id}/attachments/{message_id}/{attachment_index}")
async def get_attachment_preview(
thread_id: str,
message_id: str,
attachment_index: int,
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
) -> StreamingResponse:
if attachment_index < 0:
raise HTTPException(status_code=422, detail="Invalid attachment index")
payload, mime_type = await service.get_attachment_preview(
thread_id=thread_id,
message_id=message_id,
attachment_index=attachment_index,
current_user=current_user,
)
return StreamingResponse(
iter([payload]),
media_type=mime_type,
headers={
"Cache-Control": "private, max-age=300",
},
)
@router.get("/history")
async def get_user_history_snapshot(
service: Annotated[AgentService, Depends(get_agent_service)],
@@ -267,6 +358,34 @@ async def get_user_history_snapshot(
)
@router.post(
"/attachments",
response_model=AttachmentUploadResponse,
status_code=status.HTTP_200_OK,
)
async def upload_attachment(
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
thread_id: str = Form(alias="threadId"),
file: UploadFile = File(),
) -> AttachmentUploadResponse:
payload = await file.read()
if not payload:
raise HTTPException(status_code=422, detail="Empty attachment")
if len(payload) > _MAX_ATTACHMENT_UPLOAD_BYTES:
raise HTTPException(status_code=413, detail="Attachment too large")
attachment = await service.upload_attachment(
thread_id=thread_id,
filename=file.filename,
content_type=file.content_type,
payload=payload,
current_user=current_user,
)
return AttachmentUploadResponse(
attachment=AttachmentReference.model_validate(attachment),
)
@router.post(
"/transcribe",
response_model=AsrTranscribeResponse,
+13
View File
@@ -14,3 +14,16 @@ class TaskAcceptedResponse(BaseModel):
class AsrTranscribeResponse(BaseModel):
transcript: str = Field(description="Transcribed text from audio")
class AttachmentReference(BaseModel):
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
bucket: str
path: str
mime_type: str = Field(alias="mimeType")
url: str
class AttachmentUploadResponse(BaseModel):
attachment: AttachmentReference
+297 -60
View File
@@ -1,7 +1,6 @@
from __future__ import annotations
import asyncio
import base64
from dataclasses import dataclass
from datetime import date
import hashlib
@@ -19,17 +18,22 @@ from core.config.settings import config
from core.logging import get_logger
logger = get_logger(__name__)
_ALLOWED_ATTACHMENT_MIME_TYPES = {"image/png", "image/jpeg", "image/webp"}
_MAX_ATTACHMENT_BYTES = 5 * 1024 * 1024
_MAX_TOTAL_ATTACHMENT_BYTES = 12 * 1024 * 1024
def _extract_user_token_from_run_input(run_input: RunAgentInput) -> str | None:
forwarded = run_input.forwarded_props
if not isinstance(forwarded, dict):
def _normalize_bearer_token(value: str | None) -> str | None:
if not isinstance(value, str):
return None
for key in ("accessToken", "userToken", "token"):
value = forwarded.get(key)
if isinstance(value, str) and value.strip():
return value.strip()
return None
normalized = value.strip()
if not normalized:
return None
lower = normalized.lower()
if lower.startswith("bearer "):
token = normalized[7:].strip()
return token or None
return normalized
@dataclass(frozen=True)
@@ -66,6 +70,14 @@ class AgentRepositoryLike(Protocol):
metadata: dict[str, object] | None,
) -> None: ...
async def get_message_attachment_reference(
self,
*,
session_id: str,
message_id: str,
attachment_index: int,
) -> dict[str, str] | None: ...
class QueueClientLike(Protocol):
async def enqueue(
@@ -92,6 +104,16 @@ class AttachmentStorageLike(Protocol):
content_type: str,
) -> str: ...
async def download_bytes(self, *, bucket: str, path: str) -> bytes: ...
async def create_signed_url(
self,
*,
bucket: str,
path: str,
expires_in_seconds: int,
) -> str: ...
def ensure_session_owner(*, owner_id: str, current_user: CurrentUser) -> None:
if owner_id != str(current_user.id):
@@ -104,6 +126,8 @@ class AgentService:
_stream: EventStreamLike
_attachment_storage: AttachmentStorageLike | None
_SIGNED_URL_EXPIRES_IN_SECONDS = 3600
def __init__(
self,
*,
@@ -122,6 +146,7 @@ class AgentService:
*,
run_input: RunAgentInput,
current_user: CurrentUser,
user_token: str | None = None,
) -> TaskAccepted:
created = False
thread_id = run_input.thread_id
@@ -161,7 +186,7 @@ class AgentService:
command={
"command": "run",
"owner_id": str(current_user.id),
"user_token": _extract_user_token_from_run_input(run_input),
"user_token": _normalize_bearer_token(user_token),
"run_input": run_input.model_dump(mode="json", by_alias=True),
},
dedup_key=None,
@@ -179,57 +204,115 @@ class AgentService:
run_input: RunAgentInput,
current_user: CurrentUser,
) -> tuple[str, dict[str, object] | None]:
text, content_blocks = extract_latest_user_payload(run_input)
text, _ = extract_latest_user_payload(run_input)
content_blocks = _extract_latest_user_content_blocks(run_input)
attachments: list[dict[str, object]] = []
if self._attachment_storage is not None:
for index, block in enumerate(content_blocks):
if not isinstance(block, dict):
continue
if block.get("type") != "image_url":
continue
image_value = block.get("image_url")
if not isinstance(image_value, dict):
continue
url = image_value.get("url")
if not isinstance(url, str) or not url.startswith("data:"):
continue
decoded = _decode_data_url(url)
if decoded is None:
continue
mime_type, payload = decoded
suffix = _mime_to_suffix(mime_type)
checksum = hashlib.sha1(payload).hexdigest()[:16]
path = (
f"agent-inputs/{current_user.id}/{run_input.thread_id}/"
f"{run_input.run_id}/attachment-{index}-{checksum}.{suffix}"
binary_blocks = [
block
for block in content_blocks
if isinstance(block, dict) and block.get("type") == "binary"
]
if binary_blocks:
if self._attachment_storage is None:
raise HTTPException(
status_code=503,
detail="Attachment storage unavailable",
)
bucket_name = config.storage.bucket
forwarded_props = (
run_input.forwarded_props
if isinstance(run_input.forwarded_props, dict)
else {}
)
raw_attachments = forwarded_props.get("attachments")
if not isinstance(raw_attachments, list):
raise HTTPException(
status_code=422, detail="Invalid attachments payload"
)
if len(raw_attachments) != len(binary_blocks):
raise HTTPException(
status_code=422, detail="Invalid attachments payload"
)
total_attachment_bytes = 0
expected_prefix = f"agent-inputs/{current_user.id}/{run_input.thread_id}/"
for index, raw_attachment in enumerate(raw_attachments):
if not isinstance(raw_attachment, dict):
raise HTTPException(
status_code=422,
detail="Invalid attachment reference",
)
bucket = raw_attachment.get("bucket")
path = raw_attachment.get("path")
mime_type = raw_attachment.get("mimeType")
if (
not isinstance(bucket, str)
or not isinstance(path, str)
or not isinstance(mime_type, str)
):
raise HTTPException(
status_code=422,
detail="Invalid attachment reference",
)
if bucket != config.storage.bucket:
raise HTTPException(status_code=403, detail="Forbidden")
if not _is_safe_attachment_path(path, expected_prefix=expected_prefix):
raise HTTPException(status_code=403, detail="Forbidden")
if mime_type.lower() not in _ALLOWED_ATTACHMENT_MIME_TYPES:
raise HTTPException(
status_code=422,
detail="Unsupported attachment type",
)
binary_block = binary_blocks[index]
binary_mime = binary_block.get("mimeType")
binary_url = binary_block.get("url")
if (
not isinstance(binary_mime, str)
or binary_mime != mime_type
or not isinstance(binary_url, str)
or not binary_url
):
raise HTTPException(
status_code=422,
detail="Invalid attachments payload",
)
try:
stored_path = await self._attachment_storage.upload_bytes(
bucket=bucket_name,
payload = await self._attachment_storage.download_bytes(
bucket=bucket,
path=path,
content=payload,
content_type=mime_type,
)
except Exception: # noqa: BLE001
logger.exception(
"Attachment upload failed",
"Attachment validation download failed",
extra={
"bucket": bucket_name,
"bucket": bucket,
"path": path,
"mime_type": mime_type,
"thread_id": run_input.thread_id,
"run_id": run_input.run_id,
},
)
raise HTTPException(
status_code=502,
detail="Failed to upload attachment",
detail="Failed to fetch attachment",
)
payload_size = len(payload)
if payload_size > _MAX_ATTACHMENT_BYTES:
raise HTTPException(
status_code=413,
detail="Attachment too large",
)
total_attachment_bytes += payload_size
if total_attachment_bytes > _MAX_TOTAL_ATTACHMENT_BYTES:
raise HTTPException(
status_code=413,
detail="Attachments too large",
)
attachments.append(
{
"bucket": bucket_name,
"path": stored_path,
"bucket": bucket,
"path": path,
"mimeType": mime_type,
}
)
@@ -238,12 +321,94 @@ class AgentService:
metadata["attachments"] = attachments
return text, metadata or None
async def upload_attachment(
self,
*,
thread_id: str,
filename: str | None,
content_type: str | None,
payload: bytes,
current_user: CurrentUser,
) -> dict[str, str]:
try:
owner = await self._repository.get_session_owner(session_id=thread_id)
except HTTPException as exc:
if exc.status_code != 404:
raise
try:
await self._repository.create_session_for_user(
user_id=str(current_user.id),
session_id=thread_id,
)
await self._repository.commit()
except IntegrityError:
await self._repository.rollback()
owner = await self._repository.get_session_owner(session_id=thread_id)
ensure_session_owner(owner_id=owner, current_user=current_user)
else:
ensure_session_owner(owner_id=owner, current_user=current_user)
if self._attachment_storage is None:
raise HTTPException(
status_code=503, detail="Attachment storage unavailable"
)
if not isinstance(content_type, str):
raise HTTPException(status_code=422, detail="Unsupported attachment type")
mime_type = content_type.lower()
if mime_type not in _ALLOWED_ATTACHMENT_MIME_TYPES:
raise HTTPException(status_code=422, detail="Unsupported attachment type")
if not payload:
raise HTTPException(status_code=422, detail="Empty attachment")
if len(payload) > _MAX_ATTACHMENT_BYTES:
raise HTTPException(status_code=413, detail="Attachment too large")
suffix = _mime_to_suffix(mime_type)
checksum = hashlib.sha1(payload).hexdigest()[:16]
filename_seed = filename if isinstance(filename, str) and filename else "upload"
filename_hash = hashlib.sha1(filename_seed.encode("utf-8")).hexdigest()[:8]
path = (
f"agent-inputs/{current_user.id}/{thread_id}/uploads/"
f"{filename_hash}-{checksum}.{suffix}"
)
bucket_name = config.storage.bucket
try:
stored_path = await self._attachment_storage.upload_bytes(
bucket=bucket_name,
path=path,
content=payload,
content_type=mime_type,
)
signed_url = await self._attachment_storage.create_signed_url(
bucket=bucket_name,
path=stored_path,
expires_in_seconds=self._SIGNED_URL_EXPIRES_IN_SECONDS,
)
except Exception: # noqa: BLE001
logger.exception(
"Attachment upload failed",
extra={
"bucket": bucket_name,
"path": path,
"mime_type": mime_type,
"thread_id": thread_id,
},
)
raise HTTPException(status_code=502, detail="Failed to upload attachment")
return {
"bucket": bucket_name,
"path": stored_path,
"mimeType": mime_type,
"url": signed_url,
}
async def enqueue_resume(
self,
*,
thread_id: str,
run_input: RunAgentInput,
current_user: CurrentUser,
user_token: str | None = None,
) -> TaskAccepted:
owner = await self._repository.get_session_owner(session_id=thread_id)
ensure_session_owner(owner_id=owner, current_user=current_user)
@@ -253,7 +418,7 @@ class AgentService:
command={
"command": "resume",
"owner_id": str(current_user.id),
"user_token": _extract_user_token_from_run_input(run_input),
"user_token": _normalize_bearer_token(user_token),
"run_input": run_input.model_dump(mode="json", by_alias=True),
},
dedup_key=dedup_key,
@@ -336,6 +501,63 @@ class AgentService:
current_user=current_user,
)
async def get_attachment_preview(
self,
*,
thread_id: str,
message_id: str,
attachment_index: int,
current_user: CurrentUser,
) -> tuple[bytes, str]:
owner = await self._repository.get_session_owner(session_id=thread_id)
ensure_session_owner(owner_id=owner, current_user=current_user)
if self._attachment_storage is None:
raise HTTPException(
status_code=503, detail="Attachment storage unavailable"
)
ref = await self._repository.get_message_attachment_reference(
session_id=thread_id,
message_id=message_id,
attachment_index=attachment_index,
)
if ref is None:
raise HTTPException(status_code=404, detail="Attachment not found")
bucket = ref.get("bucket")
path = ref.get("path")
mime_type = ref.get("mimeType")
if (
not isinstance(bucket, str)
or not isinstance(path, str)
or not isinstance(mime_type, str)
):
raise HTTPException(status_code=404, detail="Attachment not found")
if bucket != config.storage.bucket:
raise HTTPException(status_code=403, detail="Forbidden")
expected_prefix = f"agent-inputs/{current_user.id}/{thread_id}/"
if not _is_safe_attachment_path(path, expected_prefix=expected_prefix):
raise HTTPException(status_code=403, detail="Forbidden")
try:
payload = await self._attachment_storage.download_bytes(
bucket=bucket,
path=path,
)
except Exception: # noqa: BLE001
logger.exception(
"Attachment download failed",
extra={
"thread_id": thread_id,
"message_id": message_id,
"attachment_index": attachment_index,
"bucket": bucket,
},
)
raise HTTPException(status_code=502, detail="Failed to fetch attachment")
return payload, mime_type
class AsrService:
def __init__(self) -> None:
@@ -445,22 +667,26 @@ class AsrService:
asr_service = AsrService()
def _decode_data_url(data_url: str) -> tuple[str, bytes] | None:
if not data_url.startswith("data:"):
return None
header, sep, payload = data_url.partition(",")
if not sep:
return None
mime_type = "image/png"
if ";" in header:
maybe_mime = header[5:].split(";", 1)[0].strip()
if maybe_mime:
mime_type = maybe_mime
try:
decoded = base64.b64decode(payload, validate=True)
except ValueError:
return None
return mime_type, decoded
def _extract_latest_user_content_blocks(
run_input: RunAgentInput,
) -> list[dict[str, Any]]:
if not run_input.messages:
return []
latest = run_input.messages[-1]
content = getattr(latest, "content", None)
if not isinstance(content, list):
return []
blocks: list[dict[str, Any]] = []
for item in content:
if isinstance(item, dict):
blocks.append(item)
continue
model_dump = getattr(item, "model_dump", None)
if callable(model_dump):
dumped = model_dump(mode="json", by_alias=True, exclude_none=True)
if isinstance(dumped, dict):
blocks.append(dumped)
return blocks
def _mime_to_suffix(mime_type: str) -> str:
@@ -470,3 +696,14 @@ def _mime_to_suffix(mime_type: str) -> str:
"image/webp": "webp",
}
return mapping.get(mime_type.lower(), "bin")
def _is_safe_attachment_path(path: str, *, expected_prefix: str) -> bool:
normalized = path.strip()
if not normalized:
return False
if normalized.startswith("/"):
return False
if ".." in normalized:
return False
return normalized.startswith(expected_prefix)