from __future__ import annotations import asyncio import re import time from typing import Any, cast from pydantic import ValidationError from supabase import AuthError from core.config.settings import config from core.http.errors import ApiProblemError from core.logging import get_logger from services.base.supabase import supabase_service from v1.auth.dev_email_session import create_dev_email_session from v1.auth.schemas import ( AuthUser, EmailSessionCreateRequest, OtpSendRequest, SessionRefreshRequest, SessionResponse, UserByIdResponse, UserByEmailResponse, ) from v1.auth.service import AuthServiceGateway logger = get_logger("v1.auth.gateway") AUTH_UNAVAILABLE_DETAIL = "Auth service temporarily unavailable" def _auth_error( *, status_code: int, code: str, detail: str, ) -> ApiProblemError: return ApiProblemError(status_code=status_code, code=code, detail=detail) class SupabaseAuthGateway(AuthServiceGateway): def __init__(self) -> None: self._user_lookup_cache_ttl_seconds: int = 60 self._user_lookup_cache_expires_at: float = 0.0 self._users_by_email: dict[str, Any] = {} self._users_by_id: dict[str, Any] = {} def _get_client(self) -> Any: return supabase_service.get_client() def _get_admin_client(self) -> Any: return supabase_service.get_admin_client() async def send_otp(self, request: OtpSendRequest) -> None: client = self._get_client() payload: dict[str, Any] = { "email": request.email, "options": {"should_create_user": True}, } try: sign_in_with_otp = cast(Any, client.auth.sign_in_with_otp) await asyncio.to_thread(sign_in_with_otp, payload) except AuthError as exc: logger.warning("Send otp failed", error_type=type(exc).__name__) if _is_auth_upstream_unavailable(exc): raise _auth_error( status_code=503, code="AUTH_SERVICE_UNAVAILABLE", detail=AUTH_UNAVAILABLE_DETAIL, ) from exc raise _auth_error( status_code=429, code="AUTH_TOO_MANY_REQUESTS", detail="Too many requests", ) from exc async def create_email_session( self, request: EmailSessionCreateRequest ) -> SessionResponse: if config.runtime.environment == "dev": return await create_dev_email_session( request=request, client=self._get_client(), admin_client=self._get_admin_client(), auth_unavailable_detail=AUTH_UNAVAILABLE_DETAIL, is_auth_upstream_unavailable=_is_auth_upstream_unavailable, map_auth_response=_map_auth_response, ) client = self._get_client() payload: dict[str, Any] = { "type": "email", "email": request.email, "token": request.token, } try: verify_otp = cast(Any, client.auth.verify_otp) response = await asyncio.to_thread(verify_otp, payload) return _map_auth_response( response, "Invalid verification code", "AUTH_VERIFICATION_CODE_INVALID", ) except AuthError as exc: logger.warning("Create email session failed", error_type=type(exc).__name__) if _is_auth_upstream_unavailable(exc): raise _auth_error( status_code=503, code="AUTH_SERVICE_UNAVAILABLE", detail=AUTH_UNAVAILABLE_DETAIL, ) from exc raise _auth_error( status_code=401, code="AUTH_VERIFICATION_CODE_INVALID", detail="Invalid verification code", ) from exc async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse: client = self._get_client() try: response = await asyncio.to_thread( client.auth.refresh_session, request.refresh_token, ) return _map_auth_response( response, "Invalid refresh token", "AUTH_REFRESH_TOKEN_INVALID", ) except AuthError as exc: logger.warning("Refresh failed", error_type=type(exc).__name__) if _is_auth_upstream_unavailable(exc): raise _auth_error( status_code=503, code="AUTH_SERVICE_UNAVAILABLE", detail=AUTH_UNAVAILABLE_DETAIL, ) from exc raise _auth_error( status_code=401, code="AUTH_REFRESH_TOKEN_INVALID", detail="Invalid refresh token", ) from exc async def delete_session(self, refresh_token: str | None) -> None: if not refresh_token: raise _auth_error( status_code=401, code="AUTH_REFRESH_TOKEN_MISSING", detail="Missing refresh token", ) client = self._get_client() try: response = await asyncio.to_thread( client.auth.refresh_session, refresh_token, ) session = getattr(response, "session", None) if session is None: raise _auth_error( status_code=401, code="AUTH_REFRESH_TOKEN_INVALID", detail="Invalid refresh token", ) await asyncio.to_thread( client.auth.set_session, str(session.access_token), str(session.refresh_token), ) await asyncio.to_thread(client.auth.sign_out) except AuthError as exc: logger.warning("Logout failed", error_type=type(exc).__name__) if _is_auth_upstream_unavailable(exc): raise _auth_error( status_code=503, code="AUTH_SERVICE_UNAVAILABLE", detail=AUTH_UNAVAILABLE_DETAIL, ) from exc raise _auth_error( status_code=401, code="AUTH_REFRESH_TOKEN_INVALID", detail="Invalid refresh token", ) from exc async def get_user_by_email(self, email: str) -> UserByEmailResponse: normalized_email = _normalize_email(email) if not normalized_email: raise _auth_error( status_code=404, code="AUTH_USER_NOT_FOUND", detail="User not found", ) await self._refresh_user_lookup_cache_if_needed() user = self._users_by_email.get(normalized_email) if user is None: raise _auth_error( status_code=404, code="AUTH_USER_NOT_FOUND", detail="User not found", ) user_email = _normalize_email(getattr(user, "email", "")) if not user_email: raise _auth_error( status_code=404, code="AUTH_USER_NOT_FOUND", detail="User not found", ) return UserByEmailResponse( id=str(getattr(user, "id", "")), email=user_email, created_at=str(getattr(user, "created_at", "")), email_confirmed_at=( str(getattr(user, "email_confirmed_at", "")) if getattr(user, "email_confirmed_at", None) else None ), ) async def get_user_by_id(self, user_id: str) -> UserByIdResponse: users = await self.get_users_by_ids([user_id]) resolved = users.get(user_id) if resolved is None: raise _auth_error( status_code=404, code="AUTH_USER_NOT_FOUND", detail="User not found", ) return resolved async def get_users_by_ids( self, user_ids: list[str] ) -> dict[str, UserByIdResponse]: await self._refresh_user_lookup_cache_if_needed() resolved: dict[str, UserByIdResponse] = {} for raw_user_id in user_ids: normalized_user_id = raw_user_id.strip() if not normalized_user_id: continue user = self._users_by_id.get(normalized_user_id) if user is None: continue user_attrs = getattr(user, "user", user) resolved[normalized_user_id] = UserByIdResponse( id=str(getattr(user_attrs, "id", "")), email=getattr(user_attrs, "email", None), created_at=str(getattr(user_attrs, "created_at", "")), email_confirmed_at=( str(getattr(user_attrs, "email_confirmed_at", "")) if getattr(user_attrs, "email_confirmed_at", None) else None ), ) return resolved async def search_user_ids_by_email(self, query: str, limit: int = 20) -> list[str]: normalized_query = _normalize_email(query) if not normalized_query: return [] await self._refresh_user_lookup_cache_if_needed() matched_user = self._users_by_email.get(normalized_query) if matched_user is None: return [] user_id = str(getattr(matched_user, "id", "")) return [user_id] if user_id else [] async def _refresh_user_lookup_cache_if_needed(self) -> None: now = time.monotonic() if now < self._user_lookup_cache_expires_at: return admin_client = self._get_admin_client() users = await asyncio.to_thread(_list_auth_users, admin_client) users_by_email: dict[str, Any] = {} users_by_id: dict[str, Any] = {} for candidate in users: candidate_id = str(getattr(candidate, "id", "")).strip() if candidate_id: users_by_id[candidate_id] = candidate candidate_email = _normalize_email(getattr(candidate, "email", "")) if candidate_email: users_by_email[candidate_email] = candidate self._users_by_id = users_by_id self._users_by_email = users_by_email self._user_lookup_cache_expires_at = now + self._user_lookup_cache_ttl_seconds def _is_auth_upstream_unavailable(exc: AuthError) -> bool: raw_status = getattr(exc, "status", None) if raw_status is None: raw_status = getattr(exc, "status_code", None) if isinstance(raw_status, int) and 500 <= raw_status < 600: return True raw_code = getattr(exc, "code", None) code = str(raw_code).lower() if raw_code is not None else "" message = str(exc).lower() indicators = ( "request_timeout", "timed out", "timeout", "gateway timeout", "bad_gateway", "service_unavailable", "internal_server_error", "unexpected_failure", "upstream", "500", "502", "503", "504", "5xx", ) return any(token in code or token in message for token in indicators) def _map_auth_response( response: object, failure_message: str, failure_code: str ) -> SessionResponse: session = getattr(response, "session", None) user = getattr(response, "user", None) if session is None or user is None: raise _auth_error( status_code=401, code=failure_code, detail=failure_message, ) email = _normalize_email(getattr(user, "email", None)) if not email: raise _auth_error( status_code=401, code=failure_code, detail=failure_message, ) try: auth_user = AuthUser(id=str(user.id), email=str(email)) except ValidationError as exc: logger.warning( "Auth response returned invalid email format", error_type=type(exc).__name__, ) raise _auth_error( status_code=401, code=failure_code, detail=failure_message, ) from exc return SessionResponse( access_token=str(session.access_token), refresh_token=str(session.refresh_token), expires_in=int(session.expires_in or 0), token_type=str(session.token_type), user=auth_user, ) def _list_auth_users(client: Any) -> list[Any]: users: list[Any] = [] page = 1 max_pages = 100 while page <= max_pages: response = client.auth.admin.list_users(page=page, per_page=100) batch = ( list(response) if isinstance(response, list) else list(getattr(response, "users", [])) ) users.extend(batch) if len(batch) < 100: break page += 1 return users _EMAIL_PATTERN = re.compile(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$") def _normalize_email(raw_email: object) -> str | None: if not isinstance(raw_email, str): return None email = raw_email.strip().lower() if not _EMAIL_PATTERN.fullmatch(email): return None return email