feat: 添加 points_audit_ledger 及 JSON 字段 Pydantic Schema 约束
This commit is contained in:
@@ -1,14 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from decimal import Decimal
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from models.agent_chat_message import AgentChatMessage
|
||||
from models.points_audit_ledger import PointsAuditLedger
|
||||
from models.points_ledger import PointsLedger
|
||||
from models.register_bonus_claims import RegisterBonusClaims
|
||||
from models.user_points import UserPoints
|
||||
from schemas.domain.points import ApplyPointsChangeCommand
|
||||
from schemas.domain.points import (
|
||||
AppendAuditLedgerCommand,
|
||||
ApplyPointsChangeCommand,
|
||||
PointsChargeSnapshot,
|
||||
)
|
||||
from schemas.enums import AgentChatMessageRole
|
||||
|
||||
|
||||
class PointsRepository:
|
||||
@@ -57,6 +66,72 @@ class PointsRepository:
|
||||
self._session.add(entry)
|
||||
await self._session.flush()
|
||||
|
||||
async def has_audit_event(self, *, event_id: str) -> bool:
|
||||
stmt = select(PointsAuditLedger.id).where(
|
||||
PointsAuditLedger.event_id == event_id
|
||||
)
|
||||
row = (await self._session.execute(stmt)).scalar_one_or_none()
|
||||
return row is not None
|
||||
|
||||
async def append_audit_ledger(self, *, command: AppendAuditLedgerCommand) -> None:
|
||||
entry = PointsAuditLedger(
|
||||
event_id=command.event_id,
|
||||
user_id_snapshot=command.user_id_snapshot,
|
||||
user_email_snapshot=command.user_email_snapshot,
|
||||
change_type=command.change_type.value,
|
||||
biz_type=command.biz_type.value if command.biz_type is not None else None,
|
||||
biz_id=command.biz_id,
|
||||
direction=command.direction,
|
||||
amount=command.amount,
|
||||
balance_after=command.balance_after,
|
||||
billed_to=command.billed_to,
|
||||
run_id=command.run_id,
|
||||
request_id=command.request_id,
|
||||
input_tokens=command.input_tokens,
|
||||
output_tokens=command.output_tokens,
|
||||
cost=command.cost,
|
||||
metadata_json=command.metadata,
|
||||
)
|
||||
self._session.add(entry)
|
||||
await self._session.flush()
|
||||
|
||||
async def get_run_usage_snapshot(
|
||||
self,
|
||||
*,
|
||||
session_id: UUID,
|
||||
run_id: str,
|
||||
) -> PointsChargeSnapshot | None:
|
||||
stmt = (
|
||||
select(AgentChatMessage)
|
||||
.where(
|
||||
AgentChatMessage.session_id == session_id,
|
||||
AgentChatMessage.role == AgentChatMessageRole.ASSISTANT,
|
||||
AgentChatMessage.deleted_at.is_(None),
|
||||
)
|
||||
.order_by(AgentChatMessage.seq.desc())
|
||||
.limit(20)
|
||||
)
|
||||
messages = list((await self._session.execute(stmt)).scalars().all())
|
||||
message = None
|
||||
for candidate in messages:
|
||||
metadata = candidate.metadata_json or {}
|
||||
if metadata.get("run_id") == run_id:
|
||||
message = candidate
|
||||
break
|
||||
|
||||
if message is None:
|
||||
return None
|
||||
|
||||
cost_value = message.cost if message.cost is not None else Decimal("0")
|
||||
return PointsChargeSnapshot(
|
||||
message_id=message.id,
|
||||
message_seq=max(int(message.seq), 1),
|
||||
model_code=(message.model_code or "agent_run").strip() or "agent_run",
|
||||
input_tokens=max(int(message.input_tokens), 0),
|
||||
output_tokens=max(int(message.output_tokens), 0),
|
||||
cost=Decimal(str(cost_value)),
|
||||
)
|
||||
|
||||
async def get_user_points(self, *, user_id: UUID) -> UserPoints:
|
||||
insert_stmt = (
|
||||
insert(UserPoints)
|
||||
@@ -67,3 +142,25 @@ class PointsRepository:
|
||||
|
||||
stmt = select(UserPoints).where(UserPoints.user_id == user_id)
|
||||
return (await self._session.execute(stmt)).scalar_one()
|
||||
|
||||
async def claim_register_bonus(
|
||||
self,
|
||||
*,
|
||||
email_hash: str,
|
||||
user_email_snapshot: str,
|
||||
first_user_id: UUID,
|
||||
grant_event_id: str,
|
||||
) -> bool:
|
||||
stmt = (
|
||||
insert(RegisterBonusClaims)
|
||||
.values(
|
||||
email_hash=email_hash,
|
||||
user_email_snapshot=user_email_snapshot,
|
||||
first_user_id=first_user_id,
|
||||
grant_event_id=grant_event_id,
|
||||
)
|
||||
.on_conflict_do_nothing(index_elements=[RegisterBonusClaims.email_hash])
|
||||
.returning(RegisterBonusClaims.id)
|
||||
)
|
||||
inserted_id = (await self._session.execute(stmt)).scalar_one_or_none()
|
||||
return inserted_id is not None
|
||||
|
||||
@@ -3,10 +3,19 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from decimal import Decimal
|
||||
import hashlib
|
||||
import hmac
|
||||
from typing import Literal
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from core.config.settings import config
|
||||
from core.http.errors import ApiProblemError, problem_payload
|
||||
from schemas.domain.points import ConsumeLedgerMetadata, PointsChargeSnapshot
|
||||
from schemas.domain.points import (
|
||||
AppendAuditLedgerCommand,
|
||||
AuditLedgerMetadata,
|
||||
ConsumeLedgerMetadata,
|
||||
RegisterLedgerMetadata,
|
||||
PointsChargeSnapshot,
|
||||
)
|
||||
from schemas.enums import PointsBizType, PointsChangeType, PointsOperatorType
|
||||
from schemas.domain.points import ApplyPointsChangeCommand
|
||||
from v1.points.repository import PointsRepository
|
||||
@@ -31,10 +40,128 @@ class PointsBalanceResult:
|
||||
can_run: bool
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PlatformCostAuditResult:
|
||||
audited: bool
|
||||
event_id: str
|
||||
cost: Decimal
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RegisterBonusResult:
|
||||
granted: bool
|
||||
amount: int
|
||||
balance_after: int
|
||||
event_id: str
|
||||
|
||||
|
||||
class PointsService:
|
||||
def __init__(self, repository: PointsRepository) -> None:
|
||||
self._repository = repository
|
||||
|
||||
async def grant_register_bonus_if_eligible(
|
||||
self,
|
||||
*,
|
||||
user_id: UUID,
|
||||
user_email: str,
|
||||
) -> RegisterBonusResult:
|
||||
normalized_email = self._normalize_email(user_email)
|
||||
if not normalized_email:
|
||||
return RegisterBonusResult(
|
||||
granted=False,
|
||||
amount=0,
|
||||
balance_after=0,
|
||||
event_id="",
|
||||
)
|
||||
|
||||
bonus_points = int(config.points_policy.register_bonus_points)
|
||||
if bonus_points <= 0:
|
||||
account = await self._repository.get_or_create_user_points_for_update(
|
||||
user_id=user_id
|
||||
)
|
||||
return RegisterBonusResult(
|
||||
granted=False,
|
||||
amount=0,
|
||||
balance_after=int(account.balance),
|
||||
event_id="",
|
||||
)
|
||||
|
||||
email_hash = self._build_register_bonus_email_hash(normalized_email)
|
||||
event_hash = hashlib.sha1(
|
||||
f"{normalized_email}:{email_hash}".encode("utf-8")
|
||||
).hexdigest()
|
||||
event_id = f"register.bonus:{event_hash}"
|
||||
|
||||
claimed = await self._repository.claim_register_bonus(
|
||||
email_hash=email_hash,
|
||||
user_email_snapshot=normalized_email,
|
||||
first_user_id=user_id,
|
||||
grant_event_id=event_id,
|
||||
)
|
||||
account = await self._repository.get_or_create_user_points_for_update(
|
||||
user_id=user_id
|
||||
)
|
||||
if not claimed:
|
||||
return RegisterBonusResult(
|
||||
granted=False,
|
||||
amount=0,
|
||||
balance_after=int(account.balance),
|
||||
event_id=event_id,
|
||||
)
|
||||
|
||||
balance = int(account.balance)
|
||||
account.balance = balance + bonus_points
|
||||
account.lifetime_earned = int(account.lifetime_earned) + bonus_points
|
||||
account.version = int(account.version) + 1
|
||||
|
||||
metadata = RegisterLedgerMetadata(
|
||||
operator_type=PointsOperatorType.SYSTEM,
|
||||
run_id=event_id,
|
||||
ext={
|
||||
"source": "register_bonus_policy",
|
||||
"email_hash": email_hash,
|
||||
},
|
||||
)
|
||||
command = ApplyPointsChangeCommand(
|
||||
user_id=user_id,
|
||||
change_type=PointsChangeType.REGISTER,
|
||||
event_id=event_id,
|
||||
amount=bonus_points,
|
||||
direction=1,
|
||||
operator_id=None,
|
||||
metadata=metadata,
|
||||
)
|
||||
await self._repository.append_ledger(
|
||||
command=command,
|
||||
balance_after=int(account.balance),
|
||||
)
|
||||
await self._repository.append_audit_ledger(
|
||||
command=AppendAuditLedgerCommand(
|
||||
event_id=event_id,
|
||||
user_id_snapshot=user_id,
|
||||
user_email_snapshot=normalized_email,
|
||||
change_type=PointsChangeType.REGISTER,
|
||||
direction=1,
|
||||
amount=bonus_points,
|
||||
balance_after=int(account.balance),
|
||||
billed_to="user",
|
||||
run_id=event_id,
|
||||
input_tokens=0,
|
||||
output_tokens=0,
|
||||
cost=Decimal("0"),
|
||||
metadata=AuditLedgerMetadata(
|
||||
source="register_bonus_policy",
|
||||
email_hash=email_hash,
|
||||
),
|
||||
)
|
||||
)
|
||||
return RegisterBonusResult(
|
||||
granted=True,
|
||||
amount=bonus_points,
|
||||
balance_after=int(account.balance),
|
||||
event_id=event_id,
|
||||
)
|
||||
|
||||
async def ensure_run_points_available(
|
||||
self,
|
||||
*,
|
||||
@@ -84,6 +211,7 @@ class PointsService:
|
||||
session_id: UUID,
|
||||
run_id: str,
|
||||
operator_id: UUID | None,
|
||||
user_email: str | None = None,
|
||||
) -> RunChargeResult:
|
||||
event_source = f"{session_id}:{run_id}".encode("utf-8")
|
||||
event_hash = hashlib.sha1(event_source).hexdigest()
|
||||
@@ -122,18 +250,28 @@ class PointsService:
|
||||
account.lifetime_spent = int(account.lifetime_spent) + RUN_POINTS_COST
|
||||
account.version = int(account.version) + 1
|
||||
|
||||
usage_snapshot = await self._repository.get_run_usage_snapshot(
|
||||
session_id=session_id,
|
||||
run_id=run_id,
|
||||
)
|
||||
usage_missing = usage_snapshot is None
|
||||
charge_snapshot = usage_snapshot or PointsChargeSnapshot(
|
||||
message_id=uuid4(),
|
||||
message_seq=1,
|
||||
model_code="agent_run",
|
||||
input_tokens=0,
|
||||
output_tokens=0,
|
||||
cost=Decimal("0"),
|
||||
)
|
||||
|
||||
metadata = ConsumeLedgerMetadata(
|
||||
operator_type=PointsOperatorType.USER,
|
||||
run_id=run_id,
|
||||
charge=PointsChargeSnapshot(
|
||||
message_id=uuid4(),
|
||||
message_seq=1,
|
||||
model_code="agent_run",
|
||||
input_tokens=0,
|
||||
output_tokens=0,
|
||||
cost=Decimal("0"),
|
||||
),
|
||||
ext={"source": "run_success"},
|
||||
charge=charge_snapshot,
|
||||
ext={
|
||||
"source": "run_success",
|
||||
"usage_missing": usage_missing,
|
||||
},
|
||||
)
|
||||
command = ApplyPointsChangeCommand(
|
||||
user_id=user_id,
|
||||
@@ -150,9 +288,112 @@ class PointsService:
|
||||
command=command,
|
||||
balance_after=int(account.balance),
|
||||
)
|
||||
await self._repository.append_audit_ledger(
|
||||
command=AppendAuditLedgerCommand(
|
||||
event_id=event_id,
|
||||
user_id_snapshot=user_id,
|
||||
user_email_snapshot=(user_email or "").strip().lower() or None,
|
||||
change_type=PointsChangeType.CONSUME,
|
||||
biz_type=PointsBizType.CHAT,
|
||||
biz_id=session_id,
|
||||
direction=-1,
|
||||
amount=RUN_POINTS_COST,
|
||||
balance_after=int(account.balance),
|
||||
billed_to="user",
|
||||
run_id=run_id,
|
||||
input_tokens=charge_snapshot.input_tokens,
|
||||
output_tokens=charge_snapshot.output_tokens,
|
||||
cost=charge_snapshot.cost,
|
||||
metadata=AuditLedgerMetadata(
|
||||
source="run_success",
|
||||
usage_missing=usage_missing,
|
||||
),
|
||||
)
|
||||
)
|
||||
return RunChargeResult(
|
||||
charged=True,
|
||||
amount=RUN_POINTS_COST,
|
||||
balance_after=int(account.balance),
|
||||
event_id=event_id,
|
||||
)
|
||||
|
||||
async def record_failed_run_platform_cost(
|
||||
self,
|
||||
*,
|
||||
user_id: UUID,
|
||||
session_id: UUID,
|
||||
run_id: str,
|
||||
operator_id: UUID | None,
|
||||
user_email: str | None = None,
|
||||
failure_kind: Literal["failed", "canceled"],
|
||||
) -> PlatformCostAuditResult:
|
||||
event_source = f"{session_id}:{run_id}:{failure_kind}".encode("utf-8")
|
||||
event_hash = hashlib.sha1(event_source).hexdigest()
|
||||
event_kind = "fail" if failure_kind == "failed" else "cancel"
|
||||
event_id = f"chat.run.{event_kind}:{event_hash}"
|
||||
|
||||
if await self._repository.has_audit_event(event_id=event_id):
|
||||
return PlatformCostAuditResult(
|
||||
audited=False,
|
||||
event_id=event_id,
|
||||
cost=Decimal("0"),
|
||||
)
|
||||
|
||||
usage_snapshot = await self._repository.get_run_usage_snapshot(
|
||||
session_id=session_id,
|
||||
run_id=run_id,
|
||||
)
|
||||
if usage_snapshot is None or usage_snapshot.cost <= Decimal("0"):
|
||||
return PlatformCostAuditResult(
|
||||
audited=False,
|
||||
event_id=event_id,
|
||||
cost=Decimal("0"),
|
||||
)
|
||||
|
||||
account = await self._repository.get_or_create_user_points_for_update(
|
||||
user_id=user_id
|
||||
)
|
||||
await self._repository.append_audit_ledger(
|
||||
command=AppendAuditLedgerCommand(
|
||||
event_id=event_id,
|
||||
user_id_snapshot=user_id,
|
||||
user_email_snapshot=(user_email or "").strip().lower() or None,
|
||||
change_type=PointsChangeType.CONSUME,
|
||||
biz_type=PointsBizType.CHAT,
|
||||
biz_id=session_id,
|
||||
direction=0,
|
||||
amount=0,
|
||||
balance_after=int(account.balance),
|
||||
billed_to="platform",
|
||||
run_id=run_id,
|
||||
input_tokens=usage_snapshot.input_tokens,
|
||||
output_tokens=usage_snapshot.output_tokens,
|
||||
cost=usage_snapshot.cost,
|
||||
metadata=AuditLedgerMetadata(
|
||||
source=f"run_{failure_kind}",
|
||||
failure_kind=failure_kind,
|
||||
operator_id=str(operator_id) if operator_id is not None else None,
|
||||
),
|
||||
)
|
||||
)
|
||||
return PlatformCostAuditResult(
|
||||
audited=True,
|
||||
event_id=event_id,
|
||||
cost=usage_snapshot.cost,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_email(email: str) -> str:
|
||||
return email.strip().lower()
|
||||
|
||||
@staticmethod
|
||||
def _build_register_bonus_email_hash(normalized_email: str) -> str:
|
||||
key = config.points_policy.register_bonus_hmac_key.get_secret_value().strip()
|
||||
if not key:
|
||||
raise RuntimeError("points_policy.register_bonus_hmac_key is required")
|
||||
digest = hmac.new(
|
||||
key=key.encode("utf-8"),
|
||||
msg=normalized_email.encode("utf-8"),
|
||||
digestmod=hashlib.sha256,
|
||||
)
|
||||
return digest.hexdigest()
|
||||
|
||||
Reference in New Issue
Block a user