feat: 支持 agent 运行取消功能
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
@@ -72,3 +73,31 @@ async def test_orchestrator_emits_run_lifecycle_events() -> None:
|
||||
assert result["worker"]["answer"] == "done"
|
||||
event_types = [item["event"]["type"] for item in pipeline.events]
|
||||
assert event_types == ["RUN_STARTED", "RUN_FINISHED"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_orchestrator_emits_run_canceled_error_on_cancelled_error() -> None:
|
||||
class _CanceledRunner:
|
||||
async def execute(self, **kwargs: object) -> dict[str, Any]:
|
||||
del kwargs
|
||||
raise asyncio.CancelledError("run canceled by user")
|
||||
|
||||
pipeline = _FakePipeline()
|
||||
orchestrator = AgentScopeRuntimeOrchestrator(
|
||||
pipeline=pipeline, runner=_CanceledRunner()
|
||||
)
|
||||
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await orchestrator.run(
|
||||
run_input=_run_input(),
|
||||
context_messages=[],
|
||||
user_context=_user_context(),
|
||||
runtime_config=_runtime_config(),
|
||||
)
|
||||
|
||||
assert [item["event"]["type"] for item in pipeline.events] == [
|
||||
"RUN_STARTED",
|
||||
"RUN_ERROR",
|
||||
]
|
||||
run_error_event = pipeline.events[-1]["event"]
|
||||
assert run_error_event["code"] == "RUN_CANCELED"
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import pytest
|
||||
from ag_ui.core import RunAgentInput
|
||||
|
||||
@@ -265,3 +266,73 @@ async def test_execute_runs_router_then_worker(
|
||||
assert load_calls == [AgentType.ROUTER, AgentType.WORKER]
|
||||
assert result["router"]["normalized_task_input"]["user_text"] == "安排会议"
|
||||
assert result["worker"]["answer"] == "ok"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_raises_cancelled_error_before_worker_when_cancel_requested(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
class _FakePipeline:
|
||||
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
|
||||
del session_id, event
|
||||
return "1-0"
|
||||
|
||||
class _FakeSessionCtx:
|
||||
async def __aenter__(self) -> object:
|
||||
return object()
|
||||
|
||||
async def __aexit__(self, exc_type: object, exc: object, tb: object) -> None:
|
||||
del exc_type, exc, tb
|
||||
|
||||
runner = AgentScopeRunner()
|
||||
|
||||
async def _fake_load_stage_config(*, session: object, agent_type: AgentType):
|
||||
del session
|
||||
return runner_module.SystemAgentRuntimeConfig(
|
||||
agent_type=agent_type,
|
||||
model_code="demo",
|
||||
api_base_url="https://example.com",
|
||||
api_key="test",
|
||||
llm_config=runner_module.SystemAgentLLMConfig(),
|
||||
)
|
||||
|
||||
async def _fake_execute_router_step(**kwargs: object) -> RouterAgentOutput:
|
||||
del kwargs
|
||||
return RouterAgentOutput(
|
||||
normalized_task_input=NormalizedTaskInput(
|
||||
user_text="安排会议",
|
||||
context_summary="",
|
||||
),
|
||||
key_entities=[],
|
||||
constraints=[],
|
||||
task_typing=TaskTyping(primary=TaskType.SCHEDULING),
|
||||
execution_mode=ExecutionMode.TOOL_ASSISTED,
|
||||
result_typing=ResultTyping(primary=ResultType.EXECUTION_REPORT),
|
||||
ui=RouterUiDecision(
|
||||
ui_mode=UiMode.NONE,
|
||||
ui_decision_reason="单任务",
|
||||
),
|
||||
)
|
||||
|
||||
async def _fake_execute_worker_step(**kwargs: object) -> WorkerAgentOutputLite:
|
||||
del kwargs
|
||||
raise AssertionError("worker should not run after cancel")
|
||||
|
||||
monkeypatch.setattr(runner_module, "AsyncSessionLocal", lambda: _FakeSessionCtx())
|
||||
monkeypatch.setattr(runner, "_load_stage_config", _fake_load_stage_config)
|
||||
monkeypatch.setattr(runner, "_build_toolkit", lambda **kwargs: object())
|
||||
monkeypatch.setattr(runner, "_execute_router_step", _fake_execute_router_step)
|
||||
monkeypatch.setattr(runner, "_execute_worker_step", _fake_execute_worker_step)
|
||||
|
||||
async def _cancel_checker() -> bool:
|
||||
return True
|
||||
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await runner.execute(
|
||||
user_context=_user_context(),
|
||||
context_messages=[],
|
||||
pipeline=_FakePipeline(),
|
||||
run_input=_run_input(),
|
||||
runtime_config=_runtime_config(),
|
||||
cancel_checker=_cancel_checker,
|
||||
)
|
||||
|
||||
@@ -166,6 +166,69 @@ async def test_run_agentscope_task_injects_runtime_config(
|
||||
assert captured_config["runtime_config"] is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agentscope_task_injects_cancel_checker(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
class _FakeRuntime:
|
||||
def __init__(self, **kwargs: object) -> None:
|
||||
del kwargs
|
||||
|
||||
async def run(self, **kwargs: object) -> object:
|
||||
checker = kwargs.get("cancel_checker")
|
||||
assert callable(checker)
|
||||
captured["cancelled"] = await checker() # type: ignore[misc]
|
||||
return object()
|
||||
|
||||
class _FakeRedis:
|
||||
async def exists(self, key: str) -> int:
|
||||
captured["cancel_key"] = key
|
||||
return 1
|
||||
|
||||
async def delete(self, key: str) -> int:
|
||||
captured["deleted_key"] = key
|
||||
return 1
|
||||
|
||||
async def _fake_get_redis_client() -> object:
|
||||
return _FakeRedis()
|
||||
|
||||
async def _empty_context(**kwargs: object) -> list[dict[str, Any]]:
|
||||
del kwargs
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(tasks_module, "AgentScopeRuntimeOrchestrator", _FakeRuntime)
|
||||
monkeypatch.setattr(
|
||||
tasks_module,
|
||||
"get_or_init_redis_client",
|
||||
_fake_get_redis_client,
|
||||
)
|
||||
monkeypatch.setattr(tasks_module, "AsyncSessionLocal", lambda: _FakeSessionCtx())
|
||||
monkeypatch.setattr(tasks_module, "_build_user_context", _fake_user_context)
|
||||
monkeypatch.setattr(
|
||||
tasks_module,
|
||||
"_build_recent_context_messages",
|
||||
_empty_context,
|
||||
)
|
||||
|
||||
await tasks_module.run_agentscope_task(
|
||||
{
|
||||
"command": "run",
|
||||
"owner_id": str(uuid4()),
|
||||
"run_input": _run_input_payload(),
|
||||
"runtime_config": {
|
||||
"enabled_tools": [],
|
||||
"context": {"window_mode": "day", "window_count": 2},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
assert captured["cancelled"] is True
|
||||
assert isinstance(captured["cancel_key"], str)
|
||||
assert captured["deleted_key"] == captured["cancel_key"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agentscope_task_requires_owner_id() -> None:
|
||||
with pytest.raises(ValueError, match="owner_id is required"):
|
||||
|
||||
Reference in New Issue
Block a user