refactor: 迁移本地 Supabase 到云端,使用 JWKS 进行 JWT 验证

- 新增 JwtVerifier 支持 RS256 + JWKS 验证
- 简化 docker-compose,删除本地 Supabase 服务(kong/auth/storage等)
- 删除冗余的 Supabase 配置文件(volumes目录)
- 适配测试用例以支持新配置方式
- 更新运行时文档和迁移计划
This commit is contained in:
qzl
2026-03-09 18:03:04 +08:00
parent 3ac09475ad
commit 6fe2e7b6c3
24 changed files with 825 additions and 1403 deletions
+52
View File
@@ -0,0 +1,52 @@
from __future__ import annotations
from typing import Any, cast
import jwt
class TokenValidationError(Exception):
pass
class TokenVerifierUnavailableError(Exception):
pass
class JwtVerifier:
def __init__(self, jwks_url: str, issuer: str, audience: str) -> None:
self._issuer: str = issuer
self._audience: str = audience
self._jwks_client: jwt.PyJWKClient = jwt.PyJWKClient(jwks_url)
def verify(self, token: str) -> dict[str, Any]:
try:
key = self._jwks_client.get_signing_key_from_jwt(token)
except jwt.PyJWKClientConnectionError as exc:
raise TokenVerifierUnavailableError("Unable to fetch JWKS") from exc
except jwt.PyJWKClientError as exc:
raise TokenValidationError("Unable to resolve signing key") from exc
try:
payload = jwt.decode(
token,
key.key,
algorithms=["RS256"],
audience=self._audience,
issuer=self._issuer,
options={"require": ["sub", "aud", "iss", "exp"]},
)
except (
jwt.ExpiredSignatureError,
jwt.InvalidAudienceError,
jwt.InvalidIssuerError,
jwt.InvalidSignatureError,
jwt.DecodeError,
jwt.PyJWTError,
) as exc:
raise TokenValidationError("Token validation failed") from exc
if not isinstance(payload, dict):
raise TokenValidationError("Token payload must be a JSON object")
return cast(dict[str, Any], payload)
+30 -14
View File
@@ -4,7 +4,14 @@ from pathlib import Path
from typing import ClassVar, Literal
from urllib.parse import quote
from pydantic import BaseModel, Field, computed_field, field_validator, model_validator
from pydantic import (
AnyHttpUrl,
BaseModel,
Field,
computed_field,
field_validator,
model_validator,
)
from pydantic_settings import BaseSettings, SettingsConfigDict
@@ -116,14 +123,14 @@ class RedisSettings(BaseModel):
class SupabaseSettings(BaseModel):
public_scheme: str = "http"
public_host: str = "localhost"
kong_http_port: int = 8000
site_url: str = "http://localhost:3000"
additional_redirect_urls: list[str] = Field(default_factory=list)
public_url: AnyHttpUrl
anon_key: str = "CHANGE_ME"
service_role_key: str = "CHANGE_ME"
jwt_secret: str | None = None
jwt_audience: str = "authenticated"
jwt_issuer: str | None = None
jwks_url: str | None = None
site_url: str | None = None
additional_redirect_urls: list[str] = Field(default_factory=list)
@field_validator("additional_redirect_urls", mode="before")
@classmethod
@@ -136,15 +143,24 @@ class SupabaseSettings(BaseModel):
return [str(item).strip() for item in value if str(item).strip()]
return []
@computed_field
@property
def public_url(self) -> str:
return f"{self.public_scheme}://{self.public_host}:{self.kong_http_port}"
@model_validator(mode="after")
def compute_defaults(self) -> "SupabaseSettings":
base = str(self.public_url).rstrip("/")
if self.jwt_issuer is None:
self.jwt_issuer = f"{base}/auth/v1"
if self.jwks_url is None:
self.jwks_url = f"{self.jwt_issuer}/.well-known/jwks.json"
if self.site_url is None:
self.site_url = "http://localhost:3000"
return self
@computed_field
@property
def url(self) -> str:
return self.public_url
return str(self.public_url)
class StorageSettings(BaseModel):
@@ -205,7 +221,7 @@ class Settings(BaseSettings):
runtime: RuntimeSettings = RuntimeSettings()
cors: CorsSettings = CorsSettings()
redis: RedisSettings = RedisSettings()
supabase: SupabaseSettings = SupabaseSettings()
supabase: SupabaseSettings = Field()
storage: StorageSettings = StorageSettings()
llm: LlmSettings = LlmSettings()
agent_runtime: AgentRuntimeSettings = AgentRuntimeSettings()
@@ -236,4 +252,4 @@ class Settings(BaseSettings):
)
config = Settings()
config = Settings() # type: ignore[reportCallIssue]
+28 -39
View File
@@ -3,10 +3,14 @@ from __future__ import annotations
from typing import Annotated
from uuid import UUID
import jwt
from fastapi import Depends, Header, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from core.auth.jwt_verifier import (
JwtVerifier,
TokenValidationError,
TokenVerifierUnavailableError,
)
from core.auth.models import CurrentUser
from core.config.settings import config
from core.db import get_db
@@ -18,6 +22,7 @@ from v1.users.service import AuthLookupAdapter, UserService
logger = get_logger("v1.users.dependencies")
_auth_gateway: SupabaseAuthGateway | None = None
_jwt_verifier: JwtVerifier | None = None
def get_auth_gateway() -> SupabaseAuthGateway:
@@ -27,6 +32,19 @@ def get_auth_gateway() -> SupabaseAuthGateway:
return _auth_gateway
def get_jwt_verifier() -> JwtVerifier:
global _jwt_verifier
if _jwt_verifier is None:
jwks_url = config.supabase.jwks_url
issuer = config.supabase.jwt_issuer
audience = config.supabase.jwt_audience
if not jwks_url or not issuer or not audience:
logger.error("JWT validation failed: verifier config not configured")
raise HTTPException(status_code=503, detail="JWT verifier not configured")
_jwt_verifier = JwtVerifier(jwks_url=jwks_url, issuer=issuer, audience=audience)
return _jwt_verifier
def get_current_user(authorization: str | None = Header(default=None)) -> CurrentUser:
if not authorization:
logger.warning("JWT validation failed: missing authorization header")
@@ -37,46 +55,17 @@ def get_current_user(authorization: str | None = Header(default=None)) -> Curren
logger.warning("JWT validation failed: invalid authorization scheme")
raise HTTPException(status_code=401, detail="Unauthorized")
secret = config.supabase.jwt_secret
if not secret:
logger.error("JWT validation failed: secret not configured")
raise HTTPException(status_code=503, detail="JWT secret not configured")
supabase_url = config.supabase.public_url.rstrip("/")
expected_issuer = f"{supabase_url}/auth/v1"
try:
payload = jwt.decode(
token,
secret,
algorithms=["HS256"],
audience="authenticated",
issuer=expected_issuer,
options={
"verify_aud": True,
"verify_iss": True,
"verify_exp": True,
"require": ["sub", "aud", "iss", "exp"],
},
)
except jwt.ExpiredSignatureError:
logger.warning("JWT validation failed: token expired")
raise HTTPException(status_code=401, detail="Unauthorized")
except jwt.InvalidAudienceError:
logger.warning("JWT validation failed: invalid audience")
raise HTTPException(status_code=401, detail="Unauthorized")
except jwt.InvalidIssuerError:
logger.warning("JWT validation failed: invalid issuer")
raise HTTPException(status_code=401, detail="Unauthorized")
except jwt.InvalidSignatureError:
logger.warning("JWT validation failed: invalid signature")
raise HTTPException(status_code=401, detail="Unauthorized")
except jwt.DecodeError:
logger.warning("JWT validation failed: malformed token")
raise HTTPException(status_code=401, detail="Unauthorized")
except jwt.PyJWTError as exc:
payload = get_jwt_verifier().verify(token)
except HTTPException:
raise
except TokenVerifierUnavailableError:
logger.error("JWT validation failed: verifier unavailable")
raise HTTPException(status_code=503, detail="JWT verifier unavailable")
except TokenValidationError as exc:
logger.warning(
"JWT validation failed: unknown error", error_type=type(exc).__name__
"JWT validation failed",
error_type=type(exc).__name__,
)
raise HTTPException(status_code=401, detail="Unauthorized") from exc
@@ -1,45 +1,43 @@
from __future__ import annotations
import os
from datetime import datetime, timedelta, timezone
from uuid import UUID, uuid4
import httpx
import jwt
import pytest
from sqlalchemy import select
from core.config import config
from core.db.session import AsyncSessionLocal
from models.agent_chat_message import AgentChatMessage
from models.agent_chat_session import AgentChatSession
from models.profile import Profile
BASE_URL = os.getenv("AGENT_LIVE_BASE_URL", "http://localhost:5775")
async def _owner_id() -> UUID:
async with AsyncSessionLocal() as session:
owner_id = (await session.execute(select(Profile.id).limit(1))).scalar_one_or_none()
if owner_id is None:
pytest.skip("profile owner not found")
return owner_id
async def _live_access_token(client: httpx.AsyncClient) -> str:
email = os.getenv("AGENT_LIVE_EMAIL")
password = os.getenv("AGENT_LIVE_PASSWORD")
if not email or not password:
pytest.fail(
"AGENT_LIVE_INTEGRATION=1 requires AGENT_LIVE_EMAIL and AGENT_LIVE_PASSWORD"
)
response = await client.post(
f"{BASE_URL}/api/v1/auth/sessions",
json={"email": email, "password": password},
)
response_text = response.text.strip().replace("\n", " ")
truncated_text = response_text[:200]
if len(response_text) > 200:
truncated_text += "..."
def _jwt_for(user_id: UUID) -> str:
secret = config.supabase.jwt_secret
if not secret:
pytest.skip("JWT secret not configured")
issuer = f"{config.supabase.public_url.rstrip('/')}/auth/v1"
payload = {
"sub": str(user_id),
"role": "authenticated",
"aud": "authenticated",
"iss": issuer,
"iat": datetime.now(timezone.utc),
"exp": datetime.now(timezone.utc) + timedelta(minutes=30),
}
return jwt.encode(payload, secret, algorithm="HS256")
assert response.status_code == 200, (
f"live login failed: status={response.status_code}, response={truncated_text!r}"
)
token = response.json().get("access_token")
assert isinstance(token, str) and token
return token
@pytest.mark.asyncio
@@ -48,11 +46,10 @@ async def test_agent_sse_closed_loop_live() -> None:
if os.getenv("AGENT_LIVE_INTEGRATION") != "1":
pytest.skip("set AGENT_LIVE_INTEGRATION=1 to run live integration test")
owner_id = await _owner_id()
token = _jwt_for(owner_id)
headers = {"Authorization": f"Bearer {token}"}
async with httpx.AsyncClient(timeout=30.0) as client:
token = await _live_access_token(client)
headers = {"Authorization": f"Bearer {token}"}
run_resp = await client.post(
f"{BASE_URL}/api/v1/agent/runs",
headers=headers,
@@ -76,9 +73,13 @@ async def test_agent_sse_closed_loop_live() -> None:
events_url = f"{BASE_URL}/api/v1/agent/runs/{thread_id}/events"
event_names: list[str] = []
async with client.stream("GET", events_url, headers=headers, timeout=20.0) as sse_resp:
async with client.stream(
"GET", events_url, headers=headers, timeout=20.0
) as sse_resp:
assert sse_resp.status_code == 200
assert sse_resp.headers.get("content-type", "").startswith("text/event-stream")
assert sse_resp.headers.get("content-type", "").startswith(
"text/event-stream"
)
async for line in sse_resp.aiter_lines():
if line.startswith("event:"):
event_names.append(line.split(":", 1)[1].strip())
@@ -94,6 +95,8 @@ async def test_agent_sse_closed_loop_live() -> None:
assert session_row.total_cost >= 0
rows = await session.execute(
select(AgentChatMessage).where(AgentChatMessage.session_id == UUID(thread_id))
select(AgentChatMessage).where(
AgentChatMessage.session_id == UUID(thread_id)
)
)
assert len(list(rows.scalars().all())) >= 1
@@ -0,0 +1,268 @@
from __future__ import annotations
from datetime import UTC, datetime, timedelta
from types import SimpleNamespace
from typing import Any, cast
from uuid import uuid4
import jwt
import pytest
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from core.auth.jwt_verifier import (
JwtVerifier,
TokenValidationError,
TokenVerifierUnavailableError,
)
def _set_jwks_client(verifier: JwtVerifier, client: Any) -> None:
cast(Any, verifier)._jwks_client = client
def _build_rsa_key_pair() -> tuple[str, str]:
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
private_pem = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
).decode("utf-8")
public_pem = (
private_key.public_key()
.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
.decode("utf-8")
)
return private_pem, public_pem
def _build_token(*, private_key: str, sub: str, audience: str, issuer: str) -> str:
now = datetime.now(UTC)
payload = {
"sub": sub,
"aud": audience,
"iss": issuer,
"exp": now + timedelta(minutes=5),
}
return jwt.encode(payload, private_key, algorithm="RS256", headers={"kid": "kid-1"})
def _build_expired_token(
*, private_key: str, sub: str, audience: str, issuer: str
) -> str:
now = datetime.now(UTC)
payload = {
"sub": sub,
"aud": audience,
"iss": issuer,
"exp": now - timedelta(minutes=1),
}
return jwt.encode(payload, private_key, algorithm="RS256", headers={"kid": "kid-1"})
def _build_hs256_token(*, secret: str, sub: str, audience: str, issuer: str) -> str:
now = datetime.now(UTC)
payload = {
"sub": sub,
"aud": audience,
"iss": issuer,
"exp": now + timedelta(minutes=5),
}
return jwt.encode(payload, secret, algorithm="HS256", headers={"kid": "kid-1"})
def test_verify_token_with_jwks_success() -> None:
user_id = uuid4()
audience = "authenticated"
issuer = "https://example.supabase.co/auth/v1"
private_key, public_key = _build_rsa_key_pair()
token = _build_token(
private_key=private_key,
sub=str(user_id),
audience=audience,
issuer=issuer,
)
verifier = JwtVerifier(
jwks_url="https://example.supabase.co/auth/v1/.well-known/jwks.json",
issuer=issuer,
audience=audience,
)
_set_jwks_client(
verifier,
SimpleNamespace(
get_signing_key_from_jwt=lambda _: SimpleNamespace(key=public_key)
),
)
claims = verifier.verify(token)
assert claims["sub"] == str(user_id)
def test_verify_token_rejects_invalid_issuer() -> None:
audience = "authenticated"
issuer = "https://example.supabase.co/auth/v1"
private_key, public_key = _build_rsa_key_pair()
token_with_wrong_iss = _build_token(
private_key=private_key,
sub=str(uuid4()),
audience=audience,
issuer="https://wrong-issuer.example.com/auth/v1",
)
verifier = JwtVerifier(
jwks_url="https://example.supabase.co/auth/v1/.well-known/jwks.json",
issuer=issuer,
audience=audience,
)
_set_jwks_client(
verifier,
SimpleNamespace(
get_signing_key_from_jwt=lambda _: SimpleNamespace(key=public_key)
),
)
with pytest.raises(TokenValidationError):
verifier.verify(token_with_wrong_iss)
def test_verify_token_rejects_hs256_token() -> None:
audience = "authenticated"
issuer = "https://example.supabase.co/auth/v1"
_, public_key = _build_rsa_key_pair()
hs_token = _build_hs256_token(
secret="test-secret",
sub=str(uuid4()),
audience=audience,
issuer=issuer,
)
verifier = JwtVerifier(
jwks_url="https://example.supabase.co/auth/v1/.well-known/jwks.json",
issuer=issuer,
audience=audience,
)
_set_jwks_client(
verifier,
SimpleNamespace(
get_signing_key_from_jwt=lambda _: SimpleNamespace(key=public_key)
),
)
with pytest.raises(TokenValidationError):
verifier.verify(hs_token)
def test_verify_token_rejects_expired_token() -> None:
audience = "authenticated"
issuer = "https://example.supabase.co/auth/v1"
private_key, public_key = _build_rsa_key_pair()
expired_token = _build_expired_token(
private_key=private_key,
sub=str(uuid4()),
audience=audience,
issuer=issuer,
)
verifier = JwtVerifier(
jwks_url="https://example.supabase.co/auth/v1/.well-known/jwks.json",
issuer=issuer,
audience=audience,
)
_set_jwks_client(
verifier,
SimpleNamespace(
get_signing_key_from_jwt=lambda _: SimpleNamespace(key=public_key)
),
)
with pytest.raises(TokenValidationError):
verifier.verify(expired_token)
def test_verify_token_rejects_invalid_audience() -> None:
audience = "authenticated"
issuer = "https://example.supabase.co/auth/v1"
private_key, public_key = _build_rsa_key_pair()
wrong_aud_token = _build_token(
private_key=private_key,
sub=str(uuid4()),
audience="anon",
issuer=issuer,
)
verifier = JwtVerifier(
jwks_url="https://example.supabase.co/auth/v1/.well-known/jwks.json",
issuer=issuer,
audience=audience,
)
_set_jwks_client(
verifier,
SimpleNamespace(
get_signing_key_from_jwt=lambda _: SimpleNamespace(key=public_key)
),
)
with pytest.raises(TokenValidationError):
verifier.verify(wrong_aud_token)
def test_verify_token_rejects_invalid_signature() -> None:
audience = "authenticated"
issuer = "https://example.supabase.co/auth/v1"
private_key, public_key = _build_rsa_key_pair()
valid_token = _build_token(
private_key=private_key,
sub=str(uuid4()),
audience=audience,
issuer=issuer,
)
tampered_token = f"{valid_token}x"
verifier = JwtVerifier(
jwks_url="https://example.supabase.co/auth/v1/.well-known/jwks.json",
issuer=issuer,
audience=audience,
)
_set_jwks_client(
verifier,
SimpleNamespace(
get_signing_key_from_jwt=lambda _: SimpleNamespace(key=public_key)
),
)
with pytest.raises(TokenValidationError):
verifier.verify(tampered_token)
def test_verify_token_maps_jwks_connection_error() -> None:
audience = "authenticated"
issuer = "https://example.supabase.co/auth/v1"
private_key, _ = _build_rsa_key_pair()
token = _build_token(
private_key=private_key,
sub=str(uuid4()),
audience=audience,
issuer=issuer,
)
verifier = JwtVerifier(
jwks_url="https://example.supabase.co/auth/v1/.well-known/jwks.json",
issuer=issuer,
audience=audience,
)
def _raise_connection_error(_: str) -> SimpleNamespace:
raise jwt.PyJWKClientConnectionError("network down")
_set_jwks_client(
verifier,
SimpleNamespace(get_signing_key_from_jwt=_raise_connection_error),
)
with pytest.raises(TokenVerifierUnavailableError):
verifier.verify(token)
@@ -1,19 +1,18 @@
from __future__ import annotations
import pytest
from pydantic import ValidationError
from pytest import MonkeyPatch
from core.config.settings import Settings
from core.config.settings import Settings, SupabaseSettings
def test_social_prefixed_supabase_env_populates_settings(
monkeypatch: MonkeyPatch,
) -> None:
monkeypatch.setenv("SOCIAL_SUPABASE__PUBLIC_SCHEME", "https")
monkeypatch.setenv("SOCIAL_SUPABASE__PUBLIC_HOST", "public.example")
monkeypatch.setenv("SOCIAL_SUPABASE__KONG_HTTP_PORT", "8443")
monkeypatch.setenv("SOCIAL_SUPABASE__PUBLIC_URL", "https://public.example:8443")
monkeypatch.setenv("SOCIAL_SUPABASE__ANON_KEY", "anon-key")
monkeypatch.setenv("SOCIAL_SUPABASE__SERVICE_ROLE_KEY", "service-key")
monkeypatch.setenv("SOCIAL_SUPABASE__JWT_SECRET", "jwt-secret")
monkeypatch.setenv("SOCIAL_SUPABASE__SITE_URL", "https://app.example.com")
monkeypatch.setenv(
"SOCIAL_SUPABASE__ADDITIONAL_REDIRECT_URLS",
@@ -27,10 +26,9 @@ def test_social_prefixed_supabase_env_populates_settings(
settings = Settings()
assert settings.supabase.public_url == "https://public.example:8443"
assert str(settings.supabase.public_url) == "https://public.example:8443/"
assert settings.supabase.anon_key == "anon-key"
assert settings.supabase.service_role_key == "service-key"
assert settings.supabase.jwt_secret == "jwt-secret"
assert settings.supabase.site_url == "https://app.example.com"
assert settings.supabase.additional_redirect_urls == [
"https://a.example.com",
@@ -38,9 +36,63 @@ def test_social_prefixed_supabase_env_populates_settings(
]
supabase_settings = settings.model_dump()["supabase"]
assert supabase_settings["public_url"] == "https://public.example:8443"
assert str(supabase_settings["public_url"]) == "https://public.example:8443/"
assert supabase_settings["anon_key"] == "anon-key"
assert supabase_settings["service_role_key"] == "service-key"
assert supabase_settings["jwt_secret"] == "jwt-secret"
assert supabase_settings["site_url"] == "https://app.example.com"
assert "jwt_secret" not in supabase_settings
assert "public_scheme" not in supabase_settings
assert "public_host" not in supabase_settings
assert "kong_http_port" not in supabase_settings
assert settings.database_url == "postgresql+asyncpg://user:pass@db:5432/app"
def test_cloud_supabase_env_populates_settings(monkeypatch: MonkeyPatch) -> None:
monkeypatch.setenv(
"SOCIAL_SUPABASE__PUBLIC_URL", "https://project.example.supabase.co"
)
monkeypatch.setenv("SOCIAL_SUPABASE__ANON_KEY", "anon-key")
monkeypatch.setenv("SOCIAL_SUPABASE__SERVICE_ROLE_KEY", "service-key")
monkeypatch.setenv("SOCIAL_SUPABASE__JWT_AUDIENCE", "authenticated")
settings = Settings()
assert str(settings.supabase.public_url) == "https://project.example.supabase.co/"
assert settings.supabase.jwt_audience == "authenticated"
assert settings.supabase.jwt_issuer == "https://project.example.supabase.co/auth/v1"
assert (
settings.supabase.jwks_url
== "https://project.example.supabase.co/auth/v1/.well-known/jwks.json"
)
supabase_settings = settings.model_dump()["supabase"]
assert "jwt_secret" not in supabase_settings
def test_missing_public_url_raises_validation_error() -> None:
with pytest.raises(ValidationError) as exc_info:
SupabaseSettings()
assert "public_url" in str(exc_info.value)
def test_public_url_with_trailing_slash_normalizes_correctly(
monkeypatch: MonkeyPatch,
) -> None:
monkeypatch.setenv("SOCIAL_SUPABASE__PUBLIC_URL", "https://example.supabase.co/")
monkeypatch.setenv("SOCIAL_SUPABASE__ANON_KEY", "anon-key")
monkeypatch.setenv("SOCIAL_SUPABASE__SERVICE_ROLE_KEY", "service-key")
monkeypatch.setenv("SOCIAL_DATABASE__HOST", "db")
monkeypatch.setenv("SOCIAL_DATABASE__PORT", "5432")
monkeypatch.setenv("SOCIAL_DATABASE__NAME", "app")
monkeypatch.setenv("SOCIAL_DATABASE__USER", "user")
monkeypatch.setenv("SOCIAL_DATABASE__PASSWORD", "pass")
settings = Settings()
assert settings.supabase.jwt_issuer == "https://example.supabase.co/auth/v1"
assert (
settings.supabase.jwks_url
== "https://example.supabase.co/auth/v1/.well-known/jwks.json"
)
assert settings.supabase.url == "https://example.supabase.co/"