380 lines
13 KiB
Python
380 lines
13 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from collections.abc import Mapping
|
|
from typing import Any, cast
|
|
from urllib.parse import urlparse
|
|
|
|
from fastapi import HTTPException
|
|
from supabase import AuthError
|
|
|
|
from core.config.settings import config
|
|
from core.logging import get_logger
|
|
from services.base.supabase import supabase_service
|
|
from v1.auth.schemas import (
|
|
AuthUser,
|
|
PasswordResetConfirmRequest,
|
|
PasswordResetRequest,
|
|
SessionCreateRequest,
|
|
SessionRefreshRequest,
|
|
SessionResponse,
|
|
UserByEmailResponse,
|
|
VerificationCreateRequest,
|
|
VerificationCreateResponse,
|
|
VerificationResendRequest,
|
|
VerificationVerifyRequest,
|
|
)
|
|
from v1.auth.service import AuthServiceGateway
|
|
|
|
logger = get_logger("v1.auth.gateway")
|
|
|
|
AUTH_UNAVAILABLE_DETAIL = "Auth service temporarily unavailable"
|
|
|
|
|
|
class SupabaseAuthGateway(AuthServiceGateway):
|
|
def _get_client(self) -> Any:
|
|
return supabase_service.get_client()
|
|
|
|
def _get_admin_client(self) -> Any:
|
|
return supabase_service.get_admin_client()
|
|
|
|
async def create_verification(
|
|
self, request: VerificationCreateRequest
|
|
) -> VerificationCreateResponse:
|
|
client = self._get_client()
|
|
metadata: dict[str, Any] = {"username": request.username}
|
|
if request.invite_code:
|
|
metadata["invite_code"] = request.invite_code
|
|
payload: dict[str, Any] = {
|
|
"email": request.email,
|
|
"password": request.password,
|
|
"data": metadata,
|
|
}
|
|
if request.redirect_to:
|
|
payload["options"] = {
|
|
"email_redirect_to": _validate_redirect_url(request.redirect_to)
|
|
}
|
|
try:
|
|
sign_up = cast(Any, client.auth.sign_up)
|
|
await asyncio.to_thread(sign_up, payload)
|
|
return VerificationCreateResponse(email=request.email)
|
|
except AuthError as exc:
|
|
logger.warning("Signup failed", error_type=type(exc).__name__)
|
|
if _is_auth_upstream_unavailable(exc):
|
|
raise HTTPException(
|
|
status_code=503,
|
|
detail=AUTH_UNAVAILABLE_DETAIL,
|
|
) from exc
|
|
raise HTTPException(
|
|
status_code=422, detail="Invalid signup request"
|
|
) from exc
|
|
|
|
async def verify_verification(
|
|
self, request: VerificationVerifyRequest
|
|
) -> SessionResponse:
|
|
if request.type != "signup":
|
|
raise HTTPException(status_code=422, detail="Invalid request")
|
|
|
|
client = self._get_client()
|
|
payload: dict[str, Any] = {
|
|
"type": request.type,
|
|
"email": request.email,
|
|
"token": request.token,
|
|
}
|
|
try:
|
|
verify_otp = cast(Any, client.auth.verify_otp)
|
|
response = await asyncio.to_thread(verify_otp, payload)
|
|
return _map_auth_response(response, "Invalid verification code")
|
|
except AuthError as exc:
|
|
logger.warning("Signup verify failed", error_type=type(exc).__name__)
|
|
if _is_auth_upstream_unavailable(exc):
|
|
raise HTTPException(
|
|
status_code=503,
|
|
detail=AUTH_UNAVAILABLE_DETAIL,
|
|
) from exc
|
|
raise HTTPException(
|
|
status_code=401, detail="Invalid verification code"
|
|
) from exc
|
|
|
|
async def resend_verification(self, request: VerificationResendRequest) -> None:
|
|
client = self._get_client()
|
|
if request.type == "recovery":
|
|
await self.request_password_reset(
|
|
PasswordResetRequest(
|
|
email=request.email,
|
|
redirect_to=request.redirect_to,
|
|
)
|
|
)
|
|
return
|
|
|
|
payload: dict[str, Any] = {"type": request.type, "email": request.email}
|
|
try:
|
|
resend = cast(Any, client.auth.resend)
|
|
await asyncio.to_thread(resend, payload)
|
|
except AuthError as exc:
|
|
logger.warning("Signup resend failed", error_type=type(exc).__name__)
|
|
if _is_auth_upstream_unavailable(exc):
|
|
raise HTTPException(
|
|
status_code=503,
|
|
detail=AUTH_UNAVAILABLE_DETAIL,
|
|
) from exc
|
|
|
|
async def create_session(self, request: SessionCreateRequest) -> SessionResponse:
|
|
client = self._get_client()
|
|
payload: dict[str, Any] = {"email": request.email, "password": request.password}
|
|
try:
|
|
sign_in = cast(Any, client.auth.sign_in_with_password)
|
|
response = await asyncio.to_thread(sign_in, payload)
|
|
return _map_auth_response(response, "Invalid credentials")
|
|
except AuthError as exc:
|
|
logger.warning("Login failed", error_type=type(exc).__name__)
|
|
if _is_auth_upstream_unavailable(exc):
|
|
raise HTTPException(
|
|
status_code=503,
|
|
detail=AUTH_UNAVAILABLE_DETAIL,
|
|
) from exc
|
|
raise HTTPException(status_code=401, detail="Invalid credentials") from exc
|
|
|
|
async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse:
|
|
client = self._get_client()
|
|
try:
|
|
response = await asyncio.to_thread(
|
|
client.auth.refresh_session,
|
|
request.refresh_token,
|
|
)
|
|
return _map_auth_response(response, "Invalid refresh token")
|
|
except AuthError as exc:
|
|
logger.warning("Refresh failed", error_type=type(exc).__name__)
|
|
if _is_auth_upstream_unavailable(exc):
|
|
raise HTTPException(
|
|
status_code=503,
|
|
detail=AUTH_UNAVAILABLE_DETAIL,
|
|
) from exc
|
|
raise HTTPException(
|
|
status_code=401, detail="Invalid refresh token"
|
|
) from exc
|
|
|
|
async def delete_session(self, refresh_token: str | None) -> None:
|
|
if not refresh_token:
|
|
raise HTTPException(status_code=401, detail="Missing refresh token")
|
|
client = self._get_client()
|
|
try:
|
|
response = await asyncio.to_thread(
|
|
client.auth.refresh_session,
|
|
refresh_token,
|
|
)
|
|
session = getattr(response, "session", None)
|
|
if session is None:
|
|
raise HTTPException(status_code=401, detail="Invalid refresh token")
|
|
await asyncio.to_thread(
|
|
client.auth.set_session,
|
|
str(session.access_token),
|
|
str(session.refresh_token),
|
|
)
|
|
await asyncio.to_thread(client.auth.sign_out)
|
|
except AuthError as exc:
|
|
logger.warning("Logout failed", error_type=type(exc).__name__)
|
|
if _is_auth_upstream_unavailable(exc):
|
|
raise HTTPException(
|
|
status_code=503,
|
|
detail=AUTH_UNAVAILABLE_DETAIL,
|
|
) from exc
|
|
raise HTTPException(
|
|
status_code=401, detail="Invalid refresh token"
|
|
) from exc
|
|
|
|
async def get_user_by_email(self, email: str) -> UserByEmailResponse:
|
|
admin_client = self._get_admin_client()
|
|
users = await asyncio.to_thread(_list_auth_users, admin_client)
|
|
normalized_email = email.lower()
|
|
user = next(
|
|
(
|
|
candidate
|
|
for candidate in users
|
|
if str(getattr(candidate, "email", "")).lower() == normalized_email
|
|
),
|
|
None,
|
|
)
|
|
if user is None:
|
|
raise HTTPException(status_code=404, detail="User not found")
|
|
|
|
return UserByEmailResponse(
|
|
id=str(getattr(user, "id", "")),
|
|
email=str(getattr(user, "email", "")),
|
|
created_at=str(getattr(user, "created_at", "")),
|
|
email_confirmed_at=(
|
|
str(getattr(user, "email_confirmed_at", ""))
|
|
if getattr(user, "email_confirmed_at", None)
|
|
else None
|
|
),
|
|
)
|
|
|
|
async def request_password_reset(self, request: PasswordResetRequest) -> None:
|
|
client = self._get_client()
|
|
try:
|
|
reset_email = cast(Any, client.auth.reset_password_email)
|
|
email = _coerce_reset_email(request.email)
|
|
if request.redirect_to:
|
|
options: dict[str, str] = {
|
|
"redirect_to": _validate_redirect_url(request.redirect_to)
|
|
}
|
|
await asyncio.to_thread(reset_email, email, options=options)
|
|
else:
|
|
await asyncio.to_thread(reset_email, email)
|
|
except AuthError as exc:
|
|
logger.warning(
|
|
"Password reset request failed",
|
|
error_type=type(exc).__name__,
|
|
)
|
|
if _is_auth_upstream_unavailable(exc):
|
|
raise HTTPException(
|
|
status_code=503,
|
|
detail=AUTH_UNAVAILABLE_DETAIL,
|
|
) from exc
|
|
|
|
async def confirm_password_reset(
|
|
self, request: PasswordResetConfirmRequest
|
|
) -> None:
|
|
client = self._get_client()
|
|
admin_client = self._get_admin_client()
|
|
verify_payload: dict[str, Any] = {
|
|
"type": "recovery",
|
|
"email": request.email,
|
|
"token": request.token,
|
|
}
|
|
try:
|
|
verify_otp = cast(Any, client.auth.verify_otp)
|
|
response = await asyncio.to_thread(verify_otp, verify_payload)
|
|
session = getattr(response, "session", None)
|
|
user = getattr(response, "user", None)
|
|
user_id = str(getattr(user, "id", "")) if user is not None else ""
|
|
if session is None or not user_id:
|
|
raise HTTPException(
|
|
status_code=401, detail="Invalid or expired verification code"
|
|
)
|
|
await asyncio.to_thread(
|
|
admin_client.auth.admin.update_user_by_id,
|
|
user_id,
|
|
{"password": request.new_password},
|
|
)
|
|
except AuthError as exc:
|
|
logger.warning(
|
|
"Password reset confirm failed", error_type=type(exc).__name__
|
|
)
|
|
if _is_auth_upstream_unavailable(exc):
|
|
raise HTTPException(
|
|
status_code=503,
|
|
detail=AUTH_UNAVAILABLE_DETAIL,
|
|
) from exc
|
|
raise HTTPException(
|
|
status_code=401, detail="Invalid or expired verification code"
|
|
) from exc
|
|
|
|
|
|
def _is_auth_upstream_unavailable(exc: AuthError) -> bool:
|
|
raw_status = getattr(exc, "status", None)
|
|
if raw_status is None:
|
|
raw_status = getattr(exc, "status_code", None)
|
|
if isinstance(raw_status, int) and 500 <= raw_status < 600:
|
|
return True
|
|
|
|
raw_code = getattr(exc, "code", None)
|
|
code = str(raw_code).lower() if raw_code is not None else ""
|
|
message = str(exc).lower()
|
|
indicators = (
|
|
"request_timeout",
|
|
"timed out",
|
|
"timeout",
|
|
"gateway timeout",
|
|
"bad_gateway",
|
|
"service_unavailable",
|
|
"internal_server_error",
|
|
"unexpected_failure",
|
|
"upstream",
|
|
"500",
|
|
"502",
|
|
"503",
|
|
"504",
|
|
"5xx",
|
|
)
|
|
return any(token in code or token in message for token in indicators)
|
|
|
|
|
|
def _validate_redirect_url(url: str) -> str:
|
|
parsed = urlparse(url)
|
|
if parsed.scheme not in {"http", "https"} or not parsed.netloc:
|
|
raise HTTPException(status_code=422, detail="Invalid redirect URL")
|
|
|
|
origin = f"{parsed.scheme.lower()}://{parsed.netloc.lower()}"
|
|
allowed_origins = {
|
|
_normalize_origin(candidate)
|
|
for candidate in config.cors.allow_origins
|
|
if _is_http_origin(candidate)
|
|
}
|
|
if origin not in allowed_origins:
|
|
raise HTTPException(status_code=422, detail="Invalid redirect URL")
|
|
return url
|
|
|
|
|
|
def _normalize_origin(value: str) -> str:
|
|
parsed = urlparse(value)
|
|
return f"{parsed.scheme.lower()}://{parsed.netloc.lower()}"
|
|
|
|
|
|
def _is_http_origin(value: str) -> bool:
|
|
parsed = urlparse(value)
|
|
return parsed.scheme in {"http", "https"} and bool(parsed.netloc)
|
|
|
|
|
|
def _coerce_reset_email(value: object) -> str:
|
|
if isinstance(value, str):
|
|
return value
|
|
|
|
if isinstance(value, Mapping):
|
|
nested = value.get("email") or value.get("value")
|
|
if isinstance(nested, str):
|
|
return nested
|
|
|
|
raise HTTPException(status_code=422, detail="Invalid email")
|
|
|
|
|
|
def _map_auth_response(response: object, failure_message: str) -> SessionResponse:
|
|
session = getattr(response, "session", None)
|
|
user = getattr(response, "user", None)
|
|
if session is None or user is None:
|
|
raise HTTPException(status_code=401, detail=failure_message)
|
|
|
|
email = getattr(user, "email", None)
|
|
if not email:
|
|
raise HTTPException(status_code=401, detail=failure_message)
|
|
|
|
auth_user = AuthUser(id=str(user.id), email=str(email))
|
|
return SessionResponse(
|
|
access_token=str(session.access_token),
|
|
refresh_token=str(session.refresh_token),
|
|
expires_in=int(session.expires_in or 0),
|
|
token_type=str(session.token_type),
|
|
user=auth_user,
|
|
)
|
|
|
|
|
|
def _list_auth_users(client: Any) -> list[Any]:
|
|
users: list[Any] = []
|
|
page = 1
|
|
max_pages = 100
|
|
|
|
while page <= max_pages:
|
|
response = client.auth.admin.list_users(page=page, per_page=100)
|
|
batch = (
|
|
list(response)
|
|
if isinstance(response, list)
|
|
else list(getattr(response, "users", []))
|
|
)
|
|
users.extend(batch)
|
|
|
|
if len(batch) < 100:
|
|
break
|
|
page += 1
|
|
|
|
return users
|