chore: 后端 agent 和 users 模块代码更新优化

This commit is contained in:
qzl
2026-03-10 17:44:29 +08:00
parent 8da9377ed9
commit 2049184456
9 changed files with 294 additions and 81 deletions
+66 -8
View File
@@ -3,7 +3,9 @@ from __future__ import annotations
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
import asyncio import asyncio
from datetime import date from datetime import date
import os
import re import re
import tempfile
import time import time
from typing import Annotated, Union from typing import Annotated, Union
@@ -29,6 +31,21 @@ _LAST_EVENT_ID_RE = re.compile(r"^\d+-\d+$")
_RUNS_PER_MINUTE = 30 _RUNS_PER_MINUTE = 30
_MAX_SSE_CONNECTIONS_PER_USER = 3 _MAX_SSE_CONNECTIONS_PER_USER = 3
_SSE_SLOT_TTL_SECONDS = 15 * 60 _SSE_SLOT_TTL_SECONDS = 15 * 60
_MAX_TRANSCRIBE_AUDIO_BYTES = 10 * 1024 * 1024
_TRANSCRIBE_READ_CHUNK_BYTES = 1024 * 1024
_MULTIPART_OVERHEAD_BYTES = 64 * 1024
_WAV_HEADER_MIN_BYTES = 12
_ALLOWED_AUDIO_CONTENT_TYPES = {
"audio/wav",
"audio/x-wav",
"audio/wave",
}
def _looks_like_wav_header(header: bytes) -> bool:
if len(header) < _WAV_HEADER_MIN_BYTES:
return False
return header[0:4] == b"RIFF" and header[8:12] == b"WAVE"
async def _allow_run_request(*, user_id: str) -> bool: async def _allow_run_request(*, user_id: str) -> bool:
@@ -220,15 +237,52 @@ async def get_user_history_snapshot(
) )
async def transcribe( async def transcribe(
audio: UploadFile, audio: UploadFile,
request: Request,
current_user: Annotated[CurrentUser, Depends(get_current_user)], current_user: Annotated[CurrentUser, Depends(get_current_user)],
) -> Union[AsrTranscribeResponse, JSONResponse]: ) -> Union[AsrTranscribeResponse, JSONResponse]:
del current_user
temp_path: str | None = None
try: try:
audio_data = await audio.read() if audio.content_type not in _ALLOWED_AUDIO_CONTENT_TYPES:
if not audio_data: raise ValueError("Unsupported audio format")
raise ValueError("Empty audio file")
transcript = await asr_service.transcribe( content_length = request.headers.get("content-length")
audio_data, audio.filename or "unknown" if content_length is not None:
try:
declared_length = int(content_length)
except ValueError:
declared_length = None
if (
declared_length is not None
and declared_length
> _MAX_TRANSCRIBE_AUDIO_BYTES + _MULTIPART_OVERHEAD_BYTES
):
raise ValueError("Audio file too large")
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
temp_path = tmp_file.name
total_bytes = 0
header = bytearray()
while True:
chunk = await audio.read(_TRANSCRIBE_READ_CHUNK_BYTES)
if not chunk:
break
total_bytes += len(chunk)
if total_bytes > _MAX_TRANSCRIBE_AUDIO_BYTES:
raise ValueError("Audio file too large")
if len(header) < _WAV_HEADER_MIN_BYTES:
required = _WAV_HEADER_MIN_BYTES - len(header)
header.extend(chunk[:required])
tmp_file.write(chunk)
if total_bytes == 0:
raise ValueError("Empty audio file")
if not _looks_like_wav_header(bytes(header)):
raise ValueError("Unsupported audio format")
transcript = await asr_service.transcribe_file(
temp_path, audio.filename or "unknown"
) )
return AsrTranscribeResponse(transcript=transcript) return AsrTranscribeResponse(transcript=transcript)
@@ -238,8 +292,12 @@ async def transcribe(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(exc)}, content={"detail": str(exc)},
) )
except RuntimeError as exc: except RuntimeError:
return JSONResponse( return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_502_BAD_GATEWAY,
content={"detail": str(exc)}, content={"detail": "ASR service unavailable"},
) )
finally:
await audio.close()
if temp_path and os.path.exists(temp_path):
os.unlink(temp_path)
+63 -48
View File
@@ -1,8 +1,6 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import tempfile
from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from datetime import date from datetime import date
from typing import Any, Protocol from typing import Any, Protocol
@@ -106,16 +104,13 @@ class AgentService:
else: else:
ensure_session_owner(owner_id=owner, current_user=current_user) ensure_session_owner(owner_id=owner, current_user=current_user)
try: task_id = await self._queue.enqueue(
task_id = await self._queue.enqueue( command={
command={ "command": "run",
"command": "run", "run_input": run_input.model_dump(mode="json", by_alias=True),
"run_input": run_input.model_dump(mode="json", by_alias=True), },
}, dedup_key=None,
dedup_key=None, )
)
except Exception: # noqa: BLE001
raise
return TaskAccepted( return TaskAccepted(
task_id=task_id, task_id=task_id,
thread_id=thread_id, thread_id=thread_id,
@@ -234,57 +229,46 @@ class AsrService:
self._api_key = dashscope_key self._api_key = dashscope_key
return self._api_key return self._api_key
@contextmanager async def transcribe_file(self, file_path: str, filename: str) -> str:
def _temp_wav(self, audio_data: bytes):
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
tmp.write(audio_data)
tmp_path = tmp.name
try:
yield tmp_path
finally:
import os
if os.path.exists(tmp_path):
os.unlink(tmp_path)
async def transcribe(self, audio_data: bytes, filename: str) -> str:
try: try:
dashscope.api_key = self._get_api_key() dashscope.api_key = self._get_api_key()
with self._temp_wav(audio_data) as tmp_path: loop = asyncio.get_event_loop()
loop = asyncio.get_event_loop()
class SyncCallback(RecognitionCallback): class SyncCallback(RecognitionCallback):
error: str | None = None error: str | None = None
def on_error(self, result: Any) -> None: def on_error(self, result: Any) -> None:
self.error = str(result) self.error = str(result)
callback = SyncCallback() callback = SyncCallback()
recognizer = Recognition( recognizer = Recognition(
model="fun-asr-realtime-2026-02-28", model="fun-asr-realtime-2026-02-28",
callback=callback, callback=callback,
format="wav", format="wav",
sample_rate=16000, sample_rate=16000,
) )
result: Any = await loop.run_in_executor( result: Any = await loop.run_in_executor(
None, None,
lambda: recognizer.call(file=tmp_path), lambda: recognizer.call(file=file_path),
) )
if callback.error: if callback.error:
raise RuntimeError(f"ASR error: {callback.error}") raise RuntimeError(f"ASR error: {callback.error}")
if result.status_code != 200: status_code = self._extract_field(result, "status_code")
raise RuntimeError(f"ASR transcription failed: {result.message}") if status_code != 200:
message = self._extract_field(result, "message")
raise RuntimeError(f"ASR transcription failed: {message}")
if result.output is None or result.output.sentence is None: sentence = self._extract_sentence_payload(result)
if sentence is None:
request_id = self._extract_field(result, "request_id")
logger.warning( logger.warning(
"ASR returned empty result", extra={"request_id": result.request_id} "ASR returned empty result", extra={"request_id": request_id}
) )
return "" return ""
sentence = result.output.sentence
if isinstance(sentence, dict): if isinstance(sentence, dict):
transcription = sentence.get("text", "") transcription = sentence.get("text", "")
elif isinstance(sentence, list): elif isinstance(sentence, list):
@@ -300,9 +284,40 @@ class AsrService:
) )
return transcription return transcription
except asyncio.CancelledError:
raise
except RuntimeError:
raise
except Exception as exc: except Exception as exc:
logger.exception("ASR transcription error") logger.exception("ASR transcription error")
raise RuntimeError(f"ASR transcription failed: {exc}") from exc raise RuntimeError(f"ASR transcription failed: {exc}") from exc
def _extract_sentence_payload(self, result: Any) -> Any | None:
if isinstance(result, dict):
output = result.get("output")
if isinstance(output, dict):
return output.get("sentence")
if output is not None:
return getattr(output, "sentence", None)
return result.get("sentence")
get_sentence = getattr(result, "get_sentence", None)
if callable(get_sentence):
sentence = get_sentence()
if sentence is not None:
return sentence
output = getattr(result, "output", None)
if output is None:
return None
if isinstance(output, dict):
return output.get("sentence")
return getattr(output, "sentence", None)
def _extract_field(self, result: Any, field: str) -> Any | None:
if isinstance(result, dict):
return result.get(field)
return getattr(result, field, None)
asr_service = AsrService() asr_service = AsrService()
-5
View File
@@ -1,21 +1,16 @@
from __future__ import annotations from __future__ import annotations
from typing import Annotated
from fastapi import APIRouter, Depends, Request, Response from fastapi import APIRouter, Depends, Request, Response
from fastapi import HTTPException from fastapi import HTTPException
from core.auth.models import CurrentUser
from v1.auth.rate_limit import enforce_rate_limit from v1.auth.rate_limit import enforce_rate_limit
from v1.auth.dependencies import get_auth_service from v1.auth.dependencies import get_auth_service
from v1.users.dependencies import get_current_user
from v1.auth.schemas import ( from v1.auth.schemas import (
PasswordResetConfirmRequest, PasswordResetConfirmRequest,
SessionCreateRequest, SessionCreateRequest,
SessionDeleteRequest, SessionDeleteRequest,
SessionRefreshRequest, SessionRefreshRequest,
SessionResponse, SessionResponse,
UserByEmailResponse,
VerificationCreateRequest, VerificationCreateRequest,
VerificationCreateResponse, VerificationCreateResponse,
VerificationResendRequest, VerificationResendRequest,
+8 -10
View File
@@ -9,7 +9,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
from core.auth.jwt_verifier import ( from core.auth.jwt_verifier import (
JwtVerifier, JwtVerifier,
TokenValidationError, TokenValidationError,
TokenVerifierUnavailableError,
) )
from core.auth.models import CurrentUser from core.auth.models import CurrentUser
from core.config.settings import config from core.config.settings import config
@@ -35,17 +34,19 @@ def get_auth_gateway() -> SupabaseAuthGateway:
def get_jwt_verifier() -> JwtVerifier: def get_jwt_verifier() -> JwtVerifier:
global _jwt_verifier global _jwt_verifier
if _jwt_verifier is None: if _jwt_verifier is None:
jwks_url = config.supabase.jwks_url
issuer = config.supabase.jwt_issuer issuer = config.supabase.jwt_issuer
audience = config.supabase.jwt_audience jwt_secret = (
if not jwks_url or not issuer or not audience: config.supabase.jwt_secret.get_secret_value()
if config.supabase.jwt_secret is not None
else None
)
if not issuer or not jwt_secret:
logger.error("JWT validation failed: verifier config not configured") logger.error("JWT validation failed: verifier config not configured")
raise HTTPException(status_code=503, detail="JWT verifier not configured") raise HTTPException(status_code=503, detail="JWT verifier not configured")
_jwt_verifier = JwtVerifier( _jwt_verifier = JwtVerifier(
jwks_url=jwks_url,
issuer=issuer, issuer=issuer,
audience=audience, jwt_secret=jwt_secret,
apikey=config.supabase.anon_key, jwt_algorithm=config.supabase.jwt_algorithm,
) )
return _jwt_verifier return _jwt_verifier
@@ -64,9 +65,6 @@ def get_current_user(authorization: str | None = Header(default=None)) -> Curren
payload = get_jwt_verifier().verify(token) payload = get_jwt_verifier().verify(token)
except HTTPException: except HTTPException:
raise raise
except TokenVerifierUnavailableError:
logger.error("JWT validation failed: verifier unavailable")
raise HTTPException(status_code=503, detail="JWT verifier unavailable")
except TokenValidationError as exc: except TokenValidationError as exc:
logger.warning( logger.warning(
"JWT validation failed", "JWT validation failed",
+1
View File
@@ -15,6 +15,7 @@ from pydantic import (
class UserResponse(BaseModel): class UserResponse(BaseModel):
id: str id: str
username: str username: str
email: str | None = None
avatar_url: str | None = None avatar_url: str | None = None
bio: str | None = None bio: str | None = None
+4
View File
@@ -90,9 +90,11 @@ class UserService(BaseService):
if user is None: if user is None:
raise HTTPException(status_code=404, detail="User not found") raise HTTPException(status_code=404, detail="User not found")
email = self._current_user.email if self._current_user else None
return UserResponse( return UserResponse(
id=str(user.id), id=str(user.id),
username=user.username, username=user.username,
email=email,
avatar_url=user.avatar_url, avatar_url=user.avatar_url,
bio=user.bio, bio=user.bio,
) )
@@ -131,9 +133,11 @@ class UserService(BaseService):
error=str(exc), error=str(exc),
) )
email = self._current_user.email if self._current_user else None
return UserResponse( return UserResponse(
id=str(user.id), id=str(user.id),
username=user.username, username=user.username,
email=email,
avatar_url=user.avatar_url, avatar_url=user.avatar_url,
bio=user.bio, bio=user.bio,
) )
@@ -354,17 +354,19 @@ def test_asr_transcribe_returns_sync_transcript(monkeypatch) -> None:
id=uuid4(), email="user@example.com" id=uuid4(), email="user@example.com"
) )
async def mock_transcribe(audio_data: bytes, filename: str) -> str: async def mock_transcribe_file(file_path: str, filename: str) -> str:
assert file_path.endswith(".wav")
assert filename == "test.wav"
return "这是测试转写结果" return "这是测试转写结果"
monkeypatch.setattr( monkeypatch.setattr(
"v1.agent.service.asr_service.transcribe", "v1.agent.service.asr_service.transcribe_file",
mock_transcribe, mock_transcribe_file,
) )
client = TestClient(app) client = TestClient(app)
wav_content = b"fake-wav-file-content" wav_content = b"RIFF\x24\x80\x00\x00WAVEfmt "
wav_file = BytesIO(wav_content) wav_file = BytesIO(wav_content)
wav_file.name = "test.wav" wav_file.name = "test.wav"
@@ -380,3 +382,68 @@ def test_asr_transcribe_returns_sync_transcript(monkeypatch) -> None:
assert data["transcript"] == "这是测试转写结果" assert data["transcript"] == "这是测试转写结果"
finally: finally:
app.dependency_overrides = {} app.dependency_overrides = {}
def test_asr_transcribe_rejects_oversized_audio(monkeypatch) -> None:
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
id=uuid4(), email="user@example.com"
)
monkeypatch.setattr(agent_router, "_MAX_TRANSCRIBE_AUDIO_BYTES", 4)
client = TestClient(app)
oversized = BytesIO(b"12345")
oversized.name = "test.wav"
try:
response = client.post(
"/api/v1/agent/transcribe",
files={"audio": ("test.wav", oversized, "audio/wav")},
)
assert response.status_code == 400
assert response.json()["detail"] == "Audio file too large"
finally:
app.dependency_overrides = {}
def test_asr_transcribe_rejects_non_wav_audio() -> None:
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
id=uuid4(), email="user@example.com"
)
client = TestClient(app)
fake_mp3 = BytesIO(b"fake-mp3")
fake_mp3.name = "test.mp3"
try:
response = client.post(
"/api/v1/agent/transcribe",
files={"audio": ("test.mp3", fake_mp3, "audio/mpeg")},
)
assert response.status_code == 400
assert response.json()["detail"] == "Unsupported audio format"
finally:
app.dependency_overrides = {}
def test_asr_transcribe_rejects_invalid_wav_payload() -> None:
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
id=uuid4(), email="user@example.com"
)
client = TestClient(app)
fake_payload = BytesIO(b"not-a-wav")
fake_payload.name = "test.wav"
try:
response = client.post(
"/api/v1/agent/transcribe",
files={"audio": ("test.wav", fake_payload, "audio/wav")},
)
assert response.status_code == 400
assert response.json()["detail"] == "Unsupported audio format"
finally:
app.dependency_overrides = {}
@@ -9,16 +9,15 @@ APP_SCRIPT = ROOT_DIR / "infra" / "scripts" / "app.sh"
def test_worker_commands_use_taskiq() -> None: def test_worker_commands_use_taskiq() -> None:
content = APP_SCRIPT.read_text(encoding="utf-8") content = APP_SCRIPT.read_text(encoding="utf-8")
removed_runner = "uv run c" "elery" removed_runner = "uv run celery"
assert "uv run taskiq worker" in content assert "uv run taskiq worker" in content
assert "core.taskiq.app:critical_broker" in content assert "core.taskiq.app:critical_broker" in content
assert "core.taskiq.app:default_broker" in content assert "core.taskiq.app:default_broker" in content
assert "core.taskiq.app:bulk_broker" in content assert "core.taskiq.app:bulk_broker" in content
assert 'pgrep -f "taskiq.*worker"' in content assert 'pgrep -f "uv run taskiq worker core.taskiq.app:"' in content
assert 'pkill -f "taskiq.*worker"' in content assert 'kill_pids_gracefully "taskiq workers"' in content
assert 'pgrep -f "gunicorn.*app:app"' in content assert "gunicorn" not in content
assert 'pkill -f "gunicorn.*app:app"' in content
assert removed_runner not in content assert removed_runner not in content
+77 -1
View File
@@ -1,13 +1,15 @@
from __future__ import annotations from __future__ import annotations
from datetime import date from datetime import date
from types import SimpleNamespace
from uuid import UUID from uuid import UUID
from ag_ui.core import RunAgentInput from ag_ui.core import RunAgentInput
from fastapi import HTTPException from fastapi import HTTPException
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from core.auth.models import CurrentUser from core.auth.models import CurrentUser
from v1.agent.service import AgentService import v1.agent.service as agent_service_module
from v1.agent.service import AgentService, AsrService
class _FakeRepository: class _FakeRepository:
@@ -249,3 +251,77 @@ async def test_get_user_history_snapshot_uses_latest_thread_when_absent() -> Non
) )
assert event["type"] == "STATE_SNAPSHOT" assert event["type"] == "STATE_SNAPSHOT"
assert event["threadId"] == "00000000-0000-0000-0000-000000000001" assert event["threadId"] == "00000000-0000-0000-0000-000000000001"
async def test_asr_service_parses_dict_output_sentence(monkeypatch) -> None:
result = SimpleNamespace(
status_code=200,
message="ok",
output={"sentence": {"text": "你好,世界"}},
request_id="req-test",
)
class _FakeRecognition:
def __init__(self, **kwargs) -> None:
del kwargs
def call(self, *, file: str):
del file
return result
monkeypatch.setattr(agent_service_module, "Recognition", _FakeRecognition)
monkeypatch.setattr(AsrService, "_get_api_key", lambda self: "test-key")
service = AsrService()
transcript = await service.transcribe_file("/tmp/test.wav", "test.wav")
assert transcript == "你好,世界"
async def test_asr_service_parses_sentence_when_result_is_dict(monkeypatch) -> None:
result = {
"status_code": 200,
"message": "ok",
"output": {"sentence": {"text": "字典结果"}},
"request_id": "req-dict",
}
class _FakeRecognition:
def __init__(self, **kwargs) -> None:
del kwargs
def call(self, *, file: str):
del file
return result
monkeypatch.setattr(agent_service_module, "Recognition", _FakeRecognition)
monkeypatch.setattr(AsrService, "_get_api_key", lambda self: "test-key")
service = AsrService()
transcript = await service.transcribe_file("/tmp/test.wav", "test.wav")
assert transcript == "字典结果"
async def test_asr_service_returns_empty_when_sentence_missing(monkeypatch) -> None:
result = {
"status_code": 200,
"message": "ok",
"output": {},
}
class _FakeRecognition:
def __init__(self, **kwargs) -> None:
del kwargs
def call(self, *, file: str):
del file
return result
monkeypatch.setattr(agent_service_module, "Recognition", _FakeRecognition)
monkeypatch.setattr(AsrService, "_get_api_key", lambda self: "test-key")
service = AsrService()
transcript = await service.transcribe_file("/tmp/test.wav", "test.wav")
assert transcript == ""