feat: 应用名称更新为灵可析并增强 Chat 功能
- 更新 Android/iOS 应用名称和图标为灵可析 - Chat 支持取消正在运行的 Agent 对话 - 改进 ChatBloc 状态管理(区分发送/等待/流式/取消状态) - HomeScreen 支持外部注入 ChatBloc 和显示等待指示器 - 后端 Agent 运行服务优化(消息处理、usage 追踪) - 补充相关单元测试和 Widget 测试
This commit is contained in:
@@ -13,8 +13,8 @@ from core.agent.domain.agui_input import (
|
||||
)
|
||||
from core.agent.application.runtime_loop_service import RuntimeLoopService
|
||||
from core.agent.application.runtime_data_service import RuntimeDataService
|
||||
from core.agent.application.session_state_persistence import SessionStatePersistence
|
||||
from core.agent.application.session_state_persistence import (
|
||||
SessionStatePersistence,
|
||||
ToolResultStorage,
|
||||
persist_tool_result_payload,
|
||||
)
|
||||
@@ -179,7 +179,6 @@ class RunService:
|
||||
seq=next_seq,
|
||||
role=AgentChatMessageRole.USER,
|
||||
content=user_input,
|
||||
model_code=model_code,
|
||||
metadata=MessageMetadataUserInput().model_dump(),
|
||||
)
|
||||
pending_tool_call_id: str | None = None
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Any, Callable
|
||||
|
||||
from crewai import Agent, Crew, LLM, Process, Task
|
||||
from crewai.agents import parser as crew_parser
|
||||
from litellm import completion, completion_cost
|
||||
from litellm import completion
|
||||
|
||||
from core.agent.domain.system_agent_config import SystemAgentLLMConfig
|
||||
from core.agent.infrastructure.config.resolver import ResolvedAgentConfig
|
||||
@@ -17,7 +17,11 @@ from core.agent.infrastructure.crewai.runtime_tools import (
|
||||
PendingFrontendToolCall,
|
||||
resolve_stage_crewai_tools,
|
||||
)
|
||||
from core.agent.infrastructure.litellm.usage_tracker import UsageCost
|
||||
from core.agent.infrastructure.litellm.pricing import calculate_tiered_model_cost
|
||||
from core.agent.infrastructure.litellm.usage_tracker import (
|
||||
UsageCost,
|
||||
extract_usage_and_cost,
|
||||
)
|
||||
from core.agent.prompt import runtime_stage_prompts
|
||||
from core.logging import get_logger
|
||||
|
||||
@@ -25,6 +29,31 @@ from core.logging import get_logger
|
||||
logger = get_logger("core.agent.infrastructure.crewai.runtime_stage_runner")
|
||||
|
||||
|
||||
class LiteLLMUsageCaptureCallback:
|
||||
def __init__(self) -> None:
|
||||
self.captured_usage: dict[str, Any] | None = None
|
||||
|
||||
@staticmethod
|
||||
def _normalize_usage(usage_payload: object) -> dict[str, Any] | None:
|
||||
if isinstance(usage_payload, dict):
|
||||
return usage_payload
|
||||
model_dump = getattr(usage_payload, "model_dump", None)
|
||||
if callable(model_dump):
|
||||
dumped = model_dump()
|
||||
if isinstance(dumped, dict):
|
||||
return dumped
|
||||
return None
|
||||
|
||||
def log_success_event(self, **kwargs: Any) -> None:
|
||||
response_obj = kwargs.get("response_obj")
|
||||
if not isinstance(response_obj, dict):
|
||||
return
|
||||
normalized = self._normalize_usage(response_obj.get("usage"))
|
||||
if normalized is None:
|
||||
return
|
||||
self.captured_usage = normalized
|
||||
|
||||
|
||||
def _tool_names(tools_payload: list[dict[str, object]]) -> list[str]:
|
||||
names: list[str] = []
|
||||
for item in tools_payload:
|
||||
@@ -69,24 +98,37 @@ def _output_diagnostics(*, text: str, tool_names: list[str]) -> dict[str, object
|
||||
}
|
||||
|
||||
|
||||
def extract_usage_from_captured_payload(
|
||||
*,
|
||||
captured_usage: dict[str, Any],
|
||||
model: str,
|
||||
) -> UsageCost:
|
||||
usage = extract_usage_and_cost(
|
||||
{
|
||||
"model": model,
|
||||
"usage": captured_usage,
|
||||
}
|
||||
)
|
||||
return usage
|
||||
|
||||
|
||||
def extract_usage_from_crew_output(*, output: object, model: str) -> UsageCost:
|
||||
token_usage = getattr(output, "token_usage", None)
|
||||
prompt_tokens = int(getattr(token_usage, "prompt_tokens", 0) or 0)
|
||||
completion_tokens = int(getattr(token_usage, "completion_tokens", 0) or 0)
|
||||
total_tokens = int(getattr(token_usage, "total_tokens", 0) or 0)
|
||||
cached_prompt_tokens = int(getattr(token_usage, "cached_prompt_tokens", 0) or 0)
|
||||
if total_tokens == 0:
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
try:
|
||||
cost = float(
|
||||
completion_cost(
|
||||
model=model,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
or 0.0
|
||||
cost = float(
|
||||
calculate_tiered_model_cost(
|
||||
model_name=model,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
cached_prompt_tokens=cached_prompt_tokens,
|
||||
)
|
||||
except Exception:
|
||||
cost = 0.0
|
||||
or 0.0
|
||||
)
|
||||
return UsageCost(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
@@ -134,32 +176,32 @@ def run_stage_with_crewai(
|
||||
content = getattr(message, "content", None)
|
||||
if isinstance(content, str):
|
||||
raw_text = content
|
||||
usage_obj = getattr(response_any, "usage", None)
|
||||
prompt_tokens = int(getattr(usage_obj, "prompt_tokens", 0) or 0)
|
||||
completion_tokens = int(getattr(usage_obj, "completion_tokens", 0) or 0)
|
||||
total_tokens = int(getattr(usage_obj, "total_tokens", 0) or 0)
|
||||
if total_tokens == 0:
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
try:
|
||||
cost = float(
|
||||
completion_cost(
|
||||
model=litellm_model,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
or 0.0
|
||||
response_dict = (
|
||||
response_any.model_dump()
|
||||
if hasattr(response_any, "model_dump")
|
||||
else dict(response_any)
|
||||
)
|
||||
if "model" not in response_dict:
|
||||
response_dict["model"] = litellm_model
|
||||
usage = extract_usage_and_cost(response_dict)
|
||||
except Exception:
|
||||
cost = 0.0
|
||||
usage = UsageCost(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
cost=cost,
|
||||
)
|
||||
usage_obj = getattr(response_any, "usage", None)
|
||||
prompt_tokens = int(getattr(usage_obj, "prompt_tokens", 0) or 0)
|
||||
completion_tokens = int(getattr(usage_obj, "completion_tokens", 0) or 0)
|
||||
total_tokens = int(getattr(usage_obj, "total_tokens", 0) or 0)
|
||||
if total_tokens == 0:
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
usage = UsageCost(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
cost=0.0,
|
||||
)
|
||||
return raw_text, usage, [], None
|
||||
|
||||
calls: list[dict[str, Any]] = []
|
||||
usage_callback = LiteLLMUsageCaptureCallback()
|
||||
crew_tools = resolve_stage_crewai_tools(
|
||||
tools_payload=tools_payload,
|
||||
calls=calls,
|
||||
@@ -173,6 +215,8 @@ def run_stage_with_crewai(
|
||||
temperature=llm_config.temperature,
|
||||
max_tokens=llm_config.max_tokens,
|
||||
timeout=llm_config.timeout_seconds,
|
||||
stream=True,
|
||||
callbacks=[usage_callback],
|
||||
)
|
||||
agent = Agent(
|
||||
role=agent_template.role,
|
||||
@@ -218,7 +262,14 @@ def run_stage_with_crewai(
|
||||
],
|
||||
pending_tool=str(pending.payload.get("name")),
|
||||
)
|
||||
return "", UsageCost(0, 0, 0, 0.0), calls, pending.payload
|
||||
if usage_callback.captured_usage is not None:
|
||||
usage = extract_usage_from_captured_payload(
|
||||
captured_usage=usage_callback.captured_usage,
|
||||
model=litellm_model,
|
||||
)
|
||||
else:
|
||||
usage = UsageCost(0, 0, 0, 0.0)
|
||||
return "", usage, calls, pending.payload
|
||||
|
||||
output_text = extract_crew_output_text(output)
|
||||
logger.info(
|
||||
@@ -231,5 +282,11 @@ def run_stage_with_crewai(
|
||||
],
|
||||
diagnostics=_output_diagnostics(text=output_text, tool_names=stage_tool_names),
|
||||
)
|
||||
usage = extract_usage_from_crew_output(output=output, model=litellm_model)
|
||||
if usage_callback.captured_usage is not None:
|
||||
usage = extract_usage_from_captured_payload(
|
||||
captured_usage=usage_callback.captured_usage,
|
||||
model=litellm_model,
|
||||
)
|
||||
else:
|
||||
usage = extract_usage_from_crew_output(output=output, model=litellm_model)
|
||||
return output_text, usage, calls, None
|
||||
|
||||
@@ -36,9 +36,22 @@ QWEN35_FLASH_TIERED_PRICING: tuple[TieredModelPricing, ...] = (
|
||||
),
|
||||
)
|
||||
|
||||
DEEPSEEK_CHAT_TIERED_PRICING: tuple[TieredModelPricing, ...] = (
|
||||
TieredModelPricing(
|
||||
max_prompt_tokens=10_000_000,
|
||||
input_cost_per_token=2.0 / 1_000_000,
|
||||
output_cost_per_token=3.0 / 1_000_000,
|
||||
cache_create_cost_per_token=2.0 / 1_000_000,
|
||||
cache_hit_cost_per_token=0.2 / 1_000_000,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
_MODEL_TIERED_PRICING: dict[str, tuple[TieredModelPricing, ...]] = {
|
||||
"dashscope/qwen3.5-flash": QWEN35_FLASH_TIERED_PRICING,
|
||||
"qwen3.5-flash": QWEN35_FLASH_TIERED_PRICING,
|
||||
"deepseek/deepseek-chat": DEEPSEEK_CHAT_TIERED_PRICING,
|
||||
"deepseek-chat": DEEPSEEK_CHAT_TIERED_PRICING,
|
||||
}
|
||||
|
||||
|
||||
@@ -61,12 +74,21 @@ def calculate_tiered_model_cost(
|
||||
model_name: str,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
cached_prompt_tokens: int = 0,
|
||||
) -> float | None:
|
||||
tier = get_tiered_pricing(model_name=model_name, prompt_tokens=prompt_tokens)
|
||||
if tier is None:
|
||||
return None
|
||||
|
||||
return (
|
||||
prompt_tokens * tier.input_cost_per_token
|
||||
+ completion_tokens * tier.output_cost_per_token
|
||||
normalized_prompt_tokens = max(int(prompt_tokens), 0)
|
||||
normalized_completion_tokens = max(int(completion_tokens), 0)
|
||||
normalized_cached_tokens = min(
|
||||
max(int(cached_prompt_tokens), 0), normalized_prompt_tokens
|
||||
)
|
||||
uncached_prompt_tokens = normalized_prompt_tokens - normalized_cached_tokens
|
||||
|
||||
return (
|
||||
uncached_prompt_tokens * tier.input_cost_per_token
|
||||
+ normalized_cached_tokens * tier.cache_hit_cost_per_token
|
||||
+ normalized_completion_tokens * tier.output_cost_per_token
|
||||
)
|
||||
|
||||
@@ -3,8 +3,6 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from litellm import completion_cost
|
||||
|
||||
from core.agent.infrastructure.litellm.pricing import calculate_tiered_model_cost
|
||||
|
||||
|
||||
@@ -26,25 +24,19 @@ def extract_usage_and_cost(response: dict[str, Any]) -> UsageCost:
|
||||
completion_tokens = int(usage.get("completion_tokens", 0))
|
||||
total_tokens = int(usage.get("total_tokens", prompt_tokens + completion_tokens))
|
||||
model_name = str(response.get("model", "")).strip().lower()
|
||||
prompt_tokens_details = usage.get("prompt_tokens_details")
|
||||
cached_prompt_tokens = 0
|
||||
if isinstance(prompt_tokens_details, dict):
|
||||
cached_prompt_tokens = int(prompt_tokens_details.get("cached_tokens", 0) or 0)
|
||||
|
||||
try:
|
||||
cost = completion_cost(completion_response=response)
|
||||
if cost is None:
|
||||
raise ValueError("unable to calculate litellm completion cost")
|
||||
return UsageCost(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
cost=float(cost),
|
||||
)
|
||||
except Exception as exc:
|
||||
local_cost = calculate_tiered_model_cost(
|
||||
model_name=model_name,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
if local_cost is None:
|
||||
raise ValueError("unable to calculate litellm completion cost") from exc
|
||||
local_cost = calculate_tiered_model_cost(
|
||||
model_name=model_name,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
cached_prompt_tokens=cached_prompt_tokens,
|
||||
)
|
||||
if local_cost is None:
|
||||
raise ValueError("unable to calculate custom completion cost")
|
||||
|
||||
return UsageCost(
|
||||
prompt_tokens=prompt_tokens,
|
||||
|
||||
@@ -5,15 +5,14 @@ import pytest
|
||||
from core.agent.infrastructure.litellm.usage_tracker import extract_usage_and_cost
|
||||
|
||||
|
||||
def test_usage_tracker_extracts_tokens_and_cost(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.agent.infrastructure.litellm.usage_tracker.completion_cost",
|
||||
lambda completion_response: 0.123,
|
||||
)
|
||||
def test_usage_tracker_uses_custom_pricing_for_qwen35() -> None:
|
||||
response = {
|
||||
"usage": {"prompt_tokens": 11, "completion_tokens": 7, "total_tokens": 18},
|
||||
"model": "dashscope/qwen3.5-flash",
|
||||
"usage": {
|
||||
"prompt_tokens": 11,
|
||||
"completion_tokens": 7,
|
||||
"total_tokens": 18,
|
||||
},
|
||||
}
|
||||
|
||||
usage = extract_usage_and_cost(response)
|
||||
@@ -21,7 +20,8 @@ def test_usage_tracker_extracts_tokens_and_cost(
|
||||
assert usage.prompt_tokens == 11
|
||||
assert usage.completion_tokens == 7
|
||||
assert usage.total_tokens == 18
|
||||
assert usage.cost == 0.123
|
||||
assert usage.cost == pytest.approx(0.0000162)
|
||||
assert usage.cost_source == "custom_pricing"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -33,19 +33,10 @@ def test_usage_tracker_extracts_tokens_and_cost(
|
||||
],
|
||||
)
|
||||
def test_usage_tracker_falls_back_to_local_qwen35_pricing_when_model_unmapped(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
expected_cost: float,
|
||||
) -> None:
|
||||
def _raise_unmapped(*, completion_response): # type: ignore[no-untyped-def]
|
||||
del completion_response
|
||||
raise Exception("This model isn't mapped yet")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.agent.infrastructure.litellm.usage_tracker.completion_cost",
|
||||
_raise_unmapped,
|
||||
)
|
||||
response = {
|
||||
"model": "dashscope/qwen3.5-flash",
|
||||
"usage": {
|
||||
@@ -59,3 +50,22 @@ def test_usage_tracker_falls_back_to_local_qwen35_pricing_when_model_unmapped(
|
||||
|
||||
assert usage.cost == pytest.approx(expected_cost)
|
||||
assert usage.cost_source == "custom_pricing"
|
||||
|
||||
|
||||
def test_usage_tracker_uses_cached_pricing_for_deepseek_chat() -> None:
|
||||
response = {
|
||||
"model": "deepseek/deepseek-chat",
|
||||
"usage": {
|
||||
"prompt_tokens": 1_000_000,
|
||||
"completion_tokens": 100_000,
|
||||
"total_tokens": 1_100_000,
|
||||
"prompt_tokens_details": {
|
||||
"cached_tokens": 400_000,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
usage = extract_usage_and_cost(response)
|
||||
|
||||
assert usage.cost == pytest.approx(1.58)
|
||||
assert usage.cost_source == "custom_pricing"
|
||||
|
||||
@@ -1058,6 +1058,128 @@ async def test_run_service_executes_backend_calendar_tool_and_emits_result(
|
||||
assert runtime_state["status"] == AgentChatSessionStatus.COMPLETED
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_service_does_not_persist_model_code_for_user_message(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
session_id = uuid4()
|
||||
user_id = uuid4()
|
||||
captured: dict[str, object] = {}
|
||||
message_calls: list[dict[str, object]] = []
|
||||
|
||||
class _FakeDbSession:
|
||||
async def commit(self) -> None:
|
||||
return None
|
||||
|
||||
class _FakeSessionFactory:
|
||||
def __call__(self) -> "_FakeSessionFactory":
|
||||
return self
|
||||
|
||||
async def __aenter__(self) -> _FakeDbSession:
|
||||
return _FakeDbSession()
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> bool:
|
||||
del exc_type, exc, tb
|
||||
return False
|
||||
|
||||
class _FakeSessionRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def lock_session_for_update(self, *, session_id: object):
|
||||
return SimpleNamespace(
|
||||
id=session_id,
|
||||
user_id=user_id,
|
||||
status=AgentChatSessionStatus.PENDING,
|
||||
message_count=0,
|
||||
total_tokens=0,
|
||||
total_cost=0,
|
||||
state_snapshot=None,
|
||||
)
|
||||
|
||||
async def next_message_seq(self, *, session_id: object):
|
||||
del session_id
|
||||
return 1
|
||||
|
||||
async def update_runtime_state(self, **kwargs) -> None:
|
||||
captured["update_runtime_state"] = kwargs
|
||||
|
||||
class _FakeMessageRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def append_message(self, **kwargs) -> None:
|
||||
message_calls.append(kwargs)
|
||||
|
||||
class _FakeRuntime:
|
||||
def execute(
|
||||
self,
|
||||
*,
|
||||
user_input: str,
|
||||
system_prompt: str | None = None,
|
||||
tools: list[dict[str, object]] | None = None,
|
||||
):
|
||||
del user_input, system_prompt, tools
|
||||
return {
|
||||
"assistant_text": "ok",
|
||||
"prompt_tokens": 1,
|
||||
"completion_tokens": 1,
|
||||
"total_tokens": 2,
|
||||
"cost": "0.001",
|
||||
"agui_events": [],
|
||||
}
|
||||
|
||||
async def _fake_load_agent_model_selection(self, _session):
|
||||
del self
|
||||
return ("qwen3.5-flash", "dashscope", SystemAgentLLMConfig())
|
||||
|
||||
async def _fake_load_user_agent_context(self, session, session_id, user_id):
|
||||
del self, session, session_id
|
||||
return SimpleNamespace(
|
||||
user_id=user_id,
|
||||
username="demo-user",
|
||||
bio=None,
|
||||
settings=SimpleNamespace(
|
||||
preferences=SimpleNamespace(
|
||||
interface_language="zh-CN",
|
||||
ai_language="zh-CN",
|
||||
timezone="Asia/Shanghai",
|
||||
country="CN",
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.run_service.SessionRepository",
|
||||
_FakeSessionRepository,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.run_service.MessageRepository",
|
||||
_FakeMessageRepository,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.run_service.create_runtime",
|
||||
lambda **_kwargs: _FakeRuntime(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.run_service.RunService._load_agent_model_selection",
|
||||
_fake_load_agent_model_selection,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.run_service.RunService._load_user_agent_context",
|
||||
_fake_load_user_agent_context,
|
||||
)
|
||||
|
||||
service = RunService(session_factory=_FakeSessionFactory()) # type: ignore[arg-type]
|
||||
await service.run(
|
||||
run_input=_build_run_input(thread_id=str(session_id), text="hello")
|
||||
)
|
||||
|
||||
user_message = message_calls[0]
|
||||
assert user_message["role"] == AgentChatMessageRole.USER
|
||||
assert "model_code" not in user_message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_user_agent_context_parses_profile_settings_v1() -> None:
|
||||
session_id = uuid4()
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.infrastructure.crewai.runtime_stage_runner import (
|
||||
LiteLLMUsageCaptureCallback,
|
||||
extract_usage_from_captured_payload,
|
||||
extract_usage_from_crew_output,
|
||||
)
|
||||
|
||||
|
||||
def test_extract_usage_from_crew_output_uses_custom_deepseek_pricing() -> None:
|
||||
output = SimpleNamespace(
|
||||
token_usage=SimpleNamespace(
|
||||
prompt_tokens=1_000_000,
|
||||
completion_tokens=100_000,
|
||||
total_tokens=1_100_000,
|
||||
cached_prompt_tokens=400_000,
|
||||
)
|
||||
)
|
||||
|
||||
usage = extract_usage_from_crew_output(
|
||||
output=output,
|
||||
model="deepseek/deepseek-chat",
|
||||
)
|
||||
|
||||
assert usage.prompt_tokens == 1_000_000
|
||||
assert usage.completion_tokens == 100_000
|
||||
assert usage.total_tokens == 1_100_000
|
||||
assert usage.cost == pytest.approx(1.58)
|
||||
|
||||
|
||||
def test_extract_usage_from_captured_payload_uses_custom_pricing() -> None:
|
||||
usage = extract_usage_from_captured_payload(
|
||||
captured_usage={
|
||||
"prompt_tokens": 1_000_000,
|
||||
"completion_tokens": 100_000,
|
||||
"total_tokens": 1_100_000,
|
||||
"prompt_tokens_details": {"cached_tokens": 400_000},
|
||||
},
|
||||
model="deepseek/deepseek-chat",
|
||||
)
|
||||
|
||||
assert usage.prompt_tokens == 1_000_000
|
||||
assert usage.completion_tokens == 100_000
|
||||
assert usage.total_tokens == 1_100_000
|
||||
assert usage.cost == pytest.approx(1.58)
|
||||
|
||||
|
||||
def test_usage_capture_callback_extracts_nested_usage_payload() -> None:
|
||||
callback = LiteLLMUsageCaptureCallback()
|
||||
|
||||
callback.log_success_event(
|
||||
kwargs={},
|
||||
response_obj={
|
||||
"usage": {
|
||||
"prompt_tokens": 15,
|
||||
"completion_tokens": 9,
|
||||
"total_tokens": 24,
|
||||
}
|
||||
},
|
||||
start_time=0,
|
||||
end_time=0,
|
||||
)
|
||||
|
||||
assert callback.captured_usage == {
|
||||
"prompt_tokens": 15,
|
||||
"completion_tokens": 9,
|
||||
"total_tokens": 24,
|
||||
}
|
||||
Reference in New Issue
Block a user