feat: 重构 memory 系统,支持 user memory 和 work memory 分离

This commit is contained in:
qzl
2026-03-23 14:25:47 +08:00
parent 3aacc756db
commit 6be616f108
70 changed files with 7031 additions and 431 deletions
@@ -6,7 +6,7 @@ import pytest
from ag_ui.core import RunAgentInput
from core.agentscope.runtime.orchestrator import AgentScopeRuntimeOrchestrator
from schemas.automation import MemoryContextConfig, RuntimeConfig
from schemas.automation import MessageContextConfig, RuntimeConfig
from schemas.user import UserContext, parse_profile_settings
@@ -51,7 +51,7 @@ def _run_input() -> RunAgentInput:
def _runtime_config() -> RuntimeConfig:
return RuntimeConfig(
enabled_tools=[],
context=MemoryContextConfig(),
context=MessageContextConfig(),
)
@@ -18,7 +18,7 @@ from schemas.agent.runtime_models import (
WorkerAgentOutputLite,
)
from schemas.agent.system_agent import AgentType
from schemas.automation import MemoryContextConfig, RuntimeConfig
from schemas.automation import MessageContextConfig, RuntimeConfig
from schemas.user import UserContext, parse_profile_settings
@@ -48,7 +48,7 @@ def _user_context() -> UserContext:
def _runtime_config() -> RuntimeConfig:
return RuntimeConfig(
enabled_tools=[],
context=MemoryContextConfig(),
context=MessageContextConfig(),
)
@@ -7,7 +7,7 @@ import pytest
import core.agentscope.runtime.tasks as tasks_module
from schemas.agent import ToolStatus
from schemas.automation import ContextWindowMode, MemoryContextConfig
from schemas.automation import ContextWindowMode, MessageContextConfig
from schemas.user import UserContext, parse_profile_settings
@@ -201,7 +201,7 @@ async def test_build_recent_context_messages_includes_all_user_attachments(
self,
*,
thread_id: str,
context_config: MemoryContextConfig,
context_config: MessageContextConfig,
) -> dict[str, object] | None:
del thread_id, context_config
return {
@@ -237,7 +237,7 @@ async def test_build_recent_context_messages_includes_all_user_attachments(
messages = await tasks_module._build_recent_context_messages(
session=object(),
thread_id=str(uuid4()),
context_config=MemoryContextConfig(
context_config=MessageContextConfig(
window_mode=ContextWindowMode.DAY,
window_count=2,
),
@@ -264,7 +264,7 @@ async def test_build_recent_context_messages_uses_tool_metadata_output(
self,
*,
thread_id: str,
context_config: MemoryContextConfig,
context_config: MessageContextConfig,
) -> dict[str, object] | None:
del thread_id, context_config
return {
@@ -295,7 +295,7 @@ async def test_build_recent_context_messages_uses_tool_metadata_output(
messages = await tasks_module._build_recent_context_messages(
session=object(),
thread_id=str(uuid4()),
context_config=MemoryContextConfig(),
context_config=MessageContextConfig(),
)
assert len(messages) == 1
@@ -319,7 +319,7 @@ async def test_build_recent_context_messages_skips_tool_without_metadata_output(
self,
*,
thread_id: str,
context_config: MemoryContextConfig,
context_config: MessageContextConfig,
) -> dict[str, object] | None:
del thread_id, context_config
return {
@@ -337,7 +337,7 @@ async def test_build_recent_context_messages_skips_tool_without_metadata_output(
messages = await tasks_module._build_recent_context_messages(
session=object(),
thread_id=str(uuid4()),
context_config=MemoryContextConfig(),
context_config=MessageContextConfig(),
)
assert messages == []
@@ -357,7 +357,7 @@ async def test_build_recent_context_messages_passes_context_config(
self,
*,
thread_id: str,
context_config: MemoryContextConfig,
context_config: MessageContextConfig,
) -> dict[str, object] | None:
del thread_id
captured_config["config"] = context_config
@@ -365,7 +365,7 @@ async def test_build_recent_context_messages_passes_context_config(
monkeypatch.setattr(tasks_module, "AgentContextService", _FakeContextService)
cfg = MemoryContextConfig(window_mode=ContextWindowMode.NUMBER, window_count=10)
cfg = MessageContextConfig(window_mode=ContextWindowMode.NUMBER, window_count=10)
messages = await tasks_module._build_recent_context_messages(
session=object(),
thread_id=str(uuid4()),
@@ -0,0 +1,113 @@
from __future__ import annotations
import json
from types import SimpleNamespace
from uuid import uuid4
import pytest
from agentscope.tool import ToolResponse
from core.agentscope.tools.custom import memory as memory_module
from models.memories import MemoryType
from schemas.memories.memory_content import UserMemoryContent
def _decode_tool_response(response: ToolResponse) -> dict[str, object]:
assert response.content
first = response.content[0]
text = str(first.get("text", "")) if isinstance(first, dict) else str(first.text)
return json.loads(text)
def _payload_error_code(payload: dict[str, object]) -> str:
error = payload.get("error")
if not isinstance(error, dict):
return ""
return str(error.get("code") or "")
class _FakeMemoriesService:
def __init__(self) -> None:
self.memory: object | None = None
self.updated_user = 0
self.updated_work = 0
async def get_memory_model(self, *, memory_type: MemoryType):
_ = memory_type
return self.memory
async def update_user_memory(self, **kwargs):
_ = kwargs
self.updated_user += 1
return SimpleNamespace()
async def update_work_memory(self, **kwargs):
_ = kwargs
self.updated_work += 1
return SimpleNamespace()
def _user_memory():
return SimpleNamespace(
id=uuid4(),
owner_id=uuid4(),
memory_type=MemoryType.USER,
content={"preferences": {"communication_style": "简洁"}},
status="active",
)
@pytest.mark.asyncio
async def test_memory_write_requires_runtime_context() -> None:
response = await memory_module.memory_write(
memory_type="user",
user_content=UserMemoryContent(interests=["跑步"]),
)
payload = _decode_tool_response(response)
assert payload["status"] == "failure"
assert _payload_error_code(payload) == "MISSING_RUNTIME_ARGS"
@pytest.mark.asyncio
async def test_memory_write_updates_user_content(
monkeypatch: pytest.MonkeyPatch,
) -> None:
fake_service = _FakeMemoriesService()
monkeypatch.setattr(
memory_module, "create_memories_service", lambda **_: fake_service
)
response = await memory_module.memory_write(
memory_type="user",
user_content=UserMemoryContent(interests=["阅读"]),
session=SimpleNamespace(),
owner_id=uuid4(),
)
payload = _decode_tool_response(response)
assert payload["status"] == "success"
assert "memory_type=user" in str(payload["result"])
assert fake_service.updated_user == 1
@pytest.mark.asyncio
async def test_memory_forget_updates_content_paths(
monkeypatch: pytest.MonkeyPatch,
) -> None:
fake_service = _FakeMemoriesService()
fake_service.memory = _user_memory()
monkeypatch.setattr(
memory_module, "create_memories_service", lambda **_: fake_service
)
response = await memory_module.memory_forget(
memory_type="user",
forget_paths=["preferences.communication_style"],
session=SimpleNamespace(),
owner_id=uuid4(),
)
payload = _decode_tool_response(response)
assert payload["status"] == "success"
assert "forgotten=1" in str(payload["result"])
assert fake_service.updated_user == 1
@@ -158,46 +158,44 @@ def test_build_system_prompt_keeps_sections_focused_without_language_duplication
assert "Follow agent contracts strictly" not in prompt
def test_build_system_prompt_includes_memory_section_when_memories_provided() -> None:
from schemas.memories import (
MemoryContext,
MemoryListResponse,
MemorySource,
MemoryType,
def test_build_system_prompt_includes_user_memory_section_for_router() -> None:
from schemas.memories.memory_content import UserMemoryContent
user_memory = UserMemoryContent()
prompt = build_system_prompt(
agent_type=AgentType.ROUTER,
user_context=_build_user_context(),
now_utc=datetime(2026, 3, 11, 0, 0, tzinfo=timezone.utc),
user_memory=user_memory,
)
memories = MemoryListResponse(
owner_id=uuid4(),
memories=[
MemoryContext(
memory_type=MemoryType.USER,
source=MemorySource.MANUAL,
title="User prefers morning meetings",
content={"text": "User likes meetings before 10am"},
created_at=datetime(2026, 3, 1, tzinfo=timezone.utc),
updated_at=datetime(2026, 3, 1, tzinfo=timezone.utc),
),
],
total=1,
)
assert "<!-- USER_MEMORY_START -->" in prompt
assert "[User Memory]" in prompt
def test_build_system_prompt_includes_work_memory_section_for_worker() -> None:
from schemas.memories.memory_content import WorkProfileContent
work_memory = WorkProfileContent()
prompt = build_system_prompt(
agent_type=AgentType.WORKER,
user_context=_build_user_context(),
now_utc=datetime(2026, 3, 11, 0, 0, tzinfo=timezone.utc),
memories=memories,
work_memory=work_memory,
)
assert "<!-- MEMORY_START -->" in prompt
assert "[User Memories]" in prompt
assert "User prefers morning meetings" in prompt
assert "<!-- WORK_MEMORY_START -->" in prompt
assert "[Work Memory]" in prompt
def test_build_system_prompt_omits_memory_section_when_no_memories() -> None:
prompt = build_system_prompt(
agent_type=AgentType.WORKER,
agent_type=AgentType.ROUTER,
user_context=_build_user_context(),
now_utc=datetime(2026, 3, 11, 0, 0, tzinfo=timezone.utc),
)
assert "<!-- MEMORY_START -->" not in prompt
assert "<!-- USER_MEMORY_START -->" not in prompt
assert "<!-- WORK_MEMORY_START -->" not in prompt
@@ -28,25 +28,3 @@ def test_build_stage_toolkit_uses_explicit_enabled_tools_as_final_set(
)
assert captured["enabled_tool_names"] == {"calendar_read", "user_lookup"}
def test_build_stage_toolkit_uses_memory_defaults_without_explicit_tools(
monkeypatch,
) -> None:
captured: dict[str, object] = {}
def _fake_build_toolkit(**kwargs):
captured.update(kwargs)
return object()
monkeypatch.setattr(
"core.agentscope.tools.toolkit.build_toolkit", _fake_build_toolkit
)
build_stage_toolkit(
agent_type=AgentType.MEMORY,
session=cast(Any, object()),
owner_id=uuid4(),
)
assert captured["enabled_tool_names"] == {"calendar_read", "user_lookup"}
@@ -16,13 +16,14 @@ async def test_build_toolkit_registers_calendar_tools() -> None:
toolkit = build_toolkit(
session=cast(AsyncSession, SimpleNamespace()),
owner_id=uuid4(),
user_token="token-123",
)
schemas = toolkit.get_json_schemas()
names = {item["function"]["name"] for item in schemas}
assert "calendar_read" in names
assert "calendar_write" in names
assert "calendar_share" in names
assert "memory_write" in names
assert "memory_forget" in names
write_schema = next(
item for item in schemas if item["function"]["name"] == "calendar_write"