refactor: align backend layout and supabase infra

Consolidate backend modules/tests under the backend package while syncing Supabase compose/env config and related plans.
This commit is contained in:
qzl
2026-02-05 15:13:06 +08:00
parent 3cfcb11240
commit ad06fe7de4
111 changed files with 5540 additions and 1362 deletions
+111
View File
@@ -0,0 +1,111 @@
## Python Environment
**MUST use uv for dependency management and virtual environment execution.**
- All Python commands: `uv run <command>`
- Add dependencies: `uv add <package>`
- All dependencies declared in `pyproject.toml`
## Logging
**MUST use project logger for all runtime logging.**
- Use project logger from `backend/src/core/logging/*`
- Prohibit: print(), logging.info/warning/error directly
- Required: structured logging with context
- Log levels: DEBUG, INFO, WARNING, ERROR, CRITICAL
## HTTP API Standards
**MUST follow RESTful conventions and RFC 7807 for error responses.**
- Errors must use `application/problem+json` with RFC 7807 fields
- No custom response envelopes for HTTP APIs
- Request and response validation must use Pydantic models
## Environment Variables
**Backend env access MUST go through** `backend/src/core/config/settings.py`.
- Only use `Settings()` / `config` from `core.config.settings`
- Do not call `os.environ`, `os.getenv`, `dotenv`, or manual parsing in backend runtime code
- Tests can set env vars via `monkeypatch.setenv`, and should read values via `Settings()` unless the test is explicitly validating env plumbing
- Canonical principle: one source of truth per setting; no duplicate/derived env vars in backend code
## Code Quality Checks
**Git pre-commit hook enforces code quality before commit.**
Pre-commit hook automatically runs on backend/ directory:
- `ruff check` - code style and linting
- `basedpyright` - type checking with error level
If any error detected, commit is rejected. Fix errors before committing.
Do not bypass or weaken checks (no ignores, disables, or config relaxations). Resolve the underlying issues.
## TDD First Policy
**Principle: tests before implementation.**
### Coverage Requirements
- Minimum coverage: 80%
- Required test types:
- Unit: isolated functions, utilities, components
- Integration: API endpoints, database operations
- E2E: critical user flows (Playwright)
### Limited Exceptions
- Docs-only changes (README, comments, formatting) may skip integration/E2E
- Non-runtime config changes may skip E2E if no behavior changes
- Any runtime code change requires unit + integration + E2E
- If an exception is used, record the reason in the PR/test notes
### Mandatory TDD Workflow
1. Write tests (RED) - they must fail
2. Run tests - confirm failure
3. Implement minimal code (GREEN) - only to pass
4. Run tests - confirm success
5. Refactor (IMPROVE)
6. Verify coverage - must be 80%+
### Enforcement
- Must use the `tdd-guide` agent for new features
- Do not write implementation before tests
- Do not lower coverage requirements
- Must include unit, integration, and E2E tests
## Database Development Rules
### Core Principle
- **Supabase**: authentication (JWT source of truth)
- **Backend**: business authorization (service layer)
- **SQLAlchemy ORM**: data access layer (async + asyncpg, service_role connection)
### Architecture
Use `schemas / repository / service` pattern:
- `schemas.py` — Pydantic models
- `repository.py` — CRUD only, no auth, no commit (only flush), must receive session (never create session/engine)
- `service.py` — authorization + business logic + transaction boundary (must commit/rollback)
- `dependencies.py` — DI (`get_db`, `get_current_user`)
### Auth & Data Access
- Backend must verify JWT signature and expiration (not just decode)
- Extract `user_id` from JWT `sub` claim
- Backend connects with **service_role** (bypasses RLS)
- `owner_id` always derived from JWT, never from client
- Scope queries by owner/org; public access must be explicit
- service_role key is backend-only; never expose credentials
- Prohibit calling Supabase Admin API (service_role key) from repository/service layers
### Migrations
- **Alembic is the single source of truth** for schema migrations
- ORM model changes → `alembic revision --autogenerate`
- Raw SQL (policies, triggers, functions) → `op.execute()`
- Migrations must be reversible; no reliance on generated IDs
### RLS Guidance
- Backend does not rely on RLS for correctness (uses service_role)
- **Backend-only tables**: RLS optional (skip to reduce maintenance)
- **Client-direct tables**: must enable RLS with policies covering select/insert/update/delete
- `alembic_version` must not be exposed to anonymous clients (revoke anon access)
- Business tables that may be exposed to clients should enable defensive RLS even if the backend does not depend on it
+1
View File
@@ -0,0 +1 @@
Generic single-database configuration.
+149
View File
@@ -0,0 +1,149 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts.
# this is typically a path given in POSIX (e.g. forward slashes)
# format, relative to the token %(here)s which refers to the location of this
# ini file
script_location = %(here)s
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
# for all available tokens
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# Or organize into date-based subdirectories (requires recursive_version_locations = true)
# file_template = %%(year)d/%%(month).2d/%%(day).2d_%%(hour).2d%%(minute).2d_%%(second).2d_%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory. for multiple paths, the path separator
# is defined by "path_separator" below.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the tzdata library which can be installed by adding
# `alembic[tz]` to the pip requirements.
# string value is passed to ZoneInfo()
# leave blank for localtime
# timezone =
# max length of characters to apply to the "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to <script_location>/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "path_separator"
# below.
# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions
# path_separator; This indicates what character is used to split lists of file
# paths, including version_locations and prepend_sys_path within configparser
# files such as alembic.ini.
# The default rendered in new alembic.ini files is "os", which uses os.pathsep
# to provide os-dependent path splitting.
#
# Note that in order to support legacy alembic.ini files, this default does NOT
# take place if path_separator is not present in alembic.ini. If this
# option is omitted entirely, fallback logic is as follows:
#
# 1. Parsing of the version_locations option falls back to using the legacy
# "version_path_separator" key, which if absent then falls back to the legacy
# behavior of splitting on spaces and/or commas.
# 2. Parsing of the prepend_sys_path option falls back to the legacy
# behavior of splitting on spaces, commas, or colons.
#
# Valid values for path_separator are:
#
# path_separator = :
# path_separator = ;
# path_separator = space
# path_separator = newline
#
# Use os.pathsep. Default configuration used for new projects.
path_separator = os
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
# database URL. This is consumed by the user-maintained env.py script only.
# other means of configuring database URLs may be customized within the env.py
# file.
sqlalchemy.url = driver://user:pass@localhost/dbname
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module
# hooks = ruff
# ruff.type = module
# ruff.module = ruff
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
# Alternatively, use the exec runner to execute a binary found on your PATH
# hooks = ruff
# ruff.type = exec
# ruff.executable = ruff
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
# Logging configuration. This is also consumed by the user-maintained
# env.py script only.
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARNING
handlers = console
qualname =
[logger_sqlalchemy]
level = WARNING
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S
+90
View File
@@ -0,0 +1,90 @@
from __future__ import annotations
import asyncio
import sys
from logging.config import fileConfig
from pathlib import Path
from typing import TYPE_CHECKING, Any
from alembic import context
from sqlalchemy import pool
from sqlalchemy.ext.asyncio import async_engine_from_config
project_root = Path(__file__).resolve().parents[1]
src_path = project_root / "src"
if str(src_path) not in sys.path:
sys.path = [str(src_path), *sys.path]
from core.config.settings import config # noqa: E402
from core.db.base import Base # noqa: E402
from models import Profile # noqa: F401,E402
if TYPE_CHECKING:
from sqlalchemy.engine import Connection
alembic_config = context.config
if alembic_config.config_file_name is not None:
fileConfig(alembic_config.config_file_name)
target_metadata = Base.metadata
def _get_database_url() -> str:
database_url = config.database_url
if not database_url:
raise RuntimeError(
"DATABASE_URL is not configured. Set SOCIAL_INFRA__SUPABASE__DATABASE_URL."
)
return database_url
def _build_config() -> dict[str, Any]:
section = alembic_config.get_section(alembic_config.config_ini_section) or {}
return {**section, "sqlalchemy.url": _get_database_url()}
def run_migrations_offline() -> None:
url = _get_database_url()
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
compare_type=True,
compare_server_default=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def _do_run_migrations(connection: "Connection" | Any) -> None:
context.configure(
connection=connection,
target_metadata=target_metadata,
compare_type=True,
compare_server_default=True,
)
with context.begin_transaction():
context.run_migrations()
async def run_migrations_online() -> None:
connectable = async_engine_from_config(
_build_config(),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
async with connectable.connect() as connection:
await connection.run_sync(_do_run_migrations)
await connectable.dispose()
if context.is_offline_mode():
run_migrations_offline()
else:
asyncio.run(run_migrations_online())
+28
View File
@@ -0,0 +1,28 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
"""Upgrade schema."""
${upgrades if upgrades else "pass"}
def downgrade() -> None:
"""Downgrade schema."""
${downgrades if downgrades else "pass"}
@@ -0,0 +1,45 @@
from __future__ import annotations
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
revision = "20260205_create_profiles_table"
down_revision = None
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"profiles",
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("username", sa.String(length=30), nullable=False),
sa.Column("display_name", sa.String(length=50), nullable=True),
sa.Column("avatar_url", sa.Text(), nullable=True),
sa.Column("bio", sa.String(length=200), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True),
sa.PrimaryKeyConstraint("id", name="pk_profiles"),
sa.UniqueConstraint("username", name="uq_profiles_username"),
)
op.create_index("ix_profiles_username", "profiles", ["username"])
op.create_index("ix_profiles_deleted_at", "profiles", ["deleted_at"])
def downgrade() -> None:
op.drop_index("ix_profiles_deleted_at", table_name="profiles")
op.drop_index("ix_profiles_username", table_name="profiles")
op.drop_table("profiles")
@@ -0,0 +1,86 @@
"""enable_rls_security_policies
Revision ID: 85d25a191d06
Revises: 20260205_create_profiles_table
Create Date: 2026-02-05 15:08:33.430692
"""
from typing import Sequence, Union
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "85d25a191d06"
down_revision: Union[str, Sequence[str], None] = "20260205_create_profiles_table"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Enable RLS security policies.
Security measures:
1. Revoke anon role access to alembic_version (internal table)
2. Enable RLS on profiles table
3. Add defensive policies for profiles (deny all public access by default)
Architecture:
- Backend uses service_role connection (bypasses RLS)
- RLS provides defense-in-depth security layer
- Prevents accidental direct PostgREST access
"""
# 1. Revoke anon role access to alembic_version table
op.execute("REVOKE ALL ON TABLE public.alembic_version FROM anon")
op.execute("REVOKE ALL ON TABLE public.alembic_version FROM authenticated")
# 2. Enable RLS on profiles table
op.execute("ALTER TABLE public.profiles ENABLE ROW LEVEL SECURITY")
# 3. Add defensive policies for profiles table
# These policies deny all public access by default
# Backend service_role connection bypasses these policies
# Deny all SELECT operations for anon and authenticated roles
op.execute(
"CREATE POLICY profiles_deny_public_select ON public.profiles "
"FOR SELECT TO anon, authenticated USING (false)"
)
# Deny all INSERT operations for anon and authenticated roles
op.execute(
"CREATE POLICY profiles_deny_public_insert ON public.profiles "
"FOR INSERT TO anon, authenticated WITH CHECK (false)"
)
# Deny all UPDATE operations for anon and authenticated roles
op.execute(
"CREATE POLICY profiles_deny_public_update ON public.profiles "
"FOR UPDATE TO anon, authenticated USING (false) WITH CHECK (false)"
)
# Deny all DELETE operations for anon and authenticated roles
op.execute(
"CREATE POLICY profiles_deny_public_delete ON public.profiles "
"FOR DELETE TO anon, authenticated USING (false)"
)
def downgrade() -> None:
"""Rollback RLS security policies."""
# 1. Drop all policies on profiles table
op.execute("DROP POLICY IF EXISTS profiles_deny_public_select ON public.profiles")
op.execute("DROP POLICY IF EXISTS profiles_deny_public_insert ON public.profiles")
op.execute("DROP POLICY IF EXISTS profiles_deny_public_update ON public.profiles")
op.execute("DROP POLICY IF EXISTS profiles_deny_public_delete ON public.profiles")
# 2. Disable RLS on profiles table
op.execute("ALTER TABLE public.profiles DISABLE ROW LEVEL SECURITY")
# 3. Re-grant default privileges to anon role on alembic_version
# (reverting to Alembic's default behavior)
op.execute("GRANT SELECT ON TABLE public.alembic_version TO anon")
op.execute("GRANT SELECT ON TABLE public.alembic_version TO authenticated")
View File
+133
View File
@@ -0,0 +1,133 @@
from __future__ import annotations
from fastapi import FastAPI, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from starlette.exceptions import HTTPException as StarletteHTTPException
from core.config.settings import config
from core.http.models import HealthResponse
from core.http.response import build_problem_details
from core.logging import configure_logging, get_logger
from v1.router import router as mobile_router
configure_logging(config)
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=config.cors.allow_origins,
allow_credentials=config.cors.allow_credentials,
allow_methods=config.cors.allow_methods,
allow_headers=config.cors.allow_headers,
)
app.include_router(mobile_router)
logger = get_logger("api.app")
@app.get("/health", response_model=HealthResponse)
async def health() -> HealthResponse:
return HealthResponse(status="ok")
def _build_http_error_response(
request: Request,
exc: Exception,
status_code: int,
detail: object,
) -> JSONResponse:
instance = request.url.path
detail_text = detail if isinstance(detail, str) else "Request failed"
logger.warning(
"HTTP error",
status_code=status_code,
detail=detail_text,
detail_extra=detail,
path=request.url.path,
method=request.method,
)
problem = build_problem_details(
status_code=status_code,
detail=detail_text,
instance=instance,
)
return JSONResponse(
status_code=status_code,
content=problem.model_dump(),
media_type="application/problem+json",
)
@app.exception_handler(HTTPException)
async def http_exception_handler(
request: Request,
exc: HTTPException,
) -> JSONResponse:
return _build_http_error_response(
request=request,
exc=exc,
status_code=exc.status_code,
detail=exc.detail,
)
@app.exception_handler(StarletteHTTPException)
async def starlette_http_exception_handler(
request: Request,
exc: StarletteHTTPException,
) -> JSONResponse:
return _build_http_error_response(
request=request,
exc=exc,
status_code=exc.status_code,
detail=exc.detail,
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(
request: Request,
exc: RequestValidationError,
) -> JSONResponse:
instance = request.url.path
logger.warning(
"Request validation error",
path=request.url.path,
method=request.method,
errors=exc.errors(),
)
problem = build_problem_details(
status_code=422,
detail="Invalid request",
instance=instance,
)
return JSONResponse(
status_code=422,
content=problem.model_dump(),
media_type="application/problem+json",
)
@app.exception_handler(Exception)
async def unhandled_exception_handler(
request: Request,
exc: Exception,
) -> JSONResponse:
instance = request.url.path
logger.exception(
"Unhandled error",
path=request.url.path,
method=request.method,
)
problem = build_problem_details(
status_code=500,
detail="Internal Server Error",
instance=instance,
)
return JSONResponse(
status_code=500,
content=problem.model_dump(),
media_type="application/problem+json",
)
View File
+9
View File
@@ -0,0 +1,9 @@
from __future__ import annotations
from dataclasses import dataclass
from uuid import UUID
@dataclass(frozen=True)
class CurrentUser:
id: UUID
+3
View File
@@ -0,0 +1,3 @@
from .settings import Settings, config
__all__ = ["Settings", "config"]
+166
View File
@@ -0,0 +1,166 @@
from __future__ import annotations
from pathlib import Path
from typing import ClassVar, Literal
from urllib.parse import quote
from pydantic import BaseModel, Field, computed_field
from pydantic_settings import BaseSettings, SettingsConfigDict
class RuntimeSettings(BaseModel):
environment: Literal["dev", "test", "prod"] = "dev"
debug: bool = True
log_level: str = "INFO"
log_json: bool = True
log_rotation: Literal["time", "size", "none"] = "time"
log_rotation_when: str = "midnight"
log_rotation_interval: int = 1
log_rotation_backup_count: int = 14
log_rotation_max_bytes: int = 10_000_000
log_dir: str = "logs"
log_error_dir: str = "logs/errors"
log_file_name: str = "app.log"
log_error_file_name: str = "error.log"
log_sensitive_fields: list[str] = Field(
default_factory=lambda: [
"password",
"secret",
"token",
"api_key",
"authorization",
"cookie",
"client_ip",
"user_id",
]
)
sql_log_queries: bool = False
class AppSettings(BaseModel):
host: str = "0.0.0.0"
port: int = Field(default=8000, ge=1, le=65535)
reload: bool = True
class CorsSettings(BaseModel):
allow_origins: list[str] = Field(
default_factory=lambda: [
"http://localhost",
"http://localhost:3000",
]
)
allow_credentials: bool = True
allow_methods: list[str] = Field(default_factory=lambda: ["*"])
allow_headers: list[str] = Field(default_factory=lambda: ["*"])
class RedisSettings(BaseModel):
host: str = "redis"
port: int = 6379
password: str | None = None
db: int = 0
socket_connect_timeout: float = 1.0
socket_timeout: float = 1.0
max_connections: int = 10
@computed_field
@property
def url(self) -> str:
if self.password:
password = quote(self.password, safe="")
return f"redis://:{password}@{self.host}:{self.port}/{self.db}"
return f"redis://{self.host}:{self.port}/{self.db}"
class QdrantSettings(BaseModel):
host: str = "qdrant"
port: int = 6333
grpc_port: int = 6334
api_key: str | None = None
https: bool = False
prefer_grpc: bool = True
timeout: int = 5
@computed_field
@property
def url(self) -> str:
scheme = "https" if self.https else "http"
return f"{scheme}://{self.host}:{self.port}"
class SupabaseSettings(BaseModel):
public_scheme: str = "http"
public_host: str = "localhost"
kong_http_port: int = 8000
anon_key: str = "CHANGE_ME"
service_role_key: str = "CHANGE_ME"
jwt_secret: str | None = None
@computed_field
@property
def public_url(self) -> str:
return f"{self.public_scheme}://{self.public_host}:{self.kong_http_port}"
@computed_field
@property
def api_external_url(self) -> str:
return self.public_url
@computed_field
@property
def url(self) -> str:
return self.public_url
class DatabaseSettings(BaseModel):
host: str = "localhost"
port: int = 5432
name: str = "postgres"
user: str = "postgres"
password: str = "CHANGE_ME"
@computed_field
@property
def url(self) -> str:
password = quote(self.password, safe="")
return (
f"postgresql+asyncpg://{self.user}:{password}"
f"@{self.host}:{self.port}/{self.name}"
)
def _resolve_env_file() -> str:
current = Path(__file__).resolve()
for parent in [current, *current.parents]:
candidate = parent / ".env"
if candidate.is_file():
return str(candidate)
return ".env"
class Settings(BaseSettings):
runtime: RuntimeSettings = RuntimeSettings()
app: AppSettings = AppSettings()
cors: CorsSettings = CorsSettings()
redis: RedisSettings = RedisSettings()
qdrant: QdrantSettings = QdrantSettings()
supabase: SupabaseSettings = SupabaseSettings()
database: DatabaseSettings = DatabaseSettings()
@computed_field
@property
def database_url(self) -> str:
return self.database.url
model_config: ClassVar[SettingsConfigDict] = SettingsConfigDict(
env_file=_resolve_env_file(),
env_prefix="SOCIAL_",
env_nested_delimiter="__",
case_sensitive=False,
extra="ignore",
)
config = Settings()
+5
View File
@@ -0,0 +1,5 @@
from __future__ import annotations
from core.db.session import AsyncSessionLocal, engine, get_db
__all__ = ["AsyncSessionLocal", "engine", "get_db"]
+37
View File
@@ -0,0 +1,37 @@
from __future__ import annotations
from datetime import datetime
from sqlalchemy import DateTime, func
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
class Base(DeclarativeBase):
"""Base class for all ORM models."""
pass
class TimestampMixin:
"""Adds created_at and updated_at timestamps."""
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
class SoftDeleteMixin:
"""Adds soft delete timestamp column."""
deleted_at: Mapped[datetime | None] = mapped_column(
DateTime(timezone=True),
nullable=True,
)
+84
View File
@@ -0,0 +1,84 @@
from __future__ import annotations
from datetime import datetime, timezone
from typing import Any, Generic, TypeVar
from sqlalchemy import Select, select, update
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from core.db.base import Base
ModelType = TypeVar("ModelType", bound=Base)
class BaseRepository(Generic[ModelType]):
_session: AsyncSession
_model: type[ModelType]
def __init__(self, session: AsyncSession, model: type[ModelType]) -> None:
self._session = session
self._model = model
def _deleted_at_column(self) -> Any | None:
return getattr(self._model, "deleted_at", None)
def _apply_soft_delete_filter(self, stmt: Select) -> Select:
deleted_at = self._deleted_at_column()
if deleted_at is None:
return stmt
return stmt.where(deleted_at.is_(None))
async def get_by_id(self, entity_id: Any) -> ModelType | None:
id_column = getattr(self._model, "id")
stmt = select(self._model).where(id_column == entity_id)
stmt = self._apply_soft_delete_filter(stmt)
result = await self._session.execute(stmt)
return result.scalar_one_or_none()
async def get_one(self, *filters: Any) -> ModelType | None:
stmt = select(self._model).where(*filters)
stmt = self._apply_soft_delete_filter(stmt)
result = await self._session.execute(stmt)
return result.scalar_one_or_none()
async def update_by_id(
self, entity_id: Any, update_data: dict[str, Any]
) -> ModelType | None:
if not update_data:
return await self.get_by_id(entity_id)
id_column = getattr(self._model, "id")
stmt = update(self._model).where(id_column == entity_id)
deleted_at = self._deleted_at_column()
if deleted_at is not None:
stmt = stmt.where(deleted_at.is_(None))
stmt = stmt.values(**update_data).returning(self._model)
try:
result = await self._session.execute(stmt)
await self._session.flush()
return result.scalar_one_or_none()
except SQLAlchemyError:
raise
async def soft_delete_by_id(self, entity_id: Any) -> ModelType | None:
deleted_at = self._deleted_at_column()
if deleted_at is None:
raise ValueError("Soft delete is not supported for this model")
id_column = getattr(self._model, "id")
stmt = (
update(self._model)
.where(id_column == entity_id)
.where(deleted_at.is_(None))
.values(deleted_at=datetime.now(timezone.utc))
.returning(self._model)
)
try:
result = await self._session.execute(stmt)
await self._session.flush()
return result.scalar_one_or_none()
except SQLAlchemyError:
raise
+22
View File
@@ -0,0 +1,22 @@
from __future__ import annotations
from uuid import UUID
from fastapi import HTTPException
from core.auth.models import CurrentUser
class BaseService:
_current_user: CurrentUser | None
def __init__(self, current_user: CurrentUser | None) -> None:
self._current_user = current_user
def require_current_user(self) -> CurrentUser:
if self._current_user is None:
raise HTTPException(status_code=401, detail="Unauthorized")
return self._current_user
def require_user_id(self) -> UUID:
return self.require_current_user().id
+34
View File
@@ -0,0 +1,34 @@
from __future__ import annotations
from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from core.config.settings import config
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncEngine
engine: AsyncEngine = create_async_engine(
config.database_url,
echo=config.runtime.sql_log_queries,
pool_pre_ping=True,
)
AsyncSessionLocal: async_sessionmaker[AsyncSession] = async_sessionmaker(
bind=engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
)
async def get_db() -> AsyncGenerator[AsyncSession, None]:
"""Dependency that provides a database session.
The session is automatically closed when the request completes.
Note: The caller (service layer) is responsible for commit/rollback.
"""
async with AsyncSessionLocal() as session:
yield session
+5
View File
@@ -0,0 +1,5 @@
from __future__ import annotations
from core.http.response import ProblemDetails, build_problem_details
__all__ = ["ProblemDetails", "build_problem_details"]
+7
View File
@@ -0,0 +1,7 @@
from __future__ import annotations
from pydantic import BaseModel
class HealthResponse(BaseModel):
status: str
+31
View File
@@ -0,0 +1,31 @@
from __future__ import annotations
from http import HTTPStatus
from pydantic import BaseModel
class ProblemDetails(BaseModel):
type: str = "about:blank"
title: str
status: int
detail: str
instance: str | None = None
def build_problem_details(
*,
status_code: int,
detail: str,
type_value: str = "about:blank",
title: str | None = None,
instance: str | None = None,
) -> ProblemDetails:
resolved_title = title or HTTPStatus(status_code).phrase
return ProblemDetails(
type=type_value,
title=resolved_title,
status=status_code,
detail=detail,
instance=instance,
)
+15
View File
@@ -0,0 +1,15 @@
from __future__ import annotations
from core.logging import celery
from core.logging.config import configure_logging
from core.logging.context import bind_context, clear_context, get_context
from core.logging.logger import get_logger
__all__ = [
"bind_context",
"celery",
"clear_context",
"configure_logging",
"get_context",
"get_logger",
]
+57
View File
@@ -0,0 +1,57 @@
from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass
from typing import cast
from celery import Celery, signals
from core.config.settings import Settings
from core.logging.config import configure_logging
from core.logging.context import bind_context, clear_context
@dataclass(frozen=True)
class CelerySignalHandlers:
on_setup_logging: Callable[..., None]
on_after_setup_task_logger: Callable[..., None]
on_task_prerun: Callable[..., None]
on_task_postrun: Callable[..., None]
def build_celery_signal_handlers(
settings: Settings | None = None,
) -> CelerySignalHandlers:
def on_setup_logging(*_args: object, **_kwargs: object) -> None:
configure_logging(settings)
def on_after_setup_task_logger(*_args: object, **_kwargs: object) -> None:
configure_logging(settings)
def on_task_prerun(*_args: object, **kwargs: object) -> None:
task_id = cast(str | None, kwargs.get("task_id"))
task = kwargs.get("task")
task_name = getattr(task, "name", None)
bind_context(task_id=task_id, task_name=task_name)
def on_task_postrun(*_args: object, **_kwargs: object) -> None:
clear_context()
return CelerySignalHandlers(
on_setup_logging=on_setup_logging,
on_after_setup_task_logger=on_after_setup_task_logger,
on_task_prerun=on_task_prerun,
on_task_postrun=on_task_postrun,
)
def configure_celery_app(app: Celery, settings: Settings | None = None) -> None:
app.conf.worker_hijack_root_logger = False
handlers = build_celery_signal_handlers(settings)
signals.setup_logging.connect(handlers.on_setup_logging, weak=False)
signals.after_setup_task_logger.connect(
handlers.on_after_setup_task_logger, weak=False
)
signals.task_prerun.connect(handlers.on_task_prerun, weak=False)
signals.task_postrun.connect(handlers.on_task_postrun, weak=False)
+103
View File
@@ -0,0 +1,103 @@
from __future__ import annotations
import logging
from logging.config import dictConfig
from pathlib import Path
import structlog
from core.config.settings import RuntimeSettings, Settings
from core.logging.formatters import (
build_plain_formatter,
build_processor_formatter,
ensure_message_key,
)
from core.logging.filters import build_sensitive_data_processor
from core.logging.handlers import build_file_handler_config
def _ensure_log_dirs(runtime: RuntimeSettings) -> None:
Path(runtime.log_dir).mkdir(parents=True, exist_ok=True)
Path(runtime.log_error_dir).mkdir(parents=True, exist_ok=True)
def build_logging_config(runtime: RuntimeSettings) -> dict[str, object]:
log_dir = Path(runtime.log_dir)
error_dir = Path(runtime.log_error_dir)
formatter_name = "json" if runtime.log_json else "plain"
file_handler = build_file_handler_config(
runtime,
file_path=log_dir / runtime.log_file_name,
level=runtime.log_level,
formatter=formatter_name,
)
error_handler = build_file_handler_config(
runtime,
file_path=error_dir / runtime.log_error_file_name,
level="ERROR",
formatter=formatter_name,
filters=["error_only"],
)
return {
"version": 1,
"disable_existing_loggers": False,
"filters": {
"error_only": {
"()": "core.logging.filters.ErrorLevelFilter",
}
},
"formatters": {
"json": {
"()": build_processor_formatter,
"sensitive_fields": runtime.log_sensitive_fields,
},
"plain": {
"()": build_plain_formatter,
"sensitive_fields": runtime.log_sensitive_fields,
},
},
"handlers": {
"file": file_handler,
"error": error_handler,
},
"root": {
"handlers": ["file", "error"],
"level": runtime.log_level,
},
}
def configure_logging(settings: Settings | None = None) -> None:
active_settings = settings or Settings()
runtime = active_settings.runtime
try:
_ensure_log_dirs(runtime)
dictConfig(build_logging_config(runtime))
except (OSError, ValueError) as exc:
logging.basicConfig(level=runtime.log_level)
logging.getLogger(__name__).error("Logging setup failed", exc_info=exc)
structlog.configure(
processors=[
structlog.contextvars.merge_contextvars,
structlog.processors.add_log_level,
structlog.processors.TimeStamper(fmt="iso", utc=True),
structlog.processors.CallsiteParameterAdder(
parameters=[
structlog.processors.CallsiteParameter.MODULE,
structlog.processors.CallsiteParameter.FUNC_NAME,
structlog.processors.CallsiteParameter.LINENO,
]
),
build_sensitive_data_processor(runtime.log_sensitive_fields),
ensure_message_key,
structlog.processors.format_exc_info,
structlog.processors.UnicodeDecoder(),
structlog.stdlib.ProcessorFormatter.wrap_for_formatter,
],
logger_factory=structlog.stdlib.LoggerFactory(),
cache_logger_on_first_use=True,
)
+15
View File
@@ -0,0 +1,15 @@
from __future__ import annotations
from structlog import contextvars
def bind_context(**values: object) -> None:
contextvars.bind_contextvars(**values)
def clear_context() -> None:
contextvars.clear_contextvars()
def get_context() -> dict[str, object]:
return contextvars.get_contextvars()
+56
View File
@@ -0,0 +1,56 @@
from __future__ import annotations
import logging
import re
from collections.abc import Callable
from typing import cast
from structlog.types import EventDict
_NORMALIZE_PATTERN = re.compile(r"[^a-z0-9]")
def _normalize_key(value: str) -> str:
return _NORMALIZE_PATTERN.sub("", value.lower())
def _is_sensitive_key(key: object, sensitive_fields: set[str]) -> bool:
normalized_key = _normalize_key(str(key))
return normalized_key in sensitive_fields or any(
fragment in normalized_key for fragment in sensitive_fields
)
def _redact_value(value: object, sensitive_fields: set[str]) -> object:
if isinstance(value, dict):
typed_value = cast(dict[str, object], value)
return {
key: (
"[REDACTED]"
if _is_sensitive_key(key, sensitive_fields)
else _redact_value(inner, sensitive_fields)
)
for key, inner in typed_value.items()
}
if isinstance(value, list):
return [_redact_value(item, sensitive_fields) for item in value]
return value
def build_sensitive_data_processor(
sensitive_fields: list[str],
) -> Callable[[object, str, EventDict], EventDict]:
normalized = {_normalize_key(field) for field in sensitive_fields}
def processor(
_logger: object, _method_name: str, event_dict: EventDict
) -> EventDict:
return cast(EventDict, _redact_value(event_dict, normalized))
return processor
class ErrorLevelFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
return record.levelno >= logging.ERROR
+81
View File
@@ -0,0 +1,81 @@
from __future__ import annotations
from structlog.dev import ConsoleRenderer
from structlog.processors import JSONRenderer
from structlog.stdlib import ProcessorFormatter
from structlog.types import EventDict
import structlog
from core.logging.filters import build_sensitive_data_processor
def ensure_message_key(
_logger: object, _method_name: str, event_dict: EventDict
) -> EventDict:
if "message" in event_dict:
return event_dict
if "event" not in event_dict:
return event_dict
without_event = {key: value for key, value in event_dict.items() if key != "event"}
return {**without_event, "message": event_dict["event"]}
def build_processor_formatter(
sensitive_fields: list[str] | None = None,
) -> ProcessorFormatter:
redact = build_sensitive_data_processor(sensitive_fields or [])
return ProcessorFormatter(
foreign_pre_chain=[
structlog.contextvars.merge_contextvars,
structlog.processors.add_log_level,
structlog.processors.TimeStamper(fmt="iso", utc=True),
structlog.processors.CallsiteParameterAdder(
parameters=[
structlog.processors.CallsiteParameter.MODULE,
structlog.processors.CallsiteParameter.FUNC_NAME,
structlog.processors.CallsiteParameter.LINENO,
]
),
structlog.stdlib.ExtraAdder(),
ensure_message_key,
],
processors=[
redact,
ensure_message_key,
ProcessorFormatter.remove_processors_meta,
structlog.processors.format_exc_info,
structlog.processors.UnicodeDecoder(),
JSONRenderer(sort_keys=True),
],
)
def build_plain_formatter(
sensitive_fields: list[str] | None = None,
) -> ProcessorFormatter:
redact = build_sensitive_data_processor(sensitive_fields or [])
return ProcessorFormatter(
foreign_pre_chain=[
structlog.contextvars.merge_contextvars,
structlog.processors.add_log_level,
structlog.processors.TimeStamper(fmt="iso", utc=True),
structlog.processors.CallsiteParameterAdder(
parameters=[
structlog.processors.CallsiteParameter.MODULE,
structlog.processors.CallsiteParameter.FUNC_NAME,
structlog.processors.CallsiteParameter.LINENO,
]
),
structlog.stdlib.ExtraAdder(),
ensure_message_key,
],
processors=[
redact,
ensure_message_key,
ProcessorFormatter.remove_processors_meta,
structlog.processors.format_exc_info,
structlog.processors.UnicodeDecoder(),
ConsoleRenderer(colors=False),
],
)
+46
View File
@@ -0,0 +1,46 @@
from __future__ import annotations
from pathlib import Path
from core.config.settings import RuntimeSettings
def build_file_handler_config(
runtime: RuntimeSettings,
file_path: Path,
level: str,
formatter: str,
filters: list[str] | None = None,
) -> dict[str, object]:
filter_list = list(filters or [])
base_config: dict[str, object] = {
"level": level,
"formatter": formatter,
"filename": str(file_path),
"encoding": "utf-8",
}
if filter_list:
base_config = {**base_config, "filters": filter_list}
if runtime.log_rotation == "time":
return {
**base_config,
"class": "logging.handlers.TimedRotatingFileHandler",
"when": runtime.log_rotation_when,
"interval": runtime.log_rotation_interval,
"backupCount": runtime.log_rotation_backup_count,
}
if runtime.log_rotation == "size":
return {
**base_config,
"class": "logging.handlers.RotatingFileHandler",
"maxBytes": runtime.log_rotation_max_bytes,
"backupCount": runtime.log_rotation_backup_count,
}
return {
**base_config,
"class": "logging.FileHandler",
}
+7
View File
@@ -0,0 +1,7 @@
from __future__ import annotations
import structlog
def get_logger(name: str) -> structlog.stdlib.BoundLogger:
return structlog.get_logger(name)
+84
View File
@@ -0,0 +1,84 @@
from __future__ import annotations
import re
from collections.abc import MutableMapping
from typing import cast
from uuid import uuid4
from fastapi import FastAPI, Request
from starlette.requests import Request as StarletteRequest
from starlette.responses import JSONResponse, Response
from starlette.types import ASGIApp, Receive, Scope, Send
from core.logging.context import bind_context, clear_context
from core.logging.logger import get_logger
class RequestContextMiddleware:
app: ASGIApp
_header_name: str
_request_id_pattern: re.Pattern[str]
def __init__(self, app: ASGIApp, header_name: str = "X-Request-ID") -> None:
self.app = app
self._header_name = header_name
self._request_id_pattern = re.compile(r"^[A-Za-z0-9_-]{8,64}$")
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope.get("type") != "http":
await self.app(scope, receive, send)
return
request = StarletteRequest(scope, receive=receive)
request_id = self._normalize_request_id(request.headers.get(self._header_name))
client_ip = request.client.host if request.client else None
user_id = getattr(request.state, "user_id", None)
request.state.request_id = request_id
bind_context(
request_id=request_id,
method=request.method,
path=request.url.path,
client_ip=client_ip,
user_id=user_id,
)
async def send_wrapper(message: MutableMapping[str, object]) -> None:
if message.get("type") == "http.response.start":
raw_headers = message.get("headers")
headers = list(cast(list[tuple[bytes, bytes]], raw_headers or []))
header_key = self._header_name.lower().encode()
if not any(item[0].lower() == header_key for item in headers):
headers.append((header_key, request_id.encode()))
message = {**message, "headers": headers}
await send(message)
try:
await self.app(scope, receive, send_wrapper)
finally:
clear_context()
def _normalize_request_id(self, request_id: str | None) -> str:
if request_id and self._request_id_pattern.match(request_id):
return request_id
return str(uuid4())
def register_exception_handlers(app: FastAPI) -> None:
logger = get_logger("core.logging.exception")
@app.exception_handler(Exception)
async def unhandled_exception_handler(request: Request, exc: Exception) -> Response:
request_id = getattr(request.state, "request_id", None)
logger.exception(
"Unhandled exception",
error_type=exc.__class__.__name__,
request_id=request_id,
)
headers = {"X-Request-ID": request_id} if request_id else None
return JSONResponse(
status_code=500,
content={"detail": "Internal Server Error"},
headers=headers,
)
+5
View File
@@ -0,0 +1,5 @@
from __future__ import annotations
from models.profile import Profile
__all__ = ["Profile"]
+43
View File
@@ -0,0 +1,43 @@
from __future__ import annotations
import uuid
from sqlalchemy import String, Text
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column
from core.db.base import Base, SoftDeleteMixin, TimestampMixin
class Profile(TimestampMixin, SoftDeleteMixin, Base):
"""User profile model.
Note: The `id` column references auth.users(id) in Supabase.
This is a business table managed by SQLAlchemy, with the foreign key
relationship to Supabase's auth schema handled at the database level.
"""
__tablename__: str = "profiles"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
primary_key=True,
)
username: Mapped[str] = mapped_column(
String(30),
unique=True,
nullable=False,
index=True,
)
display_name: Mapped[str | None] = mapped_column(
String(50),
nullable=True,
)
avatar_url: Mapped[str | None] = mapped_column(
Text,
nullable=True,
)
bio: Mapped[str | None] = mapped_column(
String(200),
nullable=True,
)
+1
View File
@@ -0,0 +1 @@
from __future__ import annotations
+21
View File
@@ -0,0 +1,21 @@
from __future__ import annotations
from services.base.qdrant import QdrantService, qdrant_service
from services.base.redis import RedisService, redis_service
from services.base.service_interface import (
BaseServiceProvider,
ServiceRegistry,
register_service,
register_service_instance,
)
__all__ = [
"BaseServiceProvider",
"QdrantService",
"RedisService",
"ServiceRegistry",
"qdrant_service",
"redis_service",
"register_service",
"register_service_instance",
]
+94
View File
@@ -0,0 +1,94 @@
from __future__ import annotations
import asyncio
from typing import Any, Dict, Optional
from qdrant_client import QdrantClient
from core.config.settings import QdrantSettings, config
from .service_interface import BaseServiceProvider, register_service_instance
class QdrantService(BaseServiceProvider):
def __init__(self, settings: QdrantSettings | None = None) -> None:
super().__init__("qdrant")
self._settings = settings or config.qdrant
self._client: Optional[QdrantClient] = None
def _build_client(self) -> QdrantClient:
return QdrantClient(
url=self._settings.url,
api_key=self._settings.api_key,
timeout=self._settings.timeout,
prefer_grpc=self._settings.prefer_grpc,
)
def _require_client(self) -> QdrantClient:
client = self._client
if client is None:
raise RuntimeError("Qdrant client is not initialized")
return client
async def initialize(self, **_: Any) -> bool:
try:
client = self._build_client()
collections = await asyncio.to_thread(client.get_collections)
self.logger.info(
"Qdrant service initialized",
collections_count=len(collections.collections),
)
self._client = client
self._set_initialized(True)
return True
except Exception as exc: # noqa: BLE001
self.logger.warning("Qdrant service initialization failed", error=str(exc))
self._client = None
self._set_initialized(False)
return False
async def close(self) -> bool:
client = self._client
if client is None:
return True
try:
close = getattr(client, "close", None)
if callable(close):
await asyncio.to_thread(close)
self.logger.info("Qdrant service closed")
self._client = None
self._set_initialized(False)
return True
except Exception as exc: # noqa: BLE001
self.logger.exception("Qdrant service close failed", error=str(exc))
self._client = None
self._set_initialized(False)
return False
async def health_check(self) -> Dict[str, Any]:
client = self._client
if client is None:
return {"status": "unhealthy", "details": {"error": "not initialized"}}
try:
collections = await asyncio.to_thread(client.get_collections)
return {
"status": "healthy",
"details": {
"connected": True,
"collections_count": len(collections.collections),
"collections": [
collection.name for collection in collections.collections[:5]
],
},
}
except Exception as exc: # noqa: BLE001
self.logger.warning("Qdrant health check failed", error=str(exc))
return {"status": "unhealthy", "details": {"error": str(exc)}}
def get_client(self) -> QdrantClient:
return self._require_client()
qdrant_service: QdrantService = register_service_instance("qdrant", QdrantService())
__all__ = ["QdrantService", "qdrant_service"]
+97
View File
@@ -0,0 +1,97 @@
from __future__ import annotations
import inspect
from typing import Any, Dict, Optional
import redis.asyncio as redis
from core.config.settings import RedisSettings, config
from .service_interface import BaseServiceProvider, register_service_instance
class RedisService(BaseServiceProvider):
def __init__(self, settings: RedisSettings | None = None) -> None:
super().__init__("redis")
self._settings = settings or config.redis
self._client: Optional[redis.Redis] = None
def _build_client(self) -> redis.Redis:
return redis.from_url(
self._settings.url,
decode_responses=True,
socket_connect_timeout=self._settings.socket_connect_timeout,
socket_timeout=self._settings.socket_timeout,
max_connections=self._settings.max_connections,
)
def _require_client(self) -> redis.Redis:
client = self._client
if client is None:
raise RuntimeError("Redis client is not initialized")
return client
async def initialize(self, **_: Any) -> bool:
try:
client = self._build_client()
ping_result = client.ping()
if inspect.isawaitable(ping_result):
await ping_result
self._client = client
self._set_initialized(True)
self.logger.info("Redis service initialized")
return True
except Exception as exc: # noqa: BLE001
self.logger.warning("Redis service initialization failed", error=str(exc))
self._client = None
self._set_initialized(False)
return False
async def close(self) -> bool:
client = self._client
if client is None:
return True
try:
await client.aclose()
self.logger.info("Redis service closed")
self._client = None
self._set_initialized(False)
return True
except Exception as exc: # noqa: BLE001
self.logger.exception("Redis service close failed", error=str(exc))
return False
async def health_check(self) -> Dict[str, Any]:
client = self._client
if client is None:
return {"status": "unhealthy", "details": {"error": "not initialized"}}
try:
ping_result = client.ping()
ping = (
await ping_result if inspect.isawaitable(ping_result) else ping_result
)
info_result = client.info()
info = (
await info_result if inspect.isawaitable(info_result) else info_result
)
return {
"status": "healthy" if ping else "unhealthy",
"details": {
"ping": ping,
"redis_version": info.get("redis_version"),
"connected_clients": info.get("connected_clients"),
"used_memory": info.get("used_memory_human"),
"uptime_in_seconds": info.get("uptime_in_seconds"),
},
}
except Exception as exc: # noqa: BLE001
self.logger.warning("Redis health check failed", error=str(exc))
return {"status": "unhealthy", "details": {"error": str(exc)}}
def get_client(self) -> redis.Redis:
return self._require_client()
redis_service: RedisService = register_service_instance("redis", RedisService())
__all__ = ["RedisService", "redis_service"]
@@ -0,0 +1,84 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Optional, TypeVar
from core.logging import get_logger
class BaseServiceProvider(ABC):
def __init__(self, service_name: str) -> None:
self.service_name = service_name
self._initialized = False
self.logger = get_logger("services.base").bind(service=service_name)
@abstractmethod
async def initialize(self, **kwargs: Any) -> bool:
raise NotImplementedError
@abstractmethod
async def close(self) -> bool:
raise NotImplementedError
@abstractmethod
async def health_check(self) -> Dict[str, Any]:
raise NotImplementedError
@property
def is_initialized(self) -> bool:
return self._initialized
def _set_initialized(self, value: bool) -> None:
self._initialized = value
def get_service_info(self) -> Dict[str, Any]:
return {
"name": self.service_name,
"initialized": self._initialized,
"type": self.__class__.__name__,
}
class ServiceRegistry:
_services: Dict[str, Callable[..., BaseServiceProvider]] = {}
@classmethod
def register(
cls, service_name: str, factory: Callable[..., BaseServiceProvider]
) -> None:
cls._services = {**cls._services, service_name: factory}
@classmethod
def get_service_factory(
cls, service_name: str
) -> Optional[Callable[..., BaseServiceProvider]]:
return cls._services.get(service_name)
@classmethod
def list_services(cls) -> list[str]:
return sorted(cls._services.keys())
@classmethod
def create_service(
cls, service_name: str, **kwargs: Any
) -> Optional[BaseServiceProvider]:
factory = cls.get_service_factory(service_name)
if not factory:
return None
return factory(**kwargs)
def register_service(service_name: str) -> Callable[[type], type]:
def decorator(service_class: type) -> type:
ServiceRegistry.register(service_name, service_class)
return service_class
return decorator
TService = TypeVar("TService", bound=BaseServiceProvider)
def register_service_instance(service_name: str, service: TService) -> TService:
ServiceRegistry.register(service_name, lambda: service)
return service
+1
View File
@@ -0,0 +1 @@
from __future__ import annotations
+1
View File
@@ -0,0 +1 @@
from __future__ import annotations
+7
View File
@@ -0,0 +1,7 @@
from __future__ import annotations
from v1.auth.service import AuthService, SupabaseAuthGateway
def get_auth_service() -> AuthService:
return AuthService(gateway=SupabaseAuthGateway())
+35
View File
@@ -0,0 +1,35 @@
from __future__ import annotations
from pydantic import BaseModel, EmailStr, Field
class SignupRequest(BaseModel):
email: EmailStr
password: str = Field(min_length=6)
display_name: str | None = None
class LoginRequest(BaseModel):
email: EmailStr
password: str = Field(min_length=6)
class RefreshRequest(BaseModel):
refresh_token: str = Field(min_length=1)
class LogoutRequest(BaseModel):
refresh_token: str = Field(min_length=1)
class AuthUser(BaseModel):
id: str
email: EmailStr
class AuthTokenResponse(BaseModel):
access_token: str
refresh_token: str
expires_in: int
token_type: str
user: AuthUser
+49
View File
@@ -0,0 +1,49 @@
from __future__ import annotations
from fastapi import APIRouter, Depends, Response
from v1.auth.dependencies import get_auth_service
from v1.auth.models import (
AuthTokenResponse,
LoginRequest,
LogoutRequest,
RefreshRequest,
SignupRequest,
)
from v1.auth.service import AuthService
router = APIRouter(prefix="/auth", tags=["auth"])
@router.post("/signup", response_model=AuthTokenResponse)
async def signup(
payload: SignupRequest,
service: AuthService = Depends(get_auth_service),
) -> AuthTokenResponse:
return await service.signup(payload)
@router.post("/login", response_model=AuthTokenResponse)
async def login(
payload: LoginRequest,
service: AuthService = Depends(get_auth_service),
) -> AuthTokenResponse:
return await service.login(payload)
@router.post("/refresh", response_model=AuthTokenResponse)
async def refresh(
payload: RefreshRequest,
service: AuthService = Depends(get_auth_service),
) -> AuthTokenResponse:
return await service.refresh(payload)
@router.post("/logout", status_code=204)
async def logout(
payload: LogoutRequest,
service: AuthService = Depends(get_auth_service),
) -> Response:
await service.logout(payload.refresh_token)
return Response(status_code=204)
+147
View File
@@ -0,0 +1,147 @@
from __future__ import annotations
import asyncio
from typing import Any, Protocol, cast
from fastapi import HTTPException
from supabase import AuthError, create_client
from core.config.settings import SupabaseSettings, config
from core.logging import get_logger
from v1.auth.models import (
AuthTokenResponse,
AuthUser,
LoginRequest,
RefreshRequest,
SignupRequest,
)
logger = get_logger("v1.auth.service")
class AuthServiceGateway(Protocol):
async def signup(self, request: SignupRequest) -> AuthTokenResponse:
raise NotImplementedError
async def login(self, request: LoginRequest) -> AuthTokenResponse:
raise NotImplementedError
async def refresh(self, request: RefreshRequest) -> AuthTokenResponse:
raise NotImplementedError
async def logout(self, refresh_token: str | None) -> None:
raise NotImplementedError
class SupabaseAuthGateway(AuthServiceGateway):
_client: Any
def __init__(self) -> None:
settings: SupabaseSettings = config.supabase
self._client = create_client(settings.url, settings.anon_key)
async def signup(self, request: SignupRequest) -> AuthTokenResponse:
payload: dict[str, Any] = {
"email": request.email,
"password": request.password,
}
if request.display_name:
payload = {
**payload,
"data": {"display_name": request.display_name},
}
try:
sign_up = cast(Any, self._client.auth.sign_up)
response = await asyncio.to_thread(sign_up, payload)
return _map_auth_response(response, "Authentication failed")
except AuthError as exc:
logger.warning("Signup failed", error=str(exc))
raise HTTPException(
status_code=401, detail="Authentication failed"
) from exc
async def login(self, request: LoginRequest) -> AuthTokenResponse:
payload: dict[str, Any] = {"email": request.email, "password": request.password}
try:
sign_in = cast(Any, self._client.auth.sign_in_with_password)
response = await asyncio.to_thread(sign_in, payload)
return _map_auth_response(response, "Invalid credentials")
except AuthError as exc:
logger.warning("Login failed", error=str(exc))
raise HTTPException(status_code=401, detail="Invalid credentials") from exc
async def refresh(self, request: RefreshRequest) -> AuthTokenResponse:
try:
response = await asyncio.to_thread(
self._client.auth.refresh_session,
request.refresh_token,
)
return _map_auth_response(response, "Invalid refresh token")
except AuthError as exc:
logger.warning("Refresh failed", error=str(exc))
raise HTTPException(
status_code=401, detail="Invalid refresh token"
) from exc
async def logout(self, refresh_token: str | None) -> None:
if not refresh_token:
raise HTTPException(status_code=401, detail="Missing refresh token")
try:
response = await asyncio.to_thread(
self._client.auth.refresh_session,
refresh_token,
)
session = getattr(response, "session", None)
if session is None:
raise HTTPException(status_code=401, detail="Invalid refresh token")
await asyncio.to_thread(
self._client.auth.set_session,
str(session.access_token),
str(session.refresh_token),
)
await asyncio.to_thread(self._client.auth.sign_out)
except AuthError as exc:
logger.warning("Logout failed", error=str(exc))
raise HTTPException(
status_code=401, detail="Invalid refresh token"
) from exc
class AuthService:
_gateway: AuthServiceGateway
def __init__(self, gateway: AuthServiceGateway) -> None:
self._gateway = gateway
async def signup(self, request: SignupRequest) -> AuthTokenResponse:
return await self._gateway.signup(request)
async def login(self, request: LoginRequest) -> AuthTokenResponse:
return await self._gateway.login(request)
async def refresh(self, request: RefreshRequest) -> AuthTokenResponse:
return await self._gateway.refresh(request)
async def logout(self, refresh_token: str | None) -> None:
await self._gateway.logout(refresh_token)
def _map_auth_response(response: object, failure_message: str) -> AuthTokenResponse:
session = getattr(response, "session", None)
user = getattr(response, "user", None)
if session is None or user is None:
raise HTTPException(status_code=401, detail=failure_message)
email = getattr(user, "email", None)
if not email:
raise HTTPException(status_code=401, detail=failure_message)
auth_user = AuthUser(id=str(user.id), email=str(email))
return AuthTokenResponse(
access_token=str(session.access_token),
refresh_token=str(session.refresh_token),
expires_in=int(session.expires_in or 0),
token_type=str(session.token_type),
user=auth_user,
)
+1
View File
@@ -0,0 +1 @@
from __future__ import annotations
+12
View File
@@ -0,0 +1,12 @@
from __future__ import annotations
from services.base.redis import RedisService, redis_service
from services.base.qdrant import QdrantService, qdrant_service
def get_redis_service() -> RedisService:
return redis_service
def get_qdrant_service() -> QdrantService:
return qdrant_service
+38
View File
@@ -0,0 +1,38 @@
from __future__ import annotations
from fastapi import APIRouter, Depends
from services.base.qdrant import QdrantService
from services.base.redis import RedisService
from v1.infra.dependencies import get_qdrant_service, get_redis_service
from v1.infra.schemas import InfraHealthResponse, ServiceHealth
router = APIRouter(prefix="/infra", tags=["infra"])
@router.get("/health", response_model=InfraHealthResponse)
async def infra_health(
redis_service: RedisService = Depends(get_redis_service),
qdrant_service: QdrantService = Depends(get_qdrant_service),
) -> InfraHealthResponse:
if not redis_service.is_initialized:
await redis_service.initialize()
if not qdrant_service.is_initialized:
await qdrant_service.initialize()
redis_health = await redis_service.health_check()
qdrant_health = await qdrant_service.health_check()
status = (
"healthy"
if redis_health["status"] == "healthy" and qdrant_health["status"] == "healthy"
else "unhealthy"
)
return InfraHealthResponse(
status=status,
services={
"redis": ServiceHealth(**redis_health),
"qdrant": ServiceHealth(**qdrant_health),
},
)
+15
View File
@@ -0,0 +1,15 @@
from __future__ import annotations
from typing import Any, Dict, Literal
from pydantic import BaseModel
class ServiceHealth(BaseModel):
status: Literal["healthy", "unhealthy"]
details: Dict[str, Any]
class InfraHealthResponse(BaseModel):
status: Literal["healthy", "unhealthy"]
services: Dict[str, ServiceHealth]
+1
View File
@@ -0,0 +1 @@
from __future__ import annotations
+93
View File
@@ -0,0 +1,93 @@
from __future__ import annotations
from typing import Annotated
from uuid import UUID
import jwt
from fastapi import Depends, Header, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from core.config.settings import config
from core.db import get_db
from core.logging import get_logger
from core.auth.models import CurrentUser
from v1.profile.repository import SQLAlchemyProfileRepository
from v1.profile.service import ProfileService
logger = get_logger("v1.profile.dependencies")
def get_current_user(authorization: str | None = Header(default=None)) -> CurrentUser:
if not authorization:
logger.warning("JWT validation failed: missing authorization header")
raise HTTPException(status_code=401, detail="Unauthorized")
scheme, _, token = authorization.partition(" ")
if scheme.lower() != "bearer" or not token:
logger.warning("JWT validation failed: invalid authorization scheme")
raise HTTPException(status_code=401, detail="Unauthorized")
secret = config.supabase.jwt_secret
if not secret:
logger.error("JWT validation failed: secret not configured")
raise HTTPException(status_code=503, detail="JWT secret not configured")
supabase_url = config.supabase.public_url.rstrip("/")
expected_issuer = f"{supabase_url}/auth/v1"
try:
payload = jwt.decode(
token,
secret,
algorithms=["HS256"],
audience="authenticated",
issuer=expected_issuer,
options={
"verify_aud": True,
"verify_iss": True,
"verify_exp": True,
"require": ["sub", "aud", "iss", "exp"],
},
)
except jwt.ExpiredSignatureError:
logger.warning("JWT validation failed: token expired")
raise HTTPException(status_code=401, detail="Unauthorized")
except jwt.InvalidAudienceError:
logger.warning("JWT validation failed: invalid audience")
raise HTTPException(status_code=401, detail="Unauthorized")
except jwt.InvalidIssuerError:
logger.warning("JWT validation failed: invalid issuer")
raise HTTPException(status_code=401, detail="Unauthorized")
except jwt.InvalidSignatureError:
logger.warning("JWT validation failed: invalid signature")
raise HTTPException(status_code=401, detail="Unauthorized")
except jwt.DecodeError:
logger.warning("JWT validation failed: malformed token")
raise HTTPException(status_code=401, detail="Unauthorized")
except jwt.PyJWTError as exc:
logger.warning(
"JWT validation failed: unknown error", error_type=type(exc).__name__
)
raise HTTPException(status_code=401, detail="Unauthorized") from exc
subject = payload.get("sub")
if not isinstance(subject, str) or not subject:
logger.warning("JWT validation failed: missing or invalid subject claim")
raise HTTPException(status_code=401, detail="Unauthorized")
try:
user_id = UUID(subject)
except ValueError:
logger.warning("JWT validation failed: invalid UUID in subject")
raise HTTPException(status_code=401, detail="Unauthorized")
logger.debug("JWT validation successful", user_id=str(user_id))
return CurrentUser(id=user_id)
def get_profile_service(
session: Annotated[AsyncSession, Depends(get_db)],
user: Annotated[CurrentUser, Depends(get_current_user)],
) -> ProfileService:
repository = SQLAlchemyProfileRepository(session)
return ProfileService(repository=repository, session=session, current_user=user)
+72
View File
@@ -0,0 +1,72 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Protocol
from uuid import UUID
from sqlalchemy.exc import SQLAlchemyError
from core.db.base_repository import BaseRepository
from core.logging import get_logger
from models.profile import Profile
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
logger = get_logger("v1.profile.repository")
class ProfileRepository(Protocol):
"""Protocol defining the profile repository interface."""
async def get_by_user_id(self, user_id: UUID) -> Profile | None:
"""Get profile by user ID."""
...
async def get_by_username(self, username: str) -> Profile | None:
"""Get profile by username."""
...
async def update_by_user_id(
self, user_id: UUID, update_data: dict[str, str | None]
) -> Profile | None:
"""Update profile by user ID. Returns updated profile or None if not found."""
...
class SQLAlchemyProfileRepository(BaseRepository[Profile]):
"""SQLAlchemy implementation of ProfileRepository.
Note: This repository only performs CRUD operations.
- No commit (only flush) - service layer handles transactions
- No auth logic - service layer handles authorization
- No HTTP exceptions - returns None or raises SQLAlchemyError
"""
def __init__(self, session: AsyncSession) -> None:
super().__init__(session, Profile)
async def get_by_user_id(self, user_id: UUID) -> Profile | None:
try:
return await self.get_by_id(user_id)
except SQLAlchemyError:
logger.exception("Profile lookup failed", user_id=str(user_id))
raise
async def get_by_username(self, username: str) -> Profile | None:
try:
return await self.get_one(Profile.username == username)
except SQLAlchemyError:
logger.exception("Profile lookup failed", username=username)
raise
async def update_by_user_id(
self, user_id: UUID, update_data: dict[str, str | None]
) -> Profile | None:
if not update_data:
return await self.get_by_user_id(user_id)
try:
return await self.update_by_id(user_id, update_data)
except SQLAlchemyError:
logger.exception("Profile update failed", user_id=str(user_id))
raise
+36
View File
@@ -0,0 +1,36 @@
from __future__ import annotations
from typing import Annotated
from fastapi import APIRouter, Depends, Path
from v1.profile.dependencies import get_profile_service
from v1.profile.schemas import ProfileResponse, ProfileUpdateRequest
from v1.profile.service import ProfileService
router = APIRouter(prefix="/profile", tags=["profile"])
@router.get("/me", response_model=ProfileResponse)
async def get_me(
service: Annotated[ProfileService, Depends(get_profile_service)],
) -> ProfileResponse:
return await service.get_me()
@router.patch("/me", response_model=ProfileResponse)
async def update_me(
payload: ProfileUpdateRequest,
service: Annotated[ProfileService, Depends(get_profile_service)],
) -> ProfileResponse:
return await service.update_me(payload)
@router.get("/{username}", response_model=ProfileResponse)
async def get_by_username(
username: Annotated[
str, Path(min_length=3, max_length=30, pattern="^[a-zA-Z0-9_]+$")
],
service: Annotated[ProfileService, Depends(get_profile_service)],
) -> ProfileResponse:
return await service.get_by_username(username)
+33
View File
@@ -0,0 +1,33 @@
from __future__ import annotations
from pydantic import AnyHttpUrl, BaseModel, Field, field_validator, model_validator
class ProfileResponse(BaseModel):
id: str
username: str
display_name: str | None = None
avatar_url: str | None = None
bio: str | None = None
class ProfileUpdateRequest(BaseModel):
display_name: str | None = Field(default=None, max_length=50)
avatar_url: str | None = Field(default=None)
bio: str | None = Field(default=None, max_length=200)
@field_validator("avatar_url", mode="before")
@classmethod
def validate_avatar_url(cls, v: str | None) -> str | None:
if v is None:
return None
parsed = AnyHttpUrl(v)
if parsed.scheme not in ("http", "https"):
raise ValueError("avatar_url must use http or https scheme")
return str(parsed)
@model_validator(mode="after")
def require_one_field(self) -> "ProfileUpdateRequest":
if self.display_name is None and self.avatar_url is None and self.bio is None:
raise ValueError("At least one field must be provided")
return self
+106
View File
@@ -0,0 +1,106 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from fastapi import HTTPException
from sqlalchemy.exc import SQLAlchemyError
from core.auth.models import CurrentUser
from core.db.base_service import BaseService
from core.logging import get_logger
from v1.profile.repository import ProfileRepository
from v1.profile.schemas import ProfileResponse, ProfileUpdateRequest
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
logger = get_logger("v1.profile.service")
class ProfileService(BaseService):
"""Profile service handling business logic and transactions.
Responsibilities:
- Authorization checks
- Transaction boundary (commit/rollback)
- Converting ORM models to response schemas
"""
_repository: ProfileRepository
_session: AsyncSession
def __init__(
self,
repository: ProfileRepository,
session: AsyncSession,
current_user: CurrentUser | None,
) -> None:
super().__init__(current_user=current_user)
self._repository = repository
self._session = session
async def get_me(self) -> ProfileResponse:
user_id = self.require_user_id()
try:
profile = await self._repository.get_by_user_id(user_id)
except SQLAlchemyError:
raise HTTPException(status_code=503, detail="Profile store unavailable")
if profile is None:
raise HTTPException(status_code=404, detail="Profile not found")
return ProfileResponse(
id=str(profile.id),
username=profile.username,
display_name=profile.display_name,
avatar_url=profile.avatar_url,
bio=profile.bio,
)
async def update_me(self, update: ProfileUpdateRequest) -> ProfileResponse:
user_id = self.require_user_id()
update_data: dict[str, str | None] = {
key: value
for key, value in {
"display_name": update.display_name,
"avatar_url": update.avatar_url,
"bio": update.bio,
}.items()
if value is not None
}
if not update_data:
raise HTTPException(status_code=400, detail="No fields to update")
try:
profile = await self._repository.update_by_user_id(user_id, update_data)
await self._session.commit()
except SQLAlchemyError:
await self._session.rollback()
raise HTTPException(status_code=503, detail="Profile store unavailable")
if profile is None:
raise HTTPException(status_code=404, detail="Profile not found")
return ProfileResponse(
id=str(profile.id),
username=profile.username,
display_name=profile.display_name,
avatar_url=profile.avatar_url,
bio=profile.bio,
)
async def get_by_username(self, username: str) -> ProfileResponse:
try:
profile = await self._repository.get_by_username(username)
except SQLAlchemyError:
raise HTTPException(status_code=503, detail="Profile store unavailable")
if profile is None:
raise HTTPException(status_code=404, detail="Profile not found")
return ProfileResponse(
id=str(profile.id),
username=profile.username,
display_name=profile.display_name,
avatar_url=profile.avatar_url,
bio=profile.bio,
)
+19
View File
@@ -0,0 +1,19 @@
from __future__ import annotations
from fastapi import APIRouter
from core.http.models import HealthResponse
from v1.auth.router import router as auth_router
from v1.infra.router import router as infra_router
from v1.profile.router import router as profile_router
router = APIRouter(prefix="/api/v1")
router.include_router(auth_router)
router.include_router(infra_router)
router.include_router(profile_router)
@router.get("/health", response_model=HealthResponse)
async def health() -> HealthResponse:
return HealthResponse(status="ok")
+11
View File
@@ -0,0 +1,11 @@
from __future__ import annotations
import sys
from pathlib import Path
def pytest_configure() -> None:
root = Path(__file__).resolve().parents[2]
src_path = root / "backend" / "src"
if str(src_path) not in sys.path:
sys.path.insert(0, str(src_path))
+134
View File
@@ -0,0 +1,134 @@
from __future__ import annotations
import json
import socket
import threading
import time
from playwright.sync_api import sync_playwright
import uvicorn
from app import app
from v1.auth.dependencies import get_auth_service
from v1.auth.models import (
AuthTokenResponse,
AuthUser,
LoginRequest,
RefreshRequest,
SignupRequest,
)
from v1.auth.service import AuthService
class FakeE2EAuthService(AuthService):
def __init__(self) -> None:
self._user = AuthUser(id="user-1", email="user@example.com")
async def signup(self, request: SignupRequest) -> AuthTokenResponse:
return AuthTokenResponse(
access_token="access-1",
refresh_token="refresh-1",
expires_in=3600,
token_type="bearer",
user=self._user,
)
async def login(self, request: LoginRequest) -> AuthTokenResponse:
return AuthTokenResponse(
access_token="access-2",
refresh_token="refresh-2",
expires_in=3600,
token_type="bearer",
user=self._user,
)
async def refresh(self, request: RefreshRequest) -> AuthTokenResponse:
return AuthTokenResponse(
access_token="access-3",
refresh_token="refresh-3",
expires_in=3600,
token_type="bearer",
user=self._user,
)
async def logout(self, refresh_token: str | None) -> None:
return None
def _find_free_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(("127.0.0.1", 0))
return sock.getsockname()[1]
def _wait_for_port(host: str, port: int, timeout: float = 5.0) -> None:
deadline = time.time() + timeout
while time.time() < deadline:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
if sock.connect_ex((host, port)) == 0:
return
time.sleep(0.05)
raise RuntimeError("Server did not start in time")
def _start_server(host: str, port: int):
config = uvicorn.Config(app, host=host, port=port, log_level="info")
server = uvicorn.Server(config)
thread = threading.Thread(target=server.run, daemon=True)
thread.start()
_wait_for_port(host, port)
return server, thread
def test_auth_flow_e2e() -> None:
app.dependency_overrides[get_auth_service] = lambda: FakeE2EAuthService()
host = "127.0.0.1"
port = _find_free_port()
server, thread = _start_server(host, port)
try:
with sync_playwright() as playwright:
request_context = playwright.request.new_context(
base_url=f"http://{host}:{port}"
)
try:
signup = request_context.post(
"/api/v1/auth/signup",
data=json.dumps(
{"email": "user@example.com", "password": "secret123"}
),
headers={"Content-Type": "application/json"},
)
assert signup.status == 200
assert signup.json()["access_token"] == "access-1"
login = request_context.post(
"/api/v1/auth/login",
data=json.dumps(
{"email": "user@example.com", "password": "secret123"}
),
headers={"Content-Type": "application/json"},
)
assert login.status == 200
assert login.json()["access_token"] == "access-2"
refresh = request_context.post(
"/api/v1/auth/refresh",
data=json.dumps({"refresh_token": "refresh-2"}),
headers={"Content-Type": "application/json"},
)
assert refresh.status == 200
assert refresh.json()["access_token"] == "access-3"
logout = request_context.post(
"/api/v1/auth/logout",
data=json.dumps({"refresh_token": "refresh-3"}),
headers={"Content-Type": "application/json"},
)
assert logout.status == 204
finally:
request_context.dispose()
finally:
app.dependency_overrides = {}
server.should_exit = True
thread.join(timeout=5)
@@ -0,0 +1,79 @@
from __future__ import annotations
import socket
import threading
import time
from playwright.sync_api import sync_playwright
import uvicorn
from app import app
from v1.infra.dependencies import get_qdrant_service, get_redis_service
class _FakeService:
def __init__(self) -> None:
self._initialized = True
@property
def is_initialized(self) -> bool:
return self._initialized
async def initialize(self) -> bool:
return True
async def health_check(self) -> dict[str, object]:
return {"status": "healthy", "details": {}}
def _find_free_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(("127.0.0.1", 0))
return sock.getsockname()[1]
def _wait_for_port(host: str, port: int, timeout: float = 5.0) -> None:
deadline = time.time() + timeout
while time.time() < deadline:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
if sock.connect_ex((host, port)) == 0:
return
time.sleep(0.05)
raise RuntimeError("Server did not start in time")
def _start_server(host: str, port: int):
config = uvicorn.Config(app, host=host, port=port, log_level="info")
server = uvicorn.Server(config)
thread = threading.Thread(target=server.run, daemon=True)
thread.start()
_wait_for_port(host, port)
return server, thread
def test_infra_health_e2e() -> None:
app.dependency_overrides[get_redis_service] = lambda: _FakeService()
app.dependency_overrides[get_qdrant_service] = lambda: _FakeService()
host = "127.0.0.1"
port = _find_free_port()
server, thread = _start_server(host, port)
try:
with sync_playwright() as playwright:
request_context = playwright.request.new_context(
base_url=f"http://{host}:{port}"
)
try:
response = request_context.get("/api/v1/infra/health")
assert response.status == 200
body = response.json()
assert body["status"] == "healthy"
assert "redis" in body["services"]
assert "qdrant" in body["services"]
finally:
request_context.dispose()
finally:
server.should_exit = True
thread.join(timeout=5)
app.dependency_overrides = {}
+96
View File
@@ -0,0 +1,96 @@
from __future__ import annotations
import json
import socket
import threading
import time
from pathlib import Path
from fastapi import FastAPI
from playwright.sync_api import sync_playwright
import uvicorn
from core.config.settings import Settings
from core.logging.config import configure_logging
from core.logging.middleware import (
RequestContextMiddleware,
register_exception_handlers,
)
def _read_json_lines(path: Path) -> list[dict[str, object]]:
return [json.loads(line) for line in path.read_text().splitlines() if line.strip()]
def _find_free_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(("127.0.0.1", 0))
return sock.getsockname()[1]
def _wait_for_port(host: str, port: int, timeout: float = 5.0) -> None:
deadline = time.time() + timeout
while time.time() < deadline:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
if sock.connect_ex((host, port)) == 0:
return
time.sleep(0.05)
raise RuntimeError("Server did not start in time")
def _start_server(app: FastAPI, host: str, port: int):
config = uvicorn.Config(app, host=host, port=port, log_level="info")
server = uvicorn.Server(config)
thread = threading.Thread(target=server.run, daemon=True)
thread.start()
_wait_for_port(host, port)
return server, thread
def test_e2e_error_logging(tmp_path: Path) -> None:
settings = Settings()
runtime = settings.runtime.model_copy(
update={
"log_dir": str(tmp_path),
"log_error_dir": str(tmp_path / "errors"),
"log_rotation": "size",
"log_rotation_max_bytes": 2048,
}
)
configure_logging(settings.model_copy(update={"runtime": runtime}))
app = FastAPI()
app.add_middleware(RequestContextMiddleware) # type: ignore[arg-type]
register_exception_handlers(app)
@app.get("/boom")
async def boom() -> dict[str, str]:
raise RuntimeError("boom")
host = "127.0.0.1"
port = _find_free_port()
server, thread = _start_server(app, host, port)
try:
with sync_playwright() as playwright:
request_context = playwright.request.new_context(
base_url=f"http://{host}:{port}"
)
response = request_context.get(
"/boom",
headers={"X-Request-ID": "e2e-5000"},
)
assert response.status == 500
request_context.dispose()
finally:
server.should_exit = True
thread.join(timeout=5)
error_entries = _read_json_lines(Path(tmp_path) / "errors" / "error.log")
entry = next(
item for item in error_entries if item.get("message") == "Unhandled exception"
)
assert entry["request_id"] == "e2e-5000"
exception = str(entry["exception"])
assert "Traceback" in exception
@@ -0,0 +1,57 @@
from __future__ import annotations
import socket
import threading
import time
from playwright.sync_api import sync_playwright
import uvicorn
from app import app
def _find_free_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(("127.0.0.1", 0))
return sock.getsockname()[1]
def _wait_for_port(host: str, port: int, timeout: float = 5.0) -> None:
deadline = time.time() + timeout
while time.time() < deadline:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
if sock.connect_ex((host, port)) == 0:
return
time.sleep(0.05)
raise RuntimeError("Server did not start in time")
def _start_server(host: str, port: int):
config = uvicorn.Config(app, host=host, port=port, log_level="info")
server = uvicorn.Server(config)
thread = threading.Thread(target=server.run, daemon=True)
thread.start()
_wait_for_port(host, port)
return server, thread
def test_mobile_health_e2e() -> None:
host = "127.0.0.1"
port = _find_free_port()
server, thread = _start_server(host, port)
try:
with sync_playwright() as playwright:
request_context = playwright.request.new_context(
base_url=f"http://{host}:{port}"
)
try:
response = request_context.get("/api/v1/health")
assert response.status == 200
body = response.json()
assert body["status"] == "ok"
finally:
request_context.dispose()
finally:
server.should_exit = True
thread.join(timeout=5)
+115
View File
@@ -0,0 +1,115 @@
from __future__ import annotations
import json
import socket
import threading
import time
from uuid import UUID
from playwright.sync_api import sync_playwright
import uvicorn
from app import app
from core.auth.models import CurrentUser
from v1.profile.dependencies import get_current_user, get_profile_service
from v1.profile.schemas import ProfileResponse, ProfileUpdateRequest
class FakeProfileService:
"""Fake service for E2E testing."""
def __init__(self, profile: ProfileResponse) -> None:
self._profile = profile
async def get_me(self) -> ProfileResponse:
return self._profile
async def update_me(self, update: ProfileUpdateRequest) -> ProfileResponse:
return ProfileResponse(
id=self._profile.id,
username=self._profile.username,
display_name=(
update.display_name
if update.display_name is not None
else self._profile.display_name
),
avatar_url=(
update.avatar_url
if update.avatar_url is not None
else self._profile.avatar_url
),
bio=update.bio if update.bio is not None else self._profile.bio,
)
async def get_by_username(self, username: str) -> ProfileResponse:
return self._profile
def _find_free_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(("127.0.0.1", 0))
return sock.getsockname()[1]
def _wait_for_port(host: str, port: int, timeout: float = 5.0) -> None:
deadline = time.time() + timeout
while time.time() < deadline:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
if sock.connect_ex((host, port)) == 0:
return
time.sleep(0.05)
raise RuntimeError("Server did not start in time")
def _start_server(host: str, port: int):
config = uvicorn.Config(app, host=host, port=port, log_level="info")
server = uvicorn.Server(config)
thread = threading.Thread(target=server.run, daemon=True)
thread.start()
_wait_for_port(host, port)
return server, thread
def test_profile_flow_e2e() -> None:
user_id = UUID("00000000-0000-0000-0000-000000000001")
profile = ProfileResponse(
id=str(user_id),
username="demo",
display_name="Demo User",
avatar_url=None,
bio=None,
)
app.dependency_overrides[get_profile_service] = lambda: FakeProfileService(profile) # type: ignore[return-value]
app.dependency_overrides[get_current_user] = lambda: CurrentUser(id=user_id)
host = "127.0.0.1"
port = _find_free_port()
server, thread = _start_server(host, port)
try:
with sync_playwright() as playwright:
request_context = playwright.request.new_context(
base_url=f"http://{host}:{port}"
)
try:
me = request_context.get("/api/v1/profile/me")
assert me.status == 200
assert me.json()["username"] == "demo"
updated = request_context.patch(
"/api/v1/profile/me",
data=json.dumps({"display_name": "Updated"}),
headers={"Content-Type": "application/json"},
)
assert updated.status == 200
assert updated.json()["display_name"] == "Updated"
public = request_context.get("/api/v1/profile/demo")
assert public.status == 200
assert public.json()["username"] == "demo"
finally:
request_context.dispose()
finally:
app.dependency_overrides = {}
server.should_exit = True
thread.join(timeout=5)
@@ -0,0 +1,49 @@
from __future__ import annotations
import socket
import pytest
from core.config.settings import Settings
from services.base.qdrant import QdrantService
from services.base.redis import RedisService
def _can_connect(host: str, port: int) -> bool:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.settimeout(0.2)
return sock.connect_ex((host, port)) == 0
@pytest.mark.asyncio
async def test_redis_service_health_check_integration() -> None:
host = "127.0.0.1"
port = 6379
if not _can_connect(host, port):
pytest.skip("Redis is not running on localhost:6379")
config = Settings()
settings = config.redis.model_copy(update={"host": host, "port": port})
service = RedisService(settings=settings)
assert await service.initialize() is True
health = await service.health_check()
assert health["status"] == "healthy"
assert await service.close() is True
@pytest.mark.asyncio
async def test_qdrant_service_health_check_integration() -> None:
host = "127.0.0.1"
port = 6333
if not _can_connect(host, port):
pytest.skip("Qdrant is not running on localhost:6333")
config = Settings()
settings = config.qdrant.model_copy(update={"host": host, "port": port})
service = QdrantService(settings=settings)
assert await service.initialize() is True
health = await service.health_check()
assert health["status"] == "healthy"
assert await service.close() is True
@@ -0,0 +1,178 @@
from __future__ import annotations
from typing import Callable
from fastapi import HTTPException
from fastapi.testclient import TestClient
from app import app
from v1.auth.dependencies import get_auth_service
from v1.auth.models import (
AuthTokenResponse,
AuthUser,
LoginRequest,
RefreshRequest,
SignupRequest,
)
from v1.auth.service import AuthService
class FakeAuthService(AuthService):
def __init__(self, token_response: AuthTokenResponse) -> None:
self._token_response = token_response
async def signup(self, request: SignupRequest) -> AuthTokenResponse:
return self._token_response
async def login(self, request: LoginRequest) -> AuthTokenResponse:
raise HTTPException(status_code=401, detail="Invalid credentials")
async def refresh(self, request: RefreshRequest) -> AuthTokenResponse:
raise HTTPException(status_code=401, detail="Invalid refresh token")
async def logout(self, refresh_token: str | None) -> None:
return None
def _override_auth_service(service: AuthService) -> Callable[[], AuthService]:
def _get_service() -> AuthService:
return service
return _get_service
def test_signup_returns_token_response() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = AuthTokenResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
)
client = TestClient(app)
try:
response = client.post(
"/api/v1/auth/signup",
json={"email": "user@example.com", "password": "secret123"},
)
assert response.status_code == 200
body = response.json()
assert body["access_token"] == "access"
assert body["refresh_token"] == "refresh"
assert body["user"]["email"] == "user@example.com"
finally:
app.dependency_overrides = {}
def test_login_invalid_returns_problem_details() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = AuthTokenResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
)
client = TestClient(app)
try:
response = client.post(
"/api/v1/auth/login",
json={"email": "user@example.com", "password": "wrongpw"},
)
assert response.status_code == 401
assert response.headers["content-type"].startswith("application/problem+json")
body = response.json()
assert body["title"] == "Unauthorized"
assert body["status"] == 401
assert body["detail"] == "Invalid credentials"
finally:
app.dependency_overrides = {}
def test_refresh_invalid_returns_problem_details() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = AuthTokenResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
)
client = TestClient(app)
try:
response = client.post(
"/api/v1/auth/refresh",
json={"refresh_token": "invalid"},
)
assert response.status_code == 401
assert response.headers["content-type"].startswith("application/problem+json")
body = response.json()
assert body["title"] == "Unauthorized"
assert body["status"] == 401
assert body["detail"] == "Invalid refresh token"
finally:
app.dependency_overrides = {}
def test_logout_returns_no_content() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = AuthTokenResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
)
client = TestClient(app)
try:
response = client.post(
"/api/v1/auth/logout",
json={"refresh_token": "refresh"},
)
assert response.status_code == 204
assert response.content == b""
finally:
app.dependency_overrides = {}
def test_signup_validation_error_returns_problem_details() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = AuthTokenResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
)
client = TestClient(app)
try:
response = client.post("/api/v1/auth/signup", json={})
assert response.status_code == 422
assert response.headers["content-type"].startswith("application/problem+json")
body = response.json()
assert body["title"] == "Unprocessable Content"
assert body["status"] == 422
assert body["detail"] == "Invalid request"
finally:
app.dependency_overrides = {}
@@ -0,0 +1,126 @@
from __future__ import annotations
import json
import logging
from pathlib import Path
from fastapi import FastAPI
from fastapi.testclient import TestClient
from core.config.settings import Settings
from core.logging.config import configure_logging
from core.logging.logger import get_logger
from core.logging.middleware import (
RequestContextMiddleware,
register_exception_handlers,
)
def _read_json_lines(path: Path) -> list[dict[str, object]]:
return [json.loads(line) for line in path.read_text().splitlines() if line.strip()]
def _configure_test_logging(tmp_path: Path) -> None:
settings = Settings()
runtime = settings.runtime.model_copy(
update={
"log_dir": str(tmp_path),
"log_error_dir": str(tmp_path / "errors"),
"log_rotation": "size",
"log_rotation_max_bytes": 2048,
}
)
test_settings = settings.model_copy(update={"runtime": runtime})
configure_logging(test_settings)
def test_middleware_binds_request_context(tmp_path: Path) -> None:
_configure_test_logging(tmp_path)
app = FastAPI()
app.add_middleware(RequestContextMiddleware) # type: ignore[arg-type]
@app.get("/ok")
async def ok() -> dict[str, str]:
logger = get_logger("tests.ok")
logger.info("request accepted", context_key="context_value")
return {"status": "ok"}
client = TestClient(app)
response = client.get("/ok", headers={"X-Request-ID": "req-1234"})
assert response.status_code == 200
assert response.headers["X-Request-ID"] == "req-1234"
log_entries = _read_json_lines(Path(tmp_path) / "app.log")
entry = next(
item for item in log_entries if item.get("message") == "request accepted"
)
assert entry["message"] == "request accepted"
assert entry["request_id"] == "req-1234"
assert entry["method"] == "GET"
assert entry["path"] == "/ok"
assert entry["context_key"] == "context_value"
logging.shutdown()
def test_exception_handler_logs_stack_and_sends_500(tmp_path: Path) -> None:
_configure_test_logging(tmp_path)
app = FastAPI()
app.add_middleware(RequestContextMiddleware)
register_exception_handlers(app)
@app.get("/boom")
async def boom() -> dict[str, str]:
raise RuntimeError("boom")
client = TestClient(app, raise_server_exceptions=False)
response = client.get("/boom", headers={"X-Request-ID": "req-5000"})
assert response.status_code == 500
assert response.json()["detail"] == "Internal Server Error"
error_entries = _read_json_lines(Path(tmp_path) / "errors" / "error.log")
assert error_entries
entry = error_entries[-1]
assert entry["level"] == "error"
assert entry["request_id"] == "req-5000"
exception = str(entry["exception"])
assert "Traceback" in exception
assert "test_fastapi_logging_integration" in exception
logging.shutdown()
def test_invalid_request_id_is_replaced_and_used_in_error_context(
tmp_path: Path,
) -> None:
_configure_test_logging(tmp_path)
app = FastAPI()
app.add_middleware(RequestContextMiddleware)
register_exception_handlers(app)
@app.get("/boom")
async def boom() -> dict[str, str]:
raise RuntimeError("boom")
client = TestClient(app, raise_server_exceptions=False)
response = client.get("/boom", headers={"X-Request-ID": "bad"})
assert response.status_code == 500
response_request_id = response.headers["X-Request-ID"]
assert response_request_id != "bad"
error_entries = _read_json_lines(Path(tmp_path) / "errors" / "error.log")
assert error_entries
entry = error_entries[-1]
assert entry["request_id"] == response_request_id
exception = str(entry["exception"])
assert "Traceback" in exception
logging.shutdown()
@@ -0,0 +1,39 @@
from __future__ import annotations
from fastapi.testclient import TestClient
from app import app
def test_app_health_returns_envelope() -> None:
client = TestClient(app)
response = client.get("/health")
assert response.status_code == 200
body = response.json()
assert body["status"] == "ok"
def test_mobile_router_health_returns_envelope() -> None:
client = TestClient(app)
response = client.get("/api/v1/health")
assert response.status_code == 200
body = response.json()
assert body["status"] == "ok"
def test_not_found_returns_error_envelope() -> None:
client = TestClient(app)
response = client.get("/missing-route")
assert response.status_code == 404
assert response.headers["content-type"].startswith("application/problem+json")
body = response.json()
assert body["type"] == "about:blank"
assert body["title"] == "Not Found"
assert body["status"] == 404
assert body["detail"] == "Not Found"
@@ -0,0 +1,188 @@
from __future__ import annotations
from typing import Callable
from uuid import UUID
from fastapi import HTTPException
from fastapi.testclient import TestClient
from app import app
from core.auth.models import CurrentUser
from v1.profile.dependencies import get_current_user, get_profile_service
from v1.profile.schemas import ProfileResponse, ProfileUpdateRequest
from v1.profile.service import ProfileService
class FakeProfileService:
"""Fake service for integration testing."""
def __init__(self, profile: ProfileResponse) -> None:
self._profile = profile
async def get_me(self) -> ProfileResponse:
if self._profile.id is None:
raise HTTPException(status_code=404, detail="Profile not found")
return self._profile
async def update_me(self, update: ProfileUpdateRequest) -> ProfileResponse:
if self._profile.id is None:
raise HTTPException(status_code=404, detail="Profile not found")
return ProfileResponse(
id=self._profile.id,
username=self._profile.username,
display_name=(
update.display_name
if update.display_name is not None
else self._profile.display_name
),
avatar_url=(
update.avatar_url
if update.avatar_url is not None
else self._profile.avatar_url
),
bio=update.bio if update.bio is not None else self._profile.bio,
)
async def get_by_username(self, username: str) -> ProfileResponse:
if username != self._profile.username:
raise HTTPException(status_code=404, detail="Profile not found")
return self._profile
def _override_profile_service(
service: FakeProfileService,
) -> Callable[[], ProfileService]:
def _get_service() -> ProfileService:
return service # type: ignore[return-value]
return _get_service
def _override_current_user(user_id: UUID) -> Callable[[], CurrentUser]:
def _get_user() -> CurrentUser:
return CurrentUser(id=user_id)
return _get_user
def test_get_me_returns_profile() -> None:
user_id = UUID("00000000-0000-0000-0000-000000000001")
profile = ProfileResponse(
id=str(user_id),
username="demo",
display_name="Demo User",
avatar_url=None,
bio=None,
)
app.dependency_overrides[get_profile_service] = _override_profile_service(
FakeProfileService(profile)
)
app.dependency_overrides[get_current_user] = _override_current_user(user_id)
client = TestClient(app)
try:
response = client.get("/api/v1/profile/me")
assert response.status_code == 200
body = response.json()
assert body["username"] == "demo"
finally:
app.dependency_overrides = {}
def test_patch_me_updates_profile() -> None:
user_id = UUID("00000000-0000-0000-0000-000000000001")
profile = ProfileResponse(
id=str(user_id),
username="demo",
display_name="Demo User",
avatar_url=None,
bio=None,
)
app.dependency_overrides[get_profile_service] = _override_profile_service(
FakeProfileService(profile)
)
app.dependency_overrides[get_current_user] = _override_current_user(user_id)
client = TestClient(app)
try:
response = client.patch(
"/api/v1/profile/me",
json={"display_name": "Updated"},
)
assert response.status_code == 200
body = response.json()
assert body["display_name"] == "Updated"
finally:
app.dependency_overrides = {}
def test_get_profile_by_username() -> None:
profile = ProfileResponse(
id="00000000-0000-0000-0000-000000000001",
username="demo",
display_name="Demo User",
avatar_url=None,
bio=None,
)
app.dependency_overrides[get_profile_service] = _override_profile_service(
FakeProfileService(profile)
)
client = TestClient(app)
try:
response = client.get("/api/v1/profile/demo")
assert response.status_code == 200
body = response.json()
assert body["username"] == "demo"
finally:
app.dependency_overrides = {}
def test_profile_not_found_returns_problem_details() -> None:
profile = ProfileResponse(
id="00000000-0000-0000-0000-000000000001",
username="demo",
display_name="Demo User",
avatar_url=None,
bio=None,
)
app.dependency_overrides[get_profile_service] = _override_profile_service(
FakeProfileService(profile)
)
client = TestClient(app)
try:
response = client.get("/api/v1/profile/unknown")
assert response.status_code == 404
assert response.headers["content-type"].startswith("application/problem+json")
body = response.json()
assert body["title"] == "Not Found"
assert body["status"] == 404
finally:
app.dependency_overrides = {}
def test_patch_me_validation_error_returns_problem_details() -> None:
user_id = UUID("00000000-0000-0000-0000-000000000001")
profile = ProfileResponse(
id=str(user_id),
username="demo",
display_name="Demo User",
avatar_url=None,
bio=None,
)
app.dependency_overrides[get_profile_service] = _override_profile_service(
FakeProfileService(profile)
)
app.dependency_overrides[get_current_user] = _override_current_user(user_id)
client = TestClient(app)
try:
response = client.patch("/api/v1/profile/me", json={})
assert response.status_code == 422
assert response.headers["content-type"].startswith("application/problem+json")
body = response.json()
assert body["title"] == "Unprocessable Content"
assert body["status"] == 422
finally:
app.dependency_overrides = {}
@@ -0,0 +1,27 @@
from __future__ import annotations
import pytest
from fastapi import HTTPException
from uuid import UUID
from core.auth.models import CurrentUser
from core.db.base_service import BaseService
def test_require_current_user_raises_when_missing() -> None:
service = BaseService(current_user=None)
with pytest.raises(HTTPException) as exc_info:
service.require_current_user()
assert exc_info.value.status_code == 401
def test_require_current_user_returns_user() -> None:
user = CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001"))
service = BaseService(current_user=user)
result = service.require_current_user()
assert result.id == user.id
@@ -0,0 +1,78 @@
from __future__ import annotations
from datetime import datetime, timezone
from uuid import UUID, uuid4
import pytest
from sqlalchemy import String
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import Mapped, mapped_column
from core.db.base import Base, SoftDeleteMixin
from core.db.base_repository import BaseRepository
class Widget(SoftDeleteMixin, Base):
__tablename__ = "widgets"
id: Mapped[UUID] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String(50), nullable=False)
@pytest.fixture
async def db_engine():
engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest.fixture
async def db_session(db_engine):
async_session = async_sessionmaker(
bind=db_engine,
class_=AsyncSession,
expire_on_commit=False,
)
async with async_session() as session:
yield session
await session.rollback()
@pytest.mark.asyncio
async def test_get_by_id_filters_soft_deleted(db_session: AsyncSession) -> None:
repository = BaseRepository(db_session, Widget)
widget_id = uuid4()
widget = Widget(id=widget_id, name="widget")
db_session.add(widget)
await db_session.commit()
found = await repository.get_by_id(widget_id)
assert found is not None
deleted = await repository.soft_delete_by_id(widget_id)
assert deleted is not None
assert deleted.deleted_at is not None
missing = await repository.get_by_id(widget_id)
assert missing is None
@pytest.mark.asyncio
async def test_soft_delete_sets_timestamp(db_session: AsyncSession) -> None:
repository = BaseRepository(db_session, Widget)
widget_id = uuid4()
widget = Widget(id=widget_id, name="widget")
db_session.add(widget)
await db_session.commit()
deleted = await repository.soft_delete_by_id(widget_id)
assert deleted is not None
assert isinstance(deleted.deleted_at, datetime)
deleted_at = deleted.deleted_at
if deleted_at.tzinfo is None:
deleted_at = deleted_at.replace(tzinfo=timezone.utc)
assert deleted_at <= datetime.now(timezone.utc)
@@ -0,0 +1,114 @@
from __future__ import annotations
from uuid import uuid4
import pytest
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from core.db.base import Base
from models.profile import Profile
@pytest.fixture
async def db_engine():
"""Create in-memory SQLite engine for testing."""
engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
echo=False,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest.fixture
async def db_session(db_engine):
"""Create a database session for testing."""
async_session = async_sessionmaker(
bind=db_engine,
class_=AsyncSession,
expire_on_commit=False,
)
async with async_session() as session:
yield session
await session.rollback()
@pytest.mark.asyncio
async def test_profile_model_create(db_session: AsyncSession) -> None:
"""Test creating a Profile model."""
profile_id = uuid4()
profile = Profile(
id=profile_id,
username="testuser",
display_name="Test User",
)
db_session.add(profile)
await db_session.commit()
await db_session.refresh(profile)
assert profile.id == profile_id
assert profile.username == "testuser"
assert profile.display_name == "Test User"
assert profile.created_at is not None
assert profile.updated_at is not None
assert profile.deleted_at is None
@pytest.mark.asyncio
async def test_profile_model_get_by_id(db_session: AsyncSession) -> None:
"""Test retrieving a Profile by ID."""
profile_id = uuid4()
profile = Profile(
id=profile_id,
username="testuser",
display_name="Test User",
)
db_session.add(profile)
await db_session.commit()
result = await db_session.get(Profile, profile_id)
assert result is not None
assert result.username == "testuser"
@pytest.mark.asyncio
async def test_profile_model_get_by_username(db_session: AsyncSession) -> None:
"""Test retrieving a Profile by username."""
profile = Profile(
id=uuid4(),
username="testuser",
display_name="Test User",
)
db_session.add(profile)
await db_session.commit()
result = await db_session.execute(
select(Profile).where(Profile.username == "testuser")
)
found = result.scalar_one()
assert found is not None
assert found.username == "testuser"
@pytest.mark.asyncio
async def test_profile_model_update(db_session: AsyncSession) -> None:
"""Test updating a Profile."""
profile = Profile(
id=uuid4(),
username="testuser",
display_name="Test User",
bio="Old bio",
)
db_session.add(profile)
await db_session.commit()
profile.display_name = "Updated User"
profile.bio = "New bio"
await db_session.commit()
await db_session.refresh(profile)
assert profile.display_name == "Updated User"
assert profile.bio == "New bio"
@@ -0,0 +1,78 @@
from __future__ import annotations
import pytest
from core.config.settings import QdrantSettings
from services.base.qdrant import QdrantService
class _FakeCollection:
def __init__(self, name: str) -> None:
self.name = name
class _FakeCollections:
def __init__(self) -> None:
self.collections = [_FakeCollection("default")]
class _FakeQdrantClient:
def get_collections(self) -> _FakeCollections:
return _FakeCollections()
@pytest.mark.asyncio
async def test_initialize_success(monkeypatch: pytest.MonkeyPatch) -> None:
service = QdrantService(settings=QdrantSettings(host="localhost", port=6333))
def _build_client(_: QdrantService) -> _FakeQdrantClient:
return _FakeQdrantClient()
monkeypatch.setattr(QdrantService, "_build_client", _build_client)
result = await service.initialize()
assert result is True
assert service.is_initialized is True
health = await service.health_check()
assert health["status"] == "healthy"
@pytest.mark.asyncio
async def test_initialize_failure(monkeypatch: pytest.MonkeyPatch) -> None:
service = QdrantService(settings=QdrantSettings(host="localhost", port=6333))
def _build_client(_: QdrantService) -> _FakeQdrantClient:
raise RuntimeError("boom")
monkeypatch.setattr(QdrantService, "_build_client", _build_client)
result = await service.initialize()
assert result is False
assert service.is_initialized is False
@pytest.mark.asyncio
async def test_health_check_returns_unhealthy_when_not_initialized() -> None:
service = QdrantService(settings=QdrantSettings(host="localhost", port=6333))
health = await service.health_check()
assert health["status"] == "unhealthy"
@pytest.mark.asyncio
async def test_close_is_idempotent() -> None:
service = QdrantService(settings=QdrantSettings(host="localhost", port=6333))
assert await service.close() is True
assert service.is_initialized is False
def test_get_client_raises_before_init() -> None:
service = QdrantService(settings=QdrantSettings(host="localhost", port=6333))
with pytest.raises(RuntimeError):
service.get_client()
@@ -0,0 +1,98 @@
from __future__ import annotations
import pytest
from core.config.settings import RedisSettings
from services.base.redis import RedisService
class _FakeRedisClient:
def __init__(self) -> None:
self.closed = False
async def ping(self) -> bool:
return True
async def info(self) -> dict[str, object]:
return {
"redis_version": "7.2",
"connected_clients": 1,
"used_memory_human": "1M",
"uptime_in_seconds": 10,
}
async def aclose(self) -> None:
self.closed = True
@pytest.mark.asyncio
async def test_initialize_success(monkeypatch: pytest.MonkeyPatch) -> None:
service = RedisService(settings=RedisSettings(host="localhost", port=6379))
def _build_client(_: RedisService) -> _FakeRedisClient:
return _FakeRedisClient()
monkeypatch.setattr(RedisService, "_build_client", _build_client)
result = await service.initialize()
assert result is True
assert service.is_initialized is True
health = await service.health_check()
assert health["status"] == "healthy"
@pytest.mark.asyncio
async def test_initialize_failure(monkeypatch: pytest.MonkeyPatch) -> None:
service = RedisService(settings=RedisSettings(host="localhost", port=6379))
def _build_client(_: RedisService) -> _FakeRedisClient:
raise RuntimeError("boom")
monkeypatch.setattr(RedisService, "_build_client", _build_client)
result = await service.initialize()
assert result is False
assert service.is_initialized is False
@pytest.mark.asyncio
async def test_close_is_idempotent() -> None:
service = RedisService(settings=RedisSettings(host="localhost", port=6379))
assert await service.close() is True
assert service.is_initialized is False
@pytest.mark.asyncio
async def test_health_check_uninitialized() -> None:
service = RedisService(settings=RedisSettings(host="localhost", port=6379))
health = await service.health_check()
assert health["status"] == "unhealthy"
@pytest.mark.asyncio
async def test_close_closes_client(monkeypatch: pytest.MonkeyPatch) -> None:
service = RedisService(settings=RedisSettings(host="localhost", port=6379))
client = _FakeRedisClient()
def _build_client(_: RedisService) -> _FakeRedisClient:
return client
monkeypatch.setattr(RedisService, "_build_client", _build_client)
assert await service.initialize() is True
assert await service.close() is True
assert client.closed is True
assert service.is_initialized is False
def test_get_client_raises_before_init() -> None:
service = RedisService(settings=RedisSettings(host="localhost", port=6379))
with pytest.raises(RuntimeError):
service.get_client()
@@ -0,0 +1,49 @@
from __future__ import annotations
from services.base.service_interface import (
BaseServiceProvider,
ServiceRegistry,
register_service,
register_service_instance,
)
class _DummyService(BaseServiceProvider):
def __init__(self, name: str = "dummy") -> None:
super().__init__(name)
async def initialize(self, **_: object) -> bool:
self._set_initialized(True)
return True
async def close(self) -> bool:
self._set_initialized(False)
return True
async def health_check(self) -> dict[str, object]:
return {"status": "healthy", "details": {}}
def test_register_service_and_create_service() -> None:
@register_service("dummy-service")
class _RegisteredService(_DummyService):
pass
created = ServiceRegistry.create_service("dummy-service")
assert created is not None
assert created.get_service_info()["name"] == "dummy"
def test_register_service_instance_returns_same_instance() -> None:
instance = _DummyService("singleton")
returned = register_service_instance("dummy-singleton", instance)
created = ServiceRegistry.create_service("dummy-singleton")
assert returned is instance
assert created is instance
def test_create_service_returns_none_for_missing() -> None:
assert ServiceRegistry.create_service("missing-service") is None
+45
View File
@@ -0,0 +1,45 @@
from __future__ import annotations
from celery import Celery
from pytest import MonkeyPatch
from core.logging import celery as celery_logging
from core.logging.context import clear_context, get_context
class DummyTask:
name: str = "tasks.sample"
def test_celery_prerun_binds_task_context() -> None:
handlers = celery_logging.build_celery_signal_handlers()
handlers.on_task_prerun(task_id="task-123", task=DummyTask())
context = get_context()
assert context["task_id"] == "task-123"
assert context["task_name"] == "tasks.sample"
clear_context()
def test_celery_setup_logging_calls_configure(monkeypatch: MonkeyPatch) -> None:
called = {"value": False}
def fake_configure_logging(settings: object | None = None) -> None:
called["value"] = True
monkeypatch.setattr(celery_logging, "configure_logging", fake_configure_logging)
handlers = celery_logging.build_celery_signal_handlers()
handlers.on_setup_logging()
assert called["value"] is True
def test_configure_celery_app_disables_hijack() -> None:
app = Celery("test")
celery_logging.configure_celery_app(app)
assert app.conf.worker_hijack_root_logger is False
+140
View File
@@ -0,0 +1,140 @@
from __future__ import annotations
import json
import logging
from collections.abc import Iterator
from pathlib import Path
from typing import cast
import pytest
import structlog
from core.config.settings import Settings
from core.logging.config import build_logging_config, configure_logging
def _get_handlers(config: dict[str, object]) -> dict[str, dict[str, object]]:
return cast(dict[str, dict[str, object]], config["handlers"])
def test_build_logging_config_time_rotation(tmp_path: Path) -> None:
settings = Settings()
runtime = settings.runtime.model_copy(
update={
"log_dir": str(tmp_path),
"log_error_dir": str(tmp_path / "errors"),
"log_rotation": "time",
}
)
config = build_logging_config(runtime)
handlers = _get_handlers(config)
assert handlers["file"]["class"] == "logging.handlers.TimedRotatingFileHandler"
assert handlers["error"]["class"] == "logging.handlers.TimedRotatingFileHandler"
assert handlers["error"]["level"] == "ERROR"
def test_build_logging_config_size_rotation(tmp_path: Path) -> None:
settings = Settings()
runtime = settings.runtime.model_copy(
update={
"log_dir": str(tmp_path),
"log_error_dir": str(tmp_path / "errors"),
"log_rotation": "size",
"log_rotation_max_bytes": 2048,
}
)
config = build_logging_config(runtime)
handlers = _get_handlers(config)
assert handlers["file"]["class"] == "logging.handlers.RotatingFileHandler"
assert handlers["error"]["class"] == "logging.handlers.RotatingFileHandler"
assert handlers["file"]["maxBytes"] == 2048
def test_build_logging_config_plain_formatter_when_disabled(tmp_path: Path) -> None:
settings = Settings()
runtime = settings.runtime.model_copy(
update={
"log_dir": str(tmp_path),
"log_error_dir": str(tmp_path / "errors"),
"log_json": False,
}
)
config = build_logging_config(runtime)
handlers = _get_handlers(config)
assert handlers["file"]["formatter"] == "plain"
assert handlers["error"]["formatter"] == "plain"
def _read_last_log_entry(log_path: Path) -> dict[str, object]:
assert log_path.exists(), f"Expected log file at {log_path}"
entries = [
json.loads(line) for line in log_path.read_text().splitlines() if line.strip()
]
assert entries, "Expected at least one log entry in app.log"
return entries[-1]
def _flush_root_handlers() -> None:
root_logger = logging.getLogger()
for handler in root_logger.handlers:
if hasattr(handler, "flush"):
handler.flush()
@pytest.fixture
def configured_logging(tmp_path: Path) -> Iterator[Path]:
settings = Settings()
runtime = settings.runtime.model_copy(
update={
"log_dir": str(tmp_path),
"log_error_dir": str(tmp_path / "errors"),
"log_rotation": "size",
"log_rotation_max_bytes": 2048,
"log_json": True,
}
)
root_logger = logging.getLogger()
original_handlers = root_logger.handlers[:]
original_level = root_logger.level
configure_logging(settings.model_copy(update={"runtime": runtime}))
yield tmp_path
for handler in root_logger.handlers:
handler.close()
root_logger.handlers = original_handlers
root_logger.setLevel(original_level)
structlog.reset_defaults()
def test_stdlib_logging_redacts_sensitive_fields(configured_logging: Path) -> None:
logger = logging.getLogger("tests.stdlib")
logger.info("login", extra={"password": "secret", "token": "abc"})
_flush_root_handlers()
log_path = configured_logging / "app.log"
entry = _read_last_log_entry(log_path)
assert entry["password"] == "[REDACTED]"
assert entry["token"] == "[REDACTED]"
def test_structlog_redacts_sensitive_fields(configured_logging: Path) -> None:
logger = structlog.get_logger("tests.structlog")
logger.info("login", password="secret", token="abc")
_flush_root_handlers()
log_path = configured_logging / "app.log"
entry = _read_last_log_entry(log_path)
assert entry["password"] == "[REDACTED]"
assert entry["token"] == "[REDACTED]"
@@ -0,0 +1,30 @@
from __future__ import annotations
from core.logging.filters import build_sensitive_data_processor
def test_redact_sensitive_fields_masks_values() -> None:
processor = build_sensitive_data_processor(
["password", "token", "api_key", "cookie"]
)
event: dict[str, object] = {
"message": "login",
"password": "secret",
"access_token": "token-123",
"apiKey": "apikey-123",
"set-cookie": "cookie-1",
"nested": {"token": "abc", "safe": "ok"},
"list": [{"password": "x"}],
}
redacted = processor(None, "info", event)
assert redacted["password"] == "[REDACTED]"
assert redacted["access_token"] == "[REDACTED]"
assert redacted["apiKey"] == "[REDACTED]"
assert redacted["set-cookie"] == "[REDACTED]"
assert redacted["nested"]["token"] == "[REDACTED]"
assert redacted["nested"]["safe"] == "ok"
assert redacted["list"][0]["password"] == "[REDACTED]"
assert event["password"] == "secret"
@@ -0,0 +1,35 @@
from __future__ import annotations
from pytest import MonkeyPatch
from core.config.settings import Settings
def test_runtime_settings_defaults() -> None:
settings = Settings()
assert settings.runtime.log_json is True
assert settings.runtime.log_rotation == "time"
assert settings.runtime.log_rotation_when == "midnight"
assert settings.runtime.log_rotation_interval == 1
assert settings.runtime.log_rotation_backup_count == 14
assert settings.runtime.log_rotation_max_bytes == 10_000_000
assert settings.runtime.log_dir == "logs"
assert settings.runtime.log_error_dir == "logs/errors"
assert settings.runtime.log_file_name == "app.log"
assert settings.runtime.log_error_file_name == "error.log"
assert "password" in settings.runtime.log_sensitive_fields
def test_runtime_settings_env_override(monkeypatch: MonkeyPatch) -> None:
monkeypatch.setenv("SOCIAL_RUNTIME__LOG_DIR", "var/logs")
monkeypatch.setenv("SOCIAL_RUNTIME__LOG_ERROR_DIR", "var/logs/errors")
monkeypatch.setenv("SOCIAL_RUNTIME__LOG_ROTATION", "size")
monkeypatch.setenv("SOCIAL_RUNTIME__LOG_ROTATION_MAX_BYTES", "2048")
settings = Settings()
assert settings.runtime.log_dir == "var/logs"
assert settings.runtime.log_error_dir == "var/logs/errors"
assert settings.runtime.log_rotation == "size"
assert settings.runtime.log_rotation_max_bytes == 2048
@@ -0,0 +1,30 @@
from __future__ import annotations
from core.http.response import ProblemDetails, build_problem_details
def test_problem_details_defaults() -> None:
result = build_problem_details(status_code=401, detail="Unauthorized")
assert isinstance(result, ProblemDetails)
assert result.type == "about:blank"
assert result.title == "Unauthorized"
assert result.status == 401
assert result.detail == "Unauthorized"
assert result.instance is None
def test_problem_details_overrides() -> None:
result = build_problem_details(
status_code=409,
detail="Conflict",
type_value="https://example.com/problems/conflict",
title="Conflict",
instance="/api/mobile/auth/signup",
)
assert result.type == "https://example.com/problems/conflict"
assert result.title == "Conflict"
assert result.status == 409
assert result.detail == "Conflict"
assert result.instance == "/api/mobile/auth/signup"
@@ -0,0 +1,49 @@
from __future__ import annotations
from pytest import MonkeyPatch
from core.config.settings import Settings
def test_social_prefixed_supabase_env_populates_settings(
monkeypatch: MonkeyPatch,
) -> None:
monkeypatch.setenv("SOCIAL_SUPABASE__PUBLIC_SCHEME", "https")
monkeypatch.setenv("SOCIAL_SUPABASE__PUBLIC_HOST", "public.example")
monkeypatch.setenv("SOCIAL_SUPABASE__KONG_HTTP_PORT", "8443")
monkeypatch.setenv("SOCIAL_SUPABASE__ANON_KEY", "anon-key")
monkeypatch.setenv("SOCIAL_SUPABASE__SERVICE_ROLE_KEY", "service-key")
monkeypatch.setenv("SOCIAL_SUPABASE__JWT_SECRET", "jwt-secret")
monkeypatch.setenv("SOCIAL_DATABASE__HOST", "db")
monkeypatch.setenv("SOCIAL_DATABASE__PORT", "5432")
monkeypatch.setenv("SOCIAL_DATABASE__NAME", "app")
monkeypatch.setenv("SOCIAL_DATABASE__USER", "user")
monkeypatch.setenv("SOCIAL_DATABASE__PASSWORD", "pass")
settings = Settings()
assert settings.supabase.public_url == "https://public.example:8443"
assert settings.supabase.api_external_url == "https://public.example:8443"
assert settings.supabase.anon_key == "anon-key"
assert settings.supabase.service_role_key == "service-key"
assert settings.supabase.jwt_secret == "jwt-secret"
supabase_settings = settings.model_dump()["supabase"]
assert supabase_settings["public_url"] == "https://public.example:8443"
assert supabase_settings["api_external_url"] == "https://public.example:8443"
assert supabase_settings["anon_key"] == "anon-key"
assert supabase_settings["service_role_key"] == "service-key"
assert supabase_settings["jwt_secret"] == "jwt-secret"
assert settings.database_url == "postgresql+asyncpg://user:pass@db:5432/app"
def test_social_prefixed_api_external_url_is_loaded(
monkeypatch: MonkeyPatch,
) -> None:
monkeypatch.setenv("SOCIAL_SUPABASE__PUBLIC_SCHEME", "https")
monkeypatch.setenv("SOCIAL_SUPABASE__PUBLIC_HOST", "api.example")
monkeypatch.setenv("SOCIAL_SUPABASE__KONG_HTTP_PORT", "8443")
settings = Settings()
assert settings.supabase.api_external_url == "https://api.example:8443"
@@ -0,0 +1,41 @@
from __future__ import annotations
import pytest
from pydantic import ValidationError
from v1.auth.models import (
AuthTokenResponse,
AuthUser,
LoginRequest,
RefreshRequest,
SignupRequest,
)
def test_signup_requires_valid_email() -> None:
with pytest.raises(ValidationError):
SignupRequest(email="not-an-email", password="secret123")
def test_login_requires_valid_email() -> None:
with pytest.raises(ValidationError):
LoginRequest(email="invalid", password="secret123")
def test_refresh_requires_token() -> None:
with pytest.raises(ValidationError):
RefreshRequest(refresh_token="")
def test_auth_token_response_maps_user() -> None:
user = AuthUser(id="user-1", email="user@example.com")
response = AuthTokenResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
assert response.user.id == "user-1"
assert response.user.email == "user@example.com"
@@ -0,0 +1,74 @@
from __future__ import annotations
import pytest
from v1.auth.models import (
AuthTokenResponse,
AuthUser,
LoginRequest,
RefreshRequest,
SignupRequest,
)
from v1.auth.service import AuthService, AuthServiceGateway
class FakeGateway(AuthServiceGateway):
def __init__(self, response: AuthTokenResponse) -> None:
self._response = response
async def signup(self, request: SignupRequest) -> AuthTokenResponse:
return self._response
async def login(self, request: LoginRequest) -> AuthTokenResponse:
return self._response
async def refresh(self, request: RefreshRequest) -> AuthTokenResponse:
return self._response
async def logout(self, refresh_token: str | None) -> None:
return None
@pytest.mark.asyncio
async def test_signup_maps_response() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = AuthTokenResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
service = AuthService(gateway=FakeGateway(token_response))
result = await service.signup(
SignupRequest(email="user@example.com", password="secret123")
)
assert result.access_token == "access"
assert result.refresh_token == "refresh"
assert result.user.id == "user-1"
class LogoutAssertingGateway(AuthServiceGateway):
def __init__(self, expected_refresh_token: str) -> None:
self._expected_refresh_token = expected_refresh_token
async def signup(self, request: SignupRequest) -> AuthTokenResponse:
raise NotImplementedError
async def login(self, request: LoginRequest) -> AuthTokenResponse:
raise NotImplementedError
async def refresh(self, request: RefreshRequest) -> AuthTokenResponse:
raise NotImplementedError
async def logout(self, refresh_token: str | None) -> None:
assert refresh_token == self._expected_refresh_token
@pytest.mark.asyncio
async def test_logout_forwards_refresh_token() -> None:
service = AuthService(gateway=LogoutAssertingGateway("refresh-token"))
await service.logout("refresh-token")
@@ -0,0 +1,285 @@
from __future__ import annotations
import time
from typing import Any
from uuid import UUID
import jwt
import pytest
from fastapi import HTTPException
from core.auth.models import CurrentUser
from v1.profile.dependencies import get_current_user
class TestGetCurrentUser:
"""Tests for JWT validation in get_current_user dependency."""
@pytest.fixture
def jwt_secret(self) -> str:
return "super-secret-jwt-token-with-at-least-32-characters"
@pytest.fixture
def valid_user_id(self) -> str:
return "00000000-0000-0000-0000-000000000123"
@pytest.fixture
def valid_payload(self, valid_user_id: str) -> dict[str, Any]:
"""Valid JWT payload with all required claims."""
now = int(time.time())
return {
"sub": valid_user_id,
"aud": "authenticated",
"iss": "http://localhost:8001/auth/v1",
"exp": now + 3600, # 1 hour from now
"iat": now,
}
def _create_token(self, payload: dict[str, Any], secret: str) -> str:
return jwt.encode(payload, secret, algorithm="HS256")
def test_valid_token_returns_current_user(
self,
jwt_secret: str,
valid_payload: dict[str, Any],
valid_user_id: str,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Valid JWT with correct aud/iss/exp should return CurrentUser."""
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.jwt_secret", jwt_secret
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.public_scheme",
"http",
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.public_host",
"localhost",
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.kong_http_port",
8001,
)
token = self._create_token(valid_payload, jwt_secret)
authorization = f"Bearer {token}"
result = get_current_user(authorization=authorization)
assert isinstance(result, CurrentUser)
assert result.id == UUID(valid_user_id)
def test_missing_authorization_raises_401(self) -> None:
"""Missing Authorization header should raise 401."""
with pytest.raises(HTTPException) as exc_info:
get_current_user(authorization=None)
assert exc_info.value.status_code == 401
assert exc_info.value.detail == "Unauthorized"
def test_invalid_scheme_raises_401(self) -> None:
"""Non-Bearer scheme should raise 401."""
with pytest.raises(HTTPException) as exc_info:
get_current_user(authorization="Basic dXNlcjpwYXNz")
assert exc_info.value.status_code == 401
def test_expired_token_raises_401(
self,
jwt_secret: str,
valid_payload: dict[str, Any],
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Expired JWT should raise 401."""
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.jwt_secret", jwt_secret
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.public_scheme",
"http",
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.public_host",
"localhost",
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.kong_http_port",
8001,
)
valid_payload["exp"] = int(time.time()) - 3600 # 1 hour ago
token = self._create_token(valid_payload, jwt_secret)
with pytest.raises(HTTPException) as exc_info:
get_current_user(authorization=f"Bearer {token}")
assert exc_info.value.status_code == 401
def test_invalid_audience_raises_401(
self,
jwt_secret: str,
valid_payload: dict[str, Any],
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""JWT with wrong audience should raise 401."""
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.jwt_secret", jwt_secret
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.public_scheme",
"http",
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.public_host",
"localhost",
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.kong_http_port",
8001,
)
valid_payload["aud"] = "wrong-audience"
token = self._create_token(valid_payload, jwt_secret)
with pytest.raises(HTTPException) as exc_info:
get_current_user(authorization=f"Bearer {token}")
assert exc_info.value.status_code == 401
def test_invalid_issuer_raises_401(
self,
jwt_secret: str,
valid_payload: dict[str, Any],
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""JWT with wrong issuer should raise 401."""
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.jwt_secret", jwt_secret
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.public_scheme",
"http",
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.public_host",
"localhost",
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.kong_http_port",
8001,
)
valid_payload["iss"] = "http://malicious-site.com/auth/v1"
token = self._create_token(valid_payload, jwt_secret)
with pytest.raises(HTTPException) as exc_info:
get_current_user(authorization=f"Bearer {token}")
assert exc_info.value.status_code == 401
def test_missing_subject_raises_401(
self,
jwt_secret: str,
valid_payload: dict[str, Any],
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""JWT without 'sub' claim should raise 401."""
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.jwt_secret", jwt_secret
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.public_scheme",
"http",
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.public_host",
"localhost",
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.kong_http_port",
8001,
)
del valid_payload["sub"]
token = self._create_token(valid_payload, jwt_secret)
with pytest.raises(HTTPException) as exc_info:
get_current_user(authorization=f"Bearer {token}")
assert exc_info.value.status_code == 401
def test_wrong_secret_raises_401(
self,
jwt_secret: str,
valid_payload: dict[str, Any],
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""JWT signed with wrong secret should raise 401."""
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.jwt_secret", jwt_secret
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.public_scheme",
"http",
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.public_host",
"localhost",
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.kong_http_port",
8001,
)
token = self._create_token(
valid_payload, "wrong-secret-key-that-is-long-enough"
)
with pytest.raises(HTTPException) as exc_info:
get_current_user(authorization=f"Bearer {token}")
assert exc_info.value.status_code == 401
def test_jwt_secret_not_configured_raises_503(
self, valid_payload: dict[str, Any], monkeypatch: pytest.MonkeyPatch
) -> None:
"""Missing JWT secret in config should raise 503."""
monkeypatch.setattr("v1.profile.dependencies.config.supabase.jwt_secret", None)
with pytest.raises(HTTPException) as exc_info:
get_current_user(authorization="Bearer some-token")
assert exc_info.value.status_code == 503
assert exc_info.value.detail == "JWT secret not configured"
def test_invalid_uuid_in_subject_raises_401(
self,
jwt_secret: str,
valid_payload: dict[str, Any],
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""JWT with non-UUID 'sub' claim should raise 401."""
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.jwt_secret", jwt_secret
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.public_scheme",
"http",
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.public_host",
"localhost",
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.kong_http_port",
8001,
)
valid_payload["sub"] = "not-a-valid-uuid"
token = self._create_token(valid_payload, jwt_secret)
with pytest.raises(HTTPException) as exc_info:
get_current_user(authorization=f"Bearer {token}")
assert exc_info.value.status_code == 401
@@ -0,0 +1,172 @@
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
from uuid import UUID
import pytest
from fastapi import HTTPException
from core.auth.models import CurrentUser
from models.profile import Profile
from v1.profile.repository import ProfileRepository
from v1.profile.schemas import ProfileUpdateRequest
from v1.profile.service import ProfileService
def _create_mock_profile(
user_id: UUID = UUID("00000000-0000-0000-0000-000000000001"),
username: str = "demo",
display_name: str | None = "Demo User",
avatar_url: str | None = None,
bio: str | None = None,
) -> Profile:
"""Create a mock Profile ORM object."""
profile = MagicMock(spec=Profile)
profile.id = user_id
profile.username = username
profile.display_name = display_name
profile.avatar_url = avatar_url
profile.bio = bio
return profile
class FakeRepo:
"""Fake repository for testing that conforms to ProfileRepository protocol."""
def __init__(self, profile: Profile | None) -> None:
self._profile = profile
async def get_by_user_id(self, user_id: UUID) -> Profile | None:
if self._profile and user_id == self._profile.id:
return self._profile
return None
async def get_by_username(self, username: str) -> Profile | None:
if self._profile and username == self._profile.username:
return self._profile
return None
async def update_by_user_id(
self, user_id: UUID, update_data: dict[str, str | None]
) -> Profile | None:
if not self._profile or user_id != self._profile.id:
return None
# Apply updates to mock
for key, value in update_data.items():
if hasattr(self._profile, key):
setattr(self._profile, key, value)
return self._profile
# Verify FakeRepo implements the protocol
_repo_check: ProfileRepository = FakeRepo(None)
@pytest.fixture
def mock_session() -> AsyncMock:
"""Create a mock AsyncSession."""
session = AsyncMock()
session.commit = AsyncMock()
session.rollback = AsyncMock()
return session
@pytest.mark.asyncio
async def test_get_me_returns_profile(mock_session: AsyncMock) -> None:
user_id = UUID("00000000-0000-0000-0000-000000000001")
profile = _create_mock_profile(user_id=user_id, username="demo")
user = CurrentUser(id=user_id)
service = ProfileService(
repository=FakeRepo(profile),
session=mock_session,
current_user=user,
)
result = await service.get_me()
assert result.username == "demo"
assert result.id == str(user_id)
@pytest.mark.asyncio
async def test_get_me_not_found_raises_404(mock_session: AsyncMock) -> None:
user_id = UUID("00000000-0000-0000-0000-000000000001")
user = CurrentUser(id=user_id)
service = ProfileService(
repository=FakeRepo(None),
session=mock_session,
current_user=user,
)
with pytest.raises(HTTPException) as exc_info:
await service.get_me()
assert exc_info.value.status_code == 404
@pytest.mark.asyncio
async def test_update_me_updates_fields(mock_session: AsyncMock) -> None:
user_id = UUID("00000000-0000-0000-0000-000000000001")
profile = _create_mock_profile(user_id=user_id, username="demo")
user = CurrentUser(id=user_id)
service = ProfileService(
repository=FakeRepo(profile),
session=mock_session,
current_user=user,
)
result = await service.update_me(ProfileUpdateRequest(display_name="Updated"))
assert result.display_name == "Updated"
mock_session.commit.assert_awaited_once()
@pytest.mark.asyncio
async def test_update_me_no_fields_raises_400(mock_session: AsyncMock) -> None:
user_id = UUID("00000000-0000-0000-0000-000000000001")
profile = _create_mock_profile(user_id=user_id)
user = CurrentUser(id=user_id)
service = ProfileService(
repository=FakeRepo(profile),
session=mock_session,
current_user=user,
)
# Create a request with all None values by bypassing validation
update = MagicMock(spec=ProfileUpdateRequest)
update.display_name = None
update.avatar_url = None
update.bio = None
with pytest.raises(HTTPException) as exc_info:
await service.update_me(update)
assert exc_info.value.status_code == 400
@pytest.mark.asyncio
async def test_get_by_username_returns_profile(mock_session: AsyncMock) -> None:
profile = _create_mock_profile(username="demo")
service = ProfileService(
repository=FakeRepo(profile),
session=mock_session,
current_user=CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001")),
)
result = await service.get_by_username("demo")
assert result.username == "demo"
@pytest.mark.asyncio
async def test_get_by_username_not_found_raises_404(mock_session: AsyncMock) -> None:
service = ProfileService(
repository=FakeRepo(None),
session=mock_session,
current_user=CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001")),
)
with pytest.raises(HTTPException) as exc_info:
await service.get_by_username("unknown")
assert exc_info.value.status_code == 404
@@ -0,0 +1,61 @@
from __future__ import annotations
import pytest
from pydantic import ValidationError
from v1.profile.schemas import ProfileResponse, ProfileUpdateRequest
def test_profile_response_maps_fields() -> None:
response = ProfileResponse(
id="user-1",
username="demo",
display_name="Demo User",
avatar_url=None,
bio=None,
)
assert response.id == "user-1"
assert response.username == "demo"
def test_profile_update_requires_one_field() -> None:
with pytest.raises(ValidationError):
ProfileUpdateRequest()
def test_profile_update_accepts_valid_https_url() -> None:
request = ProfileUpdateRequest(avatar_url="https://example.com/avatar.png")
assert request.avatar_url == "https://example.com/avatar.png"
def test_profile_update_accepts_valid_http_url() -> None:
request = ProfileUpdateRequest(
avatar_url="http://localhost:8001/storage/avatar.png"
)
assert request.avatar_url == "http://localhost:8001/storage/avatar.png"
def test_profile_update_rejects_invalid_url() -> None:
with pytest.raises(ValidationError) as exc_info:
ProfileUpdateRequest(avatar_url="not-a-valid-url")
errors = exc_info.value.errors()
assert len(errors) == 1
assert "avatar_url" in str(errors[0]["loc"])
def test_profile_update_rejects_javascript_url() -> None:
with pytest.raises(ValidationError):
ProfileUpdateRequest(avatar_url="javascript:alert('xss')")
def test_profile_update_rejects_data_url() -> None:
with pytest.raises(ValidationError):
ProfileUpdateRequest(avatar_url="data:text/html,<script>alert('xss')</script>")
def test_profile_update_accepts_none_avatar_url_with_other_field() -> None:
request = ProfileUpdateRequest(display_name="Test", avatar_url=None)
assert request.avatar_url is None
assert request.display_name == "Test"