chore: 迁移到 social-app 架构,集成 Supabase 和 taskiq worker

This commit is contained in:
qzl
2026-04-02 16:36:35 +08:00
parent 695adb7d6f
commit 92cdfd9fca
132 changed files with 5802 additions and 759 deletions
+430
View File
@@ -0,0 +1,430 @@
from __future__ import annotations
import asyncio
import time
from typing import Any, cast
from pydantic import ValidationError
from supabase import AuthError
from core.http.errors import ApiProblemError
from core.logging import get_logger
from services.base.supabase import supabase_service
from v1.auth.schemas import (
AuthUser,
OtpSendRequest,
PhoneSessionCreateRequest,
SessionRefreshRequest,
SessionResponse,
UserByIdResponse,
UserByPhoneResponse,
)
from v1.auth.service import AuthServiceGateway
logger = get_logger("v1.auth.gateway")
AUTH_UNAVAILABLE_DETAIL = "Auth service temporarily unavailable"
def _auth_error(
*,
status_code: int,
code: str,
detail: str,
) -> ApiProblemError:
return ApiProblemError(status_code=status_code, code=code, detail=detail)
class SupabaseAuthGateway(AuthServiceGateway):
def __init__(self) -> None:
self._user_lookup_cache_ttl_seconds: int = 60
self._user_lookup_cache_expires_at: float = 0.0
self._users_by_phone: dict[str, Any] = {}
self._users_by_id: dict[str, Any] = {}
def _get_client(self) -> Any:
return supabase_service.get_client()
def _get_admin_client(self) -> Any:
return supabase_service.get_admin_client()
async def send_otp(self, request: OtpSendRequest) -> None:
client = self._get_client()
payload: dict[str, Any] = {
"phone": request.phone,
"options": {"should_create_user": True},
}
try:
sign_in_with_otp = cast(Any, client.auth.sign_in_with_otp)
await asyncio.to_thread(sign_in_with_otp, payload)
except AuthError as exc:
logger.warning("Send otp failed", error_type=type(exc).__name__)
if _is_auth_upstream_unavailable(exc):
raise _auth_error(
status_code=503,
code="AUTH_SERVICE_UNAVAILABLE",
detail=AUTH_UNAVAILABLE_DETAIL,
) from exc
raise _auth_error(
status_code=429,
code="AUTH_TOO_MANY_REQUESTS",
detail="Too many requests",
) from exc
async def create_phone_session(
self, request: PhoneSessionCreateRequest
) -> SessionResponse:
client = self._get_client()
payload: dict[str, Any] = {
"type": "sms",
"phone": request.phone,
"token": request.token,
}
try:
verify_otp = cast(Any, client.auth.verify_otp)
response = await asyncio.to_thread(verify_otp, payload)
return _map_auth_response(
response,
"Invalid verification code",
"AUTH_VERIFICATION_CODE_INVALID",
)
except AuthError as exc:
logger.warning("Create phone session failed", error_type=type(exc).__name__)
if _is_auth_upstream_unavailable(exc):
raise _auth_error(
status_code=503,
code="AUTH_SERVICE_UNAVAILABLE",
detail=AUTH_UNAVAILABLE_DETAIL,
) from exc
raise _auth_error(
status_code=401,
code="AUTH_VERIFICATION_CODE_INVALID",
detail="Invalid verification code",
) from exc
async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse:
client = self._get_client()
try:
response = await asyncio.to_thread(
client.auth.refresh_session,
request.refresh_token,
)
return _map_auth_response(
response,
"Invalid refresh token",
"AUTH_REFRESH_TOKEN_INVALID",
)
except AuthError as exc:
logger.warning("Refresh failed", error_type=type(exc).__name__)
if _is_auth_upstream_unavailable(exc):
raise _auth_error(
status_code=503,
code="AUTH_SERVICE_UNAVAILABLE",
detail=AUTH_UNAVAILABLE_DETAIL,
) from exc
raise _auth_error(
status_code=401,
code="AUTH_REFRESH_TOKEN_INVALID",
detail="Invalid refresh token",
) from exc
async def delete_session(self, refresh_token: str | None) -> None:
if not refresh_token:
raise _auth_error(
status_code=401,
code="AUTH_REFRESH_TOKEN_MISSING",
detail="Missing refresh token",
)
client = self._get_client()
try:
response = await asyncio.to_thread(
client.auth.refresh_session,
refresh_token,
)
session = getattr(response, "session", None)
if session is None:
raise _auth_error(
status_code=401,
code="AUTH_REFRESH_TOKEN_INVALID",
detail="Invalid refresh token",
)
await asyncio.to_thread(
client.auth.set_session,
str(session.access_token),
str(session.refresh_token),
)
await asyncio.to_thread(client.auth.sign_out)
except AuthError as exc:
logger.warning("Logout failed", error_type=type(exc).__name__)
if _is_auth_upstream_unavailable(exc):
raise _auth_error(
status_code=503,
code="AUTH_SERVICE_UNAVAILABLE",
detail=AUTH_UNAVAILABLE_DETAIL,
) from exc
raise _auth_error(
status_code=401,
code="AUTH_REFRESH_TOKEN_INVALID",
detail="Invalid refresh token",
) from exc
async def get_user_by_phone(self, phone: str) -> UserByPhoneResponse:
normalized_phone = _normalize_phone(phone)
if not normalized_phone:
raise _auth_error(
status_code=404,
code="AUTH_USER_NOT_FOUND",
detail="User not found",
)
await self._refresh_user_lookup_cache_if_needed()
user = self._users_by_phone.get(normalized_phone)
if user is None:
raise _auth_error(
status_code=404,
code="AUTH_USER_NOT_FOUND",
detail="User not found",
)
user_phone = _normalize_phone(getattr(user, "phone", ""))
if not user_phone:
raise _auth_error(
status_code=404,
code="AUTH_USER_NOT_FOUND",
detail="User not found",
)
return UserByPhoneResponse(
id=str(getattr(user, "id", "")),
phone=user_phone,
created_at=str(getattr(user, "created_at", "")),
phone_confirmed_at=(
str(getattr(user, "phone_confirmed_at", ""))
if getattr(user, "phone_confirmed_at", None)
else None
),
)
async def get_user_by_id(self, user_id: str) -> UserByIdResponse:
users = await self.get_users_by_ids([user_id])
resolved = users.get(user_id)
if resolved is None:
raise _auth_error(
status_code=404,
code="AUTH_USER_NOT_FOUND",
detail="User not found",
)
return resolved
async def get_users_by_ids(
self, user_ids: list[str]
) -> dict[str, UserByIdResponse]:
await self._refresh_user_lookup_cache_if_needed()
resolved: dict[str, UserByIdResponse] = {}
for raw_user_id in user_ids:
normalized_user_id = raw_user_id.strip()
if not normalized_user_id:
continue
user = self._users_by_id.get(normalized_user_id)
if user is None:
continue
user_attrs = getattr(user, "user", user)
resolved[normalized_user_id] = UserByIdResponse(
id=str(getattr(user_attrs, "id", "")),
phone=getattr(user_attrs, "phone", None),
created_at=str(getattr(user_attrs, "created_at", "")),
phone_confirmed_at=(
str(getattr(user_attrs, "phone_confirmed_at", ""))
if getattr(user_attrs, "phone_confirmed_at", None)
else None
),
)
return resolved
async def search_user_ids_by_phone(self, query: str, limit: int = 20) -> list[str]:
normalized_query = _normalize_phone_search_query(query)
if not normalized_query:
return []
await self._refresh_user_lookup_cache_if_needed()
if normalized_query.startswith("+"):
matched_user = self._users_by_phone.get(normalized_query)
if matched_user is None:
return []
user_id = str(getattr(matched_user, "id", ""))
return [user_id] if user_id else []
digits = _digits_only(normalized_query)
if not digits:
return []
matched_records: list[tuple[str, str]] = []
for cached_phone, candidate in self._users_by_phone.items():
candidate_digits = _digits_only(cached_phone)
if not candidate_digits.endswith(digits):
continue
user_id = str(getattr(candidate, "id", ""))
if user_id:
matched_records.append((cached_phone, user_id))
if not matched_records:
return []
unique_ids: list[str] = []
for _, user_id in sorted(matched_records, key=lambda item: item[0]):
if user_id in unique_ids:
continue
unique_ids.append(user_id)
if len(unique_ids) >= max(1, limit):
break
return unique_ids
async def _refresh_user_lookup_cache_if_needed(self) -> None:
now = time.monotonic()
if now < self._user_lookup_cache_expires_at:
return
admin_client = self._get_admin_client()
users = await asyncio.to_thread(_list_auth_users, admin_client)
users_by_phone: dict[str, Any] = {}
users_by_id: dict[str, Any] = {}
for candidate in users:
candidate_id = str(getattr(candidate, "id", "")).strip()
if candidate_id:
users_by_id[candidate_id] = candidate
candidate_phone = _normalize_phone(getattr(candidate, "phone", ""))
if candidate_phone:
users_by_phone[candidate_phone] = candidate
self._users_by_id = users_by_id
self._users_by_phone = users_by_phone
self._user_lookup_cache_expires_at = now + self._user_lookup_cache_ttl_seconds
def _is_auth_upstream_unavailable(exc: AuthError) -> bool:
raw_status = getattr(exc, "status", None)
if raw_status is None:
raw_status = getattr(exc, "status_code", None)
if isinstance(raw_status, int) and 500 <= raw_status < 600:
return True
raw_code = getattr(exc, "code", None)
code = str(raw_code).lower() if raw_code is not None else ""
message = str(exc).lower()
indicators = (
"request_timeout",
"timed out",
"timeout",
"gateway timeout",
"bad_gateway",
"service_unavailable",
"internal_server_error",
"unexpected_failure",
"upstream",
"500",
"502",
"503",
"504",
"5xx",
)
return any(token in code or token in message for token in indicators)
def _map_auth_response(
response: object, failure_message: str, failure_code: str
) -> SessionResponse:
session = getattr(response, "session", None)
user = getattr(response, "user", None)
if session is None or user is None:
raise _auth_error(
status_code=401,
code=failure_code,
detail=failure_message,
)
phone = _normalize_phone(getattr(user, "phone", None))
if not phone:
raise _auth_error(
status_code=401,
code=failure_code,
detail=failure_message,
)
try:
auth_user = AuthUser(id=str(user.id), phone=str(phone))
except ValidationError as exc:
logger.warning(
"Auth response returned invalid phone format",
error_type=type(exc).__name__,
)
raise _auth_error(
status_code=401,
code=failure_code,
detail=failure_message,
) from exc
return SessionResponse(
access_token=str(session.access_token),
refresh_token=str(session.refresh_token),
expires_in=int(session.expires_in or 0),
token_type=str(session.token_type),
user=auth_user,
)
def _list_auth_users(client: Any) -> list[Any]:
users: list[Any] = []
page = 1
max_pages = 100
while page <= max_pages:
response = client.auth.admin.list_users(page=page, per_page=100)
batch = (
list(response)
if isinstance(response, list)
else list(getattr(response, "users", []))
)
users.extend(batch)
if len(batch) < 100:
break
page += 1
return users
def _sanitize_phone_token(raw: object) -> str:
token = str(raw).strip()
for separator in (" ", "-", "(", ")"):
token = token.replace(separator, "")
return token
def _normalize_phone(raw_phone: object) -> str | None:
phone = _sanitize_phone_token(raw_phone)
if not phone:
return None
if phone.startswith("00") and len(phone) > 2:
return f"+{phone[2:]}"
if phone.startswith("+"):
return phone
if phone.isdigit():
return f"+{phone}"
return None
def _normalize_phone_search_query(raw_query: str) -> str | None:
query = _sanitize_phone_token(raw_query)
if not query:
return None
if query.startswith("00") and len(query) > 2:
return f"+{query[2:]}"
if query.startswith("+"):
return query
if query.isdigit():
return query
return None
def _digits_only(value: str) -> str:
return "".join(ch for ch in value if ch.isdigit())