feat: 支持 agent 运行取消功能
This commit is contained in:
@@ -219,21 +219,24 @@ async def test_store_persists_router_step_output_for_cost_tracking(
|
||||
}
|
||||
)
|
||||
|
||||
append_kwargs = cast(dict[str, Any], captured["append_kwargs"])
|
||||
assert append_kwargs["seq"] == 11
|
||||
assert append_kwargs["content"] == ""
|
||||
assert append_kwargs["model_code"] == "doubao-seed-1-6-250615"
|
||||
assert append_kwargs["input_tokens"] == 12
|
||||
assert append_kwargs["output_tokens"] == 8
|
||||
assert append_kwargs["latency_ms"] == 320
|
||||
assert append_kwargs["cost"] == Decimal("0.01")
|
||||
assert append_kwargs["visibility_mask"] == 0
|
||||
|
||||
metadata = cast(dict[str, Any], append_kwargs["metadata"])
|
||||
assert sorted(metadata.keys()) == ["agent_type", "router_agent_output", "run_id"]
|
||||
assert metadata["agent_type"] == "router"
|
||||
assert metadata["router_agent_output"]["execution_mode"] == "tool_assisted"
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_marks_session_failed_for_run_canceled_error(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
captured: dict[str, object] = {}
|
||||
fake_chat_session = SimpleNamespace(state_snapshot={}, message_count=2)
|
||||
_patch_repositories(monkeypatch, captured, fake_chat_session)
|
||||
|
||||
assert captured["message_delta"] == 1
|
||||
assert captured["token_delta"] == 20
|
||||
assert captured["cost_delta"] == Decimal("0.01")
|
||||
store = store_module.SqlAlchemyEventStore(session_factory=lambda: _FakeSessionCtx())
|
||||
await store.persist(
|
||||
{
|
||||
"type": "RUN_ERROR",
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-cancel-1",
|
||||
"message": "run canceled by user",
|
||||
"code": "RUN_CANCELED",
|
||||
}
|
||||
)
|
||||
|
||||
assert captured["status"] == _SessionStatus.FAILED
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
@@ -72,3 +73,31 @@ async def test_orchestrator_emits_run_lifecycle_events() -> None:
|
||||
assert result["worker"]["answer"] == "done"
|
||||
event_types = [item["event"]["type"] for item in pipeline.events]
|
||||
assert event_types == ["RUN_STARTED", "RUN_FINISHED"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_orchestrator_emits_run_canceled_error_on_cancelled_error() -> None:
|
||||
class _CanceledRunner:
|
||||
async def execute(self, **kwargs: object) -> dict[str, Any]:
|
||||
del kwargs
|
||||
raise asyncio.CancelledError("run canceled by user")
|
||||
|
||||
pipeline = _FakePipeline()
|
||||
orchestrator = AgentScopeRuntimeOrchestrator(
|
||||
pipeline=pipeline, runner=_CanceledRunner()
|
||||
)
|
||||
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await orchestrator.run(
|
||||
run_input=_run_input(),
|
||||
context_messages=[],
|
||||
user_context=_user_context(),
|
||||
runtime_config=_runtime_config(),
|
||||
)
|
||||
|
||||
assert [item["event"]["type"] for item in pipeline.events] == [
|
||||
"RUN_STARTED",
|
||||
"RUN_ERROR",
|
||||
]
|
||||
run_error_event = pipeline.events[-1]["event"]
|
||||
assert run_error_event["code"] == "RUN_CANCELED"
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import pytest
|
||||
from ag_ui.core import RunAgentInput
|
||||
|
||||
@@ -265,3 +266,73 @@ async def test_execute_runs_router_then_worker(
|
||||
assert load_calls == [AgentType.ROUTER, AgentType.WORKER]
|
||||
assert result["router"]["normalized_task_input"]["user_text"] == "安排会议"
|
||||
assert result["worker"]["answer"] == "ok"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_raises_cancelled_error_before_worker_when_cancel_requested(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
class _FakePipeline:
|
||||
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
|
||||
del session_id, event
|
||||
return "1-0"
|
||||
|
||||
class _FakeSessionCtx:
|
||||
async def __aenter__(self) -> object:
|
||||
return object()
|
||||
|
||||
async def __aexit__(self, exc_type: object, exc: object, tb: object) -> None:
|
||||
del exc_type, exc, tb
|
||||
|
||||
runner = AgentScopeRunner()
|
||||
|
||||
async def _fake_load_stage_config(*, session: object, agent_type: AgentType):
|
||||
del session
|
||||
return runner_module.SystemAgentRuntimeConfig(
|
||||
agent_type=agent_type,
|
||||
model_code="demo",
|
||||
api_base_url="https://example.com",
|
||||
api_key="test",
|
||||
llm_config=runner_module.SystemAgentLLMConfig(),
|
||||
)
|
||||
|
||||
async def _fake_execute_router_step(**kwargs: object) -> RouterAgentOutput:
|
||||
del kwargs
|
||||
return RouterAgentOutput(
|
||||
normalized_task_input=NormalizedTaskInput(
|
||||
user_text="安排会议",
|
||||
context_summary="",
|
||||
),
|
||||
key_entities=[],
|
||||
constraints=[],
|
||||
task_typing=TaskTyping(primary=TaskType.SCHEDULING),
|
||||
execution_mode=ExecutionMode.TOOL_ASSISTED,
|
||||
result_typing=ResultTyping(primary=ResultType.EXECUTION_REPORT),
|
||||
ui=RouterUiDecision(
|
||||
ui_mode=UiMode.NONE,
|
||||
ui_decision_reason="单任务",
|
||||
),
|
||||
)
|
||||
|
||||
async def _fake_execute_worker_step(**kwargs: object) -> WorkerAgentOutputLite:
|
||||
del kwargs
|
||||
raise AssertionError("worker should not run after cancel")
|
||||
|
||||
monkeypatch.setattr(runner_module, "AsyncSessionLocal", lambda: _FakeSessionCtx())
|
||||
monkeypatch.setattr(runner, "_load_stage_config", _fake_load_stage_config)
|
||||
monkeypatch.setattr(runner, "_build_toolkit", lambda **kwargs: object())
|
||||
monkeypatch.setattr(runner, "_execute_router_step", _fake_execute_router_step)
|
||||
monkeypatch.setattr(runner, "_execute_worker_step", _fake_execute_worker_step)
|
||||
|
||||
async def _cancel_checker() -> bool:
|
||||
return True
|
||||
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await runner.execute(
|
||||
user_context=_user_context(),
|
||||
context_messages=[],
|
||||
pipeline=_FakePipeline(),
|
||||
run_input=_run_input(),
|
||||
runtime_config=_runtime_config(),
|
||||
cancel_checker=_cancel_checker,
|
||||
)
|
||||
|
||||
@@ -166,6 +166,69 @@ async def test_run_agentscope_task_injects_runtime_config(
|
||||
assert captured_config["runtime_config"] is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agentscope_task_injects_cancel_checker(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
class _FakeRuntime:
|
||||
def __init__(self, **kwargs: object) -> None:
|
||||
del kwargs
|
||||
|
||||
async def run(self, **kwargs: object) -> object:
|
||||
checker = kwargs.get("cancel_checker")
|
||||
assert callable(checker)
|
||||
captured["cancelled"] = await checker() # type: ignore[misc]
|
||||
return object()
|
||||
|
||||
class _FakeRedis:
|
||||
async def exists(self, key: str) -> int:
|
||||
captured["cancel_key"] = key
|
||||
return 1
|
||||
|
||||
async def delete(self, key: str) -> int:
|
||||
captured["deleted_key"] = key
|
||||
return 1
|
||||
|
||||
async def _fake_get_redis_client() -> object:
|
||||
return _FakeRedis()
|
||||
|
||||
async def _empty_context(**kwargs: object) -> list[dict[str, Any]]:
|
||||
del kwargs
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(tasks_module, "AgentScopeRuntimeOrchestrator", _FakeRuntime)
|
||||
monkeypatch.setattr(
|
||||
tasks_module,
|
||||
"get_or_init_redis_client",
|
||||
_fake_get_redis_client,
|
||||
)
|
||||
monkeypatch.setattr(tasks_module, "AsyncSessionLocal", lambda: _FakeSessionCtx())
|
||||
monkeypatch.setattr(tasks_module, "_build_user_context", _fake_user_context)
|
||||
monkeypatch.setattr(
|
||||
tasks_module,
|
||||
"_build_recent_context_messages",
|
||||
_empty_context,
|
||||
)
|
||||
|
||||
await tasks_module.run_agentscope_task(
|
||||
{
|
||||
"command": "run",
|
||||
"owner_id": str(uuid4()),
|
||||
"run_input": _run_input_payload(),
|
||||
"runtime_config": {
|
||||
"enabled_tools": [],
|
||||
"context": {"window_mode": "day", "window_count": 2},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
assert captured["cancelled"] is True
|
||||
assert isinstance(captured["cancel_key"], str)
|
||||
assert captured["deleted_key"] == captured["cancel_key"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agentscope_task_requires_owner_id() -> None:
|
||||
with pytest.raises(ValueError, match="owner_id is required"):
|
||||
|
||||
@@ -100,6 +100,7 @@ class _FakeRepository:
|
||||
class _FakeQueue:
|
||||
def __init__(self) -> None:
|
||||
self.commands: list[dict[str, object]] = []
|
||||
self.cancel_requests: list[dict[str, str]] = []
|
||||
|
||||
async def enqueue(
|
||||
self, *, command: dict[str, object], dedup_key: str | None
|
||||
@@ -108,6 +109,21 @@ class _FakeQueue:
|
||||
self.commands.append(command)
|
||||
return "task-1"
|
||||
|
||||
async def request_cancel(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
requested_by: str,
|
||||
) -> None:
|
||||
self.cancel_requests.append(
|
||||
{
|
||||
"thread_id": thread_id,
|
||||
"run_id": run_id,
|
||||
"requested_by": requested_by,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class _FakeStream:
|
||||
async def read(
|
||||
@@ -469,3 +485,56 @@ async def test_get_history_snapshot_filters_out_tool_messages() -> None:
|
||||
)
|
||||
|
||||
assert [message.role for message in snapshot.messages] == ["user", "assistant"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_run_requests_queue_cancel_for_owner() -> None:
|
||||
queue = _FakeQueue()
|
||||
service = AgentService(
|
||||
repository=_FakeRepository(),
|
||||
queue=queue,
|
||||
stream=_FakeStream(),
|
||||
attachment_storage=_FakeAttachmentStorage(),
|
||||
)
|
||||
|
||||
result = await service.cancel_run(
|
||||
thread_id="00000000-0000-0000-0000-000000000001",
|
||||
run_id="run-cancel-1",
|
||||
current_user=_user(),
|
||||
)
|
||||
|
||||
assert result.accepted is True
|
||||
assert result.thread_id == "00000000-0000-0000-0000-000000000001"
|
||||
assert result.run_id == "run-cancel-1"
|
||||
assert queue.cancel_requests == [
|
||||
{
|
||||
"thread_id": "00000000-0000-0000-0000-000000000001",
|
||||
"run_id": "run-cancel-1",
|
||||
"requested_by": "00000000-0000-0000-0000-000000000001",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_run_rejects_non_owner() -> None:
|
||||
queue = _FakeQueue()
|
||||
service = AgentService(
|
||||
repository=_FakeRepository(),
|
||||
queue=queue,
|
||||
stream=_FakeStream(),
|
||||
attachment_storage=_FakeAttachmentStorage(),
|
||||
)
|
||||
other_user = CurrentUser(
|
||||
id=UUID("00000000-0000-0000-0000-000000000099"),
|
||||
phone="+8613812340000",
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.cancel_run(
|
||||
thread_id="00000000-0000-0000-0000-000000000001",
|
||||
run_id="run-cancel-2",
|
||||
current_user=other_user,
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert queue.cancel_requests == []
|
||||
|
||||
Reference in New Issue
Block a user