feat: 重构 memory 系统,支持 user memory 和 work memory 分离
This commit is contained in:
@@ -6,7 +6,7 @@ import pytest
|
||||
from ag_ui.core import RunAgentInput
|
||||
|
||||
from core.agentscope.runtime.orchestrator import AgentScopeRuntimeOrchestrator
|
||||
from schemas.automation import MemoryContextConfig, RuntimeConfig
|
||||
from schemas.automation import MessageContextConfig, RuntimeConfig
|
||||
from schemas.user import UserContext, parse_profile_settings
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ def _run_input() -> RunAgentInput:
|
||||
def _runtime_config() -> RuntimeConfig:
|
||||
return RuntimeConfig(
|
||||
enabled_tools=[],
|
||||
context=MemoryContextConfig(),
|
||||
context=MessageContextConfig(),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from schemas.agent.runtime_models import (
|
||||
WorkerAgentOutputLite,
|
||||
)
|
||||
from schemas.agent.system_agent import AgentType
|
||||
from schemas.automation import MemoryContextConfig, RuntimeConfig
|
||||
from schemas.automation import MessageContextConfig, RuntimeConfig
|
||||
from schemas.user import UserContext, parse_profile_settings
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ def _user_context() -> UserContext:
|
||||
def _runtime_config() -> RuntimeConfig:
|
||||
return RuntimeConfig(
|
||||
enabled_tools=[],
|
||||
context=MemoryContextConfig(),
|
||||
context=MessageContextConfig(),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import pytest
|
||||
|
||||
import core.agentscope.runtime.tasks as tasks_module
|
||||
from schemas.agent import ToolStatus
|
||||
from schemas.automation import ContextWindowMode, MemoryContextConfig
|
||||
from schemas.automation import ContextWindowMode, MessageContextConfig
|
||||
from schemas.user import UserContext, parse_profile_settings
|
||||
|
||||
|
||||
@@ -201,7 +201,7 @@ async def test_build_recent_context_messages_includes_all_user_attachments(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
context_config: MemoryContextConfig,
|
||||
context_config: MessageContextConfig,
|
||||
) -> dict[str, object] | None:
|
||||
del thread_id, context_config
|
||||
return {
|
||||
@@ -237,7 +237,7 @@ async def test_build_recent_context_messages_includes_all_user_attachments(
|
||||
messages = await tasks_module._build_recent_context_messages(
|
||||
session=object(),
|
||||
thread_id=str(uuid4()),
|
||||
context_config=MemoryContextConfig(
|
||||
context_config=MessageContextConfig(
|
||||
window_mode=ContextWindowMode.DAY,
|
||||
window_count=2,
|
||||
),
|
||||
@@ -264,7 +264,7 @@ async def test_build_recent_context_messages_uses_tool_metadata_output(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
context_config: MemoryContextConfig,
|
||||
context_config: MessageContextConfig,
|
||||
) -> dict[str, object] | None:
|
||||
del thread_id, context_config
|
||||
return {
|
||||
@@ -295,7 +295,7 @@ async def test_build_recent_context_messages_uses_tool_metadata_output(
|
||||
messages = await tasks_module._build_recent_context_messages(
|
||||
session=object(),
|
||||
thread_id=str(uuid4()),
|
||||
context_config=MemoryContextConfig(),
|
||||
context_config=MessageContextConfig(),
|
||||
)
|
||||
|
||||
assert len(messages) == 1
|
||||
@@ -319,7 +319,7 @@ async def test_build_recent_context_messages_skips_tool_without_metadata_output(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
context_config: MemoryContextConfig,
|
||||
context_config: MessageContextConfig,
|
||||
) -> dict[str, object] | None:
|
||||
del thread_id, context_config
|
||||
return {
|
||||
@@ -337,7 +337,7 @@ async def test_build_recent_context_messages_skips_tool_without_metadata_output(
|
||||
messages = await tasks_module._build_recent_context_messages(
|
||||
session=object(),
|
||||
thread_id=str(uuid4()),
|
||||
context_config=MemoryContextConfig(),
|
||||
context_config=MessageContextConfig(),
|
||||
)
|
||||
|
||||
assert messages == []
|
||||
@@ -357,7 +357,7 @@ async def test_build_recent_context_messages_passes_context_config(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
context_config: MemoryContextConfig,
|
||||
context_config: MessageContextConfig,
|
||||
) -> dict[str, object] | None:
|
||||
del thread_id
|
||||
captured_config["config"] = context_config
|
||||
@@ -365,7 +365,7 @@ async def test_build_recent_context_messages_passes_context_config(
|
||||
|
||||
monkeypatch.setattr(tasks_module, "AgentContextService", _FakeContextService)
|
||||
|
||||
cfg = MemoryContextConfig(window_mode=ContextWindowMode.NUMBER, window_count=10)
|
||||
cfg = MessageContextConfig(window_mode=ContextWindowMode.NUMBER, window_count=10)
|
||||
messages = await tasks_module._build_recent_context_messages(
|
||||
session=object(),
|
||||
thread_id=str(uuid4()),
|
||||
|
||||
@@ -0,0 +1,113 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from agentscope.tool import ToolResponse
|
||||
|
||||
from core.agentscope.tools.custom import memory as memory_module
|
||||
from models.memories import MemoryType
|
||||
from schemas.memories.memory_content import UserMemoryContent
|
||||
|
||||
|
||||
def _decode_tool_response(response: ToolResponse) -> dict[str, object]:
|
||||
assert response.content
|
||||
first = response.content[0]
|
||||
text = str(first.get("text", "")) if isinstance(first, dict) else str(first.text)
|
||||
return json.loads(text)
|
||||
|
||||
|
||||
def _payload_error_code(payload: dict[str, object]) -> str:
|
||||
error = payload.get("error")
|
||||
if not isinstance(error, dict):
|
||||
return ""
|
||||
return str(error.get("code") or "")
|
||||
|
||||
|
||||
class _FakeMemoriesService:
|
||||
def __init__(self) -> None:
|
||||
self.memory: object | None = None
|
||||
self.updated_user = 0
|
||||
self.updated_work = 0
|
||||
|
||||
async def get_memory_model(self, *, memory_type: MemoryType):
|
||||
_ = memory_type
|
||||
return self.memory
|
||||
|
||||
async def update_user_memory(self, **kwargs):
|
||||
_ = kwargs
|
||||
self.updated_user += 1
|
||||
return SimpleNamespace()
|
||||
|
||||
async def update_work_memory(self, **kwargs):
|
||||
_ = kwargs
|
||||
self.updated_work += 1
|
||||
return SimpleNamespace()
|
||||
|
||||
|
||||
def _user_memory():
|
||||
return SimpleNamespace(
|
||||
id=uuid4(),
|
||||
owner_id=uuid4(),
|
||||
memory_type=MemoryType.USER,
|
||||
content={"preferences": {"communication_style": "简洁"}},
|
||||
status="active",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_write_requires_runtime_context() -> None:
|
||||
response = await memory_module.memory_write(
|
||||
memory_type="user",
|
||||
user_content=UserMemoryContent(interests=["跑步"]),
|
||||
)
|
||||
payload = _decode_tool_response(response)
|
||||
assert payload["status"] == "failure"
|
||||
assert _payload_error_code(payload) == "MISSING_RUNTIME_ARGS"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_write_updates_user_content(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
fake_service = _FakeMemoriesService()
|
||||
monkeypatch.setattr(
|
||||
memory_module, "create_memories_service", lambda **_: fake_service
|
||||
)
|
||||
|
||||
response = await memory_module.memory_write(
|
||||
memory_type="user",
|
||||
user_content=UserMemoryContent(interests=["阅读"]),
|
||||
session=SimpleNamespace(),
|
||||
owner_id=uuid4(),
|
||||
)
|
||||
payload = _decode_tool_response(response)
|
||||
|
||||
assert payload["status"] == "success"
|
||||
assert "memory_type=user" in str(payload["result"])
|
||||
assert fake_service.updated_user == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_forget_updates_content_paths(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
fake_service = _FakeMemoriesService()
|
||||
fake_service.memory = _user_memory()
|
||||
monkeypatch.setattr(
|
||||
memory_module, "create_memories_service", lambda **_: fake_service
|
||||
)
|
||||
|
||||
response = await memory_module.memory_forget(
|
||||
memory_type="user",
|
||||
forget_paths=["preferences.communication_style"],
|
||||
session=SimpleNamespace(),
|
||||
owner_id=uuid4(),
|
||||
)
|
||||
payload = _decode_tool_response(response)
|
||||
|
||||
assert payload["status"] == "success"
|
||||
assert "forgotten=1" in str(payload["result"])
|
||||
assert fake_service.updated_user == 1
|
||||
@@ -158,46 +158,44 @@ def test_build_system_prompt_keeps_sections_focused_without_language_duplication
|
||||
assert "Follow agent contracts strictly" not in prompt
|
||||
|
||||
|
||||
def test_build_system_prompt_includes_memory_section_when_memories_provided() -> None:
|
||||
from schemas.memories import (
|
||||
MemoryContext,
|
||||
MemoryListResponse,
|
||||
MemorySource,
|
||||
MemoryType,
|
||||
def test_build_system_prompt_includes_user_memory_section_for_router() -> None:
|
||||
from schemas.memories.memory_content import UserMemoryContent
|
||||
|
||||
user_memory = UserMemoryContent()
|
||||
|
||||
prompt = build_system_prompt(
|
||||
agent_type=AgentType.ROUTER,
|
||||
user_context=_build_user_context(),
|
||||
now_utc=datetime(2026, 3, 11, 0, 0, tzinfo=timezone.utc),
|
||||
user_memory=user_memory,
|
||||
)
|
||||
|
||||
memories = MemoryListResponse(
|
||||
owner_id=uuid4(),
|
||||
memories=[
|
||||
MemoryContext(
|
||||
memory_type=MemoryType.USER,
|
||||
source=MemorySource.MANUAL,
|
||||
title="User prefers morning meetings",
|
||||
content={"text": "User likes meetings before 10am"},
|
||||
created_at=datetime(2026, 3, 1, tzinfo=timezone.utc),
|
||||
updated_at=datetime(2026, 3, 1, tzinfo=timezone.utc),
|
||||
),
|
||||
],
|
||||
total=1,
|
||||
)
|
||||
assert "<!-- USER_MEMORY_START -->" in prompt
|
||||
assert "[User Memory]" in prompt
|
||||
|
||||
|
||||
def test_build_system_prompt_includes_work_memory_section_for_worker() -> None:
|
||||
from schemas.memories.memory_content import WorkProfileContent
|
||||
|
||||
work_memory = WorkProfileContent()
|
||||
|
||||
prompt = build_system_prompt(
|
||||
agent_type=AgentType.WORKER,
|
||||
user_context=_build_user_context(),
|
||||
now_utc=datetime(2026, 3, 11, 0, 0, tzinfo=timezone.utc),
|
||||
memories=memories,
|
||||
work_memory=work_memory,
|
||||
)
|
||||
|
||||
assert "<!-- MEMORY_START -->" in prompt
|
||||
assert "[User Memories]" in prompt
|
||||
assert "User prefers morning meetings" in prompt
|
||||
assert "<!-- WORK_MEMORY_START -->" in prompt
|
||||
assert "[Work Memory]" in prompt
|
||||
|
||||
|
||||
def test_build_system_prompt_omits_memory_section_when_no_memories() -> None:
|
||||
prompt = build_system_prompt(
|
||||
agent_type=AgentType.WORKER,
|
||||
agent_type=AgentType.ROUTER,
|
||||
user_context=_build_user_context(),
|
||||
now_utc=datetime(2026, 3, 11, 0, 0, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
assert "<!-- MEMORY_START -->" not in prompt
|
||||
assert "<!-- USER_MEMORY_START -->" not in prompt
|
||||
assert "<!-- WORK_MEMORY_START -->" not in prompt
|
||||
|
||||
@@ -28,25 +28,3 @@ def test_build_stage_toolkit_uses_explicit_enabled_tools_as_final_set(
|
||||
)
|
||||
|
||||
assert captured["enabled_tool_names"] == {"calendar_read", "user_lookup"}
|
||||
|
||||
|
||||
def test_build_stage_toolkit_uses_memory_defaults_without_explicit_tools(
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
def _fake_build_toolkit(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return object()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.agentscope.tools.toolkit.build_toolkit", _fake_build_toolkit
|
||||
)
|
||||
|
||||
build_stage_toolkit(
|
||||
agent_type=AgentType.MEMORY,
|
||||
session=cast(Any, object()),
|
||||
owner_id=uuid4(),
|
||||
)
|
||||
|
||||
assert captured["enabled_tool_names"] == {"calendar_read", "user_lookup"}
|
||||
|
||||
@@ -16,13 +16,14 @@ async def test_build_toolkit_registers_calendar_tools() -> None:
|
||||
toolkit = build_toolkit(
|
||||
session=cast(AsyncSession, SimpleNamespace()),
|
||||
owner_id=uuid4(),
|
||||
user_token="token-123",
|
||||
)
|
||||
schemas = toolkit.get_json_schemas()
|
||||
names = {item["function"]["name"] for item in schemas}
|
||||
assert "calendar_read" in names
|
||||
assert "calendar_write" in names
|
||||
assert "calendar_share" in names
|
||||
assert "memory_write" in names
|
||||
assert "memory_forget" in names
|
||||
|
||||
write_schema = next(
|
||||
item for item in schemas if item["function"]["name"] == "calendar_write"
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from v1.auth.automation_static_config import load_static_automation_job_config
|
||||
|
||||
|
||||
def test_memory_automation_static_config_contract() -> None:
|
||||
config = load_static_automation_job_config(config_name="memory_extraction")
|
||||
|
||||
assert config.context.window_mode.value == "day"
|
||||
assert config.context.window_count == 2
|
||||
assert [tool.value for tool in config.enabled_tools] == [
|
||||
"memory.write",
|
||||
"memory.forget",
|
||||
]
|
||||
prompt = config.input_template
|
||||
assert "提取" in prompt
|
||||
assert "遗忘" in prompt
|
||||
@@ -16,3 +16,18 @@ def test_memory_automation_job_trigger_exists_in_0004_migration() -> None:
|
||||
assert "'agent_type', 'memory'" in content
|
||||
assert "ux_automation_jobs_owner_memory_active" in content
|
||||
assert "input_template" in content
|
||||
|
||||
|
||||
def test_bootstrap_key_replaces_agent_type_unique_anchor() -> None:
|
||||
migration = (
|
||||
Path(__file__).resolve().parents[3]
|
||||
/ "alembic"
|
||||
/ "versions"
|
||||
/ "20260323_0003_bootstrap_job_key_and_unique_indexes.py"
|
||||
)
|
||||
content = migration.read_text(encoding="utf-8")
|
||||
|
||||
assert "bootstrap_key" in content
|
||||
assert "ux_automation_jobs_owner_bootstrap_key_active" in content
|
||||
assert "ux_memories_owner_memory_type" in content
|
||||
assert "DROP INDEX IF EXISTS ux_automation_jobs_owner_memory_active" in content
|
||||
|
||||
@@ -12,6 +12,14 @@ from v1.auth.schemas import (
|
||||
from v1.auth.service import AuthService, AuthServiceGateway
|
||||
|
||||
|
||||
class FakeRegistrationBootstrapper:
|
||||
def __init__(self) -> None:
|
||||
self.called_user_ids: list[str] = []
|
||||
|
||||
async def ensure_user_automation_jobs(self, *, user_id: str) -> None:
|
||||
self.called_user_ids.append(user_id)
|
||||
|
||||
|
||||
class FakeGateway(AuthServiceGateway):
|
||||
def __init__(self, response: SessionResponse) -> None:
|
||||
self._response = response
|
||||
@@ -75,6 +83,27 @@ async def test_create_phone_session_forwards_payload() -> None:
|
||||
assert response.user.phone == "+8613812345678"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_phone_session_bootstraps_automation_job() -> None:
|
||||
user = AuthUser(id="b196f8be-c5f4-45d8-8f07-65c0ddf4d3de", phone="+8613812345678")
|
||||
token_response = SessionResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
gateway = FakeGateway(token_response)
|
||||
bootstrapper = FakeRegistrationBootstrapper()
|
||||
service = AuthService(gateway=gateway, registration_bootstrapper=bootstrapper)
|
||||
|
||||
await service.create_phone_session(
|
||||
PhoneSessionCreateRequest(phone="+8613812345678", token="123456")
|
||||
)
|
||||
|
||||
assert bootstrapper.called_user_ids == ["b196f8be-c5f4-45d8-8f07-65c0ddf4d3de"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_session_forwards_payload() -> None:
|
||||
user = AuthUser(id="user-1", phone="+8613812345678")
|
||||
|
||||
@@ -0,0 +1,112 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from v1.auth.registration_bootstrap import (
|
||||
compute_next_local_time_utc,
|
||||
)
|
||||
|
||||
|
||||
def test_compute_next_local_time_utc_from_asia_shanghai() -> None:
|
||||
now_utc = datetime(2026, 3, 23, 0, 30, tzinfo=timezone.utc)
|
||||
|
||||
run_at, next_run_at = compute_next_local_time_utc(
|
||||
now_utc=now_utc,
|
||||
timezone_name="Asia/Shanghai",
|
||||
local_hour=8,
|
||||
local_minute=0,
|
||||
)
|
||||
|
||||
assert run_at == datetime(2026, 3, 24, 0, 0, tzinfo=timezone.utc)
|
||||
assert next_run_at == datetime(2026, 3, 25, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
|
||||
def test_compute_next_local_time_utc_rolls_to_next_day_when_passed() -> None:
|
||||
now_utc = datetime(2026, 3, 23, 2, 30, tzinfo=timezone.utc)
|
||||
|
||||
run_at, next_run_at = compute_next_local_time_utc(
|
||||
now_utc=now_utc,
|
||||
timezone_name="Asia/Shanghai",
|
||||
local_hour=8,
|
||||
local_minute=0,
|
||||
)
|
||||
|
||||
assert run_at == datetime(2026, 3, 24, 0, 0, tzinfo=timezone.utc)
|
||||
assert next_run_at == datetime(2026, 3, 25, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_registration_service_is_idempotent_when_job_exists() -> None:
|
||||
from v1.auth.registration_bootstrap import RegistrationAutomationBootstrapService
|
||||
|
||||
expected_owner_id = uuid4()
|
||||
|
||||
class _Repo:
|
||||
inserted = 0
|
||||
upsert_calls = 0
|
||||
|
||||
async def get_profile_timezone(self, *, user_id):
|
||||
assert user_id == expected_owner_id
|
||||
return "Asia/Shanghai"
|
||||
|
||||
async def insert_bootstrap_automation_job_if_absent(self, **kwargs):
|
||||
assert kwargs["owner_id"] == expected_owner_id
|
||||
assert kwargs["bootstrap_key"] == "memory_extraction"
|
||||
self.inserted += 1
|
||||
return False
|
||||
|
||||
async def upsert_initial_memory(self, **kwargs):
|
||||
self.upsert_calls += 1
|
||||
return False
|
||||
|
||||
class _Session:
|
||||
async def commit(self):
|
||||
raise AssertionError("must not commit when already exists")
|
||||
|
||||
async def rollback(self):
|
||||
raise AssertionError("must not rollback when no error")
|
||||
|
||||
service = RegistrationAutomationBootstrapService(
|
||||
repository=cast(Any, _Repo()), session=cast(Any, _Session())
|
||||
)
|
||||
await service.ensure_user_automation_jobs(user_id=str(expected_owner_id))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_registration_service_creates_initial_memories_when_missing() -> None:
|
||||
from v1.auth.registration_bootstrap import RegistrationAutomationBootstrapService
|
||||
|
||||
expected_owner_id = uuid4()
|
||||
|
||||
class _Repo:
|
||||
async def get_profile_timezone(self, *, user_id):
|
||||
assert user_id == expected_owner_id
|
||||
return "Asia/Shanghai"
|
||||
|
||||
async def upsert_initial_memory(self, **kwargs):
|
||||
return True
|
||||
|
||||
async def insert_bootstrap_automation_job_if_absent(self, **kwargs):
|
||||
_ = kwargs
|
||||
return True
|
||||
|
||||
class _Session:
|
||||
committed = 0
|
||||
|
||||
async def commit(self):
|
||||
self.committed += 1
|
||||
|
||||
async def rollback(self):
|
||||
raise AssertionError("must not rollback when no error")
|
||||
|
||||
session = _Session()
|
||||
service = RegistrationAutomationBootstrapService(
|
||||
repository=cast(Any, _Repo()), session=cast(Any, session)
|
||||
)
|
||||
await service.ensure_user_automation_jobs(user_id=str(expected_owner_id))
|
||||
|
||||
assert session.committed == 1
|
||||
Reference in New Issue
Block a user