343 lines
11 KiB
Python
343 lines
11 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from dataclasses import dataclass
|
|
from datetime import date
|
|
from typing import Any, Protocol
|
|
|
|
import dashscope
|
|
from ag_ui.core import RunAgentInput, StateSnapshotEvent
|
|
from dashscope.audio.asr import Recognition, RecognitionCallback
|
|
from fastapi import HTTPException
|
|
from sqlalchemy.exc import IntegrityError
|
|
|
|
from core.auth.models import CurrentUser
|
|
from core.config.settings import config
|
|
from core.logging import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
def _extract_user_token_from_run_input(run_input: RunAgentInput) -> str | None:
|
|
forwarded = run_input.forwarded_props
|
|
if not isinstance(forwarded, dict):
|
|
return None
|
|
for key in ("accessToken", "userToken", "token"):
|
|
value = forwarded.get(key)
|
|
if isinstance(value, str) and value.strip():
|
|
return value.strip()
|
|
return None
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class TaskAccepted:
|
|
task_id: str
|
|
thread_id: str
|
|
run_id: str
|
|
created: bool
|
|
|
|
|
|
class AgentRepositoryLike(Protocol):
|
|
async def get_session_owner(self, *, session_id: str) -> str: ...
|
|
|
|
async def create_session_for_user(
|
|
self, *, user_id: str, session_id: str | None = None
|
|
) -> str: ...
|
|
|
|
async def commit(self) -> None: ...
|
|
|
|
async def rollback(self) -> None: ...
|
|
|
|
async def get_history_day(
|
|
self, *, session_id: str, before: date | None
|
|
) -> dict[str, object] | None: ...
|
|
|
|
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None: ...
|
|
|
|
|
|
class QueueClientLike(Protocol):
|
|
async def enqueue(
|
|
self, *, command: dict[str, object], dedup_key: str | None
|
|
) -> str: ...
|
|
|
|
|
|
class EventStreamLike(Protocol):
|
|
async def read(
|
|
self,
|
|
*,
|
|
session_id: str,
|
|
last_event_id: str | None,
|
|
) -> list[dict[str, object]]: ...
|
|
|
|
|
|
def ensure_session_owner(*, owner_id: str, current_user: CurrentUser) -> None:
|
|
if owner_id != str(current_user.id):
|
|
raise HTTPException(status_code=403, detail="Forbidden")
|
|
|
|
|
|
class AgentService:
|
|
_repository: AgentRepositoryLike
|
|
_queue: QueueClientLike
|
|
_stream: EventStreamLike
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
repository: AgentRepositoryLike,
|
|
queue: QueueClientLike,
|
|
stream: EventStreamLike,
|
|
) -> None:
|
|
self._repository = repository
|
|
self._queue = queue
|
|
self._stream = stream
|
|
|
|
async def enqueue_run(
|
|
self,
|
|
*,
|
|
run_input: RunAgentInput,
|
|
current_user: CurrentUser,
|
|
) -> TaskAccepted:
|
|
created = False
|
|
thread_id = run_input.thread_id
|
|
run_id = run_input.run_id
|
|
try:
|
|
owner = await self._repository.get_session_owner(session_id=thread_id)
|
|
except HTTPException as exc:
|
|
if exc.status_code != 404:
|
|
raise
|
|
try:
|
|
await self._repository.create_session_for_user(
|
|
user_id=str(current_user.id),
|
|
session_id=thread_id,
|
|
)
|
|
await self._repository.commit()
|
|
created = True
|
|
except IntegrityError:
|
|
await self._repository.rollback()
|
|
owner = await self._repository.get_session_owner(session_id=thread_id)
|
|
ensure_session_owner(owner_id=owner, current_user=current_user)
|
|
else:
|
|
ensure_session_owner(owner_id=owner, current_user=current_user)
|
|
|
|
task_id = await self._queue.enqueue(
|
|
command={
|
|
"command": "run",
|
|
"owner_id": str(current_user.id),
|
|
"user_token": _extract_user_token_from_run_input(run_input),
|
|
"run_input": run_input.model_dump(mode="json", by_alias=True),
|
|
},
|
|
dedup_key=None,
|
|
)
|
|
return TaskAccepted(
|
|
task_id=task_id,
|
|
thread_id=thread_id,
|
|
run_id=run_id,
|
|
created=created,
|
|
)
|
|
|
|
async def enqueue_resume(
|
|
self,
|
|
*,
|
|
thread_id: str,
|
|
run_input: RunAgentInput,
|
|
current_user: CurrentUser,
|
|
) -> TaskAccepted:
|
|
owner = await self._repository.get_session_owner(session_id=thread_id)
|
|
ensure_session_owner(owner_id=owner, current_user=current_user)
|
|
|
|
dedup_key = f"resume:{thread_id}:{run_input.run_id}"
|
|
task_id = await self._queue.enqueue(
|
|
command={
|
|
"command": "resume",
|
|
"owner_id": str(current_user.id),
|
|
"user_token": _extract_user_token_from_run_input(run_input),
|
|
"run_input": run_input.model_dump(mode="json", by_alias=True),
|
|
},
|
|
dedup_key=dedup_key,
|
|
)
|
|
|
|
return TaskAccepted(
|
|
task_id=task_id,
|
|
thread_id=thread_id,
|
|
run_id=run_input.run_id,
|
|
created=False,
|
|
)
|
|
|
|
async def stream_events(
|
|
self,
|
|
*,
|
|
thread_id: str,
|
|
last_event_id: str | None,
|
|
current_user: CurrentUser,
|
|
) -> list[dict[str, object]]:
|
|
owner = await self._repository.get_session_owner(session_id=thread_id)
|
|
ensure_session_owner(owner_id=owner, current_user=current_user)
|
|
return await self._stream.read(
|
|
session_id=thread_id,
|
|
last_event_id=last_event_id,
|
|
)
|
|
|
|
async def get_history_snapshot(
|
|
self,
|
|
*,
|
|
thread_id: str,
|
|
before: date | None,
|
|
current_user: CurrentUser,
|
|
) -> dict[str, object]:
|
|
owner = await self._repository.get_session_owner(session_id=thread_id)
|
|
ensure_session_owner(owner_id=owner, current_user=current_user)
|
|
day_payload = await self._repository.get_history_day(
|
|
session_id=thread_id,
|
|
before=before,
|
|
)
|
|
snapshot = {
|
|
"scope": "history_day",
|
|
"threadId": thread_id,
|
|
"day": day_payload["day"] if day_payload else None,
|
|
"hasMore": day_payload["hasMore"] if day_payload else False,
|
|
"messages": day_payload["messages"] if day_payload else [],
|
|
}
|
|
event = StateSnapshotEvent(snapshot=snapshot).model_dump(
|
|
mode="json",
|
|
by_alias=True,
|
|
exclude_none=True,
|
|
)
|
|
event["threadId"] = thread_id
|
|
return event
|
|
|
|
async def get_user_history_snapshot(
|
|
self,
|
|
*,
|
|
current_user: CurrentUser,
|
|
thread_id: str | None,
|
|
before: date | None,
|
|
) -> dict[str, object]:
|
|
target_thread_id = thread_id
|
|
if target_thread_id is None:
|
|
target_thread_id = await self._repository.get_latest_session_id_for_user(
|
|
user_id=str(current_user.id)
|
|
)
|
|
if target_thread_id is None:
|
|
return StateSnapshotEvent(
|
|
snapshot={
|
|
"scope": "history_day",
|
|
"threadId": None,
|
|
"day": None,
|
|
"hasMore": False,
|
|
"messages": [],
|
|
}
|
|
).model_dump(mode="json", by_alias=True, exclude_none=True)
|
|
return await self.get_history_snapshot(
|
|
thread_id=target_thread_id,
|
|
before=before,
|
|
current_user=current_user,
|
|
)
|
|
|
|
|
|
class AsrService:
|
|
def __init__(self) -> None:
|
|
self._api_key: str | None = None
|
|
|
|
def _get_api_key(self) -> str:
|
|
if self._api_key is None:
|
|
dashscope_key = config.llm.provider_keys.get("dashscope")
|
|
if not dashscope_key:
|
|
raise ValueError(
|
|
"DASHSCOPE_API_KEY not configured. Set SOCIAL_LLM__PROVIDER_KEYS__DASHSCOPE in environment."
|
|
)
|
|
self._api_key = dashscope_key
|
|
return self._api_key
|
|
|
|
async def transcribe_file(self, file_path: str, filename: str) -> str:
|
|
try:
|
|
dashscope.api_key = self._get_api_key()
|
|
|
|
loop = asyncio.get_event_loop()
|
|
|
|
class SyncCallback(RecognitionCallback):
|
|
error: str | None = None
|
|
|
|
def on_error(self, result: Any) -> None:
|
|
self.error = str(result)
|
|
|
|
callback = SyncCallback()
|
|
recognizer = Recognition(
|
|
model="fun-asr-realtime-2026-02-28",
|
|
callback=callback,
|
|
format="wav",
|
|
sample_rate=16000,
|
|
)
|
|
|
|
result: Any = await loop.run_in_executor(
|
|
None,
|
|
lambda: recognizer.call(file=file_path),
|
|
)
|
|
|
|
if callback.error:
|
|
raise RuntimeError(f"ASR error: {callback.error}")
|
|
status_code = self._extract_field(result, "status_code")
|
|
if status_code != 200:
|
|
message = self._extract_field(result, "message")
|
|
raise RuntimeError(f"ASR transcription failed: {message}")
|
|
|
|
sentence = self._extract_sentence_payload(result)
|
|
if sentence is None:
|
|
request_id = self._extract_field(result, "request_id")
|
|
logger.warning(
|
|
"ASR returned empty result", extra={"request_id": request_id}
|
|
)
|
|
return ""
|
|
|
|
if isinstance(sentence, dict):
|
|
transcription = sentence.get("text", "")
|
|
elif isinstance(sentence, list):
|
|
transcription = " ".join(
|
|
item.get("text", "") for item in sentence if isinstance(item, dict)
|
|
)
|
|
else:
|
|
transcription = str(sentence) if sentence else ""
|
|
|
|
logger.info(
|
|
"ASR transcription completed",
|
|
extra={"filename": filename, "transcript_length": len(transcription)},
|
|
)
|
|
return transcription
|
|
|
|
except asyncio.CancelledError:
|
|
raise
|
|
except RuntimeError:
|
|
raise
|
|
except Exception as exc:
|
|
logger.exception("ASR transcription error")
|
|
raise RuntimeError(f"ASR transcription failed: {exc}") from exc
|
|
|
|
def _extract_sentence_payload(self, result: Any) -> Any | None:
|
|
if isinstance(result, dict):
|
|
output = result.get("output")
|
|
if isinstance(output, dict):
|
|
return output.get("sentence")
|
|
if output is not None:
|
|
return getattr(output, "sentence", None)
|
|
return result.get("sentence")
|
|
|
|
get_sentence = getattr(result, "get_sentence", None)
|
|
if callable(get_sentence):
|
|
sentence = get_sentence()
|
|
if sentence is not None:
|
|
return sentence
|
|
|
|
output = getattr(result, "output", None)
|
|
if output is None:
|
|
return None
|
|
if isinstance(output, dict):
|
|
return output.get("sentence")
|
|
return getattr(output, "sentence", None)
|
|
|
|
def _extract_field(self, result: Any, field: str) -> Any | None:
|
|
if isinstance(result, dict):
|
|
return result.get(field)
|
|
return getattr(result, field, None)
|
|
|
|
|
|
asr_service = AsrService()
|