80cbb3512f
- 添加 litellm 依赖,统一 LLM 调用层 - 新增 litellm_client.py 支持多厂商 - 更新 llm_catalog.yaml 添加 litellm_model 映射 - 删除旧的 cost_tracker.py (litellm 内置 cost 追踪) - 删除未使用的 multimodal.py 和 storage_adapter.py - 删除空文件 crewai/__init__.py, tools/__init__.py - 更新测试以适配新代码
131 lines
4.0 KiB
Python
131 lines
4.0 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
from dataclasses import dataclass
|
|
from decimal import Decimal
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import yaml
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class LLMConfig:
|
|
model_code: str
|
|
factory_name: str
|
|
litellm_model: str
|
|
request_url: str
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class LLMResponse:
|
|
content: str
|
|
usage: dict[str, Any]
|
|
|
|
|
|
class LiteLLMClient:
|
|
def __init__(self, config: LLMConfig, api_key: str | None = None) -> None:
|
|
self._config = config
|
|
self._api_key = api_key or self._get_api_key(config.factory_name)
|
|
|
|
@staticmethod
|
|
def _get_api_key(factory_name: str) -> str:
|
|
key_map = {
|
|
"dashscope": "DASHSCOPE_API_KEY",
|
|
"minimax": "MINIMAX_API_KEY",
|
|
"moonshot": "MOONSHOT_API_KEY",
|
|
"deepseek": "DEEPSEEK_API_KEY",
|
|
"volcengine-ark": "ARK_API_KEY",
|
|
"z-ai": "ZAI_API_KEY",
|
|
}
|
|
env_key = key_map.get(factory_name)
|
|
if not env_key:
|
|
raise ValueError(f"No API key mapping for factory: {factory_name}")
|
|
key = os.environ.get(env_key)
|
|
if not key:
|
|
raise ValueError(f"Environment variable {env_key} is not set")
|
|
return key
|
|
|
|
@staticmethod
|
|
def load_config(
|
|
model_code: str,
|
|
static_root: Path | None = None,
|
|
) -> LLMConfig:
|
|
root = static_root or (
|
|
Path(__file__).resolve().parents[3] / "config" / "static" / "database"
|
|
)
|
|
yaml_path = root / "llm_catalog.yaml"
|
|
with yaml_path.open("r", encoding="utf-8") as f:
|
|
data = yaml.safe_load(f)
|
|
|
|
factories = {f["name"]: f for f in data.get("factories", [])}
|
|
llms = data.get("llms", [])
|
|
|
|
for llm in llms:
|
|
if llm.get("model_code") == model_code:
|
|
factory_name = llm["factory_name"]
|
|
factory = factories.get(factory_name)
|
|
if not factory:
|
|
raise ValueError(f"Factory not found: {factory_name}")
|
|
return LLMConfig(
|
|
model_code=model_code,
|
|
factory_name=factory_name,
|
|
litellm_model=llm.get("litellm_model", model_code),
|
|
request_url=factory["request_url"],
|
|
)
|
|
|
|
raise ValueError(f"Model not found: {model_code}")
|
|
|
|
def chat(
|
|
self,
|
|
messages: list[dict[str, str]],
|
|
*,
|
|
temperature: float = 0.7,
|
|
max_tokens: int | None = None,
|
|
) -> LLMResponse:
|
|
import litellm
|
|
|
|
response = litellm.completion( # type: ignore[attr-defined]
|
|
model=self._config.litellm_model,
|
|
messages=messages,
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
api_base=self._config.request_url,
|
|
api_key=self._api_key,
|
|
)
|
|
|
|
content = response.choices[0].message.content or "" # type: ignore[attr-defined]
|
|
usage = response.usage.model_dump() if response.usage else {} # type: ignore[attr-defined]
|
|
|
|
return LLMResponse(content=content, usage=usage)
|
|
|
|
async def achat(
|
|
self,
|
|
messages: list[dict[str, str]],
|
|
*,
|
|
temperature: float = 0.7,
|
|
max_tokens: int | None = None,
|
|
) -> LLMResponse:
|
|
import litellm
|
|
|
|
response = await litellm.acompletion( # type: ignore[attr-defined]
|
|
model=self._config.litellm_model,
|
|
messages=messages,
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
api_base=self._config.request_url,
|
|
api_key=self._api_key,
|
|
)
|
|
|
|
content = response.choices[0].message.content or "" # type: ignore[attr-defined]
|
|
usage = response.usage.model_dump() if response.usage else {} # type: ignore[attr-defined]
|
|
|
|
return LLMResponse(content=content, usage=usage)
|
|
|
|
|
|
def get_model_cost(usage: dict[str, Any]) -> Decimal:
|
|
cost = usage.get("cost")
|
|
if cost is None:
|
|
return Decimal("0")
|
|
return Decimal(str(cost))
|