from __future__ import annotations from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest from fastapi import HTTPException from v1.auth.gateway import SupabaseAuthGateway from v1.auth.schemas import PasswordResetConfirmRequest, PasswordResetRequest class TestSupabaseAuthGateway: @pytest.fixture def gateway(self) -> SupabaseAuthGateway: with patch("v1.auth.gateway.create_client") as mock_create: mock_client = MagicMock() mock_admin_client = MagicMock() mock_create.side_effect = [mock_client, mock_admin_client] return SupabaseAuthGateway() @pytest.mark.asyncio async def test_request_password_reset_calls_email_with_string( self, gateway: SupabaseAuthGateway ) -> None: mock_reset_email = MagicMock() gateway._client.auth.reset_password_email = mock_reset_email request = PasswordResetRequest(email="test@example.com") await gateway.request_password_reset(request) mock_reset_email.assert_called_once_with("test@example.com") @pytest.mark.asyncio async def test_request_password_reset_with_redirect( self, gateway: SupabaseAuthGateway ) -> None: mock_reset_email = MagicMock() gateway._client.auth.reset_password_email = mock_reset_email request = PasswordResetRequest( email="test@example.com", redirect_to="http://localhost:3000/reset-password", ) await gateway.request_password_reset(request) mock_reset_email.assert_called_once_with( "test@example.com", options={"redirect_to": "http://localhost:3000/reset-password"}, ) @pytest.mark.asyncio async def test_request_password_reset_swallows_auth_error( self, gateway: SupabaseAuthGateway ) -> None: from supabase import AuthError mock_reset_email = MagicMock(side_effect=AuthError("rate limit exceeded", None)) gateway._client.auth.reset_password_email = mock_reset_email request = PasswordResetRequest(email="test@example.com") result = await gateway.request_password_reset(request) mock_reset_email.assert_called_once() assert result is None @pytest.mark.asyncio async def test_request_password_reset_extracts_email_from_mapping( self, gateway: SupabaseAuthGateway ) -> None: mock_reset_email = MagicMock() gateway._client.auth.reset_password_email = mock_reset_email request = PasswordResetRequest.model_construct( email={"email": "test@example.com"}, redirect_to=None, ) await gateway.request_password_reset(request) mock_reset_email.assert_called_once_with("test@example.com") @pytest.mark.asyncio async def test_request_password_reset_rejects_invalid_email_shape( self, gateway: SupabaseAuthGateway ) -> None: request = PasswordResetRequest.model_construct( email={"unexpected": "value"}, redirect_to=None, ) with pytest.raises(HTTPException) as exc_info: await gateway.request_password_reset(request) assert exc_info.value.status_code == 422 assert exc_info.value.detail == "Invalid email" @pytest.mark.asyncio async def test_confirm_password_reset_updates_password_by_user_id( self, gateway: SupabaseAuthGateway ) -> None: verify_response = SimpleNamespace( session=SimpleNamespace(access_token="access"), user=SimpleNamespace(id="user-1"), ) mock_verify_otp = MagicMock(return_value=verify_response) gateway._client.auth.verify_otp = mock_verify_otp mock_update_user_by_id = MagicMock() gateway._admin_client.auth.admin = SimpleNamespace( update_user_by_id=mock_update_user_by_id ) request = PasswordResetConfirmRequest( email="test@example.com", token="123456", new_password="newpassword123", ) await gateway.confirm_password_reset(request) mock_verify_otp.assert_called_once_with( { "type": "recovery", "email": "test@example.com", "token": "123456", } ) mock_update_user_by_id.assert_called_once_with( "user-1", {"password": "newpassword123"}, ) @pytest.mark.asyncio async def test_confirm_password_reset_raises_when_user_id_missing( self, gateway: SupabaseAuthGateway ) -> None: verify_response = SimpleNamespace( session=SimpleNamespace(access_token="access"), user=SimpleNamespace(id=""), ) gateway._client.auth.verify_otp = MagicMock(return_value=verify_response) request = PasswordResetConfirmRequest( email="test@example.com", token="123456", new_password="newpassword123", ) with pytest.raises(HTTPException) as exc_info: await gateway.confirm_password_reset(request) assert exc_info.value.status_code == 401 assert exc_info.value.detail == "Invalid or expired verification code"