feat: 支持 agent 运行取消功能

This commit is contained in:
qzl
2026-03-25 18:33:25 +08:00
parent 599c597e69
commit 96fc4a1e77
21 changed files with 778 additions and 85 deletions
+26
View File
@@ -1,6 +1,8 @@
from __future__ import annotations
import asyncio
from datetime import datetime, timezone
import json
from typing import Any
from fastapi import Depends
@@ -22,6 +24,7 @@ DEDUP_WAIT_RETRIES = 20
DEDUP_WAIT_SECONDS = 0.05
DEDUP_LOCK_SECONDS = 300
DEDUP_INFLIGHT_MARKER = "__inflight__"
RUN_CANCEL_SIGNAL_TTL_SECONDS = 1800
def _event_stream_block_ms() -> int:
@@ -87,6 +90,29 @@ class TaskiqQueueClient:
await redis_client.delete(redis_key)
raise
async def request_cancel(
self,
*,
thread_id: str,
run_id: str,
requested_by: str,
) -> None:
redis_client = await self._get_redis()
cancel_key = f"agent:cancel:{thread_id}:{run_id}"
payload = json.dumps(
{
"requested_by": requested_by,
"requested_at": datetime.now(timezone.utc).isoformat(),
},
ensure_ascii=True,
separators=(",", ":"),
)
await redis_client.set(
cancel_key,
payload,
ex=RUN_CANCEL_SIGNAL_TTL_SECONDS,
)
class RedisEventStream:
def __init__(self) -> None:
+29
View File
@@ -37,6 +37,7 @@ from v1.agent.schemas import (
AttachmentReference,
AttachmentSignedUrlResponse,
AttachmentUploadResponse,
CancelRunResponse,
HistorySnapshotResponse,
TaskAcceptedResponse,
)
@@ -147,6 +148,34 @@ async def enqueue_run(
)
@router.post(
"/runs/{thread_id}/cancel",
response_model=CancelRunResponse,
status_code=status.HTTP_202_ACCEPTED,
)
async def cancel_run(
thread_id: str,
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
run_id: str = Query(
alias="runId",
min_length=1,
max_length=128,
pattern=r"^[A-Za-z0-9_-]+$",
),
) -> CancelRunResponse:
canceled = await service.cancel_run(
thread_id=thread_id,
run_id=run_id,
current_user=current_user,
)
return CancelRunResponse(
threadId=canceled.thread_id,
runId=canceled.run_id,
accepted=canceled.accepted,
)
@router.get("/runs/{thread_id}/events")
async def stream_events(
request: Request,
+23
View File
@@ -49,6 +49,14 @@ class QueueClientLike(Protocol):
self, *, command: dict[str, object], dedup_key: str | None
) -> str: ...
async def request_cancel(
self,
*,
thread_id: str,
run_id: str,
requested_by: str,
) -> None: ...
class EventStreamLike(Protocol):
async def read(
@@ -90,6 +98,13 @@ class TaskAccepted:
created: bool
@dataclass(frozen=True)
class CancelRequested:
thread_id: str
run_id: str
accepted: bool
class TaskAcceptedResponse(BaseModel):
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
@@ -99,6 +114,14 @@ class TaskAcceptedResponse(BaseModel):
created: bool
class CancelRunResponse(BaseModel):
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
thread_id: str = Field(alias="threadId")
run_id: str = Field(alias="runId")
accepted: bool
class AsrTranscribeResponse(BaseModel):
transcript: str = Field(description="Transcribed text from audio")
+21
View File
@@ -30,6 +30,7 @@ from schemas.domain.chat_message import (
from v1.agent.schemas import (
AgentRepositoryLike,
AttachmentStorageLike,
CancelRequested,
EventStreamLike,
HistorySnapshotResponse,
QueueClientLike,
@@ -157,6 +158,26 @@ class AgentService:
created=created,
)
async def cancel_run(
self,
*,
thread_id: str,
run_id: str,
current_user: CurrentUser,
) -> CancelRequested:
owner = await self._repository.get_session_owner(session_id=thread_id)
ensure_session_owner(owner_id=owner, current_user=current_user)
await self._queue.request_cancel(
thread_id=thread_id,
run_id=run_id,
requested_by=str(current_user.id),
)
return CancelRequested(
thread_id=thread_id,
run_id=run_id,
accepted=True,
)
async def _append_context_cache_user_message(
self,
*,