feat: 重构 memory 系统,支持 user memory 和 work memory 分离
This commit is contained in:
@@ -9,11 +9,12 @@ from pathlib import Path
|
||||
import yaml
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.agentscope.tools.tool_config import AgentTool
|
||||
from schemas.agent.system_agent import SystemAgentLLMConfig
|
||||
from schemas.automation import (
|
||||
ContextSource,
|
||||
ContextWindowMode,
|
||||
MemoryContextConfig,
|
||||
MessageContextConfig,
|
||||
RuntimeConfig,
|
||||
)
|
||||
|
||||
@@ -38,9 +39,9 @@ def _load_system_agents_yaml(path: Path | None = None) -> dict:
|
||||
return loaded
|
||||
|
||||
|
||||
def _parse_context_messages_config(yaml_config: dict | None) -> MemoryContextConfig:
|
||||
def _parse_context_messages_config(yaml_config: dict | None) -> MessageContextConfig:
|
||||
if not yaml_config:
|
||||
return MemoryContextConfig()
|
||||
return MessageContextConfig()
|
||||
mode_str = yaml_config.get("mode", "day")
|
||||
count = yaml_config.get("count", 2)
|
||||
try:
|
||||
@@ -51,7 +52,7 @@ def _parse_context_messages_config(yaml_config: dict | None) -> MemoryContextCon
|
||||
window_mode = ContextWindowMode(mode_str)
|
||||
except ValueError:
|
||||
window_mode = ContextWindowMode.DAY
|
||||
return MemoryContextConfig(
|
||||
return MessageContextConfig(
|
||||
source=source,
|
||||
window_mode=window_mode,
|
||||
window_count=count,
|
||||
@@ -93,9 +94,9 @@ def build_runtime_config_from_system_agents(
|
||||
router_config.context_messages.model_dump() if router_config else None
|
||||
)
|
||||
|
||||
enabled_tools: list[str] = []
|
||||
enabled_tools: list[AgentTool] = []
|
||||
if worker_config and worker_config.enabled_tools:
|
||||
enabled_tools = [str(t) for t in worker_config.enabled_tools]
|
||||
enabled_tools = list(worker_config.enabled_tools)
|
||||
|
||||
return RuntimeConfig(
|
||||
enabled_tools=enabled_tools,
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
from core.agentscope.tools.tool_config import AgentTool
|
||||
from schemas.automation import AutomationJobConfig, MessageContextConfig
|
||||
|
||||
_CONFIG_NAME_PATTERN = re.compile(r"^[a-z0-9][a-z0-9_-]{0,63}$")
|
||||
|
||||
|
||||
def _automation_yaml_path(config_name: str) -> Path:
|
||||
if not _CONFIG_NAME_PATTERN.fullmatch(config_name):
|
||||
raise ValueError("invalid automation config name")
|
||||
return (
|
||||
Path(__file__).resolve().parents[2]
|
||||
/ "core"
|
||||
/ "config"
|
||||
/ "static"
|
||||
/ "automation"
|
||||
/ f"{config_name}.yaml"
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=16)
|
||||
def load_static_automation_job_config(*, config_name: str) -> AutomationJobConfig:
|
||||
path = _automation_yaml_path(config_name)
|
||||
with path.open("r", encoding="utf-8") as file:
|
||||
loaded: Any = yaml.safe_load(file) or {}
|
||||
if not isinstance(loaded, dict):
|
||||
raise ValueError(f"invalid automation config format: {path}")
|
||||
config = AutomationJobConfig.model_validate(loaded)
|
||||
if config_name == "memory_extraction":
|
||||
if config.enabled_tools != [AgentTool.MEMORY_WRITE, AgentTool.MEMORY_FORGET]:
|
||||
raise ValueError(
|
||||
"memory_extraction enabled_tools must be [memory.write, memory.forget]"
|
||||
)
|
||||
if config.context != MessageContextConfig(window_count=2):
|
||||
raise ValueError(
|
||||
"memory_extraction context must be latest_chat/day with window_count=2"
|
||||
)
|
||||
return config
|
||||
@@ -1,8 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.db import get_db
|
||||
from v1.auth.gateway import SupabaseAuthGateway
|
||||
from v1.auth.registration_bootstrap import (
|
||||
RegistrationAutomationBootstrapService,
|
||||
RegistrationBootstrapRepository,
|
||||
)
|
||||
from v1.auth.service import AuthService
|
||||
|
||||
|
||||
def get_auth_service() -> AuthService:
|
||||
return AuthService(gateway=SupabaseAuthGateway())
|
||||
def get_auth_service(
|
||||
session: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> AuthService:
|
||||
bootstrapper = RegistrationAutomationBootstrapService(
|
||||
repository=RegistrationBootstrapRepository(session=session),
|
||||
session=session,
|
||||
)
|
||||
return AuthService(
|
||||
gateway=SupabaseAuthGateway(),
|
||||
registration_bootstrapper=bootstrapper,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,228 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Protocol
|
||||
from uuid import UUID, uuid4
|
||||
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.logging import get_logger
|
||||
from models.automation_jobs import AutomationJob, AutomationJobStatus, ScheduleType
|
||||
from models.memories import MemoryType
|
||||
from models.profile import Profile
|
||||
from schemas.automation import AutomationJobConfig
|
||||
from schemas.memories.memory_content import UserMemoryContent, WorkProfileContent
|
||||
from schemas.user.context import parse_profile_settings
|
||||
from v1.auth.automation_static_config import load_static_automation_job_config
|
||||
from v1.auth.schemas import RegistrationBootstrapRequest
|
||||
from v1.memories.repository import SQLAlchemyMemoriesRepository
|
||||
|
||||
logger = get_logger("v1.auth.registration_bootstrap")
|
||||
|
||||
_LOCAL_RUN_HOUR = 8
|
||||
_LOCAL_RUN_MINUTE = 0
|
||||
|
||||
|
||||
class RegistrationBootstrapRepository:
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
self._memories_repository = SQLAlchemyMemoriesRepository(session)
|
||||
|
||||
async def get_profile_timezone(self, *, user_id: UUID) -> str:
|
||||
stmt = select(Profile.settings).where(Profile.id == user_id)
|
||||
settings = (await self._session.execute(stmt)).scalar_one_or_none()
|
||||
parsed = parse_profile_settings(
|
||||
settings if isinstance(settings, dict) else None
|
||||
)
|
||||
return parsed.preferences.timezone
|
||||
|
||||
async def insert_bootstrap_automation_job_if_absent(
|
||||
self,
|
||||
*,
|
||||
owner_id: UUID,
|
||||
bootstrap_key: str,
|
||||
title: str,
|
||||
config: AutomationJobConfig,
|
||||
timezone_name: str,
|
||||
run_at: datetime,
|
||||
next_run_at: datetime,
|
||||
) -> bool:
|
||||
stmt = (
|
||||
insert(AutomationJob)
|
||||
.values(
|
||||
id=uuid4(),
|
||||
owner_id=owner_id,
|
||||
bootstrap_key=bootstrap_key,
|
||||
title=title,
|
||||
config=config.model_dump(mode="json"),
|
||||
schedule_type=ScheduleType.DAILY,
|
||||
run_at=run_at,
|
||||
next_run_at=next_run_at,
|
||||
timezone=timezone_name,
|
||||
status=AutomationJobStatus.ACTIVE,
|
||||
created_by=owner_id,
|
||||
)
|
||||
.on_conflict_do_nothing(
|
||||
index_elements=["owner_id", "bootstrap_key"],
|
||||
index_where=AutomationJob.deleted_at.is_(None)
|
||||
& AutomationJob.bootstrap_key.is_not(None),
|
||||
)
|
||||
.returning(AutomationJob.id)
|
||||
)
|
||||
inserted_id = (await self._session.execute(stmt)).scalar_one_or_none()
|
||||
await self._session.flush()
|
||||
return inserted_id is not None
|
||||
|
||||
async def upsert_initial_memory(
|
||||
self,
|
||||
*,
|
||||
owner_id: UUID,
|
||||
memory_type: MemoryType,
|
||||
content: dict,
|
||||
) -> bool:
|
||||
return await self._memories_repository.create_if_absent(
|
||||
owner_id=owner_id,
|
||||
memory_type=memory_type,
|
||||
content=content,
|
||||
)
|
||||
|
||||
|
||||
class RegistrationBootstrapper(Protocol):
|
||||
async def ensure_user_automation_jobs(self, *, user_id: str | UUID) -> None: ...
|
||||
|
||||
|
||||
class RegistrationBootstrapRepositoryLike(Protocol):
|
||||
async def get_profile_timezone(self, *, user_id: UUID) -> str: ...
|
||||
|
||||
async def insert_bootstrap_automation_job_if_absent(
|
||||
self,
|
||||
*,
|
||||
owner_id: UUID,
|
||||
bootstrap_key: str,
|
||||
title: str,
|
||||
config: AutomationJobConfig,
|
||||
timezone_name: str,
|
||||
run_at: datetime,
|
||||
next_run_at: datetime,
|
||||
) -> bool: ...
|
||||
|
||||
async def upsert_initial_memory(
|
||||
self,
|
||||
*,
|
||||
owner_id: UUID,
|
||||
memory_type: MemoryType,
|
||||
content: dict,
|
||||
) -> bool: ...
|
||||
|
||||
|
||||
class SessionLike(Protocol):
|
||||
async def commit(self) -> None: ...
|
||||
|
||||
async def rollback(self) -> None: ...
|
||||
|
||||
|
||||
def compute_next_local_time_utc(
|
||||
*,
|
||||
now_utc: datetime,
|
||||
timezone_name: str,
|
||||
local_hour: int,
|
||||
local_minute: int,
|
||||
) -> tuple[datetime, datetime]:
|
||||
try:
|
||||
timezone_obj = ZoneInfo(timezone_name)
|
||||
except ZoneInfoNotFoundError:
|
||||
timezone_obj = ZoneInfo("UTC")
|
||||
local_now = now_utc.astimezone(timezone_obj)
|
||||
today_run_local = local_now.replace(
|
||||
hour=local_hour,
|
||||
minute=local_minute,
|
||||
second=0,
|
||||
microsecond=0,
|
||||
)
|
||||
run_local = (
|
||||
today_run_local
|
||||
if local_now <= today_run_local
|
||||
else today_run_local + timedelta(days=1)
|
||||
)
|
||||
next_local = run_local + timedelta(days=1)
|
||||
return run_local.astimezone(UTC), next_local.astimezone(UTC)
|
||||
|
||||
|
||||
class RegistrationAutomationBootstrapService:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
repository: RegistrationBootstrapRepositoryLike,
|
||||
session: SessionLike,
|
||||
) -> None:
|
||||
self._repository = repository
|
||||
self._session = session
|
||||
|
||||
async def ensure_user_automation_jobs(self, *, user_id: str | UUID) -> None:
|
||||
request = RegistrationBootstrapRequest.model_validate({"user_id": user_id})
|
||||
owner_id = request.user_id
|
||||
timezone_name = await self._repository.get_profile_timezone(user_id=owner_id)
|
||||
|
||||
definitions = [
|
||||
{
|
||||
"bootstrap_key": "memory_extraction",
|
||||
"config_name": "memory_extraction",
|
||||
"title": "Memory Agent",
|
||||
"local_hour": _LOCAL_RUN_HOUR,
|
||||
"local_minute": _LOCAL_RUN_MINUTE,
|
||||
}
|
||||
]
|
||||
|
||||
try:
|
||||
inserted_any = False
|
||||
created_or_updated_memory = False
|
||||
|
||||
user_initialized = await self._repository.upsert_initial_memory(
|
||||
owner_id=owner_id,
|
||||
memory_type=MemoryType.USER,
|
||||
content=UserMemoryContent().model_dump(mode="json"),
|
||||
)
|
||||
work_initialized = await self._repository.upsert_initial_memory(
|
||||
owner_id=owner_id,
|
||||
memory_type=MemoryType.WORK,
|
||||
content=WorkProfileContent().model_dump(mode="json"),
|
||||
)
|
||||
created_or_updated_memory = user_initialized or work_initialized
|
||||
|
||||
for definition in definitions:
|
||||
bootstrap_key = str(definition["bootstrap_key"])
|
||||
job_config = load_static_automation_job_config(
|
||||
config_name=str(definition["config_name"])
|
||||
)
|
||||
run_at, next_run_at = compute_next_local_time_utc(
|
||||
now_utc=datetime.now(UTC),
|
||||
timezone_name=timezone_name,
|
||||
local_hour=int(definition["local_hour"]),
|
||||
local_minute=int(definition["local_minute"]),
|
||||
)
|
||||
inserted = (
|
||||
await self._repository.insert_bootstrap_automation_job_if_absent(
|
||||
owner_id=owner_id,
|
||||
bootstrap_key=bootstrap_key,
|
||||
title=str(definition["title"]),
|
||||
config=job_config,
|
||||
timezone_name=timezone_name,
|
||||
run_at=run_at,
|
||||
next_run_at=next_run_at,
|
||||
)
|
||||
)
|
||||
inserted_any = inserted_any or inserted
|
||||
if inserted_any or created_or_updated_memory:
|
||||
await self._session.commit()
|
||||
logger.info(
|
||||
"user automation jobs bootstrapped",
|
||||
user_id=user_id,
|
||||
timezone=timezone_name,
|
||||
memory_initialized=created_or_updated_memory,
|
||||
)
|
||||
except Exception:
|
||||
await self._session.rollback()
|
||||
raise
|
||||
@@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
SUPABASE_PASSWORD_MIN_LENGTH = 6
|
||||
@@ -49,3 +51,9 @@ class UserByPhoneResponse(BaseModel):
|
||||
|
||||
class OtpSendResponse(BaseModel):
|
||||
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
|
||||
|
||||
|
||||
class RegistrationBootstrapRequest(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
user_id: UUID
|
||||
|
||||
@@ -28,9 +28,15 @@ class AuthServiceGateway(Protocol):
|
||||
|
||||
class AuthService:
|
||||
_gateway: AuthServiceGateway
|
||||
_registration_bootstrapper: RegistrationBootstrapper | None
|
||||
|
||||
def __init__(self, gateway: AuthServiceGateway) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
gateway: AuthServiceGateway,
|
||||
registration_bootstrapper: "RegistrationBootstrapper | None" = None,
|
||||
) -> None:
|
||||
self._gateway = gateway
|
||||
self._registration_bootstrapper = registration_bootstrapper
|
||||
|
||||
async def send_otp(self, request: OtpSendRequest) -> None:
|
||||
await self._gateway.send_otp(request)
|
||||
@@ -38,10 +44,20 @@ class AuthService:
|
||||
async def create_phone_session(
|
||||
self, request: PhoneSessionCreateRequest
|
||||
) -> SessionResponse:
|
||||
return await self._gateway.create_phone_session(request)
|
||||
response = await self._gateway.create_phone_session(request)
|
||||
if self._registration_bootstrapper is not None:
|
||||
await self._registration_bootstrapper.ensure_user_automation_jobs(
|
||||
user_id=response.user.id
|
||||
)
|
||||
return response
|
||||
|
||||
async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse:
|
||||
return await self._gateway.refresh_session(request)
|
||||
|
||||
async def delete_session(self, refresh_token: str | None) -> None:
|
||||
await self._gateway.delete_session(refresh_token)
|
||||
|
||||
|
||||
class RegistrationBootstrapper(Protocol):
|
||||
async def ensure_user_automation_jobs(self, *, user_id: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from core.db import get_db
|
||||
from v1.memories.repository import SQLAlchemyMemoriesRepository
|
||||
from v1.memories.service import MemoriesService
|
||||
from v1.users.dependencies import get_current_user
|
||||
|
||||
|
||||
async def get_memories_repository(
|
||||
session: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> SQLAlchemyMemoriesRepository:
|
||||
return SQLAlchemyMemoriesRepository(session)
|
||||
|
||||
|
||||
async def get_memories_service(
|
||||
session: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
) -> MemoriesService:
|
||||
repository = SQLAlchemyMemoriesRepository(session)
|
||||
return MemoriesService(
|
||||
repository=repository,
|
||||
session=session,
|
||||
current_user=current_user,
|
||||
)
|
||||
@@ -3,29 +3,162 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from core.db.base_repository import BaseRepository
|
||||
from models.memories import Memory
|
||||
from core.logging import get_logger
|
||||
from models.memories import Memory, MemoryStatus, MemoryType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
logger = get_logger("v1.memories.repository")
|
||||
|
||||
|
||||
class MemoriesRepositoryLike(Protocol):
|
||||
async def get_active_memories(self, *, owner_id: UUID) -> list[Memory]: ...
|
||||
async def create(
|
||||
self,
|
||||
*,
|
||||
owner_id: UUID,
|
||||
memory_type: MemoryType,
|
||||
content: dict,
|
||||
) -> Memory: ...
|
||||
|
||||
async def get_by_type_for_owner(
|
||||
self, *, owner_id: UUID, memory_type: MemoryType
|
||||
) -> Memory | None: ...
|
||||
|
||||
async def get_user_memory_for_owner(self, *, owner_id: UUID) -> Memory | None: ...
|
||||
|
||||
async def get_work_memory_for_owner(self, *, owner_id: UUID) -> Memory | None: ...
|
||||
|
||||
async def create_if_absent(
|
||||
self,
|
||||
*,
|
||||
owner_id: UUID,
|
||||
memory_type: MemoryType,
|
||||
content: dict,
|
||||
) -> bool: ...
|
||||
|
||||
async def update_content(
|
||||
self,
|
||||
memory: Memory,
|
||||
content: dict | None = None,
|
||||
) -> Memory: ...
|
||||
|
||||
|
||||
class MemoriesRepository(BaseRepository[Memory]):
|
||||
class SQLAlchemyMemoriesRepository(BaseRepository[Memory]):
|
||||
_session: AsyncSession
|
||||
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
super().__init__(session=session, model=Memory)
|
||||
self._session = session
|
||||
|
||||
async def get_active_memories(self, *, owner_id: UUID) -> list[Memory]:
|
||||
stmt = (
|
||||
select(Memory)
|
||||
.where(Memory.owner_id == owner_id)
|
||||
.where(Memory.status == "active")
|
||||
.order_by(Memory.created_at.desc())
|
||||
async def create(
|
||||
self,
|
||||
*,
|
||||
owner_id: UUID,
|
||||
memory_type: MemoryType,
|
||||
content: dict,
|
||||
) -> Memory:
|
||||
try:
|
||||
memory = Memory(
|
||||
owner_id=owner_id,
|
||||
memory_type=memory_type,
|
||||
content=content,
|
||||
status=MemoryStatus.ACTIVE,
|
||||
)
|
||||
self._session.add(memory)
|
||||
await self._session.flush()
|
||||
return memory
|
||||
except SQLAlchemyError:
|
||||
logger.exception(
|
||||
"Failed to create memory",
|
||||
owner_id=str(owner_id),
|
||||
memory_type=memory_type.value,
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_by_type_for_owner(
|
||||
self, *, owner_id: UUID, memory_type: MemoryType
|
||||
) -> Memory | None:
|
||||
try:
|
||||
stmt = (
|
||||
select(Memory)
|
||||
.where(Memory.owner_id == owner_id)
|
||||
.where(Memory.memory_type == memory_type)
|
||||
.where(Memory.status == MemoryStatus.ACTIVE)
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
if hasattr(result, "scalar_one_or_none"):
|
||||
return result.scalar_one_or_none()
|
||||
scalars = result.scalars()
|
||||
rows = list(scalars.all())
|
||||
return rows[0] if rows else None
|
||||
except SQLAlchemyError:
|
||||
logger.exception(
|
||||
"Failed to get memory by type for owner",
|
||||
owner_id=str(owner_id),
|
||||
memory_type=memory_type.value,
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_user_memory_for_owner(self, *, owner_id: UUID) -> Memory | None:
|
||||
return await self.get_by_type_for_owner(
|
||||
owner_id=owner_id, memory_type=MemoryType.USER
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_work_memory_for_owner(self, *, owner_id: UUID) -> Memory | None:
|
||||
return await self.get_by_type_for_owner(
|
||||
owner_id=owner_id, memory_type=MemoryType.WORK
|
||||
)
|
||||
|
||||
async def create_if_absent(
|
||||
self,
|
||||
*,
|
||||
owner_id: UUID,
|
||||
memory_type: MemoryType,
|
||||
content: dict,
|
||||
) -> bool:
|
||||
try:
|
||||
stmt = (
|
||||
insert(Memory)
|
||||
.values(
|
||||
owner_id=owner_id,
|
||||
memory_type=memory_type,
|
||||
content=content,
|
||||
status=MemoryStatus.ACTIVE,
|
||||
)
|
||||
.on_conflict_do_nothing(index_elements=["owner_id", "memory_type"])
|
||||
.returning(Memory.id)
|
||||
)
|
||||
inserted_id = (await self._session.execute(stmt)).scalar_one_or_none()
|
||||
await self._session.flush()
|
||||
return inserted_id is not None
|
||||
except SQLAlchemyError:
|
||||
logger.exception(
|
||||
"Failed to create memory if absent",
|
||||
owner_id=str(owner_id),
|
||||
memory_type=memory_type.value,
|
||||
)
|
||||
raise
|
||||
|
||||
async def update_content(
|
||||
self,
|
||||
memory: Memory,
|
||||
content: dict | None = None,
|
||||
) -> Memory:
|
||||
try:
|
||||
if content is not None:
|
||||
memory.content = content
|
||||
|
||||
await self._session.flush()
|
||||
return memory
|
||||
except SQLAlchemyError:
|
||||
logger.exception(
|
||||
"Failed to update memory content",
|
||||
memory_id=str(memory.id),
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -0,0 +1,83 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, status
|
||||
|
||||
from schemas.memories.memory_content import UserMemoryContent, WorkProfileContent
|
||||
from v1.memories.dependencies import get_memories_service
|
||||
from v1.memories.schemas import (
|
||||
MemoryListResponse,
|
||||
UserMemoryPartialUpdate,
|
||||
UserMemoryUpdate,
|
||||
WorkMemoryPartialUpdate,
|
||||
WorkMemoryUpdate,
|
||||
)
|
||||
from v1.memories.service import MemoriesService
|
||||
|
||||
router = APIRouter(prefix="/memories", tags=["memories"])
|
||||
|
||||
|
||||
@router.get("", response_model=MemoryListResponse)
|
||||
async def get_all_memories(
|
||||
service: Annotated[MemoriesService, Depends(get_memories_service)],
|
||||
) -> MemoryListResponse:
|
||||
result = await service.get_all_memories()
|
||||
return MemoryListResponse(
|
||||
user_memory=result["user_memory"],
|
||||
work_memory=result["work_memory"],
|
||||
)
|
||||
|
||||
|
||||
@router.get("/user", response_model=UserMemoryContent | None)
|
||||
async def get_user_memory(
|
||||
service: Annotated[MemoriesService, Depends(get_memories_service)],
|
||||
) -> UserMemoryContent | None:
|
||||
return await service.get_user_memory()
|
||||
|
||||
|
||||
@router.get("/work", response_model=WorkProfileContent | None)
|
||||
async def get_work_memory(
|
||||
service: Annotated[MemoriesService, Depends(get_memories_service)],
|
||||
) -> WorkProfileContent | None:
|
||||
return await service.get_work_memory()
|
||||
|
||||
|
||||
@router.put("/user", response_model=UserMemoryContent, status_code=status.HTTP_200_OK)
|
||||
async def update_user_memory(
|
||||
payload: UserMemoryUpdate,
|
||||
service: Annotated[MemoriesService, Depends(get_memories_service)],
|
||||
) -> UserMemoryContent:
|
||||
return await service.update_user_memory(
|
||||
content=payload.content,
|
||||
)
|
||||
|
||||
|
||||
@router.put("/work", response_model=WorkProfileContent, status_code=status.HTTP_200_OK)
|
||||
async def update_work_memory(
|
||||
payload: WorkMemoryUpdate,
|
||||
service: Annotated[MemoriesService, Depends(get_memories_service)],
|
||||
) -> WorkProfileContent:
|
||||
return await service.update_work_memory(
|
||||
content=payload.content,
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/user", response_model=UserMemoryContent)
|
||||
async def patch_user_memory(
|
||||
payload: UserMemoryPartialUpdate,
|
||||
service: Annotated[MemoriesService, Depends(get_memories_service)],
|
||||
) -> UserMemoryContent:
|
||||
return await service.patch_user_memory(
|
||||
content=payload.content,
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/work", response_model=WorkProfileContent)
|
||||
async def patch_work_memory(
|
||||
payload: WorkMemoryPartialUpdate,
|
||||
service: Annotated[MemoriesService, Depends(get_memories_service)],
|
||||
) -> WorkProfileContent:
|
||||
return await service.patch_work_memory(
|
||||
content=payload.content,
|
||||
)
|
||||
@@ -0,0 +1,38 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import ClassVar
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from schemas.memories.memory_content import UserMemoryContent, WorkProfileContent
|
||||
|
||||
|
||||
class UserMemoryUpdate(BaseModel):
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
content: UserMemoryContent
|
||||
|
||||
|
||||
class WorkMemoryUpdate(BaseModel):
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
content: WorkProfileContent
|
||||
|
||||
|
||||
class UserMemoryPartialUpdate(BaseModel):
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
content: UserMemoryContent | None = None
|
||||
|
||||
|
||||
class WorkMemoryPartialUpdate(BaseModel):
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
content: WorkProfileContent | None = None
|
||||
|
||||
|
||||
class MemoryListResponse(BaseModel):
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(from_attributes=True)
|
||||
|
||||
user_memory: UserMemoryContent | None = None
|
||||
work_memory: WorkProfileContent | None = None
|
||||
@@ -1,53 +1,286 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from models.memories import Memory
|
||||
from schemas.memories import MemoryContext, MemoryListResponse, MemorySource, MemoryType
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from core.db.base_service import BaseService
|
||||
from core.logging import get_logger
|
||||
from models.memories import Memory, MemoryType
|
||||
from schemas.memories.memory_content import UserMemoryContent, WorkProfileContent
|
||||
from v1.memories.repository import MemoriesRepositoryLike
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
logger = get_logger("v1.memories.service")
|
||||
|
||||
|
||||
class MemoriesService(BaseService):
|
||||
"""Memories service handling user/work memory operations.
|
||||
|
||||
Each user has exactly 2 memory records:
|
||||
- user_type: stores personal preferences, people, places, etc.
|
||||
- work_type: stores work profile, projects, team, etc.
|
||||
|
||||
Responsibilities:
|
||||
- Authorization checks
|
||||
- Validation (ownership, memory type)
|
||||
- Transaction boundary (commit/rollback)
|
||||
- Converting ORM models to response schemas
|
||||
"""
|
||||
|
||||
class MemoriesService:
|
||||
_repository: MemoriesRepositoryLike
|
||||
_session: AsyncSession
|
||||
|
||||
def __init__(self, repository: MemoriesRepositoryLike) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
repository: MemoriesRepositoryLike,
|
||||
session: AsyncSession,
|
||||
current_user: CurrentUser | None,
|
||||
) -> None:
|
||||
super().__init__(current_user=current_user)
|
||||
self._repository = repository
|
||||
self._session = session
|
||||
|
||||
def _to_context(self, memory: Memory) -> MemoryContext:
|
||||
return MemoryContext(
|
||||
memory_type=MemoryType(memory.memory_type.value),
|
||||
source=MemorySource(memory.source.value),
|
||||
title=memory.title,
|
||||
content=memory.content,
|
||||
created_at=memory.created_at,
|
||||
updated_at=memory.updated_at,
|
||||
async def get_user_memory(self) -> UserMemoryContent | None:
|
||||
user_id = self.require_user_id()
|
||||
|
||||
try:
|
||||
memory = await self._repository.get_user_memory_for_owner(owner_id=user_id)
|
||||
except SQLAlchemyError:
|
||||
raise HTTPException(status_code=503, detail="Memories service unavailable")
|
||||
|
||||
if memory is None:
|
||||
return None
|
||||
|
||||
return self._parse_user_content(memory)
|
||||
|
||||
async def get_work_memory(self) -> WorkProfileContent | None:
|
||||
user_id = self.require_user_id()
|
||||
|
||||
try:
|
||||
memory = await self._repository.get_work_memory_for_owner(owner_id=user_id)
|
||||
except SQLAlchemyError:
|
||||
raise HTTPException(status_code=503, detail="Memories service unavailable")
|
||||
|
||||
if memory is None:
|
||||
return None
|
||||
|
||||
return self._parse_work_content(memory)
|
||||
|
||||
async def get_all_memories(self) -> dict:
|
||||
user_id = self.require_user_id()
|
||||
|
||||
try:
|
||||
user_memory = await self._repository.get_user_memory_for_owner(
|
||||
owner_id=user_id
|
||||
)
|
||||
work_memory = await self._repository.get_work_memory_for_owner(
|
||||
owner_id=user_id
|
||||
)
|
||||
except SQLAlchemyError:
|
||||
raise HTTPException(status_code=503, detail="Memories service unavailable")
|
||||
|
||||
return {
|
||||
"user_memory": self._parse_user_content(user_memory)
|
||||
if user_memory
|
||||
else None,
|
||||
"work_memory": self._parse_work_content(work_memory)
|
||||
if work_memory
|
||||
else None,
|
||||
}
|
||||
|
||||
async def update_user_memory(
|
||||
self,
|
||||
*,
|
||||
content: UserMemoryContent,
|
||||
) -> UserMemoryContent:
|
||||
user_id = self.require_user_id()
|
||||
|
||||
try:
|
||||
memory = await self._repository.get_user_memory_for_owner(owner_id=user_id)
|
||||
except SQLAlchemyError:
|
||||
raise HTTPException(status_code=503, detail="Memories service unavailable")
|
||||
|
||||
if memory is None:
|
||||
try:
|
||||
memory = await self._repository.create(
|
||||
owner_id=user_id,
|
||||
memory_type=MemoryType.USER,
|
||||
content=content.model_dump(),
|
||||
)
|
||||
await self._session.commit()
|
||||
except SQLAlchemyError:
|
||||
await self._session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Memories service unavailable"
|
||||
)
|
||||
else:
|
||||
try:
|
||||
memory = await self._repository.update_content(
|
||||
memory=memory,
|
||||
content=content.model_dump(),
|
||||
)
|
||||
await self._session.commit()
|
||||
await self._session.refresh(memory)
|
||||
except SQLAlchemyError:
|
||||
await self._session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Memories service unavailable"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"user_memory_updated",
|
||||
extra={"user_id": str(user_id)},
|
||||
)
|
||||
|
||||
async def get_user_memories(self, *, owner_id: UUID) -> MemoryListResponse:
|
||||
memories = await self._repository.get_active_memories(owner_id=owner_id)
|
||||
user_memories = [
|
||||
self._to_context(memory)
|
||||
for memory in memories
|
||||
if memory.memory_type.value == "user"
|
||||
]
|
||||
return MemoryListResponse(
|
||||
owner_id=owner_id, memories=user_memories, total=len(user_memories)
|
||||
return self._parse_user_content(memory)
|
||||
|
||||
async def update_work_memory(
|
||||
self,
|
||||
*,
|
||||
content: WorkProfileContent,
|
||||
) -> WorkProfileContent:
|
||||
user_id = self.require_user_id()
|
||||
|
||||
try:
|
||||
memory = await self._repository.get_work_memory_for_owner(owner_id=user_id)
|
||||
except SQLAlchemyError:
|
||||
raise HTTPException(status_code=503, detail="Memories service unavailable")
|
||||
|
||||
if memory is None:
|
||||
try:
|
||||
memory = await self._repository.create(
|
||||
owner_id=user_id,
|
||||
memory_type=MemoryType.WORK,
|
||||
content=content.model_dump(),
|
||||
)
|
||||
await self._session.commit()
|
||||
except SQLAlchemyError:
|
||||
await self._session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Memories service unavailable"
|
||||
)
|
||||
else:
|
||||
try:
|
||||
memory = await self._repository.update_content(
|
||||
memory=memory,
|
||||
content=content.model_dump(),
|
||||
)
|
||||
await self._session.commit()
|
||||
await self._session.refresh(memory)
|
||||
except SQLAlchemyError:
|
||||
await self._session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Memories service unavailable"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"work_memory_updated",
|
||||
extra={"user_id": str(user_id)},
|
||||
)
|
||||
|
||||
async def get_agent_memories(self, *, owner_id: UUID) -> MemoryListResponse:
|
||||
memories = await self._repository.get_active_memories(owner_id=owner_id)
|
||||
agent_memories = [
|
||||
self._to_context(memory)
|
||||
for memory in memories
|
||||
if memory.memory_type.value == "work"
|
||||
]
|
||||
return MemoryListResponse(
|
||||
owner_id=owner_id, memories=agent_memories, total=len(agent_memories)
|
||||
return self._parse_work_content(memory)
|
||||
|
||||
async def patch_user_memory(
|
||||
self, *, content: UserMemoryContent | None = None
|
||||
) -> UserMemoryContent:
|
||||
user_id = self.require_user_id()
|
||||
|
||||
try:
|
||||
memory = await self._repository.get_user_memory_for_owner(owner_id=user_id)
|
||||
except SQLAlchemyError:
|
||||
raise HTTPException(status_code=503, detail="Memories service unavailable")
|
||||
|
||||
if memory is None:
|
||||
raise HTTPException(status_code=404, detail="User memory not found")
|
||||
|
||||
try:
|
||||
update_data: dict = {}
|
||||
if content is not None:
|
||||
existing_content = memory.content or {}
|
||||
merged = content.model_dump()
|
||||
existing_content.update(
|
||||
{k: v for k, v in merged.items() if v is not None}
|
||||
)
|
||||
update_data["content"] = existing_content
|
||||
|
||||
if update_data:
|
||||
memory = await self._repository.update_content(
|
||||
memory=memory,
|
||||
content=update_data.get("content"),
|
||||
)
|
||||
await self._session.commit()
|
||||
await self._session.refresh(memory)
|
||||
except SQLAlchemyError:
|
||||
await self._session.rollback()
|
||||
raise HTTPException(status_code=503, detail="Memories service unavailable")
|
||||
|
||||
logger.info(
|
||||
"user_memory_patched",
|
||||
extra={"user_id": str(user_id)},
|
||||
)
|
||||
|
||||
async def get_all_memories(self, *, owner_id: UUID) -> MemoryListResponse:
|
||||
memories = await self._repository.get_active_memories(owner_id=owner_id)
|
||||
memory_contexts = [self._to_context(memory) for memory in memories]
|
||||
return MemoryListResponse(
|
||||
owner_id=owner_id, memories=memory_contexts, total=len(memory_contexts)
|
||||
return self._parse_user_content(memory)
|
||||
|
||||
async def patch_work_memory(
|
||||
self, *, content: WorkProfileContent | None = None
|
||||
) -> WorkProfileContent:
|
||||
user_id = self.require_user_id()
|
||||
|
||||
try:
|
||||
memory = await self._repository.get_work_memory_for_owner(owner_id=user_id)
|
||||
except SQLAlchemyError:
|
||||
raise HTTPException(status_code=503, detail="Memories service unavailable")
|
||||
|
||||
if memory is None:
|
||||
raise HTTPException(status_code=404, detail="Work memory not found")
|
||||
|
||||
try:
|
||||
update_data: dict = {}
|
||||
if content is not None:
|
||||
existing_content = memory.content or {}
|
||||
merged = content.model_dump()
|
||||
existing_content.update(
|
||||
{k: v for k, v in merged.items() if v is not None}
|
||||
)
|
||||
update_data["content"] = existing_content
|
||||
|
||||
if update_data:
|
||||
memory = await self._repository.update_content(
|
||||
memory=memory,
|
||||
content=update_data.get("content"),
|
||||
)
|
||||
await self._session.commit()
|
||||
await self._session.refresh(memory)
|
||||
except SQLAlchemyError:
|
||||
await self._session.rollback()
|
||||
raise HTTPException(status_code=503, detail="Memories service unavailable")
|
||||
|
||||
logger.info(
|
||||
"work_memory_patched",
|
||||
extra={"user_id": str(user_id)},
|
||||
)
|
||||
|
||||
return self._parse_work_content(memory)
|
||||
|
||||
def _parse_user_content(self, memory: Memory) -> UserMemoryContent:
|
||||
content_dict = memory.content or {}
|
||||
return UserMemoryContent.model_validate(content_dict)
|
||||
|
||||
def _parse_work_content(self, memory: Memory) -> WorkProfileContent:
|
||||
content_dict = memory.content or {}
|
||||
return WorkProfileContent.model_validate(content_dict)
|
||||
|
||||
async def get_memory_model(self, *, memory_type: MemoryType) -> Memory | None:
|
||||
user_id = self.require_user_id()
|
||||
try:
|
||||
return await self._repository.get_by_type_for_owner(
|
||||
owner_id=user_id,
|
||||
memory_type=memory_type,
|
||||
)
|
||||
except SQLAlchemyError:
|
||||
raise HTTPException(status_code=503, detail="Memories service unavailable")
|
||||
|
||||
@@ -7,6 +7,7 @@ from v1.app.router import router as app_router
|
||||
from v1.auth.router import router as auth_router
|
||||
from v1.friendships.router import router as friendships_router
|
||||
from v1.inbox_messages.router import router as inbox_messages_router
|
||||
from v1.memories.router import router as memories_router
|
||||
from v1.schedule_items.router import router as schedule_items_router
|
||||
from v1.todo.router import router as todo_router
|
||||
from v1.users.router import router as users_router
|
||||
@@ -16,8 +17,8 @@ router = APIRouter(prefix="/api/v1")
|
||||
router.include_router(app_router)
|
||||
router.include_router(auth_router)
|
||||
router.include_router(agent_router)
|
||||
router.include_router(agent_router)
|
||||
router.include_router(friendships_router)
|
||||
router.include_router(memories_router)
|
||||
router.include_router(users_router)
|
||||
router.include_router(schedule_items_router)
|
||||
router.include_router(inbox_messages_router)
|
||||
|
||||
Reference in New Issue
Block a user