feat: 重构 memory 系统,支持 user memory 和 work memory 分离

This commit is contained in:
qzl
2026-03-23 14:25:47 +08:00
parent 3aacc756db
commit 6be616f108
70 changed files with 7031 additions and 431 deletions
@@ -6,7 +6,7 @@ import pytest
from ag_ui.core import RunAgentInput
from core.agentscope.runtime.orchestrator import AgentScopeRuntimeOrchestrator
from schemas.automation import MemoryContextConfig, RuntimeConfig
from schemas.automation import MessageContextConfig, RuntimeConfig
from schemas.user import UserContext, parse_profile_settings
@@ -51,7 +51,7 @@ def _run_input() -> RunAgentInput:
def _runtime_config() -> RuntimeConfig:
return RuntimeConfig(
enabled_tools=[],
context=MemoryContextConfig(),
context=MessageContextConfig(),
)
@@ -18,7 +18,7 @@ from schemas.agent.runtime_models import (
WorkerAgentOutputLite,
)
from schemas.agent.system_agent import AgentType
from schemas.automation import MemoryContextConfig, RuntimeConfig
from schemas.automation import MessageContextConfig, RuntimeConfig
from schemas.user import UserContext, parse_profile_settings
@@ -48,7 +48,7 @@ def _user_context() -> UserContext:
def _runtime_config() -> RuntimeConfig:
return RuntimeConfig(
enabled_tools=[],
context=MemoryContextConfig(),
context=MessageContextConfig(),
)
@@ -7,7 +7,7 @@ import pytest
import core.agentscope.runtime.tasks as tasks_module
from schemas.agent import ToolStatus
from schemas.automation import ContextWindowMode, MemoryContextConfig
from schemas.automation import ContextWindowMode, MessageContextConfig
from schemas.user import UserContext, parse_profile_settings
@@ -201,7 +201,7 @@ async def test_build_recent_context_messages_includes_all_user_attachments(
self,
*,
thread_id: str,
context_config: MemoryContextConfig,
context_config: MessageContextConfig,
) -> dict[str, object] | None:
del thread_id, context_config
return {
@@ -237,7 +237,7 @@ async def test_build_recent_context_messages_includes_all_user_attachments(
messages = await tasks_module._build_recent_context_messages(
session=object(),
thread_id=str(uuid4()),
context_config=MemoryContextConfig(
context_config=MessageContextConfig(
window_mode=ContextWindowMode.DAY,
window_count=2,
),
@@ -264,7 +264,7 @@ async def test_build_recent_context_messages_uses_tool_metadata_output(
self,
*,
thread_id: str,
context_config: MemoryContextConfig,
context_config: MessageContextConfig,
) -> dict[str, object] | None:
del thread_id, context_config
return {
@@ -295,7 +295,7 @@ async def test_build_recent_context_messages_uses_tool_metadata_output(
messages = await tasks_module._build_recent_context_messages(
session=object(),
thread_id=str(uuid4()),
context_config=MemoryContextConfig(),
context_config=MessageContextConfig(),
)
assert len(messages) == 1
@@ -319,7 +319,7 @@ async def test_build_recent_context_messages_skips_tool_without_metadata_output(
self,
*,
thread_id: str,
context_config: MemoryContextConfig,
context_config: MessageContextConfig,
) -> dict[str, object] | None:
del thread_id, context_config
return {
@@ -337,7 +337,7 @@ async def test_build_recent_context_messages_skips_tool_without_metadata_output(
messages = await tasks_module._build_recent_context_messages(
session=object(),
thread_id=str(uuid4()),
context_config=MemoryContextConfig(),
context_config=MessageContextConfig(),
)
assert messages == []
@@ -357,7 +357,7 @@ async def test_build_recent_context_messages_passes_context_config(
self,
*,
thread_id: str,
context_config: MemoryContextConfig,
context_config: MessageContextConfig,
) -> dict[str, object] | None:
del thread_id
captured_config["config"] = context_config
@@ -365,7 +365,7 @@ async def test_build_recent_context_messages_passes_context_config(
monkeypatch.setattr(tasks_module, "AgentContextService", _FakeContextService)
cfg = MemoryContextConfig(window_mode=ContextWindowMode.NUMBER, window_count=10)
cfg = MessageContextConfig(window_mode=ContextWindowMode.NUMBER, window_count=10)
messages = await tasks_module._build_recent_context_messages(
session=object(),
thread_id=str(uuid4()),
@@ -0,0 +1,113 @@
from __future__ import annotations
import json
from types import SimpleNamespace
from uuid import uuid4
import pytest
from agentscope.tool import ToolResponse
from core.agentscope.tools.custom import memory as memory_module
from models.memories import MemoryType
from schemas.memories.memory_content import UserMemoryContent
def _decode_tool_response(response: ToolResponse) -> dict[str, object]:
assert response.content
first = response.content[0]
text = str(first.get("text", "")) if isinstance(first, dict) else str(first.text)
return json.loads(text)
def _payload_error_code(payload: dict[str, object]) -> str:
error = payload.get("error")
if not isinstance(error, dict):
return ""
return str(error.get("code") or "")
class _FakeMemoriesService:
def __init__(self) -> None:
self.memory: object | None = None
self.updated_user = 0
self.updated_work = 0
async def get_memory_model(self, *, memory_type: MemoryType):
_ = memory_type
return self.memory
async def update_user_memory(self, **kwargs):
_ = kwargs
self.updated_user += 1
return SimpleNamespace()
async def update_work_memory(self, **kwargs):
_ = kwargs
self.updated_work += 1
return SimpleNamespace()
def _user_memory():
return SimpleNamespace(
id=uuid4(),
owner_id=uuid4(),
memory_type=MemoryType.USER,
content={"preferences": {"communication_style": "简洁"}},
status="active",
)
@pytest.mark.asyncio
async def test_memory_write_requires_runtime_context() -> None:
response = await memory_module.memory_write(
memory_type="user",
user_content=UserMemoryContent(interests=["跑步"]),
)
payload = _decode_tool_response(response)
assert payload["status"] == "failure"
assert _payload_error_code(payload) == "MISSING_RUNTIME_ARGS"
@pytest.mark.asyncio
async def test_memory_write_updates_user_content(
monkeypatch: pytest.MonkeyPatch,
) -> None:
fake_service = _FakeMemoriesService()
monkeypatch.setattr(
memory_module, "create_memories_service", lambda **_: fake_service
)
response = await memory_module.memory_write(
memory_type="user",
user_content=UserMemoryContent(interests=["阅读"]),
session=SimpleNamespace(),
owner_id=uuid4(),
)
payload = _decode_tool_response(response)
assert payload["status"] == "success"
assert "memory_type=user" in str(payload["result"])
assert fake_service.updated_user == 1
@pytest.mark.asyncio
async def test_memory_forget_updates_content_paths(
monkeypatch: pytest.MonkeyPatch,
) -> None:
fake_service = _FakeMemoriesService()
fake_service.memory = _user_memory()
monkeypatch.setattr(
memory_module, "create_memories_service", lambda **_: fake_service
)
response = await memory_module.memory_forget(
memory_type="user",
forget_paths=["preferences.communication_style"],
session=SimpleNamespace(),
owner_id=uuid4(),
)
payload = _decode_tool_response(response)
assert payload["status"] == "success"
assert "forgotten=1" in str(payload["result"])
assert fake_service.updated_user == 1
@@ -158,46 +158,44 @@ def test_build_system_prompt_keeps_sections_focused_without_language_duplication
assert "Follow agent contracts strictly" not in prompt
def test_build_system_prompt_includes_memory_section_when_memories_provided() -> None:
from schemas.memories import (
MemoryContext,
MemoryListResponse,
MemorySource,
MemoryType,
def test_build_system_prompt_includes_user_memory_section_for_router() -> None:
from schemas.memories.memory_content import UserMemoryContent
user_memory = UserMemoryContent()
prompt = build_system_prompt(
agent_type=AgentType.ROUTER,
user_context=_build_user_context(),
now_utc=datetime(2026, 3, 11, 0, 0, tzinfo=timezone.utc),
user_memory=user_memory,
)
memories = MemoryListResponse(
owner_id=uuid4(),
memories=[
MemoryContext(
memory_type=MemoryType.USER,
source=MemorySource.MANUAL,
title="User prefers morning meetings",
content={"text": "User likes meetings before 10am"},
created_at=datetime(2026, 3, 1, tzinfo=timezone.utc),
updated_at=datetime(2026, 3, 1, tzinfo=timezone.utc),
),
],
total=1,
)
assert "<!-- USER_MEMORY_START -->" in prompt
assert "[User Memory]" in prompt
def test_build_system_prompt_includes_work_memory_section_for_worker() -> None:
from schemas.memories.memory_content import WorkProfileContent
work_memory = WorkProfileContent()
prompt = build_system_prompt(
agent_type=AgentType.WORKER,
user_context=_build_user_context(),
now_utc=datetime(2026, 3, 11, 0, 0, tzinfo=timezone.utc),
memories=memories,
work_memory=work_memory,
)
assert "<!-- MEMORY_START -->" in prompt
assert "[User Memories]" in prompt
assert "User prefers morning meetings" in prompt
assert "<!-- WORK_MEMORY_START -->" in prompt
assert "[Work Memory]" in prompt
def test_build_system_prompt_omits_memory_section_when_no_memories() -> None:
prompt = build_system_prompt(
agent_type=AgentType.WORKER,
agent_type=AgentType.ROUTER,
user_context=_build_user_context(),
now_utc=datetime(2026, 3, 11, 0, 0, tzinfo=timezone.utc),
)
assert "<!-- MEMORY_START -->" not in prompt
assert "<!-- USER_MEMORY_START -->" not in prompt
assert "<!-- WORK_MEMORY_START -->" not in prompt
@@ -28,25 +28,3 @@ def test_build_stage_toolkit_uses_explicit_enabled_tools_as_final_set(
)
assert captured["enabled_tool_names"] == {"calendar_read", "user_lookup"}
def test_build_stage_toolkit_uses_memory_defaults_without_explicit_tools(
monkeypatch,
) -> None:
captured: dict[str, object] = {}
def _fake_build_toolkit(**kwargs):
captured.update(kwargs)
return object()
monkeypatch.setattr(
"core.agentscope.tools.toolkit.build_toolkit", _fake_build_toolkit
)
build_stage_toolkit(
agent_type=AgentType.MEMORY,
session=cast(Any, object()),
owner_id=uuid4(),
)
assert captured["enabled_tool_names"] == {"calendar_read", "user_lookup"}
@@ -16,13 +16,14 @@ async def test_build_toolkit_registers_calendar_tools() -> None:
toolkit = build_toolkit(
session=cast(AsyncSession, SimpleNamespace()),
owner_id=uuid4(),
user_token="token-123",
)
schemas = toolkit.get_json_schemas()
names = {item["function"]["name"] for item in schemas}
assert "calendar_read" in names
assert "calendar_write" in names
assert "calendar_share" in names
assert "memory_write" in names
assert "memory_forget" in names
write_schema = next(
item for item in schemas if item["function"]["name"] == "calendar_write"
@@ -0,0 +1,17 @@
from __future__ import annotations
from v1.auth.automation_static_config import load_static_automation_job_config
def test_memory_automation_static_config_contract() -> None:
config = load_static_automation_job_config(config_name="memory_extraction")
assert config.context.window_mode.value == "day"
assert config.context.window_count == 2
assert [tool.value for tool in config.enabled_tools] == [
"memory.write",
"memory.forget",
]
prompt = config.input_template
assert "提取" in prompt
assert "遗忘" in prompt
@@ -16,3 +16,18 @@ def test_memory_automation_job_trigger_exists_in_0004_migration() -> None:
assert "'agent_type', 'memory'" in content
assert "ux_automation_jobs_owner_memory_active" in content
assert "input_template" in content
def test_bootstrap_key_replaces_agent_type_unique_anchor() -> None:
migration = (
Path(__file__).resolve().parents[3]
/ "alembic"
/ "versions"
/ "20260323_0003_bootstrap_job_key_and_unique_indexes.py"
)
content = migration.read_text(encoding="utf-8")
assert "bootstrap_key" in content
assert "ux_automation_jobs_owner_bootstrap_key_active" in content
assert "ux_memories_owner_memory_type" in content
assert "DROP INDEX IF EXISTS ux_automation_jobs_owner_memory_active" in content
@@ -12,6 +12,14 @@ from v1.auth.schemas import (
from v1.auth.service import AuthService, AuthServiceGateway
class FakeRegistrationBootstrapper:
def __init__(self) -> None:
self.called_user_ids: list[str] = []
async def ensure_user_automation_jobs(self, *, user_id: str) -> None:
self.called_user_ids.append(user_id)
class FakeGateway(AuthServiceGateway):
def __init__(self, response: SessionResponse) -> None:
self._response = response
@@ -75,6 +83,27 @@ async def test_create_phone_session_forwards_payload() -> None:
assert response.user.phone == "+8613812345678"
@pytest.mark.asyncio
async def test_create_phone_session_bootstraps_automation_job() -> None:
user = AuthUser(id="b196f8be-c5f4-45d8-8f07-65c0ddf4d3de", phone="+8613812345678")
token_response = SessionResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
gateway = FakeGateway(token_response)
bootstrapper = FakeRegistrationBootstrapper()
service = AuthService(gateway=gateway, registration_bootstrapper=bootstrapper)
await service.create_phone_session(
PhoneSessionCreateRequest(phone="+8613812345678", token="123456")
)
assert bootstrapper.called_user_ids == ["b196f8be-c5f4-45d8-8f07-65c0ddf4d3de"]
@pytest.mark.asyncio
async def test_refresh_session_forwards_payload() -> None:
user = AuthUser(id="user-1", phone="+8613812345678")
@@ -0,0 +1,112 @@
from __future__ import annotations
from datetime import datetime, timezone
from typing import Any, cast
from uuid import uuid4
import pytest
from v1.auth.registration_bootstrap import (
compute_next_local_time_utc,
)
def test_compute_next_local_time_utc_from_asia_shanghai() -> None:
now_utc = datetime(2026, 3, 23, 0, 30, tzinfo=timezone.utc)
run_at, next_run_at = compute_next_local_time_utc(
now_utc=now_utc,
timezone_name="Asia/Shanghai",
local_hour=8,
local_minute=0,
)
assert run_at == datetime(2026, 3, 24, 0, 0, tzinfo=timezone.utc)
assert next_run_at == datetime(2026, 3, 25, 0, 0, tzinfo=timezone.utc)
def test_compute_next_local_time_utc_rolls_to_next_day_when_passed() -> None:
now_utc = datetime(2026, 3, 23, 2, 30, tzinfo=timezone.utc)
run_at, next_run_at = compute_next_local_time_utc(
now_utc=now_utc,
timezone_name="Asia/Shanghai",
local_hour=8,
local_minute=0,
)
assert run_at == datetime(2026, 3, 24, 0, 0, tzinfo=timezone.utc)
assert next_run_at == datetime(2026, 3, 25, 0, 0, tzinfo=timezone.utc)
@pytest.mark.asyncio
async def test_registration_service_is_idempotent_when_job_exists() -> None:
from v1.auth.registration_bootstrap import RegistrationAutomationBootstrapService
expected_owner_id = uuid4()
class _Repo:
inserted = 0
upsert_calls = 0
async def get_profile_timezone(self, *, user_id):
assert user_id == expected_owner_id
return "Asia/Shanghai"
async def insert_bootstrap_automation_job_if_absent(self, **kwargs):
assert kwargs["owner_id"] == expected_owner_id
assert kwargs["bootstrap_key"] == "memory_extraction"
self.inserted += 1
return False
async def upsert_initial_memory(self, **kwargs):
self.upsert_calls += 1
return False
class _Session:
async def commit(self):
raise AssertionError("must not commit when already exists")
async def rollback(self):
raise AssertionError("must not rollback when no error")
service = RegistrationAutomationBootstrapService(
repository=cast(Any, _Repo()), session=cast(Any, _Session())
)
await service.ensure_user_automation_jobs(user_id=str(expected_owner_id))
@pytest.mark.asyncio
async def test_registration_service_creates_initial_memories_when_missing() -> None:
from v1.auth.registration_bootstrap import RegistrationAutomationBootstrapService
expected_owner_id = uuid4()
class _Repo:
async def get_profile_timezone(self, *, user_id):
assert user_id == expected_owner_id
return "Asia/Shanghai"
async def upsert_initial_memory(self, **kwargs):
return True
async def insert_bootstrap_automation_job_if_absent(self, **kwargs):
_ = kwargs
return True
class _Session:
committed = 0
async def commit(self):
self.committed += 1
async def rollback(self):
raise AssertionError("must not rollback when no error")
session = _Session()
service = RegistrationAutomationBootstrapService(
repository=cast(Any, _Repo()), session=cast(Any, session)
)
await service.ensure_user_automation_jobs(user_id=str(expected_owner_id))
assert session.committed == 1