refactor: align backend layout and supabase infra
Consolidate backend modules/tests under the backend package while syncing Supabase compose/env config and related plans.
This commit is contained in:
@@ -0,0 +1,111 @@
|
||||
## Python Environment
|
||||
|
||||
**MUST use uv for dependency management and virtual environment execution.**
|
||||
|
||||
- All Python commands: `uv run <command>`
|
||||
- Add dependencies: `uv add <package>`
|
||||
- All dependencies declared in `pyproject.toml`
|
||||
|
||||
## Logging
|
||||
|
||||
**MUST use project logger for all runtime logging.**
|
||||
|
||||
- Use project logger from `backend/src/core/logging/*`
|
||||
- Prohibit: print(), logging.info/warning/error directly
|
||||
- Required: structured logging with context
|
||||
- Log levels: DEBUG, INFO, WARNING, ERROR, CRITICAL
|
||||
|
||||
## HTTP API Standards
|
||||
|
||||
**MUST follow RESTful conventions and RFC 7807 for error responses.**
|
||||
|
||||
- Errors must use `application/problem+json` with RFC 7807 fields
|
||||
- No custom response envelopes for HTTP APIs
|
||||
- Request and response validation must use Pydantic models
|
||||
|
||||
## Environment Variables
|
||||
|
||||
**Backend env access MUST go through** `backend/src/core/config/settings.py`.
|
||||
|
||||
- Only use `Settings()` / `config` from `core.config.settings`
|
||||
- Do not call `os.environ`, `os.getenv`, `dotenv`, or manual parsing in backend runtime code
|
||||
- Tests can set env vars via `monkeypatch.setenv`, and should read values via `Settings()` unless the test is explicitly validating env plumbing
|
||||
- Canonical principle: one source of truth per setting; no duplicate/derived env vars in backend code
|
||||
|
||||
## Code Quality Checks
|
||||
|
||||
**Git pre-commit hook enforces code quality before commit.**
|
||||
|
||||
Pre-commit hook automatically runs on backend/ directory:
|
||||
- `ruff check` - code style and linting
|
||||
- `basedpyright` - type checking with error level
|
||||
|
||||
If any error detected, commit is rejected. Fix errors before committing.
|
||||
Do not bypass or weaken checks (no ignores, disables, or config relaxations). Resolve the underlying issues.
|
||||
|
||||
## TDD First Policy
|
||||
|
||||
**Principle: tests before implementation.**
|
||||
|
||||
### Coverage Requirements
|
||||
- Minimum coverage: 80%
|
||||
- Required test types:
|
||||
- Unit: isolated functions, utilities, components
|
||||
- Integration: API endpoints, database operations
|
||||
- E2E: critical user flows (Playwright)
|
||||
|
||||
### Limited Exceptions
|
||||
- Docs-only changes (README, comments, formatting) may skip integration/E2E
|
||||
- Non-runtime config changes may skip E2E if no behavior changes
|
||||
- Any runtime code change requires unit + integration + E2E
|
||||
- If an exception is used, record the reason in the PR/test notes
|
||||
|
||||
### Mandatory TDD Workflow
|
||||
1. Write tests (RED) - they must fail
|
||||
2. Run tests - confirm failure
|
||||
3. Implement minimal code (GREEN) - only to pass
|
||||
4. Run tests - confirm success
|
||||
5. Refactor (IMPROVE)
|
||||
6. Verify coverage - must be 80%+
|
||||
|
||||
### Enforcement
|
||||
- Must use the `tdd-guide` agent for new features
|
||||
- Do not write implementation before tests
|
||||
- Do not lower coverage requirements
|
||||
- Must include unit, integration, and E2E tests
|
||||
|
||||
## Database Development Rules
|
||||
|
||||
### Core Principle
|
||||
- **Supabase**: authentication (JWT source of truth)
|
||||
- **Backend**: business authorization (service layer)
|
||||
- **SQLAlchemy ORM**: data access layer (async + asyncpg, service_role connection)
|
||||
|
||||
### Architecture
|
||||
Use `schemas / repository / service` pattern:
|
||||
- `schemas.py` — Pydantic models
|
||||
- `repository.py` — CRUD only, no auth, no commit (only flush), must receive session (never create session/engine)
|
||||
- `service.py` — authorization + business logic + transaction boundary (must commit/rollback)
|
||||
- `dependencies.py` — DI (`get_db`, `get_current_user`)
|
||||
|
||||
### Auth & Data Access
|
||||
- Backend must verify JWT signature and expiration (not just decode)
|
||||
- Extract `user_id` from JWT `sub` claim
|
||||
- Backend connects with **service_role** (bypasses RLS)
|
||||
- `owner_id` always derived from JWT, never from client
|
||||
- Scope queries by owner/org; public access must be explicit
|
||||
- service_role key is backend-only; never expose credentials
|
||||
- Prohibit calling Supabase Admin API (service_role key) from repository/service layers
|
||||
|
||||
### Migrations
|
||||
- **Alembic is the single source of truth** for schema migrations
|
||||
- ORM model changes → `alembic revision --autogenerate`
|
||||
- Raw SQL (policies, triggers, functions) → `op.execute()`
|
||||
- Migrations must be reversible; no reliance on generated IDs
|
||||
|
||||
### RLS Guidance
|
||||
- Backend does not rely on RLS for correctness (uses service_role)
|
||||
- **Backend-only tables**: RLS optional (skip to reduce maintenance)
|
||||
- **Client-direct tables**: must enable RLS with policies covering select/insert/update/delete
|
||||
- `alembic_version` must not be exposed to anonymous clients (revoke anon access)
|
||||
- Business tables that may be exposed to clients should enable defensive RLS even if the backend does not depend on it
|
||||
@@ -0,0 +1 @@
|
||||
Generic single-database configuration.
|
||||
@@ -0,0 +1,149 @@
|
||||
# A generic, single database configuration.
|
||||
|
||||
[alembic]
|
||||
# path to migration scripts.
|
||||
# this is typically a path given in POSIX (e.g. forward slashes)
|
||||
# format, relative to the token %(here)s which refers to the location of this
|
||||
# ini file
|
||||
script_location = %(here)s
|
||||
|
||||
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||
# Uncomment the line below if you want the files to be prepended with date and time
|
||||
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
|
||||
# for all available tokens
|
||||
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
||||
# Or organize into date-based subdirectories (requires recursive_version_locations = true)
|
||||
# file_template = %%(year)d/%%(month).2d/%%(day).2d_%%(hour).2d%%(minute).2d_%%(second).2d_%%(rev)s_%%(slug)s
|
||||
|
||||
# sys.path path, will be prepended to sys.path if present.
|
||||
# defaults to the current working directory. for multiple paths, the path separator
|
||||
# is defined by "path_separator" below.
|
||||
prepend_sys_path = .
|
||||
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
# as well as the filename.
|
||||
# If specified, requires the tzdata library which can be installed by adding
|
||||
# `alembic[tz]` to the pip requirements.
|
||||
# string value is passed to ZoneInfo()
|
||||
# leave blank for localtime
|
||||
# timezone =
|
||||
|
||||
# max length of characters to apply to the "slug" field
|
||||
# truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during
|
||||
# the 'revision' command, regardless of autogenerate
|
||||
# revision_environment = false
|
||||
|
||||
# set to 'true' to allow .pyc and .pyo files without
|
||||
# a source .py file to be detected as revisions in the
|
||||
# versions/ directory
|
||||
# sourceless = false
|
||||
|
||||
# version location specification; This defaults
|
||||
# to <script_location>/versions. When using multiple version
|
||||
# directories, initial revisions must be specified with --version-path.
|
||||
# The path separator used here should be the separator specified by "path_separator"
|
||||
# below.
|
||||
# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions
|
||||
|
||||
# path_separator; This indicates what character is used to split lists of file
|
||||
# paths, including version_locations and prepend_sys_path within configparser
|
||||
# files such as alembic.ini.
|
||||
# The default rendered in new alembic.ini files is "os", which uses os.pathsep
|
||||
# to provide os-dependent path splitting.
|
||||
#
|
||||
# Note that in order to support legacy alembic.ini files, this default does NOT
|
||||
# take place if path_separator is not present in alembic.ini. If this
|
||||
# option is omitted entirely, fallback logic is as follows:
|
||||
#
|
||||
# 1. Parsing of the version_locations option falls back to using the legacy
|
||||
# "version_path_separator" key, which if absent then falls back to the legacy
|
||||
# behavior of splitting on spaces and/or commas.
|
||||
# 2. Parsing of the prepend_sys_path option falls back to the legacy
|
||||
# behavior of splitting on spaces, commas, or colons.
|
||||
#
|
||||
# Valid values for path_separator are:
|
||||
#
|
||||
# path_separator = :
|
||||
# path_separator = ;
|
||||
# path_separator = space
|
||||
# path_separator = newline
|
||||
#
|
||||
# Use os.pathsep. Default configuration used for new projects.
|
||||
path_separator = os
|
||||
|
||||
# set to 'true' to search source files recursively
|
||||
# in each "version_locations" directory
|
||||
# new in Alembic version 1.10
|
||||
# recursive_version_locations = false
|
||||
|
||||
# the output encoding used when revision files
|
||||
# are written from script.py.mako
|
||||
# output_encoding = utf-8
|
||||
|
||||
# database URL. This is consumed by the user-maintained env.py script only.
|
||||
# other means of configuring database URLs may be customized within the env.py
|
||||
# file.
|
||||
sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||
|
||||
|
||||
[post_write_hooks]
|
||||
# post_write_hooks defines scripts or Python functions that are run
|
||||
# on newly generated revision scripts. See the documentation for further
|
||||
# detail and examples
|
||||
|
||||
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
||||
# hooks = black
|
||||
# black.type = console_scripts
|
||||
# black.entrypoint = black
|
||||
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||
|
||||
# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module
|
||||
# hooks = ruff
|
||||
# ruff.type = module
|
||||
# ruff.module = ruff
|
||||
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Alternatively, use the exec runner to execute a binary found on your PATH
|
||||
# hooks = ruff
|
||||
# ruff.type = exec
|
||||
# ruff.executable = ruff
|
||||
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Logging configuration. This is also consumed by the user-maintained
|
||||
# env.py script only.
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARNING
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARNING
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
@@ -0,0 +1,90 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from logging.config import fileConfig
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from alembic import context
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.ext.asyncio import async_engine_from_config
|
||||
|
||||
project_root = Path(__file__).resolve().parents[1]
|
||||
src_path = project_root / "src"
|
||||
if str(src_path) not in sys.path:
|
||||
sys.path = [str(src_path), *sys.path]
|
||||
|
||||
from core.config.settings import config # noqa: E402
|
||||
from core.db.base import Base # noqa: E402
|
||||
from models import Profile # noqa: F401,E402
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.engine import Connection
|
||||
|
||||
alembic_config = context.config
|
||||
|
||||
if alembic_config.config_file_name is not None:
|
||||
fileConfig(alembic_config.config_file_name)
|
||||
|
||||
target_metadata = Base.metadata
|
||||
|
||||
|
||||
def _get_database_url() -> str:
|
||||
database_url = config.database_url
|
||||
if not database_url:
|
||||
raise RuntimeError(
|
||||
"DATABASE_URL is not configured. Set SOCIAL_INFRA__SUPABASE__DATABASE_URL."
|
||||
)
|
||||
return database_url
|
||||
|
||||
|
||||
def _build_config() -> dict[str, Any]:
|
||||
section = alembic_config.get_section(alembic_config.config_ini_section) or {}
|
||||
return {**section, "sqlalchemy.url": _get_database_url()}
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
url = _get_database_url()
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
compare_type=True,
|
||||
compare_server_default=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def _do_run_migrations(connection: "Connection" | Any) -> None:
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata,
|
||||
compare_type=True,
|
||||
compare_server_default=True,
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
async def run_migrations_online() -> None:
|
||||
connectable = async_engine_from_config(
|
||||
_build_config(),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
async with connectable.connect() as connection:
|
||||
await connection.run_sync(_do_run_migrations)
|
||||
|
||||
await connectable.dispose()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
asyncio.run(run_migrations_online())
|
||||
@@ -0,0 +1,28 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
${downgrades if downgrades else "pass"}
|
||||
@@ -0,0 +1,45 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
revision = "20260205_create_profiles_table"
|
||||
down_revision = None
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"profiles",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("username", sa.String(length=30), nullable=False),
|
||||
sa.Column("display_name", sa.String(length=50), nullable=True),
|
||||
sa.Column("avatar_url", sa.Text(), nullable=True),
|
||||
sa.Column("bio", sa.String(length=200), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id", name="pk_profiles"),
|
||||
sa.UniqueConstraint("username", name="uq_profiles_username"),
|
||||
)
|
||||
op.create_index("ix_profiles_username", "profiles", ["username"])
|
||||
op.create_index("ix_profiles_deleted_at", "profiles", ["deleted_at"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_profiles_deleted_at", table_name="profiles")
|
||||
op.drop_index("ix_profiles_username", table_name="profiles")
|
||||
op.drop_table("profiles")
|
||||
@@ -0,0 +1,86 @@
|
||||
"""enable_rls_security_policies
|
||||
|
||||
Revision ID: 85d25a191d06
|
||||
Revises: 20260205_create_profiles_table
|
||||
Create Date: 2026-02-05 15:08:33.430692
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "85d25a191d06"
|
||||
down_revision: Union[str, Sequence[str], None] = "20260205_create_profiles_table"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Enable RLS security policies.
|
||||
|
||||
Security measures:
|
||||
1. Revoke anon role access to alembic_version (internal table)
|
||||
2. Enable RLS on profiles table
|
||||
3. Add defensive policies for profiles (deny all public access by default)
|
||||
|
||||
Architecture:
|
||||
- Backend uses service_role connection (bypasses RLS)
|
||||
- RLS provides defense-in-depth security layer
|
||||
- Prevents accidental direct PostgREST access
|
||||
"""
|
||||
|
||||
# 1. Revoke anon role access to alembic_version table
|
||||
op.execute("REVOKE ALL ON TABLE public.alembic_version FROM anon")
|
||||
op.execute("REVOKE ALL ON TABLE public.alembic_version FROM authenticated")
|
||||
|
||||
# 2. Enable RLS on profiles table
|
||||
op.execute("ALTER TABLE public.profiles ENABLE ROW LEVEL SECURITY")
|
||||
|
||||
# 3. Add defensive policies for profiles table
|
||||
# These policies deny all public access by default
|
||||
# Backend service_role connection bypasses these policies
|
||||
|
||||
# Deny all SELECT operations for anon and authenticated roles
|
||||
op.execute(
|
||||
"CREATE POLICY profiles_deny_public_select ON public.profiles "
|
||||
"FOR SELECT TO anon, authenticated USING (false)"
|
||||
)
|
||||
|
||||
# Deny all INSERT operations for anon and authenticated roles
|
||||
op.execute(
|
||||
"CREATE POLICY profiles_deny_public_insert ON public.profiles "
|
||||
"FOR INSERT TO anon, authenticated WITH CHECK (false)"
|
||||
)
|
||||
|
||||
# Deny all UPDATE operations for anon and authenticated roles
|
||||
op.execute(
|
||||
"CREATE POLICY profiles_deny_public_update ON public.profiles "
|
||||
"FOR UPDATE TO anon, authenticated USING (false) WITH CHECK (false)"
|
||||
)
|
||||
|
||||
# Deny all DELETE operations for anon and authenticated roles
|
||||
op.execute(
|
||||
"CREATE POLICY profiles_deny_public_delete ON public.profiles "
|
||||
"FOR DELETE TO anon, authenticated USING (false)"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Rollback RLS security policies."""
|
||||
|
||||
# 1. Drop all policies on profiles table
|
||||
op.execute("DROP POLICY IF EXISTS profiles_deny_public_select ON public.profiles")
|
||||
op.execute("DROP POLICY IF EXISTS profiles_deny_public_insert ON public.profiles")
|
||||
op.execute("DROP POLICY IF EXISTS profiles_deny_public_update ON public.profiles")
|
||||
op.execute("DROP POLICY IF EXISTS profiles_deny_public_delete ON public.profiles")
|
||||
|
||||
# 2. Disable RLS on profiles table
|
||||
op.execute("ALTER TABLE public.profiles DISABLE ROW LEVEL SECURITY")
|
||||
|
||||
# 3. Re-grant default privileges to anon role on alembic_version
|
||||
# (reverting to Alembic's default behavior)
|
||||
op.execute("GRANT SELECT ON TABLE public.alembic_version TO anon")
|
||||
op.execute("GRANT SELECT ON TABLE public.alembic_version TO authenticated")
|
||||
@@ -0,0 +1,133 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
|
||||
from core.config.settings import config
|
||||
from core.http.models import HealthResponse
|
||||
from core.http.response import build_problem_details
|
||||
from core.logging import configure_logging, get_logger
|
||||
from v1.router import router as mobile_router
|
||||
|
||||
|
||||
configure_logging(config)
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=config.cors.allow_origins,
|
||||
allow_credentials=config.cors.allow_credentials,
|
||||
allow_methods=config.cors.allow_methods,
|
||||
allow_headers=config.cors.allow_headers,
|
||||
)
|
||||
app.include_router(mobile_router)
|
||||
logger = get_logger("api.app")
|
||||
|
||||
|
||||
@app.get("/health", response_model=HealthResponse)
|
||||
async def health() -> HealthResponse:
|
||||
return HealthResponse(status="ok")
|
||||
|
||||
|
||||
def _build_http_error_response(
|
||||
request: Request,
|
||||
exc: Exception,
|
||||
status_code: int,
|
||||
detail: object,
|
||||
) -> JSONResponse:
|
||||
instance = request.url.path
|
||||
detail_text = detail if isinstance(detail, str) else "Request failed"
|
||||
logger.warning(
|
||||
"HTTP error",
|
||||
status_code=status_code,
|
||||
detail=detail_text,
|
||||
detail_extra=detail,
|
||||
path=request.url.path,
|
||||
method=request.method,
|
||||
)
|
||||
problem = build_problem_details(
|
||||
status_code=status_code,
|
||||
detail=detail_text,
|
||||
instance=instance,
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=status_code,
|
||||
content=problem.model_dump(),
|
||||
media_type="application/problem+json",
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(HTTPException)
|
||||
async def http_exception_handler(
|
||||
request: Request,
|
||||
exc: HTTPException,
|
||||
) -> JSONResponse:
|
||||
return _build_http_error_response(
|
||||
request=request,
|
||||
exc=exc,
|
||||
status_code=exc.status_code,
|
||||
detail=exc.detail,
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(StarletteHTTPException)
|
||||
async def starlette_http_exception_handler(
|
||||
request: Request,
|
||||
exc: StarletteHTTPException,
|
||||
) -> JSONResponse:
|
||||
return _build_http_error_response(
|
||||
request=request,
|
||||
exc=exc,
|
||||
status_code=exc.status_code,
|
||||
detail=exc.detail,
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(
|
||||
request: Request,
|
||||
exc: RequestValidationError,
|
||||
) -> JSONResponse:
|
||||
instance = request.url.path
|
||||
logger.warning(
|
||||
"Request validation error",
|
||||
path=request.url.path,
|
||||
method=request.method,
|
||||
errors=exc.errors(),
|
||||
)
|
||||
problem = build_problem_details(
|
||||
status_code=422,
|
||||
detail="Invalid request",
|
||||
instance=instance,
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=422,
|
||||
content=problem.model_dump(),
|
||||
media_type="application/problem+json",
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def unhandled_exception_handler(
|
||||
request: Request,
|
||||
exc: Exception,
|
||||
) -> JSONResponse:
|
||||
instance = request.url.path
|
||||
logger.exception(
|
||||
"Unhandled error",
|
||||
path=request.url.path,
|
||||
method=request.method,
|
||||
)
|
||||
problem = build_problem_details(
|
||||
status_code=500,
|
||||
detail="Internal Server Error",
|
||||
instance=instance,
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=problem.model_dump(),
|
||||
media_type="application/problem+json",
|
||||
)
|
||||
@@ -0,0 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CurrentUser:
|
||||
id: UUID
|
||||
@@ -0,0 +1,3 @@
|
||||
from .settings import Settings, config
|
||||
|
||||
__all__ = ["Settings", "config"]
|
||||
@@ -0,0 +1,166 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, Literal
|
||||
from urllib.parse import quote
|
||||
|
||||
from pydantic import BaseModel, Field, computed_field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class RuntimeSettings(BaseModel):
|
||||
environment: Literal["dev", "test", "prod"] = "dev"
|
||||
debug: bool = True
|
||||
log_level: str = "INFO"
|
||||
log_json: bool = True
|
||||
log_rotation: Literal["time", "size", "none"] = "time"
|
||||
log_rotation_when: str = "midnight"
|
||||
log_rotation_interval: int = 1
|
||||
log_rotation_backup_count: int = 14
|
||||
log_rotation_max_bytes: int = 10_000_000
|
||||
log_dir: str = "logs"
|
||||
log_error_dir: str = "logs/errors"
|
||||
log_file_name: str = "app.log"
|
||||
log_error_file_name: str = "error.log"
|
||||
log_sensitive_fields: list[str] = Field(
|
||||
default_factory=lambda: [
|
||||
"password",
|
||||
"secret",
|
||||
"token",
|
||||
"api_key",
|
||||
"authorization",
|
||||
"cookie",
|
||||
"client_ip",
|
||||
"user_id",
|
||||
]
|
||||
)
|
||||
sql_log_queries: bool = False
|
||||
|
||||
|
||||
class AppSettings(BaseModel):
|
||||
host: str = "0.0.0.0"
|
||||
port: int = Field(default=8000, ge=1, le=65535)
|
||||
reload: bool = True
|
||||
|
||||
|
||||
class CorsSettings(BaseModel):
|
||||
allow_origins: list[str] = Field(
|
||||
default_factory=lambda: [
|
||||
"http://localhost",
|
||||
"http://localhost:3000",
|
||||
]
|
||||
)
|
||||
allow_credentials: bool = True
|
||||
allow_methods: list[str] = Field(default_factory=lambda: ["*"])
|
||||
allow_headers: list[str] = Field(default_factory=lambda: ["*"])
|
||||
|
||||
|
||||
class RedisSettings(BaseModel):
|
||||
host: str = "redis"
|
||||
port: int = 6379
|
||||
password: str | None = None
|
||||
db: int = 0
|
||||
socket_connect_timeout: float = 1.0
|
||||
socket_timeout: float = 1.0
|
||||
max_connections: int = 10
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def url(self) -> str:
|
||||
if self.password:
|
||||
password = quote(self.password, safe="")
|
||||
return f"redis://:{password}@{self.host}:{self.port}/{self.db}"
|
||||
return f"redis://{self.host}:{self.port}/{self.db}"
|
||||
|
||||
|
||||
class QdrantSettings(BaseModel):
|
||||
host: str = "qdrant"
|
||||
port: int = 6333
|
||||
grpc_port: int = 6334
|
||||
api_key: str | None = None
|
||||
https: bool = False
|
||||
prefer_grpc: bool = True
|
||||
timeout: int = 5
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def url(self) -> str:
|
||||
scheme = "https" if self.https else "http"
|
||||
return f"{scheme}://{self.host}:{self.port}"
|
||||
|
||||
|
||||
class SupabaseSettings(BaseModel):
|
||||
public_scheme: str = "http"
|
||||
public_host: str = "localhost"
|
||||
kong_http_port: int = 8000
|
||||
anon_key: str = "CHANGE_ME"
|
||||
service_role_key: str = "CHANGE_ME"
|
||||
jwt_secret: str | None = None
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def public_url(self) -> str:
|
||||
return f"{self.public_scheme}://{self.public_host}:{self.kong_http_port}"
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def api_external_url(self) -> str:
|
||||
return self.public_url
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def url(self) -> str:
|
||||
return self.public_url
|
||||
|
||||
|
||||
class DatabaseSettings(BaseModel):
|
||||
host: str = "localhost"
|
||||
port: int = 5432
|
||||
name: str = "postgres"
|
||||
user: str = "postgres"
|
||||
password: str = "CHANGE_ME"
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def url(self) -> str:
|
||||
password = quote(self.password, safe="")
|
||||
return (
|
||||
f"postgresql+asyncpg://{self.user}:{password}"
|
||||
f"@{self.host}:{self.port}/{self.name}"
|
||||
)
|
||||
|
||||
|
||||
def _resolve_env_file() -> str:
|
||||
current = Path(__file__).resolve()
|
||||
for parent in [current, *current.parents]:
|
||||
candidate = parent / ".env"
|
||||
if candidate.is_file():
|
||||
return str(candidate)
|
||||
return ".env"
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
runtime: RuntimeSettings = RuntimeSettings()
|
||||
app: AppSettings = AppSettings()
|
||||
cors: CorsSettings = CorsSettings()
|
||||
redis: RedisSettings = RedisSettings()
|
||||
qdrant: QdrantSettings = QdrantSettings()
|
||||
supabase: SupabaseSettings = SupabaseSettings()
|
||||
|
||||
database: DatabaseSettings = DatabaseSettings()
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def database_url(self) -> str:
|
||||
return self.database.url
|
||||
|
||||
model_config: ClassVar[SettingsConfigDict] = SettingsConfigDict(
|
||||
env_file=_resolve_env_file(),
|
||||
env_prefix="SOCIAL_",
|
||||
env_nested_delimiter="__",
|
||||
case_sensitive=False,
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
|
||||
config = Settings()
|
||||
@@ -0,0 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.db.session import AsyncSessionLocal, engine, get_db
|
||||
|
||||
__all__ = ["AsyncSessionLocal", "engine", "get_db"]
|
||||
@@ -0,0 +1,37 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, func
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""Base class for all ORM models."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TimestampMixin:
|
||||
"""Adds created_at and updated_at timestamps."""
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
|
||||
class SoftDeleteMixin:
|
||||
"""Adds soft delete timestamp column."""
|
||||
|
||||
deleted_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=True,
|
||||
)
|
||||
@@ -0,0 +1,84 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from sqlalchemy import Select, select, update
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.db.base import Base
|
||||
|
||||
ModelType = TypeVar("ModelType", bound=Base)
|
||||
|
||||
|
||||
class BaseRepository(Generic[ModelType]):
|
||||
_session: AsyncSession
|
||||
_model: type[ModelType]
|
||||
|
||||
def __init__(self, session: AsyncSession, model: type[ModelType]) -> None:
|
||||
self._session = session
|
||||
self._model = model
|
||||
|
||||
def _deleted_at_column(self) -> Any | None:
|
||||
return getattr(self._model, "deleted_at", None)
|
||||
|
||||
def _apply_soft_delete_filter(self, stmt: Select) -> Select:
|
||||
deleted_at = self._deleted_at_column()
|
||||
if deleted_at is None:
|
||||
return stmt
|
||||
return stmt.where(deleted_at.is_(None))
|
||||
|
||||
async def get_by_id(self, entity_id: Any) -> ModelType | None:
|
||||
id_column = getattr(self._model, "id")
|
||||
stmt = select(self._model).where(id_column == entity_id)
|
||||
stmt = self._apply_soft_delete_filter(stmt)
|
||||
result = await self._session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_one(self, *filters: Any) -> ModelType | None:
|
||||
stmt = select(self._model).where(*filters)
|
||||
stmt = self._apply_soft_delete_filter(stmt)
|
||||
result = await self._session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def update_by_id(
|
||||
self, entity_id: Any, update_data: dict[str, Any]
|
||||
) -> ModelType | None:
|
||||
if not update_data:
|
||||
return await self.get_by_id(entity_id)
|
||||
|
||||
id_column = getattr(self._model, "id")
|
||||
stmt = update(self._model).where(id_column == entity_id)
|
||||
deleted_at = self._deleted_at_column()
|
||||
if deleted_at is not None:
|
||||
stmt = stmt.where(deleted_at.is_(None))
|
||||
stmt = stmt.values(**update_data).returning(self._model)
|
||||
|
||||
try:
|
||||
result = await self._session.execute(stmt)
|
||||
await self._session.flush()
|
||||
return result.scalar_one_or_none()
|
||||
except SQLAlchemyError:
|
||||
raise
|
||||
|
||||
async def soft_delete_by_id(self, entity_id: Any) -> ModelType | None:
|
||||
deleted_at = self._deleted_at_column()
|
||||
if deleted_at is None:
|
||||
raise ValueError("Soft delete is not supported for this model")
|
||||
|
||||
id_column = getattr(self._model, "id")
|
||||
stmt = (
|
||||
update(self._model)
|
||||
.where(id_column == entity_id)
|
||||
.where(deleted_at.is_(None))
|
||||
.values(deleted_at=datetime.now(timezone.utc))
|
||||
.returning(self._model)
|
||||
)
|
||||
|
||||
try:
|
||||
result = await self._session.execute(stmt)
|
||||
await self._session.flush()
|
||||
return result.scalar_one_or_none()
|
||||
except SQLAlchemyError:
|
||||
raise
|
||||
@@ -0,0 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
|
||||
|
||||
class BaseService:
|
||||
_current_user: CurrentUser | None
|
||||
|
||||
def __init__(self, current_user: CurrentUser | None) -> None:
|
||||
self._current_user = current_user
|
||||
|
||||
def require_current_user(self) -> CurrentUser:
|
||||
if self._current_user is None:
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
return self._current_user
|
||||
|
||||
def require_user_id(self) -> UUID:
|
||||
return self.require_current_user().id
|
||||
@@ -0,0 +1,34 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from core.config.settings import config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine
|
||||
|
||||
engine: AsyncEngine = create_async_engine(
|
||||
config.database_url,
|
||||
echo=config.runtime.sql_log_queries,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
|
||||
AsyncSessionLocal: async_sessionmaker[AsyncSession] = async_sessionmaker(
|
||||
bind=engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
|
||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Dependency that provides a database session.
|
||||
|
||||
The session is automatically closed when the request completes.
|
||||
Note: The caller (service layer) is responsible for commit/rollback.
|
||||
"""
|
||||
async with AsyncSessionLocal() as session:
|
||||
yield session
|
||||
@@ -0,0 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.http.response import ProblemDetails, build_problem_details
|
||||
|
||||
__all__ = ["ProblemDetails", "build_problem_details"]
|
||||
@@ -0,0 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
status: str
|
||||
@@ -0,0 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from http import HTTPStatus
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ProblemDetails(BaseModel):
|
||||
type: str = "about:blank"
|
||||
title: str
|
||||
status: int
|
||||
detail: str
|
||||
instance: str | None = None
|
||||
|
||||
|
||||
def build_problem_details(
|
||||
*,
|
||||
status_code: int,
|
||||
detail: str,
|
||||
type_value: str = "about:blank",
|
||||
title: str | None = None,
|
||||
instance: str | None = None,
|
||||
) -> ProblemDetails:
|
||||
resolved_title = title or HTTPStatus(status_code).phrase
|
||||
return ProblemDetails(
|
||||
type=type_value,
|
||||
title=resolved_title,
|
||||
status=status_code,
|
||||
detail=detail,
|
||||
instance=instance,
|
||||
)
|
||||
@@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.logging import celery
|
||||
from core.logging.config import configure_logging
|
||||
from core.logging.context import bind_context, clear_context, get_context
|
||||
from core.logging.logger import get_logger
|
||||
|
||||
__all__ = [
|
||||
"bind_context",
|
||||
"celery",
|
||||
"clear_context",
|
||||
"configure_logging",
|
||||
"get_context",
|
||||
"get_logger",
|
||||
]
|
||||
@@ -0,0 +1,57 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery, signals
|
||||
|
||||
from core.config.settings import Settings
|
||||
from core.logging.config import configure_logging
|
||||
from core.logging.context import bind_context, clear_context
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CelerySignalHandlers:
|
||||
on_setup_logging: Callable[..., None]
|
||||
on_after_setup_task_logger: Callable[..., None]
|
||||
on_task_prerun: Callable[..., None]
|
||||
on_task_postrun: Callable[..., None]
|
||||
|
||||
|
||||
def build_celery_signal_handlers(
|
||||
settings: Settings | None = None,
|
||||
) -> CelerySignalHandlers:
|
||||
def on_setup_logging(*_args: object, **_kwargs: object) -> None:
|
||||
configure_logging(settings)
|
||||
|
||||
def on_after_setup_task_logger(*_args: object, **_kwargs: object) -> None:
|
||||
configure_logging(settings)
|
||||
|
||||
def on_task_prerun(*_args: object, **kwargs: object) -> None:
|
||||
task_id = cast(str | None, kwargs.get("task_id"))
|
||||
task = kwargs.get("task")
|
||||
task_name = getattr(task, "name", None)
|
||||
bind_context(task_id=task_id, task_name=task_name)
|
||||
|
||||
def on_task_postrun(*_args: object, **_kwargs: object) -> None:
|
||||
clear_context()
|
||||
|
||||
return CelerySignalHandlers(
|
||||
on_setup_logging=on_setup_logging,
|
||||
on_after_setup_task_logger=on_after_setup_task_logger,
|
||||
on_task_prerun=on_task_prerun,
|
||||
on_task_postrun=on_task_postrun,
|
||||
)
|
||||
|
||||
|
||||
def configure_celery_app(app: Celery, settings: Settings | None = None) -> None:
|
||||
app.conf.worker_hijack_root_logger = False
|
||||
|
||||
handlers = build_celery_signal_handlers(settings)
|
||||
signals.setup_logging.connect(handlers.on_setup_logging, weak=False)
|
||||
signals.after_setup_task_logger.connect(
|
||||
handlers.on_after_setup_task_logger, weak=False
|
||||
)
|
||||
signals.task_prerun.connect(handlers.on_task_prerun, weak=False)
|
||||
signals.task_postrun.connect(handlers.on_task_postrun, weak=False)
|
||||
@@ -0,0 +1,103 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from logging.config import dictConfig
|
||||
from pathlib import Path
|
||||
|
||||
import structlog
|
||||
|
||||
from core.config.settings import RuntimeSettings, Settings
|
||||
from core.logging.formatters import (
|
||||
build_plain_formatter,
|
||||
build_processor_formatter,
|
||||
ensure_message_key,
|
||||
)
|
||||
from core.logging.filters import build_sensitive_data_processor
|
||||
from core.logging.handlers import build_file_handler_config
|
||||
|
||||
|
||||
def _ensure_log_dirs(runtime: RuntimeSettings) -> None:
|
||||
Path(runtime.log_dir).mkdir(parents=True, exist_ok=True)
|
||||
Path(runtime.log_error_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def build_logging_config(runtime: RuntimeSettings) -> dict[str, object]:
|
||||
log_dir = Path(runtime.log_dir)
|
||||
error_dir = Path(runtime.log_error_dir)
|
||||
formatter_name = "json" if runtime.log_json else "plain"
|
||||
|
||||
file_handler = build_file_handler_config(
|
||||
runtime,
|
||||
file_path=log_dir / runtime.log_file_name,
|
||||
level=runtime.log_level,
|
||||
formatter=formatter_name,
|
||||
)
|
||||
error_handler = build_file_handler_config(
|
||||
runtime,
|
||||
file_path=error_dir / runtime.log_error_file_name,
|
||||
level="ERROR",
|
||||
formatter=formatter_name,
|
||||
filters=["error_only"],
|
||||
)
|
||||
|
||||
return {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"filters": {
|
||||
"error_only": {
|
||||
"()": "core.logging.filters.ErrorLevelFilter",
|
||||
}
|
||||
},
|
||||
"formatters": {
|
||||
"json": {
|
||||
"()": build_processor_formatter,
|
||||
"sensitive_fields": runtime.log_sensitive_fields,
|
||||
},
|
||||
"plain": {
|
||||
"()": build_plain_formatter,
|
||||
"sensitive_fields": runtime.log_sensitive_fields,
|
||||
},
|
||||
},
|
||||
"handlers": {
|
||||
"file": file_handler,
|
||||
"error": error_handler,
|
||||
},
|
||||
"root": {
|
||||
"handlers": ["file", "error"],
|
||||
"level": runtime.log_level,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def configure_logging(settings: Settings | None = None) -> None:
|
||||
active_settings = settings or Settings()
|
||||
runtime = active_settings.runtime
|
||||
|
||||
try:
|
||||
_ensure_log_dirs(runtime)
|
||||
dictConfig(build_logging_config(runtime))
|
||||
except (OSError, ValueError) as exc:
|
||||
logging.basicConfig(level=runtime.log_level)
|
||||
logging.getLogger(__name__).error("Logging setup failed", exc_info=exc)
|
||||
|
||||
structlog.configure(
|
||||
processors=[
|
||||
structlog.contextvars.merge_contextvars,
|
||||
structlog.processors.add_log_level,
|
||||
structlog.processors.TimeStamper(fmt="iso", utc=True),
|
||||
structlog.processors.CallsiteParameterAdder(
|
||||
parameters=[
|
||||
structlog.processors.CallsiteParameter.MODULE,
|
||||
structlog.processors.CallsiteParameter.FUNC_NAME,
|
||||
structlog.processors.CallsiteParameter.LINENO,
|
||||
]
|
||||
),
|
||||
build_sensitive_data_processor(runtime.log_sensitive_fields),
|
||||
ensure_message_key,
|
||||
structlog.processors.format_exc_info,
|
||||
structlog.processors.UnicodeDecoder(),
|
||||
structlog.stdlib.ProcessorFormatter.wrap_for_formatter,
|
||||
],
|
||||
logger_factory=structlog.stdlib.LoggerFactory(),
|
||||
cache_logger_on_first_use=True,
|
||||
)
|
||||
@@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from structlog import contextvars
|
||||
|
||||
|
||||
def bind_context(**values: object) -> None:
|
||||
contextvars.bind_contextvars(**values)
|
||||
|
||||
|
||||
def clear_context() -> None:
|
||||
contextvars.clear_contextvars()
|
||||
|
||||
|
||||
def get_context() -> dict[str, object]:
|
||||
return contextvars.get_contextvars()
|
||||
@@ -0,0 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from structlog.types import EventDict
|
||||
|
||||
|
||||
_NORMALIZE_PATTERN = re.compile(r"[^a-z0-9]")
|
||||
|
||||
|
||||
def _normalize_key(value: str) -> str:
|
||||
return _NORMALIZE_PATTERN.sub("", value.lower())
|
||||
|
||||
|
||||
def _is_sensitive_key(key: object, sensitive_fields: set[str]) -> bool:
|
||||
normalized_key = _normalize_key(str(key))
|
||||
return normalized_key in sensitive_fields or any(
|
||||
fragment in normalized_key for fragment in sensitive_fields
|
||||
)
|
||||
|
||||
|
||||
def _redact_value(value: object, sensitive_fields: set[str]) -> object:
|
||||
if isinstance(value, dict):
|
||||
typed_value = cast(dict[str, object], value)
|
||||
return {
|
||||
key: (
|
||||
"[REDACTED]"
|
||||
if _is_sensitive_key(key, sensitive_fields)
|
||||
else _redact_value(inner, sensitive_fields)
|
||||
)
|
||||
for key, inner in typed_value.items()
|
||||
}
|
||||
if isinstance(value, list):
|
||||
return [_redact_value(item, sensitive_fields) for item in value]
|
||||
return value
|
||||
|
||||
|
||||
def build_sensitive_data_processor(
|
||||
sensitive_fields: list[str],
|
||||
) -> Callable[[object, str, EventDict], EventDict]:
|
||||
normalized = {_normalize_key(field) for field in sensitive_fields}
|
||||
|
||||
def processor(
|
||||
_logger: object, _method_name: str, event_dict: EventDict
|
||||
) -> EventDict:
|
||||
return cast(EventDict, _redact_value(event_dict, normalized))
|
||||
|
||||
return processor
|
||||
|
||||
|
||||
class ErrorLevelFilter(logging.Filter):
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
return record.levelno >= logging.ERROR
|
||||
@@ -0,0 +1,81 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from structlog.dev import ConsoleRenderer
|
||||
from structlog.processors import JSONRenderer
|
||||
from structlog.stdlib import ProcessorFormatter
|
||||
from structlog.types import EventDict
|
||||
import structlog
|
||||
|
||||
from core.logging.filters import build_sensitive_data_processor
|
||||
|
||||
|
||||
def ensure_message_key(
|
||||
_logger: object, _method_name: str, event_dict: EventDict
|
||||
) -> EventDict:
|
||||
if "message" in event_dict:
|
||||
return event_dict
|
||||
if "event" not in event_dict:
|
||||
return event_dict
|
||||
|
||||
without_event = {key: value for key, value in event_dict.items() if key != "event"}
|
||||
return {**without_event, "message": event_dict["event"]}
|
||||
|
||||
|
||||
def build_processor_formatter(
|
||||
sensitive_fields: list[str] | None = None,
|
||||
) -> ProcessorFormatter:
|
||||
redact = build_sensitive_data_processor(sensitive_fields or [])
|
||||
return ProcessorFormatter(
|
||||
foreign_pre_chain=[
|
||||
structlog.contextvars.merge_contextvars,
|
||||
structlog.processors.add_log_level,
|
||||
structlog.processors.TimeStamper(fmt="iso", utc=True),
|
||||
structlog.processors.CallsiteParameterAdder(
|
||||
parameters=[
|
||||
structlog.processors.CallsiteParameter.MODULE,
|
||||
structlog.processors.CallsiteParameter.FUNC_NAME,
|
||||
structlog.processors.CallsiteParameter.LINENO,
|
||||
]
|
||||
),
|
||||
structlog.stdlib.ExtraAdder(),
|
||||
ensure_message_key,
|
||||
],
|
||||
processors=[
|
||||
redact,
|
||||
ensure_message_key,
|
||||
ProcessorFormatter.remove_processors_meta,
|
||||
structlog.processors.format_exc_info,
|
||||
structlog.processors.UnicodeDecoder(),
|
||||
JSONRenderer(sort_keys=True),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def build_plain_formatter(
|
||||
sensitive_fields: list[str] | None = None,
|
||||
) -> ProcessorFormatter:
|
||||
redact = build_sensitive_data_processor(sensitive_fields or [])
|
||||
return ProcessorFormatter(
|
||||
foreign_pre_chain=[
|
||||
structlog.contextvars.merge_contextvars,
|
||||
structlog.processors.add_log_level,
|
||||
structlog.processors.TimeStamper(fmt="iso", utc=True),
|
||||
structlog.processors.CallsiteParameterAdder(
|
||||
parameters=[
|
||||
structlog.processors.CallsiteParameter.MODULE,
|
||||
structlog.processors.CallsiteParameter.FUNC_NAME,
|
||||
structlog.processors.CallsiteParameter.LINENO,
|
||||
]
|
||||
),
|
||||
structlog.stdlib.ExtraAdder(),
|
||||
ensure_message_key,
|
||||
],
|
||||
processors=[
|
||||
redact,
|
||||
ensure_message_key,
|
||||
ProcessorFormatter.remove_processors_meta,
|
||||
structlog.processors.format_exc_info,
|
||||
structlog.processors.UnicodeDecoder(),
|
||||
ConsoleRenderer(colors=False),
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,46 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from core.config.settings import RuntimeSettings
|
||||
|
||||
|
||||
def build_file_handler_config(
|
||||
runtime: RuntimeSettings,
|
||||
file_path: Path,
|
||||
level: str,
|
||||
formatter: str,
|
||||
filters: list[str] | None = None,
|
||||
) -> dict[str, object]:
|
||||
filter_list = list(filters or [])
|
||||
base_config: dict[str, object] = {
|
||||
"level": level,
|
||||
"formatter": formatter,
|
||||
"filename": str(file_path),
|
||||
"encoding": "utf-8",
|
||||
}
|
||||
|
||||
if filter_list:
|
||||
base_config = {**base_config, "filters": filter_list}
|
||||
|
||||
if runtime.log_rotation == "time":
|
||||
return {
|
||||
**base_config,
|
||||
"class": "logging.handlers.TimedRotatingFileHandler",
|
||||
"when": runtime.log_rotation_when,
|
||||
"interval": runtime.log_rotation_interval,
|
||||
"backupCount": runtime.log_rotation_backup_count,
|
||||
}
|
||||
|
||||
if runtime.log_rotation == "size":
|
||||
return {
|
||||
**base_config,
|
||||
"class": "logging.handlers.RotatingFileHandler",
|
||||
"maxBytes": runtime.log_rotation_max_bytes,
|
||||
"backupCount": runtime.log_rotation_backup_count,
|
||||
}
|
||||
|
||||
return {
|
||||
**base_config,
|
||||
"class": "logging.FileHandler",
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import structlog
|
||||
|
||||
|
||||
def get_logger(name: str) -> structlog.stdlib.BoundLogger:
|
||||
return structlog.get_logger(name)
|
||||
@@ -0,0 +1,84 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from collections.abc import MutableMapping
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from starlette.requests import Request as StarletteRequest
|
||||
from starlette.responses import JSONResponse, Response
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
from core.logging.context import bind_context, clear_context
|
||||
from core.logging.logger import get_logger
|
||||
|
||||
|
||||
class RequestContextMiddleware:
|
||||
app: ASGIApp
|
||||
_header_name: str
|
||||
_request_id_pattern: re.Pattern[str]
|
||||
|
||||
def __init__(self, app: ASGIApp, header_name: str = "X-Request-ID") -> None:
|
||||
self.app = app
|
||||
self._header_name = header_name
|
||||
self._request_id_pattern = re.compile(r"^[A-Za-z0-9_-]{8,64}$")
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope.get("type") != "http":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
request = StarletteRequest(scope, receive=receive)
|
||||
request_id = self._normalize_request_id(request.headers.get(self._header_name))
|
||||
client_ip = request.client.host if request.client else None
|
||||
user_id = getattr(request.state, "user_id", None)
|
||||
|
||||
request.state.request_id = request_id
|
||||
|
||||
bind_context(
|
||||
request_id=request_id,
|
||||
method=request.method,
|
||||
path=request.url.path,
|
||||
client_ip=client_ip,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
async def send_wrapper(message: MutableMapping[str, object]) -> None:
|
||||
if message.get("type") == "http.response.start":
|
||||
raw_headers = message.get("headers")
|
||||
headers = list(cast(list[tuple[bytes, bytes]], raw_headers or []))
|
||||
header_key = self._header_name.lower().encode()
|
||||
if not any(item[0].lower() == header_key for item in headers):
|
||||
headers.append((header_key, request_id.encode()))
|
||||
message = {**message, "headers": headers}
|
||||
await send(message)
|
||||
|
||||
try:
|
||||
await self.app(scope, receive, send_wrapper)
|
||||
finally:
|
||||
clear_context()
|
||||
|
||||
def _normalize_request_id(self, request_id: str | None) -> str:
|
||||
if request_id and self._request_id_pattern.match(request_id):
|
||||
return request_id
|
||||
return str(uuid4())
|
||||
|
||||
|
||||
def register_exception_handlers(app: FastAPI) -> None:
|
||||
logger = get_logger("core.logging.exception")
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def unhandled_exception_handler(request: Request, exc: Exception) -> Response:
|
||||
request_id = getattr(request.state, "request_id", None)
|
||||
logger.exception(
|
||||
"Unhandled exception",
|
||||
error_type=exc.__class__.__name__,
|
||||
request_id=request_id,
|
||||
)
|
||||
headers = {"X-Request-ID": request_id} if request_id else None
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "Internal Server Error"},
|
||||
headers=headers,
|
||||
)
|
||||
@@ -0,0 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from models.profile import Profile
|
||||
|
||||
__all__ = ["Profile"]
|
||||
@@ -0,0 +1,43 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import String, Text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from core.db.base import Base, SoftDeleteMixin, TimestampMixin
|
||||
|
||||
|
||||
class Profile(TimestampMixin, SoftDeleteMixin, Base):
|
||||
"""User profile model.
|
||||
|
||||
Note: The `id` column references auth.users(id) in Supabase.
|
||||
This is a business table managed by SQLAlchemy, with the foreign key
|
||||
relationship to Supabase's auth schema handled at the database level.
|
||||
"""
|
||||
|
||||
__tablename__: str = "profiles"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
)
|
||||
username: Mapped[str] = mapped_column(
|
||||
String(30),
|
||||
unique=True,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
display_name: Mapped[str | None] = mapped_column(
|
||||
String(50),
|
||||
nullable=True,
|
||||
)
|
||||
avatar_url: Mapped[str | None] = mapped_column(
|
||||
Text,
|
||||
nullable=True,
|
||||
)
|
||||
bio: Mapped[str | None] = mapped_column(
|
||||
String(200),
|
||||
nullable=True,
|
||||
)
|
||||
@@ -0,0 +1 @@
|
||||
from __future__ import annotations
|
||||
@@ -0,0 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from services.base.qdrant import QdrantService, qdrant_service
|
||||
from services.base.redis import RedisService, redis_service
|
||||
from services.base.service_interface import (
|
||||
BaseServiceProvider,
|
||||
ServiceRegistry,
|
||||
register_service,
|
||||
register_service_instance,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseServiceProvider",
|
||||
"QdrantService",
|
||||
"RedisService",
|
||||
"ServiceRegistry",
|
||||
"qdrant_service",
|
||||
"redis_service",
|
||||
"register_service",
|
||||
"register_service_instance",
|
||||
]
|
||||
@@ -0,0 +1,94 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
from core.config.settings import QdrantSettings, config
|
||||
|
||||
from .service_interface import BaseServiceProvider, register_service_instance
|
||||
|
||||
|
||||
class QdrantService(BaseServiceProvider):
|
||||
def __init__(self, settings: QdrantSettings | None = None) -> None:
|
||||
super().__init__("qdrant")
|
||||
self._settings = settings or config.qdrant
|
||||
self._client: Optional[QdrantClient] = None
|
||||
|
||||
def _build_client(self) -> QdrantClient:
|
||||
return QdrantClient(
|
||||
url=self._settings.url,
|
||||
api_key=self._settings.api_key,
|
||||
timeout=self._settings.timeout,
|
||||
prefer_grpc=self._settings.prefer_grpc,
|
||||
)
|
||||
|
||||
def _require_client(self) -> QdrantClient:
|
||||
client = self._client
|
||||
if client is None:
|
||||
raise RuntimeError("Qdrant client is not initialized")
|
||||
return client
|
||||
|
||||
async def initialize(self, **_: Any) -> bool:
|
||||
try:
|
||||
client = self._build_client()
|
||||
collections = await asyncio.to_thread(client.get_collections)
|
||||
self.logger.info(
|
||||
"Qdrant service initialized",
|
||||
collections_count=len(collections.collections),
|
||||
)
|
||||
self._client = client
|
||||
self._set_initialized(True)
|
||||
return True
|
||||
except Exception as exc: # noqa: BLE001
|
||||
self.logger.warning("Qdrant service initialization failed", error=str(exc))
|
||||
self._client = None
|
||||
self._set_initialized(False)
|
||||
return False
|
||||
|
||||
async def close(self) -> bool:
|
||||
client = self._client
|
||||
if client is None:
|
||||
return True
|
||||
try:
|
||||
close = getattr(client, "close", None)
|
||||
if callable(close):
|
||||
await asyncio.to_thread(close)
|
||||
self.logger.info("Qdrant service closed")
|
||||
self._client = None
|
||||
self._set_initialized(False)
|
||||
return True
|
||||
except Exception as exc: # noqa: BLE001
|
||||
self.logger.exception("Qdrant service close failed", error=str(exc))
|
||||
self._client = None
|
||||
self._set_initialized(False)
|
||||
return False
|
||||
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
client = self._client
|
||||
if client is None:
|
||||
return {"status": "unhealthy", "details": {"error": "not initialized"}}
|
||||
try:
|
||||
collections = await asyncio.to_thread(client.get_collections)
|
||||
return {
|
||||
"status": "healthy",
|
||||
"details": {
|
||||
"connected": True,
|
||||
"collections_count": len(collections.collections),
|
||||
"collections": [
|
||||
collection.name for collection in collections.collections[:5]
|
||||
],
|
||||
},
|
||||
}
|
||||
except Exception as exc: # noqa: BLE001
|
||||
self.logger.warning("Qdrant health check failed", error=str(exc))
|
||||
return {"status": "unhealthy", "details": {"error": str(exc)}}
|
||||
|
||||
def get_client(self) -> QdrantClient:
|
||||
return self._require_client()
|
||||
|
||||
|
||||
qdrant_service: QdrantService = register_service_instance("qdrant", QdrantService())
|
||||
|
||||
__all__ = ["QdrantService", "qdrant_service"]
|
||||
@@ -0,0 +1,97 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import redis.asyncio as redis
|
||||
|
||||
from core.config.settings import RedisSettings, config
|
||||
|
||||
from .service_interface import BaseServiceProvider, register_service_instance
|
||||
|
||||
|
||||
class RedisService(BaseServiceProvider):
|
||||
def __init__(self, settings: RedisSettings | None = None) -> None:
|
||||
super().__init__("redis")
|
||||
self._settings = settings or config.redis
|
||||
self._client: Optional[redis.Redis] = None
|
||||
|
||||
def _build_client(self) -> redis.Redis:
|
||||
return redis.from_url(
|
||||
self._settings.url,
|
||||
decode_responses=True,
|
||||
socket_connect_timeout=self._settings.socket_connect_timeout,
|
||||
socket_timeout=self._settings.socket_timeout,
|
||||
max_connections=self._settings.max_connections,
|
||||
)
|
||||
|
||||
def _require_client(self) -> redis.Redis:
|
||||
client = self._client
|
||||
if client is None:
|
||||
raise RuntimeError("Redis client is not initialized")
|
||||
return client
|
||||
|
||||
async def initialize(self, **_: Any) -> bool:
|
||||
try:
|
||||
client = self._build_client()
|
||||
ping_result = client.ping()
|
||||
if inspect.isawaitable(ping_result):
|
||||
await ping_result
|
||||
self._client = client
|
||||
self._set_initialized(True)
|
||||
self.logger.info("Redis service initialized")
|
||||
return True
|
||||
except Exception as exc: # noqa: BLE001
|
||||
self.logger.warning("Redis service initialization failed", error=str(exc))
|
||||
self._client = None
|
||||
self._set_initialized(False)
|
||||
return False
|
||||
|
||||
async def close(self) -> bool:
|
||||
client = self._client
|
||||
if client is None:
|
||||
return True
|
||||
try:
|
||||
await client.aclose()
|
||||
self.logger.info("Redis service closed")
|
||||
self._client = None
|
||||
self._set_initialized(False)
|
||||
return True
|
||||
except Exception as exc: # noqa: BLE001
|
||||
self.logger.exception("Redis service close failed", error=str(exc))
|
||||
return False
|
||||
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
client = self._client
|
||||
if client is None:
|
||||
return {"status": "unhealthy", "details": {"error": "not initialized"}}
|
||||
try:
|
||||
ping_result = client.ping()
|
||||
ping = (
|
||||
await ping_result if inspect.isawaitable(ping_result) else ping_result
|
||||
)
|
||||
info_result = client.info()
|
||||
info = (
|
||||
await info_result if inspect.isawaitable(info_result) else info_result
|
||||
)
|
||||
return {
|
||||
"status": "healthy" if ping else "unhealthy",
|
||||
"details": {
|
||||
"ping": ping,
|
||||
"redis_version": info.get("redis_version"),
|
||||
"connected_clients": info.get("connected_clients"),
|
||||
"used_memory": info.get("used_memory_human"),
|
||||
"uptime_in_seconds": info.get("uptime_in_seconds"),
|
||||
},
|
||||
}
|
||||
except Exception as exc: # noqa: BLE001
|
||||
self.logger.warning("Redis health check failed", error=str(exc))
|
||||
return {"status": "unhealthy", "details": {"error": str(exc)}}
|
||||
|
||||
def get_client(self) -> redis.Redis:
|
||||
return self._require_client()
|
||||
|
||||
|
||||
redis_service: RedisService = register_service_instance("redis", RedisService())
|
||||
|
||||
__all__ = ["RedisService", "redis_service"]
|
||||
@@ -0,0 +1,84 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, Optional, TypeVar
|
||||
|
||||
from core.logging import get_logger
|
||||
|
||||
|
||||
class BaseServiceProvider(ABC):
|
||||
def __init__(self, service_name: str) -> None:
|
||||
self.service_name = service_name
|
||||
self._initialized = False
|
||||
self.logger = get_logger("services.base").bind(service=service_name)
|
||||
|
||||
@abstractmethod
|
||||
async def initialize(self, **kwargs: Any) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def close(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
return self._initialized
|
||||
|
||||
def _set_initialized(self, value: bool) -> None:
|
||||
self._initialized = value
|
||||
|
||||
def get_service_info(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"name": self.service_name,
|
||||
"initialized": self._initialized,
|
||||
"type": self.__class__.__name__,
|
||||
}
|
||||
|
||||
|
||||
class ServiceRegistry:
|
||||
_services: Dict[str, Callable[..., BaseServiceProvider]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(
|
||||
cls, service_name: str, factory: Callable[..., BaseServiceProvider]
|
||||
) -> None:
|
||||
cls._services = {**cls._services, service_name: factory}
|
||||
|
||||
@classmethod
|
||||
def get_service_factory(
|
||||
cls, service_name: str
|
||||
) -> Optional[Callable[..., BaseServiceProvider]]:
|
||||
return cls._services.get(service_name)
|
||||
|
||||
@classmethod
|
||||
def list_services(cls) -> list[str]:
|
||||
return sorted(cls._services.keys())
|
||||
|
||||
@classmethod
|
||||
def create_service(
|
||||
cls, service_name: str, **kwargs: Any
|
||||
) -> Optional[BaseServiceProvider]:
|
||||
factory = cls.get_service_factory(service_name)
|
||||
if not factory:
|
||||
return None
|
||||
return factory(**kwargs)
|
||||
|
||||
|
||||
def register_service(service_name: str) -> Callable[[type], type]:
|
||||
def decorator(service_class: type) -> type:
|
||||
ServiceRegistry.register(service_name, service_class)
|
||||
return service_class
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
TService = TypeVar("TService", bound=BaseServiceProvider)
|
||||
|
||||
|
||||
def register_service_instance(service_name: str, service: TService) -> TService:
|
||||
ServiceRegistry.register(service_name, lambda: service)
|
||||
return service
|
||||
@@ -0,0 +1 @@
|
||||
from __future__ import annotations
|
||||
@@ -0,0 +1 @@
|
||||
from __future__ import annotations
|
||||
@@ -0,0 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from v1.auth.service import AuthService, SupabaseAuthGateway
|
||||
|
||||
|
||||
def get_auth_service() -> AuthService:
|
||||
return AuthService(gateway=SupabaseAuthGateway())
|
||||
@@ -0,0 +1,35 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
|
||||
|
||||
class SignupRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str = Field(min_length=6)
|
||||
display_name: str | None = None
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str = Field(min_length=6)
|
||||
|
||||
|
||||
class RefreshRequest(BaseModel):
|
||||
refresh_token: str = Field(min_length=1)
|
||||
|
||||
|
||||
class LogoutRequest(BaseModel):
|
||||
refresh_token: str = Field(min_length=1)
|
||||
|
||||
|
||||
class AuthUser(BaseModel):
|
||||
id: str
|
||||
email: EmailStr
|
||||
|
||||
|
||||
class AuthTokenResponse(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
expires_in: int
|
||||
token_type: str
|
||||
user: AuthUser
|
||||
@@ -0,0 +1,49 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends, Response
|
||||
|
||||
from v1.auth.dependencies import get_auth_service
|
||||
from v1.auth.models import (
|
||||
AuthTokenResponse,
|
||||
LoginRequest,
|
||||
LogoutRequest,
|
||||
RefreshRequest,
|
||||
SignupRequest,
|
||||
)
|
||||
from v1.auth.service import AuthService
|
||||
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
|
||||
@router.post("/signup", response_model=AuthTokenResponse)
|
||||
async def signup(
|
||||
payload: SignupRequest,
|
||||
service: AuthService = Depends(get_auth_service),
|
||||
) -> AuthTokenResponse:
|
||||
return await service.signup(payload)
|
||||
|
||||
|
||||
@router.post("/login", response_model=AuthTokenResponse)
|
||||
async def login(
|
||||
payload: LoginRequest,
|
||||
service: AuthService = Depends(get_auth_service),
|
||||
) -> AuthTokenResponse:
|
||||
return await service.login(payload)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=AuthTokenResponse)
|
||||
async def refresh(
|
||||
payload: RefreshRequest,
|
||||
service: AuthService = Depends(get_auth_service),
|
||||
) -> AuthTokenResponse:
|
||||
return await service.refresh(payload)
|
||||
|
||||
|
||||
@router.post("/logout", status_code=204)
|
||||
async def logout(
|
||||
payload: LogoutRequest,
|
||||
service: AuthService = Depends(get_auth_service),
|
||||
) -> Response:
|
||||
await service.logout(payload.refresh_token)
|
||||
return Response(status_code=204)
|
||||
@@ -0,0 +1,147 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Protocol, cast
|
||||
|
||||
from fastapi import HTTPException
|
||||
from supabase import AuthError, create_client
|
||||
|
||||
from core.config.settings import SupabaseSettings, config
|
||||
from core.logging import get_logger
|
||||
from v1.auth.models import (
|
||||
AuthTokenResponse,
|
||||
AuthUser,
|
||||
LoginRequest,
|
||||
RefreshRequest,
|
||||
SignupRequest,
|
||||
)
|
||||
|
||||
|
||||
logger = get_logger("v1.auth.service")
|
||||
|
||||
|
||||
class AuthServiceGateway(Protocol):
|
||||
async def signup(self, request: SignupRequest) -> AuthTokenResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
async def login(self, request: LoginRequest) -> AuthTokenResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
async def refresh(self, request: RefreshRequest) -> AuthTokenResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
async def logout(self, refresh_token: str | None) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SupabaseAuthGateway(AuthServiceGateway):
|
||||
_client: Any
|
||||
|
||||
def __init__(self) -> None:
|
||||
settings: SupabaseSettings = config.supabase
|
||||
self._client = create_client(settings.url, settings.anon_key)
|
||||
|
||||
async def signup(self, request: SignupRequest) -> AuthTokenResponse:
|
||||
payload: dict[str, Any] = {
|
||||
"email": request.email,
|
||||
"password": request.password,
|
||||
}
|
||||
if request.display_name:
|
||||
payload = {
|
||||
**payload,
|
||||
"data": {"display_name": request.display_name},
|
||||
}
|
||||
try:
|
||||
sign_up = cast(Any, self._client.auth.sign_up)
|
||||
response = await asyncio.to_thread(sign_up, payload)
|
||||
return _map_auth_response(response, "Authentication failed")
|
||||
except AuthError as exc:
|
||||
logger.warning("Signup failed", error=str(exc))
|
||||
raise HTTPException(
|
||||
status_code=401, detail="Authentication failed"
|
||||
) from exc
|
||||
|
||||
async def login(self, request: LoginRequest) -> AuthTokenResponse:
|
||||
payload: dict[str, Any] = {"email": request.email, "password": request.password}
|
||||
try:
|
||||
sign_in = cast(Any, self._client.auth.sign_in_with_password)
|
||||
response = await asyncio.to_thread(sign_in, payload)
|
||||
return _map_auth_response(response, "Invalid credentials")
|
||||
except AuthError as exc:
|
||||
logger.warning("Login failed", error=str(exc))
|
||||
raise HTTPException(status_code=401, detail="Invalid credentials") from exc
|
||||
|
||||
async def refresh(self, request: RefreshRequest) -> AuthTokenResponse:
|
||||
try:
|
||||
response = await asyncio.to_thread(
|
||||
self._client.auth.refresh_session,
|
||||
request.refresh_token,
|
||||
)
|
||||
return _map_auth_response(response, "Invalid refresh token")
|
||||
except AuthError as exc:
|
||||
logger.warning("Refresh failed", error=str(exc))
|
||||
raise HTTPException(
|
||||
status_code=401, detail="Invalid refresh token"
|
||||
) from exc
|
||||
|
||||
async def logout(self, refresh_token: str | None) -> None:
|
||||
if not refresh_token:
|
||||
raise HTTPException(status_code=401, detail="Missing refresh token")
|
||||
try:
|
||||
response = await asyncio.to_thread(
|
||||
self._client.auth.refresh_session,
|
||||
refresh_token,
|
||||
)
|
||||
session = getattr(response, "session", None)
|
||||
if session is None:
|
||||
raise HTTPException(status_code=401, detail="Invalid refresh token")
|
||||
await asyncio.to_thread(
|
||||
self._client.auth.set_session,
|
||||
str(session.access_token),
|
||||
str(session.refresh_token),
|
||||
)
|
||||
await asyncio.to_thread(self._client.auth.sign_out)
|
||||
except AuthError as exc:
|
||||
logger.warning("Logout failed", error=str(exc))
|
||||
raise HTTPException(
|
||||
status_code=401, detail="Invalid refresh token"
|
||||
) from exc
|
||||
|
||||
|
||||
class AuthService:
|
||||
_gateway: AuthServiceGateway
|
||||
|
||||
def __init__(self, gateway: AuthServiceGateway) -> None:
|
||||
self._gateway = gateway
|
||||
|
||||
async def signup(self, request: SignupRequest) -> AuthTokenResponse:
|
||||
return await self._gateway.signup(request)
|
||||
|
||||
async def login(self, request: LoginRequest) -> AuthTokenResponse:
|
||||
return await self._gateway.login(request)
|
||||
|
||||
async def refresh(self, request: RefreshRequest) -> AuthTokenResponse:
|
||||
return await self._gateway.refresh(request)
|
||||
|
||||
async def logout(self, refresh_token: str | None) -> None:
|
||||
await self._gateway.logout(refresh_token)
|
||||
|
||||
|
||||
def _map_auth_response(response: object, failure_message: str) -> AuthTokenResponse:
|
||||
session = getattr(response, "session", None)
|
||||
user = getattr(response, "user", None)
|
||||
if session is None or user is None:
|
||||
raise HTTPException(status_code=401, detail=failure_message)
|
||||
|
||||
email = getattr(user, "email", None)
|
||||
if not email:
|
||||
raise HTTPException(status_code=401, detail=failure_message)
|
||||
|
||||
auth_user = AuthUser(id=str(user.id), email=str(email))
|
||||
return AuthTokenResponse(
|
||||
access_token=str(session.access_token),
|
||||
refresh_token=str(session.refresh_token),
|
||||
expires_in=int(session.expires_in or 0),
|
||||
token_type=str(session.token_type),
|
||||
user=auth_user,
|
||||
)
|
||||
@@ -0,0 +1 @@
|
||||
from __future__ import annotations
|
||||
@@ -0,0 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from services.base.redis import RedisService, redis_service
|
||||
from services.base.qdrant import QdrantService, qdrant_service
|
||||
|
||||
|
||||
def get_redis_service() -> RedisService:
|
||||
return redis_service
|
||||
|
||||
|
||||
def get_qdrant_service() -> QdrantService:
|
||||
return qdrant_service
|
||||
@@ -0,0 +1,38 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from services.base.qdrant import QdrantService
|
||||
from services.base.redis import RedisService
|
||||
from v1.infra.dependencies import get_qdrant_service, get_redis_service
|
||||
from v1.infra.schemas import InfraHealthResponse, ServiceHealth
|
||||
|
||||
|
||||
router = APIRouter(prefix="/infra", tags=["infra"])
|
||||
|
||||
|
||||
@router.get("/health", response_model=InfraHealthResponse)
|
||||
async def infra_health(
|
||||
redis_service: RedisService = Depends(get_redis_service),
|
||||
qdrant_service: QdrantService = Depends(get_qdrant_service),
|
||||
) -> InfraHealthResponse:
|
||||
if not redis_service.is_initialized:
|
||||
await redis_service.initialize()
|
||||
if not qdrant_service.is_initialized:
|
||||
await qdrant_service.initialize()
|
||||
|
||||
redis_health = await redis_service.health_check()
|
||||
qdrant_health = await qdrant_service.health_check()
|
||||
status = (
|
||||
"healthy"
|
||||
if redis_health["status"] == "healthy" and qdrant_health["status"] == "healthy"
|
||||
else "unhealthy"
|
||||
)
|
||||
|
||||
return InfraHealthResponse(
|
||||
status=status,
|
||||
services={
|
||||
"redis": ServiceHealth(**redis_health),
|
||||
"qdrant": ServiceHealth(**qdrant_health),
|
||||
},
|
||||
)
|
||||
@@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ServiceHealth(BaseModel):
|
||||
status: Literal["healthy", "unhealthy"]
|
||||
details: Dict[str, Any]
|
||||
|
||||
|
||||
class InfraHealthResponse(BaseModel):
|
||||
status: Literal["healthy", "unhealthy"]
|
||||
services: Dict[str, ServiceHealth]
|
||||
@@ -0,0 +1 @@
|
||||
from __future__ import annotations
|
||||
@@ -0,0 +1,93 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
import jwt
|
||||
from fastapi import Depends, Header, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.config.settings import config
|
||||
from core.db import get_db
|
||||
from core.logging import get_logger
|
||||
from core.auth.models import CurrentUser
|
||||
from v1.profile.repository import SQLAlchemyProfileRepository
|
||||
from v1.profile.service import ProfileService
|
||||
|
||||
logger = get_logger("v1.profile.dependencies")
|
||||
|
||||
|
||||
def get_current_user(authorization: str | None = Header(default=None)) -> CurrentUser:
|
||||
if not authorization:
|
||||
logger.warning("JWT validation failed: missing authorization header")
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
scheme, _, token = authorization.partition(" ")
|
||||
if scheme.lower() != "bearer" or not token:
|
||||
logger.warning("JWT validation failed: invalid authorization scheme")
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
secret = config.supabase.jwt_secret
|
||||
if not secret:
|
||||
logger.error("JWT validation failed: secret not configured")
|
||||
raise HTTPException(status_code=503, detail="JWT secret not configured")
|
||||
|
||||
supabase_url = config.supabase.public_url.rstrip("/")
|
||||
expected_issuer = f"{supabase_url}/auth/v1"
|
||||
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
secret,
|
||||
algorithms=["HS256"],
|
||||
audience="authenticated",
|
||||
issuer=expected_issuer,
|
||||
options={
|
||||
"verify_aud": True,
|
||||
"verify_iss": True,
|
||||
"verify_exp": True,
|
||||
"require": ["sub", "aud", "iss", "exp"],
|
||||
},
|
||||
)
|
||||
except jwt.ExpiredSignatureError:
|
||||
logger.warning("JWT validation failed: token expired")
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
except jwt.InvalidAudienceError:
|
||||
logger.warning("JWT validation failed: invalid audience")
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
except jwt.InvalidIssuerError:
|
||||
logger.warning("JWT validation failed: invalid issuer")
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
except jwt.InvalidSignatureError:
|
||||
logger.warning("JWT validation failed: invalid signature")
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
except jwt.DecodeError:
|
||||
logger.warning("JWT validation failed: malformed token")
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
except jwt.PyJWTError as exc:
|
||||
logger.warning(
|
||||
"JWT validation failed: unknown error", error_type=type(exc).__name__
|
||||
)
|
||||
raise HTTPException(status_code=401, detail="Unauthorized") from exc
|
||||
|
||||
subject = payload.get("sub")
|
||||
if not isinstance(subject, str) or not subject:
|
||||
logger.warning("JWT validation failed: missing or invalid subject claim")
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
try:
|
||||
user_id = UUID(subject)
|
||||
except ValueError:
|
||||
logger.warning("JWT validation failed: invalid UUID in subject")
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
logger.debug("JWT validation successful", user_id=str(user_id))
|
||||
return CurrentUser(id=user_id)
|
||||
|
||||
|
||||
def get_profile_service(
|
||||
session: Annotated[AsyncSession, Depends(get_db)],
|
||||
user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
) -> ProfileService:
|
||||
repository = SQLAlchemyProfileRepository(session)
|
||||
return ProfileService(repository=repository, session=session, current_user=user)
|
||||
@@ -0,0 +1,72 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from core.db.base_repository import BaseRepository
|
||||
from core.logging import get_logger
|
||||
from models.profile import Profile
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
logger = get_logger("v1.profile.repository")
|
||||
|
||||
|
||||
class ProfileRepository(Protocol):
|
||||
"""Protocol defining the profile repository interface."""
|
||||
|
||||
async def get_by_user_id(self, user_id: UUID) -> Profile | None:
|
||||
"""Get profile by user ID."""
|
||||
...
|
||||
|
||||
async def get_by_username(self, username: str) -> Profile | None:
|
||||
"""Get profile by username."""
|
||||
...
|
||||
|
||||
async def update_by_user_id(
|
||||
self, user_id: UUID, update_data: dict[str, str | None]
|
||||
) -> Profile | None:
|
||||
"""Update profile by user ID. Returns updated profile or None if not found."""
|
||||
...
|
||||
|
||||
|
||||
class SQLAlchemyProfileRepository(BaseRepository[Profile]):
|
||||
"""SQLAlchemy implementation of ProfileRepository.
|
||||
|
||||
Note: This repository only performs CRUD operations.
|
||||
- No commit (only flush) - service layer handles transactions
|
||||
- No auth logic - service layer handles authorization
|
||||
- No HTTP exceptions - returns None or raises SQLAlchemyError
|
||||
"""
|
||||
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
super().__init__(session, Profile)
|
||||
|
||||
async def get_by_user_id(self, user_id: UUID) -> Profile | None:
|
||||
try:
|
||||
return await self.get_by_id(user_id)
|
||||
except SQLAlchemyError:
|
||||
logger.exception("Profile lookup failed", user_id=str(user_id))
|
||||
raise
|
||||
|
||||
async def get_by_username(self, username: str) -> Profile | None:
|
||||
try:
|
||||
return await self.get_one(Profile.username == username)
|
||||
except SQLAlchemyError:
|
||||
logger.exception("Profile lookup failed", username=username)
|
||||
raise
|
||||
|
||||
async def update_by_user_id(
|
||||
self, user_id: UUID, update_data: dict[str, str | None]
|
||||
) -> Profile | None:
|
||||
if not update_data:
|
||||
return await self.get_by_user_id(user_id)
|
||||
|
||||
try:
|
||||
return await self.update_by_id(user_id, update_data)
|
||||
except SQLAlchemyError:
|
||||
logger.exception("Profile update failed", user_id=str(user_id))
|
||||
raise
|
||||
@@ -0,0 +1,36 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Path
|
||||
|
||||
from v1.profile.dependencies import get_profile_service
|
||||
from v1.profile.schemas import ProfileResponse, ProfileUpdateRequest
|
||||
from v1.profile.service import ProfileService
|
||||
|
||||
router = APIRouter(prefix="/profile", tags=["profile"])
|
||||
|
||||
|
||||
@router.get("/me", response_model=ProfileResponse)
|
||||
async def get_me(
|
||||
service: Annotated[ProfileService, Depends(get_profile_service)],
|
||||
) -> ProfileResponse:
|
||||
return await service.get_me()
|
||||
|
||||
|
||||
@router.patch("/me", response_model=ProfileResponse)
|
||||
async def update_me(
|
||||
payload: ProfileUpdateRequest,
|
||||
service: Annotated[ProfileService, Depends(get_profile_service)],
|
||||
) -> ProfileResponse:
|
||||
return await service.update_me(payload)
|
||||
|
||||
|
||||
@router.get("/{username}", response_model=ProfileResponse)
|
||||
async def get_by_username(
|
||||
username: Annotated[
|
||||
str, Path(min_length=3, max_length=30, pattern="^[a-zA-Z0-9_]+$")
|
||||
],
|
||||
service: Annotated[ProfileService, Depends(get_profile_service)],
|
||||
) -> ProfileResponse:
|
||||
return await service.get_by_username(username)
|
||||
@@ -0,0 +1,33 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import AnyHttpUrl, BaseModel, Field, field_validator, model_validator
|
||||
|
||||
|
||||
class ProfileResponse(BaseModel):
|
||||
id: str
|
||||
username: str
|
||||
display_name: str | None = None
|
||||
avatar_url: str | None = None
|
||||
bio: str | None = None
|
||||
|
||||
|
||||
class ProfileUpdateRequest(BaseModel):
|
||||
display_name: str | None = Field(default=None, max_length=50)
|
||||
avatar_url: str | None = Field(default=None)
|
||||
bio: str | None = Field(default=None, max_length=200)
|
||||
|
||||
@field_validator("avatar_url", mode="before")
|
||||
@classmethod
|
||||
def validate_avatar_url(cls, v: str | None) -> str | None:
|
||||
if v is None:
|
||||
return None
|
||||
parsed = AnyHttpUrl(v)
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
raise ValueError("avatar_url must use http or https scheme")
|
||||
return str(parsed)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def require_one_field(self) -> "ProfileUpdateRequest":
|
||||
if self.display_name is None and self.avatar_url is None and self.bio is None:
|
||||
raise ValueError("At least one field must be provided")
|
||||
return self
|
||||
@@ -0,0 +1,106 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from core.db.base_service import BaseService
|
||||
from core.logging import get_logger
|
||||
from v1.profile.repository import ProfileRepository
|
||||
from v1.profile.schemas import ProfileResponse, ProfileUpdateRequest
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
logger = get_logger("v1.profile.service")
|
||||
|
||||
|
||||
class ProfileService(BaseService):
|
||||
"""Profile service handling business logic and transactions.
|
||||
|
||||
Responsibilities:
|
||||
- Authorization checks
|
||||
- Transaction boundary (commit/rollback)
|
||||
- Converting ORM models to response schemas
|
||||
"""
|
||||
|
||||
_repository: ProfileRepository
|
||||
_session: AsyncSession
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
repository: ProfileRepository,
|
||||
session: AsyncSession,
|
||||
current_user: CurrentUser | None,
|
||||
) -> None:
|
||||
super().__init__(current_user=current_user)
|
||||
self._repository = repository
|
||||
self._session = session
|
||||
|
||||
async def get_me(self) -> ProfileResponse:
|
||||
user_id = self.require_user_id()
|
||||
try:
|
||||
profile = await self._repository.get_by_user_id(user_id)
|
||||
except SQLAlchemyError:
|
||||
raise HTTPException(status_code=503, detail="Profile store unavailable")
|
||||
|
||||
if profile is None:
|
||||
raise HTTPException(status_code=404, detail="Profile not found")
|
||||
return ProfileResponse(
|
||||
id=str(profile.id),
|
||||
username=profile.username,
|
||||
display_name=profile.display_name,
|
||||
avatar_url=profile.avatar_url,
|
||||
bio=profile.bio,
|
||||
)
|
||||
|
||||
async def update_me(self, update: ProfileUpdateRequest) -> ProfileResponse:
|
||||
user_id = self.require_user_id()
|
||||
update_data: dict[str, str | None] = {
|
||||
key: value
|
||||
for key, value in {
|
||||
"display_name": update.display_name,
|
||||
"avatar_url": update.avatar_url,
|
||||
"bio": update.bio,
|
||||
}.items()
|
||||
if value is not None
|
||||
}
|
||||
|
||||
if not update_data:
|
||||
raise HTTPException(status_code=400, detail="No fields to update")
|
||||
|
||||
try:
|
||||
profile = await self._repository.update_by_user_id(user_id, update_data)
|
||||
await self._session.commit()
|
||||
except SQLAlchemyError:
|
||||
await self._session.rollback()
|
||||
raise HTTPException(status_code=503, detail="Profile store unavailable")
|
||||
|
||||
if profile is None:
|
||||
raise HTTPException(status_code=404, detail="Profile not found")
|
||||
|
||||
return ProfileResponse(
|
||||
id=str(profile.id),
|
||||
username=profile.username,
|
||||
display_name=profile.display_name,
|
||||
avatar_url=profile.avatar_url,
|
||||
bio=profile.bio,
|
||||
)
|
||||
|
||||
async def get_by_username(self, username: str) -> ProfileResponse:
|
||||
try:
|
||||
profile = await self._repository.get_by_username(username)
|
||||
except SQLAlchemyError:
|
||||
raise HTTPException(status_code=503, detail="Profile store unavailable")
|
||||
|
||||
if profile is None:
|
||||
raise HTTPException(status_code=404, detail="Profile not found")
|
||||
return ProfileResponse(
|
||||
id=str(profile.id),
|
||||
username=profile.username,
|
||||
display_name=profile.display_name,
|
||||
avatar_url=profile.avatar_url,
|
||||
bio=profile.bio,
|
||||
)
|
||||
@@ -0,0 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from core.http.models import HealthResponse
|
||||
from v1.auth.router import router as auth_router
|
||||
from v1.infra.router import router as infra_router
|
||||
from v1.profile.router import router as profile_router
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/v1")
|
||||
router.include_router(auth_router)
|
||||
router.include_router(infra_router)
|
||||
router.include_router(profile_router)
|
||||
|
||||
|
||||
@router.get("/health", response_model=HealthResponse)
|
||||
async def health() -> HealthResponse:
|
||||
return HealthResponse(status="ok")
|
||||
@@ -0,0 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def pytest_configure() -> None:
|
||||
root = Path(__file__).resolve().parents[2]
|
||||
src_path = root / "backend" / "src"
|
||||
if str(src_path) not in sys.path:
|
||||
sys.path.insert(0, str(src_path))
|
||||
@@ -0,0 +1,134 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
|
||||
from playwright.sync_api import sync_playwright
|
||||
import uvicorn
|
||||
|
||||
from app import app
|
||||
from v1.auth.dependencies import get_auth_service
|
||||
from v1.auth.models import (
|
||||
AuthTokenResponse,
|
||||
AuthUser,
|
||||
LoginRequest,
|
||||
RefreshRequest,
|
||||
SignupRequest,
|
||||
)
|
||||
from v1.auth.service import AuthService
|
||||
|
||||
|
||||
class FakeE2EAuthService(AuthService):
|
||||
def __init__(self) -> None:
|
||||
self._user = AuthUser(id="user-1", email="user@example.com")
|
||||
|
||||
async def signup(self, request: SignupRequest) -> AuthTokenResponse:
|
||||
return AuthTokenResponse(
|
||||
access_token="access-1",
|
||||
refresh_token="refresh-1",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=self._user,
|
||||
)
|
||||
|
||||
async def login(self, request: LoginRequest) -> AuthTokenResponse:
|
||||
return AuthTokenResponse(
|
||||
access_token="access-2",
|
||||
refresh_token="refresh-2",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=self._user,
|
||||
)
|
||||
|
||||
async def refresh(self, request: RefreshRequest) -> AuthTokenResponse:
|
||||
return AuthTokenResponse(
|
||||
access_token="access-3",
|
||||
refresh_token="refresh-3",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=self._user,
|
||||
)
|
||||
|
||||
async def logout(self, refresh_token: str | None) -> None:
|
||||
return None
|
||||
|
||||
|
||||
def _find_free_port() -> int:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
sock.bind(("127.0.0.1", 0))
|
||||
return sock.getsockname()[1]
|
||||
|
||||
|
||||
def _wait_for_port(host: str, port: int, timeout: float = 5.0) -> None:
|
||||
deadline = time.time() + timeout
|
||||
while time.time() < deadline:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
if sock.connect_ex((host, port)) == 0:
|
||||
return
|
||||
time.sleep(0.05)
|
||||
raise RuntimeError("Server did not start in time")
|
||||
|
||||
|
||||
def _start_server(host: str, port: int):
|
||||
config = uvicorn.Config(app, host=host, port=port, log_level="info")
|
||||
server = uvicorn.Server(config)
|
||||
thread = threading.Thread(target=server.run, daemon=True)
|
||||
thread.start()
|
||||
_wait_for_port(host, port)
|
||||
return server, thread
|
||||
|
||||
|
||||
def test_auth_flow_e2e() -> None:
|
||||
app.dependency_overrides[get_auth_service] = lambda: FakeE2EAuthService()
|
||||
host = "127.0.0.1"
|
||||
port = _find_free_port()
|
||||
server, thread = _start_server(host, port)
|
||||
|
||||
try:
|
||||
with sync_playwright() as playwright:
|
||||
request_context = playwright.request.new_context(
|
||||
base_url=f"http://{host}:{port}"
|
||||
)
|
||||
try:
|
||||
signup = request_context.post(
|
||||
"/api/v1/auth/signup",
|
||||
data=json.dumps(
|
||||
{"email": "user@example.com", "password": "secret123"}
|
||||
),
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
assert signup.status == 200
|
||||
assert signup.json()["access_token"] == "access-1"
|
||||
|
||||
login = request_context.post(
|
||||
"/api/v1/auth/login",
|
||||
data=json.dumps(
|
||||
{"email": "user@example.com", "password": "secret123"}
|
||||
),
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
assert login.status == 200
|
||||
assert login.json()["access_token"] == "access-2"
|
||||
|
||||
refresh = request_context.post(
|
||||
"/api/v1/auth/refresh",
|
||||
data=json.dumps({"refresh_token": "refresh-2"}),
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
assert refresh.status == 200
|
||||
assert refresh.json()["access_token"] == "access-3"
|
||||
|
||||
logout = request_context.post(
|
||||
"/api/v1/auth/logout",
|
||||
data=json.dumps({"refresh_token": "refresh-3"}),
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
assert logout.status == 204
|
||||
finally:
|
||||
request_context.dispose()
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
server.should_exit = True
|
||||
thread.join(timeout=5)
|
||||
@@ -0,0 +1,79 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
|
||||
from playwright.sync_api import sync_playwright
|
||||
import uvicorn
|
||||
|
||||
from app import app
|
||||
from v1.infra.dependencies import get_qdrant_service, get_redis_service
|
||||
|
||||
|
||||
class _FakeService:
|
||||
def __init__(self) -> None:
|
||||
self._initialized = True
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
return self._initialized
|
||||
|
||||
async def initialize(self) -> bool:
|
||||
return True
|
||||
|
||||
async def health_check(self) -> dict[str, object]:
|
||||
return {"status": "healthy", "details": {}}
|
||||
|
||||
|
||||
def _find_free_port() -> int:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
sock.bind(("127.0.0.1", 0))
|
||||
return sock.getsockname()[1]
|
||||
|
||||
|
||||
def _wait_for_port(host: str, port: int, timeout: float = 5.0) -> None:
|
||||
deadline = time.time() + timeout
|
||||
while time.time() < deadline:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
if sock.connect_ex((host, port)) == 0:
|
||||
return
|
||||
time.sleep(0.05)
|
||||
raise RuntimeError("Server did not start in time")
|
||||
|
||||
|
||||
def _start_server(host: str, port: int):
|
||||
config = uvicorn.Config(app, host=host, port=port, log_level="info")
|
||||
server = uvicorn.Server(config)
|
||||
thread = threading.Thread(target=server.run, daemon=True)
|
||||
thread.start()
|
||||
_wait_for_port(host, port)
|
||||
return server, thread
|
||||
|
||||
|
||||
def test_infra_health_e2e() -> None:
|
||||
app.dependency_overrides[get_redis_service] = lambda: _FakeService()
|
||||
app.dependency_overrides[get_qdrant_service] = lambda: _FakeService()
|
||||
|
||||
host = "127.0.0.1"
|
||||
port = _find_free_port()
|
||||
server, thread = _start_server(host, port)
|
||||
|
||||
try:
|
||||
with sync_playwright() as playwright:
|
||||
request_context = playwright.request.new_context(
|
||||
base_url=f"http://{host}:{port}"
|
||||
)
|
||||
try:
|
||||
response = request_context.get("/api/v1/infra/health")
|
||||
assert response.status == 200
|
||||
body = response.json()
|
||||
assert body["status"] == "healthy"
|
||||
assert "redis" in body["services"]
|
||||
assert "qdrant" in body["services"]
|
||||
finally:
|
||||
request_context.dispose()
|
||||
finally:
|
||||
server.should_exit = True
|
||||
thread.join(timeout=5)
|
||||
app.dependency_overrides = {}
|
||||
@@ -0,0 +1,96 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import FastAPI
|
||||
from playwright.sync_api import sync_playwright
|
||||
import uvicorn
|
||||
|
||||
from core.config.settings import Settings
|
||||
from core.logging.config import configure_logging
|
||||
from core.logging.middleware import (
|
||||
RequestContextMiddleware,
|
||||
register_exception_handlers,
|
||||
)
|
||||
|
||||
|
||||
def _read_json_lines(path: Path) -> list[dict[str, object]]:
|
||||
return [json.loads(line) for line in path.read_text().splitlines() if line.strip()]
|
||||
|
||||
|
||||
def _find_free_port() -> int:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
sock.bind(("127.0.0.1", 0))
|
||||
return sock.getsockname()[1]
|
||||
|
||||
|
||||
def _wait_for_port(host: str, port: int, timeout: float = 5.0) -> None:
|
||||
deadline = time.time() + timeout
|
||||
while time.time() < deadline:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
if sock.connect_ex((host, port)) == 0:
|
||||
return
|
||||
time.sleep(0.05)
|
||||
raise RuntimeError("Server did not start in time")
|
||||
|
||||
|
||||
def _start_server(app: FastAPI, host: str, port: int):
|
||||
config = uvicorn.Config(app, host=host, port=port, log_level="info")
|
||||
server = uvicorn.Server(config)
|
||||
thread = threading.Thread(target=server.run, daemon=True)
|
||||
thread.start()
|
||||
_wait_for_port(host, port)
|
||||
return server, thread
|
||||
|
||||
|
||||
def test_e2e_error_logging(tmp_path: Path) -> None:
|
||||
settings = Settings()
|
||||
runtime = settings.runtime.model_copy(
|
||||
update={
|
||||
"log_dir": str(tmp_path),
|
||||
"log_error_dir": str(tmp_path / "errors"),
|
||||
"log_rotation": "size",
|
||||
"log_rotation_max_bytes": 2048,
|
||||
}
|
||||
)
|
||||
configure_logging(settings.model_copy(update={"runtime": runtime}))
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(RequestContextMiddleware) # type: ignore[arg-type]
|
||||
register_exception_handlers(app)
|
||||
|
||||
@app.get("/boom")
|
||||
async def boom() -> dict[str, str]:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
host = "127.0.0.1"
|
||||
port = _find_free_port()
|
||||
server, thread = _start_server(app, host, port)
|
||||
|
||||
try:
|
||||
with sync_playwright() as playwright:
|
||||
request_context = playwright.request.new_context(
|
||||
base_url=f"http://{host}:{port}"
|
||||
)
|
||||
response = request_context.get(
|
||||
"/boom",
|
||||
headers={"X-Request-ID": "e2e-5000"},
|
||||
)
|
||||
assert response.status == 500
|
||||
request_context.dispose()
|
||||
finally:
|
||||
server.should_exit = True
|
||||
thread.join(timeout=5)
|
||||
|
||||
error_entries = _read_json_lines(Path(tmp_path) / "errors" / "error.log")
|
||||
entry = next(
|
||||
item for item in error_entries if item.get("message") == "Unhandled exception"
|
||||
)
|
||||
|
||||
assert entry["request_id"] == "e2e-5000"
|
||||
exception = str(entry["exception"])
|
||||
assert "Traceback" in exception
|
||||
@@ -0,0 +1,57 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
|
||||
from playwright.sync_api import sync_playwright
|
||||
import uvicorn
|
||||
|
||||
from app import app
|
||||
|
||||
|
||||
def _find_free_port() -> int:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
sock.bind(("127.0.0.1", 0))
|
||||
return sock.getsockname()[1]
|
||||
|
||||
|
||||
def _wait_for_port(host: str, port: int, timeout: float = 5.0) -> None:
|
||||
deadline = time.time() + timeout
|
||||
while time.time() < deadline:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
if sock.connect_ex((host, port)) == 0:
|
||||
return
|
||||
time.sleep(0.05)
|
||||
raise RuntimeError("Server did not start in time")
|
||||
|
||||
|
||||
def _start_server(host: str, port: int):
|
||||
config = uvicorn.Config(app, host=host, port=port, log_level="info")
|
||||
server = uvicorn.Server(config)
|
||||
thread = threading.Thread(target=server.run, daemon=True)
|
||||
thread.start()
|
||||
_wait_for_port(host, port)
|
||||
return server, thread
|
||||
|
||||
|
||||
def test_mobile_health_e2e() -> None:
|
||||
host = "127.0.0.1"
|
||||
port = _find_free_port()
|
||||
server, thread = _start_server(host, port)
|
||||
|
||||
try:
|
||||
with sync_playwright() as playwright:
|
||||
request_context = playwright.request.new_context(
|
||||
base_url=f"http://{host}:{port}"
|
||||
)
|
||||
try:
|
||||
response = request_context.get("/api/v1/health")
|
||||
assert response.status == 200
|
||||
body = response.json()
|
||||
assert body["status"] == "ok"
|
||||
finally:
|
||||
request_context.dispose()
|
||||
finally:
|
||||
server.should_exit = True
|
||||
thread.join(timeout=5)
|
||||
@@ -0,0 +1,115 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
from uuid import UUID
|
||||
|
||||
from playwright.sync_api import sync_playwright
|
||||
import uvicorn
|
||||
|
||||
from app import app
|
||||
from core.auth.models import CurrentUser
|
||||
from v1.profile.dependencies import get_current_user, get_profile_service
|
||||
from v1.profile.schemas import ProfileResponse, ProfileUpdateRequest
|
||||
|
||||
|
||||
class FakeProfileService:
|
||||
"""Fake service for E2E testing."""
|
||||
|
||||
def __init__(self, profile: ProfileResponse) -> None:
|
||||
self._profile = profile
|
||||
|
||||
async def get_me(self) -> ProfileResponse:
|
||||
return self._profile
|
||||
|
||||
async def update_me(self, update: ProfileUpdateRequest) -> ProfileResponse:
|
||||
return ProfileResponse(
|
||||
id=self._profile.id,
|
||||
username=self._profile.username,
|
||||
display_name=(
|
||||
update.display_name
|
||||
if update.display_name is not None
|
||||
else self._profile.display_name
|
||||
),
|
||||
avatar_url=(
|
||||
update.avatar_url
|
||||
if update.avatar_url is not None
|
||||
else self._profile.avatar_url
|
||||
),
|
||||
bio=update.bio if update.bio is not None else self._profile.bio,
|
||||
)
|
||||
|
||||
async def get_by_username(self, username: str) -> ProfileResponse:
|
||||
return self._profile
|
||||
|
||||
|
||||
def _find_free_port() -> int:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
sock.bind(("127.0.0.1", 0))
|
||||
return sock.getsockname()[1]
|
||||
|
||||
|
||||
def _wait_for_port(host: str, port: int, timeout: float = 5.0) -> None:
|
||||
deadline = time.time() + timeout
|
||||
while time.time() < deadline:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
if sock.connect_ex((host, port)) == 0:
|
||||
return
|
||||
time.sleep(0.05)
|
||||
raise RuntimeError("Server did not start in time")
|
||||
|
||||
|
||||
def _start_server(host: str, port: int):
|
||||
config = uvicorn.Config(app, host=host, port=port, log_level="info")
|
||||
server = uvicorn.Server(config)
|
||||
thread = threading.Thread(target=server.run, daemon=True)
|
||||
thread.start()
|
||||
_wait_for_port(host, port)
|
||||
return server, thread
|
||||
|
||||
|
||||
def test_profile_flow_e2e() -> None:
|
||||
user_id = UUID("00000000-0000-0000-0000-000000000001")
|
||||
profile = ProfileResponse(
|
||||
id=str(user_id),
|
||||
username="demo",
|
||||
display_name="Demo User",
|
||||
avatar_url=None,
|
||||
bio=None,
|
||||
)
|
||||
app.dependency_overrides[get_profile_service] = lambda: FakeProfileService(profile) # type: ignore[return-value]
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(id=user_id)
|
||||
|
||||
host = "127.0.0.1"
|
||||
port = _find_free_port()
|
||||
server, thread = _start_server(host, port)
|
||||
|
||||
try:
|
||||
with sync_playwright() as playwright:
|
||||
request_context = playwright.request.new_context(
|
||||
base_url=f"http://{host}:{port}"
|
||||
)
|
||||
try:
|
||||
me = request_context.get("/api/v1/profile/me")
|
||||
assert me.status == 200
|
||||
assert me.json()["username"] == "demo"
|
||||
|
||||
updated = request_context.patch(
|
||||
"/api/v1/profile/me",
|
||||
data=json.dumps({"display_name": "Updated"}),
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
assert updated.status == 200
|
||||
assert updated.json()["display_name"] == "Updated"
|
||||
|
||||
public = request_context.get("/api/v1/profile/demo")
|
||||
assert public.status == 200
|
||||
assert public.json()["username"] == "demo"
|
||||
finally:
|
||||
request_context.dispose()
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
server.should_exit = True
|
||||
thread.join(timeout=5)
|
||||
@@ -0,0 +1,49 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
|
||||
import pytest
|
||||
|
||||
from core.config.settings import Settings
|
||||
from services.base.qdrant import QdrantService
|
||||
from services.base.redis import RedisService
|
||||
|
||||
|
||||
def _can_connect(host: str, port: int) -> bool:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
sock.settimeout(0.2)
|
||||
return sock.connect_ex((host, port)) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis_service_health_check_integration() -> None:
|
||||
host = "127.0.0.1"
|
||||
port = 6379
|
||||
if not _can_connect(host, port):
|
||||
pytest.skip("Redis is not running on localhost:6379")
|
||||
|
||||
config = Settings()
|
||||
settings = config.redis.model_copy(update={"host": host, "port": port})
|
||||
service = RedisService(settings=settings)
|
||||
|
||||
assert await service.initialize() is True
|
||||
health = await service.health_check()
|
||||
assert health["status"] == "healthy"
|
||||
assert await service.close() is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_qdrant_service_health_check_integration() -> None:
|
||||
host = "127.0.0.1"
|
||||
port = 6333
|
||||
if not _can_connect(host, port):
|
||||
pytest.skip("Qdrant is not running on localhost:6333")
|
||||
|
||||
config = Settings()
|
||||
settings = config.qdrant.model_copy(update={"host": host, "port": port})
|
||||
service = QdrantService(settings=settings)
|
||||
|
||||
assert await service.initialize() is True
|
||||
health = await service.health_check()
|
||||
assert health["status"] == "healthy"
|
||||
assert await service.close() is True
|
||||
@@ -0,0 +1,178 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app import app
|
||||
from v1.auth.dependencies import get_auth_service
|
||||
from v1.auth.models import (
|
||||
AuthTokenResponse,
|
||||
AuthUser,
|
||||
LoginRequest,
|
||||
RefreshRequest,
|
||||
SignupRequest,
|
||||
)
|
||||
from v1.auth.service import AuthService
|
||||
|
||||
|
||||
class FakeAuthService(AuthService):
|
||||
def __init__(self, token_response: AuthTokenResponse) -> None:
|
||||
self._token_response = token_response
|
||||
|
||||
async def signup(self, request: SignupRequest) -> AuthTokenResponse:
|
||||
return self._token_response
|
||||
|
||||
async def login(self, request: LoginRequest) -> AuthTokenResponse:
|
||||
raise HTTPException(status_code=401, detail="Invalid credentials")
|
||||
|
||||
async def refresh(self, request: RefreshRequest) -> AuthTokenResponse:
|
||||
raise HTTPException(status_code=401, detail="Invalid refresh token")
|
||||
|
||||
async def logout(self, refresh_token: str | None) -> None:
|
||||
return None
|
||||
|
||||
|
||||
def _override_auth_service(service: AuthService) -> Callable[[], AuthService]:
|
||||
def _get_service() -> AuthService:
|
||||
return service
|
||||
|
||||
return _get_service
|
||||
|
||||
|
||||
def test_signup_returns_token_response() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = AuthTokenResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
app.dependency_overrides[get_auth_service] = _override_auth_service(
|
||||
FakeAuthService(token_response)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.post(
|
||||
"/api/v1/auth/signup",
|
||||
json={"email": "user@example.com", "password": "secret123"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["access_token"] == "access"
|
||||
assert body["refresh_token"] == "refresh"
|
||||
assert body["user"]["email"] == "user@example.com"
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_login_invalid_returns_problem_details() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = AuthTokenResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
app.dependency_overrides[get_auth_service] = _override_auth_service(
|
||||
FakeAuthService(token_response)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": "user@example.com", "password": "wrongpw"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
assert response.headers["content-type"].startswith("application/problem+json")
|
||||
body = response.json()
|
||||
assert body["title"] == "Unauthorized"
|
||||
assert body["status"] == 401
|
||||
assert body["detail"] == "Invalid credentials"
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_refresh_invalid_returns_problem_details() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = AuthTokenResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
app.dependency_overrides[get_auth_service] = _override_auth_service(
|
||||
FakeAuthService(token_response)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": "invalid"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
assert response.headers["content-type"].startswith("application/problem+json")
|
||||
body = response.json()
|
||||
assert body["title"] == "Unauthorized"
|
||||
assert body["status"] == 401
|
||||
assert body["detail"] == "Invalid refresh token"
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_logout_returns_no_content() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = AuthTokenResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
app.dependency_overrides[get_auth_service] = _override_auth_service(
|
||||
FakeAuthService(token_response)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.post(
|
||||
"/api/v1/auth/logout",
|
||||
json={"refresh_token": "refresh"},
|
||||
)
|
||||
assert response.status_code == 204
|
||||
assert response.content == b""
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_signup_validation_error_returns_problem_details() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = AuthTokenResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
app.dependency_overrides[get_auth_service] = _override_auth_service(
|
||||
FakeAuthService(token_response)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.post("/api/v1/auth/signup", json={})
|
||||
assert response.status_code == 422
|
||||
assert response.headers["content-type"].startswith("application/problem+json")
|
||||
body = response.json()
|
||||
assert body["title"] == "Unprocessable Content"
|
||||
assert body["status"] == 422
|
||||
assert body["detail"] == "Invalid request"
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
@@ -0,0 +1,126 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from core.config.settings import Settings
|
||||
from core.logging.config import configure_logging
|
||||
from core.logging.logger import get_logger
|
||||
from core.logging.middleware import (
|
||||
RequestContextMiddleware,
|
||||
register_exception_handlers,
|
||||
)
|
||||
|
||||
|
||||
def _read_json_lines(path: Path) -> list[dict[str, object]]:
|
||||
return [json.loads(line) for line in path.read_text().splitlines() if line.strip()]
|
||||
|
||||
|
||||
def _configure_test_logging(tmp_path: Path) -> None:
|
||||
settings = Settings()
|
||||
runtime = settings.runtime.model_copy(
|
||||
update={
|
||||
"log_dir": str(tmp_path),
|
||||
"log_error_dir": str(tmp_path / "errors"),
|
||||
"log_rotation": "size",
|
||||
"log_rotation_max_bytes": 2048,
|
||||
}
|
||||
)
|
||||
test_settings = settings.model_copy(update={"runtime": runtime})
|
||||
|
||||
configure_logging(test_settings)
|
||||
|
||||
|
||||
def test_middleware_binds_request_context(tmp_path: Path) -> None:
|
||||
_configure_test_logging(tmp_path)
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(RequestContextMiddleware) # type: ignore[arg-type]
|
||||
|
||||
@app.get("/ok")
|
||||
async def ok() -> dict[str, str]:
|
||||
logger = get_logger("tests.ok")
|
||||
logger.info("request accepted", context_key="context_value")
|
||||
return {"status": "ok"}
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.get("/ok", headers={"X-Request-ID": "req-1234"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["X-Request-ID"] == "req-1234"
|
||||
|
||||
log_entries = _read_json_lines(Path(tmp_path) / "app.log")
|
||||
entry = next(
|
||||
item for item in log_entries if item.get("message") == "request accepted"
|
||||
)
|
||||
assert entry["message"] == "request accepted"
|
||||
assert entry["request_id"] == "req-1234"
|
||||
assert entry["method"] == "GET"
|
||||
assert entry["path"] == "/ok"
|
||||
assert entry["context_key"] == "context_value"
|
||||
|
||||
logging.shutdown()
|
||||
|
||||
|
||||
def test_exception_handler_logs_stack_and_sends_500(tmp_path: Path) -> None:
|
||||
_configure_test_logging(tmp_path)
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(RequestContextMiddleware)
|
||||
register_exception_handlers(app)
|
||||
|
||||
@app.get("/boom")
|
||||
async def boom() -> dict[str, str]:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
response = client.get("/boom", headers={"X-Request-ID": "req-5000"})
|
||||
|
||||
assert response.status_code == 500
|
||||
assert response.json()["detail"] == "Internal Server Error"
|
||||
|
||||
error_entries = _read_json_lines(Path(tmp_path) / "errors" / "error.log")
|
||||
assert error_entries
|
||||
entry = error_entries[-1]
|
||||
assert entry["level"] == "error"
|
||||
assert entry["request_id"] == "req-5000"
|
||||
exception = str(entry["exception"])
|
||||
assert "Traceback" in exception
|
||||
assert "test_fastapi_logging_integration" in exception
|
||||
|
||||
logging.shutdown()
|
||||
|
||||
|
||||
def test_invalid_request_id_is_replaced_and_used_in_error_context(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
_configure_test_logging(tmp_path)
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(RequestContextMiddleware)
|
||||
register_exception_handlers(app)
|
||||
|
||||
@app.get("/boom")
|
||||
async def boom() -> dict[str, str]:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
response = client.get("/boom", headers={"X-Request-ID": "bad"})
|
||||
|
||||
assert response.status_code == 500
|
||||
|
||||
response_request_id = response.headers["X-Request-ID"]
|
||||
assert response_request_id != "bad"
|
||||
|
||||
error_entries = _read_json_lines(Path(tmp_path) / "errors" / "error.log")
|
||||
assert error_entries
|
||||
entry = error_entries[-1]
|
||||
assert entry["request_id"] == response_request_id
|
||||
exception = str(entry["exception"])
|
||||
assert "Traceback" in exception
|
||||
|
||||
logging.shutdown()
|
||||
@@ -0,0 +1,39 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app import app
|
||||
|
||||
|
||||
def test_app_health_returns_envelope() -> None:
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/health")
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["status"] == "ok"
|
||||
|
||||
|
||||
def test_mobile_router_health_returns_envelope() -> None:
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/api/v1/health")
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["status"] == "ok"
|
||||
|
||||
|
||||
def test_not_found_returns_error_envelope() -> None:
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/missing-route")
|
||||
|
||||
assert response.status_code == 404
|
||||
assert response.headers["content-type"].startswith("application/problem+json")
|
||||
body = response.json()
|
||||
assert body["type"] == "about:blank"
|
||||
assert body["title"] == "Not Found"
|
||||
assert body["status"] == 404
|
||||
assert body["detail"] == "Not Found"
|
||||
@@ -0,0 +1,188 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app import app
|
||||
from core.auth.models import CurrentUser
|
||||
from v1.profile.dependencies import get_current_user, get_profile_service
|
||||
from v1.profile.schemas import ProfileResponse, ProfileUpdateRequest
|
||||
from v1.profile.service import ProfileService
|
||||
|
||||
|
||||
class FakeProfileService:
|
||||
"""Fake service for integration testing."""
|
||||
|
||||
def __init__(self, profile: ProfileResponse) -> None:
|
||||
self._profile = profile
|
||||
|
||||
async def get_me(self) -> ProfileResponse:
|
||||
if self._profile.id is None:
|
||||
raise HTTPException(status_code=404, detail="Profile not found")
|
||||
return self._profile
|
||||
|
||||
async def update_me(self, update: ProfileUpdateRequest) -> ProfileResponse:
|
||||
if self._profile.id is None:
|
||||
raise HTTPException(status_code=404, detail="Profile not found")
|
||||
return ProfileResponse(
|
||||
id=self._profile.id,
|
||||
username=self._profile.username,
|
||||
display_name=(
|
||||
update.display_name
|
||||
if update.display_name is not None
|
||||
else self._profile.display_name
|
||||
),
|
||||
avatar_url=(
|
||||
update.avatar_url
|
||||
if update.avatar_url is not None
|
||||
else self._profile.avatar_url
|
||||
),
|
||||
bio=update.bio if update.bio is not None else self._profile.bio,
|
||||
)
|
||||
|
||||
async def get_by_username(self, username: str) -> ProfileResponse:
|
||||
if username != self._profile.username:
|
||||
raise HTTPException(status_code=404, detail="Profile not found")
|
||||
return self._profile
|
||||
|
||||
|
||||
def _override_profile_service(
|
||||
service: FakeProfileService,
|
||||
) -> Callable[[], ProfileService]:
|
||||
def _get_service() -> ProfileService:
|
||||
return service # type: ignore[return-value]
|
||||
|
||||
return _get_service
|
||||
|
||||
|
||||
def _override_current_user(user_id: UUID) -> Callable[[], CurrentUser]:
|
||||
def _get_user() -> CurrentUser:
|
||||
return CurrentUser(id=user_id)
|
||||
|
||||
return _get_user
|
||||
|
||||
|
||||
def test_get_me_returns_profile() -> None:
|
||||
user_id = UUID("00000000-0000-0000-0000-000000000001")
|
||||
profile = ProfileResponse(
|
||||
id=str(user_id),
|
||||
username="demo",
|
||||
display_name="Demo User",
|
||||
avatar_url=None,
|
||||
bio=None,
|
||||
)
|
||||
app.dependency_overrides[get_profile_service] = _override_profile_service(
|
||||
FakeProfileService(profile)
|
||||
)
|
||||
app.dependency_overrides[get_current_user] = _override_current_user(user_id)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.get("/api/v1/profile/me")
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["username"] == "demo"
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_patch_me_updates_profile() -> None:
|
||||
user_id = UUID("00000000-0000-0000-0000-000000000001")
|
||||
profile = ProfileResponse(
|
||||
id=str(user_id),
|
||||
username="demo",
|
||||
display_name="Demo User",
|
||||
avatar_url=None,
|
||||
bio=None,
|
||||
)
|
||||
app.dependency_overrides[get_profile_service] = _override_profile_service(
|
||||
FakeProfileService(profile)
|
||||
)
|
||||
app.dependency_overrides[get_current_user] = _override_current_user(user_id)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.patch(
|
||||
"/api/v1/profile/me",
|
||||
json={"display_name": "Updated"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["display_name"] == "Updated"
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_get_profile_by_username() -> None:
|
||||
profile = ProfileResponse(
|
||||
id="00000000-0000-0000-0000-000000000001",
|
||||
username="demo",
|
||||
display_name="Demo User",
|
||||
avatar_url=None,
|
||||
bio=None,
|
||||
)
|
||||
app.dependency_overrides[get_profile_service] = _override_profile_service(
|
||||
FakeProfileService(profile)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.get("/api/v1/profile/demo")
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["username"] == "demo"
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_profile_not_found_returns_problem_details() -> None:
|
||||
profile = ProfileResponse(
|
||||
id="00000000-0000-0000-0000-000000000001",
|
||||
username="demo",
|
||||
display_name="Demo User",
|
||||
avatar_url=None,
|
||||
bio=None,
|
||||
)
|
||||
app.dependency_overrides[get_profile_service] = _override_profile_service(
|
||||
FakeProfileService(profile)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.get("/api/v1/profile/unknown")
|
||||
assert response.status_code == 404
|
||||
assert response.headers["content-type"].startswith("application/problem+json")
|
||||
body = response.json()
|
||||
assert body["title"] == "Not Found"
|
||||
assert body["status"] == 404
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_patch_me_validation_error_returns_problem_details() -> None:
|
||||
user_id = UUID("00000000-0000-0000-0000-000000000001")
|
||||
profile = ProfileResponse(
|
||||
id=str(user_id),
|
||||
username="demo",
|
||||
display_name="Demo User",
|
||||
avatar_url=None,
|
||||
bio=None,
|
||||
)
|
||||
app.dependency_overrides[get_profile_service] = _override_profile_service(
|
||||
FakeProfileService(profile)
|
||||
)
|
||||
app.dependency_overrides[get_current_user] = _override_current_user(user_id)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.patch("/api/v1/profile/me", json={})
|
||||
assert response.status_code == 422
|
||||
assert response.headers["content-type"].startswith("application/problem+json")
|
||||
body = response.json()
|
||||
assert body["title"] == "Unprocessable Content"
|
||||
assert body["status"] == 422
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
@@ -0,0 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from core.db.base_service import BaseService
|
||||
|
||||
|
||||
def test_require_current_user_raises_when_missing() -> None:
|
||||
service = BaseService(current_user=None)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
service.require_current_user()
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
|
||||
def test_require_current_user_returns_user() -> None:
|
||||
user = CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001"))
|
||||
service = BaseService(current_user=user)
|
||||
|
||||
result = service.require_current_user()
|
||||
|
||||
assert result.id == user.id
|
||||
@@ -0,0 +1,78 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from core.db.base import Base, SoftDeleteMixin
|
||||
from core.db.base_repository import BaseRepository
|
||||
|
||||
|
||||
class Widget(SoftDeleteMixin, Base):
|
||||
__tablename__ = "widgets"
|
||||
|
||||
id: Mapped[UUID] = mapped_column(primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_engine():
|
||||
engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session(db_engine):
|
||||
async_session = async_sessionmaker(
|
||||
bind=db_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
async with async_session() as session:
|
||||
yield session
|
||||
await session.rollback()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_id_filters_soft_deleted(db_session: AsyncSession) -> None:
|
||||
repository = BaseRepository(db_session, Widget)
|
||||
widget_id = uuid4()
|
||||
|
||||
widget = Widget(id=widget_id, name="widget")
|
||||
db_session.add(widget)
|
||||
await db_session.commit()
|
||||
|
||||
found = await repository.get_by_id(widget_id)
|
||||
assert found is not None
|
||||
|
||||
deleted = await repository.soft_delete_by_id(widget_id)
|
||||
assert deleted is not None
|
||||
assert deleted.deleted_at is not None
|
||||
|
||||
missing = await repository.get_by_id(widget_id)
|
||||
assert missing is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_soft_delete_sets_timestamp(db_session: AsyncSession) -> None:
|
||||
repository = BaseRepository(db_session, Widget)
|
||||
widget_id = uuid4()
|
||||
|
||||
widget = Widget(id=widget_id, name="widget")
|
||||
db_session.add(widget)
|
||||
await db_session.commit()
|
||||
|
||||
deleted = await repository.soft_delete_by_id(widget_id)
|
||||
assert deleted is not None
|
||||
assert isinstance(deleted.deleted_at, datetime)
|
||||
deleted_at = deleted.deleted_at
|
||||
if deleted_at.tzinfo is None:
|
||||
deleted_at = deleted_at.replace(tzinfo=timezone.utc)
|
||||
assert deleted_at <= datetime.now(timezone.utc)
|
||||
@@ -0,0 +1,114 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from core.db.base import Base
|
||||
from models.profile import Profile
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_engine():
|
||||
"""Create in-memory SQLite engine for testing."""
|
||||
engine = create_async_engine(
|
||||
"sqlite+aiosqlite:///:memory:",
|
||||
echo=False,
|
||||
)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session(db_engine):
|
||||
"""Create a database session for testing."""
|
||||
async_session = async_sessionmaker(
|
||||
bind=db_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
async with async_session() as session:
|
||||
yield session
|
||||
await session.rollback()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_profile_model_create(db_session: AsyncSession) -> None:
|
||||
"""Test creating a Profile model."""
|
||||
profile_id = uuid4()
|
||||
profile = Profile(
|
||||
id=profile_id,
|
||||
username="testuser",
|
||||
display_name="Test User",
|
||||
)
|
||||
db_session.add(profile)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(profile)
|
||||
|
||||
assert profile.id == profile_id
|
||||
assert profile.username == "testuser"
|
||||
assert profile.display_name == "Test User"
|
||||
assert profile.created_at is not None
|
||||
assert profile.updated_at is not None
|
||||
assert profile.deleted_at is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_profile_model_get_by_id(db_session: AsyncSession) -> None:
|
||||
"""Test retrieving a Profile by ID."""
|
||||
profile_id = uuid4()
|
||||
profile = Profile(
|
||||
id=profile_id,
|
||||
username="testuser",
|
||||
display_name="Test User",
|
||||
)
|
||||
db_session.add(profile)
|
||||
await db_session.commit()
|
||||
|
||||
result = await db_session.get(Profile, profile_id)
|
||||
assert result is not None
|
||||
assert result.username == "testuser"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_profile_model_get_by_username(db_session: AsyncSession) -> None:
|
||||
"""Test retrieving a Profile by username."""
|
||||
profile = Profile(
|
||||
id=uuid4(),
|
||||
username="testuser",
|
||||
display_name="Test User",
|
||||
)
|
||||
db_session.add(profile)
|
||||
await db_session.commit()
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Profile).where(Profile.username == "testuser")
|
||||
)
|
||||
found = result.scalar_one()
|
||||
assert found is not None
|
||||
assert found.username == "testuser"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_profile_model_update(db_session: AsyncSession) -> None:
|
||||
"""Test updating a Profile."""
|
||||
profile = Profile(
|
||||
id=uuid4(),
|
||||
username="testuser",
|
||||
display_name="Test User",
|
||||
bio="Old bio",
|
||||
)
|
||||
db_session.add(profile)
|
||||
await db_session.commit()
|
||||
|
||||
profile.display_name = "Updated User"
|
||||
profile.bio = "New bio"
|
||||
await db_session.commit()
|
||||
await db_session.refresh(profile)
|
||||
|
||||
assert profile.display_name == "Updated User"
|
||||
assert profile.bio == "New bio"
|
||||
@@ -0,0 +1,78 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.config.settings import QdrantSettings
|
||||
from services.base.qdrant import QdrantService
|
||||
|
||||
|
||||
class _FakeCollection:
|
||||
def __init__(self, name: str) -> None:
|
||||
self.name = name
|
||||
|
||||
|
||||
class _FakeCollections:
|
||||
def __init__(self) -> None:
|
||||
self.collections = [_FakeCollection("default")]
|
||||
|
||||
|
||||
class _FakeQdrantClient:
|
||||
def get_collections(self) -> _FakeCollections:
|
||||
return _FakeCollections()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_success(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
service = QdrantService(settings=QdrantSettings(host="localhost", port=6333))
|
||||
|
||||
def _build_client(_: QdrantService) -> _FakeQdrantClient:
|
||||
return _FakeQdrantClient()
|
||||
|
||||
monkeypatch.setattr(QdrantService, "_build_client", _build_client)
|
||||
|
||||
result = await service.initialize()
|
||||
|
||||
assert result is True
|
||||
assert service.is_initialized is True
|
||||
|
||||
health = await service.health_check()
|
||||
assert health["status"] == "healthy"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_failure(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
service = QdrantService(settings=QdrantSettings(host="localhost", port=6333))
|
||||
|
||||
def _build_client(_: QdrantService) -> _FakeQdrantClient:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
monkeypatch.setattr(QdrantService, "_build_client", _build_client)
|
||||
|
||||
result = await service.initialize()
|
||||
|
||||
assert result is False
|
||||
assert service.is_initialized is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_returns_unhealthy_when_not_initialized() -> None:
|
||||
service = QdrantService(settings=QdrantSettings(host="localhost", port=6333))
|
||||
|
||||
health = await service.health_check()
|
||||
|
||||
assert health["status"] == "unhealthy"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_is_idempotent() -> None:
|
||||
service = QdrantService(settings=QdrantSettings(host="localhost", port=6333))
|
||||
|
||||
assert await service.close() is True
|
||||
assert service.is_initialized is False
|
||||
|
||||
|
||||
def test_get_client_raises_before_init() -> None:
|
||||
service = QdrantService(settings=QdrantSettings(host="localhost", port=6333))
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
service.get_client()
|
||||
@@ -0,0 +1,98 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.config.settings import RedisSettings
|
||||
from services.base.redis import RedisService
|
||||
|
||||
|
||||
class _FakeRedisClient:
|
||||
def __init__(self) -> None:
|
||||
self.closed = False
|
||||
|
||||
async def ping(self) -> bool:
|
||||
return True
|
||||
|
||||
async def info(self) -> dict[str, object]:
|
||||
return {
|
||||
"redis_version": "7.2",
|
||||
"connected_clients": 1,
|
||||
"used_memory_human": "1M",
|
||||
"uptime_in_seconds": 10,
|
||||
}
|
||||
|
||||
async def aclose(self) -> None:
|
||||
self.closed = True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_success(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
service = RedisService(settings=RedisSettings(host="localhost", port=6379))
|
||||
|
||||
def _build_client(_: RedisService) -> _FakeRedisClient:
|
||||
return _FakeRedisClient()
|
||||
|
||||
monkeypatch.setattr(RedisService, "_build_client", _build_client)
|
||||
|
||||
result = await service.initialize()
|
||||
|
||||
assert result is True
|
||||
assert service.is_initialized is True
|
||||
|
||||
health = await service.health_check()
|
||||
assert health["status"] == "healthy"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_failure(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
service = RedisService(settings=RedisSettings(host="localhost", port=6379))
|
||||
|
||||
def _build_client(_: RedisService) -> _FakeRedisClient:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
monkeypatch.setattr(RedisService, "_build_client", _build_client)
|
||||
|
||||
result = await service.initialize()
|
||||
|
||||
assert result is False
|
||||
assert service.is_initialized is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_is_idempotent() -> None:
|
||||
service = RedisService(settings=RedisSettings(host="localhost", port=6379))
|
||||
|
||||
assert await service.close() is True
|
||||
assert service.is_initialized is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_uninitialized() -> None:
|
||||
service = RedisService(settings=RedisSettings(host="localhost", port=6379))
|
||||
|
||||
health = await service.health_check()
|
||||
|
||||
assert health["status"] == "unhealthy"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_closes_client(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
service = RedisService(settings=RedisSettings(host="localhost", port=6379))
|
||||
client = _FakeRedisClient()
|
||||
|
||||
def _build_client(_: RedisService) -> _FakeRedisClient:
|
||||
return client
|
||||
|
||||
monkeypatch.setattr(RedisService, "_build_client", _build_client)
|
||||
|
||||
assert await service.initialize() is True
|
||||
assert await service.close() is True
|
||||
assert client.closed is True
|
||||
assert service.is_initialized is False
|
||||
|
||||
|
||||
def test_get_client_raises_before_init() -> None:
|
||||
service = RedisService(settings=RedisSettings(host="localhost", port=6379))
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
service.get_client()
|
||||
@@ -0,0 +1,49 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from services.base.service_interface import (
|
||||
BaseServiceProvider,
|
||||
ServiceRegistry,
|
||||
register_service,
|
||||
register_service_instance,
|
||||
)
|
||||
|
||||
|
||||
class _DummyService(BaseServiceProvider):
|
||||
def __init__(self, name: str = "dummy") -> None:
|
||||
super().__init__(name)
|
||||
|
||||
async def initialize(self, **_: object) -> bool:
|
||||
self._set_initialized(True)
|
||||
return True
|
||||
|
||||
async def close(self) -> bool:
|
||||
self._set_initialized(False)
|
||||
return True
|
||||
|
||||
async def health_check(self) -> dict[str, object]:
|
||||
return {"status": "healthy", "details": {}}
|
||||
|
||||
|
||||
def test_register_service_and_create_service() -> None:
|
||||
@register_service("dummy-service")
|
||||
class _RegisteredService(_DummyService):
|
||||
pass
|
||||
|
||||
created = ServiceRegistry.create_service("dummy-service")
|
||||
|
||||
assert created is not None
|
||||
assert created.get_service_info()["name"] == "dummy"
|
||||
|
||||
|
||||
def test_register_service_instance_returns_same_instance() -> None:
|
||||
instance = _DummyService("singleton")
|
||||
|
||||
returned = register_service_instance("dummy-singleton", instance)
|
||||
created = ServiceRegistry.create_service("dummy-singleton")
|
||||
|
||||
assert returned is instance
|
||||
assert created is instance
|
||||
|
||||
|
||||
def test_create_service_returns_none_for_missing() -> None:
|
||||
assert ServiceRegistry.create_service("missing-service") is None
|
||||
@@ -0,0 +1,45 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from celery import Celery
|
||||
from pytest import MonkeyPatch
|
||||
|
||||
from core.logging import celery as celery_logging
|
||||
from core.logging.context import clear_context, get_context
|
||||
|
||||
|
||||
class DummyTask:
|
||||
name: str = "tasks.sample"
|
||||
|
||||
|
||||
def test_celery_prerun_binds_task_context() -> None:
|
||||
handlers = celery_logging.build_celery_signal_handlers()
|
||||
|
||||
handlers.on_task_prerun(task_id="task-123", task=DummyTask())
|
||||
context = get_context()
|
||||
|
||||
assert context["task_id"] == "task-123"
|
||||
assert context["task_name"] == "tasks.sample"
|
||||
|
||||
clear_context()
|
||||
|
||||
|
||||
def test_celery_setup_logging_calls_configure(monkeypatch: MonkeyPatch) -> None:
|
||||
called = {"value": False}
|
||||
|
||||
def fake_configure_logging(settings: object | None = None) -> None:
|
||||
called["value"] = True
|
||||
|
||||
monkeypatch.setattr(celery_logging, "configure_logging", fake_configure_logging)
|
||||
handlers = celery_logging.build_celery_signal_handlers()
|
||||
|
||||
handlers.on_setup_logging()
|
||||
|
||||
assert called["value"] is True
|
||||
|
||||
|
||||
def test_configure_celery_app_disables_hijack() -> None:
|
||||
app = Celery("test")
|
||||
|
||||
celery_logging.configure_celery_app(app)
|
||||
|
||||
assert app.conf.worker_hijack_root_logger is False
|
||||
@@ -0,0 +1,140 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
import structlog
|
||||
|
||||
from core.config.settings import Settings
|
||||
from core.logging.config import build_logging_config, configure_logging
|
||||
|
||||
|
||||
def _get_handlers(config: dict[str, object]) -> dict[str, dict[str, object]]:
|
||||
return cast(dict[str, dict[str, object]], config["handlers"])
|
||||
|
||||
|
||||
def test_build_logging_config_time_rotation(tmp_path: Path) -> None:
|
||||
settings = Settings()
|
||||
runtime = settings.runtime.model_copy(
|
||||
update={
|
||||
"log_dir": str(tmp_path),
|
||||
"log_error_dir": str(tmp_path / "errors"),
|
||||
"log_rotation": "time",
|
||||
}
|
||||
)
|
||||
|
||||
config = build_logging_config(runtime)
|
||||
handlers = _get_handlers(config)
|
||||
|
||||
assert handlers["file"]["class"] == "logging.handlers.TimedRotatingFileHandler"
|
||||
assert handlers["error"]["class"] == "logging.handlers.TimedRotatingFileHandler"
|
||||
assert handlers["error"]["level"] == "ERROR"
|
||||
|
||||
|
||||
def test_build_logging_config_size_rotation(tmp_path: Path) -> None:
|
||||
settings = Settings()
|
||||
runtime = settings.runtime.model_copy(
|
||||
update={
|
||||
"log_dir": str(tmp_path),
|
||||
"log_error_dir": str(tmp_path / "errors"),
|
||||
"log_rotation": "size",
|
||||
"log_rotation_max_bytes": 2048,
|
||||
}
|
||||
)
|
||||
|
||||
config = build_logging_config(runtime)
|
||||
handlers = _get_handlers(config)
|
||||
|
||||
assert handlers["file"]["class"] == "logging.handlers.RotatingFileHandler"
|
||||
assert handlers["error"]["class"] == "logging.handlers.RotatingFileHandler"
|
||||
assert handlers["file"]["maxBytes"] == 2048
|
||||
|
||||
|
||||
def test_build_logging_config_plain_formatter_when_disabled(tmp_path: Path) -> None:
|
||||
settings = Settings()
|
||||
runtime = settings.runtime.model_copy(
|
||||
update={
|
||||
"log_dir": str(tmp_path),
|
||||
"log_error_dir": str(tmp_path / "errors"),
|
||||
"log_json": False,
|
||||
}
|
||||
)
|
||||
|
||||
config = build_logging_config(runtime)
|
||||
handlers = _get_handlers(config)
|
||||
|
||||
assert handlers["file"]["formatter"] == "plain"
|
||||
assert handlers["error"]["formatter"] == "plain"
|
||||
|
||||
|
||||
def _read_last_log_entry(log_path: Path) -> dict[str, object]:
|
||||
assert log_path.exists(), f"Expected log file at {log_path}"
|
||||
entries = [
|
||||
json.loads(line) for line in log_path.read_text().splitlines() if line.strip()
|
||||
]
|
||||
assert entries, "Expected at least one log entry in app.log"
|
||||
return entries[-1]
|
||||
|
||||
|
||||
def _flush_root_handlers() -> None:
|
||||
root_logger = logging.getLogger()
|
||||
for handler in root_logger.handlers:
|
||||
if hasattr(handler, "flush"):
|
||||
handler.flush()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def configured_logging(tmp_path: Path) -> Iterator[Path]:
|
||||
settings = Settings()
|
||||
runtime = settings.runtime.model_copy(
|
||||
update={
|
||||
"log_dir": str(tmp_path),
|
||||
"log_error_dir": str(tmp_path / "errors"),
|
||||
"log_rotation": "size",
|
||||
"log_rotation_max_bytes": 2048,
|
||||
"log_json": True,
|
||||
}
|
||||
)
|
||||
root_logger = logging.getLogger()
|
||||
original_handlers = root_logger.handlers[:]
|
||||
original_level = root_logger.level
|
||||
|
||||
configure_logging(settings.model_copy(update={"runtime": runtime}))
|
||||
|
||||
yield tmp_path
|
||||
|
||||
for handler in root_logger.handlers:
|
||||
handler.close()
|
||||
root_logger.handlers = original_handlers
|
||||
root_logger.setLevel(original_level)
|
||||
structlog.reset_defaults()
|
||||
|
||||
|
||||
def test_stdlib_logging_redacts_sensitive_fields(configured_logging: Path) -> None:
|
||||
logger = logging.getLogger("tests.stdlib")
|
||||
logger.info("login", extra={"password": "secret", "token": "abc"})
|
||||
|
||||
_flush_root_handlers()
|
||||
|
||||
log_path = configured_logging / "app.log"
|
||||
entry = _read_last_log_entry(log_path)
|
||||
|
||||
assert entry["password"] == "[REDACTED]"
|
||||
assert entry["token"] == "[REDACTED]"
|
||||
|
||||
|
||||
def test_structlog_redacts_sensitive_fields(configured_logging: Path) -> None:
|
||||
logger = structlog.get_logger("tests.structlog")
|
||||
logger.info("login", password="secret", token="abc")
|
||||
|
||||
_flush_root_handlers()
|
||||
|
||||
log_path = configured_logging / "app.log"
|
||||
entry = _read_last_log_entry(log_path)
|
||||
|
||||
assert entry["password"] == "[REDACTED]"
|
||||
assert entry["token"] == "[REDACTED]"
|
||||
@@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.logging.filters import build_sensitive_data_processor
|
||||
|
||||
|
||||
def test_redact_sensitive_fields_masks_values() -> None:
|
||||
processor = build_sensitive_data_processor(
|
||||
["password", "token", "api_key", "cookie"]
|
||||
)
|
||||
|
||||
event: dict[str, object] = {
|
||||
"message": "login",
|
||||
"password": "secret",
|
||||
"access_token": "token-123",
|
||||
"apiKey": "apikey-123",
|
||||
"set-cookie": "cookie-1",
|
||||
"nested": {"token": "abc", "safe": "ok"},
|
||||
"list": [{"password": "x"}],
|
||||
}
|
||||
|
||||
redacted = processor(None, "info", event)
|
||||
|
||||
assert redacted["password"] == "[REDACTED]"
|
||||
assert redacted["access_token"] == "[REDACTED]"
|
||||
assert redacted["apiKey"] == "[REDACTED]"
|
||||
assert redacted["set-cookie"] == "[REDACTED]"
|
||||
assert redacted["nested"]["token"] == "[REDACTED]"
|
||||
assert redacted["nested"]["safe"] == "ok"
|
||||
assert redacted["list"][0]["password"] == "[REDACTED]"
|
||||
assert event["password"] == "secret"
|
||||
@@ -0,0 +1,35 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pytest import MonkeyPatch
|
||||
|
||||
from core.config.settings import Settings
|
||||
|
||||
|
||||
def test_runtime_settings_defaults() -> None:
|
||||
settings = Settings()
|
||||
|
||||
assert settings.runtime.log_json is True
|
||||
assert settings.runtime.log_rotation == "time"
|
||||
assert settings.runtime.log_rotation_when == "midnight"
|
||||
assert settings.runtime.log_rotation_interval == 1
|
||||
assert settings.runtime.log_rotation_backup_count == 14
|
||||
assert settings.runtime.log_rotation_max_bytes == 10_000_000
|
||||
assert settings.runtime.log_dir == "logs"
|
||||
assert settings.runtime.log_error_dir == "logs/errors"
|
||||
assert settings.runtime.log_file_name == "app.log"
|
||||
assert settings.runtime.log_error_file_name == "error.log"
|
||||
assert "password" in settings.runtime.log_sensitive_fields
|
||||
|
||||
|
||||
def test_runtime_settings_env_override(monkeypatch: MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("SOCIAL_RUNTIME__LOG_DIR", "var/logs")
|
||||
monkeypatch.setenv("SOCIAL_RUNTIME__LOG_ERROR_DIR", "var/logs/errors")
|
||||
monkeypatch.setenv("SOCIAL_RUNTIME__LOG_ROTATION", "size")
|
||||
monkeypatch.setenv("SOCIAL_RUNTIME__LOG_ROTATION_MAX_BYTES", "2048")
|
||||
|
||||
settings = Settings()
|
||||
|
||||
assert settings.runtime.log_dir == "var/logs"
|
||||
assert settings.runtime.log_error_dir == "var/logs/errors"
|
||||
assert settings.runtime.log_rotation == "size"
|
||||
assert settings.runtime.log_rotation_max_bytes == 2048
|
||||
@@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.http.response import ProblemDetails, build_problem_details
|
||||
|
||||
|
||||
def test_problem_details_defaults() -> None:
|
||||
result = build_problem_details(status_code=401, detail="Unauthorized")
|
||||
|
||||
assert isinstance(result, ProblemDetails)
|
||||
assert result.type == "about:blank"
|
||||
assert result.title == "Unauthorized"
|
||||
assert result.status == 401
|
||||
assert result.detail == "Unauthorized"
|
||||
assert result.instance is None
|
||||
|
||||
|
||||
def test_problem_details_overrides() -> None:
|
||||
result = build_problem_details(
|
||||
status_code=409,
|
||||
detail="Conflict",
|
||||
type_value="https://example.com/problems/conflict",
|
||||
title="Conflict",
|
||||
instance="/api/mobile/auth/signup",
|
||||
)
|
||||
|
||||
assert result.type == "https://example.com/problems/conflict"
|
||||
assert result.title == "Conflict"
|
||||
assert result.status == 409
|
||||
assert result.detail == "Conflict"
|
||||
assert result.instance == "/api/mobile/auth/signup"
|
||||
@@ -0,0 +1,49 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pytest import MonkeyPatch
|
||||
|
||||
from core.config.settings import Settings
|
||||
|
||||
|
||||
def test_social_prefixed_supabase_env_populates_settings(
|
||||
monkeypatch: MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setenv("SOCIAL_SUPABASE__PUBLIC_SCHEME", "https")
|
||||
monkeypatch.setenv("SOCIAL_SUPABASE__PUBLIC_HOST", "public.example")
|
||||
monkeypatch.setenv("SOCIAL_SUPABASE__KONG_HTTP_PORT", "8443")
|
||||
monkeypatch.setenv("SOCIAL_SUPABASE__ANON_KEY", "anon-key")
|
||||
monkeypatch.setenv("SOCIAL_SUPABASE__SERVICE_ROLE_KEY", "service-key")
|
||||
monkeypatch.setenv("SOCIAL_SUPABASE__JWT_SECRET", "jwt-secret")
|
||||
monkeypatch.setenv("SOCIAL_DATABASE__HOST", "db")
|
||||
monkeypatch.setenv("SOCIAL_DATABASE__PORT", "5432")
|
||||
monkeypatch.setenv("SOCIAL_DATABASE__NAME", "app")
|
||||
monkeypatch.setenv("SOCIAL_DATABASE__USER", "user")
|
||||
monkeypatch.setenv("SOCIAL_DATABASE__PASSWORD", "pass")
|
||||
|
||||
settings = Settings()
|
||||
|
||||
assert settings.supabase.public_url == "https://public.example:8443"
|
||||
assert settings.supabase.api_external_url == "https://public.example:8443"
|
||||
assert settings.supabase.anon_key == "anon-key"
|
||||
assert settings.supabase.service_role_key == "service-key"
|
||||
assert settings.supabase.jwt_secret == "jwt-secret"
|
||||
|
||||
supabase_settings = settings.model_dump()["supabase"]
|
||||
assert supabase_settings["public_url"] == "https://public.example:8443"
|
||||
assert supabase_settings["api_external_url"] == "https://public.example:8443"
|
||||
assert supabase_settings["anon_key"] == "anon-key"
|
||||
assert supabase_settings["service_role_key"] == "service-key"
|
||||
assert supabase_settings["jwt_secret"] == "jwt-secret"
|
||||
assert settings.database_url == "postgresql+asyncpg://user:pass@db:5432/app"
|
||||
|
||||
|
||||
def test_social_prefixed_api_external_url_is_loaded(
|
||||
monkeypatch: MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setenv("SOCIAL_SUPABASE__PUBLIC_SCHEME", "https")
|
||||
monkeypatch.setenv("SOCIAL_SUPABASE__PUBLIC_HOST", "api.example")
|
||||
monkeypatch.setenv("SOCIAL_SUPABASE__KONG_HTTP_PORT", "8443")
|
||||
|
||||
settings = Settings()
|
||||
|
||||
assert settings.supabase.api_external_url == "https://api.example:8443"
|
||||
@@ -0,0 +1,41 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from v1.auth.models import (
|
||||
AuthTokenResponse,
|
||||
AuthUser,
|
||||
LoginRequest,
|
||||
RefreshRequest,
|
||||
SignupRequest,
|
||||
)
|
||||
|
||||
|
||||
def test_signup_requires_valid_email() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
SignupRequest(email="not-an-email", password="secret123")
|
||||
|
||||
|
||||
def test_login_requires_valid_email() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
LoginRequest(email="invalid", password="secret123")
|
||||
|
||||
|
||||
def test_refresh_requires_token() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
RefreshRequest(refresh_token="")
|
||||
|
||||
|
||||
def test_auth_token_response_maps_user() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
response = AuthTokenResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
|
||||
assert response.user.id == "user-1"
|
||||
assert response.user.email == "user@example.com"
|
||||
@@ -0,0 +1,74 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from v1.auth.models import (
|
||||
AuthTokenResponse,
|
||||
AuthUser,
|
||||
LoginRequest,
|
||||
RefreshRequest,
|
||||
SignupRequest,
|
||||
)
|
||||
from v1.auth.service import AuthService, AuthServiceGateway
|
||||
|
||||
|
||||
class FakeGateway(AuthServiceGateway):
|
||||
def __init__(self, response: AuthTokenResponse) -> None:
|
||||
self._response = response
|
||||
|
||||
async def signup(self, request: SignupRequest) -> AuthTokenResponse:
|
||||
return self._response
|
||||
|
||||
async def login(self, request: LoginRequest) -> AuthTokenResponse:
|
||||
return self._response
|
||||
|
||||
async def refresh(self, request: RefreshRequest) -> AuthTokenResponse:
|
||||
return self._response
|
||||
|
||||
async def logout(self, refresh_token: str | None) -> None:
|
||||
return None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_signup_maps_response() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = AuthTokenResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
service = AuthService(gateway=FakeGateway(token_response))
|
||||
|
||||
result = await service.signup(
|
||||
SignupRequest(email="user@example.com", password="secret123")
|
||||
)
|
||||
|
||||
assert result.access_token == "access"
|
||||
assert result.refresh_token == "refresh"
|
||||
assert result.user.id == "user-1"
|
||||
|
||||
|
||||
class LogoutAssertingGateway(AuthServiceGateway):
|
||||
def __init__(self, expected_refresh_token: str) -> None:
|
||||
self._expected_refresh_token = expected_refresh_token
|
||||
|
||||
async def signup(self, request: SignupRequest) -> AuthTokenResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
async def login(self, request: LoginRequest) -> AuthTokenResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
async def refresh(self, request: RefreshRequest) -> AuthTokenResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
async def logout(self, refresh_token: str | None) -> None:
|
||||
assert refresh_token == self._expected_refresh_token
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logout_forwards_refresh_token() -> None:
|
||||
service = AuthService(gateway=LogoutAssertingGateway("refresh-token"))
|
||||
|
||||
await service.logout("refresh-token")
|
||||
@@ -0,0 +1,285 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
import jwt
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from v1.profile.dependencies import get_current_user
|
||||
|
||||
|
||||
class TestGetCurrentUser:
|
||||
"""Tests for JWT validation in get_current_user dependency."""
|
||||
|
||||
@pytest.fixture
|
||||
def jwt_secret(self) -> str:
|
||||
return "super-secret-jwt-token-with-at-least-32-characters"
|
||||
|
||||
@pytest.fixture
|
||||
def valid_user_id(self) -> str:
|
||||
return "00000000-0000-0000-0000-000000000123"
|
||||
|
||||
@pytest.fixture
|
||||
def valid_payload(self, valid_user_id: str) -> dict[str, Any]:
|
||||
"""Valid JWT payload with all required claims."""
|
||||
now = int(time.time())
|
||||
return {
|
||||
"sub": valid_user_id,
|
||||
"aud": "authenticated",
|
||||
"iss": "http://localhost:8001/auth/v1",
|
||||
"exp": now + 3600, # 1 hour from now
|
||||
"iat": now,
|
||||
}
|
||||
|
||||
def _create_token(self, payload: dict[str, Any], secret: str) -> str:
|
||||
return jwt.encode(payload, secret, algorithm="HS256")
|
||||
|
||||
def test_valid_token_returns_current_user(
|
||||
self,
|
||||
jwt_secret: str,
|
||||
valid_payload: dict[str, Any],
|
||||
valid_user_id: str,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Valid JWT with correct aud/iss/exp should return CurrentUser."""
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.jwt_secret", jwt_secret
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.public_scheme",
|
||||
"http",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.public_host",
|
||||
"localhost",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.kong_http_port",
|
||||
8001,
|
||||
)
|
||||
|
||||
token = self._create_token(valid_payload, jwt_secret)
|
||||
authorization = f"Bearer {token}"
|
||||
|
||||
result = get_current_user(authorization=authorization)
|
||||
|
||||
assert isinstance(result, CurrentUser)
|
||||
assert result.id == UUID(valid_user_id)
|
||||
|
||||
def test_missing_authorization_raises_401(self) -> None:
|
||||
"""Missing Authorization header should raise 401."""
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_current_user(authorization=None)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert exc_info.value.detail == "Unauthorized"
|
||||
|
||||
def test_invalid_scheme_raises_401(self) -> None:
|
||||
"""Non-Bearer scheme should raise 401."""
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_current_user(authorization="Basic dXNlcjpwYXNz")
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
def test_expired_token_raises_401(
|
||||
self,
|
||||
jwt_secret: str,
|
||||
valid_payload: dict[str, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Expired JWT should raise 401."""
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.jwt_secret", jwt_secret
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.public_scheme",
|
||||
"http",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.public_host",
|
||||
"localhost",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.kong_http_port",
|
||||
8001,
|
||||
)
|
||||
|
||||
valid_payload["exp"] = int(time.time()) - 3600 # 1 hour ago
|
||||
token = self._create_token(valid_payload, jwt_secret)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_current_user(authorization=f"Bearer {token}")
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
def test_invalid_audience_raises_401(
|
||||
self,
|
||||
jwt_secret: str,
|
||||
valid_payload: dict[str, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""JWT with wrong audience should raise 401."""
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.jwt_secret", jwt_secret
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.public_scheme",
|
||||
"http",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.public_host",
|
||||
"localhost",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.kong_http_port",
|
||||
8001,
|
||||
)
|
||||
|
||||
valid_payload["aud"] = "wrong-audience"
|
||||
token = self._create_token(valid_payload, jwt_secret)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_current_user(authorization=f"Bearer {token}")
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
def test_invalid_issuer_raises_401(
|
||||
self,
|
||||
jwt_secret: str,
|
||||
valid_payload: dict[str, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""JWT with wrong issuer should raise 401."""
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.jwt_secret", jwt_secret
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.public_scheme",
|
||||
"http",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.public_host",
|
||||
"localhost",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.kong_http_port",
|
||||
8001,
|
||||
)
|
||||
|
||||
valid_payload["iss"] = "http://malicious-site.com/auth/v1"
|
||||
token = self._create_token(valid_payload, jwt_secret)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_current_user(authorization=f"Bearer {token}")
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
def test_missing_subject_raises_401(
|
||||
self,
|
||||
jwt_secret: str,
|
||||
valid_payload: dict[str, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""JWT without 'sub' claim should raise 401."""
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.jwt_secret", jwt_secret
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.public_scheme",
|
||||
"http",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.public_host",
|
||||
"localhost",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.kong_http_port",
|
||||
8001,
|
||||
)
|
||||
|
||||
del valid_payload["sub"]
|
||||
token = self._create_token(valid_payload, jwt_secret)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_current_user(authorization=f"Bearer {token}")
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
def test_wrong_secret_raises_401(
|
||||
self,
|
||||
jwt_secret: str,
|
||||
valid_payload: dict[str, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""JWT signed with wrong secret should raise 401."""
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.jwt_secret", jwt_secret
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.public_scheme",
|
||||
"http",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.public_host",
|
||||
"localhost",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.kong_http_port",
|
||||
8001,
|
||||
)
|
||||
|
||||
token = self._create_token(
|
||||
valid_payload, "wrong-secret-key-that-is-long-enough"
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_current_user(authorization=f"Bearer {token}")
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
def test_jwt_secret_not_configured_raises_503(
|
||||
self, valid_payload: dict[str, Any], monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Missing JWT secret in config should raise 503."""
|
||||
monkeypatch.setattr("v1.profile.dependencies.config.supabase.jwt_secret", None)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_current_user(authorization="Bearer some-token")
|
||||
|
||||
assert exc_info.value.status_code == 503
|
||||
assert exc_info.value.detail == "JWT secret not configured"
|
||||
|
||||
def test_invalid_uuid_in_subject_raises_401(
|
||||
self,
|
||||
jwt_secret: str,
|
||||
valid_payload: dict[str, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""JWT with non-UUID 'sub' claim should raise 401."""
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.jwt_secret", jwt_secret
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.public_scheme",
|
||||
"http",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.public_host",
|
||||
"localhost",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.kong_http_port",
|
||||
8001,
|
||||
)
|
||||
|
||||
valid_payload["sub"] = "not-a-valid-uuid"
|
||||
token = self._create_token(valid_payload, jwt_secret)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_current_user(authorization=f"Bearer {token}")
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
@@ -0,0 +1,172 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from models.profile import Profile
|
||||
from v1.profile.repository import ProfileRepository
|
||||
from v1.profile.schemas import ProfileUpdateRequest
|
||||
from v1.profile.service import ProfileService
|
||||
|
||||
|
||||
def _create_mock_profile(
|
||||
user_id: UUID = UUID("00000000-0000-0000-0000-000000000001"),
|
||||
username: str = "demo",
|
||||
display_name: str | None = "Demo User",
|
||||
avatar_url: str | None = None,
|
||||
bio: str | None = None,
|
||||
) -> Profile:
|
||||
"""Create a mock Profile ORM object."""
|
||||
profile = MagicMock(spec=Profile)
|
||||
profile.id = user_id
|
||||
profile.username = username
|
||||
profile.display_name = display_name
|
||||
profile.avatar_url = avatar_url
|
||||
profile.bio = bio
|
||||
return profile
|
||||
|
||||
|
||||
class FakeRepo:
|
||||
"""Fake repository for testing that conforms to ProfileRepository protocol."""
|
||||
|
||||
def __init__(self, profile: Profile | None) -> None:
|
||||
self._profile = profile
|
||||
|
||||
async def get_by_user_id(self, user_id: UUID) -> Profile | None:
|
||||
if self._profile and user_id == self._profile.id:
|
||||
return self._profile
|
||||
return None
|
||||
|
||||
async def get_by_username(self, username: str) -> Profile | None:
|
||||
if self._profile and username == self._profile.username:
|
||||
return self._profile
|
||||
return None
|
||||
|
||||
async def update_by_user_id(
|
||||
self, user_id: UUID, update_data: dict[str, str | None]
|
||||
) -> Profile | None:
|
||||
if not self._profile or user_id != self._profile.id:
|
||||
return None
|
||||
# Apply updates to mock
|
||||
for key, value in update_data.items():
|
||||
if hasattr(self._profile, key):
|
||||
setattr(self._profile, key, value)
|
||||
return self._profile
|
||||
|
||||
|
||||
# Verify FakeRepo implements the protocol
|
||||
_repo_check: ProfileRepository = FakeRepo(None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session() -> AsyncMock:
|
||||
"""Create a mock AsyncSession."""
|
||||
session = AsyncMock()
|
||||
session.commit = AsyncMock()
|
||||
session.rollback = AsyncMock()
|
||||
return session
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_me_returns_profile(mock_session: AsyncMock) -> None:
|
||||
user_id = UUID("00000000-0000-0000-0000-000000000001")
|
||||
profile = _create_mock_profile(user_id=user_id, username="demo")
|
||||
user = CurrentUser(id=user_id)
|
||||
service = ProfileService(
|
||||
repository=FakeRepo(profile),
|
||||
session=mock_session,
|
||||
current_user=user,
|
||||
)
|
||||
|
||||
result = await service.get_me()
|
||||
|
||||
assert result.username == "demo"
|
||||
assert result.id == str(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_me_not_found_raises_404(mock_session: AsyncMock) -> None:
|
||||
user_id = UUID("00000000-0000-0000-0000-000000000001")
|
||||
user = CurrentUser(id=user_id)
|
||||
service = ProfileService(
|
||||
repository=FakeRepo(None),
|
||||
session=mock_session,
|
||||
current_user=user,
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.get_me()
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_me_updates_fields(mock_session: AsyncMock) -> None:
|
||||
user_id = UUID("00000000-0000-0000-0000-000000000001")
|
||||
profile = _create_mock_profile(user_id=user_id, username="demo")
|
||||
user = CurrentUser(id=user_id)
|
||||
service = ProfileService(
|
||||
repository=FakeRepo(profile),
|
||||
session=mock_session,
|
||||
current_user=user,
|
||||
)
|
||||
|
||||
result = await service.update_me(ProfileUpdateRequest(display_name="Updated"))
|
||||
|
||||
assert result.display_name == "Updated"
|
||||
mock_session.commit.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_me_no_fields_raises_400(mock_session: AsyncMock) -> None:
|
||||
user_id = UUID("00000000-0000-0000-0000-000000000001")
|
||||
profile = _create_mock_profile(user_id=user_id)
|
||||
user = CurrentUser(id=user_id)
|
||||
service = ProfileService(
|
||||
repository=FakeRepo(profile),
|
||||
session=mock_session,
|
||||
current_user=user,
|
||||
)
|
||||
|
||||
# Create a request with all None values by bypassing validation
|
||||
update = MagicMock(spec=ProfileUpdateRequest)
|
||||
update.display_name = None
|
||||
update.avatar_url = None
|
||||
update.bio = None
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.update_me(update)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_username_returns_profile(mock_session: AsyncMock) -> None:
|
||||
profile = _create_mock_profile(username="demo")
|
||||
service = ProfileService(
|
||||
repository=FakeRepo(profile),
|
||||
session=mock_session,
|
||||
current_user=CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001")),
|
||||
)
|
||||
|
||||
result = await service.get_by_username("demo")
|
||||
|
||||
assert result.username == "demo"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_username_not_found_raises_404(mock_session: AsyncMock) -> None:
|
||||
service = ProfileService(
|
||||
repository=FakeRepo(None),
|
||||
session=mock_session,
|
||||
current_user=CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001")),
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.get_by_username("unknown")
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
@@ -0,0 +1,61 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from v1.profile.schemas import ProfileResponse, ProfileUpdateRequest
|
||||
|
||||
|
||||
def test_profile_response_maps_fields() -> None:
|
||||
response = ProfileResponse(
|
||||
id="user-1",
|
||||
username="demo",
|
||||
display_name="Demo User",
|
||||
avatar_url=None,
|
||||
bio=None,
|
||||
)
|
||||
|
||||
assert response.id == "user-1"
|
||||
assert response.username == "demo"
|
||||
|
||||
|
||||
def test_profile_update_requires_one_field() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
ProfileUpdateRequest()
|
||||
|
||||
|
||||
def test_profile_update_accepts_valid_https_url() -> None:
|
||||
request = ProfileUpdateRequest(avatar_url="https://example.com/avatar.png")
|
||||
assert request.avatar_url == "https://example.com/avatar.png"
|
||||
|
||||
|
||||
def test_profile_update_accepts_valid_http_url() -> None:
|
||||
request = ProfileUpdateRequest(
|
||||
avatar_url="http://localhost:8001/storage/avatar.png"
|
||||
)
|
||||
assert request.avatar_url == "http://localhost:8001/storage/avatar.png"
|
||||
|
||||
|
||||
def test_profile_update_rejects_invalid_url() -> None:
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
ProfileUpdateRequest(avatar_url="not-a-valid-url")
|
||||
|
||||
errors = exc_info.value.errors()
|
||||
assert len(errors) == 1
|
||||
assert "avatar_url" in str(errors[0]["loc"])
|
||||
|
||||
|
||||
def test_profile_update_rejects_javascript_url() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
ProfileUpdateRequest(avatar_url="javascript:alert('xss')")
|
||||
|
||||
|
||||
def test_profile_update_rejects_data_url() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
ProfileUpdateRequest(avatar_url="data:text/html,<script>alert('xss')</script>")
|
||||
|
||||
|
||||
def test_profile_update_accepts_none_avatar_url_with_other_field() -> None:
|
||||
request = ProfileUpdateRequest(display_name="Test", avatar_url=None)
|
||||
assert request.avatar_url is None
|
||||
assert request.display_name == "Test"
|
||||
Reference in New Issue
Block a user