Files
eryao/backend/src/v1/notifications/repository.py
T
qzl c79c773d67 feat(notification): add target_mode enum constraint and merge register-notifications script
- Add NotificationTargetMode enum (new_users/exist_users/all_users/user_ids)
- Create Alembic migrations: drop duplicate indexes, add target_mode column
- Merge register-notifications.sh into dev-migrate.sh sync-notifications subcommand
- Shorten notification config path: static/notification/notifications -> static/notifications
- Update registration flow to dispatch notifications by target_mode
- Add is_first_registration to RegisterBonusResult for first-time user detection
- Remove dead code: link_published_notifications_to_user
- Update welcome_points.yaml to target new_users only
- Add 44 unit tests + 1 integration test, all passing
2026-04-16 17:50:57 +08:00

161 lines
5.5 KiB
Python

from __future__ import annotations
from datetime import datetime, timezone
from uuid import UUID
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy import func, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from models.notification import Notification
from models.user_notification import UserNotification
from schemas.enums import NotificationTargetMode
class NotificationRepository:
def __init__(self, session: AsyncSession) -> None:
self._session = session
async def list_notifications(
self,
*,
user_id: UUID,
limit: int = 20,
cursor: datetime | None = None,
) -> list[tuple[UserNotification, Notification]]:
stmt = (
select(UserNotification, Notification)
.join(Notification, UserNotification.notification_id == Notification.id)
.where(
UserNotification.user_id == user_id,
Notification.status == "published",
Notification.deleted_at.is_(None),
)
.order_by(UserNotification.created_at.desc())
.limit(limit + 1)
)
if cursor is not None:
stmt = stmt.where(UserNotification.created_at < cursor)
rows = (await self._session.execute(stmt)).all()
return [(row[0], row[1]) for row in rows]
async def get_unread_count(self, *, user_id: UUID) -> int:
stmt = (
select(func.count())
.select_from(UserNotification)
.join(Notification, UserNotification.notification_id == Notification.id)
.where(
UserNotification.user_id == user_id,
UserNotification.is_read.is_(False),
Notification.status == "published",
Notification.deleted_at.is_(None),
)
)
result = (await self._session.execute(stmt)).scalar_one()
return result
async def get_user_notification(
self,
*,
user_notification_id: UUID,
user_id: UUID,
) -> tuple[UserNotification, Notification] | None:
stmt = (
select(UserNotification, Notification)
.join(Notification, UserNotification.notification_id == Notification.id)
.where(
UserNotification.id == user_notification_id,
UserNotification.user_id == user_id,
Notification.status == "published",
Notification.deleted_at.is_(None),
)
)
row = (await self._session.execute(stmt)).first()
if row is None:
return None
return (row[0], row[1])
async def mark_read(self, *, user_notification_id: UUID, user_id: UUID) -> bool:
stmt = select(UserNotification).where(
UserNotification.id == user_notification_id,
UserNotification.user_id == user_id,
UserNotification.is_read.is_(False),
)
un = (await self._session.execute(stmt)).scalar_one_or_none()
if un is None:
return False
un.is_read = True
un.read_at = datetime.now(timezone.utc)
await self._session.flush()
return True
async def mark_all_read(self, *, user_id: UUID) -> int:
un_ids_stmt = (
select(UserNotification.id)
.join(Notification, UserNotification.notification_id == Notification.id)
.where(
UserNotification.user_id == user_id,
UserNotification.is_read.is_(False),
Notification.status == "published",
Notification.deleted_at.is_(None),
)
)
un_ids = list((await self._session.execute(un_ids_stmt)).scalars().all())
if not un_ids:
return 0
count = len(un_ids)
stmt = (
update(UserNotification)
.where(UserNotification.id.in_(un_ids))
.values(is_read=True, read_at=func.now())
)
await self._session.execute(stmt)
await self._session.flush()
return count
async def commit(self) -> None:
await self._session.commit()
async def link_notifications_for_registered_user(
self, *, user_id: UUID, is_first_registration: bool
) -> int:
target_modes: list[NotificationTargetMode]
if is_first_registration:
target_modes = [
NotificationTargetMode.NEW_USERS,
NotificationTargetMode.ALL_USERS,
]
else:
target_modes = [NotificationTargetMode.ALL_USERS]
notification_ids = list(
(
await self._session.execute(
select(Notification.id).where(
Notification.status == "published",
Notification.deleted_at.is_(None),
Notification.target_mode.in_(target_modes),
)
)
)
.scalars()
.all()
)
if not notification_ids:
return 0
stmt = (
insert(UserNotification)
.values(
[
{"user_id": user_id, "notification_id": nid}
for nid in notification_ids
]
)
.on_conflict_do_nothing(index_elements=["user_id", "notification_id"])
.returning(UserNotification.id)
)
result = await self._session.execute(stmt)
await self._session.flush()
return len(list(result.scalars().all()))