refactor: align backend layout and supabase infra

Consolidate backend modules/tests under the backend package while syncing Supabase compose/env config and related plans.
This commit is contained in:
qzl
2026-02-05 15:13:06 +08:00
parent 3cfcb11240
commit ad06fe7de4
111 changed files with 5540 additions and 1362 deletions
View File
+133
View File
@@ -0,0 +1,133 @@
from __future__ import annotations
from fastapi import FastAPI, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from starlette.exceptions import HTTPException as StarletteHTTPException
from core.config.settings import config
from core.http.models import HealthResponse
from core.http.response import build_problem_details
from core.logging import configure_logging, get_logger
from v1.router import router as mobile_router
configure_logging(config)
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=config.cors.allow_origins,
allow_credentials=config.cors.allow_credentials,
allow_methods=config.cors.allow_methods,
allow_headers=config.cors.allow_headers,
)
app.include_router(mobile_router)
logger = get_logger("api.app")
@app.get("/health", response_model=HealthResponse)
async def health() -> HealthResponse:
return HealthResponse(status="ok")
def _build_http_error_response(
request: Request,
exc: Exception,
status_code: int,
detail: object,
) -> JSONResponse:
instance = request.url.path
detail_text = detail if isinstance(detail, str) else "Request failed"
logger.warning(
"HTTP error",
status_code=status_code,
detail=detail_text,
detail_extra=detail,
path=request.url.path,
method=request.method,
)
problem = build_problem_details(
status_code=status_code,
detail=detail_text,
instance=instance,
)
return JSONResponse(
status_code=status_code,
content=problem.model_dump(),
media_type="application/problem+json",
)
@app.exception_handler(HTTPException)
async def http_exception_handler(
request: Request,
exc: HTTPException,
) -> JSONResponse:
return _build_http_error_response(
request=request,
exc=exc,
status_code=exc.status_code,
detail=exc.detail,
)
@app.exception_handler(StarletteHTTPException)
async def starlette_http_exception_handler(
request: Request,
exc: StarletteHTTPException,
) -> JSONResponse:
return _build_http_error_response(
request=request,
exc=exc,
status_code=exc.status_code,
detail=exc.detail,
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(
request: Request,
exc: RequestValidationError,
) -> JSONResponse:
instance = request.url.path
logger.warning(
"Request validation error",
path=request.url.path,
method=request.method,
errors=exc.errors(),
)
problem = build_problem_details(
status_code=422,
detail="Invalid request",
instance=instance,
)
return JSONResponse(
status_code=422,
content=problem.model_dump(),
media_type="application/problem+json",
)
@app.exception_handler(Exception)
async def unhandled_exception_handler(
request: Request,
exc: Exception,
) -> JSONResponse:
instance = request.url.path
logger.exception(
"Unhandled error",
path=request.url.path,
method=request.method,
)
problem = build_problem_details(
status_code=500,
detail="Internal Server Error",
instance=instance,
)
return JSONResponse(
status_code=500,
content=problem.model_dump(),
media_type="application/problem+json",
)
View File
+9
View File
@@ -0,0 +1,9 @@
from __future__ import annotations
from dataclasses import dataclass
from uuid import UUID
@dataclass(frozen=True)
class CurrentUser:
id: UUID
+3
View File
@@ -0,0 +1,3 @@
from .settings import Settings, config
__all__ = ["Settings", "config"]
+166
View File
@@ -0,0 +1,166 @@
from __future__ import annotations
from pathlib import Path
from typing import ClassVar, Literal
from urllib.parse import quote
from pydantic import BaseModel, Field, computed_field
from pydantic_settings import BaseSettings, SettingsConfigDict
class RuntimeSettings(BaseModel):
environment: Literal["dev", "test", "prod"] = "dev"
debug: bool = True
log_level: str = "INFO"
log_json: bool = True
log_rotation: Literal["time", "size", "none"] = "time"
log_rotation_when: str = "midnight"
log_rotation_interval: int = 1
log_rotation_backup_count: int = 14
log_rotation_max_bytes: int = 10_000_000
log_dir: str = "logs"
log_error_dir: str = "logs/errors"
log_file_name: str = "app.log"
log_error_file_name: str = "error.log"
log_sensitive_fields: list[str] = Field(
default_factory=lambda: [
"password",
"secret",
"token",
"api_key",
"authorization",
"cookie",
"client_ip",
"user_id",
]
)
sql_log_queries: bool = False
class AppSettings(BaseModel):
host: str = "0.0.0.0"
port: int = Field(default=8000, ge=1, le=65535)
reload: bool = True
class CorsSettings(BaseModel):
allow_origins: list[str] = Field(
default_factory=lambda: [
"http://localhost",
"http://localhost:3000",
]
)
allow_credentials: bool = True
allow_methods: list[str] = Field(default_factory=lambda: ["*"])
allow_headers: list[str] = Field(default_factory=lambda: ["*"])
class RedisSettings(BaseModel):
host: str = "redis"
port: int = 6379
password: str | None = None
db: int = 0
socket_connect_timeout: float = 1.0
socket_timeout: float = 1.0
max_connections: int = 10
@computed_field
@property
def url(self) -> str:
if self.password:
password = quote(self.password, safe="")
return f"redis://:{password}@{self.host}:{self.port}/{self.db}"
return f"redis://{self.host}:{self.port}/{self.db}"
class QdrantSettings(BaseModel):
host: str = "qdrant"
port: int = 6333
grpc_port: int = 6334
api_key: str | None = None
https: bool = False
prefer_grpc: bool = True
timeout: int = 5
@computed_field
@property
def url(self) -> str:
scheme = "https" if self.https else "http"
return f"{scheme}://{self.host}:{self.port}"
class SupabaseSettings(BaseModel):
public_scheme: str = "http"
public_host: str = "localhost"
kong_http_port: int = 8000
anon_key: str = "CHANGE_ME"
service_role_key: str = "CHANGE_ME"
jwt_secret: str | None = None
@computed_field
@property
def public_url(self) -> str:
return f"{self.public_scheme}://{self.public_host}:{self.kong_http_port}"
@computed_field
@property
def api_external_url(self) -> str:
return self.public_url
@computed_field
@property
def url(self) -> str:
return self.public_url
class DatabaseSettings(BaseModel):
host: str = "localhost"
port: int = 5432
name: str = "postgres"
user: str = "postgres"
password: str = "CHANGE_ME"
@computed_field
@property
def url(self) -> str:
password = quote(self.password, safe="")
return (
f"postgresql+asyncpg://{self.user}:{password}"
f"@{self.host}:{self.port}/{self.name}"
)
def _resolve_env_file() -> str:
current = Path(__file__).resolve()
for parent in [current, *current.parents]:
candidate = parent / ".env"
if candidate.is_file():
return str(candidate)
return ".env"
class Settings(BaseSettings):
runtime: RuntimeSettings = RuntimeSettings()
app: AppSettings = AppSettings()
cors: CorsSettings = CorsSettings()
redis: RedisSettings = RedisSettings()
qdrant: QdrantSettings = QdrantSettings()
supabase: SupabaseSettings = SupabaseSettings()
database: DatabaseSettings = DatabaseSettings()
@computed_field
@property
def database_url(self) -> str:
return self.database.url
model_config: ClassVar[SettingsConfigDict] = SettingsConfigDict(
env_file=_resolve_env_file(),
env_prefix="SOCIAL_",
env_nested_delimiter="__",
case_sensitive=False,
extra="ignore",
)
config = Settings()
+5
View File
@@ -0,0 +1,5 @@
from __future__ import annotations
from core.db.session import AsyncSessionLocal, engine, get_db
__all__ = ["AsyncSessionLocal", "engine", "get_db"]
+37
View File
@@ -0,0 +1,37 @@
from __future__ import annotations
from datetime import datetime
from sqlalchemy import DateTime, func
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
class Base(DeclarativeBase):
"""Base class for all ORM models."""
pass
class TimestampMixin:
"""Adds created_at and updated_at timestamps."""
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
class SoftDeleteMixin:
"""Adds soft delete timestamp column."""
deleted_at: Mapped[datetime | None] = mapped_column(
DateTime(timezone=True),
nullable=True,
)
+84
View File
@@ -0,0 +1,84 @@
from __future__ import annotations
from datetime import datetime, timezone
from typing import Any, Generic, TypeVar
from sqlalchemy import Select, select, update
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from core.db.base import Base
ModelType = TypeVar("ModelType", bound=Base)
class BaseRepository(Generic[ModelType]):
_session: AsyncSession
_model: type[ModelType]
def __init__(self, session: AsyncSession, model: type[ModelType]) -> None:
self._session = session
self._model = model
def _deleted_at_column(self) -> Any | None:
return getattr(self._model, "deleted_at", None)
def _apply_soft_delete_filter(self, stmt: Select) -> Select:
deleted_at = self._deleted_at_column()
if deleted_at is None:
return stmt
return stmt.where(deleted_at.is_(None))
async def get_by_id(self, entity_id: Any) -> ModelType | None:
id_column = getattr(self._model, "id")
stmt = select(self._model).where(id_column == entity_id)
stmt = self._apply_soft_delete_filter(stmt)
result = await self._session.execute(stmt)
return result.scalar_one_or_none()
async def get_one(self, *filters: Any) -> ModelType | None:
stmt = select(self._model).where(*filters)
stmt = self._apply_soft_delete_filter(stmt)
result = await self._session.execute(stmt)
return result.scalar_one_or_none()
async def update_by_id(
self, entity_id: Any, update_data: dict[str, Any]
) -> ModelType | None:
if not update_data:
return await self.get_by_id(entity_id)
id_column = getattr(self._model, "id")
stmt = update(self._model).where(id_column == entity_id)
deleted_at = self._deleted_at_column()
if deleted_at is not None:
stmt = stmt.where(deleted_at.is_(None))
stmt = stmt.values(**update_data).returning(self._model)
try:
result = await self._session.execute(stmt)
await self._session.flush()
return result.scalar_one_or_none()
except SQLAlchemyError:
raise
async def soft_delete_by_id(self, entity_id: Any) -> ModelType | None:
deleted_at = self._deleted_at_column()
if deleted_at is None:
raise ValueError("Soft delete is not supported for this model")
id_column = getattr(self._model, "id")
stmt = (
update(self._model)
.where(id_column == entity_id)
.where(deleted_at.is_(None))
.values(deleted_at=datetime.now(timezone.utc))
.returning(self._model)
)
try:
result = await self._session.execute(stmt)
await self._session.flush()
return result.scalar_one_or_none()
except SQLAlchemyError:
raise
+22
View File
@@ -0,0 +1,22 @@
from __future__ import annotations
from uuid import UUID
from fastapi import HTTPException
from core.auth.models import CurrentUser
class BaseService:
_current_user: CurrentUser | None
def __init__(self, current_user: CurrentUser | None) -> None:
self._current_user = current_user
def require_current_user(self) -> CurrentUser:
if self._current_user is None:
raise HTTPException(status_code=401, detail="Unauthorized")
return self._current_user
def require_user_id(self) -> UUID:
return self.require_current_user().id
+34
View File
@@ -0,0 +1,34 @@
from __future__ import annotations
from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from core.config.settings import config
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncEngine
engine: AsyncEngine = create_async_engine(
config.database_url,
echo=config.runtime.sql_log_queries,
pool_pre_ping=True,
)
AsyncSessionLocal: async_sessionmaker[AsyncSession] = async_sessionmaker(
bind=engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
)
async def get_db() -> AsyncGenerator[AsyncSession, None]:
"""Dependency that provides a database session.
The session is automatically closed when the request completes.
Note: The caller (service layer) is responsible for commit/rollback.
"""
async with AsyncSessionLocal() as session:
yield session
+5
View File
@@ -0,0 +1,5 @@
from __future__ import annotations
from core.http.response import ProblemDetails, build_problem_details
__all__ = ["ProblemDetails", "build_problem_details"]
+7
View File
@@ -0,0 +1,7 @@
from __future__ import annotations
from pydantic import BaseModel
class HealthResponse(BaseModel):
status: str
+31
View File
@@ -0,0 +1,31 @@
from __future__ import annotations
from http import HTTPStatus
from pydantic import BaseModel
class ProblemDetails(BaseModel):
type: str = "about:blank"
title: str
status: int
detail: str
instance: str | None = None
def build_problem_details(
*,
status_code: int,
detail: str,
type_value: str = "about:blank",
title: str | None = None,
instance: str | None = None,
) -> ProblemDetails:
resolved_title = title or HTTPStatus(status_code).phrase
return ProblemDetails(
type=type_value,
title=resolved_title,
status=status_code,
detail=detail,
instance=instance,
)
+15
View File
@@ -0,0 +1,15 @@
from __future__ import annotations
from core.logging import celery
from core.logging.config import configure_logging
from core.logging.context import bind_context, clear_context, get_context
from core.logging.logger import get_logger
__all__ = [
"bind_context",
"celery",
"clear_context",
"configure_logging",
"get_context",
"get_logger",
]
+57
View File
@@ -0,0 +1,57 @@
from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass
from typing import cast
from celery import Celery, signals
from core.config.settings import Settings
from core.logging.config import configure_logging
from core.logging.context import bind_context, clear_context
@dataclass(frozen=True)
class CelerySignalHandlers:
on_setup_logging: Callable[..., None]
on_after_setup_task_logger: Callable[..., None]
on_task_prerun: Callable[..., None]
on_task_postrun: Callable[..., None]
def build_celery_signal_handlers(
settings: Settings | None = None,
) -> CelerySignalHandlers:
def on_setup_logging(*_args: object, **_kwargs: object) -> None:
configure_logging(settings)
def on_after_setup_task_logger(*_args: object, **_kwargs: object) -> None:
configure_logging(settings)
def on_task_prerun(*_args: object, **kwargs: object) -> None:
task_id = cast(str | None, kwargs.get("task_id"))
task = kwargs.get("task")
task_name = getattr(task, "name", None)
bind_context(task_id=task_id, task_name=task_name)
def on_task_postrun(*_args: object, **_kwargs: object) -> None:
clear_context()
return CelerySignalHandlers(
on_setup_logging=on_setup_logging,
on_after_setup_task_logger=on_after_setup_task_logger,
on_task_prerun=on_task_prerun,
on_task_postrun=on_task_postrun,
)
def configure_celery_app(app: Celery, settings: Settings | None = None) -> None:
app.conf.worker_hijack_root_logger = False
handlers = build_celery_signal_handlers(settings)
signals.setup_logging.connect(handlers.on_setup_logging, weak=False)
signals.after_setup_task_logger.connect(
handlers.on_after_setup_task_logger, weak=False
)
signals.task_prerun.connect(handlers.on_task_prerun, weak=False)
signals.task_postrun.connect(handlers.on_task_postrun, weak=False)
+103
View File
@@ -0,0 +1,103 @@
from __future__ import annotations
import logging
from logging.config import dictConfig
from pathlib import Path
import structlog
from core.config.settings import RuntimeSettings, Settings
from core.logging.formatters import (
build_plain_formatter,
build_processor_formatter,
ensure_message_key,
)
from core.logging.filters import build_sensitive_data_processor
from core.logging.handlers import build_file_handler_config
def _ensure_log_dirs(runtime: RuntimeSettings) -> None:
Path(runtime.log_dir).mkdir(parents=True, exist_ok=True)
Path(runtime.log_error_dir).mkdir(parents=True, exist_ok=True)
def build_logging_config(runtime: RuntimeSettings) -> dict[str, object]:
log_dir = Path(runtime.log_dir)
error_dir = Path(runtime.log_error_dir)
formatter_name = "json" if runtime.log_json else "plain"
file_handler = build_file_handler_config(
runtime,
file_path=log_dir / runtime.log_file_name,
level=runtime.log_level,
formatter=formatter_name,
)
error_handler = build_file_handler_config(
runtime,
file_path=error_dir / runtime.log_error_file_name,
level="ERROR",
formatter=formatter_name,
filters=["error_only"],
)
return {
"version": 1,
"disable_existing_loggers": False,
"filters": {
"error_only": {
"()": "core.logging.filters.ErrorLevelFilter",
}
},
"formatters": {
"json": {
"()": build_processor_formatter,
"sensitive_fields": runtime.log_sensitive_fields,
},
"plain": {
"()": build_plain_formatter,
"sensitive_fields": runtime.log_sensitive_fields,
},
},
"handlers": {
"file": file_handler,
"error": error_handler,
},
"root": {
"handlers": ["file", "error"],
"level": runtime.log_level,
},
}
def configure_logging(settings: Settings | None = None) -> None:
active_settings = settings or Settings()
runtime = active_settings.runtime
try:
_ensure_log_dirs(runtime)
dictConfig(build_logging_config(runtime))
except (OSError, ValueError) as exc:
logging.basicConfig(level=runtime.log_level)
logging.getLogger(__name__).error("Logging setup failed", exc_info=exc)
structlog.configure(
processors=[
structlog.contextvars.merge_contextvars,
structlog.processors.add_log_level,
structlog.processors.TimeStamper(fmt="iso", utc=True),
structlog.processors.CallsiteParameterAdder(
parameters=[
structlog.processors.CallsiteParameter.MODULE,
structlog.processors.CallsiteParameter.FUNC_NAME,
structlog.processors.CallsiteParameter.LINENO,
]
),
build_sensitive_data_processor(runtime.log_sensitive_fields),
ensure_message_key,
structlog.processors.format_exc_info,
structlog.processors.UnicodeDecoder(),
structlog.stdlib.ProcessorFormatter.wrap_for_formatter,
],
logger_factory=structlog.stdlib.LoggerFactory(),
cache_logger_on_first_use=True,
)
+15
View File
@@ -0,0 +1,15 @@
from __future__ import annotations
from structlog import contextvars
def bind_context(**values: object) -> None:
contextvars.bind_contextvars(**values)
def clear_context() -> None:
contextvars.clear_contextvars()
def get_context() -> dict[str, object]:
return contextvars.get_contextvars()
+56
View File
@@ -0,0 +1,56 @@
from __future__ import annotations
import logging
import re
from collections.abc import Callable
from typing import cast
from structlog.types import EventDict
_NORMALIZE_PATTERN = re.compile(r"[^a-z0-9]")
def _normalize_key(value: str) -> str:
return _NORMALIZE_PATTERN.sub("", value.lower())
def _is_sensitive_key(key: object, sensitive_fields: set[str]) -> bool:
normalized_key = _normalize_key(str(key))
return normalized_key in sensitive_fields or any(
fragment in normalized_key for fragment in sensitive_fields
)
def _redact_value(value: object, sensitive_fields: set[str]) -> object:
if isinstance(value, dict):
typed_value = cast(dict[str, object], value)
return {
key: (
"[REDACTED]"
if _is_sensitive_key(key, sensitive_fields)
else _redact_value(inner, sensitive_fields)
)
for key, inner in typed_value.items()
}
if isinstance(value, list):
return [_redact_value(item, sensitive_fields) for item in value]
return value
def build_sensitive_data_processor(
sensitive_fields: list[str],
) -> Callable[[object, str, EventDict], EventDict]:
normalized = {_normalize_key(field) for field in sensitive_fields}
def processor(
_logger: object, _method_name: str, event_dict: EventDict
) -> EventDict:
return cast(EventDict, _redact_value(event_dict, normalized))
return processor
class ErrorLevelFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
return record.levelno >= logging.ERROR
+81
View File
@@ -0,0 +1,81 @@
from __future__ import annotations
from structlog.dev import ConsoleRenderer
from structlog.processors import JSONRenderer
from structlog.stdlib import ProcessorFormatter
from structlog.types import EventDict
import structlog
from core.logging.filters import build_sensitive_data_processor
def ensure_message_key(
_logger: object, _method_name: str, event_dict: EventDict
) -> EventDict:
if "message" in event_dict:
return event_dict
if "event" not in event_dict:
return event_dict
without_event = {key: value for key, value in event_dict.items() if key != "event"}
return {**without_event, "message": event_dict["event"]}
def build_processor_formatter(
sensitive_fields: list[str] | None = None,
) -> ProcessorFormatter:
redact = build_sensitive_data_processor(sensitive_fields or [])
return ProcessorFormatter(
foreign_pre_chain=[
structlog.contextvars.merge_contextvars,
structlog.processors.add_log_level,
structlog.processors.TimeStamper(fmt="iso", utc=True),
structlog.processors.CallsiteParameterAdder(
parameters=[
structlog.processors.CallsiteParameter.MODULE,
structlog.processors.CallsiteParameter.FUNC_NAME,
structlog.processors.CallsiteParameter.LINENO,
]
),
structlog.stdlib.ExtraAdder(),
ensure_message_key,
],
processors=[
redact,
ensure_message_key,
ProcessorFormatter.remove_processors_meta,
structlog.processors.format_exc_info,
structlog.processors.UnicodeDecoder(),
JSONRenderer(sort_keys=True),
],
)
def build_plain_formatter(
sensitive_fields: list[str] | None = None,
) -> ProcessorFormatter:
redact = build_sensitive_data_processor(sensitive_fields or [])
return ProcessorFormatter(
foreign_pre_chain=[
structlog.contextvars.merge_contextvars,
structlog.processors.add_log_level,
structlog.processors.TimeStamper(fmt="iso", utc=True),
structlog.processors.CallsiteParameterAdder(
parameters=[
structlog.processors.CallsiteParameter.MODULE,
structlog.processors.CallsiteParameter.FUNC_NAME,
structlog.processors.CallsiteParameter.LINENO,
]
),
structlog.stdlib.ExtraAdder(),
ensure_message_key,
],
processors=[
redact,
ensure_message_key,
ProcessorFormatter.remove_processors_meta,
structlog.processors.format_exc_info,
structlog.processors.UnicodeDecoder(),
ConsoleRenderer(colors=False),
],
)
+46
View File
@@ -0,0 +1,46 @@
from __future__ import annotations
from pathlib import Path
from core.config.settings import RuntimeSettings
def build_file_handler_config(
runtime: RuntimeSettings,
file_path: Path,
level: str,
formatter: str,
filters: list[str] | None = None,
) -> dict[str, object]:
filter_list = list(filters or [])
base_config: dict[str, object] = {
"level": level,
"formatter": formatter,
"filename": str(file_path),
"encoding": "utf-8",
}
if filter_list:
base_config = {**base_config, "filters": filter_list}
if runtime.log_rotation == "time":
return {
**base_config,
"class": "logging.handlers.TimedRotatingFileHandler",
"when": runtime.log_rotation_when,
"interval": runtime.log_rotation_interval,
"backupCount": runtime.log_rotation_backup_count,
}
if runtime.log_rotation == "size":
return {
**base_config,
"class": "logging.handlers.RotatingFileHandler",
"maxBytes": runtime.log_rotation_max_bytes,
"backupCount": runtime.log_rotation_backup_count,
}
return {
**base_config,
"class": "logging.FileHandler",
}
+7
View File
@@ -0,0 +1,7 @@
from __future__ import annotations
import structlog
def get_logger(name: str) -> structlog.stdlib.BoundLogger:
return structlog.get_logger(name)
+84
View File
@@ -0,0 +1,84 @@
from __future__ import annotations
import re
from collections.abc import MutableMapping
from typing import cast
from uuid import uuid4
from fastapi import FastAPI, Request
from starlette.requests import Request as StarletteRequest
from starlette.responses import JSONResponse, Response
from starlette.types import ASGIApp, Receive, Scope, Send
from core.logging.context import bind_context, clear_context
from core.logging.logger import get_logger
class RequestContextMiddleware:
app: ASGIApp
_header_name: str
_request_id_pattern: re.Pattern[str]
def __init__(self, app: ASGIApp, header_name: str = "X-Request-ID") -> None:
self.app = app
self._header_name = header_name
self._request_id_pattern = re.compile(r"^[A-Za-z0-9_-]{8,64}$")
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope.get("type") != "http":
await self.app(scope, receive, send)
return
request = StarletteRequest(scope, receive=receive)
request_id = self._normalize_request_id(request.headers.get(self._header_name))
client_ip = request.client.host if request.client else None
user_id = getattr(request.state, "user_id", None)
request.state.request_id = request_id
bind_context(
request_id=request_id,
method=request.method,
path=request.url.path,
client_ip=client_ip,
user_id=user_id,
)
async def send_wrapper(message: MutableMapping[str, object]) -> None:
if message.get("type") == "http.response.start":
raw_headers = message.get("headers")
headers = list(cast(list[tuple[bytes, bytes]], raw_headers or []))
header_key = self._header_name.lower().encode()
if not any(item[0].lower() == header_key for item in headers):
headers.append((header_key, request_id.encode()))
message = {**message, "headers": headers}
await send(message)
try:
await self.app(scope, receive, send_wrapper)
finally:
clear_context()
def _normalize_request_id(self, request_id: str | None) -> str:
if request_id and self._request_id_pattern.match(request_id):
return request_id
return str(uuid4())
def register_exception_handlers(app: FastAPI) -> None:
logger = get_logger("core.logging.exception")
@app.exception_handler(Exception)
async def unhandled_exception_handler(request: Request, exc: Exception) -> Response:
request_id = getattr(request.state, "request_id", None)
logger.exception(
"Unhandled exception",
error_type=exc.__class__.__name__,
request_id=request_id,
)
headers = {"X-Request-ID": request_id} if request_id else None
return JSONResponse(
status_code=500,
content={"detail": "Internal Server Error"},
headers=headers,
)
+5
View File
@@ -0,0 +1,5 @@
from __future__ import annotations
from models.profile import Profile
__all__ = ["Profile"]
+43
View File
@@ -0,0 +1,43 @@
from __future__ import annotations
import uuid
from sqlalchemy import String, Text
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column
from core.db.base import Base, SoftDeleteMixin, TimestampMixin
class Profile(TimestampMixin, SoftDeleteMixin, Base):
"""User profile model.
Note: The `id` column references auth.users(id) in Supabase.
This is a business table managed by SQLAlchemy, with the foreign key
relationship to Supabase's auth schema handled at the database level.
"""
__tablename__: str = "profiles"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
primary_key=True,
)
username: Mapped[str] = mapped_column(
String(30),
unique=True,
nullable=False,
index=True,
)
display_name: Mapped[str | None] = mapped_column(
String(50),
nullable=True,
)
avatar_url: Mapped[str | None] = mapped_column(
Text,
nullable=True,
)
bio: Mapped[str | None] = mapped_column(
String(200),
nullable=True,
)
+1
View File
@@ -0,0 +1 @@
from __future__ import annotations
+21
View File
@@ -0,0 +1,21 @@
from __future__ import annotations
from services.base.qdrant import QdrantService, qdrant_service
from services.base.redis import RedisService, redis_service
from services.base.service_interface import (
BaseServiceProvider,
ServiceRegistry,
register_service,
register_service_instance,
)
__all__ = [
"BaseServiceProvider",
"QdrantService",
"RedisService",
"ServiceRegistry",
"qdrant_service",
"redis_service",
"register_service",
"register_service_instance",
]
+94
View File
@@ -0,0 +1,94 @@
from __future__ import annotations
import asyncio
from typing import Any, Dict, Optional
from qdrant_client import QdrantClient
from core.config.settings import QdrantSettings, config
from .service_interface import BaseServiceProvider, register_service_instance
class QdrantService(BaseServiceProvider):
def __init__(self, settings: QdrantSettings | None = None) -> None:
super().__init__("qdrant")
self._settings = settings or config.qdrant
self._client: Optional[QdrantClient] = None
def _build_client(self) -> QdrantClient:
return QdrantClient(
url=self._settings.url,
api_key=self._settings.api_key,
timeout=self._settings.timeout,
prefer_grpc=self._settings.prefer_grpc,
)
def _require_client(self) -> QdrantClient:
client = self._client
if client is None:
raise RuntimeError("Qdrant client is not initialized")
return client
async def initialize(self, **_: Any) -> bool:
try:
client = self._build_client()
collections = await asyncio.to_thread(client.get_collections)
self.logger.info(
"Qdrant service initialized",
collections_count=len(collections.collections),
)
self._client = client
self._set_initialized(True)
return True
except Exception as exc: # noqa: BLE001
self.logger.warning("Qdrant service initialization failed", error=str(exc))
self._client = None
self._set_initialized(False)
return False
async def close(self) -> bool:
client = self._client
if client is None:
return True
try:
close = getattr(client, "close", None)
if callable(close):
await asyncio.to_thread(close)
self.logger.info("Qdrant service closed")
self._client = None
self._set_initialized(False)
return True
except Exception as exc: # noqa: BLE001
self.logger.exception("Qdrant service close failed", error=str(exc))
self._client = None
self._set_initialized(False)
return False
async def health_check(self) -> Dict[str, Any]:
client = self._client
if client is None:
return {"status": "unhealthy", "details": {"error": "not initialized"}}
try:
collections = await asyncio.to_thread(client.get_collections)
return {
"status": "healthy",
"details": {
"connected": True,
"collections_count": len(collections.collections),
"collections": [
collection.name for collection in collections.collections[:5]
],
},
}
except Exception as exc: # noqa: BLE001
self.logger.warning("Qdrant health check failed", error=str(exc))
return {"status": "unhealthy", "details": {"error": str(exc)}}
def get_client(self) -> QdrantClient:
return self._require_client()
qdrant_service: QdrantService = register_service_instance("qdrant", QdrantService())
__all__ = ["QdrantService", "qdrant_service"]
+97
View File
@@ -0,0 +1,97 @@
from __future__ import annotations
import inspect
from typing import Any, Dict, Optional
import redis.asyncio as redis
from core.config.settings import RedisSettings, config
from .service_interface import BaseServiceProvider, register_service_instance
class RedisService(BaseServiceProvider):
def __init__(self, settings: RedisSettings | None = None) -> None:
super().__init__("redis")
self._settings = settings or config.redis
self._client: Optional[redis.Redis] = None
def _build_client(self) -> redis.Redis:
return redis.from_url(
self._settings.url,
decode_responses=True,
socket_connect_timeout=self._settings.socket_connect_timeout,
socket_timeout=self._settings.socket_timeout,
max_connections=self._settings.max_connections,
)
def _require_client(self) -> redis.Redis:
client = self._client
if client is None:
raise RuntimeError("Redis client is not initialized")
return client
async def initialize(self, **_: Any) -> bool:
try:
client = self._build_client()
ping_result = client.ping()
if inspect.isawaitable(ping_result):
await ping_result
self._client = client
self._set_initialized(True)
self.logger.info("Redis service initialized")
return True
except Exception as exc: # noqa: BLE001
self.logger.warning("Redis service initialization failed", error=str(exc))
self._client = None
self._set_initialized(False)
return False
async def close(self) -> bool:
client = self._client
if client is None:
return True
try:
await client.aclose()
self.logger.info("Redis service closed")
self._client = None
self._set_initialized(False)
return True
except Exception as exc: # noqa: BLE001
self.logger.exception("Redis service close failed", error=str(exc))
return False
async def health_check(self) -> Dict[str, Any]:
client = self._client
if client is None:
return {"status": "unhealthy", "details": {"error": "not initialized"}}
try:
ping_result = client.ping()
ping = (
await ping_result if inspect.isawaitable(ping_result) else ping_result
)
info_result = client.info()
info = (
await info_result if inspect.isawaitable(info_result) else info_result
)
return {
"status": "healthy" if ping else "unhealthy",
"details": {
"ping": ping,
"redis_version": info.get("redis_version"),
"connected_clients": info.get("connected_clients"),
"used_memory": info.get("used_memory_human"),
"uptime_in_seconds": info.get("uptime_in_seconds"),
},
}
except Exception as exc: # noqa: BLE001
self.logger.warning("Redis health check failed", error=str(exc))
return {"status": "unhealthy", "details": {"error": str(exc)}}
def get_client(self) -> redis.Redis:
return self._require_client()
redis_service: RedisService = register_service_instance("redis", RedisService())
__all__ = ["RedisService", "redis_service"]
@@ -0,0 +1,84 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Optional, TypeVar
from core.logging import get_logger
class BaseServiceProvider(ABC):
def __init__(self, service_name: str) -> None:
self.service_name = service_name
self._initialized = False
self.logger = get_logger("services.base").bind(service=service_name)
@abstractmethod
async def initialize(self, **kwargs: Any) -> bool:
raise NotImplementedError
@abstractmethod
async def close(self) -> bool:
raise NotImplementedError
@abstractmethod
async def health_check(self) -> Dict[str, Any]:
raise NotImplementedError
@property
def is_initialized(self) -> bool:
return self._initialized
def _set_initialized(self, value: bool) -> None:
self._initialized = value
def get_service_info(self) -> Dict[str, Any]:
return {
"name": self.service_name,
"initialized": self._initialized,
"type": self.__class__.__name__,
}
class ServiceRegistry:
_services: Dict[str, Callable[..., BaseServiceProvider]] = {}
@classmethod
def register(
cls, service_name: str, factory: Callable[..., BaseServiceProvider]
) -> None:
cls._services = {**cls._services, service_name: factory}
@classmethod
def get_service_factory(
cls, service_name: str
) -> Optional[Callable[..., BaseServiceProvider]]:
return cls._services.get(service_name)
@classmethod
def list_services(cls) -> list[str]:
return sorted(cls._services.keys())
@classmethod
def create_service(
cls, service_name: str, **kwargs: Any
) -> Optional[BaseServiceProvider]:
factory = cls.get_service_factory(service_name)
if not factory:
return None
return factory(**kwargs)
def register_service(service_name: str) -> Callable[[type], type]:
def decorator(service_class: type) -> type:
ServiceRegistry.register(service_name, service_class)
return service_class
return decorator
TService = TypeVar("TService", bound=BaseServiceProvider)
def register_service_instance(service_name: str, service: TService) -> TService:
ServiceRegistry.register(service_name, lambda: service)
return service
+1
View File
@@ -0,0 +1 @@
from __future__ import annotations
+1
View File
@@ -0,0 +1 @@
from __future__ import annotations
+7
View File
@@ -0,0 +1,7 @@
from __future__ import annotations
from v1.auth.service import AuthService, SupabaseAuthGateway
def get_auth_service() -> AuthService:
return AuthService(gateway=SupabaseAuthGateway())
+35
View File
@@ -0,0 +1,35 @@
from __future__ import annotations
from pydantic import BaseModel, EmailStr, Field
class SignupRequest(BaseModel):
email: EmailStr
password: str = Field(min_length=6)
display_name: str | None = None
class LoginRequest(BaseModel):
email: EmailStr
password: str = Field(min_length=6)
class RefreshRequest(BaseModel):
refresh_token: str = Field(min_length=1)
class LogoutRequest(BaseModel):
refresh_token: str = Field(min_length=1)
class AuthUser(BaseModel):
id: str
email: EmailStr
class AuthTokenResponse(BaseModel):
access_token: str
refresh_token: str
expires_in: int
token_type: str
user: AuthUser
+49
View File
@@ -0,0 +1,49 @@
from __future__ import annotations
from fastapi import APIRouter, Depends, Response
from v1.auth.dependencies import get_auth_service
from v1.auth.models import (
AuthTokenResponse,
LoginRequest,
LogoutRequest,
RefreshRequest,
SignupRequest,
)
from v1.auth.service import AuthService
router = APIRouter(prefix="/auth", tags=["auth"])
@router.post("/signup", response_model=AuthTokenResponse)
async def signup(
payload: SignupRequest,
service: AuthService = Depends(get_auth_service),
) -> AuthTokenResponse:
return await service.signup(payload)
@router.post("/login", response_model=AuthTokenResponse)
async def login(
payload: LoginRequest,
service: AuthService = Depends(get_auth_service),
) -> AuthTokenResponse:
return await service.login(payload)
@router.post("/refresh", response_model=AuthTokenResponse)
async def refresh(
payload: RefreshRequest,
service: AuthService = Depends(get_auth_service),
) -> AuthTokenResponse:
return await service.refresh(payload)
@router.post("/logout", status_code=204)
async def logout(
payload: LogoutRequest,
service: AuthService = Depends(get_auth_service),
) -> Response:
await service.logout(payload.refresh_token)
return Response(status_code=204)
+147
View File
@@ -0,0 +1,147 @@
from __future__ import annotations
import asyncio
from typing import Any, Protocol, cast
from fastapi import HTTPException
from supabase import AuthError, create_client
from core.config.settings import SupabaseSettings, config
from core.logging import get_logger
from v1.auth.models import (
AuthTokenResponse,
AuthUser,
LoginRequest,
RefreshRequest,
SignupRequest,
)
logger = get_logger("v1.auth.service")
class AuthServiceGateway(Protocol):
async def signup(self, request: SignupRequest) -> AuthTokenResponse:
raise NotImplementedError
async def login(self, request: LoginRequest) -> AuthTokenResponse:
raise NotImplementedError
async def refresh(self, request: RefreshRequest) -> AuthTokenResponse:
raise NotImplementedError
async def logout(self, refresh_token: str | None) -> None:
raise NotImplementedError
class SupabaseAuthGateway(AuthServiceGateway):
_client: Any
def __init__(self) -> None:
settings: SupabaseSettings = config.supabase
self._client = create_client(settings.url, settings.anon_key)
async def signup(self, request: SignupRequest) -> AuthTokenResponse:
payload: dict[str, Any] = {
"email": request.email,
"password": request.password,
}
if request.display_name:
payload = {
**payload,
"data": {"display_name": request.display_name},
}
try:
sign_up = cast(Any, self._client.auth.sign_up)
response = await asyncio.to_thread(sign_up, payload)
return _map_auth_response(response, "Authentication failed")
except AuthError as exc:
logger.warning("Signup failed", error=str(exc))
raise HTTPException(
status_code=401, detail="Authentication failed"
) from exc
async def login(self, request: LoginRequest) -> AuthTokenResponse:
payload: dict[str, Any] = {"email": request.email, "password": request.password}
try:
sign_in = cast(Any, self._client.auth.sign_in_with_password)
response = await asyncio.to_thread(sign_in, payload)
return _map_auth_response(response, "Invalid credentials")
except AuthError as exc:
logger.warning("Login failed", error=str(exc))
raise HTTPException(status_code=401, detail="Invalid credentials") from exc
async def refresh(self, request: RefreshRequest) -> AuthTokenResponse:
try:
response = await asyncio.to_thread(
self._client.auth.refresh_session,
request.refresh_token,
)
return _map_auth_response(response, "Invalid refresh token")
except AuthError as exc:
logger.warning("Refresh failed", error=str(exc))
raise HTTPException(
status_code=401, detail="Invalid refresh token"
) from exc
async def logout(self, refresh_token: str | None) -> None:
if not refresh_token:
raise HTTPException(status_code=401, detail="Missing refresh token")
try:
response = await asyncio.to_thread(
self._client.auth.refresh_session,
refresh_token,
)
session = getattr(response, "session", None)
if session is None:
raise HTTPException(status_code=401, detail="Invalid refresh token")
await asyncio.to_thread(
self._client.auth.set_session,
str(session.access_token),
str(session.refresh_token),
)
await asyncio.to_thread(self._client.auth.sign_out)
except AuthError as exc:
logger.warning("Logout failed", error=str(exc))
raise HTTPException(
status_code=401, detail="Invalid refresh token"
) from exc
class AuthService:
_gateway: AuthServiceGateway
def __init__(self, gateway: AuthServiceGateway) -> None:
self._gateway = gateway
async def signup(self, request: SignupRequest) -> AuthTokenResponse:
return await self._gateway.signup(request)
async def login(self, request: LoginRequest) -> AuthTokenResponse:
return await self._gateway.login(request)
async def refresh(self, request: RefreshRequest) -> AuthTokenResponse:
return await self._gateway.refresh(request)
async def logout(self, refresh_token: str | None) -> None:
await self._gateway.logout(refresh_token)
def _map_auth_response(response: object, failure_message: str) -> AuthTokenResponse:
session = getattr(response, "session", None)
user = getattr(response, "user", None)
if session is None or user is None:
raise HTTPException(status_code=401, detail=failure_message)
email = getattr(user, "email", None)
if not email:
raise HTTPException(status_code=401, detail=failure_message)
auth_user = AuthUser(id=str(user.id), email=str(email))
return AuthTokenResponse(
access_token=str(session.access_token),
refresh_token=str(session.refresh_token),
expires_in=int(session.expires_in or 0),
token_type=str(session.token_type),
user=auth_user,
)
+1
View File
@@ -0,0 +1 @@
from __future__ import annotations
+12
View File
@@ -0,0 +1,12 @@
from __future__ import annotations
from services.base.redis import RedisService, redis_service
from services.base.qdrant import QdrantService, qdrant_service
def get_redis_service() -> RedisService:
return redis_service
def get_qdrant_service() -> QdrantService:
return qdrant_service
+38
View File
@@ -0,0 +1,38 @@
from __future__ import annotations
from fastapi import APIRouter, Depends
from services.base.qdrant import QdrantService
from services.base.redis import RedisService
from v1.infra.dependencies import get_qdrant_service, get_redis_service
from v1.infra.schemas import InfraHealthResponse, ServiceHealth
router = APIRouter(prefix="/infra", tags=["infra"])
@router.get("/health", response_model=InfraHealthResponse)
async def infra_health(
redis_service: RedisService = Depends(get_redis_service),
qdrant_service: QdrantService = Depends(get_qdrant_service),
) -> InfraHealthResponse:
if not redis_service.is_initialized:
await redis_service.initialize()
if not qdrant_service.is_initialized:
await qdrant_service.initialize()
redis_health = await redis_service.health_check()
qdrant_health = await qdrant_service.health_check()
status = (
"healthy"
if redis_health["status"] == "healthy" and qdrant_health["status"] == "healthy"
else "unhealthy"
)
return InfraHealthResponse(
status=status,
services={
"redis": ServiceHealth(**redis_health),
"qdrant": ServiceHealth(**qdrant_health),
},
)
+15
View File
@@ -0,0 +1,15 @@
from __future__ import annotations
from typing import Any, Dict, Literal
from pydantic import BaseModel
class ServiceHealth(BaseModel):
status: Literal["healthy", "unhealthy"]
details: Dict[str, Any]
class InfraHealthResponse(BaseModel):
status: Literal["healthy", "unhealthy"]
services: Dict[str, ServiceHealth]
+1
View File
@@ -0,0 +1 @@
from __future__ import annotations
+93
View File
@@ -0,0 +1,93 @@
from __future__ import annotations
from typing import Annotated
from uuid import UUID
import jwt
from fastapi import Depends, Header, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from core.config.settings import config
from core.db import get_db
from core.logging import get_logger
from core.auth.models import CurrentUser
from v1.profile.repository import SQLAlchemyProfileRepository
from v1.profile.service import ProfileService
logger = get_logger("v1.profile.dependencies")
def get_current_user(authorization: str | None = Header(default=None)) -> CurrentUser:
if not authorization:
logger.warning("JWT validation failed: missing authorization header")
raise HTTPException(status_code=401, detail="Unauthorized")
scheme, _, token = authorization.partition(" ")
if scheme.lower() != "bearer" or not token:
logger.warning("JWT validation failed: invalid authorization scheme")
raise HTTPException(status_code=401, detail="Unauthorized")
secret = config.supabase.jwt_secret
if not secret:
logger.error("JWT validation failed: secret not configured")
raise HTTPException(status_code=503, detail="JWT secret not configured")
supabase_url = config.supabase.public_url.rstrip("/")
expected_issuer = f"{supabase_url}/auth/v1"
try:
payload = jwt.decode(
token,
secret,
algorithms=["HS256"],
audience="authenticated",
issuer=expected_issuer,
options={
"verify_aud": True,
"verify_iss": True,
"verify_exp": True,
"require": ["sub", "aud", "iss", "exp"],
},
)
except jwt.ExpiredSignatureError:
logger.warning("JWT validation failed: token expired")
raise HTTPException(status_code=401, detail="Unauthorized")
except jwt.InvalidAudienceError:
logger.warning("JWT validation failed: invalid audience")
raise HTTPException(status_code=401, detail="Unauthorized")
except jwt.InvalidIssuerError:
logger.warning("JWT validation failed: invalid issuer")
raise HTTPException(status_code=401, detail="Unauthorized")
except jwt.InvalidSignatureError:
logger.warning("JWT validation failed: invalid signature")
raise HTTPException(status_code=401, detail="Unauthorized")
except jwt.DecodeError:
logger.warning("JWT validation failed: malformed token")
raise HTTPException(status_code=401, detail="Unauthorized")
except jwt.PyJWTError as exc:
logger.warning(
"JWT validation failed: unknown error", error_type=type(exc).__name__
)
raise HTTPException(status_code=401, detail="Unauthorized") from exc
subject = payload.get("sub")
if not isinstance(subject, str) or not subject:
logger.warning("JWT validation failed: missing or invalid subject claim")
raise HTTPException(status_code=401, detail="Unauthorized")
try:
user_id = UUID(subject)
except ValueError:
logger.warning("JWT validation failed: invalid UUID in subject")
raise HTTPException(status_code=401, detail="Unauthorized")
logger.debug("JWT validation successful", user_id=str(user_id))
return CurrentUser(id=user_id)
def get_profile_service(
session: Annotated[AsyncSession, Depends(get_db)],
user: Annotated[CurrentUser, Depends(get_current_user)],
) -> ProfileService:
repository = SQLAlchemyProfileRepository(session)
return ProfileService(repository=repository, session=session, current_user=user)
+72
View File
@@ -0,0 +1,72 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Protocol
from uuid import UUID
from sqlalchemy.exc import SQLAlchemyError
from core.db.base_repository import BaseRepository
from core.logging import get_logger
from models.profile import Profile
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
logger = get_logger("v1.profile.repository")
class ProfileRepository(Protocol):
"""Protocol defining the profile repository interface."""
async def get_by_user_id(self, user_id: UUID) -> Profile | None:
"""Get profile by user ID."""
...
async def get_by_username(self, username: str) -> Profile | None:
"""Get profile by username."""
...
async def update_by_user_id(
self, user_id: UUID, update_data: dict[str, str | None]
) -> Profile | None:
"""Update profile by user ID. Returns updated profile or None if not found."""
...
class SQLAlchemyProfileRepository(BaseRepository[Profile]):
"""SQLAlchemy implementation of ProfileRepository.
Note: This repository only performs CRUD operations.
- No commit (only flush) - service layer handles transactions
- No auth logic - service layer handles authorization
- No HTTP exceptions - returns None or raises SQLAlchemyError
"""
def __init__(self, session: AsyncSession) -> None:
super().__init__(session, Profile)
async def get_by_user_id(self, user_id: UUID) -> Profile | None:
try:
return await self.get_by_id(user_id)
except SQLAlchemyError:
logger.exception("Profile lookup failed", user_id=str(user_id))
raise
async def get_by_username(self, username: str) -> Profile | None:
try:
return await self.get_one(Profile.username == username)
except SQLAlchemyError:
logger.exception("Profile lookup failed", username=username)
raise
async def update_by_user_id(
self, user_id: UUID, update_data: dict[str, str | None]
) -> Profile | None:
if not update_data:
return await self.get_by_user_id(user_id)
try:
return await self.update_by_id(user_id, update_data)
except SQLAlchemyError:
logger.exception("Profile update failed", user_id=str(user_id))
raise
+36
View File
@@ -0,0 +1,36 @@
from __future__ import annotations
from typing import Annotated
from fastapi import APIRouter, Depends, Path
from v1.profile.dependencies import get_profile_service
from v1.profile.schemas import ProfileResponse, ProfileUpdateRequest
from v1.profile.service import ProfileService
router = APIRouter(prefix="/profile", tags=["profile"])
@router.get("/me", response_model=ProfileResponse)
async def get_me(
service: Annotated[ProfileService, Depends(get_profile_service)],
) -> ProfileResponse:
return await service.get_me()
@router.patch("/me", response_model=ProfileResponse)
async def update_me(
payload: ProfileUpdateRequest,
service: Annotated[ProfileService, Depends(get_profile_service)],
) -> ProfileResponse:
return await service.update_me(payload)
@router.get("/{username}", response_model=ProfileResponse)
async def get_by_username(
username: Annotated[
str, Path(min_length=3, max_length=30, pattern="^[a-zA-Z0-9_]+$")
],
service: Annotated[ProfileService, Depends(get_profile_service)],
) -> ProfileResponse:
return await service.get_by_username(username)
+33
View File
@@ -0,0 +1,33 @@
from __future__ import annotations
from pydantic import AnyHttpUrl, BaseModel, Field, field_validator, model_validator
class ProfileResponse(BaseModel):
id: str
username: str
display_name: str | None = None
avatar_url: str | None = None
bio: str | None = None
class ProfileUpdateRequest(BaseModel):
display_name: str | None = Field(default=None, max_length=50)
avatar_url: str | None = Field(default=None)
bio: str | None = Field(default=None, max_length=200)
@field_validator("avatar_url", mode="before")
@classmethod
def validate_avatar_url(cls, v: str | None) -> str | None:
if v is None:
return None
parsed = AnyHttpUrl(v)
if parsed.scheme not in ("http", "https"):
raise ValueError("avatar_url must use http or https scheme")
return str(parsed)
@model_validator(mode="after")
def require_one_field(self) -> "ProfileUpdateRequest":
if self.display_name is None and self.avatar_url is None and self.bio is None:
raise ValueError("At least one field must be provided")
return self
+106
View File
@@ -0,0 +1,106 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from fastapi import HTTPException
from sqlalchemy.exc import SQLAlchemyError
from core.auth.models import CurrentUser
from core.db.base_service import BaseService
from core.logging import get_logger
from v1.profile.repository import ProfileRepository
from v1.profile.schemas import ProfileResponse, ProfileUpdateRequest
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
logger = get_logger("v1.profile.service")
class ProfileService(BaseService):
"""Profile service handling business logic and transactions.
Responsibilities:
- Authorization checks
- Transaction boundary (commit/rollback)
- Converting ORM models to response schemas
"""
_repository: ProfileRepository
_session: AsyncSession
def __init__(
self,
repository: ProfileRepository,
session: AsyncSession,
current_user: CurrentUser | None,
) -> None:
super().__init__(current_user=current_user)
self._repository = repository
self._session = session
async def get_me(self) -> ProfileResponse:
user_id = self.require_user_id()
try:
profile = await self._repository.get_by_user_id(user_id)
except SQLAlchemyError:
raise HTTPException(status_code=503, detail="Profile store unavailable")
if profile is None:
raise HTTPException(status_code=404, detail="Profile not found")
return ProfileResponse(
id=str(profile.id),
username=profile.username,
display_name=profile.display_name,
avatar_url=profile.avatar_url,
bio=profile.bio,
)
async def update_me(self, update: ProfileUpdateRequest) -> ProfileResponse:
user_id = self.require_user_id()
update_data: dict[str, str | None] = {
key: value
for key, value in {
"display_name": update.display_name,
"avatar_url": update.avatar_url,
"bio": update.bio,
}.items()
if value is not None
}
if not update_data:
raise HTTPException(status_code=400, detail="No fields to update")
try:
profile = await self._repository.update_by_user_id(user_id, update_data)
await self._session.commit()
except SQLAlchemyError:
await self._session.rollback()
raise HTTPException(status_code=503, detail="Profile store unavailable")
if profile is None:
raise HTTPException(status_code=404, detail="Profile not found")
return ProfileResponse(
id=str(profile.id),
username=profile.username,
display_name=profile.display_name,
avatar_url=profile.avatar_url,
bio=profile.bio,
)
async def get_by_username(self, username: str) -> ProfileResponse:
try:
profile = await self._repository.get_by_username(username)
except SQLAlchemyError:
raise HTTPException(status_code=503, detail="Profile store unavailable")
if profile is None:
raise HTTPException(status_code=404, detail="Profile not found")
return ProfileResponse(
id=str(profile.id),
username=profile.username,
display_name=profile.display_name,
avatar_url=profile.avatar_url,
bio=profile.bio,
)
+19
View File
@@ -0,0 +1,19 @@
from __future__ import annotations
from fastapi import APIRouter
from core.http.models import HealthResponse
from v1.auth.router import router as auth_router
from v1.infra.router import router as infra_router
from v1.profile.router import router as profile_router
router = APIRouter(prefix="/api/v1")
router.include_router(auth_router)
router.include_router(infra_router)
router.include_router(profile_router)
@router.get("/health", response_model=HealthResponse)
async def health() -> HealthResponse:
return HealthResponse(status="ok")