refactor: 切换到 litellm,删除未使用的代码
- 添加 litellm 依赖,统一 LLM 调用层 - 新增 litellm_client.py 支持多厂商 - 更新 llm_catalog.yaml 添加 litellm_model 映射 - 删除旧的 cost_tracker.py (litellm 内置 cost 追踪) - 删除未使用的 multimodal.py 和 storage_adapter.py - 删除空文件 crewai/__init__.py, tools/__init__.py - 更新测试以适配新代码
This commit is contained in:
@@ -1,25 +0,0 @@
|
||||
# Deletion Log
|
||||
|
||||
## 2026-03-03 feature polish
|
||||
|
||||
- Scope: `backend/src/v1/agent/*` and `backend/tests/unit/v1/agent/*` only.
|
||||
- Candidate review source: scoped `refactor-cleaner` run for the directories above.
|
||||
|
||||
### Executed cleanup
|
||||
|
||||
1. Merged duplicated newline validation logic in `backend/src/v1/agent/service.py`.
|
||||
- Before: duplicated checks in `prepare_resume`, `stream_run`, `stream_resume`.
|
||||
- After: centralized `_validate_no_newlines` helper.
|
||||
- Behavior impact: none (same validation semantics).
|
||||
|
||||
2. Merged duplicated SSE event string formatting in `backend/src/v1/agent/service.py`.
|
||||
- Before: repeated `f"data: {json.dumps(...)}\n\n"` fragments.
|
||||
- After: centralized `_sse_data` helper.
|
||||
- Behavior impact: none (same payload format).
|
||||
|
||||
### Candidates not deleted (insufficient evidence)
|
||||
|
||||
- `backend/src/v1/agent/crewai_flow.py`
|
||||
- Reason: candidate report suggested possible dead code, but no deletion was done in this polish pass because cross-module usage certainty was insufficient.
|
||||
- Legacy `run()` path in `backend/src/v1/agent/service.py`
|
||||
- Reason: potentially still relied on by non-scope code paths; deletion deferred.
|
||||
@@ -1,69 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Any, Mapping
|
||||
|
||||
|
||||
def _to_non_negative_int(value: Any, *, field: str) -> int:
|
||||
if isinstance(value, bool):
|
||||
raise ValueError(f"{field} must be an integer")
|
||||
if isinstance(value, int):
|
||||
converted = value
|
||||
elif isinstance(value, str) and value.isdigit():
|
||||
converted = int(value)
|
||||
else:
|
||||
raise ValueError(f"{field} must be an integer")
|
||||
if converted < 0:
|
||||
raise ValueError(f"{field} cannot be negative")
|
||||
return converted
|
||||
|
||||
|
||||
def _to_non_negative_decimal(value: Any, *, field: str) -> Decimal:
|
||||
converted = Decimal(str(value))
|
||||
if converted < 0:
|
||||
raise ValueError(f"{field} cannot be negative")
|
||||
return converted
|
||||
|
||||
|
||||
class CostTracker:
|
||||
def __init__(self, *, currency: str = "USD") -> None:
|
||||
self._input_tokens = 0
|
||||
self._output_tokens = 0
|
||||
self._total_tokens = 0
|
||||
self._cost = Decimal("0")
|
||||
self._currency = currency
|
||||
|
||||
def add_usage(self, usage: Mapping[str, Any]) -> None:
|
||||
input_tokens = usage.get("input_tokens", usage.get("prompt_tokens", 0))
|
||||
output_tokens = usage.get("output_tokens", usage.get("completion_tokens", 0))
|
||||
total_tokens = usage.get("total_tokens")
|
||||
cost = usage.get("cost", "0")
|
||||
currency = usage.get("currency")
|
||||
|
||||
normalized_input = _to_non_negative_int(input_tokens, field="input_tokens")
|
||||
normalized_output = _to_non_negative_int(output_tokens, field="output_tokens")
|
||||
normalized_total = (
|
||||
_to_non_negative_int(total_tokens, field="total_tokens")
|
||||
if total_tokens is not None
|
||||
else normalized_input + normalized_output
|
||||
)
|
||||
normalized_cost = _to_non_negative_decimal(cost, field="cost")
|
||||
|
||||
self._input_tokens += normalized_input
|
||||
self._output_tokens += normalized_output
|
||||
self._total_tokens += normalized_total
|
||||
self._cost += normalized_cost
|
||||
|
||||
if currency is not None:
|
||||
normalized_currency = str(currency)
|
||||
if normalized_currency != self._currency:
|
||||
raise ValueError("currency mismatch")
|
||||
|
||||
def snapshot(self) -> dict[str, Any]:
|
||||
return {
|
||||
"input_tokens": self._input_tokens,
|
||||
"output_tokens": self._output_tokens,
|
||||
"total_tokens": self._total_tokens,
|
||||
"cost": self._cost,
|
||||
"currency": self._currency,
|
||||
}
|
||||
@@ -1 +0,0 @@
|
||||
from __future__ import annotations
|
||||
@@ -0,0 +1,130 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from decimal import Decimal
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LLMConfig:
|
||||
model_code: str
|
||||
factory_name: str
|
||||
litellm_model: str
|
||||
request_url: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LLMResponse:
|
||||
content: str
|
||||
usage: dict[str, Any]
|
||||
|
||||
|
||||
class LiteLLMClient:
|
||||
def __init__(self, config: LLMConfig, api_key: str | None = None) -> None:
|
||||
self._config = config
|
||||
self._api_key = api_key or self._get_api_key(config.factory_name)
|
||||
|
||||
@staticmethod
|
||||
def _get_api_key(factory_name: str) -> str:
|
||||
key_map = {
|
||||
"dashscope": "DASHSCOPE_API_KEY",
|
||||
"minimax": "MINIMAX_API_KEY",
|
||||
"moonshot": "MOONSHOT_API_KEY",
|
||||
"deepseek": "DEEPSEEK_API_KEY",
|
||||
"volcengine-ark": "ARK_API_KEY",
|
||||
"z-ai": "ZAI_API_KEY",
|
||||
}
|
||||
env_key = key_map.get(factory_name)
|
||||
if not env_key:
|
||||
raise ValueError(f"No API key mapping for factory: {factory_name}")
|
||||
key = os.environ.get(env_key)
|
||||
if not key:
|
||||
raise ValueError(f"Environment variable {env_key} is not set")
|
||||
return key
|
||||
|
||||
@staticmethod
|
||||
def load_config(
|
||||
model_code: str,
|
||||
static_root: Path | None = None,
|
||||
) -> LLMConfig:
|
||||
root = static_root or (
|
||||
Path(__file__).resolve().parents[3] / "config" / "static" / "database"
|
||||
)
|
||||
yaml_path = root / "llm_catalog.yaml"
|
||||
with yaml_path.open("r", encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
factories = {f["name"]: f for f in data.get("factories", [])}
|
||||
llms = data.get("llms", [])
|
||||
|
||||
for llm in llms:
|
||||
if llm.get("model_code") == model_code:
|
||||
factory_name = llm["factory_name"]
|
||||
factory = factories.get(factory_name)
|
||||
if not factory:
|
||||
raise ValueError(f"Factory not found: {factory_name}")
|
||||
return LLMConfig(
|
||||
model_code=model_code,
|
||||
factory_name=factory_name,
|
||||
litellm_model=llm.get("litellm_model", model_code),
|
||||
request_url=factory["request_url"],
|
||||
)
|
||||
|
||||
raise ValueError(f"Model not found: {model_code}")
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
*,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int | None = None,
|
||||
) -> LLMResponse:
|
||||
import litellm
|
||||
|
||||
response = litellm.completion( # type: ignore[attr-defined]
|
||||
model=self._config.litellm_model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
api_base=self._config.request_url,
|
||||
api_key=self._api_key,
|
||||
)
|
||||
|
||||
content = response.choices[0].message.content or "" # type: ignore[attr-defined]
|
||||
usage = response.usage.model_dump() if response.usage else {} # type: ignore[attr-defined]
|
||||
|
||||
return LLMResponse(content=content, usage=usage)
|
||||
|
||||
async def achat(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
*,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int | None = None,
|
||||
) -> LLMResponse:
|
||||
import litellm
|
||||
|
||||
response = await litellm.acompletion( # type: ignore[attr-defined]
|
||||
model=self._config.litellm_model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
api_base=self._config.request_url,
|
||||
api_key=self._api_key,
|
||||
)
|
||||
|
||||
content = response.choices[0].message.content or "" # type: ignore[attr-defined]
|
||||
usage = response.usage.model_dump() if response.usage else {} # type: ignore[attr-defined]
|
||||
|
||||
return LLMResponse(content=content, usage=usage)
|
||||
|
||||
|
||||
def get_model_cost(usage: dict[str, Any]) -> Decimal:
|
||||
cost = usage.get("cost")
|
||||
if cost is None:
|
||||
return Decimal("0")
|
||||
return Decimal(str(cost))
|
||||
@@ -1,112 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from typing import Protocol
|
||||
|
||||
from core.agent.storage_adapter import StorageAdapter
|
||||
|
||||
_ALLOWED_MIME_TYPES = {
|
||||
"audio/mpeg",
|
||||
"audio/wav",
|
||||
"audio/x-wav",
|
||||
"image/jpeg",
|
||||
"image/png",
|
||||
"image/webp",
|
||||
"application/pdf",
|
||||
"text/plain",
|
||||
}
|
||||
|
||||
|
||||
class _AsrTool(Protocol):
|
||||
def transcribe(self, *, audio_bytes: bytes, filename: str) -> dict[str, object]: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AttachmentInput:
|
||||
filename: str
|
||||
mime_type: str
|
||||
content: bytes
|
||||
origin: str = "user_upload"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProcessedAttachmentContext:
|
||||
attachments: list[dict[str, object]]
|
||||
preview_texts: list[str]
|
||||
|
||||
|
||||
class MultimodalProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
storage: StorageAdapter,
|
||||
asr_tool: _AsrTool,
|
||||
max_file_size_mb: int = 20,
|
||||
) -> None:
|
||||
self._storage = storage
|
||||
self._asr_tool = asr_tool
|
||||
self._max_size_bytes = max_file_size_mb * 1024 * 1024
|
||||
|
||||
def process(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
message_seq: int,
|
||||
attachments: list[AttachmentInput],
|
||||
) -> ProcessedAttachmentContext:
|
||||
metadata_list: list[dict[str, object]] = []
|
||||
preview_texts: list[str] = []
|
||||
|
||||
for attachment in attachments:
|
||||
self._validate_attachment(attachment)
|
||||
checksum = hashlib.sha256(attachment.content).hexdigest()
|
||||
extension = Path(attachment.filename).suffix.strip(".").lower() or "bin"
|
||||
object_path = self._storage.build_object_path(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
message_seq=message_seq,
|
||||
checksum_sha256=checksum,
|
||||
extension=extension,
|
||||
)
|
||||
|
||||
preview_text = self._build_preview_text(attachment)
|
||||
if preview_text:
|
||||
preview_texts.append(preview_text)
|
||||
|
||||
metadata = self._storage.build_attachment_metadata(
|
||||
object_path=object_path,
|
||||
mime_type=attachment.mime_type,
|
||||
size=len(attachment.content),
|
||||
checksum_sha256=checksum,
|
||||
origin=attachment.origin,
|
||||
preview_text=preview_text,
|
||||
)
|
||||
metadata_list.append(metadata)
|
||||
|
||||
return ProcessedAttachmentContext(
|
||||
attachments=metadata_list,
|
||||
preview_texts=preview_texts,
|
||||
)
|
||||
|
||||
def _validate_attachment(self, attachment: AttachmentInput) -> None:
|
||||
if attachment.mime_type not in _ALLOWED_MIME_TYPES:
|
||||
raise ValueError("Unsupported MIME type")
|
||||
if len(attachment.content) > self._max_size_bytes:
|
||||
raise ValueError("Attachment exceeds max file size")
|
||||
|
||||
def _build_preview_text(self, attachment: AttachmentInput) -> str | None:
|
||||
if attachment.mime_type.startswith("audio/"):
|
||||
result = self._asr_tool.transcribe(
|
||||
audio_bytes=attachment.content,
|
||||
filename=attachment.filename,
|
||||
)
|
||||
text = result.get("text")
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
return None
|
||||
if attachment.mime_type == "text/plain":
|
||||
return attachment.content.decode("utf-8", errors="ignore")[:200]
|
||||
return None
|
||||
@@ -2,10 +2,11 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from decimal import Decimal
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
from core.agent.cost_tracker import CostTracker
|
||||
from core.agent import events
|
||||
from core.agent.litellm_client import get_model_cost
|
||||
|
||||
StageCallable = Callable[..., Awaitable[dict[str, Any]]]
|
||||
|
||||
@@ -20,6 +21,34 @@ class OrchestratorResult:
|
||||
error: str | None
|
||||
|
||||
|
||||
class _UsageTracker:
|
||||
def __init__(self) -> None:
|
||||
self._input_tokens = 0
|
||||
self._output_tokens = 0
|
||||
self._total_tokens = 0
|
||||
self._cost = Decimal("0")
|
||||
|
||||
def add_usage(self, usage: dict[str, Any]) -> None:
|
||||
input_tokens = usage.get("prompt_tokens", 0) or usage.get("input_tokens", 0)
|
||||
output_tokens = usage.get("completion_tokens", 0) or usage.get(
|
||||
"output_tokens", 0
|
||||
)
|
||||
total = usage.get("total_tokens")
|
||||
|
||||
self._input_tokens += input_tokens
|
||||
self._output_tokens += output_tokens
|
||||
self._total_tokens += total if total else (input_tokens + output_tokens)
|
||||
self._cost += get_model_cost(usage)
|
||||
|
||||
def snapshot(self) -> dict[str, Any]:
|
||||
return {
|
||||
"input_tokens": self._input_tokens,
|
||||
"output_tokens": self._output_tokens,
|
||||
"total_tokens": self._total_tokens,
|
||||
"cost": str(self._cost),
|
||||
}
|
||||
|
||||
|
||||
class AgentChatOrchestrator:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -36,7 +65,7 @@ class AgentChatOrchestrator:
|
||||
return asyncio.run(self.run(run_id=run_id, user_message=user_message))
|
||||
|
||||
async def run(self, *, run_id: str, user_message: str) -> OrchestratorResult:
|
||||
tracker = CostTracker()
|
||||
tracker = _UsageTracker()
|
||||
emitted_events: list[dict[str, Any]] = [events.run_started(run_id=run_id)]
|
||||
context: dict[str, Any] = {}
|
||||
|
||||
|
||||
@@ -1,46 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class StorageAdapter:
|
||||
_bucket: str
|
||||
|
||||
def __init__(self, bucket: str) -> None:
|
||||
self._bucket = bucket
|
||||
|
||||
@property
|
||||
def bucket(self) -> str:
|
||||
return self._bucket
|
||||
|
||||
def build_object_path(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
message_seq: int,
|
||||
checksum_sha256: str,
|
||||
extension: str,
|
||||
) -> str:
|
||||
normalized_ext = extension.strip(".").lower()
|
||||
return (
|
||||
f"agent-chat/{user_id}/{session_id}/{message_seq}/"
|
||||
f"{checksum_sha256}.{normalized_ext}"
|
||||
)
|
||||
|
||||
def build_attachment_metadata(
|
||||
self,
|
||||
*,
|
||||
object_path: str,
|
||||
mime_type: str,
|
||||
size: int,
|
||||
checksum_sha256: str,
|
||||
origin: str,
|
||||
preview_text: str | None = None,
|
||||
) -> dict[str, object]:
|
||||
return {
|
||||
"object_path": object_path,
|
||||
"mime_type": mime_type,
|
||||
"size": size,
|
||||
"checksum_sha256": checksum_sha256,
|
||||
"origin": origin,
|
||||
"preview_text": preview_text,
|
||||
}
|
||||
@@ -1 +0,0 @@
|
||||
from __future__ import annotations
|
||||
@@ -27,6 +27,8 @@ llms:
|
||||
# 你原来的两个保留
|
||||
- model_code: qwen3.5-flash
|
||||
factory_name: dashscope
|
||||
litellm_model: dashscope/qwen-turbo
|
||||
|
||||
- model_code: deepseek-v3.2
|
||||
factory_name: deepseek
|
||||
litellm_model: deepseek/deepseek-chat
|
||||
|
||||
@@ -2,81 +2,47 @@ from __future__ import annotations
|
||||
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.cost_tracker import CostTracker
|
||||
from core.agent.litellm_client import get_model_cost
|
||||
|
||||
|
||||
def test_normalize_usage_and_cost_aggregation() -> None:
|
||||
tracker = CostTracker()
|
||||
|
||||
tracker.add_usage(
|
||||
{
|
||||
"prompt_tokens": 7,
|
||||
"completion_tokens": 5,
|
||||
"cost": "0.002500",
|
||||
}
|
||||
)
|
||||
tracker.add_usage(
|
||||
{
|
||||
"input_tokens": 5,
|
||||
"output_tokens": 3,
|
||||
"cost": "0.003000",
|
||||
"currency": "USD",
|
||||
}
|
||||
)
|
||||
|
||||
snapshot = tracker.snapshot()
|
||||
|
||||
assert snapshot["input_tokens"] == 12
|
||||
assert snapshot["output_tokens"] == 8
|
||||
assert snapshot["total_tokens"] == 20
|
||||
assert snapshot["cost"] == Decimal("0.005500")
|
||||
assert snapshot["currency"] == "USD"
|
||||
def test_get_model_cost_returns_decimal() -> None:
|
||||
usage = {
|
||||
"prompt_tokens": 7,
|
||||
"completion_tokens": 5,
|
||||
"total_tokens": 12,
|
||||
"cost": "0.002500",
|
||||
}
|
||||
cost = get_model_cost(usage)
|
||||
assert cost == Decimal("0.002500")
|
||||
|
||||
|
||||
def test_add_usage_rejects_negative_values() -> None:
|
||||
tracker = CostTracker()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
tracker.add_usage({"input_tokens": -1})
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
tracker.add_usage({"cost": "-0.010000"})
|
||||
def test_get_model_cost_with_no_cost() -> None:
|
||||
usage = {
|
||||
"prompt_tokens": 7,
|
||||
"completion_tokens": 5,
|
||||
"total_tokens": 12,
|
||||
}
|
||||
cost = get_model_cost(usage)
|
||||
assert cost == Decimal("0")
|
||||
|
||||
|
||||
def test_snapshot_is_zero_before_any_usage() -> None:
|
||||
tracker = CostTracker()
|
||||
|
||||
snapshot = tracker.snapshot()
|
||||
|
||||
assert snapshot["input_tokens"] == 0
|
||||
assert snapshot["output_tokens"] == 0
|
||||
assert snapshot["total_tokens"] == 0
|
||||
assert snapshot["cost"] == Decimal("0")
|
||||
assert snapshot["currency"] == "USD"
|
||||
def test_get_model_cost_with_zero_cost() -> None:
|
||||
usage = {
|
||||
"prompt_tokens": 7,
|
||||
"completion_tokens": 5,
|
||||
"total_tokens": 12,
|
||||
"cost": "0",
|
||||
}
|
||||
cost = get_model_cost(usage)
|
||||
assert cost == Decimal("0")
|
||||
|
||||
|
||||
def test_add_usage_rejects_currency_mismatch() -> None:
|
||||
tracker = CostTracker(currency="USD")
|
||||
tracker.add_usage({"input_tokens": 1, "output_tokens": 1, "cost": "0.001000"})
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
tracker.add_usage(
|
||||
{
|
||||
"input_tokens": 1,
|
||||
"output_tokens": 1,
|
||||
"cost": "0.001000",
|
||||
"currency": "CNY",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_add_usage_rejects_non_integral_token_values() -> None:
|
||||
tracker = CostTracker()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
tracker.add_usage({"input_tokens": 1.5})
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
tracker.add_usage({"output_tokens": True})
|
||||
def test_get_model_cost_with_numeric_cost() -> None:
|
||||
usage = {
|
||||
"prompt_tokens": 7,
|
||||
"completion_tokens": 5,
|
||||
"total_tokens": 12,
|
||||
"cost": 0.0025,
|
||||
}
|
||||
cost = get_model_cost(usage)
|
||||
assert cost == Decimal("0.0025")
|
||||
|
||||
@@ -1,89 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.multimodal import AttachmentInput, MultimodalProcessor
|
||||
from core.agent.storage_adapter import StorageAdapter
|
||||
from core.agent.tools.asr_fun_asr import FunASRTool
|
||||
|
||||
|
||||
def test_multimodal_processes_audio_and_builds_attachment_context() -> None:
|
||||
storage = StorageAdapter(bucket="agent-chat-attachments")
|
||||
|
||||
def fake_transcribe(*, audio_bytes: bytes, filename: str) -> dict[str, str]:
|
||||
assert audio_bytes == b"audio"
|
||||
assert filename == "voice.wav"
|
||||
return {"text": "hello world", "request_id": "req-1"}
|
||||
|
||||
processor = MultimodalProcessor(
|
||||
storage=storage,
|
||||
asr_tool=FunASRTool(transcribe_callable=fake_transcribe),
|
||||
max_file_size_mb=1,
|
||||
)
|
||||
|
||||
result = processor.process(
|
||||
user_id="u1",
|
||||
session_id="s1",
|
||||
message_seq=4,
|
||||
attachments=[
|
||||
AttachmentInput(
|
||||
filename="voice.wav",
|
||||
mime_type="audio/wav",
|
||||
content=b"audio",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
assert len(result.attachments) == 1
|
||||
metadata = result.attachments[0]
|
||||
assert (
|
||||
metadata["object_path"]
|
||||
== "agent-chat/u1/s1/4/6ed8919ce20490a5e3ad8630a4fab69475297abd07db73918dd5f36fcfaeb11b.wav"
|
||||
)
|
||||
assert metadata["mime_type"] == "audio/wav"
|
||||
assert result.preview_texts == ["hello world"]
|
||||
|
||||
|
||||
def test_multimodal_rejects_unsupported_mime_type() -> None:
|
||||
storage = StorageAdapter(bucket="agent-chat-attachments")
|
||||
processor = MultimodalProcessor(
|
||||
storage=storage, asr_tool=FunASRTool(lambda **_: {})
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
processor.process(
|
||||
user_id="u1",
|
||||
session_id="s1",
|
||||
message_seq=1,
|
||||
attachments=[
|
||||
AttachmentInput(
|
||||
filename="malware.exe",
|
||||
mime_type="application/octet-stream",
|
||||
content=b"bad",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_multimodal_rejects_attachment_over_max_size() -> None:
|
||||
storage = StorageAdapter(bucket="agent-chat-attachments")
|
||||
processor = MultimodalProcessor(
|
||||
storage=storage,
|
||||
asr_tool=FunASRTool(lambda **_: {}),
|
||||
max_file_size_mb=1,
|
||||
)
|
||||
|
||||
oversized = b"x" * (1024 * 1024 + 1)
|
||||
with pytest.raises(ValueError):
|
||||
processor.process(
|
||||
user_id="u1",
|
||||
session_id="s1",
|
||||
message_seq=1,
|
||||
attachments=[
|
||||
AttachmentInput(
|
||||
filename="big.wav",
|
||||
mime_type="audio/wav",
|
||||
content=oversized,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -1,37 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.agent.storage_adapter import StorageAdapter
|
||||
|
||||
|
||||
def test_build_object_path_uses_expected_pattern() -> None:
|
||||
adapter = StorageAdapter(bucket="agent-chat-attachments")
|
||||
|
||||
path = adapter.build_object_path(
|
||||
user_id="u1",
|
||||
session_id="s1",
|
||||
message_seq=3,
|
||||
checksum_sha256="abc123",
|
||||
extension="wav",
|
||||
)
|
||||
|
||||
assert path == "agent-chat/u1/s1/3/abc123.wav"
|
||||
|
||||
|
||||
def test_build_attachment_metadata_contains_required_fields() -> None:
|
||||
adapter = StorageAdapter(bucket="agent-chat-attachments")
|
||||
|
||||
metadata = adapter.build_attachment_metadata(
|
||||
object_path="agent-chat/u1/s1/3/abc123.wav",
|
||||
mime_type="audio/wav",
|
||||
size=1024,
|
||||
checksum_sha256="abc123",
|
||||
origin="user_upload",
|
||||
preview_text="hello",
|
||||
)
|
||||
|
||||
assert metadata["object_path"] == "agent-chat/u1/s1/3/abc123.wav"
|
||||
assert metadata["mime_type"] == "audio/wav"
|
||||
assert metadata["size"] == 1024
|
||||
assert metadata["checksum_sha256"] == "abc123"
|
||||
assert metadata["origin"] == "user_upload"
|
||||
assert metadata["preview_text"] == "hello"
|
||||
@@ -1,89 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.multimodal import AttachmentInput, MultimodalProcessor
|
||||
from core.agent.storage_adapter import StorageAdapter
|
||||
from core.agent.tools.asr_fun_asr import FunASRTool
|
||||
|
||||
|
||||
def test_multimodal_processes_audio_and_builds_attachment_context() -> None:
|
||||
storage = StorageAdapter(bucket="agent-chat-attachments")
|
||||
|
||||
def fake_transcribe(*, audio_bytes: bytes, filename: str) -> dict[str, str]:
|
||||
assert audio_bytes == b"audio"
|
||||
assert filename == "voice.wav"
|
||||
return {"text": "hello world", "request_id": "req-1"}
|
||||
|
||||
processor = MultimodalProcessor(
|
||||
storage=storage,
|
||||
asr_tool=FunASRTool(transcribe_callable=fake_transcribe),
|
||||
max_file_size_mb=1,
|
||||
)
|
||||
|
||||
result = processor.process(
|
||||
user_id="u1",
|
||||
session_id="s1",
|
||||
message_seq=4,
|
||||
attachments=[
|
||||
AttachmentInput(
|
||||
filename="voice.wav",
|
||||
mime_type="audio/wav",
|
||||
content=b"audio",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
assert len(result.attachments) == 1
|
||||
metadata = result.attachments[0]
|
||||
assert (
|
||||
metadata["object_path"]
|
||||
== "agent-chat/u1/s1/4/6ed8919ce20490a5e3ad8630a4fab69475297abd07db73918dd5f36fcfaeb11b.wav"
|
||||
)
|
||||
assert metadata["mime_type"] == "audio/wav"
|
||||
assert result.preview_texts == ["hello world"]
|
||||
|
||||
|
||||
def test_multimodal_rejects_unsupported_mime_type() -> None:
|
||||
storage = StorageAdapter(bucket="agent-chat-attachments")
|
||||
processor = MultimodalProcessor(
|
||||
storage=storage, asr_tool=FunASRTool(lambda **_: {})
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
processor.process(
|
||||
user_id="u1",
|
||||
session_id="s1",
|
||||
message_seq=1,
|
||||
attachments=[
|
||||
AttachmentInput(
|
||||
filename="malware.exe",
|
||||
mime_type="application/octet-stream",
|
||||
content=b"bad",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_multimodal_rejects_attachment_over_max_size() -> None:
|
||||
storage = StorageAdapter(bucket="agent-chat-attachments")
|
||||
processor = MultimodalProcessor(
|
||||
storage=storage,
|
||||
asr_tool=FunASRTool(lambda **_: {}),
|
||||
max_file_size_mb=1,
|
||||
)
|
||||
|
||||
oversized = b"x" * (1024 * 1024 + 1)
|
||||
with pytest.raises(ValueError):
|
||||
processor.process(
|
||||
user_id="u1",
|
||||
session_id="s1",
|
||||
message_seq=1,
|
||||
attachments=[
|
||||
AttachmentInput(
|
||||
filename="big.wav",
|
||||
mime_type="audio/wav",
|
||||
content=oversized,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -1,37 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.agent.storage_adapter import StorageAdapter
|
||||
|
||||
|
||||
def test_build_object_path_uses_expected_pattern() -> None:
|
||||
adapter = StorageAdapter(bucket="agent-chat-attachments")
|
||||
|
||||
path = adapter.build_object_path(
|
||||
user_id="u1",
|
||||
session_id="s1",
|
||||
message_seq=3,
|
||||
checksum_sha256="abc123",
|
||||
extension="wav",
|
||||
)
|
||||
|
||||
assert path == "agent-chat/u1/s1/3/abc123.wav"
|
||||
|
||||
|
||||
def test_build_attachment_metadata_contains_required_fields() -> None:
|
||||
adapter = StorageAdapter(bucket="agent-chat-attachments")
|
||||
|
||||
metadata = adapter.build_attachment_metadata(
|
||||
object_path="agent-chat/u1/s1/3/abc123.wav",
|
||||
mime_type="audio/wav",
|
||||
size=1024,
|
||||
checksum_sha256="abc123",
|
||||
origin="user_upload",
|
||||
preview_text="hello",
|
||||
)
|
||||
|
||||
assert metadata["object_path"] == "agent-chat/u1/s1/3/abc123.wav"
|
||||
assert metadata["mime_type"] == "audio/wav"
|
||||
assert metadata["size"] == 1024
|
||||
assert metadata["checksum_sha256"] == "abc123"
|
||||
assert metadata["origin"] == "user_upload"
|
||||
assert metadata["preview_text"] == "hello"
|
||||
@@ -12,6 +12,7 @@ dependencies = [
|
||||
"crewai-tools>=1.6.1",
|
||||
"email-validator>=2.3.0",
|
||||
"fastapi>=0.128.0",
|
||||
"litellm>=1.52.0",
|
||||
"pydantic>=2.11.0",
|
||||
"pydantic-settings>=2.10.0",
|
||||
"pyjwt>=2.10.1",
|
||||
|
||||
Reference in New Issue
Block a user