refactor: 统一认证端点并删除冗余 profile 模块
- 合并 auth 端点: /verifications/verify → /verify, /verifications/resend → /resend - 整合密码重置到 /verify 端点 (type=recovery) - 移除未使用的 /auth/users 端点 - 添加 redirect URL 白名单验证 (site_url + additional_redirect_urls) - 限流改用 Redis + IP 标识,替代内存锁 - 删除 v1/profile 死代码模块 - 更新前端 auth_api 适配新端点 - 添加 supabase site_url 和 additional_redirect_urls 配置
This commit is contained in:
@@ -14,6 +14,11 @@ def test_social_prefixed_supabase_env_populates_settings(
|
||||
monkeypatch.setenv("SOCIAL_SUPABASE__ANON_KEY", "anon-key")
|
||||
monkeypatch.setenv("SOCIAL_SUPABASE__SERVICE_ROLE_KEY", "service-key")
|
||||
monkeypatch.setenv("SOCIAL_SUPABASE__JWT_SECRET", "jwt-secret")
|
||||
monkeypatch.setenv("SOCIAL_SUPABASE__SITE_URL", "https://app.example.com")
|
||||
monkeypatch.setenv(
|
||||
"SOCIAL_SUPABASE__ADDITIONAL_REDIRECT_URLS",
|
||||
'["https://a.example.com", "https://b.example.com/path"]',
|
||||
)
|
||||
monkeypatch.setenv("SOCIAL_DATABASE__HOST", "db")
|
||||
monkeypatch.setenv("SOCIAL_DATABASE__PORT", "5432")
|
||||
monkeypatch.setenv("SOCIAL_DATABASE__NAME", "app")
|
||||
@@ -26,10 +31,16 @@ def test_social_prefixed_supabase_env_populates_settings(
|
||||
assert settings.supabase.anon_key == "anon-key"
|
||||
assert settings.supabase.service_role_key == "service-key"
|
||||
assert settings.supabase.jwt_secret == "jwt-secret"
|
||||
assert settings.supabase.site_url == "https://app.example.com"
|
||||
assert settings.supabase.additional_redirect_urls == [
|
||||
"https://a.example.com",
|
||||
"https://b.example.com/path",
|
||||
]
|
||||
|
||||
supabase_settings = settings.model_dump()["supabase"]
|
||||
assert supabase_settings["public_url"] == "https://public.example:8443"
|
||||
assert supabase_settings["anon_key"] == "anon-key"
|
||||
assert supabase_settings["service_role_key"] == "service-key"
|
||||
assert supabase_settings["jwt_secret"] == "jwt-secret"
|
||||
assert supabase_settings["site_url"] == "https://app.example.com"
|
||||
assert settings.database_url == "postgresql+asyncpg://user:pass@db:5432/app"
|
||||
|
||||
@@ -7,7 +7,11 @@ import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from v1.auth.gateway import SupabaseAuthGateway
|
||||
from v1.auth.schemas import PasswordResetConfirmRequest, PasswordResetRequest
|
||||
from v1.auth.schemas import (
|
||||
PasswordResetConfirmRequest,
|
||||
PasswordResetRequest,
|
||||
VerificationResendRequest,
|
||||
)
|
||||
|
||||
|
||||
class TestSupabaseAuthGateway:
|
||||
@@ -56,6 +60,22 @@ class TestSupabaseAuthGateway:
|
||||
options={"redirect_to": "http://localhost:3000/reset-password"},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_password_reset_rejects_redirect_outside_allowlist(
|
||||
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
|
||||
) -> None:
|
||||
sut, _, _ = gateway
|
||||
request = PasswordResetRequest(
|
||||
email="test@example.com",
|
||||
redirect_to="https://evil.example/reset",
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await sut.request_password_reset(request)
|
||||
|
||||
assert exc_info.value.status_code == 422
|
||||
assert exc_info.value.detail == "Invalid redirect URL"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_password_reset_swallows_auth_error(
|
||||
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
|
||||
@@ -165,3 +185,24 @@ class TestSupabaseAuthGateway:
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert exc_info.value.detail == "Invalid or expired verification code"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recovery_resend_calls_reset_password_email(
|
||||
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
|
||||
) -> None:
|
||||
sut, mock_client, _ = gateway
|
||||
mock_reset_email = MagicMock()
|
||||
mock_client.auth.reset_password_email = mock_reset_email
|
||||
|
||||
await sut.resend_verification(
|
||||
VerificationResendRequest(
|
||||
type="recovery",
|
||||
email="test@example.com",
|
||||
redirect_to="http://localhost:3000/reset-password",
|
||||
)
|
||||
)
|
||||
|
||||
mock_reset_email.assert_called_once_with(
|
||||
"test@example.com",
|
||||
options={"redirect_to": "http://localhost:3000/reset-password"},
|
||||
)
|
||||
|
||||
@@ -33,6 +33,25 @@ def test_signup_verify_requires_six_digit_token() -> None:
|
||||
VerificationVerifyRequest(email="user@example.com", token="abc123")
|
||||
|
||||
|
||||
def test_signup_verify_disallows_new_password() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
VerificationVerifyRequest(
|
||||
type="signup",
|
||||
email="user@example.com",
|
||||
token="123456",
|
||||
new_password="secret123",
|
||||
)
|
||||
|
||||
|
||||
def test_recovery_verify_requires_new_password() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
VerificationVerifyRequest(
|
||||
type="recovery",
|
||||
email="user@example.com",
|
||||
token="123456",
|
||||
)
|
||||
|
||||
|
||||
def test_signup_resend_requires_valid_email() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
VerificationResendRequest(email="invalid")
|
||||
|
||||
@@ -5,6 +5,8 @@ import pytest
|
||||
import v1.auth.gateway as auth_gateway_module
|
||||
from v1.auth.schemas import (
|
||||
AuthUser,
|
||||
PasswordResetConfirmRequest,
|
||||
PasswordResetRequest,
|
||||
SessionCreateRequest,
|
||||
SessionRefreshRequest,
|
||||
SessionResponse,
|
||||
@@ -44,40 +46,15 @@ class FakeGateway(AuthServiceGateway):
|
||||
return None
|
||||
|
||||
async def get_user_by_email(self, email: str) -> UserByEmailResponse:
|
||||
return UserByEmailResponse(
|
||||
id="user-1",
|
||||
email=email,
|
||||
created_at="2026-02-24T00:00:00Z",
|
||||
email_confirmed_at=None,
|
||||
)
|
||||
raise NotImplementedError
|
||||
|
||||
async def request_password_reset(self, request: PasswordResetRequest) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_signup_maps_response() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = SessionResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
service = AuthService(gateway=FakeGateway(token_response))
|
||||
|
||||
start_result = await service.create_verification(
|
||||
VerificationCreateRequest(
|
||||
username="demo", email="user@example.com", password="secret123"
|
||||
)
|
||||
)
|
||||
assert start_result.email == "user@example.com"
|
||||
|
||||
result = await service.verify_verification(
|
||||
VerificationVerifyRequest(email="user@example.com", token="123456")
|
||||
)
|
||||
|
||||
assert result.access_token == "access"
|
||||
assert result.refresh_token == "refresh"
|
||||
assert result.user.id == "user-1"
|
||||
async def confirm_password_reset(
|
||||
self, request: PasswordResetConfirmRequest
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class LogoutAssertingGateway(AuthServiceGateway):
|
||||
@@ -109,6 +86,14 @@ class LogoutAssertingGateway(AuthServiceGateway):
|
||||
async def get_user_by_email(self, email: str) -> UserByEmailResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
async def request_password_reset(self, request: PasswordResetRequest) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def confirm_password_reset(
|
||||
self, request: PasswordResetConfirmRequest
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logout_forwards_refresh_token() -> None:
|
||||
@@ -117,23 +102,6 @@ async def test_logout_forwards_refresh_token() -> None:
|
||||
await service.delete_session("refresh-token")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_by_email_forwards_to_gateway() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = SessionResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
service = AuthService(gateway=FakeGateway(token_response))
|
||||
|
||||
result = await service.get_user_by_email("user@example.com")
|
||||
|
||||
assert result.email == "user@example.com"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_signup_resend_returns_none() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
@@ -182,7 +150,9 @@ async def test_supabase_signup_passes_username_in_metadata(
|
||||
class FakeClient:
|
||||
auth = FakeSupabaseAuth()
|
||||
|
||||
monkeypatch.setattr(auth_gateway_module, "create_client", lambda *_: FakeClient())
|
||||
monkeypatch.setattr(
|
||||
auth_gateway_module.supabase_service, "get_client", lambda: FakeClient()
|
||||
)
|
||||
|
||||
gateway = auth_gateway_module.SupabaseAuthGateway()
|
||||
await gateway.create_verification(
|
||||
|
||||
@@ -1,285 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
import jwt
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from v1.profile.dependencies import get_current_user
|
||||
|
||||
|
||||
class TestGetCurrentUser:
|
||||
"""Tests for JWT validation in get_current_user dependency."""
|
||||
|
||||
@pytest.fixture
|
||||
def jwt_secret(self) -> str:
|
||||
return "super-secret-jwt-token-with-at-least-32-characters"
|
||||
|
||||
@pytest.fixture
|
||||
def valid_user_id(self) -> str:
|
||||
return "00000000-0000-0000-0000-000000000123"
|
||||
|
||||
@pytest.fixture
|
||||
def valid_payload(self, valid_user_id: str) -> dict[str, Any]:
|
||||
"""Valid JWT payload with all required claims."""
|
||||
now = int(time.time())
|
||||
return {
|
||||
"sub": valid_user_id,
|
||||
"aud": "authenticated",
|
||||
"iss": "http://localhost:8001/auth/v1",
|
||||
"exp": now + 3600, # 1 hour from now
|
||||
"iat": now,
|
||||
}
|
||||
|
||||
def _create_token(self, payload: dict[str, Any], secret: str) -> str:
|
||||
return jwt.encode(payload, secret, algorithm="HS256")
|
||||
|
||||
def test_valid_token_returns_current_user(
|
||||
self,
|
||||
jwt_secret: str,
|
||||
valid_payload: dict[str, Any],
|
||||
valid_user_id: str,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Valid JWT with correct aud/iss/exp should return CurrentUser."""
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.jwt_secret", jwt_secret
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.public_scheme",
|
||||
"http",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.public_host",
|
||||
"localhost",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.kong_http_port",
|
||||
8001,
|
||||
)
|
||||
|
||||
token = self._create_token(valid_payload, jwt_secret)
|
||||
authorization = f"Bearer {token}"
|
||||
|
||||
result = get_current_user(authorization=authorization)
|
||||
|
||||
assert isinstance(result, CurrentUser)
|
||||
assert result.id == UUID(valid_user_id)
|
||||
|
||||
def test_missing_authorization_raises_401(self) -> None:
|
||||
"""Missing Authorization header should raise 401."""
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_current_user(authorization=None)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert exc_info.value.detail == "Unauthorized"
|
||||
|
||||
def test_invalid_scheme_raises_401(self) -> None:
|
||||
"""Non-Bearer scheme should raise 401."""
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_current_user(authorization="Basic dXNlcjpwYXNz")
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
def test_expired_token_raises_401(
|
||||
self,
|
||||
jwt_secret: str,
|
||||
valid_payload: dict[str, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Expired JWT should raise 401."""
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.jwt_secret", jwt_secret
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.public_scheme",
|
||||
"http",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.public_host",
|
||||
"localhost",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.kong_http_port",
|
||||
8001,
|
||||
)
|
||||
|
||||
valid_payload["exp"] = int(time.time()) - 3600 # 1 hour ago
|
||||
token = self._create_token(valid_payload, jwt_secret)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_current_user(authorization=f"Bearer {token}")
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
def test_invalid_audience_raises_401(
|
||||
self,
|
||||
jwt_secret: str,
|
||||
valid_payload: dict[str, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""JWT with wrong audience should raise 401."""
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.jwt_secret", jwt_secret
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.public_scheme",
|
||||
"http",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.public_host",
|
||||
"localhost",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.kong_http_port",
|
||||
8001,
|
||||
)
|
||||
|
||||
valid_payload["aud"] = "wrong-audience"
|
||||
token = self._create_token(valid_payload, jwt_secret)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_current_user(authorization=f"Bearer {token}")
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
def test_invalid_issuer_raises_401(
|
||||
self,
|
||||
jwt_secret: str,
|
||||
valid_payload: dict[str, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""JWT with wrong issuer should raise 401."""
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.jwt_secret", jwt_secret
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.public_scheme",
|
||||
"http",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.public_host",
|
||||
"localhost",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.kong_http_port",
|
||||
8001,
|
||||
)
|
||||
|
||||
valid_payload["iss"] = "http://malicious-site.com/auth/v1"
|
||||
token = self._create_token(valid_payload, jwt_secret)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_current_user(authorization=f"Bearer {token}")
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
def test_missing_subject_raises_401(
|
||||
self,
|
||||
jwt_secret: str,
|
||||
valid_payload: dict[str, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""JWT without 'sub' claim should raise 401."""
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.jwt_secret", jwt_secret
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.public_scheme",
|
||||
"http",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.public_host",
|
||||
"localhost",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.kong_http_port",
|
||||
8001,
|
||||
)
|
||||
|
||||
del valid_payload["sub"]
|
||||
token = self._create_token(valid_payload, jwt_secret)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_current_user(authorization=f"Bearer {token}")
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
def test_wrong_secret_raises_401(
|
||||
self,
|
||||
jwt_secret: str,
|
||||
valid_payload: dict[str, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""JWT signed with wrong secret should raise 401."""
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.jwt_secret", jwt_secret
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.public_scheme",
|
||||
"http",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.public_host",
|
||||
"localhost",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.kong_http_port",
|
||||
8001,
|
||||
)
|
||||
|
||||
token = self._create_token(
|
||||
valid_payload, "wrong-secret-key-that-is-long-enough"
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_current_user(authorization=f"Bearer {token}")
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
def test_jwt_secret_not_configured_raises_503(
|
||||
self, valid_payload: dict[str, Any], monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Missing JWT secret in config should raise 503."""
|
||||
monkeypatch.setattr("v1.profile.dependencies.config.supabase.jwt_secret", None)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_current_user(authorization="Bearer some-token")
|
||||
|
||||
assert exc_info.value.status_code == 503
|
||||
assert exc_info.value.detail == "JWT secret not configured"
|
||||
|
||||
def test_invalid_uuid_in_subject_raises_401(
|
||||
self,
|
||||
jwt_secret: str,
|
||||
valid_payload: dict[str, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""JWT with non-UUID 'sub' claim should raise 401."""
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.jwt_secret", jwt_secret
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.public_scheme",
|
||||
"http",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.public_host",
|
||||
"localhost",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"v1.profile.dependencies.config.supabase.kong_http_port",
|
||||
8001,
|
||||
)
|
||||
|
||||
valid_payload["sub"] = "not-a-valid-uuid"
|
||||
token = self._create_token(valid_payload, jwt_secret)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_current_user(authorization=f"Bearer {token}")
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
@@ -1,170 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from models.profile import Profile
|
||||
from v1.profile.repository import ProfileRepository
|
||||
from v1.profile.schemas import ProfileUpdateRequest
|
||||
from v1.profile.service import ProfileService
|
||||
|
||||
|
||||
def _create_mock_profile(
|
||||
user_id: UUID = UUID("00000000-0000-0000-0000-000000000001"),
|
||||
username: str = "demo",
|
||||
avatar_url: str | None = None,
|
||||
bio: str | None = None,
|
||||
) -> Profile:
|
||||
"""Create a mock Profile ORM object."""
|
||||
profile = MagicMock(spec=Profile)
|
||||
profile.id = user_id
|
||||
profile.username = username
|
||||
profile.avatar_url = avatar_url
|
||||
profile.bio = bio
|
||||
return profile
|
||||
|
||||
|
||||
class FakeRepo:
|
||||
"""Fake repository for testing that conforms to ProfileRepository protocol."""
|
||||
|
||||
def __init__(self, profile: Profile | None) -> None:
|
||||
self._profile = profile
|
||||
|
||||
async def get_by_user_id(self, user_id: UUID) -> Profile | None:
|
||||
if self._profile and user_id == self._profile.id:
|
||||
return self._profile
|
||||
return None
|
||||
|
||||
async def get_by_username(self, username: str) -> Profile | None:
|
||||
if self._profile and username == self._profile.username:
|
||||
return self._profile
|
||||
return None
|
||||
|
||||
async def update_by_user_id(
|
||||
self, user_id: UUID, update_data: dict[str, str | None]
|
||||
) -> Profile | None:
|
||||
if not self._profile or user_id != self._profile.id:
|
||||
return None
|
||||
# Apply updates to mock
|
||||
for key, value in update_data.items():
|
||||
if hasattr(self._profile, key):
|
||||
setattr(self._profile, key, value)
|
||||
return self._profile
|
||||
|
||||
|
||||
# Verify FakeRepo implements the protocol
|
||||
_repo_check: ProfileRepository = FakeRepo(None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session() -> AsyncMock:
|
||||
"""Create a mock AsyncSession."""
|
||||
session = AsyncMock()
|
||||
session.commit = AsyncMock()
|
||||
session.rollback = AsyncMock()
|
||||
return session
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_me_returns_profile(mock_session: AsyncMock) -> None:
|
||||
user_id = UUID("00000000-0000-0000-0000-000000000001")
|
||||
profile = _create_mock_profile(user_id=user_id, username="demo")
|
||||
user = CurrentUser(id=user_id)
|
||||
service = ProfileService(
|
||||
repository=FakeRepo(profile),
|
||||
session=mock_session,
|
||||
current_user=user,
|
||||
)
|
||||
|
||||
result = await service.get_me()
|
||||
|
||||
assert result.username == "demo"
|
||||
assert result.id == str(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_me_not_found_raises_404(mock_session: AsyncMock) -> None:
|
||||
user_id = UUID("00000000-0000-0000-0000-000000000001")
|
||||
user = CurrentUser(id=user_id)
|
||||
service = ProfileService(
|
||||
repository=FakeRepo(None),
|
||||
session=mock_session,
|
||||
current_user=user,
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.get_me()
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_me_updates_fields(mock_session: AsyncMock) -> None:
|
||||
user_id = UUID("00000000-0000-0000-0000-000000000001")
|
||||
profile = _create_mock_profile(user_id=user_id, username="demo")
|
||||
user = CurrentUser(id=user_id)
|
||||
service = ProfileService(
|
||||
repository=FakeRepo(profile),
|
||||
session=mock_session,
|
||||
current_user=user,
|
||||
)
|
||||
|
||||
result = await service.update_me(ProfileUpdateRequest(username="updated"))
|
||||
|
||||
assert result.username == "updated"
|
||||
mock_session.commit.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_me_no_fields_raises_400(mock_session: AsyncMock) -> None:
|
||||
user_id = UUID("00000000-0000-0000-0000-000000000001")
|
||||
profile = _create_mock_profile(user_id=user_id)
|
||||
user = CurrentUser(id=user_id)
|
||||
service = ProfileService(
|
||||
repository=FakeRepo(profile),
|
||||
session=mock_session,
|
||||
current_user=user,
|
||||
)
|
||||
|
||||
# Create a request with all None values by bypassing validation
|
||||
update = MagicMock(spec=ProfileUpdateRequest)
|
||||
update.username = None
|
||||
update.avatar_url = None
|
||||
update.bio = None
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.update_me(update)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_username_returns_profile(mock_session: AsyncMock) -> None:
|
||||
profile = _create_mock_profile(username="demo")
|
||||
service = ProfileService(
|
||||
repository=FakeRepo(profile),
|
||||
session=mock_session,
|
||||
current_user=CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001")),
|
||||
)
|
||||
|
||||
result = await service.get_by_username("demo")
|
||||
|
||||
assert result.username == "demo"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_username_not_found_raises_404(mock_session: AsyncMock) -> None:
|
||||
service = ProfileService(
|
||||
repository=FakeRepo(None),
|
||||
session=mock_session,
|
||||
current_user=CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001")),
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.get_by_username("unknown")
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
@@ -1,65 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from v1.profile.schemas import ProfileResponse, ProfileUpdateRequest
|
||||
|
||||
|
||||
def test_profile_response_maps_fields() -> None:
|
||||
response = ProfileResponse(
|
||||
id="user-1",
|
||||
username="demo",
|
||||
avatar_url=None,
|
||||
bio=None,
|
||||
)
|
||||
|
||||
assert response.id == "user-1"
|
||||
assert response.username == "demo"
|
||||
|
||||
|
||||
def test_profile_update_requires_one_field() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
ProfileUpdateRequest()
|
||||
|
||||
|
||||
def test_profile_update_accepts_valid_https_url() -> None:
|
||||
request = ProfileUpdateRequest(avatar_url="https://example.com/avatar.png")
|
||||
assert request.avatar_url == "https://example.com/avatar.png"
|
||||
|
||||
|
||||
def test_profile_update_accepts_valid_http_url() -> None:
|
||||
request = ProfileUpdateRequest(
|
||||
avatar_url="http://localhost:8001/storage/avatar.png"
|
||||
)
|
||||
assert request.avatar_url == "http://localhost:8001/storage/avatar.png"
|
||||
|
||||
|
||||
def test_profile_update_rejects_invalid_url() -> None:
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
ProfileUpdateRequest(avatar_url="not-a-valid-url")
|
||||
|
||||
errors = exc_info.value.errors()
|
||||
assert len(errors) == 1
|
||||
assert "avatar_url" in str(errors[0]["loc"])
|
||||
|
||||
|
||||
def test_profile_update_rejects_javascript_url() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
ProfileUpdateRequest(avatar_url="javascript:alert('xss')")
|
||||
|
||||
|
||||
def test_profile_update_rejects_data_url() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
ProfileUpdateRequest(avatar_url="data:text/html,<script>alert('xss')</script>")
|
||||
|
||||
|
||||
def test_profile_update_accepts_none_avatar_url_with_other_field() -> None:
|
||||
request = ProfileUpdateRequest(username="tester", avatar_url=None)
|
||||
assert request.avatar_url is None
|
||||
assert request.username == "tester"
|
||||
|
||||
|
||||
def test_profile_update_rejects_display_name_field() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
ProfileUpdateRequest.model_validate({"display_name": "legacy"})
|
||||
Reference in New Issue
Block a user