refactor: 切换到 litellm,删除未使用的代码
- 添加 litellm 依赖,统一 LLM 调用层 - 新增 litellm_client.py 支持多厂商 - 更新 llm_catalog.yaml 添加 litellm_model 映射 - 删除旧的 cost_tracker.py (litellm 内置 cost 追踪) - 删除未使用的 multimodal.py 和 storage_adapter.py - 删除空文件 crewai/__init__.py, tools/__init__.py - 更新测试以适配新代码
This commit is contained in:
@@ -1,69 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Any, Mapping
|
||||
|
||||
|
||||
def _to_non_negative_int(value: Any, *, field: str) -> int:
|
||||
if isinstance(value, bool):
|
||||
raise ValueError(f"{field} must be an integer")
|
||||
if isinstance(value, int):
|
||||
converted = value
|
||||
elif isinstance(value, str) and value.isdigit():
|
||||
converted = int(value)
|
||||
else:
|
||||
raise ValueError(f"{field} must be an integer")
|
||||
if converted < 0:
|
||||
raise ValueError(f"{field} cannot be negative")
|
||||
return converted
|
||||
|
||||
|
||||
def _to_non_negative_decimal(value: Any, *, field: str) -> Decimal:
|
||||
converted = Decimal(str(value))
|
||||
if converted < 0:
|
||||
raise ValueError(f"{field} cannot be negative")
|
||||
return converted
|
||||
|
||||
|
||||
class CostTracker:
|
||||
def __init__(self, *, currency: str = "USD") -> None:
|
||||
self._input_tokens = 0
|
||||
self._output_tokens = 0
|
||||
self._total_tokens = 0
|
||||
self._cost = Decimal("0")
|
||||
self._currency = currency
|
||||
|
||||
def add_usage(self, usage: Mapping[str, Any]) -> None:
|
||||
input_tokens = usage.get("input_tokens", usage.get("prompt_tokens", 0))
|
||||
output_tokens = usage.get("output_tokens", usage.get("completion_tokens", 0))
|
||||
total_tokens = usage.get("total_tokens")
|
||||
cost = usage.get("cost", "0")
|
||||
currency = usage.get("currency")
|
||||
|
||||
normalized_input = _to_non_negative_int(input_tokens, field="input_tokens")
|
||||
normalized_output = _to_non_negative_int(output_tokens, field="output_tokens")
|
||||
normalized_total = (
|
||||
_to_non_negative_int(total_tokens, field="total_tokens")
|
||||
if total_tokens is not None
|
||||
else normalized_input + normalized_output
|
||||
)
|
||||
normalized_cost = _to_non_negative_decimal(cost, field="cost")
|
||||
|
||||
self._input_tokens += normalized_input
|
||||
self._output_tokens += normalized_output
|
||||
self._total_tokens += normalized_total
|
||||
self._cost += normalized_cost
|
||||
|
||||
if currency is not None:
|
||||
normalized_currency = str(currency)
|
||||
if normalized_currency != self._currency:
|
||||
raise ValueError("currency mismatch")
|
||||
|
||||
def snapshot(self) -> dict[str, Any]:
|
||||
return {
|
||||
"input_tokens": self._input_tokens,
|
||||
"output_tokens": self._output_tokens,
|
||||
"total_tokens": self._total_tokens,
|
||||
"cost": self._cost,
|
||||
"currency": self._currency,
|
||||
}
|
||||
@@ -1 +0,0 @@
|
||||
from __future__ import annotations
|
||||
@@ -0,0 +1,130 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from decimal import Decimal
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LLMConfig:
|
||||
model_code: str
|
||||
factory_name: str
|
||||
litellm_model: str
|
||||
request_url: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LLMResponse:
|
||||
content: str
|
||||
usage: dict[str, Any]
|
||||
|
||||
|
||||
class LiteLLMClient:
|
||||
def __init__(self, config: LLMConfig, api_key: str | None = None) -> None:
|
||||
self._config = config
|
||||
self._api_key = api_key or self._get_api_key(config.factory_name)
|
||||
|
||||
@staticmethod
|
||||
def _get_api_key(factory_name: str) -> str:
|
||||
key_map = {
|
||||
"dashscope": "DASHSCOPE_API_KEY",
|
||||
"minimax": "MINIMAX_API_KEY",
|
||||
"moonshot": "MOONSHOT_API_KEY",
|
||||
"deepseek": "DEEPSEEK_API_KEY",
|
||||
"volcengine-ark": "ARK_API_KEY",
|
||||
"z-ai": "ZAI_API_KEY",
|
||||
}
|
||||
env_key = key_map.get(factory_name)
|
||||
if not env_key:
|
||||
raise ValueError(f"No API key mapping for factory: {factory_name}")
|
||||
key = os.environ.get(env_key)
|
||||
if not key:
|
||||
raise ValueError(f"Environment variable {env_key} is not set")
|
||||
return key
|
||||
|
||||
@staticmethod
|
||||
def load_config(
|
||||
model_code: str,
|
||||
static_root: Path | None = None,
|
||||
) -> LLMConfig:
|
||||
root = static_root or (
|
||||
Path(__file__).resolve().parents[3] / "config" / "static" / "database"
|
||||
)
|
||||
yaml_path = root / "llm_catalog.yaml"
|
||||
with yaml_path.open("r", encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
factories = {f["name"]: f for f in data.get("factories", [])}
|
||||
llms = data.get("llms", [])
|
||||
|
||||
for llm in llms:
|
||||
if llm.get("model_code") == model_code:
|
||||
factory_name = llm["factory_name"]
|
||||
factory = factories.get(factory_name)
|
||||
if not factory:
|
||||
raise ValueError(f"Factory not found: {factory_name}")
|
||||
return LLMConfig(
|
||||
model_code=model_code,
|
||||
factory_name=factory_name,
|
||||
litellm_model=llm.get("litellm_model", model_code),
|
||||
request_url=factory["request_url"],
|
||||
)
|
||||
|
||||
raise ValueError(f"Model not found: {model_code}")
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
*,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int | None = None,
|
||||
) -> LLMResponse:
|
||||
import litellm
|
||||
|
||||
response = litellm.completion( # type: ignore[attr-defined]
|
||||
model=self._config.litellm_model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
api_base=self._config.request_url,
|
||||
api_key=self._api_key,
|
||||
)
|
||||
|
||||
content = response.choices[0].message.content or "" # type: ignore[attr-defined]
|
||||
usage = response.usage.model_dump() if response.usage else {} # type: ignore[attr-defined]
|
||||
|
||||
return LLMResponse(content=content, usage=usage)
|
||||
|
||||
async def achat(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
*,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int | None = None,
|
||||
) -> LLMResponse:
|
||||
import litellm
|
||||
|
||||
response = await litellm.acompletion( # type: ignore[attr-defined]
|
||||
model=self._config.litellm_model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
api_base=self._config.request_url,
|
||||
api_key=self._api_key,
|
||||
)
|
||||
|
||||
content = response.choices[0].message.content or "" # type: ignore[attr-defined]
|
||||
usage = response.usage.model_dump() if response.usage else {} # type: ignore[attr-defined]
|
||||
|
||||
return LLMResponse(content=content, usage=usage)
|
||||
|
||||
|
||||
def get_model_cost(usage: dict[str, Any]) -> Decimal:
|
||||
cost = usage.get("cost")
|
||||
if cost is None:
|
||||
return Decimal("0")
|
||||
return Decimal(str(cost))
|
||||
@@ -1,112 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from typing import Protocol
|
||||
|
||||
from core.agent.storage_adapter import StorageAdapter
|
||||
|
||||
_ALLOWED_MIME_TYPES = {
|
||||
"audio/mpeg",
|
||||
"audio/wav",
|
||||
"audio/x-wav",
|
||||
"image/jpeg",
|
||||
"image/png",
|
||||
"image/webp",
|
||||
"application/pdf",
|
||||
"text/plain",
|
||||
}
|
||||
|
||||
|
||||
class _AsrTool(Protocol):
|
||||
def transcribe(self, *, audio_bytes: bytes, filename: str) -> dict[str, object]: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AttachmentInput:
|
||||
filename: str
|
||||
mime_type: str
|
||||
content: bytes
|
||||
origin: str = "user_upload"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProcessedAttachmentContext:
|
||||
attachments: list[dict[str, object]]
|
||||
preview_texts: list[str]
|
||||
|
||||
|
||||
class MultimodalProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
storage: StorageAdapter,
|
||||
asr_tool: _AsrTool,
|
||||
max_file_size_mb: int = 20,
|
||||
) -> None:
|
||||
self._storage = storage
|
||||
self._asr_tool = asr_tool
|
||||
self._max_size_bytes = max_file_size_mb * 1024 * 1024
|
||||
|
||||
def process(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
message_seq: int,
|
||||
attachments: list[AttachmentInput],
|
||||
) -> ProcessedAttachmentContext:
|
||||
metadata_list: list[dict[str, object]] = []
|
||||
preview_texts: list[str] = []
|
||||
|
||||
for attachment in attachments:
|
||||
self._validate_attachment(attachment)
|
||||
checksum = hashlib.sha256(attachment.content).hexdigest()
|
||||
extension = Path(attachment.filename).suffix.strip(".").lower() or "bin"
|
||||
object_path = self._storage.build_object_path(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
message_seq=message_seq,
|
||||
checksum_sha256=checksum,
|
||||
extension=extension,
|
||||
)
|
||||
|
||||
preview_text = self._build_preview_text(attachment)
|
||||
if preview_text:
|
||||
preview_texts.append(preview_text)
|
||||
|
||||
metadata = self._storage.build_attachment_metadata(
|
||||
object_path=object_path,
|
||||
mime_type=attachment.mime_type,
|
||||
size=len(attachment.content),
|
||||
checksum_sha256=checksum,
|
||||
origin=attachment.origin,
|
||||
preview_text=preview_text,
|
||||
)
|
||||
metadata_list.append(metadata)
|
||||
|
||||
return ProcessedAttachmentContext(
|
||||
attachments=metadata_list,
|
||||
preview_texts=preview_texts,
|
||||
)
|
||||
|
||||
def _validate_attachment(self, attachment: AttachmentInput) -> None:
|
||||
if attachment.mime_type not in _ALLOWED_MIME_TYPES:
|
||||
raise ValueError("Unsupported MIME type")
|
||||
if len(attachment.content) > self._max_size_bytes:
|
||||
raise ValueError("Attachment exceeds max file size")
|
||||
|
||||
def _build_preview_text(self, attachment: AttachmentInput) -> str | None:
|
||||
if attachment.mime_type.startswith("audio/"):
|
||||
result = self._asr_tool.transcribe(
|
||||
audio_bytes=attachment.content,
|
||||
filename=attachment.filename,
|
||||
)
|
||||
text = result.get("text")
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
return None
|
||||
if attachment.mime_type == "text/plain":
|
||||
return attachment.content.decode("utf-8", errors="ignore")[:200]
|
||||
return None
|
||||
@@ -2,10 +2,11 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from decimal import Decimal
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
from core.agent.cost_tracker import CostTracker
|
||||
from core.agent import events
|
||||
from core.agent.litellm_client import get_model_cost
|
||||
|
||||
StageCallable = Callable[..., Awaitable[dict[str, Any]]]
|
||||
|
||||
@@ -20,6 +21,34 @@ class OrchestratorResult:
|
||||
error: str | None
|
||||
|
||||
|
||||
class _UsageTracker:
|
||||
def __init__(self) -> None:
|
||||
self._input_tokens = 0
|
||||
self._output_tokens = 0
|
||||
self._total_tokens = 0
|
||||
self._cost = Decimal("0")
|
||||
|
||||
def add_usage(self, usage: dict[str, Any]) -> None:
|
||||
input_tokens = usage.get("prompt_tokens", 0) or usage.get("input_tokens", 0)
|
||||
output_tokens = usage.get("completion_tokens", 0) or usage.get(
|
||||
"output_tokens", 0
|
||||
)
|
||||
total = usage.get("total_tokens")
|
||||
|
||||
self._input_tokens += input_tokens
|
||||
self._output_tokens += output_tokens
|
||||
self._total_tokens += total if total else (input_tokens + output_tokens)
|
||||
self._cost += get_model_cost(usage)
|
||||
|
||||
def snapshot(self) -> dict[str, Any]:
|
||||
return {
|
||||
"input_tokens": self._input_tokens,
|
||||
"output_tokens": self._output_tokens,
|
||||
"total_tokens": self._total_tokens,
|
||||
"cost": str(self._cost),
|
||||
}
|
||||
|
||||
|
||||
class AgentChatOrchestrator:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -36,7 +65,7 @@ class AgentChatOrchestrator:
|
||||
return asyncio.run(self.run(run_id=run_id, user_message=user_message))
|
||||
|
||||
async def run(self, *, run_id: str, user_message: str) -> OrchestratorResult:
|
||||
tracker = CostTracker()
|
||||
tracker = _UsageTracker()
|
||||
emitted_events: list[dict[str, Any]] = [events.run_started(run_id=run_id)]
|
||||
context: dict[str, Any] = {}
|
||||
|
||||
|
||||
@@ -1,46 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class StorageAdapter:
|
||||
_bucket: str
|
||||
|
||||
def __init__(self, bucket: str) -> None:
|
||||
self._bucket = bucket
|
||||
|
||||
@property
|
||||
def bucket(self) -> str:
|
||||
return self._bucket
|
||||
|
||||
def build_object_path(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
message_seq: int,
|
||||
checksum_sha256: str,
|
||||
extension: str,
|
||||
) -> str:
|
||||
normalized_ext = extension.strip(".").lower()
|
||||
return (
|
||||
f"agent-chat/{user_id}/{session_id}/{message_seq}/"
|
||||
f"{checksum_sha256}.{normalized_ext}"
|
||||
)
|
||||
|
||||
def build_attachment_metadata(
|
||||
self,
|
||||
*,
|
||||
object_path: str,
|
||||
mime_type: str,
|
||||
size: int,
|
||||
checksum_sha256: str,
|
||||
origin: str,
|
||||
preview_text: str | None = None,
|
||||
) -> dict[str, object]:
|
||||
return {
|
||||
"object_path": object_path,
|
||||
"mime_type": mime_type,
|
||||
"size": size,
|
||||
"checksum_sha256": checksum_sha256,
|
||||
"origin": origin,
|
||||
"preview_text": preview_text,
|
||||
}
|
||||
@@ -1 +0,0 @@
|
||||
from __future__ import annotations
|
||||
@@ -27,6 +27,8 @@ llms:
|
||||
# 你原来的两个保留
|
||||
- model_code: qwen3.5-flash
|
||||
factory_name: dashscope
|
||||
litellm_model: dashscope/qwen-turbo
|
||||
|
||||
- model_code: deepseek-v3.2
|
||||
factory_name: deepseek
|
||||
litellm_model: deepseek/deepseek-chat
|
||||
|
||||
Reference in New Issue
Block a user