feat: 切换邮箱认证并重构前后端启动与门禁
This commit is contained in:
@@ -0,0 +1,129 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.config.settings import config
|
||||
from core.http.errors import ApiProblemError
|
||||
from core.http.response import build_problem_details
|
||||
from core.logging import configure_logging, get_logger, log_service_banner
|
||||
from services.base import close_registered_services, initialize_registered_services
|
||||
from v1.router import router as v1_router
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
status: str
|
||||
|
||||
|
||||
configure_logging(config)
|
||||
log_service_banner(
|
||||
service_name=config.runtime.service_name,
|
||||
environment=config.runtime.environment,
|
||||
)
|
||||
|
||||
logger = get_logger("api.app")
|
||||
SERVICE_STARTUP_ORDER = ["redis", "supabase"]
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_: FastAPI) -> AsyncGenerator[None, None]:
|
||||
initialized, services = await initialize_registered_services(SERVICE_STARTUP_ORDER)
|
||||
if not initialized:
|
||||
logger.error("Service initialization failed, aborting startup")
|
||||
raise RuntimeError("Service initialization failed")
|
||||
|
||||
logger.info("Base services initialized", services=SERVICE_STARTUP_ORDER)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
closed = await close_registered_services(services)
|
||||
if not closed:
|
||||
logger.warning("Failed to close all base services")
|
||||
logger.info("Base services closed", services=SERVICE_STARTUP_ORDER)
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=config.cors.allow_origins,
|
||||
allow_credentials=config.cors.allow_credentials,
|
||||
allow_methods=config.cors.allow_methods,
|
||||
allow_headers=config.cors.allow_headers,
|
||||
)
|
||||
app.include_router(v1_router)
|
||||
|
||||
|
||||
@app.get("/health", response_model=HealthResponse)
|
||||
async def health() -> HealthResponse:
|
||||
return HealthResponse(status="ok")
|
||||
|
||||
|
||||
@app.exception_handler(ApiProblemError)
|
||||
async def api_problem_exception_handler(
|
||||
request: Request,
|
||||
exc: ApiProblemError,
|
||||
) -> JSONResponse:
|
||||
problem = build_problem_details(
|
||||
status_code=exc.status_code,
|
||||
detail=exc.detail,
|
||||
instance=request.url.path,
|
||||
code=exc.code,
|
||||
params=exc.params,
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=problem.model_dump(),
|
||||
media_type="application/problem+json",
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def request_validation_exception_handler(
|
||||
request: Request,
|
||||
exc: RequestValidationError,
|
||||
) -> JSONResponse:
|
||||
logger.warning(
|
||||
"Request validation error",
|
||||
path=request.url.path,
|
||||
method=request.method,
|
||||
errors=exc.errors(),
|
||||
)
|
||||
problem = build_problem_details(
|
||||
status_code=422,
|
||||
detail="Invalid request",
|
||||
instance=request.url.path,
|
||||
code="REQUEST_VALIDATION_ERROR",
|
||||
params={"errors": exc.errors()},
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=422,
|
||||
content=problem.model_dump(),
|
||||
media_type="application/problem+json",
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse:
|
||||
logger.exception(
|
||||
"Unhandled error",
|
||||
path=request.url.path,
|
||||
method=request.method,
|
||||
error_type=exc.__class__.__name__,
|
||||
)
|
||||
problem = build_problem_details(
|
||||
status_code=500,
|
||||
detail="Internal Server Error",
|
||||
instance=request.url.path,
|
||||
code="INTERNAL_ERROR",
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=problem.model_dump(),
|
||||
media_type="application/problem+json",
|
||||
)
|
||||
@@ -6,15 +6,14 @@ import sys
|
||||
from pathlib import Path
|
||||
|
||||
from core.config.initial.init_data import initialize_data
|
||||
from core.config.settings import config
|
||||
from core.logging import get_logger
|
||||
|
||||
logger = get_logger("core.runtime.cli")
|
||||
|
||||
|
||||
def _resolve_alembic_path() -> Path:
|
||||
project_root = Path(__file__).parents[3]
|
||||
alembic_path = project_root / "alembic" / "alembic.ini"
|
||||
project_root = Path(__file__).parents[4]
|
||||
alembic_path = project_root / "backend" / "alembic" / "alembic.ini"
|
||||
if not alembic_path.exists():
|
||||
raise FileNotFoundError(f"Alembic config not found at {alembic_path}")
|
||||
return alembic_path
|
||||
|
||||
@@ -62,7 +62,6 @@ class ClientTimeContext(BaseModel):
|
||||
|
||||
class RuntimeMode(str, Enum):
|
||||
CHAT = "chat"
|
||||
AUTOMATION = "automation"
|
||||
|
||||
|
||||
class ForwardedPropsPayload(BaseModel):
|
||||
|
||||
@@ -4,8 +4,6 @@ import asyncio
|
||||
from typing import Any
|
||||
|
||||
from supabase import create_client
|
||||
from storage3.exceptions import StorageApiError
|
||||
|
||||
from core.config.settings import SupabaseSettings, config
|
||||
|
||||
from .service_interface import BaseServiceProvider, register_service_instance
|
||||
@@ -183,9 +181,6 @@ class SupabaseService(BaseServiceProvider):
|
||||
|
||||
def _is_bucket_not_found_error(self, exc: Exception) -> bool:
|
||||
"""Check if the exception indicates a bucket was not found."""
|
||||
if isinstance(exc, StorageApiError):
|
||||
message = str(exc).lower()
|
||||
return "bucket" in message and "not found" in message
|
||||
message = str(exc).lower()
|
||||
return "bucket" in message and "not found" in message
|
||||
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
from schemas.domain.automation import AutomationJobConfig
|
||||
|
||||
_CONFIG_NAME_PATTERN = re.compile(r"^[a-z0-9][a-z0-9_-]{0,63}$")
|
||||
|
||||
|
||||
def _automation_yaml_path(config_name: str) -> Path:
|
||||
if not _CONFIG_NAME_PATTERN.fullmatch(config_name):
|
||||
raise ValueError("invalid automation config name")
|
||||
return (
|
||||
Path(__file__).resolve().parents[2]
|
||||
/ "core"
|
||||
/ "config"
|
||||
/ "static"
|
||||
/ "automation"
|
||||
/ f"{config_name}.yaml"
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=16)
|
||||
def load_static_automation_job_config(*, config_name: str) -> AutomationJobConfig:
|
||||
path = _automation_yaml_path(config_name)
|
||||
with path.open("r", encoding="utf-8") as file:
|
||||
loaded: Any = yaml.safe_load(file) or {}
|
||||
if not isinstance(loaded, dict):
|
||||
raise ValueError(f"invalid automation config format: {path}")
|
||||
return AutomationJobConfig.model_validate(loaded)
|
||||
@@ -1,27 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.db import get_db
|
||||
from v1.auth.gateway import SupabaseAuthGateway
|
||||
from v1.auth.registration_bootstrap import (
|
||||
RegistrationAutomationBootstrapService,
|
||||
RegistrationBootstrapRepository,
|
||||
)
|
||||
from v1.auth.service import AuthService
|
||||
|
||||
|
||||
def get_auth_service(
|
||||
session: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> AuthService:
|
||||
bootstrapper = RegistrationAutomationBootstrapService(
|
||||
repository=RegistrationBootstrapRepository(session=session),
|
||||
session=session,
|
||||
)
|
||||
def get_auth_service() -> AuthService:
|
||||
return AuthService(
|
||||
gateway=SupabaseAuthGateway(),
|
||||
registration_bootstrapper=bootstrapper,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,91 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Callable, cast
|
||||
|
||||
from supabase import AuthError
|
||||
|
||||
from core.http.errors import ApiProblemError
|
||||
from core.logging import get_logger
|
||||
from v1.auth.schemas import EmailSessionCreateRequest, SessionResponse
|
||||
|
||||
logger = get_logger("v1.auth.dev_email_session")
|
||||
|
||||
|
||||
def _auth_error(*, status_code: int, code: str, detail: str) -> ApiProblemError:
|
||||
return ApiProblemError(status_code=status_code, code=code, detail=detail)
|
||||
|
||||
|
||||
async def create_dev_email_session(
|
||||
*,
|
||||
request: EmailSessionCreateRequest,
|
||||
client: Any,
|
||||
admin_client: Any,
|
||||
auth_unavailable_detail: str,
|
||||
is_auth_upstream_unavailable: Callable[[AuthError], bool],
|
||||
map_auth_response: Callable[[object, str, str], SessionResponse],
|
||||
) -> SessionResponse:
|
||||
generate_link_payload: dict[str, Any] = {
|
||||
"type": "magiclink",
|
||||
"email": request.email,
|
||||
}
|
||||
|
||||
try:
|
||||
generate_link = cast(Any, admin_client.auth.admin.generate_link)
|
||||
link_response = await asyncio.to_thread(generate_link, generate_link_payload)
|
||||
except AuthError as exc:
|
||||
logger.warning(
|
||||
"Dev email session link generation failed",
|
||||
error_type=type(exc).__name__,
|
||||
)
|
||||
if is_auth_upstream_unavailable(exc):
|
||||
raise _auth_error(
|
||||
status_code=503,
|
||||
code="AUTH_SERVICE_UNAVAILABLE",
|
||||
detail=auth_unavailable_detail,
|
||||
) from exc
|
||||
raise _auth_error(
|
||||
status_code=401,
|
||||
code="AUTH_VERIFICATION_CODE_INVALID",
|
||||
detail="Invalid verification code",
|
||||
) from exc
|
||||
|
||||
properties = getattr(link_response, "properties", None)
|
||||
dev_token = str(getattr(properties, "email_otp", "")).strip()
|
||||
if not dev_token:
|
||||
raise _auth_error(
|
||||
status_code=401,
|
||||
code="AUTH_VERIFICATION_CODE_INVALID",
|
||||
detail="Invalid verification code",
|
||||
)
|
||||
|
||||
verify_payload: dict[str, Any] = {
|
||||
"type": "email",
|
||||
"email": request.email,
|
||||
"token": dev_token,
|
||||
}
|
||||
try:
|
||||
verify_otp = cast(Any, client.auth.verify_otp)
|
||||
response = await asyncio.to_thread(verify_otp, verify_payload)
|
||||
logger.info("Dev email session bypassed otp verification")
|
||||
return map_auth_response(
|
||||
response,
|
||||
"Invalid verification code",
|
||||
"AUTH_VERIFICATION_CODE_INVALID",
|
||||
)
|
||||
except AuthError as exc:
|
||||
logger.warning(
|
||||
"Dev email session verification failed",
|
||||
error_type=type(exc).__name__,
|
||||
)
|
||||
if is_auth_upstream_unavailable(exc):
|
||||
raise _auth_error(
|
||||
status_code=503,
|
||||
code="AUTH_SERVICE_UNAVAILABLE",
|
||||
detail=auth_unavailable_detail,
|
||||
) from exc
|
||||
raise _auth_error(
|
||||
status_code=401,
|
||||
code="AUTH_VERIFICATION_CODE_INVALID",
|
||||
detail="Invalid verification code",
|
||||
) from exc
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
import time
|
||||
from typing import Any, cast
|
||||
|
||||
@@ -8,17 +9,19 @@ from pydantic import ValidationError
|
||||
|
||||
from supabase import AuthError
|
||||
|
||||
from core.config.settings import config
|
||||
from core.http.errors import ApiProblemError
|
||||
from core.logging import get_logger
|
||||
from services.base.supabase import supabase_service
|
||||
from v1.auth.dev_email_session import create_dev_email_session
|
||||
from v1.auth.schemas import (
|
||||
AuthUser,
|
||||
EmailSessionCreateRequest,
|
||||
OtpSendRequest,
|
||||
PhoneSessionCreateRequest,
|
||||
SessionRefreshRequest,
|
||||
SessionResponse,
|
||||
UserByIdResponse,
|
||||
UserByPhoneResponse,
|
||||
UserByEmailResponse,
|
||||
)
|
||||
from v1.auth.service import AuthServiceGateway
|
||||
|
||||
@@ -40,7 +43,7 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
def __init__(self) -> None:
|
||||
self._user_lookup_cache_ttl_seconds: int = 60
|
||||
self._user_lookup_cache_expires_at: float = 0.0
|
||||
self._users_by_phone: dict[str, Any] = {}
|
||||
self._users_by_email: dict[str, Any] = {}
|
||||
self._users_by_id: dict[str, Any] = {}
|
||||
|
||||
def _get_client(self) -> Any:
|
||||
@@ -52,7 +55,7 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
async def send_otp(self, request: OtpSendRequest) -> None:
|
||||
client = self._get_client()
|
||||
payload: dict[str, Any] = {
|
||||
"phone": request.phone,
|
||||
"email": request.email,
|
||||
"options": {"should_create_user": True},
|
||||
}
|
||||
try:
|
||||
@@ -72,13 +75,23 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
detail="Too many requests",
|
||||
) from exc
|
||||
|
||||
async def create_phone_session(
|
||||
self, request: PhoneSessionCreateRequest
|
||||
async def create_email_session(
|
||||
self, request: EmailSessionCreateRequest
|
||||
) -> SessionResponse:
|
||||
if config.runtime.environment == "dev":
|
||||
return await create_dev_email_session(
|
||||
request=request,
|
||||
client=self._get_client(),
|
||||
admin_client=self._get_admin_client(),
|
||||
auth_unavailable_detail=AUTH_UNAVAILABLE_DETAIL,
|
||||
is_auth_upstream_unavailable=_is_auth_upstream_unavailable,
|
||||
map_auth_response=_map_auth_response,
|
||||
)
|
||||
|
||||
client = self._get_client()
|
||||
payload: dict[str, Any] = {
|
||||
"type": "sms",
|
||||
"phone": request.phone,
|
||||
"type": "email",
|
||||
"email": request.email,
|
||||
"token": request.token,
|
||||
}
|
||||
try:
|
||||
@@ -90,7 +103,7 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
"AUTH_VERIFICATION_CODE_INVALID",
|
||||
)
|
||||
except AuthError as exc:
|
||||
logger.warning("Create phone session failed", error_type=type(exc).__name__)
|
||||
logger.warning("Create email session failed", error_type=type(exc).__name__)
|
||||
if _is_auth_upstream_unavailable(exc):
|
||||
raise _auth_error(
|
||||
status_code=503,
|
||||
@@ -169,9 +182,9 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
detail="Invalid refresh token",
|
||||
) from exc
|
||||
|
||||
async def get_user_by_phone(self, phone: str) -> UserByPhoneResponse:
|
||||
normalized_phone = _normalize_phone(phone)
|
||||
if not normalized_phone:
|
||||
async def get_user_by_email(self, email: str) -> UserByEmailResponse:
|
||||
normalized_email = _normalize_email(email)
|
||||
if not normalized_email:
|
||||
raise _auth_error(
|
||||
status_code=404,
|
||||
code="AUTH_USER_NOT_FOUND",
|
||||
@@ -180,7 +193,7 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
|
||||
await self._refresh_user_lookup_cache_if_needed()
|
||||
|
||||
user = self._users_by_phone.get(normalized_phone)
|
||||
user = self._users_by_email.get(normalized_email)
|
||||
if user is None:
|
||||
raise _auth_error(
|
||||
status_code=404,
|
||||
@@ -188,21 +201,21 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
detail="User not found",
|
||||
)
|
||||
|
||||
user_phone = _normalize_phone(getattr(user, "phone", ""))
|
||||
if not user_phone:
|
||||
user_email = _normalize_email(getattr(user, "email", ""))
|
||||
if not user_email:
|
||||
raise _auth_error(
|
||||
status_code=404,
|
||||
code="AUTH_USER_NOT_FOUND",
|
||||
detail="User not found",
|
||||
)
|
||||
|
||||
return UserByPhoneResponse(
|
||||
return UserByEmailResponse(
|
||||
id=str(getattr(user, "id", "")),
|
||||
phone=user_phone,
|
||||
email=user_email,
|
||||
created_at=str(getattr(user, "created_at", "")),
|
||||
phone_confirmed_at=(
|
||||
str(getattr(user, "phone_confirmed_at", ""))
|
||||
if getattr(user, "phone_confirmed_at", None)
|
||||
email_confirmed_at=(
|
||||
str(getattr(user, "email_confirmed_at", ""))
|
||||
if getattr(user, "email_confirmed_at", None)
|
||||
else None
|
||||
),
|
||||
)
|
||||
@@ -233,53 +246,27 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
user_attrs = getattr(user, "user", user)
|
||||
resolved[normalized_user_id] = UserByIdResponse(
|
||||
id=str(getattr(user_attrs, "id", "")),
|
||||
phone=getattr(user_attrs, "phone", None),
|
||||
email=getattr(user_attrs, "email", None),
|
||||
created_at=str(getattr(user_attrs, "created_at", "")),
|
||||
phone_confirmed_at=(
|
||||
str(getattr(user_attrs, "phone_confirmed_at", ""))
|
||||
if getattr(user_attrs, "phone_confirmed_at", None)
|
||||
email_confirmed_at=(
|
||||
str(getattr(user_attrs, "email_confirmed_at", ""))
|
||||
if getattr(user_attrs, "email_confirmed_at", None)
|
||||
else None
|
||||
),
|
||||
)
|
||||
return resolved
|
||||
|
||||
async def search_user_ids_by_phone(self, query: str, limit: int = 20) -> list[str]:
|
||||
normalized_query = _normalize_phone_search_query(query)
|
||||
async def search_user_ids_by_email(self, query: str, limit: int = 20) -> list[str]:
|
||||
normalized_query = _normalize_email(query)
|
||||
if not normalized_query:
|
||||
return []
|
||||
|
||||
await self._refresh_user_lookup_cache_if_needed()
|
||||
if normalized_query.startswith("+"):
|
||||
matched_user = self._users_by_phone.get(normalized_query)
|
||||
if matched_user is None:
|
||||
return []
|
||||
user_id = str(getattr(matched_user, "id", ""))
|
||||
return [user_id] if user_id else []
|
||||
|
||||
digits = _digits_only(normalized_query)
|
||||
if not digits:
|
||||
matched_user = self._users_by_email.get(normalized_query)
|
||||
if matched_user is None:
|
||||
return []
|
||||
|
||||
matched_records: list[tuple[str, str]] = []
|
||||
for cached_phone, candidate in self._users_by_phone.items():
|
||||
candidate_digits = _digits_only(cached_phone)
|
||||
if not candidate_digits.endswith(digits):
|
||||
continue
|
||||
user_id = str(getattr(candidate, "id", ""))
|
||||
if user_id:
|
||||
matched_records.append((cached_phone, user_id))
|
||||
|
||||
if not matched_records:
|
||||
return []
|
||||
|
||||
unique_ids: list[str] = []
|
||||
for _, user_id in sorted(matched_records, key=lambda item: item[0]):
|
||||
if user_id in unique_ids:
|
||||
continue
|
||||
unique_ids.append(user_id)
|
||||
if len(unique_ids) >= max(1, limit):
|
||||
break
|
||||
return unique_ids
|
||||
user_id = str(getattr(matched_user, "id", ""))
|
||||
return [user_id] if user_id else []
|
||||
|
||||
async def _refresh_user_lookup_cache_if_needed(self) -> None:
|
||||
now = time.monotonic()
|
||||
@@ -288,17 +275,17 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
|
||||
admin_client = self._get_admin_client()
|
||||
users = await asyncio.to_thread(_list_auth_users, admin_client)
|
||||
users_by_phone: dict[str, Any] = {}
|
||||
users_by_email: dict[str, Any] = {}
|
||||
users_by_id: dict[str, Any] = {}
|
||||
for candidate in users:
|
||||
candidate_id = str(getattr(candidate, "id", "")).strip()
|
||||
if candidate_id:
|
||||
users_by_id[candidate_id] = candidate
|
||||
candidate_phone = _normalize_phone(getattr(candidate, "phone", ""))
|
||||
if candidate_phone:
|
||||
users_by_phone[candidate_phone] = candidate
|
||||
candidate_email = _normalize_email(getattr(candidate, "email", ""))
|
||||
if candidate_email:
|
||||
users_by_email[candidate_email] = candidate
|
||||
self._users_by_id = users_by_id
|
||||
self._users_by_phone = users_by_phone
|
||||
self._users_by_email = users_by_email
|
||||
self._user_lookup_cache_expires_at = now + self._user_lookup_cache_ttl_seconds
|
||||
|
||||
|
||||
@@ -343,8 +330,8 @@ def _map_auth_response(
|
||||
detail=failure_message,
|
||||
)
|
||||
|
||||
phone = _normalize_phone(getattr(user, "phone", None))
|
||||
if not phone:
|
||||
email = _normalize_email(getattr(user, "email", None))
|
||||
if not email:
|
||||
raise _auth_error(
|
||||
status_code=401,
|
||||
code=failure_code,
|
||||
@@ -352,10 +339,10 @@ def _map_auth_response(
|
||||
)
|
||||
|
||||
try:
|
||||
auth_user = AuthUser(id=str(user.id), phone=str(phone))
|
||||
auth_user = AuthUser(id=str(user.id), email=str(email))
|
||||
except ValidationError as exc:
|
||||
logger.warning(
|
||||
"Auth response returned invalid phone format",
|
||||
"Auth response returned invalid email format",
|
||||
error_type=type(exc).__name__,
|
||||
)
|
||||
raise _auth_error(
|
||||
@@ -393,38 +380,13 @@ def _list_auth_users(client: Any) -> list[Any]:
|
||||
return users
|
||||
|
||||
|
||||
def _sanitize_phone_token(raw: object) -> str:
|
||||
token = str(raw).strip()
|
||||
for separator in (" ", "-", "(", ")"):
|
||||
token = token.replace(separator, "")
|
||||
return token
|
||||
_EMAIL_PATTERN = re.compile(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$")
|
||||
|
||||
|
||||
def _normalize_phone(raw_phone: object) -> str | None:
|
||||
phone = _sanitize_phone_token(raw_phone)
|
||||
if not phone:
|
||||
def _normalize_email(raw_email: object) -> str | None:
|
||||
if not isinstance(raw_email, str):
|
||||
return None
|
||||
if phone.startswith("00") and len(phone) > 2:
|
||||
return f"+{phone[2:]}"
|
||||
if phone.startswith("+"):
|
||||
return phone
|
||||
if phone.isdigit():
|
||||
return f"+{phone}"
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_phone_search_query(raw_query: str) -> str | None:
|
||||
query = _sanitize_phone_token(raw_query)
|
||||
if not query:
|
||||
email = raw_email.strip().lower()
|
||||
if not _EMAIL_PATTERN.fullmatch(email):
|
||||
return None
|
||||
if query.startswith("00") and len(query) > 2:
|
||||
return f"+{query[2:]}"
|
||||
if query.startswith("+"):
|
||||
return query
|
||||
if query.isdigit():
|
||||
return query
|
||||
return None
|
||||
|
||||
|
||||
def _digits_only(value: str) -> str:
|
||||
return "".join(ch for ch in value if ch.isdigit())
|
||||
return email
|
||||
|
||||
@@ -1,239 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, time, timedelta
|
||||
from typing import Protocol
|
||||
from uuid import UUID, uuid4
|
||||
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.logging import get_logger
|
||||
from models.automation_jobs import AutomationJob
|
||||
from schemas.enums import AutomationJobStatus, MemoryType, ScheduleType
|
||||
from models.profile import Profile
|
||||
from schemas.domain.automation import AutomationJobConfig, ScheduleConfig
|
||||
from schemas.domain.memory_content import UserMemoryContent, WorkProfileContent
|
||||
from schemas.shared.user import parse_profile_settings
|
||||
from v1.auth.automation_static_config import load_static_automation_job_config
|
||||
from v1.auth.schemas import RegistrationBootstrapRequest
|
||||
from v1.memories.repository import SQLAlchemyMemoriesRepository
|
||||
|
||||
logger = get_logger("v1.auth.registration_bootstrap")
|
||||
|
||||
|
||||
class RegistrationBootstrapRepository:
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
self._memories_repository = SQLAlchemyMemoriesRepository(session)
|
||||
|
||||
async def get_profile_timezone(self, *, user_id: UUID) -> str:
|
||||
stmt = select(Profile.settings).where(Profile.id == user_id)
|
||||
settings = (await self._session.execute(stmt)).scalar_one_or_none()
|
||||
parsed = parse_profile_settings(
|
||||
settings if isinstance(settings, dict) else None
|
||||
)
|
||||
return parsed.preferences.timezone
|
||||
|
||||
async def insert_bootstrap_automation_job_if_absent(
|
||||
self,
|
||||
*,
|
||||
owner_id: UUID,
|
||||
bootstrap_key: str,
|
||||
title: str,
|
||||
config: AutomationJobConfig,
|
||||
timezone_name: str,
|
||||
next_run_at: datetime,
|
||||
) -> bool:
|
||||
stmt = (
|
||||
insert(AutomationJob)
|
||||
.values(
|
||||
id=uuid4(),
|
||||
owner_id=owner_id,
|
||||
bootstrap_key=bootstrap_key,
|
||||
title=title,
|
||||
config=config.model_dump(mode="json"),
|
||||
next_run_at=next_run_at,
|
||||
timezone=timezone_name,
|
||||
status=AutomationJobStatus.ACTIVE,
|
||||
created_by=owner_id,
|
||||
)
|
||||
.on_conflict_do_nothing(
|
||||
index_elements=["owner_id", "bootstrap_key"],
|
||||
index_where=AutomationJob.deleted_at.is_(None)
|
||||
& AutomationJob.bootstrap_key.is_not(None),
|
||||
)
|
||||
.returning(AutomationJob.id)
|
||||
)
|
||||
inserted_id = (await self._session.execute(stmt)).scalar_one_or_none()
|
||||
await self._session.flush()
|
||||
return inserted_id is not None
|
||||
|
||||
async def upsert_initial_memory(
|
||||
self,
|
||||
*,
|
||||
owner_id: UUID,
|
||||
memory_type: MemoryType,
|
||||
content: dict,
|
||||
) -> bool:
|
||||
return await self._memories_repository.create_if_absent(
|
||||
owner_id=owner_id,
|
||||
memory_type=memory_type,
|
||||
content=content,
|
||||
)
|
||||
|
||||
|
||||
class RegistrationBootstrapper(Protocol):
|
||||
async def ensure_user_automation_jobs(self, *, user_id: str | UUID) -> None: ...
|
||||
|
||||
|
||||
class RegistrationBootstrapRepositoryLike(Protocol):
|
||||
async def get_profile_timezone(self, *, user_id: UUID) -> str: ...
|
||||
|
||||
async def insert_bootstrap_automation_job_if_absent(
|
||||
self,
|
||||
*,
|
||||
owner_id: UUID,
|
||||
bootstrap_key: str,
|
||||
title: str,
|
||||
config: AutomationJobConfig,
|
||||
timezone_name: str,
|
||||
next_run_at: datetime,
|
||||
) -> bool: ...
|
||||
|
||||
async def upsert_initial_memory(
|
||||
self,
|
||||
*,
|
||||
owner_id: UUID,
|
||||
memory_type: MemoryType,
|
||||
content: dict,
|
||||
) -> bool: ...
|
||||
|
||||
|
||||
class SessionLike(Protocol):
|
||||
async def commit(self) -> None: ...
|
||||
|
||||
async def rollback(self) -> None: ...
|
||||
|
||||
|
||||
def compute_first_run_at_utc(
|
||||
*,
|
||||
now_utc: datetime,
|
||||
timezone_name: str,
|
||||
schedule: ScheduleConfig,
|
||||
) -> datetime:
|
||||
try:
|
||||
timezone_obj = ZoneInfo(timezone_name)
|
||||
except ZoneInfoNotFoundError:
|
||||
timezone_obj = ZoneInfo("UTC")
|
||||
|
||||
local_now = now_utc.astimezone(timezone_obj)
|
||||
run_clock = time(
|
||||
hour=schedule.run_at.hour,
|
||||
minute=schedule.run_at.minute,
|
||||
tzinfo=timezone_obj,
|
||||
)
|
||||
|
||||
if schedule.type == ScheduleType.DAILY:
|
||||
candidate_local = datetime.combine(local_now.date(), run_clock)
|
||||
if candidate_local <= local_now:
|
||||
candidate_local = candidate_local + timedelta(days=1)
|
||||
return candidate_local.astimezone(UTC)
|
||||
|
||||
weekdays = schedule.weekdays or []
|
||||
if not weekdays:
|
||||
raise ValueError("weekly schedule requires weekdays")
|
||||
|
||||
normalized_weekdays = sorted(set(weekdays))
|
||||
for day_offset in range(0, 8):
|
||||
candidate_day = local_now.date() + timedelta(days=day_offset)
|
||||
if candidate_day.isoweekday() not in normalized_weekdays:
|
||||
continue
|
||||
candidate_local = datetime.combine(candidate_day, run_clock)
|
||||
if candidate_local > local_now:
|
||||
return candidate_local.astimezone(UTC)
|
||||
|
||||
fallback_day = local_now.date() + timedelta(days=7)
|
||||
while fallback_day.isoweekday() not in normalized_weekdays:
|
||||
fallback_day = fallback_day + timedelta(days=1)
|
||||
fallback_local = datetime.combine(fallback_day, run_clock)
|
||||
return fallback_local.astimezone(UTC)
|
||||
|
||||
|
||||
class RegistrationAutomationBootstrapService:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
repository: RegistrationBootstrapRepositoryLike,
|
||||
session: SessionLike,
|
||||
) -> None:
|
||||
self._repository = repository
|
||||
self._session = session
|
||||
|
||||
async def ensure_user_automation_jobs(self, *, user_id: str | UUID) -> None:
|
||||
request = RegistrationBootstrapRequest.model_validate({"user_id": user_id})
|
||||
owner_id = request.user_id
|
||||
timezone_name = await self._repository.get_profile_timezone(user_id=owner_id)
|
||||
|
||||
definitions = [
|
||||
{
|
||||
"bootstrap_key": "memory_extraction",
|
||||
"config_name": "memory_extraction",
|
||||
"title": "记忆推送",
|
||||
}
|
||||
]
|
||||
|
||||
try:
|
||||
inserted_any = False
|
||||
created_or_updated_memory = False
|
||||
|
||||
user_initialized = await self._repository.upsert_initial_memory(
|
||||
owner_id=owner_id,
|
||||
memory_type=MemoryType.USER,
|
||||
content=UserMemoryContent().model_dump(mode="json"),
|
||||
)
|
||||
work_initialized = await self._repository.upsert_initial_memory(
|
||||
owner_id=owner_id,
|
||||
memory_type=MemoryType.WORK,
|
||||
content=WorkProfileContent().model_dump(mode="json"),
|
||||
)
|
||||
created_or_updated_memory = user_initialized or work_initialized
|
||||
|
||||
for definition in definitions:
|
||||
bootstrap_key = str(definition["bootstrap_key"])
|
||||
job_config = load_static_automation_job_config(
|
||||
config_name=str(definition["config_name"])
|
||||
)
|
||||
schedule = job_config.schedule
|
||||
if schedule is None:
|
||||
raise ValueError(
|
||||
f"bootstrap job {bootstrap_key} has no schedule configured"
|
||||
)
|
||||
next_run_at = compute_first_run_at_utc(
|
||||
now_utc=datetime.now(UTC),
|
||||
timezone_name=timezone_name,
|
||||
schedule=schedule,
|
||||
)
|
||||
inserted = (
|
||||
await self._repository.insert_bootstrap_automation_job_if_absent(
|
||||
owner_id=owner_id,
|
||||
bootstrap_key=bootstrap_key,
|
||||
title=str(definition["title"]),
|
||||
config=job_config,
|
||||
timezone_name=timezone_name,
|
||||
next_run_at=next_run_at,
|
||||
)
|
||||
)
|
||||
inserted_any = inserted_any or inserted
|
||||
if inserted_any or created_or_updated_memory:
|
||||
await self._session.commit()
|
||||
logger.info(
|
||||
"user automation jobs bootstrapped",
|
||||
user_id=user_id,
|
||||
timezone=timezone_name,
|
||||
memory_initialized=created_or_updated_memory,
|
||||
)
|
||||
except Exception:
|
||||
await self._session.rollback()
|
||||
raise
|
||||
@@ -6,8 +6,8 @@ from core.config.settings import config
|
||||
from v1.auth.rate_limit import enforce_rate_limit
|
||||
from v1.auth.dependencies import get_auth_service
|
||||
from v1.auth.schemas import (
|
||||
EmailSessionCreateRequest,
|
||||
OtpSendRequest,
|
||||
PhoneSessionCreateRequest,
|
||||
SessionDeleteRequest,
|
||||
SessionRefreshRequest,
|
||||
SessionResponse,
|
||||
@@ -26,8 +26,8 @@ async def send_otp(
|
||||
) -> Response:
|
||||
client_ip = _client_ip(request)
|
||||
await enforce_rate_limit(
|
||||
scope="otp_send_phone",
|
||||
identifier=payload.phone,
|
||||
scope="otp_send_email",
|
||||
identifier=payload.email.lower(),
|
||||
limit=3,
|
||||
window_seconds=60,
|
||||
)
|
||||
@@ -41,26 +41,26 @@ async def send_otp(
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@router.post("/phone-session", response_model=SessionResponse)
|
||||
async def create_phone_session(
|
||||
payload: PhoneSessionCreateRequest,
|
||||
@router.post("/email-session", response_model=SessionResponse)
|
||||
async def create_email_session(
|
||||
payload: EmailSessionCreateRequest,
|
||||
request: Request,
|
||||
service: AuthService = Depends(get_auth_service),
|
||||
) -> SessionResponse:
|
||||
client_ip = _client_ip(request)
|
||||
await enforce_rate_limit(
|
||||
scope="phone_session_phone",
|
||||
identifier=payload.phone,
|
||||
scope="session_email",
|
||||
identifier=payload.email.lower(),
|
||||
limit=6,
|
||||
window_seconds=300,
|
||||
)
|
||||
await enforce_rate_limit(
|
||||
scope="phone_session_ip",
|
||||
scope="session_ip",
|
||||
identifier=client_ip,
|
||||
limit=20,
|
||||
window_seconds=300,
|
||||
)
|
||||
return await service.create_phone_session(payload)
|
||||
return await service.create_email_session(payload)
|
||||
|
||||
|
||||
@router.post("/sessions/refresh", response_model=SessionResponse)
|
||||
|
||||
@@ -1,24 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
SUPABASE_PASSWORD_MIN_LENGTH = 6
|
||||
SUPABASE_PHONE_PATTERN = r"^\+[1-9]\d{7,14}$"
|
||||
SUPABASE_EMAIL_PATTERN = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
|
||||
|
||||
|
||||
class OtpSendRequest(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
|
||||
email: str = Field(pattern=SUPABASE_EMAIL_PATTERN)
|
||||
|
||||
|
||||
class PhoneSessionCreateRequest(BaseModel):
|
||||
class EmailSessionCreateRequest(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
|
||||
token: str = Field(pattern=r"^\d{6}$")
|
||||
email: str = Field(pattern=SUPABASE_EMAIL_PATTERN)
|
||||
token: str = Field(min_length=6, max_length=6)
|
||||
|
||||
|
||||
class SessionRefreshRequest(BaseModel):
|
||||
@@ -31,7 +29,7 @@ class SessionDeleteRequest(BaseModel):
|
||||
|
||||
class AuthUser(BaseModel):
|
||||
id: str
|
||||
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
|
||||
email: str = Field(pattern=SUPABASE_EMAIL_PATTERN)
|
||||
|
||||
|
||||
class SessionResponse(BaseModel):
|
||||
@@ -42,25 +40,19 @@ class SessionResponse(BaseModel):
|
||||
user: AuthUser
|
||||
|
||||
|
||||
class UserByPhoneResponse(BaseModel):
|
||||
class UserByEmailResponse(BaseModel):
|
||||
id: str
|
||||
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
|
||||
email: str = Field(pattern=SUPABASE_EMAIL_PATTERN)
|
||||
created_at: str
|
||||
phone_confirmed_at: str | None = None
|
||||
email_confirmed_at: str | None = None
|
||||
|
||||
|
||||
class UserByIdResponse(BaseModel):
|
||||
id: str
|
||||
phone: str | None = None
|
||||
email: str | None = None
|
||||
created_at: str
|
||||
phone_confirmed_at: str | None = None
|
||||
email_confirmed_at: str | None = None
|
||||
|
||||
|
||||
class OtpSendResponse(BaseModel):
|
||||
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
|
||||
|
||||
|
||||
class RegistrationBootstrapRequest(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
user_id: UUID
|
||||
email: str = Field(pattern=SUPABASE_EMAIL_PATTERN)
|
||||
|
||||
@@ -3,8 +3,8 @@ from __future__ import annotations
|
||||
from typing import Protocol
|
||||
|
||||
from v1.auth.schemas import (
|
||||
EmailSessionCreateRequest,
|
||||
OtpSendRequest,
|
||||
PhoneSessionCreateRequest,
|
||||
SessionRefreshRequest,
|
||||
SessionResponse,
|
||||
)
|
||||
@@ -14,8 +14,8 @@ class AuthServiceGateway(Protocol):
|
||||
async def send_otp(self, request: OtpSendRequest) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def create_phone_session(
|
||||
self, request: PhoneSessionCreateRequest
|
||||
async def create_email_session(
|
||||
self, request: EmailSessionCreateRequest
|
||||
) -> SessionResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -28,36 +28,23 @@ class AuthServiceGateway(Protocol):
|
||||
|
||||
class AuthService:
|
||||
_gateway: AuthServiceGateway
|
||||
_registration_bootstrapper: RegistrationBootstrapper | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
gateway: AuthServiceGateway,
|
||||
registration_bootstrapper: "RegistrationBootstrapper | None" = None,
|
||||
) -> None:
|
||||
self._gateway = gateway
|
||||
self._registration_bootstrapper = registration_bootstrapper
|
||||
|
||||
async def send_otp(self, request: OtpSendRequest) -> None:
|
||||
await self._gateway.send_otp(request)
|
||||
|
||||
async def create_phone_session(
|
||||
self, request: PhoneSessionCreateRequest
|
||||
async def create_email_session(
|
||||
self, request: EmailSessionCreateRequest
|
||||
) -> SessionResponse:
|
||||
response = await self._gateway.create_phone_session(request)
|
||||
if self._registration_bootstrapper is not None:
|
||||
await self._registration_bootstrapper.ensure_user_automation_jobs(
|
||||
user_id=response.user.id
|
||||
)
|
||||
return response
|
||||
return await self._gateway.create_email_session(request)
|
||||
|
||||
async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse:
|
||||
return await self._gateway.refresh_session(request)
|
||||
|
||||
async def delete_session(self, refresh_token: str | None) -> None:
|
||||
await self._gateway.delete_session(refresh_token)
|
||||
|
||||
|
||||
class RegistrationBootstrapper(Protocol):
|
||||
async def ensure_user_automation_jobs(self, *, user_id: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from v1.auth.router import router as auth_router
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/v1")
|
||||
router.include_router(auth_router)
|
||||
Reference in New Issue
Block a user