chore: 迁移到 social-app 架构,集成 Supabase 和 taskiq worker
This commit is contained in:
@@ -0,0 +1 @@
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from services.base.redis import RedisService, get_or_init_redis_client, redis_service
|
||||
from services.base.service_interface import (
|
||||
BaseServiceProvider,
|
||||
ServiceRegistry,
|
||||
close_registered_services,
|
||||
initialize_registered_services,
|
||||
register_service,
|
||||
register_service_instance,
|
||||
resolve_registered_services,
|
||||
)
|
||||
from services.base.supabase import SupabaseService, supabase_service
|
||||
|
||||
__all__ = [
|
||||
"BaseServiceProvider",
|
||||
"RedisService",
|
||||
"ServiceRegistry",
|
||||
"SupabaseService",
|
||||
"close_registered_services",
|
||||
"get_or_init_redis_client",
|
||||
"initialize_registered_services",
|
||||
"redis_service",
|
||||
"register_service",
|
||||
"register_service_instance",
|
||||
"resolve_registered_services",
|
||||
"supabase_service",
|
||||
]
|
||||
@@ -0,0 +1,136 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import redis.asyncio as redis
|
||||
|
||||
from core.config.settings import RedisSettings, config
|
||||
|
||||
from .service_interface import BaseServiceProvider, register_service_instance
|
||||
|
||||
|
||||
class RedisService(BaseServiceProvider):
|
||||
def __init__(self, settings: RedisSettings | None = None) -> None:
|
||||
super().__init__("redis")
|
||||
self._settings = settings or config.redis
|
||||
self._client: Optional[redis.Redis] = None
|
||||
self._loop_id: int | None = None
|
||||
|
||||
def _build_client(self) -> redis.Redis:
|
||||
return redis.from_url(
|
||||
self._settings.url,
|
||||
decode_responses=True,
|
||||
socket_connect_timeout=self._settings.socket_connect_timeout,
|
||||
socket_timeout=self._settings.socket_timeout,
|
||||
max_connections=self._settings.max_connections,
|
||||
)
|
||||
|
||||
def _require_client(self) -> redis.Redis:
|
||||
client = self._client
|
||||
if client is None:
|
||||
raise RuntimeError("Redis client is not initialized")
|
||||
return client
|
||||
|
||||
async def initialize(self, **_: Any) -> bool:
|
||||
try:
|
||||
client = self._build_client()
|
||||
ping_result = client.ping()
|
||||
if inspect.isawaitable(ping_result):
|
||||
await ping_result
|
||||
self._client = client
|
||||
self._loop_id = _current_loop_id()
|
||||
self._set_initialized(True)
|
||||
self.logger.info("Redis service initialized")
|
||||
return True
|
||||
except Exception as exc: # noqa: BLE001
|
||||
self.logger.warning("Redis service initialization failed", error=str(exc))
|
||||
self._client = None
|
||||
self._loop_id = None
|
||||
self._set_initialized(False)
|
||||
return False
|
||||
|
||||
async def close(self) -> bool:
|
||||
client = self._client
|
||||
if client is None:
|
||||
self._loop_id = None
|
||||
return True
|
||||
try:
|
||||
await client.aclose()
|
||||
self.logger.info("Redis service closed")
|
||||
return True
|
||||
except Exception as exc: # noqa: BLE001
|
||||
self.logger.exception("Redis service close failed", error=str(exc))
|
||||
return False
|
||||
finally:
|
||||
self._client = None
|
||||
self._loop_id = None
|
||||
self._set_initialized(False)
|
||||
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
client = self._client
|
||||
if client is None:
|
||||
return {"status": "unhealthy", "details": {"error": "not initialized"}}
|
||||
try:
|
||||
ping_result = client.ping()
|
||||
ping = (
|
||||
await ping_result if inspect.isawaitable(ping_result) else ping_result
|
||||
)
|
||||
info_result = client.info()
|
||||
info = (
|
||||
await info_result if inspect.isawaitable(info_result) else info_result
|
||||
)
|
||||
return {
|
||||
"status": "healthy" if ping else "unhealthy",
|
||||
"details": {
|
||||
"ping": ping,
|
||||
"redis_version": info.get("redis_version"),
|
||||
"connected_clients": info.get("connected_clients"),
|
||||
"used_memory": info.get("used_memory_human"),
|
||||
"uptime_in_seconds": info.get("uptime_in_seconds"),
|
||||
},
|
||||
}
|
||||
except Exception as exc: # noqa: BLE001
|
||||
self.logger.warning("Redis health check failed", error=str(exc))
|
||||
return {"status": "unhealthy", "details": {"error": str(exc)}}
|
||||
|
||||
def get_client(self) -> redis.Redis:
|
||||
return self._require_client()
|
||||
|
||||
|
||||
def _current_loop_id() -> int | None:
|
||||
try:
|
||||
return id(asyncio.get_running_loop())
|
||||
except RuntimeError:
|
||||
return None
|
||||
|
||||
|
||||
async def get_or_init_redis_client() -> redis.Redis:
|
||||
current_loop_id = _current_loop_id()
|
||||
bound_loop_id = redis_service._loop_id
|
||||
if (
|
||||
redis_service.is_initialized
|
||||
and bound_loop_id is not None
|
||||
and current_loop_id is not None
|
||||
and bound_loop_id != current_loop_id
|
||||
):
|
||||
redis_service.logger.warning(
|
||||
"Redis client bound to different event loop; reinitializing",
|
||||
previous_loop_id=bound_loop_id,
|
||||
current_loop_id=current_loop_id,
|
||||
)
|
||||
redis_service._client = None
|
||||
redis_service._loop_id = None
|
||||
redis_service._set_initialized(False)
|
||||
|
||||
if not redis_service.is_initialized:
|
||||
initialized = await redis_service.initialize()
|
||||
if not initialized:
|
||||
raise RuntimeError("Redis service initialization failed")
|
||||
return redis_service.get_client()
|
||||
|
||||
|
||||
redis_service: RedisService = register_service_instance("redis", RedisService())
|
||||
|
||||
__all__ = ["RedisService", "get_or_init_redis_client", "redis_service"]
|
||||
@@ -0,0 +1,158 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, Optional, TypeVar
|
||||
|
||||
from core.logging import get_logger
|
||||
|
||||
|
||||
class BaseServiceProvider(ABC):
|
||||
def __init__(self, service_name: str) -> None:
|
||||
self.service_name = service_name
|
||||
self._initialized = False
|
||||
self.logger = get_logger("services.base").bind(service=service_name)
|
||||
|
||||
@abstractmethod
|
||||
async def initialize(self, **kwargs: Any) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def close(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
return self._initialized
|
||||
|
||||
def _set_initialized(self, value: bool) -> None:
|
||||
self._initialized = value
|
||||
|
||||
def get_service_info(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"name": self.service_name,
|
||||
"initialized": self._initialized,
|
||||
"type": self.__class__.__name__,
|
||||
}
|
||||
|
||||
|
||||
class ServiceRegistry:
|
||||
_services: Dict[str, Callable[..., BaseServiceProvider]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(
|
||||
cls, service_name: str, factory: Callable[..., BaseServiceProvider]
|
||||
) -> None:
|
||||
cls._services = {**cls._services, service_name: factory}
|
||||
|
||||
@classmethod
|
||||
def get_service_factory(
|
||||
cls, service_name: str
|
||||
) -> Optional[Callable[..., BaseServiceProvider]]:
|
||||
return cls._services.get(service_name)
|
||||
|
||||
@classmethod
|
||||
def list_services(cls) -> list[str]:
|
||||
return sorted(cls._services.keys())
|
||||
|
||||
@classmethod
|
||||
def create_service(
|
||||
cls, service_name: str, **kwargs: Any
|
||||
) -> Optional[BaseServiceProvider]:
|
||||
return cls.get_service(service_name, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_service(
|
||||
cls, service_name: str, **kwargs: Any
|
||||
) -> Optional[BaseServiceProvider]:
|
||||
factory = cls.get_service_factory(service_name)
|
||||
if not factory:
|
||||
return None
|
||||
return factory(**kwargs)
|
||||
|
||||
|
||||
def register_service(service_name: str) -> Callable[[type], type]:
|
||||
def decorator(service_class: type) -> type:
|
||||
ServiceRegistry.register(service_name, service_class)
|
||||
return service_class
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
TService = TypeVar("TService", bound=BaseServiceProvider)
|
||||
|
||||
|
||||
def register_service_instance(service_name: str, service: TService) -> TService:
|
||||
ServiceRegistry.register(service_name, lambda: service)
|
||||
return service
|
||||
|
||||
|
||||
def resolve_registered_services(service_names: list[str]) -> list[BaseServiceProvider]:
|
||||
services: list[BaseServiceProvider] = []
|
||||
for service_name in service_names:
|
||||
service = ServiceRegistry.get_service(service_name)
|
||||
if service is None:
|
||||
raise RuntimeError(f"Service is not registered: {service_name}")
|
||||
services.append(service)
|
||||
return services
|
||||
|
||||
|
||||
async def close_registered_services(services: list[BaseServiceProvider]) -> bool:
|
||||
lifecycle_logger = get_logger("services.base.lifecycle")
|
||||
all_closed = True
|
||||
for service in reversed(services):
|
||||
try:
|
||||
closed = await service.close()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
lifecycle_logger.warning(
|
||||
"Failed to close service",
|
||||
service=service.service_name,
|
||||
error=str(exc),
|
||||
)
|
||||
all_closed = False
|
||||
continue
|
||||
if not closed:
|
||||
lifecycle_logger.warning(
|
||||
"Service close returned false",
|
||||
service=service.service_name,
|
||||
)
|
||||
all_closed = False
|
||||
return all_closed
|
||||
|
||||
|
||||
async def initialize_registered_services(
|
||||
service_names: list[str],
|
||||
) -> tuple[bool, list[BaseServiceProvider]]:
|
||||
lifecycle_logger = get_logger("services.base.lifecycle")
|
||||
initialized_services: list[BaseServiceProvider] = []
|
||||
try:
|
||||
services = resolve_registered_services(service_names)
|
||||
except RuntimeError as exc:
|
||||
lifecycle_logger.error("Failed to resolve registered services", error=str(exc))
|
||||
return False, []
|
||||
|
||||
for service in services:
|
||||
try:
|
||||
initialized = await service.initialize()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
lifecycle_logger.warning(
|
||||
"Service initialization raised exception",
|
||||
service=service.service_name,
|
||||
error=str(exc),
|
||||
)
|
||||
initialized = False
|
||||
|
||||
if not initialized:
|
||||
lifecycle_logger.error(
|
||||
"Service initialization failed, rolling back",
|
||||
service=service.service_name,
|
||||
)
|
||||
await close_registered_services(initialized_services)
|
||||
return False, []
|
||||
|
||||
initialized_services.append(service)
|
||||
|
||||
return True, initialized_services
|
||||
@@ -0,0 +1,304 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from supabase import create_client
|
||||
from storage3.exceptions import StorageApiError
|
||||
|
||||
from core.config.settings import SupabaseSettings, config
|
||||
|
||||
from .service_interface import BaseServiceProvider, register_service_instance
|
||||
|
||||
|
||||
class SupabaseService(BaseServiceProvider):
|
||||
def __init__(self, settings: SupabaseSettings | None = None) -> None:
|
||||
super().__init__("supabase")
|
||||
self._settings = settings or config.supabase
|
||||
self._client: Any = None
|
||||
self._admin_client: Any = None
|
||||
|
||||
async def initialize(self, **_: Any) -> bool:
|
||||
try:
|
||||
self._init_clients()
|
||||
await self._ensure_storage_bucket()
|
||||
self._set_initialized(True)
|
||||
self.logger.info("Supabase service initialized")
|
||||
return True
|
||||
except Exception as exc: # noqa: BLE001
|
||||
self.logger.warning(
|
||||
"Supabase service initialization failed", error=str(exc)
|
||||
)
|
||||
self._client = None
|
||||
self._admin_client = None
|
||||
self._set_initialized(False)
|
||||
return False
|
||||
|
||||
async def close(self) -> bool:
|
||||
self._client = None
|
||||
self._admin_client = None
|
||||
self._set_initialized(False)
|
||||
self.logger.info("Supabase service closed")
|
||||
return True
|
||||
|
||||
async def health_check(self) -> dict[str, Any]:
|
||||
client = self._client
|
||||
admin_client = self._admin_client
|
||||
if client is None or admin_client is None:
|
||||
return {"status": "unhealthy", "details": {"error": "not initialized"}}
|
||||
try:
|
||||
await asyncio.to_thread(client.auth.get_session)
|
||||
await asyncio.to_thread(
|
||||
admin_client.auth.admin.list_users, page=1, per_page=1
|
||||
)
|
||||
return {
|
||||
"status": "healthy",
|
||||
"details": {
|
||||
"anon_client": "ready",
|
||||
"admin_client": "ready",
|
||||
},
|
||||
}
|
||||
except Exception as exc: # noqa: BLE001
|
||||
self.logger.warning("Supabase health check failed", error=str(exc))
|
||||
return {"status": "unhealthy", "details": {"error": str(exc)}}
|
||||
|
||||
def get_client(self) -> Any:
|
||||
return self._require_client()
|
||||
|
||||
def get_admin_client(self) -> Any:
|
||||
return self._require_admin_client()
|
||||
|
||||
def _require_client(self) -> Any:
|
||||
if self._client is None or self._admin_client is None:
|
||||
self._init_clients()
|
||||
self._set_initialized(True)
|
||||
self.logger.info("Supabase service lazily initialized")
|
||||
client = self._client
|
||||
if client is None:
|
||||
raise RuntimeError("Supabase client is not initialized")
|
||||
return client
|
||||
|
||||
def _require_admin_client(self) -> Any:
|
||||
if self._client is None or self._admin_client is None:
|
||||
self._init_clients()
|
||||
self._set_initialized(True)
|
||||
self.logger.info("Supabase service lazily initialized")
|
||||
admin_client = self._admin_client
|
||||
if admin_client is None:
|
||||
raise RuntimeError("Supabase admin client is not initialized")
|
||||
return admin_client
|
||||
|
||||
def _init_clients(self) -> None:
|
||||
self._client = create_client(
|
||||
self._settings.url,
|
||||
self._settings.anon_key,
|
||||
)
|
||||
self._admin_client = create_client(
|
||||
self._settings.url,
|
||||
self._settings.service_role_key,
|
||||
)
|
||||
|
||||
async def _ensure_storage_bucket(self) -> None:
|
||||
storage = getattr(self._admin_client, "storage", None)
|
||||
if storage is None:
|
||||
self.logger.warning("Storage client unavailable, skipping bucket check")
|
||||
return
|
||||
|
||||
get_bucket = getattr(storage, "get_bucket", None)
|
||||
if not callable(get_bucket):
|
||||
self.logger.warning("Storage get_bucket unavailable, skipping bucket check")
|
||||
return
|
||||
|
||||
buckets = [
|
||||
(config.storage.attachment.bucket, False),
|
||||
(config.storage.avatar.bucket, True),
|
||||
]
|
||||
|
||||
def _check_and_create() -> None:
|
||||
for bucket_name, is_public in buckets:
|
||||
try:
|
||||
get_bucket(bucket_name)
|
||||
self.logger.debug(
|
||||
"Storage bucket already exists", bucket=bucket_name
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
create_bucket = getattr(storage, "create_bucket", None)
|
||||
if not callable(create_bucket):
|
||||
self.logger.warning(
|
||||
"Storage create_bucket unavailable, skipping bucket creation"
|
||||
)
|
||||
return
|
||||
try:
|
||||
create_bucket(bucket_name, options={"public": is_public})
|
||||
self.logger.info(
|
||||
"Storage bucket created",
|
||||
bucket=bucket_name,
|
||||
public=is_public,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
msg = str(exc).lower()
|
||||
if "already exists" in msg or "duplicate" in msg:
|
||||
self.logger.debug(
|
||||
"Storage bucket already exists (race)",
|
||||
bucket=bucket_name,
|
||||
)
|
||||
continue
|
||||
self.logger.warning(
|
||||
"Failed to create storage bucket",
|
||||
bucket=bucket_name,
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
await asyncio.to_thread(_check_and_create)
|
||||
|
||||
def _get_storage(self) -> Any:
|
||||
"""Get the storage client from admin client."""
|
||||
client = self.get_admin_client()
|
||||
storage = getattr(client, "storage", None)
|
||||
if storage is None:
|
||||
raise RuntimeError("Supabase storage client unavailable")
|
||||
return storage
|
||||
|
||||
def _get_bucket_client(self, bucket: str) -> Any:
|
||||
"""Get a bucket client for the specified bucket."""
|
||||
storage = self._get_storage()
|
||||
from_bucket = getattr(storage, "from_", None)
|
||||
if not callable(from_bucket):
|
||||
raise RuntimeError("Supabase storage bucket accessor unavailable")
|
||||
return from_bucket(bucket)
|
||||
|
||||
def _validate_bucket(self, bucket: str) -> None:
|
||||
"""Validate that the bucket matches one of configured storage buckets."""
|
||||
allowed_buckets = {
|
||||
config.storage.attachment.bucket,
|
||||
config.storage.avatar.bucket,
|
||||
}
|
||||
if bucket not in allowed_buckets:
|
||||
raise RuntimeError("Invalid storage bucket")
|
||||
|
||||
def _ensure_bucket_client(self, bucket: str) -> Any:
|
||||
"""Validate bucket and return authenticated bucket client."""
|
||||
self._validate_bucket(bucket)
|
||||
return self._get_bucket_client(bucket)
|
||||
|
||||
def _is_bucket_not_found_error(self, exc: Exception) -> bool:
|
||||
"""Check if the exception indicates a bucket was not found."""
|
||||
if isinstance(exc, StorageApiError):
|
||||
message = str(exc).lower()
|
||||
return "bucket" in message and "not found" in message
|
||||
message = str(exc).lower()
|
||||
return "bucket" in message and "not found" in message
|
||||
|
||||
async def upload_bytes(
|
||||
self,
|
||||
*,
|
||||
bucket: str,
|
||||
path: str,
|
||||
content: bytes,
|
||||
content_type: str,
|
||||
) -> str:
|
||||
def _upload() -> object:
|
||||
bucket_client = self._ensure_bucket_client(bucket)
|
||||
upload = getattr(bucket_client, "upload", None)
|
||||
if not callable(upload):
|
||||
raise RuntimeError("Supabase storage upload is unavailable")
|
||||
return upload(
|
||||
path,
|
||||
content,
|
||||
{
|
||||
"content-type": content_type,
|
||||
"upsert": "true",
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
await asyncio.to_thread(_upload)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
if not self._is_bucket_not_found_error(exc):
|
||||
raise
|
||||
await self._ensure_bucket_exists(bucket=bucket)
|
||||
await asyncio.to_thread(_upload)
|
||||
return path
|
||||
|
||||
async def _ensure_bucket_exists(self, *, bucket: str) -> None:
|
||||
def _ensure() -> None:
|
||||
storage = self._get_storage()
|
||||
get_bucket = getattr(storage, "get_bucket", None)
|
||||
if not callable(get_bucket):
|
||||
raise RuntimeError("Supabase storage get_bucket is unavailable")
|
||||
try:
|
||||
get_bucket(bucket)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
msg = str(exc).lower()
|
||||
if "bucket" in msg and "not found" in msg:
|
||||
raise RuntimeError(f"Storage bucket '{bucket}' does not exist")
|
||||
raise
|
||||
|
||||
await asyncio.to_thread(_ensure)
|
||||
|
||||
async def download_bytes(self, *, bucket: str, path: str) -> bytes:
|
||||
def _download() -> object:
|
||||
bucket_client = self._ensure_bucket_client(bucket)
|
||||
download = getattr(bucket_client, "download", None)
|
||||
if not callable(download):
|
||||
raise RuntimeError("Supabase storage download is unavailable")
|
||||
return download(path)
|
||||
|
||||
raw = await asyncio.to_thread(_download)
|
||||
if isinstance(raw, bytes):
|
||||
return raw
|
||||
if isinstance(raw, bytearray):
|
||||
return bytes(raw)
|
||||
if isinstance(raw, memoryview):
|
||||
return raw.tobytes()
|
||||
raise RuntimeError("Invalid attachment payload")
|
||||
|
||||
async def create_signed_url(
|
||||
self,
|
||||
*,
|
||||
bucket: str,
|
||||
path: str,
|
||||
expires_in_seconds: int,
|
||||
) -> str:
|
||||
def _create_signed_url() -> object:
|
||||
bucket_client = self._ensure_bucket_client(bucket)
|
||||
signer = getattr(bucket_client, "create_signed_url", None)
|
||||
if not callable(signer):
|
||||
raise RuntimeError("Supabase storage signed url is unavailable")
|
||||
return signer(path, expires_in_seconds)
|
||||
|
||||
raw = await asyncio.to_thread(_create_signed_url)
|
||||
if isinstance(raw, str):
|
||||
return raw
|
||||
if isinstance(raw, dict):
|
||||
signed_url = raw.get("signedURL") or raw.get("signedUrl") or raw.get("url")
|
||||
if isinstance(signed_url, str) and signed_url:
|
||||
return signed_url
|
||||
raise RuntimeError("Invalid signed url payload")
|
||||
|
||||
def parse_signed_url(self, url: str) -> tuple[str, str]:
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed = urlparse(url)
|
||||
path_parts = parsed.path.strip("/").split("/")
|
||||
|
||||
if (
|
||||
len(path_parts) < 4
|
||||
or path_parts[0] != "storage"
|
||||
or path_parts[1] != "v1"
|
||||
or path_parts[2] != "object"
|
||||
or path_parts[3] != "sign"
|
||||
):
|
||||
raise RuntimeError("Invalid signed URL format")
|
||||
|
||||
bucket = path_parts[4]
|
||||
path = "/".join(path_parts[5:])
|
||||
|
||||
return bucket, path
|
||||
|
||||
|
||||
supabase_service: SupabaseService = register_service_instance(
|
||||
"supabase", SupabaseService()
|
||||
)
|
||||
|
||||
__all__ = ["SupabaseService", "supabase_service"]
|
||||
@@ -0,0 +1,4 @@
|
||||
from .factory import get_cache_store
|
||||
from .interfaces import CacheStore
|
||||
|
||||
__all__ = ["CacheStore", "get_cache_store"]
|
||||
@@ -0,0 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .interfaces import CacheStore
|
||||
from .redis_store import RedisCacheStore
|
||||
|
||||
_cache_store: CacheStore | None = None
|
||||
|
||||
|
||||
def get_cache_store() -> CacheStore:
|
||||
global _cache_store
|
||||
if _cache_store is None:
|
||||
_cache_store = RedisCacheStore()
|
||||
return _cache_store
|
||||
@@ -0,0 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
class CacheStore(Protocol):
|
||||
async def hgetall(self, key: str, /) -> dict[str, str]: ...
|
||||
|
||||
async def hset(self, key: str, /, mapping: dict[str, str]) -> int: ...
|
||||
|
||||
async def hincrby(self, key: str, field: str, amount: int = 1, /) -> int: ...
|
||||
|
||||
async def expire(self, key: str, ttl_seconds: int, /) -> int: ...
|
||||
|
||||
async def delete(self, *keys: str) -> int: ...
|
||||
|
||||
async def sadd(self, key: str, *members: str) -> int: ...
|
||||
|
||||
async def smembers(self, key: str, /) -> set[str]: ...
|
||||
@@ -0,0 +1,80 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from typing import Any
|
||||
|
||||
from services.base.redis import get_or_init_redis_client
|
||||
|
||||
from .interfaces import CacheStore
|
||||
|
||||
|
||||
def _to_text(value: Any) -> str | None:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
if isinstance(value, bytes):
|
||||
try:
|
||||
return value.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
async def _maybe_await(value: Any) -> Any:
|
||||
if inspect.isawaitable(value):
|
||||
return await value
|
||||
return value
|
||||
|
||||
|
||||
class RedisCacheStore(CacheStore):
|
||||
async def hgetall(self, key: str) -> dict[str, str]:
|
||||
client = await get_or_init_redis_client()
|
||||
raw = await _maybe_await(client.hgetall(key))
|
||||
if not isinstance(raw, dict):
|
||||
return {}
|
||||
|
||||
decoded: dict[str, str] = {}
|
||||
for raw_key, raw_value in raw.items():
|
||||
key_text = _to_text(raw_key)
|
||||
value_text = _to_text(raw_value)
|
||||
if key_text is None or value_text is None:
|
||||
continue
|
||||
decoded[key_text] = value_text
|
||||
return decoded
|
||||
|
||||
async def hset(self, key: str, mapping: dict[str, str]) -> int:
|
||||
client = await get_or_init_redis_client()
|
||||
result = await _maybe_await(client.hset(key, mapping=mapping))
|
||||
return int(result)
|
||||
|
||||
async def hincrby(self, key: str, field: str, amount: int = 1) -> int:
|
||||
client = await get_or_init_redis_client()
|
||||
result = await _maybe_await(client.hincrby(key, field, amount))
|
||||
return int(result)
|
||||
|
||||
async def expire(self, key: str, ttl_seconds: int) -> int:
|
||||
client = await get_or_init_redis_client()
|
||||
result = await _maybe_await(client.expire(key, ttl_seconds))
|
||||
return int(result)
|
||||
|
||||
async def delete(self, *keys: str) -> int:
|
||||
if not keys:
|
||||
return 0
|
||||
client = await get_or_init_redis_client()
|
||||
result = await _maybe_await(client.delete(*keys))
|
||||
return int(result)
|
||||
|
||||
async def sadd(self, key: str, *members: str) -> int:
|
||||
if not members:
|
||||
return 0
|
||||
client = await get_or_init_redis_client()
|
||||
result = await _maybe_await(client.sadd(key, *members))
|
||||
return int(result)
|
||||
|
||||
async def smembers(self, key: str) -> set[str]:
|
||||
client = await get_or_init_redis_client()
|
||||
raw = await _maybe_await(client.smembers(key))
|
||||
if isinstance(raw, set):
|
||||
return {value for item in raw if (value := _to_text(item))}
|
||||
if isinstance(raw, list | tuple):
|
||||
return {value for item in raw if (value := _to_text(item))}
|
||||
return set()
|
||||
@@ -0,0 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from services.llm_pricing.service import LlmPricingService
|
||||
|
||||
__all__ = ["LlmPricingService"]
|
||||
@@ -0,0 +1,183 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from core.config.initial.init_data import load_llm_catalog
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PricingTier:
|
||||
max_prompt_tokens: int
|
||||
input_cost_per_token: float
|
||||
output_cost_per_token: float
|
||||
cache_hit_cost_per_token: float
|
||||
|
||||
|
||||
class LlmPricingService:
|
||||
_pricing_by_model: dict[str, tuple[PricingTier, ...]]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._pricing_by_model = self._build_pricing_map()
|
||||
|
||||
@staticmethod
|
||||
def _build_pricing_map() -> dict[str, tuple[PricingTier, ...]]:
|
||||
catalog = load_llm_catalog()
|
||||
pricing_by_model: dict[str, tuple[PricingTier, ...]] = {}
|
||||
for model in catalog.get("llms", []):
|
||||
if not isinstance(model, dict):
|
||||
continue
|
||||
model_code = str(model.get("model_code", "")).strip().lower()
|
||||
raw_tiers = model.get("pricing_tiers")
|
||||
if not isinstance(raw_tiers, list) or not raw_tiers:
|
||||
continue
|
||||
|
||||
tiers = [
|
||||
PricingTier(
|
||||
max_prompt_tokens=int(item.get("max_prompt_tokens", 0) or 0),
|
||||
input_cost_per_token=float(
|
||||
item.get("input_cost_per_token", 0.0) or 0.0
|
||||
),
|
||||
output_cost_per_token=float(
|
||||
item.get("output_cost_per_token", 0.0) or 0.0
|
||||
),
|
||||
cache_hit_cost_per_token=float(
|
||||
item.get("cache_hit_cost_per_token", 0.0) or 0.0
|
||||
),
|
||||
)
|
||||
for item in raw_tiers
|
||||
if isinstance(item, dict)
|
||||
]
|
||||
if not tiers:
|
||||
continue
|
||||
ordered_tiers = tuple(
|
||||
sorted(tiers, key=lambda item: item.max_prompt_tokens)
|
||||
)
|
||||
if model_code:
|
||||
pricing_by_model[model_code] = ordered_tiers
|
||||
return pricing_by_model
|
||||
|
||||
def calculate_cost(
|
||||
self,
|
||||
*,
|
||||
model: str,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
cached_prompt_tokens: int = 0,
|
||||
) -> float:
|
||||
tiers = self._pricing_by_model.get(model.strip().lower())
|
||||
if tiers is None:
|
||||
raise ValueError(f"unknown model pricing: {model}")
|
||||
|
||||
normalized_prompt_tokens = max(int(prompt_tokens), 0)
|
||||
normalized_completion_tokens = max(int(completion_tokens), 0)
|
||||
normalized_cached_tokens = min(
|
||||
max(int(cached_prompt_tokens), 0), normalized_prompt_tokens
|
||||
)
|
||||
uncached_prompt_tokens = normalized_prompt_tokens - normalized_cached_tokens
|
||||
|
||||
selected_tier = tiers[-1]
|
||||
for tier in tiers:
|
||||
if normalized_prompt_tokens <= tier.max_prompt_tokens:
|
||||
selected_tier = tier
|
||||
break
|
||||
|
||||
cached_token_rate = (
|
||||
selected_tier.cache_hit_cost_per_token
|
||||
if selected_tier.cache_hit_cost_per_token > 0
|
||||
else selected_tier.input_cost_per_token
|
||||
)
|
||||
|
||||
return float(
|
||||
uncached_prompt_tokens * selected_tier.input_cost_per_token
|
||||
+ normalized_cached_tokens * cached_token_rate
|
||||
+ normalized_completion_tokens * selected_tier.output_cost_per_token
|
||||
)
|
||||
|
||||
def build_usage_metadata(
|
||||
self,
|
||||
*,
|
||||
model: str,
|
||||
usage_summary: dict[str, Any] | None,
|
||||
) -> dict[str, Any]:
|
||||
summary = usage_summary or {}
|
||||
input_tokens = max(int(summary.get("input_tokens", 0) or 0), 0)
|
||||
output_tokens = max(int(summary.get("output_tokens", 0) or 0), 0)
|
||||
total_tokens = max(
|
||||
int(summary.get("total_tokens", input_tokens + output_tokens) or 0), 0
|
||||
)
|
||||
latency_ms = max(int(summary.get("latency_ms", 0) or 0), 0)
|
||||
cached_prompt_tokens = max(int(summary.get("cached_prompt_tokens", 0) or 0), 0)
|
||||
prompt_cache_hit_tokens = max(
|
||||
int(summary.get("prompt_cache_hit_tokens", cached_prompt_tokens) or 0), 0
|
||||
)
|
||||
prompt_cache_miss_tokens = max(
|
||||
int(
|
||||
summary.get(
|
||||
"prompt_cache_miss_tokens",
|
||||
max(input_tokens - prompt_cache_hit_tokens, 0),
|
||||
)
|
||||
or 0
|
||||
),
|
||||
0,
|
||||
)
|
||||
reasoning_tokens = max(int(summary.get("reasoning_tokens", 0) or 0), 0)
|
||||
direct_cost_raw = summary.get("direct_cost")
|
||||
direct_cost_observed = bool(int(summary.get("direct_cost_observed", 0) or 0))
|
||||
direct_cost_complete = bool(int(summary.get("direct_cost_complete", 0) or 0))
|
||||
model_call_records = max(int(summary.get("model_call_records", 0) or 0), 0)
|
||||
usage_records = max(int(summary.get("usage_records", 0) or 0), 0)
|
||||
usage_complete = model_call_records == 0 or model_call_records == usage_records
|
||||
direct_cost = self._coerce_non_negative_float(direct_cost_raw)
|
||||
|
||||
if (
|
||||
usage_complete
|
||||
and direct_cost_observed
|
||||
and direct_cost_complete
|
||||
and direct_cost is not None
|
||||
):
|
||||
cost = direct_cost
|
||||
cost_source = "provider"
|
||||
else:
|
||||
cost = self.calculate_cost(
|
||||
model=model,
|
||||
prompt_tokens=input_tokens,
|
||||
completion_tokens=output_tokens,
|
||||
cached_prompt_tokens=cached_prompt_tokens,
|
||||
)
|
||||
cost_source = (
|
||||
"incomplete_usage_fallback"
|
||||
if not usage_complete
|
||||
else (
|
||||
"catalog_fallback_incomplete_provider_cost"
|
||||
if direct_cost_observed and not direct_cost_complete
|
||||
else "catalog_fallback"
|
||||
)
|
||||
)
|
||||
|
||||
return {
|
||||
"model": model,
|
||||
"inputTokens": input_tokens,
|
||||
"outputTokens": output_tokens,
|
||||
"totalTokens": total_tokens,
|
||||
"cachedPromptTokens": cached_prompt_tokens,
|
||||
"promptCacheHitTokens": prompt_cache_hit_tokens,
|
||||
"promptCacheMissTokens": prompt_cache_miss_tokens,
|
||||
"reasoningTokens": reasoning_tokens,
|
||||
"cost": cost,
|
||||
"costSource": cost_source,
|
||||
"usageComplete": usage_complete,
|
||||
"latencyMs": latency_ms,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _coerce_non_negative_float(value: Any) -> float | None:
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
parsed = float(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
if parsed < 0:
|
||||
return None
|
||||
return parsed
|
||||
Reference in New Issue
Block a user