feat(agent): support multimodal intent input and ASR transcribe endpoint

This commit is contained in:
zl-q
2026-03-08 17:34:28 +08:00
parent 5ada60e834
commit 1060503a2d
11 changed files with 422 additions and 74 deletions
+37 -5
View File
@@ -5,12 +5,12 @@ import asyncio
from datetime import date
import re
import time
from typing import Annotated
from typing import Annotated, Union
from ag_ui.core import RunAgentInput
from fastapi import APIRouter, Depends, Header, Query, Request, status
from fastapi import APIRouter, Depends, Header, Query, Request, status, UploadFile
from fastapi import HTTPException
from fastapi.responses import StreamingResponse
from fastapi.responses import JSONResponse, StreamingResponse
from core.agent.infrastructure.agui.stream import to_sse_event
from core.agent.domain.agui_input import (
@@ -20,8 +20,8 @@ from core.agent.domain.agui_input import (
from core.auth.models import CurrentUser
from services.base.redis import get_or_init_redis_client
from v1.agent.dependencies import get_agent_service
from v1.agent.schemas import TaskAcceptedResponse
from v1.agent.service import AgentService
from v1.agent.schemas import AsrTranscribeResponse, TaskAcceptedResponse
from v1.agent.service import AgentService, asr_service
from v1.users.dependencies import get_current_user
router = APIRouter(prefix="/agent", tags=["agent"])
@@ -211,3 +211,35 @@ async def get_user_history_snapshot(
thread_id=thread_id,
before=before,
)
@router.post(
"/transcribe",
response_model=AsrTranscribeResponse,
status_code=status.HTTP_200_OK,
)
async def transcribe(
audio: UploadFile,
current_user: Annotated[CurrentUser, Depends(get_current_user)],
) -> Union[AsrTranscribeResponse, JSONResponse]:
try:
audio_data = await audio.read()
if not audio_data:
raise ValueError("Empty audio file")
transcript = await asr_service.transcribe(
audio_data, audio.filename or "unknown"
)
return AsrTranscribeResponse(transcript=transcript)
except ValueError as exc:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(exc)},
)
except RuntimeError as exc:
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={"detail": str(exc)},
)
+4
View File
@@ -10,3 +10,7 @@ class TaskAcceptedResponse(BaseModel):
thread_id: str = Field(alias="threadId")
run_id: str = Field(alias="runId")
created: bool
class AsrTranscribeResponse(BaseModel):
transcript: str = Field(description="Transcribed text from audio")
+99 -3
View File
@@ -1,15 +1,23 @@
from __future__ import annotations
import asyncio
import tempfile
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import date
from typing import Protocol
from typing import Any, Protocol
from ag_ui.core import StateSnapshotEvent
from ag_ui.core import RunAgentInput
import dashscope
from ag_ui.core import RunAgentInput, StateSnapshotEvent
from dashscope.audio.asr import Recognition, RecognitionCallback
from fastapi import HTTPException
from sqlalchemy.exc import IntegrityError
from core.auth.models import CurrentUser
from core.config.settings import config
from core.logging import get_logger
logger = get_logger(__name__)
@dataclass(frozen=True)
@@ -210,3 +218,91 @@ class AgentService:
before=before,
current_user=current_user,
)
class AsrService:
def __init__(self) -> None:
self._api_key: str | None = None
def _get_api_key(self) -> str:
if self._api_key is None:
dashscope_key = config.llm.provider_keys.get("dashscope")
if not dashscope_key:
raise ValueError(
"DASHSCOPE_API_KEY not configured. Set SOCIAL_LLM__PROVIDER_KEYS__DASHSCOPE in environment."
)
self._api_key = dashscope_key
return self._api_key
@contextmanager
def _temp_wav(self, audio_data: bytes):
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
tmp.write(audio_data)
tmp_path = tmp.name
try:
yield tmp_path
finally:
import os
if os.path.exists(tmp_path):
os.unlink(tmp_path)
async def transcribe(self, audio_data: bytes, filename: str) -> str:
try:
dashscope.api_key = self._get_api_key()
with self._temp_wav(audio_data) as tmp_path:
loop = asyncio.get_event_loop()
class SyncCallback(RecognitionCallback):
error: str | None = None
def on_error(self, result: Any) -> None:
self.error = str(result)
callback = SyncCallback()
recognizer = Recognition(
model="fun-asr-realtime-2026-02-28",
callback=callback,
format="wav",
sample_rate=16000,
)
result: Any = await loop.run_in_executor(
None,
lambda: recognizer.call(file=tmp_path),
)
if callback.error:
raise RuntimeError(f"ASR error: {callback.error}")
if result.status_code != 200:
raise RuntimeError(f"ASR transcription failed: {result.message}")
if result.output is None or result.output.sentence is None:
logger.warning(
"ASR returned empty result", extra={"request_id": result.request_id}
)
return ""
sentence = result.output.sentence
if isinstance(sentence, dict):
transcription = sentence.get("text", "")
elif isinstance(sentence, list):
transcription = " ".join(
item.get("text", "") for item in sentence if isinstance(item, dict)
)
else:
transcription = str(sentence) if sentence else ""
logger.info(
"ASR transcription completed",
extra={"filename": filename, "transcript_length": len(transcription)},
)
return transcription
except Exception as exc:
logger.exception("ASR transcription error")
raise RuntimeError(f"ASR transcription failed: {exc}") from exc
asr_service = AsrService()