83 lines
3.1 KiB
Python
83 lines
3.1 KiB
Python
from __future__ import annotations
|
|
|
|
from uuid import UUID
|
|
|
|
import pytest
|
|
from fastapi.testclient import TestClient
|
|
|
|
from app import app
|
|
from core.auth.models import CurrentUser
|
|
from v1.agent.dependencies import get_agent_service
|
|
from v1.agent.schemas import RunAgentInput
|
|
from v1.users.dependencies import get_current_user
|
|
|
|
|
|
class FakeAgentService:
|
|
async def prepare_resume(self, run_id: str, input_data: RunAgentInput):
|
|
return None
|
|
|
|
async def stream_run(self, input_data: RunAgentInput):
|
|
yield 'data: {"type": "RUN_STARTED", "runId": "r1"}\n\n'
|
|
yield 'data: {"type": "TEXT_MESSAGE_START", "messageId": "m1"}\n\n'
|
|
yield 'data: {"type": "TEXT_MESSAGE_CONTENT", "delta": "Hello"}\n\n'
|
|
yield 'data: {"type": "TEXT_MESSAGE_END", "messageId": "m1"}\n\n'
|
|
yield 'data: {"type": "RUN_FINISHED", "runId": "r1"}\n\n'
|
|
|
|
async def stream_resume(self, run_id: str, input_data: RunAgentInput):
|
|
yield 'data: {"type": "RUN_STARTED", "runId": "r1"}\n\n'
|
|
yield 'data: {"type": "TEXT_MESSAGE_START", "messageId": "m2"}\n\n'
|
|
yield 'data: {"type": "TEXT_MESSAGE_CONTENT", "delta": "Resumed"}\n\n'
|
|
yield 'data: {"type": "TEXT_MESSAGE_END", "messageId": "m2"}\n\n'
|
|
yield 'data: {"type": "RUN_FINISHED", "runId": "r1"}\n\n'
|
|
|
|
|
|
def _get_test_user() -> CurrentUser:
|
|
return CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001"))
|
|
|
|
|
|
@pytest.fixture
|
|
def client() -> TestClient:
|
|
app.dependency_overrides[get_current_user] = _get_test_user
|
|
app.dependency_overrides[get_agent_service] = lambda: FakeAgentService()
|
|
yield TestClient(app)
|
|
app.dependency_overrides.clear()
|
|
|
|
|
|
class TestChatRoutes:
|
|
def test_run_route_streams_sse_events(self, client: TestClient):
|
|
payload = {
|
|
"threadId": "t1",
|
|
"runId": "r1",
|
|
"state": {},
|
|
"messages": [],
|
|
"tools": [],
|
|
"context": [],
|
|
"forwardedProps": {},
|
|
}
|
|
response = client.post("/api/v1/agent/runs", json=payload)
|
|
assert response.status_code == 200
|
|
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
|
|
|
|
events = response.text.split("\n\n")
|
|
assert 'data: {"type": "RUN_STARTED"' in events[0]
|
|
assert 'data: {"type": "TEXT_MESSAGE_START"' in events[1]
|
|
|
|
def test_resume_route_streams_sse_events(self, client: TestClient):
|
|
payload = {
|
|
"threadId": "t1",
|
|
"runId": "r1",
|
|
"state": {},
|
|
"messages": [],
|
|
"tools": [],
|
|
"context": [],
|
|
"forwardedProps": {},
|
|
"resume": {"interruptId": "int-1", "payload": {"decision": "approved"}},
|
|
}
|
|
response = client.post("/api/v1/agent/runs/r1/resume", json=payload)
|
|
assert response.status_code == 200
|
|
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
|
|
|
|
events = response.text.split("\n\n")
|
|
assert 'data: {"type": "RUN_STARTED"' in events[0]
|
|
assert 'data: {"type": "TEXT_MESSAGE_CONTENT", "delta": "Resumed"' in events[2]
|