ff40ff9dd8
- 数据库:添加 has_purchased_starter_pack 字段到 register_bonus_claims - 后端:创建静态配置管理套餐信息,支持按国家/地区区分 - 后端:新增 GET /api/v1/points/packages API 返回可用套餐 - 后端:创建 utils/paths.py 统一路径管理 - 前端:动态获取套餐信息,移除硬编码 - 前端:添加 ProductCode 枚举约束,前后端类型安全 - 配置:Profile 默认国家改为 US(ISO 3166-1 alpha-2) - 文档:更新协议文档说明新 API 和字段
212 lines
7.3 KiB
Python
212 lines
7.3 KiB
Python
from __future__ import annotations
|
|
|
|
from decimal import Decimal
|
|
from uuid import UUID
|
|
|
|
from sqlalchemy.dialects.postgresql import insert
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from models.agent_chat_message import AgentChatMessage
|
|
from models.points_audit_ledger import PointsAuditLedger
|
|
from models.points_ledger import PointsLedger
|
|
from models.profile import Profile
|
|
from models.register_bonus_claims import RegisterBonusClaims
|
|
from models.user_points import UserPoints
|
|
from schemas.domain.points import (
|
|
AppendAuditLedgerCommand,
|
|
ApplyPointsChangeCommand,
|
|
PointsChargeSnapshot,
|
|
)
|
|
from schemas.enums import AgentChatMessageRole
|
|
|
|
|
|
class PointsRepository:
|
|
def __init__(self, session: AsyncSession) -> None:
|
|
self._session = session
|
|
|
|
async def get_or_create_user_points_for_update(
|
|
self, *, user_id: UUID
|
|
) -> UserPoints:
|
|
insert_stmt = (
|
|
insert(UserPoints)
|
|
.values(user_id=user_id)
|
|
.on_conflict_do_nothing(index_elements=[UserPoints.user_id])
|
|
)
|
|
await self._session.execute(insert_stmt)
|
|
|
|
stmt = select(UserPoints).where(UserPoints.user_id == user_id).with_for_update()
|
|
return (await self._session.execute(stmt)).scalar_one()
|
|
|
|
async def has_ledger_event(self, *, user_id: UUID, event_id: str) -> bool:
|
|
stmt = select(PointsLedger.id).where(
|
|
PointsLedger.user_id == user_id,
|
|
PointsLedger.event_id == event_id,
|
|
)
|
|
row = (await self._session.execute(stmt)).scalar_one_or_none()
|
|
return row is not None
|
|
|
|
async def append_ledger(
|
|
self,
|
|
*,
|
|
command: ApplyPointsChangeCommand,
|
|
balance_after: int,
|
|
) -> None:
|
|
entry = PointsLedger(
|
|
user_id=command.user_id,
|
|
direction=command.direction,
|
|
amount=command.amount,
|
|
balance_after=balance_after,
|
|
change_type=command.change_type.value,
|
|
biz_type=command.biz_type.value if command.biz_type is not None else None,
|
|
biz_id=command.biz_id,
|
|
event_id=command.event_id,
|
|
operator_id=command.operator_id,
|
|
metadata_json=command.metadata.model_dump(mode="json", exclude_none=True),
|
|
)
|
|
self._session.add(entry)
|
|
await self._session.flush()
|
|
|
|
async def has_audit_event(self, *, event_id: str) -> bool:
|
|
stmt = select(PointsAuditLedger.id).where(
|
|
PointsAuditLedger.event_id == event_id
|
|
)
|
|
row = (await self._session.execute(stmt)).scalar_one_or_none()
|
|
return row is not None
|
|
|
|
async def append_audit_ledger(self, *, command: AppendAuditLedgerCommand) -> None:
|
|
entry = PointsAuditLedger(
|
|
event_id=command.event_id,
|
|
user_id_snapshot=command.user_id_snapshot,
|
|
user_email_snapshot=command.user_email_snapshot,
|
|
change_type=command.change_type.value,
|
|
biz_type=command.biz_type.value if command.biz_type is not None else None,
|
|
biz_id=command.biz_id,
|
|
direction=command.direction,
|
|
amount=command.amount,
|
|
balance_after=command.balance_after,
|
|
billed_to=command.billed_to,
|
|
run_id=command.run_id,
|
|
request_id=command.request_id,
|
|
input_tokens=command.input_tokens,
|
|
output_tokens=command.output_tokens,
|
|
cost=command.cost,
|
|
metadata_json=command.metadata.model_dump(mode="json", exclude_none=True),
|
|
)
|
|
self._session.add(entry)
|
|
await self._session.flush()
|
|
|
|
async def get_run_usage_snapshot(
|
|
self,
|
|
*,
|
|
session_id: UUID,
|
|
run_id: str,
|
|
) -> PointsChargeSnapshot | None:
|
|
stmt = (
|
|
select(AgentChatMessage)
|
|
.where(
|
|
AgentChatMessage.session_id == session_id,
|
|
AgentChatMessage.role == AgentChatMessageRole.ASSISTANT,
|
|
AgentChatMessage.deleted_at.is_(None),
|
|
)
|
|
.order_by(AgentChatMessage.seq.desc())
|
|
.limit(20)
|
|
)
|
|
messages = list((await self._session.execute(stmt)).scalars().all())
|
|
message = None
|
|
for candidate in messages:
|
|
metadata = candidate.metadata_json or {}
|
|
if metadata.get("run_id") == run_id:
|
|
message = candidate
|
|
break
|
|
|
|
if message is None:
|
|
return None
|
|
|
|
cost_value = message.cost if message.cost is not None else Decimal("0")
|
|
return PointsChargeSnapshot(
|
|
message_id=message.id,
|
|
message_seq=max(int(message.seq), 1),
|
|
model_code=(message.model_code or "agent_run").strip() or "agent_run",
|
|
input_tokens=max(int(message.input_tokens), 0),
|
|
output_tokens=max(int(message.output_tokens), 0),
|
|
cost=Decimal(str(cost_value)),
|
|
)
|
|
|
|
async def get_user_points(self, *, user_id: UUID) -> UserPoints:
|
|
insert_stmt = (
|
|
insert(UserPoints)
|
|
.values(user_id=user_id)
|
|
.on_conflict_do_nothing(index_elements=[UserPoints.user_id])
|
|
)
|
|
await self._session.execute(insert_stmt)
|
|
|
|
stmt = select(UserPoints).where(UserPoints.user_id == user_id)
|
|
return (await self._session.execute(stmt)).scalar_one()
|
|
|
|
async def claim_register_bonus(
|
|
self,
|
|
*,
|
|
email_hash: str,
|
|
user_email_snapshot: str,
|
|
first_user_id_snapshot: UUID,
|
|
grant_event_id: str,
|
|
) -> bool:
|
|
stmt = (
|
|
insert(RegisterBonusClaims)
|
|
.values(
|
|
email_hash=email_hash,
|
|
user_email_snapshot=user_email_snapshot,
|
|
first_user_id_snapshot=first_user_id_snapshot,
|
|
grant_event_id=grant_event_id,
|
|
)
|
|
.on_conflict_do_nothing(index_elements=[RegisterBonusClaims.email_hash])
|
|
.returning(RegisterBonusClaims.id)
|
|
)
|
|
inserted_id = (await self._session.execute(stmt)).scalar_one_or_none()
|
|
return inserted_id is not None
|
|
|
|
async def get_register_bonus_claim(
|
|
self,
|
|
*,
|
|
email_hash: str,
|
|
) -> RegisterBonusClaims | None:
|
|
stmt = (
|
|
select(RegisterBonusClaims)
|
|
.where(RegisterBonusClaims.email_hash == email_hash)
|
|
.limit(1)
|
|
)
|
|
return (await self._session.execute(stmt)).scalar_one_or_none()
|
|
|
|
async def update_register_bonus_balance_snapshot(
|
|
self,
|
|
*,
|
|
email_hash: str,
|
|
balance_snapshot: int,
|
|
) -> bool:
|
|
claim = await self.get_register_bonus_claim(email_hash=email_hash)
|
|
if claim is None:
|
|
return False
|
|
claim.balance_snapshot = int(balance_snapshot)
|
|
await self._session.flush()
|
|
return True
|
|
|
|
async def has_purchased_starter_pack(
|
|
self,
|
|
*,
|
|
email_hash: str,
|
|
) -> bool:
|
|
claim = await self.get_register_bonus_claim(email_hash=email_hash)
|
|
if claim is None:
|
|
return False
|
|
return bool(claim.has_purchased_starter_pack)
|
|
|
|
async def get_profile_settings(
|
|
self,
|
|
*,
|
|
user_id: UUID,
|
|
) -> dict[str, object] | None:
|
|
stmt = select(Profile.settings).where(Profile.id == user_id).limit(1)
|
|
row = (await self._session.execute(stmt)).scalar_one_or_none()
|
|
return row
|