chore: 后端 agent 和 users 模块代码更新优化
This commit is contained in:
@@ -354,17 +354,19 @@ def test_asr_transcribe_returns_sync_transcript(monkeypatch) -> None:
|
||||
id=uuid4(), email="user@example.com"
|
||||
)
|
||||
|
||||
async def mock_transcribe(audio_data: bytes, filename: str) -> str:
|
||||
async def mock_transcribe_file(file_path: str, filename: str) -> str:
|
||||
assert file_path.endswith(".wav")
|
||||
assert filename == "test.wav"
|
||||
return "这是测试转写结果"
|
||||
|
||||
monkeypatch.setattr(
|
||||
"v1.agent.service.asr_service.transcribe",
|
||||
mock_transcribe,
|
||||
"v1.agent.service.asr_service.transcribe_file",
|
||||
mock_transcribe_file,
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
wav_content = b"fake-wav-file-content"
|
||||
wav_content = b"RIFF\x24\x80\x00\x00WAVEfmt "
|
||||
wav_file = BytesIO(wav_content)
|
||||
wav_file.name = "test.wav"
|
||||
|
||||
@@ -380,3 +382,68 @@ def test_asr_transcribe_returns_sync_transcript(monkeypatch) -> None:
|
||||
assert data["transcript"] == "这是测试转写结果"
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_asr_transcribe_rejects_oversized_audio(monkeypatch) -> None:
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=uuid4(), email="user@example.com"
|
||||
)
|
||||
|
||||
monkeypatch.setattr(agent_router, "_MAX_TRANSCRIBE_AUDIO_BYTES", 4)
|
||||
|
||||
client = TestClient(app)
|
||||
oversized = BytesIO(b"12345")
|
||||
oversized.name = "test.wav"
|
||||
|
||||
try:
|
||||
response = client.post(
|
||||
"/api/v1/agent/transcribe",
|
||||
files={"audio": ("test.wav", oversized, "audio/wav")},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.json()["detail"] == "Audio file too large"
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_asr_transcribe_rejects_non_wav_audio() -> None:
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=uuid4(), email="user@example.com"
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
fake_mp3 = BytesIO(b"fake-mp3")
|
||||
fake_mp3.name = "test.mp3"
|
||||
|
||||
try:
|
||||
response = client.post(
|
||||
"/api/v1/agent/transcribe",
|
||||
files={"audio": ("test.mp3", fake_mp3, "audio/mpeg")},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.json()["detail"] == "Unsupported audio format"
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_asr_transcribe_rejects_invalid_wav_payload() -> None:
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=uuid4(), email="user@example.com"
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
fake_payload = BytesIO(b"not-a-wav")
|
||||
fake_payload.name = "test.wav"
|
||||
|
||||
try:
|
||||
response = client.post(
|
||||
"/api/v1/agent/transcribe",
|
||||
files={"audio": ("test.wav", fake_payload, "audio/wav")},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.json()["detail"] == "Unsupported audio format"
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
@@ -9,16 +9,15 @@ APP_SCRIPT = ROOT_DIR / "infra" / "scripts" / "app.sh"
|
||||
|
||||
def test_worker_commands_use_taskiq() -> None:
|
||||
content = APP_SCRIPT.read_text(encoding="utf-8")
|
||||
removed_runner = "uv run c" "elery"
|
||||
removed_runner = "uv run celery"
|
||||
|
||||
assert "uv run taskiq worker" in content
|
||||
assert "core.taskiq.app:critical_broker" in content
|
||||
assert "core.taskiq.app:default_broker" in content
|
||||
assert "core.taskiq.app:bulk_broker" in content
|
||||
assert 'pgrep -f "taskiq.*worker"' in content
|
||||
assert 'pkill -f "taskiq.*worker"' in content
|
||||
assert 'pgrep -f "gunicorn.*app:app"' in content
|
||||
assert 'pkill -f "gunicorn.*app:app"' in content
|
||||
assert 'pgrep -f "uv run taskiq worker core.taskiq.app:"' in content
|
||||
assert 'kill_pids_gracefully "taskiq workers"' in content
|
||||
assert "gunicorn" not in content
|
||||
assert removed_runner not in content
|
||||
|
||||
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
from types import SimpleNamespace
|
||||
from uuid import UUID
|
||||
|
||||
from ag_ui.core import RunAgentInput
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from core.auth.models import CurrentUser
|
||||
from v1.agent.service import AgentService
|
||||
import v1.agent.service as agent_service_module
|
||||
from v1.agent.service import AgentService, AsrService
|
||||
|
||||
|
||||
class _FakeRepository:
|
||||
@@ -249,3 +251,77 @@ async def test_get_user_history_snapshot_uses_latest_thread_when_absent() -> Non
|
||||
)
|
||||
assert event["type"] == "STATE_SNAPSHOT"
|
||||
assert event["threadId"] == "00000000-0000-0000-0000-000000000001"
|
||||
|
||||
|
||||
async def test_asr_service_parses_dict_output_sentence(monkeypatch) -> None:
|
||||
result = SimpleNamespace(
|
||||
status_code=200,
|
||||
message="ok",
|
||||
output={"sentence": {"text": "你好,世界"}},
|
||||
request_id="req-test",
|
||||
)
|
||||
|
||||
class _FakeRecognition:
|
||||
def __init__(self, **kwargs) -> None:
|
||||
del kwargs
|
||||
|
||||
def call(self, *, file: str):
|
||||
del file
|
||||
return result
|
||||
|
||||
monkeypatch.setattr(agent_service_module, "Recognition", _FakeRecognition)
|
||||
monkeypatch.setattr(AsrService, "_get_api_key", lambda self: "test-key")
|
||||
service = AsrService()
|
||||
|
||||
transcript = await service.transcribe_file("/tmp/test.wav", "test.wav")
|
||||
|
||||
assert transcript == "你好,世界"
|
||||
|
||||
|
||||
async def test_asr_service_parses_sentence_when_result_is_dict(monkeypatch) -> None:
|
||||
result = {
|
||||
"status_code": 200,
|
||||
"message": "ok",
|
||||
"output": {"sentence": {"text": "字典结果"}},
|
||||
"request_id": "req-dict",
|
||||
}
|
||||
|
||||
class _FakeRecognition:
|
||||
def __init__(self, **kwargs) -> None:
|
||||
del kwargs
|
||||
|
||||
def call(self, *, file: str):
|
||||
del file
|
||||
return result
|
||||
|
||||
monkeypatch.setattr(agent_service_module, "Recognition", _FakeRecognition)
|
||||
monkeypatch.setattr(AsrService, "_get_api_key", lambda self: "test-key")
|
||||
service = AsrService()
|
||||
|
||||
transcript = await service.transcribe_file("/tmp/test.wav", "test.wav")
|
||||
|
||||
assert transcript == "字典结果"
|
||||
|
||||
|
||||
async def test_asr_service_returns_empty_when_sentence_missing(monkeypatch) -> None:
|
||||
result = {
|
||||
"status_code": 200,
|
||||
"message": "ok",
|
||||
"output": {},
|
||||
}
|
||||
|
||||
class _FakeRecognition:
|
||||
def __init__(self, **kwargs) -> None:
|
||||
del kwargs
|
||||
|
||||
def call(self, *, file: str):
|
||||
del file
|
||||
return result
|
||||
|
||||
monkeypatch.setattr(agent_service_module, "Recognition", _FakeRecognition)
|
||||
monkeypatch.setattr(AsrService, "_get_api_key", lambda self: "test-key")
|
||||
service = AsrService()
|
||||
|
||||
transcript = await service.transcribe_file("/tmp/test.wav", "test.wav")
|
||||
|
||||
assert transcript == ""
|
||||
|
||||
Reference in New Issue
Block a user