feat: 添加 points_audit_ledger 及 JSON 字段 Pydantic Schema 约束

This commit is contained in:
qzl
2026-04-10 12:28:18 +08:00
parent 46513829cd
commit 0ac8b81a66
34 changed files with 2595 additions and 1757 deletions
+98 -1
View File
@@ -1,14 +1,23 @@
from __future__ import annotations
from decimal import Decimal
from uuid import UUID
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from models.agent_chat_message import AgentChatMessage
from models.points_audit_ledger import PointsAuditLedger
from models.points_ledger import PointsLedger
from models.register_bonus_claims import RegisterBonusClaims
from models.user_points import UserPoints
from schemas.domain.points import ApplyPointsChangeCommand
from schemas.domain.points import (
AppendAuditLedgerCommand,
ApplyPointsChangeCommand,
PointsChargeSnapshot,
)
from schemas.enums import AgentChatMessageRole
class PointsRepository:
@@ -57,6 +66,72 @@ class PointsRepository:
self._session.add(entry)
await self._session.flush()
async def has_audit_event(self, *, event_id: str) -> bool:
stmt = select(PointsAuditLedger.id).where(
PointsAuditLedger.event_id == event_id
)
row = (await self._session.execute(stmt)).scalar_one_or_none()
return row is not None
async def append_audit_ledger(self, *, command: AppendAuditLedgerCommand) -> None:
entry = PointsAuditLedger(
event_id=command.event_id,
user_id_snapshot=command.user_id_snapshot,
user_email_snapshot=command.user_email_snapshot,
change_type=command.change_type.value,
biz_type=command.biz_type.value if command.biz_type is not None else None,
biz_id=command.biz_id,
direction=command.direction,
amount=command.amount,
balance_after=command.balance_after,
billed_to=command.billed_to,
run_id=command.run_id,
request_id=command.request_id,
input_tokens=command.input_tokens,
output_tokens=command.output_tokens,
cost=command.cost,
metadata_json=command.metadata,
)
self._session.add(entry)
await self._session.flush()
async def get_run_usage_snapshot(
self,
*,
session_id: UUID,
run_id: str,
) -> PointsChargeSnapshot | None:
stmt = (
select(AgentChatMessage)
.where(
AgentChatMessage.session_id == session_id,
AgentChatMessage.role == AgentChatMessageRole.ASSISTANT,
AgentChatMessage.deleted_at.is_(None),
)
.order_by(AgentChatMessage.seq.desc())
.limit(20)
)
messages = list((await self._session.execute(stmt)).scalars().all())
message = None
for candidate in messages:
metadata = candidate.metadata_json or {}
if metadata.get("run_id") == run_id:
message = candidate
break
if message is None:
return None
cost_value = message.cost if message.cost is not None else Decimal("0")
return PointsChargeSnapshot(
message_id=message.id,
message_seq=max(int(message.seq), 1),
model_code=(message.model_code or "agent_run").strip() or "agent_run",
input_tokens=max(int(message.input_tokens), 0),
output_tokens=max(int(message.output_tokens), 0),
cost=Decimal(str(cost_value)),
)
async def get_user_points(self, *, user_id: UUID) -> UserPoints:
insert_stmt = (
insert(UserPoints)
@@ -67,3 +142,25 @@ class PointsRepository:
stmt = select(UserPoints).where(UserPoints.user_id == user_id)
return (await self._session.execute(stmt)).scalar_one()
async def claim_register_bonus(
self,
*,
email_hash: str,
user_email_snapshot: str,
first_user_id: UUID,
grant_event_id: str,
) -> bool:
stmt = (
insert(RegisterBonusClaims)
.values(
email_hash=email_hash,
user_email_snapshot=user_email_snapshot,
first_user_id=first_user_id,
grant_event_id=grant_event_id,
)
.on_conflict_do_nothing(index_elements=[RegisterBonusClaims.email_hash])
.returning(RegisterBonusClaims.id)
)
inserted_id = (await self._session.execute(stmt)).scalar_one_or_none()
return inserted_id is not None
+251 -10
View File
@@ -3,10 +3,19 @@ from __future__ import annotations
from dataclasses import dataclass
from decimal import Decimal
import hashlib
import hmac
from typing import Literal
from uuid import UUID, uuid4
from core.config.settings import config
from core.http.errors import ApiProblemError, problem_payload
from schemas.domain.points import ConsumeLedgerMetadata, PointsChargeSnapshot
from schemas.domain.points import (
AppendAuditLedgerCommand,
AuditLedgerMetadata,
ConsumeLedgerMetadata,
RegisterLedgerMetadata,
PointsChargeSnapshot,
)
from schemas.enums import PointsBizType, PointsChangeType, PointsOperatorType
from schemas.domain.points import ApplyPointsChangeCommand
from v1.points.repository import PointsRepository
@@ -31,10 +40,128 @@ class PointsBalanceResult:
can_run: bool
@dataclass(frozen=True)
class PlatformCostAuditResult:
audited: bool
event_id: str
cost: Decimal
@dataclass(frozen=True)
class RegisterBonusResult:
granted: bool
amount: int
balance_after: int
event_id: str
class PointsService:
def __init__(self, repository: PointsRepository) -> None:
self._repository = repository
async def grant_register_bonus_if_eligible(
self,
*,
user_id: UUID,
user_email: str,
) -> RegisterBonusResult:
normalized_email = self._normalize_email(user_email)
if not normalized_email:
return RegisterBonusResult(
granted=False,
amount=0,
balance_after=0,
event_id="",
)
bonus_points = int(config.points_policy.register_bonus_points)
if bonus_points <= 0:
account = await self._repository.get_or_create_user_points_for_update(
user_id=user_id
)
return RegisterBonusResult(
granted=False,
amount=0,
balance_after=int(account.balance),
event_id="",
)
email_hash = self._build_register_bonus_email_hash(normalized_email)
event_hash = hashlib.sha1(
f"{normalized_email}:{email_hash}".encode("utf-8")
).hexdigest()
event_id = f"register.bonus:{event_hash}"
claimed = await self._repository.claim_register_bonus(
email_hash=email_hash,
user_email_snapshot=normalized_email,
first_user_id=user_id,
grant_event_id=event_id,
)
account = await self._repository.get_or_create_user_points_for_update(
user_id=user_id
)
if not claimed:
return RegisterBonusResult(
granted=False,
amount=0,
balance_after=int(account.balance),
event_id=event_id,
)
balance = int(account.balance)
account.balance = balance + bonus_points
account.lifetime_earned = int(account.lifetime_earned) + bonus_points
account.version = int(account.version) + 1
metadata = RegisterLedgerMetadata(
operator_type=PointsOperatorType.SYSTEM,
run_id=event_id,
ext={
"source": "register_bonus_policy",
"email_hash": email_hash,
},
)
command = ApplyPointsChangeCommand(
user_id=user_id,
change_type=PointsChangeType.REGISTER,
event_id=event_id,
amount=bonus_points,
direction=1,
operator_id=None,
metadata=metadata,
)
await self._repository.append_ledger(
command=command,
balance_after=int(account.balance),
)
await self._repository.append_audit_ledger(
command=AppendAuditLedgerCommand(
event_id=event_id,
user_id_snapshot=user_id,
user_email_snapshot=normalized_email,
change_type=PointsChangeType.REGISTER,
direction=1,
amount=bonus_points,
balance_after=int(account.balance),
billed_to="user",
run_id=event_id,
input_tokens=0,
output_tokens=0,
cost=Decimal("0"),
metadata=AuditLedgerMetadata(
source="register_bonus_policy",
email_hash=email_hash,
),
)
)
return RegisterBonusResult(
granted=True,
amount=bonus_points,
balance_after=int(account.balance),
event_id=event_id,
)
async def ensure_run_points_available(
self,
*,
@@ -84,6 +211,7 @@ class PointsService:
session_id: UUID,
run_id: str,
operator_id: UUID | None,
user_email: str | None = None,
) -> RunChargeResult:
event_source = f"{session_id}:{run_id}".encode("utf-8")
event_hash = hashlib.sha1(event_source).hexdigest()
@@ -122,18 +250,28 @@ class PointsService:
account.lifetime_spent = int(account.lifetime_spent) + RUN_POINTS_COST
account.version = int(account.version) + 1
usage_snapshot = await self._repository.get_run_usage_snapshot(
session_id=session_id,
run_id=run_id,
)
usage_missing = usage_snapshot is None
charge_snapshot = usage_snapshot or PointsChargeSnapshot(
message_id=uuid4(),
message_seq=1,
model_code="agent_run",
input_tokens=0,
output_tokens=0,
cost=Decimal("0"),
)
metadata = ConsumeLedgerMetadata(
operator_type=PointsOperatorType.USER,
run_id=run_id,
charge=PointsChargeSnapshot(
message_id=uuid4(),
message_seq=1,
model_code="agent_run",
input_tokens=0,
output_tokens=0,
cost=Decimal("0"),
),
ext={"source": "run_success"},
charge=charge_snapshot,
ext={
"source": "run_success",
"usage_missing": usage_missing,
},
)
command = ApplyPointsChangeCommand(
user_id=user_id,
@@ -150,9 +288,112 @@ class PointsService:
command=command,
balance_after=int(account.balance),
)
await self._repository.append_audit_ledger(
command=AppendAuditLedgerCommand(
event_id=event_id,
user_id_snapshot=user_id,
user_email_snapshot=(user_email or "").strip().lower() or None,
change_type=PointsChangeType.CONSUME,
biz_type=PointsBizType.CHAT,
biz_id=session_id,
direction=-1,
amount=RUN_POINTS_COST,
balance_after=int(account.balance),
billed_to="user",
run_id=run_id,
input_tokens=charge_snapshot.input_tokens,
output_tokens=charge_snapshot.output_tokens,
cost=charge_snapshot.cost,
metadata=AuditLedgerMetadata(
source="run_success",
usage_missing=usage_missing,
),
)
)
return RunChargeResult(
charged=True,
amount=RUN_POINTS_COST,
balance_after=int(account.balance),
event_id=event_id,
)
async def record_failed_run_platform_cost(
self,
*,
user_id: UUID,
session_id: UUID,
run_id: str,
operator_id: UUID | None,
user_email: str | None = None,
failure_kind: Literal["failed", "canceled"],
) -> PlatformCostAuditResult:
event_source = f"{session_id}:{run_id}:{failure_kind}".encode("utf-8")
event_hash = hashlib.sha1(event_source).hexdigest()
event_kind = "fail" if failure_kind == "failed" else "cancel"
event_id = f"chat.run.{event_kind}:{event_hash}"
if await self._repository.has_audit_event(event_id=event_id):
return PlatformCostAuditResult(
audited=False,
event_id=event_id,
cost=Decimal("0"),
)
usage_snapshot = await self._repository.get_run_usage_snapshot(
session_id=session_id,
run_id=run_id,
)
if usage_snapshot is None or usage_snapshot.cost <= Decimal("0"):
return PlatformCostAuditResult(
audited=False,
event_id=event_id,
cost=Decimal("0"),
)
account = await self._repository.get_or_create_user_points_for_update(
user_id=user_id
)
await self._repository.append_audit_ledger(
command=AppendAuditLedgerCommand(
event_id=event_id,
user_id_snapshot=user_id,
user_email_snapshot=(user_email or "").strip().lower() or None,
change_type=PointsChangeType.CONSUME,
biz_type=PointsBizType.CHAT,
biz_id=session_id,
direction=0,
amount=0,
balance_after=int(account.balance),
billed_to="platform",
run_id=run_id,
input_tokens=usage_snapshot.input_tokens,
output_tokens=usage_snapshot.output_tokens,
cost=usage_snapshot.cost,
metadata=AuditLedgerMetadata(
source=f"run_{failure_kind}",
failure_kind=failure_kind,
operator_id=str(operator_id) if operator_id is not None else None,
),
)
)
return PlatformCostAuditResult(
audited=True,
event_id=event_id,
cost=usage_snapshot.cost,
)
@staticmethod
def _normalize_email(email: str) -> str:
return email.strip().lower()
@staticmethod
def _build_register_bonus_email_hash(normalized_email: str) -> str:
key = config.points_policy.register_bonus_hmac_key.get_secret_value().strip()
if not key:
raise RuntimeError("points_policy.register_bonus_hmac_key is required")
digest = hmac.new(
key=key.encode("utf-8"),
msg=normalized_email.encode("utf-8"),
digestmod=hashlib.sha256,
)
return digest.hexdigest()