diff --git a/backend/src/core/auth/jwt_verifier.py b/backend/src/core/auth/jwt_verifier.py index d2b9be7..7937631 100644 --- a/backend/src/core/auth/jwt_verifier.py +++ b/backend/src/core/auth/jwt_verifier.py @@ -9,56 +9,53 @@ class TokenValidationError(Exception): pass -class TokenVerifierUnavailableError(Exception): - pass - - class JwtVerifier: + _expected_audience = "authenticated" + def __init__( self, - jwks_url: str, issuer: str, - audience: str, - apikey: str, + jwt_secret: str, + jwt_algorithm: str, ) -> None: + if jwt_algorithm != "HS256": + raise TokenValidationError("Unsupported JWT algorithm") + self._issuer: str = issuer - self._audience: str = audience - self._jwks_client: jwt.PyJWKClient = jwt.PyJWKClient( - jwks_url, - headers={ - "apikey": apikey, - "Authorization": f"Bearer {apikey}", - }, - ) + self._jwt_secret: str = jwt_secret + self._jwt_algorithm: str = jwt_algorithm def verify(self, token: str) -> dict[str, Any]: - try: - key = self._jwks_client.get_signing_key_from_jwt(token) - except jwt.PyJWKClientConnectionError as exc: - raise TokenVerifierUnavailableError("Unable to fetch JWKS") from exc - except jwt.PyJWKClientError as exc: - raise TokenValidationError("Unable to resolve signing key") from exc - try: payload = jwt.decode( token, - key.key, - algorithms=["RS256"], - audience=self._audience, - issuer=self._issuer, - options={"require": ["sub", "aud", "iss", "exp"]}, + self._jwt_secret, + algorithms=[self._jwt_algorithm], + options={"require": ["sub", "exp", "aud"], "verify_aud": False}, ) except ( jwt.ExpiredSignatureError, - jwt.InvalidAudienceError, jwt.InvalidIssuerError, jwt.InvalidSignatureError, + jwt.InvalidAlgorithmError, jwt.DecodeError, jwt.PyJWTError, ) as exc: raise TokenValidationError("Token validation failed") from exc - if not isinstance(payload, dict): - raise TokenValidationError("Token payload must be a JSON object") + token_audience = payload.get("aud") + if isinstance(token_audience, str): + audience_match = token_audience == self._expected_audience + elif isinstance(token_audience, list): + audience_match = self._expected_audience in token_audience + else: + audience_match = False + + if not audience_match: + raise TokenValidationError("Token audience mismatch") + + token_issuer = payload.get("iss") + if token_issuer is not None and token_issuer != self._issuer: + raise TokenValidationError("Token issuer mismatch") return cast(dict[str, Any], payload) diff --git a/backend/src/core/config/settings.py b/backend/src/core/config/settings.py index c395aa7..22b1959 100644 --- a/backend/src/core/config/settings.py +++ b/backend/src/core/config/settings.py @@ -8,6 +8,7 @@ from pydantic import ( AnyHttpUrl, BaseModel, Field, + SecretStr, computed_field, field_validator, model_validator, @@ -126,9 +127,9 @@ class SupabaseSettings(BaseModel): public_url: AnyHttpUrl anon_key: str = "CHANGE_ME" service_role_key: str = "CHANGE_ME" - jwt_audience: str = "authenticated" + jwt_secret: SecretStr | None = Field(default=None, exclude=True) + jwt_algorithm: Literal["HS256"] = "HS256" jwt_issuer: str | None = None - jwks_url: str | None = None @model_validator(mode="after") def compute_defaults(self) -> "SupabaseSettings": @@ -136,9 +137,6 @@ class SupabaseSettings(BaseModel): if self.jwt_issuer is None: self.jwt_issuer = f"{base}/auth/v1" - if self.jwks_url is None: - self.jwks_url = f"{self.jwt_issuer}/.well-known/jwks.json" - return self @computed_field diff --git a/backend/tests/unit/core/auth/test_jwt_verifier.py b/backend/tests/unit/core/auth/test_jwt_verifier.py index 015ef3e..dc75c2e 100644 --- a/backend/tests/unit/core/auth/test_jwt_verifier.py +++ b/backend/tests/unit/core/auth/test_jwt_verifier.py @@ -1,8 +1,6 @@ from __future__ import annotations from datetime import UTC, datetime, timedelta -from types import SimpleNamespace -from typing import Any, cast from uuid import uuid4 import jwt @@ -10,300 +8,174 @@ import pytest from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import rsa -from core.auth.jwt_verifier import ( - JwtVerifier, - TokenValidationError, - TokenVerifierUnavailableError, -) +from core.auth.jwt_verifier import JwtVerifier, TokenValidationError -def test_jwks_client_uses_supabase_auth_headers( - monkeypatch: pytest.MonkeyPatch, -) -> None: - captured: dict[str, Any] = {} - - class _FakePyJWKClient: - def __init__( - self, - uri: str, - *, - headers: dict[str, Any] | None = None, - **_: Any, - ) -> None: - captured["uri"] = uri - captured["headers"] = headers - - monkeypatch.setattr("core.auth.jwt_verifier.jwt.PyJWKClient", _FakePyJWKClient) - - JwtVerifier( - jwks_url="https://example.supabase.co/auth/v1/.well-known/jwks.json", - issuer="https://example.supabase.co/auth/v1", - audience="authenticated", - apikey="anon-key-value", - ) - - assert ( - captured["uri"] == "https://example.supabase.co/auth/v1/.well-known/jwks.json" - ) - assert captured["headers"] == { - "apikey": "anon-key-value", - "Authorization": "Bearer anon-key-value", +def _build_hs256_token( + *, + secret: str, + sub: str, + issuer: str | None = None, + audience: str | None = "authenticated", +) -> str: + now = datetime.now(UTC) + payload = { + "sub": sub, + "exp": now + timedelta(minutes=5), } + if audience is not None: + payload["aud"] = audience + if issuer is not None: + payload["iss"] = issuer + return jwt.encode(payload, secret, algorithm="HS256") -def _set_jwks_client(verifier: JwtVerifier, client: Any) -> None: - cast(Any, verifier)._jwks_client = client - - -def _build_rsa_key_pair() -> tuple[str, str]: +def _build_rs256_token( + *, sub: str, issuer: str, audience: str = "authenticated" +) -> str: + now = datetime.now(UTC) + payload = { + "sub": sub, + "iss": issuer, + "aud": audience, + "exp": now + timedelta(minutes=5), + } private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) private_pem = private_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.PKCS8, encryption_algorithm=serialization.NoEncryption(), ).decode("utf-8") - public_pem = ( - private_key.public_key() - .public_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PublicFormat.SubjectPublicKeyInfo, - ) - .decode("utf-8") - ) - return private_pem, public_pem + return jwt.encode(payload, private_pem, algorithm="RS256", headers={"kid": "kid-1"}) -def _build_token(*, private_key: str, sub: str, audience: str, issuer: str) -> str: - now = datetime.now(UTC) - payload = { - "sub": sub, - "aud": audience, - "iss": issuer, - "exp": now + timedelta(minutes=5), - } - return jwt.encode(payload, private_key, algorithm="RS256", headers={"kid": "kid-1"}) - - -def _build_expired_token( - *, private_key: str, sub: str, audience: str, issuer: str -) -> str: - now = datetime.now(UTC) - payload = { - "sub": sub, - "aud": audience, - "iss": issuer, - "exp": now - timedelta(minutes=1), - } - return jwt.encode(payload, private_key, algorithm="RS256", headers={"kid": "kid-1"}) - - -def _build_hs256_token(*, secret: str, sub: str, audience: str, issuer: str) -> str: - now = datetime.now(UTC) - payload = { - "sub": sub, - "aud": audience, - "iss": issuer, - "exp": now + timedelta(minutes=5), - } - return jwt.encode(payload, secret, algorithm="HS256", headers={"kid": "kid-1"}) - - -def test_verify_token_with_jwks_success() -> None: - user_id = uuid4() - audience = "authenticated" - issuer = "https://example.supabase.co/auth/v1" - private_key, public_key = _build_rsa_key_pair() - token = _build_token( - private_key=private_key, - sub=str(user_id), - audience=audience, - issuer=issuer, - ) - +def test_verify_hs256_token_success() -> None: verifier = JwtVerifier( - jwks_url="https://example.supabase.co/auth/v1/.well-known/jwks.json", - issuer=issuer, - audience=audience, - apikey="anon-key", + issuer="https://example.supabase.co/auth/v1", + jwt_secret="test-secret", + jwt_algorithm="HS256", ) - _set_jwks_client( - verifier, - SimpleNamespace( - get_signing_key_from_jwt=lambda _: SimpleNamespace(key=public_key) - ), + token = _build_hs256_token( + secret="test-secret", + sub=str(uuid4()), + issuer="https://example.supabase.co/auth/v1", ) claims = verifier.verify(token) - assert claims["sub"] == str(user_id) + assert "sub" in claims -def test_verify_token_rejects_invalid_issuer() -> None: - audience = "authenticated" - issuer = "https://example.supabase.co/auth/v1" - private_key, public_key = _build_rsa_key_pair() - token_with_wrong_iss = _build_token( - private_key=private_key, +def test_verify_rejects_invalid_issuer() -> None: + verifier = JwtVerifier( + issuer="https://example.supabase.co/auth/v1", + jwt_secret="test-secret", + jwt_algorithm="HS256", + ) + token = _build_hs256_token( + secret="test-secret", sub=str(uuid4()), - audience=audience, issuer="https://wrong-issuer.example.com/auth/v1", ) - verifier = JwtVerifier( - jwks_url="https://example.supabase.co/auth/v1/.well-known/jwks.json", - issuer=issuer, - audience=audience, - apikey="anon-key", - ) - _set_jwks_client( - verifier, - SimpleNamespace( - get_signing_key_from_jwt=lambda _: SimpleNamespace(key=public_key) - ), - ) - with pytest.raises(TokenValidationError): - verifier.verify(token_with_wrong_iss) + verifier.verify(token) -def test_verify_token_rejects_hs256_token() -> None: - audience = "authenticated" - issuer = "https://example.supabase.co/auth/v1" - _, public_key = _build_rsa_key_pair() - hs_token = _build_hs256_token( +def test_verify_rejects_missing_audience() -> None: + verifier = JwtVerifier( + issuer="https://example.supabase.co/auth/v1", + jwt_secret="test-secret", + jwt_algorithm="HS256", + ) + token = _build_hs256_token( secret="test-secret", sub=str(uuid4()), - audience=audience, - issuer=issuer, - ) - - verifier = JwtVerifier( - jwks_url="https://example.supabase.co/auth/v1/.well-known/jwks.json", - issuer=issuer, - audience=audience, - apikey="anon-key", - ) - _set_jwks_client( - verifier, - SimpleNamespace( - get_signing_key_from_jwt=lambda _: SimpleNamespace(key=public_key) - ), + audience=None, ) with pytest.raises(TokenValidationError): - verifier.verify(hs_token) - - -def test_verify_token_rejects_expired_token() -> None: - audience = "authenticated" - issuer = "https://example.supabase.co/auth/v1" - private_key, public_key = _build_rsa_key_pair() - expired_token = _build_expired_token( - private_key=private_key, - sub=str(uuid4()), - audience=audience, - issuer=issuer, - ) - - verifier = JwtVerifier( - jwks_url="https://example.supabase.co/auth/v1/.well-known/jwks.json", - issuer=issuer, - audience=audience, - apikey="anon-key", - ) - _set_jwks_client( - verifier, - SimpleNamespace( - get_signing_key_from_jwt=lambda _: SimpleNamespace(key=public_key) - ), - ) - - with pytest.raises(TokenValidationError): - verifier.verify(expired_token) - - -def test_verify_token_rejects_invalid_audience() -> None: - audience = "authenticated" - issuer = "https://example.supabase.co/auth/v1" - private_key, public_key = _build_rsa_key_pair() - wrong_aud_token = _build_token( - private_key=private_key, - sub=str(uuid4()), - audience="anon", - issuer=issuer, - ) - - verifier = JwtVerifier( - jwks_url="https://example.supabase.co/auth/v1/.well-known/jwks.json", - issuer=issuer, - audience=audience, - apikey="anon-key", - ) - _set_jwks_client( - verifier, - SimpleNamespace( - get_signing_key_from_jwt=lambda _: SimpleNamespace(key=public_key) - ), - ) - - with pytest.raises(TokenValidationError): - verifier.verify(wrong_aud_token) - - -def test_verify_token_rejects_invalid_signature() -> None: - audience = "authenticated" - issuer = "https://example.supabase.co/auth/v1" - private_key, public_key = _build_rsa_key_pair() - valid_token = _build_token( - private_key=private_key, - sub=str(uuid4()), - audience=audience, - issuer=issuer, - ) - tampered_token = f"{valid_token}x" - - verifier = JwtVerifier( - jwks_url="https://example.supabase.co/auth/v1/.well-known/jwks.json", - issuer=issuer, - audience=audience, - apikey="anon-key", - ) - _set_jwks_client( - verifier, - SimpleNamespace( - get_signing_key_from_jwt=lambda _: SimpleNamespace(key=public_key) - ), - ) - - with pytest.raises(TokenValidationError): - verifier.verify(tampered_token) - - -def test_verify_token_maps_jwks_connection_error() -> None: - audience = "authenticated" - issuer = "https://example.supabase.co/auth/v1" - private_key, _ = _build_rsa_key_pair() - token = _build_token( - private_key=private_key, - sub=str(uuid4()), - audience=audience, - issuer=issuer, - ) - - verifier = JwtVerifier( - jwks_url="https://example.supabase.co/auth/v1/.well-known/jwks.json", - issuer=issuer, - audience=audience, - apikey="anon-key", - ) - - def _raise_connection_error(_: str) -> SimpleNamespace: - raise jwt.PyJWKClientConnectionError("network down") - - _set_jwks_client( - verifier, - SimpleNamespace(get_signing_key_from_jwt=_raise_connection_error), - ) - - with pytest.raises(TokenVerifierUnavailableError): verifier.verify(token) + + +def test_verify_accepts_token_without_issuer_claim() -> None: + verifier = JwtVerifier( + issuer="https://example.supabase.co/auth/v1", + jwt_secret="test-secret", + jwt_algorithm="HS256", + ) + token = _build_hs256_token( + secret="test-secret", + sub=str(uuid4()), + ) + + claims = verifier.verify(token) + + assert "sub" in claims + + +def test_verify_accepts_list_audience() -> None: + verifier = JwtVerifier( + issuer="https://example.supabase.co/auth/v1", + jwt_secret="test-secret", + jwt_algorithm="HS256", + ) + token = jwt.encode( + { + "sub": str(uuid4()), + "aud": ["anonymous", "authenticated"], + "exp": datetime.now(UTC) + timedelta(minutes=5), + }, + "test-secret", + algorithm="HS256", + ) + + claims = verifier.verify(token) + + assert claims["aud"] == ["anonymous", "authenticated"] + + +def test_verify_rejects_rs256_token() -> None: + verifier = JwtVerifier( + issuer="https://example.supabase.co/auth/v1", + jwt_secret="test-secret", + jwt_algorithm="HS256", + ) + token = _build_rs256_token( + sub=str(uuid4()), + issuer="https://example.supabase.co/auth/v1", + ) + + with pytest.raises(TokenValidationError): + verifier.verify(token) + + +def test_verify_rejects_expired_token() -> None: + verifier = JwtVerifier( + issuer="https://example.supabase.co/auth/v1", + jwt_secret="test-secret", + jwt_algorithm="HS256", + ) + now = datetime.now(UTC) + token = jwt.encode( + { + "sub": str(uuid4()), + "iss": "https://example.supabase.co/auth/v1", + "aud": "authenticated", + "exp": now - timedelta(minutes=1), + }, + "test-secret", + algorithm="HS256", + ) + + with pytest.raises(TokenValidationError): + verifier.verify(token) + + +def test_verify_rejects_unsupported_algorithm_setting() -> None: + with pytest.raises(TokenValidationError): + JwtVerifier( + issuer="https://example.supabase.co/auth/v1", + jwt_secret="test-secret", + jwt_algorithm="RS256", + ) diff --git a/backend/tests/unit/test_settings_supabase_env.py b/backend/tests/unit/test_settings_supabase_env.py index 9ecd738..1ed413d 100644 --- a/backend/tests/unit/test_settings_supabase_env.py +++ b/backend/tests/unit/test_settings_supabase_env.py @@ -13,6 +13,8 @@ def test_social_prefixed_supabase_env_populates_settings( monkeypatch.setenv("SOCIAL_SUPABASE__PUBLIC_URL", "https://public.example:8443") monkeypatch.setenv("SOCIAL_SUPABASE__ANON_KEY", "anon-key") monkeypatch.setenv("SOCIAL_SUPABASE__SERVICE_ROLE_KEY", "service-key") + monkeypatch.setenv("SOCIAL_SUPABASE__JWT_SECRET", "jwt-secret") + monkeypatch.setenv("SOCIAL_SUPABASE__JWT_ALGORITHM", "HS256") monkeypatch.setenv("SOCIAL_DATABASE__HOST", "db") monkeypatch.setenv("SOCIAL_DATABASE__PORT", "5432") monkeypatch.setenv("SOCIAL_DATABASE__NAME", "app") @@ -24,6 +26,9 @@ def test_social_prefixed_supabase_env_populates_settings( assert str(settings.supabase.public_url) == "https://public.example:8443/" assert settings.supabase.anon_key == "anon-key" assert settings.supabase.service_role_key == "service-key" + assert settings.supabase.jwt_secret is not None + assert settings.supabase.jwt_secret.get_secret_value() == "jwt-secret" + assert settings.supabase.jwt_algorithm == "HS256" supabase_settings = settings.model_dump()["supabase"] assert str(supabase_settings["public_url"]) == "https://public.example:8443/" @@ -42,17 +47,14 @@ def test_cloud_supabase_env_populates_settings(monkeypatch: MonkeyPatch) -> None ) monkeypatch.setenv("SOCIAL_SUPABASE__ANON_KEY", "anon-key") monkeypatch.setenv("SOCIAL_SUPABASE__SERVICE_ROLE_KEY", "service-key") - monkeypatch.setenv("SOCIAL_SUPABASE__JWT_AUDIENCE", "authenticated") + monkeypatch.setenv("SOCIAL_SUPABASE__JWT_SECRET", "jwt-secret") + monkeypatch.setenv("SOCIAL_SUPABASE__JWT_ALGORITHM", "HS256") settings = Settings() assert str(settings.supabase.public_url) == "https://project.example.supabase.co/" - assert settings.supabase.jwt_audience == "authenticated" + assert settings.supabase.jwt_algorithm == "HS256" assert settings.supabase.jwt_issuer == "https://project.example.supabase.co/auth/v1" - assert ( - settings.supabase.jwks_url - == "https://project.example.supabase.co/auth/v1/.well-known/jwks.json" - ) supabase_settings = settings.model_dump()["supabase"] assert "jwt_secret" not in supabase_settings @@ -71,6 +73,8 @@ def test_public_url_with_trailing_slash_normalizes_correctly( monkeypatch.setenv("SOCIAL_SUPABASE__PUBLIC_URL", "https://example.supabase.co/") monkeypatch.setenv("SOCIAL_SUPABASE__ANON_KEY", "anon-key") monkeypatch.setenv("SOCIAL_SUPABASE__SERVICE_ROLE_KEY", "service-key") + monkeypatch.setenv("SOCIAL_SUPABASE__JWT_SECRET", "jwt-secret") + monkeypatch.setenv("SOCIAL_SUPABASE__JWT_ALGORITHM", "HS256") monkeypatch.setenv("SOCIAL_DATABASE__HOST", "db") monkeypatch.setenv("SOCIAL_DATABASE__PORT", "5432") monkeypatch.setenv("SOCIAL_DATABASE__NAME", "app") @@ -80,8 +84,4 @@ def test_public_url_with_trailing_slash_normalizes_correctly( settings = Settings() assert settings.supabase.jwt_issuer == "https://example.supabase.co/auth/v1" - assert ( - settings.supabase.jwks_url - == "https://example.supabase.co/auth/v1/.well-known/jwks.json" - ) assert settings.supabase.url == "https://example.supabase.co/"