feat(agent): support multimodal intent input and ASR transcribe endpoint

This commit is contained in:
zl-q
2026-03-08 17:34:28 +08:00
parent 5ada60e834
commit 1060503a2d
11 changed files with 422 additions and 74 deletions
+99 -3
View File
@@ -1,15 +1,23 @@
from __future__ import annotations
import asyncio
import tempfile
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import date
from typing import Protocol
from typing import Any, Protocol
from ag_ui.core import StateSnapshotEvent
from ag_ui.core import RunAgentInput
import dashscope
from ag_ui.core import RunAgentInput, StateSnapshotEvent
from dashscope.audio.asr import Recognition, RecognitionCallback
from fastapi import HTTPException
from sqlalchemy.exc import IntegrityError
from core.auth.models import CurrentUser
from core.config.settings import config
from core.logging import get_logger
logger = get_logger(__name__)
@dataclass(frozen=True)
@@ -210,3 +218,91 @@ class AgentService:
before=before,
current_user=current_user,
)
class AsrService:
def __init__(self) -> None:
self._api_key: str | None = None
def _get_api_key(self) -> str:
if self._api_key is None:
dashscope_key = config.llm.provider_keys.get("dashscope")
if not dashscope_key:
raise ValueError(
"DASHSCOPE_API_KEY not configured. Set SOCIAL_LLM__PROVIDER_KEYS__DASHSCOPE in environment."
)
self._api_key = dashscope_key
return self._api_key
@contextmanager
def _temp_wav(self, audio_data: bytes):
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
tmp.write(audio_data)
tmp_path = tmp.name
try:
yield tmp_path
finally:
import os
if os.path.exists(tmp_path):
os.unlink(tmp_path)
async def transcribe(self, audio_data: bytes, filename: str) -> str:
try:
dashscope.api_key = self._get_api_key()
with self._temp_wav(audio_data) as tmp_path:
loop = asyncio.get_event_loop()
class SyncCallback(RecognitionCallback):
error: str | None = None
def on_error(self, result: Any) -> None:
self.error = str(result)
callback = SyncCallback()
recognizer = Recognition(
model="fun-asr-realtime-2026-02-28",
callback=callback,
format="wav",
sample_rate=16000,
)
result: Any = await loop.run_in_executor(
None,
lambda: recognizer.call(file=tmp_path),
)
if callback.error:
raise RuntimeError(f"ASR error: {callback.error}")
if result.status_code != 200:
raise RuntimeError(f"ASR transcription failed: {result.message}")
if result.output is None or result.output.sentence is None:
logger.warning(
"ASR returned empty result", extra={"request_id": result.request_id}
)
return ""
sentence = result.output.sentence
if isinstance(sentence, dict):
transcription = sentence.get("text", "")
elif isinstance(sentence, list):
transcription = " ".join(
item.get("text", "") for item in sentence if isinstance(item, dict)
)
else:
transcription = str(sentence) if sentence else ""
logger.info(
"ASR transcription completed",
extra={"filename": filename, "transcript_length": len(transcription)},
)
return transcription
except Exception as exc:
logger.exception("ASR transcription error")
raise RuntimeError(f"ASR transcription failed: {exc}") from exc
asr_service = AsrService()