from __future__ import annotations from datetime import datetime, timezone from uuid import UUID from sqlalchemy.dialects.postgresql import insert from sqlalchemy import func, select, update from sqlalchemy.ext.asyncio import AsyncSession from models.notification import Notification from models.user_notification import UserNotification from schemas.enums import NotificationTargetMode class NotificationRepository: def __init__(self, session: AsyncSession) -> None: self._session = session async def list_notifications( self, *, user_id: UUID, limit: int = 20, cursor: datetime | None = None, ) -> list[tuple[UserNotification, Notification]]: stmt = ( select(UserNotification, Notification) .join(Notification, UserNotification.notification_id == Notification.id) .where( UserNotification.user_id == user_id, Notification.status == "published", Notification.deleted_at.is_(None), ) .order_by(UserNotification.created_at.desc()) .limit(limit + 1) ) if cursor is not None: stmt = stmt.where(UserNotification.created_at < cursor) rows = (await self._session.execute(stmt)).all() return [(row[0], row[1]) for row in rows] async def get_unread_count(self, *, user_id: UUID) -> int: stmt = ( select(func.count()) .select_from(UserNotification) .join(Notification, UserNotification.notification_id == Notification.id) .where( UserNotification.user_id == user_id, UserNotification.is_read.is_(False), Notification.status == "published", Notification.deleted_at.is_(None), ) ) result = (await self._session.execute(stmt)).scalar_one() return result async def get_user_notification( self, *, user_notification_id: UUID, user_id: UUID, ) -> tuple[UserNotification, Notification] | None: stmt = ( select(UserNotification, Notification) .join(Notification, UserNotification.notification_id == Notification.id) .where( UserNotification.id == user_notification_id, UserNotification.user_id == user_id, Notification.status == "published", Notification.deleted_at.is_(None), ) ) row = (await self._session.execute(stmt)).first() if row is None: return None return (row[0], row[1]) async def mark_read(self, *, user_notification_id: UUID, user_id: UUID) -> bool: stmt = select(UserNotification).where( UserNotification.id == user_notification_id, UserNotification.user_id == user_id, UserNotification.is_read.is_(False), ) un = (await self._session.execute(stmt)).scalar_one_or_none() if un is None: return False un.is_read = True un.read_at = datetime.now(timezone.utc) await self._session.flush() return True async def mark_all_read(self, *, user_id: UUID) -> int: un_ids_stmt = ( select(UserNotification.id) .join(Notification, UserNotification.notification_id == Notification.id) .where( UserNotification.user_id == user_id, UserNotification.is_read.is_(False), Notification.status == "published", Notification.deleted_at.is_(None), ) ) un_ids = list((await self._session.execute(un_ids_stmt)).scalars().all()) if not un_ids: return 0 count = len(un_ids) stmt = ( update(UserNotification) .where(UserNotification.id.in_(un_ids)) .values(is_read=True, read_at=func.now()) ) await self._session.execute(stmt) await self._session.flush() return count async def commit(self) -> None: await self._session.commit() async def link_notifications_for_registered_user( self, *, user_id: UUID, is_first_registration: bool ) -> int: target_modes: list[NotificationTargetMode] if is_first_registration: target_modes = [ NotificationTargetMode.NEW_USERS, NotificationTargetMode.ALL_USERS, ] else: target_modes = [NotificationTargetMode.ALL_USERS] notification_ids = list( ( await self._session.execute( select(Notification.id).where( Notification.status == "published", Notification.deleted_at.is_(None), Notification.target_mode.in_(target_modes), ) ) ) .scalars() .all() ) if not notification_ids: return 0 stmt = ( insert(UserNotification) .values( [ {"user_id": user_id, "notification_id": nid} for nid in notification_ids ] ) .on_conflict_do_nothing(index_elements=["user_id", "notification_id"]) .returning(UserNotification.id) ) result = await self._session.execute(stmt) await self._session.flush() return len(list(result.scalars().all()))