from __future__ import annotations from typing import TYPE_CHECKING, Protocol from uuid import UUID from sqlalchemy.dialects.postgresql import insert from sqlalchemy import select from sqlalchemy.exc import SQLAlchemyError from core.db.base_repository import BaseRepository from core.logging import get_logger from models.memories import Memory from schemas.enums import MemoryStatus, MemoryType if TYPE_CHECKING: from sqlalchemy.ext.asyncio import AsyncSession logger = get_logger("v1.memories.repository") class MemoriesRepositoryLike(Protocol): async def create( self, *, owner_id: UUID, memory_type: MemoryType, content: dict, ) -> Memory: ... async def get_by_type_for_owner( self, *, owner_id: UUID, memory_type: MemoryType ) -> Memory | None: ... async def get_user_memory_for_owner(self, *, owner_id: UUID) -> Memory | None: ... async def get_work_memory_for_owner(self, *, owner_id: UUID) -> Memory | None: ... async def create_if_absent( self, *, owner_id: UUID, memory_type: MemoryType, content: dict, ) -> bool: ... async def update_content( self, memory: Memory, content: dict | None = None, ) -> Memory: ... class SQLAlchemyMemoriesRepository(BaseRepository[Memory]): _session: AsyncSession def __init__(self, session: AsyncSession) -> None: super().__init__(session=session, model=Memory) self._session = session async def create( self, *, owner_id: UUID, memory_type: MemoryType, content: dict, ) -> Memory: try: memory = Memory( owner_id=owner_id, memory_type=memory_type, content=content, status=MemoryStatus.ACTIVE, ) self._session.add(memory) await self._session.flush() return memory except SQLAlchemyError: logger.exception( "Failed to create memory", owner_id=str(owner_id), memory_type=memory_type.value, ) raise async def get_by_type_for_owner( self, *, owner_id: UUID, memory_type: MemoryType ) -> Memory | None: try: stmt = ( select(Memory) .where(Memory.owner_id == owner_id) .where(Memory.memory_type == memory_type) .where(Memory.status == MemoryStatus.ACTIVE) ) result = await self._session.execute(stmt) if hasattr(result, "scalar_one_or_none"): return result.scalar_one_or_none() scalars = result.scalars() rows = list(scalars.all()) return rows[0] if rows else None except SQLAlchemyError: logger.exception( "Failed to get memory by type for owner", owner_id=str(owner_id), memory_type=memory_type.value, ) raise async def get_user_memory_for_owner(self, *, owner_id: UUID) -> Memory | None: return await self.get_by_type_for_owner( owner_id=owner_id, memory_type=MemoryType.USER ) async def get_work_memory_for_owner(self, *, owner_id: UUID) -> Memory | None: return await self.get_by_type_for_owner( owner_id=owner_id, memory_type=MemoryType.WORK ) async def create_if_absent( self, *, owner_id: UUID, memory_type: MemoryType, content: dict, ) -> bool: try: stmt = ( insert(Memory) .values( owner_id=owner_id, memory_type=memory_type, content=content, status=MemoryStatus.ACTIVE, ) .on_conflict_do_nothing(index_elements=["owner_id", "memory_type"]) .returning(Memory.id) ) inserted_id = (await self._session.execute(stmt)).scalar_one_or_none() await self._session.flush() return inserted_id is not None except SQLAlchemyError: logger.exception( "Failed to create memory if absent", owner_id=str(owner_id), memory_type=memory_type.value, ) raise async def update_content( self, memory: Memory, content: dict | None = None, ) -> Memory: try: if content is not None: memory.content = content await self._session.flush() return memory except SQLAlchemyError: logger.exception( "Failed to update memory content", memory_id=str(memory.id), ) raise