feat: 支持 agent 运行取消功能
This commit is contained in:
@@ -50,6 +50,8 @@ class AgUiService {
|
|||||||
Completer<void>? _activeSseDoneCompleter;
|
Completer<void>? _activeSseDoneCompleter;
|
||||||
|
|
||||||
String? _threadId;
|
String? _threadId;
|
||||||
|
String? _activeThreadIdForRun;
|
||||||
|
String? _activeRunId;
|
||||||
bool _hasMoreHistory = false;
|
bool _hasMoreHistory = false;
|
||||||
|
|
||||||
AgUiService({EventCallback? onEvent, required IApiClient apiClient})
|
AgUiService({EventCallback? onEvent, required IApiClient apiClient})
|
||||||
@@ -83,11 +85,20 @@ class AgUiService {
|
|||||||
throw StateError('Missing runId in /agent/runs response');
|
throw StateError('Missing runId in /agent/runs response');
|
||||||
}
|
}
|
||||||
_threadId = threadId;
|
_threadId = threadId;
|
||||||
await _streamEventsFromApi(
|
_activeThreadIdForRun = threadId;
|
||||||
threadId,
|
_activeRunId = runId;
|
||||||
expectedRunId: runId,
|
try {
|
||||||
streamToken: streamToken,
|
await _streamEventsFromApi(
|
||||||
);
|
threadId,
|
||||||
|
expectedRunId: runId,
|
||||||
|
streamToken: streamToken,
|
||||||
|
);
|
||||||
|
} finally {
|
||||||
|
if (_activeThreadIdForRun == threadId && _activeRunId == runId) {
|
||||||
|
_activeThreadIdForRun = null;
|
||||||
|
_activeRunId = null;
|
||||||
|
}
|
||||||
|
}
|
||||||
return SendMessageResult(
|
return SendMessageResult(
|
||||||
uploadedAttachments: runInputPayload.uploadedAttachments,
|
uploadedAttachments: runInputPayload.uploadedAttachments,
|
||||||
);
|
);
|
||||||
@@ -151,6 +162,19 @@ class AgUiService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Future<void> cancelCurrentRun() async {
|
Future<void> cancelCurrentRun() async {
|
||||||
|
final activeThreadId = _activeThreadIdForRun;
|
||||||
|
final activeRunId = _activeRunId;
|
||||||
|
if (activeThreadId != null && activeRunId != null) {
|
||||||
|
final encodedRunId = Uri.encodeQueryComponent(activeRunId);
|
||||||
|
await _apiClient.post<Map<String, dynamic>>(
|
||||||
|
'/api/v1/agent/runs/$activeThreadId/cancel?runId=$encodedRunId',
|
||||||
|
);
|
||||||
|
_activeThreadIdForRun = null;
|
||||||
|
_activeRunId = null;
|
||||||
|
_activeStreamToken += 1;
|
||||||
|
await _cancelActiveSseSubscription();
|
||||||
|
return;
|
||||||
|
}
|
||||||
_activeStreamToken += 1;
|
_activeStreamToken += 1;
|
||||||
await _cancelActiveSseSubscription();
|
await _cancelActiveSseSubscription();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -123,10 +123,16 @@ class ChatBloc extends Cubit<ChatState> {
|
|||||||
);
|
);
|
||||||
case AgUiEventType.runError:
|
case AgUiEventType.runError:
|
||||||
final errorEvent = event as RunErrorEvent;
|
final errorEvent = event as RunErrorEvent;
|
||||||
|
final isCanceledByUser = errorEvent.code == 'RUN_CANCELED';
|
||||||
emit(
|
emit(
|
||||||
_resetRunState(
|
_resetRunState(
|
||||||
error: errorEvent.message,
|
error: isCanceledByUser ? null : errorEvent.message,
|
||||||
).copyWith(items: _markActiveToolCallsFailed(state.items)),
|
).copyWith(
|
||||||
|
items: _markActiveToolCallsFailed(
|
||||||
|
state.items,
|
||||||
|
reason: isCanceledByUser ? '本次运行已取消' : '本次运行已失败',
|
||||||
|
),
|
||||||
|
),
|
||||||
);
|
);
|
||||||
case AgUiEventType.stepStarted:
|
case AgUiEventType.stepStarted:
|
||||||
_handleStepStarted(event as StepStartedEvent);
|
_handleStepStarted(event as StepStartedEvent);
|
||||||
@@ -286,7 +292,10 @@ class ChatBloc extends Cubit<ChatState> {
|
|||||||
return items.where((item) => item is! ToolCallItem).toList();
|
return items.where((item) => item is! ToolCallItem).toList();
|
||||||
}
|
}
|
||||||
|
|
||||||
List<ChatListItem> _markActiveToolCallsFailed(List<ChatListItem> items) {
|
List<ChatListItem> _markActiveToolCallsFailed(
|
||||||
|
List<ChatListItem> items, {
|
||||||
|
required String reason,
|
||||||
|
}) {
|
||||||
return items.map((item) {
|
return items.map((item) {
|
||||||
if (item is! ToolCallItem) {
|
if (item is! ToolCallItem) {
|
||||||
return item;
|
return item;
|
||||||
@@ -297,10 +306,7 @@ class ChatBloc extends Cubit<ChatState> {
|
|||||||
if (item.status == ToolCallStatus.completed) {
|
if (item.status == ToolCallStatus.completed) {
|
||||||
return item;
|
return item;
|
||||||
}
|
}
|
||||||
return item.copyWith(
|
return item.copyWith(status: ToolCallStatus.error, errorMessage: reason);
|
||||||
status: ToolCallStatus.error,
|
|
||||||
errorMessage: '本次运行已失败',
|
|
||||||
);
|
|
||||||
}).toList();
|
}).toList();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ extension _HomeScreenInteractions on _HomeScreenState {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (canceled) {
|
if (canceled) {
|
||||||
Toast.show(context, '已停止等待回复', type: ToastType.info);
|
Toast.show(context, '已请求停止', type: ToastType.info);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ class _FakeApiClient implements IApiClient {
|
|||||||
final List<String> sseLines;
|
final List<String> sseLines;
|
||||||
final Stream<String> Function()? sseLineStreamFactory;
|
final Stream<String> Function()? sseLineStreamFactory;
|
||||||
final String Function()? runIdFactory;
|
final String Function()? runIdFactory;
|
||||||
|
final List<String> postPaths = <String>[];
|
||||||
|
|
||||||
@override
|
@override
|
||||||
Future<Response<T>> delete<T>(String path, {data, Options? options}) {
|
Future<Response<T>> delete<T>(String path, {data, Options? options}) {
|
||||||
@@ -51,6 +52,19 @@ class _FakeApiClient implements IApiClient {
|
|||||||
|
|
||||||
@override
|
@override
|
||||||
Future<Response<T>> post<T>(String path, {data, Options? options}) async {
|
Future<Response<T>> post<T>(String path, {data, Options? options}) async {
|
||||||
|
postPaths.add(path);
|
||||||
|
if (path.contains('/cancel?runId=')) {
|
||||||
|
final payload = <String, dynamic>{
|
||||||
|
'threadId': 'thread-1',
|
||||||
|
'runId': 'run-new',
|
||||||
|
'accepted': true,
|
||||||
|
};
|
||||||
|
return Response<T>(
|
||||||
|
requestOptions: RequestOptions(path: path),
|
||||||
|
data: payload as T,
|
||||||
|
statusCode: 202,
|
||||||
|
);
|
||||||
|
}
|
||||||
final runIdFactory = this.runIdFactory;
|
final runIdFactory = this.runIdFactory;
|
||||||
final payload = <String, dynamic>{
|
final payload = <String, dynamic>{
|
||||||
'taskId': 'task-1',
|
'taskId': 'task-1',
|
||||||
@@ -192,6 +206,39 @@ void main() {
|
|||||||
await streamController.close();
|
await streamController.close();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
test(
|
||||||
|
'cancelCurrentRun calls backend cancel endpoint for active run',
|
||||||
|
() async {
|
||||||
|
final streamController = StreamController<String>();
|
||||||
|
final fakeApi = _FakeApiClient(
|
||||||
|
sseLines: const <String>[],
|
||||||
|
sseLineStreamFactory: () => streamController.stream,
|
||||||
|
);
|
||||||
|
final service = AgUiService(apiClient: fakeApi);
|
||||||
|
|
||||||
|
final sendFuture = service.sendMessage('hello');
|
||||||
|
await Future<void>.delayed(Duration.zero);
|
||||||
|
for (final line in _buildSseEvent(
|
||||||
|
id: '51',
|
||||||
|
type: AgUiEventTypeWire.runStarted,
|
||||||
|
payload:
|
||||||
|
'{"type":"RUN_STARTED","threadId":"thread-1","runId":"run-new"}',
|
||||||
|
)) {
|
||||||
|
streamController.add(line);
|
||||||
|
}
|
||||||
|
await Future<void>.delayed(Duration.zero);
|
||||||
|
|
||||||
|
await service.cancelCurrentRun();
|
||||||
|
await sendFuture;
|
||||||
|
|
||||||
|
expect(
|
||||||
|
fakeApi.postPaths,
|
||||||
|
contains('/api/v1/agent/runs/thread-1/cancel?runId=run-new'),
|
||||||
|
);
|
||||||
|
await streamController.close();
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
test(
|
test(
|
||||||
'new sendMessage cancels previous SSE subscription explicitly',
|
'new sendMessage cancels previous SSE subscription explicitly',
|
||||||
() async {
|
() async {
|
||||||
|
|||||||
@@ -206,6 +206,28 @@ void main() {
|
|||||||
expect(bloc.state.error, 'runtime execution failed');
|
expect(bloc.state.error, 'runtime execution failed');
|
||||||
});
|
});
|
||||||
|
|
||||||
|
test('run canceled error clears error and marks tool as canceled', () {
|
||||||
|
service.emitEvent(
|
||||||
|
ToolCallStartEvent(
|
||||||
|
toolCallId: 'tool-cancel',
|
||||||
|
toolCallName: 'ocr_image',
|
||||||
|
),
|
||||||
|
);
|
||||||
|
service.emitEvent(ToolCallEndEvent(toolCallId: 'tool-cancel'));
|
||||||
|
|
||||||
|
service.emitEvent(
|
||||||
|
RunErrorEvent(message: 'run canceled by user', code: 'RUN_CANCELED'),
|
||||||
|
);
|
||||||
|
|
||||||
|
final toolItem = bloc.state.items.whereType<ToolCallItem>().single;
|
||||||
|
expect(toolItem.status, ToolCallStatus.error);
|
||||||
|
expect(toolItem.errorMessage, '本次运行已取消');
|
||||||
|
expect(bloc.state.error, isNull);
|
||||||
|
expect(bloc.state.isWaitingFirstToken, isFalse);
|
||||||
|
expect(bloc.state.isStreaming, isFalse);
|
||||||
|
expect(bloc.state.isCancelling, isFalse);
|
||||||
|
});
|
||||||
|
|
||||||
test('text event with ui schema is rendered into chat items', () {
|
test('text event with ui schema is rendered into chat items', () {
|
||||||
service.emitEvent(RunStartedEvent(threadId: 'thread-1', runId: 'run-1'));
|
service.emitEvent(RunStartedEvent(threadId: 'thread-1', runId: 'run-1'));
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Protocol
|
import asyncio
|
||||||
|
from typing import Any, Awaitable, Callable, Protocol
|
||||||
|
|
||||||
from ag_ui.core.types import RunAgentInput
|
from ag_ui.core.types import RunAgentInput
|
||||||
from agentscope.message import Msg
|
from agentscope.message import Msg
|
||||||
@@ -28,6 +29,7 @@ class RunnerLike(Protocol):
|
|||||||
runtime_config: RuntimeConfig,
|
runtime_config: RuntimeConfig,
|
||||||
user_memory: UserMemoryContent | None,
|
user_memory: UserMemoryContent | None,
|
||||||
work_memory: WorkProfileContent | None,
|
work_memory: WorkProfileContent | None,
|
||||||
|
cancel_checker: Callable[[], Awaitable[bool]] | None = None,
|
||||||
) -> dict[str, Any]: ...
|
) -> dict[str, Any]: ...
|
||||||
|
|
||||||
|
|
||||||
@@ -53,6 +55,7 @@ class AgentScopeRuntimeOrchestrator:
|
|||||||
runtime_config: RuntimeConfig,
|
runtime_config: RuntimeConfig,
|
||||||
user_memory: UserMemoryContent | None = None,
|
user_memory: UserMemoryContent | None = None,
|
||||||
work_memory: WorkProfileContent | None = None,
|
work_memory: WorkProfileContent | None = None,
|
||||||
|
cancel_checker: Callable[[], Awaitable[bool]] | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
thread_id = run_input.thread_id
|
thread_id = run_input.thread_id
|
||||||
run_id = run_input.run_id
|
run_id = run_input.run_id
|
||||||
@@ -74,6 +77,7 @@ class AgentScopeRuntimeOrchestrator:
|
|||||||
runtime_config=runtime_config,
|
runtime_config=runtime_config,
|
||||||
user_memory=user_memory,
|
user_memory=user_memory,
|
||||||
work_memory=work_memory,
|
work_memory=work_memory,
|
||||||
|
cancel_checker=cancel_checker,
|
||||||
)
|
)
|
||||||
|
|
||||||
await self._pipeline.emit(
|
await self._pipeline.emit(
|
||||||
@@ -85,6 +89,23 @@ class AgentScopeRuntimeOrchestrator:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
return result if isinstance(result, dict) else {}
|
return result if isinstance(result, dict) else {}
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info(
|
||||||
|
"agentscope runtime execution canceled",
|
||||||
|
thread_id=thread_id,
|
||||||
|
run_id=run_id,
|
||||||
|
)
|
||||||
|
await self._pipeline.emit(
|
||||||
|
session_id=thread_id,
|
||||||
|
event={
|
||||||
|
"type": "RUN_ERROR",
|
||||||
|
"threadId": thread_id,
|
||||||
|
"runId": run_id,
|
||||||
|
"message": "run canceled by user",
|
||||||
|
"code": "RUN_CANCELED",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
raise
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"agentscope runtime execution failed",
|
"agentscope runtime execution failed",
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import contextlib
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any, Awaitable, Callable
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from ag_ui.core.types import RunAgentInput
|
from ag_ui.core.types import RunAgentInput
|
||||||
@@ -64,6 +66,8 @@ class AgentScopeRunner:
|
|||||||
def __init__(self, *, litellm_service: LiteLLMService | None = None) -> None:
|
def __init__(self, *, litellm_service: LiteLLMService | None = None) -> None:
|
||||||
patch_agentscope_json_repair_compat()
|
patch_agentscope_json_repair_compat()
|
||||||
self._litellm_service: LiteLLMService = litellm_service or LiteLLMService()
|
self._litellm_service: LiteLLMService = litellm_service or LiteLLMService()
|
||||||
|
self._active_agent: JsonReActAgent | None = None
|
||||||
|
self._active_agent_lock = asyncio.Lock()
|
||||||
|
|
||||||
async def execute(
|
async def execute(
|
||||||
self,
|
self,
|
||||||
@@ -75,51 +79,99 @@ class AgentScopeRunner:
|
|||||||
runtime_config: RuntimeConfig,
|
runtime_config: RuntimeConfig,
|
||||||
user_memory: UserMemoryContent | None = None,
|
user_memory: UserMemoryContent | None = None,
|
||||||
work_memory: WorkProfileContent | None = None,
|
work_memory: WorkProfileContent | None = None,
|
||||||
|
cancel_checker: Callable[[], Awaitable[bool]] | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
owner_id = UUID(user_context.id)
|
owner_id = UUID(user_context.id)
|
||||||
runtime_client_time = self._resolve_runtime_client_time(run_input=run_input)
|
runtime_client_time = self._resolve_runtime_client_time(run_input=run_input)
|
||||||
runtime_mode = self._resolve_runtime_mode(run_input=run_input)
|
runtime_mode = self._resolve_runtime_mode(run_input=run_input)
|
||||||
|
stop_cancel_watch = asyncio.Event()
|
||||||
|
cancel_watch_task: asyncio.Task[None] | None = None
|
||||||
|
run_task = asyncio.current_task()
|
||||||
|
|
||||||
async with AsyncSessionLocal() as session:
|
if cancel_checker is not None and run_task is not None:
|
||||||
router_config = await self._load_stage_config(
|
cancel_watch_task = asyncio.create_task(
|
||||||
session=session,
|
self._watch_cancel_signal(
|
||||||
agent_type=AgentType.ROUTER,
|
cancel_checker=cancel_checker,
|
||||||
)
|
stop_signal=stop_cancel_watch,
|
||||||
worker_config = await self._load_stage_config(
|
run_task=run_task,
|
||||||
session=session,
|
)
|
||||||
agent_type=AgentType.WORKER,
|
|
||||||
)
|
|
||||||
worker_toolkit = self._build_toolkit(
|
|
||||||
session=session,
|
|
||||||
owner_id=owner_id,
|
|
||||||
enabled_tools=runtime_config.enabled_tools,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
router_output = await self._execute_router_step(
|
try:
|
||||||
pipeline=pipeline,
|
async with AsyncSessionLocal() as session:
|
||||||
run_input=run_input,
|
router_config = await self._load_stage_config(
|
||||||
user_context=user_context,
|
session=session,
|
||||||
context_messages=context_messages,
|
agent_type=AgentType.ROUTER,
|
||||||
stage_config=router_config,
|
)
|
||||||
runtime_client_time=runtime_client_time,
|
worker_config = await self._load_stage_config(
|
||||||
runtime_mode=runtime_mode,
|
session=session,
|
||||||
user_memory=user_memory,
|
agent_type=AgentType.WORKER,
|
||||||
)
|
)
|
||||||
worker_output = await self._execute_worker_step(
|
worker_toolkit = self._build_toolkit(
|
||||||
pipeline=pipeline,
|
session=session,
|
||||||
run_input=run_input,
|
owner_id=owner_id,
|
||||||
user_context=user_context,
|
enabled_tools=runtime_config.enabled_tools,
|
||||||
router_output=router_output,
|
)
|
||||||
toolkit=worker_toolkit,
|
|
||||||
stage_config=worker_config,
|
router_output = await self._execute_router_step(
|
||||||
runtime_client_time=runtime_client_time,
|
pipeline=pipeline,
|
||||||
runtime_mode=runtime_mode,
|
run_input=run_input,
|
||||||
work_memory=work_memory,
|
user_context=user_context,
|
||||||
)
|
context_messages=context_messages,
|
||||||
return {
|
stage_config=router_config,
|
||||||
"router": router_output.model_dump(mode="json", exclude_none=True),
|
runtime_client_time=runtime_client_time,
|
||||||
"worker": worker_output.model_dump(mode="json", exclude_none=True),
|
runtime_mode=runtime_mode,
|
||||||
}
|
user_memory=user_memory,
|
||||||
|
)
|
||||||
|
if cancel_checker is not None and await cancel_checker():
|
||||||
|
raise asyncio.CancelledError("run canceled by user")
|
||||||
|
worker_output = await self._execute_worker_step(
|
||||||
|
pipeline=pipeline,
|
||||||
|
run_input=run_input,
|
||||||
|
user_context=user_context,
|
||||||
|
router_output=router_output,
|
||||||
|
toolkit=worker_toolkit,
|
||||||
|
stage_config=worker_config,
|
||||||
|
runtime_client_time=runtime_client_time,
|
||||||
|
runtime_mode=runtime_mode,
|
||||||
|
work_memory=work_memory,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"router": router_output.model_dump(mode="json", exclude_none=True),
|
||||||
|
"worker": worker_output.model_dump(mode="json", exclude_none=True),
|
||||||
|
}
|
||||||
|
finally:
|
||||||
|
stop_cancel_watch.set()
|
||||||
|
if cancel_watch_task is not None:
|
||||||
|
cancel_watch_task.cancel()
|
||||||
|
with contextlib.suppress(asyncio.CancelledError):
|
||||||
|
await cancel_watch_task
|
||||||
|
|
||||||
|
async def _watch_cancel_signal(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
cancel_checker: Callable[[], Awaitable[bool]],
|
||||||
|
stop_signal: asyncio.Event,
|
||||||
|
run_task: asyncio.Task[object],
|
||||||
|
) -> None:
|
||||||
|
while not stop_signal.is_set():
|
||||||
|
should_cancel = False
|
||||||
|
try:
|
||||||
|
should_cancel = await cancel_checker()
|
||||||
|
except Exception:
|
||||||
|
should_cancel = False
|
||||||
|
|
||||||
|
if should_cancel:
|
||||||
|
async with self._active_agent_lock:
|
||||||
|
active_agent = self._active_agent
|
||||||
|
if active_agent is not None:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
await active_agent.interrupt()
|
||||||
|
if not run_task.done():
|
||||||
|
run_task.cancel("run canceled by user")
|
||||||
|
return
|
||||||
|
|
||||||
|
await asyncio.sleep(0.2)
|
||||||
|
|
||||||
def _build_toolkit(
|
def _build_toolkit(
|
||||||
self,
|
self,
|
||||||
@@ -373,9 +425,16 @@ class AgentScopeRunner:
|
|||||||
model=tracking_model,
|
model=tracking_model,
|
||||||
emitter=emitter,
|
emitter=emitter,
|
||||||
)
|
)
|
||||||
response_msg = await agent.reply_json(
|
async with self._active_agent_lock:
|
||||||
input_messages, output_model=worker_output_model
|
self._active_agent = agent
|
||||||
)
|
try:
|
||||||
|
response_msg = await agent.reply_json(
|
||||||
|
input_messages, output_model=worker_output_model
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
async with self._active_agent_lock:
|
||||||
|
if self._active_agent is agent:
|
||||||
|
self._active_agent = None
|
||||||
worker_payload = worker_output_model.model_validate(response_msg.metadata or {})
|
worker_payload = worker_output_model.model_validate(response_msg.metadata or {})
|
||||||
response_metadata = self._litellm_service.build_usage_metadata(
|
response_metadata = self._litellm_service.build_usage_metadata(
|
||||||
model=stage_config.model_code,
|
model=stage_config.model_code,
|
||||||
|
|||||||
@@ -257,6 +257,7 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
|||||||
thread_id = run_input.thread_id
|
thread_id = run_input.thread_id
|
||||||
run_id = run_input.run_id
|
run_id = run_input.run_id
|
||||||
owner_id = UUID(raw_owner_id)
|
owner_id = UUID(raw_owner_id)
|
||||||
|
cancel_key = f"agent:cancel:{thread_id}:{run_id}"
|
||||||
|
|
||||||
if command_type != "run":
|
if command_type != "run":
|
||||||
raise ValueError("invalid command type")
|
raise ValueError("invalid command type")
|
||||||
@@ -278,6 +279,15 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
|||||||
work_memory: WorkProfileContent | None = memories_result.get("work_memory")
|
work_memory: WorkProfileContent | None = memories_result.get("work_memory")
|
||||||
|
|
||||||
redis_client = await get_or_init_redis_client()
|
redis_client = await get_or_init_redis_client()
|
||||||
|
|
||||||
|
async def _cancel_checker() -> bool:
|
||||||
|
exists_fn = getattr(redis_client, "exists", None)
|
||||||
|
if not callable(exists_fn):
|
||||||
|
return False
|
||||||
|
exists_call = cast(Any, exists_fn)(cancel_key)
|
||||||
|
result = await exists_call
|
||||||
|
return bool(result)
|
||||||
|
|
||||||
bus = RedisStreamBus(
|
bus = RedisStreamBus(
|
||||||
client=redis_client,
|
client=redis_client,
|
||||||
stream_prefix=config.agent_runtime.redis_stream_prefix,
|
stream_prefix=config.agent_runtime.redis_stream_prefix,
|
||||||
@@ -302,14 +312,21 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
|||||||
context_config=runtime_config.context,
|
context_config=runtime_config.context,
|
||||||
)
|
)
|
||||||
|
|
||||||
await runtime.run(
|
try:
|
||||||
run_input=run_input,
|
await runtime.run(
|
||||||
context_messages=context_messages,
|
run_input=run_input,
|
||||||
user_context=user_context,
|
context_messages=context_messages,
|
||||||
runtime_config=runtime_config,
|
user_context=user_context,
|
||||||
user_memory=user_memory,
|
runtime_config=runtime_config,
|
||||||
work_memory=work_memory,
|
user_memory=user_memory,
|
||||||
)
|
work_memory=work_memory,
|
||||||
|
cancel_checker=_cancel_checker,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
delete_fn = getattr(redis_client, "delete", None)
|
||||||
|
if callable(delete_fn):
|
||||||
|
delete_call = cast(Any, delete_fn)(cancel_key)
|
||||||
|
await delete_call
|
||||||
logger.info(
|
logger.info(
|
||||||
"agentscope runtime task completed",
|
"agentscope runtime task completed",
|
||||||
command_type=command_type,
|
command_type=command_type,
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
import json
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
@@ -22,6 +24,7 @@ DEDUP_WAIT_RETRIES = 20
|
|||||||
DEDUP_WAIT_SECONDS = 0.05
|
DEDUP_WAIT_SECONDS = 0.05
|
||||||
DEDUP_LOCK_SECONDS = 300
|
DEDUP_LOCK_SECONDS = 300
|
||||||
DEDUP_INFLIGHT_MARKER = "__inflight__"
|
DEDUP_INFLIGHT_MARKER = "__inflight__"
|
||||||
|
RUN_CANCEL_SIGNAL_TTL_SECONDS = 1800
|
||||||
|
|
||||||
|
|
||||||
def _event_stream_block_ms() -> int:
|
def _event_stream_block_ms() -> int:
|
||||||
@@ -87,6 +90,29 @@ class TaskiqQueueClient:
|
|||||||
await redis_client.delete(redis_key)
|
await redis_client.delete(redis_key)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
async def request_cancel(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
requested_by: str,
|
||||||
|
) -> None:
|
||||||
|
redis_client = await self._get_redis()
|
||||||
|
cancel_key = f"agent:cancel:{thread_id}:{run_id}"
|
||||||
|
payload = json.dumps(
|
||||||
|
{
|
||||||
|
"requested_by": requested_by,
|
||||||
|
"requested_at": datetime.now(timezone.utc).isoformat(),
|
||||||
|
},
|
||||||
|
ensure_ascii=True,
|
||||||
|
separators=(",", ":"),
|
||||||
|
)
|
||||||
|
await redis_client.set(
|
||||||
|
cancel_key,
|
||||||
|
payload,
|
||||||
|
ex=RUN_CANCEL_SIGNAL_TTL_SECONDS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RedisEventStream:
|
class RedisEventStream:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ from v1.agent.schemas import (
|
|||||||
AttachmentReference,
|
AttachmentReference,
|
||||||
AttachmentSignedUrlResponse,
|
AttachmentSignedUrlResponse,
|
||||||
AttachmentUploadResponse,
|
AttachmentUploadResponse,
|
||||||
|
CancelRunResponse,
|
||||||
HistorySnapshotResponse,
|
HistorySnapshotResponse,
|
||||||
TaskAcceptedResponse,
|
TaskAcceptedResponse,
|
||||||
)
|
)
|
||||||
@@ -147,6 +148,34 @@ async def enqueue_run(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/runs/{thread_id}/cancel",
|
||||||
|
response_model=CancelRunResponse,
|
||||||
|
status_code=status.HTTP_202_ACCEPTED,
|
||||||
|
)
|
||||||
|
async def cancel_run(
|
||||||
|
thread_id: str,
|
||||||
|
service: Annotated[AgentService, Depends(get_agent_service)],
|
||||||
|
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||||
|
run_id: str = Query(
|
||||||
|
alias="runId",
|
||||||
|
min_length=1,
|
||||||
|
max_length=128,
|
||||||
|
pattern=r"^[A-Za-z0-9_-]+$",
|
||||||
|
),
|
||||||
|
) -> CancelRunResponse:
|
||||||
|
canceled = await service.cancel_run(
|
||||||
|
thread_id=thread_id,
|
||||||
|
run_id=run_id,
|
||||||
|
current_user=current_user,
|
||||||
|
)
|
||||||
|
return CancelRunResponse(
|
||||||
|
threadId=canceled.thread_id,
|
||||||
|
runId=canceled.run_id,
|
||||||
|
accepted=canceled.accepted,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/runs/{thread_id}/events")
|
@router.get("/runs/{thread_id}/events")
|
||||||
async def stream_events(
|
async def stream_events(
|
||||||
request: Request,
|
request: Request,
|
||||||
|
|||||||
@@ -49,6 +49,14 @@ class QueueClientLike(Protocol):
|
|||||||
self, *, command: dict[str, object], dedup_key: str | None
|
self, *, command: dict[str, object], dedup_key: str | None
|
||||||
) -> str: ...
|
) -> str: ...
|
||||||
|
|
||||||
|
async def request_cancel(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
requested_by: str,
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
class EventStreamLike(Protocol):
|
class EventStreamLike(Protocol):
|
||||||
async def read(
|
async def read(
|
||||||
@@ -90,6 +98,13 @@ class TaskAccepted:
|
|||||||
created: bool
|
created: bool
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class CancelRequested:
|
||||||
|
thread_id: str
|
||||||
|
run_id: str
|
||||||
|
accepted: bool
|
||||||
|
|
||||||
|
|
||||||
class TaskAcceptedResponse(BaseModel):
|
class TaskAcceptedResponse(BaseModel):
|
||||||
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
|
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
|
||||||
|
|
||||||
@@ -99,6 +114,14 @@ class TaskAcceptedResponse(BaseModel):
|
|||||||
created: bool
|
created: bool
|
||||||
|
|
||||||
|
|
||||||
|
class CancelRunResponse(BaseModel):
|
||||||
|
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
|
||||||
|
|
||||||
|
thread_id: str = Field(alias="threadId")
|
||||||
|
run_id: str = Field(alias="runId")
|
||||||
|
accepted: bool
|
||||||
|
|
||||||
|
|
||||||
class AsrTranscribeResponse(BaseModel):
|
class AsrTranscribeResponse(BaseModel):
|
||||||
transcript: str = Field(description="Transcribed text from audio")
|
transcript: str = Field(description="Transcribed text from audio")
|
||||||
|
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ from schemas.domain.chat_message import (
|
|||||||
from v1.agent.schemas import (
|
from v1.agent.schemas import (
|
||||||
AgentRepositoryLike,
|
AgentRepositoryLike,
|
||||||
AttachmentStorageLike,
|
AttachmentStorageLike,
|
||||||
|
CancelRequested,
|
||||||
EventStreamLike,
|
EventStreamLike,
|
||||||
HistorySnapshotResponse,
|
HistorySnapshotResponse,
|
||||||
QueueClientLike,
|
QueueClientLike,
|
||||||
@@ -157,6 +158,26 @@ class AgentService:
|
|||||||
created=created,
|
created=created,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def cancel_run(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
current_user: CurrentUser,
|
||||||
|
) -> CancelRequested:
|
||||||
|
owner = await self._repository.get_session_owner(session_id=thread_id)
|
||||||
|
ensure_session_owner(owner_id=owner, current_user=current_user)
|
||||||
|
await self._queue.request_cancel(
|
||||||
|
thread_id=thread_id,
|
||||||
|
run_id=run_id,
|
||||||
|
requested_by=str(current_user.id),
|
||||||
|
)
|
||||||
|
return CancelRequested(
|
||||||
|
thread_id=thread_id,
|
||||||
|
run_id=run_id,
|
||||||
|
accepted=True,
|
||||||
|
)
|
||||||
|
|
||||||
async def _append_context_cache_user_message(
|
async def _append_context_cache_user_message(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from v1.users.dependencies import get_current_user
|
|||||||
class _FakeAgentService:
|
class _FakeAgentService:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._stream_called = False
|
self._stream_called = False
|
||||||
|
self.cancel_calls: list[tuple[str, str, str]] = []
|
||||||
|
|
||||||
async def enqueue_run(
|
async def enqueue_run(
|
||||||
self,
|
self,
|
||||||
@@ -102,6 +103,16 @@ class _FakeAgentService:
|
|||||||
"url": "https://signed.example/temp-url.png",
|
"url": "https://signed.example/temp-url.png",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async def cancel_run(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
current_user: CurrentUser,
|
||||||
|
) -> SimpleNamespace:
|
||||||
|
self.cancel_calls.append((thread_id, run_id, str(current_user.id)))
|
||||||
|
return SimpleNamespace(thread_id=thread_id, run_id=run_id, accepted=True)
|
||||||
|
|
||||||
|
|
||||||
class _FailingStreamAgentService(_FakeAgentService):
|
class _FailingStreamAgentService(_FakeAgentService):
|
||||||
async def stream_events(
|
async def stream_events(
|
||||||
@@ -306,6 +317,29 @@ def test_stream_rejects_invalid_last_event_id() -> None:
|
|||||||
app.dependency_overrides = {}
|
app.dependency_overrides = {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_cancel_run_returns_202_and_payload() -> None:
|
||||||
|
service = _FakeAgentService()
|
||||||
|
app.dependency_overrides[get_agent_service] = lambda: service
|
||||||
|
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||||
|
id=uuid4(), phone="+8613812345678"
|
||||||
|
)
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/cancel",
|
||||||
|
params={"runId": "run-99"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 202
|
||||||
|
payload = response.json()
|
||||||
|
assert payload["threadId"] == "00000000-0000-0000-0000-000000000001"
|
||||||
|
assert payload["runId"] == "run-99"
|
||||||
|
assert payload["accepted"] is True
|
||||||
|
assert service.cancel_calls
|
||||||
|
finally:
|
||||||
|
app.dependency_overrides = {}
|
||||||
|
|
||||||
|
|
||||||
def test_history_returns_state_snapshot() -> None:
|
def test_history_returns_state_snapshot() -> None:
|
||||||
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
|
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
|
||||||
client = TestClient(app)
|
client = TestClient(app)
|
||||||
|
|||||||
@@ -219,21 +219,24 @@ async def test_store_persists_router_step_output_for_cost_tracking(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
append_kwargs = cast(dict[str, Any], captured["append_kwargs"])
|
|
||||||
assert append_kwargs["seq"] == 11
|
|
||||||
assert append_kwargs["content"] == ""
|
|
||||||
assert append_kwargs["model_code"] == "doubao-seed-1-6-250615"
|
|
||||||
assert append_kwargs["input_tokens"] == 12
|
|
||||||
assert append_kwargs["output_tokens"] == 8
|
|
||||||
assert append_kwargs["latency_ms"] == 320
|
|
||||||
assert append_kwargs["cost"] == Decimal("0.01")
|
|
||||||
assert append_kwargs["visibility_mask"] == 0
|
|
||||||
|
|
||||||
metadata = cast(dict[str, Any], append_kwargs["metadata"])
|
@pytest.mark.asyncio
|
||||||
assert sorted(metadata.keys()) == ["agent_type", "router_agent_output", "run_id"]
|
async def test_store_marks_session_failed_for_run_canceled_error(
|
||||||
assert metadata["agent_type"] == "router"
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
assert metadata["router_agent_output"]["execution_mode"] == "tool_assisted"
|
) -> None:
|
||||||
|
captured: dict[str, object] = {}
|
||||||
|
fake_chat_session = SimpleNamespace(state_snapshot={}, message_count=2)
|
||||||
|
_patch_repositories(monkeypatch, captured, fake_chat_session)
|
||||||
|
|
||||||
assert captured["message_delta"] == 1
|
store = store_module.SqlAlchemyEventStore(session_factory=lambda: _FakeSessionCtx())
|
||||||
assert captured["token_delta"] == 20
|
await store.persist(
|
||||||
assert captured["cost_delta"] == Decimal("0.01")
|
{
|
||||||
|
"type": "RUN_ERROR",
|
||||||
|
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||||
|
"runId": "run-cancel-1",
|
||||||
|
"message": "run canceled by user",
|
||||||
|
"code": "RUN_CANCELED",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert captured["status"] == _SessionStatus.FAILED
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -72,3 +73,31 @@ async def test_orchestrator_emits_run_lifecycle_events() -> None:
|
|||||||
assert result["worker"]["answer"] == "done"
|
assert result["worker"]["answer"] == "done"
|
||||||
event_types = [item["event"]["type"] for item in pipeline.events]
|
event_types = [item["event"]["type"] for item in pipeline.events]
|
||||||
assert event_types == ["RUN_STARTED", "RUN_FINISHED"]
|
assert event_types == ["RUN_STARTED", "RUN_FINISHED"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_orchestrator_emits_run_canceled_error_on_cancelled_error() -> None:
|
||||||
|
class _CanceledRunner:
|
||||||
|
async def execute(self, **kwargs: object) -> dict[str, Any]:
|
||||||
|
del kwargs
|
||||||
|
raise asyncio.CancelledError("run canceled by user")
|
||||||
|
|
||||||
|
pipeline = _FakePipeline()
|
||||||
|
orchestrator = AgentScopeRuntimeOrchestrator(
|
||||||
|
pipeline=pipeline, runner=_CanceledRunner()
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(asyncio.CancelledError):
|
||||||
|
await orchestrator.run(
|
||||||
|
run_input=_run_input(),
|
||||||
|
context_messages=[],
|
||||||
|
user_context=_user_context(),
|
||||||
|
runtime_config=_runtime_config(),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert [item["event"]["type"] for item in pipeline.events] == [
|
||||||
|
"RUN_STARTED",
|
||||||
|
"RUN_ERROR",
|
||||||
|
]
|
||||||
|
run_error_event = pipeline.events[-1]["event"]
|
||||||
|
assert run_error_event["code"] == "RUN_CANCELED"
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import pytest
|
import pytest
|
||||||
from ag_ui.core import RunAgentInput
|
from ag_ui.core import RunAgentInput
|
||||||
|
|
||||||
@@ -265,3 +266,73 @@ async def test_execute_runs_router_then_worker(
|
|||||||
assert load_calls == [AgentType.ROUTER, AgentType.WORKER]
|
assert load_calls == [AgentType.ROUTER, AgentType.WORKER]
|
||||||
assert result["router"]["normalized_task_input"]["user_text"] == "安排会议"
|
assert result["router"]["normalized_task_input"]["user_text"] == "安排会议"
|
||||||
assert result["worker"]["answer"] == "ok"
|
assert result["worker"]["answer"] == "ok"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_raises_cancelled_error_before_worker_when_cancel_requested(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
class _FakePipeline:
|
||||||
|
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
|
||||||
|
del session_id, event
|
||||||
|
return "1-0"
|
||||||
|
|
||||||
|
class _FakeSessionCtx:
|
||||||
|
async def __aenter__(self) -> object:
|
||||||
|
return object()
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type: object, exc: object, tb: object) -> None:
|
||||||
|
del exc_type, exc, tb
|
||||||
|
|
||||||
|
runner = AgentScopeRunner()
|
||||||
|
|
||||||
|
async def _fake_load_stage_config(*, session: object, agent_type: AgentType):
|
||||||
|
del session
|
||||||
|
return runner_module.SystemAgentRuntimeConfig(
|
||||||
|
agent_type=agent_type,
|
||||||
|
model_code="demo",
|
||||||
|
api_base_url="https://example.com",
|
||||||
|
api_key="test",
|
||||||
|
llm_config=runner_module.SystemAgentLLMConfig(),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _fake_execute_router_step(**kwargs: object) -> RouterAgentOutput:
|
||||||
|
del kwargs
|
||||||
|
return RouterAgentOutput(
|
||||||
|
normalized_task_input=NormalizedTaskInput(
|
||||||
|
user_text="安排会议",
|
||||||
|
context_summary="",
|
||||||
|
),
|
||||||
|
key_entities=[],
|
||||||
|
constraints=[],
|
||||||
|
task_typing=TaskTyping(primary=TaskType.SCHEDULING),
|
||||||
|
execution_mode=ExecutionMode.TOOL_ASSISTED,
|
||||||
|
result_typing=ResultTyping(primary=ResultType.EXECUTION_REPORT),
|
||||||
|
ui=RouterUiDecision(
|
||||||
|
ui_mode=UiMode.NONE,
|
||||||
|
ui_decision_reason="单任务",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _fake_execute_worker_step(**kwargs: object) -> WorkerAgentOutputLite:
|
||||||
|
del kwargs
|
||||||
|
raise AssertionError("worker should not run after cancel")
|
||||||
|
|
||||||
|
monkeypatch.setattr(runner_module, "AsyncSessionLocal", lambda: _FakeSessionCtx())
|
||||||
|
monkeypatch.setattr(runner, "_load_stage_config", _fake_load_stage_config)
|
||||||
|
monkeypatch.setattr(runner, "_build_toolkit", lambda **kwargs: object())
|
||||||
|
monkeypatch.setattr(runner, "_execute_router_step", _fake_execute_router_step)
|
||||||
|
monkeypatch.setattr(runner, "_execute_worker_step", _fake_execute_worker_step)
|
||||||
|
|
||||||
|
async def _cancel_checker() -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
with pytest.raises(asyncio.CancelledError):
|
||||||
|
await runner.execute(
|
||||||
|
user_context=_user_context(),
|
||||||
|
context_messages=[],
|
||||||
|
pipeline=_FakePipeline(),
|
||||||
|
run_input=_run_input(),
|
||||||
|
runtime_config=_runtime_config(),
|
||||||
|
cancel_checker=_cancel_checker,
|
||||||
|
)
|
||||||
|
|||||||
@@ -166,6 +166,69 @@ async def test_run_agentscope_task_injects_runtime_config(
|
|||||||
assert captured_config["runtime_config"] is not None
|
assert captured_config["runtime_config"] is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_agentscope_task_injects_cancel_checker(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
captured: dict[str, Any] = {}
|
||||||
|
|
||||||
|
class _FakeRuntime:
|
||||||
|
def __init__(self, **kwargs: object) -> None:
|
||||||
|
del kwargs
|
||||||
|
|
||||||
|
async def run(self, **kwargs: object) -> object:
|
||||||
|
checker = kwargs.get("cancel_checker")
|
||||||
|
assert callable(checker)
|
||||||
|
captured["cancelled"] = await checker() # type: ignore[misc]
|
||||||
|
return object()
|
||||||
|
|
||||||
|
class _FakeRedis:
|
||||||
|
async def exists(self, key: str) -> int:
|
||||||
|
captured["cancel_key"] = key
|
||||||
|
return 1
|
||||||
|
|
||||||
|
async def delete(self, key: str) -> int:
|
||||||
|
captured["deleted_key"] = key
|
||||||
|
return 1
|
||||||
|
|
||||||
|
async def _fake_get_redis_client() -> object:
|
||||||
|
return _FakeRedis()
|
||||||
|
|
||||||
|
async def _empty_context(**kwargs: object) -> list[dict[str, Any]]:
|
||||||
|
del kwargs
|
||||||
|
return []
|
||||||
|
|
||||||
|
monkeypatch.setattr(tasks_module, "AgentScopeRuntimeOrchestrator", _FakeRuntime)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
tasks_module,
|
||||||
|
"get_or_init_redis_client",
|
||||||
|
_fake_get_redis_client,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(tasks_module, "AsyncSessionLocal", lambda: _FakeSessionCtx())
|
||||||
|
monkeypatch.setattr(tasks_module, "_build_user_context", _fake_user_context)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
tasks_module,
|
||||||
|
"_build_recent_context_messages",
|
||||||
|
_empty_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
await tasks_module.run_agentscope_task(
|
||||||
|
{
|
||||||
|
"command": "run",
|
||||||
|
"owner_id": str(uuid4()),
|
||||||
|
"run_input": _run_input_payload(),
|
||||||
|
"runtime_config": {
|
||||||
|
"enabled_tools": [],
|
||||||
|
"context": {"window_mode": "day", "window_count": 2},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert captured["cancelled"] is True
|
||||||
|
assert isinstance(captured["cancel_key"], str)
|
||||||
|
assert captured["deleted_key"] == captured["cancel_key"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_run_agentscope_task_requires_owner_id() -> None:
|
async def test_run_agentscope_task_requires_owner_id() -> None:
|
||||||
with pytest.raises(ValueError, match="owner_id is required"):
|
with pytest.raises(ValueError, match="owner_id is required"):
|
||||||
|
|||||||
@@ -100,6 +100,7 @@ class _FakeRepository:
|
|||||||
class _FakeQueue:
|
class _FakeQueue:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.commands: list[dict[str, object]] = []
|
self.commands: list[dict[str, object]] = []
|
||||||
|
self.cancel_requests: list[dict[str, str]] = []
|
||||||
|
|
||||||
async def enqueue(
|
async def enqueue(
|
||||||
self, *, command: dict[str, object], dedup_key: str | None
|
self, *, command: dict[str, object], dedup_key: str | None
|
||||||
@@ -108,6 +109,21 @@ class _FakeQueue:
|
|||||||
self.commands.append(command)
|
self.commands.append(command)
|
||||||
return "task-1"
|
return "task-1"
|
||||||
|
|
||||||
|
async def request_cancel(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
requested_by: str,
|
||||||
|
) -> None:
|
||||||
|
self.cancel_requests.append(
|
||||||
|
{
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"run_id": run_id,
|
||||||
|
"requested_by": requested_by,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class _FakeStream:
|
class _FakeStream:
|
||||||
async def read(
|
async def read(
|
||||||
@@ -469,3 +485,56 @@ async def test_get_history_snapshot_filters_out_tool_messages() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert [message.role for message in snapshot.messages] == ["user", "assistant"]
|
assert [message.role for message in snapshot.messages] == ["user", "assistant"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancel_run_requests_queue_cancel_for_owner() -> None:
|
||||||
|
queue = _FakeQueue()
|
||||||
|
service = AgentService(
|
||||||
|
repository=_FakeRepository(),
|
||||||
|
queue=queue,
|
||||||
|
stream=_FakeStream(),
|
||||||
|
attachment_storage=_FakeAttachmentStorage(),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await service.cancel_run(
|
||||||
|
thread_id="00000000-0000-0000-0000-000000000001",
|
||||||
|
run_id="run-cancel-1",
|
||||||
|
current_user=_user(),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.accepted is True
|
||||||
|
assert result.thread_id == "00000000-0000-0000-0000-000000000001"
|
||||||
|
assert result.run_id == "run-cancel-1"
|
||||||
|
assert queue.cancel_requests == [
|
||||||
|
{
|
||||||
|
"thread_id": "00000000-0000-0000-0000-000000000001",
|
||||||
|
"run_id": "run-cancel-1",
|
||||||
|
"requested_by": "00000000-0000-0000-0000-000000000001",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancel_run_rejects_non_owner() -> None:
|
||||||
|
queue = _FakeQueue()
|
||||||
|
service = AgentService(
|
||||||
|
repository=_FakeRepository(),
|
||||||
|
queue=queue,
|
||||||
|
stream=_FakeStream(),
|
||||||
|
attachment_storage=_FakeAttachmentStorage(),
|
||||||
|
)
|
||||||
|
other_user = CurrentUser(
|
||||||
|
id=UUID("00000000-0000-0000-0000-000000000099"),
|
||||||
|
phone="+8613812340000",
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
await service.cancel_run(
|
||||||
|
thread_id="00000000-0000-0000-0000-000000000001",
|
||||||
|
run_id="run-cancel-2",
|
||||||
|
current_user=other_user,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == 403
|
||||||
|
assert queue.cancel_requests == []
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
|
|
||||||
**Architecture:** 使用“协作取消 + 主任务中断”方案:API 层写入 Redis cancel 信号,runtime 在 worker 进程内并行 watcher 监听信号,命中后先调用 active agent 的 `interrupt()` 做优雅收尾,再 `cancel()` 当前 run 主任务做硬兜底。终态统一通过 `RUN_ERROR` 事件落库,复用现有 `FAILED` 会话语义,避免数据库枚举迁移。
|
**Architecture:** 使用“协作取消 + 主任务中断”方案:API 层写入 Redis cancel 信号,runtime 在 worker 进程内并行 watcher 监听信号,命中后先调用 active agent 的 `interrupt()` 做优雅收尾,再 `cancel()` 当前 run 主任务做硬兜底。终态统一通过 `RUN_ERROR` 事件落库,复用现有 `FAILED` 会话语义,避免数据库枚举迁移。
|
||||||
|
|
||||||
**Tech Stack:** FastAPI, TaskIQ, Redis, AgentScope, SQLAlchemy, Pytest, Ruff, BasedPyright
|
**Tech Stack:** FastAPI, TaskIQ, Redis, AgentScope, SQLAlchemy, Flutter, Pytest, Ruff, BasedPyright
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -390,3 +390,81 @@ git commit -m "feat: support run cancellation with RUN_CANCELED failed semantics
|
|||||||
- 回滚 `router/service/dependencies` cancel 新接口
|
- 回滚 `router/service/dependencies` cancel 新接口
|
||||||
- 回滚 `runner/orchestrator/tasks` cancel 注入逻辑
|
- 回滚 `runner/orchestrator/tasks` cancel 注入逻辑
|
||||||
- 保持原 `POST /runs` 与 SSE 流程不变
|
- 保持原 `POST /runs` 与 SSE 流程不变
|
||||||
|
|
||||||
|
### Task 7: 前端接入 cancel API(发送后“停止生成”按钮走后端真实取消)
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `apps/lib/features/chat/data/services/ag_ui_service.dart`
|
||||||
|
- Modify: `apps/lib/features/chat/presentation/bloc/chat_bloc.dart`
|
||||||
|
- Modify: `apps/lib/features/chat/data/models/ag_ui_event.dart`
|
||||||
|
- Modify: `apps/lib/features/home/ui/screens/home_screen_interactions.dart`
|
||||||
|
- Test: `apps/test/features/chat/data/services/ag_ui_service_test.dart`
|
||||||
|
- Test: `apps/test/features/chat/presentation/chat_bloc_attachment_sync_test.dart`
|
||||||
|
|
||||||
|
**Step 1: 在 AgUiService 维护当前运行态标识**
|
||||||
|
|
||||||
|
在 `AgUiService` 增加字段:
|
||||||
|
- `_activeThreadIdForRun: String?`
|
||||||
|
- `_activeRunId: String?`
|
||||||
|
|
||||||
|
并在 `sendMessage` 成功拿到 `/runs` 响应后设置这两个字段;在收到目标 run 的终态事件(`RUN_FINISHED` / `RUN_ERROR`)后清理。
|
||||||
|
|
||||||
|
**Step 2: 将 cancelCurrentRun 从“仅断 SSE”升级为“先调用后端 cancel,再本地收流”**
|
||||||
|
|
||||||
|
`AgUiService.cancelCurrentRun()` 改为:
|
||||||
|
1. 若 `_activeThreadIdForRun` 或 `_activeRunId` 为空:退化为当前行为(仅关闭 SSE)
|
||||||
|
2. 否则先调用:
|
||||||
|
|
||||||
|
```text
|
||||||
|
POST /api/v1/agent/runs/{threadId}/cancel?runId={runId}
|
||||||
|
```
|
||||||
|
|
||||||
|
3. 请求成功后再执行 `_cancelActiveSseSubscription()`(避免继续占用本地连接)
|
||||||
|
4. 不论后端是否即时生效,都清理本地 active run 字段,防止重复 cancel
|
||||||
|
|
||||||
|
说明:这一步就是把“发送消息后的停止按钮”真正连到后端取消能力。
|
||||||
|
|
||||||
|
**Step 3: 错误语义细化(前端展示友好)**
|
||||||
|
|
||||||
|
在 `chat_bloc.dart` 处理 `RunErrorEvent` 时:
|
||||||
|
- 如果 `errorEvent.code == 'RUN_CANCELED'`,错误文案不按失败提示展示(可置空或显示“已停止生成”)
|
||||||
|
- 仍执行 `_resetRunState` 与 tool 卡片收尾,保持 UI 一致性
|
||||||
|
|
||||||
|
**Step 4: 保持现有按钮入口,不改交互入口路径**
|
||||||
|
|
||||||
|
`home_screen_interactions.dart` 里的 `_onStopGenerating -> _chatBloc.cancelCurrentRun()` 已经是正确入口,继续复用。
|
||||||
|
|
||||||
|
仅调整 Toast 文案策略:
|
||||||
|
- 请求已发出:`已请求停止`
|
||||||
|
- 收到 `RUN_ERROR(code=RUN_CANCELED)`:最终态 `已停止生成`
|
||||||
|
|
||||||
|
**Step 5: 写 AgUiService 测试(先红)**
|
||||||
|
|
||||||
|
在 `ag_ui_service_test.dart` 增加:
|
||||||
|
- `cancelCurrentRun` 会调用新端点 `/api/v1/agent/runs/{threadId}/cancel`
|
||||||
|
- query 参数包含 `runId`
|
||||||
|
- 调用后会关闭当前 SSE subscription
|
||||||
|
|
||||||
|
**Step 6: 写 ChatBloc 测试(先红)**
|
||||||
|
|
||||||
|
在 `chat_bloc_attachment_sync_test.dart` 增加:
|
||||||
|
- 收到 `RunErrorEvent(message: 'run canceled by user', code: 'RUN_CANCELED')` 后:
|
||||||
|
- `isWaitingFirstToken/isStreaming/isCancelling` 全部归零
|
||||||
|
- 不显示普通失败文案(或显示取消态文案,按你们最终文案策略断言)
|
||||||
|
|
||||||
|
**Step 7: 运行 Flutter 测试**
|
||||||
|
|
||||||
|
Run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
flutter test apps/test/features/chat/data/services/ag_ui_service_test.dart apps/test/features/chat/presentation/chat_bloc_attachment_sync_test.dart
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected: PASS
|
||||||
|
|
||||||
|
**Step 8: 前端接入提交**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add apps/lib/features/chat/data/services/ag_ui_service.dart apps/lib/features/chat/presentation/bloc/chat_bloc.dart apps/lib/features/chat/data/models/ag_ui_event.dart apps/lib/features/home/ui/screens/home_screen_interactions.dart apps/test/features/chat/data/services/ag_ui_service_test.dart apps/test/features/chat/presentation/chat_bloc_attachment_sync_test.dart
|
||||||
|
git commit -m "feat: wire stop-generating button to backend run cancel API"
|
||||||
|
```
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ Base URL: `/api/v1/agent`
|
|||||||
| 方法 | 路径 | 说明 |
|
| 方法 | 路径 | 说明 |
|
||||||
|---|---|---|
|
|---|---|---|
|
||||||
| POST | `/runs` | 创建一次 agent run(异步入队) |
|
| POST | `/runs` | 创建一次 agent run(异步入队) |
|
||||||
|
| POST | `/runs/{thread_id}/cancel` | 请求取消指定 run |
|
||||||
| GET | `/runs/{thread_id}/events` | 订阅 SSE 事件流 |
|
| GET | `/runs/{thread_id}/events` | 订阅 SSE 事件流 |
|
||||||
| GET | `/history` | 获取历史快照(按天分页) |
|
| GET | `/history` | 获取历史快照(按天分页) |
|
||||||
| POST | `/attachments` | 上传用户图片附件 |
|
| POST | `/attachments` | 上传用户图片附件 |
|
||||||
@@ -95,7 +96,43 @@ Base URL: `/api/v1/agent`
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 3) GET `/history`
|
## 3) POST `/runs/{thread_id}/cancel`
|
||||||
|
|
||||||
|
请求取消指定 run。该接口返回 `202` 仅表示取消请求已被后端接收,不保证运行已在同一时刻停止。
|
||||||
|
|
||||||
|
### Path
|
||||||
|
|
||||||
|
| 参数 | 类型 | 说明 |
|
||||||
|
|---|---|---|
|
||||||
|
| `thread_id` | string | 会话 ID |
|
||||||
|
|
||||||
|
### Query
|
||||||
|
|
||||||
|
| 参数 | 类型 | 必填 | 说明 |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `runId` | string | 是 | 需要取消的 run ID |
|
||||||
|
|
||||||
|
### Response
|
||||||
|
|
||||||
|
`202 Accepted`
|
||||||
|
|
||||||
|
```ts
|
||||||
|
{
|
||||||
|
threadId: string;
|
||||||
|
runId: string;
|
||||||
|
accepted: true;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 错误码
|
||||||
|
|
||||||
|
- `401` 未认证
|
||||||
|
- `403` 非会话所有者
|
||||||
|
- `422` 参数非法
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4) GET `/history`
|
||||||
|
|
||||||
返回历史快照(`HistorySnapshotResponse`),不是 SSE 包装事件。
|
返回历史快照(`HistorySnapshotResponse`),不是 SSE 包装事件。
|
||||||
|
|
||||||
@@ -146,7 +183,7 @@ tool 消息在存储层用于运行时上下文续接,不在 `/history` 对外
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 4) POST `/attachments`
|
## 5) POST `/attachments`
|
||||||
|
|
||||||
上传图片附件,返回可直接用于 `RunAgentInput.messages[].content[].url` 的签名链接。
|
上传图片附件,返回可直接用于 `RunAgentInput.messages[].content[].url` 的签名链接。
|
||||||
|
|
||||||
@@ -182,7 +219,7 @@ tool 消息在存储层用于运行时上下文续接,不在 `/history` 对外
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 5) GET `/attachments/signed-url`
|
## 6) GET `/attachments/signed-url`
|
||||||
|
|
||||||
对已有存储对象重新签名。
|
对已有存储对象重新签名。
|
||||||
|
|
||||||
@@ -205,7 +242,7 @@ tool 消息在存储层用于运行时上下文续接,不在 `/history` 对外
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 6) POST `/transcribe`
|
## 7) POST `/transcribe`
|
||||||
|
|
||||||
WAV 音频转写。
|
WAV 音频转写。
|
||||||
|
|
||||||
|
|||||||
@@ -75,6 +75,20 @@ data: <json>
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
取消语义(当前实现):
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "RUN_ERROR",
|
||||||
|
"threadId": "...",
|
||||||
|
"runId": "...",
|
||||||
|
"message": "run canceled by user",
|
||||||
|
"code": "RUN_CANCELED"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
说明:`RUN_CANCELED` 表示用户主动中断,本阶段后端仍复用会话 `failed` 状态以保持兼容。
|
||||||
|
|
||||||
### 3.2 阶段事件
|
### 3.2 阶段事件
|
||||||
|
|
||||||
#### `STEP_STARTED`
|
#### `STEP_STARTED`
|
||||||
|
|||||||
Reference in New Issue
Block a user