Files
eryao/backend/src/v1/auth/gateway.py
T

393 lines
13 KiB
Python
Raw Normal View History

from __future__ import annotations
import asyncio
import re
import time
from typing import Any, cast
from pydantic import ValidationError
from supabase import AuthError
from core.config.settings import config
from core.http.errors import ApiProblemError
from core.logging import get_logger
from services.base.supabase import supabase_service
from v1.auth.dev_email_session import create_dev_email_session
from v1.auth.schemas import (
AuthUser,
EmailSessionCreateRequest,
OtpSendRequest,
SessionRefreshRequest,
SessionResponse,
UserByIdResponse,
UserByEmailResponse,
)
from v1.auth.service import AuthServiceGateway
logger = get_logger("v1.auth.gateway")
AUTH_UNAVAILABLE_DETAIL = "Auth service temporarily unavailable"
def _auth_error(
*,
status_code: int,
code: str,
detail: str,
) -> ApiProblemError:
return ApiProblemError(status_code=status_code, code=code, detail=detail)
class SupabaseAuthGateway(AuthServiceGateway):
def __init__(self) -> None:
self._user_lookup_cache_ttl_seconds: int = 60
self._user_lookup_cache_expires_at: float = 0.0
self._users_by_email: dict[str, Any] = {}
self._users_by_id: dict[str, Any] = {}
def _get_client(self) -> Any:
return supabase_service.get_client()
def _get_admin_client(self) -> Any:
return supabase_service.get_admin_client()
async def send_otp(self, request: OtpSendRequest) -> None:
client = self._get_client()
payload: dict[str, Any] = {
"email": request.email,
"options": {"should_create_user": True},
}
try:
sign_in_with_otp = cast(Any, client.auth.sign_in_with_otp)
await asyncio.to_thread(sign_in_with_otp, payload)
except AuthError as exc:
logger.warning("Send otp failed", error_type=type(exc).__name__)
if _is_auth_upstream_unavailable(exc):
raise _auth_error(
status_code=503,
code="AUTH_SERVICE_UNAVAILABLE",
detail=AUTH_UNAVAILABLE_DETAIL,
) from exc
raise _auth_error(
status_code=429,
code="AUTH_TOO_MANY_REQUESTS",
detail="Too many requests",
) from exc
async def create_email_session(
self, request: EmailSessionCreateRequest
) -> SessionResponse:
if config.runtime.environment == "dev":
return await create_dev_email_session(
request=request,
client=self._get_client(),
admin_client=self._get_admin_client(),
auth_unavailable_detail=AUTH_UNAVAILABLE_DETAIL,
is_auth_upstream_unavailable=_is_auth_upstream_unavailable,
map_auth_response=_map_auth_response,
)
client = self._get_client()
payload: dict[str, Any] = {
"type": "email",
"email": request.email,
"token": request.token,
}
try:
verify_otp = cast(Any, client.auth.verify_otp)
response = await asyncio.to_thread(verify_otp, payload)
return _map_auth_response(
response,
"Invalid verification code",
"AUTH_VERIFICATION_CODE_INVALID",
)
except AuthError as exc:
logger.warning("Create email session failed", error_type=type(exc).__name__)
if _is_auth_upstream_unavailable(exc):
raise _auth_error(
status_code=503,
code="AUTH_SERVICE_UNAVAILABLE",
detail=AUTH_UNAVAILABLE_DETAIL,
) from exc
raise _auth_error(
status_code=401,
code="AUTH_VERIFICATION_CODE_INVALID",
detail="Invalid verification code",
) from exc
async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse:
client = self._get_client()
try:
response = await asyncio.to_thread(
client.auth.refresh_session,
request.refresh_token,
)
return _map_auth_response(
response,
"Invalid refresh token",
"AUTH_REFRESH_TOKEN_INVALID",
)
except AuthError as exc:
logger.warning("Refresh failed", error_type=type(exc).__name__)
if _is_auth_upstream_unavailable(exc):
raise _auth_error(
status_code=503,
code="AUTH_SERVICE_UNAVAILABLE",
detail=AUTH_UNAVAILABLE_DETAIL,
) from exc
raise _auth_error(
status_code=401,
code="AUTH_REFRESH_TOKEN_INVALID",
detail="Invalid refresh token",
) from exc
async def delete_session(self, refresh_token: str | None) -> None:
if not refresh_token:
raise _auth_error(
status_code=401,
code="AUTH_REFRESH_TOKEN_MISSING",
detail="Missing refresh token",
)
client = self._get_client()
try:
response = await asyncio.to_thread(
client.auth.refresh_session,
refresh_token,
)
session = getattr(response, "session", None)
if session is None:
raise _auth_error(
status_code=401,
code="AUTH_REFRESH_TOKEN_INVALID",
detail="Invalid refresh token",
)
await asyncio.to_thread(
client.auth.set_session,
str(session.access_token),
str(session.refresh_token),
)
await asyncio.to_thread(client.auth.sign_out)
except AuthError as exc:
logger.warning("Logout failed", error_type=type(exc).__name__)
if _is_auth_upstream_unavailable(exc):
raise _auth_error(
status_code=503,
code="AUTH_SERVICE_UNAVAILABLE",
detail=AUTH_UNAVAILABLE_DETAIL,
) from exc
raise _auth_error(
status_code=401,
code="AUTH_REFRESH_TOKEN_INVALID",
detail="Invalid refresh token",
) from exc
async def get_user_by_email(self, email: str) -> UserByEmailResponse:
normalized_email = _normalize_email(email)
if not normalized_email:
raise _auth_error(
status_code=404,
code="AUTH_USER_NOT_FOUND",
detail="User not found",
)
await self._refresh_user_lookup_cache_if_needed()
user = self._users_by_email.get(normalized_email)
if user is None:
raise _auth_error(
status_code=404,
code="AUTH_USER_NOT_FOUND",
detail="User not found",
)
user_email = _normalize_email(getattr(user, "email", ""))
if not user_email:
raise _auth_error(
status_code=404,
code="AUTH_USER_NOT_FOUND",
detail="User not found",
)
return UserByEmailResponse(
id=str(getattr(user, "id", "")),
email=user_email,
created_at=str(getattr(user, "created_at", "")),
email_confirmed_at=(
str(getattr(user, "email_confirmed_at", ""))
if getattr(user, "email_confirmed_at", None)
else None
),
)
async def get_user_by_id(self, user_id: str) -> UserByIdResponse:
users = await self.get_users_by_ids([user_id])
resolved = users.get(user_id)
if resolved is None:
raise _auth_error(
status_code=404,
code="AUTH_USER_NOT_FOUND",
detail="User not found",
)
return resolved
async def get_users_by_ids(
self, user_ids: list[str]
) -> dict[str, UserByIdResponse]:
await self._refresh_user_lookup_cache_if_needed()
resolved: dict[str, UserByIdResponse] = {}
for raw_user_id in user_ids:
normalized_user_id = raw_user_id.strip()
if not normalized_user_id:
continue
user = self._users_by_id.get(normalized_user_id)
if user is None:
continue
user_attrs = getattr(user, "user", user)
resolved[normalized_user_id] = UserByIdResponse(
id=str(getattr(user_attrs, "id", "")),
email=getattr(user_attrs, "email", None),
created_at=str(getattr(user_attrs, "created_at", "")),
email_confirmed_at=(
str(getattr(user_attrs, "email_confirmed_at", ""))
if getattr(user_attrs, "email_confirmed_at", None)
else None
),
)
return resolved
async def search_user_ids_by_email(self, query: str, limit: int = 20) -> list[str]:
normalized_query = _normalize_email(query)
if not normalized_query:
return []
await self._refresh_user_lookup_cache_if_needed()
matched_user = self._users_by_email.get(normalized_query)
if matched_user is None:
return []
user_id = str(getattr(matched_user, "id", ""))
return [user_id] if user_id else []
async def _refresh_user_lookup_cache_if_needed(self) -> None:
now = time.monotonic()
if now < self._user_lookup_cache_expires_at:
return
admin_client = self._get_admin_client()
users = await asyncio.to_thread(_list_auth_users, admin_client)
users_by_email: dict[str, Any] = {}
users_by_id: dict[str, Any] = {}
for candidate in users:
candidate_id = str(getattr(candidate, "id", "")).strip()
if candidate_id:
users_by_id[candidate_id] = candidate
candidate_email = _normalize_email(getattr(candidate, "email", ""))
if candidate_email:
users_by_email[candidate_email] = candidate
self._users_by_id = users_by_id
self._users_by_email = users_by_email
self._user_lookup_cache_expires_at = now + self._user_lookup_cache_ttl_seconds
def _is_auth_upstream_unavailable(exc: AuthError) -> bool:
raw_status = getattr(exc, "status", None)
if raw_status is None:
raw_status = getattr(exc, "status_code", None)
if isinstance(raw_status, int) and 500 <= raw_status < 600:
return True
raw_code = getattr(exc, "code", None)
code = str(raw_code).lower() if raw_code is not None else ""
message = str(exc).lower()
indicators = (
"request_timeout",
"timed out",
"timeout",
"gateway timeout",
"bad_gateway",
"service_unavailable",
"internal_server_error",
"unexpected_failure",
"upstream",
"500",
"502",
"503",
"504",
"5xx",
)
return any(token in code or token in message for token in indicators)
def _map_auth_response(
response: object, failure_message: str, failure_code: str
) -> SessionResponse:
session = getattr(response, "session", None)
user = getattr(response, "user", None)
if session is None or user is None:
raise _auth_error(
status_code=401,
code=failure_code,
detail=failure_message,
)
email = _normalize_email(getattr(user, "email", None))
if not email:
raise _auth_error(
status_code=401,
code=failure_code,
detail=failure_message,
)
try:
auth_user = AuthUser(id=str(user.id), email=str(email))
except ValidationError as exc:
logger.warning(
"Auth response returned invalid email format",
error_type=type(exc).__name__,
)
raise _auth_error(
status_code=401,
code=failure_code,
detail=failure_message,
) from exc
return SessionResponse(
access_token=str(session.access_token),
refresh_token=str(session.refresh_token),
expires_in=int(session.expires_in or 0),
token_type=str(session.token_type),
user=auth_user,
)
def _list_auth_users(client: Any) -> list[Any]:
users: list[Any] = []
page = 1
max_pages = 100
while page <= max_pages:
response = client.auth.admin.list_users(page=page, per_page=100)
batch = (
list(response)
if isinstance(response, list)
else list(getattr(response, "users", []))
)
users.extend(batch)
if len(batch) < 100:
break
page += 1
return users
_EMAIL_PATTERN = re.compile(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$")
def _normalize_email(raw_email: object) -> str | None:
if not isinstance(raw_email, str):
return None
email = raw_email.strip().lower()
if not _EMAIL_PATTERN.fullmatch(email):
return None
return email