diff --git a/backend/src/v1/agent/router.py b/backend/src/v1/agent/router.py index 0488948..3ae5888 100644 --- a/backend/src/v1/agent/router.py +++ b/backend/src/v1/agent/router.py @@ -3,7 +3,9 @@ from __future__ import annotations from collections.abc import AsyncIterator import asyncio from datetime import date +import os import re +import tempfile import time from typing import Annotated, Union @@ -29,6 +31,21 @@ _LAST_EVENT_ID_RE = re.compile(r"^\d+-\d+$") _RUNS_PER_MINUTE = 30 _MAX_SSE_CONNECTIONS_PER_USER = 3 _SSE_SLOT_TTL_SECONDS = 15 * 60 +_MAX_TRANSCRIBE_AUDIO_BYTES = 10 * 1024 * 1024 +_TRANSCRIBE_READ_CHUNK_BYTES = 1024 * 1024 +_MULTIPART_OVERHEAD_BYTES = 64 * 1024 +_WAV_HEADER_MIN_BYTES = 12 +_ALLOWED_AUDIO_CONTENT_TYPES = { + "audio/wav", + "audio/x-wav", + "audio/wave", +} + + +def _looks_like_wav_header(header: bytes) -> bool: + if len(header) < _WAV_HEADER_MIN_BYTES: + return False + return header[0:4] == b"RIFF" and header[8:12] == b"WAVE" async def _allow_run_request(*, user_id: str) -> bool: @@ -220,15 +237,52 @@ async def get_user_history_snapshot( ) async def transcribe( audio: UploadFile, + request: Request, current_user: Annotated[CurrentUser, Depends(get_current_user)], ) -> Union[AsrTranscribeResponse, JSONResponse]: + del current_user + temp_path: str | None = None try: - audio_data = await audio.read() - if not audio_data: - raise ValueError("Empty audio file") + if audio.content_type not in _ALLOWED_AUDIO_CONTENT_TYPES: + raise ValueError("Unsupported audio format") - transcript = await asr_service.transcribe( - audio_data, audio.filename or "unknown" + content_length = request.headers.get("content-length") + if content_length is not None: + try: + declared_length = int(content_length) + except ValueError: + declared_length = None + if ( + declared_length is not None + and declared_length + > _MAX_TRANSCRIBE_AUDIO_BYTES + _MULTIPART_OVERHEAD_BYTES + ): + raise ValueError("Audio file too large") + + with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: + temp_path = tmp_file.name + + total_bytes = 0 + header = bytearray() + while True: + chunk = await audio.read(_TRANSCRIBE_READ_CHUNK_BYTES) + if not chunk: + break + total_bytes += len(chunk) + if total_bytes > _MAX_TRANSCRIBE_AUDIO_BYTES: + raise ValueError("Audio file too large") + if len(header) < _WAV_HEADER_MIN_BYTES: + required = _WAV_HEADER_MIN_BYTES - len(header) + header.extend(chunk[:required]) + tmp_file.write(chunk) + + if total_bytes == 0: + raise ValueError("Empty audio file") + if not _looks_like_wav_header(bytes(header)): + raise ValueError("Unsupported audio format") + + transcript = await asr_service.transcribe_file( + temp_path, audio.filename or "unknown" ) return AsrTranscribeResponse(transcript=transcript) @@ -238,8 +292,12 @@ async def transcribe( status_code=status.HTTP_400_BAD_REQUEST, content={"detail": str(exc)}, ) - except RuntimeError as exc: + except RuntimeError: return JSONResponse( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content={"detail": str(exc)}, + status_code=status.HTTP_502_BAD_GATEWAY, + content={"detail": "ASR service unavailable"}, ) + finally: + await audio.close() + if temp_path and os.path.exists(temp_path): + os.unlink(temp_path) diff --git a/backend/src/v1/agent/service.py b/backend/src/v1/agent/service.py index c60e6ce..a3ac36b 100644 --- a/backend/src/v1/agent/service.py +++ b/backend/src/v1/agent/service.py @@ -1,8 +1,6 @@ from __future__ import annotations import asyncio -import tempfile -from contextlib import contextmanager from dataclasses import dataclass from datetime import date from typing import Any, Protocol @@ -106,16 +104,13 @@ class AgentService: else: ensure_session_owner(owner_id=owner, current_user=current_user) - try: - task_id = await self._queue.enqueue( - command={ - "command": "run", - "run_input": run_input.model_dump(mode="json", by_alias=True), - }, - dedup_key=None, - ) - except Exception: # noqa: BLE001 - raise + task_id = await self._queue.enqueue( + command={ + "command": "run", + "run_input": run_input.model_dump(mode="json", by_alias=True), + }, + dedup_key=None, + ) return TaskAccepted( task_id=task_id, thread_id=thread_id, @@ -234,57 +229,46 @@ class AsrService: self._api_key = dashscope_key return self._api_key - @contextmanager - def _temp_wav(self, audio_data: bytes): - with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: - tmp.write(audio_data) - tmp_path = tmp.name - try: - yield tmp_path - finally: - import os - - if os.path.exists(tmp_path): - os.unlink(tmp_path) - - async def transcribe(self, audio_data: bytes, filename: str) -> str: + async def transcribe_file(self, file_path: str, filename: str) -> str: try: dashscope.api_key = self._get_api_key() - with self._temp_wav(audio_data) as tmp_path: - loop = asyncio.get_event_loop() + loop = asyncio.get_event_loop() - class SyncCallback(RecognitionCallback): - error: str | None = None + class SyncCallback(RecognitionCallback): + error: str | None = None - def on_error(self, result: Any) -> None: - self.error = str(result) + def on_error(self, result: Any) -> None: + self.error = str(result) - callback = SyncCallback() - recognizer = Recognition( - model="fun-asr-realtime-2026-02-28", - callback=callback, - format="wav", - sample_rate=16000, - ) + callback = SyncCallback() + recognizer = Recognition( + model="fun-asr-realtime-2026-02-28", + callback=callback, + format="wav", + sample_rate=16000, + ) - result: Any = await loop.run_in_executor( - None, - lambda: recognizer.call(file=tmp_path), - ) + result: Any = await loop.run_in_executor( + None, + lambda: recognizer.call(file=file_path), + ) if callback.error: raise RuntimeError(f"ASR error: {callback.error}") - if result.status_code != 200: - raise RuntimeError(f"ASR transcription failed: {result.message}") + status_code = self._extract_field(result, "status_code") + if status_code != 200: + message = self._extract_field(result, "message") + raise RuntimeError(f"ASR transcription failed: {message}") - if result.output is None or result.output.sentence is None: + sentence = self._extract_sentence_payload(result) + if sentence is None: + request_id = self._extract_field(result, "request_id") logger.warning( - "ASR returned empty result", extra={"request_id": result.request_id} + "ASR returned empty result", extra={"request_id": request_id} ) return "" - sentence = result.output.sentence if isinstance(sentence, dict): transcription = sentence.get("text", "") elif isinstance(sentence, list): @@ -300,9 +284,40 @@ class AsrService: ) return transcription + except asyncio.CancelledError: + raise + except RuntimeError: + raise except Exception as exc: logger.exception("ASR transcription error") raise RuntimeError(f"ASR transcription failed: {exc}") from exc + def _extract_sentence_payload(self, result: Any) -> Any | None: + if isinstance(result, dict): + output = result.get("output") + if isinstance(output, dict): + return output.get("sentence") + if output is not None: + return getattr(output, "sentence", None) + return result.get("sentence") + + get_sentence = getattr(result, "get_sentence", None) + if callable(get_sentence): + sentence = get_sentence() + if sentence is not None: + return sentence + + output = getattr(result, "output", None) + if output is None: + return None + if isinstance(output, dict): + return output.get("sentence") + return getattr(output, "sentence", None) + + def _extract_field(self, result: Any, field: str) -> Any | None: + if isinstance(result, dict): + return result.get(field) + return getattr(result, field, None) + asr_service = AsrService() diff --git a/backend/src/v1/auth/router.py b/backend/src/v1/auth/router.py index ab987b0..5a04767 100644 --- a/backend/src/v1/auth/router.py +++ b/backend/src/v1/auth/router.py @@ -1,21 +1,16 @@ from __future__ import annotations -from typing import Annotated - from fastapi import APIRouter, Depends, Request, Response from fastapi import HTTPException -from core.auth.models import CurrentUser from v1.auth.rate_limit import enforce_rate_limit from v1.auth.dependencies import get_auth_service -from v1.users.dependencies import get_current_user from v1.auth.schemas import ( PasswordResetConfirmRequest, SessionCreateRequest, SessionDeleteRequest, SessionRefreshRequest, SessionResponse, - UserByEmailResponse, VerificationCreateRequest, VerificationCreateResponse, VerificationResendRequest, diff --git a/backend/src/v1/users/dependencies.py b/backend/src/v1/users/dependencies.py index ae20b1b..453ded0 100644 --- a/backend/src/v1/users/dependencies.py +++ b/backend/src/v1/users/dependencies.py @@ -9,7 +9,6 @@ from sqlalchemy.ext.asyncio import AsyncSession from core.auth.jwt_verifier import ( JwtVerifier, TokenValidationError, - TokenVerifierUnavailableError, ) from core.auth.models import CurrentUser from core.config.settings import config @@ -35,17 +34,19 @@ def get_auth_gateway() -> SupabaseAuthGateway: def get_jwt_verifier() -> JwtVerifier: global _jwt_verifier if _jwt_verifier is None: - jwks_url = config.supabase.jwks_url issuer = config.supabase.jwt_issuer - audience = config.supabase.jwt_audience - if not jwks_url or not issuer or not audience: + jwt_secret = ( + config.supabase.jwt_secret.get_secret_value() + if config.supabase.jwt_secret is not None + else None + ) + if not issuer or not jwt_secret: logger.error("JWT validation failed: verifier config not configured") raise HTTPException(status_code=503, detail="JWT verifier not configured") _jwt_verifier = JwtVerifier( - jwks_url=jwks_url, issuer=issuer, - audience=audience, - apikey=config.supabase.anon_key, + jwt_secret=jwt_secret, + jwt_algorithm=config.supabase.jwt_algorithm, ) return _jwt_verifier @@ -64,9 +65,6 @@ def get_current_user(authorization: str | None = Header(default=None)) -> Curren payload = get_jwt_verifier().verify(token) except HTTPException: raise - except TokenVerifierUnavailableError: - logger.error("JWT validation failed: verifier unavailable") - raise HTTPException(status_code=503, detail="JWT verifier unavailable") except TokenValidationError as exc: logger.warning( "JWT validation failed", diff --git a/backend/src/v1/users/schemas.py b/backend/src/v1/users/schemas.py index 273209f..c171beb 100644 --- a/backend/src/v1/users/schemas.py +++ b/backend/src/v1/users/schemas.py @@ -15,6 +15,7 @@ from pydantic import ( class UserResponse(BaseModel): id: str username: str + email: str | None = None avatar_url: str | None = None bio: str | None = None diff --git a/backend/src/v1/users/service.py b/backend/src/v1/users/service.py index b5918b0..915be0c 100644 --- a/backend/src/v1/users/service.py +++ b/backend/src/v1/users/service.py @@ -90,9 +90,11 @@ class UserService(BaseService): if user is None: raise HTTPException(status_code=404, detail="User not found") + email = self._current_user.email if self._current_user else None return UserResponse( id=str(user.id), username=user.username, + email=email, avatar_url=user.avatar_url, bio=user.bio, ) @@ -131,9 +133,11 @@ class UserService(BaseService): error=str(exc), ) + email = self._current_user.email if self._current_user else None return UserResponse( id=str(user.id), username=user.username, + email=email, avatar_url=user.avatar_url, bio=user.bio, ) diff --git a/backend/tests/integration/v1/agent/test_routes.py b/backend/tests/integration/v1/agent/test_routes.py index 65db156..c70447f 100644 --- a/backend/tests/integration/v1/agent/test_routes.py +++ b/backend/tests/integration/v1/agent/test_routes.py @@ -354,17 +354,19 @@ def test_asr_transcribe_returns_sync_transcript(monkeypatch) -> None: id=uuid4(), email="user@example.com" ) - async def mock_transcribe(audio_data: bytes, filename: str) -> str: + async def mock_transcribe_file(file_path: str, filename: str) -> str: + assert file_path.endswith(".wav") + assert filename == "test.wav" return "这是测试转写结果" monkeypatch.setattr( - "v1.agent.service.asr_service.transcribe", - mock_transcribe, + "v1.agent.service.asr_service.transcribe_file", + mock_transcribe_file, ) client = TestClient(app) - wav_content = b"fake-wav-file-content" + wav_content = b"RIFF\x24\x80\x00\x00WAVEfmt " wav_file = BytesIO(wav_content) wav_file.name = "test.wav" @@ -380,3 +382,68 @@ def test_asr_transcribe_returns_sync_transcript(monkeypatch) -> None: assert data["transcript"] == "这是测试转写结果" finally: app.dependency_overrides = {} + + +def test_asr_transcribe_rejects_oversized_audio(monkeypatch) -> None: + app.dependency_overrides[get_current_user] = lambda: CurrentUser( + id=uuid4(), email="user@example.com" + ) + + monkeypatch.setattr(agent_router, "_MAX_TRANSCRIBE_AUDIO_BYTES", 4) + + client = TestClient(app) + oversized = BytesIO(b"12345") + oversized.name = "test.wav" + + try: + response = client.post( + "/api/v1/agent/transcribe", + files={"audio": ("test.wav", oversized, "audio/wav")}, + ) + + assert response.status_code == 400 + assert response.json()["detail"] == "Audio file too large" + finally: + app.dependency_overrides = {} + + +def test_asr_transcribe_rejects_non_wav_audio() -> None: + app.dependency_overrides[get_current_user] = lambda: CurrentUser( + id=uuid4(), email="user@example.com" + ) + + client = TestClient(app) + fake_mp3 = BytesIO(b"fake-mp3") + fake_mp3.name = "test.mp3" + + try: + response = client.post( + "/api/v1/agent/transcribe", + files={"audio": ("test.mp3", fake_mp3, "audio/mpeg")}, + ) + + assert response.status_code == 400 + assert response.json()["detail"] == "Unsupported audio format" + finally: + app.dependency_overrides = {} + + +def test_asr_transcribe_rejects_invalid_wav_payload() -> None: + app.dependency_overrides[get_current_user] = lambda: CurrentUser( + id=uuid4(), email="user@example.com" + ) + + client = TestClient(app) + fake_payload = BytesIO(b"not-a-wav") + fake_payload.name = "test.wav" + + try: + response = client.post( + "/api/v1/agent/transcribe", + files={"audio": ("test.wav", fake_payload, "audio/wav")}, + ) + + assert response.status_code == 400 + assert response.json()["detail"] == "Unsupported audio format" + finally: + app.dependency_overrides = {} diff --git a/backend/tests/unit/infra/test_worker_runtime_script.py b/backend/tests/unit/infra/test_worker_runtime_script.py index c50ad6b..1c9ac8c 100644 --- a/backend/tests/unit/infra/test_worker_runtime_script.py +++ b/backend/tests/unit/infra/test_worker_runtime_script.py @@ -9,16 +9,15 @@ APP_SCRIPT = ROOT_DIR / "infra" / "scripts" / "app.sh" def test_worker_commands_use_taskiq() -> None: content = APP_SCRIPT.read_text(encoding="utf-8") - removed_runner = "uv run c" "elery" + removed_runner = "uv run celery" assert "uv run taskiq worker" in content assert "core.taskiq.app:critical_broker" in content assert "core.taskiq.app:default_broker" in content assert "core.taskiq.app:bulk_broker" in content - assert 'pgrep -f "taskiq.*worker"' in content - assert 'pkill -f "taskiq.*worker"' in content - assert 'pgrep -f "gunicorn.*app:app"' in content - assert 'pkill -f "gunicorn.*app:app"' in content + assert 'pgrep -f "uv run taskiq worker core.taskiq.app:"' in content + assert 'kill_pids_gracefully "taskiq workers"' in content + assert "gunicorn" not in content assert removed_runner not in content diff --git a/backend/tests/unit/v1/agent/test_service.py b/backend/tests/unit/v1/agent/test_service.py index 9e1b69f..4136109 100644 --- a/backend/tests/unit/v1/agent/test_service.py +++ b/backend/tests/unit/v1/agent/test_service.py @@ -1,13 +1,15 @@ from __future__ import annotations from datetime import date +from types import SimpleNamespace from uuid import UUID from ag_ui.core import RunAgentInput from fastapi import HTTPException from sqlalchemy.exc import IntegrityError from core.auth.models import CurrentUser -from v1.agent.service import AgentService +import v1.agent.service as agent_service_module +from v1.agent.service import AgentService, AsrService class _FakeRepository: @@ -249,3 +251,77 @@ async def test_get_user_history_snapshot_uses_latest_thread_when_absent() -> Non ) assert event["type"] == "STATE_SNAPSHOT" assert event["threadId"] == "00000000-0000-0000-0000-000000000001" + + +async def test_asr_service_parses_dict_output_sentence(monkeypatch) -> None: + result = SimpleNamespace( + status_code=200, + message="ok", + output={"sentence": {"text": "你好,世界"}}, + request_id="req-test", + ) + + class _FakeRecognition: + def __init__(self, **kwargs) -> None: + del kwargs + + def call(self, *, file: str): + del file + return result + + monkeypatch.setattr(agent_service_module, "Recognition", _FakeRecognition) + monkeypatch.setattr(AsrService, "_get_api_key", lambda self: "test-key") + service = AsrService() + + transcript = await service.transcribe_file("/tmp/test.wav", "test.wav") + + assert transcript == "你好,世界" + + +async def test_asr_service_parses_sentence_when_result_is_dict(monkeypatch) -> None: + result = { + "status_code": 200, + "message": "ok", + "output": {"sentence": {"text": "字典结果"}}, + "request_id": "req-dict", + } + + class _FakeRecognition: + def __init__(self, **kwargs) -> None: + del kwargs + + def call(self, *, file: str): + del file + return result + + monkeypatch.setattr(agent_service_module, "Recognition", _FakeRecognition) + monkeypatch.setattr(AsrService, "_get_api_key", lambda self: "test-key") + service = AsrService() + + transcript = await service.transcribe_file("/tmp/test.wav", "test.wav") + + assert transcript == "字典结果" + + +async def test_asr_service_returns_empty_when_sentence_missing(monkeypatch) -> None: + result = { + "status_code": 200, + "message": "ok", + "output": {}, + } + + class _FakeRecognition: + def __init__(self, **kwargs) -> None: + del kwargs + + def call(self, *, file: str): + del file + return result + + monkeypatch.setattr(agent_service_module, "Recognition", _FakeRecognition) + monkeypatch.setattr(AsrService, "_get_api_key", lambda self: "test-key") + service = AsrService() + + transcript = await service.transcribe_file("/tmp/test.wav", "test.wav") + + assert transcript == ""