feat(notification): add target_mode enum constraint and merge register-notifications script
- Add NotificationTargetMode enum (new_users/exist_users/all_users/user_ids) - Create Alembic migrations: drop duplicate indexes, add target_mode column - Merge register-notifications.sh into dev-migrate.sh sync-notifications subcommand - Shorten notification config path: static/notification/notifications -> static/notifications - Update registration flow to dispatch notifications by target_mode - Add is_first_registration to RegisterBonusResult for first-time user detection - Remove dead code: link_published_notifications_to_user - Update welcome_points.yaml to target new_users only - Add 44 unit tests + 1 integration test, all passing
This commit is contained in:
@@ -0,0 +1,25 @@
|
||||
"""drop duplicate indexes on llm_factory.name and llms.model_code
|
||||
|
||||
Revision ID: 20260416_0002
|
||||
Revises: 20260416_0001
|
||||
Create Date: 2026-04-16
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "20260416_0002"
|
||||
down_revision: Union[str, Sequence[str], None] = "20260416_0001"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_index("ix_llm_factory_name", table_name="llm_factory")
|
||||
op.drop_index("ix_llms_model_code", table_name="llms")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.create_index("ix_llm_factory_name", "llm_factory", ["name"], unique=True)
|
||||
op.create_index("ix_llms_model_code", "llms", ["model_code"], unique=True)
|
||||
@@ -0,0 +1,37 @@
|
||||
"""add target_mode to notifications
|
||||
|
||||
Revision ID: 20260416_0003
|
||||
Revises: 20260416_0002
|
||||
Create Date: 2026-04-16
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision: str = "20260416_0003"
|
||||
down_revision: Union[str, Sequence[str], None] = "20260416_0002"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"notifications",
|
||||
sa.Column(
|
||||
"target_mode",
|
||||
sa.String(32),
|
||||
nullable=False,
|
||||
server_default="all_users",
|
||||
),
|
||||
)
|
||||
op.execute(
|
||||
"ALTER TABLE notifications ADD CONSTRAINT ck_notifications_target_mode "
|
||||
"CHECK (target_mode IN ('new_users', 'exist_users', 'all_users', 'user_ids'))"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("ALTER TABLE notifications DROP CONSTRAINT ck_notifications_target_mode")
|
||||
op.drop_column("notifications", "target_mode")
|
||||
@@ -13,6 +13,7 @@ from backend.src.schemas.shared.notification import (
|
||||
NotificationPayload,
|
||||
NotificationPayloadNone,
|
||||
)
|
||||
from schemas.enums import NotificationTargetMode
|
||||
|
||||
|
||||
class StaticNotificationDefinition(BaseModel):
|
||||
@@ -32,14 +33,16 @@ class StaticNotificationDefinition(BaseModel):
|
||||
class StaticNotificationTargets(BaseModel):
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
mode: Literal["all_users", "user_ids"]
|
||||
mode: NotificationTargetMode
|
||||
user_ids: list[UUID] | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_target_mode(self) -> StaticNotificationTargets:
|
||||
if self.mode == "all_users" and self.user_ids is not None:
|
||||
raise ValueError("targets.user_ids must be absent when mode=all_users")
|
||||
if self.mode == "user_ids":
|
||||
if self.mode != NotificationTargetMode.USER_IDS and self.user_ids is not None:
|
||||
raise ValueError(
|
||||
"targets.user_ids must be absent when mode is not user_ids"
|
||||
)
|
||||
if self.mode == NotificationTargetMode.USER_IDS:
|
||||
if self.user_ids is None or len(self.user_ids) == 0:
|
||||
raise ValueError(
|
||||
"targets.user_ids must be a non-empty list when mode=user_ids"
|
||||
|
||||
@@ -23,7 +23,9 @@ from core.config.notification.static_schema import (
|
||||
)
|
||||
from models.auth_user import AuthUser
|
||||
from models.notification import Notification
|
||||
from models.register_bonus_claims import RegisterBonusClaims
|
||||
from models.user_notification import UserNotification
|
||||
from schemas.enums import NotificationTargetMode
|
||||
from utils.paths import get_notification_config_dir
|
||||
|
||||
logger = get_logger("core.config.notification.static_sync")
|
||||
@@ -203,6 +205,7 @@ async def _sync_document(
|
||||
body=definition.body,
|
||||
payload=_payload_to_dict(definition.payload),
|
||||
status=definition.status,
|
||||
target_mode=document.config.targets.mode,
|
||||
published_at=_resolve_published_at(existing=None, config=definition),
|
||||
revoked_at=_resolve_revoked_at(existing=None, config=definition),
|
||||
deleted_at=_resolve_deleted_at(existing=None, config=definition),
|
||||
@@ -219,6 +222,7 @@ async def _sync_document(
|
||||
notification=notification,
|
||||
config=definition,
|
||||
content_hash=content_hash,
|
||||
target_mode=document.config.targets.mode,
|
||||
)
|
||||
if changed:
|
||||
updated = 1
|
||||
@@ -373,7 +377,11 @@ def _resolve_deleted_at(
|
||||
|
||||
|
||||
def _apply_notification_updates(
|
||||
*, notification: Notification, config: object, content_hash: str
|
||||
*,
|
||||
notification: Notification,
|
||||
config: object,
|
||||
content_hash: str,
|
||||
target_mode: NotificationTargetMode,
|
||||
) -> bool:
|
||||
next_values = {
|
||||
"type": getattr(config, "type"),
|
||||
@@ -383,6 +391,7 @@ def _apply_notification_updates(
|
||||
"body": getattr(config, "body"),
|
||||
"payload": _payload_to_dict(getattr(config, "payload")),
|
||||
"status": getattr(config, "status"),
|
||||
"target_mode": target_mode,
|
||||
"published_at": _resolve_published_at(existing=notification, config=config),
|
||||
"revoked_at": _resolve_revoked_at(existing=notification, config=config),
|
||||
"deleted_at": _resolve_deleted_at(existing=notification, config=config),
|
||||
@@ -399,7 +408,7 @@ async def _resolve_target_user_ids(
|
||||
*, session: AsyncSession, config: StaticNotificationFile
|
||||
) -> list[UUID]:
|
||||
targets = config.targets
|
||||
if targets.mode == "user_ids":
|
||||
if targets.mode == NotificationTargetMode.USER_IDS:
|
||||
requested_user_ids = list(dict.fromkeys(targets.user_ids or []))
|
||||
result = await session.execute(
|
||||
select(AuthUser.id).where(AuthUser.id.in_(requested_user_ids))
|
||||
@@ -416,6 +425,21 @@ async def _resolve_target_user_ids(
|
||||
+ ", ".join(sorted(missing_user_ids))
|
||||
)
|
||||
return requested_user_ids
|
||||
if targets.mode in (
|
||||
NotificationTargetMode.NEW_USERS,
|
||||
NotificationTargetMode.EXIST_USERS,
|
||||
):
|
||||
claimed_result = await session.execute(
|
||||
select(RegisterBonusClaims.first_user_id_snapshot).where(
|
||||
RegisterBonusClaims.first_user_id_snapshot.isnot(None)
|
||||
)
|
||||
)
|
||||
claimed_ids = set(claimed_result.scalars().all())
|
||||
all_users_result = await session.execute(select(AuthUser.id))
|
||||
all_user_ids = all_users_result.scalars().all()
|
||||
if targets.mode == NotificationTargetMode.NEW_USERS:
|
||||
return [uid for uid in all_user_ids if uid not in claimed_ids]
|
||||
return [uid for uid in all_user_ids if uid in claimed_ids]
|
||||
result = await session.execute(select(AuthUser.id))
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
+1
-1
@@ -11,4 +11,4 @@ notification:
|
||||
tab: balance
|
||||
|
||||
targets:
|
||||
mode: all_users
|
||||
mode: new_users
|
||||
@@ -84,7 +84,7 @@ async def run_init_data() -> bool:
|
||||
|
||||
|
||||
async def bootstrap() -> bool:
|
||||
logger.info("Starting bootstrap (migrate + init-data)")
|
||||
logger.info("Starting bootstrap (migrate + init-data + sync-notifications)")
|
||||
|
||||
if not run_migrations():
|
||||
logger.error("Bootstrap aborted: migrations failed")
|
||||
@@ -94,6 +94,10 @@ async def bootstrap() -> bool:
|
||||
logger.error("Bootstrap aborted: init-data failed")
|
||||
return False
|
||||
|
||||
if not await run_sync_notifications():
|
||||
logger.error("Bootstrap aborted: sync-notifications failed")
|
||||
return False
|
||||
|
||||
logger.info("Bootstrap completed successfully")
|
||||
return True
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from core.db.base import Base, SoftDeleteMixin, TimestampMixin
|
||||
from core.db.types import json_jsonb
|
||||
from schemas.enums import NotificationTargetMode
|
||||
|
||||
|
||||
class Notification(TimestampMixin, SoftDeleteMixin, Base):
|
||||
@@ -18,6 +19,10 @@ class Notification(TimestampMixin, SoftDeleteMixin, Base):
|
||||
"status IN ('draft', 'published', 'revoked')",
|
||||
name="ck_notifications_status",
|
||||
),
|
||||
CheckConstraint(
|
||||
"target_mode IN ('new_users', 'exist_users', 'all_users', 'user_ids')",
|
||||
name="ck_notifications_target_mode",
|
||||
),
|
||||
CheckConstraint(
|
||||
"jsonb_typeof(payload) = 'object'",
|
||||
name="ck_notifications_payload_object",
|
||||
@@ -63,6 +68,9 @@ class Notification(TimestampMixin, SoftDeleteMixin, Base):
|
||||
status: Mapped[str] = mapped_column(
|
||||
String(16), nullable=False, server_default=text("'published'")
|
||||
)
|
||||
target_mode: Mapped[NotificationTargetMode] = mapped_column(
|
||||
String(32), nullable=False, server_default=text("'all_users'")
|
||||
)
|
||||
published_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
@@ -151,3 +151,10 @@ class GroupMemberStatus(str, Enum):
|
||||
ACTIVE = "active"
|
||||
MUTED = "muted"
|
||||
REMOVED = "removed"
|
||||
|
||||
|
||||
class NotificationTargetMode(str, Enum):
|
||||
NEW_USERS = "new_users"
|
||||
EXIST_USERS = "exist_users"
|
||||
ALL_USERS = "all_users"
|
||||
USER_IDS = "user_ids"
|
||||
|
||||
@@ -24,7 +24,7 @@ def get_database_config_dir() -> Path:
|
||||
|
||||
|
||||
def get_notification_config_dir() -> Path:
|
||||
return get_static_config_dir() / "notification/notifications"
|
||||
return get_static_config_dir() / "notifications"
|
||||
|
||||
|
||||
def get_divination_data_dir() -> Path:
|
||||
|
||||
@@ -73,13 +73,14 @@ async def create_email_session(
|
||||
)
|
||||
result = await service.create_email_session(payload)
|
||||
points_service = PointsService(repository=PointsRepository(session))
|
||||
await points_service.grant_register_bonus_if_eligible(
|
||||
bonus_result = await points_service.grant_register_bonus_if_eligible(
|
||||
user_id=UUID(result.user.id),
|
||||
user_email=result.user.email,
|
||||
)
|
||||
notification_service = NotificationService(NotificationRepository(session))
|
||||
linked_count = await notification_service.link_published_notifications_to_user(
|
||||
user_id=UUID(result.user.id)
|
||||
linked_count = await notification_service.link_notifications_for_registered_user(
|
||||
user_id=UUID(result.user.id),
|
||||
is_first_registration=bonus_result.is_first_registration,
|
||||
)
|
||||
await session.commit()
|
||||
logger.info(
|
||||
|
||||
@@ -9,6 +9,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from models.notification import Notification
|
||||
from models.user_notification import UserNotification
|
||||
from schemas.enums import NotificationTargetMode
|
||||
|
||||
|
||||
class NotificationRepository:
|
||||
@@ -116,13 +117,24 @@ class NotificationRepository:
|
||||
async def commit(self) -> None:
|
||||
await self._session.commit()
|
||||
|
||||
async def link_published_notifications_to_user(self, *, user_id: UUID) -> int:
|
||||
async def link_notifications_for_registered_user(
|
||||
self, *, user_id: UUID, is_first_registration: bool
|
||||
) -> int:
|
||||
target_modes: list[NotificationTargetMode]
|
||||
if is_first_registration:
|
||||
target_modes = [
|
||||
NotificationTargetMode.NEW_USERS,
|
||||
NotificationTargetMode.ALL_USERS,
|
||||
]
|
||||
else:
|
||||
target_modes = [NotificationTargetMode.ALL_USERS]
|
||||
notification_ids = list(
|
||||
(
|
||||
await self._session.execute(
|
||||
select(Notification.id).where(
|
||||
Notification.status == "published",
|
||||
Notification.deleted_at.is_(None),
|
||||
Notification.target_mode.in_(target_modes),
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -136,8 +148,8 @@ class NotificationRepository:
|
||||
insert(UserNotification)
|
||||
.values(
|
||||
[
|
||||
{"user_id": user_id, "notification_id": notification_id}
|
||||
for notification_id in notification_ids
|
||||
{"user_id": user_id, "notification_id": nid}
|
||||
for nid in notification_ids
|
||||
]
|
||||
)
|
||||
.on_conflict_do_nothing(index_elements=["user_id", "notification_id"])
|
||||
|
||||
@@ -123,9 +123,11 @@ class NotificationService:
|
||||
await self._repository.commit()
|
||||
return updated_count
|
||||
|
||||
async def link_published_notifications_to_user(self, *, user_id: UUID) -> int:
|
||||
return await self._repository.link_published_notifications_to_user(
|
||||
user_id=user_id
|
||||
async def link_notifications_for_registered_user(
|
||||
self, *, user_id: UUID, is_first_registration: bool
|
||||
) -> int:
|
||||
return await self._repository.link_notifications_for_registered_user(
|
||||
user_id=user_id, is_first_registration=is_first_registration
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -61,6 +61,7 @@ class RegisterBonusResult:
|
||||
amount: int
|
||||
balance_after: int
|
||||
event_id: str
|
||||
is_first_registration: bool = False
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -122,14 +123,17 @@ class PointsService:
|
||||
account = await self._repository.get_or_create_user_points_for_update(
|
||||
user_id=user_id
|
||||
)
|
||||
if claim is not None and claim.balance_snapshot is not None:
|
||||
account.balance = max(int(claim.balance_snapshot), 0)
|
||||
account.version = int(account.version) + 1
|
||||
if claim is not None:
|
||||
is_first_registration = claim.first_user_id_snapshot is None
|
||||
if claim.balance_snapshot is not None:
|
||||
account.balance = max(int(claim.balance_snapshot), 0)
|
||||
account.version = int(account.version) + 1
|
||||
return RegisterBonusResult(
|
||||
granted=False,
|
||||
amount=0,
|
||||
balance_after=int(account.balance),
|
||||
event_id=event_id,
|
||||
is_first_registration=is_first_registration,
|
||||
)
|
||||
|
||||
claimed = await self._repository.claim_register_bonus(
|
||||
@@ -144,6 +148,7 @@ class PointsService:
|
||||
amount=0,
|
||||
balance_after=int(account.balance),
|
||||
event_id=event_id,
|
||||
is_first_registration=False,
|
||||
)
|
||||
|
||||
balance = int(account.balance)
|
||||
@@ -197,6 +202,7 @@ class PointsService:
|
||||
amount=bonus_points,
|
||||
balance_after=int(account.balance),
|
||||
event_id=event_id,
|
||||
is_first_registration=True,
|
||||
)
|
||||
|
||||
async def ensure_run_points_available(
|
||||
|
||||
@@ -0,0 +1,96 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import TypedDict
|
||||
from uuid import UUID
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.db.session import AsyncSessionLocal
|
||||
from models.notification import Notification
|
||||
from models.user_notification import UserNotification
|
||||
|
||||
|
||||
class IdentityData(TypedDict):
|
||||
email: str
|
||||
code: str
|
||||
|
||||
|
||||
async def _create_email_session(
|
||||
client: httpx.AsyncClient,
|
||||
*,
|
||||
email: str,
|
||||
code: str,
|
||||
) -> dict[str, object]:
|
||||
resp = await client.post(
|
||||
"/api/v1/auth/email-session",
|
||||
json={"email": email, "token": code},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
async def _delete_user(client: httpx.AsyncClient, *, token: str) -> None:
|
||||
resp = await client.delete(
|
||||
"/api/v1/users/me",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert resp.status_code == 204
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_notification_target_mode_first_reg_and_reregister(
|
||||
api_client: httpx.AsyncClient,
|
||||
test_identity: IdentityData,
|
||||
db_cleanup: list[str],
|
||||
) -> None:
|
||||
email = str(test_identity["email"]).strip().lower()
|
||||
db_cleanup.append(email)
|
||||
|
||||
first = await _create_email_session(
|
||||
api_client, email=email, code=str(test_identity["code"])
|
||||
)
|
||||
user1 = first.get("user")
|
||||
assert isinstance(user1, dict)
|
||||
user1_id = UUID(str(user1["id"]))
|
||||
token1 = str(first["access_token"])
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
select(Notification.target_mode)
|
||||
.join(UserNotification, UserNotification.notification_id == Notification.id)
|
||||
.where(UserNotification.user_id == user1_id)
|
||||
.order_by(Notification.target_mode)
|
||||
)
|
||||
first_target_modes = [str(row[0]) for row in result.all()]
|
||||
|
||||
assert "new_users" in first_target_modes
|
||||
assert "exist_users" not in first_target_modes
|
||||
|
||||
await _delete_user(api_client, token=token1)
|
||||
time.sleep(0.5)
|
||||
|
||||
second = await _create_email_session(
|
||||
api_client, email=email, code=str(test_identity["code"])
|
||||
)
|
||||
user2 = second.get("user")
|
||||
assert isinstance(user2, dict)
|
||||
user2_id = UUID(str(user2["id"]))
|
||||
token2 = str(second["access_token"])
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
select(Notification.target_mode)
|
||||
.join(UserNotification, UserNotification.notification_id == Notification.id)
|
||||
.where(UserNotification.user_id == user2_id)
|
||||
.order_by(Notification.target_mode)
|
||||
)
|
||||
second_target_modes = [str(row[0]) for row in result.all()]
|
||||
|
||||
assert "new_users" not in second_target_modes
|
||||
assert "all_users" not in second_target_modes
|
||||
assert "exist_users" not in second_target_modes
|
||||
|
||||
await _delete_user(api_client, token=token2)
|
||||
@@ -0,0 +1,150 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from schemas.enums import NotificationTargetMode
|
||||
from v1.notifications.service import NotificationService
|
||||
|
||||
|
||||
class _FakeNotification:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
id: UUID,
|
||||
target_mode: NotificationTargetMode = NotificationTargetMode.ALL_USERS,
|
||||
status: str = "published",
|
||||
deleted_at: datetime | None = None,
|
||||
):
|
||||
self.id = id
|
||||
self.target_mode = target_mode
|
||||
self.status = status
|
||||
self.deleted_at = deleted_at
|
||||
|
||||
|
||||
class _TrackingNotificationRepository:
|
||||
def __init__(self, notifications: list[_FakeNotification]) -> None:
|
||||
self._notifications = notifications
|
||||
self.linked_notification_ids: list[list[UUID]] = []
|
||||
self.linked_is_first: list[bool] = []
|
||||
|
||||
async def link_notifications_for_registered_user(
|
||||
self, *, user_id: UUID, is_first_registration: bool
|
||||
) -> int:
|
||||
target_modes: list[NotificationTargetMode]
|
||||
if is_first_registration:
|
||||
target_modes = [
|
||||
NotificationTargetMode.NEW_USERS,
|
||||
NotificationTargetMode.ALL_USERS,
|
||||
]
|
||||
else:
|
||||
target_modes = [NotificationTargetMode.ALL_USERS]
|
||||
|
||||
matched = [
|
||||
n
|
||||
for n in self._notifications
|
||||
if n.status == "published"
|
||||
and n.deleted_at is None
|
||||
and n.target_mode in target_modes
|
||||
]
|
||||
self.linked_notification_ids.append([n.id for n in matched])
|
||||
self.linked_is_first.append(is_first_registration)
|
||||
return len(matched)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def notification_new_users() -> _FakeNotification:
|
||||
return _FakeNotification(id=uuid4(), target_mode=NotificationTargetMode.NEW_USERS)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def notification_all_users() -> _FakeNotification:
|
||||
return _FakeNotification(id=uuid4(), target_mode=NotificationTargetMode.ALL_USERS)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def notification_exist_users() -> _FakeNotification:
|
||||
return _FakeNotification(id=uuid4(), target_mode=NotificationTargetMode.EXIST_USERS)
|
||||
|
||||
|
||||
class TestLinkNotificationsForRegisteredUser:
|
||||
@pytest.mark.asyncio
|
||||
async def test_first_registration_gets_new_users_and_all_users(
|
||||
self,
|
||||
notification_new_users: _FakeNotification,
|
||||
notification_all_users: _FakeNotification,
|
||||
notification_exist_users: _FakeNotification,
|
||||
) -> None:
|
||||
repo = _TrackingNotificationRepository(
|
||||
[notification_new_users, notification_all_users, notification_exist_users]
|
||||
)
|
||||
service = NotificationService(repository=repo) # type: ignore[arg-type]
|
||||
|
||||
count = await service.link_notifications_for_registered_user(
|
||||
user_id=uuid4(), is_first_registration=True
|
||||
)
|
||||
|
||||
assert count == 2
|
||||
linked_ids = repo.linked_notification_ids[0]
|
||||
assert notification_new_users.id in linked_ids
|
||||
assert notification_all_users.id in linked_ids
|
||||
assert notification_exist_users.id not in linked_ids
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reregistered_user_only_gets_all_users(
|
||||
self,
|
||||
notification_new_users: _FakeNotification,
|
||||
notification_all_users: _FakeNotification,
|
||||
notification_exist_users: _FakeNotification,
|
||||
) -> None:
|
||||
repo = _TrackingNotificationRepository(
|
||||
[notification_new_users, notification_all_users, notification_exist_users]
|
||||
)
|
||||
service = NotificationService(repository=repo) # type: ignore[arg-type]
|
||||
|
||||
count = await service.link_notifications_for_registered_user(
|
||||
user_id=uuid4(), is_first_registration=False
|
||||
)
|
||||
|
||||
assert count == 1
|
||||
linked_ids = repo.linked_notification_ids[0]
|
||||
assert notification_new_users.id not in linked_ids
|
||||
assert notification_all_users.id in linked_ids
|
||||
assert notification_exist_users.id not in linked_ids
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_published_notifications_returns_zero(self) -> None:
|
||||
repo = _TrackingNotificationRepository([])
|
||||
service = NotificationService(repository=repo) # type: ignore[arg-type]
|
||||
|
||||
count = await service.link_notifications_for_registered_user(
|
||||
user_id=uuid4(), is_first_registration=True
|
||||
)
|
||||
|
||||
assert count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_only_new_users_notification_first_registration(self) -> None:
|
||||
n = _FakeNotification(id=uuid4(), target_mode=NotificationTargetMode.NEW_USERS)
|
||||
repo = _TrackingNotificationRepository([n])
|
||||
service = NotificationService(repository=repo) # type: ignore[arg-type]
|
||||
|
||||
count = await service.link_notifications_for_registered_user(
|
||||
user_id=uuid4(), is_first_registration=True
|
||||
)
|
||||
|
||||
assert count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_only_new_users_notification_reregistered(self) -> None:
|
||||
n = _FakeNotification(id=uuid4(), target_mode=NotificationTargetMode.NEW_USERS)
|
||||
repo = _TrackingNotificationRepository([n])
|
||||
service = NotificationService(repository=repo) # type: ignore[arg-type]
|
||||
|
||||
count = await service.link_notifications_for_registered_user(
|
||||
user_id=uuid4(), is_first_registration=False
|
||||
)
|
||||
|
||||
assert count == 0
|
||||
@@ -0,0 +1,125 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from models.register_bonus_claims import RegisterBonusClaims
|
||||
from v1.points.service import PointsService
|
||||
|
||||
|
||||
class _FakeAccount:
|
||||
balance: int = 100
|
||||
frozen_balance: int = 0
|
||||
lifetime_earned: int = 0
|
||||
lifetime_spent: int = 0
|
||||
version: int = 0
|
||||
|
||||
|
||||
class _FakePointsRepository:
|
||||
def __init__(self, *, claim: RegisterBonusClaims | None = None) -> None:
|
||||
self.account = _FakeAccount()
|
||||
self.claim = claim
|
||||
self.claimed = False
|
||||
self.appended_ledger: list[object] = []
|
||||
self.appended_audit: list[object] = []
|
||||
|
||||
async def get_or_create_user_points_for_update(
|
||||
self, *, user_id: object
|
||||
) -> _FakeAccount:
|
||||
return self.account
|
||||
|
||||
async def has_ledger_event(self, *, user_id: object, event_id: str) -> bool:
|
||||
return False
|
||||
|
||||
async def append_ledger(self, *, command: object, balance_after: int) -> None:
|
||||
self.appended_ledger.append(command)
|
||||
|
||||
async def append_audit_ledger(self, *, command: object) -> None:
|
||||
self.appended_audit.append(command)
|
||||
|
||||
async def has_audit_event(self, *, event_id: str) -> bool:
|
||||
return False
|
||||
|
||||
async def claim_register_bonus(
|
||||
self,
|
||||
*,
|
||||
email_hash: str,
|
||||
user_email_snapshot: str,
|
||||
first_user_id_snapshot: object,
|
||||
grant_event_id: str,
|
||||
) -> bool:
|
||||
if self.claimed:
|
||||
return False
|
||||
self.claimed = True
|
||||
return True
|
||||
|
||||
async def get_register_bonus_claim(
|
||||
self, *, email_hash: str
|
||||
) -> RegisterBonusClaims | None:
|
||||
return self.claim
|
||||
|
||||
|
||||
class TestRegisterBonusIsFirstRegistration:
|
||||
@pytest.mark.asyncio
|
||||
async def test_first_registration_sets_true(self) -> None:
|
||||
repo = _FakePointsRepository(claim=None)
|
||||
service = PointsService(repository=repo) # type: ignore[arg-type]
|
||||
|
||||
result = await service.grant_register_bonus_if_eligible(
|
||||
user_id=uuid4(), user_email="new@example.com"
|
||||
)
|
||||
|
||||
assert result.granted is True
|
||||
assert result.is_first_registration is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reregistered_with_existing_claim_sets_false(self) -> None:
|
||||
existing_claim = RegisterBonusClaims(
|
||||
email_hash="abc",
|
||||
user_email_snapshot="old@example.com",
|
||||
first_user_id_snapshot=uuid4(),
|
||||
balance_snapshot=50,
|
||||
grant_event_id="evt",
|
||||
)
|
||||
repo = _FakePointsRepository(claim=existing_claim)
|
||||
service = PointsService(repository=repo) # type: ignore[arg-type]
|
||||
|
||||
result = await service.grant_register_bonus_if_eligible(
|
||||
user_id=uuid4(), user_email="old@example.com"
|
||||
)
|
||||
|
||||
assert result.granted is False
|
||||
assert result.is_first_registration is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reregistered_claim_without_first_user_id_sets_true(self) -> None:
|
||||
claim_no_snapshot = RegisterBonusClaims(
|
||||
email_hash="abc",
|
||||
user_email_snapshot="edge@example.com",
|
||||
first_user_id_snapshot=None,
|
||||
balance_snapshot=50,
|
||||
grant_event_id="evt",
|
||||
)
|
||||
repo = _FakePointsRepository(claim=claim_no_snapshot)
|
||||
service = PointsService(repository=repo) # type: ignore[arg-type]
|
||||
|
||||
result = await service.grant_register_bonus_if_eligible(
|
||||
user_id=uuid4(), user_email="edge@example.com"
|
||||
)
|
||||
|
||||
assert result.granted is False
|
||||
assert result.is_first_registration is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_claim_competition_failure_sets_false(self) -> None:
|
||||
repo = _FakePointsRepository(claim=None)
|
||||
repo.claimed = True
|
||||
service = PointsService(repository=repo) # type: ignore[arg-type]
|
||||
|
||||
result = await service.grant_register_bonus_if_eligible(
|
||||
user_id=uuid4(), user_email="race@example.com"
|
||||
)
|
||||
|
||||
assert result.granted is False
|
||||
assert result.is_first_registration is False
|
||||
@@ -10,6 +10,7 @@ from core.config.notification.static_sync import (
|
||||
build_static_notification_content_hash,
|
||||
load_static_notification_documents,
|
||||
)
|
||||
from schemas.enums import NotificationTargetMode
|
||||
|
||||
|
||||
def _write_yaml(path: Path, content: str) -> None:
|
||||
@@ -43,10 +44,86 @@ def test_load_static_notification_file_parses_valid_yaml(tmp_path: Path) -> None
|
||||
|
||||
assert loaded.notification.source_key == "welcome_bonus"
|
||||
assert loaded.notification.payload.action == "open_route"
|
||||
assert loaded.targets.mode == "user_ids"
|
||||
assert loaded.targets.mode == NotificationTargetMode.USER_IDS
|
||||
assert len(loaded.targets.user_ids or []) == 1
|
||||
|
||||
|
||||
def test_load_static_notification_file_parses_new_users(tmp_path: Path) -> None:
|
||||
file_path = tmp_path / "welcome_points.yaml"
|
||||
_write_yaml(
|
||||
file_path,
|
||||
"""
|
||||
notification:
|
||||
source_key: welcome_points
|
||||
version: 1
|
||||
type: system
|
||||
status: published
|
||||
title: Welcome
|
||||
body: You got points.
|
||||
payload:
|
||||
action: open_route
|
||||
route: /points
|
||||
tab: balance
|
||||
targets:
|
||||
mode: new_users
|
||||
""",
|
||||
)
|
||||
|
||||
loaded = load_static_notification_file(file_path)
|
||||
|
||||
assert loaded.targets.mode == NotificationTargetMode.NEW_USERS
|
||||
assert loaded.targets.user_ids is None
|
||||
|
||||
|
||||
def test_load_static_notification_file_parses_exist_users(tmp_path: Path) -> None:
|
||||
file_path = tmp_path / "promo.yaml"
|
||||
_write_yaml(
|
||||
file_path,
|
||||
"""
|
||||
notification:
|
||||
source_key: promo_return
|
||||
version: 1
|
||||
type: system
|
||||
status: published
|
||||
title: Come back
|
||||
body: We miss you.
|
||||
payload:
|
||||
action: none
|
||||
targets:
|
||||
mode: exist_users
|
||||
""",
|
||||
)
|
||||
|
||||
loaded = load_static_notification_file(file_path)
|
||||
|
||||
assert loaded.targets.mode == NotificationTargetMode.EXIST_USERS
|
||||
assert loaded.targets.user_ids is None
|
||||
|
||||
|
||||
def test_load_static_notification_file_parses_all_users(tmp_path: Path) -> None:
|
||||
file_path = tmp_path / "announce.yaml"
|
||||
_write_yaml(
|
||||
file_path,
|
||||
"""
|
||||
notification:
|
||||
source_key: system_announce
|
||||
version: 1
|
||||
type: system
|
||||
status: published
|
||||
title: Announcement
|
||||
body: Maintenance at midnight.
|
||||
payload:
|
||||
action: none
|
||||
targets:
|
||||
mode: all_users
|
||||
""",
|
||||
)
|
||||
|
||||
loaded = load_static_notification_file(file_path)
|
||||
|
||||
assert loaded.targets.mode == NotificationTargetMode.ALL_USERS
|
||||
|
||||
|
||||
def test_load_static_notification_file_rejects_invalid_targets(tmp_path: Path) -> None:
|
||||
file_path = tmp_path / "invalid.yaml"
|
||||
_write_yaml(
|
||||
@@ -72,6 +149,81 @@ def test_load_static_notification_file_rejects_invalid_targets(tmp_path: Path) -
|
||||
load_static_notification_file(file_path)
|
||||
|
||||
|
||||
def test_load_static_notification_file_rejects_unknown_mode(tmp_path: Path) -> None:
|
||||
file_path = tmp_path / "bad_mode.yaml"
|
||||
_write_yaml(
|
||||
file_path,
|
||||
"""
|
||||
notification:
|
||||
source_key: bad_mode
|
||||
version: 1
|
||||
type: system
|
||||
status: published
|
||||
title: Bad
|
||||
body: Bad mode.
|
||||
payload:
|
||||
action: none
|
||||
targets:
|
||||
mode: non_existent
|
||||
""",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid static notification data"):
|
||||
load_static_notification_file(file_path)
|
||||
|
||||
|
||||
def test_load_static_notification_file_rejects_new_users_with_user_ids(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
file_path = tmp_path / "bad_new_users.yaml"
|
||||
_write_yaml(
|
||||
file_path,
|
||||
"""
|
||||
notification:
|
||||
source_key: bad_new
|
||||
version: 1
|
||||
type: system
|
||||
status: published
|
||||
title: Bad
|
||||
body: Bad.
|
||||
payload:
|
||||
action: none
|
||||
targets:
|
||||
mode: new_users
|
||||
user_ids:
|
||||
- 11111111-1111-1111-1111-111111111111
|
||||
""",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid static notification data"):
|
||||
load_static_notification_file(file_path)
|
||||
|
||||
|
||||
def test_load_static_notification_file_rejects_user_ids_without_list(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
file_path = tmp_path / "bad_user_ids.yaml"
|
||||
_write_yaml(
|
||||
file_path,
|
||||
"""
|
||||
notification:
|
||||
source_key: bad_uids
|
||||
version: 1
|
||||
type: system
|
||||
status: published
|
||||
title: Bad
|
||||
body: Bad.
|
||||
payload:
|
||||
action: none
|
||||
targets:
|
||||
mode: user_ids
|
||||
""",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid static notification data"):
|
||||
load_static_notification_file(file_path)
|
||||
|
||||
|
||||
def test_load_static_notification_documents_rejects_duplicate_source_key(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
|
||||
Reference in New Issue
Block a user