chore: 后端 agent 和 users 模块代码更新优化
This commit is contained in:
@@ -3,7 +3,9 @@ from __future__ import annotations
|
|||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
import asyncio
|
import asyncio
|
||||||
from datetime import date
|
from datetime import date
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
|
import tempfile
|
||||||
import time
|
import time
|
||||||
from typing import Annotated, Union
|
from typing import Annotated, Union
|
||||||
|
|
||||||
@@ -29,6 +31,21 @@ _LAST_EVENT_ID_RE = re.compile(r"^\d+-\d+$")
|
|||||||
_RUNS_PER_MINUTE = 30
|
_RUNS_PER_MINUTE = 30
|
||||||
_MAX_SSE_CONNECTIONS_PER_USER = 3
|
_MAX_SSE_CONNECTIONS_PER_USER = 3
|
||||||
_SSE_SLOT_TTL_SECONDS = 15 * 60
|
_SSE_SLOT_TTL_SECONDS = 15 * 60
|
||||||
|
_MAX_TRANSCRIBE_AUDIO_BYTES = 10 * 1024 * 1024
|
||||||
|
_TRANSCRIBE_READ_CHUNK_BYTES = 1024 * 1024
|
||||||
|
_MULTIPART_OVERHEAD_BYTES = 64 * 1024
|
||||||
|
_WAV_HEADER_MIN_BYTES = 12
|
||||||
|
_ALLOWED_AUDIO_CONTENT_TYPES = {
|
||||||
|
"audio/wav",
|
||||||
|
"audio/x-wav",
|
||||||
|
"audio/wave",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _looks_like_wav_header(header: bytes) -> bool:
|
||||||
|
if len(header) < _WAV_HEADER_MIN_BYTES:
|
||||||
|
return False
|
||||||
|
return header[0:4] == b"RIFF" and header[8:12] == b"WAVE"
|
||||||
|
|
||||||
|
|
||||||
async def _allow_run_request(*, user_id: str) -> bool:
|
async def _allow_run_request(*, user_id: str) -> bool:
|
||||||
@@ -220,15 +237,52 @@ async def get_user_history_snapshot(
|
|||||||
)
|
)
|
||||||
async def transcribe(
|
async def transcribe(
|
||||||
audio: UploadFile,
|
audio: UploadFile,
|
||||||
|
request: Request,
|
||||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||||
) -> Union[AsrTranscribeResponse, JSONResponse]:
|
) -> Union[AsrTranscribeResponse, JSONResponse]:
|
||||||
|
del current_user
|
||||||
|
temp_path: str | None = None
|
||||||
try:
|
try:
|
||||||
audio_data = await audio.read()
|
if audio.content_type not in _ALLOWED_AUDIO_CONTENT_TYPES:
|
||||||
if not audio_data:
|
raise ValueError("Unsupported audio format")
|
||||||
raise ValueError("Empty audio file")
|
|
||||||
|
|
||||||
transcript = await asr_service.transcribe(
|
content_length = request.headers.get("content-length")
|
||||||
audio_data, audio.filename or "unknown"
|
if content_length is not None:
|
||||||
|
try:
|
||||||
|
declared_length = int(content_length)
|
||||||
|
except ValueError:
|
||||||
|
declared_length = None
|
||||||
|
if (
|
||||||
|
declared_length is not None
|
||||||
|
and declared_length
|
||||||
|
> _MAX_TRANSCRIBE_AUDIO_BYTES + _MULTIPART_OVERHEAD_BYTES
|
||||||
|
):
|
||||||
|
raise ValueError("Audio file too large")
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
|
||||||
|
temp_path = tmp_file.name
|
||||||
|
|
||||||
|
total_bytes = 0
|
||||||
|
header = bytearray()
|
||||||
|
while True:
|
||||||
|
chunk = await audio.read(_TRANSCRIBE_READ_CHUNK_BYTES)
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
total_bytes += len(chunk)
|
||||||
|
if total_bytes > _MAX_TRANSCRIBE_AUDIO_BYTES:
|
||||||
|
raise ValueError("Audio file too large")
|
||||||
|
if len(header) < _WAV_HEADER_MIN_BYTES:
|
||||||
|
required = _WAV_HEADER_MIN_BYTES - len(header)
|
||||||
|
header.extend(chunk[:required])
|
||||||
|
tmp_file.write(chunk)
|
||||||
|
|
||||||
|
if total_bytes == 0:
|
||||||
|
raise ValueError("Empty audio file")
|
||||||
|
if not _looks_like_wav_header(bytes(header)):
|
||||||
|
raise ValueError("Unsupported audio format")
|
||||||
|
|
||||||
|
transcript = await asr_service.transcribe_file(
|
||||||
|
temp_path, audio.filename or "unknown"
|
||||||
)
|
)
|
||||||
|
|
||||||
return AsrTranscribeResponse(transcript=transcript)
|
return AsrTranscribeResponse(transcript=transcript)
|
||||||
@@ -238,8 +292,12 @@ async def transcribe(
|
|||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
content={"detail": str(exc)},
|
content={"detail": str(exc)},
|
||||||
)
|
)
|
||||||
except RuntimeError as exc:
|
except RuntimeError:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||||
content={"detail": str(exc)},
|
content={"detail": "ASR service unavailable"},
|
||||||
)
|
)
|
||||||
|
finally:
|
||||||
|
await audio.close()
|
||||||
|
if temp_path and os.path.exists(temp_path):
|
||||||
|
os.unlink(temp_path)
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import tempfile
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import date
|
from datetime import date
|
||||||
from typing import Any, Protocol
|
from typing import Any, Protocol
|
||||||
@@ -106,16 +104,13 @@ class AgentService:
|
|||||||
else:
|
else:
|
||||||
ensure_session_owner(owner_id=owner, current_user=current_user)
|
ensure_session_owner(owner_id=owner, current_user=current_user)
|
||||||
|
|
||||||
try:
|
task_id = await self._queue.enqueue(
|
||||||
task_id = await self._queue.enqueue(
|
command={
|
||||||
command={
|
"command": "run",
|
||||||
"command": "run",
|
"run_input": run_input.model_dump(mode="json", by_alias=True),
|
||||||
"run_input": run_input.model_dump(mode="json", by_alias=True),
|
},
|
||||||
},
|
dedup_key=None,
|
||||||
dedup_key=None,
|
)
|
||||||
)
|
|
||||||
except Exception: # noqa: BLE001
|
|
||||||
raise
|
|
||||||
return TaskAccepted(
|
return TaskAccepted(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
@@ -234,57 +229,46 @@ class AsrService:
|
|||||||
self._api_key = dashscope_key
|
self._api_key = dashscope_key
|
||||||
return self._api_key
|
return self._api_key
|
||||||
|
|
||||||
@contextmanager
|
async def transcribe_file(self, file_path: str, filename: str) -> str:
|
||||||
def _temp_wav(self, audio_data: bytes):
|
|
||||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
|
|
||||||
tmp.write(audio_data)
|
|
||||||
tmp_path = tmp.name
|
|
||||||
try:
|
|
||||||
yield tmp_path
|
|
||||||
finally:
|
|
||||||
import os
|
|
||||||
|
|
||||||
if os.path.exists(tmp_path):
|
|
||||||
os.unlink(tmp_path)
|
|
||||||
|
|
||||||
async def transcribe(self, audio_data: bytes, filename: str) -> str:
|
|
||||||
try:
|
try:
|
||||||
dashscope.api_key = self._get_api_key()
|
dashscope.api_key = self._get_api_key()
|
||||||
|
|
||||||
with self._temp_wav(audio_data) as tmp_path:
|
loop = asyncio.get_event_loop()
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
|
|
||||||
class SyncCallback(RecognitionCallback):
|
class SyncCallback(RecognitionCallback):
|
||||||
error: str | None = None
|
error: str | None = None
|
||||||
|
|
||||||
def on_error(self, result: Any) -> None:
|
def on_error(self, result: Any) -> None:
|
||||||
self.error = str(result)
|
self.error = str(result)
|
||||||
|
|
||||||
callback = SyncCallback()
|
callback = SyncCallback()
|
||||||
recognizer = Recognition(
|
recognizer = Recognition(
|
||||||
model="fun-asr-realtime-2026-02-28",
|
model="fun-asr-realtime-2026-02-28",
|
||||||
callback=callback,
|
callback=callback,
|
||||||
format="wav",
|
format="wav",
|
||||||
sample_rate=16000,
|
sample_rate=16000,
|
||||||
)
|
)
|
||||||
|
|
||||||
result: Any = await loop.run_in_executor(
|
result: Any = await loop.run_in_executor(
|
||||||
None,
|
None,
|
||||||
lambda: recognizer.call(file=tmp_path),
|
lambda: recognizer.call(file=file_path),
|
||||||
)
|
)
|
||||||
|
|
||||||
if callback.error:
|
if callback.error:
|
||||||
raise RuntimeError(f"ASR error: {callback.error}")
|
raise RuntimeError(f"ASR error: {callback.error}")
|
||||||
if result.status_code != 200:
|
status_code = self._extract_field(result, "status_code")
|
||||||
raise RuntimeError(f"ASR transcription failed: {result.message}")
|
if status_code != 200:
|
||||||
|
message = self._extract_field(result, "message")
|
||||||
|
raise RuntimeError(f"ASR transcription failed: {message}")
|
||||||
|
|
||||||
if result.output is None or result.output.sentence is None:
|
sentence = self._extract_sentence_payload(result)
|
||||||
|
if sentence is None:
|
||||||
|
request_id = self._extract_field(result, "request_id")
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"ASR returned empty result", extra={"request_id": result.request_id}
|
"ASR returned empty result", extra={"request_id": request_id}
|
||||||
)
|
)
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
sentence = result.output.sentence
|
|
||||||
if isinstance(sentence, dict):
|
if isinstance(sentence, dict):
|
||||||
transcription = sentence.get("text", "")
|
transcription = sentence.get("text", "")
|
||||||
elif isinstance(sentence, list):
|
elif isinstance(sentence, list):
|
||||||
@@ -300,9 +284,40 @@ class AsrService:
|
|||||||
)
|
)
|
||||||
return transcription
|
return transcription
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
|
except RuntimeError:
|
||||||
|
raise
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.exception("ASR transcription error")
|
logger.exception("ASR transcription error")
|
||||||
raise RuntimeError(f"ASR transcription failed: {exc}") from exc
|
raise RuntimeError(f"ASR transcription failed: {exc}") from exc
|
||||||
|
|
||||||
|
def _extract_sentence_payload(self, result: Any) -> Any | None:
|
||||||
|
if isinstance(result, dict):
|
||||||
|
output = result.get("output")
|
||||||
|
if isinstance(output, dict):
|
||||||
|
return output.get("sentence")
|
||||||
|
if output is not None:
|
||||||
|
return getattr(output, "sentence", None)
|
||||||
|
return result.get("sentence")
|
||||||
|
|
||||||
|
get_sentence = getattr(result, "get_sentence", None)
|
||||||
|
if callable(get_sentence):
|
||||||
|
sentence = get_sentence()
|
||||||
|
if sentence is not None:
|
||||||
|
return sentence
|
||||||
|
|
||||||
|
output = getattr(result, "output", None)
|
||||||
|
if output is None:
|
||||||
|
return None
|
||||||
|
if isinstance(output, dict):
|
||||||
|
return output.get("sentence")
|
||||||
|
return getattr(output, "sentence", None)
|
||||||
|
|
||||||
|
def _extract_field(self, result: Any, field: str) -> Any | None:
|
||||||
|
if isinstance(result, dict):
|
||||||
|
return result.get(field)
|
||||||
|
return getattr(result, field, None)
|
||||||
|
|
||||||
|
|
||||||
asr_service = AsrService()
|
asr_service = AsrService()
|
||||||
|
|||||||
@@ -1,21 +1,16 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Annotated
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Request, Response
|
from fastapi import APIRouter, Depends, Request, Response
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
|
|
||||||
from core.auth.models import CurrentUser
|
|
||||||
from v1.auth.rate_limit import enforce_rate_limit
|
from v1.auth.rate_limit import enforce_rate_limit
|
||||||
from v1.auth.dependencies import get_auth_service
|
from v1.auth.dependencies import get_auth_service
|
||||||
from v1.users.dependencies import get_current_user
|
|
||||||
from v1.auth.schemas import (
|
from v1.auth.schemas import (
|
||||||
PasswordResetConfirmRequest,
|
PasswordResetConfirmRequest,
|
||||||
SessionCreateRequest,
|
SessionCreateRequest,
|
||||||
SessionDeleteRequest,
|
SessionDeleteRequest,
|
||||||
SessionRefreshRequest,
|
SessionRefreshRequest,
|
||||||
SessionResponse,
|
SessionResponse,
|
||||||
UserByEmailResponse,
|
|
||||||
VerificationCreateRequest,
|
VerificationCreateRequest,
|
||||||
VerificationCreateResponse,
|
VerificationCreateResponse,
|
||||||
VerificationResendRequest,
|
VerificationResendRequest,
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||||||
from core.auth.jwt_verifier import (
|
from core.auth.jwt_verifier import (
|
||||||
JwtVerifier,
|
JwtVerifier,
|
||||||
TokenValidationError,
|
TokenValidationError,
|
||||||
TokenVerifierUnavailableError,
|
|
||||||
)
|
)
|
||||||
from core.auth.models import CurrentUser
|
from core.auth.models import CurrentUser
|
||||||
from core.config.settings import config
|
from core.config.settings import config
|
||||||
@@ -35,17 +34,19 @@ def get_auth_gateway() -> SupabaseAuthGateway:
|
|||||||
def get_jwt_verifier() -> JwtVerifier:
|
def get_jwt_verifier() -> JwtVerifier:
|
||||||
global _jwt_verifier
|
global _jwt_verifier
|
||||||
if _jwt_verifier is None:
|
if _jwt_verifier is None:
|
||||||
jwks_url = config.supabase.jwks_url
|
|
||||||
issuer = config.supabase.jwt_issuer
|
issuer = config.supabase.jwt_issuer
|
||||||
audience = config.supabase.jwt_audience
|
jwt_secret = (
|
||||||
if not jwks_url or not issuer or not audience:
|
config.supabase.jwt_secret.get_secret_value()
|
||||||
|
if config.supabase.jwt_secret is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if not issuer or not jwt_secret:
|
||||||
logger.error("JWT validation failed: verifier config not configured")
|
logger.error("JWT validation failed: verifier config not configured")
|
||||||
raise HTTPException(status_code=503, detail="JWT verifier not configured")
|
raise HTTPException(status_code=503, detail="JWT verifier not configured")
|
||||||
_jwt_verifier = JwtVerifier(
|
_jwt_verifier = JwtVerifier(
|
||||||
jwks_url=jwks_url,
|
|
||||||
issuer=issuer,
|
issuer=issuer,
|
||||||
audience=audience,
|
jwt_secret=jwt_secret,
|
||||||
apikey=config.supabase.anon_key,
|
jwt_algorithm=config.supabase.jwt_algorithm,
|
||||||
)
|
)
|
||||||
return _jwt_verifier
|
return _jwt_verifier
|
||||||
|
|
||||||
@@ -64,9 +65,6 @@ def get_current_user(authorization: str | None = Header(default=None)) -> Curren
|
|||||||
payload = get_jwt_verifier().verify(token)
|
payload = get_jwt_verifier().verify(token)
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except TokenVerifierUnavailableError:
|
|
||||||
logger.error("JWT validation failed: verifier unavailable")
|
|
||||||
raise HTTPException(status_code=503, detail="JWT verifier unavailable")
|
|
||||||
except TokenValidationError as exc:
|
except TokenValidationError as exc:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"JWT validation failed",
|
"JWT validation failed",
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from pydantic import (
|
|||||||
class UserResponse(BaseModel):
|
class UserResponse(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
username: str
|
username: str
|
||||||
|
email: str | None = None
|
||||||
avatar_url: str | None = None
|
avatar_url: str | None = None
|
||||||
bio: str | None = None
|
bio: str | None = None
|
||||||
|
|
||||||
|
|||||||
@@ -90,9 +90,11 @@ class UserService(BaseService):
|
|||||||
|
|
||||||
if user is None:
|
if user is None:
|
||||||
raise HTTPException(status_code=404, detail="User not found")
|
raise HTTPException(status_code=404, detail="User not found")
|
||||||
|
email = self._current_user.email if self._current_user else None
|
||||||
return UserResponse(
|
return UserResponse(
|
||||||
id=str(user.id),
|
id=str(user.id),
|
||||||
username=user.username,
|
username=user.username,
|
||||||
|
email=email,
|
||||||
avatar_url=user.avatar_url,
|
avatar_url=user.avatar_url,
|
||||||
bio=user.bio,
|
bio=user.bio,
|
||||||
)
|
)
|
||||||
@@ -131,9 +133,11 @@ class UserService(BaseService):
|
|||||||
error=str(exc),
|
error=str(exc),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
email = self._current_user.email if self._current_user else None
|
||||||
return UserResponse(
|
return UserResponse(
|
||||||
id=str(user.id),
|
id=str(user.id),
|
||||||
username=user.username,
|
username=user.username,
|
||||||
|
email=email,
|
||||||
avatar_url=user.avatar_url,
|
avatar_url=user.avatar_url,
|
||||||
bio=user.bio,
|
bio=user.bio,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -354,17 +354,19 @@ def test_asr_transcribe_returns_sync_transcript(monkeypatch) -> None:
|
|||||||
id=uuid4(), email="user@example.com"
|
id=uuid4(), email="user@example.com"
|
||||||
)
|
)
|
||||||
|
|
||||||
async def mock_transcribe(audio_data: bytes, filename: str) -> str:
|
async def mock_transcribe_file(file_path: str, filename: str) -> str:
|
||||||
|
assert file_path.endswith(".wav")
|
||||||
|
assert filename == "test.wav"
|
||||||
return "这是测试转写结果"
|
return "这是测试转写结果"
|
||||||
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"v1.agent.service.asr_service.transcribe",
|
"v1.agent.service.asr_service.transcribe_file",
|
||||||
mock_transcribe,
|
mock_transcribe_file,
|
||||||
)
|
)
|
||||||
|
|
||||||
client = TestClient(app)
|
client = TestClient(app)
|
||||||
|
|
||||||
wav_content = b"fake-wav-file-content"
|
wav_content = b"RIFF\x24\x80\x00\x00WAVEfmt "
|
||||||
wav_file = BytesIO(wav_content)
|
wav_file = BytesIO(wav_content)
|
||||||
wav_file.name = "test.wav"
|
wav_file.name = "test.wav"
|
||||||
|
|
||||||
@@ -380,3 +382,68 @@ def test_asr_transcribe_returns_sync_transcript(monkeypatch) -> None:
|
|||||||
assert data["transcript"] == "这是测试转写结果"
|
assert data["transcript"] == "这是测试转写结果"
|
||||||
finally:
|
finally:
|
||||||
app.dependency_overrides = {}
|
app.dependency_overrides = {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_asr_transcribe_rejects_oversized_audio(monkeypatch) -> None:
|
||||||
|
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||||
|
id=uuid4(), email="user@example.com"
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr(agent_router, "_MAX_TRANSCRIBE_AUDIO_BYTES", 4)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
oversized = BytesIO(b"12345")
|
||||||
|
oversized.name = "test.wav"
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/agent/transcribe",
|
||||||
|
files={"audio": ("test.wav", oversized, "audio/wav")},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert response.json()["detail"] == "Audio file too large"
|
||||||
|
finally:
|
||||||
|
app.dependency_overrides = {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_asr_transcribe_rejects_non_wav_audio() -> None:
|
||||||
|
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||||
|
id=uuid4(), email="user@example.com"
|
||||||
|
)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
fake_mp3 = BytesIO(b"fake-mp3")
|
||||||
|
fake_mp3.name = "test.mp3"
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/agent/transcribe",
|
||||||
|
files={"audio": ("test.mp3", fake_mp3, "audio/mpeg")},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert response.json()["detail"] == "Unsupported audio format"
|
||||||
|
finally:
|
||||||
|
app.dependency_overrides = {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_asr_transcribe_rejects_invalid_wav_payload() -> None:
|
||||||
|
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||||
|
id=uuid4(), email="user@example.com"
|
||||||
|
)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
fake_payload = BytesIO(b"not-a-wav")
|
||||||
|
fake_payload.name = "test.wav"
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/agent/transcribe",
|
||||||
|
files={"audio": ("test.wav", fake_payload, "audio/wav")},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert response.json()["detail"] == "Unsupported audio format"
|
||||||
|
finally:
|
||||||
|
app.dependency_overrides = {}
|
||||||
|
|||||||
@@ -9,16 +9,15 @@ APP_SCRIPT = ROOT_DIR / "infra" / "scripts" / "app.sh"
|
|||||||
|
|
||||||
def test_worker_commands_use_taskiq() -> None:
|
def test_worker_commands_use_taskiq() -> None:
|
||||||
content = APP_SCRIPT.read_text(encoding="utf-8")
|
content = APP_SCRIPT.read_text(encoding="utf-8")
|
||||||
removed_runner = "uv run c" "elery"
|
removed_runner = "uv run celery"
|
||||||
|
|
||||||
assert "uv run taskiq worker" in content
|
assert "uv run taskiq worker" in content
|
||||||
assert "core.taskiq.app:critical_broker" in content
|
assert "core.taskiq.app:critical_broker" in content
|
||||||
assert "core.taskiq.app:default_broker" in content
|
assert "core.taskiq.app:default_broker" in content
|
||||||
assert "core.taskiq.app:bulk_broker" in content
|
assert "core.taskiq.app:bulk_broker" in content
|
||||||
assert 'pgrep -f "taskiq.*worker"' in content
|
assert 'pgrep -f "uv run taskiq worker core.taskiq.app:"' in content
|
||||||
assert 'pkill -f "taskiq.*worker"' in content
|
assert 'kill_pids_gracefully "taskiq workers"' in content
|
||||||
assert 'pgrep -f "gunicorn.*app:app"' in content
|
assert "gunicorn" not in content
|
||||||
assert 'pkill -f "gunicorn.*app:app"' in content
|
|
||||||
assert removed_runner not in content
|
assert removed_runner not in content
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,13 +1,15 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import date
|
from datetime import date
|
||||||
|
from types import SimpleNamespace
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from ag_ui.core import RunAgentInput
|
from ag_ui.core import RunAgentInput
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError
|
||||||
from core.auth.models import CurrentUser
|
from core.auth.models import CurrentUser
|
||||||
from v1.agent.service import AgentService
|
import v1.agent.service as agent_service_module
|
||||||
|
from v1.agent.service import AgentService, AsrService
|
||||||
|
|
||||||
|
|
||||||
class _FakeRepository:
|
class _FakeRepository:
|
||||||
@@ -249,3 +251,77 @@ async def test_get_user_history_snapshot_uses_latest_thread_when_absent() -> Non
|
|||||||
)
|
)
|
||||||
assert event["type"] == "STATE_SNAPSHOT"
|
assert event["type"] == "STATE_SNAPSHOT"
|
||||||
assert event["threadId"] == "00000000-0000-0000-0000-000000000001"
|
assert event["threadId"] == "00000000-0000-0000-0000-000000000001"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_asr_service_parses_dict_output_sentence(monkeypatch) -> None:
|
||||||
|
result = SimpleNamespace(
|
||||||
|
status_code=200,
|
||||||
|
message="ok",
|
||||||
|
output={"sentence": {"text": "你好,世界"}},
|
||||||
|
request_id="req-test",
|
||||||
|
)
|
||||||
|
|
||||||
|
class _FakeRecognition:
|
||||||
|
def __init__(self, **kwargs) -> None:
|
||||||
|
del kwargs
|
||||||
|
|
||||||
|
def call(self, *, file: str):
|
||||||
|
del file
|
||||||
|
return result
|
||||||
|
|
||||||
|
monkeypatch.setattr(agent_service_module, "Recognition", _FakeRecognition)
|
||||||
|
monkeypatch.setattr(AsrService, "_get_api_key", lambda self: "test-key")
|
||||||
|
service = AsrService()
|
||||||
|
|
||||||
|
transcript = await service.transcribe_file("/tmp/test.wav", "test.wav")
|
||||||
|
|
||||||
|
assert transcript == "你好,世界"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_asr_service_parses_sentence_when_result_is_dict(monkeypatch) -> None:
|
||||||
|
result = {
|
||||||
|
"status_code": 200,
|
||||||
|
"message": "ok",
|
||||||
|
"output": {"sentence": {"text": "字典结果"}},
|
||||||
|
"request_id": "req-dict",
|
||||||
|
}
|
||||||
|
|
||||||
|
class _FakeRecognition:
|
||||||
|
def __init__(self, **kwargs) -> None:
|
||||||
|
del kwargs
|
||||||
|
|
||||||
|
def call(self, *, file: str):
|
||||||
|
del file
|
||||||
|
return result
|
||||||
|
|
||||||
|
monkeypatch.setattr(agent_service_module, "Recognition", _FakeRecognition)
|
||||||
|
monkeypatch.setattr(AsrService, "_get_api_key", lambda self: "test-key")
|
||||||
|
service = AsrService()
|
||||||
|
|
||||||
|
transcript = await service.transcribe_file("/tmp/test.wav", "test.wav")
|
||||||
|
|
||||||
|
assert transcript == "字典结果"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_asr_service_returns_empty_when_sentence_missing(monkeypatch) -> None:
|
||||||
|
result = {
|
||||||
|
"status_code": 200,
|
||||||
|
"message": "ok",
|
||||||
|
"output": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
class _FakeRecognition:
|
||||||
|
def __init__(self, **kwargs) -> None:
|
||||||
|
del kwargs
|
||||||
|
|
||||||
|
def call(self, *, file: str):
|
||||||
|
del file
|
||||||
|
return result
|
||||||
|
|
||||||
|
monkeypatch.setattr(agent_service_module, "Recognition", _FakeRecognition)
|
||||||
|
monkeypatch.setattr(AsrService, "_get_api_key", lambda self: "test-key")
|
||||||
|
service = AsrService()
|
||||||
|
|
||||||
|
transcript = await service.transcribe_file("/tmp/test.wav", "test.wav")
|
||||||
|
|
||||||
|
assert transcript == ""
|
||||||
|
|||||||
Reference in New Issue
Block a user