feat: 添加 Agent 步骤事件与图片附件功能
- 新增 stepStarted/stepFinished 事件类型支持 - 前端实现图片附件上传和预览功能 - 后端增强工具结果存储和事件处理 - 完善相关单元测试和集成测试
This commit is contained in:
@@ -50,7 +50,6 @@ class ApiClient implements IApiClient {
|
||||
};
|
||||
}
|
||||
|
||||
@override
|
||||
Future<Response<T>> get<T>(String path, {Options? options}) async {
|
||||
try {
|
||||
return await _dio.get<T>(path, options: options);
|
||||
|
||||
@@ -40,7 +40,11 @@ class MockApiClient implements IApiClient {
|
||||
_handlers[key] = handler;
|
||||
}
|
||||
|
||||
void registerPatternHandler(RegExp pattern, String method, MockHandler handler) {
|
||||
void registerPatternHandler(
|
||||
RegExp pattern,
|
||||
String method,
|
||||
MockHandler handler,
|
||||
) {
|
||||
_patternHandlers.add(
|
||||
_PatternRoute(
|
||||
pattern: pattern,
|
||||
@@ -96,11 +100,7 @@ class MockApiClient implements IApiClient {
|
||||
final direct = _handlers[key];
|
||||
if (direct != null) {
|
||||
final response = direct(
|
||||
MockRequest(
|
||||
path: path,
|
||||
method: 'SSE',
|
||||
headers: headers,
|
||||
),
|
||||
MockRequest(path: path, method: 'SSE', headers: headers),
|
||||
);
|
||||
if (response is Stream<String>) {
|
||||
return response;
|
||||
@@ -118,11 +118,7 @@ class MockApiClient implements IApiClient {
|
||||
continue;
|
||||
}
|
||||
final response = route.handler(
|
||||
MockRequest(
|
||||
path: path,
|
||||
method: 'SSE',
|
||||
headers: headers,
|
||||
),
|
||||
MockRequest(path: path, method: 'SSE', headers: headers),
|
||||
);
|
||||
if (response is Stream<String>) {
|
||||
return response;
|
||||
@@ -147,12 +143,7 @@ class MockApiClient implements IApiClient {
|
||||
|
||||
if (handler != null) {
|
||||
final response = handler(
|
||||
MockRequest(
|
||||
path: path,
|
||||
method: method,
|
||||
data: data,
|
||||
options: options,
|
||||
),
|
||||
MockRequest(path: path, method: method, data: data, options: options),
|
||||
);
|
||||
if (response is Response) {
|
||||
return response as Response<T>;
|
||||
|
||||
@@ -9,6 +9,8 @@ class AgUiEventTypeWire {
|
||||
static const runStarted = 'RUN_STARTED';
|
||||
static const runFinished = 'RUN_FINISHED';
|
||||
static const runError = 'RUN_ERROR';
|
||||
static const stepStarted = 'STEP_STARTED';
|
||||
static const stepFinished = 'STEP_FINISHED';
|
||||
static const textMessageStart = 'TEXT_MESSAGE_START';
|
||||
static const textMessageContent = 'TEXT_MESSAGE_CONTENT';
|
||||
static const textMessageEnd = 'TEXT_MESSAGE_END';
|
||||
@@ -25,6 +27,8 @@ enum AgUiEventType {
|
||||
runStarted,
|
||||
runFinished,
|
||||
runError,
|
||||
stepStarted,
|
||||
stepFinished,
|
||||
textMessageStart,
|
||||
textMessageContent,
|
||||
textMessageEnd,
|
||||
@@ -43,6 +47,8 @@ const _wireToTypeMap = {
|
||||
AgUiEventTypeWire.runStarted: AgUiEventType.runStarted,
|
||||
AgUiEventTypeWire.runFinished: AgUiEventType.runFinished,
|
||||
AgUiEventTypeWire.runError: AgUiEventType.runError,
|
||||
AgUiEventTypeWire.stepStarted: AgUiEventType.stepStarted,
|
||||
AgUiEventTypeWire.stepFinished: AgUiEventType.stepFinished,
|
||||
AgUiEventTypeWire.textMessageStart: AgUiEventType.textMessageStart,
|
||||
AgUiEventTypeWire.textMessageContent: AgUiEventType.textMessageContent,
|
||||
AgUiEventTypeWire.textMessageEnd: AgUiEventType.textMessageEnd,
|
||||
@@ -60,6 +66,8 @@ const _typeToWireMap = {
|
||||
AgUiEventType.runStarted: AgUiEventTypeWire.runStarted,
|
||||
AgUiEventType.runFinished: AgUiEventTypeWire.runFinished,
|
||||
AgUiEventType.runError: AgUiEventTypeWire.runError,
|
||||
AgUiEventType.stepStarted: AgUiEventTypeWire.stepStarted,
|
||||
AgUiEventType.stepFinished: AgUiEventTypeWire.stepFinished,
|
||||
AgUiEventType.textMessageStart: AgUiEventTypeWire.textMessageStart,
|
||||
AgUiEventType.textMessageContent: AgUiEventTypeWire.textMessageContent,
|
||||
AgUiEventType.textMessageEnd: AgUiEventTypeWire.textMessageEnd,
|
||||
@@ -83,6 +91,8 @@ final _typeToFactory = {
|
||||
AgUiEventType.runStarted: RunStartedEvent.fromJson,
|
||||
AgUiEventType.runFinished: RunFinishedEvent.fromJson,
|
||||
AgUiEventType.runError: RunErrorEvent.fromJson,
|
||||
AgUiEventType.stepStarted: StepStartedEvent.fromJson,
|
||||
AgUiEventType.stepFinished: StepFinishedEvent.fromJson,
|
||||
AgUiEventType.textMessageStart: TextMessageStartEvent.fromJson,
|
||||
AgUiEventType.textMessageContent: TextMessageContentEvent.fromJson,
|
||||
AgUiEventType.textMessageEnd: TextMessageEndEvent.fromJson,
|
||||
@@ -170,6 +180,34 @@ class RunErrorEvent extends AgUiEvent {
|
||||
Map<String, dynamic> toJson() => _$RunErrorEventToJson(this);
|
||||
}
|
||||
|
||||
@JsonSerializable()
|
||||
class StepStartedEvent extends AgUiEvent {
|
||||
final String stepName;
|
||||
|
||||
StepStartedEvent({required this.stepName})
|
||||
: super(type: AgUiEventType.stepStarted);
|
||||
|
||||
factory StepStartedEvent.fromJson(Map<String, dynamic> json) =>
|
||||
_$StepStartedEventFromJson(json);
|
||||
|
||||
@override
|
||||
Map<String, dynamic> toJson() => _$StepStartedEventToJson(this);
|
||||
}
|
||||
|
||||
@JsonSerializable()
|
||||
class StepFinishedEvent extends AgUiEvent {
|
||||
final String stepName;
|
||||
|
||||
StepFinishedEvent({required this.stepName})
|
||||
: super(type: AgUiEventType.stepFinished);
|
||||
|
||||
factory StepFinishedEvent.fromJson(Map<String, dynamic> json) =>
|
||||
_$StepFinishedEventFromJson(json);
|
||||
|
||||
@override
|
||||
Map<String, dynamic> toJson() => _$StepFinishedEventToJson(this);
|
||||
}
|
||||
|
||||
@JsonSerializable()
|
||||
class TextMessageStartEvent extends AgUiEvent {
|
||||
final String messageId;
|
||||
@@ -310,10 +348,33 @@ class ToolCallResultEvent extends AgUiEvent {
|
||||
|
||||
factory ToolCallResultEvent.fromJson(Map<String, dynamic> json) {
|
||||
final rawContent = json['content'];
|
||||
final content = rawContent is String ? rawContent : '';
|
||||
final hasStructuredFields =
|
||||
json['ui'] != null || json['result'] != null || json['error'] != null;
|
||||
final content = switch (rawContent) {
|
||||
String value when value.trim().startsWith('{') => value,
|
||||
String value when value.trim().startsWith('[') => value,
|
||||
String value when hasStructuredFields => jsonEncode({
|
||||
'toolName': json['toolName'],
|
||||
'result': json['result'],
|
||||
'error': json['error'],
|
||||
'ui': json['ui'],
|
||||
'content': value,
|
||||
}),
|
||||
String value => value,
|
||||
_ => jsonEncode({
|
||||
'toolName': json['toolName'],
|
||||
'result': json['result'],
|
||||
'error': json['error'],
|
||||
'ui': json['ui'],
|
||||
'content': json['content'],
|
||||
}),
|
||||
};
|
||||
final toolCallId =
|
||||
json['toolCallId'] as String? ?? json['callId'] as String? ?? '';
|
||||
final messageId = json['messageId'] as String? ?? 'tool-result-$toolCallId';
|
||||
return ToolCallResultEvent(
|
||||
messageId: json['messageId'] as String,
|
||||
toolCallId: json['toolCallId'] as String,
|
||||
messageId: messageId,
|
||||
toolCallId: toolCallId,
|
||||
content: content,
|
||||
);
|
||||
}
|
||||
@@ -388,6 +449,7 @@ class SnapshotMessage {
|
||||
final String? toolCallId;
|
||||
final UiCard? ui;
|
||||
final DateTime? timestamp;
|
||||
final List<Map<String, dynamic>>? attachments;
|
||||
|
||||
SnapshotMessage({
|
||||
required this.id,
|
||||
@@ -396,6 +458,7 @@ class SnapshotMessage {
|
||||
this.toolCallId,
|
||||
this.ui,
|
||||
this.timestamp,
|
||||
this.attachments,
|
||||
});
|
||||
|
||||
factory SnapshotMessage.fromJson(Map<String, dynamic> json) =>
|
||||
|
||||
@@ -14,6 +14,8 @@ const _$AgUiEventTypeEnumMap = {
|
||||
AgUiEventType.runStarted: 'runStarted',
|
||||
AgUiEventType.runFinished: 'runFinished',
|
||||
AgUiEventType.runError: 'runError',
|
||||
AgUiEventType.stepStarted: 'stepStarted',
|
||||
AgUiEventType.stepFinished: 'stepFinished',
|
||||
AgUiEventType.textMessageStart: 'textMessageStart',
|
||||
AgUiEventType.textMessageContent: 'textMessageContent',
|
||||
AgUiEventType.textMessageEnd: 'textMessageEnd',
|
||||
@@ -53,6 +55,18 @@ RunErrorEvent _$RunErrorEventFromJson(Map<String, dynamic> json) =>
|
||||
Map<String, dynamic> _$RunErrorEventToJson(RunErrorEvent instance) =>
|
||||
<String, dynamic>{'message': instance.message, 'code': instance.code};
|
||||
|
||||
StepStartedEvent _$StepStartedEventFromJson(Map<String, dynamic> json) =>
|
||||
StepStartedEvent(stepName: json['stepName'] as String);
|
||||
|
||||
Map<String, dynamic> _$StepStartedEventToJson(StepStartedEvent instance) =>
|
||||
<String, dynamic>{'stepName': instance.stepName};
|
||||
|
||||
StepFinishedEvent _$StepFinishedEventFromJson(Map<String, dynamic> json) =>
|
||||
StepFinishedEvent(stepName: json['stepName'] as String);
|
||||
|
||||
Map<String, dynamic> _$StepFinishedEventToJson(StepFinishedEvent instance) =>
|
||||
<String, dynamic>{'stepName': instance.stepName};
|
||||
|
||||
TextMessageStartEvent _$TextMessageStartEventFromJson(
|
||||
Map<String, dynamic> json,
|
||||
) => TextMessageStartEvent(
|
||||
@@ -170,6 +184,9 @@ SnapshotMessage _$SnapshotMessageFromJson(Map<String, dynamic> json) =>
|
||||
timestamp: json['timestamp'] == null
|
||||
? null
|
||||
: DateTime.parse(json['timestamp'] as String),
|
||||
attachments: (json['attachments'] as List<dynamic>?)
|
||||
?.whereType<Map<String, dynamic>>()
|
||||
.toList(),
|
||||
);
|
||||
|
||||
Map<String, dynamic> _$SnapshotMessageToJson(SnapshotMessage instance) =>
|
||||
@@ -180,4 +197,5 @@ Map<String, dynamic> _$SnapshotMessageToJson(SnapshotMessage instance) =>
|
||||
'toolCallId': instance.toolCallId,
|
||||
'ui': instance.ui,
|
||||
'timestamp': instance.timestamp?.toIso8601String(),
|
||||
'attachments': instance.attachments,
|
||||
};
|
||||
|
||||
@@ -22,6 +22,7 @@ class TextMessageItem extends ChatListItem {
|
||||
@override
|
||||
final MessageSender sender;
|
||||
final bool isStreaming;
|
||||
final List<Map<String, dynamic>> attachments;
|
||||
|
||||
TextMessageItem({
|
||||
required this.id,
|
||||
@@ -29,6 +30,7 @@ class TextMessageItem extends ChatListItem {
|
||||
required this.timestamp,
|
||||
required this.sender,
|
||||
this.isStreaming = false,
|
||||
this.attachments = const [],
|
||||
});
|
||||
|
||||
@override
|
||||
@@ -40,12 +42,14 @@ class TextMessageItem extends ChatListItem {
|
||||
DateTime? timestamp,
|
||||
MessageSender? sender,
|
||||
bool? isStreaming,
|
||||
List<Map<String, dynamic>>? attachments,
|
||||
}) => TextMessageItem(
|
||||
id: id ?? this.id,
|
||||
content: content ?? this.content,
|
||||
timestamp: timestamp ?? this.timestamp,
|
||||
sender: sender ?? this.sender,
|
||||
isStreaming: isStreaming ?? this.isStreaming,
|
||||
attachments: attachments ?? this.attachments,
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import 'dart:async';
|
||||
import 'dart:convert';
|
||||
import 'dart:math';
|
||||
import 'dart:typed_data';
|
||||
|
||||
import 'package:dio/dio.dart';
|
||||
import 'package:image_picker/image_picker.dart';
|
||||
@@ -80,6 +81,18 @@ class AgUiService {
|
||||
onEvent(event);
|
||||
}
|
||||
|
||||
Future<Uint8List> fetchAttachmentPreview(String previewPath) async {
|
||||
final response = await _apiClient.get<List<int>>(
|
||||
previewPath,
|
||||
options: Options(responseType: ResponseType.bytes),
|
||||
);
|
||||
final payload = response.data;
|
||||
if (payload is List<int>) {
|
||||
return Uint8List.fromList(payload);
|
||||
}
|
||||
throw StateError('Invalid attachment payload');
|
||||
}
|
||||
|
||||
Future<String> transcribeAudio(String filePath) async {
|
||||
final formData = FormData.fromMap({
|
||||
'audio': await MultipartFile.fromFile(
|
||||
@@ -247,22 +260,27 @@ class AgUiService {
|
||||
final runId = _nextId(_runIdPrefix);
|
||||
|
||||
final contentBlocks = <Map<String, dynamic>>[];
|
||||
final attachmentMetadata = <Map<String, dynamic>>[];
|
||||
|
||||
if (content.isNotEmpty) {
|
||||
contentBlocks.add({'type': 'text', 'text': content});
|
||||
}
|
||||
|
||||
if (images != null && images.isNotEmpty) {
|
||||
for (final image in images) {
|
||||
final bytes = await image.readAsBytes();
|
||||
final base64 = base64Encode(bytes);
|
||||
final uploadedAttachments = await _uploadAttachments(
|
||||
threadId: threadId,
|
||||
images: images,
|
||||
);
|
||||
for (final attachment in uploadedAttachments) {
|
||||
contentBlocks.add({
|
||||
'type': 'image',
|
||||
'source': {
|
||||
'type': 'base64',
|
||||
'media_type': 'image/jpeg',
|
||||
'data': base64,
|
||||
},
|
||||
'type': 'binary',
|
||||
'mimeType': attachment['mimeType'],
|
||||
'url': attachment['url'],
|
||||
});
|
||||
attachmentMetadata.add({
|
||||
'bucket': attachment['bucket'],
|
||||
'path': attachment['path'],
|
||||
'mimeType': attachment['mimeType'],
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -286,10 +304,64 @@ class AgUiService {
|
||||
],
|
||||
'tools': _buildTools(),
|
||||
'context': <Map<String, dynamic>>[],
|
||||
'forwardedProps': <String, dynamic>{},
|
||||
'forwardedProps': {
|
||||
if (attachmentMetadata.isNotEmpty) 'attachments': attachmentMetadata,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
Future<List<Map<String, dynamic>>> _uploadAttachments({
|
||||
required String threadId,
|
||||
required List<XFile> images,
|
||||
}) async {
|
||||
final attachments = <Map<String, dynamic>>[];
|
||||
for (final image in images) {
|
||||
final mimeType = image.mimeType ?? 'image/jpeg';
|
||||
final fileBytes = await image.readAsBytes();
|
||||
final formData = FormData.fromMap({
|
||||
'threadId': threadId,
|
||||
'file': MultipartFile.fromBytes(
|
||||
fileBytes,
|
||||
filename: image.name,
|
||||
contentType: DioMediaType.parse(mimeType),
|
||||
),
|
||||
});
|
||||
final response = await _apiClient.post<Map<String, dynamic>>(
|
||||
'/api/v1/agent/attachments',
|
||||
data: formData,
|
||||
);
|
||||
final payload = response.data;
|
||||
if (payload is! Map<String, dynamic>) {
|
||||
throw StateError('Invalid /agent/attachments response');
|
||||
}
|
||||
final attachment = payload['attachment'];
|
||||
if (attachment is! Map<String, dynamic>) {
|
||||
throw StateError('Missing attachment in /agent/attachments response');
|
||||
}
|
||||
final bucket = attachment['bucket'];
|
||||
final path = attachment['path'];
|
||||
final uploadedMime = attachment['mimeType'];
|
||||
final url = attachment['url'];
|
||||
if (bucket is! String ||
|
||||
path is! String ||
|
||||
uploadedMime is! String ||
|
||||
url is! String ||
|
||||
bucket.isEmpty ||
|
||||
path.isEmpty ||
|
||||
uploadedMime.isEmpty ||
|
||||
url.isEmpty) {
|
||||
throw StateError('Invalid attachment reference');
|
||||
}
|
||||
attachments.add({
|
||||
'bucket': bucket,
|
||||
'path': path,
|
||||
'mimeType': uploadedMime,
|
||||
'url': url,
|
||||
});
|
||||
}
|
||||
return attachments;
|
||||
}
|
||||
|
||||
List<Map<String, dynamic>> _buildTools() {
|
||||
return [
|
||||
{
|
||||
@@ -360,6 +432,11 @@ class AgUiService {
|
||||
'SSE',
|
||||
_handleMockSse,
|
||||
);
|
||||
client.registerHandler(
|
||||
'/api/v1/agent/attachments',
|
||||
'POST',
|
||||
_handleMockUploadAttachment,
|
||||
);
|
||||
client.registerHandler(
|
||||
'/api/v1/agent/transcribe',
|
||||
'POST',
|
||||
@@ -371,6 +448,26 @@ class AgUiService {
|
||||
return {'transcript': '这是模拟语音转写'};
|
||||
}
|
||||
|
||||
Map<String, dynamic> _handleMockUploadAttachment(MockRequest request) {
|
||||
final payload = request.data;
|
||||
final threadId = payload is Map<String, dynamic>
|
||||
? (payload['threadId'] as String?)
|
||||
: null;
|
||||
final resolvedThreadId = (threadId != null && threadId.isNotEmpty)
|
||||
? threadId
|
||||
: (_threadId ?? _newUuid());
|
||||
final path =
|
||||
'agent-inputs/mock/$resolvedThreadId/${_nextId('upload_')}.png';
|
||||
return {
|
||||
'attachment': {
|
||||
'bucket': 'mock-bucket',
|
||||
'path': path,
|
||||
'mimeType': 'image/png',
|
||||
'url': 'https://mock.local/$path',
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
Map<String, dynamic> _handleMockRun(MockRequest request) {
|
||||
final payload = request.data;
|
||||
final runInput = payload is Map<String, dynamic>
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import 'dart:convert';
|
||||
import 'dart:typed_data';
|
||||
|
||||
import 'package:flutter_bloc/flutter_bloc.dart';
|
||||
import 'package:image_picker/image_picker.dart';
|
||||
@@ -10,6 +11,8 @@ import '../../data/models/ag_ui_event.dart';
|
||||
import '../../data/models/chat_list_item.dart';
|
||||
import '../../data/services/ag_ui_service.dart';
|
||||
|
||||
enum AgentStage { intent, execution, report }
|
||||
|
||||
class ChatState {
|
||||
final List<ChatListItem> items;
|
||||
final bool isSending;
|
||||
@@ -21,6 +24,7 @@ class ChatState {
|
||||
final String? error;
|
||||
final DateTime? oldestLoadedDate;
|
||||
final bool hasEarlierHistory;
|
||||
final AgentStage? currentStage;
|
||||
|
||||
const ChatState({
|
||||
this.items = const [],
|
||||
@@ -33,6 +37,7 @@ class ChatState {
|
||||
this.error,
|
||||
this.oldestLoadedDate,
|
||||
this.hasEarlierHistory = false,
|
||||
this.currentStage,
|
||||
});
|
||||
|
||||
bool get isLoading =>
|
||||
@@ -55,6 +60,7 @@ class ChatState {
|
||||
Object? error = _unset,
|
||||
Object? oldestLoadedDate = _unset,
|
||||
bool? hasEarlierHistory,
|
||||
Object? currentStage = _unset,
|
||||
}) {
|
||||
return ChatState(
|
||||
items: items ?? this.items,
|
||||
@@ -71,6 +77,9 @@ class ChatState {
|
||||
? this.oldestLoadedDate
|
||||
: oldestLoadedDate as DateTime?,
|
||||
hasEarlierHistory: hasEarlierHistory ?? this.hasEarlierHistory,
|
||||
currentStage: currentStage == _unset
|
||||
? this.currentStage
|
||||
: currentStage as AgentStage?,
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -78,6 +87,9 @@ class ChatState {
|
||||
class ChatBloc extends Cubit<ChatState> {
|
||||
final AgUiService _service;
|
||||
final Map<String, String> _toolCallArgsBuffer = {};
|
||||
final Map<String, Uint8List> _attachmentPreviewCache = <String, Uint8List>{};
|
||||
final Map<String, Future<Uint8List?>> _attachmentPreviewInflight =
|
||||
<String, Future<Uint8List?>>{};
|
||||
|
||||
ChatBloc({AgUiService? service, IApiClient? apiClient})
|
||||
: _service =
|
||||
@@ -102,6 +114,7 @@ class ChatBloc extends Cubit<ChatState> {
|
||||
isWaitingFirstToken: true,
|
||||
isCancelling: false,
|
||||
error: null,
|
||||
currentStage: null,
|
||||
),
|
||||
);
|
||||
case AgUiEventType.runFinished:
|
||||
@@ -112,6 +125,7 @@ class ChatBloc extends Cubit<ChatState> {
|
||||
isStreaming: false,
|
||||
isCancelling: false,
|
||||
currentMessageId: null,
|
||||
currentStage: null,
|
||||
),
|
||||
);
|
||||
case AgUiEventType.runError:
|
||||
@@ -124,8 +138,13 @@ class ChatBloc extends Cubit<ChatState> {
|
||||
isCancelling: false,
|
||||
currentMessageId: null,
|
||||
error: errorEvent.message,
|
||||
currentStage: null,
|
||||
),
|
||||
);
|
||||
case AgUiEventType.stepStarted:
|
||||
_handleStepStarted(event as StepStartedEvent);
|
||||
case AgUiEventType.stepFinished:
|
||||
_handleStepFinished(event as StepFinishedEvent);
|
||||
case AgUiEventType.textMessageStart:
|
||||
_handleTextMessageStart(event as TextMessageStartEvent);
|
||||
case AgUiEventType.textMessageContent:
|
||||
@@ -151,6 +170,16 @@ class ChatBloc extends Cubit<ChatState> {
|
||||
}
|
||||
}
|
||||
|
||||
void _handleStepStarted(StepStartedEvent event) {
|
||||
emit(state.copyWith(currentStage: _stageFromName(event.stepName)));
|
||||
}
|
||||
|
||||
void _handleStepFinished(StepFinishedEvent event) {
|
||||
if (state.currentStage == _stageFromName(event.stepName)) {
|
||||
emit(state.copyWith(currentStage: null));
|
||||
}
|
||||
}
|
||||
|
||||
void _handleTextMessageStart(TextMessageStartEvent startEvent) {
|
||||
final newMessage = TextMessageItem(
|
||||
id: startEvent.messageId,
|
||||
@@ -327,6 +356,7 @@ class ChatBloc extends Cubit<ChatState> {
|
||||
content: msg.content ?? '',
|
||||
timestamp: timestamp,
|
||||
sender: MessageSender.user,
|
||||
attachments: msg.attachments ?? const [],
|
||||
);
|
||||
case 'assistant':
|
||||
return TextMessageItem(
|
||||
@@ -369,11 +399,20 @@ class ChatBloc extends Cubit<ChatState> {
|
||||
}
|
||||
|
||||
Future<void> sendMessage(String content, {List<XFile>? images}) async {
|
||||
final attachments = (images ?? const <XFile>[])
|
||||
.map(
|
||||
(image) => <String, dynamic>{
|
||||
"path": image.path,
|
||||
"mimeType": "image/*",
|
||||
},
|
||||
)
|
||||
.toList();
|
||||
final userMessage = TextMessageItem(
|
||||
id: 'user-${DateTime.now().millisecondsSinceEpoch}',
|
||||
content: content,
|
||||
timestamp: DateTime.now(),
|
||||
sender: MessageSender.user,
|
||||
attachments: attachments,
|
||||
);
|
||||
emit(
|
||||
state.copyWith(
|
||||
@@ -509,7 +548,43 @@ class ChatBloc extends Cubit<ChatState> {
|
||||
}
|
||||
}
|
||||
|
||||
Future<Uint8List?> loadAttachmentPreview(String previewPath) async {
|
||||
final cached = _attachmentPreviewCache[previewPath];
|
||||
if (cached != null) {
|
||||
return cached;
|
||||
}
|
||||
final pending = _attachmentPreviewInflight[previewPath];
|
||||
if (pending != null) {
|
||||
return pending;
|
||||
}
|
||||
final future = _service
|
||||
.fetchAttachmentPreview(previewPath)
|
||||
.then((bytes) {
|
||||
_attachmentPreviewCache[previewPath] = bytes;
|
||||
return bytes;
|
||||
})
|
||||
.catchError((_) => null)
|
||||
.whenComplete(() {
|
||||
_attachmentPreviewInflight.remove(previewPath);
|
||||
});
|
||||
_attachmentPreviewInflight[previewPath] = future;
|
||||
return future;
|
||||
}
|
||||
|
||||
void clearError() {
|
||||
emit(state.copyWith(error: null));
|
||||
}
|
||||
}
|
||||
|
||||
AgentStage? _stageFromName(String value) {
|
||||
switch (value) {
|
||||
case 'intent':
|
||||
return AgentStage.intent;
|
||||
case 'execution':
|
||||
return AgentStage.execution;
|
||||
case 'report':
|
||||
return AgentStage.report;
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import 'dart:io';
|
||||
import 'dart:typed_data';
|
||||
|
||||
import 'package:flutter/material.dart';
|
||||
import 'package:flutter_bloc/flutter_bloc.dart';
|
||||
@@ -34,12 +35,15 @@ const _rippleDurationMs = 1200;
|
||||
const _recordingDotSize = 10.0;
|
||||
const _transcribingSpinnerSize = 18.0;
|
||||
const _transcribingStrokeWidth = 2.0;
|
||||
const _attachmentPreviewSize = 88.0;
|
||||
const _attachmentPreviewRadius = 10.0;
|
||||
const _attachmentPreviewGap = 8.0;
|
||||
const _inputActionButtonKey = ValueKey('home_input_action_button');
|
||||
const _inputActionIconKey = ValueKey('home_input_action_icon');
|
||||
|
||||
/// 颜色常量
|
||||
const _chatBgColor = Color(0xFFF8FAFC);
|
||||
const _userBubbleColor = Color(0xFFEAF1FB);
|
||||
const _chatBgColor = AppColors.slate50;
|
||||
const _userBubbleColor = AppColors.blue50;
|
||||
|
||||
class HomeScreen extends StatefulWidget {
|
||||
final VoiceRecorder? voiceRecorder;
|
||||
@@ -265,7 +269,8 @@ class _HomeScreenState extends State<HomeScreen>
|
||||
),
|
||||
),
|
||||
),
|
||||
if (showWaitingIndicator) _buildWaitingIndicator(),
|
||||
if (showWaitingIndicator)
|
||||
_buildWaitingIndicator(currentStage: state.currentStage),
|
||||
],
|
||||
);
|
||||
}
|
||||
@@ -310,12 +315,19 @@ class _HomeScreenState extends State<HomeScreen>
|
||||
),
|
||||
),
|
||||
),
|
||||
if (showWaitingIndicator) _buildWaitingIndicator(),
|
||||
if (showWaitingIndicator)
|
||||
_buildWaitingIndicator(currentStage: state.currentStage),
|
||||
],
|
||||
);
|
||||
}
|
||||
|
||||
Widget _buildWaitingIndicator() {
|
||||
Widget _buildWaitingIndicator({required AgentStage? currentStage}) {
|
||||
final label = switch (currentStage) {
|
||||
AgentStage.intent => '意图识别中',
|
||||
AgentStage.execution => '任务执行中',
|
||||
AgentStage.report => '结果总结中',
|
||||
null => '正在思考...',
|
||||
};
|
||||
return Padding(
|
||||
padding: const EdgeInsets.fromLTRB(
|
||||
_defaultPadding,
|
||||
@@ -325,7 +337,7 @@ class _HomeScreenState extends State<HomeScreen>
|
||||
),
|
||||
child: Row(
|
||||
crossAxisAlignment: CrossAxisAlignment.center,
|
||||
children: const [
|
||||
children: [
|
||||
SizedBox(
|
||||
width: _transcribingSpinnerSize,
|
||||
height: _transcribingSpinnerSize,
|
||||
@@ -336,7 +348,7 @@ class _HomeScreenState extends State<HomeScreen>
|
||||
),
|
||||
SizedBox(width: 8),
|
||||
Text(
|
||||
'正在思考...',
|
||||
label,
|
||||
style: TextStyle(fontSize: 14, color: AppColors.slate500),
|
||||
),
|
||||
],
|
||||
@@ -413,38 +425,152 @@ class _HomeScreenState extends State<HomeScreen>
|
||||
|
||||
Widget _buildMessageItem(TextMessageItem item) {
|
||||
final isUser = item.sender == MessageSender.user;
|
||||
return Row(
|
||||
mainAxisAlignment: isUser
|
||||
? MainAxisAlignment.end
|
||||
: MainAxisAlignment.start,
|
||||
crossAxisAlignment: CrossAxisAlignment.start,
|
||||
final imageAttachments = _collectRenderableImageAttachments(
|
||||
item.attachments,
|
||||
);
|
||||
final hasRenderableAttachments = imageAttachments.isNotEmpty;
|
||||
return Column(
|
||||
crossAxisAlignment: isUser
|
||||
? CrossAxisAlignment.end
|
||||
: CrossAxisAlignment.start,
|
||||
children: [
|
||||
Flexible(
|
||||
child: Container(
|
||||
padding: const EdgeInsets.symmetric(
|
||||
horizontal: _messagePaddingH,
|
||||
vertical: _messagePaddingV,
|
||||
),
|
||||
decoration: BoxDecoration(
|
||||
color: isUser ? _userBubbleColor : AppColors.white,
|
||||
borderRadius: BorderRadius.only(
|
||||
topLeft: const Radius.circular(_cornerRadius),
|
||||
topRight: const Radius.circular(_cornerRadius),
|
||||
bottomLeft: Radius.circular(isUser ? _cornerRadius : 0),
|
||||
bottomRight: Radius.circular(isUser ? 0 : _cornerRadius),
|
||||
Row(
|
||||
mainAxisAlignment: isUser
|
||||
? MainAxisAlignment.end
|
||||
: MainAxisAlignment.start,
|
||||
crossAxisAlignment: CrossAxisAlignment.start,
|
||||
children: [
|
||||
Flexible(
|
||||
child: Container(
|
||||
padding: const EdgeInsets.symmetric(
|
||||
horizontal: _messagePaddingH,
|
||||
vertical: _messagePaddingV,
|
||||
),
|
||||
decoration: BoxDecoration(
|
||||
color: isUser ? _userBubbleColor : AppColors.white,
|
||||
borderRadius: BorderRadius.only(
|
||||
topLeft: const Radius.circular(_cornerRadius),
|
||||
topRight: const Radius.circular(_cornerRadius),
|
||||
bottomLeft: Radius.circular(isUser ? _cornerRadius : 0),
|
||||
bottomRight: Radius.circular(isUser ? 0 : _cornerRadius),
|
||||
),
|
||||
border: isUser ? null : Border.all(color: AppColors.slate300),
|
||||
),
|
||||
child: Text(
|
||||
item.content,
|
||||
style: const TextStyle(
|
||||
fontSize: 14,
|
||||
color: AppColors.slate900,
|
||||
),
|
||||
),
|
||||
),
|
||||
border: isUser ? null : Border.all(color: AppColors.slate300),
|
||||
),
|
||||
child: Text(
|
||||
item.content,
|
||||
style: const TextStyle(fontSize: 14, color: AppColors.slate900),
|
||||
if (item.attachments.isNotEmpty && !hasRenderableAttachments) ...[
|
||||
const SizedBox(width: _itemSpacing / 2),
|
||||
_buildAttachmentBadge(item.attachments.length),
|
||||
],
|
||||
],
|
||||
),
|
||||
if (hasRenderableAttachments)
|
||||
Padding(
|
||||
padding: const EdgeInsets.only(top: _attachmentPreviewGap),
|
||||
child: _buildHistoryAttachmentPreviews(
|
||||
item.attachments,
|
||||
imageAttachments: imageAttachments,
|
||||
),
|
||||
),
|
||||
),
|
||||
],
|
||||
);
|
||||
}
|
||||
|
||||
Widget _buildHistoryAttachmentPreviews(
|
||||
List<Map<String, dynamic>> attachments, {
|
||||
List<Map<String, dynamic>>? imageAttachments,
|
||||
}) {
|
||||
final renderableAttachments =
|
||||
imageAttachments ?? _collectRenderableImageAttachments(attachments);
|
||||
if (renderableAttachments.isEmpty) {
|
||||
return _buildAttachmentBadge(attachments.length);
|
||||
}
|
||||
return Wrap(
|
||||
spacing: _attachmentPreviewGap,
|
||||
runSpacing: _attachmentPreviewGap,
|
||||
crossAxisAlignment: WrapCrossAlignment.start,
|
||||
children: renderableAttachments.map(_buildHistoryAttachmentTile).toList(),
|
||||
);
|
||||
}
|
||||
|
||||
List<Map<String, dynamic>> _collectRenderableImageAttachments(
|
||||
List<Map<String, dynamic>> attachments,
|
||||
) {
|
||||
return attachments.where(_isRenderableImageAttachment).toList();
|
||||
}
|
||||
|
||||
bool _isRenderableImageAttachment(Map<String, dynamic> attachment) {
|
||||
final mimeType = attachment['mimeType'];
|
||||
final previewPath = attachment['previewPath'];
|
||||
return mimeType is String &&
|
||||
mimeType.startsWith('image/') &&
|
||||
previewPath is String &&
|
||||
previewPath.isNotEmpty;
|
||||
}
|
||||
|
||||
Widget _buildHistoryAttachmentTile(Map<String, dynamic> attachment) {
|
||||
final previewPath = attachment['previewPath'];
|
||||
if (previewPath is! String || previewPath.isEmpty) {
|
||||
return _buildAttachmentBadge(1);
|
||||
}
|
||||
return ClipRRect(
|
||||
borderRadius: BorderRadius.circular(_attachmentPreviewRadius),
|
||||
child: Container(
|
||||
width: _attachmentPreviewSize,
|
||||
height: _attachmentPreviewSize,
|
||||
color: AppColors.slate100,
|
||||
child: FutureBuilder<Uint8List?>(
|
||||
future: _chatBloc.loadAttachmentPreview(previewPath),
|
||||
builder: (context, snapshot) {
|
||||
if (snapshot.connectionState == ConnectionState.waiting) {
|
||||
return const Center(
|
||||
child: SizedBox(
|
||||
width: _transcribingSpinnerSize,
|
||||
height: _transcribingSpinnerSize,
|
||||
child: CircularProgressIndicator(
|
||||
strokeWidth: _transcribingStrokeWidth,
|
||||
),
|
||||
),
|
||||
);
|
||||
}
|
||||
final data = snapshot.data;
|
||||
if (data == null || data.isEmpty) {
|
||||
return const Center(
|
||||
child: Icon(
|
||||
LucideIcons.imageOff,
|
||||
size: _iconSize,
|
||||
color: AppColors.slate500,
|
||||
),
|
||||
);
|
||||
}
|
||||
return Image.memory(data, fit: BoxFit.cover, gaplessPlayback: true);
|
||||
},
|
||||
),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
Widget _buildAttachmentBadge(int count) {
|
||||
return Container(
|
||||
padding: const EdgeInsets.symmetric(horizontal: 8, vertical: 4),
|
||||
decoration: BoxDecoration(
|
||||
color: AppColors.slate200,
|
||||
borderRadius: BorderRadius.circular(8),
|
||||
),
|
||||
child: Text(
|
||||
'图片附件 x$count',
|
||||
style: const TextStyle(fontSize: 12, color: AppColors.slate600),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
Widget _buildToolCallItem(ToolCallItem item) {
|
||||
final (statusText, statusColor, statusIcon) = switch (item.status) {
|
||||
ToolCallStatus.pending => (
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import 'dart:convert';
|
||||
import 'dart:typed_data';
|
||||
|
||||
import 'package:flutter_test/flutter_test.dart';
|
||||
import 'package:image_picker/image_picker.dart';
|
||||
@@ -265,6 +266,71 @@ void main() {
|
||||
expect(first['content'], '只发送当前输入');
|
||||
});
|
||||
|
||||
test('sendMessage uploads images then posts binary url blocks', () async {
|
||||
final client = MockApiClient();
|
||||
final service = AgUiService(onEvent: (_) {}, apiClient: client);
|
||||
client.clearMocks();
|
||||
|
||||
var uploadCalls = 0;
|
||||
final uploadedPath = 'agent-inputs/user/thread-1/upload-1.png';
|
||||
client.registerHandler('/api/v1/agent/attachments', 'POST', (request) {
|
||||
uploadCalls += 1;
|
||||
return {
|
||||
'attachment': {
|
||||
'bucket': 'bucket-test',
|
||||
'path': uploadedPath,
|
||||
'mimeType': 'image/png',
|
||||
'url': 'https://signed.example/$uploadedPath',
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
Map<String, dynamic>? postedRunInput;
|
||||
client.registerHandler('/api/v1/agent/runs', 'POST', (request) {
|
||||
postedRunInput = request.data as Map<String, dynamic>;
|
||||
return {
|
||||
'taskId': 'task-1',
|
||||
'threadId': 'thread-1',
|
||||
'runId': 'run-1',
|
||||
'created': false,
|
||||
};
|
||||
});
|
||||
client.registerHandler('/api/v1/agent/runs/thread-1/events', 'SSE', (_) {
|
||||
return <String>[
|
||||
'event: RUN_STARTED',
|
||||
'data: {"type":"RUN_STARTED","threadId":"thread-1","runId":"run-1"}',
|
||||
'',
|
||||
'event: RUN_FINISHED',
|
||||
'data: {"type":"RUN_FINISHED","threadId":"thread-1","runId":"run-1"}',
|
||||
'',
|
||||
];
|
||||
});
|
||||
|
||||
final image = XFile.fromData(
|
||||
Uint8List.fromList(<int>[1, 2, 3]),
|
||||
mimeType: 'image/png',
|
||||
name: 'demo.png',
|
||||
);
|
||||
|
||||
await service.sendMessage('图文消息', images: [image]);
|
||||
|
||||
expect(uploadCalls, 1);
|
||||
expect(postedRunInput, isNotNull);
|
||||
final messages = postedRunInput!['messages'] as List<dynamic>;
|
||||
final first = messages.first as Map<String, dynamic>;
|
||||
final content = first['content'] as List<dynamic>;
|
||||
expect((content.first as Map<String, dynamic>)['type'], 'text');
|
||||
expect((content[1] as Map<String, dynamic>)['type'], 'binary');
|
||||
expect(
|
||||
(content[1] as Map<String, dynamic>)['url'],
|
||||
'https://signed.example/$uploadedPath',
|
||||
);
|
||||
final forwardedProps =
|
||||
postedRunInput!['forwardedProps'] as Map<String, dynamic>;
|
||||
final attachments = forwardedProps['attachments'] as List<dynamic>;
|
||||
expect((attachments.first as Map<String, dynamic>)['path'], uploadedPath);
|
||||
});
|
||||
|
||||
test('approveToolCall posts only tool message to resume API', () async {
|
||||
final client = MockApiClient();
|
||||
final service = AgUiService(onEvent: (_) {}, apiClient: client);
|
||||
@@ -482,5 +548,56 @@ void main() {
|
||||
expect(seenLastEventIds[0], isNull);
|
||||
expect(seenLastEventIds[1], '2-0');
|
||||
});
|
||||
|
||||
test('stream parses backend TOOL_CALL_RESULT payload with ui field', () async {
|
||||
final events = <AgUiEvent>[];
|
||||
final client = MockApiClient();
|
||||
final service = AgUiService(onEvent: events.add, apiClient: client);
|
||||
client.clearMocks();
|
||||
client.registerHandler('/api/v1/agent/runs', 'POST', (_) {
|
||||
return {
|
||||
'taskId': 'task-1',
|
||||
'threadId': 'thread-1',
|
||||
'runId': 'run-1',
|
||||
'created': false,
|
||||
};
|
||||
});
|
||||
client.registerHandler('/api/v1/agent/runs/thread-1/events', 'SSE', (_) {
|
||||
return <String>[
|
||||
'event: RUN_STARTED',
|
||||
'data: {"type":"RUN_STARTED","threadId":"thread-1","runId":"run-1"}',
|
||||
'',
|
||||
'event: TOOL_CALL_RESULT',
|
||||
'data: {"type":"TOOL_CALL_RESULT","messageId":"tool-result-1","toolCallId":"call-1","callId":"call-1","toolName":"calendar_write","result":{"type":"calendar_operation.v1","version":"v1","data":{"ok":true,"operation":"create"},"actions":[]},"ui":{"type":"calendar_operation.v1","version":"v1","data":{"ok":true,"operation":"create"},"actions":[]},"content":"已创建日程:项目评审(明天 10:00)"}',
|
||||
'',
|
||||
'event: RUN_FINISHED',
|
||||
'data: {"type":"RUN_FINISHED","threadId":"thread-1","runId":"run-1"}',
|
||||
'',
|
||||
];
|
||||
});
|
||||
|
||||
await service.sendMessage('创建日程');
|
||||
|
||||
final result = events.whereType<ToolCallResultEvent>().toList();
|
||||
expect(result.length, 1);
|
||||
expect(result.first.ui?.cardType, 'calendar_operation.v1');
|
||||
});
|
||||
|
||||
test('fetchAttachmentPreview returns binary bytes', () async {
|
||||
final client = MockApiClient();
|
||||
final service = AgUiService(onEvent: (_) {}, apiClient: client);
|
||||
client.clearMocks();
|
||||
client.registerHandler(
|
||||
'/api/v1/agent/runs/t1/attachments/m1/0',
|
||||
'GET',
|
||||
(_) => <int>[1, 2, 3, 4],
|
||||
);
|
||||
|
||||
final data = await service.fetchAttachmentPreview(
|
||||
'/api/v1/agent/runs/t1/attachments/m1/0',
|
||||
);
|
||||
|
||||
expect(data, [1, 2, 3, 4]);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import 'dart:typed_data';
|
||||
|
||||
import 'package:bloc_test/bloc_test.dart';
|
||||
import 'package:flutter_test/flutter_test.dart';
|
||||
import 'package:image_picker/image_picker.dart';
|
||||
@@ -9,8 +11,17 @@ import 'package:social_app/features/chat/presentation/bloc/chat_bloc.dart';
|
||||
class MockAgUiService extends AgUiService {
|
||||
MockAgUiService() : super(onEvent: (_) {});
|
||||
|
||||
int previewCalls = 0;
|
||||
|
||||
@override
|
||||
Future<void> sendMessage(String content, {List<XFile>? images}) async {}
|
||||
|
||||
@override
|
||||
Future<Uint8List> fetchAttachmentPreview(String previewPath) async {
|
||||
previewCalls += 1;
|
||||
await Future<void>.delayed(const Duration(milliseconds: 10));
|
||||
return Uint8List.fromList(<int>[1, 2, 3]);
|
||||
}
|
||||
}
|
||||
|
||||
class _ThrowingAgUiService extends AgUiService {
|
||||
@@ -182,6 +193,23 @@ void main() {
|
||||
],
|
||||
);
|
||||
|
||||
blocTest<ChatBloc, ChatState>(
|
||||
'step events update currentStage',
|
||||
build: () => chatBloc,
|
||||
act: (bloc) {
|
||||
service.onEvent(StepStartedEvent(stepName: 'execution'));
|
||||
service.onEvent(StepFinishedEvent(stepName: 'execution'));
|
||||
},
|
||||
expect: () => [
|
||||
isA<ChatState>().having(
|
||||
(s) => s.currentStage,
|
||||
'currentStage',
|
||||
AgentStage.execution,
|
||||
),
|
||||
isA<ChatState>().having((s) => s.currentStage, 'currentStage', isNull),
|
||||
],
|
||||
);
|
||||
|
||||
blocTest<ChatBloc, ChatState>(
|
||||
'runError sets error message',
|
||||
build: () => chatBloc,
|
||||
@@ -325,5 +353,58 @@ void main() {
|
||||
),
|
||||
],
|
||||
);
|
||||
|
||||
blocTest<ChatBloc, ChatState>(
|
||||
'state snapshot user message keeps attachments',
|
||||
build: () => chatBloc,
|
||||
act: (bloc) {
|
||||
service.onEvent(
|
||||
StateSnapshotEvent(
|
||||
snapshot: {
|
||||
'scope': 'history_day',
|
||||
'messages': [
|
||||
{
|
||||
'id': 'u1',
|
||||
'role': 'user',
|
||||
'content': '请分析这张图',
|
||||
'attachments': [
|
||||
{'bucket': 'b', 'path': 'p', 'mimeType': 'image/png'},
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
),
|
||||
);
|
||||
},
|
||||
expect: () => [
|
||||
isA<ChatState>().having(
|
||||
(s) {
|
||||
final item = s.items.first;
|
||||
return item is TextMessageItem && item.attachments.length == 1;
|
||||
},
|
||||
'user attachment count',
|
||||
true,
|
||||
),
|
||||
],
|
||||
);
|
||||
|
||||
test(
|
||||
'loadAttachmentPreview deduplicates in-flight and caches result',
|
||||
() async {
|
||||
final mock = service as MockAgUiService;
|
||||
final results = await Future.wait<Uint8List?>([
|
||||
chatBloc.loadAttachmentPreview('/api/preview/1'),
|
||||
chatBloc.loadAttachmentPreview('/api/preview/1'),
|
||||
]);
|
||||
final secondRound = await chatBloc.loadAttachmentPreview(
|
||||
'/api/preview/1',
|
||||
);
|
||||
|
||||
expect(results.first, isNotNull);
|
||||
expect(results.last, isNotNull);
|
||||
expect(secondRound, isNotNull);
|
||||
expect(mock.previewCalls, 1);
|
||||
},
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -6,11 +6,14 @@ import 'package:flutter_test/flutter_test.dart';
|
||||
import 'package:image_picker/image_picker.dart';
|
||||
import 'package:lucide_icons/lucide_icons.dart';
|
||||
import 'package:social_app/core/api/api_exception.dart';
|
||||
import 'package:social_app/core/api/mock_api_client.dart';
|
||||
import 'package:social_app/core/di/injection.dart';
|
||||
import 'package:social_app/features/chat/data/models/ag_ui_event.dart';
|
||||
import 'package:social_app/features/chat/data/services/ag_ui_service.dart';
|
||||
import 'package:social_app/features/chat/presentation/bloc/chat_bloc.dart';
|
||||
import 'package:social_app/features/home/data/voice_recorder.dart';
|
||||
import 'package:social_app/features/home/ui/screens/home_screen.dart';
|
||||
import 'package:social_app/features/messages/data/inbox_api.dart';
|
||||
|
||||
class _FakeVoiceRecorder implements VoiceRecorder {
|
||||
bool started = false;
|
||||
@@ -43,9 +46,19 @@ class _WaitingAgUiService extends AgUiService {
|
||||
onEvent(RunStartedEvent(threadId: 't1', runId: 'r1'));
|
||||
return _pending.future;
|
||||
}
|
||||
|
||||
void emitStepStarted(String stepName) {
|
||||
onEvent(StepStartedEvent(stepName: stepName));
|
||||
}
|
||||
}
|
||||
|
||||
void main() {
|
||||
setUpAll(() {
|
||||
if (!sl.isRegistered<InboxApi>()) {
|
||||
sl.registerSingleton<InboxApi>(InboxApi(MockApiClient()));
|
||||
}
|
||||
});
|
||||
|
||||
IconData _inputActionIcon(WidgetTester tester) {
|
||||
final icon = tester.widget<Icon>(
|
||||
find.byKey(const ValueKey('home_input_action_icon')),
|
||||
@@ -275,7 +288,8 @@ void main() {
|
||||
testWidgets('shows stop icon and waiting indicator while waiting agent', (
|
||||
WidgetTester tester,
|
||||
) async {
|
||||
final chatBloc = ChatBloc(service: _WaitingAgUiService());
|
||||
final waitingService = _WaitingAgUiService();
|
||||
final chatBloc = ChatBloc(service: waitingService);
|
||||
await tester.pumpWidget(
|
||||
MaterialApp(
|
||||
home: HomeScreen(autoLoadHistory: false, chatBloc: chatBloc),
|
||||
@@ -291,6 +305,12 @@ void main() {
|
||||
expect(_inputActionIcon(tester), LucideIcons.square);
|
||||
expect(find.text('正在思考...'), findsOneWidget);
|
||||
|
||||
waitingService.emitStepStarted('intent');
|
||||
await tester.pump();
|
||||
|
||||
expect(find.text('意图识别中'), findsOneWidget);
|
||||
expect(find.text('正在思考...'), findsNothing);
|
||||
|
||||
await tester.tap(find.byKey(const ValueKey('home_input_action_button')));
|
||||
await tester.pump();
|
||||
|
||||
|
||||
@@ -37,6 +37,21 @@ def to_agui_wire_event(event: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
data = event.get("data")
|
||||
if isinstance(data, dict):
|
||||
if event_type == "tool.result":
|
||||
for key in (
|
||||
"messageId",
|
||||
"toolCallId",
|
||||
"callId",
|
||||
"toolName",
|
||||
"stage",
|
||||
"taskId",
|
||||
"ui",
|
||||
"content",
|
||||
):
|
||||
value = data.get(key)
|
||||
if value is not None:
|
||||
payload[key] = value
|
||||
return payload
|
||||
reserved = {"type", "threadId", "runId"}
|
||||
data_map = cast(dict[str, Any], data)
|
||||
payload.update({k: v for k, v in data_map.items() if k not in reserved})
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from decimal import Decimal, InvalidOperation
|
||||
from typing import Any, Callable, Protocol
|
||||
from uuid import UUID
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from core.agentscope.events.tool_result_summary import build_tool_content_summary
|
||||
from core.agentscope.events.persistence import MessageRepository, SessionRepository
|
||||
from core.logging import get_logger
|
||||
from models.agent_chat_message import AgentChatMessageRole
|
||||
from models.agent_chat_session import AgentChatSessionStatus
|
||||
|
||||
@@ -14,6 +17,16 @@ class EventStore(Protocol):
|
||||
async def persist(self, event: dict[str, Any]) -> None: ...
|
||||
|
||||
|
||||
class ToolResultStorageLike(Protocol):
|
||||
async def upload_json(
|
||||
self,
|
||||
*,
|
||||
bucket: str,
|
||||
path: str,
|
||||
payload: dict[str, object],
|
||||
) -> str: ...
|
||||
|
||||
|
||||
class NullEventStore:
|
||||
async def persist(self, event: dict[str, Any]) -> None:
|
||||
del event
|
||||
@@ -21,9 +34,20 @@ class NullEventStore:
|
||||
|
||||
class SqlAlchemyEventStore:
|
||||
_session_factory: Callable[[], Any]
|
||||
_tool_result_storage: ToolResultStorageLike | None
|
||||
_tool_result_bucket: str | None
|
||||
_logger = get_logger("core.agentscope.events.store")
|
||||
|
||||
def __init__(self, *, session_factory: Any) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
session_factory: Any,
|
||||
tool_result_storage: ToolResultStorageLike | None = None,
|
||||
tool_result_bucket: str | None = None,
|
||||
) -> None:
|
||||
self._session_factory = session_factory
|
||||
self._tool_result_storage = tool_result_storage
|
||||
self._tool_result_bucket = tool_result_bucket
|
||||
self._message_buffers: dict[tuple[str, str], str] = {}
|
||||
self._message_contexts: dict[tuple[str, str], dict[str, object]] = {}
|
||||
|
||||
@@ -228,23 +252,89 @@ class SqlAlchemyEventStore:
|
||||
if not isinstance(tool_name, str) or not tool_name:
|
||||
return
|
||||
|
||||
payload = {
|
||||
"args": event.get("args"),
|
||||
"result": event.get("result"),
|
||||
"error": event.get("error"),
|
||||
"call_id": event.get("callId"),
|
||||
}
|
||||
content = json.dumps(payload, ensure_ascii=False, separators=(",", ":"))
|
||||
metadata: dict[str, object] = {"tool_name": tool_name}
|
||||
run_id = event.get("runId")
|
||||
if isinstance(run_id, str) and run_id:
|
||||
metadata["run_id"] = run_id
|
||||
run_id_value = run_id if isinstance(run_id, str) and run_id else ""
|
||||
task_id = event.get("taskId")
|
||||
task_id_value = task_id if isinstance(task_id, str) and task_id else "task"
|
||||
call_id_value = event.get("callId")
|
||||
if not isinstance(call_id_value, str) or not call_id_value:
|
||||
call_id_value = (
|
||||
f"{run_id_value}-{task_id_value}-{uuid4().hex[:8]}"
|
||||
if run_id_value
|
||||
else f"{task_id_value}-{uuid4().hex[:8]}"
|
||||
)
|
||||
|
||||
summary = build_tool_content_summary(
|
||||
tool_name=tool_name,
|
||||
args=event.get("args") if isinstance(event.get("args"), dict) else None,
|
||||
result=event.get("result"),
|
||||
error=event.get("error"),
|
||||
)
|
||||
|
||||
raw_result_value = event.get("result")
|
||||
raw_result: dict[str, object] = (
|
||||
raw_result_value if isinstance(raw_result_value, dict) else {}
|
||||
)
|
||||
ui_candidate = raw_result.get("ui")
|
||||
ui_schema = ui_candidate if isinstance(ui_candidate, dict) else None
|
||||
result_type = raw_result.get("type")
|
||||
result_data = raw_result.get("data")
|
||||
if (
|
||||
ui_schema is None
|
||||
and isinstance(result_type, str)
|
||||
and isinstance(result_data, dict)
|
||||
):
|
||||
ui_schema = raw_result
|
||||
|
||||
payload: dict[str, object] = {
|
||||
"toolName": tool_name,
|
||||
"ui_schema": ui_schema,
|
||||
"result": _sanitize_result(raw_result),
|
||||
"error": _sanitize_error(event.get("error")),
|
||||
"callId": call_id_value,
|
||||
"runId": run_id_value,
|
||||
"taskId": task_id_value,
|
||||
"content": summary,
|
||||
}
|
||||
|
||||
metadata: dict[str, object] = {
|
||||
"tool_name": tool_name,
|
||||
"tool_call_id": call_id_value,
|
||||
"summary_version": "v1",
|
||||
}
|
||||
if run_id_value:
|
||||
metadata["run_id"] = run_id_value
|
||||
stage = event.get("stage")
|
||||
if isinstance(stage, str) and stage:
|
||||
metadata["stage"] = stage
|
||||
task_id = event.get("taskId")
|
||||
if isinstance(task_id, str) and task_id:
|
||||
metadata["task_id"] = task_id
|
||||
if task_id_value:
|
||||
metadata["task_id"] = task_id_value
|
||||
|
||||
if self._tool_result_storage is not None and self._tool_result_bucket:
|
||||
safe_run = _sanitize_path_component(run_id_value or "run")
|
||||
safe_call = _sanitize_path_component(call_id_value)
|
||||
storage_path = f"tool-results/{session_id}/{safe_run}/{safe_call}.json"
|
||||
try:
|
||||
await self._tool_result_storage.upload_json(
|
||||
bucket=self._tool_result_bucket,
|
||||
path=storage_path,
|
||||
payload=payload,
|
||||
)
|
||||
metadata["storage_bucket"] = self._tool_result_bucket
|
||||
metadata["storage_path"] = storage_path
|
||||
except Exception: # noqa: BLE001
|
||||
metadata["storage_upload_failed"] = True
|
||||
self._logger.warning(
|
||||
"tool result storage upload failed",
|
||||
session_id=str(session_id),
|
||||
run_id=run_id_value,
|
||||
call_id=call_id_value,
|
||||
storage_path=storage_path,
|
||||
)
|
||||
|
||||
content = summary or json.dumps(
|
||||
payload, ensure_ascii=False, separators=(",", ":")
|
||||
)
|
||||
|
||||
locked_session = await session_repo.lock_session_for_update(
|
||||
session_id=session_id
|
||||
@@ -333,3 +423,69 @@ class SqlAlchemyEventStore:
|
||||
except (InvalidOperation, TypeError, ValueError):
|
||||
return Decimal("0")
|
||||
return parsed if parsed >= 0 else Decimal("0")
|
||||
|
||||
|
||||
def _sanitize_path_component(value: str) -> str:
|
||||
compact = re.sub(r"[^A-Za-z0-9._-]", "-", value.strip())
|
||||
compact = compact.strip(".-")
|
||||
return compact or "id"
|
||||
|
||||
|
||||
def _sanitize_error(value: object) -> str | None:
|
||||
if isinstance(value, str) and value.strip():
|
||||
return " ".join(value.split())[:300]
|
||||
if isinstance(value, dict):
|
||||
for key in ("message", "error", "detail"):
|
||||
item = value.get(key)
|
||||
if isinstance(item, str) and item.strip():
|
||||
return " ".join(item.split())[:300]
|
||||
return None
|
||||
|
||||
|
||||
def _sanitize_result(value: object) -> dict[str, object]:
|
||||
if not isinstance(value, dict):
|
||||
return {}
|
||||
|
||||
def _is_sensitive_key(key: str) -> bool:
|
||||
normalized = key.strip().lower().replace("-", "_")
|
||||
if not normalized:
|
||||
return False
|
||||
exact = {
|
||||
"password",
|
||||
"token",
|
||||
"secret",
|
||||
"api_key",
|
||||
"apikey",
|
||||
"credential",
|
||||
"authorization",
|
||||
"auth",
|
||||
}
|
||||
if normalized in exact:
|
||||
return True
|
||||
patterns = (
|
||||
"password",
|
||||
"token",
|
||||
"secret",
|
||||
"auth",
|
||||
"credential",
|
||||
"api_key",
|
||||
"apikey",
|
||||
"authorization",
|
||||
)
|
||||
return any(pattern in normalized for pattern in patterns)
|
||||
|
||||
def _sanitize_value(item: object) -> object:
|
||||
if isinstance(item, dict):
|
||||
return _sanitize_result(item)
|
||||
if isinstance(item, list):
|
||||
return [_sanitize_value(entry) for entry in item]
|
||||
return item
|
||||
|
||||
sanitized: dict[str, object] = {}
|
||||
for key, item in value.items():
|
||||
key_text = str(key)
|
||||
if _is_sensitive_key(key_text):
|
||||
sanitized[str(key)] = "[REDACTED]"
|
||||
continue
|
||||
sanitized[str(key)] = _sanitize_value(item)
|
||||
return sanitized
|
||||
|
||||
@@ -0,0 +1,124 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def build_tool_content_summary(
|
||||
*,
|
||||
tool_name: str,
|
||||
args: dict[str, Any] | None,
|
||||
result: Any,
|
||||
error: Any,
|
||||
) -> str:
|
||||
error_message = _extract_error_message(error)
|
||||
if error_message is not None:
|
||||
return _truncate(f"{tool_name} 执行失败:{error_message}")
|
||||
|
||||
normalized_args = args if isinstance(args, dict) else {}
|
||||
normalized_result = result if isinstance(result, dict) else {}
|
||||
|
||||
business_failure = _extract_business_failure_message(normalized_result)
|
||||
if business_failure is not None:
|
||||
return _truncate(f"{tool_name} 执行失败:{business_failure}")
|
||||
|
||||
if tool_name == "calendar_write":
|
||||
title = _pick_first_str(normalized_result, ("title",)) or _pick_first_str(
|
||||
normalized_args, ("title",)
|
||||
)
|
||||
start_at = _pick_first_str(normalized_result, ("startAt", "start_at"))
|
||||
if title and start_at:
|
||||
return _truncate(f"已创建日程:{title}({start_at})")
|
||||
if title:
|
||||
return _truncate(f"已创建日程:{title}")
|
||||
|
||||
if tool_name == "calendar_read":
|
||||
total = _extract_total(normalized_result)
|
||||
query = _pick_first_str(normalized_args, ("query",)) or "全部"
|
||||
if total is not None:
|
||||
return _truncate(f"查询到 {total} 条日程({query})")
|
||||
|
||||
if tool_name == "calendar_delete":
|
||||
target = _pick_first_str(normalized_result, ("title", "eventId", "event_id"))
|
||||
if target:
|
||||
return _truncate(f"已删除日程:{target}")
|
||||
|
||||
if tool_name == "calendar_share":
|
||||
target = _pick_first_str(normalized_result, ("target", "user", "userName"))
|
||||
if target:
|
||||
return _truncate(f"已分享日程给 {target}")
|
||||
|
||||
if tool_name == "user_resolve":
|
||||
target = _pick_first_str(normalized_result, ("name", "userName", "userId"))
|
||||
if target:
|
||||
return _truncate(f"已匹配用户:{target}")
|
||||
|
||||
result_content = _pick_first_str(normalized_result, ("content", "message"))
|
||||
if result_content:
|
||||
return _truncate(result_content)
|
||||
|
||||
return _truncate(f"{tool_name} 执行完成")
|
||||
|
||||
|
||||
def _extract_error_message(error: Any) -> str | None:
|
||||
if isinstance(error, str) and error.strip():
|
||||
return error.strip()
|
||||
if isinstance(error, dict):
|
||||
for key in ("message", "error", "detail"):
|
||||
value = error.get(key)
|
||||
if isinstance(value, str) and value.strip():
|
||||
return value.strip()
|
||||
return None
|
||||
|
||||
|
||||
def _pick_first_str(payload: dict[str, Any], keys: tuple[str, ...]) -> str | None:
|
||||
for key in keys:
|
||||
value = payload.get(key)
|
||||
if isinstance(value, str):
|
||||
normalized = " ".join(value.split())
|
||||
if normalized:
|
||||
return normalized
|
||||
return None
|
||||
|
||||
|
||||
def _extract_total(result: dict[str, Any]) -> int | None:
|
||||
candidates: list[Any] = [result.get("total")]
|
||||
data = result.get("data")
|
||||
if isinstance(data, dict):
|
||||
candidates.append(data.get("total"))
|
||||
events = data.get("events")
|
||||
if isinstance(events, list):
|
||||
candidates.append(len(events))
|
||||
for value in candidates:
|
||||
if isinstance(value, bool):
|
||||
continue
|
||||
if isinstance(value, int) and value >= 0:
|
||||
return value
|
||||
if isinstance(value, str) and value.isdigit():
|
||||
return int(value)
|
||||
return None
|
||||
|
||||
|
||||
def _extract_business_failure_message(result: dict[str, Any]) -> str | None:
|
||||
top_ok = result.get("ok")
|
||||
if top_ok is False:
|
||||
top_message = _pick_first_str(result, ("message", "error", "detail"))
|
||||
if top_message:
|
||||
return top_message
|
||||
|
||||
data = result.get("data")
|
||||
if isinstance(data, dict) and data.get("ok") is False:
|
||||
data_message = _pick_first_str(data, ("message", "error", "detail"))
|
||||
if data_message:
|
||||
return data_message
|
||||
code = _pick_first_str(data, ("code",))
|
||||
if code:
|
||||
return code
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _truncate(text: str, limit: int = 80) -> str:
|
||||
normalized = " ".join(text.split())
|
||||
if len(normalized) <= limit:
|
||||
return normalized
|
||||
return normalized[: limit - 3] + "..."
|
||||
@@ -42,11 +42,19 @@ def build_intent_user_prompt(
|
||||
*, user_input: str | list[dict[str, Any]]
|
||||
) -> str | list[dict[str, Any]]:
|
||||
if isinstance(user_input, list):
|
||||
context_messages = _conversation_context_messages(user_input)
|
||||
context_hint = (
|
||||
json.dumps(context_messages, ensure_ascii=True, separators=(",", ":"))
|
||||
if context_messages
|
||||
else "[]"
|
||||
)
|
||||
instruction_text = "\n\n".join(
|
||||
[
|
||||
INTENT_TASK_INSTRUCTION,
|
||||
"[Output Schema]",
|
||||
_schema_json(IntentOutput),
|
||||
"[Conversation Context]",
|
||||
context_hint,
|
||||
"[User Input]",
|
||||
"Use the following multimodal blocks as the latest user input.",
|
||||
]
|
||||
@@ -127,6 +135,56 @@ def _latest_user_content_blocks(
|
||||
return []
|
||||
|
||||
|
||||
def _conversation_context_messages(
|
||||
user_input: list[dict[str, Any]],
|
||||
) -> list[dict[str, str]]:
|
||||
latest_user_index = -1
|
||||
for index in range(len(user_input) - 1, -1, -1):
|
||||
item = user_input[index]
|
||||
if isinstance(item, dict) and item.get("role") == "user":
|
||||
latest_user_index = index
|
||||
break
|
||||
|
||||
if latest_user_index <= 0:
|
||||
return []
|
||||
|
||||
context_items: list[dict[str, str]] = []
|
||||
for item in user_input[:latest_user_index]:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
role = item.get("role")
|
||||
if role not in {"user", "assistant"}:
|
||||
continue
|
||||
content = item.get("content")
|
||||
text = _content_to_text(content)
|
||||
if text:
|
||||
context_items.append({"role": str(role), "content": text})
|
||||
|
||||
if len(context_items) <= 12:
|
||||
return context_items
|
||||
return context_items[-12:]
|
||||
|
||||
|
||||
def _content_to_text(content: Any) -> str:
|
||||
if isinstance(content, str):
|
||||
return " ".join(content.split())
|
||||
if not isinstance(content, list):
|
||||
return ""
|
||||
|
||||
parts: list[str] = []
|
||||
for block in content:
|
||||
if not isinstance(block, dict):
|
||||
continue
|
||||
block_type = block.get("type")
|
||||
if block_type == "text":
|
||||
text = block.get("text")
|
||||
if isinstance(text, str) and text.strip():
|
||||
parts.append(" ".join(text.split()))
|
||||
elif block_type in {"binary", "image"}:
|
||||
parts.append("[image]")
|
||||
return " ".join(parts).strip()
|
||||
|
||||
|
||||
def _binary_source_block(item: dict[str, Any]) -> dict[str, Any] | None:
|
||||
mime_type = item.get("mimeType")
|
||||
media_type = mime_type if isinstance(mime_type, str) and mime_type else "image/png"
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Protocol
|
||||
from uuid import UUID
|
||||
|
||||
import structlog
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.logging import get_logger
|
||||
@@ -31,7 +33,9 @@ class PipelineLike(Protocol):
|
||||
class AgentRouteRuntime:
|
||||
_orchestrator: OrchestratorLike
|
||||
_pipeline: PipelineLike
|
||||
_logger = get_logger("core.agentscope.runtime.agent_route_runtime")
|
||||
_logger: structlog.stdlib.BoundLogger = get_logger(
|
||||
"core.agentscope.runtime.agent_route_runtime"
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self, *, orchestrator: OrchestratorLike, pipeline: PipelineLike
|
||||
@@ -144,15 +148,6 @@ class AgentRouteRuntime:
|
||||
"data": {"stepName": "execution"},
|
||||
},
|
||||
)
|
||||
await self._pipeline.emit(
|
||||
session_id=command.thread_id,
|
||||
event={
|
||||
"type": "step.finish",
|
||||
"threadId": command.thread_id,
|
||||
"runId": command.run_id,
|
||||
"data": {"stepName": "execution"},
|
||||
},
|
||||
)
|
||||
|
||||
await self._emit_stage_text(
|
||||
thread_id=command.thread_id,
|
||||
@@ -191,6 +186,15 @@ class AgentRouteRuntime:
|
||||
task_id=task.task_id,
|
||||
tool_calls=_task_tool_calls(task),
|
||||
)
|
||||
await self._pipeline.emit(
|
||||
session_id=command.thread_id,
|
||||
event={
|
||||
"type": "step.finish",
|
||||
"threadId": command.thread_id,
|
||||
"runId": command.run_id,
|
||||
"data": {"stepName": "execution"},
|
||||
},
|
||||
)
|
||||
|
||||
await self._pipeline.emit(
|
||||
session_id=command.thread_id,
|
||||
@@ -294,6 +298,13 @@ class AgentRouteRuntime:
|
||||
tool_name = tool_call.get("tool_name")
|
||||
if not isinstance(tool_name, str) or not tool_name:
|
||||
continue
|
||||
call_id = f"{run_id}-{task_id}-{index}"
|
||||
result_payload = _build_tool_result_event_payload(
|
||||
tool_name=tool_name,
|
||||
call_id=call_id,
|
||||
raw_result=tool_call.get("result"),
|
||||
raw_error=tool_call.get("error"),
|
||||
)
|
||||
await self._pipeline.emit(
|
||||
session_id=thread_id,
|
||||
event={
|
||||
@@ -301,18 +312,175 @@ class AgentRouteRuntime:
|
||||
"threadId": thread_id,
|
||||
"runId": run_id,
|
||||
"data": {
|
||||
"callId": f"{run_id}-{task_id}-{index}",
|
||||
"messageId": result_payload["messageId"],
|
||||
"toolCallId": call_id,
|
||||
"callId": call_id,
|
||||
"stage": "execution",
|
||||
"taskId": task_id,
|
||||
"toolName": tool_name,
|
||||
"args": tool_call.get("args", {}),
|
||||
"result": tool_call.get("result"),
|
||||
"error": tool_call.get("error"),
|
||||
"args": _sanitize_result(tool_call.get("args", {})),
|
||||
"result": result_payload["result"],
|
||||
"error": result_payload["error"],
|
||||
"ui": result_payload["ui"],
|
||||
"content": result_payload["content"],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _build_tool_result_event_payload(
|
||||
*,
|
||||
tool_name: str,
|
||||
call_id: str,
|
||||
raw_result: Any,
|
||||
raw_error: Any,
|
||||
) -> dict[str, Any]:
|
||||
result = _sanitize_result(_normalize_tool_result(raw_result))
|
||||
error = _sanitize_error(raw_error)
|
||||
|
||||
ui: dict[str, Any] | None = None
|
||||
direct_ui = result.get("ui")
|
||||
if isinstance(direct_ui, dict):
|
||||
ui = direct_ui
|
||||
elif isinstance(result.get("type"), str) and isinstance(result.get("data"), dict):
|
||||
ui = result
|
||||
|
||||
text_content = _extract_result_text_content(result)
|
||||
if text_content is None and isinstance(error, str):
|
||||
text_content = error
|
||||
if text_content is None:
|
||||
text_content = f"{tool_name} 执行完成"
|
||||
|
||||
return {
|
||||
"messageId": f"tool-result-{call_id}",
|
||||
"result": result,
|
||||
"error": error,
|
||||
"ui": ui,
|
||||
"content": text_content,
|
||||
}
|
||||
|
||||
|
||||
def _normalize_tool_result(raw_result: Any) -> dict[str, Any]:
|
||||
if isinstance(raw_result, dict):
|
||||
content = raw_result.get("content")
|
||||
if isinstance(content, str):
|
||||
parsed = _try_parse_json_object(content)
|
||||
if parsed is not None:
|
||||
return parsed
|
||||
return raw_result
|
||||
if isinstance(raw_result, str):
|
||||
parsed = _try_parse_json_object(raw_result)
|
||||
if parsed is not None:
|
||||
return parsed
|
||||
text = raw_result.strip()
|
||||
if text:
|
||||
return {"content": text}
|
||||
if raw_result is not None:
|
||||
return {"value": raw_result}
|
||||
return {}
|
||||
|
||||
|
||||
def _try_parse_json_object(value: str) -> dict[str, Any] | None:
|
||||
raw = value.strip()
|
||||
if not raw:
|
||||
return None
|
||||
try:
|
||||
parsed = json.loads(raw)
|
||||
except ValueError:
|
||||
return None
|
||||
if not isinstance(parsed, dict):
|
||||
return None
|
||||
return parsed
|
||||
|
||||
|
||||
def _extract_result_text_content(result: dict[str, Any]) -> str | None:
|
||||
content = result.get("content")
|
||||
if isinstance(content, str) and content.strip():
|
||||
return content
|
||||
data = result.get("data")
|
||||
if isinstance(data, dict):
|
||||
message = data.get("message")
|
||||
if isinstance(message, str) and message.strip():
|
||||
return message
|
||||
return None
|
||||
|
||||
|
||||
def _sanitize_error(value: Any) -> str | None:
|
||||
if isinstance(value, str) and value.strip():
|
||||
text = " ".join(value.split())
|
||||
return _redact_sensitive_text(text)[:300]
|
||||
if isinstance(value, dict):
|
||||
for key in ("message", "error", "detail"):
|
||||
item = value.get(key)
|
||||
if isinstance(item, str) and item.strip():
|
||||
text = " ".join(item.split())
|
||||
return _redact_sensitive_text(text)[:300]
|
||||
return None
|
||||
|
||||
|
||||
def _sanitize_result(value: Any) -> dict[str, Any]:
|
||||
if not isinstance(value, dict):
|
||||
return {}
|
||||
|
||||
def _is_sensitive_key(key: str) -> bool:
|
||||
normalized = key.strip().lower().replace("-", "_")
|
||||
if not normalized:
|
||||
return False
|
||||
exact = {
|
||||
"password",
|
||||
"token",
|
||||
"secret",
|
||||
"api_key",
|
||||
"apikey",
|
||||
"credential",
|
||||
"authorization",
|
||||
"auth",
|
||||
}
|
||||
if normalized in exact:
|
||||
return True
|
||||
patterns = (
|
||||
"password",
|
||||
"token",
|
||||
"secret",
|
||||
"credential",
|
||||
"api_key",
|
||||
"apikey",
|
||||
"authorization",
|
||||
)
|
||||
return any(pattern in normalized for pattern in patterns)
|
||||
|
||||
def _sanitize_value(item: Any) -> Any:
|
||||
if isinstance(item, dict):
|
||||
return _sanitize_result(item)
|
||||
if isinstance(item, list):
|
||||
return [_sanitize_value(entry) for entry in item]
|
||||
if isinstance(item, str):
|
||||
return _redact_sensitive_text(item)
|
||||
return item
|
||||
|
||||
sanitized: dict[str, Any] = {}
|
||||
for key, item in value.items():
|
||||
key_text = str(key)
|
||||
if _is_sensitive_key(key_text):
|
||||
sanitized[key_text] = "[REDACTED]"
|
||||
continue
|
||||
sanitized[key_text] = _sanitize_value(item)
|
||||
return sanitized
|
||||
|
||||
|
||||
def _redact_sensitive_text(value: str) -> str:
|
||||
redacted = value
|
||||
key_value_patterns = (
|
||||
r"(?i)(authorization)\s*[:=]\s*bearer\s+[^\s,;]+",
|
||||
r"(?i)(password|token|secret|api[_-]?key|authorization|credential)\s*[:=]\s*[^\s,;]+",
|
||||
r"(?i)(password|token|secret|api[_-]?key|authorization|credential)\s+[^\s,;]+",
|
||||
)
|
||||
for pattern in key_value_patterns:
|
||||
redacted = re.sub(pattern, r"\1=[REDACTED]", redacted)
|
||||
redacted = re.sub(r"(?i)bearer\s+[^\s,;]+", "Bearer [REDACTED]", redacted)
|
||||
return redacted
|
||||
|
||||
|
||||
def _text_end_telemetry_payload(metadata: dict[str, Any]) -> dict[str, Any]:
|
||||
payload: dict[str, Any] = {}
|
||||
model = _first_non_empty_str(metadata, keys=("model", "model_code"))
|
||||
|
||||
@@ -52,6 +52,24 @@ def _tools_payload_from_schema(
|
||||
return payload
|
||||
|
||||
|
||||
def _merge_tool_schemas(
|
||||
*schema_sets: list[dict[str, object]],
|
||||
) -> list[dict[str, object]]:
|
||||
merged: list[dict[str, object]] = []
|
||||
seen_names: set[str] = set()
|
||||
for schemas in schema_sets:
|
||||
for schema in schemas:
|
||||
function = schema.get("function")
|
||||
if not isinstance(function, dict):
|
||||
continue
|
||||
name = function.get("name")
|
||||
if not isinstance(name, str) or not name or name in seen_names:
|
||||
continue
|
||||
seen_names.add(name)
|
||||
merged.append(schema)
|
||||
return merged
|
||||
|
||||
|
||||
class AgentScopeRuntimeOrchestrator:
|
||||
_runner: Any
|
||||
_config_loader: Callable[[AsyncSession], Awaitable[dict[str, RuntimeStageConfig]]]
|
||||
@@ -96,10 +114,20 @@ class AgentScopeRuntimeOrchestrator:
|
||||
enable_hitl=False,
|
||||
)
|
||||
intent_tools_schema = intent_toolkit.get_json_schemas()
|
||||
execution_toolkit = build_stage_toolkit(
|
||||
stage="execution",
|
||||
session=session,
|
||||
owner_id=owner_id,
|
||||
user_token=user_token,
|
||||
enable_hitl=True,
|
||||
)
|
||||
execution_tools_schema = execution_toolkit.get_json_schemas()
|
||||
intent_prompt = build_system_prompt(
|
||||
stage="intent",
|
||||
user_context=user_context,
|
||||
tools=_tools_payload_from_schema(intent_tools_schema),
|
||||
tools=_tools_payload_from_schema(
|
||||
_merge_tool_schemas(intent_tools_schema, execution_tools_schema)
|
||||
),
|
||||
)
|
||||
intent_payload = await self._runner.run_json_stage(
|
||||
stage_config=stage_config["intent"],
|
||||
@@ -125,14 +153,6 @@ class AgentScopeRuntimeOrchestrator:
|
||||
|
||||
execution_output: ExecutionBatchOutput | None = None
|
||||
if intent_output.route == "TASK_EXECUTION":
|
||||
execution_toolkit = build_stage_toolkit(
|
||||
stage="execution",
|
||||
session=session,
|
||||
owner_id=owner_id,
|
||||
user_token=user_token,
|
||||
enable_hitl=True,
|
||||
)
|
||||
execution_tools_schema = execution_toolkit.get_json_schemas()
|
||||
execution_prompt = build_system_prompt(
|
||||
stage="execution",
|
||||
user_context=user_context,
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import math
|
||||
from time import perf_counter
|
||||
from typing import Any, cast
|
||||
|
||||
@@ -106,6 +107,9 @@ class AgentScopeReActRunner:
|
||||
stage_config=stage_config,
|
||||
response=response,
|
||||
latency_ms=latency_ms,
|
||||
system_prompt=system_prompt,
|
||||
user_prompt=user_prompt,
|
||||
assistant_text=text_content,
|
||||
)
|
||||
except json.JSONDecodeError as exc:
|
||||
logger.exception(
|
||||
@@ -234,6 +238,9 @@ def _merge_stage_response_metadata(
|
||||
stage_config: RuntimeStageConfig,
|
||||
response: Any,
|
||||
latency_ms: int,
|
||||
system_prompt: str,
|
||||
user_prompt: str | list[dict[str, Any]],
|
||||
assistant_text: str,
|
||||
) -> dict[str, Any]:
|
||||
result = dict(payload)
|
||||
existing = result.get("response_metadata")
|
||||
@@ -247,6 +254,15 @@ def _merge_stage_response_metadata(
|
||||
completion_tokens = _to_non_negative_int(
|
||||
_read_value(usage, "completion_tokens") or _read_value(usage, "output_tokens")
|
||||
)
|
||||
if prompt_tokens is None:
|
||||
prompt_tokens = _estimate_token_count(
|
||||
{
|
||||
"system": system_prompt,
|
||||
"user": user_prompt,
|
||||
}
|
||||
)
|
||||
if completion_tokens is None:
|
||||
completion_tokens = _estimate_token_count(assistant_text)
|
||||
cost = _to_non_negative_float(
|
||||
_read_value(usage, "cost")
|
||||
or _read_value(_read_value(usage, "metadata"), "cost")
|
||||
@@ -352,3 +368,16 @@ def _estimate_cost_by_pricing(
|
||||
)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def _estimate_token_count(value: object) -> int:
|
||||
try:
|
||||
serialized = (
|
||||
value if isinstance(value, str) else json.dumps(value, ensure_ascii=False)
|
||||
)
|
||||
except (TypeError, ValueError):
|
||||
serialized = str(value)
|
||||
normalized = serialized.strip()
|
||||
if not normalized:
|
||||
return 0
|
||||
return max(1, math.ceil(len(normalized) / 4))
|
||||
|
||||
@@ -21,6 +21,7 @@ from core.config.settings import config
|
||||
from core.db.session import AsyncSessionLocal
|
||||
from core.logging import get_logger
|
||||
from core.taskiq.app import bulk_broker, critical_broker, default_broker
|
||||
from core.agentscope.tools.tool_result_storage import create_tool_result_storage
|
||||
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
|
||||
from services.base.redis import get_or_init_redis_client
|
||||
|
||||
@@ -67,16 +68,10 @@ def _build_user_context(*, owner_id: UUID, run_input: RunCommand) -> UserAgentCo
|
||||
def _extract_user_token(
|
||||
*, command: dict[str, Any], run_input: RunCommand
|
||||
) -> str | None:
|
||||
del run_input
|
||||
raw_token = command.get("user_token")
|
||||
if isinstance(raw_token, str) and raw_token.strip():
|
||||
return raw_token.strip()
|
||||
forwarded = (
|
||||
run_input.forwarded_props if isinstance(run_input.forwarded_props, dict) else {}
|
||||
)
|
||||
for key in ("accessToken", "userToken", "token"):
|
||||
value = forwarded.get(key)
|
||||
if isinstance(value, str) and value.strip():
|
||||
return value.strip()
|
||||
return None
|
||||
|
||||
|
||||
@@ -162,7 +157,11 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
)
|
||||
pipeline = AgentScopeEventPipeline(
|
||||
codec=AgentScopeAgUiCodec(),
|
||||
store=SqlAlchemyEventStore(session_factory=AsyncSessionLocal),
|
||||
store=SqlAlchemyEventStore(
|
||||
session_factory=AsyncSessionLocal,
|
||||
tool_result_storage=create_tool_result_storage(),
|
||||
tool_result_bucket=config.storage.bucket,
|
||||
),
|
||||
bus=bus,
|
||||
)
|
||||
runtime = route_runtime_type(
|
||||
|
||||
@@ -67,6 +67,7 @@ def validate_run_request_messages_contract(run_input: RunAgentInput) -> None:
|
||||
message = run_input.messages[0]
|
||||
if getattr(message, "role", None) != "user":
|
||||
raise ValueError("RunAgentInput.messages[0].role must be user")
|
||||
_validate_user_content_blocks(getattr(message, "content", None))
|
||||
extract_latest_user_payload(run_input)
|
||||
|
||||
|
||||
@@ -106,84 +107,76 @@ def extract_latest_user_payload(
|
||||
text_parts.append(text)
|
||||
blocks.append({"type": "text", "text": text})
|
||||
continue
|
||||
if item_type not in {"image", "binary"}:
|
||||
if item_type != "binary":
|
||||
continue
|
||||
source_type: str | None = None
|
||||
source_value: str | None = None
|
||||
source_mime: str | None = None
|
||||
if item_type == "binary":
|
||||
source_mime = (
|
||||
item.get("mimeType")
|
||||
if isinstance(item, dict)
|
||||
else getattr(item, "mime_type", None)
|
||||
)
|
||||
source_url = (
|
||||
item.get("url")
|
||||
if isinstance(item, dict)
|
||||
else getattr(item, "url", None)
|
||||
)
|
||||
source_data = (
|
||||
item.get("data")
|
||||
if isinstance(item, dict)
|
||||
else getattr(item, "data", None)
|
||||
)
|
||||
if isinstance(source_url, str) and source_url:
|
||||
source_type = "url"
|
||||
source_value = source_url
|
||||
elif isinstance(source_data, str) and source_data:
|
||||
source_type = "data"
|
||||
source_value = source_data
|
||||
else:
|
||||
source = getattr(item, "source", None)
|
||||
source_type = (
|
||||
source.get("type")
|
||||
if isinstance(source, dict)
|
||||
else getattr(source, "type", None)
|
||||
)
|
||||
source_value = (
|
||||
source.get("value")
|
||||
if isinstance(source, dict)
|
||||
else getattr(source, "value", None)
|
||||
)
|
||||
source_mime = (
|
||||
source.get("mimeType")
|
||||
if isinstance(source, dict)
|
||||
else getattr(source, "mimeType", None)
|
||||
)
|
||||
if (
|
||||
source_type == "url"
|
||||
and isinstance(source_value, str)
|
||||
and source_value
|
||||
):
|
||||
source_url = (
|
||||
item.get("url")
|
||||
if isinstance(item, dict)
|
||||
else getattr(item, "url", None)
|
||||
)
|
||||
if isinstance(source_url, str) and source_url:
|
||||
blocks.append(
|
||||
{"type": "image_url", "image_url": {"url": source_value}}
|
||||
)
|
||||
elif (
|
||||
source_type == "data"
|
||||
and isinstance(source_value, str)
|
||||
and source_value
|
||||
):
|
||||
mime_type = (
|
||||
source_mime
|
||||
if isinstance(source_mime, str) and source_mime
|
||||
else "image/png"
|
||||
)
|
||||
blocks.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:{mime_type};base64,{source_value}"
|
||||
},
|
||||
}
|
||||
{"type": "image_url", "image_url": {"url": source_url}}
|
||||
)
|
||||
combined = "".join(text_parts).strip()
|
||||
if combined:
|
||||
if combined or blocks:
|
||||
return combined, blocks
|
||||
raise ValueError(
|
||||
"RunAgentInput.messages requires at least one non-empty user message"
|
||||
)
|
||||
|
||||
|
||||
def _validate_user_content_blocks(content: Any) -> None:
|
||||
if isinstance(content, str):
|
||||
if content.strip():
|
||||
return
|
||||
raise ValueError(
|
||||
"RunAgentInput.messages requires at least one non-empty user message"
|
||||
)
|
||||
if not isinstance(content, list):
|
||||
raise ValueError("RunAgentInput.messages[0].content must be string or list")
|
||||
|
||||
has_text = False
|
||||
has_binary = False
|
||||
for item in content:
|
||||
item_type = getattr(item, "type", None)
|
||||
if item_type == "text":
|
||||
text = getattr(item, "text", None)
|
||||
if isinstance(text, str) and text.strip():
|
||||
has_text = True
|
||||
continue
|
||||
if item_type == "binary":
|
||||
mime_type = (
|
||||
item.get("mimeType")
|
||||
if isinstance(item, dict)
|
||||
else getattr(item, "mime_type", None)
|
||||
)
|
||||
url = (
|
||||
item.get("url")
|
||||
if isinstance(item, dict)
|
||||
else getattr(item, "url", None)
|
||||
)
|
||||
data = (
|
||||
item.get("data")
|
||||
if isinstance(item, dict)
|
||||
else getattr(item, "data", None)
|
||||
)
|
||||
if not isinstance(mime_type, str) or not mime_type.startswith("image/"):
|
||||
raise ValueError("binary content requires image mimeType")
|
||||
if not isinstance(url, str) or not url:
|
||||
raise ValueError("binary content requires url")
|
||||
if isinstance(data, str) and data:
|
||||
raise ValueError("binary content data is not allowed")
|
||||
has_binary = True
|
||||
continue
|
||||
raise ValueError("unsupported content block type")
|
||||
|
||||
if not has_text and not has_binary:
|
||||
raise ValueError(
|
||||
"RunAgentInput.messages requires at least one non-empty user message"
|
||||
)
|
||||
|
||||
|
||||
def extract_latest_tool_result(
|
||||
run_input: RunAgentInput,
|
||||
) -> tuple[str, dict[str, object]]:
|
||||
|
||||
@@ -3,10 +3,11 @@ from __future__ import annotations
|
||||
import logging
|
||||
from logging.config import dictConfig
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
import structlog
|
||||
|
||||
from core.config.settings import PROJECT_ROOT, RuntimeSettings, Settings
|
||||
from core.config.settings import PROJECT_ROOT, RuntimeSettings, Settings, config
|
||||
from core.logging.formatters import (
|
||||
build_plain_formatter,
|
||||
build_processor_formatter,
|
||||
@@ -77,7 +78,7 @@ def build_logging_config(runtime: RuntimeSettings) -> dict[str, object]:
|
||||
|
||||
|
||||
def configure_logging(settings: Settings | None = None) -> None:
|
||||
active_settings = settings or Settings()
|
||||
active_settings = settings if settings is not None else cast(Settings, config)
|
||||
runtime = active_settings.runtime
|
||||
|
||||
try:
|
||||
|
||||
@@ -19,19 +19,14 @@ class SupabaseService(BaseServiceProvider):
|
||||
|
||||
async def initialize(self, **_: Any) -> bool:
|
||||
try:
|
||||
self._client = create_client(
|
||||
self._settings.url,
|
||||
self._settings.anon_key,
|
||||
)
|
||||
self._admin_client = create_client(
|
||||
self._settings.url,
|
||||
self._settings.service_role_key,
|
||||
)
|
||||
self._init_clients()
|
||||
self._set_initialized(True)
|
||||
self.logger.info("Supabase service initialized")
|
||||
return True
|
||||
except Exception as exc: # noqa: BLE001
|
||||
self.logger.warning("Supabase service initialization failed", error=str(exc))
|
||||
self.logger.warning(
|
||||
"Supabase service initialization failed", error=str(exc)
|
||||
)
|
||||
self._client = None
|
||||
self._admin_client = None
|
||||
self._set_initialized(False)
|
||||
@@ -51,7 +46,9 @@ class SupabaseService(BaseServiceProvider):
|
||||
return {"status": "unhealthy", "details": {"error": "not initialized"}}
|
||||
try:
|
||||
await asyncio.to_thread(client.auth.get_session)
|
||||
await asyncio.to_thread(admin_client.auth.admin.list_users, page=1, per_page=1)
|
||||
await asyncio.to_thread(
|
||||
admin_client.auth.admin.list_users, page=1, per_page=1
|
||||
)
|
||||
return {
|
||||
"status": "healthy",
|
||||
"details": {
|
||||
@@ -70,17 +67,35 @@ class SupabaseService(BaseServiceProvider):
|
||||
return self._require_admin_client()
|
||||
|
||||
def _require_client(self) -> Any:
|
||||
if self._client is None or self._admin_client is None:
|
||||
self._init_clients()
|
||||
self._set_initialized(True)
|
||||
self.logger.info("Supabase service lazily initialized")
|
||||
client = self._client
|
||||
if client is None:
|
||||
raise RuntimeError("Supabase client is not initialized")
|
||||
return client
|
||||
|
||||
def _require_admin_client(self) -> Any:
|
||||
if self._client is None or self._admin_client is None:
|
||||
self._init_clients()
|
||||
self._set_initialized(True)
|
||||
self.logger.info("Supabase service lazily initialized")
|
||||
admin_client = self._admin_client
|
||||
if admin_client is None:
|
||||
raise RuntimeError("Supabase admin client is not initialized")
|
||||
return admin_client
|
||||
|
||||
def _init_clients(self) -> None:
|
||||
self._client = create_client(
|
||||
self._settings.url,
|
||||
self._settings.anon_key,
|
||||
)
|
||||
self._admin_client = create_client(
|
||||
self._settings.url,
|
||||
self._settings.service_role_key,
|
||||
)
|
||||
|
||||
|
||||
supabase_service: SupabaseService = register_service_instance(
|
||||
"supabase", SupabaseService()
|
||||
|
||||
@@ -3,11 +3,20 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from storage3.exceptions import StorageApiError
|
||||
|
||||
from core.config.settings import config
|
||||
from services.base.supabase import supabase_service
|
||||
|
||||
|
||||
class AgentAttachmentStorage:
|
||||
def _validate_bucket(self, *, bucket: str) -> None:
|
||||
expected = config.storage.bucket
|
||||
if bucket != expected:
|
||||
raise RuntimeError("Invalid attachment bucket")
|
||||
|
||||
def _bucket_client(self, *, bucket: str) -> Any:
|
||||
self._validate_bucket(bucket=bucket)
|
||||
client = supabase_service.get_admin_client()
|
||||
storage = getattr(client, "storage", None)
|
||||
if storage is None:
|
||||
@@ -39,9 +48,82 @@ class AgentAttachmentStorage:
|
||||
},
|
||||
)
|
||||
|
||||
await asyncio.to_thread(_upload)
|
||||
try:
|
||||
await asyncio.to_thread(_upload)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
if not _is_bucket_not_found_error(exc):
|
||||
raise
|
||||
await self._ensure_bucket_exists(bucket=bucket)
|
||||
await asyncio.to_thread(_upload)
|
||||
return path
|
||||
|
||||
async def _ensure_bucket_exists(self, *, bucket: str) -> None:
|
||||
def _ensure() -> None:
|
||||
client = supabase_service.get_admin_client()
|
||||
storage = getattr(client, "storage", None)
|
||||
if storage is None:
|
||||
raise RuntimeError("Supabase storage client unavailable")
|
||||
get_bucket = getattr(storage, "get_bucket", None)
|
||||
if callable(get_bucket):
|
||||
try:
|
||||
get_bucket(bucket)
|
||||
return
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
|
||||
create_bucket = getattr(storage, "create_bucket", None)
|
||||
if not callable(create_bucket):
|
||||
raise RuntimeError("Supabase storage create_bucket is unavailable")
|
||||
try:
|
||||
create_bucket(bucket, options={"public": False})
|
||||
except Exception as exc: # noqa: BLE001
|
||||
message = str(exc).lower()
|
||||
if "already exists" in message or "duplicate" in message:
|
||||
return
|
||||
raise
|
||||
|
||||
await asyncio.to_thread(_ensure)
|
||||
|
||||
async def download_bytes(self, *, bucket: str, path: str) -> bytes:
|
||||
def _download() -> object:
|
||||
bucket_client = self._bucket_client(bucket=bucket)
|
||||
download = getattr(bucket_client, "download", None)
|
||||
if not callable(download):
|
||||
raise RuntimeError("Supabase storage download is unavailable")
|
||||
return download(path)
|
||||
|
||||
raw = await asyncio.to_thread(_download)
|
||||
if isinstance(raw, bytes):
|
||||
return raw
|
||||
if isinstance(raw, bytearray):
|
||||
return bytes(raw)
|
||||
if isinstance(raw, memoryview):
|
||||
return raw.tobytes()
|
||||
raise RuntimeError("Invalid attachment payload")
|
||||
|
||||
async def create_signed_url(
|
||||
self,
|
||||
*,
|
||||
bucket: str,
|
||||
path: str,
|
||||
expires_in_seconds: int,
|
||||
) -> str:
|
||||
def _create_signed_url() -> object:
|
||||
bucket_client = self._bucket_client(bucket=bucket)
|
||||
signer = getattr(bucket_client, "create_signed_url", None)
|
||||
if not callable(signer):
|
||||
raise RuntimeError("Supabase storage signed url is unavailable")
|
||||
return signer(path, expires_in_seconds)
|
||||
|
||||
raw = await asyncio.to_thread(_create_signed_url)
|
||||
if isinstance(raw, str):
|
||||
return raw
|
||||
if isinstance(raw, dict):
|
||||
signed_url = raw.get("signedURL") or raw.get("signedUrl") or raw.get("url")
|
||||
if isinstance(signed_url, str) and signed_url:
|
||||
return signed_url
|
||||
raise RuntimeError("Invalid signed url payload")
|
||||
|
||||
|
||||
def create_attachment_storage() -> AgentAttachmentStorage | None:
|
||||
try:
|
||||
@@ -49,3 +131,11 @@ def create_attachment_storage() -> AgentAttachmentStorage | None:
|
||||
except Exception:
|
||||
return None
|
||||
return AgentAttachmentStorage()
|
||||
|
||||
|
||||
def _is_bucket_not_found_error(exc: Exception) -> bool:
|
||||
if isinstance(exc, StorageApiError):
|
||||
message = str(exc).lower()
|
||||
return "bucket" in message and "not found" in message
|
||||
message = str(exc).lower()
|
||||
return "bucket" in message and "not found" in message
|
||||
|
||||
@@ -9,6 +9,7 @@ from fastapi import HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.config.settings import config
|
||||
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
|
||||
from models.agent_chat_session import AgentChatSession
|
||||
|
||||
@@ -200,6 +201,61 @@ class AgentRepository:
|
||||
return None
|
||||
return str(latest_id)
|
||||
|
||||
async def get_message_attachment_reference(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
message_id: str,
|
||||
attachment_index: int,
|
||||
) -> dict[str, str] | None:
|
||||
try:
|
||||
session_uuid = UUID(session_id)
|
||||
message_uuid = UUID(message_id)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(
|
||||
status_code=422, detail="Invalid message/session id"
|
||||
) from exc
|
||||
|
||||
stmt = (
|
||||
select(AgentChatMessage)
|
||||
.where(AgentChatMessage.id == message_uuid)
|
||||
.where(AgentChatMessage.session_id == session_uuid)
|
||||
.where(AgentChatMessage.deleted_at.is_(None))
|
||||
)
|
||||
message = (await self._session.execute(stmt)).scalar_one_or_none()
|
||||
if message is None:
|
||||
return None
|
||||
|
||||
metadata = (
|
||||
message.metadata_json if isinstance(message.metadata_json, dict) else {}
|
||||
)
|
||||
attachments_raw = metadata.get("attachments")
|
||||
if not isinstance(attachments_raw, list):
|
||||
return None
|
||||
if attachment_index < 0 or attachment_index >= len(attachments_raw):
|
||||
return None
|
||||
|
||||
attachment = attachments_raw[attachment_index]
|
||||
if not isinstance(attachment, dict):
|
||||
return None
|
||||
bucket = attachment.get("bucket")
|
||||
path = attachment.get("path")
|
||||
mime_type = attachment.get("mimeType")
|
||||
if (
|
||||
not isinstance(bucket, str)
|
||||
or not bucket
|
||||
or not isinstance(path, str)
|
||||
or not path
|
||||
or not isinstance(mime_type, str)
|
||||
or not mime_type
|
||||
):
|
||||
return None
|
||||
return {
|
||||
"bucket": bucket,
|
||||
"path": path,
|
||||
"mimeType": mime_type,
|
||||
}
|
||||
|
||||
async def _to_snapshot_message(
|
||||
self, message: AgentChatMessage
|
||||
) -> dict[str, object]:
|
||||
@@ -233,30 +289,65 @@ class AgentRepository:
|
||||
storage_bucket = metadata.get("storage_bucket")
|
||||
storage_path = metadata.get("storage_path")
|
||||
if isinstance(storage_bucket, str) and isinstance(storage_path, str):
|
||||
try:
|
||||
hydrated_content = await self._tool_result_storage.read_json(
|
||||
bucket=storage_bucket,
|
||||
path=storage_path,
|
||||
expected_bucket = config.storage.bucket
|
||||
message_session_id = getattr(message, "session_id", None)
|
||||
expected_prefix = (
|
||||
f"tool-results/{message_session_id}/"
|
||||
if message_session_id is not None
|
||||
else None
|
||||
)
|
||||
tool_call_id = metadata.get("tool_call_id")
|
||||
is_legacy_path = isinstance(
|
||||
tool_call_id, str
|
||||
) and storage_path.endswith(f"/{tool_call_id}.json")
|
||||
if (
|
||||
storage_bucket == expected_bucket
|
||||
and _is_safe_storage_path(storage_path)
|
||||
and (
|
||||
(
|
||||
expected_prefix is not None
|
||||
and storage_path.startswith(expected_prefix)
|
||||
)
|
||||
or (
|
||||
storage_path.startswith("tool-results/")
|
||||
and is_legacy_path
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
hydrated_content = None
|
||||
):
|
||||
try:
|
||||
hydrated_content = (
|
||||
await self._tool_result_storage.read_json(
|
||||
bucket=storage_bucket,
|
||||
path=storage_path,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
hydrated_content = None
|
||||
|
||||
resolved_content = hydrated_content or parsed_content
|
||||
payload["content"] = message.content
|
||||
if resolved_content is not None:
|
||||
result = resolved_content.get("result")
|
||||
if isinstance(result, dict):
|
||||
result_content = result.get("content")
|
||||
if isinstance(result_content, str):
|
||||
payload["content"] = result_content
|
||||
ui = resolved_content.get("ui")
|
||||
if not isinstance(ui, dict):
|
||||
ui = resolved_content.get("ui_schema")
|
||||
if isinstance(ui, dict):
|
||||
payload["ui"] = ui
|
||||
display_content = resolved_content.get("content")
|
||||
if isinstance(display_content, str):
|
||||
if not isinstance(display_content, str):
|
||||
nested_result = resolved_content.get("result")
|
||||
if isinstance(nested_result, dict):
|
||||
nested_content = nested_result.get("content")
|
||||
if isinstance(nested_content, str):
|
||||
display_content = nested_content
|
||||
if (
|
||||
isinstance(display_content, str)
|
||||
and display_content.strip()
|
||||
and (
|
||||
not payload["content"]
|
||||
or _looks_like_offloaded_placeholder(str(payload["content"]))
|
||||
)
|
||||
):
|
||||
payload["content"] = display_content
|
||||
|
||||
if "content" not in payload:
|
||||
payload["content"] = message.content
|
||||
else:
|
||||
payload["content"] = message.content
|
||||
metadata = message.metadata_json or {}
|
||||
@@ -264,7 +355,22 @@ class AgentRepository:
|
||||
metadata.get("attachments") if isinstance(metadata, dict) else None
|
||||
)
|
||||
if isinstance(attachments, list):
|
||||
rendered = [item for item in attachments if isinstance(item, dict)]
|
||||
rendered: list[dict[str, object]] = []
|
||||
for index, item in enumerate(attachments):
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
mime_type = item.get("mimeType")
|
||||
if not isinstance(mime_type, str) or not mime_type:
|
||||
continue
|
||||
rendered.append(
|
||||
{
|
||||
"mimeType": mime_type,
|
||||
"previewPath": (
|
||||
f"/api/v1/agent/runs/{message.session_id}/attachments/"
|
||||
f"{message.id}/{index}"
|
||||
),
|
||||
}
|
||||
)
|
||||
if rendered:
|
||||
payload["attachments"] = rendered
|
||||
return payload
|
||||
@@ -279,3 +385,19 @@ def _derive_session_title(content_text: str) -> str | None:
|
||||
if not normalized:
|
||||
return None
|
||||
return normalized[:80]
|
||||
|
||||
|
||||
def _is_safe_storage_path(path: str) -> bool:
|
||||
normalized = path.strip()
|
||||
if not normalized:
|
||||
return False
|
||||
if normalized.startswith("/"):
|
||||
return False
|
||||
if ".." in normalized:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _looks_like_offloaded_placeholder(content: str) -> bool:
|
||||
normalized = content.strip().lower()
|
||||
return normalized in {'{"offloaded":true}', '{"offloaded": true}'}
|
||||
|
||||
@@ -10,7 +10,17 @@ import time
|
||||
from typing import Annotated, Union
|
||||
|
||||
from ag_ui.core import RunAgentInput
|
||||
from fastapi import APIRouter, Depends, Header, Query, Request, status, UploadFile
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
File,
|
||||
Form,
|
||||
Header,
|
||||
Query,
|
||||
Request,
|
||||
status,
|
||||
UploadFile,
|
||||
)
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
@@ -20,11 +30,18 @@ from core.agentscope.schemas.agui_input import (
|
||||
parse_run_input,
|
||||
validate_run_request_messages_contract,
|
||||
)
|
||||
from core.auth.jwt_verifier import JwtVerifier, TokenValidationError
|
||||
from core.auth.models import CurrentUser
|
||||
from core.config.settings import config
|
||||
from core.logging import get_logger
|
||||
from services.base.redis import get_or_init_redis_client
|
||||
from v1.agent.dependencies import get_agent_service
|
||||
from v1.agent.schemas import AsrTranscribeResponse, TaskAcceptedResponse
|
||||
from v1.agent.schemas import (
|
||||
AsrTranscribeResponse,
|
||||
AttachmentReference,
|
||||
AttachmentUploadResponse,
|
||||
TaskAcceptedResponse,
|
||||
)
|
||||
from v1.agent.service import AgentService, asr_service
|
||||
from v1.users.dependencies import get_current_user
|
||||
|
||||
@@ -38,6 +55,7 @@ _SSE_SLOT_TTL_SECONDS = 15 * 60
|
||||
_MAX_TRANSCRIBE_AUDIO_BYTES = 10 * 1024 * 1024
|
||||
_TRANSCRIBE_READ_CHUNK_BYTES = 1024 * 1024
|
||||
_MULTIPART_OVERHEAD_BYTES = 64 * 1024
|
||||
_MAX_ATTACHMENT_UPLOAD_BYTES = 5 * 1024 * 1024
|
||||
_WAV_HEADER_MIN_BYTES = 12
|
||||
_ALLOWED_AUDIO_CONTENT_TYPES = {
|
||||
"audio/wav",
|
||||
@@ -46,6 +64,42 @@ _ALLOWED_AUDIO_CONTENT_TYPES = {
|
||||
}
|
||||
|
||||
|
||||
def _verified_access_token_for_user(
|
||||
*,
|
||||
authorization: str | None,
|
||||
current_user: CurrentUser,
|
||||
) -> str | None:
|
||||
if not isinstance(authorization, str):
|
||||
return None
|
||||
normalized = authorization.strip()
|
||||
if not normalized:
|
||||
return None
|
||||
if not normalized.lower().startswith("bearer "):
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
token = normalized[7:].strip()
|
||||
if not token:
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
jwt_secret = config.supabase.jwt_secret
|
||||
if jwt_secret is None:
|
||||
raise HTTPException(status_code=503, detail="Auth verifier unavailable")
|
||||
|
||||
verifier = JwtVerifier(
|
||||
issuer=str(config.supabase.jwt_issuer),
|
||||
jwt_secret=jwt_secret.get_secret_value(),
|
||||
jwt_algorithm=config.supabase.jwt_algorithm,
|
||||
)
|
||||
try:
|
||||
payload = verifier.verify(token)
|
||||
except TokenValidationError as exc:
|
||||
raise HTTPException(status_code=401, detail="Unauthorized") from exc
|
||||
|
||||
subject = payload.get("sub")
|
||||
if not isinstance(subject, str) or subject != str(current_user.id):
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
return token
|
||||
|
||||
|
||||
def _looks_like_wav_header(header: bytes) -> bool:
|
||||
if len(header) < _WAV_HEADER_MIN_BYTES:
|
||||
return False
|
||||
@@ -111,6 +165,7 @@ async def enqueue_run(
|
||||
request: RunAgentInput,
|
||||
service: Annotated[AgentService, Depends(get_agent_service)],
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
authorization: str | None = Header(default=None, alias="Authorization"),
|
||||
) -> TaskAcceptedResponse:
|
||||
try:
|
||||
normalized = parse_run_input(request.model_dump(mode="json", by_alias=True))
|
||||
@@ -120,10 +175,15 @@ async def enqueue_run(
|
||||
allowed = await _allow_run_request(user_id=str(current_user.id))
|
||||
if not allowed:
|
||||
raise HTTPException(status_code=429, detail="Too many run requests")
|
||||
user_token = _verified_access_token_for_user(
|
||||
authorization=authorization,
|
||||
current_user=current_user,
|
||||
)
|
||||
|
||||
task = await service.enqueue_run(
|
||||
run_input=request,
|
||||
current_user=current_user,
|
||||
user_token=user_token,
|
||||
)
|
||||
return TaskAcceptedResponse(
|
||||
taskId=task.task_id,
|
||||
@@ -143,6 +203,7 @@ async def enqueue_resume(
|
||||
request: RunAgentInput,
|
||||
service: Annotated[AgentService, Depends(get_agent_service)],
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
authorization: str | None = Header(default=None, alias="Authorization"),
|
||||
) -> TaskAcceptedResponse:
|
||||
if request.thread_id != thread_id:
|
||||
raise HTTPException(status_code=422, detail="thread_id path/body mismatch")
|
||||
@@ -154,10 +215,15 @@ async def enqueue_resume(
|
||||
allowed = await _allow_run_request(user_id=str(current_user.id))
|
||||
if not allowed:
|
||||
raise HTTPException(status_code=429, detail="Too many run requests")
|
||||
user_token = _verified_access_token_for_user(
|
||||
authorization=authorization,
|
||||
current_user=current_user,
|
||||
)
|
||||
task = await service.enqueue_resume(
|
||||
thread_id=thread_id,
|
||||
run_input=request,
|
||||
current_user=current_user,
|
||||
user_token=user_token,
|
||||
)
|
||||
return TaskAcceptedResponse(
|
||||
taskId=task.task_id,
|
||||
@@ -253,6 +319,31 @@ async def get_history_snapshot(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/runs/{thread_id}/attachments/{message_id}/{attachment_index}")
|
||||
async def get_attachment_preview(
|
||||
thread_id: str,
|
||||
message_id: str,
|
||||
attachment_index: int,
|
||||
service: Annotated[AgentService, Depends(get_agent_service)],
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
) -> StreamingResponse:
|
||||
if attachment_index < 0:
|
||||
raise HTTPException(status_code=422, detail="Invalid attachment index")
|
||||
payload, mime_type = await service.get_attachment_preview(
|
||||
thread_id=thread_id,
|
||||
message_id=message_id,
|
||||
attachment_index=attachment_index,
|
||||
current_user=current_user,
|
||||
)
|
||||
return StreamingResponse(
|
||||
iter([payload]),
|
||||
media_type=mime_type,
|
||||
headers={
|
||||
"Cache-Control": "private, max-age=300",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/history")
|
||||
async def get_user_history_snapshot(
|
||||
service: Annotated[AgentService, Depends(get_agent_service)],
|
||||
@@ -267,6 +358,34 @@ async def get_user_history_snapshot(
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/attachments",
|
||||
response_model=AttachmentUploadResponse,
|
||||
status_code=status.HTTP_200_OK,
|
||||
)
|
||||
async def upload_attachment(
|
||||
service: Annotated[AgentService, Depends(get_agent_service)],
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
thread_id: str = Form(alias="threadId"),
|
||||
file: UploadFile = File(),
|
||||
) -> AttachmentUploadResponse:
|
||||
payload = await file.read()
|
||||
if not payload:
|
||||
raise HTTPException(status_code=422, detail="Empty attachment")
|
||||
if len(payload) > _MAX_ATTACHMENT_UPLOAD_BYTES:
|
||||
raise HTTPException(status_code=413, detail="Attachment too large")
|
||||
attachment = await service.upload_attachment(
|
||||
thread_id=thread_id,
|
||||
filename=file.filename,
|
||||
content_type=file.content_type,
|
||||
payload=payload,
|
||||
current_user=current_user,
|
||||
)
|
||||
return AttachmentUploadResponse(
|
||||
attachment=AttachmentReference.model_validate(attachment),
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/transcribe",
|
||||
response_model=AsrTranscribeResponse,
|
||||
|
||||
@@ -14,3 +14,16 @@ class TaskAcceptedResponse(BaseModel):
|
||||
|
||||
class AsrTranscribeResponse(BaseModel):
|
||||
transcript: str = Field(description="Transcribed text from audio")
|
||||
|
||||
|
||||
class AttachmentReference(BaseModel):
|
||||
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
|
||||
|
||||
bucket: str
|
||||
path: str
|
||||
mime_type: str = Field(alias="mimeType")
|
||||
url: str
|
||||
|
||||
|
||||
class AttachmentUploadResponse(BaseModel):
|
||||
attachment: AttachmentReference
|
||||
|
||||
+297
-60
@@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
from dataclasses import dataclass
|
||||
from datetime import date
|
||||
import hashlib
|
||||
@@ -19,17 +18,22 @@ from core.config.settings import config
|
||||
from core.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
_ALLOWED_ATTACHMENT_MIME_TYPES = {"image/png", "image/jpeg", "image/webp"}
|
||||
_MAX_ATTACHMENT_BYTES = 5 * 1024 * 1024
|
||||
_MAX_TOTAL_ATTACHMENT_BYTES = 12 * 1024 * 1024
|
||||
|
||||
|
||||
def _extract_user_token_from_run_input(run_input: RunAgentInput) -> str | None:
|
||||
forwarded = run_input.forwarded_props
|
||||
if not isinstance(forwarded, dict):
|
||||
def _normalize_bearer_token(value: str | None) -> str | None:
|
||||
if not isinstance(value, str):
|
||||
return None
|
||||
for key in ("accessToken", "userToken", "token"):
|
||||
value = forwarded.get(key)
|
||||
if isinstance(value, str) and value.strip():
|
||||
return value.strip()
|
||||
return None
|
||||
normalized = value.strip()
|
||||
if not normalized:
|
||||
return None
|
||||
lower = normalized.lower()
|
||||
if lower.startswith("bearer "):
|
||||
token = normalized[7:].strip()
|
||||
return token or None
|
||||
return normalized
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -66,6 +70,14 @@ class AgentRepositoryLike(Protocol):
|
||||
metadata: dict[str, object] | None,
|
||||
) -> None: ...
|
||||
|
||||
async def get_message_attachment_reference(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
message_id: str,
|
||||
attachment_index: int,
|
||||
) -> dict[str, str] | None: ...
|
||||
|
||||
|
||||
class QueueClientLike(Protocol):
|
||||
async def enqueue(
|
||||
@@ -92,6 +104,16 @@ class AttachmentStorageLike(Protocol):
|
||||
content_type: str,
|
||||
) -> str: ...
|
||||
|
||||
async def download_bytes(self, *, bucket: str, path: str) -> bytes: ...
|
||||
|
||||
async def create_signed_url(
|
||||
self,
|
||||
*,
|
||||
bucket: str,
|
||||
path: str,
|
||||
expires_in_seconds: int,
|
||||
) -> str: ...
|
||||
|
||||
|
||||
def ensure_session_owner(*, owner_id: str, current_user: CurrentUser) -> None:
|
||||
if owner_id != str(current_user.id):
|
||||
@@ -104,6 +126,8 @@ class AgentService:
|
||||
_stream: EventStreamLike
|
||||
_attachment_storage: AttachmentStorageLike | None
|
||||
|
||||
_SIGNED_URL_EXPIRES_IN_SECONDS = 3600
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -122,6 +146,7 @@ class AgentService:
|
||||
*,
|
||||
run_input: RunAgentInput,
|
||||
current_user: CurrentUser,
|
||||
user_token: str | None = None,
|
||||
) -> TaskAccepted:
|
||||
created = False
|
||||
thread_id = run_input.thread_id
|
||||
@@ -161,7 +186,7 @@ class AgentService:
|
||||
command={
|
||||
"command": "run",
|
||||
"owner_id": str(current_user.id),
|
||||
"user_token": _extract_user_token_from_run_input(run_input),
|
||||
"user_token": _normalize_bearer_token(user_token),
|
||||
"run_input": run_input.model_dump(mode="json", by_alias=True),
|
||||
},
|
||||
dedup_key=None,
|
||||
@@ -179,57 +204,115 @@ class AgentService:
|
||||
run_input: RunAgentInput,
|
||||
current_user: CurrentUser,
|
||||
) -> tuple[str, dict[str, object] | None]:
|
||||
text, content_blocks = extract_latest_user_payload(run_input)
|
||||
text, _ = extract_latest_user_payload(run_input)
|
||||
content_blocks = _extract_latest_user_content_blocks(run_input)
|
||||
attachments: list[dict[str, object]] = []
|
||||
if self._attachment_storage is not None:
|
||||
for index, block in enumerate(content_blocks):
|
||||
if not isinstance(block, dict):
|
||||
continue
|
||||
if block.get("type") != "image_url":
|
||||
continue
|
||||
image_value = block.get("image_url")
|
||||
if not isinstance(image_value, dict):
|
||||
continue
|
||||
url = image_value.get("url")
|
||||
if not isinstance(url, str) or not url.startswith("data:"):
|
||||
continue
|
||||
decoded = _decode_data_url(url)
|
||||
if decoded is None:
|
||||
continue
|
||||
mime_type, payload = decoded
|
||||
suffix = _mime_to_suffix(mime_type)
|
||||
checksum = hashlib.sha1(payload).hexdigest()[:16]
|
||||
path = (
|
||||
f"agent-inputs/{current_user.id}/{run_input.thread_id}/"
|
||||
f"{run_input.run_id}/attachment-{index}-{checksum}.{suffix}"
|
||||
binary_blocks = [
|
||||
block
|
||||
for block in content_blocks
|
||||
if isinstance(block, dict) and block.get("type") == "binary"
|
||||
]
|
||||
if binary_blocks:
|
||||
if self._attachment_storage is None:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Attachment storage unavailable",
|
||||
)
|
||||
bucket_name = config.storage.bucket
|
||||
forwarded_props = (
|
||||
run_input.forwarded_props
|
||||
if isinstance(run_input.forwarded_props, dict)
|
||||
else {}
|
||||
)
|
||||
raw_attachments = forwarded_props.get("attachments")
|
||||
if not isinstance(raw_attachments, list):
|
||||
raise HTTPException(
|
||||
status_code=422, detail="Invalid attachments payload"
|
||||
)
|
||||
if len(raw_attachments) != len(binary_blocks):
|
||||
raise HTTPException(
|
||||
status_code=422, detail="Invalid attachments payload"
|
||||
)
|
||||
|
||||
total_attachment_bytes = 0
|
||||
expected_prefix = f"agent-inputs/{current_user.id}/{run_input.thread_id}/"
|
||||
for index, raw_attachment in enumerate(raw_attachments):
|
||||
if not isinstance(raw_attachment, dict):
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail="Invalid attachment reference",
|
||||
)
|
||||
bucket = raw_attachment.get("bucket")
|
||||
path = raw_attachment.get("path")
|
||||
mime_type = raw_attachment.get("mimeType")
|
||||
if (
|
||||
not isinstance(bucket, str)
|
||||
or not isinstance(path, str)
|
||||
or not isinstance(mime_type, str)
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail="Invalid attachment reference",
|
||||
)
|
||||
if bucket != config.storage.bucket:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
if not _is_safe_attachment_path(path, expected_prefix=expected_prefix):
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
if mime_type.lower() not in _ALLOWED_ATTACHMENT_MIME_TYPES:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail="Unsupported attachment type",
|
||||
)
|
||||
|
||||
binary_block = binary_blocks[index]
|
||||
binary_mime = binary_block.get("mimeType")
|
||||
binary_url = binary_block.get("url")
|
||||
if (
|
||||
not isinstance(binary_mime, str)
|
||||
or binary_mime != mime_type
|
||||
or not isinstance(binary_url, str)
|
||||
or not binary_url
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail="Invalid attachments payload",
|
||||
)
|
||||
|
||||
try:
|
||||
stored_path = await self._attachment_storage.upload_bytes(
|
||||
bucket=bucket_name,
|
||||
payload = await self._attachment_storage.download_bytes(
|
||||
bucket=bucket,
|
||||
path=path,
|
||||
content=payload,
|
||||
content_type=mime_type,
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
logger.exception(
|
||||
"Attachment upload failed",
|
||||
"Attachment validation download failed",
|
||||
extra={
|
||||
"bucket": bucket_name,
|
||||
"bucket": bucket,
|
||||
"path": path,
|
||||
"mime_type": mime_type,
|
||||
"thread_id": run_input.thread_id,
|
||||
"run_id": run_input.run_id,
|
||||
},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to upload attachment",
|
||||
detail="Failed to fetch attachment",
|
||||
)
|
||||
payload_size = len(payload)
|
||||
if payload_size > _MAX_ATTACHMENT_BYTES:
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail="Attachment too large",
|
||||
)
|
||||
total_attachment_bytes += payload_size
|
||||
if total_attachment_bytes > _MAX_TOTAL_ATTACHMENT_BYTES:
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail="Attachments too large",
|
||||
)
|
||||
|
||||
attachments.append(
|
||||
{
|
||||
"bucket": bucket_name,
|
||||
"path": stored_path,
|
||||
"bucket": bucket,
|
||||
"path": path,
|
||||
"mimeType": mime_type,
|
||||
}
|
||||
)
|
||||
@@ -238,12 +321,94 @@ class AgentService:
|
||||
metadata["attachments"] = attachments
|
||||
return text, metadata or None
|
||||
|
||||
async def upload_attachment(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
filename: str | None,
|
||||
content_type: str | None,
|
||||
payload: bytes,
|
||||
current_user: CurrentUser,
|
||||
) -> dict[str, str]:
|
||||
try:
|
||||
owner = await self._repository.get_session_owner(session_id=thread_id)
|
||||
except HTTPException as exc:
|
||||
if exc.status_code != 404:
|
||||
raise
|
||||
try:
|
||||
await self._repository.create_session_for_user(
|
||||
user_id=str(current_user.id),
|
||||
session_id=thread_id,
|
||||
)
|
||||
await self._repository.commit()
|
||||
except IntegrityError:
|
||||
await self._repository.rollback()
|
||||
owner = await self._repository.get_session_owner(session_id=thread_id)
|
||||
ensure_session_owner(owner_id=owner, current_user=current_user)
|
||||
else:
|
||||
ensure_session_owner(owner_id=owner, current_user=current_user)
|
||||
if self._attachment_storage is None:
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Attachment storage unavailable"
|
||||
)
|
||||
|
||||
if not isinstance(content_type, str):
|
||||
raise HTTPException(status_code=422, detail="Unsupported attachment type")
|
||||
mime_type = content_type.lower()
|
||||
if mime_type not in _ALLOWED_ATTACHMENT_MIME_TYPES:
|
||||
raise HTTPException(status_code=422, detail="Unsupported attachment type")
|
||||
if not payload:
|
||||
raise HTTPException(status_code=422, detail="Empty attachment")
|
||||
if len(payload) > _MAX_ATTACHMENT_BYTES:
|
||||
raise HTTPException(status_code=413, detail="Attachment too large")
|
||||
|
||||
suffix = _mime_to_suffix(mime_type)
|
||||
checksum = hashlib.sha1(payload).hexdigest()[:16]
|
||||
filename_seed = filename if isinstance(filename, str) and filename else "upload"
|
||||
filename_hash = hashlib.sha1(filename_seed.encode("utf-8")).hexdigest()[:8]
|
||||
path = (
|
||||
f"agent-inputs/{current_user.id}/{thread_id}/uploads/"
|
||||
f"{filename_hash}-{checksum}.{suffix}"
|
||||
)
|
||||
bucket_name = config.storage.bucket
|
||||
try:
|
||||
stored_path = await self._attachment_storage.upload_bytes(
|
||||
bucket=bucket_name,
|
||||
path=path,
|
||||
content=payload,
|
||||
content_type=mime_type,
|
||||
)
|
||||
signed_url = await self._attachment_storage.create_signed_url(
|
||||
bucket=bucket_name,
|
||||
path=stored_path,
|
||||
expires_in_seconds=self._SIGNED_URL_EXPIRES_IN_SECONDS,
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
logger.exception(
|
||||
"Attachment upload failed",
|
||||
extra={
|
||||
"bucket": bucket_name,
|
||||
"path": path,
|
||||
"mime_type": mime_type,
|
||||
"thread_id": thread_id,
|
||||
},
|
||||
)
|
||||
raise HTTPException(status_code=502, detail="Failed to upload attachment")
|
||||
|
||||
return {
|
||||
"bucket": bucket_name,
|
||||
"path": stored_path,
|
||||
"mimeType": mime_type,
|
||||
"url": signed_url,
|
||||
}
|
||||
|
||||
async def enqueue_resume(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
run_input: RunAgentInput,
|
||||
current_user: CurrentUser,
|
||||
user_token: str | None = None,
|
||||
) -> TaskAccepted:
|
||||
owner = await self._repository.get_session_owner(session_id=thread_id)
|
||||
ensure_session_owner(owner_id=owner, current_user=current_user)
|
||||
@@ -253,7 +418,7 @@ class AgentService:
|
||||
command={
|
||||
"command": "resume",
|
||||
"owner_id": str(current_user.id),
|
||||
"user_token": _extract_user_token_from_run_input(run_input),
|
||||
"user_token": _normalize_bearer_token(user_token),
|
||||
"run_input": run_input.model_dump(mode="json", by_alias=True),
|
||||
},
|
||||
dedup_key=dedup_key,
|
||||
@@ -336,6 +501,63 @@ class AgentService:
|
||||
current_user=current_user,
|
||||
)
|
||||
|
||||
async def get_attachment_preview(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
message_id: str,
|
||||
attachment_index: int,
|
||||
current_user: CurrentUser,
|
||||
) -> tuple[bytes, str]:
|
||||
owner = await self._repository.get_session_owner(session_id=thread_id)
|
||||
ensure_session_owner(owner_id=owner, current_user=current_user)
|
||||
if self._attachment_storage is None:
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Attachment storage unavailable"
|
||||
)
|
||||
|
||||
ref = await self._repository.get_message_attachment_reference(
|
||||
session_id=thread_id,
|
||||
message_id=message_id,
|
||||
attachment_index=attachment_index,
|
||||
)
|
||||
if ref is None:
|
||||
raise HTTPException(status_code=404, detail="Attachment not found")
|
||||
|
||||
bucket = ref.get("bucket")
|
||||
path = ref.get("path")
|
||||
mime_type = ref.get("mimeType")
|
||||
if (
|
||||
not isinstance(bucket, str)
|
||||
or not isinstance(path, str)
|
||||
or not isinstance(mime_type, str)
|
||||
):
|
||||
raise HTTPException(status_code=404, detail="Attachment not found")
|
||||
if bucket != config.storage.bucket:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
|
||||
expected_prefix = f"agent-inputs/{current_user.id}/{thread_id}/"
|
||||
if not _is_safe_attachment_path(path, expected_prefix=expected_prefix):
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
|
||||
try:
|
||||
payload = await self._attachment_storage.download_bytes(
|
||||
bucket=bucket,
|
||||
path=path,
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
logger.exception(
|
||||
"Attachment download failed",
|
||||
extra={
|
||||
"thread_id": thread_id,
|
||||
"message_id": message_id,
|
||||
"attachment_index": attachment_index,
|
||||
"bucket": bucket,
|
||||
},
|
||||
)
|
||||
raise HTTPException(status_code=502, detail="Failed to fetch attachment")
|
||||
return payload, mime_type
|
||||
|
||||
|
||||
class AsrService:
|
||||
def __init__(self) -> None:
|
||||
@@ -445,22 +667,26 @@ class AsrService:
|
||||
asr_service = AsrService()
|
||||
|
||||
|
||||
def _decode_data_url(data_url: str) -> tuple[str, bytes] | None:
|
||||
if not data_url.startswith("data:"):
|
||||
return None
|
||||
header, sep, payload = data_url.partition(",")
|
||||
if not sep:
|
||||
return None
|
||||
mime_type = "image/png"
|
||||
if ";" in header:
|
||||
maybe_mime = header[5:].split(";", 1)[0].strip()
|
||||
if maybe_mime:
|
||||
mime_type = maybe_mime
|
||||
try:
|
||||
decoded = base64.b64decode(payload, validate=True)
|
||||
except ValueError:
|
||||
return None
|
||||
return mime_type, decoded
|
||||
def _extract_latest_user_content_blocks(
|
||||
run_input: RunAgentInput,
|
||||
) -> list[dict[str, Any]]:
|
||||
if not run_input.messages:
|
||||
return []
|
||||
latest = run_input.messages[-1]
|
||||
content = getattr(latest, "content", None)
|
||||
if not isinstance(content, list):
|
||||
return []
|
||||
blocks: list[dict[str, Any]] = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
blocks.append(item)
|
||||
continue
|
||||
model_dump = getattr(item, "model_dump", None)
|
||||
if callable(model_dump):
|
||||
dumped = model_dump(mode="json", by_alias=True, exclude_none=True)
|
||||
if isinstance(dumped, dict):
|
||||
blocks.append(dumped)
|
||||
return blocks
|
||||
|
||||
|
||||
def _mime_to_suffix(mime_type: str) -> str:
|
||||
@@ -470,3 +696,14 @@ def _mime_to_suffix(mime_type: str) -> str:
|
||||
"image/webp": "webp",
|
||||
}
|
||||
return mapping.get(mime_type.lower(), "bin")
|
||||
|
||||
|
||||
def _is_safe_attachment_path(path: str, *, expected_prefix: str) -> bool:
|
||||
normalized = path.strip()
|
||||
if not normalized:
|
||||
return False
|
||||
if normalized.startswith("/"):
|
||||
return False
|
||||
if ".." in normalized:
|
||||
return False
|
||||
return normalized.startswith(expected_prefix)
|
||||
|
||||
@@ -1,16 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app import app
|
||||
from core.auth.models import CurrentUser
|
||||
from v1.auth.dependencies import get_auth_service
|
||||
from v1.users.dependencies import get_current_user
|
||||
from v1.auth.rate_limit import reset_rate_limit_state
|
||||
from v1.auth.schemas import (
|
||||
AuthUser,
|
||||
|
||||
@@ -18,8 +18,14 @@ class _FakeAgentService:
|
||||
def __init__(self) -> None:
|
||||
self._stream_called = False
|
||||
|
||||
async def enqueue_run(self, *, run_input: RunAgentInput, current_user: CurrentUser):
|
||||
del current_user
|
||||
async def enqueue_run(
|
||||
self,
|
||||
*,
|
||||
run_input: RunAgentInput,
|
||||
current_user: CurrentUser,
|
||||
user_token: str | None = None,
|
||||
):
|
||||
del current_user, user_token
|
||||
return SimpleNamespace(
|
||||
task_id="task-run-1",
|
||||
thread_id=run_input.thread_id,
|
||||
@@ -33,8 +39,9 @@ class _FakeAgentService:
|
||||
thread_id: str,
|
||||
run_input: RunAgentInput,
|
||||
current_user: CurrentUser,
|
||||
user_token: str | None = None,
|
||||
):
|
||||
del thread_id, current_user
|
||||
del thread_id, current_user, user_token
|
||||
return SimpleNamespace(
|
||||
task_id="task-resume-1",
|
||||
thread_id=run_input.thread_id,
|
||||
@@ -109,6 +116,23 @@ class _FakeAgentService:
|
||||
},
|
||||
}
|
||||
|
||||
async def upload_attachment(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
filename: str | None,
|
||||
content_type: str | None,
|
||||
payload: bytes,
|
||||
current_user: CurrentUser,
|
||||
) -> dict[str, str]:
|
||||
del filename, content_type, payload, current_user
|
||||
return {
|
||||
"bucket": "bucket-test",
|
||||
"path": f"agent-inputs/user/{thread_id}/upload.png",
|
||||
"mimeType": "image/png",
|
||||
"url": "https://signed.example/upload.png",
|
||||
}
|
||||
|
||||
|
||||
class _FailingStreamAgentService(_FakeAgentService):
|
||||
async def stream_events(
|
||||
@@ -393,6 +417,31 @@ def test_resume_accepts_tool_message_without_user_message() -> None:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_upload_attachment_returns_reference() -> None:
|
||||
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=uuid4(), email="user@example.com"
|
||||
)
|
||||
client = TestClient(app)
|
||||
|
||||
file_payload = BytesIO(b"png")
|
||||
file_payload.name = "demo.png"
|
||||
|
||||
try:
|
||||
response = client.post(
|
||||
"/api/v1/agent/attachments",
|
||||
data={"threadId": "00000000-0000-0000-0000-000000000001"},
|
||||
files={"file": ("demo.png", file_payload, "image/png")},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
attachment = body["attachment"]
|
||||
assert attachment["mimeType"] == "image/png"
|
||||
assert "00000000-0000-0000-0000-000000000001" in attachment["path"]
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_asr_transcribe_returns_sync_transcript(monkeypatch) -> None:
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=uuid4(), email="user@example.com"
|
||||
|
||||
@@ -40,3 +40,34 @@ def test_reserved_keys_in_data_cannot_override_wire_fields() -> None:
|
||||
assert result["threadId"] == "thread-1"
|
||||
assert result["runId"] == "run-1"
|
||||
assert result["message"] == "ok"
|
||||
|
||||
|
||||
def test_tool_result_wire_event_filters_sensitive_fields() -> None:
|
||||
internal = {
|
||||
"type": "tool.result",
|
||||
"threadId": "thread-1",
|
||||
"runId": "run-1",
|
||||
"data": {
|
||||
"messageId": "tool-result-1",
|
||||
"toolCallId": "call-1",
|
||||
"callId": "call-1",
|
||||
"toolName": "calendar_write",
|
||||
"content": "summary",
|
||||
"ui": {"type": "calendar_operation.v1", "data": {"ok": True}},
|
||||
"args": {"token": "secret"},
|
||||
"result": {"raw": "secret"},
|
||||
"error": "stack trace",
|
||||
},
|
||||
}
|
||||
|
||||
result = to_agui_wire_event(internal)
|
||||
|
||||
assert result["type"] == "TOOL_CALL_RESULT"
|
||||
assert result["messageId"] == "tool-result-1"
|
||||
assert result["toolCallId"] == "call-1"
|
||||
assert result["toolName"] == "calendar_write"
|
||||
assert result["content"] == "summary"
|
||||
assert isinstance(result.get("ui"), dict)
|
||||
assert "args" not in result
|
||||
assert "result" not in result
|
||||
assert "error" not in result
|
||||
|
||||
@@ -28,6 +28,27 @@ class _FakeSessionCtx:
|
||||
del exc_type, exc, tb
|
||||
|
||||
|
||||
class _FakeToolResultStorage:
|
||||
def __init__(self) -> None:
|
||||
self.upload_calls: list[dict[str, object]] = []
|
||||
|
||||
async def upload_json(
|
||||
self,
|
||||
*,
|
||||
bucket: str,
|
||||
path: str,
|
||||
payload: dict[str, object],
|
||||
) -> str:
|
||||
self.upload_calls.append(
|
||||
{
|
||||
"bucket": bucket,
|
||||
"path": path,
|
||||
"payload": payload,
|
||||
}
|
||||
)
|
||||
return path
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_marks_session_running_on_run_started(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
@@ -300,7 +321,12 @@ async def test_store_persists_tool_call_result_as_tool_message(
|
||||
monkeypatch.setattr(store_module, "MessageRepository", _FakeMessageRepository)
|
||||
monkeypatch.setattr(store_module, "AgentChatSessionStatus", _SessionStatus)
|
||||
|
||||
store = store_module.SqlAlchemyEventStore(session_factory=lambda: _FakeSessionCtx())
|
||||
fake_storage = _FakeToolResultStorage()
|
||||
store = store_module.SqlAlchemyEventStore(
|
||||
session_factory=lambda: _FakeSessionCtx(),
|
||||
tool_result_storage=fake_storage,
|
||||
tool_result_bucket="agent-tool-results",
|
||||
)
|
||||
await store.persist(
|
||||
{
|
||||
"type": "TOOL_CALL_RESULT",
|
||||
@@ -310,7 +336,7 @@ async def test_store_persists_tool_call_result_as_tool_message(
|
||||
"taskId": "t1",
|
||||
"stage": "execution",
|
||||
"args": {"title": "A"},
|
||||
"result": {"event_id": "evt-1"},
|
||||
"result": {"event_id": "evt-1", "token": "secret"},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -318,9 +344,94 @@ async def test_store_persists_tool_call_result_as_tool_message(
|
||||
assert getattr(append_kwargs["role"], "value", None) == "tool"
|
||||
assert append_kwargs["tool_name"] == "calendar_write"
|
||||
assert append_kwargs["metadata"]["task_id"] == "t1"
|
||||
tool_call_id = append_kwargs["metadata"]["tool_call_id"]
|
||||
assert isinstance(tool_call_id, str)
|
||||
assert tool_call_id.startswith("run-1-t1-")
|
||||
assert append_kwargs["metadata"]["storage_bucket"] == "agent-tool-results"
|
||||
assert isinstance(append_kwargs["metadata"]["storage_path"], str)
|
||||
assert append_kwargs["content"].startswith("已创建日程")
|
||||
assert len(fake_storage.upload_calls) == 1
|
||||
uploaded = fake_storage.upload_calls[0]
|
||||
assert uploaded["bucket"] == "agent-tool-results"
|
||||
payload = cast(dict[str, Any], uploaded["payload"])
|
||||
assert payload["toolName"] == "calendar_write"
|
||||
assert "args" not in payload
|
||||
assert isinstance(payload.get("result"), dict)
|
||||
assert payload["result"]["token"] == "[REDACTED]"
|
||||
assert captured["message_delta"] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_sanitizes_nested_sensitive_fields_in_result_payload(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
captured: dict[str, object] = {}
|
||||
fake_chat_session = SimpleNamespace(state_snapshot={}, message_count=0)
|
||||
|
||||
class _FakeSessionRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def get_session(self, *, session_id): # noqa: ANN001
|
||||
del session_id
|
||||
return fake_chat_session
|
||||
|
||||
async def lock_session_for_update(self, *, session_id): # noqa: ANN001
|
||||
del session_id
|
||||
return fake_chat_session
|
||||
|
||||
async def update_runtime_state(self, **kwargs): # noqa: ANN003
|
||||
captured.update(kwargs)
|
||||
|
||||
class _FakeMessageRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def append_message(self, **kwargs): # noqa: ANN003
|
||||
captured["append_kwargs"] = kwargs
|
||||
|
||||
monkeypatch.setattr(store_module, "SessionRepository", _FakeSessionRepository)
|
||||
monkeypatch.setattr(store_module, "MessageRepository", _FakeMessageRepository)
|
||||
monkeypatch.setattr(store_module, "AgentChatSessionStatus", _SessionStatus)
|
||||
|
||||
fake_storage = _FakeToolResultStorage()
|
||||
store = store_module.SqlAlchemyEventStore(
|
||||
session_factory=lambda: _FakeSessionCtx(),
|
||||
tool_result_storage=fake_storage,
|
||||
tool_result_bucket="agent-tool-results",
|
||||
)
|
||||
await store.persist(
|
||||
{
|
||||
"type": "TOOL_CALL_RESULT",
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-1",
|
||||
"toolName": "calendar_write",
|
||||
"result": {
|
||||
"data": {
|
||||
"ok": True,
|
||||
"accessToken": "secret-a",
|
||||
"nested": {
|
||||
"refresh_token": "secret-b",
|
||||
},
|
||||
"items": [
|
||||
{"authorizationHeader": "secret-c"},
|
||||
],
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
payload = cast(dict[str, Any], fake_storage.upload_calls[0]["payload"])
|
||||
stored_result = cast(dict[str, Any], payload["result"])
|
||||
data = cast(dict[str, Any], stored_result["data"])
|
||||
assert data["accessToken"] == "[REDACTED]"
|
||||
nested = cast(dict[str, Any], data["nested"])
|
||||
assert nested["refresh_token"] == "[REDACTED]"
|
||||
items = cast(list[Any], data["items"])
|
||||
assert isinstance(items[0], dict)
|
||||
assert items[0]["authorizationHeader"] == "[REDACTED]"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_drops_buffer_when_session_missing(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
|
||||
@@ -0,0 +1,73 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.agentscope.events.tool_result_summary import build_tool_content_summary
|
||||
|
||||
|
||||
def test_summary_prioritizes_error() -> None:
|
||||
text = build_tool_content_summary(
|
||||
tool_name="calendar_write",
|
||||
args={"title": "A"},
|
||||
result={"message": "ignored"},
|
||||
error={"message": "denied"},
|
||||
)
|
||||
assert text == "calendar_write 执行失败:denied"
|
||||
|
||||
|
||||
def test_summary_for_calendar_write() -> None:
|
||||
text = build_tool_content_summary(
|
||||
tool_name="calendar_write",
|
||||
args={"title": "项目评审"},
|
||||
result={"startAt": "明天 10:00"},
|
||||
error=None,
|
||||
)
|
||||
assert text == "已创建日程:项目评审(明天 10:00)"
|
||||
|
||||
|
||||
def test_summary_for_calendar_read() -> None:
|
||||
text = build_tool_content_summary(
|
||||
tool_name="calendar_read",
|
||||
args={"query": "今天"},
|
||||
result={"data": {"total": 3}},
|
||||
error=None,
|
||||
)
|
||||
assert text == "查询到 3 条日程(今天)"
|
||||
|
||||
|
||||
def test_summary_falls_back_to_result_content() -> None:
|
||||
text = build_tool_content_summary(
|
||||
tool_name="unknown_tool",
|
||||
args=None,
|
||||
result={"content": "这是非常长的说明" * 20},
|
||||
error=None,
|
||||
)
|
||||
assert text.startswith("这是非常长的说明")
|
||||
assert len(text) <= 80
|
||||
|
||||
|
||||
def test_summary_default_done() -> None:
|
||||
text = build_tool_content_summary(
|
||||
tool_name="unknown_tool",
|
||||
args=None,
|
||||
result=None,
|
||||
error=None,
|
||||
)
|
||||
assert text == "unknown_tool 执行完成"
|
||||
|
||||
|
||||
def test_summary_marks_business_failure_when_ok_false() -> None:
|
||||
text = build_tool_content_summary(
|
||||
tool_name="calendar_write",
|
||||
args={"title": "上学"},
|
||||
result={
|
||||
"type": "calendar_operation.v1",
|
||||
"data": {
|
||||
"ok": False,
|
||||
"code": "UNAUTHORIZED",
|
||||
"message": "calendar.write requires validated user token",
|
||||
},
|
||||
},
|
||||
error=None,
|
||||
)
|
||||
assert (
|
||||
text == "calendar_write 执行失败:calendar.write requires validated user token"
|
||||
)
|
||||
@@ -109,7 +109,6 @@ async def test_runtime_emits_started_text_and_finished_events() -> None:
|
||||
"step.start",
|
||||
"step.finish",
|
||||
"step.start",
|
||||
"step.finish",
|
||||
"text.start",
|
||||
"text.delta",
|
||||
"text.end",
|
||||
@@ -117,6 +116,7 @@ async def test_runtime_emits_started_text_and_finished_events() -> None:
|
||||
"text.delta",
|
||||
"text.end",
|
||||
"tool.result",
|
||||
"step.finish",
|
||||
"step.start",
|
||||
"text.start",
|
||||
"text.delta",
|
||||
@@ -127,10 +127,14 @@ async def test_runtime_emits_started_text_and_finished_events() -> None:
|
||||
assert calls[1]["data"]["stepName"] == "intent"
|
||||
assert calls[2]["data"]["stepName"] == "intent"
|
||||
assert calls[3]["data"]["stepName"] == "execution"
|
||||
assert calls[4]["data"]["stepName"] == "execution"
|
||||
assert calls[5]["data"]["stage"] == "intent"
|
||||
assert calls[8]["data"]["stage"] == "execution"
|
||||
assert calls[11]["data"]["toolName"] == "calendar_write"
|
||||
assert calls[4]["data"]["stage"] == "intent"
|
||||
assert calls[7]["data"]["stage"] == "execution"
|
||||
assert calls[10]["data"]["toolName"] == "calendar_write"
|
||||
assert calls[10]["data"]["toolCallId"] == "run-1-t1-1"
|
||||
assert calls[10]["data"]["messageId"] == "tool-result-run-1-t1-1"
|
||||
tool_content = calls[10]["data"]["content"]
|
||||
assert tool_content == "calendar_write 执行完成"
|
||||
assert calls[11]["data"]["stepName"] == "execution"
|
||||
assert calls[12]["data"]["stepName"] == "report"
|
||||
assert calls[14]["data"]["delta"] == "hello world"
|
||||
assert calls[13]["data"]["messageId"] == calls[14]["data"]["messageId"]
|
||||
@@ -305,3 +309,300 @@ async def test_runtime_direct_response_finishes_without_report_stage() -> None:
|
||||
]
|
||||
assert calls[3]["data"]["stage"] == "intent"
|
||||
assert calls[4]["data"]["delta"] == "direct-answer"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_tool_result_parses_json_string_ui_payload() -> None:
|
||||
calls: list[dict[str, Any]] = []
|
||||
|
||||
class _FakePipeline:
|
||||
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
|
||||
assert session_id == "thread-1"
|
||||
calls.append(event)
|
||||
return f"{len(calls)}-0"
|
||||
|
||||
class _FakeOrchestrator:
|
||||
async def run(self, **_: object) -> RuntimeOutput:
|
||||
return RuntimeOutput(
|
||||
intent=IntentOutput(
|
||||
route="TASK_EXECUTION",
|
||||
intent_summary="summary",
|
||||
direct_response=None,
|
||||
tasks=[IntentTask(task_id="t1", title="exec", objective="do")],
|
||||
complexity="complex",
|
||||
response_metadata={},
|
||||
),
|
||||
execution=ExecutionBatchOutput(
|
||||
task_results=[
|
||||
ExecutionTaskOutput(
|
||||
task_id="t1",
|
||||
status="SUCCESS",
|
||||
execution_summary="execution-ok",
|
||||
execution_data={},
|
||||
user_feedback_needs=[],
|
||||
response_metadata={},
|
||||
tool_calls=[
|
||||
ExecutionToolCall(
|
||||
tool_name="calendar_write",
|
||||
args={"title": "A"},
|
||||
result='{"type":"calendar_card.v1","version":"v1","data":{"ok":true,"title":"A"},"actions":[]}',
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
overall_status="SUCCESS",
|
||||
aggregate_summary="ok",
|
||||
),
|
||||
report=ReportOutput(
|
||||
assistant_text="hello world",
|
||||
response_metadata={},
|
||||
),
|
||||
)
|
||||
|
||||
runtime = AgentRouteRuntime(
|
||||
orchestrator=_FakeOrchestrator(), pipeline=_FakePipeline()
|
||||
)
|
||||
command = RunCommand(threadId="thread-1", runId="run-1", messages=[])
|
||||
|
||||
await runtime.run(
|
||||
command=command,
|
||||
owner_id=uuid4(),
|
||||
user_token="token",
|
||||
user_context=_user_context(),
|
||||
session=cast(AsyncSession, object()),
|
||||
)
|
||||
|
||||
tool_events = [item for item in calls if item.get("type") == "tool.result"]
|
||||
assert len(tool_events) == 1
|
||||
data = tool_events[0]["data"]
|
||||
assert isinstance(data, dict)
|
||||
assert isinstance(data.get("ui"), dict)
|
||||
assert data["ui"]["type"] == "calendar_card.v1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_tool_result_keeps_plain_text_content() -> None:
|
||||
calls: list[dict[str, Any]] = []
|
||||
|
||||
class _FakePipeline:
|
||||
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
|
||||
assert session_id == "thread-1"
|
||||
calls.append(event)
|
||||
return f"{len(calls)}-0"
|
||||
|
||||
class _FakeOrchestrator:
|
||||
async def run(self, **_: object) -> RuntimeOutput:
|
||||
return RuntimeOutput(
|
||||
intent=IntentOutput(
|
||||
route="TASK_EXECUTION",
|
||||
intent_summary="summary",
|
||||
direct_response=None,
|
||||
tasks=[IntentTask(task_id="t1", title="exec", objective="do")],
|
||||
complexity="complex",
|
||||
response_metadata={},
|
||||
),
|
||||
execution=ExecutionBatchOutput(
|
||||
task_results=[
|
||||
ExecutionTaskOutput(
|
||||
task_id="t1",
|
||||
status="SUCCESS",
|
||||
execution_summary="execution-ok",
|
||||
execution_data={},
|
||||
user_feedback_needs=[],
|
||||
response_metadata={},
|
||||
tool_calls=[
|
||||
ExecutionToolCall(
|
||||
tool_name="calendar_write",
|
||||
args={"title": "A"},
|
||||
result="created successfully",
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
overall_status="SUCCESS",
|
||||
aggregate_summary="ok",
|
||||
),
|
||||
report=ReportOutput(
|
||||
assistant_text="hello world",
|
||||
response_metadata={},
|
||||
),
|
||||
)
|
||||
|
||||
runtime = AgentRouteRuntime(
|
||||
orchestrator=_FakeOrchestrator(), pipeline=_FakePipeline()
|
||||
)
|
||||
command = RunCommand(threadId="thread-1", runId="run-1", messages=[])
|
||||
|
||||
await runtime.run(
|
||||
command=command,
|
||||
owner_id=uuid4(),
|
||||
user_token="token",
|
||||
user_context=_user_context(),
|
||||
session=cast(AsyncSession, object()),
|
||||
)
|
||||
|
||||
tool_events = [item for item in calls if item.get("type") == "tool.result"]
|
||||
assert len(tool_events) == 1
|
||||
data = tool_events[0]["data"]
|
||||
assert isinstance(data, dict)
|
||||
assert data["content"] == "created successfully"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_tool_result_sanitizes_sensitive_payload() -> None:
|
||||
calls: list[dict[str, Any]] = []
|
||||
|
||||
class _FakePipeline:
|
||||
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
|
||||
assert session_id == "thread-1"
|
||||
calls.append(event)
|
||||
return f"{len(calls)}-0"
|
||||
|
||||
class _FakeOrchestrator:
|
||||
async def run(self, **_: object) -> RuntimeOutput:
|
||||
return RuntimeOutput(
|
||||
intent=IntentOutput(
|
||||
route="TASK_EXECUTION",
|
||||
intent_summary="summary",
|
||||
direct_response=None,
|
||||
tasks=[IntentTask(task_id="t1", title="exec", objective="do")],
|
||||
complexity="complex",
|
||||
response_metadata={},
|
||||
),
|
||||
execution=ExecutionBatchOutput(
|
||||
task_results=[
|
||||
ExecutionTaskOutput(
|
||||
task_id="t1",
|
||||
status="SUCCESS",
|
||||
execution_summary="execution-ok",
|
||||
execution_data={},
|
||||
user_feedback_needs=[],
|
||||
response_metadata={},
|
||||
tool_calls=[
|
||||
ExecutionToolCall(
|
||||
tool_name="calendar_write",
|
||||
args={
|
||||
"title": "A",
|
||||
"accessToken": "arg-secret",
|
||||
"author": "alice",
|
||||
},
|
||||
result={
|
||||
"ok": True,
|
||||
"accessToken": "secret-token",
|
||||
"message": "Authorization: Bearer inline-token",
|
||||
"nested": [
|
||||
{
|
||||
"authorizationHeader": "Bearer abc",
|
||||
}
|
||||
],
|
||||
},
|
||||
error="failed authorization=Bearer abc123 detail",
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
overall_status="SUCCESS",
|
||||
aggregate_summary="ok",
|
||||
),
|
||||
report=ReportOutput(
|
||||
assistant_text="hello world",
|
||||
response_metadata={},
|
||||
),
|
||||
)
|
||||
|
||||
runtime = AgentRouteRuntime(
|
||||
orchestrator=_FakeOrchestrator(), pipeline=_FakePipeline()
|
||||
)
|
||||
command = RunCommand(threadId="thread-1", runId="run-1", messages=[])
|
||||
|
||||
await runtime.run(
|
||||
command=command,
|
||||
owner_id=uuid4(),
|
||||
user_token="token",
|
||||
user_context=_user_context(),
|
||||
session=cast(AsyncSession, object()),
|
||||
)
|
||||
|
||||
tool_events = [item for item in calls if item.get("type") == "tool.result"]
|
||||
assert len(tool_events) == 1
|
||||
data = tool_events[0]["data"]
|
||||
assert isinstance(data, dict)
|
||||
assert isinstance(data["result"], dict)
|
||||
assert data["result"]["accessToken"] == "[REDACTED]"
|
||||
assert data["result"]["message"] == "Authorization=[REDACTED]"
|
||||
nested = data["result"]["nested"]
|
||||
assert isinstance(nested, list)
|
||||
assert nested[0]["authorizationHeader"] == "[REDACTED]"
|
||||
assert isinstance(data["args"], dict)
|
||||
assert data["args"]["accessToken"] == "[REDACTED]"
|
||||
assert data["args"]["author"] == "alice"
|
||||
assert data["error"] == "failed authorization=[REDACTED] detail"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_tool_result_keeps_non_object_result() -> None:
|
||||
calls: list[dict[str, Any]] = []
|
||||
|
||||
class _FakePipeline:
|
||||
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
|
||||
assert session_id == "thread-1"
|
||||
calls.append(event)
|
||||
return f"{len(calls)}-0"
|
||||
|
||||
class _FakeOrchestrator:
|
||||
async def run(self, **_: object) -> RuntimeOutput:
|
||||
return RuntimeOutput(
|
||||
intent=IntentOutput(
|
||||
route="TASK_EXECUTION",
|
||||
intent_summary="summary",
|
||||
direct_response=None,
|
||||
tasks=[IntentTask(task_id="t1", title="exec", objective="do")],
|
||||
complexity="complex",
|
||||
response_metadata={},
|
||||
),
|
||||
execution=ExecutionBatchOutput(
|
||||
task_results=[
|
||||
ExecutionTaskOutput(
|
||||
task_id="t1",
|
||||
status="SUCCESS",
|
||||
execution_summary="execution-ok",
|
||||
execution_data={},
|
||||
user_feedback_needs=[],
|
||||
response_metadata={},
|
||||
tool_calls=[
|
||||
ExecutionToolCall(
|
||||
tool_name="calendar_write",
|
||||
args={"title": "A"},
|
||||
result=["evt-1", "evt-2"],
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
overall_status="SUCCESS",
|
||||
aggregate_summary="ok",
|
||||
),
|
||||
report=ReportOutput(
|
||||
assistant_text="hello world",
|
||||
response_metadata={},
|
||||
),
|
||||
)
|
||||
|
||||
runtime = AgentRouteRuntime(
|
||||
orchestrator=_FakeOrchestrator(), pipeline=_FakePipeline()
|
||||
)
|
||||
command = RunCommand(threadId="thread-1", runId="run-1", messages=[])
|
||||
|
||||
await runtime.run(
|
||||
command=command,
|
||||
owner_id=uuid4(),
|
||||
user_token="token",
|
||||
user_context=_user_context(),
|
||||
session=cast(AsyncSession, object()),
|
||||
)
|
||||
|
||||
tool_events = [item for item in calls if item.get("type") == "tool.result"]
|
||||
assert len(tool_events) == 1
|
||||
data = tool_events[0]["data"]
|
||||
assert isinstance(data, dict)
|
||||
assert isinstance(data["result"], dict)
|
||||
assert data["result"]["value"] == ["evt-1", "evt-2"]
|
||||
|
||||
@@ -212,6 +212,9 @@ def test_merge_stage_response_metadata_estimates_cost_from_pricing(
|
||||
model="qwen3.5-flash",
|
||||
),
|
||||
latency_ms=50,
|
||||
system_prompt="system",
|
||||
user_prompt="user",
|
||||
assistant_text='{"route":"DIRECT_RESPONSE"}',
|
||||
)
|
||||
|
||||
metadata = payload["response_metadata"]
|
||||
|
||||
@@ -50,6 +50,10 @@ async def test_run_agentscope_task_calls_runtime_run(
|
||||
async def _fake_get_redis_client() -> object:
|
||||
return object()
|
||||
|
||||
async def _empty_context(**kwargs: object) -> list[dict[str, Any]]:
|
||||
del kwargs
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(tasks_module, "AgentRouteRuntime", _FakeRuntime)
|
||||
monkeypatch.setattr(
|
||||
tasks_module,
|
||||
@@ -60,7 +64,7 @@ async def test_run_agentscope_task_calls_runtime_run(
|
||||
monkeypatch.setattr(
|
||||
tasks_module,
|
||||
"_build_recent_context_messages",
|
||||
lambda **_: [],
|
||||
_empty_context,
|
||||
)
|
||||
|
||||
result = await tasks_module.run_agentscope_task(
|
||||
@@ -101,6 +105,10 @@ async def test_run_agentscope_task_includes_recent_context_messages(
|
||||
async def _fake_get_redis_client() -> object:
|
||||
return object()
|
||||
|
||||
async def _empty_context(**kwargs: object) -> list[dict[str, Any]]:
|
||||
del kwargs
|
||||
return []
|
||||
|
||||
async def _fake_context(**kwargs: object) -> list[dict[str, Any]]:
|
||||
del kwargs
|
||||
return [{"id": "ctx-1", "role": "assistant", "content": "历史上下文"}]
|
||||
@@ -115,7 +123,7 @@ async def test_run_agentscope_task_includes_recent_context_messages(
|
||||
monkeypatch.setattr(
|
||||
tasks_module,
|
||||
"_build_recent_context_messages",
|
||||
lambda **_: [],
|
||||
_empty_context,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
tasks_module,
|
||||
|
||||
@@ -94,3 +94,46 @@ def test_validate_run_request_messages_contract_requires_single_user_message() -
|
||||
match="RunAgentInput.messages must contain exactly one user message",
|
||||
):
|
||||
validate_run_request_messages_contract(run_input)
|
||||
|
||||
|
||||
def test_validate_run_request_messages_contract_accepts_binary_url_blocks() -> None:
|
||||
payload = _base_payload()
|
||||
payload["messages"] = [
|
||||
{
|
||||
"id": "u1",
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "请分析"},
|
||||
{
|
||||
"type": "binary",
|
||||
"mimeType": "image/png",
|
||||
"url": "https://signed.example/a.png",
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
run_input = parse_run_input(payload)
|
||||
|
||||
validate_run_request_messages_contract(run_input)
|
||||
|
||||
|
||||
def test_validate_run_request_messages_contract_rejects_binary_data_block() -> None:
|
||||
payload = _base_payload()
|
||||
payload["messages"] = [
|
||||
{
|
||||
"id": "u1",
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "请分析"},
|
||||
{
|
||||
"type": "binary",
|
||||
"mimeType": "image/png",
|
||||
"data": "aGVsbG8=",
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
run_input = parse_run_input(payload)
|
||||
|
||||
with pytest.raises(ValueError, match="binary content requires url"):
|
||||
validate_run_request_messages_contract(run_input)
|
||||
|
||||
@@ -54,3 +54,20 @@ def test_build_intent_user_prompt_filters_non_image_binary_block() -> None:
|
||||
assert isinstance(prompt, list)
|
||||
image_blocks = [item for item in prompt if item.get("type") == "image"]
|
||||
assert image_blocks == []
|
||||
|
||||
|
||||
def test_build_intent_user_prompt_includes_previous_context_messages() -> None:
|
||||
prompt = build_intent_user_prompt(
|
||||
user_input=[
|
||||
{"id": "u1", "role": "user", "content": "我的口令是蓝鲸42"},
|
||||
{"id": "a1", "role": "assistant", "content": "已记住"},
|
||||
{"id": "u2", "role": "user", "content": "请重复口令"},
|
||||
]
|
||||
)
|
||||
|
||||
assert isinstance(prompt, list)
|
||||
assert prompt
|
||||
instruction = prompt[0].get("text", "")
|
||||
assert isinstance(instruction, str)
|
||||
assert "[Conversation Context]" in instruction
|
||||
assert "\\u84dd\\u9cb842" in instruction
|
||||
|
||||
@@ -67,10 +67,8 @@ async def test_close_clears_clients(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
assert await service.initialize() is True
|
||||
assert await service.close() is True
|
||||
assert service.is_initialized is False
|
||||
with pytest.raises(RuntimeError):
|
||||
service.get_client()
|
||||
with pytest.raises(RuntimeError):
|
||||
service.get_admin_client()
|
||||
assert service.get_client() is not None
|
||||
assert service.get_admin_client() is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -117,7 +115,47 @@ def test_get_client_raises_before_init() -> None:
|
||||
settings=SupabaseSettings(public_url="https://test.supabase.co")
|
||||
)
|
||||
|
||||
assert service.get_client() is not None
|
||||
assert service.get_admin_client() is not None
|
||||
|
||||
|
||||
def test_get_client_raises_when_lazy_initialization_fails(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
service = SupabaseService(
|
||||
settings=SupabaseSettings(public_url="https://test.supabase.co")
|
||||
)
|
||||
|
||||
def _fake_create_client(_: str, __: str) -> object:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
monkeypatch.setattr("services.base.supabase.create_client", _fake_create_client)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
service.get_client()
|
||||
with pytest.raises(RuntimeError):
|
||||
service.get_admin_client()
|
||||
|
||||
|
||||
def test_get_admin_client_lazily_initializes_clients(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
service = SupabaseService(
|
||||
settings=SupabaseSettings(public_url="https://test.supabase.co")
|
||||
)
|
||||
anon_client = MagicMock(name="anon")
|
||||
admin_client = MagicMock(name="admin")
|
||||
create_calls: list[tuple[str, str]] = []
|
||||
|
||||
def _fake_create_client(url: str, key: str) -> object:
|
||||
create_calls.append((url, key))
|
||||
return anon_client if len(create_calls) == 1 else admin_client
|
||||
|
||||
monkeypatch.setattr("services.base.supabase.create_client", _fake_create_client)
|
||||
|
||||
resolved_admin = service.get_admin_client()
|
||||
|
||||
assert resolved_admin is admin_client
|
||||
assert service.get_client() is anon_client
|
||||
assert service.is_initialized is True
|
||||
assert len(create_calls) == 2
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
import v1.agent.attachment_storage as attachment_storage_module
|
||||
|
||||
|
||||
class _FakeBucket:
|
||||
def __init__(self) -> None:
|
||||
self.upload_calls: list[tuple[str, bytes, dict[str, str]]] = []
|
||||
self.download_calls: list[str] = []
|
||||
|
||||
def upload(self, path: str, content: bytes, options: dict[str, str]) -> object:
|
||||
self.upload_calls.append((path, content, options))
|
||||
return {"path": path}
|
||||
|
||||
def download(self, path: str) -> object:
|
||||
self.download_calls.append(path)
|
||||
return b"ok"
|
||||
|
||||
|
||||
class _FakeStorage:
|
||||
def __init__(self, bucket: _FakeBucket) -> None:
|
||||
self._bucket = bucket
|
||||
|
||||
def from_(self, bucket: str) -> object:
|
||||
del bucket
|
||||
return self._bucket
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_attachment_storage_rejects_unexpected_bucket(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
storage = attachment_storage_module.AgentAttachmentStorage()
|
||||
monkeypatch.setattr(
|
||||
attachment_storage_module.config.storage,
|
||||
"bucket",
|
||||
"allowed-bucket",
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Invalid attachment bucket"):
|
||||
await storage.upload_bytes(
|
||||
bucket="other-bucket",
|
||||
path="agent-inputs/u/t/r/file.png",
|
||||
content=b"data",
|
||||
content_type="image/png",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_attachment_storage_accepts_configured_bucket(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
storage = attachment_storage_module.AgentAttachmentStorage()
|
||||
fake_bucket = _FakeBucket()
|
||||
fake_client = SimpleNamespace(storage=_FakeStorage(fake_bucket))
|
||||
monkeypatch.setattr(
|
||||
attachment_storage_module.config.storage,
|
||||
"bucket",
|
||||
"allowed-bucket",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
attachment_storage_module.supabase_service,
|
||||
"get_admin_client",
|
||||
lambda: fake_client,
|
||||
)
|
||||
|
||||
path = await storage.upload_bytes(
|
||||
bucket="allowed-bucket",
|
||||
path="agent-inputs/u/t/r/file.png",
|
||||
content=b"data",
|
||||
content_type="image/png",
|
||||
)
|
||||
payload = await storage.download_bytes(
|
||||
bucket="allowed-bucket",
|
||||
path=path,
|
||||
)
|
||||
|
||||
assert path == "agent-inputs/u/t/r/file.png"
|
||||
assert payload == b"ok"
|
||||
assert len(fake_bucket.upload_calls) == 1
|
||||
assert fake_bucket.download_calls == ["agent-inputs/u/t/r/file.png"]
|
||||
@@ -6,6 +6,7 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.config.settings import config
|
||||
from models.agent_chat_message import AgentChatMessageRole
|
||||
from v1.agent.repository import AgentRepository
|
||||
|
||||
@@ -62,7 +63,7 @@ async def test_tool_message_hydrates_content_from_object_storage() -> None:
|
||||
content='{"offloaded":true}',
|
||||
metadata_json={
|
||||
"tool_call_id": "call-1",
|
||||
"storage_bucket": "private",
|
||||
"storage_bucket": config.storage.bucket,
|
||||
"storage_path": "tool-results/run-1/call-1.json",
|
||||
},
|
||||
)
|
||||
@@ -73,6 +74,43 @@ async def test_tool_message_hydrates_content_from_object_storage() -> None:
|
||||
assert payload["content"] == "已跳转"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_message_hydrates_ui_from_ui_schema_field() -> None:
|
||||
repository = AgentRepository(
|
||||
session=SimpleNamespace(), # type: ignore[arg-type]
|
||||
tool_result_storage=_FakeToolResultStorage(
|
||||
{
|
||||
"toolName": "calendar_write",
|
||||
"ui_schema": {
|
||||
"type": "calendar_operation.v1",
|
||||
"version": "v1",
|
||||
"data": {"ok": True, "operation": "create"},
|
||||
"actions": [],
|
||||
},
|
||||
}
|
||||
),
|
||||
)
|
||||
message = SimpleNamespace(
|
||||
id=uuid4(),
|
||||
role=AgentChatMessageRole.TOOL,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
content="已创建日程:项目评审(明天 10:00)",
|
||||
metadata_json={
|
||||
"tool_call_id": "call-3",
|
||||
"storage_bucket": config.storage.bucket,
|
||||
"storage_path": "tool-results/run-1/call-3.json",
|
||||
},
|
||||
)
|
||||
|
||||
payload = await repository._to_snapshot_message(message) # type: ignore[arg-type]
|
||||
|
||||
assert payload["toolCallId"] == "call-3"
|
||||
assert payload["content"] == "已创建日程:项目评审(明天 10:00)"
|
||||
ui = payload.get("ui")
|
||||
assert isinstance(ui, dict)
|
||||
assert ui["type"] == "calendar_operation.v1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_message_keeps_inline_content_when_storage_payload_missing() -> None:
|
||||
repository = AgentRepository(
|
||||
@@ -86,7 +124,7 @@ async def test_tool_message_keeps_inline_content_when_storage_payload_missing()
|
||||
content="inline-tool-content",
|
||||
metadata_json={
|
||||
"tool_call_id": "call-2",
|
||||
"storage_bucket": "private",
|
||||
"storage_bucket": config.storage.bucket,
|
||||
"storage_path": "tool-results/run-1/call-2.json",
|
||||
},
|
||||
)
|
||||
@@ -97,6 +135,111 @@ async def test_tool_message_keeps_inline_content_when_storage_payload_missing()
|
||||
assert payload["content"] == "inline-tool-content"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_message_skips_storage_when_path_not_matching_session() -> None:
|
||||
repository = AgentRepository(
|
||||
session=SimpleNamespace(), # type: ignore[arg-type]
|
||||
tool_result_storage=_FakeToolResultStorage(
|
||||
{
|
||||
"ui_schema": {
|
||||
"type": "calendar_operation.v1",
|
||||
"version": "v1",
|
||||
"data": {"ok": True},
|
||||
"actions": [],
|
||||
}
|
||||
}
|
||||
),
|
||||
)
|
||||
message = SimpleNamespace(
|
||||
id=uuid4(),
|
||||
session_id=uuid4(),
|
||||
role=AgentChatMessageRole.TOOL,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
content="summary",
|
||||
metadata_json={
|
||||
"tool_call_id": "call-x",
|
||||
"storage_bucket": config.storage.bucket,
|
||||
"storage_path": "tool-results/foreign-session/call-y.json",
|
||||
},
|
||||
)
|
||||
|
||||
payload = await repository._to_snapshot_message(message) # type: ignore[arg-type]
|
||||
|
||||
assert payload["content"] == "summary"
|
||||
assert "ui" not in payload
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_message_rejects_path_traversal() -> None:
|
||||
repository = AgentRepository(
|
||||
session=SimpleNamespace(), # type: ignore[arg-type]
|
||||
tool_result_storage=_FakeToolResultStorage(
|
||||
{
|
||||
"ui_schema": {
|
||||
"type": "calendar_operation.v1",
|
||||
"version": "v1",
|
||||
"data": {"ok": True},
|
||||
"actions": [],
|
||||
}
|
||||
}
|
||||
),
|
||||
)
|
||||
message = SimpleNamespace(
|
||||
id=uuid4(),
|
||||
session_id=uuid4(),
|
||||
role=AgentChatMessageRole.TOOL,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
content="summary",
|
||||
metadata_json={
|
||||
"tool_call_id": "call-z",
|
||||
"storage_bucket": config.storage.bucket,
|
||||
"storage_path": "tool-results/ok/../../evil/call-z.json",
|
||||
},
|
||||
)
|
||||
|
||||
payload = await repository._to_snapshot_message(message) # type: ignore[arg-type]
|
||||
|
||||
assert payload["content"] == "summary"
|
||||
assert "ui" not in payload
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_message_supports_legacy_storage_path() -> None:
|
||||
repository = AgentRepository(
|
||||
session=SimpleNamespace(), # type: ignore[arg-type]
|
||||
tool_result_storage=_FakeToolResultStorage(
|
||||
{
|
||||
"ui_schema": {
|
||||
"type": "calendar_operation.v1",
|
||||
"version": "v1",
|
||||
"data": {"ok": True},
|
||||
"actions": [],
|
||||
},
|
||||
"content": "legacy content",
|
||||
}
|
||||
),
|
||||
)
|
||||
message = SimpleNamespace(
|
||||
id=uuid4(),
|
||||
session_id=uuid4(),
|
||||
role=AgentChatMessageRole.TOOL,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
content='{"offloaded":true}',
|
||||
metadata_json={
|
||||
"tool_call_id": "call-legacy",
|
||||
"storage_bucket": config.storage.bucket,
|
||||
"storage_path": "tool-results/old-run/call-legacy.json",
|
||||
},
|
||||
)
|
||||
|
||||
payload = await repository._to_snapshot_message(message) # type: ignore[arg-type]
|
||||
|
||||
assert payload["content"] == "legacy content"
|
||||
ui = payload.get("ui")
|
||||
assert isinstance(ui, dict)
|
||||
assert ui["type"] == "calendar_operation.v1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_message_snapshot_includes_renderable_attachments() -> None:
|
||||
repository = AgentRepository(
|
||||
@@ -104,6 +247,7 @@ async def test_user_message_snapshot_includes_renderable_attachments() -> None:
|
||||
)
|
||||
message = SimpleNamespace(
|
||||
id=uuid4(),
|
||||
session_id=uuid4(),
|
||||
role=AgentChatMessageRole.USER,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
content="请分析这张图",
|
||||
@@ -122,13 +266,13 @@ async def test_user_message_snapshot_includes_renderable_attachments() -> None:
|
||||
|
||||
assert payload["role"] == "user"
|
||||
assert payload["content"] == "请分析这张图"
|
||||
assert payload["attachments"] == [
|
||||
{
|
||||
"bucket": "agent-chat-attachments",
|
||||
"path": "agent-inputs/u1/t1/r1/m1/att-1.png",
|
||||
"mimeType": "image/png",
|
||||
}
|
||||
]
|
||||
attachments = payload.get("attachments")
|
||||
assert isinstance(attachments, list)
|
||||
assert len(attachments) == 1
|
||||
first = attachments[0]
|
||||
assert isinstance(first, dict)
|
||||
assert first["mimeType"] == "image/png"
|
||||
assert isinstance(first.get("previewPath"), str)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -174,3 +318,32 @@ async def test_persist_user_message_keeps_existing_session_title() -> None:
|
||||
|
||||
assert session_row.title == "已有标题"
|
||||
assert session_row.message_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_message_attachment_reference_returns_item() -> None:
|
||||
session_id = str(uuid4())
|
||||
message_id = str(uuid4())
|
||||
message = SimpleNamespace(
|
||||
metadata_json={
|
||||
"attachments": [
|
||||
{
|
||||
"bucket": "bucket-test",
|
||||
"path": "agent-inputs/u/t/r/a.png",
|
||||
"mimeType": "image/png",
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
fake_session = _FakeSession(message)
|
||||
repository = AgentRepository(session=fake_session) # type: ignore[arg-type]
|
||||
|
||||
ref = await repository.get_message_attachment_reference(
|
||||
session_id=session_id,
|
||||
message_id=message_id,
|
||||
attachment_index=0,
|
||||
)
|
||||
|
||||
assert ref is not None
|
||||
assert ref["bucket"] == "bucket-test"
|
||||
assert ref["mimeType"] == "image/png"
|
||||
|
||||
@@ -225,3 +225,44 @@ async def test_stream_events_retries_on_redis_timeout(
|
||||
|
||||
merged = "".join(chunks)
|
||||
assert "event: RUN_FINISHED" in merged
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_attachment_preview_rejects_negative_index() -> None:
|
||||
class _Service:
|
||||
async def get_attachment_preview(self, **kwargs): # noqa: ANN003
|
||||
del kwargs
|
||||
raise AssertionError("get_attachment_preview should not be called")
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await agent_router.get_attachment_preview(
|
||||
thread_id="00000000-0000-0000-0000-000000000001",
|
||||
message_id="00000000-0000-0000-0000-000000000010",
|
||||
attachment_index=-1,
|
||||
service=cast(Any, _Service()),
|
||||
current_user=CurrentUser(id=uuid4(), email="user@example.com"),
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_attachment_preview_returns_streaming_response() -> None:
|
||||
class _Service:
|
||||
async def get_attachment_preview(self, **kwargs): # noqa: ANN003
|
||||
del kwargs
|
||||
return b"png-bytes", "image/png"
|
||||
|
||||
response = await agent_router.get_attachment_preview(
|
||||
thread_id="00000000-0000-0000-0000-000000000001",
|
||||
message_id="00000000-0000-0000-0000-000000000010",
|
||||
attachment_index=0,
|
||||
service=cast(Any, _Service()),
|
||||
current_user=CurrentUser(id=uuid4(), email="user@example.com"),
|
||||
)
|
||||
chunks: list[bytes] = []
|
||||
async for chunk in response.body_iterator:
|
||||
chunks.append(cast(bytes, chunk))
|
||||
|
||||
assert response.media_type == "image/png"
|
||||
assert b"".join(chunks) == b"png-bytes"
|
||||
|
||||
@@ -6,8 +6,10 @@ from uuid import UUID
|
||||
|
||||
from ag_ui.core import RunAgentInput
|
||||
from fastapi import HTTPException
|
||||
import pytest
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from core.auth.models import CurrentUser
|
||||
from core.config.settings import config
|
||||
import v1.agent.service as agent_service_module
|
||||
from v1.agent.service import AgentService, AsrService
|
||||
|
||||
@@ -74,12 +76,32 @@ class _FakeRepository:
|
||||
}
|
||||
)
|
||||
|
||||
async def get_message_attachment_reference(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
message_id: str,
|
||||
attachment_index: int,
|
||||
) -> dict[str, str] | None:
|
||||
del session_id, message_id
|
||||
if attachment_index != 0:
|
||||
return None
|
||||
return {
|
||||
"bucket": config.storage.bucket,
|
||||
"path": "agent-inputs/00000000-0000-0000-0000-000000000001/00000000-0000-0000-0000-000000000001/run-1/attachment-0-a.png",
|
||||
"mimeType": "image/png",
|
||||
}
|
||||
|
||||
|
||||
class _FakeQueue:
|
||||
def __init__(self) -> None:
|
||||
self.commands: list[dict[str, object]] = []
|
||||
|
||||
async def enqueue(
|
||||
self, *, command: dict[str, object], dedup_key: str | None
|
||||
) -> str:
|
||||
del command, dedup_key
|
||||
self.commands.append(command)
|
||||
del dedup_key
|
||||
return "task-1"
|
||||
|
||||
|
||||
@@ -123,6 +145,33 @@ class _FakeAttachmentStorage:
|
||||
)
|
||||
return path
|
||||
|
||||
async def download_bytes(self, *, bucket: str, path: str) -> bytes:
|
||||
self.calls.append(
|
||||
{
|
||||
"bucket": bucket,
|
||||
"path": path,
|
||||
"download": True,
|
||||
}
|
||||
)
|
||||
return b"png-bytes"
|
||||
|
||||
async def create_signed_url(
|
||||
self,
|
||||
*,
|
||||
bucket: str,
|
||||
path: str,
|
||||
expires_in_seconds: int,
|
||||
) -> str:
|
||||
self.calls.append(
|
||||
{
|
||||
"bucket": bucket,
|
||||
"path": path,
|
||||
"signed": True,
|
||||
"expires_in_seconds": expires_in_seconds,
|
||||
}
|
||||
)
|
||||
return f"https://signed.example/{path}?exp={expires_in_seconds}"
|
||||
|
||||
|
||||
class _AlwaysFailAttachmentStorage:
|
||||
async def upload_bytes(
|
||||
@@ -136,6 +185,20 @@ class _AlwaysFailAttachmentStorage:
|
||||
del bucket, path, content, content_type
|
||||
raise RuntimeError("upload failed")
|
||||
|
||||
async def download_bytes(self, *, bucket: str, path: str) -> bytes:
|
||||
del bucket, path
|
||||
raise RuntimeError("download failed")
|
||||
|
||||
async def create_signed_url(
|
||||
self,
|
||||
*,
|
||||
bucket: str,
|
||||
path: str,
|
||||
expires_in_seconds: int,
|
||||
) -> str:
|
||||
del bucket, path, expires_in_seconds
|
||||
raise RuntimeError("sign failed")
|
||||
|
||||
|
||||
def _user() -> CurrentUser:
|
||||
return CurrentUser(
|
||||
@@ -186,9 +249,10 @@ async def test_resume_idempotency_uses_redis_lock_and_task_key() -> None:
|
||||
|
||||
async def test_enqueue_run_creates_missing_thread_session() -> None:
|
||||
repository = _FakeRepository()
|
||||
queue = _FakeQueue()
|
||||
service = AgentService(
|
||||
repository=repository,
|
||||
queue=_FakeQueue(),
|
||||
queue=queue,
|
||||
stream=_FakeStream(),
|
||||
)
|
||||
run_input = _build_run_input(
|
||||
@@ -206,6 +270,30 @@ async def test_enqueue_run_creates_missing_thread_session() -> None:
|
||||
assert accepted.created is True
|
||||
assert repository.created_with_session_id == "00000000-0000-0000-0000-000000000999"
|
||||
assert repository.committed is True
|
||||
assert queue.commands[0]["user_token"] is None
|
||||
|
||||
|
||||
async def test_enqueue_run_uses_explicit_user_token() -> None:
|
||||
repository = _FakeRepository()
|
||||
queue = _FakeQueue()
|
||||
service = AgentService(
|
||||
repository=repository,
|
||||
queue=queue,
|
||||
stream=_FakeStream(),
|
||||
)
|
||||
run_input = _build_run_input(
|
||||
thread_id="00000000-0000-0000-0000-000000000001",
|
||||
run_id="run-1",
|
||||
)
|
||||
|
||||
await service.enqueue_run(
|
||||
run_input=run_input,
|
||||
current_user=_user(),
|
||||
user_token="Bearer access-token-1",
|
||||
)
|
||||
|
||||
assert queue.commands
|
||||
assert queue.commands[0]["user_token"] == "access-token-1"
|
||||
|
||||
|
||||
async def test_enqueue_run_keeps_created_session_when_enqueue_fails() -> None:
|
||||
@@ -270,7 +358,7 @@ async def test_enqueue_run_handles_session_create_race() -> None:
|
||||
assert repository.rolled_back is True
|
||||
|
||||
|
||||
async def test_enqueue_run_uploads_user_image_to_supabase_and_injects_metadata(
|
||||
async def test_enqueue_run_uses_forwarded_attachments_and_injects_metadata(
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(
|
||||
@@ -297,15 +385,23 @@ async def test_enqueue_run_uploads_user_image_to_supabase_and_injects_metadata(
|
||||
{"type": "text", "text": "帮我看下这张图"},
|
||||
{
|
||||
"type": "binary",
|
||||
"data": "aGVsbG8=",
|
||||
"mimeType": "image/png",
|
||||
"url": "https://signed.example/upload.png",
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
"forwardedProps": {
|
||||
"attachments": [
|
||||
{
|
||||
"bucket": "agent-test-bucket",
|
||||
"path": "agent-inputs/00000000-0000-0000-0000-000000000001/00000000-0000-0000-0000-000000000001/upload.png",
|
||||
"mimeType": "image/png",
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -313,10 +409,9 @@ async def test_enqueue_run_uploads_user_image_to_supabase_and_injects_metadata(
|
||||
|
||||
assert accepted.task_id == "task-1"
|
||||
assert len(attachment_storage.calls) == 1
|
||||
upload = attachment_storage.calls[0]
|
||||
assert upload["bucket"] == "agent-test-bucket"
|
||||
assert upload["content"] == b"hello"
|
||||
assert upload["content_type"] == "image/png"
|
||||
download = attachment_storage.calls[0]
|
||||
assert download["bucket"] == "agent-test-bucket"
|
||||
assert download["download"] is True
|
||||
assert repository.persisted_user_messages
|
||||
persisted = repository.persisted_user_messages[0]
|
||||
assert persisted["session_id"] == "00000000-0000-0000-0000-000000000001"
|
||||
@@ -330,7 +425,7 @@ async def test_enqueue_run_uploads_user_image_to_supabase_and_injects_metadata(
|
||||
assert isinstance(attachments[0]["path"], str)
|
||||
|
||||
|
||||
async def test_enqueue_run_raises_when_attachment_upload_fails_without_fallback(
|
||||
async def test_enqueue_run_raises_when_attachment_download_fails_without_fallback(
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(
|
||||
@@ -356,15 +451,23 @@ async def test_enqueue_run_raises_when_attachment_upload_fails_without_fallback(
|
||||
{"type": "text", "text": "帮我看下这张图"},
|
||||
{
|
||||
"type": "binary",
|
||||
"data": "aGVsbG8=",
|
||||
"mimeType": "image/png",
|
||||
"url": "https://signed.example/upload.png",
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
"forwardedProps": {
|
||||
"attachments": [
|
||||
{
|
||||
"bucket": "agent-test-bucket",
|
||||
"path": "agent-inputs/00000000-0000-0000-0000-000000000001/00000000-0000-0000-0000-000000000001/upload.png",
|
||||
"mimeType": "image/png",
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -373,11 +476,183 @@ async def test_enqueue_run_raises_when_attachment_upload_fails_without_fallback(
|
||||
raise AssertionError("expected HTTPException")
|
||||
except HTTPException as exc:
|
||||
assert exc.status_code == 502
|
||||
assert exc.detail == "Failed to upload attachment"
|
||||
assert exc.detail == "Failed to fetch attachment"
|
||||
|
||||
assert repository.persisted_user_messages == []
|
||||
|
||||
|
||||
async def test_enqueue_run_rejects_unsupported_attachment_type(
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(
|
||||
agent_service_module.config.storage, "bucket", "agent-test-bucket"
|
||||
)
|
||||
repository = _FakeRepository()
|
||||
attachment_storage = _FakeAttachmentStorage()
|
||||
service = AgentService(
|
||||
repository=repository,
|
||||
queue=_FakeQueue(),
|
||||
stream=_FakeStream(),
|
||||
attachment_storage=attachment_storage,
|
||||
)
|
||||
run_input = RunAgentInput.model_validate(
|
||||
{
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-with-bad-image",
|
||||
"state": {},
|
||||
"messages": [
|
||||
{
|
||||
"id": "u1",
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "请看附件"},
|
||||
{
|
||||
"type": "binary",
|
||||
"mimeType": "image/gif",
|
||||
"url": "https://signed.example/upload.gif",
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {
|
||||
"attachments": [
|
||||
{
|
||||
"bucket": "agent-test-bucket",
|
||||
"path": "agent-inputs/00000000-0000-0000-0000-000000000001/00000000-0000-0000-0000-000000000001/upload.gif",
|
||||
"mimeType": "image/gif",
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.enqueue_run(run_input=run_input, current_user=_user())
|
||||
|
||||
assert exc_info.value.status_code == 422
|
||||
assert exc_info.value.detail == "Unsupported attachment type"
|
||||
assert attachment_storage.calls == []
|
||||
|
||||
|
||||
async def test_enqueue_run_rejects_attachment_too_large(
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(agent_service_module, "_MAX_ATTACHMENT_BYTES", 4)
|
||||
monkeypatch.setattr(
|
||||
agent_service_module.config.storage, "bucket", "agent-test-bucket"
|
||||
)
|
||||
repository = _FakeRepository()
|
||||
attachment_storage = _FakeAttachmentStorage()
|
||||
service = AgentService(
|
||||
repository=repository,
|
||||
queue=_FakeQueue(),
|
||||
stream=_FakeStream(),
|
||||
attachment_storage=attachment_storage,
|
||||
)
|
||||
run_input = RunAgentInput.model_validate(
|
||||
{
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-with-big-image",
|
||||
"state": {},
|
||||
"messages": [
|
||||
{
|
||||
"id": "u1",
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "请看附件"},
|
||||
{
|
||||
"type": "binary",
|
||||
"mimeType": "image/png",
|
||||
"url": "https://signed.example/upload.png",
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {
|
||||
"attachments": [
|
||||
{
|
||||
"bucket": "agent-test-bucket",
|
||||
"path": "agent-inputs/00000000-0000-0000-0000-000000000001/00000000-0000-0000-0000-000000000001/upload.png",
|
||||
"mimeType": "image/png",
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.enqueue_run(run_input=run_input, current_user=_user())
|
||||
|
||||
assert exc_info.value.status_code == 413
|
||||
assert exc_info.value.detail == "Attachment too large"
|
||||
assert len(attachment_storage.calls) == 1
|
||||
assert attachment_storage.calls[0]["download"] is True
|
||||
|
||||
|
||||
async def test_enqueue_run_accepts_binary_url_and_persists_metadata() -> None:
|
||||
repository = _FakeRepository()
|
||||
queue = _FakeQueue()
|
||||
attachment_storage = _FakeAttachmentStorage()
|
||||
service = AgentService(
|
||||
repository=repository,
|
||||
queue=queue,
|
||||
stream=_FakeStream(),
|
||||
attachment_storage=attachment_storage,
|
||||
)
|
||||
run_input = RunAgentInput.model_validate(
|
||||
{
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-with-binary-url",
|
||||
"state": {},
|
||||
"messages": [
|
||||
{
|
||||
"id": "u1",
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "请分析"},
|
||||
{
|
||||
"type": "binary",
|
||||
"mimeType": "image/png",
|
||||
"url": "https://signed.example/upload-1.png",
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {
|
||||
"attachments": [
|
||||
{
|
||||
"bucket": config.storage.bucket,
|
||||
"path": "agent-inputs/00000000-0000-0000-0000-000000000001/00000000-0000-0000-0000-000000000001/upload-1.png",
|
||||
"mimeType": "image/png",
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
accepted = await service.enqueue_run(run_input=run_input, current_user=_user())
|
||||
|
||||
assert accepted.task_id == "task-1"
|
||||
persisted = repository.persisted_user_messages[-1]
|
||||
metadata = persisted["metadata"]
|
||||
assert isinstance(metadata, dict)
|
||||
attachments = metadata.get("attachments")
|
||||
assert isinstance(attachments, list)
|
||||
assert attachments[0]["path"].endswith("upload-1.png")
|
||||
queue_input = queue.commands[-1]["run_input"]
|
||||
assert isinstance(queue_input, dict)
|
||||
content = queue_input["messages"][0]["content"]
|
||||
assert isinstance(content, list)
|
||||
assert content[1]["type"] == "binary"
|
||||
assert content[1]["url"] == "https://signed.example/upload-1.png"
|
||||
|
||||
|
||||
async def test_get_history_snapshot_wraps_history_day_as_state_snapshot_event() -> None:
|
||||
service = AgentService(
|
||||
repository=_FakeRepository(),
|
||||
@@ -415,6 +690,59 @@ async def test_get_user_history_snapshot_uses_latest_thread_when_absent() -> Non
|
||||
assert event["threadId"] == "00000000-0000-0000-0000-000000000001"
|
||||
|
||||
|
||||
async def test_get_attachment_preview_returns_payload_and_mime() -> None:
|
||||
service = AgentService(
|
||||
repository=_FakeRepository(),
|
||||
queue=_FakeQueue(),
|
||||
stream=_FakeStream(),
|
||||
attachment_storage=_FakeAttachmentStorage(),
|
||||
)
|
||||
|
||||
payload, mime_type = await service.get_attachment_preview(
|
||||
thread_id="00000000-0000-0000-0000-000000000001",
|
||||
message_id="00000000-0000-0000-0000-000000000010",
|
||||
attachment_index=0,
|
||||
current_user=_user(),
|
||||
)
|
||||
|
||||
assert payload == b"png-bytes"
|
||||
assert mime_type == "image/png"
|
||||
|
||||
|
||||
async def test_get_attachment_preview_rejects_invalid_path() -> None:
|
||||
class _BadPathRepository(_FakeRepository):
|
||||
async def get_message_attachment_reference(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
message_id: str,
|
||||
attachment_index: int,
|
||||
) -> dict[str, str] | None:
|
||||
del session_id, message_id, attachment_index
|
||||
return {
|
||||
"bucket": "bucket-test",
|
||||
"path": "agent-inputs/other-user/other-thread/run-1/a.png",
|
||||
"mimeType": "image/png",
|
||||
}
|
||||
|
||||
service = AgentService(
|
||||
repository=_BadPathRepository(),
|
||||
queue=_FakeQueue(),
|
||||
stream=_FakeStream(),
|
||||
attachment_storage=_FakeAttachmentStorage(),
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.get_attachment_preview(
|
||||
thread_id="00000000-0000-0000-0000-000000000001",
|
||||
message_id="00000000-0000-0000-0000-000000000010",
|
||||
attachment_index=0,
|
||||
current_user=_user(),
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
|
||||
|
||||
async def test_asr_service_parses_dict_output_sentence(monkeypatch) -> None:
|
||||
result = SimpleNamespace(
|
||||
status_code=200,
|
||||
|
||||
@@ -67,3 +67,60 @@ uv run pytest tests/integration/v1/agent/test_sse_flow_live.py::test_agent_runs_
|
||||
2. **连续会话记忆测试** - 验证 session 是否从数据库读取历史上下文
|
||||
3. **工具调用测试** - calendar 读/写/删/分享 + 用户查找 + 时间感知
|
||||
4. **session 失败排查** - 找出最新失败原因并修复
|
||||
|
||||
## 9. 本轮进展与结论(2026-03-12)
|
||||
|
||||
### 9.1 反馈闭环状态
|
||||
|
||||
1. **intent/execution 阶段 tokens/cost 入库**:已解决。
|
||||
2. **连续会话记忆(今天+昨天上下文)**:已解决。
|
||||
3. **工具调用冒烟(读/写/删/分享 + user 查询 + 时间感知)**:部分解决。
|
||||
4. **最新失败 session 根因定位与修复**:已解决。
|
||||
5. **反馈同步到文档**:已完成(本节)。
|
||||
|
||||
### 9.2 关键修复
|
||||
|
||||
1. **stage telemetry 补齐**(intent/execution):
|
||||
- usage 缺失时补 token 估算;
|
||||
- 通过 `LiteLLMService.calculate_cost` 按项目定价估算 cost;
|
||||
- 回填 `response_metadata.inputTokens/outputTokens/cost` 并落库。
|
||||
|
||||
2. **会话记忆上下文注入**:
|
||||
- runtime 在执行前读取同一 session 最近两天(今天+昨天)的 user/assistant 消息;
|
||||
- intent prompt 增加 `[Conversation Context]`,避免只看最新用户输入。
|
||||
|
||||
3. **工具调用稳定性修复**:
|
||||
- tool 名统一为下划线(`calendar_read`/`calendar_write`/`user_resolve`),修复 OpenAI/LiteLLM tool name 正则错误;
|
||||
- intent prompt 注入 intent+execution 合并工具 schema,避免误判“无可用写入工具”。
|
||||
|
||||
### 9.3 Live 证据
|
||||
|
||||
#### A) tokens/cost 入库(thread=`cb1681c2-c223-4ced-bcfd-76f7252ba2d8`)
|
||||
|
||||
- intent: `input_tokens=1541`,`output_tokens=37`,`cost=0.000382`
|
||||
- execution: `input_tokens=2161`,`output_tokens=376`,`cost=0.005450`
|
||||
- report: `input_tokens=3266`,`output_tokens=318`,`cost=0.007256`
|
||||
- session 聚合:`total_tokens=13518`,`total_cost=0.019473`
|
||||
|
||||
#### B) 连续会话记忆(thread=`9c456736-d5e5-48a4-b9db-55f507baf573`)
|
||||
|
||||
- run `mem-1`:`请记住口令是蓝鲸42,只回复已记住。`
|
||||
- run `mem-2`:`只回复我刚才让你记住的口令,不要解释。`
|
||||
- assistant 回复:`蓝鲸42`(记忆命中)。
|
||||
|
||||
#### C) 工具调用 + 时间感知(thread=`cb1681c2-c223-4ced-bcfd-76f7252ba2d8`,run=`run-tool-1`)
|
||||
|
||||
- 事件序列含 execution 阶段与多次 `TOOL_CALL_RESULT`
|
||||
- 工具调用结果:`calendar_write`、`calendar_read`(多次)
|
||||
- assistant 回复包含时间感知信息(北京时间日期/星期/时刻)
|
||||
|
||||
### 9.4 最新失败 session 根因
|
||||
|
||||
- 失败样本:`d6bc4dbd-8361-4a39-bf09-12b3392e0e70`
|
||||
- 根因:tool 名含点号(如 `calendar.write`)触发校验失败:
|
||||
- `Invalid 'tools[0].function.name' ... expected pattern ^[a-zA-Z0-9_-]+$`
|
||||
- 修复后:同类执行链路已可稳定进入 execution 并产出 `TOOL_CALL_RESULT`。
|
||||
|
||||
### 9.5 当前未闭环项
|
||||
|
||||
- `user_resolve` + calendar **分享 + 删除** 组合链路的完整 live 证据还未补齐(本轮执行中断:`Tool execution aborted`)。
|
||||
|
||||
@@ -0,0 +1,173 @@
|
||||
# Agent Tool UI Schema and Frontend Event Wiring Design
|
||||
|
||||
## Goal
|
||||
|
||||
修正 agent 工具结果的数据契约与前后端对接:
|
||||
|
||||
1. SSE `TOOL_CALL_RESULT` 继续携带可实时渲染的 `ui`。
|
||||
2. 落库时 `messages.content` 仅存关键摘要,完整工具结果(含 `ui schema`)存对象存储。
|
||||
3. `messages.metadata` 仅存访问路径和索引字段,history 通过 metadata 回填完整工具卡片数据。
|
||||
4. 前端正式接通 runs/events/history 三路,并统一实时与历史渲染行为。
|
||||
|
||||
## Constraints
|
||||
|
||||
- 暂缓冒烟测试,先完成工具数据修正与前后端接口对接。
|
||||
- 保持现有前端 `UiSchemaRenderer` 可解析格式,不做破坏性协议改动。
|
||||
- `resume` 新需求暂不扩展。
|
||||
- 遵循 AG-UI 事件语义和现有 FastAPI 路由约定。
|
||||
|
||||
## Selected Approach
|
||||
|
||||
采用兼容增强方案:
|
||||
|
||||
- 事件流对前端保持兼容(`TOOL_CALL_RESULT` 带 `ui` + `content`)。
|
||||
- 持久化与回放做结构化增强(storage + metadata 索引 + 摘要 content)。
|
||||
- 前端实时与历史统一映射层,保证同类消息一致渲染。
|
||||
|
||||
## Design A: Unified Data Contract
|
||||
|
||||
### SSE Event Contract (Realtime)
|
||||
|
||||
`TOOL_CALL_RESULT` 事件继续包含前端当前可解析字段:
|
||||
|
||||
- `callId`
|
||||
- `toolName`
|
||||
- `args`
|
||||
- `result`
|
||||
- `error`
|
||||
- `content` (关键结果摘要)
|
||||
- `ui` (工具卡片 schema)
|
||||
|
||||
这保证前端实时流不需要等待 history 即可显示工具卡片。
|
||||
|
||||
### Persistence Contract (Database + Storage)
|
||||
|
||||
对 tool message 持久化采用双层:
|
||||
|
||||
- `messages.content`: 仅保存 `content_summary`(短文本,供低成本上下文和兜底展示)。
|
||||
- 对象存储: 保存完整 payload(`ui`、`args`、`result`、`error`、时间戳、工具标识等)。
|
||||
- `messages.metadata`: 只保存索引和访问路径:
|
||||
- `tool_call_id`
|
||||
- `tool_name`
|
||||
- `run_id`
|
||||
- `stage`
|
||||
- `task_id`
|
||||
- `storage_bucket`
|
||||
- `storage_path`
|
||||
- `summary_version`
|
||||
|
||||
### History Contract
|
||||
|
||||
history 序列化时:
|
||||
|
||||
1. 先通过 `metadata.storage_bucket/storage_path` 读取完整 payload。
|
||||
2. 从 payload 回填 `ui`,并保留摘要 `content`。
|
||||
3. storage 读取失败时,回退 `messages.content`,确保历史可读。
|
||||
|
||||
## Design B: Frontend Wiring (runs/events/history)
|
||||
|
||||
### runs
|
||||
|
||||
- `POST /api/v1/agent/runs` 仅负责创建 run 与启动执行。
|
||||
- 前端保留 `threadId/runId` 和本地流状态,不承载渲染业务。
|
||||
|
||||
### events
|
||||
|
||||
- SSE 作为唯一实时渲染来源。
|
||||
- `TOOL_CALL_RESULT` 直接读取事件内 `ui` 渲染 `ToolResultItem`。
|
||||
- `STEP_STARTED/STEP_FINISHED` 显示三阶段状态(intent/execution/report)。
|
||||
|
||||
### history
|
||||
|
||||
- 通过 `/api/v1/agent/history` 或 `/api/v1/agent/runs/{threadId}/history` 回放。
|
||||
- tool message 优先读 `ui`(由后端从 metadata+storage 回填)。
|
||||
- user message 读取 `attachments` 渲染多模态内容。
|
||||
|
||||
### Consistency Rule
|
||||
|
||||
- 实时事件与历史快照统一进入同一 `ChatListItem` 映射层。
|
||||
- `content` 只做兜底文本,不作为工具卡片主数据。
|
||||
|
||||
## Design C: Backend Implementation Details
|
||||
|
||||
### Modules to Change
|
||||
|
||||
- `backend/src/core/agentscope/events/store.py`
|
||||
- 增加 tool result 的摘要生成与 storage 上传。
|
||||
- `append_message` 时写入摘要 content + metadata 索引。
|
||||
- `backend/src/core/agentscope/tools/tool_result_storage.py`
|
||||
- 复用现有 `upload_json/read_json`,作为完整 payload 存取层。
|
||||
- `backend/src/v1/agent/repository.py`
|
||||
- `_to_snapshot_message` 对 tool message 优先按 metadata 读取 storage 并回填 `ui`。
|
||||
- `backend/src/core/agentscope/runtime/agent_route_runtime.py`
|
||||
- 确保 `tool.result` 事件继续带 `ui` 和摘要 `content`。
|
||||
|
||||
### Failure Fallback
|
||||
|
||||
- storage 写失败:不阻断主流程,至少保证 `messages.content` 可读,metadata 标记缺失。
|
||||
- storage 读失败:history 返回摘要 `content`,`ui` 为空。
|
||||
|
||||
## Design D: content_summary Rule Engine
|
||||
|
||||
### Function
|
||||
|
||||
新增纯函数:
|
||||
|
||||
`build_tool_content_summary(tool_name, args, result, error) -> str`
|
||||
|
||||
### Rules (Priority)
|
||||
|
||||
1. 错误优先:有 `error` 直接输出失败摘要。
|
||||
2. 工具专用模板:
|
||||
- `calendar_write`: `已创建日程:{title}({start_time})`
|
||||
- `calendar_read`: `查询到 {count} 条日程({date_range})`
|
||||
- `calendar_delete`: `已删除日程:{title_or_id}`
|
||||
- `calendar_share`: `已分享日程给 {target}`
|
||||
- `user_resolve`: `已匹配用户:{name_or_id}`
|
||||
3. 通用回退:优先 `result.content`,否则抽取常见键拼句。
|
||||
4. 最终兜底:`{tool_name} 执行完成/执行失败`。
|
||||
5. 清洗:去换行与多空格,限制长度,避免大段 JSON。
|
||||
|
||||
### Summary Storage Policy
|
||||
|
||||
- `messages.content` 存摘要。
|
||||
- `summary_version` 存入 metadata,支持未来摘要算法演进。
|
||||
|
||||
## Testing and Acceptance
|
||||
|
||||
### Backend
|
||||
|
||||
- 单元测试:
|
||||
- `events/store`: tool result 摘要写入、metadata 路径写入、storage 异常回退。
|
||||
- `v1/agent/repository`: history 按 metadata 回填 `ui`;storage 缺失回退 content。
|
||||
- 摘要函数:覆盖成功/失败/缺字段/超长文本场景。
|
||||
- 集成测试:
|
||||
- `/runs` + `/events`:实时 `TOOL_CALL_RESULT` 带 `ui`。
|
||||
- `/history`:返回 tool message 的 `ui` 来自 metadata+storage。
|
||||
|
||||
### Frontend
|
||||
|
||||
- 单元/组件测试:
|
||||
- `AgUiService` 解析 `TOOL_CALL_RESULT` 的 `ui`。
|
||||
- `ChatBloc`:实时事件与 history 快照都能产出 `ToolResultItem`。
|
||||
- `UiSchemaRenderer`:history 回放卡片渲染一致。
|
||||
- user message 附件渲染(history)。
|
||||
- 页面行为验证:
|
||||
- events 到达即实时更新消息列表。
|
||||
- step 三阶段状态正确切换。
|
||||
- 上拉历史后工具卡片可正常显示。
|
||||
|
||||
## Risks and Mitigations
|
||||
|
||||
- 风险:storage 不可用导致 history 卡片缺失。
|
||||
- 缓解:保底展示摘要 content,不阻断对话。
|
||||
- 风险:事件格式变更导致前端实时解析失败。
|
||||
- 缓解:维持现有 `ToolCallResultEvent` 字段,不做破坏性改名。
|
||||
- 风险:摘要规则覆盖不足。
|
||||
- 缓解:规则版本化 + 测试样例扩展。
|
||||
|
||||
## Out of Scope
|
||||
|
||||
- resume 扩展协议与交互策略。
|
||||
- 新一轮 live 冒烟验收。
|
||||
- 新 UI 风格重构,仅实现链路打通与数据契约修正。
|
||||
@@ -0,0 +1,283 @@
|
||||
# Agent UI Schema and Event Wiring Implementation Plan
|
||||
|
||||
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
|
||||
|
||||
**Goal:** 打通 agent 工具结果在实时事件与历史回放的一致渲染链路:SSE 实时带 UI,落库 content 存摘要,完整 UI schema 存 storage 并通过 metadata 回填。
|
||||
|
||||
**Architecture:** 后端在 `TOOL_CALL_RESULT` 持久化链路中引入“摘要 + 全量分离”策略:摘要写 `messages.content`,全量 payload 写对象存储,metadata 仅存索引路径;history 读取时按 metadata 反查 storage 回填 `ui`。前端复用现有 AG-UI 事件模型,实现 runs/events/history 三路统一映射到 `ChatListItem`,并补齐 step 事件渲染与 history 多模态渲染。
|
||||
|
||||
**Tech Stack:** FastAPI, SQLAlchemy, AgentScope runtime/events, Supabase Storage, Flutter (Bloc/Cubit), Dart models/tests, AG-UI events
|
||||
|
||||
---
|
||||
|
||||
### Task 1: Add Tool Summary Rule Engine (Backend)
|
||||
|
||||
**Files:**
|
||||
- Create: `backend/src/core/agentscope/events/tool_result_summary.py`
|
||||
- Test: `backend/tests/unit/core/agentscope/events/test_tool_result_summary.py`
|
||||
|
||||
**Step 1: Write the failing test**
|
||||
|
||||
```python
|
||||
from core.agentscope.events.tool_result_summary import build_tool_content_summary
|
||||
|
||||
|
||||
def test_calendar_write_summary() -> None:
|
||||
text = build_tool_content_summary(
|
||||
tool_name="calendar_write",
|
||||
args={"title": "项目评审"},
|
||||
result={"start_time": "明天 10:00"},
|
||||
error=None,
|
||||
)
|
||||
assert text.startswith("已创建日程")
|
||||
```
|
||||
|
||||
**Step 2: Run test to verify it fails**
|
||||
|
||||
Run: `uv run pytest backend/tests/unit/core/agentscope/events/test_tool_result_summary.py -q`
|
||||
Expected: FAIL with import/module/function missing.
|
||||
|
||||
**Step 3: Write minimal implementation**
|
||||
|
||||
```python
|
||||
def build_tool_content_summary(*, tool_name: str, args, result, error) -> str:
|
||||
if error:
|
||||
return f"{tool_name} 执行失败"
|
||||
if tool_name == "calendar_write":
|
||||
return "已创建日程"
|
||||
return f"{tool_name} 执行完成"
|
||||
```
|
||||
|
||||
**Step 4: Run test to verify it passes**
|
||||
|
||||
Run: `uv run pytest backend/tests/unit/core/agentscope/events/test_tool_result_summary.py -q`
|
||||
Expected: PASS.
|
||||
|
||||
**Step 5: Extend tests for all rules and refactor**
|
||||
|
||||
Add cases for `calendar_read/calendar_delete/calendar_share/user_resolve/error/fallback/truncation` and implement full rule table.
|
||||
|
||||
**Step 6: Commit**
|
||||
|
||||
```bash
|
||||
git add backend/src/core/agentscope/events/tool_result_summary.py backend/tests/unit/core/agentscope/events/test_tool_result_summary.py
|
||||
git commit -m "feat: add deterministic tool result summary engine"
|
||||
```
|
||||
|
||||
### Task 2: Persist Full Tool Payload to Storage and Keep Content Lightweight
|
||||
|
||||
**Files:**
|
||||
- Modify: `backend/src/core/agentscope/events/store.py`
|
||||
- Test: `backend/tests/unit/core/agentscope/events/test_store.py`
|
||||
|
||||
**Step 1: Write the failing tests**
|
||||
|
||||
Add tests asserting:
|
||||
- `TOOL_CALL_RESULT` persists summary to `content`.
|
||||
- metadata includes `storage_bucket/storage_path/tool_call_id`.
|
||||
- uploaded payload includes full `ui/args/result/error`.
|
||||
|
||||
**Step 2: Run targeted tests (RED)**
|
||||
|
||||
Run: `uv run pytest backend/tests/unit/core/agentscope/events/test_store.py -q`
|
||||
Expected: FAIL on new assertions.
|
||||
|
||||
**Step 3: Implement minimal storage write path**
|
||||
|
||||
In `_persist_tool_call_result`:
|
||||
- build `full_payload` from event fields.
|
||||
- call summary engine for `content`.
|
||||
- upload payload via tool result storage (inject dependency if needed).
|
||||
- store only path/index in metadata.
|
||||
|
||||
**Step 4: Run tests (GREEN)**
|
||||
|
||||
Run: `uv run pytest backend/tests/unit/core/agentscope/events/test_store.py -q`
|
||||
Expected: PASS.
|
||||
|
||||
**Step 5: Add fallback test and implementation**
|
||||
|
||||
Add case where storage upload fails but tool message still persists with summary and no crash.
|
||||
|
||||
**Step 6: Commit**
|
||||
|
||||
```bash
|
||||
git add backend/src/core/agentscope/events/store.py backend/tests/unit/core/agentscope/events/test_store.py
|
||||
git commit -m "feat: store tool payload in object storage with metadata index"
|
||||
```
|
||||
|
||||
### Task 3: Hydrate History Tool UI from Metadata Storage Path
|
||||
|
||||
**Files:**
|
||||
- Modify: `backend/src/v1/agent/repository.py`
|
||||
- Test: `backend/tests/unit/v1/agent/test_repository.py`
|
||||
|
||||
**Step 1: Write failing tests**
|
||||
|
||||
Add/adjust assertions:
|
||||
- history tool payload resolves `ui` from storage payload.
|
||||
- when storage missing, fallback to `messages.content` summary.
|
||||
|
||||
**Step 2: Run tests (RED)**
|
||||
|
||||
Run: `uv run pytest backend/tests/unit/v1/agent/test_repository.py -q`
|
||||
Expected: FAIL on `ui` hydration and fallback assertions.
|
||||
|
||||
**Step 3: Implement minimal hydration logic**
|
||||
|
||||
In `_to_snapshot_message` for tool role:
|
||||
- read storage via `metadata.storage_bucket/storage_path`.
|
||||
- map hydrated payload fields to snapshot (`ui`, `content`, `toolCallId`).
|
||||
- keep safe fallback when storage read fails.
|
||||
|
||||
**Step 4: Run tests (GREEN)**
|
||||
|
||||
Run: `uv run pytest backend/tests/unit/v1/agent/test_repository.py -q`
|
||||
Expected: PASS.
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add backend/src/v1/agent/repository.py backend/tests/unit/v1/agent/test_repository.py
|
||||
git commit -m "fix: hydrate tool ui from metadata storage in history snapshots"
|
||||
```
|
||||
|
||||
### Task 4: Keep SSE TOOL_CALL_RESULT Compatible with Existing Frontend Parsing
|
||||
|
||||
**Files:**
|
||||
- Modify: `backend/src/core/agentscope/runtime/agent_route_runtime.py`
|
||||
- Test: `backend/tests/unit/core/agentscope/runtime/test_agent_route_runtime.py`
|
||||
|
||||
**Step 1: Write failing test**
|
||||
|
||||
Add assertion that emitted `TOOL_CALL_RESULT` data contains expected renderable fields (`callId/toolName/result/error` and `ui` path from result payload).
|
||||
|
||||
**Step 2: Run tests (RED)**
|
||||
|
||||
Run: `uv run pytest backend/tests/unit/core/agentscope/runtime/test_agent_route_runtime.py -q`
|
||||
Expected: FAIL on missing/incorrect payload fields.
|
||||
|
||||
**Step 3: Implement minimal payload normalization**
|
||||
|
||||
Normalize tool result event payload so frontend can keep current parsing without contract breaks.
|
||||
|
||||
**Step 4: Run tests (GREEN)**
|
||||
|
||||
Run: `uv run pytest backend/tests/unit/core/agentscope/runtime/test_agent_route_runtime.py -q`
|
||||
Expected: PASS.
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add backend/src/core/agentscope/runtime/agent_route_runtime.py backend/tests/unit/core/agentscope/runtime/test_agent_route_runtime.py
|
||||
git commit -m "fix: preserve frontend-compatible tool result event payload"
|
||||
```
|
||||
|
||||
### Task 5: Wire Frontend History + Events to Unified Rendering Path
|
||||
|
||||
**Files:**
|
||||
- Modify: `apps/lib/features/chat/data/services/ag_ui_service.dart`
|
||||
- Modify: `apps/lib/features/chat/presentation/bloc/chat_bloc.dart`
|
||||
- Modify: `apps/lib/features/chat/data/models/tool_result.dart`
|
||||
- Modify: `apps/lib/features/home/ui/screens/home_screen.dart`
|
||||
- Test: `apps/test/features/chat/ag_ui_service_test.dart`
|
||||
- Create/Modify: `apps/test/features/chat/chat_bloc_test.dart`
|
||||
|
||||
**Step 1: Write failing tests**
|
||||
|
||||
Add tests asserting:
|
||||
- history tool message with `ui` becomes `ToolResultItem`.
|
||||
- SSE `TOOL_CALL_RESULT` with `ui` renders same item shape.
|
||||
- attachments in history user message are mapped for multimodal rendering.
|
||||
|
||||
**Step 2: Run tests (RED)**
|
||||
|
||||
Run: `cd apps && flutter test test/features/chat/ag_ui_service_test.dart`
|
||||
Expected: FAIL on new mapping assertions.
|
||||
|
||||
**Step 3: Implement minimal mapping changes**
|
||||
|
||||
- In service/bloc, unify history and event mapping into same conversion path.
|
||||
- Keep existing `UiSchemaRenderer` input format untouched.
|
||||
- Ensure fallback to content text when `ui` missing.
|
||||
|
||||
**Step 4: Run tests (GREEN)**
|
||||
|
||||
Run: `cd apps && flutter test test/features/chat/ag_ui_service_test.dart`
|
||||
Expected: PASS.
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add apps/lib/features/chat/data/services/ag_ui_service.dart apps/lib/features/chat/presentation/bloc/chat_bloc.dart apps/lib/features/chat/data/models/tool_result.dart apps/lib/features/home/ui/screens/home_screen.dart apps/test/features/chat/ag_ui_service_test.dart apps/test/features/chat/chat_bloc_test.dart
|
||||
git commit -m "feat: unify realtime and history tool card rendering"
|
||||
```
|
||||
|
||||
### Task 6: Add Step Event Rendering for Intent/Execution/Report
|
||||
|
||||
**Files:**
|
||||
- Modify: `apps/lib/features/chat/presentation/bloc/chat_bloc.dart`
|
||||
- Modify: `apps/lib/features/home/ui/screens/home_screen.dart`
|
||||
- Test: `apps/test/features/chat/chat_bloc_test.dart`
|
||||
|
||||
**Step 1: Write failing test**
|
||||
|
||||
Add test verifying `STEP_STARTED/STEP_FINISHED` transitions produce visible stage state.
|
||||
|
||||
**Step 2: Run tests (RED)**
|
||||
|
||||
Run: `cd apps && flutter test test/features/chat/chat_bloc_test.dart`
|
||||
Expected: FAIL on missing stage state.
|
||||
|
||||
**Step 3: Implement minimal state and UI**
|
||||
|
||||
- Track current stage enum in `ChatState`.
|
||||
- Render compact stage progress row in chat screen.
|
||||
|
||||
**Step 4: Run tests (GREEN)**
|
||||
|
||||
Run: `cd apps && flutter test test/features/chat/chat_bloc_test.dart`
|
||||
Expected: PASS.
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add apps/lib/features/chat/presentation/bloc/chat_bloc.dart apps/lib/features/home/ui/screens/home_screen.dart apps/test/features/chat/chat_bloc_test.dart
|
||||
git commit -m "feat: render agent step progress from AG-UI events"
|
||||
```
|
||||
|
||||
### Task 7: Verification Gate (Backend + Frontend)
|
||||
|
||||
**Files:**
|
||||
- Modify (if needed): `docs/plans/2026-03-11-agent-multimodal-smoke-runbook.md`
|
||||
|
||||
**Step 1: Run backend targeted tests**
|
||||
|
||||
Run: `uv run pytest backend/tests/unit/core/agentscope/events/test_tool_result_summary.py backend/tests/unit/core/agentscope/events/test_store.py backend/tests/unit/v1/agent/test_repository.py backend/tests/unit/core/agentscope/runtime/test_agent_route_runtime.py -q`
|
||||
Expected: PASS.
|
||||
|
||||
**Step 2: Run frontend targeted tests**
|
||||
|
||||
Run: `cd apps && flutter test test/features/chat/ag_ui_service_test.dart test/features/chat/chat_bloc_test.dart`
|
||||
Expected: PASS.
|
||||
|
||||
**Step 3: Run backend quality checks**
|
||||
|
||||
Run: `uv run ruff check backend/src backend/tests`
|
||||
Expected: PASS.
|
||||
|
||||
**Step 4: Run backend type checks**
|
||||
|
||||
Run: `uv run basedpyright`
|
||||
Expected: 0 errors.
|
||||
|
||||
**Step 5: Update runbook evidence**
|
||||
|
||||
Record changed contract, test evidence, and known follow-ups.
|
||||
|
||||
**Step 6: Commit**
|
||||
|
||||
```bash
|
||||
git add docs/plans/2026-03-11-agent-multimodal-smoke-runbook.md
|
||||
git commit -m "docs: record tool ui schema storage and rendering verification"
|
||||
```
|
||||
Reference in New Issue
Block a user