feat: 重构 memory 系统,支持 user memory 和 work memory 分离
This commit is contained in:
@@ -0,0 +1,38 @@
|
||||
"""drop source column from memories
|
||||
|
||||
Revision ID: 202603230001
|
||||
Revises: 202603200001
|
||||
Create Date: 2026-03-23 18:00:00
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision: str = "202603230001"
|
||||
down_revision: Union[str, Sequence[str], None] = "202603200001"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
bind = op.get_bind()
|
||||
inspector = sa.inspect(bind)
|
||||
columns = {column["name"] for column in inspector.get_columns("memories")}
|
||||
if "source" in columns:
|
||||
op.drop_column("memories", "source")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
bind = op.get_bind()
|
||||
inspector = sa.inspect(bind)
|
||||
columns = {column["name"] for column in inspector.get_columns("memories")}
|
||||
if "source" not in columns:
|
||||
op.add_column(
|
||||
"memories",
|
||||
sa.Column(
|
||||
"source", sa.String(length=20), nullable=False, server_default="agent"
|
||||
),
|
||||
)
|
||||
op.alter_column("memories", "source", server_default=None)
|
||||
@@ -0,0 +1,34 @@
|
||||
"""drop title column from memories
|
||||
|
||||
Revision ID: 202603230002
|
||||
Revises: 202603230001
|
||||
Create Date: 2026-03-23 21:00:00
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision: str = "202603230002"
|
||||
down_revision: Union[str, Sequence[str], None] = "202603230001"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
bind = op.get_bind()
|
||||
inspector = sa.inspect(bind)
|
||||
columns = {column["name"] for column in inspector.get_columns("memories")}
|
||||
if "title" in columns:
|
||||
op.drop_column("memories", "title")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
bind = op.get_bind()
|
||||
inspector = sa.inspect(bind)
|
||||
columns = {column["name"] for column in inspector.get_columns("memories")}
|
||||
if "title" not in columns:
|
||||
op.add_column(
|
||||
"memories", sa.Column("title", sa.String(length=255), nullable=True)
|
||||
)
|
||||
@@ -0,0 +1,355 @@
|
||||
"""add bootstrap key and unique indexes for registration bootstrap
|
||||
|
||||
Revision ID: 202603230003
|
||||
Revises: 202603230002
|
||||
Create Date: 2026-03-23 23:10:00
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision: str = "202603230003"
|
||||
down_revision: Union[str, Sequence[str], None] = "202603230002"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
bind = op.get_bind()
|
||||
inspector = sa.inspect(bind)
|
||||
|
||||
automation_columns = {
|
||||
column["name"] for column in inspector.get_columns("automation_jobs")
|
||||
}
|
||||
if "bootstrap_key" not in automation_columns:
|
||||
op.add_column(
|
||||
"automation_jobs",
|
||||
sa.Column("bootstrap_key", sa.String(length=64), nullable=True),
|
||||
)
|
||||
|
||||
op.execute("DROP INDEX IF EXISTS ux_automation_jobs_owner_memory_active")
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE public.automation_jobs
|
||||
SET bootstrap_key = 'memory_extraction'
|
||||
WHERE bootstrap_key IS NULL
|
||||
AND (
|
||||
config->>'agent_type' = 'memory'
|
||||
OR (
|
||||
created_by = owner_id
|
||||
AND title = 'Memory Agent'
|
||||
AND coalesce(config->'enabled_tools', '[]'::jsonb) @> '["memory.write", "memory.forget"]'::jsonb
|
||||
AND jsonb_array_length(coalesce(config->'enabled_tools', '[]'::jsonb)) = 2
|
||||
AND coalesce(config->'context', '{}'::jsonb) @> jsonb_build_object(
|
||||
'source', 'latest_chat',
|
||||
'window_mode', 'day',
|
||||
'window_count', 2
|
||||
)
|
||||
)
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS public.memories_dedup_backup_202603230003
|
||||
(LIKE public.memories INCLUDING ALL)
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS public.automation_jobs_dedup_backup_202603230003 (
|
||||
id UUID PRIMARY KEY
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
WITH ranked AS (
|
||||
SELECT
|
||||
id,
|
||||
row_number() OVER (
|
||||
PARTITION BY owner_id, bootstrap_key
|
||||
ORDER BY updated_at DESC, created_at DESC, id DESC
|
||||
) AS rn
|
||||
FROM public.automation_jobs
|
||||
WHERE deleted_at IS NULL
|
||||
AND bootstrap_key IS NOT NULL
|
||||
)
|
||||
INSERT INTO public.automation_jobs_dedup_backup_202603230003(id)
|
||||
SELECT id
|
||||
FROM ranked
|
||||
WHERE rn > 1
|
||||
ON CONFLICT (id) DO NOTHING
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
WITH ranked AS (
|
||||
SELECT
|
||||
id,
|
||||
row_number() OVER (
|
||||
PARTITION BY owner_id, memory_type
|
||||
ORDER BY updated_at DESC, created_at DESC, id DESC
|
||||
) AS rn
|
||||
FROM public.memories
|
||||
)
|
||||
INSERT INTO public.memories_dedup_backup_202603230003
|
||||
SELECT m.*
|
||||
FROM public.memories m
|
||||
JOIN ranked r ON r.id = m.id
|
||||
WHERE r.rn > 1
|
||||
ON CONFLICT (id) DO NOTHING
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
WITH ranked AS (
|
||||
SELECT
|
||||
id,
|
||||
row_number() OVER (
|
||||
PARTITION BY owner_id, bootstrap_key
|
||||
ORDER BY updated_at DESC, created_at DESC, id DESC
|
||||
) AS rn
|
||||
FROM public.automation_jobs
|
||||
WHERE deleted_at IS NULL
|
||||
AND bootstrap_key IS NOT NULL
|
||||
)
|
||||
UPDATE public.automation_jobs aj
|
||||
SET deleted_at = now()
|
||||
FROM ranked r
|
||||
WHERE aj.id = r.id
|
||||
AND r.rn > 1
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS ux_automation_jobs_owner_bootstrap_key_active
|
||||
ON public.automation_jobs(owner_id, bootstrap_key)
|
||||
WHERE deleted_at IS NULL
|
||||
AND bootstrap_key IS NOT NULL
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
WITH ranked AS (
|
||||
SELECT
|
||||
id,
|
||||
row_number() OVER (
|
||||
PARTITION BY owner_id, memory_type
|
||||
ORDER BY updated_at DESC, created_at DESC, id DESC
|
||||
) AS rn
|
||||
FROM public.memories
|
||||
)
|
||||
DELETE FROM public.memories m
|
||||
USING ranked r
|
||||
WHERE m.id = r.id
|
||||
AND r.rn > 1
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS ux_memories_owner_memory_type
|
||||
ON public.memories(owner_id, memory_type)
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE OR REPLACE FUNCTION public.create_profile_for_new_user()
|
||||
RETURNS trigger
|
||||
LANGUAGE plpgsql
|
||||
SECURITY DEFINER
|
||||
SET search_path = ''
|
||||
AS $$
|
||||
DECLARE
|
||||
base_seed TEXT;
|
||||
candidate_username TEXT;
|
||||
attempt INT := 0;
|
||||
BEGIN
|
||||
base_seed := coalesce(NEW.phone, NEW.id::text);
|
||||
|
||||
LOOP
|
||||
candidate_username := 'user_' || public.generate_profile_username_suffix(base_seed || ':' || attempt::text);
|
||||
EXIT WHEN NOT EXISTS (
|
||||
SELECT 1 FROM public.profiles p WHERE p.username = candidate_username
|
||||
);
|
||||
attempt := attempt + 1;
|
||||
IF attempt >= 50 THEN
|
||||
candidate_username := 'user_' || substr(replace(NEW.id::text, '-', ''), 1, 6);
|
||||
EXIT;
|
||||
END IF;
|
||||
END LOOP;
|
||||
|
||||
INSERT INTO public.profiles (id, username, avatar_url, bio, settings, created_at, updated_at)
|
||||
VALUES (
|
||||
NEW.id,
|
||||
candidate_username,
|
||||
NULL,
|
||||
NULL,
|
||||
'{}'::jsonb,
|
||||
now(),
|
||||
now()
|
||||
)
|
||||
ON CONFLICT (id) DO NOTHING;
|
||||
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("DROP INDEX IF EXISTS ux_memories_owner_memory_type")
|
||||
op.execute("DROP INDEX IF EXISTS ux_automation_jobs_owner_bootstrap_key_active")
|
||||
|
||||
bind = op.get_bind()
|
||||
inspector = sa.inspect(bind)
|
||||
tables = set(inspector.get_table_names(schema="public"))
|
||||
|
||||
if "automation_jobs_dedup_backup_202603230003" in tables:
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE public.automation_jobs aj
|
||||
SET deleted_at = NULL
|
||||
FROM public.automation_jobs_dedup_backup_202603230003 b
|
||||
WHERE aj.id = b.id
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
"DROP TABLE IF EXISTS public.automation_jobs_dedup_backup_202603230003"
|
||||
)
|
||||
|
||||
if "memories_dedup_backup_202603230003" in tables:
|
||||
op.execute(
|
||||
"""
|
||||
INSERT INTO public.memories
|
||||
SELECT b.*
|
||||
FROM public.memories_dedup_backup_202603230003 b
|
||||
LEFT JOIN public.memories m ON m.id = b.id
|
||||
WHERE m.id IS NULL
|
||||
ON CONFLICT (id) DO NOTHING
|
||||
"""
|
||||
)
|
||||
op.execute("DROP TABLE IF EXISTS public.memories_dedup_backup_202603230003")
|
||||
|
||||
automation_columns = {
|
||||
column["name"] for column in inspector.get_columns("automation_jobs")
|
||||
}
|
||||
if "bootstrap_key" in automation_columns:
|
||||
op.drop_column("automation_jobs", "bootstrap_key")
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS ux_automation_jobs_owner_memory_active
|
||||
ON public.automation_jobs(owner_id)
|
||||
WHERE deleted_at IS NULL
|
||||
AND config->>'agent_type' = 'memory'
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE OR REPLACE FUNCTION public.create_profile_for_new_user()
|
||||
RETURNS trigger
|
||||
LANGUAGE plpgsql
|
||||
SECURITY DEFINER
|
||||
SET search_path = ''
|
||||
AS $$
|
||||
DECLARE
|
||||
base_seed TEXT;
|
||||
candidate_username TEXT;
|
||||
attempt INT := 0;
|
||||
BEGIN
|
||||
base_seed := coalesce(NEW.phone, NEW.id::text);
|
||||
|
||||
LOOP
|
||||
candidate_username := 'user_' || public.generate_profile_username_suffix(base_seed || ':' || attempt::text);
|
||||
EXIT WHEN NOT EXISTS (
|
||||
SELECT 1 FROM public.profiles p WHERE p.username = candidate_username
|
||||
);
|
||||
attempt := attempt + 1;
|
||||
IF attempt >= 50 THEN
|
||||
candidate_username := 'user_' || substr(replace(NEW.id::text, '-', ''), 1, 6);
|
||||
EXIT;
|
||||
END IF;
|
||||
END LOOP;
|
||||
|
||||
INSERT INTO public.profiles (id, username, avatar_url, bio, settings, created_at, updated_at)
|
||||
VALUES (
|
||||
NEW.id,
|
||||
candidate_username,
|
||||
NULL,
|
||||
NULL,
|
||||
'{}'::jsonb,
|
||||
now(),
|
||||
now()
|
||||
)
|
||||
ON CONFLICT (id) DO NOTHING;
|
||||
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM public.automation_jobs aj
|
||||
WHERE aj.owner_id = NEW.id
|
||||
AND aj.deleted_at IS NULL
|
||||
AND aj.config->>'agent_type' = 'memory'
|
||||
) THEN
|
||||
INSERT INTO public.automation_jobs (
|
||||
id,
|
||||
owner_id,
|
||||
title,
|
||||
config,
|
||||
schedule_type,
|
||||
run_at,
|
||||
next_run_at,
|
||||
timezone,
|
||||
status,
|
||||
created_by,
|
||||
created_at,
|
||||
updated_at
|
||||
) VALUES (
|
||||
gen_random_uuid(),
|
||||
NEW.id,
|
||||
'Memory Agent',
|
||||
jsonb_build_object(
|
||||
'agent_type', 'memory',
|
||||
'model_code', 'qwen3.5-flash',
|
||||
'enabled_tools', jsonb_build_array('calendar.read', 'user.lookup'),
|
||||
'input_template', '请基于最近聊天上下文生成一段可执行的记忆总结与建议。',
|
||||
'context', jsonb_build_object(
|
||||
'source', 'latest_chat',
|
||||
'window_mode', 'day',
|
||||
'window_count', 2
|
||||
)
|
||||
),
|
||||
'daily',
|
||||
now(),
|
||||
now() + interval '1 day',
|
||||
'UTC',
|
||||
'active',
|
||||
NEW.id,
|
||||
now(),
|
||||
now()
|
||||
);
|
||||
END IF;
|
||||
EXCEPTION WHEN unique_violation THEN
|
||||
NULL;
|
||||
END;
|
||||
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
@@ -1,11 +1,15 @@
|
||||
from core.agentscope.prompts.agent_prompt import build_agent_prompt
|
||||
from core.agentscope.prompts.memory_prompt import build_memory_prompt
|
||||
from core.agentscope.prompts.memory_prompt import (
|
||||
build_user_memory_prompt,
|
||||
build_work_memory_prompt,
|
||||
)
|
||||
from core.agentscope.prompts.system_prompt import build_system_prompt
|
||||
from core.agentscope.prompts.tool_prompt import build_tools_prompt
|
||||
|
||||
__all__ = [
|
||||
"build_agent_prompt",
|
||||
"build_memory_prompt",
|
||||
"build_user_memory_prompt",
|
||||
"build_work_memory_prompt",
|
||||
"build_system_prompt",
|
||||
"build_tools_prompt",
|
||||
]
|
||||
|
||||
@@ -1,52 +1,59 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from schemas.memories import MemoryContext, MemoryListResponse
|
||||
from schemas.memories.memory_content import UserMemoryContent, WorkProfileContent
|
||||
|
||||
|
||||
def _wrap_section(section: str, content: str) -> str:
|
||||
marker_map = {
|
||||
"memory": ("<!-- MEMORY_START -->", "<!-- MEMORY_END -->"),
|
||||
"user_memory": ("<!-- USER_MEMORY_START -->", "<!-- USER_MEMORY_END -->"),
|
||||
"work_memory": ("<!-- WORK_MEMORY_START -->", "<!-- WORK_MEMORY_END -->"),
|
||||
}
|
||||
start, end = marker_map[section]
|
||||
body = content.strip()
|
||||
return f"{start}\n{body}\n{end}" if body else f"{start}\n{end}"
|
||||
|
||||
|
||||
def _format_memory_content(content: dict[str, Any]) -> str:
|
||||
def _format_content(content: UserMemoryContent | WorkProfileContent) -> str:
|
||||
if isinstance(content, dict):
|
||||
return json.dumps(content, ensure_ascii=True, separators=(",", ":"))
|
||||
return str(content)
|
||||
return json.dumps(
|
||||
content.model_dump(mode="json"), ensure_ascii=True, separators=(",", ":")
|
||||
)
|
||||
|
||||
|
||||
def _format_memory(ctx: MemoryContext) -> str:
|
||||
parts = [
|
||||
f"[{ctx.memory_type.value.upper()}] {ctx.title or 'Untitled'}",
|
||||
f" source: {ctx.source.value}",
|
||||
f" content: {_format_memory_content(ctx.content)}",
|
||||
]
|
||||
if ctx.created_at:
|
||||
parts.append(f" created_at: {ctx.created_at.isoformat()}")
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def build_memory_prompt(
|
||||
def build_user_memory_prompt(
|
||||
*,
|
||||
memories: MemoryListResponse,
|
||||
user_memory: UserMemoryContent | None,
|
||||
) -> str | None:
|
||||
if not memories.memories:
|
||||
if user_memory is None:
|
||||
return None
|
||||
|
||||
lines: list[str] = [
|
||||
"[User Memories]",
|
||||
"- Memories are persistent context from previous sessions.",
|
||||
"- Use them to ground responses in known user facts and preferences.",
|
||||
"- Do not invent facts not present in memories.",
|
||||
"[User Memory]",
|
||||
"- User memory contains personal preferences, habits, people, and places.",
|
||||
"- Use this to understand the user's personal context and preferences.",
|
||||
"- Do not invent facts not present here.",
|
||||
f"content: {_format_content(user_memory)}",
|
||||
]
|
||||
|
||||
for ctx in memories.memories:
|
||||
lines.append(_format_memory(ctx))
|
||||
return _wrap_section("user_memory", "\n".join(lines))
|
||||
|
||||
return _wrap_section("memory", "\n".join(lines))
|
||||
|
||||
def build_work_memory_prompt(
|
||||
*,
|
||||
work_memory: WorkProfileContent | None,
|
||||
) -> str | None:
|
||||
if work_memory is None:
|
||||
return None
|
||||
|
||||
lines: list[str] = [
|
||||
"[Work Memory]",
|
||||
"- Work memory contains projects, team members, habits, and milestones.",
|
||||
"- Use this to understand the user's work context and ongoing tasks.",
|
||||
"- Do not invent facts not present here.",
|
||||
f"content: {_format_content(work_memory)}",
|
||||
]
|
||||
|
||||
return _wrap_section("work_memory", "\n".join(lines))
|
||||
|
||||
@@ -9,12 +9,15 @@ from ag_ui.core.types import Tool
|
||||
from core.agentscope.prompts.agent_prompt import (
|
||||
build_agent_prompt,
|
||||
)
|
||||
from core.agentscope.prompts.memory_prompt import build_memory_prompt
|
||||
from core.agentscope.prompts.memory_prompt import (
|
||||
build_user_memory_prompt,
|
||||
build_work_memory_prompt,
|
||||
)
|
||||
from core.agentscope.prompts.route_prompt import build_frontend_route_prompt
|
||||
from core.agentscope.prompts.tool_prompt import build_tools_prompt
|
||||
from schemas.agent.system_agent import AgentType, SystemAgentLLMConfig
|
||||
from schemas.agent.forwarded_props import ClientTimeContext
|
||||
from schemas.memories import MemoryListResponse
|
||||
from schemas.memories.memory_content import UserMemoryContent, WorkProfileContent
|
||||
from schemas.user.context import UserContext
|
||||
|
||||
|
||||
@@ -210,9 +213,16 @@ def build_system_prompt(
|
||||
runtime_client_time: ClientTimeContext | None = None,
|
||||
extra_context: str | None = None,
|
||||
tools: Sequence[Tool | dict[str, Any]] | None = None,
|
||||
memories: MemoryListResponse | None = None,
|
||||
user_memory: UserMemoryContent | None = None,
|
||||
work_memory: WorkProfileContent | None = None,
|
||||
) -> str:
|
||||
include_route_section = agent_type == AgentType.WORKER
|
||||
|
||||
if agent_type == AgentType.ROUTER:
|
||||
memory_prompt = build_user_memory_prompt(user_memory=user_memory)
|
||||
else:
|
||||
memory_prompt = build_work_memory_prompt(work_memory=work_memory)
|
||||
|
||||
sections: list[str | None] = [
|
||||
_build_identity_section(),
|
||||
_build_env_section(
|
||||
@@ -228,7 +238,7 @@ def build_system_prompt(
|
||||
llm_config=llm_config,
|
||||
),
|
||||
build_tools_prompt(tools=tools) if tools else None,
|
||||
build_memory_prompt(memories=memories) if memories else None,
|
||||
memory_prompt,
|
||||
_build_output_rules(),
|
||||
]
|
||||
return "\n\n".join(item for item in sections if item).strip()
|
||||
|
||||
@@ -7,7 +7,7 @@ from agentscope.message import Msg
|
||||
from core.agentscope.runtime.runner import AgentScopeRunner
|
||||
from core.logging import get_logger
|
||||
from schemas.automation import RuntimeConfig
|
||||
from schemas.memories import MemoryListResponse
|
||||
from schemas.memories.memory_content import UserMemoryContent, WorkProfileContent
|
||||
from schemas.user import UserContext
|
||||
|
||||
logger = get_logger("core.agentscope.runtime.orchestrator")
|
||||
@@ -26,7 +26,8 @@ class RunnerLike(Protocol):
|
||||
pipeline: PipelineLike,
|
||||
run_input: RunAgentInput,
|
||||
runtime_config: RuntimeConfig,
|
||||
memories: MemoryListResponse | None,
|
||||
user_memory: UserMemoryContent | None,
|
||||
work_memory: WorkProfileContent | None,
|
||||
) -> dict[str, Any]: ...
|
||||
|
||||
|
||||
@@ -50,7 +51,8 @@ class AgentScopeRuntimeOrchestrator:
|
||||
context_messages: list[Msg],
|
||||
user_context: UserContext,
|
||||
runtime_config: RuntimeConfig,
|
||||
memories: MemoryListResponse | None = None,
|
||||
user_memory: UserMemoryContent | None = None,
|
||||
work_memory: WorkProfileContent | None = None,
|
||||
) -> dict[str, Any]:
|
||||
thread_id = run_input.thread_id
|
||||
run_id = run_input.run_id
|
||||
@@ -70,7 +72,8 @@ class AgentScopeRuntimeOrchestrator:
|
||||
pipeline=self._pipeline,
|
||||
run_input=run_input,
|
||||
runtime_config=runtime_config,
|
||||
memories=memories,
|
||||
user_memory=user_memory,
|
||||
work_memory=work_memory,
|
||||
)
|
||||
|
||||
await self._pipeline.emit(
|
||||
|
||||
@@ -41,7 +41,7 @@ from schemas.agent.system_agent import (
|
||||
SystemAgentLLMConfig,
|
||||
)
|
||||
from schemas.automation import RuntimeConfig
|
||||
from schemas.memories import MemoryListResponse
|
||||
from schemas.memories.memory_content import UserMemoryContent, WorkProfileContent
|
||||
from schemas.user import UserContext
|
||||
from services.litellm.service import LiteLLMService
|
||||
from sqlalchemy import select
|
||||
@@ -71,7 +71,8 @@ class AgentScopeRunner:
|
||||
pipeline: PipelineLike,
|
||||
run_input: RunAgentInput,
|
||||
runtime_config: RuntimeConfig,
|
||||
memories: MemoryListResponse | None = None,
|
||||
user_memory: UserMemoryContent | None = None,
|
||||
work_memory: WorkProfileContent | None = None,
|
||||
) -> dict[str, Any]:
|
||||
owner_id = UUID(user_context.id)
|
||||
runtime_client_time = self._resolve_runtime_client_time(run_input=run_input)
|
||||
@@ -98,7 +99,7 @@ class AgentScopeRunner:
|
||||
context_messages=context_messages,
|
||||
stage_config=router_config,
|
||||
runtime_client_time=runtime_client_time,
|
||||
memories=memories,
|
||||
user_memory=user_memory,
|
||||
)
|
||||
worker_output = await self._execute_worker_step(
|
||||
pipeline=pipeline,
|
||||
@@ -108,7 +109,7 @@ class AgentScopeRunner:
|
||||
toolkit=worker_toolkit,
|
||||
stage_config=worker_config,
|
||||
runtime_client_time=runtime_client_time,
|
||||
memories=memories,
|
||||
work_memory=work_memory,
|
||||
)
|
||||
return {
|
||||
"router": router_output.model_dump(mode="json", exclude_none=True),
|
||||
@@ -166,7 +167,7 @@ class AgentScopeRunner:
|
||||
context_messages: list[Msg],
|
||||
stage_config: SystemAgentRuntimeConfig,
|
||||
runtime_client_time: ClientTimeContext | None,
|
||||
memories: MemoryListResponse | None,
|
||||
user_memory: UserMemoryContent | None,
|
||||
) -> RouterAgentOutput:
|
||||
await self._emit_step_event(
|
||||
pipeline=pipeline,
|
||||
@@ -179,7 +180,7 @@ class AgentScopeRunner:
|
||||
context_messages=context_messages,
|
||||
stage_config=stage_config,
|
||||
runtime_client_time=runtime_client_time,
|
||||
memories=memories,
|
||||
user_memory=user_memory,
|
||||
run_input=run_input,
|
||||
)
|
||||
router_output = RouterAgentOutput.model_validate(router_result.payload)
|
||||
@@ -201,7 +202,7 @@ class AgentScopeRunner:
|
||||
toolkit: Any,
|
||||
stage_config: SystemAgentRuntimeConfig,
|
||||
runtime_client_time: ClientTimeContext | None,
|
||||
memories: MemoryListResponse | None,
|
||||
work_memory: WorkProfileContent | None,
|
||||
) -> WorkerAgentOutputLite:
|
||||
worker_output_model = resolve_worker_output_model(router_output.ui.ui_mode)
|
||||
await self._emit_step_event(
|
||||
@@ -221,7 +222,7 @@ class AgentScopeRunner:
|
||||
worker_output_model=worker_output_model,
|
||||
pipeline=pipeline,
|
||||
runtime_client_time=runtime_client_time,
|
||||
memories=memories,
|
||||
work_memory=work_memory,
|
||||
)
|
||||
worker_output = worker_output_model.model_validate(worker_result.payload)
|
||||
await self._emit_step_event(
|
||||
@@ -239,7 +240,7 @@ class AgentScopeRunner:
|
||||
context_messages: list[Msg],
|
||||
stage_config: SystemAgentRuntimeConfig,
|
||||
runtime_client_time: ClientTimeContext | None,
|
||||
memories: MemoryListResponse | None,
|
||||
user_memory: UserMemoryContent | None,
|
||||
run_input: RunAgentInput,
|
||||
) -> StageExecutionResult:
|
||||
messages_for_router = self._build_router_messages(
|
||||
@@ -260,7 +261,7 @@ class AgentScopeRunner:
|
||||
now_utc=datetime.now(timezone.utc),
|
||||
runtime_client_time=runtime_client_time,
|
||||
tools=None,
|
||||
memories=memories,
|
||||
user_memory=user_memory,
|
||||
),
|
||||
"system",
|
||||
),
|
||||
@@ -319,7 +320,7 @@ class AgentScopeRunner:
|
||||
worker_output_model: type[WorkerAgentOutputLite],
|
||||
pipeline: PipelineLike,
|
||||
runtime_client_time: ClientTimeContext | None,
|
||||
memories: MemoryListResponse | None,
|
||||
work_memory: WorkProfileContent | None,
|
||||
) -> StageExecutionResult:
|
||||
tracking_model = self._build_model(stage_config=stage_config)
|
||||
emitter = PipelineStageEmitter(
|
||||
@@ -340,7 +341,7 @@ class AgentScopeRunner:
|
||||
runtime_client_time=runtime_client_time,
|
||||
extra_context=stage_config.extra_context,
|
||||
tools=None,
|
||||
memories=memories,
|
||||
work_memory=work_memory,
|
||||
),
|
||||
toolkit=toolkit,
|
||||
model=tracking_model,
|
||||
|
||||
@@ -20,8 +20,8 @@ from core.config.settings import config
|
||||
from core.db.session import AsyncSessionLocal
|
||||
from core.logging import get_logger
|
||||
from core.taskiq.app import worker_agent_broker, worker_automation_broker
|
||||
from schemas.automation import MemoryContextConfig, RuntimeConfig
|
||||
from schemas.memories import MemoryListResponse
|
||||
from schemas.automation import MessageContextConfig, RuntimeConfig
|
||||
from schemas.memories.memory_content import UserMemoryContent, WorkProfileContent
|
||||
from schemas.messages.chat_message import (
|
||||
AgentChatMessageMetadata,
|
||||
extract_user_message_attachments,
|
||||
@@ -30,7 +30,7 @@ from schemas.user import UserContext
|
||||
from services.base.redis import get_or_init_redis_client
|
||||
from services.base.supabase import supabase_service
|
||||
from v1.agent.repository import AgentRepository
|
||||
from v1.memories.repository import MemoriesRepository
|
||||
from v1.memories.repository import SQLAlchemyMemoriesRepository
|
||||
from v1.memories.service import MemoriesService
|
||||
from v1.users.dependencies import get_user_service
|
||||
|
||||
@@ -83,7 +83,7 @@ async def _build_recent_context_messages(
|
||||
*,
|
||||
session: Any,
|
||||
thread_id: str,
|
||||
context_config: "MemoryContextConfig",
|
||||
context_config: "MessageContextConfig",
|
||||
) -> list[Msg]:
|
||||
context_service = AgentContextService(repository=AgentRepository(session))
|
||||
result = await context_service.load_context_messages(
|
||||
@@ -194,11 +194,16 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
orchestrator = _load_runtime()
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
current_user = CurrentUser(id=owner_id)
|
||||
user_context = await _build_user_context(owner_id=owner_id, session=session)
|
||||
memories_service = MemoriesService(MemoriesRepository(session))
|
||||
memories: MemoryListResponse = await memories_service.get_all_memories(
|
||||
owner_id=owner_id
|
||||
memories_service = MemoriesService(
|
||||
repository=SQLAlchemyMemoriesRepository(session),
|
||||
session=session,
|
||||
current_user=current_user,
|
||||
)
|
||||
memories_result = await memories_service.get_all_memories()
|
||||
user_memory: UserMemoryContent | None = memories_result.get("user_memory")
|
||||
work_memory: WorkProfileContent | None = memories_result.get("work_memory")
|
||||
|
||||
redis_client = await get_or_init_redis_client()
|
||||
bus = RedisStreamBus(
|
||||
@@ -229,7 +234,8 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
context_messages=context_messages,
|
||||
user_context=user_context,
|
||||
runtime_config=runtime_config,
|
||||
memories=memories,
|
||||
user_memory=user_memory,
|
||||
work_memory=work_memory,
|
||||
)
|
||||
logger.info(
|
||||
"agentscope runtime task completed",
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Any, Protocol
|
||||
|
||||
from schemas.agent.visibility import SystemVisibilityBit, bit_mask
|
||||
|
||||
from schemas.automation import ContextWindowMode, MemoryContextConfig
|
||||
from schemas.automation import ContextWindowMode, MessageContextConfig
|
||||
|
||||
|
||||
_DEFAULT_CONTEXT_WINDOW_USER_MESSAGES = 20
|
||||
@@ -86,7 +86,7 @@ class AgentContextService:
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
context_config: MemoryContextConfig,
|
||||
context_config: MessageContextConfig,
|
||||
) -> dict[str, object] | None:
|
||||
visibility_mask = bit_mask(bit=int(SystemVisibilityBit.CONTEXT_ASSEMBLY))
|
||||
context_loader = CONTEXT_LOADER_REGISTRY.resolve(
|
||||
|
||||
@@ -6,10 +6,16 @@ from core.agentscope.tools.custom.calendar import (
|
||||
from core.agentscope.tools.custom.user_lookup import (
|
||||
user_lookup,
|
||||
)
|
||||
from core.agentscope.tools.custom.memory import (
|
||||
memory_forget,
|
||||
memory_write,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"calendar_read",
|
||||
"calendar_write",
|
||||
"calendar_share",
|
||||
"user_lookup",
|
||||
"memory_write",
|
||||
"memory_forget",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,330 @@
|
||||
from copy import deepcopy
|
||||
from typing import Annotated, Any, cast
|
||||
from uuid import UUID
|
||||
|
||||
from agentscope.tool import ToolResponse
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.agentscope.tools.tool_call_context import get_current_tool_call_id
|
||||
from core.agentscope.tools.utils.memory_domain import (
|
||||
create_memories_service,
|
||||
map_memory_exception,
|
||||
)
|
||||
from core.agentscope.tools.utils.tool_response_builder import (
|
||||
build_error_output,
|
||||
build_tool_response,
|
||||
)
|
||||
from models.memories import MemoryType
|
||||
from schemas.agent.runtime_models import ToolAgentOutput, ToolStatus
|
||||
from schemas.memories.memory_content import UserMemoryContent, WorkProfileContent
|
||||
|
||||
|
||||
class MemoryWriteArgs(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
memory_type: MemoryType = MemoryType.USER
|
||||
user_content: UserMemoryContent | None = None
|
||||
work_content: WorkProfileContent | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_content(self) -> "MemoryWriteArgs":
|
||||
if self.memory_type == MemoryType.USER:
|
||||
if self.user_content is None or self.work_content is not None:
|
||||
raise ValueError("memory_type=user requires user_content only")
|
||||
else:
|
||||
if self.work_content is None or self.user_content is not None:
|
||||
raise ValueError("memory_type=work requires work_content only")
|
||||
return self
|
||||
|
||||
|
||||
class MemoryForgetArgs(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
memory_type: MemoryType = MemoryType.USER
|
||||
forget_paths: list[str] = Field(min_length=1, max_length=100)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_forget_paths(self) -> "MemoryForgetArgs":
|
||||
allowed_roots = (
|
||||
set(UserMemoryContent.model_fields)
|
||||
if self.memory_type == MemoryType.USER
|
||||
else set(WorkProfileContent.model_fields)
|
||||
)
|
||||
normalized: list[str] = []
|
||||
for raw_path in self.forget_paths:
|
||||
path = raw_path.strip()
|
||||
if not path:
|
||||
continue
|
||||
parts = [part for part in path.split(".") if part]
|
||||
if not parts:
|
||||
continue
|
||||
if len(parts) > 5:
|
||||
raise ValueError("forget path depth exceeds limit")
|
||||
if parts[0] not in allowed_roots:
|
||||
raise ValueError("forget path root is not allowed")
|
||||
normalized.append(path)
|
||||
if not normalized:
|
||||
raise ValueError("forget_paths cannot be empty")
|
||||
self.forget_paths = normalized
|
||||
return self
|
||||
|
||||
|
||||
def _memory_error_output(
|
||||
*,
|
||||
tool_name: str,
|
||||
tool_call_args: dict[str, Any],
|
||||
code: str,
|
||||
message: str,
|
||||
retryable: bool,
|
||||
) -> ToolResponse:
|
||||
output = build_error_output(
|
||||
tool_name=tool_name,
|
||||
tool_call_id=get_current_tool_call_id(tool_name=tool_name),
|
||||
code=code,
|
||||
message=message,
|
||||
retryable=retryable,
|
||||
)
|
||||
output = output.model_copy(update={"tool_call_args": tool_call_args})
|
||||
return build_tool_response(output)
|
||||
|
||||
|
||||
def _validate_runtime_context(
|
||||
*,
|
||||
tool_name: str,
|
||||
tool_call_args: dict[str, Any],
|
||||
session: Any,
|
||||
owner_id: Any,
|
||||
) -> ToolResponse | None:
|
||||
if session is None or owner_id is None:
|
||||
return _memory_error_output(
|
||||
tool_name=tool_name,
|
||||
tool_call_args=tool_call_args,
|
||||
code="MISSING_RUNTIME_ARGS",
|
||||
message="记忆工具缺少运行时参数",
|
||||
retryable=False,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _deep_merge_dict(base: dict[str, Any], patch: dict[str, Any]) -> dict[str, Any]:
|
||||
merged = deepcopy(base)
|
||||
for key, value in patch.items():
|
||||
if isinstance(value, dict) and isinstance(merged.get(key), dict):
|
||||
merged[key] = _deep_merge_dict(cast(dict[str, Any], merged[key]), value)
|
||||
else:
|
||||
merged[key] = value
|
||||
return merged
|
||||
|
||||
|
||||
def _remove_content_paths(
|
||||
base_payload: dict[str, Any],
|
||||
paths: list[str],
|
||||
) -> tuple[dict[str, Any], list[str]]:
|
||||
result = deepcopy(base_payload)
|
||||
removed: list[str] = []
|
||||
for raw_path in paths:
|
||||
path = raw_path.strip()
|
||||
if not path:
|
||||
continue
|
||||
keys = [part for part in path.split(".") if part]
|
||||
if not keys:
|
||||
continue
|
||||
if _delete_nested_path(result, keys):
|
||||
removed.append(path)
|
||||
return result, removed
|
||||
|
||||
|
||||
def _delete_nested_path(payload: dict[str, Any], keys: list[str]) -> bool:
|
||||
current: dict[str, Any] = payload
|
||||
for key in keys[:-1]:
|
||||
next_value = current.get(key)
|
||||
if not isinstance(next_value, dict):
|
||||
return False
|
||||
current = next_value
|
||||
leaf = keys[-1]
|
||||
if leaf in current:
|
||||
del current[leaf]
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def memory_write(
|
||||
memory_type: Annotated[
|
||||
str,
|
||||
Field(description="Memory type: user or work."),
|
||||
] = "user",
|
||||
user_content: Annotated[
|
||||
UserMemoryContent | None,
|
||||
Field(description="Patch payload for user memory content."),
|
||||
] = None,
|
||||
work_content: Annotated[
|
||||
WorkProfileContent | None,
|
||||
Field(description="Patch payload for work memory content."),
|
||||
] = None,
|
||||
session: Any = None,
|
||||
owner_id: Any = None,
|
||||
) -> ToolResponse:
|
||||
tool_name = "memory_write"
|
||||
tool_call_args: dict[str, Any] = {
|
||||
"memory_type": memory_type,
|
||||
"user_content": user_content,
|
||||
"work_content": work_content,
|
||||
}
|
||||
runtime_error = _validate_runtime_context(
|
||||
tool_name=tool_name,
|
||||
tool_call_args=tool_call_args,
|
||||
session=session,
|
||||
owner_id=owner_id,
|
||||
)
|
||||
if runtime_error is not None:
|
||||
return runtime_error
|
||||
|
||||
try:
|
||||
parsed_args = MemoryWriteArgs.model_validate(tool_call_args)
|
||||
service = create_memories_service(
|
||||
session=cast(AsyncSession, session),
|
||||
owner_id=cast(UUID, owner_id),
|
||||
)
|
||||
existing = await service.get_memory_model(memory_type=parsed_args.memory_type)
|
||||
|
||||
if parsed_args.memory_type == MemoryType.USER:
|
||||
base_model = (
|
||||
UserMemoryContent.model_validate(existing.content)
|
||||
if existing is not None
|
||||
else UserMemoryContent()
|
||||
)
|
||||
patch_model = cast(UserMemoryContent, parsed_args.user_content)
|
||||
merged = _deep_merge_dict(
|
||||
base_model.model_dump(),
|
||||
patch_model.model_dump(exclude_unset=True),
|
||||
)
|
||||
validated = UserMemoryContent.model_validate(merged)
|
||||
await service.update_user_memory(
|
||||
content=validated,
|
||||
)
|
||||
else:
|
||||
base_model = (
|
||||
WorkProfileContent.model_validate(existing.content)
|
||||
if existing is not None
|
||||
else WorkProfileContent()
|
||||
)
|
||||
patch_model = cast(WorkProfileContent, parsed_args.work_content)
|
||||
merged = _deep_merge_dict(
|
||||
base_model.model_dump(),
|
||||
patch_model.model_dump(exclude_unset=True),
|
||||
)
|
||||
validated = WorkProfileContent.model_validate(merged)
|
||||
await service.update_work_memory(
|
||||
content=validated,
|
||||
)
|
||||
|
||||
summary = f"status=success memory_type={parsed_args.memory_type.value}"
|
||||
return build_tool_response(
|
||||
ToolAgentOutput(
|
||||
tool_name=tool_name,
|
||||
tool_call_id=get_current_tool_call_id(tool_name=tool_name),
|
||||
tool_call_args=tool_call_args,
|
||||
status=ToolStatus.SUCCESS,
|
||||
result=summary,
|
||||
)
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
code, message, retryable = map_memory_exception(exc)
|
||||
return _memory_error_output(
|
||||
tool_name=tool_name,
|
||||
tool_call_args=tool_call_args,
|
||||
code=code,
|
||||
message=message,
|
||||
retryable=retryable,
|
||||
)
|
||||
|
||||
|
||||
async def memory_forget(
|
||||
memory_type: Annotated[
|
||||
str,
|
||||
Field(description="Memory type: user or work."),
|
||||
] = "user",
|
||||
forget_paths: Annotated[
|
||||
list[str] | None,
|
||||
Field(description="Dot paths to remove from content."),
|
||||
] = None,
|
||||
session: Any = None,
|
||||
owner_id: Any = None,
|
||||
) -> ToolResponse:
|
||||
tool_name = "memory_forget"
|
||||
tool_call_args: dict[str, Any] = {
|
||||
"memory_type": memory_type,
|
||||
"forget_paths": forget_paths or [],
|
||||
}
|
||||
runtime_error = _validate_runtime_context(
|
||||
tool_name=tool_name,
|
||||
tool_call_args=tool_call_args,
|
||||
session=session,
|
||||
owner_id=owner_id,
|
||||
)
|
||||
if runtime_error is not None:
|
||||
return runtime_error
|
||||
|
||||
try:
|
||||
parsed_args = MemoryForgetArgs.model_validate(tool_call_args)
|
||||
service = create_memories_service(
|
||||
session=cast(AsyncSession, session),
|
||||
owner_id=cast(UUID, owner_id),
|
||||
)
|
||||
existing = await service.get_memory_model(memory_type=parsed_args.memory_type)
|
||||
if existing is None:
|
||||
summary = f"status=success memory_type={parsed_args.memory_type.value} forgotten=0"
|
||||
return build_tool_response(
|
||||
ToolAgentOutput(
|
||||
tool_name=tool_name,
|
||||
tool_call_id=get_current_tool_call_id(tool_name=tool_name),
|
||||
tool_call_args=tool_call_args,
|
||||
status=ToolStatus.SUCCESS,
|
||||
result=summary,
|
||||
)
|
||||
)
|
||||
|
||||
if parsed_args.memory_type == MemoryType.USER:
|
||||
base_model = UserMemoryContent.model_validate(existing.content)
|
||||
updated_dict, removed_paths = _remove_content_paths(
|
||||
base_model.model_dump(),
|
||||
parsed_args.forget_paths,
|
||||
)
|
||||
validated = UserMemoryContent.model_validate(updated_dict)
|
||||
await service.update_user_memory(
|
||||
content=validated,
|
||||
)
|
||||
else:
|
||||
base_model = WorkProfileContent.model_validate(existing.content)
|
||||
updated_dict, removed_paths = _remove_content_paths(
|
||||
base_model.model_dump(),
|
||||
parsed_args.forget_paths,
|
||||
)
|
||||
validated = WorkProfileContent.model_validate(updated_dict)
|
||||
await service.update_work_memory(
|
||||
content=validated,
|
||||
)
|
||||
|
||||
summary = (
|
||||
f"status=success memory_type={parsed_args.memory_type.value} forgotten={len(removed_paths)} "
|
||||
f"skipped=0"
|
||||
)
|
||||
return build_tool_response(
|
||||
ToolAgentOutput(
|
||||
tool_name=tool_name,
|
||||
tool_call_id=get_current_tool_call_id(tool_name=tool_name),
|
||||
tool_call_args=tool_call_args,
|
||||
status=ToolStatus.SUCCESS,
|
||||
result=summary,
|
||||
)
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
code, message, retryable = map_memory_exception(exc)
|
||||
return _memory_error_output(
|
||||
tool_name=tool_name,
|
||||
tool_call_args=tool_call_args,
|
||||
code=code,
|
||||
message=message,
|
||||
retryable=retryable,
|
||||
)
|
||||
@@ -4,17 +4,13 @@ from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ToolGroup(str, Enum):
|
||||
READ = "read"
|
||||
EXECUTE = "execute"
|
||||
MEMORY = "memory"
|
||||
|
||||
|
||||
class AgentTool(str, Enum):
|
||||
CALENDAR_READ = "calendar.read"
|
||||
CALENDAR_WRITE = "calendar.write"
|
||||
CALENDAR_SHARE = "calendar.share"
|
||||
USER_LOOKUP = "user.lookup"
|
||||
MEMORY_WRITE = "memory.write"
|
||||
MEMORY_FORGET = "memory.forget"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -25,29 +21,32 @@ class ToolApprovalConfig:
|
||||
@dataclass(frozen=True)
|
||||
class ToolConfig:
|
||||
name: str
|
||||
group: ToolGroup
|
||||
approval: ToolApprovalConfig
|
||||
|
||||
|
||||
TOOL_CONFIGS: dict[str, ToolConfig] = {
|
||||
"calendar_read": ToolConfig(
|
||||
name="calendar_read",
|
||||
group=ToolGroup.READ,
|
||||
approval=ToolApprovalConfig(required=False),
|
||||
),
|
||||
"user_lookup": ToolConfig(
|
||||
name="user_lookup",
|
||||
group=ToolGroup.MEMORY,
|
||||
approval=ToolApprovalConfig(required=False),
|
||||
),
|
||||
"calendar_write": ToolConfig(
|
||||
name="calendar_write",
|
||||
group=ToolGroup.EXECUTE,
|
||||
approval=ToolApprovalConfig(required=False),
|
||||
),
|
||||
"calendar_share": ToolConfig(
|
||||
name="calendar_share",
|
||||
group=ToolGroup.EXECUTE,
|
||||
approval=ToolApprovalConfig(required=False),
|
||||
),
|
||||
"memory_write": ToolConfig(
|
||||
name="memory_write",
|
||||
approval=ToolApprovalConfig(required=False),
|
||||
),
|
||||
"memory_forget": ToolConfig(
|
||||
name="memory_forget",
|
||||
approval=ToolApprovalConfig(required=False),
|
||||
),
|
||||
}
|
||||
@@ -57,6 +56,8 @@ AGENT_TOOL_TO_FUNCTION_NAME: dict[AgentTool, str] = {
|
||||
AgentTool.CALENDAR_WRITE: "calendar_write",
|
||||
AgentTool.CALENDAR_SHARE: "calendar_share",
|
||||
AgentTool.USER_LOOKUP: "user_lookup",
|
||||
AgentTool.MEMORY_WRITE: "memory_write",
|
||||
AgentTool.MEMORY_FORGET: "memory_forget",
|
||||
}
|
||||
|
||||
TOOL_NAME_ALIASES: dict[str, AgentTool] = {
|
||||
@@ -68,6 +69,10 @@ TOOL_NAME_ALIASES: dict[str, AgentTool] = {
|
||||
"calendar_share": AgentTool.CALENDAR_SHARE,
|
||||
AgentTool.USER_LOOKUP.value: AgentTool.USER_LOOKUP,
|
||||
"user_lookup": AgentTool.USER_LOOKUP,
|
||||
AgentTool.MEMORY_WRITE.value: AgentTool.MEMORY_WRITE,
|
||||
"memory_write": AgentTool.MEMORY_WRITE,
|
||||
AgentTool.MEMORY_FORGET.value: AgentTool.MEMORY_FORGET,
|
||||
"memory_forget": AgentTool.MEMORY_FORGET,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -10,6 +10,10 @@ from core.agentscope.tools.custom.calendar import (
|
||||
calendar_share,
|
||||
calendar_write,
|
||||
)
|
||||
from core.agentscope.tools.custom.memory import (
|
||||
memory_forget,
|
||||
memory_write,
|
||||
)
|
||||
from core.agentscope.tools.custom.user_lookup import user_lookup
|
||||
from core.agentscope.tools.tool_config import (
|
||||
TOOL_CONFIGS,
|
||||
@@ -23,6 +27,8 @@ TOOL_FUNCTIONS: dict[str, Any] = {
|
||||
"calendar_write": calendar_write,
|
||||
"calendar_share": calendar_share,
|
||||
"user_lookup": user_lookup,
|
||||
"memory_write": memory_write,
|
||||
"memory_forget": memory_forget,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from v1.memories.repository import SQLAlchemyMemoriesRepository
|
||||
from v1.memories.service import MemoriesService
|
||||
|
||||
|
||||
def create_memories_service(
|
||||
session: AsyncSession,
|
||||
owner_id: UUID,
|
||||
) -> MemoriesService:
|
||||
return MemoriesService(
|
||||
repository=SQLAlchemyMemoriesRepository(session),
|
||||
session=session,
|
||||
current_user=CurrentUser(id=owner_id),
|
||||
)
|
||||
|
||||
|
||||
def map_memory_exception(exc: Exception) -> tuple[str, str, bool]:
|
||||
if isinstance(exc, HTTPException):
|
||||
detail = exc.detail
|
||||
if isinstance(detail, str) and detail.strip():
|
||||
return "OPERATION_FAILED", detail.strip(), exc.status_code >= 500
|
||||
return "OPERATION_FAILED", "记忆操作失败", exc.status_code >= 500
|
||||
if isinstance(exc, ValueError):
|
||||
return "INVALID_ARGUMENT", "请求参数无效", False
|
||||
return "INTERNAL_ERROR", "记忆操作失败", True
|
||||
@@ -1,11 +1,3 @@
|
||||
from core.automation.scheduler import (
|
||||
AutomationSchedulerService,
|
||||
DispatchResult,
|
||||
SqlAlchemyAutomationSchedulerRepository,
|
||||
)
|
||||
from core.automation.scheduler import run_automation_scheduler_scan
|
||||
|
||||
__all__ = [
|
||||
"AutomationSchedulerService",
|
||||
"DispatchResult",
|
||||
"SqlAlchemyAutomationSchedulerRepository",
|
||||
]
|
||||
__all__ = ["run_automation_scheduler_scan"]
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
input_template: 请基于最近两天用户聊天上下文提取用户记忆;如果已有记忆内容变化请更新;如果记忆已失效请执行遗忘。
|
||||
enabled_tools:
|
||||
- memory.write
|
||||
- memory.forget
|
||||
context:
|
||||
source: latest_chat
|
||||
window_mode: day
|
||||
window_count: 2
|
||||
@@ -32,6 +32,10 @@ class AutomationJob(TimestampMixin, SoftDeleteMixin, Base):
|
||||
UUID(as_uuid=True),
|
||||
nullable=False,
|
||||
)
|
||||
bootstrap_key: Mapped[str | None] = mapped_column(
|
||||
String(64),
|
||||
nullable=True,
|
||||
)
|
||||
title: Mapped[str] = mapped_column(
|
||||
String(255),
|
||||
nullable=False,
|
||||
|
||||
@@ -16,12 +16,6 @@ class MemoryType(str, Enum):
|
||||
WORK = "work"
|
||||
|
||||
|
||||
class MemorySource(str, Enum):
|
||||
MANUAL = "manual"
|
||||
AGENT = "agent"
|
||||
IMPORTED = "imported"
|
||||
|
||||
|
||||
class MemoryStatus(str, Enum):
|
||||
ACTIVE = "active"
|
||||
DISABLED = "disabled"
|
||||
@@ -46,15 +40,10 @@ class Memory(TimestampMixin, Base):
|
||||
String(20),
|
||||
nullable=False,
|
||||
)
|
||||
title: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
content: Mapped[dict] = mapped_column(
|
||||
json_jsonb,
|
||||
nullable=False,
|
||||
)
|
||||
source: Mapped[MemorySource] = mapped_column(
|
||||
String(20),
|
||||
nullable=False,
|
||||
)
|
||||
status: Mapped[MemoryStatus] = mapped_column(
|
||||
String(20),
|
||||
nullable=False,
|
||||
|
||||
@@ -12,7 +12,6 @@ from schemas.inbox.messages import (
|
||||
parse_calendar_content,
|
||||
)
|
||||
from schemas.invite_codes import InviteCodeRewardConfig
|
||||
from schemas.memories import MemoryContext
|
||||
from schemas.messages import AgentChatMessageMetadata
|
||||
from schemas.schedule.items import (
|
||||
AttachmentType,
|
||||
@@ -36,7 +35,6 @@ __all__ = [
|
||||
"InboxMessageStatus",
|
||||
"InboxMessageType",
|
||||
"InviteCodeRewardConfig",
|
||||
"MemoryContext",
|
||||
"ScheduleItemMetadata",
|
||||
"ScheduleItemMetadataAttachment",
|
||||
"ScheduleItemSourceType",
|
||||
|
||||
@@ -20,7 +20,7 @@ class ContextWindowMode(str, Enum):
|
||||
NUMBER = "number"
|
||||
|
||||
|
||||
class MemoryContextConfig(BaseModel):
|
||||
class MessageContextConfig(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
source: ContextSource = ContextSource.LATEST_CHAT
|
||||
@@ -32,7 +32,7 @@ class RuntimeConfig(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
enabled_tools: list[AgentTool] = Field(default_factory=list, max_length=32)
|
||||
context: MemoryContextConfig = Field(default_factory=MemoryContextConfig)
|
||||
context: MessageContextConfig = Field(default_factory=MessageContextConfig)
|
||||
|
||||
|
||||
class AutomationJobConfig(RuntimeConfig):
|
||||
@@ -46,6 +46,7 @@ class AutomationJob(BaseModel):
|
||||
|
||||
id: UUID
|
||||
owner_id: UUID
|
||||
bootstrap_key: str | None = Field(default=None, min_length=1, max_length=64)
|
||||
title: str = Field(..., min_length=1, max_length=255)
|
||||
config: AutomationJobConfig
|
||||
schedule_type: ScheduleType
|
||||
@@ -63,6 +64,7 @@ class AutomationJob(BaseModel):
|
||||
return cls(
|
||||
id=obj.id,
|
||||
owner_id=obj.owner_id,
|
||||
bootstrap_key=obj.bootstrap_key,
|
||||
title=obj.title,
|
||||
config=AutomationJobConfig.model_validate(obj.config or {}),
|
||||
schedule_type=obj.schedule_type,
|
||||
|
||||
@@ -2,10 +2,19 @@ from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, ClassVar, Literal
|
||||
from typing import ClassVar, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from schemas.memories.memory_content import (
|
||||
TeamMember,
|
||||
UserMemoryContent,
|
||||
UserPreferences,
|
||||
WorkHabit,
|
||||
WorkProfileContent,
|
||||
WorkProject,
|
||||
)
|
||||
|
||||
|
||||
class MemoryType(str, Enum):
|
||||
@@ -13,12 +22,6 @@ class MemoryType(str, Enum):
|
||||
WORK = "work"
|
||||
|
||||
|
||||
class MemorySource(str, Enum):
|
||||
MANUAL = "manual"
|
||||
AGENT = "agent"
|
||||
IMPORTED = "imported"
|
||||
|
||||
|
||||
class MemoryStatus(str, Enum):
|
||||
ACTIVE = "active"
|
||||
DISABLED = "disabled"
|
||||
@@ -33,38 +36,20 @@ class MemoryModel(BaseModel):
|
||||
owner_id: UUID
|
||||
agent_id: UUID | None = None
|
||||
memory_type: Literal["user", "work"]
|
||||
title: str | None = None
|
||||
content: dict[str, Any]
|
||||
source: MemorySource
|
||||
content: UserMemoryContent | WorkProfileContent
|
||||
status: MemoryStatus
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class MemoryContext(BaseModel):
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
memory_type: MemoryType
|
||||
source: MemorySource
|
||||
title: str | None = None
|
||||
content: dict[str, Any]
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class MemoryListResponse(BaseModel):
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
owner_id: UUID
|
||||
memories: list[MemoryContext] = Field(default_factory=list)
|
||||
total: int
|
||||
|
||||
|
||||
__all__ = [
|
||||
"MemoryContext",
|
||||
"MemoryListResponse",
|
||||
"MemoryModel",
|
||||
"MemorySource",
|
||||
"MemoryStatus",
|
||||
"MemoryType",
|
||||
"TeamMember",
|
||||
"UserMemoryContent",
|
||||
"UserPreferences",
|
||||
"WorkHabit",
|
||||
"WorkProfileContent",
|
||||
"WorkProject",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,192 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date, datetime
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
Weekday = Literal["mon", "tue", "wed", "thu", "fri", "sat", "sun"]
|
||||
ProjectStatus = Literal["planned", "active", "paused", "completed"]
|
||||
PreferenceLevel = Literal["like", "neutral", "avoid"]
|
||||
MemorySource = Literal["user", "inferred", "calendar", "email", "agent"]
|
||||
|
||||
|
||||
class MemoryMeta(BaseModel):
|
||||
source: MemorySource | None = Field(default=None, description="记忆来源")
|
||||
confidence: float = Field(default=1.0, ge=0.0, le=1.0, description="置信度")
|
||||
last_updated_at: datetime | None = Field(default=None, description="最后更新时间")
|
||||
|
||||
|
||||
class TimeWindow(BaseModel):
|
||||
weekdays: list[Weekday] = Field(default_factory=list, description="适用星期")
|
||||
start: str = Field(description="开始时间,HH:MM")
|
||||
end: str = Field(description="结束时间,HH:MM")
|
||||
|
||||
|
||||
class PersonMemory(BaseModel):
|
||||
name: str = Field(description="人物姓名")
|
||||
relationship: str | None = Field(
|
||||
default=None, description="与用户关系,如家人/同事/导师/朋友"
|
||||
)
|
||||
role: str | None = Field(default=None, description="角色,如老板/导师/合作方")
|
||||
preferred_contact_channel: str | None = Field(
|
||||
default=None, description="偏好联系方式"
|
||||
)
|
||||
notes: str | None = Field(default=None, description="补充说明")
|
||||
meta: MemoryMeta = Field(default_factory=MemoryMeta)
|
||||
|
||||
|
||||
class PlaceMemory(BaseModel):
|
||||
name: str = Field(description="地点名称")
|
||||
category: str | None = Field(
|
||||
default=None, description="地点类别,如home/office/gym/cafe"
|
||||
)
|
||||
address: str | None = Field(default=None, description="地址")
|
||||
timezone: str | None = Field(default=None, description="地点时区")
|
||||
commute_minutes: int | None = Field(default=None, ge=0, description="典型通勤时长")
|
||||
preference: PreferenceLevel | None = Field(default=None, description="地点偏好")
|
||||
notes: str | None = Field(default=None, description="补充说明")
|
||||
meta: MemoryMeta = Field(default_factory=MemoryMeta)
|
||||
|
||||
|
||||
class UserPreferences(BaseModel):
|
||||
communication_style: str | None = Field(
|
||||
default=None, description="沟通风格,如简洁直接"
|
||||
)
|
||||
language_preference: list[str] = Field(default_factory=list, description="语言偏好")
|
||||
location_preference: str | None = Field(
|
||||
default=None, description="地点偏好,如喜欢远程"
|
||||
)
|
||||
work_lifestyle: str | None = Field(default=None, description="作息方式,如早睡早起")
|
||||
notification_preference: list[str] = Field(
|
||||
default_factory=list, description="通知方式偏好"
|
||||
)
|
||||
|
||||
|
||||
class SchedulingPreferences(BaseModel):
|
||||
productive_windows: list[TimeWindow] = Field(
|
||||
default_factory=list, description="高效率时段"
|
||||
)
|
||||
preferred_meeting_windows: list[TimeWindow] = Field(
|
||||
default_factory=list, description="偏好的会议时段"
|
||||
)
|
||||
no_meeting_windows: list[TimeWindow] = Field(
|
||||
default_factory=list, description="尽量不安排会议的时段"
|
||||
)
|
||||
deep_work_windows: list[TimeWindow] = Field(
|
||||
default_factory=list, description="深度工作时段"
|
||||
)
|
||||
preferred_meeting_duration_minutes: list[int] = Field(
|
||||
default_factory=lambda: [30, 60], description="偏好的会议时长"
|
||||
)
|
||||
meeting_buffer_minutes: int | None = Field(
|
||||
default=None, ge=0, description="会议间缓冲时间"
|
||||
)
|
||||
max_meetings_per_day: int | None = Field(
|
||||
default=None, ge=0, description="单日会议上限"
|
||||
)
|
||||
notes: str | None = Field(default=None, description="其他排程说明")
|
||||
|
||||
|
||||
class RecurringRoutine(BaseModel):
|
||||
name: str = Field(description="周期性安排名称")
|
||||
description: str | None = Field(default=None, description="周期性安排描述")
|
||||
cadence: str | None = Field(
|
||||
default=None, description="频率,如daily/weekly/monthly"
|
||||
)
|
||||
time_windows: list[TimeWindow] = Field(
|
||||
default_factory=list, description="通常发生时段"
|
||||
)
|
||||
importance: str | None = Field(default=None, description="重要程度")
|
||||
meta: MemoryMeta = Field(default_factory=MemoryMeta)
|
||||
|
||||
|
||||
class UserMemoryContent(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
occupation: str | None = Field(default=None, description="职业")
|
||||
timezone: str | None = Field(default=None, description="时区")
|
||||
primary_language: str | None = Field(default=None, description="主要语言")
|
||||
people: list[PersonMemory] = Field(default_factory=list, description="重要人物")
|
||||
places: list[PlaceMemory] = Field(default_factory=list, description="常去地点")
|
||||
preferences: UserPreferences = Field(default_factory=UserPreferences)
|
||||
scheduling_preferences: SchedulingPreferences = Field(
|
||||
default_factory=SchedulingPreferences
|
||||
)
|
||||
interests: list[str] = Field(default_factory=list, description="兴趣爱好")
|
||||
avoid_topics: list[str] = Field(default_factory=list, description="不想讨论的话题")
|
||||
custom_rules: list[str] = Field(default_factory=list, description="用户自定义规则")
|
||||
recurring_routines: list[RecurringRoutine] = Field(
|
||||
default_factory=list, description="周期性习惯/安排"
|
||||
)
|
||||
|
||||
|
||||
class Milestone(BaseModel):
|
||||
name: str = Field(description="里程碑名称")
|
||||
due_date: date | None = Field(default=None, description="截止日期")
|
||||
status: str | None = Field(default=None, description="状态")
|
||||
notes: str | None = Field(default=None, description="补充说明")
|
||||
|
||||
|
||||
class WorkProject(BaseModel):
|
||||
name: str = Field(description="项目名")
|
||||
description: str | None = Field(default=None, description="项目描述")
|
||||
status: ProjectStatus | None = Field(default=None, description="项目状态")
|
||||
priority: str | None = Field(default=None, description="项目优先级")
|
||||
deadline: date | None = Field(default=None, description="项目截止时间")
|
||||
collaborators: list[str] = Field(default_factory=list, description="协作人")
|
||||
key_milestones: list[Milestone] = Field(
|
||||
default_factory=list, description="关键里程碑"
|
||||
)
|
||||
notes: str | None = Field(default=None, description="补充说明")
|
||||
meta: MemoryMeta = Field(default_factory=MemoryMeta)
|
||||
|
||||
|
||||
class WorkHabit(BaseModel):
|
||||
available_hours: list[TimeWindow] = Field(
|
||||
default_factory=list, description="常规可工作时间"
|
||||
)
|
||||
deep_work_blocks: list[TimeWindow] = Field(
|
||||
default_factory=list, description="偏好的深度工作时间"
|
||||
)
|
||||
preferred_meeting_windows: list[TimeWindow] = Field(
|
||||
default_factory=list, description="偏好的会议时间"
|
||||
)
|
||||
no_meeting_windows: list[TimeWindow] = Field(
|
||||
default_factory=list, description="不希望开会的时间"
|
||||
)
|
||||
preferred_meeting_duration_minutes: list[int] = Field(
|
||||
default_factory=lambda: [30, 60], description="偏好的会议时长"
|
||||
)
|
||||
notification_channel: str | None = Field(default=None, description="首选沟通渠道")
|
||||
notes: str | None = Field(default=None, description="补充说明")
|
||||
|
||||
|
||||
class TeamMember(BaseModel):
|
||||
name: str = Field(description="成员姓名")
|
||||
role: str | None = Field(default=None, description="团队角色")
|
||||
relationship: str | None = Field(
|
||||
default=None, description="关系,如直属上级/同事/合作方"
|
||||
)
|
||||
preferred_contact_channel: str | None = Field(
|
||||
default=None, description="偏好沟通渠道"
|
||||
)
|
||||
notes: str | None = Field(default=None, description="补充说明")
|
||||
meta: MemoryMeta = Field(default_factory=MemoryMeta)
|
||||
|
||||
|
||||
class WorkProfileContent(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
occupation: str | None = Field(default=None, description="职业身份")
|
||||
expertise: list[str] = Field(default_factory=list, description="专业领域")
|
||||
preferred_tools: list[str] = Field(default_factory=list, description="惯用工具")
|
||||
current_projects: list[WorkProject] = Field(
|
||||
default_factory=list, description="长期项目画像"
|
||||
)
|
||||
work_habits: WorkHabit = Field(default_factory=WorkHabit)
|
||||
team_members: list[TeamMember] = Field(default_factory=list, description="团队成员")
|
||||
team_context: str | None = Field(default=None, description="团队背景")
|
||||
work_rules: list[str] = Field(
|
||||
default_factory=list, description="工作规则或默认原则"
|
||||
)
|
||||
@@ -9,11 +9,12 @@ from pathlib import Path
|
||||
import yaml
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.agentscope.tools.tool_config import AgentTool
|
||||
from schemas.agent.system_agent import SystemAgentLLMConfig
|
||||
from schemas.automation import (
|
||||
ContextSource,
|
||||
ContextWindowMode,
|
||||
MemoryContextConfig,
|
||||
MessageContextConfig,
|
||||
RuntimeConfig,
|
||||
)
|
||||
|
||||
@@ -38,9 +39,9 @@ def _load_system_agents_yaml(path: Path | None = None) -> dict:
|
||||
return loaded
|
||||
|
||||
|
||||
def _parse_context_messages_config(yaml_config: dict | None) -> MemoryContextConfig:
|
||||
def _parse_context_messages_config(yaml_config: dict | None) -> MessageContextConfig:
|
||||
if not yaml_config:
|
||||
return MemoryContextConfig()
|
||||
return MessageContextConfig()
|
||||
mode_str = yaml_config.get("mode", "day")
|
||||
count = yaml_config.get("count", 2)
|
||||
try:
|
||||
@@ -51,7 +52,7 @@ def _parse_context_messages_config(yaml_config: dict | None) -> MemoryContextCon
|
||||
window_mode = ContextWindowMode(mode_str)
|
||||
except ValueError:
|
||||
window_mode = ContextWindowMode.DAY
|
||||
return MemoryContextConfig(
|
||||
return MessageContextConfig(
|
||||
source=source,
|
||||
window_mode=window_mode,
|
||||
window_count=count,
|
||||
@@ -93,9 +94,9 @@ def build_runtime_config_from_system_agents(
|
||||
router_config.context_messages.model_dump() if router_config else None
|
||||
)
|
||||
|
||||
enabled_tools: list[str] = []
|
||||
enabled_tools: list[AgentTool] = []
|
||||
if worker_config and worker_config.enabled_tools:
|
||||
enabled_tools = [str(t) for t in worker_config.enabled_tools]
|
||||
enabled_tools = list(worker_config.enabled_tools)
|
||||
|
||||
return RuntimeConfig(
|
||||
enabled_tools=enabled_tools,
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
from core.agentscope.tools.tool_config import AgentTool
|
||||
from schemas.automation import AutomationJobConfig, MessageContextConfig
|
||||
|
||||
_CONFIG_NAME_PATTERN = re.compile(r"^[a-z0-9][a-z0-9_-]{0,63}$")
|
||||
|
||||
|
||||
def _automation_yaml_path(config_name: str) -> Path:
|
||||
if not _CONFIG_NAME_PATTERN.fullmatch(config_name):
|
||||
raise ValueError("invalid automation config name")
|
||||
return (
|
||||
Path(__file__).resolve().parents[2]
|
||||
/ "core"
|
||||
/ "config"
|
||||
/ "static"
|
||||
/ "automation"
|
||||
/ f"{config_name}.yaml"
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=16)
|
||||
def load_static_automation_job_config(*, config_name: str) -> AutomationJobConfig:
|
||||
path = _automation_yaml_path(config_name)
|
||||
with path.open("r", encoding="utf-8") as file:
|
||||
loaded: Any = yaml.safe_load(file) or {}
|
||||
if not isinstance(loaded, dict):
|
||||
raise ValueError(f"invalid automation config format: {path}")
|
||||
config = AutomationJobConfig.model_validate(loaded)
|
||||
if config_name == "memory_extraction":
|
||||
if config.enabled_tools != [AgentTool.MEMORY_WRITE, AgentTool.MEMORY_FORGET]:
|
||||
raise ValueError(
|
||||
"memory_extraction enabled_tools must be [memory.write, memory.forget]"
|
||||
)
|
||||
if config.context != MessageContextConfig(window_count=2):
|
||||
raise ValueError(
|
||||
"memory_extraction context must be latest_chat/day with window_count=2"
|
||||
)
|
||||
return config
|
||||
@@ -1,8 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.db import get_db
|
||||
from v1.auth.gateway import SupabaseAuthGateway
|
||||
from v1.auth.registration_bootstrap import (
|
||||
RegistrationAutomationBootstrapService,
|
||||
RegistrationBootstrapRepository,
|
||||
)
|
||||
from v1.auth.service import AuthService
|
||||
|
||||
|
||||
def get_auth_service() -> AuthService:
|
||||
return AuthService(gateway=SupabaseAuthGateway())
|
||||
def get_auth_service(
|
||||
session: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> AuthService:
|
||||
bootstrapper = RegistrationAutomationBootstrapService(
|
||||
repository=RegistrationBootstrapRepository(session=session),
|
||||
session=session,
|
||||
)
|
||||
return AuthService(
|
||||
gateway=SupabaseAuthGateway(),
|
||||
registration_bootstrapper=bootstrapper,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,228 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Protocol
|
||||
from uuid import UUID, uuid4
|
||||
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.logging import get_logger
|
||||
from models.automation_jobs import AutomationJob, AutomationJobStatus, ScheduleType
|
||||
from models.memories import MemoryType
|
||||
from models.profile import Profile
|
||||
from schemas.automation import AutomationJobConfig
|
||||
from schemas.memories.memory_content import UserMemoryContent, WorkProfileContent
|
||||
from schemas.user.context import parse_profile_settings
|
||||
from v1.auth.automation_static_config import load_static_automation_job_config
|
||||
from v1.auth.schemas import RegistrationBootstrapRequest
|
||||
from v1.memories.repository import SQLAlchemyMemoriesRepository
|
||||
|
||||
logger = get_logger("v1.auth.registration_bootstrap")
|
||||
|
||||
_LOCAL_RUN_HOUR = 8
|
||||
_LOCAL_RUN_MINUTE = 0
|
||||
|
||||
|
||||
class RegistrationBootstrapRepository:
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
self._memories_repository = SQLAlchemyMemoriesRepository(session)
|
||||
|
||||
async def get_profile_timezone(self, *, user_id: UUID) -> str:
|
||||
stmt = select(Profile.settings).where(Profile.id == user_id)
|
||||
settings = (await self._session.execute(stmt)).scalar_one_or_none()
|
||||
parsed = parse_profile_settings(
|
||||
settings if isinstance(settings, dict) else None
|
||||
)
|
||||
return parsed.preferences.timezone
|
||||
|
||||
async def insert_bootstrap_automation_job_if_absent(
|
||||
self,
|
||||
*,
|
||||
owner_id: UUID,
|
||||
bootstrap_key: str,
|
||||
title: str,
|
||||
config: AutomationJobConfig,
|
||||
timezone_name: str,
|
||||
run_at: datetime,
|
||||
next_run_at: datetime,
|
||||
) -> bool:
|
||||
stmt = (
|
||||
insert(AutomationJob)
|
||||
.values(
|
||||
id=uuid4(),
|
||||
owner_id=owner_id,
|
||||
bootstrap_key=bootstrap_key,
|
||||
title=title,
|
||||
config=config.model_dump(mode="json"),
|
||||
schedule_type=ScheduleType.DAILY,
|
||||
run_at=run_at,
|
||||
next_run_at=next_run_at,
|
||||
timezone=timezone_name,
|
||||
status=AutomationJobStatus.ACTIVE,
|
||||
created_by=owner_id,
|
||||
)
|
||||
.on_conflict_do_nothing(
|
||||
index_elements=["owner_id", "bootstrap_key"],
|
||||
index_where=AutomationJob.deleted_at.is_(None)
|
||||
& AutomationJob.bootstrap_key.is_not(None),
|
||||
)
|
||||
.returning(AutomationJob.id)
|
||||
)
|
||||
inserted_id = (await self._session.execute(stmt)).scalar_one_or_none()
|
||||
await self._session.flush()
|
||||
return inserted_id is not None
|
||||
|
||||
async def upsert_initial_memory(
|
||||
self,
|
||||
*,
|
||||
owner_id: UUID,
|
||||
memory_type: MemoryType,
|
||||
content: dict,
|
||||
) -> bool:
|
||||
return await self._memories_repository.create_if_absent(
|
||||
owner_id=owner_id,
|
||||
memory_type=memory_type,
|
||||
content=content,
|
||||
)
|
||||
|
||||
|
||||
class RegistrationBootstrapper(Protocol):
|
||||
async def ensure_user_automation_jobs(self, *, user_id: str | UUID) -> None: ...
|
||||
|
||||
|
||||
class RegistrationBootstrapRepositoryLike(Protocol):
|
||||
async def get_profile_timezone(self, *, user_id: UUID) -> str: ...
|
||||
|
||||
async def insert_bootstrap_automation_job_if_absent(
|
||||
self,
|
||||
*,
|
||||
owner_id: UUID,
|
||||
bootstrap_key: str,
|
||||
title: str,
|
||||
config: AutomationJobConfig,
|
||||
timezone_name: str,
|
||||
run_at: datetime,
|
||||
next_run_at: datetime,
|
||||
) -> bool: ...
|
||||
|
||||
async def upsert_initial_memory(
|
||||
self,
|
||||
*,
|
||||
owner_id: UUID,
|
||||
memory_type: MemoryType,
|
||||
content: dict,
|
||||
) -> bool: ...
|
||||
|
||||
|
||||
class SessionLike(Protocol):
|
||||
async def commit(self) -> None: ...
|
||||
|
||||
async def rollback(self) -> None: ...
|
||||
|
||||
|
||||
def compute_next_local_time_utc(
|
||||
*,
|
||||
now_utc: datetime,
|
||||
timezone_name: str,
|
||||
local_hour: int,
|
||||
local_minute: int,
|
||||
) -> tuple[datetime, datetime]:
|
||||
try:
|
||||
timezone_obj = ZoneInfo(timezone_name)
|
||||
except ZoneInfoNotFoundError:
|
||||
timezone_obj = ZoneInfo("UTC")
|
||||
local_now = now_utc.astimezone(timezone_obj)
|
||||
today_run_local = local_now.replace(
|
||||
hour=local_hour,
|
||||
minute=local_minute,
|
||||
second=0,
|
||||
microsecond=0,
|
||||
)
|
||||
run_local = (
|
||||
today_run_local
|
||||
if local_now <= today_run_local
|
||||
else today_run_local + timedelta(days=1)
|
||||
)
|
||||
next_local = run_local + timedelta(days=1)
|
||||
return run_local.astimezone(UTC), next_local.astimezone(UTC)
|
||||
|
||||
|
||||
class RegistrationAutomationBootstrapService:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
repository: RegistrationBootstrapRepositoryLike,
|
||||
session: SessionLike,
|
||||
) -> None:
|
||||
self._repository = repository
|
||||
self._session = session
|
||||
|
||||
async def ensure_user_automation_jobs(self, *, user_id: str | UUID) -> None:
|
||||
request = RegistrationBootstrapRequest.model_validate({"user_id": user_id})
|
||||
owner_id = request.user_id
|
||||
timezone_name = await self._repository.get_profile_timezone(user_id=owner_id)
|
||||
|
||||
definitions = [
|
||||
{
|
||||
"bootstrap_key": "memory_extraction",
|
||||
"config_name": "memory_extraction",
|
||||
"title": "Memory Agent",
|
||||
"local_hour": _LOCAL_RUN_HOUR,
|
||||
"local_minute": _LOCAL_RUN_MINUTE,
|
||||
}
|
||||
]
|
||||
|
||||
try:
|
||||
inserted_any = False
|
||||
created_or_updated_memory = False
|
||||
|
||||
user_initialized = await self._repository.upsert_initial_memory(
|
||||
owner_id=owner_id,
|
||||
memory_type=MemoryType.USER,
|
||||
content=UserMemoryContent().model_dump(mode="json"),
|
||||
)
|
||||
work_initialized = await self._repository.upsert_initial_memory(
|
||||
owner_id=owner_id,
|
||||
memory_type=MemoryType.WORK,
|
||||
content=WorkProfileContent().model_dump(mode="json"),
|
||||
)
|
||||
created_or_updated_memory = user_initialized or work_initialized
|
||||
|
||||
for definition in definitions:
|
||||
bootstrap_key = str(definition["bootstrap_key"])
|
||||
job_config = load_static_automation_job_config(
|
||||
config_name=str(definition["config_name"])
|
||||
)
|
||||
run_at, next_run_at = compute_next_local_time_utc(
|
||||
now_utc=datetime.now(UTC),
|
||||
timezone_name=timezone_name,
|
||||
local_hour=int(definition["local_hour"]),
|
||||
local_minute=int(definition["local_minute"]),
|
||||
)
|
||||
inserted = (
|
||||
await self._repository.insert_bootstrap_automation_job_if_absent(
|
||||
owner_id=owner_id,
|
||||
bootstrap_key=bootstrap_key,
|
||||
title=str(definition["title"]),
|
||||
config=job_config,
|
||||
timezone_name=timezone_name,
|
||||
run_at=run_at,
|
||||
next_run_at=next_run_at,
|
||||
)
|
||||
)
|
||||
inserted_any = inserted_any or inserted
|
||||
if inserted_any or created_or_updated_memory:
|
||||
await self._session.commit()
|
||||
logger.info(
|
||||
"user automation jobs bootstrapped",
|
||||
user_id=user_id,
|
||||
timezone=timezone_name,
|
||||
memory_initialized=created_or_updated_memory,
|
||||
)
|
||||
except Exception:
|
||||
await self._session.rollback()
|
||||
raise
|
||||
@@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
SUPABASE_PASSWORD_MIN_LENGTH = 6
|
||||
@@ -49,3 +51,9 @@ class UserByPhoneResponse(BaseModel):
|
||||
|
||||
class OtpSendResponse(BaseModel):
|
||||
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
|
||||
|
||||
|
||||
class RegistrationBootstrapRequest(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
user_id: UUID
|
||||
|
||||
@@ -28,9 +28,15 @@ class AuthServiceGateway(Protocol):
|
||||
|
||||
class AuthService:
|
||||
_gateway: AuthServiceGateway
|
||||
_registration_bootstrapper: RegistrationBootstrapper | None
|
||||
|
||||
def __init__(self, gateway: AuthServiceGateway) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
gateway: AuthServiceGateway,
|
||||
registration_bootstrapper: "RegistrationBootstrapper | None" = None,
|
||||
) -> None:
|
||||
self._gateway = gateway
|
||||
self._registration_bootstrapper = registration_bootstrapper
|
||||
|
||||
async def send_otp(self, request: OtpSendRequest) -> None:
|
||||
await self._gateway.send_otp(request)
|
||||
@@ -38,10 +44,20 @@ class AuthService:
|
||||
async def create_phone_session(
|
||||
self, request: PhoneSessionCreateRequest
|
||||
) -> SessionResponse:
|
||||
return await self._gateway.create_phone_session(request)
|
||||
response = await self._gateway.create_phone_session(request)
|
||||
if self._registration_bootstrapper is not None:
|
||||
await self._registration_bootstrapper.ensure_user_automation_jobs(
|
||||
user_id=response.user.id
|
||||
)
|
||||
return response
|
||||
|
||||
async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse:
|
||||
return await self._gateway.refresh_session(request)
|
||||
|
||||
async def delete_session(self, refresh_token: str | None) -> None:
|
||||
await self._gateway.delete_session(refresh_token)
|
||||
|
||||
|
||||
class RegistrationBootstrapper(Protocol):
|
||||
async def ensure_user_automation_jobs(self, *, user_id: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from core.db import get_db
|
||||
from v1.memories.repository import SQLAlchemyMemoriesRepository
|
||||
from v1.memories.service import MemoriesService
|
||||
from v1.users.dependencies import get_current_user
|
||||
|
||||
|
||||
async def get_memories_repository(
|
||||
session: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> SQLAlchemyMemoriesRepository:
|
||||
return SQLAlchemyMemoriesRepository(session)
|
||||
|
||||
|
||||
async def get_memories_service(
|
||||
session: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
) -> MemoriesService:
|
||||
repository = SQLAlchemyMemoriesRepository(session)
|
||||
return MemoriesService(
|
||||
repository=repository,
|
||||
session=session,
|
||||
current_user=current_user,
|
||||
)
|
||||
@@ -3,29 +3,162 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from core.db.base_repository import BaseRepository
|
||||
from models.memories import Memory
|
||||
from core.logging import get_logger
|
||||
from models.memories import Memory, MemoryStatus, MemoryType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
logger = get_logger("v1.memories.repository")
|
||||
|
||||
|
||||
class MemoriesRepositoryLike(Protocol):
|
||||
async def get_active_memories(self, *, owner_id: UUID) -> list[Memory]: ...
|
||||
async def create(
|
||||
self,
|
||||
*,
|
||||
owner_id: UUID,
|
||||
memory_type: MemoryType,
|
||||
content: dict,
|
||||
) -> Memory: ...
|
||||
|
||||
async def get_by_type_for_owner(
|
||||
self, *, owner_id: UUID, memory_type: MemoryType
|
||||
) -> Memory | None: ...
|
||||
|
||||
async def get_user_memory_for_owner(self, *, owner_id: UUID) -> Memory | None: ...
|
||||
|
||||
async def get_work_memory_for_owner(self, *, owner_id: UUID) -> Memory | None: ...
|
||||
|
||||
async def create_if_absent(
|
||||
self,
|
||||
*,
|
||||
owner_id: UUID,
|
||||
memory_type: MemoryType,
|
||||
content: dict,
|
||||
) -> bool: ...
|
||||
|
||||
async def update_content(
|
||||
self,
|
||||
memory: Memory,
|
||||
content: dict | None = None,
|
||||
) -> Memory: ...
|
||||
|
||||
|
||||
class MemoriesRepository(BaseRepository[Memory]):
|
||||
class SQLAlchemyMemoriesRepository(BaseRepository[Memory]):
|
||||
_session: AsyncSession
|
||||
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
super().__init__(session=session, model=Memory)
|
||||
self._session = session
|
||||
|
||||
async def get_active_memories(self, *, owner_id: UUID) -> list[Memory]:
|
||||
stmt = (
|
||||
select(Memory)
|
||||
.where(Memory.owner_id == owner_id)
|
||||
.where(Memory.status == "active")
|
||||
.order_by(Memory.created_at.desc())
|
||||
async def create(
|
||||
self,
|
||||
*,
|
||||
owner_id: UUID,
|
||||
memory_type: MemoryType,
|
||||
content: dict,
|
||||
) -> Memory:
|
||||
try:
|
||||
memory = Memory(
|
||||
owner_id=owner_id,
|
||||
memory_type=memory_type,
|
||||
content=content,
|
||||
status=MemoryStatus.ACTIVE,
|
||||
)
|
||||
self._session.add(memory)
|
||||
await self._session.flush()
|
||||
return memory
|
||||
except SQLAlchemyError:
|
||||
logger.exception(
|
||||
"Failed to create memory",
|
||||
owner_id=str(owner_id),
|
||||
memory_type=memory_type.value,
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_by_type_for_owner(
|
||||
self, *, owner_id: UUID, memory_type: MemoryType
|
||||
) -> Memory | None:
|
||||
try:
|
||||
stmt = (
|
||||
select(Memory)
|
||||
.where(Memory.owner_id == owner_id)
|
||||
.where(Memory.memory_type == memory_type)
|
||||
.where(Memory.status == MemoryStatus.ACTIVE)
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
if hasattr(result, "scalar_one_or_none"):
|
||||
return result.scalar_one_or_none()
|
||||
scalars = result.scalars()
|
||||
rows = list(scalars.all())
|
||||
return rows[0] if rows else None
|
||||
except SQLAlchemyError:
|
||||
logger.exception(
|
||||
"Failed to get memory by type for owner",
|
||||
owner_id=str(owner_id),
|
||||
memory_type=memory_type.value,
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_user_memory_for_owner(self, *, owner_id: UUID) -> Memory | None:
|
||||
return await self.get_by_type_for_owner(
|
||||
owner_id=owner_id, memory_type=MemoryType.USER
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_work_memory_for_owner(self, *, owner_id: UUID) -> Memory | None:
|
||||
return await self.get_by_type_for_owner(
|
||||
owner_id=owner_id, memory_type=MemoryType.WORK
|
||||
)
|
||||
|
||||
async def create_if_absent(
|
||||
self,
|
||||
*,
|
||||
owner_id: UUID,
|
||||
memory_type: MemoryType,
|
||||
content: dict,
|
||||
) -> bool:
|
||||
try:
|
||||
stmt = (
|
||||
insert(Memory)
|
||||
.values(
|
||||
owner_id=owner_id,
|
||||
memory_type=memory_type,
|
||||
content=content,
|
||||
status=MemoryStatus.ACTIVE,
|
||||
)
|
||||
.on_conflict_do_nothing(index_elements=["owner_id", "memory_type"])
|
||||
.returning(Memory.id)
|
||||
)
|
||||
inserted_id = (await self._session.execute(stmt)).scalar_one_or_none()
|
||||
await self._session.flush()
|
||||
return inserted_id is not None
|
||||
except SQLAlchemyError:
|
||||
logger.exception(
|
||||
"Failed to create memory if absent",
|
||||
owner_id=str(owner_id),
|
||||
memory_type=memory_type.value,
|
||||
)
|
||||
raise
|
||||
|
||||
async def update_content(
|
||||
self,
|
||||
memory: Memory,
|
||||
content: dict | None = None,
|
||||
) -> Memory:
|
||||
try:
|
||||
if content is not None:
|
||||
memory.content = content
|
||||
|
||||
await self._session.flush()
|
||||
return memory
|
||||
except SQLAlchemyError:
|
||||
logger.exception(
|
||||
"Failed to update memory content",
|
||||
memory_id=str(memory.id),
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -0,0 +1,83 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, status
|
||||
|
||||
from schemas.memories.memory_content import UserMemoryContent, WorkProfileContent
|
||||
from v1.memories.dependencies import get_memories_service
|
||||
from v1.memories.schemas import (
|
||||
MemoryListResponse,
|
||||
UserMemoryPartialUpdate,
|
||||
UserMemoryUpdate,
|
||||
WorkMemoryPartialUpdate,
|
||||
WorkMemoryUpdate,
|
||||
)
|
||||
from v1.memories.service import MemoriesService
|
||||
|
||||
router = APIRouter(prefix="/memories", tags=["memories"])
|
||||
|
||||
|
||||
@router.get("", response_model=MemoryListResponse)
|
||||
async def get_all_memories(
|
||||
service: Annotated[MemoriesService, Depends(get_memories_service)],
|
||||
) -> MemoryListResponse:
|
||||
result = await service.get_all_memories()
|
||||
return MemoryListResponse(
|
||||
user_memory=result["user_memory"],
|
||||
work_memory=result["work_memory"],
|
||||
)
|
||||
|
||||
|
||||
@router.get("/user", response_model=UserMemoryContent | None)
|
||||
async def get_user_memory(
|
||||
service: Annotated[MemoriesService, Depends(get_memories_service)],
|
||||
) -> UserMemoryContent | None:
|
||||
return await service.get_user_memory()
|
||||
|
||||
|
||||
@router.get("/work", response_model=WorkProfileContent | None)
|
||||
async def get_work_memory(
|
||||
service: Annotated[MemoriesService, Depends(get_memories_service)],
|
||||
) -> WorkProfileContent | None:
|
||||
return await service.get_work_memory()
|
||||
|
||||
|
||||
@router.put("/user", response_model=UserMemoryContent, status_code=status.HTTP_200_OK)
|
||||
async def update_user_memory(
|
||||
payload: UserMemoryUpdate,
|
||||
service: Annotated[MemoriesService, Depends(get_memories_service)],
|
||||
) -> UserMemoryContent:
|
||||
return await service.update_user_memory(
|
||||
content=payload.content,
|
||||
)
|
||||
|
||||
|
||||
@router.put("/work", response_model=WorkProfileContent, status_code=status.HTTP_200_OK)
|
||||
async def update_work_memory(
|
||||
payload: WorkMemoryUpdate,
|
||||
service: Annotated[MemoriesService, Depends(get_memories_service)],
|
||||
) -> WorkProfileContent:
|
||||
return await service.update_work_memory(
|
||||
content=payload.content,
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/user", response_model=UserMemoryContent)
|
||||
async def patch_user_memory(
|
||||
payload: UserMemoryPartialUpdate,
|
||||
service: Annotated[MemoriesService, Depends(get_memories_service)],
|
||||
) -> UserMemoryContent:
|
||||
return await service.patch_user_memory(
|
||||
content=payload.content,
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/work", response_model=WorkProfileContent)
|
||||
async def patch_work_memory(
|
||||
payload: WorkMemoryPartialUpdate,
|
||||
service: Annotated[MemoriesService, Depends(get_memories_service)],
|
||||
) -> WorkProfileContent:
|
||||
return await service.patch_work_memory(
|
||||
content=payload.content,
|
||||
)
|
||||
@@ -0,0 +1,38 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import ClassVar
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from schemas.memories.memory_content import UserMemoryContent, WorkProfileContent
|
||||
|
||||
|
||||
class UserMemoryUpdate(BaseModel):
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
content: UserMemoryContent
|
||||
|
||||
|
||||
class WorkMemoryUpdate(BaseModel):
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
content: WorkProfileContent
|
||||
|
||||
|
||||
class UserMemoryPartialUpdate(BaseModel):
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
content: UserMemoryContent | None = None
|
||||
|
||||
|
||||
class WorkMemoryPartialUpdate(BaseModel):
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
content: WorkProfileContent | None = None
|
||||
|
||||
|
||||
class MemoryListResponse(BaseModel):
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(from_attributes=True)
|
||||
|
||||
user_memory: UserMemoryContent | None = None
|
||||
work_memory: WorkProfileContent | None = None
|
||||
@@ -1,53 +1,286 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from models.memories import Memory
|
||||
from schemas.memories import MemoryContext, MemoryListResponse, MemorySource, MemoryType
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from core.db.base_service import BaseService
|
||||
from core.logging import get_logger
|
||||
from models.memories import Memory, MemoryType
|
||||
from schemas.memories.memory_content import UserMemoryContent, WorkProfileContent
|
||||
from v1.memories.repository import MemoriesRepositoryLike
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
logger = get_logger("v1.memories.service")
|
||||
|
||||
|
||||
class MemoriesService(BaseService):
|
||||
"""Memories service handling user/work memory operations.
|
||||
|
||||
Each user has exactly 2 memory records:
|
||||
- user_type: stores personal preferences, people, places, etc.
|
||||
- work_type: stores work profile, projects, team, etc.
|
||||
|
||||
Responsibilities:
|
||||
- Authorization checks
|
||||
- Validation (ownership, memory type)
|
||||
- Transaction boundary (commit/rollback)
|
||||
- Converting ORM models to response schemas
|
||||
"""
|
||||
|
||||
class MemoriesService:
|
||||
_repository: MemoriesRepositoryLike
|
||||
_session: AsyncSession
|
||||
|
||||
def __init__(self, repository: MemoriesRepositoryLike) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
repository: MemoriesRepositoryLike,
|
||||
session: AsyncSession,
|
||||
current_user: CurrentUser | None,
|
||||
) -> None:
|
||||
super().__init__(current_user=current_user)
|
||||
self._repository = repository
|
||||
self._session = session
|
||||
|
||||
def _to_context(self, memory: Memory) -> MemoryContext:
|
||||
return MemoryContext(
|
||||
memory_type=MemoryType(memory.memory_type.value),
|
||||
source=MemorySource(memory.source.value),
|
||||
title=memory.title,
|
||||
content=memory.content,
|
||||
created_at=memory.created_at,
|
||||
updated_at=memory.updated_at,
|
||||
async def get_user_memory(self) -> UserMemoryContent | None:
|
||||
user_id = self.require_user_id()
|
||||
|
||||
try:
|
||||
memory = await self._repository.get_user_memory_for_owner(owner_id=user_id)
|
||||
except SQLAlchemyError:
|
||||
raise HTTPException(status_code=503, detail="Memories service unavailable")
|
||||
|
||||
if memory is None:
|
||||
return None
|
||||
|
||||
return self._parse_user_content(memory)
|
||||
|
||||
async def get_work_memory(self) -> WorkProfileContent | None:
|
||||
user_id = self.require_user_id()
|
||||
|
||||
try:
|
||||
memory = await self._repository.get_work_memory_for_owner(owner_id=user_id)
|
||||
except SQLAlchemyError:
|
||||
raise HTTPException(status_code=503, detail="Memories service unavailable")
|
||||
|
||||
if memory is None:
|
||||
return None
|
||||
|
||||
return self._parse_work_content(memory)
|
||||
|
||||
async def get_all_memories(self) -> dict:
|
||||
user_id = self.require_user_id()
|
||||
|
||||
try:
|
||||
user_memory = await self._repository.get_user_memory_for_owner(
|
||||
owner_id=user_id
|
||||
)
|
||||
work_memory = await self._repository.get_work_memory_for_owner(
|
||||
owner_id=user_id
|
||||
)
|
||||
except SQLAlchemyError:
|
||||
raise HTTPException(status_code=503, detail="Memories service unavailable")
|
||||
|
||||
return {
|
||||
"user_memory": self._parse_user_content(user_memory)
|
||||
if user_memory
|
||||
else None,
|
||||
"work_memory": self._parse_work_content(work_memory)
|
||||
if work_memory
|
||||
else None,
|
||||
}
|
||||
|
||||
async def update_user_memory(
|
||||
self,
|
||||
*,
|
||||
content: UserMemoryContent,
|
||||
) -> UserMemoryContent:
|
||||
user_id = self.require_user_id()
|
||||
|
||||
try:
|
||||
memory = await self._repository.get_user_memory_for_owner(owner_id=user_id)
|
||||
except SQLAlchemyError:
|
||||
raise HTTPException(status_code=503, detail="Memories service unavailable")
|
||||
|
||||
if memory is None:
|
||||
try:
|
||||
memory = await self._repository.create(
|
||||
owner_id=user_id,
|
||||
memory_type=MemoryType.USER,
|
||||
content=content.model_dump(),
|
||||
)
|
||||
await self._session.commit()
|
||||
except SQLAlchemyError:
|
||||
await self._session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Memories service unavailable"
|
||||
)
|
||||
else:
|
||||
try:
|
||||
memory = await self._repository.update_content(
|
||||
memory=memory,
|
||||
content=content.model_dump(),
|
||||
)
|
||||
await self._session.commit()
|
||||
await self._session.refresh(memory)
|
||||
except SQLAlchemyError:
|
||||
await self._session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Memories service unavailable"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"user_memory_updated",
|
||||
extra={"user_id": str(user_id)},
|
||||
)
|
||||
|
||||
async def get_user_memories(self, *, owner_id: UUID) -> MemoryListResponse:
|
||||
memories = await self._repository.get_active_memories(owner_id=owner_id)
|
||||
user_memories = [
|
||||
self._to_context(memory)
|
||||
for memory in memories
|
||||
if memory.memory_type.value == "user"
|
||||
]
|
||||
return MemoryListResponse(
|
||||
owner_id=owner_id, memories=user_memories, total=len(user_memories)
|
||||
return self._parse_user_content(memory)
|
||||
|
||||
async def update_work_memory(
|
||||
self,
|
||||
*,
|
||||
content: WorkProfileContent,
|
||||
) -> WorkProfileContent:
|
||||
user_id = self.require_user_id()
|
||||
|
||||
try:
|
||||
memory = await self._repository.get_work_memory_for_owner(owner_id=user_id)
|
||||
except SQLAlchemyError:
|
||||
raise HTTPException(status_code=503, detail="Memories service unavailable")
|
||||
|
||||
if memory is None:
|
||||
try:
|
||||
memory = await self._repository.create(
|
||||
owner_id=user_id,
|
||||
memory_type=MemoryType.WORK,
|
||||
content=content.model_dump(),
|
||||
)
|
||||
await self._session.commit()
|
||||
except SQLAlchemyError:
|
||||
await self._session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Memories service unavailable"
|
||||
)
|
||||
else:
|
||||
try:
|
||||
memory = await self._repository.update_content(
|
||||
memory=memory,
|
||||
content=content.model_dump(),
|
||||
)
|
||||
await self._session.commit()
|
||||
await self._session.refresh(memory)
|
||||
except SQLAlchemyError:
|
||||
await self._session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Memories service unavailable"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"work_memory_updated",
|
||||
extra={"user_id": str(user_id)},
|
||||
)
|
||||
|
||||
async def get_agent_memories(self, *, owner_id: UUID) -> MemoryListResponse:
|
||||
memories = await self._repository.get_active_memories(owner_id=owner_id)
|
||||
agent_memories = [
|
||||
self._to_context(memory)
|
||||
for memory in memories
|
||||
if memory.memory_type.value == "work"
|
||||
]
|
||||
return MemoryListResponse(
|
||||
owner_id=owner_id, memories=agent_memories, total=len(agent_memories)
|
||||
return self._parse_work_content(memory)
|
||||
|
||||
async def patch_user_memory(
|
||||
self, *, content: UserMemoryContent | None = None
|
||||
) -> UserMemoryContent:
|
||||
user_id = self.require_user_id()
|
||||
|
||||
try:
|
||||
memory = await self._repository.get_user_memory_for_owner(owner_id=user_id)
|
||||
except SQLAlchemyError:
|
||||
raise HTTPException(status_code=503, detail="Memories service unavailable")
|
||||
|
||||
if memory is None:
|
||||
raise HTTPException(status_code=404, detail="User memory not found")
|
||||
|
||||
try:
|
||||
update_data: dict = {}
|
||||
if content is not None:
|
||||
existing_content = memory.content or {}
|
||||
merged = content.model_dump()
|
||||
existing_content.update(
|
||||
{k: v for k, v in merged.items() if v is not None}
|
||||
)
|
||||
update_data["content"] = existing_content
|
||||
|
||||
if update_data:
|
||||
memory = await self._repository.update_content(
|
||||
memory=memory,
|
||||
content=update_data.get("content"),
|
||||
)
|
||||
await self._session.commit()
|
||||
await self._session.refresh(memory)
|
||||
except SQLAlchemyError:
|
||||
await self._session.rollback()
|
||||
raise HTTPException(status_code=503, detail="Memories service unavailable")
|
||||
|
||||
logger.info(
|
||||
"user_memory_patched",
|
||||
extra={"user_id": str(user_id)},
|
||||
)
|
||||
|
||||
async def get_all_memories(self, *, owner_id: UUID) -> MemoryListResponse:
|
||||
memories = await self._repository.get_active_memories(owner_id=owner_id)
|
||||
memory_contexts = [self._to_context(memory) for memory in memories]
|
||||
return MemoryListResponse(
|
||||
owner_id=owner_id, memories=memory_contexts, total=len(memory_contexts)
|
||||
return self._parse_user_content(memory)
|
||||
|
||||
async def patch_work_memory(
|
||||
self, *, content: WorkProfileContent | None = None
|
||||
) -> WorkProfileContent:
|
||||
user_id = self.require_user_id()
|
||||
|
||||
try:
|
||||
memory = await self._repository.get_work_memory_for_owner(owner_id=user_id)
|
||||
except SQLAlchemyError:
|
||||
raise HTTPException(status_code=503, detail="Memories service unavailable")
|
||||
|
||||
if memory is None:
|
||||
raise HTTPException(status_code=404, detail="Work memory not found")
|
||||
|
||||
try:
|
||||
update_data: dict = {}
|
||||
if content is not None:
|
||||
existing_content = memory.content or {}
|
||||
merged = content.model_dump()
|
||||
existing_content.update(
|
||||
{k: v for k, v in merged.items() if v is not None}
|
||||
)
|
||||
update_data["content"] = existing_content
|
||||
|
||||
if update_data:
|
||||
memory = await self._repository.update_content(
|
||||
memory=memory,
|
||||
content=update_data.get("content"),
|
||||
)
|
||||
await self._session.commit()
|
||||
await self._session.refresh(memory)
|
||||
except SQLAlchemyError:
|
||||
await self._session.rollback()
|
||||
raise HTTPException(status_code=503, detail="Memories service unavailable")
|
||||
|
||||
logger.info(
|
||||
"work_memory_patched",
|
||||
extra={"user_id": str(user_id)},
|
||||
)
|
||||
|
||||
return self._parse_work_content(memory)
|
||||
|
||||
def _parse_user_content(self, memory: Memory) -> UserMemoryContent:
|
||||
content_dict = memory.content or {}
|
||||
return UserMemoryContent.model_validate(content_dict)
|
||||
|
||||
def _parse_work_content(self, memory: Memory) -> WorkProfileContent:
|
||||
content_dict = memory.content or {}
|
||||
return WorkProfileContent.model_validate(content_dict)
|
||||
|
||||
async def get_memory_model(self, *, memory_type: MemoryType) -> Memory | None:
|
||||
user_id = self.require_user_id()
|
||||
try:
|
||||
return await self._repository.get_by_type_for_owner(
|
||||
owner_id=user_id,
|
||||
memory_type=memory_type,
|
||||
)
|
||||
except SQLAlchemyError:
|
||||
raise HTTPException(status_code=503, detail="Memories service unavailable")
|
||||
|
||||
@@ -7,6 +7,7 @@ from v1.app.router import router as app_router
|
||||
from v1.auth.router import router as auth_router
|
||||
from v1.friendships.router import router as friendships_router
|
||||
from v1.inbox_messages.router import router as inbox_messages_router
|
||||
from v1.memories.router import router as memories_router
|
||||
from v1.schedule_items.router import router as schedule_items_router
|
||||
from v1.todo.router import router as todo_router
|
||||
from v1.users.router import router as users_router
|
||||
@@ -16,8 +17,8 @@ router = APIRouter(prefix="/api/v1")
|
||||
router.include_router(app_router)
|
||||
router.include_router(auth_router)
|
||||
router.include_router(agent_router)
|
||||
router.include_router(agent_router)
|
||||
router.include_router(friendships_router)
|
||||
router.include_router(memories_router)
|
||||
router.include_router(users_router)
|
||||
router.include_router(schedule_items_router)
|
||||
router.include_router(inbox_messages_router)
|
||||
|
||||
@@ -6,7 +6,7 @@ import pytest
|
||||
from ag_ui.core import RunAgentInput
|
||||
|
||||
from core.agentscope.runtime.orchestrator import AgentScopeRuntimeOrchestrator
|
||||
from schemas.automation import MemoryContextConfig, RuntimeConfig
|
||||
from schemas.automation import MessageContextConfig, RuntimeConfig
|
||||
from schemas.user import UserContext, parse_profile_settings
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ def _run_input() -> RunAgentInput:
|
||||
def _runtime_config() -> RuntimeConfig:
|
||||
return RuntimeConfig(
|
||||
enabled_tools=[],
|
||||
context=MemoryContextConfig(),
|
||||
context=MessageContextConfig(),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from schemas.agent.runtime_models import (
|
||||
WorkerAgentOutputLite,
|
||||
)
|
||||
from schemas.agent.system_agent import AgentType
|
||||
from schemas.automation import MemoryContextConfig, RuntimeConfig
|
||||
from schemas.automation import MessageContextConfig, RuntimeConfig
|
||||
from schemas.user import UserContext, parse_profile_settings
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ def _user_context() -> UserContext:
|
||||
def _runtime_config() -> RuntimeConfig:
|
||||
return RuntimeConfig(
|
||||
enabled_tools=[],
|
||||
context=MemoryContextConfig(),
|
||||
context=MessageContextConfig(),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import pytest
|
||||
|
||||
import core.agentscope.runtime.tasks as tasks_module
|
||||
from schemas.agent import ToolStatus
|
||||
from schemas.automation import ContextWindowMode, MemoryContextConfig
|
||||
from schemas.automation import ContextWindowMode, MessageContextConfig
|
||||
from schemas.user import UserContext, parse_profile_settings
|
||||
|
||||
|
||||
@@ -201,7 +201,7 @@ async def test_build_recent_context_messages_includes_all_user_attachments(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
context_config: MemoryContextConfig,
|
||||
context_config: MessageContextConfig,
|
||||
) -> dict[str, object] | None:
|
||||
del thread_id, context_config
|
||||
return {
|
||||
@@ -237,7 +237,7 @@ async def test_build_recent_context_messages_includes_all_user_attachments(
|
||||
messages = await tasks_module._build_recent_context_messages(
|
||||
session=object(),
|
||||
thread_id=str(uuid4()),
|
||||
context_config=MemoryContextConfig(
|
||||
context_config=MessageContextConfig(
|
||||
window_mode=ContextWindowMode.DAY,
|
||||
window_count=2,
|
||||
),
|
||||
@@ -264,7 +264,7 @@ async def test_build_recent_context_messages_uses_tool_metadata_output(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
context_config: MemoryContextConfig,
|
||||
context_config: MessageContextConfig,
|
||||
) -> dict[str, object] | None:
|
||||
del thread_id, context_config
|
||||
return {
|
||||
@@ -295,7 +295,7 @@ async def test_build_recent_context_messages_uses_tool_metadata_output(
|
||||
messages = await tasks_module._build_recent_context_messages(
|
||||
session=object(),
|
||||
thread_id=str(uuid4()),
|
||||
context_config=MemoryContextConfig(),
|
||||
context_config=MessageContextConfig(),
|
||||
)
|
||||
|
||||
assert len(messages) == 1
|
||||
@@ -319,7 +319,7 @@ async def test_build_recent_context_messages_skips_tool_without_metadata_output(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
context_config: MemoryContextConfig,
|
||||
context_config: MessageContextConfig,
|
||||
) -> dict[str, object] | None:
|
||||
del thread_id, context_config
|
||||
return {
|
||||
@@ -337,7 +337,7 @@ async def test_build_recent_context_messages_skips_tool_without_metadata_output(
|
||||
messages = await tasks_module._build_recent_context_messages(
|
||||
session=object(),
|
||||
thread_id=str(uuid4()),
|
||||
context_config=MemoryContextConfig(),
|
||||
context_config=MessageContextConfig(),
|
||||
)
|
||||
|
||||
assert messages == []
|
||||
@@ -357,7 +357,7 @@ async def test_build_recent_context_messages_passes_context_config(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
context_config: MemoryContextConfig,
|
||||
context_config: MessageContextConfig,
|
||||
) -> dict[str, object] | None:
|
||||
del thread_id
|
||||
captured_config["config"] = context_config
|
||||
@@ -365,7 +365,7 @@ async def test_build_recent_context_messages_passes_context_config(
|
||||
|
||||
monkeypatch.setattr(tasks_module, "AgentContextService", _FakeContextService)
|
||||
|
||||
cfg = MemoryContextConfig(window_mode=ContextWindowMode.NUMBER, window_count=10)
|
||||
cfg = MessageContextConfig(window_mode=ContextWindowMode.NUMBER, window_count=10)
|
||||
messages = await tasks_module._build_recent_context_messages(
|
||||
session=object(),
|
||||
thread_id=str(uuid4()),
|
||||
|
||||
@@ -0,0 +1,113 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from agentscope.tool import ToolResponse
|
||||
|
||||
from core.agentscope.tools.custom import memory as memory_module
|
||||
from models.memories import MemoryType
|
||||
from schemas.memories.memory_content import UserMemoryContent
|
||||
|
||||
|
||||
def _decode_tool_response(response: ToolResponse) -> dict[str, object]:
|
||||
assert response.content
|
||||
first = response.content[0]
|
||||
text = str(first.get("text", "")) if isinstance(first, dict) else str(first.text)
|
||||
return json.loads(text)
|
||||
|
||||
|
||||
def _payload_error_code(payload: dict[str, object]) -> str:
|
||||
error = payload.get("error")
|
||||
if not isinstance(error, dict):
|
||||
return ""
|
||||
return str(error.get("code") or "")
|
||||
|
||||
|
||||
class _FakeMemoriesService:
|
||||
def __init__(self) -> None:
|
||||
self.memory: object | None = None
|
||||
self.updated_user = 0
|
||||
self.updated_work = 0
|
||||
|
||||
async def get_memory_model(self, *, memory_type: MemoryType):
|
||||
_ = memory_type
|
||||
return self.memory
|
||||
|
||||
async def update_user_memory(self, **kwargs):
|
||||
_ = kwargs
|
||||
self.updated_user += 1
|
||||
return SimpleNamespace()
|
||||
|
||||
async def update_work_memory(self, **kwargs):
|
||||
_ = kwargs
|
||||
self.updated_work += 1
|
||||
return SimpleNamespace()
|
||||
|
||||
|
||||
def _user_memory():
|
||||
return SimpleNamespace(
|
||||
id=uuid4(),
|
||||
owner_id=uuid4(),
|
||||
memory_type=MemoryType.USER,
|
||||
content={"preferences": {"communication_style": "简洁"}},
|
||||
status="active",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_write_requires_runtime_context() -> None:
|
||||
response = await memory_module.memory_write(
|
||||
memory_type="user",
|
||||
user_content=UserMemoryContent(interests=["跑步"]),
|
||||
)
|
||||
payload = _decode_tool_response(response)
|
||||
assert payload["status"] == "failure"
|
||||
assert _payload_error_code(payload) == "MISSING_RUNTIME_ARGS"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_write_updates_user_content(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
fake_service = _FakeMemoriesService()
|
||||
monkeypatch.setattr(
|
||||
memory_module, "create_memories_service", lambda **_: fake_service
|
||||
)
|
||||
|
||||
response = await memory_module.memory_write(
|
||||
memory_type="user",
|
||||
user_content=UserMemoryContent(interests=["阅读"]),
|
||||
session=SimpleNamespace(),
|
||||
owner_id=uuid4(),
|
||||
)
|
||||
payload = _decode_tool_response(response)
|
||||
|
||||
assert payload["status"] == "success"
|
||||
assert "memory_type=user" in str(payload["result"])
|
||||
assert fake_service.updated_user == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_forget_updates_content_paths(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
fake_service = _FakeMemoriesService()
|
||||
fake_service.memory = _user_memory()
|
||||
monkeypatch.setattr(
|
||||
memory_module, "create_memories_service", lambda **_: fake_service
|
||||
)
|
||||
|
||||
response = await memory_module.memory_forget(
|
||||
memory_type="user",
|
||||
forget_paths=["preferences.communication_style"],
|
||||
session=SimpleNamespace(),
|
||||
owner_id=uuid4(),
|
||||
)
|
||||
payload = _decode_tool_response(response)
|
||||
|
||||
assert payload["status"] == "success"
|
||||
assert "forgotten=1" in str(payload["result"])
|
||||
assert fake_service.updated_user == 1
|
||||
@@ -158,46 +158,44 @@ def test_build_system_prompt_keeps_sections_focused_without_language_duplication
|
||||
assert "Follow agent contracts strictly" not in prompt
|
||||
|
||||
|
||||
def test_build_system_prompt_includes_memory_section_when_memories_provided() -> None:
|
||||
from schemas.memories import (
|
||||
MemoryContext,
|
||||
MemoryListResponse,
|
||||
MemorySource,
|
||||
MemoryType,
|
||||
def test_build_system_prompt_includes_user_memory_section_for_router() -> None:
|
||||
from schemas.memories.memory_content import UserMemoryContent
|
||||
|
||||
user_memory = UserMemoryContent()
|
||||
|
||||
prompt = build_system_prompt(
|
||||
agent_type=AgentType.ROUTER,
|
||||
user_context=_build_user_context(),
|
||||
now_utc=datetime(2026, 3, 11, 0, 0, tzinfo=timezone.utc),
|
||||
user_memory=user_memory,
|
||||
)
|
||||
|
||||
memories = MemoryListResponse(
|
||||
owner_id=uuid4(),
|
||||
memories=[
|
||||
MemoryContext(
|
||||
memory_type=MemoryType.USER,
|
||||
source=MemorySource.MANUAL,
|
||||
title="User prefers morning meetings",
|
||||
content={"text": "User likes meetings before 10am"},
|
||||
created_at=datetime(2026, 3, 1, tzinfo=timezone.utc),
|
||||
updated_at=datetime(2026, 3, 1, tzinfo=timezone.utc),
|
||||
),
|
||||
],
|
||||
total=1,
|
||||
)
|
||||
assert "<!-- USER_MEMORY_START -->" in prompt
|
||||
assert "[User Memory]" in prompt
|
||||
|
||||
|
||||
def test_build_system_prompt_includes_work_memory_section_for_worker() -> None:
|
||||
from schemas.memories.memory_content import WorkProfileContent
|
||||
|
||||
work_memory = WorkProfileContent()
|
||||
|
||||
prompt = build_system_prompt(
|
||||
agent_type=AgentType.WORKER,
|
||||
user_context=_build_user_context(),
|
||||
now_utc=datetime(2026, 3, 11, 0, 0, tzinfo=timezone.utc),
|
||||
memories=memories,
|
||||
work_memory=work_memory,
|
||||
)
|
||||
|
||||
assert "<!-- MEMORY_START -->" in prompt
|
||||
assert "[User Memories]" in prompt
|
||||
assert "User prefers morning meetings" in prompt
|
||||
assert "<!-- WORK_MEMORY_START -->" in prompt
|
||||
assert "[Work Memory]" in prompt
|
||||
|
||||
|
||||
def test_build_system_prompt_omits_memory_section_when_no_memories() -> None:
|
||||
prompt = build_system_prompt(
|
||||
agent_type=AgentType.WORKER,
|
||||
agent_type=AgentType.ROUTER,
|
||||
user_context=_build_user_context(),
|
||||
now_utc=datetime(2026, 3, 11, 0, 0, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
assert "<!-- MEMORY_START -->" not in prompt
|
||||
assert "<!-- USER_MEMORY_START -->" not in prompt
|
||||
assert "<!-- WORK_MEMORY_START -->" not in prompt
|
||||
|
||||
@@ -28,25 +28,3 @@ def test_build_stage_toolkit_uses_explicit_enabled_tools_as_final_set(
|
||||
)
|
||||
|
||||
assert captured["enabled_tool_names"] == {"calendar_read", "user_lookup"}
|
||||
|
||||
|
||||
def test_build_stage_toolkit_uses_memory_defaults_without_explicit_tools(
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
def _fake_build_toolkit(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return object()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.agentscope.tools.toolkit.build_toolkit", _fake_build_toolkit
|
||||
)
|
||||
|
||||
build_stage_toolkit(
|
||||
agent_type=AgentType.MEMORY,
|
||||
session=cast(Any, object()),
|
||||
owner_id=uuid4(),
|
||||
)
|
||||
|
||||
assert captured["enabled_tool_names"] == {"calendar_read", "user_lookup"}
|
||||
|
||||
@@ -16,13 +16,14 @@ async def test_build_toolkit_registers_calendar_tools() -> None:
|
||||
toolkit = build_toolkit(
|
||||
session=cast(AsyncSession, SimpleNamespace()),
|
||||
owner_id=uuid4(),
|
||||
user_token="token-123",
|
||||
)
|
||||
schemas = toolkit.get_json_schemas()
|
||||
names = {item["function"]["name"] for item in schemas}
|
||||
assert "calendar_read" in names
|
||||
assert "calendar_write" in names
|
||||
assert "calendar_share" in names
|
||||
assert "memory_write" in names
|
||||
assert "memory_forget" in names
|
||||
|
||||
write_schema = next(
|
||||
item for item in schemas if item["function"]["name"] == "calendar_write"
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from v1.auth.automation_static_config import load_static_automation_job_config
|
||||
|
||||
|
||||
def test_memory_automation_static_config_contract() -> None:
|
||||
config = load_static_automation_job_config(config_name="memory_extraction")
|
||||
|
||||
assert config.context.window_mode.value == "day"
|
||||
assert config.context.window_count == 2
|
||||
assert [tool.value for tool in config.enabled_tools] == [
|
||||
"memory.write",
|
||||
"memory.forget",
|
||||
]
|
||||
prompt = config.input_template
|
||||
assert "提取" in prompt
|
||||
assert "遗忘" in prompt
|
||||
@@ -16,3 +16,18 @@ def test_memory_automation_job_trigger_exists_in_0004_migration() -> None:
|
||||
assert "'agent_type', 'memory'" in content
|
||||
assert "ux_automation_jobs_owner_memory_active" in content
|
||||
assert "input_template" in content
|
||||
|
||||
|
||||
def test_bootstrap_key_replaces_agent_type_unique_anchor() -> None:
|
||||
migration = (
|
||||
Path(__file__).resolve().parents[3]
|
||||
/ "alembic"
|
||||
/ "versions"
|
||||
/ "20260323_0003_bootstrap_job_key_and_unique_indexes.py"
|
||||
)
|
||||
content = migration.read_text(encoding="utf-8")
|
||||
|
||||
assert "bootstrap_key" in content
|
||||
assert "ux_automation_jobs_owner_bootstrap_key_active" in content
|
||||
assert "ux_memories_owner_memory_type" in content
|
||||
assert "DROP INDEX IF EXISTS ux_automation_jobs_owner_memory_active" in content
|
||||
|
||||
@@ -12,6 +12,14 @@ from v1.auth.schemas import (
|
||||
from v1.auth.service import AuthService, AuthServiceGateway
|
||||
|
||||
|
||||
class FakeRegistrationBootstrapper:
|
||||
def __init__(self) -> None:
|
||||
self.called_user_ids: list[str] = []
|
||||
|
||||
async def ensure_user_automation_jobs(self, *, user_id: str) -> None:
|
||||
self.called_user_ids.append(user_id)
|
||||
|
||||
|
||||
class FakeGateway(AuthServiceGateway):
|
||||
def __init__(self, response: SessionResponse) -> None:
|
||||
self._response = response
|
||||
@@ -75,6 +83,27 @@ async def test_create_phone_session_forwards_payload() -> None:
|
||||
assert response.user.phone == "+8613812345678"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_phone_session_bootstraps_automation_job() -> None:
|
||||
user = AuthUser(id="b196f8be-c5f4-45d8-8f07-65c0ddf4d3de", phone="+8613812345678")
|
||||
token_response = SessionResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
gateway = FakeGateway(token_response)
|
||||
bootstrapper = FakeRegistrationBootstrapper()
|
||||
service = AuthService(gateway=gateway, registration_bootstrapper=bootstrapper)
|
||||
|
||||
await service.create_phone_session(
|
||||
PhoneSessionCreateRequest(phone="+8613812345678", token="123456")
|
||||
)
|
||||
|
||||
assert bootstrapper.called_user_ids == ["b196f8be-c5f4-45d8-8f07-65c0ddf4d3de"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_session_forwards_payload() -> None:
|
||||
user = AuthUser(id="user-1", phone="+8613812345678")
|
||||
|
||||
@@ -0,0 +1,112 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from v1.auth.registration_bootstrap import (
|
||||
compute_next_local_time_utc,
|
||||
)
|
||||
|
||||
|
||||
def test_compute_next_local_time_utc_from_asia_shanghai() -> None:
|
||||
now_utc = datetime(2026, 3, 23, 0, 30, tzinfo=timezone.utc)
|
||||
|
||||
run_at, next_run_at = compute_next_local_time_utc(
|
||||
now_utc=now_utc,
|
||||
timezone_name="Asia/Shanghai",
|
||||
local_hour=8,
|
||||
local_minute=0,
|
||||
)
|
||||
|
||||
assert run_at == datetime(2026, 3, 24, 0, 0, tzinfo=timezone.utc)
|
||||
assert next_run_at == datetime(2026, 3, 25, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
|
||||
def test_compute_next_local_time_utc_rolls_to_next_day_when_passed() -> None:
|
||||
now_utc = datetime(2026, 3, 23, 2, 30, tzinfo=timezone.utc)
|
||||
|
||||
run_at, next_run_at = compute_next_local_time_utc(
|
||||
now_utc=now_utc,
|
||||
timezone_name="Asia/Shanghai",
|
||||
local_hour=8,
|
||||
local_minute=0,
|
||||
)
|
||||
|
||||
assert run_at == datetime(2026, 3, 24, 0, 0, tzinfo=timezone.utc)
|
||||
assert next_run_at == datetime(2026, 3, 25, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_registration_service_is_idempotent_when_job_exists() -> None:
|
||||
from v1.auth.registration_bootstrap import RegistrationAutomationBootstrapService
|
||||
|
||||
expected_owner_id = uuid4()
|
||||
|
||||
class _Repo:
|
||||
inserted = 0
|
||||
upsert_calls = 0
|
||||
|
||||
async def get_profile_timezone(self, *, user_id):
|
||||
assert user_id == expected_owner_id
|
||||
return "Asia/Shanghai"
|
||||
|
||||
async def insert_bootstrap_automation_job_if_absent(self, **kwargs):
|
||||
assert kwargs["owner_id"] == expected_owner_id
|
||||
assert kwargs["bootstrap_key"] == "memory_extraction"
|
||||
self.inserted += 1
|
||||
return False
|
||||
|
||||
async def upsert_initial_memory(self, **kwargs):
|
||||
self.upsert_calls += 1
|
||||
return False
|
||||
|
||||
class _Session:
|
||||
async def commit(self):
|
||||
raise AssertionError("must not commit when already exists")
|
||||
|
||||
async def rollback(self):
|
||||
raise AssertionError("must not rollback when no error")
|
||||
|
||||
service = RegistrationAutomationBootstrapService(
|
||||
repository=cast(Any, _Repo()), session=cast(Any, _Session())
|
||||
)
|
||||
await service.ensure_user_automation_jobs(user_id=str(expected_owner_id))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_registration_service_creates_initial_memories_when_missing() -> None:
|
||||
from v1.auth.registration_bootstrap import RegistrationAutomationBootstrapService
|
||||
|
||||
expected_owner_id = uuid4()
|
||||
|
||||
class _Repo:
|
||||
async def get_profile_timezone(self, *, user_id):
|
||||
assert user_id == expected_owner_id
|
||||
return "Asia/Shanghai"
|
||||
|
||||
async def upsert_initial_memory(self, **kwargs):
|
||||
return True
|
||||
|
||||
async def insert_bootstrap_automation_job_if_absent(self, **kwargs):
|
||||
_ = kwargs
|
||||
return True
|
||||
|
||||
class _Session:
|
||||
committed = 0
|
||||
|
||||
async def commit(self):
|
||||
self.committed += 1
|
||||
|
||||
async def rollback(self):
|
||||
raise AssertionError("must not rollback when no error")
|
||||
|
||||
session = _Session()
|
||||
service = RegistrationAutomationBootstrapService(
|
||||
repository=cast(Any, _Repo()), session=cast(Any, session)
|
||||
)
|
||||
await service.ensure_user_automation_jobs(user_id=str(expected_owner_id))
|
||||
|
||||
assert session.committed == 1
|
||||
Reference in New Issue
Block a user