387 lines
12 KiB
Python
387 lines
12 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import os
|
|
import re
|
|
import tempfile
|
|
from collections.abc import AsyncIterator
|
|
from datetime import date
|
|
from typing import Annotated
|
|
|
|
from ag_ui.core import RunAgentInput
|
|
from core.agentscope.events import to_sse_event
|
|
from core.agentscope.schemas.agui_input import (
|
|
parse_run_input,
|
|
validate_run_request_messages_contract,
|
|
)
|
|
from core.auth.models import CurrentUser
|
|
from core.logging import get_logger
|
|
from redis.exceptions import TimeoutError as RedisTimeoutError
|
|
from fastapi import (
|
|
APIRouter,
|
|
Depends,
|
|
File,
|
|
Form,
|
|
Header,
|
|
HTTPException,
|
|
Query,
|
|
Request,
|
|
UploadFile,
|
|
status,
|
|
)
|
|
from fastapi.responses import StreamingResponse
|
|
from services.base.redis import get_or_init_redis_client
|
|
from v1.agent.dependencies import get_agent_service
|
|
from v1.agent.schemas import (
|
|
AsrTranscribeResponse,
|
|
AttachmentReference,
|
|
AttachmentSignedUrlResponse,
|
|
AttachmentUploadResponse,
|
|
CancelRunResponse,
|
|
HistorySnapshotResponse,
|
|
TaskAcceptedResponse,
|
|
)
|
|
from v1.agent.asr import asr_service
|
|
from v1.agent.service import AgentService
|
|
from v1.users.dependencies import get_current_user
|
|
|
|
router = APIRouter(prefix="/agent", tags=["agent"])
|
|
logger = get_logger("v1.agent.router")
|
|
_LAST_EVENT_ID_RE = re.compile(r"^\d+-\d+$")
|
|
_MAX_SSE_CONNECTIONS_PER_USER = 3
|
|
_SSE_SLOT_TTL_SECONDS = 15 * 60
|
|
_TERMINAL_RUN_EVENT_TYPES = {"RUN_FINISHED", "RUN_ERROR"}
|
|
_MAX_TRANSCRIBE_AUDIO_BYTES = 10 * 1024 * 1024
|
|
_TRANSCRIBE_READ_CHUNK_BYTES = 1024 * 1024
|
|
_MULTIPART_OVERHEAD_BYTES = 64 * 1024
|
|
_MAX_ATTACHMENT_UPLOAD_BYTES = 5 * 1024 * 1024
|
|
_WAV_HEADER_MIN_BYTES = 12
|
|
_ALLOWED_AUDIO_CONTENT_TYPES = {
|
|
"audio/wav",
|
|
"audio/x-wav",
|
|
"audio/wave",
|
|
}
|
|
|
|
|
|
def _looks_like_wav_header(header: bytes) -> bool:
|
|
if len(header) < _WAV_HEADER_MIN_BYTES:
|
|
return False
|
|
return header[0:4] == b"RIFF" and header[8:12] == b"WAVE"
|
|
|
|
|
|
async def _acquire_sse_slot(*, user_id: str) -> bool:
|
|
try:
|
|
redis = await get_or_init_redis_client()
|
|
key = f"agent:sse-active:{user_id}"
|
|
count = await redis.incr(key)
|
|
if count == 1:
|
|
await redis.expire(key, _SSE_SLOT_TTL_SECONDS)
|
|
elif count > _MAX_SSE_CONNECTIONS_PER_USER:
|
|
await redis.decr(key)
|
|
return False
|
|
else:
|
|
ttl = await redis.ttl(key)
|
|
if ttl < 0:
|
|
await redis.expire(key, _SSE_SLOT_TTL_SECONDS)
|
|
return True
|
|
except Exception as exc: # noqa: BLE001
|
|
logger.warning(
|
|
"SSE slot acquire failed",
|
|
user_id=user_id,
|
|
reason=str(exc),
|
|
)
|
|
return True
|
|
|
|
|
|
async def _release_sse_slot(*, user_id: str) -> None:
|
|
try:
|
|
redis = await get_or_init_redis_client()
|
|
key = f"agent:sse-active:{user_id}"
|
|
count = await redis.decr(key)
|
|
if count <= 0:
|
|
await redis.delete(key)
|
|
else:
|
|
ttl = await redis.ttl(key)
|
|
if ttl < 0:
|
|
await redis.expire(key, _SSE_SLOT_TTL_SECONDS)
|
|
except Exception as exc: # noqa: BLE001
|
|
logger.warning(
|
|
"SSE slot release failed",
|
|
user_id=user_id,
|
|
reason=str(exc),
|
|
)
|
|
return None
|
|
|
|
|
|
def _is_terminal_run_event(event: dict[str, object]) -> bool:
|
|
raw_event_type = event.get("type")
|
|
return (
|
|
isinstance(raw_event_type, str) and raw_event_type in _TERMINAL_RUN_EVENT_TYPES
|
|
)
|
|
|
|
|
|
@router.post(
|
|
"/runs", response_model=TaskAcceptedResponse, status_code=status.HTTP_202_ACCEPTED
|
|
)
|
|
async def enqueue_run(
|
|
request: RunAgentInput,
|
|
service: Annotated[AgentService, Depends(get_agent_service)],
|
|
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
|
) -> TaskAcceptedResponse:
|
|
try:
|
|
request = parse_run_input(request.model_dump(by_alias=True, exclude_none=True))
|
|
except ValueError as exc:
|
|
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
|
try:
|
|
validate_run_request_messages_contract(request)
|
|
except ValueError as exc:
|
|
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
|
task = await service.enqueue_run(
|
|
run_input=request,
|
|
current_user=current_user,
|
|
)
|
|
return TaskAcceptedResponse(
|
|
taskId=task.task_id,
|
|
threadId=task.thread_id,
|
|
runId=task.run_id,
|
|
created=task.created,
|
|
)
|
|
|
|
|
|
@router.post(
|
|
"/runs/{thread_id}/cancel",
|
|
response_model=CancelRunResponse,
|
|
status_code=status.HTTP_202_ACCEPTED,
|
|
)
|
|
async def cancel_run(
|
|
thread_id: str,
|
|
service: Annotated[AgentService, Depends(get_agent_service)],
|
|
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
|
run_id: str = Query(
|
|
alias="runId",
|
|
min_length=1,
|
|
max_length=128,
|
|
pattern=r"^[A-Za-z0-9_-]+$",
|
|
),
|
|
) -> CancelRunResponse:
|
|
canceled = await service.cancel_run(
|
|
thread_id=thread_id,
|
|
run_id=run_id,
|
|
current_user=current_user,
|
|
)
|
|
return CancelRunResponse(
|
|
threadId=canceled.thread_id,
|
|
runId=canceled.run_id,
|
|
accepted=canceled.accepted,
|
|
)
|
|
|
|
|
|
@router.get("/runs/{thread_id}/events")
|
|
async def stream_events(
|
|
request: Request,
|
|
thread_id: str,
|
|
service: Annotated[AgentService, Depends(get_agent_service)],
|
|
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
|
last_event_id: str | None = Header(default=None, alias="Last-Event-ID"),
|
|
idle_limit: int = Query(default=300, ge=1, le=3600),
|
|
) -> StreamingResponse:
|
|
if last_event_id is not None and (
|
|
len(last_event_id) > 32 or _LAST_EVENT_ID_RE.fullmatch(last_event_id) is None
|
|
):
|
|
raise HTTPException(status_code=422, detail="Invalid Last-Event-ID")
|
|
|
|
sse_slot_acquired = await _acquire_sse_slot(user_id=str(current_user.id))
|
|
if not sse_slot_acquired:
|
|
raise HTTPException(status_code=429, detail="Too many SSE connections")
|
|
|
|
async def _event_iter() -> AsyncIterator[str]:
|
|
cursor = last_event_id
|
|
idle_polls = 0
|
|
terminal_event_reached = False
|
|
try:
|
|
while (
|
|
not terminal_event_reached
|
|
and not await request.is_disconnected()
|
|
and idle_polls < idle_limit
|
|
):
|
|
try:
|
|
rows = await service.stream_events(
|
|
thread_id=thread_id,
|
|
last_event_id=cursor,
|
|
current_user=current_user,
|
|
)
|
|
except (TimeoutError, RedisTimeoutError):
|
|
idle_polls += 1
|
|
yield ": keep-alive\n\n"
|
|
await asyncio.sleep(0.2)
|
|
continue
|
|
except Exception as exc: # noqa: BLE001
|
|
logger.warning(
|
|
"SSE stream read failed",
|
|
thread_id=thread_id,
|
|
user_id=str(current_user.id),
|
|
reason=str(exc),
|
|
)
|
|
break
|
|
|
|
if not rows:
|
|
idle_polls += 1
|
|
yield ": keep-alive\n\n"
|
|
await asyncio.sleep(0.2)
|
|
continue
|
|
|
|
idle_polls = 0
|
|
for row in rows:
|
|
row_id = str(row.get("id", ""))
|
|
event = row.get("event")
|
|
if not row_id or not isinstance(event, dict):
|
|
continue
|
|
cursor = row_id
|
|
yield to_sse_event(row_id, event)
|
|
if _is_terminal_run_event(event):
|
|
terminal_event_reached = True
|
|
break
|
|
|
|
finally:
|
|
await _release_sse_slot(user_id=str(current_user.id))
|
|
|
|
return StreamingResponse(
|
|
_event_iter(),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"X-Accel-Buffering": "no",
|
|
},
|
|
)
|
|
|
|
|
|
@router.get("/history", response_model=HistorySnapshotResponse)
|
|
async def get_user_history_snapshot(
|
|
service: Annotated[AgentService, Depends(get_agent_service)],
|
|
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
|
thread_id: str | None = Query(default=None, alias="threadId"),
|
|
before: date | None = Query(default=None),
|
|
) -> HistorySnapshotResponse:
|
|
return await service.get_user_history_snapshot(
|
|
current_user=current_user,
|
|
thread_id=thread_id,
|
|
before=before,
|
|
)
|
|
|
|
|
|
@router.post(
|
|
"/attachments",
|
|
response_model=AttachmentUploadResponse,
|
|
status_code=status.HTTP_200_OK,
|
|
)
|
|
async def upload_attachment(
|
|
service: Annotated[AgentService, Depends(get_agent_service)],
|
|
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
|
thread_id: str = Form(alias="threadId"),
|
|
file: UploadFile = File(),
|
|
) -> AttachmentUploadResponse:
|
|
payload = await file.read()
|
|
if not payload:
|
|
raise HTTPException(status_code=422, detail="Empty attachment")
|
|
if len(payload) > _MAX_ATTACHMENT_UPLOAD_BYTES:
|
|
raise HTTPException(status_code=413, detail="Attachment too large")
|
|
attachment = await service.upload_attachment(
|
|
thread_id=thread_id,
|
|
filename=file.filename,
|
|
content_type=file.content_type,
|
|
payload=payload,
|
|
current_user=current_user,
|
|
)
|
|
return AttachmentUploadResponse(
|
|
attachment=AttachmentReference.model_validate(attachment),
|
|
)
|
|
|
|
|
|
@router.get(
|
|
"/attachments/signed-url",
|
|
response_model=AttachmentSignedUrlResponse,
|
|
status_code=status.HTTP_200_OK,
|
|
)
|
|
async def create_attachment_signed_url(
|
|
service: Annotated[AgentService, Depends(get_agent_service)],
|
|
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
|
bucket: str = Query(min_length=1, max_length=100),
|
|
path: str = Query(min_length=1, max_length=500),
|
|
) -> AttachmentSignedUrlResponse:
|
|
signed = await service.create_attachment_signed_url(
|
|
bucket=bucket,
|
|
path=path,
|
|
current_user=current_user,
|
|
)
|
|
return AttachmentSignedUrlResponse(**signed)
|
|
|
|
|
|
@router.post(
|
|
"/transcribe",
|
|
response_model=AsrTranscribeResponse,
|
|
status_code=status.HTTP_200_OK,
|
|
)
|
|
async def transcribe(
|
|
audio: UploadFile,
|
|
request: Request,
|
|
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
|
) -> AsrTranscribeResponse:
|
|
temp_path: str | None = None
|
|
try:
|
|
if audio.content_type not in _ALLOWED_AUDIO_CONTENT_TYPES:
|
|
raise HTTPException(status_code=400, detail="Unsupported audio format")
|
|
|
|
content_length = request.headers.get("content-length")
|
|
if content_length is not None:
|
|
try:
|
|
declared_length = int(content_length)
|
|
except ValueError:
|
|
declared_length = None
|
|
if (
|
|
declared_length is not None
|
|
and declared_length
|
|
> _MAX_TRANSCRIBE_AUDIO_BYTES + _MULTIPART_OVERHEAD_BYTES
|
|
):
|
|
raise HTTPException(status_code=400, detail="Audio file too large")
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
|
|
temp_path = tmp_file.name
|
|
|
|
total_bytes = 0
|
|
header = bytearray()
|
|
while True:
|
|
chunk = await audio.read(_TRANSCRIBE_READ_CHUNK_BYTES)
|
|
if not chunk:
|
|
break
|
|
total_bytes += len(chunk)
|
|
if total_bytes > _MAX_TRANSCRIBE_AUDIO_BYTES:
|
|
raise HTTPException(status_code=400, detail="Audio file too large")
|
|
if len(header) < _WAV_HEADER_MIN_BYTES:
|
|
required = _WAV_HEADER_MIN_BYTES - len(header)
|
|
header.extend(chunk[:required])
|
|
tmp_file.write(chunk)
|
|
|
|
if total_bytes == 0:
|
|
raise HTTPException(status_code=400, detail="Empty audio file")
|
|
if not _looks_like_wav_header(bytes(header)):
|
|
raise HTTPException(status_code=400, detail="Unsupported audio format")
|
|
|
|
transcript = await asr_service.transcribe_file(
|
|
temp_path, audio.filename or "unknown"
|
|
)
|
|
|
|
return AsrTranscribeResponse(transcript=transcript)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except RuntimeError:
|
|
raise HTTPException(status_code=502, detail="ASR service unavailable")
|
|
finally:
|
|
await audio.close()
|
|
if temp_path:
|
|
try:
|
|
os.unlink(temp_path)
|
|
except OSError:
|
|
pass
|