feat(agent): support multimodal intent input and ASR transcribe endpoint
This commit is contained in:
@@ -5,12 +5,12 @@ import asyncio
|
||||
from datetime import date
|
||||
import re
|
||||
import time
|
||||
from typing import Annotated
|
||||
from typing import Annotated, Union
|
||||
|
||||
from ag_ui.core import RunAgentInput
|
||||
from fastapi import APIRouter, Depends, Header, Query, Request, status
|
||||
from fastapi import APIRouter, Depends, Header, Query, Request, status, UploadFile
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from core.agent.infrastructure.agui.stream import to_sse_event
|
||||
from core.agent.domain.agui_input import (
|
||||
@@ -20,8 +20,8 @@ from core.agent.domain.agui_input import (
|
||||
from core.auth.models import CurrentUser
|
||||
from services.base.redis import get_or_init_redis_client
|
||||
from v1.agent.dependencies import get_agent_service
|
||||
from v1.agent.schemas import TaskAcceptedResponse
|
||||
from v1.agent.service import AgentService
|
||||
from v1.agent.schemas import AsrTranscribeResponse, TaskAcceptedResponse
|
||||
from v1.agent.service import AgentService, asr_service
|
||||
from v1.users.dependencies import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/agent", tags=["agent"])
|
||||
@@ -211,3 +211,35 @@ async def get_user_history_snapshot(
|
||||
thread_id=thread_id,
|
||||
before=before,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/transcribe",
|
||||
response_model=AsrTranscribeResponse,
|
||||
status_code=status.HTTP_200_OK,
|
||||
)
|
||||
async def transcribe(
|
||||
audio: UploadFile,
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
) -> Union[AsrTranscribeResponse, JSONResponse]:
|
||||
try:
|
||||
audio_data = await audio.read()
|
||||
if not audio_data:
|
||||
raise ValueError("Empty audio file")
|
||||
|
||||
transcript = await asr_service.transcribe(
|
||||
audio_data, audio.filename or "unknown"
|
||||
)
|
||||
|
||||
return AsrTranscribeResponse(transcript=transcript)
|
||||
|
||||
except ValueError as exc:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(exc)},
|
||||
)
|
||||
except RuntimeError as exc:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content={"detail": str(exc)},
|
||||
)
|
||||
|
||||
@@ -10,3 +10,7 @@ class TaskAcceptedResponse(BaseModel):
|
||||
thread_id: str = Field(alias="threadId")
|
||||
run_id: str = Field(alias="runId")
|
||||
created: bool
|
||||
|
||||
|
||||
class AsrTranscribeResponse(BaseModel):
|
||||
transcript: str = Field(description="Transcribed text from audio")
|
||||
|
||||
@@ -1,15 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import tempfile
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import date
|
||||
from typing import Protocol
|
||||
from typing import Any, Protocol
|
||||
|
||||
from ag_ui.core import StateSnapshotEvent
|
||||
from ag_ui.core import RunAgentInput
|
||||
import dashscope
|
||||
from ag_ui.core import RunAgentInput, StateSnapshotEvent
|
||||
from dashscope.audio.asr import Recognition, RecognitionCallback
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from core.config.settings import config
|
||||
from core.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -210,3 +218,91 @@ class AgentService:
|
||||
before=before,
|
||||
current_user=current_user,
|
||||
)
|
||||
|
||||
|
||||
class AsrService:
|
||||
def __init__(self) -> None:
|
||||
self._api_key: str | None = None
|
||||
|
||||
def _get_api_key(self) -> str:
|
||||
if self._api_key is None:
|
||||
dashscope_key = config.llm.provider_keys.get("dashscope")
|
||||
if not dashscope_key:
|
||||
raise ValueError(
|
||||
"DASHSCOPE_API_KEY not configured. Set SOCIAL_LLM__PROVIDER_KEYS__DASHSCOPE in environment."
|
||||
)
|
||||
self._api_key = dashscope_key
|
||||
return self._api_key
|
||||
|
||||
@contextmanager
|
||||
def _temp_wav(self, audio_data: bytes):
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
|
||||
tmp.write(audio_data)
|
||||
tmp_path = tmp.name
|
||||
try:
|
||||
yield tmp_path
|
||||
finally:
|
||||
import os
|
||||
|
||||
if os.path.exists(tmp_path):
|
||||
os.unlink(tmp_path)
|
||||
|
||||
async def transcribe(self, audio_data: bytes, filename: str) -> str:
|
||||
try:
|
||||
dashscope.api_key = self._get_api_key()
|
||||
|
||||
with self._temp_wav(audio_data) as tmp_path:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
class SyncCallback(RecognitionCallback):
|
||||
error: str | None = None
|
||||
|
||||
def on_error(self, result: Any) -> None:
|
||||
self.error = str(result)
|
||||
|
||||
callback = SyncCallback()
|
||||
recognizer = Recognition(
|
||||
model="fun-asr-realtime-2026-02-28",
|
||||
callback=callback,
|
||||
format="wav",
|
||||
sample_rate=16000,
|
||||
)
|
||||
|
||||
result: Any = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: recognizer.call(file=tmp_path),
|
||||
)
|
||||
|
||||
if callback.error:
|
||||
raise RuntimeError(f"ASR error: {callback.error}")
|
||||
if result.status_code != 200:
|
||||
raise RuntimeError(f"ASR transcription failed: {result.message}")
|
||||
|
||||
if result.output is None or result.output.sentence is None:
|
||||
logger.warning(
|
||||
"ASR returned empty result", extra={"request_id": result.request_id}
|
||||
)
|
||||
return ""
|
||||
|
||||
sentence = result.output.sentence
|
||||
if isinstance(sentence, dict):
|
||||
transcription = sentence.get("text", "")
|
||||
elif isinstance(sentence, list):
|
||||
transcription = " ".join(
|
||||
item.get("text", "") for item in sentence if isinstance(item, dict)
|
||||
)
|
||||
else:
|
||||
transcription = str(sentence) if sentence else ""
|
||||
|
||||
logger.info(
|
||||
"ASR transcription completed",
|
||||
extra={"filename": filename, "transcript_length": len(transcription)},
|
||||
)
|
||||
return transcription
|
||||
|
||||
except Exception as exc:
|
||||
logger.exception("ASR transcription error")
|
||||
raise RuntimeError(f"ASR transcription failed: {exc}") from exc
|
||||
|
||||
|
||||
asr_service = AsrService()
|
||||
|
||||
Reference in New Issue
Block a user