chore: checkpoint current backend/runtime changes
This commit is contained in:
+13
-12
@@ -13,7 +13,7 @@ from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
from core.config.settings import config
|
||||
from core.http.response import build_problem_details
|
||||
from core.logging import configure_logging, get_logger, log_service_banner
|
||||
from services.base.redis import redis_service
|
||||
from services.base import close_registered_services, initialize_registered_services
|
||||
from v1.router import router as mobile_router
|
||||
|
||||
|
||||
@@ -29,22 +29,23 @@ log_service_banner(
|
||||
)
|
||||
|
||||
logger = get_logger("api.app")
|
||||
SERVICE_STARTUP_ORDER = ["redis", "supabase"]
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
initialized = await redis_service.initialize()
|
||||
initialized, services = await initialize_registered_services(SERVICE_STARTUP_ORDER)
|
||||
if not initialized:
|
||||
logger.error("Redis service failed to initialize, aborting startup")
|
||||
raise RuntimeError("Redis service initialization failed")
|
||||
logger.info(
|
||||
"Redis service initialized",
|
||||
host=config.redis.host,
|
||||
db=config.redis.db,
|
||||
)
|
||||
yield
|
||||
await redis_service.close()
|
||||
logger.info("Redis service closed")
|
||||
logger.error("Service initialization failed, aborting startup")
|
||||
raise RuntimeError("Service initialization failed")
|
||||
logger.info("Base services initialized", services=SERVICE_STARTUP_ORDER)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
closed = await close_registered_services(services)
|
||||
if not closed:
|
||||
logger.warning("Failed to close all base services")
|
||||
logger.info("Base services closed", services=SERVICE_STARTUP_ORDER)
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Protocol, cast
|
||||
from typing import Any, Protocol
|
||||
from uuid import UUID
|
||||
|
||||
from core.agent.application.resume_service import ResumeService
|
||||
from core.agent.application.run_service import RunService
|
||||
from core.agent.infrastructure.events.redis_stream import RedisStreamEventStore
|
||||
from core.celery.app import celery_app
|
||||
from core.config.settings import config
|
||||
from core.logging import get_logger
|
||||
from services.base.redis import redis_service
|
||||
from core.taskiq.app import bulk_broker, critical_broker, default_broker
|
||||
from services.base.redis import get_or_init_redis_client
|
||||
|
||||
logger = get_logger("core.agent.infrastructure.queue.tasks")
|
||||
|
||||
@@ -29,13 +29,12 @@ class ResumeServiceLike(Protocol):
|
||||
|
||||
|
||||
async def _build_redis_publisher() -> PublishEvent:
|
||||
settings = cast(Any, config)
|
||||
client = redis_service.get_client()
|
||||
client = await get_or_init_redis_client()
|
||||
event_store = RedisStreamEventStore(
|
||||
client=client,
|
||||
stream_prefix=settings.agent_runtime.redis_stream_prefix,
|
||||
read_count=settings.agent_runtime.redis_stream_read_count,
|
||||
block_ms=settings.agent_runtime.redis_stream_block_ms,
|
||||
stream_prefix=config.agent_runtime.redis_stream_prefix,
|
||||
read_count=config.agent_runtime.redis_stream_read_count,
|
||||
block_ms=config.agent_runtime.redis_stream_block_ms,
|
||||
)
|
||||
|
||||
async def _publish(event_type: str, payload: dict[str, object]) -> None:
|
||||
@@ -70,22 +69,27 @@ async def run_agent_task(
|
||||
raise ValueError("session_id is required")
|
||||
UUID(session_id)
|
||||
|
||||
tool_call_id = ""
|
||||
user_input = ""
|
||||
if command_type == "resume":
|
||||
tool_call_id = str(command.get("tool_call_id", ""))
|
||||
if not tool_call_id:
|
||||
raise ValueError("tool_call_id is required")
|
||||
else:
|
||||
user_input = str(command.get("user_input", ""))
|
||||
if not user_input:
|
||||
raise ValueError("user_input is required")
|
||||
|
||||
start_event = "RUN_RESUMED" if command_type == "resume" else "RUN_STARTED"
|
||||
await publisher(start_event, {"session_id": session_id})
|
||||
|
||||
try:
|
||||
if command_type == "resume":
|
||||
tool_call_id = str(command.get("tool_call_id", ""))
|
||||
if not tool_call_id:
|
||||
raise ValueError("tool_call_id is required")
|
||||
result = await service_resume.resume(
|
||||
session_id=session_id,
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
else:
|
||||
user_input = str(command.get("user_input", ""))
|
||||
if not user_input:
|
||||
raise ValueError("user_input is required")
|
||||
result = await service_run.run(
|
||||
session_id=session_id,
|
||||
user_input=user_input,
|
||||
@@ -125,6 +129,16 @@ async def run_agent_task(
|
||||
raise
|
||||
|
||||
|
||||
@celery_app.task(name="tasks.agent.run_command")
|
||||
@default_broker.task(task_name="tasks.agent.run_command")
|
||||
async def run_command_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
return await run_agent_task(command)
|
||||
|
||||
|
||||
@critical_broker.task(task_name="tasks.agent.run_command.critical")
|
||||
async def run_command_task_critical(command: dict[str, Any]) -> dict[str, object]:
|
||||
return await run_agent_task(command)
|
||||
|
||||
|
||||
@bulk_broker.task(task_name="tasks.agent.run_command.bulk")
|
||||
async def run_command_task_bulk(command: dict[str, Any]) -> dict[str, object]:
|
||||
return await run_agent_task(command)
|
||||
|
||||
@@ -1,73 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals as celery_signals
|
||||
from kombu import Queue
|
||||
|
||||
from core.config.settings import config
|
||||
from core.logging import get_logger
|
||||
from core.logging.celery import configure_celery_app
|
||||
from services.base.redis import redis_service
|
||||
|
||||
logger = get_logger("core.celery")
|
||||
|
||||
|
||||
def _init_redis_on_worker_startup(**_: object) -> None:
|
||||
import asyncio
|
||||
|
||||
logger.info("Initializing Redis service for Celery worker")
|
||||
try:
|
||||
result = asyncio.run(redis_service.initialize())
|
||||
if result:
|
||||
logger.info("Redis service initialized for Celery worker")
|
||||
else:
|
||||
logger.warning("Redis service initialization returned False")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Failed to initialize Redis for Celery worker", error=str(exc))
|
||||
|
||||
|
||||
def create_celery_app() -> Celery:
|
||||
"""Create and configure the Celery application."""
|
||||
celery_settings = config.celery
|
||||
|
||||
app = Celery(
|
||||
"social_app",
|
||||
broker=config.celery_broker_url,
|
||||
backend=config.celery_result_backend,
|
||||
include=["core.agent.infrastructure.queue.tasks"],
|
||||
)
|
||||
|
||||
app.conf.update(
|
||||
task_serializer=celery_settings.task_serializer,
|
||||
result_serializer=celery_settings.result_serializer,
|
||||
accept_content=celery_settings.accept_content,
|
||||
timezone=celery_settings.timezone,
|
||||
enable_utc=celery_settings.enable_utc,
|
||||
task_track_started=celery_settings.task_track_started,
|
||||
task_time_limit=celery_settings.task_time_limit,
|
||||
task_soft_time_limit=celery_settings.task_soft_time_limit,
|
||||
task_default_retry_delay=celery_settings.task_default_retry_delay,
|
||||
task_default_queue="default",
|
||||
task_create_missing_queues=False,
|
||||
task_queues=(
|
||||
Queue("default"),
|
||||
Queue("critical"),
|
||||
Queue("bulk"),
|
||||
),
|
||||
task_routes={
|
||||
"tasks.critical.*": {"queue": "critical"},
|
||||
"tasks.bulk.*": {"queue": "bulk"},
|
||||
},
|
||||
task_acks_late=True,
|
||||
task_reject_on_worker_lost=True,
|
||||
worker_prefetch_multiplier=1,
|
||||
)
|
||||
|
||||
configure_celery_app(app, settings=config)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
celery_app = create_celery_app()
|
||||
|
||||
celery_signals.worker_process_init.connect(_init_redis_on_worker_startup)
|
||||
@@ -63,19 +63,9 @@ class RuntimeSettings(BaseModel):
|
||||
return self
|
||||
|
||||
|
||||
class CelerySettings(BaseModel):
|
||||
class TaskiqSettings(BaseModel):
|
||||
broker_url: str | None = None
|
||||
result_backend: str | None = None
|
||||
task_serializer: str = "json"
|
||||
result_serializer: str = "json"
|
||||
accept_content: list[str] = Field(default_factory=lambda: ["json"])
|
||||
timezone: str = "UTC"
|
||||
enable_utc: bool = True
|
||||
task_track_started: bool = True
|
||||
task_time_limit: int = 300
|
||||
task_soft_time_limit: int = 240
|
||||
task_default_retry_delay: int = 30
|
||||
task_max_retries: int = 3
|
||||
result_backend_url: str | None = None
|
||||
|
||||
|
||||
class CorsSettings(BaseModel):
|
||||
@@ -189,7 +179,7 @@ class Settings(BaseSettings):
|
||||
storage: StorageSettings = StorageSettings()
|
||||
llm: LlmSettings = LlmSettings()
|
||||
agent_runtime: AgentRuntimeSettings = AgentRuntimeSettings()
|
||||
celery: CelerySettings = CelerySettings()
|
||||
taskiq: TaskiqSettings = TaskiqSettings()
|
||||
database: DatabaseSettings = DatabaseSettings()
|
||||
|
||||
@computed_field
|
||||
@@ -199,13 +189,13 @@ class Settings(BaseSettings):
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def celery_broker_url(self) -> str:
|
||||
return self.celery.broker_url or self.redis.url
|
||||
def taskiq_broker_url(self) -> str:
|
||||
return self.taskiq.broker_url or self.redis.url
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def celery_result_backend(self) -> str:
|
||||
return self.celery.result_backend or self.redis.url
|
||||
def taskiq_result_backend_url(self) -> str:
|
||||
return self.taskiq.result_backend_url or self.redis.url
|
||||
|
||||
model_config: ClassVar[SettingsConfigDict] = SettingsConfigDict(
|
||||
env_file=_resolve_env_file(),
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.logging import celery
|
||||
from core.logging.banner import log_service_banner
|
||||
from core.logging.config import configure_logging
|
||||
from core.logging.context import bind_context, clear_context, get_context
|
||||
@@ -8,7 +7,6 @@ from core.logging.logger import get_logger
|
||||
|
||||
__all__ = [
|
||||
"bind_context",
|
||||
"celery",
|
||||
"clear_context",
|
||||
"configure_logging",
|
||||
"get_context",
|
||||
|
||||
@@ -1,64 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery, signals
|
||||
|
||||
from core.config.settings import Settings
|
||||
from core.logging.banner import log_service_banner
|
||||
from core.logging.config import configure_logging
|
||||
from core.logging.context import bind_context, clear_context
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CelerySignalHandlers:
|
||||
on_setup_logging: Callable[..., None]
|
||||
on_after_setup_task_logger: Callable[..., None]
|
||||
on_task_prerun: Callable[..., None]
|
||||
on_task_postrun: Callable[..., None]
|
||||
|
||||
|
||||
def build_celery_signal_handlers(
|
||||
settings: Settings | None = None,
|
||||
) -> CelerySignalHandlers:
|
||||
active_settings = settings or Settings()
|
||||
|
||||
def on_setup_logging(*_args: object, **_kwargs: object) -> None:
|
||||
configure_logging(settings)
|
||||
log_service_banner(
|
||||
service_name=active_settings.runtime.service_name,
|
||||
environment=active_settings.runtime.environment,
|
||||
)
|
||||
|
||||
def on_after_setup_task_logger(*_args: object, **_kwargs: object) -> None:
|
||||
configure_logging(settings)
|
||||
|
||||
def on_task_prerun(*_args: object, **kwargs: object) -> None:
|
||||
task_id = cast(str | None, kwargs.get("task_id"))
|
||||
task = kwargs.get("task")
|
||||
task_name = getattr(task, "name", None)
|
||||
bind_context(task_id=task_id, task_name=task_name)
|
||||
|
||||
def on_task_postrun(*_args: object, **_kwargs: object) -> None:
|
||||
clear_context()
|
||||
|
||||
return CelerySignalHandlers(
|
||||
on_setup_logging=on_setup_logging,
|
||||
on_after_setup_task_logger=on_after_setup_task_logger,
|
||||
on_task_prerun=on_task_prerun,
|
||||
on_task_postrun=on_task_postrun,
|
||||
)
|
||||
|
||||
|
||||
def configure_celery_app(app: Celery, settings: Settings | None = None) -> None:
|
||||
app.conf.worker_hijack_root_logger = False
|
||||
|
||||
handlers = build_celery_signal_handlers(settings)
|
||||
signals.setup_logging.connect(handlers.on_setup_logging, weak=False)
|
||||
signals.after_setup_task_logger.connect(
|
||||
handlers.on_after_setup_task_logger, weak=False
|
||||
)
|
||||
signals.task_prerun.connect(handlers.on_task_prerun, weak=False)
|
||||
signals.task_postrun.connect(handlers.on_task_postrun, weak=False)
|
||||
@@ -0,0 +1,3 @@
|
||||
from core.taskiq.app import broker, bulk_broker, critical_broker, default_broker
|
||||
|
||||
__all__ = ["broker", "default_broker", "critical_broker", "bulk_broker"]
|
||||
@@ -0,0 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from taskiq_redis import ListQueueBroker, RedisAsyncResultBackend
|
||||
|
||||
from core.config.settings import config
|
||||
from core.logging import configure_logging
|
||||
|
||||
|
||||
configure_logging(config)
|
||||
|
||||
def _build_broker(queue_name: str) -> ListQueueBroker:
|
||||
return ListQueueBroker(
|
||||
url=config.taskiq_broker_url,
|
||||
queue_name=queue_name,
|
||||
).with_result_backend(
|
||||
RedisAsyncResultBackend(redis_url=config.taskiq_result_backend_url)
|
||||
)
|
||||
|
||||
|
||||
default_broker = _build_broker("default")
|
||||
critical_broker = _build_broker("critical")
|
||||
bulk_broker = _build_broker("bulk")
|
||||
|
||||
# Backward-compatible export name for existing imports/tests.
|
||||
broker = default_broker
|
||||
|
||||
__all__ = ["broker", "default_broker", "critical_broker", "bulk_broker"]
|
||||
@@ -1,18 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from services.base.redis import RedisService, redis_service
|
||||
from services.base.redis import RedisService, get_or_init_redis_client, redis_service
|
||||
from services.base.service_interface import (
|
||||
BaseServiceProvider,
|
||||
ServiceRegistry,
|
||||
close_registered_services,
|
||||
initialize_registered_services,
|
||||
register_service,
|
||||
register_service_instance,
|
||||
resolve_registered_services,
|
||||
)
|
||||
from services.base.supabase import SupabaseService, supabase_service
|
||||
|
||||
__all__ = [
|
||||
"BaseServiceProvider",
|
||||
"RedisService",
|
||||
"ServiceRegistry",
|
||||
"SupabaseService",
|
||||
"close_registered_services",
|
||||
"get_or_init_redis_client",
|
||||
"initialize_registered_services",
|
||||
"redis_service",
|
||||
"register_service",
|
||||
"register_service_instance",
|
||||
"resolve_registered_services",
|
||||
"supabase_service",
|
||||
]
|
||||
|
||||
@@ -92,6 +92,14 @@ class RedisService(BaseServiceProvider):
|
||||
return self._require_client()
|
||||
|
||||
|
||||
async def get_or_init_redis_client() -> redis.Redis:
|
||||
if not redis_service.is_initialized:
|
||||
initialized = await redis_service.initialize()
|
||||
if not initialized:
|
||||
raise RuntimeError("Redis service initialization failed")
|
||||
return redis_service.get_client()
|
||||
|
||||
|
||||
redis_service: RedisService = register_service_instance("redis", RedisService())
|
||||
|
||||
__all__ = ["RedisService", "redis_service"]
|
||||
__all__ = ["RedisService", "get_or_init_redis_client", "redis_service"]
|
||||
|
||||
@@ -61,6 +61,12 @@ class ServiceRegistry:
|
||||
@classmethod
|
||||
def create_service(
|
||||
cls, service_name: str, **kwargs: Any
|
||||
) -> Optional[BaseServiceProvider]:
|
||||
return cls.get_service(service_name, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_service(
|
||||
cls, service_name: str, **kwargs: Any
|
||||
) -> Optional[BaseServiceProvider]:
|
||||
factory = cls.get_service_factory(service_name)
|
||||
if not factory:
|
||||
@@ -82,3 +88,71 @@ TService = TypeVar("TService", bound=BaseServiceProvider)
|
||||
def register_service_instance(service_name: str, service: TService) -> TService:
|
||||
ServiceRegistry.register(service_name, lambda: service)
|
||||
return service
|
||||
|
||||
|
||||
def resolve_registered_services(service_names: list[str]) -> list[BaseServiceProvider]:
|
||||
services: list[BaseServiceProvider] = []
|
||||
for service_name in service_names:
|
||||
service = ServiceRegistry.get_service(service_name)
|
||||
if service is None:
|
||||
raise RuntimeError(f"Service is not registered: {service_name}")
|
||||
services.append(service)
|
||||
return services
|
||||
|
||||
|
||||
async def close_registered_services(services: list[BaseServiceProvider]) -> bool:
|
||||
lifecycle_logger = get_logger("services.base.lifecycle")
|
||||
all_closed = True
|
||||
for service in reversed(services):
|
||||
try:
|
||||
closed = await service.close()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
lifecycle_logger.warning(
|
||||
"Failed to close service",
|
||||
service=service.service_name,
|
||||
error=str(exc),
|
||||
)
|
||||
all_closed = False
|
||||
continue
|
||||
if not closed:
|
||||
lifecycle_logger.warning(
|
||||
"Service close returned false",
|
||||
service=service.service_name,
|
||||
)
|
||||
all_closed = False
|
||||
return all_closed
|
||||
|
||||
|
||||
async def initialize_registered_services(
|
||||
service_names: list[str],
|
||||
) -> tuple[bool, list[BaseServiceProvider]]:
|
||||
lifecycle_logger = get_logger("services.base.lifecycle")
|
||||
initialized_services: list[BaseServiceProvider] = []
|
||||
try:
|
||||
services = resolve_registered_services(service_names)
|
||||
except RuntimeError as exc:
|
||||
lifecycle_logger.error("Failed to resolve registered services", error=str(exc))
|
||||
return False, []
|
||||
|
||||
for service in services:
|
||||
try:
|
||||
initialized = await service.initialize()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
lifecycle_logger.warning(
|
||||
"Service initialization raised exception",
|
||||
service=service.service_name,
|
||||
error=str(exc),
|
||||
)
|
||||
initialized = False
|
||||
|
||||
if not initialized:
|
||||
lifecycle_logger.error(
|
||||
"Service initialization failed, rolling back",
|
||||
service=service.service_name,
|
||||
)
|
||||
await close_registered_services(initialized_services)
|
||||
return False, []
|
||||
|
||||
initialized_services.append(service)
|
||||
|
||||
return True, initialized_services
|
||||
|
||||
@@ -0,0 +1,89 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from supabase import create_client
|
||||
|
||||
from core.config.settings import SupabaseSettings, config
|
||||
|
||||
from .service_interface import BaseServiceProvider, register_service_instance
|
||||
|
||||
|
||||
class SupabaseService(BaseServiceProvider):
|
||||
def __init__(self, settings: SupabaseSettings | None = None) -> None:
|
||||
super().__init__("supabase")
|
||||
self._settings = settings or config.supabase
|
||||
self._client: Any = None
|
||||
self._admin_client: Any = None
|
||||
|
||||
async def initialize(self, **_: Any) -> bool:
|
||||
try:
|
||||
self._client = create_client(
|
||||
self._settings.url,
|
||||
self._settings.anon_key,
|
||||
)
|
||||
self._admin_client = create_client(
|
||||
self._settings.url,
|
||||
self._settings.service_role_key,
|
||||
)
|
||||
self._set_initialized(True)
|
||||
self.logger.info("Supabase service initialized")
|
||||
return True
|
||||
except Exception as exc: # noqa: BLE001
|
||||
self.logger.warning("Supabase service initialization failed", error=str(exc))
|
||||
self._client = None
|
||||
self._admin_client = None
|
||||
self._set_initialized(False)
|
||||
return False
|
||||
|
||||
async def close(self) -> bool:
|
||||
self._client = None
|
||||
self._admin_client = None
|
||||
self._set_initialized(False)
|
||||
self.logger.info("Supabase service closed")
|
||||
return True
|
||||
|
||||
async def health_check(self) -> dict[str, Any]:
|
||||
client = self._client
|
||||
admin_client = self._admin_client
|
||||
if client is None or admin_client is None:
|
||||
return {"status": "unhealthy", "details": {"error": "not initialized"}}
|
||||
try:
|
||||
await asyncio.to_thread(client.auth.get_session)
|
||||
await asyncio.to_thread(admin_client.auth.admin.list_users, page=1, per_page=1)
|
||||
return {
|
||||
"status": "healthy",
|
||||
"details": {
|
||||
"anon_client": "ready",
|
||||
"admin_client": "ready",
|
||||
},
|
||||
}
|
||||
except Exception as exc: # noqa: BLE001
|
||||
self.logger.warning("Supabase health check failed", error=str(exc))
|
||||
return {"status": "unhealthy", "details": {"error": str(exc)}}
|
||||
|
||||
def get_client(self) -> Any:
|
||||
return self._require_client()
|
||||
|
||||
def get_admin_client(self) -> Any:
|
||||
return self._require_admin_client()
|
||||
|
||||
def _require_client(self) -> Any:
|
||||
client = self._client
|
||||
if client is None:
|
||||
raise RuntimeError("Supabase client is not initialized")
|
||||
return client
|
||||
|
||||
def _require_admin_client(self) -> Any:
|
||||
admin_client = self._admin_client
|
||||
if admin_client is None:
|
||||
raise RuntimeError("Supabase admin client is not initialized")
|
||||
return admin_client
|
||||
|
||||
|
||||
supabase_service: SupabaseService = register_service_instance(
|
||||
"supabase", SupabaseService()
|
||||
)
|
||||
|
||||
__all__ = ["SupabaseService", "supabase_service"]
|
||||
@@ -1,57 +1,98 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, cast
|
||||
import asyncio
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Depends
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.agent.infrastructure.events.redis_stream import RedisStreamEventStore
|
||||
from core.agent.infrastructure.queue.tasks import run_command_task
|
||||
from core.agent.infrastructure.queue.tasks import (
|
||||
run_command_task,
|
||||
run_command_task_bulk,
|
||||
run_command_task_critical,
|
||||
)
|
||||
from core.config.settings import config
|
||||
from core.db import get_db
|
||||
from services.base.redis import redis_service
|
||||
from services.base.redis import get_or_init_redis_client
|
||||
from v1.agent.repository import AgentRepository
|
||||
from v1.agent.service import AgentService
|
||||
|
||||
DEDUP_WAIT_RETRIES = 20
|
||||
DEDUP_WAIT_SECONDS = 0.05
|
||||
DEDUP_LOCK_SECONDS = 300
|
||||
DEDUP_INFLIGHT_MARKER = "__inflight__"
|
||||
|
||||
class CeleryQueueClient:
|
||||
|
||||
class TaskiqQueueClient:
|
||||
def __init__(self) -> None:
|
||||
self._redis = redis_service.get_client()
|
||||
self._redis: Redis | None = None
|
||||
|
||||
async def _get_redis(self) -> Redis:
|
||||
if self._redis is None:
|
||||
self._redis = await get_or_init_redis_client()
|
||||
return self._redis
|
||||
|
||||
@staticmethod
|
||||
def _select_queue_task(command: dict[str, object]) -> Any:
|
||||
queue = str(command.get("queue", "default")).strip().lower()
|
||||
if queue == "critical":
|
||||
return run_command_task_critical
|
||||
if queue == "bulk":
|
||||
return run_command_task_bulk
|
||||
return run_command_task
|
||||
|
||||
async def enqueue(
|
||||
self, *, command: dict[str, object], dedup_key: str | None
|
||||
) -> str:
|
||||
redis_client = await self._get_redis()
|
||||
redis_key = None
|
||||
if dedup_key:
|
||||
redis_key = f"agent:dedup:{dedup_key}"
|
||||
locked = await self._redis.set(redis_key, "__inflight__", nx=True, ex=300)
|
||||
locked = await redis_client.set(
|
||||
redis_key,
|
||||
DEDUP_INFLIGHT_MARKER,
|
||||
nx=True,
|
||||
ex=DEDUP_LOCK_SECONDS,
|
||||
)
|
||||
if not locked:
|
||||
existing = await self._redis.get(redis_key)
|
||||
if existing and existing != "__inflight__":
|
||||
return existing
|
||||
for _ in range(DEDUP_WAIT_RETRIES):
|
||||
existing = await redis_client.get(redis_key)
|
||||
if existing and existing != DEDUP_INFLIGHT_MARKER:
|
||||
return existing
|
||||
await asyncio.sleep(DEDUP_WAIT_SECONDS)
|
||||
raise RuntimeError("duplicate request is still in progress")
|
||||
|
||||
payload = dict(command)
|
||||
if dedup_key:
|
||||
payload["dedup_key"] = dedup_key
|
||||
delay = getattr(run_command_task, "delay")
|
||||
result = delay(payload)
|
||||
task_id = str(result.id)
|
||||
if redis_key is not None:
|
||||
await self._redis.set(redis_key, task_id, ex=300)
|
||||
return task_id
|
||||
queue_task = self._select_queue_task(payload)
|
||||
try:
|
||||
result = await queue_task.kiq(payload)
|
||||
task_id = str(result.task_id)
|
||||
if redis_key is not None:
|
||||
await redis_client.set(redis_key, task_id, ex=DEDUP_LOCK_SECONDS)
|
||||
return task_id
|
||||
except Exception:
|
||||
if redis_key is not None:
|
||||
await redis_client.delete(redis_key)
|
||||
raise
|
||||
|
||||
|
||||
class RedisEventStream:
|
||||
def __init__(self) -> None:
|
||||
settings = cast(Any, config)
|
||||
client = redis_service.get_client()
|
||||
self._store = RedisStreamEventStore(
|
||||
client=client,
|
||||
stream_prefix=settings.agent_runtime.redis_stream_prefix,
|
||||
read_count=settings.agent_runtime.redis_stream_read_count,
|
||||
block_ms=settings.agent_runtime.redis_stream_block_ms,
|
||||
)
|
||||
self._store: RedisStreamEventStore | None = None
|
||||
|
||||
async def _get_store(self) -> RedisStreamEventStore:
|
||||
if self._store is None:
|
||||
client = await get_or_init_redis_client()
|
||||
self._store = RedisStreamEventStore(
|
||||
client=client,
|
||||
stream_prefix=config.agent_runtime.redis_stream_prefix,
|
||||
read_count=config.agent_runtime.redis_stream_read_count,
|
||||
block_ms=config.agent_runtime.redis_stream_block_ms,
|
||||
)
|
||||
return self._store
|
||||
|
||||
async def read(
|
||||
self,
|
||||
@@ -59,7 +100,8 @@ class RedisEventStream:
|
||||
session_id: str,
|
||||
last_event_id: str | None,
|
||||
) -> list[dict[str, Any]]:
|
||||
rows = await self._store.read_events(
|
||||
store = await self._get_store()
|
||||
rows = await store.read_events(
|
||||
session_id=UUID(session_id),
|
||||
last_event_id=last_event_id,
|
||||
)
|
||||
@@ -69,6 +111,6 @@ class RedisEventStream:
|
||||
def get_agent_service(session: AsyncSession = Depends(get_db)) -> AgentService:
|
||||
return AgentService(
|
||||
repository=AgentRepository(session),
|
||||
queue=CeleryQueueClient(),
|
||||
queue=TaskiqQueueClient(),
|
||||
stream=RedisEventStream(),
|
||||
)
|
||||
|
||||
@@ -5,10 +5,10 @@ from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
from fastapi import HTTPException
|
||||
from supabase import AuthError, create_client
|
||||
from supabase import AuthError
|
||||
|
||||
from core.config.settings import SupabaseSettings, config
|
||||
from core.logging import get_logger
|
||||
from services.base.supabase import supabase_service
|
||||
from v1.auth.schemas import (
|
||||
AuthUser,
|
||||
PasswordResetConfirmRequest,
|
||||
@@ -28,17 +28,16 @@ logger = get_logger("v1.auth.gateway")
|
||||
|
||||
|
||||
class SupabaseAuthGateway(AuthServiceGateway):
|
||||
_client: Any
|
||||
_admin_client: Any
|
||||
def _get_client(self) -> Any:
|
||||
return supabase_service.get_client()
|
||||
|
||||
def __init__(self) -> None:
|
||||
settings: SupabaseSettings = config.supabase
|
||||
self._client = create_client(settings.url, settings.anon_key)
|
||||
self._admin_client = create_client(settings.url, settings.service_role_key)
|
||||
def _get_admin_client(self) -> Any:
|
||||
return supabase_service.get_admin_client()
|
||||
|
||||
async def create_verification(
|
||||
self, request: VerificationCreateRequest
|
||||
) -> VerificationCreateResponse:
|
||||
client = self._get_client()
|
||||
metadata: dict[str, Any] = {"username": request.username}
|
||||
if request.invite_code:
|
||||
metadata["invite_code"] = request.invite_code
|
||||
@@ -50,7 +49,7 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
if request.redirect_to:
|
||||
payload["options"] = {"email_redirect_to": request.redirect_to}
|
||||
try:
|
||||
sign_up = cast(Any, self._client.auth.sign_up)
|
||||
sign_up = cast(Any, client.auth.sign_up)
|
||||
await asyncio.to_thread(sign_up, payload)
|
||||
return VerificationCreateResponse(email=request.email)
|
||||
except AuthError as exc:
|
||||
@@ -62,13 +61,14 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
async def verify_verification(
|
||||
self, request: VerificationVerifyRequest
|
||||
) -> SessionResponse:
|
||||
client = self._get_client()
|
||||
payload: dict[str, Any] = {
|
||||
"type": "signup",
|
||||
"email": request.email,
|
||||
"token": request.token,
|
||||
}
|
||||
try:
|
||||
verify_otp = cast(Any, self._client.auth.verify_otp)
|
||||
verify_otp = cast(Any, client.auth.verify_otp)
|
||||
response = await asyncio.to_thread(verify_otp, payload)
|
||||
return _map_auth_response(response, "Invalid verification code")
|
||||
except AuthError as exc:
|
||||
@@ -78,17 +78,19 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
) from exc
|
||||
|
||||
async def resend_verification(self, request: VerificationResendRequest) -> None:
|
||||
client = self._get_client()
|
||||
payload: dict[str, Any] = {"type": "signup", "email": request.email}
|
||||
try:
|
||||
resend = cast(Any, self._client.auth.resend)
|
||||
resend = cast(Any, client.auth.resend)
|
||||
await asyncio.to_thread(resend, payload)
|
||||
except AuthError as exc:
|
||||
logger.warning("Signup resend failed", error_type=type(exc).__name__)
|
||||
|
||||
async def create_session(self, request: SessionCreateRequest) -> SessionResponse:
|
||||
client = self._get_client()
|
||||
payload: dict[str, Any] = {"email": request.email, "password": request.password}
|
||||
try:
|
||||
sign_in = cast(Any, self._client.auth.sign_in_with_password)
|
||||
sign_in = cast(Any, client.auth.sign_in_with_password)
|
||||
response = await asyncio.to_thread(sign_in, payload)
|
||||
return _map_auth_response(response, "Invalid credentials")
|
||||
except AuthError as exc:
|
||||
@@ -96,9 +98,10 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
raise HTTPException(status_code=401, detail="Invalid credentials") from exc
|
||||
|
||||
async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse:
|
||||
client = self._get_client()
|
||||
try:
|
||||
response = await asyncio.to_thread(
|
||||
self._client.auth.refresh_session,
|
||||
client.auth.refresh_session,
|
||||
request.refresh_token,
|
||||
)
|
||||
return _map_auth_response(response, "Invalid refresh token")
|
||||
@@ -111,20 +114,21 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
async def delete_session(self, refresh_token: str | None) -> None:
|
||||
if not refresh_token:
|
||||
raise HTTPException(status_code=401, detail="Missing refresh token")
|
||||
client = self._get_client()
|
||||
try:
|
||||
response = await asyncio.to_thread(
|
||||
self._client.auth.refresh_session,
|
||||
client.auth.refresh_session,
|
||||
refresh_token,
|
||||
)
|
||||
session = getattr(response, "session", None)
|
||||
if session is None:
|
||||
raise HTTPException(status_code=401, detail="Invalid refresh token")
|
||||
await asyncio.to_thread(
|
||||
self._client.auth.set_session,
|
||||
client.auth.set_session,
|
||||
str(session.access_token),
|
||||
str(session.refresh_token),
|
||||
)
|
||||
await asyncio.to_thread(self._client.auth.sign_out)
|
||||
await asyncio.to_thread(client.auth.sign_out)
|
||||
except AuthError as exc:
|
||||
logger.warning("Logout failed", error_type=type(exc).__name__)
|
||||
raise HTTPException(
|
||||
@@ -132,7 +136,8 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
) from exc
|
||||
|
||||
async def get_user_by_email(self, email: str) -> UserByEmailResponse:
|
||||
users = await asyncio.to_thread(_list_auth_users, self._admin_client)
|
||||
admin_client = self._get_admin_client()
|
||||
users = await asyncio.to_thread(_list_auth_users, admin_client)
|
||||
normalized_email = email.lower()
|
||||
user = next(
|
||||
(
|
||||
@@ -157,8 +162,9 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
)
|
||||
|
||||
async def request_password_reset(self, request: PasswordResetRequest) -> None:
|
||||
client = self._get_client()
|
||||
try:
|
||||
reset_email = cast(Any, self._client.auth.reset_password_email)
|
||||
reset_email = cast(Any, client.auth.reset_password_email)
|
||||
email = _coerce_reset_email(request.email)
|
||||
if request.redirect_to:
|
||||
options: dict[str, str] = {"redirect_to": request.redirect_to}
|
||||
@@ -174,13 +180,15 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
async def confirm_password_reset(
|
||||
self, request: PasswordResetConfirmRequest
|
||||
) -> None:
|
||||
client = self._get_client()
|
||||
admin_client = self._get_admin_client()
|
||||
verify_payload: dict[str, Any] = {
|
||||
"type": "recovery",
|
||||
"email": request.email,
|
||||
"token": request.token,
|
||||
}
|
||||
try:
|
||||
verify_otp = cast(Any, self._client.auth.verify_otp)
|
||||
verify_otp = cast(Any, client.auth.verify_otp)
|
||||
response = await asyncio.to_thread(verify_otp, verify_payload)
|
||||
session = getattr(response, "session", None)
|
||||
user = getattr(response, "user", None)
|
||||
@@ -190,7 +198,7 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
status_code=401, detail="Invalid or expired verification code"
|
||||
)
|
||||
await asyncio.to_thread(
|
||||
self._admin_client.auth.admin.update_user_by_id,
|
||||
admin_client.auth.admin.update_user_by_id,
|
||||
user_id,
|
||||
{"password": request.new_password},
|
||||
)
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.infrastructure.queue.tasks import run_agent_task
|
||||
from core.agent.infrastructure.queue.tasks import _build_redis_publisher, run_agent_task
|
||||
|
||||
|
||||
class _FakeRunService:
|
||||
@@ -67,3 +67,35 @@ async def test_run_agent_task_emits_error_event_on_exception() -> None:
|
||||
)
|
||||
|
||||
assert events == ["RUN_STARTED", "RUN_ERROR"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_task_rejects_invalid_command() -> None:
|
||||
with pytest.raises(ValueError, match="invalid command type"):
|
||||
await run_agent_task({"command": "invalid", "session_id": "00000000-0000-0000-0000-000000000001"})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_task_resume_requires_tool_call_id() -> None:
|
||||
with pytest.raises(ValueError, match="tool_call_id is required"):
|
||||
await run_agent_task(
|
||||
{
|
||||
"command": "resume",
|
||||
"session_id": "00000000-0000-0000-0000-000000000001",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_redis_publisher_init_fail_raises_runtime_error(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
from core.agent.infrastructure.queue import tasks
|
||||
|
||||
async def _fake_get_client() -> object:
|
||||
raise RuntimeError("Redis service initialization failed")
|
||||
|
||||
monkeypatch.setattr(tasks, "get_or_init_redis_client", _fake_get_client)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Redis service initialization failed"):
|
||||
await _build_redis_publisher()
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.config.settings import Settings
|
||||
|
||||
|
||||
def test_taskiq_uses_redis_url_by_default() -> None:
|
||||
settings = Settings()
|
||||
|
||||
assert settings.taskiq_broker_url.startswith("redis://")
|
||||
assert settings.taskiq_result_backend_url.startswith("redis://")
|
||||
@@ -0,0 +1,37 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from core.taskiq.app import broker, bulk_broker, critical_broker, default_broker
|
||||
|
||||
|
||||
def test_taskiq_broker_is_configured() -> None:
|
||||
assert broker is not None
|
||||
assert default_broker is broker
|
||||
assert critical_broker is not None
|
||||
assert bulk_broker is not None
|
||||
|
||||
|
||||
def test_taskiq_app_configures_logging_on_import(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
sys.modules.pop("core.taskiq.app", None)
|
||||
sys.modules.pop("core.taskiq", None)
|
||||
|
||||
called = {"count": 0, "args": None}
|
||||
|
||||
def _fake_configure_logging(*args: object, **__: object) -> None:
|
||||
called["count"] += 1
|
||||
called["args"] = args
|
||||
|
||||
monkeypatch.setattr("core.logging.configure_logging", _fake_configure_logging)
|
||||
|
||||
importlib.import_module("core.taskiq.app")
|
||||
|
||||
from core.config.settings import config
|
||||
|
||||
assert called["count"] == 1
|
||||
assert called["args"] == (config,)
|
||||
@@ -0,0 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
ROOT_DIR = Path(__file__).resolve().parents[4]
|
||||
APP_SCRIPT = ROOT_DIR / "infra" / "scripts" / "app.sh"
|
||||
|
||||
|
||||
def test_worker_commands_use_taskiq() -> None:
|
||||
content = APP_SCRIPT.read_text(encoding="utf-8")
|
||||
removed_runner = "uv run c" "elery"
|
||||
|
||||
assert "uv run taskiq worker" in content
|
||||
assert "core.taskiq.app:critical_broker" in content
|
||||
assert "core.taskiq.app:default_broker" in content
|
||||
assert "core.taskiq.app:bulk_broker" in content
|
||||
assert 'pgrep -f "taskiq.*worker"' in content
|
||||
assert 'pkill -f "taskiq.*worker"' in content
|
||||
assert removed_runner not in content
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import pytest
|
||||
|
||||
from core.config.settings import RedisSettings
|
||||
from services.base.redis import RedisService
|
||||
from services.base.redis import RedisService, get_or_init_redis_client, redis_service
|
||||
|
||||
|
||||
class _FakeRedisClient:
|
||||
@@ -96,3 +96,35 @@ def test_get_client_raises_before_init() -> None:
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
service.get_client()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_or_init_redis_client_initializes_when_needed(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
fake_client = _FakeRedisClient()
|
||||
|
||||
async def _fake_initialize() -> bool:
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(type(redis_service), "is_initialized", property(lambda _: False))
|
||||
monkeypatch.setattr(redis_service, "initialize", _fake_initialize)
|
||||
monkeypatch.setattr(redis_service, "get_client", lambda: fake_client)
|
||||
|
||||
client = await get_or_init_redis_client()
|
||||
|
||||
assert client is fake_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_or_init_redis_client_raises_when_init_fails(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
async def _fake_initialize() -> bool:
|
||||
return False
|
||||
|
||||
monkeypatch.setattr(type(redis_service), "is_initialized", property(lambda _: False))
|
||||
monkeypatch.setattr(redis_service, "initialize", _fake_initialize)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Redis service initialization failed"):
|
||||
await get_or_init_redis_client()
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from services.base.service_interface import (
|
||||
BaseServiceProvider,
|
||||
ServiceRegistry,
|
||||
close_registered_services,
|
||||
initialize_registered_services,
|
||||
register_service,
|
||||
register_service_instance,
|
||||
)
|
||||
@@ -35,6 +39,17 @@ def test_register_service_and_create_service() -> None:
|
||||
assert created.get_service_info()["name"] == "dummy"
|
||||
|
||||
|
||||
def test_register_service_and_get_service() -> None:
|
||||
@register_service("dummy-service-get")
|
||||
class _RegisteredService(_DummyService):
|
||||
pass
|
||||
|
||||
resolved = ServiceRegistry.get_service("dummy-service-get")
|
||||
|
||||
assert resolved is not None
|
||||
assert resolved.get_service_info()["name"] == "dummy"
|
||||
|
||||
|
||||
def test_register_service_instance_returns_same_instance() -> None:
|
||||
instance = _DummyService("singleton")
|
||||
|
||||
@@ -47,3 +62,77 @@ def test_register_service_instance_returns_same_instance() -> None:
|
||||
|
||||
def test_create_service_returns_none_for_missing() -> None:
|
||||
assert ServiceRegistry.create_service("missing-service") is None
|
||||
|
||||
|
||||
def test_get_service_returns_none_for_missing() -> None:
|
||||
assert ServiceRegistry.get_service("missing-service") is None
|
||||
|
||||
|
||||
class _LifecycleService(BaseServiceProvider):
|
||||
def __init__(self, name: str, recorder: list[str], fail_on_init: bool = False) -> None:
|
||||
super().__init__(name)
|
||||
self._recorder = recorder
|
||||
self._fail_on_init = fail_on_init
|
||||
|
||||
async def initialize(self, **_: object) -> bool:
|
||||
self._recorder.append(f"init:{self.service_name}")
|
||||
if self._fail_on_init:
|
||||
return False
|
||||
self._set_initialized(True)
|
||||
return True
|
||||
|
||||
async def close(self) -> bool:
|
||||
self._recorder.append(f"close:{self.service_name}")
|
||||
self._set_initialized(False)
|
||||
return True
|
||||
|
||||
async def health_check(self) -> dict[str, object]:
|
||||
return {"status": "healthy", "details": {}}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_registered_services_success() -> None:
|
||||
recorder: list[str] = []
|
||||
first = register_service_instance(
|
||||
"lifecycle-success-first", _LifecycleService("first", recorder)
|
||||
)
|
||||
second = register_service_instance(
|
||||
"lifecycle-success-second", _LifecycleService("second", recorder)
|
||||
)
|
||||
|
||||
initialized, services = await initialize_registered_services(
|
||||
["lifecycle-success-first", "lifecycle-success-second"]
|
||||
)
|
||||
|
||||
assert initialized is True
|
||||
assert services == [first, second]
|
||||
assert recorder == ["init:first", "init:second"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_registered_services_failure_rolls_back() -> None:
|
||||
recorder: list[str] = []
|
||||
register_service_instance("lifecycle-fail-first", _LifecycleService("first", recorder))
|
||||
register_service_instance(
|
||||
"lifecycle-fail-second", _LifecycleService("second", recorder, fail_on_init=True)
|
||||
)
|
||||
|
||||
initialized, services = await initialize_registered_services(
|
||||
["lifecycle-fail-first", "lifecycle-fail-second"]
|
||||
)
|
||||
|
||||
assert initialized is False
|
||||
assert services == []
|
||||
assert recorder == ["init:first", "init:second", "close:first"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_registered_services_closes_in_reverse_order() -> None:
|
||||
recorder: list[str] = []
|
||||
first = _LifecycleService("first", recorder)
|
||||
second = _LifecycleService("second", recorder)
|
||||
|
||||
closed = await close_registered_services([first, second])
|
||||
|
||||
assert closed is True
|
||||
assert recorder == ["close:second", "close:first"]
|
||||
|
||||
@@ -0,0 +1,111 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.config.settings import SupabaseSettings
|
||||
from services.base.supabase import SupabaseService
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_success(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
service = SupabaseService(settings=SupabaseSettings())
|
||||
anon_client = MagicMock()
|
||||
admin_client = MagicMock()
|
||||
|
||||
create_calls: list[tuple[str, str]] = []
|
||||
|
||||
def _fake_create_client(url: str, key: str) -> object:
|
||||
create_calls.append((url, key))
|
||||
return anon_client if len(create_calls) == 1 else admin_client
|
||||
|
||||
monkeypatch.setattr("services.base.supabase.create_client", _fake_create_client)
|
||||
|
||||
result = await service.initialize()
|
||||
|
||||
assert result is True
|
||||
assert service.is_initialized is True
|
||||
assert service.get_client() is anon_client
|
||||
assert service.get_admin_client() is admin_client
|
||||
assert len(create_calls) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_failure(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
service = SupabaseService(settings=SupabaseSettings())
|
||||
|
||||
def _fake_create_client(_: str, __: str) -> object:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
monkeypatch.setattr("services.base.supabase.create_client", _fake_create_client)
|
||||
|
||||
result = await service.initialize()
|
||||
|
||||
assert result is False
|
||||
assert service.is_initialized is False
|
||||
with pytest.raises(RuntimeError):
|
||||
service.get_client()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_clears_clients(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
service = SupabaseService(settings=SupabaseSettings())
|
||||
|
||||
def _fake_create_client(_: str, __: str) -> object:
|
||||
return MagicMock()
|
||||
|
||||
monkeypatch.setattr("services.base.supabase.create_client", _fake_create_client)
|
||||
|
||||
assert await service.initialize() is True
|
||||
assert await service.close() is True
|
||||
assert service.is_initialized is False
|
||||
with pytest.raises(RuntimeError):
|
||||
service.get_client()
|
||||
with pytest.raises(RuntimeError):
|
||||
service.get_admin_client()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_uninitialized() -> None:
|
||||
service = SupabaseService(settings=SupabaseSettings())
|
||||
|
||||
health = await service.health_check()
|
||||
|
||||
assert health["status"] == "unhealthy"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_initialized(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
service = SupabaseService(settings=SupabaseSettings())
|
||||
|
||||
anon_client = MagicMock()
|
||||
anon_client.auth.get_session = MagicMock(return_value=None)
|
||||
|
||||
admin_list_users = MagicMock(return_value=SimpleNamespace(users=[]))
|
||||
admin_client = MagicMock()
|
||||
admin_client.auth.admin = SimpleNamespace(list_users=admin_list_users)
|
||||
|
||||
create_sequence = [anon_client, admin_client]
|
||||
|
||||
def _fake_create_client(_: str, __: str) -> object:
|
||||
return create_sequence.pop(0)
|
||||
|
||||
monkeypatch.setattr("services.base.supabase.create_client", _fake_create_client)
|
||||
|
||||
assert await service.initialize() is True
|
||||
|
||||
health = await service.health_check()
|
||||
|
||||
assert health["status"] == "healthy"
|
||||
admin_list_users.assert_called_once_with(page=1, per_page=1)
|
||||
|
||||
|
||||
def test_get_client_raises_before_init() -> None:
|
||||
service = SupabaseService(settings=SupabaseSettings())
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
service.get_client()
|
||||
with pytest.raises(RuntimeError):
|
||||
service.get_admin_client()
|
||||
@@ -0,0 +1,43 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
import app as app_module
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifespan_uses_registered_services(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
initialized_services = [object(), object()]
|
||||
calls: dict[str, object] = {}
|
||||
|
||||
async def _fake_initialize(service_names: list[str]) -> tuple[bool, list[object]]:
|
||||
calls["init_names"] = service_names
|
||||
return True, initialized_services
|
||||
|
||||
async def _fake_close(services: list[object]) -> bool:
|
||||
calls["close_services"] = services
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(app_module, "initialize_registered_services", _fake_initialize)
|
||||
monkeypatch.setattr(app_module, "close_registered_services", _fake_close)
|
||||
|
||||
context = app_module.lifespan(app_module.app)
|
||||
await context.__aenter__()
|
||||
await context.__aexit__(None, None, None)
|
||||
|
||||
assert calls["init_names"] == ["redis", "supabase"]
|
||||
assert calls["close_services"] == initialized_services
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifespan_raises_when_initialization_failed(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
async def _fake_initialize(_: list[str]) -> tuple[bool, list[object]]:
|
||||
return False, []
|
||||
|
||||
monkeypatch.setattr(app_module, "initialize_registered_services", _fake_initialize)
|
||||
|
||||
context = app_module.lifespan(app_module.app)
|
||||
with pytest.raises(RuntimeError, match="Service initialization failed"):
|
||||
await context.__aenter__()
|
||||
@@ -1,45 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from celery import Celery
|
||||
from pytest import MonkeyPatch
|
||||
|
||||
from core.logging import celery as celery_logging
|
||||
from core.logging.context import clear_context, get_context
|
||||
|
||||
|
||||
class DummyTask:
|
||||
name: str = "tasks.sample"
|
||||
|
||||
|
||||
def test_celery_prerun_binds_task_context() -> None:
|
||||
handlers = celery_logging.build_celery_signal_handlers()
|
||||
|
||||
handlers.on_task_prerun(task_id="task-123", task=DummyTask())
|
||||
context = get_context()
|
||||
|
||||
assert context["task_id"] == "task-123"
|
||||
assert context["task_name"] == "tasks.sample"
|
||||
|
||||
clear_context()
|
||||
|
||||
|
||||
def test_celery_setup_logging_calls_configure(monkeypatch: MonkeyPatch) -> None:
|
||||
called = {"value": False}
|
||||
|
||||
def fake_configure_logging(settings: object | None = None) -> None:
|
||||
called["value"] = True
|
||||
|
||||
monkeypatch.setattr(celery_logging, "configure_logging", fake_configure_logging)
|
||||
handlers = celery_logging.build_celery_signal_handlers()
|
||||
|
||||
handlers.on_setup_logging()
|
||||
|
||||
assert called["value"] is True
|
||||
|
||||
|
||||
def test_configure_celery_app_disables_hijack() -> None:
|
||||
app = Celery("test")
|
||||
|
||||
celery_logging.configure_celery_app(app)
|
||||
|
||||
assert app.conf.worker_hijack_root_logger is False
|
||||
@@ -0,0 +1,227 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from v1.agent.dependencies import TaskiqQueueClient
|
||||
|
||||
|
||||
class _FakeRedis:
|
||||
def __init__(self) -> None:
|
||||
self.store: dict[str, str] = {}
|
||||
self.delete_calls: list[str] = []
|
||||
|
||||
async def set(
|
||||
self,
|
||||
key: str,
|
||||
value: str,
|
||||
*,
|
||||
nx: bool = False,
|
||||
ex: int | None = None,
|
||||
) -> bool:
|
||||
del ex
|
||||
if nx and key in self.store:
|
||||
return False
|
||||
self.store[key] = value
|
||||
return True
|
||||
|
||||
async def get(self, key: str) -> str | None:
|
||||
return self.store.get(key)
|
||||
|
||||
async def delete(self, key: str) -> int:
|
||||
self.delete_calls.append(key)
|
||||
existed = 1 if key in self.store else 0
|
||||
self.store.pop(key, None)
|
||||
return existed
|
||||
|
||||
|
||||
class _FakeAsyncResult:
|
||||
def __init__(self, task_id: str) -> None:
|
||||
self.task_id = task_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_returns_task_id(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
from v1.agent import dependencies as deps
|
||||
|
||||
fake_redis = _FakeRedis()
|
||||
resolved_client = {"value": False}
|
||||
|
||||
async def _fake_kiq(payload: dict[str, object]) -> _FakeAsyncResult:
|
||||
assert payload["command"] == "run"
|
||||
return _FakeAsyncResult("task-123")
|
||||
|
||||
async def _fake_get_or_init_client() -> _FakeRedis:
|
||||
resolved_client["value"] = True
|
||||
return fake_redis
|
||||
|
||||
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
|
||||
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_kiq)
|
||||
|
||||
client = TaskiqQueueClient()
|
||||
task_id = await client.enqueue(command={"command": "run"}, dedup_key=None)
|
||||
|
||||
assert resolved_client["value"] is True
|
||||
assert task_id == "task-123"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_resume_dedup_returns_existing_task_id(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
from v1.agent import dependencies as deps
|
||||
|
||||
fake_redis = _FakeRedis()
|
||||
resolved_client = {"value": False}
|
||||
|
||||
async def _fake_kiq(payload: dict[str, object]) -> _FakeAsyncResult:
|
||||
del payload
|
||||
return _FakeAsyncResult("new-task-id")
|
||||
|
||||
async def _fake_get_or_init_client() -> _FakeRedis:
|
||||
resolved_client["value"] = True
|
||||
return fake_redis
|
||||
|
||||
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
|
||||
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_kiq)
|
||||
|
||||
dedup_key = "resume:session-1:call-1"
|
||||
fake_redis.store[f"agent:dedup:{dedup_key}"] = "existing-task-id"
|
||||
|
||||
client = TaskiqQueueClient()
|
||||
task_id = await client.enqueue(
|
||||
command={"command": "resume", "session_id": "session-1", "tool_call_id": "call-1"},
|
||||
dedup_key=dedup_key,
|
||||
)
|
||||
|
||||
assert resolved_client["value"] is True
|
||||
assert task_id == "existing-task-id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_inflight_dedup_waits_and_reuses_existing_task_id(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
from v1.agent import dependencies as deps
|
||||
|
||||
fake_redis = _FakeRedis()
|
||||
dedup_key = "resume:session-1:call-1"
|
||||
redis_key = f"agent:dedup:{dedup_key}"
|
||||
fake_redis.store[redis_key] = deps.DEDUP_INFLIGHT_MARKER
|
||||
attempts = {"count": 0}
|
||||
|
||||
async def _fake_get_or_init_client() -> _FakeRedis:
|
||||
return fake_redis
|
||||
|
||||
async def _fake_get(key: str) -> str | None:
|
||||
attempts["count"] += 1
|
||||
if attempts["count"] > 1:
|
||||
fake_redis.store[key] = "existing-task-id"
|
||||
return fake_redis.store.get(key)
|
||||
|
||||
async def _fake_sleep(_: float) -> None:
|
||||
return None
|
||||
|
||||
async def _fake_kiq(payload: dict[str, object]) -> _FakeAsyncResult:
|
||||
del payload
|
||||
raise AssertionError("should not enqueue when dedup task id appears")
|
||||
|
||||
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
|
||||
monkeypatch.setattr(fake_redis, "get", _fake_get)
|
||||
monkeypatch.setattr(deps.asyncio, "sleep", _fake_sleep)
|
||||
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_kiq)
|
||||
|
||||
client = TaskiqQueueClient()
|
||||
task_id = await client.enqueue(
|
||||
command={"command": "resume", "session_id": "session-1", "tool_call_id": "call-1"},
|
||||
dedup_key=dedup_key,
|
||||
)
|
||||
|
||||
assert task_id == "existing-task-id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_failure_cleans_dedup_lock(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
from v1.agent import dependencies as deps
|
||||
|
||||
fake_redis = _FakeRedis()
|
||||
dedup_key = "resume:session-1:call-1"
|
||||
redis_key = f"agent:dedup:{dedup_key}"
|
||||
|
||||
async def _fake_get_or_init_client() -> _FakeRedis:
|
||||
return fake_redis
|
||||
|
||||
async def _fake_kiq(payload: dict[str, object]) -> _FakeAsyncResult:
|
||||
del payload
|
||||
raise RuntimeError("enqueue failed")
|
||||
|
||||
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
|
||||
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_kiq)
|
||||
|
||||
client = TaskiqQueueClient()
|
||||
with pytest.raises(RuntimeError, match="enqueue failed"):
|
||||
await client.enqueue(
|
||||
command={"command": "resume", "session_id": "session-1", "tool_call_id": "call-1"},
|
||||
dedup_key=dedup_key,
|
||||
)
|
||||
|
||||
assert redis_key in fake_redis.delete_calls
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_uses_critical_queue_when_requested(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
from v1.agent import dependencies as deps
|
||||
|
||||
fake_redis = _FakeRedis()
|
||||
|
||||
async def _fake_get_or_init_client() -> _FakeRedis:
|
||||
return fake_redis
|
||||
|
||||
async def _fake_default_kiq(_: dict[str, object]) -> _FakeAsyncResult:
|
||||
raise AssertionError("default queue should not be selected")
|
||||
|
||||
async def _fake_critical_kiq(_: dict[str, object]) -> _FakeAsyncResult:
|
||||
return _FakeAsyncResult("critical-task-id")
|
||||
|
||||
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
|
||||
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_default_kiq)
|
||||
monkeypatch.setattr(deps.run_command_task_critical, "kiq", _fake_critical_kiq)
|
||||
|
||||
client = TaskiqQueueClient()
|
||||
task_id = await client.enqueue(
|
||||
command={"command": "run", "queue": "critical"},
|
||||
dedup_key=None,
|
||||
)
|
||||
|
||||
assert task_id == "critical-task-id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_uses_bulk_queue_when_requested(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
from v1.agent import dependencies as deps
|
||||
|
||||
fake_redis = _FakeRedis()
|
||||
|
||||
async def _fake_get_or_init_client() -> _FakeRedis:
|
||||
return fake_redis
|
||||
|
||||
async def _fake_default_kiq(_: dict[str, object]) -> _FakeAsyncResult:
|
||||
raise AssertionError("default queue should not be selected")
|
||||
|
||||
async def _fake_bulk_kiq(_: dict[str, object]) -> _FakeAsyncResult:
|
||||
return _FakeAsyncResult("bulk-task-id")
|
||||
|
||||
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
|
||||
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_default_kiq)
|
||||
monkeypatch.setattr(deps.run_command_task_bulk, "kiq", _fake_bulk_kiq)
|
||||
|
||||
client = TaskiqQueueClient()
|
||||
task_id = await client.enqueue(
|
||||
command={"command": "run", "queue": "bulk"},
|
||||
dedup_key=None,
|
||||
)
|
||||
|
||||
assert task_id == "bulk-task-id"
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
@@ -12,37 +12,44 @@ from v1.auth.schemas import PasswordResetConfirmRequest, PasswordResetRequest
|
||||
|
||||
class TestSupabaseAuthGateway:
|
||||
@pytest.fixture
|
||||
def gateway(self) -> SupabaseAuthGateway:
|
||||
with patch("v1.auth.gateway.create_client") as mock_create:
|
||||
mock_client = MagicMock()
|
||||
mock_admin_client = MagicMock()
|
||||
mock_create.side_effect = [mock_client, mock_admin_client]
|
||||
return SupabaseAuthGateway()
|
||||
def gateway(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> tuple[SupabaseAuthGateway, MagicMock, MagicMock]:
|
||||
mock_client = MagicMock()
|
||||
mock_admin_client = MagicMock()
|
||||
monkeypatch.setattr("v1.auth.gateway.supabase_service.get_client", lambda: mock_client)
|
||||
monkeypatch.setattr(
|
||||
"v1.auth.gateway.supabase_service.get_admin_client",
|
||||
lambda: mock_admin_client,
|
||||
)
|
||||
return SupabaseAuthGateway(), mock_client, mock_admin_client
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_password_reset_calls_email_with_string(
|
||||
self, gateway: SupabaseAuthGateway
|
||||
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
|
||||
) -> None:
|
||||
sut, mock_client, _ = gateway
|
||||
mock_reset_email = MagicMock()
|
||||
gateway._client.auth.reset_password_email = mock_reset_email
|
||||
mock_client.auth.reset_password_email = mock_reset_email
|
||||
|
||||
request = PasswordResetRequest(email="test@example.com")
|
||||
await gateway.request_password_reset(request)
|
||||
await sut.request_password_reset(request)
|
||||
|
||||
mock_reset_email.assert_called_once_with("test@example.com")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_password_reset_with_redirect(
|
||||
self, gateway: SupabaseAuthGateway
|
||||
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
|
||||
) -> None:
|
||||
sut, mock_client, _ = gateway
|
||||
mock_reset_email = MagicMock()
|
||||
gateway._client.auth.reset_password_email = mock_reset_email
|
||||
mock_client.auth.reset_password_email = mock_reset_email
|
||||
|
||||
request = PasswordResetRequest(
|
||||
email="test@example.com",
|
||||
redirect_to="http://localhost:3000/reset-password",
|
||||
)
|
||||
await gateway.request_password_reset(request)
|
||||
await sut.request_password_reset(request)
|
||||
|
||||
mock_reset_email.assert_called_once_with(
|
||||
"test@example.com",
|
||||
@@ -51,64 +58,68 @@ class TestSupabaseAuthGateway:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_password_reset_swallows_auth_error(
|
||||
self, gateway: SupabaseAuthGateway
|
||||
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
|
||||
) -> None:
|
||||
sut, mock_client, _ = gateway
|
||||
from supabase import AuthError
|
||||
|
||||
mock_reset_email = MagicMock(side_effect=AuthError("rate limit exceeded", None))
|
||||
gateway._client.auth.reset_password_email = mock_reset_email
|
||||
mock_client.auth.reset_password_email = mock_reset_email
|
||||
|
||||
request = PasswordResetRequest(email="test@example.com")
|
||||
|
||||
result = await gateway.request_password_reset(request)
|
||||
result = await sut.request_password_reset(request)
|
||||
|
||||
mock_reset_email.assert_called_once()
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_password_reset_extracts_email_from_mapping(
|
||||
self, gateway: SupabaseAuthGateway
|
||||
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
|
||||
) -> None:
|
||||
sut, mock_client, _ = gateway
|
||||
mock_reset_email = MagicMock()
|
||||
gateway._client.auth.reset_password_email = mock_reset_email
|
||||
mock_client.auth.reset_password_email = mock_reset_email
|
||||
|
||||
request = PasswordResetRequest.model_construct(
|
||||
email={"email": "test@example.com"},
|
||||
redirect_to=None,
|
||||
)
|
||||
|
||||
await gateway.request_password_reset(request)
|
||||
await sut.request_password_reset(request)
|
||||
|
||||
mock_reset_email.assert_called_once_with("test@example.com")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_password_reset_rejects_invalid_email_shape(
|
||||
self, gateway: SupabaseAuthGateway
|
||||
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
|
||||
) -> None:
|
||||
sut, _, _ = gateway
|
||||
request = PasswordResetRequest.model_construct(
|
||||
email={"unexpected": "value"},
|
||||
redirect_to=None,
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await gateway.request_password_reset(request)
|
||||
await sut.request_password_reset(request)
|
||||
|
||||
assert exc_info.value.status_code == 422
|
||||
assert exc_info.value.detail == "Invalid email"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_confirm_password_reset_updates_password_by_user_id(
|
||||
self, gateway: SupabaseAuthGateway
|
||||
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
|
||||
) -> None:
|
||||
sut, mock_client, mock_admin_client = gateway
|
||||
verify_response = SimpleNamespace(
|
||||
session=SimpleNamespace(access_token="access"),
|
||||
user=SimpleNamespace(id="user-1"),
|
||||
)
|
||||
mock_verify_otp = MagicMock(return_value=verify_response)
|
||||
gateway._client.auth.verify_otp = mock_verify_otp
|
||||
mock_client.auth.verify_otp = mock_verify_otp
|
||||
|
||||
mock_update_user_by_id = MagicMock()
|
||||
gateway._admin_client.auth.admin = SimpleNamespace(
|
||||
mock_admin_client.auth.admin = SimpleNamespace(
|
||||
update_user_by_id=mock_update_user_by_id
|
||||
)
|
||||
|
||||
@@ -118,7 +129,7 @@ class TestSupabaseAuthGateway:
|
||||
new_password="newpassword123",
|
||||
)
|
||||
|
||||
await gateway.confirm_password_reset(request)
|
||||
await sut.confirm_password_reset(request)
|
||||
|
||||
mock_verify_otp.assert_called_once_with(
|
||||
{
|
||||
@@ -134,13 +145,14 @@ class TestSupabaseAuthGateway:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_confirm_password_reset_raises_when_user_id_missing(
|
||||
self, gateway: SupabaseAuthGateway
|
||||
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
|
||||
) -> None:
|
||||
sut, mock_client, _ = gateway
|
||||
verify_response = SimpleNamespace(
|
||||
session=SimpleNamespace(access_token="access"),
|
||||
user=SimpleNamespace(id=""),
|
||||
)
|
||||
gateway._client.auth.verify_otp = MagicMock(return_value=verify_response)
|
||||
mock_client.auth.verify_otp = MagicMock(return_value=verify_response)
|
||||
|
||||
request = PasswordResetConfirmRequest(
|
||||
email="test@example.com",
|
||||
@@ -149,7 +161,7 @@ class TestSupabaseAuthGateway:
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await gateway.confirm_password_reset(request)
|
||||
await sut.confirm_password_reset(request)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert exc_info.value.detail == "Invalid or expired verification code"
|
||||
|
||||
Reference in New Issue
Block a user