166 lines
4.8 KiB
Python
166 lines
4.8 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING, Protocol
|
|
from uuid import UUID
|
|
|
|
from sqlalchemy.dialects.postgresql import insert
|
|
from sqlalchemy import select
|
|
from sqlalchemy.exc import SQLAlchemyError
|
|
|
|
from core.db.base_repository import BaseRepository
|
|
from core.logging import get_logger
|
|
from models.memories import Memory
|
|
from schemas.enums import MemoryStatus, MemoryType
|
|
|
|
if TYPE_CHECKING:
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
logger = get_logger("v1.memories.repository")
|
|
|
|
|
|
class MemoriesRepositoryLike(Protocol):
|
|
async def create(
|
|
self,
|
|
*,
|
|
owner_id: UUID,
|
|
memory_type: MemoryType,
|
|
content: dict,
|
|
) -> Memory: ...
|
|
|
|
async def get_by_type_for_owner(
|
|
self, *, owner_id: UUID, memory_type: MemoryType
|
|
) -> Memory | None: ...
|
|
|
|
async def get_user_memory_for_owner(self, *, owner_id: UUID) -> Memory | None: ...
|
|
|
|
async def get_work_memory_for_owner(self, *, owner_id: UUID) -> Memory | None: ...
|
|
|
|
async def create_if_absent(
|
|
self,
|
|
*,
|
|
owner_id: UUID,
|
|
memory_type: MemoryType,
|
|
content: dict,
|
|
) -> bool: ...
|
|
|
|
async def update_content(
|
|
self,
|
|
memory: Memory,
|
|
content: dict | None = None,
|
|
) -> Memory: ...
|
|
|
|
|
|
class SQLAlchemyMemoriesRepository(BaseRepository[Memory]):
|
|
_session: AsyncSession
|
|
|
|
def __init__(self, session: AsyncSession) -> None:
|
|
super().__init__(session=session, model=Memory)
|
|
self._session = session
|
|
|
|
async def create(
|
|
self,
|
|
*,
|
|
owner_id: UUID,
|
|
memory_type: MemoryType,
|
|
content: dict,
|
|
) -> Memory:
|
|
try:
|
|
memory = Memory(
|
|
owner_id=owner_id,
|
|
memory_type=memory_type,
|
|
content=content,
|
|
status=MemoryStatus.ACTIVE,
|
|
)
|
|
self._session.add(memory)
|
|
await self._session.flush()
|
|
return memory
|
|
except SQLAlchemyError:
|
|
logger.exception(
|
|
"Failed to create memory",
|
|
owner_id=str(owner_id),
|
|
memory_type=memory_type.value,
|
|
)
|
|
raise
|
|
|
|
async def get_by_type_for_owner(
|
|
self, *, owner_id: UUID, memory_type: MemoryType
|
|
) -> Memory | None:
|
|
try:
|
|
stmt = (
|
|
select(Memory)
|
|
.where(Memory.owner_id == owner_id)
|
|
.where(Memory.memory_type == memory_type)
|
|
.where(Memory.status == MemoryStatus.ACTIVE)
|
|
)
|
|
result = await self._session.execute(stmt)
|
|
if hasattr(result, "scalar_one_or_none"):
|
|
return result.scalar_one_or_none()
|
|
scalars = result.scalars()
|
|
rows = list(scalars.all())
|
|
return rows[0] if rows else None
|
|
except SQLAlchemyError:
|
|
logger.exception(
|
|
"Failed to get memory by type for owner",
|
|
owner_id=str(owner_id),
|
|
memory_type=memory_type.value,
|
|
)
|
|
raise
|
|
|
|
async def get_user_memory_for_owner(self, *, owner_id: UUID) -> Memory | None:
|
|
return await self.get_by_type_for_owner(
|
|
owner_id=owner_id, memory_type=MemoryType.USER
|
|
)
|
|
|
|
async def get_work_memory_for_owner(self, *, owner_id: UUID) -> Memory | None:
|
|
return await self.get_by_type_for_owner(
|
|
owner_id=owner_id, memory_type=MemoryType.WORK
|
|
)
|
|
|
|
async def create_if_absent(
|
|
self,
|
|
*,
|
|
owner_id: UUID,
|
|
memory_type: MemoryType,
|
|
content: dict,
|
|
) -> bool:
|
|
try:
|
|
stmt = (
|
|
insert(Memory)
|
|
.values(
|
|
owner_id=owner_id,
|
|
memory_type=memory_type,
|
|
content=content,
|
|
status=MemoryStatus.ACTIVE,
|
|
)
|
|
.on_conflict_do_nothing(index_elements=["owner_id", "memory_type"])
|
|
.returning(Memory.id)
|
|
)
|
|
inserted_id = (await self._session.execute(stmt)).scalar_one_or_none()
|
|
await self._session.flush()
|
|
return inserted_id is not None
|
|
except SQLAlchemyError:
|
|
logger.exception(
|
|
"Failed to create memory if absent",
|
|
owner_id=str(owner_id),
|
|
memory_type=memory_type.value,
|
|
)
|
|
raise
|
|
|
|
async def update_content(
|
|
self,
|
|
memory: Memory,
|
|
content: dict | None = None,
|
|
) -> Memory:
|
|
try:
|
|
if content is not None:
|
|
memory.content = content
|
|
|
|
await self._session.flush()
|
|
return memory
|
|
except SQLAlchemyError:
|
|
logger.exception(
|
|
"Failed to update memory content",
|
|
memory_id=str(memory.id),
|
|
)
|
|
raise
|