chore: checkpoint current backend/runtime changes
This commit is contained in:
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.infrastructure.queue.tasks import run_agent_task
|
||||
from core.agent.infrastructure.queue.tasks import _build_redis_publisher, run_agent_task
|
||||
|
||||
|
||||
class _FakeRunService:
|
||||
@@ -67,3 +67,35 @@ async def test_run_agent_task_emits_error_event_on_exception() -> None:
|
||||
)
|
||||
|
||||
assert events == ["RUN_STARTED", "RUN_ERROR"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_task_rejects_invalid_command() -> None:
|
||||
with pytest.raises(ValueError, match="invalid command type"):
|
||||
await run_agent_task({"command": "invalid", "session_id": "00000000-0000-0000-0000-000000000001"})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_task_resume_requires_tool_call_id() -> None:
|
||||
with pytest.raises(ValueError, match="tool_call_id is required"):
|
||||
await run_agent_task(
|
||||
{
|
||||
"command": "resume",
|
||||
"session_id": "00000000-0000-0000-0000-000000000001",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_redis_publisher_init_fail_raises_runtime_error(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
from core.agent.infrastructure.queue import tasks
|
||||
|
||||
async def _fake_get_client() -> object:
|
||||
raise RuntimeError("Redis service initialization failed")
|
||||
|
||||
monkeypatch.setattr(tasks, "get_or_init_redis_client", _fake_get_client)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Redis service initialization failed"):
|
||||
await _build_redis_publisher()
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.config.settings import Settings
|
||||
|
||||
|
||||
def test_taskiq_uses_redis_url_by_default() -> None:
|
||||
settings = Settings()
|
||||
|
||||
assert settings.taskiq_broker_url.startswith("redis://")
|
||||
assert settings.taskiq_result_backend_url.startswith("redis://")
|
||||
@@ -0,0 +1,37 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from core.taskiq.app import broker, bulk_broker, critical_broker, default_broker
|
||||
|
||||
|
||||
def test_taskiq_broker_is_configured() -> None:
|
||||
assert broker is not None
|
||||
assert default_broker is broker
|
||||
assert critical_broker is not None
|
||||
assert bulk_broker is not None
|
||||
|
||||
|
||||
def test_taskiq_app_configures_logging_on_import(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
sys.modules.pop("core.taskiq.app", None)
|
||||
sys.modules.pop("core.taskiq", None)
|
||||
|
||||
called = {"count": 0, "args": None}
|
||||
|
||||
def _fake_configure_logging(*args: object, **__: object) -> None:
|
||||
called["count"] += 1
|
||||
called["args"] = args
|
||||
|
||||
monkeypatch.setattr("core.logging.configure_logging", _fake_configure_logging)
|
||||
|
||||
importlib.import_module("core.taskiq.app")
|
||||
|
||||
from core.config.settings import config
|
||||
|
||||
assert called["count"] == 1
|
||||
assert called["args"] == (config,)
|
||||
@@ -0,0 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
ROOT_DIR = Path(__file__).resolve().parents[4]
|
||||
APP_SCRIPT = ROOT_DIR / "infra" / "scripts" / "app.sh"
|
||||
|
||||
|
||||
def test_worker_commands_use_taskiq() -> None:
|
||||
content = APP_SCRIPT.read_text(encoding="utf-8")
|
||||
removed_runner = "uv run c" "elery"
|
||||
|
||||
assert "uv run taskiq worker" in content
|
||||
assert "core.taskiq.app:critical_broker" in content
|
||||
assert "core.taskiq.app:default_broker" in content
|
||||
assert "core.taskiq.app:bulk_broker" in content
|
||||
assert 'pgrep -f "taskiq.*worker"' in content
|
||||
assert 'pkill -f "taskiq.*worker"' in content
|
||||
assert removed_runner not in content
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import pytest
|
||||
|
||||
from core.config.settings import RedisSettings
|
||||
from services.base.redis import RedisService
|
||||
from services.base.redis import RedisService, get_or_init_redis_client, redis_service
|
||||
|
||||
|
||||
class _FakeRedisClient:
|
||||
@@ -96,3 +96,35 @@ def test_get_client_raises_before_init() -> None:
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
service.get_client()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_or_init_redis_client_initializes_when_needed(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
fake_client = _FakeRedisClient()
|
||||
|
||||
async def _fake_initialize() -> bool:
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(type(redis_service), "is_initialized", property(lambda _: False))
|
||||
monkeypatch.setattr(redis_service, "initialize", _fake_initialize)
|
||||
monkeypatch.setattr(redis_service, "get_client", lambda: fake_client)
|
||||
|
||||
client = await get_or_init_redis_client()
|
||||
|
||||
assert client is fake_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_or_init_redis_client_raises_when_init_fails(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
async def _fake_initialize() -> bool:
|
||||
return False
|
||||
|
||||
monkeypatch.setattr(type(redis_service), "is_initialized", property(lambda _: False))
|
||||
monkeypatch.setattr(redis_service, "initialize", _fake_initialize)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Redis service initialization failed"):
|
||||
await get_or_init_redis_client()
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from services.base.service_interface import (
|
||||
BaseServiceProvider,
|
||||
ServiceRegistry,
|
||||
close_registered_services,
|
||||
initialize_registered_services,
|
||||
register_service,
|
||||
register_service_instance,
|
||||
)
|
||||
@@ -35,6 +39,17 @@ def test_register_service_and_create_service() -> None:
|
||||
assert created.get_service_info()["name"] == "dummy"
|
||||
|
||||
|
||||
def test_register_service_and_get_service() -> None:
|
||||
@register_service("dummy-service-get")
|
||||
class _RegisteredService(_DummyService):
|
||||
pass
|
||||
|
||||
resolved = ServiceRegistry.get_service("dummy-service-get")
|
||||
|
||||
assert resolved is not None
|
||||
assert resolved.get_service_info()["name"] == "dummy"
|
||||
|
||||
|
||||
def test_register_service_instance_returns_same_instance() -> None:
|
||||
instance = _DummyService("singleton")
|
||||
|
||||
@@ -47,3 +62,77 @@ def test_register_service_instance_returns_same_instance() -> None:
|
||||
|
||||
def test_create_service_returns_none_for_missing() -> None:
|
||||
assert ServiceRegistry.create_service("missing-service") is None
|
||||
|
||||
|
||||
def test_get_service_returns_none_for_missing() -> None:
|
||||
assert ServiceRegistry.get_service("missing-service") is None
|
||||
|
||||
|
||||
class _LifecycleService(BaseServiceProvider):
|
||||
def __init__(self, name: str, recorder: list[str], fail_on_init: bool = False) -> None:
|
||||
super().__init__(name)
|
||||
self._recorder = recorder
|
||||
self._fail_on_init = fail_on_init
|
||||
|
||||
async def initialize(self, **_: object) -> bool:
|
||||
self._recorder.append(f"init:{self.service_name}")
|
||||
if self._fail_on_init:
|
||||
return False
|
||||
self._set_initialized(True)
|
||||
return True
|
||||
|
||||
async def close(self) -> bool:
|
||||
self._recorder.append(f"close:{self.service_name}")
|
||||
self._set_initialized(False)
|
||||
return True
|
||||
|
||||
async def health_check(self) -> dict[str, object]:
|
||||
return {"status": "healthy", "details": {}}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_registered_services_success() -> None:
|
||||
recorder: list[str] = []
|
||||
first = register_service_instance(
|
||||
"lifecycle-success-first", _LifecycleService("first", recorder)
|
||||
)
|
||||
second = register_service_instance(
|
||||
"lifecycle-success-second", _LifecycleService("second", recorder)
|
||||
)
|
||||
|
||||
initialized, services = await initialize_registered_services(
|
||||
["lifecycle-success-first", "lifecycle-success-second"]
|
||||
)
|
||||
|
||||
assert initialized is True
|
||||
assert services == [first, second]
|
||||
assert recorder == ["init:first", "init:second"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_registered_services_failure_rolls_back() -> None:
|
||||
recorder: list[str] = []
|
||||
register_service_instance("lifecycle-fail-first", _LifecycleService("first", recorder))
|
||||
register_service_instance(
|
||||
"lifecycle-fail-second", _LifecycleService("second", recorder, fail_on_init=True)
|
||||
)
|
||||
|
||||
initialized, services = await initialize_registered_services(
|
||||
["lifecycle-fail-first", "lifecycle-fail-second"]
|
||||
)
|
||||
|
||||
assert initialized is False
|
||||
assert services == []
|
||||
assert recorder == ["init:first", "init:second", "close:first"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_registered_services_closes_in_reverse_order() -> None:
|
||||
recorder: list[str] = []
|
||||
first = _LifecycleService("first", recorder)
|
||||
second = _LifecycleService("second", recorder)
|
||||
|
||||
closed = await close_registered_services([first, second])
|
||||
|
||||
assert closed is True
|
||||
assert recorder == ["close:second", "close:first"]
|
||||
|
||||
@@ -0,0 +1,111 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.config.settings import SupabaseSettings
|
||||
from services.base.supabase import SupabaseService
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_success(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
service = SupabaseService(settings=SupabaseSettings())
|
||||
anon_client = MagicMock()
|
||||
admin_client = MagicMock()
|
||||
|
||||
create_calls: list[tuple[str, str]] = []
|
||||
|
||||
def _fake_create_client(url: str, key: str) -> object:
|
||||
create_calls.append((url, key))
|
||||
return anon_client if len(create_calls) == 1 else admin_client
|
||||
|
||||
monkeypatch.setattr("services.base.supabase.create_client", _fake_create_client)
|
||||
|
||||
result = await service.initialize()
|
||||
|
||||
assert result is True
|
||||
assert service.is_initialized is True
|
||||
assert service.get_client() is anon_client
|
||||
assert service.get_admin_client() is admin_client
|
||||
assert len(create_calls) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_failure(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
service = SupabaseService(settings=SupabaseSettings())
|
||||
|
||||
def _fake_create_client(_: str, __: str) -> object:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
monkeypatch.setattr("services.base.supabase.create_client", _fake_create_client)
|
||||
|
||||
result = await service.initialize()
|
||||
|
||||
assert result is False
|
||||
assert service.is_initialized is False
|
||||
with pytest.raises(RuntimeError):
|
||||
service.get_client()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_clears_clients(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
service = SupabaseService(settings=SupabaseSettings())
|
||||
|
||||
def _fake_create_client(_: str, __: str) -> object:
|
||||
return MagicMock()
|
||||
|
||||
monkeypatch.setattr("services.base.supabase.create_client", _fake_create_client)
|
||||
|
||||
assert await service.initialize() is True
|
||||
assert await service.close() is True
|
||||
assert service.is_initialized is False
|
||||
with pytest.raises(RuntimeError):
|
||||
service.get_client()
|
||||
with pytest.raises(RuntimeError):
|
||||
service.get_admin_client()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_uninitialized() -> None:
|
||||
service = SupabaseService(settings=SupabaseSettings())
|
||||
|
||||
health = await service.health_check()
|
||||
|
||||
assert health["status"] == "unhealthy"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_initialized(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
service = SupabaseService(settings=SupabaseSettings())
|
||||
|
||||
anon_client = MagicMock()
|
||||
anon_client.auth.get_session = MagicMock(return_value=None)
|
||||
|
||||
admin_list_users = MagicMock(return_value=SimpleNamespace(users=[]))
|
||||
admin_client = MagicMock()
|
||||
admin_client.auth.admin = SimpleNamespace(list_users=admin_list_users)
|
||||
|
||||
create_sequence = [anon_client, admin_client]
|
||||
|
||||
def _fake_create_client(_: str, __: str) -> object:
|
||||
return create_sequence.pop(0)
|
||||
|
||||
monkeypatch.setattr("services.base.supabase.create_client", _fake_create_client)
|
||||
|
||||
assert await service.initialize() is True
|
||||
|
||||
health = await service.health_check()
|
||||
|
||||
assert health["status"] == "healthy"
|
||||
admin_list_users.assert_called_once_with(page=1, per_page=1)
|
||||
|
||||
|
||||
def test_get_client_raises_before_init() -> None:
|
||||
service = SupabaseService(settings=SupabaseSettings())
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
service.get_client()
|
||||
with pytest.raises(RuntimeError):
|
||||
service.get_admin_client()
|
||||
@@ -0,0 +1,43 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
import app as app_module
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifespan_uses_registered_services(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
initialized_services = [object(), object()]
|
||||
calls: dict[str, object] = {}
|
||||
|
||||
async def _fake_initialize(service_names: list[str]) -> tuple[bool, list[object]]:
|
||||
calls["init_names"] = service_names
|
||||
return True, initialized_services
|
||||
|
||||
async def _fake_close(services: list[object]) -> bool:
|
||||
calls["close_services"] = services
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(app_module, "initialize_registered_services", _fake_initialize)
|
||||
monkeypatch.setattr(app_module, "close_registered_services", _fake_close)
|
||||
|
||||
context = app_module.lifespan(app_module.app)
|
||||
await context.__aenter__()
|
||||
await context.__aexit__(None, None, None)
|
||||
|
||||
assert calls["init_names"] == ["redis", "supabase"]
|
||||
assert calls["close_services"] == initialized_services
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifespan_raises_when_initialization_failed(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
async def _fake_initialize(_: list[str]) -> tuple[bool, list[object]]:
|
||||
return False, []
|
||||
|
||||
monkeypatch.setattr(app_module, "initialize_registered_services", _fake_initialize)
|
||||
|
||||
context = app_module.lifespan(app_module.app)
|
||||
with pytest.raises(RuntimeError, match="Service initialization failed"):
|
||||
await context.__aenter__()
|
||||
@@ -1,45 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from celery import Celery
|
||||
from pytest import MonkeyPatch
|
||||
|
||||
from core.logging import celery as celery_logging
|
||||
from core.logging.context import clear_context, get_context
|
||||
|
||||
|
||||
class DummyTask:
|
||||
name: str = "tasks.sample"
|
||||
|
||||
|
||||
def test_celery_prerun_binds_task_context() -> None:
|
||||
handlers = celery_logging.build_celery_signal_handlers()
|
||||
|
||||
handlers.on_task_prerun(task_id="task-123", task=DummyTask())
|
||||
context = get_context()
|
||||
|
||||
assert context["task_id"] == "task-123"
|
||||
assert context["task_name"] == "tasks.sample"
|
||||
|
||||
clear_context()
|
||||
|
||||
|
||||
def test_celery_setup_logging_calls_configure(monkeypatch: MonkeyPatch) -> None:
|
||||
called = {"value": False}
|
||||
|
||||
def fake_configure_logging(settings: object | None = None) -> None:
|
||||
called["value"] = True
|
||||
|
||||
monkeypatch.setattr(celery_logging, "configure_logging", fake_configure_logging)
|
||||
handlers = celery_logging.build_celery_signal_handlers()
|
||||
|
||||
handlers.on_setup_logging()
|
||||
|
||||
assert called["value"] is True
|
||||
|
||||
|
||||
def test_configure_celery_app_disables_hijack() -> None:
|
||||
app = Celery("test")
|
||||
|
||||
celery_logging.configure_celery_app(app)
|
||||
|
||||
assert app.conf.worker_hijack_root_logger is False
|
||||
@@ -0,0 +1,227 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from v1.agent.dependencies import TaskiqQueueClient
|
||||
|
||||
|
||||
class _FakeRedis:
|
||||
def __init__(self) -> None:
|
||||
self.store: dict[str, str] = {}
|
||||
self.delete_calls: list[str] = []
|
||||
|
||||
async def set(
|
||||
self,
|
||||
key: str,
|
||||
value: str,
|
||||
*,
|
||||
nx: bool = False,
|
||||
ex: int | None = None,
|
||||
) -> bool:
|
||||
del ex
|
||||
if nx and key in self.store:
|
||||
return False
|
||||
self.store[key] = value
|
||||
return True
|
||||
|
||||
async def get(self, key: str) -> str | None:
|
||||
return self.store.get(key)
|
||||
|
||||
async def delete(self, key: str) -> int:
|
||||
self.delete_calls.append(key)
|
||||
existed = 1 if key in self.store else 0
|
||||
self.store.pop(key, None)
|
||||
return existed
|
||||
|
||||
|
||||
class _FakeAsyncResult:
|
||||
def __init__(self, task_id: str) -> None:
|
||||
self.task_id = task_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_returns_task_id(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
from v1.agent import dependencies as deps
|
||||
|
||||
fake_redis = _FakeRedis()
|
||||
resolved_client = {"value": False}
|
||||
|
||||
async def _fake_kiq(payload: dict[str, object]) -> _FakeAsyncResult:
|
||||
assert payload["command"] == "run"
|
||||
return _FakeAsyncResult("task-123")
|
||||
|
||||
async def _fake_get_or_init_client() -> _FakeRedis:
|
||||
resolved_client["value"] = True
|
||||
return fake_redis
|
||||
|
||||
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
|
||||
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_kiq)
|
||||
|
||||
client = TaskiqQueueClient()
|
||||
task_id = await client.enqueue(command={"command": "run"}, dedup_key=None)
|
||||
|
||||
assert resolved_client["value"] is True
|
||||
assert task_id == "task-123"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_resume_dedup_returns_existing_task_id(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
from v1.agent import dependencies as deps
|
||||
|
||||
fake_redis = _FakeRedis()
|
||||
resolved_client = {"value": False}
|
||||
|
||||
async def _fake_kiq(payload: dict[str, object]) -> _FakeAsyncResult:
|
||||
del payload
|
||||
return _FakeAsyncResult("new-task-id")
|
||||
|
||||
async def _fake_get_or_init_client() -> _FakeRedis:
|
||||
resolved_client["value"] = True
|
||||
return fake_redis
|
||||
|
||||
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
|
||||
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_kiq)
|
||||
|
||||
dedup_key = "resume:session-1:call-1"
|
||||
fake_redis.store[f"agent:dedup:{dedup_key}"] = "existing-task-id"
|
||||
|
||||
client = TaskiqQueueClient()
|
||||
task_id = await client.enqueue(
|
||||
command={"command": "resume", "session_id": "session-1", "tool_call_id": "call-1"},
|
||||
dedup_key=dedup_key,
|
||||
)
|
||||
|
||||
assert resolved_client["value"] is True
|
||||
assert task_id == "existing-task-id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_inflight_dedup_waits_and_reuses_existing_task_id(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
from v1.agent import dependencies as deps
|
||||
|
||||
fake_redis = _FakeRedis()
|
||||
dedup_key = "resume:session-1:call-1"
|
||||
redis_key = f"agent:dedup:{dedup_key}"
|
||||
fake_redis.store[redis_key] = deps.DEDUP_INFLIGHT_MARKER
|
||||
attempts = {"count": 0}
|
||||
|
||||
async def _fake_get_or_init_client() -> _FakeRedis:
|
||||
return fake_redis
|
||||
|
||||
async def _fake_get(key: str) -> str | None:
|
||||
attempts["count"] += 1
|
||||
if attempts["count"] > 1:
|
||||
fake_redis.store[key] = "existing-task-id"
|
||||
return fake_redis.store.get(key)
|
||||
|
||||
async def _fake_sleep(_: float) -> None:
|
||||
return None
|
||||
|
||||
async def _fake_kiq(payload: dict[str, object]) -> _FakeAsyncResult:
|
||||
del payload
|
||||
raise AssertionError("should not enqueue when dedup task id appears")
|
||||
|
||||
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
|
||||
monkeypatch.setattr(fake_redis, "get", _fake_get)
|
||||
monkeypatch.setattr(deps.asyncio, "sleep", _fake_sleep)
|
||||
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_kiq)
|
||||
|
||||
client = TaskiqQueueClient()
|
||||
task_id = await client.enqueue(
|
||||
command={"command": "resume", "session_id": "session-1", "tool_call_id": "call-1"},
|
||||
dedup_key=dedup_key,
|
||||
)
|
||||
|
||||
assert task_id == "existing-task-id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_failure_cleans_dedup_lock(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
from v1.agent import dependencies as deps
|
||||
|
||||
fake_redis = _FakeRedis()
|
||||
dedup_key = "resume:session-1:call-1"
|
||||
redis_key = f"agent:dedup:{dedup_key}"
|
||||
|
||||
async def _fake_get_or_init_client() -> _FakeRedis:
|
||||
return fake_redis
|
||||
|
||||
async def _fake_kiq(payload: dict[str, object]) -> _FakeAsyncResult:
|
||||
del payload
|
||||
raise RuntimeError("enqueue failed")
|
||||
|
||||
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
|
||||
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_kiq)
|
||||
|
||||
client = TaskiqQueueClient()
|
||||
with pytest.raises(RuntimeError, match="enqueue failed"):
|
||||
await client.enqueue(
|
||||
command={"command": "resume", "session_id": "session-1", "tool_call_id": "call-1"},
|
||||
dedup_key=dedup_key,
|
||||
)
|
||||
|
||||
assert redis_key in fake_redis.delete_calls
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_uses_critical_queue_when_requested(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
from v1.agent import dependencies as deps
|
||||
|
||||
fake_redis = _FakeRedis()
|
||||
|
||||
async def _fake_get_or_init_client() -> _FakeRedis:
|
||||
return fake_redis
|
||||
|
||||
async def _fake_default_kiq(_: dict[str, object]) -> _FakeAsyncResult:
|
||||
raise AssertionError("default queue should not be selected")
|
||||
|
||||
async def _fake_critical_kiq(_: dict[str, object]) -> _FakeAsyncResult:
|
||||
return _FakeAsyncResult("critical-task-id")
|
||||
|
||||
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
|
||||
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_default_kiq)
|
||||
monkeypatch.setattr(deps.run_command_task_critical, "kiq", _fake_critical_kiq)
|
||||
|
||||
client = TaskiqQueueClient()
|
||||
task_id = await client.enqueue(
|
||||
command={"command": "run", "queue": "critical"},
|
||||
dedup_key=None,
|
||||
)
|
||||
|
||||
assert task_id == "critical-task-id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_uses_bulk_queue_when_requested(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
from v1.agent import dependencies as deps
|
||||
|
||||
fake_redis = _FakeRedis()
|
||||
|
||||
async def _fake_get_or_init_client() -> _FakeRedis:
|
||||
return fake_redis
|
||||
|
||||
async def _fake_default_kiq(_: dict[str, object]) -> _FakeAsyncResult:
|
||||
raise AssertionError("default queue should not be selected")
|
||||
|
||||
async def _fake_bulk_kiq(_: dict[str, object]) -> _FakeAsyncResult:
|
||||
return _FakeAsyncResult("bulk-task-id")
|
||||
|
||||
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
|
||||
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_default_kiq)
|
||||
monkeypatch.setattr(deps.run_command_task_bulk, "kiq", _fake_bulk_kiq)
|
||||
|
||||
client = TaskiqQueueClient()
|
||||
task_id = await client.enqueue(
|
||||
command={"command": "run", "queue": "bulk"},
|
||||
dedup_key=None,
|
||||
)
|
||||
|
||||
assert task_id == "bulk-task-id"
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
@@ -12,37 +12,44 @@ from v1.auth.schemas import PasswordResetConfirmRequest, PasswordResetRequest
|
||||
|
||||
class TestSupabaseAuthGateway:
|
||||
@pytest.fixture
|
||||
def gateway(self) -> SupabaseAuthGateway:
|
||||
with patch("v1.auth.gateway.create_client") as mock_create:
|
||||
mock_client = MagicMock()
|
||||
mock_admin_client = MagicMock()
|
||||
mock_create.side_effect = [mock_client, mock_admin_client]
|
||||
return SupabaseAuthGateway()
|
||||
def gateway(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> tuple[SupabaseAuthGateway, MagicMock, MagicMock]:
|
||||
mock_client = MagicMock()
|
||||
mock_admin_client = MagicMock()
|
||||
monkeypatch.setattr("v1.auth.gateway.supabase_service.get_client", lambda: mock_client)
|
||||
monkeypatch.setattr(
|
||||
"v1.auth.gateway.supabase_service.get_admin_client",
|
||||
lambda: mock_admin_client,
|
||||
)
|
||||
return SupabaseAuthGateway(), mock_client, mock_admin_client
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_password_reset_calls_email_with_string(
|
||||
self, gateway: SupabaseAuthGateway
|
||||
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
|
||||
) -> None:
|
||||
sut, mock_client, _ = gateway
|
||||
mock_reset_email = MagicMock()
|
||||
gateway._client.auth.reset_password_email = mock_reset_email
|
||||
mock_client.auth.reset_password_email = mock_reset_email
|
||||
|
||||
request = PasswordResetRequest(email="test@example.com")
|
||||
await gateway.request_password_reset(request)
|
||||
await sut.request_password_reset(request)
|
||||
|
||||
mock_reset_email.assert_called_once_with("test@example.com")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_password_reset_with_redirect(
|
||||
self, gateway: SupabaseAuthGateway
|
||||
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
|
||||
) -> None:
|
||||
sut, mock_client, _ = gateway
|
||||
mock_reset_email = MagicMock()
|
||||
gateway._client.auth.reset_password_email = mock_reset_email
|
||||
mock_client.auth.reset_password_email = mock_reset_email
|
||||
|
||||
request = PasswordResetRequest(
|
||||
email="test@example.com",
|
||||
redirect_to="http://localhost:3000/reset-password",
|
||||
)
|
||||
await gateway.request_password_reset(request)
|
||||
await sut.request_password_reset(request)
|
||||
|
||||
mock_reset_email.assert_called_once_with(
|
||||
"test@example.com",
|
||||
@@ -51,64 +58,68 @@ class TestSupabaseAuthGateway:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_password_reset_swallows_auth_error(
|
||||
self, gateway: SupabaseAuthGateway
|
||||
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
|
||||
) -> None:
|
||||
sut, mock_client, _ = gateway
|
||||
from supabase import AuthError
|
||||
|
||||
mock_reset_email = MagicMock(side_effect=AuthError("rate limit exceeded", None))
|
||||
gateway._client.auth.reset_password_email = mock_reset_email
|
||||
mock_client.auth.reset_password_email = mock_reset_email
|
||||
|
||||
request = PasswordResetRequest(email="test@example.com")
|
||||
|
||||
result = await gateway.request_password_reset(request)
|
||||
result = await sut.request_password_reset(request)
|
||||
|
||||
mock_reset_email.assert_called_once()
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_password_reset_extracts_email_from_mapping(
|
||||
self, gateway: SupabaseAuthGateway
|
||||
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
|
||||
) -> None:
|
||||
sut, mock_client, _ = gateway
|
||||
mock_reset_email = MagicMock()
|
||||
gateway._client.auth.reset_password_email = mock_reset_email
|
||||
mock_client.auth.reset_password_email = mock_reset_email
|
||||
|
||||
request = PasswordResetRequest.model_construct(
|
||||
email={"email": "test@example.com"},
|
||||
redirect_to=None,
|
||||
)
|
||||
|
||||
await gateway.request_password_reset(request)
|
||||
await sut.request_password_reset(request)
|
||||
|
||||
mock_reset_email.assert_called_once_with("test@example.com")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_password_reset_rejects_invalid_email_shape(
|
||||
self, gateway: SupabaseAuthGateway
|
||||
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
|
||||
) -> None:
|
||||
sut, _, _ = gateway
|
||||
request = PasswordResetRequest.model_construct(
|
||||
email={"unexpected": "value"},
|
||||
redirect_to=None,
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await gateway.request_password_reset(request)
|
||||
await sut.request_password_reset(request)
|
||||
|
||||
assert exc_info.value.status_code == 422
|
||||
assert exc_info.value.detail == "Invalid email"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_confirm_password_reset_updates_password_by_user_id(
|
||||
self, gateway: SupabaseAuthGateway
|
||||
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
|
||||
) -> None:
|
||||
sut, mock_client, mock_admin_client = gateway
|
||||
verify_response = SimpleNamespace(
|
||||
session=SimpleNamespace(access_token="access"),
|
||||
user=SimpleNamespace(id="user-1"),
|
||||
)
|
||||
mock_verify_otp = MagicMock(return_value=verify_response)
|
||||
gateway._client.auth.verify_otp = mock_verify_otp
|
||||
mock_client.auth.verify_otp = mock_verify_otp
|
||||
|
||||
mock_update_user_by_id = MagicMock()
|
||||
gateway._admin_client.auth.admin = SimpleNamespace(
|
||||
mock_admin_client.auth.admin = SimpleNamespace(
|
||||
update_user_by_id=mock_update_user_by_id
|
||||
)
|
||||
|
||||
@@ -118,7 +129,7 @@ class TestSupabaseAuthGateway:
|
||||
new_password="newpassword123",
|
||||
)
|
||||
|
||||
await gateway.confirm_password_reset(request)
|
||||
await sut.confirm_password_reset(request)
|
||||
|
||||
mock_verify_otp.assert_called_once_with(
|
||||
{
|
||||
@@ -134,13 +145,14 @@ class TestSupabaseAuthGateway:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_confirm_password_reset_raises_when_user_id_missing(
|
||||
self, gateway: SupabaseAuthGateway
|
||||
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
|
||||
) -> None:
|
||||
sut, mock_client, _ = gateway
|
||||
verify_response = SimpleNamespace(
|
||||
session=SimpleNamespace(access_token="access"),
|
||||
user=SimpleNamespace(id=""),
|
||||
)
|
||||
gateway._client.auth.verify_otp = MagicMock(return_value=verify_response)
|
||||
mock_client.auth.verify_otp = MagicMock(return_value=verify_response)
|
||||
|
||||
request = PasswordResetConfirmRequest(
|
||||
email="test@example.com",
|
||||
@@ -149,7 +161,7 @@ class TestSupabaseAuthGateway:
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await gateway.confirm_password_reset(request)
|
||||
await sut.confirm_password_reset(request)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert exc_info.value.detail == "Invalid or expired verification code"
|
||||
|
||||
Reference in New Issue
Block a user