feat: 支持 agent 运行取消功能
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable
|
||||
from uuid import UUID
|
||||
|
||||
from ag_ui.core.types import RunAgentInput
|
||||
@@ -64,6 +66,8 @@ class AgentScopeRunner:
|
||||
def __init__(self, *, litellm_service: LiteLLMService | None = None) -> None:
|
||||
patch_agentscope_json_repair_compat()
|
||||
self._litellm_service: LiteLLMService = litellm_service or LiteLLMService()
|
||||
self._active_agent: JsonReActAgent | None = None
|
||||
self._active_agent_lock = asyncio.Lock()
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
@@ -75,51 +79,99 @@ class AgentScopeRunner:
|
||||
runtime_config: RuntimeConfig,
|
||||
user_memory: UserMemoryContent | None = None,
|
||||
work_memory: WorkProfileContent | None = None,
|
||||
cancel_checker: Callable[[], Awaitable[bool]] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
owner_id = UUID(user_context.id)
|
||||
runtime_client_time = self._resolve_runtime_client_time(run_input=run_input)
|
||||
runtime_mode = self._resolve_runtime_mode(run_input=run_input)
|
||||
stop_cancel_watch = asyncio.Event()
|
||||
cancel_watch_task: asyncio.Task[None] | None = None
|
||||
run_task = asyncio.current_task()
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
router_config = await self._load_stage_config(
|
||||
session=session,
|
||||
agent_type=AgentType.ROUTER,
|
||||
)
|
||||
worker_config = await self._load_stage_config(
|
||||
session=session,
|
||||
agent_type=AgentType.WORKER,
|
||||
)
|
||||
worker_toolkit = self._build_toolkit(
|
||||
session=session,
|
||||
owner_id=owner_id,
|
||||
enabled_tools=runtime_config.enabled_tools,
|
||||
if cancel_checker is not None and run_task is not None:
|
||||
cancel_watch_task = asyncio.create_task(
|
||||
self._watch_cancel_signal(
|
||||
cancel_checker=cancel_checker,
|
||||
stop_signal=stop_cancel_watch,
|
||||
run_task=run_task,
|
||||
)
|
||||
)
|
||||
|
||||
router_output = await self._execute_router_step(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
user_context=user_context,
|
||||
context_messages=context_messages,
|
||||
stage_config=router_config,
|
||||
runtime_client_time=runtime_client_time,
|
||||
runtime_mode=runtime_mode,
|
||||
user_memory=user_memory,
|
||||
)
|
||||
worker_output = await self._execute_worker_step(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
user_context=user_context,
|
||||
router_output=router_output,
|
||||
toolkit=worker_toolkit,
|
||||
stage_config=worker_config,
|
||||
runtime_client_time=runtime_client_time,
|
||||
runtime_mode=runtime_mode,
|
||||
work_memory=work_memory,
|
||||
)
|
||||
return {
|
||||
"router": router_output.model_dump(mode="json", exclude_none=True),
|
||||
"worker": worker_output.model_dump(mode="json", exclude_none=True),
|
||||
}
|
||||
try:
|
||||
async with AsyncSessionLocal() as session:
|
||||
router_config = await self._load_stage_config(
|
||||
session=session,
|
||||
agent_type=AgentType.ROUTER,
|
||||
)
|
||||
worker_config = await self._load_stage_config(
|
||||
session=session,
|
||||
agent_type=AgentType.WORKER,
|
||||
)
|
||||
worker_toolkit = self._build_toolkit(
|
||||
session=session,
|
||||
owner_id=owner_id,
|
||||
enabled_tools=runtime_config.enabled_tools,
|
||||
)
|
||||
|
||||
router_output = await self._execute_router_step(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
user_context=user_context,
|
||||
context_messages=context_messages,
|
||||
stage_config=router_config,
|
||||
runtime_client_time=runtime_client_time,
|
||||
runtime_mode=runtime_mode,
|
||||
user_memory=user_memory,
|
||||
)
|
||||
if cancel_checker is not None and await cancel_checker():
|
||||
raise asyncio.CancelledError("run canceled by user")
|
||||
worker_output = await self._execute_worker_step(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
user_context=user_context,
|
||||
router_output=router_output,
|
||||
toolkit=worker_toolkit,
|
||||
stage_config=worker_config,
|
||||
runtime_client_time=runtime_client_time,
|
||||
runtime_mode=runtime_mode,
|
||||
work_memory=work_memory,
|
||||
)
|
||||
return {
|
||||
"router": router_output.model_dump(mode="json", exclude_none=True),
|
||||
"worker": worker_output.model_dump(mode="json", exclude_none=True),
|
||||
}
|
||||
finally:
|
||||
stop_cancel_watch.set()
|
||||
if cancel_watch_task is not None:
|
||||
cancel_watch_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await cancel_watch_task
|
||||
|
||||
async def _watch_cancel_signal(
|
||||
self,
|
||||
*,
|
||||
cancel_checker: Callable[[], Awaitable[bool]],
|
||||
stop_signal: asyncio.Event,
|
||||
run_task: asyncio.Task[object],
|
||||
) -> None:
|
||||
while not stop_signal.is_set():
|
||||
should_cancel = False
|
||||
try:
|
||||
should_cancel = await cancel_checker()
|
||||
except Exception:
|
||||
should_cancel = False
|
||||
|
||||
if should_cancel:
|
||||
async with self._active_agent_lock:
|
||||
active_agent = self._active_agent
|
||||
if active_agent is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
await active_agent.interrupt()
|
||||
if not run_task.done():
|
||||
run_task.cancel("run canceled by user")
|
||||
return
|
||||
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
def _build_toolkit(
|
||||
self,
|
||||
@@ -373,9 +425,16 @@ class AgentScopeRunner:
|
||||
model=tracking_model,
|
||||
emitter=emitter,
|
||||
)
|
||||
response_msg = await agent.reply_json(
|
||||
input_messages, output_model=worker_output_model
|
||||
)
|
||||
async with self._active_agent_lock:
|
||||
self._active_agent = agent
|
||||
try:
|
||||
response_msg = await agent.reply_json(
|
||||
input_messages, output_model=worker_output_model
|
||||
)
|
||||
finally:
|
||||
async with self._active_agent_lock:
|
||||
if self._active_agent is agent:
|
||||
self._active_agent = None
|
||||
worker_payload = worker_output_model.model_validate(response_msg.metadata or {})
|
||||
response_metadata = self._litellm_service.build_usage_metadata(
|
||||
model=stage_config.model_code,
|
||||
|
||||
Reference in New Issue
Block a user