chore: checkpoint current backend/runtime changes

This commit is contained in:
qzl
2026-03-06 17:28:17 +08:00
parent 2c59fe5ee2
commit b6087fd195
32 changed files with 1641 additions and 469 deletions
+69 -27
View File
@@ -1,57 +1,98 @@
from __future__ import annotations
from typing import Any, cast
import asyncio
from typing import Any
from uuid import UUID
from fastapi import Depends
from redis.asyncio import Redis
from sqlalchemy.ext.asyncio import AsyncSession
from core.agent.infrastructure.events.redis_stream import RedisStreamEventStore
from core.agent.infrastructure.queue.tasks import run_command_task
from core.agent.infrastructure.queue.tasks import (
run_command_task,
run_command_task_bulk,
run_command_task_critical,
)
from core.config.settings import config
from core.db import get_db
from services.base.redis import redis_service
from services.base.redis import get_or_init_redis_client
from v1.agent.repository import AgentRepository
from v1.agent.service import AgentService
DEDUP_WAIT_RETRIES = 20
DEDUP_WAIT_SECONDS = 0.05
DEDUP_LOCK_SECONDS = 300
DEDUP_INFLIGHT_MARKER = "__inflight__"
class CeleryQueueClient:
class TaskiqQueueClient:
def __init__(self) -> None:
self._redis = redis_service.get_client()
self._redis: Redis | None = None
async def _get_redis(self) -> Redis:
if self._redis is None:
self._redis = await get_or_init_redis_client()
return self._redis
@staticmethod
def _select_queue_task(command: dict[str, object]) -> Any:
queue = str(command.get("queue", "default")).strip().lower()
if queue == "critical":
return run_command_task_critical
if queue == "bulk":
return run_command_task_bulk
return run_command_task
async def enqueue(
self, *, command: dict[str, object], dedup_key: str | None
) -> str:
redis_client = await self._get_redis()
redis_key = None
if dedup_key:
redis_key = f"agent:dedup:{dedup_key}"
locked = await self._redis.set(redis_key, "__inflight__", nx=True, ex=300)
locked = await redis_client.set(
redis_key,
DEDUP_INFLIGHT_MARKER,
nx=True,
ex=DEDUP_LOCK_SECONDS,
)
if not locked:
existing = await self._redis.get(redis_key)
if existing and existing != "__inflight__":
return existing
for _ in range(DEDUP_WAIT_RETRIES):
existing = await redis_client.get(redis_key)
if existing and existing != DEDUP_INFLIGHT_MARKER:
return existing
await asyncio.sleep(DEDUP_WAIT_SECONDS)
raise RuntimeError("duplicate request is still in progress")
payload = dict(command)
if dedup_key:
payload["dedup_key"] = dedup_key
delay = getattr(run_command_task, "delay")
result = delay(payload)
task_id = str(result.id)
if redis_key is not None:
await self._redis.set(redis_key, task_id, ex=300)
return task_id
queue_task = self._select_queue_task(payload)
try:
result = await queue_task.kiq(payload)
task_id = str(result.task_id)
if redis_key is not None:
await redis_client.set(redis_key, task_id, ex=DEDUP_LOCK_SECONDS)
return task_id
except Exception:
if redis_key is not None:
await redis_client.delete(redis_key)
raise
class RedisEventStream:
def __init__(self) -> None:
settings = cast(Any, config)
client = redis_service.get_client()
self._store = RedisStreamEventStore(
client=client,
stream_prefix=settings.agent_runtime.redis_stream_prefix,
read_count=settings.agent_runtime.redis_stream_read_count,
block_ms=settings.agent_runtime.redis_stream_block_ms,
)
self._store: RedisStreamEventStore | None = None
async def _get_store(self) -> RedisStreamEventStore:
if self._store is None:
client = await get_or_init_redis_client()
self._store = RedisStreamEventStore(
client=client,
stream_prefix=config.agent_runtime.redis_stream_prefix,
read_count=config.agent_runtime.redis_stream_read_count,
block_ms=config.agent_runtime.redis_stream_block_ms,
)
return self._store
async def read(
self,
@@ -59,7 +100,8 @@ class RedisEventStream:
session_id: str,
last_event_id: str | None,
) -> list[dict[str, Any]]:
rows = await self._store.read_events(
store = await self._get_store()
rows = await store.read_events(
session_id=UUID(session_id),
last_event_id=last_event_id,
)
@@ -69,6 +111,6 @@ class RedisEventStream:
def get_agent_service(session: AsyncSession = Depends(get_db)) -> AgentService:
return AgentService(
repository=AgentRepository(session),
queue=CeleryQueueClient(),
queue=TaskiqQueueClient(),
stream=RedisEventStream(),
)
+28 -20
View File
@@ -5,10 +5,10 @@ from collections.abc import Mapping
from typing import Any, cast
from fastapi import HTTPException
from supabase import AuthError, create_client
from supabase import AuthError
from core.config.settings import SupabaseSettings, config
from core.logging import get_logger
from services.base.supabase import supabase_service
from v1.auth.schemas import (
AuthUser,
PasswordResetConfirmRequest,
@@ -28,17 +28,16 @@ logger = get_logger("v1.auth.gateway")
class SupabaseAuthGateway(AuthServiceGateway):
_client: Any
_admin_client: Any
def _get_client(self) -> Any:
return supabase_service.get_client()
def __init__(self) -> None:
settings: SupabaseSettings = config.supabase
self._client = create_client(settings.url, settings.anon_key)
self._admin_client = create_client(settings.url, settings.service_role_key)
def _get_admin_client(self) -> Any:
return supabase_service.get_admin_client()
async def create_verification(
self, request: VerificationCreateRequest
) -> VerificationCreateResponse:
client = self._get_client()
metadata: dict[str, Any] = {"username": request.username}
if request.invite_code:
metadata["invite_code"] = request.invite_code
@@ -50,7 +49,7 @@ class SupabaseAuthGateway(AuthServiceGateway):
if request.redirect_to:
payload["options"] = {"email_redirect_to": request.redirect_to}
try:
sign_up = cast(Any, self._client.auth.sign_up)
sign_up = cast(Any, client.auth.sign_up)
await asyncio.to_thread(sign_up, payload)
return VerificationCreateResponse(email=request.email)
except AuthError as exc:
@@ -62,13 +61,14 @@ class SupabaseAuthGateway(AuthServiceGateway):
async def verify_verification(
self, request: VerificationVerifyRequest
) -> SessionResponse:
client = self._get_client()
payload: dict[str, Any] = {
"type": "signup",
"email": request.email,
"token": request.token,
}
try:
verify_otp = cast(Any, self._client.auth.verify_otp)
verify_otp = cast(Any, client.auth.verify_otp)
response = await asyncio.to_thread(verify_otp, payload)
return _map_auth_response(response, "Invalid verification code")
except AuthError as exc:
@@ -78,17 +78,19 @@ class SupabaseAuthGateway(AuthServiceGateway):
) from exc
async def resend_verification(self, request: VerificationResendRequest) -> None:
client = self._get_client()
payload: dict[str, Any] = {"type": "signup", "email": request.email}
try:
resend = cast(Any, self._client.auth.resend)
resend = cast(Any, client.auth.resend)
await asyncio.to_thread(resend, payload)
except AuthError as exc:
logger.warning("Signup resend failed", error_type=type(exc).__name__)
async def create_session(self, request: SessionCreateRequest) -> SessionResponse:
client = self._get_client()
payload: dict[str, Any] = {"email": request.email, "password": request.password}
try:
sign_in = cast(Any, self._client.auth.sign_in_with_password)
sign_in = cast(Any, client.auth.sign_in_with_password)
response = await asyncio.to_thread(sign_in, payload)
return _map_auth_response(response, "Invalid credentials")
except AuthError as exc:
@@ -96,9 +98,10 @@ class SupabaseAuthGateway(AuthServiceGateway):
raise HTTPException(status_code=401, detail="Invalid credentials") from exc
async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse:
client = self._get_client()
try:
response = await asyncio.to_thread(
self._client.auth.refresh_session,
client.auth.refresh_session,
request.refresh_token,
)
return _map_auth_response(response, "Invalid refresh token")
@@ -111,20 +114,21 @@ class SupabaseAuthGateway(AuthServiceGateway):
async def delete_session(self, refresh_token: str | None) -> None:
if not refresh_token:
raise HTTPException(status_code=401, detail="Missing refresh token")
client = self._get_client()
try:
response = await asyncio.to_thread(
self._client.auth.refresh_session,
client.auth.refresh_session,
refresh_token,
)
session = getattr(response, "session", None)
if session is None:
raise HTTPException(status_code=401, detail="Invalid refresh token")
await asyncio.to_thread(
self._client.auth.set_session,
client.auth.set_session,
str(session.access_token),
str(session.refresh_token),
)
await asyncio.to_thread(self._client.auth.sign_out)
await asyncio.to_thread(client.auth.sign_out)
except AuthError as exc:
logger.warning("Logout failed", error_type=type(exc).__name__)
raise HTTPException(
@@ -132,7 +136,8 @@ class SupabaseAuthGateway(AuthServiceGateway):
) from exc
async def get_user_by_email(self, email: str) -> UserByEmailResponse:
users = await asyncio.to_thread(_list_auth_users, self._admin_client)
admin_client = self._get_admin_client()
users = await asyncio.to_thread(_list_auth_users, admin_client)
normalized_email = email.lower()
user = next(
(
@@ -157,8 +162,9 @@ class SupabaseAuthGateway(AuthServiceGateway):
)
async def request_password_reset(self, request: PasswordResetRequest) -> None:
client = self._get_client()
try:
reset_email = cast(Any, self._client.auth.reset_password_email)
reset_email = cast(Any, client.auth.reset_password_email)
email = _coerce_reset_email(request.email)
if request.redirect_to:
options: dict[str, str] = {"redirect_to": request.redirect_to}
@@ -174,13 +180,15 @@ class SupabaseAuthGateway(AuthServiceGateway):
async def confirm_password_reset(
self, request: PasswordResetConfirmRequest
) -> None:
client = self._get_client()
admin_client = self._get_admin_client()
verify_payload: dict[str, Any] = {
"type": "recovery",
"email": request.email,
"token": request.token,
}
try:
verify_otp = cast(Any, self._client.auth.verify_otp)
verify_otp = cast(Any, client.auth.verify_otp)
response = await asyncio.to_thread(verify_otp, verify_payload)
session = getattr(response, "session", None)
user = getattr(response, "user", None)
@@ -190,7 +198,7 @@ class SupabaseAuthGateway(AuthServiceGateway):
status_code=401, detail="Invalid or expired verification code"
)
await asyncio.to_thread(
self._admin_client.auth.admin.update_user_by_id,
admin_client.auth.admin.update_user_by_id,
user_id,
{"password": request.new_password},
)