feat: 支持 agent 运行取消功能

This commit is contained in:
qzl
2026-03-25 18:33:25 +08:00
parent 599c597e69
commit 96fc4a1e77
21 changed files with 778 additions and 85 deletions
+101 -42
View File
@@ -1,8 +1,10 @@
from __future__ import annotations
import asyncio
import contextlib
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Awaitable, Callable
from uuid import UUID
from ag_ui.core.types import RunAgentInput
@@ -64,6 +66,8 @@ class AgentScopeRunner:
def __init__(self, *, litellm_service: LiteLLMService | None = None) -> None:
patch_agentscope_json_repair_compat()
self._litellm_service: LiteLLMService = litellm_service or LiteLLMService()
self._active_agent: JsonReActAgent | None = None
self._active_agent_lock = asyncio.Lock()
async def execute(
self,
@@ -75,51 +79,99 @@ class AgentScopeRunner:
runtime_config: RuntimeConfig,
user_memory: UserMemoryContent | None = None,
work_memory: WorkProfileContent | None = None,
cancel_checker: Callable[[], Awaitable[bool]] | None = None,
) -> dict[str, Any]:
owner_id = UUID(user_context.id)
runtime_client_time = self._resolve_runtime_client_time(run_input=run_input)
runtime_mode = self._resolve_runtime_mode(run_input=run_input)
stop_cancel_watch = asyncio.Event()
cancel_watch_task: asyncio.Task[None] | None = None
run_task = asyncio.current_task()
async with AsyncSessionLocal() as session:
router_config = await self._load_stage_config(
session=session,
agent_type=AgentType.ROUTER,
)
worker_config = await self._load_stage_config(
session=session,
agent_type=AgentType.WORKER,
)
worker_toolkit = self._build_toolkit(
session=session,
owner_id=owner_id,
enabled_tools=runtime_config.enabled_tools,
if cancel_checker is not None and run_task is not None:
cancel_watch_task = asyncio.create_task(
self._watch_cancel_signal(
cancel_checker=cancel_checker,
stop_signal=stop_cancel_watch,
run_task=run_task,
)
)
router_output = await self._execute_router_step(
pipeline=pipeline,
run_input=run_input,
user_context=user_context,
context_messages=context_messages,
stage_config=router_config,
runtime_client_time=runtime_client_time,
runtime_mode=runtime_mode,
user_memory=user_memory,
)
worker_output = await self._execute_worker_step(
pipeline=pipeline,
run_input=run_input,
user_context=user_context,
router_output=router_output,
toolkit=worker_toolkit,
stage_config=worker_config,
runtime_client_time=runtime_client_time,
runtime_mode=runtime_mode,
work_memory=work_memory,
)
return {
"router": router_output.model_dump(mode="json", exclude_none=True),
"worker": worker_output.model_dump(mode="json", exclude_none=True),
}
try:
async with AsyncSessionLocal() as session:
router_config = await self._load_stage_config(
session=session,
agent_type=AgentType.ROUTER,
)
worker_config = await self._load_stage_config(
session=session,
agent_type=AgentType.WORKER,
)
worker_toolkit = self._build_toolkit(
session=session,
owner_id=owner_id,
enabled_tools=runtime_config.enabled_tools,
)
router_output = await self._execute_router_step(
pipeline=pipeline,
run_input=run_input,
user_context=user_context,
context_messages=context_messages,
stage_config=router_config,
runtime_client_time=runtime_client_time,
runtime_mode=runtime_mode,
user_memory=user_memory,
)
if cancel_checker is not None and await cancel_checker():
raise asyncio.CancelledError("run canceled by user")
worker_output = await self._execute_worker_step(
pipeline=pipeline,
run_input=run_input,
user_context=user_context,
router_output=router_output,
toolkit=worker_toolkit,
stage_config=worker_config,
runtime_client_time=runtime_client_time,
runtime_mode=runtime_mode,
work_memory=work_memory,
)
return {
"router": router_output.model_dump(mode="json", exclude_none=True),
"worker": worker_output.model_dump(mode="json", exclude_none=True),
}
finally:
stop_cancel_watch.set()
if cancel_watch_task is not None:
cancel_watch_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await cancel_watch_task
async def _watch_cancel_signal(
self,
*,
cancel_checker: Callable[[], Awaitable[bool]],
stop_signal: asyncio.Event,
run_task: asyncio.Task[object],
) -> None:
while not stop_signal.is_set():
should_cancel = False
try:
should_cancel = await cancel_checker()
except Exception:
should_cancel = False
if should_cancel:
async with self._active_agent_lock:
active_agent = self._active_agent
if active_agent is not None:
with contextlib.suppress(Exception):
await active_agent.interrupt()
if not run_task.done():
run_task.cancel("run canceled by user")
return
await asyncio.sleep(0.2)
def _build_toolkit(
self,
@@ -373,9 +425,16 @@ class AgentScopeRunner:
model=tracking_model,
emitter=emitter,
)
response_msg = await agent.reply_json(
input_messages, output_model=worker_output_model
)
async with self._active_agent_lock:
self._active_agent = agent
try:
response_msg = await agent.reply_json(
input_messages, output_model=worker_output_model
)
finally:
async with self._active_agent_lock:
if self._active_agent is agent:
self._active_agent = None
worker_payload = worker_output_model.model_validate(response_msg.metadata or {})
response_metadata = self._litellm_service.build_usage_metadata(
model=stage_config.model_code,