feat: 重构 memory 系统,支持 user memory 和 work memory 分离

This commit is contained in:
qzl
2026-03-23 14:25:47 +08:00
parent 3aacc756db
commit 6be616f108
70 changed files with 7031 additions and 431 deletions
@@ -0,0 +1,38 @@
"""drop source column from memories
Revision ID: 202603230001
Revises: 202603200001
Create Date: 2026-03-23 18:00:00
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
revision: str = "202603230001"
down_revision: Union[str, Sequence[str], None] = "202603200001"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
bind = op.get_bind()
inspector = sa.inspect(bind)
columns = {column["name"] for column in inspector.get_columns("memories")}
if "source" in columns:
op.drop_column("memories", "source")
def downgrade() -> None:
bind = op.get_bind()
inspector = sa.inspect(bind)
columns = {column["name"] for column in inspector.get_columns("memories")}
if "source" not in columns:
op.add_column(
"memories",
sa.Column(
"source", sa.String(length=20), nullable=False, server_default="agent"
),
)
op.alter_column("memories", "source", server_default=None)
@@ -0,0 +1,34 @@
"""drop title column from memories
Revision ID: 202603230002
Revises: 202603230001
Create Date: 2026-03-23 21:00:00
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
revision: str = "202603230002"
down_revision: Union[str, Sequence[str], None] = "202603230001"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
bind = op.get_bind()
inspector = sa.inspect(bind)
columns = {column["name"] for column in inspector.get_columns("memories")}
if "title" in columns:
op.drop_column("memories", "title")
def downgrade() -> None:
bind = op.get_bind()
inspector = sa.inspect(bind)
columns = {column["name"] for column in inspector.get_columns("memories")}
if "title" not in columns:
op.add_column(
"memories", sa.Column("title", sa.String(length=255), nullable=True)
)
@@ -0,0 +1,355 @@
"""add bootstrap key and unique indexes for registration bootstrap
Revision ID: 202603230003
Revises: 202603230002
Create Date: 2026-03-23 23:10:00
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
revision: str = "202603230003"
down_revision: Union[str, Sequence[str], None] = "202603230002"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
bind = op.get_bind()
inspector = sa.inspect(bind)
automation_columns = {
column["name"] for column in inspector.get_columns("automation_jobs")
}
if "bootstrap_key" not in automation_columns:
op.add_column(
"automation_jobs",
sa.Column("bootstrap_key", sa.String(length=64), nullable=True),
)
op.execute("DROP INDEX IF EXISTS ux_automation_jobs_owner_memory_active")
op.execute(
"""
UPDATE public.automation_jobs
SET bootstrap_key = 'memory_extraction'
WHERE bootstrap_key IS NULL
AND (
config->>'agent_type' = 'memory'
OR (
created_by = owner_id
AND title = 'Memory Agent'
AND coalesce(config->'enabled_tools', '[]'::jsonb) @> '["memory.write", "memory.forget"]'::jsonb
AND jsonb_array_length(coalesce(config->'enabled_tools', '[]'::jsonb)) = 2
AND coalesce(config->'context', '{}'::jsonb) @> jsonb_build_object(
'source', 'latest_chat',
'window_mode', 'day',
'window_count', 2
)
)
)
"""
)
op.execute(
"""
CREATE TABLE IF NOT EXISTS public.memories_dedup_backup_202603230003
(LIKE public.memories INCLUDING ALL)
"""
)
op.execute(
"""
CREATE TABLE IF NOT EXISTS public.automation_jobs_dedup_backup_202603230003 (
id UUID PRIMARY KEY
)
"""
)
op.execute(
"""
WITH ranked AS (
SELECT
id,
row_number() OVER (
PARTITION BY owner_id, bootstrap_key
ORDER BY updated_at DESC, created_at DESC, id DESC
) AS rn
FROM public.automation_jobs
WHERE deleted_at IS NULL
AND bootstrap_key IS NOT NULL
)
INSERT INTO public.automation_jobs_dedup_backup_202603230003(id)
SELECT id
FROM ranked
WHERE rn > 1
ON CONFLICT (id) DO NOTHING
"""
)
op.execute(
"""
WITH ranked AS (
SELECT
id,
row_number() OVER (
PARTITION BY owner_id, memory_type
ORDER BY updated_at DESC, created_at DESC, id DESC
) AS rn
FROM public.memories
)
INSERT INTO public.memories_dedup_backup_202603230003
SELECT m.*
FROM public.memories m
JOIN ranked r ON r.id = m.id
WHERE r.rn > 1
ON CONFLICT (id) DO NOTHING
"""
)
op.execute(
"""
WITH ranked AS (
SELECT
id,
row_number() OVER (
PARTITION BY owner_id, bootstrap_key
ORDER BY updated_at DESC, created_at DESC, id DESC
) AS rn
FROM public.automation_jobs
WHERE deleted_at IS NULL
AND bootstrap_key IS NOT NULL
)
UPDATE public.automation_jobs aj
SET deleted_at = now()
FROM ranked r
WHERE aj.id = r.id
AND r.rn > 1
"""
)
op.execute(
"""
CREATE UNIQUE INDEX IF NOT EXISTS ux_automation_jobs_owner_bootstrap_key_active
ON public.automation_jobs(owner_id, bootstrap_key)
WHERE deleted_at IS NULL
AND bootstrap_key IS NOT NULL
"""
)
op.execute(
"""
WITH ranked AS (
SELECT
id,
row_number() OVER (
PARTITION BY owner_id, memory_type
ORDER BY updated_at DESC, created_at DESC, id DESC
) AS rn
FROM public.memories
)
DELETE FROM public.memories m
USING ranked r
WHERE m.id = r.id
AND r.rn > 1
"""
)
op.execute(
"""
CREATE UNIQUE INDEX IF NOT EXISTS ux_memories_owner_memory_type
ON public.memories(owner_id, memory_type)
"""
)
op.execute(
"""
CREATE OR REPLACE FUNCTION public.create_profile_for_new_user()
RETURNS trigger
LANGUAGE plpgsql
SECURITY DEFINER
SET search_path = ''
AS $$
DECLARE
base_seed TEXT;
candidate_username TEXT;
attempt INT := 0;
BEGIN
base_seed := coalesce(NEW.phone, NEW.id::text);
LOOP
candidate_username := 'user_' || public.generate_profile_username_suffix(base_seed || ':' || attempt::text);
EXIT WHEN NOT EXISTS (
SELECT 1 FROM public.profiles p WHERE p.username = candidate_username
);
attempt := attempt + 1;
IF attempt >= 50 THEN
candidate_username := 'user_' || substr(replace(NEW.id::text, '-', ''), 1, 6);
EXIT;
END IF;
END LOOP;
INSERT INTO public.profiles (id, username, avatar_url, bio, settings, created_at, updated_at)
VALUES (
NEW.id,
candidate_username,
NULL,
NULL,
'{}'::jsonb,
now(),
now()
)
ON CONFLICT (id) DO NOTHING;
RETURN NEW;
END;
$$;
"""
)
def downgrade() -> None:
op.execute("DROP INDEX IF EXISTS ux_memories_owner_memory_type")
op.execute("DROP INDEX IF EXISTS ux_automation_jobs_owner_bootstrap_key_active")
bind = op.get_bind()
inspector = sa.inspect(bind)
tables = set(inspector.get_table_names(schema="public"))
if "automation_jobs_dedup_backup_202603230003" in tables:
op.execute(
"""
UPDATE public.automation_jobs aj
SET deleted_at = NULL
FROM public.automation_jobs_dedup_backup_202603230003 b
WHERE aj.id = b.id
"""
)
op.execute(
"DROP TABLE IF EXISTS public.automation_jobs_dedup_backup_202603230003"
)
if "memories_dedup_backup_202603230003" in tables:
op.execute(
"""
INSERT INTO public.memories
SELECT b.*
FROM public.memories_dedup_backup_202603230003 b
LEFT JOIN public.memories m ON m.id = b.id
WHERE m.id IS NULL
ON CONFLICT (id) DO NOTHING
"""
)
op.execute("DROP TABLE IF EXISTS public.memories_dedup_backup_202603230003")
automation_columns = {
column["name"] for column in inspector.get_columns("automation_jobs")
}
if "bootstrap_key" in automation_columns:
op.drop_column("automation_jobs", "bootstrap_key")
op.execute(
"""
CREATE UNIQUE INDEX IF NOT EXISTS ux_automation_jobs_owner_memory_active
ON public.automation_jobs(owner_id)
WHERE deleted_at IS NULL
AND config->>'agent_type' = 'memory'
"""
)
op.execute(
"""
CREATE OR REPLACE FUNCTION public.create_profile_for_new_user()
RETURNS trigger
LANGUAGE plpgsql
SECURITY DEFINER
SET search_path = ''
AS $$
DECLARE
base_seed TEXT;
candidate_username TEXT;
attempt INT := 0;
BEGIN
base_seed := coalesce(NEW.phone, NEW.id::text);
LOOP
candidate_username := 'user_' || public.generate_profile_username_suffix(base_seed || ':' || attempt::text);
EXIT WHEN NOT EXISTS (
SELECT 1 FROM public.profiles p WHERE p.username = candidate_username
);
attempt := attempt + 1;
IF attempt >= 50 THEN
candidate_username := 'user_' || substr(replace(NEW.id::text, '-', ''), 1, 6);
EXIT;
END IF;
END LOOP;
INSERT INTO public.profiles (id, username, avatar_url, bio, settings, created_at, updated_at)
VALUES (
NEW.id,
candidate_username,
NULL,
NULL,
'{}'::jsonb,
now(),
now()
)
ON CONFLICT (id) DO NOTHING;
BEGIN
IF NOT EXISTS (
SELECT 1
FROM public.automation_jobs aj
WHERE aj.owner_id = NEW.id
AND aj.deleted_at IS NULL
AND aj.config->>'agent_type' = 'memory'
) THEN
INSERT INTO public.automation_jobs (
id,
owner_id,
title,
config,
schedule_type,
run_at,
next_run_at,
timezone,
status,
created_by,
created_at,
updated_at
) VALUES (
gen_random_uuid(),
NEW.id,
'Memory Agent',
jsonb_build_object(
'agent_type', 'memory',
'model_code', 'qwen3.5-flash',
'enabled_tools', jsonb_build_array('calendar.read', 'user.lookup'),
'input_template', '请基于最近聊天上下文生成一段可执行的记忆总结与建议。',
'context', jsonb_build_object(
'source', 'latest_chat',
'window_mode', 'day',
'window_count', 2
)
),
'daily',
now(),
now() + interval '1 day',
'UTC',
'active',
NEW.id,
now(),
now()
);
END IF;
EXCEPTION WHEN unique_violation THEN
NULL;
END;
RETURN NEW;
END;
$$;
"""
)
@@ -1,11 +1,15 @@
from core.agentscope.prompts.agent_prompt import build_agent_prompt
from core.agentscope.prompts.memory_prompt import build_memory_prompt
from core.agentscope.prompts.memory_prompt import (
build_user_memory_prompt,
build_work_memory_prompt,
)
from core.agentscope.prompts.system_prompt import build_system_prompt
from core.agentscope.prompts.tool_prompt import build_tools_prompt
__all__ = [
"build_agent_prompt",
"build_memory_prompt",
"build_user_memory_prompt",
"build_work_memory_prompt",
"build_system_prompt",
"build_tools_prompt",
]
@@ -1,52 +1,59 @@
from __future__ import annotations
import json
from typing import Any
from schemas.memories import MemoryContext, MemoryListResponse
from schemas.memories.memory_content import UserMemoryContent, WorkProfileContent
def _wrap_section(section: str, content: str) -> str:
marker_map = {
"memory": ("<!-- MEMORY_START -->", "<!-- MEMORY_END -->"),
"user_memory": ("<!-- USER_MEMORY_START -->", "<!-- USER_MEMORY_END -->"),
"work_memory": ("<!-- WORK_MEMORY_START -->", "<!-- WORK_MEMORY_END -->"),
}
start, end = marker_map[section]
body = content.strip()
return f"{start}\n{body}\n{end}" if body else f"{start}\n{end}"
def _format_memory_content(content: dict[str, Any]) -> str:
def _format_content(content: UserMemoryContent | WorkProfileContent) -> str:
if isinstance(content, dict):
return json.dumps(content, ensure_ascii=True, separators=(",", ":"))
return str(content)
return json.dumps(
content.model_dump(mode="json"), ensure_ascii=True, separators=(",", ":")
)
def _format_memory(ctx: MemoryContext) -> str:
parts = [
f"[{ctx.memory_type.value.upper()}] {ctx.title or 'Untitled'}",
f" source: {ctx.source.value}",
f" content: {_format_memory_content(ctx.content)}",
]
if ctx.created_at:
parts.append(f" created_at: {ctx.created_at.isoformat()}")
return "\n".join(parts)
def build_memory_prompt(
def build_user_memory_prompt(
*,
memories: MemoryListResponse,
user_memory: UserMemoryContent | None,
) -> str | None:
if not memories.memories:
if user_memory is None:
return None
lines: list[str] = [
"[User Memories]",
"- Memories are persistent context from previous sessions.",
"- Use them to ground responses in known user facts and preferences.",
"- Do not invent facts not present in memories.",
"[User Memory]",
"- User memory contains personal preferences, habits, people, and places.",
"- Use this to understand the user's personal context and preferences.",
"- Do not invent facts not present here.",
f"content: {_format_content(user_memory)}",
]
for ctx in memories.memories:
lines.append(_format_memory(ctx))
return _wrap_section("user_memory", "\n".join(lines))
return _wrap_section("memory", "\n".join(lines))
def build_work_memory_prompt(
*,
work_memory: WorkProfileContent | None,
) -> str | None:
if work_memory is None:
return None
lines: list[str] = [
"[Work Memory]",
"- Work memory contains projects, team members, habits, and milestones.",
"- Use this to understand the user's work context and ongoing tasks.",
"- Do not invent facts not present here.",
f"content: {_format_content(work_memory)}",
]
return _wrap_section("work_memory", "\n".join(lines))
@@ -9,12 +9,15 @@ from ag_ui.core.types import Tool
from core.agentscope.prompts.agent_prompt import (
build_agent_prompt,
)
from core.agentscope.prompts.memory_prompt import build_memory_prompt
from core.agentscope.prompts.memory_prompt import (
build_user_memory_prompt,
build_work_memory_prompt,
)
from core.agentscope.prompts.route_prompt import build_frontend_route_prompt
from core.agentscope.prompts.tool_prompt import build_tools_prompt
from schemas.agent.system_agent import AgentType, SystemAgentLLMConfig
from schemas.agent.forwarded_props import ClientTimeContext
from schemas.memories import MemoryListResponse
from schemas.memories.memory_content import UserMemoryContent, WorkProfileContent
from schemas.user.context import UserContext
@@ -210,9 +213,16 @@ def build_system_prompt(
runtime_client_time: ClientTimeContext | None = None,
extra_context: str | None = None,
tools: Sequence[Tool | dict[str, Any]] | None = None,
memories: MemoryListResponse | None = None,
user_memory: UserMemoryContent | None = None,
work_memory: WorkProfileContent | None = None,
) -> str:
include_route_section = agent_type == AgentType.WORKER
if agent_type == AgentType.ROUTER:
memory_prompt = build_user_memory_prompt(user_memory=user_memory)
else:
memory_prompt = build_work_memory_prompt(work_memory=work_memory)
sections: list[str | None] = [
_build_identity_section(),
_build_env_section(
@@ -228,7 +238,7 @@ def build_system_prompt(
llm_config=llm_config,
),
build_tools_prompt(tools=tools) if tools else None,
build_memory_prompt(memories=memories) if memories else None,
memory_prompt,
_build_output_rules(),
]
return "\n\n".join(item for item in sections if item).strip()
@@ -7,7 +7,7 @@ from agentscope.message import Msg
from core.agentscope.runtime.runner import AgentScopeRunner
from core.logging import get_logger
from schemas.automation import RuntimeConfig
from schemas.memories import MemoryListResponse
from schemas.memories.memory_content import UserMemoryContent, WorkProfileContent
from schemas.user import UserContext
logger = get_logger("core.agentscope.runtime.orchestrator")
@@ -26,7 +26,8 @@ class RunnerLike(Protocol):
pipeline: PipelineLike,
run_input: RunAgentInput,
runtime_config: RuntimeConfig,
memories: MemoryListResponse | None,
user_memory: UserMemoryContent | None,
work_memory: WorkProfileContent | None,
) -> dict[str, Any]: ...
@@ -50,7 +51,8 @@ class AgentScopeRuntimeOrchestrator:
context_messages: list[Msg],
user_context: UserContext,
runtime_config: RuntimeConfig,
memories: MemoryListResponse | None = None,
user_memory: UserMemoryContent | None = None,
work_memory: WorkProfileContent | None = None,
) -> dict[str, Any]:
thread_id = run_input.thread_id
run_id = run_input.run_id
@@ -70,7 +72,8 @@ class AgentScopeRuntimeOrchestrator:
pipeline=self._pipeline,
run_input=run_input,
runtime_config=runtime_config,
memories=memories,
user_memory=user_memory,
work_memory=work_memory,
)
await self._pipeline.emit(
+13 -12
View File
@@ -41,7 +41,7 @@ from schemas.agent.system_agent import (
SystemAgentLLMConfig,
)
from schemas.automation import RuntimeConfig
from schemas.memories import MemoryListResponse
from schemas.memories.memory_content import UserMemoryContent, WorkProfileContent
from schemas.user import UserContext
from services.litellm.service import LiteLLMService
from sqlalchemy import select
@@ -71,7 +71,8 @@ class AgentScopeRunner:
pipeline: PipelineLike,
run_input: RunAgentInput,
runtime_config: RuntimeConfig,
memories: MemoryListResponse | None = None,
user_memory: UserMemoryContent | None = None,
work_memory: WorkProfileContent | None = None,
) -> dict[str, Any]:
owner_id = UUID(user_context.id)
runtime_client_time = self._resolve_runtime_client_time(run_input=run_input)
@@ -98,7 +99,7 @@ class AgentScopeRunner:
context_messages=context_messages,
stage_config=router_config,
runtime_client_time=runtime_client_time,
memories=memories,
user_memory=user_memory,
)
worker_output = await self._execute_worker_step(
pipeline=pipeline,
@@ -108,7 +109,7 @@ class AgentScopeRunner:
toolkit=worker_toolkit,
stage_config=worker_config,
runtime_client_time=runtime_client_time,
memories=memories,
work_memory=work_memory,
)
return {
"router": router_output.model_dump(mode="json", exclude_none=True),
@@ -166,7 +167,7 @@ class AgentScopeRunner:
context_messages: list[Msg],
stage_config: SystemAgentRuntimeConfig,
runtime_client_time: ClientTimeContext | None,
memories: MemoryListResponse | None,
user_memory: UserMemoryContent | None,
) -> RouterAgentOutput:
await self._emit_step_event(
pipeline=pipeline,
@@ -179,7 +180,7 @@ class AgentScopeRunner:
context_messages=context_messages,
stage_config=stage_config,
runtime_client_time=runtime_client_time,
memories=memories,
user_memory=user_memory,
run_input=run_input,
)
router_output = RouterAgentOutput.model_validate(router_result.payload)
@@ -201,7 +202,7 @@ class AgentScopeRunner:
toolkit: Any,
stage_config: SystemAgentRuntimeConfig,
runtime_client_time: ClientTimeContext | None,
memories: MemoryListResponse | None,
work_memory: WorkProfileContent | None,
) -> WorkerAgentOutputLite:
worker_output_model = resolve_worker_output_model(router_output.ui.ui_mode)
await self._emit_step_event(
@@ -221,7 +222,7 @@ class AgentScopeRunner:
worker_output_model=worker_output_model,
pipeline=pipeline,
runtime_client_time=runtime_client_time,
memories=memories,
work_memory=work_memory,
)
worker_output = worker_output_model.model_validate(worker_result.payload)
await self._emit_step_event(
@@ -239,7 +240,7 @@ class AgentScopeRunner:
context_messages: list[Msg],
stage_config: SystemAgentRuntimeConfig,
runtime_client_time: ClientTimeContext | None,
memories: MemoryListResponse | None,
user_memory: UserMemoryContent | None,
run_input: RunAgentInput,
) -> StageExecutionResult:
messages_for_router = self._build_router_messages(
@@ -260,7 +261,7 @@ class AgentScopeRunner:
now_utc=datetime.now(timezone.utc),
runtime_client_time=runtime_client_time,
tools=None,
memories=memories,
user_memory=user_memory,
),
"system",
),
@@ -319,7 +320,7 @@ class AgentScopeRunner:
worker_output_model: type[WorkerAgentOutputLite],
pipeline: PipelineLike,
runtime_client_time: ClientTimeContext | None,
memories: MemoryListResponse | None,
work_memory: WorkProfileContent | None,
) -> StageExecutionResult:
tracking_model = self._build_model(stage_config=stage_config)
emitter = PipelineStageEmitter(
@@ -340,7 +341,7 @@ class AgentScopeRunner:
runtime_client_time=runtime_client_time,
extra_context=stage_config.extra_context,
tools=None,
memories=memories,
work_memory=work_memory,
),
toolkit=toolkit,
model=tracking_model,
+14 -8
View File
@@ -20,8 +20,8 @@ from core.config.settings import config
from core.db.session import AsyncSessionLocal
from core.logging import get_logger
from core.taskiq.app import worker_agent_broker, worker_automation_broker
from schemas.automation import MemoryContextConfig, RuntimeConfig
from schemas.memories import MemoryListResponse
from schemas.automation import MessageContextConfig, RuntimeConfig
from schemas.memories.memory_content import UserMemoryContent, WorkProfileContent
from schemas.messages.chat_message import (
AgentChatMessageMetadata,
extract_user_message_attachments,
@@ -30,7 +30,7 @@ from schemas.user import UserContext
from services.base.redis import get_or_init_redis_client
from services.base.supabase import supabase_service
from v1.agent.repository import AgentRepository
from v1.memories.repository import MemoriesRepository
from v1.memories.repository import SQLAlchemyMemoriesRepository
from v1.memories.service import MemoriesService
from v1.users.dependencies import get_user_service
@@ -83,7 +83,7 @@ async def _build_recent_context_messages(
*,
session: Any,
thread_id: str,
context_config: "MemoryContextConfig",
context_config: "MessageContextConfig",
) -> list[Msg]:
context_service = AgentContextService(repository=AgentRepository(session))
result = await context_service.load_context_messages(
@@ -194,11 +194,16 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
orchestrator = _load_runtime()
async with AsyncSessionLocal() as session:
current_user = CurrentUser(id=owner_id)
user_context = await _build_user_context(owner_id=owner_id, session=session)
memories_service = MemoriesService(MemoriesRepository(session))
memories: MemoryListResponse = await memories_service.get_all_memories(
owner_id=owner_id
memories_service = MemoriesService(
repository=SQLAlchemyMemoriesRepository(session),
session=session,
current_user=current_user,
)
memories_result = await memories_service.get_all_memories()
user_memory: UserMemoryContent | None = memories_result.get("user_memory")
work_memory: WorkProfileContent | None = memories_result.get("work_memory")
redis_client = await get_or_init_redis_client()
bus = RedisStreamBus(
@@ -229,7 +234,8 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
context_messages=context_messages,
user_context=user_context,
runtime_config=runtime_config,
memories=memories,
user_memory=user_memory,
work_memory=work_memory,
)
logger.info(
"agentscope runtime task completed",
@@ -6,7 +6,7 @@ from typing import Any, Protocol
from schemas.agent.visibility import SystemVisibilityBit, bit_mask
from schemas.automation import ContextWindowMode, MemoryContextConfig
from schemas.automation import ContextWindowMode, MessageContextConfig
_DEFAULT_CONTEXT_WINDOW_USER_MESSAGES = 20
@@ -86,7 +86,7 @@ class AgentContextService:
self,
*,
thread_id: str,
context_config: MemoryContextConfig,
context_config: MessageContextConfig,
) -> dict[str, object] | None:
visibility_mask = bit_mask(bit=int(SystemVisibilityBit.CONTEXT_ASSEMBLY))
context_loader = CONTEXT_LOADER_REGISTRY.resolve(
@@ -6,10 +6,16 @@ from core.agentscope.tools.custom.calendar import (
from core.agentscope.tools.custom.user_lookup import (
user_lookup,
)
from core.agentscope.tools.custom.memory import (
memory_forget,
memory_write,
)
__all__ = [
"calendar_read",
"calendar_write",
"calendar_share",
"user_lookup",
"memory_write",
"memory_forget",
]
@@ -0,0 +1,330 @@
from copy import deepcopy
from typing import Annotated, Any, cast
from uuid import UUID
from agentscope.tool import ToolResponse
from pydantic import BaseModel, ConfigDict, Field, model_validator
from sqlalchemy.ext.asyncio import AsyncSession
from core.agentscope.tools.tool_call_context import get_current_tool_call_id
from core.agentscope.tools.utils.memory_domain import (
create_memories_service,
map_memory_exception,
)
from core.agentscope.tools.utils.tool_response_builder import (
build_error_output,
build_tool_response,
)
from models.memories import MemoryType
from schemas.agent.runtime_models import ToolAgentOutput, ToolStatus
from schemas.memories.memory_content import UserMemoryContent, WorkProfileContent
class MemoryWriteArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
memory_type: MemoryType = MemoryType.USER
user_content: UserMemoryContent | None = None
work_content: WorkProfileContent | None = None
@model_validator(mode="after")
def validate_content(self) -> "MemoryWriteArgs":
if self.memory_type == MemoryType.USER:
if self.user_content is None or self.work_content is not None:
raise ValueError("memory_type=user requires user_content only")
else:
if self.work_content is None or self.user_content is not None:
raise ValueError("memory_type=work requires work_content only")
return self
class MemoryForgetArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
memory_type: MemoryType = MemoryType.USER
forget_paths: list[str] = Field(min_length=1, max_length=100)
@model_validator(mode="after")
def validate_forget_paths(self) -> "MemoryForgetArgs":
allowed_roots = (
set(UserMemoryContent.model_fields)
if self.memory_type == MemoryType.USER
else set(WorkProfileContent.model_fields)
)
normalized: list[str] = []
for raw_path in self.forget_paths:
path = raw_path.strip()
if not path:
continue
parts = [part for part in path.split(".") if part]
if not parts:
continue
if len(parts) > 5:
raise ValueError("forget path depth exceeds limit")
if parts[0] not in allowed_roots:
raise ValueError("forget path root is not allowed")
normalized.append(path)
if not normalized:
raise ValueError("forget_paths cannot be empty")
self.forget_paths = normalized
return self
def _memory_error_output(
*,
tool_name: str,
tool_call_args: dict[str, Any],
code: str,
message: str,
retryable: bool,
) -> ToolResponse:
output = build_error_output(
tool_name=tool_name,
tool_call_id=get_current_tool_call_id(tool_name=tool_name),
code=code,
message=message,
retryable=retryable,
)
output = output.model_copy(update={"tool_call_args": tool_call_args})
return build_tool_response(output)
def _validate_runtime_context(
*,
tool_name: str,
tool_call_args: dict[str, Any],
session: Any,
owner_id: Any,
) -> ToolResponse | None:
if session is None or owner_id is None:
return _memory_error_output(
tool_name=tool_name,
tool_call_args=tool_call_args,
code="MISSING_RUNTIME_ARGS",
message="记忆工具缺少运行时参数",
retryable=False,
)
return None
def _deep_merge_dict(base: dict[str, Any], patch: dict[str, Any]) -> dict[str, Any]:
merged = deepcopy(base)
for key, value in patch.items():
if isinstance(value, dict) and isinstance(merged.get(key), dict):
merged[key] = _deep_merge_dict(cast(dict[str, Any], merged[key]), value)
else:
merged[key] = value
return merged
def _remove_content_paths(
base_payload: dict[str, Any],
paths: list[str],
) -> tuple[dict[str, Any], list[str]]:
result = deepcopy(base_payload)
removed: list[str] = []
for raw_path in paths:
path = raw_path.strip()
if not path:
continue
keys = [part for part in path.split(".") if part]
if not keys:
continue
if _delete_nested_path(result, keys):
removed.append(path)
return result, removed
def _delete_nested_path(payload: dict[str, Any], keys: list[str]) -> bool:
current: dict[str, Any] = payload
for key in keys[:-1]:
next_value = current.get(key)
if not isinstance(next_value, dict):
return False
current = next_value
leaf = keys[-1]
if leaf in current:
del current[leaf]
return True
return False
async def memory_write(
memory_type: Annotated[
str,
Field(description="Memory type: user or work."),
] = "user",
user_content: Annotated[
UserMemoryContent | None,
Field(description="Patch payload for user memory content."),
] = None,
work_content: Annotated[
WorkProfileContent | None,
Field(description="Patch payload for work memory content."),
] = None,
session: Any = None,
owner_id: Any = None,
) -> ToolResponse:
tool_name = "memory_write"
tool_call_args: dict[str, Any] = {
"memory_type": memory_type,
"user_content": user_content,
"work_content": work_content,
}
runtime_error = _validate_runtime_context(
tool_name=tool_name,
tool_call_args=tool_call_args,
session=session,
owner_id=owner_id,
)
if runtime_error is not None:
return runtime_error
try:
parsed_args = MemoryWriteArgs.model_validate(tool_call_args)
service = create_memories_service(
session=cast(AsyncSession, session),
owner_id=cast(UUID, owner_id),
)
existing = await service.get_memory_model(memory_type=parsed_args.memory_type)
if parsed_args.memory_type == MemoryType.USER:
base_model = (
UserMemoryContent.model_validate(existing.content)
if existing is not None
else UserMemoryContent()
)
patch_model = cast(UserMemoryContent, parsed_args.user_content)
merged = _deep_merge_dict(
base_model.model_dump(),
patch_model.model_dump(exclude_unset=True),
)
validated = UserMemoryContent.model_validate(merged)
await service.update_user_memory(
content=validated,
)
else:
base_model = (
WorkProfileContent.model_validate(existing.content)
if existing is not None
else WorkProfileContent()
)
patch_model = cast(WorkProfileContent, parsed_args.work_content)
merged = _deep_merge_dict(
base_model.model_dump(),
patch_model.model_dump(exclude_unset=True),
)
validated = WorkProfileContent.model_validate(merged)
await service.update_work_memory(
content=validated,
)
summary = f"status=success memory_type={parsed_args.memory_type.value}"
return build_tool_response(
ToolAgentOutput(
tool_name=tool_name,
tool_call_id=get_current_tool_call_id(tool_name=tool_name),
tool_call_args=tool_call_args,
status=ToolStatus.SUCCESS,
result=summary,
)
)
except Exception as exc: # noqa: BLE001
code, message, retryable = map_memory_exception(exc)
return _memory_error_output(
tool_name=tool_name,
tool_call_args=tool_call_args,
code=code,
message=message,
retryable=retryable,
)
async def memory_forget(
memory_type: Annotated[
str,
Field(description="Memory type: user or work."),
] = "user",
forget_paths: Annotated[
list[str] | None,
Field(description="Dot paths to remove from content."),
] = None,
session: Any = None,
owner_id: Any = None,
) -> ToolResponse:
tool_name = "memory_forget"
tool_call_args: dict[str, Any] = {
"memory_type": memory_type,
"forget_paths": forget_paths or [],
}
runtime_error = _validate_runtime_context(
tool_name=tool_name,
tool_call_args=tool_call_args,
session=session,
owner_id=owner_id,
)
if runtime_error is not None:
return runtime_error
try:
parsed_args = MemoryForgetArgs.model_validate(tool_call_args)
service = create_memories_service(
session=cast(AsyncSession, session),
owner_id=cast(UUID, owner_id),
)
existing = await service.get_memory_model(memory_type=parsed_args.memory_type)
if existing is None:
summary = f"status=success memory_type={parsed_args.memory_type.value} forgotten=0"
return build_tool_response(
ToolAgentOutput(
tool_name=tool_name,
tool_call_id=get_current_tool_call_id(tool_name=tool_name),
tool_call_args=tool_call_args,
status=ToolStatus.SUCCESS,
result=summary,
)
)
if parsed_args.memory_type == MemoryType.USER:
base_model = UserMemoryContent.model_validate(existing.content)
updated_dict, removed_paths = _remove_content_paths(
base_model.model_dump(),
parsed_args.forget_paths,
)
validated = UserMemoryContent.model_validate(updated_dict)
await service.update_user_memory(
content=validated,
)
else:
base_model = WorkProfileContent.model_validate(existing.content)
updated_dict, removed_paths = _remove_content_paths(
base_model.model_dump(),
parsed_args.forget_paths,
)
validated = WorkProfileContent.model_validate(updated_dict)
await service.update_work_memory(
content=validated,
)
summary = (
f"status=success memory_type={parsed_args.memory_type.value} forgotten={len(removed_paths)} "
f"skipped=0"
)
return build_tool_response(
ToolAgentOutput(
tool_name=tool_name,
tool_call_id=get_current_tool_call_id(tool_name=tool_name),
tool_call_args=tool_call_args,
status=ToolStatus.SUCCESS,
result=summary,
)
)
except Exception as exc: # noqa: BLE001
code, message, retryable = map_memory_exception(exc)
return _memory_error_output(
tool_name=tool_name,
tool_call_args=tool_call_args,
code=code,
message=message,
retryable=retryable,
)
@@ -4,17 +4,13 @@ from dataclasses import dataclass
from enum import Enum
class ToolGroup(str, Enum):
READ = "read"
EXECUTE = "execute"
MEMORY = "memory"
class AgentTool(str, Enum):
CALENDAR_READ = "calendar.read"
CALENDAR_WRITE = "calendar.write"
CALENDAR_SHARE = "calendar.share"
USER_LOOKUP = "user.lookup"
MEMORY_WRITE = "memory.write"
MEMORY_FORGET = "memory.forget"
@dataclass(frozen=True)
@@ -25,29 +21,32 @@ class ToolApprovalConfig:
@dataclass(frozen=True)
class ToolConfig:
name: str
group: ToolGroup
approval: ToolApprovalConfig
TOOL_CONFIGS: dict[str, ToolConfig] = {
"calendar_read": ToolConfig(
name="calendar_read",
group=ToolGroup.READ,
approval=ToolApprovalConfig(required=False),
),
"user_lookup": ToolConfig(
name="user_lookup",
group=ToolGroup.MEMORY,
approval=ToolApprovalConfig(required=False),
),
"calendar_write": ToolConfig(
name="calendar_write",
group=ToolGroup.EXECUTE,
approval=ToolApprovalConfig(required=False),
),
"calendar_share": ToolConfig(
name="calendar_share",
group=ToolGroup.EXECUTE,
approval=ToolApprovalConfig(required=False),
),
"memory_write": ToolConfig(
name="memory_write",
approval=ToolApprovalConfig(required=False),
),
"memory_forget": ToolConfig(
name="memory_forget",
approval=ToolApprovalConfig(required=False),
),
}
@@ -57,6 +56,8 @@ AGENT_TOOL_TO_FUNCTION_NAME: dict[AgentTool, str] = {
AgentTool.CALENDAR_WRITE: "calendar_write",
AgentTool.CALENDAR_SHARE: "calendar_share",
AgentTool.USER_LOOKUP: "user_lookup",
AgentTool.MEMORY_WRITE: "memory_write",
AgentTool.MEMORY_FORGET: "memory_forget",
}
TOOL_NAME_ALIASES: dict[str, AgentTool] = {
@@ -68,6 +69,10 @@ TOOL_NAME_ALIASES: dict[str, AgentTool] = {
"calendar_share": AgentTool.CALENDAR_SHARE,
AgentTool.USER_LOOKUP.value: AgentTool.USER_LOOKUP,
"user_lookup": AgentTool.USER_LOOKUP,
AgentTool.MEMORY_WRITE.value: AgentTool.MEMORY_WRITE,
"memory_write": AgentTool.MEMORY_WRITE,
AgentTool.MEMORY_FORGET.value: AgentTool.MEMORY_FORGET,
"memory_forget": AgentTool.MEMORY_FORGET,
}
@@ -10,6 +10,10 @@ from core.agentscope.tools.custom.calendar import (
calendar_share,
calendar_write,
)
from core.agentscope.tools.custom.memory import (
memory_forget,
memory_write,
)
from core.agentscope.tools.custom.user_lookup import user_lookup
from core.agentscope.tools.tool_config import (
TOOL_CONFIGS,
@@ -23,6 +27,8 @@ TOOL_FUNCTIONS: dict[str, Any] = {
"calendar_write": calendar_write,
"calendar_share": calendar_share,
"user_lookup": user_lookup,
"memory_write": memory_write,
"memory_forget": memory_forget,
}
@@ -0,0 +1,32 @@
from __future__ import annotations
from uuid import UUID
from fastapi import HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from core.auth.models import CurrentUser
from v1.memories.repository import SQLAlchemyMemoriesRepository
from v1.memories.service import MemoriesService
def create_memories_service(
session: AsyncSession,
owner_id: UUID,
) -> MemoriesService:
return MemoriesService(
repository=SQLAlchemyMemoriesRepository(session),
session=session,
current_user=CurrentUser(id=owner_id),
)
def map_memory_exception(exc: Exception) -> tuple[str, str, bool]:
if isinstance(exc, HTTPException):
detail = exc.detail
if isinstance(detail, str) and detail.strip():
return "OPERATION_FAILED", detail.strip(), exc.status_code >= 500
return "OPERATION_FAILED", "记忆操作失败", exc.status_code >= 500
if isinstance(exc, ValueError):
return "INVALID_ARGUMENT", "请求参数无效", False
return "INTERNAL_ERROR", "记忆操作失败", True
+2 -10
View File
@@ -1,11 +1,3 @@
from core.automation.scheduler import (
AutomationSchedulerService,
DispatchResult,
SqlAlchemyAutomationSchedulerRepository,
)
from core.automation.scheduler import run_automation_scheduler_scan
__all__ = [
"AutomationSchedulerService",
"DispatchResult",
"SqlAlchemyAutomationSchedulerRepository",
]
__all__ = ["run_automation_scheduler_scan"]
@@ -0,0 +1,8 @@
input_template: 请基于最近两天用户聊天上下文提取用户记忆;如果已有记忆内容变化请更新;如果记忆已失效请执行遗忘。
enabled_tools:
- memory.write
- memory.forget
context:
source: latest_chat
window_mode: day
window_count: 2
+4
View File
@@ -32,6 +32,10 @@ class AutomationJob(TimestampMixin, SoftDeleteMixin, Base):
UUID(as_uuid=True),
nullable=False,
)
bootstrap_key: Mapped[str | None] = mapped_column(
String(64),
nullable=True,
)
title: Mapped[str] = mapped_column(
String(255),
nullable=False,
-11
View File
@@ -16,12 +16,6 @@ class MemoryType(str, Enum):
WORK = "work"
class MemorySource(str, Enum):
MANUAL = "manual"
AGENT = "agent"
IMPORTED = "imported"
class MemoryStatus(str, Enum):
ACTIVE = "active"
DISABLED = "disabled"
@@ -46,15 +40,10 @@ class Memory(TimestampMixin, Base):
String(20),
nullable=False,
)
title: Mapped[str | None] = mapped_column(String(255), nullable=True)
content: Mapped[dict] = mapped_column(
json_jsonb,
nullable=False,
)
source: Mapped[MemorySource] = mapped_column(
String(20),
nullable=False,
)
status: Mapped[MemoryStatus] = mapped_column(
String(20),
nullable=False,
-2
View File
@@ -12,7 +12,6 @@ from schemas.inbox.messages import (
parse_calendar_content,
)
from schemas.invite_codes import InviteCodeRewardConfig
from schemas.memories import MemoryContext
from schemas.messages import AgentChatMessageMetadata
from schemas.schedule.items import (
AttachmentType,
@@ -36,7 +35,6 @@ __all__ = [
"InboxMessageStatus",
"InboxMessageType",
"InviteCodeRewardConfig",
"MemoryContext",
"ScheduleItemMetadata",
"ScheduleItemMetadataAttachment",
"ScheduleItemSourceType",
+4 -2
View File
@@ -20,7 +20,7 @@ class ContextWindowMode(str, Enum):
NUMBER = "number"
class MemoryContextConfig(BaseModel):
class MessageContextConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
source: ContextSource = ContextSource.LATEST_CHAT
@@ -32,7 +32,7 @@ class RuntimeConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
enabled_tools: list[AgentTool] = Field(default_factory=list, max_length=32)
context: MemoryContextConfig = Field(default_factory=MemoryContextConfig)
context: MessageContextConfig = Field(default_factory=MessageContextConfig)
class AutomationJobConfig(RuntimeConfig):
@@ -46,6 +46,7 @@ class AutomationJob(BaseModel):
id: UUID
owner_id: UUID
bootstrap_key: str | None = Field(default=None, min_length=1, max_length=64)
title: str = Field(..., min_length=1, max_length=255)
config: AutomationJobConfig
schedule_type: ScheduleType
@@ -63,6 +64,7 @@ class AutomationJob(BaseModel):
return cls(
id=obj.id,
owner_id=obj.owner_id,
bootstrap_key=obj.bootstrap_key,
title=obj.title,
config=AutomationJobConfig.model_validate(obj.config or {}),
schedule_type=obj.schedule_type,
+18 -33
View File
@@ -2,10 +2,19 @@ from __future__ import annotations
from datetime import datetime
from enum import Enum
from typing import Any, ClassVar, Literal
from typing import ClassVar, Literal
from uuid import UUID
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict
from schemas.memories.memory_content import (
TeamMember,
UserMemoryContent,
UserPreferences,
WorkHabit,
WorkProfileContent,
WorkProject,
)
class MemoryType(str, Enum):
@@ -13,12 +22,6 @@ class MemoryType(str, Enum):
WORK = "work"
class MemorySource(str, Enum):
MANUAL = "manual"
AGENT = "agent"
IMPORTED = "imported"
class MemoryStatus(str, Enum):
ACTIVE = "active"
DISABLED = "disabled"
@@ -33,38 +36,20 @@ class MemoryModel(BaseModel):
owner_id: UUID
agent_id: UUID | None = None
memory_type: Literal["user", "work"]
title: str | None = None
content: dict[str, Any]
source: MemorySource
content: UserMemoryContent | WorkProfileContent
status: MemoryStatus
created_at: datetime
updated_at: datetime
class MemoryContext(BaseModel):
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
memory_type: MemoryType
source: MemorySource
title: str | None = None
content: dict[str, Any]
created_at: datetime
updated_at: datetime
class MemoryListResponse(BaseModel):
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
owner_id: UUID
memories: list[MemoryContext] = Field(default_factory=list)
total: int
__all__ = [
"MemoryContext",
"MemoryListResponse",
"MemoryModel",
"MemorySource",
"MemoryStatus",
"MemoryType",
"TeamMember",
"UserMemoryContent",
"UserPreferences",
"WorkHabit",
"WorkProfileContent",
"WorkProject",
]
@@ -0,0 +1,192 @@
from __future__ import annotations
from datetime import date, datetime
from typing import Literal
from pydantic import BaseModel, ConfigDict, Field
Weekday = Literal["mon", "tue", "wed", "thu", "fri", "sat", "sun"]
ProjectStatus = Literal["planned", "active", "paused", "completed"]
PreferenceLevel = Literal["like", "neutral", "avoid"]
MemorySource = Literal["user", "inferred", "calendar", "email", "agent"]
class MemoryMeta(BaseModel):
source: MemorySource | None = Field(default=None, description="记忆来源")
confidence: float = Field(default=1.0, ge=0.0, le=1.0, description="置信度")
last_updated_at: datetime | None = Field(default=None, description="最后更新时间")
class TimeWindow(BaseModel):
weekdays: list[Weekday] = Field(default_factory=list, description="适用星期")
start: str = Field(description="开始时间,HH:MM")
end: str = Field(description="结束时间,HH:MM")
class PersonMemory(BaseModel):
name: str = Field(description="人物姓名")
relationship: str | None = Field(
default=None, description="与用户关系,如家人/同事/导师/朋友"
)
role: str | None = Field(default=None, description="角色,如老板/导师/合作方")
preferred_contact_channel: str | None = Field(
default=None, description="偏好联系方式"
)
notes: str | None = Field(default=None, description="补充说明")
meta: MemoryMeta = Field(default_factory=MemoryMeta)
class PlaceMemory(BaseModel):
name: str = Field(description="地点名称")
category: str | None = Field(
default=None, description="地点类别,如home/office/gym/cafe"
)
address: str | None = Field(default=None, description="地址")
timezone: str | None = Field(default=None, description="地点时区")
commute_minutes: int | None = Field(default=None, ge=0, description="典型通勤时长")
preference: PreferenceLevel | None = Field(default=None, description="地点偏好")
notes: str | None = Field(default=None, description="补充说明")
meta: MemoryMeta = Field(default_factory=MemoryMeta)
class UserPreferences(BaseModel):
communication_style: str | None = Field(
default=None, description="沟通风格,如简洁直接"
)
language_preference: list[str] = Field(default_factory=list, description="语言偏好")
location_preference: str | None = Field(
default=None, description="地点偏好,如喜欢远程"
)
work_lifestyle: str | None = Field(default=None, description="作息方式,如早睡早起")
notification_preference: list[str] = Field(
default_factory=list, description="通知方式偏好"
)
class SchedulingPreferences(BaseModel):
productive_windows: list[TimeWindow] = Field(
default_factory=list, description="高效率时段"
)
preferred_meeting_windows: list[TimeWindow] = Field(
default_factory=list, description="偏好的会议时段"
)
no_meeting_windows: list[TimeWindow] = Field(
default_factory=list, description="尽量不安排会议的时段"
)
deep_work_windows: list[TimeWindow] = Field(
default_factory=list, description="深度工作时段"
)
preferred_meeting_duration_minutes: list[int] = Field(
default_factory=lambda: [30, 60], description="偏好的会议时长"
)
meeting_buffer_minutes: int | None = Field(
default=None, ge=0, description="会议间缓冲时间"
)
max_meetings_per_day: int | None = Field(
default=None, ge=0, description="单日会议上限"
)
notes: str | None = Field(default=None, description="其他排程说明")
class RecurringRoutine(BaseModel):
name: str = Field(description="周期性安排名称")
description: str | None = Field(default=None, description="周期性安排描述")
cadence: str | None = Field(
default=None, description="频率,如daily/weekly/monthly"
)
time_windows: list[TimeWindow] = Field(
default_factory=list, description="通常发生时段"
)
importance: str | None = Field(default=None, description="重要程度")
meta: MemoryMeta = Field(default_factory=MemoryMeta)
class UserMemoryContent(BaseModel):
model_config = ConfigDict(extra="allow")
occupation: str | None = Field(default=None, description="职业")
timezone: str | None = Field(default=None, description="时区")
primary_language: str | None = Field(default=None, description="主要语言")
people: list[PersonMemory] = Field(default_factory=list, description="重要人物")
places: list[PlaceMemory] = Field(default_factory=list, description="常去地点")
preferences: UserPreferences = Field(default_factory=UserPreferences)
scheduling_preferences: SchedulingPreferences = Field(
default_factory=SchedulingPreferences
)
interests: list[str] = Field(default_factory=list, description="兴趣爱好")
avoid_topics: list[str] = Field(default_factory=list, description="不想讨论的话题")
custom_rules: list[str] = Field(default_factory=list, description="用户自定义规则")
recurring_routines: list[RecurringRoutine] = Field(
default_factory=list, description="周期性习惯/安排"
)
class Milestone(BaseModel):
name: str = Field(description="里程碑名称")
due_date: date | None = Field(default=None, description="截止日期")
status: str | None = Field(default=None, description="状态")
notes: str | None = Field(default=None, description="补充说明")
class WorkProject(BaseModel):
name: str = Field(description="项目名")
description: str | None = Field(default=None, description="项目描述")
status: ProjectStatus | None = Field(default=None, description="项目状态")
priority: str | None = Field(default=None, description="项目优先级")
deadline: date | None = Field(default=None, description="项目截止时间")
collaborators: list[str] = Field(default_factory=list, description="协作人")
key_milestones: list[Milestone] = Field(
default_factory=list, description="关键里程碑"
)
notes: str | None = Field(default=None, description="补充说明")
meta: MemoryMeta = Field(default_factory=MemoryMeta)
class WorkHabit(BaseModel):
available_hours: list[TimeWindow] = Field(
default_factory=list, description="常规可工作时间"
)
deep_work_blocks: list[TimeWindow] = Field(
default_factory=list, description="偏好的深度工作时间"
)
preferred_meeting_windows: list[TimeWindow] = Field(
default_factory=list, description="偏好的会议时间"
)
no_meeting_windows: list[TimeWindow] = Field(
default_factory=list, description="不希望开会的时间"
)
preferred_meeting_duration_minutes: list[int] = Field(
default_factory=lambda: [30, 60], description="偏好的会议时长"
)
notification_channel: str | None = Field(default=None, description="首选沟通渠道")
notes: str | None = Field(default=None, description="补充说明")
class TeamMember(BaseModel):
name: str = Field(description="成员姓名")
role: str | None = Field(default=None, description="团队角色")
relationship: str | None = Field(
default=None, description="关系,如直属上级/同事/合作方"
)
preferred_contact_channel: str | None = Field(
default=None, description="偏好沟通渠道"
)
notes: str | None = Field(default=None, description="补充说明")
meta: MemoryMeta = Field(default_factory=MemoryMeta)
class WorkProfileContent(BaseModel):
model_config = ConfigDict(extra="allow")
occupation: str | None = Field(default=None, description="职业身份")
expertise: list[str] = Field(default_factory=list, description="专业领域")
preferred_tools: list[str] = Field(default_factory=list, description="惯用工具")
current_projects: list[WorkProject] = Field(
default_factory=list, description="长期项目画像"
)
work_habits: WorkHabit = Field(default_factory=WorkHabit)
team_members: list[TeamMember] = Field(default_factory=list, description="团队成员")
team_context: str | None = Field(default=None, description="团队背景")
work_rules: list[str] = Field(
default_factory=list, description="工作规则或默认原则"
)
+7 -6
View File
@@ -9,11 +9,12 @@ from pathlib import Path
import yaml
from pydantic import ValidationError
from core.agentscope.tools.tool_config import AgentTool
from schemas.agent.system_agent import SystemAgentLLMConfig
from schemas.automation import (
ContextSource,
ContextWindowMode,
MemoryContextConfig,
MessageContextConfig,
RuntimeConfig,
)
@@ -38,9 +39,9 @@ def _load_system_agents_yaml(path: Path | None = None) -> dict:
return loaded
def _parse_context_messages_config(yaml_config: dict | None) -> MemoryContextConfig:
def _parse_context_messages_config(yaml_config: dict | None) -> MessageContextConfig:
if not yaml_config:
return MemoryContextConfig()
return MessageContextConfig()
mode_str = yaml_config.get("mode", "day")
count = yaml_config.get("count", 2)
try:
@@ -51,7 +52,7 @@ def _parse_context_messages_config(yaml_config: dict | None) -> MemoryContextCon
window_mode = ContextWindowMode(mode_str)
except ValueError:
window_mode = ContextWindowMode.DAY
return MemoryContextConfig(
return MessageContextConfig(
source=source,
window_mode=window_mode,
window_count=count,
@@ -93,9 +94,9 @@ def build_runtime_config_from_system_agents(
router_config.context_messages.model_dump() if router_config else None
)
enabled_tools: list[str] = []
enabled_tools: list[AgentTool] = []
if worker_config and worker_config.enabled_tools:
enabled_tools = [str(t) for t in worker_config.enabled_tools]
enabled_tools = list(worker_config.enabled_tools)
return RuntimeConfig(
enabled_tools=enabled_tools,
@@ -0,0 +1,46 @@
from __future__ import annotations
from functools import lru_cache
from pathlib import Path
import re
from typing import Any
import yaml
from core.agentscope.tools.tool_config import AgentTool
from schemas.automation import AutomationJobConfig, MessageContextConfig
_CONFIG_NAME_PATTERN = re.compile(r"^[a-z0-9][a-z0-9_-]{0,63}$")
def _automation_yaml_path(config_name: str) -> Path:
if not _CONFIG_NAME_PATTERN.fullmatch(config_name):
raise ValueError("invalid automation config name")
return (
Path(__file__).resolve().parents[2]
/ "core"
/ "config"
/ "static"
/ "automation"
/ f"{config_name}.yaml"
)
@lru_cache(maxsize=16)
def load_static_automation_job_config(*, config_name: str) -> AutomationJobConfig:
path = _automation_yaml_path(config_name)
with path.open("r", encoding="utf-8") as file:
loaded: Any = yaml.safe_load(file) or {}
if not isinstance(loaded, dict):
raise ValueError(f"invalid automation config format: {path}")
config = AutomationJobConfig.model_validate(loaded)
if config_name == "memory_extraction":
if config.enabled_tools != [AgentTool.MEMORY_WRITE, AgentTool.MEMORY_FORGET]:
raise ValueError(
"memory_extraction enabled_tools must be [memory.write, memory.forget]"
)
if config.context != MessageContextConfig(window_count=2):
raise ValueError(
"memory_extraction context must be latest_chat/day with window_count=2"
)
return config
+21 -2
View File
@@ -1,8 +1,27 @@
from __future__ import annotations
from typing import Annotated
from fastapi import Depends
from sqlalchemy.ext.asyncio import AsyncSession
from core.db import get_db
from v1.auth.gateway import SupabaseAuthGateway
from v1.auth.registration_bootstrap import (
RegistrationAutomationBootstrapService,
RegistrationBootstrapRepository,
)
from v1.auth.service import AuthService
def get_auth_service() -> AuthService:
return AuthService(gateway=SupabaseAuthGateway())
def get_auth_service(
session: Annotated[AsyncSession, Depends(get_db)],
) -> AuthService:
bootstrapper = RegistrationAutomationBootstrapService(
repository=RegistrationBootstrapRepository(session=session),
session=session,
)
return AuthService(
gateway=SupabaseAuthGateway(),
registration_bootstrapper=bootstrapper,
)
@@ -0,0 +1,228 @@
from __future__ import annotations
from datetime import UTC, datetime, timedelta
from typing import Protocol
from uuid import UUID, uuid4
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
from sqlalchemy import select
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.ext.asyncio import AsyncSession
from core.logging import get_logger
from models.automation_jobs import AutomationJob, AutomationJobStatus, ScheduleType
from models.memories import MemoryType
from models.profile import Profile
from schemas.automation import AutomationJobConfig
from schemas.memories.memory_content import UserMemoryContent, WorkProfileContent
from schemas.user.context import parse_profile_settings
from v1.auth.automation_static_config import load_static_automation_job_config
from v1.auth.schemas import RegistrationBootstrapRequest
from v1.memories.repository import SQLAlchemyMemoriesRepository
logger = get_logger("v1.auth.registration_bootstrap")
_LOCAL_RUN_HOUR = 8
_LOCAL_RUN_MINUTE = 0
class RegistrationBootstrapRepository:
def __init__(self, session: AsyncSession) -> None:
self._session = session
self._memories_repository = SQLAlchemyMemoriesRepository(session)
async def get_profile_timezone(self, *, user_id: UUID) -> str:
stmt = select(Profile.settings).where(Profile.id == user_id)
settings = (await self._session.execute(stmt)).scalar_one_or_none()
parsed = parse_profile_settings(
settings if isinstance(settings, dict) else None
)
return parsed.preferences.timezone
async def insert_bootstrap_automation_job_if_absent(
self,
*,
owner_id: UUID,
bootstrap_key: str,
title: str,
config: AutomationJobConfig,
timezone_name: str,
run_at: datetime,
next_run_at: datetime,
) -> bool:
stmt = (
insert(AutomationJob)
.values(
id=uuid4(),
owner_id=owner_id,
bootstrap_key=bootstrap_key,
title=title,
config=config.model_dump(mode="json"),
schedule_type=ScheduleType.DAILY,
run_at=run_at,
next_run_at=next_run_at,
timezone=timezone_name,
status=AutomationJobStatus.ACTIVE,
created_by=owner_id,
)
.on_conflict_do_nothing(
index_elements=["owner_id", "bootstrap_key"],
index_where=AutomationJob.deleted_at.is_(None)
& AutomationJob.bootstrap_key.is_not(None),
)
.returning(AutomationJob.id)
)
inserted_id = (await self._session.execute(stmt)).scalar_one_or_none()
await self._session.flush()
return inserted_id is not None
async def upsert_initial_memory(
self,
*,
owner_id: UUID,
memory_type: MemoryType,
content: dict,
) -> bool:
return await self._memories_repository.create_if_absent(
owner_id=owner_id,
memory_type=memory_type,
content=content,
)
class RegistrationBootstrapper(Protocol):
async def ensure_user_automation_jobs(self, *, user_id: str | UUID) -> None: ...
class RegistrationBootstrapRepositoryLike(Protocol):
async def get_profile_timezone(self, *, user_id: UUID) -> str: ...
async def insert_bootstrap_automation_job_if_absent(
self,
*,
owner_id: UUID,
bootstrap_key: str,
title: str,
config: AutomationJobConfig,
timezone_name: str,
run_at: datetime,
next_run_at: datetime,
) -> bool: ...
async def upsert_initial_memory(
self,
*,
owner_id: UUID,
memory_type: MemoryType,
content: dict,
) -> bool: ...
class SessionLike(Protocol):
async def commit(self) -> None: ...
async def rollback(self) -> None: ...
def compute_next_local_time_utc(
*,
now_utc: datetime,
timezone_name: str,
local_hour: int,
local_minute: int,
) -> tuple[datetime, datetime]:
try:
timezone_obj = ZoneInfo(timezone_name)
except ZoneInfoNotFoundError:
timezone_obj = ZoneInfo("UTC")
local_now = now_utc.astimezone(timezone_obj)
today_run_local = local_now.replace(
hour=local_hour,
minute=local_minute,
second=0,
microsecond=0,
)
run_local = (
today_run_local
if local_now <= today_run_local
else today_run_local + timedelta(days=1)
)
next_local = run_local + timedelta(days=1)
return run_local.astimezone(UTC), next_local.astimezone(UTC)
class RegistrationAutomationBootstrapService:
def __init__(
self,
*,
repository: RegistrationBootstrapRepositoryLike,
session: SessionLike,
) -> None:
self._repository = repository
self._session = session
async def ensure_user_automation_jobs(self, *, user_id: str | UUID) -> None:
request = RegistrationBootstrapRequest.model_validate({"user_id": user_id})
owner_id = request.user_id
timezone_name = await self._repository.get_profile_timezone(user_id=owner_id)
definitions = [
{
"bootstrap_key": "memory_extraction",
"config_name": "memory_extraction",
"title": "Memory Agent",
"local_hour": _LOCAL_RUN_HOUR,
"local_minute": _LOCAL_RUN_MINUTE,
}
]
try:
inserted_any = False
created_or_updated_memory = False
user_initialized = await self._repository.upsert_initial_memory(
owner_id=owner_id,
memory_type=MemoryType.USER,
content=UserMemoryContent().model_dump(mode="json"),
)
work_initialized = await self._repository.upsert_initial_memory(
owner_id=owner_id,
memory_type=MemoryType.WORK,
content=WorkProfileContent().model_dump(mode="json"),
)
created_or_updated_memory = user_initialized or work_initialized
for definition in definitions:
bootstrap_key = str(definition["bootstrap_key"])
job_config = load_static_automation_job_config(
config_name=str(definition["config_name"])
)
run_at, next_run_at = compute_next_local_time_utc(
now_utc=datetime.now(UTC),
timezone_name=timezone_name,
local_hour=int(definition["local_hour"]),
local_minute=int(definition["local_minute"]),
)
inserted = (
await self._repository.insert_bootstrap_automation_job_if_absent(
owner_id=owner_id,
bootstrap_key=bootstrap_key,
title=str(definition["title"]),
config=job_config,
timezone_name=timezone_name,
run_at=run_at,
next_run_at=next_run_at,
)
)
inserted_any = inserted_any or inserted
if inserted_any or created_or_updated_memory:
await self._session.commit()
logger.info(
"user automation jobs bootstrapped",
user_id=user_id,
timezone=timezone_name,
memory_initialized=created_or_updated_memory,
)
except Exception:
await self._session.rollback()
raise
+8
View File
@@ -1,5 +1,7 @@
from __future__ import annotations
from uuid import UUID
from pydantic import BaseModel, ConfigDict, Field
SUPABASE_PASSWORD_MIN_LENGTH = 6
@@ -49,3 +51,9 @@ class UserByPhoneResponse(BaseModel):
class OtpSendResponse(BaseModel):
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
class RegistrationBootstrapRequest(BaseModel):
model_config = ConfigDict(extra="forbid")
user_id: UUID
+18 -2
View File
@@ -28,9 +28,15 @@ class AuthServiceGateway(Protocol):
class AuthService:
_gateway: AuthServiceGateway
_registration_bootstrapper: RegistrationBootstrapper | None
def __init__(self, gateway: AuthServiceGateway) -> None:
def __init__(
self,
gateway: AuthServiceGateway,
registration_bootstrapper: "RegistrationBootstrapper | None" = None,
) -> None:
self._gateway = gateway
self._registration_bootstrapper = registration_bootstrapper
async def send_otp(self, request: OtpSendRequest) -> None:
await self._gateway.send_otp(request)
@@ -38,10 +44,20 @@ class AuthService:
async def create_phone_session(
self, request: PhoneSessionCreateRequest
) -> SessionResponse:
return await self._gateway.create_phone_session(request)
response = await self._gateway.create_phone_session(request)
if self._registration_bootstrapper is not None:
await self._registration_bootstrapper.ensure_user_automation_jobs(
user_id=response.user.id
)
return response
async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse:
return await self._gateway.refresh_session(request)
async def delete_session(self, refresh_token: str | None) -> None:
await self._gateway.delete_session(refresh_token)
class RegistrationBootstrapper(Protocol):
async def ensure_user_automation_jobs(self, *, user_id: str) -> None:
raise NotImplementedError
+30
View File
@@ -0,0 +1,30 @@
from __future__ import annotations
from typing import Annotated
from fastapi import Depends
from sqlalchemy.ext.asyncio import AsyncSession
from core.auth.models import CurrentUser
from core.db import get_db
from v1.memories.repository import SQLAlchemyMemoriesRepository
from v1.memories.service import MemoriesService
from v1.users.dependencies import get_current_user
async def get_memories_repository(
session: Annotated[AsyncSession, Depends(get_db)],
) -> SQLAlchemyMemoriesRepository:
return SQLAlchemyMemoriesRepository(session)
async def get_memories_service(
session: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
) -> MemoriesService:
repository = SQLAlchemyMemoriesRepository(session)
return MemoriesService(
repository=repository,
session=session,
current_user=current_user,
)
+144 -11
View File
@@ -3,29 +3,162 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Protocol
from uuid import UUID
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy import select
from sqlalchemy.exc import SQLAlchemyError
from core.db.base_repository import BaseRepository
from models.memories import Memory
from core.logging import get_logger
from models.memories import Memory, MemoryStatus, MemoryType
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
logger = get_logger("v1.memories.repository")
class MemoriesRepositoryLike(Protocol):
async def get_active_memories(self, *, owner_id: UUID) -> list[Memory]: ...
async def create(
self,
*,
owner_id: UUID,
memory_type: MemoryType,
content: dict,
) -> Memory: ...
async def get_by_type_for_owner(
self, *, owner_id: UUID, memory_type: MemoryType
) -> Memory | None: ...
async def get_user_memory_for_owner(self, *, owner_id: UUID) -> Memory | None: ...
async def get_work_memory_for_owner(self, *, owner_id: UUID) -> Memory | None: ...
async def create_if_absent(
self,
*,
owner_id: UUID,
memory_type: MemoryType,
content: dict,
) -> bool: ...
async def update_content(
self,
memory: Memory,
content: dict | None = None,
) -> Memory: ...
class MemoriesRepository(BaseRepository[Memory]):
class SQLAlchemyMemoriesRepository(BaseRepository[Memory]):
_session: AsyncSession
def __init__(self, session: AsyncSession) -> None:
super().__init__(session=session, model=Memory)
self._session = session
async def get_active_memories(self, *, owner_id: UUID) -> list[Memory]:
stmt = (
select(Memory)
.where(Memory.owner_id == owner_id)
.where(Memory.status == "active")
.order_by(Memory.created_at.desc())
async def create(
self,
*,
owner_id: UUID,
memory_type: MemoryType,
content: dict,
) -> Memory:
try:
memory = Memory(
owner_id=owner_id,
memory_type=memory_type,
content=content,
status=MemoryStatus.ACTIVE,
)
self._session.add(memory)
await self._session.flush()
return memory
except SQLAlchemyError:
logger.exception(
"Failed to create memory",
owner_id=str(owner_id),
memory_type=memory_type.value,
)
raise
async def get_by_type_for_owner(
self, *, owner_id: UUID, memory_type: MemoryType
) -> Memory | None:
try:
stmt = (
select(Memory)
.where(Memory.owner_id == owner_id)
.where(Memory.memory_type == memory_type)
.where(Memory.status == MemoryStatus.ACTIVE)
)
result = await self._session.execute(stmt)
if hasattr(result, "scalar_one_or_none"):
return result.scalar_one_or_none()
scalars = result.scalars()
rows = list(scalars.all())
return rows[0] if rows else None
except SQLAlchemyError:
logger.exception(
"Failed to get memory by type for owner",
owner_id=str(owner_id),
memory_type=memory_type.value,
)
raise
async def get_user_memory_for_owner(self, *, owner_id: UUID) -> Memory | None:
return await self.get_by_type_for_owner(
owner_id=owner_id, memory_type=MemoryType.USER
)
result = await self._session.execute(stmt)
return list(result.scalars().all())
async def get_work_memory_for_owner(self, *, owner_id: UUID) -> Memory | None:
return await self.get_by_type_for_owner(
owner_id=owner_id, memory_type=MemoryType.WORK
)
async def create_if_absent(
self,
*,
owner_id: UUID,
memory_type: MemoryType,
content: dict,
) -> bool:
try:
stmt = (
insert(Memory)
.values(
owner_id=owner_id,
memory_type=memory_type,
content=content,
status=MemoryStatus.ACTIVE,
)
.on_conflict_do_nothing(index_elements=["owner_id", "memory_type"])
.returning(Memory.id)
)
inserted_id = (await self._session.execute(stmt)).scalar_one_or_none()
await self._session.flush()
return inserted_id is not None
except SQLAlchemyError:
logger.exception(
"Failed to create memory if absent",
owner_id=str(owner_id),
memory_type=memory_type.value,
)
raise
async def update_content(
self,
memory: Memory,
content: dict | None = None,
) -> Memory:
try:
if content is not None:
memory.content = content
await self._session.flush()
return memory
except SQLAlchemyError:
logger.exception(
"Failed to update memory content",
memory_id=str(memory.id),
)
raise
+83
View File
@@ -0,0 +1,83 @@
from __future__ import annotations
from typing import Annotated
from fastapi import APIRouter, Depends, status
from schemas.memories.memory_content import UserMemoryContent, WorkProfileContent
from v1.memories.dependencies import get_memories_service
from v1.memories.schemas import (
MemoryListResponse,
UserMemoryPartialUpdate,
UserMemoryUpdate,
WorkMemoryPartialUpdate,
WorkMemoryUpdate,
)
from v1.memories.service import MemoriesService
router = APIRouter(prefix="/memories", tags=["memories"])
@router.get("", response_model=MemoryListResponse)
async def get_all_memories(
service: Annotated[MemoriesService, Depends(get_memories_service)],
) -> MemoryListResponse:
result = await service.get_all_memories()
return MemoryListResponse(
user_memory=result["user_memory"],
work_memory=result["work_memory"],
)
@router.get("/user", response_model=UserMemoryContent | None)
async def get_user_memory(
service: Annotated[MemoriesService, Depends(get_memories_service)],
) -> UserMemoryContent | None:
return await service.get_user_memory()
@router.get("/work", response_model=WorkProfileContent | None)
async def get_work_memory(
service: Annotated[MemoriesService, Depends(get_memories_service)],
) -> WorkProfileContent | None:
return await service.get_work_memory()
@router.put("/user", response_model=UserMemoryContent, status_code=status.HTTP_200_OK)
async def update_user_memory(
payload: UserMemoryUpdate,
service: Annotated[MemoriesService, Depends(get_memories_service)],
) -> UserMemoryContent:
return await service.update_user_memory(
content=payload.content,
)
@router.put("/work", response_model=WorkProfileContent, status_code=status.HTTP_200_OK)
async def update_work_memory(
payload: WorkMemoryUpdate,
service: Annotated[MemoriesService, Depends(get_memories_service)],
) -> WorkProfileContent:
return await service.update_work_memory(
content=payload.content,
)
@router.patch("/user", response_model=UserMemoryContent)
async def patch_user_memory(
payload: UserMemoryPartialUpdate,
service: Annotated[MemoriesService, Depends(get_memories_service)],
) -> UserMemoryContent:
return await service.patch_user_memory(
content=payload.content,
)
@router.patch("/work", response_model=WorkProfileContent)
async def patch_work_memory(
payload: WorkMemoryPartialUpdate,
service: Annotated[MemoriesService, Depends(get_memories_service)],
) -> WorkProfileContent:
return await service.patch_work_memory(
content=payload.content,
)
+38
View File
@@ -0,0 +1,38 @@
from __future__ import annotations
from typing import ClassVar
from pydantic import BaseModel, ConfigDict
from schemas.memories.memory_content import UserMemoryContent, WorkProfileContent
class UserMemoryUpdate(BaseModel):
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
content: UserMemoryContent
class WorkMemoryUpdate(BaseModel):
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
content: WorkProfileContent
class UserMemoryPartialUpdate(BaseModel):
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
content: UserMemoryContent | None = None
class WorkMemoryPartialUpdate(BaseModel):
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
content: WorkProfileContent | None = None
class MemoryListResponse(BaseModel):
model_config: ClassVar[ConfigDict] = ConfigDict(from_attributes=True)
user_memory: UserMemoryContent | None = None
work_memory: WorkProfileContent | None = None
+269 -36
View File
@@ -1,53 +1,286 @@
from __future__ import annotations
from uuid import UUID
from typing import TYPE_CHECKING
from models.memories import Memory
from schemas.memories import MemoryContext, MemoryListResponse, MemorySource, MemoryType
from fastapi import HTTPException
from sqlalchemy.exc import SQLAlchemyError
from core.auth.models import CurrentUser
from core.db.base_service import BaseService
from core.logging import get_logger
from models.memories import Memory, MemoryType
from schemas.memories.memory_content import UserMemoryContent, WorkProfileContent
from v1.memories.repository import MemoriesRepositoryLike
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
logger = get_logger("v1.memories.service")
class MemoriesService(BaseService):
"""Memories service handling user/work memory operations.
Each user has exactly 2 memory records:
- user_type: stores personal preferences, people, places, etc.
- work_type: stores work profile, projects, team, etc.
Responsibilities:
- Authorization checks
- Validation (ownership, memory type)
- Transaction boundary (commit/rollback)
- Converting ORM models to response schemas
"""
class MemoriesService:
_repository: MemoriesRepositoryLike
_session: AsyncSession
def __init__(self, repository: MemoriesRepositoryLike) -> None:
def __init__(
self,
repository: MemoriesRepositoryLike,
session: AsyncSession,
current_user: CurrentUser | None,
) -> None:
super().__init__(current_user=current_user)
self._repository = repository
self._session = session
def _to_context(self, memory: Memory) -> MemoryContext:
return MemoryContext(
memory_type=MemoryType(memory.memory_type.value),
source=MemorySource(memory.source.value),
title=memory.title,
content=memory.content,
created_at=memory.created_at,
updated_at=memory.updated_at,
async def get_user_memory(self) -> UserMemoryContent | None:
user_id = self.require_user_id()
try:
memory = await self._repository.get_user_memory_for_owner(owner_id=user_id)
except SQLAlchemyError:
raise HTTPException(status_code=503, detail="Memories service unavailable")
if memory is None:
return None
return self._parse_user_content(memory)
async def get_work_memory(self) -> WorkProfileContent | None:
user_id = self.require_user_id()
try:
memory = await self._repository.get_work_memory_for_owner(owner_id=user_id)
except SQLAlchemyError:
raise HTTPException(status_code=503, detail="Memories service unavailable")
if memory is None:
return None
return self._parse_work_content(memory)
async def get_all_memories(self) -> dict:
user_id = self.require_user_id()
try:
user_memory = await self._repository.get_user_memory_for_owner(
owner_id=user_id
)
work_memory = await self._repository.get_work_memory_for_owner(
owner_id=user_id
)
except SQLAlchemyError:
raise HTTPException(status_code=503, detail="Memories service unavailable")
return {
"user_memory": self._parse_user_content(user_memory)
if user_memory
else None,
"work_memory": self._parse_work_content(work_memory)
if work_memory
else None,
}
async def update_user_memory(
self,
*,
content: UserMemoryContent,
) -> UserMemoryContent:
user_id = self.require_user_id()
try:
memory = await self._repository.get_user_memory_for_owner(owner_id=user_id)
except SQLAlchemyError:
raise HTTPException(status_code=503, detail="Memories service unavailable")
if memory is None:
try:
memory = await self._repository.create(
owner_id=user_id,
memory_type=MemoryType.USER,
content=content.model_dump(),
)
await self._session.commit()
except SQLAlchemyError:
await self._session.rollback()
raise HTTPException(
status_code=503, detail="Memories service unavailable"
)
else:
try:
memory = await self._repository.update_content(
memory=memory,
content=content.model_dump(),
)
await self._session.commit()
await self._session.refresh(memory)
except SQLAlchemyError:
await self._session.rollback()
raise HTTPException(
status_code=503, detail="Memories service unavailable"
)
logger.info(
"user_memory_updated",
extra={"user_id": str(user_id)},
)
async def get_user_memories(self, *, owner_id: UUID) -> MemoryListResponse:
memories = await self._repository.get_active_memories(owner_id=owner_id)
user_memories = [
self._to_context(memory)
for memory in memories
if memory.memory_type.value == "user"
]
return MemoryListResponse(
owner_id=owner_id, memories=user_memories, total=len(user_memories)
return self._parse_user_content(memory)
async def update_work_memory(
self,
*,
content: WorkProfileContent,
) -> WorkProfileContent:
user_id = self.require_user_id()
try:
memory = await self._repository.get_work_memory_for_owner(owner_id=user_id)
except SQLAlchemyError:
raise HTTPException(status_code=503, detail="Memories service unavailable")
if memory is None:
try:
memory = await self._repository.create(
owner_id=user_id,
memory_type=MemoryType.WORK,
content=content.model_dump(),
)
await self._session.commit()
except SQLAlchemyError:
await self._session.rollback()
raise HTTPException(
status_code=503, detail="Memories service unavailable"
)
else:
try:
memory = await self._repository.update_content(
memory=memory,
content=content.model_dump(),
)
await self._session.commit()
await self._session.refresh(memory)
except SQLAlchemyError:
await self._session.rollback()
raise HTTPException(
status_code=503, detail="Memories service unavailable"
)
logger.info(
"work_memory_updated",
extra={"user_id": str(user_id)},
)
async def get_agent_memories(self, *, owner_id: UUID) -> MemoryListResponse:
memories = await self._repository.get_active_memories(owner_id=owner_id)
agent_memories = [
self._to_context(memory)
for memory in memories
if memory.memory_type.value == "work"
]
return MemoryListResponse(
owner_id=owner_id, memories=agent_memories, total=len(agent_memories)
return self._parse_work_content(memory)
async def patch_user_memory(
self, *, content: UserMemoryContent | None = None
) -> UserMemoryContent:
user_id = self.require_user_id()
try:
memory = await self._repository.get_user_memory_for_owner(owner_id=user_id)
except SQLAlchemyError:
raise HTTPException(status_code=503, detail="Memories service unavailable")
if memory is None:
raise HTTPException(status_code=404, detail="User memory not found")
try:
update_data: dict = {}
if content is not None:
existing_content = memory.content or {}
merged = content.model_dump()
existing_content.update(
{k: v for k, v in merged.items() if v is not None}
)
update_data["content"] = existing_content
if update_data:
memory = await self._repository.update_content(
memory=memory,
content=update_data.get("content"),
)
await self._session.commit()
await self._session.refresh(memory)
except SQLAlchemyError:
await self._session.rollback()
raise HTTPException(status_code=503, detail="Memories service unavailable")
logger.info(
"user_memory_patched",
extra={"user_id": str(user_id)},
)
async def get_all_memories(self, *, owner_id: UUID) -> MemoryListResponse:
memories = await self._repository.get_active_memories(owner_id=owner_id)
memory_contexts = [self._to_context(memory) for memory in memories]
return MemoryListResponse(
owner_id=owner_id, memories=memory_contexts, total=len(memory_contexts)
return self._parse_user_content(memory)
async def patch_work_memory(
self, *, content: WorkProfileContent | None = None
) -> WorkProfileContent:
user_id = self.require_user_id()
try:
memory = await self._repository.get_work_memory_for_owner(owner_id=user_id)
except SQLAlchemyError:
raise HTTPException(status_code=503, detail="Memories service unavailable")
if memory is None:
raise HTTPException(status_code=404, detail="Work memory not found")
try:
update_data: dict = {}
if content is not None:
existing_content = memory.content or {}
merged = content.model_dump()
existing_content.update(
{k: v for k, v in merged.items() if v is not None}
)
update_data["content"] = existing_content
if update_data:
memory = await self._repository.update_content(
memory=memory,
content=update_data.get("content"),
)
await self._session.commit()
await self._session.refresh(memory)
except SQLAlchemyError:
await self._session.rollback()
raise HTTPException(status_code=503, detail="Memories service unavailable")
logger.info(
"work_memory_patched",
extra={"user_id": str(user_id)},
)
return self._parse_work_content(memory)
def _parse_user_content(self, memory: Memory) -> UserMemoryContent:
content_dict = memory.content or {}
return UserMemoryContent.model_validate(content_dict)
def _parse_work_content(self, memory: Memory) -> WorkProfileContent:
content_dict = memory.content or {}
return WorkProfileContent.model_validate(content_dict)
async def get_memory_model(self, *, memory_type: MemoryType) -> Memory | None:
user_id = self.require_user_id()
try:
return await self._repository.get_by_type_for_owner(
owner_id=user_id,
memory_type=memory_type,
)
except SQLAlchemyError:
raise HTTPException(status_code=503, detail="Memories service unavailable")
+2 -1
View File
@@ -7,6 +7,7 @@ from v1.app.router import router as app_router
from v1.auth.router import router as auth_router
from v1.friendships.router import router as friendships_router
from v1.inbox_messages.router import router as inbox_messages_router
from v1.memories.router import router as memories_router
from v1.schedule_items.router import router as schedule_items_router
from v1.todo.router import router as todo_router
from v1.users.router import router as users_router
@@ -16,8 +17,8 @@ router = APIRouter(prefix="/api/v1")
router.include_router(app_router)
router.include_router(auth_router)
router.include_router(agent_router)
router.include_router(agent_router)
router.include_router(friendships_router)
router.include_router(memories_router)
router.include_router(users_router)
router.include_router(schedule_items_router)
router.include_router(inbox_messages_router)
@@ -6,7 +6,7 @@ import pytest
from ag_ui.core import RunAgentInput
from core.agentscope.runtime.orchestrator import AgentScopeRuntimeOrchestrator
from schemas.automation import MemoryContextConfig, RuntimeConfig
from schemas.automation import MessageContextConfig, RuntimeConfig
from schemas.user import UserContext, parse_profile_settings
@@ -51,7 +51,7 @@ def _run_input() -> RunAgentInput:
def _runtime_config() -> RuntimeConfig:
return RuntimeConfig(
enabled_tools=[],
context=MemoryContextConfig(),
context=MessageContextConfig(),
)
@@ -18,7 +18,7 @@ from schemas.agent.runtime_models import (
WorkerAgentOutputLite,
)
from schemas.agent.system_agent import AgentType
from schemas.automation import MemoryContextConfig, RuntimeConfig
from schemas.automation import MessageContextConfig, RuntimeConfig
from schemas.user import UserContext, parse_profile_settings
@@ -48,7 +48,7 @@ def _user_context() -> UserContext:
def _runtime_config() -> RuntimeConfig:
return RuntimeConfig(
enabled_tools=[],
context=MemoryContextConfig(),
context=MessageContextConfig(),
)
@@ -7,7 +7,7 @@ import pytest
import core.agentscope.runtime.tasks as tasks_module
from schemas.agent import ToolStatus
from schemas.automation import ContextWindowMode, MemoryContextConfig
from schemas.automation import ContextWindowMode, MessageContextConfig
from schemas.user import UserContext, parse_profile_settings
@@ -201,7 +201,7 @@ async def test_build_recent_context_messages_includes_all_user_attachments(
self,
*,
thread_id: str,
context_config: MemoryContextConfig,
context_config: MessageContextConfig,
) -> dict[str, object] | None:
del thread_id, context_config
return {
@@ -237,7 +237,7 @@ async def test_build_recent_context_messages_includes_all_user_attachments(
messages = await tasks_module._build_recent_context_messages(
session=object(),
thread_id=str(uuid4()),
context_config=MemoryContextConfig(
context_config=MessageContextConfig(
window_mode=ContextWindowMode.DAY,
window_count=2,
),
@@ -264,7 +264,7 @@ async def test_build_recent_context_messages_uses_tool_metadata_output(
self,
*,
thread_id: str,
context_config: MemoryContextConfig,
context_config: MessageContextConfig,
) -> dict[str, object] | None:
del thread_id, context_config
return {
@@ -295,7 +295,7 @@ async def test_build_recent_context_messages_uses_tool_metadata_output(
messages = await tasks_module._build_recent_context_messages(
session=object(),
thread_id=str(uuid4()),
context_config=MemoryContextConfig(),
context_config=MessageContextConfig(),
)
assert len(messages) == 1
@@ -319,7 +319,7 @@ async def test_build_recent_context_messages_skips_tool_without_metadata_output(
self,
*,
thread_id: str,
context_config: MemoryContextConfig,
context_config: MessageContextConfig,
) -> dict[str, object] | None:
del thread_id, context_config
return {
@@ -337,7 +337,7 @@ async def test_build_recent_context_messages_skips_tool_without_metadata_output(
messages = await tasks_module._build_recent_context_messages(
session=object(),
thread_id=str(uuid4()),
context_config=MemoryContextConfig(),
context_config=MessageContextConfig(),
)
assert messages == []
@@ -357,7 +357,7 @@ async def test_build_recent_context_messages_passes_context_config(
self,
*,
thread_id: str,
context_config: MemoryContextConfig,
context_config: MessageContextConfig,
) -> dict[str, object] | None:
del thread_id
captured_config["config"] = context_config
@@ -365,7 +365,7 @@ async def test_build_recent_context_messages_passes_context_config(
monkeypatch.setattr(tasks_module, "AgentContextService", _FakeContextService)
cfg = MemoryContextConfig(window_mode=ContextWindowMode.NUMBER, window_count=10)
cfg = MessageContextConfig(window_mode=ContextWindowMode.NUMBER, window_count=10)
messages = await tasks_module._build_recent_context_messages(
session=object(),
thread_id=str(uuid4()),
@@ -0,0 +1,113 @@
from __future__ import annotations
import json
from types import SimpleNamespace
from uuid import uuid4
import pytest
from agentscope.tool import ToolResponse
from core.agentscope.tools.custom import memory as memory_module
from models.memories import MemoryType
from schemas.memories.memory_content import UserMemoryContent
def _decode_tool_response(response: ToolResponse) -> dict[str, object]:
assert response.content
first = response.content[0]
text = str(first.get("text", "")) if isinstance(first, dict) else str(first.text)
return json.loads(text)
def _payload_error_code(payload: dict[str, object]) -> str:
error = payload.get("error")
if not isinstance(error, dict):
return ""
return str(error.get("code") or "")
class _FakeMemoriesService:
def __init__(self) -> None:
self.memory: object | None = None
self.updated_user = 0
self.updated_work = 0
async def get_memory_model(self, *, memory_type: MemoryType):
_ = memory_type
return self.memory
async def update_user_memory(self, **kwargs):
_ = kwargs
self.updated_user += 1
return SimpleNamespace()
async def update_work_memory(self, **kwargs):
_ = kwargs
self.updated_work += 1
return SimpleNamespace()
def _user_memory():
return SimpleNamespace(
id=uuid4(),
owner_id=uuid4(),
memory_type=MemoryType.USER,
content={"preferences": {"communication_style": "简洁"}},
status="active",
)
@pytest.mark.asyncio
async def test_memory_write_requires_runtime_context() -> None:
response = await memory_module.memory_write(
memory_type="user",
user_content=UserMemoryContent(interests=["跑步"]),
)
payload = _decode_tool_response(response)
assert payload["status"] == "failure"
assert _payload_error_code(payload) == "MISSING_RUNTIME_ARGS"
@pytest.mark.asyncio
async def test_memory_write_updates_user_content(
monkeypatch: pytest.MonkeyPatch,
) -> None:
fake_service = _FakeMemoriesService()
monkeypatch.setattr(
memory_module, "create_memories_service", lambda **_: fake_service
)
response = await memory_module.memory_write(
memory_type="user",
user_content=UserMemoryContent(interests=["阅读"]),
session=SimpleNamespace(),
owner_id=uuid4(),
)
payload = _decode_tool_response(response)
assert payload["status"] == "success"
assert "memory_type=user" in str(payload["result"])
assert fake_service.updated_user == 1
@pytest.mark.asyncio
async def test_memory_forget_updates_content_paths(
monkeypatch: pytest.MonkeyPatch,
) -> None:
fake_service = _FakeMemoriesService()
fake_service.memory = _user_memory()
monkeypatch.setattr(
memory_module, "create_memories_service", lambda **_: fake_service
)
response = await memory_module.memory_forget(
memory_type="user",
forget_paths=["preferences.communication_style"],
session=SimpleNamespace(),
owner_id=uuid4(),
)
payload = _decode_tool_response(response)
assert payload["status"] == "success"
assert "forgotten=1" in str(payload["result"])
assert fake_service.updated_user == 1
@@ -158,46 +158,44 @@ def test_build_system_prompt_keeps_sections_focused_without_language_duplication
assert "Follow agent contracts strictly" not in prompt
def test_build_system_prompt_includes_memory_section_when_memories_provided() -> None:
from schemas.memories import (
MemoryContext,
MemoryListResponse,
MemorySource,
MemoryType,
def test_build_system_prompt_includes_user_memory_section_for_router() -> None:
from schemas.memories.memory_content import UserMemoryContent
user_memory = UserMemoryContent()
prompt = build_system_prompt(
agent_type=AgentType.ROUTER,
user_context=_build_user_context(),
now_utc=datetime(2026, 3, 11, 0, 0, tzinfo=timezone.utc),
user_memory=user_memory,
)
memories = MemoryListResponse(
owner_id=uuid4(),
memories=[
MemoryContext(
memory_type=MemoryType.USER,
source=MemorySource.MANUAL,
title="User prefers morning meetings",
content={"text": "User likes meetings before 10am"},
created_at=datetime(2026, 3, 1, tzinfo=timezone.utc),
updated_at=datetime(2026, 3, 1, tzinfo=timezone.utc),
),
],
total=1,
)
assert "<!-- USER_MEMORY_START -->" in prompt
assert "[User Memory]" in prompt
def test_build_system_prompt_includes_work_memory_section_for_worker() -> None:
from schemas.memories.memory_content import WorkProfileContent
work_memory = WorkProfileContent()
prompt = build_system_prompt(
agent_type=AgentType.WORKER,
user_context=_build_user_context(),
now_utc=datetime(2026, 3, 11, 0, 0, tzinfo=timezone.utc),
memories=memories,
work_memory=work_memory,
)
assert "<!-- MEMORY_START -->" in prompt
assert "[User Memories]" in prompt
assert "User prefers morning meetings" in prompt
assert "<!-- WORK_MEMORY_START -->" in prompt
assert "[Work Memory]" in prompt
def test_build_system_prompt_omits_memory_section_when_no_memories() -> None:
prompt = build_system_prompt(
agent_type=AgentType.WORKER,
agent_type=AgentType.ROUTER,
user_context=_build_user_context(),
now_utc=datetime(2026, 3, 11, 0, 0, tzinfo=timezone.utc),
)
assert "<!-- MEMORY_START -->" not in prompt
assert "<!-- USER_MEMORY_START -->" not in prompt
assert "<!-- WORK_MEMORY_START -->" not in prompt
@@ -28,25 +28,3 @@ def test_build_stage_toolkit_uses_explicit_enabled_tools_as_final_set(
)
assert captured["enabled_tool_names"] == {"calendar_read", "user_lookup"}
def test_build_stage_toolkit_uses_memory_defaults_without_explicit_tools(
monkeypatch,
) -> None:
captured: dict[str, object] = {}
def _fake_build_toolkit(**kwargs):
captured.update(kwargs)
return object()
monkeypatch.setattr(
"core.agentscope.tools.toolkit.build_toolkit", _fake_build_toolkit
)
build_stage_toolkit(
agent_type=AgentType.MEMORY,
session=cast(Any, object()),
owner_id=uuid4(),
)
assert captured["enabled_tool_names"] == {"calendar_read", "user_lookup"}
@@ -16,13 +16,14 @@ async def test_build_toolkit_registers_calendar_tools() -> None:
toolkit = build_toolkit(
session=cast(AsyncSession, SimpleNamespace()),
owner_id=uuid4(),
user_token="token-123",
)
schemas = toolkit.get_json_schemas()
names = {item["function"]["name"] for item in schemas}
assert "calendar_read" in names
assert "calendar_write" in names
assert "calendar_share" in names
assert "memory_write" in names
assert "memory_forget" in names
write_schema = next(
item for item in schemas if item["function"]["name"] == "calendar_write"
@@ -0,0 +1,17 @@
from __future__ import annotations
from v1.auth.automation_static_config import load_static_automation_job_config
def test_memory_automation_static_config_contract() -> None:
config = load_static_automation_job_config(config_name="memory_extraction")
assert config.context.window_mode.value == "day"
assert config.context.window_count == 2
assert [tool.value for tool in config.enabled_tools] == [
"memory.write",
"memory.forget",
]
prompt = config.input_template
assert "提取" in prompt
assert "遗忘" in prompt
@@ -16,3 +16,18 @@ def test_memory_automation_job_trigger_exists_in_0004_migration() -> None:
assert "'agent_type', 'memory'" in content
assert "ux_automation_jobs_owner_memory_active" in content
assert "input_template" in content
def test_bootstrap_key_replaces_agent_type_unique_anchor() -> None:
migration = (
Path(__file__).resolve().parents[3]
/ "alembic"
/ "versions"
/ "20260323_0003_bootstrap_job_key_and_unique_indexes.py"
)
content = migration.read_text(encoding="utf-8")
assert "bootstrap_key" in content
assert "ux_automation_jobs_owner_bootstrap_key_active" in content
assert "ux_memories_owner_memory_type" in content
assert "DROP INDEX IF EXISTS ux_automation_jobs_owner_memory_active" in content
@@ -12,6 +12,14 @@ from v1.auth.schemas import (
from v1.auth.service import AuthService, AuthServiceGateway
class FakeRegistrationBootstrapper:
def __init__(self) -> None:
self.called_user_ids: list[str] = []
async def ensure_user_automation_jobs(self, *, user_id: str) -> None:
self.called_user_ids.append(user_id)
class FakeGateway(AuthServiceGateway):
def __init__(self, response: SessionResponse) -> None:
self._response = response
@@ -75,6 +83,27 @@ async def test_create_phone_session_forwards_payload() -> None:
assert response.user.phone == "+8613812345678"
@pytest.mark.asyncio
async def test_create_phone_session_bootstraps_automation_job() -> None:
user = AuthUser(id="b196f8be-c5f4-45d8-8f07-65c0ddf4d3de", phone="+8613812345678")
token_response = SessionResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
gateway = FakeGateway(token_response)
bootstrapper = FakeRegistrationBootstrapper()
service = AuthService(gateway=gateway, registration_bootstrapper=bootstrapper)
await service.create_phone_session(
PhoneSessionCreateRequest(phone="+8613812345678", token="123456")
)
assert bootstrapper.called_user_ids == ["b196f8be-c5f4-45d8-8f07-65c0ddf4d3de"]
@pytest.mark.asyncio
async def test_refresh_session_forwards_payload() -> None:
user = AuthUser(id="user-1", phone="+8613812345678")
@@ -0,0 +1,112 @@
from __future__ import annotations
from datetime import datetime, timezone
from typing import Any, cast
from uuid import uuid4
import pytest
from v1.auth.registration_bootstrap import (
compute_next_local_time_utc,
)
def test_compute_next_local_time_utc_from_asia_shanghai() -> None:
now_utc = datetime(2026, 3, 23, 0, 30, tzinfo=timezone.utc)
run_at, next_run_at = compute_next_local_time_utc(
now_utc=now_utc,
timezone_name="Asia/Shanghai",
local_hour=8,
local_minute=0,
)
assert run_at == datetime(2026, 3, 24, 0, 0, tzinfo=timezone.utc)
assert next_run_at == datetime(2026, 3, 25, 0, 0, tzinfo=timezone.utc)
def test_compute_next_local_time_utc_rolls_to_next_day_when_passed() -> None:
now_utc = datetime(2026, 3, 23, 2, 30, tzinfo=timezone.utc)
run_at, next_run_at = compute_next_local_time_utc(
now_utc=now_utc,
timezone_name="Asia/Shanghai",
local_hour=8,
local_minute=0,
)
assert run_at == datetime(2026, 3, 24, 0, 0, tzinfo=timezone.utc)
assert next_run_at == datetime(2026, 3, 25, 0, 0, tzinfo=timezone.utc)
@pytest.mark.asyncio
async def test_registration_service_is_idempotent_when_job_exists() -> None:
from v1.auth.registration_bootstrap import RegistrationAutomationBootstrapService
expected_owner_id = uuid4()
class _Repo:
inserted = 0
upsert_calls = 0
async def get_profile_timezone(self, *, user_id):
assert user_id == expected_owner_id
return "Asia/Shanghai"
async def insert_bootstrap_automation_job_if_absent(self, **kwargs):
assert kwargs["owner_id"] == expected_owner_id
assert kwargs["bootstrap_key"] == "memory_extraction"
self.inserted += 1
return False
async def upsert_initial_memory(self, **kwargs):
self.upsert_calls += 1
return False
class _Session:
async def commit(self):
raise AssertionError("must not commit when already exists")
async def rollback(self):
raise AssertionError("must not rollback when no error")
service = RegistrationAutomationBootstrapService(
repository=cast(Any, _Repo()), session=cast(Any, _Session())
)
await service.ensure_user_automation_jobs(user_id=str(expected_owner_id))
@pytest.mark.asyncio
async def test_registration_service_creates_initial_memories_when_missing() -> None:
from v1.auth.registration_bootstrap import RegistrationAutomationBootstrapService
expected_owner_id = uuid4()
class _Repo:
async def get_profile_timezone(self, *, user_id):
assert user_id == expected_owner_id
return "Asia/Shanghai"
async def upsert_initial_memory(self, **kwargs):
return True
async def insert_bootstrap_automation_job_if_absent(self, **kwargs):
_ = kwargs
return True
class _Session:
committed = 0
async def commit(self):
self.committed += 1
async def rollback(self):
raise AssertionError("must not rollback when no error")
session = _Session()
service = RegistrationAutomationBootstrapService(
repository=cast(Any, _Repo()), session=cast(Any, session)
)
await service.ensure_user_automation_jobs(user_id=str(expected_owner_id))
assert session.committed == 1