feat: 支持 agent 运行取消功能

This commit is contained in:
qzl
2026-03-25 18:33:25 +08:00
parent 599c597e69
commit 96fc4a1e77
21 changed files with 778 additions and 85 deletions
@@ -219,21 +219,24 @@ async def test_store_persists_router_step_output_for_cost_tracking(
}
)
append_kwargs = cast(dict[str, Any], captured["append_kwargs"])
assert append_kwargs["seq"] == 11
assert append_kwargs["content"] == ""
assert append_kwargs["model_code"] == "doubao-seed-1-6-250615"
assert append_kwargs["input_tokens"] == 12
assert append_kwargs["output_tokens"] == 8
assert append_kwargs["latency_ms"] == 320
assert append_kwargs["cost"] == Decimal("0.01")
assert append_kwargs["visibility_mask"] == 0
metadata = cast(dict[str, Any], append_kwargs["metadata"])
assert sorted(metadata.keys()) == ["agent_type", "router_agent_output", "run_id"]
assert metadata["agent_type"] == "router"
assert metadata["router_agent_output"]["execution_mode"] == "tool_assisted"
@pytest.mark.asyncio
async def test_store_marks_session_failed_for_run_canceled_error(
monkeypatch: pytest.MonkeyPatch,
) -> None:
captured: dict[str, object] = {}
fake_chat_session = SimpleNamespace(state_snapshot={}, message_count=2)
_patch_repositories(monkeypatch, captured, fake_chat_session)
assert captured["message_delta"] == 1
assert captured["token_delta"] == 20
assert captured["cost_delta"] == Decimal("0.01")
store = store_module.SqlAlchemyEventStore(session_factory=lambda: _FakeSessionCtx())
await store.persist(
{
"type": "RUN_ERROR",
"threadId": "00000000-0000-0000-0000-000000000001",
"runId": "run-cancel-1",
"message": "run canceled by user",
"code": "RUN_CANCELED",
}
)
assert captured["status"] == _SessionStatus.FAILED
@@ -1,5 +1,6 @@
from __future__ import annotations
import asyncio
from typing import Any
import pytest
@@ -72,3 +73,31 @@ async def test_orchestrator_emits_run_lifecycle_events() -> None:
assert result["worker"]["answer"] == "done"
event_types = [item["event"]["type"] for item in pipeline.events]
assert event_types == ["RUN_STARTED", "RUN_FINISHED"]
@pytest.mark.asyncio
async def test_orchestrator_emits_run_canceled_error_on_cancelled_error() -> None:
class _CanceledRunner:
async def execute(self, **kwargs: object) -> dict[str, Any]:
del kwargs
raise asyncio.CancelledError("run canceled by user")
pipeline = _FakePipeline()
orchestrator = AgentScopeRuntimeOrchestrator(
pipeline=pipeline, runner=_CanceledRunner()
)
with pytest.raises(asyncio.CancelledError):
await orchestrator.run(
run_input=_run_input(),
context_messages=[],
user_context=_user_context(),
runtime_config=_runtime_config(),
)
assert [item["event"]["type"] for item in pipeline.events] == [
"RUN_STARTED",
"RUN_ERROR",
]
run_error_event = pipeline.events[-1]["event"]
assert run_error_event["code"] == "RUN_CANCELED"
@@ -1,5 +1,6 @@
from __future__ import annotations
import asyncio
import pytest
from ag_ui.core import RunAgentInput
@@ -265,3 +266,73 @@ async def test_execute_runs_router_then_worker(
assert load_calls == [AgentType.ROUTER, AgentType.WORKER]
assert result["router"]["normalized_task_input"]["user_text"] == "安排会议"
assert result["worker"]["answer"] == "ok"
@pytest.mark.asyncio
async def test_execute_raises_cancelled_error_before_worker_when_cancel_requested(
monkeypatch: pytest.MonkeyPatch,
) -> None:
class _FakePipeline:
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
del session_id, event
return "1-0"
class _FakeSessionCtx:
async def __aenter__(self) -> object:
return object()
async def __aexit__(self, exc_type: object, exc: object, tb: object) -> None:
del exc_type, exc, tb
runner = AgentScopeRunner()
async def _fake_load_stage_config(*, session: object, agent_type: AgentType):
del session
return runner_module.SystemAgentRuntimeConfig(
agent_type=agent_type,
model_code="demo",
api_base_url="https://example.com",
api_key="test",
llm_config=runner_module.SystemAgentLLMConfig(),
)
async def _fake_execute_router_step(**kwargs: object) -> RouterAgentOutput:
del kwargs
return RouterAgentOutput(
normalized_task_input=NormalizedTaskInput(
user_text="安排会议",
context_summary="",
),
key_entities=[],
constraints=[],
task_typing=TaskTyping(primary=TaskType.SCHEDULING),
execution_mode=ExecutionMode.TOOL_ASSISTED,
result_typing=ResultTyping(primary=ResultType.EXECUTION_REPORT),
ui=RouterUiDecision(
ui_mode=UiMode.NONE,
ui_decision_reason="单任务",
),
)
async def _fake_execute_worker_step(**kwargs: object) -> WorkerAgentOutputLite:
del kwargs
raise AssertionError("worker should not run after cancel")
monkeypatch.setattr(runner_module, "AsyncSessionLocal", lambda: _FakeSessionCtx())
monkeypatch.setattr(runner, "_load_stage_config", _fake_load_stage_config)
monkeypatch.setattr(runner, "_build_toolkit", lambda **kwargs: object())
monkeypatch.setattr(runner, "_execute_router_step", _fake_execute_router_step)
monkeypatch.setattr(runner, "_execute_worker_step", _fake_execute_worker_step)
async def _cancel_checker() -> bool:
return True
with pytest.raises(asyncio.CancelledError):
await runner.execute(
user_context=_user_context(),
context_messages=[],
pipeline=_FakePipeline(),
run_input=_run_input(),
runtime_config=_runtime_config(),
cancel_checker=_cancel_checker,
)
@@ -166,6 +166,69 @@ async def test_run_agentscope_task_injects_runtime_config(
assert captured_config["runtime_config"] is not None
@pytest.mark.asyncio
async def test_run_agentscope_task_injects_cancel_checker(
monkeypatch: pytest.MonkeyPatch,
) -> None:
captured: dict[str, Any] = {}
class _FakeRuntime:
def __init__(self, **kwargs: object) -> None:
del kwargs
async def run(self, **kwargs: object) -> object:
checker = kwargs.get("cancel_checker")
assert callable(checker)
captured["cancelled"] = await checker() # type: ignore[misc]
return object()
class _FakeRedis:
async def exists(self, key: str) -> int:
captured["cancel_key"] = key
return 1
async def delete(self, key: str) -> int:
captured["deleted_key"] = key
return 1
async def _fake_get_redis_client() -> object:
return _FakeRedis()
async def _empty_context(**kwargs: object) -> list[dict[str, Any]]:
del kwargs
return []
monkeypatch.setattr(tasks_module, "AgentScopeRuntimeOrchestrator", _FakeRuntime)
monkeypatch.setattr(
tasks_module,
"get_or_init_redis_client",
_fake_get_redis_client,
)
monkeypatch.setattr(tasks_module, "AsyncSessionLocal", lambda: _FakeSessionCtx())
monkeypatch.setattr(tasks_module, "_build_user_context", _fake_user_context)
monkeypatch.setattr(
tasks_module,
"_build_recent_context_messages",
_empty_context,
)
await tasks_module.run_agentscope_task(
{
"command": "run",
"owner_id": str(uuid4()),
"run_input": _run_input_payload(),
"runtime_config": {
"enabled_tools": [],
"context": {"window_mode": "day", "window_count": 2},
},
}
)
assert captured["cancelled"] is True
assert isinstance(captured["cancel_key"], str)
assert captured["deleted_key"] == captured["cancel_key"]
@pytest.mark.asyncio
async def test_run_agentscope_task_requires_owner_id() -> None:
with pytest.raises(ValueError, match="owner_id is required"):
@@ -100,6 +100,7 @@ class _FakeRepository:
class _FakeQueue:
def __init__(self) -> None:
self.commands: list[dict[str, object]] = []
self.cancel_requests: list[dict[str, str]] = []
async def enqueue(
self, *, command: dict[str, object], dedup_key: str | None
@@ -108,6 +109,21 @@ class _FakeQueue:
self.commands.append(command)
return "task-1"
async def request_cancel(
self,
*,
thread_id: str,
run_id: str,
requested_by: str,
) -> None:
self.cancel_requests.append(
{
"thread_id": thread_id,
"run_id": run_id,
"requested_by": requested_by,
}
)
class _FakeStream:
async def read(
@@ -469,3 +485,56 @@ async def test_get_history_snapshot_filters_out_tool_messages() -> None:
)
assert [message.role for message in snapshot.messages] == ["user", "assistant"]
@pytest.mark.asyncio
async def test_cancel_run_requests_queue_cancel_for_owner() -> None:
queue = _FakeQueue()
service = AgentService(
repository=_FakeRepository(),
queue=queue,
stream=_FakeStream(),
attachment_storage=_FakeAttachmentStorage(),
)
result = await service.cancel_run(
thread_id="00000000-0000-0000-0000-000000000001",
run_id="run-cancel-1",
current_user=_user(),
)
assert result.accepted is True
assert result.thread_id == "00000000-0000-0000-0000-000000000001"
assert result.run_id == "run-cancel-1"
assert queue.cancel_requests == [
{
"thread_id": "00000000-0000-0000-0000-000000000001",
"run_id": "run-cancel-1",
"requested_by": "00000000-0000-0000-0000-000000000001",
}
]
@pytest.mark.asyncio
async def test_cancel_run_rejects_non_owner() -> None:
queue = _FakeQueue()
service = AgentService(
repository=_FakeRepository(),
queue=queue,
stream=_FakeStream(),
attachment_storage=_FakeAttachmentStorage(),
)
other_user = CurrentUser(
id=UUID("00000000-0000-0000-0000-000000000099"),
phone="+8613812340000",
)
with pytest.raises(HTTPException) as exc_info:
await service.cancel_run(
thread_id="00000000-0000-0000-0000-000000000001",
run_id="run-cancel-2",
current_user=other_user,
)
assert exc_info.value.status_code == 403
assert queue.cancel_requests == []