chore: 后端 agent 和 users 模块代码更新优化
This commit is contained in:
@@ -3,7 +3,9 @@ from __future__ import annotations
|
||||
from collections.abc import AsyncIterator
|
||||
import asyncio
|
||||
from datetime import date
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import time
|
||||
from typing import Annotated, Union
|
||||
|
||||
@@ -29,6 +31,21 @@ _LAST_EVENT_ID_RE = re.compile(r"^\d+-\d+$")
|
||||
_RUNS_PER_MINUTE = 30
|
||||
_MAX_SSE_CONNECTIONS_PER_USER = 3
|
||||
_SSE_SLOT_TTL_SECONDS = 15 * 60
|
||||
_MAX_TRANSCRIBE_AUDIO_BYTES = 10 * 1024 * 1024
|
||||
_TRANSCRIBE_READ_CHUNK_BYTES = 1024 * 1024
|
||||
_MULTIPART_OVERHEAD_BYTES = 64 * 1024
|
||||
_WAV_HEADER_MIN_BYTES = 12
|
||||
_ALLOWED_AUDIO_CONTENT_TYPES = {
|
||||
"audio/wav",
|
||||
"audio/x-wav",
|
||||
"audio/wave",
|
||||
}
|
||||
|
||||
|
||||
def _looks_like_wav_header(header: bytes) -> bool:
|
||||
if len(header) < _WAV_HEADER_MIN_BYTES:
|
||||
return False
|
||||
return header[0:4] == b"RIFF" and header[8:12] == b"WAVE"
|
||||
|
||||
|
||||
async def _allow_run_request(*, user_id: str) -> bool:
|
||||
@@ -220,15 +237,52 @@ async def get_user_history_snapshot(
|
||||
)
|
||||
async def transcribe(
|
||||
audio: UploadFile,
|
||||
request: Request,
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
) -> Union[AsrTranscribeResponse, JSONResponse]:
|
||||
del current_user
|
||||
temp_path: str | None = None
|
||||
try:
|
||||
audio_data = await audio.read()
|
||||
if not audio_data:
|
||||
raise ValueError("Empty audio file")
|
||||
if audio.content_type not in _ALLOWED_AUDIO_CONTENT_TYPES:
|
||||
raise ValueError("Unsupported audio format")
|
||||
|
||||
transcript = await asr_service.transcribe(
|
||||
audio_data, audio.filename or "unknown"
|
||||
content_length = request.headers.get("content-length")
|
||||
if content_length is not None:
|
||||
try:
|
||||
declared_length = int(content_length)
|
||||
except ValueError:
|
||||
declared_length = None
|
||||
if (
|
||||
declared_length is not None
|
||||
and declared_length
|
||||
> _MAX_TRANSCRIBE_AUDIO_BYTES + _MULTIPART_OVERHEAD_BYTES
|
||||
):
|
||||
raise ValueError("Audio file too large")
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
|
||||
temp_path = tmp_file.name
|
||||
|
||||
total_bytes = 0
|
||||
header = bytearray()
|
||||
while True:
|
||||
chunk = await audio.read(_TRANSCRIBE_READ_CHUNK_BYTES)
|
||||
if not chunk:
|
||||
break
|
||||
total_bytes += len(chunk)
|
||||
if total_bytes > _MAX_TRANSCRIBE_AUDIO_BYTES:
|
||||
raise ValueError("Audio file too large")
|
||||
if len(header) < _WAV_HEADER_MIN_BYTES:
|
||||
required = _WAV_HEADER_MIN_BYTES - len(header)
|
||||
header.extend(chunk[:required])
|
||||
tmp_file.write(chunk)
|
||||
|
||||
if total_bytes == 0:
|
||||
raise ValueError("Empty audio file")
|
||||
if not _looks_like_wav_header(bytes(header)):
|
||||
raise ValueError("Unsupported audio format")
|
||||
|
||||
transcript = await asr_service.transcribe_file(
|
||||
temp_path, audio.filename or "unknown"
|
||||
)
|
||||
|
||||
return AsrTranscribeResponse(transcript=transcript)
|
||||
@@ -238,8 +292,12 @@ async def transcribe(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(exc)},
|
||||
)
|
||||
except RuntimeError as exc:
|
||||
except RuntimeError:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content={"detail": str(exc)},
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
content={"detail": "ASR service unavailable"},
|
||||
)
|
||||
finally:
|
||||
await audio.close()
|
||||
if temp_path and os.path.exists(temp_path):
|
||||
os.unlink(temp_path)
|
||||
|
||||
Reference in New Issue
Block a user