Files
social-app/backend/src/core/agent/litellm_client.py
T

131 lines
4.0 KiB
Python
Raw Normal View History

from __future__ import annotations
import os
from dataclasses import dataclass
from decimal import Decimal
from pathlib import Path
from typing import Any
import yaml
@dataclass(frozen=True)
class LLMConfig:
model_code: str
factory_name: str
litellm_model: str
request_url: str
@dataclass(frozen=True)
class LLMResponse:
content: str
usage: dict[str, Any]
class LiteLLMClient:
def __init__(self, config: LLMConfig, api_key: str | None = None) -> None:
self._config = config
self._api_key = api_key or self._get_api_key(config.factory_name)
@staticmethod
def _get_api_key(factory_name: str) -> str:
key_map = {
"dashscope": "DASHSCOPE_API_KEY",
"minimax": "MINIMAX_API_KEY",
"moonshot": "MOONSHOT_API_KEY",
"deepseek": "DEEPSEEK_API_KEY",
"volcengine-ark": "ARK_API_KEY",
"z-ai": "ZAI_API_KEY",
}
env_key = key_map.get(factory_name)
if not env_key:
raise ValueError(f"No API key mapping for factory: {factory_name}")
key = os.environ.get(env_key)
if not key:
raise ValueError(f"Environment variable {env_key} is not set")
return key
@staticmethod
def load_config(
model_code: str,
static_root: Path | None = None,
) -> LLMConfig:
root = static_root or (
Path(__file__).resolve().parents[3] / "config" / "static" / "database"
)
yaml_path = root / "llm_catalog.yaml"
with yaml_path.open("r", encoding="utf-8") as f:
data = yaml.safe_load(f)
factories = {f["name"]: f for f in data.get("factories", [])}
llms = data.get("llms", [])
for llm in llms:
if llm.get("model_code") == model_code:
factory_name = llm["factory_name"]
factory = factories.get(factory_name)
if not factory:
raise ValueError(f"Factory not found: {factory_name}")
return LLMConfig(
model_code=model_code,
factory_name=factory_name,
litellm_model=llm.get("litellm_model", model_code),
request_url=factory["request_url"],
)
raise ValueError(f"Model not found: {model_code}")
def chat(
self,
messages: list[dict[str, str]],
*,
temperature: float = 0.7,
max_tokens: int | None = None,
) -> LLMResponse:
import litellm
response = litellm.completion( # type: ignore[attr-defined]
model=self._config.litellm_model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
api_base=self._config.request_url,
api_key=self._api_key,
)
content = response.choices[0].message.content or "" # type: ignore[attr-defined]
usage = response.usage.model_dump() if response.usage else {} # type: ignore[attr-defined]
return LLMResponse(content=content, usage=usage)
async def achat(
self,
messages: list[dict[str, str]],
*,
temperature: float = 0.7,
max_tokens: int | None = None,
) -> LLMResponse:
import litellm
response = await litellm.acompletion( # type: ignore[attr-defined]
model=self._config.litellm_model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
api_base=self._config.request_url,
api_key=self._api_key,
)
content = response.choices[0].message.content or "" # type: ignore[attr-defined]
usage = response.usage.model_dump() if response.usage else {} # type: ignore[attr-defined]
return LLMResponse(content=content, usage=usage)
def get_model_cost(usage: dict[str, Any]) -> Decimal:
cost = usage.get("cost")
if cost is None:
return Decimal("0")
return Decimal(str(cost))