refactor: align backend layout and supabase infra

Consolidate backend modules/tests under the backend package while syncing Supabase compose/env config and related plans.
This commit is contained in:
qzl
2026-02-05 15:13:06 +08:00
parent 3cfcb11240
commit ad06fe7de4
111 changed files with 5540 additions and 1362 deletions
View File
+9
View File
@@ -0,0 +1,9 @@
from __future__ import annotations
from dataclasses import dataclass
from uuid import UUID
@dataclass(frozen=True)
class CurrentUser:
id: UUID
+3
View File
@@ -0,0 +1,3 @@
from .settings import Settings, config
__all__ = ["Settings", "config"]
+166
View File
@@ -0,0 +1,166 @@
from __future__ import annotations
from pathlib import Path
from typing import ClassVar, Literal
from urllib.parse import quote
from pydantic import BaseModel, Field, computed_field
from pydantic_settings import BaseSettings, SettingsConfigDict
class RuntimeSettings(BaseModel):
environment: Literal["dev", "test", "prod"] = "dev"
debug: bool = True
log_level: str = "INFO"
log_json: bool = True
log_rotation: Literal["time", "size", "none"] = "time"
log_rotation_when: str = "midnight"
log_rotation_interval: int = 1
log_rotation_backup_count: int = 14
log_rotation_max_bytes: int = 10_000_000
log_dir: str = "logs"
log_error_dir: str = "logs/errors"
log_file_name: str = "app.log"
log_error_file_name: str = "error.log"
log_sensitive_fields: list[str] = Field(
default_factory=lambda: [
"password",
"secret",
"token",
"api_key",
"authorization",
"cookie",
"client_ip",
"user_id",
]
)
sql_log_queries: bool = False
class AppSettings(BaseModel):
host: str = "0.0.0.0"
port: int = Field(default=8000, ge=1, le=65535)
reload: bool = True
class CorsSettings(BaseModel):
allow_origins: list[str] = Field(
default_factory=lambda: [
"http://localhost",
"http://localhost:3000",
]
)
allow_credentials: bool = True
allow_methods: list[str] = Field(default_factory=lambda: ["*"])
allow_headers: list[str] = Field(default_factory=lambda: ["*"])
class RedisSettings(BaseModel):
host: str = "redis"
port: int = 6379
password: str | None = None
db: int = 0
socket_connect_timeout: float = 1.0
socket_timeout: float = 1.0
max_connections: int = 10
@computed_field
@property
def url(self) -> str:
if self.password:
password = quote(self.password, safe="")
return f"redis://:{password}@{self.host}:{self.port}/{self.db}"
return f"redis://{self.host}:{self.port}/{self.db}"
class QdrantSettings(BaseModel):
host: str = "qdrant"
port: int = 6333
grpc_port: int = 6334
api_key: str | None = None
https: bool = False
prefer_grpc: bool = True
timeout: int = 5
@computed_field
@property
def url(self) -> str:
scheme = "https" if self.https else "http"
return f"{scheme}://{self.host}:{self.port}"
class SupabaseSettings(BaseModel):
public_scheme: str = "http"
public_host: str = "localhost"
kong_http_port: int = 8000
anon_key: str = "CHANGE_ME"
service_role_key: str = "CHANGE_ME"
jwt_secret: str | None = None
@computed_field
@property
def public_url(self) -> str:
return f"{self.public_scheme}://{self.public_host}:{self.kong_http_port}"
@computed_field
@property
def api_external_url(self) -> str:
return self.public_url
@computed_field
@property
def url(self) -> str:
return self.public_url
class DatabaseSettings(BaseModel):
host: str = "localhost"
port: int = 5432
name: str = "postgres"
user: str = "postgres"
password: str = "CHANGE_ME"
@computed_field
@property
def url(self) -> str:
password = quote(self.password, safe="")
return (
f"postgresql+asyncpg://{self.user}:{password}"
f"@{self.host}:{self.port}/{self.name}"
)
def _resolve_env_file() -> str:
current = Path(__file__).resolve()
for parent in [current, *current.parents]:
candidate = parent / ".env"
if candidate.is_file():
return str(candidate)
return ".env"
class Settings(BaseSettings):
runtime: RuntimeSettings = RuntimeSettings()
app: AppSettings = AppSettings()
cors: CorsSettings = CorsSettings()
redis: RedisSettings = RedisSettings()
qdrant: QdrantSettings = QdrantSettings()
supabase: SupabaseSettings = SupabaseSettings()
database: DatabaseSettings = DatabaseSettings()
@computed_field
@property
def database_url(self) -> str:
return self.database.url
model_config: ClassVar[SettingsConfigDict] = SettingsConfigDict(
env_file=_resolve_env_file(),
env_prefix="SOCIAL_",
env_nested_delimiter="__",
case_sensitive=False,
extra="ignore",
)
config = Settings()
+5
View File
@@ -0,0 +1,5 @@
from __future__ import annotations
from core.db.session import AsyncSessionLocal, engine, get_db
__all__ = ["AsyncSessionLocal", "engine", "get_db"]
+37
View File
@@ -0,0 +1,37 @@
from __future__ import annotations
from datetime import datetime
from sqlalchemy import DateTime, func
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
class Base(DeclarativeBase):
"""Base class for all ORM models."""
pass
class TimestampMixin:
"""Adds created_at and updated_at timestamps."""
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
class SoftDeleteMixin:
"""Adds soft delete timestamp column."""
deleted_at: Mapped[datetime | None] = mapped_column(
DateTime(timezone=True),
nullable=True,
)
+84
View File
@@ -0,0 +1,84 @@
from __future__ import annotations
from datetime import datetime, timezone
from typing import Any, Generic, TypeVar
from sqlalchemy import Select, select, update
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from core.db.base import Base
ModelType = TypeVar("ModelType", bound=Base)
class BaseRepository(Generic[ModelType]):
_session: AsyncSession
_model: type[ModelType]
def __init__(self, session: AsyncSession, model: type[ModelType]) -> None:
self._session = session
self._model = model
def _deleted_at_column(self) -> Any | None:
return getattr(self._model, "deleted_at", None)
def _apply_soft_delete_filter(self, stmt: Select) -> Select:
deleted_at = self._deleted_at_column()
if deleted_at is None:
return stmt
return stmt.where(deleted_at.is_(None))
async def get_by_id(self, entity_id: Any) -> ModelType | None:
id_column = getattr(self._model, "id")
stmt = select(self._model).where(id_column == entity_id)
stmt = self._apply_soft_delete_filter(stmt)
result = await self._session.execute(stmt)
return result.scalar_one_or_none()
async def get_one(self, *filters: Any) -> ModelType | None:
stmt = select(self._model).where(*filters)
stmt = self._apply_soft_delete_filter(stmt)
result = await self._session.execute(stmt)
return result.scalar_one_or_none()
async def update_by_id(
self, entity_id: Any, update_data: dict[str, Any]
) -> ModelType | None:
if not update_data:
return await self.get_by_id(entity_id)
id_column = getattr(self._model, "id")
stmt = update(self._model).where(id_column == entity_id)
deleted_at = self._deleted_at_column()
if deleted_at is not None:
stmt = stmt.where(deleted_at.is_(None))
stmt = stmt.values(**update_data).returning(self._model)
try:
result = await self._session.execute(stmt)
await self._session.flush()
return result.scalar_one_or_none()
except SQLAlchemyError:
raise
async def soft_delete_by_id(self, entity_id: Any) -> ModelType | None:
deleted_at = self._deleted_at_column()
if deleted_at is None:
raise ValueError("Soft delete is not supported for this model")
id_column = getattr(self._model, "id")
stmt = (
update(self._model)
.where(id_column == entity_id)
.where(deleted_at.is_(None))
.values(deleted_at=datetime.now(timezone.utc))
.returning(self._model)
)
try:
result = await self._session.execute(stmt)
await self._session.flush()
return result.scalar_one_or_none()
except SQLAlchemyError:
raise
+22
View File
@@ -0,0 +1,22 @@
from __future__ import annotations
from uuid import UUID
from fastapi import HTTPException
from core.auth.models import CurrentUser
class BaseService:
_current_user: CurrentUser | None
def __init__(self, current_user: CurrentUser | None) -> None:
self._current_user = current_user
def require_current_user(self) -> CurrentUser:
if self._current_user is None:
raise HTTPException(status_code=401, detail="Unauthorized")
return self._current_user
def require_user_id(self) -> UUID:
return self.require_current_user().id
+34
View File
@@ -0,0 +1,34 @@
from __future__ import annotations
from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from core.config.settings import config
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncEngine
engine: AsyncEngine = create_async_engine(
config.database_url,
echo=config.runtime.sql_log_queries,
pool_pre_ping=True,
)
AsyncSessionLocal: async_sessionmaker[AsyncSession] = async_sessionmaker(
bind=engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
)
async def get_db() -> AsyncGenerator[AsyncSession, None]:
"""Dependency that provides a database session.
The session is automatically closed when the request completes.
Note: The caller (service layer) is responsible for commit/rollback.
"""
async with AsyncSessionLocal() as session:
yield session
+5
View File
@@ -0,0 +1,5 @@
from __future__ import annotations
from core.http.response import ProblemDetails, build_problem_details
__all__ = ["ProblemDetails", "build_problem_details"]
+7
View File
@@ -0,0 +1,7 @@
from __future__ import annotations
from pydantic import BaseModel
class HealthResponse(BaseModel):
status: str
+31
View File
@@ -0,0 +1,31 @@
from __future__ import annotations
from http import HTTPStatus
from pydantic import BaseModel
class ProblemDetails(BaseModel):
type: str = "about:blank"
title: str
status: int
detail: str
instance: str | None = None
def build_problem_details(
*,
status_code: int,
detail: str,
type_value: str = "about:blank",
title: str | None = None,
instance: str | None = None,
) -> ProblemDetails:
resolved_title = title or HTTPStatus(status_code).phrase
return ProblemDetails(
type=type_value,
title=resolved_title,
status=status_code,
detail=detail,
instance=instance,
)
+15
View File
@@ -0,0 +1,15 @@
from __future__ import annotations
from core.logging import celery
from core.logging.config import configure_logging
from core.logging.context import bind_context, clear_context, get_context
from core.logging.logger import get_logger
__all__ = [
"bind_context",
"celery",
"clear_context",
"configure_logging",
"get_context",
"get_logger",
]
+57
View File
@@ -0,0 +1,57 @@
from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass
from typing import cast
from celery import Celery, signals
from core.config.settings import Settings
from core.logging.config import configure_logging
from core.logging.context import bind_context, clear_context
@dataclass(frozen=True)
class CelerySignalHandlers:
on_setup_logging: Callable[..., None]
on_after_setup_task_logger: Callable[..., None]
on_task_prerun: Callable[..., None]
on_task_postrun: Callable[..., None]
def build_celery_signal_handlers(
settings: Settings | None = None,
) -> CelerySignalHandlers:
def on_setup_logging(*_args: object, **_kwargs: object) -> None:
configure_logging(settings)
def on_after_setup_task_logger(*_args: object, **_kwargs: object) -> None:
configure_logging(settings)
def on_task_prerun(*_args: object, **kwargs: object) -> None:
task_id = cast(str | None, kwargs.get("task_id"))
task = kwargs.get("task")
task_name = getattr(task, "name", None)
bind_context(task_id=task_id, task_name=task_name)
def on_task_postrun(*_args: object, **_kwargs: object) -> None:
clear_context()
return CelerySignalHandlers(
on_setup_logging=on_setup_logging,
on_after_setup_task_logger=on_after_setup_task_logger,
on_task_prerun=on_task_prerun,
on_task_postrun=on_task_postrun,
)
def configure_celery_app(app: Celery, settings: Settings | None = None) -> None:
app.conf.worker_hijack_root_logger = False
handlers = build_celery_signal_handlers(settings)
signals.setup_logging.connect(handlers.on_setup_logging, weak=False)
signals.after_setup_task_logger.connect(
handlers.on_after_setup_task_logger, weak=False
)
signals.task_prerun.connect(handlers.on_task_prerun, weak=False)
signals.task_postrun.connect(handlers.on_task_postrun, weak=False)
+103
View File
@@ -0,0 +1,103 @@
from __future__ import annotations
import logging
from logging.config import dictConfig
from pathlib import Path
import structlog
from core.config.settings import RuntimeSettings, Settings
from core.logging.formatters import (
build_plain_formatter,
build_processor_formatter,
ensure_message_key,
)
from core.logging.filters import build_sensitive_data_processor
from core.logging.handlers import build_file_handler_config
def _ensure_log_dirs(runtime: RuntimeSettings) -> None:
Path(runtime.log_dir).mkdir(parents=True, exist_ok=True)
Path(runtime.log_error_dir).mkdir(parents=True, exist_ok=True)
def build_logging_config(runtime: RuntimeSettings) -> dict[str, object]:
log_dir = Path(runtime.log_dir)
error_dir = Path(runtime.log_error_dir)
formatter_name = "json" if runtime.log_json else "plain"
file_handler = build_file_handler_config(
runtime,
file_path=log_dir / runtime.log_file_name,
level=runtime.log_level,
formatter=formatter_name,
)
error_handler = build_file_handler_config(
runtime,
file_path=error_dir / runtime.log_error_file_name,
level="ERROR",
formatter=formatter_name,
filters=["error_only"],
)
return {
"version": 1,
"disable_existing_loggers": False,
"filters": {
"error_only": {
"()": "core.logging.filters.ErrorLevelFilter",
}
},
"formatters": {
"json": {
"()": build_processor_formatter,
"sensitive_fields": runtime.log_sensitive_fields,
},
"plain": {
"()": build_plain_formatter,
"sensitive_fields": runtime.log_sensitive_fields,
},
},
"handlers": {
"file": file_handler,
"error": error_handler,
},
"root": {
"handlers": ["file", "error"],
"level": runtime.log_level,
},
}
def configure_logging(settings: Settings | None = None) -> None:
active_settings = settings or Settings()
runtime = active_settings.runtime
try:
_ensure_log_dirs(runtime)
dictConfig(build_logging_config(runtime))
except (OSError, ValueError) as exc:
logging.basicConfig(level=runtime.log_level)
logging.getLogger(__name__).error("Logging setup failed", exc_info=exc)
structlog.configure(
processors=[
structlog.contextvars.merge_contextvars,
structlog.processors.add_log_level,
structlog.processors.TimeStamper(fmt="iso", utc=True),
structlog.processors.CallsiteParameterAdder(
parameters=[
structlog.processors.CallsiteParameter.MODULE,
structlog.processors.CallsiteParameter.FUNC_NAME,
structlog.processors.CallsiteParameter.LINENO,
]
),
build_sensitive_data_processor(runtime.log_sensitive_fields),
ensure_message_key,
structlog.processors.format_exc_info,
structlog.processors.UnicodeDecoder(),
structlog.stdlib.ProcessorFormatter.wrap_for_formatter,
],
logger_factory=structlog.stdlib.LoggerFactory(),
cache_logger_on_first_use=True,
)
+15
View File
@@ -0,0 +1,15 @@
from __future__ import annotations
from structlog import contextvars
def bind_context(**values: object) -> None:
contextvars.bind_contextvars(**values)
def clear_context() -> None:
contextvars.clear_contextvars()
def get_context() -> dict[str, object]:
return contextvars.get_contextvars()
+56
View File
@@ -0,0 +1,56 @@
from __future__ import annotations
import logging
import re
from collections.abc import Callable
from typing import cast
from structlog.types import EventDict
_NORMALIZE_PATTERN = re.compile(r"[^a-z0-9]")
def _normalize_key(value: str) -> str:
return _NORMALIZE_PATTERN.sub("", value.lower())
def _is_sensitive_key(key: object, sensitive_fields: set[str]) -> bool:
normalized_key = _normalize_key(str(key))
return normalized_key in sensitive_fields or any(
fragment in normalized_key for fragment in sensitive_fields
)
def _redact_value(value: object, sensitive_fields: set[str]) -> object:
if isinstance(value, dict):
typed_value = cast(dict[str, object], value)
return {
key: (
"[REDACTED]"
if _is_sensitive_key(key, sensitive_fields)
else _redact_value(inner, sensitive_fields)
)
for key, inner in typed_value.items()
}
if isinstance(value, list):
return [_redact_value(item, sensitive_fields) for item in value]
return value
def build_sensitive_data_processor(
sensitive_fields: list[str],
) -> Callable[[object, str, EventDict], EventDict]:
normalized = {_normalize_key(field) for field in sensitive_fields}
def processor(
_logger: object, _method_name: str, event_dict: EventDict
) -> EventDict:
return cast(EventDict, _redact_value(event_dict, normalized))
return processor
class ErrorLevelFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
return record.levelno >= logging.ERROR
+81
View File
@@ -0,0 +1,81 @@
from __future__ import annotations
from structlog.dev import ConsoleRenderer
from structlog.processors import JSONRenderer
from structlog.stdlib import ProcessorFormatter
from structlog.types import EventDict
import structlog
from core.logging.filters import build_sensitive_data_processor
def ensure_message_key(
_logger: object, _method_name: str, event_dict: EventDict
) -> EventDict:
if "message" in event_dict:
return event_dict
if "event" not in event_dict:
return event_dict
without_event = {key: value for key, value in event_dict.items() if key != "event"}
return {**without_event, "message": event_dict["event"]}
def build_processor_formatter(
sensitive_fields: list[str] | None = None,
) -> ProcessorFormatter:
redact = build_sensitive_data_processor(sensitive_fields or [])
return ProcessorFormatter(
foreign_pre_chain=[
structlog.contextvars.merge_contextvars,
structlog.processors.add_log_level,
structlog.processors.TimeStamper(fmt="iso", utc=True),
structlog.processors.CallsiteParameterAdder(
parameters=[
structlog.processors.CallsiteParameter.MODULE,
structlog.processors.CallsiteParameter.FUNC_NAME,
structlog.processors.CallsiteParameter.LINENO,
]
),
structlog.stdlib.ExtraAdder(),
ensure_message_key,
],
processors=[
redact,
ensure_message_key,
ProcessorFormatter.remove_processors_meta,
structlog.processors.format_exc_info,
structlog.processors.UnicodeDecoder(),
JSONRenderer(sort_keys=True),
],
)
def build_plain_formatter(
sensitive_fields: list[str] | None = None,
) -> ProcessorFormatter:
redact = build_sensitive_data_processor(sensitive_fields or [])
return ProcessorFormatter(
foreign_pre_chain=[
structlog.contextvars.merge_contextvars,
structlog.processors.add_log_level,
structlog.processors.TimeStamper(fmt="iso", utc=True),
structlog.processors.CallsiteParameterAdder(
parameters=[
structlog.processors.CallsiteParameter.MODULE,
structlog.processors.CallsiteParameter.FUNC_NAME,
structlog.processors.CallsiteParameter.LINENO,
]
),
structlog.stdlib.ExtraAdder(),
ensure_message_key,
],
processors=[
redact,
ensure_message_key,
ProcessorFormatter.remove_processors_meta,
structlog.processors.format_exc_info,
structlog.processors.UnicodeDecoder(),
ConsoleRenderer(colors=False),
],
)
+46
View File
@@ -0,0 +1,46 @@
from __future__ import annotations
from pathlib import Path
from core.config.settings import RuntimeSettings
def build_file_handler_config(
runtime: RuntimeSettings,
file_path: Path,
level: str,
formatter: str,
filters: list[str] | None = None,
) -> dict[str, object]:
filter_list = list(filters or [])
base_config: dict[str, object] = {
"level": level,
"formatter": formatter,
"filename": str(file_path),
"encoding": "utf-8",
}
if filter_list:
base_config = {**base_config, "filters": filter_list}
if runtime.log_rotation == "time":
return {
**base_config,
"class": "logging.handlers.TimedRotatingFileHandler",
"when": runtime.log_rotation_when,
"interval": runtime.log_rotation_interval,
"backupCount": runtime.log_rotation_backup_count,
}
if runtime.log_rotation == "size":
return {
**base_config,
"class": "logging.handlers.RotatingFileHandler",
"maxBytes": runtime.log_rotation_max_bytes,
"backupCount": runtime.log_rotation_backup_count,
}
return {
**base_config,
"class": "logging.FileHandler",
}
+7
View File
@@ -0,0 +1,7 @@
from __future__ import annotations
import structlog
def get_logger(name: str) -> structlog.stdlib.BoundLogger:
return structlog.get_logger(name)
+84
View File
@@ -0,0 +1,84 @@
from __future__ import annotations
import re
from collections.abc import MutableMapping
from typing import cast
from uuid import uuid4
from fastapi import FastAPI, Request
from starlette.requests import Request as StarletteRequest
from starlette.responses import JSONResponse, Response
from starlette.types import ASGIApp, Receive, Scope, Send
from core.logging.context import bind_context, clear_context
from core.logging.logger import get_logger
class RequestContextMiddleware:
app: ASGIApp
_header_name: str
_request_id_pattern: re.Pattern[str]
def __init__(self, app: ASGIApp, header_name: str = "X-Request-ID") -> None:
self.app = app
self._header_name = header_name
self._request_id_pattern = re.compile(r"^[A-Za-z0-9_-]{8,64}$")
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope.get("type") != "http":
await self.app(scope, receive, send)
return
request = StarletteRequest(scope, receive=receive)
request_id = self._normalize_request_id(request.headers.get(self._header_name))
client_ip = request.client.host if request.client else None
user_id = getattr(request.state, "user_id", None)
request.state.request_id = request_id
bind_context(
request_id=request_id,
method=request.method,
path=request.url.path,
client_ip=client_ip,
user_id=user_id,
)
async def send_wrapper(message: MutableMapping[str, object]) -> None:
if message.get("type") == "http.response.start":
raw_headers = message.get("headers")
headers = list(cast(list[tuple[bytes, bytes]], raw_headers or []))
header_key = self._header_name.lower().encode()
if not any(item[0].lower() == header_key for item in headers):
headers.append((header_key, request_id.encode()))
message = {**message, "headers": headers}
await send(message)
try:
await self.app(scope, receive, send_wrapper)
finally:
clear_context()
def _normalize_request_id(self, request_id: str | None) -> str:
if request_id and self._request_id_pattern.match(request_id):
return request_id
return str(uuid4())
def register_exception_handlers(app: FastAPI) -> None:
logger = get_logger("core.logging.exception")
@app.exception_handler(Exception)
async def unhandled_exception_handler(request: Request, exc: Exception) -> Response:
request_id = getattr(request.state, "request_id", None)
logger.exception(
"Unhandled exception",
error_type=exc.__class__.__name__,
request_id=request_id,
)
headers = {"X-Request-ID": request_id} if request_id else None
return JSONResponse(
status_code=500,
content={"detail": "Internal Server Error"},
headers=headers,
)