feat: 支持 agent 运行取消功能

This commit is contained in:
qzl
2026-03-25 18:33:25 +08:00
parent 599c597e69
commit 96fc4a1e77
21 changed files with 778 additions and 85 deletions
@@ -1,5 +1,6 @@
from __future__ import annotations
import asyncio
from typing import Any
import pytest
@@ -72,3 +73,31 @@ async def test_orchestrator_emits_run_lifecycle_events() -> None:
assert result["worker"]["answer"] == "done"
event_types = [item["event"]["type"] for item in pipeline.events]
assert event_types == ["RUN_STARTED", "RUN_FINISHED"]
@pytest.mark.asyncio
async def test_orchestrator_emits_run_canceled_error_on_cancelled_error() -> None:
class _CanceledRunner:
async def execute(self, **kwargs: object) -> dict[str, Any]:
del kwargs
raise asyncio.CancelledError("run canceled by user")
pipeline = _FakePipeline()
orchestrator = AgentScopeRuntimeOrchestrator(
pipeline=pipeline, runner=_CanceledRunner()
)
with pytest.raises(asyncio.CancelledError):
await orchestrator.run(
run_input=_run_input(),
context_messages=[],
user_context=_user_context(),
runtime_config=_runtime_config(),
)
assert [item["event"]["type"] for item in pipeline.events] == [
"RUN_STARTED",
"RUN_ERROR",
]
run_error_event = pipeline.events[-1]["event"]
assert run_error_event["code"] == "RUN_CANCELED"
@@ -1,5 +1,6 @@
from __future__ import annotations
import asyncio
import pytest
from ag_ui.core import RunAgentInput
@@ -265,3 +266,73 @@ async def test_execute_runs_router_then_worker(
assert load_calls == [AgentType.ROUTER, AgentType.WORKER]
assert result["router"]["normalized_task_input"]["user_text"] == "安排会议"
assert result["worker"]["answer"] == "ok"
@pytest.mark.asyncio
async def test_execute_raises_cancelled_error_before_worker_when_cancel_requested(
monkeypatch: pytest.MonkeyPatch,
) -> None:
class _FakePipeline:
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
del session_id, event
return "1-0"
class _FakeSessionCtx:
async def __aenter__(self) -> object:
return object()
async def __aexit__(self, exc_type: object, exc: object, tb: object) -> None:
del exc_type, exc, tb
runner = AgentScopeRunner()
async def _fake_load_stage_config(*, session: object, agent_type: AgentType):
del session
return runner_module.SystemAgentRuntimeConfig(
agent_type=agent_type,
model_code="demo",
api_base_url="https://example.com",
api_key="test",
llm_config=runner_module.SystemAgentLLMConfig(),
)
async def _fake_execute_router_step(**kwargs: object) -> RouterAgentOutput:
del kwargs
return RouterAgentOutput(
normalized_task_input=NormalizedTaskInput(
user_text="安排会议",
context_summary="",
),
key_entities=[],
constraints=[],
task_typing=TaskTyping(primary=TaskType.SCHEDULING),
execution_mode=ExecutionMode.TOOL_ASSISTED,
result_typing=ResultTyping(primary=ResultType.EXECUTION_REPORT),
ui=RouterUiDecision(
ui_mode=UiMode.NONE,
ui_decision_reason="单任务",
),
)
async def _fake_execute_worker_step(**kwargs: object) -> WorkerAgentOutputLite:
del kwargs
raise AssertionError("worker should not run after cancel")
monkeypatch.setattr(runner_module, "AsyncSessionLocal", lambda: _FakeSessionCtx())
monkeypatch.setattr(runner, "_load_stage_config", _fake_load_stage_config)
monkeypatch.setattr(runner, "_build_toolkit", lambda **kwargs: object())
monkeypatch.setattr(runner, "_execute_router_step", _fake_execute_router_step)
monkeypatch.setattr(runner, "_execute_worker_step", _fake_execute_worker_step)
async def _cancel_checker() -> bool:
return True
with pytest.raises(asyncio.CancelledError):
await runner.execute(
user_context=_user_context(),
context_messages=[],
pipeline=_FakePipeline(),
run_input=_run_input(),
runtime_config=_runtime_config(),
cancel_checker=_cancel_checker,
)
@@ -166,6 +166,69 @@ async def test_run_agentscope_task_injects_runtime_config(
assert captured_config["runtime_config"] is not None
@pytest.mark.asyncio
async def test_run_agentscope_task_injects_cancel_checker(
monkeypatch: pytest.MonkeyPatch,
) -> None:
captured: dict[str, Any] = {}
class _FakeRuntime:
def __init__(self, **kwargs: object) -> None:
del kwargs
async def run(self, **kwargs: object) -> object:
checker = kwargs.get("cancel_checker")
assert callable(checker)
captured["cancelled"] = await checker() # type: ignore[misc]
return object()
class _FakeRedis:
async def exists(self, key: str) -> int:
captured["cancel_key"] = key
return 1
async def delete(self, key: str) -> int:
captured["deleted_key"] = key
return 1
async def _fake_get_redis_client() -> object:
return _FakeRedis()
async def _empty_context(**kwargs: object) -> list[dict[str, Any]]:
del kwargs
return []
monkeypatch.setattr(tasks_module, "AgentScopeRuntimeOrchestrator", _FakeRuntime)
monkeypatch.setattr(
tasks_module,
"get_or_init_redis_client",
_fake_get_redis_client,
)
monkeypatch.setattr(tasks_module, "AsyncSessionLocal", lambda: _FakeSessionCtx())
monkeypatch.setattr(tasks_module, "_build_user_context", _fake_user_context)
monkeypatch.setattr(
tasks_module,
"_build_recent_context_messages",
_empty_context,
)
await tasks_module.run_agentscope_task(
{
"command": "run",
"owner_id": str(uuid4()),
"run_input": _run_input_payload(),
"runtime_config": {
"enabled_tools": [],
"context": {"window_mode": "day", "window_count": 2},
},
}
)
assert captured["cancelled"] is True
assert isinstance(captured["cancel_key"], str)
assert captured["deleted_key"] == captured["cancel_key"]
@pytest.mark.asyncio
async def test_run_agentscope_task_requires_owner_id() -> None:
with pytest.raises(ValueError, match="owner_id is required"):