105 lines
3.3 KiB
Python
105 lines
3.3 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
from collections.abc import AsyncIterator
|
||
|
|
import asyncio
|
||
|
|
from typing import Annotated
|
||
|
|
|
||
|
|
from fastapi import APIRouter, Depends, Header, Query, Request, status
|
||
|
|
from fastapi.responses import StreamingResponse
|
||
|
|
|
||
|
|
from core.agent.infrastructure.agui.stream import to_sse_event
|
||
|
|
from core.auth.models import CurrentUser
|
||
|
|
from v1.agent.dependencies import get_agent_service
|
||
|
|
from v1.agent.schemas import ResumeRequest, RunRequest, TaskAcceptedResponse
|
||
|
|
from v1.agent.service import AgentService
|
||
|
|
from v1.users.dependencies import get_current_user
|
||
|
|
|
||
|
|
router = APIRouter(prefix="/agent", tags=["agent"])
|
||
|
|
|
||
|
|
|
||
|
|
@router.post(
|
||
|
|
"/runs", response_model=TaskAcceptedResponse, status_code=status.HTTP_202_ACCEPTED
|
||
|
|
)
|
||
|
|
async def enqueue_run(
|
||
|
|
request: RunRequest,
|
||
|
|
service: Annotated[AgentService, Depends(get_agent_service)],
|
||
|
|
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||
|
|
) -> TaskAcceptedResponse:
|
||
|
|
task = await service.enqueue_run(
|
||
|
|
session_id=request.session_id,
|
||
|
|
prompt=request.prompt,
|
||
|
|
current_user=current_user,
|
||
|
|
)
|
||
|
|
return TaskAcceptedResponse(
|
||
|
|
task_id=task.task_id,
|
||
|
|
session_id=task.session_id,
|
||
|
|
created=task.created,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@router.post(
|
||
|
|
"/runs/{session_id}/resume",
|
||
|
|
response_model=TaskAcceptedResponse,
|
||
|
|
status_code=status.HTTP_202_ACCEPTED,
|
||
|
|
)
|
||
|
|
async def enqueue_resume(
|
||
|
|
session_id: str,
|
||
|
|
request: ResumeRequest,
|
||
|
|
service: Annotated[AgentService, Depends(get_agent_service)],
|
||
|
|
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||
|
|
) -> TaskAcceptedResponse:
|
||
|
|
task = await service.enqueue_resume(
|
||
|
|
session_id=session_id,
|
||
|
|
tool_call_id=request.tool_call_id,
|
||
|
|
current_user=current_user,
|
||
|
|
)
|
||
|
|
return TaskAcceptedResponse(
|
||
|
|
task_id=task.task_id,
|
||
|
|
session_id=task.session_id,
|
||
|
|
created=task.created,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@router.get("/runs/{session_id}/events")
|
||
|
|
async def stream_events(
|
||
|
|
request: Request,
|
||
|
|
session_id: str,
|
||
|
|
service: Annotated[AgentService, Depends(get_agent_service)],
|
||
|
|
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||
|
|
last_event_id: str | None = Header(default=None, alias="Last-Event-ID"),
|
||
|
|
idle_limit: int = Query(default=300, ge=1, le=3600),
|
||
|
|
) -> StreamingResponse:
|
||
|
|
async def _event_iter() -> AsyncIterator[str]:
|
||
|
|
cursor = last_event_id
|
||
|
|
idle_polls = 0
|
||
|
|
while not await request.is_disconnected() and idle_polls < idle_limit:
|
||
|
|
rows = await service.stream_events(
|
||
|
|
session_id=session_id,
|
||
|
|
last_event_id=cursor,
|
||
|
|
current_user=current_user,
|
||
|
|
)
|
||
|
|
if not rows:
|
||
|
|
idle_polls += 1
|
||
|
|
yield ": keep-alive\n\n"
|
||
|
|
await asyncio.sleep(0.2)
|
||
|
|
continue
|
||
|
|
|
||
|
|
idle_polls = 0
|
||
|
|
for row in rows:
|
||
|
|
row_id = str(row.get("id", ""))
|
||
|
|
event = row.get("event")
|
||
|
|
if not row_id or not isinstance(event, dict):
|
||
|
|
continue
|
||
|
|
cursor = row_id
|
||
|
|
yield to_sse_event(row_id, event)
|
||
|
|
|
||
|
|
return StreamingResponse(
|
||
|
|
_event_iter(),
|
||
|
|
media_type="text/event-stream",
|
||
|
|
headers={
|
||
|
|
"Cache-Control": "no-cache",
|
||
|
|
"Connection": "keep-alive",
|
||
|
|
"X-Accel-Buffering": "no",
|
||
|
|
},
|
||
|
|
)
|