refactor: align backend layout and supabase infra

Consolidate backend modules/tests under the backend package while syncing Supabase compose/env config and related plans.
This commit is contained in:
qzl
2026-02-05 15:13:06 +08:00
parent 3cfcb11240
commit ad06fe7de4
111 changed files with 5540 additions and 1362 deletions
@@ -0,0 +1,27 @@
from __future__ import annotations
import pytest
from fastapi import HTTPException
from uuid import UUID
from core.auth.models import CurrentUser
from core.db.base_service import BaseService
def test_require_current_user_raises_when_missing() -> None:
service = BaseService(current_user=None)
with pytest.raises(HTTPException) as exc_info:
service.require_current_user()
assert exc_info.value.status_code == 401
def test_require_current_user_returns_user() -> None:
user = CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001"))
service = BaseService(current_user=user)
result = service.require_current_user()
assert result.id == user.id
@@ -0,0 +1,78 @@
from __future__ import annotations
from datetime import datetime, timezone
from uuid import UUID, uuid4
import pytest
from sqlalchemy import String
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import Mapped, mapped_column
from core.db.base import Base, SoftDeleteMixin
from core.db.base_repository import BaseRepository
class Widget(SoftDeleteMixin, Base):
__tablename__ = "widgets"
id: Mapped[UUID] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String(50), nullable=False)
@pytest.fixture
async def db_engine():
engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest.fixture
async def db_session(db_engine):
async_session = async_sessionmaker(
bind=db_engine,
class_=AsyncSession,
expire_on_commit=False,
)
async with async_session() as session:
yield session
await session.rollback()
@pytest.mark.asyncio
async def test_get_by_id_filters_soft_deleted(db_session: AsyncSession) -> None:
repository = BaseRepository(db_session, Widget)
widget_id = uuid4()
widget = Widget(id=widget_id, name="widget")
db_session.add(widget)
await db_session.commit()
found = await repository.get_by_id(widget_id)
assert found is not None
deleted = await repository.soft_delete_by_id(widget_id)
assert deleted is not None
assert deleted.deleted_at is not None
missing = await repository.get_by_id(widget_id)
assert missing is None
@pytest.mark.asyncio
async def test_soft_delete_sets_timestamp(db_session: AsyncSession) -> None:
repository = BaseRepository(db_session, Widget)
widget_id = uuid4()
widget = Widget(id=widget_id, name="widget")
db_session.add(widget)
await db_session.commit()
deleted = await repository.soft_delete_by_id(widget_id)
assert deleted is not None
assert isinstance(deleted.deleted_at, datetime)
deleted_at = deleted.deleted_at
if deleted_at.tzinfo is None:
deleted_at = deleted_at.replace(tzinfo=timezone.utc)
assert deleted_at <= datetime.now(timezone.utc)
@@ -0,0 +1,114 @@
from __future__ import annotations
from uuid import uuid4
import pytest
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from core.db.base import Base
from models.profile import Profile
@pytest.fixture
async def db_engine():
"""Create in-memory SQLite engine for testing."""
engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
echo=False,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest.fixture
async def db_session(db_engine):
"""Create a database session for testing."""
async_session = async_sessionmaker(
bind=db_engine,
class_=AsyncSession,
expire_on_commit=False,
)
async with async_session() as session:
yield session
await session.rollback()
@pytest.mark.asyncio
async def test_profile_model_create(db_session: AsyncSession) -> None:
"""Test creating a Profile model."""
profile_id = uuid4()
profile = Profile(
id=profile_id,
username="testuser",
display_name="Test User",
)
db_session.add(profile)
await db_session.commit()
await db_session.refresh(profile)
assert profile.id == profile_id
assert profile.username == "testuser"
assert profile.display_name == "Test User"
assert profile.created_at is not None
assert profile.updated_at is not None
assert profile.deleted_at is None
@pytest.mark.asyncio
async def test_profile_model_get_by_id(db_session: AsyncSession) -> None:
"""Test retrieving a Profile by ID."""
profile_id = uuid4()
profile = Profile(
id=profile_id,
username="testuser",
display_name="Test User",
)
db_session.add(profile)
await db_session.commit()
result = await db_session.get(Profile, profile_id)
assert result is not None
assert result.username == "testuser"
@pytest.mark.asyncio
async def test_profile_model_get_by_username(db_session: AsyncSession) -> None:
"""Test retrieving a Profile by username."""
profile = Profile(
id=uuid4(),
username="testuser",
display_name="Test User",
)
db_session.add(profile)
await db_session.commit()
result = await db_session.execute(
select(Profile).where(Profile.username == "testuser")
)
found = result.scalar_one()
assert found is not None
assert found.username == "testuser"
@pytest.mark.asyncio
async def test_profile_model_update(db_session: AsyncSession) -> None:
"""Test updating a Profile."""
profile = Profile(
id=uuid4(),
username="testuser",
display_name="Test User",
bio="Old bio",
)
db_session.add(profile)
await db_session.commit()
profile.display_name = "Updated User"
profile.bio = "New bio"
await db_session.commit()
await db_session.refresh(profile)
assert profile.display_name == "Updated User"
assert profile.bio == "New bio"
@@ -0,0 +1,78 @@
from __future__ import annotations
import pytest
from core.config.settings import QdrantSettings
from services.base.qdrant import QdrantService
class _FakeCollection:
def __init__(self, name: str) -> None:
self.name = name
class _FakeCollections:
def __init__(self) -> None:
self.collections = [_FakeCollection("default")]
class _FakeQdrantClient:
def get_collections(self) -> _FakeCollections:
return _FakeCollections()
@pytest.mark.asyncio
async def test_initialize_success(monkeypatch: pytest.MonkeyPatch) -> None:
service = QdrantService(settings=QdrantSettings(host="localhost", port=6333))
def _build_client(_: QdrantService) -> _FakeQdrantClient:
return _FakeQdrantClient()
monkeypatch.setattr(QdrantService, "_build_client", _build_client)
result = await service.initialize()
assert result is True
assert service.is_initialized is True
health = await service.health_check()
assert health["status"] == "healthy"
@pytest.mark.asyncio
async def test_initialize_failure(monkeypatch: pytest.MonkeyPatch) -> None:
service = QdrantService(settings=QdrantSettings(host="localhost", port=6333))
def _build_client(_: QdrantService) -> _FakeQdrantClient:
raise RuntimeError("boom")
monkeypatch.setattr(QdrantService, "_build_client", _build_client)
result = await service.initialize()
assert result is False
assert service.is_initialized is False
@pytest.mark.asyncio
async def test_health_check_returns_unhealthy_when_not_initialized() -> None:
service = QdrantService(settings=QdrantSettings(host="localhost", port=6333))
health = await service.health_check()
assert health["status"] == "unhealthy"
@pytest.mark.asyncio
async def test_close_is_idempotent() -> None:
service = QdrantService(settings=QdrantSettings(host="localhost", port=6333))
assert await service.close() is True
assert service.is_initialized is False
def test_get_client_raises_before_init() -> None:
service = QdrantService(settings=QdrantSettings(host="localhost", port=6333))
with pytest.raises(RuntimeError):
service.get_client()
@@ -0,0 +1,98 @@
from __future__ import annotations
import pytest
from core.config.settings import RedisSettings
from services.base.redis import RedisService
class _FakeRedisClient:
def __init__(self) -> None:
self.closed = False
async def ping(self) -> bool:
return True
async def info(self) -> dict[str, object]:
return {
"redis_version": "7.2",
"connected_clients": 1,
"used_memory_human": "1M",
"uptime_in_seconds": 10,
}
async def aclose(self) -> None:
self.closed = True
@pytest.mark.asyncio
async def test_initialize_success(monkeypatch: pytest.MonkeyPatch) -> None:
service = RedisService(settings=RedisSettings(host="localhost", port=6379))
def _build_client(_: RedisService) -> _FakeRedisClient:
return _FakeRedisClient()
monkeypatch.setattr(RedisService, "_build_client", _build_client)
result = await service.initialize()
assert result is True
assert service.is_initialized is True
health = await service.health_check()
assert health["status"] == "healthy"
@pytest.mark.asyncio
async def test_initialize_failure(monkeypatch: pytest.MonkeyPatch) -> None:
service = RedisService(settings=RedisSettings(host="localhost", port=6379))
def _build_client(_: RedisService) -> _FakeRedisClient:
raise RuntimeError("boom")
monkeypatch.setattr(RedisService, "_build_client", _build_client)
result = await service.initialize()
assert result is False
assert service.is_initialized is False
@pytest.mark.asyncio
async def test_close_is_idempotent() -> None:
service = RedisService(settings=RedisSettings(host="localhost", port=6379))
assert await service.close() is True
assert service.is_initialized is False
@pytest.mark.asyncio
async def test_health_check_uninitialized() -> None:
service = RedisService(settings=RedisSettings(host="localhost", port=6379))
health = await service.health_check()
assert health["status"] == "unhealthy"
@pytest.mark.asyncio
async def test_close_closes_client(monkeypatch: pytest.MonkeyPatch) -> None:
service = RedisService(settings=RedisSettings(host="localhost", port=6379))
client = _FakeRedisClient()
def _build_client(_: RedisService) -> _FakeRedisClient:
return client
monkeypatch.setattr(RedisService, "_build_client", _build_client)
assert await service.initialize() is True
assert await service.close() is True
assert client.closed is True
assert service.is_initialized is False
def test_get_client_raises_before_init() -> None:
service = RedisService(settings=RedisSettings(host="localhost", port=6379))
with pytest.raises(RuntimeError):
service.get_client()
@@ -0,0 +1,49 @@
from __future__ import annotations
from services.base.service_interface import (
BaseServiceProvider,
ServiceRegistry,
register_service,
register_service_instance,
)
class _DummyService(BaseServiceProvider):
def __init__(self, name: str = "dummy") -> None:
super().__init__(name)
async def initialize(self, **_: object) -> bool:
self._set_initialized(True)
return True
async def close(self) -> bool:
self._set_initialized(False)
return True
async def health_check(self) -> dict[str, object]:
return {"status": "healthy", "details": {}}
def test_register_service_and_create_service() -> None:
@register_service("dummy-service")
class _RegisteredService(_DummyService):
pass
created = ServiceRegistry.create_service("dummy-service")
assert created is not None
assert created.get_service_info()["name"] == "dummy"
def test_register_service_instance_returns_same_instance() -> None:
instance = _DummyService("singleton")
returned = register_service_instance("dummy-singleton", instance)
created = ServiceRegistry.create_service("dummy-singleton")
assert returned is instance
assert created is instance
def test_create_service_returns_none_for_missing() -> None:
assert ServiceRegistry.create_service("missing-service") is None
+45
View File
@@ -0,0 +1,45 @@
from __future__ import annotations
from celery import Celery
from pytest import MonkeyPatch
from core.logging import celery as celery_logging
from core.logging.context import clear_context, get_context
class DummyTask:
name: str = "tasks.sample"
def test_celery_prerun_binds_task_context() -> None:
handlers = celery_logging.build_celery_signal_handlers()
handlers.on_task_prerun(task_id="task-123", task=DummyTask())
context = get_context()
assert context["task_id"] == "task-123"
assert context["task_name"] == "tasks.sample"
clear_context()
def test_celery_setup_logging_calls_configure(monkeypatch: MonkeyPatch) -> None:
called = {"value": False}
def fake_configure_logging(settings: object | None = None) -> None:
called["value"] = True
monkeypatch.setattr(celery_logging, "configure_logging", fake_configure_logging)
handlers = celery_logging.build_celery_signal_handlers()
handlers.on_setup_logging()
assert called["value"] is True
def test_configure_celery_app_disables_hijack() -> None:
app = Celery("test")
celery_logging.configure_celery_app(app)
assert app.conf.worker_hijack_root_logger is False
+140
View File
@@ -0,0 +1,140 @@
from __future__ import annotations
import json
import logging
from collections.abc import Iterator
from pathlib import Path
from typing import cast
import pytest
import structlog
from core.config.settings import Settings
from core.logging.config import build_logging_config, configure_logging
def _get_handlers(config: dict[str, object]) -> dict[str, dict[str, object]]:
return cast(dict[str, dict[str, object]], config["handlers"])
def test_build_logging_config_time_rotation(tmp_path: Path) -> None:
settings = Settings()
runtime = settings.runtime.model_copy(
update={
"log_dir": str(tmp_path),
"log_error_dir": str(tmp_path / "errors"),
"log_rotation": "time",
}
)
config = build_logging_config(runtime)
handlers = _get_handlers(config)
assert handlers["file"]["class"] == "logging.handlers.TimedRotatingFileHandler"
assert handlers["error"]["class"] == "logging.handlers.TimedRotatingFileHandler"
assert handlers["error"]["level"] == "ERROR"
def test_build_logging_config_size_rotation(tmp_path: Path) -> None:
settings = Settings()
runtime = settings.runtime.model_copy(
update={
"log_dir": str(tmp_path),
"log_error_dir": str(tmp_path / "errors"),
"log_rotation": "size",
"log_rotation_max_bytes": 2048,
}
)
config = build_logging_config(runtime)
handlers = _get_handlers(config)
assert handlers["file"]["class"] == "logging.handlers.RotatingFileHandler"
assert handlers["error"]["class"] == "logging.handlers.RotatingFileHandler"
assert handlers["file"]["maxBytes"] == 2048
def test_build_logging_config_plain_formatter_when_disabled(tmp_path: Path) -> None:
settings = Settings()
runtime = settings.runtime.model_copy(
update={
"log_dir": str(tmp_path),
"log_error_dir": str(tmp_path / "errors"),
"log_json": False,
}
)
config = build_logging_config(runtime)
handlers = _get_handlers(config)
assert handlers["file"]["formatter"] == "plain"
assert handlers["error"]["formatter"] == "plain"
def _read_last_log_entry(log_path: Path) -> dict[str, object]:
assert log_path.exists(), f"Expected log file at {log_path}"
entries = [
json.loads(line) for line in log_path.read_text().splitlines() if line.strip()
]
assert entries, "Expected at least one log entry in app.log"
return entries[-1]
def _flush_root_handlers() -> None:
root_logger = logging.getLogger()
for handler in root_logger.handlers:
if hasattr(handler, "flush"):
handler.flush()
@pytest.fixture
def configured_logging(tmp_path: Path) -> Iterator[Path]:
settings = Settings()
runtime = settings.runtime.model_copy(
update={
"log_dir": str(tmp_path),
"log_error_dir": str(tmp_path / "errors"),
"log_rotation": "size",
"log_rotation_max_bytes": 2048,
"log_json": True,
}
)
root_logger = logging.getLogger()
original_handlers = root_logger.handlers[:]
original_level = root_logger.level
configure_logging(settings.model_copy(update={"runtime": runtime}))
yield tmp_path
for handler in root_logger.handlers:
handler.close()
root_logger.handlers = original_handlers
root_logger.setLevel(original_level)
structlog.reset_defaults()
def test_stdlib_logging_redacts_sensitive_fields(configured_logging: Path) -> None:
logger = logging.getLogger("tests.stdlib")
logger.info("login", extra={"password": "secret", "token": "abc"})
_flush_root_handlers()
log_path = configured_logging / "app.log"
entry = _read_last_log_entry(log_path)
assert entry["password"] == "[REDACTED]"
assert entry["token"] == "[REDACTED]"
def test_structlog_redacts_sensitive_fields(configured_logging: Path) -> None:
logger = structlog.get_logger("tests.structlog")
logger.info("login", password="secret", token="abc")
_flush_root_handlers()
log_path = configured_logging / "app.log"
entry = _read_last_log_entry(log_path)
assert entry["password"] == "[REDACTED]"
assert entry["token"] == "[REDACTED]"
@@ -0,0 +1,30 @@
from __future__ import annotations
from core.logging.filters import build_sensitive_data_processor
def test_redact_sensitive_fields_masks_values() -> None:
processor = build_sensitive_data_processor(
["password", "token", "api_key", "cookie"]
)
event: dict[str, object] = {
"message": "login",
"password": "secret",
"access_token": "token-123",
"apiKey": "apikey-123",
"set-cookie": "cookie-1",
"nested": {"token": "abc", "safe": "ok"},
"list": [{"password": "x"}],
}
redacted = processor(None, "info", event)
assert redacted["password"] == "[REDACTED]"
assert redacted["access_token"] == "[REDACTED]"
assert redacted["apiKey"] == "[REDACTED]"
assert redacted["set-cookie"] == "[REDACTED]"
assert redacted["nested"]["token"] == "[REDACTED]"
assert redacted["nested"]["safe"] == "ok"
assert redacted["list"][0]["password"] == "[REDACTED]"
assert event["password"] == "secret"
@@ -0,0 +1,35 @@
from __future__ import annotations
from pytest import MonkeyPatch
from core.config.settings import Settings
def test_runtime_settings_defaults() -> None:
settings = Settings()
assert settings.runtime.log_json is True
assert settings.runtime.log_rotation == "time"
assert settings.runtime.log_rotation_when == "midnight"
assert settings.runtime.log_rotation_interval == 1
assert settings.runtime.log_rotation_backup_count == 14
assert settings.runtime.log_rotation_max_bytes == 10_000_000
assert settings.runtime.log_dir == "logs"
assert settings.runtime.log_error_dir == "logs/errors"
assert settings.runtime.log_file_name == "app.log"
assert settings.runtime.log_error_file_name == "error.log"
assert "password" in settings.runtime.log_sensitive_fields
def test_runtime_settings_env_override(monkeypatch: MonkeyPatch) -> None:
monkeypatch.setenv("SOCIAL_RUNTIME__LOG_DIR", "var/logs")
monkeypatch.setenv("SOCIAL_RUNTIME__LOG_ERROR_DIR", "var/logs/errors")
monkeypatch.setenv("SOCIAL_RUNTIME__LOG_ROTATION", "size")
monkeypatch.setenv("SOCIAL_RUNTIME__LOG_ROTATION_MAX_BYTES", "2048")
settings = Settings()
assert settings.runtime.log_dir == "var/logs"
assert settings.runtime.log_error_dir == "var/logs/errors"
assert settings.runtime.log_rotation == "size"
assert settings.runtime.log_rotation_max_bytes == 2048
@@ -0,0 +1,30 @@
from __future__ import annotations
from core.http.response import ProblemDetails, build_problem_details
def test_problem_details_defaults() -> None:
result = build_problem_details(status_code=401, detail="Unauthorized")
assert isinstance(result, ProblemDetails)
assert result.type == "about:blank"
assert result.title == "Unauthorized"
assert result.status == 401
assert result.detail == "Unauthorized"
assert result.instance is None
def test_problem_details_overrides() -> None:
result = build_problem_details(
status_code=409,
detail="Conflict",
type_value="https://example.com/problems/conflict",
title="Conflict",
instance="/api/mobile/auth/signup",
)
assert result.type == "https://example.com/problems/conflict"
assert result.title == "Conflict"
assert result.status == 409
assert result.detail == "Conflict"
assert result.instance == "/api/mobile/auth/signup"
@@ -0,0 +1,49 @@
from __future__ import annotations
from pytest import MonkeyPatch
from core.config.settings import Settings
def test_social_prefixed_supabase_env_populates_settings(
monkeypatch: MonkeyPatch,
) -> None:
monkeypatch.setenv("SOCIAL_SUPABASE__PUBLIC_SCHEME", "https")
monkeypatch.setenv("SOCIAL_SUPABASE__PUBLIC_HOST", "public.example")
monkeypatch.setenv("SOCIAL_SUPABASE__KONG_HTTP_PORT", "8443")
monkeypatch.setenv("SOCIAL_SUPABASE__ANON_KEY", "anon-key")
monkeypatch.setenv("SOCIAL_SUPABASE__SERVICE_ROLE_KEY", "service-key")
monkeypatch.setenv("SOCIAL_SUPABASE__JWT_SECRET", "jwt-secret")
monkeypatch.setenv("SOCIAL_DATABASE__HOST", "db")
monkeypatch.setenv("SOCIAL_DATABASE__PORT", "5432")
monkeypatch.setenv("SOCIAL_DATABASE__NAME", "app")
monkeypatch.setenv("SOCIAL_DATABASE__USER", "user")
monkeypatch.setenv("SOCIAL_DATABASE__PASSWORD", "pass")
settings = Settings()
assert settings.supabase.public_url == "https://public.example:8443"
assert settings.supabase.api_external_url == "https://public.example:8443"
assert settings.supabase.anon_key == "anon-key"
assert settings.supabase.service_role_key == "service-key"
assert settings.supabase.jwt_secret == "jwt-secret"
supabase_settings = settings.model_dump()["supabase"]
assert supabase_settings["public_url"] == "https://public.example:8443"
assert supabase_settings["api_external_url"] == "https://public.example:8443"
assert supabase_settings["anon_key"] == "anon-key"
assert supabase_settings["service_role_key"] == "service-key"
assert supabase_settings["jwt_secret"] == "jwt-secret"
assert settings.database_url == "postgresql+asyncpg://user:pass@db:5432/app"
def test_social_prefixed_api_external_url_is_loaded(
monkeypatch: MonkeyPatch,
) -> None:
monkeypatch.setenv("SOCIAL_SUPABASE__PUBLIC_SCHEME", "https")
monkeypatch.setenv("SOCIAL_SUPABASE__PUBLIC_HOST", "api.example")
monkeypatch.setenv("SOCIAL_SUPABASE__KONG_HTTP_PORT", "8443")
settings = Settings()
assert settings.supabase.api_external_url == "https://api.example:8443"
@@ -0,0 +1,41 @@
from __future__ import annotations
import pytest
from pydantic import ValidationError
from v1.auth.models import (
AuthTokenResponse,
AuthUser,
LoginRequest,
RefreshRequest,
SignupRequest,
)
def test_signup_requires_valid_email() -> None:
with pytest.raises(ValidationError):
SignupRequest(email="not-an-email", password="secret123")
def test_login_requires_valid_email() -> None:
with pytest.raises(ValidationError):
LoginRequest(email="invalid", password="secret123")
def test_refresh_requires_token() -> None:
with pytest.raises(ValidationError):
RefreshRequest(refresh_token="")
def test_auth_token_response_maps_user() -> None:
user = AuthUser(id="user-1", email="user@example.com")
response = AuthTokenResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
assert response.user.id == "user-1"
assert response.user.email == "user@example.com"
@@ -0,0 +1,74 @@
from __future__ import annotations
import pytest
from v1.auth.models import (
AuthTokenResponse,
AuthUser,
LoginRequest,
RefreshRequest,
SignupRequest,
)
from v1.auth.service import AuthService, AuthServiceGateway
class FakeGateway(AuthServiceGateway):
def __init__(self, response: AuthTokenResponse) -> None:
self._response = response
async def signup(self, request: SignupRequest) -> AuthTokenResponse:
return self._response
async def login(self, request: LoginRequest) -> AuthTokenResponse:
return self._response
async def refresh(self, request: RefreshRequest) -> AuthTokenResponse:
return self._response
async def logout(self, refresh_token: str | None) -> None:
return None
@pytest.mark.asyncio
async def test_signup_maps_response() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = AuthTokenResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
service = AuthService(gateway=FakeGateway(token_response))
result = await service.signup(
SignupRequest(email="user@example.com", password="secret123")
)
assert result.access_token == "access"
assert result.refresh_token == "refresh"
assert result.user.id == "user-1"
class LogoutAssertingGateway(AuthServiceGateway):
def __init__(self, expected_refresh_token: str) -> None:
self._expected_refresh_token = expected_refresh_token
async def signup(self, request: SignupRequest) -> AuthTokenResponse:
raise NotImplementedError
async def login(self, request: LoginRequest) -> AuthTokenResponse:
raise NotImplementedError
async def refresh(self, request: RefreshRequest) -> AuthTokenResponse:
raise NotImplementedError
async def logout(self, refresh_token: str | None) -> None:
assert refresh_token == self._expected_refresh_token
@pytest.mark.asyncio
async def test_logout_forwards_refresh_token() -> None:
service = AuthService(gateway=LogoutAssertingGateway("refresh-token"))
await service.logout("refresh-token")
@@ -0,0 +1,285 @@
from __future__ import annotations
import time
from typing import Any
from uuid import UUID
import jwt
import pytest
from fastapi import HTTPException
from core.auth.models import CurrentUser
from v1.profile.dependencies import get_current_user
class TestGetCurrentUser:
"""Tests for JWT validation in get_current_user dependency."""
@pytest.fixture
def jwt_secret(self) -> str:
return "super-secret-jwt-token-with-at-least-32-characters"
@pytest.fixture
def valid_user_id(self) -> str:
return "00000000-0000-0000-0000-000000000123"
@pytest.fixture
def valid_payload(self, valid_user_id: str) -> dict[str, Any]:
"""Valid JWT payload with all required claims."""
now = int(time.time())
return {
"sub": valid_user_id,
"aud": "authenticated",
"iss": "http://localhost:8001/auth/v1",
"exp": now + 3600, # 1 hour from now
"iat": now,
}
def _create_token(self, payload: dict[str, Any], secret: str) -> str:
return jwt.encode(payload, secret, algorithm="HS256")
def test_valid_token_returns_current_user(
self,
jwt_secret: str,
valid_payload: dict[str, Any],
valid_user_id: str,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Valid JWT with correct aud/iss/exp should return CurrentUser."""
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.jwt_secret", jwt_secret
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.public_scheme",
"http",
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.public_host",
"localhost",
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.kong_http_port",
8001,
)
token = self._create_token(valid_payload, jwt_secret)
authorization = f"Bearer {token}"
result = get_current_user(authorization=authorization)
assert isinstance(result, CurrentUser)
assert result.id == UUID(valid_user_id)
def test_missing_authorization_raises_401(self) -> None:
"""Missing Authorization header should raise 401."""
with pytest.raises(HTTPException) as exc_info:
get_current_user(authorization=None)
assert exc_info.value.status_code == 401
assert exc_info.value.detail == "Unauthorized"
def test_invalid_scheme_raises_401(self) -> None:
"""Non-Bearer scheme should raise 401."""
with pytest.raises(HTTPException) as exc_info:
get_current_user(authorization="Basic dXNlcjpwYXNz")
assert exc_info.value.status_code == 401
def test_expired_token_raises_401(
self,
jwt_secret: str,
valid_payload: dict[str, Any],
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Expired JWT should raise 401."""
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.jwt_secret", jwt_secret
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.public_scheme",
"http",
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.public_host",
"localhost",
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.kong_http_port",
8001,
)
valid_payload["exp"] = int(time.time()) - 3600 # 1 hour ago
token = self._create_token(valid_payload, jwt_secret)
with pytest.raises(HTTPException) as exc_info:
get_current_user(authorization=f"Bearer {token}")
assert exc_info.value.status_code == 401
def test_invalid_audience_raises_401(
self,
jwt_secret: str,
valid_payload: dict[str, Any],
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""JWT with wrong audience should raise 401."""
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.jwt_secret", jwt_secret
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.public_scheme",
"http",
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.public_host",
"localhost",
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.kong_http_port",
8001,
)
valid_payload["aud"] = "wrong-audience"
token = self._create_token(valid_payload, jwt_secret)
with pytest.raises(HTTPException) as exc_info:
get_current_user(authorization=f"Bearer {token}")
assert exc_info.value.status_code == 401
def test_invalid_issuer_raises_401(
self,
jwt_secret: str,
valid_payload: dict[str, Any],
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""JWT with wrong issuer should raise 401."""
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.jwt_secret", jwt_secret
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.public_scheme",
"http",
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.public_host",
"localhost",
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.kong_http_port",
8001,
)
valid_payload["iss"] = "http://malicious-site.com/auth/v1"
token = self._create_token(valid_payload, jwt_secret)
with pytest.raises(HTTPException) as exc_info:
get_current_user(authorization=f"Bearer {token}")
assert exc_info.value.status_code == 401
def test_missing_subject_raises_401(
self,
jwt_secret: str,
valid_payload: dict[str, Any],
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""JWT without 'sub' claim should raise 401."""
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.jwt_secret", jwt_secret
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.public_scheme",
"http",
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.public_host",
"localhost",
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.kong_http_port",
8001,
)
del valid_payload["sub"]
token = self._create_token(valid_payload, jwt_secret)
with pytest.raises(HTTPException) as exc_info:
get_current_user(authorization=f"Bearer {token}")
assert exc_info.value.status_code == 401
def test_wrong_secret_raises_401(
self,
jwt_secret: str,
valid_payload: dict[str, Any],
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""JWT signed with wrong secret should raise 401."""
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.jwt_secret", jwt_secret
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.public_scheme",
"http",
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.public_host",
"localhost",
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.kong_http_port",
8001,
)
token = self._create_token(
valid_payload, "wrong-secret-key-that-is-long-enough"
)
with pytest.raises(HTTPException) as exc_info:
get_current_user(authorization=f"Bearer {token}")
assert exc_info.value.status_code == 401
def test_jwt_secret_not_configured_raises_503(
self, valid_payload: dict[str, Any], monkeypatch: pytest.MonkeyPatch
) -> None:
"""Missing JWT secret in config should raise 503."""
monkeypatch.setattr("v1.profile.dependencies.config.supabase.jwt_secret", None)
with pytest.raises(HTTPException) as exc_info:
get_current_user(authorization="Bearer some-token")
assert exc_info.value.status_code == 503
assert exc_info.value.detail == "JWT secret not configured"
def test_invalid_uuid_in_subject_raises_401(
self,
jwt_secret: str,
valid_payload: dict[str, Any],
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""JWT with non-UUID 'sub' claim should raise 401."""
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.jwt_secret", jwt_secret
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.public_scheme",
"http",
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.public_host",
"localhost",
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.kong_http_port",
8001,
)
valid_payload["sub"] = "not-a-valid-uuid"
token = self._create_token(valid_payload, jwt_secret)
with pytest.raises(HTTPException) as exc_info:
get_current_user(authorization=f"Bearer {token}")
assert exc_info.value.status_code == 401
@@ -0,0 +1,172 @@
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
from uuid import UUID
import pytest
from fastapi import HTTPException
from core.auth.models import CurrentUser
from models.profile import Profile
from v1.profile.repository import ProfileRepository
from v1.profile.schemas import ProfileUpdateRequest
from v1.profile.service import ProfileService
def _create_mock_profile(
user_id: UUID = UUID("00000000-0000-0000-0000-000000000001"),
username: str = "demo",
display_name: str | None = "Demo User",
avatar_url: str | None = None,
bio: str | None = None,
) -> Profile:
"""Create a mock Profile ORM object."""
profile = MagicMock(spec=Profile)
profile.id = user_id
profile.username = username
profile.display_name = display_name
profile.avatar_url = avatar_url
profile.bio = bio
return profile
class FakeRepo:
"""Fake repository for testing that conforms to ProfileRepository protocol."""
def __init__(self, profile: Profile | None) -> None:
self._profile = profile
async def get_by_user_id(self, user_id: UUID) -> Profile | None:
if self._profile and user_id == self._profile.id:
return self._profile
return None
async def get_by_username(self, username: str) -> Profile | None:
if self._profile and username == self._profile.username:
return self._profile
return None
async def update_by_user_id(
self, user_id: UUID, update_data: dict[str, str | None]
) -> Profile | None:
if not self._profile or user_id != self._profile.id:
return None
# Apply updates to mock
for key, value in update_data.items():
if hasattr(self._profile, key):
setattr(self._profile, key, value)
return self._profile
# Verify FakeRepo implements the protocol
_repo_check: ProfileRepository = FakeRepo(None)
@pytest.fixture
def mock_session() -> AsyncMock:
"""Create a mock AsyncSession."""
session = AsyncMock()
session.commit = AsyncMock()
session.rollback = AsyncMock()
return session
@pytest.mark.asyncio
async def test_get_me_returns_profile(mock_session: AsyncMock) -> None:
user_id = UUID("00000000-0000-0000-0000-000000000001")
profile = _create_mock_profile(user_id=user_id, username="demo")
user = CurrentUser(id=user_id)
service = ProfileService(
repository=FakeRepo(profile),
session=mock_session,
current_user=user,
)
result = await service.get_me()
assert result.username == "demo"
assert result.id == str(user_id)
@pytest.mark.asyncio
async def test_get_me_not_found_raises_404(mock_session: AsyncMock) -> None:
user_id = UUID("00000000-0000-0000-0000-000000000001")
user = CurrentUser(id=user_id)
service = ProfileService(
repository=FakeRepo(None),
session=mock_session,
current_user=user,
)
with pytest.raises(HTTPException) as exc_info:
await service.get_me()
assert exc_info.value.status_code == 404
@pytest.mark.asyncio
async def test_update_me_updates_fields(mock_session: AsyncMock) -> None:
user_id = UUID("00000000-0000-0000-0000-000000000001")
profile = _create_mock_profile(user_id=user_id, username="demo")
user = CurrentUser(id=user_id)
service = ProfileService(
repository=FakeRepo(profile),
session=mock_session,
current_user=user,
)
result = await service.update_me(ProfileUpdateRequest(display_name="Updated"))
assert result.display_name == "Updated"
mock_session.commit.assert_awaited_once()
@pytest.mark.asyncio
async def test_update_me_no_fields_raises_400(mock_session: AsyncMock) -> None:
user_id = UUID("00000000-0000-0000-0000-000000000001")
profile = _create_mock_profile(user_id=user_id)
user = CurrentUser(id=user_id)
service = ProfileService(
repository=FakeRepo(profile),
session=mock_session,
current_user=user,
)
# Create a request with all None values by bypassing validation
update = MagicMock(spec=ProfileUpdateRequest)
update.display_name = None
update.avatar_url = None
update.bio = None
with pytest.raises(HTTPException) as exc_info:
await service.update_me(update)
assert exc_info.value.status_code == 400
@pytest.mark.asyncio
async def test_get_by_username_returns_profile(mock_session: AsyncMock) -> None:
profile = _create_mock_profile(username="demo")
service = ProfileService(
repository=FakeRepo(profile),
session=mock_session,
current_user=CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001")),
)
result = await service.get_by_username("demo")
assert result.username == "demo"
@pytest.mark.asyncio
async def test_get_by_username_not_found_raises_404(mock_session: AsyncMock) -> None:
service = ProfileService(
repository=FakeRepo(None),
session=mock_session,
current_user=CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001")),
)
with pytest.raises(HTTPException) as exc_info:
await service.get_by_username("unknown")
assert exc_info.value.status_code == 404
@@ -0,0 +1,61 @@
from __future__ import annotations
import pytest
from pydantic import ValidationError
from v1.profile.schemas import ProfileResponse, ProfileUpdateRequest
def test_profile_response_maps_fields() -> None:
response = ProfileResponse(
id="user-1",
username="demo",
display_name="Demo User",
avatar_url=None,
bio=None,
)
assert response.id == "user-1"
assert response.username == "demo"
def test_profile_update_requires_one_field() -> None:
with pytest.raises(ValidationError):
ProfileUpdateRequest()
def test_profile_update_accepts_valid_https_url() -> None:
request = ProfileUpdateRequest(avatar_url="https://example.com/avatar.png")
assert request.avatar_url == "https://example.com/avatar.png"
def test_profile_update_accepts_valid_http_url() -> None:
request = ProfileUpdateRequest(
avatar_url="http://localhost:8001/storage/avatar.png"
)
assert request.avatar_url == "http://localhost:8001/storage/avatar.png"
def test_profile_update_rejects_invalid_url() -> None:
with pytest.raises(ValidationError) as exc_info:
ProfileUpdateRequest(avatar_url="not-a-valid-url")
errors = exc_info.value.errors()
assert len(errors) == 1
assert "avatar_url" in str(errors[0]["loc"])
def test_profile_update_rejects_javascript_url() -> None:
with pytest.raises(ValidationError):
ProfileUpdateRequest(avatar_url="javascript:alert('xss')")
def test_profile_update_rejects_data_url() -> None:
with pytest.raises(ValidationError):
ProfileUpdateRequest(avatar_url="data:text/html,<script>alert('xss')</script>")
def test_profile_update_accepts_none_avatar_url_with_other_field() -> None:
request = ProfileUpdateRequest(display_name="Test", avatar_url=None)
assert request.avatar_url is None
assert request.display_name == "Test"