feat: 重构 memory 系统,支持 user memory 和 work memory 分离
This commit is contained in:
@@ -3,29 +3,162 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from core.db.base_repository import BaseRepository
|
||||
from models.memories import Memory
|
||||
from core.logging import get_logger
|
||||
from models.memories import Memory, MemoryStatus, MemoryType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
logger = get_logger("v1.memories.repository")
|
||||
|
||||
|
||||
class MemoriesRepositoryLike(Protocol):
|
||||
async def get_active_memories(self, *, owner_id: UUID) -> list[Memory]: ...
|
||||
async def create(
|
||||
self,
|
||||
*,
|
||||
owner_id: UUID,
|
||||
memory_type: MemoryType,
|
||||
content: dict,
|
||||
) -> Memory: ...
|
||||
|
||||
async def get_by_type_for_owner(
|
||||
self, *, owner_id: UUID, memory_type: MemoryType
|
||||
) -> Memory | None: ...
|
||||
|
||||
async def get_user_memory_for_owner(self, *, owner_id: UUID) -> Memory | None: ...
|
||||
|
||||
async def get_work_memory_for_owner(self, *, owner_id: UUID) -> Memory | None: ...
|
||||
|
||||
async def create_if_absent(
|
||||
self,
|
||||
*,
|
||||
owner_id: UUID,
|
||||
memory_type: MemoryType,
|
||||
content: dict,
|
||||
) -> bool: ...
|
||||
|
||||
async def update_content(
|
||||
self,
|
||||
memory: Memory,
|
||||
content: dict | None = None,
|
||||
) -> Memory: ...
|
||||
|
||||
|
||||
class MemoriesRepository(BaseRepository[Memory]):
|
||||
class SQLAlchemyMemoriesRepository(BaseRepository[Memory]):
|
||||
_session: AsyncSession
|
||||
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
super().__init__(session=session, model=Memory)
|
||||
self._session = session
|
||||
|
||||
async def get_active_memories(self, *, owner_id: UUID) -> list[Memory]:
|
||||
stmt = (
|
||||
select(Memory)
|
||||
.where(Memory.owner_id == owner_id)
|
||||
.where(Memory.status == "active")
|
||||
.order_by(Memory.created_at.desc())
|
||||
async def create(
|
||||
self,
|
||||
*,
|
||||
owner_id: UUID,
|
||||
memory_type: MemoryType,
|
||||
content: dict,
|
||||
) -> Memory:
|
||||
try:
|
||||
memory = Memory(
|
||||
owner_id=owner_id,
|
||||
memory_type=memory_type,
|
||||
content=content,
|
||||
status=MemoryStatus.ACTIVE,
|
||||
)
|
||||
self._session.add(memory)
|
||||
await self._session.flush()
|
||||
return memory
|
||||
except SQLAlchemyError:
|
||||
logger.exception(
|
||||
"Failed to create memory",
|
||||
owner_id=str(owner_id),
|
||||
memory_type=memory_type.value,
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_by_type_for_owner(
|
||||
self, *, owner_id: UUID, memory_type: MemoryType
|
||||
) -> Memory | None:
|
||||
try:
|
||||
stmt = (
|
||||
select(Memory)
|
||||
.where(Memory.owner_id == owner_id)
|
||||
.where(Memory.memory_type == memory_type)
|
||||
.where(Memory.status == MemoryStatus.ACTIVE)
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
if hasattr(result, "scalar_one_or_none"):
|
||||
return result.scalar_one_or_none()
|
||||
scalars = result.scalars()
|
||||
rows = list(scalars.all())
|
||||
return rows[0] if rows else None
|
||||
except SQLAlchemyError:
|
||||
logger.exception(
|
||||
"Failed to get memory by type for owner",
|
||||
owner_id=str(owner_id),
|
||||
memory_type=memory_type.value,
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_user_memory_for_owner(self, *, owner_id: UUID) -> Memory | None:
|
||||
return await self.get_by_type_for_owner(
|
||||
owner_id=owner_id, memory_type=MemoryType.USER
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_work_memory_for_owner(self, *, owner_id: UUID) -> Memory | None:
|
||||
return await self.get_by_type_for_owner(
|
||||
owner_id=owner_id, memory_type=MemoryType.WORK
|
||||
)
|
||||
|
||||
async def create_if_absent(
|
||||
self,
|
||||
*,
|
||||
owner_id: UUID,
|
||||
memory_type: MemoryType,
|
||||
content: dict,
|
||||
) -> bool:
|
||||
try:
|
||||
stmt = (
|
||||
insert(Memory)
|
||||
.values(
|
||||
owner_id=owner_id,
|
||||
memory_type=memory_type,
|
||||
content=content,
|
||||
status=MemoryStatus.ACTIVE,
|
||||
)
|
||||
.on_conflict_do_nothing(index_elements=["owner_id", "memory_type"])
|
||||
.returning(Memory.id)
|
||||
)
|
||||
inserted_id = (await self._session.execute(stmt)).scalar_one_or_none()
|
||||
await self._session.flush()
|
||||
return inserted_id is not None
|
||||
except SQLAlchemyError:
|
||||
logger.exception(
|
||||
"Failed to create memory if absent",
|
||||
owner_id=str(owner_id),
|
||||
memory_type=memory_type.value,
|
||||
)
|
||||
raise
|
||||
|
||||
async def update_content(
|
||||
self,
|
||||
memory: Memory,
|
||||
content: dict | None = None,
|
||||
) -> Memory:
|
||||
try:
|
||||
if content is not None:
|
||||
memory.content = content
|
||||
|
||||
await self._session.flush()
|
||||
return memory
|
||||
except SQLAlchemyError:
|
||||
logger.exception(
|
||||
"Failed to update memory content",
|
||||
memory_id=str(memory.id),
|
||||
)
|
||||
raise
|
||||
|
||||
Reference in New Issue
Block a user