2026-04-02 16:36:35 +08:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
import asyncio
|
2026-04-02 18:39:35 +08:00
|
|
|
import re
|
2026-04-02 16:36:35 +08:00
|
|
|
import time
|
|
|
|
|
from typing import Any, cast
|
|
|
|
|
|
|
|
|
|
from pydantic import ValidationError
|
|
|
|
|
|
|
|
|
|
from supabase import AuthError
|
|
|
|
|
|
2026-04-02 18:39:35 +08:00
|
|
|
from core.config.settings import config
|
2026-04-02 16:36:35 +08:00
|
|
|
from core.http.errors import ApiProblemError
|
|
|
|
|
from core.logging import get_logger
|
|
|
|
|
from services.base.supabase import supabase_service
|
2026-04-02 18:39:35 +08:00
|
|
|
from v1.auth.dev_email_session import create_dev_email_session
|
2026-04-02 16:36:35 +08:00
|
|
|
from v1.auth.schemas import (
|
|
|
|
|
AuthUser,
|
2026-04-02 18:39:35 +08:00
|
|
|
EmailSessionCreateRequest,
|
2026-04-02 16:36:35 +08:00
|
|
|
OtpSendRequest,
|
|
|
|
|
SessionRefreshRequest,
|
|
|
|
|
SessionResponse,
|
|
|
|
|
UserByIdResponse,
|
2026-04-02 18:39:35 +08:00
|
|
|
UserByEmailResponse,
|
2026-04-02 16:36:35 +08:00
|
|
|
)
|
|
|
|
|
from v1.auth.service import AuthServiceGateway
|
|
|
|
|
|
|
|
|
|
logger = get_logger("v1.auth.gateway")
|
|
|
|
|
|
|
|
|
|
AUTH_UNAVAILABLE_DETAIL = "Auth service temporarily unavailable"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _auth_error(
|
|
|
|
|
*,
|
|
|
|
|
status_code: int,
|
|
|
|
|
code: str,
|
|
|
|
|
detail: str,
|
|
|
|
|
) -> ApiProblemError:
|
|
|
|
|
return ApiProblemError(status_code=status_code, code=code, detail=detail)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SupabaseAuthGateway(AuthServiceGateway):
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
|
self._user_lookup_cache_ttl_seconds: int = 60
|
|
|
|
|
self._user_lookup_cache_expires_at: float = 0.0
|
2026-04-02 18:39:35 +08:00
|
|
|
self._users_by_email: dict[str, Any] = {}
|
2026-04-02 16:36:35 +08:00
|
|
|
self._users_by_id: dict[str, Any] = {}
|
|
|
|
|
|
|
|
|
|
def _get_client(self) -> Any:
|
|
|
|
|
return supabase_service.get_client()
|
|
|
|
|
|
|
|
|
|
def _get_admin_client(self) -> Any:
|
|
|
|
|
return supabase_service.get_admin_client()
|
|
|
|
|
|
|
|
|
|
async def send_otp(self, request: OtpSendRequest) -> None:
|
|
|
|
|
client = self._get_client()
|
|
|
|
|
payload: dict[str, Any] = {
|
2026-04-02 18:39:35 +08:00
|
|
|
"email": request.email,
|
2026-04-02 16:36:35 +08:00
|
|
|
"options": {"should_create_user": True},
|
|
|
|
|
}
|
|
|
|
|
try:
|
|
|
|
|
sign_in_with_otp = cast(Any, client.auth.sign_in_with_otp)
|
|
|
|
|
await asyncio.to_thread(sign_in_with_otp, payload)
|
|
|
|
|
except AuthError as exc:
|
|
|
|
|
logger.warning("Send otp failed", error_type=type(exc).__name__)
|
|
|
|
|
if _is_auth_upstream_unavailable(exc):
|
|
|
|
|
raise _auth_error(
|
|
|
|
|
status_code=503,
|
|
|
|
|
code="AUTH_SERVICE_UNAVAILABLE",
|
|
|
|
|
detail=AUTH_UNAVAILABLE_DETAIL,
|
|
|
|
|
) from exc
|
|
|
|
|
raise _auth_error(
|
|
|
|
|
status_code=429,
|
|
|
|
|
code="AUTH_TOO_MANY_REQUESTS",
|
|
|
|
|
detail="Too many requests",
|
|
|
|
|
) from exc
|
|
|
|
|
|
2026-04-02 18:39:35 +08:00
|
|
|
async def create_email_session(
|
|
|
|
|
self, request: EmailSessionCreateRequest
|
2026-04-02 16:36:35 +08:00
|
|
|
) -> SessionResponse:
|
2026-04-02 18:39:35 +08:00
|
|
|
if config.runtime.environment == "dev":
|
|
|
|
|
return await create_dev_email_session(
|
|
|
|
|
request=request,
|
|
|
|
|
client=self._get_client(),
|
|
|
|
|
admin_client=self._get_admin_client(),
|
|
|
|
|
auth_unavailable_detail=AUTH_UNAVAILABLE_DETAIL,
|
|
|
|
|
is_auth_upstream_unavailable=_is_auth_upstream_unavailable,
|
|
|
|
|
map_auth_response=_map_auth_response,
|
|
|
|
|
)
|
|
|
|
|
|
2026-04-02 16:36:35 +08:00
|
|
|
client = self._get_client()
|
|
|
|
|
payload: dict[str, Any] = {
|
2026-04-02 18:39:35 +08:00
|
|
|
"type": "email",
|
|
|
|
|
"email": request.email,
|
2026-04-02 16:36:35 +08:00
|
|
|
"token": request.token,
|
|
|
|
|
}
|
|
|
|
|
try:
|
|
|
|
|
verify_otp = cast(Any, client.auth.verify_otp)
|
|
|
|
|
response = await asyncio.to_thread(verify_otp, payload)
|
|
|
|
|
return _map_auth_response(
|
|
|
|
|
response,
|
|
|
|
|
"Invalid verification code",
|
|
|
|
|
"AUTH_VERIFICATION_CODE_INVALID",
|
|
|
|
|
)
|
|
|
|
|
except AuthError as exc:
|
2026-04-02 18:39:35 +08:00
|
|
|
logger.warning("Create email session failed", error_type=type(exc).__name__)
|
2026-04-02 16:36:35 +08:00
|
|
|
if _is_auth_upstream_unavailable(exc):
|
|
|
|
|
raise _auth_error(
|
|
|
|
|
status_code=503,
|
|
|
|
|
code="AUTH_SERVICE_UNAVAILABLE",
|
|
|
|
|
detail=AUTH_UNAVAILABLE_DETAIL,
|
|
|
|
|
) from exc
|
|
|
|
|
raise _auth_error(
|
|
|
|
|
status_code=401,
|
|
|
|
|
code="AUTH_VERIFICATION_CODE_INVALID",
|
|
|
|
|
detail="Invalid verification code",
|
|
|
|
|
) from exc
|
|
|
|
|
|
|
|
|
|
async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse:
|
|
|
|
|
client = self._get_client()
|
|
|
|
|
try:
|
|
|
|
|
response = await asyncio.to_thread(
|
|
|
|
|
client.auth.refresh_session,
|
|
|
|
|
request.refresh_token,
|
|
|
|
|
)
|
|
|
|
|
return _map_auth_response(
|
|
|
|
|
response,
|
|
|
|
|
"Invalid refresh token",
|
|
|
|
|
"AUTH_REFRESH_TOKEN_INVALID",
|
|
|
|
|
)
|
|
|
|
|
except AuthError as exc:
|
|
|
|
|
logger.warning("Refresh failed", error_type=type(exc).__name__)
|
|
|
|
|
if _is_auth_upstream_unavailable(exc):
|
|
|
|
|
raise _auth_error(
|
|
|
|
|
status_code=503,
|
|
|
|
|
code="AUTH_SERVICE_UNAVAILABLE",
|
|
|
|
|
detail=AUTH_UNAVAILABLE_DETAIL,
|
|
|
|
|
) from exc
|
|
|
|
|
raise _auth_error(
|
|
|
|
|
status_code=401,
|
|
|
|
|
code="AUTH_REFRESH_TOKEN_INVALID",
|
|
|
|
|
detail="Invalid refresh token",
|
|
|
|
|
) from exc
|
|
|
|
|
|
|
|
|
|
async def delete_session(self, refresh_token: str | None) -> None:
|
|
|
|
|
if not refresh_token:
|
|
|
|
|
raise _auth_error(
|
|
|
|
|
status_code=401,
|
|
|
|
|
code="AUTH_REFRESH_TOKEN_MISSING",
|
|
|
|
|
detail="Missing refresh token",
|
|
|
|
|
)
|
|
|
|
|
client = self._get_client()
|
|
|
|
|
try:
|
|
|
|
|
response = await asyncio.to_thread(
|
|
|
|
|
client.auth.refresh_session,
|
|
|
|
|
refresh_token,
|
|
|
|
|
)
|
|
|
|
|
session = getattr(response, "session", None)
|
|
|
|
|
if session is None:
|
|
|
|
|
raise _auth_error(
|
|
|
|
|
status_code=401,
|
|
|
|
|
code="AUTH_REFRESH_TOKEN_INVALID",
|
|
|
|
|
detail="Invalid refresh token",
|
|
|
|
|
)
|
|
|
|
|
await asyncio.to_thread(
|
|
|
|
|
client.auth.set_session,
|
|
|
|
|
str(session.access_token),
|
|
|
|
|
str(session.refresh_token),
|
|
|
|
|
)
|
|
|
|
|
await asyncio.to_thread(client.auth.sign_out)
|
|
|
|
|
except AuthError as exc:
|
|
|
|
|
logger.warning("Logout failed", error_type=type(exc).__name__)
|
|
|
|
|
if _is_auth_upstream_unavailable(exc):
|
|
|
|
|
raise _auth_error(
|
|
|
|
|
status_code=503,
|
|
|
|
|
code="AUTH_SERVICE_UNAVAILABLE",
|
|
|
|
|
detail=AUTH_UNAVAILABLE_DETAIL,
|
|
|
|
|
) from exc
|
|
|
|
|
raise _auth_error(
|
|
|
|
|
status_code=401,
|
|
|
|
|
code="AUTH_REFRESH_TOKEN_INVALID",
|
|
|
|
|
detail="Invalid refresh token",
|
|
|
|
|
) from exc
|
|
|
|
|
|
2026-04-02 18:39:35 +08:00
|
|
|
async def get_user_by_email(self, email: str) -> UserByEmailResponse:
|
|
|
|
|
normalized_email = _normalize_email(email)
|
|
|
|
|
if not normalized_email:
|
2026-04-02 16:36:35 +08:00
|
|
|
raise _auth_error(
|
|
|
|
|
status_code=404,
|
|
|
|
|
code="AUTH_USER_NOT_FOUND",
|
|
|
|
|
detail="User not found",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
await self._refresh_user_lookup_cache_if_needed()
|
|
|
|
|
|
2026-04-02 18:39:35 +08:00
|
|
|
user = self._users_by_email.get(normalized_email)
|
2026-04-02 16:36:35 +08:00
|
|
|
if user is None:
|
|
|
|
|
raise _auth_error(
|
|
|
|
|
status_code=404,
|
|
|
|
|
code="AUTH_USER_NOT_FOUND",
|
|
|
|
|
detail="User not found",
|
|
|
|
|
)
|
|
|
|
|
|
2026-04-02 18:39:35 +08:00
|
|
|
user_email = _normalize_email(getattr(user, "email", ""))
|
|
|
|
|
if not user_email:
|
2026-04-02 16:36:35 +08:00
|
|
|
raise _auth_error(
|
|
|
|
|
status_code=404,
|
|
|
|
|
code="AUTH_USER_NOT_FOUND",
|
|
|
|
|
detail="User not found",
|
|
|
|
|
)
|
|
|
|
|
|
2026-04-02 18:39:35 +08:00
|
|
|
return UserByEmailResponse(
|
2026-04-02 16:36:35 +08:00
|
|
|
id=str(getattr(user, "id", "")),
|
2026-04-02 18:39:35 +08:00
|
|
|
email=user_email,
|
2026-04-02 16:36:35 +08:00
|
|
|
created_at=str(getattr(user, "created_at", "")),
|
2026-04-02 18:39:35 +08:00
|
|
|
email_confirmed_at=(
|
|
|
|
|
str(getattr(user, "email_confirmed_at", ""))
|
|
|
|
|
if getattr(user, "email_confirmed_at", None)
|
2026-04-02 16:36:35 +08:00
|
|
|
else None
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
async def get_user_by_id(self, user_id: str) -> UserByIdResponse:
|
|
|
|
|
users = await self.get_users_by_ids([user_id])
|
|
|
|
|
resolved = users.get(user_id)
|
|
|
|
|
if resolved is None:
|
|
|
|
|
raise _auth_error(
|
|
|
|
|
status_code=404,
|
|
|
|
|
code="AUTH_USER_NOT_FOUND",
|
|
|
|
|
detail="User not found",
|
|
|
|
|
)
|
|
|
|
|
return resolved
|
|
|
|
|
|
|
|
|
|
async def get_users_by_ids(
|
|
|
|
|
self, user_ids: list[str]
|
|
|
|
|
) -> dict[str, UserByIdResponse]:
|
|
|
|
|
await self._refresh_user_lookup_cache_if_needed()
|
|
|
|
|
resolved: dict[str, UserByIdResponse] = {}
|
|
|
|
|
for raw_user_id in user_ids:
|
|
|
|
|
normalized_user_id = raw_user_id.strip()
|
|
|
|
|
if not normalized_user_id:
|
|
|
|
|
continue
|
|
|
|
|
user = self._users_by_id.get(normalized_user_id)
|
|
|
|
|
if user is None:
|
|
|
|
|
continue
|
|
|
|
|
user_attrs = getattr(user, "user", user)
|
|
|
|
|
resolved[normalized_user_id] = UserByIdResponse(
|
|
|
|
|
id=str(getattr(user_attrs, "id", "")),
|
2026-04-02 18:39:35 +08:00
|
|
|
email=getattr(user_attrs, "email", None),
|
2026-04-02 16:36:35 +08:00
|
|
|
created_at=str(getattr(user_attrs, "created_at", "")),
|
2026-04-02 18:39:35 +08:00
|
|
|
email_confirmed_at=(
|
|
|
|
|
str(getattr(user_attrs, "email_confirmed_at", ""))
|
|
|
|
|
if getattr(user_attrs, "email_confirmed_at", None)
|
2026-04-02 16:36:35 +08:00
|
|
|
else None
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
return resolved
|
|
|
|
|
|
2026-04-02 18:39:35 +08:00
|
|
|
async def search_user_ids_by_email(self, query: str, limit: int = 20) -> list[str]:
|
|
|
|
|
normalized_query = _normalize_email(query)
|
2026-04-02 16:36:35 +08:00
|
|
|
if not normalized_query:
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
await self._refresh_user_lookup_cache_if_needed()
|
2026-04-02 18:39:35 +08:00
|
|
|
matched_user = self._users_by_email.get(normalized_query)
|
|
|
|
|
if matched_user is None:
|
2026-04-02 16:36:35 +08:00
|
|
|
return []
|
2026-04-02 18:39:35 +08:00
|
|
|
user_id = str(getattr(matched_user, "id", ""))
|
|
|
|
|
return [user_id] if user_id else []
|
2026-04-02 16:36:35 +08:00
|
|
|
|
|
|
|
|
async def _refresh_user_lookup_cache_if_needed(self) -> None:
|
|
|
|
|
now = time.monotonic()
|
|
|
|
|
if now < self._user_lookup_cache_expires_at:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
admin_client = self._get_admin_client()
|
|
|
|
|
users = await asyncio.to_thread(_list_auth_users, admin_client)
|
2026-04-02 18:39:35 +08:00
|
|
|
users_by_email: dict[str, Any] = {}
|
2026-04-02 16:36:35 +08:00
|
|
|
users_by_id: dict[str, Any] = {}
|
|
|
|
|
for candidate in users:
|
|
|
|
|
candidate_id = str(getattr(candidate, "id", "")).strip()
|
|
|
|
|
if candidate_id:
|
|
|
|
|
users_by_id[candidate_id] = candidate
|
2026-04-02 18:39:35 +08:00
|
|
|
candidate_email = _normalize_email(getattr(candidate, "email", ""))
|
|
|
|
|
if candidate_email:
|
|
|
|
|
users_by_email[candidate_email] = candidate
|
2026-04-02 16:36:35 +08:00
|
|
|
self._users_by_id = users_by_id
|
2026-04-02 18:39:35 +08:00
|
|
|
self._users_by_email = users_by_email
|
2026-04-02 16:36:35 +08:00
|
|
|
self._user_lookup_cache_expires_at = now + self._user_lookup_cache_ttl_seconds
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _is_auth_upstream_unavailable(exc: AuthError) -> bool:
|
|
|
|
|
raw_status = getattr(exc, "status", None)
|
|
|
|
|
if raw_status is None:
|
|
|
|
|
raw_status = getattr(exc, "status_code", None)
|
|
|
|
|
if isinstance(raw_status, int) and 500 <= raw_status < 600:
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
raw_code = getattr(exc, "code", None)
|
|
|
|
|
code = str(raw_code).lower() if raw_code is not None else ""
|
|
|
|
|
message = str(exc).lower()
|
|
|
|
|
indicators = (
|
|
|
|
|
"request_timeout",
|
|
|
|
|
"timed out",
|
|
|
|
|
"timeout",
|
|
|
|
|
"gateway timeout",
|
|
|
|
|
"bad_gateway",
|
|
|
|
|
"service_unavailable",
|
|
|
|
|
"internal_server_error",
|
|
|
|
|
"unexpected_failure",
|
|
|
|
|
"upstream",
|
|
|
|
|
"500",
|
|
|
|
|
"502",
|
|
|
|
|
"503",
|
|
|
|
|
"504",
|
|
|
|
|
"5xx",
|
|
|
|
|
)
|
|
|
|
|
return any(token in code or token in message for token in indicators)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _map_auth_response(
|
|
|
|
|
response: object, failure_message: str, failure_code: str
|
|
|
|
|
) -> SessionResponse:
|
|
|
|
|
session = getattr(response, "session", None)
|
|
|
|
|
user = getattr(response, "user", None)
|
|
|
|
|
if session is None or user is None:
|
|
|
|
|
raise _auth_error(
|
|
|
|
|
status_code=401,
|
|
|
|
|
code=failure_code,
|
|
|
|
|
detail=failure_message,
|
|
|
|
|
)
|
|
|
|
|
|
2026-04-02 18:39:35 +08:00
|
|
|
email = _normalize_email(getattr(user, "email", None))
|
|
|
|
|
if not email:
|
2026-04-02 16:36:35 +08:00
|
|
|
raise _auth_error(
|
|
|
|
|
status_code=401,
|
|
|
|
|
code=failure_code,
|
|
|
|
|
detail=failure_message,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
try:
|
2026-04-02 18:39:35 +08:00
|
|
|
auth_user = AuthUser(id=str(user.id), email=str(email))
|
2026-04-02 16:36:35 +08:00
|
|
|
except ValidationError as exc:
|
|
|
|
|
logger.warning(
|
2026-04-02 18:39:35 +08:00
|
|
|
"Auth response returned invalid email format",
|
2026-04-02 16:36:35 +08:00
|
|
|
error_type=type(exc).__name__,
|
|
|
|
|
)
|
|
|
|
|
raise _auth_error(
|
|
|
|
|
status_code=401,
|
|
|
|
|
code=failure_code,
|
|
|
|
|
detail=failure_message,
|
|
|
|
|
) from exc
|
|
|
|
|
return SessionResponse(
|
|
|
|
|
access_token=str(session.access_token),
|
|
|
|
|
refresh_token=str(session.refresh_token),
|
|
|
|
|
expires_in=int(session.expires_in or 0),
|
|
|
|
|
token_type=str(session.token_type),
|
|
|
|
|
user=auth_user,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _list_auth_users(client: Any) -> list[Any]:
|
|
|
|
|
users: list[Any] = []
|
|
|
|
|
page = 1
|
|
|
|
|
max_pages = 100
|
|
|
|
|
|
|
|
|
|
while page <= max_pages:
|
|
|
|
|
response = client.auth.admin.list_users(page=page, per_page=100)
|
|
|
|
|
batch = (
|
|
|
|
|
list(response)
|
|
|
|
|
if isinstance(response, list)
|
|
|
|
|
else list(getattr(response, "users", []))
|
|
|
|
|
)
|
|
|
|
|
users.extend(batch)
|
|
|
|
|
|
|
|
|
|
if len(batch) < 100:
|
|
|
|
|
break
|
|
|
|
|
page += 1
|
|
|
|
|
|
|
|
|
|
return users
|
|
|
|
|
|
|
|
|
|
|
2026-04-02 18:39:35 +08:00
|
|
|
_EMAIL_PATTERN = re.compile(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$")
|
2026-04-02 16:36:35 +08:00
|
|
|
|
|
|
|
|
|
2026-04-02 18:39:35 +08:00
|
|
|
def _normalize_email(raw_email: object) -> str | None:
|
|
|
|
|
if not isinstance(raw_email, str):
|
2026-04-02 16:36:35 +08:00
|
|
|
return None
|
2026-04-02 18:39:35 +08:00
|
|
|
email = raw_email.strip().lower()
|
|
|
|
|
if not _EMAIL_PATTERN.fullmatch(email):
|
2026-04-02 16:36:35 +08:00
|
|
|
return None
|
2026-04-02 18:39:35 +08:00
|
|
|
return email
|