chore: checkpoint current backend/runtime changes

This commit is contained in:
qzl
2026-03-06 17:28:17 +08:00
parent 2c59fe5ee2
commit b6087fd195
32 changed files with 1641 additions and 469 deletions
@@ -2,7 +2,7 @@ from __future__ import annotations
import pytest
from core.agent.infrastructure.queue.tasks import run_agent_task
from core.agent.infrastructure.queue.tasks import _build_redis_publisher, run_agent_task
class _FakeRunService:
@@ -67,3 +67,35 @@ async def test_run_agent_task_emits_error_event_on_exception() -> None:
)
assert events == ["RUN_STARTED", "RUN_ERROR"]
@pytest.mark.asyncio
async def test_run_agent_task_rejects_invalid_command() -> None:
with pytest.raises(ValueError, match="invalid command type"):
await run_agent_task({"command": "invalid", "session_id": "00000000-0000-0000-0000-000000000001"})
@pytest.mark.asyncio
async def test_run_agent_task_resume_requires_tool_call_id() -> None:
with pytest.raises(ValueError, match="tool_call_id is required"):
await run_agent_task(
{
"command": "resume",
"session_id": "00000000-0000-0000-0000-000000000001",
}
)
@pytest.mark.asyncio
async def test_build_redis_publisher_init_fail_raises_runtime_error(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from core.agent.infrastructure.queue import tasks
async def _fake_get_client() -> object:
raise RuntimeError("Redis service initialization failed")
monkeypatch.setattr(tasks, "get_or_init_redis_client", _fake_get_client)
with pytest.raises(RuntimeError, match="Redis service initialization failed"):
await _build_redis_publisher()
@@ -0,0 +1,10 @@
from __future__ import annotations
from core.config.settings import Settings
def test_taskiq_uses_redis_url_by_default() -> None:
settings = Settings()
assert settings.taskiq_broker_url.startswith("redis://")
assert settings.taskiq_result_backend_url.startswith("redis://")
@@ -0,0 +1,37 @@
from __future__ import annotations
import importlib
import sys
import pytest
from core.taskiq.app import broker, bulk_broker, critical_broker, default_broker
def test_taskiq_broker_is_configured() -> None:
assert broker is not None
assert default_broker is broker
assert critical_broker is not None
assert bulk_broker is not None
def test_taskiq_app_configures_logging_on_import(
monkeypatch: pytest.MonkeyPatch,
) -> None:
sys.modules.pop("core.taskiq.app", None)
sys.modules.pop("core.taskiq", None)
called = {"count": 0, "args": None}
def _fake_configure_logging(*args: object, **__: object) -> None:
called["count"] += 1
called["args"] = args
monkeypatch.setattr("core.logging.configure_logging", _fake_configure_logging)
importlib.import_module("core.taskiq.app")
from core.config.settings import config
assert called["count"] == 1
assert called["args"] == (config,)
@@ -0,0 +1,20 @@
from __future__ import annotations
from pathlib import Path
ROOT_DIR = Path(__file__).resolve().parents[4]
APP_SCRIPT = ROOT_DIR / "infra" / "scripts" / "app.sh"
def test_worker_commands_use_taskiq() -> None:
content = APP_SCRIPT.read_text(encoding="utf-8")
removed_runner = "uv run c" "elery"
assert "uv run taskiq worker" in content
assert "core.taskiq.app:critical_broker" in content
assert "core.taskiq.app:default_broker" in content
assert "core.taskiq.app:bulk_broker" in content
assert 'pgrep -f "taskiq.*worker"' in content
assert 'pkill -f "taskiq.*worker"' in content
assert removed_runner not in content
@@ -3,7 +3,7 @@ from __future__ import annotations
import pytest
from core.config.settings import RedisSettings
from services.base.redis import RedisService
from services.base.redis import RedisService, get_or_init_redis_client, redis_service
class _FakeRedisClient:
@@ -96,3 +96,35 @@ def test_get_client_raises_before_init() -> None:
with pytest.raises(RuntimeError):
service.get_client()
@pytest.mark.asyncio
async def test_get_or_init_redis_client_initializes_when_needed(
monkeypatch: pytest.MonkeyPatch,
) -> None:
fake_client = _FakeRedisClient()
async def _fake_initialize() -> bool:
return True
monkeypatch.setattr(type(redis_service), "is_initialized", property(lambda _: False))
monkeypatch.setattr(redis_service, "initialize", _fake_initialize)
monkeypatch.setattr(redis_service, "get_client", lambda: fake_client)
client = await get_or_init_redis_client()
assert client is fake_client
@pytest.mark.asyncio
async def test_get_or_init_redis_client_raises_when_init_fails(
monkeypatch: pytest.MonkeyPatch,
) -> None:
async def _fake_initialize() -> bool:
return False
monkeypatch.setattr(type(redis_service), "is_initialized", property(lambda _: False))
monkeypatch.setattr(redis_service, "initialize", _fake_initialize)
with pytest.raises(RuntimeError, match="Redis service initialization failed"):
await get_or_init_redis_client()
@@ -1,8 +1,12 @@
from __future__ import annotations
import pytest
from services.base.service_interface import (
BaseServiceProvider,
ServiceRegistry,
close_registered_services,
initialize_registered_services,
register_service,
register_service_instance,
)
@@ -35,6 +39,17 @@ def test_register_service_and_create_service() -> None:
assert created.get_service_info()["name"] == "dummy"
def test_register_service_and_get_service() -> None:
@register_service("dummy-service-get")
class _RegisteredService(_DummyService):
pass
resolved = ServiceRegistry.get_service("dummy-service-get")
assert resolved is not None
assert resolved.get_service_info()["name"] == "dummy"
def test_register_service_instance_returns_same_instance() -> None:
instance = _DummyService("singleton")
@@ -47,3 +62,77 @@ def test_register_service_instance_returns_same_instance() -> None:
def test_create_service_returns_none_for_missing() -> None:
assert ServiceRegistry.create_service("missing-service") is None
def test_get_service_returns_none_for_missing() -> None:
assert ServiceRegistry.get_service("missing-service") is None
class _LifecycleService(BaseServiceProvider):
def __init__(self, name: str, recorder: list[str], fail_on_init: bool = False) -> None:
super().__init__(name)
self._recorder = recorder
self._fail_on_init = fail_on_init
async def initialize(self, **_: object) -> bool:
self._recorder.append(f"init:{self.service_name}")
if self._fail_on_init:
return False
self._set_initialized(True)
return True
async def close(self) -> bool:
self._recorder.append(f"close:{self.service_name}")
self._set_initialized(False)
return True
async def health_check(self) -> dict[str, object]:
return {"status": "healthy", "details": {}}
@pytest.mark.asyncio
async def test_initialize_registered_services_success() -> None:
recorder: list[str] = []
first = register_service_instance(
"lifecycle-success-first", _LifecycleService("first", recorder)
)
second = register_service_instance(
"lifecycle-success-second", _LifecycleService("second", recorder)
)
initialized, services = await initialize_registered_services(
["lifecycle-success-first", "lifecycle-success-second"]
)
assert initialized is True
assert services == [first, second]
assert recorder == ["init:first", "init:second"]
@pytest.mark.asyncio
async def test_initialize_registered_services_failure_rolls_back() -> None:
recorder: list[str] = []
register_service_instance("lifecycle-fail-first", _LifecycleService("first", recorder))
register_service_instance(
"lifecycle-fail-second", _LifecycleService("second", recorder, fail_on_init=True)
)
initialized, services = await initialize_registered_services(
["lifecycle-fail-first", "lifecycle-fail-second"]
)
assert initialized is False
assert services == []
assert recorder == ["init:first", "init:second", "close:first"]
@pytest.mark.asyncio
async def test_close_registered_services_closes_in_reverse_order() -> None:
recorder: list[str] = []
first = _LifecycleService("first", recorder)
second = _LifecycleService("second", recorder)
closed = await close_registered_services([first, second])
assert closed is True
assert recorder == ["close:second", "close:first"]
@@ -0,0 +1,111 @@
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from core.config.settings import SupabaseSettings
from services.base.supabase import SupabaseService
@pytest.mark.asyncio
async def test_initialize_success(monkeypatch: pytest.MonkeyPatch) -> None:
service = SupabaseService(settings=SupabaseSettings())
anon_client = MagicMock()
admin_client = MagicMock()
create_calls: list[tuple[str, str]] = []
def _fake_create_client(url: str, key: str) -> object:
create_calls.append((url, key))
return anon_client if len(create_calls) == 1 else admin_client
monkeypatch.setattr("services.base.supabase.create_client", _fake_create_client)
result = await service.initialize()
assert result is True
assert service.is_initialized is True
assert service.get_client() is anon_client
assert service.get_admin_client() is admin_client
assert len(create_calls) == 2
@pytest.mark.asyncio
async def test_initialize_failure(monkeypatch: pytest.MonkeyPatch) -> None:
service = SupabaseService(settings=SupabaseSettings())
def _fake_create_client(_: str, __: str) -> object:
raise RuntimeError("boom")
monkeypatch.setattr("services.base.supabase.create_client", _fake_create_client)
result = await service.initialize()
assert result is False
assert service.is_initialized is False
with pytest.raises(RuntimeError):
service.get_client()
@pytest.mark.asyncio
async def test_close_clears_clients(monkeypatch: pytest.MonkeyPatch) -> None:
service = SupabaseService(settings=SupabaseSettings())
def _fake_create_client(_: str, __: str) -> object:
return MagicMock()
monkeypatch.setattr("services.base.supabase.create_client", _fake_create_client)
assert await service.initialize() is True
assert await service.close() is True
assert service.is_initialized is False
with pytest.raises(RuntimeError):
service.get_client()
with pytest.raises(RuntimeError):
service.get_admin_client()
@pytest.mark.asyncio
async def test_health_check_uninitialized() -> None:
service = SupabaseService(settings=SupabaseSettings())
health = await service.health_check()
assert health["status"] == "unhealthy"
@pytest.mark.asyncio
async def test_health_check_initialized(monkeypatch: pytest.MonkeyPatch) -> None:
service = SupabaseService(settings=SupabaseSettings())
anon_client = MagicMock()
anon_client.auth.get_session = MagicMock(return_value=None)
admin_list_users = MagicMock(return_value=SimpleNamespace(users=[]))
admin_client = MagicMock()
admin_client.auth.admin = SimpleNamespace(list_users=admin_list_users)
create_sequence = [anon_client, admin_client]
def _fake_create_client(_: str, __: str) -> object:
return create_sequence.pop(0)
monkeypatch.setattr("services.base.supabase.create_client", _fake_create_client)
assert await service.initialize() is True
health = await service.health_check()
assert health["status"] == "healthy"
admin_list_users.assert_called_once_with(page=1, per_page=1)
def test_get_client_raises_before_init() -> None:
service = SupabaseService(settings=SupabaseSettings())
with pytest.raises(RuntimeError):
service.get_client()
with pytest.raises(RuntimeError):
service.get_admin_client()
+43
View File
@@ -0,0 +1,43 @@
from __future__ import annotations
import pytest
import app as app_module
@pytest.mark.asyncio
async def test_lifespan_uses_registered_services(monkeypatch: pytest.MonkeyPatch) -> None:
initialized_services = [object(), object()]
calls: dict[str, object] = {}
async def _fake_initialize(service_names: list[str]) -> tuple[bool, list[object]]:
calls["init_names"] = service_names
return True, initialized_services
async def _fake_close(services: list[object]) -> bool:
calls["close_services"] = services
return True
monkeypatch.setattr(app_module, "initialize_registered_services", _fake_initialize)
monkeypatch.setattr(app_module, "close_registered_services", _fake_close)
context = app_module.lifespan(app_module.app)
await context.__aenter__()
await context.__aexit__(None, None, None)
assert calls["init_names"] == ["redis", "supabase"]
assert calls["close_services"] == initialized_services
@pytest.mark.asyncio
async def test_lifespan_raises_when_initialization_failed(
monkeypatch: pytest.MonkeyPatch,
) -> None:
async def _fake_initialize(_: list[str]) -> tuple[bool, list[object]]:
return False, []
monkeypatch.setattr(app_module, "initialize_registered_services", _fake_initialize)
context = app_module.lifespan(app_module.app)
with pytest.raises(RuntimeError, match="Service initialization failed"):
await context.__aenter__()
-45
View File
@@ -1,45 +0,0 @@
from __future__ import annotations
from celery import Celery
from pytest import MonkeyPatch
from core.logging import celery as celery_logging
from core.logging.context import clear_context, get_context
class DummyTask:
name: str = "tasks.sample"
def test_celery_prerun_binds_task_context() -> None:
handlers = celery_logging.build_celery_signal_handlers()
handlers.on_task_prerun(task_id="task-123", task=DummyTask())
context = get_context()
assert context["task_id"] == "task-123"
assert context["task_name"] == "tasks.sample"
clear_context()
def test_celery_setup_logging_calls_configure(monkeypatch: MonkeyPatch) -> None:
called = {"value": False}
def fake_configure_logging(settings: object | None = None) -> None:
called["value"] = True
monkeypatch.setattr(celery_logging, "configure_logging", fake_configure_logging)
handlers = celery_logging.build_celery_signal_handlers()
handlers.on_setup_logging()
assert called["value"] is True
def test_configure_celery_app_disables_hijack() -> None:
app = Celery("test")
celery_logging.configure_celery_app(app)
assert app.conf.worker_hijack_root_logger is False
@@ -0,0 +1,227 @@
from __future__ import annotations
import pytest
from v1.agent.dependencies import TaskiqQueueClient
class _FakeRedis:
def __init__(self) -> None:
self.store: dict[str, str] = {}
self.delete_calls: list[str] = []
async def set(
self,
key: str,
value: str,
*,
nx: bool = False,
ex: int | None = None,
) -> bool:
del ex
if nx and key in self.store:
return False
self.store[key] = value
return True
async def get(self, key: str) -> str | None:
return self.store.get(key)
async def delete(self, key: str) -> int:
self.delete_calls.append(key)
existed = 1 if key in self.store else 0
self.store.pop(key, None)
return existed
class _FakeAsyncResult:
def __init__(self, task_id: str) -> None:
self.task_id = task_id
@pytest.mark.asyncio
async def test_enqueue_returns_task_id(monkeypatch: pytest.MonkeyPatch) -> None:
from v1.agent import dependencies as deps
fake_redis = _FakeRedis()
resolved_client = {"value": False}
async def _fake_kiq(payload: dict[str, object]) -> _FakeAsyncResult:
assert payload["command"] == "run"
return _FakeAsyncResult("task-123")
async def _fake_get_or_init_client() -> _FakeRedis:
resolved_client["value"] = True
return fake_redis
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_kiq)
client = TaskiqQueueClient()
task_id = await client.enqueue(command={"command": "run"}, dedup_key=None)
assert resolved_client["value"] is True
assert task_id == "task-123"
@pytest.mark.asyncio
async def test_enqueue_resume_dedup_returns_existing_task_id(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from v1.agent import dependencies as deps
fake_redis = _FakeRedis()
resolved_client = {"value": False}
async def _fake_kiq(payload: dict[str, object]) -> _FakeAsyncResult:
del payload
return _FakeAsyncResult("new-task-id")
async def _fake_get_or_init_client() -> _FakeRedis:
resolved_client["value"] = True
return fake_redis
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_kiq)
dedup_key = "resume:session-1:call-1"
fake_redis.store[f"agent:dedup:{dedup_key}"] = "existing-task-id"
client = TaskiqQueueClient()
task_id = await client.enqueue(
command={"command": "resume", "session_id": "session-1", "tool_call_id": "call-1"},
dedup_key=dedup_key,
)
assert resolved_client["value"] is True
assert task_id == "existing-task-id"
@pytest.mark.asyncio
async def test_enqueue_inflight_dedup_waits_and_reuses_existing_task_id(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from v1.agent import dependencies as deps
fake_redis = _FakeRedis()
dedup_key = "resume:session-1:call-1"
redis_key = f"agent:dedup:{dedup_key}"
fake_redis.store[redis_key] = deps.DEDUP_INFLIGHT_MARKER
attempts = {"count": 0}
async def _fake_get_or_init_client() -> _FakeRedis:
return fake_redis
async def _fake_get(key: str) -> str | None:
attempts["count"] += 1
if attempts["count"] > 1:
fake_redis.store[key] = "existing-task-id"
return fake_redis.store.get(key)
async def _fake_sleep(_: float) -> None:
return None
async def _fake_kiq(payload: dict[str, object]) -> _FakeAsyncResult:
del payload
raise AssertionError("should not enqueue when dedup task id appears")
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
monkeypatch.setattr(fake_redis, "get", _fake_get)
monkeypatch.setattr(deps.asyncio, "sleep", _fake_sleep)
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_kiq)
client = TaskiqQueueClient()
task_id = await client.enqueue(
command={"command": "resume", "session_id": "session-1", "tool_call_id": "call-1"},
dedup_key=dedup_key,
)
assert task_id == "existing-task-id"
@pytest.mark.asyncio
async def test_enqueue_failure_cleans_dedup_lock(monkeypatch: pytest.MonkeyPatch) -> None:
from v1.agent import dependencies as deps
fake_redis = _FakeRedis()
dedup_key = "resume:session-1:call-1"
redis_key = f"agent:dedup:{dedup_key}"
async def _fake_get_or_init_client() -> _FakeRedis:
return fake_redis
async def _fake_kiq(payload: dict[str, object]) -> _FakeAsyncResult:
del payload
raise RuntimeError("enqueue failed")
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_kiq)
client = TaskiqQueueClient()
with pytest.raises(RuntimeError, match="enqueue failed"):
await client.enqueue(
command={"command": "resume", "session_id": "session-1", "tool_call_id": "call-1"},
dedup_key=dedup_key,
)
assert redis_key in fake_redis.delete_calls
@pytest.mark.asyncio
async def test_enqueue_uses_critical_queue_when_requested(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from v1.agent import dependencies as deps
fake_redis = _FakeRedis()
async def _fake_get_or_init_client() -> _FakeRedis:
return fake_redis
async def _fake_default_kiq(_: dict[str, object]) -> _FakeAsyncResult:
raise AssertionError("default queue should not be selected")
async def _fake_critical_kiq(_: dict[str, object]) -> _FakeAsyncResult:
return _FakeAsyncResult("critical-task-id")
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_default_kiq)
monkeypatch.setattr(deps.run_command_task_critical, "kiq", _fake_critical_kiq)
client = TaskiqQueueClient()
task_id = await client.enqueue(
command={"command": "run", "queue": "critical"},
dedup_key=None,
)
assert task_id == "critical-task-id"
@pytest.mark.asyncio
async def test_enqueue_uses_bulk_queue_when_requested(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from v1.agent import dependencies as deps
fake_redis = _FakeRedis()
async def _fake_get_or_init_client() -> _FakeRedis:
return fake_redis
async def _fake_default_kiq(_: dict[str, object]) -> _FakeAsyncResult:
raise AssertionError("default queue should not be selected")
async def _fake_bulk_kiq(_: dict[str, object]) -> _FakeAsyncResult:
return _FakeAsyncResult("bulk-task-id")
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_default_kiq)
monkeypatch.setattr(deps.run_command_task_bulk, "kiq", _fake_bulk_kiq)
client = TaskiqQueueClient()
task_id = await client.enqueue(
command={"command": "run", "queue": "bulk"},
dedup_key=None,
)
assert task_id == "bulk-task-id"
+40 -28
View File
@@ -1,7 +1,7 @@
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock
import pytest
from fastapi import HTTPException
@@ -12,37 +12,44 @@ from v1.auth.schemas import PasswordResetConfirmRequest, PasswordResetRequest
class TestSupabaseAuthGateway:
@pytest.fixture
def gateway(self) -> SupabaseAuthGateway:
with patch("v1.auth.gateway.create_client") as mock_create:
mock_client = MagicMock()
mock_admin_client = MagicMock()
mock_create.side_effect = [mock_client, mock_admin_client]
return SupabaseAuthGateway()
def gateway(
self, monkeypatch: pytest.MonkeyPatch
) -> tuple[SupabaseAuthGateway, MagicMock, MagicMock]:
mock_client = MagicMock()
mock_admin_client = MagicMock()
monkeypatch.setattr("v1.auth.gateway.supabase_service.get_client", lambda: mock_client)
monkeypatch.setattr(
"v1.auth.gateway.supabase_service.get_admin_client",
lambda: mock_admin_client,
)
return SupabaseAuthGateway(), mock_client, mock_admin_client
@pytest.mark.asyncio
async def test_request_password_reset_calls_email_with_string(
self, gateway: SupabaseAuthGateway
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
) -> None:
sut, mock_client, _ = gateway
mock_reset_email = MagicMock()
gateway._client.auth.reset_password_email = mock_reset_email
mock_client.auth.reset_password_email = mock_reset_email
request = PasswordResetRequest(email="test@example.com")
await gateway.request_password_reset(request)
await sut.request_password_reset(request)
mock_reset_email.assert_called_once_with("test@example.com")
@pytest.mark.asyncio
async def test_request_password_reset_with_redirect(
self, gateway: SupabaseAuthGateway
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
) -> None:
sut, mock_client, _ = gateway
mock_reset_email = MagicMock()
gateway._client.auth.reset_password_email = mock_reset_email
mock_client.auth.reset_password_email = mock_reset_email
request = PasswordResetRequest(
email="test@example.com",
redirect_to="http://localhost:3000/reset-password",
)
await gateway.request_password_reset(request)
await sut.request_password_reset(request)
mock_reset_email.assert_called_once_with(
"test@example.com",
@@ -51,64 +58,68 @@ class TestSupabaseAuthGateway:
@pytest.mark.asyncio
async def test_request_password_reset_swallows_auth_error(
self, gateway: SupabaseAuthGateway
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
) -> None:
sut, mock_client, _ = gateway
from supabase import AuthError
mock_reset_email = MagicMock(side_effect=AuthError("rate limit exceeded", None))
gateway._client.auth.reset_password_email = mock_reset_email
mock_client.auth.reset_password_email = mock_reset_email
request = PasswordResetRequest(email="test@example.com")
result = await gateway.request_password_reset(request)
result = await sut.request_password_reset(request)
mock_reset_email.assert_called_once()
assert result is None
@pytest.mark.asyncio
async def test_request_password_reset_extracts_email_from_mapping(
self, gateway: SupabaseAuthGateway
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
) -> None:
sut, mock_client, _ = gateway
mock_reset_email = MagicMock()
gateway._client.auth.reset_password_email = mock_reset_email
mock_client.auth.reset_password_email = mock_reset_email
request = PasswordResetRequest.model_construct(
email={"email": "test@example.com"},
redirect_to=None,
)
await gateway.request_password_reset(request)
await sut.request_password_reset(request)
mock_reset_email.assert_called_once_with("test@example.com")
@pytest.mark.asyncio
async def test_request_password_reset_rejects_invalid_email_shape(
self, gateway: SupabaseAuthGateway
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
) -> None:
sut, _, _ = gateway
request = PasswordResetRequest.model_construct(
email={"unexpected": "value"},
redirect_to=None,
)
with pytest.raises(HTTPException) as exc_info:
await gateway.request_password_reset(request)
await sut.request_password_reset(request)
assert exc_info.value.status_code == 422
assert exc_info.value.detail == "Invalid email"
@pytest.mark.asyncio
async def test_confirm_password_reset_updates_password_by_user_id(
self, gateway: SupabaseAuthGateway
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
) -> None:
sut, mock_client, mock_admin_client = gateway
verify_response = SimpleNamespace(
session=SimpleNamespace(access_token="access"),
user=SimpleNamespace(id="user-1"),
)
mock_verify_otp = MagicMock(return_value=verify_response)
gateway._client.auth.verify_otp = mock_verify_otp
mock_client.auth.verify_otp = mock_verify_otp
mock_update_user_by_id = MagicMock()
gateway._admin_client.auth.admin = SimpleNamespace(
mock_admin_client.auth.admin = SimpleNamespace(
update_user_by_id=mock_update_user_by_id
)
@@ -118,7 +129,7 @@ class TestSupabaseAuthGateway:
new_password="newpassword123",
)
await gateway.confirm_password_reset(request)
await sut.confirm_password_reset(request)
mock_verify_otp.assert_called_once_with(
{
@@ -134,13 +145,14 @@ class TestSupabaseAuthGateway:
@pytest.mark.asyncio
async def test_confirm_password_reset_raises_when_user_id_missing(
self, gateway: SupabaseAuthGateway
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
) -> None:
sut, mock_client, _ = gateway
verify_response = SimpleNamespace(
session=SimpleNamespace(access_token="access"),
user=SimpleNamespace(id=""),
)
gateway._client.auth.verify_otp = MagicMock(return_value=verify_response)
mock_client.auth.verify_otp = MagicMock(return_value=verify_response)
request = PasswordResetConfirmRequest(
email="test@example.com",
@@ -149,7 +161,7 @@ class TestSupabaseAuthGateway:
)
with pytest.raises(HTTPException) as exc_info:
await gateway.confirm_password_reset(request)
await sut.confirm_password_reset(request)
assert exc_info.value.status_code == 401
assert exc_info.value.detail == "Invalid or expired verification code"