from __future__ import annotations import asyncio import tempfile from contextlib import contextmanager from dataclasses import dataclass from datetime import date from typing import Any, Protocol import dashscope from ag_ui.core import RunAgentInput, StateSnapshotEvent from dashscope.audio.asr import Recognition, RecognitionCallback from fastapi import HTTPException from sqlalchemy.exc import IntegrityError from core.auth.models import CurrentUser from core.config.settings import config from core.logging import get_logger logger = get_logger(__name__) @dataclass(frozen=True) class TaskAccepted: task_id: str thread_id: str run_id: str created: bool class AgentRepositoryLike(Protocol): async def get_session_owner(self, *, session_id: str) -> str: ... async def create_session_for_user( self, *, user_id: str, session_id: str | None = None ) -> str: ... async def commit(self) -> None: ... async def rollback(self) -> None: ... async def get_history_day( self, *, session_id: str, before: date | None ) -> dict[str, object] | None: ... async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None: ... class QueueClientLike(Protocol): async def enqueue( self, *, command: dict[str, object], dedup_key: str | None ) -> str: ... class EventStreamLike(Protocol): async def read( self, *, session_id: str, last_event_id: str | None, ) -> list[dict[str, object]]: ... def ensure_session_owner(*, owner_id: str, current_user: CurrentUser) -> None: if owner_id != str(current_user.id): raise HTTPException(status_code=403, detail="Forbidden") class AgentService: def __init__( self, *, repository: AgentRepositoryLike, queue: QueueClientLike, stream: EventStreamLike, ) -> None: self._repository = repository self._queue = queue self._stream = stream async def enqueue_run( self, *, run_input: RunAgentInput, current_user: CurrentUser, ) -> TaskAccepted: created = False thread_id = run_input.thread_id run_id = run_input.run_id try: owner = await self._repository.get_session_owner(session_id=thread_id) except HTTPException as exc: if exc.status_code != 404: raise try: await self._repository.create_session_for_user( user_id=str(current_user.id), session_id=thread_id, ) await self._repository.commit() created = True except IntegrityError: await self._repository.rollback() owner = await self._repository.get_session_owner(session_id=thread_id) ensure_session_owner(owner_id=owner, current_user=current_user) else: ensure_session_owner(owner_id=owner, current_user=current_user) try: task_id = await self._queue.enqueue( command={ "command": "run", "run_input": run_input.model_dump(mode="json", by_alias=True), }, dedup_key=None, ) except Exception: # noqa: BLE001 raise return TaskAccepted( task_id=task_id, thread_id=thread_id, run_id=run_id, created=created, ) async def enqueue_resume( self, *, thread_id: str, run_input: RunAgentInput, current_user: CurrentUser, ) -> TaskAccepted: owner = await self._repository.get_session_owner(session_id=thread_id) ensure_session_owner(owner_id=owner, current_user=current_user) dedup_key = f"resume:{thread_id}:{run_input.run_id}" task_id = await self._queue.enqueue( command={ "command": "resume", "run_input": run_input.model_dump(mode="json", by_alias=True), }, dedup_key=dedup_key, ) return TaskAccepted( task_id=task_id, thread_id=thread_id, run_id=run_input.run_id, created=False, ) async def stream_events( self, *, thread_id: str, last_event_id: str | None, current_user: CurrentUser, ) -> list[dict[str, object]]: owner = await self._repository.get_session_owner(session_id=thread_id) ensure_session_owner(owner_id=owner, current_user=current_user) return await self._stream.read( session_id=thread_id, last_event_id=last_event_id, ) async def get_history_snapshot( self, *, thread_id: str, before: date | None, current_user: CurrentUser, ) -> dict[str, object]: owner = await self._repository.get_session_owner(session_id=thread_id) ensure_session_owner(owner_id=owner, current_user=current_user) day_payload = await self._repository.get_history_day( session_id=thread_id, before=before, ) snapshot = { "scope": "history_day", "threadId": thread_id, "day": day_payload["day"] if day_payload else None, "hasMore": day_payload["hasMore"] if day_payload else False, "messages": day_payload["messages"] if day_payload else [], } event = StateSnapshotEvent(snapshot=snapshot).model_dump( mode="json", by_alias=True, exclude_none=True, ) event["threadId"] = thread_id return event async def get_user_history_snapshot( self, *, current_user: CurrentUser, thread_id: str | None, before: date | None, ) -> dict[str, object]: target_thread_id = thread_id if target_thread_id is None: target_thread_id = await self._repository.get_latest_session_id_for_user( user_id=str(current_user.id) ) if target_thread_id is None: return StateSnapshotEvent( snapshot={ "scope": "history_day", "threadId": None, "day": None, "hasMore": False, "messages": [], } ).model_dump(mode="json", by_alias=True, exclude_none=True) return await self.get_history_snapshot( thread_id=target_thread_id, before=before, current_user=current_user, ) class AsrService: def __init__(self) -> None: self._api_key: str | None = None def _get_api_key(self) -> str: if self._api_key is None: dashscope_key = config.llm.provider_keys.get("dashscope") if not dashscope_key: raise ValueError( "DASHSCOPE_API_KEY not configured. Set SOCIAL_LLM__PROVIDER_KEYS__DASHSCOPE in environment." ) self._api_key = dashscope_key return self._api_key @contextmanager def _temp_wav(self, audio_data: bytes): with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: tmp.write(audio_data) tmp_path = tmp.name try: yield tmp_path finally: import os if os.path.exists(tmp_path): os.unlink(tmp_path) async def transcribe(self, audio_data: bytes, filename: str) -> str: try: dashscope.api_key = self._get_api_key() with self._temp_wav(audio_data) as tmp_path: loop = asyncio.get_event_loop() class SyncCallback(RecognitionCallback): error: str | None = None def on_error(self, result: Any) -> None: self.error = str(result) callback = SyncCallback() recognizer = Recognition( model="fun-asr-realtime-2026-02-28", callback=callback, format="wav", sample_rate=16000, ) result: Any = await loop.run_in_executor( None, lambda: recognizer.call(file=tmp_path), ) if callback.error: raise RuntimeError(f"ASR error: {callback.error}") if result.status_code != 200: raise RuntimeError(f"ASR transcription failed: {result.message}") if result.output is None or result.output.sentence is None: logger.warning( "ASR returned empty result", extra={"request_id": result.request_id} ) return "" sentence = result.output.sentence if isinstance(sentence, dict): transcription = sentence.get("text", "") elif isinstance(sentence, list): transcription = " ".join( item.get("text", "") for item in sentence if isinstance(item, dict) ) else: transcription = str(sentence) if sentence else "" logger.info( "ASR transcription completed", extra={"filename": filename, "transcript_length": len(transcription)}, ) return transcription except Exception as exc: logger.exception("ASR transcription error") raise RuntimeError(f"ASR transcription failed: {exc}") from exc asr_service = AsrService()