feat: 应用名称更新为灵可析并增强 Chat 功能
- 更新 Android/iOS 应用名称和图标为灵可析 - Chat 支持取消正在运行的 Agent 对话 - 改进 ChatBloc 状态管理(区分发送/等待/流式/取消状态) - HomeScreen 支持外部注入 ChatBloc 和显示等待指示器 - 后端 Agent 运行服务优化(消息处理、usage 追踪) - 补充相关单元测试和 Widget 测试
@@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
<application
|
<application
|
||||||
android:usesCleartextTraffic="true"
|
android:usesCleartextTraffic="true"
|
||||||
android:label="social_app"
|
android:label="灵可析"
|
||||||
android:name="${applicationName}"
|
android:name="${applicationName}"
|
||||||
android:icon="@mipmap/ic_launcher">
|
android:icon="@mipmap/ic_launcher">
|
||||||
<activity
|
<activity
|
||||||
|
|||||||
|
After Width: | Height: | Size: 38 KiB |
|
After Width: | Height: | Size: 19 KiB |
|
After Width: | Height: | Size: 62 KiB |
|
After Width: | Height: | Size: 130 KiB |
|
After Width: | Height: | Size: 225 KiB |
@@ -0,0 +1,9 @@
|
|||||||
|
<?xml version="1.0" encoding="utf-8"?>
|
||||||
|
<adaptive-icon xmlns:android="http://schemas.android.com/apk/res/android">
|
||||||
|
<background android:drawable="@color/ic_launcher_background"/>
|
||||||
|
<foreground>
|
||||||
|
<inset
|
||||||
|
android:drawable="@drawable/ic_launcher_foreground"
|
||||||
|
android:inset="16%" />
|
||||||
|
</foreground>
|
||||||
|
</adaptive-icon>
|
||||||
|
Before Width: | Height: | Size: 544 B After Width: | Height: | Size: 9.7 KiB |
|
Before Width: | Height: | Size: 442 B After Width: | Height: | Size: 5.1 KiB |
|
Before Width: | Height: | Size: 721 B After Width: | Height: | Size: 16 KiB |
|
Before Width: | Height: | Size: 1.0 KiB After Width: | Height: | Size: 31 KiB |
|
Before Width: | Height: | Size: 1.4 KiB After Width: | Height: | Size: 50 KiB |
@@ -0,0 +1,4 @@
|
|||||||
|
<?xml version="1.0" encoding="utf-8"?>
|
||||||
|
<resources>
|
||||||
|
<color name="ic_launcher_background">#FFFFFF</color>
|
||||||
|
</resources>
|
||||||
@@ -427,7 +427,7 @@
|
|||||||
isa = XCBuildConfiguration;
|
isa = XCBuildConfiguration;
|
||||||
buildSettings = {
|
buildSettings = {
|
||||||
ALWAYS_SEARCH_USER_PATHS = NO;
|
ALWAYS_SEARCH_USER_PATHS = NO;
|
||||||
ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES;
|
ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = AppIcon;
|
||||||
CLANG_ANALYZER_NONNULL = YES;
|
CLANG_ANALYZER_NONNULL = YES;
|
||||||
CLANG_CXX_LANGUAGE_STANDARD = "gnu++0x";
|
CLANG_CXX_LANGUAGE_STANDARD = "gnu++0x";
|
||||||
CLANG_CXX_LIBRARY = "libc++";
|
CLANG_CXX_LIBRARY = "libc++";
|
||||||
@@ -484,7 +484,7 @@
|
|||||||
isa = XCBuildConfiguration;
|
isa = XCBuildConfiguration;
|
||||||
buildSettings = {
|
buildSettings = {
|
||||||
ALWAYS_SEARCH_USER_PATHS = NO;
|
ALWAYS_SEARCH_USER_PATHS = NO;
|
||||||
ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES;
|
ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = AppIcon;
|
||||||
CLANG_ANALYZER_NONNULL = YES;
|
CLANG_ANALYZER_NONNULL = YES;
|
||||||
CLANG_CXX_LANGUAGE_STANDARD = "gnu++0x";
|
CLANG_CXX_LANGUAGE_STANDARD = "gnu++0x";
|
||||||
CLANG_CXX_LIBRARY = "libc++";
|
CLANG_CXX_LIBRARY = "libc++";
|
||||||
|
|||||||
@@ -1,122 +1 @@
|
|||||||
{
|
{"images":[{"size":"20x20","idiom":"iphone","filename":"Icon-App-20x20@2x.png","scale":"2x"},{"size":"20x20","idiom":"iphone","filename":"Icon-App-20x20@3x.png","scale":"3x"},{"size":"29x29","idiom":"iphone","filename":"Icon-App-29x29@1x.png","scale":"1x"},{"size":"29x29","idiom":"iphone","filename":"Icon-App-29x29@2x.png","scale":"2x"},{"size":"29x29","idiom":"iphone","filename":"Icon-App-29x29@3x.png","scale":"3x"},{"size":"40x40","idiom":"iphone","filename":"Icon-App-40x40@2x.png","scale":"2x"},{"size":"40x40","idiom":"iphone","filename":"Icon-App-40x40@3x.png","scale":"3x"},{"size":"57x57","idiom":"iphone","filename":"Icon-App-57x57@1x.png","scale":"1x"},{"size":"57x57","idiom":"iphone","filename":"Icon-App-57x57@2x.png","scale":"2x"},{"size":"60x60","idiom":"iphone","filename":"Icon-App-60x60@2x.png","scale":"2x"},{"size":"60x60","idiom":"iphone","filename":"Icon-App-60x60@3x.png","scale":"3x"},{"size":"20x20","idiom":"ipad","filename":"Icon-App-20x20@1x.png","scale":"1x"},{"size":"20x20","idiom":"ipad","filename":"Icon-App-20x20@2x.png","scale":"2x"},{"size":"29x29","idiom":"ipad","filename":"Icon-App-29x29@1x.png","scale":"1x"},{"size":"29x29","idiom":"ipad","filename":"Icon-App-29x29@2x.png","scale":"2x"},{"size":"40x40","idiom":"ipad","filename":"Icon-App-40x40@1x.png","scale":"1x"},{"size":"40x40","idiom":"ipad","filename":"Icon-App-40x40@2x.png","scale":"2x"},{"size":"50x50","idiom":"ipad","filename":"Icon-App-50x50@1x.png","scale":"1x"},{"size":"50x50","idiom":"ipad","filename":"Icon-App-50x50@2x.png","scale":"2x"},{"size":"72x72","idiom":"ipad","filename":"Icon-App-72x72@1x.png","scale":"1x"},{"size":"72x72","idiom":"ipad","filename":"Icon-App-72x72@2x.png","scale":"2x"},{"size":"76x76","idiom":"ipad","filename":"Icon-App-76x76@1x.png","scale":"1x"},{"size":"76x76","idiom":"ipad","filename":"Icon-App-76x76@2x.png","scale":"2x"},{"size":"83.5x83.5","idiom":"ipad","filename":"Icon-App-83.5x83.5@2x.png","scale":"2x"},{"size":"1024x1024","idiom":"ios-marketing","filename":"Icon-App-1024x1024@1x.png","scale":"1x"}],"info":{"version":1,"author":"xcode"}}
|
||||||
"images" : [
|
|
||||||
{
|
|
||||||
"size" : "20x20",
|
|
||||||
"idiom" : "iphone",
|
|
||||||
"filename" : "Icon-App-20x20@2x.png",
|
|
||||||
"scale" : "2x"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"size" : "20x20",
|
|
||||||
"idiom" : "iphone",
|
|
||||||
"filename" : "Icon-App-20x20@3x.png",
|
|
||||||
"scale" : "3x"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"size" : "29x29",
|
|
||||||
"idiom" : "iphone",
|
|
||||||
"filename" : "Icon-App-29x29@1x.png",
|
|
||||||
"scale" : "1x"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"size" : "29x29",
|
|
||||||
"idiom" : "iphone",
|
|
||||||
"filename" : "Icon-App-29x29@2x.png",
|
|
||||||
"scale" : "2x"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"size" : "29x29",
|
|
||||||
"idiom" : "iphone",
|
|
||||||
"filename" : "Icon-App-29x29@3x.png",
|
|
||||||
"scale" : "3x"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"size" : "40x40",
|
|
||||||
"idiom" : "iphone",
|
|
||||||
"filename" : "Icon-App-40x40@2x.png",
|
|
||||||
"scale" : "2x"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"size" : "40x40",
|
|
||||||
"idiom" : "iphone",
|
|
||||||
"filename" : "Icon-App-40x40@3x.png",
|
|
||||||
"scale" : "3x"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"size" : "60x60",
|
|
||||||
"idiom" : "iphone",
|
|
||||||
"filename" : "Icon-App-60x60@2x.png",
|
|
||||||
"scale" : "2x"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"size" : "60x60",
|
|
||||||
"idiom" : "iphone",
|
|
||||||
"filename" : "Icon-App-60x60@3x.png",
|
|
||||||
"scale" : "3x"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"size" : "20x20",
|
|
||||||
"idiom" : "ipad",
|
|
||||||
"filename" : "Icon-App-20x20@1x.png",
|
|
||||||
"scale" : "1x"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"size" : "20x20",
|
|
||||||
"idiom" : "ipad",
|
|
||||||
"filename" : "Icon-App-20x20@2x.png",
|
|
||||||
"scale" : "2x"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"size" : "29x29",
|
|
||||||
"idiom" : "ipad",
|
|
||||||
"filename" : "Icon-App-29x29@1x.png",
|
|
||||||
"scale" : "1x"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"size" : "29x29",
|
|
||||||
"idiom" : "ipad",
|
|
||||||
"filename" : "Icon-App-29x29@2x.png",
|
|
||||||
"scale" : "2x"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"size" : "40x40",
|
|
||||||
"idiom" : "ipad",
|
|
||||||
"filename" : "Icon-App-40x40@1x.png",
|
|
||||||
"scale" : "1x"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"size" : "40x40",
|
|
||||||
"idiom" : "ipad",
|
|
||||||
"filename" : "Icon-App-40x40@2x.png",
|
|
||||||
"scale" : "2x"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"size" : "76x76",
|
|
||||||
"idiom" : "ipad",
|
|
||||||
"filename" : "Icon-App-76x76@1x.png",
|
|
||||||
"scale" : "1x"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"size" : "76x76",
|
|
||||||
"idiom" : "ipad",
|
|
||||||
"filename" : "Icon-App-76x76@2x.png",
|
|
||||||
"scale" : "2x"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"size" : "83.5x83.5",
|
|
||||||
"idiom" : "ipad",
|
|
||||||
"filename" : "Icon-App-83.5x83.5@2x.png",
|
|
||||||
"scale" : "2x"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"size" : "1024x1024",
|
|
||||||
"idiom" : "ios-marketing",
|
|
||||||
"filename" : "Icon-App-1024x1024@1x.png",
|
|
||||||
"scale" : "1x"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"info" : {
|
|
||||||
"version" : 1,
|
|
||||||
"author" : "xcode"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
Before Width: | Height: | Size: 11 KiB After Width: | Height: | Size: 1.3 MiB |
|
Before Width: | Height: | Size: 295 B After Width: | Height: | Size: 1.2 KiB |
|
Before Width: | Height: | Size: 406 B After Width: | Height: | Size: 3.8 KiB |
|
Before Width: | Height: | Size: 450 B After Width: | Height: | Size: 7.2 KiB |
|
Before Width: | Height: | Size: 282 B After Width: | Height: | Size: 2.3 KiB |
|
Before Width: | Height: | Size: 462 B After Width: | Height: | Size: 6.8 KiB |
|
Before Width: | Height: | Size: 704 B After Width: | Height: | Size: 13 KiB |
|
Before Width: | Height: | Size: 406 B After Width: | Height: | Size: 3.8 KiB |
|
Before Width: | Height: | Size: 586 B After Width: | Height: | Size: 12 KiB |
|
Before Width: | Height: | Size: 862 B After Width: | Height: | Size: 22 KiB |
|
After Width: | Height: | Size: 5.4 KiB |
|
After Width: | Height: | Size: 17 KiB |
|
After Width: | Height: | Size: 6.7 KiB |
|
After Width: | Height: | Size: 21 KiB |
|
Before Width: | Height: | Size: 862 B After Width: | Height: | Size: 22 KiB |
|
Before Width: | Height: | Size: 1.6 KiB After Width: | Height: | Size: 45 KiB |
|
After Width: | Height: | Size: 9.7 KiB |
|
After Width: | Height: | Size: 31 KiB |
|
Before Width: | Height: | Size: 762 B After Width: | Height: | Size: 11 KiB |
|
Before Width: | Height: | Size: 1.2 KiB After Width: | Height: | Size: 34 KiB |
|
Before Width: | Height: | Size: 1.4 KiB After Width: | Height: | Size: 40 KiB |
@@ -5,7 +5,7 @@
|
|||||||
<key>CFBundleDevelopmentRegion</key>
|
<key>CFBundleDevelopmentRegion</key>
|
||||||
<string>$(DEVELOPMENT_LANGUAGE)</string>
|
<string>$(DEVELOPMENT_LANGUAGE)</string>
|
||||||
<key>CFBundleDisplayName</key>
|
<key>CFBundleDisplayName</key>
|
||||||
<string>Social App</string>
|
<string>灵可析</string>
|
||||||
<key>CFBundleExecutable</key>
|
<key>CFBundleExecutable</key>
|
||||||
<string>$(EXECUTABLE_NAME)</string>
|
<string>$(EXECUTABLE_NAME)</string>
|
||||||
<key>CFBundleIdentifier</key>
|
<key>CFBundleIdentifier</key>
|
||||||
@@ -13,7 +13,7 @@
|
|||||||
<key>CFBundleInfoDictionaryVersion</key>
|
<key>CFBundleInfoDictionaryVersion</key>
|
||||||
<string>6.0</string>
|
<string>6.0</string>
|
||||||
<key>CFBundleName</key>
|
<key>CFBundleName</key>
|
||||||
<string>social_app</string>
|
<string>灵可析</string>
|
||||||
<key>CFBundlePackageType</key>
|
<key>CFBundlePackageType</key>
|
||||||
<string>APPL</string>
|
<string>APPL</string>
|
||||||
<key>CFBundleShortVersionString</key>
|
<key>CFBundleShortVersionString</key>
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ class AgUiService {
|
|||||||
final MockHistoryService _historyService;
|
final MockHistoryService _historyService;
|
||||||
final Map<String, List<String>> _mockSseLinesByThread = {};
|
final Map<String, List<String>> _mockSseLinesByThread = {};
|
||||||
final Map<String, String> _lastEventIdByThread = {};
|
final Map<String, String> _lastEventIdByThread = {};
|
||||||
|
int _activeStreamToken = 0;
|
||||||
|
|
||||||
String? _threadId;
|
String? _threadId;
|
||||||
bool _hasMoreHistory = false;
|
bool _hasMoreHistory = false;
|
||||||
@@ -41,6 +42,7 @@ class AgUiService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Future<void> sendMessage(String content) async {
|
Future<void> sendMessage(String content) async {
|
||||||
|
final streamToken = ++_activeStreamToken;
|
||||||
final runInput = _buildRunInput(content: content);
|
final runInput = _buildRunInput(content: content);
|
||||||
final response = await _apiClient.post<Map<String, dynamic>>(
|
final response = await _apiClient.post<Map<String, dynamic>>(
|
||||||
'/api/v1/agent/runs',
|
'/api/v1/agent/runs',
|
||||||
@@ -55,7 +57,7 @@ class AgUiService {
|
|||||||
throw StateError('Missing threadId in /agent/runs response');
|
throw StateError('Missing threadId in /agent/runs response');
|
||||||
}
|
}
|
||||||
_threadId = threadId;
|
_threadId = threadId;
|
||||||
await _streamEventsFromApi(threadId);
|
await _streamEventsFromApi(threadId, streamToken: streamToken);
|
||||||
}
|
}
|
||||||
|
|
||||||
Future<void> loadHistory({DateTime? beforeDate}) async {
|
Future<void> loadHistory({DateTime? beforeDate}) async {
|
||||||
@@ -105,6 +107,7 @@ class AgUiService {
|
|||||||
required String toolName,
|
required String toolName,
|
||||||
required Map<String, dynamic> args,
|
required Map<String, dynamic> args,
|
||||||
}) async {
|
}) async {
|
||||||
|
final streamToken = ++_activeStreamToken;
|
||||||
final threadId = _threadId;
|
final threadId = _threadId;
|
||||||
if (threadId == null || threadId.isEmpty) {
|
if (threadId == null || threadId.isEmpty) {
|
||||||
throw StateError('Missing threadId for resume');
|
throw StateError('Missing threadId for resume');
|
||||||
@@ -150,7 +153,7 @@ class AgUiService {
|
|||||||
_threadId = responseThreadId;
|
_threadId = responseThreadId;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
await _streamEventsFromApi(threadId);
|
await _streamEventsFromApi(threadId, streamToken: streamToken);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool hasEarlierHistory(DateTime fromDate) {
|
bool hasEarlierHistory(DateTime fromDate) {
|
||||||
@@ -160,7 +163,14 @@ class AgUiService {
|
|||||||
return _hasMoreHistory;
|
return _hasMoreHistory;
|
||||||
}
|
}
|
||||||
|
|
||||||
Future<void> _streamEventsFromApi(String threadId) async {
|
Future<void> cancelCurrentRun() async {
|
||||||
|
_activeStreamToken += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
Future<void> _streamEventsFromApi(
|
||||||
|
String threadId, {
|
||||||
|
required int streamToken,
|
||||||
|
}) async {
|
||||||
final lastEventId = _lastEventIdByThread[threadId];
|
final lastEventId = _lastEventIdByThread[threadId];
|
||||||
final headers = <String, String>{'Accept': 'text/event-stream'};
|
final headers = <String, String>{'Accept': 'text/event-stream'};
|
||||||
if (lastEventId != null && lastEventId.isNotEmpty) {
|
if (lastEventId != null && lastEventId.isNotEmpty) {
|
||||||
@@ -175,6 +185,9 @@ class AgUiService {
|
|||||||
String? eventId;
|
String? eventId;
|
||||||
final dataBuffer = StringBuffer();
|
final dataBuffer = StringBuffer();
|
||||||
await for (final line in sseLines) {
|
await for (final line in sseLines) {
|
||||||
|
if (streamToken != _activeStreamToken) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
if (line.isEmpty) {
|
if (line.isEmpty) {
|
||||||
if (dataBuffer.isNotEmpty) {
|
if (dataBuffer.isNotEmpty) {
|
||||||
final raw = dataBuffer.toString();
|
final raw = dataBuffer.toString();
|
||||||
|
|||||||
@@ -11,7 +11,11 @@ import '../../data/services/ag_ui_service.dart';
|
|||||||
|
|
||||||
class ChatState {
|
class ChatState {
|
||||||
final List<ChatListItem> items;
|
final List<ChatListItem> items;
|
||||||
final bool isLoading;
|
final bool isSending;
|
||||||
|
final bool isWaitingFirstToken;
|
||||||
|
final bool isStreaming;
|
||||||
|
final bool isCancelling;
|
||||||
|
final bool isLoadingHistory;
|
||||||
final String? currentMessageId;
|
final String? currentMessageId;
|
||||||
final String? error;
|
final String? error;
|
||||||
final DateTime? oldestLoadedDate;
|
final DateTime? oldestLoadedDate;
|
||||||
@@ -19,18 +23,33 @@ class ChatState {
|
|||||||
|
|
||||||
const ChatState({
|
const ChatState({
|
||||||
this.items = const [],
|
this.items = const [],
|
||||||
this.isLoading = false,
|
this.isSending = false,
|
||||||
|
this.isWaitingFirstToken = false,
|
||||||
|
this.isStreaming = false,
|
||||||
|
this.isCancelling = false,
|
||||||
|
this.isLoadingHistory = false,
|
||||||
this.currentMessageId,
|
this.currentMessageId,
|
||||||
this.error,
|
this.error,
|
||||||
this.oldestLoadedDate,
|
this.oldestLoadedDate,
|
||||||
this.hasEarlierHistory = false,
|
this.hasEarlierHistory = false,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
bool get isLoading =>
|
||||||
|
isSending ||
|
||||||
|
isWaitingFirstToken ||
|
||||||
|
isStreaming ||
|
||||||
|
isCancelling ||
|
||||||
|
isLoadingHistory;
|
||||||
|
|
||||||
static const _unset = Object();
|
static const _unset = Object();
|
||||||
|
|
||||||
ChatState copyWith({
|
ChatState copyWith({
|
||||||
List<ChatListItem>? items,
|
List<ChatListItem>? items,
|
||||||
bool? isLoading,
|
bool? isSending,
|
||||||
|
bool? isWaitingFirstToken,
|
||||||
|
bool? isStreaming,
|
||||||
|
bool? isCancelling,
|
||||||
|
bool? isLoadingHistory,
|
||||||
Object? currentMessageId = _unset,
|
Object? currentMessageId = _unset,
|
||||||
Object? error = _unset,
|
Object? error = _unset,
|
||||||
Object? oldestLoadedDate = _unset,
|
Object? oldestLoadedDate = _unset,
|
||||||
@@ -38,7 +57,11 @@ class ChatState {
|
|||||||
}) {
|
}) {
|
||||||
return ChatState(
|
return ChatState(
|
||||||
items: items ?? this.items,
|
items: items ?? this.items,
|
||||||
isLoading: isLoading ?? this.isLoading,
|
isSending: isSending ?? this.isSending,
|
||||||
|
isWaitingFirstToken: isWaitingFirstToken ?? this.isWaitingFirstToken,
|
||||||
|
isStreaming: isStreaming ?? this.isStreaming,
|
||||||
|
isCancelling: isCancelling ?? this.isCancelling,
|
||||||
|
isLoadingHistory: isLoadingHistory ?? this.isLoadingHistory,
|
||||||
currentMessageId: currentMessageId == _unset
|
currentMessageId: currentMessageId == _unset
|
||||||
? this.currentMessageId
|
? this.currentMessageId
|
||||||
: currentMessageId as String?,
|
: currentMessageId as String?,
|
||||||
@@ -72,12 +95,36 @@ class ChatBloc extends Cubit<ChatState> {
|
|||||||
void _handleEvent(AgUiEvent event) {
|
void _handleEvent(AgUiEvent event) {
|
||||||
switch (event.type) {
|
switch (event.type) {
|
||||||
case AgUiEventType.runStarted:
|
case AgUiEventType.runStarted:
|
||||||
emit(state.copyWith(isLoading: true, error: null));
|
emit(
|
||||||
|
state.copyWith(
|
||||||
|
isSending: false,
|
||||||
|
isWaitingFirstToken: true,
|
||||||
|
isCancelling: false,
|
||||||
|
error: null,
|
||||||
|
),
|
||||||
|
);
|
||||||
case AgUiEventType.runFinished:
|
case AgUiEventType.runFinished:
|
||||||
emit(state.copyWith(isLoading: false, currentMessageId: null));
|
emit(
|
||||||
|
state.copyWith(
|
||||||
|
isSending: false,
|
||||||
|
isWaitingFirstToken: false,
|
||||||
|
isStreaming: false,
|
||||||
|
isCancelling: false,
|
||||||
|
currentMessageId: null,
|
||||||
|
),
|
||||||
|
);
|
||||||
case AgUiEventType.runError:
|
case AgUiEventType.runError:
|
||||||
final errorEvent = event as RunErrorEvent;
|
final errorEvent = event as RunErrorEvent;
|
||||||
emit(state.copyWith(isLoading: false, error: errorEvent.message));
|
emit(
|
||||||
|
state.copyWith(
|
||||||
|
isSending: false,
|
||||||
|
isWaitingFirstToken: false,
|
||||||
|
isStreaming: false,
|
||||||
|
isCancelling: false,
|
||||||
|
currentMessageId: null,
|
||||||
|
error: errorEvent.message,
|
||||||
|
),
|
||||||
|
);
|
||||||
case AgUiEventType.textMessageStart:
|
case AgUiEventType.textMessageStart:
|
||||||
_handleTextMessageStart(event as TextMessageStartEvent);
|
_handleTextMessageStart(event as TextMessageStartEvent);
|
||||||
case AgUiEventType.textMessageContent:
|
case AgUiEventType.textMessageContent:
|
||||||
@@ -115,6 +162,8 @@ class ChatBloc extends Cubit<ChatState> {
|
|||||||
state.copyWith(
|
state.copyWith(
|
||||||
items: [...state.items, newMessage],
|
items: [...state.items, newMessage],
|
||||||
currentMessageId: startEvent.messageId,
|
currentMessageId: startEvent.messageId,
|
||||||
|
isWaitingFirstToken: false,
|
||||||
|
isStreaming: true,
|
||||||
),
|
),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@@ -136,7 +185,13 @@ class ChatBloc extends Cubit<ChatState> {
|
|||||||
}
|
}
|
||||||
return item;
|
return item;
|
||||||
}).toList();
|
}).toList();
|
||||||
emit(state.copyWith(items: updatedItems, currentMessageId: null));
|
emit(
|
||||||
|
state.copyWith(
|
||||||
|
items: updatedItems,
|
||||||
|
currentMessageId: null,
|
||||||
|
isStreaming: false,
|
||||||
|
),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
void _handleToolCallStart(ToolCallStartEvent startEvent) {
|
void _handleToolCallStart(ToolCallStartEvent startEvent) {
|
||||||
@@ -319,20 +374,50 @@ class ChatBloc extends Cubit<ChatState> {
|
|||||||
timestamp: DateTime.now(),
|
timestamp: DateTime.now(),
|
||||||
sender: MessageSender.user,
|
sender: MessageSender.user,
|
||||||
);
|
);
|
||||||
emit(state.copyWith(items: [...state.items, userMessage]));
|
emit(
|
||||||
await _service.sendMessage(content);
|
state.copyWith(
|
||||||
|
items: [...state.items, userMessage],
|
||||||
|
isSending: true,
|
||||||
|
isWaitingFirstToken: true,
|
||||||
|
isStreaming: false,
|
||||||
|
isCancelling: false,
|
||||||
|
error: null,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
try {
|
||||||
|
await _service.sendMessage(content);
|
||||||
|
} catch (error) {
|
||||||
|
emit(
|
||||||
|
state.copyWith(
|
||||||
|
isSending: false,
|
||||||
|
isWaitingFirstToken: false,
|
||||||
|
isStreaming: false,
|
||||||
|
isCancelling: false,
|
||||||
|
error: error.toString(),
|
||||||
|
),
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Future<void> loadHistory() async {
|
Future<void> loadHistory() async {
|
||||||
if (state.isLoading) return;
|
if (state.isLoadingHistory) return;
|
||||||
await _service.loadHistory();
|
emit(state.copyWith(isLoadingHistory: true));
|
||||||
|
try {
|
||||||
|
await _service.loadHistory();
|
||||||
|
} finally {
|
||||||
|
emit(state.copyWith(isLoadingHistory: false));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Future<void> loadMoreHistory() async {
|
Future<void> loadMoreHistory() async {
|
||||||
if (state.isLoading || !state.hasEarlierHistory) return;
|
if (state.isLoadingHistory || !state.hasEarlierHistory) return;
|
||||||
if (state.oldestLoadedDate == null) return;
|
if (state.oldestLoadedDate == null) return;
|
||||||
|
emit(state.copyWith(isLoadingHistory: true));
|
||||||
await _service.loadHistory(beforeDate: state.oldestLoadedDate);
|
try {
|
||||||
|
await _service.loadHistory(beforeDate: state.oldestLoadedDate);
|
||||||
|
} finally {
|
||||||
|
emit(state.copyWith(isLoadingHistory: false));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Future<void> approveToolCall(String toolCallId) async {
|
Future<void> approveToolCall(String toolCallId) async {
|
||||||
@@ -355,7 +440,16 @@ class ChatBloc extends Cubit<ChatState> {
|
|||||||
}
|
}
|
||||||
return item;
|
return item;
|
||||||
}).toList();
|
}).toList();
|
||||||
emit(state.copyWith(items: updatedItems, isLoading: true, error: null));
|
emit(
|
||||||
|
state.copyWith(
|
||||||
|
items: updatedItems,
|
||||||
|
isSending: false,
|
||||||
|
isWaitingFirstToken: true,
|
||||||
|
isStreaming: false,
|
||||||
|
isCancelling: false,
|
||||||
|
error: null,
|
||||||
|
),
|
||||||
|
);
|
||||||
try {
|
try {
|
||||||
await _service.approveToolCall(
|
await _service.approveToolCall(
|
||||||
toolCallId: target.callId,
|
toolCallId: target.callId,
|
||||||
@@ -375,7 +469,10 @@ class ChatBloc extends Cubit<ChatState> {
|
|||||||
emit(
|
emit(
|
||||||
state.copyWith(
|
state.copyWith(
|
||||||
items: failedItems,
|
items: failedItems,
|
||||||
isLoading: false,
|
isSending: false,
|
||||||
|
isWaitingFirstToken: false,
|
||||||
|
isStreaming: false,
|
||||||
|
isCancelling: false,
|
||||||
error: error.toString(),
|
error: error.toString(),
|
||||||
),
|
),
|
||||||
);
|
);
|
||||||
@@ -386,6 +483,31 @@ class ChatBloc extends Cubit<ChatState> {
|
|||||||
return _service.transcribeAudio(filePath);
|
return _service.transcribeAudio(filePath);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Future<bool> cancelCurrentRun() async {
|
||||||
|
if (!(state.isWaitingFirstToken ||
|
||||||
|
state.isStreaming ||
|
||||||
|
state.isCancelling)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
emit(state.copyWith(isCancelling: true, error: null));
|
||||||
|
try {
|
||||||
|
await _service.cancelCurrentRun();
|
||||||
|
emit(
|
||||||
|
state.copyWith(
|
||||||
|
isSending: false,
|
||||||
|
isWaitingFirstToken: false,
|
||||||
|
isStreaming: false,
|
||||||
|
isCancelling: false,
|
||||||
|
currentMessageId: null,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
return true;
|
||||||
|
} catch (error) {
|
||||||
|
emit(state.copyWith(isCancelling: false, error: error.toString()));
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void clearError() {
|
void clearError() {
|
||||||
emit(state.copyWith(error: null));
|
emit(state.copyWith(error: null));
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -31,6 +31,8 @@ const _rippleDurationMs = 1200;
|
|||||||
const _recordingDotSize = 10.0;
|
const _recordingDotSize = 10.0;
|
||||||
const _transcribingSpinnerSize = 18.0;
|
const _transcribingSpinnerSize = 18.0;
|
||||||
const _transcribingStrokeWidth = 2.0;
|
const _transcribingStrokeWidth = 2.0;
|
||||||
|
const _inputActionButtonKey = ValueKey('home_input_action_button');
|
||||||
|
const _inputActionIconKey = ValueKey('home_input_action_icon');
|
||||||
|
|
||||||
/// 颜色常量
|
/// 颜色常量
|
||||||
const _chatBgColor = Color(0xFFF8FAFC);
|
const _chatBgColor = Color(0xFFF8FAFC);
|
||||||
@@ -40,6 +42,7 @@ class HomeScreen extends StatefulWidget {
|
|||||||
final VoiceRecorder? voiceRecorder;
|
final VoiceRecorder? voiceRecorder;
|
||||||
final Future<String> Function(String filePath)? onTranscribeAudio;
|
final Future<String> Function(String filePath)? onTranscribeAudio;
|
||||||
final Future<void> Function(String transcript)? onAutoSendTranscript;
|
final Future<void> Function(String transcript)? onAutoSendTranscript;
|
||||||
|
final ChatBloc? chatBloc;
|
||||||
final bool autoLoadHistory;
|
final bool autoLoadHistory;
|
||||||
|
|
||||||
const HomeScreen({
|
const HomeScreen({
|
||||||
@@ -47,6 +50,7 @@ class HomeScreen extends StatefulWidget {
|
|||||||
this.voiceRecorder,
|
this.voiceRecorder,
|
||||||
this.onTranscribeAudio,
|
this.onTranscribeAudio,
|
||||||
this.onAutoSendTranscript,
|
this.onAutoSendTranscript,
|
||||||
|
this.chatBloc,
|
||||||
this.autoLoadHistory = true,
|
this.autoLoadHistory = true,
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -72,7 +76,7 @@ class _HomeScreenState extends State<HomeScreen>
|
|||||||
void initState() {
|
void initState() {
|
||||||
super.initState();
|
super.initState();
|
||||||
_messageController.addListener(_onMessageChanged);
|
_messageController.addListener(_onMessageChanged);
|
||||||
_chatBloc = ChatBloc();
|
_chatBloc = widget.chatBloc ?? ChatBloc();
|
||||||
_voiceRecorder = widget.voiceRecorder ?? RecordVoiceRecorder();
|
_voiceRecorder = widget.voiceRecorder ?? RecordVoiceRecorder();
|
||||||
_transcribeAudio =
|
_transcribeAudio =
|
||||||
widget.onTranscribeAudio ?? _chatBloc.transcribeAudioFile;
|
widget.onTranscribeAudio ?? _chatBloc.transcribeAudioFile;
|
||||||
@@ -93,7 +97,9 @@ class _HomeScreenState extends State<HomeScreen>
|
|||||||
_scrollController.dispose();
|
_scrollController.dispose();
|
||||||
_listeningAnimationController.dispose();
|
_listeningAnimationController.dispose();
|
||||||
_voiceRecorder.dispose();
|
_voiceRecorder.dispose();
|
||||||
_chatBloc.close();
|
if (widget.chatBloc == null) {
|
||||||
|
_chatBloc.close();
|
||||||
|
}
|
||||||
RouteNavigationTool.instance.clearNavigator();
|
RouteNavigationTool.instance.clearNavigator();
|
||||||
super.dispose();
|
super.dispose();
|
||||||
}
|
}
|
||||||
@@ -131,7 +137,7 @@ class _HomeScreenState extends State<HomeScreen>
|
|||||||
children: [
|
children: [
|
||||||
_buildHeader(context),
|
_buildHeader(context),
|
||||||
Expanded(child: _buildChatArea(context, state)),
|
Expanded(child: _buildChatArea(context, state)),
|
||||||
_buildInputContainer(context),
|
_buildInputContainer(context, state),
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
@@ -185,49 +191,100 @@ class _HomeScreenState extends State<HomeScreen>
|
|||||||
}
|
}
|
||||||
|
|
||||||
Widget _buildChatArea(BuildContext context, ChatState state) {
|
Widget _buildChatArea(BuildContext context, ChatState state) {
|
||||||
if (state.isLoading && state.items.isEmpty) {
|
final showWaitingIndicator =
|
||||||
|
state.isWaitingFirstToken || state.isStreaming || state.isCancelling;
|
||||||
|
|
||||||
|
if (state.isLoadingHistory && state.items.isEmpty) {
|
||||||
return const Center(child: CircularProgressIndicator());
|
return const Center(child: CircularProgressIndicator());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (state.items.isEmpty) {
|
if (state.items.isEmpty) {
|
||||||
return const Center(
|
return Column(
|
||||||
child: Text(
|
crossAxisAlignment: CrossAxisAlignment.stretch,
|
||||||
'开始对话吧',
|
children: [
|
||||||
style: TextStyle(fontSize: 16, color: AppColors.slate400),
|
const Expanded(
|
||||||
),
|
child: Center(
|
||||||
|
child: Text(
|
||||||
|
'开始对话吧',
|
||||||
|
style: TextStyle(fontSize: 16, color: AppColors.slate400),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
if (showWaitingIndicator) _buildWaitingIndicator(),
|
||||||
|
],
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
return RefreshIndicator(
|
return Column(
|
||||||
onRefresh: () => _onRefresh(context),
|
crossAxisAlignment: CrossAxisAlignment.stretch,
|
||||||
child: ListView.builder(
|
children: [
|
||||||
controller: _scrollController,
|
Expanded(
|
||||||
physics: const AlwaysScrollableScrollPhysics(),
|
child: RefreshIndicator(
|
||||||
padding: const EdgeInsets.all(_defaultPadding),
|
onRefresh: () => _onRefresh(context),
|
||||||
itemCount: state.items.length + (state.hasEarlierHistory ? 1 : 0),
|
child: ListView.builder(
|
||||||
itemBuilder: (context, index) {
|
controller: _scrollController,
|
||||||
if (index == 0 && state.hasEarlierHistory) {
|
physics: const AlwaysScrollableScrollPhysics(),
|
||||||
return _buildLoadMoreButton(context, state.isLoading);
|
padding: const EdgeInsets.all(_defaultPadding),
|
||||||
}
|
itemCount: state.items.length + (state.hasEarlierHistory ? 1 : 0),
|
||||||
|
itemBuilder: (context, index) {
|
||||||
|
if (index == 0 && state.hasEarlierHistory) {
|
||||||
|
return _buildLoadMoreButton(context, state.isLoadingHistory);
|
||||||
|
}
|
||||||
|
|
||||||
final itemIndex = state.hasEarlierHistory ? index - 1 : index;
|
final itemIndex = state.hasEarlierHistory ? index - 1 : index;
|
||||||
final item = state.items[itemIndex];
|
final item = state.items[itemIndex];
|
||||||
|
|
||||||
final showDateDivider =
|
final showDateDivider =
|
||||||
itemIndex == 0 ||
|
itemIndex == 0 ||
|
||||||
!_isSameDay(state.items[itemIndex - 1].timestamp, item.timestamp);
|
!_isSameDay(
|
||||||
|
state.items[itemIndex - 1].timestamp,
|
||||||
|
item.timestamp,
|
||||||
|
);
|
||||||
|
|
||||||
return Column(
|
return Column(
|
||||||
crossAxisAlignment: CrossAxisAlignment.stretch,
|
crossAxisAlignment: CrossAxisAlignment.stretch,
|
||||||
children: [
|
children: [
|
||||||
if (showDateDivider) _buildDateDivider(item.timestamp),
|
if (showDateDivider) _buildDateDivider(item.timestamp),
|
||||||
Padding(
|
Padding(
|
||||||
padding: const EdgeInsets.only(bottom: _itemSpacing),
|
padding: const EdgeInsets.only(bottom: _itemSpacing),
|
||||||
child: _buildChatItem(item),
|
child: _buildChatItem(item),
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
);
|
);
|
||||||
},
|
},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
if (showWaitingIndicator) _buildWaitingIndicator(),
|
||||||
|
],
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Widget _buildWaitingIndicator() {
|
||||||
|
return Padding(
|
||||||
|
padding: const EdgeInsets.fromLTRB(
|
||||||
|
_defaultPadding,
|
||||||
|
0,
|
||||||
|
_defaultPadding,
|
||||||
|
_defaultPadding,
|
||||||
|
),
|
||||||
|
child: Row(
|
||||||
|
crossAxisAlignment: CrossAxisAlignment.center,
|
||||||
|
children: const [
|
||||||
|
SizedBox(
|
||||||
|
width: _transcribingSpinnerSize,
|
||||||
|
height: _transcribingSpinnerSize,
|
||||||
|
child: CircularProgressIndicator(
|
||||||
|
strokeWidth: _transcribingStrokeWidth,
|
||||||
|
color: AppColors.blue600,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
SizedBox(width: 8),
|
||||||
|
Text(
|
||||||
|
'正在思考...',
|
||||||
|
style: TextStyle(fontSize: 14, color: AppColors.slate500),
|
||||||
|
),
|
||||||
|
],
|
||||||
),
|
),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@@ -406,7 +463,9 @@ class _HomeScreenState extends State<HomeScreen>
|
|||||||
return UiSchemaRenderer.render(item.uiCard);
|
return UiSchemaRenderer.render(item.uiCard);
|
||||||
}
|
}
|
||||||
|
|
||||||
Widget _buildInputContainer(BuildContext context) {
|
Widget _buildInputContainer(BuildContext context, ChatState state) {
|
||||||
|
final isWaitingAgent =
|
||||||
|
state.isWaitingFirstToken || state.isStreaming || state.isCancelling;
|
||||||
return Container(
|
return Container(
|
||||||
padding: const EdgeInsets.all(_inputPadding),
|
padding: const EdgeInsets.all(_inputPadding),
|
||||||
color: _chatBgColor,
|
color: _chatBgColor,
|
||||||
@@ -471,10 +530,13 @@ class _HomeScreenState extends State<HomeScreen>
|
|||||||
),
|
),
|
||||||
const SizedBox(width: 8),
|
const SizedBox(width: 8),
|
||||||
GestureDetector(
|
GestureDetector(
|
||||||
|
key: _inputActionButtonKey,
|
||||||
onTap: _isTranscribing
|
onTap: _isTranscribing
|
||||||
? null
|
? null
|
||||||
: _isRecording
|
: _isRecording
|
||||||
? () => _stopRecording(autoSendAfterTranscribe: true)
|
? () => _stopRecording(autoSendAfterTranscribe: true)
|
||||||
|
: isWaitingAgent
|
||||||
|
? () => _onStopGenerating(context)
|
||||||
: _hasMessage
|
: _hasMessage
|
||||||
? () => _sendMessage(context)
|
? () => _sendMessage(context)
|
||||||
: _startRecording,
|
: _startRecording,
|
||||||
@@ -488,11 +550,14 @@ class _HomeScreenState extends State<HomeScreen>
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
: Icon(
|
: Icon(
|
||||||
_isRecording || _hasMessage
|
key: _inputActionIconKey,
|
||||||
|
_isRecording || isWaitingAgent
|
||||||
|
? LucideIcons.square
|
||||||
|
: _hasMessage
|
||||||
? LucideIcons.send
|
? LucideIcons.send
|
||||||
: LucideIcons.mic,
|
: LucideIcons.mic,
|
||||||
size: _iconSize,
|
size: _iconSize,
|
||||||
color: _isRecording || _hasMessage
|
color: _isRecording || isWaitingAgent || _hasMessage
|
||||||
? AppColors.blue600
|
? AppColors.blue600
|
||||||
: AppColors.slate500,
|
: AppColors.slate500,
|
||||||
),
|
),
|
||||||
@@ -511,7 +576,7 @@ class _HomeScreenState extends State<HomeScreen>
|
|||||||
if (content.isEmpty) return;
|
if (content.isEmpty) return;
|
||||||
FocusScope.of(context).unfocus();
|
FocusScope.of(context).unfocus();
|
||||||
_messageController.clear();
|
_messageController.clear();
|
||||||
context.read<ChatBloc>().sendMessage(content);
|
await context.read<ChatBloc>().sendMessage(content);
|
||||||
|
|
||||||
WidgetsBinding.instance.addPostFrameCallback((_) {
|
WidgetsBinding.instance.addPostFrameCallback((_) {
|
||||||
if (_scrollController.hasClients) {
|
if (_scrollController.hasClients) {
|
||||||
@@ -524,6 +589,16 @@ class _HomeScreenState extends State<HomeScreen>
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Future<void> _onStopGenerating(BuildContext context) async {
|
||||||
|
final canceled = await context.read<ChatBloc>().cancelCurrentRun();
|
||||||
|
if (!mounted) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (canceled) {
|
||||||
|
Toast.show(context, '已停止等待回复', type: ToastType.info);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Widget _buildListeningIndicator() {
|
Widget _buildListeningIndicator() {
|
||||||
return SizedBox(
|
return SizedBox(
|
||||||
height: _inputMinHeight,
|
height: _inputMinHeight,
|
||||||
|
|||||||
@@ -31,8 +31,17 @@ dev_dependencies:
|
|||||||
mocktail: ^1.0.4
|
mocktail: ^1.0.4
|
||||||
json_serializable: ^6.7.1
|
json_serializable: ^6.7.1
|
||||||
build_runner: ^2.4.8
|
build_runner: ^2.4.8
|
||||||
|
flutter_launcher_icons: ^0.14.0
|
||||||
|
|
||||||
flutter:
|
flutter:
|
||||||
uses-material-design: true
|
uses-material-design: true
|
||||||
assets:
|
assets:
|
||||||
- assets/images/
|
- assets/images/
|
||||||
|
|
||||||
|
flutter_launcher_icons:
|
||||||
|
android: true
|
||||||
|
ios: true
|
||||||
|
image_path: "assets/images/logo.png"
|
||||||
|
adaptive_icon_background: "#FFFFFF"
|
||||||
|
adaptive_icon_foreground: "assets/images/logo.png"
|
||||||
|
remove_alpha_ios: true
|
||||||
|
|||||||
@@ -12,6 +12,15 @@ class MockAgUiService extends AgUiService {
|
|||||||
Future<void> sendMessage(String content) async {}
|
Future<void> sendMessage(String content) async {}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class _ThrowingAgUiService extends AgUiService {
|
||||||
|
_ThrowingAgUiService() : super(onEvent: (_) {});
|
||||||
|
|
||||||
|
@override
|
||||||
|
Future<void> sendMessage(String content) async {
|
||||||
|
throw StateError('network down');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
late ChatBloc chatBloc;
|
late ChatBloc chatBloc;
|
||||||
late AgUiService service;
|
late AgUiService service;
|
||||||
@@ -29,6 +38,9 @@ void main() {
|
|||||||
test('initial state is empty', () {
|
test('initial state is empty', () {
|
||||||
expect(chatBloc.state.items, isEmpty);
|
expect(chatBloc.state.items, isEmpty);
|
||||||
expect(chatBloc.state.isLoading, false);
|
expect(chatBloc.state.isLoading, false);
|
||||||
|
expect(chatBloc.state.isSending, false);
|
||||||
|
expect(chatBloc.state.isWaitingFirstToken, false);
|
||||||
|
expect(chatBloc.state.isStreaming, false);
|
||||||
expect(chatBloc.state.currentMessageId, isNull);
|
expect(chatBloc.state.currentMessageId, isNull);
|
||||||
expect(chatBloc.state.error, isNull);
|
expect(chatBloc.state.error, isNull);
|
||||||
});
|
});
|
||||||
@@ -40,6 +52,12 @@ void main() {
|
|||||||
expect: () => [
|
expect: () => [
|
||||||
isA<ChatState>()
|
isA<ChatState>()
|
||||||
.having((state) => state.items.length, 'items length', 1)
|
.having((state) => state.items.length, 'items length', 1)
|
||||||
|
.having((state) => state.isSending, 'isSending', true)
|
||||||
|
.having(
|
||||||
|
(state) => state.isWaitingFirstToken,
|
||||||
|
'isWaitingFirstToken',
|
||||||
|
true,
|
||||||
|
)
|
||||||
.having(
|
.having(
|
||||||
(state) => state.items.first,
|
(state) => state.items.first,
|
||||||
'first item',
|
'first item',
|
||||||
@@ -56,15 +74,13 @@ void main() {
|
|||||||
'textMessageStart event adds AI message with streaming',
|
'textMessageStart event adds AI message with streaming',
|
||||||
build: () => chatBloc,
|
build: () => chatBloc,
|
||||||
act: (bloc) {
|
act: (bloc) {
|
||||||
bloc.emit(chatBloc.state.copyWith(isLoading: true));
|
bloc.emit(chatBloc.state.copyWith(isStreaming: true));
|
||||||
service.onEvent(
|
service.onEvent(
|
||||||
TextMessageStartEvent(messageId: 'msg_1', role: 'assistant'),
|
TextMessageStartEvent(messageId: 'msg_1', role: 'assistant'),
|
||||||
);
|
);
|
||||||
},
|
},
|
||||||
expect: () => [
|
expect: () => [
|
||||||
isA<ChatState>()
|
isA<ChatState>().having((s) => s.isStreaming, 'isStreaming', true),
|
||||||
.having((s) => s.isLoading, 'isLoading', true)
|
|
||||||
.having((s) => s.isLoading, 'isLoading', true),
|
|
||||||
isA<ChatState>()
|
isA<ChatState>()
|
||||||
.having((s) => s.items.length, 'items length', 1)
|
.having((s) => s.items.length, 'items length', 1)
|
||||||
.having((s) => s.currentMessageId, 'currentMessageId', 'msg_1')
|
.having((s) => s.currentMessageId, 'currentMessageId', 'msg_1')
|
||||||
@@ -128,6 +144,7 @@ void main() {
|
|||||||
expect: () => [
|
expect: () => [
|
||||||
isA<ChatState>()
|
isA<ChatState>()
|
||||||
.having((s) => s.currentMessageId, 'currentMessageId', isNull)
|
.having((s) => s.currentMessageId, 'currentMessageId', isNull)
|
||||||
|
.having((s) => s.isStreaming, 'isStreaming', false)
|
||||||
.having(
|
.having(
|
||||||
(s) => (s.items.first as TextMessageItem).isStreaming,
|
(s) => (s.items.first as TextMessageItem).isStreaming,
|
||||||
'isStreaming',
|
'isStreaming',
|
||||||
@@ -145,6 +162,7 @@ void main() {
|
|||||||
expect: () => [
|
expect: () => [
|
||||||
isA<ChatState>()
|
isA<ChatState>()
|
||||||
.having((s) => s.isLoading, 'isLoading', true)
|
.having((s) => s.isLoading, 'isLoading', true)
|
||||||
|
.having((s) => s.isWaitingFirstToken, 'isWaitingFirstToken', true)
|
||||||
.having((s) => s.error, 'error', isNull),
|
.having((s) => s.error, 'error', isNull),
|
||||||
],
|
],
|
||||||
);
|
);
|
||||||
@@ -152,7 +170,7 @@ void main() {
|
|||||||
blocTest<ChatBloc, ChatState>(
|
blocTest<ChatBloc, ChatState>(
|
||||||
'runFinished sets isLoading to false',
|
'runFinished sets isLoading to false',
|
||||||
build: () => chatBloc,
|
build: () => chatBloc,
|
||||||
seed: () => const ChatState(isLoading: true),
|
seed: () => const ChatState(isWaitingFirstToken: true),
|
||||||
act: (bloc) {
|
act: (bloc) {
|
||||||
service.onEvent(RunFinishedEvent(threadId: 't1', runId: 'r1'));
|
service.onEvent(RunFinishedEvent(threadId: 't1', runId: 'r1'));
|
||||||
},
|
},
|
||||||
@@ -166,7 +184,7 @@ void main() {
|
|||||||
blocTest<ChatBloc, ChatState>(
|
blocTest<ChatBloc, ChatState>(
|
||||||
'runError sets error message',
|
'runError sets error message',
|
||||||
build: () => chatBloc,
|
build: () => chatBloc,
|
||||||
seed: () => const ChatState(isLoading: true),
|
seed: () => const ChatState(isWaitingFirstToken: true),
|
||||||
act: (bloc) {
|
act: (bloc) {
|
||||||
service.onEvent(
|
service.onEvent(
|
||||||
RunErrorEvent(message: 'Something went wrong', code: 'ERR'),
|
RunErrorEvent(message: 'Something went wrong', code: 'ERR'),
|
||||||
@@ -175,10 +193,40 @@ void main() {
|
|||||||
expect: () => [
|
expect: () => [
|
||||||
isA<ChatState>()
|
isA<ChatState>()
|
||||||
.having((s) => s.isLoading, 'isLoading', false)
|
.having((s) => s.isLoading, 'isLoading', false)
|
||||||
|
.having((s) => s.currentMessageId, 'currentMessageId', isNull)
|
||||||
.having((s) => s.error, 'error', 'Something went wrong'),
|
.having((s) => s.error, 'error', 'Something went wrong'),
|
||||||
],
|
],
|
||||||
);
|
);
|
||||||
|
|
||||||
|
blocTest<ChatBloc, ChatState>(
|
||||||
|
'cancelCurrentRun exits waiting states',
|
||||||
|
build: () => chatBloc,
|
||||||
|
seed: () => const ChatState(isWaitingFirstToken: true),
|
||||||
|
act: (bloc) => bloc.cancelCurrentRun(),
|
||||||
|
expect: () => [
|
||||||
|
isA<ChatState>().having((s) => s.isCancelling, 'isCancelling', true),
|
||||||
|
isA<ChatState>()
|
||||||
|
.having((s) => s.isWaitingFirstToken, 'isWaitingFirstToken', false)
|
||||||
|
.having((s) => s.isStreaming, 'isStreaming', false)
|
||||||
|
.having((s) => s.isCancelling, 'isCancelling', false),
|
||||||
|
],
|
||||||
|
);
|
||||||
|
|
||||||
|
blocTest<ChatBloc, ChatState>(
|
||||||
|
'sendMessage failure emits error and exits waiting state',
|
||||||
|
build: () => ChatBloc(service: _ThrowingAgUiService()),
|
||||||
|
act: (bloc) => bloc.sendMessage('hello'),
|
||||||
|
expect: () => [
|
||||||
|
isA<ChatState>()
|
||||||
|
.having((s) => s.isSending, 'isSending', true)
|
||||||
|
.having((s) => s.isWaitingFirstToken, 'isWaitingFirstToken', true),
|
||||||
|
isA<ChatState>()
|
||||||
|
.having((s) => s.isSending, 'isSending', false)
|
||||||
|
.having((s) => s.isWaitingFirstToken, 'isWaitingFirstToken', false)
|
||||||
|
.having((s) => s.error, 'error', contains('network down')),
|
||||||
|
],
|
||||||
|
);
|
||||||
|
|
||||||
blocTest<ChatBloc, ChatState>(
|
blocTest<ChatBloc, ChatState>(
|
||||||
'clearError removes error',
|
'clearError removes error',
|
||||||
build: () => chatBloc,
|
build: () => chatBloc,
|
||||||
|
|||||||
@@ -5,6 +5,9 @@ import 'package:flutter/material.dart';
|
|||||||
import 'package:flutter_test/flutter_test.dart';
|
import 'package:flutter_test/flutter_test.dart';
|
||||||
import 'package:lucide_icons/lucide_icons.dart';
|
import 'package:lucide_icons/lucide_icons.dart';
|
||||||
import 'package:social_app/core/api/api_exception.dart';
|
import 'package:social_app/core/api/api_exception.dart';
|
||||||
|
import 'package:social_app/features/chat/data/models/ag_ui_event.dart';
|
||||||
|
import 'package:social_app/features/chat/data/services/ag_ui_service.dart';
|
||||||
|
import 'package:social_app/features/chat/presentation/bloc/chat_bloc.dart';
|
||||||
import 'package:social_app/features/home/data/voice_recorder.dart';
|
import 'package:social_app/features/home/data/voice_recorder.dart';
|
||||||
import 'package:social_app/features/home/ui/screens/home_screen.dart';
|
import 'package:social_app/features/home/ui/screens/home_screen.dart';
|
||||||
|
|
||||||
@@ -29,7 +32,26 @@ class _FakeVoiceRecorder implements VoiceRecorder {
|
|||||||
Future<void> dispose() async {}
|
Future<void> dispose() async {}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class _WaitingAgUiService extends AgUiService {
|
||||||
|
_WaitingAgUiService() : super(onEvent: (_) {});
|
||||||
|
|
||||||
|
final Completer<void> _pending = Completer<void>();
|
||||||
|
|
||||||
|
@override
|
||||||
|
Future<void> sendMessage(String content) async {
|
||||||
|
onEvent(RunStartedEvent(threadId: 't1', runId: 'r1'));
|
||||||
|
return _pending.future;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
|
IconData _inputActionIcon(WidgetTester tester) {
|
||||||
|
final icon = tester.widget<Icon>(
|
||||||
|
find.byKey(const ValueKey('home_input_action_icon')),
|
||||||
|
);
|
||||||
|
return icon.icon!;
|
||||||
|
}
|
||||||
|
|
||||||
group('HomeScreen Widget Tests', () {
|
group('HomeScreen Widget Tests', () {
|
||||||
testWidgets('displays input field', (WidgetTester tester) async {
|
testWidgets('displays input field', (WidgetTester tester) async {
|
||||||
await tester.pumpWidget(
|
await tester.pumpWidget(
|
||||||
@@ -79,8 +101,7 @@ void main() {
|
|||||||
|
|
||||||
expect(fakeRecorder.started, true);
|
expect(fakeRecorder.started, true);
|
||||||
expect(find.text('正在聆听...'), findsOneWidget);
|
expect(find.text('正在聆听...'), findsOneWidget);
|
||||||
expect(find.byIcon(LucideIcons.square), findsOneWidget);
|
expect(_inputActionIcon(tester), LucideIcons.square);
|
||||||
expect(find.byIcon(LucideIcons.send), findsOneWidget);
|
|
||||||
});
|
});
|
||||||
|
|
||||||
testWidgets('tap send while recording transcribes and auto sends message', (
|
testWidgets('tap send while recording transcribes and auto sends message', (
|
||||||
@@ -105,9 +126,9 @@ void main() {
|
|||||||
);
|
);
|
||||||
await tester.pumpAndSettle();
|
await tester.pumpAndSettle();
|
||||||
|
|
||||||
await tester.tap(find.byIcon(LucideIcons.mic));
|
await tester.tap(find.byKey(const ValueKey('home_input_action_button')));
|
||||||
await tester.pump();
|
await tester.pump();
|
||||||
await tester.tap(find.byIcon(LucideIcons.send));
|
await tester.tap(find.byKey(const ValueKey('home_input_action_button')));
|
||||||
await tester.pump(const Duration(milliseconds: 300));
|
await tester.pump(const Duration(milliseconds: 300));
|
||||||
|
|
||||||
expect(sentTranscript, '语音自动发送');
|
expect(sentTranscript, '语音自动发送');
|
||||||
@@ -127,18 +148,19 @@ void main() {
|
|||||||
expect(filePath.endsWith('.wav'), true);
|
expect(filePath.endsWith('.wav'), true);
|
||||||
return '语音转文字结果';
|
return '语音转文字结果';
|
||||||
},
|
},
|
||||||
|
onAutoSendTranscript: (_) async {},
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
);
|
);
|
||||||
await tester.pumpAndSettle();
|
await tester.pumpAndSettle();
|
||||||
|
|
||||||
await tester.tap(find.byIcon(LucideIcons.mic));
|
await tester.tap(find.byKey(const ValueKey('home_input_action_button')));
|
||||||
await tester.pump();
|
await tester.pump();
|
||||||
await tester.tap(find.byIcon(LucideIcons.square));
|
await tester.tap(find.byKey(const ValueKey('home_input_action_button')));
|
||||||
await tester.pump();
|
await tester.pump();
|
||||||
|
|
||||||
expect(find.text('语音识别中...'), findsOneWidget);
|
expect(find.text('语音识别中...'), findsOneWidget);
|
||||||
expect(find.byType(CircularProgressIndicator), findsOneWidget);
|
expect(find.byType(CircularProgressIndicator), findsAtLeastNWidgets(1));
|
||||||
});
|
});
|
||||||
|
|
||||||
testWidgets('tap stop shows readable unauthorized message', (
|
testWidgets('tap stop shows readable unauthorized message', (
|
||||||
@@ -158,9 +180,9 @@ void main() {
|
|||||||
);
|
);
|
||||||
await tester.pumpAndSettle();
|
await tester.pumpAndSettle();
|
||||||
|
|
||||||
await tester.tap(find.byIcon(LucideIcons.mic));
|
await tester.tap(find.byKey(const ValueKey('home_input_action_button')));
|
||||||
await tester.pump();
|
await tester.pump();
|
||||||
await tester.tap(find.byIcon(LucideIcons.square));
|
await tester.tap(find.byKey(const ValueKey('home_input_action_button')));
|
||||||
await tester.pump(const Duration(milliseconds: 300));
|
await tester.pump(const Duration(milliseconds: 300));
|
||||||
|
|
||||||
expect(find.text('请重新登录'), findsOneWidget);
|
expect(find.text('请重新登录'), findsOneWidget);
|
||||||
@@ -182,9 +204,9 @@ void main() {
|
|||||||
);
|
);
|
||||||
await tester.pumpAndSettle();
|
await tester.pumpAndSettle();
|
||||||
|
|
||||||
await tester.tap(find.byIcon(LucideIcons.mic));
|
await tester.tap(find.byKey(const ValueKey('home_input_action_button')));
|
||||||
await tester.pump();
|
await tester.pump();
|
||||||
await tester.tap(find.byIcon(LucideIcons.square));
|
await tester.tap(find.byKey(const ValueKey('home_input_action_button')));
|
||||||
await tester.pump(const Duration(milliseconds: 300));
|
await tester.pump(const Duration(milliseconds: 300));
|
||||||
|
|
||||||
expect(find.text('未识别到有效语音,请靠近麦克风并连续说话后重试'), findsOneWidget);
|
expect(find.text('未识别到有效语音,请靠近麦克风并连续说话后重试'), findsOneWidget);
|
||||||
@@ -203,14 +225,15 @@ void main() {
|
|||||||
voiceRecorder: fakeRecorder,
|
voiceRecorder: fakeRecorder,
|
||||||
autoLoadHistory: false,
|
autoLoadHistory: false,
|
||||||
onTranscribeAudio: (_) => completer.future,
|
onTranscribeAudio: (_) => completer.future,
|
||||||
|
onAutoSendTranscript: (_) async {},
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
);
|
);
|
||||||
await tester.pumpAndSettle();
|
await tester.pumpAndSettle();
|
||||||
|
|
||||||
await tester.tap(find.byIcon(LucideIcons.mic));
|
await tester.tap(find.byKey(const ValueKey('home_input_action_button')));
|
||||||
await tester.pump();
|
await tester.pump();
|
||||||
await tester.tap(find.byIcon(LucideIcons.square));
|
await tester.tap(find.byKey(const ValueKey('home_input_action_button')));
|
||||||
await tester.pump();
|
await tester.pump();
|
||||||
|
|
||||||
expect(find.text('语音识别中...'), findsOneWidget);
|
expect(find.text('语音识别中...'), findsOneWidget);
|
||||||
@@ -237,7 +260,7 @@ void main() {
|
|||||||
);
|
);
|
||||||
expect(editableBefore.widget.focusNode.hasFocus, isTrue);
|
expect(editableBefore.widget.focusNode.hasFocus, isTrue);
|
||||||
|
|
||||||
await tester.tap(find.byIcon(LucideIcons.send));
|
await tester.tap(find.byKey(const ValueKey('home_input_action_button')));
|
||||||
await tester.pump();
|
await tester.pump();
|
||||||
|
|
||||||
final editableAfter = tester.state<EditableTextState>(
|
final editableAfter = tester.state<EditableTextState>(
|
||||||
@@ -247,5 +270,33 @@ void main() {
|
|||||||
|
|
||||||
await tester.pump(const Duration(milliseconds: 300));
|
await tester.pump(const Duration(milliseconds: 300));
|
||||||
});
|
});
|
||||||
|
|
||||||
|
testWidgets('shows stop icon and waiting indicator while waiting agent', (
|
||||||
|
WidgetTester tester,
|
||||||
|
) async {
|
||||||
|
final chatBloc = ChatBloc(service: _WaitingAgUiService());
|
||||||
|
await tester.pumpWidget(
|
||||||
|
MaterialApp(
|
||||||
|
home: HomeScreen(autoLoadHistory: false, chatBloc: chatBloc),
|
||||||
|
),
|
||||||
|
);
|
||||||
|
await tester.pumpAndSettle();
|
||||||
|
|
||||||
|
await tester.enterText(find.byType(TextField), 'hello');
|
||||||
|
await tester.pump();
|
||||||
|
await tester.tap(find.byKey(const ValueKey('home_input_action_button')));
|
||||||
|
await tester.pump();
|
||||||
|
|
||||||
|
expect(_inputActionIcon(tester), LucideIcons.square);
|
||||||
|
expect(find.text('正在思考...'), findsOneWidget);
|
||||||
|
|
||||||
|
await tester.tap(find.byKey(const ValueKey('home_input_action_button')));
|
||||||
|
await tester.pump();
|
||||||
|
|
||||||
|
expect(find.text('已停止等待回复'), findsOneWidget);
|
||||||
|
await tester.pump(const Duration(seconds: 3));
|
||||||
|
|
||||||
|
await chatBloc.close();
|
||||||
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,8 +13,8 @@ from core.agent.domain.agui_input import (
|
|||||||
)
|
)
|
||||||
from core.agent.application.runtime_loop_service import RuntimeLoopService
|
from core.agent.application.runtime_loop_service import RuntimeLoopService
|
||||||
from core.agent.application.runtime_data_service import RuntimeDataService
|
from core.agent.application.runtime_data_service import RuntimeDataService
|
||||||
from core.agent.application.session_state_persistence import SessionStatePersistence
|
|
||||||
from core.agent.application.session_state_persistence import (
|
from core.agent.application.session_state_persistence import (
|
||||||
|
SessionStatePersistence,
|
||||||
ToolResultStorage,
|
ToolResultStorage,
|
||||||
persist_tool_result_payload,
|
persist_tool_result_payload,
|
||||||
)
|
)
|
||||||
@@ -179,7 +179,6 @@ class RunService:
|
|||||||
seq=next_seq,
|
seq=next_seq,
|
||||||
role=AgentChatMessageRole.USER,
|
role=AgentChatMessageRole.USER,
|
||||||
content=user_input,
|
content=user_input,
|
||||||
model_code=model_code,
|
|
||||||
metadata=MessageMetadataUserInput().model_dump(),
|
metadata=MessageMetadataUserInput().model_dump(),
|
||||||
)
|
)
|
||||||
pending_tool_call_id: str | None = None
|
pending_tool_call_id: str | None = None
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from typing import Any, Callable
|
|||||||
|
|
||||||
from crewai import Agent, Crew, LLM, Process, Task
|
from crewai import Agent, Crew, LLM, Process, Task
|
||||||
from crewai.agents import parser as crew_parser
|
from crewai.agents import parser as crew_parser
|
||||||
from litellm import completion, completion_cost
|
from litellm import completion
|
||||||
|
|
||||||
from core.agent.domain.system_agent_config import SystemAgentLLMConfig
|
from core.agent.domain.system_agent_config import SystemAgentLLMConfig
|
||||||
from core.agent.infrastructure.config.resolver import ResolvedAgentConfig
|
from core.agent.infrastructure.config.resolver import ResolvedAgentConfig
|
||||||
@@ -17,7 +17,11 @@ from core.agent.infrastructure.crewai.runtime_tools import (
|
|||||||
PendingFrontendToolCall,
|
PendingFrontendToolCall,
|
||||||
resolve_stage_crewai_tools,
|
resolve_stage_crewai_tools,
|
||||||
)
|
)
|
||||||
from core.agent.infrastructure.litellm.usage_tracker import UsageCost
|
from core.agent.infrastructure.litellm.pricing import calculate_tiered_model_cost
|
||||||
|
from core.agent.infrastructure.litellm.usage_tracker import (
|
||||||
|
UsageCost,
|
||||||
|
extract_usage_and_cost,
|
||||||
|
)
|
||||||
from core.agent.prompt import runtime_stage_prompts
|
from core.agent.prompt import runtime_stage_prompts
|
||||||
from core.logging import get_logger
|
from core.logging import get_logger
|
||||||
|
|
||||||
@@ -25,6 +29,31 @@ from core.logging import get_logger
|
|||||||
logger = get_logger("core.agent.infrastructure.crewai.runtime_stage_runner")
|
logger = get_logger("core.agent.infrastructure.crewai.runtime_stage_runner")
|
||||||
|
|
||||||
|
|
||||||
|
class LiteLLMUsageCaptureCallback:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.captured_usage: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_usage(usage_payload: object) -> dict[str, Any] | None:
|
||||||
|
if isinstance(usage_payload, dict):
|
||||||
|
return usage_payload
|
||||||
|
model_dump = getattr(usage_payload, "model_dump", None)
|
||||||
|
if callable(model_dump):
|
||||||
|
dumped = model_dump()
|
||||||
|
if isinstance(dumped, dict):
|
||||||
|
return dumped
|
||||||
|
return None
|
||||||
|
|
||||||
|
def log_success_event(self, **kwargs: Any) -> None:
|
||||||
|
response_obj = kwargs.get("response_obj")
|
||||||
|
if not isinstance(response_obj, dict):
|
||||||
|
return
|
||||||
|
normalized = self._normalize_usage(response_obj.get("usage"))
|
||||||
|
if normalized is None:
|
||||||
|
return
|
||||||
|
self.captured_usage = normalized
|
||||||
|
|
||||||
|
|
||||||
def _tool_names(tools_payload: list[dict[str, object]]) -> list[str]:
|
def _tool_names(tools_payload: list[dict[str, object]]) -> list[str]:
|
||||||
names: list[str] = []
|
names: list[str] = []
|
||||||
for item in tools_payload:
|
for item in tools_payload:
|
||||||
@@ -69,24 +98,37 @@ def _output_diagnostics(*, text: str, tool_names: list[str]) -> dict[str, object
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def extract_usage_from_captured_payload(
|
||||||
|
*,
|
||||||
|
captured_usage: dict[str, Any],
|
||||||
|
model: str,
|
||||||
|
) -> UsageCost:
|
||||||
|
usage = extract_usage_and_cost(
|
||||||
|
{
|
||||||
|
"model": model,
|
||||||
|
"usage": captured_usage,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return usage
|
||||||
|
|
||||||
|
|
||||||
def extract_usage_from_crew_output(*, output: object, model: str) -> UsageCost:
|
def extract_usage_from_crew_output(*, output: object, model: str) -> UsageCost:
|
||||||
token_usage = getattr(output, "token_usage", None)
|
token_usage = getattr(output, "token_usage", None)
|
||||||
prompt_tokens = int(getattr(token_usage, "prompt_tokens", 0) or 0)
|
prompt_tokens = int(getattr(token_usage, "prompt_tokens", 0) or 0)
|
||||||
completion_tokens = int(getattr(token_usage, "completion_tokens", 0) or 0)
|
completion_tokens = int(getattr(token_usage, "completion_tokens", 0) or 0)
|
||||||
total_tokens = int(getattr(token_usage, "total_tokens", 0) or 0)
|
total_tokens = int(getattr(token_usage, "total_tokens", 0) or 0)
|
||||||
|
cached_prompt_tokens = int(getattr(token_usage, "cached_prompt_tokens", 0) or 0)
|
||||||
if total_tokens == 0:
|
if total_tokens == 0:
|
||||||
total_tokens = prompt_tokens + completion_tokens
|
total_tokens = prompt_tokens + completion_tokens
|
||||||
try:
|
cost = float(
|
||||||
cost = float(
|
calculate_tiered_model_cost(
|
||||||
completion_cost(
|
model_name=model,
|
||||||
model=model,
|
prompt_tokens=prompt_tokens,
|
||||||
prompt_tokens=prompt_tokens,
|
completion_tokens=completion_tokens,
|
||||||
completion_tokens=completion_tokens,
|
cached_prompt_tokens=cached_prompt_tokens,
|
||||||
)
|
|
||||||
or 0.0
|
|
||||||
)
|
)
|
||||||
except Exception:
|
or 0.0
|
||||||
cost = 0.0
|
)
|
||||||
return UsageCost(
|
return UsageCost(
|
||||||
prompt_tokens=prompt_tokens,
|
prompt_tokens=prompt_tokens,
|
||||||
completion_tokens=completion_tokens,
|
completion_tokens=completion_tokens,
|
||||||
@@ -134,32 +176,32 @@ def run_stage_with_crewai(
|
|||||||
content = getattr(message, "content", None)
|
content = getattr(message, "content", None)
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
raw_text = content
|
raw_text = content
|
||||||
usage_obj = getattr(response_any, "usage", None)
|
|
||||||
prompt_tokens = int(getattr(usage_obj, "prompt_tokens", 0) or 0)
|
|
||||||
completion_tokens = int(getattr(usage_obj, "completion_tokens", 0) or 0)
|
|
||||||
total_tokens = int(getattr(usage_obj, "total_tokens", 0) or 0)
|
|
||||||
if total_tokens == 0:
|
|
||||||
total_tokens = prompt_tokens + completion_tokens
|
|
||||||
try:
|
try:
|
||||||
cost = float(
|
response_dict = (
|
||||||
completion_cost(
|
response_any.model_dump()
|
||||||
model=litellm_model,
|
if hasattr(response_any, "model_dump")
|
||||||
prompt_tokens=prompt_tokens,
|
else dict(response_any)
|
||||||
completion_tokens=completion_tokens,
|
|
||||||
)
|
|
||||||
or 0.0
|
|
||||||
)
|
)
|
||||||
|
if "model" not in response_dict:
|
||||||
|
response_dict["model"] = litellm_model
|
||||||
|
usage = extract_usage_and_cost(response_dict)
|
||||||
except Exception:
|
except Exception:
|
||||||
cost = 0.0
|
usage_obj = getattr(response_any, "usage", None)
|
||||||
usage = UsageCost(
|
prompt_tokens = int(getattr(usage_obj, "prompt_tokens", 0) or 0)
|
||||||
prompt_tokens=prompt_tokens,
|
completion_tokens = int(getattr(usage_obj, "completion_tokens", 0) or 0)
|
||||||
completion_tokens=completion_tokens,
|
total_tokens = int(getattr(usage_obj, "total_tokens", 0) or 0)
|
||||||
total_tokens=total_tokens,
|
if total_tokens == 0:
|
||||||
cost=cost,
|
total_tokens = prompt_tokens + completion_tokens
|
||||||
)
|
usage = UsageCost(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=total_tokens,
|
||||||
|
cost=0.0,
|
||||||
|
)
|
||||||
return raw_text, usage, [], None
|
return raw_text, usage, [], None
|
||||||
|
|
||||||
calls: list[dict[str, Any]] = []
|
calls: list[dict[str, Any]] = []
|
||||||
|
usage_callback = LiteLLMUsageCaptureCallback()
|
||||||
crew_tools = resolve_stage_crewai_tools(
|
crew_tools = resolve_stage_crewai_tools(
|
||||||
tools_payload=tools_payload,
|
tools_payload=tools_payload,
|
||||||
calls=calls,
|
calls=calls,
|
||||||
@@ -173,6 +215,8 @@ def run_stage_with_crewai(
|
|||||||
temperature=llm_config.temperature,
|
temperature=llm_config.temperature,
|
||||||
max_tokens=llm_config.max_tokens,
|
max_tokens=llm_config.max_tokens,
|
||||||
timeout=llm_config.timeout_seconds,
|
timeout=llm_config.timeout_seconds,
|
||||||
|
stream=True,
|
||||||
|
callbacks=[usage_callback],
|
||||||
)
|
)
|
||||||
agent = Agent(
|
agent = Agent(
|
||||||
role=agent_template.role,
|
role=agent_template.role,
|
||||||
@@ -218,7 +262,14 @@ def run_stage_with_crewai(
|
|||||||
],
|
],
|
||||||
pending_tool=str(pending.payload.get("name")),
|
pending_tool=str(pending.payload.get("name")),
|
||||||
)
|
)
|
||||||
return "", UsageCost(0, 0, 0, 0.0), calls, pending.payload
|
if usage_callback.captured_usage is not None:
|
||||||
|
usage = extract_usage_from_captured_payload(
|
||||||
|
captured_usage=usage_callback.captured_usage,
|
||||||
|
model=litellm_model,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
usage = UsageCost(0, 0, 0, 0.0)
|
||||||
|
return "", usage, calls, pending.payload
|
||||||
|
|
||||||
output_text = extract_crew_output_text(output)
|
output_text = extract_crew_output_text(output)
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -231,5 +282,11 @@ def run_stage_with_crewai(
|
|||||||
],
|
],
|
||||||
diagnostics=_output_diagnostics(text=output_text, tool_names=stage_tool_names),
|
diagnostics=_output_diagnostics(text=output_text, tool_names=stage_tool_names),
|
||||||
)
|
)
|
||||||
usage = extract_usage_from_crew_output(output=output, model=litellm_model)
|
if usage_callback.captured_usage is not None:
|
||||||
|
usage = extract_usage_from_captured_payload(
|
||||||
|
captured_usage=usage_callback.captured_usage,
|
||||||
|
model=litellm_model,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
usage = extract_usage_from_crew_output(output=output, model=litellm_model)
|
||||||
return output_text, usage, calls, None
|
return output_text, usage, calls, None
|
||||||
|
|||||||
@@ -36,9 +36,22 @@ QWEN35_FLASH_TIERED_PRICING: tuple[TieredModelPricing, ...] = (
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
DEEPSEEK_CHAT_TIERED_PRICING: tuple[TieredModelPricing, ...] = (
|
||||||
|
TieredModelPricing(
|
||||||
|
max_prompt_tokens=10_000_000,
|
||||||
|
input_cost_per_token=2.0 / 1_000_000,
|
||||||
|
output_cost_per_token=3.0 / 1_000_000,
|
||||||
|
cache_create_cost_per_token=2.0 / 1_000_000,
|
||||||
|
cache_hit_cost_per_token=0.2 / 1_000_000,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
_MODEL_TIERED_PRICING: dict[str, tuple[TieredModelPricing, ...]] = {
|
_MODEL_TIERED_PRICING: dict[str, tuple[TieredModelPricing, ...]] = {
|
||||||
"dashscope/qwen3.5-flash": QWEN35_FLASH_TIERED_PRICING,
|
"dashscope/qwen3.5-flash": QWEN35_FLASH_TIERED_PRICING,
|
||||||
|
"qwen3.5-flash": QWEN35_FLASH_TIERED_PRICING,
|
||||||
|
"deepseek/deepseek-chat": DEEPSEEK_CHAT_TIERED_PRICING,
|
||||||
|
"deepseek-chat": DEEPSEEK_CHAT_TIERED_PRICING,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -61,12 +74,21 @@ def calculate_tiered_model_cost(
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
prompt_tokens: int,
|
prompt_tokens: int,
|
||||||
completion_tokens: int,
|
completion_tokens: int,
|
||||||
|
cached_prompt_tokens: int = 0,
|
||||||
) -> float | None:
|
) -> float | None:
|
||||||
tier = get_tiered_pricing(model_name=model_name, prompt_tokens=prompt_tokens)
|
tier = get_tiered_pricing(model_name=model_name, prompt_tokens=prompt_tokens)
|
||||||
if tier is None:
|
if tier is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return (
|
normalized_prompt_tokens = max(int(prompt_tokens), 0)
|
||||||
prompt_tokens * tier.input_cost_per_token
|
normalized_completion_tokens = max(int(completion_tokens), 0)
|
||||||
+ completion_tokens * tier.output_cost_per_token
|
normalized_cached_tokens = min(
|
||||||
|
max(int(cached_prompt_tokens), 0), normalized_prompt_tokens
|
||||||
|
)
|
||||||
|
uncached_prompt_tokens = normalized_prompt_tokens - normalized_cached_tokens
|
||||||
|
|
||||||
|
return (
|
||||||
|
uncached_prompt_tokens * tier.input_cost_per_token
|
||||||
|
+ normalized_cached_tokens * tier.cache_hit_cost_per_token
|
||||||
|
+ normalized_completion_tokens * tier.output_cost_per_token
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -3,8 +3,6 @@ from __future__ import annotations
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from litellm import completion_cost
|
|
||||||
|
|
||||||
from core.agent.infrastructure.litellm.pricing import calculate_tiered_model_cost
|
from core.agent.infrastructure.litellm.pricing import calculate_tiered_model_cost
|
||||||
|
|
||||||
|
|
||||||
@@ -26,25 +24,19 @@ def extract_usage_and_cost(response: dict[str, Any]) -> UsageCost:
|
|||||||
completion_tokens = int(usage.get("completion_tokens", 0))
|
completion_tokens = int(usage.get("completion_tokens", 0))
|
||||||
total_tokens = int(usage.get("total_tokens", prompt_tokens + completion_tokens))
|
total_tokens = int(usage.get("total_tokens", prompt_tokens + completion_tokens))
|
||||||
model_name = str(response.get("model", "")).strip().lower()
|
model_name = str(response.get("model", "")).strip().lower()
|
||||||
|
prompt_tokens_details = usage.get("prompt_tokens_details")
|
||||||
|
cached_prompt_tokens = 0
|
||||||
|
if isinstance(prompt_tokens_details, dict):
|
||||||
|
cached_prompt_tokens = int(prompt_tokens_details.get("cached_tokens", 0) or 0)
|
||||||
|
|
||||||
try:
|
local_cost = calculate_tiered_model_cost(
|
||||||
cost = completion_cost(completion_response=response)
|
model_name=model_name,
|
||||||
if cost is None:
|
prompt_tokens=prompt_tokens,
|
||||||
raise ValueError("unable to calculate litellm completion cost")
|
completion_tokens=completion_tokens,
|
||||||
return UsageCost(
|
cached_prompt_tokens=cached_prompt_tokens,
|
||||||
prompt_tokens=prompt_tokens,
|
)
|
||||||
completion_tokens=completion_tokens,
|
if local_cost is None:
|
||||||
total_tokens=total_tokens,
|
raise ValueError("unable to calculate custom completion cost")
|
||||||
cost=float(cost),
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
|
||||||
local_cost = calculate_tiered_model_cost(
|
|
||||||
model_name=model_name,
|
|
||||||
prompt_tokens=prompt_tokens,
|
|
||||||
completion_tokens=completion_tokens,
|
|
||||||
)
|
|
||||||
if local_cost is None:
|
|
||||||
raise ValueError("unable to calculate litellm completion cost") from exc
|
|
||||||
|
|
||||||
return UsageCost(
|
return UsageCost(
|
||||||
prompt_tokens=prompt_tokens,
|
prompt_tokens=prompt_tokens,
|
||||||
|
|||||||
@@ -5,15 +5,14 @@ import pytest
|
|||||||
from core.agent.infrastructure.litellm.usage_tracker import extract_usage_and_cost
|
from core.agent.infrastructure.litellm.usage_tracker import extract_usage_and_cost
|
||||||
|
|
||||||
|
|
||||||
def test_usage_tracker_extracts_tokens_and_cost(
|
def test_usage_tracker_uses_custom_pricing_for_qwen35() -> None:
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
) -> None:
|
|
||||||
monkeypatch.setattr(
|
|
||||||
"core.agent.infrastructure.litellm.usage_tracker.completion_cost",
|
|
||||||
lambda completion_response: 0.123,
|
|
||||||
)
|
|
||||||
response = {
|
response = {
|
||||||
"usage": {"prompt_tokens": 11, "completion_tokens": 7, "total_tokens": 18},
|
"model": "dashscope/qwen3.5-flash",
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 11,
|
||||||
|
"completion_tokens": 7,
|
||||||
|
"total_tokens": 18,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
usage = extract_usage_and_cost(response)
|
usage = extract_usage_and_cost(response)
|
||||||
@@ -21,7 +20,8 @@ def test_usage_tracker_extracts_tokens_and_cost(
|
|||||||
assert usage.prompt_tokens == 11
|
assert usage.prompt_tokens == 11
|
||||||
assert usage.completion_tokens == 7
|
assert usage.completion_tokens == 7
|
||||||
assert usage.total_tokens == 18
|
assert usage.total_tokens == 18
|
||||||
assert usage.cost == 0.123
|
assert usage.cost == pytest.approx(0.0000162)
|
||||||
|
assert usage.cost_source == "custom_pricing"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -33,19 +33,10 @@ def test_usage_tracker_extracts_tokens_and_cost(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_usage_tracker_falls_back_to_local_qwen35_pricing_when_model_unmapped(
|
def test_usage_tracker_falls_back_to_local_qwen35_pricing_when_model_unmapped(
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
prompt_tokens: int,
|
prompt_tokens: int,
|
||||||
completion_tokens: int,
|
completion_tokens: int,
|
||||||
expected_cost: float,
|
expected_cost: float,
|
||||||
) -> None:
|
) -> None:
|
||||||
def _raise_unmapped(*, completion_response): # type: ignore[no-untyped-def]
|
|
||||||
del completion_response
|
|
||||||
raise Exception("This model isn't mapped yet")
|
|
||||||
|
|
||||||
monkeypatch.setattr(
|
|
||||||
"core.agent.infrastructure.litellm.usage_tracker.completion_cost",
|
|
||||||
_raise_unmapped,
|
|
||||||
)
|
|
||||||
response = {
|
response = {
|
||||||
"model": "dashscope/qwen3.5-flash",
|
"model": "dashscope/qwen3.5-flash",
|
||||||
"usage": {
|
"usage": {
|
||||||
@@ -59,3 +50,22 @@ def test_usage_tracker_falls_back_to_local_qwen35_pricing_when_model_unmapped(
|
|||||||
|
|
||||||
assert usage.cost == pytest.approx(expected_cost)
|
assert usage.cost == pytest.approx(expected_cost)
|
||||||
assert usage.cost_source == "custom_pricing"
|
assert usage.cost_source == "custom_pricing"
|
||||||
|
|
||||||
|
|
||||||
|
def test_usage_tracker_uses_cached_pricing_for_deepseek_chat() -> None:
|
||||||
|
response = {
|
||||||
|
"model": "deepseek/deepseek-chat",
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 1_000_000,
|
||||||
|
"completion_tokens": 100_000,
|
||||||
|
"total_tokens": 1_100_000,
|
||||||
|
"prompt_tokens_details": {
|
||||||
|
"cached_tokens": 400_000,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
usage = extract_usage_and_cost(response)
|
||||||
|
|
||||||
|
assert usage.cost == pytest.approx(1.58)
|
||||||
|
assert usage.cost_source == "custom_pricing"
|
||||||
|
|||||||
@@ -1058,6 +1058,128 @@ async def test_run_service_executes_backend_calendar_tool_and_emits_result(
|
|||||||
assert runtime_state["status"] == AgentChatSessionStatus.COMPLETED
|
assert runtime_state["status"] == AgentChatSessionStatus.COMPLETED
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_service_does_not_persist_model_code_for_user_message(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
session_id = uuid4()
|
||||||
|
user_id = uuid4()
|
||||||
|
captured: dict[str, object] = {}
|
||||||
|
message_calls: list[dict[str, object]] = []
|
||||||
|
|
||||||
|
class _FakeDbSession:
|
||||||
|
async def commit(self) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
class _FakeSessionFactory:
|
||||||
|
def __call__(self) -> "_FakeSessionFactory":
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aenter__(self) -> _FakeDbSession:
|
||||||
|
return _FakeDbSession()
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb) -> bool:
|
||||||
|
del exc_type, exc, tb
|
||||||
|
return False
|
||||||
|
|
||||||
|
class _FakeSessionRepository:
|
||||||
|
def __init__(self, session: object) -> None:
|
||||||
|
del session
|
||||||
|
|
||||||
|
async def lock_session_for_update(self, *, session_id: object):
|
||||||
|
return SimpleNamespace(
|
||||||
|
id=session_id,
|
||||||
|
user_id=user_id,
|
||||||
|
status=AgentChatSessionStatus.PENDING,
|
||||||
|
message_count=0,
|
||||||
|
total_tokens=0,
|
||||||
|
total_cost=0,
|
||||||
|
state_snapshot=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def next_message_seq(self, *, session_id: object):
|
||||||
|
del session_id
|
||||||
|
return 1
|
||||||
|
|
||||||
|
async def update_runtime_state(self, **kwargs) -> None:
|
||||||
|
captured["update_runtime_state"] = kwargs
|
||||||
|
|
||||||
|
class _FakeMessageRepository:
|
||||||
|
def __init__(self, session: object) -> None:
|
||||||
|
del session
|
||||||
|
|
||||||
|
async def append_message(self, **kwargs) -> None:
|
||||||
|
message_calls.append(kwargs)
|
||||||
|
|
||||||
|
class _FakeRuntime:
|
||||||
|
def execute(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
user_input: str,
|
||||||
|
system_prompt: str | None = None,
|
||||||
|
tools: list[dict[str, object]] | None = None,
|
||||||
|
):
|
||||||
|
del user_input, system_prompt, tools
|
||||||
|
return {
|
||||||
|
"assistant_text": "ok",
|
||||||
|
"prompt_tokens": 1,
|
||||||
|
"completion_tokens": 1,
|
||||||
|
"total_tokens": 2,
|
||||||
|
"cost": "0.001",
|
||||||
|
"agui_events": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _fake_load_agent_model_selection(self, _session):
|
||||||
|
del self
|
||||||
|
return ("qwen3.5-flash", "dashscope", SystemAgentLLMConfig())
|
||||||
|
|
||||||
|
async def _fake_load_user_agent_context(self, session, session_id, user_id):
|
||||||
|
del self, session, session_id
|
||||||
|
return SimpleNamespace(
|
||||||
|
user_id=user_id,
|
||||||
|
username="demo-user",
|
||||||
|
bio=None,
|
||||||
|
settings=SimpleNamespace(
|
||||||
|
preferences=SimpleNamespace(
|
||||||
|
interface_language="zh-CN",
|
||||||
|
ai_language="zh-CN",
|
||||||
|
timezone="Asia/Shanghai",
|
||||||
|
country="CN",
|
||||||
|
)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"core.agent.application.run_service.SessionRepository",
|
||||||
|
_FakeSessionRepository,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"core.agent.application.run_service.MessageRepository",
|
||||||
|
_FakeMessageRepository,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"core.agent.application.run_service.create_runtime",
|
||||||
|
lambda **_kwargs: _FakeRuntime(),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"core.agent.application.run_service.RunService._load_agent_model_selection",
|
||||||
|
_fake_load_agent_model_selection,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"core.agent.application.run_service.RunService._load_user_agent_context",
|
||||||
|
_fake_load_user_agent_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
service = RunService(session_factory=_FakeSessionFactory()) # type: ignore[arg-type]
|
||||||
|
await service.run(
|
||||||
|
run_input=_build_run_input(thread_id=str(session_id), text="hello")
|
||||||
|
)
|
||||||
|
|
||||||
|
user_message = message_calls[0]
|
||||||
|
assert user_message["role"] == AgentChatMessageRole.USER
|
||||||
|
assert "model_code" not in user_message
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_load_user_agent_context_parses_profile_settings_v1() -> None:
|
async def test_load_user_agent_context_parses_profile_settings_v1() -> None:
|
||||||
session_id = uuid4()
|
session_id = uuid4()
|
||||||
|
|||||||
@@ -0,0 +1,72 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.agent.infrastructure.crewai.runtime_stage_runner import (
|
||||||
|
LiteLLMUsageCaptureCallback,
|
||||||
|
extract_usage_from_captured_payload,
|
||||||
|
extract_usage_from_crew_output,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_usage_from_crew_output_uses_custom_deepseek_pricing() -> None:
|
||||||
|
output = SimpleNamespace(
|
||||||
|
token_usage=SimpleNamespace(
|
||||||
|
prompt_tokens=1_000_000,
|
||||||
|
completion_tokens=100_000,
|
||||||
|
total_tokens=1_100_000,
|
||||||
|
cached_prompt_tokens=400_000,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
usage = extract_usage_from_crew_output(
|
||||||
|
output=output,
|
||||||
|
model="deepseek/deepseek-chat",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert usage.prompt_tokens == 1_000_000
|
||||||
|
assert usage.completion_tokens == 100_000
|
||||||
|
assert usage.total_tokens == 1_100_000
|
||||||
|
assert usage.cost == pytest.approx(1.58)
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_usage_from_captured_payload_uses_custom_pricing() -> None:
|
||||||
|
usage = extract_usage_from_captured_payload(
|
||||||
|
captured_usage={
|
||||||
|
"prompt_tokens": 1_000_000,
|
||||||
|
"completion_tokens": 100_000,
|
||||||
|
"total_tokens": 1_100_000,
|
||||||
|
"prompt_tokens_details": {"cached_tokens": 400_000},
|
||||||
|
},
|
||||||
|
model="deepseek/deepseek-chat",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert usage.prompt_tokens == 1_000_000
|
||||||
|
assert usage.completion_tokens == 100_000
|
||||||
|
assert usage.total_tokens == 1_100_000
|
||||||
|
assert usage.cost == pytest.approx(1.58)
|
||||||
|
|
||||||
|
|
||||||
|
def test_usage_capture_callback_extracts_nested_usage_payload() -> None:
|
||||||
|
callback = LiteLLMUsageCaptureCallback()
|
||||||
|
|
||||||
|
callback.log_success_event(
|
||||||
|
kwargs={},
|
||||||
|
response_obj={
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 15,
|
||||||
|
"completion_tokens": 9,
|
||||||
|
"total_tokens": 24,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
start_time=0,
|
||||||
|
end_time=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert callback.captured_usage == {
|
||||||
|
"prompt_tokens": 15,
|
||||||
|
"completion_tokens": 9,
|
||||||
|
"total_tokens": 24,
|
||||||
|
}
|
||||||