feat: 重构 agentscope 缓存架构,新增消息和附件缓存

This commit is contained in:
qzl
2026-03-25 17:41:55 +08:00
parent d22ded21f8
commit 599c597e69
25 changed files with 1509 additions and 78 deletions
@@ -0,0 +1,21 @@
from .attachment_content_cache import (
AttachmentContentCache,
create_attachment_content_cache,
)
from .context_messages_cache import (
ContextMessagesCache,
create_context_messages_cache,
)
from .user_context_cache import (
UserContextCache,
create_user_context_cache,
)
__all__ = [
"AttachmentContentCache",
"ContextMessagesCache",
"UserContextCache",
"create_attachment_content_cache",
"create_context_messages_cache",
"create_user_context_cache",
]
@@ -0,0 +1,96 @@
from __future__ import annotations
from core.config.settings import config
from core.logging import get_logger
from services.caches import CacheStore, get_cache_store
logger = get_logger("core.agentscope.caches.attachment_content_cache")
class AttachmentContentCache:
def __init__(
self,
*,
client: CacheStore,
key_prefix: str,
ttl_seconds: int,
max_base64_bytes: int,
) -> None:
self._client = client
self._key_prefix = key_prefix
self._ttl_seconds = ttl_seconds
self._max_base64_bytes = max_base64_bytes
async def get(
self,
*,
bucket: str,
path: str,
mime_type: str,
) -> str | None:
key = self._key(bucket=bucket, path=path, mime_type=mime_type)
try:
raw = await self._client.hgetall(key)
except Exception as exc:
logger.warning(
"Failed to read attachment content cache",
bucket=bucket,
path=path,
error=str(exc),
)
return None
payload = raw.get("base64")
if not isinstance(payload, str) or not payload:
return None
return payload
async def set(
self,
*,
bucket: str,
path: str,
mime_type: str,
base64_data: str,
) -> None:
encoded_bytes = len(base64_data.encode("utf-8"))
if encoded_bytes > self._max_base64_bytes:
logger.info(
"Skip attachment cache write due to size limit",
bucket=bucket,
path=path,
encoded_bytes=encoded_bytes,
max_bytes=self._max_base64_bytes,
)
return
key = self._key(bucket=bucket, path=path, mime_type=mime_type)
try:
await self._client.hset(
key,
mapping={
"base64": base64_data,
"mime_type": mime_type,
},
)
await self._client.expire(key, self._ttl_seconds)
except Exception as exc:
logger.warning(
"Failed to write attachment content cache",
bucket=bucket,
path=path,
error=str(exc),
)
def _key(self, *, bucket: str, path: str, mime_type: str) -> str:
return f"{self._key_prefix}:{bucket}:{path}:mime:{mime_type}"
def create_attachment_content_cache() -> AttachmentContentCache:
runtime = config.agent_runtime
return AttachmentContentCache(
client=get_cache_store(),
key_prefix=runtime.attachment_content_cache_prefix,
ttl_seconds=runtime.attachment_content_cache_ttl_seconds,
max_base64_bytes=runtime.attachment_content_cache_max_base64_bytes,
)
@@ -0,0 +1,281 @@
from __future__ import annotations
from datetime import datetime, timezone
import json
from core.config.settings import config
from core.logging import get_logger
from schemas.agent.visibility import SystemVisibilityBit, bit_mask
from schemas.domain.automation import ContextWindowMode, MessageContextConfig
from services.caches import CacheStore, get_cache_store
logger = get_logger("core.agentscope.caches.context_messages_cache")
class ContextMessagesCache:
def __init__(
self,
*,
client: CacheStore,
key_prefix: str,
ttl_seconds: int,
) -> None:
self._client = client
self._key_prefix = key_prefix
self._ttl_seconds = ttl_seconds
async def get(
self,
*,
thread_id: str,
runtime_mode: str,
context_config: MessageContextConfig,
) -> list[dict[str, object]] | None:
key = self._cache_key(
thread_id=thread_id,
runtime_mode=runtime_mode,
context_config=context_config,
)
try:
raw = await self._client.hgetall(key)
except Exception as exc:
logger.warning(
"Failed to read context messages cache",
thread_id=thread_id,
error=str(exc),
)
return None
payload = raw.get("payload")
if not isinstance(payload, str) or not payload:
return None
try:
decoded = json.loads(payload)
except Exception:
await self._safe_delete(key)
return None
if not isinstance(decoded, dict):
await self._safe_delete(key)
return None
messages = decoded.get("messages")
if not isinstance(messages, list):
await self._safe_delete(key)
return None
return [item for item in messages if isinstance(item, dict)]
async def set(
self,
*,
thread_id: str,
runtime_mode: str,
context_config: MessageContextConfig,
messages: list[dict[str, object]],
) -> None:
key = self._cache_key(
thread_id=thread_id,
runtime_mode=runtime_mode,
context_config=context_config,
)
index_key = self._index_key(thread_id=thread_id, runtime_mode=runtime_mode)
payload = json.dumps(
{
"window_mode": context_config.window_mode.value,
"window_count": int(context_config.window_count),
"messages": [item for item in messages if isinstance(item, dict)],
},
ensure_ascii=True,
separators=(",", ":"),
)
try:
await self._client.hset(key, mapping={"payload": payload})
await self._client.expire(key, self._ttl_seconds)
await self._client.sadd(index_key, key)
await self._client.expire(index_key, self._ttl_seconds)
except Exception as exc:
logger.warning(
"Failed to write context messages cache",
thread_id=thread_id,
error=str(exc),
)
async def append_message(
self,
*,
thread_id: str,
runtime_mode: str,
visibility_mask: int,
message: dict[str, object],
) -> None:
if not self._is_context_visible(visibility_mask=visibility_mask):
return
index_key = self._index_key(thread_id=thread_id, runtime_mode=runtime_mode)
try:
keys = await self._client.smembers(index_key)
except Exception as exc:
logger.warning(
"Failed to read context cache index",
thread_id=thread_id,
error=str(exc),
)
return
if not keys:
return
normalized = self._normalize_message(message)
for key in keys:
try:
raw = await self._client.hgetall(key)
payload_raw = raw.get("payload")
if not isinstance(payload_raw, str) or not payload_raw:
continue
decoded = json.loads(payload_raw)
if not isinstance(decoded, dict):
continue
mode = decoded.get("window_mode")
count = decoded.get("window_count")
messages_raw = decoded.get("messages")
if (
not isinstance(mode, str)
or not isinstance(count, int)
or not isinstance(messages_raw, list)
):
continue
messages = [item for item in messages_raw if isinstance(item, dict)]
messages.append(dict(normalized))
trimmed = self._trim_messages(messages=messages, mode=mode, count=count)
next_payload = json.dumps(
{
"window_mode": mode,
"window_count": count,
"messages": trimmed,
},
ensure_ascii=True,
separators=(",", ":"),
)
await self._client.hset(key, mapping={"payload": next_payload})
await self._client.expire(key, self._ttl_seconds)
except Exception as exc:
logger.warning(
"Failed to append context cache message",
key=key,
thread_id=thread_id,
error=str(exc),
)
def _cache_key(
self,
*,
thread_id: str,
runtime_mode: str,
context_config: MessageContextConfig,
) -> str:
return (
f"{self._key_prefix}:{thread_id}:rm:{runtime_mode}:"
f"wm:{context_config.window_mode.value}:wc:{int(context_config.window_count)}"
)
def _index_key(self, *, thread_id: str, runtime_mode: str) -> str:
return f"{self._key_prefix}:index:{runtime_mode}:{thread_id}"
async def _safe_delete(self, key: str) -> None:
try:
await self._client.delete(key)
except Exception:
return
def _trim_messages(
self,
*,
messages: list[dict[str, object]],
mode: str,
count: int,
) -> list[dict[str, object]]:
safe_count = max(int(count), 1)
if mode == ContextWindowMode.NUMBER.value:
return self._trim_by_user_window(messages=messages, count=safe_count)
return self._trim_by_day_window(messages=messages, count=safe_count)
def _trim_by_user_window(
self,
*,
messages: list[dict[str, object]],
count: int,
) -> list[dict[str, object]]:
selected_reversed: list[dict[str, object]] = []
user_count = 0
for item in reversed(messages):
selected_reversed.append(item)
role = item.get("role")
if isinstance(role, str) and role == "user":
user_count += 1
if user_count >= count:
break
selected_reversed.reverse()
return selected_reversed
def _trim_by_day_window(
self,
*,
messages: list[dict[str, object]],
count: int,
) -> list[dict[str, object]]:
selected_reversed: list[dict[str, object]] = []
days_seen: list[str] = []
for item in reversed(messages):
day_value = self._extract_day(item)
if day_value not in days_seen:
if len(days_seen) >= count:
break
days_seen.append(day_value)
selected_reversed.append(item)
selected_reversed.reverse()
return selected_reversed
@staticmethod
def _extract_day(message: dict[str, object]) -> str:
raw = message.get("timestamp")
if isinstance(raw, str) and raw:
normalized = raw.replace("Z", "+00:00")
try:
return datetime.fromisoformat(normalized).date().isoformat()
except ValueError:
pass
return datetime.now(timezone.utc).date().isoformat()
@staticmethod
def _normalize_message(message: dict[str, object]) -> dict[str, object]:
normalized: dict[str, object] = {
"role": str(message.get("role") or "assistant"),
"content": str(message.get("content") or ""),
"timestamp": str(
message.get("timestamp")
or datetime.now(timezone.utc).isoformat(timespec="seconds")
),
}
metadata = message.get("metadata")
if isinstance(metadata, dict):
normalized["metadata"] = metadata
return normalized
@staticmethod
def _is_context_visible(*, visibility_mask: int) -> bool:
required = bit_mask(bit=int(SystemVisibilityBit.CONTEXT_ASSEMBLY))
return (max(int(visibility_mask), 0) & required) != 0
def create_context_messages_cache() -> ContextMessagesCache:
runtime = config.agent_runtime
return ContextMessagesCache(
client=get_cache_store(),
key_prefix=runtime.context_messages_cache_prefix,
ttl_seconds=runtime.context_messages_cache_ttl_seconds,
)
@@ -1,49 +1,26 @@
from __future__ import annotations
import inspect
import json
from collections.abc import Iterable
from typing import Any, Protocol
from typing import Any
from uuid import UUID
import redis.asyncio as redis
from core.config.settings import config
from core.logging import get_logger
from schemas.shared.user import (
UserContext,
parse_profile_settings,
)
from services.caches import CacheStore, get_cache_store
logger = get_logger("core.agentscope.persistence.user_context_cache")
class RedisHashClient(Protocol):
def hgetall(self, name: str, /) -> Any: ...
def hset(self, name: str, /, *args: Any, **kwargs: Any) -> Any: ...
def hincrby(self, name: str, key: str, amount: int = 1, /) -> Any: ...
def expire(self, name: str, time: int, /) -> Any: ...
def delete(self, *names: str) -> Any: ...
def sadd(self, name: str, *values: str) -> Any: ...
def smembers(self, name: str) -> Any: ...
async def _maybe_await(value: Any) -> Any:
if inspect.isawaitable(value):
return await value
return value
logger = get_logger("core.agentscope.caches.user_context_cache")
class UserContextCache:
def __init__(
self,
*,
client: RedisHashClient,
client: CacheStore,
key_prefix: str,
ttl_seconds: int,
max_turns: int,
@@ -56,7 +33,7 @@ class UserContextCache:
async def get(self, *, session_id: UUID) -> UserContext | None:
key = self._key(session_id)
try:
raw = await _maybe_await(self._client.hgetall(key))
raw = await self._client.hgetall(key)
except Exception as exc:
logger.warning(
"Failed to read user context cache",
@@ -107,18 +84,16 @@ class UserContextCache:
index_key = self._user_sessions_key(user_id)
payload = self._serialize(context)
try:
await _maybe_await(
self._client.hset(
key,
mapping={
"payload": payload,
"turns_used": "0",
},
)
await self._client.hset(
key,
mapping={
"payload": payload,
"turns_used": "0",
},
)
await _maybe_await(self._client.expire(key, self._ttl_seconds))
await _maybe_await(self._client.sadd(index_key, key))
await _maybe_await(self._client.expire(index_key, self._ttl_seconds))
await self._client.expire(key, self._ttl_seconds)
await self._client.sadd(index_key, key)
await self._client.expire(index_key, self._ttl_seconds)
except Exception as exc:
logger.warning(
"Failed to write user context cache",
@@ -130,7 +105,7 @@ class UserContextCache:
async def invalidate_user(self, *, user_id: UUID) -> int:
index_key = self._user_sessions_key(user_id)
try:
members_raw = await _maybe_await(self._client.smembers(index_key))
members_raw = await self._client.smembers(index_key)
except Exception as exc:
logger.warning(
"Failed to read user context cache index",
@@ -147,7 +122,7 @@ class UserContextCache:
deleted = 0
try:
deleted_raw = await _maybe_await(self._client.delete(*members))
deleted_raw = await self._client.delete(*members)
deleted = self._parse_int(deleted_raw)
except Exception as exc:
logger.warning(
@@ -205,7 +180,7 @@ class UserContextCache:
async def _safe_delete(self, key: str) -> None:
try:
await _maybe_await(self._client.delete(key))
await self._client.delete(key)
except Exception as exc:
logger.warning(
"Failed to delete user context cache key", key=key, error=str(exc)
@@ -214,7 +189,7 @@ class UserContextCache:
async def _safe_hincrby(self, key: str, field: str, amount: int) -> None:
try:
await _maybe_await(self._client.hincrby(key, field, amount))
await self._client.hincrby(key, field, amount)
except Exception as exc:
logger.warning(
"Failed to update user context cache usage",
@@ -272,7 +247,7 @@ class UserContextCache:
def create_user_context_cache() -> UserContextCache:
client = redis.from_url(config.redis.url, decode_responses=True)
client = get_cache_store()
runtime_settings = config.agent_runtime
return UserContextCache(
client=client,
+95 -8
View File
@@ -1,11 +1,16 @@
from __future__ import annotations
from datetime import datetime, timezone
from decimal import Decimal, InvalidOperation
from typing import Any, Callable, Protocol
from uuid import UUID
from core.agentscope.caches.context_messages_cache import (
create_context_messages_cache,
)
from core.agentscope.events.persistence import MessageRepository, SessionRepository
from core.logging import get_logger
from schemas.agent.forwarded_props import RuntimeMode
from schemas.enums import AgentChatMessageRole, AgentChatSessionStatus
from schemas.agent.system_agent import AgentType
from schemas.agent.runtime_models import AgentOutput, RouterAgentOutput, ToolAgentOutput
@@ -174,7 +179,10 @@ class SqlAlchemyEventStore:
if locked_session is None:
return
seq = int(getattr(locked_session, "message_count", 0) or 0) + 1
await message_repo.append_message(
visibility_mask = self._resolve_stage_visibility_mask(
event=event,
)
persisted = await message_repo.append_message(
session_id=session_id,
seq=seq,
role=role,
@@ -186,9 +194,16 @@ class SqlAlchemyEventStore:
output_tokens=output_tokens,
cost=cost,
latency_ms=latency_ms,
visibility_mask=self._resolve_stage_visibility_mask(
event=event,
),
visibility_mask=visibility_mask,
)
await self._append_context_cache_message(
session_id=session_id,
event=event,
visibility_mask=visibility_mask,
role=role.value,
content=content,
metadata=metadata_model.model_dump(mode="json", exclude_none=True),
timestamp=self._resolve_message_timestamp(persisted),
)
current_status = getattr(chat_session, "status", AgentChatSessionStatus.RUNNING)
@@ -339,16 +354,26 @@ class SqlAlchemyEventStore:
if locked_session is None:
return
seq = int(getattr(locked_session, "message_count", 0) or 0) + 1
await message_repo.append_message(
visibility_mask = self._resolve_stage_visibility_mask(
event=event,
)
persisted = await message_repo.append_message(
session_id=session_id,
seq=seq,
role=AgentChatMessageRole.TOOL,
content=content,
tool_name=tool_output.tool_name,
metadata=metadata_model.model_dump(mode="json", exclude_none=True),
visibility_mask=self._resolve_stage_visibility_mask(
event=event,
),
visibility_mask=visibility_mask,
)
await self._append_context_cache_message(
session_id=session_id,
event=event,
visibility_mask=visibility_mask,
role=AgentChatMessageRole.TOOL.value,
content=content,
metadata=metadata_model.model_dump(mode="json", exclude_none=True),
timestamp=self._resolve_message_timestamp(persisted),
)
current_status = getattr(chat_session, "status", AgentChatSessionStatus.RUNNING)
@@ -377,6 +402,13 @@ class SqlAlchemyEventStore:
*,
event: dict[str, Any],
) -> int:
runtime_mode = self._event_value(event, "runtime_mode")
if (
isinstance(runtime_mode, str)
and runtime_mode.strip().lower() == RuntimeMode.AUTOMATION.value
):
return bit_mask(bit=int(SystemVisibilityBit.UI_HISTORY))
raw_stage = self._event_value(event, "stage")
if not isinstance(raw_stage, str):
return bit_mask(bit=int(SystemVisibilityBit.UI_HISTORY))
@@ -387,6 +419,61 @@ class SqlAlchemyEventStore:
bit=int(SystemVisibilityBit.CONTEXT_ASSEMBLY)
)
async def _append_context_cache_message(
self,
*,
session_id: UUID,
event: dict[str, Any],
visibility_mask: int,
role: str,
content: str,
metadata: dict[str, object] | None,
timestamp: str,
) -> None:
message_payload: dict[str, object] = {
"role": role,
"content": content,
"timestamp": timestamp,
}
if isinstance(metadata, dict):
message_payload["metadata"] = metadata
try:
context_cache = create_context_messages_cache()
await context_cache.append_message(
thread_id=str(session_id),
runtime_mode=self._resolve_runtime_mode(event=event),
visibility_mask=visibility_mask,
message=message_payload,
)
except Exception as exc:
self._logger.warning(
"Failed to append context cache message from event",
thread_id=str(session_id),
error=str(exc),
)
@staticmethod
def _resolve_runtime_mode(*, event: dict[str, Any]) -> str:
raw = event.get("runtime_mode")
if isinstance(raw, str):
normalized = raw.strip().lower()
if normalized:
return normalized
return RuntimeMode.CHAT.value
@staticmethod
def _resolve_message_timestamp(message: Any) -> str:
created_at = getattr(message, "created_at", None)
if isinstance(created_at, str) and created_at:
return created_at
if isinstance(created_at, datetime):
try:
return created_at.astimezone(timezone.utc).isoformat()
except Exception:
pass
return datetime.now(timezone.utc).isoformat(timespec="seconds")
async def _update_session_state(
self,
*,
@@ -1,9 +0,0 @@
from core.agentscope.persistence.user_context_cache import (
UserContextCache,
create_user_context_cache,
)
__all__ = [
"UserContextCache",
"create_user_context_cache",
]
@@ -29,7 +29,9 @@ from models.llm_factory import LlmFactory
from models.system_agents import SystemAgents
from schemas.agent.forwarded_props import (
ClientTimeContext,
RuntimeMode,
parse_forwarded_props_client_time,
parse_forwarded_props_runtime_mode,
)
from schemas.agent.runtime_models import (
RouterAgentOutput,
@@ -76,6 +78,7 @@ class AgentScopeRunner:
) -> dict[str, Any]:
owner_id = UUID(user_context.id)
runtime_client_time = self._resolve_runtime_client_time(run_input=run_input)
runtime_mode = self._resolve_runtime_mode(run_input=run_input)
async with AsyncSessionLocal() as session:
router_config = await self._load_stage_config(
@@ -99,6 +102,7 @@ class AgentScopeRunner:
context_messages=context_messages,
stage_config=router_config,
runtime_client_time=runtime_client_time,
runtime_mode=runtime_mode,
user_memory=user_memory,
)
worker_output = await self._execute_worker_step(
@@ -109,6 +113,7 @@ class AgentScopeRunner:
toolkit=worker_toolkit,
stage_config=worker_config,
runtime_client_time=runtime_client_time,
runtime_mode=runtime_mode,
work_memory=work_memory,
)
return {
@@ -171,6 +176,7 @@ class AgentScopeRunner:
context_messages: list[Msg],
stage_config: SystemAgentRuntimeConfig,
runtime_client_time: ClientTimeContext | None,
runtime_mode: RuntimeMode,
user_memory: UserMemoryContent | None,
) -> RouterAgentOutput:
await self._emit_step_event(
@@ -178,6 +184,7 @@ class AgentScopeRunner:
run_input=run_input,
step_name=AgentType.ROUTER.value,
event_type="STEP_STARTED",
runtime_mode=runtime_mode,
)
router_result = await self._run_router_stage(
user_context=user_context,
@@ -193,6 +200,7 @@ class AgentScopeRunner:
run_input=run_input,
step_name=AgentType.ROUTER.value,
event_type="STEP_FINISHED",
runtime_mode=runtime_mode,
extra_event={
"_router_persist": {
"router_output": router_output.model_dump(
@@ -214,6 +222,7 @@ class AgentScopeRunner:
toolkit: Any,
stage_config: SystemAgentRuntimeConfig,
runtime_client_time: ClientTimeContext | None,
runtime_mode: RuntimeMode,
work_memory: WorkProfileContent | None,
) -> WorkerAgentOutputLite:
worker_output_model = resolve_worker_output_model(router_output.ui.ui_mode)
@@ -222,6 +231,7 @@ class AgentScopeRunner:
run_input=run_input,
step_name=AgentType.WORKER.value,
event_type="STEP_STARTED",
runtime_mode=runtime_mode,
)
worker_result = await self._run_worker_stage(
user_context=user_context,
@@ -234,6 +244,7 @@ class AgentScopeRunner:
worker_output_model=worker_output_model,
pipeline=pipeline,
runtime_client_time=runtime_client_time,
runtime_mode=runtime_mode,
work_memory=work_memory,
)
worker_output = worker_output_model.model_validate(worker_result.payload)
@@ -242,6 +253,7 @@ class AgentScopeRunner:
run_input=run_input,
step_name=AgentType.WORKER.value,
event_type="STEP_FINISHED",
runtime_mode=runtime_mode,
)
return worker_output
@@ -332,6 +344,7 @@ class AgentScopeRunner:
worker_output_model: type[WorkerAgentOutputLite],
pipeline: PipelineLike,
runtime_client_time: ClientTimeContext | None,
runtime_mode: RuntimeMode,
work_memory: WorkProfileContent | None,
) -> StageExecutionResult:
tracking_model = self._build_model(stage_config=stage_config)
@@ -340,6 +353,7 @@ class AgentScopeRunner:
session_id=run_input.thread_id,
run_id=run_input.run_id,
stage=stage_config.agent_type.value,
runtime_mode=runtime_mode.value,
emit_text_events=True,
emit_tool_events=True,
)
@@ -437,12 +451,14 @@ class AgentScopeRunner:
run_input: RunAgentInput,
step_name: str,
event_type: str,
runtime_mode: RuntimeMode,
extra_event: dict[str, Any] | None = None,
) -> None:
payload: dict[str, Any] = {
"type": event_type,
"threadId": run_input.thread_id,
"runId": run_input.run_id,
"runtime_mode": runtime_mode.value,
"stepName": step_name,
}
if extra_event:
@@ -459,6 +475,15 @@ class AgentScopeRunner:
getattr(run_input, "forwarded_props", None)
)
@staticmethod
def _resolve_runtime_mode(*, run_input: RunAgentInput) -> RuntimeMode:
try:
return parse_forwarded_props_runtime_mode(
getattr(run_input, "forwarded_props", None)
)
except ValueError:
return RuntimeMode.CHAT
@staticmethod
def _resolve_provider_api_key(*, factory_name: str) -> str:
normalized_factory_name = factory_name.strip().upper()
@@ -20,6 +20,7 @@ class PipelineStageEmitter:
session_id: str,
run_id: str,
stage: str,
runtime_mode: str,
emit_text_events: bool,
emit_tool_events: bool,
) -> None:
@@ -27,6 +28,7 @@ class PipelineStageEmitter:
self._session_id = session_id
self._run_id = run_id
self._stage = stage
self._runtime_mode = runtime_mode
self._emit_text_events = emit_text_events
self._emit_tool_events = emit_tool_events
self._emitted_tool_calls: set[str] = set()
@@ -127,6 +129,7 @@ class PipelineStageEmitter:
"type": event_type,
"threadId": self._session_id,
"runId": self._run_id,
"runtime_mode": self._runtime_mode,
**payload,
},
)
+82 -9
View File
@@ -6,6 +6,13 @@ from typing import Any, cast
from uuid import UUID
from agentscope.message import Msg
from core.agentscope.caches import create_user_context_cache
from core.agentscope.caches.attachment_content_cache import (
create_attachment_content_cache,
)
from core.agentscope.caches.context_messages_cache import (
create_context_messages_cache,
)
from core.agentscope.events import (
AgentScopeAgUiCodec,
AgentScopeEventPipeline,
@@ -20,12 +27,16 @@ from core.config.settings import config
from core.db.session import AsyncSessionLocal
from core.logging import get_logger
from core.taskiq.app import worker_agent_broker, worker_automation_broker
from schemas.agent.forwarded_props import (
RuntimeMode,
parse_forwarded_props_runtime_mode,
)
from schemas.domain.automation import MessageContextConfig, RuntimeConfig
from schemas.domain.memory_content import UserMemoryContent, WorkProfileContent
from schemas.domain.chat_message import (
AgentChatMessageMetadata,
extract_user_message_attachments,
)
from schemas.domain.memory_content import UserMemoryContent, WorkProfileContent
from schemas.shared.user import UserContext
from services.base.redis import get_or_init_redis_client
from services.base.supabase import supabase_service
@@ -73,28 +84,56 @@ async def _build_user_context(
*,
owner_id: UUID,
session: Any,
session_id: str,
) -> UserContext:
cache = create_user_context_cache()
cached = await cache.get(session_id=UUID(session_id))
if cached:
return cached
current_user = CurrentUser(id=owner_id)
user_service = get_user_service(session=session, user=current_user)
return await user_service.get_me()
user_context = await user_service.get_me()
await cache.set(session_id=UUID(session_id), context=user_context)
return user_context
async def _build_recent_context_messages(
*,
session: Any,
thread_id: str,
runtime_mode: RuntimeMode = RuntimeMode.CHAT,
context_config: "MessageContextConfig",
) -> list[Msg]:
context_service = AgentContextService(repository=AgentRepository(session))
result = await context_service.load_context_messages(
context_cache = create_context_messages_cache()
attachment_cache = create_attachment_content_cache()
raw_messages = await context_cache.get(
thread_id=thread_id,
runtime_mode=runtime_mode.value,
context_config=context_config,
)
if not result:
return []
if raw_messages is None:
context_service = AgentContextService(repository=AgentRepository(session))
result = await context_service.load_context_messages(
thread_id=thread_id,
context_config=context_config,
)
if not result:
return []
messages_obj = result.get("messages")
if not isinstance(messages_obj, list):
return []
raw_messages = [item for item in messages_obj if isinstance(item, dict)]
await context_cache.set(
thread_id=thread_id,
runtime_mode=runtime_mode.value,
context_config=context_config,
messages=raw_messages,
)
raw_messages: list[dict[str, object]] = result.get("messages") or []
if not raw_messages:
return []
@@ -120,6 +159,24 @@ async def _build_recent_context_messages(
:_MAX_CONTEXT_ATTACHMENTS
]
for attachment in attachments:
mime_type = attachment.mime_type or "image/png"
cached_b64 = await attachment_cache.get(
bucket=attachment.bucket,
path=attachment.path,
mime_type=mime_type,
)
if cached_b64:
image_blocks.append(
{
"type": "image",
"source": {
"type": "base64",
"media_type": mime_type,
"data": cached_b64,
},
}
)
continue
try:
image_bytes = await supabase_service.download_bytes(
bucket=attachment.bucket,
@@ -128,12 +185,18 @@ async def _build_recent_context_messages(
except Exception:
continue
b64_data = base64.b64encode(image_bytes).decode("utf-8")
await attachment_cache.set(
bucket=attachment.bucket,
path=attachment.path,
mime_type=mime_type,
base64_data=b64_data,
)
image_blocks.append(
{
"type": "image",
"source": {
"type": "base64",
"media_type": attachment.mime_type or "image/png",
"media_type": mime_type,
"data": b64_data,
},
}
@@ -183,6 +246,13 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
raise ValueError("run_input is required")
run_input = parse_run_input(run_input_raw)
runtime_mode = RuntimeMode.CHAT
try:
runtime_mode = parse_forwarded_props_runtime_mode(
getattr(run_input, "forwarded_props", None)
)
except ValueError:
runtime_mode = RuntimeMode.CHAT
runtime_config = RuntimeConfig.model_validate(runtime_config_raw or {})
thread_id = run_input.thread_id
run_id = run_input.run_id
@@ -195,7 +265,9 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
async with AsyncSessionLocal() as session:
current_user = CurrentUser(id=owner_id)
user_context = await _build_user_context(owner_id=owner_id, session=session)
user_context = await _build_user_context(
owner_id=owner_id, session=session, session_id=thread_id
)
memories_service = MemoriesService(
repository=SQLAlchemyMemoriesRepository(session),
session=session,
@@ -226,6 +298,7 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
context_messages = await _build_recent_context_messages(
session=session,
thread_id=thread_id,
runtime_mode=runtime_mode,
context_config=runtime_config.context,
)
+9
View File
@@ -161,6 +161,15 @@ class AgentRuntimeSettings(BaseModel):
user_context_cache_prefix: str = "agent:user-context"
user_context_cache_ttl_seconds: int = Field(default=600, ge=60, le=86400)
user_context_cache_max_turns: int = Field(default=6, ge=1, le=100)
context_messages_cache_prefix: str = "agent:context-messages"
context_messages_cache_ttl_seconds: int = Field(default=600, ge=30, le=86400)
attachment_content_cache_prefix: str = "agent:attachment-content"
attachment_content_cache_ttl_seconds: int = Field(default=1800, ge=30, le=86400)
attachment_content_cache_max_base64_bytes: int = Field(
default=6 * 1024 * 1024,
ge=1024,
le=64 * 1024 * 1024,
)
class AutomationSchedulerSettings(BaseModel):
+12 -1
View File
@@ -3,7 +3,7 @@ from __future__ import annotations
from enum import Enum
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, field_validator
from schemas.agent.ui_hints import UiHintsPayload
@@ -86,6 +86,17 @@ class KeyEntity(BaseModel):
type: str
value: str | None = None
@field_validator("value", mode="before")
@classmethod
def normalize_value(cls, value: object) -> object:
if value is None:
return None
if isinstance(value, str):
return value
if isinstance(value, bool | int | float):
return str(value)
return value
class ConstraintItem(BaseModel):
model_config = ConfigDict(extra="forbid")
+4
View File
@@ -0,0 +1,4 @@
from .factory import get_cache_store
from .interfaces import CacheStore
__all__ = ["CacheStore", "get_cache_store"]
+13
View File
@@ -0,0 +1,13 @@
from __future__ import annotations
from .interfaces import CacheStore
from .redis_store import RedisCacheStore
_cache_store: CacheStore | None = None
def get_cache_store() -> CacheStore:
global _cache_store
if _cache_store is None:
_cache_store = RedisCacheStore()
return _cache_store
+19
View File
@@ -0,0 +1,19 @@
from __future__ import annotations
from typing import Protocol
class CacheStore(Protocol):
async def hgetall(self, key: str, /) -> dict[str, str]: ...
async def hset(self, key: str, /, mapping: dict[str, str]) -> int: ...
async def hincrby(self, key: str, field: str, amount: int = 1, /) -> int: ...
async def expire(self, key: str, ttl_seconds: int, /) -> int: ...
async def delete(self, *keys: str) -> int: ...
async def sadd(self, key: str, *members: str) -> int: ...
async def smembers(self, key: str, /) -> set[str]: ...
@@ -0,0 +1,80 @@
from __future__ import annotations
import inspect
from typing import Any
from services.base.redis import get_or_init_redis_client
from .interfaces import CacheStore
def _to_text(value: Any) -> str | None:
if isinstance(value, str):
return value
if isinstance(value, bytes):
try:
return value.decode("utf-8")
except UnicodeDecodeError:
return None
return None
async def _maybe_await(value: Any) -> Any:
if inspect.isawaitable(value):
return await value
return value
class RedisCacheStore(CacheStore):
async def hgetall(self, key: str) -> dict[str, str]:
client = await get_or_init_redis_client()
raw = await _maybe_await(client.hgetall(key))
if not isinstance(raw, dict):
return {}
decoded: dict[str, str] = {}
for raw_key, raw_value in raw.items():
key_text = _to_text(raw_key)
value_text = _to_text(raw_value)
if key_text is None or value_text is None:
continue
decoded[key_text] = value_text
return decoded
async def hset(self, key: str, mapping: dict[str, str]) -> int:
client = await get_or_init_redis_client()
result = await _maybe_await(client.hset(key, mapping=mapping))
return int(result)
async def hincrby(self, key: str, field: str, amount: int = 1) -> int:
client = await get_or_init_redis_client()
result = await _maybe_await(client.hincrby(key, field, amount))
return int(result)
async def expire(self, key: str, ttl_seconds: int) -> int:
client = await get_or_init_redis_client()
result = await _maybe_await(client.expire(key, ttl_seconds))
return int(result)
async def delete(self, *keys: str) -> int:
if not keys:
return 0
client = await get_or_init_redis_client()
result = await _maybe_await(client.delete(*keys))
return int(result)
async def sadd(self, key: str, *members: str) -> int:
if not members:
return 0
client = await get_or_init_redis_client()
result = await _maybe_await(client.sadd(key, *members))
return int(result)
async def smembers(self, key: str) -> set[str]:
client = await get_or_init_redis_client()
raw = await _maybe_await(client.smembers(key))
if isinstance(raw, set):
return {value for item in raw if (value := _to_text(item))}
if isinstance(raw, list | tuple):
return {value for item in raw if (value := _to_text(item))}
return set()
+49 -1
View File
@@ -1,6 +1,6 @@
from __future__ import annotations
from datetime import date
from datetime import date, datetime, timezone
import hashlib
from urllib.parse import urlparse
@@ -10,6 +10,9 @@ from fastapi import HTTPException
from sqlalchemy.exc import IntegrityError
from core.auth.models import CurrentUser
from core.agentscope.caches.context_messages_cache import (
create_context_messages_cache,
)
from core.agentscope.schemas.agui_input import extract_latest_user_payload
from core.config.settings import config
from core.logging import get_logger
@@ -124,6 +127,13 @@ class AgentService:
visibility_mask=visibility_mask,
)
await self._repository.commit()
await self._append_context_cache_user_message(
thread_id=thread_id,
runtime_mode=runtime_mode,
visibility_mask=visibility_mask,
content=user_message_text,
metadata=user_message_metadata,
)
queue = "automation" if runtime_mode == RuntimeMode.AUTOMATION else "agent"
task_id = await self._queue.enqueue(
@@ -147,6 +157,44 @@ class AgentService:
created=created,
)
async def _append_context_cache_user_message(
self,
*,
thread_id: str,
runtime_mode: RuntimeMode,
visibility_mask: int,
content: str,
metadata: AgentChatMessageMetadata | None,
) -> None:
metadata_payload = (
metadata.model_dump(mode="json", exclude_none=True)
if isinstance(metadata, AgentChatMessageMetadata)
else None
)
message_payload: dict[str, object] = {
"role": "user",
"content": content,
"timestamp": datetime.now(timezone.utc).isoformat(timespec="seconds"),
}
if isinstance(metadata_payload, dict):
message_payload["metadata"] = metadata_payload
try:
context_cache = create_context_messages_cache()
await context_cache.append_message(
thread_id=thread_id,
runtime_mode=runtime_mode.value,
visibility_mask=visibility_mask,
message=message_payload,
)
except Exception as exc:
logger.warning(
"Failed to append user message to context cache",
thread_id=thread_id,
runtime_mode=runtime_mode.value,
error=str(exc),
)
async def _resolve_user_message_visibility_mask(
self, *, runtime_mode: RuntimeMode
) -> int:
+2 -2
View File
@@ -7,10 +7,10 @@ from uuid import UUID
from fastapi import HTTPException
from sqlalchemy.exc import SQLAlchemyError
from core.auth.models import CurrentUser
from core.agentscope.persistence.user_context_cache import (
from core.agentscope.caches.user_context_cache import (
create_user_context_cache,
)
from core.auth.models import CurrentUser
from core.db.base_service import BaseService
from core.logging import get_logger
from schemas.shared.user import UserContext, parse_profile_settings