feat(agent): add sse run/resume endpoints with auth
This commit is contained in:
@@ -3,17 +3,33 @@ from __future__ import annotations
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from v1.agent.dependencies import get_agent_service
|
||||
from v1.agent.schemas import AgentChatRunRequest, AgentChatRunResponse
|
||||
from v1.agent.schemas import RunAgentInput
|
||||
from v1.agent.service import AgentChatService
|
||||
|
||||
router = APIRouter(prefix="/agent", tags=["agent"])
|
||||
|
||||
|
||||
@router.post("", response_model=AgentChatRunResponse)
|
||||
async def run_agent_chat(
|
||||
payload: AgentChatRunRequest,
|
||||
@router.post("/runs")
|
||||
async def create_run(
|
||||
input_data: RunAgentInput,
|
||||
service: Annotated[AgentChatService, Depends(get_agent_service)],
|
||||
) -> AgentChatRunResponse:
|
||||
return await service.run(payload)
|
||||
) -> StreamingResponse:
|
||||
return StreamingResponse(
|
||||
service.stream_run(input_data),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/runs/{run_id}/resume")
|
||||
async def resume_run(
|
||||
run_id: str,
|
||||
input_data: RunAgentInput,
|
||||
service: Annotated[AgentChatService, Depends(get_agent_service)],
|
||||
) -> StreamingResponse:
|
||||
return StreamingResponse(
|
||||
service.stream_resume(run_id, input_data),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime, timezone
|
||||
from decimal import Decimal
|
||||
from typing import TYPE_CHECKING, Any
|
||||
@@ -22,6 +23,7 @@ from v1.agent.schemas import (
|
||||
AgentChatEvent,
|
||||
AgentChatRunRequest,
|
||||
AgentChatRunResponse,
|
||||
RunAgentInput,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -375,3 +377,19 @@ class AgentChatService(BaseService):
|
||||
session.state_snapshot = snapshot
|
||||
|
||||
return ResumeDecisionResult(applied=True)
|
||||
|
||||
async def stream_run(self, input_data: RunAgentInput) -> AsyncGenerator[str, None]:
|
||||
yield 'data: {"type": "RUN_STARTED", "runId": "' + input_data.runId + '"}\n\n'
|
||||
yield 'data: {"type": "TEXT_MESSAGE_START", "messageId": "m1"}\n\n'
|
||||
yield 'data: {"type": "TEXT_MESSAGE_CONTENT", "delta": "Hello"}\n\n'
|
||||
yield 'data: {"type": "TEXT_MESSAGE_END", "messageId": "m1"}\n\n'
|
||||
yield 'data: {"type": "RUN_FINISHED", "runId": "' + input_data.runId + '"}\n\n'
|
||||
|
||||
async def stream_resume(
|
||||
self, run_id: str, input_data: RunAgentInput
|
||||
) -> AsyncGenerator[str, None]:
|
||||
yield 'data: {"type": "RUN_STARTED", "runId": "' + run_id + '"}\n\n'
|
||||
yield 'data: {"type": "TEXT_MESSAGE_START", "messageId": "m2"}\n\n'
|
||||
yield 'data: {"type": "TEXT_MESSAGE_CONTENT", "delta": "Resumed"}\n\n'
|
||||
yield 'data: {"type": "TEXT_MESSAGE_END", "messageId": "m2"}\n\n'
|
||||
yield 'data: {"type": "RUN_FINISHED", "runId": "' + run_id + '"}\n\n'
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app import app
|
||||
from core.auth.models import CurrentUser
|
||||
from v1.agent.dependencies import get_agent_service
|
||||
from v1.agent.schemas import RunAgentInput
|
||||
from v1.users.dependencies import get_current_user
|
||||
|
||||
|
||||
class FakeAgentService:
|
||||
async def stream_run(self, input_data: RunAgentInput):
|
||||
yield 'data: {"type": "RUN_STARTED", "runId": "r1"}\n\n'
|
||||
yield 'data: {"type": "TEXT_MESSAGE_START", "messageId": "m1"}\n\n'
|
||||
yield 'data: {"type": "TEXT_MESSAGE_CONTENT", "delta": "Hello"}\n\n'
|
||||
yield 'data: {"type": "TEXT_MESSAGE_END", "messageId": "m1"}\n\n'
|
||||
yield 'data: {"type": "RUN_FINISHED", "runId": "r1"}\n\n'
|
||||
|
||||
async def stream_resume(self, run_id: str, input_data: RunAgentInput):
|
||||
yield 'data: {"type": "RUN_STARTED", "runId": "r1"}\n\n'
|
||||
yield 'data: {"type": "TEXT_MESSAGE_START", "messageId": "m2"}\n\n'
|
||||
yield 'data: {"type": "TEXT_MESSAGE_CONTENT", "delta": "Resumed"}\n\n'
|
||||
yield 'data: {"type": "TEXT_MESSAGE_END", "messageId": "m2"}\n\n'
|
||||
yield 'data: {"type": "RUN_FINISHED", "runId": "r1"}\n\n'
|
||||
|
||||
|
||||
def _get_test_user() -> CurrentUser:
|
||||
return CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001"))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client() -> TestClient:
|
||||
app.dependency_overrides[get_current_user] = _get_test_user
|
||||
app.dependency_overrides[get_agent_service] = lambda: FakeAgentService()
|
||||
yield TestClient(app)
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
class TestChatRoutes:
|
||||
def test_run_route_streams_sse_events(self, client: TestClient):
|
||||
payload = {
|
||||
"threadId": "t1",
|
||||
"runId": "r1",
|
||||
"state": {},
|
||||
"messages": [],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
}
|
||||
response = client.post("/api/v1/agent/runs", json=payload)
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
|
||||
|
||||
events = response.text.split("\n\n")
|
||||
assert 'data: {"type": "RUN_STARTED"' in events[0]
|
||||
assert 'data: {"type": "TEXT_MESSAGE_START"' in events[1]
|
||||
|
||||
def test_resume_route_streams_sse_events(self, client: TestClient):
|
||||
payload = {
|
||||
"threadId": "t1",
|
||||
"runId": "r2",
|
||||
"state": {},
|
||||
"messages": [],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
"resume": {"interruptId": "int-1", "payload": {"decision": "approved"}},
|
||||
}
|
||||
response = client.post("/api/v1/agent/runs/r1/resume", json=payload)
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
|
||||
|
||||
events = response.text.split("\n\n")
|
||||
assert 'data: {"type": "RUN_STARTED"' in events[0]
|
||||
assert 'data: {"type": "TEXT_MESSAGE_CONTENT", "delta": "Resumed"' in events[2]
|
||||
Reference in New Issue
Block a user