feat: 支持 agent 运行取消功能
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Protocol
|
||||
import asyncio
|
||||
from typing import Any, Awaitable, Callable, Protocol
|
||||
|
||||
from ag_ui.core.types import RunAgentInput
|
||||
from agentscope.message import Msg
|
||||
@@ -28,6 +29,7 @@ class RunnerLike(Protocol):
|
||||
runtime_config: RuntimeConfig,
|
||||
user_memory: UserMemoryContent | None,
|
||||
work_memory: WorkProfileContent | None,
|
||||
cancel_checker: Callable[[], Awaitable[bool]] | None = None,
|
||||
) -> dict[str, Any]: ...
|
||||
|
||||
|
||||
@@ -53,6 +55,7 @@ class AgentScopeRuntimeOrchestrator:
|
||||
runtime_config: RuntimeConfig,
|
||||
user_memory: UserMemoryContent | None = None,
|
||||
work_memory: WorkProfileContent | None = None,
|
||||
cancel_checker: Callable[[], Awaitable[bool]] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
thread_id = run_input.thread_id
|
||||
run_id = run_input.run_id
|
||||
@@ -74,6 +77,7 @@ class AgentScopeRuntimeOrchestrator:
|
||||
runtime_config=runtime_config,
|
||||
user_memory=user_memory,
|
||||
work_memory=work_memory,
|
||||
cancel_checker=cancel_checker,
|
||||
)
|
||||
|
||||
await self._pipeline.emit(
|
||||
@@ -85,6 +89,23 @@ class AgentScopeRuntimeOrchestrator:
|
||||
},
|
||||
)
|
||||
return result if isinstance(result, dict) else {}
|
||||
except asyncio.CancelledError:
|
||||
logger.info(
|
||||
"agentscope runtime execution canceled",
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
)
|
||||
await self._pipeline.emit(
|
||||
session_id=thread_id,
|
||||
event={
|
||||
"type": "RUN_ERROR",
|
||||
"threadId": thread_id,
|
||||
"runId": run_id,
|
||||
"message": "run canceled by user",
|
||||
"code": "RUN_CANCELED",
|
||||
},
|
||||
)
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"agentscope runtime execution failed",
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable
|
||||
from uuid import UUID
|
||||
|
||||
from ag_ui.core.types import RunAgentInput
|
||||
@@ -64,6 +66,8 @@ class AgentScopeRunner:
|
||||
def __init__(self, *, litellm_service: LiteLLMService | None = None) -> None:
|
||||
patch_agentscope_json_repair_compat()
|
||||
self._litellm_service: LiteLLMService = litellm_service or LiteLLMService()
|
||||
self._active_agent: JsonReActAgent | None = None
|
||||
self._active_agent_lock = asyncio.Lock()
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
@@ -75,51 +79,99 @@ class AgentScopeRunner:
|
||||
runtime_config: RuntimeConfig,
|
||||
user_memory: UserMemoryContent | None = None,
|
||||
work_memory: WorkProfileContent | None = None,
|
||||
cancel_checker: Callable[[], Awaitable[bool]] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
owner_id = UUID(user_context.id)
|
||||
runtime_client_time = self._resolve_runtime_client_time(run_input=run_input)
|
||||
runtime_mode = self._resolve_runtime_mode(run_input=run_input)
|
||||
stop_cancel_watch = asyncio.Event()
|
||||
cancel_watch_task: asyncio.Task[None] | None = None
|
||||
run_task = asyncio.current_task()
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
router_config = await self._load_stage_config(
|
||||
session=session,
|
||||
agent_type=AgentType.ROUTER,
|
||||
)
|
||||
worker_config = await self._load_stage_config(
|
||||
session=session,
|
||||
agent_type=AgentType.WORKER,
|
||||
)
|
||||
worker_toolkit = self._build_toolkit(
|
||||
session=session,
|
||||
owner_id=owner_id,
|
||||
enabled_tools=runtime_config.enabled_tools,
|
||||
if cancel_checker is not None and run_task is not None:
|
||||
cancel_watch_task = asyncio.create_task(
|
||||
self._watch_cancel_signal(
|
||||
cancel_checker=cancel_checker,
|
||||
stop_signal=stop_cancel_watch,
|
||||
run_task=run_task,
|
||||
)
|
||||
)
|
||||
|
||||
router_output = await self._execute_router_step(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
user_context=user_context,
|
||||
context_messages=context_messages,
|
||||
stage_config=router_config,
|
||||
runtime_client_time=runtime_client_time,
|
||||
runtime_mode=runtime_mode,
|
||||
user_memory=user_memory,
|
||||
)
|
||||
worker_output = await self._execute_worker_step(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
user_context=user_context,
|
||||
router_output=router_output,
|
||||
toolkit=worker_toolkit,
|
||||
stage_config=worker_config,
|
||||
runtime_client_time=runtime_client_time,
|
||||
runtime_mode=runtime_mode,
|
||||
work_memory=work_memory,
|
||||
)
|
||||
return {
|
||||
"router": router_output.model_dump(mode="json", exclude_none=True),
|
||||
"worker": worker_output.model_dump(mode="json", exclude_none=True),
|
||||
}
|
||||
try:
|
||||
async with AsyncSessionLocal() as session:
|
||||
router_config = await self._load_stage_config(
|
||||
session=session,
|
||||
agent_type=AgentType.ROUTER,
|
||||
)
|
||||
worker_config = await self._load_stage_config(
|
||||
session=session,
|
||||
agent_type=AgentType.WORKER,
|
||||
)
|
||||
worker_toolkit = self._build_toolkit(
|
||||
session=session,
|
||||
owner_id=owner_id,
|
||||
enabled_tools=runtime_config.enabled_tools,
|
||||
)
|
||||
|
||||
router_output = await self._execute_router_step(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
user_context=user_context,
|
||||
context_messages=context_messages,
|
||||
stage_config=router_config,
|
||||
runtime_client_time=runtime_client_time,
|
||||
runtime_mode=runtime_mode,
|
||||
user_memory=user_memory,
|
||||
)
|
||||
if cancel_checker is not None and await cancel_checker():
|
||||
raise asyncio.CancelledError("run canceled by user")
|
||||
worker_output = await self._execute_worker_step(
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
user_context=user_context,
|
||||
router_output=router_output,
|
||||
toolkit=worker_toolkit,
|
||||
stage_config=worker_config,
|
||||
runtime_client_time=runtime_client_time,
|
||||
runtime_mode=runtime_mode,
|
||||
work_memory=work_memory,
|
||||
)
|
||||
return {
|
||||
"router": router_output.model_dump(mode="json", exclude_none=True),
|
||||
"worker": worker_output.model_dump(mode="json", exclude_none=True),
|
||||
}
|
||||
finally:
|
||||
stop_cancel_watch.set()
|
||||
if cancel_watch_task is not None:
|
||||
cancel_watch_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await cancel_watch_task
|
||||
|
||||
async def _watch_cancel_signal(
|
||||
self,
|
||||
*,
|
||||
cancel_checker: Callable[[], Awaitable[bool]],
|
||||
stop_signal: asyncio.Event,
|
||||
run_task: asyncio.Task[object],
|
||||
) -> None:
|
||||
while not stop_signal.is_set():
|
||||
should_cancel = False
|
||||
try:
|
||||
should_cancel = await cancel_checker()
|
||||
except Exception:
|
||||
should_cancel = False
|
||||
|
||||
if should_cancel:
|
||||
async with self._active_agent_lock:
|
||||
active_agent = self._active_agent
|
||||
if active_agent is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
await active_agent.interrupt()
|
||||
if not run_task.done():
|
||||
run_task.cancel("run canceled by user")
|
||||
return
|
||||
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
def _build_toolkit(
|
||||
self,
|
||||
@@ -373,9 +425,16 @@ class AgentScopeRunner:
|
||||
model=tracking_model,
|
||||
emitter=emitter,
|
||||
)
|
||||
response_msg = await agent.reply_json(
|
||||
input_messages, output_model=worker_output_model
|
||||
)
|
||||
async with self._active_agent_lock:
|
||||
self._active_agent = agent
|
||||
try:
|
||||
response_msg = await agent.reply_json(
|
||||
input_messages, output_model=worker_output_model
|
||||
)
|
||||
finally:
|
||||
async with self._active_agent_lock:
|
||||
if self._active_agent is agent:
|
||||
self._active_agent = None
|
||||
worker_payload = worker_output_model.model_validate(response_msg.metadata or {})
|
||||
response_metadata = self._litellm_service.build_usage_metadata(
|
||||
model=stage_config.model_code,
|
||||
|
||||
@@ -257,6 +257,7 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
thread_id = run_input.thread_id
|
||||
run_id = run_input.run_id
|
||||
owner_id = UUID(raw_owner_id)
|
||||
cancel_key = f"agent:cancel:{thread_id}:{run_id}"
|
||||
|
||||
if command_type != "run":
|
||||
raise ValueError("invalid command type")
|
||||
@@ -278,6 +279,15 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
work_memory: WorkProfileContent | None = memories_result.get("work_memory")
|
||||
|
||||
redis_client = await get_or_init_redis_client()
|
||||
|
||||
async def _cancel_checker() -> bool:
|
||||
exists_fn = getattr(redis_client, "exists", None)
|
||||
if not callable(exists_fn):
|
||||
return False
|
||||
exists_call = cast(Any, exists_fn)(cancel_key)
|
||||
result = await exists_call
|
||||
return bool(result)
|
||||
|
||||
bus = RedisStreamBus(
|
||||
client=redis_client,
|
||||
stream_prefix=config.agent_runtime.redis_stream_prefix,
|
||||
@@ -302,14 +312,21 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
context_config=runtime_config.context,
|
||||
)
|
||||
|
||||
await runtime.run(
|
||||
run_input=run_input,
|
||||
context_messages=context_messages,
|
||||
user_context=user_context,
|
||||
runtime_config=runtime_config,
|
||||
user_memory=user_memory,
|
||||
work_memory=work_memory,
|
||||
)
|
||||
try:
|
||||
await runtime.run(
|
||||
run_input=run_input,
|
||||
context_messages=context_messages,
|
||||
user_context=user_context,
|
||||
runtime_config=runtime_config,
|
||||
user_memory=user_memory,
|
||||
work_memory=work_memory,
|
||||
cancel_checker=_cancel_checker,
|
||||
)
|
||||
finally:
|
||||
delete_fn = getattr(redis_client, "delete", None)
|
||||
if callable(delete_fn):
|
||||
delete_call = cast(Any, delete_fn)(cancel_key)
|
||||
await delete_call
|
||||
logger.info(
|
||||
"agentscope runtime task completed",
|
||||
command_type=command_type,
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Depends
|
||||
@@ -22,6 +24,7 @@ DEDUP_WAIT_RETRIES = 20
|
||||
DEDUP_WAIT_SECONDS = 0.05
|
||||
DEDUP_LOCK_SECONDS = 300
|
||||
DEDUP_INFLIGHT_MARKER = "__inflight__"
|
||||
RUN_CANCEL_SIGNAL_TTL_SECONDS = 1800
|
||||
|
||||
|
||||
def _event_stream_block_ms() -> int:
|
||||
@@ -87,6 +90,29 @@ class TaskiqQueueClient:
|
||||
await redis_client.delete(redis_key)
|
||||
raise
|
||||
|
||||
async def request_cancel(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
requested_by: str,
|
||||
) -> None:
|
||||
redis_client = await self._get_redis()
|
||||
cancel_key = f"agent:cancel:{thread_id}:{run_id}"
|
||||
payload = json.dumps(
|
||||
{
|
||||
"requested_by": requested_by,
|
||||
"requested_at": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
ensure_ascii=True,
|
||||
separators=(",", ":"),
|
||||
)
|
||||
await redis_client.set(
|
||||
cancel_key,
|
||||
payload,
|
||||
ex=RUN_CANCEL_SIGNAL_TTL_SECONDS,
|
||||
)
|
||||
|
||||
|
||||
class RedisEventStream:
|
||||
def __init__(self) -> None:
|
||||
|
||||
@@ -37,6 +37,7 @@ from v1.agent.schemas import (
|
||||
AttachmentReference,
|
||||
AttachmentSignedUrlResponse,
|
||||
AttachmentUploadResponse,
|
||||
CancelRunResponse,
|
||||
HistorySnapshotResponse,
|
||||
TaskAcceptedResponse,
|
||||
)
|
||||
@@ -147,6 +148,34 @@ async def enqueue_run(
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/runs/{thread_id}/cancel",
|
||||
response_model=CancelRunResponse,
|
||||
status_code=status.HTTP_202_ACCEPTED,
|
||||
)
|
||||
async def cancel_run(
|
||||
thread_id: str,
|
||||
service: Annotated[AgentService, Depends(get_agent_service)],
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
run_id: str = Query(
|
||||
alias="runId",
|
||||
min_length=1,
|
||||
max_length=128,
|
||||
pattern=r"^[A-Za-z0-9_-]+$",
|
||||
),
|
||||
) -> CancelRunResponse:
|
||||
canceled = await service.cancel_run(
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
current_user=current_user,
|
||||
)
|
||||
return CancelRunResponse(
|
||||
threadId=canceled.thread_id,
|
||||
runId=canceled.run_id,
|
||||
accepted=canceled.accepted,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/runs/{thread_id}/events")
|
||||
async def stream_events(
|
||||
request: Request,
|
||||
|
||||
@@ -49,6 +49,14 @@ class QueueClientLike(Protocol):
|
||||
self, *, command: dict[str, object], dedup_key: str | None
|
||||
) -> str: ...
|
||||
|
||||
async def request_cancel(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
requested_by: str,
|
||||
) -> None: ...
|
||||
|
||||
|
||||
class EventStreamLike(Protocol):
|
||||
async def read(
|
||||
@@ -90,6 +98,13 @@ class TaskAccepted:
|
||||
created: bool
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CancelRequested:
|
||||
thread_id: str
|
||||
run_id: str
|
||||
accepted: bool
|
||||
|
||||
|
||||
class TaskAcceptedResponse(BaseModel):
|
||||
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
|
||||
|
||||
@@ -99,6 +114,14 @@ class TaskAcceptedResponse(BaseModel):
|
||||
created: bool
|
||||
|
||||
|
||||
class CancelRunResponse(BaseModel):
|
||||
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
|
||||
|
||||
thread_id: str = Field(alias="threadId")
|
||||
run_id: str = Field(alias="runId")
|
||||
accepted: bool
|
||||
|
||||
|
||||
class AsrTranscribeResponse(BaseModel):
|
||||
transcript: str = Field(description="Transcribed text from audio")
|
||||
|
||||
|
||||
@@ -30,6 +30,7 @@ from schemas.domain.chat_message import (
|
||||
from v1.agent.schemas import (
|
||||
AgentRepositoryLike,
|
||||
AttachmentStorageLike,
|
||||
CancelRequested,
|
||||
EventStreamLike,
|
||||
HistorySnapshotResponse,
|
||||
QueueClientLike,
|
||||
@@ -157,6 +158,26 @@ class AgentService:
|
||||
created=created,
|
||||
)
|
||||
|
||||
async def cancel_run(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
current_user: CurrentUser,
|
||||
) -> CancelRequested:
|
||||
owner = await self._repository.get_session_owner(session_id=thread_id)
|
||||
ensure_session_owner(owner_id=owner, current_user=current_user)
|
||||
await self._queue.request_cancel(
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
requested_by=str(current_user.id),
|
||||
)
|
||||
return CancelRequested(
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
accepted=True,
|
||||
)
|
||||
|
||||
async def _append_context_cache_user_message(
|
||||
self,
|
||||
*,
|
||||
|
||||
Reference in New Issue
Block a user