feat: 重构 memory 系统,支持 user memory 和 work memory 分离
This commit is contained in:
@@ -0,0 +1,46 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
from core.agentscope.tools.tool_config import AgentTool
|
||||
from schemas.automation import AutomationJobConfig, MessageContextConfig
|
||||
|
||||
_CONFIG_NAME_PATTERN = re.compile(r"^[a-z0-9][a-z0-9_-]{0,63}$")
|
||||
|
||||
|
||||
def _automation_yaml_path(config_name: str) -> Path:
|
||||
if not _CONFIG_NAME_PATTERN.fullmatch(config_name):
|
||||
raise ValueError("invalid automation config name")
|
||||
return (
|
||||
Path(__file__).resolve().parents[2]
|
||||
/ "core"
|
||||
/ "config"
|
||||
/ "static"
|
||||
/ "automation"
|
||||
/ f"{config_name}.yaml"
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=16)
|
||||
def load_static_automation_job_config(*, config_name: str) -> AutomationJobConfig:
|
||||
path = _automation_yaml_path(config_name)
|
||||
with path.open("r", encoding="utf-8") as file:
|
||||
loaded: Any = yaml.safe_load(file) or {}
|
||||
if not isinstance(loaded, dict):
|
||||
raise ValueError(f"invalid automation config format: {path}")
|
||||
config = AutomationJobConfig.model_validate(loaded)
|
||||
if config_name == "memory_extraction":
|
||||
if config.enabled_tools != [AgentTool.MEMORY_WRITE, AgentTool.MEMORY_FORGET]:
|
||||
raise ValueError(
|
||||
"memory_extraction enabled_tools must be [memory.write, memory.forget]"
|
||||
)
|
||||
if config.context != MessageContextConfig(window_count=2):
|
||||
raise ValueError(
|
||||
"memory_extraction context must be latest_chat/day with window_count=2"
|
||||
)
|
||||
return config
|
||||
@@ -1,8 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.db import get_db
|
||||
from v1.auth.gateway import SupabaseAuthGateway
|
||||
from v1.auth.registration_bootstrap import (
|
||||
RegistrationAutomationBootstrapService,
|
||||
RegistrationBootstrapRepository,
|
||||
)
|
||||
from v1.auth.service import AuthService
|
||||
|
||||
|
||||
def get_auth_service() -> AuthService:
|
||||
return AuthService(gateway=SupabaseAuthGateway())
|
||||
def get_auth_service(
|
||||
session: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> AuthService:
|
||||
bootstrapper = RegistrationAutomationBootstrapService(
|
||||
repository=RegistrationBootstrapRepository(session=session),
|
||||
session=session,
|
||||
)
|
||||
return AuthService(
|
||||
gateway=SupabaseAuthGateway(),
|
||||
registration_bootstrapper=bootstrapper,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,228 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Protocol
|
||||
from uuid import UUID, uuid4
|
||||
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.logging import get_logger
|
||||
from models.automation_jobs import AutomationJob, AutomationJobStatus, ScheduleType
|
||||
from models.memories import MemoryType
|
||||
from models.profile import Profile
|
||||
from schemas.automation import AutomationJobConfig
|
||||
from schemas.memories.memory_content import UserMemoryContent, WorkProfileContent
|
||||
from schemas.user.context import parse_profile_settings
|
||||
from v1.auth.automation_static_config import load_static_automation_job_config
|
||||
from v1.auth.schemas import RegistrationBootstrapRequest
|
||||
from v1.memories.repository import SQLAlchemyMemoriesRepository
|
||||
|
||||
logger = get_logger("v1.auth.registration_bootstrap")
|
||||
|
||||
_LOCAL_RUN_HOUR = 8
|
||||
_LOCAL_RUN_MINUTE = 0
|
||||
|
||||
|
||||
class RegistrationBootstrapRepository:
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
self._memories_repository = SQLAlchemyMemoriesRepository(session)
|
||||
|
||||
async def get_profile_timezone(self, *, user_id: UUID) -> str:
|
||||
stmt = select(Profile.settings).where(Profile.id == user_id)
|
||||
settings = (await self._session.execute(stmt)).scalar_one_or_none()
|
||||
parsed = parse_profile_settings(
|
||||
settings if isinstance(settings, dict) else None
|
||||
)
|
||||
return parsed.preferences.timezone
|
||||
|
||||
async def insert_bootstrap_automation_job_if_absent(
|
||||
self,
|
||||
*,
|
||||
owner_id: UUID,
|
||||
bootstrap_key: str,
|
||||
title: str,
|
||||
config: AutomationJobConfig,
|
||||
timezone_name: str,
|
||||
run_at: datetime,
|
||||
next_run_at: datetime,
|
||||
) -> bool:
|
||||
stmt = (
|
||||
insert(AutomationJob)
|
||||
.values(
|
||||
id=uuid4(),
|
||||
owner_id=owner_id,
|
||||
bootstrap_key=bootstrap_key,
|
||||
title=title,
|
||||
config=config.model_dump(mode="json"),
|
||||
schedule_type=ScheduleType.DAILY,
|
||||
run_at=run_at,
|
||||
next_run_at=next_run_at,
|
||||
timezone=timezone_name,
|
||||
status=AutomationJobStatus.ACTIVE,
|
||||
created_by=owner_id,
|
||||
)
|
||||
.on_conflict_do_nothing(
|
||||
index_elements=["owner_id", "bootstrap_key"],
|
||||
index_where=AutomationJob.deleted_at.is_(None)
|
||||
& AutomationJob.bootstrap_key.is_not(None),
|
||||
)
|
||||
.returning(AutomationJob.id)
|
||||
)
|
||||
inserted_id = (await self._session.execute(stmt)).scalar_one_or_none()
|
||||
await self._session.flush()
|
||||
return inserted_id is not None
|
||||
|
||||
async def upsert_initial_memory(
|
||||
self,
|
||||
*,
|
||||
owner_id: UUID,
|
||||
memory_type: MemoryType,
|
||||
content: dict,
|
||||
) -> bool:
|
||||
return await self._memories_repository.create_if_absent(
|
||||
owner_id=owner_id,
|
||||
memory_type=memory_type,
|
||||
content=content,
|
||||
)
|
||||
|
||||
|
||||
class RegistrationBootstrapper(Protocol):
|
||||
async def ensure_user_automation_jobs(self, *, user_id: str | UUID) -> None: ...
|
||||
|
||||
|
||||
class RegistrationBootstrapRepositoryLike(Protocol):
|
||||
async def get_profile_timezone(self, *, user_id: UUID) -> str: ...
|
||||
|
||||
async def insert_bootstrap_automation_job_if_absent(
|
||||
self,
|
||||
*,
|
||||
owner_id: UUID,
|
||||
bootstrap_key: str,
|
||||
title: str,
|
||||
config: AutomationJobConfig,
|
||||
timezone_name: str,
|
||||
run_at: datetime,
|
||||
next_run_at: datetime,
|
||||
) -> bool: ...
|
||||
|
||||
async def upsert_initial_memory(
|
||||
self,
|
||||
*,
|
||||
owner_id: UUID,
|
||||
memory_type: MemoryType,
|
||||
content: dict,
|
||||
) -> bool: ...
|
||||
|
||||
|
||||
class SessionLike(Protocol):
|
||||
async def commit(self) -> None: ...
|
||||
|
||||
async def rollback(self) -> None: ...
|
||||
|
||||
|
||||
def compute_next_local_time_utc(
|
||||
*,
|
||||
now_utc: datetime,
|
||||
timezone_name: str,
|
||||
local_hour: int,
|
||||
local_minute: int,
|
||||
) -> tuple[datetime, datetime]:
|
||||
try:
|
||||
timezone_obj = ZoneInfo(timezone_name)
|
||||
except ZoneInfoNotFoundError:
|
||||
timezone_obj = ZoneInfo("UTC")
|
||||
local_now = now_utc.astimezone(timezone_obj)
|
||||
today_run_local = local_now.replace(
|
||||
hour=local_hour,
|
||||
minute=local_minute,
|
||||
second=0,
|
||||
microsecond=0,
|
||||
)
|
||||
run_local = (
|
||||
today_run_local
|
||||
if local_now <= today_run_local
|
||||
else today_run_local + timedelta(days=1)
|
||||
)
|
||||
next_local = run_local + timedelta(days=1)
|
||||
return run_local.astimezone(UTC), next_local.astimezone(UTC)
|
||||
|
||||
|
||||
class RegistrationAutomationBootstrapService:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
repository: RegistrationBootstrapRepositoryLike,
|
||||
session: SessionLike,
|
||||
) -> None:
|
||||
self._repository = repository
|
||||
self._session = session
|
||||
|
||||
async def ensure_user_automation_jobs(self, *, user_id: str | UUID) -> None:
|
||||
request = RegistrationBootstrapRequest.model_validate({"user_id": user_id})
|
||||
owner_id = request.user_id
|
||||
timezone_name = await self._repository.get_profile_timezone(user_id=owner_id)
|
||||
|
||||
definitions = [
|
||||
{
|
||||
"bootstrap_key": "memory_extraction",
|
||||
"config_name": "memory_extraction",
|
||||
"title": "Memory Agent",
|
||||
"local_hour": _LOCAL_RUN_HOUR,
|
||||
"local_minute": _LOCAL_RUN_MINUTE,
|
||||
}
|
||||
]
|
||||
|
||||
try:
|
||||
inserted_any = False
|
||||
created_or_updated_memory = False
|
||||
|
||||
user_initialized = await self._repository.upsert_initial_memory(
|
||||
owner_id=owner_id,
|
||||
memory_type=MemoryType.USER,
|
||||
content=UserMemoryContent().model_dump(mode="json"),
|
||||
)
|
||||
work_initialized = await self._repository.upsert_initial_memory(
|
||||
owner_id=owner_id,
|
||||
memory_type=MemoryType.WORK,
|
||||
content=WorkProfileContent().model_dump(mode="json"),
|
||||
)
|
||||
created_or_updated_memory = user_initialized or work_initialized
|
||||
|
||||
for definition in definitions:
|
||||
bootstrap_key = str(definition["bootstrap_key"])
|
||||
job_config = load_static_automation_job_config(
|
||||
config_name=str(definition["config_name"])
|
||||
)
|
||||
run_at, next_run_at = compute_next_local_time_utc(
|
||||
now_utc=datetime.now(UTC),
|
||||
timezone_name=timezone_name,
|
||||
local_hour=int(definition["local_hour"]),
|
||||
local_minute=int(definition["local_minute"]),
|
||||
)
|
||||
inserted = (
|
||||
await self._repository.insert_bootstrap_automation_job_if_absent(
|
||||
owner_id=owner_id,
|
||||
bootstrap_key=bootstrap_key,
|
||||
title=str(definition["title"]),
|
||||
config=job_config,
|
||||
timezone_name=timezone_name,
|
||||
run_at=run_at,
|
||||
next_run_at=next_run_at,
|
||||
)
|
||||
)
|
||||
inserted_any = inserted_any or inserted
|
||||
if inserted_any or created_or_updated_memory:
|
||||
await self._session.commit()
|
||||
logger.info(
|
||||
"user automation jobs bootstrapped",
|
||||
user_id=user_id,
|
||||
timezone=timezone_name,
|
||||
memory_initialized=created_or_updated_memory,
|
||||
)
|
||||
except Exception:
|
||||
await self._session.rollback()
|
||||
raise
|
||||
@@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
SUPABASE_PASSWORD_MIN_LENGTH = 6
|
||||
@@ -49,3 +51,9 @@ class UserByPhoneResponse(BaseModel):
|
||||
|
||||
class OtpSendResponse(BaseModel):
|
||||
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
|
||||
|
||||
|
||||
class RegistrationBootstrapRequest(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
user_id: UUID
|
||||
|
||||
@@ -28,9 +28,15 @@ class AuthServiceGateway(Protocol):
|
||||
|
||||
class AuthService:
|
||||
_gateway: AuthServiceGateway
|
||||
_registration_bootstrapper: RegistrationBootstrapper | None
|
||||
|
||||
def __init__(self, gateway: AuthServiceGateway) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
gateway: AuthServiceGateway,
|
||||
registration_bootstrapper: "RegistrationBootstrapper | None" = None,
|
||||
) -> None:
|
||||
self._gateway = gateway
|
||||
self._registration_bootstrapper = registration_bootstrapper
|
||||
|
||||
async def send_otp(self, request: OtpSendRequest) -> None:
|
||||
await self._gateway.send_otp(request)
|
||||
@@ -38,10 +44,20 @@ class AuthService:
|
||||
async def create_phone_session(
|
||||
self, request: PhoneSessionCreateRequest
|
||||
) -> SessionResponse:
|
||||
return await self._gateway.create_phone_session(request)
|
||||
response = await self._gateway.create_phone_session(request)
|
||||
if self._registration_bootstrapper is not None:
|
||||
await self._registration_bootstrapper.ensure_user_automation_jobs(
|
||||
user_id=response.user.id
|
||||
)
|
||||
return response
|
||||
|
||||
async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse:
|
||||
return await self._gateway.refresh_session(request)
|
||||
|
||||
async def delete_session(self, refresh_token: str | None) -> None:
|
||||
await self._gateway.delete_session(refresh_token)
|
||||
|
||||
|
||||
class RegistrationBootstrapper(Protocol):
|
||||
async def ensure_user_automation_jobs(self, *, user_id: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
Reference in New Issue
Block a user