feat: 实现站内通知系统
- 后端: 新增 notifications/user_notifications 表迁移及 ORM 模型
- 后端: 实现 schema/repository/service/router 全套通知 API
- GET /api/v1/notifications (列表+游标分页)
- GET /api/v1/notifications/unread-count
- PATCH /api/v1/notifications/{id}/read (幂等)
- PATCH /api/v1/notifications/mark-all-read (幂等)
- 后端: payload 使用 Pydantic discriminated union (none/open_route/open_url)
- 后端: 19 个单元测试全部通过
- Flutter: 通知 feature 完整实现 (models/apis/repositories/bloc/UI)
- Flutter: Home 页通知按钮接入真实页面,显示未读 badge
- Flutter: 14 个测试全部通过
- 协议文档: notification-inbox-protocol.md 及错误码注册
This commit is contained in:
@@ -0,0 +1,170 @@
|
||||
"""add notifications and user_notifications tables
|
||||
|
||||
Revision ID: 20260411_0004
|
||||
Revises: 20260411_0003
|
||||
Create Date: 2026-04-11 12:00:00
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
revision: str = "20260411_0004"
|
||||
down_revision: Union[str, Sequence[str], None] = "20260411_0003"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"notifications",
|
||||
sa.Column(
|
||||
"id", sa.UUID(), server_default=sa.text("gen_random_uuid()"), nullable=False
|
||||
),
|
||||
sa.Column(
|
||||
"type",
|
||||
sa.String(length=32),
|
||||
server_default=sa.text("'system'"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("title", sa.Text(), nullable=False),
|
||||
sa.Column("body", sa.Text(), nullable=False),
|
||||
sa.Column(
|
||||
"payload",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
server_default=sa.text("'{}'::jsonb"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"status",
|
||||
sa.String(length=16),
|
||||
server_default=sa.text("'published'"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("published_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("revoked_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.CheckConstraint(
|
||||
"status IN ('draft', 'published', 'revoked')",
|
||||
name="ck_notifications_status",
|
||||
),
|
||||
sa.CheckConstraint(
|
||||
"jsonb_typeof(payload) = 'object'",
|
||||
name="ck_notifications_payload_object",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_notifications_status_created_at",
|
||||
"notifications",
|
||||
["status", sa.text("created_at DESC")],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_notifications_published_at",
|
||||
"notifications",
|
||||
[sa.text("published_at DESC")],
|
||||
)
|
||||
_enable_rls("notifications")
|
||||
|
||||
op.create_table(
|
||||
"user_notifications",
|
||||
sa.Column(
|
||||
"id", sa.UUID(), server_default=sa.text("gen_random_uuid()"), nullable=False
|
||||
),
|
||||
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||
sa.Column("notification_id", sa.UUID(), nullable=False),
|
||||
sa.Column(
|
||||
"is_read", sa.Boolean(), server_default=sa.text("false"), nullable=False
|
||||
),
|
||||
sa.Column("read_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["auth.users.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["notification_id"], ["notifications.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint(
|
||||
"user_id", "notification_id", name="uq_user_notifications_user_notification"
|
||||
),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_user_notifications_user_created_at",
|
||||
"user_notifications",
|
||||
["user_id", sa.text("created_at DESC")],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_user_notifications_user_unread",
|
||||
"user_notifications",
|
||||
["user_id", "is_read"],
|
||||
)
|
||||
_enable_rls("user_notifications")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
_drop_rls("user_notifications")
|
||||
op.drop_index("ix_user_notifications_user_unread", table_name="user_notifications")
|
||||
op.drop_index(
|
||||
"ix_user_notifications_user_created_at", table_name="user_notifications"
|
||||
)
|
||||
op.drop_table("user_notifications")
|
||||
|
||||
_drop_rls("notifications")
|
||||
op.drop_index("ix_notifications_published_at", table_name="notifications")
|
||||
op.drop_index("ix_notifications_status_created_at", table_name="notifications")
|
||||
op.drop_table("notifications")
|
||||
|
||||
|
||||
def _enable_rls(table_name: str) -> None:
|
||||
for role in ["anon", "authenticated"]:
|
||||
for action in ["select", "insert", "update", "delete"]:
|
||||
op.execute(
|
||||
f"DROP POLICY IF EXISTS {role}_{action}_{table_name} ON {table_name}"
|
||||
)
|
||||
op.execute(f"ALTER TABLE {table_name} ENABLE ROW LEVEL SECURITY")
|
||||
for role in ["anon", "authenticated"]:
|
||||
op.execute(
|
||||
f"CREATE POLICY {role}_select_{table_name} ON {table_name} FOR SELECT TO {role} USING (false)"
|
||||
)
|
||||
op.execute(
|
||||
f"CREATE POLICY {role}_insert_{table_name} ON {table_name} FOR INSERT TO {role} WITH CHECK (false)"
|
||||
)
|
||||
op.execute(
|
||||
f"CREATE POLICY {role}_update_{table_name} ON {table_name} FOR UPDATE TO {role} USING (false) WITH CHECK (false)"
|
||||
)
|
||||
op.execute(
|
||||
f"CREATE POLICY {role}_delete_{table_name} ON {table_name} FOR DELETE TO {role} USING (false)"
|
||||
)
|
||||
|
||||
|
||||
def _drop_rls(table_name: str) -> None:
|
||||
for role in ["anon", "authenticated"]:
|
||||
for action in ["select", "insert", "update", "delete"]:
|
||||
op.execute(
|
||||
f"DROP POLICY IF EXISTS {role}_{action}_{table_name} ON {table_name}"
|
||||
)
|
||||
op.execute(f"ALTER TABLE {table_name} DISABLE ROW LEVEL SECURITY")
|
||||
@@ -10,7 +10,9 @@ from .points_audit_ledger import PointsAuditLedger
|
||||
from .points_ledger import PointsLedger
|
||||
from .profile import Profile
|
||||
from .register_bonus_claims import RegisterBonusClaims
|
||||
from .notification import Notification
|
||||
from .system_agents import SystemAgents
|
||||
from .user_notification import UserNotification
|
||||
from .user_points import UserPoints
|
||||
|
||||
__all__ = [
|
||||
@@ -20,10 +22,12 @@ __all__ = [
|
||||
"InviteCode",
|
||||
"Llm",
|
||||
"LlmFactory",
|
||||
"Notification",
|
||||
"PointsAuditLedger",
|
||||
"PointsLedger",
|
||||
"Profile",
|
||||
"RegisterBonusClaims",
|
||||
"SystemAgents",
|
||||
"UserNotification",
|
||||
"UserPoints",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import CheckConstraint, DateTime, Index, String, Text, text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from core.db.base import Base, SoftDeleteMixin, TimestampMixin
|
||||
from core.db.types import json_jsonb
|
||||
|
||||
|
||||
class Notification(TimestampMixin, SoftDeleteMixin, Base):
|
||||
__tablename__ = "notifications"
|
||||
__table_args__ = (
|
||||
CheckConstraint(
|
||||
"status IN ('draft', 'published', 'revoked')",
|
||||
name="ck_notifications_status",
|
||||
),
|
||||
CheckConstraint(
|
||||
"jsonb_typeof(payload) = 'object'",
|
||||
name="ck_notifications_payload_object",
|
||||
),
|
||||
Index(
|
||||
"ix_notifications_status_created_at",
|
||||
"status",
|
||||
"created_at",
|
||||
),
|
||||
Index(
|
||||
"ix_notifications_published_at",
|
||||
"published_at",
|
||||
),
|
||||
)
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
type: Mapped[str] = mapped_column(
|
||||
String(32), nullable=False, server_default=text("'system'")
|
||||
)
|
||||
title: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
body: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
payload: Mapped[dict[str, object]] = mapped_column(
|
||||
json_jsonb,
|
||||
nullable=False,
|
||||
server_default=text("'{}'::jsonb"),
|
||||
default=dict,
|
||||
)
|
||||
status: Mapped[str] = mapped_column(
|
||||
String(16), nullable=False, server_default=text("'published'")
|
||||
)
|
||||
published_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
revoked_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
@@ -0,0 +1,41 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, ForeignKey, UniqueConstraint, text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from core.db.base import Base, TimestampMixin
|
||||
|
||||
|
||||
class UserNotification(TimestampMixin, Base):
|
||||
__tablename__ = "user_notifications"
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"user_id",
|
||||
"notification_id",
|
||||
name="uq_user_notifications_user_notification",
|
||||
),
|
||||
)
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
user_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("auth.users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
notification_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("notifications.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
is_read: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, server_default=text("false")
|
||||
)
|
||||
read_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
@@ -0,0 +1 @@
|
||||
from __future__ import annotations
|
||||
@@ -0,0 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.db import get_db
|
||||
from v1.notifications.repository import NotificationRepository
|
||||
from v1.notifications.service import NotificationService
|
||||
|
||||
|
||||
def get_notification_service(
|
||||
session: AsyncSession = Depends(get_db),
|
||||
) -> NotificationService:
|
||||
return NotificationService(repository=NotificationRepository(session))
|
||||
@@ -0,0 +1,113 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from models.notification import Notification
|
||||
from models.user_notification import UserNotification
|
||||
|
||||
|
||||
class NotificationRepository:
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
async def list_notifications(
|
||||
self,
|
||||
*,
|
||||
user_id: UUID,
|
||||
limit: int = 20,
|
||||
cursor: datetime | None = None,
|
||||
) -> list[tuple[UserNotification, Notification]]:
|
||||
stmt = (
|
||||
select(UserNotification, Notification)
|
||||
.join(Notification, UserNotification.notification_id == Notification.id)
|
||||
.where(
|
||||
UserNotification.user_id == user_id,
|
||||
Notification.status == "published",
|
||||
Notification.deleted_at.is_(None),
|
||||
)
|
||||
.order_by(UserNotification.created_at.desc())
|
||||
.limit(limit + 1)
|
||||
)
|
||||
if cursor is not None:
|
||||
stmt = stmt.where(UserNotification.created_at < cursor)
|
||||
|
||||
rows = (await self._session.execute(stmt)).all()
|
||||
return [(row[0], row[1]) for row in rows]
|
||||
|
||||
async def get_unread_count(self, *, user_id: UUID) -> int:
|
||||
stmt = (
|
||||
select(func.count())
|
||||
.select_from(UserNotification)
|
||||
.join(Notification, UserNotification.notification_id == Notification.id)
|
||||
.where(
|
||||
UserNotification.user_id == user_id,
|
||||
UserNotification.is_read.is_(False),
|
||||
Notification.status == "published",
|
||||
Notification.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
result = (await self._session.execute(stmt)).scalar_one()
|
||||
return result
|
||||
|
||||
async def get_user_notification(
|
||||
self,
|
||||
*,
|
||||
user_notification_id: UUID,
|
||||
user_id: UUID,
|
||||
) -> tuple[UserNotification, Notification] | None:
|
||||
stmt = (
|
||||
select(UserNotification, Notification)
|
||||
.join(Notification, UserNotification.notification_id == Notification.id)
|
||||
.where(
|
||||
UserNotification.id == user_notification_id,
|
||||
UserNotification.user_id == user_id,
|
||||
Notification.status == "published",
|
||||
Notification.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
row = (await self._session.execute(stmt)).first()
|
||||
if row is None:
|
||||
return None
|
||||
return (row[0], row[1])
|
||||
|
||||
async def mark_read(self, *, user_notification_id: UUID, user_id: UUID) -> bool:
|
||||
stmt = select(UserNotification).where(
|
||||
UserNotification.id == user_notification_id,
|
||||
UserNotification.user_id == user_id,
|
||||
UserNotification.is_read.is_(False),
|
||||
)
|
||||
un = (await self._session.execute(stmt)).scalar_one_or_none()
|
||||
if un is None:
|
||||
return False
|
||||
un.is_read = True
|
||||
un.read_at = datetime.now()
|
||||
await self._session.flush()
|
||||
return True
|
||||
|
||||
async def mark_all_read(self, *, user_id: UUID) -> int:
|
||||
un_ids_stmt = (
|
||||
select(UserNotification.id)
|
||||
.join(Notification, UserNotification.notification_id == Notification.id)
|
||||
.where(
|
||||
UserNotification.user_id == user_id,
|
||||
UserNotification.is_read.is_(False),
|
||||
Notification.status == "published",
|
||||
Notification.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
un_ids = list((await self._session.execute(un_ids_stmt)).scalars().all())
|
||||
if not un_ids:
|
||||
return 0
|
||||
count = len(un_ids)
|
||||
stmt = (
|
||||
update(UserNotification)
|
||||
.where(UserNotification.id.in_(un_ids))
|
||||
.values(is_read=True, read_at=func.now())
|
||||
)
|
||||
await self._session.execute(stmt)
|
||||
await self._session.flush()
|
||||
return count
|
||||
@@ -0,0 +1,117 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from v1.notifications.dependencies import get_notification_service
|
||||
from v1.notifications.schemas import (
|
||||
MarkAllReadResponse,
|
||||
NotificationItemResponse,
|
||||
NotificationListResponse,
|
||||
UnreadCountResponse,
|
||||
)
|
||||
from v1.notifications.service import NotificationService
|
||||
from v1.users.dependencies import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/notifications", tags=["notifications"])
|
||||
|
||||
|
||||
@router.get("", response_model=NotificationListResponse)
|
||||
async def list_notifications(
|
||||
service: Annotated[NotificationService, Depends(get_notification_service)],
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
limit: int = Query(default=20, ge=1, le=50),
|
||||
cursor: str | None = Query(default=None),
|
||||
) -> NotificationListResponse:
|
||||
from datetime import datetime
|
||||
|
||||
parsed_cursor = None
|
||||
if cursor is not None:
|
||||
try:
|
||||
parsed_cursor = datetime.fromisoformat(cursor.replace("Z", "+00:00"))
|
||||
except (ValueError, AttributeError):
|
||||
parsed_cursor = None
|
||||
|
||||
result = await service.list_notifications(
|
||||
user_id=current_user.id,
|
||||
limit=limit,
|
||||
cursor=parsed_cursor,
|
||||
)
|
||||
items = []
|
||||
for item in result.items:
|
||||
items.append(
|
||||
NotificationItemResponse(
|
||||
id=str(item.id),
|
||||
notificationId=str(item.notification_id),
|
||||
type=item.type,
|
||||
title=item.title,
|
||||
body=item.body,
|
||||
payload=item.payload,
|
||||
isRead=item.is_read,
|
||||
readAt=item.read_at,
|
||||
createdAt=item.created_at,
|
||||
)
|
||||
)
|
||||
return NotificationListResponse(
|
||||
items=items,
|
||||
nextCursor=result.next_cursor.isoformat() if result.next_cursor else None,
|
||||
hasMore=result.has_more,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/unread-count", response_model=UnreadCountResponse)
|
||||
async def get_unread_count(
|
||||
service: Annotated[NotificationService, Depends(get_notification_service)],
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
) -> UnreadCountResponse:
|
||||
count = await service.get_unread_count(user_id=current_user.id)
|
||||
return UnreadCountResponse(count=count)
|
||||
|
||||
|
||||
@router.patch("/{notification_id}/read", response_model=NotificationItemResponse)
|
||||
async def mark_notification_read(
|
||||
notification_id: str,
|
||||
service: Annotated[NotificationService, Depends(get_notification_service)],
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
) -> NotificationItemResponse:
|
||||
from uuid import UUID
|
||||
|
||||
try:
|
||||
uid = UUID(notification_id)
|
||||
except ValueError:
|
||||
from core.http.errors import ApiProblemError, problem_payload
|
||||
|
||||
raise ApiProblemError(
|
||||
status_code=404,
|
||||
detail=problem_payload(
|
||||
code="NOTIFICATION_NOT_FOUND",
|
||||
detail="Notification not found or not owned by current user",
|
||||
),
|
||||
)
|
||||
|
||||
item = await service.mark_read(
|
||||
user_notification_id=uid,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
return NotificationItemResponse(
|
||||
id=str(item.id),
|
||||
notificationId=str(item.notification_id),
|
||||
type=item.type,
|
||||
title=item.title,
|
||||
body=item.body,
|
||||
payload=item.payload,
|
||||
isRead=item.is_read,
|
||||
readAt=item.read_at,
|
||||
createdAt=item.created_at,
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/mark-all-read", response_model=MarkAllReadResponse)
|
||||
async def mark_all_read(
|
||||
service: Annotated[NotificationService, Depends(get_notification_service)],
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
) -> MarkAllReadResponse:
|
||||
updated_count = await service.mark_all_read(user_id=current_user.id)
|
||||
return MarkAllReadResponse(updatedCount=updated_count)
|
||||
@@ -0,0 +1,67 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Literal, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class NotificationPayloadNone(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
action: Literal["none"]
|
||||
|
||||
|
||||
class NotificationPayloadRoute(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
action: Literal["open_route"]
|
||||
route: str = Field(max_length=200)
|
||||
entity_id: str | None = Field(default=None, max_length=64)
|
||||
tab: str | None = Field(default=None, max_length=32)
|
||||
|
||||
|
||||
class NotificationPayloadUrl(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
action: Literal["open_url"]
|
||||
url: str = Field(max_length=500)
|
||||
|
||||
|
||||
NotificationPayload = Union[
|
||||
NotificationPayloadNone, NotificationPayloadRoute, NotificationPayloadUrl
|
||||
]
|
||||
|
||||
|
||||
class NotificationItemResponse(BaseModel):
|
||||
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
|
||||
|
||||
id: str
|
||||
notification_id: str = Field(alias="notificationId")
|
||||
type: str
|
||||
title: str
|
||||
body: str
|
||||
payload: NotificationPayload
|
||||
is_read: bool = Field(alias="isRead")
|
||||
read_at: datetime | None = Field(alias="readAt", default=None)
|
||||
created_at: datetime = Field(alias="createdAt")
|
||||
|
||||
|
||||
class NotificationListResponse(BaseModel):
|
||||
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
|
||||
|
||||
items: list[NotificationItemResponse]
|
||||
next_cursor: str | None = Field(alias="nextCursor", default=None)
|
||||
has_more: bool = Field(alias="hasMore", default=False)
|
||||
|
||||
|
||||
class UnreadCountResponse(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
count: int = Field(ge=0)
|
||||
|
||||
|
||||
class MarkAllReadResponse(BaseModel):
|
||||
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
|
||||
|
||||
updated_count: int = Field(alias="updatedCount", ge=0)
|
||||
@@ -0,0 +1,141 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from core.http.errors import ApiProblemError, problem_payload
|
||||
from v1.notifications.repository import NotificationRepository
|
||||
from v1.notifications.schemas import (
|
||||
NotificationPayloadNone,
|
||||
NotificationPayloadRoute,
|
||||
NotificationPayloadUrl,
|
||||
NotificationPayload,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NotificationListItem:
|
||||
id: UUID
|
||||
notification_id: UUID
|
||||
type: str
|
||||
title: str
|
||||
body: str
|
||||
payload: NotificationPayload
|
||||
is_read: bool
|
||||
read_at: datetime | None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NotificationListResult:
|
||||
items: list[NotificationListItem]
|
||||
next_cursor: datetime | None
|
||||
has_more: bool
|
||||
|
||||
|
||||
class NotificationService:
|
||||
def __init__(self, repository: NotificationRepository) -> None:
|
||||
self._repository = repository
|
||||
|
||||
async def list_notifications(
|
||||
self,
|
||||
*,
|
||||
user_id: UUID,
|
||||
limit: int = 20,
|
||||
cursor: datetime | None = None,
|
||||
) -> NotificationListResult:
|
||||
actual_limit = min(limit, 50)
|
||||
rows = await self._repository.list_notifications(
|
||||
user_id=user_id,
|
||||
limit=actual_limit + 1,
|
||||
cursor=cursor,
|
||||
)
|
||||
has_more = len(rows) > actual_limit
|
||||
items = rows[:actual_limit]
|
||||
next_cursor = None
|
||||
if has_more and items:
|
||||
next_cursor = items[-1][0].created_at
|
||||
|
||||
list_items = []
|
||||
for un, n in items:
|
||||
payload = _parse_payload(n.payload)
|
||||
list_items.append(
|
||||
NotificationListItem(
|
||||
id=un.id,
|
||||
notification_id=n.id,
|
||||
type=n.type,
|
||||
title=n.title,
|
||||
body=n.body,
|
||||
payload=payload,
|
||||
is_read=un.is_read,
|
||||
read_at=un.read_at,
|
||||
created_at=un.created_at,
|
||||
)
|
||||
)
|
||||
return NotificationListResult(
|
||||
items=list_items,
|
||||
next_cursor=next_cursor,
|
||||
has_more=has_more,
|
||||
)
|
||||
|
||||
async def get_unread_count(self, *, user_id: UUID) -> int:
|
||||
return await self._repository.get_unread_count(user_id=user_id)
|
||||
|
||||
async def mark_read(
|
||||
self, *, user_notification_id: UUID, user_id: UUID
|
||||
) -> NotificationListItem:
|
||||
result = await self._repository.get_user_notification(
|
||||
user_notification_id=user_notification_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if result is None:
|
||||
raise ApiProblemError(
|
||||
status_code=404,
|
||||
detail=problem_payload(
|
||||
code="NOTIFICATION_NOT_FOUND",
|
||||
detail="Notification not found or not owned by current user",
|
||||
),
|
||||
)
|
||||
un, n = result
|
||||
if not un.is_read:
|
||||
await self._repository.mark_read(
|
||||
user_notification_id=user_notification_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
payload = _parse_payload(n.payload)
|
||||
return NotificationListItem(
|
||||
id=un.id,
|
||||
notification_id=n.id,
|
||||
type=n.type,
|
||||
title=n.title,
|
||||
body=n.body,
|
||||
payload=payload,
|
||||
is_read=True,
|
||||
read_at=un.read_at or datetime.now(),
|
||||
created_at=un.created_at,
|
||||
)
|
||||
|
||||
async def mark_all_read(self, *, user_id: UUID) -> int:
|
||||
return await self._repository.mark_all_read(user_id=user_id)
|
||||
|
||||
|
||||
def _parse_payload(raw: dict[str, object]) -> NotificationPayload:
|
||||
action = raw.get("action")
|
||||
if action == "none":
|
||||
return NotificationPayloadNone(action="none")
|
||||
if action == "open_route":
|
||||
return NotificationPayloadRoute(
|
||||
action="open_route",
|
||||
route=str(raw.get("route", "")),
|
||||
entity_id=str(raw["entity_id"])
|
||||
if "entity_id" in raw and raw["entity_id"] is not None
|
||||
else None,
|
||||
tab=str(raw["tab"]) if "tab" in raw and raw["tab"] is not None else None,
|
||||
)
|
||||
if action == "open_url":
|
||||
return NotificationPayloadUrl(
|
||||
action="open_url",
|
||||
url=str(raw.get("url", "")),
|
||||
)
|
||||
return NotificationPayloadNone(action="none")
|
||||
@@ -4,6 +4,7 @@ from fastapi import APIRouter
|
||||
|
||||
from v1.agent.router import router as agent_router
|
||||
from v1.auth.router import router as auth_router
|
||||
from v1.notifications.router import router as notifications_router
|
||||
from v1.points.router import router as points_router
|
||||
from v1.users.router import router as users_router
|
||||
|
||||
@@ -11,5 +12,6 @@ from v1.users.router import router as users_router
|
||||
router = APIRouter(prefix="/api/v1")
|
||||
router.include_router(auth_router)
|
||||
router.include_router(agent_router)
|
||||
router.include_router(notifications_router)
|
||||
router.include_router(points_router)
|
||||
router.include_router(users_router)
|
||||
|
||||
@@ -0,0 +1,381 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from v1.notifications.service import NotificationService, _parse_payload
|
||||
from v1.notifications.schemas import (
|
||||
NotificationPayloadNone,
|
||||
NotificationPayloadRoute,
|
||||
NotificationPayloadUrl,
|
||||
)
|
||||
from core.http.errors import ApiProblemError
|
||||
|
||||
|
||||
class _FakeUserNotification:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
id: UUID,
|
||||
user_id: UUID,
|
||||
notification_id: UUID,
|
||||
is_read: bool = False,
|
||||
read_at: datetime | None = None,
|
||||
created_at: datetime | None = None,
|
||||
):
|
||||
self.id = id
|
||||
self.user_id = user_id
|
||||
self.notification_id = notification_id
|
||||
self.is_read = is_read
|
||||
self.read_at = read_at
|
||||
self.created_at = created_at or datetime.now()
|
||||
|
||||
|
||||
class _FakeNotification:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
id: UUID,
|
||||
type: str = "system",
|
||||
title: str = "Test",
|
||||
body: str = "Test body",
|
||||
payload: dict | None = None,
|
||||
status: str = "published",
|
||||
deleted_at: datetime | None = None,
|
||||
created_at: datetime | None = None,
|
||||
):
|
||||
self.id = id
|
||||
self.type = type
|
||||
self.title = title
|
||||
self.body = body
|
||||
self.payload = payload or {"action": "none"}
|
||||
self.status = status
|
||||
self.deleted_at = deleted_at
|
||||
self.created_at = created_at or datetime.now()
|
||||
|
||||
|
||||
class _FakeNotificationRepository:
|
||||
def __init__(self) -> None:
|
||||
self._items: list[tuple[_FakeUserNotification, _FakeNotification]] = []
|
||||
self._mark_read_ids: list[UUID] = []
|
||||
self._mark_all_read_user_ids: list[UUID] = []
|
||||
|
||||
def add_item(self, un: _FakeUserNotification, n: _FakeNotification) -> None:
|
||||
self._items.append((un, n))
|
||||
|
||||
async def list_notifications(
|
||||
self,
|
||||
*,
|
||||
user_id: UUID,
|
||||
limit: int = 20,
|
||||
cursor: datetime | None = None,
|
||||
) -> list[tuple[_FakeUserNotification, _FakeNotification]]:
|
||||
user_items = [
|
||||
(un, n)
|
||||
for un, n in self._items
|
||||
if un.user_id == user_id
|
||||
and n.status == "published"
|
||||
and n.deleted_at is None
|
||||
]
|
||||
if cursor is not None:
|
||||
user_items = [(un, n) for un, n in user_items if un.created_at < cursor]
|
||||
user_items.sort(key=lambda x: x[0].created_at, reverse=True)
|
||||
return user_items[:limit]
|
||||
|
||||
async def get_unread_count(self, *, user_id: UUID) -> int:
|
||||
return sum(
|
||||
1
|
||||
for un, n in self._items
|
||||
if un.user_id == user_id
|
||||
and not un.is_read
|
||||
and n.status == "published"
|
||||
and n.deleted_at is None
|
||||
)
|
||||
|
||||
async def get_user_notification(
|
||||
self,
|
||||
*,
|
||||
user_notification_id: UUID,
|
||||
user_id: UUID,
|
||||
) -> tuple[_FakeUserNotification, _FakeNotification] | None:
|
||||
for un, n in self._items:
|
||||
if un.id == user_notification_id and un.user_id == user_id:
|
||||
return (un, n)
|
||||
return None
|
||||
|
||||
async def mark_read(self, *, user_notification_id: UUID, user_id: UUID) -> bool:
|
||||
self._mark_read_ids.append(user_notification_id)
|
||||
for un, n in self._items:
|
||||
if un.id == user_notification_id and un.user_id == user_id:
|
||||
un.is_read = True
|
||||
un.read_at = datetime.now()
|
||||
return True
|
||||
return False
|
||||
|
||||
async def mark_all_read(self, *, user_id: UUID) -> int:
|
||||
self._mark_all_read_user_ids.append(user_id)
|
||||
count = 0
|
||||
for un, n in self._items:
|
||||
if (
|
||||
un.user_id == user_id
|
||||
and not un.is_read
|
||||
and n.status == "published"
|
||||
and n.deleted_at is None
|
||||
):
|
||||
un.is_read = True
|
||||
un.read_at = datetime.now()
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_repo() -> _FakeNotificationRepository:
|
||||
return _FakeNotificationRepository()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service(fake_repo: _FakeNotificationRepository) -> NotificationService:
|
||||
return NotificationService(repository=fake_repo) # type: ignore[arg-type]
|
||||
|
||||
|
||||
USER_A = uuid4()
|
||||
USER_B = uuid4()
|
||||
|
||||
|
||||
def _make_notification(
|
||||
*,
|
||||
user_id: UUID,
|
||||
notification_id: UUID | None = None,
|
||||
is_read: bool = False,
|
||||
read_at: datetime | None = None,
|
||||
title: str = "Test",
|
||||
body: str = "Test body",
|
||||
payload: dict | None = None,
|
||||
status: str = "published",
|
||||
deleted_at: datetime | None = None,
|
||||
) -> tuple[_FakeUserNotification, _FakeNotification]:
|
||||
nid = notification_id or uuid4()
|
||||
unid = uuid4()
|
||||
n = _FakeNotification(
|
||||
id=nid,
|
||||
title=title,
|
||||
body=body,
|
||||
payload=payload,
|
||||
status=status,
|
||||
deleted_at=deleted_at,
|
||||
)
|
||||
un = _FakeUserNotification(
|
||||
id=unid,
|
||||
user_id=user_id,
|
||||
notification_id=nid,
|
||||
is_read=is_read,
|
||||
read_at=read_at,
|
||||
)
|
||||
return un, n
|
||||
|
||||
|
||||
class TestListNotifications:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_only_user_a_notifications(
|
||||
self, service: NotificationService, fake_repo: _FakeNotificationRepository
|
||||
):
|
||||
un_a, n_a = _make_notification(user_id=USER_A, title="A1")
|
||||
un_b, n_b = _make_notification(user_id=USER_B, title="B1")
|
||||
fake_repo.add_item(un_a, n_a)
|
||||
fake_repo.add_item(un_b, n_b)
|
||||
|
||||
result = await service.list_notifications(user_id=USER_A)
|
||||
assert len(result.items) == 1
|
||||
assert result.items[0].title == "A1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_excludes_revoked_notifications(
|
||||
self, service: NotificationService, fake_repo: _FakeNotificationRepository
|
||||
):
|
||||
un, n = _make_notification(user_id=USER_A, status="revoked")
|
||||
fake_repo.add_item(un, n)
|
||||
|
||||
result = await service.list_notifications(user_id=USER_A)
|
||||
assert len(result.items) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_excludes_deleted_notifications(
|
||||
self, service: NotificationService, fake_repo: _FakeNotificationRepository
|
||||
):
|
||||
un, n = _make_notification(user_id=USER_A, deleted_at=datetime.now())
|
||||
fake_repo.add_item(un, n)
|
||||
|
||||
result = await service.list_notifications(user_id=USER_A)
|
||||
assert len(result.items) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pagination_has_more(
|
||||
self, service: NotificationService, fake_repo: _FakeNotificationRepository
|
||||
):
|
||||
for i in range(3):
|
||||
un, n = _make_notification(user_id=USER_A, title=f"N{i}")
|
||||
fake_repo.add_item(un, n)
|
||||
|
||||
result = await service.list_notifications(user_id=USER_A, limit=2)
|
||||
assert len(result.items) == 2
|
||||
assert result.has_more is True
|
||||
assert result.next_cursor is not None
|
||||
|
||||
|
||||
class TestUnreadCount:
|
||||
@pytest.mark.asyncio
|
||||
async def test_counts_unread_only(
|
||||
self, service: NotificationService, fake_repo: _FakeNotificationRepository
|
||||
):
|
||||
un_read, n_read = _make_notification(user_id=USER_A, is_read=True)
|
||||
un_unread, n_unread = _make_notification(user_id=USER_A, is_read=False)
|
||||
fake_repo.add_item(un_read, n_read)
|
||||
fake_repo.add_item(un_unread, n_unread)
|
||||
|
||||
count = await service.get_unread_count(user_id=USER_A)
|
||||
assert count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_excludes_revoked_from_count(
|
||||
self, service: NotificationService, fake_repo: _FakeNotificationRepository
|
||||
):
|
||||
un, n = _make_notification(user_id=USER_A, status="revoked", is_read=False)
|
||||
fake_repo.add_item(un, n)
|
||||
|
||||
count = await service.get_unread_count(user_id=USER_A)
|
||||
assert count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_b_unread_not_counted_for_user_a(
|
||||
self, service: NotificationService, fake_repo: _FakeNotificationRepository
|
||||
):
|
||||
un, n = _make_notification(user_id=USER_B, is_read=False)
|
||||
fake_repo.add_item(un, n)
|
||||
|
||||
count = await service.get_unread_count(user_id=USER_A)
|
||||
assert count == 0
|
||||
|
||||
|
||||
class TestMarkRead:
|
||||
@pytest.mark.asyncio
|
||||
async def test_mark_read_success(
|
||||
self, service: NotificationService, fake_repo: _FakeNotificationRepository
|
||||
):
|
||||
un, n = _make_notification(user_id=USER_A, is_read=False)
|
||||
fake_repo.add_item(un, n)
|
||||
|
||||
result = await service.mark_read(user_notification_id=un.id, user_id=USER_A)
|
||||
assert result.is_read is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mark_read_idempotent(
|
||||
self, service: NotificationService, fake_repo: _FakeNotificationRepository
|
||||
):
|
||||
now = datetime.now()
|
||||
un, n = _make_notification(user_id=USER_A, is_read=True, read_at=now)
|
||||
fake_repo.add_item(un, n)
|
||||
|
||||
result = await service.mark_read(user_notification_id=un.id, user_id=USER_A)
|
||||
assert result.is_read is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mark_read_wrong_user_raises_404(
|
||||
self, service: NotificationService, fake_repo: _FakeNotificationRepository
|
||||
):
|
||||
un, n = _make_notification(user_id=USER_A)
|
||||
fake_repo.add_item(un, n)
|
||||
|
||||
with pytest.raises(ApiProblemError) as exc_info:
|
||||
await service.mark_read(user_notification_id=un.id, user_id=USER_B)
|
||||
assert exc_info.value.status_code == 404
|
||||
assert exc_info.value.code == "NOTIFICATION_NOT_FOUND"
|
||||
|
||||
|
||||
class TestMarkAllRead:
|
||||
@pytest.mark.asyncio
|
||||
async def test_marks_all_unread_as_read(
|
||||
self, service: NotificationService, fake_repo: _FakeNotificationRepository
|
||||
):
|
||||
un1, n1 = _make_notification(user_id=USER_A, is_read=False)
|
||||
un2, n2 = _make_notification(user_id=USER_A, is_read=False)
|
||||
fake_repo.add_item(un1, n1)
|
||||
fake_repo.add_item(un2, n2)
|
||||
|
||||
updated = await service.mark_all_read(user_id=USER_A)
|
||||
assert updated == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_idempotent_when_all_read(
|
||||
self, service: NotificationService, fake_repo: _FakeNotificationRepository
|
||||
):
|
||||
un, n = _make_notification(user_id=USER_A, is_read=True)
|
||||
fake_repo.add_item(un, n)
|
||||
|
||||
updated = await service.mark_all_read(user_id=USER_A)
|
||||
assert updated == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_does_not_affect_other_user(
|
||||
self, service: NotificationService, fake_repo: _FakeNotificationRepository
|
||||
):
|
||||
un_a, n_a = _make_notification(user_id=USER_A, is_read=False)
|
||||
un_b, n_b = _make_notification(user_id=USER_B, is_read=False)
|
||||
fake_repo.add_item(un_a, n_a)
|
||||
fake_repo.add_item(un_b, n_b)
|
||||
|
||||
updated = await service.mark_all_read(user_id=USER_A)
|
||||
assert updated == 1
|
||||
assert un_b.is_read is False
|
||||
|
||||
|
||||
class TestParsePayload:
|
||||
def test_none_action(self):
|
||||
payload = _parse_payload({"action": "none"})
|
||||
assert isinstance(payload, NotificationPayloadNone)
|
||||
assert payload.action == "none"
|
||||
|
||||
def test_open_route_action(self):
|
||||
payload = _parse_payload(
|
||||
{
|
||||
"action": "open_route",
|
||||
"route": "/history",
|
||||
"entity_id": "abc-123",
|
||||
"tab": "details",
|
||||
}
|
||||
)
|
||||
assert isinstance(payload, NotificationPayloadRoute)
|
||||
assert payload.route == "/history"
|
||||
assert payload.entity_id == "abc-123"
|
||||
assert payload.tab == "details"
|
||||
|
||||
def test_open_url_action(self):
|
||||
payload = _parse_payload(
|
||||
{
|
||||
"action": "open_url",
|
||||
"url": "https://example.com",
|
||||
}
|
||||
)
|
||||
assert isinstance(payload, NotificationPayloadUrl)
|
||||
assert payload.url == "https://example.com"
|
||||
|
||||
def test_unknown_action_defaults_to_none(self):
|
||||
payload = _parse_payload({"action": "unknown"})
|
||||
assert isinstance(payload, NotificationPayloadNone)
|
||||
|
||||
def test_missing_action_defaults_to_none(self):
|
||||
payload = _parse_payload({})
|
||||
assert isinstance(payload, NotificationPayloadNone)
|
||||
|
||||
def test_open_route_minimal(self):
|
||||
payload = _parse_payload(
|
||||
{
|
||||
"action": "open_route",
|
||||
"route": "/settings",
|
||||
}
|
||||
)
|
||||
assert isinstance(payload, NotificationPayloadRoute)
|
||||
assert payload.route == "/settings"
|
||||
assert payload.entity_id is None
|
||||
assert payload.tab is None
|
||||
Reference in New Issue
Block a user