from __future__ import annotations import asyncio from dataclasses import dataclass from datetime import date import hashlib from typing import Any, Protocol from urllib.parse import urlparse import dashscope from ag_ui.core import RunAgentInput, StateSnapshotEvent from dashscope.audio.asr import Recognition, RecognitionCallback from fastapi import HTTPException from sqlalchemy.exc import IntegrityError from core.auth.models import CurrentUser from core.agentscope.schemas.agui_input import extract_latest_user_payload from core.config.settings import config from core.logging import get_logger logger = get_logger(__name__) _ALLOWED_ATTACHMENT_MIME_TYPES = {"image/png", "image/jpeg", "image/webp"} _MAX_ATTACHMENT_BYTES = 5 * 1024 * 1024 _MAX_TOTAL_ATTACHMENT_BYTES = 12 * 1024 * 1024 @dataclass(frozen=True) class TaskAccepted: task_id: str thread_id: str run_id: str created: bool class AgentRepositoryLike(Protocol): async def get_session_owner(self, *, session_id: str) -> str: ... async def create_session_for_user( self, *, user_id: str, session_id: str | None = None ) -> str: ... async def commit(self) -> None: ... async def rollback(self) -> None: ... async def get_history_day( self, *, session_id: str, before: date | None ) -> dict[str, object] | None: ... async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None: ... async def persist_user_message( self, *, session_id: str, run_id: str, content_text: str, metadata: dict[str, object] | None, ) -> None: ... class QueueClientLike(Protocol): async def enqueue( self, *, command: dict[str, object], dedup_key: str | None ) -> str: ... class EventStreamLike(Protocol): async def read( self, *, session_id: str, last_event_id: str | None, ) -> list[dict[str, object]]: ... class AttachmentStorageLike(Protocol): async def upload_bytes( self, *, bucket: str, path: str, content: bytes, content_type: str, ) -> str: ... async def download_bytes(self, *, bucket: str, path: str) -> bytes: ... async def create_signed_url( self, *, bucket: str, path: str, expires_in_seconds: int, ) -> str: ... def parse_signed_url(self, url: str) -> tuple[str, str]: ... def ensure_session_owner(*, owner_id: str, current_user: CurrentUser) -> None: if owner_id != str(current_user.id): raise HTTPException(status_code=403, detail="Forbidden") class AgentService: _repository: AgentRepositoryLike _queue: QueueClientLike _stream: EventStreamLike _attachment_storage: AttachmentStorageLike | None _SIGNED_URL_EXPIRES_IN_SECONDS = 3600 def __init__( self, *, repository: AgentRepositoryLike, queue: QueueClientLike, stream: EventStreamLike, attachment_storage: AttachmentStorageLike | None = None, ) -> None: self._repository = repository self._queue = queue self._stream = stream self._attachment_storage = attachment_storage async def enqueue_run( self, *, run_input: RunAgentInput, current_user: CurrentUser, ) -> TaskAccepted: created = False thread_id = run_input.thread_id run_id = run_input.run_id try: owner = await self._repository.get_session_owner(session_id=thread_id) except HTTPException as exc: if exc.status_code != 404: raise try: await self._repository.create_session_for_user( user_id=str(current_user.id), session_id=thread_id, ) await self._repository.commit() created = True except IntegrityError: await self._repository.rollback() owner = await self._repository.get_session_owner(session_id=thread_id) ensure_session_owner(owner_id=owner, current_user=current_user) else: ensure_session_owner(owner_id=owner, current_user=current_user) user_message_text, user_message_metadata = await self._prepare_user_message( run_input=run_input, current_user=current_user, ) await self._repository.persist_user_message( session_id=thread_id, run_id=run_id, content_text=user_message_text, metadata=user_message_metadata, ) await self._repository.commit() task_id = await self._queue.enqueue( command={ "command": "run", "owner_id": str(current_user.id), "run_input": run_input.model_dump(mode="json", by_alias=True), }, dedup_key=None, ) return TaskAccepted( task_id=task_id, thread_id=thread_id, run_id=run_id, created=created, ) async def _prepare_user_message( self, *, run_input: RunAgentInput, current_user: CurrentUser, ) -> tuple[str, dict[str, object] | None]: from schemas.messages.chat_message import UserMessageAttachments text, content_blocks = extract_latest_user_payload(run_input) user_attachments: UserMessageAttachments | None = None for block in content_blocks: if not isinstance(block, dict): continue block_type = block.get("type") if block_type != "binary": continue url = block.get("url") mime_type = block.get("mimeType") if not isinstance(url, str) or not url: continue if not isinstance(mime_type, str): mime_type = "application/octet-stream" if self._attachment_storage is None: raise HTTPException( status_code=503, detail="Attachment storage unavailable", ) try: bucket, path = self._validate_binary_signed_url( url=url, thread_id=run_input.thread_id, current_user=current_user, ) user_attachments = UserMessageAttachments( bucket=bucket, path=path, mime_type=mime_type, ) break except HTTPException: raise except Exception as exc: # noqa: BLE001 logger.warning("Failed to parse signed URL", url=url, error=str(exc)) raise HTTPException(status_code=422, detail="Invalid signed image url") metadata: dict[str, object] | None = None if user_attachments is not None: metadata = { "user_message_attachments": user_attachments.model_dump(by_alias=True), } return text, metadata async def upload_attachment( self, *, thread_id: str, filename: str | None, content_type: str | None, payload: bytes, current_user: CurrentUser, ) -> dict[str, str]: try: owner = await self._repository.get_session_owner(session_id=thread_id) except HTTPException as exc: if exc.status_code != 404: raise try: await self._repository.create_session_for_user( user_id=str(current_user.id), session_id=thread_id, ) await self._repository.commit() except IntegrityError: await self._repository.rollback() owner = await self._repository.get_session_owner(session_id=thread_id) ensure_session_owner(owner_id=owner, current_user=current_user) else: ensure_session_owner(owner_id=owner, current_user=current_user) if self._attachment_storage is None: raise HTTPException( status_code=503, detail="Attachment storage unavailable" ) if not isinstance(content_type, str): raise HTTPException(status_code=422, detail="Unsupported attachment type") mime_type = content_type.lower() if mime_type not in _ALLOWED_ATTACHMENT_MIME_TYPES: raise HTTPException(status_code=422, detail="Unsupported attachment type") if not payload: raise HTTPException(status_code=422, detail="Empty attachment") if len(payload) > _MAX_ATTACHMENT_BYTES: raise HTTPException(status_code=413, detail="Attachment too large") suffix = _mime_to_suffix(mime_type) checksum = hashlib.sha1(payload).hexdigest()[:16] filename_seed = filename if isinstance(filename, str) and filename else "upload" filename_hash = hashlib.sha1(filename_seed.encode("utf-8")).hexdigest()[:8] path = ( f"agent-inputs/{current_user.id}/{thread_id}/uploads/" f"{filename_hash}-{checksum}.{suffix}" ) bucket_name = config.storage.bucket try: stored_path = await self._attachment_storage.upload_bytes( bucket=bucket_name, path=path, content=payload, content_type=mime_type, ) signed_url = await self._attachment_storage.create_signed_url( bucket=bucket_name, path=stored_path, expires_in_seconds=self._SIGNED_URL_EXPIRES_IN_SECONDS, ) except Exception: # noqa: BLE001 logger.exception( "Attachment upload failed", extra={ "bucket": bucket_name, "path": path, "mime_type": mime_type, "thread_id": thread_id, }, ) raise HTTPException(status_code=502, detail="Failed to upload attachment") return { "bucket": bucket_name, "path": stored_path, "mimeType": mime_type, "url": signed_url, } async def create_attachment_signed_url( self, *, bucket: str, path: str, current_user: CurrentUser, ) -> dict[str, str]: if self._attachment_storage is None: raise HTTPException( status_code=503, detail="Attachment storage unavailable" ) normalized_bucket = bucket.strip() if normalized_bucket != config.storage.bucket: raise HTTPException(status_code=422, detail="Invalid attachment bucket") normalized_path = path.strip() expected_prefix = f"agent-inputs/{current_user.id}/" if not _is_safe_attachment_path( normalized_path, expected_prefix=expected_prefix ): raise HTTPException(status_code=422, detail="Invalid attachment path scope") try: signed_url = await self._attachment_storage.create_signed_url( bucket=normalized_bucket, path=normalized_path, expires_in_seconds=self._SIGNED_URL_EXPIRES_IN_SECONDS, ) except Exception: # noqa: BLE001 logger.exception( "Attachment signed URL generation failed", extra={ "bucket": normalized_bucket, "path": normalized_path, "user_id": str(current_user.id), }, ) raise HTTPException(status_code=502, detail="Failed to generate signed URL") return { "bucket": normalized_bucket, "path": normalized_path, "url": signed_url, } async def enqueue_resume( self, *, thread_id: str, run_input: RunAgentInput, current_user: CurrentUser, ) -> TaskAccepted: owner = await self._repository.get_session_owner(session_id=thread_id) ensure_session_owner(owner_id=owner, current_user=current_user) dedup_key = f"resume:{thread_id}:{run_input.run_id}" task_id = await self._queue.enqueue( command={ "command": "resume", "owner_id": str(current_user.id), "run_input": run_input.model_dump(mode="json", by_alias=True), }, dedup_key=dedup_key, ) return TaskAccepted( task_id=task_id, thread_id=thread_id, run_id=run_input.run_id, created=False, ) async def stream_events( self, *, thread_id: str, last_event_id: str | None, current_user: CurrentUser, ) -> list[dict[str, object]]: owner = await self._repository.get_session_owner(session_id=thread_id) ensure_session_owner(owner_id=owner, current_user=current_user) return await self._stream.read( session_id=thread_id, last_event_id=last_event_id, ) async def get_history_snapshot( self, *, thread_id: str, before: date | None, current_user: CurrentUser, ) -> dict[str, object]: owner = await self._repository.get_session_owner(session_id=thread_id) ensure_session_owner(owner_id=owner, current_user=current_user) day_payload = await self._repository.get_history_day( session_id=thread_id, before=before, ) snapshot = { "scope": "history_day", "threadId": thread_id, "day": day_payload["day"] if day_payload else None, "hasMore": day_payload["hasMore"] if day_payload else False, "messages": day_payload["messages"] if day_payload else [], } event = StateSnapshotEvent(snapshot=snapshot).model_dump( mode="json", by_alias=True, exclude_none=True, ) event["threadId"] = thread_id return event async def get_user_history_snapshot( self, *, current_user: CurrentUser, thread_id: str | None, before: date | None, ) -> dict[str, object]: target_thread_id = thread_id if target_thread_id is None: target_thread_id = await self._repository.get_latest_session_id_for_user( user_id=str(current_user.id) ) if target_thread_id is None: return StateSnapshotEvent( snapshot={ "scope": "history_day", "threadId": None, "day": None, "hasMore": False, "messages": [], } ).model_dump(mode="json", by_alias=True, exclude_none=True) return await self.get_history_snapshot( thread_id=target_thread_id, before=before, current_user=current_user, ) def _validate_binary_signed_url( self, *, url: str, thread_id: str, current_user: CurrentUser, ) -> tuple[str, str]: if self._attachment_storage is None: raise HTTPException( status_code=503, detail="Attachment storage unavailable" ) parsed = urlparse(url) expected_host = urlparse(config.supabase.url).netloc if parsed.netloc != expected_host: raise HTTPException(status_code=422, detail="INVALID_BINARY_URL_HOST") try: bucket, path = self._attachment_storage.parse_signed_url(url) except Exception as exc: # noqa: BLE001 raise HTTPException( status_code=422, detail="Invalid signed image url" ) from exc if bucket != config.storage.bucket: raise HTTPException(status_code=422, detail="INVALID_BINARY_URL_BUCKET") expected_prefix = f"agent-inputs/{current_user.id}/{thread_id}/uploads/" if not _is_safe_attachment_path(path, expected_prefix=expected_prefix): raise HTTPException(status_code=422, detail="INVALID_BINARY_URL_PATH_SCOPE") return bucket, path class AsrService: def __init__(self) -> None: self._api_key: str | None = None def _get_api_key(self) -> str: if self._api_key is None: dashscope_key = config.llm.provider_keys.get("dashscope") if not dashscope_key: raise ValueError( "DASHSCOPE_API_KEY not configured. Set SOCIAL_LLM__PROVIDER_KEYS__DASHSCOPE in environment." ) self._api_key = dashscope_key return self._api_key async def transcribe_file(self, file_path: str, filename: str) -> str: try: dashscope.api_key = self._get_api_key() loop = asyncio.get_event_loop() class SyncCallback(RecognitionCallback): error: str | None = None def on_error(self, result: Any) -> None: self.error = str(result) callback = SyncCallback() recognizer = Recognition( model="fun-asr-realtime-2026-02-28", callback=callback, format="wav", sample_rate=16000, ) result: Any = await loop.run_in_executor( None, lambda: recognizer.call(file=file_path), ) if callback.error: raise RuntimeError(f"ASR error: {callback.error}") status_code = self._extract_field(result, "status_code") if status_code != 200: message = self._extract_field(result, "message") raise RuntimeError(f"ASR transcription failed: {message}") sentence = self._extract_sentence_payload(result) if sentence is None: request_id = self._extract_field(result, "request_id") logger.warning( "ASR returned empty result", extra={"request_id": request_id} ) return "" if isinstance(sentence, dict): transcription = sentence.get("text", "") elif isinstance(sentence, list): transcription = " ".join( item.get("text", "") for item in sentence if isinstance(item, dict) ) else: transcription = str(sentence) if sentence else "" logger.info( "ASR transcription completed", extra={"filename": filename, "transcript_length": len(transcription)}, ) return transcription except asyncio.CancelledError: raise except RuntimeError: raise except Exception as exc: logger.exception("ASR transcription error") raise RuntimeError(f"ASR transcription failed: {exc}") from exc def _extract_sentence_payload(self, result: Any) -> Any | None: if isinstance(result, dict): output = result.get("output") if isinstance(output, dict): return output.get("sentence") if output is not None: return getattr(output, "sentence", None) return result.get("sentence") get_sentence = getattr(result, "get_sentence", None) if callable(get_sentence): sentence = get_sentence() if sentence is not None: return sentence output = getattr(result, "output", None) if output is None: return None if isinstance(output, dict): return output.get("sentence") return getattr(output, "sentence", None) def _extract_field(self, result: Any, field: str) -> Any | None: if isinstance(result, dict): return result.get(field) return getattr(result, field, None) asr_service = AsrService() def _mime_to_suffix(mime_type: str) -> str: mapping = { "image/png": "png", "image/jpeg": "jpg", "image/webp": "webp", } return mapping.get(mime_type.lower(), "bin") def _is_safe_attachment_path(path: str, *, expected_prefix: str) -> bool: normalized = path.strip() if not normalized: return False if normalized.startswith("/"): return False if ".." in normalized: return False return normalized.startswith(expected_prefix)