from __future__ import annotations import asyncio import os import re import tempfile from collections.abc import AsyncIterator from datetime import date from typing import Annotated from ag_ui.core import RunAgentInput from core.http.errors import ApiProblemError, problem_payload from core.agentscope.events import to_sse_event from core.agentscope.schemas.agui_input import ( parse_run_input, validate_run_request_messages_contract, ) from core.auth.models import CurrentUser from core.logging import get_logger from redis.exceptions import TimeoutError as RedisTimeoutError from fastapi import ( APIRouter, Depends, File, Form, Header, Query, Request, UploadFile, status, ) from fastapi.responses import StreamingResponse from services.base.redis import get_or_init_redis_client from v1.agent.dependencies import get_agent_service from v1.agent.schemas import ( AsrTranscribeResponse, AttachmentReference, AttachmentSignedUrlResponse, AttachmentUploadResponse, CancelRunResponse, HistorySnapshotResponse, TaskAcceptedResponse, ) from v1.agent.asr import asr_service from v1.agent.service import AgentService from v1.users.dependencies import get_current_user router = APIRouter(prefix="/agent", tags=["agent"]) logger = get_logger("v1.agent.router") _LAST_EVENT_ID_RE = re.compile(r"^\d+-\d+$") _RUN_ID_RE = re.compile(r"^[A-Za-z0-9_-]{1,128}$") _MAX_SSE_CONNECTIONS_PER_USER = 3 _SSE_SLOT_TTL_SECONDS = 15 * 60 _TERMINAL_RUN_EVENT_TYPES = {"RUN_FINISHED", "RUN_ERROR"} _MAX_TRANSCRIBE_AUDIO_BYTES = 10 * 1024 * 1024 _TRANSCRIBE_READ_CHUNK_BYTES = 1024 * 1024 _MULTIPART_OVERHEAD_BYTES = 64 * 1024 _MAX_ATTACHMENT_UPLOAD_BYTES = 5 * 1024 * 1024 _WAV_HEADER_MIN_BYTES = 12 _ALLOWED_AUDIO_CONTENT_TYPES = { "audio/wav", "audio/x-wav", "audio/wave", } def _looks_like_wav_header(header: bytes) -> bool: if len(header) < _WAV_HEADER_MIN_BYTES: return False return header[0:4] == b"RIFF" and header[8:12] == b"WAVE" async def _acquire_sse_slot(*, user_id: str) -> bool: try: redis = await get_or_init_redis_client() key = f"agent:sse-active:{user_id}" count = await redis.incr(key) if count == 1: await redis.expire(key, _SSE_SLOT_TTL_SECONDS) elif count > _MAX_SSE_CONNECTIONS_PER_USER: await redis.decr(key) return False else: ttl = await redis.ttl(key) if ttl < 0: await redis.expire(key, _SSE_SLOT_TTL_SECONDS) return True except Exception as exc: # noqa: BLE001 logger.warning( "SSE slot acquire failed", user_id=user_id, reason=str(exc), ) return True async def _release_sse_slot(*, user_id: str) -> None: try: redis = await get_or_init_redis_client() key = f"agent:sse-active:{user_id}" count = await redis.decr(key) if count <= 0: await redis.delete(key) else: ttl = await redis.ttl(key) if ttl < 0: await redis.expire(key, _SSE_SLOT_TTL_SECONDS) except Exception as exc: # noqa: BLE001 logger.warning( "SSE slot release failed", user_id=user_id, reason=str(exc), ) return None def _is_terminal_run_event(event: dict[str, object]) -> bool: raw_event_type = event.get("type") return ( isinstance(raw_event_type, str) and raw_event_type in _TERMINAL_RUN_EVENT_TYPES ) def _is_target_run_event(event: dict[str, object], *, target_run_id: str) -> bool: run_id = event.get("runId") return isinstance(run_id, str) and run_id == target_run_id @router.post( "/runs", response_model=TaskAcceptedResponse, status_code=status.HTTP_202_ACCEPTED ) async def enqueue_run( request: RunAgentInput, service: Annotated[AgentService, Depends(get_agent_service)], current_user: Annotated[CurrentUser, Depends(get_current_user)], ) -> TaskAcceptedResponse: try: request = parse_run_input(request.model_dump(by_alias=True, exclude_none=True)) except ValueError as exc: raise ApiProblemError( status_code=422, detail=problem_payload(code="AGENT_RUN_INPUT_INVALID", detail=str(exc)), ) from exc try: validate_run_request_messages_contract(request) except ValueError as exc: raise ApiProblemError( status_code=422, detail=problem_payload(code="AGENT_RUN_MESSAGES_INVALID", detail=str(exc)), ) from exc task = await service.enqueue_run( run_input=request, current_user=current_user, ) return TaskAcceptedResponse( taskId=task.task_id, threadId=task.thread_id, runId=task.run_id, created=task.created, ) @router.post( "/runs/{thread_id}/cancel", response_model=CancelRunResponse, status_code=status.HTTP_202_ACCEPTED, ) async def cancel_run( thread_id: str, service: Annotated[AgentService, Depends(get_agent_service)], current_user: Annotated[CurrentUser, Depends(get_current_user)], run_id: str = Query( alias="runId", min_length=1, max_length=128, pattern=r"^[A-Za-z0-9_-]+$", ), ) -> CancelRunResponse: canceled = await service.cancel_run( thread_id=thread_id, run_id=run_id, current_user=current_user, ) return CancelRunResponse( threadId=canceled.thread_id, runId=canceled.run_id, accepted=canceled.accepted, ) @router.get("/runs/{thread_id}/events") async def stream_events( request: Request, thread_id: str, service: Annotated[AgentService, Depends(get_agent_service)], current_user: Annotated[CurrentUser, Depends(get_current_user)], run_id: str | None = Query(default=None, alias="runId"), last_event_id: str | None = Header(default=None, alias="Last-Event-ID"), idle_limit: int = Query(default=300, ge=1, le=3600), ) -> StreamingResponse: if run_id is None or _RUN_ID_RE.fullmatch(run_id) is None: raise ApiProblemError( status_code=422, detail=problem_payload( code="AGENT_INVALID_RUN_ID", detail="Invalid runId", ), ) if last_event_id is not None and ( len(last_event_id) > 32 or _LAST_EVENT_ID_RE.fullmatch(last_event_id) is None ): raise ApiProblemError( status_code=422, detail=problem_payload( code="AGENT_INVALID_LAST_EVENT_ID", detail="Invalid Last-Event-ID", ), ) sse_slot_acquired = await _acquire_sse_slot(user_id=str(current_user.id)) if not sse_slot_acquired: raise ApiProblemError( status_code=429, detail=problem_payload( code="AGENT_SSE_CONNECTION_LIMIT", detail="Too many SSE connections", ), ) async def _event_iter() -> AsyncIterator[str]: cursor = last_event_id idle_polls = 0 terminal_event_reached = False try: while ( not terminal_event_reached and not await request.is_disconnected() and idle_polls < idle_limit ): try: rows = await service.stream_events( thread_id=thread_id, last_event_id=cursor, current_user=current_user, ) except (TimeoutError, RedisTimeoutError): idle_polls += 1 yield ": keep-alive\n\n" await asyncio.sleep(0.2) continue except Exception as exc: # noqa: BLE001 logger.warning( "SSE stream read failed", thread_id=thread_id, user_id=str(current_user.id), reason=str(exc), ) break if not rows: idle_polls += 1 yield ": keep-alive\n\n" await asyncio.sleep(0.2) continue idle_polls = 0 for row in rows: row_id = str(row.get("id", "")) event = row.get("event") if not row_id or not isinstance(event, dict): continue cursor = row_id if not _is_target_run_event(event, target_run_id=run_id): continue yield to_sse_event(row_id, event) if _is_terminal_run_event(event): terminal_event_reached = True break finally: await _release_sse_slot(user_id=str(current_user.id)) return StreamingResponse( _event_iter(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", }, ) @router.get("/history", response_model=HistorySnapshotResponse) async def get_user_history_snapshot( service: Annotated[AgentService, Depends(get_agent_service)], current_user: Annotated[CurrentUser, Depends(get_current_user)], thread_id: str | None = Query(default=None, alias="threadId"), before: date | None = Query(default=None), ) -> HistorySnapshotResponse: return await service.get_user_history_snapshot( current_user=current_user, thread_id=thread_id, before=before, ) @router.post( "/attachments", response_model=AttachmentUploadResponse, status_code=status.HTTP_200_OK, ) async def upload_attachment( service: Annotated[AgentService, Depends(get_agent_service)], current_user: Annotated[CurrentUser, Depends(get_current_user)], thread_id: str = Form(alias="threadId"), file: UploadFile = File(), ) -> AttachmentUploadResponse: payload = await file.read() if not payload: raise ApiProblemError( status_code=422, detail=problem_payload( code="AGENT_ATTACHMENT_EMPTY", detail="Empty attachment", ), ) if len(payload) > _MAX_ATTACHMENT_UPLOAD_BYTES: raise ApiProblemError( status_code=413, detail=problem_payload( code="AGENT_ATTACHMENT_TOO_LARGE", detail="Attachment too large", params={"maxBytes": _MAX_ATTACHMENT_UPLOAD_BYTES}, ), ) attachment = await service.upload_attachment( thread_id=thread_id, filename=file.filename, content_type=file.content_type, payload=payload, current_user=current_user, ) return AttachmentUploadResponse( attachment=AttachmentReference.model_validate(attachment), ) @router.get( "/attachments/signed-url", response_model=AttachmentSignedUrlResponse, status_code=status.HTTP_200_OK, ) async def create_attachment_signed_url( service: Annotated[AgentService, Depends(get_agent_service)], current_user: Annotated[CurrentUser, Depends(get_current_user)], bucket: str = Query(min_length=1, max_length=100), path: str = Query(min_length=1, max_length=500), ) -> AttachmentSignedUrlResponse: signed = await service.create_attachment_signed_url( bucket=bucket, path=path, current_user=current_user, ) return AttachmentSignedUrlResponse(**signed) @router.post( "/transcribe", response_model=AsrTranscribeResponse, status_code=status.HTTP_200_OK, ) async def transcribe( audio: UploadFile, request: Request, current_user: Annotated[CurrentUser, Depends(get_current_user)], ) -> AsrTranscribeResponse: temp_path: str | None = None try: if audio.content_type not in _ALLOWED_AUDIO_CONTENT_TYPES: raise ApiProblemError( status_code=400, detail=problem_payload( code="AGENT_AUDIO_UNSUPPORTED_FORMAT", detail="Unsupported audio format", ), ) content_length = request.headers.get("content-length") if content_length is not None: try: declared_length = int(content_length) except ValueError: declared_length = None if ( declared_length is not None and declared_length > _MAX_TRANSCRIBE_AUDIO_BYTES + _MULTIPART_OVERHEAD_BYTES ): raise ApiProblemError( status_code=400, detail=problem_payload( code="AGENT_AUDIO_TOO_LARGE", detail="Audio file too large", params={"maxBytes": _MAX_TRANSCRIBE_AUDIO_BYTES}, ), ) with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: temp_path = tmp_file.name total_bytes = 0 header = bytearray() while True: chunk = await audio.read(_TRANSCRIBE_READ_CHUNK_BYTES) if not chunk: break total_bytes += len(chunk) if total_bytes > _MAX_TRANSCRIBE_AUDIO_BYTES: raise ApiProblemError( status_code=400, detail=problem_payload( code="AGENT_AUDIO_TOO_LARGE", detail="Audio file too large", params={"maxBytes": _MAX_TRANSCRIBE_AUDIO_BYTES}, ), ) if len(header) < _WAV_HEADER_MIN_BYTES: required = _WAV_HEADER_MIN_BYTES - len(header) header.extend(chunk[:required]) tmp_file.write(chunk) if total_bytes == 0: raise ApiProblemError( status_code=400, detail=problem_payload( code="AGENT_AUDIO_EMPTY", detail="Empty audio file", ), ) if not _looks_like_wav_header(bytes(header)): raise ApiProblemError( status_code=400, detail=problem_payload( code="AGENT_AUDIO_UNSUPPORTED_FORMAT", detail="Unsupported audio format", ), ) transcript = await asr_service.transcribe_file( temp_path, audio.filename or "unknown" ) return AsrTranscribeResponse(transcript=transcript) except ApiProblemError: raise except RuntimeError: raise ApiProblemError( status_code=502, detail=problem_payload( code="AGENT_ASR_UNAVAILABLE", detail="ASR service unavailable", ), ) finally: await audio.close() if temp_path: try: os.unlink(temp_path) except OSError: pass