feat: 重构 memory 系统,支持 user memory 和 work memory 分离

This commit is contained in:
qzl
2026-03-23 14:25:47 +08:00
parent 3aacc756db
commit 6be616f108
70 changed files with 7031 additions and 431 deletions
@@ -0,0 +1,46 @@
from __future__ import annotations
from functools import lru_cache
from pathlib import Path
import re
from typing import Any
import yaml
from core.agentscope.tools.tool_config import AgentTool
from schemas.automation import AutomationJobConfig, MessageContextConfig
_CONFIG_NAME_PATTERN = re.compile(r"^[a-z0-9][a-z0-9_-]{0,63}$")
def _automation_yaml_path(config_name: str) -> Path:
if not _CONFIG_NAME_PATTERN.fullmatch(config_name):
raise ValueError("invalid automation config name")
return (
Path(__file__).resolve().parents[2]
/ "core"
/ "config"
/ "static"
/ "automation"
/ f"{config_name}.yaml"
)
@lru_cache(maxsize=16)
def load_static_automation_job_config(*, config_name: str) -> AutomationJobConfig:
path = _automation_yaml_path(config_name)
with path.open("r", encoding="utf-8") as file:
loaded: Any = yaml.safe_load(file) or {}
if not isinstance(loaded, dict):
raise ValueError(f"invalid automation config format: {path}")
config = AutomationJobConfig.model_validate(loaded)
if config_name == "memory_extraction":
if config.enabled_tools != [AgentTool.MEMORY_WRITE, AgentTool.MEMORY_FORGET]:
raise ValueError(
"memory_extraction enabled_tools must be [memory.write, memory.forget]"
)
if config.context != MessageContextConfig(window_count=2):
raise ValueError(
"memory_extraction context must be latest_chat/day with window_count=2"
)
return config
+21 -2
View File
@@ -1,8 +1,27 @@
from __future__ import annotations
from typing import Annotated
from fastapi import Depends
from sqlalchemy.ext.asyncio import AsyncSession
from core.db import get_db
from v1.auth.gateway import SupabaseAuthGateway
from v1.auth.registration_bootstrap import (
RegistrationAutomationBootstrapService,
RegistrationBootstrapRepository,
)
from v1.auth.service import AuthService
def get_auth_service() -> AuthService:
return AuthService(gateway=SupabaseAuthGateway())
def get_auth_service(
session: Annotated[AsyncSession, Depends(get_db)],
) -> AuthService:
bootstrapper = RegistrationAutomationBootstrapService(
repository=RegistrationBootstrapRepository(session=session),
session=session,
)
return AuthService(
gateway=SupabaseAuthGateway(),
registration_bootstrapper=bootstrapper,
)
@@ -0,0 +1,228 @@
from __future__ import annotations
from datetime import UTC, datetime, timedelta
from typing import Protocol
from uuid import UUID, uuid4
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
from sqlalchemy import select
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.ext.asyncio import AsyncSession
from core.logging import get_logger
from models.automation_jobs import AutomationJob, AutomationJobStatus, ScheduleType
from models.memories import MemoryType
from models.profile import Profile
from schemas.automation import AutomationJobConfig
from schemas.memories.memory_content import UserMemoryContent, WorkProfileContent
from schemas.user.context import parse_profile_settings
from v1.auth.automation_static_config import load_static_automation_job_config
from v1.auth.schemas import RegistrationBootstrapRequest
from v1.memories.repository import SQLAlchemyMemoriesRepository
logger = get_logger("v1.auth.registration_bootstrap")
_LOCAL_RUN_HOUR = 8
_LOCAL_RUN_MINUTE = 0
class RegistrationBootstrapRepository:
def __init__(self, session: AsyncSession) -> None:
self._session = session
self._memories_repository = SQLAlchemyMemoriesRepository(session)
async def get_profile_timezone(self, *, user_id: UUID) -> str:
stmt = select(Profile.settings).where(Profile.id == user_id)
settings = (await self._session.execute(stmt)).scalar_one_or_none()
parsed = parse_profile_settings(
settings if isinstance(settings, dict) else None
)
return parsed.preferences.timezone
async def insert_bootstrap_automation_job_if_absent(
self,
*,
owner_id: UUID,
bootstrap_key: str,
title: str,
config: AutomationJobConfig,
timezone_name: str,
run_at: datetime,
next_run_at: datetime,
) -> bool:
stmt = (
insert(AutomationJob)
.values(
id=uuid4(),
owner_id=owner_id,
bootstrap_key=bootstrap_key,
title=title,
config=config.model_dump(mode="json"),
schedule_type=ScheduleType.DAILY,
run_at=run_at,
next_run_at=next_run_at,
timezone=timezone_name,
status=AutomationJobStatus.ACTIVE,
created_by=owner_id,
)
.on_conflict_do_nothing(
index_elements=["owner_id", "bootstrap_key"],
index_where=AutomationJob.deleted_at.is_(None)
& AutomationJob.bootstrap_key.is_not(None),
)
.returning(AutomationJob.id)
)
inserted_id = (await self._session.execute(stmt)).scalar_one_or_none()
await self._session.flush()
return inserted_id is not None
async def upsert_initial_memory(
self,
*,
owner_id: UUID,
memory_type: MemoryType,
content: dict,
) -> bool:
return await self._memories_repository.create_if_absent(
owner_id=owner_id,
memory_type=memory_type,
content=content,
)
class RegistrationBootstrapper(Protocol):
async def ensure_user_automation_jobs(self, *, user_id: str | UUID) -> None: ...
class RegistrationBootstrapRepositoryLike(Protocol):
async def get_profile_timezone(self, *, user_id: UUID) -> str: ...
async def insert_bootstrap_automation_job_if_absent(
self,
*,
owner_id: UUID,
bootstrap_key: str,
title: str,
config: AutomationJobConfig,
timezone_name: str,
run_at: datetime,
next_run_at: datetime,
) -> bool: ...
async def upsert_initial_memory(
self,
*,
owner_id: UUID,
memory_type: MemoryType,
content: dict,
) -> bool: ...
class SessionLike(Protocol):
async def commit(self) -> None: ...
async def rollback(self) -> None: ...
def compute_next_local_time_utc(
*,
now_utc: datetime,
timezone_name: str,
local_hour: int,
local_minute: int,
) -> tuple[datetime, datetime]:
try:
timezone_obj = ZoneInfo(timezone_name)
except ZoneInfoNotFoundError:
timezone_obj = ZoneInfo("UTC")
local_now = now_utc.astimezone(timezone_obj)
today_run_local = local_now.replace(
hour=local_hour,
minute=local_minute,
second=0,
microsecond=0,
)
run_local = (
today_run_local
if local_now <= today_run_local
else today_run_local + timedelta(days=1)
)
next_local = run_local + timedelta(days=1)
return run_local.astimezone(UTC), next_local.astimezone(UTC)
class RegistrationAutomationBootstrapService:
def __init__(
self,
*,
repository: RegistrationBootstrapRepositoryLike,
session: SessionLike,
) -> None:
self._repository = repository
self._session = session
async def ensure_user_automation_jobs(self, *, user_id: str | UUID) -> None:
request = RegistrationBootstrapRequest.model_validate({"user_id": user_id})
owner_id = request.user_id
timezone_name = await self._repository.get_profile_timezone(user_id=owner_id)
definitions = [
{
"bootstrap_key": "memory_extraction",
"config_name": "memory_extraction",
"title": "Memory Agent",
"local_hour": _LOCAL_RUN_HOUR,
"local_minute": _LOCAL_RUN_MINUTE,
}
]
try:
inserted_any = False
created_or_updated_memory = False
user_initialized = await self._repository.upsert_initial_memory(
owner_id=owner_id,
memory_type=MemoryType.USER,
content=UserMemoryContent().model_dump(mode="json"),
)
work_initialized = await self._repository.upsert_initial_memory(
owner_id=owner_id,
memory_type=MemoryType.WORK,
content=WorkProfileContent().model_dump(mode="json"),
)
created_or_updated_memory = user_initialized or work_initialized
for definition in definitions:
bootstrap_key = str(definition["bootstrap_key"])
job_config = load_static_automation_job_config(
config_name=str(definition["config_name"])
)
run_at, next_run_at = compute_next_local_time_utc(
now_utc=datetime.now(UTC),
timezone_name=timezone_name,
local_hour=int(definition["local_hour"]),
local_minute=int(definition["local_minute"]),
)
inserted = (
await self._repository.insert_bootstrap_automation_job_if_absent(
owner_id=owner_id,
bootstrap_key=bootstrap_key,
title=str(definition["title"]),
config=job_config,
timezone_name=timezone_name,
run_at=run_at,
next_run_at=next_run_at,
)
)
inserted_any = inserted_any or inserted
if inserted_any or created_or_updated_memory:
await self._session.commit()
logger.info(
"user automation jobs bootstrapped",
user_id=user_id,
timezone=timezone_name,
memory_initialized=created_or_updated_memory,
)
except Exception:
await self._session.rollback()
raise
+8
View File
@@ -1,5 +1,7 @@
from __future__ import annotations
from uuid import UUID
from pydantic import BaseModel, ConfigDict, Field
SUPABASE_PASSWORD_MIN_LENGTH = 6
@@ -49,3 +51,9 @@ class UserByPhoneResponse(BaseModel):
class OtpSendResponse(BaseModel):
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
class RegistrationBootstrapRequest(BaseModel):
model_config = ConfigDict(extra="forbid")
user_id: UUID
+18 -2
View File
@@ -28,9 +28,15 @@ class AuthServiceGateway(Protocol):
class AuthService:
_gateway: AuthServiceGateway
_registration_bootstrapper: RegistrationBootstrapper | None
def __init__(self, gateway: AuthServiceGateway) -> None:
def __init__(
self,
gateway: AuthServiceGateway,
registration_bootstrapper: "RegistrationBootstrapper | None" = None,
) -> None:
self._gateway = gateway
self._registration_bootstrapper = registration_bootstrapper
async def send_otp(self, request: OtpSendRequest) -> None:
await self._gateway.send_otp(request)
@@ -38,10 +44,20 @@ class AuthService:
async def create_phone_session(
self, request: PhoneSessionCreateRequest
) -> SessionResponse:
return await self._gateway.create_phone_session(request)
response = await self._gateway.create_phone_session(request)
if self._registration_bootstrapper is not None:
await self._registration_bootstrapper.ensure_user_automation_jobs(
user_id=response.user.id
)
return response
async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse:
return await self._gateway.refresh_session(request)
async def delete_session(self, refresh_token: str | None) -> None:
await self._gateway.delete_session(refresh_token)
class RegistrationBootstrapper(Protocol):
async def ensure_user_automation_jobs(self, *, user_id: str) -> None:
raise NotImplementedError