2026-03-05 15:34:37 +08:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
from collections.abc import AsyncIterator
|
|
|
|
|
import asyncio
|
2026-03-07 17:30:20 +08:00
|
|
|
from datetime import date
|
|
|
|
|
import re
|
|
|
|
|
import time
|
2026-03-08 17:34:28 +08:00
|
|
|
from typing import Annotated, Union
|
2026-03-05 15:34:37 +08:00
|
|
|
|
2026-03-07 17:30:20 +08:00
|
|
|
from ag_ui.core import RunAgentInput
|
2026-03-08 17:34:28 +08:00
|
|
|
from fastapi import APIRouter, Depends, Header, Query, Request, status, UploadFile
|
2026-03-07 17:30:20 +08:00
|
|
|
from fastapi import HTTPException
|
2026-03-08 17:34:28 +08:00
|
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
2026-03-05 15:34:37 +08:00
|
|
|
|
|
|
|
|
from core.agent.infrastructure.agui.stream import to_sse_event
|
2026-03-08 16:01:16 +08:00
|
|
|
from core.agent.domain.agui_input import (
|
|
|
|
|
parse_run_input,
|
|
|
|
|
validate_run_request_messages_contract,
|
|
|
|
|
)
|
2026-03-05 15:34:37 +08:00
|
|
|
from core.auth.models import CurrentUser
|
2026-03-07 17:30:20 +08:00
|
|
|
from services.base.redis import get_or_init_redis_client
|
2026-03-05 15:34:37 +08:00
|
|
|
from v1.agent.dependencies import get_agent_service
|
2026-03-08 17:34:28 +08:00
|
|
|
from v1.agent.schemas import AsrTranscribeResponse, TaskAcceptedResponse
|
|
|
|
|
from v1.agent.service import AgentService, asr_service
|
2026-03-05 15:34:37 +08:00
|
|
|
from v1.users.dependencies import get_current_user
|
|
|
|
|
|
|
|
|
|
router = APIRouter(prefix="/agent", tags=["agent"])
|
2026-03-07 17:30:20 +08:00
|
|
|
_LAST_EVENT_ID_RE = re.compile(r"^\d+-\d+$")
|
|
|
|
|
_RUNS_PER_MINUTE = 30
|
|
|
|
|
_MAX_SSE_CONNECTIONS_PER_USER = 3
|
|
|
|
|
_SSE_SLOT_TTL_SECONDS = 15 * 60
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _allow_run_request(*, user_id: str) -> bool:
|
|
|
|
|
try:
|
|
|
|
|
redis = await get_or_init_redis_client()
|
|
|
|
|
minute_bucket = int(time.time() // 60)
|
|
|
|
|
key = f"agent:run-rate:{user_id}:{minute_bucket}"
|
|
|
|
|
count = await redis.incr(key)
|
|
|
|
|
if count == 1:
|
|
|
|
|
await redis.expire(key, 70)
|
|
|
|
|
return int(count) <= _RUNS_PER_MINUTE
|
|
|
|
|
except Exception: # noqa: BLE001
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _acquire_sse_slot(*, user_id: str) -> bool:
|
|
|
|
|
try:
|
|
|
|
|
redis = await get_or_init_redis_client()
|
|
|
|
|
key = f"agent:sse-active:{user_id}"
|
|
|
|
|
count = await redis.incr(key)
|
|
|
|
|
if count == 1:
|
|
|
|
|
await redis.expire(key, _SSE_SLOT_TTL_SECONDS)
|
|
|
|
|
if int(count) > _MAX_SSE_CONNECTIONS_PER_USER:
|
|
|
|
|
await redis.decr(key)
|
|
|
|
|
return False
|
|
|
|
|
return True
|
|
|
|
|
except Exception: # noqa: BLE001
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _release_sse_slot(*, user_id: str) -> None:
|
|
|
|
|
try:
|
|
|
|
|
redis = await get_or_init_redis_client()
|
|
|
|
|
key = f"agent:sse-active:{user_id}"
|
|
|
|
|
count = await redis.decr(key)
|
|
|
|
|
if int(count) <= 0:
|
|
|
|
|
await redis.delete(key)
|
|
|
|
|
except Exception: # noqa: BLE001
|
|
|
|
|
return None
|
2026-03-05 15:34:37 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post(
|
|
|
|
|
"/runs", response_model=TaskAcceptedResponse, status_code=status.HTTP_202_ACCEPTED
|
|
|
|
|
)
|
|
|
|
|
async def enqueue_run(
|
2026-03-07 17:30:20 +08:00
|
|
|
request: RunAgentInput,
|
2026-03-05 15:34:37 +08:00
|
|
|
service: Annotated[AgentService, Depends(get_agent_service)],
|
|
|
|
|
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
|
|
|
|
) -> TaskAcceptedResponse:
|
2026-03-07 17:30:20 +08:00
|
|
|
try:
|
2026-03-08 16:01:16 +08:00
|
|
|
normalized = parse_run_input(request.model_dump(mode="json", by_alias=True))
|
|
|
|
|
validate_run_request_messages_contract(normalized)
|
2026-03-07 17:30:20 +08:00
|
|
|
except ValueError as exc:
|
|
|
|
|
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
|
|
|
|
allowed = await _allow_run_request(user_id=str(current_user.id))
|
|
|
|
|
if not allowed:
|
|
|
|
|
raise HTTPException(status_code=429, detail="Too many run requests")
|
|
|
|
|
|
2026-03-05 15:34:37 +08:00
|
|
|
task = await service.enqueue_run(
|
2026-03-07 17:30:20 +08:00
|
|
|
run_input=request,
|
2026-03-05 15:34:37 +08:00
|
|
|
current_user=current_user,
|
|
|
|
|
)
|
|
|
|
|
return TaskAcceptedResponse(
|
2026-03-08 16:01:16 +08:00
|
|
|
taskId=task.task_id,
|
|
|
|
|
threadId=task.thread_id,
|
|
|
|
|
runId=task.run_id,
|
2026-03-05 15:34:37 +08:00
|
|
|
created=task.created,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post(
|
2026-03-07 17:30:20 +08:00
|
|
|
"/runs/{thread_id}/resume",
|
2026-03-05 15:34:37 +08:00
|
|
|
response_model=TaskAcceptedResponse,
|
|
|
|
|
status_code=status.HTTP_202_ACCEPTED,
|
|
|
|
|
)
|
|
|
|
|
async def enqueue_resume(
|
2026-03-07 17:30:20 +08:00
|
|
|
thread_id: str,
|
|
|
|
|
request: RunAgentInput,
|
2026-03-05 15:34:37 +08:00
|
|
|
service: Annotated[AgentService, Depends(get_agent_service)],
|
|
|
|
|
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
|
|
|
|
) -> TaskAcceptedResponse:
|
2026-03-07 17:30:20 +08:00
|
|
|
if request.thread_id != thread_id:
|
|
|
|
|
raise HTTPException(status_code=422, detail="thread_id path/body mismatch")
|
|
|
|
|
try:
|
|
|
|
|
parse_run_input(request.model_dump(mode="json", by_alias=True))
|
|
|
|
|
except ValueError as exc:
|
|
|
|
|
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
2026-03-05 15:34:37 +08:00
|
|
|
task = await service.enqueue_resume(
|
2026-03-07 17:30:20 +08:00
|
|
|
thread_id=thread_id,
|
|
|
|
|
run_input=request,
|
2026-03-05 15:34:37 +08:00
|
|
|
current_user=current_user,
|
|
|
|
|
)
|
|
|
|
|
return TaskAcceptedResponse(
|
2026-03-08 16:01:16 +08:00
|
|
|
taskId=task.task_id,
|
|
|
|
|
threadId=task.thread_id,
|
|
|
|
|
runId=task.run_id,
|
2026-03-05 15:34:37 +08:00
|
|
|
created=task.created,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
2026-03-07 17:30:20 +08:00
|
|
|
@router.get("/runs/{thread_id}/events")
|
2026-03-05 15:34:37 +08:00
|
|
|
async def stream_events(
|
|
|
|
|
request: Request,
|
2026-03-07 17:30:20 +08:00
|
|
|
thread_id: str,
|
2026-03-05 15:34:37 +08:00
|
|
|
service: Annotated[AgentService, Depends(get_agent_service)],
|
|
|
|
|
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
|
|
|
|
last_event_id: str | None = Header(default=None, alias="Last-Event-ID"),
|
|
|
|
|
idle_limit: int = Query(default=300, ge=1, le=3600),
|
|
|
|
|
) -> StreamingResponse:
|
2026-03-08 16:01:16 +08:00
|
|
|
if last_event_id is not None and (
|
|
|
|
|
len(last_event_id) > 32 or _LAST_EVENT_ID_RE.fullmatch(last_event_id) is None
|
2026-03-07 17:30:20 +08:00
|
|
|
):
|
|
|
|
|
raise HTTPException(status_code=422, detail="Invalid Last-Event-ID")
|
|
|
|
|
|
|
|
|
|
sse_slot_acquired = await _acquire_sse_slot(user_id=str(current_user.id))
|
|
|
|
|
if not sse_slot_acquired:
|
|
|
|
|
raise HTTPException(status_code=429, detail="Too many SSE connections")
|
|
|
|
|
|
2026-03-05 15:34:37 +08:00
|
|
|
async def _event_iter() -> AsyncIterator[str]:
|
|
|
|
|
cursor = last_event_id
|
|
|
|
|
idle_polls = 0
|
2026-03-07 17:30:20 +08:00
|
|
|
try:
|
|
|
|
|
while not await request.is_disconnected() and idle_polls < idle_limit:
|
|
|
|
|
rows = await service.stream_events(
|
|
|
|
|
thread_id=thread_id,
|
|
|
|
|
last_event_id=cursor,
|
|
|
|
|
current_user=current_user,
|
|
|
|
|
)
|
|
|
|
|
if not rows:
|
|
|
|
|
idle_polls += 1
|
|
|
|
|
yield ": keep-alive\n\n"
|
|
|
|
|
await asyncio.sleep(0.2)
|
2026-03-05 15:34:37 +08:00
|
|
|
continue
|
2026-03-07 17:30:20 +08:00
|
|
|
|
|
|
|
|
idle_polls = 0
|
|
|
|
|
for row in rows:
|
|
|
|
|
row_id = str(row.get("id", ""))
|
|
|
|
|
event = row.get("event")
|
|
|
|
|
if not row_id or not isinstance(event, dict):
|
|
|
|
|
continue
|
|
|
|
|
cursor = row_id
|
|
|
|
|
yield to_sse_event(row_id, event)
|
|
|
|
|
finally:
|
|
|
|
|
await _release_sse_slot(user_id=str(current_user.id))
|
2026-03-05 15:34:37 +08:00
|
|
|
|
|
|
|
|
return StreamingResponse(
|
|
|
|
|
_event_iter(),
|
|
|
|
|
media_type="text/event-stream",
|
|
|
|
|
headers={
|
|
|
|
|
"Cache-Control": "no-cache",
|
|
|
|
|
"Connection": "keep-alive",
|
|
|
|
|
"X-Accel-Buffering": "no",
|
|
|
|
|
},
|
|
|
|
|
)
|
2026-03-07 17:30:20 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("/runs/{thread_id}/history")
|
|
|
|
|
async def get_history_snapshot(
|
|
|
|
|
thread_id: str,
|
|
|
|
|
service: Annotated[AgentService, Depends(get_agent_service)],
|
|
|
|
|
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
|
|
|
|
before: date | None = Query(default=None),
|
|
|
|
|
) -> dict[str, object]:
|
|
|
|
|
return await service.get_history_snapshot(
|
|
|
|
|
thread_id=thread_id,
|
|
|
|
|
before=before,
|
|
|
|
|
current_user=current_user,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("/history")
|
|
|
|
|
async def get_user_history_snapshot(
|
|
|
|
|
service: Annotated[AgentService, Depends(get_agent_service)],
|
|
|
|
|
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
|
|
|
|
thread_id: str | None = Query(default=None, alias="threadId"),
|
|
|
|
|
before: date | None = Query(default=None),
|
|
|
|
|
) -> dict[str, object]:
|
|
|
|
|
return await service.get_user_history_snapshot(
|
|
|
|
|
current_user=current_user,
|
|
|
|
|
thread_id=thread_id,
|
|
|
|
|
before=before,
|
|
|
|
|
)
|
2026-03-08 17:34:28 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post(
|
|
|
|
|
"/transcribe",
|
|
|
|
|
response_model=AsrTranscribeResponse,
|
|
|
|
|
status_code=status.HTTP_200_OK,
|
|
|
|
|
)
|
|
|
|
|
async def transcribe(
|
|
|
|
|
audio: UploadFile,
|
|
|
|
|
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
|
|
|
|
) -> Union[AsrTranscribeResponse, JSONResponse]:
|
|
|
|
|
try:
|
|
|
|
|
audio_data = await audio.read()
|
|
|
|
|
if not audio_data:
|
|
|
|
|
raise ValueError("Empty audio file")
|
|
|
|
|
|
|
|
|
|
transcript = await asr_service.transcribe(
|
|
|
|
|
audio_data, audio.filename or "unknown"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return AsrTranscribeResponse(transcript=transcript)
|
|
|
|
|
|
|
|
|
|
except ValueError as exc:
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
|
|
content={"detail": str(exc)},
|
|
|
|
|
)
|
|
|
|
|
except RuntimeError as exc:
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
|
|
|
content={"detail": str(exc)},
|
|
|
|
|
)
|