from __future__ import annotations from types import SimpleNamespace from typing import Any, cast from uuid import uuid4 from ag_ui.core import RunAgentInput from fastapi import HTTPException import pytest from core.auth.models import CurrentUser from v1.agent import router as agent_router @pytest.mark.asyncio async def test_allow_run_request_fails_closed_when_redis_unavailable( monkeypatch: pytest.MonkeyPatch, ) -> None: async def _raise_redis_error(): raise RuntimeError("redis unavailable") monkeypatch.setattr(agent_router, "get_or_init_redis_client", _raise_redis_error) allowed = await agent_router._allow_run_request(user_id="user-1") assert allowed is False @pytest.mark.asyncio async def test_acquire_sse_slot_fails_closed_when_redis_unavailable( monkeypatch: pytest.MonkeyPatch, ) -> None: async def _raise_redis_error(): raise RuntimeError("redis unavailable") monkeypatch.setattr(agent_router, "get_or_init_redis_client", _raise_redis_error) allowed = await agent_router._acquire_sse_slot(user_id="user-1") assert allowed is False @pytest.mark.asyncio async def test_allow_transcribe_request_fails_closed_when_redis_unavailable( monkeypatch: pytest.MonkeyPatch, ) -> None: async def _raise_redis_error(): raise RuntimeError("redis unavailable") monkeypatch.setattr(agent_router, "get_or_init_redis_client", _raise_redis_error) allowed = await agent_router._allow_transcribe_request(user_id="user-1") assert allowed is False def _resume_input_with_tool_message() -> RunAgentInput: return RunAgentInput.model_validate( { "threadId": "00000000-0000-0000-0000-000000000001", "runId": "run-resume-1", "state": {}, "messages": [ { "id": "tool-1", "role": "tool", "toolCallId": "call-1", "content": '{"toolName":"navigate_to_route","result":{"ok":true}}', } ], "tools": [], "context": [], "forwardedProps": {}, } ) @pytest.mark.asyncio async def test_enqueue_resume_rejects_without_tool_contract() -> None: request = RunAgentInput.model_validate( { "threadId": "00000000-0000-0000-0000-000000000001", "runId": "run-resume-invalid", "state": {}, "messages": [ { "id": "u1", "role": "user", "content": "continue", } ], "tools": [], "context": [], "forwardedProps": {}, } ) class _Service: async def enqueue_resume(self, **kwargs): # noqa: ANN003 del kwargs raise AssertionError("enqueue_resume should not be called") with pytest.raises(HTTPException) as exc_info: await agent_router.enqueue_resume( thread_id="00000000-0000-0000-0000-000000000001", request=request, service=cast(Any, _Service()), current_user=CurrentUser(id=uuid4(), email="user@example.com"), ) assert exc_info.value.status_code == 422 assert ( exc_info.value.detail == "RunAgentInput.messages requires a tool message with toolCallId for resume" ) @pytest.mark.asyncio async def test_enqueue_resume_rejects_when_rate_limited( monkeypatch: pytest.MonkeyPatch, ) -> None: request = _resume_input_with_tool_message() async def _deny_run(*, user_id: str) -> bool: del user_id return False monkeypatch.setattr(agent_router, "_allow_run_request", _deny_run) class _Service: async def enqueue_resume(self, **kwargs): # noqa: ANN003 del kwargs raise AssertionError("enqueue_resume should not be called") with pytest.raises(HTTPException) as exc_info: await agent_router.enqueue_resume( thread_id="00000000-0000-0000-0000-000000000001", request=request, service=cast(Any, _Service()), current_user=CurrentUser(id=uuid4(), email="user@example.com"), ) assert exc_info.value.status_code == 429 assert exc_info.value.detail == "Too many run requests" @pytest.mark.asyncio async def test_enqueue_resume_accepts_valid_tool_contract( monkeypatch: pytest.MonkeyPatch, ) -> None: request = _resume_input_with_tool_message() async def _allow_run(*, user_id: str) -> bool: del user_id return True monkeypatch.setattr(agent_router, "_allow_run_request", _allow_run) class _Service: async def enqueue_resume(self, **kwargs): # noqa: ANN003 return SimpleNamespace( task_id="task-resume-1", thread_id=kwargs["thread_id"], run_id=kwargs["run_input"].run_id, created=False, ) result = await agent_router.enqueue_resume( thread_id="00000000-0000-0000-0000-000000000001", request=request, service=cast(Any, _Service()), current_user=CurrentUser(id=uuid4(), email="user@example.com"), ) assert result.task_id == "task-resume-1" assert result.thread_id == "00000000-0000-0000-0000-000000000001" assert result.run_id == "run-resume-1"