Files
social-app/backend/src/core/agent/litellm_client.py
T
qzl 80cbb3512f refactor: 切换到 litellm,删除未使用的代码
- 添加 litellm 依赖,统一 LLM 调用层
- 新增 litellm_client.py 支持多厂商
- 更新 llm_catalog.yaml 添加 litellm_model 映射
- 删除旧的 cost_tracker.py (litellm 内置 cost 追踪)
- 删除未使用的 multimodal.py 和 storage_adapter.py
- 删除空文件 crewai/__init__.py, tools/__init__.py
- 更新测试以适配新代码
2026-03-03 17:52:34 +08:00

131 lines
4.0 KiB
Python

from __future__ import annotations
import os
from dataclasses import dataclass
from decimal import Decimal
from pathlib import Path
from typing import Any
import yaml
@dataclass(frozen=True)
class LLMConfig:
model_code: str
factory_name: str
litellm_model: str
request_url: str
@dataclass(frozen=True)
class LLMResponse:
content: str
usage: dict[str, Any]
class LiteLLMClient:
def __init__(self, config: LLMConfig, api_key: str | None = None) -> None:
self._config = config
self._api_key = api_key or self._get_api_key(config.factory_name)
@staticmethod
def _get_api_key(factory_name: str) -> str:
key_map = {
"dashscope": "DASHSCOPE_API_KEY",
"minimax": "MINIMAX_API_KEY",
"moonshot": "MOONSHOT_API_KEY",
"deepseek": "DEEPSEEK_API_KEY",
"volcengine-ark": "ARK_API_KEY",
"z-ai": "ZAI_API_KEY",
}
env_key = key_map.get(factory_name)
if not env_key:
raise ValueError(f"No API key mapping for factory: {factory_name}")
key = os.environ.get(env_key)
if not key:
raise ValueError(f"Environment variable {env_key} is not set")
return key
@staticmethod
def load_config(
model_code: str,
static_root: Path | None = None,
) -> LLMConfig:
root = static_root or (
Path(__file__).resolve().parents[3] / "config" / "static" / "database"
)
yaml_path = root / "llm_catalog.yaml"
with yaml_path.open("r", encoding="utf-8") as f:
data = yaml.safe_load(f)
factories = {f["name"]: f for f in data.get("factories", [])}
llms = data.get("llms", [])
for llm in llms:
if llm.get("model_code") == model_code:
factory_name = llm["factory_name"]
factory = factories.get(factory_name)
if not factory:
raise ValueError(f"Factory not found: {factory_name}")
return LLMConfig(
model_code=model_code,
factory_name=factory_name,
litellm_model=llm.get("litellm_model", model_code),
request_url=factory["request_url"],
)
raise ValueError(f"Model not found: {model_code}")
def chat(
self,
messages: list[dict[str, str]],
*,
temperature: float = 0.7,
max_tokens: int | None = None,
) -> LLMResponse:
import litellm
response = litellm.completion( # type: ignore[attr-defined]
model=self._config.litellm_model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
api_base=self._config.request_url,
api_key=self._api_key,
)
content = response.choices[0].message.content or "" # type: ignore[attr-defined]
usage = response.usage.model_dump() if response.usage else {} # type: ignore[attr-defined]
return LLMResponse(content=content, usage=usage)
async def achat(
self,
messages: list[dict[str, str]],
*,
temperature: float = 0.7,
max_tokens: int | None = None,
) -> LLMResponse:
import litellm
response = await litellm.acompletion( # type: ignore[attr-defined]
model=self._config.litellm_model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
api_base=self._config.request_url,
api_key=self._api_key,
)
content = response.choices[0].message.content or "" # type: ignore[attr-defined]
usage = response.usage.model_dump() if response.usage else {} # type: ignore[attr-defined]
return LLMResponse(content=content, usage=usage)
def get_model_cost(usage: dict[str, Any]) -> Decimal:
cost = usage.get("cost")
if cost is None:
return Decimal("0")
return Decimal(str(cost))