from __future__ import annotations import os from dataclasses import dataclass from decimal import Decimal from pathlib import Path from typing import Any import yaml @dataclass(frozen=True) class LLMConfig: model_code: str factory_name: str litellm_model: str request_url: str @dataclass(frozen=True) class LLMResponse: content: str usage: dict[str, Any] class LiteLLMClient: def __init__(self, config: LLMConfig, api_key: str | None = None) -> None: self._config = config self._api_key = api_key or self._get_api_key(config.factory_name) @staticmethod def _get_api_key(factory_name: str) -> str: key_map = { "dashscope": "DASHSCOPE_API_KEY", "minimax": "MINIMAX_API_KEY", "moonshot": "MOONSHOT_API_KEY", "deepseek": "DEEPSEEK_API_KEY", "volcengine-ark": "ARK_API_KEY", "z-ai": "ZAI_API_KEY", } env_key = key_map.get(factory_name) if not env_key: raise ValueError(f"No API key mapping for factory: {factory_name}") key = os.environ.get(env_key) if not key: raise ValueError(f"Environment variable {env_key} is not set") return key @staticmethod def load_config( model_code: str, static_root: Path | None = None, ) -> LLMConfig: root = static_root or ( Path(__file__).resolve().parents[3] / "config" / "static" / "database" ) yaml_path = root / "llm_catalog.yaml" with yaml_path.open("r", encoding="utf-8") as f: data = yaml.safe_load(f) factories = {f["name"]: f for f in data.get("factories", [])} llms = data.get("llms", []) for llm in llms: if llm.get("model_code") == model_code: factory_name = llm["factory_name"] factory = factories.get(factory_name) if not factory: raise ValueError(f"Factory not found: {factory_name}") return LLMConfig( model_code=model_code, factory_name=factory_name, litellm_model=llm.get("litellm_model", model_code), request_url=factory["request_url"], ) raise ValueError(f"Model not found: {model_code}") def chat( self, messages: list[dict[str, str]], *, temperature: float = 0.7, max_tokens: int | None = None, ) -> LLMResponse: import litellm response = litellm.completion( # type: ignore[attr-defined] model=self._config.litellm_model, messages=messages, temperature=temperature, max_tokens=max_tokens, api_base=self._config.request_url, api_key=self._api_key, ) content = response.choices[0].message.content or "" # type: ignore[attr-defined] usage = response.usage.model_dump() if response.usage else {} # type: ignore[attr-defined] return LLMResponse(content=content, usage=usage) async def achat( self, messages: list[dict[str, str]], *, temperature: float = 0.7, max_tokens: int | None = None, ) -> LLMResponse: import litellm response = await litellm.acompletion( # type: ignore[attr-defined] model=self._config.litellm_model, messages=messages, temperature=temperature, max_tokens=max_tokens, api_base=self._config.request_url, api_key=self._api_key, ) content = response.choices[0].message.content or "" # type: ignore[attr-defined] usage = response.usage.model_dump() if response.usage else {} # type: ignore[attr-defined] return LLMResponse(content=content, usage=usage) def get_model_cost(usage: dict[str, Any]) -> Decimal: cost = usage.get("cost") if cost is None: return Decimal("0") return Decimal(str(cost))