feat: integrate invite API and improve notification handling
- Add invite code display and binding functionality via API - Fix notification unread count sync on auth state change - Improve notification mark read with server state validation - Add auth state listener to trigger notification refresh - Add YaoCoinConverter for coin-to-yao type conversion - Remove YaoLegend from divination screens (UI cleanup) - Abbreviate relation labels in yao detail view - Add re-register notice to account delete screen - Update 'coins' terminology to 'points' in localization - Fix backend points consumption to only run in CHAT mode - Add HttpxAuthNoiseFilter to suppress auth endpoint logging - Fix notification static_schema import path - Add test coverage for notification bloc error handling - Update AGENTS.md page header rules and image handling - Delete deprecated run-dev.sh script
This commit is contained in:
@@ -384,13 +384,14 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
runtime_config=runtime_config,
|
||||
cancel_checker=_cancel_checker,
|
||||
)
|
||||
await points_service.consume_successful_run_points(
|
||||
user_id=owner_id,
|
||||
session_id=UUID(thread_id),
|
||||
run_id=run_id,
|
||||
operator_id=owner_id,
|
||||
user_email=owner_email,
|
||||
)
|
||||
if runtime_mode == RuntimeMode.CHAT:
|
||||
await points_service.consume_successful_run_points(
|
||||
user_id=owner_id,
|
||||
session_id=UUID(thread_id),
|
||||
run_id=run_id,
|
||||
operator_id=owner_id,
|
||||
user_email=owner_email,
|
||||
)
|
||||
await session.commit()
|
||||
except asyncio.CancelledError:
|
||||
await points_service.record_failed_run_platform_cost(
|
||||
|
||||
@@ -9,7 +9,7 @@ from uuid import UUID
|
||||
import yaml
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationError, model_validator
|
||||
|
||||
from v1.notifications.schemas import (
|
||||
from backend.src.schemas.shared.notification import (
|
||||
NotificationPayload,
|
||||
NotificationPayloadNone,
|
||||
)
|
||||
|
||||
@@ -39,6 +39,7 @@ def build_logging_config(runtime: RuntimeSettings) -> dict[str, object]:
|
||||
file_path=log_dir / runtime.log_file_name,
|
||||
level=runtime.log_level,
|
||||
formatter=formatter_name,
|
||||
filters=["suppress_httpx_auth_noise"],
|
||||
)
|
||||
error_handler = build_file_handler_config(
|
||||
runtime,
|
||||
@@ -54,7 +55,10 @@ def build_logging_config(runtime: RuntimeSettings) -> dict[str, object]:
|
||||
"filters": {
|
||||
"error_only": {
|
||||
"()": "core.logging.filters.ErrorLevelFilter",
|
||||
}
|
||||
},
|
||||
"suppress_httpx_auth_noise": {
|
||||
"()": "core.logging.filters.HttpxAuthNoiseFilter",
|
||||
},
|
||||
},
|
||||
"formatters": {
|
||||
"json": {
|
||||
|
||||
@@ -54,3 +54,16 @@ def build_sensitive_data_processor(
|
||||
class ErrorLevelFilter(logging.Filter):
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
return record.levelno >= logging.ERROR
|
||||
|
||||
|
||||
class HttpxAuthNoiseFilter(logging.Filter):
|
||||
_SUPPRESSED_FRAGMENTS = (
|
||||
"/auth/v1/user",
|
||||
"/auth/v1/token?grant_type=refresh_token",
|
||||
)
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
if record.levelno >= logging.WARNING:
|
||||
return True
|
||||
message = record.getMessage()
|
||||
return not any(fragment in message for fragment in self._SUPPRESSED_FRAGMENTS)
|
||||
|
||||
@@ -25,7 +25,7 @@ class SpecialMark(str, Enum):
|
||||
|
||||
|
||||
class YaoDetail(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
model_config = ConfigDict(extra="forbid", populate_by_name=True)
|
||||
|
||||
position: int = Field(ge=1, le=6)
|
||||
spirit_name: str = Field(alias="spiritName", min_length=1)
|
||||
@@ -38,7 +38,7 @@ class YaoDetail(BaseModel):
|
||||
|
||||
|
||||
class FushenDetail(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
model_config = ConfigDict(extra="forbid", populate_by_name=True)
|
||||
|
||||
position: int = Field(ge=1, le=6)
|
||||
relation_name: str = Field(alias="relationName", min_length=1)
|
||||
@@ -47,7 +47,7 @@ class FushenDetail(BaseModel):
|
||||
|
||||
|
||||
class GanzhiDetail(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
model_config = ConfigDict(extra="forbid", populate_by_name=True)
|
||||
|
||||
year_gan_zhi: str = Field(alias="yearGanZhi", min_length=2, max_length=2)
|
||||
month_gan_zhi: str = Field(alias="monthGanZhi", min_length=2, max_length=2)
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import ClassVar, Literal, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class NotificationPayloadNone(BaseModel):
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
action: Literal["none"]
|
||||
|
||||
|
||||
class NotificationPayloadRoute(BaseModel):
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
action: Literal["open_route"]
|
||||
route: str = Field(max_length=200)
|
||||
entity_id: str | None = Field(default=None, max_length=64)
|
||||
tab: str | None = Field(default=None, max_length=32)
|
||||
|
||||
|
||||
class NotificationPayloadUrl(BaseModel):
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
action: Literal["open_url"]
|
||||
url: str = Field(max_length=500)
|
||||
|
||||
|
||||
NotificationPayload = Union[
|
||||
NotificationPayloadNone,
|
||||
NotificationPayloadRoute,
|
||||
NotificationPayloadUrl,
|
||||
]
|
||||
@@ -170,7 +170,7 @@ class AgentRepository:
|
||||
session_row.last_activity_at = datetime.now(timezone.utc)
|
||||
await self._session.flush()
|
||||
|
||||
async def get_user_message_count(self, *, session_id: str) -> int:
|
||||
async def get_assistant_message_count(self, *, session_id: str) -> int:
|
||||
try:
|
||||
session_uuid = UUID(session_id)
|
||||
except ValueError as exc:
|
||||
@@ -184,7 +184,7 @@ class AgentRepository:
|
||||
select(func.count(AgentChatMessage.id))
|
||||
.where(AgentChatMessage.session_id == session_uuid)
|
||||
.where(AgentChatMessage.deleted_at.is_(None))
|
||||
.where(AgentChatMessage.role == AgentChatMessageRole.USER)
|
||||
.where(AgentChatMessage.role == AgentChatMessageRole.ASSISTANT)
|
||||
)
|
||||
count = (await self._session.execute(stmt)).scalar_one()
|
||||
return int(count)
|
||||
@@ -266,7 +266,11 @@ class AgentRepository:
|
||||
).scalar_one_or_none() is not None
|
||||
snapshot_messages: list[dict[str, object]] = []
|
||||
for message in messages:
|
||||
snapshot_messages.append(await self._to_snapshot_message(message))
|
||||
snapshot_messages.append(
|
||||
(await self._to_chat_message_schema(message)).model_dump(
|
||||
mode="json", by_alias=True, exclude_none=True
|
||||
)
|
||||
)
|
||||
return {
|
||||
"day": target_day.isoformat(),
|
||||
"hasMore": has_more,
|
||||
@@ -278,7 +282,7 @@ class AgentRepository:
|
||||
*,
|
||||
session_id: str,
|
||||
visibility_mask: int | None = None,
|
||||
) -> list[dict[str, object]]:
|
||||
) -> list[AgentChatMessageSchema]:
|
||||
try:
|
||||
session_uuid = UUID(session_id)
|
||||
except ValueError as exc:
|
||||
@@ -299,9 +303,9 @@ class AgentRepository:
|
||||
visibility_mask=visibility_mask,
|
||||
)
|
||||
messages = (await self._session.execute(message_stmt)).scalars().all()
|
||||
snapshot_messages: list[dict[str, object]] = []
|
||||
snapshot_messages: list[AgentChatMessageSchema] = []
|
||||
for message in messages:
|
||||
snapshot_messages.append(await self._to_snapshot_message(message))
|
||||
snapshot_messages.append(await self._to_chat_message_schema(message))
|
||||
return snapshot_messages
|
||||
|
||||
async def get_recent_messages_by_user_window(
|
||||
@@ -352,7 +356,11 @@ class AgentRepository:
|
||||
selected = list(reversed(selected_desc))
|
||||
snapshot_messages: list[dict[str, object]] = []
|
||||
for message in selected:
|
||||
snapshot_messages.append(await self._to_snapshot_message(message))
|
||||
snapshot_messages.append(
|
||||
(await self._to_chat_message_schema(message)).model_dump(
|
||||
mode="json", by_alias=True, exclude_none=True
|
||||
)
|
||||
)
|
||||
return snapshot_messages
|
||||
|
||||
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None:
|
||||
@@ -382,7 +390,7 @@ class AgentRepository:
|
||||
user_id: str,
|
||||
visibility_mask: int | None = None,
|
||||
session_limit: int = 50,
|
||||
) -> list[dict[str, object]]:
|
||||
) -> list[AgentChatMessageSchema]:
|
||||
try:
|
||||
user_uuid = UUID(user_id)
|
||||
except ValueError as exc:
|
||||
@@ -404,7 +412,7 @@ class AgentRepository:
|
||||
if not session_ids:
|
||||
return []
|
||||
|
||||
snapshots: list[dict[str, object]] = []
|
||||
snapshots: list[AgentChatMessageSchema] = []
|
||||
for session_id in session_ids:
|
||||
message_stmt = (
|
||||
select(AgentChatMessage)
|
||||
@@ -423,10 +431,14 @@ class AgentRepository:
|
||||
)
|
||||
if not candidate_messages:
|
||||
continue
|
||||
selected_snapshot: dict[str, object] | None = None
|
||||
selected_snapshot: AgentChatMessageSchema | None = None
|
||||
for message in candidate_messages:
|
||||
snapshot = await self._to_snapshot_message(message)
|
||||
metadata = snapshot.get("metadata")
|
||||
snapshot = await self._to_chat_message_schema(message)
|
||||
metadata = (
|
||||
snapshot.metadata.model_dump(mode="json", exclude_none=True)
|
||||
if snapshot.metadata is not None
|
||||
else None
|
||||
)
|
||||
if not isinstance(metadata, dict):
|
||||
continue
|
||||
agent_output = metadata.get("agent_output")
|
||||
@@ -440,7 +452,7 @@ class AgentRepository:
|
||||
snapshots.append(selected_snapshot)
|
||||
|
||||
snapshots.sort(
|
||||
key=lambda item: str(item.get("timestamp") or ""),
|
||||
key=lambda item: str(item.timestamp),
|
||||
reverse=True,
|
||||
)
|
||||
return snapshots
|
||||
@@ -462,9 +474,9 @@ class AgentRepository:
|
||||
"config": config_payload,
|
||||
}
|
||||
|
||||
async def _to_snapshot_message(
|
||||
async def _to_chat_message_schema(
|
||||
self, message: AgentChatMessage
|
||||
) -> dict[str, object]:
|
||||
) -> AgentChatMessageSchema:
|
||||
role = (
|
||||
message.role.value
|
||||
if isinstance(message.role, AgentChatMessageRole)
|
||||
@@ -487,7 +499,7 @@ class AgentRepository:
|
||||
"timestamp": message.created_at.astimezone(timezone.utc).isoformat(),
|
||||
}
|
||||
)
|
||||
return payload_model.model_dump(mode="json", exclude_none=True)
|
||||
return payload_model
|
||||
|
||||
def _apply_visibility_filter(
|
||||
self,
|
||||
|
||||
@@ -8,6 +8,7 @@ from uuid import UUID
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from schemas.agent.runtime_models import ErrorInfo
|
||||
from schemas.domain.chat_message import AgentChatMessage
|
||||
from schemas.domain.divination import DerivedDivinationData
|
||||
|
||||
|
||||
@@ -37,7 +38,7 @@ class AgentRepositoryLike(Protocol):
|
||||
*,
|
||||
session_id: str,
|
||||
visibility_mask: int | None = None,
|
||||
) -> list[dict[str, object]]: ...
|
||||
) -> list[AgentChatMessage]: ...
|
||||
|
||||
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None: ...
|
||||
|
||||
@@ -47,7 +48,7 @@ class AgentRepositoryLike(Protocol):
|
||||
user_id: str,
|
||||
visibility_mask: int | None = None,
|
||||
session_limit: int = 50,
|
||||
) -> list[dict[str, object]]: ...
|
||||
) -> list[AgentChatMessage]: ...
|
||||
|
||||
async def persist_user_message(
|
||||
self,
|
||||
@@ -58,7 +59,7 @@ class AgentRepositoryLike(Protocol):
|
||||
visibility_mask: int,
|
||||
) -> None: ...
|
||||
|
||||
async def get_user_message_count(self, *, session_id: str) -> int: ...
|
||||
async def get_assistant_message_count(self, *, session_id: str) -> int: ...
|
||||
|
||||
async def get_system_agent_config(
|
||||
self, *, agent_type: str
|
||||
|
||||
@@ -46,7 +46,7 @@ from v1.agent.utils import (
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
MAX_RUNS_PER_SESSION = 2
|
||||
MAX_ASSISTANT_MESSAGES_PER_SESSION = 2
|
||||
|
||||
|
||||
def ensure_session_owner(*, owner_id: str, current_user: CurrentUser) -> None:
|
||||
@@ -151,6 +151,7 @@ class AgentService:
|
||||
await self._enforce_run_preconditions(
|
||||
thread_id=thread_id,
|
||||
current_user=current_user,
|
||||
runtime_mode=runtime_mode,
|
||||
)
|
||||
except ApiProblemError:
|
||||
if created:
|
||||
@@ -247,7 +248,7 @@ class AgentService:
|
||||
metadata: AgentChatMessageMetadata | None,
|
||||
) -> None:
|
||||
metadata_payload = (
|
||||
metadata.model_dump(mode="json", exclude_none=True)
|
||||
metadata.model_dump(mode="json", by_alias=True, exclude_none=True)
|
||||
if isinstance(metadata, AgentChatMessageMetadata)
|
||||
else None
|
||||
)
|
||||
@@ -494,19 +495,23 @@ class AgentService:
|
||||
*,
|
||||
thread_id: str,
|
||||
current_user: CurrentUser,
|
||||
runtime_mode: RuntimeMode,
|
||||
) -> None:
|
||||
await self._points_service.ensure_run_points_available(user_id=current_user.id)
|
||||
if runtime_mode == RuntimeMode.CHAT:
|
||||
await self._points_service.ensure_run_points_available(
|
||||
user_id=current_user.id
|
||||
)
|
||||
|
||||
user_message_count = await self._repository.get_user_message_count(
|
||||
assistant_message_count = await self._repository.get_assistant_message_count(
|
||||
session_id=thread_id
|
||||
)
|
||||
if user_message_count >= MAX_RUNS_PER_SESSION:
|
||||
if assistant_message_count >= MAX_ASSISTANT_MESSAGES_PER_SESSION:
|
||||
raise ApiProblemError(
|
||||
status_code=409,
|
||||
detail=problem_payload(
|
||||
code="AGENT_SESSION_RUN_LIMIT_EXCEEDED",
|
||||
detail="Session run limit exceeded",
|
||||
params={"maxRuns": MAX_RUNS_PER_SESSION},
|
||||
params={"maxRuns": MAX_ASSISTANT_MESSAGES_PER_SESSION},
|
||||
),
|
||||
)
|
||||
|
||||
@@ -597,7 +602,6 @@ class AgentService:
|
||||
thread_id: str,
|
||||
current_user: CurrentUser,
|
||||
) -> HistorySnapshotResponse:
|
||||
from schemas.domain.chat_message import AgentChatMessage
|
||||
from v1.agent.utils import convert_message_to_history
|
||||
from v1.agent.schemas import HistoryMessage
|
||||
|
||||
@@ -609,11 +613,9 @@ class AgentService:
|
||||
)
|
||||
|
||||
messages: list[HistoryMessage] = []
|
||||
for msg_dict in raw_messages:
|
||||
msg = AgentChatMessage.model_validate(msg_dict)
|
||||
if msg.role == "tool":
|
||||
for msg in raw_messages:
|
||||
if msg.role not in {"user", "assistant"}:
|
||||
continue
|
||||
|
||||
signed_urls: dict[str, str] = {}
|
||||
attachments = extract_user_message_attachments(msg.metadata)
|
||||
if self._attachment_storage and attachments:
|
||||
@@ -653,7 +655,6 @@ class AgentService:
|
||||
current_user: CurrentUser,
|
||||
thread_id: str | None,
|
||||
) -> HistorySnapshotResponse:
|
||||
from schemas.domain.chat_message import AgentChatMessage
|
||||
from v1.agent.utils import convert_message_to_history
|
||||
from v1.agent.schemas import HistoryMessage
|
||||
|
||||
@@ -675,8 +676,9 @@ class AgentService:
|
||||
visible_messages = raw_messages[:summary_limit]
|
||||
|
||||
messages: list[HistoryMessage] = []
|
||||
for msg_dict in visible_messages:
|
||||
msg = AgentChatMessage.model_validate(msg_dict)
|
||||
for msg in visible_messages:
|
||||
if msg.role != "assistant":
|
||||
continue
|
||||
converted = convert_message_to_history(msg)
|
||||
messages.append(HistoryMessage.model_validate(converted))
|
||||
|
||||
|
||||
@@ -7,6 +7,9 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.config.settings import config
|
||||
from core.db import get_db
|
||||
from core.logging import get_logger
|
||||
from v1.notifications.repository import NotificationRepository
|
||||
from v1.notifications.service import NotificationService
|
||||
from v1.auth.rate_limit import enforce_rate_limit
|
||||
from v1.auth.dependencies import get_auth_service
|
||||
from v1.auth.schemas import (
|
||||
@@ -22,6 +25,7 @@ from v1.points.service import PointsService
|
||||
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
logger = get_logger("v1.auth.router")
|
||||
|
||||
|
||||
@router.post("/otp/send", status_code=204)
|
||||
@@ -73,7 +77,16 @@ async def create_email_session(
|
||||
user_id=UUID(result.user.id),
|
||||
user_email=result.user.email,
|
||||
)
|
||||
notification_service = NotificationService(NotificationRepository(session))
|
||||
linked_count = await notification_service.link_published_notifications_to_user(
|
||||
user_id=UUID(result.user.id)
|
||||
)
|
||||
await session.commit()
|
||||
logger.info(
|
||||
"Linked published notifications for authenticated user",
|
||||
user_id=result.user.id,
|
||||
linked_count=linked_count,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from core.db import get_db
|
||||
from v1.invite.repository import InviteCodeRepository
|
||||
from v1.invite.service import InviteCodeService
|
||||
from v1.users.dependencies import get_current_user
|
||||
|
||||
|
||||
def get_invite_code_repository(
|
||||
session: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> InviteCodeRepository:
|
||||
return InviteCodeRepository(session)
|
||||
|
||||
|
||||
def get_invite_code_service(
|
||||
repository: Annotated[InviteCodeRepository, Depends(get_invite_code_repository)],
|
||||
) -> InviteCodeService:
|
||||
return InviteCodeService(repository=repository)
|
||||
|
||||
|
||||
def get_current_user_for_invite(
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
) -> CurrentUser:
|
||||
return current_user
|
||||
@@ -0,0 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from models.invite_code import InviteCode
|
||||
|
||||
|
||||
class InviteCodeRepository:
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
async def get_by_owner_id(self, *, owner_id: UUID) -> InviteCode | None:
|
||||
stmt = (
|
||||
select(InviteCode)
|
||||
.where(InviteCode.owner_id == owner_id)
|
||||
.order_by(InviteCode.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
return (await self._session.execute(stmt)).scalar_one_or_none()
|
||||
@@ -0,0 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from v1.invite.dependencies import (
|
||||
get_current_user_for_invite,
|
||||
get_invite_code_service,
|
||||
)
|
||||
from v1.invite.schemas import MyInviteCodeResponse
|
||||
from v1.invite.service import InviteCodeService
|
||||
|
||||
|
||||
router = APIRouter(prefix="/invite", tags=["invite"])
|
||||
|
||||
|
||||
@router.get("/me", response_model=MyInviteCodeResponse)
|
||||
async def get_my_invite_code(
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user_for_invite)],
|
||||
service: Annotated[InviteCodeService, Depends(get_invite_code_service)],
|
||||
) -> MyInviteCodeResponse:
|
||||
return await service.get_my_invite_code(user_id=current_user.id)
|
||||
@@ -0,0 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class MyInviteCodeResponse(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
code: str
|
||||
used_count: int
|
||||
@@ -0,0 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID
|
||||
|
||||
from core.http.errors import ApiProblemError, problem_payload
|
||||
from v1.invite.repository import InviteCodeRepository
|
||||
from v1.invite.schemas import MyInviteCodeResponse
|
||||
|
||||
|
||||
@dataclass
|
||||
class InviteCodeService:
|
||||
repository: InviteCodeRepository
|
||||
|
||||
async def get_my_invite_code(self, *, user_id: UUID) -> MyInviteCodeResponse:
|
||||
invite_code = await self.repository.get_by_owner_id(owner_id=user_id)
|
||||
if invite_code is None:
|
||||
raise ApiProblemError(
|
||||
status_code=404,
|
||||
detail=problem_payload(
|
||||
code="INVITE_CODE_NOT_FOUND",
|
||||
detail="Invite code not found for current user",
|
||||
),
|
||||
)
|
||||
return MyInviteCodeResponse(
|
||||
code=invite_code.code,
|
||||
used_count=invite_code.used_count,
|
||||
)
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy import func, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
@@ -111,3 +112,37 @@ class NotificationRepository:
|
||||
await self._session.execute(stmt)
|
||||
await self._session.flush()
|
||||
return count
|
||||
|
||||
async def commit(self) -> None:
|
||||
await self._session.commit()
|
||||
|
||||
async def link_published_notifications_to_user(self, *, user_id: UUID) -> int:
|
||||
notification_ids = list(
|
||||
(
|
||||
await self._session.execute(
|
||||
select(Notification.id).where(
|
||||
Notification.status == "published",
|
||||
Notification.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
if not notification_ids:
|
||||
return 0
|
||||
|
||||
stmt = (
|
||||
insert(UserNotification)
|
||||
.values(
|
||||
[
|
||||
{"user_id": user_id, "notification_id": notification_id}
|
||||
for notification_id in notification_ids
|
||||
]
|
||||
)
|
||||
.on_conflict_do_nothing(index_elements=["user_id", "notification_id"])
|
||||
.returning(UserNotification.id)
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
await self._session.flush()
|
||||
return len(list(result.scalars().all()))
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
from core.logging import get_logger
|
||||
from core.auth.models import CurrentUser
|
||||
from v1.notifications.dependencies import get_notification_service
|
||||
from v1.notifications.schemas import (
|
||||
@@ -16,6 +17,7 @@ from v1.notifications.service import NotificationService
|
||||
from v1.users.dependencies import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/notifications", tags=["notifications"])
|
||||
logger = get_logger("v1.notifications.router")
|
||||
|
||||
|
||||
@router.get("", response_model=NotificationListResponse)
|
||||
@@ -39,6 +41,13 @@ async def list_notifications(
|
||||
limit=limit,
|
||||
cursor=parsed_cursor,
|
||||
)
|
||||
logger.info(
|
||||
"Notification list fetched",
|
||||
user_id=str(current_user.id),
|
||||
limit=limit,
|
||||
item_count=len(result.items),
|
||||
has_more=result.has_more,
|
||||
)
|
||||
items = []
|
||||
for item in result.items:
|
||||
items.append(
|
||||
@@ -67,6 +76,11 @@ async def get_unread_count(
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
) -> UnreadCountResponse:
|
||||
count = await service.get_unread_count(user_id=current_user.id)
|
||||
logger.info(
|
||||
"Notification unread count fetched",
|
||||
user_id=str(current_user.id),
|
||||
count=count,
|
||||
)
|
||||
return UnreadCountResponse(count=count)
|
||||
|
||||
|
||||
@@ -95,6 +109,11 @@ async def mark_notification_read(
|
||||
user_notification_id=uid,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
logger.info(
|
||||
"Notification marked as read",
|
||||
user_id=str(current_user.id),
|
||||
user_notification_id=str(uid),
|
||||
)
|
||||
return NotificationItemResponse(
|
||||
id=str(item.id),
|
||||
notificationId=str(item.notification_id),
|
||||
@@ -114,4 +133,9 @@ async def mark_all_read(
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
) -> MarkAllReadResponse:
|
||||
updated_count = await service.mark_all_read(user_id=current_user.id)
|
||||
logger.info(
|
||||
"All notifications marked as read",
|
||||
user_id=str(current_user.id),
|
||||
updated_count=updated_count,
|
||||
)
|
||||
return MarkAllReadResponse(updatedCount=updated_count)
|
||||
|
||||
@@ -1,35 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Literal, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from schemas.shared.notification import (
|
||||
NotificationPayload,
|
||||
NotificationPayloadNone,
|
||||
NotificationPayloadRoute,
|
||||
NotificationPayloadUrl,
|
||||
)
|
||||
|
||||
class NotificationPayloadNone(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
action: Literal["none"]
|
||||
|
||||
|
||||
class NotificationPayloadRoute(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
action: Literal["open_route"]
|
||||
route: str = Field(max_length=200)
|
||||
entity_id: str | None = Field(default=None, max_length=64)
|
||||
tab: str | None = Field(default=None, max_length=32)
|
||||
|
||||
|
||||
class NotificationPayloadUrl(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
action: Literal["open_url"]
|
||||
url: str = Field(max_length=500)
|
||||
|
||||
|
||||
NotificationPayload = Union[
|
||||
NotificationPayloadNone, NotificationPayloadRoute, NotificationPayloadUrl
|
||||
__all__ = [
|
||||
"NotificationPayload",
|
||||
"NotificationPayloadNone",
|
||||
"NotificationPayloadRoute",
|
||||
"NotificationPayloadUrl",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -103,6 +103,7 @@ class NotificationService:
|
||||
user_notification_id=user_notification_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
await self._repository.commit()
|
||||
payload = _parse_payload(n.payload)
|
||||
return NotificationListItem(
|
||||
id=un.id,
|
||||
@@ -117,7 +118,15 @@ class NotificationService:
|
||||
)
|
||||
|
||||
async def mark_all_read(self, *, user_id: UUID) -> int:
|
||||
return await self._repository.mark_all_read(user_id=user_id)
|
||||
updated_count = await self._repository.mark_all_read(user_id=user_id)
|
||||
if updated_count > 0:
|
||||
await self._repository.commit()
|
||||
return updated_count
|
||||
|
||||
async def link_published_notifications_to_user(self, *, user_id: UUID) -> int:
|
||||
return await self._repository.link_published_notifications_to_user(
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
|
||||
def _parse_payload(raw: dict[str, object]) -> NotificationPayload:
|
||||
|
||||
@@ -4,6 +4,7 @@ from fastapi import APIRouter
|
||||
|
||||
from v1.agent.router import router as agent_router
|
||||
from v1.auth.router import router as auth_router
|
||||
from v1.invite.router import router as invite_router
|
||||
from v1.notifications.router import router as notifications_router
|
||||
from v1.points.router import router as points_router
|
||||
from v1.users.router import router as users_router
|
||||
@@ -12,6 +13,7 @@ from v1.users.router import router as users_router
|
||||
router = APIRouter(prefix="/api/v1")
|
||||
router.include_router(auth_router)
|
||||
router.include_router(agent_router)
|
||||
router.include_router(invite_router)
|
||||
router.include_router(notifications_router)
|
||||
router.include_router(points_router)
|
||||
router.include_router(users_router)
|
||||
|
||||
@@ -3,9 +3,11 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import delete, or_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from models.invite_code import InviteCode
|
||||
from models.points_audit_ledger import PointsAuditLedger
|
||||
from models.profile import Profile
|
||||
|
||||
|
||||
@@ -35,3 +37,28 @@ class SQLAlchemyUserRepository:
|
||||
|
||||
async def save(self) -> None:
|
||||
await self.session.commit()
|
||||
|
||||
async def delete_invite_codes_by_owner_id(self, *, user_id: UUID) -> int:
|
||||
stmt = delete(InviteCode).where(InviteCode.owner_id == user_id)
|
||||
result = await self.session.execute(stmt)
|
||||
return int(result.rowcount or 0)
|
||||
|
||||
async def delete_points_audit_snapshots(
|
||||
self,
|
||||
*,
|
||||
user_id: UUID,
|
||||
user_email: str | None,
|
||||
) -> int:
|
||||
if user_email:
|
||||
stmt = delete(PointsAuditLedger).where(
|
||||
or_(
|
||||
PointsAuditLedger.user_id_snapshot == user_id,
|
||||
PointsAuditLedger.user_email_snapshot == user_email,
|
||||
)
|
||||
)
|
||||
else:
|
||||
stmt = delete(PointsAuditLedger).where(
|
||||
PointsAuditLedger.user_id_snapshot == user_id
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return int(result.rowcount or 0)
|
||||
|
||||
@@ -296,7 +296,9 @@ class UserService:
|
||||
user_id = str(self.current_user.id)
|
||||
avatar_bucket = config.storage.avatar.bucket
|
||||
avatar_prefix = f"{self.current_user.id}/"
|
||||
points_repository = PointsRepository(self.repository.session)
|
||||
session = self.repository.session
|
||||
points_repository = PointsRepository(session) if session is not None else None
|
||||
normalized_email = (self.current_user.email or "").strip().lower() or None
|
||||
|
||||
try:
|
||||
await self.attachment_storage.delete_prefix(
|
||||
@@ -318,30 +320,51 @@ class UserService:
|
||||
),
|
||||
) from exc
|
||||
|
||||
try:
|
||||
user_email = (self.current_user.email or "").strip().lower()
|
||||
if user_email:
|
||||
email_hash = PointsService._build_register_bonus_email_hash(user_email)
|
||||
account = await points_repository.get_user_points(
|
||||
user_id=self.current_user.id
|
||||
if session is not None and points_repository is not None:
|
||||
try:
|
||||
deleted_invite_codes = (
|
||||
await self.repository.delete_invite_codes_by_owner_id(
|
||||
user_id=self.current_user.id
|
||||
)
|
||||
)
|
||||
await points_repository.update_register_bonus_balance_snapshot(
|
||||
email_hash=email_hash,
|
||||
balance_snapshot=int(account.balance),
|
||||
deleted_audit_rows = (
|
||||
await self.repository.delete_points_audit_snapshots(
|
||||
user_id=self.current_user.id,
|
||||
user_email=normalized_email,
|
||||
)
|
||||
)
|
||||
await self.repository.session.commit()
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"Account deletion failed while persisting points snapshot",
|
||||
user_id=user_id,
|
||||
)
|
||||
raise ApiProblemError(
|
||||
status_code=502,
|
||||
detail=problem_payload(
|
||||
code="PROFILE_DELETE_FAILED",
|
||||
detail="Failed to delete account data",
|
||||
),
|
||||
) from exc
|
||||
|
||||
if normalized_email:
|
||||
email_hash = PointsService._build_register_bonus_email_hash(
|
||||
normalized_email
|
||||
)
|
||||
account = await points_repository.get_user_points(
|
||||
user_id=self.current_user.id
|
||||
)
|
||||
await points_repository.update_register_bonus_balance_snapshot(
|
||||
email_hash=email_hash,
|
||||
balance_snapshot=int(account.balance),
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
logger.info(
|
||||
"Account deletion local data cleanup completed",
|
||||
user_id=user_id,
|
||||
invite_codes_deleted=deleted_invite_codes,
|
||||
points_audit_rows_deleted=deleted_audit_rows,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"Account deletion failed while cleaning local data",
|
||||
user_id=user_id,
|
||||
)
|
||||
raise ApiProblemError(
|
||||
status_code=502,
|
||||
detail=problem_payload(
|
||||
code="PROFILE_DELETE_FAILED",
|
||||
detail="Failed to delete account data",
|
||||
),
|
||||
) from exc
|
||||
|
||||
try:
|
||||
await self.attachment_storage.delete_auth_user(user_id=user_id)
|
||||
|
||||
Reference in New Issue
Block a user