from __future__ import annotations import asyncio from typing import Any, Protocol, cast from fastapi import HTTPException from supabase import AuthError, create_client from core.config.settings import SupabaseSettings, config from core.logging import get_logger from v1.auth.models import ( AuthTokenResponse, AuthUser, LoginRequest, RefreshRequest, SignupRequest, ) logger = get_logger("v1.auth.service") class AuthServiceGateway(Protocol): async def signup(self, request: SignupRequest) -> AuthTokenResponse: raise NotImplementedError async def login(self, request: LoginRequest) -> AuthTokenResponse: raise NotImplementedError async def refresh(self, request: RefreshRequest) -> AuthTokenResponse: raise NotImplementedError async def logout(self, refresh_token: str | None) -> None: raise NotImplementedError class SupabaseAuthGateway(AuthServiceGateway): _client: Any def __init__(self) -> None: settings: SupabaseSettings = config.supabase self._client = create_client(settings.url, settings.anon_key) async def signup(self, request: SignupRequest) -> AuthTokenResponse: payload: dict[str, Any] = { "email": request.email, "password": request.password, } if request.display_name: payload = { **payload, "data": {"display_name": request.display_name}, } try: sign_up = cast(Any, self._client.auth.sign_up) response = await asyncio.to_thread(sign_up, payload) return _map_auth_response(response, "Authentication failed") except AuthError as exc: logger.warning("Signup failed", error=str(exc)) raise HTTPException( status_code=401, detail="Authentication failed" ) from exc async def login(self, request: LoginRequest) -> AuthTokenResponse: payload: dict[str, Any] = {"email": request.email, "password": request.password} try: sign_in = cast(Any, self._client.auth.sign_in_with_password) response = await asyncio.to_thread(sign_in, payload) return _map_auth_response(response, "Invalid credentials") except AuthError as exc: logger.warning("Login failed", error=str(exc)) raise HTTPException(status_code=401, detail="Invalid credentials") from exc async def refresh(self, request: RefreshRequest) -> AuthTokenResponse: try: response = await asyncio.to_thread( self._client.auth.refresh_session, request.refresh_token, ) return _map_auth_response(response, "Invalid refresh token") except AuthError as exc: logger.warning("Refresh failed", error=str(exc)) raise HTTPException( status_code=401, detail="Invalid refresh token" ) from exc async def logout(self, refresh_token: str | None) -> None: if not refresh_token: raise HTTPException(status_code=401, detail="Missing refresh token") try: response = await asyncio.to_thread( self._client.auth.refresh_session, refresh_token, ) session = getattr(response, "session", None) if session is None: raise HTTPException(status_code=401, detail="Invalid refresh token") await asyncio.to_thread( self._client.auth.set_session, str(session.access_token), str(session.refresh_token), ) await asyncio.to_thread(self._client.auth.sign_out) except AuthError as exc: logger.warning("Logout failed", error=str(exc)) raise HTTPException( status_code=401, detail="Invalid refresh token" ) from exc class AuthService: _gateway: AuthServiceGateway def __init__(self, gateway: AuthServiceGateway) -> None: self._gateway = gateway async def signup(self, request: SignupRequest) -> AuthTokenResponse: return await self._gateway.signup(request) async def login(self, request: LoginRequest) -> AuthTokenResponse: return await self._gateway.login(request) async def refresh(self, request: RefreshRequest) -> AuthTokenResponse: return await self._gateway.refresh(request) async def logout(self, refresh_token: str | None) -> None: await self._gateway.logout(refresh_token) def _map_auth_response(response: object, failure_message: str) -> AuthTokenResponse: session = getattr(response, "session", None) user = getattr(response, "user", None) if session is None or user is None: raise HTTPException(status_code=401, detail=failure_message) email = getattr(user, "email", None) if not email: raise HTTPException(status_code=401, detail=failure_message) auth_user = AuthUser(id=str(user.id), email=str(email)) return AuthTokenResponse( access_token=str(session.access_token), refresh_token=str(session.refresh_token), expires_in=int(session.expires_in or 0), token_type=str(session.token_type), user=auth_user, )