From 599c597e69560908ed195724acd33c0412475471 Mon Sep 17 00:00:00 2001 From: qzl Date: Wed, 25 Mar 2026 17:41:55 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E9=87=8D=E6=9E=84=20agentscope=20?= =?UTF-8?q?=E7=BC=93=E5=AD=98=E6=9E=B6=E6=9E=84=EF=BC=8C=E6=96=B0=E5=A2=9E?= =?UTF-8?q?=E6=B6=88=E6=81=AF=E5=92=8C=E9=99=84=E4=BB=B6=E7=BC=93=E5=AD=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .env.example | 2 +- .../src/core/agentscope/caches/__init__.py | 21 + .../caches/attachment_content_cache.py | 96 +++++ .../caches/context_messages_cache.py | 281 +++++++++++++ .../user_context_cache.py | 63 +-- backend/src/core/agentscope/events/store.py | 103 ++++- .../core/agentscope/persistence/__init__.py | 9 - backend/src/core/agentscope/runtime/runner.py | 25 ++ .../core/agentscope/runtime/stage_emitter.py | 3 + backend/src/core/agentscope/runtime/tasks.py | 91 +++- backend/src/core/config/settings.py | 9 + backend/src/schemas/agent/runtime_models.py | 13 +- backend/src/services/caches/__init__.py | 4 + backend/src/services/caches/factory.py | 13 + backend/src/services/caches/interfaces.py | 19 + backend/src/services/caches/redis_store.py | 80 ++++ backend/src/v1/agent/service.py | 50 ++- backend/src/v1/users/service.py | 4 +- .../caches/test_context_messages_cache.py | 228 ++++++++++ .../unit/core/agentscope/events/test_store.py | 32 ++ .../persistence/test_user_context_cache.py | 2 +- .../agentscope/runtime/test_stage_emitter.py | 2 + .../unit/schemas/agent/test_runtime_models.py | 38 ++ .../2026-03-25-agent-run-cancel-failed.md | 392 ++++++++++++++++++ docs/protocols/agent/sse-events.md | 7 +- 25 files changed, 1509 insertions(+), 78 deletions(-) create mode 100644 backend/src/core/agentscope/caches/__init__.py create mode 100644 backend/src/core/agentscope/caches/attachment_content_cache.py create mode 100644 backend/src/core/agentscope/caches/context_messages_cache.py rename backend/src/core/agentscope/{persistence => caches}/user_context_cache.py (81%) delete mode 100644 backend/src/core/agentscope/persistence/__init__.py create mode 100644 backend/src/services/caches/__init__.py create mode 100644 backend/src/services/caches/factory.py create mode 100644 backend/src/services/caches/interfaces.py create mode 100644 backend/src/services/caches/redis_store.py create mode 100644 backend/tests/unit/core/agentscope/caches/test_context_messages_cache.py create mode 100644 backend/tests/unit/schemas/agent/test_runtime_models.py create mode 100644 docs/plans/2026-03-25-agent-run-cancel-failed.md diff --git a/.env.example b/.env.example index c4bc8c3..cbd5337 100644 --- a/.env.example +++ b/.env.example @@ -92,5 +92,5 @@ SOCIAL_APP_VERSION__DOWNLOAD_BASE_URL= ############ # Test相关 ############ -SOCIAL_TEST__PHONE=+8613812345678 +SOCIAL_TEST__PHONE=8613812345678 SOCIAL_TEST__PASSWORD=Test@123456 diff --git a/backend/src/core/agentscope/caches/__init__.py b/backend/src/core/agentscope/caches/__init__.py new file mode 100644 index 0000000..c404a36 --- /dev/null +++ b/backend/src/core/agentscope/caches/__init__.py @@ -0,0 +1,21 @@ +from .attachment_content_cache import ( + AttachmentContentCache, + create_attachment_content_cache, +) +from .context_messages_cache import ( + ContextMessagesCache, + create_context_messages_cache, +) +from .user_context_cache import ( + UserContextCache, + create_user_context_cache, +) + +__all__ = [ + "AttachmentContentCache", + "ContextMessagesCache", + "UserContextCache", + "create_attachment_content_cache", + "create_context_messages_cache", + "create_user_context_cache", +] diff --git a/backend/src/core/agentscope/caches/attachment_content_cache.py b/backend/src/core/agentscope/caches/attachment_content_cache.py new file mode 100644 index 0000000..2855270 --- /dev/null +++ b/backend/src/core/agentscope/caches/attachment_content_cache.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +from core.config.settings import config +from core.logging import get_logger +from services.caches import CacheStore, get_cache_store + +logger = get_logger("core.agentscope.caches.attachment_content_cache") + + +class AttachmentContentCache: + def __init__( + self, + *, + client: CacheStore, + key_prefix: str, + ttl_seconds: int, + max_base64_bytes: int, + ) -> None: + self._client = client + self._key_prefix = key_prefix + self._ttl_seconds = ttl_seconds + self._max_base64_bytes = max_base64_bytes + + async def get( + self, + *, + bucket: str, + path: str, + mime_type: str, + ) -> str | None: + key = self._key(bucket=bucket, path=path, mime_type=mime_type) + try: + raw = await self._client.hgetall(key) + except Exception as exc: + logger.warning( + "Failed to read attachment content cache", + bucket=bucket, + path=path, + error=str(exc), + ) + return None + + payload = raw.get("base64") + if not isinstance(payload, str) or not payload: + return None + return payload + + async def set( + self, + *, + bucket: str, + path: str, + mime_type: str, + base64_data: str, + ) -> None: + encoded_bytes = len(base64_data.encode("utf-8")) + if encoded_bytes > self._max_base64_bytes: + logger.info( + "Skip attachment cache write due to size limit", + bucket=bucket, + path=path, + encoded_bytes=encoded_bytes, + max_bytes=self._max_base64_bytes, + ) + return + + key = self._key(bucket=bucket, path=path, mime_type=mime_type) + try: + await self._client.hset( + key, + mapping={ + "base64": base64_data, + "mime_type": mime_type, + }, + ) + await self._client.expire(key, self._ttl_seconds) + except Exception as exc: + logger.warning( + "Failed to write attachment content cache", + bucket=bucket, + path=path, + error=str(exc), + ) + + def _key(self, *, bucket: str, path: str, mime_type: str) -> str: + return f"{self._key_prefix}:{bucket}:{path}:mime:{mime_type}" + + +def create_attachment_content_cache() -> AttachmentContentCache: + runtime = config.agent_runtime + return AttachmentContentCache( + client=get_cache_store(), + key_prefix=runtime.attachment_content_cache_prefix, + ttl_seconds=runtime.attachment_content_cache_ttl_seconds, + max_base64_bytes=runtime.attachment_content_cache_max_base64_bytes, + ) diff --git a/backend/src/core/agentscope/caches/context_messages_cache.py b/backend/src/core/agentscope/caches/context_messages_cache.py new file mode 100644 index 0000000..eab6e3b --- /dev/null +++ b/backend/src/core/agentscope/caches/context_messages_cache.py @@ -0,0 +1,281 @@ +from __future__ import annotations + +from datetime import datetime, timezone +import json + +from core.config.settings import config +from core.logging import get_logger +from schemas.agent.visibility import SystemVisibilityBit, bit_mask +from schemas.domain.automation import ContextWindowMode, MessageContextConfig +from services.caches import CacheStore, get_cache_store + +logger = get_logger("core.agentscope.caches.context_messages_cache") + + +class ContextMessagesCache: + def __init__( + self, + *, + client: CacheStore, + key_prefix: str, + ttl_seconds: int, + ) -> None: + self._client = client + self._key_prefix = key_prefix + self._ttl_seconds = ttl_seconds + + async def get( + self, + *, + thread_id: str, + runtime_mode: str, + context_config: MessageContextConfig, + ) -> list[dict[str, object]] | None: + key = self._cache_key( + thread_id=thread_id, + runtime_mode=runtime_mode, + context_config=context_config, + ) + try: + raw = await self._client.hgetall(key) + except Exception as exc: + logger.warning( + "Failed to read context messages cache", + thread_id=thread_id, + error=str(exc), + ) + return None + + payload = raw.get("payload") + if not isinstance(payload, str) or not payload: + return None + + try: + decoded = json.loads(payload) + except Exception: + await self._safe_delete(key) + return None + + if not isinstance(decoded, dict): + await self._safe_delete(key) + return None + + messages = decoded.get("messages") + if not isinstance(messages, list): + await self._safe_delete(key) + return None + + return [item for item in messages if isinstance(item, dict)] + + async def set( + self, + *, + thread_id: str, + runtime_mode: str, + context_config: MessageContextConfig, + messages: list[dict[str, object]], + ) -> None: + key = self._cache_key( + thread_id=thread_id, + runtime_mode=runtime_mode, + context_config=context_config, + ) + index_key = self._index_key(thread_id=thread_id, runtime_mode=runtime_mode) + payload = json.dumps( + { + "window_mode": context_config.window_mode.value, + "window_count": int(context_config.window_count), + "messages": [item for item in messages if isinstance(item, dict)], + }, + ensure_ascii=True, + separators=(",", ":"), + ) + try: + await self._client.hset(key, mapping={"payload": payload}) + await self._client.expire(key, self._ttl_seconds) + await self._client.sadd(index_key, key) + await self._client.expire(index_key, self._ttl_seconds) + except Exception as exc: + logger.warning( + "Failed to write context messages cache", + thread_id=thread_id, + error=str(exc), + ) + + async def append_message( + self, + *, + thread_id: str, + runtime_mode: str, + visibility_mask: int, + message: dict[str, object], + ) -> None: + if not self._is_context_visible(visibility_mask=visibility_mask): + return + + index_key = self._index_key(thread_id=thread_id, runtime_mode=runtime_mode) + try: + keys = await self._client.smembers(index_key) + except Exception as exc: + logger.warning( + "Failed to read context cache index", + thread_id=thread_id, + error=str(exc), + ) + return + + if not keys: + return + + normalized = self._normalize_message(message) + for key in keys: + try: + raw = await self._client.hgetall(key) + payload_raw = raw.get("payload") + if not isinstance(payload_raw, str) or not payload_raw: + continue + decoded = json.loads(payload_raw) + if not isinstance(decoded, dict): + continue + + mode = decoded.get("window_mode") + count = decoded.get("window_count") + messages_raw = decoded.get("messages") + if ( + not isinstance(mode, str) + or not isinstance(count, int) + or not isinstance(messages_raw, list) + ): + continue + + messages = [item for item in messages_raw if isinstance(item, dict)] + messages.append(dict(normalized)) + trimmed = self._trim_messages(messages=messages, mode=mode, count=count) + + next_payload = json.dumps( + { + "window_mode": mode, + "window_count": count, + "messages": trimmed, + }, + ensure_ascii=True, + separators=(",", ":"), + ) + await self._client.hset(key, mapping={"payload": next_payload}) + await self._client.expire(key, self._ttl_seconds) + except Exception as exc: + logger.warning( + "Failed to append context cache message", + key=key, + thread_id=thread_id, + error=str(exc), + ) + + def _cache_key( + self, + *, + thread_id: str, + runtime_mode: str, + context_config: MessageContextConfig, + ) -> str: + return ( + f"{self._key_prefix}:{thread_id}:rm:{runtime_mode}:" + f"wm:{context_config.window_mode.value}:wc:{int(context_config.window_count)}" + ) + + def _index_key(self, *, thread_id: str, runtime_mode: str) -> str: + return f"{self._key_prefix}:index:{runtime_mode}:{thread_id}" + + async def _safe_delete(self, key: str) -> None: + try: + await self._client.delete(key) + except Exception: + return + + def _trim_messages( + self, + *, + messages: list[dict[str, object]], + mode: str, + count: int, + ) -> list[dict[str, object]]: + safe_count = max(int(count), 1) + if mode == ContextWindowMode.NUMBER.value: + return self._trim_by_user_window(messages=messages, count=safe_count) + return self._trim_by_day_window(messages=messages, count=safe_count) + + def _trim_by_user_window( + self, + *, + messages: list[dict[str, object]], + count: int, + ) -> list[dict[str, object]]: + selected_reversed: list[dict[str, object]] = [] + user_count = 0 + for item in reversed(messages): + selected_reversed.append(item) + role = item.get("role") + if isinstance(role, str) and role == "user": + user_count += 1 + if user_count >= count: + break + selected_reversed.reverse() + return selected_reversed + + def _trim_by_day_window( + self, + *, + messages: list[dict[str, object]], + count: int, + ) -> list[dict[str, object]]: + selected_reversed: list[dict[str, object]] = [] + days_seen: list[str] = [] + for item in reversed(messages): + day_value = self._extract_day(item) + if day_value not in days_seen: + if len(days_seen) >= count: + break + days_seen.append(day_value) + selected_reversed.append(item) + selected_reversed.reverse() + return selected_reversed + + @staticmethod + def _extract_day(message: dict[str, object]) -> str: + raw = message.get("timestamp") + if isinstance(raw, str) and raw: + normalized = raw.replace("Z", "+00:00") + try: + return datetime.fromisoformat(normalized).date().isoformat() + except ValueError: + pass + return datetime.now(timezone.utc).date().isoformat() + + @staticmethod + def _normalize_message(message: dict[str, object]) -> dict[str, object]: + normalized: dict[str, object] = { + "role": str(message.get("role") or "assistant"), + "content": str(message.get("content") or ""), + "timestamp": str( + message.get("timestamp") + or datetime.now(timezone.utc).isoformat(timespec="seconds") + ), + } + metadata = message.get("metadata") + if isinstance(metadata, dict): + normalized["metadata"] = metadata + return normalized + + @staticmethod + def _is_context_visible(*, visibility_mask: int) -> bool: + required = bit_mask(bit=int(SystemVisibilityBit.CONTEXT_ASSEMBLY)) + return (max(int(visibility_mask), 0) & required) != 0 + + +def create_context_messages_cache() -> ContextMessagesCache: + runtime = config.agent_runtime + return ContextMessagesCache( + client=get_cache_store(), + key_prefix=runtime.context_messages_cache_prefix, + ttl_seconds=runtime.context_messages_cache_ttl_seconds, + ) diff --git a/backend/src/core/agentscope/persistence/user_context_cache.py b/backend/src/core/agentscope/caches/user_context_cache.py similarity index 81% rename from backend/src/core/agentscope/persistence/user_context_cache.py rename to backend/src/core/agentscope/caches/user_context_cache.py index a040ed8..2e18199 100644 --- a/backend/src/core/agentscope/persistence/user_context_cache.py +++ b/backend/src/core/agentscope/caches/user_context_cache.py @@ -1,49 +1,26 @@ from __future__ import annotations -import inspect import json from collections.abc import Iterable -from typing import Any, Protocol +from typing import Any from uuid import UUID -import redis.asyncio as redis from core.config.settings import config from core.logging import get_logger from schemas.shared.user import ( UserContext, parse_profile_settings, ) +from services.caches import CacheStore, get_cache_store -logger = get_logger("core.agentscope.persistence.user_context_cache") - - -class RedisHashClient(Protocol): - def hgetall(self, name: str, /) -> Any: ... - - def hset(self, name: str, /, *args: Any, **kwargs: Any) -> Any: ... - - def hincrby(self, name: str, key: str, amount: int = 1, /) -> Any: ... - - def expire(self, name: str, time: int, /) -> Any: ... - - def delete(self, *names: str) -> Any: ... - - def sadd(self, name: str, *values: str) -> Any: ... - - def smembers(self, name: str) -> Any: ... - - -async def _maybe_await(value: Any) -> Any: - if inspect.isawaitable(value): - return await value - return value +logger = get_logger("core.agentscope.caches.user_context_cache") class UserContextCache: def __init__( self, *, - client: RedisHashClient, + client: CacheStore, key_prefix: str, ttl_seconds: int, max_turns: int, @@ -56,7 +33,7 @@ class UserContextCache: async def get(self, *, session_id: UUID) -> UserContext | None: key = self._key(session_id) try: - raw = await _maybe_await(self._client.hgetall(key)) + raw = await self._client.hgetall(key) except Exception as exc: logger.warning( "Failed to read user context cache", @@ -107,18 +84,16 @@ class UserContextCache: index_key = self._user_sessions_key(user_id) payload = self._serialize(context) try: - await _maybe_await( - self._client.hset( - key, - mapping={ - "payload": payload, - "turns_used": "0", - }, - ) + await self._client.hset( + key, + mapping={ + "payload": payload, + "turns_used": "0", + }, ) - await _maybe_await(self._client.expire(key, self._ttl_seconds)) - await _maybe_await(self._client.sadd(index_key, key)) - await _maybe_await(self._client.expire(index_key, self._ttl_seconds)) + await self._client.expire(key, self._ttl_seconds) + await self._client.sadd(index_key, key) + await self._client.expire(index_key, self._ttl_seconds) except Exception as exc: logger.warning( "Failed to write user context cache", @@ -130,7 +105,7 @@ class UserContextCache: async def invalidate_user(self, *, user_id: UUID) -> int: index_key = self._user_sessions_key(user_id) try: - members_raw = await _maybe_await(self._client.smembers(index_key)) + members_raw = await self._client.smembers(index_key) except Exception as exc: logger.warning( "Failed to read user context cache index", @@ -147,7 +122,7 @@ class UserContextCache: deleted = 0 try: - deleted_raw = await _maybe_await(self._client.delete(*members)) + deleted_raw = await self._client.delete(*members) deleted = self._parse_int(deleted_raw) except Exception as exc: logger.warning( @@ -205,7 +180,7 @@ class UserContextCache: async def _safe_delete(self, key: str) -> None: try: - await _maybe_await(self._client.delete(key)) + await self._client.delete(key) except Exception as exc: logger.warning( "Failed to delete user context cache key", key=key, error=str(exc) @@ -214,7 +189,7 @@ class UserContextCache: async def _safe_hincrby(self, key: str, field: str, amount: int) -> None: try: - await _maybe_await(self._client.hincrby(key, field, amount)) + await self._client.hincrby(key, field, amount) except Exception as exc: logger.warning( "Failed to update user context cache usage", @@ -272,7 +247,7 @@ class UserContextCache: def create_user_context_cache() -> UserContextCache: - client = redis.from_url(config.redis.url, decode_responses=True) + client = get_cache_store() runtime_settings = config.agent_runtime return UserContextCache( client=client, diff --git a/backend/src/core/agentscope/events/store.py b/backend/src/core/agentscope/events/store.py index 3113785..2d668fd 100644 --- a/backend/src/core/agentscope/events/store.py +++ b/backend/src/core/agentscope/events/store.py @@ -1,11 +1,16 @@ from __future__ import annotations +from datetime import datetime, timezone from decimal import Decimal, InvalidOperation from typing import Any, Callable, Protocol from uuid import UUID +from core.agentscope.caches.context_messages_cache import ( + create_context_messages_cache, +) from core.agentscope.events.persistence import MessageRepository, SessionRepository from core.logging import get_logger +from schemas.agent.forwarded_props import RuntimeMode from schemas.enums import AgentChatMessageRole, AgentChatSessionStatus from schemas.agent.system_agent import AgentType from schemas.agent.runtime_models import AgentOutput, RouterAgentOutput, ToolAgentOutput @@ -174,7 +179,10 @@ class SqlAlchemyEventStore: if locked_session is None: return seq = int(getattr(locked_session, "message_count", 0) or 0) + 1 - await message_repo.append_message( + visibility_mask = self._resolve_stage_visibility_mask( + event=event, + ) + persisted = await message_repo.append_message( session_id=session_id, seq=seq, role=role, @@ -186,9 +194,16 @@ class SqlAlchemyEventStore: output_tokens=output_tokens, cost=cost, latency_ms=latency_ms, - visibility_mask=self._resolve_stage_visibility_mask( - event=event, - ), + visibility_mask=visibility_mask, + ) + await self._append_context_cache_message( + session_id=session_id, + event=event, + visibility_mask=visibility_mask, + role=role.value, + content=content, + metadata=metadata_model.model_dump(mode="json", exclude_none=True), + timestamp=self._resolve_message_timestamp(persisted), ) current_status = getattr(chat_session, "status", AgentChatSessionStatus.RUNNING) @@ -339,16 +354,26 @@ class SqlAlchemyEventStore: if locked_session is None: return seq = int(getattr(locked_session, "message_count", 0) or 0) + 1 - await message_repo.append_message( + visibility_mask = self._resolve_stage_visibility_mask( + event=event, + ) + persisted = await message_repo.append_message( session_id=session_id, seq=seq, role=AgentChatMessageRole.TOOL, content=content, tool_name=tool_output.tool_name, metadata=metadata_model.model_dump(mode="json", exclude_none=True), - visibility_mask=self._resolve_stage_visibility_mask( - event=event, - ), + visibility_mask=visibility_mask, + ) + await self._append_context_cache_message( + session_id=session_id, + event=event, + visibility_mask=visibility_mask, + role=AgentChatMessageRole.TOOL.value, + content=content, + metadata=metadata_model.model_dump(mode="json", exclude_none=True), + timestamp=self._resolve_message_timestamp(persisted), ) current_status = getattr(chat_session, "status", AgentChatSessionStatus.RUNNING) @@ -377,6 +402,13 @@ class SqlAlchemyEventStore: *, event: dict[str, Any], ) -> int: + runtime_mode = self._event_value(event, "runtime_mode") + if ( + isinstance(runtime_mode, str) + and runtime_mode.strip().lower() == RuntimeMode.AUTOMATION.value + ): + return bit_mask(bit=int(SystemVisibilityBit.UI_HISTORY)) + raw_stage = self._event_value(event, "stage") if not isinstance(raw_stage, str): return bit_mask(bit=int(SystemVisibilityBit.UI_HISTORY)) @@ -387,6 +419,61 @@ class SqlAlchemyEventStore: bit=int(SystemVisibilityBit.CONTEXT_ASSEMBLY) ) + async def _append_context_cache_message( + self, + *, + session_id: UUID, + event: dict[str, Any], + visibility_mask: int, + role: str, + content: str, + metadata: dict[str, object] | None, + timestamp: str, + ) -> None: + message_payload: dict[str, object] = { + "role": role, + "content": content, + "timestamp": timestamp, + } + if isinstance(metadata, dict): + message_payload["metadata"] = metadata + + try: + context_cache = create_context_messages_cache() + await context_cache.append_message( + thread_id=str(session_id), + runtime_mode=self._resolve_runtime_mode(event=event), + visibility_mask=visibility_mask, + message=message_payload, + ) + except Exception as exc: + self._logger.warning( + "Failed to append context cache message from event", + thread_id=str(session_id), + error=str(exc), + ) + + @staticmethod + def _resolve_runtime_mode(*, event: dict[str, Any]) -> str: + raw = event.get("runtime_mode") + if isinstance(raw, str): + normalized = raw.strip().lower() + if normalized: + return normalized + return RuntimeMode.CHAT.value + + @staticmethod + def _resolve_message_timestamp(message: Any) -> str: + created_at = getattr(message, "created_at", None) + if isinstance(created_at, str) and created_at: + return created_at + if isinstance(created_at, datetime): + try: + return created_at.astimezone(timezone.utc).isoformat() + except Exception: + pass + return datetime.now(timezone.utc).isoformat(timespec="seconds") + async def _update_session_state( self, *, diff --git a/backend/src/core/agentscope/persistence/__init__.py b/backend/src/core/agentscope/persistence/__init__.py deleted file mode 100644 index 62c5c0a..0000000 --- a/backend/src/core/agentscope/persistence/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from core.agentscope.persistence.user_context_cache import ( - UserContextCache, - create_user_context_cache, -) - -__all__ = [ - "UserContextCache", - "create_user_context_cache", -] diff --git a/backend/src/core/agentscope/runtime/runner.py b/backend/src/core/agentscope/runtime/runner.py index 7cc9ef8..f66bfc2 100644 --- a/backend/src/core/agentscope/runtime/runner.py +++ b/backend/src/core/agentscope/runtime/runner.py @@ -29,7 +29,9 @@ from models.llm_factory import LlmFactory from models.system_agents import SystemAgents from schemas.agent.forwarded_props import ( ClientTimeContext, + RuntimeMode, parse_forwarded_props_client_time, + parse_forwarded_props_runtime_mode, ) from schemas.agent.runtime_models import ( RouterAgentOutput, @@ -76,6 +78,7 @@ class AgentScopeRunner: ) -> dict[str, Any]: owner_id = UUID(user_context.id) runtime_client_time = self._resolve_runtime_client_time(run_input=run_input) + runtime_mode = self._resolve_runtime_mode(run_input=run_input) async with AsyncSessionLocal() as session: router_config = await self._load_stage_config( @@ -99,6 +102,7 @@ class AgentScopeRunner: context_messages=context_messages, stage_config=router_config, runtime_client_time=runtime_client_time, + runtime_mode=runtime_mode, user_memory=user_memory, ) worker_output = await self._execute_worker_step( @@ -109,6 +113,7 @@ class AgentScopeRunner: toolkit=worker_toolkit, stage_config=worker_config, runtime_client_time=runtime_client_time, + runtime_mode=runtime_mode, work_memory=work_memory, ) return { @@ -171,6 +176,7 @@ class AgentScopeRunner: context_messages: list[Msg], stage_config: SystemAgentRuntimeConfig, runtime_client_time: ClientTimeContext | None, + runtime_mode: RuntimeMode, user_memory: UserMemoryContent | None, ) -> RouterAgentOutput: await self._emit_step_event( @@ -178,6 +184,7 @@ class AgentScopeRunner: run_input=run_input, step_name=AgentType.ROUTER.value, event_type="STEP_STARTED", + runtime_mode=runtime_mode, ) router_result = await self._run_router_stage( user_context=user_context, @@ -193,6 +200,7 @@ class AgentScopeRunner: run_input=run_input, step_name=AgentType.ROUTER.value, event_type="STEP_FINISHED", + runtime_mode=runtime_mode, extra_event={ "_router_persist": { "router_output": router_output.model_dump( @@ -214,6 +222,7 @@ class AgentScopeRunner: toolkit: Any, stage_config: SystemAgentRuntimeConfig, runtime_client_time: ClientTimeContext | None, + runtime_mode: RuntimeMode, work_memory: WorkProfileContent | None, ) -> WorkerAgentOutputLite: worker_output_model = resolve_worker_output_model(router_output.ui.ui_mode) @@ -222,6 +231,7 @@ class AgentScopeRunner: run_input=run_input, step_name=AgentType.WORKER.value, event_type="STEP_STARTED", + runtime_mode=runtime_mode, ) worker_result = await self._run_worker_stage( user_context=user_context, @@ -234,6 +244,7 @@ class AgentScopeRunner: worker_output_model=worker_output_model, pipeline=pipeline, runtime_client_time=runtime_client_time, + runtime_mode=runtime_mode, work_memory=work_memory, ) worker_output = worker_output_model.model_validate(worker_result.payload) @@ -242,6 +253,7 @@ class AgentScopeRunner: run_input=run_input, step_name=AgentType.WORKER.value, event_type="STEP_FINISHED", + runtime_mode=runtime_mode, ) return worker_output @@ -332,6 +344,7 @@ class AgentScopeRunner: worker_output_model: type[WorkerAgentOutputLite], pipeline: PipelineLike, runtime_client_time: ClientTimeContext | None, + runtime_mode: RuntimeMode, work_memory: WorkProfileContent | None, ) -> StageExecutionResult: tracking_model = self._build_model(stage_config=stage_config) @@ -340,6 +353,7 @@ class AgentScopeRunner: session_id=run_input.thread_id, run_id=run_input.run_id, stage=stage_config.agent_type.value, + runtime_mode=runtime_mode.value, emit_text_events=True, emit_tool_events=True, ) @@ -437,12 +451,14 @@ class AgentScopeRunner: run_input: RunAgentInput, step_name: str, event_type: str, + runtime_mode: RuntimeMode, extra_event: dict[str, Any] | None = None, ) -> None: payload: dict[str, Any] = { "type": event_type, "threadId": run_input.thread_id, "runId": run_input.run_id, + "runtime_mode": runtime_mode.value, "stepName": step_name, } if extra_event: @@ -459,6 +475,15 @@ class AgentScopeRunner: getattr(run_input, "forwarded_props", None) ) + @staticmethod + def _resolve_runtime_mode(*, run_input: RunAgentInput) -> RuntimeMode: + try: + return parse_forwarded_props_runtime_mode( + getattr(run_input, "forwarded_props", None) + ) + except ValueError: + return RuntimeMode.CHAT + @staticmethod def _resolve_provider_api_key(*, factory_name: str) -> str: normalized_factory_name = factory_name.strip().upper() diff --git a/backend/src/core/agentscope/runtime/stage_emitter.py b/backend/src/core/agentscope/runtime/stage_emitter.py index fb9e7f1..c8a157c 100644 --- a/backend/src/core/agentscope/runtime/stage_emitter.py +++ b/backend/src/core/agentscope/runtime/stage_emitter.py @@ -20,6 +20,7 @@ class PipelineStageEmitter: session_id: str, run_id: str, stage: str, + runtime_mode: str, emit_text_events: bool, emit_tool_events: bool, ) -> None: @@ -27,6 +28,7 @@ class PipelineStageEmitter: self._session_id = session_id self._run_id = run_id self._stage = stage + self._runtime_mode = runtime_mode self._emit_text_events = emit_text_events self._emit_tool_events = emit_tool_events self._emitted_tool_calls: set[str] = set() @@ -127,6 +129,7 @@ class PipelineStageEmitter: "type": event_type, "threadId": self._session_id, "runId": self._run_id, + "runtime_mode": self._runtime_mode, **payload, }, ) diff --git a/backend/src/core/agentscope/runtime/tasks.py b/backend/src/core/agentscope/runtime/tasks.py index 9798ecc..4c27621 100644 --- a/backend/src/core/agentscope/runtime/tasks.py +++ b/backend/src/core/agentscope/runtime/tasks.py @@ -6,6 +6,13 @@ from typing import Any, cast from uuid import UUID from agentscope.message import Msg +from core.agentscope.caches import create_user_context_cache +from core.agentscope.caches.attachment_content_cache import ( + create_attachment_content_cache, +) +from core.agentscope.caches.context_messages_cache import ( + create_context_messages_cache, +) from core.agentscope.events import ( AgentScopeAgUiCodec, AgentScopeEventPipeline, @@ -20,12 +27,16 @@ from core.config.settings import config from core.db.session import AsyncSessionLocal from core.logging import get_logger from core.taskiq.app import worker_agent_broker, worker_automation_broker +from schemas.agent.forwarded_props import ( + RuntimeMode, + parse_forwarded_props_runtime_mode, +) from schemas.domain.automation import MessageContextConfig, RuntimeConfig -from schemas.domain.memory_content import UserMemoryContent, WorkProfileContent from schemas.domain.chat_message import ( AgentChatMessageMetadata, extract_user_message_attachments, ) +from schemas.domain.memory_content import UserMemoryContent, WorkProfileContent from schemas.shared.user import UserContext from services.base.redis import get_or_init_redis_client from services.base.supabase import supabase_service @@ -73,28 +84,56 @@ async def _build_user_context( *, owner_id: UUID, session: Any, + session_id: str, ) -> UserContext: + cache = create_user_context_cache() + cached = await cache.get(session_id=UUID(session_id)) + if cached: + return cached + current_user = CurrentUser(id=owner_id) user_service = get_user_service(session=session, user=current_user) - return await user_service.get_me() + user_context = await user_service.get_me() + + await cache.set(session_id=UUID(session_id), context=user_context) + return user_context async def _build_recent_context_messages( *, session: Any, thread_id: str, + runtime_mode: RuntimeMode = RuntimeMode.CHAT, context_config: "MessageContextConfig", ) -> list[Msg]: - context_service = AgentContextService(repository=AgentRepository(session)) - result = await context_service.load_context_messages( + context_cache = create_context_messages_cache() + attachment_cache = create_attachment_content_cache() + raw_messages = await context_cache.get( thread_id=thread_id, + runtime_mode=runtime_mode.value, context_config=context_config, ) - if not result: - return [] + if raw_messages is None: + context_service = AgentContextService(repository=AgentRepository(session)) + result = await context_service.load_context_messages( + thread_id=thread_id, + context_config=context_config, + ) + if not result: + return [] + + messages_obj = result.get("messages") + if not isinstance(messages_obj, list): + return [] + raw_messages = [item for item in messages_obj if isinstance(item, dict)] + await context_cache.set( + thread_id=thread_id, + runtime_mode=runtime_mode.value, + context_config=context_config, + messages=raw_messages, + ) - raw_messages: list[dict[str, object]] = result.get("messages") or [] if not raw_messages: return [] @@ -120,6 +159,24 @@ async def _build_recent_context_messages( :_MAX_CONTEXT_ATTACHMENTS ] for attachment in attachments: + mime_type = attachment.mime_type or "image/png" + cached_b64 = await attachment_cache.get( + bucket=attachment.bucket, + path=attachment.path, + mime_type=mime_type, + ) + if cached_b64: + image_blocks.append( + { + "type": "image", + "source": { + "type": "base64", + "media_type": mime_type, + "data": cached_b64, + }, + } + ) + continue try: image_bytes = await supabase_service.download_bytes( bucket=attachment.bucket, @@ -128,12 +185,18 @@ async def _build_recent_context_messages( except Exception: continue b64_data = base64.b64encode(image_bytes).decode("utf-8") + await attachment_cache.set( + bucket=attachment.bucket, + path=attachment.path, + mime_type=mime_type, + base64_data=b64_data, + ) image_blocks.append( { "type": "image", "source": { "type": "base64", - "media_type": attachment.mime_type or "image/png", + "media_type": mime_type, "data": b64_data, }, } @@ -183,6 +246,13 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]: raise ValueError("run_input is required") run_input = parse_run_input(run_input_raw) + runtime_mode = RuntimeMode.CHAT + try: + runtime_mode = parse_forwarded_props_runtime_mode( + getattr(run_input, "forwarded_props", None) + ) + except ValueError: + runtime_mode = RuntimeMode.CHAT runtime_config = RuntimeConfig.model_validate(runtime_config_raw or {}) thread_id = run_input.thread_id run_id = run_input.run_id @@ -195,7 +265,9 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]: async with AsyncSessionLocal() as session: current_user = CurrentUser(id=owner_id) - user_context = await _build_user_context(owner_id=owner_id, session=session) + user_context = await _build_user_context( + owner_id=owner_id, session=session, session_id=thread_id + ) memories_service = MemoriesService( repository=SQLAlchemyMemoriesRepository(session), session=session, @@ -226,6 +298,7 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]: context_messages = await _build_recent_context_messages( session=session, thread_id=thread_id, + runtime_mode=runtime_mode, context_config=runtime_config.context, ) diff --git a/backend/src/core/config/settings.py b/backend/src/core/config/settings.py index 5dcd117..46acb3a 100644 --- a/backend/src/core/config/settings.py +++ b/backend/src/core/config/settings.py @@ -161,6 +161,15 @@ class AgentRuntimeSettings(BaseModel): user_context_cache_prefix: str = "agent:user-context" user_context_cache_ttl_seconds: int = Field(default=600, ge=60, le=86400) user_context_cache_max_turns: int = Field(default=6, ge=1, le=100) + context_messages_cache_prefix: str = "agent:context-messages" + context_messages_cache_ttl_seconds: int = Field(default=600, ge=30, le=86400) + attachment_content_cache_prefix: str = "agent:attachment-content" + attachment_content_cache_ttl_seconds: int = Field(default=1800, ge=30, le=86400) + attachment_content_cache_max_base64_bytes: int = Field( + default=6 * 1024 * 1024, + ge=1024, + le=64 * 1024 * 1024, + ) class AutomationSchedulerSettings(BaseModel): diff --git a/backend/src/schemas/agent/runtime_models.py b/backend/src/schemas/agent/runtime_models.py index 37356c5..c5317d4 100644 --- a/backend/src/schemas/agent/runtime_models.py +++ b/backend/src/schemas/agent/runtime_models.py @@ -3,7 +3,7 @@ from __future__ import annotations from enum import Enum from typing import Any -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, field_validator from schemas.agent.ui_hints import UiHintsPayload @@ -86,6 +86,17 @@ class KeyEntity(BaseModel): type: str value: str | None = None + @field_validator("value", mode="before") + @classmethod + def normalize_value(cls, value: object) -> object: + if value is None: + return None + if isinstance(value, str): + return value + if isinstance(value, bool | int | float): + return str(value) + return value + class ConstraintItem(BaseModel): model_config = ConfigDict(extra="forbid") diff --git a/backend/src/services/caches/__init__.py b/backend/src/services/caches/__init__.py new file mode 100644 index 0000000..c7ca694 --- /dev/null +++ b/backend/src/services/caches/__init__.py @@ -0,0 +1,4 @@ +from .factory import get_cache_store +from .interfaces import CacheStore + +__all__ = ["CacheStore", "get_cache_store"] diff --git a/backend/src/services/caches/factory.py b/backend/src/services/caches/factory.py new file mode 100644 index 0000000..0ed3c34 --- /dev/null +++ b/backend/src/services/caches/factory.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +from .interfaces import CacheStore +from .redis_store import RedisCacheStore + +_cache_store: CacheStore | None = None + + +def get_cache_store() -> CacheStore: + global _cache_store + if _cache_store is None: + _cache_store = RedisCacheStore() + return _cache_store diff --git a/backend/src/services/caches/interfaces.py b/backend/src/services/caches/interfaces.py new file mode 100644 index 0000000..ff5950b --- /dev/null +++ b/backend/src/services/caches/interfaces.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +from typing import Protocol + + +class CacheStore(Protocol): + async def hgetall(self, key: str, /) -> dict[str, str]: ... + + async def hset(self, key: str, /, mapping: dict[str, str]) -> int: ... + + async def hincrby(self, key: str, field: str, amount: int = 1, /) -> int: ... + + async def expire(self, key: str, ttl_seconds: int, /) -> int: ... + + async def delete(self, *keys: str) -> int: ... + + async def sadd(self, key: str, *members: str) -> int: ... + + async def smembers(self, key: str, /) -> set[str]: ... diff --git a/backend/src/services/caches/redis_store.py b/backend/src/services/caches/redis_store.py new file mode 100644 index 0000000..9a7a7cf --- /dev/null +++ b/backend/src/services/caches/redis_store.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +import inspect +from typing import Any + +from services.base.redis import get_or_init_redis_client + +from .interfaces import CacheStore + + +def _to_text(value: Any) -> str | None: + if isinstance(value, str): + return value + if isinstance(value, bytes): + try: + return value.decode("utf-8") + except UnicodeDecodeError: + return None + return None + + +async def _maybe_await(value: Any) -> Any: + if inspect.isawaitable(value): + return await value + return value + + +class RedisCacheStore(CacheStore): + async def hgetall(self, key: str) -> dict[str, str]: + client = await get_or_init_redis_client() + raw = await _maybe_await(client.hgetall(key)) + if not isinstance(raw, dict): + return {} + + decoded: dict[str, str] = {} + for raw_key, raw_value in raw.items(): + key_text = _to_text(raw_key) + value_text = _to_text(raw_value) + if key_text is None or value_text is None: + continue + decoded[key_text] = value_text + return decoded + + async def hset(self, key: str, mapping: dict[str, str]) -> int: + client = await get_or_init_redis_client() + result = await _maybe_await(client.hset(key, mapping=mapping)) + return int(result) + + async def hincrby(self, key: str, field: str, amount: int = 1) -> int: + client = await get_or_init_redis_client() + result = await _maybe_await(client.hincrby(key, field, amount)) + return int(result) + + async def expire(self, key: str, ttl_seconds: int) -> int: + client = await get_or_init_redis_client() + result = await _maybe_await(client.expire(key, ttl_seconds)) + return int(result) + + async def delete(self, *keys: str) -> int: + if not keys: + return 0 + client = await get_or_init_redis_client() + result = await _maybe_await(client.delete(*keys)) + return int(result) + + async def sadd(self, key: str, *members: str) -> int: + if not members: + return 0 + client = await get_or_init_redis_client() + result = await _maybe_await(client.sadd(key, *members)) + return int(result) + + async def smembers(self, key: str) -> set[str]: + client = await get_or_init_redis_client() + raw = await _maybe_await(client.smembers(key)) + if isinstance(raw, set): + return {value for item in raw if (value := _to_text(item))} + if isinstance(raw, list | tuple): + return {value for item in raw if (value := _to_text(item))} + return set() diff --git a/backend/src/v1/agent/service.py b/backend/src/v1/agent/service.py index f234838..e1e94f3 100644 --- a/backend/src/v1/agent/service.py +++ b/backend/src/v1/agent/service.py @@ -1,6 +1,6 @@ from __future__ import annotations -from datetime import date +from datetime import date, datetime, timezone import hashlib from urllib.parse import urlparse @@ -10,6 +10,9 @@ from fastapi import HTTPException from sqlalchemy.exc import IntegrityError from core.auth.models import CurrentUser +from core.agentscope.caches.context_messages_cache import ( + create_context_messages_cache, +) from core.agentscope.schemas.agui_input import extract_latest_user_payload from core.config.settings import config from core.logging import get_logger @@ -124,6 +127,13 @@ class AgentService: visibility_mask=visibility_mask, ) await self._repository.commit() + await self._append_context_cache_user_message( + thread_id=thread_id, + runtime_mode=runtime_mode, + visibility_mask=visibility_mask, + content=user_message_text, + metadata=user_message_metadata, + ) queue = "automation" if runtime_mode == RuntimeMode.AUTOMATION else "agent" task_id = await self._queue.enqueue( @@ -147,6 +157,44 @@ class AgentService: created=created, ) + async def _append_context_cache_user_message( + self, + *, + thread_id: str, + runtime_mode: RuntimeMode, + visibility_mask: int, + content: str, + metadata: AgentChatMessageMetadata | None, + ) -> None: + metadata_payload = ( + metadata.model_dump(mode="json", exclude_none=True) + if isinstance(metadata, AgentChatMessageMetadata) + else None + ) + message_payload: dict[str, object] = { + "role": "user", + "content": content, + "timestamp": datetime.now(timezone.utc).isoformat(timespec="seconds"), + } + if isinstance(metadata_payload, dict): + message_payload["metadata"] = metadata_payload + + try: + context_cache = create_context_messages_cache() + await context_cache.append_message( + thread_id=thread_id, + runtime_mode=runtime_mode.value, + visibility_mask=visibility_mask, + message=message_payload, + ) + except Exception as exc: + logger.warning( + "Failed to append user message to context cache", + thread_id=thread_id, + runtime_mode=runtime_mode.value, + error=str(exc), + ) + async def _resolve_user_message_visibility_mask( self, *, runtime_mode: RuntimeMode ) -> int: diff --git a/backend/src/v1/users/service.py b/backend/src/v1/users/service.py index 11acee1..5d061b8 100644 --- a/backend/src/v1/users/service.py +++ b/backend/src/v1/users/service.py @@ -7,10 +7,10 @@ from uuid import UUID from fastapi import HTTPException from sqlalchemy.exc import SQLAlchemyError -from core.auth.models import CurrentUser -from core.agentscope.persistence.user_context_cache import ( +from core.agentscope.caches.user_context_cache import ( create_user_context_cache, ) +from core.auth.models import CurrentUser from core.db.base_service import BaseService from core.logging import get_logger from schemas.shared.user import UserContext, parse_profile_settings diff --git a/backend/tests/unit/core/agentscope/caches/test_context_messages_cache.py b/backend/tests/unit/core/agentscope/caches/test_context_messages_cache.py new file mode 100644 index 0000000..cc2f47a --- /dev/null +++ b/backend/tests/unit/core/agentscope/caches/test_context_messages_cache.py @@ -0,0 +1,228 @@ +from __future__ import annotations + +from datetime import datetime, timedelta, timezone +from typing import Any, cast + +import pytest + +from core.agentscope.caches.context_messages_cache import ContextMessagesCache +from schemas.domain.automation import ContextWindowMode, MessageContextConfig + + +class _FakeCacheStore: + def __init__(self) -> None: + self.hash_store: dict[str, dict[str, str]] = {} + self.set_store: dict[str, set[str]] = {} + + async def hgetall(self, key: str) -> dict[str, str]: + return dict(self.hash_store.get(key, {})) + + async def hset(self, key: str, mapping: dict[str, str]) -> int: + self.hash_store[key] = dict(mapping) + return 1 + + async def hincrby(self, key: str, field: str, amount: int = 1) -> int: + del key, field, amount + return 0 + + async def expire(self, key: str, ttl_seconds: int) -> int: + del key, ttl_seconds + return 1 + + async def delete(self, *keys: str) -> int: + for key in keys: + self.hash_store.pop(key, None) + self.set_store.pop(key, None) + return len(keys) + + async def sadd(self, key: str, *members: str) -> int: + values = self.set_store.setdefault(key, set()) + before = len(values) + for member in members: + values.add(member) + return len(values) - before + + async def smembers(self, key: str) -> set[str]: + return set(self.set_store.get(key, set())) + + +@pytest.mark.asyncio +async def test_context_messages_cache_set_get_roundtrip() -> None: + store = _FakeCacheStore() + cache = ContextMessagesCache( + client=store, + key_prefix="agent:context-messages", + ttl_seconds=600, + ) + config = MessageContextConfig(window_mode=ContextWindowMode.DAY, window_count=2) + messages: list[dict[str, object]] = [ + {"role": "user", "content": "hello", "timestamp": "2026-03-25T08:00:00+00:00"} + ] + + await cache.set( + thread_id="thread-1", + runtime_mode="chat", + context_config=config, + messages=messages, + ) + loaded = await cache.get( + thread_id="thread-1", + runtime_mode="chat", + context_config=config, + ) + + assert loaded is not None + assert loaded[0]["content"] == "hello" + + +@pytest.mark.asyncio +async def test_context_messages_cache_append_skips_when_not_visible() -> None: + store = _FakeCacheStore() + cache = ContextMessagesCache( + client=store, + key_prefix="agent:context-messages", + ttl_seconds=600, + ) + config = MessageContextConfig( + window_mode=ContextWindowMode.NUMBER, + window_count=1, + ) + await cache.set( + thread_id="thread-1", + runtime_mode="chat", + context_config=config, + messages=cast( + list[dict[str, Any]], + [ + { + "role": "user", + "content": "q1", + "timestamp": "2026-03-25T08:00:00+00:00", + } + ], + ), + ) + + await cache.append_message( + thread_id="thread-1", + runtime_mode="chat", + visibility_mask=1, + message={"role": "assistant", "content": "a1"}, + ) + + loaded = await cache.get( + thread_id="thread-1", + runtime_mode="chat", + context_config=config, + ) + assert loaded is not None + assert len(loaded) == 1 + + +@pytest.mark.asyncio +async def test_context_messages_cache_append_trims_number_window() -> None: + store = _FakeCacheStore() + cache = ContextMessagesCache( + client=store, + key_prefix="agent:context-messages", + ttl_seconds=600, + ) + config = MessageContextConfig( + window_mode=ContextWindowMode.NUMBER, + window_count=1, + ) + await cache.set( + thread_id="thread-1", + runtime_mode="chat", + context_config=config, + messages=cast( + list[dict[str, Any]], + [ + { + "role": "user", + "content": "q0", + "timestamp": "2026-03-25T08:00:00+00:00", + }, + { + "role": "assistant", + "content": "a0", + "timestamp": "2026-03-25T08:01:00+00:00", + }, + { + "role": "user", + "content": "q1", + "timestamp": "2026-03-25T08:02:00+00:00", + }, + ], + ), + ) + + await cache.append_message( + thread_id="thread-1", + runtime_mode="chat", + visibility_mask=2, + message={"role": "assistant", "content": "a1"}, + ) + + loaded = await cache.get( + thread_id="thread-1", + runtime_mode="chat", + context_config=config, + ) + assert loaded is not None + assert [str(item["content"]) for item in loaded] == ["q1", "a1"] + + +@pytest.mark.asyncio +async def test_context_messages_cache_append_trims_day_window() -> None: + store = _FakeCacheStore() + cache = ContextMessagesCache( + client=store, + key_prefix="agent:context-messages", + ttl_seconds=600, + ) + config = MessageContextConfig(window_mode=ContextWindowMode.DAY, window_count=2) + + now = datetime(2026, 3, 25, 10, 0, tzinfo=timezone.utc) + yesterday = now - timedelta(days=1) + two_days_ago = now - timedelta(days=2) + + await cache.set( + thread_id="thread-1", + runtime_mode="chat", + context_config=config, + messages=cast( + list[dict[str, Any]], + [ + { + "role": "assistant", + "content": "d-2", + "timestamp": two_days_ago.isoformat(), + }, + { + "role": "assistant", + "content": "d-1", + "timestamp": yesterday.isoformat(), + }, + ], + ), + ) + + await cache.append_message( + thread_id="thread-1", + runtime_mode="chat", + visibility_mask=2, + message={ + "role": "assistant", + "content": "d0", + "timestamp": now.isoformat(), + }, + ) + + loaded = await cache.get( + thread_id="thread-1", + runtime_mode="chat", + context_config=config, + ) + assert loaded is not None + assert [str(item["content"]) for item in loaded] == ["d-1", "d0"] diff --git a/backend/tests/unit/core/agentscope/events/test_store.py b/backend/tests/unit/core/agentscope/events/test_store.py index b2f81ce..b13094c 100644 --- a/backend/tests/unit/core/agentscope/events/test_store.py +++ b/backend/tests/unit/core/agentscope/events/test_store.py @@ -145,6 +145,38 @@ async def test_store_persists_tool_output_with_summary_as_content( assert append_kwargs["visibility_mask"] == (1 << 0) +@pytest.mark.asyncio +async def test_store_sets_history_only_visibility_for_automation_worker_output( + monkeypatch: pytest.MonkeyPatch, +) -> None: + captured: dict[str, object] = {} + fake_chat_session = SimpleNamespace(state_snapshot={}, message_count=3) + _patch_repositories(monkeypatch, captured, fake_chat_session) + + store = store_module.SqlAlchemyEventStore(session_factory=lambda: _FakeSessionCtx()) + await store.persist( + { + "type": "TEXT_MESSAGE_END", + "threadId": "00000000-0000-0000-0000-000000000001", + "runId": "run-auto-1", + "messageId": "assistant-auto-1", + "role": "assistant", + "stage": "worker", + "runtime_mode": "automation", + "status": "success", + "answer": "automation-result", + "key_points": [], + "result_type": "summary", + "suggested_actions": [], + "error": None, + } + ) + + append_kwargs = cast(dict[str, Any], captured["append_kwargs"]) + assert append_kwargs["content"] == "automation-result" + assert append_kwargs["visibility_mask"] == (1 << 0) + + @pytest.mark.asyncio async def test_store_persists_router_step_output_for_cost_tracking( monkeypatch: pytest.MonkeyPatch, diff --git a/backend/tests/unit/core/agentscope/persistence/test_user_context_cache.py b/backend/tests/unit/core/agentscope/persistence/test_user_context_cache.py index 4ec64de..50eda6d 100644 --- a/backend/tests/unit/core/agentscope/persistence/test_user_context_cache.py +++ b/backend/tests/unit/core/agentscope/persistence/test_user_context_cache.py @@ -5,7 +5,7 @@ from uuid import uuid4 import pytest -from core.agentscope.persistence.user_context_cache import UserContextCache +from core.agentscope.caches.user_context_cache import UserContextCache from schemas.shared.user import ( UserContext, parse_profile_settings, diff --git a/backend/tests/unit/core/agentscope/runtime/test_stage_emitter.py b/backend/tests/unit/core/agentscope/runtime/test_stage_emitter.py index 71b4c9c..8434a6b 100644 --- a/backend/tests/unit/core/agentscope/runtime/test_stage_emitter.py +++ b/backend/tests/unit/core/agentscope/runtime/test_stage_emitter.py @@ -28,6 +28,7 @@ async def test_tool_result_event_uses_runtime_tool_call_id() -> None: session_id="thread-1", run_id="run-1", stage="worker", + runtime_mode="chat", emit_text_events=False, emit_tool_events=True, ) @@ -65,3 +66,4 @@ async def test_tool_result_event_uses_runtime_tool_call_id() -> None: result_events = [e for e in pipeline.events if e.get("type") == "TOOL_CALL_RESULT"] assert len(result_events) == 1 assert result_events[0]["tool_call_id"] == "runtime-call-123" + assert result_events[0]["runtime_mode"] == "chat" diff --git a/backend/tests/unit/schemas/agent/test_runtime_models.py b/backend/tests/unit/schemas/agent/test_runtime_models.py new file mode 100644 index 0000000..6da1dd2 --- /dev/null +++ b/backend/tests/unit/schemas/agent/test_runtime_models.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from schemas.agent.runtime_models import RouterAgentOutput + + +def test_router_agent_output_coerces_key_entity_value_to_string() -> None: + payload = { + "normalized_task_input": { + "user_text": "test", + "multimodal_summary": [], + "context_summary": "", + }, + "key_entities": [ + { + "name": "priority", + "type": "number", + "value": 8, + } + ], + "constraints": [], + "task_typing": { + "primary": "planning", + "secondary": [], + }, + "execution_mode": "onestep", + "result_typing": { + "primary": "summary", + "secondary": [], + }, + "ui": { + "ui_mode": "none", + "ui_decision_reason": "test", + }, + } + + model = RouterAgentOutput.model_validate(payload) + + assert model.key_entities[0].value == "8" diff --git a/docs/plans/2026-03-25-agent-run-cancel-failed.md b/docs/plans/2026-03-25-agent-run-cancel-failed.md new file mode 100644 index 0000000..9d3654e --- /dev/null +++ b/docs/plans/2026-03-25-agent-run-cancel-failed.md @@ -0,0 +1,392 @@ +# Agent Run Cancel (Failed Semantics) Implementation Plan + +> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task. + +**Goal:** 为 `/api/v1/agent/runs` 增加可中断能力,在用户触发 cancel 后真正停止运行中的 agent 流程,并以 `RUN_ERROR(code=RUN_CANCELED)` 结束,最终将 session 状态落为 `failed`。 + +**Architecture:** 使用“协作取消 + 主任务中断”方案:API 层写入 Redis cancel 信号,runtime 在 worker 进程内并行 watcher 监听信号,命中后先调用 active agent 的 `interrupt()` 做优雅收尾,再 `cancel()` 当前 run 主任务做硬兜底。终态统一通过 `RUN_ERROR` 事件落库,复用现有 `FAILED` 会话语义,避免数据库枚举迁移。 + +**Tech Stack:** FastAPI, TaskIQ, Redis, AgentScope, SQLAlchemy, Pytest, Ruff, BasedPyright + +--- + +### Task 1: 先更新协议文档(接口与事件语义) + +**Files:** +- Modify: `docs/protocols/agent/api-endpoints.md` +- Modify: `docs/protocols/agent/sse-events.md` + +**Step 1: 在 API 文档新增 cancel 端点契约** + +在 `api-endpoints.md` 的端点清单添加: + +```md +| POST | `/runs/{thread_id}/cancel` | 请求取消指定 run | +``` + +并新增章节说明: +- 请求参数:`thread_id` + `runId`(建议 query) +- 返回:`202 Accepted` + `accepted: true` +- 语义:仅表示“取消请求已接收”,不保证已即时终止 + +**Step 2: 在 SSE 文档补充取消终态语义** + +在 `sse-events.md` 的 `RUN_ERROR` 章节补充: + +```json +{ + "type": "RUN_ERROR", + "threadId": "...", + "runId": "...", + "message": "run canceled by user", + "code": "RUN_CANCELED" +} +``` + +并明确: +- `RUN_CANCELED` 是用户主动中断,不是系统异常 +- 本阶段仍复用 session `failed`(向后兼容) + +**Step 3: 文档自检** + +检查文档是否同时覆盖: +- HTTP 行为 +- SSE 终态事件 +- 兼容策略(不引入新 session 状态) + +**Step 4: 提交文档变更** + +```bash +git add docs/protocols/agent/api-endpoints.md docs/protocols/agent/sse-events.md +git commit -m "docs: define agent run cancel API and RUN_CANCELED error semantics" +``` + +### Task 2: 打通 API 层到队列层的 cancel 信号写入 + +**Files:** +- Modify: `backend/src/v1/agent/schemas.py` +- Modify: `backend/src/v1/agent/dependencies.py` +- Modify: `backend/src/v1/agent/service.py` +- Modify: `backend/src/v1/agent/router.py` +- Test: `backend/tests/unit/v1/agent/test_service.py` +- Test: `backend/tests/integration/v1/agent/test_routes.py` + +**Step 1: 增加 cancel 接口响应 schema** + +在 `v1/agent/schemas.py` 增加: + +```python +class CancelRunResponse(BaseModel): + model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True) + thread_id: str = Field(alias="threadId") + run_id: str = Field(alias="runId") + accepted: bool +``` + +**Step 2: 扩展队列协议接口** + +在 `QueueClientLike` 增加: + +```python +async def request_cancel(self, *, thread_id: str, run_id: str, requested_by: str) -> None: ... +``` + +**Step 3: 在 `TaskiqQueueClient` 实现 request_cancel** + +在 `v1/agent/dependencies.py` 中新增: +- cancel key 规范:`agent:cancel:{thread_id}:{run_id}` +- `SET key value EX ` 写入取消信号 +- `value` 可写 json 字符串(包含 user_id/timestamp) + +**Step 4: 在 service 层新增 cancel_run** + +在 `v1/agent/service.py` 增加方法: +- 校验 session owner(复用 `get_session_owner + ensure_session_owner`) +- 调用 `self._queue.request_cancel(...)` +- 返回 `accepted` 结果 DTO + +**Step 5: 在 router 新增 cancel 路由** + +在 `v1/agent/router.py` 新增: + +```python +@router.post("/runs/{thread_id}/cancel", response_model=CancelRunResponse, status_code=202) +async def cancel_run(...): + ... +``` + +约束: +- `runId` 必填(建议 query) +- 非 owner 返回 403 +- 参数非法返回 422 + +**Step 6: 写 service 单测(先红)** + +在 `test_service.py` 添加: +- owner 可发起 cancel,`queue.request_cancel` 被调用 +- 非 owner cancel 返回 403 + +**Step 7: 运行单测确认失败** + +Run: `uv run pytest backend/tests/unit/v1/agent/test_service.py -k cancel -q` + +Expected: 至少 1 个测试失败(新逻辑尚未实现) + +**Step 8: 实现最小代码使测试通过** + +按 Step 1-5 完成实现,避免额外重构。 + +**Step 9: 运行测试验证通过** + +Run: `uv run pytest backend/tests/unit/v1/agent/test_service.py -k cancel -q` + +Expected: PASS + +**Step 10: 增加路由集成测试** + +在 `test_routes.py` 增加: +- `POST /api/v1/agent/runs/{thread_id}/cancel?runId=...` 返回 202 +- 响应字段别名正确(`threadId/runId/accepted`) + +**Step 11: 运行路由测试** + +Run: `uv run pytest backend/tests/integration/v1/agent/test_routes.py -k cancel -q` + +Expected: PASS + +**Step 12: 提交 API 层变更** + +```bash +git add backend/src/v1/agent/schemas.py backend/src/v1/agent/dependencies.py backend/src/v1/agent/service.py backend/src/v1/agent/router.py backend/tests/unit/v1/agent/test_service.py backend/tests/integration/v1/agent/test_routes.py +git commit -m "feat: add agent run cancel endpoint and Redis cancel signal" +``` + +### Task 3: runtime runner 植入取消 watcher 与优雅中断 + +**Files:** +- Modify: `backend/src/core/agentscope/runtime/runner.py` +- Test: `backend/tests/unit/core/agentscope/runtime/test_runner.py` + +**Step 1: 在 runner.execute 增加 cancel_checker 参数** + +更新 `execute()` 签名: + +```python +cancel_checker: Callable[[], Awaitable[bool]] | None = None +``` + +并保持默认 `None` 向后兼容。 + +**Step 2: 增加 active agent 引用与锁** + +在 `AgentScopeRunner.__init__` 增加: +- `self._active_agent: JsonReActAgent | None = None` +- `self._active_agent_lock = asyncio.Lock()` + +**Step 3: 在 `_run_worker_stage` 设置 active agent 生命周期** + +在 `agent.reply_json(...)` 前后包裹: +- before: 记录 `self._active_agent = agent` +- finally: 清理引用 + +**Step 4: 新增 `_watch_cancel_signal` 协程** + +行为: +- 循环调用 `cancel_checker()` +- 命中后先尝试 `await active_agent.interrupt()` +- 再 `run_task.cancel("run canceled by user")` +- 间隔 `await asyncio.sleep(0.2)` + +**Step 5: 在 execute 启停 watcher** + +- `run_task = asyncio.current_task()` +- 如果有 `cancel_checker`,`create_task(_watch_cancel_signal(...))` +- `finally` 中停止 watcher 并 `await` 回收 + +**Step 6: 补 stage 边界取消 gate(关键)** + +在 router 结束后、worker 开始前检查一次 `cancel_checker()`: +- 为 true 时抛 `asyncio.CancelledError` + +目的:防止“router 已结束但仍进入 worker”。 + +**Step 7: 写 runner 单测(先红)** + +新增测试用例: +- cancel 信号触发后,`execute` 抛出 `CancelledError` +- worker 未被继续执行(或中途被中断) + +**Step 8: 运行 runner 测试** + +Run: `uv run pytest backend/tests/unit/core/agentscope/runtime/test_runner.py -k cancel -q` + +Expected: PASS + +**Step 9: 提交 runner 变更** + +```bash +git add backend/src/core/agentscope/runtime/runner.py backend/tests/unit/core/agentscope/runtime/test_runner.py +git commit -m "feat: add cooperative cancellation watcher to agentscope runner" +``` + +### Task 4: orchestrator 与 task worker 处理 CancelledError 终态 + +**Files:** +- Modify: `backend/src/core/agentscope/runtime/orchestrator.py` +- Modify: `backend/src/core/agentscope/runtime/tasks.py` +- Test: `backend/tests/unit/core/agentscope/runtime/test_orchestrator.py` +- Test: `backend/tests/unit/core/agentscope/runtime/test_tasks.py` + +**Step 1: orchestrator 单独捕获 CancelledError** + +在 `orchestrator.run()` 添加: + +```python +except asyncio.CancelledError: + await self._pipeline.emit(... RUN_ERROR code="RUN_CANCELED" ...) + raise +``` + +保留现有 `except Exception` 处理系统错误。 + +**Step 2: task 层构造 cancel_checker 并注入 runtime.run** + +在 `tasks.py`: +- 构造 key:`agent:cancel:{thread_id}:{run_id}` +- 定义 `async def cancel_checker() -> bool: return bool(await redis.exists(key))` +- 调用 `runtime.run(..., cancel_checker=cancel_checker)` + +**Step 3: task 层补资源清理** + +在 `run_agentscope_task` 的 `finally`: +- 删除 cancel key 或缩短 TTL +- 记录日志(仅必要字段) + +**Step 4: 写 orchestrator 单测(先红)** + +验证: +- 收到 `CancelledError` 时发 `RUN_ERROR` 且 `code == "RUN_CANCELED"` + +**Step 5: 写 tasks 单测(先红)** + +验证: +- runtime 收到的 `cancel_checker` 可用 +- key 命中时上抛 `CancelledError` 路径成立 + +**Step 6: 运行测试** + +Run: `uv run pytest backend/tests/unit/core/agentscope/runtime/test_orchestrator.py backend/tests/unit/core/agentscope/runtime/test_tasks.py -k cancel -q` + +Expected: PASS + +**Step 7: 提交 runtime 编排层变更** + +```bash +git add backend/src/core/agentscope/runtime/orchestrator.py backend/src/core/agentscope/runtime/tasks.py backend/tests/unit/core/agentscope/runtime/test_orchestrator.py backend/tests/unit/core/agentscope/runtime/test_tasks.py +git commit -m "fix: emit RUN_CANCELED error when run task is interrupted" +``` + +### Task 5: 事件流与持久化一致性回归 + +**Files:** +- Modify: `backend/tests/unit/core/agentscope/events/test_store.py` +- Modify: `backend/tests/unit/core/agentscope/events/test_agui_codec.py` +- Modify: `backend/tests/integration/v1/agent/test_sse_flow_live.py` + +**Step 1: 补 event store 行为测试** + +新增断言: +- 当 `RUN_ERROR` 且 `code=RUN_CANCELED` 时,session 状态依然为 `FAILED` + +**Step 2: 补 codec 测试** + +新增断言: +- `RUN_ERROR` 的 `code` 字段能正确透传到 wire event + +**Step 3: 补 SSE 集成测试** + +场景: +- 触发 `/runs` +- 触发 `/runs/{thread_id}/cancel` +- SSE 最终出现 `RUN_ERROR(code=RUN_CANCELED)` + +**Step 4: 运行事件相关测试** + +Run: `uv run pytest backend/tests/unit/core/agentscope/events/test_store.py backend/tests/unit/core/agentscope/events/test_agui_codec.py backend/tests/integration/v1/agent/test_sse_flow_live.py -k "cancel or run_error" -q` + +Expected: PASS + +**Step 5: 提交事件层变更** + +```bash +git add backend/tests/unit/core/agentscope/events/test_store.py backend/tests/unit/core/agentscope/events/test_agui_codec.py backend/tests/integration/v1/agent/test_sse_flow_live.py +git commit -m "test: cover RUN_CANCELED propagation across store codec and SSE" +``` + +### Task 6: 全量验证与发布前检查 + +**Files:** +- Modify: `docs/protocols/agent/api-endpoints.md`(如需补充最终字段) +- Modify: `docs/protocols/agent/sse-events.md`(如需补充最终字段) + +**Step 1: 运行受影响单元测试集合** + +Run: + +```bash +uv run pytest backend/tests/unit/v1/agent/test_service.py backend/tests/unit/core/agentscope/runtime/test_runner.py backend/tests/unit/core/agentscope/runtime/test_orchestrator.py backend/tests/unit/core/agentscope/runtime/test_tasks.py backend/tests/unit/core/agentscope/events/test_store.py backend/tests/unit/core/agentscope/events/test_agui_codec.py -q +``` + +Expected: PASS + +**Step 2: 运行受影响集成测试集合** + +Run: + +```bash +uv run pytest backend/tests/integration/v1/agent/test_routes.py backend/tests/integration/v1/agent/test_sse_flow_live.py -q +``` + +Expected: PASS + +**Step 3: 运行静态检查** + +Run: + +```bash +uv run ruff check backend/src backend/tests +uv run basedpyright +``` + +Expected: PASS(无新增 lint/type 错误) + +**Step 4: 手工验证路径** + +手工流程: +- 发起 `/runs` +- 立刻调用 `/runs/{thread_id}/cancel?runId=...` +- 观察 SSE:应以 `RUN_ERROR(code=RUN_CANCELED)` 结束 +- 检查 session:`status=failed` + +**Step 5: 最终提交** + +```bash +git add docs/protocols/agent/api-endpoints.md docs/protocols/agent/sse-events.md backend/src/v1/agent/*.py backend/src/core/agentscope/runtime/*.py backend/tests/unit/core/agentscope/runtime/*.py backend/tests/unit/core/agentscope/events/*.py backend/tests/unit/v1/agent/test_service.py backend/tests/integration/v1/agent/test_routes.py backend/tests/integration/v1/agent/test_sse_flow_live.py +git commit -m "feat: support run cancellation with RUN_CANCELED failed semantics" +``` + +--- + +## 风险与回滚 + +- 风险 1:cancel key 误命中导致误中断 + - 缓解:key 粒度使用 `thread_id + run_id`,并设置 TTL +- 风险 2:中断时出现重复终态事件 + - 缓解:在 orchestrator 保证 CancelledError 只走 `RUN_ERROR` 分支,避免继续发 `RUN_FINISHED` +- 风险 3:高并发下 Redis 轮询压力上升 + - 缓解:轮询间隔 200ms,后续按并发量评估改为 pub/sub + +回滚策略: +- 回滚 `router/service/dependencies` cancel 新接口 +- 回滚 `runner/orchestrator/tasks` cancel 注入逻辑 +- 保持原 `POST /runs` 与 SSE 流程不变 diff --git a/docs/protocols/agent/sse-events.md b/docs/protocols/agent/sse-events.md index ab4a601..5b38dcf 100644 --- a/docs/protocols/agent/sse-events.md +++ b/docs/protocols/agent/sse-events.md @@ -364,7 +364,8 @@ cost = uncached_prompt_tokens * input_cost_per_token | 0 | `UI_HISTORY` | `/history` API 投影可见的消息 | | 1 | `CONTEXT_ASSEMBLY` | 运行时上下文装配(context assembly)可见 | -> 新消息入库时,`chat` 模式设置 `mask = UI_HISTORY | CONTEXT_ASSEMBLY`(值为 3),`automation` 模式设置 `mask = 0`。 +> 用户输入入库时,`chat` 模式设置 `mask = UI_HISTORY | CONTEXT_ASSEMBLY`(值为 3),`automation` 模式设置 `mask = 0`。 +> agent 运行产物入库时,`automation` 模式设置 `mask = UI_HISTORY`(值为 1),用于展示历史但不参与 context assembly。 ### /history API @@ -385,6 +386,7 @@ WHERE (visibility_mask & 2) != 0 **影响**: - `chat` 模式用户输入:mask=3 → 进入 `/history` ✅,进入 context assembly ✅ - `automation` 模式用户输入:mask=0 → 进入 `/history` ❌,进入 context assembly ❌ +- `automation` 模式 agent 输出:mask=1 → 进入 `/history` ✅,进入 context assembly ❌ ### Automation 模式上下文注入 @@ -396,7 +398,8 @@ WHERE (visibility_mask & 2) != 0 |------|--------|--------------| | Pipeline | `router` -> `worker` | `router` -> `worker` | | 用户输入 visibility_mask | `UI_HISTORY \| CONTEXT_ASSEMBLY` | `0` | -| 进入 /history | ✅ | ❌ | +| agent 输出 visibility_mask | `UI_HISTORY \| CONTEXT_ASSEMBLY`(memory stage 仅 `UI_HISTORY`) | `UI_HISTORY` | +| 进入 /history | ✅ | ✅(仅 agent 输出) | | 进入 context assembly | ✅(自动) | ❌(通过 run_input 注入) | | enabled_tools 来源 | `system_agents.yaml` worker 配置 | `AutomationJob.config.enabled_tools` | | context 配置来源 | `system_agents.yaml` router context_messages | `AutomationJob.config.context` |