refactor: align backend layout and supabase infra
Consolidate backend modules/tests under the backend package while syncing Supabase compose/env config and related plans.
This commit is contained in:
@@ -0,0 +1,133 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
|
||||
from core.config.settings import config
|
||||
from core.http.models import HealthResponse
|
||||
from core.http.response import build_problem_details
|
||||
from core.logging import configure_logging, get_logger
|
||||
from v1.router import router as mobile_router
|
||||
|
||||
|
||||
configure_logging(config)
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=config.cors.allow_origins,
|
||||
allow_credentials=config.cors.allow_credentials,
|
||||
allow_methods=config.cors.allow_methods,
|
||||
allow_headers=config.cors.allow_headers,
|
||||
)
|
||||
app.include_router(mobile_router)
|
||||
logger = get_logger("api.app")
|
||||
|
||||
|
||||
@app.get("/health", response_model=HealthResponse)
|
||||
async def health() -> HealthResponse:
|
||||
return HealthResponse(status="ok")
|
||||
|
||||
|
||||
def _build_http_error_response(
|
||||
request: Request,
|
||||
exc: Exception,
|
||||
status_code: int,
|
||||
detail: object,
|
||||
) -> JSONResponse:
|
||||
instance = request.url.path
|
||||
detail_text = detail if isinstance(detail, str) else "Request failed"
|
||||
logger.warning(
|
||||
"HTTP error",
|
||||
status_code=status_code,
|
||||
detail=detail_text,
|
||||
detail_extra=detail,
|
||||
path=request.url.path,
|
||||
method=request.method,
|
||||
)
|
||||
problem = build_problem_details(
|
||||
status_code=status_code,
|
||||
detail=detail_text,
|
||||
instance=instance,
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=status_code,
|
||||
content=problem.model_dump(),
|
||||
media_type="application/problem+json",
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(HTTPException)
|
||||
async def http_exception_handler(
|
||||
request: Request,
|
||||
exc: HTTPException,
|
||||
) -> JSONResponse:
|
||||
return _build_http_error_response(
|
||||
request=request,
|
||||
exc=exc,
|
||||
status_code=exc.status_code,
|
||||
detail=exc.detail,
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(StarletteHTTPException)
|
||||
async def starlette_http_exception_handler(
|
||||
request: Request,
|
||||
exc: StarletteHTTPException,
|
||||
) -> JSONResponse:
|
||||
return _build_http_error_response(
|
||||
request=request,
|
||||
exc=exc,
|
||||
status_code=exc.status_code,
|
||||
detail=exc.detail,
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(
|
||||
request: Request,
|
||||
exc: RequestValidationError,
|
||||
) -> JSONResponse:
|
||||
instance = request.url.path
|
||||
logger.warning(
|
||||
"Request validation error",
|
||||
path=request.url.path,
|
||||
method=request.method,
|
||||
errors=exc.errors(),
|
||||
)
|
||||
problem = build_problem_details(
|
||||
status_code=422,
|
||||
detail="Invalid request",
|
||||
instance=instance,
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=422,
|
||||
content=problem.model_dump(),
|
||||
media_type="application/problem+json",
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def unhandled_exception_handler(
|
||||
request: Request,
|
||||
exc: Exception,
|
||||
) -> JSONResponse:
|
||||
instance = request.url.path
|
||||
logger.exception(
|
||||
"Unhandled error",
|
||||
path=request.url.path,
|
||||
method=request.method,
|
||||
)
|
||||
problem = build_problem_details(
|
||||
status_code=500,
|
||||
detail="Internal Server Error",
|
||||
instance=instance,
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=problem.model_dump(),
|
||||
media_type="application/problem+json",
|
||||
)
|
||||
@@ -0,0 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CurrentUser:
|
||||
id: UUID
|
||||
@@ -0,0 +1,3 @@
|
||||
from .settings import Settings, config
|
||||
|
||||
__all__ = ["Settings", "config"]
|
||||
@@ -0,0 +1,166 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, Literal
|
||||
from urllib.parse import quote
|
||||
|
||||
from pydantic import BaseModel, Field, computed_field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class RuntimeSettings(BaseModel):
|
||||
environment: Literal["dev", "test", "prod"] = "dev"
|
||||
debug: bool = True
|
||||
log_level: str = "INFO"
|
||||
log_json: bool = True
|
||||
log_rotation: Literal["time", "size", "none"] = "time"
|
||||
log_rotation_when: str = "midnight"
|
||||
log_rotation_interval: int = 1
|
||||
log_rotation_backup_count: int = 14
|
||||
log_rotation_max_bytes: int = 10_000_000
|
||||
log_dir: str = "logs"
|
||||
log_error_dir: str = "logs/errors"
|
||||
log_file_name: str = "app.log"
|
||||
log_error_file_name: str = "error.log"
|
||||
log_sensitive_fields: list[str] = Field(
|
||||
default_factory=lambda: [
|
||||
"password",
|
||||
"secret",
|
||||
"token",
|
||||
"api_key",
|
||||
"authorization",
|
||||
"cookie",
|
||||
"client_ip",
|
||||
"user_id",
|
||||
]
|
||||
)
|
||||
sql_log_queries: bool = False
|
||||
|
||||
|
||||
class AppSettings(BaseModel):
|
||||
host: str = "0.0.0.0"
|
||||
port: int = Field(default=8000, ge=1, le=65535)
|
||||
reload: bool = True
|
||||
|
||||
|
||||
class CorsSettings(BaseModel):
|
||||
allow_origins: list[str] = Field(
|
||||
default_factory=lambda: [
|
||||
"http://localhost",
|
||||
"http://localhost:3000",
|
||||
]
|
||||
)
|
||||
allow_credentials: bool = True
|
||||
allow_methods: list[str] = Field(default_factory=lambda: ["*"])
|
||||
allow_headers: list[str] = Field(default_factory=lambda: ["*"])
|
||||
|
||||
|
||||
class RedisSettings(BaseModel):
|
||||
host: str = "redis"
|
||||
port: int = 6379
|
||||
password: str | None = None
|
||||
db: int = 0
|
||||
socket_connect_timeout: float = 1.0
|
||||
socket_timeout: float = 1.0
|
||||
max_connections: int = 10
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def url(self) -> str:
|
||||
if self.password:
|
||||
password = quote(self.password, safe="")
|
||||
return f"redis://:{password}@{self.host}:{self.port}/{self.db}"
|
||||
return f"redis://{self.host}:{self.port}/{self.db}"
|
||||
|
||||
|
||||
class QdrantSettings(BaseModel):
|
||||
host: str = "qdrant"
|
||||
port: int = 6333
|
||||
grpc_port: int = 6334
|
||||
api_key: str | None = None
|
||||
https: bool = False
|
||||
prefer_grpc: bool = True
|
||||
timeout: int = 5
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def url(self) -> str:
|
||||
scheme = "https" if self.https else "http"
|
||||
return f"{scheme}://{self.host}:{self.port}"
|
||||
|
||||
|
||||
class SupabaseSettings(BaseModel):
|
||||
public_scheme: str = "http"
|
||||
public_host: str = "localhost"
|
||||
kong_http_port: int = 8000
|
||||
anon_key: str = "CHANGE_ME"
|
||||
service_role_key: str = "CHANGE_ME"
|
||||
jwt_secret: str | None = None
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def public_url(self) -> str:
|
||||
return f"{self.public_scheme}://{self.public_host}:{self.kong_http_port}"
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def api_external_url(self) -> str:
|
||||
return self.public_url
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def url(self) -> str:
|
||||
return self.public_url
|
||||
|
||||
|
||||
class DatabaseSettings(BaseModel):
|
||||
host: str = "localhost"
|
||||
port: int = 5432
|
||||
name: str = "postgres"
|
||||
user: str = "postgres"
|
||||
password: str = "CHANGE_ME"
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def url(self) -> str:
|
||||
password = quote(self.password, safe="")
|
||||
return (
|
||||
f"postgresql+asyncpg://{self.user}:{password}"
|
||||
f"@{self.host}:{self.port}/{self.name}"
|
||||
)
|
||||
|
||||
|
||||
def _resolve_env_file() -> str:
|
||||
current = Path(__file__).resolve()
|
||||
for parent in [current, *current.parents]:
|
||||
candidate = parent / ".env"
|
||||
if candidate.is_file():
|
||||
return str(candidate)
|
||||
return ".env"
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
runtime: RuntimeSettings = RuntimeSettings()
|
||||
app: AppSettings = AppSettings()
|
||||
cors: CorsSettings = CorsSettings()
|
||||
redis: RedisSettings = RedisSettings()
|
||||
qdrant: QdrantSettings = QdrantSettings()
|
||||
supabase: SupabaseSettings = SupabaseSettings()
|
||||
|
||||
database: DatabaseSettings = DatabaseSettings()
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def database_url(self) -> str:
|
||||
return self.database.url
|
||||
|
||||
model_config: ClassVar[SettingsConfigDict] = SettingsConfigDict(
|
||||
env_file=_resolve_env_file(),
|
||||
env_prefix="SOCIAL_",
|
||||
env_nested_delimiter="__",
|
||||
case_sensitive=False,
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
|
||||
config = Settings()
|
||||
@@ -0,0 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.db.session import AsyncSessionLocal, engine, get_db
|
||||
|
||||
__all__ = ["AsyncSessionLocal", "engine", "get_db"]
|
||||
@@ -0,0 +1,37 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, func
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""Base class for all ORM models."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TimestampMixin:
|
||||
"""Adds created_at and updated_at timestamps."""
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
|
||||
class SoftDeleteMixin:
|
||||
"""Adds soft delete timestamp column."""
|
||||
|
||||
deleted_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=True,
|
||||
)
|
||||
@@ -0,0 +1,84 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from sqlalchemy import Select, select, update
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.db.base import Base
|
||||
|
||||
ModelType = TypeVar("ModelType", bound=Base)
|
||||
|
||||
|
||||
class BaseRepository(Generic[ModelType]):
|
||||
_session: AsyncSession
|
||||
_model: type[ModelType]
|
||||
|
||||
def __init__(self, session: AsyncSession, model: type[ModelType]) -> None:
|
||||
self._session = session
|
||||
self._model = model
|
||||
|
||||
def _deleted_at_column(self) -> Any | None:
|
||||
return getattr(self._model, "deleted_at", None)
|
||||
|
||||
def _apply_soft_delete_filter(self, stmt: Select) -> Select:
|
||||
deleted_at = self._deleted_at_column()
|
||||
if deleted_at is None:
|
||||
return stmt
|
||||
return stmt.where(deleted_at.is_(None))
|
||||
|
||||
async def get_by_id(self, entity_id: Any) -> ModelType | None:
|
||||
id_column = getattr(self._model, "id")
|
||||
stmt = select(self._model).where(id_column == entity_id)
|
||||
stmt = self._apply_soft_delete_filter(stmt)
|
||||
result = await self._session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_one(self, *filters: Any) -> ModelType | None:
|
||||
stmt = select(self._model).where(*filters)
|
||||
stmt = self._apply_soft_delete_filter(stmt)
|
||||
result = await self._session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def update_by_id(
|
||||
self, entity_id: Any, update_data: dict[str, Any]
|
||||
) -> ModelType | None:
|
||||
if not update_data:
|
||||
return await self.get_by_id(entity_id)
|
||||
|
||||
id_column = getattr(self._model, "id")
|
||||
stmt = update(self._model).where(id_column == entity_id)
|
||||
deleted_at = self._deleted_at_column()
|
||||
if deleted_at is not None:
|
||||
stmt = stmt.where(deleted_at.is_(None))
|
||||
stmt = stmt.values(**update_data).returning(self._model)
|
||||
|
||||
try:
|
||||
result = await self._session.execute(stmt)
|
||||
await self._session.flush()
|
||||
return result.scalar_one_or_none()
|
||||
except SQLAlchemyError:
|
||||
raise
|
||||
|
||||
async def soft_delete_by_id(self, entity_id: Any) -> ModelType | None:
|
||||
deleted_at = self._deleted_at_column()
|
||||
if deleted_at is None:
|
||||
raise ValueError("Soft delete is not supported for this model")
|
||||
|
||||
id_column = getattr(self._model, "id")
|
||||
stmt = (
|
||||
update(self._model)
|
||||
.where(id_column == entity_id)
|
||||
.where(deleted_at.is_(None))
|
||||
.values(deleted_at=datetime.now(timezone.utc))
|
||||
.returning(self._model)
|
||||
)
|
||||
|
||||
try:
|
||||
result = await self._session.execute(stmt)
|
||||
await self._session.flush()
|
||||
return result.scalar_one_or_none()
|
||||
except SQLAlchemyError:
|
||||
raise
|
||||
@@ -0,0 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
|
||||
|
||||
class BaseService:
|
||||
_current_user: CurrentUser | None
|
||||
|
||||
def __init__(self, current_user: CurrentUser | None) -> None:
|
||||
self._current_user = current_user
|
||||
|
||||
def require_current_user(self) -> CurrentUser:
|
||||
if self._current_user is None:
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
return self._current_user
|
||||
|
||||
def require_user_id(self) -> UUID:
|
||||
return self.require_current_user().id
|
||||
@@ -0,0 +1,34 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from core.config.settings import config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine
|
||||
|
||||
engine: AsyncEngine = create_async_engine(
|
||||
config.database_url,
|
||||
echo=config.runtime.sql_log_queries,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
|
||||
AsyncSessionLocal: async_sessionmaker[AsyncSession] = async_sessionmaker(
|
||||
bind=engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
|
||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Dependency that provides a database session.
|
||||
|
||||
The session is automatically closed when the request completes.
|
||||
Note: The caller (service layer) is responsible for commit/rollback.
|
||||
"""
|
||||
async with AsyncSessionLocal() as session:
|
||||
yield session
|
||||
@@ -0,0 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.http.response import ProblemDetails, build_problem_details
|
||||
|
||||
__all__ = ["ProblemDetails", "build_problem_details"]
|
||||
@@ -0,0 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
status: str
|
||||
@@ -0,0 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from http import HTTPStatus
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ProblemDetails(BaseModel):
|
||||
type: str = "about:blank"
|
||||
title: str
|
||||
status: int
|
||||
detail: str
|
||||
instance: str | None = None
|
||||
|
||||
|
||||
def build_problem_details(
|
||||
*,
|
||||
status_code: int,
|
||||
detail: str,
|
||||
type_value: str = "about:blank",
|
||||
title: str | None = None,
|
||||
instance: str | None = None,
|
||||
) -> ProblemDetails:
|
||||
resolved_title = title or HTTPStatus(status_code).phrase
|
||||
return ProblemDetails(
|
||||
type=type_value,
|
||||
title=resolved_title,
|
||||
status=status_code,
|
||||
detail=detail,
|
||||
instance=instance,
|
||||
)
|
||||
@@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.logging import celery
|
||||
from core.logging.config import configure_logging
|
||||
from core.logging.context import bind_context, clear_context, get_context
|
||||
from core.logging.logger import get_logger
|
||||
|
||||
__all__ = [
|
||||
"bind_context",
|
||||
"celery",
|
||||
"clear_context",
|
||||
"configure_logging",
|
||||
"get_context",
|
||||
"get_logger",
|
||||
]
|
||||
@@ -0,0 +1,57 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery, signals
|
||||
|
||||
from core.config.settings import Settings
|
||||
from core.logging.config import configure_logging
|
||||
from core.logging.context import bind_context, clear_context
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CelerySignalHandlers:
|
||||
on_setup_logging: Callable[..., None]
|
||||
on_after_setup_task_logger: Callable[..., None]
|
||||
on_task_prerun: Callable[..., None]
|
||||
on_task_postrun: Callable[..., None]
|
||||
|
||||
|
||||
def build_celery_signal_handlers(
|
||||
settings: Settings | None = None,
|
||||
) -> CelerySignalHandlers:
|
||||
def on_setup_logging(*_args: object, **_kwargs: object) -> None:
|
||||
configure_logging(settings)
|
||||
|
||||
def on_after_setup_task_logger(*_args: object, **_kwargs: object) -> None:
|
||||
configure_logging(settings)
|
||||
|
||||
def on_task_prerun(*_args: object, **kwargs: object) -> None:
|
||||
task_id = cast(str | None, kwargs.get("task_id"))
|
||||
task = kwargs.get("task")
|
||||
task_name = getattr(task, "name", None)
|
||||
bind_context(task_id=task_id, task_name=task_name)
|
||||
|
||||
def on_task_postrun(*_args: object, **_kwargs: object) -> None:
|
||||
clear_context()
|
||||
|
||||
return CelerySignalHandlers(
|
||||
on_setup_logging=on_setup_logging,
|
||||
on_after_setup_task_logger=on_after_setup_task_logger,
|
||||
on_task_prerun=on_task_prerun,
|
||||
on_task_postrun=on_task_postrun,
|
||||
)
|
||||
|
||||
|
||||
def configure_celery_app(app: Celery, settings: Settings | None = None) -> None:
|
||||
app.conf.worker_hijack_root_logger = False
|
||||
|
||||
handlers = build_celery_signal_handlers(settings)
|
||||
signals.setup_logging.connect(handlers.on_setup_logging, weak=False)
|
||||
signals.after_setup_task_logger.connect(
|
||||
handlers.on_after_setup_task_logger, weak=False
|
||||
)
|
||||
signals.task_prerun.connect(handlers.on_task_prerun, weak=False)
|
||||
signals.task_postrun.connect(handlers.on_task_postrun, weak=False)
|
||||
@@ -0,0 +1,103 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from logging.config import dictConfig
|
||||
from pathlib import Path
|
||||
|
||||
import structlog
|
||||
|
||||
from core.config.settings import RuntimeSettings, Settings
|
||||
from core.logging.formatters import (
|
||||
build_plain_formatter,
|
||||
build_processor_formatter,
|
||||
ensure_message_key,
|
||||
)
|
||||
from core.logging.filters import build_sensitive_data_processor
|
||||
from core.logging.handlers import build_file_handler_config
|
||||
|
||||
|
||||
def _ensure_log_dirs(runtime: RuntimeSettings) -> None:
|
||||
Path(runtime.log_dir).mkdir(parents=True, exist_ok=True)
|
||||
Path(runtime.log_error_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def build_logging_config(runtime: RuntimeSettings) -> dict[str, object]:
|
||||
log_dir = Path(runtime.log_dir)
|
||||
error_dir = Path(runtime.log_error_dir)
|
||||
formatter_name = "json" if runtime.log_json else "plain"
|
||||
|
||||
file_handler = build_file_handler_config(
|
||||
runtime,
|
||||
file_path=log_dir / runtime.log_file_name,
|
||||
level=runtime.log_level,
|
||||
formatter=formatter_name,
|
||||
)
|
||||
error_handler = build_file_handler_config(
|
||||
runtime,
|
||||
file_path=error_dir / runtime.log_error_file_name,
|
||||
level="ERROR",
|
||||
formatter=formatter_name,
|
||||
filters=["error_only"],
|
||||
)
|
||||
|
||||
return {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"filters": {
|
||||
"error_only": {
|
||||
"()": "core.logging.filters.ErrorLevelFilter",
|
||||
}
|
||||
},
|
||||
"formatters": {
|
||||
"json": {
|
||||
"()": build_processor_formatter,
|
||||
"sensitive_fields": runtime.log_sensitive_fields,
|
||||
},
|
||||
"plain": {
|
||||
"()": build_plain_formatter,
|
||||
"sensitive_fields": runtime.log_sensitive_fields,
|
||||
},
|
||||
},
|
||||
"handlers": {
|
||||
"file": file_handler,
|
||||
"error": error_handler,
|
||||
},
|
||||
"root": {
|
||||
"handlers": ["file", "error"],
|
||||
"level": runtime.log_level,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def configure_logging(settings: Settings | None = None) -> None:
|
||||
active_settings = settings or Settings()
|
||||
runtime = active_settings.runtime
|
||||
|
||||
try:
|
||||
_ensure_log_dirs(runtime)
|
||||
dictConfig(build_logging_config(runtime))
|
||||
except (OSError, ValueError) as exc:
|
||||
logging.basicConfig(level=runtime.log_level)
|
||||
logging.getLogger(__name__).error("Logging setup failed", exc_info=exc)
|
||||
|
||||
structlog.configure(
|
||||
processors=[
|
||||
structlog.contextvars.merge_contextvars,
|
||||
structlog.processors.add_log_level,
|
||||
structlog.processors.TimeStamper(fmt="iso", utc=True),
|
||||
structlog.processors.CallsiteParameterAdder(
|
||||
parameters=[
|
||||
structlog.processors.CallsiteParameter.MODULE,
|
||||
structlog.processors.CallsiteParameter.FUNC_NAME,
|
||||
structlog.processors.CallsiteParameter.LINENO,
|
||||
]
|
||||
),
|
||||
build_sensitive_data_processor(runtime.log_sensitive_fields),
|
||||
ensure_message_key,
|
||||
structlog.processors.format_exc_info,
|
||||
structlog.processors.UnicodeDecoder(),
|
||||
structlog.stdlib.ProcessorFormatter.wrap_for_formatter,
|
||||
],
|
||||
logger_factory=structlog.stdlib.LoggerFactory(),
|
||||
cache_logger_on_first_use=True,
|
||||
)
|
||||
@@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from structlog import contextvars
|
||||
|
||||
|
||||
def bind_context(**values: object) -> None:
|
||||
contextvars.bind_contextvars(**values)
|
||||
|
||||
|
||||
def clear_context() -> None:
|
||||
contextvars.clear_contextvars()
|
||||
|
||||
|
||||
def get_context() -> dict[str, object]:
|
||||
return contextvars.get_contextvars()
|
||||
@@ -0,0 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from structlog.types import EventDict
|
||||
|
||||
|
||||
_NORMALIZE_PATTERN = re.compile(r"[^a-z0-9]")
|
||||
|
||||
|
||||
def _normalize_key(value: str) -> str:
|
||||
return _NORMALIZE_PATTERN.sub("", value.lower())
|
||||
|
||||
|
||||
def _is_sensitive_key(key: object, sensitive_fields: set[str]) -> bool:
|
||||
normalized_key = _normalize_key(str(key))
|
||||
return normalized_key in sensitive_fields or any(
|
||||
fragment in normalized_key for fragment in sensitive_fields
|
||||
)
|
||||
|
||||
|
||||
def _redact_value(value: object, sensitive_fields: set[str]) -> object:
|
||||
if isinstance(value, dict):
|
||||
typed_value = cast(dict[str, object], value)
|
||||
return {
|
||||
key: (
|
||||
"[REDACTED]"
|
||||
if _is_sensitive_key(key, sensitive_fields)
|
||||
else _redact_value(inner, sensitive_fields)
|
||||
)
|
||||
for key, inner in typed_value.items()
|
||||
}
|
||||
if isinstance(value, list):
|
||||
return [_redact_value(item, sensitive_fields) for item in value]
|
||||
return value
|
||||
|
||||
|
||||
def build_sensitive_data_processor(
|
||||
sensitive_fields: list[str],
|
||||
) -> Callable[[object, str, EventDict], EventDict]:
|
||||
normalized = {_normalize_key(field) for field in sensitive_fields}
|
||||
|
||||
def processor(
|
||||
_logger: object, _method_name: str, event_dict: EventDict
|
||||
) -> EventDict:
|
||||
return cast(EventDict, _redact_value(event_dict, normalized))
|
||||
|
||||
return processor
|
||||
|
||||
|
||||
class ErrorLevelFilter(logging.Filter):
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
return record.levelno >= logging.ERROR
|
||||
@@ -0,0 +1,81 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from structlog.dev import ConsoleRenderer
|
||||
from structlog.processors import JSONRenderer
|
||||
from structlog.stdlib import ProcessorFormatter
|
||||
from structlog.types import EventDict
|
||||
import structlog
|
||||
|
||||
from core.logging.filters import build_sensitive_data_processor
|
||||
|
||||
|
||||
def ensure_message_key(
|
||||
_logger: object, _method_name: str, event_dict: EventDict
|
||||
) -> EventDict:
|
||||
if "message" in event_dict:
|
||||
return event_dict
|
||||
if "event" not in event_dict:
|
||||
return event_dict
|
||||
|
||||
without_event = {key: value for key, value in event_dict.items() if key != "event"}
|
||||
return {**without_event, "message": event_dict["event"]}
|
||||
|
||||
|
||||
def build_processor_formatter(
|
||||
sensitive_fields: list[str] | None = None,
|
||||
) -> ProcessorFormatter:
|
||||
redact = build_sensitive_data_processor(sensitive_fields or [])
|
||||
return ProcessorFormatter(
|
||||
foreign_pre_chain=[
|
||||
structlog.contextvars.merge_contextvars,
|
||||
structlog.processors.add_log_level,
|
||||
structlog.processors.TimeStamper(fmt="iso", utc=True),
|
||||
structlog.processors.CallsiteParameterAdder(
|
||||
parameters=[
|
||||
structlog.processors.CallsiteParameter.MODULE,
|
||||
structlog.processors.CallsiteParameter.FUNC_NAME,
|
||||
structlog.processors.CallsiteParameter.LINENO,
|
||||
]
|
||||
),
|
||||
structlog.stdlib.ExtraAdder(),
|
||||
ensure_message_key,
|
||||
],
|
||||
processors=[
|
||||
redact,
|
||||
ensure_message_key,
|
||||
ProcessorFormatter.remove_processors_meta,
|
||||
structlog.processors.format_exc_info,
|
||||
structlog.processors.UnicodeDecoder(),
|
||||
JSONRenderer(sort_keys=True),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def build_plain_formatter(
|
||||
sensitive_fields: list[str] | None = None,
|
||||
) -> ProcessorFormatter:
|
||||
redact = build_sensitive_data_processor(sensitive_fields or [])
|
||||
return ProcessorFormatter(
|
||||
foreign_pre_chain=[
|
||||
structlog.contextvars.merge_contextvars,
|
||||
structlog.processors.add_log_level,
|
||||
structlog.processors.TimeStamper(fmt="iso", utc=True),
|
||||
structlog.processors.CallsiteParameterAdder(
|
||||
parameters=[
|
||||
structlog.processors.CallsiteParameter.MODULE,
|
||||
structlog.processors.CallsiteParameter.FUNC_NAME,
|
||||
structlog.processors.CallsiteParameter.LINENO,
|
||||
]
|
||||
),
|
||||
structlog.stdlib.ExtraAdder(),
|
||||
ensure_message_key,
|
||||
],
|
||||
processors=[
|
||||
redact,
|
||||
ensure_message_key,
|
||||
ProcessorFormatter.remove_processors_meta,
|
||||
structlog.processors.format_exc_info,
|
||||
structlog.processors.UnicodeDecoder(),
|
||||
ConsoleRenderer(colors=False),
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,46 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from core.config.settings import RuntimeSettings
|
||||
|
||||
|
||||
def build_file_handler_config(
|
||||
runtime: RuntimeSettings,
|
||||
file_path: Path,
|
||||
level: str,
|
||||
formatter: str,
|
||||
filters: list[str] | None = None,
|
||||
) -> dict[str, object]:
|
||||
filter_list = list(filters or [])
|
||||
base_config: dict[str, object] = {
|
||||
"level": level,
|
||||
"formatter": formatter,
|
||||
"filename": str(file_path),
|
||||
"encoding": "utf-8",
|
||||
}
|
||||
|
||||
if filter_list:
|
||||
base_config = {**base_config, "filters": filter_list}
|
||||
|
||||
if runtime.log_rotation == "time":
|
||||
return {
|
||||
**base_config,
|
||||
"class": "logging.handlers.TimedRotatingFileHandler",
|
||||
"when": runtime.log_rotation_when,
|
||||
"interval": runtime.log_rotation_interval,
|
||||
"backupCount": runtime.log_rotation_backup_count,
|
||||
}
|
||||
|
||||
if runtime.log_rotation == "size":
|
||||
return {
|
||||
**base_config,
|
||||
"class": "logging.handlers.RotatingFileHandler",
|
||||
"maxBytes": runtime.log_rotation_max_bytes,
|
||||
"backupCount": runtime.log_rotation_backup_count,
|
||||
}
|
||||
|
||||
return {
|
||||
**base_config,
|
||||
"class": "logging.FileHandler",
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import structlog
|
||||
|
||||
|
||||
def get_logger(name: str) -> structlog.stdlib.BoundLogger:
|
||||
return structlog.get_logger(name)
|
||||
@@ -0,0 +1,84 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from collections.abc import MutableMapping
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from starlette.requests import Request as StarletteRequest
|
||||
from starlette.responses import JSONResponse, Response
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
from core.logging.context import bind_context, clear_context
|
||||
from core.logging.logger import get_logger
|
||||
|
||||
|
||||
class RequestContextMiddleware:
|
||||
app: ASGIApp
|
||||
_header_name: str
|
||||
_request_id_pattern: re.Pattern[str]
|
||||
|
||||
def __init__(self, app: ASGIApp, header_name: str = "X-Request-ID") -> None:
|
||||
self.app = app
|
||||
self._header_name = header_name
|
||||
self._request_id_pattern = re.compile(r"^[A-Za-z0-9_-]{8,64}$")
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope.get("type") != "http":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
request = StarletteRequest(scope, receive=receive)
|
||||
request_id = self._normalize_request_id(request.headers.get(self._header_name))
|
||||
client_ip = request.client.host if request.client else None
|
||||
user_id = getattr(request.state, "user_id", None)
|
||||
|
||||
request.state.request_id = request_id
|
||||
|
||||
bind_context(
|
||||
request_id=request_id,
|
||||
method=request.method,
|
||||
path=request.url.path,
|
||||
client_ip=client_ip,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
async def send_wrapper(message: MutableMapping[str, object]) -> None:
|
||||
if message.get("type") == "http.response.start":
|
||||
raw_headers = message.get("headers")
|
||||
headers = list(cast(list[tuple[bytes, bytes]], raw_headers or []))
|
||||
header_key = self._header_name.lower().encode()
|
||||
if not any(item[0].lower() == header_key for item in headers):
|
||||
headers.append((header_key, request_id.encode()))
|
||||
message = {**message, "headers": headers}
|
||||
await send(message)
|
||||
|
||||
try:
|
||||
await self.app(scope, receive, send_wrapper)
|
||||
finally:
|
||||
clear_context()
|
||||
|
||||
def _normalize_request_id(self, request_id: str | None) -> str:
|
||||
if request_id and self._request_id_pattern.match(request_id):
|
||||
return request_id
|
||||
return str(uuid4())
|
||||
|
||||
|
||||
def register_exception_handlers(app: FastAPI) -> None:
|
||||
logger = get_logger("core.logging.exception")
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def unhandled_exception_handler(request: Request, exc: Exception) -> Response:
|
||||
request_id = getattr(request.state, "request_id", None)
|
||||
logger.exception(
|
||||
"Unhandled exception",
|
||||
error_type=exc.__class__.__name__,
|
||||
request_id=request_id,
|
||||
)
|
||||
headers = {"X-Request-ID": request_id} if request_id else None
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "Internal Server Error"},
|
||||
headers=headers,
|
||||
)
|
||||
@@ -0,0 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from models.profile import Profile
|
||||
|
||||
__all__ = ["Profile"]
|
||||
@@ -0,0 +1,43 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import String, Text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from core.db.base import Base, SoftDeleteMixin, TimestampMixin
|
||||
|
||||
|
||||
class Profile(TimestampMixin, SoftDeleteMixin, Base):
|
||||
"""User profile model.
|
||||
|
||||
Note: The `id` column references auth.users(id) in Supabase.
|
||||
This is a business table managed by SQLAlchemy, with the foreign key
|
||||
relationship to Supabase's auth schema handled at the database level.
|
||||
"""
|
||||
|
||||
__tablename__: str = "profiles"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
)
|
||||
username: Mapped[str] = mapped_column(
|
||||
String(30),
|
||||
unique=True,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
display_name: Mapped[str | None] = mapped_column(
|
||||
String(50),
|
||||
nullable=True,
|
||||
)
|
||||
avatar_url: Mapped[str | None] = mapped_column(
|
||||
Text,
|
||||
nullable=True,
|
||||
)
|
||||
bio: Mapped[str | None] = mapped_column(
|
||||
String(200),
|
||||
nullable=True,
|
||||
)
|
||||
@@ -0,0 +1 @@
|
||||
from __future__ import annotations
|
||||
@@ -0,0 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from services.base.qdrant import QdrantService, qdrant_service
|
||||
from services.base.redis import RedisService, redis_service
|
||||
from services.base.service_interface import (
|
||||
BaseServiceProvider,
|
||||
ServiceRegistry,
|
||||
register_service,
|
||||
register_service_instance,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseServiceProvider",
|
||||
"QdrantService",
|
||||
"RedisService",
|
||||
"ServiceRegistry",
|
||||
"qdrant_service",
|
||||
"redis_service",
|
||||
"register_service",
|
||||
"register_service_instance",
|
||||
]
|
||||
@@ -0,0 +1,94 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
from core.config.settings import QdrantSettings, config
|
||||
|
||||
from .service_interface import BaseServiceProvider, register_service_instance
|
||||
|
||||
|
||||
class QdrantService(BaseServiceProvider):
|
||||
def __init__(self, settings: QdrantSettings | None = None) -> None:
|
||||
super().__init__("qdrant")
|
||||
self._settings = settings or config.qdrant
|
||||
self._client: Optional[QdrantClient] = None
|
||||
|
||||
def _build_client(self) -> QdrantClient:
|
||||
return QdrantClient(
|
||||
url=self._settings.url,
|
||||
api_key=self._settings.api_key,
|
||||
timeout=self._settings.timeout,
|
||||
prefer_grpc=self._settings.prefer_grpc,
|
||||
)
|
||||
|
||||
def _require_client(self) -> QdrantClient:
|
||||
client = self._client
|
||||
if client is None:
|
||||
raise RuntimeError("Qdrant client is not initialized")
|
||||
return client
|
||||
|
||||
async def initialize(self, **_: Any) -> bool:
|
||||
try:
|
||||
client = self._build_client()
|
||||
collections = await asyncio.to_thread(client.get_collections)
|
||||
self.logger.info(
|
||||
"Qdrant service initialized",
|
||||
collections_count=len(collections.collections),
|
||||
)
|
||||
self._client = client
|
||||
self._set_initialized(True)
|
||||
return True
|
||||
except Exception as exc: # noqa: BLE001
|
||||
self.logger.warning("Qdrant service initialization failed", error=str(exc))
|
||||
self._client = None
|
||||
self._set_initialized(False)
|
||||
return False
|
||||
|
||||
async def close(self) -> bool:
|
||||
client = self._client
|
||||
if client is None:
|
||||
return True
|
||||
try:
|
||||
close = getattr(client, "close", None)
|
||||
if callable(close):
|
||||
await asyncio.to_thread(close)
|
||||
self.logger.info("Qdrant service closed")
|
||||
self._client = None
|
||||
self._set_initialized(False)
|
||||
return True
|
||||
except Exception as exc: # noqa: BLE001
|
||||
self.logger.exception("Qdrant service close failed", error=str(exc))
|
||||
self._client = None
|
||||
self._set_initialized(False)
|
||||
return False
|
||||
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
client = self._client
|
||||
if client is None:
|
||||
return {"status": "unhealthy", "details": {"error": "not initialized"}}
|
||||
try:
|
||||
collections = await asyncio.to_thread(client.get_collections)
|
||||
return {
|
||||
"status": "healthy",
|
||||
"details": {
|
||||
"connected": True,
|
||||
"collections_count": len(collections.collections),
|
||||
"collections": [
|
||||
collection.name for collection in collections.collections[:5]
|
||||
],
|
||||
},
|
||||
}
|
||||
except Exception as exc: # noqa: BLE001
|
||||
self.logger.warning("Qdrant health check failed", error=str(exc))
|
||||
return {"status": "unhealthy", "details": {"error": str(exc)}}
|
||||
|
||||
def get_client(self) -> QdrantClient:
|
||||
return self._require_client()
|
||||
|
||||
|
||||
qdrant_service: QdrantService = register_service_instance("qdrant", QdrantService())
|
||||
|
||||
__all__ = ["QdrantService", "qdrant_service"]
|
||||
@@ -0,0 +1,97 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import redis.asyncio as redis
|
||||
|
||||
from core.config.settings import RedisSettings, config
|
||||
|
||||
from .service_interface import BaseServiceProvider, register_service_instance
|
||||
|
||||
|
||||
class RedisService(BaseServiceProvider):
|
||||
def __init__(self, settings: RedisSettings | None = None) -> None:
|
||||
super().__init__("redis")
|
||||
self._settings = settings or config.redis
|
||||
self._client: Optional[redis.Redis] = None
|
||||
|
||||
def _build_client(self) -> redis.Redis:
|
||||
return redis.from_url(
|
||||
self._settings.url,
|
||||
decode_responses=True,
|
||||
socket_connect_timeout=self._settings.socket_connect_timeout,
|
||||
socket_timeout=self._settings.socket_timeout,
|
||||
max_connections=self._settings.max_connections,
|
||||
)
|
||||
|
||||
def _require_client(self) -> redis.Redis:
|
||||
client = self._client
|
||||
if client is None:
|
||||
raise RuntimeError("Redis client is not initialized")
|
||||
return client
|
||||
|
||||
async def initialize(self, **_: Any) -> bool:
|
||||
try:
|
||||
client = self._build_client()
|
||||
ping_result = client.ping()
|
||||
if inspect.isawaitable(ping_result):
|
||||
await ping_result
|
||||
self._client = client
|
||||
self._set_initialized(True)
|
||||
self.logger.info("Redis service initialized")
|
||||
return True
|
||||
except Exception as exc: # noqa: BLE001
|
||||
self.logger.warning("Redis service initialization failed", error=str(exc))
|
||||
self._client = None
|
||||
self._set_initialized(False)
|
||||
return False
|
||||
|
||||
async def close(self) -> bool:
|
||||
client = self._client
|
||||
if client is None:
|
||||
return True
|
||||
try:
|
||||
await client.aclose()
|
||||
self.logger.info("Redis service closed")
|
||||
self._client = None
|
||||
self._set_initialized(False)
|
||||
return True
|
||||
except Exception as exc: # noqa: BLE001
|
||||
self.logger.exception("Redis service close failed", error=str(exc))
|
||||
return False
|
||||
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
client = self._client
|
||||
if client is None:
|
||||
return {"status": "unhealthy", "details": {"error": "not initialized"}}
|
||||
try:
|
||||
ping_result = client.ping()
|
||||
ping = (
|
||||
await ping_result if inspect.isawaitable(ping_result) else ping_result
|
||||
)
|
||||
info_result = client.info()
|
||||
info = (
|
||||
await info_result if inspect.isawaitable(info_result) else info_result
|
||||
)
|
||||
return {
|
||||
"status": "healthy" if ping else "unhealthy",
|
||||
"details": {
|
||||
"ping": ping,
|
||||
"redis_version": info.get("redis_version"),
|
||||
"connected_clients": info.get("connected_clients"),
|
||||
"used_memory": info.get("used_memory_human"),
|
||||
"uptime_in_seconds": info.get("uptime_in_seconds"),
|
||||
},
|
||||
}
|
||||
except Exception as exc: # noqa: BLE001
|
||||
self.logger.warning("Redis health check failed", error=str(exc))
|
||||
return {"status": "unhealthy", "details": {"error": str(exc)}}
|
||||
|
||||
def get_client(self) -> redis.Redis:
|
||||
return self._require_client()
|
||||
|
||||
|
||||
redis_service: RedisService = register_service_instance("redis", RedisService())
|
||||
|
||||
__all__ = ["RedisService", "redis_service"]
|
||||
@@ -0,0 +1,84 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, Optional, TypeVar
|
||||
|
||||
from core.logging import get_logger
|
||||
|
||||
|
||||
class BaseServiceProvider(ABC):
|
||||
def __init__(self, service_name: str) -> None:
|
||||
self.service_name = service_name
|
||||
self._initialized = False
|
||||
self.logger = get_logger("services.base").bind(service=service_name)
|
||||
|
||||
@abstractmethod
|
||||
async def initialize(self, **kwargs: Any) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def close(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
return self._initialized
|
||||
|
||||
def _set_initialized(self, value: bool) -> None:
|
||||
self._initialized = value
|
||||
|
||||
def get_service_info(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"name": self.service_name,
|
||||
"initialized": self._initialized,
|
||||
"type": self.__class__.__name__,
|
||||
}
|
||||
|
||||
|
||||
class ServiceRegistry:
|
||||
_services: Dict[str, Callable[..., BaseServiceProvider]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(
|
||||
cls, service_name: str, factory: Callable[..., BaseServiceProvider]
|
||||
) -> None:
|
||||
cls._services = {**cls._services, service_name: factory}
|
||||
|
||||
@classmethod
|
||||
def get_service_factory(
|
||||
cls, service_name: str
|
||||
) -> Optional[Callable[..., BaseServiceProvider]]:
|
||||
return cls._services.get(service_name)
|
||||
|
||||
@classmethod
|
||||
def list_services(cls) -> list[str]:
|
||||
return sorted(cls._services.keys())
|
||||
|
||||
@classmethod
|
||||
def create_service(
|
||||
cls, service_name: str, **kwargs: Any
|
||||
) -> Optional[BaseServiceProvider]:
|
||||
factory = cls.get_service_factory(service_name)
|
||||
if not factory:
|
||||
return None
|
||||
return factory(**kwargs)
|
||||
|
||||
|
||||
def register_service(service_name: str) -> Callable[[type], type]:
|
||||
def decorator(service_class: type) -> type:
|
||||
ServiceRegistry.register(service_name, service_class)
|
||||
return service_class
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
TService = TypeVar("TService", bound=BaseServiceProvider)
|
||||
|
||||
|
||||
def register_service_instance(service_name: str, service: TService) -> TService:
|
||||
ServiceRegistry.register(service_name, lambda: service)
|
||||
return service
|
||||
@@ -0,0 +1 @@
|
||||
from __future__ import annotations
|
||||
@@ -0,0 +1 @@
|
||||
from __future__ import annotations
|
||||
@@ -0,0 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from v1.auth.service import AuthService, SupabaseAuthGateway
|
||||
|
||||
|
||||
def get_auth_service() -> AuthService:
|
||||
return AuthService(gateway=SupabaseAuthGateway())
|
||||
@@ -0,0 +1,35 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
|
||||
|
||||
class SignupRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str = Field(min_length=6)
|
||||
display_name: str | None = None
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str = Field(min_length=6)
|
||||
|
||||
|
||||
class RefreshRequest(BaseModel):
|
||||
refresh_token: str = Field(min_length=1)
|
||||
|
||||
|
||||
class LogoutRequest(BaseModel):
|
||||
refresh_token: str = Field(min_length=1)
|
||||
|
||||
|
||||
class AuthUser(BaseModel):
|
||||
id: str
|
||||
email: EmailStr
|
||||
|
||||
|
||||
class AuthTokenResponse(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
expires_in: int
|
||||
token_type: str
|
||||
user: AuthUser
|
||||
@@ -0,0 +1,49 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends, Response
|
||||
|
||||
from v1.auth.dependencies import get_auth_service
|
||||
from v1.auth.models import (
|
||||
AuthTokenResponse,
|
||||
LoginRequest,
|
||||
LogoutRequest,
|
||||
RefreshRequest,
|
||||
SignupRequest,
|
||||
)
|
||||
from v1.auth.service import AuthService
|
||||
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
|
||||
@router.post("/signup", response_model=AuthTokenResponse)
|
||||
async def signup(
|
||||
payload: SignupRequest,
|
||||
service: AuthService = Depends(get_auth_service),
|
||||
) -> AuthTokenResponse:
|
||||
return await service.signup(payload)
|
||||
|
||||
|
||||
@router.post("/login", response_model=AuthTokenResponse)
|
||||
async def login(
|
||||
payload: LoginRequest,
|
||||
service: AuthService = Depends(get_auth_service),
|
||||
) -> AuthTokenResponse:
|
||||
return await service.login(payload)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=AuthTokenResponse)
|
||||
async def refresh(
|
||||
payload: RefreshRequest,
|
||||
service: AuthService = Depends(get_auth_service),
|
||||
) -> AuthTokenResponse:
|
||||
return await service.refresh(payload)
|
||||
|
||||
|
||||
@router.post("/logout", status_code=204)
|
||||
async def logout(
|
||||
payload: LogoutRequest,
|
||||
service: AuthService = Depends(get_auth_service),
|
||||
) -> Response:
|
||||
await service.logout(payload.refresh_token)
|
||||
return Response(status_code=204)
|
||||
@@ -0,0 +1,147 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Protocol, cast
|
||||
|
||||
from fastapi import HTTPException
|
||||
from supabase import AuthError, create_client
|
||||
|
||||
from core.config.settings import SupabaseSettings, config
|
||||
from core.logging import get_logger
|
||||
from v1.auth.models import (
|
||||
AuthTokenResponse,
|
||||
AuthUser,
|
||||
LoginRequest,
|
||||
RefreshRequest,
|
||||
SignupRequest,
|
||||
)
|
||||
|
||||
|
||||
logger = get_logger("v1.auth.service")
|
||||
|
||||
|
||||
class AuthServiceGateway(Protocol):
|
||||
async def signup(self, request: SignupRequest) -> AuthTokenResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
async def login(self, request: LoginRequest) -> AuthTokenResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
async def refresh(self, request: RefreshRequest) -> AuthTokenResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
async def logout(self, refresh_token: str | None) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SupabaseAuthGateway(AuthServiceGateway):
|
||||
_client: Any
|
||||
|
||||
def __init__(self) -> None:
|
||||
settings: SupabaseSettings = config.supabase
|
||||
self._client = create_client(settings.url, settings.anon_key)
|
||||
|
||||
async def signup(self, request: SignupRequest) -> AuthTokenResponse:
|
||||
payload: dict[str, Any] = {
|
||||
"email": request.email,
|
||||
"password": request.password,
|
||||
}
|
||||
if request.display_name:
|
||||
payload = {
|
||||
**payload,
|
||||
"data": {"display_name": request.display_name},
|
||||
}
|
||||
try:
|
||||
sign_up = cast(Any, self._client.auth.sign_up)
|
||||
response = await asyncio.to_thread(sign_up, payload)
|
||||
return _map_auth_response(response, "Authentication failed")
|
||||
except AuthError as exc:
|
||||
logger.warning("Signup failed", error=str(exc))
|
||||
raise HTTPException(
|
||||
status_code=401, detail="Authentication failed"
|
||||
) from exc
|
||||
|
||||
async def login(self, request: LoginRequest) -> AuthTokenResponse:
|
||||
payload: dict[str, Any] = {"email": request.email, "password": request.password}
|
||||
try:
|
||||
sign_in = cast(Any, self._client.auth.sign_in_with_password)
|
||||
response = await asyncio.to_thread(sign_in, payload)
|
||||
return _map_auth_response(response, "Invalid credentials")
|
||||
except AuthError as exc:
|
||||
logger.warning("Login failed", error=str(exc))
|
||||
raise HTTPException(status_code=401, detail="Invalid credentials") from exc
|
||||
|
||||
async def refresh(self, request: RefreshRequest) -> AuthTokenResponse:
|
||||
try:
|
||||
response = await asyncio.to_thread(
|
||||
self._client.auth.refresh_session,
|
||||
request.refresh_token,
|
||||
)
|
||||
return _map_auth_response(response, "Invalid refresh token")
|
||||
except AuthError as exc:
|
||||
logger.warning("Refresh failed", error=str(exc))
|
||||
raise HTTPException(
|
||||
status_code=401, detail="Invalid refresh token"
|
||||
) from exc
|
||||
|
||||
async def logout(self, refresh_token: str | None) -> None:
|
||||
if not refresh_token:
|
||||
raise HTTPException(status_code=401, detail="Missing refresh token")
|
||||
try:
|
||||
response = await asyncio.to_thread(
|
||||
self._client.auth.refresh_session,
|
||||
refresh_token,
|
||||
)
|
||||
session = getattr(response, "session", None)
|
||||
if session is None:
|
||||
raise HTTPException(status_code=401, detail="Invalid refresh token")
|
||||
await asyncio.to_thread(
|
||||
self._client.auth.set_session,
|
||||
str(session.access_token),
|
||||
str(session.refresh_token),
|
||||
)
|
||||
await asyncio.to_thread(self._client.auth.sign_out)
|
||||
except AuthError as exc:
|
||||
logger.warning("Logout failed", error=str(exc))
|
||||
raise HTTPException(
|
||||
status_code=401, detail="Invalid refresh token"
|
||||
) from exc
|
||||
|
||||
|
||||
class AuthService:
|
||||
_gateway: AuthServiceGateway
|
||||
|
||||
def __init__(self, gateway: AuthServiceGateway) -> None:
|
||||
self._gateway = gateway
|
||||
|
||||
async def signup(self, request: SignupRequest) -> AuthTokenResponse:
|
||||
return await self._gateway.signup(request)
|
||||
|
||||
async def login(self, request: LoginRequest) -> AuthTokenResponse:
|
||||
return await self._gateway.login(request)
|
||||
|
||||
async def refresh(self, request: RefreshRequest) -> AuthTokenResponse:
|
||||
return await self._gateway.refresh(request)
|
||||
|
||||
async def logout(self, refresh_token: str | None) -> None:
|
||||
await self._gateway.logout(refresh_token)
|
||||
|
||||
|
||||
def _map_auth_response(response: object, failure_message: str) -> AuthTokenResponse:
|
||||
session = getattr(response, "session", None)
|
||||
user = getattr(response, "user", None)
|
||||
if session is None or user is None:
|
||||
raise HTTPException(status_code=401, detail=failure_message)
|
||||
|
||||
email = getattr(user, "email", None)
|
||||
if not email:
|
||||
raise HTTPException(status_code=401, detail=failure_message)
|
||||
|
||||
auth_user = AuthUser(id=str(user.id), email=str(email))
|
||||
return AuthTokenResponse(
|
||||
access_token=str(session.access_token),
|
||||
refresh_token=str(session.refresh_token),
|
||||
expires_in=int(session.expires_in or 0),
|
||||
token_type=str(session.token_type),
|
||||
user=auth_user,
|
||||
)
|
||||
@@ -0,0 +1 @@
|
||||
from __future__ import annotations
|
||||
@@ -0,0 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from services.base.redis import RedisService, redis_service
|
||||
from services.base.qdrant import QdrantService, qdrant_service
|
||||
|
||||
|
||||
def get_redis_service() -> RedisService:
|
||||
return redis_service
|
||||
|
||||
|
||||
def get_qdrant_service() -> QdrantService:
|
||||
return qdrant_service
|
||||
@@ -0,0 +1,38 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from services.base.qdrant import QdrantService
|
||||
from services.base.redis import RedisService
|
||||
from v1.infra.dependencies import get_qdrant_service, get_redis_service
|
||||
from v1.infra.schemas import InfraHealthResponse, ServiceHealth
|
||||
|
||||
|
||||
router = APIRouter(prefix="/infra", tags=["infra"])
|
||||
|
||||
|
||||
@router.get("/health", response_model=InfraHealthResponse)
|
||||
async def infra_health(
|
||||
redis_service: RedisService = Depends(get_redis_service),
|
||||
qdrant_service: QdrantService = Depends(get_qdrant_service),
|
||||
) -> InfraHealthResponse:
|
||||
if not redis_service.is_initialized:
|
||||
await redis_service.initialize()
|
||||
if not qdrant_service.is_initialized:
|
||||
await qdrant_service.initialize()
|
||||
|
||||
redis_health = await redis_service.health_check()
|
||||
qdrant_health = await qdrant_service.health_check()
|
||||
status = (
|
||||
"healthy"
|
||||
if redis_health["status"] == "healthy" and qdrant_health["status"] == "healthy"
|
||||
else "unhealthy"
|
||||
)
|
||||
|
||||
return InfraHealthResponse(
|
||||
status=status,
|
||||
services={
|
||||
"redis": ServiceHealth(**redis_health),
|
||||
"qdrant": ServiceHealth(**qdrant_health),
|
||||
},
|
||||
)
|
||||
@@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ServiceHealth(BaseModel):
|
||||
status: Literal["healthy", "unhealthy"]
|
||||
details: Dict[str, Any]
|
||||
|
||||
|
||||
class InfraHealthResponse(BaseModel):
|
||||
status: Literal["healthy", "unhealthy"]
|
||||
services: Dict[str, ServiceHealth]
|
||||
@@ -0,0 +1 @@
|
||||
from __future__ import annotations
|
||||
@@ -0,0 +1,93 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
import jwt
|
||||
from fastapi import Depends, Header, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.config.settings import config
|
||||
from core.db import get_db
|
||||
from core.logging import get_logger
|
||||
from core.auth.models import CurrentUser
|
||||
from v1.profile.repository import SQLAlchemyProfileRepository
|
||||
from v1.profile.service import ProfileService
|
||||
|
||||
logger = get_logger("v1.profile.dependencies")
|
||||
|
||||
|
||||
def get_current_user(authorization: str | None = Header(default=None)) -> CurrentUser:
|
||||
if not authorization:
|
||||
logger.warning("JWT validation failed: missing authorization header")
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
scheme, _, token = authorization.partition(" ")
|
||||
if scheme.lower() != "bearer" or not token:
|
||||
logger.warning("JWT validation failed: invalid authorization scheme")
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
secret = config.supabase.jwt_secret
|
||||
if not secret:
|
||||
logger.error("JWT validation failed: secret not configured")
|
||||
raise HTTPException(status_code=503, detail="JWT secret not configured")
|
||||
|
||||
supabase_url = config.supabase.public_url.rstrip("/")
|
||||
expected_issuer = f"{supabase_url}/auth/v1"
|
||||
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
secret,
|
||||
algorithms=["HS256"],
|
||||
audience="authenticated",
|
||||
issuer=expected_issuer,
|
||||
options={
|
||||
"verify_aud": True,
|
||||
"verify_iss": True,
|
||||
"verify_exp": True,
|
||||
"require": ["sub", "aud", "iss", "exp"],
|
||||
},
|
||||
)
|
||||
except jwt.ExpiredSignatureError:
|
||||
logger.warning("JWT validation failed: token expired")
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
except jwt.InvalidAudienceError:
|
||||
logger.warning("JWT validation failed: invalid audience")
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
except jwt.InvalidIssuerError:
|
||||
logger.warning("JWT validation failed: invalid issuer")
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
except jwt.InvalidSignatureError:
|
||||
logger.warning("JWT validation failed: invalid signature")
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
except jwt.DecodeError:
|
||||
logger.warning("JWT validation failed: malformed token")
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
except jwt.PyJWTError as exc:
|
||||
logger.warning(
|
||||
"JWT validation failed: unknown error", error_type=type(exc).__name__
|
||||
)
|
||||
raise HTTPException(status_code=401, detail="Unauthorized") from exc
|
||||
|
||||
subject = payload.get("sub")
|
||||
if not isinstance(subject, str) or not subject:
|
||||
logger.warning("JWT validation failed: missing or invalid subject claim")
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
try:
|
||||
user_id = UUID(subject)
|
||||
except ValueError:
|
||||
logger.warning("JWT validation failed: invalid UUID in subject")
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
logger.debug("JWT validation successful", user_id=str(user_id))
|
||||
return CurrentUser(id=user_id)
|
||||
|
||||
|
||||
def get_profile_service(
|
||||
session: Annotated[AsyncSession, Depends(get_db)],
|
||||
user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
) -> ProfileService:
|
||||
repository = SQLAlchemyProfileRepository(session)
|
||||
return ProfileService(repository=repository, session=session, current_user=user)
|
||||
@@ -0,0 +1,72 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from core.db.base_repository import BaseRepository
|
||||
from core.logging import get_logger
|
||||
from models.profile import Profile
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
logger = get_logger("v1.profile.repository")
|
||||
|
||||
|
||||
class ProfileRepository(Protocol):
|
||||
"""Protocol defining the profile repository interface."""
|
||||
|
||||
async def get_by_user_id(self, user_id: UUID) -> Profile | None:
|
||||
"""Get profile by user ID."""
|
||||
...
|
||||
|
||||
async def get_by_username(self, username: str) -> Profile | None:
|
||||
"""Get profile by username."""
|
||||
...
|
||||
|
||||
async def update_by_user_id(
|
||||
self, user_id: UUID, update_data: dict[str, str | None]
|
||||
) -> Profile | None:
|
||||
"""Update profile by user ID. Returns updated profile or None if not found."""
|
||||
...
|
||||
|
||||
|
||||
class SQLAlchemyProfileRepository(BaseRepository[Profile]):
|
||||
"""SQLAlchemy implementation of ProfileRepository.
|
||||
|
||||
Note: This repository only performs CRUD operations.
|
||||
- No commit (only flush) - service layer handles transactions
|
||||
- No auth logic - service layer handles authorization
|
||||
- No HTTP exceptions - returns None or raises SQLAlchemyError
|
||||
"""
|
||||
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
super().__init__(session, Profile)
|
||||
|
||||
async def get_by_user_id(self, user_id: UUID) -> Profile | None:
|
||||
try:
|
||||
return await self.get_by_id(user_id)
|
||||
except SQLAlchemyError:
|
||||
logger.exception("Profile lookup failed", user_id=str(user_id))
|
||||
raise
|
||||
|
||||
async def get_by_username(self, username: str) -> Profile | None:
|
||||
try:
|
||||
return await self.get_one(Profile.username == username)
|
||||
except SQLAlchemyError:
|
||||
logger.exception("Profile lookup failed", username=username)
|
||||
raise
|
||||
|
||||
async def update_by_user_id(
|
||||
self, user_id: UUID, update_data: dict[str, str | None]
|
||||
) -> Profile | None:
|
||||
if not update_data:
|
||||
return await self.get_by_user_id(user_id)
|
||||
|
||||
try:
|
||||
return await self.update_by_id(user_id, update_data)
|
||||
except SQLAlchemyError:
|
||||
logger.exception("Profile update failed", user_id=str(user_id))
|
||||
raise
|
||||
@@ -0,0 +1,36 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Path
|
||||
|
||||
from v1.profile.dependencies import get_profile_service
|
||||
from v1.profile.schemas import ProfileResponse, ProfileUpdateRequest
|
||||
from v1.profile.service import ProfileService
|
||||
|
||||
router = APIRouter(prefix="/profile", tags=["profile"])
|
||||
|
||||
|
||||
@router.get("/me", response_model=ProfileResponse)
|
||||
async def get_me(
|
||||
service: Annotated[ProfileService, Depends(get_profile_service)],
|
||||
) -> ProfileResponse:
|
||||
return await service.get_me()
|
||||
|
||||
|
||||
@router.patch("/me", response_model=ProfileResponse)
|
||||
async def update_me(
|
||||
payload: ProfileUpdateRequest,
|
||||
service: Annotated[ProfileService, Depends(get_profile_service)],
|
||||
) -> ProfileResponse:
|
||||
return await service.update_me(payload)
|
||||
|
||||
|
||||
@router.get("/{username}", response_model=ProfileResponse)
|
||||
async def get_by_username(
|
||||
username: Annotated[
|
||||
str, Path(min_length=3, max_length=30, pattern="^[a-zA-Z0-9_]+$")
|
||||
],
|
||||
service: Annotated[ProfileService, Depends(get_profile_service)],
|
||||
) -> ProfileResponse:
|
||||
return await service.get_by_username(username)
|
||||
@@ -0,0 +1,33 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import AnyHttpUrl, BaseModel, Field, field_validator, model_validator
|
||||
|
||||
|
||||
class ProfileResponse(BaseModel):
|
||||
id: str
|
||||
username: str
|
||||
display_name: str | None = None
|
||||
avatar_url: str | None = None
|
||||
bio: str | None = None
|
||||
|
||||
|
||||
class ProfileUpdateRequest(BaseModel):
|
||||
display_name: str | None = Field(default=None, max_length=50)
|
||||
avatar_url: str | None = Field(default=None)
|
||||
bio: str | None = Field(default=None, max_length=200)
|
||||
|
||||
@field_validator("avatar_url", mode="before")
|
||||
@classmethod
|
||||
def validate_avatar_url(cls, v: str | None) -> str | None:
|
||||
if v is None:
|
||||
return None
|
||||
parsed = AnyHttpUrl(v)
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
raise ValueError("avatar_url must use http or https scheme")
|
||||
return str(parsed)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def require_one_field(self) -> "ProfileUpdateRequest":
|
||||
if self.display_name is None and self.avatar_url is None and self.bio is None:
|
||||
raise ValueError("At least one field must be provided")
|
||||
return self
|
||||
@@ -0,0 +1,106 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from core.db.base_service import BaseService
|
||||
from core.logging import get_logger
|
||||
from v1.profile.repository import ProfileRepository
|
||||
from v1.profile.schemas import ProfileResponse, ProfileUpdateRequest
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
logger = get_logger("v1.profile.service")
|
||||
|
||||
|
||||
class ProfileService(BaseService):
|
||||
"""Profile service handling business logic and transactions.
|
||||
|
||||
Responsibilities:
|
||||
- Authorization checks
|
||||
- Transaction boundary (commit/rollback)
|
||||
- Converting ORM models to response schemas
|
||||
"""
|
||||
|
||||
_repository: ProfileRepository
|
||||
_session: AsyncSession
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
repository: ProfileRepository,
|
||||
session: AsyncSession,
|
||||
current_user: CurrentUser | None,
|
||||
) -> None:
|
||||
super().__init__(current_user=current_user)
|
||||
self._repository = repository
|
||||
self._session = session
|
||||
|
||||
async def get_me(self) -> ProfileResponse:
|
||||
user_id = self.require_user_id()
|
||||
try:
|
||||
profile = await self._repository.get_by_user_id(user_id)
|
||||
except SQLAlchemyError:
|
||||
raise HTTPException(status_code=503, detail="Profile store unavailable")
|
||||
|
||||
if profile is None:
|
||||
raise HTTPException(status_code=404, detail="Profile not found")
|
||||
return ProfileResponse(
|
||||
id=str(profile.id),
|
||||
username=profile.username,
|
||||
display_name=profile.display_name,
|
||||
avatar_url=profile.avatar_url,
|
||||
bio=profile.bio,
|
||||
)
|
||||
|
||||
async def update_me(self, update: ProfileUpdateRequest) -> ProfileResponse:
|
||||
user_id = self.require_user_id()
|
||||
update_data: dict[str, str | None] = {
|
||||
key: value
|
||||
for key, value in {
|
||||
"display_name": update.display_name,
|
||||
"avatar_url": update.avatar_url,
|
||||
"bio": update.bio,
|
||||
}.items()
|
||||
if value is not None
|
||||
}
|
||||
|
||||
if not update_data:
|
||||
raise HTTPException(status_code=400, detail="No fields to update")
|
||||
|
||||
try:
|
||||
profile = await self._repository.update_by_user_id(user_id, update_data)
|
||||
await self._session.commit()
|
||||
except SQLAlchemyError:
|
||||
await self._session.rollback()
|
||||
raise HTTPException(status_code=503, detail="Profile store unavailable")
|
||||
|
||||
if profile is None:
|
||||
raise HTTPException(status_code=404, detail="Profile not found")
|
||||
|
||||
return ProfileResponse(
|
||||
id=str(profile.id),
|
||||
username=profile.username,
|
||||
display_name=profile.display_name,
|
||||
avatar_url=profile.avatar_url,
|
||||
bio=profile.bio,
|
||||
)
|
||||
|
||||
async def get_by_username(self, username: str) -> ProfileResponse:
|
||||
try:
|
||||
profile = await self._repository.get_by_username(username)
|
||||
except SQLAlchemyError:
|
||||
raise HTTPException(status_code=503, detail="Profile store unavailable")
|
||||
|
||||
if profile is None:
|
||||
raise HTTPException(status_code=404, detail="Profile not found")
|
||||
return ProfileResponse(
|
||||
id=str(profile.id),
|
||||
username=profile.username,
|
||||
display_name=profile.display_name,
|
||||
avatar_url=profile.avatar_url,
|
||||
bio=profile.bio,
|
||||
)
|
||||
@@ -0,0 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from core.http.models import HealthResponse
|
||||
from v1.auth.router import router as auth_router
|
||||
from v1.infra.router import router as infra_router
|
||||
from v1.profile.router import router as profile_router
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/v1")
|
||||
router.include_router(auth_router)
|
||||
router.include_router(infra_router)
|
||||
router.include_router(profile_router)
|
||||
|
||||
|
||||
@router.get("/health", response_model=HealthResponse)
|
||||
async def health() -> HealthResponse:
|
||||
return HealthResponse(status="ok")
|
||||
Reference in New Issue
Block a user