feat: 切换邮箱认证并重构前后端启动与门禁

This commit is contained in:
qzl
2026-04-02 18:39:35 +08:00
parent 92cdfd9fca
commit 31594558eb
116 changed files with 5608 additions and 628 deletions
+129
View File
@@ -0,0 +1,129 @@
from __future__ import annotations
from contextlib import asynccontextmanager
from typing import AsyncGenerator
from fastapi import FastAPI, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from core.config.settings import config
from core.http.errors import ApiProblemError
from core.http.response import build_problem_details
from core.logging import configure_logging, get_logger, log_service_banner
from services.base import close_registered_services, initialize_registered_services
from v1.router import router as v1_router
class HealthResponse(BaseModel):
status: str
configure_logging(config)
log_service_banner(
service_name=config.runtime.service_name,
environment=config.runtime.environment,
)
logger = get_logger("api.app")
SERVICE_STARTUP_ORDER = ["redis", "supabase"]
@asynccontextmanager
async def lifespan(_: FastAPI) -> AsyncGenerator[None, None]:
initialized, services = await initialize_registered_services(SERVICE_STARTUP_ORDER)
if not initialized:
logger.error("Service initialization failed, aborting startup")
raise RuntimeError("Service initialization failed")
logger.info("Base services initialized", services=SERVICE_STARTUP_ORDER)
try:
yield
finally:
closed = await close_registered_services(services)
if not closed:
logger.warning("Failed to close all base services")
logger.info("Base services closed", services=SERVICE_STARTUP_ORDER)
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=config.cors.allow_origins,
allow_credentials=config.cors.allow_credentials,
allow_methods=config.cors.allow_methods,
allow_headers=config.cors.allow_headers,
)
app.include_router(v1_router)
@app.get("/health", response_model=HealthResponse)
async def health() -> HealthResponse:
return HealthResponse(status="ok")
@app.exception_handler(ApiProblemError)
async def api_problem_exception_handler(
request: Request,
exc: ApiProblemError,
) -> JSONResponse:
problem = build_problem_details(
status_code=exc.status_code,
detail=exc.detail,
instance=request.url.path,
code=exc.code,
params=exc.params,
)
return JSONResponse(
status_code=exc.status_code,
content=problem.model_dump(),
media_type="application/problem+json",
)
@app.exception_handler(RequestValidationError)
async def request_validation_exception_handler(
request: Request,
exc: RequestValidationError,
) -> JSONResponse:
logger.warning(
"Request validation error",
path=request.url.path,
method=request.method,
errors=exc.errors(),
)
problem = build_problem_details(
status_code=422,
detail="Invalid request",
instance=request.url.path,
code="REQUEST_VALIDATION_ERROR",
params={"errors": exc.errors()},
)
return JSONResponse(
status_code=422,
content=problem.model_dump(),
media_type="application/problem+json",
)
@app.exception_handler(Exception)
async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse:
logger.exception(
"Unhandled error",
path=request.url.path,
method=request.method,
error_type=exc.__class__.__name__,
)
problem = build_problem_details(
status_code=500,
detail="Internal Server Error",
instance=request.url.path,
code="INTERNAL_ERROR",
)
return JSONResponse(
status_code=500,
content=problem.model_dump(),
media_type="application/problem+json",
)
+2 -3
View File
@@ -6,15 +6,14 @@ import sys
from pathlib import Path
from core.config.initial.init_data import initialize_data
from core.config.settings import config
from core.logging import get_logger
logger = get_logger("core.runtime.cli")
def _resolve_alembic_path() -> Path:
project_root = Path(__file__).parents[3]
alembic_path = project_root / "alembic" / "alembic.ini"
project_root = Path(__file__).parents[4]
alembic_path = project_root / "backend" / "alembic" / "alembic.ini"
if not alembic_path.exists():
raise FileNotFoundError(f"Alembic config not found at {alembic_path}")
return alembic_path
@@ -62,7 +62,6 @@ class ClientTimeContext(BaseModel):
class RuntimeMode(str, Enum):
CHAT = "chat"
AUTOMATION = "automation"
class ForwardedPropsPayload(BaseModel):
-5
View File
@@ -4,8 +4,6 @@ import asyncio
from typing import Any
from supabase import create_client
from storage3.exceptions import StorageApiError
from core.config.settings import SupabaseSettings, config
from .service_interface import BaseServiceProvider, register_service_instance
@@ -183,9 +181,6 @@ class SupabaseService(BaseServiceProvider):
def _is_bucket_not_found_error(self, exc: Exception) -> bool:
"""Check if the exception indicates a bucket was not found."""
if isinstance(exc, StorageApiError):
message = str(exc).lower()
return "bucket" in message and "not found" in message
message = str(exc).lower()
return "bucket" in message and "not found" in message
@@ -1,35 +0,0 @@
from __future__ import annotations
from functools import lru_cache
from pathlib import Path
import re
from typing import Any
import yaml
from schemas.domain.automation import AutomationJobConfig
_CONFIG_NAME_PATTERN = re.compile(r"^[a-z0-9][a-z0-9_-]{0,63}$")
def _automation_yaml_path(config_name: str) -> Path:
if not _CONFIG_NAME_PATTERN.fullmatch(config_name):
raise ValueError("invalid automation config name")
return (
Path(__file__).resolve().parents[2]
/ "core"
/ "config"
/ "static"
/ "automation"
/ f"{config_name}.yaml"
)
@lru_cache(maxsize=16)
def load_static_automation_job_config(*, config_name: str) -> AutomationJobConfig:
path = _automation_yaml_path(config_name)
with path.open("r", encoding="utf-8") as file:
loaded: Any = yaml.safe_load(file) or {}
if not isinstance(loaded, dict):
raise ValueError(f"invalid automation config format: {path}")
return AutomationJobConfig.model_validate(loaded)
+1 -16
View File
@@ -1,27 +1,12 @@
from __future__ import annotations
from typing import Annotated
from fastapi import Depends
from sqlalchemy.ext.asyncio import AsyncSession
from core.db import get_db
from v1.auth.gateway import SupabaseAuthGateway
from v1.auth.registration_bootstrap import (
RegistrationAutomationBootstrapService,
RegistrationBootstrapRepository,
)
from v1.auth.service import AuthService
def get_auth_service(
session: Annotated[AsyncSession, Depends(get_db)],
) -> AuthService:
bootstrapper = RegistrationAutomationBootstrapService(
repository=RegistrationBootstrapRepository(session=session),
session=session,
)
def get_auth_service() -> AuthService:
return AuthService(
gateway=SupabaseAuthGateway(),
registration_bootstrapper=bootstrapper,
)
+91
View File
@@ -0,0 +1,91 @@
from __future__ import annotations
import asyncio
from typing import Any, Callable, cast
from supabase import AuthError
from core.http.errors import ApiProblemError
from core.logging import get_logger
from v1.auth.schemas import EmailSessionCreateRequest, SessionResponse
logger = get_logger("v1.auth.dev_email_session")
def _auth_error(*, status_code: int, code: str, detail: str) -> ApiProblemError:
return ApiProblemError(status_code=status_code, code=code, detail=detail)
async def create_dev_email_session(
*,
request: EmailSessionCreateRequest,
client: Any,
admin_client: Any,
auth_unavailable_detail: str,
is_auth_upstream_unavailable: Callable[[AuthError], bool],
map_auth_response: Callable[[object, str, str], SessionResponse],
) -> SessionResponse:
generate_link_payload: dict[str, Any] = {
"type": "magiclink",
"email": request.email,
}
try:
generate_link = cast(Any, admin_client.auth.admin.generate_link)
link_response = await asyncio.to_thread(generate_link, generate_link_payload)
except AuthError as exc:
logger.warning(
"Dev email session link generation failed",
error_type=type(exc).__name__,
)
if is_auth_upstream_unavailable(exc):
raise _auth_error(
status_code=503,
code="AUTH_SERVICE_UNAVAILABLE",
detail=auth_unavailable_detail,
) from exc
raise _auth_error(
status_code=401,
code="AUTH_VERIFICATION_CODE_INVALID",
detail="Invalid verification code",
) from exc
properties = getattr(link_response, "properties", None)
dev_token = str(getattr(properties, "email_otp", "")).strip()
if not dev_token:
raise _auth_error(
status_code=401,
code="AUTH_VERIFICATION_CODE_INVALID",
detail="Invalid verification code",
)
verify_payload: dict[str, Any] = {
"type": "email",
"email": request.email,
"token": dev_token,
}
try:
verify_otp = cast(Any, client.auth.verify_otp)
response = await asyncio.to_thread(verify_otp, verify_payload)
logger.info("Dev email session bypassed otp verification")
return map_auth_response(
response,
"Invalid verification code",
"AUTH_VERIFICATION_CODE_INVALID",
)
except AuthError as exc:
logger.warning(
"Dev email session verification failed",
error_type=type(exc).__name__,
)
if is_auth_upstream_unavailable(exc):
raise _auth_error(
status_code=503,
code="AUTH_SERVICE_UNAVAILABLE",
detail=auth_unavailable_detail,
) from exc
raise _auth_error(
status_code=401,
code="AUTH_VERIFICATION_CODE_INVALID",
detail="Invalid verification code",
) from exc
+58 -96
View File
@@ -1,6 +1,7 @@
from __future__ import annotations
import asyncio
import re
import time
from typing import Any, cast
@@ -8,17 +9,19 @@ from pydantic import ValidationError
from supabase import AuthError
from core.config.settings import config
from core.http.errors import ApiProblemError
from core.logging import get_logger
from services.base.supabase import supabase_service
from v1.auth.dev_email_session import create_dev_email_session
from v1.auth.schemas import (
AuthUser,
EmailSessionCreateRequest,
OtpSendRequest,
PhoneSessionCreateRequest,
SessionRefreshRequest,
SessionResponse,
UserByIdResponse,
UserByPhoneResponse,
UserByEmailResponse,
)
from v1.auth.service import AuthServiceGateway
@@ -40,7 +43,7 @@ class SupabaseAuthGateway(AuthServiceGateway):
def __init__(self) -> None:
self._user_lookup_cache_ttl_seconds: int = 60
self._user_lookup_cache_expires_at: float = 0.0
self._users_by_phone: dict[str, Any] = {}
self._users_by_email: dict[str, Any] = {}
self._users_by_id: dict[str, Any] = {}
def _get_client(self) -> Any:
@@ -52,7 +55,7 @@ class SupabaseAuthGateway(AuthServiceGateway):
async def send_otp(self, request: OtpSendRequest) -> None:
client = self._get_client()
payload: dict[str, Any] = {
"phone": request.phone,
"email": request.email,
"options": {"should_create_user": True},
}
try:
@@ -72,13 +75,23 @@ class SupabaseAuthGateway(AuthServiceGateway):
detail="Too many requests",
) from exc
async def create_phone_session(
self, request: PhoneSessionCreateRequest
async def create_email_session(
self, request: EmailSessionCreateRequest
) -> SessionResponse:
if config.runtime.environment == "dev":
return await create_dev_email_session(
request=request,
client=self._get_client(),
admin_client=self._get_admin_client(),
auth_unavailable_detail=AUTH_UNAVAILABLE_DETAIL,
is_auth_upstream_unavailable=_is_auth_upstream_unavailable,
map_auth_response=_map_auth_response,
)
client = self._get_client()
payload: dict[str, Any] = {
"type": "sms",
"phone": request.phone,
"type": "email",
"email": request.email,
"token": request.token,
}
try:
@@ -90,7 +103,7 @@ class SupabaseAuthGateway(AuthServiceGateway):
"AUTH_VERIFICATION_CODE_INVALID",
)
except AuthError as exc:
logger.warning("Create phone session failed", error_type=type(exc).__name__)
logger.warning("Create email session failed", error_type=type(exc).__name__)
if _is_auth_upstream_unavailable(exc):
raise _auth_error(
status_code=503,
@@ -169,9 +182,9 @@ class SupabaseAuthGateway(AuthServiceGateway):
detail="Invalid refresh token",
) from exc
async def get_user_by_phone(self, phone: str) -> UserByPhoneResponse:
normalized_phone = _normalize_phone(phone)
if not normalized_phone:
async def get_user_by_email(self, email: str) -> UserByEmailResponse:
normalized_email = _normalize_email(email)
if not normalized_email:
raise _auth_error(
status_code=404,
code="AUTH_USER_NOT_FOUND",
@@ -180,7 +193,7 @@ class SupabaseAuthGateway(AuthServiceGateway):
await self._refresh_user_lookup_cache_if_needed()
user = self._users_by_phone.get(normalized_phone)
user = self._users_by_email.get(normalized_email)
if user is None:
raise _auth_error(
status_code=404,
@@ -188,21 +201,21 @@ class SupabaseAuthGateway(AuthServiceGateway):
detail="User not found",
)
user_phone = _normalize_phone(getattr(user, "phone", ""))
if not user_phone:
user_email = _normalize_email(getattr(user, "email", ""))
if not user_email:
raise _auth_error(
status_code=404,
code="AUTH_USER_NOT_FOUND",
detail="User not found",
)
return UserByPhoneResponse(
return UserByEmailResponse(
id=str(getattr(user, "id", "")),
phone=user_phone,
email=user_email,
created_at=str(getattr(user, "created_at", "")),
phone_confirmed_at=(
str(getattr(user, "phone_confirmed_at", ""))
if getattr(user, "phone_confirmed_at", None)
email_confirmed_at=(
str(getattr(user, "email_confirmed_at", ""))
if getattr(user, "email_confirmed_at", None)
else None
),
)
@@ -233,53 +246,27 @@ class SupabaseAuthGateway(AuthServiceGateway):
user_attrs = getattr(user, "user", user)
resolved[normalized_user_id] = UserByIdResponse(
id=str(getattr(user_attrs, "id", "")),
phone=getattr(user_attrs, "phone", None),
email=getattr(user_attrs, "email", None),
created_at=str(getattr(user_attrs, "created_at", "")),
phone_confirmed_at=(
str(getattr(user_attrs, "phone_confirmed_at", ""))
if getattr(user_attrs, "phone_confirmed_at", None)
email_confirmed_at=(
str(getattr(user_attrs, "email_confirmed_at", ""))
if getattr(user_attrs, "email_confirmed_at", None)
else None
),
)
return resolved
async def search_user_ids_by_phone(self, query: str, limit: int = 20) -> list[str]:
normalized_query = _normalize_phone_search_query(query)
async def search_user_ids_by_email(self, query: str, limit: int = 20) -> list[str]:
normalized_query = _normalize_email(query)
if not normalized_query:
return []
await self._refresh_user_lookup_cache_if_needed()
if normalized_query.startswith("+"):
matched_user = self._users_by_phone.get(normalized_query)
if matched_user is None:
return []
user_id = str(getattr(matched_user, "id", ""))
return [user_id] if user_id else []
digits = _digits_only(normalized_query)
if not digits:
matched_user = self._users_by_email.get(normalized_query)
if matched_user is None:
return []
matched_records: list[tuple[str, str]] = []
for cached_phone, candidate in self._users_by_phone.items():
candidate_digits = _digits_only(cached_phone)
if not candidate_digits.endswith(digits):
continue
user_id = str(getattr(candidate, "id", ""))
if user_id:
matched_records.append((cached_phone, user_id))
if not matched_records:
return []
unique_ids: list[str] = []
for _, user_id in sorted(matched_records, key=lambda item: item[0]):
if user_id in unique_ids:
continue
unique_ids.append(user_id)
if len(unique_ids) >= max(1, limit):
break
return unique_ids
user_id = str(getattr(matched_user, "id", ""))
return [user_id] if user_id else []
async def _refresh_user_lookup_cache_if_needed(self) -> None:
now = time.monotonic()
@@ -288,17 +275,17 @@ class SupabaseAuthGateway(AuthServiceGateway):
admin_client = self._get_admin_client()
users = await asyncio.to_thread(_list_auth_users, admin_client)
users_by_phone: dict[str, Any] = {}
users_by_email: dict[str, Any] = {}
users_by_id: dict[str, Any] = {}
for candidate in users:
candidate_id = str(getattr(candidate, "id", "")).strip()
if candidate_id:
users_by_id[candidate_id] = candidate
candidate_phone = _normalize_phone(getattr(candidate, "phone", ""))
if candidate_phone:
users_by_phone[candidate_phone] = candidate
candidate_email = _normalize_email(getattr(candidate, "email", ""))
if candidate_email:
users_by_email[candidate_email] = candidate
self._users_by_id = users_by_id
self._users_by_phone = users_by_phone
self._users_by_email = users_by_email
self._user_lookup_cache_expires_at = now + self._user_lookup_cache_ttl_seconds
@@ -343,8 +330,8 @@ def _map_auth_response(
detail=failure_message,
)
phone = _normalize_phone(getattr(user, "phone", None))
if not phone:
email = _normalize_email(getattr(user, "email", None))
if not email:
raise _auth_error(
status_code=401,
code=failure_code,
@@ -352,10 +339,10 @@ def _map_auth_response(
)
try:
auth_user = AuthUser(id=str(user.id), phone=str(phone))
auth_user = AuthUser(id=str(user.id), email=str(email))
except ValidationError as exc:
logger.warning(
"Auth response returned invalid phone format",
"Auth response returned invalid email format",
error_type=type(exc).__name__,
)
raise _auth_error(
@@ -393,38 +380,13 @@ def _list_auth_users(client: Any) -> list[Any]:
return users
def _sanitize_phone_token(raw: object) -> str:
token = str(raw).strip()
for separator in (" ", "-", "(", ")"):
token = token.replace(separator, "")
return token
_EMAIL_PATTERN = re.compile(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$")
def _normalize_phone(raw_phone: object) -> str | None:
phone = _sanitize_phone_token(raw_phone)
if not phone:
def _normalize_email(raw_email: object) -> str | None:
if not isinstance(raw_email, str):
return None
if phone.startswith("00") and len(phone) > 2:
return f"+{phone[2:]}"
if phone.startswith("+"):
return phone
if phone.isdigit():
return f"+{phone}"
return None
def _normalize_phone_search_query(raw_query: str) -> str | None:
query = _sanitize_phone_token(raw_query)
if not query:
email = raw_email.strip().lower()
if not _EMAIL_PATTERN.fullmatch(email):
return None
if query.startswith("00") and len(query) > 2:
return f"+{query[2:]}"
if query.startswith("+"):
return query
if query.isdigit():
return query
return None
def _digits_only(value: str) -> str:
return "".join(ch for ch in value if ch.isdigit())
return email
@@ -1,239 +0,0 @@
from __future__ import annotations
from datetime import UTC, datetime, time, timedelta
from typing import Protocol
from uuid import UUID, uuid4
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
from sqlalchemy import select
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.ext.asyncio import AsyncSession
from core.logging import get_logger
from models.automation_jobs import AutomationJob
from schemas.enums import AutomationJobStatus, MemoryType, ScheduleType
from models.profile import Profile
from schemas.domain.automation import AutomationJobConfig, ScheduleConfig
from schemas.domain.memory_content import UserMemoryContent, WorkProfileContent
from schemas.shared.user import parse_profile_settings
from v1.auth.automation_static_config import load_static_automation_job_config
from v1.auth.schemas import RegistrationBootstrapRequest
from v1.memories.repository import SQLAlchemyMemoriesRepository
logger = get_logger("v1.auth.registration_bootstrap")
class RegistrationBootstrapRepository:
def __init__(self, session: AsyncSession) -> None:
self._session = session
self._memories_repository = SQLAlchemyMemoriesRepository(session)
async def get_profile_timezone(self, *, user_id: UUID) -> str:
stmt = select(Profile.settings).where(Profile.id == user_id)
settings = (await self._session.execute(stmt)).scalar_one_or_none()
parsed = parse_profile_settings(
settings if isinstance(settings, dict) else None
)
return parsed.preferences.timezone
async def insert_bootstrap_automation_job_if_absent(
self,
*,
owner_id: UUID,
bootstrap_key: str,
title: str,
config: AutomationJobConfig,
timezone_name: str,
next_run_at: datetime,
) -> bool:
stmt = (
insert(AutomationJob)
.values(
id=uuid4(),
owner_id=owner_id,
bootstrap_key=bootstrap_key,
title=title,
config=config.model_dump(mode="json"),
next_run_at=next_run_at,
timezone=timezone_name,
status=AutomationJobStatus.ACTIVE,
created_by=owner_id,
)
.on_conflict_do_nothing(
index_elements=["owner_id", "bootstrap_key"],
index_where=AutomationJob.deleted_at.is_(None)
& AutomationJob.bootstrap_key.is_not(None),
)
.returning(AutomationJob.id)
)
inserted_id = (await self._session.execute(stmt)).scalar_one_or_none()
await self._session.flush()
return inserted_id is not None
async def upsert_initial_memory(
self,
*,
owner_id: UUID,
memory_type: MemoryType,
content: dict,
) -> bool:
return await self._memories_repository.create_if_absent(
owner_id=owner_id,
memory_type=memory_type,
content=content,
)
class RegistrationBootstrapper(Protocol):
async def ensure_user_automation_jobs(self, *, user_id: str | UUID) -> None: ...
class RegistrationBootstrapRepositoryLike(Protocol):
async def get_profile_timezone(self, *, user_id: UUID) -> str: ...
async def insert_bootstrap_automation_job_if_absent(
self,
*,
owner_id: UUID,
bootstrap_key: str,
title: str,
config: AutomationJobConfig,
timezone_name: str,
next_run_at: datetime,
) -> bool: ...
async def upsert_initial_memory(
self,
*,
owner_id: UUID,
memory_type: MemoryType,
content: dict,
) -> bool: ...
class SessionLike(Protocol):
async def commit(self) -> None: ...
async def rollback(self) -> None: ...
def compute_first_run_at_utc(
*,
now_utc: datetime,
timezone_name: str,
schedule: ScheduleConfig,
) -> datetime:
try:
timezone_obj = ZoneInfo(timezone_name)
except ZoneInfoNotFoundError:
timezone_obj = ZoneInfo("UTC")
local_now = now_utc.astimezone(timezone_obj)
run_clock = time(
hour=schedule.run_at.hour,
minute=schedule.run_at.minute,
tzinfo=timezone_obj,
)
if schedule.type == ScheduleType.DAILY:
candidate_local = datetime.combine(local_now.date(), run_clock)
if candidate_local <= local_now:
candidate_local = candidate_local + timedelta(days=1)
return candidate_local.astimezone(UTC)
weekdays = schedule.weekdays or []
if not weekdays:
raise ValueError("weekly schedule requires weekdays")
normalized_weekdays = sorted(set(weekdays))
for day_offset in range(0, 8):
candidate_day = local_now.date() + timedelta(days=day_offset)
if candidate_day.isoweekday() not in normalized_weekdays:
continue
candidate_local = datetime.combine(candidate_day, run_clock)
if candidate_local > local_now:
return candidate_local.astimezone(UTC)
fallback_day = local_now.date() + timedelta(days=7)
while fallback_day.isoweekday() not in normalized_weekdays:
fallback_day = fallback_day + timedelta(days=1)
fallback_local = datetime.combine(fallback_day, run_clock)
return fallback_local.astimezone(UTC)
class RegistrationAutomationBootstrapService:
def __init__(
self,
*,
repository: RegistrationBootstrapRepositoryLike,
session: SessionLike,
) -> None:
self._repository = repository
self._session = session
async def ensure_user_automation_jobs(self, *, user_id: str | UUID) -> None:
request = RegistrationBootstrapRequest.model_validate({"user_id": user_id})
owner_id = request.user_id
timezone_name = await self._repository.get_profile_timezone(user_id=owner_id)
definitions = [
{
"bootstrap_key": "memory_extraction",
"config_name": "memory_extraction",
"title": "记忆推送",
}
]
try:
inserted_any = False
created_or_updated_memory = False
user_initialized = await self._repository.upsert_initial_memory(
owner_id=owner_id,
memory_type=MemoryType.USER,
content=UserMemoryContent().model_dump(mode="json"),
)
work_initialized = await self._repository.upsert_initial_memory(
owner_id=owner_id,
memory_type=MemoryType.WORK,
content=WorkProfileContent().model_dump(mode="json"),
)
created_or_updated_memory = user_initialized or work_initialized
for definition in definitions:
bootstrap_key = str(definition["bootstrap_key"])
job_config = load_static_automation_job_config(
config_name=str(definition["config_name"])
)
schedule = job_config.schedule
if schedule is None:
raise ValueError(
f"bootstrap job {bootstrap_key} has no schedule configured"
)
next_run_at = compute_first_run_at_utc(
now_utc=datetime.now(UTC),
timezone_name=timezone_name,
schedule=schedule,
)
inserted = (
await self._repository.insert_bootstrap_automation_job_if_absent(
owner_id=owner_id,
bootstrap_key=bootstrap_key,
title=str(definition["title"]),
config=job_config,
timezone_name=timezone_name,
next_run_at=next_run_at,
)
)
inserted_any = inserted_any or inserted
if inserted_any or created_or_updated_memory:
await self._session.commit()
logger.info(
"user automation jobs bootstrapped",
user_id=user_id,
timezone=timezone_name,
memory_initialized=created_or_updated_memory,
)
except Exception:
await self._session.rollback()
raise
+10 -10
View File
@@ -6,8 +6,8 @@ from core.config.settings import config
from v1.auth.rate_limit import enforce_rate_limit
from v1.auth.dependencies import get_auth_service
from v1.auth.schemas import (
EmailSessionCreateRequest,
OtpSendRequest,
PhoneSessionCreateRequest,
SessionDeleteRequest,
SessionRefreshRequest,
SessionResponse,
@@ -26,8 +26,8 @@ async def send_otp(
) -> Response:
client_ip = _client_ip(request)
await enforce_rate_limit(
scope="otp_send_phone",
identifier=payload.phone,
scope="otp_send_email",
identifier=payload.email.lower(),
limit=3,
window_seconds=60,
)
@@ -41,26 +41,26 @@ async def send_otp(
return Response(status_code=204)
@router.post("/phone-session", response_model=SessionResponse)
async def create_phone_session(
payload: PhoneSessionCreateRequest,
@router.post("/email-session", response_model=SessionResponse)
async def create_email_session(
payload: EmailSessionCreateRequest,
request: Request,
service: AuthService = Depends(get_auth_service),
) -> SessionResponse:
client_ip = _client_ip(request)
await enforce_rate_limit(
scope="phone_session_phone",
identifier=payload.phone,
scope="session_email",
identifier=payload.email.lower(),
limit=6,
window_seconds=300,
)
await enforce_rate_limit(
scope="phone_session_ip",
scope="session_ip",
identifier=client_ip,
limit=20,
window_seconds=300,
)
return await service.create_phone_session(payload)
return await service.create_email_session(payload)
@router.post("/sessions/refresh", response_model=SessionResponse)
+12 -20
View File
@@ -1,24 +1,22 @@
from __future__ import annotations
from uuid import UUID
from pydantic import BaseModel, ConfigDict, Field
SUPABASE_PASSWORD_MIN_LENGTH = 6
SUPABASE_PHONE_PATTERN = r"^\+[1-9]\d{7,14}$"
SUPABASE_EMAIL_PATTERN = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
class OtpSendRequest(BaseModel):
model_config = ConfigDict(extra="forbid")
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
email: str = Field(pattern=SUPABASE_EMAIL_PATTERN)
class PhoneSessionCreateRequest(BaseModel):
class EmailSessionCreateRequest(BaseModel):
model_config = ConfigDict(extra="forbid")
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
token: str = Field(pattern=r"^\d{6}$")
email: str = Field(pattern=SUPABASE_EMAIL_PATTERN)
token: str = Field(min_length=6, max_length=6)
class SessionRefreshRequest(BaseModel):
@@ -31,7 +29,7 @@ class SessionDeleteRequest(BaseModel):
class AuthUser(BaseModel):
id: str
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
email: str = Field(pattern=SUPABASE_EMAIL_PATTERN)
class SessionResponse(BaseModel):
@@ -42,25 +40,19 @@ class SessionResponse(BaseModel):
user: AuthUser
class UserByPhoneResponse(BaseModel):
class UserByEmailResponse(BaseModel):
id: str
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
email: str = Field(pattern=SUPABASE_EMAIL_PATTERN)
created_at: str
phone_confirmed_at: str | None = None
email_confirmed_at: str | None = None
class UserByIdResponse(BaseModel):
id: str
phone: str | None = None
email: str | None = None
created_at: str
phone_confirmed_at: str | None = None
email_confirmed_at: str | None = None
class OtpSendResponse(BaseModel):
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
class RegistrationBootstrapRequest(BaseModel):
model_config = ConfigDict(extra="forbid")
user_id: UUID
email: str = Field(pattern=SUPABASE_EMAIL_PATTERN)
+6 -19
View File
@@ -3,8 +3,8 @@ from __future__ import annotations
from typing import Protocol
from v1.auth.schemas import (
EmailSessionCreateRequest,
OtpSendRequest,
PhoneSessionCreateRequest,
SessionRefreshRequest,
SessionResponse,
)
@@ -14,8 +14,8 @@ class AuthServiceGateway(Protocol):
async def send_otp(self, request: OtpSendRequest) -> None:
raise NotImplementedError
async def create_phone_session(
self, request: PhoneSessionCreateRequest
async def create_email_session(
self, request: EmailSessionCreateRequest
) -> SessionResponse:
raise NotImplementedError
@@ -28,36 +28,23 @@ class AuthServiceGateway(Protocol):
class AuthService:
_gateway: AuthServiceGateway
_registration_bootstrapper: RegistrationBootstrapper | None
def __init__(
self,
gateway: AuthServiceGateway,
registration_bootstrapper: "RegistrationBootstrapper | None" = None,
) -> None:
self._gateway = gateway
self._registration_bootstrapper = registration_bootstrapper
async def send_otp(self, request: OtpSendRequest) -> None:
await self._gateway.send_otp(request)
async def create_phone_session(
self, request: PhoneSessionCreateRequest
async def create_email_session(
self, request: EmailSessionCreateRequest
) -> SessionResponse:
response = await self._gateway.create_phone_session(request)
if self._registration_bootstrapper is not None:
await self._registration_bootstrapper.ensure_user_automation_jobs(
user_id=response.user.id
)
return response
return await self._gateway.create_email_session(request)
async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse:
return await self._gateway.refresh_session(request)
async def delete_session(self, refresh_token: str | None) -> None:
await self._gateway.delete_session(refresh_token)
class RegistrationBootstrapper(Protocol):
async def ensure_user_automation_jobs(self, *, user_id: str) -> None:
raise NotImplementedError
+9
View File
@@ -0,0 +1,9 @@
from __future__ import annotations
from fastapi import APIRouter
from v1.auth.router import router as auth_router
router = APIRouter(prefix="/api/v1")
router.include_router(auth_router)