chore: 迁移到 social-app 架构,集成 Supabase 和 taskiq worker
This commit is contained in:
@@ -49,6 +49,32 @@ This file governs `backend/**` only. Keep it minimal, enforceable, and non-dupli
|
||||
- Strong typing required at boundaries (Pydantic/dataclass); avoid weak untyped payload contracts.
|
||||
- Protocol/data contract changes must stay aligned with `docs/protocols/`.
|
||||
|
||||
## Database Rules
|
||||
|
||||
- Supabase Auth is identity source; backend enforces business authorization.
|
||||
- Use service-role DB access only in backend.
|
||||
- Soft delete uses `deleted_at`; reads must exclude deleted records by default.
|
||||
- Alembic is the only schema migration source of truth.
|
||||
- Database migrations use `./infra/scripts/dev-migrate.sh`:
|
||||
- `migrate` - run migrations only
|
||||
- `init-data` - seed data only
|
||||
- `bootstrap` - migrate + init-data
|
||||
|
||||
## Agent Runtime & Tools
|
||||
|
||||
- AG-UI protocol is mandatory for agent loop behavior.
|
||||
- `ToolAgentOutput.result` is the canonical tool result field.
|
||||
- Tool results must be machine-oriented and include IDs/outcomes needed for chaining.
|
||||
|
||||
## Tool Schema Rules for Small Models (e.g., qwen3.5-flash)
|
||||
|
||||
- Prefer `operations: list[OperationModel]` over parallel arrays.
|
||||
- Validate tool args with strict Pydantic models (`extra="forbid"`).
|
||||
- Keep payloads JSON-native (objects/lists), shallow, and deterministic.
|
||||
- Make action-specific required fields explicit and fail with structured errors.
|
||||
- Return per-item outcomes (`success/failed`, identifiers, partial status) for self-correction.
|
||||
- Avoid broad entry-point coercion fallbacks; fix schema/prompt alignment first.
|
||||
- Do not pass provider request fields with `None` values (avoid upstream 400 blocking tool calls).
|
||||
|
||||
## Testing
|
||||
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
app = FastAPI(
|
||||
title="Eryao API",
|
||||
description="觅爻签问后端服务",
|
||||
version="0.1.0",
|
||||
)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check() -> dict[str, str]:
|
||||
return {"status": "ok"}
|
||||
@@ -1,3 +0,0 @@
|
||||
from .settings import Settings, config
|
||||
|
||||
__all__ = ["Settings", "config"]
|
||||
@@ -8,6 +8,7 @@ from pydantic import (
|
||||
AnyHttpUrl,
|
||||
BaseModel,
|
||||
Field,
|
||||
SecretStr,
|
||||
computed_field,
|
||||
field_validator,
|
||||
model_validator,
|
||||
@@ -118,11 +119,54 @@ class RedisSettings(BaseModel):
|
||||
return f"redis://{self.host}:{self.port}/{self.db}"
|
||||
|
||||
|
||||
class SupabaseSettings(BaseModel):
|
||||
public_url: AnyHttpUrl
|
||||
anon_key: str = "CHANGE_ME"
|
||||
service_role_key: str = "CHANGE_ME"
|
||||
jwt_secret: SecretStr | None = Field(default=None, exclude=True)
|
||||
jwt_algorithm: Literal["HS256"] = "HS256"
|
||||
jwt_issuer: str | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def compute_defaults(self) -> "SupabaseSettings":
|
||||
base = str(self.public_url).rstrip("/")
|
||||
if self.jwt_issuer is None:
|
||||
self.jwt_issuer = f"{base}/auth/v1"
|
||||
|
||||
return self
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def url(self) -> str:
|
||||
return str(self.public_url)
|
||||
|
||||
|
||||
class StorageSettings(BaseModel):
|
||||
provider: Literal["supabase"] = "supabase"
|
||||
signed_url_ttl_seconds: int = Field(default=600, ge=60, le=3600)
|
||||
retention_days: int = Field(default=30, ge=1, le=3650)
|
||||
|
||||
class AttachmentSettings(BaseModel):
|
||||
bucket: str = Field(default="eryao-attachments", min_length=3, max_length=63)
|
||||
max_size_mb: int = Field(default=20, ge=1, le=200)
|
||||
|
||||
class AvatarSettings(BaseModel):
|
||||
bucket: str = Field(default="avatars", min_length=3, max_length=63)
|
||||
max_size_mb: int = Field(default=2, ge=1, le=10)
|
||||
|
||||
attachment: AttachmentSettings = Field(default_factory=AttachmentSettings)
|
||||
avatar: AvatarSettings = Field(default_factory=AvatarSettings)
|
||||
|
||||
|
||||
class LlmSettings(BaseModel):
|
||||
provider_keys: dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class DatabaseSettings(BaseModel):
|
||||
host: str = "localhost"
|
||||
port: int = 3306
|
||||
name: str = "eryao"
|
||||
user: str = "root"
|
||||
port: int = 5432
|
||||
name: str = "postgres"
|
||||
user: str = "postgres"
|
||||
password: str = "CHANGE_ME"
|
||||
|
||||
@computed_field
|
||||
@@ -130,83 +174,11 @@ class DatabaseSettings(BaseModel):
|
||||
def url(self) -> str:
|
||||
password = quote(self.password, safe="")
|
||||
return (
|
||||
f"mysql+aiomysql://{self.user}:{password}"
|
||||
f"postgresql+asyncpg://{self.user}:{password}"
|
||||
f"@{self.host}:{self.port}/{self.name}"
|
||||
)
|
||||
|
||||
|
||||
class AppVersionSettings(BaseModel):
|
||||
manifest_path: str = Field(
|
||||
default="deploy/static/releases/manifest.json",
|
||||
description="发布清单文件路径,相对于项目根目录",
|
||||
)
|
||||
release_path_prefix: str = Field(
|
||||
default="releases",
|
||||
description="下载 URL 中文件目录前缀",
|
||||
)
|
||||
download_base_url: AnyHttpUrl | None = Field(
|
||||
default=None,
|
||||
description="下载链接基础域名,如 https://your-domain.com",
|
||||
)
|
||||
|
||||
@field_validator("download_base_url", mode="before")
|
||||
@classmethod
|
||||
def empty_download_base_url_to_none(cls, value: object) -> object:
|
||||
if value == "":
|
||||
return None
|
||||
return value
|
||||
|
||||
@field_validator("manifest_path")
|
||||
@classmethod
|
||||
def validate_manifest_path(cls, value: str) -> str:
|
||||
normalized = Path(value)
|
||||
if normalized.is_absolute() or ".." in normalized.parts:
|
||||
raise ValueError("manifest_path must be a safe relative path")
|
||||
return value
|
||||
|
||||
|
||||
class AliyunSmsSettings(BaseModel):
|
||||
access_key_id: str = "CHANGE_ME"
|
||||
access_key_secret: str = "CHANGE_ME"
|
||||
sign_name: str = "CHANGE_ME"
|
||||
template_code: str = "CHANGE_ME"
|
||||
region_id: str = "cn-hangzhou"
|
||||
endpoint: str = "dysmsapi.aliyuncs.com"
|
||||
test_mode: bool = False
|
||||
|
||||
|
||||
class AliyunContentSecuritySettings(BaseModel):
|
||||
access_key_id: str = "CHANGE_ME"
|
||||
access_key_secret: str = "CHANGE_ME"
|
||||
endpoint: str = "green-cip.cn-shenzhen.aliyuncs.com"
|
||||
|
||||
|
||||
class AlipaySettings(BaseModel):
|
||||
app_id: str = "CHANGE_ME"
|
||||
merchant_id: str = "CHANGE_ME"
|
||||
public_key: str = "CHANGE_ME"
|
||||
private_key: str = "CHANGE_ME"
|
||||
sign_type: str = "RSA2"
|
||||
notify_url: str = ""
|
||||
timeout_express: str = "30m"
|
||||
sandbox: bool = False
|
||||
|
||||
|
||||
class DeepSeekSettings(BaseModel):
|
||||
api_key: str = "CHANGE_ME"
|
||||
|
||||
|
||||
class AuthSettings(BaseModel):
|
||||
token_expiration_days: int = 7
|
||||
token_refresh_threshold_hours: int = 2
|
||||
|
||||
|
||||
class VerificationSettings(BaseModel):
|
||||
code_length: int = 6
|
||||
expiration_minutes: int = 5
|
||||
test_mode: bool = False
|
||||
|
||||
|
||||
class SensitiveWordSettings(BaseModel):
|
||||
use_aliyun: bool = True
|
||||
fallback_to_local: bool = True
|
||||
@@ -217,6 +189,11 @@ class TestSettings(BaseModel):
|
||||
password: str = ""
|
||||
|
||||
|
||||
class TaskiqSettings(BaseModel):
|
||||
broker_url: str | None = None
|
||||
result_backend_url: str | None = None
|
||||
|
||||
|
||||
def _resolve_env_file() -> str:
|
||||
current = Path(__file__).resolve()
|
||||
for parent in [current, *current.parents]:
|
||||
@@ -233,24 +210,31 @@ class Settings(BaseSettings):
|
||||
runtime: RuntimeSettings = RuntimeSettings()
|
||||
cors: CorsSettings = CorsSettings()
|
||||
redis: RedisSettings = RedisSettings()
|
||||
database: DatabaseSettings = DatabaseSettings()
|
||||
app_version: AppVersionSettings = AppVersionSettings()
|
||||
aliyun_sms: AliyunSmsSettings = AliyunSmsSettings()
|
||||
aliyun_content_security: AliyunContentSecuritySettings = (
|
||||
AliyunContentSecuritySettings()
|
||||
supabase: SupabaseSettings = Field(
|
||||
default_factory=lambda: SupabaseSettings(public_url="http://localhost:8001")
|
||||
)
|
||||
alipay: AlipaySettings = AlipaySettings()
|
||||
deepseek: DeepSeekSettings = DeepSeekSettings()
|
||||
auth: AuthSettings = AuthSettings()
|
||||
verification: VerificationSettings = VerificationSettings()
|
||||
sensitive_word: SensitiveWordSettings = SensitiveWordSettings()
|
||||
storage: StorageSettings = StorageSettings()
|
||||
llm: LlmSettings = LlmSettings()
|
||||
database: DatabaseSettings = DatabaseSettings()
|
||||
sensitive_word: SensitiveWordSettings = Field(default_factory=SensitiveWordSettings)
|
||||
test: TestSettings = Field(default_factory=TestSettings)
|
||||
taskiq: TaskiqSettings = Field(default_factory=TaskiqSettings)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def database_url(self) -> str:
|
||||
return self.database.url
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def taskiq_broker_url(self) -> str:
|
||||
return self.taskiq.broker_url or self.redis.url
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def taskiq_result_backend_url(self) -> str:
|
||||
return self.taskiq.result_backend_url or self.redis.url
|
||||
|
||||
model_config: ClassVar[SettingsConfigDict] = SettingsConfigDict(
|
||||
env_file=_resolve_env_file(),
|
||||
env_prefix="ERYAO_",
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
input_template: |
|
||||
你正在执行一次"自动化记忆回顾与整理"任务。
|
||||
|
||||
任务目标:
|
||||
1) 回顾最近两天的聊天与上下文,识别用户长期偏好、习惯和关键事实的变化。
|
||||
2) 对已经失效、被否定或明显过期的信息执行遗忘。
|
||||
3) 对新增且有证据支持的信息执行写入。
|
||||
4) 严禁编造;没有证据就不要写入。
|
||||
5) 只更新最小必要字段,避免过度覆盖。
|
||||
|
||||
输出要求:
|
||||
- 必须使用以下固定格式输出:
|
||||
<----------【周期任务输出】---------->
|
||||
【记忆回顾】<一句人性化总结,说明今天主要发生了什么>
|
||||
【新增记忆】<按"X条:要点1;要点2"描述;没有则写"0条">
|
||||
【遗忘记忆】<按"X条:要点1;要点2"描述;没有则写"0条">
|
||||
【未来展望】<基于本次记忆变化,给出1-2条温和、可执行的后续建议;若暂无建议则说明"可继续观察">
|
||||
|
||||
表达风格:
|
||||
- 语言自然、温和、可读,像助理在做每日回顾。
|
||||
- 结论先行,避免空话,不要输出与任务无关的闲聊内容。
|
||||
enabled_tools:
|
||||
- memory.write
|
||||
- memory.forget
|
||||
context:
|
||||
source: latest_chat
|
||||
window_mode: day
|
||||
window_count: 2
|
||||
schedule:
|
||||
type: daily
|
||||
run_at:
|
||||
hour: 8
|
||||
minute: 0
|
||||
weekdays: null
|
||||
@@ -1,158 +0,0 @@
|
||||
version: "1.0"
|
||||
routes:
|
||||
- route_id: auth.boot
|
||||
path: /boot
|
||||
description: Bootstraps auth session and redirects to login or home.
|
||||
category: auth
|
||||
auth_required: false
|
||||
- route_id: auth.login
|
||||
path: /login
|
||||
description: Login entry for unauthenticated users.
|
||||
category: auth
|
||||
auth_required: false
|
||||
- route_id: home.main
|
||||
path: /
|
||||
description: Main assistant home screen.
|
||||
category: home
|
||||
auth_required: true
|
||||
- route_id: message.invite_list
|
||||
path: /messages/invites
|
||||
description: Lists message invitations.
|
||||
category: messages
|
||||
auth_required: true
|
||||
- route_id: message.invite_detail
|
||||
path: /messages/invites/{id}
|
||||
description: Shows details for a single invitation.
|
||||
category: messages
|
||||
auth_required: true
|
||||
path_params:
|
||||
- id
|
||||
- route_id: contacts.list
|
||||
path: /contacts
|
||||
description: Contact list and quick relationship actions.
|
||||
category: contacts
|
||||
auth_required: true
|
||||
- route_id: contacts.add
|
||||
path: /contacts/add
|
||||
description: Create or edit a contact profile.
|
||||
category: contacts
|
||||
auth_required: true
|
||||
- route_id: calendar.dayweek
|
||||
path: /calendar/dayweek
|
||||
description: Day and week calendar view.
|
||||
category: calendar
|
||||
auth_required: true
|
||||
query_params:
|
||||
- date
|
||||
- from
|
||||
- route_id: calendar.month
|
||||
path: /calendar/month
|
||||
description: Month calendar overview.
|
||||
category: calendar
|
||||
auth_required: true
|
||||
query_params:
|
||||
- from
|
||||
- route_id: calendar.event_detail
|
||||
path: /calendar/events/{id}
|
||||
description: Detail page for one calendar event.
|
||||
category: calendar
|
||||
auth_required: true
|
||||
path_params:
|
||||
- id
|
||||
- route_id: calendar.event_create
|
||||
path: /calendar/events/new
|
||||
description: Create page for one calendar event.
|
||||
category: calendar
|
||||
auth_required: true
|
||||
query_params:
|
||||
- date
|
||||
- route_id: calendar.event_edit
|
||||
path: /calendar/events/{id}/edit
|
||||
description: Edit page for one calendar event.
|
||||
category: calendar
|
||||
auth_required: true
|
||||
path_params:
|
||||
- id
|
||||
- route_id: calendar.event_share
|
||||
path: /calendar/events/{id}/share
|
||||
description: Share settings page for one calendar event.
|
||||
category: calendar
|
||||
auth_required: true
|
||||
path_params:
|
||||
- id
|
||||
- route_id: todo.list
|
||||
path: /todo
|
||||
description: Todo quadrants and backlog overview.
|
||||
category: todo
|
||||
auth_required: true
|
||||
- route_id: todo.create
|
||||
path: /todo/new
|
||||
description: Create page for one todo item.
|
||||
category: todo
|
||||
auth_required: true
|
||||
- route_id: todo.detail
|
||||
path: /todo/{id}
|
||||
description: Detail page for one todo item.
|
||||
category: todo
|
||||
auth_required: true
|
||||
path_params:
|
||||
- id
|
||||
- route_id: todo.edit
|
||||
path: /todo/{id}/edit
|
||||
description: Dedicated subpage for editing one todo item (not an in-page modal).
|
||||
category: todo
|
||||
auth_required: true
|
||||
path_params:
|
||||
- id
|
||||
- route_id: settings.main
|
||||
path: /settings
|
||||
description: Settings hub page.
|
||||
category: settings
|
||||
auth_required: true
|
||||
- route_id: settings.features
|
||||
path: /settings/features
|
||||
description: Automation job list page.
|
||||
category: settings
|
||||
auth_required: true
|
||||
- route_id: settings.job_new
|
||||
path: /settings/job/new
|
||||
description: Create page for one automation job.
|
||||
category: settings
|
||||
auth_required: true
|
||||
- route_id: settings.job_detail
|
||||
path: /settings/job/{id}
|
||||
description: Detail page for one automation job.
|
||||
category: settings
|
||||
auth_required: true
|
||||
path_params:
|
||||
- id
|
||||
- route_id: settings.memory
|
||||
path: /settings/memory
|
||||
description: Memory preferences and controls.
|
||||
category: settings
|
||||
auth_required: true
|
||||
- route_id: settings.memory_user
|
||||
path: /settings/memory/user
|
||||
description: User memory summary view.
|
||||
category: settings
|
||||
auth_required: true
|
||||
- route_id: settings.memory_work
|
||||
path: /settings/memory/work
|
||||
description: Work memory summary view.
|
||||
category: settings
|
||||
auth_required: true
|
||||
- route_id: settings.memory_user_edit
|
||||
path: /settings/memory/user/edit
|
||||
description: Edit user memory details.
|
||||
category: settings
|
||||
auth_required: true
|
||||
- route_id: settings.memory_work_edit
|
||||
path: /settings/memory/work/edit
|
||||
description: Edit work memory details.
|
||||
category: settings
|
||||
auth_required: true
|
||||
- route_id: settings.edit_profile
|
||||
path: /edit-profile
|
||||
description: Profile editing page.
|
||||
category: settings
|
||||
auth_required: true
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import JSON
|
||||
from sqlalchemy.dialects.mysql import JSON as MySQLJSON
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
|
||||
json_type = JSON().with_variant(MySQLJSON, "mysql")
|
||||
json_jsonb = JSON().with_variant(JSONB, "postgresql")
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
__all__ = []
|
||||
|
||||
@@ -5,9 +5,6 @@ import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.triggers.interval import IntervalTrigger
|
||||
from core.automation.scheduler import run_automation_scheduler_scan
|
||||
from core.config.initial.init_data import initialize_data
|
||||
from core.config.settings import config
|
||||
from core.logging import get_logger
|
||||
@@ -16,7 +13,6 @@ logger = get_logger("core.runtime.cli")
|
||||
|
||||
|
||||
def _resolve_alembic_path() -> Path:
|
||||
"""Resolve alembic.ini path relative to project root."""
|
||||
project_root = Path(__file__).parents[3]
|
||||
alembic_path = project_root / "alembic" / "alembic.ini"
|
||||
if not alembic_path.exists():
|
||||
@@ -25,7 +21,6 @@ def _resolve_alembic_path() -> Path:
|
||||
|
||||
|
||||
def _redact_sensitive(text: str) -> str:
|
||||
"""Redact sensitive information from log output."""
|
||||
import re
|
||||
|
||||
SENSITIVE_KEYS = ("password", "token", "secret", "api_key")
|
||||
@@ -40,7 +35,6 @@ def _redact_sensitive(text: str) -> str:
|
||||
|
||||
|
||||
def run_migrations() -> bool:
|
||||
"""Run alembic migrations in a subprocess to avoid event loop conflicts."""
|
||||
import os
|
||||
|
||||
logger.info("Running alembic migrations")
|
||||
@@ -75,7 +69,6 @@ def run_migrations() -> bool:
|
||||
|
||||
|
||||
async def run_init_data() -> bool:
|
||||
"""Initialize bootstrap data."""
|
||||
logger.info("Running init-data")
|
||||
try:
|
||||
result = await initialize_data()
|
||||
@@ -90,7 +83,6 @@ async def run_init_data() -> bool:
|
||||
|
||||
|
||||
async def bootstrap() -> bool:
|
||||
"""Run migrations followed by init-data."""
|
||||
logger.info("Starting bootstrap (migrate + init-data)")
|
||||
|
||||
if not run_migrations():
|
||||
@@ -105,52 +97,11 @@ async def bootstrap() -> bool:
|
||||
return True
|
||||
|
||||
|
||||
async def run_automation_scheduler_forever() -> None:
|
||||
if not config.automation_scheduler.enabled:
|
||||
logger.info("Automation scheduler disabled by config")
|
||||
return
|
||||
|
||||
interval_seconds = int(config.automation_scheduler.interval_seconds)
|
||||
batch_limit = int(config.automation_scheduler.batch_limit)
|
||||
logger.info(
|
||||
"Starting automation scheduler",
|
||||
interval_seconds=interval_seconds,
|
||||
batch_limit=batch_limit,
|
||||
)
|
||||
|
||||
async def scan_job() -> None:
|
||||
try:
|
||||
await run_automation_scheduler_scan(limit=batch_limit)
|
||||
except Exception as exc:
|
||||
logger.exception("Automation scheduler scan failed", error=str(exc))
|
||||
|
||||
scheduler = AsyncIOScheduler()
|
||||
scheduler.add_job(
|
||||
scan_job,
|
||||
trigger=IntervalTrigger(seconds=interval_seconds),
|
||||
id="automation_scheduler_scan",
|
||||
name="Automation scheduler scan",
|
||||
replace_existing=True,
|
||||
max_instances=1,
|
||||
coalesce=True,
|
||||
)
|
||||
scheduler.start()
|
||||
|
||||
stop_event = asyncio.Event()
|
||||
try:
|
||||
await stop_event.wait()
|
||||
finally:
|
||||
scheduler.shutdown(wait=False)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
"""CLI entry point."""
|
||||
if len(sys.argv) < 2:
|
||||
logger.error("No command provided")
|
||||
logger.info("Usage: python -m core.runtime.cli <command>")
|
||||
logger.info(
|
||||
"Available commands: migrate, init-data, bootstrap, automation-scheduler"
|
||||
)
|
||||
logger.info("Available commands: migrate, init-data, bootstrap")
|
||||
return 1
|
||||
|
||||
command = sys.argv[1]
|
||||
@@ -161,14 +112,9 @@ def main() -> int:
|
||||
success = asyncio.run(run_init_data())
|
||||
elif command == "bootstrap":
|
||||
success = asyncio.run(bootstrap())
|
||||
elif command == "automation-scheduler":
|
||||
asyncio.run(run_automation_scheduler_forever())
|
||||
return 0
|
||||
else:
|
||||
logger.error("Unknown command", command=command)
|
||||
logger.info(
|
||||
"Available commands: migrate, init-data, bootstrap, automation-scheduler"
|
||||
)
|
||||
logger.info("Available commands: migrate, init-data, bootstrap")
|
||||
return 1
|
||||
|
||||
return 0 if success else 1
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
__all__ = []
|
||||
@@ -0,0 +1,3 @@
|
||||
from core.taskiq.app import broker, worker_agent_broker, worker_general_broker
|
||||
|
||||
__all__ = ["broker", "worker_agent_broker", "worker_general_broker"]
|
||||
@@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from taskiq_redis import ListQueueBroker, RedisAsyncResultBackend
|
||||
|
||||
from core.config.settings import config
|
||||
from core.logging import configure_logging, log_service_banner
|
||||
|
||||
|
||||
configure_logging(config)
|
||||
log_service_banner(
|
||||
service_name=config.runtime.service_name,
|
||||
environment=config.runtime.environment,
|
||||
)
|
||||
|
||||
|
||||
def _build_broker(queue_name: str) -> ListQueueBroker:
|
||||
return ListQueueBroker(
|
||||
url=config.taskiq_broker_url,
|
||||
queue_name=queue_name,
|
||||
).with_result_backend(
|
||||
RedisAsyncResultBackend(redis_url=config.taskiq_result_backend_url)
|
||||
)
|
||||
|
||||
|
||||
worker_agent_broker = _build_broker("agent")
|
||||
worker_general_broker = _build_broker("general")
|
||||
|
||||
broker = worker_agent_broker
|
||||
|
||||
__all__ = ["broker", "worker_agent_broker", "worker_general_broker"]
|
||||
@@ -1,12 +1,11 @@
|
||||
from . import user, divination, payment, notification, feedback, version, log, violation
|
||||
from __future__ import annotations
|
||||
|
||||
from models.llm import Llm
|
||||
from models.llm_factory import LlmFactory
|
||||
from models.system_agents import SystemAgents
|
||||
|
||||
__all__ = [
|
||||
"user",
|
||||
"divination",
|
||||
"payment",
|
||||
"notification",
|
||||
"feedback",
|
||||
"version",
|
||||
"log",
|
||||
"violation",
|
||||
"Llm",
|
||||
"LlmFactory",
|
||||
"SystemAgents",
|
||||
]
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.db.base import Base, TimestampMixin
|
||||
from sqlalchemy import BigInteger, DateTime, Integer, String, Text, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
|
||||
class DivinationRecord(TimestampMixin, Base):
|
||||
__tablename__ = "user_divination_records"
|
||||
|
||||
id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
|
||||
user_id: Mapped[int] = mapped_column(BigInteger, nullable=False)
|
||||
trace_id: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
question: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
question_type: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
divination_data: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
deepseek_request: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
deepseek_response: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
interpretation_result: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
api_success: Mapped[bool] = mapped_column(Integer, nullable=False, default=0)
|
||||
error_message: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
api_duration_ms: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
|
||||
phone_number: Mapped[str | None] = mapped_column(String(20), nullable=True)
|
||||
|
||||
|
||||
class DivinationHistory(TimestampMixin, Base):
|
||||
__tablename__ = "user_divination_history"
|
||||
|
||||
id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
|
||||
user_id: Mapped[int] = mapped_column(BigInteger, nullable=False)
|
||||
phone_number: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||
local_record_id: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
|
||||
json_data: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
ai_result: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
question_type: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
question: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
timestamp: Mapped[int] = mapped_column(BigInteger, nullable=False)
|
||||
is_active: Mapped[bool] = mapped_column(Integer, nullable=False, default=1)
|
||||
sync_time: Mapped[str] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.now()
|
||||
)
|
||||
@@ -1,17 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.db.base import Base
|
||||
from sqlalchemy import BigInteger, DateTime, String, Text, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
|
||||
class UserFeedback(Base):
|
||||
__tablename__ = "user_feedback"
|
||||
|
||||
id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
|
||||
user_id: Mapped[int] = mapped_column(BigInteger, nullable=False)
|
||||
phone_number: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||
content: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
created_at: Mapped[str] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.now()
|
||||
)
|
||||
@@ -0,0 +1,26 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import ForeignKey, String
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from core.db.base import Base, SoftDeleteMixin, TimestampMixin
|
||||
|
||||
|
||||
class Llm(TimestampMixin, SoftDeleteMixin, Base):
|
||||
__tablename__: str = "llms"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
factory_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("llm_factory.id", ondelete="RESTRICT"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
model_code: Mapped[str] = mapped_column(
|
||||
String(50), nullable=False, unique=True, index=True
|
||||
)
|
||||
@@ -0,0 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import String, Text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from core.db.base import Base, SoftDeleteMixin, TimestampMixin
|
||||
|
||||
|
||||
class LlmFactory(TimestampMixin, SoftDeleteMixin, Base):
|
||||
__tablename__: str = "llm_factory"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
name: Mapped[str] = mapped_column(
|
||||
String(50), nullable=False, unique=True, index=True
|
||||
)
|
||||
request_url: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
avatar: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
@@ -1,39 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.db.base import Base
|
||||
from sqlalchemy import BigInteger, DateTime, Integer, String, Text, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
|
||||
class NetworkAccessLog(Base):
|
||||
__tablename__ = "network_access_logs"
|
||||
|
||||
id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
|
||||
user_id: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
|
||||
phone_number: Mapped[str | None] = mapped_column(String(20), nullable=True)
|
||||
client_ip: Mapped[str] = mapped_column(String(45), nullable=False)
|
||||
client_port: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
server_ip: Mapped[str] = mapped_column(String(45), nullable=False)
|
||||
server_port: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
http_method: Mapped[str] = mapped_column(String(10), nullable=False)
|
||||
request_path: Mapped[str] = mapped_column(String(500), nullable=False)
|
||||
request_url: Mapped[str] = mapped_column(String(1000), nullable=False)
|
||||
user_agent: Mapped[str | None] = mapped_column(String(1000), nullable=True)
|
||||
device_info: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
response_status: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
processing_time_ms: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
|
||||
request_size: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
|
||||
response_size: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
|
||||
x_forwarded_for: Mapped[str | None] = mapped_column(String(500), nullable=True)
|
||||
x_real_ip: Mapped[str | None] = mapped_column(String(45), nullable=True)
|
||||
referer: Mapped[str | None] = mapped_column(String(1000), nullable=True)
|
||||
operation_type: Mapped[str | None] = mapped_column(String(50), nullable=True)
|
||||
operation_result: Mapped[str | None] = mapped_column(String(20), nullable=True)
|
||||
error_message: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
session_id: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
access_time: Mapped[str] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.now()
|
||||
)
|
||||
created_at: Mapped[str] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.now()
|
||||
)
|
||||
@@ -1,14 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.db.base import Base
|
||||
from sqlalchemy import BigInteger, String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
|
||||
class Notification(Base):
|
||||
__tablename__ = "notification"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
title: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
content: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
timestamp: Mapped[int] = mapped_column(BigInteger, nullable=False)
|
||||
@@ -1,40 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.db.base import Base, TimestampMixin
|
||||
from sqlalchemy import BigInteger, DateTime, Integer, String, Text, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
|
||||
class PaymentOrder(TimestampMixin, Base):
|
||||
__tablename__ = "payment_order"
|
||||
|
||||
id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
|
||||
order_no: Mapped[str] = mapped_column(String(64), unique=True, nullable=False)
|
||||
user_id: Mapped[int] = mapped_column(BigInteger, nullable=False)
|
||||
amount: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||
coin_count: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
subject: Mapped[str] = mapped_column(String(256), nullable=False)
|
||||
body: Mapped[str | None] = mapped_column(String(512), nullable=True)
|
||||
channel: Mapped[str] = mapped_column(String(16), nullable=False)
|
||||
status: Mapped[str] = mapped_column(String(16), nullable=False, default="CREATED")
|
||||
trade_no: Mapped[str | None] = mapped_column(String(64), nullable=True)
|
||||
payment_time: Mapped[str | None] = mapped_column(DateTime, nullable=True)
|
||||
|
||||
|
||||
class PaymentRecord(Base):
|
||||
__tablename__ = "payment_record"
|
||||
|
||||
id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
|
||||
order_no: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
trade_no: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
user_id: Mapped[int] = mapped_column(BigInteger, nullable=False)
|
||||
channel: Mapped[str] = mapped_column(String(16), nullable=False)
|
||||
notify_type: Mapped[str] = mapped_column(String(16), nullable=False)
|
||||
trade_status: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||
notify_data: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
process_status: Mapped[str] = mapped_column(String(16), nullable=False)
|
||||
process_message: Mapped[str | None] = mapped_column(String(512), nullable=True)
|
||||
coin_added: Mapped[bool] = mapped_column(Integer, nullable=False, default=0)
|
||||
created_at: Mapped[str] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.now()
|
||||
)
|
||||
@@ -0,0 +1,32 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import JSON, ForeignKey, String
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from core.db.base import Base, TimestampMixin
|
||||
|
||||
|
||||
class SystemAgents(TimestampMixin, Base):
|
||||
__tablename__: str = "system_agents"
|
||||
|
||||
agent_type: Mapped[str] = mapped_column(
|
||||
String(20),
|
||||
primary_key=True,
|
||||
)
|
||||
llm_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("llms.id", ondelete="RESTRICT"),
|
||||
nullable=False,
|
||||
)
|
||||
status: Mapped[str] = mapped_column(
|
||||
String(20),
|
||||
nullable=False,
|
||||
)
|
||||
config: Mapped[dict] = mapped_column(
|
||||
JSON().with_variant(JSONB, "postgresql"),
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
)
|
||||
@@ -1,52 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.db.base import Base, TimestampMixin
|
||||
from sqlalchemy import BigInteger, DateTime, Integer, String, Text, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
|
||||
class User(TimestampMixin, Base):
|
||||
__tablename__ = "user_profile"
|
||||
|
||||
id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
|
||||
phone_number: Mapped[str] = mapped_column(String(20), unique=True, nullable=False)
|
||||
nickname: Mapped[str] = mapped_column(String(50), nullable=False, default="")
|
||||
gender: Mapped[str] = mapped_column(String(10), nullable=False, default="男")
|
||||
birthday: Mapped[str] = mapped_column(
|
||||
String(20), nullable=False, default="2000-01-01"
|
||||
)
|
||||
signature: Mapped[str] = mapped_column(String(255), nullable=False, default="")
|
||||
|
||||
|
||||
class UserToken(Base):
|
||||
__tablename__ = "user_tokens"
|
||||
|
||||
id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
|
||||
user_id: Mapped[int] = mapped_column(BigInteger, nullable=False)
|
||||
token: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
expire_time: Mapped[str] = mapped_column(DateTime, nullable=False)
|
||||
created_at: Mapped[str] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
|
||||
class VerificationCode(Base):
|
||||
__tablename__ = "verification_codes"
|
||||
|
||||
id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
|
||||
phone_number: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||
code: Mapped[str] = mapped_column(String(6), nullable=False)
|
||||
expiration_time: Mapped[str] = mapped_column(DateTime, nullable=False)
|
||||
created_at: Mapped[str] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.now()
|
||||
)
|
||||
used: Mapped[bool] = mapped_column(Integer, nullable=False, default=0)
|
||||
|
||||
|
||||
class UserCoin(Base):
|
||||
__tablename__ = "user_coin"
|
||||
|
||||
id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
|
||||
user_id: Mapped[int] = mapped_column(BigInteger, nullable=False, unique=True)
|
||||
phone_number: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||
coin_balance: Mapped[int] = mapped_column(Integer, nullable=False, default=3)
|
||||
@@ -1,19 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.db.base import Base, TimestampMixin
|
||||
from sqlalchemy import BigInteger, Integer, String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
|
||||
class AppVersion(TimestampMixin, Base):
|
||||
__tablename__ = "app_version"
|
||||
|
||||
id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
|
||||
version_name: Mapped[str] = mapped_column(String(20), unique=True, nullable=False)
|
||||
version_code: Mapped[int] = mapped_column(Integer, unique=True, nullable=False)
|
||||
min_supported_version: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||
min_supported_code: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
is_force_update: Mapped[bool] = mapped_column(Integer, nullable=False, default=0)
|
||||
update_message: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
download_url: Mapped[str | None] = mapped_column(String(500), nullable=True)
|
||||
is_active: Mapped[bool] = mapped_column(Integer, nullable=False, default=1)
|
||||
@@ -1,30 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.db.base import Base
|
||||
from sqlalchemy import BigInteger, DateTime, Float, Integer, String, Text, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
|
||||
class SensitiveWordViolation(Base):
|
||||
__tablename__ = "sensitive_word_violations"
|
||||
|
||||
id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
|
||||
user_id: Mapped[int] = mapped_column(BigInteger, nullable=False)
|
||||
content_type: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||
original_content: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
violation_type: Mapped[str] = mapped_column(String(30), nullable=False)
|
||||
detection_service: Mapped[str] = mapped_column(
|
||||
String(20), nullable=False, default="LOCAL"
|
||||
)
|
||||
risk_level: Mapped[str | None] = mapped_column(String(50), nullable=True)
|
||||
confidence: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
aliyun_response: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
matched_words: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
client_ip: Mapped[str | None] = mapped_column(String(45), nullable=True)
|
||||
user_agent: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
violation_time: Mapped[str] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.now()
|
||||
)
|
||||
created_at: Mapped[str] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.now()
|
||||
)
|
||||
@@ -0,0 +1 @@
|
||||
"""Backend reusable schemas package."""
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
from schemas.agent.forwarded_props import (
|
||||
ClientTimeContext,
|
||||
ForwardedPropsPayload,
|
||||
parse_forwarded_props_client_time,
|
||||
parse_forwarded_props_runtime_mode,
|
||||
)
|
||||
from schemas.agent.forwarded_props import RuntimeMode
|
||||
from schemas.agent.runtime_models import (
|
||||
AgentOutput,
|
||||
ConstraintItem,
|
||||
ExecutionMode,
|
||||
KeyEntity,
|
||||
NormalizedTaskInput,
|
||||
ResultTyping,
|
||||
ResultType,
|
||||
RouterAgentOutput,
|
||||
RunStatus,
|
||||
TaskType,
|
||||
TaskTyping,
|
||||
ToolAgentOutput,
|
||||
ToolStatus,
|
||||
WorkerAgentOutputLite,
|
||||
WorkerAgentOutputRich,
|
||||
resolve_worker_output_model,
|
||||
)
|
||||
from schemas.agent.system_agent import AgentType, SystemAgentLLMConfig
|
||||
from schemas.agent.visibility import SystemVisibilityBit, VisibilityMask, bit_mask
|
||||
from schemas.agent.ui_hints import (
|
||||
UiHintAction,
|
||||
UiHintIntent,
|
||||
UiHintSection,
|
||||
UiHintStatus,
|
||||
UiHintsPayload,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AgentType",
|
||||
"AgentOutput",
|
||||
"ConstraintItem",
|
||||
"ExecutionMode",
|
||||
"ForwardedPropsPayload",
|
||||
"KeyEntity",
|
||||
"NormalizedTaskInput",
|
||||
"ResultTyping",
|
||||
"ClientTimeContext",
|
||||
"ResultType",
|
||||
"RouterAgentOutput",
|
||||
"RunStatus",
|
||||
"RuntimeMode",
|
||||
"TaskType",
|
||||
"TaskTyping",
|
||||
"SystemAgentLLMConfig",
|
||||
"SystemVisibilityBit",
|
||||
"ToolAgentOutput",
|
||||
"ToolStatus",
|
||||
"UiHintAction",
|
||||
"UiHintIntent",
|
||||
"UiHintSection",
|
||||
"UiHintStatus",
|
||||
"UiHintsPayload",
|
||||
"VisibilityMask",
|
||||
"WorkerAgentOutputLite",
|
||||
"WorkerAgentOutputRich",
|
||||
"bit_mask",
|
||||
"parse_forwarded_props_client_time",
|
||||
"parse_forwarded_props_runtime_mode",
|
||||
"resolve_worker_output_model",
|
||||
]
|
||||
@@ -0,0 +1,93 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
import re
|
||||
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
StrictInt,
|
||||
ValidationError,
|
||||
field_validator,
|
||||
)
|
||||
|
||||
_RFC3339_WITH_TZ_PATTERN = re.compile(
|
||||
r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}(?:\.\d+)?(?:Z|[+-]\d{2}:\d{2})$"
|
||||
)
|
||||
|
||||
|
||||
class ClientTimeContext(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
device_timezone: str = Field(
|
||||
...,
|
||||
description="IANA timezone from client device, e.g. America/Los_Angeles.",
|
||||
)
|
||||
client_now_iso: str = Field(
|
||||
...,
|
||||
description="RFC3339 datetime with timezone offset from client device.",
|
||||
)
|
||||
client_epoch_ms: StrictInt = Field(
|
||||
...,
|
||||
ge=0,
|
||||
description="Unix epoch milliseconds from client device.",
|
||||
)
|
||||
|
||||
@field_validator("device_timezone")
|
||||
@classmethod
|
||||
def validate_device_timezone(cls, value: str) -> str:
|
||||
try:
|
||||
ZoneInfo(value)
|
||||
except ZoneInfoNotFoundError as exc:
|
||||
raise ValueError("invalid client_time.device_timezone") from exc
|
||||
return value
|
||||
|
||||
@field_validator("client_now_iso")
|
||||
@classmethod
|
||||
def validate_client_now_iso(cls, value: str) -> str:
|
||||
if not _RFC3339_WITH_TZ_PATTERN.fullmatch(value):
|
||||
raise ValueError("invalid client_time.client_now_iso")
|
||||
normalized = value.replace("Z", "+00:00")
|
||||
try:
|
||||
parsed = datetime.fromisoformat(normalized)
|
||||
except ValueError as exc:
|
||||
raise ValueError("invalid client_time.client_now_iso") from exc
|
||||
if parsed.tzinfo is None:
|
||||
raise ValueError("invalid client_time.client_now_iso")
|
||||
return value
|
||||
|
||||
|
||||
class RuntimeMode(str, Enum):
|
||||
CHAT = "chat"
|
||||
AUTOMATION = "automation"
|
||||
|
||||
|
||||
class ForwardedPropsPayload(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
runtime_mode: RuntimeMode
|
||||
client_time: ClientTimeContext | None = None
|
||||
|
||||
|
||||
def parse_forwarded_props(forwarded_props: object) -> ForwardedPropsPayload:
|
||||
if not isinstance(forwarded_props, dict):
|
||||
raise ValueError("invalid RunAgentInput.forwardedProps")
|
||||
try:
|
||||
return ForwardedPropsPayload.model_validate(forwarded_props)
|
||||
except ValidationError as exc:
|
||||
raise ValueError("invalid RunAgentInput.forwardedProps") from exc
|
||||
|
||||
|
||||
def parse_forwarded_props_client_time(
|
||||
forwarded_props: object,
|
||||
) -> ClientTimeContext | None:
|
||||
payload = parse_forwarded_props(forwarded_props)
|
||||
return payload.client_time
|
||||
|
||||
|
||||
def parse_forwarded_props_runtime_mode(forwarded_props: object) -> RuntimeMode:
|
||||
payload = parse_forwarded_props(forwarded_props)
|
||||
return payload.runtime_mode
|
||||
@@ -0,0 +1,177 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from schemas.agent.ui_hints import UiHintsPayload
|
||||
|
||||
|
||||
class TaskType(str, Enum):
|
||||
KNOWLEDGE = "knowledge"
|
||||
RECOMMENDATION = "recommendation"
|
||||
PLANNING = "planning"
|
||||
SCHEDULING = "scheduling"
|
||||
REMINDER_MANAGEMENT = "reminder_management"
|
||||
TODO_MANAGEMENT = "todo_management"
|
||||
COMMUNICATION_DRAFTING = "communication_drafting"
|
||||
INFORMATION_ORGANIZATION = "information_organization"
|
||||
STATUS_TRACKING = "status_tracking"
|
||||
TRANSACTION_ASSIST = "transaction_assist"
|
||||
ACTION_EXECUTION = "action_execution"
|
||||
TROUBLESHOOTING = "troubleshooting"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
class ResultType(str, Enum):
|
||||
DIRECT_ANSWER = "direct_answer"
|
||||
OPTIONS_WITH_RECOMMENDATION = "options_with_recommendation"
|
||||
ACTION_PLAN = "action_plan"
|
||||
SCHEDULE_PROPOSAL = "schedule_proposal"
|
||||
TODO_LIST = "todo_list"
|
||||
DRAFT_MESSAGE = "draft_message"
|
||||
SUMMARY = "summary"
|
||||
PROGRESS_SUMMARY = "progress_summary"
|
||||
DIAGNOSIS_REPORT = "diagnosis_report"
|
||||
STRUCTURED_PAYLOAD = "structured_payload"
|
||||
EXECUTION_REPORT = "execution_report"
|
||||
CLARIFICATION_REQUEST = "clarification_request"
|
||||
SAFETY_BLOCK = "safety_block"
|
||||
ERROR_REPORT = "error_report"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
class TaskTyping(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
primary: TaskType
|
||||
secondary: list[TaskType] = Field(default_factory=list, max_length=3)
|
||||
|
||||
|
||||
class ResultTyping(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
primary: ResultType
|
||||
secondary: list[ResultType] = Field(default_factory=list, max_length=3)
|
||||
|
||||
|
||||
class ExecutionMode(str, Enum):
|
||||
ONESTEP = "onestep"
|
||||
TOOL_ASSISTED = "tool_assisted"
|
||||
MULTISTEP = "multistep"
|
||||
|
||||
|
||||
class RunStatus(str, Enum):
|
||||
SUCCESS = "success"
|
||||
PARTIAL_SUCCESS = "partial_success"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class ToolStatus(str, Enum):
|
||||
SUCCESS = "success"
|
||||
FAILURE = "failure"
|
||||
PARTIAL = "partial"
|
||||
|
||||
|
||||
class KeyEntity(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
name: str
|
||||
type: str
|
||||
value: str | None = None
|
||||
|
||||
@field_validator("value", mode="before")
|
||||
@classmethod
|
||||
def normalize_value(cls, value: object) -> object:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
if isinstance(value, bool | int | float):
|
||||
return str(value)
|
||||
return value
|
||||
|
||||
|
||||
class ConstraintItem(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
key: str
|
||||
value: str
|
||||
required: bool = True
|
||||
|
||||
@field_validator("value", mode="before")
|
||||
@classmethod
|
||||
def normalize_value(cls, value: object) -> object:
|
||||
if isinstance(value, bool | int | float):
|
||||
return str(value)
|
||||
return value
|
||||
|
||||
|
||||
class NormalizedTaskInput(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
user_text: str
|
||||
multimodal_summary: list[str] = Field(default_factory=list)
|
||||
context_summary: str = Field(default="", max_length=2000)
|
||||
|
||||
|
||||
class RouterAgentOutput(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
normalized_task_input: NormalizedTaskInput
|
||||
key_entities: list[KeyEntity] = Field(default_factory=list)
|
||||
constraints: list[ConstraintItem] = Field(default_factory=list)
|
||||
task_typing: TaskTyping
|
||||
execution_mode: ExecutionMode
|
||||
result_typing: ResultTyping
|
||||
|
||||
|
||||
class ErrorInfo(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
code: str
|
||||
message: str
|
||||
retryable: bool = False
|
||||
details: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ToolAgentOutput(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
tool_name: str
|
||||
tool_call_id: str
|
||||
tool_call_args: dict[str, Any] | None = None
|
||||
status: ToolStatus
|
||||
result: str
|
||||
error: ErrorInfo | None = None
|
||||
|
||||
|
||||
class WorkerAgentOutputLite(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
status: RunStatus = RunStatus.SUCCESS
|
||||
answer: str
|
||||
key_points: list[str] = Field(default_factory=list)
|
||||
result_type: ResultType = ResultType.UNKNOWN
|
||||
suggested_actions: list[str] = Field(default_factory=list)
|
||||
error: ErrorInfo | None = None
|
||||
|
||||
|
||||
class WorkerAgentOutputRich(WorkerAgentOutputLite):
|
||||
ui_hints: UiHintsPayload | None = None
|
||||
|
||||
|
||||
class AgentOutput(WorkerAgentOutputRich):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
WorkerAgentOutput = WorkerAgentOutputLite | WorkerAgentOutputRich
|
||||
|
||||
|
||||
def resolve_worker_output_model(
|
||||
execution_mode: ExecutionMode,
|
||||
) -> type[WorkerAgentOutputLite]:
|
||||
if execution_mode == ExecutionMode.ONESTEP:
|
||||
return WorkerAgentOutputLite
|
||||
return WorkerAgentOutputRich
|
||||
@@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class AgentType(str, Enum):
|
||||
ROUTER = "router"
|
||||
WORKER = "worker"
|
||||
|
||||
|
||||
class ContextBuildStrategy(str, Enum):
|
||||
DAY = "day"
|
||||
NUMBER = "number"
|
||||
|
||||
|
||||
class ContextMessagesConfig(BaseModel):
|
||||
mode: ContextBuildStrategy = ContextBuildStrategy.NUMBER
|
||||
count: int = Field(default=20, ge=1, le=200)
|
||||
|
||||
|
||||
class SystemAgentLLMConfig(BaseModel):
|
||||
temperature: float | None = Field(default=None, ge=0.0, le=2.0)
|
||||
max_tokens: int | None = Field(default=None, ge=1)
|
||||
timeout_seconds: float | None = Field(default=30.0, gt=0.0, le=300.0)
|
||||
context_messages: ContextMessagesConfig = Field(
|
||||
default_factory=ContextMessagesConfig
|
||||
)
|
||||
enabled_tools: list[str] = Field(default_factory=list, max_length=32)
|
||||
@@ -0,0 +1,349 @@
|
||||
"""
|
||||
UiHints - 描述性 UI 提示
|
||||
|
||||
设计原则:
|
||||
- 描述性而非渲染性: 告诉编译器“要展示什么”,而不是“如何渲染”
|
||||
- 最小化 token: 保持字段简洁
|
||||
- 可编译: 可机械转换为 UiSchemaRenderer
|
||||
- 尽量无损: hints 中的主要内容字段应尽量被保留到 renderer 中
|
||||
|
||||
Version: 2.1
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
import re
|
||||
from typing import Any, ClassVar, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import field_validator
|
||||
|
||||
_NAVIGATION_PATH_PATTERN = re.compile(r"^/[A-Za-z0-9/_-]*$")
|
||||
_NAVIGATION_PARAM_KEY_PATTERN = re.compile(r"^[A-Za-z][A-Za-z0-9_]{0,31}$")
|
||||
_MAX_NAVIGATION_PARAMS = 8
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Enums
|
||||
# ============================================================
|
||||
|
||||
|
||||
class UiHintStatus(str, Enum):
|
||||
INFO = "info"
|
||||
SUCCESS = "success"
|
||||
WARNING = "warning"
|
||||
ERROR = "error"
|
||||
PENDING = "pending"
|
||||
|
||||
|
||||
class UiHintIntent(str, Enum):
|
||||
"""主要展示意图(弱提示,不应决定字段生死)"""
|
||||
|
||||
MESSAGE = "message" # 普通消息/说明
|
||||
DATA = "data" # 数据/结果摘要
|
||||
LIST = "list" # 列表为主
|
||||
STATUS = "status" # 状态结果为主
|
||||
FORM = "form" # 结构化内容(当前不表示真实输入表单)
|
||||
MIXED = "mixed" # 混合内容
|
||||
|
||||
|
||||
class UiHintActionStyle(str, Enum):
|
||||
PRIMARY = "primary"
|
||||
SECONDARY = "secondary"
|
||||
GHOST = "ghost"
|
||||
DANGER = "danger"
|
||||
|
||||
|
||||
class UiHintTextFormat(str, Enum):
|
||||
PLAIN = "plain"
|
||||
MARKDOWN = "markdown"
|
||||
|
||||
|
||||
class UiHintActionType(str, Enum):
|
||||
NAVIGATION = "navigation"
|
||||
URL = "url"
|
||||
EVENT = "event"
|
||||
TOOL = "tool"
|
||||
COPY = "copy"
|
||||
PAYLOAD = "payload"
|
||||
|
||||
|
||||
class UiHintIconSource(str, Enum):
|
||||
ICON = "icon"
|
||||
EMOJI = "emoji"
|
||||
URL = "url"
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Base Config
|
||||
# ============================================================
|
||||
|
||||
|
||||
class UiHintBaseModel(BaseModel):
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(
|
||||
extra="forbid",
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Action Targets
|
||||
# ============================================================
|
||||
|
||||
|
||||
class UiHintActionNavigation(UiHintBaseModel):
|
||||
type: Literal["navigation"]
|
||||
path: str = Field(..., description="Internal route path.")
|
||||
params: dict[str, Any] | None = Field(default=None, description="Route params.")
|
||||
|
||||
@field_validator("path")
|
||||
@classmethod
|
||||
def validate_navigation_path(cls, value: str) -> str:
|
||||
path = value.strip()
|
||||
if not path:
|
||||
raise ValueError("navigation path must not be empty")
|
||||
if len(path) > 256:
|
||||
raise ValueError("navigation path is too long")
|
||||
if path.startswith("//") or "://" in path:
|
||||
raise ValueError("navigation path must be internal")
|
||||
if "?" in path or "#" in path:
|
||||
raise ValueError("navigation path must not contain query or fragment")
|
||||
if ":" in path:
|
||||
raise ValueError("navigation path must be concrete without placeholders")
|
||||
if _NAVIGATION_PATH_PATTERN.fullmatch(path) is None:
|
||||
raise ValueError("navigation path contains unsupported characters")
|
||||
return path
|
||||
|
||||
@field_validator("params")
|
||||
@classmethod
|
||||
def validate_navigation_params(
|
||||
cls, value: dict[str, Any] | None
|
||||
) -> dict[str, Any] | None:
|
||||
if value is None:
|
||||
return None
|
||||
if len(value) > _MAX_NAVIGATION_PARAMS:
|
||||
raise ValueError("navigation params exceed limit")
|
||||
|
||||
normalized: dict[str, Any] = {}
|
||||
for key, param_value in value.items():
|
||||
if _NAVIGATION_PARAM_KEY_PATTERN.fullmatch(key) is None:
|
||||
raise ValueError("navigation param key is invalid")
|
||||
if isinstance(param_value, (str, int, float, bool)):
|
||||
normalized[key] = param_value
|
||||
continue
|
||||
raise ValueError("navigation params must be scalar")
|
||||
return normalized
|
||||
|
||||
|
||||
class UiHintActionUrl(UiHintBaseModel):
|
||||
type: Literal["url"]
|
||||
url: str = Field(..., description="External URL.")
|
||||
target: Literal["_self", "_blank"] | None = Field(default=None)
|
||||
|
||||
|
||||
class UiHintActionEvent(UiHintBaseModel):
|
||||
type: Literal["event"]
|
||||
event: str = Field(..., description="Frontend event name.")
|
||||
payload: dict[str, Any] | None = Field(default=None)
|
||||
|
||||
|
||||
class UiHintActionTool(UiHintBaseModel):
|
||||
type: Literal["tool"]
|
||||
tool_id: str = Field(alias="toolId", description="Tool identifier.")
|
||||
params: dict[str, Any] | None = Field(default=None)
|
||||
|
||||
|
||||
class UiHintActionCopy(UiHintBaseModel):
|
||||
type: Literal["copy"]
|
||||
content: str = Field(..., description="Content to copy.")
|
||||
success_message: str | None = Field(alias="successMessage", default=None)
|
||||
|
||||
|
||||
class UiHintActionPayload(UiHintBaseModel):
|
||||
type: Literal["payload"]
|
||||
payload: dict[str, Any] = Field(..., description="Structured payload.")
|
||||
submit_to: str | None = Field(alias="submitTo", default=None)
|
||||
|
||||
|
||||
UiHintActionTarget = (
|
||||
UiHintActionNavigation
|
||||
| UiHintActionUrl
|
||||
| UiHintActionEvent
|
||||
| UiHintActionTool
|
||||
| UiHintActionCopy
|
||||
| UiHintActionPayload
|
||||
)
|
||||
|
||||
|
||||
class UiHintAction(UiHintBaseModel):
|
||||
label: str = Field(..., description="Button label.")
|
||||
style: UiHintActionStyle | None = Field(default=None, description="Button style.")
|
||||
disabled: bool = Field(default=False, description="Disabled state.")
|
||||
action: UiHintActionTarget = Field(..., description="Action to execute.")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Small Descriptive Models
|
||||
# ============================================================
|
||||
|
||||
|
||||
class UiHintIcon(UiHintBaseModel):
|
||||
source: UiHintIconSource = Field(default=UiHintIconSource.ICON)
|
||||
value: str = Field(..., description="Icon identifier / emoji / url.")
|
||||
color: str | None = Field(default=None)
|
||||
size: int | None = Field(default=None)
|
||||
|
||||
|
||||
class UiHintKvItem(UiHintBaseModel):
|
||||
key: str = Field(..., description="Key identifier.")
|
||||
label: str | None = Field(default=None, description="Display label.")
|
||||
value: Any = Field(default=None, description="Value.")
|
||||
copyable: bool = Field(default=False, description="Allow copy.")
|
||||
|
||||
|
||||
class UiHintListItem(UiHintBaseModel):
|
||||
id: str | None = Field(default=None)
|
||||
title: str = Field(..., description="Item title.")
|
||||
subtitle: str | None = Field(default=None)
|
||||
description: str | None = Field(default=None)
|
||||
icon: UiHintIcon | None = Field(default=None)
|
||||
status: UiHintStatus | None = Field(default=None)
|
||||
actions: list[UiHintAction] = Field(default_factory=list)
|
||||
|
||||
@field_validator("status", mode="before")
|
||||
@classmethod
|
||||
def normalize_status(cls, value: object) -> object:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, dict):
|
||||
status_type = value.get("type")
|
||||
if isinstance(status_type, str):
|
||||
return status_type
|
||||
status_value = value.get("status")
|
||||
if isinstance(status_value, str):
|
||||
return status_value
|
||||
return value
|
||||
|
||||
|
||||
class UiHintSection(UiHintBaseModel):
|
||||
title: str | None = Field(default=None, description="Section title.")
|
||||
description: str | None = Field(default=None, description="Section description.")
|
||||
icon: UiHintIcon | None = Field(default=None, description="Section icon.")
|
||||
|
||||
content: str | None = Field(default=None, description="Main text content.")
|
||||
content_format: UiHintTextFormat = Field(
|
||||
default=UiHintTextFormat.PLAIN,
|
||||
alias="contentFormat",
|
||||
description="Section content text format.",
|
||||
)
|
||||
|
||||
items: list[UiHintKvItem] = Field(default_factory=list, description="KV items.")
|
||||
list_items: list[UiHintListItem] = Field(
|
||||
default_factory=list,
|
||||
alias="listItems",
|
||||
description="List items.",
|
||||
)
|
||||
actions: list[UiHintAction] = Field(default_factory=list, description="Actions.")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Root Payload
|
||||
# ============================================================
|
||||
|
||||
|
||||
class UiHintsPayload(UiHintBaseModel):
|
||||
"""
|
||||
描述性 UI 提示
|
||||
|
||||
设计目标:
|
||||
- agent 输出尽可能短
|
||||
- 不表达布局细节
|
||||
- 编译器负责转换为完整 UiSchemaRenderer
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(
|
||||
extra="forbid",
|
||||
populate_by_name=True,
|
||||
json_schema_extra={
|
||||
"examples": [
|
||||
{
|
||||
"intent": "status",
|
||||
"status": "success",
|
||||
"title": "日程已创建",
|
||||
"body": "本次创建已成功完成。",
|
||||
"items": [
|
||||
{"key": "title", "label": "主题", "value": "Q1 规划会议"},
|
||||
{"key": "time", "label": "时间", "value": "2026-03-15 14:00"},
|
||||
],
|
||||
"actions": [
|
||||
{
|
||||
"label": "查看详情",
|
||||
"style": "primary",
|
||||
"action": {
|
||||
"type": "navigation",
|
||||
"path": "/calendar/evt_123",
|
||||
},
|
||||
},
|
||||
{
|
||||
"label": "删除",
|
||||
"style": "danger",
|
||||
"action": {
|
||||
"type": "tool",
|
||||
"toolId": "calendar.delete",
|
||||
"params": {"eventId": "evt_123"},
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
version: str = Field(default="2.1")
|
||||
|
||||
intent: UiHintIntent = Field(
|
||||
default=UiHintIntent.MESSAGE,
|
||||
description="Primary display intent.",
|
||||
)
|
||||
status: UiHintStatus = Field(
|
||||
default=UiHintStatus.INFO,
|
||||
description="Overall status.",
|
||||
)
|
||||
|
||||
title: str | None = Field(default=None, description="Top-level title.")
|
||||
description: str | None = Field(default=None, description="Top-level description.")
|
||||
|
||||
body: str | None = Field(default=None, description="Top-level main body text.")
|
||||
body_format: UiHintTextFormat = Field(
|
||||
default=UiHintTextFormat.PLAIN,
|
||||
alias="bodyFormat",
|
||||
description="Body text format.",
|
||||
)
|
||||
|
||||
items: list[UiHintKvItem] = Field(
|
||||
default_factory=list,
|
||||
description="Top-level key-value items.",
|
||||
)
|
||||
list_items: list[UiHintListItem] = Field(
|
||||
default_factory=list,
|
||||
alias="listItems",
|
||||
description="Top-level list items.",
|
||||
)
|
||||
sections: list[UiHintSection] = Field(
|
||||
default_factory=list,
|
||||
description="Grouped sections.",
|
||||
)
|
||||
actions: list[UiHintAction] = Field(
|
||||
default_factory=list,
|
||||
description="Top-level actions.",
|
||||
)
|
||||
|
||||
icon: UiHintIcon | None = Field(
|
||||
default=None,
|
||||
description="Top-level icon.",
|
||||
)
|
||||
meta: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Extra meta, e.g. requestId/toolId/traceId/userId.",
|
||||
)
|
||||
@@ -0,0 +1,628 @@
|
||||
"""
|
||||
UI Schema Renderer Protocol
|
||||
|
||||
目标:
|
||||
- 只保留“基础组件 + 布局容器”
|
||||
- 最终返回一个 UiSchemaRenderer
|
||||
- 前端只需要递归渲染 root 布局树即可
|
||||
|
||||
Version: 2.0
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Literal, TypedDict, Union
|
||||
|
||||
# ============================================================
|
||||
# Enums
|
||||
# ============================================================
|
||||
|
||||
|
||||
class UiStatus(str, Enum):
|
||||
INFO = "info"
|
||||
SUCCESS = "success"
|
||||
WARNING = "warning"
|
||||
ERROR = "error"
|
||||
PENDING = "pending"
|
||||
|
||||
|
||||
class IconSource(str, Enum):
|
||||
ICON = "icon"
|
||||
EMOJI = "emoji"
|
||||
URL = "url"
|
||||
|
||||
|
||||
class TextFormat(str, Enum):
|
||||
PLAIN = "plain"
|
||||
MARKDOWN = "markdown"
|
||||
|
||||
|
||||
class TextRole(str, Enum):
|
||||
TITLE = "title"
|
||||
SUBTITLE = "subtitle"
|
||||
BODY = "body"
|
||||
CAPTION = "caption"
|
||||
CODE = "code"
|
||||
|
||||
|
||||
class ButtonStyle(str, Enum):
|
||||
PRIMARY = "primary"
|
||||
SECONDARY = "secondary"
|
||||
GHOST = "ghost"
|
||||
DANGER = "danger"
|
||||
|
||||
|
||||
class LayoutDirection(str, Enum):
|
||||
VERTICAL = "vertical"
|
||||
HORIZONTAL = "horizontal"
|
||||
|
||||
|
||||
class LayoutAppearance(str, Enum):
|
||||
PLAIN = "plain"
|
||||
CARD = "card"
|
||||
SECTION = "section"
|
||||
|
||||
|
||||
class LayoutAlign(str, Enum):
|
||||
START = "start"
|
||||
CENTER = "center"
|
||||
END = "end"
|
||||
STRETCH = "stretch"
|
||||
|
||||
|
||||
class LayoutJustify(str, Enum):
|
||||
START = "start"
|
||||
CENTER = "center"
|
||||
END = "end"
|
||||
SPACE_BETWEEN = "space-between"
|
||||
|
||||
|
||||
class RendererTheme(str, Enum):
|
||||
DEFAULT = "default"
|
||||
LIGHT = "light"
|
||||
DARK = "dark"
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Meta
|
||||
# ============================================================
|
||||
|
||||
|
||||
class UiMeta(TypedDict, total=False):
|
||||
requestId: str
|
||||
toolId: str
|
||||
traceId: str
|
||||
userId: str
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Action Payloads
|
||||
# ============================================================
|
||||
|
||||
|
||||
class NavigateAction(TypedDict, total=False):
|
||||
type: Literal["navigation"]
|
||||
path: str
|
||||
params: dict[str, Any]
|
||||
|
||||
|
||||
class UrlAction(TypedDict, total=False):
|
||||
type: Literal["url"]
|
||||
url: str
|
||||
target: Literal["_self", "_blank"]
|
||||
|
||||
|
||||
class EventAction(TypedDict, total=False):
|
||||
type: Literal["event"]
|
||||
event: str
|
||||
payload: dict[str, Any]
|
||||
|
||||
|
||||
class ToolAction(TypedDict, total=False):
|
||||
type: Literal["tool"]
|
||||
toolId: str
|
||||
params: dict[str, Any]
|
||||
|
||||
|
||||
class CopyAction(TypedDict, total=False):
|
||||
type: Literal["copy"]
|
||||
content: str
|
||||
successMessage: str
|
||||
|
||||
|
||||
class PayloadAction(TypedDict, total=False):
|
||||
type: Literal["payload"]
|
||||
payload: dict[str, Any]
|
||||
submitTo: str
|
||||
|
||||
|
||||
UiActionPayload = Union[
|
||||
NavigateAction,
|
||||
UrlAction,
|
||||
EventAction,
|
||||
ToolAction,
|
||||
CopyAction,
|
||||
PayloadAction,
|
||||
]
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Shared Small Types
|
||||
# ============================================================
|
||||
|
||||
|
||||
class UiIconSpec(TypedDict, total=False):
|
||||
source: str
|
||||
value: str
|
||||
color: str
|
||||
size: int
|
||||
|
||||
|
||||
class UiKvItem(TypedDict, total=False):
|
||||
key: str
|
||||
label: str
|
||||
value: Any
|
||||
copyable: bool
|
||||
|
||||
|
||||
class UiBaseNode(TypedDict, total=False):
|
||||
id: str
|
||||
visible: bool
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Primitive Components
|
||||
# ============================================================
|
||||
|
||||
|
||||
class UiTextNode(UiBaseNode, total=False):
|
||||
type: Literal["text"]
|
||||
content: str
|
||||
format: str # TextFormat
|
||||
role: str # TextRole
|
||||
status: str # UiStatus
|
||||
maxLines: int
|
||||
|
||||
|
||||
class UiIconNode(UiBaseNode, total=False):
|
||||
type: Literal["icon"]
|
||||
source: str # IconSource
|
||||
value: str
|
||||
color: str
|
||||
size: int
|
||||
|
||||
|
||||
class UiBadgeNode(UiBaseNode, total=False):
|
||||
type: Literal["badge"]
|
||||
label: str
|
||||
status: str # UiStatus
|
||||
|
||||
|
||||
class UiButtonNode(UiBaseNode, total=False):
|
||||
type: Literal["button"]
|
||||
label: str
|
||||
style: str # ButtonStyle
|
||||
disabled: bool
|
||||
icon: UiIconSpec
|
||||
action: UiActionPayload
|
||||
|
||||
|
||||
class UiKvNode(UiBaseNode, total=False):
|
||||
type: Literal["kv"]
|
||||
items: list[UiKvItem]
|
||||
columns: int
|
||||
|
||||
|
||||
class UiDividerNode(UiBaseNode, total=False):
|
||||
type: Literal["divider"]
|
||||
inset: int
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Layout Containers
|
||||
# ============================================================
|
||||
|
||||
|
||||
class UiStackNode(UiBaseNode, total=False):
|
||||
type: Literal["stack"]
|
||||
direction: str # LayoutDirection
|
||||
gap: int
|
||||
appearance: str # LayoutAppearance
|
||||
status: str # UiStatus
|
||||
align: str # LayoutAlign
|
||||
justify: str # LayoutJustify
|
||||
wrap: bool
|
||||
children: list["UiNode"]
|
||||
|
||||
|
||||
class UiGridNode(UiBaseNode, total=False):
|
||||
type: Literal["grid"]
|
||||
columns: int
|
||||
gap: int
|
||||
appearance: str # LayoutAppearance
|
||||
status: str # UiStatus
|
||||
children: list["UiNode"]
|
||||
|
||||
|
||||
UiNode = Union[
|
||||
UiTextNode,
|
||||
UiIconNode,
|
||||
UiBadgeNode,
|
||||
UiButtonNode,
|
||||
UiKvNode,
|
||||
UiDividerNode,
|
||||
UiStackNode,
|
||||
UiGridNode,
|
||||
]
|
||||
|
||||
UiLayoutNode = Union[UiStackNode, UiGridNode]
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Root Renderer
|
||||
# ============================================================
|
||||
|
||||
|
||||
class UiSchemaRenderer(TypedDict, total=False):
|
||||
version: str
|
||||
locale: str
|
||||
status: str # UiStatus
|
||||
theme: str # RendererTheme
|
||||
meta: UiMeta
|
||||
root: UiLayoutNode
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Root Builder
|
||||
# ============================================================
|
||||
|
||||
|
||||
def build_renderer(
|
||||
root: UiLayoutNode,
|
||||
*,
|
||||
version: str = "2.0",
|
||||
locale: str = "zh-CN",
|
||||
status: UiStatus = UiStatus.INFO,
|
||||
theme: RendererTheme = RendererTheme.DEFAULT,
|
||||
meta: UiMeta | None = None,
|
||||
) -> UiSchemaRenderer:
|
||||
renderer: UiSchemaRenderer = {
|
||||
"version": version,
|
||||
"locale": locale,
|
||||
"status": status.value,
|
||||
"theme": theme.value,
|
||||
"root": root,
|
||||
}
|
||||
if meta:
|
||||
renderer["meta"] = meta
|
||||
return renderer
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Primitive Builders
|
||||
# ============================================================
|
||||
|
||||
|
||||
def build_text(
|
||||
content: str,
|
||||
*,
|
||||
node_id: str | None = None,
|
||||
format: TextFormat = TextFormat.PLAIN,
|
||||
role: TextRole = TextRole.BODY,
|
||||
status: UiStatus | None = None,
|
||||
max_lines: int | None = None,
|
||||
visible: bool = True,
|
||||
) -> UiTextNode:
|
||||
node: UiTextNode = {
|
||||
"type": "text",
|
||||
"content": content,
|
||||
"format": format.value,
|
||||
"role": role.value,
|
||||
"visible": visible,
|
||||
}
|
||||
if node_id:
|
||||
node["id"] = node_id
|
||||
if status:
|
||||
node["status"] = status.value
|
||||
if max_lines is not None:
|
||||
node["maxLines"] = max_lines
|
||||
return node
|
||||
|
||||
|
||||
def build_icon(
|
||||
source: IconSource,
|
||||
value: str,
|
||||
*,
|
||||
node_id: str | None = None,
|
||||
color: str | None = None,
|
||||
size: int | None = None,
|
||||
visible: bool = True,
|
||||
) -> UiIconNode:
|
||||
node: UiIconNode = {
|
||||
"type": "icon",
|
||||
"source": source.value,
|
||||
"value": value,
|
||||
"visible": visible,
|
||||
}
|
||||
if node_id:
|
||||
node["id"] = node_id
|
||||
if color:
|
||||
node["color"] = color
|
||||
if size is not None:
|
||||
node["size"] = size
|
||||
return node
|
||||
|
||||
|
||||
def build_badge(
|
||||
label: str,
|
||||
*,
|
||||
node_id: str | None = None,
|
||||
status: UiStatus = UiStatus.INFO,
|
||||
visible: bool = True,
|
||||
) -> UiBadgeNode:
|
||||
node: UiBadgeNode = {
|
||||
"type": "badge",
|
||||
"label": label,
|
||||
"status": status.value,
|
||||
"visible": visible,
|
||||
}
|
||||
if node_id:
|
||||
node["id"] = node_id
|
||||
return node
|
||||
|
||||
|
||||
def build_button(
|
||||
label: str,
|
||||
action: UiActionPayload,
|
||||
*,
|
||||
node_id: str | None = None,
|
||||
style: ButtonStyle = ButtonStyle.PRIMARY,
|
||||
disabled: bool = False,
|
||||
icon: UiIconSpec | None = None,
|
||||
visible: bool = True,
|
||||
) -> UiButtonNode:
|
||||
node: UiButtonNode = {
|
||||
"type": "button",
|
||||
"label": label,
|
||||
"style": style.value,
|
||||
"disabled": disabled,
|
||||
"action": action,
|
||||
"visible": visible,
|
||||
}
|
||||
if node_id:
|
||||
node["id"] = node_id
|
||||
if icon:
|
||||
node["icon"] = icon
|
||||
return node
|
||||
|
||||
|
||||
def build_kv(
|
||||
items: list[UiKvItem],
|
||||
*,
|
||||
node_id: str | None = None,
|
||||
columns: int = 1,
|
||||
visible: bool = True,
|
||||
) -> UiKvNode:
|
||||
node: UiKvNode = {
|
||||
"type": "kv",
|
||||
"items": items,
|
||||
"columns": columns,
|
||||
"visible": visible,
|
||||
}
|
||||
if node_id:
|
||||
node["id"] = node_id
|
||||
return node
|
||||
|
||||
|
||||
def build_divider(
|
||||
*,
|
||||
node_id: str | None = None,
|
||||
inset: int = 0,
|
||||
visible: bool = True,
|
||||
) -> UiDividerNode:
|
||||
node: UiDividerNode = {
|
||||
"type": "divider",
|
||||
"inset": inset,
|
||||
"visible": visible,
|
||||
}
|
||||
if node_id:
|
||||
node["id"] = node_id
|
||||
return node
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Layout Builders
|
||||
# ============================================================
|
||||
|
||||
|
||||
def build_stack(
|
||||
children: list[UiNode],
|
||||
*,
|
||||
node_id: str | None = None,
|
||||
direction: LayoutDirection = LayoutDirection.VERTICAL,
|
||||
gap: int = 12,
|
||||
appearance: LayoutAppearance = LayoutAppearance.PLAIN,
|
||||
status: UiStatus | None = None,
|
||||
align: LayoutAlign = LayoutAlign.START,
|
||||
justify: LayoutJustify = LayoutJustify.START,
|
||||
wrap: bool = False,
|
||||
visible: bool = True,
|
||||
) -> UiStackNode:
|
||||
node: UiStackNode = {
|
||||
"type": "stack",
|
||||
"direction": direction.value,
|
||||
"gap": gap,
|
||||
"appearance": appearance.value,
|
||||
"align": align.value,
|
||||
"justify": justify.value,
|
||||
"wrap": wrap,
|
||||
"children": children,
|
||||
"visible": visible,
|
||||
}
|
||||
if node_id:
|
||||
node["id"] = node_id
|
||||
if status:
|
||||
node["status"] = status.value
|
||||
return node
|
||||
|
||||
|
||||
def build_grid(
|
||||
children: list[UiNode],
|
||||
*,
|
||||
columns: int,
|
||||
node_id: str | None = None,
|
||||
gap: int = 12,
|
||||
appearance: LayoutAppearance = LayoutAppearance.PLAIN,
|
||||
status: UiStatus | None = None,
|
||||
visible: bool = True,
|
||||
) -> UiGridNode:
|
||||
node: UiGridNode = {
|
||||
"type": "grid",
|
||||
"columns": columns,
|
||||
"gap": gap,
|
||||
"appearance": appearance.value,
|
||||
"children": children,
|
||||
"visible": visible,
|
||||
}
|
||||
if node_id:
|
||||
node["id"] = node_id
|
||||
if status:
|
||||
node["status"] = status.value
|
||||
return node
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Small Action Builders
|
||||
# ============================================================
|
||||
|
||||
|
||||
def action_navigation(
|
||||
path: str, params: dict[str, Any] | None = None
|
||||
) -> NavigateAction:
|
||||
action: NavigateAction = {"type": "navigation", "path": path}
|
||||
if params:
|
||||
action["params"] = params
|
||||
return action
|
||||
|
||||
|
||||
def action_url(url: str, target: Literal["_self", "_blank"] = "_blank") -> UrlAction:
|
||||
return {"type": "url", "url": url, "target": target}
|
||||
|
||||
|
||||
def action_event(event: str, payload: dict[str, Any] | None = None) -> EventAction:
|
||||
action: EventAction = {"type": "event", "event": event}
|
||||
if payload:
|
||||
action["payload"] = payload
|
||||
return action
|
||||
|
||||
|
||||
def action_tool(tool_id: str, params: dict[str, Any] | None = None) -> ToolAction:
|
||||
action: ToolAction = {"type": "tool", "toolId": tool_id}
|
||||
if params:
|
||||
action["params"] = params
|
||||
return action
|
||||
|
||||
|
||||
def action_copy(content: str, success_message: str | None = None) -> CopyAction:
|
||||
action: CopyAction = {"type": "copy", "content": content}
|
||||
if success_message:
|
||||
action["successMessage"] = success_message
|
||||
return action
|
||||
|
||||
|
||||
def action_payload(
|
||||
payload: dict[str, Any], submit_to: str | None = None
|
||||
) -> PayloadAction:
|
||||
action: PayloadAction = {"type": "payload", "payload": payload}
|
||||
if submit_to:
|
||||
action["submitTo"] = submit_to
|
||||
return action
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Derived Helpers (协议外的便捷封装,不是基础原语)
|
||||
# ============================================================
|
||||
|
||||
|
||||
def build_card(
|
||||
children: list[UiNode],
|
||||
*,
|
||||
node_id: str | None = None,
|
||||
gap: int = 12,
|
||||
status: UiStatus | None = None,
|
||||
) -> UiStackNode:
|
||||
return build_stack(
|
||||
children,
|
||||
node_id=node_id,
|
||||
direction=LayoutDirection.VERTICAL,
|
||||
gap=gap,
|
||||
appearance=LayoutAppearance.CARD,
|
||||
status=status,
|
||||
)
|
||||
|
||||
|
||||
def build_section(
|
||||
title: str,
|
||||
children: list[UiNode],
|
||||
*,
|
||||
node_id: str | None = None,
|
||||
description: str | None = None,
|
||||
status: UiStatus | None = None,
|
||||
gap: int = 12,
|
||||
) -> UiStackNode:
|
||||
header_nodes: list[UiNode] = [build_text(title, role=TextRole.TITLE)]
|
||||
if description:
|
||||
header_nodes.append(build_text(description, role=TextRole.CAPTION))
|
||||
|
||||
all_children = header_nodes + children
|
||||
return build_stack(
|
||||
all_children,
|
||||
node_id=node_id,
|
||||
direction=LayoutDirection.VERTICAL,
|
||||
gap=gap,
|
||||
appearance=LayoutAppearance.SECTION,
|
||||
status=status,
|
||||
)
|
||||
|
||||
|
||||
def build_status_panel(
|
||||
title: str,
|
||||
message: str,
|
||||
*,
|
||||
status: UiStatus,
|
||||
primary_button: UiButtonNode | None = None,
|
||||
secondary_button: UiButtonNode | None = None,
|
||||
node_id: str | None = None,
|
||||
) -> UiStackNode:
|
||||
status_label = f"ui.status.{status.value}"
|
||||
children: list[UiNode] = [
|
||||
build_stack(
|
||||
[
|
||||
build_text(title, role=TextRole.TITLE),
|
||||
build_badge(label=status_label, status=status),
|
||||
],
|
||||
direction=LayoutDirection.HORIZONTAL,
|
||||
gap=8,
|
||||
align=LayoutAlign.CENTER,
|
||||
justify=LayoutJustify.SPACE_BETWEEN,
|
||||
),
|
||||
build_text(message, role=TextRole.BODY, status=status),
|
||||
]
|
||||
|
||||
actions: list[UiNode] = []
|
||||
if primary_button:
|
||||
actions.append(primary_button)
|
||||
if secondary_button:
|
||||
actions.append(secondary_button)
|
||||
|
||||
if actions:
|
||||
children.append(
|
||||
build_stack(
|
||||
actions,
|
||||
direction=LayoutDirection.HORIZONTAL,
|
||||
gap=8,
|
||||
)
|
||||
)
|
||||
|
||||
return build_card(children, node_id=node_id, status=status)
|
||||
@@ -0,0 +1,50 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import IntEnum
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
|
||||
class SystemVisibilityBit(IntEnum):
|
||||
UI_HISTORY = 0
|
||||
CONTEXT_ASSEMBLY = 1
|
||||
|
||||
|
||||
class VisibilityMask(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
value: int = Field(..., ge=0, le=(1 << 63) - 1)
|
||||
|
||||
@classmethod
|
||||
def from_bits(cls, *, bits: list[int]) -> "VisibilityMask":
|
||||
mask = 0
|
||||
for bit in bits:
|
||||
validate_visibility_bit(bit=bit)
|
||||
mask |= 1 << bit
|
||||
return cls(value=mask)
|
||||
|
||||
def contains(self, *, bit: int) -> bool:
|
||||
validate_visibility_bit(bit=bit)
|
||||
return bool(self.value & (1 << bit))
|
||||
|
||||
|
||||
class VisibilityBitRef(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
bit: int = Field(..., ge=0, le=63)
|
||||
|
||||
@field_validator("bit")
|
||||
@classmethod
|
||||
def _validate_bit(cls, value: int) -> int:
|
||||
validate_visibility_bit(bit=value)
|
||||
return value
|
||||
|
||||
|
||||
def validate_visibility_bit(*, bit: int) -> None:
|
||||
if bit < 0 or bit > 63:
|
||||
raise ValueError("visibility bit must be in range [0, 63]")
|
||||
|
||||
|
||||
def bit_mask(*, bit: int) -> int:
|
||||
validate_visibility_bit(bit=bit)
|
||||
return 1 << bit
|
||||
@@ -0,0 +1 @@
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from services.base.redis import RedisService, get_or_init_redis_client, redis_service
|
||||
from services.base.service_interface import (
|
||||
BaseServiceProvider,
|
||||
ServiceRegistry,
|
||||
close_registered_services,
|
||||
initialize_registered_services,
|
||||
register_service,
|
||||
register_service_instance,
|
||||
resolve_registered_services,
|
||||
)
|
||||
from services.base.supabase import SupabaseService, supabase_service
|
||||
|
||||
__all__ = [
|
||||
"BaseServiceProvider",
|
||||
"RedisService",
|
||||
"ServiceRegistry",
|
||||
"SupabaseService",
|
||||
"close_registered_services",
|
||||
"get_or_init_redis_client",
|
||||
"initialize_registered_services",
|
||||
"redis_service",
|
||||
"register_service",
|
||||
"register_service_instance",
|
||||
"resolve_registered_services",
|
||||
"supabase_service",
|
||||
]
|
||||
@@ -0,0 +1,136 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import redis.asyncio as redis
|
||||
|
||||
from core.config.settings import RedisSettings, config
|
||||
|
||||
from .service_interface import BaseServiceProvider, register_service_instance
|
||||
|
||||
|
||||
class RedisService(BaseServiceProvider):
|
||||
def __init__(self, settings: RedisSettings | None = None) -> None:
|
||||
super().__init__("redis")
|
||||
self._settings = settings or config.redis
|
||||
self._client: Optional[redis.Redis] = None
|
||||
self._loop_id: int | None = None
|
||||
|
||||
def _build_client(self) -> redis.Redis:
|
||||
return redis.from_url(
|
||||
self._settings.url,
|
||||
decode_responses=True,
|
||||
socket_connect_timeout=self._settings.socket_connect_timeout,
|
||||
socket_timeout=self._settings.socket_timeout,
|
||||
max_connections=self._settings.max_connections,
|
||||
)
|
||||
|
||||
def _require_client(self) -> redis.Redis:
|
||||
client = self._client
|
||||
if client is None:
|
||||
raise RuntimeError("Redis client is not initialized")
|
||||
return client
|
||||
|
||||
async def initialize(self, **_: Any) -> bool:
|
||||
try:
|
||||
client = self._build_client()
|
||||
ping_result = client.ping()
|
||||
if inspect.isawaitable(ping_result):
|
||||
await ping_result
|
||||
self._client = client
|
||||
self._loop_id = _current_loop_id()
|
||||
self._set_initialized(True)
|
||||
self.logger.info("Redis service initialized")
|
||||
return True
|
||||
except Exception as exc: # noqa: BLE001
|
||||
self.logger.warning("Redis service initialization failed", error=str(exc))
|
||||
self._client = None
|
||||
self._loop_id = None
|
||||
self._set_initialized(False)
|
||||
return False
|
||||
|
||||
async def close(self) -> bool:
|
||||
client = self._client
|
||||
if client is None:
|
||||
self._loop_id = None
|
||||
return True
|
||||
try:
|
||||
await client.aclose()
|
||||
self.logger.info("Redis service closed")
|
||||
return True
|
||||
except Exception as exc: # noqa: BLE001
|
||||
self.logger.exception("Redis service close failed", error=str(exc))
|
||||
return False
|
||||
finally:
|
||||
self._client = None
|
||||
self._loop_id = None
|
||||
self._set_initialized(False)
|
||||
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
client = self._client
|
||||
if client is None:
|
||||
return {"status": "unhealthy", "details": {"error": "not initialized"}}
|
||||
try:
|
||||
ping_result = client.ping()
|
||||
ping = (
|
||||
await ping_result if inspect.isawaitable(ping_result) else ping_result
|
||||
)
|
||||
info_result = client.info()
|
||||
info = (
|
||||
await info_result if inspect.isawaitable(info_result) else info_result
|
||||
)
|
||||
return {
|
||||
"status": "healthy" if ping else "unhealthy",
|
||||
"details": {
|
||||
"ping": ping,
|
||||
"redis_version": info.get("redis_version"),
|
||||
"connected_clients": info.get("connected_clients"),
|
||||
"used_memory": info.get("used_memory_human"),
|
||||
"uptime_in_seconds": info.get("uptime_in_seconds"),
|
||||
},
|
||||
}
|
||||
except Exception as exc: # noqa: BLE001
|
||||
self.logger.warning("Redis health check failed", error=str(exc))
|
||||
return {"status": "unhealthy", "details": {"error": str(exc)}}
|
||||
|
||||
def get_client(self) -> redis.Redis:
|
||||
return self._require_client()
|
||||
|
||||
|
||||
def _current_loop_id() -> int | None:
|
||||
try:
|
||||
return id(asyncio.get_running_loop())
|
||||
except RuntimeError:
|
||||
return None
|
||||
|
||||
|
||||
async def get_or_init_redis_client() -> redis.Redis:
|
||||
current_loop_id = _current_loop_id()
|
||||
bound_loop_id = redis_service._loop_id
|
||||
if (
|
||||
redis_service.is_initialized
|
||||
and bound_loop_id is not None
|
||||
and current_loop_id is not None
|
||||
and bound_loop_id != current_loop_id
|
||||
):
|
||||
redis_service.logger.warning(
|
||||
"Redis client bound to different event loop; reinitializing",
|
||||
previous_loop_id=bound_loop_id,
|
||||
current_loop_id=current_loop_id,
|
||||
)
|
||||
redis_service._client = None
|
||||
redis_service._loop_id = None
|
||||
redis_service._set_initialized(False)
|
||||
|
||||
if not redis_service.is_initialized:
|
||||
initialized = await redis_service.initialize()
|
||||
if not initialized:
|
||||
raise RuntimeError("Redis service initialization failed")
|
||||
return redis_service.get_client()
|
||||
|
||||
|
||||
redis_service: RedisService = register_service_instance("redis", RedisService())
|
||||
|
||||
__all__ = ["RedisService", "get_or_init_redis_client", "redis_service"]
|
||||
@@ -0,0 +1,158 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, Optional, TypeVar
|
||||
|
||||
from core.logging import get_logger
|
||||
|
||||
|
||||
class BaseServiceProvider(ABC):
|
||||
def __init__(self, service_name: str) -> None:
|
||||
self.service_name = service_name
|
||||
self._initialized = False
|
||||
self.logger = get_logger("services.base").bind(service=service_name)
|
||||
|
||||
@abstractmethod
|
||||
async def initialize(self, **kwargs: Any) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def close(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
return self._initialized
|
||||
|
||||
def _set_initialized(self, value: bool) -> None:
|
||||
self._initialized = value
|
||||
|
||||
def get_service_info(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"name": self.service_name,
|
||||
"initialized": self._initialized,
|
||||
"type": self.__class__.__name__,
|
||||
}
|
||||
|
||||
|
||||
class ServiceRegistry:
|
||||
_services: Dict[str, Callable[..., BaseServiceProvider]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(
|
||||
cls, service_name: str, factory: Callable[..., BaseServiceProvider]
|
||||
) -> None:
|
||||
cls._services = {**cls._services, service_name: factory}
|
||||
|
||||
@classmethod
|
||||
def get_service_factory(
|
||||
cls, service_name: str
|
||||
) -> Optional[Callable[..., BaseServiceProvider]]:
|
||||
return cls._services.get(service_name)
|
||||
|
||||
@classmethod
|
||||
def list_services(cls) -> list[str]:
|
||||
return sorted(cls._services.keys())
|
||||
|
||||
@classmethod
|
||||
def create_service(
|
||||
cls, service_name: str, **kwargs: Any
|
||||
) -> Optional[BaseServiceProvider]:
|
||||
return cls.get_service(service_name, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_service(
|
||||
cls, service_name: str, **kwargs: Any
|
||||
) -> Optional[BaseServiceProvider]:
|
||||
factory = cls.get_service_factory(service_name)
|
||||
if not factory:
|
||||
return None
|
||||
return factory(**kwargs)
|
||||
|
||||
|
||||
def register_service(service_name: str) -> Callable[[type], type]:
|
||||
def decorator(service_class: type) -> type:
|
||||
ServiceRegistry.register(service_name, service_class)
|
||||
return service_class
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
TService = TypeVar("TService", bound=BaseServiceProvider)
|
||||
|
||||
|
||||
def register_service_instance(service_name: str, service: TService) -> TService:
|
||||
ServiceRegistry.register(service_name, lambda: service)
|
||||
return service
|
||||
|
||||
|
||||
def resolve_registered_services(service_names: list[str]) -> list[BaseServiceProvider]:
|
||||
services: list[BaseServiceProvider] = []
|
||||
for service_name in service_names:
|
||||
service = ServiceRegistry.get_service(service_name)
|
||||
if service is None:
|
||||
raise RuntimeError(f"Service is not registered: {service_name}")
|
||||
services.append(service)
|
||||
return services
|
||||
|
||||
|
||||
async def close_registered_services(services: list[BaseServiceProvider]) -> bool:
|
||||
lifecycle_logger = get_logger("services.base.lifecycle")
|
||||
all_closed = True
|
||||
for service in reversed(services):
|
||||
try:
|
||||
closed = await service.close()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
lifecycle_logger.warning(
|
||||
"Failed to close service",
|
||||
service=service.service_name,
|
||||
error=str(exc),
|
||||
)
|
||||
all_closed = False
|
||||
continue
|
||||
if not closed:
|
||||
lifecycle_logger.warning(
|
||||
"Service close returned false",
|
||||
service=service.service_name,
|
||||
)
|
||||
all_closed = False
|
||||
return all_closed
|
||||
|
||||
|
||||
async def initialize_registered_services(
|
||||
service_names: list[str],
|
||||
) -> tuple[bool, list[BaseServiceProvider]]:
|
||||
lifecycle_logger = get_logger("services.base.lifecycle")
|
||||
initialized_services: list[BaseServiceProvider] = []
|
||||
try:
|
||||
services = resolve_registered_services(service_names)
|
||||
except RuntimeError as exc:
|
||||
lifecycle_logger.error("Failed to resolve registered services", error=str(exc))
|
||||
return False, []
|
||||
|
||||
for service in services:
|
||||
try:
|
||||
initialized = await service.initialize()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
lifecycle_logger.warning(
|
||||
"Service initialization raised exception",
|
||||
service=service.service_name,
|
||||
error=str(exc),
|
||||
)
|
||||
initialized = False
|
||||
|
||||
if not initialized:
|
||||
lifecycle_logger.error(
|
||||
"Service initialization failed, rolling back",
|
||||
service=service.service_name,
|
||||
)
|
||||
await close_registered_services(initialized_services)
|
||||
return False, []
|
||||
|
||||
initialized_services.append(service)
|
||||
|
||||
return True, initialized_services
|
||||
@@ -0,0 +1,304 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from supabase import create_client
|
||||
from storage3.exceptions import StorageApiError
|
||||
|
||||
from core.config.settings import SupabaseSettings, config
|
||||
|
||||
from .service_interface import BaseServiceProvider, register_service_instance
|
||||
|
||||
|
||||
class SupabaseService(BaseServiceProvider):
|
||||
def __init__(self, settings: SupabaseSettings | None = None) -> None:
|
||||
super().__init__("supabase")
|
||||
self._settings = settings or config.supabase
|
||||
self._client: Any = None
|
||||
self._admin_client: Any = None
|
||||
|
||||
async def initialize(self, **_: Any) -> bool:
|
||||
try:
|
||||
self._init_clients()
|
||||
await self._ensure_storage_bucket()
|
||||
self._set_initialized(True)
|
||||
self.logger.info("Supabase service initialized")
|
||||
return True
|
||||
except Exception as exc: # noqa: BLE001
|
||||
self.logger.warning(
|
||||
"Supabase service initialization failed", error=str(exc)
|
||||
)
|
||||
self._client = None
|
||||
self._admin_client = None
|
||||
self._set_initialized(False)
|
||||
return False
|
||||
|
||||
async def close(self) -> bool:
|
||||
self._client = None
|
||||
self._admin_client = None
|
||||
self._set_initialized(False)
|
||||
self.logger.info("Supabase service closed")
|
||||
return True
|
||||
|
||||
async def health_check(self) -> dict[str, Any]:
|
||||
client = self._client
|
||||
admin_client = self._admin_client
|
||||
if client is None or admin_client is None:
|
||||
return {"status": "unhealthy", "details": {"error": "not initialized"}}
|
||||
try:
|
||||
await asyncio.to_thread(client.auth.get_session)
|
||||
await asyncio.to_thread(
|
||||
admin_client.auth.admin.list_users, page=1, per_page=1
|
||||
)
|
||||
return {
|
||||
"status": "healthy",
|
||||
"details": {
|
||||
"anon_client": "ready",
|
||||
"admin_client": "ready",
|
||||
},
|
||||
}
|
||||
except Exception as exc: # noqa: BLE001
|
||||
self.logger.warning("Supabase health check failed", error=str(exc))
|
||||
return {"status": "unhealthy", "details": {"error": str(exc)}}
|
||||
|
||||
def get_client(self) -> Any:
|
||||
return self._require_client()
|
||||
|
||||
def get_admin_client(self) -> Any:
|
||||
return self._require_admin_client()
|
||||
|
||||
def _require_client(self) -> Any:
|
||||
if self._client is None or self._admin_client is None:
|
||||
self._init_clients()
|
||||
self._set_initialized(True)
|
||||
self.logger.info("Supabase service lazily initialized")
|
||||
client = self._client
|
||||
if client is None:
|
||||
raise RuntimeError("Supabase client is not initialized")
|
||||
return client
|
||||
|
||||
def _require_admin_client(self) -> Any:
|
||||
if self._client is None or self._admin_client is None:
|
||||
self._init_clients()
|
||||
self._set_initialized(True)
|
||||
self.logger.info("Supabase service lazily initialized")
|
||||
admin_client = self._admin_client
|
||||
if admin_client is None:
|
||||
raise RuntimeError("Supabase admin client is not initialized")
|
||||
return admin_client
|
||||
|
||||
def _init_clients(self) -> None:
|
||||
self._client = create_client(
|
||||
self._settings.url,
|
||||
self._settings.anon_key,
|
||||
)
|
||||
self._admin_client = create_client(
|
||||
self._settings.url,
|
||||
self._settings.service_role_key,
|
||||
)
|
||||
|
||||
async def _ensure_storage_bucket(self) -> None:
|
||||
storage = getattr(self._admin_client, "storage", None)
|
||||
if storage is None:
|
||||
self.logger.warning("Storage client unavailable, skipping bucket check")
|
||||
return
|
||||
|
||||
get_bucket = getattr(storage, "get_bucket", None)
|
||||
if not callable(get_bucket):
|
||||
self.logger.warning("Storage get_bucket unavailable, skipping bucket check")
|
||||
return
|
||||
|
||||
buckets = [
|
||||
(config.storage.attachment.bucket, False),
|
||||
(config.storage.avatar.bucket, True),
|
||||
]
|
||||
|
||||
def _check_and_create() -> None:
|
||||
for bucket_name, is_public in buckets:
|
||||
try:
|
||||
get_bucket(bucket_name)
|
||||
self.logger.debug(
|
||||
"Storage bucket already exists", bucket=bucket_name
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
create_bucket = getattr(storage, "create_bucket", None)
|
||||
if not callable(create_bucket):
|
||||
self.logger.warning(
|
||||
"Storage create_bucket unavailable, skipping bucket creation"
|
||||
)
|
||||
return
|
||||
try:
|
||||
create_bucket(bucket_name, options={"public": is_public})
|
||||
self.logger.info(
|
||||
"Storage bucket created",
|
||||
bucket=bucket_name,
|
||||
public=is_public,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
msg = str(exc).lower()
|
||||
if "already exists" in msg or "duplicate" in msg:
|
||||
self.logger.debug(
|
||||
"Storage bucket already exists (race)",
|
||||
bucket=bucket_name,
|
||||
)
|
||||
continue
|
||||
self.logger.warning(
|
||||
"Failed to create storage bucket",
|
||||
bucket=bucket_name,
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
await asyncio.to_thread(_check_and_create)
|
||||
|
||||
def _get_storage(self) -> Any:
|
||||
"""Get the storage client from admin client."""
|
||||
client = self.get_admin_client()
|
||||
storage = getattr(client, "storage", None)
|
||||
if storage is None:
|
||||
raise RuntimeError("Supabase storage client unavailable")
|
||||
return storage
|
||||
|
||||
def _get_bucket_client(self, bucket: str) -> Any:
|
||||
"""Get a bucket client for the specified bucket."""
|
||||
storage = self._get_storage()
|
||||
from_bucket = getattr(storage, "from_", None)
|
||||
if not callable(from_bucket):
|
||||
raise RuntimeError("Supabase storage bucket accessor unavailable")
|
||||
return from_bucket(bucket)
|
||||
|
||||
def _validate_bucket(self, bucket: str) -> None:
|
||||
"""Validate that the bucket matches one of configured storage buckets."""
|
||||
allowed_buckets = {
|
||||
config.storage.attachment.bucket,
|
||||
config.storage.avatar.bucket,
|
||||
}
|
||||
if bucket not in allowed_buckets:
|
||||
raise RuntimeError("Invalid storage bucket")
|
||||
|
||||
def _ensure_bucket_client(self, bucket: str) -> Any:
|
||||
"""Validate bucket and return authenticated bucket client."""
|
||||
self._validate_bucket(bucket)
|
||||
return self._get_bucket_client(bucket)
|
||||
|
||||
def _is_bucket_not_found_error(self, exc: Exception) -> bool:
|
||||
"""Check if the exception indicates a bucket was not found."""
|
||||
if isinstance(exc, StorageApiError):
|
||||
message = str(exc).lower()
|
||||
return "bucket" in message and "not found" in message
|
||||
message = str(exc).lower()
|
||||
return "bucket" in message and "not found" in message
|
||||
|
||||
async def upload_bytes(
|
||||
self,
|
||||
*,
|
||||
bucket: str,
|
||||
path: str,
|
||||
content: bytes,
|
||||
content_type: str,
|
||||
) -> str:
|
||||
def _upload() -> object:
|
||||
bucket_client = self._ensure_bucket_client(bucket)
|
||||
upload = getattr(bucket_client, "upload", None)
|
||||
if not callable(upload):
|
||||
raise RuntimeError("Supabase storage upload is unavailable")
|
||||
return upload(
|
||||
path,
|
||||
content,
|
||||
{
|
||||
"content-type": content_type,
|
||||
"upsert": "true",
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
await asyncio.to_thread(_upload)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
if not self._is_bucket_not_found_error(exc):
|
||||
raise
|
||||
await self._ensure_bucket_exists(bucket=bucket)
|
||||
await asyncio.to_thread(_upload)
|
||||
return path
|
||||
|
||||
async def _ensure_bucket_exists(self, *, bucket: str) -> None:
|
||||
def _ensure() -> None:
|
||||
storage = self._get_storage()
|
||||
get_bucket = getattr(storage, "get_bucket", None)
|
||||
if not callable(get_bucket):
|
||||
raise RuntimeError("Supabase storage get_bucket is unavailable")
|
||||
try:
|
||||
get_bucket(bucket)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
msg = str(exc).lower()
|
||||
if "bucket" in msg and "not found" in msg:
|
||||
raise RuntimeError(f"Storage bucket '{bucket}' does not exist")
|
||||
raise
|
||||
|
||||
await asyncio.to_thread(_ensure)
|
||||
|
||||
async def download_bytes(self, *, bucket: str, path: str) -> bytes:
|
||||
def _download() -> object:
|
||||
bucket_client = self._ensure_bucket_client(bucket)
|
||||
download = getattr(bucket_client, "download", None)
|
||||
if not callable(download):
|
||||
raise RuntimeError("Supabase storage download is unavailable")
|
||||
return download(path)
|
||||
|
||||
raw = await asyncio.to_thread(_download)
|
||||
if isinstance(raw, bytes):
|
||||
return raw
|
||||
if isinstance(raw, bytearray):
|
||||
return bytes(raw)
|
||||
if isinstance(raw, memoryview):
|
||||
return raw.tobytes()
|
||||
raise RuntimeError("Invalid attachment payload")
|
||||
|
||||
async def create_signed_url(
|
||||
self,
|
||||
*,
|
||||
bucket: str,
|
||||
path: str,
|
||||
expires_in_seconds: int,
|
||||
) -> str:
|
||||
def _create_signed_url() -> object:
|
||||
bucket_client = self._ensure_bucket_client(bucket)
|
||||
signer = getattr(bucket_client, "create_signed_url", None)
|
||||
if not callable(signer):
|
||||
raise RuntimeError("Supabase storage signed url is unavailable")
|
||||
return signer(path, expires_in_seconds)
|
||||
|
||||
raw = await asyncio.to_thread(_create_signed_url)
|
||||
if isinstance(raw, str):
|
||||
return raw
|
||||
if isinstance(raw, dict):
|
||||
signed_url = raw.get("signedURL") or raw.get("signedUrl") or raw.get("url")
|
||||
if isinstance(signed_url, str) and signed_url:
|
||||
return signed_url
|
||||
raise RuntimeError("Invalid signed url payload")
|
||||
|
||||
def parse_signed_url(self, url: str) -> tuple[str, str]:
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed = urlparse(url)
|
||||
path_parts = parsed.path.strip("/").split("/")
|
||||
|
||||
if (
|
||||
len(path_parts) < 4
|
||||
or path_parts[0] != "storage"
|
||||
or path_parts[1] != "v1"
|
||||
or path_parts[2] != "object"
|
||||
or path_parts[3] != "sign"
|
||||
):
|
||||
raise RuntimeError("Invalid signed URL format")
|
||||
|
||||
bucket = path_parts[4]
|
||||
path = "/".join(path_parts[5:])
|
||||
|
||||
return bucket, path
|
||||
|
||||
|
||||
supabase_service: SupabaseService = register_service_instance(
|
||||
"supabase", SupabaseService()
|
||||
)
|
||||
|
||||
__all__ = ["SupabaseService", "supabase_service"]
|
||||
@@ -0,0 +1,4 @@
|
||||
from .factory import get_cache_store
|
||||
from .interfaces import CacheStore
|
||||
|
||||
__all__ = ["CacheStore", "get_cache_store"]
|
||||
@@ -0,0 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .interfaces import CacheStore
|
||||
from .redis_store import RedisCacheStore
|
||||
|
||||
_cache_store: CacheStore | None = None
|
||||
|
||||
|
||||
def get_cache_store() -> CacheStore:
|
||||
global _cache_store
|
||||
if _cache_store is None:
|
||||
_cache_store = RedisCacheStore()
|
||||
return _cache_store
|
||||
@@ -0,0 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
class CacheStore(Protocol):
|
||||
async def hgetall(self, key: str, /) -> dict[str, str]: ...
|
||||
|
||||
async def hset(self, key: str, /, mapping: dict[str, str]) -> int: ...
|
||||
|
||||
async def hincrby(self, key: str, field: str, amount: int = 1, /) -> int: ...
|
||||
|
||||
async def expire(self, key: str, ttl_seconds: int, /) -> int: ...
|
||||
|
||||
async def delete(self, *keys: str) -> int: ...
|
||||
|
||||
async def sadd(self, key: str, *members: str) -> int: ...
|
||||
|
||||
async def smembers(self, key: str, /) -> set[str]: ...
|
||||
@@ -0,0 +1,80 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from typing import Any
|
||||
|
||||
from services.base.redis import get_or_init_redis_client
|
||||
|
||||
from .interfaces import CacheStore
|
||||
|
||||
|
||||
def _to_text(value: Any) -> str | None:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
if isinstance(value, bytes):
|
||||
try:
|
||||
return value.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
async def _maybe_await(value: Any) -> Any:
|
||||
if inspect.isawaitable(value):
|
||||
return await value
|
||||
return value
|
||||
|
||||
|
||||
class RedisCacheStore(CacheStore):
|
||||
async def hgetall(self, key: str) -> dict[str, str]:
|
||||
client = await get_or_init_redis_client()
|
||||
raw = await _maybe_await(client.hgetall(key))
|
||||
if not isinstance(raw, dict):
|
||||
return {}
|
||||
|
||||
decoded: dict[str, str] = {}
|
||||
for raw_key, raw_value in raw.items():
|
||||
key_text = _to_text(raw_key)
|
||||
value_text = _to_text(raw_value)
|
||||
if key_text is None or value_text is None:
|
||||
continue
|
||||
decoded[key_text] = value_text
|
||||
return decoded
|
||||
|
||||
async def hset(self, key: str, mapping: dict[str, str]) -> int:
|
||||
client = await get_or_init_redis_client()
|
||||
result = await _maybe_await(client.hset(key, mapping=mapping))
|
||||
return int(result)
|
||||
|
||||
async def hincrby(self, key: str, field: str, amount: int = 1) -> int:
|
||||
client = await get_or_init_redis_client()
|
||||
result = await _maybe_await(client.hincrby(key, field, amount))
|
||||
return int(result)
|
||||
|
||||
async def expire(self, key: str, ttl_seconds: int) -> int:
|
||||
client = await get_or_init_redis_client()
|
||||
result = await _maybe_await(client.expire(key, ttl_seconds))
|
||||
return int(result)
|
||||
|
||||
async def delete(self, *keys: str) -> int:
|
||||
if not keys:
|
||||
return 0
|
||||
client = await get_or_init_redis_client()
|
||||
result = await _maybe_await(client.delete(*keys))
|
||||
return int(result)
|
||||
|
||||
async def sadd(self, key: str, *members: str) -> int:
|
||||
if not members:
|
||||
return 0
|
||||
client = await get_or_init_redis_client()
|
||||
result = await _maybe_await(client.sadd(key, *members))
|
||||
return int(result)
|
||||
|
||||
async def smembers(self, key: str) -> set[str]:
|
||||
client = await get_or_init_redis_client()
|
||||
raw = await _maybe_await(client.smembers(key))
|
||||
if isinstance(raw, set):
|
||||
return {value for item in raw if (value := _to_text(item))}
|
||||
if isinstance(raw, list | tuple):
|
||||
return {value for item in raw if (value := _to_text(item))}
|
||||
return set()
|
||||
@@ -0,0 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from services.llm_pricing.service import LlmPricingService
|
||||
|
||||
__all__ = ["LlmPricingService"]
|
||||
@@ -0,0 +1,183 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from core.config.initial.init_data import load_llm_catalog
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PricingTier:
|
||||
max_prompt_tokens: int
|
||||
input_cost_per_token: float
|
||||
output_cost_per_token: float
|
||||
cache_hit_cost_per_token: float
|
||||
|
||||
|
||||
class LlmPricingService:
|
||||
_pricing_by_model: dict[str, tuple[PricingTier, ...]]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._pricing_by_model = self._build_pricing_map()
|
||||
|
||||
@staticmethod
|
||||
def _build_pricing_map() -> dict[str, tuple[PricingTier, ...]]:
|
||||
catalog = load_llm_catalog()
|
||||
pricing_by_model: dict[str, tuple[PricingTier, ...]] = {}
|
||||
for model in catalog.get("llms", []):
|
||||
if not isinstance(model, dict):
|
||||
continue
|
||||
model_code = str(model.get("model_code", "")).strip().lower()
|
||||
raw_tiers = model.get("pricing_tiers")
|
||||
if not isinstance(raw_tiers, list) or not raw_tiers:
|
||||
continue
|
||||
|
||||
tiers = [
|
||||
PricingTier(
|
||||
max_prompt_tokens=int(item.get("max_prompt_tokens", 0) or 0),
|
||||
input_cost_per_token=float(
|
||||
item.get("input_cost_per_token", 0.0) or 0.0
|
||||
),
|
||||
output_cost_per_token=float(
|
||||
item.get("output_cost_per_token", 0.0) or 0.0
|
||||
),
|
||||
cache_hit_cost_per_token=float(
|
||||
item.get("cache_hit_cost_per_token", 0.0) or 0.0
|
||||
),
|
||||
)
|
||||
for item in raw_tiers
|
||||
if isinstance(item, dict)
|
||||
]
|
||||
if not tiers:
|
||||
continue
|
||||
ordered_tiers = tuple(
|
||||
sorted(tiers, key=lambda item: item.max_prompt_tokens)
|
||||
)
|
||||
if model_code:
|
||||
pricing_by_model[model_code] = ordered_tiers
|
||||
return pricing_by_model
|
||||
|
||||
def calculate_cost(
|
||||
self,
|
||||
*,
|
||||
model: str,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
cached_prompt_tokens: int = 0,
|
||||
) -> float:
|
||||
tiers = self._pricing_by_model.get(model.strip().lower())
|
||||
if tiers is None:
|
||||
raise ValueError(f"unknown model pricing: {model}")
|
||||
|
||||
normalized_prompt_tokens = max(int(prompt_tokens), 0)
|
||||
normalized_completion_tokens = max(int(completion_tokens), 0)
|
||||
normalized_cached_tokens = min(
|
||||
max(int(cached_prompt_tokens), 0), normalized_prompt_tokens
|
||||
)
|
||||
uncached_prompt_tokens = normalized_prompt_tokens - normalized_cached_tokens
|
||||
|
||||
selected_tier = tiers[-1]
|
||||
for tier in tiers:
|
||||
if normalized_prompt_tokens <= tier.max_prompt_tokens:
|
||||
selected_tier = tier
|
||||
break
|
||||
|
||||
cached_token_rate = (
|
||||
selected_tier.cache_hit_cost_per_token
|
||||
if selected_tier.cache_hit_cost_per_token > 0
|
||||
else selected_tier.input_cost_per_token
|
||||
)
|
||||
|
||||
return float(
|
||||
uncached_prompt_tokens * selected_tier.input_cost_per_token
|
||||
+ normalized_cached_tokens * cached_token_rate
|
||||
+ normalized_completion_tokens * selected_tier.output_cost_per_token
|
||||
)
|
||||
|
||||
def build_usage_metadata(
|
||||
self,
|
||||
*,
|
||||
model: str,
|
||||
usage_summary: dict[str, Any] | None,
|
||||
) -> dict[str, Any]:
|
||||
summary = usage_summary or {}
|
||||
input_tokens = max(int(summary.get("input_tokens", 0) or 0), 0)
|
||||
output_tokens = max(int(summary.get("output_tokens", 0) or 0), 0)
|
||||
total_tokens = max(
|
||||
int(summary.get("total_tokens", input_tokens + output_tokens) or 0), 0
|
||||
)
|
||||
latency_ms = max(int(summary.get("latency_ms", 0) or 0), 0)
|
||||
cached_prompt_tokens = max(int(summary.get("cached_prompt_tokens", 0) or 0), 0)
|
||||
prompt_cache_hit_tokens = max(
|
||||
int(summary.get("prompt_cache_hit_tokens", cached_prompt_tokens) or 0), 0
|
||||
)
|
||||
prompt_cache_miss_tokens = max(
|
||||
int(
|
||||
summary.get(
|
||||
"prompt_cache_miss_tokens",
|
||||
max(input_tokens - prompt_cache_hit_tokens, 0),
|
||||
)
|
||||
or 0
|
||||
),
|
||||
0,
|
||||
)
|
||||
reasoning_tokens = max(int(summary.get("reasoning_tokens", 0) or 0), 0)
|
||||
direct_cost_raw = summary.get("direct_cost")
|
||||
direct_cost_observed = bool(int(summary.get("direct_cost_observed", 0) or 0))
|
||||
direct_cost_complete = bool(int(summary.get("direct_cost_complete", 0) or 0))
|
||||
model_call_records = max(int(summary.get("model_call_records", 0) or 0), 0)
|
||||
usage_records = max(int(summary.get("usage_records", 0) or 0), 0)
|
||||
usage_complete = model_call_records == 0 or model_call_records == usage_records
|
||||
direct_cost = self._coerce_non_negative_float(direct_cost_raw)
|
||||
|
||||
if (
|
||||
usage_complete
|
||||
and direct_cost_observed
|
||||
and direct_cost_complete
|
||||
and direct_cost is not None
|
||||
):
|
||||
cost = direct_cost
|
||||
cost_source = "provider"
|
||||
else:
|
||||
cost = self.calculate_cost(
|
||||
model=model,
|
||||
prompt_tokens=input_tokens,
|
||||
completion_tokens=output_tokens,
|
||||
cached_prompt_tokens=cached_prompt_tokens,
|
||||
)
|
||||
cost_source = (
|
||||
"incomplete_usage_fallback"
|
||||
if not usage_complete
|
||||
else (
|
||||
"catalog_fallback_incomplete_provider_cost"
|
||||
if direct_cost_observed and not direct_cost_complete
|
||||
else "catalog_fallback"
|
||||
)
|
||||
)
|
||||
|
||||
return {
|
||||
"model": model,
|
||||
"inputTokens": input_tokens,
|
||||
"outputTokens": output_tokens,
|
||||
"totalTokens": total_tokens,
|
||||
"cachedPromptTokens": cached_prompt_tokens,
|
||||
"promptCacheHitTokens": prompt_cache_hit_tokens,
|
||||
"promptCacheMissTokens": prompt_cache_miss_tokens,
|
||||
"reasoningTokens": reasoning_tokens,
|
||||
"cost": cost,
|
||||
"costSource": cost_source,
|
||||
"usageComplete": usage_complete,
|
||||
"latencyMs": latency_ms,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _coerce_non_negative_float(value: Any) -> float | None:
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
parsed = float(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
if parsed < 0:
|
||||
return None
|
||||
return parsed
|
||||
@@ -0,0 +1 @@
|
||||
from __future__ import annotations
|
||||
@@ -0,0 +1,35 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
from schemas.domain.automation import AutomationJobConfig
|
||||
|
||||
_CONFIG_NAME_PATTERN = re.compile(r"^[a-z0-9][a-z0-9_-]{0,63}$")
|
||||
|
||||
|
||||
def _automation_yaml_path(config_name: str) -> Path:
|
||||
if not _CONFIG_NAME_PATTERN.fullmatch(config_name):
|
||||
raise ValueError("invalid automation config name")
|
||||
return (
|
||||
Path(__file__).resolve().parents[2]
|
||||
/ "core"
|
||||
/ "config"
|
||||
/ "static"
|
||||
/ "automation"
|
||||
/ f"{config_name}.yaml"
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=16)
|
||||
def load_static_automation_job_config(*, config_name: str) -> AutomationJobConfig:
|
||||
path = _automation_yaml_path(config_name)
|
||||
with path.open("r", encoding="utf-8") as file:
|
||||
loaded: Any = yaml.safe_load(file) or {}
|
||||
if not isinstance(loaded, dict):
|
||||
raise ValueError(f"invalid automation config format: {path}")
|
||||
return AutomationJobConfig.model_validate(loaded)
|
||||
@@ -0,0 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.db import get_db
|
||||
from v1.auth.gateway import SupabaseAuthGateway
|
||||
from v1.auth.registration_bootstrap import (
|
||||
RegistrationAutomationBootstrapService,
|
||||
RegistrationBootstrapRepository,
|
||||
)
|
||||
from v1.auth.service import AuthService
|
||||
|
||||
|
||||
def get_auth_service(
|
||||
session: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> AuthService:
|
||||
bootstrapper = RegistrationAutomationBootstrapService(
|
||||
repository=RegistrationBootstrapRepository(session=session),
|
||||
session=session,
|
||||
)
|
||||
return AuthService(
|
||||
gateway=SupabaseAuthGateway(),
|
||||
registration_bootstrapper=bootstrapper,
|
||||
)
|
||||
@@ -0,0 +1,430 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from supabase import AuthError
|
||||
|
||||
from core.http.errors import ApiProblemError
|
||||
from core.logging import get_logger
|
||||
from services.base.supabase import supabase_service
|
||||
from v1.auth.schemas import (
|
||||
AuthUser,
|
||||
OtpSendRequest,
|
||||
PhoneSessionCreateRequest,
|
||||
SessionRefreshRequest,
|
||||
SessionResponse,
|
||||
UserByIdResponse,
|
||||
UserByPhoneResponse,
|
||||
)
|
||||
from v1.auth.service import AuthServiceGateway
|
||||
|
||||
logger = get_logger("v1.auth.gateway")
|
||||
|
||||
AUTH_UNAVAILABLE_DETAIL = "Auth service temporarily unavailable"
|
||||
|
||||
|
||||
def _auth_error(
|
||||
*,
|
||||
status_code: int,
|
||||
code: str,
|
||||
detail: str,
|
||||
) -> ApiProblemError:
|
||||
return ApiProblemError(status_code=status_code, code=code, detail=detail)
|
||||
|
||||
|
||||
class SupabaseAuthGateway(AuthServiceGateway):
|
||||
def __init__(self) -> None:
|
||||
self._user_lookup_cache_ttl_seconds: int = 60
|
||||
self._user_lookup_cache_expires_at: float = 0.0
|
||||
self._users_by_phone: dict[str, Any] = {}
|
||||
self._users_by_id: dict[str, Any] = {}
|
||||
|
||||
def _get_client(self) -> Any:
|
||||
return supabase_service.get_client()
|
||||
|
||||
def _get_admin_client(self) -> Any:
|
||||
return supabase_service.get_admin_client()
|
||||
|
||||
async def send_otp(self, request: OtpSendRequest) -> None:
|
||||
client = self._get_client()
|
||||
payload: dict[str, Any] = {
|
||||
"phone": request.phone,
|
||||
"options": {"should_create_user": True},
|
||||
}
|
||||
try:
|
||||
sign_in_with_otp = cast(Any, client.auth.sign_in_with_otp)
|
||||
await asyncio.to_thread(sign_in_with_otp, payload)
|
||||
except AuthError as exc:
|
||||
logger.warning("Send otp failed", error_type=type(exc).__name__)
|
||||
if _is_auth_upstream_unavailable(exc):
|
||||
raise _auth_error(
|
||||
status_code=503,
|
||||
code="AUTH_SERVICE_UNAVAILABLE",
|
||||
detail=AUTH_UNAVAILABLE_DETAIL,
|
||||
) from exc
|
||||
raise _auth_error(
|
||||
status_code=429,
|
||||
code="AUTH_TOO_MANY_REQUESTS",
|
||||
detail="Too many requests",
|
||||
) from exc
|
||||
|
||||
async def create_phone_session(
|
||||
self, request: PhoneSessionCreateRequest
|
||||
) -> SessionResponse:
|
||||
client = self._get_client()
|
||||
payload: dict[str, Any] = {
|
||||
"type": "sms",
|
||||
"phone": request.phone,
|
||||
"token": request.token,
|
||||
}
|
||||
try:
|
||||
verify_otp = cast(Any, client.auth.verify_otp)
|
||||
response = await asyncio.to_thread(verify_otp, payload)
|
||||
return _map_auth_response(
|
||||
response,
|
||||
"Invalid verification code",
|
||||
"AUTH_VERIFICATION_CODE_INVALID",
|
||||
)
|
||||
except AuthError as exc:
|
||||
logger.warning("Create phone session failed", error_type=type(exc).__name__)
|
||||
if _is_auth_upstream_unavailable(exc):
|
||||
raise _auth_error(
|
||||
status_code=503,
|
||||
code="AUTH_SERVICE_UNAVAILABLE",
|
||||
detail=AUTH_UNAVAILABLE_DETAIL,
|
||||
) from exc
|
||||
raise _auth_error(
|
||||
status_code=401,
|
||||
code="AUTH_VERIFICATION_CODE_INVALID",
|
||||
detail="Invalid verification code",
|
||||
) from exc
|
||||
|
||||
async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse:
|
||||
client = self._get_client()
|
||||
try:
|
||||
response = await asyncio.to_thread(
|
||||
client.auth.refresh_session,
|
||||
request.refresh_token,
|
||||
)
|
||||
return _map_auth_response(
|
||||
response,
|
||||
"Invalid refresh token",
|
||||
"AUTH_REFRESH_TOKEN_INVALID",
|
||||
)
|
||||
except AuthError as exc:
|
||||
logger.warning("Refresh failed", error_type=type(exc).__name__)
|
||||
if _is_auth_upstream_unavailable(exc):
|
||||
raise _auth_error(
|
||||
status_code=503,
|
||||
code="AUTH_SERVICE_UNAVAILABLE",
|
||||
detail=AUTH_UNAVAILABLE_DETAIL,
|
||||
) from exc
|
||||
raise _auth_error(
|
||||
status_code=401,
|
||||
code="AUTH_REFRESH_TOKEN_INVALID",
|
||||
detail="Invalid refresh token",
|
||||
) from exc
|
||||
|
||||
async def delete_session(self, refresh_token: str | None) -> None:
|
||||
if not refresh_token:
|
||||
raise _auth_error(
|
||||
status_code=401,
|
||||
code="AUTH_REFRESH_TOKEN_MISSING",
|
||||
detail="Missing refresh token",
|
||||
)
|
||||
client = self._get_client()
|
||||
try:
|
||||
response = await asyncio.to_thread(
|
||||
client.auth.refresh_session,
|
||||
refresh_token,
|
||||
)
|
||||
session = getattr(response, "session", None)
|
||||
if session is None:
|
||||
raise _auth_error(
|
||||
status_code=401,
|
||||
code="AUTH_REFRESH_TOKEN_INVALID",
|
||||
detail="Invalid refresh token",
|
||||
)
|
||||
await asyncio.to_thread(
|
||||
client.auth.set_session,
|
||||
str(session.access_token),
|
||||
str(session.refresh_token),
|
||||
)
|
||||
await asyncio.to_thread(client.auth.sign_out)
|
||||
except AuthError as exc:
|
||||
logger.warning("Logout failed", error_type=type(exc).__name__)
|
||||
if _is_auth_upstream_unavailable(exc):
|
||||
raise _auth_error(
|
||||
status_code=503,
|
||||
code="AUTH_SERVICE_UNAVAILABLE",
|
||||
detail=AUTH_UNAVAILABLE_DETAIL,
|
||||
) from exc
|
||||
raise _auth_error(
|
||||
status_code=401,
|
||||
code="AUTH_REFRESH_TOKEN_INVALID",
|
||||
detail="Invalid refresh token",
|
||||
) from exc
|
||||
|
||||
async def get_user_by_phone(self, phone: str) -> UserByPhoneResponse:
|
||||
normalized_phone = _normalize_phone(phone)
|
||||
if not normalized_phone:
|
||||
raise _auth_error(
|
||||
status_code=404,
|
||||
code="AUTH_USER_NOT_FOUND",
|
||||
detail="User not found",
|
||||
)
|
||||
|
||||
await self._refresh_user_lookup_cache_if_needed()
|
||||
|
||||
user = self._users_by_phone.get(normalized_phone)
|
||||
if user is None:
|
||||
raise _auth_error(
|
||||
status_code=404,
|
||||
code="AUTH_USER_NOT_FOUND",
|
||||
detail="User not found",
|
||||
)
|
||||
|
||||
user_phone = _normalize_phone(getattr(user, "phone", ""))
|
||||
if not user_phone:
|
||||
raise _auth_error(
|
||||
status_code=404,
|
||||
code="AUTH_USER_NOT_FOUND",
|
||||
detail="User not found",
|
||||
)
|
||||
|
||||
return UserByPhoneResponse(
|
||||
id=str(getattr(user, "id", "")),
|
||||
phone=user_phone,
|
||||
created_at=str(getattr(user, "created_at", "")),
|
||||
phone_confirmed_at=(
|
||||
str(getattr(user, "phone_confirmed_at", ""))
|
||||
if getattr(user, "phone_confirmed_at", None)
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
async def get_user_by_id(self, user_id: str) -> UserByIdResponse:
|
||||
users = await self.get_users_by_ids([user_id])
|
||||
resolved = users.get(user_id)
|
||||
if resolved is None:
|
||||
raise _auth_error(
|
||||
status_code=404,
|
||||
code="AUTH_USER_NOT_FOUND",
|
||||
detail="User not found",
|
||||
)
|
||||
return resolved
|
||||
|
||||
async def get_users_by_ids(
|
||||
self, user_ids: list[str]
|
||||
) -> dict[str, UserByIdResponse]:
|
||||
await self._refresh_user_lookup_cache_if_needed()
|
||||
resolved: dict[str, UserByIdResponse] = {}
|
||||
for raw_user_id in user_ids:
|
||||
normalized_user_id = raw_user_id.strip()
|
||||
if not normalized_user_id:
|
||||
continue
|
||||
user = self._users_by_id.get(normalized_user_id)
|
||||
if user is None:
|
||||
continue
|
||||
user_attrs = getattr(user, "user", user)
|
||||
resolved[normalized_user_id] = UserByIdResponse(
|
||||
id=str(getattr(user_attrs, "id", "")),
|
||||
phone=getattr(user_attrs, "phone", None),
|
||||
created_at=str(getattr(user_attrs, "created_at", "")),
|
||||
phone_confirmed_at=(
|
||||
str(getattr(user_attrs, "phone_confirmed_at", ""))
|
||||
if getattr(user_attrs, "phone_confirmed_at", None)
|
||||
else None
|
||||
),
|
||||
)
|
||||
return resolved
|
||||
|
||||
async def search_user_ids_by_phone(self, query: str, limit: int = 20) -> list[str]:
|
||||
normalized_query = _normalize_phone_search_query(query)
|
||||
if not normalized_query:
|
||||
return []
|
||||
|
||||
await self._refresh_user_lookup_cache_if_needed()
|
||||
if normalized_query.startswith("+"):
|
||||
matched_user = self._users_by_phone.get(normalized_query)
|
||||
if matched_user is None:
|
||||
return []
|
||||
user_id = str(getattr(matched_user, "id", ""))
|
||||
return [user_id] if user_id else []
|
||||
|
||||
digits = _digits_only(normalized_query)
|
||||
if not digits:
|
||||
return []
|
||||
|
||||
matched_records: list[tuple[str, str]] = []
|
||||
for cached_phone, candidate in self._users_by_phone.items():
|
||||
candidate_digits = _digits_only(cached_phone)
|
||||
if not candidate_digits.endswith(digits):
|
||||
continue
|
||||
user_id = str(getattr(candidate, "id", ""))
|
||||
if user_id:
|
||||
matched_records.append((cached_phone, user_id))
|
||||
|
||||
if not matched_records:
|
||||
return []
|
||||
|
||||
unique_ids: list[str] = []
|
||||
for _, user_id in sorted(matched_records, key=lambda item: item[0]):
|
||||
if user_id in unique_ids:
|
||||
continue
|
||||
unique_ids.append(user_id)
|
||||
if len(unique_ids) >= max(1, limit):
|
||||
break
|
||||
return unique_ids
|
||||
|
||||
async def _refresh_user_lookup_cache_if_needed(self) -> None:
|
||||
now = time.monotonic()
|
||||
if now < self._user_lookup_cache_expires_at:
|
||||
return
|
||||
|
||||
admin_client = self._get_admin_client()
|
||||
users = await asyncio.to_thread(_list_auth_users, admin_client)
|
||||
users_by_phone: dict[str, Any] = {}
|
||||
users_by_id: dict[str, Any] = {}
|
||||
for candidate in users:
|
||||
candidate_id = str(getattr(candidate, "id", "")).strip()
|
||||
if candidate_id:
|
||||
users_by_id[candidate_id] = candidate
|
||||
candidate_phone = _normalize_phone(getattr(candidate, "phone", ""))
|
||||
if candidate_phone:
|
||||
users_by_phone[candidate_phone] = candidate
|
||||
self._users_by_id = users_by_id
|
||||
self._users_by_phone = users_by_phone
|
||||
self._user_lookup_cache_expires_at = now + self._user_lookup_cache_ttl_seconds
|
||||
|
||||
|
||||
def _is_auth_upstream_unavailable(exc: AuthError) -> bool:
|
||||
raw_status = getattr(exc, "status", None)
|
||||
if raw_status is None:
|
||||
raw_status = getattr(exc, "status_code", None)
|
||||
if isinstance(raw_status, int) and 500 <= raw_status < 600:
|
||||
return True
|
||||
|
||||
raw_code = getattr(exc, "code", None)
|
||||
code = str(raw_code).lower() if raw_code is not None else ""
|
||||
message = str(exc).lower()
|
||||
indicators = (
|
||||
"request_timeout",
|
||||
"timed out",
|
||||
"timeout",
|
||||
"gateway timeout",
|
||||
"bad_gateway",
|
||||
"service_unavailable",
|
||||
"internal_server_error",
|
||||
"unexpected_failure",
|
||||
"upstream",
|
||||
"500",
|
||||
"502",
|
||||
"503",
|
||||
"504",
|
||||
"5xx",
|
||||
)
|
||||
return any(token in code or token in message for token in indicators)
|
||||
|
||||
|
||||
def _map_auth_response(
|
||||
response: object, failure_message: str, failure_code: str
|
||||
) -> SessionResponse:
|
||||
session = getattr(response, "session", None)
|
||||
user = getattr(response, "user", None)
|
||||
if session is None or user is None:
|
||||
raise _auth_error(
|
||||
status_code=401,
|
||||
code=failure_code,
|
||||
detail=failure_message,
|
||||
)
|
||||
|
||||
phone = _normalize_phone(getattr(user, "phone", None))
|
||||
if not phone:
|
||||
raise _auth_error(
|
||||
status_code=401,
|
||||
code=failure_code,
|
||||
detail=failure_message,
|
||||
)
|
||||
|
||||
try:
|
||||
auth_user = AuthUser(id=str(user.id), phone=str(phone))
|
||||
except ValidationError as exc:
|
||||
logger.warning(
|
||||
"Auth response returned invalid phone format",
|
||||
error_type=type(exc).__name__,
|
||||
)
|
||||
raise _auth_error(
|
||||
status_code=401,
|
||||
code=failure_code,
|
||||
detail=failure_message,
|
||||
) from exc
|
||||
return SessionResponse(
|
||||
access_token=str(session.access_token),
|
||||
refresh_token=str(session.refresh_token),
|
||||
expires_in=int(session.expires_in or 0),
|
||||
token_type=str(session.token_type),
|
||||
user=auth_user,
|
||||
)
|
||||
|
||||
|
||||
def _list_auth_users(client: Any) -> list[Any]:
|
||||
users: list[Any] = []
|
||||
page = 1
|
||||
max_pages = 100
|
||||
|
||||
while page <= max_pages:
|
||||
response = client.auth.admin.list_users(page=page, per_page=100)
|
||||
batch = (
|
||||
list(response)
|
||||
if isinstance(response, list)
|
||||
else list(getattr(response, "users", []))
|
||||
)
|
||||
users.extend(batch)
|
||||
|
||||
if len(batch) < 100:
|
||||
break
|
||||
page += 1
|
||||
|
||||
return users
|
||||
|
||||
|
||||
def _sanitize_phone_token(raw: object) -> str:
|
||||
token = str(raw).strip()
|
||||
for separator in (" ", "-", "(", ")"):
|
||||
token = token.replace(separator, "")
|
||||
return token
|
||||
|
||||
|
||||
def _normalize_phone(raw_phone: object) -> str | None:
|
||||
phone = _sanitize_phone_token(raw_phone)
|
||||
if not phone:
|
||||
return None
|
||||
if phone.startswith("00") and len(phone) > 2:
|
||||
return f"+{phone[2:]}"
|
||||
if phone.startswith("+"):
|
||||
return phone
|
||||
if phone.isdigit():
|
||||
return f"+{phone}"
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_phone_search_query(raw_query: str) -> str | None:
|
||||
query = _sanitize_phone_token(raw_query)
|
||||
if not query:
|
||||
return None
|
||||
if query.startswith("00") and len(query) > 2:
|
||||
return f"+{query[2:]}"
|
||||
if query.startswith("+"):
|
||||
return query
|
||||
if query.isdigit():
|
||||
return query
|
||||
return None
|
||||
|
||||
|
||||
def _digits_only(value: str) -> str:
|
||||
return "".join(ch for ch in value if ch.isdigit())
|
||||
@@ -0,0 +1,114 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections import deque
|
||||
from time import monotonic
|
||||
|
||||
from core.http.errors import ApiProblemError
|
||||
|
||||
from core.logging import get_logger
|
||||
from services.base.redis import get_or_init_redis_client
|
||||
|
||||
_BUCKETS: dict[str, deque[float]] = {}
|
||||
_LAST_SEEN: dict[str, float] = {}
|
||||
_LOCK = asyncio.Lock()
|
||||
_CLEANUP_INTERVAL = 200
|
||||
_CALL_COUNT = 0
|
||||
logger = get_logger("v1.auth.rate_limit")
|
||||
_REDIS_LIMIT_SCRIPT = """
|
||||
local current = redis.call("INCR", KEYS[1])
|
||||
if current == 1 then
|
||||
redis.call("EXPIRE", KEYS[1], ARGV[1])
|
||||
end
|
||||
return current
|
||||
"""
|
||||
|
||||
|
||||
async def enforce_rate_limit(
|
||||
*,
|
||||
scope: str,
|
||||
identifier: str,
|
||||
limit: int,
|
||||
window_seconds: int,
|
||||
) -> None:
|
||||
key = f"auth:rate_limit:{scope}:{identifier.lower()}"
|
||||
try:
|
||||
await _enforce_rate_limit_with_redis(
|
||||
key=key,
|
||||
limit=limit,
|
||||
window_seconds=window_seconds,
|
||||
)
|
||||
return
|
||||
except ApiProblemError:
|
||||
raise
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning(
|
||||
"Rate limit fallback to in-memory",
|
||||
scope=scope,
|
||||
error_type=type(exc).__name__,
|
||||
)
|
||||
await _enforce_rate_limit_in_memory(
|
||||
key=key,
|
||||
limit=limit,
|
||||
window_seconds=window_seconds,
|
||||
)
|
||||
|
||||
|
||||
async def _enforce_rate_limit_with_redis(
|
||||
*,
|
||||
key: str,
|
||||
limit: int,
|
||||
window_seconds: int,
|
||||
) -> None:
|
||||
client = await get_or_init_redis_client()
|
||||
current = await client.eval(_REDIS_LIMIT_SCRIPT, 1, key, window_seconds) # type: ignore[await]
|
||||
if int(current) > limit:
|
||||
raise ApiProblemError(
|
||||
status_code=429,
|
||||
code="AUTH_TOO_MANY_REQUESTS",
|
||||
detail="Too many requests",
|
||||
)
|
||||
|
||||
|
||||
async def _enforce_rate_limit_in_memory(
|
||||
*,
|
||||
key: str,
|
||||
limit: int,
|
||||
window_seconds: int,
|
||||
) -> None:
|
||||
global _CALL_COUNT
|
||||
now = monotonic()
|
||||
async with _LOCK:
|
||||
bucket = _BUCKETS.setdefault(key, deque())
|
||||
_LAST_SEEN[key] = now
|
||||
cutoff = now - float(window_seconds)
|
||||
while bucket and bucket[0] <= cutoff:
|
||||
bucket.popleft()
|
||||
if len(bucket) >= limit:
|
||||
raise ApiProblemError(
|
||||
status_code=429,
|
||||
code="AUTH_TOO_MANY_REQUESTS",
|
||||
detail="Too many requests",
|
||||
)
|
||||
bucket.append(now)
|
||||
_CALL_COUNT += 1
|
||||
if _CALL_COUNT % _CLEANUP_INTERVAL == 0:
|
||||
_cleanup_stale_buckets(now)
|
||||
|
||||
|
||||
def _cleanup_stale_buckets(now: float) -> None:
|
||||
stale_keys = [
|
||||
key
|
||||
for key, last_seen in _LAST_SEEN.items()
|
||||
if key not in _BUCKETS or (not _BUCKETS[key] and now - last_seen > 3600)
|
||||
]
|
||||
for key in stale_keys:
|
||||
_BUCKETS.pop(key, None)
|
||||
_LAST_SEEN.pop(key, None)
|
||||
|
||||
|
||||
def reset_rate_limit_state() -> None:
|
||||
_BUCKETS.clear()
|
||||
_LAST_SEEN.clear()
|
||||
global _CALL_COUNT
|
||||
_CALL_COUNT = 0
|
||||
@@ -0,0 +1,239 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, time, timedelta
|
||||
from typing import Protocol
|
||||
from uuid import UUID, uuid4
|
||||
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.logging import get_logger
|
||||
from models.automation_jobs import AutomationJob
|
||||
from schemas.enums import AutomationJobStatus, MemoryType, ScheduleType
|
||||
from models.profile import Profile
|
||||
from schemas.domain.automation import AutomationJobConfig, ScheduleConfig
|
||||
from schemas.domain.memory_content import UserMemoryContent, WorkProfileContent
|
||||
from schemas.shared.user import parse_profile_settings
|
||||
from v1.auth.automation_static_config import load_static_automation_job_config
|
||||
from v1.auth.schemas import RegistrationBootstrapRequest
|
||||
from v1.memories.repository import SQLAlchemyMemoriesRepository
|
||||
|
||||
logger = get_logger("v1.auth.registration_bootstrap")
|
||||
|
||||
|
||||
class RegistrationBootstrapRepository:
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
self._memories_repository = SQLAlchemyMemoriesRepository(session)
|
||||
|
||||
async def get_profile_timezone(self, *, user_id: UUID) -> str:
|
||||
stmt = select(Profile.settings).where(Profile.id == user_id)
|
||||
settings = (await self._session.execute(stmt)).scalar_one_or_none()
|
||||
parsed = parse_profile_settings(
|
||||
settings if isinstance(settings, dict) else None
|
||||
)
|
||||
return parsed.preferences.timezone
|
||||
|
||||
async def insert_bootstrap_automation_job_if_absent(
|
||||
self,
|
||||
*,
|
||||
owner_id: UUID,
|
||||
bootstrap_key: str,
|
||||
title: str,
|
||||
config: AutomationJobConfig,
|
||||
timezone_name: str,
|
||||
next_run_at: datetime,
|
||||
) -> bool:
|
||||
stmt = (
|
||||
insert(AutomationJob)
|
||||
.values(
|
||||
id=uuid4(),
|
||||
owner_id=owner_id,
|
||||
bootstrap_key=bootstrap_key,
|
||||
title=title,
|
||||
config=config.model_dump(mode="json"),
|
||||
next_run_at=next_run_at,
|
||||
timezone=timezone_name,
|
||||
status=AutomationJobStatus.ACTIVE,
|
||||
created_by=owner_id,
|
||||
)
|
||||
.on_conflict_do_nothing(
|
||||
index_elements=["owner_id", "bootstrap_key"],
|
||||
index_where=AutomationJob.deleted_at.is_(None)
|
||||
& AutomationJob.bootstrap_key.is_not(None),
|
||||
)
|
||||
.returning(AutomationJob.id)
|
||||
)
|
||||
inserted_id = (await self._session.execute(stmt)).scalar_one_or_none()
|
||||
await self._session.flush()
|
||||
return inserted_id is not None
|
||||
|
||||
async def upsert_initial_memory(
|
||||
self,
|
||||
*,
|
||||
owner_id: UUID,
|
||||
memory_type: MemoryType,
|
||||
content: dict,
|
||||
) -> bool:
|
||||
return await self._memories_repository.create_if_absent(
|
||||
owner_id=owner_id,
|
||||
memory_type=memory_type,
|
||||
content=content,
|
||||
)
|
||||
|
||||
|
||||
class RegistrationBootstrapper(Protocol):
|
||||
async def ensure_user_automation_jobs(self, *, user_id: str | UUID) -> None: ...
|
||||
|
||||
|
||||
class RegistrationBootstrapRepositoryLike(Protocol):
|
||||
async def get_profile_timezone(self, *, user_id: UUID) -> str: ...
|
||||
|
||||
async def insert_bootstrap_automation_job_if_absent(
|
||||
self,
|
||||
*,
|
||||
owner_id: UUID,
|
||||
bootstrap_key: str,
|
||||
title: str,
|
||||
config: AutomationJobConfig,
|
||||
timezone_name: str,
|
||||
next_run_at: datetime,
|
||||
) -> bool: ...
|
||||
|
||||
async def upsert_initial_memory(
|
||||
self,
|
||||
*,
|
||||
owner_id: UUID,
|
||||
memory_type: MemoryType,
|
||||
content: dict,
|
||||
) -> bool: ...
|
||||
|
||||
|
||||
class SessionLike(Protocol):
|
||||
async def commit(self) -> None: ...
|
||||
|
||||
async def rollback(self) -> None: ...
|
||||
|
||||
|
||||
def compute_first_run_at_utc(
|
||||
*,
|
||||
now_utc: datetime,
|
||||
timezone_name: str,
|
||||
schedule: ScheduleConfig,
|
||||
) -> datetime:
|
||||
try:
|
||||
timezone_obj = ZoneInfo(timezone_name)
|
||||
except ZoneInfoNotFoundError:
|
||||
timezone_obj = ZoneInfo("UTC")
|
||||
|
||||
local_now = now_utc.astimezone(timezone_obj)
|
||||
run_clock = time(
|
||||
hour=schedule.run_at.hour,
|
||||
minute=schedule.run_at.minute,
|
||||
tzinfo=timezone_obj,
|
||||
)
|
||||
|
||||
if schedule.type == ScheduleType.DAILY:
|
||||
candidate_local = datetime.combine(local_now.date(), run_clock)
|
||||
if candidate_local <= local_now:
|
||||
candidate_local = candidate_local + timedelta(days=1)
|
||||
return candidate_local.astimezone(UTC)
|
||||
|
||||
weekdays = schedule.weekdays or []
|
||||
if not weekdays:
|
||||
raise ValueError("weekly schedule requires weekdays")
|
||||
|
||||
normalized_weekdays = sorted(set(weekdays))
|
||||
for day_offset in range(0, 8):
|
||||
candidate_day = local_now.date() + timedelta(days=day_offset)
|
||||
if candidate_day.isoweekday() not in normalized_weekdays:
|
||||
continue
|
||||
candidate_local = datetime.combine(candidate_day, run_clock)
|
||||
if candidate_local > local_now:
|
||||
return candidate_local.astimezone(UTC)
|
||||
|
||||
fallback_day = local_now.date() + timedelta(days=7)
|
||||
while fallback_day.isoweekday() not in normalized_weekdays:
|
||||
fallback_day = fallback_day + timedelta(days=1)
|
||||
fallback_local = datetime.combine(fallback_day, run_clock)
|
||||
return fallback_local.astimezone(UTC)
|
||||
|
||||
|
||||
class RegistrationAutomationBootstrapService:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
repository: RegistrationBootstrapRepositoryLike,
|
||||
session: SessionLike,
|
||||
) -> None:
|
||||
self._repository = repository
|
||||
self._session = session
|
||||
|
||||
async def ensure_user_automation_jobs(self, *, user_id: str | UUID) -> None:
|
||||
request = RegistrationBootstrapRequest.model_validate({"user_id": user_id})
|
||||
owner_id = request.user_id
|
||||
timezone_name = await self._repository.get_profile_timezone(user_id=owner_id)
|
||||
|
||||
definitions = [
|
||||
{
|
||||
"bootstrap_key": "memory_extraction",
|
||||
"config_name": "memory_extraction",
|
||||
"title": "记忆推送",
|
||||
}
|
||||
]
|
||||
|
||||
try:
|
||||
inserted_any = False
|
||||
created_or_updated_memory = False
|
||||
|
||||
user_initialized = await self._repository.upsert_initial_memory(
|
||||
owner_id=owner_id,
|
||||
memory_type=MemoryType.USER,
|
||||
content=UserMemoryContent().model_dump(mode="json"),
|
||||
)
|
||||
work_initialized = await self._repository.upsert_initial_memory(
|
||||
owner_id=owner_id,
|
||||
memory_type=MemoryType.WORK,
|
||||
content=WorkProfileContent().model_dump(mode="json"),
|
||||
)
|
||||
created_or_updated_memory = user_initialized or work_initialized
|
||||
|
||||
for definition in definitions:
|
||||
bootstrap_key = str(definition["bootstrap_key"])
|
||||
job_config = load_static_automation_job_config(
|
||||
config_name=str(definition["config_name"])
|
||||
)
|
||||
schedule = job_config.schedule
|
||||
if schedule is None:
|
||||
raise ValueError(
|
||||
f"bootstrap job {bootstrap_key} has no schedule configured"
|
||||
)
|
||||
next_run_at = compute_first_run_at_utc(
|
||||
now_utc=datetime.now(UTC),
|
||||
timezone_name=timezone_name,
|
||||
schedule=schedule,
|
||||
)
|
||||
inserted = (
|
||||
await self._repository.insert_bootstrap_automation_job_if_absent(
|
||||
owner_id=owner_id,
|
||||
bootstrap_key=bootstrap_key,
|
||||
title=str(definition["title"]),
|
||||
config=job_config,
|
||||
timezone_name=timezone_name,
|
||||
next_run_at=next_run_at,
|
||||
)
|
||||
)
|
||||
inserted_any = inserted_any or inserted
|
||||
if inserted_any or created_or_updated_memory:
|
||||
await self._session.commit()
|
||||
logger.info(
|
||||
"user automation jobs bootstrapped",
|
||||
user_id=user_id,
|
||||
timezone=timezone_name,
|
||||
memory_initialized=created_or_updated_memory,
|
||||
)
|
||||
except Exception:
|
||||
await self._session.rollback()
|
||||
raise
|
||||
@@ -0,0 +1,117 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends, Request, Response
|
||||
|
||||
from core.config.settings import config
|
||||
from v1.auth.rate_limit import enforce_rate_limit
|
||||
from v1.auth.dependencies import get_auth_service
|
||||
from v1.auth.schemas import (
|
||||
OtpSendRequest,
|
||||
PhoneSessionCreateRequest,
|
||||
SessionDeleteRequest,
|
||||
SessionRefreshRequest,
|
||||
SessionResponse,
|
||||
)
|
||||
from v1.auth.service import AuthService
|
||||
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
|
||||
@router.post("/otp/send", status_code=204)
|
||||
async def send_otp(
|
||||
payload: OtpSendRequest,
|
||||
request: Request,
|
||||
service: AuthService = Depends(get_auth_service),
|
||||
) -> Response:
|
||||
client_ip = _client_ip(request)
|
||||
await enforce_rate_limit(
|
||||
scope="otp_send_phone",
|
||||
identifier=payload.phone,
|
||||
limit=3,
|
||||
window_seconds=60,
|
||||
)
|
||||
await enforce_rate_limit(
|
||||
scope="otp_send_ip",
|
||||
identifier=client_ip,
|
||||
limit=20,
|
||||
window_seconds=60,
|
||||
)
|
||||
await service.send_otp(payload)
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@router.post("/phone-session", response_model=SessionResponse)
|
||||
async def create_phone_session(
|
||||
payload: PhoneSessionCreateRequest,
|
||||
request: Request,
|
||||
service: AuthService = Depends(get_auth_service),
|
||||
) -> SessionResponse:
|
||||
client_ip = _client_ip(request)
|
||||
await enforce_rate_limit(
|
||||
scope="phone_session_phone",
|
||||
identifier=payload.phone,
|
||||
limit=6,
|
||||
window_seconds=300,
|
||||
)
|
||||
await enforce_rate_limit(
|
||||
scope="phone_session_ip",
|
||||
identifier=client_ip,
|
||||
limit=20,
|
||||
window_seconds=300,
|
||||
)
|
||||
return await service.create_phone_session(payload)
|
||||
|
||||
|
||||
@router.post("/sessions/refresh", response_model=SessionResponse)
|
||||
async def refresh_session(
|
||||
payload: SessionRefreshRequest,
|
||||
request: Request,
|
||||
service: AuthService = Depends(get_auth_service),
|
||||
) -> SessionResponse:
|
||||
await enforce_rate_limit(
|
||||
scope="refresh",
|
||||
identifier=_client_ip(request),
|
||||
limit=10,
|
||||
window_seconds=60,
|
||||
)
|
||||
return await service.refresh_session(payload)
|
||||
|
||||
|
||||
@router.delete("/sessions", status_code=204)
|
||||
async def delete_session(
|
||||
payload: SessionDeleteRequest,
|
||||
request: Request,
|
||||
service: AuthService = Depends(get_auth_service),
|
||||
) -> Response:
|
||||
await enforce_rate_limit(
|
||||
scope="logout",
|
||||
identifier=_client_ip(request),
|
||||
limit=10,
|
||||
window_seconds=60,
|
||||
)
|
||||
await service.delete_session(payload.refresh_token)
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
def _client_ip(request: Request) -> str:
|
||||
host = request.client.host if request.client else ""
|
||||
if not host:
|
||||
return "unknown"
|
||||
|
||||
if _should_trust_proxy_headers(host):
|
||||
forwarded_for = request.headers.get("x-forwarded-for", "")
|
||||
if forwarded_for:
|
||||
first = forwarded_for.split(",")[0].strip()
|
||||
if first:
|
||||
return first
|
||||
real_ip = request.headers.get("x-real-ip", "").strip()
|
||||
if real_ip:
|
||||
return real_ip
|
||||
|
||||
return host
|
||||
|
||||
|
||||
def _should_trust_proxy_headers(host: str) -> bool:
|
||||
trusted_proxies = {entry.strip() for entry in config.runtime.trusted_proxy_ips}
|
||||
return host in trusted_proxies
|
||||
@@ -0,0 +1,66 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
SUPABASE_PASSWORD_MIN_LENGTH = 6
|
||||
SUPABASE_PHONE_PATTERN = r"^\+[1-9]\d{7,14}$"
|
||||
|
||||
|
||||
class OtpSendRequest(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
|
||||
|
||||
|
||||
class PhoneSessionCreateRequest(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
|
||||
token: str = Field(pattern=r"^\d{6}$")
|
||||
|
||||
|
||||
class SessionRefreshRequest(BaseModel):
|
||||
refresh_token: str = Field(min_length=1)
|
||||
|
||||
|
||||
class SessionDeleteRequest(BaseModel):
|
||||
refresh_token: str = Field(min_length=1)
|
||||
|
||||
|
||||
class AuthUser(BaseModel):
|
||||
id: str
|
||||
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
|
||||
|
||||
|
||||
class SessionResponse(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
expires_in: int
|
||||
token_type: str
|
||||
user: AuthUser
|
||||
|
||||
|
||||
class UserByPhoneResponse(BaseModel):
|
||||
id: str
|
||||
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
|
||||
created_at: str
|
||||
phone_confirmed_at: str | None = None
|
||||
|
||||
|
||||
class UserByIdResponse(BaseModel):
|
||||
id: str
|
||||
phone: str | None = None
|
||||
created_at: str
|
||||
phone_confirmed_at: str | None = None
|
||||
|
||||
|
||||
class OtpSendResponse(BaseModel):
|
||||
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
|
||||
|
||||
|
||||
class RegistrationBootstrapRequest(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
user_id: UUID
|
||||
@@ -0,0 +1,63 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
from v1.auth.schemas import (
|
||||
OtpSendRequest,
|
||||
PhoneSessionCreateRequest,
|
||||
SessionRefreshRequest,
|
||||
SessionResponse,
|
||||
)
|
||||
|
||||
|
||||
class AuthServiceGateway(Protocol):
|
||||
async def send_otp(self, request: OtpSendRequest) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def create_phone_session(
|
||||
self, request: PhoneSessionCreateRequest
|
||||
) -> SessionResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
async def delete_session(self, refresh_token: str | None) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class AuthService:
|
||||
_gateway: AuthServiceGateway
|
||||
_registration_bootstrapper: RegistrationBootstrapper | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
gateway: AuthServiceGateway,
|
||||
registration_bootstrapper: "RegistrationBootstrapper | None" = None,
|
||||
) -> None:
|
||||
self._gateway = gateway
|
||||
self._registration_bootstrapper = registration_bootstrapper
|
||||
|
||||
async def send_otp(self, request: OtpSendRequest) -> None:
|
||||
await self._gateway.send_otp(request)
|
||||
|
||||
async def create_phone_session(
|
||||
self, request: PhoneSessionCreateRequest
|
||||
) -> SessionResponse:
|
||||
response = await self._gateway.create_phone_session(request)
|
||||
if self._registration_bootstrapper is not None:
|
||||
await self._registration_bootstrapper.ensure_user_automation_jobs(
|
||||
user_id=response.user.id
|
||||
)
|
||||
return response
|
||||
|
||||
async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse:
|
||||
return await self._gateway.refresh_session(request)
|
||||
|
||||
async def delete_session(self, refresh_token: str | None) -> None:
|
||||
await self._gateway.delete_session(refresh_token)
|
||||
|
||||
|
||||
class RegistrationBootstrapper(Protocol):
|
||||
async def ensure_user_automation_jobs(self, *, user_id: str) -> None:
|
||||
raise NotImplementedError
|
||||
Reference in New Issue
Block a user