feat: 重构 agentscope 缓存架构,新增消息和附件缓存
This commit is contained in:
+1
-1
@@ -92,5 +92,5 @@ SOCIAL_APP_VERSION__DOWNLOAD_BASE_URL=
|
||||
############
|
||||
# Test相关
|
||||
############
|
||||
SOCIAL_TEST__PHONE=+8613812345678
|
||||
SOCIAL_TEST__PHONE=8613812345678
|
||||
SOCIAL_TEST__PASSWORD=Test@123456
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
from .attachment_content_cache import (
|
||||
AttachmentContentCache,
|
||||
create_attachment_content_cache,
|
||||
)
|
||||
from .context_messages_cache import (
|
||||
ContextMessagesCache,
|
||||
create_context_messages_cache,
|
||||
)
|
||||
from .user_context_cache import (
|
||||
UserContextCache,
|
||||
create_user_context_cache,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AttachmentContentCache",
|
||||
"ContextMessagesCache",
|
||||
"UserContextCache",
|
||||
"create_attachment_content_cache",
|
||||
"create_context_messages_cache",
|
||||
"create_user_context_cache",
|
||||
]
|
||||
@@ -0,0 +1,96 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.config.settings import config
|
||||
from core.logging import get_logger
|
||||
from services.caches import CacheStore, get_cache_store
|
||||
|
||||
logger = get_logger("core.agentscope.caches.attachment_content_cache")
|
||||
|
||||
|
||||
class AttachmentContentCache:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
client: CacheStore,
|
||||
key_prefix: str,
|
||||
ttl_seconds: int,
|
||||
max_base64_bytes: int,
|
||||
) -> None:
|
||||
self._client = client
|
||||
self._key_prefix = key_prefix
|
||||
self._ttl_seconds = ttl_seconds
|
||||
self._max_base64_bytes = max_base64_bytes
|
||||
|
||||
async def get(
|
||||
self,
|
||||
*,
|
||||
bucket: str,
|
||||
path: str,
|
||||
mime_type: str,
|
||||
) -> str | None:
|
||||
key = self._key(bucket=bucket, path=path, mime_type=mime_type)
|
||||
try:
|
||||
raw = await self._client.hgetall(key)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to read attachment content cache",
|
||||
bucket=bucket,
|
||||
path=path,
|
||||
error=str(exc),
|
||||
)
|
||||
return None
|
||||
|
||||
payload = raw.get("base64")
|
||||
if not isinstance(payload, str) or not payload:
|
||||
return None
|
||||
return payload
|
||||
|
||||
async def set(
|
||||
self,
|
||||
*,
|
||||
bucket: str,
|
||||
path: str,
|
||||
mime_type: str,
|
||||
base64_data: str,
|
||||
) -> None:
|
||||
encoded_bytes = len(base64_data.encode("utf-8"))
|
||||
if encoded_bytes > self._max_base64_bytes:
|
||||
logger.info(
|
||||
"Skip attachment cache write due to size limit",
|
||||
bucket=bucket,
|
||||
path=path,
|
||||
encoded_bytes=encoded_bytes,
|
||||
max_bytes=self._max_base64_bytes,
|
||||
)
|
||||
return
|
||||
|
||||
key = self._key(bucket=bucket, path=path, mime_type=mime_type)
|
||||
try:
|
||||
await self._client.hset(
|
||||
key,
|
||||
mapping={
|
||||
"base64": base64_data,
|
||||
"mime_type": mime_type,
|
||||
},
|
||||
)
|
||||
await self._client.expire(key, self._ttl_seconds)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to write attachment content cache",
|
||||
bucket=bucket,
|
||||
path=path,
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
def _key(self, *, bucket: str, path: str, mime_type: str) -> str:
|
||||
return f"{self._key_prefix}:{bucket}:{path}:mime:{mime_type}"
|
||||
|
||||
|
||||
def create_attachment_content_cache() -> AttachmentContentCache:
|
||||
runtime = config.agent_runtime
|
||||
return AttachmentContentCache(
|
||||
client=get_cache_store(),
|
||||
key_prefix=runtime.attachment_content_cache_prefix,
|
||||
ttl_seconds=runtime.attachment_content_cache_ttl_seconds,
|
||||
max_base64_bytes=runtime.attachment_content_cache_max_base64_bytes,
|
||||
)
|
||||
@@ -0,0 +1,281 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
import json
|
||||
|
||||
from core.config.settings import config
|
||||
from core.logging import get_logger
|
||||
from schemas.agent.visibility import SystemVisibilityBit, bit_mask
|
||||
from schemas.domain.automation import ContextWindowMode, MessageContextConfig
|
||||
from services.caches import CacheStore, get_cache_store
|
||||
|
||||
logger = get_logger("core.agentscope.caches.context_messages_cache")
|
||||
|
||||
|
||||
class ContextMessagesCache:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
client: CacheStore,
|
||||
key_prefix: str,
|
||||
ttl_seconds: int,
|
||||
) -> None:
|
||||
self._client = client
|
||||
self._key_prefix = key_prefix
|
||||
self._ttl_seconds = ttl_seconds
|
||||
|
||||
async def get(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
runtime_mode: str,
|
||||
context_config: MessageContextConfig,
|
||||
) -> list[dict[str, object]] | None:
|
||||
key = self._cache_key(
|
||||
thread_id=thread_id,
|
||||
runtime_mode=runtime_mode,
|
||||
context_config=context_config,
|
||||
)
|
||||
try:
|
||||
raw = await self._client.hgetall(key)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to read context messages cache",
|
||||
thread_id=thread_id,
|
||||
error=str(exc),
|
||||
)
|
||||
return None
|
||||
|
||||
payload = raw.get("payload")
|
||||
if not isinstance(payload, str) or not payload:
|
||||
return None
|
||||
|
||||
try:
|
||||
decoded = json.loads(payload)
|
||||
except Exception:
|
||||
await self._safe_delete(key)
|
||||
return None
|
||||
|
||||
if not isinstance(decoded, dict):
|
||||
await self._safe_delete(key)
|
||||
return None
|
||||
|
||||
messages = decoded.get("messages")
|
||||
if not isinstance(messages, list):
|
||||
await self._safe_delete(key)
|
||||
return None
|
||||
|
||||
return [item for item in messages if isinstance(item, dict)]
|
||||
|
||||
async def set(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
runtime_mode: str,
|
||||
context_config: MessageContextConfig,
|
||||
messages: list[dict[str, object]],
|
||||
) -> None:
|
||||
key = self._cache_key(
|
||||
thread_id=thread_id,
|
||||
runtime_mode=runtime_mode,
|
||||
context_config=context_config,
|
||||
)
|
||||
index_key = self._index_key(thread_id=thread_id, runtime_mode=runtime_mode)
|
||||
payload = json.dumps(
|
||||
{
|
||||
"window_mode": context_config.window_mode.value,
|
||||
"window_count": int(context_config.window_count),
|
||||
"messages": [item for item in messages if isinstance(item, dict)],
|
||||
},
|
||||
ensure_ascii=True,
|
||||
separators=(",", ":"),
|
||||
)
|
||||
try:
|
||||
await self._client.hset(key, mapping={"payload": payload})
|
||||
await self._client.expire(key, self._ttl_seconds)
|
||||
await self._client.sadd(index_key, key)
|
||||
await self._client.expire(index_key, self._ttl_seconds)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to write context messages cache",
|
||||
thread_id=thread_id,
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
async def append_message(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
runtime_mode: str,
|
||||
visibility_mask: int,
|
||||
message: dict[str, object],
|
||||
) -> None:
|
||||
if not self._is_context_visible(visibility_mask=visibility_mask):
|
||||
return
|
||||
|
||||
index_key = self._index_key(thread_id=thread_id, runtime_mode=runtime_mode)
|
||||
try:
|
||||
keys = await self._client.smembers(index_key)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to read context cache index",
|
||||
thread_id=thread_id,
|
||||
error=str(exc),
|
||||
)
|
||||
return
|
||||
|
||||
if not keys:
|
||||
return
|
||||
|
||||
normalized = self._normalize_message(message)
|
||||
for key in keys:
|
||||
try:
|
||||
raw = await self._client.hgetall(key)
|
||||
payload_raw = raw.get("payload")
|
||||
if not isinstance(payload_raw, str) or not payload_raw:
|
||||
continue
|
||||
decoded = json.loads(payload_raw)
|
||||
if not isinstance(decoded, dict):
|
||||
continue
|
||||
|
||||
mode = decoded.get("window_mode")
|
||||
count = decoded.get("window_count")
|
||||
messages_raw = decoded.get("messages")
|
||||
if (
|
||||
not isinstance(mode, str)
|
||||
or not isinstance(count, int)
|
||||
or not isinstance(messages_raw, list)
|
||||
):
|
||||
continue
|
||||
|
||||
messages = [item for item in messages_raw if isinstance(item, dict)]
|
||||
messages.append(dict(normalized))
|
||||
trimmed = self._trim_messages(messages=messages, mode=mode, count=count)
|
||||
|
||||
next_payload = json.dumps(
|
||||
{
|
||||
"window_mode": mode,
|
||||
"window_count": count,
|
||||
"messages": trimmed,
|
||||
},
|
||||
ensure_ascii=True,
|
||||
separators=(",", ":"),
|
||||
)
|
||||
await self._client.hset(key, mapping={"payload": next_payload})
|
||||
await self._client.expire(key, self._ttl_seconds)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to append context cache message",
|
||||
key=key,
|
||||
thread_id=thread_id,
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
def _cache_key(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
runtime_mode: str,
|
||||
context_config: MessageContextConfig,
|
||||
) -> str:
|
||||
return (
|
||||
f"{self._key_prefix}:{thread_id}:rm:{runtime_mode}:"
|
||||
f"wm:{context_config.window_mode.value}:wc:{int(context_config.window_count)}"
|
||||
)
|
||||
|
||||
def _index_key(self, *, thread_id: str, runtime_mode: str) -> str:
|
||||
return f"{self._key_prefix}:index:{runtime_mode}:{thread_id}"
|
||||
|
||||
async def _safe_delete(self, key: str) -> None:
|
||||
try:
|
||||
await self._client.delete(key)
|
||||
except Exception:
|
||||
return
|
||||
|
||||
def _trim_messages(
|
||||
self,
|
||||
*,
|
||||
messages: list[dict[str, object]],
|
||||
mode: str,
|
||||
count: int,
|
||||
) -> list[dict[str, object]]:
|
||||
safe_count = max(int(count), 1)
|
||||
if mode == ContextWindowMode.NUMBER.value:
|
||||
return self._trim_by_user_window(messages=messages, count=safe_count)
|
||||
return self._trim_by_day_window(messages=messages, count=safe_count)
|
||||
|
||||
def _trim_by_user_window(
|
||||
self,
|
||||
*,
|
||||
messages: list[dict[str, object]],
|
||||
count: int,
|
||||
) -> list[dict[str, object]]:
|
||||
selected_reversed: list[dict[str, object]] = []
|
||||
user_count = 0
|
||||
for item in reversed(messages):
|
||||
selected_reversed.append(item)
|
||||
role = item.get("role")
|
||||
if isinstance(role, str) and role == "user":
|
||||
user_count += 1
|
||||
if user_count >= count:
|
||||
break
|
||||
selected_reversed.reverse()
|
||||
return selected_reversed
|
||||
|
||||
def _trim_by_day_window(
|
||||
self,
|
||||
*,
|
||||
messages: list[dict[str, object]],
|
||||
count: int,
|
||||
) -> list[dict[str, object]]:
|
||||
selected_reversed: list[dict[str, object]] = []
|
||||
days_seen: list[str] = []
|
||||
for item in reversed(messages):
|
||||
day_value = self._extract_day(item)
|
||||
if day_value not in days_seen:
|
||||
if len(days_seen) >= count:
|
||||
break
|
||||
days_seen.append(day_value)
|
||||
selected_reversed.append(item)
|
||||
selected_reversed.reverse()
|
||||
return selected_reversed
|
||||
|
||||
@staticmethod
|
||||
def _extract_day(message: dict[str, object]) -> str:
|
||||
raw = message.get("timestamp")
|
||||
if isinstance(raw, str) and raw:
|
||||
normalized = raw.replace("Z", "+00:00")
|
||||
try:
|
||||
return datetime.fromisoformat(normalized).date().isoformat()
|
||||
except ValueError:
|
||||
pass
|
||||
return datetime.now(timezone.utc).date().isoformat()
|
||||
|
||||
@staticmethod
|
||||
def _normalize_message(message: dict[str, object]) -> dict[str, object]:
|
||||
normalized: dict[str, object] = {
|
||||
"role": str(message.get("role") or "assistant"),
|
||||
"content": str(message.get("content") or ""),
|
||||
"timestamp": str(
|
||||
message.get("timestamp")
|
||||
or datetime.now(timezone.utc).isoformat(timespec="seconds")
|
||||
),
|
||||
}
|
||||
metadata = message.get("metadata")
|
||||
if isinstance(metadata, dict):
|
||||
normalized["metadata"] = metadata
|
||||
return normalized
|
||||
|
||||
@staticmethod
|
||||
def _is_context_visible(*, visibility_mask: int) -> bool:
|
||||
required = bit_mask(bit=int(SystemVisibilityBit.CONTEXT_ASSEMBLY))
|
||||
return (max(int(visibility_mask), 0) & required) != 0
|
||||
|
||||
|
||||
def create_context_messages_cache() -> ContextMessagesCache:
|
||||
runtime = config.agent_runtime
|
||||
return ContextMessagesCache(
|
||||
client=get_cache_store(),
|
||||
key_prefix=runtime.context_messages_cache_prefix,
|
||||
ttl_seconds=runtime.context_messages_cache_ttl_seconds,
|
||||
)
|
||||
+19
-44
@@ -1,49 +1,26 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import json
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Protocol
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
import redis.asyncio as redis
|
||||
from core.config.settings import config
|
||||
from core.logging import get_logger
|
||||
from schemas.shared.user import (
|
||||
UserContext,
|
||||
parse_profile_settings,
|
||||
)
|
||||
from services.caches import CacheStore, get_cache_store
|
||||
|
||||
logger = get_logger("core.agentscope.persistence.user_context_cache")
|
||||
|
||||
|
||||
class RedisHashClient(Protocol):
|
||||
def hgetall(self, name: str, /) -> Any: ...
|
||||
|
||||
def hset(self, name: str, /, *args: Any, **kwargs: Any) -> Any: ...
|
||||
|
||||
def hincrby(self, name: str, key: str, amount: int = 1, /) -> Any: ...
|
||||
|
||||
def expire(self, name: str, time: int, /) -> Any: ...
|
||||
|
||||
def delete(self, *names: str) -> Any: ...
|
||||
|
||||
def sadd(self, name: str, *values: str) -> Any: ...
|
||||
|
||||
def smembers(self, name: str) -> Any: ...
|
||||
|
||||
|
||||
async def _maybe_await(value: Any) -> Any:
|
||||
if inspect.isawaitable(value):
|
||||
return await value
|
||||
return value
|
||||
logger = get_logger("core.agentscope.caches.user_context_cache")
|
||||
|
||||
|
||||
class UserContextCache:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
client: RedisHashClient,
|
||||
client: CacheStore,
|
||||
key_prefix: str,
|
||||
ttl_seconds: int,
|
||||
max_turns: int,
|
||||
@@ -56,7 +33,7 @@ class UserContextCache:
|
||||
async def get(self, *, session_id: UUID) -> UserContext | None:
|
||||
key = self._key(session_id)
|
||||
try:
|
||||
raw = await _maybe_await(self._client.hgetall(key))
|
||||
raw = await self._client.hgetall(key)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to read user context cache",
|
||||
@@ -107,18 +84,16 @@ class UserContextCache:
|
||||
index_key = self._user_sessions_key(user_id)
|
||||
payload = self._serialize(context)
|
||||
try:
|
||||
await _maybe_await(
|
||||
self._client.hset(
|
||||
key,
|
||||
mapping={
|
||||
"payload": payload,
|
||||
"turns_used": "0",
|
||||
},
|
||||
)
|
||||
await self._client.hset(
|
||||
key,
|
||||
mapping={
|
||||
"payload": payload,
|
||||
"turns_used": "0",
|
||||
},
|
||||
)
|
||||
await _maybe_await(self._client.expire(key, self._ttl_seconds))
|
||||
await _maybe_await(self._client.sadd(index_key, key))
|
||||
await _maybe_await(self._client.expire(index_key, self._ttl_seconds))
|
||||
await self._client.expire(key, self._ttl_seconds)
|
||||
await self._client.sadd(index_key, key)
|
||||
await self._client.expire(index_key, self._ttl_seconds)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to write user context cache",
|
||||
@@ -130,7 +105,7 @@ class UserContextCache:
|
||||
async def invalidate_user(self, *, user_id: UUID) -> int:
|
||||
index_key = self._user_sessions_key(user_id)
|
||||
try:
|
||||
members_raw = await _maybe_await(self._client.smembers(index_key))
|
||||
members_raw = await self._client.smembers(index_key)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to read user context cache index",
|
||||
@@ -147,7 +122,7 @@ class UserContextCache:
|
||||
|
||||
deleted = 0
|
||||
try:
|
||||
deleted_raw = await _maybe_await(self._client.delete(*members))
|
||||
deleted_raw = await self._client.delete(*members)
|
||||
deleted = self._parse_int(deleted_raw)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
@@ -205,7 +180,7 @@ class UserContextCache:
|
||||
|
||||
async def _safe_delete(self, key: str) -> None:
|
||||
try:
|
||||
await _maybe_await(self._client.delete(key))
|
||||
await self._client.delete(key)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to delete user context cache key", key=key, error=str(exc)
|
||||
@@ -214,7 +189,7 @@ class UserContextCache:
|
||||
|
||||
async def _safe_hincrby(self, key: str, field: str, amount: int) -> None:
|
||||
try:
|
||||
await _maybe_await(self._client.hincrby(key, field, amount))
|
||||
await self._client.hincrby(key, field, amount)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to update user context cache usage",
|
||||
@@ -272,7 +247,7 @@ class UserContextCache:
|
||||
|
||||
|
||||
def create_user_context_cache() -> UserContextCache:
|
||||
client = redis.from_url(config.redis.url, decode_responses=True)
|
||||
client = get_cache_store()
|
||||
runtime_settings = config.agent_runtime
|
||||
return UserContextCache(
|
||||
client=client,
|
||||
@@ -1,11 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from decimal import Decimal, InvalidOperation
|
||||
from typing import Any, Callable, Protocol
|
||||
from uuid import UUID
|
||||
|
||||
from core.agentscope.caches.context_messages_cache import (
|
||||
create_context_messages_cache,
|
||||
)
|
||||
from core.agentscope.events.persistence import MessageRepository, SessionRepository
|
||||
from core.logging import get_logger
|
||||
from schemas.agent.forwarded_props import RuntimeMode
|
||||
from schemas.enums import AgentChatMessageRole, AgentChatSessionStatus
|
||||
from schemas.agent.system_agent import AgentType
|
||||
from schemas.agent.runtime_models import AgentOutput, RouterAgentOutput, ToolAgentOutput
|
||||
@@ -174,7 +179,10 @@ class SqlAlchemyEventStore:
|
||||
if locked_session is None:
|
||||
return
|
||||
seq = int(getattr(locked_session, "message_count", 0) or 0) + 1
|
||||
await message_repo.append_message(
|
||||
visibility_mask = self._resolve_stage_visibility_mask(
|
||||
event=event,
|
||||
)
|
||||
persisted = await message_repo.append_message(
|
||||
session_id=session_id,
|
||||
seq=seq,
|
||||
role=role,
|
||||
@@ -186,9 +194,16 @@ class SqlAlchemyEventStore:
|
||||
output_tokens=output_tokens,
|
||||
cost=cost,
|
||||
latency_ms=latency_ms,
|
||||
visibility_mask=self._resolve_stage_visibility_mask(
|
||||
event=event,
|
||||
),
|
||||
visibility_mask=visibility_mask,
|
||||
)
|
||||
await self._append_context_cache_message(
|
||||
session_id=session_id,
|
||||
event=event,
|
||||
visibility_mask=visibility_mask,
|
||||
role=role.value,
|
||||
content=content,
|
||||
metadata=metadata_model.model_dump(mode="json", exclude_none=True),
|
||||
timestamp=self._resolve_message_timestamp(persisted),
|
||||
)
|
||||
|
||||
current_status = getattr(chat_session, "status", AgentChatSessionStatus.RUNNING)
|
||||
@@ -339,16 +354,26 @@ class SqlAlchemyEventStore:
|
||||
if locked_session is None:
|
||||
return
|
||||
seq = int(getattr(locked_session, "message_count", 0) or 0) + 1
|
||||
await message_repo.append_message(
|
||||
visibility_mask = self._resolve_stage_visibility_mask(
|
||||
event=event,
|
||||
)
|
||||
persisted = await message_repo.append_message(
|
||||
session_id=session_id,
|
||||
seq=seq,
|
||||
role=AgentChatMessageRole.TOOL,
|
||||
content=content,
|
||||
tool_name=tool_output.tool_name,
|
||||
metadata=metadata_model.model_dump(mode="json", exclude_none=True),
|
||||
visibility_mask=self._resolve_stage_visibility_mask(
|
||||
event=event,
|
||||
),
|
||||
visibility_mask=visibility_mask,
|
||||
)
|
||||
await self._append_context_cache_message(
|
||||
session_id=session_id,
|
||||
event=event,
|
||||
visibility_mask=visibility_mask,
|
||||
role=AgentChatMessageRole.TOOL.value,
|
||||
content=content,
|
||||
metadata=metadata_model.model_dump(mode="json", exclude_none=True),
|
||||
timestamp=self._resolve_message_timestamp(persisted),
|
||||
)
|
||||
|
||||
current_status = getattr(chat_session, "status", AgentChatSessionStatus.RUNNING)
|
||||
@@ -377,6 +402,13 @@ class SqlAlchemyEventStore:
|
||||
*,
|
||||
event: dict[str, Any],
|
||||
) -> int:
|
||||
runtime_mode = self._event_value(event, "runtime_mode")
|
||||
if (
|
||||
isinstance(runtime_mode, str)
|
||||
and runtime_mode.strip().lower() == RuntimeMode.AUTOMATION.value
|
||||
):
|
||||
return bit_mask(bit=int(SystemVisibilityBit.UI_HISTORY))
|
||||
|
||||
raw_stage = self._event_value(event, "stage")
|
||||
if not isinstance(raw_stage, str):
|
||||
return bit_mask(bit=int(SystemVisibilityBit.UI_HISTORY))
|
||||
@@ -387,6 +419,61 @@ class SqlAlchemyEventStore:
|
||||
bit=int(SystemVisibilityBit.CONTEXT_ASSEMBLY)
|
||||
)
|
||||
|
||||
async def _append_context_cache_message(
|
||||
self,
|
||||
*,
|
||||
session_id: UUID,
|
||||
event: dict[str, Any],
|
||||
visibility_mask: int,
|
||||
role: str,
|
||||
content: str,
|
||||
metadata: dict[str, object] | None,
|
||||
timestamp: str,
|
||||
) -> None:
|
||||
message_payload: dict[str, object] = {
|
||||
"role": role,
|
||||
"content": content,
|
||||
"timestamp": timestamp,
|
||||
}
|
||||
if isinstance(metadata, dict):
|
||||
message_payload["metadata"] = metadata
|
||||
|
||||
try:
|
||||
context_cache = create_context_messages_cache()
|
||||
await context_cache.append_message(
|
||||
thread_id=str(session_id),
|
||||
runtime_mode=self._resolve_runtime_mode(event=event),
|
||||
visibility_mask=visibility_mask,
|
||||
message=message_payload,
|
||||
)
|
||||
except Exception as exc:
|
||||
self._logger.warning(
|
||||
"Failed to append context cache message from event",
|
||||
thread_id=str(session_id),
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_runtime_mode(*, event: dict[str, Any]) -> str:
|
||||
raw = event.get("runtime_mode")
|
||||
if isinstance(raw, str):
|
||||
normalized = raw.strip().lower()
|
||||
if normalized:
|
||||
return normalized
|
||||
return RuntimeMode.CHAT.value
|
||||
|
||||
@staticmethod
|
||||
def _resolve_message_timestamp(message: Any) -> str:
|
||||
created_at = getattr(message, "created_at", None)
|
||||
if isinstance(created_at, str) and created_at:
|
||||
return created_at
|
||||
if isinstance(created_at, datetime):
|
||||
try:
|
||||
return created_at.astimezone(timezone.utc).isoformat()
|
||||
except Exception:
|
||||
pass
|
||||
return datetime.now(timezone.utc).isoformat(timespec="seconds")
|
||||
|
||||
async def _update_session_state(
|
||||
self,
|
||||
*,
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
from core.agentscope.persistence.user_context_cache import (
|
||||
UserContextCache,
|
||||
create_user_context_cache,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"UserContextCache",
|
||||
"create_user_context_cache",
|
||||
]
|
||||
@@ -29,7 +29,9 @@ from models.llm_factory import LlmFactory
|
||||
from models.system_agents import SystemAgents
|
||||
from schemas.agent.forwarded_props import (
|
||||
ClientTimeContext,
|
||||
RuntimeMode,
|
||||
parse_forwarded_props_client_time,
|
||||
parse_forwarded_props_runtime_mode,
|
||||
)
|
||||
from schemas.agent.runtime_models import (
|
||||
RouterAgentOutput,
|
||||
@@ -76,6 +78,7 @@ class AgentScopeRunner:
|
||||
) -> dict[str, Any]:
|
||||
owner_id = UUID(user_context.id)
|
||||
runtime_client_time = self._resolve_runtime_client_time(run_input=run_input)
|
||||
runtime_mode = self._resolve_runtime_mode(run_input=run_input)
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
router_config = await self._load_stage_config(
|
||||
@@ -99,6 +102,7 @@ class AgentScopeRunner:
|
||||
context_messages=context_messages,
|
||||
stage_config=router_config,
|
||||
runtime_client_time=runtime_client_time,
|
||||
runtime_mode=runtime_mode,
|
||||
user_memory=user_memory,
|
||||
)
|
||||
worker_output = await self._execute_worker_step(
|
||||
@@ -109,6 +113,7 @@ class AgentScopeRunner:
|
||||
toolkit=worker_toolkit,
|
||||
stage_config=worker_config,
|
||||
runtime_client_time=runtime_client_time,
|
||||
runtime_mode=runtime_mode,
|
||||
work_memory=work_memory,
|
||||
)
|
||||
return {
|
||||
@@ -171,6 +176,7 @@ class AgentScopeRunner:
|
||||
context_messages: list[Msg],
|
||||
stage_config: SystemAgentRuntimeConfig,
|
||||
runtime_client_time: ClientTimeContext | None,
|
||||
runtime_mode: RuntimeMode,
|
||||
user_memory: UserMemoryContent | None,
|
||||
) -> RouterAgentOutput:
|
||||
await self._emit_step_event(
|
||||
@@ -178,6 +184,7 @@ class AgentScopeRunner:
|
||||
run_input=run_input,
|
||||
step_name=AgentType.ROUTER.value,
|
||||
event_type="STEP_STARTED",
|
||||
runtime_mode=runtime_mode,
|
||||
)
|
||||
router_result = await self._run_router_stage(
|
||||
user_context=user_context,
|
||||
@@ -193,6 +200,7 @@ class AgentScopeRunner:
|
||||
run_input=run_input,
|
||||
step_name=AgentType.ROUTER.value,
|
||||
event_type="STEP_FINISHED",
|
||||
runtime_mode=runtime_mode,
|
||||
extra_event={
|
||||
"_router_persist": {
|
||||
"router_output": router_output.model_dump(
|
||||
@@ -214,6 +222,7 @@ class AgentScopeRunner:
|
||||
toolkit: Any,
|
||||
stage_config: SystemAgentRuntimeConfig,
|
||||
runtime_client_time: ClientTimeContext | None,
|
||||
runtime_mode: RuntimeMode,
|
||||
work_memory: WorkProfileContent | None,
|
||||
) -> WorkerAgentOutputLite:
|
||||
worker_output_model = resolve_worker_output_model(router_output.ui.ui_mode)
|
||||
@@ -222,6 +231,7 @@ class AgentScopeRunner:
|
||||
run_input=run_input,
|
||||
step_name=AgentType.WORKER.value,
|
||||
event_type="STEP_STARTED",
|
||||
runtime_mode=runtime_mode,
|
||||
)
|
||||
worker_result = await self._run_worker_stage(
|
||||
user_context=user_context,
|
||||
@@ -234,6 +244,7 @@ class AgentScopeRunner:
|
||||
worker_output_model=worker_output_model,
|
||||
pipeline=pipeline,
|
||||
runtime_client_time=runtime_client_time,
|
||||
runtime_mode=runtime_mode,
|
||||
work_memory=work_memory,
|
||||
)
|
||||
worker_output = worker_output_model.model_validate(worker_result.payload)
|
||||
@@ -242,6 +253,7 @@ class AgentScopeRunner:
|
||||
run_input=run_input,
|
||||
step_name=AgentType.WORKER.value,
|
||||
event_type="STEP_FINISHED",
|
||||
runtime_mode=runtime_mode,
|
||||
)
|
||||
return worker_output
|
||||
|
||||
@@ -332,6 +344,7 @@ class AgentScopeRunner:
|
||||
worker_output_model: type[WorkerAgentOutputLite],
|
||||
pipeline: PipelineLike,
|
||||
runtime_client_time: ClientTimeContext | None,
|
||||
runtime_mode: RuntimeMode,
|
||||
work_memory: WorkProfileContent | None,
|
||||
) -> StageExecutionResult:
|
||||
tracking_model = self._build_model(stage_config=stage_config)
|
||||
@@ -340,6 +353,7 @@ class AgentScopeRunner:
|
||||
session_id=run_input.thread_id,
|
||||
run_id=run_input.run_id,
|
||||
stage=stage_config.agent_type.value,
|
||||
runtime_mode=runtime_mode.value,
|
||||
emit_text_events=True,
|
||||
emit_tool_events=True,
|
||||
)
|
||||
@@ -437,12 +451,14 @@ class AgentScopeRunner:
|
||||
run_input: RunAgentInput,
|
||||
step_name: str,
|
||||
event_type: str,
|
||||
runtime_mode: RuntimeMode,
|
||||
extra_event: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
payload: dict[str, Any] = {
|
||||
"type": event_type,
|
||||
"threadId": run_input.thread_id,
|
||||
"runId": run_input.run_id,
|
||||
"runtime_mode": runtime_mode.value,
|
||||
"stepName": step_name,
|
||||
}
|
||||
if extra_event:
|
||||
@@ -459,6 +475,15 @@ class AgentScopeRunner:
|
||||
getattr(run_input, "forwarded_props", None)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_runtime_mode(*, run_input: RunAgentInput) -> RuntimeMode:
|
||||
try:
|
||||
return parse_forwarded_props_runtime_mode(
|
||||
getattr(run_input, "forwarded_props", None)
|
||||
)
|
||||
except ValueError:
|
||||
return RuntimeMode.CHAT
|
||||
|
||||
@staticmethod
|
||||
def _resolve_provider_api_key(*, factory_name: str) -> str:
|
||||
normalized_factory_name = factory_name.strip().upper()
|
||||
|
||||
@@ -20,6 +20,7 @@ class PipelineStageEmitter:
|
||||
session_id: str,
|
||||
run_id: str,
|
||||
stage: str,
|
||||
runtime_mode: str,
|
||||
emit_text_events: bool,
|
||||
emit_tool_events: bool,
|
||||
) -> None:
|
||||
@@ -27,6 +28,7 @@ class PipelineStageEmitter:
|
||||
self._session_id = session_id
|
||||
self._run_id = run_id
|
||||
self._stage = stage
|
||||
self._runtime_mode = runtime_mode
|
||||
self._emit_text_events = emit_text_events
|
||||
self._emit_tool_events = emit_tool_events
|
||||
self._emitted_tool_calls: set[str] = set()
|
||||
@@ -127,6 +129,7 @@ class PipelineStageEmitter:
|
||||
"type": event_type,
|
||||
"threadId": self._session_id,
|
||||
"runId": self._run_id,
|
||||
"runtime_mode": self._runtime_mode,
|
||||
**payload,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -6,6 +6,13 @@ from typing import Any, cast
|
||||
from uuid import UUID
|
||||
|
||||
from agentscope.message import Msg
|
||||
from core.agentscope.caches import create_user_context_cache
|
||||
from core.agentscope.caches.attachment_content_cache import (
|
||||
create_attachment_content_cache,
|
||||
)
|
||||
from core.agentscope.caches.context_messages_cache import (
|
||||
create_context_messages_cache,
|
||||
)
|
||||
from core.agentscope.events import (
|
||||
AgentScopeAgUiCodec,
|
||||
AgentScopeEventPipeline,
|
||||
@@ -20,12 +27,16 @@ from core.config.settings import config
|
||||
from core.db.session import AsyncSessionLocal
|
||||
from core.logging import get_logger
|
||||
from core.taskiq.app import worker_agent_broker, worker_automation_broker
|
||||
from schemas.agent.forwarded_props import (
|
||||
RuntimeMode,
|
||||
parse_forwarded_props_runtime_mode,
|
||||
)
|
||||
from schemas.domain.automation import MessageContextConfig, RuntimeConfig
|
||||
from schemas.domain.memory_content import UserMemoryContent, WorkProfileContent
|
||||
from schemas.domain.chat_message import (
|
||||
AgentChatMessageMetadata,
|
||||
extract_user_message_attachments,
|
||||
)
|
||||
from schemas.domain.memory_content import UserMemoryContent, WorkProfileContent
|
||||
from schemas.shared.user import UserContext
|
||||
from services.base.redis import get_or_init_redis_client
|
||||
from services.base.supabase import supabase_service
|
||||
@@ -73,28 +84,56 @@ async def _build_user_context(
|
||||
*,
|
||||
owner_id: UUID,
|
||||
session: Any,
|
||||
session_id: str,
|
||||
) -> UserContext:
|
||||
cache = create_user_context_cache()
|
||||
cached = await cache.get(session_id=UUID(session_id))
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
current_user = CurrentUser(id=owner_id)
|
||||
user_service = get_user_service(session=session, user=current_user)
|
||||
return await user_service.get_me()
|
||||
user_context = await user_service.get_me()
|
||||
|
||||
await cache.set(session_id=UUID(session_id), context=user_context)
|
||||
return user_context
|
||||
|
||||
|
||||
async def _build_recent_context_messages(
|
||||
*,
|
||||
session: Any,
|
||||
thread_id: str,
|
||||
runtime_mode: RuntimeMode = RuntimeMode.CHAT,
|
||||
context_config: "MessageContextConfig",
|
||||
) -> list[Msg]:
|
||||
context_service = AgentContextService(repository=AgentRepository(session))
|
||||
result = await context_service.load_context_messages(
|
||||
context_cache = create_context_messages_cache()
|
||||
attachment_cache = create_attachment_content_cache()
|
||||
raw_messages = await context_cache.get(
|
||||
thread_id=thread_id,
|
||||
runtime_mode=runtime_mode.value,
|
||||
context_config=context_config,
|
||||
)
|
||||
|
||||
if not result:
|
||||
return []
|
||||
if raw_messages is None:
|
||||
context_service = AgentContextService(repository=AgentRepository(session))
|
||||
result = await context_service.load_context_messages(
|
||||
thread_id=thread_id,
|
||||
context_config=context_config,
|
||||
)
|
||||
if not result:
|
||||
return []
|
||||
|
||||
messages_obj = result.get("messages")
|
||||
if not isinstance(messages_obj, list):
|
||||
return []
|
||||
raw_messages = [item for item in messages_obj if isinstance(item, dict)]
|
||||
await context_cache.set(
|
||||
thread_id=thread_id,
|
||||
runtime_mode=runtime_mode.value,
|
||||
context_config=context_config,
|
||||
messages=raw_messages,
|
||||
)
|
||||
|
||||
raw_messages: list[dict[str, object]] = result.get("messages") or []
|
||||
if not raw_messages:
|
||||
return []
|
||||
|
||||
@@ -120,6 +159,24 @@ async def _build_recent_context_messages(
|
||||
:_MAX_CONTEXT_ATTACHMENTS
|
||||
]
|
||||
for attachment in attachments:
|
||||
mime_type = attachment.mime_type or "image/png"
|
||||
cached_b64 = await attachment_cache.get(
|
||||
bucket=attachment.bucket,
|
||||
path=attachment.path,
|
||||
mime_type=mime_type,
|
||||
)
|
||||
if cached_b64:
|
||||
image_blocks.append(
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": mime_type,
|
||||
"data": cached_b64,
|
||||
},
|
||||
}
|
||||
)
|
||||
continue
|
||||
try:
|
||||
image_bytes = await supabase_service.download_bytes(
|
||||
bucket=attachment.bucket,
|
||||
@@ -128,12 +185,18 @@ async def _build_recent_context_messages(
|
||||
except Exception:
|
||||
continue
|
||||
b64_data = base64.b64encode(image_bytes).decode("utf-8")
|
||||
await attachment_cache.set(
|
||||
bucket=attachment.bucket,
|
||||
path=attachment.path,
|
||||
mime_type=mime_type,
|
||||
base64_data=b64_data,
|
||||
)
|
||||
image_blocks.append(
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": attachment.mime_type or "image/png",
|
||||
"media_type": mime_type,
|
||||
"data": b64_data,
|
||||
},
|
||||
}
|
||||
@@ -183,6 +246,13 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
raise ValueError("run_input is required")
|
||||
|
||||
run_input = parse_run_input(run_input_raw)
|
||||
runtime_mode = RuntimeMode.CHAT
|
||||
try:
|
||||
runtime_mode = parse_forwarded_props_runtime_mode(
|
||||
getattr(run_input, "forwarded_props", None)
|
||||
)
|
||||
except ValueError:
|
||||
runtime_mode = RuntimeMode.CHAT
|
||||
runtime_config = RuntimeConfig.model_validate(runtime_config_raw or {})
|
||||
thread_id = run_input.thread_id
|
||||
run_id = run_input.run_id
|
||||
@@ -195,7 +265,9 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
current_user = CurrentUser(id=owner_id)
|
||||
user_context = await _build_user_context(owner_id=owner_id, session=session)
|
||||
user_context = await _build_user_context(
|
||||
owner_id=owner_id, session=session, session_id=thread_id
|
||||
)
|
||||
memories_service = MemoriesService(
|
||||
repository=SQLAlchemyMemoriesRepository(session),
|
||||
session=session,
|
||||
@@ -226,6 +298,7 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
context_messages = await _build_recent_context_messages(
|
||||
session=session,
|
||||
thread_id=thread_id,
|
||||
runtime_mode=runtime_mode,
|
||||
context_config=runtime_config.context,
|
||||
)
|
||||
|
||||
|
||||
@@ -161,6 +161,15 @@ class AgentRuntimeSettings(BaseModel):
|
||||
user_context_cache_prefix: str = "agent:user-context"
|
||||
user_context_cache_ttl_seconds: int = Field(default=600, ge=60, le=86400)
|
||||
user_context_cache_max_turns: int = Field(default=6, ge=1, le=100)
|
||||
context_messages_cache_prefix: str = "agent:context-messages"
|
||||
context_messages_cache_ttl_seconds: int = Field(default=600, ge=30, le=86400)
|
||||
attachment_content_cache_prefix: str = "agent:attachment-content"
|
||||
attachment_content_cache_ttl_seconds: int = Field(default=1800, ge=30, le=86400)
|
||||
attachment_content_cache_max_base64_bytes: int = Field(
|
||||
default=6 * 1024 * 1024,
|
||||
ge=1024,
|
||||
le=64 * 1024 * 1024,
|
||||
)
|
||||
|
||||
|
||||
class AutomationSchedulerSettings(BaseModel):
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from schemas.agent.ui_hints import UiHintsPayload
|
||||
|
||||
@@ -86,6 +86,17 @@ class KeyEntity(BaseModel):
|
||||
type: str
|
||||
value: str | None = None
|
||||
|
||||
@field_validator("value", mode="before")
|
||||
@classmethod
|
||||
def normalize_value(cls, value: object) -> object:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
if isinstance(value, bool | int | float):
|
||||
return str(value)
|
||||
return value
|
||||
|
||||
|
||||
class ConstraintItem(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
from .factory import get_cache_store
|
||||
from .interfaces import CacheStore
|
||||
|
||||
__all__ = ["CacheStore", "get_cache_store"]
|
||||
@@ -0,0 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .interfaces import CacheStore
|
||||
from .redis_store import RedisCacheStore
|
||||
|
||||
_cache_store: CacheStore | None = None
|
||||
|
||||
|
||||
def get_cache_store() -> CacheStore:
|
||||
global _cache_store
|
||||
if _cache_store is None:
|
||||
_cache_store = RedisCacheStore()
|
||||
return _cache_store
|
||||
@@ -0,0 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
class CacheStore(Protocol):
|
||||
async def hgetall(self, key: str, /) -> dict[str, str]: ...
|
||||
|
||||
async def hset(self, key: str, /, mapping: dict[str, str]) -> int: ...
|
||||
|
||||
async def hincrby(self, key: str, field: str, amount: int = 1, /) -> int: ...
|
||||
|
||||
async def expire(self, key: str, ttl_seconds: int, /) -> int: ...
|
||||
|
||||
async def delete(self, *keys: str) -> int: ...
|
||||
|
||||
async def sadd(self, key: str, *members: str) -> int: ...
|
||||
|
||||
async def smembers(self, key: str, /) -> set[str]: ...
|
||||
@@ -0,0 +1,80 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from typing import Any
|
||||
|
||||
from services.base.redis import get_or_init_redis_client
|
||||
|
||||
from .interfaces import CacheStore
|
||||
|
||||
|
||||
def _to_text(value: Any) -> str | None:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
if isinstance(value, bytes):
|
||||
try:
|
||||
return value.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
async def _maybe_await(value: Any) -> Any:
|
||||
if inspect.isawaitable(value):
|
||||
return await value
|
||||
return value
|
||||
|
||||
|
||||
class RedisCacheStore(CacheStore):
|
||||
async def hgetall(self, key: str) -> dict[str, str]:
|
||||
client = await get_or_init_redis_client()
|
||||
raw = await _maybe_await(client.hgetall(key))
|
||||
if not isinstance(raw, dict):
|
||||
return {}
|
||||
|
||||
decoded: dict[str, str] = {}
|
||||
for raw_key, raw_value in raw.items():
|
||||
key_text = _to_text(raw_key)
|
||||
value_text = _to_text(raw_value)
|
||||
if key_text is None or value_text is None:
|
||||
continue
|
||||
decoded[key_text] = value_text
|
||||
return decoded
|
||||
|
||||
async def hset(self, key: str, mapping: dict[str, str]) -> int:
|
||||
client = await get_or_init_redis_client()
|
||||
result = await _maybe_await(client.hset(key, mapping=mapping))
|
||||
return int(result)
|
||||
|
||||
async def hincrby(self, key: str, field: str, amount: int = 1) -> int:
|
||||
client = await get_or_init_redis_client()
|
||||
result = await _maybe_await(client.hincrby(key, field, amount))
|
||||
return int(result)
|
||||
|
||||
async def expire(self, key: str, ttl_seconds: int) -> int:
|
||||
client = await get_or_init_redis_client()
|
||||
result = await _maybe_await(client.expire(key, ttl_seconds))
|
||||
return int(result)
|
||||
|
||||
async def delete(self, *keys: str) -> int:
|
||||
if not keys:
|
||||
return 0
|
||||
client = await get_or_init_redis_client()
|
||||
result = await _maybe_await(client.delete(*keys))
|
||||
return int(result)
|
||||
|
||||
async def sadd(self, key: str, *members: str) -> int:
|
||||
if not members:
|
||||
return 0
|
||||
client = await get_or_init_redis_client()
|
||||
result = await _maybe_await(client.sadd(key, *members))
|
||||
return int(result)
|
||||
|
||||
async def smembers(self, key: str) -> set[str]:
|
||||
client = await get_or_init_redis_client()
|
||||
raw = await _maybe_await(client.smembers(key))
|
||||
if isinstance(raw, set):
|
||||
return {value for item in raw if (value := _to_text(item))}
|
||||
if isinstance(raw, list | tuple):
|
||||
return {value for item in raw if (value := _to_text(item))}
|
||||
return set()
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
from datetime import date, datetime, timezone
|
||||
import hashlib
|
||||
|
||||
from urllib.parse import urlparse
|
||||
@@ -10,6 +10,9 @@ from fastapi import HTTPException
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from core.agentscope.caches.context_messages_cache import (
|
||||
create_context_messages_cache,
|
||||
)
|
||||
from core.agentscope.schemas.agui_input import extract_latest_user_payload
|
||||
from core.config.settings import config
|
||||
from core.logging import get_logger
|
||||
@@ -124,6 +127,13 @@ class AgentService:
|
||||
visibility_mask=visibility_mask,
|
||||
)
|
||||
await self._repository.commit()
|
||||
await self._append_context_cache_user_message(
|
||||
thread_id=thread_id,
|
||||
runtime_mode=runtime_mode,
|
||||
visibility_mask=visibility_mask,
|
||||
content=user_message_text,
|
||||
metadata=user_message_metadata,
|
||||
)
|
||||
|
||||
queue = "automation" if runtime_mode == RuntimeMode.AUTOMATION else "agent"
|
||||
task_id = await self._queue.enqueue(
|
||||
@@ -147,6 +157,44 @@ class AgentService:
|
||||
created=created,
|
||||
)
|
||||
|
||||
async def _append_context_cache_user_message(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
runtime_mode: RuntimeMode,
|
||||
visibility_mask: int,
|
||||
content: str,
|
||||
metadata: AgentChatMessageMetadata | None,
|
||||
) -> None:
|
||||
metadata_payload = (
|
||||
metadata.model_dump(mode="json", exclude_none=True)
|
||||
if isinstance(metadata, AgentChatMessageMetadata)
|
||||
else None
|
||||
)
|
||||
message_payload: dict[str, object] = {
|
||||
"role": "user",
|
||||
"content": content,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(timespec="seconds"),
|
||||
}
|
||||
if isinstance(metadata_payload, dict):
|
||||
message_payload["metadata"] = metadata_payload
|
||||
|
||||
try:
|
||||
context_cache = create_context_messages_cache()
|
||||
await context_cache.append_message(
|
||||
thread_id=thread_id,
|
||||
runtime_mode=runtime_mode.value,
|
||||
visibility_mask=visibility_mask,
|
||||
message=message_payload,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to append user message to context cache",
|
||||
thread_id=thread_id,
|
||||
runtime_mode=runtime_mode.value,
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
async def _resolve_user_message_visibility_mask(
|
||||
self, *, runtime_mode: RuntimeMode
|
||||
) -> int:
|
||||
|
||||
@@ -7,10 +7,10 @@ from uuid import UUID
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from core.agentscope.persistence.user_context_cache import (
|
||||
from core.agentscope.caches.user_context_cache import (
|
||||
create_user_context_cache,
|
||||
)
|
||||
from core.auth.models import CurrentUser
|
||||
from core.db.base_service import BaseService
|
||||
from core.logging import get_logger
|
||||
from schemas.shared.user import UserContext, parse_profile_settings
|
||||
|
||||
@@ -0,0 +1,228 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, cast
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agentscope.caches.context_messages_cache import ContextMessagesCache
|
||||
from schemas.domain.automation import ContextWindowMode, MessageContextConfig
|
||||
|
||||
|
||||
class _FakeCacheStore:
|
||||
def __init__(self) -> None:
|
||||
self.hash_store: dict[str, dict[str, str]] = {}
|
||||
self.set_store: dict[str, set[str]] = {}
|
||||
|
||||
async def hgetall(self, key: str) -> dict[str, str]:
|
||||
return dict(self.hash_store.get(key, {}))
|
||||
|
||||
async def hset(self, key: str, mapping: dict[str, str]) -> int:
|
||||
self.hash_store[key] = dict(mapping)
|
||||
return 1
|
||||
|
||||
async def hincrby(self, key: str, field: str, amount: int = 1) -> int:
|
||||
del key, field, amount
|
||||
return 0
|
||||
|
||||
async def expire(self, key: str, ttl_seconds: int) -> int:
|
||||
del key, ttl_seconds
|
||||
return 1
|
||||
|
||||
async def delete(self, *keys: str) -> int:
|
||||
for key in keys:
|
||||
self.hash_store.pop(key, None)
|
||||
self.set_store.pop(key, None)
|
||||
return len(keys)
|
||||
|
||||
async def sadd(self, key: str, *members: str) -> int:
|
||||
values = self.set_store.setdefault(key, set())
|
||||
before = len(values)
|
||||
for member in members:
|
||||
values.add(member)
|
||||
return len(values) - before
|
||||
|
||||
async def smembers(self, key: str) -> set[str]:
|
||||
return set(self.set_store.get(key, set()))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_messages_cache_set_get_roundtrip() -> None:
|
||||
store = _FakeCacheStore()
|
||||
cache = ContextMessagesCache(
|
||||
client=store,
|
||||
key_prefix="agent:context-messages",
|
||||
ttl_seconds=600,
|
||||
)
|
||||
config = MessageContextConfig(window_mode=ContextWindowMode.DAY, window_count=2)
|
||||
messages: list[dict[str, object]] = [
|
||||
{"role": "user", "content": "hello", "timestamp": "2026-03-25T08:00:00+00:00"}
|
||||
]
|
||||
|
||||
await cache.set(
|
||||
thread_id="thread-1",
|
||||
runtime_mode="chat",
|
||||
context_config=config,
|
||||
messages=messages,
|
||||
)
|
||||
loaded = await cache.get(
|
||||
thread_id="thread-1",
|
||||
runtime_mode="chat",
|
||||
context_config=config,
|
||||
)
|
||||
|
||||
assert loaded is not None
|
||||
assert loaded[0]["content"] == "hello"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_messages_cache_append_skips_when_not_visible() -> None:
|
||||
store = _FakeCacheStore()
|
||||
cache = ContextMessagesCache(
|
||||
client=store,
|
||||
key_prefix="agent:context-messages",
|
||||
ttl_seconds=600,
|
||||
)
|
||||
config = MessageContextConfig(
|
||||
window_mode=ContextWindowMode.NUMBER,
|
||||
window_count=1,
|
||||
)
|
||||
await cache.set(
|
||||
thread_id="thread-1",
|
||||
runtime_mode="chat",
|
||||
context_config=config,
|
||||
messages=cast(
|
||||
list[dict[str, Any]],
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "q1",
|
||||
"timestamp": "2026-03-25T08:00:00+00:00",
|
||||
}
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
await cache.append_message(
|
||||
thread_id="thread-1",
|
||||
runtime_mode="chat",
|
||||
visibility_mask=1,
|
||||
message={"role": "assistant", "content": "a1"},
|
||||
)
|
||||
|
||||
loaded = await cache.get(
|
||||
thread_id="thread-1",
|
||||
runtime_mode="chat",
|
||||
context_config=config,
|
||||
)
|
||||
assert loaded is not None
|
||||
assert len(loaded) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_messages_cache_append_trims_number_window() -> None:
|
||||
store = _FakeCacheStore()
|
||||
cache = ContextMessagesCache(
|
||||
client=store,
|
||||
key_prefix="agent:context-messages",
|
||||
ttl_seconds=600,
|
||||
)
|
||||
config = MessageContextConfig(
|
||||
window_mode=ContextWindowMode.NUMBER,
|
||||
window_count=1,
|
||||
)
|
||||
await cache.set(
|
||||
thread_id="thread-1",
|
||||
runtime_mode="chat",
|
||||
context_config=config,
|
||||
messages=cast(
|
||||
list[dict[str, Any]],
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "q0",
|
||||
"timestamp": "2026-03-25T08:00:00+00:00",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "a0",
|
||||
"timestamp": "2026-03-25T08:01:00+00:00",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "q1",
|
||||
"timestamp": "2026-03-25T08:02:00+00:00",
|
||||
},
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
await cache.append_message(
|
||||
thread_id="thread-1",
|
||||
runtime_mode="chat",
|
||||
visibility_mask=2,
|
||||
message={"role": "assistant", "content": "a1"},
|
||||
)
|
||||
|
||||
loaded = await cache.get(
|
||||
thread_id="thread-1",
|
||||
runtime_mode="chat",
|
||||
context_config=config,
|
||||
)
|
||||
assert loaded is not None
|
||||
assert [str(item["content"]) for item in loaded] == ["q1", "a1"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_messages_cache_append_trims_day_window() -> None:
|
||||
store = _FakeCacheStore()
|
||||
cache = ContextMessagesCache(
|
||||
client=store,
|
||||
key_prefix="agent:context-messages",
|
||||
ttl_seconds=600,
|
||||
)
|
||||
config = MessageContextConfig(window_mode=ContextWindowMode.DAY, window_count=2)
|
||||
|
||||
now = datetime(2026, 3, 25, 10, 0, tzinfo=timezone.utc)
|
||||
yesterday = now - timedelta(days=1)
|
||||
two_days_ago = now - timedelta(days=2)
|
||||
|
||||
await cache.set(
|
||||
thread_id="thread-1",
|
||||
runtime_mode="chat",
|
||||
context_config=config,
|
||||
messages=cast(
|
||||
list[dict[str, Any]],
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "d-2",
|
||||
"timestamp": two_days_ago.isoformat(),
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "d-1",
|
||||
"timestamp": yesterday.isoformat(),
|
||||
},
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
await cache.append_message(
|
||||
thread_id="thread-1",
|
||||
runtime_mode="chat",
|
||||
visibility_mask=2,
|
||||
message={
|
||||
"role": "assistant",
|
||||
"content": "d0",
|
||||
"timestamp": now.isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
loaded = await cache.get(
|
||||
thread_id="thread-1",
|
||||
runtime_mode="chat",
|
||||
context_config=config,
|
||||
)
|
||||
assert loaded is not None
|
||||
assert [str(item["content"]) for item in loaded] == ["d-1", "d0"]
|
||||
@@ -145,6 +145,38 @@ async def test_store_persists_tool_output_with_summary_as_content(
|
||||
assert append_kwargs["visibility_mask"] == (1 << 0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_sets_history_only_visibility_for_automation_worker_output(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
captured: dict[str, object] = {}
|
||||
fake_chat_session = SimpleNamespace(state_snapshot={}, message_count=3)
|
||||
_patch_repositories(monkeypatch, captured, fake_chat_session)
|
||||
|
||||
store = store_module.SqlAlchemyEventStore(session_factory=lambda: _FakeSessionCtx())
|
||||
await store.persist(
|
||||
{
|
||||
"type": "TEXT_MESSAGE_END",
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-auto-1",
|
||||
"messageId": "assistant-auto-1",
|
||||
"role": "assistant",
|
||||
"stage": "worker",
|
||||
"runtime_mode": "automation",
|
||||
"status": "success",
|
||||
"answer": "automation-result",
|
||||
"key_points": [],
|
||||
"result_type": "summary",
|
||||
"suggested_actions": [],
|
||||
"error": None,
|
||||
}
|
||||
)
|
||||
|
||||
append_kwargs = cast(dict[str, Any], captured["append_kwargs"])
|
||||
assert append_kwargs["content"] == "automation-result"
|
||||
assert append_kwargs["visibility_mask"] == (1 << 0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_persists_router_step_output_for_cost_tracking(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
|
||||
@@ -5,7 +5,7 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agentscope.persistence.user_context_cache import UserContextCache
|
||||
from core.agentscope.caches.user_context_cache import UserContextCache
|
||||
from schemas.shared.user import (
|
||||
UserContext,
|
||||
parse_profile_settings,
|
||||
|
||||
@@ -28,6 +28,7 @@ async def test_tool_result_event_uses_runtime_tool_call_id() -> None:
|
||||
session_id="thread-1",
|
||||
run_id="run-1",
|
||||
stage="worker",
|
||||
runtime_mode="chat",
|
||||
emit_text_events=False,
|
||||
emit_tool_events=True,
|
||||
)
|
||||
@@ -65,3 +66,4 @@ async def test_tool_result_event_uses_runtime_tool_call_id() -> None:
|
||||
result_events = [e for e in pipeline.events if e.get("type") == "TOOL_CALL_RESULT"]
|
||||
assert len(result_events) == 1
|
||||
assert result_events[0]["tool_call_id"] == "runtime-call-123"
|
||||
assert result_events[0]["runtime_mode"] == "chat"
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from schemas.agent.runtime_models import RouterAgentOutput
|
||||
|
||||
|
||||
def test_router_agent_output_coerces_key_entity_value_to_string() -> None:
|
||||
payload = {
|
||||
"normalized_task_input": {
|
||||
"user_text": "test",
|
||||
"multimodal_summary": [],
|
||||
"context_summary": "",
|
||||
},
|
||||
"key_entities": [
|
||||
{
|
||||
"name": "priority",
|
||||
"type": "number",
|
||||
"value": 8,
|
||||
}
|
||||
],
|
||||
"constraints": [],
|
||||
"task_typing": {
|
||||
"primary": "planning",
|
||||
"secondary": [],
|
||||
},
|
||||
"execution_mode": "onestep",
|
||||
"result_typing": {
|
||||
"primary": "summary",
|
||||
"secondary": [],
|
||||
},
|
||||
"ui": {
|
||||
"ui_mode": "none",
|
||||
"ui_decision_reason": "test",
|
||||
},
|
||||
}
|
||||
|
||||
model = RouterAgentOutput.model_validate(payload)
|
||||
|
||||
assert model.key_entities[0].value == "8"
|
||||
@@ -0,0 +1,392 @@
|
||||
# Agent Run Cancel (Failed Semantics) Implementation Plan
|
||||
|
||||
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
|
||||
|
||||
**Goal:** 为 `/api/v1/agent/runs` 增加可中断能力,在用户触发 cancel 后真正停止运行中的 agent 流程,并以 `RUN_ERROR(code=RUN_CANCELED)` 结束,最终将 session 状态落为 `failed`。
|
||||
|
||||
**Architecture:** 使用“协作取消 + 主任务中断”方案:API 层写入 Redis cancel 信号,runtime 在 worker 进程内并行 watcher 监听信号,命中后先调用 active agent 的 `interrupt()` 做优雅收尾,再 `cancel()` 当前 run 主任务做硬兜底。终态统一通过 `RUN_ERROR` 事件落库,复用现有 `FAILED` 会话语义,避免数据库枚举迁移。
|
||||
|
||||
**Tech Stack:** FastAPI, TaskIQ, Redis, AgentScope, SQLAlchemy, Pytest, Ruff, BasedPyright
|
||||
|
||||
---
|
||||
|
||||
### Task 1: 先更新协议文档(接口与事件语义)
|
||||
|
||||
**Files:**
|
||||
- Modify: `docs/protocols/agent/api-endpoints.md`
|
||||
- Modify: `docs/protocols/agent/sse-events.md`
|
||||
|
||||
**Step 1: 在 API 文档新增 cancel 端点契约**
|
||||
|
||||
在 `api-endpoints.md` 的端点清单添加:
|
||||
|
||||
```md
|
||||
| POST | `/runs/{thread_id}/cancel` | 请求取消指定 run |
|
||||
```
|
||||
|
||||
并新增章节说明:
|
||||
- 请求参数:`thread_id` + `runId`(建议 query)
|
||||
- 返回:`202 Accepted` + `accepted: true`
|
||||
- 语义:仅表示“取消请求已接收”,不保证已即时终止
|
||||
|
||||
**Step 2: 在 SSE 文档补充取消终态语义**
|
||||
|
||||
在 `sse-events.md` 的 `RUN_ERROR` 章节补充:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "RUN_ERROR",
|
||||
"threadId": "...",
|
||||
"runId": "...",
|
||||
"message": "run canceled by user",
|
||||
"code": "RUN_CANCELED"
|
||||
}
|
||||
```
|
||||
|
||||
并明确:
|
||||
- `RUN_CANCELED` 是用户主动中断,不是系统异常
|
||||
- 本阶段仍复用 session `failed`(向后兼容)
|
||||
|
||||
**Step 3: 文档自检**
|
||||
|
||||
检查文档是否同时覆盖:
|
||||
- HTTP 行为
|
||||
- SSE 终态事件
|
||||
- 兼容策略(不引入新 session 状态)
|
||||
|
||||
**Step 4: 提交文档变更**
|
||||
|
||||
```bash
|
||||
git add docs/protocols/agent/api-endpoints.md docs/protocols/agent/sse-events.md
|
||||
git commit -m "docs: define agent run cancel API and RUN_CANCELED error semantics"
|
||||
```
|
||||
|
||||
### Task 2: 打通 API 层到队列层的 cancel 信号写入
|
||||
|
||||
**Files:**
|
||||
- Modify: `backend/src/v1/agent/schemas.py`
|
||||
- Modify: `backend/src/v1/agent/dependencies.py`
|
||||
- Modify: `backend/src/v1/agent/service.py`
|
||||
- Modify: `backend/src/v1/agent/router.py`
|
||||
- Test: `backend/tests/unit/v1/agent/test_service.py`
|
||||
- Test: `backend/tests/integration/v1/agent/test_routes.py`
|
||||
|
||||
**Step 1: 增加 cancel 接口响应 schema**
|
||||
|
||||
在 `v1/agent/schemas.py` 增加:
|
||||
|
||||
```python
|
||||
class CancelRunResponse(BaseModel):
|
||||
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
|
||||
thread_id: str = Field(alias="threadId")
|
||||
run_id: str = Field(alias="runId")
|
||||
accepted: bool
|
||||
```
|
||||
|
||||
**Step 2: 扩展队列协议接口**
|
||||
|
||||
在 `QueueClientLike` 增加:
|
||||
|
||||
```python
|
||||
async def request_cancel(self, *, thread_id: str, run_id: str, requested_by: str) -> None: ...
|
||||
```
|
||||
|
||||
**Step 3: 在 `TaskiqQueueClient` 实现 request_cancel**
|
||||
|
||||
在 `v1/agent/dependencies.py` 中新增:
|
||||
- cancel key 规范:`agent:cancel:{thread_id}:{run_id}`
|
||||
- `SET key value EX <ttl>` 写入取消信号
|
||||
- `value` 可写 json 字符串(包含 user_id/timestamp)
|
||||
|
||||
**Step 4: 在 service 层新增 cancel_run**
|
||||
|
||||
在 `v1/agent/service.py` 增加方法:
|
||||
- 校验 session owner(复用 `get_session_owner + ensure_session_owner`)
|
||||
- 调用 `self._queue.request_cancel(...)`
|
||||
- 返回 `accepted` 结果 DTO
|
||||
|
||||
**Step 5: 在 router 新增 cancel 路由**
|
||||
|
||||
在 `v1/agent/router.py` 新增:
|
||||
|
||||
```python
|
||||
@router.post("/runs/{thread_id}/cancel", response_model=CancelRunResponse, status_code=202)
|
||||
async def cancel_run(...):
|
||||
...
|
||||
```
|
||||
|
||||
约束:
|
||||
- `runId` 必填(建议 query)
|
||||
- 非 owner 返回 403
|
||||
- 参数非法返回 422
|
||||
|
||||
**Step 6: 写 service 单测(先红)**
|
||||
|
||||
在 `test_service.py` 添加:
|
||||
- owner 可发起 cancel,`queue.request_cancel` 被调用
|
||||
- 非 owner cancel 返回 403
|
||||
|
||||
**Step 7: 运行单测确认失败**
|
||||
|
||||
Run: `uv run pytest backend/tests/unit/v1/agent/test_service.py -k cancel -q`
|
||||
|
||||
Expected: 至少 1 个测试失败(新逻辑尚未实现)
|
||||
|
||||
**Step 8: 实现最小代码使测试通过**
|
||||
|
||||
按 Step 1-5 完成实现,避免额外重构。
|
||||
|
||||
**Step 9: 运行测试验证通过**
|
||||
|
||||
Run: `uv run pytest backend/tests/unit/v1/agent/test_service.py -k cancel -q`
|
||||
|
||||
Expected: PASS
|
||||
|
||||
**Step 10: 增加路由集成测试**
|
||||
|
||||
在 `test_routes.py` 增加:
|
||||
- `POST /api/v1/agent/runs/{thread_id}/cancel?runId=...` 返回 202
|
||||
- 响应字段别名正确(`threadId/runId/accepted`)
|
||||
|
||||
**Step 11: 运行路由测试**
|
||||
|
||||
Run: `uv run pytest backend/tests/integration/v1/agent/test_routes.py -k cancel -q`
|
||||
|
||||
Expected: PASS
|
||||
|
||||
**Step 12: 提交 API 层变更**
|
||||
|
||||
```bash
|
||||
git add backend/src/v1/agent/schemas.py backend/src/v1/agent/dependencies.py backend/src/v1/agent/service.py backend/src/v1/agent/router.py backend/tests/unit/v1/agent/test_service.py backend/tests/integration/v1/agent/test_routes.py
|
||||
git commit -m "feat: add agent run cancel endpoint and Redis cancel signal"
|
||||
```
|
||||
|
||||
### Task 3: runtime runner 植入取消 watcher 与优雅中断
|
||||
|
||||
**Files:**
|
||||
- Modify: `backend/src/core/agentscope/runtime/runner.py`
|
||||
- Test: `backend/tests/unit/core/agentscope/runtime/test_runner.py`
|
||||
|
||||
**Step 1: 在 runner.execute 增加 cancel_checker 参数**
|
||||
|
||||
更新 `execute()` 签名:
|
||||
|
||||
```python
|
||||
cancel_checker: Callable[[], Awaitable[bool]] | None = None
|
||||
```
|
||||
|
||||
并保持默认 `None` 向后兼容。
|
||||
|
||||
**Step 2: 增加 active agent 引用与锁**
|
||||
|
||||
在 `AgentScopeRunner.__init__` 增加:
|
||||
- `self._active_agent: JsonReActAgent | None = None`
|
||||
- `self._active_agent_lock = asyncio.Lock()`
|
||||
|
||||
**Step 3: 在 `_run_worker_stage` 设置 active agent 生命周期**
|
||||
|
||||
在 `agent.reply_json(...)` 前后包裹:
|
||||
- before: 记录 `self._active_agent = agent`
|
||||
- finally: 清理引用
|
||||
|
||||
**Step 4: 新增 `_watch_cancel_signal` 协程**
|
||||
|
||||
行为:
|
||||
- 循环调用 `cancel_checker()`
|
||||
- 命中后先尝试 `await active_agent.interrupt()`
|
||||
- 再 `run_task.cancel("run canceled by user")`
|
||||
- 间隔 `await asyncio.sleep(0.2)`
|
||||
|
||||
**Step 5: 在 execute 启停 watcher**
|
||||
|
||||
- `run_task = asyncio.current_task()`
|
||||
- 如果有 `cancel_checker`,`create_task(_watch_cancel_signal(...))`
|
||||
- `finally` 中停止 watcher 并 `await` 回收
|
||||
|
||||
**Step 6: 补 stage 边界取消 gate(关键)**
|
||||
|
||||
在 router 结束后、worker 开始前检查一次 `cancel_checker()`:
|
||||
- 为 true 时抛 `asyncio.CancelledError`
|
||||
|
||||
目的:防止“router 已结束但仍进入 worker”。
|
||||
|
||||
**Step 7: 写 runner 单测(先红)**
|
||||
|
||||
新增测试用例:
|
||||
- cancel 信号触发后,`execute` 抛出 `CancelledError`
|
||||
- worker 未被继续执行(或中途被中断)
|
||||
|
||||
**Step 8: 运行 runner 测试**
|
||||
|
||||
Run: `uv run pytest backend/tests/unit/core/agentscope/runtime/test_runner.py -k cancel -q`
|
||||
|
||||
Expected: PASS
|
||||
|
||||
**Step 9: 提交 runner 变更**
|
||||
|
||||
```bash
|
||||
git add backend/src/core/agentscope/runtime/runner.py backend/tests/unit/core/agentscope/runtime/test_runner.py
|
||||
git commit -m "feat: add cooperative cancellation watcher to agentscope runner"
|
||||
```
|
||||
|
||||
### Task 4: orchestrator 与 task worker 处理 CancelledError 终态
|
||||
|
||||
**Files:**
|
||||
- Modify: `backend/src/core/agentscope/runtime/orchestrator.py`
|
||||
- Modify: `backend/src/core/agentscope/runtime/tasks.py`
|
||||
- Test: `backend/tests/unit/core/agentscope/runtime/test_orchestrator.py`
|
||||
- Test: `backend/tests/unit/core/agentscope/runtime/test_tasks.py`
|
||||
|
||||
**Step 1: orchestrator 单独捕获 CancelledError**
|
||||
|
||||
在 `orchestrator.run()` 添加:
|
||||
|
||||
```python
|
||||
except asyncio.CancelledError:
|
||||
await self._pipeline.emit(... RUN_ERROR code="RUN_CANCELED" ...)
|
||||
raise
|
||||
```
|
||||
|
||||
保留现有 `except Exception` 处理系统错误。
|
||||
|
||||
**Step 2: task 层构造 cancel_checker 并注入 runtime.run**
|
||||
|
||||
在 `tasks.py`:
|
||||
- 构造 key:`agent:cancel:{thread_id}:{run_id}`
|
||||
- 定义 `async def cancel_checker() -> bool: return bool(await redis.exists(key))`
|
||||
- 调用 `runtime.run(..., cancel_checker=cancel_checker)`
|
||||
|
||||
**Step 3: task 层补资源清理**
|
||||
|
||||
在 `run_agentscope_task` 的 `finally`:
|
||||
- 删除 cancel key 或缩短 TTL
|
||||
- 记录日志(仅必要字段)
|
||||
|
||||
**Step 4: 写 orchestrator 单测(先红)**
|
||||
|
||||
验证:
|
||||
- 收到 `CancelledError` 时发 `RUN_ERROR` 且 `code == "RUN_CANCELED"`
|
||||
|
||||
**Step 5: 写 tasks 单测(先红)**
|
||||
|
||||
验证:
|
||||
- runtime 收到的 `cancel_checker` 可用
|
||||
- key 命中时上抛 `CancelledError` 路径成立
|
||||
|
||||
**Step 6: 运行测试**
|
||||
|
||||
Run: `uv run pytest backend/tests/unit/core/agentscope/runtime/test_orchestrator.py backend/tests/unit/core/agentscope/runtime/test_tasks.py -k cancel -q`
|
||||
|
||||
Expected: PASS
|
||||
|
||||
**Step 7: 提交 runtime 编排层变更**
|
||||
|
||||
```bash
|
||||
git add backend/src/core/agentscope/runtime/orchestrator.py backend/src/core/agentscope/runtime/tasks.py backend/tests/unit/core/agentscope/runtime/test_orchestrator.py backend/tests/unit/core/agentscope/runtime/test_tasks.py
|
||||
git commit -m "fix: emit RUN_CANCELED error when run task is interrupted"
|
||||
```
|
||||
|
||||
### Task 5: 事件流与持久化一致性回归
|
||||
|
||||
**Files:**
|
||||
- Modify: `backend/tests/unit/core/agentscope/events/test_store.py`
|
||||
- Modify: `backend/tests/unit/core/agentscope/events/test_agui_codec.py`
|
||||
- Modify: `backend/tests/integration/v1/agent/test_sse_flow_live.py`
|
||||
|
||||
**Step 1: 补 event store 行为测试**
|
||||
|
||||
新增断言:
|
||||
- 当 `RUN_ERROR` 且 `code=RUN_CANCELED` 时,session 状态依然为 `FAILED`
|
||||
|
||||
**Step 2: 补 codec 测试**
|
||||
|
||||
新增断言:
|
||||
- `RUN_ERROR` 的 `code` 字段能正确透传到 wire event
|
||||
|
||||
**Step 3: 补 SSE 集成测试**
|
||||
|
||||
场景:
|
||||
- 触发 `/runs`
|
||||
- 触发 `/runs/{thread_id}/cancel`
|
||||
- SSE 最终出现 `RUN_ERROR(code=RUN_CANCELED)`
|
||||
|
||||
**Step 4: 运行事件相关测试**
|
||||
|
||||
Run: `uv run pytest backend/tests/unit/core/agentscope/events/test_store.py backend/tests/unit/core/agentscope/events/test_agui_codec.py backend/tests/integration/v1/agent/test_sse_flow_live.py -k "cancel or run_error" -q`
|
||||
|
||||
Expected: PASS
|
||||
|
||||
**Step 5: 提交事件层变更**
|
||||
|
||||
```bash
|
||||
git add backend/tests/unit/core/agentscope/events/test_store.py backend/tests/unit/core/agentscope/events/test_agui_codec.py backend/tests/integration/v1/agent/test_sse_flow_live.py
|
||||
git commit -m "test: cover RUN_CANCELED propagation across store codec and SSE"
|
||||
```
|
||||
|
||||
### Task 6: 全量验证与发布前检查
|
||||
|
||||
**Files:**
|
||||
- Modify: `docs/protocols/agent/api-endpoints.md`(如需补充最终字段)
|
||||
- Modify: `docs/protocols/agent/sse-events.md`(如需补充最终字段)
|
||||
|
||||
**Step 1: 运行受影响单元测试集合**
|
||||
|
||||
Run:
|
||||
|
||||
```bash
|
||||
uv run pytest backend/tests/unit/v1/agent/test_service.py backend/tests/unit/core/agentscope/runtime/test_runner.py backend/tests/unit/core/agentscope/runtime/test_orchestrator.py backend/tests/unit/core/agentscope/runtime/test_tasks.py backend/tests/unit/core/agentscope/events/test_store.py backend/tests/unit/core/agentscope/events/test_agui_codec.py -q
|
||||
```
|
||||
|
||||
Expected: PASS
|
||||
|
||||
**Step 2: 运行受影响集成测试集合**
|
||||
|
||||
Run:
|
||||
|
||||
```bash
|
||||
uv run pytest backend/tests/integration/v1/agent/test_routes.py backend/tests/integration/v1/agent/test_sse_flow_live.py -q
|
||||
```
|
||||
|
||||
Expected: PASS
|
||||
|
||||
**Step 3: 运行静态检查**
|
||||
|
||||
Run:
|
||||
|
||||
```bash
|
||||
uv run ruff check backend/src backend/tests
|
||||
uv run basedpyright
|
||||
```
|
||||
|
||||
Expected: PASS(无新增 lint/type 错误)
|
||||
|
||||
**Step 4: 手工验证路径**
|
||||
|
||||
手工流程:
|
||||
- 发起 `/runs`
|
||||
- 立刻调用 `/runs/{thread_id}/cancel?runId=...`
|
||||
- 观察 SSE:应以 `RUN_ERROR(code=RUN_CANCELED)` 结束
|
||||
- 检查 session:`status=failed`
|
||||
|
||||
**Step 5: 最终提交**
|
||||
|
||||
```bash
|
||||
git add docs/protocols/agent/api-endpoints.md docs/protocols/agent/sse-events.md backend/src/v1/agent/*.py backend/src/core/agentscope/runtime/*.py backend/tests/unit/core/agentscope/runtime/*.py backend/tests/unit/core/agentscope/events/*.py backend/tests/unit/v1/agent/test_service.py backend/tests/integration/v1/agent/test_routes.py backend/tests/integration/v1/agent/test_sse_flow_live.py
|
||||
git commit -m "feat: support run cancellation with RUN_CANCELED failed semantics"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 风险与回滚
|
||||
|
||||
- 风险 1:cancel key 误命中导致误中断
|
||||
- 缓解:key 粒度使用 `thread_id + run_id`,并设置 TTL
|
||||
- 风险 2:中断时出现重复终态事件
|
||||
- 缓解:在 orchestrator 保证 CancelledError 只走 `RUN_ERROR` 分支,避免继续发 `RUN_FINISHED`
|
||||
- 风险 3:高并发下 Redis 轮询压力上升
|
||||
- 缓解:轮询间隔 200ms,后续按并发量评估改为 pub/sub
|
||||
|
||||
回滚策略:
|
||||
- 回滚 `router/service/dependencies` cancel 新接口
|
||||
- 回滚 `runner/orchestrator/tasks` cancel 注入逻辑
|
||||
- 保持原 `POST /runs` 与 SSE 流程不变
|
||||
@@ -364,7 +364,8 @@ cost = uncached_prompt_tokens * input_cost_per_token
|
||||
| 0 | `UI_HISTORY` | `/history` API 投影可见的消息 |
|
||||
| 1 | `CONTEXT_ASSEMBLY` | 运行时上下文装配(context assembly)可见 |
|
||||
|
||||
> 新消息入库时,`chat` 模式设置 `mask = UI_HISTORY | CONTEXT_ASSEMBLY`(值为 3),`automation` 模式设置 `mask = 0`。
|
||||
> 用户输入入库时,`chat` 模式设置 `mask = UI_HISTORY | CONTEXT_ASSEMBLY`(值为 3),`automation` 模式设置 `mask = 0`。
|
||||
> agent 运行产物入库时,`automation` 模式设置 `mask = UI_HISTORY`(值为 1),用于展示历史但不参与 context assembly。
|
||||
|
||||
### /history API
|
||||
|
||||
@@ -385,6 +386,7 @@ WHERE (visibility_mask & 2) != 0
|
||||
**影响**:
|
||||
- `chat` 模式用户输入:mask=3 → 进入 `/history` ✅,进入 context assembly ✅
|
||||
- `automation` 模式用户输入:mask=0 → 进入 `/history` ❌,进入 context assembly ❌
|
||||
- `automation` 模式 agent 输出:mask=1 → 进入 `/history` ✅,进入 context assembly ❌
|
||||
|
||||
### Automation 模式上下文注入
|
||||
|
||||
@@ -396,7 +398,8 @@ WHERE (visibility_mask & 2) != 0
|
||||
|------|--------|--------------|
|
||||
| Pipeline | `router` -> `worker` | `router` -> `worker` |
|
||||
| 用户输入 visibility_mask | `UI_HISTORY \| CONTEXT_ASSEMBLY` | `0` |
|
||||
| 进入 /history | ✅ | ❌ |
|
||||
| agent 输出 visibility_mask | `UI_HISTORY \| CONTEXT_ASSEMBLY`(memory stage 仅 `UI_HISTORY`) | `UI_HISTORY` |
|
||||
| 进入 /history | ✅ | ✅(仅 agent 输出) |
|
||||
| 进入 context assembly | ✅(自动) | ❌(通过 run_input 注入) |
|
||||
| enabled_tools 来源 | `system_agents.yaml` worker 配置 | `AutomationJob.config.enabled_tools` |
|
||||
| context 配置来源 | `system_agents.yaml` router context_messages | `AutomationJob.config.context` |
|
||||
|
||||
Reference in New Issue
Block a user