feat: 重构 agentscope 缓存架构,新增消息和附件缓存

This commit is contained in:
qzl
2026-03-25 17:41:55 +08:00
parent d22ded21f8
commit 599c597e69
25 changed files with 1509 additions and 78 deletions
+1 -1
View File
@@ -92,5 +92,5 @@ SOCIAL_APP_VERSION__DOWNLOAD_BASE_URL=
############
# Test相关
############
SOCIAL_TEST__PHONE=+8613812345678
SOCIAL_TEST__PHONE=8613812345678
SOCIAL_TEST__PASSWORD=Test@123456
@@ -0,0 +1,21 @@
from .attachment_content_cache import (
AttachmentContentCache,
create_attachment_content_cache,
)
from .context_messages_cache import (
ContextMessagesCache,
create_context_messages_cache,
)
from .user_context_cache import (
UserContextCache,
create_user_context_cache,
)
__all__ = [
"AttachmentContentCache",
"ContextMessagesCache",
"UserContextCache",
"create_attachment_content_cache",
"create_context_messages_cache",
"create_user_context_cache",
]
@@ -0,0 +1,96 @@
from __future__ import annotations
from core.config.settings import config
from core.logging import get_logger
from services.caches import CacheStore, get_cache_store
logger = get_logger("core.agentscope.caches.attachment_content_cache")
class AttachmentContentCache:
def __init__(
self,
*,
client: CacheStore,
key_prefix: str,
ttl_seconds: int,
max_base64_bytes: int,
) -> None:
self._client = client
self._key_prefix = key_prefix
self._ttl_seconds = ttl_seconds
self._max_base64_bytes = max_base64_bytes
async def get(
self,
*,
bucket: str,
path: str,
mime_type: str,
) -> str | None:
key = self._key(bucket=bucket, path=path, mime_type=mime_type)
try:
raw = await self._client.hgetall(key)
except Exception as exc:
logger.warning(
"Failed to read attachment content cache",
bucket=bucket,
path=path,
error=str(exc),
)
return None
payload = raw.get("base64")
if not isinstance(payload, str) or not payload:
return None
return payload
async def set(
self,
*,
bucket: str,
path: str,
mime_type: str,
base64_data: str,
) -> None:
encoded_bytes = len(base64_data.encode("utf-8"))
if encoded_bytes > self._max_base64_bytes:
logger.info(
"Skip attachment cache write due to size limit",
bucket=bucket,
path=path,
encoded_bytes=encoded_bytes,
max_bytes=self._max_base64_bytes,
)
return
key = self._key(bucket=bucket, path=path, mime_type=mime_type)
try:
await self._client.hset(
key,
mapping={
"base64": base64_data,
"mime_type": mime_type,
},
)
await self._client.expire(key, self._ttl_seconds)
except Exception as exc:
logger.warning(
"Failed to write attachment content cache",
bucket=bucket,
path=path,
error=str(exc),
)
def _key(self, *, bucket: str, path: str, mime_type: str) -> str:
return f"{self._key_prefix}:{bucket}:{path}:mime:{mime_type}"
def create_attachment_content_cache() -> AttachmentContentCache:
runtime = config.agent_runtime
return AttachmentContentCache(
client=get_cache_store(),
key_prefix=runtime.attachment_content_cache_prefix,
ttl_seconds=runtime.attachment_content_cache_ttl_seconds,
max_base64_bytes=runtime.attachment_content_cache_max_base64_bytes,
)
@@ -0,0 +1,281 @@
from __future__ import annotations
from datetime import datetime, timezone
import json
from core.config.settings import config
from core.logging import get_logger
from schemas.agent.visibility import SystemVisibilityBit, bit_mask
from schemas.domain.automation import ContextWindowMode, MessageContextConfig
from services.caches import CacheStore, get_cache_store
logger = get_logger("core.agentscope.caches.context_messages_cache")
class ContextMessagesCache:
def __init__(
self,
*,
client: CacheStore,
key_prefix: str,
ttl_seconds: int,
) -> None:
self._client = client
self._key_prefix = key_prefix
self._ttl_seconds = ttl_seconds
async def get(
self,
*,
thread_id: str,
runtime_mode: str,
context_config: MessageContextConfig,
) -> list[dict[str, object]] | None:
key = self._cache_key(
thread_id=thread_id,
runtime_mode=runtime_mode,
context_config=context_config,
)
try:
raw = await self._client.hgetall(key)
except Exception as exc:
logger.warning(
"Failed to read context messages cache",
thread_id=thread_id,
error=str(exc),
)
return None
payload = raw.get("payload")
if not isinstance(payload, str) or not payload:
return None
try:
decoded = json.loads(payload)
except Exception:
await self._safe_delete(key)
return None
if not isinstance(decoded, dict):
await self._safe_delete(key)
return None
messages = decoded.get("messages")
if not isinstance(messages, list):
await self._safe_delete(key)
return None
return [item for item in messages if isinstance(item, dict)]
async def set(
self,
*,
thread_id: str,
runtime_mode: str,
context_config: MessageContextConfig,
messages: list[dict[str, object]],
) -> None:
key = self._cache_key(
thread_id=thread_id,
runtime_mode=runtime_mode,
context_config=context_config,
)
index_key = self._index_key(thread_id=thread_id, runtime_mode=runtime_mode)
payload = json.dumps(
{
"window_mode": context_config.window_mode.value,
"window_count": int(context_config.window_count),
"messages": [item for item in messages if isinstance(item, dict)],
},
ensure_ascii=True,
separators=(",", ":"),
)
try:
await self._client.hset(key, mapping={"payload": payload})
await self._client.expire(key, self._ttl_seconds)
await self._client.sadd(index_key, key)
await self._client.expire(index_key, self._ttl_seconds)
except Exception as exc:
logger.warning(
"Failed to write context messages cache",
thread_id=thread_id,
error=str(exc),
)
async def append_message(
self,
*,
thread_id: str,
runtime_mode: str,
visibility_mask: int,
message: dict[str, object],
) -> None:
if not self._is_context_visible(visibility_mask=visibility_mask):
return
index_key = self._index_key(thread_id=thread_id, runtime_mode=runtime_mode)
try:
keys = await self._client.smembers(index_key)
except Exception as exc:
logger.warning(
"Failed to read context cache index",
thread_id=thread_id,
error=str(exc),
)
return
if not keys:
return
normalized = self._normalize_message(message)
for key in keys:
try:
raw = await self._client.hgetall(key)
payload_raw = raw.get("payload")
if not isinstance(payload_raw, str) or not payload_raw:
continue
decoded = json.loads(payload_raw)
if not isinstance(decoded, dict):
continue
mode = decoded.get("window_mode")
count = decoded.get("window_count")
messages_raw = decoded.get("messages")
if (
not isinstance(mode, str)
or not isinstance(count, int)
or not isinstance(messages_raw, list)
):
continue
messages = [item for item in messages_raw if isinstance(item, dict)]
messages.append(dict(normalized))
trimmed = self._trim_messages(messages=messages, mode=mode, count=count)
next_payload = json.dumps(
{
"window_mode": mode,
"window_count": count,
"messages": trimmed,
},
ensure_ascii=True,
separators=(",", ":"),
)
await self._client.hset(key, mapping={"payload": next_payload})
await self._client.expire(key, self._ttl_seconds)
except Exception as exc:
logger.warning(
"Failed to append context cache message",
key=key,
thread_id=thread_id,
error=str(exc),
)
def _cache_key(
self,
*,
thread_id: str,
runtime_mode: str,
context_config: MessageContextConfig,
) -> str:
return (
f"{self._key_prefix}:{thread_id}:rm:{runtime_mode}:"
f"wm:{context_config.window_mode.value}:wc:{int(context_config.window_count)}"
)
def _index_key(self, *, thread_id: str, runtime_mode: str) -> str:
return f"{self._key_prefix}:index:{runtime_mode}:{thread_id}"
async def _safe_delete(self, key: str) -> None:
try:
await self._client.delete(key)
except Exception:
return
def _trim_messages(
self,
*,
messages: list[dict[str, object]],
mode: str,
count: int,
) -> list[dict[str, object]]:
safe_count = max(int(count), 1)
if mode == ContextWindowMode.NUMBER.value:
return self._trim_by_user_window(messages=messages, count=safe_count)
return self._trim_by_day_window(messages=messages, count=safe_count)
def _trim_by_user_window(
self,
*,
messages: list[dict[str, object]],
count: int,
) -> list[dict[str, object]]:
selected_reversed: list[dict[str, object]] = []
user_count = 0
for item in reversed(messages):
selected_reversed.append(item)
role = item.get("role")
if isinstance(role, str) and role == "user":
user_count += 1
if user_count >= count:
break
selected_reversed.reverse()
return selected_reversed
def _trim_by_day_window(
self,
*,
messages: list[dict[str, object]],
count: int,
) -> list[dict[str, object]]:
selected_reversed: list[dict[str, object]] = []
days_seen: list[str] = []
for item in reversed(messages):
day_value = self._extract_day(item)
if day_value not in days_seen:
if len(days_seen) >= count:
break
days_seen.append(day_value)
selected_reversed.append(item)
selected_reversed.reverse()
return selected_reversed
@staticmethod
def _extract_day(message: dict[str, object]) -> str:
raw = message.get("timestamp")
if isinstance(raw, str) and raw:
normalized = raw.replace("Z", "+00:00")
try:
return datetime.fromisoformat(normalized).date().isoformat()
except ValueError:
pass
return datetime.now(timezone.utc).date().isoformat()
@staticmethod
def _normalize_message(message: dict[str, object]) -> dict[str, object]:
normalized: dict[str, object] = {
"role": str(message.get("role") or "assistant"),
"content": str(message.get("content") or ""),
"timestamp": str(
message.get("timestamp")
or datetime.now(timezone.utc).isoformat(timespec="seconds")
),
}
metadata = message.get("metadata")
if isinstance(metadata, dict):
normalized["metadata"] = metadata
return normalized
@staticmethod
def _is_context_visible(*, visibility_mask: int) -> bool:
required = bit_mask(bit=int(SystemVisibilityBit.CONTEXT_ASSEMBLY))
return (max(int(visibility_mask), 0) & required) != 0
def create_context_messages_cache() -> ContextMessagesCache:
runtime = config.agent_runtime
return ContextMessagesCache(
client=get_cache_store(),
key_prefix=runtime.context_messages_cache_prefix,
ttl_seconds=runtime.context_messages_cache_ttl_seconds,
)
@@ -1,49 +1,26 @@
from __future__ import annotations
import inspect
import json
from collections.abc import Iterable
from typing import Any, Protocol
from typing import Any
from uuid import UUID
import redis.asyncio as redis
from core.config.settings import config
from core.logging import get_logger
from schemas.shared.user import (
UserContext,
parse_profile_settings,
)
from services.caches import CacheStore, get_cache_store
logger = get_logger("core.agentscope.persistence.user_context_cache")
class RedisHashClient(Protocol):
def hgetall(self, name: str, /) -> Any: ...
def hset(self, name: str, /, *args: Any, **kwargs: Any) -> Any: ...
def hincrby(self, name: str, key: str, amount: int = 1, /) -> Any: ...
def expire(self, name: str, time: int, /) -> Any: ...
def delete(self, *names: str) -> Any: ...
def sadd(self, name: str, *values: str) -> Any: ...
def smembers(self, name: str) -> Any: ...
async def _maybe_await(value: Any) -> Any:
if inspect.isawaitable(value):
return await value
return value
logger = get_logger("core.agentscope.caches.user_context_cache")
class UserContextCache:
def __init__(
self,
*,
client: RedisHashClient,
client: CacheStore,
key_prefix: str,
ttl_seconds: int,
max_turns: int,
@@ -56,7 +33,7 @@ class UserContextCache:
async def get(self, *, session_id: UUID) -> UserContext | None:
key = self._key(session_id)
try:
raw = await _maybe_await(self._client.hgetall(key))
raw = await self._client.hgetall(key)
except Exception as exc:
logger.warning(
"Failed to read user context cache",
@@ -107,18 +84,16 @@ class UserContextCache:
index_key = self._user_sessions_key(user_id)
payload = self._serialize(context)
try:
await _maybe_await(
self._client.hset(
key,
mapping={
"payload": payload,
"turns_used": "0",
},
)
await self._client.hset(
key,
mapping={
"payload": payload,
"turns_used": "0",
},
)
await _maybe_await(self._client.expire(key, self._ttl_seconds))
await _maybe_await(self._client.sadd(index_key, key))
await _maybe_await(self._client.expire(index_key, self._ttl_seconds))
await self._client.expire(key, self._ttl_seconds)
await self._client.sadd(index_key, key)
await self._client.expire(index_key, self._ttl_seconds)
except Exception as exc:
logger.warning(
"Failed to write user context cache",
@@ -130,7 +105,7 @@ class UserContextCache:
async def invalidate_user(self, *, user_id: UUID) -> int:
index_key = self._user_sessions_key(user_id)
try:
members_raw = await _maybe_await(self._client.smembers(index_key))
members_raw = await self._client.smembers(index_key)
except Exception as exc:
logger.warning(
"Failed to read user context cache index",
@@ -147,7 +122,7 @@ class UserContextCache:
deleted = 0
try:
deleted_raw = await _maybe_await(self._client.delete(*members))
deleted_raw = await self._client.delete(*members)
deleted = self._parse_int(deleted_raw)
except Exception as exc:
logger.warning(
@@ -205,7 +180,7 @@ class UserContextCache:
async def _safe_delete(self, key: str) -> None:
try:
await _maybe_await(self._client.delete(key))
await self._client.delete(key)
except Exception as exc:
logger.warning(
"Failed to delete user context cache key", key=key, error=str(exc)
@@ -214,7 +189,7 @@ class UserContextCache:
async def _safe_hincrby(self, key: str, field: str, amount: int) -> None:
try:
await _maybe_await(self._client.hincrby(key, field, amount))
await self._client.hincrby(key, field, amount)
except Exception as exc:
logger.warning(
"Failed to update user context cache usage",
@@ -272,7 +247,7 @@ class UserContextCache:
def create_user_context_cache() -> UserContextCache:
client = redis.from_url(config.redis.url, decode_responses=True)
client = get_cache_store()
runtime_settings = config.agent_runtime
return UserContextCache(
client=client,
+95 -8
View File
@@ -1,11 +1,16 @@
from __future__ import annotations
from datetime import datetime, timezone
from decimal import Decimal, InvalidOperation
from typing import Any, Callable, Protocol
from uuid import UUID
from core.agentscope.caches.context_messages_cache import (
create_context_messages_cache,
)
from core.agentscope.events.persistence import MessageRepository, SessionRepository
from core.logging import get_logger
from schemas.agent.forwarded_props import RuntimeMode
from schemas.enums import AgentChatMessageRole, AgentChatSessionStatus
from schemas.agent.system_agent import AgentType
from schemas.agent.runtime_models import AgentOutput, RouterAgentOutput, ToolAgentOutput
@@ -174,7 +179,10 @@ class SqlAlchemyEventStore:
if locked_session is None:
return
seq = int(getattr(locked_session, "message_count", 0) or 0) + 1
await message_repo.append_message(
visibility_mask = self._resolve_stage_visibility_mask(
event=event,
)
persisted = await message_repo.append_message(
session_id=session_id,
seq=seq,
role=role,
@@ -186,9 +194,16 @@ class SqlAlchemyEventStore:
output_tokens=output_tokens,
cost=cost,
latency_ms=latency_ms,
visibility_mask=self._resolve_stage_visibility_mask(
event=event,
),
visibility_mask=visibility_mask,
)
await self._append_context_cache_message(
session_id=session_id,
event=event,
visibility_mask=visibility_mask,
role=role.value,
content=content,
metadata=metadata_model.model_dump(mode="json", exclude_none=True),
timestamp=self._resolve_message_timestamp(persisted),
)
current_status = getattr(chat_session, "status", AgentChatSessionStatus.RUNNING)
@@ -339,16 +354,26 @@ class SqlAlchemyEventStore:
if locked_session is None:
return
seq = int(getattr(locked_session, "message_count", 0) or 0) + 1
await message_repo.append_message(
visibility_mask = self._resolve_stage_visibility_mask(
event=event,
)
persisted = await message_repo.append_message(
session_id=session_id,
seq=seq,
role=AgentChatMessageRole.TOOL,
content=content,
tool_name=tool_output.tool_name,
metadata=metadata_model.model_dump(mode="json", exclude_none=True),
visibility_mask=self._resolve_stage_visibility_mask(
event=event,
),
visibility_mask=visibility_mask,
)
await self._append_context_cache_message(
session_id=session_id,
event=event,
visibility_mask=visibility_mask,
role=AgentChatMessageRole.TOOL.value,
content=content,
metadata=metadata_model.model_dump(mode="json", exclude_none=True),
timestamp=self._resolve_message_timestamp(persisted),
)
current_status = getattr(chat_session, "status", AgentChatSessionStatus.RUNNING)
@@ -377,6 +402,13 @@ class SqlAlchemyEventStore:
*,
event: dict[str, Any],
) -> int:
runtime_mode = self._event_value(event, "runtime_mode")
if (
isinstance(runtime_mode, str)
and runtime_mode.strip().lower() == RuntimeMode.AUTOMATION.value
):
return bit_mask(bit=int(SystemVisibilityBit.UI_HISTORY))
raw_stage = self._event_value(event, "stage")
if not isinstance(raw_stage, str):
return bit_mask(bit=int(SystemVisibilityBit.UI_HISTORY))
@@ -387,6 +419,61 @@ class SqlAlchemyEventStore:
bit=int(SystemVisibilityBit.CONTEXT_ASSEMBLY)
)
async def _append_context_cache_message(
self,
*,
session_id: UUID,
event: dict[str, Any],
visibility_mask: int,
role: str,
content: str,
metadata: dict[str, object] | None,
timestamp: str,
) -> None:
message_payload: dict[str, object] = {
"role": role,
"content": content,
"timestamp": timestamp,
}
if isinstance(metadata, dict):
message_payload["metadata"] = metadata
try:
context_cache = create_context_messages_cache()
await context_cache.append_message(
thread_id=str(session_id),
runtime_mode=self._resolve_runtime_mode(event=event),
visibility_mask=visibility_mask,
message=message_payload,
)
except Exception as exc:
self._logger.warning(
"Failed to append context cache message from event",
thread_id=str(session_id),
error=str(exc),
)
@staticmethod
def _resolve_runtime_mode(*, event: dict[str, Any]) -> str:
raw = event.get("runtime_mode")
if isinstance(raw, str):
normalized = raw.strip().lower()
if normalized:
return normalized
return RuntimeMode.CHAT.value
@staticmethod
def _resolve_message_timestamp(message: Any) -> str:
created_at = getattr(message, "created_at", None)
if isinstance(created_at, str) and created_at:
return created_at
if isinstance(created_at, datetime):
try:
return created_at.astimezone(timezone.utc).isoformat()
except Exception:
pass
return datetime.now(timezone.utc).isoformat(timespec="seconds")
async def _update_session_state(
self,
*,
@@ -1,9 +0,0 @@
from core.agentscope.persistence.user_context_cache import (
UserContextCache,
create_user_context_cache,
)
__all__ = [
"UserContextCache",
"create_user_context_cache",
]
@@ -29,7 +29,9 @@ from models.llm_factory import LlmFactory
from models.system_agents import SystemAgents
from schemas.agent.forwarded_props import (
ClientTimeContext,
RuntimeMode,
parse_forwarded_props_client_time,
parse_forwarded_props_runtime_mode,
)
from schemas.agent.runtime_models import (
RouterAgentOutput,
@@ -76,6 +78,7 @@ class AgentScopeRunner:
) -> dict[str, Any]:
owner_id = UUID(user_context.id)
runtime_client_time = self._resolve_runtime_client_time(run_input=run_input)
runtime_mode = self._resolve_runtime_mode(run_input=run_input)
async with AsyncSessionLocal() as session:
router_config = await self._load_stage_config(
@@ -99,6 +102,7 @@ class AgentScopeRunner:
context_messages=context_messages,
stage_config=router_config,
runtime_client_time=runtime_client_time,
runtime_mode=runtime_mode,
user_memory=user_memory,
)
worker_output = await self._execute_worker_step(
@@ -109,6 +113,7 @@ class AgentScopeRunner:
toolkit=worker_toolkit,
stage_config=worker_config,
runtime_client_time=runtime_client_time,
runtime_mode=runtime_mode,
work_memory=work_memory,
)
return {
@@ -171,6 +176,7 @@ class AgentScopeRunner:
context_messages: list[Msg],
stage_config: SystemAgentRuntimeConfig,
runtime_client_time: ClientTimeContext | None,
runtime_mode: RuntimeMode,
user_memory: UserMemoryContent | None,
) -> RouterAgentOutput:
await self._emit_step_event(
@@ -178,6 +184,7 @@ class AgentScopeRunner:
run_input=run_input,
step_name=AgentType.ROUTER.value,
event_type="STEP_STARTED",
runtime_mode=runtime_mode,
)
router_result = await self._run_router_stage(
user_context=user_context,
@@ -193,6 +200,7 @@ class AgentScopeRunner:
run_input=run_input,
step_name=AgentType.ROUTER.value,
event_type="STEP_FINISHED",
runtime_mode=runtime_mode,
extra_event={
"_router_persist": {
"router_output": router_output.model_dump(
@@ -214,6 +222,7 @@ class AgentScopeRunner:
toolkit: Any,
stage_config: SystemAgentRuntimeConfig,
runtime_client_time: ClientTimeContext | None,
runtime_mode: RuntimeMode,
work_memory: WorkProfileContent | None,
) -> WorkerAgentOutputLite:
worker_output_model = resolve_worker_output_model(router_output.ui.ui_mode)
@@ -222,6 +231,7 @@ class AgentScopeRunner:
run_input=run_input,
step_name=AgentType.WORKER.value,
event_type="STEP_STARTED",
runtime_mode=runtime_mode,
)
worker_result = await self._run_worker_stage(
user_context=user_context,
@@ -234,6 +244,7 @@ class AgentScopeRunner:
worker_output_model=worker_output_model,
pipeline=pipeline,
runtime_client_time=runtime_client_time,
runtime_mode=runtime_mode,
work_memory=work_memory,
)
worker_output = worker_output_model.model_validate(worker_result.payload)
@@ -242,6 +253,7 @@ class AgentScopeRunner:
run_input=run_input,
step_name=AgentType.WORKER.value,
event_type="STEP_FINISHED",
runtime_mode=runtime_mode,
)
return worker_output
@@ -332,6 +344,7 @@ class AgentScopeRunner:
worker_output_model: type[WorkerAgentOutputLite],
pipeline: PipelineLike,
runtime_client_time: ClientTimeContext | None,
runtime_mode: RuntimeMode,
work_memory: WorkProfileContent | None,
) -> StageExecutionResult:
tracking_model = self._build_model(stage_config=stage_config)
@@ -340,6 +353,7 @@ class AgentScopeRunner:
session_id=run_input.thread_id,
run_id=run_input.run_id,
stage=stage_config.agent_type.value,
runtime_mode=runtime_mode.value,
emit_text_events=True,
emit_tool_events=True,
)
@@ -437,12 +451,14 @@ class AgentScopeRunner:
run_input: RunAgentInput,
step_name: str,
event_type: str,
runtime_mode: RuntimeMode,
extra_event: dict[str, Any] | None = None,
) -> None:
payload: dict[str, Any] = {
"type": event_type,
"threadId": run_input.thread_id,
"runId": run_input.run_id,
"runtime_mode": runtime_mode.value,
"stepName": step_name,
}
if extra_event:
@@ -459,6 +475,15 @@ class AgentScopeRunner:
getattr(run_input, "forwarded_props", None)
)
@staticmethod
def _resolve_runtime_mode(*, run_input: RunAgentInput) -> RuntimeMode:
try:
return parse_forwarded_props_runtime_mode(
getattr(run_input, "forwarded_props", None)
)
except ValueError:
return RuntimeMode.CHAT
@staticmethod
def _resolve_provider_api_key(*, factory_name: str) -> str:
normalized_factory_name = factory_name.strip().upper()
@@ -20,6 +20,7 @@ class PipelineStageEmitter:
session_id: str,
run_id: str,
stage: str,
runtime_mode: str,
emit_text_events: bool,
emit_tool_events: bool,
) -> None:
@@ -27,6 +28,7 @@ class PipelineStageEmitter:
self._session_id = session_id
self._run_id = run_id
self._stage = stage
self._runtime_mode = runtime_mode
self._emit_text_events = emit_text_events
self._emit_tool_events = emit_tool_events
self._emitted_tool_calls: set[str] = set()
@@ -127,6 +129,7 @@ class PipelineStageEmitter:
"type": event_type,
"threadId": self._session_id,
"runId": self._run_id,
"runtime_mode": self._runtime_mode,
**payload,
},
)
+82 -9
View File
@@ -6,6 +6,13 @@ from typing import Any, cast
from uuid import UUID
from agentscope.message import Msg
from core.agentscope.caches import create_user_context_cache
from core.agentscope.caches.attachment_content_cache import (
create_attachment_content_cache,
)
from core.agentscope.caches.context_messages_cache import (
create_context_messages_cache,
)
from core.agentscope.events import (
AgentScopeAgUiCodec,
AgentScopeEventPipeline,
@@ -20,12 +27,16 @@ from core.config.settings import config
from core.db.session import AsyncSessionLocal
from core.logging import get_logger
from core.taskiq.app import worker_agent_broker, worker_automation_broker
from schemas.agent.forwarded_props import (
RuntimeMode,
parse_forwarded_props_runtime_mode,
)
from schemas.domain.automation import MessageContextConfig, RuntimeConfig
from schemas.domain.memory_content import UserMemoryContent, WorkProfileContent
from schemas.domain.chat_message import (
AgentChatMessageMetadata,
extract_user_message_attachments,
)
from schemas.domain.memory_content import UserMemoryContent, WorkProfileContent
from schemas.shared.user import UserContext
from services.base.redis import get_or_init_redis_client
from services.base.supabase import supabase_service
@@ -73,28 +84,56 @@ async def _build_user_context(
*,
owner_id: UUID,
session: Any,
session_id: str,
) -> UserContext:
cache = create_user_context_cache()
cached = await cache.get(session_id=UUID(session_id))
if cached:
return cached
current_user = CurrentUser(id=owner_id)
user_service = get_user_service(session=session, user=current_user)
return await user_service.get_me()
user_context = await user_service.get_me()
await cache.set(session_id=UUID(session_id), context=user_context)
return user_context
async def _build_recent_context_messages(
*,
session: Any,
thread_id: str,
runtime_mode: RuntimeMode = RuntimeMode.CHAT,
context_config: "MessageContextConfig",
) -> list[Msg]:
context_service = AgentContextService(repository=AgentRepository(session))
result = await context_service.load_context_messages(
context_cache = create_context_messages_cache()
attachment_cache = create_attachment_content_cache()
raw_messages = await context_cache.get(
thread_id=thread_id,
runtime_mode=runtime_mode.value,
context_config=context_config,
)
if not result:
return []
if raw_messages is None:
context_service = AgentContextService(repository=AgentRepository(session))
result = await context_service.load_context_messages(
thread_id=thread_id,
context_config=context_config,
)
if not result:
return []
messages_obj = result.get("messages")
if not isinstance(messages_obj, list):
return []
raw_messages = [item for item in messages_obj if isinstance(item, dict)]
await context_cache.set(
thread_id=thread_id,
runtime_mode=runtime_mode.value,
context_config=context_config,
messages=raw_messages,
)
raw_messages: list[dict[str, object]] = result.get("messages") or []
if not raw_messages:
return []
@@ -120,6 +159,24 @@ async def _build_recent_context_messages(
:_MAX_CONTEXT_ATTACHMENTS
]
for attachment in attachments:
mime_type = attachment.mime_type or "image/png"
cached_b64 = await attachment_cache.get(
bucket=attachment.bucket,
path=attachment.path,
mime_type=mime_type,
)
if cached_b64:
image_blocks.append(
{
"type": "image",
"source": {
"type": "base64",
"media_type": mime_type,
"data": cached_b64,
},
}
)
continue
try:
image_bytes = await supabase_service.download_bytes(
bucket=attachment.bucket,
@@ -128,12 +185,18 @@ async def _build_recent_context_messages(
except Exception:
continue
b64_data = base64.b64encode(image_bytes).decode("utf-8")
await attachment_cache.set(
bucket=attachment.bucket,
path=attachment.path,
mime_type=mime_type,
base64_data=b64_data,
)
image_blocks.append(
{
"type": "image",
"source": {
"type": "base64",
"media_type": attachment.mime_type or "image/png",
"media_type": mime_type,
"data": b64_data,
},
}
@@ -183,6 +246,13 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
raise ValueError("run_input is required")
run_input = parse_run_input(run_input_raw)
runtime_mode = RuntimeMode.CHAT
try:
runtime_mode = parse_forwarded_props_runtime_mode(
getattr(run_input, "forwarded_props", None)
)
except ValueError:
runtime_mode = RuntimeMode.CHAT
runtime_config = RuntimeConfig.model_validate(runtime_config_raw or {})
thread_id = run_input.thread_id
run_id = run_input.run_id
@@ -195,7 +265,9 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
async with AsyncSessionLocal() as session:
current_user = CurrentUser(id=owner_id)
user_context = await _build_user_context(owner_id=owner_id, session=session)
user_context = await _build_user_context(
owner_id=owner_id, session=session, session_id=thread_id
)
memories_service = MemoriesService(
repository=SQLAlchemyMemoriesRepository(session),
session=session,
@@ -226,6 +298,7 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
context_messages = await _build_recent_context_messages(
session=session,
thread_id=thread_id,
runtime_mode=runtime_mode,
context_config=runtime_config.context,
)
+9
View File
@@ -161,6 +161,15 @@ class AgentRuntimeSettings(BaseModel):
user_context_cache_prefix: str = "agent:user-context"
user_context_cache_ttl_seconds: int = Field(default=600, ge=60, le=86400)
user_context_cache_max_turns: int = Field(default=6, ge=1, le=100)
context_messages_cache_prefix: str = "agent:context-messages"
context_messages_cache_ttl_seconds: int = Field(default=600, ge=30, le=86400)
attachment_content_cache_prefix: str = "agent:attachment-content"
attachment_content_cache_ttl_seconds: int = Field(default=1800, ge=30, le=86400)
attachment_content_cache_max_base64_bytes: int = Field(
default=6 * 1024 * 1024,
ge=1024,
le=64 * 1024 * 1024,
)
class AutomationSchedulerSettings(BaseModel):
+12 -1
View File
@@ -3,7 +3,7 @@ from __future__ import annotations
from enum import Enum
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, field_validator
from schemas.agent.ui_hints import UiHintsPayload
@@ -86,6 +86,17 @@ class KeyEntity(BaseModel):
type: str
value: str | None = None
@field_validator("value", mode="before")
@classmethod
def normalize_value(cls, value: object) -> object:
if value is None:
return None
if isinstance(value, str):
return value
if isinstance(value, bool | int | float):
return str(value)
return value
class ConstraintItem(BaseModel):
model_config = ConfigDict(extra="forbid")
+4
View File
@@ -0,0 +1,4 @@
from .factory import get_cache_store
from .interfaces import CacheStore
__all__ = ["CacheStore", "get_cache_store"]
+13
View File
@@ -0,0 +1,13 @@
from __future__ import annotations
from .interfaces import CacheStore
from .redis_store import RedisCacheStore
_cache_store: CacheStore | None = None
def get_cache_store() -> CacheStore:
global _cache_store
if _cache_store is None:
_cache_store = RedisCacheStore()
return _cache_store
+19
View File
@@ -0,0 +1,19 @@
from __future__ import annotations
from typing import Protocol
class CacheStore(Protocol):
async def hgetall(self, key: str, /) -> dict[str, str]: ...
async def hset(self, key: str, /, mapping: dict[str, str]) -> int: ...
async def hincrby(self, key: str, field: str, amount: int = 1, /) -> int: ...
async def expire(self, key: str, ttl_seconds: int, /) -> int: ...
async def delete(self, *keys: str) -> int: ...
async def sadd(self, key: str, *members: str) -> int: ...
async def smembers(self, key: str, /) -> set[str]: ...
@@ -0,0 +1,80 @@
from __future__ import annotations
import inspect
from typing import Any
from services.base.redis import get_or_init_redis_client
from .interfaces import CacheStore
def _to_text(value: Any) -> str | None:
if isinstance(value, str):
return value
if isinstance(value, bytes):
try:
return value.decode("utf-8")
except UnicodeDecodeError:
return None
return None
async def _maybe_await(value: Any) -> Any:
if inspect.isawaitable(value):
return await value
return value
class RedisCacheStore(CacheStore):
async def hgetall(self, key: str) -> dict[str, str]:
client = await get_or_init_redis_client()
raw = await _maybe_await(client.hgetall(key))
if not isinstance(raw, dict):
return {}
decoded: dict[str, str] = {}
for raw_key, raw_value in raw.items():
key_text = _to_text(raw_key)
value_text = _to_text(raw_value)
if key_text is None or value_text is None:
continue
decoded[key_text] = value_text
return decoded
async def hset(self, key: str, mapping: dict[str, str]) -> int:
client = await get_or_init_redis_client()
result = await _maybe_await(client.hset(key, mapping=mapping))
return int(result)
async def hincrby(self, key: str, field: str, amount: int = 1) -> int:
client = await get_or_init_redis_client()
result = await _maybe_await(client.hincrby(key, field, amount))
return int(result)
async def expire(self, key: str, ttl_seconds: int) -> int:
client = await get_or_init_redis_client()
result = await _maybe_await(client.expire(key, ttl_seconds))
return int(result)
async def delete(self, *keys: str) -> int:
if not keys:
return 0
client = await get_or_init_redis_client()
result = await _maybe_await(client.delete(*keys))
return int(result)
async def sadd(self, key: str, *members: str) -> int:
if not members:
return 0
client = await get_or_init_redis_client()
result = await _maybe_await(client.sadd(key, *members))
return int(result)
async def smembers(self, key: str) -> set[str]:
client = await get_or_init_redis_client()
raw = await _maybe_await(client.smembers(key))
if isinstance(raw, set):
return {value for item in raw if (value := _to_text(item))}
if isinstance(raw, list | tuple):
return {value for item in raw if (value := _to_text(item))}
return set()
+49 -1
View File
@@ -1,6 +1,6 @@
from __future__ import annotations
from datetime import date
from datetime import date, datetime, timezone
import hashlib
from urllib.parse import urlparse
@@ -10,6 +10,9 @@ from fastapi import HTTPException
from sqlalchemy.exc import IntegrityError
from core.auth.models import CurrentUser
from core.agentscope.caches.context_messages_cache import (
create_context_messages_cache,
)
from core.agentscope.schemas.agui_input import extract_latest_user_payload
from core.config.settings import config
from core.logging import get_logger
@@ -124,6 +127,13 @@ class AgentService:
visibility_mask=visibility_mask,
)
await self._repository.commit()
await self._append_context_cache_user_message(
thread_id=thread_id,
runtime_mode=runtime_mode,
visibility_mask=visibility_mask,
content=user_message_text,
metadata=user_message_metadata,
)
queue = "automation" if runtime_mode == RuntimeMode.AUTOMATION else "agent"
task_id = await self._queue.enqueue(
@@ -147,6 +157,44 @@ class AgentService:
created=created,
)
async def _append_context_cache_user_message(
self,
*,
thread_id: str,
runtime_mode: RuntimeMode,
visibility_mask: int,
content: str,
metadata: AgentChatMessageMetadata | None,
) -> None:
metadata_payload = (
metadata.model_dump(mode="json", exclude_none=True)
if isinstance(metadata, AgentChatMessageMetadata)
else None
)
message_payload: dict[str, object] = {
"role": "user",
"content": content,
"timestamp": datetime.now(timezone.utc).isoformat(timespec="seconds"),
}
if isinstance(metadata_payload, dict):
message_payload["metadata"] = metadata_payload
try:
context_cache = create_context_messages_cache()
await context_cache.append_message(
thread_id=thread_id,
runtime_mode=runtime_mode.value,
visibility_mask=visibility_mask,
message=message_payload,
)
except Exception as exc:
logger.warning(
"Failed to append user message to context cache",
thread_id=thread_id,
runtime_mode=runtime_mode.value,
error=str(exc),
)
async def _resolve_user_message_visibility_mask(
self, *, runtime_mode: RuntimeMode
) -> int:
+2 -2
View File
@@ -7,10 +7,10 @@ from uuid import UUID
from fastapi import HTTPException
from sqlalchemy.exc import SQLAlchemyError
from core.auth.models import CurrentUser
from core.agentscope.persistence.user_context_cache import (
from core.agentscope.caches.user_context_cache import (
create_user_context_cache,
)
from core.auth.models import CurrentUser
from core.db.base_service import BaseService
from core.logging import get_logger
from schemas.shared.user import UserContext, parse_profile_settings
@@ -0,0 +1,228 @@
from __future__ import annotations
from datetime import datetime, timedelta, timezone
from typing import Any, cast
import pytest
from core.agentscope.caches.context_messages_cache import ContextMessagesCache
from schemas.domain.automation import ContextWindowMode, MessageContextConfig
class _FakeCacheStore:
def __init__(self) -> None:
self.hash_store: dict[str, dict[str, str]] = {}
self.set_store: dict[str, set[str]] = {}
async def hgetall(self, key: str) -> dict[str, str]:
return dict(self.hash_store.get(key, {}))
async def hset(self, key: str, mapping: dict[str, str]) -> int:
self.hash_store[key] = dict(mapping)
return 1
async def hincrby(self, key: str, field: str, amount: int = 1) -> int:
del key, field, amount
return 0
async def expire(self, key: str, ttl_seconds: int) -> int:
del key, ttl_seconds
return 1
async def delete(self, *keys: str) -> int:
for key in keys:
self.hash_store.pop(key, None)
self.set_store.pop(key, None)
return len(keys)
async def sadd(self, key: str, *members: str) -> int:
values = self.set_store.setdefault(key, set())
before = len(values)
for member in members:
values.add(member)
return len(values) - before
async def smembers(self, key: str) -> set[str]:
return set(self.set_store.get(key, set()))
@pytest.mark.asyncio
async def test_context_messages_cache_set_get_roundtrip() -> None:
store = _FakeCacheStore()
cache = ContextMessagesCache(
client=store,
key_prefix="agent:context-messages",
ttl_seconds=600,
)
config = MessageContextConfig(window_mode=ContextWindowMode.DAY, window_count=2)
messages: list[dict[str, object]] = [
{"role": "user", "content": "hello", "timestamp": "2026-03-25T08:00:00+00:00"}
]
await cache.set(
thread_id="thread-1",
runtime_mode="chat",
context_config=config,
messages=messages,
)
loaded = await cache.get(
thread_id="thread-1",
runtime_mode="chat",
context_config=config,
)
assert loaded is not None
assert loaded[0]["content"] == "hello"
@pytest.mark.asyncio
async def test_context_messages_cache_append_skips_when_not_visible() -> None:
store = _FakeCacheStore()
cache = ContextMessagesCache(
client=store,
key_prefix="agent:context-messages",
ttl_seconds=600,
)
config = MessageContextConfig(
window_mode=ContextWindowMode.NUMBER,
window_count=1,
)
await cache.set(
thread_id="thread-1",
runtime_mode="chat",
context_config=config,
messages=cast(
list[dict[str, Any]],
[
{
"role": "user",
"content": "q1",
"timestamp": "2026-03-25T08:00:00+00:00",
}
],
),
)
await cache.append_message(
thread_id="thread-1",
runtime_mode="chat",
visibility_mask=1,
message={"role": "assistant", "content": "a1"},
)
loaded = await cache.get(
thread_id="thread-1",
runtime_mode="chat",
context_config=config,
)
assert loaded is not None
assert len(loaded) == 1
@pytest.mark.asyncio
async def test_context_messages_cache_append_trims_number_window() -> None:
store = _FakeCacheStore()
cache = ContextMessagesCache(
client=store,
key_prefix="agent:context-messages",
ttl_seconds=600,
)
config = MessageContextConfig(
window_mode=ContextWindowMode.NUMBER,
window_count=1,
)
await cache.set(
thread_id="thread-1",
runtime_mode="chat",
context_config=config,
messages=cast(
list[dict[str, Any]],
[
{
"role": "user",
"content": "q0",
"timestamp": "2026-03-25T08:00:00+00:00",
},
{
"role": "assistant",
"content": "a0",
"timestamp": "2026-03-25T08:01:00+00:00",
},
{
"role": "user",
"content": "q1",
"timestamp": "2026-03-25T08:02:00+00:00",
},
],
),
)
await cache.append_message(
thread_id="thread-1",
runtime_mode="chat",
visibility_mask=2,
message={"role": "assistant", "content": "a1"},
)
loaded = await cache.get(
thread_id="thread-1",
runtime_mode="chat",
context_config=config,
)
assert loaded is not None
assert [str(item["content"]) for item in loaded] == ["q1", "a1"]
@pytest.mark.asyncio
async def test_context_messages_cache_append_trims_day_window() -> None:
store = _FakeCacheStore()
cache = ContextMessagesCache(
client=store,
key_prefix="agent:context-messages",
ttl_seconds=600,
)
config = MessageContextConfig(window_mode=ContextWindowMode.DAY, window_count=2)
now = datetime(2026, 3, 25, 10, 0, tzinfo=timezone.utc)
yesterday = now - timedelta(days=1)
two_days_ago = now - timedelta(days=2)
await cache.set(
thread_id="thread-1",
runtime_mode="chat",
context_config=config,
messages=cast(
list[dict[str, Any]],
[
{
"role": "assistant",
"content": "d-2",
"timestamp": two_days_ago.isoformat(),
},
{
"role": "assistant",
"content": "d-1",
"timestamp": yesterday.isoformat(),
},
],
),
)
await cache.append_message(
thread_id="thread-1",
runtime_mode="chat",
visibility_mask=2,
message={
"role": "assistant",
"content": "d0",
"timestamp": now.isoformat(),
},
)
loaded = await cache.get(
thread_id="thread-1",
runtime_mode="chat",
context_config=config,
)
assert loaded is not None
assert [str(item["content"]) for item in loaded] == ["d-1", "d0"]
@@ -145,6 +145,38 @@ async def test_store_persists_tool_output_with_summary_as_content(
assert append_kwargs["visibility_mask"] == (1 << 0)
@pytest.mark.asyncio
async def test_store_sets_history_only_visibility_for_automation_worker_output(
monkeypatch: pytest.MonkeyPatch,
) -> None:
captured: dict[str, object] = {}
fake_chat_session = SimpleNamespace(state_snapshot={}, message_count=3)
_patch_repositories(monkeypatch, captured, fake_chat_session)
store = store_module.SqlAlchemyEventStore(session_factory=lambda: _FakeSessionCtx())
await store.persist(
{
"type": "TEXT_MESSAGE_END",
"threadId": "00000000-0000-0000-0000-000000000001",
"runId": "run-auto-1",
"messageId": "assistant-auto-1",
"role": "assistant",
"stage": "worker",
"runtime_mode": "automation",
"status": "success",
"answer": "automation-result",
"key_points": [],
"result_type": "summary",
"suggested_actions": [],
"error": None,
}
)
append_kwargs = cast(dict[str, Any], captured["append_kwargs"])
assert append_kwargs["content"] == "automation-result"
assert append_kwargs["visibility_mask"] == (1 << 0)
@pytest.mark.asyncio
async def test_store_persists_router_step_output_for_cost_tracking(
monkeypatch: pytest.MonkeyPatch,
@@ -5,7 +5,7 @@ from uuid import uuid4
import pytest
from core.agentscope.persistence.user_context_cache import UserContextCache
from core.agentscope.caches.user_context_cache import UserContextCache
from schemas.shared.user import (
UserContext,
parse_profile_settings,
@@ -28,6 +28,7 @@ async def test_tool_result_event_uses_runtime_tool_call_id() -> None:
session_id="thread-1",
run_id="run-1",
stage="worker",
runtime_mode="chat",
emit_text_events=False,
emit_tool_events=True,
)
@@ -65,3 +66,4 @@ async def test_tool_result_event_uses_runtime_tool_call_id() -> None:
result_events = [e for e in pipeline.events if e.get("type") == "TOOL_CALL_RESULT"]
assert len(result_events) == 1
assert result_events[0]["tool_call_id"] == "runtime-call-123"
assert result_events[0]["runtime_mode"] == "chat"
@@ -0,0 +1,38 @@
from __future__ import annotations
from schemas.agent.runtime_models import RouterAgentOutput
def test_router_agent_output_coerces_key_entity_value_to_string() -> None:
payload = {
"normalized_task_input": {
"user_text": "test",
"multimodal_summary": [],
"context_summary": "",
},
"key_entities": [
{
"name": "priority",
"type": "number",
"value": 8,
}
],
"constraints": [],
"task_typing": {
"primary": "planning",
"secondary": [],
},
"execution_mode": "onestep",
"result_typing": {
"primary": "summary",
"secondary": [],
},
"ui": {
"ui_mode": "none",
"ui_decision_reason": "test",
},
}
model = RouterAgentOutput.model_validate(payload)
assert model.key_entities[0].value == "8"
@@ -0,0 +1,392 @@
# Agent Run Cancel (Failed Semantics) Implementation Plan
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
**Goal:** 为 `/api/v1/agent/runs` 增加可中断能力,在用户触发 cancel 后真正停止运行中的 agent 流程,并以 `RUN_ERROR(code=RUN_CANCELED)` 结束,最终将 session 状态落为 `failed`
**Architecture:** 使用“协作取消 + 主任务中断”方案:API 层写入 Redis cancel 信号,runtime 在 worker 进程内并行 watcher 监听信号,命中后先调用 active agent 的 `interrupt()` 做优雅收尾,再 `cancel()` 当前 run 主任务做硬兜底。终态统一通过 `RUN_ERROR` 事件落库,复用现有 `FAILED` 会话语义,避免数据库枚举迁移。
**Tech Stack:** FastAPI, TaskIQ, Redis, AgentScope, SQLAlchemy, Pytest, Ruff, BasedPyright
---
### Task 1: 先更新协议文档(接口与事件语义)
**Files:**
- Modify: `docs/protocols/agent/api-endpoints.md`
- Modify: `docs/protocols/agent/sse-events.md`
**Step 1: 在 API 文档新增 cancel 端点契约**
`api-endpoints.md` 的端点清单添加:
```md
| POST | `/runs/{thread_id}/cancel` | 请求取消指定 run |
```
并新增章节说明:
- 请求参数:`thread_id` + `runId`(建议 query
- 返回:`202 Accepted` + `accepted: true`
- 语义:仅表示“取消请求已接收”,不保证已即时终止
**Step 2: 在 SSE 文档补充取消终态语义**
`sse-events.md``RUN_ERROR` 章节补充:
```json
{
"type": "RUN_ERROR",
"threadId": "...",
"runId": "...",
"message": "run canceled by user",
"code": "RUN_CANCELED"
}
```
并明确:
- `RUN_CANCELED` 是用户主动中断,不是系统异常
- 本阶段仍复用 session `failed`(向后兼容)
**Step 3: 文档自检**
检查文档是否同时覆盖:
- HTTP 行为
- SSE 终态事件
- 兼容策略(不引入新 session 状态)
**Step 4: 提交文档变更**
```bash
git add docs/protocols/agent/api-endpoints.md docs/protocols/agent/sse-events.md
git commit -m "docs: define agent run cancel API and RUN_CANCELED error semantics"
```
### Task 2: 打通 API 层到队列层的 cancel 信号写入
**Files:**
- Modify: `backend/src/v1/agent/schemas.py`
- Modify: `backend/src/v1/agent/dependencies.py`
- Modify: `backend/src/v1/agent/service.py`
- Modify: `backend/src/v1/agent/router.py`
- Test: `backend/tests/unit/v1/agent/test_service.py`
- Test: `backend/tests/integration/v1/agent/test_routes.py`
**Step 1: 增加 cancel 接口响应 schema**
`v1/agent/schemas.py` 增加:
```python
class CancelRunResponse(BaseModel):
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
thread_id: str = Field(alias="threadId")
run_id: str = Field(alias="runId")
accepted: bool
```
**Step 2: 扩展队列协议接口**
`QueueClientLike` 增加:
```python
async def request_cancel(self, *, thread_id: str, run_id: str, requested_by: str) -> None: ...
```
**Step 3: 在 `TaskiqQueueClient` 实现 request_cancel**
`v1/agent/dependencies.py` 中新增:
- cancel key 规范:`agent:cancel:{thread_id}:{run_id}`
- `SET key value EX <ttl>` 写入取消信号
- `value` 可写 json 字符串(包含 user_id/timestamp
**Step 4: 在 service 层新增 cancel_run**
`v1/agent/service.py` 增加方法:
- 校验 session owner(复用 `get_session_owner + ensure_session_owner`
- 调用 `self._queue.request_cancel(...)`
- 返回 `accepted` 结果 DTO
**Step 5: 在 router 新增 cancel 路由**
`v1/agent/router.py` 新增:
```python
@router.post("/runs/{thread_id}/cancel", response_model=CancelRunResponse, status_code=202)
async def cancel_run(...):
...
```
约束:
- `runId` 必填(建议 query
- 非 owner 返回 403
- 参数非法返回 422
**Step 6: 写 service 单测(先红)**
`test_service.py` 添加:
- owner 可发起 cancel`queue.request_cancel` 被调用
- 非 owner cancel 返回 403
**Step 7: 运行单测确认失败**
Run: `uv run pytest backend/tests/unit/v1/agent/test_service.py -k cancel -q`
Expected: 至少 1 个测试失败(新逻辑尚未实现)
**Step 8: 实现最小代码使测试通过**
按 Step 1-5 完成实现,避免额外重构。
**Step 9: 运行测试验证通过**
Run: `uv run pytest backend/tests/unit/v1/agent/test_service.py -k cancel -q`
Expected: PASS
**Step 10: 增加路由集成测试**
`test_routes.py` 增加:
- `POST /api/v1/agent/runs/{thread_id}/cancel?runId=...` 返回 202
- 响应字段别名正确(`threadId/runId/accepted`
**Step 11: 运行路由测试**
Run: `uv run pytest backend/tests/integration/v1/agent/test_routes.py -k cancel -q`
Expected: PASS
**Step 12: 提交 API 层变更**
```bash
git add backend/src/v1/agent/schemas.py backend/src/v1/agent/dependencies.py backend/src/v1/agent/service.py backend/src/v1/agent/router.py backend/tests/unit/v1/agent/test_service.py backend/tests/integration/v1/agent/test_routes.py
git commit -m "feat: add agent run cancel endpoint and Redis cancel signal"
```
### Task 3: runtime runner 植入取消 watcher 与优雅中断
**Files:**
- Modify: `backend/src/core/agentscope/runtime/runner.py`
- Test: `backend/tests/unit/core/agentscope/runtime/test_runner.py`
**Step 1: 在 runner.execute 增加 cancel_checker 参数**
更新 `execute()` 签名:
```python
cancel_checker: Callable[[], Awaitable[bool]] | None = None
```
并保持默认 `None` 向后兼容。
**Step 2: 增加 active agent 引用与锁**
`AgentScopeRunner.__init__` 增加:
- `self._active_agent: JsonReActAgent | None = None`
- `self._active_agent_lock = asyncio.Lock()`
**Step 3: 在 `_run_worker_stage` 设置 active agent 生命周期**
`agent.reply_json(...)` 前后包裹:
- before: 记录 `self._active_agent = agent`
- finally: 清理引用
**Step 4: 新增 `_watch_cancel_signal` 协程**
行为:
- 循环调用 `cancel_checker()`
- 命中后先尝试 `await active_agent.interrupt()`
- 再 `run_task.cancel("run canceled by user")`
- 间隔 `await asyncio.sleep(0.2)`
**Step 5: 在 execute 启停 watcher**
- `run_task = asyncio.current_task()`
- 如果有 `cancel_checker``create_task(_watch_cancel_signal(...))`
- `finally` 中停止 watcher 并 `await` 回收
**Step 6: 补 stage 边界取消 gate(关键)**
在 router 结束后、worker 开始前检查一次 `cancel_checker()`
- 为 true 时抛 `asyncio.CancelledError`
目的:防止“router 已结束但仍进入 worker”。
**Step 7: 写 runner 单测(先红)**
新增测试用例:
- cancel 信号触发后,`execute` 抛出 `CancelledError`
- worker 未被继续执行(或中途被中断)
**Step 8: 运行 runner 测试**
Run: `uv run pytest backend/tests/unit/core/agentscope/runtime/test_runner.py -k cancel -q`
Expected: PASS
**Step 9: 提交 runner 变更**
```bash
git add backend/src/core/agentscope/runtime/runner.py backend/tests/unit/core/agentscope/runtime/test_runner.py
git commit -m "feat: add cooperative cancellation watcher to agentscope runner"
```
### Task 4: orchestrator 与 task worker 处理 CancelledError 终态
**Files:**
- Modify: `backend/src/core/agentscope/runtime/orchestrator.py`
- Modify: `backend/src/core/agentscope/runtime/tasks.py`
- Test: `backend/tests/unit/core/agentscope/runtime/test_orchestrator.py`
- Test: `backend/tests/unit/core/agentscope/runtime/test_tasks.py`
**Step 1: orchestrator 单独捕获 CancelledError**
`orchestrator.run()` 添加:
```python
except asyncio.CancelledError:
await self._pipeline.emit(... RUN_ERROR code="RUN_CANCELED" ...)
raise
```
保留现有 `except Exception` 处理系统错误。
**Step 2: task 层构造 cancel_checker 并注入 runtime.run**
`tasks.py`
- 构造 key`agent:cancel:{thread_id}:{run_id}`
- 定义 `async def cancel_checker() -> bool: return bool(await redis.exists(key))`
- 调用 `runtime.run(..., cancel_checker=cancel_checker)`
**Step 3: task 层补资源清理**
`run_agentscope_task``finally`
- 删除 cancel key 或缩短 TTL
- 记录日志(仅必要字段)
**Step 4: 写 orchestrator 单测(先红)**
验证:
- 收到 `CancelledError` 时发 `RUN_ERROR``code == "RUN_CANCELED"`
**Step 5: 写 tasks 单测(先红)**
验证:
- runtime 收到的 `cancel_checker` 可用
- key 命中时上抛 `CancelledError` 路径成立
**Step 6: 运行测试**
Run: `uv run pytest backend/tests/unit/core/agentscope/runtime/test_orchestrator.py backend/tests/unit/core/agentscope/runtime/test_tasks.py -k cancel -q`
Expected: PASS
**Step 7: 提交 runtime 编排层变更**
```bash
git add backend/src/core/agentscope/runtime/orchestrator.py backend/src/core/agentscope/runtime/tasks.py backend/tests/unit/core/agentscope/runtime/test_orchestrator.py backend/tests/unit/core/agentscope/runtime/test_tasks.py
git commit -m "fix: emit RUN_CANCELED error when run task is interrupted"
```
### Task 5: 事件流与持久化一致性回归
**Files:**
- Modify: `backend/tests/unit/core/agentscope/events/test_store.py`
- Modify: `backend/tests/unit/core/agentscope/events/test_agui_codec.py`
- Modify: `backend/tests/integration/v1/agent/test_sse_flow_live.py`
**Step 1: 补 event store 行为测试**
新增断言:
- 当 `RUN_ERROR``code=RUN_CANCELED` 时,session 状态依然为 `FAILED`
**Step 2: 补 codec 测试**
新增断言:
- `RUN_ERROR``code` 字段能正确透传到 wire event
**Step 3: 补 SSE 集成测试**
场景:
- 触发 `/runs`
- 触发 `/runs/{thread_id}/cancel`
- SSE 最终出现 `RUN_ERROR(code=RUN_CANCELED)`
**Step 4: 运行事件相关测试**
Run: `uv run pytest backend/tests/unit/core/agentscope/events/test_store.py backend/tests/unit/core/agentscope/events/test_agui_codec.py backend/tests/integration/v1/agent/test_sse_flow_live.py -k "cancel or run_error" -q`
Expected: PASS
**Step 5: 提交事件层变更**
```bash
git add backend/tests/unit/core/agentscope/events/test_store.py backend/tests/unit/core/agentscope/events/test_agui_codec.py backend/tests/integration/v1/agent/test_sse_flow_live.py
git commit -m "test: cover RUN_CANCELED propagation across store codec and SSE"
```
### Task 6: 全量验证与发布前检查
**Files:**
- Modify: `docs/protocols/agent/api-endpoints.md`(如需补充最终字段)
- Modify: `docs/protocols/agent/sse-events.md`(如需补充最终字段)
**Step 1: 运行受影响单元测试集合**
Run:
```bash
uv run pytest backend/tests/unit/v1/agent/test_service.py backend/tests/unit/core/agentscope/runtime/test_runner.py backend/tests/unit/core/agentscope/runtime/test_orchestrator.py backend/tests/unit/core/agentscope/runtime/test_tasks.py backend/tests/unit/core/agentscope/events/test_store.py backend/tests/unit/core/agentscope/events/test_agui_codec.py -q
```
Expected: PASS
**Step 2: 运行受影响集成测试集合**
Run:
```bash
uv run pytest backend/tests/integration/v1/agent/test_routes.py backend/tests/integration/v1/agent/test_sse_flow_live.py -q
```
Expected: PASS
**Step 3: 运行静态检查**
Run:
```bash
uv run ruff check backend/src backend/tests
uv run basedpyright
```
Expected: PASS(无新增 lint/type 错误)
**Step 4: 手工验证路径**
手工流程:
- 发起 `/runs`
- 立刻调用 `/runs/{thread_id}/cancel?runId=...`
- 观察 SSE:应以 `RUN_ERROR(code=RUN_CANCELED)` 结束
- 检查 session`status=failed`
**Step 5: 最终提交**
```bash
git add docs/protocols/agent/api-endpoints.md docs/protocols/agent/sse-events.md backend/src/v1/agent/*.py backend/src/core/agentscope/runtime/*.py backend/tests/unit/core/agentscope/runtime/*.py backend/tests/unit/core/agentscope/events/*.py backend/tests/unit/v1/agent/test_service.py backend/tests/integration/v1/agent/test_routes.py backend/tests/integration/v1/agent/test_sse_flow_live.py
git commit -m "feat: support run cancellation with RUN_CANCELED failed semantics"
```
---
## 风险与回滚
- 风险 1:cancel key 误命中导致误中断
- 缓解:key 粒度使用 `thread_id + run_id`,并设置 TTL
- 风险 2:中断时出现重复终态事件
- 缓解:在 orchestrator 保证 CancelledError 只走 `RUN_ERROR` 分支,避免继续发 `RUN_FINISHED`
- 风险 3:高并发下 Redis 轮询压力上升
- 缓解:轮询间隔 200ms,后续按并发量评估改为 pub/sub
回滚策略:
- 回滚 `router/service/dependencies` cancel 新接口
- 回滚 `runner/orchestrator/tasks` cancel 注入逻辑
- 保持原 `POST /runs` 与 SSE 流程不变
+5 -2
View File
@@ -364,7 +364,8 @@ cost = uncached_prompt_tokens * input_cost_per_token
| 0 | `UI_HISTORY` | `/history` API 投影可见的消息 |
| 1 | `CONTEXT_ASSEMBLY` | 运行时上下文装配(context assembly)可见 |
> 新消息入库时,`chat` 模式设置 `mask = UI_HISTORY | CONTEXT_ASSEMBLY`(值为 3),`automation` 模式设置 `mask = 0`
> 用户输入入库时,`chat` 模式设置 `mask = UI_HISTORY | CONTEXT_ASSEMBLY`(值为 3),`automation` 模式设置 `mask = 0`
> agent 运行产物入库时,`automation` 模式设置 `mask = UI_HISTORY`(值为 1),用于展示历史但不参与 context assembly。
### /history API
@@ -385,6 +386,7 @@ WHERE (visibility_mask & 2) != 0
**影响**
- `chat` 模式用户输入:mask=3 → 进入 `/history` ✅,进入 context assembly ✅
- `automation` 模式用户输入:mask=0 → 进入 `/history` ❌,进入 context assembly ❌
- `automation` 模式 agent 输出:mask=1 → 进入 `/history` ✅,进入 context assembly ❌
### Automation 模式上下文注入
@@ -396,7 +398,8 @@ WHERE (visibility_mask & 2) != 0
|------|--------|--------------|
| Pipeline | `router` -> `worker` | `router` -> `worker` |
| 用户输入 visibility_mask | `UI_HISTORY \| CONTEXT_ASSEMBLY` | `0` |
| 进入 /history | ✅ | ❌ |
| agent 输出 visibility_mask | `UI_HISTORY \| CONTEXT_ASSEMBLY`memory stage 仅 `UI_HISTORY` | `UI_HISTORY` |
| 进入 /history | ✅ | ✅(仅 agent 输出) |
| 进入 context assembly | ✅(自动) | ❌(通过 run_input 注入) |
| enabled_tools 来源 | `system_agents.yaml` worker 配置 | `AutomationJob.config.enabled_tools` |
| context 配置来源 | `system_agents.yaml` router context_messages | `AutomationJob.config.context` |