feat: 重构 memory 系统,支持 user memory 和 work memory 分离

This commit is contained in:
qzl
2026-03-23 14:25:47 +08:00
parent 3aacc756db
commit 6be616f108
70 changed files with 7031 additions and 431 deletions
+7 -6
View File
@@ -9,11 +9,12 @@ from pathlib import Path
import yaml
from pydantic import ValidationError
from core.agentscope.tools.tool_config import AgentTool
from schemas.agent.system_agent import SystemAgentLLMConfig
from schemas.automation import (
ContextSource,
ContextWindowMode,
MemoryContextConfig,
MessageContextConfig,
RuntimeConfig,
)
@@ -38,9 +39,9 @@ def _load_system_agents_yaml(path: Path | None = None) -> dict:
return loaded
def _parse_context_messages_config(yaml_config: dict | None) -> MemoryContextConfig:
def _parse_context_messages_config(yaml_config: dict | None) -> MessageContextConfig:
if not yaml_config:
return MemoryContextConfig()
return MessageContextConfig()
mode_str = yaml_config.get("mode", "day")
count = yaml_config.get("count", 2)
try:
@@ -51,7 +52,7 @@ def _parse_context_messages_config(yaml_config: dict | None) -> MemoryContextCon
window_mode = ContextWindowMode(mode_str)
except ValueError:
window_mode = ContextWindowMode.DAY
return MemoryContextConfig(
return MessageContextConfig(
source=source,
window_mode=window_mode,
window_count=count,
@@ -93,9 +94,9 @@ def build_runtime_config_from_system_agents(
router_config.context_messages.model_dump() if router_config else None
)
enabled_tools: list[str] = []
enabled_tools: list[AgentTool] = []
if worker_config and worker_config.enabled_tools:
enabled_tools = [str(t) for t in worker_config.enabled_tools]
enabled_tools = list(worker_config.enabled_tools)
return RuntimeConfig(
enabled_tools=enabled_tools,
@@ -0,0 +1,46 @@
from __future__ import annotations
from functools import lru_cache
from pathlib import Path
import re
from typing import Any
import yaml
from core.agentscope.tools.tool_config import AgentTool
from schemas.automation import AutomationJobConfig, MessageContextConfig
_CONFIG_NAME_PATTERN = re.compile(r"^[a-z0-9][a-z0-9_-]{0,63}$")
def _automation_yaml_path(config_name: str) -> Path:
if not _CONFIG_NAME_PATTERN.fullmatch(config_name):
raise ValueError("invalid automation config name")
return (
Path(__file__).resolve().parents[2]
/ "core"
/ "config"
/ "static"
/ "automation"
/ f"{config_name}.yaml"
)
@lru_cache(maxsize=16)
def load_static_automation_job_config(*, config_name: str) -> AutomationJobConfig:
path = _automation_yaml_path(config_name)
with path.open("r", encoding="utf-8") as file:
loaded: Any = yaml.safe_load(file) or {}
if not isinstance(loaded, dict):
raise ValueError(f"invalid automation config format: {path}")
config = AutomationJobConfig.model_validate(loaded)
if config_name == "memory_extraction":
if config.enabled_tools != [AgentTool.MEMORY_WRITE, AgentTool.MEMORY_FORGET]:
raise ValueError(
"memory_extraction enabled_tools must be [memory.write, memory.forget]"
)
if config.context != MessageContextConfig(window_count=2):
raise ValueError(
"memory_extraction context must be latest_chat/day with window_count=2"
)
return config
+21 -2
View File
@@ -1,8 +1,27 @@
from __future__ import annotations
from typing import Annotated
from fastapi import Depends
from sqlalchemy.ext.asyncio import AsyncSession
from core.db import get_db
from v1.auth.gateway import SupabaseAuthGateway
from v1.auth.registration_bootstrap import (
RegistrationAutomationBootstrapService,
RegistrationBootstrapRepository,
)
from v1.auth.service import AuthService
def get_auth_service() -> AuthService:
return AuthService(gateway=SupabaseAuthGateway())
def get_auth_service(
session: Annotated[AsyncSession, Depends(get_db)],
) -> AuthService:
bootstrapper = RegistrationAutomationBootstrapService(
repository=RegistrationBootstrapRepository(session=session),
session=session,
)
return AuthService(
gateway=SupabaseAuthGateway(),
registration_bootstrapper=bootstrapper,
)
@@ -0,0 +1,228 @@
from __future__ import annotations
from datetime import UTC, datetime, timedelta
from typing import Protocol
from uuid import UUID, uuid4
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
from sqlalchemy import select
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.ext.asyncio import AsyncSession
from core.logging import get_logger
from models.automation_jobs import AutomationJob, AutomationJobStatus, ScheduleType
from models.memories import MemoryType
from models.profile import Profile
from schemas.automation import AutomationJobConfig
from schemas.memories.memory_content import UserMemoryContent, WorkProfileContent
from schemas.user.context import parse_profile_settings
from v1.auth.automation_static_config import load_static_automation_job_config
from v1.auth.schemas import RegistrationBootstrapRequest
from v1.memories.repository import SQLAlchemyMemoriesRepository
logger = get_logger("v1.auth.registration_bootstrap")
_LOCAL_RUN_HOUR = 8
_LOCAL_RUN_MINUTE = 0
class RegistrationBootstrapRepository:
def __init__(self, session: AsyncSession) -> None:
self._session = session
self._memories_repository = SQLAlchemyMemoriesRepository(session)
async def get_profile_timezone(self, *, user_id: UUID) -> str:
stmt = select(Profile.settings).where(Profile.id == user_id)
settings = (await self._session.execute(stmt)).scalar_one_or_none()
parsed = parse_profile_settings(
settings if isinstance(settings, dict) else None
)
return parsed.preferences.timezone
async def insert_bootstrap_automation_job_if_absent(
self,
*,
owner_id: UUID,
bootstrap_key: str,
title: str,
config: AutomationJobConfig,
timezone_name: str,
run_at: datetime,
next_run_at: datetime,
) -> bool:
stmt = (
insert(AutomationJob)
.values(
id=uuid4(),
owner_id=owner_id,
bootstrap_key=bootstrap_key,
title=title,
config=config.model_dump(mode="json"),
schedule_type=ScheduleType.DAILY,
run_at=run_at,
next_run_at=next_run_at,
timezone=timezone_name,
status=AutomationJobStatus.ACTIVE,
created_by=owner_id,
)
.on_conflict_do_nothing(
index_elements=["owner_id", "bootstrap_key"],
index_where=AutomationJob.deleted_at.is_(None)
& AutomationJob.bootstrap_key.is_not(None),
)
.returning(AutomationJob.id)
)
inserted_id = (await self._session.execute(stmt)).scalar_one_or_none()
await self._session.flush()
return inserted_id is not None
async def upsert_initial_memory(
self,
*,
owner_id: UUID,
memory_type: MemoryType,
content: dict,
) -> bool:
return await self._memories_repository.create_if_absent(
owner_id=owner_id,
memory_type=memory_type,
content=content,
)
class RegistrationBootstrapper(Protocol):
async def ensure_user_automation_jobs(self, *, user_id: str | UUID) -> None: ...
class RegistrationBootstrapRepositoryLike(Protocol):
async def get_profile_timezone(self, *, user_id: UUID) -> str: ...
async def insert_bootstrap_automation_job_if_absent(
self,
*,
owner_id: UUID,
bootstrap_key: str,
title: str,
config: AutomationJobConfig,
timezone_name: str,
run_at: datetime,
next_run_at: datetime,
) -> bool: ...
async def upsert_initial_memory(
self,
*,
owner_id: UUID,
memory_type: MemoryType,
content: dict,
) -> bool: ...
class SessionLike(Protocol):
async def commit(self) -> None: ...
async def rollback(self) -> None: ...
def compute_next_local_time_utc(
*,
now_utc: datetime,
timezone_name: str,
local_hour: int,
local_minute: int,
) -> tuple[datetime, datetime]:
try:
timezone_obj = ZoneInfo(timezone_name)
except ZoneInfoNotFoundError:
timezone_obj = ZoneInfo("UTC")
local_now = now_utc.astimezone(timezone_obj)
today_run_local = local_now.replace(
hour=local_hour,
minute=local_minute,
second=0,
microsecond=0,
)
run_local = (
today_run_local
if local_now <= today_run_local
else today_run_local + timedelta(days=1)
)
next_local = run_local + timedelta(days=1)
return run_local.astimezone(UTC), next_local.astimezone(UTC)
class RegistrationAutomationBootstrapService:
def __init__(
self,
*,
repository: RegistrationBootstrapRepositoryLike,
session: SessionLike,
) -> None:
self._repository = repository
self._session = session
async def ensure_user_automation_jobs(self, *, user_id: str | UUID) -> None:
request = RegistrationBootstrapRequest.model_validate({"user_id": user_id})
owner_id = request.user_id
timezone_name = await self._repository.get_profile_timezone(user_id=owner_id)
definitions = [
{
"bootstrap_key": "memory_extraction",
"config_name": "memory_extraction",
"title": "Memory Agent",
"local_hour": _LOCAL_RUN_HOUR,
"local_minute": _LOCAL_RUN_MINUTE,
}
]
try:
inserted_any = False
created_or_updated_memory = False
user_initialized = await self._repository.upsert_initial_memory(
owner_id=owner_id,
memory_type=MemoryType.USER,
content=UserMemoryContent().model_dump(mode="json"),
)
work_initialized = await self._repository.upsert_initial_memory(
owner_id=owner_id,
memory_type=MemoryType.WORK,
content=WorkProfileContent().model_dump(mode="json"),
)
created_or_updated_memory = user_initialized or work_initialized
for definition in definitions:
bootstrap_key = str(definition["bootstrap_key"])
job_config = load_static_automation_job_config(
config_name=str(definition["config_name"])
)
run_at, next_run_at = compute_next_local_time_utc(
now_utc=datetime.now(UTC),
timezone_name=timezone_name,
local_hour=int(definition["local_hour"]),
local_minute=int(definition["local_minute"]),
)
inserted = (
await self._repository.insert_bootstrap_automation_job_if_absent(
owner_id=owner_id,
bootstrap_key=bootstrap_key,
title=str(definition["title"]),
config=job_config,
timezone_name=timezone_name,
run_at=run_at,
next_run_at=next_run_at,
)
)
inserted_any = inserted_any or inserted
if inserted_any or created_or_updated_memory:
await self._session.commit()
logger.info(
"user automation jobs bootstrapped",
user_id=user_id,
timezone=timezone_name,
memory_initialized=created_or_updated_memory,
)
except Exception:
await self._session.rollback()
raise
+8
View File
@@ -1,5 +1,7 @@
from __future__ import annotations
from uuid import UUID
from pydantic import BaseModel, ConfigDict, Field
SUPABASE_PASSWORD_MIN_LENGTH = 6
@@ -49,3 +51,9 @@ class UserByPhoneResponse(BaseModel):
class OtpSendResponse(BaseModel):
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
class RegistrationBootstrapRequest(BaseModel):
model_config = ConfigDict(extra="forbid")
user_id: UUID
+18 -2
View File
@@ -28,9 +28,15 @@ class AuthServiceGateway(Protocol):
class AuthService:
_gateway: AuthServiceGateway
_registration_bootstrapper: RegistrationBootstrapper | None
def __init__(self, gateway: AuthServiceGateway) -> None:
def __init__(
self,
gateway: AuthServiceGateway,
registration_bootstrapper: "RegistrationBootstrapper | None" = None,
) -> None:
self._gateway = gateway
self._registration_bootstrapper = registration_bootstrapper
async def send_otp(self, request: OtpSendRequest) -> None:
await self._gateway.send_otp(request)
@@ -38,10 +44,20 @@ class AuthService:
async def create_phone_session(
self, request: PhoneSessionCreateRequest
) -> SessionResponse:
return await self._gateway.create_phone_session(request)
response = await self._gateway.create_phone_session(request)
if self._registration_bootstrapper is not None:
await self._registration_bootstrapper.ensure_user_automation_jobs(
user_id=response.user.id
)
return response
async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse:
return await self._gateway.refresh_session(request)
async def delete_session(self, refresh_token: str | None) -> None:
await self._gateway.delete_session(refresh_token)
class RegistrationBootstrapper(Protocol):
async def ensure_user_automation_jobs(self, *, user_id: str) -> None:
raise NotImplementedError
+30
View File
@@ -0,0 +1,30 @@
from __future__ import annotations
from typing import Annotated
from fastapi import Depends
from sqlalchemy.ext.asyncio import AsyncSession
from core.auth.models import CurrentUser
from core.db import get_db
from v1.memories.repository import SQLAlchemyMemoriesRepository
from v1.memories.service import MemoriesService
from v1.users.dependencies import get_current_user
async def get_memories_repository(
session: Annotated[AsyncSession, Depends(get_db)],
) -> SQLAlchemyMemoriesRepository:
return SQLAlchemyMemoriesRepository(session)
async def get_memories_service(
session: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
) -> MemoriesService:
repository = SQLAlchemyMemoriesRepository(session)
return MemoriesService(
repository=repository,
session=session,
current_user=current_user,
)
+144 -11
View File
@@ -3,29 +3,162 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Protocol
from uuid import UUID
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy import select
from sqlalchemy.exc import SQLAlchemyError
from core.db.base_repository import BaseRepository
from models.memories import Memory
from core.logging import get_logger
from models.memories import Memory, MemoryStatus, MemoryType
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
logger = get_logger("v1.memories.repository")
class MemoriesRepositoryLike(Protocol):
async def get_active_memories(self, *, owner_id: UUID) -> list[Memory]: ...
async def create(
self,
*,
owner_id: UUID,
memory_type: MemoryType,
content: dict,
) -> Memory: ...
async def get_by_type_for_owner(
self, *, owner_id: UUID, memory_type: MemoryType
) -> Memory | None: ...
async def get_user_memory_for_owner(self, *, owner_id: UUID) -> Memory | None: ...
async def get_work_memory_for_owner(self, *, owner_id: UUID) -> Memory | None: ...
async def create_if_absent(
self,
*,
owner_id: UUID,
memory_type: MemoryType,
content: dict,
) -> bool: ...
async def update_content(
self,
memory: Memory,
content: dict | None = None,
) -> Memory: ...
class MemoriesRepository(BaseRepository[Memory]):
class SQLAlchemyMemoriesRepository(BaseRepository[Memory]):
_session: AsyncSession
def __init__(self, session: AsyncSession) -> None:
super().__init__(session=session, model=Memory)
self._session = session
async def get_active_memories(self, *, owner_id: UUID) -> list[Memory]:
stmt = (
select(Memory)
.where(Memory.owner_id == owner_id)
.where(Memory.status == "active")
.order_by(Memory.created_at.desc())
async def create(
self,
*,
owner_id: UUID,
memory_type: MemoryType,
content: dict,
) -> Memory:
try:
memory = Memory(
owner_id=owner_id,
memory_type=memory_type,
content=content,
status=MemoryStatus.ACTIVE,
)
self._session.add(memory)
await self._session.flush()
return memory
except SQLAlchemyError:
logger.exception(
"Failed to create memory",
owner_id=str(owner_id),
memory_type=memory_type.value,
)
raise
async def get_by_type_for_owner(
self, *, owner_id: UUID, memory_type: MemoryType
) -> Memory | None:
try:
stmt = (
select(Memory)
.where(Memory.owner_id == owner_id)
.where(Memory.memory_type == memory_type)
.where(Memory.status == MemoryStatus.ACTIVE)
)
result = await self._session.execute(stmt)
if hasattr(result, "scalar_one_or_none"):
return result.scalar_one_or_none()
scalars = result.scalars()
rows = list(scalars.all())
return rows[0] if rows else None
except SQLAlchemyError:
logger.exception(
"Failed to get memory by type for owner",
owner_id=str(owner_id),
memory_type=memory_type.value,
)
raise
async def get_user_memory_for_owner(self, *, owner_id: UUID) -> Memory | None:
return await self.get_by_type_for_owner(
owner_id=owner_id, memory_type=MemoryType.USER
)
result = await self._session.execute(stmt)
return list(result.scalars().all())
async def get_work_memory_for_owner(self, *, owner_id: UUID) -> Memory | None:
return await self.get_by_type_for_owner(
owner_id=owner_id, memory_type=MemoryType.WORK
)
async def create_if_absent(
self,
*,
owner_id: UUID,
memory_type: MemoryType,
content: dict,
) -> bool:
try:
stmt = (
insert(Memory)
.values(
owner_id=owner_id,
memory_type=memory_type,
content=content,
status=MemoryStatus.ACTIVE,
)
.on_conflict_do_nothing(index_elements=["owner_id", "memory_type"])
.returning(Memory.id)
)
inserted_id = (await self._session.execute(stmt)).scalar_one_or_none()
await self._session.flush()
return inserted_id is not None
except SQLAlchemyError:
logger.exception(
"Failed to create memory if absent",
owner_id=str(owner_id),
memory_type=memory_type.value,
)
raise
async def update_content(
self,
memory: Memory,
content: dict | None = None,
) -> Memory:
try:
if content is not None:
memory.content = content
await self._session.flush()
return memory
except SQLAlchemyError:
logger.exception(
"Failed to update memory content",
memory_id=str(memory.id),
)
raise
+83
View File
@@ -0,0 +1,83 @@
from __future__ import annotations
from typing import Annotated
from fastapi import APIRouter, Depends, status
from schemas.memories.memory_content import UserMemoryContent, WorkProfileContent
from v1.memories.dependencies import get_memories_service
from v1.memories.schemas import (
MemoryListResponse,
UserMemoryPartialUpdate,
UserMemoryUpdate,
WorkMemoryPartialUpdate,
WorkMemoryUpdate,
)
from v1.memories.service import MemoriesService
router = APIRouter(prefix="/memories", tags=["memories"])
@router.get("", response_model=MemoryListResponse)
async def get_all_memories(
service: Annotated[MemoriesService, Depends(get_memories_service)],
) -> MemoryListResponse:
result = await service.get_all_memories()
return MemoryListResponse(
user_memory=result["user_memory"],
work_memory=result["work_memory"],
)
@router.get("/user", response_model=UserMemoryContent | None)
async def get_user_memory(
service: Annotated[MemoriesService, Depends(get_memories_service)],
) -> UserMemoryContent | None:
return await service.get_user_memory()
@router.get("/work", response_model=WorkProfileContent | None)
async def get_work_memory(
service: Annotated[MemoriesService, Depends(get_memories_service)],
) -> WorkProfileContent | None:
return await service.get_work_memory()
@router.put("/user", response_model=UserMemoryContent, status_code=status.HTTP_200_OK)
async def update_user_memory(
payload: UserMemoryUpdate,
service: Annotated[MemoriesService, Depends(get_memories_service)],
) -> UserMemoryContent:
return await service.update_user_memory(
content=payload.content,
)
@router.put("/work", response_model=WorkProfileContent, status_code=status.HTTP_200_OK)
async def update_work_memory(
payload: WorkMemoryUpdate,
service: Annotated[MemoriesService, Depends(get_memories_service)],
) -> WorkProfileContent:
return await service.update_work_memory(
content=payload.content,
)
@router.patch("/user", response_model=UserMemoryContent)
async def patch_user_memory(
payload: UserMemoryPartialUpdate,
service: Annotated[MemoriesService, Depends(get_memories_service)],
) -> UserMemoryContent:
return await service.patch_user_memory(
content=payload.content,
)
@router.patch("/work", response_model=WorkProfileContent)
async def patch_work_memory(
payload: WorkMemoryPartialUpdate,
service: Annotated[MemoriesService, Depends(get_memories_service)],
) -> WorkProfileContent:
return await service.patch_work_memory(
content=payload.content,
)
+38
View File
@@ -0,0 +1,38 @@
from __future__ import annotations
from typing import ClassVar
from pydantic import BaseModel, ConfigDict
from schemas.memories.memory_content import UserMemoryContent, WorkProfileContent
class UserMemoryUpdate(BaseModel):
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
content: UserMemoryContent
class WorkMemoryUpdate(BaseModel):
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
content: WorkProfileContent
class UserMemoryPartialUpdate(BaseModel):
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
content: UserMemoryContent | None = None
class WorkMemoryPartialUpdate(BaseModel):
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
content: WorkProfileContent | None = None
class MemoryListResponse(BaseModel):
model_config: ClassVar[ConfigDict] = ConfigDict(from_attributes=True)
user_memory: UserMemoryContent | None = None
work_memory: WorkProfileContent | None = None
+269 -36
View File
@@ -1,53 +1,286 @@
from __future__ import annotations
from uuid import UUID
from typing import TYPE_CHECKING
from models.memories import Memory
from schemas.memories import MemoryContext, MemoryListResponse, MemorySource, MemoryType
from fastapi import HTTPException
from sqlalchemy.exc import SQLAlchemyError
from core.auth.models import CurrentUser
from core.db.base_service import BaseService
from core.logging import get_logger
from models.memories import Memory, MemoryType
from schemas.memories.memory_content import UserMemoryContent, WorkProfileContent
from v1.memories.repository import MemoriesRepositoryLike
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
logger = get_logger("v1.memories.service")
class MemoriesService(BaseService):
"""Memories service handling user/work memory operations.
Each user has exactly 2 memory records:
- user_type: stores personal preferences, people, places, etc.
- work_type: stores work profile, projects, team, etc.
Responsibilities:
- Authorization checks
- Validation (ownership, memory type)
- Transaction boundary (commit/rollback)
- Converting ORM models to response schemas
"""
class MemoriesService:
_repository: MemoriesRepositoryLike
_session: AsyncSession
def __init__(self, repository: MemoriesRepositoryLike) -> None:
def __init__(
self,
repository: MemoriesRepositoryLike,
session: AsyncSession,
current_user: CurrentUser | None,
) -> None:
super().__init__(current_user=current_user)
self._repository = repository
self._session = session
def _to_context(self, memory: Memory) -> MemoryContext:
return MemoryContext(
memory_type=MemoryType(memory.memory_type.value),
source=MemorySource(memory.source.value),
title=memory.title,
content=memory.content,
created_at=memory.created_at,
updated_at=memory.updated_at,
async def get_user_memory(self) -> UserMemoryContent | None:
user_id = self.require_user_id()
try:
memory = await self._repository.get_user_memory_for_owner(owner_id=user_id)
except SQLAlchemyError:
raise HTTPException(status_code=503, detail="Memories service unavailable")
if memory is None:
return None
return self._parse_user_content(memory)
async def get_work_memory(self) -> WorkProfileContent | None:
user_id = self.require_user_id()
try:
memory = await self._repository.get_work_memory_for_owner(owner_id=user_id)
except SQLAlchemyError:
raise HTTPException(status_code=503, detail="Memories service unavailable")
if memory is None:
return None
return self._parse_work_content(memory)
async def get_all_memories(self) -> dict:
user_id = self.require_user_id()
try:
user_memory = await self._repository.get_user_memory_for_owner(
owner_id=user_id
)
work_memory = await self._repository.get_work_memory_for_owner(
owner_id=user_id
)
except SQLAlchemyError:
raise HTTPException(status_code=503, detail="Memories service unavailable")
return {
"user_memory": self._parse_user_content(user_memory)
if user_memory
else None,
"work_memory": self._parse_work_content(work_memory)
if work_memory
else None,
}
async def update_user_memory(
self,
*,
content: UserMemoryContent,
) -> UserMemoryContent:
user_id = self.require_user_id()
try:
memory = await self._repository.get_user_memory_for_owner(owner_id=user_id)
except SQLAlchemyError:
raise HTTPException(status_code=503, detail="Memories service unavailable")
if memory is None:
try:
memory = await self._repository.create(
owner_id=user_id,
memory_type=MemoryType.USER,
content=content.model_dump(),
)
await self._session.commit()
except SQLAlchemyError:
await self._session.rollback()
raise HTTPException(
status_code=503, detail="Memories service unavailable"
)
else:
try:
memory = await self._repository.update_content(
memory=memory,
content=content.model_dump(),
)
await self._session.commit()
await self._session.refresh(memory)
except SQLAlchemyError:
await self._session.rollback()
raise HTTPException(
status_code=503, detail="Memories service unavailable"
)
logger.info(
"user_memory_updated",
extra={"user_id": str(user_id)},
)
async def get_user_memories(self, *, owner_id: UUID) -> MemoryListResponse:
memories = await self._repository.get_active_memories(owner_id=owner_id)
user_memories = [
self._to_context(memory)
for memory in memories
if memory.memory_type.value == "user"
]
return MemoryListResponse(
owner_id=owner_id, memories=user_memories, total=len(user_memories)
return self._parse_user_content(memory)
async def update_work_memory(
self,
*,
content: WorkProfileContent,
) -> WorkProfileContent:
user_id = self.require_user_id()
try:
memory = await self._repository.get_work_memory_for_owner(owner_id=user_id)
except SQLAlchemyError:
raise HTTPException(status_code=503, detail="Memories service unavailable")
if memory is None:
try:
memory = await self._repository.create(
owner_id=user_id,
memory_type=MemoryType.WORK,
content=content.model_dump(),
)
await self._session.commit()
except SQLAlchemyError:
await self._session.rollback()
raise HTTPException(
status_code=503, detail="Memories service unavailable"
)
else:
try:
memory = await self._repository.update_content(
memory=memory,
content=content.model_dump(),
)
await self._session.commit()
await self._session.refresh(memory)
except SQLAlchemyError:
await self._session.rollback()
raise HTTPException(
status_code=503, detail="Memories service unavailable"
)
logger.info(
"work_memory_updated",
extra={"user_id": str(user_id)},
)
async def get_agent_memories(self, *, owner_id: UUID) -> MemoryListResponse:
memories = await self._repository.get_active_memories(owner_id=owner_id)
agent_memories = [
self._to_context(memory)
for memory in memories
if memory.memory_type.value == "work"
]
return MemoryListResponse(
owner_id=owner_id, memories=agent_memories, total=len(agent_memories)
return self._parse_work_content(memory)
async def patch_user_memory(
self, *, content: UserMemoryContent | None = None
) -> UserMemoryContent:
user_id = self.require_user_id()
try:
memory = await self._repository.get_user_memory_for_owner(owner_id=user_id)
except SQLAlchemyError:
raise HTTPException(status_code=503, detail="Memories service unavailable")
if memory is None:
raise HTTPException(status_code=404, detail="User memory not found")
try:
update_data: dict = {}
if content is not None:
existing_content = memory.content or {}
merged = content.model_dump()
existing_content.update(
{k: v for k, v in merged.items() if v is not None}
)
update_data["content"] = existing_content
if update_data:
memory = await self._repository.update_content(
memory=memory,
content=update_data.get("content"),
)
await self._session.commit()
await self._session.refresh(memory)
except SQLAlchemyError:
await self._session.rollback()
raise HTTPException(status_code=503, detail="Memories service unavailable")
logger.info(
"user_memory_patched",
extra={"user_id": str(user_id)},
)
async def get_all_memories(self, *, owner_id: UUID) -> MemoryListResponse:
memories = await self._repository.get_active_memories(owner_id=owner_id)
memory_contexts = [self._to_context(memory) for memory in memories]
return MemoryListResponse(
owner_id=owner_id, memories=memory_contexts, total=len(memory_contexts)
return self._parse_user_content(memory)
async def patch_work_memory(
self, *, content: WorkProfileContent | None = None
) -> WorkProfileContent:
user_id = self.require_user_id()
try:
memory = await self._repository.get_work_memory_for_owner(owner_id=user_id)
except SQLAlchemyError:
raise HTTPException(status_code=503, detail="Memories service unavailable")
if memory is None:
raise HTTPException(status_code=404, detail="Work memory not found")
try:
update_data: dict = {}
if content is not None:
existing_content = memory.content or {}
merged = content.model_dump()
existing_content.update(
{k: v for k, v in merged.items() if v is not None}
)
update_data["content"] = existing_content
if update_data:
memory = await self._repository.update_content(
memory=memory,
content=update_data.get("content"),
)
await self._session.commit()
await self._session.refresh(memory)
except SQLAlchemyError:
await self._session.rollback()
raise HTTPException(status_code=503, detail="Memories service unavailable")
logger.info(
"work_memory_patched",
extra={"user_id": str(user_id)},
)
return self._parse_work_content(memory)
def _parse_user_content(self, memory: Memory) -> UserMemoryContent:
content_dict = memory.content or {}
return UserMemoryContent.model_validate(content_dict)
def _parse_work_content(self, memory: Memory) -> WorkProfileContent:
content_dict = memory.content or {}
return WorkProfileContent.model_validate(content_dict)
async def get_memory_model(self, *, memory_type: MemoryType) -> Memory | None:
user_id = self.require_user_id()
try:
return await self._repository.get_by_type_for_owner(
owner_id=user_id,
memory_type=memory_type,
)
except SQLAlchemyError:
raise HTTPException(status_code=503, detail="Memories service unavailable")
+2 -1
View File
@@ -7,6 +7,7 @@ from v1.app.router import router as app_router
from v1.auth.router import router as auth_router
from v1.friendships.router import router as friendships_router
from v1.inbox_messages.router import router as inbox_messages_router
from v1.memories.router import router as memories_router
from v1.schedule_items.router import router as schedule_items_router
from v1.todo.router import router as todo_router
from v1.users.router import router as users_router
@@ -16,8 +17,8 @@ router = APIRouter(prefix="/api/v1")
router.include_router(app_router)
router.include_router(auth_router)
router.include_router(agent_router)
router.include_router(agent_router)
router.include_router(friendships_router)
router.include_router(memories_router)
router.include_router(users_router)
router.include_router(schedule_items_router)
router.include_router(inbox_messages_router)