From 96fc4a1e7728e25b65581461ed54eba1c924cf27 Mon Sep 17 00:00:00 2001 From: qzl Date: Wed, 25 Mar 2026 18:33:25 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=20agent=20=E8=BF=90?= =?UTF-8?q?=E8=A1=8C=E5=8F=96=E6=B6=88=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../chat/data/services/ag_ui_service.dart | 34 ++++- .../chat/presentation/bloc/chat_bloc.dart | 20 ++- .../ui/screens/home_screen_interactions.dart | 2 +- .../data/services/ag_ui_service_test.dart | 47 ++++++ .../chat_bloc_attachment_sync_test.dart | 22 +++ .../core/agentscope/runtime/orchestrator.py | 23 ++- backend/src/core/agentscope/runtime/runner.py | 143 +++++++++++++----- backend/src/core/agentscope/runtime/tasks.py | 33 +++- backend/src/v1/agent/dependencies.py | 26 ++++ backend/src/v1/agent/router.py | 29 ++++ backend/src/v1/agent/schemas.py | 23 +++ backend/src/v1/agent/service.py | 21 +++ .../tests/integration/v1/agent/test_routes.py | 34 +++++ .../unit/core/agentscope/events/test_store.py | 35 +++-- .../agentscope/runtime/test_orchestrator.py | 29 ++++ .../core/agentscope/runtime/test_runner.py | 71 +++++++++ .../core/agentscope/runtime/test_tasks.py | 63 ++++++++ backend/tests/unit/v1/agent/test_service.py | 69 +++++++++ .../2026-03-25-agent-run-cancel-failed.md | 80 +++++++++- docs/protocols/agent/api-endpoints.md | 45 +++++- docs/protocols/agent/sse-events.md | 14 ++ 21 files changed, 778 insertions(+), 85 deletions(-) diff --git a/apps/lib/features/chat/data/services/ag_ui_service.dart b/apps/lib/features/chat/data/services/ag_ui_service.dart index 7edf4c2..24177dc 100644 --- a/apps/lib/features/chat/data/services/ag_ui_service.dart +++ b/apps/lib/features/chat/data/services/ag_ui_service.dart @@ -50,6 +50,8 @@ class AgUiService { Completer? _activeSseDoneCompleter; String? _threadId; + String? _activeThreadIdForRun; + String? _activeRunId; bool _hasMoreHistory = false; AgUiService({EventCallback? onEvent, required IApiClient apiClient}) @@ -83,11 +85,20 @@ class AgUiService { throw StateError('Missing runId in /agent/runs response'); } _threadId = threadId; - await _streamEventsFromApi( - threadId, - expectedRunId: runId, - streamToken: streamToken, - ); + _activeThreadIdForRun = threadId; + _activeRunId = runId; + try { + await _streamEventsFromApi( + threadId, + expectedRunId: runId, + streamToken: streamToken, + ); + } finally { + if (_activeThreadIdForRun == threadId && _activeRunId == runId) { + _activeThreadIdForRun = null; + _activeRunId = null; + } + } return SendMessageResult( uploadedAttachments: runInputPayload.uploadedAttachments, ); @@ -151,6 +162,19 @@ class AgUiService { } Future cancelCurrentRun() async { + final activeThreadId = _activeThreadIdForRun; + final activeRunId = _activeRunId; + if (activeThreadId != null && activeRunId != null) { + final encodedRunId = Uri.encodeQueryComponent(activeRunId); + await _apiClient.post>( + '/api/v1/agent/runs/$activeThreadId/cancel?runId=$encodedRunId', + ); + _activeThreadIdForRun = null; + _activeRunId = null; + _activeStreamToken += 1; + await _cancelActiveSseSubscription(); + return; + } _activeStreamToken += 1; await _cancelActiveSseSubscription(); } diff --git a/apps/lib/features/chat/presentation/bloc/chat_bloc.dart b/apps/lib/features/chat/presentation/bloc/chat_bloc.dart index ada3df1..61b75e3 100644 --- a/apps/lib/features/chat/presentation/bloc/chat_bloc.dart +++ b/apps/lib/features/chat/presentation/bloc/chat_bloc.dart @@ -123,10 +123,16 @@ class ChatBloc extends Cubit { ); case AgUiEventType.runError: final errorEvent = event as RunErrorEvent; + final isCanceledByUser = errorEvent.code == 'RUN_CANCELED'; emit( _resetRunState( - error: errorEvent.message, - ).copyWith(items: _markActiveToolCallsFailed(state.items)), + error: isCanceledByUser ? null : errorEvent.message, + ).copyWith( + items: _markActiveToolCallsFailed( + state.items, + reason: isCanceledByUser ? '本次运行已取消' : '本次运行已失败', + ), + ), ); case AgUiEventType.stepStarted: _handleStepStarted(event as StepStartedEvent); @@ -286,7 +292,10 @@ class ChatBloc extends Cubit { return items.where((item) => item is! ToolCallItem).toList(); } - List _markActiveToolCallsFailed(List items) { + List _markActiveToolCallsFailed( + List items, { + required String reason, + }) { return items.map((item) { if (item is! ToolCallItem) { return item; @@ -297,10 +306,7 @@ class ChatBloc extends Cubit { if (item.status == ToolCallStatus.completed) { return item; } - return item.copyWith( - status: ToolCallStatus.error, - errorMessage: '本次运行已失败', - ); + return item.copyWith(status: ToolCallStatus.error, errorMessage: reason); }).toList(); } diff --git a/apps/lib/features/home/ui/screens/home_screen_interactions.dart b/apps/lib/features/home/ui/screens/home_screen_interactions.dart index 553346e..c97df08 100644 --- a/apps/lib/features/home/ui/screens/home_screen_interactions.dart +++ b/apps/lib/features/home/ui/screens/home_screen_interactions.dart @@ -75,7 +75,7 @@ extension _HomeScreenInteractions on _HomeScreenState { return; } if (canceled) { - Toast.show(context, '已停止等待回复', type: ToastType.info); + Toast.show(context, '已请求停止', type: ToastType.info); } } diff --git a/apps/test/features/chat/data/services/ag_ui_service_test.dart b/apps/test/features/chat/data/services/ag_ui_service_test.dart index 0145d49..b745b96 100644 --- a/apps/test/features/chat/data/services/ag_ui_service_test.dart +++ b/apps/test/features/chat/data/services/ag_ui_service_test.dart @@ -16,6 +16,7 @@ class _FakeApiClient implements IApiClient { final List sseLines; final Stream Function()? sseLineStreamFactory; final String Function()? runIdFactory; + final List postPaths = []; @override Future> delete(String path, {data, Options? options}) { @@ -51,6 +52,19 @@ class _FakeApiClient implements IApiClient { @override Future> post(String path, {data, Options? options}) async { + postPaths.add(path); + if (path.contains('/cancel?runId=')) { + final payload = { + 'threadId': 'thread-1', + 'runId': 'run-new', + 'accepted': true, + }; + return Response( + requestOptions: RequestOptions(path: path), + data: payload as T, + statusCode: 202, + ); + } final runIdFactory = this.runIdFactory; final payload = { 'taskId': 'task-1', @@ -192,6 +206,39 @@ void main() { await streamController.close(); }); + test( + 'cancelCurrentRun calls backend cancel endpoint for active run', + () async { + final streamController = StreamController(); + final fakeApi = _FakeApiClient( + sseLines: const [], + sseLineStreamFactory: () => streamController.stream, + ); + final service = AgUiService(apiClient: fakeApi); + + final sendFuture = service.sendMessage('hello'); + await Future.delayed(Duration.zero); + for (final line in _buildSseEvent( + id: '51', + type: AgUiEventTypeWire.runStarted, + payload: + '{"type":"RUN_STARTED","threadId":"thread-1","runId":"run-new"}', + )) { + streamController.add(line); + } + await Future.delayed(Duration.zero); + + await service.cancelCurrentRun(); + await sendFuture; + + expect( + fakeApi.postPaths, + contains('/api/v1/agent/runs/thread-1/cancel?runId=run-new'), + ); + await streamController.close(); + }, + ); + test( 'new sendMessage cancels previous SSE subscription explicitly', () async { diff --git a/apps/test/features/chat/presentation/chat_bloc_attachment_sync_test.dart b/apps/test/features/chat/presentation/chat_bloc_attachment_sync_test.dart index 0df2956..70c7036 100644 --- a/apps/test/features/chat/presentation/chat_bloc_attachment_sync_test.dart +++ b/apps/test/features/chat/presentation/chat_bloc_attachment_sync_test.dart @@ -206,6 +206,28 @@ void main() { expect(bloc.state.error, 'runtime execution failed'); }); + test('run canceled error clears error and marks tool as canceled', () { + service.emitEvent( + ToolCallStartEvent( + toolCallId: 'tool-cancel', + toolCallName: 'ocr_image', + ), + ); + service.emitEvent(ToolCallEndEvent(toolCallId: 'tool-cancel')); + + service.emitEvent( + RunErrorEvent(message: 'run canceled by user', code: 'RUN_CANCELED'), + ); + + final toolItem = bloc.state.items.whereType().single; + expect(toolItem.status, ToolCallStatus.error); + expect(toolItem.errorMessage, '本次运行已取消'); + expect(bloc.state.error, isNull); + expect(bloc.state.isWaitingFirstToken, isFalse); + expect(bloc.state.isStreaming, isFalse); + expect(bloc.state.isCancelling, isFalse); + }); + test('text event with ui schema is rendered into chat items', () { service.emitEvent(RunStartedEvent(threadId: 'thread-1', runId: 'run-1')); diff --git a/backend/src/core/agentscope/runtime/orchestrator.py b/backend/src/core/agentscope/runtime/orchestrator.py index 55617ac..fce6f0b 100644 --- a/backend/src/core/agentscope/runtime/orchestrator.py +++ b/backend/src/core/agentscope/runtime/orchestrator.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Any, Protocol +import asyncio +from typing import Any, Awaitable, Callable, Protocol from ag_ui.core.types import RunAgentInput from agentscope.message import Msg @@ -28,6 +29,7 @@ class RunnerLike(Protocol): runtime_config: RuntimeConfig, user_memory: UserMemoryContent | None, work_memory: WorkProfileContent | None, + cancel_checker: Callable[[], Awaitable[bool]] | None = None, ) -> dict[str, Any]: ... @@ -53,6 +55,7 @@ class AgentScopeRuntimeOrchestrator: runtime_config: RuntimeConfig, user_memory: UserMemoryContent | None = None, work_memory: WorkProfileContent | None = None, + cancel_checker: Callable[[], Awaitable[bool]] | None = None, ) -> dict[str, Any]: thread_id = run_input.thread_id run_id = run_input.run_id @@ -74,6 +77,7 @@ class AgentScopeRuntimeOrchestrator: runtime_config=runtime_config, user_memory=user_memory, work_memory=work_memory, + cancel_checker=cancel_checker, ) await self._pipeline.emit( @@ -85,6 +89,23 @@ class AgentScopeRuntimeOrchestrator: }, ) return result if isinstance(result, dict) else {} + except asyncio.CancelledError: + logger.info( + "agentscope runtime execution canceled", + thread_id=thread_id, + run_id=run_id, + ) + await self._pipeline.emit( + session_id=thread_id, + event={ + "type": "RUN_ERROR", + "threadId": thread_id, + "runId": run_id, + "message": "run canceled by user", + "code": "RUN_CANCELED", + }, + ) + raise except Exception: logger.exception( "agentscope runtime execution failed", diff --git a/backend/src/core/agentscope/runtime/runner.py b/backend/src/core/agentscope/runtime/runner.py index f66bfc2..5f7b117 100644 --- a/backend/src/core/agentscope/runtime/runner.py +++ b/backend/src/core/agentscope/runtime/runner.py @@ -1,8 +1,10 @@ from __future__ import annotations +import asyncio +import contextlib from dataclasses import dataclass from datetime import datetime, timezone -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Awaitable, Callable from uuid import UUID from ag_ui.core.types import RunAgentInput @@ -64,6 +66,8 @@ class AgentScopeRunner: def __init__(self, *, litellm_service: LiteLLMService | None = None) -> None: patch_agentscope_json_repair_compat() self._litellm_service: LiteLLMService = litellm_service or LiteLLMService() + self._active_agent: JsonReActAgent | None = None + self._active_agent_lock = asyncio.Lock() async def execute( self, @@ -75,51 +79,99 @@ class AgentScopeRunner: runtime_config: RuntimeConfig, user_memory: UserMemoryContent | None = None, work_memory: WorkProfileContent | None = None, + cancel_checker: Callable[[], Awaitable[bool]] | None = None, ) -> dict[str, Any]: owner_id = UUID(user_context.id) runtime_client_time = self._resolve_runtime_client_time(run_input=run_input) runtime_mode = self._resolve_runtime_mode(run_input=run_input) + stop_cancel_watch = asyncio.Event() + cancel_watch_task: asyncio.Task[None] | None = None + run_task = asyncio.current_task() - async with AsyncSessionLocal() as session: - router_config = await self._load_stage_config( - session=session, - agent_type=AgentType.ROUTER, - ) - worker_config = await self._load_stage_config( - session=session, - agent_type=AgentType.WORKER, - ) - worker_toolkit = self._build_toolkit( - session=session, - owner_id=owner_id, - enabled_tools=runtime_config.enabled_tools, + if cancel_checker is not None and run_task is not None: + cancel_watch_task = asyncio.create_task( + self._watch_cancel_signal( + cancel_checker=cancel_checker, + stop_signal=stop_cancel_watch, + run_task=run_task, + ) ) - router_output = await self._execute_router_step( - pipeline=pipeline, - run_input=run_input, - user_context=user_context, - context_messages=context_messages, - stage_config=router_config, - runtime_client_time=runtime_client_time, - runtime_mode=runtime_mode, - user_memory=user_memory, - ) - worker_output = await self._execute_worker_step( - pipeline=pipeline, - run_input=run_input, - user_context=user_context, - router_output=router_output, - toolkit=worker_toolkit, - stage_config=worker_config, - runtime_client_time=runtime_client_time, - runtime_mode=runtime_mode, - work_memory=work_memory, - ) - return { - "router": router_output.model_dump(mode="json", exclude_none=True), - "worker": worker_output.model_dump(mode="json", exclude_none=True), - } + try: + async with AsyncSessionLocal() as session: + router_config = await self._load_stage_config( + session=session, + agent_type=AgentType.ROUTER, + ) + worker_config = await self._load_stage_config( + session=session, + agent_type=AgentType.WORKER, + ) + worker_toolkit = self._build_toolkit( + session=session, + owner_id=owner_id, + enabled_tools=runtime_config.enabled_tools, + ) + + router_output = await self._execute_router_step( + pipeline=pipeline, + run_input=run_input, + user_context=user_context, + context_messages=context_messages, + stage_config=router_config, + runtime_client_time=runtime_client_time, + runtime_mode=runtime_mode, + user_memory=user_memory, + ) + if cancel_checker is not None and await cancel_checker(): + raise asyncio.CancelledError("run canceled by user") + worker_output = await self._execute_worker_step( + pipeline=pipeline, + run_input=run_input, + user_context=user_context, + router_output=router_output, + toolkit=worker_toolkit, + stage_config=worker_config, + runtime_client_time=runtime_client_time, + runtime_mode=runtime_mode, + work_memory=work_memory, + ) + return { + "router": router_output.model_dump(mode="json", exclude_none=True), + "worker": worker_output.model_dump(mode="json", exclude_none=True), + } + finally: + stop_cancel_watch.set() + if cancel_watch_task is not None: + cancel_watch_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await cancel_watch_task + + async def _watch_cancel_signal( + self, + *, + cancel_checker: Callable[[], Awaitable[bool]], + stop_signal: asyncio.Event, + run_task: asyncio.Task[object], + ) -> None: + while not stop_signal.is_set(): + should_cancel = False + try: + should_cancel = await cancel_checker() + except Exception: + should_cancel = False + + if should_cancel: + async with self._active_agent_lock: + active_agent = self._active_agent + if active_agent is not None: + with contextlib.suppress(Exception): + await active_agent.interrupt() + if not run_task.done(): + run_task.cancel("run canceled by user") + return + + await asyncio.sleep(0.2) def _build_toolkit( self, @@ -373,9 +425,16 @@ class AgentScopeRunner: model=tracking_model, emitter=emitter, ) - response_msg = await agent.reply_json( - input_messages, output_model=worker_output_model - ) + async with self._active_agent_lock: + self._active_agent = agent + try: + response_msg = await agent.reply_json( + input_messages, output_model=worker_output_model + ) + finally: + async with self._active_agent_lock: + if self._active_agent is agent: + self._active_agent = None worker_payload = worker_output_model.model_validate(response_msg.metadata or {}) response_metadata = self._litellm_service.build_usage_metadata( model=stage_config.model_code, diff --git a/backend/src/core/agentscope/runtime/tasks.py b/backend/src/core/agentscope/runtime/tasks.py index 4c27621..4774990 100644 --- a/backend/src/core/agentscope/runtime/tasks.py +++ b/backend/src/core/agentscope/runtime/tasks.py @@ -257,6 +257,7 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]: thread_id = run_input.thread_id run_id = run_input.run_id owner_id = UUID(raw_owner_id) + cancel_key = f"agent:cancel:{thread_id}:{run_id}" if command_type != "run": raise ValueError("invalid command type") @@ -278,6 +279,15 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]: work_memory: WorkProfileContent | None = memories_result.get("work_memory") redis_client = await get_or_init_redis_client() + + async def _cancel_checker() -> bool: + exists_fn = getattr(redis_client, "exists", None) + if not callable(exists_fn): + return False + exists_call = cast(Any, exists_fn)(cancel_key) + result = await exists_call + return bool(result) + bus = RedisStreamBus( client=redis_client, stream_prefix=config.agent_runtime.redis_stream_prefix, @@ -302,14 +312,21 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]: context_config=runtime_config.context, ) - await runtime.run( - run_input=run_input, - context_messages=context_messages, - user_context=user_context, - runtime_config=runtime_config, - user_memory=user_memory, - work_memory=work_memory, - ) + try: + await runtime.run( + run_input=run_input, + context_messages=context_messages, + user_context=user_context, + runtime_config=runtime_config, + user_memory=user_memory, + work_memory=work_memory, + cancel_checker=_cancel_checker, + ) + finally: + delete_fn = getattr(redis_client, "delete", None) + if callable(delete_fn): + delete_call = cast(Any, delete_fn)(cancel_key) + await delete_call logger.info( "agentscope runtime task completed", command_type=command_type, diff --git a/backend/src/v1/agent/dependencies.py b/backend/src/v1/agent/dependencies.py index adf8350..f6c40d9 100644 --- a/backend/src/v1/agent/dependencies.py +++ b/backend/src/v1/agent/dependencies.py @@ -1,6 +1,8 @@ from __future__ import annotations import asyncio +from datetime import datetime, timezone +import json from typing import Any from fastapi import Depends @@ -22,6 +24,7 @@ DEDUP_WAIT_RETRIES = 20 DEDUP_WAIT_SECONDS = 0.05 DEDUP_LOCK_SECONDS = 300 DEDUP_INFLIGHT_MARKER = "__inflight__" +RUN_CANCEL_SIGNAL_TTL_SECONDS = 1800 def _event_stream_block_ms() -> int: @@ -87,6 +90,29 @@ class TaskiqQueueClient: await redis_client.delete(redis_key) raise + async def request_cancel( + self, + *, + thread_id: str, + run_id: str, + requested_by: str, + ) -> None: + redis_client = await self._get_redis() + cancel_key = f"agent:cancel:{thread_id}:{run_id}" + payload = json.dumps( + { + "requested_by": requested_by, + "requested_at": datetime.now(timezone.utc).isoformat(), + }, + ensure_ascii=True, + separators=(",", ":"), + ) + await redis_client.set( + cancel_key, + payload, + ex=RUN_CANCEL_SIGNAL_TTL_SECONDS, + ) + class RedisEventStream: def __init__(self) -> None: diff --git a/backend/src/v1/agent/router.py b/backend/src/v1/agent/router.py index 3879008..e14c534 100644 --- a/backend/src/v1/agent/router.py +++ b/backend/src/v1/agent/router.py @@ -37,6 +37,7 @@ from v1.agent.schemas import ( AttachmentReference, AttachmentSignedUrlResponse, AttachmentUploadResponse, + CancelRunResponse, HistorySnapshotResponse, TaskAcceptedResponse, ) @@ -147,6 +148,34 @@ async def enqueue_run( ) +@router.post( + "/runs/{thread_id}/cancel", + response_model=CancelRunResponse, + status_code=status.HTTP_202_ACCEPTED, +) +async def cancel_run( + thread_id: str, + service: Annotated[AgentService, Depends(get_agent_service)], + current_user: Annotated[CurrentUser, Depends(get_current_user)], + run_id: str = Query( + alias="runId", + min_length=1, + max_length=128, + pattern=r"^[A-Za-z0-9_-]+$", + ), +) -> CancelRunResponse: + canceled = await service.cancel_run( + thread_id=thread_id, + run_id=run_id, + current_user=current_user, + ) + return CancelRunResponse( + threadId=canceled.thread_id, + runId=canceled.run_id, + accepted=canceled.accepted, + ) + + @router.get("/runs/{thread_id}/events") async def stream_events( request: Request, diff --git a/backend/src/v1/agent/schemas.py b/backend/src/v1/agent/schemas.py index 88cd718..236224d 100644 --- a/backend/src/v1/agent/schemas.py +++ b/backend/src/v1/agent/schemas.py @@ -49,6 +49,14 @@ class QueueClientLike(Protocol): self, *, command: dict[str, object], dedup_key: str | None ) -> str: ... + async def request_cancel( + self, + *, + thread_id: str, + run_id: str, + requested_by: str, + ) -> None: ... + class EventStreamLike(Protocol): async def read( @@ -90,6 +98,13 @@ class TaskAccepted: created: bool +@dataclass(frozen=True) +class CancelRequested: + thread_id: str + run_id: str + accepted: bool + + class TaskAcceptedResponse(BaseModel): model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True) @@ -99,6 +114,14 @@ class TaskAcceptedResponse(BaseModel): created: bool +class CancelRunResponse(BaseModel): + model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True) + + thread_id: str = Field(alias="threadId") + run_id: str = Field(alias="runId") + accepted: bool + + class AsrTranscribeResponse(BaseModel): transcript: str = Field(description="Transcribed text from audio") diff --git a/backend/src/v1/agent/service.py b/backend/src/v1/agent/service.py index e1e94f3..91e046f 100644 --- a/backend/src/v1/agent/service.py +++ b/backend/src/v1/agent/service.py @@ -30,6 +30,7 @@ from schemas.domain.chat_message import ( from v1.agent.schemas import ( AgentRepositoryLike, AttachmentStorageLike, + CancelRequested, EventStreamLike, HistorySnapshotResponse, QueueClientLike, @@ -157,6 +158,26 @@ class AgentService: created=created, ) + async def cancel_run( + self, + *, + thread_id: str, + run_id: str, + current_user: CurrentUser, + ) -> CancelRequested: + owner = await self._repository.get_session_owner(session_id=thread_id) + ensure_session_owner(owner_id=owner, current_user=current_user) + await self._queue.request_cancel( + thread_id=thread_id, + run_id=run_id, + requested_by=str(current_user.id), + ) + return CancelRequested( + thread_id=thread_id, + run_id=run_id, + accepted=True, + ) + async def _append_context_cache_user_message( self, *, diff --git a/backend/tests/integration/v1/agent/test_routes.py b/backend/tests/integration/v1/agent/test_routes.py index d915cc2..64820c9 100644 --- a/backend/tests/integration/v1/agent/test_routes.py +++ b/backend/tests/integration/v1/agent/test_routes.py @@ -17,6 +17,7 @@ from v1.users.dependencies import get_current_user class _FakeAgentService: def __init__(self) -> None: self._stream_called = False + self.cancel_calls: list[tuple[str, str, str]] = [] async def enqueue_run( self, @@ -102,6 +103,16 @@ class _FakeAgentService: "url": "https://signed.example/temp-url.png", } + async def cancel_run( + self, + *, + thread_id: str, + run_id: str, + current_user: CurrentUser, + ) -> SimpleNamespace: + self.cancel_calls.append((thread_id, run_id, str(current_user.id))) + return SimpleNamespace(thread_id=thread_id, run_id=run_id, accepted=True) + class _FailingStreamAgentService(_FakeAgentService): async def stream_events( @@ -306,6 +317,29 @@ def test_stream_rejects_invalid_last_event_id() -> None: app.dependency_overrides = {} +def test_cancel_run_returns_202_and_payload() -> None: + service = _FakeAgentService() + app.dependency_overrides[get_agent_service] = lambda: service + app.dependency_overrides[get_current_user] = lambda: CurrentUser( + id=uuid4(), phone="+8613812345678" + ) + client = TestClient(app) + + try: + response = client.post( + "/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/cancel", + params={"runId": "run-99"}, + ) + assert response.status_code == 202 + payload = response.json() + assert payload["threadId"] == "00000000-0000-0000-0000-000000000001" + assert payload["runId"] == "run-99" + assert payload["accepted"] is True + assert service.cancel_calls + finally: + app.dependency_overrides = {} + + def test_history_returns_state_snapshot() -> None: app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService() client = TestClient(app) diff --git a/backend/tests/unit/core/agentscope/events/test_store.py b/backend/tests/unit/core/agentscope/events/test_store.py index b13094c..0a9102d 100644 --- a/backend/tests/unit/core/agentscope/events/test_store.py +++ b/backend/tests/unit/core/agentscope/events/test_store.py @@ -219,21 +219,24 @@ async def test_store_persists_router_step_output_for_cost_tracking( } ) - append_kwargs = cast(dict[str, Any], captured["append_kwargs"]) - assert append_kwargs["seq"] == 11 - assert append_kwargs["content"] == "" - assert append_kwargs["model_code"] == "doubao-seed-1-6-250615" - assert append_kwargs["input_tokens"] == 12 - assert append_kwargs["output_tokens"] == 8 - assert append_kwargs["latency_ms"] == 320 - assert append_kwargs["cost"] == Decimal("0.01") - assert append_kwargs["visibility_mask"] == 0 - metadata = cast(dict[str, Any], append_kwargs["metadata"]) - assert sorted(metadata.keys()) == ["agent_type", "router_agent_output", "run_id"] - assert metadata["agent_type"] == "router" - assert metadata["router_agent_output"]["execution_mode"] == "tool_assisted" +@pytest.mark.asyncio +async def test_store_marks_session_failed_for_run_canceled_error( + monkeypatch: pytest.MonkeyPatch, +) -> None: + captured: dict[str, object] = {} + fake_chat_session = SimpleNamespace(state_snapshot={}, message_count=2) + _patch_repositories(monkeypatch, captured, fake_chat_session) - assert captured["message_delta"] == 1 - assert captured["token_delta"] == 20 - assert captured["cost_delta"] == Decimal("0.01") + store = store_module.SqlAlchemyEventStore(session_factory=lambda: _FakeSessionCtx()) + await store.persist( + { + "type": "RUN_ERROR", + "threadId": "00000000-0000-0000-0000-000000000001", + "runId": "run-cancel-1", + "message": "run canceled by user", + "code": "RUN_CANCELED", + } + ) + + assert captured["status"] == _SessionStatus.FAILED diff --git a/backend/tests/unit/core/agentscope/runtime/test_orchestrator.py b/backend/tests/unit/core/agentscope/runtime/test_orchestrator.py index b2eb6e7..c4c77c5 100644 --- a/backend/tests/unit/core/agentscope/runtime/test_orchestrator.py +++ b/backend/tests/unit/core/agentscope/runtime/test_orchestrator.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio from typing import Any import pytest @@ -72,3 +73,31 @@ async def test_orchestrator_emits_run_lifecycle_events() -> None: assert result["worker"]["answer"] == "done" event_types = [item["event"]["type"] for item in pipeline.events] assert event_types == ["RUN_STARTED", "RUN_FINISHED"] + + +@pytest.mark.asyncio +async def test_orchestrator_emits_run_canceled_error_on_cancelled_error() -> None: + class _CanceledRunner: + async def execute(self, **kwargs: object) -> dict[str, Any]: + del kwargs + raise asyncio.CancelledError("run canceled by user") + + pipeline = _FakePipeline() + orchestrator = AgentScopeRuntimeOrchestrator( + pipeline=pipeline, runner=_CanceledRunner() + ) + + with pytest.raises(asyncio.CancelledError): + await orchestrator.run( + run_input=_run_input(), + context_messages=[], + user_context=_user_context(), + runtime_config=_runtime_config(), + ) + + assert [item["event"]["type"] for item in pipeline.events] == [ + "RUN_STARTED", + "RUN_ERROR", + ] + run_error_event = pipeline.events[-1]["event"] + assert run_error_event["code"] == "RUN_CANCELED" diff --git a/backend/tests/unit/core/agentscope/runtime/test_runner.py b/backend/tests/unit/core/agentscope/runtime/test_runner.py index e985c9d..2528aed 100644 --- a/backend/tests/unit/core/agentscope/runtime/test_runner.py +++ b/backend/tests/unit/core/agentscope/runtime/test_runner.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import pytest from ag_ui.core import RunAgentInput @@ -265,3 +266,73 @@ async def test_execute_runs_router_then_worker( assert load_calls == [AgentType.ROUTER, AgentType.WORKER] assert result["router"]["normalized_task_input"]["user_text"] == "安排会议" assert result["worker"]["answer"] == "ok" + + +@pytest.mark.asyncio +async def test_execute_raises_cancelled_error_before_worker_when_cancel_requested( + monkeypatch: pytest.MonkeyPatch, +) -> None: + class _FakePipeline: + async def emit(self, *, session_id: str, event: dict[str, object]) -> str: + del session_id, event + return "1-0" + + class _FakeSessionCtx: + async def __aenter__(self) -> object: + return object() + + async def __aexit__(self, exc_type: object, exc: object, tb: object) -> None: + del exc_type, exc, tb + + runner = AgentScopeRunner() + + async def _fake_load_stage_config(*, session: object, agent_type: AgentType): + del session + return runner_module.SystemAgentRuntimeConfig( + agent_type=agent_type, + model_code="demo", + api_base_url="https://example.com", + api_key="test", + llm_config=runner_module.SystemAgentLLMConfig(), + ) + + async def _fake_execute_router_step(**kwargs: object) -> RouterAgentOutput: + del kwargs + return RouterAgentOutput( + normalized_task_input=NormalizedTaskInput( + user_text="安排会议", + context_summary="", + ), + key_entities=[], + constraints=[], + task_typing=TaskTyping(primary=TaskType.SCHEDULING), + execution_mode=ExecutionMode.TOOL_ASSISTED, + result_typing=ResultTyping(primary=ResultType.EXECUTION_REPORT), + ui=RouterUiDecision( + ui_mode=UiMode.NONE, + ui_decision_reason="单任务", + ), + ) + + async def _fake_execute_worker_step(**kwargs: object) -> WorkerAgentOutputLite: + del kwargs + raise AssertionError("worker should not run after cancel") + + monkeypatch.setattr(runner_module, "AsyncSessionLocal", lambda: _FakeSessionCtx()) + monkeypatch.setattr(runner, "_load_stage_config", _fake_load_stage_config) + monkeypatch.setattr(runner, "_build_toolkit", lambda **kwargs: object()) + monkeypatch.setattr(runner, "_execute_router_step", _fake_execute_router_step) + monkeypatch.setattr(runner, "_execute_worker_step", _fake_execute_worker_step) + + async def _cancel_checker() -> bool: + return True + + with pytest.raises(asyncio.CancelledError): + await runner.execute( + user_context=_user_context(), + context_messages=[], + pipeline=_FakePipeline(), + run_input=_run_input(), + runtime_config=_runtime_config(), + cancel_checker=_cancel_checker, + ) diff --git a/backend/tests/unit/core/agentscope/runtime/test_tasks.py b/backend/tests/unit/core/agentscope/runtime/test_tasks.py index 0f73baa..33d3022 100644 --- a/backend/tests/unit/core/agentscope/runtime/test_tasks.py +++ b/backend/tests/unit/core/agentscope/runtime/test_tasks.py @@ -166,6 +166,69 @@ async def test_run_agentscope_task_injects_runtime_config( assert captured_config["runtime_config"] is not None +@pytest.mark.asyncio +async def test_run_agentscope_task_injects_cancel_checker( + monkeypatch: pytest.MonkeyPatch, +) -> None: + captured: dict[str, Any] = {} + + class _FakeRuntime: + def __init__(self, **kwargs: object) -> None: + del kwargs + + async def run(self, **kwargs: object) -> object: + checker = kwargs.get("cancel_checker") + assert callable(checker) + captured["cancelled"] = await checker() # type: ignore[misc] + return object() + + class _FakeRedis: + async def exists(self, key: str) -> int: + captured["cancel_key"] = key + return 1 + + async def delete(self, key: str) -> int: + captured["deleted_key"] = key + return 1 + + async def _fake_get_redis_client() -> object: + return _FakeRedis() + + async def _empty_context(**kwargs: object) -> list[dict[str, Any]]: + del kwargs + return [] + + monkeypatch.setattr(tasks_module, "AgentScopeRuntimeOrchestrator", _FakeRuntime) + monkeypatch.setattr( + tasks_module, + "get_or_init_redis_client", + _fake_get_redis_client, + ) + monkeypatch.setattr(tasks_module, "AsyncSessionLocal", lambda: _FakeSessionCtx()) + monkeypatch.setattr(tasks_module, "_build_user_context", _fake_user_context) + monkeypatch.setattr( + tasks_module, + "_build_recent_context_messages", + _empty_context, + ) + + await tasks_module.run_agentscope_task( + { + "command": "run", + "owner_id": str(uuid4()), + "run_input": _run_input_payload(), + "runtime_config": { + "enabled_tools": [], + "context": {"window_mode": "day", "window_count": 2}, + }, + } + ) + + assert captured["cancelled"] is True + assert isinstance(captured["cancel_key"], str) + assert captured["deleted_key"] == captured["cancel_key"] + + @pytest.mark.asyncio async def test_run_agentscope_task_requires_owner_id() -> None: with pytest.raises(ValueError, match="owner_id is required"): diff --git a/backend/tests/unit/v1/agent/test_service.py b/backend/tests/unit/v1/agent/test_service.py index f6132d3..d4ffdab 100644 --- a/backend/tests/unit/v1/agent/test_service.py +++ b/backend/tests/unit/v1/agent/test_service.py @@ -100,6 +100,7 @@ class _FakeRepository: class _FakeQueue: def __init__(self) -> None: self.commands: list[dict[str, object]] = [] + self.cancel_requests: list[dict[str, str]] = [] async def enqueue( self, *, command: dict[str, object], dedup_key: str | None @@ -108,6 +109,21 @@ class _FakeQueue: self.commands.append(command) return "task-1" + async def request_cancel( + self, + *, + thread_id: str, + run_id: str, + requested_by: str, + ) -> None: + self.cancel_requests.append( + { + "thread_id": thread_id, + "run_id": run_id, + "requested_by": requested_by, + } + ) + class _FakeStream: async def read( @@ -469,3 +485,56 @@ async def test_get_history_snapshot_filters_out_tool_messages() -> None: ) assert [message.role for message in snapshot.messages] == ["user", "assistant"] + + +@pytest.mark.asyncio +async def test_cancel_run_requests_queue_cancel_for_owner() -> None: + queue = _FakeQueue() + service = AgentService( + repository=_FakeRepository(), + queue=queue, + stream=_FakeStream(), + attachment_storage=_FakeAttachmentStorage(), + ) + + result = await service.cancel_run( + thread_id="00000000-0000-0000-0000-000000000001", + run_id="run-cancel-1", + current_user=_user(), + ) + + assert result.accepted is True + assert result.thread_id == "00000000-0000-0000-0000-000000000001" + assert result.run_id == "run-cancel-1" + assert queue.cancel_requests == [ + { + "thread_id": "00000000-0000-0000-0000-000000000001", + "run_id": "run-cancel-1", + "requested_by": "00000000-0000-0000-0000-000000000001", + } + ] + + +@pytest.mark.asyncio +async def test_cancel_run_rejects_non_owner() -> None: + queue = _FakeQueue() + service = AgentService( + repository=_FakeRepository(), + queue=queue, + stream=_FakeStream(), + attachment_storage=_FakeAttachmentStorage(), + ) + other_user = CurrentUser( + id=UUID("00000000-0000-0000-0000-000000000099"), + phone="+8613812340000", + ) + + with pytest.raises(HTTPException) as exc_info: + await service.cancel_run( + thread_id="00000000-0000-0000-0000-000000000001", + run_id="run-cancel-2", + current_user=other_user, + ) + + assert exc_info.value.status_code == 403 + assert queue.cancel_requests == [] diff --git a/docs/plans/2026-03-25-agent-run-cancel-failed.md b/docs/plans/2026-03-25-agent-run-cancel-failed.md index 9d3654e..86dcbcc 100644 --- a/docs/plans/2026-03-25-agent-run-cancel-failed.md +++ b/docs/plans/2026-03-25-agent-run-cancel-failed.md @@ -6,7 +6,7 @@ **Architecture:** 使用“协作取消 + 主任务中断”方案:API 层写入 Redis cancel 信号,runtime 在 worker 进程内并行 watcher 监听信号,命中后先调用 active agent 的 `interrupt()` 做优雅收尾,再 `cancel()` 当前 run 主任务做硬兜底。终态统一通过 `RUN_ERROR` 事件落库,复用现有 `FAILED` 会话语义,避免数据库枚举迁移。 -**Tech Stack:** FastAPI, TaskIQ, Redis, AgentScope, SQLAlchemy, Pytest, Ruff, BasedPyright +**Tech Stack:** FastAPI, TaskIQ, Redis, AgentScope, SQLAlchemy, Flutter, Pytest, Ruff, BasedPyright --- @@ -390,3 +390,81 @@ git commit -m "feat: support run cancellation with RUN_CANCELED failed semantics - 回滚 `router/service/dependencies` cancel 新接口 - 回滚 `runner/orchestrator/tasks` cancel 注入逻辑 - 保持原 `POST /runs` 与 SSE 流程不变 + +### Task 7: 前端接入 cancel API(发送后“停止生成”按钮走后端真实取消) + +**Files:** +- Modify: `apps/lib/features/chat/data/services/ag_ui_service.dart` +- Modify: `apps/lib/features/chat/presentation/bloc/chat_bloc.dart` +- Modify: `apps/lib/features/chat/data/models/ag_ui_event.dart` +- Modify: `apps/lib/features/home/ui/screens/home_screen_interactions.dart` +- Test: `apps/test/features/chat/data/services/ag_ui_service_test.dart` +- Test: `apps/test/features/chat/presentation/chat_bloc_attachment_sync_test.dart` + +**Step 1: 在 AgUiService 维护当前运行态标识** + +在 `AgUiService` 增加字段: +- `_activeThreadIdForRun: String?` +- `_activeRunId: String?` + +并在 `sendMessage` 成功拿到 `/runs` 响应后设置这两个字段;在收到目标 run 的终态事件(`RUN_FINISHED` / `RUN_ERROR`)后清理。 + +**Step 2: 将 cancelCurrentRun 从“仅断 SSE”升级为“先调用后端 cancel,再本地收流”** + +`AgUiService.cancelCurrentRun()` 改为: +1. 若 `_activeThreadIdForRun` 或 `_activeRunId` 为空:退化为当前行为(仅关闭 SSE) +2. 否则先调用: + +```text +POST /api/v1/agent/runs/{threadId}/cancel?runId={runId} +``` + +3. 请求成功后再执行 `_cancelActiveSseSubscription()`(避免继续占用本地连接) +4. 不论后端是否即时生效,都清理本地 active run 字段,防止重复 cancel + +说明:这一步就是把“发送消息后的停止按钮”真正连到后端取消能力。 + +**Step 3: 错误语义细化(前端展示友好)** + +在 `chat_bloc.dart` 处理 `RunErrorEvent` 时: +- 如果 `errorEvent.code == 'RUN_CANCELED'`,错误文案不按失败提示展示(可置空或显示“已停止生成”) +- 仍执行 `_resetRunState` 与 tool 卡片收尾,保持 UI 一致性 + +**Step 4: 保持现有按钮入口,不改交互入口路径** + +`home_screen_interactions.dart` 里的 `_onStopGenerating -> _chatBloc.cancelCurrentRun()` 已经是正确入口,继续复用。 + +仅调整 Toast 文案策略: +- 请求已发出:`已请求停止` +- 收到 `RUN_ERROR(code=RUN_CANCELED)`:最终态 `已停止生成` + +**Step 5: 写 AgUiService 测试(先红)** + +在 `ag_ui_service_test.dart` 增加: +- `cancelCurrentRun` 会调用新端点 `/api/v1/agent/runs/{threadId}/cancel` +- query 参数包含 `runId` +- 调用后会关闭当前 SSE subscription + +**Step 6: 写 ChatBloc 测试(先红)** + +在 `chat_bloc_attachment_sync_test.dart` 增加: +- 收到 `RunErrorEvent(message: 'run canceled by user', code: 'RUN_CANCELED')` 后: + - `isWaitingFirstToken/isStreaming/isCancelling` 全部归零 + - 不显示普通失败文案(或显示取消态文案,按你们最终文案策略断言) + +**Step 7: 运行 Flutter 测试** + +Run: + +```bash +flutter test apps/test/features/chat/data/services/ag_ui_service_test.dart apps/test/features/chat/presentation/chat_bloc_attachment_sync_test.dart +``` + +Expected: PASS + +**Step 8: 前端接入提交** + +```bash +git add apps/lib/features/chat/data/services/ag_ui_service.dart apps/lib/features/chat/presentation/bloc/chat_bloc.dart apps/lib/features/chat/data/models/ag_ui_event.dart apps/lib/features/home/ui/screens/home_screen_interactions.dart apps/test/features/chat/data/services/ag_ui_service_test.dart apps/test/features/chat/presentation/chat_bloc_attachment_sync_test.dart +git commit -m "feat: wire stop-generating button to backend run cancel API" +``` diff --git a/docs/protocols/agent/api-endpoints.md b/docs/protocols/agent/api-endpoints.md index 65b1783..5d918c5 100644 --- a/docs/protocols/agent/api-endpoints.md +++ b/docs/protocols/agent/api-endpoints.md @@ -11,6 +11,7 @@ Base URL: `/api/v1/agent` | 方法 | 路径 | 说明 | |---|---|---| | POST | `/runs` | 创建一次 agent run(异步入队) | +| POST | `/runs/{thread_id}/cancel` | 请求取消指定 run | | GET | `/runs/{thread_id}/events` | 订阅 SSE 事件流 | | GET | `/history` | 获取历史快照(按天分页) | | POST | `/attachments` | 上传用户图片附件 | @@ -95,7 +96,43 @@ Base URL: `/api/v1/agent` --- -## 3) GET `/history` +## 3) POST `/runs/{thread_id}/cancel` + +请求取消指定 run。该接口返回 `202` 仅表示取消请求已被后端接收,不保证运行已在同一时刻停止。 + +### Path + +| 参数 | 类型 | 说明 | +|---|---|---| +| `thread_id` | string | 会话 ID | + +### Query + +| 参数 | 类型 | 必填 | 说明 | +|---|---|---|---| +| `runId` | string | 是 | 需要取消的 run ID | + +### Response + +`202 Accepted` + +```ts +{ + threadId: string; + runId: string; + accepted: true; +} +``` + +### 错误码 + +- `401` 未认证 +- `403` 非会话所有者 +- `422` 参数非法 + +--- + +## 4) GET `/history` 返回历史快照(`HistorySnapshotResponse`),不是 SSE 包装事件。 @@ -146,7 +183,7 @@ tool 消息在存储层用于运行时上下文续接,不在 `/history` 对外 --- -## 4) POST `/attachments` +## 5) POST `/attachments` 上传图片附件,返回可直接用于 `RunAgentInput.messages[].content[].url` 的签名链接。 @@ -182,7 +219,7 @@ tool 消息在存储层用于运行时上下文续接,不在 `/history` 对外 --- -## 5) GET `/attachments/signed-url` +## 6) GET `/attachments/signed-url` 对已有存储对象重新签名。 @@ -205,7 +242,7 @@ tool 消息在存储层用于运行时上下文续接,不在 `/history` 对外 --- -## 6) POST `/transcribe` +## 7) POST `/transcribe` WAV 音频转写。 diff --git a/docs/protocols/agent/sse-events.md b/docs/protocols/agent/sse-events.md index 5b38dcf..2705f04 100644 --- a/docs/protocols/agent/sse-events.md +++ b/docs/protocols/agent/sse-events.md @@ -75,6 +75,20 @@ data: } ``` +取消语义(当前实现): + +```json +{ + "type": "RUN_ERROR", + "threadId": "...", + "runId": "...", + "message": "run canceled by user", + "code": "RUN_CANCELED" +} +``` + +说明:`RUN_CANCELED` 表示用户主动中断,本阶段后端仍复用会话 `failed` 状态以保持兼容。 + ### 3.2 阶段事件 #### `STEP_STARTED`