feat: 添加 Agent 步骤事件与图片附件功能

- 新增 stepStarted/stepFinished 事件类型支持
- 前端实现图片附件上传和预览功能
- 后端增强工具结果存储和事件处理
- 完善相关单元测试和集成测试
This commit is contained in:
zl-q
2026-03-12 09:29:57 +08:00
parent 87215f9d41
commit 7b8865e256
45 changed files with 3869 additions and 308 deletions
-1
View File
@@ -50,7 +50,6 @@ class ApiClient implements IApiClient {
};
}
@override
Future<Response<T>> get<T>(String path, {Options? options}) async {
try {
return await _dio.get<T>(path, options: options);
+8 -17
View File
@@ -40,7 +40,11 @@ class MockApiClient implements IApiClient {
_handlers[key] = handler;
}
void registerPatternHandler(RegExp pattern, String method, MockHandler handler) {
void registerPatternHandler(
RegExp pattern,
String method,
MockHandler handler,
) {
_patternHandlers.add(
_PatternRoute(
pattern: pattern,
@@ -96,11 +100,7 @@ class MockApiClient implements IApiClient {
final direct = _handlers[key];
if (direct != null) {
final response = direct(
MockRequest(
path: path,
method: 'SSE',
headers: headers,
),
MockRequest(path: path, method: 'SSE', headers: headers),
);
if (response is Stream<String>) {
return response;
@@ -118,11 +118,7 @@ class MockApiClient implements IApiClient {
continue;
}
final response = route.handler(
MockRequest(
path: path,
method: 'SSE',
headers: headers,
),
MockRequest(path: path, method: 'SSE', headers: headers),
);
if (response is Stream<String>) {
return response;
@@ -147,12 +143,7 @@ class MockApiClient implements IApiClient {
if (handler != null) {
final response = handler(
MockRequest(
path: path,
method: method,
data: data,
options: options,
),
MockRequest(path: path, method: method, data: data, options: options),
);
if (response is Response) {
return response as Response<T>;
@@ -9,6 +9,8 @@ class AgUiEventTypeWire {
static const runStarted = 'RUN_STARTED';
static const runFinished = 'RUN_FINISHED';
static const runError = 'RUN_ERROR';
static const stepStarted = 'STEP_STARTED';
static const stepFinished = 'STEP_FINISHED';
static const textMessageStart = 'TEXT_MESSAGE_START';
static const textMessageContent = 'TEXT_MESSAGE_CONTENT';
static const textMessageEnd = 'TEXT_MESSAGE_END';
@@ -25,6 +27,8 @@ enum AgUiEventType {
runStarted,
runFinished,
runError,
stepStarted,
stepFinished,
textMessageStart,
textMessageContent,
textMessageEnd,
@@ -43,6 +47,8 @@ const _wireToTypeMap = {
AgUiEventTypeWire.runStarted: AgUiEventType.runStarted,
AgUiEventTypeWire.runFinished: AgUiEventType.runFinished,
AgUiEventTypeWire.runError: AgUiEventType.runError,
AgUiEventTypeWire.stepStarted: AgUiEventType.stepStarted,
AgUiEventTypeWire.stepFinished: AgUiEventType.stepFinished,
AgUiEventTypeWire.textMessageStart: AgUiEventType.textMessageStart,
AgUiEventTypeWire.textMessageContent: AgUiEventType.textMessageContent,
AgUiEventTypeWire.textMessageEnd: AgUiEventType.textMessageEnd,
@@ -60,6 +66,8 @@ const _typeToWireMap = {
AgUiEventType.runStarted: AgUiEventTypeWire.runStarted,
AgUiEventType.runFinished: AgUiEventTypeWire.runFinished,
AgUiEventType.runError: AgUiEventTypeWire.runError,
AgUiEventType.stepStarted: AgUiEventTypeWire.stepStarted,
AgUiEventType.stepFinished: AgUiEventTypeWire.stepFinished,
AgUiEventType.textMessageStart: AgUiEventTypeWire.textMessageStart,
AgUiEventType.textMessageContent: AgUiEventTypeWire.textMessageContent,
AgUiEventType.textMessageEnd: AgUiEventTypeWire.textMessageEnd,
@@ -83,6 +91,8 @@ final _typeToFactory = {
AgUiEventType.runStarted: RunStartedEvent.fromJson,
AgUiEventType.runFinished: RunFinishedEvent.fromJson,
AgUiEventType.runError: RunErrorEvent.fromJson,
AgUiEventType.stepStarted: StepStartedEvent.fromJson,
AgUiEventType.stepFinished: StepFinishedEvent.fromJson,
AgUiEventType.textMessageStart: TextMessageStartEvent.fromJson,
AgUiEventType.textMessageContent: TextMessageContentEvent.fromJson,
AgUiEventType.textMessageEnd: TextMessageEndEvent.fromJson,
@@ -170,6 +180,34 @@ class RunErrorEvent extends AgUiEvent {
Map<String, dynamic> toJson() => _$RunErrorEventToJson(this);
}
@JsonSerializable()
class StepStartedEvent extends AgUiEvent {
final String stepName;
StepStartedEvent({required this.stepName})
: super(type: AgUiEventType.stepStarted);
factory StepStartedEvent.fromJson(Map<String, dynamic> json) =>
_$StepStartedEventFromJson(json);
@override
Map<String, dynamic> toJson() => _$StepStartedEventToJson(this);
}
@JsonSerializable()
class StepFinishedEvent extends AgUiEvent {
final String stepName;
StepFinishedEvent({required this.stepName})
: super(type: AgUiEventType.stepFinished);
factory StepFinishedEvent.fromJson(Map<String, dynamic> json) =>
_$StepFinishedEventFromJson(json);
@override
Map<String, dynamic> toJson() => _$StepFinishedEventToJson(this);
}
@JsonSerializable()
class TextMessageStartEvent extends AgUiEvent {
final String messageId;
@@ -310,10 +348,33 @@ class ToolCallResultEvent extends AgUiEvent {
factory ToolCallResultEvent.fromJson(Map<String, dynamic> json) {
final rawContent = json['content'];
final content = rawContent is String ? rawContent : '';
final hasStructuredFields =
json['ui'] != null || json['result'] != null || json['error'] != null;
final content = switch (rawContent) {
String value when value.trim().startsWith('{') => value,
String value when value.trim().startsWith('[') => value,
String value when hasStructuredFields => jsonEncode({
'toolName': json['toolName'],
'result': json['result'],
'error': json['error'],
'ui': json['ui'],
'content': value,
}),
String value => value,
_ => jsonEncode({
'toolName': json['toolName'],
'result': json['result'],
'error': json['error'],
'ui': json['ui'],
'content': json['content'],
}),
};
final toolCallId =
json['toolCallId'] as String? ?? json['callId'] as String? ?? '';
final messageId = json['messageId'] as String? ?? 'tool-result-$toolCallId';
return ToolCallResultEvent(
messageId: json['messageId'] as String,
toolCallId: json['toolCallId'] as String,
messageId: messageId,
toolCallId: toolCallId,
content: content,
);
}
@@ -388,6 +449,7 @@ class SnapshotMessage {
final String? toolCallId;
final UiCard? ui;
final DateTime? timestamp;
final List<Map<String, dynamic>>? attachments;
SnapshotMessage({
required this.id,
@@ -396,6 +458,7 @@ class SnapshotMessage {
this.toolCallId,
this.ui,
this.timestamp,
this.attachments,
});
factory SnapshotMessage.fromJson(Map<String, dynamic> json) =>
@@ -14,6 +14,8 @@ const _$AgUiEventTypeEnumMap = {
AgUiEventType.runStarted: 'runStarted',
AgUiEventType.runFinished: 'runFinished',
AgUiEventType.runError: 'runError',
AgUiEventType.stepStarted: 'stepStarted',
AgUiEventType.stepFinished: 'stepFinished',
AgUiEventType.textMessageStart: 'textMessageStart',
AgUiEventType.textMessageContent: 'textMessageContent',
AgUiEventType.textMessageEnd: 'textMessageEnd',
@@ -53,6 +55,18 @@ RunErrorEvent _$RunErrorEventFromJson(Map<String, dynamic> json) =>
Map<String, dynamic> _$RunErrorEventToJson(RunErrorEvent instance) =>
<String, dynamic>{'message': instance.message, 'code': instance.code};
StepStartedEvent _$StepStartedEventFromJson(Map<String, dynamic> json) =>
StepStartedEvent(stepName: json['stepName'] as String);
Map<String, dynamic> _$StepStartedEventToJson(StepStartedEvent instance) =>
<String, dynamic>{'stepName': instance.stepName};
StepFinishedEvent _$StepFinishedEventFromJson(Map<String, dynamic> json) =>
StepFinishedEvent(stepName: json['stepName'] as String);
Map<String, dynamic> _$StepFinishedEventToJson(StepFinishedEvent instance) =>
<String, dynamic>{'stepName': instance.stepName};
TextMessageStartEvent _$TextMessageStartEventFromJson(
Map<String, dynamic> json,
) => TextMessageStartEvent(
@@ -170,6 +184,9 @@ SnapshotMessage _$SnapshotMessageFromJson(Map<String, dynamic> json) =>
timestamp: json['timestamp'] == null
? null
: DateTime.parse(json['timestamp'] as String),
attachments: (json['attachments'] as List<dynamic>?)
?.whereType<Map<String, dynamic>>()
.toList(),
);
Map<String, dynamic> _$SnapshotMessageToJson(SnapshotMessage instance) =>
@@ -180,4 +197,5 @@ Map<String, dynamic> _$SnapshotMessageToJson(SnapshotMessage instance) =>
'toolCallId': instance.toolCallId,
'ui': instance.ui,
'timestamp': instance.timestamp?.toIso8601String(),
'attachments': instance.attachments,
};
@@ -22,6 +22,7 @@ class TextMessageItem extends ChatListItem {
@override
final MessageSender sender;
final bool isStreaming;
final List<Map<String, dynamic>> attachments;
TextMessageItem({
required this.id,
@@ -29,6 +30,7 @@ class TextMessageItem extends ChatListItem {
required this.timestamp,
required this.sender,
this.isStreaming = false,
this.attachments = const [],
});
@override
@@ -40,12 +42,14 @@ class TextMessageItem extends ChatListItem {
DateTime? timestamp,
MessageSender? sender,
bool? isStreaming,
List<Map<String, dynamic>>? attachments,
}) => TextMessageItem(
id: id ?? this.id,
content: content ?? this.content,
timestamp: timestamp ?? this.timestamp,
sender: sender ?? this.sender,
isStreaming: isStreaming ?? this.isStreaming,
attachments: attachments ?? this.attachments,
);
}
@@ -1,6 +1,7 @@
import 'dart:async';
import 'dart:convert';
import 'dart:math';
import 'dart:typed_data';
import 'package:dio/dio.dart';
import 'package:image_picker/image_picker.dart';
@@ -80,6 +81,18 @@ class AgUiService {
onEvent(event);
}
Future<Uint8List> fetchAttachmentPreview(String previewPath) async {
final response = await _apiClient.get<List<int>>(
previewPath,
options: Options(responseType: ResponseType.bytes),
);
final payload = response.data;
if (payload is List<int>) {
return Uint8List.fromList(payload);
}
throw StateError('Invalid attachment payload');
}
Future<String> transcribeAudio(String filePath) async {
final formData = FormData.fromMap({
'audio': await MultipartFile.fromFile(
@@ -247,22 +260,27 @@ class AgUiService {
final runId = _nextId(_runIdPrefix);
final contentBlocks = <Map<String, dynamic>>[];
final attachmentMetadata = <Map<String, dynamic>>[];
if (content.isNotEmpty) {
contentBlocks.add({'type': 'text', 'text': content});
}
if (images != null && images.isNotEmpty) {
for (final image in images) {
final bytes = await image.readAsBytes();
final base64 = base64Encode(bytes);
final uploadedAttachments = await _uploadAttachments(
threadId: threadId,
images: images,
);
for (final attachment in uploadedAttachments) {
contentBlocks.add({
'type': 'image',
'source': {
'type': 'base64',
'media_type': 'image/jpeg',
'data': base64,
},
'type': 'binary',
'mimeType': attachment['mimeType'],
'url': attachment['url'],
});
attachmentMetadata.add({
'bucket': attachment['bucket'],
'path': attachment['path'],
'mimeType': attachment['mimeType'],
});
}
}
@@ -286,10 +304,64 @@ class AgUiService {
],
'tools': _buildTools(),
'context': <Map<String, dynamic>>[],
'forwardedProps': <String, dynamic>{},
'forwardedProps': {
if (attachmentMetadata.isNotEmpty) 'attachments': attachmentMetadata,
},
};
}
Future<List<Map<String, dynamic>>> _uploadAttachments({
required String threadId,
required List<XFile> images,
}) async {
final attachments = <Map<String, dynamic>>[];
for (final image in images) {
final mimeType = image.mimeType ?? 'image/jpeg';
final fileBytes = await image.readAsBytes();
final formData = FormData.fromMap({
'threadId': threadId,
'file': MultipartFile.fromBytes(
fileBytes,
filename: image.name,
contentType: DioMediaType.parse(mimeType),
),
});
final response = await _apiClient.post<Map<String, dynamic>>(
'/api/v1/agent/attachments',
data: formData,
);
final payload = response.data;
if (payload is! Map<String, dynamic>) {
throw StateError('Invalid /agent/attachments response');
}
final attachment = payload['attachment'];
if (attachment is! Map<String, dynamic>) {
throw StateError('Missing attachment in /agent/attachments response');
}
final bucket = attachment['bucket'];
final path = attachment['path'];
final uploadedMime = attachment['mimeType'];
final url = attachment['url'];
if (bucket is! String ||
path is! String ||
uploadedMime is! String ||
url is! String ||
bucket.isEmpty ||
path.isEmpty ||
uploadedMime.isEmpty ||
url.isEmpty) {
throw StateError('Invalid attachment reference');
}
attachments.add({
'bucket': bucket,
'path': path,
'mimeType': uploadedMime,
'url': url,
});
}
return attachments;
}
List<Map<String, dynamic>> _buildTools() {
return [
{
@@ -360,6 +432,11 @@ class AgUiService {
'SSE',
_handleMockSse,
);
client.registerHandler(
'/api/v1/agent/attachments',
'POST',
_handleMockUploadAttachment,
);
client.registerHandler(
'/api/v1/agent/transcribe',
'POST',
@@ -371,6 +448,26 @@ class AgUiService {
return {'transcript': '这是模拟语音转写'};
}
Map<String, dynamic> _handleMockUploadAttachment(MockRequest request) {
final payload = request.data;
final threadId = payload is Map<String, dynamic>
? (payload['threadId'] as String?)
: null;
final resolvedThreadId = (threadId != null && threadId.isNotEmpty)
? threadId
: (_threadId ?? _newUuid());
final path =
'agent-inputs/mock/$resolvedThreadId/${_nextId('upload_')}.png';
return {
'attachment': {
'bucket': 'mock-bucket',
'path': path,
'mimeType': 'image/png',
'url': 'https://mock.local/$path',
},
};
}
Map<String, dynamic> _handleMockRun(MockRequest request) {
final payload = request.data;
final runInput = payload is Map<String, dynamic>
@@ -1,4 +1,5 @@
import 'dart:convert';
import 'dart:typed_data';
import 'package:flutter_bloc/flutter_bloc.dart';
import 'package:image_picker/image_picker.dart';
@@ -10,6 +11,8 @@ import '../../data/models/ag_ui_event.dart';
import '../../data/models/chat_list_item.dart';
import '../../data/services/ag_ui_service.dart';
enum AgentStage { intent, execution, report }
class ChatState {
final List<ChatListItem> items;
final bool isSending;
@@ -21,6 +24,7 @@ class ChatState {
final String? error;
final DateTime? oldestLoadedDate;
final bool hasEarlierHistory;
final AgentStage? currentStage;
const ChatState({
this.items = const [],
@@ -33,6 +37,7 @@ class ChatState {
this.error,
this.oldestLoadedDate,
this.hasEarlierHistory = false,
this.currentStage,
});
bool get isLoading =>
@@ -55,6 +60,7 @@ class ChatState {
Object? error = _unset,
Object? oldestLoadedDate = _unset,
bool? hasEarlierHistory,
Object? currentStage = _unset,
}) {
return ChatState(
items: items ?? this.items,
@@ -71,6 +77,9 @@ class ChatState {
? this.oldestLoadedDate
: oldestLoadedDate as DateTime?,
hasEarlierHistory: hasEarlierHistory ?? this.hasEarlierHistory,
currentStage: currentStage == _unset
? this.currentStage
: currentStage as AgentStage?,
);
}
}
@@ -78,6 +87,9 @@ class ChatState {
class ChatBloc extends Cubit<ChatState> {
final AgUiService _service;
final Map<String, String> _toolCallArgsBuffer = {};
final Map<String, Uint8List> _attachmentPreviewCache = <String, Uint8List>{};
final Map<String, Future<Uint8List?>> _attachmentPreviewInflight =
<String, Future<Uint8List?>>{};
ChatBloc({AgUiService? service, IApiClient? apiClient})
: _service =
@@ -102,6 +114,7 @@ class ChatBloc extends Cubit<ChatState> {
isWaitingFirstToken: true,
isCancelling: false,
error: null,
currentStage: null,
),
);
case AgUiEventType.runFinished:
@@ -112,6 +125,7 @@ class ChatBloc extends Cubit<ChatState> {
isStreaming: false,
isCancelling: false,
currentMessageId: null,
currentStage: null,
),
);
case AgUiEventType.runError:
@@ -124,8 +138,13 @@ class ChatBloc extends Cubit<ChatState> {
isCancelling: false,
currentMessageId: null,
error: errorEvent.message,
currentStage: null,
),
);
case AgUiEventType.stepStarted:
_handleStepStarted(event as StepStartedEvent);
case AgUiEventType.stepFinished:
_handleStepFinished(event as StepFinishedEvent);
case AgUiEventType.textMessageStart:
_handleTextMessageStart(event as TextMessageStartEvent);
case AgUiEventType.textMessageContent:
@@ -151,6 +170,16 @@ class ChatBloc extends Cubit<ChatState> {
}
}
void _handleStepStarted(StepStartedEvent event) {
emit(state.copyWith(currentStage: _stageFromName(event.stepName)));
}
void _handleStepFinished(StepFinishedEvent event) {
if (state.currentStage == _stageFromName(event.stepName)) {
emit(state.copyWith(currentStage: null));
}
}
void _handleTextMessageStart(TextMessageStartEvent startEvent) {
final newMessage = TextMessageItem(
id: startEvent.messageId,
@@ -327,6 +356,7 @@ class ChatBloc extends Cubit<ChatState> {
content: msg.content ?? '',
timestamp: timestamp,
sender: MessageSender.user,
attachments: msg.attachments ?? const [],
);
case 'assistant':
return TextMessageItem(
@@ -369,11 +399,20 @@ class ChatBloc extends Cubit<ChatState> {
}
Future<void> sendMessage(String content, {List<XFile>? images}) async {
final attachments = (images ?? const <XFile>[])
.map(
(image) => <String, dynamic>{
"path": image.path,
"mimeType": "image/*",
},
)
.toList();
final userMessage = TextMessageItem(
id: 'user-${DateTime.now().millisecondsSinceEpoch}',
content: content,
timestamp: DateTime.now(),
sender: MessageSender.user,
attachments: attachments,
);
emit(
state.copyWith(
@@ -509,7 +548,43 @@ class ChatBloc extends Cubit<ChatState> {
}
}
Future<Uint8List?> loadAttachmentPreview(String previewPath) async {
final cached = _attachmentPreviewCache[previewPath];
if (cached != null) {
return cached;
}
final pending = _attachmentPreviewInflight[previewPath];
if (pending != null) {
return pending;
}
final future = _service
.fetchAttachmentPreview(previewPath)
.then((bytes) {
_attachmentPreviewCache[previewPath] = bytes;
return bytes;
})
.catchError((_) => null)
.whenComplete(() {
_attachmentPreviewInflight.remove(previewPath);
});
_attachmentPreviewInflight[previewPath] = future;
return future;
}
void clearError() {
emit(state.copyWith(error: null));
}
}
AgentStage? _stageFromName(String value) {
switch (value) {
case 'intent':
return AgentStage.intent;
case 'execution':
return AgentStage.execution;
case 'report':
return AgentStage.report;
default:
return null;
}
}
@@ -1,4 +1,5 @@
import 'dart:io';
import 'dart:typed_data';
import 'package:flutter/material.dart';
import 'package:flutter_bloc/flutter_bloc.dart';
@@ -34,12 +35,15 @@ const _rippleDurationMs = 1200;
const _recordingDotSize = 10.0;
const _transcribingSpinnerSize = 18.0;
const _transcribingStrokeWidth = 2.0;
const _attachmentPreviewSize = 88.0;
const _attachmentPreviewRadius = 10.0;
const _attachmentPreviewGap = 8.0;
const _inputActionButtonKey = ValueKey('home_input_action_button');
const _inputActionIconKey = ValueKey('home_input_action_icon');
/// 颜色常量
const _chatBgColor = Color(0xFFF8FAFC);
const _userBubbleColor = Color(0xFFEAF1FB);
const _chatBgColor = AppColors.slate50;
const _userBubbleColor = AppColors.blue50;
class HomeScreen extends StatefulWidget {
final VoiceRecorder? voiceRecorder;
@@ -265,7 +269,8 @@ class _HomeScreenState extends State<HomeScreen>
),
),
),
if (showWaitingIndicator) _buildWaitingIndicator(),
if (showWaitingIndicator)
_buildWaitingIndicator(currentStage: state.currentStage),
],
);
}
@@ -310,12 +315,19 @@ class _HomeScreenState extends State<HomeScreen>
),
),
),
if (showWaitingIndicator) _buildWaitingIndicator(),
if (showWaitingIndicator)
_buildWaitingIndicator(currentStage: state.currentStage),
],
);
}
Widget _buildWaitingIndicator() {
Widget _buildWaitingIndicator({required AgentStage? currentStage}) {
final label = switch (currentStage) {
AgentStage.intent => '意图识别中',
AgentStage.execution => '任务执行中',
AgentStage.report => '结果总结中',
null => '正在思考...',
};
return Padding(
padding: const EdgeInsets.fromLTRB(
_defaultPadding,
@@ -325,7 +337,7 @@ class _HomeScreenState extends State<HomeScreen>
),
child: Row(
crossAxisAlignment: CrossAxisAlignment.center,
children: const [
children: [
SizedBox(
width: _transcribingSpinnerSize,
height: _transcribingSpinnerSize,
@@ -336,7 +348,7 @@ class _HomeScreenState extends State<HomeScreen>
),
SizedBox(width: 8),
Text(
'正在思考...',
label,
style: TextStyle(fontSize: 14, color: AppColors.slate500),
),
],
@@ -413,38 +425,152 @@ class _HomeScreenState extends State<HomeScreen>
Widget _buildMessageItem(TextMessageItem item) {
final isUser = item.sender == MessageSender.user;
return Row(
mainAxisAlignment: isUser
? MainAxisAlignment.end
: MainAxisAlignment.start,
crossAxisAlignment: CrossAxisAlignment.start,
final imageAttachments = _collectRenderableImageAttachments(
item.attachments,
);
final hasRenderableAttachments = imageAttachments.isNotEmpty;
return Column(
crossAxisAlignment: isUser
? CrossAxisAlignment.end
: CrossAxisAlignment.start,
children: [
Flexible(
child: Container(
padding: const EdgeInsets.symmetric(
horizontal: _messagePaddingH,
vertical: _messagePaddingV,
),
decoration: BoxDecoration(
color: isUser ? _userBubbleColor : AppColors.white,
borderRadius: BorderRadius.only(
topLeft: const Radius.circular(_cornerRadius),
topRight: const Radius.circular(_cornerRadius),
bottomLeft: Radius.circular(isUser ? _cornerRadius : 0),
bottomRight: Radius.circular(isUser ? 0 : _cornerRadius),
Row(
mainAxisAlignment: isUser
? MainAxisAlignment.end
: MainAxisAlignment.start,
crossAxisAlignment: CrossAxisAlignment.start,
children: [
Flexible(
child: Container(
padding: const EdgeInsets.symmetric(
horizontal: _messagePaddingH,
vertical: _messagePaddingV,
),
decoration: BoxDecoration(
color: isUser ? _userBubbleColor : AppColors.white,
borderRadius: BorderRadius.only(
topLeft: const Radius.circular(_cornerRadius),
topRight: const Radius.circular(_cornerRadius),
bottomLeft: Radius.circular(isUser ? _cornerRadius : 0),
bottomRight: Radius.circular(isUser ? 0 : _cornerRadius),
),
border: isUser ? null : Border.all(color: AppColors.slate300),
),
child: Text(
item.content,
style: const TextStyle(
fontSize: 14,
color: AppColors.slate900,
),
),
),
border: isUser ? null : Border.all(color: AppColors.slate300),
),
child: Text(
item.content,
style: const TextStyle(fontSize: 14, color: AppColors.slate900),
if (item.attachments.isNotEmpty && !hasRenderableAttachments) ...[
const SizedBox(width: _itemSpacing / 2),
_buildAttachmentBadge(item.attachments.length),
],
],
),
if (hasRenderableAttachments)
Padding(
padding: const EdgeInsets.only(top: _attachmentPreviewGap),
child: _buildHistoryAttachmentPreviews(
item.attachments,
imageAttachments: imageAttachments,
),
),
),
],
);
}
Widget _buildHistoryAttachmentPreviews(
List<Map<String, dynamic>> attachments, {
List<Map<String, dynamic>>? imageAttachments,
}) {
final renderableAttachments =
imageAttachments ?? _collectRenderableImageAttachments(attachments);
if (renderableAttachments.isEmpty) {
return _buildAttachmentBadge(attachments.length);
}
return Wrap(
spacing: _attachmentPreviewGap,
runSpacing: _attachmentPreviewGap,
crossAxisAlignment: WrapCrossAlignment.start,
children: renderableAttachments.map(_buildHistoryAttachmentTile).toList(),
);
}
List<Map<String, dynamic>> _collectRenderableImageAttachments(
List<Map<String, dynamic>> attachments,
) {
return attachments.where(_isRenderableImageAttachment).toList();
}
bool _isRenderableImageAttachment(Map<String, dynamic> attachment) {
final mimeType = attachment['mimeType'];
final previewPath = attachment['previewPath'];
return mimeType is String &&
mimeType.startsWith('image/') &&
previewPath is String &&
previewPath.isNotEmpty;
}
Widget _buildHistoryAttachmentTile(Map<String, dynamic> attachment) {
final previewPath = attachment['previewPath'];
if (previewPath is! String || previewPath.isEmpty) {
return _buildAttachmentBadge(1);
}
return ClipRRect(
borderRadius: BorderRadius.circular(_attachmentPreviewRadius),
child: Container(
width: _attachmentPreviewSize,
height: _attachmentPreviewSize,
color: AppColors.slate100,
child: FutureBuilder<Uint8List?>(
future: _chatBloc.loadAttachmentPreview(previewPath),
builder: (context, snapshot) {
if (snapshot.connectionState == ConnectionState.waiting) {
return const Center(
child: SizedBox(
width: _transcribingSpinnerSize,
height: _transcribingSpinnerSize,
child: CircularProgressIndicator(
strokeWidth: _transcribingStrokeWidth,
),
),
);
}
final data = snapshot.data;
if (data == null || data.isEmpty) {
return const Center(
child: Icon(
LucideIcons.imageOff,
size: _iconSize,
color: AppColors.slate500,
),
);
}
return Image.memory(data, fit: BoxFit.cover, gaplessPlayback: true);
},
),
),
);
}
Widget _buildAttachmentBadge(int count) {
return Container(
padding: const EdgeInsets.symmetric(horizontal: 8, vertical: 4),
decoration: BoxDecoration(
color: AppColors.slate200,
borderRadius: BorderRadius.circular(8),
),
child: Text(
'图片附件 x$count',
style: const TextStyle(fontSize: 12, color: AppColors.slate600),
),
);
}
Widget _buildToolCallItem(ToolCallItem item) {
final (statusText, statusColor, statusIcon) = switch (item.status) {
ToolCallStatus.pending => (
@@ -1,4 +1,5 @@
import 'dart:convert';
import 'dart:typed_data';
import 'package:flutter_test/flutter_test.dart';
import 'package:image_picker/image_picker.dart';
@@ -265,6 +266,71 @@ void main() {
expect(first['content'], '只发送当前输入');
});
test('sendMessage uploads images then posts binary url blocks', () async {
final client = MockApiClient();
final service = AgUiService(onEvent: (_) {}, apiClient: client);
client.clearMocks();
var uploadCalls = 0;
final uploadedPath = 'agent-inputs/user/thread-1/upload-1.png';
client.registerHandler('/api/v1/agent/attachments', 'POST', (request) {
uploadCalls += 1;
return {
'attachment': {
'bucket': 'bucket-test',
'path': uploadedPath,
'mimeType': 'image/png',
'url': 'https://signed.example/$uploadedPath',
},
};
});
Map<String, dynamic>? postedRunInput;
client.registerHandler('/api/v1/agent/runs', 'POST', (request) {
postedRunInput = request.data as Map<String, dynamic>;
return {
'taskId': 'task-1',
'threadId': 'thread-1',
'runId': 'run-1',
'created': false,
};
});
client.registerHandler('/api/v1/agent/runs/thread-1/events', 'SSE', (_) {
return <String>[
'event: RUN_STARTED',
'data: {"type":"RUN_STARTED","threadId":"thread-1","runId":"run-1"}',
'',
'event: RUN_FINISHED',
'data: {"type":"RUN_FINISHED","threadId":"thread-1","runId":"run-1"}',
'',
];
});
final image = XFile.fromData(
Uint8List.fromList(<int>[1, 2, 3]),
mimeType: 'image/png',
name: 'demo.png',
);
await service.sendMessage('图文消息', images: [image]);
expect(uploadCalls, 1);
expect(postedRunInput, isNotNull);
final messages = postedRunInput!['messages'] as List<dynamic>;
final first = messages.first as Map<String, dynamic>;
final content = first['content'] as List<dynamic>;
expect((content.first as Map<String, dynamic>)['type'], 'text');
expect((content[1] as Map<String, dynamic>)['type'], 'binary');
expect(
(content[1] as Map<String, dynamic>)['url'],
'https://signed.example/$uploadedPath',
);
final forwardedProps =
postedRunInput!['forwardedProps'] as Map<String, dynamic>;
final attachments = forwardedProps['attachments'] as List<dynamic>;
expect((attachments.first as Map<String, dynamic>)['path'], uploadedPath);
});
test('approveToolCall posts only tool message to resume API', () async {
final client = MockApiClient();
final service = AgUiService(onEvent: (_) {}, apiClient: client);
@@ -482,5 +548,56 @@ void main() {
expect(seenLastEventIds[0], isNull);
expect(seenLastEventIds[1], '2-0');
});
test('stream parses backend TOOL_CALL_RESULT payload with ui field', () async {
final events = <AgUiEvent>[];
final client = MockApiClient();
final service = AgUiService(onEvent: events.add, apiClient: client);
client.clearMocks();
client.registerHandler('/api/v1/agent/runs', 'POST', (_) {
return {
'taskId': 'task-1',
'threadId': 'thread-1',
'runId': 'run-1',
'created': false,
};
});
client.registerHandler('/api/v1/agent/runs/thread-1/events', 'SSE', (_) {
return <String>[
'event: RUN_STARTED',
'data: {"type":"RUN_STARTED","threadId":"thread-1","runId":"run-1"}',
'',
'event: TOOL_CALL_RESULT',
'data: {"type":"TOOL_CALL_RESULT","messageId":"tool-result-1","toolCallId":"call-1","callId":"call-1","toolName":"calendar_write","result":{"type":"calendar_operation.v1","version":"v1","data":{"ok":true,"operation":"create"},"actions":[]},"ui":{"type":"calendar_operation.v1","version":"v1","data":{"ok":true,"operation":"create"},"actions":[]},"content":"已创建日程:项目评审(明天 10:00)"}',
'',
'event: RUN_FINISHED',
'data: {"type":"RUN_FINISHED","threadId":"thread-1","runId":"run-1"}',
'',
];
});
await service.sendMessage('创建日程');
final result = events.whereType<ToolCallResultEvent>().toList();
expect(result.length, 1);
expect(result.first.ui?.cardType, 'calendar_operation.v1');
});
test('fetchAttachmentPreview returns binary bytes', () async {
final client = MockApiClient();
final service = AgUiService(onEvent: (_) {}, apiClient: client);
client.clearMocks();
client.registerHandler(
'/api/v1/agent/runs/t1/attachments/m1/0',
'GET',
(_) => <int>[1, 2, 3, 4],
);
final data = await service.fetchAttachmentPreview(
'/api/v1/agent/runs/t1/attachments/m1/0',
);
expect(data, [1, 2, 3, 4]);
});
});
}
@@ -1,3 +1,5 @@
import 'dart:typed_data';
import 'package:bloc_test/bloc_test.dart';
import 'package:flutter_test/flutter_test.dart';
import 'package:image_picker/image_picker.dart';
@@ -9,8 +11,17 @@ import 'package:social_app/features/chat/presentation/bloc/chat_bloc.dart';
class MockAgUiService extends AgUiService {
MockAgUiService() : super(onEvent: (_) {});
int previewCalls = 0;
@override
Future<void> sendMessage(String content, {List<XFile>? images}) async {}
@override
Future<Uint8List> fetchAttachmentPreview(String previewPath) async {
previewCalls += 1;
await Future<void>.delayed(const Duration(milliseconds: 10));
return Uint8List.fromList(<int>[1, 2, 3]);
}
}
class _ThrowingAgUiService extends AgUiService {
@@ -182,6 +193,23 @@ void main() {
],
);
blocTest<ChatBloc, ChatState>(
'step events update currentStage',
build: () => chatBloc,
act: (bloc) {
service.onEvent(StepStartedEvent(stepName: 'execution'));
service.onEvent(StepFinishedEvent(stepName: 'execution'));
},
expect: () => [
isA<ChatState>().having(
(s) => s.currentStage,
'currentStage',
AgentStage.execution,
),
isA<ChatState>().having((s) => s.currentStage, 'currentStage', isNull),
],
);
blocTest<ChatBloc, ChatState>(
'runError sets error message',
build: () => chatBloc,
@@ -325,5 +353,58 @@ void main() {
),
],
);
blocTest<ChatBloc, ChatState>(
'state snapshot user message keeps attachments',
build: () => chatBloc,
act: (bloc) {
service.onEvent(
StateSnapshotEvent(
snapshot: {
'scope': 'history_day',
'messages': [
{
'id': 'u1',
'role': 'user',
'content': '请分析这张图',
'attachments': [
{'bucket': 'b', 'path': 'p', 'mimeType': 'image/png'},
],
},
],
},
),
);
},
expect: () => [
isA<ChatState>().having(
(s) {
final item = s.items.first;
return item is TextMessageItem && item.attachments.length == 1;
},
'user attachment count',
true,
),
],
);
test(
'loadAttachmentPreview deduplicates in-flight and caches result',
() async {
final mock = service as MockAgUiService;
final results = await Future.wait<Uint8List?>([
chatBloc.loadAttachmentPreview('/api/preview/1'),
chatBloc.loadAttachmentPreview('/api/preview/1'),
]);
final secondRound = await chatBloc.loadAttachmentPreview(
'/api/preview/1',
);
expect(results.first, isNotNull);
expect(results.last, isNotNull);
expect(secondRound, isNotNull);
expect(mock.previewCalls, 1);
},
);
});
}
@@ -6,11 +6,14 @@ import 'package:flutter_test/flutter_test.dart';
import 'package:image_picker/image_picker.dart';
import 'package:lucide_icons/lucide_icons.dart';
import 'package:social_app/core/api/api_exception.dart';
import 'package:social_app/core/api/mock_api_client.dart';
import 'package:social_app/core/di/injection.dart';
import 'package:social_app/features/chat/data/models/ag_ui_event.dart';
import 'package:social_app/features/chat/data/services/ag_ui_service.dart';
import 'package:social_app/features/chat/presentation/bloc/chat_bloc.dart';
import 'package:social_app/features/home/data/voice_recorder.dart';
import 'package:social_app/features/home/ui/screens/home_screen.dart';
import 'package:social_app/features/messages/data/inbox_api.dart';
class _FakeVoiceRecorder implements VoiceRecorder {
bool started = false;
@@ -43,9 +46,19 @@ class _WaitingAgUiService extends AgUiService {
onEvent(RunStartedEvent(threadId: 't1', runId: 'r1'));
return _pending.future;
}
void emitStepStarted(String stepName) {
onEvent(StepStartedEvent(stepName: stepName));
}
}
void main() {
setUpAll(() {
if (!sl.isRegistered<InboxApi>()) {
sl.registerSingleton<InboxApi>(InboxApi(MockApiClient()));
}
});
IconData _inputActionIcon(WidgetTester tester) {
final icon = tester.widget<Icon>(
find.byKey(const ValueKey('home_input_action_icon')),
@@ -275,7 +288,8 @@ void main() {
testWidgets('shows stop icon and waiting indicator while waiting agent', (
WidgetTester tester,
) async {
final chatBloc = ChatBloc(service: _WaitingAgUiService());
final waitingService = _WaitingAgUiService();
final chatBloc = ChatBloc(service: waitingService);
await tester.pumpWidget(
MaterialApp(
home: HomeScreen(autoLoadHistory: false, chatBloc: chatBloc),
@@ -291,6 +305,12 @@ void main() {
expect(_inputActionIcon(tester), LucideIcons.square);
expect(find.text('正在思考...'), findsOneWidget);
waitingService.emitStepStarted('intent');
await tester.pump();
expect(find.text('意图识别中'), findsOneWidget);
expect(find.text('正在思考...'), findsNothing);
await tester.tap(find.byKey(const ValueKey('home_input_action_button')));
await tester.pump();
@@ -37,6 +37,21 @@ def to_agui_wire_event(event: dict[str, Any]) -> dict[str, Any]:
data = event.get("data")
if isinstance(data, dict):
if event_type == "tool.result":
for key in (
"messageId",
"toolCallId",
"callId",
"toolName",
"stage",
"taskId",
"ui",
"content",
):
value = data.get(key)
if value is not None:
payload[key] = value
return payload
reserved = {"type", "threadId", "runId"}
data_map = cast(dict[str, Any], data)
payload.update({k: v for k, v in data_map.items() if k not in reserved})
+171 -15
View File
@@ -1,11 +1,14 @@
from __future__ import annotations
import json
import re
from decimal import Decimal, InvalidOperation
from typing import Any, Callable, Protocol
from uuid import UUID
from uuid import UUID, uuid4
from core.agentscope.events.tool_result_summary import build_tool_content_summary
from core.agentscope.events.persistence import MessageRepository, SessionRepository
from core.logging import get_logger
from models.agent_chat_message import AgentChatMessageRole
from models.agent_chat_session import AgentChatSessionStatus
@@ -14,6 +17,16 @@ class EventStore(Protocol):
async def persist(self, event: dict[str, Any]) -> None: ...
class ToolResultStorageLike(Protocol):
async def upload_json(
self,
*,
bucket: str,
path: str,
payload: dict[str, object],
) -> str: ...
class NullEventStore:
async def persist(self, event: dict[str, Any]) -> None:
del event
@@ -21,9 +34,20 @@ class NullEventStore:
class SqlAlchemyEventStore:
_session_factory: Callable[[], Any]
_tool_result_storage: ToolResultStorageLike | None
_tool_result_bucket: str | None
_logger = get_logger("core.agentscope.events.store")
def __init__(self, *, session_factory: Any) -> None:
def __init__(
self,
*,
session_factory: Any,
tool_result_storage: ToolResultStorageLike | None = None,
tool_result_bucket: str | None = None,
) -> None:
self._session_factory = session_factory
self._tool_result_storage = tool_result_storage
self._tool_result_bucket = tool_result_bucket
self._message_buffers: dict[tuple[str, str], str] = {}
self._message_contexts: dict[tuple[str, str], dict[str, object]] = {}
@@ -228,23 +252,89 @@ class SqlAlchemyEventStore:
if not isinstance(tool_name, str) or not tool_name:
return
payload = {
"args": event.get("args"),
"result": event.get("result"),
"error": event.get("error"),
"call_id": event.get("callId"),
}
content = json.dumps(payload, ensure_ascii=False, separators=(",", ":"))
metadata: dict[str, object] = {"tool_name": tool_name}
run_id = event.get("runId")
if isinstance(run_id, str) and run_id:
metadata["run_id"] = run_id
run_id_value = run_id if isinstance(run_id, str) and run_id else ""
task_id = event.get("taskId")
task_id_value = task_id if isinstance(task_id, str) and task_id else "task"
call_id_value = event.get("callId")
if not isinstance(call_id_value, str) or not call_id_value:
call_id_value = (
f"{run_id_value}-{task_id_value}-{uuid4().hex[:8]}"
if run_id_value
else f"{task_id_value}-{uuid4().hex[:8]}"
)
summary = build_tool_content_summary(
tool_name=tool_name,
args=event.get("args") if isinstance(event.get("args"), dict) else None,
result=event.get("result"),
error=event.get("error"),
)
raw_result_value = event.get("result")
raw_result: dict[str, object] = (
raw_result_value if isinstance(raw_result_value, dict) else {}
)
ui_candidate = raw_result.get("ui")
ui_schema = ui_candidate if isinstance(ui_candidate, dict) else None
result_type = raw_result.get("type")
result_data = raw_result.get("data")
if (
ui_schema is None
and isinstance(result_type, str)
and isinstance(result_data, dict)
):
ui_schema = raw_result
payload: dict[str, object] = {
"toolName": tool_name,
"ui_schema": ui_schema,
"result": _sanitize_result(raw_result),
"error": _sanitize_error(event.get("error")),
"callId": call_id_value,
"runId": run_id_value,
"taskId": task_id_value,
"content": summary,
}
metadata: dict[str, object] = {
"tool_name": tool_name,
"tool_call_id": call_id_value,
"summary_version": "v1",
}
if run_id_value:
metadata["run_id"] = run_id_value
stage = event.get("stage")
if isinstance(stage, str) and stage:
metadata["stage"] = stage
task_id = event.get("taskId")
if isinstance(task_id, str) and task_id:
metadata["task_id"] = task_id
if task_id_value:
metadata["task_id"] = task_id_value
if self._tool_result_storage is not None and self._tool_result_bucket:
safe_run = _sanitize_path_component(run_id_value or "run")
safe_call = _sanitize_path_component(call_id_value)
storage_path = f"tool-results/{session_id}/{safe_run}/{safe_call}.json"
try:
await self._tool_result_storage.upload_json(
bucket=self._tool_result_bucket,
path=storage_path,
payload=payload,
)
metadata["storage_bucket"] = self._tool_result_bucket
metadata["storage_path"] = storage_path
except Exception: # noqa: BLE001
metadata["storage_upload_failed"] = True
self._logger.warning(
"tool result storage upload failed",
session_id=str(session_id),
run_id=run_id_value,
call_id=call_id_value,
storage_path=storage_path,
)
content = summary or json.dumps(
payload, ensure_ascii=False, separators=(",", ":")
)
locked_session = await session_repo.lock_session_for_update(
session_id=session_id
@@ -333,3 +423,69 @@ class SqlAlchemyEventStore:
except (InvalidOperation, TypeError, ValueError):
return Decimal("0")
return parsed if parsed >= 0 else Decimal("0")
def _sanitize_path_component(value: str) -> str:
compact = re.sub(r"[^A-Za-z0-9._-]", "-", value.strip())
compact = compact.strip(".-")
return compact or "id"
def _sanitize_error(value: object) -> str | None:
if isinstance(value, str) and value.strip():
return " ".join(value.split())[:300]
if isinstance(value, dict):
for key in ("message", "error", "detail"):
item = value.get(key)
if isinstance(item, str) and item.strip():
return " ".join(item.split())[:300]
return None
def _sanitize_result(value: object) -> dict[str, object]:
if not isinstance(value, dict):
return {}
def _is_sensitive_key(key: str) -> bool:
normalized = key.strip().lower().replace("-", "_")
if not normalized:
return False
exact = {
"password",
"token",
"secret",
"api_key",
"apikey",
"credential",
"authorization",
"auth",
}
if normalized in exact:
return True
patterns = (
"password",
"token",
"secret",
"auth",
"credential",
"api_key",
"apikey",
"authorization",
)
return any(pattern in normalized for pattern in patterns)
def _sanitize_value(item: object) -> object:
if isinstance(item, dict):
return _sanitize_result(item)
if isinstance(item, list):
return [_sanitize_value(entry) for entry in item]
return item
sanitized: dict[str, object] = {}
for key, item in value.items():
key_text = str(key)
if _is_sensitive_key(key_text):
sanitized[str(key)] = "[REDACTED]"
continue
sanitized[str(key)] = _sanitize_value(item)
return sanitized
@@ -0,0 +1,124 @@
from __future__ import annotations
from typing import Any
def build_tool_content_summary(
*,
tool_name: str,
args: dict[str, Any] | None,
result: Any,
error: Any,
) -> str:
error_message = _extract_error_message(error)
if error_message is not None:
return _truncate(f"{tool_name} 执行失败:{error_message}")
normalized_args = args if isinstance(args, dict) else {}
normalized_result = result if isinstance(result, dict) else {}
business_failure = _extract_business_failure_message(normalized_result)
if business_failure is not None:
return _truncate(f"{tool_name} 执行失败:{business_failure}")
if tool_name == "calendar_write":
title = _pick_first_str(normalized_result, ("title",)) or _pick_first_str(
normalized_args, ("title",)
)
start_at = _pick_first_str(normalized_result, ("startAt", "start_at"))
if title and start_at:
return _truncate(f"已创建日程:{title}{start_at}")
if title:
return _truncate(f"已创建日程:{title}")
if tool_name == "calendar_read":
total = _extract_total(normalized_result)
query = _pick_first_str(normalized_args, ("query",)) or "全部"
if total is not None:
return _truncate(f"查询到 {total} 条日程({query}")
if tool_name == "calendar_delete":
target = _pick_first_str(normalized_result, ("title", "eventId", "event_id"))
if target:
return _truncate(f"已删除日程:{target}")
if tool_name == "calendar_share":
target = _pick_first_str(normalized_result, ("target", "user", "userName"))
if target:
return _truncate(f"已分享日程给 {target}")
if tool_name == "user_resolve":
target = _pick_first_str(normalized_result, ("name", "userName", "userId"))
if target:
return _truncate(f"已匹配用户:{target}")
result_content = _pick_first_str(normalized_result, ("content", "message"))
if result_content:
return _truncate(result_content)
return _truncate(f"{tool_name} 执行完成")
def _extract_error_message(error: Any) -> str | None:
if isinstance(error, str) and error.strip():
return error.strip()
if isinstance(error, dict):
for key in ("message", "error", "detail"):
value = error.get(key)
if isinstance(value, str) and value.strip():
return value.strip()
return None
def _pick_first_str(payload: dict[str, Any], keys: tuple[str, ...]) -> str | None:
for key in keys:
value = payload.get(key)
if isinstance(value, str):
normalized = " ".join(value.split())
if normalized:
return normalized
return None
def _extract_total(result: dict[str, Any]) -> int | None:
candidates: list[Any] = [result.get("total")]
data = result.get("data")
if isinstance(data, dict):
candidates.append(data.get("total"))
events = data.get("events")
if isinstance(events, list):
candidates.append(len(events))
for value in candidates:
if isinstance(value, bool):
continue
if isinstance(value, int) and value >= 0:
return value
if isinstance(value, str) and value.isdigit():
return int(value)
return None
def _extract_business_failure_message(result: dict[str, Any]) -> str | None:
top_ok = result.get("ok")
if top_ok is False:
top_message = _pick_first_str(result, ("message", "error", "detail"))
if top_message:
return top_message
data = result.get("data")
if isinstance(data, dict) and data.get("ok") is False:
data_message = _pick_first_str(data, ("message", "error", "detail"))
if data_message:
return data_message
code = _pick_first_str(data, ("code",))
if code:
return code
return None
def _truncate(text: str, limit: int = 80) -> str:
normalized = " ".join(text.split())
if len(normalized) <= limit:
return normalized
return normalized[: limit - 3] + "..."
@@ -42,11 +42,19 @@ def build_intent_user_prompt(
*, user_input: str | list[dict[str, Any]]
) -> str | list[dict[str, Any]]:
if isinstance(user_input, list):
context_messages = _conversation_context_messages(user_input)
context_hint = (
json.dumps(context_messages, ensure_ascii=True, separators=(",", ":"))
if context_messages
else "[]"
)
instruction_text = "\n\n".join(
[
INTENT_TASK_INSTRUCTION,
"[Output Schema]",
_schema_json(IntentOutput),
"[Conversation Context]",
context_hint,
"[User Input]",
"Use the following multimodal blocks as the latest user input.",
]
@@ -127,6 +135,56 @@ def _latest_user_content_blocks(
return []
def _conversation_context_messages(
user_input: list[dict[str, Any]],
) -> list[dict[str, str]]:
latest_user_index = -1
for index in range(len(user_input) - 1, -1, -1):
item = user_input[index]
if isinstance(item, dict) and item.get("role") == "user":
latest_user_index = index
break
if latest_user_index <= 0:
return []
context_items: list[dict[str, str]] = []
for item in user_input[:latest_user_index]:
if not isinstance(item, dict):
continue
role = item.get("role")
if role not in {"user", "assistant"}:
continue
content = item.get("content")
text = _content_to_text(content)
if text:
context_items.append({"role": str(role), "content": text})
if len(context_items) <= 12:
return context_items
return context_items[-12:]
def _content_to_text(content: Any) -> str:
if isinstance(content, str):
return " ".join(content.split())
if not isinstance(content, list):
return ""
parts: list[str] = []
for block in content:
if not isinstance(block, dict):
continue
block_type = block.get("type")
if block_type == "text":
text = block.get("text")
if isinstance(text, str) and text.strip():
parts.append(" ".join(text.split()))
elif block_type in {"binary", "image"}:
parts.append("[image]")
return " ".join(parts).strip()
def _binary_source_block(item: dict[str, Any]) -> dict[str, Any] | None:
mime_type = item.get("mimeType")
media_type = mime_type if isinstance(mime_type, str) and mime_type else "image/png"
@@ -1,9 +1,11 @@
from __future__ import annotations
import json
import re
from typing import Any, Protocol
from uuid import UUID
import structlog
from sqlalchemy.ext.asyncio import AsyncSession
from core.logging import get_logger
@@ -31,7 +33,9 @@ class PipelineLike(Protocol):
class AgentRouteRuntime:
_orchestrator: OrchestratorLike
_pipeline: PipelineLike
_logger = get_logger("core.agentscope.runtime.agent_route_runtime")
_logger: structlog.stdlib.BoundLogger = get_logger(
"core.agentscope.runtime.agent_route_runtime"
)
def __init__(
self, *, orchestrator: OrchestratorLike, pipeline: PipelineLike
@@ -144,15 +148,6 @@ class AgentRouteRuntime:
"data": {"stepName": "execution"},
},
)
await self._pipeline.emit(
session_id=command.thread_id,
event={
"type": "step.finish",
"threadId": command.thread_id,
"runId": command.run_id,
"data": {"stepName": "execution"},
},
)
await self._emit_stage_text(
thread_id=command.thread_id,
@@ -191,6 +186,15 @@ class AgentRouteRuntime:
task_id=task.task_id,
tool_calls=_task_tool_calls(task),
)
await self._pipeline.emit(
session_id=command.thread_id,
event={
"type": "step.finish",
"threadId": command.thread_id,
"runId": command.run_id,
"data": {"stepName": "execution"},
},
)
await self._pipeline.emit(
session_id=command.thread_id,
@@ -294,6 +298,13 @@ class AgentRouteRuntime:
tool_name = tool_call.get("tool_name")
if not isinstance(tool_name, str) or not tool_name:
continue
call_id = f"{run_id}-{task_id}-{index}"
result_payload = _build_tool_result_event_payload(
tool_name=tool_name,
call_id=call_id,
raw_result=tool_call.get("result"),
raw_error=tool_call.get("error"),
)
await self._pipeline.emit(
session_id=thread_id,
event={
@@ -301,18 +312,175 @@ class AgentRouteRuntime:
"threadId": thread_id,
"runId": run_id,
"data": {
"callId": f"{run_id}-{task_id}-{index}",
"messageId": result_payload["messageId"],
"toolCallId": call_id,
"callId": call_id,
"stage": "execution",
"taskId": task_id,
"toolName": tool_name,
"args": tool_call.get("args", {}),
"result": tool_call.get("result"),
"error": tool_call.get("error"),
"args": _sanitize_result(tool_call.get("args", {})),
"result": result_payload["result"],
"error": result_payload["error"],
"ui": result_payload["ui"],
"content": result_payload["content"],
},
},
)
def _build_tool_result_event_payload(
*,
tool_name: str,
call_id: str,
raw_result: Any,
raw_error: Any,
) -> dict[str, Any]:
result = _sanitize_result(_normalize_tool_result(raw_result))
error = _sanitize_error(raw_error)
ui: dict[str, Any] | None = None
direct_ui = result.get("ui")
if isinstance(direct_ui, dict):
ui = direct_ui
elif isinstance(result.get("type"), str) and isinstance(result.get("data"), dict):
ui = result
text_content = _extract_result_text_content(result)
if text_content is None and isinstance(error, str):
text_content = error
if text_content is None:
text_content = f"{tool_name} 执行完成"
return {
"messageId": f"tool-result-{call_id}",
"result": result,
"error": error,
"ui": ui,
"content": text_content,
}
def _normalize_tool_result(raw_result: Any) -> dict[str, Any]:
if isinstance(raw_result, dict):
content = raw_result.get("content")
if isinstance(content, str):
parsed = _try_parse_json_object(content)
if parsed is not None:
return parsed
return raw_result
if isinstance(raw_result, str):
parsed = _try_parse_json_object(raw_result)
if parsed is not None:
return parsed
text = raw_result.strip()
if text:
return {"content": text}
if raw_result is not None:
return {"value": raw_result}
return {}
def _try_parse_json_object(value: str) -> dict[str, Any] | None:
raw = value.strip()
if not raw:
return None
try:
parsed = json.loads(raw)
except ValueError:
return None
if not isinstance(parsed, dict):
return None
return parsed
def _extract_result_text_content(result: dict[str, Any]) -> str | None:
content = result.get("content")
if isinstance(content, str) and content.strip():
return content
data = result.get("data")
if isinstance(data, dict):
message = data.get("message")
if isinstance(message, str) and message.strip():
return message
return None
def _sanitize_error(value: Any) -> str | None:
if isinstance(value, str) and value.strip():
text = " ".join(value.split())
return _redact_sensitive_text(text)[:300]
if isinstance(value, dict):
for key in ("message", "error", "detail"):
item = value.get(key)
if isinstance(item, str) and item.strip():
text = " ".join(item.split())
return _redact_sensitive_text(text)[:300]
return None
def _sanitize_result(value: Any) -> dict[str, Any]:
if not isinstance(value, dict):
return {}
def _is_sensitive_key(key: str) -> bool:
normalized = key.strip().lower().replace("-", "_")
if not normalized:
return False
exact = {
"password",
"token",
"secret",
"api_key",
"apikey",
"credential",
"authorization",
"auth",
}
if normalized in exact:
return True
patterns = (
"password",
"token",
"secret",
"credential",
"api_key",
"apikey",
"authorization",
)
return any(pattern in normalized for pattern in patterns)
def _sanitize_value(item: Any) -> Any:
if isinstance(item, dict):
return _sanitize_result(item)
if isinstance(item, list):
return [_sanitize_value(entry) for entry in item]
if isinstance(item, str):
return _redact_sensitive_text(item)
return item
sanitized: dict[str, Any] = {}
for key, item in value.items():
key_text = str(key)
if _is_sensitive_key(key_text):
sanitized[key_text] = "[REDACTED]"
continue
sanitized[key_text] = _sanitize_value(item)
return sanitized
def _redact_sensitive_text(value: str) -> str:
redacted = value
key_value_patterns = (
r"(?i)(authorization)\s*[:=]\s*bearer\s+[^\s,;]+",
r"(?i)(password|token|secret|api[_-]?key|authorization|credential)\s*[:=]\s*[^\s,;]+",
r"(?i)(password|token|secret|api[_-]?key|authorization|credential)\s+[^\s,;]+",
)
for pattern in key_value_patterns:
redacted = re.sub(pattern, r"\1=[REDACTED]", redacted)
redacted = re.sub(r"(?i)bearer\s+[^\s,;]+", "Bearer [REDACTED]", redacted)
return redacted
def _text_end_telemetry_payload(metadata: dict[str, Any]) -> dict[str, Any]:
payload: dict[str, Any] = {}
model = _first_non_empty_str(metadata, keys=("model", "model_code"))
@@ -52,6 +52,24 @@ def _tools_payload_from_schema(
return payload
def _merge_tool_schemas(
*schema_sets: list[dict[str, object]],
) -> list[dict[str, object]]:
merged: list[dict[str, object]] = []
seen_names: set[str] = set()
for schemas in schema_sets:
for schema in schemas:
function = schema.get("function")
if not isinstance(function, dict):
continue
name = function.get("name")
if not isinstance(name, str) or not name or name in seen_names:
continue
seen_names.add(name)
merged.append(schema)
return merged
class AgentScopeRuntimeOrchestrator:
_runner: Any
_config_loader: Callable[[AsyncSession], Awaitable[dict[str, RuntimeStageConfig]]]
@@ -96,10 +114,20 @@ class AgentScopeRuntimeOrchestrator:
enable_hitl=False,
)
intent_tools_schema = intent_toolkit.get_json_schemas()
execution_toolkit = build_stage_toolkit(
stage="execution",
session=session,
owner_id=owner_id,
user_token=user_token,
enable_hitl=True,
)
execution_tools_schema = execution_toolkit.get_json_schemas()
intent_prompt = build_system_prompt(
stage="intent",
user_context=user_context,
tools=_tools_payload_from_schema(intent_tools_schema),
tools=_tools_payload_from_schema(
_merge_tool_schemas(intent_tools_schema, execution_tools_schema)
),
)
intent_payload = await self._runner.run_json_stage(
stage_config=stage_config["intent"],
@@ -125,14 +153,6 @@ class AgentScopeRuntimeOrchestrator:
execution_output: ExecutionBatchOutput | None = None
if intent_output.route == "TASK_EXECUTION":
execution_toolkit = build_stage_toolkit(
stage="execution",
session=session,
owner_id=owner_id,
user_token=user_token,
enable_hitl=True,
)
execution_tools_schema = execution_toolkit.get_json_schemas()
execution_prompt = build_system_prompt(
stage="execution",
user_context=user_context,
@@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio
import json
import math
from time import perf_counter
from typing import Any, cast
@@ -106,6 +107,9 @@ class AgentScopeReActRunner:
stage_config=stage_config,
response=response,
latency_ms=latency_ms,
system_prompt=system_prompt,
user_prompt=user_prompt,
assistant_text=text_content,
)
except json.JSONDecodeError as exc:
logger.exception(
@@ -234,6 +238,9 @@ def _merge_stage_response_metadata(
stage_config: RuntimeStageConfig,
response: Any,
latency_ms: int,
system_prompt: str,
user_prompt: str | list[dict[str, Any]],
assistant_text: str,
) -> dict[str, Any]:
result = dict(payload)
existing = result.get("response_metadata")
@@ -247,6 +254,15 @@ def _merge_stage_response_metadata(
completion_tokens = _to_non_negative_int(
_read_value(usage, "completion_tokens") or _read_value(usage, "output_tokens")
)
if prompt_tokens is None:
prompt_tokens = _estimate_token_count(
{
"system": system_prompt,
"user": user_prompt,
}
)
if completion_tokens is None:
completion_tokens = _estimate_token_count(assistant_text)
cost = _to_non_negative_float(
_read_value(usage, "cost")
or _read_value(_read_value(usage, "metadata"), "cost")
@@ -352,3 +368,16 @@ def _estimate_cost_by_pricing(
)
except ValueError:
return None
def _estimate_token_count(value: object) -> int:
try:
serialized = (
value if isinstance(value, str) else json.dumps(value, ensure_ascii=False)
)
except (TypeError, ValueError):
serialized = str(value)
normalized = serialized.strip()
if not normalized:
return 0
return max(1, math.ceil(len(normalized) / 4))
+7 -8
View File
@@ -21,6 +21,7 @@ from core.config.settings import config
from core.db.session import AsyncSessionLocal
from core.logging import get_logger
from core.taskiq.app import bulk_broker, critical_broker, default_broker
from core.agentscope.tools.tool_result_storage import create_tool_result_storage
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
from services.base.redis import get_or_init_redis_client
@@ -67,16 +68,10 @@ def _build_user_context(*, owner_id: UUID, run_input: RunCommand) -> UserAgentCo
def _extract_user_token(
*, command: dict[str, Any], run_input: RunCommand
) -> str | None:
del run_input
raw_token = command.get("user_token")
if isinstance(raw_token, str) and raw_token.strip():
return raw_token.strip()
forwarded = (
run_input.forwarded_props if isinstance(run_input.forwarded_props, dict) else {}
)
for key in ("accessToken", "userToken", "token"):
value = forwarded.get(key)
if isinstance(value, str) and value.strip():
return value.strip()
return None
@@ -162,7 +157,11 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
)
pipeline = AgentScopeEventPipeline(
codec=AgentScopeAgUiCodec(),
store=SqlAlchemyEventStore(session_factory=AsyncSessionLocal),
store=SqlAlchemyEventStore(
session_factory=AsyncSessionLocal,
tool_result_storage=create_tool_result_storage(),
tool_result_bucket=config.storage.bucket,
),
bus=bus,
)
runtime = route_runtime_type(
@@ -67,6 +67,7 @@ def validate_run_request_messages_contract(run_input: RunAgentInput) -> None:
message = run_input.messages[0]
if getattr(message, "role", None) != "user":
raise ValueError("RunAgentInput.messages[0].role must be user")
_validate_user_content_blocks(getattr(message, "content", None))
extract_latest_user_payload(run_input)
@@ -106,84 +107,76 @@ def extract_latest_user_payload(
text_parts.append(text)
blocks.append({"type": "text", "text": text})
continue
if item_type not in {"image", "binary"}:
if item_type != "binary":
continue
source_type: str | None = None
source_value: str | None = None
source_mime: str | None = None
if item_type == "binary":
source_mime = (
item.get("mimeType")
if isinstance(item, dict)
else getattr(item, "mime_type", None)
)
source_url = (
item.get("url")
if isinstance(item, dict)
else getattr(item, "url", None)
)
source_data = (
item.get("data")
if isinstance(item, dict)
else getattr(item, "data", None)
)
if isinstance(source_url, str) and source_url:
source_type = "url"
source_value = source_url
elif isinstance(source_data, str) and source_data:
source_type = "data"
source_value = source_data
else:
source = getattr(item, "source", None)
source_type = (
source.get("type")
if isinstance(source, dict)
else getattr(source, "type", None)
)
source_value = (
source.get("value")
if isinstance(source, dict)
else getattr(source, "value", None)
)
source_mime = (
source.get("mimeType")
if isinstance(source, dict)
else getattr(source, "mimeType", None)
)
if (
source_type == "url"
and isinstance(source_value, str)
and source_value
):
source_url = (
item.get("url")
if isinstance(item, dict)
else getattr(item, "url", None)
)
if isinstance(source_url, str) and source_url:
blocks.append(
{"type": "image_url", "image_url": {"url": source_value}}
)
elif (
source_type == "data"
and isinstance(source_value, str)
and source_value
):
mime_type = (
source_mime
if isinstance(source_mime, str) and source_mime
else "image/png"
)
blocks.append(
{
"type": "image_url",
"image_url": {
"url": f"data:{mime_type};base64,{source_value}"
},
}
{"type": "image_url", "image_url": {"url": source_url}}
)
combined = "".join(text_parts).strip()
if combined:
if combined or blocks:
return combined, blocks
raise ValueError(
"RunAgentInput.messages requires at least one non-empty user message"
)
def _validate_user_content_blocks(content: Any) -> None:
if isinstance(content, str):
if content.strip():
return
raise ValueError(
"RunAgentInput.messages requires at least one non-empty user message"
)
if not isinstance(content, list):
raise ValueError("RunAgentInput.messages[0].content must be string or list")
has_text = False
has_binary = False
for item in content:
item_type = getattr(item, "type", None)
if item_type == "text":
text = getattr(item, "text", None)
if isinstance(text, str) and text.strip():
has_text = True
continue
if item_type == "binary":
mime_type = (
item.get("mimeType")
if isinstance(item, dict)
else getattr(item, "mime_type", None)
)
url = (
item.get("url")
if isinstance(item, dict)
else getattr(item, "url", None)
)
data = (
item.get("data")
if isinstance(item, dict)
else getattr(item, "data", None)
)
if not isinstance(mime_type, str) or not mime_type.startswith("image/"):
raise ValueError("binary content requires image mimeType")
if not isinstance(url, str) or not url:
raise ValueError("binary content requires url")
if isinstance(data, str) and data:
raise ValueError("binary content data is not allowed")
has_binary = True
continue
raise ValueError("unsupported content block type")
if not has_text and not has_binary:
raise ValueError(
"RunAgentInput.messages requires at least one non-empty user message"
)
def extract_latest_tool_result(
run_input: RunAgentInput,
) -> tuple[str, dict[str, object]]:
+3 -2
View File
@@ -3,10 +3,11 @@ from __future__ import annotations
import logging
from logging.config import dictConfig
from pathlib import Path
from typing import cast
import structlog
from core.config.settings import PROJECT_ROOT, RuntimeSettings, Settings
from core.config.settings import PROJECT_ROOT, RuntimeSettings, Settings, config
from core.logging.formatters import (
build_plain_formatter,
build_processor_formatter,
@@ -77,7 +78,7 @@ def build_logging_config(runtime: RuntimeSettings) -> dict[str, object]:
def configure_logging(settings: Settings | None = None) -> None:
active_settings = settings or Settings()
active_settings = settings if settings is not None else cast(Settings, config)
runtime = active_settings.runtime
try:
+25 -10
View File
@@ -19,19 +19,14 @@ class SupabaseService(BaseServiceProvider):
async def initialize(self, **_: Any) -> bool:
try:
self._client = create_client(
self._settings.url,
self._settings.anon_key,
)
self._admin_client = create_client(
self._settings.url,
self._settings.service_role_key,
)
self._init_clients()
self._set_initialized(True)
self.logger.info("Supabase service initialized")
return True
except Exception as exc: # noqa: BLE001
self.logger.warning("Supabase service initialization failed", error=str(exc))
self.logger.warning(
"Supabase service initialization failed", error=str(exc)
)
self._client = None
self._admin_client = None
self._set_initialized(False)
@@ -51,7 +46,9 @@ class SupabaseService(BaseServiceProvider):
return {"status": "unhealthy", "details": {"error": "not initialized"}}
try:
await asyncio.to_thread(client.auth.get_session)
await asyncio.to_thread(admin_client.auth.admin.list_users, page=1, per_page=1)
await asyncio.to_thread(
admin_client.auth.admin.list_users, page=1, per_page=1
)
return {
"status": "healthy",
"details": {
@@ -70,17 +67,35 @@ class SupabaseService(BaseServiceProvider):
return self._require_admin_client()
def _require_client(self) -> Any:
if self._client is None or self._admin_client is None:
self._init_clients()
self._set_initialized(True)
self.logger.info("Supabase service lazily initialized")
client = self._client
if client is None:
raise RuntimeError("Supabase client is not initialized")
return client
def _require_admin_client(self) -> Any:
if self._client is None or self._admin_client is None:
self._init_clients()
self._set_initialized(True)
self.logger.info("Supabase service lazily initialized")
admin_client = self._admin_client
if admin_client is None:
raise RuntimeError("Supabase admin client is not initialized")
return admin_client
def _init_clients(self) -> None:
self._client = create_client(
self._settings.url,
self._settings.anon_key,
)
self._admin_client = create_client(
self._settings.url,
self._settings.service_role_key,
)
supabase_service: SupabaseService = register_service_instance(
"supabase", SupabaseService()
+91 -1
View File
@@ -3,11 +3,20 @@ from __future__ import annotations
import asyncio
from typing import Any
from storage3.exceptions import StorageApiError
from core.config.settings import config
from services.base.supabase import supabase_service
class AgentAttachmentStorage:
def _validate_bucket(self, *, bucket: str) -> None:
expected = config.storage.bucket
if bucket != expected:
raise RuntimeError("Invalid attachment bucket")
def _bucket_client(self, *, bucket: str) -> Any:
self._validate_bucket(bucket=bucket)
client = supabase_service.get_admin_client()
storage = getattr(client, "storage", None)
if storage is None:
@@ -39,9 +48,82 @@ class AgentAttachmentStorage:
},
)
await asyncio.to_thread(_upload)
try:
await asyncio.to_thread(_upload)
except Exception as exc: # noqa: BLE001
if not _is_bucket_not_found_error(exc):
raise
await self._ensure_bucket_exists(bucket=bucket)
await asyncio.to_thread(_upload)
return path
async def _ensure_bucket_exists(self, *, bucket: str) -> None:
def _ensure() -> None:
client = supabase_service.get_admin_client()
storage = getattr(client, "storage", None)
if storage is None:
raise RuntimeError("Supabase storage client unavailable")
get_bucket = getattr(storage, "get_bucket", None)
if callable(get_bucket):
try:
get_bucket(bucket)
return
except Exception: # noqa: BLE001
pass
create_bucket = getattr(storage, "create_bucket", None)
if not callable(create_bucket):
raise RuntimeError("Supabase storage create_bucket is unavailable")
try:
create_bucket(bucket, options={"public": False})
except Exception as exc: # noqa: BLE001
message = str(exc).lower()
if "already exists" in message or "duplicate" in message:
return
raise
await asyncio.to_thread(_ensure)
async def download_bytes(self, *, bucket: str, path: str) -> bytes:
def _download() -> object:
bucket_client = self._bucket_client(bucket=bucket)
download = getattr(bucket_client, "download", None)
if not callable(download):
raise RuntimeError("Supabase storage download is unavailable")
return download(path)
raw = await asyncio.to_thread(_download)
if isinstance(raw, bytes):
return raw
if isinstance(raw, bytearray):
return bytes(raw)
if isinstance(raw, memoryview):
return raw.tobytes()
raise RuntimeError("Invalid attachment payload")
async def create_signed_url(
self,
*,
bucket: str,
path: str,
expires_in_seconds: int,
) -> str:
def _create_signed_url() -> object:
bucket_client = self._bucket_client(bucket=bucket)
signer = getattr(bucket_client, "create_signed_url", None)
if not callable(signer):
raise RuntimeError("Supabase storage signed url is unavailable")
return signer(path, expires_in_seconds)
raw = await asyncio.to_thread(_create_signed_url)
if isinstance(raw, str):
return raw
if isinstance(raw, dict):
signed_url = raw.get("signedURL") or raw.get("signedUrl") or raw.get("url")
if isinstance(signed_url, str) and signed_url:
return signed_url
raise RuntimeError("Invalid signed url payload")
def create_attachment_storage() -> AgentAttachmentStorage | None:
try:
@@ -49,3 +131,11 @@ def create_attachment_storage() -> AgentAttachmentStorage | None:
except Exception:
return None
return AgentAttachmentStorage()
def _is_bucket_not_found_error(exc: Exception) -> bool:
if isinstance(exc, StorageApiError):
message = str(exc).lower()
return "bucket" in message and "not found" in message
message = str(exc).lower()
return "bucket" in message and "not found" in message
+138 -16
View File
@@ -9,6 +9,7 @@ from fastapi import HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from core.config.settings import config
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
from models.agent_chat_session import AgentChatSession
@@ -200,6 +201,61 @@ class AgentRepository:
return None
return str(latest_id)
async def get_message_attachment_reference(
self,
*,
session_id: str,
message_id: str,
attachment_index: int,
) -> dict[str, str] | None:
try:
session_uuid = UUID(session_id)
message_uuid = UUID(message_id)
except ValueError as exc:
raise HTTPException(
status_code=422, detail="Invalid message/session id"
) from exc
stmt = (
select(AgentChatMessage)
.where(AgentChatMessage.id == message_uuid)
.where(AgentChatMessage.session_id == session_uuid)
.where(AgentChatMessage.deleted_at.is_(None))
)
message = (await self._session.execute(stmt)).scalar_one_or_none()
if message is None:
return None
metadata = (
message.metadata_json if isinstance(message.metadata_json, dict) else {}
)
attachments_raw = metadata.get("attachments")
if not isinstance(attachments_raw, list):
return None
if attachment_index < 0 or attachment_index >= len(attachments_raw):
return None
attachment = attachments_raw[attachment_index]
if not isinstance(attachment, dict):
return None
bucket = attachment.get("bucket")
path = attachment.get("path")
mime_type = attachment.get("mimeType")
if (
not isinstance(bucket, str)
or not bucket
or not isinstance(path, str)
or not path
or not isinstance(mime_type, str)
or not mime_type
):
return None
return {
"bucket": bucket,
"path": path,
"mimeType": mime_type,
}
async def _to_snapshot_message(
self, message: AgentChatMessage
) -> dict[str, object]:
@@ -233,30 +289,65 @@ class AgentRepository:
storage_bucket = metadata.get("storage_bucket")
storage_path = metadata.get("storage_path")
if isinstance(storage_bucket, str) and isinstance(storage_path, str):
try:
hydrated_content = await self._tool_result_storage.read_json(
bucket=storage_bucket,
path=storage_path,
expected_bucket = config.storage.bucket
message_session_id = getattr(message, "session_id", None)
expected_prefix = (
f"tool-results/{message_session_id}/"
if message_session_id is not None
else None
)
tool_call_id = metadata.get("tool_call_id")
is_legacy_path = isinstance(
tool_call_id, str
) and storage_path.endswith(f"/{tool_call_id}.json")
if (
storage_bucket == expected_bucket
and _is_safe_storage_path(storage_path)
and (
(
expected_prefix is not None
and storage_path.startswith(expected_prefix)
)
or (
storage_path.startswith("tool-results/")
and is_legacy_path
)
)
except Exception:
hydrated_content = None
):
try:
hydrated_content = (
await self._tool_result_storage.read_json(
bucket=storage_bucket,
path=storage_path,
)
)
except Exception:
hydrated_content = None
resolved_content = hydrated_content or parsed_content
payload["content"] = message.content
if resolved_content is not None:
result = resolved_content.get("result")
if isinstance(result, dict):
result_content = result.get("content")
if isinstance(result_content, str):
payload["content"] = result_content
ui = resolved_content.get("ui")
if not isinstance(ui, dict):
ui = resolved_content.get("ui_schema")
if isinstance(ui, dict):
payload["ui"] = ui
display_content = resolved_content.get("content")
if isinstance(display_content, str):
if not isinstance(display_content, str):
nested_result = resolved_content.get("result")
if isinstance(nested_result, dict):
nested_content = nested_result.get("content")
if isinstance(nested_content, str):
display_content = nested_content
if (
isinstance(display_content, str)
and display_content.strip()
and (
not payload["content"]
or _looks_like_offloaded_placeholder(str(payload["content"]))
)
):
payload["content"] = display_content
if "content" not in payload:
payload["content"] = message.content
else:
payload["content"] = message.content
metadata = message.metadata_json or {}
@@ -264,7 +355,22 @@ class AgentRepository:
metadata.get("attachments") if isinstance(metadata, dict) else None
)
if isinstance(attachments, list):
rendered = [item for item in attachments if isinstance(item, dict)]
rendered: list[dict[str, object]] = []
for index, item in enumerate(attachments):
if not isinstance(item, dict):
continue
mime_type = item.get("mimeType")
if not isinstance(mime_type, str) or not mime_type:
continue
rendered.append(
{
"mimeType": mime_type,
"previewPath": (
f"/api/v1/agent/runs/{message.session_id}/attachments/"
f"{message.id}/{index}"
),
}
)
if rendered:
payload["attachments"] = rendered
return payload
@@ -279,3 +385,19 @@ def _derive_session_title(content_text: str) -> str | None:
if not normalized:
return None
return normalized[:80]
def _is_safe_storage_path(path: str) -> bool:
normalized = path.strip()
if not normalized:
return False
if normalized.startswith("/"):
return False
if ".." in normalized:
return False
return True
def _looks_like_offloaded_placeholder(content: str) -> bool:
normalized = content.strip().lower()
return normalized in {'{"offloaded":true}', '{"offloaded": true}'}
+121 -2
View File
@@ -10,7 +10,17 @@ import time
from typing import Annotated, Union
from ag_ui.core import RunAgentInput
from fastapi import APIRouter, Depends, Header, Query, Request, status, UploadFile
from fastapi import (
APIRouter,
Depends,
File,
Form,
Header,
Query,
Request,
status,
UploadFile,
)
from fastapi import HTTPException
from fastapi.responses import JSONResponse, StreamingResponse
@@ -20,11 +30,18 @@ from core.agentscope.schemas.agui_input import (
parse_run_input,
validate_run_request_messages_contract,
)
from core.auth.jwt_verifier import JwtVerifier, TokenValidationError
from core.auth.models import CurrentUser
from core.config.settings import config
from core.logging import get_logger
from services.base.redis import get_or_init_redis_client
from v1.agent.dependencies import get_agent_service
from v1.agent.schemas import AsrTranscribeResponse, TaskAcceptedResponse
from v1.agent.schemas import (
AsrTranscribeResponse,
AttachmentReference,
AttachmentUploadResponse,
TaskAcceptedResponse,
)
from v1.agent.service import AgentService, asr_service
from v1.users.dependencies import get_current_user
@@ -38,6 +55,7 @@ _SSE_SLOT_TTL_SECONDS = 15 * 60
_MAX_TRANSCRIBE_AUDIO_BYTES = 10 * 1024 * 1024
_TRANSCRIBE_READ_CHUNK_BYTES = 1024 * 1024
_MULTIPART_OVERHEAD_BYTES = 64 * 1024
_MAX_ATTACHMENT_UPLOAD_BYTES = 5 * 1024 * 1024
_WAV_HEADER_MIN_BYTES = 12
_ALLOWED_AUDIO_CONTENT_TYPES = {
"audio/wav",
@@ -46,6 +64,42 @@ _ALLOWED_AUDIO_CONTENT_TYPES = {
}
def _verified_access_token_for_user(
*,
authorization: str | None,
current_user: CurrentUser,
) -> str | None:
if not isinstance(authorization, str):
return None
normalized = authorization.strip()
if not normalized:
return None
if not normalized.lower().startswith("bearer "):
raise HTTPException(status_code=401, detail="Unauthorized")
token = normalized[7:].strip()
if not token:
raise HTTPException(status_code=401, detail="Unauthorized")
jwt_secret = config.supabase.jwt_secret
if jwt_secret is None:
raise HTTPException(status_code=503, detail="Auth verifier unavailable")
verifier = JwtVerifier(
issuer=str(config.supabase.jwt_issuer),
jwt_secret=jwt_secret.get_secret_value(),
jwt_algorithm=config.supabase.jwt_algorithm,
)
try:
payload = verifier.verify(token)
except TokenValidationError as exc:
raise HTTPException(status_code=401, detail="Unauthorized") from exc
subject = payload.get("sub")
if not isinstance(subject, str) or subject != str(current_user.id):
raise HTTPException(status_code=403, detail="Forbidden")
return token
def _looks_like_wav_header(header: bytes) -> bool:
if len(header) < _WAV_HEADER_MIN_BYTES:
return False
@@ -111,6 +165,7 @@ async def enqueue_run(
request: RunAgentInput,
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
authorization: str | None = Header(default=None, alias="Authorization"),
) -> TaskAcceptedResponse:
try:
normalized = parse_run_input(request.model_dump(mode="json", by_alias=True))
@@ -120,10 +175,15 @@ async def enqueue_run(
allowed = await _allow_run_request(user_id=str(current_user.id))
if not allowed:
raise HTTPException(status_code=429, detail="Too many run requests")
user_token = _verified_access_token_for_user(
authorization=authorization,
current_user=current_user,
)
task = await service.enqueue_run(
run_input=request,
current_user=current_user,
user_token=user_token,
)
return TaskAcceptedResponse(
taskId=task.task_id,
@@ -143,6 +203,7 @@ async def enqueue_resume(
request: RunAgentInput,
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
authorization: str | None = Header(default=None, alias="Authorization"),
) -> TaskAcceptedResponse:
if request.thread_id != thread_id:
raise HTTPException(status_code=422, detail="thread_id path/body mismatch")
@@ -154,10 +215,15 @@ async def enqueue_resume(
allowed = await _allow_run_request(user_id=str(current_user.id))
if not allowed:
raise HTTPException(status_code=429, detail="Too many run requests")
user_token = _verified_access_token_for_user(
authorization=authorization,
current_user=current_user,
)
task = await service.enqueue_resume(
thread_id=thread_id,
run_input=request,
current_user=current_user,
user_token=user_token,
)
return TaskAcceptedResponse(
taskId=task.task_id,
@@ -253,6 +319,31 @@ async def get_history_snapshot(
)
@router.get("/runs/{thread_id}/attachments/{message_id}/{attachment_index}")
async def get_attachment_preview(
thread_id: str,
message_id: str,
attachment_index: int,
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
) -> StreamingResponse:
if attachment_index < 0:
raise HTTPException(status_code=422, detail="Invalid attachment index")
payload, mime_type = await service.get_attachment_preview(
thread_id=thread_id,
message_id=message_id,
attachment_index=attachment_index,
current_user=current_user,
)
return StreamingResponse(
iter([payload]),
media_type=mime_type,
headers={
"Cache-Control": "private, max-age=300",
},
)
@router.get("/history")
async def get_user_history_snapshot(
service: Annotated[AgentService, Depends(get_agent_service)],
@@ -267,6 +358,34 @@ async def get_user_history_snapshot(
)
@router.post(
"/attachments",
response_model=AttachmentUploadResponse,
status_code=status.HTTP_200_OK,
)
async def upload_attachment(
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
thread_id: str = Form(alias="threadId"),
file: UploadFile = File(),
) -> AttachmentUploadResponse:
payload = await file.read()
if not payload:
raise HTTPException(status_code=422, detail="Empty attachment")
if len(payload) > _MAX_ATTACHMENT_UPLOAD_BYTES:
raise HTTPException(status_code=413, detail="Attachment too large")
attachment = await service.upload_attachment(
thread_id=thread_id,
filename=file.filename,
content_type=file.content_type,
payload=payload,
current_user=current_user,
)
return AttachmentUploadResponse(
attachment=AttachmentReference.model_validate(attachment),
)
@router.post(
"/transcribe",
response_model=AsrTranscribeResponse,
+13
View File
@@ -14,3 +14,16 @@ class TaskAcceptedResponse(BaseModel):
class AsrTranscribeResponse(BaseModel):
transcript: str = Field(description="Transcribed text from audio")
class AttachmentReference(BaseModel):
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
bucket: str
path: str
mime_type: str = Field(alias="mimeType")
url: str
class AttachmentUploadResponse(BaseModel):
attachment: AttachmentReference
+297 -60
View File
@@ -1,7 +1,6 @@
from __future__ import annotations
import asyncio
import base64
from dataclasses import dataclass
from datetime import date
import hashlib
@@ -19,17 +18,22 @@ from core.config.settings import config
from core.logging import get_logger
logger = get_logger(__name__)
_ALLOWED_ATTACHMENT_MIME_TYPES = {"image/png", "image/jpeg", "image/webp"}
_MAX_ATTACHMENT_BYTES = 5 * 1024 * 1024
_MAX_TOTAL_ATTACHMENT_BYTES = 12 * 1024 * 1024
def _extract_user_token_from_run_input(run_input: RunAgentInput) -> str | None:
forwarded = run_input.forwarded_props
if not isinstance(forwarded, dict):
def _normalize_bearer_token(value: str | None) -> str | None:
if not isinstance(value, str):
return None
for key in ("accessToken", "userToken", "token"):
value = forwarded.get(key)
if isinstance(value, str) and value.strip():
return value.strip()
return None
normalized = value.strip()
if not normalized:
return None
lower = normalized.lower()
if lower.startswith("bearer "):
token = normalized[7:].strip()
return token or None
return normalized
@dataclass(frozen=True)
@@ -66,6 +70,14 @@ class AgentRepositoryLike(Protocol):
metadata: dict[str, object] | None,
) -> None: ...
async def get_message_attachment_reference(
self,
*,
session_id: str,
message_id: str,
attachment_index: int,
) -> dict[str, str] | None: ...
class QueueClientLike(Protocol):
async def enqueue(
@@ -92,6 +104,16 @@ class AttachmentStorageLike(Protocol):
content_type: str,
) -> str: ...
async def download_bytes(self, *, bucket: str, path: str) -> bytes: ...
async def create_signed_url(
self,
*,
bucket: str,
path: str,
expires_in_seconds: int,
) -> str: ...
def ensure_session_owner(*, owner_id: str, current_user: CurrentUser) -> None:
if owner_id != str(current_user.id):
@@ -104,6 +126,8 @@ class AgentService:
_stream: EventStreamLike
_attachment_storage: AttachmentStorageLike | None
_SIGNED_URL_EXPIRES_IN_SECONDS = 3600
def __init__(
self,
*,
@@ -122,6 +146,7 @@ class AgentService:
*,
run_input: RunAgentInput,
current_user: CurrentUser,
user_token: str | None = None,
) -> TaskAccepted:
created = False
thread_id = run_input.thread_id
@@ -161,7 +186,7 @@ class AgentService:
command={
"command": "run",
"owner_id": str(current_user.id),
"user_token": _extract_user_token_from_run_input(run_input),
"user_token": _normalize_bearer_token(user_token),
"run_input": run_input.model_dump(mode="json", by_alias=True),
},
dedup_key=None,
@@ -179,57 +204,115 @@ class AgentService:
run_input: RunAgentInput,
current_user: CurrentUser,
) -> tuple[str, dict[str, object] | None]:
text, content_blocks = extract_latest_user_payload(run_input)
text, _ = extract_latest_user_payload(run_input)
content_blocks = _extract_latest_user_content_blocks(run_input)
attachments: list[dict[str, object]] = []
if self._attachment_storage is not None:
for index, block in enumerate(content_blocks):
if not isinstance(block, dict):
continue
if block.get("type") != "image_url":
continue
image_value = block.get("image_url")
if not isinstance(image_value, dict):
continue
url = image_value.get("url")
if not isinstance(url, str) or not url.startswith("data:"):
continue
decoded = _decode_data_url(url)
if decoded is None:
continue
mime_type, payload = decoded
suffix = _mime_to_suffix(mime_type)
checksum = hashlib.sha1(payload).hexdigest()[:16]
path = (
f"agent-inputs/{current_user.id}/{run_input.thread_id}/"
f"{run_input.run_id}/attachment-{index}-{checksum}.{suffix}"
binary_blocks = [
block
for block in content_blocks
if isinstance(block, dict) and block.get("type") == "binary"
]
if binary_blocks:
if self._attachment_storage is None:
raise HTTPException(
status_code=503,
detail="Attachment storage unavailable",
)
bucket_name = config.storage.bucket
forwarded_props = (
run_input.forwarded_props
if isinstance(run_input.forwarded_props, dict)
else {}
)
raw_attachments = forwarded_props.get("attachments")
if not isinstance(raw_attachments, list):
raise HTTPException(
status_code=422, detail="Invalid attachments payload"
)
if len(raw_attachments) != len(binary_blocks):
raise HTTPException(
status_code=422, detail="Invalid attachments payload"
)
total_attachment_bytes = 0
expected_prefix = f"agent-inputs/{current_user.id}/{run_input.thread_id}/"
for index, raw_attachment in enumerate(raw_attachments):
if not isinstance(raw_attachment, dict):
raise HTTPException(
status_code=422,
detail="Invalid attachment reference",
)
bucket = raw_attachment.get("bucket")
path = raw_attachment.get("path")
mime_type = raw_attachment.get("mimeType")
if (
not isinstance(bucket, str)
or not isinstance(path, str)
or not isinstance(mime_type, str)
):
raise HTTPException(
status_code=422,
detail="Invalid attachment reference",
)
if bucket != config.storage.bucket:
raise HTTPException(status_code=403, detail="Forbidden")
if not _is_safe_attachment_path(path, expected_prefix=expected_prefix):
raise HTTPException(status_code=403, detail="Forbidden")
if mime_type.lower() not in _ALLOWED_ATTACHMENT_MIME_TYPES:
raise HTTPException(
status_code=422,
detail="Unsupported attachment type",
)
binary_block = binary_blocks[index]
binary_mime = binary_block.get("mimeType")
binary_url = binary_block.get("url")
if (
not isinstance(binary_mime, str)
or binary_mime != mime_type
or not isinstance(binary_url, str)
or not binary_url
):
raise HTTPException(
status_code=422,
detail="Invalid attachments payload",
)
try:
stored_path = await self._attachment_storage.upload_bytes(
bucket=bucket_name,
payload = await self._attachment_storage.download_bytes(
bucket=bucket,
path=path,
content=payload,
content_type=mime_type,
)
except Exception: # noqa: BLE001
logger.exception(
"Attachment upload failed",
"Attachment validation download failed",
extra={
"bucket": bucket_name,
"bucket": bucket,
"path": path,
"mime_type": mime_type,
"thread_id": run_input.thread_id,
"run_id": run_input.run_id,
},
)
raise HTTPException(
status_code=502,
detail="Failed to upload attachment",
detail="Failed to fetch attachment",
)
payload_size = len(payload)
if payload_size > _MAX_ATTACHMENT_BYTES:
raise HTTPException(
status_code=413,
detail="Attachment too large",
)
total_attachment_bytes += payload_size
if total_attachment_bytes > _MAX_TOTAL_ATTACHMENT_BYTES:
raise HTTPException(
status_code=413,
detail="Attachments too large",
)
attachments.append(
{
"bucket": bucket_name,
"path": stored_path,
"bucket": bucket,
"path": path,
"mimeType": mime_type,
}
)
@@ -238,12 +321,94 @@ class AgentService:
metadata["attachments"] = attachments
return text, metadata or None
async def upload_attachment(
self,
*,
thread_id: str,
filename: str | None,
content_type: str | None,
payload: bytes,
current_user: CurrentUser,
) -> dict[str, str]:
try:
owner = await self._repository.get_session_owner(session_id=thread_id)
except HTTPException as exc:
if exc.status_code != 404:
raise
try:
await self._repository.create_session_for_user(
user_id=str(current_user.id),
session_id=thread_id,
)
await self._repository.commit()
except IntegrityError:
await self._repository.rollback()
owner = await self._repository.get_session_owner(session_id=thread_id)
ensure_session_owner(owner_id=owner, current_user=current_user)
else:
ensure_session_owner(owner_id=owner, current_user=current_user)
if self._attachment_storage is None:
raise HTTPException(
status_code=503, detail="Attachment storage unavailable"
)
if not isinstance(content_type, str):
raise HTTPException(status_code=422, detail="Unsupported attachment type")
mime_type = content_type.lower()
if mime_type not in _ALLOWED_ATTACHMENT_MIME_TYPES:
raise HTTPException(status_code=422, detail="Unsupported attachment type")
if not payload:
raise HTTPException(status_code=422, detail="Empty attachment")
if len(payload) > _MAX_ATTACHMENT_BYTES:
raise HTTPException(status_code=413, detail="Attachment too large")
suffix = _mime_to_suffix(mime_type)
checksum = hashlib.sha1(payload).hexdigest()[:16]
filename_seed = filename if isinstance(filename, str) and filename else "upload"
filename_hash = hashlib.sha1(filename_seed.encode("utf-8")).hexdigest()[:8]
path = (
f"agent-inputs/{current_user.id}/{thread_id}/uploads/"
f"{filename_hash}-{checksum}.{suffix}"
)
bucket_name = config.storage.bucket
try:
stored_path = await self._attachment_storage.upload_bytes(
bucket=bucket_name,
path=path,
content=payload,
content_type=mime_type,
)
signed_url = await self._attachment_storage.create_signed_url(
bucket=bucket_name,
path=stored_path,
expires_in_seconds=self._SIGNED_URL_EXPIRES_IN_SECONDS,
)
except Exception: # noqa: BLE001
logger.exception(
"Attachment upload failed",
extra={
"bucket": bucket_name,
"path": path,
"mime_type": mime_type,
"thread_id": thread_id,
},
)
raise HTTPException(status_code=502, detail="Failed to upload attachment")
return {
"bucket": bucket_name,
"path": stored_path,
"mimeType": mime_type,
"url": signed_url,
}
async def enqueue_resume(
self,
*,
thread_id: str,
run_input: RunAgentInput,
current_user: CurrentUser,
user_token: str | None = None,
) -> TaskAccepted:
owner = await self._repository.get_session_owner(session_id=thread_id)
ensure_session_owner(owner_id=owner, current_user=current_user)
@@ -253,7 +418,7 @@ class AgentService:
command={
"command": "resume",
"owner_id": str(current_user.id),
"user_token": _extract_user_token_from_run_input(run_input),
"user_token": _normalize_bearer_token(user_token),
"run_input": run_input.model_dump(mode="json", by_alias=True),
},
dedup_key=dedup_key,
@@ -336,6 +501,63 @@ class AgentService:
current_user=current_user,
)
async def get_attachment_preview(
self,
*,
thread_id: str,
message_id: str,
attachment_index: int,
current_user: CurrentUser,
) -> tuple[bytes, str]:
owner = await self._repository.get_session_owner(session_id=thread_id)
ensure_session_owner(owner_id=owner, current_user=current_user)
if self._attachment_storage is None:
raise HTTPException(
status_code=503, detail="Attachment storage unavailable"
)
ref = await self._repository.get_message_attachment_reference(
session_id=thread_id,
message_id=message_id,
attachment_index=attachment_index,
)
if ref is None:
raise HTTPException(status_code=404, detail="Attachment not found")
bucket = ref.get("bucket")
path = ref.get("path")
mime_type = ref.get("mimeType")
if (
not isinstance(bucket, str)
or not isinstance(path, str)
or not isinstance(mime_type, str)
):
raise HTTPException(status_code=404, detail="Attachment not found")
if bucket != config.storage.bucket:
raise HTTPException(status_code=403, detail="Forbidden")
expected_prefix = f"agent-inputs/{current_user.id}/{thread_id}/"
if not _is_safe_attachment_path(path, expected_prefix=expected_prefix):
raise HTTPException(status_code=403, detail="Forbidden")
try:
payload = await self._attachment_storage.download_bytes(
bucket=bucket,
path=path,
)
except Exception: # noqa: BLE001
logger.exception(
"Attachment download failed",
extra={
"thread_id": thread_id,
"message_id": message_id,
"attachment_index": attachment_index,
"bucket": bucket,
},
)
raise HTTPException(status_code=502, detail="Failed to fetch attachment")
return payload, mime_type
class AsrService:
def __init__(self) -> None:
@@ -445,22 +667,26 @@ class AsrService:
asr_service = AsrService()
def _decode_data_url(data_url: str) -> tuple[str, bytes] | None:
if not data_url.startswith("data:"):
return None
header, sep, payload = data_url.partition(",")
if not sep:
return None
mime_type = "image/png"
if ";" in header:
maybe_mime = header[5:].split(";", 1)[0].strip()
if maybe_mime:
mime_type = maybe_mime
try:
decoded = base64.b64decode(payload, validate=True)
except ValueError:
return None
return mime_type, decoded
def _extract_latest_user_content_blocks(
run_input: RunAgentInput,
) -> list[dict[str, Any]]:
if not run_input.messages:
return []
latest = run_input.messages[-1]
content = getattr(latest, "content", None)
if not isinstance(content, list):
return []
blocks: list[dict[str, Any]] = []
for item in content:
if isinstance(item, dict):
blocks.append(item)
continue
model_dump = getattr(item, "model_dump", None)
if callable(model_dump):
dumped = model_dump(mode="json", by_alias=True, exclude_none=True)
if isinstance(dumped, dict):
blocks.append(dumped)
return blocks
def _mime_to_suffix(mime_type: str) -> str:
@@ -470,3 +696,14 @@ def _mime_to_suffix(mime_type: str) -> str:
"image/webp": "webp",
}
return mapping.get(mime_type.lower(), "bin")
def _is_safe_attachment_path(path: str, *, expected_prefix: str) -> bool:
normalized = path.strip()
if not normalized:
return False
if normalized.startswith("/"):
return False
if ".." in normalized:
return False
return normalized.startswith(expected_prefix)
@@ -1,16 +1,13 @@
from __future__ import annotations
from typing import Callable
from uuid import UUID
import pytest
from fastapi import HTTPException
from fastapi.testclient import TestClient
from app import app
from core.auth.models import CurrentUser
from v1.auth.dependencies import get_auth_service
from v1.users.dependencies import get_current_user
from v1.auth.rate_limit import reset_rate_limit_state
from v1.auth.schemas import (
AuthUser,
@@ -18,8 +18,14 @@ class _FakeAgentService:
def __init__(self) -> None:
self._stream_called = False
async def enqueue_run(self, *, run_input: RunAgentInput, current_user: CurrentUser):
del current_user
async def enqueue_run(
self,
*,
run_input: RunAgentInput,
current_user: CurrentUser,
user_token: str | None = None,
):
del current_user, user_token
return SimpleNamespace(
task_id="task-run-1",
thread_id=run_input.thread_id,
@@ -33,8 +39,9 @@ class _FakeAgentService:
thread_id: str,
run_input: RunAgentInput,
current_user: CurrentUser,
user_token: str | None = None,
):
del thread_id, current_user
del thread_id, current_user, user_token
return SimpleNamespace(
task_id="task-resume-1",
thread_id=run_input.thread_id,
@@ -109,6 +116,23 @@ class _FakeAgentService:
},
}
async def upload_attachment(
self,
*,
thread_id: str,
filename: str | None,
content_type: str | None,
payload: bytes,
current_user: CurrentUser,
) -> dict[str, str]:
del filename, content_type, payload, current_user
return {
"bucket": "bucket-test",
"path": f"agent-inputs/user/{thread_id}/upload.png",
"mimeType": "image/png",
"url": "https://signed.example/upload.png",
}
class _FailingStreamAgentService(_FakeAgentService):
async def stream_events(
@@ -393,6 +417,31 @@ def test_resume_accepts_tool_message_without_user_message() -> None:
app.dependency_overrides = {}
def test_upload_attachment_returns_reference() -> None:
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
id=uuid4(), email="user@example.com"
)
client = TestClient(app)
file_payload = BytesIO(b"png")
file_payload.name = "demo.png"
try:
response = client.post(
"/api/v1/agent/attachments",
data={"threadId": "00000000-0000-0000-0000-000000000001"},
files={"file": ("demo.png", file_payload, "image/png")},
)
assert response.status_code == 200
body = response.json()
attachment = body["attachment"]
assert attachment["mimeType"] == "image/png"
assert "00000000-0000-0000-0000-000000000001" in attachment["path"]
finally:
app.dependency_overrides = {}
def test_asr_transcribe_returns_sync_transcript(monkeypatch) -> None:
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
id=uuid4(), email="user@example.com"
@@ -40,3 +40,34 @@ def test_reserved_keys_in_data_cannot_override_wire_fields() -> None:
assert result["threadId"] == "thread-1"
assert result["runId"] == "run-1"
assert result["message"] == "ok"
def test_tool_result_wire_event_filters_sensitive_fields() -> None:
internal = {
"type": "tool.result",
"threadId": "thread-1",
"runId": "run-1",
"data": {
"messageId": "tool-result-1",
"toolCallId": "call-1",
"callId": "call-1",
"toolName": "calendar_write",
"content": "summary",
"ui": {"type": "calendar_operation.v1", "data": {"ok": True}},
"args": {"token": "secret"},
"result": {"raw": "secret"},
"error": "stack trace",
},
}
result = to_agui_wire_event(internal)
assert result["type"] == "TOOL_CALL_RESULT"
assert result["messageId"] == "tool-result-1"
assert result["toolCallId"] == "call-1"
assert result["toolName"] == "calendar_write"
assert result["content"] == "summary"
assert isinstance(result.get("ui"), dict)
assert "args" not in result
assert "result" not in result
assert "error" not in result
@@ -28,6 +28,27 @@ class _FakeSessionCtx:
del exc_type, exc, tb
class _FakeToolResultStorage:
def __init__(self) -> None:
self.upload_calls: list[dict[str, object]] = []
async def upload_json(
self,
*,
bucket: str,
path: str,
payload: dict[str, object],
) -> str:
self.upload_calls.append(
{
"bucket": bucket,
"path": path,
"payload": payload,
}
)
return path
@pytest.mark.asyncio
async def test_store_marks_session_running_on_run_started(
monkeypatch: pytest.MonkeyPatch,
@@ -300,7 +321,12 @@ async def test_store_persists_tool_call_result_as_tool_message(
monkeypatch.setattr(store_module, "MessageRepository", _FakeMessageRepository)
monkeypatch.setattr(store_module, "AgentChatSessionStatus", _SessionStatus)
store = store_module.SqlAlchemyEventStore(session_factory=lambda: _FakeSessionCtx())
fake_storage = _FakeToolResultStorage()
store = store_module.SqlAlchemyEventStore(
session_factory=lambda: _FakeSessionCtx(),
tool_result_storage=fake_storage,
tool_result_bucket="agent-tool-results",
)
await store.persist(
{
"type": "TOOL_CALL_RESULT",
@@ -310,7 +336,7 @@ async def test_store_persists_tool_call_result_as_tool_message(
"taskId": "t1",
"stage": "execution",
"args": {"title": "A"},
"result": {"event_id": "evt-1"},
"result": {"event_id": "evt-1", "token": "secret"},
}
)
@@ -318,9 +344,94 @@ async def test_store_persists_tool_call_result_as_tool_message(
assert getattr(append_kwargs["role"], "value", None) == "tool"
assert append_kwargs["tool_name"] == "calendar_write"
assert append_kwargs["metadata"]["task_id"] == "t1"
tool_call_id = append_kwargs["metadata"]["tool_call_id"]
assert isinstance(tool_call_id, str)
assert tool_call_id.startswith("run-1-t1-")
assert append_kwargs["metadata"]["storage_bucket"] == "agent-tool-results"
assert isinstance(append_kwargs["metadata"]["storage_path"], str)
assert append_kwargs["content"].startswith("已创建日程")
assert len(fake_storage.upload_calls) == 1
uploaded = fake_storage.upload_calls[0]
assert uploaded["bucket"] == "agent-tool-results"
payload = cast(dict[str, Any], uploaded["payload"])
assert payload["toolName"] == "calendar_write"
assert "args" not in payload
assert isinstance(payload.get("result"), dict)
assert payload["result"]["token"] == "[REDACTED]"
assert captured["message_delta"] == 1
@pytest.mark.asyncio
async def test_store_sanitizes_nested_sensitive_fields_in_result_payload(
monkeypatch: pytest.MonkeyPatch,
) -> None:
captured: dict[str, object] = {}
fake_chat_session = SimpleNamespace(state_snapshot={}, message_count=0)
class _FakeSessionRepository:
def __init__(self, session: object) -> None:
del session
async def get_session(self, *, session_id): # noqa: ANN001
del session_id
return fake_chat_session
async def lock_session_for_update(self, *, session_id): # noqa: ANN001
del session_id
return fake_chat_session
async def update_runtime_state(self, **kwargs): # noqa: ANN003
captured.update(kwargs)
class _FakeMessageRepository:
def __init__(self, session: object) -> None:
del session
async def append_message(self, **kwargs): # noqa: ANN003
captured["append_kwargs"] = kwargs
monkeypatch.setattr(store_module, "SessionRepository", _FakeSessionRepository)
monkeypatch.setattr(store_module, "MessageRepository", _FakeMessageRepository)
monkeypatch.setattr(store_module, "AgentChatSessionStatus", _SessionStatus)
fake_storage = _FakeToolResultStorage()
store = store_module.SqlAlchemyEventStore(
session_factory=lambda: _FakeSessionCtx(),
tool_result_storage=fake_storage,
tool_result_bucket="agent-tool-results",
)
await store.persist(
{
"type": "TOOL_CALL_RESULT",
"threadId": "00000000-0000-0000-0000-000000000001",
"runId": "run-1",
"toolName": "calendar_write",
"result": {
"data": {
"ok": True,
"accessToken": "secret-a",
"nested": {
"refresh_token": "secret-b",
},
"items": [
{"authorizationHeader": "secret-c"},
],
}
},
}
)
payload = cast(dict[str, Any], fake_storage.upload_calls[0]["payload"])
stored_result = cast(dict[str, Any], payload["result"])
data = cast(dict[str, Any], stored_result["data"])
assert data["accessToken"] == "[REDACTED]"
nested = cast(dict[str, Any], data["nested"])
assert nested["refresh_token"] == "[REDACTED]"
items = cast(list[Any], data["items"])
assert isinstance(items[0], dict)
assert items[0]["authorizationHeader"] == "[REDACTED]"
@pytest.mark.asyncio
async def test_store_drops_buffer_when_session_missing(
monkeypatch: pytest.MonkeyPatch,
@@ -0,0 +1,73 @@
from __future__ import annotations
from core.agentscope.events.tool_result_summary import build_tool_content_summary
def test_summary_prioritizes_error() -> None:
text = build_tool_content_summary(
tool_name="calendar_write",
args={"title": "A"},
result={"message": "ignored"},
error={"message": "denied"},
)
assert text == "calendar_write 执行失败:denied"
def test_summary_for_calendar_write() -> None:
text = build_tool_content_summary(
tool_name="calendar_write",
args={"title": "项目评审"},
result={"startAt": "明天 10:00"},
error=None,
)
assert text == "已创建日程:项目评审(明天 10:00)"
def test_summary_for_calendar_read() -> None:
text = build_tool_content_summary(
tool_name="calendar_read",
args={"query": "今天"},
result={"data": {"total": 3}},
error=None,
)
assert text == "查询到 3 条日程(今天)"
def test_summary_falls_back_to_result_content() -> None:
text = build_tool_content_summary(
tool_name="unknown_tool",
args=None,
result={"content": "这是非常长的说明" * 20},
error=None,
)
assert text.startswith("这是非常长的说明")
assert len(text) <= 80
def test_summary_default_done() -> None:
text = build_tool_content_summary(
tool_name="unknown_tool",
args=None,
result=None,
error=None,
)
assert text == "unknown_tool 执行完成"
def test_summary_marks_business_failure_when_ok_false() -> None:
text = build_tool_content_summary(
tool_name="calendar_write",
args={"title": "上学"},
result={
"type": "calendar_operation.v1",
"data": {
"ok": False,
"code": "UNAUTHORIZED",
"message": "calendar.write requires validated user token",
},
},
error=None,
)
assert (
text == "calendar_write 执行失败:calendar.write requires validated user token"
)
@@ -109,7 +109,6 @@ async def test_runtime_emits_started_text_and_finished_events() -> None:
"step.start",
"step.finish",
"step.start",
"step.finish",
"text.start",
"text.delta",
"text.end",
@@ -117,6 +116,7 @@ async def test_runtime_emits_started_text_and_finished_events() -> None:
"text.delta",
"text.end",
"tool.result",
"step.finish",
"step.start",
"text.start",
"text.delta",
@@ -127,10 +127,14 @@ async def test_runtime_emits_started_text_and_finished_events() -> None:
assert calls[1]["data"]["stepName"] == "intent"
assert calls[2]["data"]["stepName"] == "intent"
assert calls[3]["data"]["stepName"] == "execution"
assert calls[4]["data"]["stepName"] == "execution"
assert calls[5]["data"]["stage"] == "intent"
assert calls[8]["data"]["stage"] == "execution"
assert calls[11]["data"]["toolName"] == "calendar_write"
assert calls[4]["data"]["stage"] == "intent"
assert calls[7]["data"]["stage"] == "execution"
assert calls[10]["data"]["toolName"] == "calendar_write"
assert calls[10]["data"]["toolCallId"] == "run-1-t1-1"
assert calls[10]["data"]["messageId"] == "tool-result-run-1-t1-1"
tool_content = calls[10]["data"]["content"]
assert tool_content == "calendar_write 执行完成"
assert calls[11]["data"]["stepName"] == "execution"
assert calls[12]["data"]["stepName"] == "report"
assert calls[14]["data"]["delta"] == "hello world"
assert calls[13]["data"]["messageId"] == calls[14]["data"]["messageId"]
@@ -305,3 +309,300 @@ async def test_runtime_direct_response_finishes_without_report_stage() -> None:
]
assert calls[3]["data"]["stage"] == "intent"
assert calls[4]["data"]["delta"] == "direct-answer"
@pytest.mark.asyncio
async def test_runtime_tool_result_parses_json_string_ui_payload() -> None:
calls: list[dict[str, Any]] = []
class _FakePipeline:
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
assert session_id == "thread-1"
calls.append(event)
return f"{len(calls)}-0"
class _FakeOrchestrator:
async def run(self, **_: object) -> RuntimeOutput:
return RuntimeOutput(
intent=IntentOutput(
route="TASK_EXECUTION",
intent_summary="summary",
direct_response=None,
tasks=[IntentTask(task_id="t1", title="exec", objective="do")],
complexity="complex",
response_metadata={},
),
execution=ExecutionBatchOutput(
task_results=[
ExecutionTaskOutput(
task_id="t1",
status="SUCCESS",
execution_summary="execution-ok",
execution_data={},
user_feedback_needs=[],
response_metadata={},
tool_calls=[
ExecutionToolCall(
tool_name="calendar_write",
args={"title": "A"},
result='{"type":"calendar_card.v1","version":"v1","data":{"ok":true,"title":"A"},"actions":[]}',
)
],
)
],
overall_status="SUCCESS",
aggregate_summary="ok",
),
report=ReportOutput(
assistant_text="hello world",
response_metadata={},
),
)
runtime = AgentRouteRuntime(
orchestrator=_FakeOrchestrator(), pipeline=_FakePipeline()
)
command = RunCommand(threadId="thread-1", runId="run-1", messages=[])
await runtime.run(
command=command,
owner_id=uuid4(),
user_token="token",
user_context=_user_context(),
session=cast(AsyncSession, object()),
)
tool_events = [item for item in calls if item.get("type") == "tool.result"]
assert len(tool_events) == 1
data = tool_events[0]["data"]
assert isinstance(data, dict)
assert isinstance(data.get("ui"), dict)
assert data["ui"]["type"] == "calendar_card.v1"
@pytest.mark.asyncio
async def test_runtime_tool_result_keeps_plain_text_content() -> None:
calls: list[dict[str, Any]] = []
class _FakePipeline:
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
assert session_id == "thread-1"
calls.append(event)
return f"{len(calls)}-0"
class _FakeOrchestrator:
async def run(self, **_: object) -> RuntimeOutput:
return RuntimeOutput(
intent=IntentOutput(
route="TASK_EXECUTION",
intent_summary="summary",
direct_response=None,
tasks=[IntentTask(task_id="t1", title="exec", objective="do")],
complexity="complex",
response_metadata={},
),
execution=ExecutionBatchOutput(
task_results=[
ExecutionTaskOutput(
task_id="t1",
status="SUCCESS",
execution_summary="execution-ok",
execution_data={},
user_feedback_needs=[],
response_metadata={},
tool_calls=[
ExecutionToolCall(
tool_name="calendar_write",
args={"title": "A"},
result="created successfully",
)
],
)
],
overall_status="SUCCESS",
aggregate_summary="ok",
),
report=ReportOutput(
assistant_text="hello world",
response_metadata={},
),
)
runtime = AgentRouteRuntime(
orchestrator=_FakeOrchestrator(), pipeline=_FakePipeline()
)
command = RunCommand(threadId="thread-1", runId="run-1", messages=[])
await runtime.run(
command=command,
owner_id=uuid4(),
user_token="token",
user_context=_user_context(),
session=cast(AsyncSession, object()),
)
tool_events = [item for item in calls if item.get("type") == "tool.result"]
assert len(tool_events) == 1
data = tool_events[0]["data"]
assert isinstance(data, dict)
assert data["content"] == "created successfully"
@pytest.mark.asyncio
async def test_runtime_tool_result_sanitizes_sensitive_payload() -> None:
calls: list[dict[str, Any]] = []
class _FakePipeline:
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
assert session_id == "thread-1"
calls.append(event)
return f"{len(calls)}-0"
class _FakeOrchestrator:
async def run(self, **_: object) -> RuntimeOutput:
return RuntimeOutput(
intent=IntentOutput(
route="TASK_EXECUTION",
intent_summary="summary",
direct_response=None,
tasks=[IntentTask(task_id="t1", title="exec", objective="do")],
complexity="complex",
response_metadata={},
),
execution=ExecutionBatchOutput(
task_results=[
ExecutionTaskOutput(
task_id="t1",
status="SUCCESS",
execution_summary="execution-ok",
execution_data={},
user_feedback_needs=[],
response_metadata={},
tool_calls=[
ExecutionToolCall(
tool_name="calendar_write",
args={
"title": "A",
"accessToken": "arg-secret",
"author": "alice",
},
result={
"ok": True,
"accessToken": "secret-token",
"message": "Authorization: Bearer inline-token",
"nested": [
{
"authorizationHeader": "Bearer abc",
}
],
},
error="failed authorization=Bearer abc123 detail",
)
],
)
],
overall_status="SUCCESS",
aggregate_summary="ok",
),
report=ReportOutput(
assistant_text="hello world",
response_metadata={},
),
)
runtime = AgentRouteRuntime(
orchestrator=_FakeOrchestrator(), pipeline=_FakePipeline()
)
command = RunCommand(threadId="thread-1", runId="run-1", messages=[])
await runtime.run(
command=command,
owner_id=uuid4(),
user_token="token",
user_context=_user_context(),
session=cast(AsyncSession, object()),
)
tool_events = [item for item in calls if item.get("type") == "tool.result"]
assert len(tool_events) == 1
data = tool_events[0]["data"]
assert isinstance(data, dict)
assert isinstance(data["result"], dict)
assert data["result"]["accessToken"] == "[REDACTED]"
assert data["result"]["message"] == "Authorization=[REDACTED]"
nested = data["result"]["nested"]
assert isinstance(nested, list)
assert nested[0]["authorizationHeader"] == "[REDACTED]"
assert isinstance(data["args"], dict)
assert data["args"]["accessToken"] == "[REDACTED]"
assert data["args"]["author"] == "alice"
assert data["error"] == "failed authorization=[REDACTED] detail"
@pytest.mark.asyncio
async def test_runtime_tool_result_keeps_non_object_result() -> None:
calls: list[dict[str, Any]] = []
class _FakePipeline:
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
assert session_id == "thread-1"
calls.append(event)
return f"{len(calls)}-0"
class _FakeOrchestrator:
async def run(self, **_: object) -> RuntimeOutput:
return RuntimeOutput(
intent=IntentOutput(
route="TASK_EXECUTION",
intent_summary="summary",
direct_response=None,
tasks=[IntentTask(task_id="t1", title="exec", objective="do")],
complexity="complex",
response_metadata={},
),
execution=ExecutionBatchOutput(
task_results=[
ExecutionTaskOutput(
task_id="t1",
status="SUCCESS",
execution_summary="execution-ok",
execution_data={},
user_feedback_needs=[],
response_metadata={},
tool_calls=[
ExecutionToolCall(
tool_name="calendar_write",
args={"title": "A"},
result=["evt-1", "evt-2"],
)
],
)
],
overall_status="SUCCESS",
aggregate_summary="ok",
),
report=ReportOutput(
assistant_text="hello world",
response_metadata={},
),
)
runtime = AgentRouteRuntime(
orchestrator=_FakeOrchestrator(), pipeline=_FakePipeline()
)
command = RunCommand(threadId="thread-1", runId="run-1", messages=[])
await runtime.run(
command=command,
owner_id=uuid4(),
user_token="token",
user_context=_user_context(),
session=cast(AsyncSession, object()),
)
tool_events = [item for item in calls if item.get("type") == "tool.result"]
assert len(tool_events) == 1
data = tool_events[0]["data"]
assert isinstance(data, dict)
assert isinstance(data["result"], dict)
assert data["result"]["value"] == ["evt-1", "evt-2"]
@@ -212,6 +212,9 @@ def test_merge_stage_response_metadata_estimates_cost_from_pricing(
model="qwen3.5-flash",
),
latency_ms=50,
system_prompt="system",
user_prompt="user",
assistant_text='{"route":"DIRECT_RESPONSE"}',
)
metadata = payload["response_metadata"]
@@ -50,6 +50,10 @@ async def test_run_agentscope_task_calls_runtime_run(
async def _fake_get_redis_client() -> object:
return object()
async def _empty_context(**kwargs: object) -> list[dict[str, Any]]:
del kwargs
return []
monkeypatch.setattr(tasks_module, "AgentRouteRuntime", _FakeRuntime)
monkeypatch.setattr(
tasks_module,
@@ -60,7 +64,7 @@ async def test_run_agentscope_task_calls_runtime_run(
monkeypatch.setattr(
tasks_module,
"_build_recent_context_messages",
lambda **_: [],
_empty_context,
)
result = await tasks_module.run_agentscope_task(
@@ -101,6 +105,10 @@ async def test_run_agentscope_task_includes_recent_context_messages(
async def _fake_get_redis_client() -> object:
return object()
async def _empty_context(**kwargs: object) -> list[dict[str, Any]]:
del kwargs
return []
async def _fake_context(**kwargs: object) -> list[dict[str, Any]]:
del kwargs
return [{"id": "ctx-1", "role": "assistant", "content": "历史上下文"}]
@@ -115,7 +123,7 @@ async def test_run_agentscope_task_includes_recent_context_messages(
monkeypatch.setattr(
tasks_module,
"_build_recent_context_messages",
lambda **_: [],
_empty_context,
)
monkeypatch.setattr(
tasks_module,
@@ -94,3 +94,46 @@ def test_validate_run_request_messages_contract_requires_single_user_message() -
match="RunAgentInput.messages must contain exactly one user message",
):
validate_run_request_messages_contract(run_input)
def test_validate_run_request_messages_contract_accepts_binary_url_blocks() -> None:
payload = _base_payload()
payload["messages"] = [
{
"id": "u1",
"role": "user",
"content": [
{"type": "text", "text": "请分析"},
{
"type": "binary",
"mimeType": "image/png",
"url": "https://signed.example/a.png",
},
],
}
]
run_input = parse_run_input(payload)
validate_run_request_messages_contract(run_input)
def test_validate_run_request_messages_contract_rejects_binary_data_block() -> None:
payload = _base_payload()
payload["messages"] = [
{
"id": "u1",
"role": "user",
"content": [
{"type": "text", "text": "请分析"},
{
"type": "binary",
"mimeType": "image/png",
"data": "aGVsbG8=",
},
],
}
]
run_input = parse_run_input(payload)
with pytest.raises(ValueError, match="binary content requires url"):
validate_run_request_messages_contract(run_input)
@@ -54,3 +54,20 @@ def test_build_intent_user_prompt_filters_non_image_binary_block() -> None:
assert isinstance(prompt, list)
image_blocks = [item for item in prompt if item.get("type") == "image"]
assert image_blocks == []
def test_build_intent_user_prompt_includes_previous_context_messages() -> None:
prompt = build_intent_user_prompt(
user_input=[
{"id": "u1", "role": "user", "content": "我的口令是蓝鲸42"},
{"id": "a1", "role": "assistant", "content": "已记住"},
{"id": "u2", "role": "user", "content": "请重复口令"},
]
)
assert isinstance(prompt, list)
assert prompt
instruction = prompt[0].get("text", "")
assert isinstance(instruction, str)
assert "[Conversation Context]" in instruction
assert "\\u84dd\\u9cb842" in instruction
@@ -67,10 +67,8 @@ async def test_close_clears_clients(monkeypatch: pytest.MonkeyPatch) -> None:
assert await service.initialize() is True
assert await service.close() is True
assert service.is_initialized is False
with pytest.raises(RuntimeError):
service.get_client()
with pytest.raises(RuntimeError):
service.get_admin_client()
assert service.get_client() is not None
assert service.get_admin_client() is not None
@pytest.mark.asyncio
@@ -117,7 +115,47 @@ def test_get_client_raises_before_init() -> None:
settings=SupabaseSettings(public_url="https://test.supabase.co")
)
assert service.get_client() is not None
assert service.get_admin_client() is not None
def test_get_client_raises_when_lazy_initialization_fails(
monkeypatch: pytest.MonkeyPatch,
) -> None:
service = SupabaseService(
settings=SupabaseSettings(public_url="https://test.supabase.co")
)
def _fake_create_client(_: str, __: str) -> object:
raise RuntimeError("boom")
monkeypatch.setattr("services.base.supabase.create_client", _fake_create_client)
with pytest.raises(RuntimeError):
service.get_client()
with pytest.raises(RuntimeError):
service.get_admin_client()
def test_get_admin_client_lazily_initializes_clients(
monkeypatch: pytest.MonkeyPatch,
) -> None:
service = SupabaseService(
settings=SupabaseSettings(public_url="https://test.supabase.co")
)
anon_client = MagicMock(name="anon")
admin_client = MagicMock(name="admin")
create_calls: list[tuple[str, str]] = []
def _fake_create_client(url: str, key: str) -> object:
create_calls.append((url, key))
return anon_client if len(create_calls) == 1 else admin_client
monkeypatch.setattr("services.base.supabase.create_client", _fake_create_client)
resolved_admin = service.get_admin_client()
assert resolved_admin is admin_client
assert service.get_client() is anon_client
assert service.is_initialized is True
assert len(create_calls) == 2
@@ -0,0 +1,85 @@
from __future__ import annotations
from types import SimpleNamespace
import pytest
import v1.agent.attachment_storage as attachment_storage_module
class _FakeBucket:
def __init__(self) -> None:
self.upload_calls: list[tuple[str, bytes, dict[str, str]]] = []
self.download_calls: list[str] = []
def upload(self, path: str, content: bytes, options: dict[str, str]) -> object:
self.upload_calls.append((path, content, options))
return {"path": path}
def download(self, path: str) -> object:
self.download_calls.append(path)
return b"ok"
class _FakeStorage:
def __init__(self, bucket: _FakeBucket) -> None:
self._bucket = bucket
def from_(self, bucket: str) -> object:
del bucket
return self._bucket
@pytest.mark.asyncio
async def test_attachment_storage_rejects_unexpected_bucket(
monkeypatch: pytest.MonkeyPatch,
) -> None:
storage = attachment_storage_module.AgentAttachmentStorage()
monkeypatch.setattr(
attachment_storage_module.config.storage,
"bucket",
"allowed-bucket",
)
with pytest.raises(RuntimeError, match="Invalid attachment bucket"):
await storage.upload_bytes(
bucket="other-bucket",
path="agent-inputs/u/t/r/file.png",
content=b"data",
content_type="image/png",
)
@pytest.mark.asyncio
async def test_attachment_storage_accepts_configured_bucket(
monkeypatch: pytest.MonkeyPatch,
) -> None:
storage = attachment_storage_module.AgentAttachmentStorage()
fake_bucket = _FakeBucket()
fake_client = SimpleNamespace(storage=_FakeStorage(fake_bucket))
monkeypatch.setattr(
attachment_storage_module.config.storage,
"bucket",
"allowed-bucket",
)
monkeypatch.setattr(
attachment_storage_module.supabase_service,
"get_admin_client",
lambda: fake_client,
)
path = await storage.upload_bytes(
bucket="allowed-bucket",
path="agent-inputs/u/t/r/file.png",
content=b"data",
content_type="image/png",
)
payload = await storage.download_bytes(
bucket="allowed-bucket",
path=path,
)
assert path == "agent-inputs/u/t/r/file.png"
assert payload == b"ok"
assert len(fake_bucket.upload_calls) == 1
assert fake_bucket.download_calls == ["agent-inputs/u/t/r/file.png"]
+182 -9
View File
@@ -6,6 +6,7 @@ from uuid import uuid4
import pytest
from core.config.settings import config
from models.agent_chat_message import AgentChatMessageRole
from v1.agent.repository import AgentRepository
@@ -62,7 +63,7 @@ async def test_tool_message_hydrates_content_from_object_storage() -> None:
content='{"offloaded":true}',
metadata_json={
"tool_call_id": "call-1",
"storage_bucket": "private",
"storage_bucket": config.storage.bucket,
"storage_path": "tool-results/run-1/call-1.json",
},
)
@@ -73,6 +74,43 @@ async def test_tool_message_hydrates_content_from_object_storage() -> None:
assert payload["content"] == "已跳转"
@pytest.mark.asyncio
async def test_tool_message_hydrates_ui_from_ui_schema_field() -> None:
repository = AgentRepository(
session=SimpleNamespace(), # type: ignore[arg-type]
tool_result_storage=_FakeToolResultStorage(
{
"toolName": "calendar_write",
"ui_schema": {
"type": "calendar_operation.v1",
"version": "v1",
"data": {"ok": True, "operation": "create"},
"actions": [],
},
}
),
)
message = SimpleNamespace(
id=uuid4(),
role=AgentChatMessageRole.TOOL,
created_at=datetime.now(timezone.utc),
content="已创建日程:项目评审(明天 10:00)",
metadata_json={
"tool_call_id": "call-3",
"storage_bucket": config.storage.bucket,
"storage_path": "tool-results/run-1/call-3.json",
},
)
payload = await repository._to_snapshot_message(message) # type: ignore[arg-type]
assert payload["toolCallId"] == "call-3"
assert payload["content"] == "已创建日程:项目评审(明天 10:00)"
ui = payload.get("ui")
assert isinstance(ui, dict)
assert ui["type"] == "calendar_operation.v1"
@pytest.mark.asyncio
async def test_tool_message_keeps_inline_content_when_storage_payload_missing() -> None:
repository = AgentRepository(
@@ -86,7 +124,7 @@ async def test_tool_message_keeps_inline_content_when_storage_payload_missing()
content="inline-tool-content",
metadata_json={
"tool_call_id": "call-2",
"storage_bucket": "private",
"storage_bucket": config.storage.bucket,
"storage_path": "tool-results/run-1/call-2.json",
},
)
@@ -97,6 +135,111 @@ async def test_tool_message_keeps_inline_content_when_storage_payload_missing()
assert payload["content"] == "inline-tool-content"
@pytest.mark.asyncio
async def test_tool_message_skips_storage_when_path_not_matching_session() -> None:
repository = AgentRepository(
session=SimpleNamespace(), # type: ignore[arg-type]
tool_result_storage=_FakeToolResultStorage(
{
"ui_schema": {
"type": "calendar_operation.v1",
"version": "v1",
"data": {"ok": True},
"actions": [],
}
}
),
)
message = SimpleNamespace(
id=uuid4(),
session_id=uuid4(),
role=AgentChatMessageRole.TOOL,
created_at=datetime.now(timezone.utc),
content="summary",
metadata_json={
"tool_call_id": "call-x",
"storage_bucket": config.storage.bucket,
"storage_path": "tool-results/foreign-session/call-y.json",
},
)
payload = await repository._to_snapshot_message(message) # type: ignore[arg-type]
assert payload["content"] == "summary"
assert "ui" not in payload
@pytest.mark.asyncio
async def test_tool_message_rejects_path_traversal() -> None:
repository = AgentRepository(
session=SimpleNamespace(), # type: ignore[arg-type]
tool_result_storage=_FakeToolResultStorage(
{
"ui_schema": {
"type": "calendar_operation.v1",
"version": "v1",
"data": {"ok": True},
"actions": [],
}
}
),
)
message = SimpleNamespace(
id=uuid4(),
session_id=uuid4(),
role=AgentChatMessageRole.TOOL,
created_at=datetime.now(timezone.utc),
content="summary",
metadata_json={
"tool_call_id": "call-z",
"storage_bucket": config.storage.bucket,
"storage_path": "tool-results/ok/../../evil/call-z.json",
},
)
payload = await repository._to_snapshot_message(message) # type: ignore[arg-type]
assert payload["content"] == "summary"
assert "ui" not in payload
@pytest.mark.asyncio
async def test_tool_message_supports_legacy_storage_path() -> None:
repository = AgentRepository(
session=SimpleNamespace(), # type: ignore[arg-type]
tool_result_storage=_FakeToolResultStorage(
{
"ui_schema": {
"type": "calendar_operation.v1",
"version": "v1",
"data": {"ok": True},
"actions": [],
},
"content": "legacy content",
}
),
)
message = SimpleNamespace(
id=uuid4(),
session_id=uuid4(),
role=AgentChatMessageRole.TOOL,
created_at=datetime.now(timezone.utc),
content='{"offloaded":true}',
metadata_json={
"tool_call_id": "call-legacy",
"storage_bucket": config.storage.bucket,
"storage_path": "tool-results/old-run/call-legacy.json",
},
)
payload = await repository._to_snapshot_message(message) # type: ignore[arg-type]
assert payload["content"] == "legacy content"
ui = payload.get("ui")
assert isinstance(ui, dict)
assert ui["type"] == "calendar_operation.v1"
@pytest.mark.asyncio
async def test_user_message_snapshot_includes_renderable_attachments() -> None:
repository = AgentRepository(
@@ -104,6 +247,7 @@ async def test_user_message_snapshot_includes_renderable_attachments() -> None:
)
message = SimpleNamespace(
id=uuid4(),
session_id=uuid4(),
role=AgentChatMessageRole.USER,
created_at=datetime.now(timezone.utc),
content="请分析这张图",
@@ -122,13 +266,13 @@ async def test_user_message_snapshot_includes_renderable_attachments() -> None:
assert payload["role"] == "user"
assert payload["content"] == "请分析这张图"
assert payload["attachments"] == [
{
"bucket": "agent-chat-attachments",
"path": "agent-inputs/u1/t1/r1/m1/att-1.png",
"mimeType": "image/png",
}
]
attachments = payload.get("attachments")
assert isinstance(attachments, list)
assert len(attachments) == 1
first = attachments[0]
assert isinstance(first, dict)
assert first["mimeType"] == "image/png"
assert isinstance(first.get("previewPath"), str)
@pytest.mark.asyncio
@@ -174,3 +318,32 @@ async def test_persist_user_message_keeps_existing_session_title() -> None:
assert session_row.title == "已有标题"
assert session_row.message_count == 2
@pytest.mark.asyncio
async def test_get_message_attachment_reference_returns_item() -> None:
session_id = str(uuid4())
message_id = str(uuid4())
message = SimpleNamespace(
metadata_json={
"attachments": [
{
"bucket": "bucket-test",
"path": "agent-inputs/u/t/r/a.png",
"mimeType": "image/png",
}
]
}
)
fake_session = _FakeSession(message)
repository = AgentRepository(session=fake_session) # type: ignore[arg-type]
ref = await repository.get_message_attachment_reference(
session_id=session_id,
message_id=message_id,
attachment_index=0,
)
assert ref is not None
assert ref["bucket"] == "bucket-test"
assert ref["mimeType"] == "image/png"
@@ -225,3 +225,44 @@ async def test_stream_events_retries_on_redis_timeout(
merged = "".join(chunks)
assert "event: RUN_FINISHED" in merged
@pytest.mark.asyncio
async def test_get_attachment_preview_rejects_negative_index() -> None:
class _Service:
async def get_attachment_preview(self, **kwargs): # noqa: ANN003
del kwargs
raise AssertionError("get_attachment_preview should not be called")
with pytest.raises(HTTPException) as exc_info:
await agent_router.get_attachment_preview(
thread_id="00000000-0000-0000-0000-000000000001",
message_id="00000000-0000-0000-0000-000000000010",
attachment_index=-1,
service=cast(Any, _Service()),
current_user=CurrentUser(id=uuid4(), email="user@example.com"),
)
assert exc_info.value.status_code == 422
@pytest.mark.asyncio
async def test_get_attachment_preview_returns_streaming_response() -> None:
class _Service:
async def get_attachment_preview(self, **kwargs): # noqa: ANN003
del kwargs
return b"png-bytes", "image/png"
response = await agent_router.get_attachment_preview(
thread_id="00000000-0000-0000-0000-000000000001",
message_id="00000000-0000-0000-0000-000000000010",
attachment_index=0,
service=cast(Any, _Service()),
current_user=CurrentUser(id=uuid4(), email="user@example.com"),
)
chunks: list[bytes] = []
async for chunk in response.body_iterator:
chunks.append(cast(bytes, chunk))
assert response.media_type == "image/png"
assert b"".join(chunks) == b"png-bytes"
+341 -13
View File
@@ -6,8 +6,10 @@ from uuid import UUID
from ag_ui.core import RunAgentInput
from fastapi import HTTPException
import pytest
from sqlalchemy.exc import IntegrityError
from core.auth.models import CurrentUser
from core.config.settings import config
import v1.agent.service as agent_service_module
from v1.agent.service import AgentService, AsrService
@@ -74,12 +76,32 @@ class _FakeRepository:
}
)
async def get_message_attachment_reference(
self,
*,
session_id: str,
message_id: str,
attachment_index: int,
) -> dict[str, str] | None:
del session_id, message_id
if attachment_index != 0:
return None
return {
"bucket": config.storage.bucket,
"path": "agent-inputs/00000000-0000-0000-0000-000000000001/00000000-0000-0000-0000-000000000001/run-1/attachment-0-a.png",
"mimeType": "image/png",
}
class _FakeQueue:
def __init__(self) -> None:
self.commands: list[dict[str, object]] = []
async def enqueue(
self, *, command: dict[str, object], dedup_key: str | None
) -> str:
del command, dedup_key
self.commands.append(command)
del dedup_key
return "task-1"
@@ -123,6 +145,33 @@ class _FakeAttachmentStorage:
)
return path
async def download_bytes(self, *, bucket: str, path: str) -> bytes:
self.calls.append(
{
"bucket": bucket,
"path": path,
"download": True,
}
)
return b"png-bytes"
async def create_signed_url(
self,
*,
bucket: str,
path: str,
expires_in_seconds: int,
) -> str:
self.calls.append(
{
"bucket": bucket,
"path": path,
"signed": True,
"expires_in_seconds": expires_in_seconds,
}
)
return f"https://signed.example/{path}?exp={expires_in_seconds}"
class _AlwaysFailAttachmentStorage:
async def upload_bytes(
@@ -136,6 +185,20 @@ class _AlwaysFailAttachmentStorage:
del bucket, path, content, content_type
raise RuntimeError("upload failed")
async def download_bytes(self, *, bucket: str, path: str) -> bytes:
del bucket, path
raise RuntimeError("download failed")
async def create_signed_url(
self,
*,
bucket: str,
path: str,
expires_in_seconds: int,
) -> str:
del bucket, path, expires_in_seconds
raise RuntimeError("sign failed")
def _user() -> CurrentUser:
return CurrentUser(
@@ -186,9 +249,10 @@ async def test_resume_idempotency_uses_redis_lock_and_task_key() -> None:
async def test_enqueue_run_creates_missing_thread_session() -> None:
repository = _FakeRepository()
queue = _FakeQueue()
service = AgentService(
repository=repository,
queue=_FakeQueue(),
queue=queue,
stream=_FakeStream(),
)
run_input = _build_run_input(
@@ -206,6 +270,30 @@ async def test_enqueue_run_creates_missing_thread_session() -> None:
assert accepted.created is True
assert repository.created_with_session_id == "00000000-0000-0000-0000-000000000999"
assert repository.committed is True
assert queue.commands[0]["user_token"] is None
async def test_enqueue_run_uses_explicit_user_token() -> None:
repository = _FakeRepository()
queue = _FakeQueue()
service = AgentService(
repository=repository,
queue=queue,
stream=_FakeStream(),
)
run_input = _build_run_input(
thread_id="00000000-0000-0000-0000-000000000001",
run_id="run-1",
)
await service.enqueue_run(
run_input=run_input,
current_user=_user(),
user_token="Bearer access-token-1",
)
assert queue.commands
assert queue.commands[0]["user_token"] == "access-token-1"
async def test_enqueue_run_keeps_created_session_when_enqueue_fails() -> None:
@@ -270,7 +358,7 @@ async def test_enqueue_run_handles_session_create_race() -> None:
assert repository.rolled_back is True
async def test_enqueue_run_uploads_user_image_to_supabase_and_injects_metadata(
async def test_enqueue_run_uses_forwarded_attachments_and_injects_metadata(
monkeypatch,
) -> None:
monkeypatch.setattr(
@@ -297,15 +385,23 @@ async def test_enqueue_run_uploads_user_image_to_supabase_and_injects_metadata(
{"type": "text", "text": "帮我看下这张图"},
{
"type": "binary",
"data": "aGVsbG8=",
"mimeType": "image/png",
"url": "https://signed.example/upload.png",
},
],
}
],
"tools": [],
"context": [],
"forwardedProps": {},
"forwardedProps": {
"attachments": [
{
"bucket": "agent-test-bucket",
"path": "agent-inputs/00000000-0000-0000-0000-000000000001/00000000-0000-0000-0000-000000000001/upload.png",
"mimeType": "image/png",
}
]
},
}
)
@@ -313,10 +409,9 @@ async def test_enqueue_run_uploads_user_image_to_supabase_and_injects_metadata(
assert accepted.task_id == "task-1"
assert len(attachment_storage.calls) == 1
upload = attachment_storage.calls[0]
assert upload["bucket"] == "agent-test-bucket"
assert upload["content"] == b"hello"
assert upload["content_type"] == "image/png"
download = attachment_storage.calls[0]
assert download["bucket"] == "agent-test-bucket"
assert download["download"] is True
assert repository.persisted_user_messages
persisted = repository.persisted_user_messages[0]
assert persisted["session_id"] == "00000000-0000-0000-0000-000000000001"
@@ -330,7 +425,7 @@ async def test_enqueue_run_uploads_user_image_to_supabase_and_injects_metadata(
assert isinstance(attachments[0]["path"], str)
async def test_enqueue_run_raises_when_attachment_upload_fails_without_fallback(
async def test_enqueue_run_raises_when_attachment_download_fails_without_fallback(
monkeypatch,
) -> None:
monkeypatch.setattr(
@@ -356,15 +451,23 @@ async def test_enqueue_run_raises_when_attachment_upload_fails_without_fallback(
{"type": "text", "text": "帮我看下这张图"},
{
"type": "binary",
"data": "aGVsbG8=",
"mimeType": "image/png",
"url": "https://signed.example/upload.png",
},
],
}
],
"tools": [],
"context": [],
"forwardedProps": {},
"forwardedProps": {
"attachments": [
{
"bucket": "agent-test-bucket",
"path": "agent-inputs/00000000-0000-0000-0000-000000000001/00000000-0000-0000-0000-000000000001/upload.png",
"mimeType": "image/png",
}
]
},
}
)
@@ -373,11 +476,183 @@ async def test_enqueue_run_raises_when_attachment_upload_fails_without_fallback(
raise AssertionError("expected HTTPException")
except HTTPException as exc:
assert exc.status_code == 502
assert exc.detail == "Failed to upload attachment"
assert exc.detail == "Failed to fetch attachment"
assert repository.persisted_user_messages == []
async def test_enqueue_run_rejects_unsupported_attachment_type(
monkeypatch,
) -> None:
monkeypatch.setattr(
agent_service_module.config.storage, "bucket", "agent-test-bucket"
)
repository = _FakeRepository()
attachment_storage = _FakeAttachmentStorage()
service = AgentService(
repository=repository,
queue=_FakeQueue(),
stream=_FakeStream(),
attachment_storage=attachment_storage,
)
run_input = RunAgentInput.model_validate(
{
"threadId": "00000000-0000-0000-0000-000000000001",
"runId": "run-with-bad-image",
"state": {},
"messages": [
{
"id": "u1",
"role": "user",
"content": [
{"type": "text", "text": "请看附件"},
{
"type": "binary",
"mimeType": "image/gif",
"url": "https://signed.example/upload.gif",
},
],
}
],
"tools": [],
"context": [],
"forwardedProps": {
"attachments": [
{
"bucket": "agent-test-bucket",
"path": "agent-inputs/00000000-0000-0000-0000-000000000001/00000000-0000-0000-0000-000000000001/upload.gif",
"mimeType": "image/gif",
}
]
},
}
)
with pytest.raises(HTTPException) as exc_info:
await service.enqueue_run(run_input=run_input, current_user=_user())
assert exc_info.value.status_code == 422
assert exc_info.value.detail == "Unsupported attachment type"
assert attachment_storage.calls == []
async def test_enqueue_run_rejects_attachment_too_large(
monkeypatch,
) -> None:
monkeypatch.setattr(agent_service_module, "_MAX_ATTACHMENT_BYTES", 4)
monkeypatch.setattr(
agent_service_module.config.storage, "bucket", "agent-test-bucket"
)
repository = _FakeRepository()
attachment_storage = _FakeAttachmentStorage()
service = AgentService(
repository=repository,
queue=_FakeQueue(),
stream=_FakeStream(),
attachment_storage=attachment_storage,
)
run_input = RunAgentInput.model_validate(
{
"threadId": "00000000-0000-0000-0000-000000000001",
"runId": "run-with-big-image",
"state": {},
"messages": [
{
"id": "u1",
"role": "user",
"content": [
{"type": "text", "text": "请看附件"},
{
"type": "binary",
"mimeType": "image/png",
"url": "https://signed.example/upload.png",
},
],
}
],
"tools": [],
"context": [],
"forwardedProps": {
"attachments": [
{
"bucket": "agent-test-bucket",
"path": "agent-inputs/00000000-0000-0000-0000-000000000001/00000000-0000-0000-0000-000000000001/upload.png",
"mimeType": "image/png",
}
]
},
}
)
with pytest.raises(HTTPException) as exc_info:
await service.enqueue_run(run_input=run_input, current_user=_user())
assert exc_info.value.status_code == 413
assert exc_info.value.detail == "Attachment too large"
assert len(attachment_storage.calls) == 1
assert attachment_storage.calls[0]["download"] is True
async def test_enqueue_run_accepts_binary_url_and_persists_metadata() -> None:
repository = _FakeRepository()
queue = _FakeQueue()
attachment_storage = _FakeAttachmentStorage()
service = AgentService(
repository=repository,
queue=queue,
stream=_FakeStream(),
attachment_storage=attachment_storage,
)
run_input = RunAgentInput.model_validate(
{
"threadId": "00000000-0000-0000-0000-000000000001",
"runId": "run-with-binary-url",
"state": {},
"messages": [
{
"id": "u1",
"role": "user",
"content": [
{"type": "text", "text": "请分析"},
{
"type": "binary",
"mimeType": "image/png",
"url": "https://signed.example/upload-1.png",
},
],
}
],
"tools": [],
"context": [],
"forwardedProps": {
"attachments": [
{
"bucket": config.storage.bucket,
"path": "agent-inputs/00000000-0000-0000-0000-000000000001/00000000-0000-0000-0000-000000000001/upload-1.png",
"mimeType": "image/png",
}
]
},
}
)
accepted = await service.enqueue_run(run_input=run_input, current_user=_user())
assert accepted.task_id == "task-1"
persisted = repository.persisted_user_messages[-1]
metadata = persisted["metadata"]
assert isinstance(metadata, dict)
attachments = metadata.get("attachments")
assert isinstance(attachments, list)
assert attachments[0]["path"].endswith("upload-1.png")
queue_input = queue.commands[-1]["run_input"]
assert isinstance(queue_input, dict)
content = queue_input["messages"][0]["content"]
assert isinstance(content, list)
assert content[1]["type"] == "binary"
assert content[1]["url"] == "https://signed.example/upload-1.png"
async def test_get_history_snapshot_wraps_history_day_as_state_snapshot_event() -> None:
service = AgentService(
repository=_FakeRepository(),
@@ -415,6 +690,59 @@ async def test_get_user_history_snapshot_uses_latest_thread_when_absent() -> Non
assert event["threadId"] == "00000000-0000-0000-0000-000000000001"
async def test_get_attachment_preview_returns_payload_and_mime() -> None:
service = AgentService(
repository=_FakeRepository(),
queue=_FakeQueue(),
stream=_FakeStream(),
attachment_storage=_FakeAttachmentStorage(),
)
payload, mime_type = await service.get_attachment_preview(
thread_id="00000000-0000-0000-0000-000000000001",
message_id="00000000-0000-0000-0000-000000000010",
attachment_index=0,
current_user=_user(),
)
assert payload == b"png-bytes"
assert mime_type == "image/png"
async def test_get_attachment_preview_rejects_invalid_path() -> None:
class _BadPathRepository(_FakeRepository):
async def get_message_attachment_reference(
self,
*,
session_id: str,
message_id: str,
attachment_index: int,
) -> dict[str, str] | None:
del session_id, message_id, attachment_index
return {
"bucket": "bucket-test",
"path": "agent-inputs/other-user/other-thread/run-1/a.png",
"mimeType": "image/png",
}
service = AgentService(
repository=_BadPathRepository(),
queue=_FakeQueue(),
stream=_FakeStream(),
attachment_storage=_FakeAttachmentStorage(),
)
with pytest.raises(HTTPException) as exc_info:
await service.get_attachment_preview(
thread_id="00000000-0000-0000-0000-000000000001",
message_id="00000000-0000-0000-0000-000000000010",
attachment_index=0,
current_user=_user(),
)
assert exc_info.value.status_code == 403
async def test_asr_service_parses_dict_output_sentence(monkeypatch) -> None:
result = SimpleNamespace(
status_code=200,
@@ -67,3 +67,60 @@ uv run pytest tests/integration/v1/agent/test_sse_flow_live.py::test_agent_runs_
2. **连续会话记忆测试** - 验证 session 是否从数据库读取历史上下文
3. **工具调用测试** - calendar 读/写/删/分享 + 用户查找 + 时间感知
4. **session 失败排查** - 找出最新失败原因并修复
## 9. 本轮进展与结论(2026-03-12
### 9.1 反馈闭环状态
1. **intent/execution 阶段 tokens/cost 入库**:已解决。
2. **连续会话记忆(今天+昨天上下文)**:已解决。
3. **工具调用冒烟(读/写/删/分享 + user 查询 + 时间感知)**:部分解决。
4. **最新失败 session 根因定位与修复**:已解决。
5. **反馈同步到文档**:已完成(本节)。
### 9.2 关键修复
1. **stage telemetry 补齐**intent/execution):
- usage 缺失时补 token 估算;
- 通过 `LiteLLMService.calculate_cost` 按项目定价估算 cost
- 回填 `response_metadata.inputTokens/outputTokens/cost` 并落库。
2. **会话记忆上下文注入**
- runtime 在执行前读取同一 session 最近两天(今天+昨天)的 user/assistant 消息;
- intent prompt 增加 `[Conversation Context]`,避免只看最新用户输入。
3. **工具调用稳定性修复**
- tool 名统一为下划线(`calendar_read`/`calendar_write`/`user_resolve`),修复 OpenAI/LiteLLM tool name 正则错误;
- intent prompt 注入 intent+execution 合并工具 schema,避免误判“无可用写入工具”。
### 9.3 Live 证据
#### A) tokens/cost 入库(thread=`cb1681c2-c223-4ced-bcfd-76f7252ba2d8`
- intent: `input_tokens=1541``output_tokens=37``cost=0.000382`
- execution: `input_tokens=2161``output_tokens=376``cost=0.005450`
- report: `input_tokens=3266``output_tokens=318``cost=0.007256`
- session 聚合:`total_tokens=13518``total_cost=0.019473`
#### B) 连续会话记忆(thread=`9c456736-d5e5-48a4-b9db-55f507baf573`
- run `mem-1``请记住口令是蓝鲸42,只回复已记住。`
- run `mem-2``只回复我刚才让你记住的口令,不要解释。`
- assistant 回复:`蓝鲸42`(记忆命中)。
#### C) 工具调用 + 时间感知(thread=`cb1681c2-c223-4ced-bcfd-76f7252ba2d8`run=`run-tool-1`
- 事件序列含 execution 阶段与多次 `TOOL_CALL_RESULT`
- 工具调用结果:`calendar_write``calendar_read`(多次)
- assistant 回复包含时间感知信息(北京时间日期/星期/时刻)
### 9.4 最新失败 session 根因
- 失败样本:`d6bc4dbd-8361-4a39-bf09-12b3392e0e70`
- 根因:tool 名含点号(如 `calendar.write`)触发校验失败:
- `Invalid 'tools[0].function.name' ... expected pattern ^[a-zA-Z0-9_-]+$`
- 修复后:同类执行链路已可稳定进入 execution 并产出 `TOOL_CALL_RESULT`
### 9.5 当前未闭环项
- `user_resolve` + calendar **分享 + 删除** 组合链路的完整 live 证据还未补齐(本轮执行中断:`Tool execution aborted`)。
@@ -0,0 +1,173 @@
# Agent Tool UI Schema and Frontend Event Wiring Design
## Goal
修正 agent 工具结果的数据契约与前后端对接:
1. SSE `TOOL_CALL_RESULT` 继续携带可实时渲染的 `ui`
2. 落库时 `messages.content` 仅存关键摘要,完整工具结果(含 `ui schema`)存对象存储。
3. `messages.metadata` 仅存访问路径和索引字段,history 通过 metadata 回填完整工具卡片数据。
4. 前端正式接通 runs/events/history 三路,并统一实时与历史渲染行为。
## Constraints
- 暂缓冒烟测试,先完成工具数据修正与前后端接口对接。
- 保持现有前端 `UiSchemaRenderer` 可解析格式,不做破坏性协议改动。
- `resume` 新需求暂不扩展。
- 遵循 AG-UI 事件语义和现有 FastAPI 路由约定。
## Selected Approach
采用兼容增强方案:
- 事件流对前端保持兼容(`TOOL_CALL_RESULT``ui` + `content`)。
- 持久化与回放做结构化增强(storage + metadata 索引 + 摘要 content)。
- 前端实时与历史统一映射层,保证同类消息一致渲染。
## Design A: Unified Data Contract
### SSE Event Contract (Realtime)
`TOOL_CALL_RESULT` 事件继续包含前端当前可解析字段:
- `callId`
- `toolName`
- `args`
- `result`
- `error`
- `content` (关键结果摘要)
- `ui` (工具卡片 schema)
这保证前端实时流不需要等待 history 即可显示工具卡片。
### Persistence Contract (Database + Storage)
对 tool message 持久化采用双层:
- `messages.content`: 仅保存 `content_summary`(短文本,供低成本上下文和兜底展示)。
- 对象存储: 保存完整 payload(`ui``args``result``error`、时间戳、工具标识等)。
- `messages.metadata`: 只保存索引和访问路径:
- `tool_call_id`
- `tool_name`
- `run_id`
- `stage`
- `task_id`
- `storage_bucket`
- `storage_path`
- `summary_version`
### History Contract
history 序列化时:
1. 先通过 `metadata.storage_bucket/storage_path` 读取完整 payload。
2. 从 payload 回填 `ui`,并保留摘要 `content`
3. storage 读取失败时,回退 `messages.content`,确保历史可读。
## Design B: Frontend Wiring (runs/events/history)
### runs
- `POST /api/v1/agent/runs` 仅负责创建 run 与启动执行。
- 前端保留 `threadId/runId` 和本地流状态,不承载渲染业务。
### events
- SSE 作为唯一实时渲染来源。
- `TOOL_CALL_RESULT` 直接读取事件内 `ui` 渲染 `ToolResultItem`
- `STEP_STARTED/STEP_FINISHED` 显示三阶段状态(intent/execution/report)。
### history
- 通过 `/api/v1/agent/history``/api/v1/agent/runs/{threadId}/history` 回放。
- tool message 优先读 `ui`(由后端从 metadata+storage 回填)。
- user message 读取 `attachments` 渲染多模态内容。
### Consistency Rule
- 实时事件与历史快照统一进入同一 `ChatListItem` 映射层。
- `content` 只做兜底文本,不作为工具卡片主数据。
## Design C: Backend Implementation Details
### Modules to Change
- `backend/src/core/agentscope/events/store.py`
- 增加 tool result 的摘要生成与 storage 上传。
- `append_message` 时写入摘要 content + metadata 索引。
- `backend/src/core/agentscope/tools/tool_result_storage.py`
- 复用现有 `upload_json/read_json`,作为完整 payload 存取层。
- `backend/src/v1/agent/repository.py`
- `_to_snapshot_message` 对 tool message 优先按 metadata 读取 storage 并回填 `ui`
- `backend/src/core/agentscope/runtime/agent_route_runtime.py`
- 确保 `tool.result` 事件继续带 `ui` 和摘要 `content`
### Failure Fallback
- storage 写失败:不阻断主流程,至少保证 `messages.content` 可读,metadata 标记缺失。
- storage 读失败:history 返回摘要 `content``ui` 为空。
## Design D: content_summary Rule Engine
### Function
新增纯函数:
`build_tool_content_summary(tool_name, args, result, error) -> str`
### Rules (Priority)
1. 错误优先:有 `error` 直接输出失败摘要。
2. 工具专用模板:
- `calendar_write`: `已创建日程:{title}{start_time}`
- `calendar_read`: `查询到 {count} 条日程({date_range}`
- `calendar_delete`: `已删除日程:{title_or_id}`
- `calendar_share`: `已分享日程给 {target}`
- `user_resolve`: `已匹配用户:{name_or_id}`
3. 通用回退:优先 `result.content`,否则抽取常见键拼句。
4. 最终兜底:`{tool_name} 执行完成/执行失败`
5. 清洗:去换行与多空格,限制长度,避免大段 JSON。
### Summary Storage Policy
- `messages.content` 存摘要。
- `summary_version` 存入 metadata,支持未来摘要算法演进。
## Testing and Acceptance
### Backend
- 单元测试:
- `events/store`: tool result 摘要写入、metadata 路径写入、storage 异常回退。
- `v1/agent/repository`: history 按 metadata 回填 `ui`storage 缺失回退 content。
- 摘要函数:覆盖成功/失败/缺字段/超长文本场景。
- 集成测试:
- `/runs` + `/events`:实时 `TOOL_CALL_RESULT``ui`
- `/history`:返回 tool message 的 `ui` 来自 metadata+storage。
### Frontend
- 单元/组件测试:
- `AgUiService` 解析 `TOOL_CALL_RESULT``ui`
- `ChatBloc`:实时事件与 history 快照都能产出 `ToolResultItem`
- `UiSchemaRenderer`history 回放卡片渲染一致。
- user message 附件渲染(history)。
- 页面行为验证:
- events 到达即实时更新消息列表。
- step 三阶段状态正确切换。
- 上拉历史后工具卡片可正常显示。
## Risks and Mitigations
- 风险:storage 不可用导致 history 卡片缺失。
- 缓解:保底展示摘要 content,不阻断对话。
- 风险:事件格式变更导致前端实时解析失败。
- 缓解:维持现有 `ToolCallResultEvent` 字段,不做破坏性改名。
- 风险:摘要规则覆盖不足。
- 缓解:规则版本化 + 测试样例扩展。
## Out of Scope
- resume 扩展协议与交互策略。
- 新一轮 live 冒烟验收。
- 新 UI 风格重构,仅实现链路打通与数据契约修正。
@@ -0,0 +1,283 @@
# Agent UI Schema and Event Wiring Implementation Plan
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
**Goal:** 打通 agent 工具结果在实时事件与历史回放的一致渲染链路:SSE 实时带 UI,落库 content 存摘要,完整 UI schema 存 storage 并通过 metadata 回填。
**Architecture:** 后端在 `TOOL_CALL_RESULT` 持久化链路中引入“摘要 + 全量分离”策略:摘要写 `messages.content`,全量 payload 写对象存储,metadata 仅存索引路径;history 读取时按 metadata 反查 storage 回填 `ui`。前端复用现有 AG-UI 事件模型,实现 runs/events/history 三路统一映射到 `ChatListItem`,并补齐 step 事件渲染与 history 多模态渲染。
**Tech Stack:** FastAPI, SQLAlchemy, AgentScope runtime/events, Supabase Storage, Flutter (Bloc/Cubit), Dart models/tests, AG-UI events
---
### Task 1: Add Tool Summary Rule Engine (Backend)
**Files:**
- Create: `backend/src/core/agentscope/events/tool_result_summary.py`
- Test: `backend/tests/unit/core/agentscope/events/test_tool_result_summary.py`
**Step 1: Write the failing test**
```python
from core.agentscope.events.tool_result_summary import build_tool_content_summary
def test_calendar_write_summary() -> None:
text = build_tool_content_summary(
tool_name="calendar_write",
args={"title": "项目评审"},
result={"start_time": "明天 10:00"},
error=None,
)
assert text.startswith("已创建日程")
```
**Step 2: Run test to verify it fails**
Run: `uv run pytest backend/tests/unit/core/agentscope/events/test_tool_result_summary.py -q`
Expected: FAIL with import/module/function missing.
**Step 3: Write minimal implementation**
```python
def build_tool_content_summary(*, tool_name: str, args, result, error) -> str:
if error:
return f"{tool_name} 执行失败"
if tool_name == "calendar_write":
return "已创建日程"
return f"{tool_name} 执行完成"
```
**Step 4: Run test to verify it passes**
Run: `uv run pytest backend/tests/unit/core/agentscope/events/test_tool_result_summary.py -q`
Expected: PASS.
**Step 5: Extend tests for all rules and refactor**
Add cases for `calendar_read/calendar_delete/calendar_share/user_resolve/error/fallback/truncation` and implement full rule table.
**Step 6: Commit**
```bash
git add backend/src/core/agentscope/events/tool_result_summary.py backend/tests/unit/core/agentscope/events/test_tool_result_summary.py
git commit -m "feat: add deterministic tool result summary engine"
```
### Task 2: Persist Full Tool Payload to Storage and Keep Content Lightweight
**Files:**
- Modify: `backend/src/core/agentscope/events/store.py`
- Test: `backend/tests/unit/core/agentscope/events/test_store.py`
**Step 1: Write the failing tests**
Add tests asserting:
- `TOOL_CALL_RESULT` persists summary to `content`.
- metadata includes `storage_bucket/storage_path/tool_call_id`.
- uploaded payload includes full `ui/args/result/error`.
**Step 2: Run targeted tests (RED)**
Run: `uv run pytest backend/tests/unit/core/agentscope/events/test_store.py -q`
Expected: FAIL on new assertions.
**Step 3: Implement minimal storage write path**
In `_persist_tool_call_result`:
- build `full_payload` from event fields.
- call summary engine for `content`.
- upload payload via tool result storage (inject dependency if needed).
- store only path/index in metadata.
**Step 4: Run tests (GREEN)**
Run: `uv run pytest backend/tests/unit/core/agentscope/events/test_store.py -q`
Expected: PASS.
**Step 5: Add fallback test and implementation**
Add case where storage upload fails but tool message still persists with summary and no crash.
**Step 6: Commit**
```bash
git add backend/src/core/agentscope/events/store.py backend/tests/unit/core/agentscope/events/test_store.py
git commit -m "feat: store tool payload in object storage with metadata index"
```
### Task 3: Hydrate History Tool UI from Metadata Storage Path
**Files:**
- Modify: `backend/src/v1/agent/repository.py`
- Test: `backend/tests/unit/v1/agent/test_repository.py`
**Step 1: Write failing tests**
Add/adjust assertions:
- history tool payload resolves `ui` from storage payload.
- when storage missing, fallback to `messages.content` summary.
**Step 2: Run tests (RED)**
Run: `uv run pytest backend/tests/unit/v1/agent/test_repository.py -q`
Expected: FAIL on `ui` hydration and fallback assertions.
**Step 3: Implement minimal hydration logic**
In `_to_snapshot_message` for tool role:
- read storage via `metadata.storage_bucket/storage_path`.
- map hydrated payload fields to snapshot (`ui`, `content`, `toolCallId`).
- keep safe fallback when storage read fails.
**Step 4: Run tests (GREEN)**
Run: `uv run pytest backend/tests/unit/v1/agent/test_repository.py -q`
Expected: PASS.
**Step 5: Commit**
```bash
git add backend/src/v1/agent/repository.py backend/tests/unit/v1/agent/test_repository.py
git commit -m "fix: hydrate tool ui from metadata storage in history snapshots"
```
### Task 4: Keep SSE TOOL_CALL_RESULT Compatible with Existing Frontend Parsing
**Files:**
- Modify: `backend/src/core/agentscope/runtime/agent_route_runtime.py`
- Test: `backend/tests/unit/core/agentscope/runtime/test_agent_route_runtime.py`
**Step 1: Write failing test**
Add assertion that emitted `TOOL_CALL_RESULT` data contains expected renderable fields (`callId/toolName/result/error` and `ui` path from result payload).
**Step 2: Run tests (RED)**
Run: `uv run pytest backend/tests/unit/core/agentscope/runtime/test_agent_route_runtime.py -q`
Expected: FAIL on missing/incorrect payload fields.
**Step 3: Implement minimal payload normalization**
Normalize tool result event payload so frontend can keep current parsing without contract breaks.
**Step 4: Run tests (GREEN)**
Run: `uv run pytest backend/tests/unit/core/agentscope/runtime/test_agent_route_runtime.py -q`
Expected: PASS.
**Step 5: Commit**
```bash
git add backend/src/core/agentscope/runtime/agent_route_runtime.py backend/tests/unit/core/agentscope/runtime/test_agent_route_runtime.py
git commit -m "fix: preserve frontend-compatible tool result event payload"
```
### Task 5: Wire Frontend History + Events to Unified Rendering Path
**Files:**
- Modify: `apps/lib/features/chat/data/services/ag_ui_service.dart`
- Modify: `apps/lib/features/chat/presentation/bloc/chat_bloc.dart`
- Modify: `apps/lib/features/chat/data/models/tool_result.dart`
- Modify: `apps/lib/features/home/ui/screens/home_screen.dart`
- Test: `apps/test/features/chat/ag_ui_service_test.dart`
- Create/Modify: `apps/test/features/chat/chat_bloc_test.dart`
**Step 1: Write failing tests**
Add tests asserting:
- history tool message with `ui` becomes `ToolResultItem`.
- SSE `TOOL_CALL_RESULT` with `ui` renders same item shape.
- attachments in history user message are mapped for multimodal rendering.
**Step 2: Run tests (RED)**
Run: `cd apps && flutter test test/features/chat/ag_ui_service_test.dart`
Expected: FAIL on new mapping assertions.
**Step 3: Implement minimal mapping changes**
- In service/bloc, unify history and event mapping into same conversion path.
- Keep existing `UiSchemaRenderer` input format untouched.
- Ensure fallback to content text when `ui` missing.
**Step 4: Run tests (GREEN)**
Run: `cd apps && flutter test test/features/chat/ag_ui_service_test.dart`
Expected: PASS.
**Step 5: Commit**
```bash
git add apps/lib/features/chat/data/services/ag_ui_service.dart apps/lib/features/chat/presentation/bloc/chat_bloc.dart apps/lib/features/chat/data/models/tool_result.dart apps/lib/features/home/ui/screens/home_screen.dart apps/test/features/chat/ag_ui_service_test.dart apps/test/features/chat/chat_bloc_test.dart
git commit -m "feat: unify realtime and history tool card rendering"
```
### Task 6: Add Step Event Rendering for Intent/Execution/Report
**Files:**
- Modify: `apps/lib/features/chat/presentation/bloc/chat_bloc.dart`
- Modify: `apps/lib/features/home/ui/screens/home_screen.dart`
- Test: `apps/test/features/chat/chat_bloc_test.dart`
**Step 1: Write failing test**
Add test verifying `STEP_STARTED/STEP_FINISHED` transitions produce visible stage state.
**Step 2: Run tests (RED)**
Run: `cd apps && flutter test test/features/chat/chat_bloc_test.dart`
Expected: FAIL on missing stage state.
**Step 3: Implement minimal state and UI**
- Track current stage enum in `ChatState`.
- Render compact stage progress row in chat screen.
**Step 4: Run tests (GREEN)**
Run: `cd apps && flutter test test/features/chat/chat_bloc_test.dart`
Expected: PASS.
**Step 5: Commit**
```bash
git add apps/lib/features/chat/presentation/bloc/chat_bloc.dart apps/lib/features/home/ui/screens/home_screen.dart apps/test/features/chat/chat_bloc_test.dart
git commit -m "feat: render agent step progress from AG-UI events"
```
### Task 7: Verification Gate (Backend + Frontend)
**Files:**
- Modify (if needed): `docs/plans/2026-03-11-agent-multimodal-smoke-runbook.md`
**Step 1: Run backend targeted tests**
Run: `uv run pytest backend/tests/unit/core/agentscope/events/test_tool_result_summary.py backend/tests/unit/core/agentscope/events/test_store.py backend/tests/unit/v1/agent/test_repository.py backend/tests/unit/core/agentscope/runtime/test_agent_route_runtime.py -q`
Expected: PASS.
**Step 2: Run frontend targeted tests**
Run: `cd apps && flutter test test/features/chat/ag_ui_service_test.dart test/features/chat/chat_bloc_test.dart`
Expected: PASS.
**Step 3: Run backend quality checks**
Run: `uv run ruff check backend/src backend/tests`
Expected: PASS.
**Step 4: Run backend type checks**
Run: `uv run basedpyright`
Expected: 0 errors.
**Step 5: Update runbook evidence**
Record changed contract, test evidence, and known follow-ups.
**Step 6: Commit**
```bash
git add docs/plans/2026-03-11-agent-multimodal-smoke-runbook.md
git commit -m "docs: record tool ui schema storage and rendering verification"
```