Files
social-app/backend/src/v1/agent/service.py
T

710 lines
24 KiB
Python
Raw Normal View History

from __future__ import annotations
import asyncio
from dataclasses import dataclass
from datetime import date
import hashlib
from typing import Any, Protocol
import dashscope
from ag_ui.core import RunAgentInput, StateSnapshotEvent
from dashscope.audio.asr import Recognition, RecognitionCallback
from fastapi import HTTPException
from sqlalchemy.exc import IntegrityError
from core.auth.models import CurrentUser
from core.agentscope.schemas.agui_input import extract_latest_user_payload
from core.config.settings import config
from core.logging import get_logger
logger = get_logger(__name__)
_ALLOWED_ATTACHMENT_MIME_TYPES = {"image/png", "image/jpeg", "image/webp"}
_MAX_ATTACHMENT_BYTES = 5 * 1024 * 1024
_MAX_TOTAL_ATTACHMENT_BYTES = 12 * 1024 * 1024
def _normalize_bearer_token(value: str | None) -> str | None:
if not isinstance(value, str):
return None
normalized = value.strip()
if not normalized:
return None
lower = normalized.lower()
if lower.startswith("bearer "):
token = normalized[7:].strip()
return token or None
return normalized
@dataclass(frozen=True)
class TaskAccepted:
task_id: str
thread_id: str
run_id: str
created: bool
class AgentRepositoryLike(Protocol):
async def get_session_owner(self, *, session_id: str) -> str: ...
async def create_session_for_user(
self, *, user_id: str, session_id: str | None = None
) -> str: ...
async def commit(self) -> None: ...
async def rollback(self) -> None: ...
async def get_history_day(
self, *, session_id: str, before: date | None
) -> dict[str, object] | None: ...
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None: ...
async def persist_user_message(
self,
*,
session_id: str,
run_id: str,
content_text: str,
metadata: dict[str, object] | None,
) -> None: ...
async def get_message_attachment_reference(
self,
*,
session_id: str,
message_id: str,
attachment_index: int,
) -> dict[str, str] | None: ...
class QueueClientLike(Protocol):
async def enqueue(
self, *, command: dict[str, object], dedup_key: str | None
) -> str: ...
class EventStreamLike(Protocol):
async def read(
self,
*,
session_id: str,
last_event_id: str | None,
) -> list[dict[str, object]]: ...
class AttachmentStorageLike(Protocol):
async def upload_bytes(
self,
*,
bucket: str,
path: str,
content: bytes,
content_type: str,
) -> str: ...
async def download_bytes(self, *, bucket: str, path: str) -> bytes: ...
async def create_signed_url(
self,
*,
bucket: str,
path: str,
expires_in_seconds: int,
) -> str: ...
def ensure_session_owner(*, owner_id: str, current_user: CurrentUser) -> None:
if owner_id != str(current_user.id):
raise HTTPException(status_code=403, detail="Forbidden")
class AgentService:
_repository: AgentRepositoryLike
_queue: QueueClientLike
_stream: EventStreamLike
_attachment_storage: AttachmentStorageLike | None
_SIGNED_URL_EXPIRES_IN_SECONDS = 3600
def __init__(
self,
*,
repository: AgentRepositoryLike,
queue: QueueClientLike,
stream: EventStreamLike,
attachment_storage: AttachmentStorageLike | None = None,
) -> None:
self._repository = repository
self._queue = queue
self._stream = stream
self._attachment_storage = attachment_storage
async def enqueue_run(
self,
*,
run_input: RunAgentInput,
current_user: CurrentUser,
user_token: str | None = None,
) -> TaskAccepted:
created = False
thread_id = run_input.thread_id
run_id = run_input.run_id
try:
owner = await self._repository.get_session_owner(session_id=thread_id)
except HTTPException as exc:
if exc.status_code != 404:
raise
try:
await self._repository.create_session_for_user(
user_id=str(current_user.id),
session_id=thread_id,
)
await self._repository.commit()
created = True
except IntegrityError:
await self._repository.rollback()
owner = await self._repository.get_session_owner(session_id=thread_id)
ensure_session_owner(owner_id=owner, current_user=current_user)
else:
ensure_session_owner(owner_id=owner, current_user=current_user)
user_message_text, user_message_metadata = await self._prepare_user_message(
run_input=run_input,
current_user=current_user,
)
await self._repository.persist_user_message(
session_id=thread_id,
run_id=run_id,
content_text=user_message_text,
metadata=user_message_metadata,
)
await self._repository.commit()
task_id = await self._queue.enqueue(
command={
"command": "run",
"owner_id": str(current_user.id),
"user_token": _normalize_bearer_token(user_token),
"run_input": run_input.model_dump(mode="json", by_alias=True),
},
dedup_key=None,
)
return TaskAccepted(
task_id=task_id,
thread_id=thread_id,
run_id=run_id,
created=created,
)
async def _prepare_user_message(
self,
*,
run_input: RunAgentInput,
current_user: CurrentUser,
) -> tuple[str, dict[str, object] | None]:
text, _ = extract_latest_user_payload(run_input)
content_blocks = _extract_latest_user_content_blocks(run_input)
attachments: list[dict[str, object]] = []
binary_blocks = [
block
for block in content_blocks
if isinstance(block, dict) and block.get("type") == "binary"
]
if binary_blocks:
if self._attachment_storage is None:
raise HTTPException(
status_code=503,
detail="Attachment storage unavailable",
)
forwarded_props = (
run_input.forwarded_props
if isinstance(run_input.forwarded_props, dict)
else {}
)
raw_attachments = forwarded_props.get("attachments")
if not isinstance(raw_attachments, list):
raise HTTPException(
status_code=422, detail="Invalid attachments payload"
)
if len(raw_attachments) != len(binary_blocks):
raise HTTPException(
status_code=422, detail="Invalid attachments payload"
)
total_attachment_bytes = 0
expected_prefix = f"agent-inputs/{current_user.id}/{run_input.thread_id}/"
for index, raw_attachment in enumerate(raw_attachments):
if not isinstance(raw_attachment, dict):
raise HTTPException(
status_code=422,
detail="Invalid attachment reference",
)
bucket = raw_attachment.get("bucket")
path = raw_attachment.get("path")
mime_type = raw_attachment.get("mimeType")
if (
not isinstance(bucket, str)
or not isinstance(path, str)
or not isinstance(mime_type, str)
):
raise HTTPException(
status_code=422,
detail="Invalid attachment reference",
)
if bucket != config.storage.bucket:
raise HTTPException(status_code=403, detail="Forbidden")
if not _is_safe_attachment_path(path, expected_prefix=expected_prefix):
raise HTTPException(status_code=403, detail="Forbidden")
if mime_type.lower() not in _ALLOWED_ATTACHMENT_MIME_TYPES:
raise HTTPException(
status_code=422,
detail="Unsupported attachment type",
)
binary_block = binary_blocks[index]
binary_mime = binary_block.get("mimeType")
binary_url = binary_block.get("url")
if (
not isinstance(binary_mime, str)
or binary_mime != mime_type
or not isinstance(binary_url, str)
or not binary_url
):
raise HTTPException(
status_code=422,
detail="Invalid attachments payload",
)
try:
payload = await self._attachment_storage.download_bytes(
bucket=bucket,
path=path,
)
except Exception: # noqa: BLE001
logger.exception(
"Attachment validation download failed",
extra={
"bucket": bucket,
"path": path,
"thread_id": run_input.thread_id,
"run_id": run_input.run_id,
},
)
raise HTTPException(
status_code=502,
detail="Failed to fetch attachment",
)
payload_size = len(payload)
if payload_size > _MAX_ATTACHMENT_BYTES:
raise HTTPException(
status_code=413,
detail="Attachment too large",
)
total_attachment_bytes += payload_size
if total_attachment_bytes > _MAX_TOTAL_ATTACHMENT_BYTES:
raise HTTPException(
status_code=413,
detail="Attachments too large",
)
attachments.append(
{
"bucket": bucket,
"path": path,
"mimeType": mime_type,
}
)
metadata: dict[str, object] = {}
if attachments:
metadata["attachments"] = attachments
return text, metadata or None
async def upload_attachment(
self,
*,
thread_id: str,
filename: str | None,
content_type: str | None,
payload: bytes,
current_user: CurrentUser,
) -> dict[str, str]:
try:
owner = await self._repository.get_session_owner(session_id=thread_id)
except HTTPException as exc:
if exc.status_code != 404:
raise
try:
await self._repository.create_session_for_user(
user_id=str(current_user.id),
session_id=thread_id,
)
await self._repository.commit()
except IntegrityError:
await self._repository.rollback()
owner = await self._repository.get_session_owner(session_id=thread_id)
ensure_session_owner(owner_id=owner, current_user=current_user)
else:
ensure_session_owner(owner_id=owner, current_user=current_user)
if self._attachment_storage is None:
raise HTTPException(
status_code=503, detail="Attachment storage unavailable"
)
if not isinstance(content_type, str):
raise HTTPException(status_code=422, detail="Unsupported attachment type")
mime_type = content_type.lower()
if mime_type not in _ALLOWED_ATTACHMENT_MIME_TYPES:
raise HTTPException(status_code=422, detail="Unsupported attachment type")
if not payload:
raise HTTPException(status_code=422, detail="Empty attachment")
if len(payload) > _MAX_ATTACHMENT_BYTES:
raise HTTPException(status_code=413, detail="Attachment too large")
suffix = _mime_to_suffix(mime_type)
checksum = hashlib.sha1(payload).hexdigest()[:16]
filename_seed = filename if isinstance(filename, str) and filename else "upload"
filename_hash = hashlib.sha1(filename_seed.encode("utf-8")).hexdigest()[:8]
path = (
f"agent-inputs/{current_user.id}/{thread_id}/uploads/"
f"{filename_hash}-{checksum}.{suffix}"
)
bucket_name = config.storage.bucket
try:
stored_path = await self._attachment_storage.upload_bytes(
bucket=bucket_name,
path=path,
content=payload,
content_type=mime_type,
)
signed_url = await self._attachment_storage.create_signed_url(
bucket=bucket_name,
path=stored_path,
expires_in_seconds=self._SIGNED_URL_EXPIRES_IN_SECONDS,
)
except Exception: # noqa: BLE001
logger.exception(
"Attachment upload failed",
extra={
"bucket": bucket_name,
"path": path,
"mime_type": mime_type,
"thread_id": thread_id,
},
)
raise HTTPException(status_code=502, detail="Failed to upload attachment")
return {
"bucket": bucket_name,
"path": stored_path,
"mimeType": mime_type,
"url": signed_url,
}
async def enqueue_resume(
self,
*,
thread_id: str,
run_input: RunAgentInput,
current_user: CurrentUser,
user_token: str | None = None,
) -> TaskAccepted:
owner = await self._repository.get_session_owner(session_id=thread_id)
ensure_session_owner(owner_id=owner, current_user=current_user)
dedup_key = f"resume:{thread_id}:{run_input.run_id}"
task_id = await self._queue.enqueue(
command={
"command": "resume",
"owner_id": str(current_user.id),
"user_token": _normalize_bearer_token(user_token),
"run_input": run_input.model_dump(mode="json", by_alias=True),
},
dedup_key=dedup_key,
)
return TaskAccepted(
task_id=task_id,
thread_id=thread_id,
run_id=run_input.run_id,
created=False,
)
async def stream_events(
self,
*,
thread_id: str,
last_event_id: str | None,
current_user: CurrentUser,
) -> list[dict[str, object]]:
owner = await self._repository.get_session_owner(session_id=thread_id)
ensure_session_owner(owner_id=owner, current_user=current_user)
return await self._stream.read(
session_id=thread_id,
last_event_id=last_event_id,
)
async def get_history_snapshot(
self,
*,
thread_id: str,
before: date | None,
current_user: CurrentUser,
) -> dict[str, object]:
owner = await self._repository.get_session_owner(session_id=thread_id)
ensure_session_owner(owner_id=owner, current_user=current_user)
day_payload = await self._repository.get_history_day(
session_id=thread_id,
before=before,
)
snapshot = {
"scope": "history_day",
"threadId": thread_id,
"day": day_payload["day"] if day_payload else None,
"hasMore": day_payload["hasMore"] if day_payload else False,
"messages": day_payload["messages"] if day_payload else [],
}
event = StateSnapshotEvent(snapshot=snapshot).model_dump(
mode="json",
by_alias=True,
exclude_none=True,
)
event["threadId"] = thread_id
return event
async def get_user_history_snapshot(
self,
*,
current_user: CurrentUser,
thread_id: str | None,
before: date | None,
) -> dict[str, object]:
target_thread_id = thread_id
if target_thread_id is None:
target_thread_id = await self._repository.get_latest_session_id_for_user(
user_id=str(current_user.id)
)
if target_thread_id is None:
return StateSnapshotEvent(
snapshot={
"scope": "history_day",
"threadId": None,
"day": None,
"hasMore": False,
"messages": [],
}
).model_dump(mode="json", by_alias=True, exclude_none=True)
return await self.get_history_snapshot(
thread_id=target_thread_id,
before=before,
current_user=current_user,
)
async def get_attachment_preview(
self,
*,
thread_id: str,
message_id: str,
attachment_index: int,
current_user: CurrentUser,
) -> tuple[bytes, str]:
owner = await self._repository.get_session_owner(session_id=thread_id)
ensure_session_owner(owner_id=owner, current_user=current_user)
if self._attachment_storage is None:
raise HTTPException(
status_code=503, detail="Attachment storage unavailable"
)
ref = await self._repository.get_message_attachment_reference(
session_id=thread_id,
message_id=message_id,
attachment_index=attachment_index,
)
if ref is None:
raise HTTPException(status_code=404, detail="Attachment not found")
bucket = ref.get("bucket")
path = ref.get("path")
mime_type = ref.get("mimeType")
if (
not isinstance(bucket, str)
or not isinstance(path, str)
or not isinstance(mime_type, str)
):
raise HTTPException(status_code=404, detail="Attachment not found")
if bucket != config.storage.bucket:
raise HTTPException(status_code=403, detail="Forbidden")
expected_prefix = f"agent-inputs/{current_user.id}/{thread_id}/"
if not _is_safe_attachment_path(path, expected_prefix=expected_prefix):
raise HTTPException(status_code=403, detail="Forbidden")
try:
payload = await self._attachment_storage.download_bytes(
bucket=bucket,
path=path,
)
except Exception: # noqa: BLE001
logger.exception(
"Attachment download failed",
extra={
"thread_id": thread_id,
"message_id": message_id,
"attachment_index": attachment_index,
"bucket": bucket,
},
)
raise HTTPException(status_code=502, detail="Failed to fetch attachment")
return payload, mime_type
class AsrService:
def __init__(self) -> None:
self._api_key: str | None = None
def _get_api_key(self) -> str:
if self._api_key is None:
dashscope_key = config.llm.provider_keys.get("dashscope")
if not dashscope_key:
raise ValueError(
"DASHSCOPE_API_KEY not configured. Set SOCIAL_LLM__PROVIDER_KEYS__DASHSCOPE in environment."
)
self._api_key = dashscope_key
return self._api_key
async def transcribe_file(self, file_path: str, filename: str) -> str:
try:
dashscope.api_key = self._get_api_key()
loop = asyncio.get_event_loop()
class SyncCallback(RecognitionCallback):
error: str | None = None
def on_error(self, result: Any) -> None:
self.error = str(result)
callback = SyncCallback()
recognizer = Recognition(
model="fun-asr-realtime-2026-02-28",
callback=callback,
format="wav",
sample_rate=16000,
)
result: Any = await loop.run_in_executor(
None,
lambda: recognizer.call(file=file_path),
)
if callback.error:
raise RuntimeError(f"ASR error: {callback.error}")
status_code = self._extract_field(result, "status_code")
if status_code != 200:
message = self._extract_field(result, "message")
raise RuntimeError(f"ASR transcription failed: {message}")
sentence = self._extract_sentence_payload(result)
if sentence is None:
request_id = self._extract_field(result, "request_id")
logger.warning(
"ASR returned empty result", extra={"request_id": request_id}
)
return ""
if isinstance(sentence, dict):
transcription = sentence.get("text", "")
elif isinstance(sentence, list):
transcription = " ".join(
item.get("text", "") for item in sentence if isinstance(item, dict)
)
else:
transcription = str(sentence) if sentence else ""
logger.info(
"ASR transcription completed",
extra={"filename": filename, "transcript_length": len(transcription)},
)
return transcription
except asyncio.CancelledError:
raise
except RuntimeError:
raise
except Exception as exc:
logger.exception("ASR transcription error")
raise RuntimeError(f"ASR transcription failed: {exc}") from exc
def _extract_sentence_payload(self, result: Any) -> Any | None:
if isinstance(result, dict):
output = result.get("output")
if isinstance(output, dict):
return output.get("sentence")
if output is not None:
return getattr(output, "sentence", None)
return result.get("sentence")
get_sentence = getattr(result, "get_sentence", None)
if callable(get_sentence):
sentence = get_sentence()
if sentence is not None:
return sentence
output = getattr(result, "output", None)
if output is None:
return None
if isinstance(output, dict):
return output.get("sentence")
return getattr(output, "sentence", None)
def _extract_field(self, result: Any, field: str) -> Any | None:
if isinstance(result, dict):
return result.get(field)
return getattr(result, field, None)
asr_service = AsrService()
def _extract_latest_user_content_blocks(
run_input: RunAgentInput,
) -> list[dict[str, Any]]:
if not run_input.messages:
return []
latest = run_input.messages[-1]
content = getattr(latest, "content", None)
if not isinstance(content, list):
return []
blocks: list[dict[str, Any]] = []
for item in content:
if isinstance(item, dict):
blocks.append(item)
continue
model_dump = getattr(item, "model_dump", None)
if callable(model_dump):
dumped = model_dump(mode="json", by_alias=True, exclude_none=True)
if isinstance(dumped, dict):
blocks.append(dumped)
return blocks
def _mime_to_suffix(mime_type: str) -> str:
mapping = {
"image/png": "png",
"image/jpeg": "jpg",
"image/webp": "webp",
}
return mapping.get(mime_type.lower(), "bin")
def _is_safe_attachment_path(path: str, *, expected_prefix: str) -> bool:
normalized = path.strip()
if not normalized:
return False
if normalized.startswith("/"):
return False
if ".." in normalized:
return False
return normalized.startswith(expected_prefix)