refactor: 切换到 litellm,删除未使用的代码

- 添加 litellm 依赖,统一 LLM 调用层
- 新增 litellm_client.py 支持多厂商
- 更新 llm_catalog.yaml 添加 litellm_model 映射
- 删除旧的 cost_tracker.py (litellm 内置 cost 追踪)
- 删除未使用的 multimodal.py 和 storage_adapter.py
- 删除空文件 crewai/__init__.py, tools/__init__.py
- 更新测试以适配新代码
This commit is contained in:
qzl
2026-03-03 17:52:34 +08:00
parent a4f684466c
commit 80cbb3512f
15 changed files with 200 additions and 578 deletions
-25
View File
@@ -1,25 +0,0 @@
# Deletion Log
## 2026-03-03 feature polish
- Scope: `backend/src/v1/agent/*` and `backend/tests/unit/v1/agent/*` only.
- Candidate review source: scoped `refactor-cleaner` run for the directories above.
### Executed cleanup
1. Merged duplicated newline validation logic in `backend/src/v1/agent/service.py`.
- Before: duplicated checks in `prepare_resume`, `stream_run`, `stream_resume`.
- After: centralized `_validate_no_newlines` helper.
- Behavior impact: none (same validation semantics).
2. Merged duplicated SSE event string formatting in `backend/src/v1/agent/service.py`.
- Before: repeated `f"data: {json.dumps(...)}\n\n"` fragments.
- After: centralized `_sse_data` helper.
- Behavior impact: none (same payload format).
### Candidates not deleted (insufficient evidence)
- `backend/src/v1/agent/crewai_flow.py`
- Reason: candidate report suggested possible dead code, but no deletion was done in this polish pass because cross-module usage certainty was insufficient.
- Legacy `run()` path in `backend/src/v1/agent/service.py`
- Reason: potentially still relied on by non-scope code paths; deletion deferred.
-69
View File
@@ -1,69 +0,0 @@
from __future__ import annotations
from decimal import Decimal
from typing import Any, Mapping
def _to_non_negative_int(value: Any, *, field: str) -> int:
if isinstance(value, bool):
raise ValueError(f"{field} must be an integer")
if isinstance(value, int):
converted = value
elif isinstance(value, str) and value.isdigit():
converted = int(value)
else:
raise ValueError(f"{field} must be an integer")
if converted < 0:
raise ValueError(f"{field} cannot be negative")
return converted
def _to_non_negative_decimal(value: Any, *, field: str) -> Decimal:
converted = Decimal(str(value))
if converted < 0:
raise ValueError(f"{field} cannot be negative")
return converted
class CostTracker:
def __init__(self, *, currency: str = "USD") -> None:
self._input_tokens = 0
self._output_tokens = 0
self._total_tokens = 0
self._cost = Decimal("0")
self._currency = currency
def add_usage(self, usage: Mapping[str, Any]) -> None:
input_tokens = usage.get("input_tokens", usage.get("prompt_tokens", 0))
output_tokens = usage.get("output_tokens", usage.get("completion_tokens", 0))
total_tokens = usage.get("total_tokens")
cost = usage.get("cost", "0")
currency = usage.get("currency")
normalized_input = _to_non_negative_int(input_tokens, field="input_tokens")
normalized_output = _to_non_negative_int(output_tokens, field="output_tokens")
normalized_total = (
_to_non_negative_int(total_tokens, field="total_tokens")
if total_tokens is not None
else normalized_input + normalized_output
)
normalized_cost = _to_non_negative_decimal(cost, field="cost")
self._input_tokens += normalized_input
self._output_tokens += normalized_output
self._total_tokens += normalized_total
self._cost += normalized_cost
if currency is not None:
normalized_currency = str(currency)
if normalized_currency != self._currency:
raise ValueError("currency mismatch")
def snapshot(self) -> dict[str, Any]:
return {
"input_tokens": self._input_tokens,
"output_tokens": self._output_tokens,
"total_tokens": self._total_tokens,
"cost": self._cost,
"currency": self._currency,
}
@@ -1 +0,0 @@
from __future__ import annotations
+130
View File
@@ -0,0 +1,130 @@
from __future__ import annotations
import os
from dataclasses import dataclass
from decimal import Decimal
from pathlib import Path
from typing import Any
import yaml
@dataclass(frozen=True)
class LLMConfig:
model_code: str
factory_name: str
litellm_model: str
request_url: str
@dataclass(frozen=True)
class LLMResponse:
content: str
usage: dict[str, Any]
class LiteLLMClient:
def __init__(self, config: LLMConfig, api_key: str | None = None) -> None:
self._config = config
self._api_key = api_key or self._get_api_key(config.factory_name)
@staticmethod
def _get_api_key(factory_name: str) -> str:
key_map = {
"dashscope": "DASHSCOPE_API_KEY",
"minimax": "MINIMAX_API_KEY",
"moonshot": "MOONSHOT_API_KEY",
"deepseek": "DEEPSEEK_API_KEY",
"volcengine-ark": "ARK_API_KEY",
"z-ai": "ZAI_API_KEY",
}
env_key = key_map.get(factory_name)
if not env_key:
raise ValueError(f"No API key mapping for factory: {factory_name}")
key = os.environ.get(env_key)
if not key:
raise ValueError(f"Environment variable {env_key} is not set")
return key
@staticmethod
def load_config(
model_code: str,
static_root: Path | None = None,
) -> LLMConfig:
root = static_root or (
Path(__file__).resolve().parents[3] / "config" / "static" / "database"
)
yaml_path = root / "llm_catalog.yaml"
with yaml_path.open("r", encoding="utf-8") as f:
data = yaml.safe_load(f)
factories = {f["name"]: f for f in data.get("factories", [])}
llms = data.get("llms", [])
for llm in llms:
if llm.get("model_code") == model_code:
factory_name = llm["factory_name"]
factory = factories.get(factory_name)
if not factory:
raise ValueError(f"Factory not found: {factory_name}")
return LLMConfig(
model_code=model_code,
factory_name=factory_name,
litellm_model=llm.get("litellm_model", model_code),
request_url=factory["request_url"],
)
raise ValueError(f"Model not found: {model_code}")
def chat(
self,
messages: list[dict[str, str]],
*,
temperature: float = 0.7,
max_tokens: int | None = None,
) -> LLMResponse:
import litellm
response = litellm.completion( # type: ignore[attr-defined]
model=self._config.litellm_model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
api_base=self._config.request_url,
api_key=self._api_key,
)
content = response.choices[0].message.content or "" # type: ignore[attr-defined]
usage = response.usage.model_dump() if response.usage else {} # type: ignore[attr-defined]
return LLMResponse(content=content, usage=usage)
async def achat(
self,
messages: list[dict[str, str]],
*,
temperature: float = 0.7,
max_tokens: int | None = None,
) -> LLMResponse:
import litellm
response = await litellm.acompletion( # type: ignore[attr-defined]
model=self._config.litellm_model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
api_base=self._config.request_url,
api_key=self._api_key,
)
content = response.choices[0].message.content or "" # type: ignore[attr-defined]
usage = response.usage.model_dump() if response.usage else {} # type: ignore[attr-defined]
return LLMResponse(content=content, usage=usage)
def get_model_cost(usage: dict[str, Any]) -> Decimal:
cost = usage.get("cost")
if cost is None:
return Decimal("0")
return Decimal(str(cost))
-112
View File
@@ -1,112 +0,0 @@
from __future__ import annotations
from dataclasses import dataclass
import hashlib
from pathlib import Path
from typing import Protocol
from core.agent.storage_adapter import StorageAdapter
_ALLOWED_MIME_TYPES = {
"audio/mpeg",
"audio/wav",
"audio/x-wav",
"image/jpeg",
"image/png",
"image/webp",
"application/pdf",
"text/plain",
}
class _AsrTool(Protocol):
def transcribe(self, *, audio_bytes: bytes, filename: str) -> dict[str, object]: ...
@dataclass(frozen=True)
class AttachmentInput:
filename: str
mime_type: str
content: bytes
origin: str = "user_upload"
@dataclass(frozen=True)
class ProcessedAttachmentContext:
attachments: list[dict[str, object]]
preview_texts: list[str]
class MultimodalProcessor:
def __init__(
self,
*,
storage: StorageAdapter,
asr_tool: _AsrTool,
max_file_size_mb: int = 20,
) -> None:
self._storage = storage
self._asr_tool = asr_tool
self._max_size_bytes = max_file_size_mb * 1024 * 1024
def process(
self,
*,
user_id: str,
session_id: str,
message_seq: int,
attachments: list[AttachmentInput],
) -> ProcessedAttachmentContext:
metadata_list: list[dict[str, object]] = []
preview_texts: list[str] = []
for attachment in attachments:
self._validate_attachment(attachment)
checksum = hashlib.sha256(attachment.content).hexdigest()
extension = Path(attachment.filename).suffix.strip(".").lower() or "bin"
object_path = self._storage.build_object_path(
user_id=user_id,
session_id=session_id,
message_seq=message_seq,
checksum_sha256=checksum,
extension=extension,
)
preview_text = self._build_preview_text(attachment)
if preview_text:
preview_texts.append(preview_text)
metadata = self._storage.build_attachment_metadata(
object_path=object_path,
mime_type=attachment.mime_type,
size=len(attachment.content),
checksum_sha256=checksum,
origin=attachment.origin,
preview_text=preview_text,
)
metadata_list.append(metadata)
return ProcessedAttachmentContext(
attachments=metadata_list,
preview_texts=preview_texts,
)
def _validate_attachment(self, attachment: AttachmentInput) -> None:
if attachment.mime_type not in _ALLOWED_MIME_TYPES:
raise ValueError("Unsupported MIME type")
if len(attachment.content) > self._max_size_bytes:
raise ValueError("Attachment exceeds max file size")
def _build_preview_text(self, attachment: AttachmentInput) -> str | None:
if attachment.mime_type.startswith("audio/"):
result = self._asr_tool.transcribe(
audio_bytes=attachment.content,
filename=attachment.filename,
)
text = result.get("text")
if isinstance(text, str):
return text
return None
if attachment.mime_type == "text/plain":
return attachment.content.decode("utf-8", errors="ignore")[:200]
return None
+31 -2
View File
@@ -2,10 +2,11 @@ from __future__ import annotations
import asyncio
from dataclasses import dataclass
from decimal import Decimal
from typing import Any, Awaitable, Callable
from core.agent.cost_tracker import CostTracker
from core.agent import events
from core.agent.litellm_client import get_model_cost
StageCallable = Callable[..., Awaitable[dict[str, Any]]]
@@ -20,6 +21,34 @@ class OrchestratorResult:
error: str | None
class _UsageTracker:
def __init__(self) -> None:
self._input_tokens = 0
self._output_tokens = 0
self._total_tokens = 0
self._cost = Decimal("0")
def add_usage(self, usage: dict[str, Any]) -> None:
input_tokens = usage.get("prompt_tokens", 0) or usage.get("input_tokens", 0)
output_tokens = usage.get("completion_tokens", 0) or usage.get(
"output_tokens", 0
)
total = usage.get("total_tokens")
self._input_tokens += input_tokens
self._output_tokens += output_tokens
self._total_tokens += total if total else (input_tokens + output_tokens)
self._cost += get_model_cost(usage)
def snapshot(self) -> dict[str, Any]:
return {
"input_tokens": self._input_tokens,
"output_tokens": self._output_tokens,
"total_tokens": self._total_tokens,
"cost": str(self._cost),
}
class AgentChatOrchestrator:
def __init__(
self,
@@ -36,7 +65,7 @@ class AgentChatOrchestrator:
return asyncio.run(self.run(run_id=run_id, user_message=user_message))
async def run(self, *, run_id: str, user_message: str) -> OrchestratorResult:
tracker = CostTracker()
tracker = _UsageTracker()
emitted_events: list[dict[str, Any]] = [events.run_started(run_id=run_id)]
context: dict[str, Any] = {}
-46
View File
@@ -1,46 +0,0 @@
from __future__ import annotations
class StorageAdapter:
_bucket: str
def __init__(self, bucket: str) -> None:
self._bucket = bucket
@property
def bucket(self) -> str:
return self._bucket
def build_object_path(
self,
*,
user_id: str,
session_id: str,
message_seq: int,
checksum_sha256: str,
extension: str,
) -> str:
normalized_ext = extension.strip(".").lower()
return (
f"agent-chat/{user_id}/{session_id}/{message_seq}/"
f"{checksum_sha256}.{normalized_ext}"
)
def build_attachment_metadata(
self,
*,
object_path: str,
mime_type: str,
size: int,
checksum_sha256: str,
origin: str,
preview_text: str | None = None,
) -> dict[str, object]:
return {
"object_path": object_path,
"mime_type": mime_type,
"size": size,
"checksum_sha256": checksum_sha256,
"origin": origin,
"preview_text": preview_text,
}
-1
View File
@@ -1 +0,0 @@
from __future__ import annotations
@@ -27,6 +27,8 @@ llms:
# 你原来的两个保留
- model_code: qwen3.5-flash
factory_name: dashscope
litellm_model: dashscope/qwen-turbo
- model_code: deepseek-v3.2
factory_name: deepseek
litellm_model: deepseek/deepseek-chat
@@ -2,81 +2,47 @@ from __future__ import annotations
from decimal import Decimal
import pytest
from core.agent.cost_tracker import CostTracker
from core.agent.litellm_client import get_model_cost
def test_normalize_usage_and_cost_aggregation() -> None:
tracker = CostTracker()
tracker.add_usage(
{
"prompt_tokens": 7,
"completion_tokens": 5,
"cost": "0.002500",
}
)
tracker.add_usage(
{
"input_tokens": 5,
"output_tokens": 3,
"cost": "0.003000",
"currency": "USD",
}
)
snapshot = tracker.snapshot()
assert snapshot["input_tokens"] == 12
assert snapshot["output_tokens"] == 8
assert snapshot["total_tokens"] == 20
assert snapshot["cost"] == Decimal("0.005500")
assert snapshot["currency"] == "USD"
def test_get_model_cost_returns_decimal() -> None:
usage = {
"prompt_tokens": 7,
"completion_tokens": 5,
"total_tokens": 12,
"cost": "0.002500",
}
cost = get_model_cost(usage)
assert cost == Decimal("0.002500")
def test_add_usage_rejects_negative_values() -> None:
tracker = CostTracker()
with pytest.raises(ValueError):
tracker.add_usage({"input_tokens": -1})
with pytest.raises(ValueError):
tracker.add_usage({"cost": "-0.010000"})
def test_get_model_cost_with_no_cost() -> None:
usage = {
"prompt_tokens": 7,
"completion_tokens": 5,
"total_tokens": 12,
}
cost = get_model_cost(usage)
assert cost == Decimal("0")
def test_snapshot_is_zero_before_any_usage() -> None:
tracker = CostTracker()
snapshot = tracker.snapshot()
assert snapshot["input_tokens"] == 0
assert snapshot["output_tokens"] == 0
assert snapshot["total_tokens"] == 0
assert snapshot["cost"] == Decimal("0")
assert snapshot["currency"] == "USD"
def test_get_model_cost_with_zero_cost() -> None:
usage = {
"prompt_tokens": 7,
"completion_tokens": 5,
"total_tokens": 12,
"cost": "0",
}
cost = get_model_cost(usage)
assert cost == Decimal("0")
def test_add_usage_rejects_currency_mismatch() -> None:
tracker = CostTracker(currency="USD")
tracker.add_usage({"input_tokens": 1, "output_tokens": 1, "cost": "0.001000"})
with pytest.raises(ValueError):
tracker.add_usage(
{
"input_tokens": 1,
"output_tokens": 1,
"cost": "0.001000",
"currency": "CNY",
}
)
def test_add_usage_rejects_non_integral_token_values() -> None:
tracker = CostTracker()
with pytest.raises(ValueError):
tracker.add_usage({"input_tokens": 1.5})
with pytest.raises(ValueError):
tracker.add_usage({"output_tokens": True})
def test_get_model_cost_with_numeric_cost() -> None:
usage = {
"prompt_tokens": 7,
"completion_tokens": 5,
"total_tokens": 12,
"cost": 0.0025,
}
cost = get_model_cost(usage)
assert cost == Decimal("0.0025")
@@ -1,89 +0,0 @@
from __future__ import annotations
import pytest
from core.agent.multimodal import AttachmentInput, MultimodalProcessor
from core.agent.storage_adapter import StorageAdapter
from core.agent.tools.asr_fun_asr import FunASRTool
def test_multimodal_processes_audio_and_builds_attachment_context() -> None:
storage = StorageAdapter(bucket="agent-chat-attachments")
def fake_transcribe(*, audio_bytes: bytes, filename: str) -> dict[str, str]:
assert audio_bytes == b"audio"
assert filename == "voice.wav"
return {"text": "hello world", "request_id": "req-1"}
processor = MultimodalProcessor(
storage=storage,
asr_tool=FunASRTool(transcribe_callable=fake_transcribe),
max_file_size_mb=1,
)
result = processor.process(
user_id="u1",
session_id="s1",
message_seq=4,
attachments=[
AttachmentInput(
filename="voice.wav",
mime_type="audio/wav",
content=b"audio",
)
],
)
assert len(result.attachments) == 1
metadata = result.attachments[0]
assert (
metadata["object_path"]
== "agent-chat/u1/s1/4/6ed8919ce20490a5e3ad8630a4fab69475297abd07db73918dd5f36fcfaeb11b.wav"
)
assert metadata["mime_type"] == "audio/wav"
assert result.preview_texts == ["hello world"]
def test_multimodal_rejects_unsupported_mime_type() -> None:
storage = StorageAdapter(bucket="agent-chat-attachments")
processor = MultimodalProcessor(
storage=storage, asr_tool=FunASRTool(lambda **_: {})
)
with pytest.raises(ValueError):
processor.process(
user_id="u1",
session_id="s1",
message_seq=1,
attachments=[
AttachmentInput(
filename="malware.exe",
mime_type="application/octet-stream",
content=b"bad",
)
],
)
def test_multimodal_rejects_attachment_over_max_size() -> None:
storage = StorageAdapter(bucket="agent-chat-attachments")
processor = MultimodalProcessor(
storage=storage,
asr_tool=FunASRTool(lambda **_: {}),
max_file_size_mb=1,
)
oversized = b"x" * (1024 * 1024 + 1)
with pytest.raises(ValueError):
processor.process(
user_id="u1",
session_id="s1",
message_seq=1,
attachments=[
AttachmentInput(
filename="big.wav",
mime_type="audio/wav",
content=oversized,
)
],
)
@@ -1,37 +0,0 @@
from __future__ import annotations
from core.agent.storage_adapter import StorageAdapter
def test_build_object_path_uses_expected_pattern() -> None:
adapter = StorageAdapter(bucket="agent-chat-attachments")
path = adapter.build_object_path(
user_id="u1",
session_id="s1",
message_seq=3,
checksum_sha256="abc123",
extension="wav",
)
assert path == "agent-chat/u1/s1/3/abc123.wav"
def test_build_attachment_metadata_contains_required_fields() -> None:
adapter = StorageAdapter(bucket="agent-chat-attachments")
metadata = adapter.build_attachment_metadata(
object_path="agent-chat/u1/s1/3/abc123.wav",
mime_type="audio/wav",
size=1024,
checksum_sha256="abc123",
origin="user_upload",
preview_text="hello",
)
assert metadata["object_path"] == "agent-chat/u1/s1/3/abc123.wav"
assert metadata["mime_type"] == "audio/wav"
assert metadata["size"] == 1024
assert metadata["checksum_sha256"] == "abc123"
assert metadata["origin"] == "user_upload"
assert metadata["preview_text"] == "hello"
@@ -1,89 +0,0 @@
from __future__ import annotations
import pytest
from core.agent.multimodal import AttachmentInput, MultimodalProcessor
from core.agent.storage_adapter import StorageAdapter
from core.agent.tools.asr_fun_asr import FunASRTool
def test_multimodal_processes_audio_and_builds_attachment_context() -> None:
storage = StorageAdapter(bucket="agent-chat-attachments")
def fake_transcribe(*, audio_bytes: bytes, filename: str) -> dict[str, str]:
assert audio_bytes == b"audio"
assert filename == "voice.wav"
return {"text": "hello world", "request_id": "req-1"}
processor = MultimodalProcessor(
storage=storage,
asr_tool=FunASRTool(transcribe_callable=fake_transcribe),
max_file_size_mb=1,
)
result = processor.process(
user_id="u1",
session_id="s1",
message_seq=4,
attachments=[
AttachmentInput(
filename="voice.wav",
mime_type="audio/wav",
content=b"audio",
)
],
)
assert len(result.attachments) == 1
metadata = result.attachments[0]
assert (
metadata["object_path"]
== "agent-chat/u1/s1/4/6ed8919ce20490a5e3ad8630a4fab69475297abd07db73918dd5f36fcfaeb11b.wav"
)
assert metadata["mime_type"] == "audio/wav"
assert result.preview_texts == ["hello world"]
def test_multimodal_rejects_unsupported_mime_type() -> None:
storage = StorageAdapter(bucket="agent-chat-attachments")
processor = MultimodalProcessor(
storage=storage, asr_tool=FunASRTool(lambda **_: {})
)
with pytest.raises(ValueError):
processor.process(
user_id="u1",
session_id="s1",
message_seq=1,
attachments=[
AttachmentInput(
filename="malware.exe",
mime_type="application/octet-stream",
content=b"bad",
)
],
)
def test_multimodal_rejects_attachment_over_max_size() -> None:
storage = StorageAdapter(bucket="agent-chat-attachments")
processor = MultimodalProcessor(
storage=storage,
asr_tool=FunASRTool(lambda **_: {}),
max_file_size_mb=1,
)
oversized = b"x" * (1024 * 1024 + 1)
with pytest.raises(ValueError):
processor.process(
user_id="u1",
session_id="s1",
message_seq=1,
attachments=[
AttachmentInput(
filename="big.wav",
mime_type="audio/wav",
content=oversized,
)
],
)
@@ -1,37 +0,0 @@
from __future__ import annotations
from core.agent.storage_adapter import StorageAdapter
def test_build_object_path_uses_expected_pattern() -> None:
adapter = StorageAdapter(bucket="agent-chat-attachments")
path = adapter.build_object_path(
user_id="u1",
session_id="s1",
message_seq=3,
checksum_sha256="abc123",
extension="wav",
)
assert path == "agent-chat/u1/s1/3/abc123.wav"
def test_build_attachment_metadata_contains_required_fields() -> None:
adapter = StorageAdapter(bucket="agent-chat-attachments")
metadata = adapter.build_attachment_metadata(
object_path="agent-chat/u1/s1/3/abc123.wav",
mime_type="audio/wav",
size=1024,
checksum_sha256="abc123",
origin="user_upload",
preview_text="hello",
)
assert metadata["object_path"] == "agent-chat/u1/s1/3/abc123.wav"
assert metadata["mime_type"] == "audio/wav"
assert metadata["size"] == 1024
assert metadata["checksum_sha256"] == "abc123"
assert metadata["origin"] == "user_upload"
assert metadata["preview_text"] == "hello"
+1
View File
@@ -12,6 +12,7 @@ dependencies = [
"crewai-tools>=1.6.1",
"email-validator>=2.3.0",
"fastapi>=0.128.0",
"litellm>=1.52.0",
"pydantic>=2.11.0",
"pydantic-settings>=2.10.0",
"pyjwt>=2.10.1",