chore: 后端 agent 和 users 模块代码更新优化
This commit is contained in:
@@ -3,7 +3,9 @@ from __future__ import annotations
|
||||
from collections.abc import AsyncIterator
|
||||
import asyncio
|
||||
from datetime import date
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import time
|
||||
from typing import Annotated, Union
|
||||
|
||||
@@ -29,6 +31,21 @@ _LAST_EVENT_ID_RE = re.compile(r"^\d+-\d+$")
|
||||
_RUNS_PER_MINUTE = 30
|
||||
_MAX_SSE_CONNECTIONS_PER_USER = 3
|
||||
_SSE_SLOT_TTL_SECONDS = 15 * 60
|
||||
_MAX_TRANSCRIBE_AUDIO_BYTES = 10 * 1024 * 1024
|
||||
_TRANSCRIBE_READ_CHUNK_BYTES = 1024 * 1024
|
||||
_MULTIPART_OVERHEAD_BYTES = 64 * 1024
|
||||
_WAV_HEADER_MIN_BYTES = 12
|
||||
_ALLOWED_AUDIO_CONTENT_TYPES = {
|
||||
"audio/wav",
|
||||
"audio/x-wav",
|
||||
"audio/wave",
|
||||
}
|
||||
|
||||
|
||||
def _looks_like_wav_header(header: bytes) -> bool:
|
||||
if len(header) < _WAV_HEADER_MIN_BYTES:
|
||||
return False
|
||||
return header[0:4] == b"RIFF" and header[8:12] == b"WAVE"
|
||||
|
||||
|
||||
async def _allow_run_request(*, user_id: str) -> bool:
|
||||
@@ -220,15 +237,52 @@ async def get_user_history_snapshot(
|
||||
)
|
||||
async def transcribe(
|
||||
audio: UploadFile,
|
||||
request: Request,
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
) -> Union[AsrTranscribeResponse, JSONResponse]:
|
||||
del current_user
|
||||
temp_path: str | None = None
|
||||
try:
|
||||
audio_data = await audio.read()
|
||||
if not audio_data:
|
||||
raise ValueError("Empty audio file")
|
||||
if audio.content_type not in _ALLOWED_AUDIO_CONTENT_TYPES:
|
||||
raise ValueError("Unsupported audio format")
|
||||
|
||||
transcript = await asr_service.transcribe(
|
||||
audio_data, audio.filename or "unknown"
|
||||
content_length = request.headers.get("content-length")
|
||||
if content_length is not None:
|
||||
try:
|
||||
declared_length = int(content_length)
|
||||
except ValueError:
|
||||
declared_length = None
|
||||
if (
|
||||
declared_length is not None
|
||||
and declared_length
|
||||
> _MAX_TRANSCRIBE_AUDIO_BYTES + _MULTIPART_OVERHEAD_BYTES
|
||||
):
|
||||
raise ValueError("Audio file too large")
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
|
||||
temp_path = tmp_file.name
|
||||
|
||||
total_bytes = 0
|
||||
header = bytearray()
|
||||
while True:
|
||||
chunk = await audio.read(_TRANSCRIBE_READ_CHUNK_BYTES)
|
||||
if not chunk:
|
||||
break
|
||||
total_bytes += len(chunk)
|
||||
if total_bytes > _MAX_TRANSCRIBE_AUDIO_BYTES:
|
||||
raise ValueError("Audio file too large")
|
||||
if len(header) < _WAV_HEADER_MIN_BYTES:
|
||||
required = _WAV_HEADER_MIN_BYTES - len(header)
|
||||
header.extend(chunk[:required])
|
||||
tmp_file.write(chunk)
|
||||
|
||||
if total_bytes == 0:
|
||||
raise ValueError("Empty audio file")
|
||||
if not _looks_like_wav_header(bytes(header)):
|
||||
raise ValueError("Unsupported audio format")
|
||||
|
||||
transcript = await asr_service.transcribe_file(
|
||||
temp_path, audio.filename or "unknown"
|
||||
)
|
||||
|
||||
return AsrTranscribeResponse(transcript=transcript)
|
||||
@@ -238,8 +292,12 @@ async def transcribe(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(exc)},
|
||||
)
|
||||
except RuntimeError as exc:
|
||||
except RuntimeError:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content={"detail": str(exc)},
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
content={"detail": "ASR service unavailable"},
|
||||
)
|
||||
finally:
|
||||
await audio.close()
|
||||
if temp_path and os.path.exists(temp_path):
|
||||
os.unlink(temp_path)
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import tempfile
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import date
|
||||
from typing import Any, Protocol
|
||||
@@ -106,16 +104,13 @@ class AgentService:
|
||||
else:
|
||||
ensure_session_owner(owner_id=owner, current_user=current_user)
|
||||
|
||||
try:
|
||||
task_id = await self._queue.enqueue(
|
||||
command={
|
||||
"command": "run",
|
||||
"run_input": run_input.model_dump(mode="json", by_alias=True),
|
||||
},
|
||||
dedup_key=None,
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
raise
|
||||
task_id = await self._queue.enqueue(
|
||||
command={
|
||||
"command": "run",
|
||||
"run_input": run_input.model_dump(mode="json", by_alias=True),
|
||||
},
|
||||
dedup_key=None,
|
||||
)
|
||||
return TaskAccepted(
|
||||
task_id=task_id,
|
||||
thread_id=thread_id,
|
||||
@@ -234,57 +229,46 @@ class AsrService:
|
||||
self._api_key = dashscope_key
|
||||
return self._api_key
|
||||
|
||||
@contextmanager
|
||||
def _temp_wav(self, audio_data: bytes):
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
|
||||
tmp.write(audio_data)
|
||||
tmp_path = tmp.name
|
||||
try:
|
||||
yield tmp_path
|
||||
finally:
|
||||
import os
|
||||
|
||||
if os.path.exists(tmp_path):
|
||||
os.unlink(tmp_path)
|
||||
|
||||
async def transcribe(self, audio_data: bytes, filename: str) -> str:
|
||||
async def transcribe_file(self, file_path: str, filename: str) -> str:
|
||||
try:
|
||||
dashscope.api_key = self._get_api_key()
|
||||
|
||||
with self._temp_wav(audio_data) as tmp_path:
|
||||
loop = asyncio.get_event_loop()
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
class SyncCallback(RecognitionCallback):
|
||||
error: str | None = None
|
||||
class SyncCallback(RecognitionCallback):
|
||||
error: str | None = None
|
||||
|
||||
def on_error(self, result: Any) -> None:
|
||||
self.error = str(result)
|
||||
def on_error(self, result: Any) -> None:
|
||||
self.error = str(result)
|
||||
|
||||
callback = SyncCallback()
|
||||
recognizer = Recognition(
|
||||
model="fun-asr-realtime-2026-02-28",
|
||||
callback=callback,
|
||||
format="wav",
|
||||
sample_rate=16000,
|
||||
)
|
||||
callback = SyncCallback()
|
||||
recognizer = Recognition(
|
||||
model="fun-asr-realtime-2026-02-28",
|
||||
callback=callback,
|
||||
format="wav",
|
||||
sample_rate=16000,
|
||||
)
|
||||
|
||||
result: Any = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: recognizer.call(file=tmp_path),
|
||||
)
|
||||
result: Any = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: recognizer.call(file=file_path),
|
||||
)
|
||||
|
||||
if callback.error:
|
||||
raise RuntimeError(f"ASR error: {callback.error}")
|
||||
if result.status_code != 200:
|
||||
raise RuntimeError(f"ASR transcription failed: {result.message}")
|
||||
status_code = self._extract_field(result, "status_code")
|
||||
if status_code != 200:
|
||||
message = self._extract_field(result, "message")
|
||||
raise RuntimeError(f"ASR transcription failed: {message}")
|
||||
|
||||
if result.output is None or result.output.sentence is None:
|
||||
sentence = self._extract_sentence_payload(result)
|
||||
if sentence is None:
|
||||
request_id = self._extract_field(result, "request_id")
|
||||
logger.warning(
|
||||
"ASR returned empty result", extra={"request_id": result.request_id}
|
||||
"ASR returned empty result", extra={"request_id": request_id}
|
||||
)
|
||||
return ""
|
||||
|
||||
sentence = result.output.sentence
|
||||
if isinstance(sentence, dict):
|
||||
transcription = sentence.get("text", "")
|
||||
elif isinstance(sentence, list):
|
||||
@@ -300,9 +284,40 @@ class AsrService:
|
||||
)
|
||||
return transcription
|
||||
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except RuntimeError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.exception("ASR transcription error")
|
||||
raise RuntimeError(f"ASR transcription failed: {exc}") from exc
|
||||
|
||||
def _extract_sentence_payload(self, result: Any) -> Any | None:
|
||||
if isinstance(result, dict):
|
||||
output = result.get("output")
|
||||
if isinstance(output, dict):
|
||||
return output.get("sentence")
|
||||
if output is not None:
|
||||
return getattr(output, "sentence", None)
|
||||
return result.get("sentence")
|
||||
|
||||
get_sentence = getattr(result, "get_sentence", None)
|
||||
if callable(get_sentence):
|
||||
sentence = get_sentence()
|
||||
if sentence is not None:
|
||||
return sentence
|
||||
|
||||
output = getattr(result, "output", None)
|
||||
if output is None:
|
||||
return None
|
||||
if isinstance(output, dict):
|
||||
return output.get("sentence")
|
||||
return getattr(output, "sentence", None)
|
||||
|
||||
def _extract_field(self, result: Any, field: str) -> Any | None:
|
||||
if isinstance(result, dict):
|
||||
return result.get(field)
|
||||
return getattr(result, field, None)
|
||||
|
||||
|
||||
asr_service = AsrService()
|
||||
|
||||
@@ -1,21 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Request, Response
|
||||
from fastapi import HTTPException
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from v1.auth.rate_limit import enforce_rate_limit
|
||||
from v1.auth.dependencies import get_auth_service
|
||||
from v1.users.dependencies import get_current_user
|
||||
from v1.auth.schemas import (
|
||||
PasswordResetConfirmRequest,
|
||||
SessionCreateRequest,
|
||||
SessionDeleteRequest,
|
||||
SessionRefreshRequest,
|
||||
SessionResponse,
|
||||
UserByEmailResponse,
|
||||
VerificationCreateRequest,
|
||||
VerificationCreateResponse,
|
||||
VerificationResendRequest,
|
||||
|
||||
@@ -9,7 +9,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from core.auth.jwt_verifier import (
|
||||
JwtVerifier,
|
||||
TokenValidationError,
|
||||
TokenVerifierUnavailableError,
|
||||
)
|
||||
from core.auth.models import CurrentUser
|
||||
from core.config.settings import config
|
||||
@@ -35,17 +34,19 @@ def get_auth_gateway() -> SupabaseAuthGateway:
|
||||
def get_jwt_verifier() -> JwtVerifier:
|
||||
global _jwt_verifier
|
||||
if _jwt_verifier is None:
|
||||
jwks_url = config.supabase.jwks_url
|
||||
issuer = config.supabase.jwt_issuer
|
||||
audience = config.supabase.jwt_audience
|
||||
if not jwks_url or not issuer or not audience:
|
||||
jwt_secret = (
|
||||
config.supabase.jwt_secret.get_secret_value()
|
||||
if config.supabase.jwt_secret is not None
|
||||
else None
|
||||
)
|
||||
if not issuer or not jwt_secret:
|
||||
logger.error("JWT validation failed: verifier config not configured")
|
||||
raise HTTPException(status_code=503, detail="JWT verifier not configured")
|
||||
_jwt_verifier = JwtVerifier(
|
||||
jwks_url=jwks_url,
|
||||
issuer=issuer,
|
||||
audience=audience,
|
||||
apikey=config.supabase.anon_key,
|
||||
jwt_secret=jwt_secret,
|
||||
jwt_algorithm=config.supabase.jwt_algorithm,
|
||||
)
|
||||
return _jwt_verifier
|
||||
|
||||
@@ -64,9 +65,6 @@ def get_current_user(authorization: str | None = Header(default=None)) -> Curren
|
||||
payload = get_jwt_verifier().verify(token)
|
||||
except HTTPException:
|
||||
raise
|
||||
except TokenVerifierUnavailableError:
|
||||
logger.error("JWT validation failed: verifier unavailable")
|
||||
raise HTTPException(status_code=503, detail="JWT verifier unavailable")
|
||||
except TokenValidationError as exc:
|
||||
logger.warning(
|
||||
"JWT validation failed",
|
||||
|
||||
@@ -15,6 +15,7 @@ from pydantic import (
|
||||
class UserResponse(BaseModel):
|
||||
id: str
|
||||
username: str
|
||||
email: str | None = None
|
||||
avatar_url: str | None = None
|
||||
bio: str | None = None
|
||||
|
||||
|
||||
@@ -90,9 +90,11 @@ class UserService(BaseService):
|
||||
|
||||
if user is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
email = self._current_user.email if self._current_user else None
|
||||
return UserResponse(
|
||||
id=str(user.id),
|
||||
username=user.username,
|
||||
email=email,
|
||||
avatar_url=user.avatar_url,
|
||||
bio=user.bio,
|
||||
)
|
||||
@@ -131,9 +133,11 @@ class UserService(BaseService):
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
email = self._current_user.email if self._current_user else None
|
||||
return UserResponse(
|
||||
id=str(user.id),
|
||||
username=user.username,
|
||||
email=email,
|
||||
avatar_url=user.avatar_url,
|
||||
bio=user.bio,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user