chore: checkpoint current backend/runtime changes

This commit is contained in:
qzl
2026-03-06 17:28:17 +08:00
parent 2c59fe5ee2
commit b6087fd195
32 changed files with 1641 additions and 469 deletions
+13 -12
View File
@@ -13,7 +13,7 @@ from starlette.exceptions import HTTPException as StarletteHTTPException
from core.config.settings import config
from core.http.response import build_problem_details
from core.logging import configure_logging, get_logger, log_service_banner
from services.base.redis import redis_service
from services.base import close_registered_services, initialize_registered_services
from v1.router import router as mobile_router
@@ -29,22 +29,23 @@ log_service_banner(
)
logger = get_logger("api.app")
SERVICE_STARTUP_ORDER = ["redis", "supabase"]
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
initialized = await redis_service.initialize()
initialized, services = await initialize_registered_services(SERVICE_STARTUP_ORDER)
if not initialized:
logger.error("Redis service failed to initialize, aborting startup")
raise RuntimeError("Redis service initialization failed")
logger.info(
"Redis service initialized",
host=config.redis.host,
db=config.redis.db,
)
yield
await redis_service.close()
logger.info("Redis service closed")
logger.error("Service initialization failed, aborting startup")
raise RuntimeError("Service initialization failed")
logger.info("Base services initialized", services=SERVICE_STARTUP_ORDER)
try:
yield
finally:
closed = await close_registered_services(services)
if not closed:
logger.warning("Failed to close all base services")
logger.info("Base services closed", services=SERVICE_STARTUP_ORDER)
app = FastAPI(lifespan=lifespan)
@@ -1,15 +1,15 @@
from __future__ import annotations
from typing import Any, Protocol, cast
from typing import Any, Protocol
from uuid import UUID
from core.agent.application.resume_service import ResumeService
from core.agent.application.run_service import RunService
from core.agent.infrastructure.events.redis_stream import RedisStreamEventStore
from core.celery.app import celery_app
from core.config.settings import config
from core.logging import get_logger
from services.base.redis import redis_service
from core.taskiq.app import bulk_broker, critical_broker, default_broker
from services.base.redis import get_or_init_redis_client
logger = get_logger("core.agent.infrastructure.queue.tasks")
@@ -29,13 +29,12 @@ class ResumeServiceLike(Protocol):
async def _build_redis_publisher() -> PublishEvent:
settings = cast(Any, config)
client = redis_service.get_client()
client = await get_or_init_redis_client()
event_store = RedisStreamEventStore(
client=client,
stream_prefix=settings.agent_runtime.redis_stream_prefix,
read_count=settings.agent_runtime.redis_stream_read_count,
block_ms=settings.agent_runtime.redis_stream_block_ms,
stream_prefix=config.agent_runtime.redis_stream_prefix,
read_count=config.agent_runtime.redis_stream_read_count,
block_ms=config.agent_runtime.redis_stream_block_ms,
)
async def _publish(event_type: str, payload: dict[str, object]) -> None:
@@ -70,22 +69,27 @@ async def run_agent_task(
raise ValueError("session_id is required")
UUID(session_id)
tool_call_id = ""
user_input = ""
if command_type == "resume":
tool_call_id = str(command.get("tool_call_id", ""))
if not tool_call_id:
raise ValueError("tool_call_id is required")
else:
user_input = str(command.get("user_input", ""))
if not user_input:
raise ValueError("user_input is required")
start_event = "RUN_RESUMED" if command_type == "resume" else "RUN_STARTED"
await publisher(start_event, {"session_id": session_id})
try:
if command_type == "resume":
tool_call_id = str(command.get("tool_call_id", ""))
if not tool_call_id:
raise ValueError("tool_call_id is required")
result = await service_resume.resume(
session_id=session_id,
tool_call_id=tool_call_id,
)
else:
user_input = str(command.get("user_input", ""))
if not user_input:
raise ValueError("user_input is required")
result = await service_run.run(
session_id=session_id,
user_input=user_input,
@@ -125,6 +129,16 @@ async def run_agent_task(
raise
@celery_app.task(name="tasks.agent.run_command")
@default_broker.task(task_name="tasks.agent.run_command")
async def run_command_task(command: dict[str, Any]) -> dict[str, object]:
return await run_agent_task(command)
@critical_broker.task(task_name="tasks.agent.run_command.critical")
async def run_command_task_critical(command: dict[str, Any]) -> dict[str, object]:
return await run_agent_task(command)
@bulk_broker.task(task_name="tasks.agent.run_command.bulk")
async def run_command_task_bulk(command: dict[str, Any]) -> dict[str, object]:
return await run_agent_task(command)
-73
View File
@@ -1,73 +0,0 @@
from __future__ import annotations
from celery import Celery
from celery import signals as celery_signals
from kombu import Queue
from core.config.settings import config
from core.logging import get_logger
from core.logging.celery import configure_celery_app
from services.base.redis import redis_service
logger = get_logger("core.celery")
def _init_redis_on_worker_startup(**_: object) -> None:
import asyncio
logger.info("Initializing Redis service for Celery worker")
try:
result = asyncio.run(redis_service.initialize())
if result:
logger.info("Redis service initialized for Celery worker")
else:
logger.warning("Redis service initialization returned False")
except Exception as exc: # noqa: BLE001
logger.error("Failed to initialize Redis for Celery worker", error=str(exc))
def create_celery_app() -> Celery:
"""Create and configure the Celery application."""
celery_settings = config.celery
app = Celery(
"social_app",
broker=config.celery_broker_url,
backend=config.celery_result_backend,
include=["core.agent.infrastructure.queue.tasks"],
)
app.conf.update(
task_serializer=celery_settings.task_serializer,
result_serializer=celery_settings.result_serializer,
accept_content=celery_settings.accept_content,
timezone=celery_settings.timezone,
enable_utc=celery_settings.enable_utc,
task_track_started=celery_settings.task_track_started,
task_time_limit=celery_settings.task_time_limit,
task_soft_time_limit=celery_settings.task_soft_time_limit,
task_default_retry_delay=celery_settings.task_default_retry_delay,
task_default_queue="default",
task_create_missing_queues=False,
task_queues=(
Queue("default"),
Queue("critical"),
Queue("bulk"),
),
task_routes={
"tasks.critical.*": {"queue": "critical"},
"tasks.bulk.*": {"queue": "bulk"},
},
task_acks_late=True,
task_reject_on_worker_lost=True,
worker_prefetch_multiplier=1,
)
configure_celery_app(app, settings=config)
return app
celery_app = create_celery_app()
celery_signals.worker_process_init.connect(_init_redis_on_worker_startup)
+7 -17
View File
@@ -63,19 +63,9 @@ class RuntimeSettings(BaseModel):
return self
class CelerySettings(BaseModel):
class TaskiqSettings(BaseModel):
broker_url: str | None = None
result_backend: str | None = None
task_serializer: str = "json"
result_serializer: str = "json"
accept_content: list[str] = Field(default_factory=lambda: ["json"])
timezone: str = "UTC"
enable_utc: bool = True
task_track_started: bool = True
task_time_limit: int = 300
task_soft_time_limit: int = 240
task_default_retry_delay: int = 30
task_max_retries: int = 3
result_backend_url: str | None = None
class CorsSettings(BaseModel):
@@ -189,7 +179,7 @@ class Settings(BaseSettings):
storage: StorageSettings = StorageSettings()
llm: LlmSettings = LlmSettings()
agent_runtime: AgentRuntimeSettings = AgentRuntimeSettings()
celery: CelerySettings = CelerySettings()
taskiq: TaskiqSettings = TaskiqSettings()
database: DatabaseSettings = DatabaseSettings()
@computed_field
@@ -199,13 +189,13 @@ class Settings(BaseSettings):
@computed_field
@property
def celery_broker_url(self) -> str:
return self.celery.broker_url or self.redis.url
def taskiq_broker_url(self) -> str:
return self.taskiq.broker_url or self.redis.url
@computed_field
@property
def celery_result_backend(self) -> str:
return self.celery.result_backend or self.redis.url
def taskiq_result_backend_url(self) -> str:
return self.taskiq.result_backend_url or self.redis.url
model_config: ClassVar[SettingsConfigDict] = SettingsConfigDict(
env_file=_resolve_env_file(),
-2
View File
@@ -1,6 +1,5 @@
from __future__ import annotations
from core.logging import celery
from core.logging.banner import log_service_banner
from core.logging.config import configure_logging
from core.logging.context import bind_context, clear_context, get_context
@@ -8,7 +7,6 @@ from core.logging.logger import get_logger
__all__ = [
"bind_context",
"celery",
"clear_context",
"configure_logging",
"get_context",
-64
View File
@@ -1,64 +0,0 @@
from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass
from typing import cast
from celery import Celery, signals
from core.config.settings import Settings
from core.logging.banner import log_service_banner
from core.logging.config import configure_logging
from core.logging.context import bind_context, clear_context
@dataclass(frozen=True)
class CelerySignalHandlers:
on_setup_logging: Callable[..., None]
on_after_setup_task_logger: Callable[..., None]
on_task_prerun: Callable[..., None]
on_task_postrun: Callable[..., None]
def build_celery_signal_handlers(
settings: Settings | None = None,
) -> CelerySignalHandlers:
active_settings = settings or Settings()
def on_setup_logging(*_args: object, **_kwargs: object) -> None:
configure_logging(settings)
log_service_banner(
service_name=active_settings.runtime.service_name,
environment=active_settings.runtime.environment,
)
def on_after_setup_task_logger(*_args: object, **_kwargs: object) -> None:
configure_logging(settings)
def on_task_prerun(*_args: object, **kwargs: object) -> None:
task_id = cast(str | None, kwargs.get("task_id"))
task = kwargs.get("task")
task_name = getattr(task, "name", None)
bind_context(task_id=task_id, task_name=task_name)
def on_task_postrun(*_args: object, **_kwargs: object) -> None:
clear_context()
return CelerySignalHandlers(
on_setup_logging=on_setup_logging,
on_after_setup_task_logger=on_after_setup_task_logger,
on_task_prerun=on_task_prerun,
on_task_postrun=on_task_postrun,
)
def configure_celery_app(app: Celery, settings: Settings | None = None) -> None:
app.conf.worker_hijack_root_logger = False
handlers = build_celery_signal_handlers(settings)
signals.setup_logging.connect(handlers.on_setup_logging, weak=False)
signals.after_setup_task_logger.connect(
handlers.on_after_setup_task_logger, weak=False
)
signals.task_prerun.connect(handlers.on_task_prerun, weak=False)
signals.task_postrun.connect(handlers.on_task_postrun, weak=False)
+3
View File
@@ -0,0 +1,3 @@
from core.taskiq.app import broker, bulk_broker, critical_broker, default_broker
__all__ = ["broker", "default_broker", "critical_broker", "bulk_broker"]
+27
View File
@@ -0,0 +1,27 @@
from __future__ import annotations
from taskiq_redis import ListQueueBroker, RedisAsyncResultBackend
from core.config.settings import config
from core.logging import configure_logging
configure_logging(config)
def _build_broker(queue_name: str) -> ListQueueBroker:
return ListQueueBroker(
url=config.taskiq_broker_url,
queue_name=queue_name,
).with_result_backend(
RedisAsyncResultBackend(redis_url=config.taskiq_result_backend_url)
)
default_broker = _build_broker("default")
critical_broker = _build_broker("critical")
bulk_broker = _build_broker("bulk")
# Backward-compatible export name for existing imports/tests.
broker = default_broker
__all__ = ["broker", "default_broker", "critical_broker", "bulk_broker"]
+11 -1
View File
@@ -1,18 +1,28 @@
from __future__ import annotations
from services.base.redis import RedisService, redis_service
from services.base.redis import RedisService, get_or_init_redis_client, redis_service
from services.base.service_interface import (
BaseServiceProvider,
ServiceRegistry,
close_registered_services,
initialize_registered_services,
register_service,
register_service_instance,
resolve_registered_services,
)
from services.base.supabase import SupabaseService, supabase_service
__all__ = [
"BaseServiceProvider",
"RedisService",
"ServiceRegistry",
"SupabaseService",
"close_registered_services",
"get_or_init_redis_client",
"initialize_registered_services",
"redis_service",
"register_service",
"register_service_instance",
"resolve_registered_services",
"supabase_service",
]
+9 -1
View File
@@ -92,6 +92,14 @@ class RedisService(BaseServiceProvider):
return self._require_client()
async def get_or_init_redis_client() -> redis.Redis:
if not redis_service.is_initialized:
initialized = await redis_service.initialize()
if not initialized:
raise RuntimeError("Redis service initialization failed")
return redis_service.get_client()
redis_service: RedisService = register_service_instance("redis", RedisService())
__all__ = ["RedisService", "redis_service"]
__all__ = ["RedisService", "get_or_init_redis_client", "redis_service"]
@@ -61,6 +61,12 @@ class ServiceRegistry:
@classmethod
def create_service(
cls, service_name: str, **kwargs: Any
) -> Optional[BaseServiceProvider]:
return cls.get_service(service_name, **kwargs)
@classmethod
def get_service(
cls, service_name: str, **kwargs: Any
) -> Optional[BaseServiceProvider]:
factory = cls.get_service_factory(service_name)
if not factory:
@@ -82,3 +88,71 @@ TService = TypeVar("TService", bound=BaseServiceProvider)
def register_service_instance(service_name: str, service: TService) -> TService:
ServiceRegistry.register(service_name, lambda: service)
return service
def resolve_registered_services(service_names: list[str]) -> list[BaseServiceProvider]:
services: list[BaseServiceProvider] = []
for service_name in service_names:
service = ServiceRegistry.get_service(service_name)
if service is None:
raise RuntimeError(f"Service is not registered: {service_name}")
services.append(service)
return services
async def close_registered_services(services: list[BaseServiceProvider]) -> bool:
lifecycle_logger = get_logger("services.base.lifecycle")
all_closed = True
for service in reversed(services):
try:
closed = await service.close()
except Exception as exc: # noqa: BLE001
lifecycle_logger.warning(
"Failed to close service",
service=service.service_name,
error=str(exc),
)
all_closed = False
continue
if not closed:
lifecycle_logger.warning(
"Service close returned false",
service=service.service_name,
)
all_closed = False
return all_closed
async def initialize_registered_services(
service_names: list[str],
) -> tuple[bool, list[BaseServiceProvider]]:
lifecycle_logger = get_logger("services.base.lifecycle")
initialized_services: list[BaseServiceProvider] = []
try:
services = resolve_registered_services(service_names)
except RuntimeError as exc:
lifecycle_logger.error("Failed to resolve registered services", error=str(exc))
return False, []
for service in services:
try:
initialized = await service.initialize()
except Exception as exc: # noqa: BLE001
lifecycle_logger.warning(
"Service initialization raised exception",
service=service.service_name,
error=str(exc),
)
initialized = False
if not initialized:
lifecycle_logger.error(
"Service initialization failed, rolling back",
service=service.service_name,
)
await close_registered_services(initialized_services)
return False, []
initialized_services.append(service)
return True, initialized_services
+89
View File
@@ -0,0 +1,89 @@
from __future__ import annotations
import asyncio
from typing import Any
from supabase import create_client
from core.config.settings import SupabaseSettings, config
from .service_interface import BaseServiceProvider, register_service_instance
class SupabaseService(BaseServiceProvider):
def __init__(self, settings: SupabaseSettings | None = None) -> None:
super().__init__("supabase")
self._settings = settings or config.supabase
self._client: Any = None
self._admin_client: Any = None
async def initialize(self, **_: Any) -> bool:
try:
self._client = create_client(
self._settings.url,
self._settings.anon_key,
)
self._admin_client = create_client(
self._settings.url,
self._settings.service_role_key,
)
self._set_initialized(True)
self.logger.info("Supabase service initialized")
return True
except Exception as exc: # noqa: BLE001
self.logger.warning("Supabase service initialization failed", error=str(exc))
self._client = None
self._admin_client = None
self._set_initialized(False)
return False
async def close(self) -> bool:
self._client = None
self._admin_client = None
self._set_initialized(False)
self.logger.info("Supabase service closed")
return True
async def health_check(self) -> dict[str, Any]:
client = self._client
admin_client = self._admin_client
if client is None or admin_client is None:
return {"status": "unhealthy", "details": {"error": "not initialized"}}
try:
await asyncio.to_thread(client.auth.get_session)
await asyncio.to_thread(admin_client.auth.admin.list_users, page=1, per_page=1)
return {
"status": "healthy",
"details": {
"anon_client": "ready",
"admin_client": "ready",
},
}
except Exception as exc: # noqa: BLE001
self.logger.warning("Supabase health check failed", error=str(exc))
return {"status": "unhealthy", "details": {"error": str(exc)}}
def get_client(self) -> Any:
return self._require_client()
def get_admin_client(self) -> Any:
return self._require_admin_client()
def _require_client(self) -> Any:
client = self._client
if client is None:
raise RuntimeError("Supabase client is not initialized")
return client
def _require_admin_client(self) -> Any:
admin_client = self._admin_client
if admin_client is None:
raise RuntimeError("Supabase admin client is not initialized")
return admin_client
supabase_service: SupabaseService = register_service_instance(
"supabase", SupabaseService()
)
__all__ = ["SupabaseService", "supabase_service"]
+69 -27
View File
@@ -1,57 +1,98 @@
from __future__ import annotations
from typing import Any, cast
import asyncio
from typing import Any
from uuid import UUID
from fastapi import Depends
from redis.asyncio import Redis
from sqlalchemy.ext.asyncio import AsyncSession
from core.agent.infrastructure.events.redis_stream import RedisStreamEventStore
from core.agent.infrastructure.queue.tasks import run_command_task
from core.agent.infrastructure.queue.tasks import (
run_command_task,
run_command_task_bulk,
run_command_task_critical,
)
from core.config.settings import config
from core.db import get_db
from services.base.redis import redis_service
from services.base.redis import get_or_init_redis_client
from v1.agent.repository import AgentRepository
from v1.agent.service import AgentService
DEDUP_WAIT_RETRIES = 20
DEDUP_WAIT_SECONDS = 0.05
DEDUP_LOCK_SECONDS = 300
DEDUP_INFLIGHT_MARKER = "__inflight__"
class CeleryQueueClient:
class TaskiqQueueClient:
def __init__(self) -> None:
self._redis = redis_service.get_client()
self._redis: Redis | None = None
async def _get_redis(self) -> Redis:
if self._redis is None:
self._redis = await get_or_init_redis_client()
return self._redis
@staticmethod
def _select_queue_task(command: dict[str, object]) -> Any:
queue = str(command.get("queue", "default")).strip().lower()
if queue == "critical":
return run_command_task_critical
if queue == "bulk":
return run_command_task_bulk
return run_command_task
async def enqueue(
self, *, command: dict[str, object], dedup_key: str | None
) -> str:
redis_client = await self._get_redis()
redis_key = None
if dedup_key:
redis_key = f"agent:dedup:{dedup_key}"
locked = await self._redis.set(redis_key, "__inflight__", nx=True, ex=300)
locked = await redis_client.set(
redis_key,
DEDUP_INFLIGHT_MARKER,
nx=True,
ex=DEDUP_LOCK_SECONDS,
)
if not locked:
existing = await self._redis.get(redis_key)
if existing and existing != "__inflight__":
return existing
for _ in range(DEDUP_WAIT_RETRIES):
existing = await redis_client.get(redis_key)
if existing and existing != DEDUP_INFLIGHT_MARKER:
return existing
await asyncio.sleep(DEDUP_WAIT_SECONDS)
raise RuntimeError("duplicate request is still in progress")
payload = dict(command)
if dedup_key:
payload["dedup_key"] = dedup_key
delay = getattr(run_command_task, "delay")
result = delay(payload)
task_id = str(result.id)
if redis_key is not None:
await self._redis.set(redis_key, task_id, ex=300)
return task_id
queue_task = self._select_queue_task(payload)
try:
result = await queue_task.kiq(payload)
task_id = str(result.task_id)
if redis_key is not None:
await redis_client.set(redis_key, task_id, ex=DEDUP_LOCK_SECONDS)
return task_id
except Exception:
if redis_key is not None:
await redis_client.delete(redis_key)
raise
class RedisEventStream:
def __init__(self) -> None:
settings = cast(Any, config)
client = redis_service.get_client()
self._store = RedisStreamEventStore(
client=client,
stream_prefix=settings.agent_runtime.redis_stream_prefix,
read_count=settings.agent_runtime.redis_stream_read_count,
block_ms=settings.agent_runtime.redis_stream_block_ms,
)
self._store: RedisStreamEventStore | None = None
async def _get_store(self) -> RedisStreamEventStore:
if self._store is None:
client = await get_or_init_redis_client()
self._store = RedisStreamEventStore(
client=client,
stream_prefix=config.agent_runtime.redis_stream_prefix,
read_count=config.agent_runtime.redis_stream_read_count,
block_ms=config.agent_runtime.redis_stream_block_ms,
)
return self._store
async def read(
self,
@@ -59,7 +100,8 @@ class RedisEventStream:
session_id: str,
last_event_id: str | None,
) -> list[dict[str, Any]]:
rows = await self._store.read_events(
store = await self._get_store()
rows = await store.read_events(
session_id=UUID(session_id),
last_event_id=last_event_id,
)
@@ -69,6 +111,6 @@ class RedisEventStream:
def get_agent_service(session: AsyncSession = Depends(get_db)) -> AgentService:
return AgentService(
repository=AgentRepository(session),
queue=CeleryQueueClient(),
queue=TaskiqQueueClient(),
stream=RedisEventStream(),
)
+28 -20
View File
@@ -5,10 +5,10 @@ from collections.abc import Mapping
from typing import Any, cast
from fastapi import HTTPException
from supabase import AuthError, create_client
from supabase import AuthError
from core.config.settings import SupabaseSettings, config
from core.logging import get_logger
from services.base.supabase import supabase_service
from v1.auth.schemas import (
AuthUser,
PasswordResetConfirmRequest,
@@ -28,17 +28,16 @@ logger = get_logger("v1.auth.gateway")
class SupabaseAuthGateway(AuthServiceGateway):
_client: Any
_admin_client: Any
def _get_client(self) -> Any:
return supabase_service.get_client()
def __init__(self) -> None:
settings: SupabaseSettings = config.supabase
self._client = create_client(settings.url, settings.anon_key)
self._admin_client = create_client(settings.url, settings.service_role_key)
def _get_admin_client(self) -> Any:
return supabase_service.get_admin_client()
async def create_verification(
self, request: VerificationCreateRequest
) -> VerificationCreateResponse:
client = self._get_client()
metadata: dict[str, Any] = {"username": request.username}
if request.invite_code:
metadata["invite_code"] = request.invite_code
@@ -50,7 +49,7 @@ class SupabaseAuthGateway(AuthServiceGateway):
if request.redirect_to:
payload["options"] = {"email_redirect_to": request.redirect_to}
try:
sign_up = cast(Any, self._client.auth.sign_up)
sign_up = cast(Any, client.auth.sign_up)
await asyncio.to_thread(sign_up, payload)
return VerificationCreateResponse(email=request.email)
except AuthError as exc:
@@ -62,13 +61,14 @@ class SupabaseAuthGateway(AuthServiceGateway):
async def verify_verification(
self, request: VerificationVerifyRequest
) -> SessionResponse:
client = self._get_client()
payload: dict[str, Any] = {
"type": "signup",
"email": request.email,
"token": request.token,
}
try:
verify_otp = cast(Any, self._client.auth.verify_otp)
verify_otp = cast(Any, client.auth.verify_otp)
response = await asyncio.to_thread(verify_otp, payload)
return _map_auth_response(response, "Invalid verification code")
except AuthError as exc:
@@ -78,17 +78,19 @@ class SupabaseAuthGateway(AuthServiceGateway):
) from exc
async def resend_verification(self, request: VerificationResendRequest) -> None:
client = self._get_client()
payload: dict[str, Any] = {"type": "signup", "email": request.email}
try:
resend = cast(Any, self._client.auth.resend)
resend = cast(Any, client.auth.resend)
await asyncio.to_thread(resend, payload)
except AuthError as exc:
logger.warning("Signup resend failed", error_type=type(exc).__name__)
async def create_session(self, request: SessionCreateRequest) -> SessionResponse:
client = self._get_client()
payload: dict[str, Any] = {"email": request.email, "password": request.password}
try:
sign_in = cast(Any, self._client.auth.sign_in_with_password)
sign_in = cast(Any, client.auth.sign_in_with_password)
response = await asyncio.to_thread(sign_in, payload)
return _map_auth_response(response, "Invalid credentials")
except AuthError as exc:
@@ -96,9 +98,10 @@ class SupabaseAuthGateway(AuthServiceGateway):
raise HTTPException(status_code=401, detail="Invalid credentials") from exc
async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse:
client = self._get_client()
try:
response = await asyncio.to_thread(
self._client.auth.refresh_session,
client.auth.refresh_session,
request.refresh_token,
)
return _map_auth_response(response, "Invalid refresh token")
@@ -111,20 +114,21 @@ class SupabaseAuthGateway(AuthServiceGateway):
async def delete_session(self, refresh_token: str | None) -> None:
if not refresh_token:
raise HTTPException(status_code=401, detail="Missing refresh token")
client = self._get_client()
try:
response = await asyncio.to_thread(
self._client.auth.refresh_session,
client.auth.refresh_session,
refresh_token,
)
session = getattr(response, "session", None)
if session is None:
raise HTTPException(status_code=401, detail="Invalid refresh token")
await asyncio.to_thread(
self._client.auth.set_session,
client.auth.set_session,
str(session.access_token),
str(session.refresh_token),
)
await asyncio.to_thread(self._client.auth.sign_out)
await asyncio.to_thread(client.auth.sign_out)
except AuthError as exc:
logger.warning("Logout failed", error_type=type(exc).__name__)
raise HTTPException(
@@ -132,7 +136,8 @@ class SupabaseAuthGateway(AuthServiceGateway):
) from exc
async def get_user_by_email(self, email: str) -> UserByEmailResponse:
users = await asyncio.to_thread(_list_auth_users, self._admin_client)
admin_client = self._get_admin_client()
users = await asyncio.to_thread(_list_auth_users, admin_client)
normalized_email = email.lower()
user = next(
(
@@ -157,8 +162,9 @@ class SupabaseAuthGateway(AuthServiceGateway):
)
async def request_password_reset(self, request: PasswordResetRequest) -> None:
client = self._get_client()
try:
reset_email = cast(Any, self._client.auth.reset_password_email)
reset_email = cast(Any, client.auth.reset_password_email)
email = _coerce_reset_email(request.email)
if request.redirect_to:
options: dict[str, str] = {"redirect_to": request.redirect_to}
@@ -174,13 +180,15 @@ class SupabaseAuthGateway(AuthServiceGateway):
async def confirm_password_reset(
self, request: PasswordResetConfirmRequest
) -> None:
client = self._get_client()
admin_client = self._get_admin_client()
verify_payload: dict[str, Any] = {
"type": "recovery",
"email": request.email,
"token": request.token,
}
try:
verify_otp = cast(Any, self._client.auth.verify_otp)
verify_otp = cast(Any, client.auth.verify_otp)
response = await asyncio.to_thread(verify_otp, verify_payload)
session = getattr(response, "session", None)
user = getattr(response, "user", None)
@@ -190,7 +198,7 @@ class SupabaseAuthGateway(AuthServiceGateway):
status_code=401, detail="Invalid or expired verification code"
)
await asyncio.to_thread(
self._admin_client.auth.admin.update_user_by_id,
admin_client.auth.admin.update_user_by_id,
user_id,
{"password": request.new_password},
)
@@ -2,7 +2,7 @@ from __future__ import annotations
import pytest
from core.agent.infrastructure.queue.tasks import run_agent_task
from core.agent.infrastructure.queue.tasks import _build_redis_publisher, run_agent_task
class _FakeRunService:
@@ -67,3 +67,35 @@ async def test_run_agent_task_emits_error_event_on_exception() -> None:
)
assert events == ["RUN_STARTED", "RUN_ERROR"]
@pytest.mark.asyncio
async def test_run_agent_task_rejects_invalid_command() -> None:
with pytest.raises(ValueError, match="invalid command type"):
await run_agent_task({"command": "invalid", "session_id": "00000000-0000-0000-0000-000000000001"})
@pytest.mark.asyncio
async def test_run_agent_task_resume_requires_tool_call_id() -> None:
with pytest.raises(ValueError, match="tool_call_id is required"):
await run_agent_task(
{
"command": "resume",
"session_id": "00000000-0000-0000-0000-000000000001",
}
)
@pytest.mark.asyncio
async def test_build_redis_publisher_init_fail_raises_runtime_error(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from core.agent.infrastructure.queue import tasks
async def _fake_get_client() -> object:
raise RuntimeError("Redis service initialization failed")
monkeypatch.setattr(tasks, "get_or_init_redis_client", _fake_get_client)
with pytest.raises(RuntimeError, match="Redis service initialization failed"):
await _build_redis_publisher()
@@ -0,0 +1,10 @@
from __future__ import annotations
from core.config.settings import Settings
def test_taskiq_uses_redis_url_by_default() -> None:
settings = Settings()
assert settings.taskiq_broker_url.startswith("redis://")
assert settings.taskiq_result_backend_url.startswith("redis://")
@@ -0,0 +1,37 @@
from __future__ import annotations
import importlib
import sys
import pytest
from core.taskiq.app import broker, bulk_broker, critical_broker, default_broker
def test_taskiq_broker_is_configured() -> None:
assert broker is not None
assert default_broker is broker
assert critical_broker is not None
assert bulk_broker is not None
def test_taskiq_app_configures_logging_on_import(
monkeypatch: pytest.MonkeyPatch,
) -> None:
sys.modules.pop("core.taskiq.app", None)
sys.modules.pop("core.taskiq", None)
called = {"count": 0, "args": None}
def _fake_configure_logging(*args: object, **__: object) -> None:
called["count"] += 1
called["args"] = args
monkeypatch.setattr("core.logging.configure_logging", _fake_configure_logging)
importlib.import_module("core.taskiq.app")
from core.config.settings import config
assert called["count"] == 1
assert called["args"] == (config,)
@@ -0,0 +1,20 @@
from __future__ import annotations
from pathlib import Path
ROOT_DIR = Path(__file__).resolve().parents[4]
APP_SCRIPT = ROOT_DIR / "infra" / "scripts" / "app.sh"
def test_worker_commands_use_taskiq() -> None:
content = APP_SCRIPT.read_text(encoding="utf-8")
removed_runner = "uv run c" "elery"
assert "uv run taskiq worker" in content
assert "core.taskiq.app:critical_broker" in content
assert "core.taskiq.app:default_broker" in content
assert "core.taskiq.app:bulk_broker" in content
assert 'pgrep -f "taskiq.*worker"' in content
assert 'pkill -f "taskiq.*worker"' in content
assert removed_runner not in content
@@ -3,7 +3,7 @@ from __future__ import annotations
import pytest
from core.config.settings import RedisSettings
from services.base.redis import RedisService
from services.base.redis import RedisService, get_or_init_redis_client, redis_service
class _FakeRedisClient:
@@ -96,3 +96,35 @@ def test_get_client_raises_before_init() -> None:
with pytest.raises(RuntimeError):
service.get_client()
@pytest.mark.asyncio
async def test_get_or_init_redis_client_initializes_when_needed(
monkeypatch: pytest.MonkeyPatch,
) -> None:
fake_client = _FakeRedisClient()
async def _fake_initialize() -> bool:
return True
monkeypatch.setattr(type(redis_service), "is_initialized", property(lambda _: False))
monkeypatch.setattr(redis_service, "initialize", _fake_initialize)
monkeypatch.setattr(redis_service, "get_client", lambda: fake_client)
client = await get_or_init_redis_client()
assert client is fake_client
@pytest.mark.asyncio
async def test_get_or_init_redis_client_raises_when_init_fails(
monkeypatch: pytest.MonkeyPatch,
) -> None:
async def _fake_initialize() -> bool:
return False
monkeypatch.setattr(type(redis_service), "is_initialized", property(lambda _: False))
monkeypatch.setattr(redis_service, "initialize", _fake_initialize)
with pytest.raises(RuntimeError, match="Redis service initialization failed"):
await get_or_init_redis_client()
@@ -1,8 +1,12 @@
from __future__ import annotations
import pytest
from services.base.service_interface import (
BaseServiceProvider,
ServiceRegistry,
close_registered_services,
initialize_registered_services,
register_service,
register_service_instance,
)
@@ -35,6 +39,17 @@ def test_register_service_and_create_service() -> None:
assert created.get_service_info()["name"] == "dummy"
def test_register_service_and_get_service() -> None:
@register_service("dummy-service-get")
class _RegisteredService(_DummyService):
pass
resolved = ServiceRegistry.get_service("dummy-service-get")
assert resolved is not None
assert resolved.get_service_info()["name"] == "dummy"
def test_register_service_instance_returns_same_instance() -> None:
instance = _DummyService("singleton")
@@ -47,3 +62,77 @@ def test_register_service_instance_returns_same_instance() -> None:
def test_create_service_returns_none_for_missing() -> None:
assert ServiceRegistry.create_service("missing-service") is None
def test_get_service_returns_none_for_missing() -> None:
assert ServiceRegistry.get_service("missing-service") is None
class _LifecycleService(BaseServiceProvider):
def __init__(self, name: str, recorder: list[str], fail_on_init: bool = False) -> None:
super().__init__(name)
self._recorder = recorder
self._fail_on_init = fail_on_init
async def initialize(self, **_: object) -> bool:
self._recorder.append(f"init:{self.service_name}")
if self._fail_on_init:
return False
self._set_initialized(True)
return True
async def close(self) -> bool:
self._recorder.append(f"close:{self.service_name}")
self._set_initialized(False)
return True
async def health_check(self) -> dict[str, object]:
return {"status": "healthy", "details": {}}
@pytest.mark.asyncio
async def test_initialize_registered_services_success() -> None:
recorder: list[str] = []
first = register_service_instance(
"lifecycle-success-first", _LifecycleService("first", recorder)
)
second = register_service_instance(
"lifecycle-success-second", _LifecycleService("second", recorder)
)
initialized, services = await initialize_registered_services(
["lifecycle-success-first", "lifecycle-success-second"]
)
assert initialized is True
assert services == [first, second]
assert recorder == ["init:first", "init:second"]
@pytest.mark.asyncio
async def test_initialize_registered_services_failure_rolls_back() -> None:
recorder: list[str] = []
register_service_instance("lifecycle-fail-first", _LifecycleService("first", recorder))
register_service_instance(
"lifecycle-fail-second", _LifecycleService("second", recorder, fail_on_init=True)
)
initialized, services = await initialize_registered_services(
["lifecycle-fail-first", "lifecycle-fail-second"]
)
assert initialized is False
assert services == []
assert recorder == ["init:first", "init:second", "close:first"]
@pytest.mark.asyncio
async def test_close_registered_services_closes_in_reverse_order() -> None:
recorder: list[str] = []
first = _LifecycleService("first", recorder)
second = _LifecycleService("second", recorder)
closed = await close_registered_services([first, second])
assert closed is True
assert recorder == ["close:second", "close:first"]
@@ -0,0 +1,111 @@
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from core.config.settings import SupabaseSettings
from services.base.supabase import SupabaseService
@pytest.mark.asyncio
async def test_initialize_success(monkeypatch: pytest.MonkeyPatch) -> None:
service = SupabaseService(settings=SupabaseSettings())
anon_client = MagicMock()
admin_client = MagicMock()
create_calls: list[tuple[str, str]] = []
def _fake_create_client(url: str, key: str) -> object:
create_calls.append((url, key))
return anon_client if len(create_calls) == 1 else admin_client
monkeypatch.setattr("services.base.supabase.create_client", _fake_create_client)
result = await service.initialize()
assert result is True
assert service.is_initialized is True
assert service.get_client() is anon_client
assert service.get_admin_client() is admin_client
assert len(create_calls) == 2
@pytest.mark.asyncio
async def test_initialize_failure(monkeypatch: pytest.MonkeyPatch) -> None:
service = SupabaseService(settings=SupabaseSettings())
def _fake_create_client(_: str, __: str) -> object:
raise RuntimeError("boom")
monkeypatch.setattr("services.base.supabase.create_client", _fake_create_client)
result = await service.initialize()
assert result is False
assert service.is_initialized is False
with pytest.raises(RuntimeError):
service.get_client()
@pytest.mark.asyncio
async def test_close_clears_clients(monkeypatch: pytest.MonkeyPatch) -> None:
service = SupabaseService(settings=SupabaseSettings())
def _fake_create_client(_: str, __: str) -> object:
return MagicMock()
monkeypatch.setattr("services.base.supabase.create_client", _fake_create_client)
assert await service.initialize() is True
assert await service.close() is True
assert service.is_initialized is False
with pytest.raises(RuntimeError):
service.get_client()
with pytest.raises(RuntimeError):
service.get_admin_client()
@pytest.mark.asyncio
async def test_health_check_uninitialized() -> None:
service = SupabaseService(settings=SupabaseSettings())
health = await service.health_check()
assert health["status"] == "unhealthy"
@pytest.mark.asyncio
async def test_health_check_initialized(monkeypatch: pytest.MonkeyPatch) -> None:
service = SupabaseService(settings=SupabaseSettings())
anon_client = MagicMock()
anon_client.auth.get_session = MagicMock(return_value=None)
admin_list_users = MagicMock(return_value=SimpleNamespace(users=[]))
admin_client = MagicMock()
admin_client.auth.admin = SimpleNamespace(list_users=admin_list_users)
create_sequence = [anon_client, admin_client]
def _fake_create_client(_: str, __: str) -> object:
return create_sequence.pop(0)
monkeypatch.setattr("services.base.supabase.create_client", _fake_create_client)
assert await service.initialize() is True
health = await service.health_check()
assert health["status"] == "healthy"
admin_list_users.assert_called_once_with(page=1, per_page=1)
def test_get_client_raises_before_init() -> None:
service = SupabaseService(settings=SupabaseSettings())
with pytest.raises(RuntimeError):
service.get_client()
with pytest.raises(RuntimeError):
service.get_admin_client()
+43
View File
@@ -0,0 +1,43 @@
from __future__ import annotations
import pytest
import app as app_module
@pytest.mark.asyncio
async def test_lifespan_uses_registered_services(monkeypatch: pytest.MonkeyPatch) -> None:
initialized_services = [object(), object()]
calls: dict[str, object] = {}
async def _fake_initialize(service_names: list[str]) -> tuple[bool, list[object]]:
calls["init_names"] = service_names
return True, initialized_services
async def _fake_close(services: list[object]) -> bool:
calls["close_services"] = services
return True
monkeypatch.setattr(app_module, "initialize_registered_services", _fake_initialize)
monkeypatch.setattr(app_module, "close_registered_services", _fake_close)
context = app_module.lifespan(app_module.app)
await context.__aenter__()
await context.__aexit__(None, None, None)
assert calls["init_names"] == ["redis", "supabase"]
assert calls["close_services"] == initialized_services
@pytest.mark.asyncio
async def test_lifespan_raises_when_initialization_failed(
monkeypatch: pytest.MonkeyPatch,
) -> None:
async def _fake_initialize(_: list[str]) -> tuple[bool, list[object]]:
return False, []
monkeypatch.setattr(app_module, "initialize_registered_services", _fake_initialize)
context = app_module.lifespan(app_module.app)
with pytest.raises(RuntimeError, match="Service initialization failed"):
await context.__aenter__()
-45
View File
@@ -1,45 +0,0 @@
from __future__ import annotations
from celery import Celery
from pytest import MonkeyPatch
from core.logging import celery as celery_logging
from core.logging.context import clear_context, get_context
class DummyTask:
name: str = "tasks.sample"
def test_celery_prerun_binds_task_context() -> None:
handlers = celery_logging.build_celery_signal_handlers()
handlers.on_task_prerun(task_id="task-123", task=DummyTask())
context = get_context()
assert context["task_id"] == "task-123"
assert context["task_name"] == "tasks.sample"
clear_context()
def test_celery_setup_logging_calls_configure(monkeypatch: MonkeyPatch) -> None:
called = {"value": False}
def fake_configure_logging(settings: object | None = None) -> None:
called["value"] = True
monkeypatch.setattr(celery_logging, "configure_logging", fake_configure_logging)
handlers = celery_logging.build_celery_signal_handlers()
handlers.on_setup_logging()
assert called["value"] is True
def test_configure_celery_app_disables_hijack() -> None:
app = Celery("test")
celery_logging.configure_celery_app(app)
assert app.conf.worker_hijack_root_logger is False
@@ -0,0 +1,227 @@
from __future__ import annotations
import pytest
from v1.agent.dependencies import TaskiqQueueClient
class _FakeRedis:
def __init__(self) -> None:
self.store: dict[str, str] = {}
self.delete_calls: list[str] = []
async def set(
self,
key: str,
value: str,
*,
nx: bool = False,
ex: int | None = None,
) -> bool:
del ex
if nx and key in self.store:
return False
self.store[key] = value
return True
async def get(self, key: str) -> str | None:
return self.store.get(key)
async def delete(self, key: str) -> int:
self.delete_calls.append(key)
existed = 1 if key in self.store else 0
self.store.pop(key, None)
return existed
class _FakeAsyncResult:
def __init__(self, task_id: str) -> None:
self.task_id = task_id
@pytest.mark.asyncio
async def test_enqueue_returns_task_id(monkeypatch: pytest.MonkeyPatch) -> None:
from v1.agent import dependencies as deps
fake_redis = _FakeRedis()
resolved_client = {"value": False}
async def _fake_kiq(payload: dict[str, object]) -> _FakeAsyncResult:
assert payload["command"] == "run"
return _FakeAsyncResult("task-123")
async def _fake_get_or_init_client() -> _FakeRedis:
resolved_client["value"] = True
return fake_redis
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_kiq)
client = TaskiqQueueClient()
task_id = await client.enqueue(command={"command": "run"}, dedup_key=None)
assert resolved_client["value"] is True
assert task_id == "task-123"
@pytest.mark.asyncio
async def test_enqueue_resume_dedup_returns_existing_task_id(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from v1.agent import dependencies as deps
fake_redis = _FakeRedis()
resolved_client = {"value": False}
async def _fake_kiq(payload: dict[str, object]) -> _FakeAsyncResult:
del payload
return _FakeAsyncResult("new-task-id")
async def _fake_get_or_init_client() -> _FakeRedis:
resolved_client["value"] = True
return fake_redis
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_kiq)
dedup_key = "resume:session-1:call-1"
fake_redis.store[f"agent:dedup:{dedup_key}"] = "existing-task-id"
client = TaskiqQueueClient()
task_id = await client.enqueue(
command={"command": "resume", "session_id": "session-1", "tool_call_id": "call-1"},
dedup_key=dedup_key,
)
assert resolved_client["value"] is True
assert task_id == "existing-task-id"
@pytest.mark.asyncio
async def test_enqueue_inflight_dedup_waits_and_reuses_existing_task_id(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from v1.agent import dependencies as deps
fake_redis = _FakeRedis()
dedup_key = "resume:session-1:call-1"
redis_key = f"agent:dedup:{dedup_key}"
fake_redis.store[redis_key] = deps.DEDUP_INFLIGHT_MARKER
attempts = {"count": 0}
async def _fake_get_or_init_client() -> _FakeRedis:
return fake_redis
async def _fake_get(key: str) -> str | None:
attempts["count"] += 1
if attempts["count"] > 1:
fake_redis.store[key] = "existing-task-id"
return fake_redis.store.get(key)
async def _fake_sleep(_: float) -> None:
return None
async def _fake_kiq(payload: dict[str, object]) -> _FakeAsyncResult:
del payload
raise AssertionError("should not enqueue when dedup task id appears")
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
monkeypatch.setattr(fake_redis, "get", _fake_get)
monkeypatch.setattr(deps.asyncio, "sleep", _fake_sleep)
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_kiq)
client = TaskiqQueueClient()
task_id = await client.enqueue(
command={"command": "resume", "session_id": "session-1", "tool_call_id": "call-1"},
dedup_key=dedup_key,
)
assert task_id == "existing-task-id"
@pytest.mark.asyncio
async def test_enqueue_failure_cleans_dedup_lock(monkeypatch: pytest.MonkeyPatch) -> None:
from v1.agent import dependencies as deps
fake_redis = _FakeRedis()
dedup_key = "resume:session-1:call-1"
redis_key = f"agent:dedup:{dedup_key}"
async def _fake_get_or_init_client() -> _FakeRedis:
return fake_redis
async def _fake_kiq(payload: dict[str, object]) -> _FakeAsyncResult:
del payload
raise RuntimeError("enqueue failed")
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_kiq)
client = TaskiqQueueClient()
with pytest.raises(RuntimeError, match="enqueue failed"):
await client.enqueue(
command={"command": "resume", "session_id": "session-1", "tool_call_id": "call-1"},
dedup_key=dedup_key,
)
assert redis_key in fake_redis.delete_calls
@pytest.mark.asyncio
async def test_enqueue_uses_critical_queue_when_requested(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from v1.agent import dependencies as deps
fake_redis = _FakeRedis()
async def _fake_get_or_init_client() -> _FakeRedis:
return fake_redis
async def _fake_default_kiq(_: dict[str, object]) -> _FakeAsyncResult:
raise AssertionError("default queue should not be selected")
async def _fake_critical_kiq(_: dict[str, object]) -> _FakeAsyncResult:
return _FakeAsyncResult("critical-task-id")
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_default_kiq)
monkeypatch.setattr(deps.run_command_task_critical, "kiq", _fake_critical_kiq)
client = TaskiqQueueClient()
task_id = await client.enqueue(
command={"command": "run", "queue": "critical"},
dedup_key=None,
)
assert task_id == "critical-task-id"
@pytest.mark.asyncio
async def test_enqueue_uses_bulk_queue_when_requested(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from v1.agent import dependencies as deps
fake_redis = _FakeRedis()
async def _fake_get_or_init_client() -> _FakeRedis:
return fake_redis
async def _fake_default_kiq(_: dict[str, object]) -> _FakeAsyncResult:
raise AssertionError("default queue should not be selected")
async def _fake_bulk_kiq(_: dict[str, object]) -> _FakeAsyncResult:
return _FakeAsyncResult("bulk-task-id")
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_default_kiq)
monkeypatch.setattr(deps.run_command_task_bulk, "kiq", _fake_bulk_kiq)
client = TaskiqQueueClient()
task_id = await client.enqueue(
command={"command": "run", "queue": "bulk"},
dedup_key=None,
)
assert task_id == "bulk-task-id"
+40 -28
View File
@@ -1,7 +1,7 @@
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock
import pytest
from fastapi import HTTPException
@@ -12,37 +12,44 @@ from v1.auth.schemas import PasswordResetConfirmRequest, PasswordResetRequest
class TestSupabaseAuthGateway:
@pytest.fixture
def gateway(self) -> SupabaseAuthGateway:
with patch("v1.auth.gateway.create_client") as mock_create:
mock_client = MagicMock()
mock_admin_client = MagicMock()
mock_create.side_effect = [mock_client, mock_admin_client]
return SupabaseAuthGateway()
def gateway(
self, monkeypatch: pytest.MonkeyPatch
) -> tuple[SupabaseAuthGateway, MagicMock, MagicMock]:
mock_client = MagicMock()
mock_admin_client = MagicMock()
monkeypatch.setattr("v1.auth.gateway.supabase_service.get_client", lambda: mock_client)
monkeypatch.setattr(
"v1.auth.gateway.supabase_service.get_admin_client",
lambda: mock_admin_client,
)
return SupabaseAuthGateway(), mock_client, mock_admin_client
@pytest.mark.asyncio
async def test_request_password_reset_calls_email_with_string(
self, gateway: SupabaseAuthGateway
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
) -> None:
sut, mock_client, _ = gateway
mock_reset_email = MagicMock()
gateway._client.auth.reset_password_email = mock_reset_email
mock_client.auth.reset_password_email = mock_reset_email
request = PasswordResetRequest(email="test@example.com")
await gateway.request_password_reset(request)
await sut.request_password_reset(request)
mock_reset_email.assert_called_once_with("test@example.com")
@pytest.mark.asyncio
async def test_request_password_reset_with_redirect(
self, gateway: SupabaseAuthGateway
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
) -> None:
sut, mock_client, _ = gateway
mock_reset_email = MagicMock()
gateway._client.auth.reset_password_email = mock_reset_email
mock_client.auth.reset_password_email = mock_reset_email
request = PasswordResetRequest(
email="test@example.com",
redirect_to="http://localhost:3000/reset-password",
)
await gateway.request_password_reset(request)
await sut.request_password_reset(request)
mock_reset_email.assert_called_once_with(
"test@example.com",
@@ -51,64 +58,68 @@ class TestSupabaseAuthGateway:
@pytest.mark.asyncio
async def test_request_password_reset_swallows_auth_error(
self, gateway: SupabaseAuthGateway
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
) -> None:
sut, mock_client, _ = gateway
from supabase import AuthError
mock_reset_email = MagicMock(side_effect=AuthError("rate limit exceeded", None))
gateway._client.auth.reset_password_email = mock_reset_email
mock_client.auth.reset_password_email = mock_reset_email
request = PasswordResetRequest(email="test@example.com")
result = await gateway.request_password_reset(request)
result = await sut.request_password_reset(request)
mock_reset_email.assert_called_once()
assert result is None
@pytest.mark.asyncio
async def test_request_password_reset_extracts_email_from_mapping(
self, gateway: SupabaseAuthGateway
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
) -> None:
sut, mock_client, _ = gateway
mock_reset_email = MagicMock()
gateway._client.auth.reset_password_email = mock_reset_email
mock_client.auth.reset_password_email = mock_reset_email
request = PasswordResetRequest.model_construct(
email={"email": "test@example.com"},
redirect_to=None,
)
await gateway.request_password_reset(request)
await sut.request_password_reset(request)
mock_reset_email.assert_called_once_with("test@example.com")
@pytest.mark.asyncio
async def test_request_password_reset_rejects_invalid_email_shape(
self, gateway: SupabaseAuthGateway
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
) -> None:
sut, _, _ = gateway
request = PasswordResetRequest.model_construct(
email={"unexpected": "value"},
redirect_to=None,
)
with pytest.raises(HTTPException) as exc_info:
await gateway.request_password_reset(request)
await sut.request_password_reset(request)
assert exc_info.value.status_code == 422
assert exc_info.value.detail == "Invalid email"
@pytest.mark.asyncio
async def test_confirm_password_reset_updates_password_by_user_id(
self, gateway: SupabaseAuthGateway
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
) -> None:
sut, mock_client, mock_admin_client = gateway
verify_response = SimpleNamespace(
session=SimpleNamespace(access_token="access"),
user=SimpleNamespace(id="user-1"),
)
mock_verify_otp = MagicMock(return_value=verify_response)
gateway._client.auth.verify_otp = mock_verify_otp
mock_client.auth.verify_otp = mock_verify_otp
mock_update_user_by_id = MagicMock()
gateway._admin_client.auth.admin = SimpleNamespace(
mock_admin_client.auth.admin = SimpleNamespace(
update_user_by_id=mock_update_user_by_id
)
@@ -118,7 +129,7 @@ class TestSupabaseAuthGateway:
new_password="newpassword123",
)
await gateway.confirm_password_reset(request)
await sut.confirm_password_reset(request)
mock_verify_otp.assert_called_once_with(
{
@@ -134,13 +145,14 @@ class TestSupabaseAuthGateway:
@pytest.mark.asyncio
async def test_confirm_password_reset_raises_when_user_id_missing(
self, gateway: SupabaseAuthGateway
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
) -> None:
sut, mock_client, _ = gateway
verify_response = SimpleNamespace(
session=SimpleNamespace(access_token="access"),
user=SimpleNamespace(id=""),
)
gateway._client.auth.verify_otp = MagicMock(return_value=verify_response)
mock_client.auth.verify_otp = MagicMock(return_value=verify_response)
request = PasswordResetConfirmRequest(
email="test@example.com",
@@ -149,7 +161,7 @@ class TestSupabaseAuthGateway:
)
with pytest.raises(HTTPException) as exc_info:
await gateway.confirm_password_reset(request)
await sut.confirm_password_reset(request)
assert exc_info.value.status_code == 401
assert exc_info.value.detail == "Invalid or expired verification code"