Files
social-app/backend/src/v1/memories/repository.py
T

166 lines
4.8 KiB
Python

from __future__ import annotations
from typing import TYPE_CHECKING, Protocol
from uuid import UUID
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy import select
from sqlalchemy.exc import SQLAlchemyError
from core.db.base_repository import BaseRepository
from core.logging import get_logger
from models.memories import Memory
from schemas.enums import MemoryStatus, MemoryType
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
logger = get_logger("v1.memories.repository")
class MemoriesRepositoryLike(Protocol):
async def create(
self,
*,
owner_id: UUID,
memory_type: MemoryType,
content: dict,
) -> Memory: ...
async def get_by_type_for_owner(
self, *, owner_id: UUID, memory_type: MemoryType
) -> Memory | None: ...
async def get_user_memory_for_owner(self, *, owner_id: UUID) -> Memory | None: ...
async def get_work_memory_for_owner(self, *, owner_id: UUID) -> Memory | None: ...
async def create_if_absent(
self,
*,
owner_id: UUID,
memory_type: MemoryType,
content: dict,
) -> bool: ...
async def update_content(
self,
memory: Memory,
content: dict | None = None,
) -> Memory: ...
class SQLAlchemyMemoriesRepository(BaseRepository[Memory]):
_session: AsyncSession
def __init__(self, session: AsyncSession) -> None:
super().__init__(session=session, model=Memory)
self._session = session
async def create(
self,
*,
owner_id: UUID,
memory_type: MemoryType,
content: dict,
) -> Memory:
try:
memory = Memory(
owner_id=owner_id,
memory_type=memory_type,
content=content,
status=MemoryStatus.ACTIVE,
)
self._session.add(memory)
await self._session.flush()
return memory
except SQLAlchemyError:
logger.exception(
"Failed to create memory",
owner_id=str(owner_id),
memory_type=memory_type.value,
)
raise
async def get_by_type_for_owner(
self, *, owner_id: UUID, memory_type: MemoryType
) -> Memory | None:
try:
stmt = (
select(Memory)
.where(Memory.owner_id == owner_id)
.where(Memory.memory_type == memory_type)
.where(Memory.status == MemoryStatus.ACTIVE)
)
result = await self._session.execute(stmt)
if hasattr(result, "scalar_one_or_none"):
return result.scalar_one_or_none()
scalars = result.scalars()
rows = list(scalars.all())
return rows[0] if rows else None
except SQLAlchemyError:
logger.exception(
"Failed to get memory by type for owner",
owner_id=str(owner_id),
memory_type=memory_type.value,
)
raise
async def get_user_memory_for_owner(self, *, owner_id: UUID) -> Memory | None:
return await self.get_by_type_for_owner(
owner_id=owner_id, memory_type=MemoryType.USER
)
async def get_work_memory_for_owner(self, *, owner_id: UUID) -> Memory | None:
return await self.get_by_type_for_owner(
owner_id=owner_id, memory_type=MemoryType.WORK
)
async def create_if_absent(
self,
*,
owner_id: UUID,
memory_type: MemoryType,
content: dict,
) -> bool:
try:
stmt = (
insert(Memory)
.values(
owner_id=owner_id,
memory_type=memory_type,
content=content,
status=MemoryStatus.ACTIVE,
)
.on_conflict_do_nothing(index_elements=["owner_id", "memory_type"])
.returning(Memory.id)
)
inserted_id = (await self._session.execute(stmt)).scalar_one_or_none()
await self._session.flush()
return inserted_id is not None
except SQLAlchemyError:
logger.exception(
"Failed to create memory if absent",
owner_id=str(owner_id),
memory_type=memory_type.value,
)
raise
async def update_content(
self,
memory: Memory,
content: dict | None = None,
) -> Memory:
try:
if content is not None:
memory.content = content
await self._session.flush()
return memory
except SQLAlchemyError:
logger.exception(
"Failed to update memory content",
memory_id=str(memory.id),
)
raise