refactor: align backend layout and supabase infra
Consolidate backend modules/tests under the backend package while syncing Supabase compose/env config and related plans.
This commit is contained in:
@@ -0,0 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def pytest_configure() -> None:
|
||||
root = Path(__file__).resolve().parents[2]
|
||||
src_path = root / "backend" / "src"
|
||||
if str(src_path) not in sys.path:
|
||||
sys.path.insert(0, str(src_path))
|
||||
@@ -0,0 +1,134 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
|
||||
from playwright.sync_api import sync_playwright
|
||||
import uvicorn
|
||||
|
||||
from app import app
|
||||
from v1.auth.dependencies import get_auth_service
|
||||
from v1.auth.models import (
|
||||
AuthTokenResponse,
|
||||
AuthUser,
|
||||
LoginRequest,
|
||||
RefreshRequest,
|
||||
SignupRequest,
|
||||
)
|
||||
from v1.auth.service import AuthService
|
||||
|
||||
|
||||
class FakeE2EAuthService(AuthService):
|
||||
def __init__(self) -> None:
|
||||
self._user = AuthUser(id="user-1", email="user@example.com")
|
||||
|
||||
async def signup(self, request: SignupRequest) -> AuthTokenResponse:
|
||||
return AuthTokenResponse(
|
||||
access_token="access-1",
|
||||
refresh_token="refresh-1",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=self._user,
|
||||
)
|
||||
|
||||
async def login(self, request: LoginRequest) -> AuthTokenResponse:
|
||||
return AuthTokenResponse(
|
||||
access_token="access-2",
|
||||
refresh_token="refresh-2",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=self._user,
|
||||
)
|
||||
|
||||
async def refresh(self, request: RefreshRequest) -> AuthTokenResponse:
|
||||
return AuthTokenResponse(
|
||||
access_token="access-3",
|
||||
refresh_token="refresh-3",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=self._user,
|
||||
)
|
||||
|
||||
async def logout(self, refresh_token: str | None) -> None:
|
||||
return None
|
||||
|
||||
|
||||
def _find_free_port() -> int:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
sock.bind(("127.0.0.1", 0))
|
||||
return sock.getsockname()[1]
|
||||
|
||||
|
||||
def _wait_for_port(host: str, port: int, timeout: float = 5.0) -> None:
|
||||
deadline = time.time() + timeout
|
||||
while time.time() < deadline:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
if sock.connect_ex((host, port)) == 0:
|
||||
return
|
||||
time.sleep(0.05)
|
||||
raise RuntimeError("Server did not start in time")
|
||||
|
||||
|
||||
def _start_server(host: str, port: int):
|
||||
config = uvicorn.Config(app, host=host, port=port, log_level="info")
|
||||
server = uvicorn.Server(config)
|
||||
thread = threading.Thread(target=server.run, daemon=True)
|
||||
thread.start()
|
||||
_wait_for_port(host, port)
|
||||
return server, thread
|
||||
|
||||
|
||||
def test_auth_flow_e2e() -> None:
|
||||
app.dependency_overrides[get_auth_service] = lambda: FakeE2EAuthService()
|
||||
host = "127.0.0.1"
|
||||
port = _find_free_port()
|
||||
server, thread = _start_server(host, port)
|
||||
|
||||
try:
|
||||
with sync_playwright() as playwright:
|
||||
request_context = playwright.request.new_context(
|
||||
base_url=f"http://{host}:{port}"
|
||||
)
|
||||
try:
|
||||
signup = request_context.post(
|
||||
"/api/v1/auth/signup",
|
||||
data=json.dumps(
|
||||
{"email": "user@example.com", "password": "secret123"}
|
||||
),
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
assert signup.status == 200
|
||||
assert signup.json()["access_token"] == "access-1"
|
||||
|
||||
login = request_context.post(
|
||||
"/api/v1/auth/login",
|
||||
data=json.dumps(
|
||||
{"email": "user@example.com", "password": "secret123"}
|
||||
),
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
assert login.status == 200
|
||||
assert login.json()["access_token"] == "access-2"
|
||||
|
||||
refresh = request_context.post(
|
||||
"/api/v1/auth/refresh",
|
||||
data=json.dumps({"refresh_token": "refresh-2"}),
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
assert refresh.status == 200
|
||||
assert refresh.json()["access_token"] == "access-3"
|
||||
|
||||
logout = request_context.post(
|
||||
"/api/v1/auth/logout",
|
||||
data=json.dumps({"refresh_token": "refresh-3"}),
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
assert logout.status == 204
|
||||
finally:
|
||||
request_context.dispose()
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
server.should_exit = True
|
||||
thread.join(timeout=5)
|
||||
@@ -0,0 +1,79 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
|
||||
from playwright.sync_api import sync_playwright
|
||||
import uvicorn
|
||||
|
||||
from app import app
|
||||
from v1.infra.dependencies import get_qdrant_service, get_redis_service
|
||||
|
||||
|
||||
class _FakeService:
|
||||
def __init__(self) -> None:
|
||||
self._initialized = True
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
return self._initialized
|
||||
|
||||
async def initialize(self) -> bool:
|
||||
return True
|
||||
|
||||
async def health_check(self) -> dict[str, object]:
|
||||
return {"status": "healthy", "details": {}}
|
||||
|
||||
|
||||
def _find_free_port() -> int:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
sock.bind(("127.0.0.1", 0))
|
||||
return sock.getsockname()[1]
|
||||
|
||||
|
||||
def _wait_for_port(host: str, port: int, timeout: float = 5.0) -> None:
|
||||
deadline = time.time() + timeout
|
||||
while time.time() < deadline:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
if sock.connect_ex((host, port)) == 0:
|
||||
return
|
||||
time.sleep(0.05)
|
||||
raise RuntimeError("Server did not start in time")
|
||||
|
||||
|
||||
def _start_server(host: str, port: int):
|
||||
config = uvicorn.Config(app, host=host, port=port, log_level="info")
|
||||
server = uvicorn.Server(config)
|
||||
thread = threading.Thread(target=server.run, daemon=True)
|
||||
thread.start()
|
||||
_wait_for_port(host, port)
|
||||
return server, thread
|
||||
|
||||
|
||||
def test_infra_health_e2e() -> None:
|
||||
app.dependency_overrides[get_redis_service] = lambda: _FakeService()
|
||||
app.dependency_overrides[get_qdrant_service] = lambda: _FakeService()
|
||||
|
||||
host = "127.0.0.1"
|
||||
port = _find_free_port()
|
||||
server, thread = _start_server(host, port)
|
||||
|
||||
try:
|
||||
with sync_playwright() as playwright:
|
||||
request_context = playwright.request.new_context(
|
||||
base_url=f"http://{host}:{port}"
|
||||
)
|
||||
try:
|
||||
response = request_context.get("/api/v1/infra/health")
|
||||
assert response.status == 200
|
||||
body = response.json()
|
||||
assert body["status"] == "healthy"
|
||||
assert "redis" in body["services"]
|
||||
assert "qdrant" in body["services"]
|
||||
finally:
|
||||
request_context.dispose()
|
||||
finally:
|
||||
server.should_exit = True
|
||||
thread.join(timeout=5)
|
||||
app.dependency_overrides = {}
|
||||
@@ -0,0 +1,96 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import FastAPI
|
||||
from playwright.sync_api import sync_playwright
|
||||
import uvicorn
|
||||
|
||||
from core.config.settings import Settings
|
||||
from core.logging.config import configure_logging
|
||||
from core.logging.middleware import (
|
||||
RequestContextMiddleware,
|
||||
register_exception_handlers,
|
||||
)
|
||||
|
||||
|
||||
def _read_json_lines(path: Path) -> list[dict[str, object]]:
|
||||
return [json.loads(line) for line in path.read_text().splitlines() if line.strip()]
|
||||
|
||||
|
||||
def _find_free_port() -> int:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
sock.bind(("127.0.0.1", 0))
|
||||
return sock.getsockname()[1]
|
||||
|
||||
|
||||
def _wait_for_port(host: str, port: int, timeout: float = 5.0) -> None:
|
||||
deadline = time.time() + timeout
|
||||
while time.time() < deadline:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
if sock.connect_ex((host, port)) == 0:
|
||||
return
|
||||
time.sleep(0.05)
|
||||
raise RuntimeError("Server did not start in time")
|
||||
|
||||
|
||||
def _start_server(app: FastAPI, host: str, port: int):
|
||||
config = uvicorn.Config(app, host=host, port=port, log_level="info")
|
||||
server = uvicorn.Server(config)
|
||||
thread = threading.Thread(target=server.run, daemon=True)
|
||||
thread.start()
|
||||
_wait_for_port(host, port)
|
||||
return server, thread
|
||||
|
||||
|
||||
def test_e2e_error_logging(tmp_path: Path) -> None:
|
||||
settings = Settings()
|
||||
runtime = settings.runtime.model_copy(
|
||||
update={
|
||||
"log_dir": str(tmp_path),
|
||||
"log_error_dir": str(tmp_path / "errors"),
|
||||
"log_rotation": "size",
|
||||
"log_rotation_max_bytes": 2048,
|
||||
}
|
||||
)
|
||||
configure_logging(settings.model_copy(update={"runtime": runtime}))
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(RequestContextMiddleware) # type: ignore[arg-type]
|
||||
register_exception_handlers(app)
|
||||
|
||||
@app.get("/boom")
|
||||
async def boom() -> dict[str, str]:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
host = "127.0.0.1"
|
||||
port = _find_free_port()
|
||||
server, thread = _start_server(app, host, port)
|
||||
|
||||
try:
|
||||
with sync_playwright() as playwright:
|
||||
request_context = playwright.request.new_context(
|
||||
base_url=f"http://{host}:{port}"
|
||||
)
|
||||
response = request_context.get(
|
||||
"/boom",
|
||||
headers={"X-Request-ID": "e2e-5000"},
|
||||
)
|
||||
assert response.status == 500
|
||||
request_context.dispose()
|
||||
finally:
|
||||
server.should_exit = True
|
||||
thread.join(timeout=5)
|
||||
|
||||
error_entries = _read_json_lines(Path(tmp_path) / "errors" / "error.log")
|
||||
entry = next(
|
||||
item for item in error_entries if item.get("message") == "Unhandled exception"
|
||||
)
|
||||
|
||||
assert entry["request_id"] == "e2e-5000"
|
||||
exception = str(entry["exception"])
|
||||
assert "Traceback" in exception
|
||||
@@ -0,0 +1,57 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
|
||||
from playwright.sync_api import sync_playwright
|
||||
import uvicorn
|
||||
|
||||
from app import app
|
||||
|
||||
|
||||
def _find_free_port() -> int:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
sock.bind(("127.0.0.1", 0))
|
||||
return sock.getsockname()[1]
|
||||
|
||||
|
||||
def _wait_for_port(host: str, port: int, timeout: float = 5.0) -> None:
|
||||
deadline = time.time() + timeout
|
||||
while time.time() < deadline:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
if sock.connect_ex((host, port)) == 0:
|
||||
return
|
||||
time.sleep(0.05)
|
||||
raise RuntimeError("Server did not start in time")
|
||||
|
||||
|
||||
def _start_server(host: str, port: int):
|
||||
config = uvicorn.Config(app, host=host, port=port, log_level="info")
|
||||
server = uvicorn.Server(config)
|
||||
thread = threading.Thread(target=server.run, daemon=True)
|
||||
thread.start()
|
||||
_wait_for_port(host, port)
|
||||
return server, thread
|
||||
|
||||
|
||||
def test_mobile_health_e2e() -> None:
|
||||
host = "127.0.0.1"
|
||||
port = _find_free_port()
|
||||
server, thread = _start_server(host, port)
|
||||
|
||||
try:
|
||||
with sync_playwright() as playwright:
|
||||
request_context = playwright.request.new_context(
|
||||
base_url=f"http://{host}:{port}"
|
||||
)
|
||||
try:
|
||||
response = request_context.get("/api/v1/health")
|
||||
assert response.status == 200
|
||||
body = response.json()
|
||||
assert body["status"] == "ok"
|
||||
finally:
|
||||
request_context.dispose()
|
||||
finally:
|
||||
server.should_exit = True
|
||||
thread.join(timeout=5)
|
||||
@@ -0,0 +1,115 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
from uuid import UUID
|
||||
|
||||
from playwright.sync_api import sync_playwright
|
||||
import uvicorn
|
||||
|
||||
from app import app
|
||||
from core.auth.models import CurrentUser
|
||||
from v1.profile.dependencies import get_current_user, get_profile_service
|
||||
from v1.profile.schemas import ProfileResponse, ProfileUpdateRequest
|
||||
|
||||
|
||||
class FakeProfileService:
|
||||
"""Fake service for E2E testing."""
|
||||
|
||||
def __init__(self, profile: ProfileResponse) -> None:
|
||||
self._profile = profile
|
||||
|
||||
async def get_me(self) -> ProfileResponse:
|
||||
return self._profile
|
||||
|
||||
async def update_me(self, update: ProfileUpdateRequest) -> ProfileResponse:
|
||||
return ProfileResponse(
|
||||
id=self._profile.id,
|
||||
username=self._profile.username,
|
||||
display_name=(
|
||||
update.display_name
|
||||
if update.display_name is not None
|
||||
else self._profile.display_name
|
||||
),
|
||||
avatar_url=(
|
||||
update.avatar_url
|
||||
if update.avatar_url is not None
|
||||
else self._profile.avatar_url
|
||||
),
|
||||
bio=update.bio if update.bio is not None else self._profile.bio,
|
||||
)
|
||||
|
||||
async def get_by_username(self, username: str) -> ProfileResponse:
|
||||
return self._profile
|
||||
|
||||
|
||||
def _find_free_port() -> int:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
sock.bind(("127.0.0.1", 0))
|
||||
return sock.getsockname()[1]
|
||||
|
||||
|
||||
def _wait_for_port(host: str, port: int, timeout: float = 5.0) -> None:
|
||||
deadline = time.time() + timeout
|
||||
while time.time() < deadline:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
if sock.connect_ex((host, port)) == 0:
|
||||
return
|
||||
time.sleep(0.05)
|
||||
raise RuntimeError("Server did not start in time")
|
||||
|
||||
|
||||
def _start_server(host: str, port: int):
|
||||
config = uvicorn.Config(app, host=host, port=port, log_level="info")
|
||||
server = uvicorn.Server(config)
|
||||
thread = threading.Thread(target=server.run, daemon=True)
|
||||
thread.start()
|
||||
_wait_for_port(host, port)
|
||||
return server, thread
|
||||
|
||||
|
||||
def test_profile_flow_e2e() -> None:
|
||||
user_id = UUID("00000000-0000-0000-0000-000000000001")
|
||||
profile = ProfileResponse(
|
||||
id=str(user_id),
|
||||
username="demo",
|
||||
display_name="Demo User",
|
||||
avatar_url=None,
|
||||
bio=None,
|
||||
)
|
||||
app.dependency_overrides[get_profile_service] = lambda: FakeProfileService(profile) # type: ignore[return-value]
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(id=user_id)
|
||||
|
||||
host = "127.0.0.1"
|
||||
port = _find_free_port()
|
||||
server, thread = _start_server(host, port)
|
||||
|
||||
try:
|
||||
with sync_playwright() as playwright:
|
||||
request_context = playwright.request.new_context(
|
||||
base_url=f"http://{host}:{port}"
|
||||
)
|
||||
try:
|
||||
me = request_context.get("/api/v1/profile/me")
|
||||
assert me.status == 200
|
||||
assert me.json()["username"] == "demo"
|
||||
|
||||
updated = request_context.patch(
|
||||
"/api/v1/profile/me",
|
||||
data=json.dumps({"display_name": "Updated"}),
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
assert updated.status == 200
|
||||
assert updated.json()["display_name"] == "Updated"
|
||||
|
||||
public = request_context.get("/api/v1/profile/demo")
|
||||
assert public.status == 200
|
||||
assert public.json()["username"] == "demo"
|
||||
finally:
|
||||
request_context.dispose()
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
server.should_exit = True
|
||||
thread.join(timeout=5)
|
||||
@@ -0,0 +1,49 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
|
||||
import pytest
|
||||
|
||||
from core.config.settings import Settings
|
||||
from services.base.qdrant import QdrantService
|
||||
from services.base.redis import RedisService
|
||||
|
||||
|
||||
def _can_connect(host: str, port: int) -> bool:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
sock.settimeout(0.2)
|
||||
return sock.connect_ex((host, port)) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis_service_health_check_integration() -> None:
|
||||
host = "127.0.0.1"
|
||||
port = 6379
|
||||
if not _can_connect(host, port):
|
||||
pytest.skip("Redis is not running on localhost:6379")
|
||||
|
||||
config = Settings()
|
||||
settings = config.redis.model_copy(update={"host": host, "port": port})
|
||||
service = RedisService(settings=settings)
|
||||
|
||||
assert await service.initialize() is True
|
||||
health = await service.health_check()
|
||||
assert health["status"] == "healthy"
|
||||
assert await service.close() is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_qdrant_service_health_check_integration() -> None:
|
||||
host = "127.0.0.1"
|
||||
port = 6333
|
||||
if not _can_connect(host, port):
|
||||
pytest.skip("Qdrant is not running on localhost:6333")
|
||||
|
||||
config = Settings()
|
||||
settings = config.qdrant.model_copy(update={"host": host, "port": port})
|
||||
service = QdrantService(settings=settings)
|
||||
|
||||
assert await service.initialize() is True
|
||||
health = await service.health_check()
|
||||
assert health["status"] == "healthy"
|
||||
assert await service.close() is True
|
||||
@@ -0,0 +1,178 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app import app
|
||||
from v1.auth.dependencies import get_auth_service
|
||||
from v1.auth.models import (
|
||||
AuthTokenResponse,
|
||||
AuthUser,
|
||||
LoginRequest,
|
||||
RefreshRequest,
|
||||
SignupRequest,
|
||||
)
|
||||
from v1.auth.service import AuthService
|
||||
|
||||
|
||||
class FakeAuthService(AuthService):
|
||||
def __init__(self, token_response: AuthTokenResponse) -> None:
|
||||
self._token_response = token_response
|
||||
|
||||
async def signup(self, request: SignupRequest) -> AuthTokenResponse:
|
||||
return self._token_response
|
||||
|
||||
async def login(self, request: LoginRequest) -> AuthTokenResponse:
|
||||
raise HTTPException(status_code=401, detail="Invalid credentials")
|
||||
|
||||
async def refresh(self, request: RefreshRequest) -> AuthTokenResponse:
|
||||
raise HTTPException(status_code=401, detail="Invalid refresh token")
|
||||
|
||||
async def logout(self, refresh_token: str | None) -> None:
|
||||
return None
|
||||
|
||||
|
||||
def _override_auth_service(service: AuthService) -> Callable[[], AuthService]:
|
||||
def _get_service() -> AuthService:
|
||||
return service
|
||||
|
||||
return _get_service
|
||||
|
||||
|
||||
def test_signup_returns_token_response() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = AuthTokenResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
app.dependency_overrides[get_auth_service] = _override_auth_service(
|
||||
FakeAuthService(token_response)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.post(
|
||||
"/api/v1/auth/signup",
|
||||
json={"email": "user@example.com", "password": "secret123"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["access_token"] == "access"
|
||||
assert body["refresh_token"] == "refresh"
|
||||
assert body["user"]["email"] == "user@example.com"
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_login_invalid_returns_problem_details() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = AuthTokenResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
app.dependency_overrides[get_auth_service] = _override_auth_service(
|
||||
FakeAuthService(token_response)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": "user@example.com", "password": "wrongpw"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
assert response.headers["content-type"].startswith("application/problem+json")
|
||||
body = response.json()
|
||||
assert body["title"] == "Unauthorized"
|
||||
assert body["status"] == 401
|
||||
assert body["detail"] == "Invalid credentials"
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_refresh_invalid_returns_problem_details() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = AuthTokenResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
app.dependency_overrides[get_auth_service] = _override_auth_service(
|
||||
FakeAuthService(token_response)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": "invalid"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
assert response.headers["content-type"].startswith("application/problem+json")
|
||||
body = response.json()
|
||||
assert body["title"] == "Unauthorized"
|
||||
assert body["status"] == 401
|
||||
assert body["detail"] == "Invalid refresh token"
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_logout_returns_no_content() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = AuthTokenResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
app.dependency_overrides[get_auth_service] = _override_auth_service(
|
||||
FakeAuthService(token_response)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.post(
|
||||
"/api/v1/auth/logout",
|
||||
json={"refresh_token": "refresh"},
|
||||
)
|
||||
assert response.status_code == 204
|
||||
assert response.content == b""
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_signup_validation_error_returns_problem_details() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = AuthTokenResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
app.dependency_overrides[get_auth_service] = _override_auth_service(
|
||||
FakeAuthService(token_response)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.post("/api/v1/auth/signup", json={})
|
||||
assert response.status_code == 422
|
||||
assert response.headers["content-type"].startswith("application/problem+json")
|
||||
body = response.json()
|
||||
assert body["title"] == "Unprocessable Content"
|
||||
assert body["status"] == 422
|
||||
assert body["detail"] == "Invalid request"
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
@@ -0,0 +1,126 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from core.config.settings import Settings
|
||||
from core.logging.config import configure_logging
|
||||
from core.logging.logger import get_logger
|
||||
from core.logging.middleware import (
|
||||
RequestContextMiddleware,
|
||||
register_exception_handlers,
|
||||
)
|
||||
|
||||
|
||||
def _read_json_lines(path: Path) -> list[dict[str, object]]:
|
||||
return [json.loads(line) for line in path.read_text().splitlines() if line.strip()]
|
||||
|
||||
|
||||
def _configure_test_logging(tmp_path: Path) -> None:
|
||||
settings = Settings()
|
||||
runtime = settings.runtime.model_copy(
|
||||
update={
|
||||
"log_dir": str(tmp_path),
|
||||
"log_error_dir": str(tmp_path / "errors"),
|
||||
"log_rotation": "size",
|
||||
"log_rotation_max_bytes": 2048,
|
||||
}
|
||||
)
|
||||
test_settings = settings.model_copy(update={"runtime": runtime})
|
||||
|
||||
configure_logging(test_settings)
|
||||
|
||||
|
||||
def test_middleware_binds_request_context(tmp_path: Path) -> None:
|
||||
_configure_test_logging(tmp_path)
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(RequestContextMiddleware) # type: ignore[arg-type]
|
||||
|
||||
@app.get("/ok")
|
||||
async def ok() -> dict[str, str]:
|
||||
logger = get_logger("tests.ok")
|
||||
logger.info("request accepted", context_key="context_value")
|
||||
return {"status": "ok"}
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.get("/ok", headers={"X-Request-ID": "req-1234"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["X-Request-ID"] == "req-1234"
|
||||
|
||||
log_entries = _read_json_lines(Path(tmp_path) / "app.log")
|
||||
entry = next(
|
||||
item for item in log_entries if item.get("message") == "request accepted"
|
||||
)
|
||||
assert entry["message"] == "request accepted"
|
||||
assert entry["request_id"] == "req-1234"
|
||||
assert entry["method"] == "GET"
|
||||
assert entry["path"] == "/ok"
|
||||
assert entry["context_key"] == "context_value"
|
||||
|
||||
logging.shutdown()
|
||||
|
||||
|
||||
def test_exception_handler_logs_stack_and_sends_500(tmp_path: Path) -> None:
|
||||
_configure_test_logging(tmp_path)
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(RequestContextMiddleware)
|
||||
register_exception_handlers(app)
|
||||
|
||||
@app.get("/boom")
|
||||
async def boom() -> dict[str, str]:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
response = client.get("/boom", headers={"X-Request-ID": "req-5000"})
|
||||
|
||||
assert response.status_code == 500
|
||||
assert response.json()["detail"] == "Internal Server Error"
|
||||
|
||||
error_entries = _read_json_lines(Path(tmp_path) / "errors" / "error.log")
|
||||
assert error_entries
|
||||
entry = error_entries[-1]
|
||||
assert entry["level"] == "error"
|
||||
assert entry["request_id"] == "req-5000"
|
||||
exception = str(entry["exception"])
|
||||
assert "Traceback" in exception
|
||||
assert "test_fastapi_logging_integration" in exception
|
||||
|
||||
logging.shutdown()
|
||||
|
||||
|
||||
def test_invalid_request_id_is_replaced_and_used_in_error_context(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
_configure_test_logging(tmp_path)
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(RequestContextMiddleware)
|
||||
register_exception_handlers(app)
|
||||
|
||||
@app.get("/boom")
|
||||
async def boom() -> dict[str, str]:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
response = client.get("/boom", headers={"X-Request-ID": "bad"})
|
||||
|
||||
assert response.status_code == 500
|
||||
|
||||
response_request_id = response.headers["X-Request-ID"]
|
||||
assert response_request_id != "bad"
|
||||
|
||||
error_entries = _read_json_lines(Path(tmp_path) / "errors" / "error.log")
|
||||
assert error_entries
|
||||
entry = error_entries[-1]
|
||||
assert entry["request_id"] == response_request_id
|
||||
exception = str(entry["exception"])
|
||||
assert "Traceback" in exception
|
||||
|
||||
logging.shutdown()
|
||||
@@ -0,0 +1,39 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app import app
|
||||
|
||||
|
||||
def test_app_health_returns_envelope() -> None:
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/health")
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["status"] == "ok"
|
||||
|
||||
|
||||
def test_mobile_router_health_returns_envelope() -> None:
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/api/v1/health")
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["status"] == "ok"
|
||||
|
||||
|
||||
def test_not_found_returns_error_envelope() -> None:
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/missing-route")
|
||||
|
||||
assert response.status_code == 404
|
||||
assert response.headers["content-type"].startswith("application/problem+json")
|
||||
body = response.json()
|
||||
assert body["type"] == "about:blank"
|
||||
assert body["title"] == "Not Found"
|
||||
assert body["status"] == 404
|
||||
assert body["detail"] == "Not Found"
|
||||
@@ -0,0 +1,188 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app import app
|
||||
from core.auth.models import CurrentUser
|
||||
from v1.profile.dependencies import get_current_user, get_profile_service
|
||||
from v1.profile.schemas import ProfileResponse, ProfileUpdateRequest
|
||||
from v1.profile.service import ProfileService
|
||||
|
||||
|
||||
class FakeProfileService:
|
||||
"""Fake service for integration testing."""
|
||||
|
||||
def __init__(self, profile: ProfileResponse) -> None:
|
||||
self._profile = profile
|
||||
|
||||
async def get_me(self) -> ProfileResponse:
|
||||
if self._profile.id is None:
|
||||
raise HTTPException(status_code=404, detail="Profile not found")
|
||||
return self._profile
|
||||
|
||||
async def update_me(self, update: ProfileUpdateRequest) -> ProfileResponse:
|
||||
if self._profile.id is None:
|
||||
raise HTTPException(status_code=404, detail="Profile not found")
|
||||
return ProfileResponse(
|
||||
id=self._profile.id,
|
||||
username=self._profile.username,
|
||||
display_name=(
|
||||
update.display_name
|
||||
if update.display_name is not None
|
||||
else self._profile.display_name
|
||||
),
|
||||
avatar_url=(
|
||||
update.avatar_url
|
||||
if update.avatar_url is not None
|
||||
else self._profile.avatar_url
|
||||
),
|
||||
bio=update.bio if update.bio is not None else self._profile.bio,
|
||||
)
|
||||
|
||||
async def get_by_username(self, username: str) -> ProfileResponse:
|
||||
if username != self._profile.username:
|
||||
raise HTTPException(status_code=404, detail="Profile not found")
|
||||
return self._profile
|
||||
|
||||
|
||||
def _override_profile_service(
|
||||
service: FakeProfileService,
|
||||
) -> Callable[[], ProfileService]:
|
||||
def _get_service() -> ProfileService:
|
||||
return service # type: ignore[return-value]
|
||||
|
||||
return _get_service
|
||||
|
||||
|
||||
def _override_current_user(user_id: UUID) -> Callable[[], CurrentUser]:
|
||||
def _get_user() -> CurrentUser:
|
||||
return CurrentUser(id=user_id)
|
||||
|
||||
return _get_user
|
||||
|
||||
|
||||
def test_get_me_returns_profile() -> None:
|
||||
user_id = UUID("00000000-0000-0000-0000-000000000001")
|
||||
profile = ProfileResponse(
|
||||
id=str(user_id),
|
||||
username="demo",
|
||||
display_name="Demo User",
|
||||
avatar_url=None,
|
||||
bio=None,
|
||||
)
|
||||
app.dependency_overrides[get_profile_service] = _override_profile_service(
|
||||
FakeProfileService(profile)
|
||||
)
|
||||
app.dependency_overrides[get_current_user] = _override_current_user(user_id)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.get("/api/v1/profile/me")
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["username"] == "demo"
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_patch_me_updates_profile() -> None:
|
||||
user_id = UUID("00000000-0000-0000-0000-000000000001")
|
||||
profile = ProfileResponse(
|
||||
id=str(user_id),
|
||||
username="demo",
|
||||
display_name="Demo User",
|
||||
avatar_url=None,
|
||||
bio=None,
|
||||
)
|
||||
app.dependency_overrides[get_profile_service] = _override_profile_service(
|
||||
FakeProfileService(profile)
|
||||
)
|
||||
app.dependency_overrides[get_current_user] = _override_current_user(user_id)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.patch(
|
||||
"/api/v1/profile/me",
|
||||
json={"display_name": "Updated"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["display_name"] == "Updated"
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_get_profile_by_username() -> None:
|
||||
profile = ProfileResponse(
|
||||
id="00000000-0000-0000-0000-000000000001",
|
||||
username="demo",
|
||||
display_name="Demo User",
|
||||
avatar_url=None,
|
||||
bio=None,
|
||||
)
|
||||
app.dependency_overrides[get_profile_service] = _override_profile_service(
|
||||
FakeProfileService(profile)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.get("/api/v1/profile/demo")
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["username"] == "demo"
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_profile_not_found_returns_problem_details() -> None:
|
||||
profile = ProfileResponse(
|
||||
id="00000000-0000-0000-0000-000000000001",
|
||||
username="demo",
|
||||
display_name="Demo User",
|
||||
avatar_url=None,
|
||||
bio=None,
|
||||
)
|
||||
app.dependency_overrides[get_profile_service] = _override_profile_service(
|
||||
FakeProfileService(profile)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.get("/api/v1/profile/unknown")
|
||||
assert response.status_code == 404
|
||||
assert response.headers["content-type"].startswith("application/problem+json")
|
||||
body = response.json()
|
||||
assert body["title"] == "Not Found"
|
||||
assert body["status"] == 404
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_patch_me_validation_error_returns_problem_details() -> None:
|
||||
user_id = UUID("00000000-0000-0000-0000-000000000001")
|
||||
profile = ProfileResponse(
|
||||
id=str(user_id),
|
||||
username="demo",
|
||||
display_name="Demo User",
|
||||
avatar_url=None,
|
||||
bio=None,
|
||||
)
|
||||
app.dependency_overrides[get_profile_service] = _override_profile_service(
|
||||
FakeProfileService(profile)
|
||||
)
|
||||
app.dependency_overrides[get_current_user] = _override_current_user(user_id)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.patch("/api/v1/profile/me", json={})
|
||||
assert response.status_code == 422
|
||||
assert response.headers["content-type"].startswith("application/problem+json")
|
||||
body = response.json()
|
||||
assert body["title"] == "Unprocessable Content"
|
||||
assert body["status"] == 422
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
@@ -0,0 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from core.db.base_service import BaseService
|
||||
|
||||
|
||||
def test_require_current_user_raises_when_missing() -> None:
|
||||
service = BaseService(current_user=None)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
service.require_current_user()
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
|
||||
def test_require_current_user_returns_user() -> None:
|
||||
user = CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001"))
|
||||
service = BaseService(current_user=user)
|
||||
|
||||
result = service.require_current_user()
|
||||
|
||||
assert result.id == user.id
|
||||
@@ -0,0 +1,78 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from core.db.base import Base, SoftDeleteMixin
|
||||
from core.db.base_repository import BaseRepository
|
||||
|
||||
|
||||
class Widget(SoftDeleteMixin, Base):
|
||||
__tablename__ = "widgets"
|
||||
|
||||
id: Mapped[UUID] = mapped_column(primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_engine():
|
||||
engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session(db_engine):
|
||||
async_session = async_sessionmaker(
|
||||
bind=db_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
async with async_session() as session:
|
||||
yield session
|
||||
await session.rollback()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_id_filters_soft_deleted(db_session: AsyncSession) -> None:
|
||||
repository = BaseRepository(db_session, Widget)
|
||||
widget_id = uuid4()
|
||||
|
||||
widget = Widget(id=widget_id, name="widget")
|
||||
db_session.add(widget)
|
||||
await db_session.commit()
|
||||
|
||||
found = await repository.get_by_id(widget_id)
|
||||
assert found is not None
|
||||
|
||||
deleted = await repository.soft_delete_by_id(widget_id)
|
||||
assert deleted is not None
|
||||
assert deleted.deleted_at is not None
|
||||
|
||||
missing = await repository.get_by_id(widget_id)
|
||||
assert missing is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_soft_delete_sets_timestamp(db_session: AsyncSession) -> None:
|
||||
repository = BaseRepository(db_session, Widget)
|
||||
widget_id = uuid4()
|
||||
|
||||
widget = Widget(id=widget_id, name="widget")
|
||||
db_session.add(widget)
|
||||
await db_session.commit()
|
||||
|
||||
deleted = await repository.soft_delete_by_id(widget_id)
|
||||
assert deleted is not None
|
||||
assert isinstance(deleted.deleted_at, datetime)
|
||||
deleted_at = deleted.deleted_at
|
||||
if deleted_at.tzinfo is None:
|
||||
deleted_at = deleted_at.replace(tzinfo=timezone.utc)
|
||||
assert deleted_at <= datetime.now(timezone.utc)
|
||||
@@ -0,0 +1,114 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from core.db.base import Base
|
||||
from models.profile import Profile
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_engine():
|
||||
"""Create in-memory SQLite engine for testing."""
|
||||
engine = create_async_engine(
|
||||
"sqlite+aiosqlite:///:memory:",
|
||||
echo=False,
|
||||
)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session(db_engine):
|
||||
"""Create a database session for testing."""
|
||||
async_session = async_sessionmaker(
|
||||
bind=db_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
async with async_session() as session:
|
||||
yield session
|
||||
await session.rollback()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_profile_model_create(db_session: AsyncSession) -> None:
|
||||
"""Test creating a Profile model."""
|
||||
profile_id = uuid4()
|
||||
profile = Profile(
|
||||
id=profile_id,
|
||||
username="testuser",
|
||||
display_name="Test User",
|
||||
)
|
||||
db_session.add(profile)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(profile)
|
||||
|
||||
assert profile.id == profile_id
|
||||
assert profile.username == "testuser"
|
||||
assert profile.display_name == "Test User"
|
||||
assert profile.created_at is not None
|
||||
assert profile.updated_at is not None
|
||||
assert profile.deleted_at is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_profile_model_get_by_id(db_session: AsyncSession) -> None:
|
||||
"""Test retrieving a Profile by ID."""
|
||||
profile_id = uuid4()
|
||||
profile = Profile(
|
||||
id=profile_id,
|
||||
username="testuser",
|
||||
display_name="Test User",
|
||||
)
|
||||
db_session.add(profile)
|
||||
await db_session.commit()
|
||||
|
||||
result = await db_session.get(Profile, profile_id)
|
||||
assert result is not None
|
||||
assert result.username == "testuser"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_profile_model_get_by_username(db_session: AsyncSession) -> None:
|
||||
"""Test retrieving a Profile by username."""
|
||||
profile = Profile(
|
||||
id=uuid4(),
|
||||
username="testuser",
|
||||
display_name="Test User",
|
||||
)
|
||||
db_session.add(profile)
|
||||
await db_session.commit()
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Profile).where(Profile.username == "testuser")
|
||||
)
|
||||
found = result.scalar_one()
|
||||
assert found is not None
|
||||
assert found.username == "testuser"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_profile_model_update(db_session: AsyncSession) -> None:
|
||||
"""Test updating a Profile."""
|
||||
profile = Profile(
|
||||
id=uuid4(),
|
||||
username="testuser",
|
||||
display_name="Test User",
|
||||
bio="Old bio",
|
||||
)
|
||||
db_session.add(profile)
|
||||
await db_session.commit()
|
||||
|
||||
profile.display_name = "Updated User"
|
||||
profile.bio = "New bio"
|
||||
await db_session.commit()
|
||||
await db_session.refresh(profile)
|
||||
|
||||
assert profile.display_name == "Updated User"
|
||||
assert profile.bio == "New bio"
|
||||
@@ -0,0 +1,78 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.config.settings import QdrantSettings
|
||||
from services.base.qdrant import QdrantService
|
||||
|
||||
|
||||
class _FakeCollection:
|
||||
def __init__(self, name: str) -> None:
|
||||
self.name = name
|
||||
|
||||
|
||||
class _FakeCollections:
|
||||
def __init__(self) -> None:
|
||||
self.collections = [_FakeCollection("default")]
|
||||
|
||||
|
||||
class _FakeQdrantClient:
|
||||
def get_collections(self) -> _FakeCollections:
|
||||
return _FakeCollections()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_success(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
service = QdrantService(settings=QdrantSettings(host="localhost", port=6333))
|
||||
|
||||
def _build_client(_: QdrantService) -> _FakeQdrantClient:
|
||||
return _FakeQdrantClient()
|
||||
|
||||
monkeypatch.setattr(QdrantService, "_build_client", _build_client)
|
||||
|
||||
result = await service.initialize()
|
||||
|
||||
assert result is True
|
||||
assert service.is_initialized is True
|
||||
|
||||
health = await service.health_check()
|
||||
assert health["status"] == "healthy"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_failure(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
service = QdrantService(settings=QdrantSettings(host="localhost", port=6333))
|
||||
|
||||
def _build_client(_: QdrantService) -> _FakeQdrantClient:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
monkeypatch.setattr(QdrantService, "_build_client", _build_client)
|
||||
|
||||
result = await service.initialize()
|
||||
|
||||
assert result is False
|
||||
assert service.is_initialized is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_returns_unhealthy_when_not_initialized() -> None:
|
||||
service = QdrantService(settings=QdrantSettings(host="localhost", port=6333))
|
||||
|
||||
health = await service.health_check()
|
||||
|
||||
assert health["status"] == "unhealthy"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_is_idempotent() -> None:
|
||||
service = QdrantService(settings=QdrantSettings(host="localhost", port=6333))
|
||||
|
||||
assert await service.close() is True
|
||||
assert service.is_initialized is False
|
||||
|
||||
|
||||
def test_get_client_raises_before_init() -> None:
|
||||
service = QdrantService(settings=QdrantSettings(host="localhost", port=6333))
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
service.get_client()
|
||||
@@ -0,0 +1,98 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.config.settings import RedisSettings
|
||||
from services.base.redis import RedisService
|
||||
|
||||
|
||||
class _FakeRedisClient:
|
||||
def __init__(self) -> None:
|
||||
self.closed = False
|
||||
|
||||
async def ping(self) -> bool:
|
||||
return True
|
||||
|
||||
async def info(self) -> dict[str, object]:
|
||||
return {
|
||||
"redis_version": "7.2",
|
||||
"connected_clients": 1,
|
||||
"used_memory_human": "1M",
|
||||
"uptime_in_seconds": 10,
|
||||
}
|
||||
|
||||
async def aclose(self) -> None:
|
||||
self.closed = True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_success(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
service = RedisService(settings=RedisSettings(host="localhost", port=6379))
|
||||
|
||||
def _build_client(_: RedisService) -> _FakeRedisClient:
|
||||
return _FakeRedisClient()
|
||||
|
||||
monkeypatch.setattr(RedisService, "_build_client", _build_client)
|
||||
|
||||
result = await service.initialize()
|
||||
|
||||
assert result is True
|
||||
assert service.is_initialized is True
|
||||
|
||||
health = await service.health_check()
|
||||
assert health["status"] == "healthy"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_failure(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
service = RedisService(settings=RedisSettings(host="localhost", port=6379))
|
||||
|
||||
def _build_client(_: RedisService) -> _FakeRedisClient:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
monkeypatch.setattr(RedisService, "_build_client", _build_client)
|
||||
|
||||
result = await service.initialize()
|
||||
|
||||
assert result is False
|
||||
assert service.is_initialized is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_is_idempotent() -> None:
|
||||
service = RedisService(settings=RedisSettings(host="localhost", port=6379))
|
||||
|
||||
assert await service.close() is True
|
||||
assert service.is_initialized is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_uninitialized() -> None:
|
||||
service = RedisService(settings=RedisSettings(host="localhost", port=6379))
|
||||
|
||||
health = await service.health_check()
|
||||
|
||||
assert health["status"] == "unhealthy"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_closes_client(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
service = RedisService(settings=RedisSettings(host="localhost", port=6379))
|
||||
client = _FakeRedisClient()
|
||||
|
||||
def _build_client(_: RedisService) -> _FakeRedisClient:
|
||||
return client
|
||||
|
||||
monkeypatch.setattr(RedisService, "_build_client", _build_client)
|
||||
|
||||
assert await service.initialize() is True
|
||||
assert await service.close() is True
|
||||
assert client.closed is True
|
||||
assert service.is_initialized is False
|
||||
|
||||
|
||||
def test_get_client_raises_before_init() -> None:
|
||||
service = RedisService(settings=RedisSettings(host="localhost", port=6379))
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
service.get_client()
|
||||
@@ -0,0 +1,49 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from services.base.service_interface import (
|
||||
BaseServiceProvider,
|
||||
ServiceRegistry,
|
||||
register_service,
|
||||
register_service_instance,
|
||||
)
|
||||
|
||||
|
||||
class _DummyService(BaseServiceProvider):
|
||||
def __init__(self, name: str = "dummy") -> None:
|
||||
super().__init__(name)
|
||||
|
||||
async def initialize(self, **_: object) -> bool:
|
||||
self._set_initialized(True)
|
||||
return True
|
||||
|
||||
async def close(self) -> bool:
|
||||
self._set_initialized(False)
|
||||
return True
|
||||
|
||||
async def health_check(self) -> dict[str, object]:
|
||||
return {"status": "healthy", "details": {}}
|
||||
|
||||
|
||||
def test_register_service_and_create_service() -> None:
|
||||
@register_service("dummy-service")
|
||||
class _RegisteredService(_DummyService):
|
||||
pass
|
||||
|
||||
created = ServiceRegistry.create_service("dummy-service")
|
||||
|
||||
assert created is not None
|
||||
assert created.get_service_info()["name"] == "dummy"
|
||||
|
||||
|
||||
def test_register_service_instance_returns_same_instance() -> None:
|
||||
instance = _DummyService("singleton")
|
||||
|
||||
returned = register_service_instance("dummy-singleton", instance)
|
||||
created = ServiceRegistry.create_service("dummy-singleton")
|
||||
|
||||
assert returned is instance
|
||||
assert created is instance
|
||||
|
||||
|
||||
def test_create_service_returns_none_for_missing() -> None:
|
||||
assert ServiceRegistry.create_service("missing-service") is None
|
||||
@@ -0,0 +1,45 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from celery import Celery
|
||||
from pytest import MonkeyPatch
|
||||
|
||||
from core.logging import celery as celery_logging
|
||||
from core.logging.context import clear_context, get_context
|
||||
|
||||
|
||||
class DummyTask:
|
||||
name: str = "tasks.sample"
|
||||
|
||||
|
||||
def test_celery_prerun_binds_task_context() -> None:
|
||||
handlers = celery_logging.build_celery_signal_handlers()
|
||||
|
||||
handlers.on_task_prerun(task_id="task-123", task=DummyTask())
|
||||
context = get_context()
|
||||
|
||||
assert context["task_id"] == "task-123"
|
||||
assert context["task_name"] == "tasks.sample"
|
||||
|
||||
clear_context()
|
||||
|
||||
|
||||
def test_celery_setup_logging_calls_configure(monkeypatch: MonkeyPatch) -> None:
|
||||
called = {"value": False}
|
||||
|
||||
def fake_configure_logging(settings: object | None = None) -> None:
|
||||
called["value"] = True
|
||||
|
||||
monkeypatch.setattr(celery_logging, "configure_logging", fake_configure_logging)
|
||||
handlers = celery_logging.build_celery_signal_handlers()
|
||||
|
||||
handlers.on_setup_logging()
|
||||
|
||||
assert called["value"] is True
|
||||
|
||||
|
||||
def test_configure_celery_app_disables_hijack() -> None:
|
||||
app = Celery("test")
|
||||
|
||||
celery_logging.configure_celery_app(app)
|
||||
|
||||
assert app.conf.worker_hijack_root_logger is False
|
||||
@@ -0,0 +1,140 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
import structlog
|
||||
|
||||
from core.config.settings import Settings
|
||||
from core.logging.config import build_logging_config, configure_logging
|
||||
|
||||
|
||||
def _get_handlers(config: dict[str, object]) -> dict[str, dict[str, object]]:
|
||||
return cast(dict[str, dict[str, object]], config["handlers"])
|
||||
|
||||
|
||||
def test_build_logging_config_time_rotation(tmp_path: Path) -> None:
|
||||
settings = Settings()
|
||||
runtime = settings.runtime.model_copy(
|
||||
update={
|
||||
"log_dir": str(tmp_path),
|
||||
"log_error_dir": str(tmp_path / "errors"),
|
||||
"log_rotation": "time",
|
||||
}
|
||||
)
|
||||
|
||||
config = build_logging_config(runtime)
|
||||
handlers = _get_handlers(config)
|
||||
|
||||
assert handlers["file"]["class"] == "logging.handlers.TimedRotatingFileHandler"
|
||||
assert handlers["error"]["class"] == "logging.handlers.TimedRotatingFileHandler"
|
||||
assert handlers["error"]["level"] == "ERROR"
|
||||
|
||||
|
||||
def test_build_logging_config_size_rotation(tmp_path: Path) -> None:
|
||||
settings = Settings()
|
||||
runtime = settings.runtime.model_copy(
|
||||
update={
|
||||
"log_dir": str(tmp_path),
|
||||
"log_error_dir": str(tmp_path / "errors"),
|
||||
"log_rotation": "size",
|
||||
"log_rotation_max_bytes": 2048,
|
||||
}
|
||||
)
|
||||
|
||||
config = build_logging_config(runtime)
|
||||
handlers = _get_handlers(config)
|
||||
|
||||
assert handlers["file"]["class"] == "logging.handlers.RotatingFileHandler"
|
||||
assert handlers["error"]["class"] == "logging.handlers.RotatingFileHandler"
|
||||
assert handlers["file"]["maxBytes"] == 2048
|
||||
|
||||
|
||||
def test_build_logging_config_plain_formatter_when_disabled(tmp_path: Path) -> None:
|
||||
settings = Settings()
|
||||
runtime = settings.runtime.model_copy(
|
||||
update={
|
||||
"log_dir": str(tmp_path),
|
||||
"log_error_dir": str(tmp_path / "errors"),
|
||||
"log_json": False,
|
||||
}
|
||||
)
|
||||
|
||||
config = build_logging_config(runtime)
|
||||
handlers = _get_handlers(config)
|
||||
|
||||
assert handlers["file"]["formatter"] == "plain"
|
||||
assert handlers["error"]["formatter"] == "plain"
|
||||
|
||||
|
||||
def _read_last_log_entry(log_path: Path) -> dict[str, object]:
|
||||
assert log_path.exists(), f"Expected log file at {log_path}"
|
||||
entries = [
|
||||
json.loads(line) for line in log_path.read_text().splitlines() if line.strip()
|
||||
]
|
||||
assert entries, "Expected at least one log entry in app.log"
|
||||
return entries[-1]
|
||||
|
||||
|
||||
def _flush_root_handlers() -> None:
|
||||
root_logger = logging.getLogger()
|
||||
for handler in root_logger.handlers:
|
||||
if hasattr(handler, "flush"):
|
||||
handler.flush()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def configured_logging(tmp_path: Path) -> Iterator[Path]:
|
||||
settings = Settings()
|
||||
runtime = settings.runtime.model_copy(
|
||||
update={
|
||||
"log_dir": str(tmp_path),
|
||||
"log_error_dir": str(tmp_path / "errors"),
|
||||
"log_rotation": "size",
|
||||
"log_rotation_max_bytes": 2048,
|
||||
"log_json": True,
|
||||
}
|
||||
)
|
||||
root_logger = logging.getLogger()
|
||||
original_handlers = root_logger.handlers[:]
|
||||
original_level = root_logger.level
|
||||
|
||||
configure_logging(settings.model_copy(update={"runtime": runtime}))
|
||||
|
||||
yield tmp_path
|
||||
|
||||
for handler in root_logger.handlers:
|
||||
handler.close()
|
||||
root_logger.handlers = original_handlers
|
||||
root_logger.setLevel(original_level)
|
||||
structlog.reset_defaults()
|
||||
|
||||
|
||||
def test_stdlib_logging_redacts_sensitive_fields(configured_logging: Path) -> None:
|
||||
logger = logging.getLogger("tests.stdlib")
|
||||
logger.info("login", extra={"password": "secret", "token": "abc"})
|
||||
|
||||
_flush_root_handlers()
|
||||
|
||||
log_path = configured_logging / "app.log"
|
||||
entry = _read_last_log_entry(log_path)
|
||||
|
||||
assert entry["password"] == "[REDACTED]"
|
||||
assert entry["token"] == "[REDACTED]"
|
||||
|
||||
|
||||
def test_structlog_redacts_sensitive_fields(configured_logging: Path) -> None:
|
||||
logger = structlog.get_logger("tests.structlog")
|
||||
logger.info("login", password="secret", token="abc")
|
||||
|
||||
_flush_root_handlers()
|
||||
|
||||
log_path = configured_logging / "app.log"
|
||||
entry = _read_last_log_entry(log_path)
|
||||
|
||||
assert entry["password"] == "[REDACTED]"
|
||||
assert entry["token"] == "[REDACTED]"
|
||||
@@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.logging.filters import build_sensitive_data_processor
|
||||
|
||||
|
||||
def test_redact_sensitive_fields_masks_values() -> None:
|
||||
processor = build_sensitive_data_processor(
|
||||
["password", "token", "api_key", "cookie"]
|
||||
)
|
||||
|
||||
event: dict[str, object] = {
|
||||
"message": "login",
|
||||
"password": "secret",
|
||||
"access_token": "token-123",
|
||||
"apiKey": "apikey-123",
|
||||
"set-cookie": "cookie-1",
|
||||
"nested": {"token": "abc", "safe": "ok"},
|
||||
"list": [{"password": "x"}],
|
||||
}
|
||||
|
||||
redacted = processor(None, "info", event)
|
||||
|
||||
assert redacted["password"] == "[REDACTED]"
|
||||
assert redacted["access_token"] == "[REDACTED]"
|
||||
assert redacted["apiKey"] == "[REDACTED]"
|
||||
assert redacted["set-cookie"] == "[REDACTED]"
|
||||
assert redacted["nested"]["token"] == "[REDACTED]"
|
||||
assert redacted["nested"]["safe"] == "ok"
|
||||
assert redacted["list"][0]["password"] == "[REDACTED]"
|
||||
assert event["password"] == "secret"
|
||||
@@ -0,0 +1,35 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pytest import MonkeyPatch
|
||||
|
||||
from core.config.settings import Settings
|
||||
|
||||
|
||||
def test_runtime_settings_defaults() -> None:
|
||||
settings = Settings()
|
||||
|
||||
assert settings.runtime.log_json is True
|
||||
assert settings.runtime.log_rotation == "time"
|
||||
assert settings.runtime.log_rotation_when == "midnight"
|
||||
assert settings.runtime.log_rotation_interval == 1
|
||||
assert settings.runtime.log_rotation_backup_count == 14
|
||||
assert settings.runtime.log_rotation_max_bytes == 10_000_000
|
||||
assert settings.runtime.log_dir == "logs"
|
||||
assert settings.runtime.log_error_dir == "logs/errors"
|
||||
assert settings.runtime.log_file_name == "app.log"
|
||||
assert settings.runtime.log_error_file_name == "error.log"
|
||||
assert "password" in settings.runtime.log_sensitive_fields
|
||||
|
||||
|
||||
def test_runtime_settings_env_override(monkeypatch: MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("SOCIAL_RUNTIME__LOG_DIR", "var/logs")
|
||||
monkeypatch.setenv("SOCIAL_RUNTIME__LOG_ERROR_DIR", "var/logs/errors")
|
||||
monkeypatch.setenv("SOCIAL_RUNTIME__LOG_ROTATION", "size")
|
||||
monkeypatch.setenv("SOCIAL_RUNTIME__LOG_ROTATION_MAX_BYTES", "2048")
|
||||
|
||||
settings = Settings()
|
||||
|
||||
assert settings.runtime.log_dir == "var/logs"
|
||||
assert settings.runtime.log_error_dir == "var/logs/errors"
|
||||
assert settings.runtime.log_rotation == "size"
|
||||
assert settings.runtime.log_rotation_max_bytes == 2048
|
||||
@@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.http.response import ProblemDetails, build_problem_details
|
||||
|
||||
|
||||
def test_problem_details_defaults() -> None:
|
||||
result = build_problem_details(status_code=401, detail="Unauthorized")
|
||||
|
||||
assert isinstance(result, ProblemDetails)
|
||||
assert result.type == "about:blank"
|
||||
assert result.title == "Unauthorized"
|
||||
assert result.status == 401
|
||||
assert result.detail == "Unauthorized"
|
||||
assert result.instance is None
|
||||
|
||||
|
||||
def test_problem_details_overrides() -> None:
|
||||
result = build_problem_details(
|
||||
status_code=409,
|
||||
detail="Conflict",
|
||||
type_value="https://example.com/problems/conflict",
|
||||
title="Conflict",
|
||||
instance="/api/mobile/auth/signup",
|
||||
)
|
||||
|
||||
assert result.type == "https://example.com/problems/conflict"
|
||||
assert result.title == "Conflict"
|
||||
assert result.status == 409
|
||||
assert result.detail == "Conflict"
|
||||
assert result.instance == "/api/mobile/auth/signup"
|
||||
@@ -0,0 +1,49 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pytest import MonkeyPatch
|
||||
|
||||
from core.config.settings import Settings
|
||||
|
||||
|
||||
def test_social_prefixed_supabase_env_populates_settings(
|
||||
monkeypatch: MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setenv("SOCIAL_SUPABASE__PUBLIC_SCHEME", "https")
|
||||
monkeypatch.setenv("SOCIAL_SUPABASE__PUBLIC_HOST", "public.example")
|
||||
monkeypatch.setenv("SOCIAL_SUPABASE__KONG_HTTP_PORT", "8443")
|
||||
monkeypatch.setenv("SOCIAL_SUPABASE__ANON_KEY", "anon-key")
|
||||
monkeypatch.setenv("SOCIAL_SUPABASE__SERVICE_ROLE_KEY", "service-key")
|
||||
monkeypatch.setenv("SOCIAL_SUPABASE__JWT_SECRET", "jwt-secret")
|
||||
monkeypatch.setenv("SOCIAL_DATABASE__HOST", "db")
|
||||
monkeypatch.setenv("SOCIAL_DATABASE__PORT", "5432")
|
||||
monkeypatch.setenv("SOCIAL_DATABASE__NAME", "app")
|
||||
monkeypatch.setenv("SOCIAL_DATABASE__USER", "user")
|
||||
monkeypatch.setenv("SOCIAL_DATABASE__PASSWORD", "pass")
|
||||
|
||||
settings = Settings()
|
||||
|
||||
assert settings.supabase.public_url == "https://public.example:8443"
|
||||
assert settings.supabase.api_external_url == "https://public.example:8443"
|
||||
assert settings.supabase.anon_key == "anon-key"
|
||||
assert settings.supabase.service_role_key == "service-key"
|
||||
assert settings.supabase.jwt_secret == "jwt-secret"
|
||||
|
||||
supabase_settings = settings.model_dump()["supabase"]
|
||||
assert supabase_settings["public_url"] == "https://public.example:8443"
|
||||
assert supabase_settings["api_external_url"] == "https://public.example:8443"
|
||||
assert supabase_settings["anon_key"] == "anon-key"
|
||||
assert supabase_settings["service_role_key"] == "service-key"
|
||||
assert supabase_settings["jwt_secret"] == "jwt-secret"
|
||||
assert settings.database_url == "postgresql+asyncpg://user:pass@db:5432/app"
|
||||
|
||||
|
||||
def test_social_prefixed_api_external_url_is_loaded(
|
||||
monkeypatch: MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setenv("SOCIAL_SUPABASE__PUBLIC_SCHEME", "https")
|
||||
monkeypatch.setenv("SOCIAL_SUPABASE__PUBLIC_HOST", "api.example")
|
||||
monkeypatch.setenv("SOCIAL_SUPABASE__KONG_HTTP_PORT", "8443")
|
||||
|
||||
settings = Settings()
|
||||
|
||||
assert settings.supabase.api_external_url == "https://api.example:8443"
|
||||
@@ -0,0 +1,41 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from v1.auth.models import (
|
||||
AuthTokenResponse,
|
||||
AuthUser,
|
||||
LoginRequest,
|
||||
RefreshRequest,
|
||||
SignupRequest,
|
||||
)
|
||||
|
||||
|
||||
def test_signup_requires_valid_email() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
SignupRequest(email="not-an-email", password="secret123")
|
||||
|
||||
|
||||
def test_login_requires_valid_email() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
LoginRequest(email="invalid", password="secret123")
|
||||
|
||||
|
||||
def test_refresh_requires_token() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
RefreshRequest(refresh_token="")
|
||||
|
||||
|
||||
def test_auth_token_response_maps_user() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
response = AuthTokenResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
|
||||
assert response.user.id == "user-1"
|
||||
assert response.user.email == "user@example.com"
|
||||
@@ -0,0 +1,74 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from v1.auth.models import (
|
||||
AuthTokenResponse,
|
||||
AuthUser,
|
||||
LoginRequest,
|
||||
RefreshRequest,
|
||||
SignupRequest,
|
||||
)
|
||||
from v1.auth.service import AuthService, AuthServiceGateway
|
||||
|
||||
|
||||
class FakeGateway(AuthServiceGateway):
|
||||
def __init__(self, response: AuthTokenResponse) -> None:
|
||||
self._response = response
|
||||
|
||||
async def signup(self, request: SignupRequest) -> AuthTokenResponse:
|
||||
return self._response
|
||||
|
||||
async def login(self, request: LoginRequest) -> AuthTokenResponse:
|
||||
return self._response
|
||||
|
||||
async def refresh(self, request: RefreshRequest) -> AuthTokenResponse:
|
||||
return self._response
|
||||
|
||||
async def logout(self, refresh_token: str | None) -> None:
|
||||
return None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_signup_maps_response() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = AuthTokenResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
service = AuthService(gateway=FakeGateway(token_response))
|
||||
|
||||
result = await service.signup(
|
||||
SignupRequest(email="user@example.com", password="secret123")
|
||||
)
|
||||
|
||||
assert result.access_token == "access"
|
||||
assert result.refresh_token == "refresh"
|
||||
assert result.user.id == "user-1"
|
||||
|
||||
|
||||
class LogoutAssertingGateway(AuthServiceGateway):
|
||||
def __init__(self, expected_refresh_token: str) -> None:
|
||||
self._expected_refresh_token = expected_refresh_token
|
||||
|
||||
async def signup(self, request: SignupRequest) -> AuthTokenResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
async def login(self, request: LoginRequest) -> AuthTokenResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
async def refresh(self, request: RefreshRequest) -> AuthTokenResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
async def logout(self, refresh_token: str | None) -> None:
|
||||
assert refresh_token == self._expected_refresh_token
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logout_forwards_refresh_token() -> None:
|
||||
service = AuthService(gateway=LogoutAssertingGateway("refresh-token"))
|
||||
|
||||
await service.logout("refresh-token")
|
||||
@@ -0,0 +1,285 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
import jwt
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from v1.profile.dependencies import get_current_user
|
||||
|
||||
|
||||
class TestGetCurrentUser:
|
||||
"""Tests for JWT validation in get_current_user dependency."""
|
||||
|
||||
@pytest.fixture
|
||||
def jwt_secret(self) -> str:
|
||||
return "super-secret-jwt-token-with-at-least-32-characters"
|
||||
|
||||
@pytest.fixture
|
||||
def valid_user_id(self) -> str:
|
||||
return "00000000-0000-0000-0000-000000000123"
|
||||
|
||||
@pytest.fixture
|
||||
def valid_payload(self, valid_user_id: str) -> dict[str, Any]:
|
||||
"""Valid JWT payload with all required claims."""
|
||||
now = int(time.time())
|
||||
return {
|
||||
"sub": valid_user_id,
|
||||
"aud": "authenticated",
|
||||
"iss": "http://localhost:8001/auth/v1",
|
||||
"exp": now + 3600, # 1 hour from now
|
||||
"iat": now,
|
||||
}
|
||||
|
||||
def _create_token(self, payload: dict[str, Any], secret: str) -> str:
|
||||
return jwt.encode(payload, secret, algorithm="HS256")
|
||||
|
||||
def test_valid_token_returns_current_user(
|
||||
self,
|
||||
jwt_secret: str,
|
||||
valid_payload: dict[str, Any],
|
||||
valid_user_id: str,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Valid JWT with correct aud/iss/exp should return CurrentUser."""
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.jwt_secret", jwt_secret
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.public_scheme",
|
||||
"http",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.public_host",
|
||||
"localhost",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.kong_http_port",
|
||||
8001,
|
||||
)
|
||||
|
||||
token = self._create_token(valid_payload, jwt_secret)
|
||||
authorization = f"Bearer {token}"
|
||||
|
||||
result = get_current_user(authorization=authorization)
|
||||
|
||||
assert isinstance(result, CurrentUser)
|
||||
assert result.id == UUID(valid_user_id)
|
||||
|
||||
def test_missing_authorization_raises_401(self) -> None:
|
||||
"""Missing Authorization header should raise 401."""
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_current_user(authorization=None)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert exc_info.value.detail == "Unauthorized"
|
||||
|
||||
def test_invalid_scheme_raises_401(self) -> None:
|
||||
"""Non-Bearer scheme should raise 401."""
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_current_user(authorization="Basic dXNlcjpwYXNz")
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
def test_expired_token_raises_401(
|
||||
self,
|
||||
jwt_secret: str,
|
||||
valid_payload: dict[str, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Expired JWT should raise 401."""
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.jwt_secret", jwt_secret
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.public_scheme",
|
||||
"http",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.public_host",
|
||||
"localhost",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.kong_http_port",
|
||||
8001,
|
||||
)
|
||||
|
||||
valid_payload["exp"] = int(time.time()) - 3600 # 1 hour ago
|
||||
token = self._create_token(valid_payload, jwt_secret)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_current_user(authorization=f"Bearer {token}")
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
def test_invalid_audience_raises_401(
|
||||
self,
|
||||
jwt_secret: str,
|
||||
valid_payload: dict[str, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""JWT with wrong audience should raise 401."""
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.jwt_secret", jwt_secret
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.public_scheme",
|
||||
"http",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.public_host",
|
||||
"localhost",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.kong_http_port",
|
||||
8001,
|
||||
)
|
||||
|
||||
valid_payload["aud"] = "wrong-audience"
|
||||
token = self._create_token(valid_payload, jwt_secret)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_current_user(authorization=f"Bearer {token}")
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
def test_invalid_issuer_raises_401(
|
||||
self,
|
||||
jwt_secret: str,
|
||||
valid_payload: dict[str, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""JWT with wrong issuer should raise 401."""
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.jwt_secret", jwt_secret
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.public_scheme",
|
||||
"http",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.public_host",
|
||||
"localhost",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.kong_http_port",
|
||||
8001,
|
||||
)
|
||||
|
||||
valid_payload["iss"] = "http://malicious-site.com/auth/v1"
|
||||
token = self._create_token(valid_payload, jwt_secret)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_current_user(authorization=f"Bearer {token}")
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
def test_missing_subject_raises_401(
|
||||
self,
|
||||
jwt_secret: str,
|
||||
valid_payload: dict[str, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""JWT without 'sub' claim should raise 401."""
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.jwt_secret", jwt_secret
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.public_scheme",
|
||||
"http",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.public_host",
|
||||
"localhost",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.kong_http_port",
|
||||
8001,
|
||||
)
|
||||
|
||||
del valid_payload["sub"]
|
||||
token = self._create_token(valid_payload, jwt_secret)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_current_user(authorization=f"Bearer {token}")
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
def test_wrong_secret_raises_401(
|
||||
self,
|
||||
jwt_secret: str,
|
||||
valid_payload: dict[str, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""JWT signed with wrong secret should raise 401."""
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.jwt_secret", jwt_secret
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.public_scheme",
|
||||
"http",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.public_host",
|
||||
"localhost",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.kong_http_port",
|
||||
8001,
|
||||
)
|
||||
|
||||
token = self._create_token(
|
||||
valid_payload, "wrong-secret-key-that-is-long-enough"
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_current_user(authorization=f"Bearer {token}")
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
def test_jwt_secret_not_configured_raises_503(
|
||||
self, valid_payload: dict[str, Any], monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Missing JWT secret in config should raise 503."""
|
||||
monkeypatch.setattr("v1.profile.dependencies.config.supabase.jwt_secret", None)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_current_user(authorization="Bearer some-token")
|
||||
|
||||
assert exc_info.value.status_code == 503
|
||||
assert exc_info.value.detail == "JWT secret not configured"
|
||||
|
||||
def test_invalid_uuid_in_subject_raises_401(
|
||||
self,
|
||||
jwt_secret: str,
|
||||
valid_payload: dict[str, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""JWT with non-UUID 'sub' claim should raise 401."""
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.jwt_secret", jwt_secret
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.public_scheme",
|
||||
"http",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.public_host",
|
||||
"localhost",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.kong_http_port",
|
||||
8001,
|
||||
)
|
||||
|
||||
valid_payload["sub"] = "not-a-valid-uuid"
|
||||
token = self._create_token(valid_payload, jwt_secret)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_current_user(authorization=f"Bearer {token}")
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
@@ -0,0 +1,172 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from models.profile import Profile
|
||||
from v1.profile.repository import ProfileRepository
|
||||
from v1.profile.schemas import ProfileUpdateRequest
|
||||
from v1.profile.service import ProfileService
|
||||
|
||||
|
||||
def _create_mock_profile(
|
||||
user_id: UUID = UUID("00000000-0000-0000-0000-000000000001"),
|
||||
username: str = "demo",
|
||||
display_name: str | None = "Demo User",
|
||||
avatar_url: str | None = None,
|
||||
bio: str | None = None,
|
||||
) -> Profile:
|
||||
"""Create a mock Profile ORM object."""
|
||||
profile = MagicMock(spec=Profile)
|
||||
profile.id = user_id
|
||||
profile.username = username
|
||||
profile.display_name = display_name
|
||||
profile.avatar_url = avatar_url
|
||||
profile.bio = bio
|
||||
return profile
|
||||
|
||||
|
||||
class FakeRepo:
|
||||
"""Fake repository for testing that conforms to ProfileRepository protocol."""
|
||||
|
||||
def __init__(self, profile: Profile | None) -> None:
|
||||
self._profile = profile
|
||||
|
||||
async def get_by_user_id(self, user_id: UUID) -> Profile | None:
|
||||
if self._profile and user_id == self._profile.id:
|
||||
return self._profile
|
||||
return None
|
||||
|
||||
async def get_by_username(self, username: str) -> Profile | None:
|
||||
if self._profile and username == self._profile.username:
|
||||
return self._profile
|
||||
return None
|
||||
|
||||
async def update_by_user_id(
|
||||
self, user_id: UUID, update_data: dict[str, str | None]
|
||||
) -> Profile | None:
|
||||
if not self._profile or user_id != self._profile.id:
|
||||
return None
|
||||
# Apply updates to mock
|
||||
for key, value in update_data.items():
|
||||
if hasattr(self._profile, key):
|
||||
setattr(self._profile, key, value)
|
||||
return self._profile
|
||||
|
||||
|
||||
# Verify FakeRepo implements the protocol
|
||||
_repo_check: ProfileRepository = FakeRepo(None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session() -> AsyncMock:
|
||||
"""Create a mock AsyncSession."""
|
||||
session = AsyncMock()
|
||||
session.commit = AsyncMock()
|
||||
session.rollback = AsyncMock()
|
||||
return session
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_me_returns_profile(mock_session: AsyncMock) -> None:
|
||||
user_id = UUID("00000000-0000-0000-0000-000000000001")
|
||||
profile = _create_mock_profile(user_id=user_id, username="demo")
|
||||
user = CurrentUser(id=user_id)
|
||||
service = ProfileService(
|
||||
repository=FakeRepo(profile),
|
||||
session=mock_session,
|
||||
current_user=user,
|
||||
)
|
||||
|
||||
result = await service.get_me()
|
||||
|
||||
assert result.username == "demo"
|
||||
assert result.id == str(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_me_not_found_raises_404(mock_session: AsyncMock) -> None:
|
||||
user_id = UUID("00000000-0000-0000-0000-000000000001")
|
||||
user = CurrentUser(id=user_id)
|
||||
service = ProfileService(
|
||||
repository=FakeRepo(None),
|
||||
session=mock_session,
|
||||
current_user=user,
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.get_me()
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_me_updates_fields(mock_session: AsyncMock) -> None:
|
||||
user_id = UUID("00000000-0000-0000-0000-000000000001")
|
||||
profile = _create_mock_profile(user_id=user_id, username="demo")
|
||||
user = CurrentUser(id=user_id)
|
||||
service = ProfileService(
|
||||
repository=FakeRepo(profile),
|
||||
session=mock_session,
|
||||
current_user=user,
|
||||
)
|
||||
|
||||
result = await service.update_me(ProfileUpdateRequest(display_name="Updated"))
|
||||
|
||||
assert result.display_name == "Updated"
|
||||
mock_session.commit.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_me_no_fields_raises_400(mock_session: AsyncMock) -> None:
|
||||
user_id = UUID("00000000-0000-0000-0000-000000000001")
|
||||
profile = _create_mock_profile(user_id=user_id)
|
||||
user = CurrentUser(id=user_id)
|
||||
service = ProfileService(
|
||||
repository=FakeRepo(profile),
|
||||
session=mock_session,
|
||||
current_user=user,
|
||||
)
|
||||
|
||||
# Create a request with all None values by bypassing validation
|
||||
update = MagicMock(spec=ProfileUpdateRequest)
|
||||
update.display_name = None
|
||||
update.avatar_url = None
|
||||
update.bio = None
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.update_me(update)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_username_returns_profile(mock_session: AsyncMock) -> None:
|
||||
profile = _create_mock_profile(username="demo")
|
||||
service = ProfileService(
|
||||
repository=FakeRepo(profile),
|
||||
session=mock_session,
|
||||
current_user=CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001")),
|
||||
)
|
||||
|
||||
result = await service.get_by_username("demo")
|
||||
|
||||
assert result.username == "demo"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_username_not_found_raises_404(mock_session: AsyncMock) -> None:
|
||||
service = ProfileService(
|
||||
repository=FakeRepo(None),
|
||||
session=mock_session,
|
||||
current_user=CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001")),
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.get_by_username("unknown")
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
@@ -0,0 +1,61 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from v1.profile.schemas import ProfileResponse, ProfileUpdateRequest
|
||||
|
||||
|
||||
def test_profile_response_maps_fields() -> None:
|
||||
response = ProfileResponse(
|
||||
id="user-1",
|
||||
username="demo",
|
||||
display_name="Demo User",
|
||||
avatar_url=None,
|
||||
bio=None,
|
||||
)
|
||||
|
||||
assert response.id == "user-1"
|
||||
assert response.username == "demo"
|
||||
|
||||
|
||||
def test_profile_update_requires_one_field() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
ProfileUpdateRequest()
|
||||
|
||||
|
||||
def test_profile_update_accepts_valid_https_url() -> None:
|
||||
request = ProfileUpdateRequest(avatar_url="https://example.com/avatar.png")
|
||||
assert request.avatar_url == "https://example.com/avatar.png"
|
||||
|
||||
|
||||
def test_profile_update_accepts_valid_http_url() -> None:
|
||||
request = ProfileUpdateRequest(
|
||||
avatar_url="http://localhost:8001/storage/avatar.png"
|
||||
)
|
||||
assert request.avatar_url == "http://localhost:8001/storage/avatar.png"
|
||||
|
||||
|
||||
def test_profile_update_rejects_invalid_url() -> None:
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
ProfileUpdateRequest(avatar_url="not-a-valid-url")
|
||||
|
||||
errors = exc_info.value.errors()
|
||||
assert len(errors) == 1
|
||||
assert "avatar_url" in str(errors[0]["loc"])
|
||||
|
||||
|
||||
def test_profile_update_rejects_javascript_url() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
ProfileUpdateRequest(avatar_url="javascript:alert('xss')")
|
||||
|
||||
|
||||
def test_profile_update_rejects_data_url() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
ProfileUpdateRequest(avatar_url="data:text/html,<script>alert('xss')</script>")
|
||||
|
||||
|
||||
def test_profile_update_accepts_none_avatar_url_with_other_field() -> None:
|
||||
request = ProfileUpdateRequest(display_name="Test", avatar_url=None)
|
||||
assert request.avatar_url is None
|
||||
assert request.display_name == "Test"
|
||||
Reference in New Issue
Block a user