chore: 后端 agent 和 users 模块代码更新优化

This commit is contained in:
qzl
2026-03-10 17:44:29 +08:00
parent 8da9377ed9
commit 2049184456
9 changed files with 294 additions and 81 deletions
+66 -8
View File
@@ -3,7 +3,9 @@ from __future__ import annotations
from collections.abc import AsyncIterator
import asyncio
from datetime import date
import os
import re
import tempfile
import time
from typing import Annotated, Union
@@ -29,6 +31,21 @@ _LAST_EVENT_ID_RE = re.compile(r"^\d+-\d+$")
_RUNS_PER_MINUTE = 30
_MAX_SSE_CONNECTIONS_PER_USER = 3
_SSE_SLOT_TTL_SECONDS = 15 * 60
_MAX_TRANSCRIBE_AUDIO_BYTES = 10 * 1024 * 1024
_TRANSCRIBE_READ_CHUNK_BYTES = 1024 * 1024
_MULTIPART_OVERHEAD_BYTES = 64 * 1024
_WAV_HEADER_MIN_BYTES = 12
_ALLOWED_AUDIO_CONTENT_TYPES = {
"audio/wav",
"audio/x-wav",
"audio/wave",
}
def _looks_like_wav_header(header: bytes) -> bool:
if len(header) < _WAV_HEADER_MIN_BYTES:
return False
return header[0:4] == b"RIFF" and header[8:12] == b"WAVE"
async def _allow_run_request(*, user_id: str) -> bool:
@@ -220,15 +237,52 @@ async def get_user_history_snapshot(
)
async def transcribe(
audio: UploadFile,
request: Request,
current_user: Annotated[CurrentUser, Depends(get_current_user)],
) -> Union[AsrTranscribeResponse, JSONResponse]:
del current_user
temp_path: str | None = None
try:
audio_data = await audio.read()
if not audio_data:
raise ValueError("Empty audio file")
if audio.content_type not in _ALLOWED_AUDIO_CONTENT_TYPES:
raise ValueError("Unsupported audio format")
transcript = await asr_service.transcribe(
audio_data, audio.filename or "unknown"
content_length = request.headers.get("content-length")
if content_length is not None:
try:
declared_length = int(content_length)
except ValueError:
declared_length = None
if (
declared_length is not None
and declared_length
> _MAX_TRANSCRIBE_AUDIO_BYTES + _MULTIPART_OVERHEAD_BYTES
):
raise ValueError("Audio file too large")
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
temp_path = tmp_file.name
total_bytes = 0
header = bytearray()
while True:
chunk = await audio.read(_TRANSCRIBE_READ_CHUNK_BYTES)
if not chunk:
break
total_bytes += len(chunk)
if total_bytes > _MAX_TRANSCRIBE_AUDIO_BYTES:
raise ValueError("Audio file too large")
if len(header) < _WAV_HEADER_MIN_BYTES:
required = _WAV_HEADER_MIN_BYTES - len(header)
header.extend(chunk[:required])
tmp_file.write(chunk)
if total_bytes == 0:
raise ValueError("Empty audio file")
if not _looks_like_wav_header(bytes(header)):
raise ValueError("Unsupported audio format")
transcript = await asr_service.transcribe_file(
temp_path, audio.filename or "unknown"
)
return AsrTranscribeResponse(transcript=transcript)
@@ -238,8 +292,12 @@ async def transcribe(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(exc)},
)
except RuntimeError as exc:
except RuntimeError:
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={"detail": str(exc)},
status_code=status.HTTP_502_BAD_GATEWAY,
content={"detail": "ASR service unavailable"},
)
finally:
await audio.close()
if temp_path and os.path.exists(temp_path):
os.unlink(temp_path)
+63 -48
View File
@@ -1,8 +1,6 @@
from __future__ import annotations
import asyncio
import tempfile
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import date
from typing import Any, Protocol
@@ -106,16 +104,13 @@ class AgentService:
else:
ensure_session_owner(owner_id=owner, current_user=current_user)
try:
task_id = await self._queue.enqueue(
command={
"command": "run",
"run_input": run_input.model_dump(mode="json", by_alias=True),
},
dedup_key=None,
)
except Exception: # noqa: BLE001
raise
task_id = await self._queue.enqueue(
command={
"command": "run",
"run_input": run_input.model_dump(mode="json", by_alias=True),
},
dedup_key=None,
)
return TaskAccepted(
task_id=task_id,
thread_id=thread_id,
@@ -234,57 +229,46 @@ class AsrService:
self._api_key = dashscope_key
return self._api_key
@contextmanager
def _temp_wav(self, audio_data: bytes):
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
tmp.write(audio_data)
tmp_path = tmp.name
try:
yield tmp_path
finally:
import os
if os.path.exists(tmp_path):
os.unlink(tmp_path)
async def transcribe(self, audio_data: bytes, filename: str) -> str:
async def transcribe_file(self, file_path: str, filename: str) -> str:
try:
dashscope.api_key = self._get_api_key()
with self._temp_wav(audio_data) as tmp_path:
loop = asyncio.get_event_loop()
loop = asyncio.get_event_loop()
class SyncCallback(RecognitionCallback):
error: str | None = None
class SyncCallback(RecognitionCallback):
error: str | None = None
def on_error(self, result: Any) -> None:
self.error = str(result)
def on_error(self, result: Any) -> None:
self.error = str(result)
callback = SyncCallback()
recognizer = Recognition(
model="fun-asr-realtime-2026-02-28",
callback=callback,
format="wav",
sample_rate=16000,
)
callback = SyncCallback()
recognizer = Recognition(
model="fun-asr-realtime-2026-02-28",
callback=callback,
format="wav",
sample_rate=16000,
)
result: Any = await loop.run_in_executor(
None,
lambda: recognizer.call(file=tmp_path),
)
result: Any = await loop.run_in_executor(
None,
lambda: recognizer.call(file=file_path),
)
if callback.error:
raise RuntimeError(f"ASR error: {callback.error}")
if result.status_code != 200:
raise RuntimeError(f"ASR transcription failed: {result.message}")
status_code = self._extract_field(result, "status_code")
if status_code != 200:
message = self._extract_field(result, "message")
raise RuntimeError(f"ASR transcription failed: {message}")
if result.output is None or result.output.sentence is None:
sentence = self._extract_sentence_payload(result)
if sentence is None:
request_id = self._extract_field(result, "request_id")
logger.warning(
"ASR returned empty result", extra={"request_id": result.request_id}
"ASR returned empty result", extra={"request_id": request_id}
)
return ""
sentence = result.output.sentence
if isinstance(sentence, dict):
transcription = sentence.get("text", "")
elif isinstance(sentence, list):
@@ -300,9 +284,40 @@ class AsrService:
)
return transcription
except asyncio.CancelledError:
raise
except RuntimeError:
raise
except Exception as exc:
logger.exception("ASR transcription error")
raise RuntimeError(f"ASR transcription failed: {exc}") from exc
def _extract_sentence_payload(self, result: Any) -> Any | None:
if isinstance(result, dict):
output = result.get("output")
if isinstance(output, dict):
return output.get("sentence")
if output is not None:
return getattr(output, "sentence", None)
return result.get("sentence")
get_sentence = getattr(result, "get_sentence", None)
if callable(get_sentence):
sentence = get_sentence()
if sentence is not None:
return sentence
output = getattr(result, "output", None)
if output is None:
return None
if isinstance(output, dict):
return output.get("sentence")
return getattr(output, "sentence", None)
def _extract_field(self, result: Any, field: str) -> Any | None:
if isinstance(result, dict):
return result.get(field)
return getattr(result, field, None)
asr_service = AsrService()
-5
View File
@@ -1,21 +1,16 @@
from __future__ import annotations
from typing import Annotated
from fastapi import APIRouter, Depends, Request, Response
from fastapi import HTTPException
from core.auth.models import CurrentUser
from v1.auth.rate_limit import enforce_rate_limit
from v1.auth.dependencies import get_auth_service
from v1.users.dependencies import get_current_user
from v1.auth.schemas import (
PasswordResetConfirmRequest,
SessionCreateRequest,
SessionDeleteRequest,
SessionRefreshRequest,
SessionResponse,
UserByEmailResponse,
VerificationCreateRequest,
VerificationCreateResponse,
VerificationResendRequest,
+8 -10
View File
@@ -9,7 +9,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
from core.auth.jwt_verifier import (
JwtVerifier,
TokenValidationError,
TokenVerifierUnavailableError,
)
from core.auth.models import CurrentUser
from core.config.settings import config
@@ -35,17 +34,19 @@ def get_auth_gateway() -> SupabaseAuthGateway:
def get_jwt_verifier() -> JwtVerifier:
global _jwt_verifier
if _jwt_verifier is None:
jwks_url = config.supabase.jwks_url
issuer = config.supabase.jwt_issuer
audience = config.supabase.jwt_audience
if not jwks_url or not issuer or not audience:
jwt_secret = (
config.supabase.jwt_secret.get_secret_value()
if config.supabase.jwt_secret is not None
else None
)
if not issuer or not jwt_secret:
logger.error("JWT validation failed: verifier config not configured")
raise HTTPException(status_code=503, detail="JWT verifier not configured")
_jwt_verifier = JwtVerifier(
jwks_url=jwks_url,
issuer=issuer,
audience=audience,
apikey=config.supabase.anon_key,
jwt_secret=jwt_secret,
jwt_algorithm=config.supabase.jwt_algorithm,
)
return _jwt_verifier
@@ -64,9 +65,6 @@ def get_current_user(authorization: str | None = Header(default=None)) -> Curren
payload = get_jwt_verifier().verify(token)
except HTTPException:
raise
except TokenVerifierUnavailableError:
logger.error("JWT validation failed: verifier unavailable")
raise HTTPException(status_code=503, detail="JWT verifier unavailable")
except TokenValidationError as exc:
logger.warning(
"JWT validation failed",
+1
View File
@@ -15,6 +15,7 @@ from pydantic import (
class UserResponse(BaseModel):
id: str
username: str
email: str | None = None
avatar_url: str | None = None
bio: str | None = None
+4
View File
@@ -90,9 +90,11 @@ class UserService(BaseService):
if user is None:
raise HTTPException(status_code=404, detail="User not found")
email = self._current_user.email if self._current_user else None
return UserResponse(
id=str(user.id),
username=user.username,
email=email,
avatar_url=user.avatar_url,
bio=user.bio,
)
@@ -131,9 +133,11 @@ class UserService(BaseService):
error=str(exc),
)
email = self._current_user.email if self._current_user else None
return UserResponse(
id=str(user.id),
username=user.username,
email=email,
avatar_url=user.avatar_url,
bio=user.bio,
)