c79c773d67
- Add NotificationTargetMode enum (new_users/exist_users/all_users/user_ids) - Create Alembic migrations: drop duplicate indexes, add target_mode column - Merge register-notifications.sh into dev-migrate.sh sync-notifications subcommand - Shorten notification config path: static/notification/notifications -> static/notifications - Update registration flow to dispatch notifications by target_mode - Add is_first_registration to RegisterBonusResult for first-time user detection - Remove dead code: link_published_notifications_to_user - Update welcome_points.yaml to target new_users only - Add 44 unit tests + 1 integration test, all passing
161 lines
5.5 KiB
Python
161 lines
5.5 KiB
Python
from __future__ import annotations
|
|
|
|
from datetime import datetime, timezone
|
|
from uuid import UUID
|
|
|
|
from sqlalchemy.dialects.postgresql import insert
|
|
from sqlalchemy import func, select, update
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from models.notification import Notification
|
|
from models.user_notification import UserNotification
|
|
from schemas.enums import NotificationTargetMode
|
|
|
|
|
|
class NotificationRepository:
|
|
def __init__(self, session: AsyncSession) -> None:
|
|
self._session = session
|
|
|
|
async def list_notifications(
|
|
self,
|
|
*,
|
|
user_id: UUID,
|
|
limit: int = 20,
|
|
cursor: datetime | None = None,
|
|
) -> list[tuple[UserNotification, Notification]]:
|
|
stmt = (
|
|
select(UserNotification, Notification)
|
|
.join(Notification, UserNotification.notification_id == Notification.id)
|
|
.where(
|
|
UserNotification.user_id == user_id,
|
|
Notification.status == "published",
|
|
Notification.deleted_at.is_(None),
|
|
)
|
|
.order_by(UserNotification.created_at.desc())
|
|
.limit(limit + 1)
|
|
)
|
|
if cursor is not None:
|
|
stmt = stmt.where(UserNotification.created_at < cursor)
|
|
|
|
rows = (await self._session.execute(stmt)).all()
|
|
return [(row[0], row[1]) for row in rows]
|
|
|
|
async def get_unread_count(self, *, user_id: UUID) -> int:
|
|
stmt = (
|
|
select(func.count())
|
|
.select_from(UserNotification)
|
|
.join(Notification, UserNotification.notification_id == Notification.id)
|
|
.where(
|
|
UserNotification.user_id == user_id,
|
|
UserNotification.is_read.is_(False),
|
|
Notification.status == "published",
|
|
Notification.deleted_at.is_(None),
|
|
)
|
|
)
|
|
result = (await self._session.execute(stmt)).scalar_one()
|
|
return result
|
|
|
|
async def get_user_notification(
|
|
self,
|
|
*,
|
|
user_notification_id: UUID,
|
|
user_id: UUID,
|
|
) -> tuple[UserNotification, Notification] | None:
|
|
stmt = (
|
|
select(UserNotification, Notification)
|
|
.join(Notification, UserNotification.notification_id == Notification.id)
|
|
.where(
|
|
UserNotification.id == user_notification_id,
|
|
UserNotification.user_id == user_id,
|
|
Notification.status == "published",
|
|
Notification.deleted_at.is_(None),
|
|
)
|
|
)
|
|
row = (await self._session.execute(stmt)).first()
|
|
if row is None:
|
|
return None
|
|
return (row[0], row[1])
|
|
|
|
async def mark_read(self, *, user_notification_id: UUID, user_id: UUID) -> bool:
|
|
stmt = select(UserNotification).where(
|
|
UserNotification.id == user_notification_id,
|
|
UserNotification.user_id == user_id,
|
|
UserNotification.is_read.is_(False),
|
|
)
|
|
un = (await self._session.execute(stmt)).scalar_one_or_none()
|
|
if un is None:
|
|
return False
|
|
un.is_read = True
|
|
un.read_at = datetime.now(timezone.utc)
|
|
await self._session.flush()
|
|
return True
|
|
|
|
async def mark_all_read(self, *, user_id: UUID) -> int:
|
|
un_ids_stmt = (
|
|
select(UserNotification.id)
|
|
.join(Notification, UserNotification.notification_id == Notification.id)
|
|
.where(
|
|
UserNotification.user_id == user_id,
|
|
UserNotification.is_read.is_(False),
|
|
Notification.status == "published",
|
|
Notification.deleted_at.is_(None),
|
|
)
|
|
)
|
|
un_ids = list((await self._session.execute(un_ids_stmt)).scalars().all())
|
|
if not un_ids:
|
|
return 0
|
|
count = len(un_ids)
|
|
stmt = (
|
|
update(UserNotification)
|
|
.where(UserNotification.id.in_(un_ids))
|
|
.values(is_read=True, read_at=func.now())
|
|
)
|
|
await self._session.execute(stmt)
|
|
await self._session.flush()
|
|
return count
|
|
|
|
async def commit(self) -> None:
|
|
await self._session.commit()
|
|
|
|
async def link_notifications_for_registered_user(
|
|
self, *, user_id: UUID, is_first_registration: bool
|
|
) -> int:
|
|
target_modes: list[NotificationTargetMode]
|
|
if is_first_registration:
|
|
target_modes = [
|
|
NotificationTargetMode.NEW_USERS,
|
|
NotificationTargetMode.ALL_USERS,
|
|
]
|
|
else:
|
|
target_modes = [NotificationTargetMode.ALL_USERS]
|
|
notification_ids = list(
|
|
(
|
|
await self._session.execute(
|
|
select(Notification.id).where(
|
|
Notification.status == "published",
|
|
Notification.deleted_at.is_(None),
|
|
Notification.target_mode.in_(target_modes),
|
|
)
|
|
)
|
|
)
|
|
.scalars()
|
|
.all()
|
|
)
|
|
if not notification_ids:
|
|
return 0
|
|
|
|
stmt = (
|
|
insert(UserNotification)
|
|
.values(
|
|
[
|
|
{"user_id": user_id, "notification_id": nid}
|
|
for nid in notification_ids
|
|
]
|
|
)
|
|
.on_conflict_do_nothing(index_elements=["user_id", "notification_id"])
|
|
.returning(UserNotification.id)
|
|
)
|
|
result = await self._session.execute(stmt)
|
|
await self._session.flush()
|
|
return len(list(result.scalars().all()))
|