from __future__ import annotations import asyncio from dataclasses import dataclass from datetime import date from typing import Any, Protocol import dashscope from ag_ui.core import RunAgentInput, StateSnapshotEvent from dashscope.audio.asr import Recognition, RecognitionCallback from fastapi import HTTPException from sqlalchemy.exc import IntegrityError from core.auth.models import CurrentUser from core.config.settings import config from core.logging import get_logger logger = get_logger(__name__) def _extract_user_token_from_run_input(run_input: RunAgentInput) -> str | None: forwarded = run_input.forwarded_props if not isinstance(forwarded, dict): return None for key in ("accessToken", "userToken", "token"): value = forwarded.get(key) if isinstance(value, str) and value.strip(): return value.strip() return None @dataclass(frozen=True) class TaskAccepted: task_id: str thread_id: str run_id: str created: bool class AgentRepositoryLike(Protocol): async def get_session_owner(self, *, session_id: str) -> str: ... async def create_session_for_user( self, *, user_id: str, session_id: str | None = None ) -> str: ... async def commit(self) -> None: ... async def rollback(self) -> None: ... async def get_history_day( self, *, session_id: str, before: date | None ) -> dict[str, object] | None: ... async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None: ... class QueueClientLike(Protocol): async def enqueue( self, *, command: dict[str, object], dedup_key: str | None ) -> str: ... class EventStreamLike(Protocol): async def read( self, *, session_id: str, last_event_id: str | None, ) -> list[dict[str, object]]: ... def ensure_session_owner(*, owner_id: str, current_user: CurrentUser) -> None: if owner_id != str(current_user.id): raise HTTPException(status_code=403, detail="Forbidden") class AgentService: _repository: AgentRepositoryLike _queue: QueueClientLike _stream: EventStreamLike def __init__( self, *, repository: AgentRepositoryLike, queue: QueueClientLike, stream: EventStreamLike, ) -> None: self._repository = repository self._queue = queue self._stream = stream async def enqueue_run( self, *, run_input: RunAgentInput, current_user: CurrentUser, ) -> TaskAccepted: created = False thread_id = run_input.thread_id run_id = run_input.run_id try: owner = await self._repository.get_session_owner(session_id=thread_id) except HTTPException as exc: if exc.status_code != 404: raise try: await self._repository.create_session_for_user( user_id=str(current_user.id), session_id=thread_id, ) await self._repository.commit() created = True except IntegrityError: await self._repository.rollback() owner = await self._repository.get_session_owner(session_id=thread_id) ensure_session_owner(owner_id=owner, current_user=current_user) else: ensure_session_owner(owner_id=owner, current_user=current_user) task_id = await self._queue.enqueue( command={ "command": "run", "owner_id": str(current_user.id), "user_token": _extract_user_token_from_run_input(run_input), "run_input": run_input.model_dump(mode="json", by_alias=True), }, dedup_key=None, ) return TaskAccepted( task_id=task_id, thread_id=thread_id, run_id=run_id, created=created, ) async def enqueue_resume( self, *, thread_id: str, run_input: RunAgentInput, current_user: CurrentUser, ) -> TaskAccepted: owner = await self._repository.get_session_owner(session_id=thread_id) ensure_session_owner(owner_id=owner, current_user=current_user) dedup_key = f"resume:{thread_id}:{run_input.run_id}" task_id = await self._queue.enqueue( command={ "command": "resume", "owner_id": str(current_user.id), "user_token": _extract_user_token_from_run_input(run_input), "run_input": run_input.model_dump(mode="json", by_alias=True), }, dedup_key=dedup_key, ) return TaskAccepted( task_id=task_id, thread_id=thread_id, run_id=run_input.run_id, created=False, ) async def stream_events( self, *, thread_id: str, last_event_id: str | None, current_user: CurrentUser, ) -> list[dict[str, object]]: owner = await self._repository.get_session_owner(session_id=thread_id) ensure_session_owner(owner_id=owner, current_user=current_user) return await self._stream.read( session_id=thread_id, last_event_id=last_event_id, ) async def get_history_snapshot( self, *, thread_id: str, before: date | None, current_user: CurrentUser, ) -> dict[str, object]: owner = await self._repository.get_session_owner(session_id=thread_id) ensure_session_owner(owner_id=owner, current_user=current_user) day_payload = await self._repository.get_history_day( session_id=thread_id, before=before, ) snapshot = { "scope": "history_day", "threadId": thread_id, "day": day_payload["day"] if day_payload else None, "hasMore": day_payload["hasMore"] if day_payload else False, "messages": day_payload["messages"] if day_payload else [], } event = StateSnapshotEvent(snapshot=snapshot).model_dump( mode="json", by_alias=True, exclude_none=True, ) event["threadId"] = thread_id return event async def get_user_history_snapshot( self, *, current_user: CurrentUser, thread_id: str | None, before: date | None, ) -> dict[str, object]: target_thread_id = thread_id if target_thread_id is None: target_thread_id = await self._repository.get_latest_session_id_for_user( user_id=str(current_user.id) ) if target_thread_id is None: return StateSnapshotEvent( snapshot={ "scope": "history_day", "threadId": None, "day": None, "hasMore": False, "messages": [], } ).model_dump(mode="json", by_alias=True, exclude_none=True) return await self.get_history_snapshot( thread_id=target_thread_id, before=before, current_user=current_user, ) class AsrService: def __init__(self) -> None: self._api_key: str | None = None def _get_api_key(self) -> str: if self._api_key is None: dashscope_key = config.llm.provider_keys.get("dashscope") if not dashscope_key: raise ValueError( "DASHSCOPE_API_KEY not configured. Set SOCIAL_LLM__PROVIDER_KEYS__DASHSCOPE in environment." ) self._api_key = dashscope_key return self._api_key async def transcribe_file(self, file_path: str, filename: str) -> str: try: dashscope.api_key = self._get_api_key() loop = asyncio.get_event_loop() class SyncCallback(RecognitionCallback): error: str | None = None def on_error(self, result: Any) -> None: self.error = str(result) callback = SyncCallback() recognizer = Recognition( model="fun-asr-realtime-2026-02-28", callback=callback, format="wav", sample_rate=16000, ) result: Any = await loop.run_in_executor( None, lambda: recognizer.call(file=file_path), ) if callback.error: raise RuntimeError(f"ASR error: {callback.error}") status_code = self._extract_field(result, "status_code") if status_code != 200: message = self._extract_field(result, "message") raise RuntimeError(f"ASR transcription failed: {message}") sentence = self._extract_sentence_payload(result) if sentence is None: request_id = self._extract_field(result, "request_id") logger.warning( "ASR returned empty result", extra={"request_id": request_id} ) return "" if isinstance(sentence, dict): transcription = sentence.get("text", "") elif isinstance(sentence, list): transcription = " ".join( item.get("text", "") for item in sentence if isinstance(item, dict) ) else: transcription = str(sentence) if sentence else "" logger.info( "ASR transcription completed", extra={"filename": filename, "transcript_length": len(transcription)}, ) return transcription except asyncio.CancelledError: raise except RuntimeError: raise except Exception as exc: logger.exception("ASR transcription error") raise RuntimeError(f"ASR transcription failed: {exc}") from exc def _extract_sentence_payload(self, result: Any) -> Any | None: if isinstance(result, dict): output = result.get("output") if isinstance(output, dict): return output.get("sentence") if output is not None: return getattr(output, "sentence", None) return result.get("sentence") get_sentence = getattr(result, "get_sentence", None) if callable(get_sentence): sentence = get_sentence() if sentence is not None: return sentence output = getattr(result, "output", None) if output is None: return None if isinstance(output, dict): return output.get("sentence") return getattr(output, "sentence", None) def _extract_field(self, result: Any, field: str) -> Any | None: if isinstance(result, dict): return result.get(field) return getattr(result, field, None) asr_service = AsrService()