refactor: align backend layout and supabase infra
Consolidate backend modules/tests under the backend package while syncing Supabase compose/env config and related plans.
This commit is contained in:
@@ -0,0 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CurrentUser:
|
||||
id: UUID
|
||||
@@ -0,0 +1,3 @@
|
||||
from .settings import Settings, config
|
||||
|
||||
__all__ = ["Settings", "config"]
|
||||
@@ -0,0 +1,166 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, Literal
|
||||
from urllib.parse import quote
|
||||
|
||||
from pydantic import BaseModel, Field, computed_field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class RuntimeSettings(BaseModel):
|
||||
environment: Literal["dev", "test", "prod"] = "dev"
|
||||
debug: bool = True
|
||||
log_level: str = "INFO"
|
||||
log_json: bool = True
|
||||
log_rotation: Literal["time", "size", "none"] = "time"
|
||||
log_rotation_when: str = "midnight"
|
||||
log_rotation_interval: int = 1
|
||||
log_rotation_backup_count: int = 14
|
||||
log_rotation_max_bytes: int = 10_000_000
|
||||
log_dir: str = "logs"
|
||||
log_error_dir: str = "logs/errors"
|
||||
log_file_name: str = "app.log"
|
||||
log_error_file_name: str = "error.log"
|
||||
log_sensitive_fields: list[str] = Field(
|
||||
default_factory=lambda: [
|
||||
"password",
|
||||
"secret",
|
||||
"token",
|
||||
"api_key",
|
||||
"authorization",
|
||||
"cookie",
|
||||
"client_ip",
|
||||
"user_id",
|
||||
]
|
||||
)
|
||||
sql_log_queries: bool = False
|
||||
|
||||
|
||||
class AppSettings(BaseModel):
|
||||
host: str = "0.0.0.0"
|
||||
port: int = Field(default=8000, ge=1, le=65535)
|
||||
reload: bool = True
|
||||
|
||||
|
||||
class CorsSettings(BaseModel):
|
||||
allow_origins: list[str] = Field(
|
||||
default_factory=lambda: [
|
||||
"http://localhost",
|
||||
"http://localhost:3000",
|
||||
]
|
||||
)
|
||||
allow_credentials: bool = True
|
||||
allow_methods: list[str] = Field(default_factory=lambda: ["*"])
|
||||
allow_headers: list[str] = Field(default_factory=lambda: ["*"])
|
||||
|
||||
|
||||
class RedisSettings(BaseModel):
|
||||
host: str = "redis"
|
||||
port: int = 6379
|
||||
password: str | None = None
|
||||
db: int = 0
|
||||
socket_connect_timeout: float = 1.0
|
||||
socket_timeout: float = 1.0
|
||||
max_connections: int = 10
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def url(self) -> str:
|
||||
if self.password:
|
||||
password = quote(self.password, safe="")
|
||||
return f"redis://:{password}@{self.host}:{self.port}/{self.db}"
|
||||
return f"redis://{self.host}:{self.port}/{self.db}"
|
||||
|
||||
|
||||
class QdrantSettings(BaseModel):
|
||||
host: str = "qdrant"
|
||||
port: int = 6333
|
||||
grpc_port: int = 6334
|
||||
api_key: str | None = None
|
||||
https: bool = False
|
||||
prefer_grpc: bool = True
|
||||
timeout: int = 5
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def url(self) -> str:
|
||||
scheme = "https" if self.https else "http"
|
||||
return f"{scheme}://{self.host}:{self.port}"
|
||||
|
||||
|
||||
class SupabaseSettings(BaseModel):
|
||||
public_scheme: str = "http"
|
||||
public_host: str = "localhost"
|
||||
kong_http_port: int = 8000
|
||||
anon_key: str = "CHANGE_ME"
|
||||
service_role_key: str = "CHANGE_ME"
|
||||
jwt_secret: str | None = None
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def public_url(self) -> str:
|
||||
return f"{self.public_scheme}://{self.public_host}:{self.kong_http_port}"
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def api_external_url(self) -> str:
|
||||
return self.public_url
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def url(self) -> str:
|
||||
return self.public_url
|
||||
|
||||
|
||||
class DatabaseSettings(BaseModel):
|
||||
host: str = "localhost"
|
||||
port: int = 5432
|
||||
name: str = "postgres"
|
||||
user: str = "postgres"
|
||||
password: str = "CHANGE_ME"
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def url(self) -> str:
|
||||
password = quote(self.password, safe="")
|
||||
return (
|
||||
f"postgresql+asyncpg://{self.user}:{password}"
|
||||
f"@{self.host}:{self.port}/{self.name}"
|
||||
)
|
||||
|
||||
|
||||
def _resolve_env_file() -> str:
|
||||
current = Path(__file__).resolve()
|
||||
for parent in [current, *current.parents]:
|
||||
candidate = parent / ".env"
|
||||
if candidate.is_file():
|
||||
return str(candidate)
|
||||
return ".env"
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
runtime: RuntimeSettings = RuntimeSettings()
|
||||
app: AppSettings = AppSettings()
|
||||
cors: CorsSettings = CorsSettings()
|
||||
redis: RedisSettings = RedisSettings()
|
||||
qdrant: QdrantSettings = QdrantSettings()
|
||||
supabase: SupabaseSettings = SupabaseSettings()
|
||||
|
||||
database: DatabaseSettings = DatabaseSettings()
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def database_url(self) -> str:
|
||||
return self.database.url
|
||||
|
||||
model_config: ClassVar[SettingsConfigDict] = SettingsConfigDict(
|
||||
env_file=_resolve_env_file(),
|
||||
env_prefix="SOCIAL_",
|
||||
env_nested_delimiter="__",
|
||||
case_sensitive=False,
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
|
||||
config = Settings()
|
||||
@@ -0,0 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.db.session import AsyncSessionLocal, engine, get_db
|
||||
|
||||
__all__ = ["AsyncSessionLocal", "engine", "get_db"]
|
||||
@@ -0,0 +1,37 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, func
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""Base class for all ORM models."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TimestampMixin:
|
||||
"""Adds created_at and updated_at timestamps."""
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
|
||||
class SoftDeleteMixin:
|
||||
"""Adds soft delete timestamp column."""
|
||||
|
||||
deleted_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=True,
|
||||
)
|
||||
@@ -0,0 +1,84 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from sqlalchemy import Select, select, update
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.db.base import Base
|
||||
|
||||
ModelType = TypeVar("ModelType", bound=Base)
|
||||
|
||||
|
||||
class BaseRepository(Generic[ModelType]):
|
||||
_session: AsyncSession
|
||||
_model: type[ModelType]
|
||||
|
||||
def __init__(self, session: AsyncSession, model: type[ModelType]) -> None:
|
||||
self._session = session
|
||||
self._model = model
|
||||
|
||||
def _deleted_at_column(self) -> Any | None:
|
||||
return getattr(self._model, "deleted_at", None)
|
||||
|
||||
def _apply_soft_delete_filter(self, stmt: Select) -> Select:
|
||||
deleted_at = self._deleted_at_column()
|
||||
if deleted_at is None:
|
||||
return stmt
|
||||
return stmt.where(deleted_at.is_(None))
|
||||
|
||||
async def get_by_id(self, entity_id: Any) -> ModelType | None:
|
||||
id_column = getattr(self._model, "id")
|
||||
stmt = select(self._model).where(id_column == entity_id)
|
||||
stmt = self._apply_soft_delete_filter(stmt)
|
||||
result = await self._session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_one(self, *filters: Any) -> ModelType | None:
|
||||
stmt = select(self._model).where(*filters)
|
||||
stmt = self._apply_soft_delete_filter(stmt)
|
||||
result = await self._session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def update_by_id(
|
||||
self, entity_id: Any, update_data: dict[str, Any]
|
||||
) -> ModelType | None:
|
||||
if not update_data:
|
||||
return await self.get_by_id(entity_id)
|
||||
|
||||
id_column = getattr(self._model, "id")
|
||||
stmt = update(self._model).where(id_column == entity_id)
|
||||
deleted_at = self._deleted_at_column()
|
||||
if deleted_at is not None:
|
||||
stmt = stmt.where(deleted_at.is_(None))
|
||||
stmt = stmt.values(**update_data).returning(self._model)
|
||||
|
||||
try:
|
||||
result = await self._session.execute(stmt)
|
||||
await self._session.flush()
|
||||
return result.scalar_one_or_none()
|
||||
except SQLAlchemyError:
|
||||
raise
|
||||
|
||||
async def soft_delete_by_id(self, entity_id: Any) -> ModelType | None:
|
||||
deleted_at = self._deleted_at_column()
|
||||
if deleted_at is None:
|
||||
raise ValueError("Soft delete is not supported for this model")
|
||||
|
||||
id_column = getattr(self._model, "id")
|
||||
stmt = (
|
||||
update(self._model)
|
||||
.where(id_column == entity_id)
|
||||
.where(deleted_at.is_(None))
|
||||
.values(deleted_at=datetime.now(timezone.utc))
|
||||
.returning(self._model)
|
||||
)
|
||||
|
||||
try:
|
||||
result = await self._session.execute(stmt)
|
||||
await self._session.flush()
|
||||
return result.scalar_one_or_none()
|
||||
except SQLAlchemyError:
|
||||
raise
|
||||
@@ -0,0 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
|
||||
|
||||
class BaseService:
|
||||
_current_user: CurrentUser | None
|
||||
|
||||
def __init__(self, current_user: CurrentUser | None) -> None:
|
||||
self._current_user = current_user
|
||||
|
||||
def require_current_user(self) -> CurrentUser:
|
||||
if self._current_user is None:
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
return self._current_user
|
||||
|
||||
def require_user_id(self) -> UUID:
|
||||
return self.require_current_user().id
|
||||
@@ -0,0 +1,34 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from core.config.settings import config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine
|
||||
|
||||
engine: AsyncEngine = create_async_engine(
|
||||
config.database_url,
|
||||
echo=config.runtime.sql_log_queries,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
|
||||
AsyncSessionLocal: async_sessionmaker[AsyncSession] = async_sessionmaker(
|
||||
bind=engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
|
||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Dependency that provides a database session.
|
||||
|
||||
The session is automatically closed when the request completes.
|
||||
Note: The caller (service layer) is responsible for commit/rollback.
|
||||
"""
|
||||
async with AsyncSessionLocal() as session:
|
||||
yield session
|
||||
@@ -0,0 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.http.response import ProblemDetails, build_problem_details
|
||||
|
||||
__all__ = ["ProblemDetails", "build_problem_details"]
|
||||
@@ -0,0 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
status: str
|
||||
@@ -0,0 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from http import HTTPStatus
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ProblemDetails(BaseModel):
|
||||
type: str = "about:blank"
|
||||
title: str
|
||||
status: int
|
||||
detail: str
|
||||
instance: str | None = None
|
||||
|
||||
|
||||
def build_problem_details(
|
||||
*,
|
||||
status_code: int,
|
||||
detail: str,
|
||||
type_value: str = "about:blank",
|
||||
title: str | None = None,
|
||||
instance: str | None = None,
|
||||
) -> ProblemDetails:
|
||||
resolved_title = title or HTTPStatus(status_code).phrase
|
||||
return ProblemDetails(
|
||||
type=type_value,
|
||||
title=resolved_title,
|
||||
status=status_code,
|
||||
detail=detail,
|
||||
instance=instance,
|
||||
)
|
||||
@@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.logging import celery
|
||||
from core.logging.config import configure_logging
|
||||
from core.logging.context import bind_context, clear_context, get_context
|
||||
from core.logging.logger import get_logger
|
||||
|
||||
__all__ = [
|
||||
"bind_context",
|
||||
"celery",
|
||||
"clear_context",
|
||||
"configure_logging",
|
||||
"get_context",
|
||||
"get_logger",
|
||||
]
|
||||
@@ -0,0 +1,57 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery, signals
|
||||
|
||||
from core.config.settings import Settings
|
||||
from core.logging.config import configure_logging
|
||||
from core.logging.context import bind_context, clear_context
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CelerySignalHandlers:
|
||||
on_setup_logging: Callable[..., None]
|
||||
on_after_setup_task_logger: Callable[..., None]
|
||||
on_task_prerun: Callable[..., None]
|
||||
on_task_postrun: Callable[..., None]
|
||||
|
||||
|
||||
def build_celery_signal_handlers(
|
||||
settings: Settings | None = None,
|
||||
) -> CelerySignalHandlers:
|
||||
def on_setup_logging(*_args: object, **_kwargs: object) -> None:
|
||||
configure_logging(settings)
|
||||
|
||||
def on_after_setup_task_logger(*_args: object, **_kwargs: object) -> None:
|
||||
configure_logging(settings)
|
||||
|
||||
def on_task_prerun(*_args: object, **kwargs: object) -> None:
|
||||
task_id = cast(str | None, kwargs.get("task_id"))
|
||||
task = kwargs.get("task")
|
||||
task_name = getattr(task, "name", None)
|
||||
bind_context(task_id=task_id, task_name=task_name)
|
||||
|
||||
def on_task_postrun(*_args: object, **_kwargs: object) -> None:
|
||||
clear_context()
|
||||
|
||||
return CelerySignalHandlers(
|
||||
on_setup_logging=on_setup_logging,
|
||||
on_after_setup_task_logger=on_after_setup_task_logger,
|
||||
on_task_prerun=on_task_prerun,
|
||||
on_task_postrun=on_task_postrun,
|
||||
)
|
||||
|
||||
|
||||
def configure_celery_app(app: Celery, settings: Settings | None = None) -> None:
|
||||
app.conf.worker_hijack_root_logger = False
|
||||
|
||||
handlers = build_celery_signal_handlers(settings)
|
||||
signals.setup_logging.connect(handlers.on_setup_logging, weak=False)
|
||||
signals.after_setup_task_logger.connect(
|
||||
handlers.on_after_setup_task_logger, weak=False
|
||||
)
|
||||
signals.task_prerun.connect(handlers.on_task_prerun, weak=False)
|
||||
signals.task_postrun.connect(handlers.on_task_postrun, weak=False)
|
||||
@@ -0,0 +1,103 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from logging.config import dictConfig
|
||||
from pathlib import Path
|
||||
|
||||
import structlog
|
||||
|
||||
from core.config.settings import RuntimeSettings, Settings
|
||||
from core.logging.formatters import (
|
||||
build_plain_formatter,
|
||||
build_processor_formatter,
|
||||
ensure_message_key,
|
||||
)
|
||||
from core.logging.filters import build_sensitive_data_processor
|
||||
from core.logging.handlers import build_file_handler_config
|
||||
|
||||
|
||||
def _ensure_log_dirs(runtime: RuntimeSettings) -> None:
|
||||
Path(runtime.log_dir).mkdir(parents=True, exist_ok=True)
|
||||
Path(runtime.log_error_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def build_logging_config(runtime: RuntimeSettings) -> dict[str, object]:
|
||||
log_dir = Path(runtime.log_dir)
|
||||
error_dir = Path(runtime.log_error_dir)
|
||||
formatter_name = "json" if runtime.log_json else "plain"
|
||||
|
||||
file_handler = build_file_handler_config(
|
||||
runtime,
|
||||
file_path=log_dir / runtime.log_file_name,
|
||||
level=runtime.log_level,
|
||||
formatter=formatter_name,
|
||||
)
|
||||
error_handler = build_file_handler_config(
|
||||
runtime,
|
||||
file_path=error_dir / runtime.log_error_file_name,
|
||||
level="ERROR",
|
||||
formatter=formatter_name,
|
||||
filters=["error_only"],
|
||||
)
|
||||
|
||||
return {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"filters": {
|
||||
"error_only": {
|
||||
"()": "core.logging.filters.ErrorLevelFilter",
|
||||
}
|
||||
},
|
||||
"formatters": {
|
||||
"json": {
|
||||
"()": build_processor_formatter,
|
||||
"sensitive_fields": runtime.log_sensitive_fields,
|
||||
},
|
||||
"plain": {
|
||||
"()": build_plain_formatter,
|
||||
"sensitive_fields": runtime.log_sensitive_fields,
|
||||
},
|
||||
},
|
||||
"handlers": {
|
||||
"file": file_handler,
|
||||
"error": error_handler,
|
||||
},
|
||||
"root": {
|
||||
"handlers": ["file", "error"],
|
||||
"level": runtime.log_level,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def configure_logging(settings: Settings | None = None) -> None:
|
||||
active_settings = settings or Settings()
|
||||
runtime = active_settings.runtime
|
||||
|
||||
try:
|
||||
_ensure_log_dirs(runtime)
|
||||
dictConfig(build_logging_config(runtime))
|
||||
except (OSError, ValueError) as exc:
|
||||
logging.basicConfig(level=runtime.log_level)
|
||||
logging.getLogger(__name__).error("Logging setup failed", exc_info=exc)
|
||||
|
||||
structlog.configure(
|
||||
processors=[
|
||||
structlog.contextvars.merge_contextvars,
|
||||
structlog.processors.add_log_level,
|
||||
structlog.processors.TimeStamper(fmt="iso", utc=True),
|
||||
structlog.processors.CallsiteParameterAdder(
|
||||
parameters=[
|
||||
structlog.processors.CallsiteParameter.MODULE,
|
||||
structlog.processors.CallsiteParameter.FUNC_NAME,
|
||||
structlog.processors.CallsiteParameter.LINENO,
|
||||
]
|
||||
),
|
||||
build_sensitive_data_processor(runtime.log_sensitive_fields),
|
||||
ensure_message_key,
|
||||
structlog.processors.format_exc_info,
|
||||
structlog.processors.UnicodeDecoder(),
|
||||
structlog.stdlib.ProcessorFormatter.wrap_for_formatter,
|
||||
],
|
||||
logger_factory=structlog.stdlib.LoggerFactory(),
|
||||
cache_logger_on_first_use=True,
|
||||
)
|
||||
@@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from structlog import contextvars
|
||||
|
||||
|
||||
def bind_context(**values: object) -> None:
|
||||
contextvars.bind_contextvars(**values)
|
||||
|
||||
|
||||
def clear_context() -> None:
|
||||
contextvars.clear_contextvars()
|
||||
|
||||
|
||||
def get_context() -> dict[str, object]:
|
||||
return contextvars.get_contextvars()
|
||||
@@ -0,0 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from structlog.types import EventDict
|
||||
|
||||
|
||||
_NORMALIZE_PATTERN = re.compile(r"[^a-z0-9]")
|
||||
|
||||
|
||||
def _normalize_key(value: str) -> str:
|
||||
return _NORMALIZE_PATTERN.sub("", value.lower())
|
||||
|
||||
|
||||
def _is_sensitive_key(key: object, sensitive_fields: set[str]) -> bool:
|
||||
normalized_key = _normalize_key(str(key))
|
||||
return normalized_key in sensitive_fields or any(
|
||||
fragment in normalized_key for fragment in sensitive_fields
|
||||
)
|
||||
|
||||
|
||||
def _redact_value(value: object, sensitive_fields: set[str]) -> object:
|
||||
if isinstance(value, dict):
|
||||
typed_value = cast(dict[str, object], value)
|
||||
return {
|
||||
key: (
|
||||
"[REDACTED]"
|
||||
if _is_sensitive_key(key, sensitive_fields)
|
||||
else _redact_value(inner, sensitive_fields)
|
||||
)
|
||||
for key, inner in typed_value.items()
|
||||
}
|
||||
if isinstance(value, list):
|
||||
return [_redact_value(item, sensitive_fields) for item in value]
|
||||
return value
|
||||
|
||||
|
||||
def build_sensitive_data_processor(
|
||||
sensitive_fields: list[str],
|
||||
) -> Callable[[object, str, EventDict], EventDict]:
|
||||
normalized = {_normalize_key(field) for field in sensitive_fields}
|
||||
|
||||
def processor(
|
||||
_logger: object, _method_name: str, event_dict: EventDict
|
||||
) -> EventDict:
|
||||
return cast(EventDict, _redact_value(event_dict, normalized))
|
||||
|
||||
return processor
|
||||
|
||||
|
||||
class ErrorLevelFilter(logging.Filter):
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
return record.levelno >= logging.ERROR
|
||||
@@ -0,0 +1,81 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from structlog.dev import ConsoleRenderer
|
||||
from structlog.processors import JSONRenderer
|
||||
from structlog.stdlib import ProcessorFormatter
|
||||
from structlog.types import EventDict
|
||||
import structlog
|
||||
|
||||
from core.logging.filters import build_sensitive_data_processor
|
||||
|
||||
|
||||
def ensure_message_key(
|
||||
_logger: object, _method_name: str, event_dict: EventDict
|
||||
) -> EventDict:
|
||||
if "message" in event_dict:
|
||||
return event_dict
|
||||
if "event" not in event_dict:
|
||||
return event_dict
|
||||
|
||||
without_event = {key: value for key, value in event_dict.items() if key != "event"}
|
||||
return {**without_event, "message": event_dict["event"]}
|
||||
|
||||
|
||||
def build_processor_formatter(
|
||||
sensitive_fields: list[str] | None = None,
|
||||
) -> ProcessorFormatter:
|
||||
redact = build_sensitive_data_processor(sensitive_fields or [])
|
||||
return ProcessorFormatter(
|
||||
foreign_pre_chain=[
|
||||
structlog.contextvars.merge_contextvars,
|
||||
structlog.processors.add_log_level,
|
||||
structlog.processors.TimeStamper(fmt="iso", utc=True),
|
||||
structlog.processors.CallsiteParameterAdder(
|
||||
parameters=[
|
||||
structlog.processors.CallsiteParameter.MODULE,
|
||||
structlog.processors.CallsiteParameter.FUNC_NAME,
|
||||
structlog.processors.CallsiteParameter.LINENO,
|
||||
]
|
||||
),
|
||||
structlog.stdlib.ExtraAdder(),
|
||||
ensure_message_key,
|
||||
],
|
||||
processors=[
|
||||
redact,
|
||||
ensure_message_key,
|
||||
ProcessorFormatter.remove_processors_meta,
|
||||
structlog.processors.format_exc_info,
|
||||
structlog.processors.UnicodeDecoder(),
|
||||
JSONRenderer(sort_keys=True),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def build_plain_formatter(
|
||||
sensitive_fields: list[str] | None = None,
|
||||
) -> ProcessorFormatter:
|
||||
redact = build_sensitive_data_processor(sensitive_fields or [])
|
||||
return ProcessorFormatter(
|
||||
foreign_pre_chain=[
|
||||
structlog.contextvars.merge_contextvars,
|
||||
structlog.processors.add_log_level,
|
||||
structlog.processors.TimeStamper(fmt="iso", utc=True),
|
||||
structlog.processors.CallsiteParameterAdder(
|
||||
parameters=[
|
||||
structlog.processors.CallsiteParameter.MODULE,
|
||||
structlog.processors.CallsiteParameter.FUNC_NAME,
|
||||
structlog.processors.CallsiteParameter.LINENO,
|
||||
]
|
||||
),
|
||||
structlog.stdlib.ExtraAdder(),
|
||||
ensure_message_key,
|
||||
],
|
||||
processors=[
|
||||
redact,
|
||||
ensure_message_key,
|
||||
ProcessorFormatter.remove_processors_meta,
|
||||
structlog.processors.format_exc_info,
|
||||
structlog.processors.UnicodeDecoder(),
|
||||
ConsoleRenderer(colors=False),
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,46 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from core.config.settings import RuntimeSettings
|
||||
|
||||
|
||||
def build_file_handler_config(
|
||||
runtime: RuntimeSettings,
|
||||
file_path: Path,
|
||||
level: str,
|
||||
formatter: str,
|
||||
filters: list[str] | None = None,
|
||||
) -> dict[str, object]:
|
||||
filter_list = list(filters or [])
|
||||
base_config: dict[str, object] = {
|
||||
"level": level,
|
||||
"formatter": formatter,
|
||||
"filename": str(file_path),
|
||||
"encoding": "utf-8",
|
||||
}
|
||||
|
||||
if filter_list:
|
||||
base_config = {**base_config, "filters": filter_list}
|
||||
|
||||
if runtime.log_rotation == "time":
|
||||
return {
|
||||
**base_config,
|
||||
"class": "logging.handlers.TimedRotatingFileHandler",
|
||||
"when": runtime.log_rotation_when,
|
||||
"interval": runtime.log_rotation_interval,
|
||||
"backupCount": runtime.log_rotation_backup_count,
|
||||
}
|
||||
|
||||
if runtime.log_rotation == "size":
|
||||
return {
|
||||
**base_config,
|
||||
"class": "logging.handlers.RotatingFileHandler",
|
||||
"maxBytes": runtime.log_rotation_max_bytes,
|
||||
"backupCount": runtime.log_rotation_backup_count,
|
||||
}
|
||||
|
||||
return {
|
||||
**base_config,
|
||||
"class": "logging.FileHandler",
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import structlog
|
||||
|
||||
|
||||
def get_logger(name: str) -> structlog.stdlib.BoundLogger:
|
||||
return structlog.get_logger(name)
|
||||
@@ -0,0 +1,84 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from collections.abc import MutableMapping
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from starlette.requests import Request as StarletteRequest
|
||||
from starlette.responses import JSONResponse, Response
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
from core.logging.context import bind_context, clear_context
|
||||
from core.logging.logger import get_logger
|
||||
|
||||
|
||||
class RequestContextMiddleware:
|
||||
app: ASGIApp
|
||||
_header_name: str
|
||||
_request_id_pattern: re.Pattern[str]
|
||||
|
||||
def __init__(self, app: ASGIApp, header_name: str = "X-Request-ID") -> None:
|
||||
self.app = app
|
||||
self._header_name = header_name
|
||||
self._request_id_pattern = re.compile(r"^[A-Za-z0-9_-]{8,64}$")
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope.get("type") != "http":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
request = StarletteRequest(scope, receive=receive)
|
||||
request_id = self._normalize_request_id(request.headers.get(self._header_name))
|
||||
client_ip = request.client.host if request.client else None
|
||||
user_id = getattr(request.state, "user_id", None)
|
||||
|
||||
request.state.request_id = request_id
|
||||
|
||||
bind_context(
|
||||
request_id=request_id,
|
||||
method=request.method,
|
||||
path=request.url.path,
|
||||
client_ip=client_ip,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
async def send_wrapper(message: MutableMapping[str, object]) -> None:
|
||||
if message.get("type") == "http.response.start":
|
||||
raw_headers = message.get("headers")
|
||||
headers = list(cast(list[tuple[bytes, bytes]], raw_headers or []))
|
||||
header_key = self._header_name.lower().encode()
|
||||
if not any(item[0].lower() == header_key for item in headers):
|
||||
headers.append((header_key, request_id.encode()))
|
||||
message = {**message, "headers": headers}
|
||||
await send(message)
|
||||
|
||||
try:
|
||||
await self.app(scope, receive, send_wrapper)
|
||||
finally:
|
||||
clear_context()
|
||||
|
||||
def _normalize_request_id(self, request_id: str | None) -> str:
|
||||
if request_id and self._request_id_pattern.match(request_id):
|
||||
return request_id
|
||||
return str(uuid4())
|
||||
|
||||
|
||||
def register_exception_handlers(app: FastAPI) -> None:
|
||||
logger = get_logger("core.logging.exception")
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def unhandled_exception_handler(request: Request, exc: Exception) -> Response:
|
||||
request_id = getattr(request.state, "request_id", None)
|
||||
logger.exception(
|
||||
"Unhandled exception",
|
||||
error_type=exc.__class__.__name__,
|
||||
request_id=request_id,
|
||||
)
|
||||
headers = {"X-Request-ID": request_id} if request_id else None
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "Internal Server Error"},
|
||||
headers=headers,
|
||||
)
|
||||
Reference in New Issue
Block a user