Files
social-app/backend/src/v1/auth/gateway.py
T

306 lines
11 KiB
Python
Raw Normal View History

from __future__ import annotations
import asyncio
from collections.abc import Mapping
from typing import Any, cast
from urllib.parse import urlparse
from fastapi import HTTPException
from supabase import AuthError
from core.config.settings import config
from core.logging import get_logger
from services.base.supabase import supabase_service
from v1.auth.schemas import (
AuthUser,
PasswordResetConfirmRequest,
PasswordResetRequest,
SessionCreateRequest,
SessionRefreshRequest,
SessionResponse,
UserByEmailResponse,
VerificationCreateRequest,
VerificationCreateResponse,
VerificationResendRequest,
VerificationVerifyRequest,
)
from v1.auth.service import AuthServiceGateway
logger = get_logger("v1.auth.gateway")
class SupabaseAuthGateway(AuthServiceGateway):
def _get_client(self) -> Any:
return supabase_service.get_client()
def _get_admin_client(self) -> Any:
return supabase_service.get_admin_client()
async def create_verification(
self, request: VerificationCreateRequest
) -> VerificationCreateResponse:
client = self._get_client()
metadata: dict[str, Any] = {"username": request.username}
if request.invite_code:
metadata["invite_code"] = request.invite_code
payload: dict[str, Any] = {
"email": request.email,
"password": request.password,
"data": metadata,
}
if request.redirect_to:
payload["options"] = {
"email_redirect_to": _validate_redirect_url_or_raise(
request.redirect_to
)
}
try:
sign_up = cast(Any, client.auth.sign_up)
await asyncio.to_thread(sign_up, payload)
return VerificationCreateResponse(email=request.email)
except AuthError as exc:
logger.warning("Signup failed", error_type=type(exc).__name__)
raise HTTPException(
status_code=422, detail="Invalid signup request"
) from exc
async def verify_verification(
self, request: VerificationVerifyRequest
) -> SessionResponse:
if request.type != "signup":
raise HTTPException(status_code=422, detail="Invalid request")
client = self._get_client()
payload: dict[str, Any] = {
"type": request.type,
"email": request.email,
"token": request.token,
}
try:
verify_otp = cast(Any, client.auth.verify_otp)
response = await asyncio.to_thread(verify_otp, payload)
return _map_auth_response(response, "Invalid verification code")
except AuthError as exc:
logger.warning("Signup verify failed", error_type=type(exc).__name__)
raise HTTPException(
status_code=401, detail="Invalid verification code"
) from exc
async def resend_verification(self, request: VerificationResendRequest) -> None:
client = self._get_client()
if request.type == "recovery":
await self.request_password_reset(
PasswordResetRequest(
email=request.email,
redirect_to=request.redirect_to,
)
)
return
payload: dict[str, Any] = {"type": request.type, "email": request.email}
try:
resend = cast(Any, client.auth.resend)
await asyncio.to_thread(resend, payload)
except AuthError as exc:
logger.warning("Signup resend failed", error_type=type(exc).__name__)
async def create_session(self, request: SessionCreateRequest) -> SessionResponse:
client = self._get_client()
payload: dict[str, Any] = {"email": request.email, "password": request.password}
try:
sign_in = cast(Any, client.auth.sign_in_with_password)
response = await asyncio.to_thread(sign_in, payload)
return _map_auth_response(response, "Invalid credentials")
except AuthError as exc:
logger.warning("Login failed", error_type=type(exc).__name__)
raise HTTPException(status_code=401, detail="Invalid credentials") from exc
async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse:
client = self._get_client()
try:
response = await asyncio.to_thread(
client.auth.refresh_session,
request.refresh_token,
)
return _map_auth_response(response, "Invalid refresh token")
except AuthError as exc:
logger.warning("Refresh failed", error_type=type(exc).__name__)
raise HTTPException(
status_code=401, detail="Invalid refresh token"
) from exc
async def delete_session(self, refresh_token: str | None) -> None:
if not refresh_token:
raise HTTPException(status_code=401, detail="Missing refresh token")
client = self._get_client()
try:
response = await asyncio.to_thread(
client.auth.refresh_session,
refresh_token,
)
session = getattr(response, "session", None)
if session is None:
raise HTTPException(status_code=401, detail="Invalid refresh token")
await asyncio.to_thread(
client.auth.set_session,
str(session.access_token),
str(session.refresh_token),
)
await asyncio.to_thread(client.auth.sign_out)
except AuthError as exc:
logger.warning("Logout failed", error_type=type(exc).__name__)
raise HTTPException(
status_code=401, detail="Invalid refresh token"
) from exc
async def get_user_by_email(self, email: str) -> UserByEmailResponse:
admin_client = self._get_admin_client()
users = await asyncio.to_thread(_list_auth_users, admin_client)
normalized_email = email.lower()
user = next(
(
candidate
for candidate in users
if str(getattr(candidate, "email", "")).lower() == normalized_email
),
None,
)
if user is None:
raise HTTPException(status_code=404, detail="User not found")
return UserByEmailResponse(
id=str(getattr(user, "id", "")),
email=str(getattr(user, "email", "")),
created_at=str(getattr(user, "created_at", "")),
email_confirmed_at=(
str(getattr(user, "email_confirmed_at", ""))
if getattr(user, "email_confirmed_at", None)
else None
),
)
async def request_password_reset(self, request: PasswordResetRequest) -> None:
client = self._get_client()
try:
reset_email = cast(Any, client.auth.reset_password_email)
email = _coerce_reset_email(request.email)
if request.redirect_to:
options: dict[str, str] = {
"redirect_to": _validate_redirect_url_or_raise(request.redirect_to)
}
await asyncio.to_thread(reset_email, email, options=options)
else:
await asyncio.to_thread(reset_email, email)
except AuthError as exc:
logger.warning(
"Password reset request failed",
error_type=type(exc).__name__,
)
async def confirm_password_reset(
self, request: PasswordResetConfirmRequest
) -> None:
client = self._get_client()
admin_client = self._get_admin_client()
verify_payload: dict[str, Any] = {
"type": "recovery",
"email": request.email,
"token": request.token,
}
try:
verify_otp = cast(Any, client.auth.verify_otp)
response = await asyncio.to_thread(verify_otp, verify_payload)
session = getattr(response, "session", None)
user = getattr(response, "user", None)
user_id = str(getattr(user, "id", "")) if user is not None else ""
if session is None or not user_id:
raise HTTPException(
status_code=401, detail="Invalid or expired verification code"
)
await asyncio.to_thread(
admin_client.auth.admin.update_user_by_id,
user_id,
{"password": request.new_password},
)
except AuthError as exc:
logger.warning(
"Password reset confirm failed", error_type=type(exc).__name__
)
raise HTTPException(
status_code=401, detail="Invalid or expired verification code"
) from exc
def _coerce_reset_email(value: object) -> str:
if isinstance(value, str):
return value
if isinstance(value, Mapping):
nested = value.get("email") or value.get("value")
if isinstance(nested, str):
return nested
raise HTTPException(status_code=422, detail="Invalid email")
def _map_auth_response(response: object, failure_message: str) -> SessionResponse:
session = getattr(response, "session", None)
user = getattr(response, "user", None)
if session is None or user is None:
raise HTTPException(status_code=401, detail=failure_message)
email = getattr(user, "email", None)
if not email:
raise HTTPException(status_code=401, detail=failure_message)
auth_user = AuthUser(id=str(user.id), email=str(email))
return SessionResponse(
access_token=str(session.access_token),
refresh_token=str(session.refresh_token),
expires_in=int(session.expires_in or 0),
token_type=str(session.token_type),
user=auth_user,
)
def _validate_redirect_url_or_raise(url: str) -> str:
parsed = urlparse(url)
if parsed.scheme not in {"http", "https"}:
raise HTTPException(status_code=422, detail="Invalid redirect URL")
if not parsed.netloc:
raise HTTPException(status_code=422, detail="Invalid redirect URL")
site_origin = _origin_of(config.supabase.site_url)
allowlist = {
site_origin,
*(_origin_of(item) for item in config.supabase.additional_redirect_urls),
}
target_origin = f"{parsed.scheme}://{parsed.netloc}".lower()
if target_origin not in allowlist:
raise HTTPException(status_code=422, detail="Invalid redirect URL")
return url
def _origin_of(url: str) -> str:
parsed = urlparse(url.strip())
if parsed.scheme not in {"http", "https"} or not parsed.netloc:
return ""
return f"{parsed.scheme}://{parsed.netloc}".lower()
def _list_auth_users(client: Any) -> list[Any]:
users: list[Any] = []
page = 1
max_pages = 100
while page <= max_pages:
response = client.auth.admin.list_users(page=page, per_page=100)
batch = list(getattr(response, "users", []))
users.extend(batch)
if len(batch) < 100:
break
page += 1
return users