Files
eryao/backend/src/v1/notifications/repository.py
T

161 lines
5.5 KiB
Python
Raw Normal View History

2026-04-10 18:50:08 +08:00
from __future__ import annotations
from datetime import datetime, timezone
2026-04-10 18:50:08 +08:00
from uuid import UUID
from sqlalchemy.dialects.postgresql import insert
2026-04-10 18:50:08 +08:00
from sqlalchemy import func, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from models.notification import Notification
from models.user_notification import UserNotification
from schemas.enums import NotificationTargetMode
2026-04-10 18:50:08 +08:00
class NotificationRepository:
def __init__(self, session: AsyncSession) -> None:
self._session = session
async def list_notifications(
self,
*,
user_id: UUID,
limit: int = 20,
cursor: datetime | None = None,
) -> list[tuple[UserNotification, Notification]]:
stmt = (
select(UserNotification, Notification)
.join(Notification, UserNotification.notification_id == Notification.id)
.where(
UserNotification.user_id == user_id,
Notification.status == "published",
Notification.deleted_at.is_(None),
)
.order_by(UserNotification.created_at.desc())
.limit(limit + 1)
)
if cursor is not None:
stmt = stmt.where(UserNotification.created_at < cursor)
rows = (await self._session.execute(stmt)).all()
return [(row[0], row[1]) for row in rows]
async def get_unread_count(self, *, user_id: UUID) -> int:
stmt = (
select(func.count())
.select_from(UserNotification)
.join(Notification, UserNotification.notification_id == Notification.id)
.where(
UserNotification.user_id == user_id,
UserNotification.is_read.is_(False),
Notification.status == "published",
Notification.deleted_at.is_(None),
)
)
result = (await self._session.execute(stmt)).scalar_one()
return result
async def get_user_notification(
self,
*,
user_notification_id: UUID,
user_id: UUID,
) -> tuple[UserNotification, Notification] | None:
stmt = (
select(UserNotification, Notification)
.join(Notification, UserNotification.notification_id == Notification.id)
.where(
UserNotification.id == user_notification_id,
UserNotification.user_id == user_id,
Notification.status == "published",
Notification.deleted_at.is_(None),
)
)
row = (await self._session.execute(stmt)).first()
if row is None:
return None
return (row[0], row[1])
async def mark_read(self, *, user_notification_id: UUID, user_id: UUID) -> bool:
stmt = select(UserNotification).where(
UserNotification.id == user_notification_id,
UserNotification.user_id == user_id,
UserNotification.is_read.is_(False),
)
un = (await self._session.execute(stmt)).scalar_one_or_none()
if un is None:
return False
un.is_read = True
un.read_at = datetime.now(timezone.utc)
2026-04-10 18:50:08 +08:00
await self._session.flush()
return True
async def mark_all_read(self, *, user_id: UUID) -> int:
un_ids_stmt = (
select(UserNotification.id)
.join(Notification, UserNotification.notification_id == Notification.id)
.where(
UserNotification.user_id == user_id,
UserNotification.is_read.is_(False),
Notification.status == "published",
Notification.deleted_at.is_(None),
)
)
un_ids = list((await self._session.execute(un_ids_stmt)).scalars().all())
if not un_ids:
return 0
count = len(un_ids)
stmt = (
update(UserNotification)
.where(UserNotification.id.in_(un_ids))
.values(is_read=True, read_at=func.now())
)
await self._session.execute(stmt)
await self._session.flush()
return count
async def commit(self) -> None:
await self._session.commit()
async def link_notifications_for_registered_user(
self, *, user_id: UUID, is_first_registration: bool
) -> int:
target_modes: list[NotificationTargetMode]
if is_first_registration:
target_modes = [
NotificationTargetMode.NEW_USERS,
NotificationTargetMode.ALL_USERS,
]
else:
target_modes = [NotificationTargetMode.ALL_USERS]
notification_ids = list(
(
await self._session.execute(
select(Notification.id).where(
Notification.status == "published",
Notification.deleted_at.is_(None),
Notification.target_mode.in_(target_modes),
)
)
)
.scalars()
.all()
)
if not notification_ids:
return 0
stmt = (
insert(UserNotification)
.values(
[
{"user_id": user_id, "notification_id": nid}
for nid in notification_ids
]
)
.on_conflict_do_nothing(index_elements=["user_id", "notification_id"])
.returning(UserNotification.id)
)
result = await self._session.execute(stmt)
await self._session.flush()
return len(list(result.scalars().all()))