feat: 添加好友功能并集成 LiteLLM 代理服务
- 新增好友搜索、添加、好友列表功能 - 集成 LiteLLM 代理服务及多模型定价配置 - 更新 iOS CocoaPods 配置 - 更新 .gitignore 和环境变量配置
This commit is contained in:
@@ -0,0 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from services.litellm.service import (
|
||||
LiteLLMResponseWithCost,
|
||||
LiteLLMService,
|
||||
LiteLLMUsage,
|
||||
)
|
||||
|
||||
__all__ = ["LiteLLMService", "LiteLLMUsage", "LiteLLMResponseWithCost"]
|
||||
@@ -0,0 +1,189 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable
|
||||
|
||||
from litellm import completion
|
||||
|
||||
from core.config.settings import config
|
||||
from core.config.initial.init_data import load_llm_catalog
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PricingTier:
|
||||
max_prompt_tokens: int
|
||||
input_cost_per_token: float
|
||||
output_cost_per_token: float
|
||||
cache_hit_cost_per_token: float
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LiteLLMUsage:
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
cached_prompt_tokens: int
|
||||
cost: float
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LiteLLMResponseWithCost:
|
||||
response: dict[str, Any]
|
||||
usage: LiteLLMUsage
|
||||
|
||||
|
||||
class LiteLLMService:
|
||||
proxy_base_url: str
|
||||
proxy_api_key: str
|
||||
_pricing_by_model: dict[str, tuple[PricingTier, ...]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
proxy_base_url: str | None = None,
|
||||
proxy_api_key: str | None = None,
|
||||
) -> None:
|
||||
self.proxy_base_url = proxy_base_url or config.litellm.base_url
|
||||
self.proxy_api_key = proxy_api_key or config.litellm.api_key
|
||||
self._pricing_by_model = self._build_pricing_map()
|
||||
|
||||
@staticmethod
|
||||
def _build_pricing_map() -> dict[str, tuple[PricingTier, ...]]:
|
||||
catalog = load_llm_catalog()
|
||||
pricing_by_model: dict[str, tuple[PricingTier, ...]] = {}
|
||||
for model in catalog.get("llms", []):
|
||||
if not isinstance(model, dict):
|
||||
continue
|
||||
model_code = str(model.get("model_code", "")).strip().lower()
|
||||
litellm_model = str(model.get("litellm_model", "")).strip().lower()
|
||||
raw_tiers = model.get("pricing_tiers")
|
||||
if not isinstance(raw_tiers, list) or not raw_tiers:
|
||||
continue
|
||||
|
||||
tiers = [
|
||||
PricingTier(
|
||||
max_prompt_tokens=int(item.get("max_prompt_tokens", 0) or 0),
|
||||
input_cost_per_token=float(
|
||||
item.get("input_cost_per_token", 0.0) or 0.0
|
||||
),
|
||||
output_cost_per_token=float(
|
||||
item.get("output_cost_per_token", 0.0) or 0.0
|
||||
),
|
||||
cache_hit_cost_per_token=float(
|
||||
item.get("cache_hit_cost_per_token", 0.0) or 0.0
|
||||
),
|
||||
)
|
||||
for item in raw_tiers
|
||||
if isinstance(item, dict)
|
||||
]
|
||||
if not tiers:
|
||||
continue
|
||||
ordered_tiers = tuple(
|
||||
sorted(tiers, key=lambda item: item.max_prompt_tokens)
|
||||
)
|
||||
if model_code:
|
||||
pricing_by_model[model_code] = ordered_tiers
|
||||
if litellm_model:
|
||||
pricing_by_model[litellm_model] = ordered_tiers
|
||||
return pricing_by_model
|
||||
|
||||
def calculate_cost(
|
||||
self,
|
||||
*,
|
||||
model: str,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
cached_prompt_tokens: int = 0,
|
||||
) -> float:
|
||||
tiers = self._pricing_by_model.get(model.strip().lower())
|
||||
if tiers is None:
|
||||
raise ValueError(f"unknown model pricing: {model}")
|
||||
|
||||
normalized_prompt_tokens = max(int(prompt_tokens), 0)
|
||||
normalized_completion_tokens = max(int(completion_tokens), 0)
|
||||
normalized_cached_tokens = min(
|
||||
max(int(cached_prompt_tokens), 0), normalized_prompt_tokens
|
||||
)
|
||||
uncached_prompt_tokens = normalized_prompt_tokens - normalized_cached_tokens
|
||||
|
||||
selected_tier = tiers[-1]
|
||||
for tier in tiers:
|
||||
if normalized_prompt_tokens <= tier.max_prompt_tokens:
|
||||
selected_tier = tier
|
||||
break
|
||||
|
||||
return float(
|
||||
uncached_prompt_tokens * selected_tier.input_cost_per_token
|
||||
+ normalized_cached_tokens * selected_tier.cache_hit_cost_per_token
|
||||
+ normalized_completion_tokens * selected_tier.output_cost_per_token
|
||||
)
|
||||
|
||||
def run_completion_with_cost(
|
||||
self,
|
||||
*,
|
||||
model: str,
|
||||
messages: list[dict[str, Any]],
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
timeout: float | None = None,
|
||||
completion_fn: Callable[..., dict[str, Any]] | None = None,
|
||||
) -> LiteLLMResponseWithCost:
|
||||
caller = completion_fn or completion
|
||||
request_model = model if model.startswith("openai/") else f"openai/{model}"
|
||||
|
||||
response_any = caller(
|
||||
model=request_model,
|
||||
api_key=self.proxy_api_key,
|
||||
api_base=self.proxy_base_url,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
timeout=timeout,
|
||||
stream=False,
|
||||
)
|
||||
response = self._normalize_response(response_any)
|
||||
|
||||
usage_raw = response.get("usage")
|
||||
if not isinstance(usage_raw, dict):
|
||||
raise ValueError("missing usage in response")
|
||||
|
||||
prompt_tokens = int(usage_raw.get("prompt_tokens", 0) or 0)
|
||||
completion_tokens = int(usage_raw.get("completion_tokens", 0) or 0)
|
||||
total_tokens = int(
|
||||
usage_raw.get("total_tokens", prompt_tokens + completion_tokens) or 0
|
||||
)
|
||||
cached_prompt_tokens = 0
|
||||
prompt_tokens_details = usage_raw.get("prompt_tokens_details")
|
||||
if isinstance(prompt_tokens_details, dict):
|
||||
cached_prompt_tokens = int(
|
||||
prompt_tokens_details.get("cached_tokens", 0) or 0
|
||||
)
|
||||
|
||||
resolved_model = str(response.get("model", model)).strip()
|
||||
cost = self.calculate_cost(
|
||||
model=resolved_model,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
cached_prompt_tokens=cached_prompt_tokens,
|
||||
)
|
||||
return LiteLLMResponseWithCost(
|
||||
response=response,
|
||||
usage=LiteLLMUsage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
cached_prompt_tokens=cached_prompt_tokens,
|
||||
cost=cost,
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_response(response_any: Any) -> dict[str, Any]:
|
||||
if isinstance(response_any, dict):
|
||||
return response_any
|
||||
model_dump = getattr(response_any, "model_dump", None)
|
||||
if callable(model_dump):
|
||||
dumped = model_dump()
|
||||
if isinstance(dumped, dict):
|
||||
return dumped
|
||||
raise ValueError("litellm response is not serializable")
|
||||
Reference in New Issue
Block a user