refactor: 切换到 litellm,删除未使用的代码

- 添加 litellm 依赖,统一 LLM 调用层
- 新增 litellm_client.py 支持多厂商
- 更新 llm_catalog.yaml 添加 litellm_model 映射
- 删除旧的 cost_tracker.py (litellm 内置 cost 追踪)
- 删除未使用的 multimodal.py 和 storage_adapter.py
- 删除空文件 crewai/__init__.py, tools/__init__.py
- 更新测试以适配新代码
This commit is contained in:
qzl
2026-03-03 17:52:34 +08:00
parent a4f684466c
commit 80cbb3512f
15 changed files with 200 additions and 578 deletions
-25
View File
@@ -1,25 +0,0 @@
# Deletion Log
## 2026-03-03 feature polish
- Scope: `backend/src/v1/agent/*` and `backend/tests/unit/v1/agent/*` only.
- Candidate review source: scoped `refactor-cleaner` run for the directories above.
### Executed cleanup
1. Merged duplicated newline validation logic in `backend/src/v1/agent/service.py`.
- Before: duplicated checks in `prepare_resume`, `stream_run`, `stream_resume`.
- After: centralized `_validate_no_newlines` helper.
- Behavior impact: none (same validation semantics).
2. Merged duplicated SSE event string formatting in `backend/src/v1/agent/service.py`.
- Before: repeated `f"data: {json.dumps(...)}\n\n"` fragments.
- After: centralized `_sse_data` helper.
- Behavior impact: none (same payload format).
### Candidates not deleted (insufficient evidence)
- `backend/src/v1/agent/crewai_flow.py`
- Reason: candidate report suggested possible dead code, but no deletion was done in this polish pass because cross-module usage certainty was insufficient.
- Legacy `run()` path in `backend/src/v1/agent/service.py`
- Reason: potentially still relied on by non-scope code paths; deletion deferred.
-69
View File
@@ -1,69 +0,0 @@
from __future__ import annotations
from decimal import Decimal
from typing import Any, Mapping
def _to_non_negative_int(value: Any, *, field: str) -> int:
if isinstance(value, bool):
raise ValueError(f"{field} must be an integer")
if isinstance(value, int):
converted = value
elif isinstance(value, str) and value.isdigit():
converted = int(value)
else:
raise ValueError(f"{field} must be an integer")
if converted < 0:
raise ValueError(f"{field} cannot be negative")
return converted
def _to_non_negative_decimal(value: Any, *, field: str) -> Decimal:
converted = Decimal(str(value))
if converted < 0:
raise ValueError(f"{field} cannot be negative")
return converted
class CostTracker:
def __init__(self, *, currency: str = "USD") -> None:
self._input_tokens = 0
self._output_tokens = 0
self._total_tokens = 0
self._cost = Decimal("0")
self._currency = currency
def add_usage(self, usage: Mapping[str, Any]) -> None:
input_tokens = usage.get("input_tokens", usage.get("prompt_tokens", 0))
output_tokens = usage.get("output_tokens", usage.get("completion_tokens", 0))
total_tokens = usage.get("total_tokens")
cost = usage.get("cost", "0")
currency = usage.get("currency")
normalized_input = _to_non_negative_int(input_tokens, field="input_tokens")
normalized_output = _to_non_negative_int(output_tokens, field="output_tokens")
normalized_total = (
_to_non_negative_int(total_tokens, field="total_tokens")
if total_tokens is not None
else normalized_input + normalized_output
)
normalized_cost = _to_non_negative_decimal(cost, field="cost")
self._input_tokens += normalized_input
self._output_tokens += normalized_output
self._total_tokens += normalized_total
self._cost += normalized_cost
if currency is not None:
normalized_currency = str(currency)
if normalized_currency != self._currency:
raise ValueError("currency mismatch")
def snapshot(self) -> dict[str, Any]:
return {
"input_tokens": self._input_tokens,
"output_tokens": self._output_tokens,
"total_tokens": self._total_tokens,
"cost": self._cost,
"currency": self._currency,
}
@@ -1 +0,0 @@
from __future__ import annotations
+130
View File
@@ -0,0 +1,130 @@
from __future__ import annotations
import os
from dataclasses import dataclass
from decimal import Decimal
from pathlib import Path
from typing import Any
import yaml
@dataclass(frozen=True)
class LLMConfig:
model_code: str
factory_name: str
litellm_model: str
request_url: str
@dataclass(frozen=True)
class LLMResponse:
content: str
usage: dict[str, Any]
class LiteLLMClient:
def __init__(self, config: LLMConfig, api_key: str | None = None) -> None:
self._config = config
self._api_key = api_key or self._get_api_key(config.factory_name)
@staticmethod
def _get_api_key(factory_name: str) -> str:
key_map = {
"dashscope": "DASHSCOPE_API_KEY",
"minimax": "MINIMAX_API_KEY",
"moonshot": "MOONSHOT_API_KEY",
"deepseek": "DEEPSEEK_API_KEY",
"volcengine-ark": "ARK_API_KEY",
"z-ai": "ZAI_API_KEY",
}
env_key = key_map.get(factory_name)
if not env_key:
raise ValueError(f"No API key mapping for factory: {factory_name}")
key = os.environ.get(env_key)
if not key:
raise ValueError(f"Environment variable {env_key} is not set")
return key
@staticmethod
def load_config(
model_code: str,
static_root: Path | None = None,
) -> LLMConfig:
root = static_root or (
Path(__file__).resolve().parents[3] / "config" / "static" / "database"
)
yaml_path = root / "llm_catalog.yaml"
with yaml_path.open("r", encoding="utf-8") as f:
data = yaml.safe_load(f)
factories = {f["name"]: f for f in data.get("factories", [])}
llms = data.get("llms", [])
for llm in llms:
if llm.get("model_code") == model_code:
factory_name = llm["factory_name"]
factory = factories.get(factory_name)
if not factory:
raise ValueError(f"Factory not found: {factory_name}")
return LLMConfig(
model_code=model_code,
factory_name=factory_name,
litellm_model=llm.get("litellm_model", model_code),
request_url=factory["request_url"],
)
raise ValueError(f"Model not found: {model_code}")
def chat(
self,
messages: list[dict[str, str]],
*,
temperature: float = 0.7,
max_tokens: int | None = None,
) -> LLMResponse:
import litellm
response = litellm.completion( # type: ignore[attr-defined]
model=self._config.litellm_model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
api_base=self._config.request_url,
api_key=self._api_key,
)
content = response.choices[0].message.content or "" # type: ignore[attr-defined]
usage = response.usage.model_dump() if response.usage else {} # type: ignore[attr-defined]
return LLMResponse(content=content, usage=usage)
async def achat(
self,
messages: list[dict[str, str]],
*,
temperature: float = 0.7,
max_tokens: int | None = None,
) -> LLMResponse:
import litellm
response = await litellm.acompletion( # type: ignore[attr-defined]
model=self._config.litellm_model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
api_base=self._config.request_url,
api_key=self._api_key,
)
content = response.choices[0].message.content or "" # type: ignore[attr-defined]
usage = response.usage.model_dump() if response.usage else {} # type: ignore[attr-defined]
return LLMResponse(content=content, usage=usage)
def get_model_cost(usage: dict[str, Any]) -> Decimal:
cost = usage.get("cost")
if cost is None:
return Decimal("0")
return Decimal(str(cost))
-112
View File
@@ -1,112 +0,0 @@
from __future__ import annotations
from dataclasses import dataclass
import hashlib
from pathlib import Path
from typing import Protocol
from core.agent.storage_adapter import StorageAdapter
_ALLOWED_MIME_TYPES = {
"audio/mpeg",
"audio/wav",
"audio/x-wav",
"image/jpeg",
"image/png",
"image/webp",
"application/pdf",
"text/plain",
}
class _AsrTool(Protocol):
def transcribe(self, *, audio_bytes: bytes, filename: str) -> dict[str, object]: ...
@dataclass(frozen=True)
class AttachmentInput:
filename: str
mime_type: str
content: bytes
origin: str = "user_upload"
@dataclass(frozen=True)
class ProcessedAttachmentContext:
attachments: list[dict[str, object]]
preview_texts: list[str]
class MultimodalProcessor:
def __init__(
self,
*,
storage: StorageAdapter,
asr_tool: _AsrTool,
max_file_size_mb: int = 20,
) -> None:
self._storage = storage
self._asr_tool = asr_tool
self._max_size_bytes = max_file_size_mb * 1024 * 1024
def process(
self,
*,
user_id: str,
session_id: str,
message_seq: int,
attachments: list[AttachmentInput],
) -> ProcessedAttachmentContext:
metadata_list: list[dict[str, object]] = []
preview_texts: list[str] = []
for attachment in attachments:
self._validate_attachment(attachment)
checksum = hashlib.sha256(attachment.content).hexdigest()
extension = Path(attachment.filename).suffix.strip(".").lower() or "bin"
object_path = self._storage.build_object_path(
user_id=user_id,
session_id=session_id,
message_seq=message_seq,
checksum_sha256=checksum,
extension=extension,
)
preview_text = self._build_preview_text(attachment)
if preview_text:
preview_texts.append(preview_text)
metadata = self._storage.build_attachment_metadata(
object_path=object_path,
mime_type=attachment.mime_type,
size=len(attachment.content),
checksum_sha256=checksum,
origin=attachment.origin,
preview_text=preview_text,
)
metadata_list.append(metadata)
return ProcessedAttachmentContext(
attachments=metadata_list,
preview_texts=preview_texts,
)
def _validate_attachment(self, attachment: AttachmentInput) -> None:
if attachment.mime_type not in _ALLOWED_MIME_TYPES:
raise ValueError("Unsupported MIME type")
if len(attachment.content) > self._max_size_bytes:
raise ValueError("Attachment exceeds max file size")
def _build_preview_text(self, attachment: AttachmentInput) -> str | None:
if attachment.mime_type.startswith("audio/"):
result = self._asr_tool.transcribe(
audio_bytes=attachment.content,
filename=attachment.filename,
)
text = result.get("text")
if isinstance(text, str):
return text
return None
if attachment.mime_type == "text/plain":
return attachment.content.decode("utf-8", errors="ignore")[:200]
return None
+31 -2
View File
@@ -2,10 +2,11 @@ from __future__ import annotations
import asyncio import asyncio
from dataclasses import dataclass from dataclasses import dataclass
from decimal import Decimal
from typing import Any, Awaitable, Callable from typing import Any, Awaitable, Callable
from core.agent.cost_tracker import CostTracker
from core.agent import events from core.agent import events
from core.agent.litellm_client import get_model_cost
StageCallable = Callable[..., Awaitable[dict[str, Any]]] StageCallable = Callable[..., Awaitable[dict[str, Any]]]
@@ -20,6 +21,34 @@ class OrchestratorResult:
error: str | None error: str | None
class _UsageTracker:
def __init__(self) -> None:
self._input_tokens = 0
self._output_tokens = 0
self._total_tokens = 0
self._cost = Decimal("0")
def add_usage(self, usage: dict[str, Any]) -> None:
input_tokens = usage.get("prompt_tokens", 0) or usage.get("input_tokens", 0)
output_tokens = usage.get("completion_tokens", 0) or usage.get(
"output_tokens", 0
)
total = usage.get("total_tokens")
self._input_tokens += input_tokens
self._output_tokens += output_tokens
self._total_tokens += total if total else (input_tokens + output_tokens)
self._cost += get_model_cost(usage)
def snapshot(self) -> dict[str, Any]:
return {
"input_tokens": self._input_tokens,
"output_tokens": self._output_tokens,
"total_tokens": self._total_tokens,
"cost": str(self._cost),
}
class AgentChatOrchestrator: class AgentChatOrchestrator:
def __init__( def __init__(
self, self,
@@ -36,7 +65,7 @@ class AgentChatOrchestrator:
return asyncio.run(self.run(run_id=run_id, user_message=user_message)) return asyncio.run(self.run(run_id=run_id, user_message=user_message))
async def run(self, *, run_id: str, user_message: str) -> OrchestratorResult: async def run(self, *, run_id: str, user_message: str) -> OrchestratorResult:
tracker = CostTracker() tracker = _UsageTracker()
emitted_events: list[dict[str, Any]] = [events.run_started(run_id=run_id)] emitted_events: list[dict[str, Any]] = [events.run_started(run_id=run_id)]
context: dict[str, Any] = {} context: dict[str, Any] = {}
-46
View File
@@ -1,46 +0,0 @@
from __future__ import annotations
class StorageAdapter:
_bucket: str
def __init__(self, bucket: str) -> None:
self._bucket = bucket
@property
def bucket(self) -> str:
return self._bucket
def build_object_path(
self,
*,
user_id: str,
session_id: str,
message_seq: int,
checksum_sha256: str,
extension: str,
) -> str:
normalized_ext = extension.strip(".").lower()
return (
f"agent-chat/{user_id}/{session_id}/{message_seq}/"
f"{checksum_sha256}.{normalized_ext}"
)
def build_attachment_metadata(
self,
*,
object_path: str,
mime_type: str,
size: int,
checksum_sha256: str,
origin: str,
preview_text: str | None = None,
) -> dict[str, object]:
return {
"object_path": object_path,
"mime_type": mime_type,
"size": size,
"checksum_sha256": checksum_sha256,
"origin": origin,
"preview_text": preview_text,
}
-1
View File
@@ -1 +0,0 @@
from __future__ import annotations
@@ -27,6 +27,8 @@ llms:
# 你原来的两个保留 # 你原来的两个保留
- model_code: qwen3.5-flash - model_code: qwen3.5-flash
factory_name: dashscope factory_name: dashscope
litellm_model: dashscope/qwen-turbo
- model_code: deepseek-v3.2 - model_code: deepseek-v3.2
factory_name: deepseek factory_name: deepseek
litellm_model: deepseek/deepseek-chat
@@ -2,81 +2,47 @@ from __future__ import annotations
from decimal import Decimal from decimal import Decimal
import pytest from core.agent.litellm_client import get_model_cost
from core.agent.cost_tracker import CostTracker
def test_normalize_usage_and_cost_aggregation() -> None: def test_get_model_cost_returns_decimal() -> None:
tracker = CostTracker() usage = {
"prompt_tokens": 7,
tracker.add_usage( "completion_tokens": 5,
{ "total_tokens": 12,
"prompt_tokens": 7, "cost": "0.002500",
"completion_tokens": 5, }
"cost": "0.002500", cost = get_model_cost(usage)
} assert cost == Decimal("0.002500")
)
tracker.add_usage(
{
"input_tokens": 5,
"output_tokens": 3,
"cost": "0.003000",
"currency": "USD",
}
)
snapshot = tracker.snapshot()
assert snapshot["input_tokens"] == 12
assert snapshot["output_tokens"] == 8
assert snapshot["total_tokens"] == 20
assert snapshot["cost"] == Decimal("0.005500")
assert snapshot["currency"] == "USD"
def test_add_usage_rejects_negative_values() -> None: def test_get_model_cost_with_no_cost() -> None:
tracker = CostTracker() usage = {
"prompt_tokens": 7,
with pytest.raises(ValueError): "completion_tokens": 5,
tracker.add_usage({"input_tokens": -1}) "total_tokens": 12,
}
with pytest.raises(ValueError): cost = get_model_cost(usage)
tracker.add_usage({"cost": "-0.010000"}) assert cost == Decimal("0")
def test_snapshot_is_zero_before_any_usage() -> None: def test_get_model_cost_with_zero_cost() -> None:
tracker = CostTracker() usage = {
"prompt_tokens": 7,
snapshot = tracker.snapshot() "completion_tokens": 5,
"total_tokens": 12,
assert snapshot["input_tokens"] == 0 "cost": "0",
assert snapshot["output_tokens"] == 0 }
assert snapshot["total_tokens"] == 0 cost = get_model_cost(usage)
assert snapshot["cost"] == Decimal("0") assert cost == Decimal("0")
assert snapshot["currency"] == "USD"
def test_add_usage_rejects_currency_mismatch() -> None: def test_get_model_cost_with_numeric_cost() -> None:
tracker = CostTracker(currency="USD") usage = {
tracker.add_usage({"input_tokens": 1, "output_tokens": 1, "cost": "0.001000"}) "prompt_tokens": 7,
"completion_tokens": 5,
with pytest.raises(ValueError): "total_tokens": 12,
tracker.add_usage( "cost": 0.0025,
{ }
"input_tokens": 1, cost = get_model_cost(usage)
"output_tokens": 1, assert cost == Decimal("0.0025")
"cost": "0.001000",
"currency": "CNY",
}
)
def test_add_usage_rejects_non_integral_token_values() -> None:
tracker = CostTracker()
with pytest.raises(ValueError):
tracker.add_usage({"input_tokens": 1.5})
with pytest.raises(ValueError):
tracker.add_usage({"output_tokens": True})
@@ -1,89 +0,0 @@
from __future__ import annotations
import pytest
from core.agent.multimodal import AttachmentInput, MultimodalProcessor
from core.agent.storage_adapter import StorageAdapter
from core.agent.tools.asr_fun_asr import FunASRTool
def test_multimodal_processes_audio_and_builds_attachment_context() -> None:
storage = StorageAdapter(bucket="agent-chat-attachments")
def fake_transcribe(*, audio_bytes: bytes, filename: str) -> dict[str, str]:
assert audio_bytes == b"audio"
assert filename == "voice.wav"
return {"text": "hello world", "request_id": "req-1"}
processor = MultimodalProcessor(
storage=storage,
asr_tool=FunASRTool(transcribe_callable=fake_transcribe),
max_file_size_mb=1,
)
result = processor.process(
user_id="u1",
session_id="s1",
message_seq=4,
attachments=[
AttachmentInput(
filename="voice.wav",
mime_type="audio/wav",
content=b"audio",
)
],
)
assert len(result.attachments) == 1
metadata = result.attachments[0]
assert (
metadata["object_path"]
== "agent-chat/u1/s1/4/6ed8919ce20490a5e3ad8630a4fab69475297abd07db73918dd5f36fcfaeb11b.wav"
)
assert metadata["mime_type"] == "audio/wav"
assert result.preview_texts == ["hello world"]
def test_multimodal_rejects_unsupported_mime_type() -> None:
storage = StorageAdapter(bucket="agent-chat-attachments")
processor = MultimodalProcessor(
storage=storage, asr_tool=FunASRTool(lambda **_: {})
)
with pytest.raises(ValueError):
processor.process(
user_id="u1",
session_id="s1",
message_seq=1,
attachments=[
AttachmentInput(
filename="malware.exe",
mime_type="application/octet-stream",
content=b"bad",
)
],
)
def test_multimodal_rejects_attachment_over_max_size() -> None:
storage = StorageAdapter(bucket="agent-chat-attachments")
processor = MultimodalProcessor(
storage=storage,
asr_tool=FunASRTool(lambda **_: {}),
max_file_size_mb=1,
)
oversized = b"x" * (1024 * 1024 + 1)
with pytest.raises(ValueError):
processor.process(
user_id="u1",
session_id="s1",
message_seq=1,
attachments=[
AttachmentInput(
filename="big.wav",
mime_type="audio/wav",
content=oversized,
)
],
)
@@ -1,37 +0,0 @@
from __future__ import annotations
from core.agent.storage_adapter import StorageAdapter
def test_build_object_path_uses_expected_pattern() -> None:
adapter = StorageAdapter(bucket="agent-chat-attachments")
path = adapter.build_object_path(
user_id="u1",
session_id="s1",
message_seq=3,
checksum_sha256="abc123",
extension="wav",
)
assert path == "agent-chat/u1/s1/3/abc123.wav"
def test_build_attachment_metadata_contains_required_fields() -> None:
adapter = StorageAdapter(bucket="agent-chat-attachments")
metadata = adapter.build_attachment_metadata(
object_path="agent-chat/u1/s1/3/abc123.wav",
mime_type="audio/wav",
size=1024,
checksum_sha256="abc123",
origin="user_upload",
preview_text="hello",
)
assert metadata["object_path"] == "agent-chat/u1/s1/3/abc123.wav"
assert metadata["mime_type"] == "audio/wav"
assert metadata["size"] == 1024
assert metadata["checksum_sha256"] == "abc123"
assert metadata["origin"] == "user_upload"
assert metadata["preview_text"] == "hello"
@@ -1,89 +0,0 @@
from __future__ import annotations
import pytest
from core.agent.multimodal import AttachmentInput, MultimodalProcessor
from core.agent.storage_adapter import StorageAdapter
from core.agent.tools.asr_fun_asr import FunASRTool
def test_multimodal_processes_audio_and_builds_attachment_context() -> None:
storage = StorageAdapter(bucket="agent-chat-attachments")
def fake_transcribe(*, audio_bytes: bytes, filename: str) -> dict[str, str]:
assert audio_bytes == b"audio"
assert filename == "voice.wav"
return {"text": "hello world", "request_id": "req-1"}
processor = MultimodalProcessor(
storage=storage,
asr_tool=FunASRTool(transcribe_callable=fake_transcribe),
max_file_size_mb=1,
)
result = processor.process(
user_id="u1",
session_id="s1",
message_seq=4,
attachments=[
AttachmentInput(
filename="voice.wav",
mime_type="audio/wav",
content=b"audio",
)
],
)
assert len(result.attachments) == 1
metadata = result.attachments[0]
assert (
metadata["object_path"]
== "agent-chat/u1/s1/4/6ed8919ce20490a5e3ad8630a4fab69475297abd07db73918dd5f36fcfaeb11b.wav"
)
assert metadata["mime_type"] == "audio/wav"
assert result.preview_texts == ["hello world"]
def test_multimodal_rejects_unsupported_mime_type() -> None:
storage = StorageAdapter(bucket="agent-chat-attachments")
processor = MultimodalProcessor(
storage=storage, asr_tool=FunASRTool(lambda **_: {})
)
with pytest.raises(ValueError):
processor.process(
user_id="u1",
session_id="s1",
message_seq=1,
attachments=[
AttachmentInput(
filename="malware.exe",
mime_type="application/octet-stream",
content=b"bad",
)
],
)
def test_multimodal_rejects_attachment_over_max_size() -> None:
storage = StorageAdapter(bucket="agent-chat-attachments")
processor = MultimodalProcessor(
storage=storage,
asr_tool=FunASRTool(lambda **_: {}),
max_file_size_mb=1,
)
oversized = b"x" * (1024 * 1024 + 1)
with pytest.raises(ValueError):
processor.process(
user_id="u1",
session_id="s1",
message_seq=1,
attachments=[
AttachmentInput(
filename="big.wav",
mime_type="audio/wav",
content=oversized,
)
],
)
@@ -1,37 +0,0 @@
from __future__ import annotations
from core.agent.storage_adapter import StorageAdapter
def test_build_object_path_uses_expected_pattern() -> None:
adapter = StorageAdapter(bucket="agent-chat-attachments")
path = adapter.build_object_path(
user_id="u1",
session_id="s1",
message_seq=3,
checksum_sha256="abc123",
extension="wav",
)
assert path == "agent-chat/u1/s1/3/abc123.wav"
def test_build_attachment_metadata_contains_required_fields() -> None:
adapter = StorageAdapter(bucket="agent-chat-attachments")
metadata = adapter.build_attachment_metadata(
object_path="agent-chat/u1/s1/3/abc123.wav",
mime_type="audio/wav",
size=1024,
checksum_sha256="abc123",
origin="user_upload",
preview_text="hello",
)
assert metadata["object_path"] == "agent-chat/u1/s1/3/abc123.wav"
assert metadata["mime_type"] == "audio/wav"
assert metadata["size"] == 1024
assert metadata["checksum_sha256"] == "abc123"
assert metadata["origin"] == "user_upload"
assert metadata["preview_text"] == "hello"
+1
View File
@@ -12,6 +12,7 @@ dependencies = [
"crewai-tools>=1.6.1", "crewai-tools>=1.6.1",
"email-validator>=2.3.0", "email-validator>=2.3.0",
"fastapi>=0.128.0", "fastapi>=0.128.0",
"litellm>=1.52.0",
"pydantic>=2.11.0", "pydantic>=2.11.0",
"pydantic-settings>=2.10.0", "pydantic-settings>=2.10.0",
"pyjwt>=2.10.1", "pyjwt>=2.10.1",