refactor: 切换到 litellm,删除未使用的代码
- 添加 litellm 依赖,统一 LLM 调用层 - 新增 litellm_client.py 支持多厂商 - 更新 llm_catalog.yaml 添加 litellm_model 映射 - 删除旧的 cost_tracker.py (litellm 内置 cost 追踪) - 删除未使用的 multimodal.py 和 storage_adapter.py - 删除空文件 crewai/__init__.py, tools/__init__.py - 更新测试以适配新代码
This commit is contained in:
@@ -0,0 +1,130 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from decimal import Decimal
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LLMConfig:
|
||||
model_code: str
|
||||
factory_name: str
|
||||
litellm_model: str
|
||||
request_url: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LLMResponse:
|
||||
content: str
|
||||
usage: dict[str, Any]
|
||||
|
||||
|
||||
class LiteLLMClient:
|
||||
def __init__(self, config: LLMConfig, api_key: str | None = None) -> None:
|
||||
self._config = config
|
||||
self._api_key = api_key or self._get_api_key(config.factory_name)
|
||||
|
||||
@staticmethod
|
||||
def _get_api_key(factory_name: str) -> str:
|
||||
key_map = {
|
||||
"dashscope": "DASHSCOPE_API_KEY",
|
||||
"minimax": "MINIMAX_API_KEY",
|
||||
"moonshot": "MOONSHOT_API_KEY",
|
||||
"deepseek": "DEEPSEEK_API_KEY",
|
||||
"volcengine-ark": "ARK_API_KEY",
|
||||
"z-ai": "ZAI_API_KEY",
|
||||
}
|
||||
env_key = key_map.get(factory_name)
|
||||
if not env_key:
|
||||
raise ValueError(f"No API key mapping for factory: {factory_name}")
|
||||
key = os.environ.get(env_key)
|
||||
if not key:
|
||||
raise ValueError(f"Environment variable {env_key} is not set")
|
||||
return key
|
||||
|
||||
@staticmethod
|
||||
def load_config(
|
||||
model_code: str,
|
||||
static_root: Path | None = None,
|
||||
) -> LLMConfig:
|
||||
root = static_root or (
|
||||
Path(__file__).resolve().parents[3] / "config" / "static" / "database"
|
||||
)
|
||||
yaml_path = root / "llm_catalog.yaml"
|
||||
with yaml_path.open("r", encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
factories = {f["name"]: f for f in data.get("factories", [])}
|
||||
llms = data.get("llms", [])
|
||||
|
||||
for llm in llms:
|
||||
if llm.get("model_code") == model_code:
|
||||
factory_name = llm["factory_name"]
|
||||
factory = factories.get(factory_name)
|
||||
if not factory:
|
||||
raise ValueError(f"Factory not found: {factory_name}")
|
||||
return LLMConfig(
|
||||
model_code=model_code,
|
||||
factory_name=factory_name,
|
||||
litellm_model=llm.get("litellm_model", model_code),
|
||||
request_url=factory["request_url"],
|
||||
)
|
||||
|
||||
raise ValueError(f"Model not found: {model_code}")
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
*,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int | None = None,
|
||||
) -> LLMResponse:
|
||||
import litellm
|
||||
|
||||
response = litellm.completion( # type: ignore[attr-defined]
|
||||
model=self._config.litellm_model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
api_base=self._config.request_url,
|
||||
api_key=self._api_key,
|
||||
)
|
||||
|
||||
content = response.choices[0].message.content or "" # type: ignore[attr-defined]
|
||||
usage = response.usage.model_dump() if response.usage else {} # type: ignore[attr-defined]
|
||||
|
||||
return LLMResponse(content=content, usage=usage)
|
||||
|
||||
async def achat(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
*,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int | None = None,
|
||||
) -> LLMResponse:
|
||||
import litellm
|
||||
|
||||
response = await litellm.acompletion( # type: ignore[attr-defined]
|
||||
model=self._config.litellm_model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
api_base=self._config.request_url,
|
||||
api_key=self._api_key,
|
||||
)
|
||||
|
||||
content = response.choices[0].message.content or "" # type: ignore[attr-defined]
|
||||
usage = response.usage.model_dump() if response.usage else {} # type: ignore[attr-defined]
|
||||
|
||||
return LLMResponse(content=content, usage=usage)
|
||||
|
||||
|
||||
def get_model_cost(usage: dict[str, Any]) -> Decimal:
|
||||
cost = usage.get("cost")
|
||||
if cost is None:
|
||||
return Decimal("0")
|
||||
return Decimal(str(cost))
|
||||
Reference in New Issue
Block a user