chore: 后端 agent 和 users 模块代码更新优化
This commit is contained in:
@@ -1,8 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import tempfile
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import date
|
||||
from typing import Any, Protocol
|
||||
@@ -106,16 +104,13 @@ class AgentService:
|
||||
else:
|
||||
ensure_session_owner(owner_id=owner, current_user=current_user)
|
||||
|
||||
try:
|
||||
task_id = await self._queue.enqueue(
|
||||
command={
|
||||
"command": "run",
|
||||
"run_input": run_input.model_dump(mode="json", by_alias=True),
|
||||
},
|
||||
dedup_key=None,
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
raise
|
||||
task_id = await self._queue.enqueue(
|
||||
command={
|
||||
"command": "run",
|
||||
"run_input": run_input.model_dump(mode="json", by_alias=True),
|
||||
},
|
||||
dedup_key=None,
|
||||
)
|
||||
return TaskAccepted(
|
||||
task_id=task_id,
|
||||
thread_id=thread_id,
|
||||
@@ -234,57 +229,46 @@ class AsrService:
|
||||
self._api_key = dashscope_key
|
||||
return self._api_key
|
||||
|
||||
@contextmanager
|
||||
def _temp_wav(self, audio_data: bytes):
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
|
||||
tmp.write(audio_data)
|
||||
tmp_path = tmp.name
|
||||
try:
|
||||
yield tmp_path
|
||||
finally:
|
||||
import os
|
||||
|
||||
if os.path.exists(tmp_path):
|
||||
os.unlink(tmp_path)
|
||||
|
||||
async def transcribe(self, audio_data: bytes, filename: str) -> str:
|
||||
async def transcribe_file(self, file_path: str, filename: str) -> str:
|
||||
try:
|
||||
dashscope.api_key = self._get_api_key()
|
||||
|
||||
with self._temp_wav(audio_data) as tmp_path:
|
||||
loop = asyncio.get_event_loop()
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
class SyncCallback(RecognitionCallback):
|
||||
error: str | None = None
|
||||
class SyncCallback(RecognitionCallback):
|
||||
error: str | None = None
|
||||
|
||||
def on_error(self, result: Any) -> None:
|
||||
self.error = str(result)
|
||||
def on_error(self, result: Any) -> None:
|
||||
self.error = str(result)
|
||||
|
||||
callback = SyncCallback()
|
||||
recognizer = Recognition(
|
||||
model="fun-asr-realtime-2026-02-28",
|
||||
callback=callback,
|
||||
format="wav",
|
||||
sample_rate=16000,
|
||||
)
|
||||
callback = SyncCallback()
|
||||
recognizer = Recognition(
|
||||
model="fun-asr-realtime-2026-02-28",
|
||||
callback=callback,
|
||||
format="wav",
|
||||
sample_rate=16000,
|
||||
)
|
||||
|
||||
result: Any = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: recognizer.call(file=tmp_path),
|
||||
)
|
||||
result: Any = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: recognizer.call(file=file_path),
|
||||
)
|
||||
|
||||
if callback.error:
|
||||
raise RuntimeError(f"ASR error: {callback.error}")
|
||||
if result.status_code != 200:
|
||||
raise RuntimeError(f"ASR transcription failed: {result.message}")
|
||||
status_code = self._extract_field(result, "status_code")
|
||||
if status_code != 200:
|
||||
message = self._extract_field(result, "message")
|
||||
raise RuntimeError(f"ASR transcription failed: {message}")
|
||||
|
||||
if result.output is None or result.output.sentence is None:
|
||||
sentence = self._extract_sentence_payload(result)
|
||||
if sentence is None:
|
||||
request_id = self._extract_field(result, "request_id")
|
||||
logger.warning(
|
||||
"ASR returned empty result", extra={"request_id": result.request_id}
|
||||
"ASR returned empty result", extra={"request_id": request_id}
|
||||
)
|
||||
return ""
|
||||
|
||||
sentence = result.output.sentence
|
||||
if isinstance(sentence, dict):
|
||||
transcription = sentence.get("text", "")
|
||||
elif isinstance(sentence, list):
|
||||
@@ -300,9 +284,40 @@ class AsrService:
|
||||
)
|
||||
return transcription
|
||||
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except RuntimeError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.exception("ASR transcription error")
|
||||
raise RuntimeError(f"ASR transcription failed: {exc}") from exc
|
||||
|
||||
def _extract_sentence_payload(self, result: Any) -> Any | None:
|
||||
if isinstance(result, dict):
|
||||
output = result.get("output")
|
||||
if isinstance(output, dict):
|
||||
return output.get("sentence")
|
||||
if output is not None:
|
||||
return getattr(output, "sentence", None)
|
||||
return result.get("sentence")
|
||||
|
||||
get_sentence = getattr(result, "get_sentence", None)
|
||||
if callable(get_sentence):
|
||||
sentence = get_sentence()
|
||||
if sentence is not None:
|
||||
return sentence
|
||||
|
||||
output = getattr(result, "output", None)
|
||||
if output is None:
|
||||
return None
|
||||
if isinstance(output, dict):
|
||||
return output.get("sentence")
|
||||
return getattr(output, "sentence", None)
|
||||
|
||||
def _extract_field(self, result: Any, field: str) -> Any | None:
|
||||
if isinstance(result, dict):
|
||||
return result.get(field)
|
||||
return getattr(result, field, None)
|
||||
|
||||
|
||||
asr_service = AsrService()
|
||||
|
||||
Reference in New Issue
Block a user