fix: 后端 JWT 验证改为 HS256 方式提升认证可靠性

This commit is contained in:
qzl
2026-03-10 17:43:55 +08:00
parent 5d839192ab
commit 95d6927724
4 changed files with 177 additions and 310 deletions
+27 -30
View File
@@ -9,56 +9,53 @@ class TokenValidationError(Exception):
pass pass
class TokenVerifierUnavailableError(Exception):
pass
class JwtVerifier: class JwtVerifier:
_expected_audience = "authenticated"
def __init__( def __init__(
self, self,
jwks_url: str,
issuer: str, issuer: str,
audience: str, jwt_secret: str,
apikey: str, jwt_algorithm: str,
) -> None: ) -> None:
if jwt_algorithm != "HS256":
raise TokenValidationError("Unsupported JWT algorithm")
self._issuer: str = issuer self._issuer: str = issuer
self._audience: str = audience self._jwt_secret: str = jwt_secret
self._jwks_client: jwt.PyJWKClient = jwt.PyJWKClient( self._jwt_algorithm: str = jwt_algorithm
jwks_url,
headers={
"apikey": apikey,
"Authorization": f"Bearer {apikey}",
},
)
def verify(self, token: str) -> dict[str, Any]: def verify(self, token: str) -> dict[str, Any]:
try:
key = self._jwks_client.get_signing_key_from_jwt(token)
except jwt.PyJWKClientConnectionError as exc:
raise TokenVerifierUnavailableError("Unable to fetch JWKS") from exc
except jwt.PyJWKClientError as exc:
raise TokenValidationError("Unable to resolve signing key") from exc
try: try:
payload = jwt.decode( payload = jwt.decode(
token, token,
key.key, self._jwt_secret,
algorithms=["RS256"], algorithms=[self._jwt_algorithm],
audience=self._audience, options={"require": ["sub", "exp", "aud"], "verify_aud": False},
issuer=self._issuer,
options={"require": ["sub", "aud", "iss", "exp"]},
) )
except ( except (
jwt.ExpiredSignatureError, jwt.ExpiredSignatureError,
jwt.InvalidAudienceError,
jwt.InvalidIssuerError, jwt.InvalidIssuerError,
jwt.InvalidSignatureError, jwt.InvalidSignatureError,
jwt.InvalidAlgorithmError,
jwt.DecodeError, jwt.DecodeError,
jwt.PyJWTError, jwt.PyJWTError,
) as exc: ) as exc:
raise TokenValidationError("Token validation failed") from exc raise TokenValidationError("Token validation failed") from exc
if not isinstance(payload, dict): token_audience = payload.get("aud")
raise TokenValidationError("Token payload must be a JSON object") if isinstance(token_audience, str):
audience_match = token_audience == self._expected_audience
elif isinstance(token_audience, list):
audience_match = self._expected_audience in token_audience
else:
audience_match = False
if not audience_match:
raise TokenValidationError("Token audience mismatch")
token_issuer = payload.get("iss")
if token_issuer is not None and token_issuer != self._issuer:
raise TokenValidationError("Token issuer mismatch")
return cast(dict[str, Any], payload) return cast(dict[str, Any], payload)
+3 -5
View File
@@ -8,6 +8,7 @@ from pydantic import (
AnyHttpUrl, AnyHttpUrl,
BaseModel, BaseModel,
Field, Field,
SecretStr,
computed_field, computed_field,
field_validator, field_validator,
model_validator, model_validator,
@@ -126,9 +127,9 @@ class SupabaseSettings(BaseModel):
public_url: AnyHttpUrl public_url: AnyHttpUrl
anon_key: str = "CHANGE_ME" anon_key: str = "CHANGE_ME"
service_role_key: str = "CHANGE_ME" service_role_key: str = "CHANGE_ME"
jwt_audience: str = "authenticated" jwt_secret: SecretStr | None = Field(default=None, exclude=True)
jwt_algorithm: Literal["HS256"] = "HS256"
jwt_issuer: str | None = None jwt_issuer: str | None = None
jwks_url: str | None = None
@model_validator(mode="after") @model_validator(mode="after")
def compute_defaults(self) -> "SupabaseSettings": def compute_defaults(self) -> "SupabaseSettings":
@@ -136,9 +137,6 @@ class SupabaseSettings(BaseModel):
if self.jwt_issuer is None: if self.jwt_issuer is None:
self.jwt_issuer = f"{base}/auth/v1" self.jwt_issuer = f"{base}/auth/v1"
if self.jwks_url is None:
self.jwks_url = f"{self.jwt_issuer}/.well-known/jwks.json"
return self return self
@computed_field @computed_field
+137 -265
View File
@@ -1,8 +1,6 @@
from __future__ import annotations from __future__ import annotations
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
from types import SimpleNamespace
from typing import Any, cast
from uuid import uuid4 from uuid import uuid4
import jwt import jwt
@@ -10,300 +8,174 @@ import pytest
from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives.asymmetric import rsa
from core.auth.jwt_verifier import ( from core.auth.jwt_verifier import JwtVerifier, TokenValidationError
JwtVerifier,
TokenValidationError,
TokenVerifierUnavailableError,
)
def test_jwks_client_uses_supabase_auth_headers( def _build_hs256_token(
monkeypatch: pytest.MonkeyPatch, *,
) -> None: secret: str,
captured: dict[str, Any] = {} sub: str,
issuer: str | None = None,
class _FakePyJWKClient: audience: str | None = "authenticated",
def __init__( ) -> str:
self, now = datetime.now(UTC)
uri: str, payload = {
*, "sub": sub,
headers: dict[str, Any] | None = None, "exp": now + timedelta(minutes=5),
**_: Any,
) -> None:
captured["uri"] = uri
captured["headers"] = headers
monkeypatch.setattr("core.auth.jwt_verifier.jwt.PyJWKClient", _FakePyJWKClient)
JwtVerifier(
jwks_url="https://example.supabase.co/auth/v1/.well-known/jwks.json",
issuer="https://example.supabase.co/auth/v1",
audience="authenticated",
apikey="anon-key-value",
)
assert (
captured["uri"] == "https://example.supabase.co/auth/v1/.well-known/jwks.json"
)
assert captured["headers"] == {
"apikey": "anon-key-value",
"Authorization": "Bearer anon-key-value",
} }
if audience is not None:
payload["aud"] = audience
if issuer is not None:
payload["iss"] = issuer
return jwt.encode(payload, secret, algorithm="HS256")
def _set_jwks_client(verifier: JwtVerifier, client: Any) -> None: def _build_rs256_token(
cast(Any, verifier)._jwks_client = client *, sub: str, issuer: str, audience: str = "authenticated"
) -> str:
now = datetime.now(UTC)
def _build_rsa_key_pair() -> tuple[str, str]: payload = {
"sub": sub,
"iss": issuer,
"aud": audience,
"exp": now + timedelta(minutes=5),
}
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
private_pem = private_key.private_bytes( private_pem = private_key.private_bytes(
encoding=serialization.Encoding.PEM, encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8, format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(), encryption_algorithm=serialization.NoEncryption(),
).decode("utf-8") ).decode("utf-8")
public_pem = ( return jwt.encode(payload, private_pem, algorithm="RS256", headers={"kid": "kid-1"})
private_key.public_key()
.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
.decode("utf-8")
)
return private_pem, public_pem
def _build_token(*, private_key: str, sub: str, audience: str, issuer: str) -> str: def test_verify_hs256_token_success() -> None:
now = datetime.now(UTC)
payload = {
"sub": sub,
"aud": audience,
"iss": issuer,
"exp": now + timedelta(minutes=5),
}
return jwt.encode(payload, private_key, algorithm="RS256", headers={"kid": "kid-1"})
def _build_expired_token(
*, private_key: str, sub: str, audience: str, issuer: str
) -> str:
now = datetime.now(UTC)
payload = {
"sub": sub,
"aud": audience,
"iss": issuer,
"exp": now - timedelta(minutes=1),
}
return jwt.encode(payload, private_key, algorithm="RS256", headers={"kid": "kid-1"})
def _build_hs256_token(*, secret: str, sub: str, audience: str, issuer: str) -> str:
now = datetime.now(UTC)
payload = {
"sub": sub,
"aud": audience,
"iss": issuer,
"exp": now + timedelta(minutes=5),
}
return jwt.encode(payload, secret, algorithm="HS256", headers={"kid": "kid-1"})
def test_verify_token_with_jwks_success() -> None:
user_id = uuid4()
audience = "authenticated"
issuer = "https://example.supabase.co/auth/v1"
private_key, public_key = _build_rsa_key_pair()
token = _build_token(
private_key=private_key,
sub=str(user_id),
audience=audience,
issuer=issuer,
)
verifier = JwtVerifier( verifier = JwtVerifier(
jwks_url="https://example.supabase.co/auth/v1/.well-known/jwks.json", issuer="https://example.supabase.co/auth/v1",
issuer=issuer, jwt_secret="test-secret",
audience=audience, jwt_algorithm="HS256",
apikey="anon-key",
) )
_set_jwks_client( token = _build_hs256_token(
verifier, secret="test-secret",
SimpleNamespace( sub=str(uuid4()),
get_signing_key_from_jwt=lambda _: SimpleNamespace(key=public_key) issuer="https://example.supabase.co/auth/v1",
),
) )
claims = verifier.verify(token) claims = verifier.verify(token)
assert claims["sub"] == str(user_id) assert "sub" in claims
def test_verify_token_rejects_invalid_issuer() -> None: def test_verify_rejects_invalid_issuer() -> None:
audience = "authenticated" verifier = JwtVerifier(
issuer = "https://example.supabase.co/auth/v1" issuer="https://example.supabase.co/auth/v1",
private_key, public_key = _build_rsa_key_pair() jwt_secret="test-secret",
token_with_wrong_iss = _build_token( jwt_algorithm="HS256",
private_key=private_key, )
token = _build_hs256_token(
secret="test-secret",
sub=str(uuid4()), sub=str(uuid4()),
audience=audience,
issuer="https://wrong-issuer.example.com/auth/v1", issuer="https://wrong-issuer.example.com/auth/v1",
) )
verifier = JwtVerifier(
jwks_url="https://example.supabase.co/auth/v1/.well-known/jwks.json",
issuer=issuer,
audience=audience,
apikey="anon-key",
)
_set_jwks_client(
verifier,
SimpleNamespace(
get_signing_key_from_jwt=lambda _: SimpleNamespace(key=public_key)
),
)
with pytest.raises(TokenValidationError): with pytest.raises(TokenValidationError):
verifier.verify(token_with_wrong_iss) verifier.verify(token)
def test_verify_token_rejects_hs256_token() -> None: def test_verify_rejects_missing_audience() -> None:
audience = "authenticated" verifier = JwtVerifier(
issuer = "https://example.supabase.co/auth/v1" issuer="https://example.supabase.co/auth/v1",
_, public_key = _build_rsa_key_pair() jwt_secret="test-secret",
hs_token = _build_hs256_token( jwt_algorithm="HS256",
)
token = _build_hs256_token(
secret="test-secret", secret="test-secret",
sub=str(uuid4()), sub=str(uuid4()),
audience=audience, audience=None,
issuer=issuer,
)
verifier = JwtVerifier(
jwks_url="https://example.supabase.co/auth/v1/.well-known/jwks.json",
issuer=issuer,
audience=audience,
apikey="anon-key",
)
_set_jwks_client(
verifier,
SimpleNamespace(
get_signing_key_from_jwt=lambda _: SimpleNamespace(key=public_key)
),
) )
with pytest.raises(TokenValidationError): with pytest.raises(TokenValidationError):
verifier.verify(hs_token)
def test_verify_token_rejects_expired_token() -> None:
audience = "authenticated"
issuer = "https://example.supabase.co/auth/v1"
private_key, public_key = _build_rsa_key_pair()
expired_token = _build_expired_token(
private_key=private_key,
sub=str(uuid4()),
audience=audience,
issuer=issuer,
)
verifier = JwtVerifier(
jwks_url="https://example.supabase.co/auth/v1/.well-known/jwks.json",
issuer=issuer,
audience=audience,
apikey="anon-key",
)
_set_jwks_client(
verifier,
SimpleNamespace(
get_signing_key_from_jwt=lambda _: SimpleNamespace(key=public_key)
),
)
with pytest.raises(TokenValidationError):
verifier.verify(expired_token)
def test_verify_token_rejects_invalid_audience() -> None:
audience = "authenticated"
issuer = "https://example.supabase.co/auth/v1"
private_key, public_key = _build_rsa_key_pair()
wrong_aud_token = _build_token(
private_key=private_key,
sub=str(uuid4()),
audience="anon",
issuer=issuer,
)
verifier = JwtVerifier(
jwks_url="https://example.supabase.co/auth/v1/.well-known/jwks.json",
issuer=issuer,
audience=audience,
apikey="anon-key",
)
_set_jwks_client(
verifier,
SimpleNamespace(
get_signing_key_from_jwt=lambda _: SimpleNamespace(key=public_key)
),
)
with pytest.raises(TokenValidationError):
verifier.verify(wrong_aud_token)
def test_verify_token_rejects_invalid_signature() -> None:
audience = "authenticated"
issuer = "https://example.supabase.co/auth/v1"
private_key, public_key = _build_rsa_key_pair()
valid_token = _build_token(
private_key=private_key,
sub=str(uuid4()),
audience=audience,
issuer=issuer,
)
tampered_token = f"{valid_token}x"
verifier = JwtVerifier(
jwks_url="https://example.supabase.co/auth/v1/.well-known/jwks.json",
issuer=issuer,
audience=audience,
apikey="anon-key",
)
_set_jwks_client(
verifier,
SimpleNamespace(
get_signing_key_from_jwt=lambda _: SimpleNamespace(key=public_key)
),
)
with pytest.raises(TokenValidationError):
verifier.verify(tampered_token)
def test_verify_token_maps_jwks_connection_error() -> None:
audience = "authenticated"
issuer = "https://example.supabase.co/auth/v1"
private_key, _ = _build_rsa_key_pair()
token = _build_token(
private_key=private_key,
sub=str(uuid4()),
audience=audience,
issuer=issuer,
)
verifier = JwtVerifier(
jwks_url="https://example.supabase.co/auth/v1/.well-known/jwks.json",
issuer=issuer,
audience=audience,
apikey="anon-key",
)
def _raise_connection_error(_: str) -> SimpleNamespace:
raise jwt.PyJWKClientConnectionError("network down")
_set_jwks_client(
verifier,
SimpleNamespace(get_signing_key_from_jwt=_raise_connection_error),
)
with pytest.raises(TokenVerifierUnavailableError):
verifier.verify(token) verifier.verify(token)
def test_verify_accepts_token_without_issuer_claim() -> None:
verifier = JwtVerifier(
issuer="https://example.supabase.co/auth/v1",
jwt_secret="test-secret",
jwt_algorithm="HS256",
)
token = _build_hs256_token(
secret="test-secret",
sub=str(uuid4()),
)
claims = verifier.verify(token)
assert "sub" in claims
def test_verify_accepts_list_audience() -> None:
verifier = JwtVerifier(
issuer="https://example.supabase.co/auth/v1",
jwt_secret="test-secret",
jwt_algorithm="HS256",
)
token = jwt.encode(
{
"sub": str(uuid4()),
"aud": ["anonymous", "authenticated"],
"exp": datetime.now(UTC) + timedelta(minutes=5),
},
"test-secret",
algorithm="HS256",
)
claims = verifier.verify(token)
assert claims["aud"] == ["anonymous", "authenticated"]
def test_verify_rejects_rs256_token() -> None:
verifier = JwtVerifier(
issuer="https://example.supabase.co/auth/v1",
jwt_secret="test-secret",
jwt_algorithm="HS256",
)
token = _build_rs256_token(
sub=str(uuid4()),
issuer="https://example.supabase.co/auth/v1",
)
with pytest.raises(TokenValidationError):
verifier.verify(token)
def test_verify_rejects_expired_token() -> None:
verifier = JwtVerifier(
issuer="https://example.supabase.co/auth/v1",
jwt_secret="test-secret",
jwt_algorithm="HS256",
)
now = datetime.now(UTC)
token = jwt.encode(
{
"sub": str(uuid4()),
"iss": "https://example.supabase.co/auth/v1",
"aud": "authenticated",
"exp": now - timedelta(minutes=1),
},
"test-secret",
algorithm="HS256",
)
with pytest.raises(TokenValidationError):
verifier.verify(token)
def test_verify_rejects_unsupported_algorithm_setting() -> None:
with pytest.raises(TokenValidationError):
JwtVerifier(
issuer="https://example.supabase.co/auth/v1",
jwt_secret="test-secret",
jwt_algorithm="RS256",
)
@@ -13,6 +13,8 @@ def test_social_prefixed_supabase_env_populates_settings(
monkeypatch.setenv("SOCIAL_SUPABASE__PUBLIC_URL", "https://public.example:8443") monkeypatch.setenv("SOCIAL_SUPABASE__PUBLIC_URL", "https://public.example:8443")
monkeypatch.setenv("SOCIAL_SUPABASE__ANON_KEY", "anon-key") monkeypatch.setenv("SOCIAL_SUPABASE__ANON_KEY", "anon-key")
monkeypatch.setenv("SOCIAL_SUPABASE__SERVICE_ROLE_KEY", "service-key") monkeypatch.setenv("SOCIAL_SUPABASE__SERVICE_ROLE_KEY", "service-key")
monkeypatch.setenv("SOCIAL_SUPABASE__JWT_SECRET", "jwt-secret")
monkeypatch.setenv("SOCIAL_SUPABASE__JWT_ALGORITHM", "HS256")
monkeypatch.setenv("SOCIAL_DATABASE__HOST", "db") monkeypatch.setenv("SOCIAL_DATABASE__HOST", "db")
monkeypatch.setenv("SOCIAL_DATABASE__PORT", "5432") monkeypatch.setenv("SOCIAL_DATABASE__PORT", "5432")
monkeypatch.setenv("SOCIAL_DATABASE__NAME", "app") monkeypatch.setenv("SOCIAL_DATABASE__NAME", "app")
@@ -24,6 +26,9 @@ def test_social_prefixed_supabase_env_populates_settings(
assert str(settings.supabase.public_url) == "https://public.example:8443/" assert str(settings.supabase.public_url) == "https://public.example:8443/"
assert settings.supabase.anon_key == "anon-key" assert settings.supabase.anon_key == "anon-key"
assert settings.supabase.service_role_key == "service-key" assert settings.supabase.service_role_key == "service-key"
assert settings.supabase.jwt_secret is not None
assert settings.supabase.jwt_secret.get_secret_value() == "jwt-secret"
assert settings.supabase.jwt_algorithm == "HS256"
supabase_settings = settings.model_dump()["supabase"] supabase_settings = settings.model_dump()["supabase"]
assert str(supabase_settings["public_url"]) == "https://public.example:8443/" assert str(supabase_settings["public_url"]) == "https://public.example:8443/"
@@ -42,17 +47,14 @@ def test_cloud_supabase_env_populates_settings(monkeypatch: MonkeyPatch) -> None
) )
monkeypatch.setenv("SOCIAL_SUPABASE__ANON_KEY", "anon-key") monkeypatch.setenv("SOCIAL_SUPABASE__ANON_KEY", "anon-key")
monkeypatch.setenv("SOCIAL_SUPABASE__SERVICE_ROLE_KEY", "service-key") monkeypatch.setenv("SOCIAL_SUPABASE__SERVICE_ROLE_KEY", "service-key")
monkeypatch.setenv("SOCIAL_SUPABASE__JWT_AUDIENCE", "authenticated") monkeypatch.setenv("SOCIAL_SUPABASE__JWT_SECRET", "jwt-secret")
monkeypatch.setenv("SOCIAL_SUPABASE__JWT_ALGORITHM", "HS256")
settings = Settings() settings = Settings()
assert str(settings.supabase.public_url) == "https://project.example.supabase.co/" assert str(settings.supabase.public_url) == "https://project.example.supabase.co/"
assert settings.supabase.jwt_audience == "authenticated" assert settings.supabase.jwt_algorithm == "HS256"
assert settings.supabase.jwt_issuer == "https://project.example.supabase.co/auth/v1" assert settings.supabase.jwt_issuer == "https://project.example.supabase.co/auth/v1"
assert (
settings.supabase.jwks_url
== "https://project.example.supabase.co/auth/v1/.well-known/jwks.json"
)
supabase_settings = settings.model_dump()["supabase"] supabase_settings = settings.model_dump()["supabase"]
assert "jwt_secret" not in supabase_settings assert "jwt_secret" not in supabase_settings
@@ -71,6 +73,8 @@ def test_public_url_with_trailing_slash_normalizes_correctly(
monkeypatch.setenv("SOCIAL_SUPABASE__PUBLIC_URL", "https://example.supabase.co/") monkeypatch.setenv("SOCIAL_SUPABASE__PUBLIC_URL", "https://example.supabase.co/")
monkeypatch.setenv("SOCIAL_SUPABASE__ANON_KEY", "anon-key") monkeypatch.setenv("SOCIAL_SUPABASE__ANON_KEY", "anon-key")
monkeypatch.setenv("SOCIAL_SUPABASE__SERVICE_ROLE_KEY", "service-key") monkeypatch.setenv("SOCIAL_SUPABASE__SERVICE_ROLE_KEY", "service-key")
monkeypatch.setenv("SOCIAL_SUPABASE__JWT_SECRET", "jwt-secret")
monkeypatch.setenv("SOCIAL_SUPABASE__JWT_ALGORITHM", "HS256")
monkeypatch.setenv("SOCIAL_DATABASE__HOST", "db") monkeypatch.setenv("SOCIAL_DATABASE__HOST", "db")
monkeypatch.setenv("SOCIAL_DATABASE__PORT", "5432") monkeypatch.setenv("SOCIAL_DATABASE__PORT", "5432")
monkeypatch.setenv("SOCIAL_DATABASE__NAME", "app") monkeypatch.setenv("SOCIAL_DATABASE__NAME", "app")
@@ -80,8 +84,4 @@ def test_public_url_with_trailing_slash_normalizes_correctly(
settings = Settings() settings = Settings()
assert settings.supabase.jwt_issuer == "https://example.supabase.co/auth/v1" assert settings.supabase.jwt_issuer == "https://example.supabase.co/auth/v1"
assert (
settings.supabase.jwks_url
== "https://example.supabase.co/auth/v1/.well-known/jwks.json"
)
assert settings.supabase.url == "https://example.supabase.co/" assert settings.supabase.url == "https://example.supabase.co/"