chore: 迁移到 social-app 架构,集成 Supabase 和 taskiq worker

This commit is contained in:
qzl
2026-04-02 16:36:35 +08:00
parent 695adb7d6f
commit 92cdfd9fca
132 changed files with 5802 additions and 759 deletions
View File
-14
View File
@@ -1,14 +0,0 @@
from __future__ import annotations
from fastapi import FastAPI
app = FastAPI(
title="Eryao API",
description="觅爻签问后端服务",
version="0.1.0",
)
@app.get("/health")
async def health_check() -> dict[str, str]:
return {"status": "ok"}
-3
View File
@@ -1,3 +0,0 @@
from .settings import Settings, config
__all__ = ["Settings", "config"]
+70 -86
View File
@@ -8,6 +8,7 @@ from pydantic import (
AnyHttpUrl,
BaseModel,
Field,
SecretStr,
computed_field,
field_validator,
model_validator,
@@ -118,11 +119,54 @@ class RedisSettings(BaseModel):
return f"redis://{self.host}:{self.port}/{self.db}"
class SupabaseSettings(BaseModel):
public_url: AnyHttpUrl
anon_key: str = "CHANGE_ME"
service_role_key: str = "CHANGE_ME"
jwt_secret: SecretStr | None = Field(default=None, exclude=True)
jwt_algorithm: Literal["HS256"] = "HS256"
jwt_issuer: str | None = None
@model_validator(mode="after")
def compute_defaults(self) -> "SupabaseSettings":
base = str(self.public_url).rstrip("/")
if self.jwt_issuer is None:
self.jwt_issuer = f"{base}/auth/v1"
return self
@computed_field
@property
def url(self) -> str:
return str(self.public_url)
class StorageSettings(BaseModel):
provider: Literal["supabase"] = "supabase"
signed_url_ttl_seconds: int = Field(default=600, ge=60, le=3600)
retention_days: int = Field(default=30, ge=1, le=3650)
class AttachmentSettings(BaseModel):
bucket: str = Field(default="eryao-attachments", min_length=3, max_length=63)
max_size_mb: int = Field(default=20, ge=1, le=200)
class AvatarSettings(BaseModel):
bucket: str = Field(default="avatars", min_length=3, max_length=63)
max_size_mb: int = Field(default=2, ge=1, le=10)
attachment: AttachmentSettings = Field(default_factory=AttachmentSettings)
avatar: AvatarSettings = Field(default_factory=AvatarSettings)
class LlmSettings(BaseModel):
provider_keys: dict[str, str] = Field(default_factory=dict)
class DatabaseSettings(BaseModel):
host: str = "localhost"
port: int = 3306
name: str = "eryao"
user: str = "root"
port: int = 5432
name: str = "postgres"
user: str = "postgres"
password: str = "CHANGE_ME"
@computed_field
@@ -130,83 +174,11 @@ class DatabaseSettings(BaseModel):
def url(self) -> str:
password = quote(self.password, safe="")
return (
f"mysql+aiomysql://{self.user}:{password}"
f"postgresql+asyncpg://{self.user}:{password}"
f"@{self.host}:{self.port}/{self.name}"
)
class AppVersionSettings(BaseModel):
manifest_path: str = Field(
default="deploy/static/releases/manifest.json",
description="发布清单文件路径,相对于项目根目录",
)
release_path_prefix: str = Field(
default="releases",
description="下载 URL 中文件目录前缀",
)
download_base_url: AnyHttpUrl | None = Field(
default=None,
description="下载链接基础域名,如 https://your-domain.com",
)
@field_validator("download_base_url", mode="before")
@classmethod
def empty_download_base_url_to_none(cls, value: object) -> object:
if value == "":
return None
return value
@field_validator("manifest_path")
@classmethod
def validate_manifest_path(cls, value: str) -> str:
normalized = Path(value)
if normalized.is_absolute() or ".." in normalized.parts:
raise ValueError("manifest_path must be a safe relative path")
return value
class AliyunSmsSettings(BaseModel):
access_key_id: str = "CHANGE_ME"
access_key_secret: str = "CHANGE_ME"
sign_name: str = "CHANGE_ME"
template_code: str = "CHANGE_ME"
region_id: str = "cn-hangzhou"
endpoint: str = "dysmsapi.aliyuncs.com"
test_mode: bool = False
class AliyunContentSecuritySettings(BaseModel):
access_key_id: str = "CHANGE_ME"
access_key_secret: str = "CHANGE_ME"
endpoint: str = "green-cip.cn-shenzhen.aliyuncs.com"
class AlipaySettings(BaseModel):
app_id: str = "CHANGE_ME"
merchant_id: str = "CHANGE_ME"
public_key: str = "CHANGE_ME"
private_key: str = "CHANGE_ME"
sign_type: str = "RSA2"
notify_url: str = ""
timeout_express: str = "30m"
sandbox: bool = False
class DeepSeekSettings(BaseModel):
api_key: str = "CHANGE_ME"
class AuthSettings(BaseModel):
token_expiration_days: int = 7
token_refresh_threshold_hours: int = 2
class VerificationSettings(BaseModel):
code_length: int = 6
expiration_minutes: int = 5
test_mode: bool = False
class SensitiveWordSettings(BaseModel):
use_aliyun: bool = True
fallback_to_local: bool = True
@@ -217,6 +189,11 @@ class TestSettings(BaseModel):
password: str = ""
class TaskiqSettings(BaseModel):
broker_url: str | None = None
result_backend_url: str | None = None
def _resolve_env_file() -> str:
current = Path(__file__).resolve()
for parent in [current, *current.parents]:
@@ -233,24 +210,31 @@ class Settings(BaseSettings):
runtime: RuntimeSettings = RuntimeSettings()
cors: CorsSettings = CorsSettings()
redis: RedisSettings = RedisSettings()
database: DatabaseSettings = DatabaseSettings()
app_version: AppVersionSettings = AppVersionSettings()
aliyun_sms: AliyunSmsSettings = AliyunSmsSettings()
aliyun_content_security: AliyunContentSecuritySettings = (
AliyunContentSecuritySettings()
supabase: SupabaseSettings = Field(
default_factory=lambda: SupabaseSettings(public_url="http://localhost:8001")
)
alipay: AlipaySettings = AlipaySettings()
deepseek: DeepSeekSettings = DeepSeekSettings()
auth: AuthSettings = AuthSettings()
verification: VerificationSettings = VerificationSettings()
sensitive_word: SensitiveWordSettings = SensitiveWordSettings()
storage: StorageSettings = StorageSettings()
llm: LlmSettings = LlmSettings()
database: DatabaseSettings = DatabaseSettings()
sensitive_word: SensitiveWordSettings = Field(default_factory=SensitiveWordSettings)
test: TestSettings = Field(default_factory=TestSettings)
taskiq: TaskiqSettings = Field(default_factory=TaskiqSettings)
@computed_field
@property
def database_url(self) -> str:
return self.database.url
@computed_field
@property
def taskiq_broker_url(self) -> str:
return self.taskiq.broker_url or self.redis.url
@computed_field
@property
def taskiq_result_backend_url(self) -> str:
return self.taskiq.result_backend_url or self.redis.url
model_config: ClassVar[SettingsConfigDict] = SettingsConfigDict(
env_file=_resolve_env_file(),
env_prefix="ERYAO_",
@@ -1,34 +0,0 @@
input_template: |
你正在执行一次"自动化记忆回顾与整理"任务。
任务目标:
1) 回顾最近两天的聊天与上下文,识别用户长期偏好、习惯和关键事实的变化。
2) 对已经失效、被否定或明显过期的信息执行遗忘。
3) 对新增且有证据支持的信息执行写入。
4) 严禁编造;没有证据就不要写入。
5) 只更新最小必要字段,避免过度覆盖。
输出要求:
- 必须使用以下固定格式输出:
<----------【周期任务输出】---------->
【记忆回顾】<一句人性化总结,说明今天主要发生了什么>
【新增记忆】<按"X条:要点1;要点2"描述;没有则写"0条">
【遗忘记忆】<按"X条:要点1;要点2"描述;没有则写"0条">
【未来展望】<基于本次记忆变化,给出1-2条温和、可执行的后续建议;若暂无建议则说明"可继续观察">
表达风格:
- 语言自然、温和、可读,像助理在做每日回顾。
- 结论先行,避免空话,不要输出与任务无关的闲聊内容。
enabled_tools:
- memory.write
- memory.forget
context:
source: latest_chat
window_mode: day
window_count: 2
schedule:
type: daily
run_at:
hour: 8
minute: 0
weekdays: null
@@ -1,158 +0,0 @@
version: "1.0"
routes:
- route_id: auth.boot
path: /boot
description: Bootstraps auth session and redirects to login or home.
category: auth
auth_required: false
- route_id: auth.login
path: /login
description: Login entry for unauthenticated users.
category: auth
auth_required: false
- route_id: home.main
path: /
description: Main assistant home screen.
category: home
auth_required: true
- route_id: message.invite_list
path: /messages/invites
description: Lists message invitations.
category: messages
auth_required: true
- route_id: message.invite_detail
path: /messages/invites/{id}
description: Shows details for a single invitation.
category: messages
auth_required: true
path_params:
- id
- route_id: contacts.list
path: /contacts
description: Contact list and quick relationship actions.
category: contacts
auth_required: true
- route_id: contacts.add
path: /contacts/add
description: Create or edit a contact profile.
category: contacts
auth_required: true
- route_id: calendar.dayweek
path: /calendar/dayweek
description: Day and week calendar view.
category: calendar
auth_required: true
query_params:
- date
- from
- route_id: calendar.month
path: /calendar/month
description: Month calendar overview.
category: calendar
auth_required: true
query_params:
- from
- route_id: calendar.event_detail
path: /calendar/events/{id}
description: Detail page for one calendar event.
category: calendar
auth_required: true
path_params:
- id
- route_id: calendar.event_create
path: /calendar/events/new
description: Create page for one calendar event.
category: calendar
auth_required: true
query_params:
- date
- route_id: calendar.event_edit
path: /calendar/events/{id}/edit
description: Edit page for one calendar event.
category: calendar
auth_required: true
path_params:
- id
- route_id: calendar.event_share
path: /calendar/events/{id}/share
description: Share settings page for one calendar event.
category: calendar
auth_required: true
path_params:
- id
- route_id: todo.list
path: /todo
description: Todo quadrants and backlog overview.
category: todo
auth_required: true
- route_id: todo.create
path: /todo/new
description: Create page for one todo item.
category: todo
auth_required: true
- route_id: todo.detail
path: /todo/{id}
description: Detail page for one todo item.
category: todo
auth_required: true
path_params:
- id
- route_id: todo.edit
path: /todo/{id}/edit
description: Dedicated subpage for editing one todo item (not an in-page modal).
category: todo
auth_required: true
path_params:
- id
- route_id: settings.main
path: /settings
description: Settings hub page.
category: settings
auth_required: true
- route_id: settings.features
path: /settings/features
description: Automation job list page.
category: settings
auth_required: true
- route_id: settings.job_new
path: /settings/job/new
description: Create page for one automation job.
category: settings
auth_required: true
- route_id: settings.job_detail
path: /settings/job/{id}
description: Detail page for one automation job.
category: settings
auth_required: true
path_params:
- id
- route_id: settings.memory
path: /settings/memory
description: Memory preferences and controls.
category: settings
auth_required: true
- route_id: settings.memory_user
path: /settings/memory/user
description: User memory summary view.
category: settings
auth_required: true
- route_id: settings.memory_work
path: /settings/memory/work
description: Work memory summary view.
category: settings
auth_required: true
- route_id: settings.memory_user_edit
path: /settings/memory/user/edit
description: Edit user memory details.
category: settings
auth_required: true
- route_id: settings.memory_work_edit
path: /settings/memory/work/edit
description: Edit work memory details.
category: settings
auth_required: true
- route_id: settings.edit_profile
path: /edit-profile
description: Profile editing page.
category: settings
auth_required: true
+2 -2
View File
@@ -1,6 +1,6 @@
from __future__ import annotations
from sqlalchemy import JSON
from sqlalchemy.dialects.mysql import JSON as MySQLJSON
from sqlalchemy.dialects.postgresql import JSONB
json_type = JSON().with_variant(MySQLJSON, "mysql")
json_jsonb = JSON().with_variant(JSONB, "postgresql")
+3
View File
@@ -0,0 +1,3 @@
from __future__ import annotations
__all__ = []
+2 -56
View File
@@ -5,9 +5,6 @@ import subprocess
import sys
from pathlib import Path
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.interval import IntervalTrigger
from core.automation.scheduler import run_automation_scheduler_scan
from core.config.initial.init_data import initialize_data
from core.config.settings import config
from core.logging import get_logger
@@ -16,7 +13,6 @@ logger = get_logger("core.runtime.cli")
def _resolve_alembic_path() -> Path:
"""Resolve alembic.ini path relative to project root."""
project_root = Path(__file__).parents[3]
alembic_path = project_root / "alembic" / "alembic.ini"
if not alembic_path.exists():
@@ -25,7 +21,6 @@ def _resolve_alembic_path() -> Path:
def _redact_sensitive(text: str) -> str:
"""Redact sensitive information from log output."""
import re
SENSITIVE_KEYS = ("password", "token", "secret", "api_key")
@@ -40,7 +35,6 @@ def _redact_sensitive(text: str) -> str:
def run_migrations() -> bool:
"""Run alembic migrations in a subprocess to avoid event loop conflicts."""
import os
logger.info("Running alembic migrations")
@@ -75,7 +69,6 @@ def run_migrations() -> bool:
async def run_init_data() -> bool:
"""Initialize bootstrap data."""
logger.info("Running init-data")
try:
result = await initialize_data()
@@ -90,7 +83,6 @@ async def run_init_data() -> bool:
async def bootstrap() -> bool:
"""Run migrations followed by init-data."""
logger.info("Starting bootstrap (migrate + init-data)")
if not run_migrations():
@@ -105,52 +97,11 @@ async def bootstrap() -> bool:
return True
async def run_automation_scheduler_forever() -> None:
if not config.automation_scheduler.enabled:
logger.info("Automation scheduler disabled by config")
return
interval_seconds = int(config.automation_scheduler.interval_seconds)
batch_limit = int(config.automation_scheduler.batch_limit)
logger.info(
"Starting automation scheduler",
interval_seconds=interval_seconds,
batch_limit=batch_limit,
)
async def scan_job() -> None:
try:
await run_automation_scheduler_scan(limit=batch_limit)
except Exception as exc:
logger.exception("Automation scheduler scan failed", error=str(exc))
scheduler = AsyncIOScheduler()
scheduler.add_job(
scan_job,
trigger=IntervalTrigger(seconds=interval_seconds),
id="automation_scheduler_scan",
name="Automation scheduler scan",
replace_existing=True,
max_instances=1,
coalesce=True,
)
scheduler.start()
stop_event = asyncio.Event()
try:
await stop_event.wait()
finally:
scheduler.shutdown(wait=False)
def main() -> int:
"""CLI entry point."""
if len(sys.argv) < 2:
logger.error("No command provided")
logger.info("Usage: python -m core.runtime.cli <command>")
logger.info(
"Available commands: migrate, init-data, bootstrap, automation-scheduler"
)
logger.info("Available commands: migrate, init-data, bootstrap")
return 1
command = sys.argv[1]
@@ -161,14 +112,9 @@ def main() -> int:
success = asyncio.run(run_init_data())
elif command == "bootstrap":
success = asyncio.run(bootstrap())
elif command == "automation-scheduler":
asyncio.run(run_automation_scheduler_forever())
return 0
else:
logger.error("Unknown command", command=command)
logger.info(
"Available commands: migrate, init-data, bootstrap, automation-scheduler"
)
logger.info("Available commands: migrate, init-data, bootstrap")
return 1
return 0 if success else 1
+3
View File
@@ -0,0 +1,3 @@
from __future__ import annotations
__all__ = []
+3
View File
@@ -0,0 +1,3 @@
from core.taskiq.app import broker, worker_agent_broker, worker_general_broker
__all__ = ["broker", "worker_agent_broker", "worker_general_broker"]
+30
View File
@@ -0,0 +1,30 @@
from __future__ import annotations
from taskiq_redis import ListQueueBroker, RedisAsyncResultBackend
from core.config.settings import config
from core.logging import configure_logging, log_service_banner
configure_logging(config)
log_service_banner(
service_name=config.runtime.service_name,
environment=config.runtime.environment,
)
def _build_broker(queue_name: str) -> ListQueueBroker:
return ListQueueBroker(
url=config.taskiq_broker_url,
queue_name=queue_name,
).with_result_backend(
RedisAsyncResultBackend(redis_url=config.taskiq_result_backend_url)
)
worker_agent_broker = _build_broker("agent")
worker_general_broker = _build_broker("general")
broker = worker_agent_broker
__all__ = ["broker", "worker_agent_broker", "worker_general_broker"]
+8 -9
View File
@@ -1,12 +1,11 @@
from . import user, divination, payment, notification, feedback, version, log, violation
from __future__ import annotations
from models.llm import Llm
from models.llm_factory import LlmFactory
from models.system_agents import SystemAgents
__all__ = [
"user",
"divination",
"payment",
"notification",
"feedback",
"version",
"log",
"violation",
"Llm",
"LlmFactory",
"SystemAgents",
]
-41
View File
@@ -1,41 +0,0 @@
from __future__ import annotations
from core.db.base import Base, TimestampMixin
from sqlalchemy import BigInteger, DateTime, Integer, String, Text, func
from sqlalchemy.orm import Mapped, mapped_column
class DivinationRecord(TimestampMixin, Base):
__tablename__ = "user_divination_records"
id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
user_id: Mapped[int] = mapped_column(BigInteger, nullable=False)
trace_id: Mapped[str] = mapped_column(String(64), nullable=False)
question: Mapped[str] = mapped_column(Text, nullable=False)
question_type: Mapped[str] = mapped_column(String(50), nullable=False)
divination_data: Mapped[str] = mapped_column(Text, nullable=False)
deepseek_request: Mapped[str] = mapped_column(Text, nullable=False)
deepseek_response: Mapped[str | None] = mapped_column(Text, nullable=True)
interpretation_result: Mapped[str | None] = mapped_column(Text, nullable=True)
api_success: Mapped[bool] = mapped_column(Integer, nullable=False, default=0)
error_message: Mapped[str | None] = mapped_column(Text, nullable=True)
api_duration_ms: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
phone_number: Mapped[str | None] = mapped_column(String(20), nullable=True)
class DivinationHistory(TimestampMixin, Base):
__tablename__ = "user_divination_history"
id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
user_id: Mapped[int] = mapped_column(BigInteger, nullable=False)
phone_number: Mapped[str] = mapped_column(String(20), nullable=False)
local_record_id: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
json_data: Mapped[str] = mapped_column(Text, nullable=False)
ai_result: Mapped[str] = mapped_column(Text, nullable=False)
question_type: Mapped[str] = mapped_column(String(50), nullable=False)
question: Mapped[str] = mapped_column(Text, nullable=False)
timestamp: Mapped[int] = mapped_column(BigInteger, nullable=False)
is_active: Mapped[bool] = mapped_column(Integer, nullable=False, default=1)
sync_time: Mapped[str] = mapped_column(
DateTime, nullable=False, server_default=func.now()
)
-17
View File
@@ -1,17 +0,0 @@
from __future__ import annotations
from core.db.base import Base
from sqlalchemy import BigInteger, DateTime, String, Text, func
from sqlalchemy.orm import Mapped, mapped_column
class UserFeedback(Base):
__tablename__ = "user_feedback"
id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
user_id: Mapped[int] = mapped_column(BigInteger, nullable=False)
phone_number: Mapped[str] = mapped_column(String(20), nullable=False)
content: Mapped[str] = mapped_column(Text, nullable=False)
created_at: Mapped[str] = mapped_column(
DateTime, nullable=False, server_default=func.now()
)
+26
View File
@@ -0,0 +1,26 @@
from __future__ import annotations
import uuid
from sqlalchemy import ForeignKey, String
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column
from core.db.base import Base, SoftDeleteMixin, TimestampMixin
class Llm(TimestampMixin, SoftDeleteMixin, Base):
__tablename__: str = "llms"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
factory_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("llm_factory.id", ondelete="RESTRICT"),
nullable=False,
index=True,
)
model_code: Mapped[str] = mapped_column(
String(50), nullable=False, unique=True, index=True
)
+22
View File
@@ -0,0 +1,22 @@
from __future__ import annotations
import uuid
from sqlalchemy import String, Text
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column
from core.db.base import Base, SoftDeleteMixin, TimestampMixin
class LlmFactory(TimestampMixin, SoftDeleteMixin, Base):
__tablename__: str = "llm_factory"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
name: Mapped[str] = mapped_column(
String(50), nullable=False, unique=True, index=True
)
request_url: Mapped[str] = mapped_column(String(255), nullable=False)
avatar: Mapped[str | None] = mapped_column(Text, nullable=True)
-39
View File
@@ -1,39 +0,0 @@
from __future__ import annotations
from core.db.base import Base
from sqlalchemy import BigInteger, DateTime, Integer, String, Text, func
from sqlalchemy.orm import Mapped, mapped_column
class NetworkAccessLog(Base):
__tablename__ = "network_access_logs"
id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
user_id: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
phone_number: Mapped[str | None] = mapped_column(String(20), nullable=True)
client_ip: Mapped[str] = mapped_column(String(45), nullable=False)
client_port: Mapped[int | None] = mapped_column(Integer, nullable=True)
server_ip: Mapped[str] = mapped_column(String(45), nullable=False)
server_port: Mapped[int] = mapped_column(Integer, nullable=False)
http_method: Mapped[str] = mapped_column(String(10), nullable=False)
request_path: Mapped[str] = mapped_column(String(500), nullable=False)
request_url: Mapped[str] = mapped_column(String(1000), nullable=False)
user_agent: Mapped[str | None] = mapped_column(String(1000), nullable=True)
device_info: Mapped[str | None] = mapped_column(Text, nullable=True)
response_status: Mapped[int | None] = mapped_column(Integer, nullable=True)
processing_time_ms: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
request_size: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
response_size: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
x_forwarded_for: Mapped[str | None] = mapped_column(String(500), nullable=True)
x_real_ip: Mapped[str | None] = mapped_column(String(45), nullable=True)
referer: Mapped[str | None] = mapped_column(String(1000), nullable=True)
operation_type: Mapped[str | None] = mapped_column(String(50), nullable=True)
operation_result: Mapped[str | None] = mapped_column(String(20), nullable=True)
error_message: Mapped[str | None] = mapped_column(Text, nullable=True)
session_id: Mapped[str | None] = mapped_column(String(100), nullable=True)
access_time: Mapped[str] = mapped_column(
DateTime, nullable=False, server_default=func.now()
)
created_at: Mapped[str] = mapped_column(
DateTime, nullable=False, server_default=func.now()
)
-14
View File
@@ -1,14 +0,0 @@
from __future__ import annotations
from core.db.base import Base
from sqlalchemy import BigInteger, String, Text
from sqlalchemy.orm import Mapped, mapped_column
class Notification(Base):
__tablename__ = "notification"
id: Mapped[str] = mapped_column(String(64), primary_key=True)
title: Mapped[str] = mapped_column(String(255), nullable=False)
content: Mapped[str] = mapped_column(Text, nullable=False)
timestamp: Mapped[int] = mapped_column(BigInteger, nullable=False)
-40
View File
@@ -1,40 +0,0 @@
from __future__ import annotations
from core.db.base import Base, TimestampMixin
from sqlalchemy import BigInteger, DateTime, Integer, String, Text, func
from sqlalchemy.orm import Mapped, mapped_column
class PaymentOrder(TimestampMixin, Base):
__tablename__ = "payment_order"
id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
order_no: Mapped[str] = mapped_column(String(64), unique=True, nullable=False)
user_id: Mapped[int] = mapped_column(BigInteger, nullable=False)
amount: Mapped[str] = mapped_column(String(20), nullable=False)
coin_count: Mapped[int] = mapped_column(Integer, nullable=False)
subject: Mapped[str] = mapped_column(String(256), nullable=False)
body: Mapped[str | None] = mapped_column(String(512), nullable=True)
channel: Mapped[str] = mapped_column(String(16), nullable=False)
status: Mapped[str] = mapped_column(String(16), nullable=False, default="CREATED")
trade_no: Mapped[str | None] = mapped_column(String(64), nullable=True)
payment_time: Mapped[str | None] = mapped_column(DateTime, nullable=True)
class PaymentRecord(Base):
__tablename__ = "payment_record"
id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
order_no: Mapped[str] = mapped_column(String(64), nullable=False)
trade_no: Mapped[str] = mapped_column(String(64), nullable=False)
user_id: Mapped[int] = mapped_column(BigInteger, nullable=False)
channel: Mapped[str] = mapped_column(String(16), nullable=False)
notify_type: Mapped[str] = mapped_column(String(16), nullable=False)
trade_status: Mapped[str] = mapped_column(String(32), nullable=False)
notify_data: Mapped[str] = mapped_column(Text, nullable=False)
process_status: Mapped[str] = mapped_column(String(16), nullable=False)
process_message: Mapped[str | None] = mapped_column(String(512), nullable=True)
coin_added: Mapped[bool] = mapped_column(Integer, nullable=False, default=0)
created_at: Mapped[str] = mapped_column(
DateTime, nullable=False, server_default=func.now()
)
+32
View File
@@ -0,0 +1,32 @@
from __future__ import annotations
import uuid
from sqlalchemy import JSON, ForeignKey, String
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.orm import Mapped, mapped_column
from core.db.base import Base, TimestampMixin
class SystemAgents(TimestampMixin, Base):
__tablename__: str = "system_agents"
agent_type: Mapped[str] = mapped_column(
String(20),
primary_key=True,
)
llm_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("llms.id", ondelete="RESTRICT"),
nullable=False,
)
status: Mapped[str] = mapped_column(
String(20),
nullable=False,
)
config: Mapped[dict] = mapped_column(
JSON().with_variant(JSONB, "postgresql"),
nullable=False,
server_default="{}",
)
-52
View File
@@ -1,52 +0,0 @@
from __future__ import annotations
from core.db.base import Base, TimestampMixin
from sqlalchemy import BigInteger, DateTime, Integer, String, Text, func
from sqlalchemy.orm import Mapped, mapped_column
class User(TimestampMixin, Base):
__tablename__ = "user_profile"
id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
phone_number: Mapped[str] = mapped_column(String(20), unique=True, nullable=False)
nickname: Mapped[str] = mapped_column(String(50), nullable=False, default="")
gender: Mapped[str] = mapped_column(String(10), nullable=False, default="")
birthday: Mapped[str] = mapped_column(
String(20), nullable=False, default="2000-01-01"
)
signature: Mapped[str] = mapped_column(String(255), nullable=False, default="")
class UserToken(Base):
__tablename__ = "user_tokens"
id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
user_id: Mapped[int] = mapped_column(BigInteger, nullable=False)
token: Mapped[str] = mapped_column(String(255), nullable=False)
expire_time: Mapped[str] = mapped_column(DateTime, nullable=False)
created_at: Mapped[str] = mapped_column(
DateTime, nullable=False, server_default=func.now()
)
class VerificationCode(Base):
__tablename__ = "verification_codes"
id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
phone_number: Mapped[str] = mapped_column(String(20), nullable=False)
code: Mapped[str] = mapped_column(String(6), nullable=False)
expiration_time: Mapped[str] = mapped_column(DateTime, nullable=False)
created_at: Mapped[str] = mapped_column(
DateTime, nullable=False, server_default=func.now()
)
used: Mapped[bool] = mapped_column(Integer, nullable=False, default=0)
class UserCoin(Base):
__tablename__ = "user_coin"
id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
user_id: Mapped[int] = mapped_column(BigInteger, nullable=False, unique=True)
phone_number: Mapped[str] = mapped_column(String(20), nullable=False)
coin_balance: Mapped[int] = mapped_column(Integer, nullable=False, default=3)
-19
View File
@@ -1,19 +0,0 @@
from __future__ import annotations
from core.db.base import Base, TimestampMixin
from sqlalchemy import BigInteger, Integer, String, Text
from sqlalchemy.orm import Mapped, mapped_column
class AppVersion(TimestampMixin, Base):
__tablename__ = "app_version"
id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
version_name: Mapped[str] = mapped_column(String(20), unique=True, nullable=False)
version_code: Mapped[int] = mapped_column(Integer, unique=True, nullable=False)
min_supported_version: Mapped[str] = mapped_column(String(20), nullable=False)
min_supported_code: Mapped[int] = mapped_column(Integer, nullable=False)
is_force_update: Mapped[bool] = mapped_column(Integer, nullable=False, default=0)
update_message: Mapped[str | None] = mapped_column(Text, nullable=True)
download_url: Mapped[str | None] = mapped_column(String(500), nullable=True)
is_active: Mapped[bool] = mapped_column(Integer, nullable=False, default=1)
-30
View File
@@ -1,30 +0,0 @@
from __future__ import annotations
from core.db.base import Base
from sqlalchemy import BigInteger, DateTime, Float, Integer, String, Text, func
from sqlalchemy.orm import Mapped, mapped_column
class SensitiveWordViolation(Base):
__tablename__ = "sensitive_word_violations"
id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
user_id: Mapped[int] = mapped_column(BigInteger, nullable=False)
content_type: Mapped[str] = mapped_column(String(20), nullable=False)
original_content: Mapped[str] = mapped_column(Text, nullable=False)
violation_type: Mapped[str] = mapped_column(String(30), nullable=False)
detection_service: Mapped[str] = mapped_column(
String(20), nullable=False, default="LOCAL"
)
risk_level: Mapped[str | None] = mapped_column(String(50), nullable=True)
confidence: Mapped[float | None] = mapped_column(Float, nullable=True)
aliyun_response: Mapped[str | None] = mapped_column(Text, nullable=True)
matched_words: Mapped[str] = mapped_column(Text, nullable=False)
client_ip: Mapped[str | None] = mapped_column(String(45), nullable=True)
user_agent: Mapped[str | None] = mapped_column(Text, nullable=True)
violation_time: Mapped[str] = mapped_column(
DateTime, nullable=False, server_default=func.now()
)
created_at: Mapped[str] = mapped_column(
DateTime, nullable=False, server_default=func.now()
)
+1
View File
@@ -0,0 +1 @@
"""Backend reusable schemas package."""
+68
View File
@@ -0,0 +1,68 @@
from schemas.agent.forwarded_props import (
ClientTimeContext,
ForwardedPropsPayload,
parse_forwarded_props_client_time,
parse_forwarded_props_runtime_mode,
)
from schemas.agent.forwarded_props import RuntimeMode
from schemas.agent.runtime_models import (
AgentOutput,
ConstraintItem,
ExecutionMode,
KeyEntity,
NormalizedTaskInput,
ResultTyping,
ResultType,
RouterAgentOutput,
RunStatus,
TaskType,
TaskTyping,
ToolAgentOutput,
ToolStatus,
WorkerAgentOutputLite,
WorkerAgentOutputRich,
resolve_worker_output_model,
)
from schemas.agent.system_agent import AgentType, SystemAgentLLMConfig
from schemas.agent.visibility import SystemVisibilityBit, VisibilityMask, bit_mask
from schemas.agent.ui_hints import (
UiHintAction,
UiHintIntent,
UiHintSection,
UiHintStatus,
UiHintsPayload,
)
__all__ = [
"AgentType",
"AgentOutput",
"ConstraintItem",
"ExecutionMode",
"ForwardedPropsPayload",
"KeyEntity",
"NormalizedTaskInput",
"ResultTyping",
"ClientTimeContext",
"ResultType",
"RouterAgentOutput",
"RunStatus",
"RuntimeMode",
"TaskType",
"TaskTyping",
"SystemAgentLLMConfig",
"SystemVisibilityBit",
"ToolAgentOutput",
"ToolStatus",
"UiHintAction",
"UiHintIntent",
"UiHintSection",
"UiHintStatus",
"UiHintsPayload",
"VisibilityMask",
"WorkerAgentOutputLite",
"WorkerAgentOutputRich",
"bit_mask",
"parse_forwarded_props_client_time",
"parse_forwarded_props_runtime_mode",
"resolve_worker_output_model",
]
@@ -0,0 +1,93 @@
from __future__ import annotations
from datetime import datetime
from enum import Enum
import re
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
from pydantic import (
BaseModel,
ConfigDict,
Field,
StrictInt,
ValidationError,
field_validator,
)
_RFC3339_WITH_TZ_PATTERN = re.compile(
r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}(?:\.\d+)?(?:Z|[+-]\d{2}:\d{2})$"
)
class ClientTimeContext(BaseModel):
model_config = ConfigDict(extra="forbid")
device_timezone: str = Field(
...,
description="IANA timezone from client device, e.g. America/Los_Angeles.",
)
client_now_iso: str = Field(
...,
description="RFC3339 datetime with timezone offset from client device.",
)
client_epoch_ms: StrictInt = Field(
...,
ge=0,
description="Unix epoch milliseconds from client device.",
)
@field_validator("device_timezone")
@classmethod
def validate_device_timezone(cls, value: str) -> str:
try:
ZoneInfo(value)
except ZoneInfoNotFoundError as exc:
raise ValueError("invalid client_time.device_timezone") from exc
return value
@field_validator("client_now_iso")
@classmethod
def validate_client_now_iso(cls, value: str) -> str:
if not _RFC3339_WITH_TZ_PATTERN.fullmatch(value):
raise ValueError("invalid client_time.client_now_iso")
normalized = value.replace("Z", "+00:00")
try:
parsed = datetime.fromisoformat(normalized)
except ValueError as exc:
raise ValueError("invalid client_time.client_now_iso") from exc
if parsed.tzinfo is None:
raise ValueError("invalid client_time.client_now_iso")
return value
class RuntimeMode(str, Enum):
CHAT = "chat"
AUTOMATION = "automation"
class ForwardedPropsPayload(BaseModel):
model_config = ConfigDict(extra="forbid")
runtime_mode: RuntimeMode
client_time: ClientTimeContext | None = None
def parse_forwarded_props(forwarded_props: object) -> ForwardedPropsPayload:
if not isinstance(forwarded_props, dict):
raise ValueError("invalid RunAgentInput.forwardedProps")
try:
return ForwardedPropsPayload.model_validate(forwarded_props)
except ValidationError as exc:
raise ValueError("invalid RunAgentInput.forwardedProps") from exc
def parse_forwarded_props_client_time(
forwarded_props: object,
) -> ClientTimeContext | None:
payload = parse_forwarded_props(forwarded_props)
return payload.client_time
def parse_forwarded_props_runtime_mode(forwarded_props: object) -> RuntimeMode:
payload = parse_forwarded_props(forwarded_props)
return payload.runtime_mode
+177
View File
@@ -0,0 +1,177 @@
from __future__ import annotations
from enum import Enum
from typing import Any
from pydantic import BaseModel, ConfigDict, Field, field_validator
from schemas.agent.ui_hints import UiHintsPayload
class TaskType(str, Enum):
KNOWLEDGE = "knowledge"
RECOMMENDATION = "recommendation"
PLANNING = "planning"
SCHEDULING = "scheduling"
REMINDER_MANAGEMENT = "reminder_management"
TODO_MANAGEMENT = "todo_management"
COMMUNICATION_DRAFTING = "communication_drafting"
INFORMATION_ORGANIZATION = "information_organization"
STATUS_TRACKING = "status_tracking"
TRANSACTION_ASSIST = "transaction_assist"
ACTION_EXECUTION = "action_execution"
TROUBLESHOOTING = "troubleshooting"
UNKNOWN = "unknown"
class ResultType(str, Enum):
DIRECT_ANSWER = "direct_answer"
OPTIONS_WITH_RECOMMENDATION = "options_with_recommendation"
ACTION_PLAN = "action_plan"
SCHEDULE_PROPOSAL = "schedule_proposal"
TODO_LIST = "todo_list"
DRAFT_MESSAGE = "draft_message"
SUMMARY = "summary"
PROGRESS_SUMMARY = "progress_summary"
DIAGNOSIS_REPORT = "diagnosis_report"
STRUCTURED_PAYLOAD = "structured_payload"
EXECUTION_REPORT = "execution_report"
CLARIFICATION_REQUEST = "clarification_request"
SAFETY_BLOCK = "safety_block"
ERROR_REPORT = "error_report"
UNKNOWN = "unknown"
class TaskTyping(BaseModel):
model_config = ConfigDict(extra="forbid")
primary: TaskType
secondary: list[TaskType] = Field(default_factory=list, max_length=3)
class ResultTyping(BaseModel):
model_config = ConfigDict(extra="forbid")
primary: ResultType
secondary: list[ResultType] = Field(default_factory=list, max_length=3)
class ExecutionMode(str, Enum):
ONESTEP = "onestep"
TOOL_ASSISTED = "tool_assisted"
MULTISTEP = "multistep"
class RunStatus(str, Enum):
SUCCESS = "success"
PARTIAL_SUCCESS = "partial_success"
FAILED = "failed"
class ToolStatus(str, Enum):
SUCCESS = "success"
FAILURE = "failure"
PARTIAL = "partial"
class KeyEntity(BaseModel):
model_config = ConfigDict(extra="forbid")
name: str
type: str
value: str | None = None
@field_validator("value", mode="before")
@classmethod
def normalize_value(cls, value: object) -> object:
if value is None:
return None
if isinstance(value, str):
return value
if isinstance(value, bool | int | float):
return str(value)
return value
class ConstraintItem(BaseModel):
model_config = ConfigDict(extra="forbid")
key: str
value: str
required: bool = True
@field_validator("value", mode="before")
@classmethod
def normalize_value(cls, value: object) -> object:
if isinstance(value, bool | int | float):
return str(value)
return value
class NormalizedTaskInput(BaseModel):
model_config = ConfigDict(extra="forbid")
user_text: str
multimodal_summary: list[str] = Field(default_factory=list)
context_summary: str = Field(default="", max_length=2000)
class RouterAgentOutput(BaseModel):
model_config = ConfigDict(extra="forbid")
normalized_task_input: NormalizedTaskInput
key_entities: list[KeyEntity] = Field(default_factory=list)
constraints: list[ConstraintItem] = Field(default_factory=list)
task_typing: TaskTyping
execution_mode: ExecutionMode
result_typing: ResultTyping
class ErrorInfo(BaseModel):
model_config = ConfigDict(extra="forbid")
code: str
message: str
retryable: bool = False
details: dict[str, Any] | None = None
class ToolAgentOutput(BaseModel):
model_config = ConfigDict(extra="forbid")
tool_name: str
tool_call_id: str
tool_call_args: dict[str, Any] | None = None
status: ToolStatus
result: str
error: ErrorInfo | None = None
class WorkerAgentOutputLite(BaseModel):
model_config = ConfigDict(extra="forbid")
status: RunStatus = RunStatus.SUCCESS
answer: str
key_points: list[str] = Field(default_factory=list)
result_type: ResultType = ResultType.UNKNOWN
suggested_actions: list[str] = Field(default_factory=list)
error: ErrorInfo | None = None
class WorkerAgentOutputRich(WorkerAgentOutputLite):
ui_hints: UiHintsPayload | None = None
class AgentOutput(WorkerAgentOutputRich):
model_config = ConfigDict(extra="forbid")
WorkerAgentOutput = WorkerAgentOutputLite | WorkerAgentOutputRich
def resolve_worker_output_model(
execution_mode: ExecutionMode,
) -> type[WorkerAgentOutputLite]:
if execution_mode == ExecutionMode.ONESTEP:
return WorkerAgentOutputLite
return WorkerAgentOutputRich
+30
View File
@@ -0,0 +1,30 @@
from __future__ import annotations
from enum import Enum
from pydantic import BaseModel, Field
class AgentType(str, Enum):
ROUTER = "router"
WORKER = "worker"
class ContextBuildStrategy(str, Enum):
DAY = "day"
NUMBER = "number"
class ContextMessagesConfig(BaseModel):
mode: ContextBuildStrategy = ContextBuildStrategy.NUMBER
count: int = Field(default=20, ge=1, le=200)
class SystemAgentLLMConfig(BaseModel):
temperature: float | None = Field(default=None, ge=0.0, le=2.0)
max_tokens: int | None = Field(default=None, ge=1)
timeout_seconds: float | None = Field(default=30.0, gt=0.0, le=300.0)
context_messages: ContextMessagesConfig = Field(
default_factory=ContextMessagesConfig
)
enabled_tools: list[str] = Field(default_factory=list, max_length=32)
+349
View File
@@ -0,0 +1,349 @@
"""
UiHints - 描述性 UI 提示
设计原则:
- 描述性而非渲染性: 告诉编译器“要展示什么”,而不是“如何渲染”
- 最小化 token: 保持字段简洁
- 可编译: 可机械转换为 UiSchemaRenderer
- 尽量无损: hints 中的主要内容字段应尽量被保留到 renderer 中
Version: 2.1
"""
from __future__ import annotations
from enum import Enum
import re
from typing import Any, ClassVar, Literal
from pydantic import BaseModel, ConfigDict, Field
from pydantic import field_validator
_NAVIGATION_PATH_PATTERN = re.compile(r"^/[A-Za-z0-9/_-]*$")
_NAVIGATION_PARAM_KEY_PATTERN = re.compile(r"^[A-Za-z][A-Za-z0-9_]{0,31}$")
_MAX_NAVIGATION_PARAMS = 8
# ============================================================
# Enums
# ============================================================
class UiHintStatus(str, Enum):
INFO = "info"
SUCCESS = "success"
WARNING = "warning"
ERROR = "error"
PENDING = "pending"
class UiHintIntent(str, Enum):
"""主要展示意图(弱提示,不应决定字段生死)"""
MESSAGE = "message" # 普通消息/说明
DATA = "data" # 数据/结果摘要
LIST = "list" # 列表为主
STATUS = "status" # 状态结果为主
FORM = "form" # 结构化内容(当前不表示真实输入表单)
MIXED = "mixed" # 混合内容
class UiHintActionStyle(str, Enum):
PRIMARY = "primary"
SECONDARY = "secondary"
GHOST = "ghost"
DANGER = "danger"
class UiHintTextFormat(str, Enum):
PLAIN = "plain"
MARKDOWN = "markdown"
class UiHintActionType(str, Enum):
NAVIGATION = "navigation"
URL = "url"
EVENT = "event"
TOOL = "tool"
COPY = "copy"
PAYLOAD = "payload"
class UiHintIconSource(str, Enum):
ICON = "icon"
EMOJI = "emoji"
URL = "url"
# ============================================================
# Base Config
# ============================================================
class UiHintBaseModel(BaseModel):
model_config: ClassVar[ConfigDict] = ConfigDict(
extra="forbid",
populate_by_name=True,
)
# ============================================================
# Action Targets
# ============================================================
class UiHintActionNavigation(UiHintBaseModel):
type: Literal["navigation"]
path: str = Field(..., description="Internal route path.")
params: dict[str, Any] | None = Field(default=None, description="Route params.")
@field_validator("path")
@classmethod
def validate_navigation_path(cls, value: str) -> str:
path = value.strip()
if not path:
raise ValueError("navigation path must not be empty")
if len(path) > 256:
raise ValueError("navigation path is too long")
if path.startswith("//") or "://" in path:
raise ValueError("navigation path must be internal")
if "?" in path or "#" in path:
raise ValueError("navigation path must not contain query or fragment")
if ":" in path:
raise ValueError("navigation path must be concrete without placeholders")
if _NAVIGATION_PATH_PATTERN.fullmatch(path) is None:
raise ValueError("navigation path contains unsupported characters")
return path
@field_validator("params")
@classmethod
def validate_navigation_params(
cls, value: dict[str, Any] | None
) -> dict[str, Any] | None:
if value is None:
return None
if len(value) > _MAX_NAVIGATION_PARAMS:
raise ValueError("navigation params exceed limit")
normalized: dict[str, Any] = {}
for key, param_value in value.items():
if _NAVIGATION_PARAM_KEY_PATTERN.fullmatch(key) is None:
raise ValueError("navigation param key is invalid")
if isinstance(param_value, (str, int, float, bool)):
normalized[key] = param_value
continue
raise ValueError("navigation params must be scalar")
return normalized
class UiHintActionUrl(UiHintBaseModel):
type: Literal["url"]
url: str = Field(..., description="External URL.")
target: Literal["_self", "_blank"] | None = Field(default=None)
class UiHintActionEvent(UiHintBaseModel):
type: Literal["event"]
event: str = Field(..., description="Frontend event name.")
payload: dict[str, Any] | None = Field(default=None)
class UiHintActionTool(UiHintBaseModel):
type: Literal["tool"]
tool_id: str = Field(alias="toolId", description="Tool identifier.")
params: dict[str, Any] | None = Field(default=None)
class UiHintActionCopy(UiHintBaseModel):
type: Literal["copy"]
content: str = Field(..., description="Content to copy.")
success_message: str | None = Field(alias="successMessage", default=None)
class UiHintActionPayload(UiHintBaseModel):
type: Literal["payload"]
payload: dict[str, Any] = Field(..., description="Structured payload.")
submit_to: str | None = Field(alias="submitTo", default=None)
UiHintActionTarget = (
UiHintActionNavigation
| UiHintActionUrl
| UiHintActionEvent
| UiHintActionTool
| UiHintActionCopy
| UiHintActionPayload
)
class UiHintAction(UiHintBaseModel):
label: str = Field(..., description="Button label.")
style: UiHintActionStyle | None = Field(default=None, description="Button style.")
disabled: bool = Field(default=False, description="Disabled state.")
action: UiHintActionTarget = Field(..., description="Action to execute.")
# ============================================================
# Small Descriptive Models
# ============================================================
class UiHintIcon(UiHintBaseModel):
source: UiHintIconSource = Field(default=UiHintIconSource.ICON)
value: str = Field(..., description="Icon identifier / emoji / url.")
color: str | None = Field(default=None)
size: int | None = Field(default=None)
class UiHintKvItem(UiHintBaseModel):
key: str = Field(..., description="Key identifier.")
label: str | None = Field(default=None, description="Display label.")
value: Any = Field(default=None, description="Value.")
copyable: bool = Field(default=False, description="Allow copy.")
class UiHintListItem(UiHintBaseModel):
id: str | None = Field(default=None)
title: str = Field(..., description="Item title.")
subtitle: str | None = Field(default=None)
description: str | None = Field(default=None)
icon: UiHintIcon | None = Field(default=None)
status: UiHintStatus | None = Field(default=None)
actions: list[UiHintAction] = Field(default_factory=list)
@field_validator("status", mode="before")
@classmethod
def normalize_status(cls, value: object) -> object:
if value is None:
return None
if isinstance(value, dict):
status_type = value.get("type")
if isinstance(status_type, str):
return status_type
status_value = value.get("status")
if isinstance(status_value, str):
return status_value
return value
class UiHintSection(UiHintBaseModel):
title: str | None = Field(default=None, description="Section title.")
description: str | None = Field(default=None, description="Section description.")
icon: UiHintIcon | None = Field(default=None, description="Section icon.")
content: str | None = Field(default=None, description="Main text content.")
content_format: UiHintTextFormat = Field(
default=UiHintTextFormat.PLAIN,
alias="contentFormat",
description="Section content text format.",
)
items: list[UiHintKvItem] = Field(default_factory=list, description="KV items.")
list_items: list[UiHintListItem] = Field(
default_factory=list,
alias="listItems",
description="List items.",
)
actions: list[UiHintAction] = Field(default_factory=list, description="Actions.")
# ============================================================
# Root Payload
# ============================================================
class UiHintsPayload(UiHintBaseModel):
"""
描述性 UI 提示
设计目标:
- agent 输出尽可能短
- 不表达布局细节
- 编译器负责转换为完整 UiSchemaRenderer
"""
model_config: ClassVar[ConfigDict] = ConfigDict(
extra="forbid",
populate_by_name=True,
json_schema_extra={
"examples": [
{
"intent": "status",
"status": "success",
"title": "日程已创建",
"body": "本次创建已成功完成。",
"items": [
{"key": "title", "label": "主题", "value": "Q1 规划会议"},
{"key": "time", "label": "时间", "value": "2026-03-15 14:00"},
],
"actions": [
{
"label": "查看详情",
"style": "primary",
"action": {
"type": "navigation",
"path": "/calendar/evt_123",
},
},
{
"label": "删除",
"style": "danger",
"action": {
"type": "tool",
"toolId": "calendar.delete",
"params": {"eventId": "evt_123"},
},
},
],
}
]
},
)
version: str = Field(default="2.1")
intent: UiHintIntent = Field(
default=UiHintIntent.MESSAGE,
description="Primary display intent.",
)
status: UiHintStatus = Field(
default=UiHintStatus.INFO,
description="Overall status.",
)
title: str | None = Field(default=None, description="Top-level title.")
description: str | None = Field(default=None, description="Top-level description.")
body: str | None = Field(default=None, description="Top-level main body text.")
body_format: UiHintTextFormat = Field(
default=UiHintTextFormat.PLAIN,
alias="bodyFormat",
description="Body text format.",
)
items: list[UiHintKvItem] = Field(
default_factory=list,
description="Top-level key-value items.",
)
list_items: list[UiHintListItem] = Field(
default_factory=list,
alias="listItems",
description="Top-level list items.",
)
sections: list[UiHintSection] = Field(
default_factory=list,
description="Grouped sections.",
)
actions: list[UiHintAction] = Field(
default_factory=list,
description="Top-level actions.",
)
icon: UiHintIcon | None = Field(
default=None,
description="Top-level icon.",
)
meta: dict[str, Any] = Field(
default_factory=dict,
description="Extra meta, e.g. requestId/toolId/traceId/userId.",
)
+628
View File
@@ -0,0 +1,628 @@
"""
UI Schema Renderer Protocol
目标:
- 只保留“基础组件 + 布局容器”
- 最终返回一个 UiSchemaRenderer
- 前端只需要递归渲染 root 布局树即可
Version: 2.0
"""
from __future__ import annotations
from enum import Enum
from typing import Any, Literal, TypedDict, Union
# ============================================================
# Enums
# ============================================================
class UiStatus(str, Enum):
INFO = "info"
SUCCESS = "success"
WARNING = "warning"
ERROR = "error"
PENDING = "pending"
class IconSource(str, Enum):
ICON = "icon"
EMOJI = "emoji"
URL = "url"
class TextFormat(str, Enum):
PLAIN = "plain"
MARKDOWN = "markdown"
class TextRole(str, Enum):
TITLE = "title"
SUBTITLE = "subtitle"
BODY = "body"
CAPTION = "caption"
CODE = "code"
class ButtonStyle(str, Enum):
PRIMARY = "primary"
SECONDARY = "secondary"
GHOST = "ghost"
DANGER = "danger"
class LayoutDirection(str, Enum):
VERTICAL = "vertical"
HORIZONTAL = "horizontal"
class LayoutAppearance(str, Enum):
PLAIN = "plain"
CARD = "card"
SECTION = "section"
class LayoutAlign(str, Enum):
START = "start"
CENTER = "center"
END = "end"
STRETCH = "stretch"
class LayoutJustify(str, Enum):
START = "start"
CENTER = "center"
END = "end"
SPACE_BETWEEN = "space-between"
class RendererTheme(str, Enum):
DEFAULT = "default"
LIGHT = "light"
DARK = "dark"
# ============================================================
# Meta
# ============================================================
class UiMeta(TypedDict, total=False):
requestId: str
toolId: str
traceId: str
userId: str
# ============================================================
# Action Payloads
# ============================================================
class NavigateAction(TypedDict, total=False):
type: Literal["navigation"]
path: str
params: dict[str, Any]
class UrlAction(TypedDict, total=False):
type: Literal["url"]
url: str
target: Literal["_self", "_blank"]
class EventAction(TypedDict, total=False):
type: Literal["event"]
event: str
payload: dict[str, Any]
class ToolAction(TypedDict, total=False):
type: Literal["tool"]
toolId: str
params: dict[str, Any]
class CopyAction(TypedDict, total=False):
type: Literal["copy"]
content: str
successMessage: str
class PayloadAction(TypedDict, total=False):
type: Literal["payload"]
payload: dict[str, Any]
submitTo: str
UiActionPayload = Union[
NavigateAction,
UrlAction,
EventAction,
ToolAction,
CopyAction,
PayloadAction,
]
# ============================================================
# Shared Small Types
# ============================================================
class UiIconSpec(TypedDict, total=False):
source: str
value: str
color: str
size: int
class UiKvItem(TypedDict, total=False):
key: str
label: str
value: Any
copyable: bool
class UiBaseNode(TypedDict, total=False):
id: str
visible: bool
# ============================================================
# Primitive Components
# ============================================================
class UiTextNode(UiBaseNode, total=False):
type: Literal["text"]
content: str
format: str # TextFormat
role: str # TextRole
status: str # UiStatus
maxLines: int
class UiIconNode(UiBaseNode, total=False):
type: Literal["icon"]
source: str # IconSource
value: str
color: str
size: int
class UiBadgeNode(UiBaseNode, total=False):
type: Literal["badge"]
label: str
status: str # UiStatus
class UiButtonNode(UiBaseNode, total=False):
type: Literal["button"]
label: str
style: str # ButtonStyle
disabled: bool
icon: UiIconSpec
action: UiActionPayload
class UiKvNode(UiBaseNode, total=False):
type: Literal["kv"]
items: list[UiKvItem]
columns: int
class UiDividerNode(UiBaseNode, total=False):
type: Literal["divider"]
inset: int
# ============================================================
# Layout Containers
# ============================================================
class UiStackNode(UiBaseNode, total=False):
type: Literal["stack"]
direction: str # LayoutDirection
gap: int
appearance: str # LayoutAppearance
status: str # UiStatus
align: str # LayoutAlign
justify: str # LayoutJustify
wrap: bool
children: list["UiNode"]
class UiGridNode(UiBaseNode, total=False):
type: Literal["grid"]
columns: int
gap: int
appearance: str # LayoutAppearance
status: str # UiStatus
children: list["UiNode"]
UiNode = Union[
UiTextNode,
UiIconNode,
UiBadgeNode,
UiButtonNode,
UiKvNode,
UiDividerNode,
UiStackNode,
UiGridNode,
]
UiLayoutNode = Union[UiStackNode, UiGridNode]
# ============================================================
# Root Renderer
# ============================================================
class UiSchemaRenderer(TypedDict, total=False):
version: str
locale: str
status: str # UiStatus
theme: str # RendererTheme
meta: UiMeta
root: UiLayoutNode
# ============================================================
# Root Builder
# ============================================================
def build_renderer(
root: UiLayoutNode,
*,
version: str = "2.0",
locale: str = "zh-CN",
status: UiStatus = UiStatus.INFO,
theme: RendererTheme = RendererTheme.DEFAULT,
meta: UiMeta | None = None,
) -> UiSchemaRenderer:
renderer: UiSchemaRenderer = {
"version": version,
"locale": locale,
"status": status.value,
"theme": theme.value,
"root": root,
}
if meta:
renderer["meta"] = meta
return renderer
# ============================================================
# Primitive Builders
# ============================================================
def build_text(
content: str,
*,
node_id: str | None = None,
format: TextFormat = TextFormat.PLAIN,
role: TextRole = TextRole.BODY,
status: UiStatus | None = None,
max_lines: int | None = None,
visible: bool = True,
) -> UiTextNode:
node: UiTextNode = {
"type": "text",
"content": content,
"format": format.value,
"role": role.value,
"visible": visible,
}
if node_id:
node["id"] = node_id
if status:
node["status"] = status.value
if max_lines is not None:
node["maxLines"] = max_lines
return node
def build_icon(
source: IconSource,
value: str,
*,
node_id: str | None = None,
color: str | None = None,
size: int | None = None,
visible: bool = True,
) -> UiIconNode:
node: UiIconNode = {
"type": "icon",
"source": source.value,
"value": value,
"visible": visible,
}
if node_id:
node["id"] = node_id
if color:
node["color"] = color
if size is not None:
node["size"] = size
return node
def build_badge(
label: str,
*,
node_id: str | None = None,
status: UiStatus = UiStatus.INFO,
visible: bool = True,
) -> UiBadgeNode:
node: UiBadgeNode = {
"type": "badge",
"label": label,
"status": status.value,
"visible": visible,
}
if node_id:
node["id"] = node_id
return node
def build_button(
label: str,
action: UiActionPayload,
*,
node_id: str | None = None,
style: ButtonStyle = ButtonStyle.PRIMARY,
disabled: bool = False,
icon: UiIconSpec | None = None,
visible: bool = True,
) -> UiButtonNode:
node: UiButtonNode = {
"type": "button",
"label": label,
"style": style.value,
"disabled": disabled,
"action": action,
"visible": visible,
}
if node_id:
node["id"] = node_id
if icon:
node["icon"] = icon
return node
def build_kv(
items: list[UiKvItem],
*,
node_id: str | None = None,
columns: int = 1,
visible: bool = True,
) -> UiKvNode:
node: UiKvNode = {
"type": "kv",
"items": items,
"columns": columns,
"visible": visible,
}
if node_id:
node["id"] = node_id
return node
def build_divider(
*,
node_id: str | None = None,
inset: int = 0,
visible: bool = True,
) -> UiDividerNode:
node: UiDividerNode = {
"type": "divider",
"inset": inset,
"visible": visible,
}
if node_id:
node["id"] = node_id
return node
# ============================================================
# Layout Builders
# ============================================================
def build_stack(
children: list[UiNode],
*,
node_id: str | None = None,
direction: LayoutDirection = LayoutDirection.VERTICAL,
gap: int = 12,
appearance: LayoutAppearance = LayoutAppearance.PLAIN,
status: UiStatus | None = None,
align: LayoutAlign = LayoutAlign.START,
justify: LayoutJustify = LayoutJustify.START,
wrap: bool = False,
visible: bool = True,
) -> UiStackNode:
node: UiStackNode = {
"type": "stack",
"direction": direction.value,
"gap": gap,
"appearance": appearance.value,
"align": align.value,
"justify": justify.value,
"wrap": wrap,
"children": children,
"visible": visible,
}
if node_id:
node["id"] = node_id
if status:
node["status"] = status.value
return node
def build_grid(
children: list[UiNode],
*,
columns: int,
node_id: str | None = None,
gap: int = 12,
appearance: LayoutAppearance = LayoutAppearance.PLAIN,
status: UiStatus | None = None,
visible: bool = True,
) -> UiGridNode:
node: UiGridNode = {
"type": "grid",
"columns": columns,
"gap": gap,
"appearance": appearance.value,
"children": children,
"visible": visible,
}
if node_id:
node["id"] = node_id
if status:
node["status"] = status.value
return node
# ============================================================
# Small Action Builders
# ============================================================
def action_navigation(
path: str, params: dict[str, Any] | None = None
) -> NavigateAction:
action: NavigateAction = {"type": "navigation", "path": path}
if params:
action["params"] = params
return action
def action_url(url: str, target: Literal["_self", "_blank"] = "_blank") -> UrlAction:
return {"type": "url", "url": url, "target": target}
def action_event(event: str, payload: dict[str, Any] | None = None) -> EventAction:
action: EventAction = {"type": "event", "event": event}
if payload:
action["payload"] = payload
return action
def action_tool(tool_id: str, params: dict[str, Any] | None = None) -> ToolAction:
action: ToolAction = {"type": "tool", "toolId": tool_id}
if params:
action["params"] = params
return action
def action_copy(content: str, success_message: str | None = None) -> CopyAction:
action: CopyAction = {"type": "copy", "content": content}
if success_message:
action["successMessage"] = success_message
return action
def action_payload(
payload: dict[str, Any], submit_to: str | None = None
) -> PayloadAction:
action: PayloadAction = {"type": "payload", "payload": payload}
if submit_to:
action["submitTo"] = submit_to
return action
# ============================================================
# Derived Helpers (协议外的便捷封装,不是基础原语)
# ============================================================
def build_card(
children: list[UiNode],
*,
node_id: str | None = None,
gap: int = 12,
status: UiStatus | None = None,
) -> UiStackNode:
return build_stack(
children,
node_id=node_id,
direction=LayoutDirection.VERTICAL,
gap=gap,
appearance=LayoutAppearance.CARD,
status=status,
)
def build_section(
title: str,
children: list[UiNode],
*,
node_id: str | None = None,
description: str | None = None,
status: UiStatus | None = None,
gap: int = 12,
) -> UiStackNode:
header_nodes: list[UiNode] = [build_text(title, role=TextRole.TITLE)]
if description:
header_nodes.append(build_text(description, role=TextRole.CAPTION))
all_children = header_nodes + children
return build_stack(
all_children,
node_id=node_id,
direction=LayoutDirection.VERTICAL,
gap=gap,
appearance=LayoutAppearance.SECTION,
status=status,
)
def build_status_panel(
title: str,
message: str,
*,
status: UiStatus,
primary_button: UiButtonNode | None = None,
secondary_button: UiButtonNode | None = None,
node_id: str | None = None,
) -> UiStackNode:
status_label = f"ui.status.{status.value}"
children: list[UiNode] = [
build_stack(
[
build_text(title, role=TextRole.TITLE),
build_badge(label=status_label, status=status),
],
direction=LayoutDirection.HORIZONTAL,
gap=8,
align=LayoutAlign.CENTER,
justify=LayoutJustify.SPACE_BETWEEN,
),
build_text(message, role=TextRole.BODY, status=status),
]
actions: list[UiNode] = []
if primary_button:
actions.append(primary_button)
if secondary_button:
actions.append(secondary_button)
if actions:
children.append(
build_stack(
actions,
direction=LayoutDirection.HORIZONTAL,
gap=8,
)
)
return build_card(children, node_id=node_id, status=status)
+50
View File
@@ -0,0 +1,50 @@
from __future__ import annotations
from enum import IntEnum
from pydantic import BaseModel, ConfigDict, Field, field_validator
class SystemVisibilityBit(IntEnum):
UI_HISTORY = 0
CONTEXT_ASSEMBLY = 1
class VisibilityMask(BaseModel):
model_config = ConfigDict(extra="forbid")
value: int = Field(..., ge=0, le=(1 << 63) - 1)
@classmethod
def from_bits(cls, *, bits: list[int]) -> "VisibilityMask":
mask = 0
for bit in bits:
validate_visibility_bit(bit=bit)
mask |= 1 << bit
return cls(value=mask)
def contains(self, *, bit: int) -> bool:
validate_visibility_bit(bit=bit)
return bool(self.value & (1 << bit))
class VisibilityBitRef(BaseModel):
model_config = ConfigDict(extra="forbid")
bit: int = Field(..., ge=0, le=63)
@field_validator("bit")
@classmethod
def _validate_bit(cls, value: int) -> int:
validate_visibility_bit(bit=value)
return value
def validate_visibility_bit(*, bit: int) -> None:
if bit < 0 or bit > 63:
raise ValueError("visibility bit must be in range [0, 63]")
def bit_mask(*, bit: int) -> int:
validate_visibility_bit(bit=bit)
return 1 << bit
+1
View File
@@ -0,0 +1 @@
from __future__ import annotations
+28
View File
@@ -0,0 +1,28 @@
from __future__ import annotations
from services.base.redis import RedisService, get_or_init_redis_client, redis_service
from services.base.service_interface import (
BaseServiceProvider,
ServiceRegistry,
close_registered_services,
initialize_registered_services,
register_service,
register_service_instance,
resolve_registered_services,
)
from services.base.supabase import SupabaseService, supabase_service
__all__ = [
"BaseServiceProvider",
"RedisService",
"ServiceRegistry",
"SupabaseService",
"close_registered_services",
"get_or_init_redis_client",
"initialize_registered_services",
"redis_service",
"register_service",
"register_service_instance",
"resolve_registered_services",
"supabase_service",
]
+136
View File
@@ -0,0 +1,136 @@
from __future__ import annotations
import asyncio
import inspect
from typing import Any, Dict, Optional
import redis.asyncio as redis
from core.config.settings import RedisSettings, config
from .service_interface import BaseServiceProvider, register_service_instance
class RedisService(BaseServiceProvider):
def __init__(self, settings: RedisSettings | None = None) -> None:
super().__init__("redis")
self._settings = settings or config.redis
self._client: Optional[redis.Redis] = None
self._loop_id: int | None = None
def _build_client(self) -> redis.Redis:
return redis.from_url(
self._settings.url,
decode_responses=True,
socket_connect_timeout=self._settings.socket_connect_timeout,
socket_timeout=self._settings.socket_timeout,
max_connections=self._settings.max_connections,
)
def _require_client(self) -> redis.Redis:
client = self._client
if client is None:
raise RuntimeError("Redis client is not initialized")
return client
async def initialize(self, **_: Any) -> bool:
try:
client = self._build_client()
ping_result = client.ping()
if inspect.isawaitable(ping_result):
await ping_result
self._client = client
self._loop_id = _current_loop_id()
self._set_initialized(True)
self.logger.info("Redis service initialized")
return True
except Exception as exc: # noqa: BLE001
self.logger.warning("Redis service initialization failed", error=str(exc))
self._client = None
self._loop_id = None
self._set_initialized(False)
return False
async def close(self) -> bool:
client = self._client
if client is None:
self._loop_id = None
return True
try:
await client.aclose()
self.logger.info("Redis service closed")
return True
except Exception as exc: # noqa: BLE001
self.logger.exception("Redis service close failed", error=str(exc))
return False
finally:
self._client = None
self._loop_id = None
self._set_initialized(False)
async def health_check(self) -> Dict[str, Any]:
client = self._client
if client is None:
return {"status": "unhealthy", "details": {"error": "not initialized"}}
try:
ping_result = client.ping()
ping = (
await ping_result if inspect.isawaitable(ping_result) else ping_result
)
info_result = client.info()
info = (
await info_result if inspect.isawaitable(info_result) else info_result
)
return {
"status": "healthy" if ping else "unhealthy",
"details": {
"ping": ping,
"redis_version": info.get("redis_version"),
"connected_clients": info.get("connected_clients"),
"used_memory": info.get("used_memory_human"),
"uptime_in_seconds": info.get("uptime_in_seconds"),
},
}
except Exception as exc: # noqa: BLE001
self.logger.warning("Redis health check failed", error=str(exc))
return {"status": "unhealthy", "details": {"error": str(exc)}}
def get_client(self) -> redis.Redis:
return self._require_client()
def _current_loop_id() -> int | None:
try:
return id(asyncio.get_running_loop())
except RuntimeError:
return None
async def get_or_init_redis_client() -> redis.Redis:
current_loop_id = _current_loop_id()
bound_loop_id = redis_service._loop_id
if (
redis_service.is_initialized
and bound_loop_id is not None
and current_loop_id is not None
and bound_loop_id != current_loop_id
):
redis_service.logger.warning(
"Redis client bound to different event loop; reinitializing",
previous_loop_id=bound_loop_id,
current_loop_id=current_loop_id,
)
redis_service._client = None
redis_service._loop_id = None
redis_service._set_initialized(False)
if not redis_service.is_initialized:
initialized = await redis_service.initialize()
if not initialized:
raise RuntimeError("Redis service initialization failed")
return redis_service.get_client()
redis_service: RedisService = register_service_instance("redis", RedisService())
__all__ = ["RedisService", "get_or_init_redis_client", "redis_service"]
@@ -0,0 +1,158 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Optional, TypeVar
from core.logging import get_logger
class BaseServiceProvider(ABC):
def __init__(self, service_name: str) -> None:
self.service_name = service_name
self._initialized = False
self.logger = get_logger("services.base").bind(service=service_name)
@abstractmethod
async def initialize(self, **kwargs: Any) -> bool:
raise NotImplementedError
@abstractmethod
async def close(self) -> bool:
raise NotImplementedError
@abstractmethod
async def health_check(self) -> Dict[str, Any]:
raise NotImplementedError
@property
def is_initialized(self) -> bool:
return self._initialized
def _set_initialized(self, value: bool) -> None:
self._initialized = value
def get_service_info(self) -> Dict[str, Any]:
return {
"name": self.service_name,
"initialized": self._initialized,
"type": self.__class__.__name__,
}
class ServiceRegistry:
_services: Dict[str, Callable[..., BaseServiceProvider]] = {}
@classmethod
def register(
cls, service_name: str, factory: Callable[..., BaseServiceProvider]
) -> None:
cls._services = {**cls._services, service_name: factory}
@classmethod
def get_service_factory(
cls, service_name: str
) -> Optional[Callable[..., BaseServiceProvider]]:
return cls._services.get(service_name)
@classmethod
def list_services(cls) -> list[str]:
return sorted(cls._services.keys())
@classmethod
def create_service(
cls, service_name: str, **kwargs: Any
) -> Optional[BaseServiceProvider]:
return cls.get_service(service_name, **kwargs)
@classmethod
def get_service(
cls, service_name: str, **kwargs: Any
) -> Optional[BaseServiceProvider]:
factory = cls.get_service_factory(service_name)
if not factory:
return None
return factory(**kwargs)
def register_service(service_name: str) -> Callable[[type], type]:
def decorator(service_class: type) -> type:
ServiceRegistry.register(service_name, service_class)
return service_class
return decorator
TService = TypeVar("TService", bound=BaseServiceProvider)
def register_service_instance(service_name: str, service: TService) -> TService:
ServiceRegistry.register(service_name, lambda: service)
return service
def resolve_registered_services(service_names: list[str]) -> list[BaseServiceProvider]:
services: list[BaseServiceProvider] = []
for service_name in service_names:
service = ServiceRegistry.get_service(service_name)
if service is None:
raise RuntimeError(f"Service is not registered: {service_name}")
services.append(service)
return services
async def close_registered_services(services: list[BaseServiceProvider]) -> bool:
lifecycle_logger = get_logger("services.base.lifecycle")
all_closed = True
for service in reversed(services):
try:
closed = await service.close()
except Exception as exc: # noqa: BLE001
lifecycle_logger.warning(
"Failed to close service",
service=service.service_name,
error=str(exc),
)
all_closed = False
continue
if not closed:
lifecycle_logger.warning(
"Service close returned false",
service=service.service_name,
)
all_closed = False
return all_closed
async def initialize_registered_services(
service_names: list[str],
) -> tuple[bool, list[BaseServiceProvider]]:
lifecycle_logger = get_logger("services.base.lifecycle")
initialized_services: list[BaseServiceProvider] = []
try:
services = resolve_registered_services(service_names)
except RuntimeError as exc:
lifecycle_logger.error("Failed to resolve registered services", error=str(exc))
return False, []
for service in services:
try:
initialized = await service.initialize()
except Exception as exc: # noqa: BLE001
lifecycle_logger.warning(
"Service initialization raised exception",
service=service.service_name,
error=str(exc),
)
initialized = False
if not initialized:
lifecycle_logger.error(
"Service initialization failed, rolling back",
service=service.service_name,
)
await close_registered_services(initialized_services)
return False, []
initialized_services.append(service)
return True, initialized_services
+304
View File
@@ -0,0 +1,304 @@
from __future__ import annotations
import asyncio
from typing import Any
from supabase import create_client
from storage3.exceptions import StorageApiError
from core.config.settings import SupabaseSettings, config
from .service_interface import BaseServiceProvider, register_service_instance
class SupabaseService(BaseServiceProvider):
def __init__(self, settings: SupabaseSettings | None = None) -> None:
super().__init__("supabase")
self._settings = settings or config.supabase
self._client: Any = None
self._admin_client: Any = None
async def initialize(self, **_: Any) -> bool:
try:
self._init_clients()
await self._ensure_storage_bucket()
self._set_initialized(True)
self.logger.info("Supabase service initialized")
return True
except Exception as exc: # noqa: BLE001
self.logger.warning(
"Supabase service initialization failed", error=str(exc)
)
self._client = None
self._admin_client = None
self._set_initialized(False)
return False
async def close(self) -> bool:
self._client = None
self._admin_client = None
self._set_initialized(False)
self.logger.info("Supabase service closed")
return True
async def health_check(self) -> dict[str, Any]:
client = self._client
admin_client = self._admin_client
if client is None or admin_client is None:
return {"status": "unhealthy", "details": {"error": "not initialized"}}
try:
await asyncio.to_thread(client.auth.get_session)
await asyncio.to_thread(
admin_client.auth.admin.list_users, page=1, per_page=1
)
return {
"status": "healthy",
"details": {
"anon_client": "ready",
"admin_client": "ready",
},
}
except Exception as exc: # noqa: BLE001
self.logger.warning("Supabase health check failed", error=str(exc))
return {"status": "unhealthy", "details": {"error": str(exc)}}
def get_client(self) -> Any:
return self._require_client()
def get_admin_client(self) -> Any:
return self._require_admin_client()
def _require_client(self) -> Any:
if self._client is None or self._admin_client is None:
self._init_clients()
self._set_initialized(True)
self.logger.info("Supabase service lazily initialized")
client = self._client
if client is None:
raise RuntimeError("Supabase client is not initialized")
return client
def _require_admin_client(self) -> Any:
if self._client is None or self._admin_client is None:
self._init_clients()
self._set_initialized(True)
self.logger.info("Supabase service lazily initialized")
admin_client = self._admin_client
if admin_client is None:
raise RuntimeError("Supabase admin client is not initialized")
return admin_client
def _init_clients(self) -> None:
self._client = create_client(
self._settings.url,
self._settings.anon_key,
)
self._admin_client = create_client(
self._settings.url,
self._settings.service_role_key,
)
async def _ensure_storage_bucket(self) -> None:
storage = getattr(self._admin_client, "storage", None)
if storage is None:
self.logger.warning("Storage client unavailable, skipping bucket check")
return
get_bucket = getattr(storage, "get_bucket", None)
if not callable(get_bucket):
self.logger.warning("Storage get_bucket unavailable, skipping bucket check")
return
buckets = [
(config.storage.attachment.bucket, False),
(config.storage.avatar.bucket, True),
]
def _check_and_create() -> None:
for bucket_name, is_public in buckets:
try:
get_bucket(bucket_name)
self.logger.debug(
"Storage bucket already exists", bucket=bucket_name
)
except Exception: # noqa: BLE001
create_bucket = getattr(storage, "create_bucket", None)
if not callable(create_bucket):
self.logger.warning(
"Storage create_bucket unavailable, skipping bucket creation"
)
return
try:
create_bucket(bucket_name, options={"public": is_public})
self.logger.info(
"Storage bucket created",
bucket=bucket_name,
public=is_public,
)
except Exception as exc: # noqa: BLE001
msg = str(exc).lower()
if "already exists" in msg or "duplicate" in msg:
self.logger.debug(
"Storage bucket already exists (race)",
bucket=bucket_name,
)
continue
self.logger.warning(
"Failed to create storage bucket",
bucket=bucket_name,
error=str(exc),
)
await asyncio.to_thread(_check_and_create)
def _get_storage(self) -> Any:
"""Get the storage client from admin client."""
client = self.get_admin_client()
storage = getattr(client, "storage", None)
if storage is None:
raise RuntimeError("Supabase storage client unavailable")
return storage
def _get_bucket_client(self, bucket: str) -> Any:
"""Get a bucket client for the specified bucket."""
storage = self._get_storage()
from_bucket = getattr(storage, "from_", None)
if not callable(from_bucket):
raise RuntimeError("Supabase storage bucket accessor unavailable")
return from_bucket(bucket)
def _validate_bucket(self, bucket: str) -> None:
"""Validate that the bucket matches one of configured storage buckets."""
allowed_buckets = {
config.storage.attachment.bucket,
config.storage.avatar.bucket,
}
if bucket not in allowed_buckets:
raise RuntimeError("Invalid storage bucket")
def _ensure_bucket_client(self, bucket: str) -> Any:
"""Validate bucket and return authenticated bucket client."""
self._validate_bucket(bucket)
return self._get_bucket_client(bucket)
def _is_bucket_not_found_error(self, exc: Exception) -> bool:
"""Check if the exception indicates a bucket was not found."""
if isinstance(exc, StorageApiError):
message = str(exc).lower()
return "bucket" in message and "not found" in message
message = str(exc).lower()
return "bucket" in message and "not found" in message
async def upload_bytes(
self,
*,
bucket: str,
path: str,
content: bytes,
content_type: str,
) -> str:
def _upload() -> object:
bucket_client = self._ensure_bucket_client(bucket)
upload = getattr(bucket_client, "upload", None)
if not callable(upload):
raise RuntimeError("Supabase storage upload is unavailable")
return upload(
path,
content,
{
"content-type": content_type,
"upsert": "true",
},
)
try:
await asyncio.to_thread(_upload)
except Exception as exc: # noqa: BLE001
if not self._is_bucket_not_found_error(exc):
raise
await self._ensure_bucket_exists(bucket=bucket)
await asyncio.to_thread(_upload)
return path
async def _ensure_bucket_exists(self, *, bucket: str) -> None:
def _ensure() -> None:
storage = self._get_storage()
get_bucket = getattr(storage, "get_bucket", None)
if not callable(get_bucket):
raise RuntimeError("Supabase storage get_bucket is unavailable")
try:
get_bucket(bucket)
except Exception as exc: # noqa: BLE001
msg = str(exc).lower()
if "bucket" in msg and "not found" in msg:
raise RuntimeError(f"Storage bucket '{bucket}' does not exist")
raise
await asyncio.to_thread(_ensure)
async def download_bytes(self, *, bucket: str, path: str) -> bytes:
def _download() -> object:
bucket_client = self._ensure_bucket_client(bucket)
download = getattr(bucket_client, "download", None)
if not callable(download):
raise RuntimeError("Supabase storage download is unavailable")
return download(path)
raw = await asyncio.to_thread(_download)
if isinstance(raw, bytes):
return raw
if isinstance(raw, bytearray):
return bytes(raw)
if isinstance(raw, memoryview):
return raw.tobytes()
raise RuntimeError("Invalid attachment payload")
async def create_signed_url(
self,
*,
bucket: str,
path: str,
expires_in_seconds: int,
) -> str:
def _create_signed_url() -> object:
bucket_client = self._ensure_bucket_client(bucket)
signer = getattr(bucket_client, "create_signed_url", None)
if not callable(signer):
raise RuntimeError("Supabase storage signed url is unavailable")
return signer(path, expires_in_seconds)
raw = await asyncio.to_thread(_create_signed_url)
if isinstance(raw, str):
return raw
if isinstance(raw, dict):
signed_url = raw.get("signedURL") or raw.get("signedUrl") or raw.get("url")
if isinstance(signed_url, str) and signed_url:
return signed_url
raise RuntimeError("Invalid signed url payload")
def parse_signed_url(self, url: str) -> tuple[str, str]:
from urllib.parse import urlparse
parsed = urlparse(url)
path_parts = parsed.path.strip("/").split("/")
if (
len(path_parts) < 4
or path_parts[0] != "storage"
or path_parts[1] != "v1"
or path_parts[2] != "object"
or path_parts[3] != "sign"
):
raise RuntimeError("Invalid signed URL format")
bucket = path_parts[4]
path = "/".join(path_parts[5:])
return bucket, path
supabase_service: SupabaseService = register_service_instance(
"supabase", SupabaseService()
)
__all__ = ["SupabaseService", "supabase_service"]
+4
View File
@@ -0,0 +1,4 @@
from .factory import get_cache_store
from .interfaces import CacheStore
__all__ = ["CacheStore", "get_cache_store"]
+13
View File
@@ -0,0 +1,13 @@
from __future__ import annotations
from .interfaces import CacheStore
from .redis_store import RedisCacheStore
_cache_store: CacheStore | None = None
def get_cache_store() -> CacheStore:
global _cache_store
if _cache_store is None:
_cache_store = RedisCacheStore()
return _cache_store
+19
View File
@@ -0,0 +1,19 @@
from __future__ import annotations
from typing import Protocol
class CacheStore(Protocol):
async def hgetall(self, key: str, /) -> dict[str, str]: ...
async def hset(self, key: str, /, mapping: dict[str, str]) -> int: ...
async def hincrby(self, key: str, field: str, amount: int = 1, /) -> int: ...
async def expire(self, key: str, ttl_seconds: int, /) -> int: ...
async def delete(self, *keys: str) -> int: ...
async def sadd(self, key: str, *members: str) -> int: ...
async def smembers(self, key: str, /) -> set[str]: ...
@@ -0,0 +1,80 @@
from __future__ import annotations
import inspect
from typing import Any
from services.base.redis import get_or_init_redis_client
from .interfaces import CacheStore
def _to_text(value: Any) -> str | None:
if isinstance(value, str):
return value
if isinstance(value, bytes):
try:
return value.decode("utf-8")
except UnicodeDecodeError:
return None
return None
async def _maybe_await(value: Any) -> Any:
if inspect.isawaitable(value):
return await value
return value
class RedisCacheStore(CacheStore):
async def hgetall(self, key: str) -> dict[str, str]:
client = await get_or_init_redis_client()
raw = await _maybe_await(client.hgetall(key))
if not isinstance(raw, dict):
return {}
decoded: dict[str, str] = {}
for raw_key, raw_value in raw.items():
key_text = _to_text(raw_key)
value_text = _to_text(raw_value)
if key_text is None or value_text is None:
continue
decoded[key_text] = value_text
return decoded
async def hset(self, key: str, mapping: dict[str, str]) -> int:
client = await get_or_init_redis_client()
result = await _maybe_await(client.hset(key, mapping=mapping))
return int(result)
async def hincrby(self, key: str, field: str, amount: int = 1) -> int:
client = await get_or_init_redis_client()
result = await _maybe_await(client.hincrby(key, field, amount))
return int(result)
async def expire(self, key: str, ttl_seconds: int) -> int:
client = await get_or_init_redis_client()
result = await _maybe_await(client.expire(key, ttl_seconds))
return int(result)
async def delete(self, *keys: str) -> int:
if not keys:
return 0
client = await get_or_init_redis_client()
result = await _maybe_await(client.delete(*keys))
return int(result)
async def sadd(self, key: str, *members: str) -> int:
if not members:
return 0
client = await get_or_init_redis_client()
result = await _maybe_await(client.sadd(key, *members))
return int(result)
async def smembers(self, key: str) -> set[str]:
client = await get_or_init_redis_client()
raw = await _maybe_await(client.smembers(key))
if isinstance(raw, set):
return {value for item in raw if (value := _to_text(item))}
if isinstance(raw, list | tuple):
return {value for item in raw if (value := _to_text(item))}
return set()
@@ -0,0 +1,5 @@
from __future__ import annotations
from services.llm_pricing.service import LlmPricingService
__all__ = ["LlmPricingService"]
+183
View File
@@ -0,0 +1,183 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from core.config.initial.init_data import load_llm_catalog
@dataclass(frozen=True)
class PricingTier:
max_prompt_tokens: int
input_cost_per_token: float
output_cost_per_token: float
cache_hit_cost_per_token: float
class LlmPricingService:
_pricing_by_model: dict[str, tuple[PricingTier, ...]]
def __init__(self) -> None:
self._pricing_by_model = self._build_pricing_map()
@staticmethod
def _build_pricing_map() -> dict[str, tuple[PricingTier, ...]]:
catalog = load_llm_catalog()
pricing_by_model: dict[str, tuple[PricingTier, ...]] = {}
for model in catalog.get("llms", []):
if not isinstance(model, dict):
continue
model_code = str(model.get("model_code", "")).strip().lower()
raw_tiers = model.get("pricing_tiers")
if not isinstance(raw_tiers, list) or not raw_tiers:
continue
tiers = [
PricingTier(
max_prompt_tokens=int(item.get("max_prompt_tokens", 0) or 0),
input_cost_per_token=float(
item.get("input_cost_per_token", 0.0) or 0.0
),
output_cost_per_token=float(
item.get("output_cost_per_token", 0.0) or 0.0
),
cache_hit_cost_per_token=float(
item.get("cache_hit_cost_per_token", 0.0) or 0.0
),
)
for item in raw_tiers
if isinstance(item, dict)
]
if not tiers:
continue
ordered_tiers = tuple(
sorted(tiers, key=lambda item: item.max_prompt_tokens)
)
if model_code:
pricing_by_model[model_code] = ordered_tiers
return pricing_by_model
def calculate_cost(
self,
*,
model: str,
prompt_tokens: int,
completion_tokens: int,
cached_prompt_tokens: int = 0,
) -> float:
tiers = self._pricing_by_model.get(model.strip().lower())
if tiers is None:
raise ValueError(f"unknown model pricing: {model}")
normalized_prompt_tokens = max(int(prompt_tokens), 0)
normalized_completion_tokens = max(int(completion_tokens), 0)
normalized_cached_tokens = min(
max(int(cached_prompt_tokens), 0), normalized_prompt_tokens
)
uncached_prompt_tokens = normalized_prompt_tokens - normalized_cached_tokens
selected_tier = tiers[-1]
for tier in tiers:
if normalized_prompt_tokens <= tier.max_prompt_tokens:
selected_tier = tier
break
cached_token_rate = (
selected_tier.cache_hit_cost_per_token
if selected_tier.cache_hit_cost_per_token > 0
else selected_tier.input_cost_per_token
)
return float(
uncached_prompt_tokens * selected_tier.input_cost_per_token
+ normalized_cached_tokens * cached_token_rate
+ normalized_completion_tokens * selected_tier.output_cost_per_token
)
def build_usage_metadata(
self,
*,
model: str,
usage_summary: dict[str, Any] | None,
) -> dict[str, Any]:
summary = usage_summary or {}
input_tokens = max(int(summary.get("input_tokens", 0) or 0), 0)
output_tokens = max(int(summary.get("output_tokens", 0) or 0), 0)
total_tokens = max(
int(summary.get("total_tokens", input_tokens + output_tokens) or 0), 0
)
latency_ms = max(int(summary.get("latency_ms", 0) or 0), 0)
cached_prompt_tokens = max(int(summary.get("cached_prompt_tokens", 0) or 0), 0)
prompt_cache_hit_tokens = max(
int(summary.get("prompt_cache_hit_tokens", cached_prompt_tokens) or 0), 0
)
prompt_cache_miss_tokens = max(
int(
summary.get(
"prompt_cache_miss_tokens",
max(input_tokens - prompt_cache_hit_tokens, 0),
)
or 0
),
0,
)
reasoning_tokens = max(int(summary.get("reasoning_tokens", 0) or 0), 0)
direct_cost_raw = summary.get("direct_cost")
direct_cost_observed = bool(int(summary.get("direct_cost_observed", 0) or 0))
direct_cost_complete = bool(int(summary.get("direct_cost_complete", 0) or 0))
model_call_records = max(int(summary.get("model_call_records", 0) or 0), 0)
usage_records = max(int(summary.get("usage_records", 0) or 0), 0)
usage_complete = model_call_records == 0 or model_call_records == usage_records
direct_cost = self._coerce_non_negative_float(direct_cost_raw)
if (
usage_complete
and direct_cost_observed
and direct_cost_complete
and direct_cost is not None
):
cost = direct_cost
cost_source = "provider"
else:
cost = self.calculate_cost(
model=model,
prompt_tokens=input_tokens,
completion_tokens=output_tokens,
cached_prompt_tokens=cached_prompt_tokens,
)
cost_source = (
"incomplete_usage_fallback"
if not usage_complete
else (
"catalog_fallback_incomplete_provider_cost"
if direct_cost_observed and not direct_cost_complete
else "catalog_fallback"
)
)
return {
"model": model,
"inputTokens": input_tokens,
"outputTokens": output_tokens,
"totalTokens": total_tokens,
"cachedPromptTokens": cached_prompt_tokens,
"promptCacheHitTokens": prompt_cache_hit_tokens,
"promptCacheMissTokens": prompt_cache_miss_tokens,
"reasoningTokens": reasoning_tokens,
"cost": cost,
"costSource": cost_source,
"usageComplete": usage_complete,
"latencyMs": latency_ms,
}
@staticmethod
def _coerce_non_negative_float(value: Any) -> float | None:
if value is None:
return None
try:
parsed = float(value)
except (TypeError, ValueError):
return None
if parsed < 0:
return None
return parsed
View File
+1
View File
@@ -0,0 +1 @@
from __future__ import annotations
@@ -0,0 +1,35 @@
from __future__ import annotations
from functools import lru_cache
from pathlib import Path
import re
from typing import Any
import yaml
from schemas.domain.automation import AutomationJobConfig
_CONFIG_NAME_PATTERN = re.compile(r"^[a-z0-9][a-z0-9_-]{0,63}$")
def _automation_yaml_path(config_name: str) -> Path:
if not _CONFIG_NAME_PATTERN.fullmatch(config_name):
raise ValueError("invalid automation config name")
return (
Path(__file__).resolve().parents[2]
/ "core"
/ "config"
/ "static"
/ "automation"
/ f"{config_name}.yaml"
)
@lru_cache(maxsize=16)
def load_static_automation_job_config(*, config_name: str) -> AutomationJobConfig:
path = _automation_yaml_path(config_name)
with path.open("r", encoding="utf-8") as file:
loaded: Any = yaml.safe_load(file) or {}
if not isinstance(loaded, dict):
raise ValueError(f"invalid automation config format: {path}")
return AutomationJobConfig.model_validate(loaded)
+27
View File
@@ -0,0 +1,27 @@
from __future__ import annotations
from typing import Annotated
from fastapi import Depends
from sqlalchemy.ext.asyncio import AsyncSession
from core.db import get_db
from v1.auth.gateway import SupabaseAuthGateway
from v1.auth.registration_bootstrap import (
RegistrationAutomationBootstrapService,
RegistrationBootstrapRepository,
)
from v1.auth.service import AuthService
def get_auth_service(
session: Annotated[AsyncSession, Depends(get_db)],
) -> AuthService:
bootstrapper = RegistrationAutomationBootstrapService(
repository=RegistrationBootstrapRepository(session=session),
session=session,
)
return AuthService(
gateway=SupabaseAuthGateway(),
registration_bootstrapper=bootstrapper,
)
+430
View File
@@ -0,0 +1,430 @@
from __future__ import annotations
import asyncio
import time
from typing import Any, cast
from pydantic import ValidationError
from supabase import AuthError
from core.http.errors import ApiProblemError
from core.logging import get_logger
from services.base.supabase import supabase_service
from v1.auth.schemas import (
AuthUser,
OtpSendRequest,
PhoneSessionCreateRequest,
SessionRefreshRequest,
SessionResponse,
UserByIdResponse,
UserByPhoneResponse,
)
from v1.auth.service import AuthServiceGateway
logger = get_logger("v1.auth.gateway")
AUTH_UNAVAILABLE_DETAIL = "Auth service temporarily unavailable"
def _auth_error(
*,
status_code: int,
code: str,
detail: str,
) -> ApiProblemError:
return ApiProblemError(status_code=status_code, code=code, detail=detail)
class SupabaseAuthGateway(AuthServiceGateway):
def __init__(self) -> None:
self._user_lookup_cache_ttl_seconds: int = 60
self._user_lookup_cache_expires_at: float = 0.0
self._users_by_phone: dict[str, Any] = {}
self._users_by_id: dict[str, Any] = {}
def _get_client(self) -> Any:
return supabase_service.get_client()
def _get_admin_client(self) -> Any:
return supabase_service.get_admin_client()
async def send_otp(self, request: OtpSendRequest) -> None:
client = self._get_client()
payload: dict[str, Any] = {
"phone": request.phone,
"options": {"should_create_user": True},
}
try:
sign_in_with_otp = cast(Any, client.auth.sign_in_with_otp)
await asyncio.to_thread(sign_in_with_otp, payload)
except AuthError as exc:
logger.warning("Send otp failed", error_type=type(exc).__name__)
if _is_auth_upstream_unavailable(exc):
raise _auth_error(
status_code=503,
code="AUTH_SERVICE_UNAVAILABLE",
detail=AUTH_UNAVAILABLE_DETAIL,
) from exc
raise _auth_error(
status_code=429,
code="AUTH_TOO_MANY_REQUESTS",
detail="Too many requests",
) from exc
async def create_phone_session(
self, request: PhoneSessionCreateRequest
) -> SessionResponse:
client = self._get_client()
payload: dict[str, Any] = {
"type": "sms",
"phone": request.phone,
"token": request.token,
}
try:
verify_otp = cast(Any, client.auth.verify_otp)
response = await asyncio.to_thread(verify_otp, payload)
return _map_auth_response(
response,
"Invalid verification code",
"AUTH_VERIFICATION_CODE_INVALID",
)
except AuthError as exc:
logger.warning("Create phone session failed", error_type=type(exc).__name__)
if _is_auth_upstream_unavailable(exc):
raise _auth_error(
status_code=503,
code="AUTH_SERVICE_UNAVAILABLE",
detail=AUTH_UNAVAILABLE_DETAIL,
) from exc
raise _auth_error(
status_code=401,
code="AUTH_VERIFICATION_CODE_INVALID",
detail="Invalid verification code",
) from exc
async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse:
client = self._get_client()
try:
response = await asyncio.to_thread(
client.auth.refresh_session,
request.refresh_token,
)
return _map_auth_response(
response,
"Invalid refresh token",
"AUTH_REFRESH_TOKEN_INVALID",
)
except AuthError as exc:
logger.warning("Refresh failed", error_type=type(exc).__name__)
if _is_auth_upstream_unavailable(exc):
raise _auth_error(
status_code=503,
code="AUTH_SERVICE_UNAVAILABLE",
detail=AUTH_UNAVAILABLE_DETAIL,
) from exc
raise _auth_error(
status_code=401,
code="AUTH_REFRESH_TOKEN_INVALID",
detail="Invalid refresh token",
) from exc
async def delete_session(self, refresh_token: str | None) -> None:
if not refresh_token:
raise _auth_error(
status_code=401,
code="AUTH_REFRESH_TOKEN_MISSING",
detail="Missing refresh token",
)
client = self._get_client()
try:
response = await asyncio.to_thread(
client.auth.refresh_session,
refresh_token,
)
session = getattr(response, "session", None)
if session is None:
raise _auth_error(
status_code=401,
code="AUTH_REFRESH_TOKEN_INVALID",
detail="Invalid refresh token",
)
await asyncio.to_thread(
client.auth.set_session,
str(session.access_token),
str(session.refresh_token),
)
await asyncio.to_thread(client.auth.sign_out)
except AuthError as exc:
logger.warning("Logout failed", error_type=type(exc).__name__)
if _is_auth_upstream_unavailable(exc):
raise _auth_error(
status_code=503,
code="AUTH_SERVICE_UNAVAILABLE",
detail=AUTH_UNAVAILABLE_DETAIL,
) from exc
raise _auth_error(
status_code=401,
code="AUTH_REFRESH_TOKEN_INVALID",
detail="Invalid refresh token",
) from exc
async def get_user_by_phone(self, phone: str) -> UserByPhoneResponse:
normalized_phone = _normalize_phone(phone)
if not normalized_phone:
raise _auth_error(
status_code=404,
code="AUTH_USER_NOT_FOUND",
detail="User not found",
)
await self._refresh_user_lookup_cache_if_needed()
user = self._users_by_phone.get(normalized_phone)
if user is None:
raise _auth_error(
status_code=404,
code="AUTH_USER_NOT_FOUND",
detail="User not found",
)
user_phone = _normalize_phone(getattr(user, "phone", ""))
if not user_phone:
raise _auth_error(
status_code=404,
code="AUTH_USER_NOT_FOUND",
detail="User not found",
)
return UserByPhoneResponse(
id=str(getattr(user, "id", "")),
phone=user_phone,
created_at=str(getattr(user, "created_at", "")),
phone_confirmed_at=(
str(getattr(user, "phone_confirmed_at", ""))
if getattr(user, "phone_confirmed_at", None)
else None
),
)
async def get_user_by_id(self, user_id: str) -> UserByIdResponse:
users = await self.get_users_by_ids([user_id])
resolved = users.get(user_id)
if resolved is None:
raise _auth_error(
status_code=404,
code="AUTH_USER_NOT_FOUND",
detail="User not found",
)
return resolved
async def get_users_by_ids(
self, user_ids: list[str]
) -> dict[str, UserByIdResponse]:
await self._refresh_user_lookup_cache_if_needed()
resolved: dict[str, UserByIdResponse] = {}
for raw_user_id in user_ids:
normalized_user_id = raw_user_id.strip()
if not normalized_user_id:
continue
user = self._users_by_id.get(normalized_user_id)
if user is None:
continue
user_attrs = getattr(user, "user", user)
resolved[normalized_user_id] = UserByIdResponse(
id=str(getattr(user_attrs, "id", "")),
phone=getattr(user_attrs, "phone", None),
created_at=str(getattr(user_attrs, "created_at", "")),
phone_confirmed_at=(
str(getattr(user_attrs, "phone_confirmed_at", ""))
if getattr(user_attrs, "phone_confirmed_at", None)
else None
),
)
return resolved
async def search_user_ids_by_phone(self, query: str, limit: int = 20) -> list[str]:
normalized_query = _normalize_phone_search_query(query)
if not normalized_query:
return []
await self._refresh_user_lookup_cache_if_needed()
if normalized_query.startswith("+"):
matched_user = self._users_by_phone.get(normalized_query)
if matched_user is None:
return []
user_id = str(getattr(matched_user, "id", ""))
return [user_id] if user_id else []
digits = _digits_only(normalized_query)
if not digits:
return []
matched_records: list[tuple[str, str]] = []
for cached_phone, candidate in self._users_by_phone.items():
candidate_digits = _digits_only(cached_phone)
if not candidate_digits.endswith(digits):
continue
user_id = str(getattr(candidate, "id", ""))
if user_id:
matched_records.append((cached_phone, user_id))
if not matched_records:
return []
unique_ids: list[str] = []
for _, user_id in sorted(matched_records, key=lambda item: item[0]):
if user_id in unique_ids:
continue
unique_ids.append(user_id)
if len(unique_ids) >= max(1, limit):
break
return unique_ids
async def _refresh_user_lookup_cache_if_needed(self) -> None:
now = time.monotonic()
if now < self._user_lookup_cache_expires_at:
return
admin_client = self._get_admin_client()
users = await asyncio.to_thread(_list_auth_users, admin_client)
users_by_phone: dict[str, Any] = {}
users_by_id: dict[str, Any] = {}
for candidate in users:
candidate_id = str(getattr(candidate, "id", "")).strip()
if candidate_id:
users_by_id[candidate_id] = candidate
candidate_phone = _normalize_phone(getattr(candidate, "phone", ""))
if candidate_phone:
users_by_phone[candidate_phone] = candidate
self._users_by_id = users_by_id
self._users_by_phone = users_by_phone
self._user_lookup_cache_expires_at = now + self._user_lookup_cache_ttl_seconds
def _is_auth_upstream_unavailable(exc: AuthError) -> bool:
raw_status = getattr(exc, "status", None)
if raw_status is None:
raw_status = getattr(exc, "status_code", None)
if isinstance(raw_status, int) and 500 <= raw_status < 600:
return True
raw_code = getattr(exc, "code", None)
code = str(raw_code).lower() if raw_code is not None else ""
message = str(exc).lower()
indicators = (
"request_timeout",
"timed out",
"timeout",
"gateway timeout",
"bad_gateway",
"service_unavailable",
"internal_server_error",
"unexpected_failure",
"upstream",
"500",
"502",
"503",
"504",
"5xx",
)
return any(token in code or token in message for token in indicators)
def _map_auth_response(
response: object, failure_message: str, failure_code: str
) -> SessionResponse:
session = getattr(response, "session", None)
user = getattr(response, "user", None)
if session is None or user is None:
raise _auth_error(
status_code=401,
code=failure_code,
detail=failure_message,
)
phone = _normalize_phone(getattr(user, "phone", None))
if not phone:
raise _auth_error(
status_code=401,
code=failure_code,
detail=failure_message,
)
try:
auth_user = AuthUser(id=str(user.id), phone=str(phone))
except ValidationError as exc:
logger.warning(
"Auth response returned invalid phone format",
error_type=type(exc).__name__,
)
raise _auth_error(
status_code=401,
code=failure_code,
detail=failure_message,
) from exc
return SessionResponse(
access_token=str(session.access_token),
refresh_token=str(session.refresh_token),
expires_in=int(session.expires_in or 0),
token_type=str(session.token_type),
user=auth_user,
)
def _list_auth_users(client: Any) -> list[Any]:
users: list[Any] = []
page = 1
max_pages = 100
while page <= max_pages:
response = client.auth.admin.list_users(page=page, per_page=100)
batch = (
list(response)
if isinstance(response, list)
else list(getattr(response, "users", []))
)
users.extend(batch)
if len(batch) < 100:
break
page += 1
return users
def _sanitize_phone_token(raw: object) -> str:
token = str(raw).strip()
for separator in (" ", "-", "(", ")"):
token = token.replace(separator, "")
return token
def _normalize_phone(raw_phone: object) -> str | None:
phone = _sanitize_phone_token(raw_phone)
if not phone:
return None
if phone.startswith("00") and len(phone) > 2:
return f"+{phone[2:]}"
if phone.startswith("+"):
return phone
if phone.isdigit():
return f"+{phone}"
return None
def _normalize_phone_search_query(raw_query: str) -> str | None:
query = _sanitize_phone_token(raw_query)
if not query:
return None
if query.startswith("00") and len(query) > 2:
return f"+{query[2:]}"
if query.startswith("+"):
return query
if query.isdigit():
return query
return None
def _digits_only(value: str) -> str:
return "".join(ch for ch in value if ch.isdigit())
+114
View File
@@ -0,0 +1,114 @@
from __future__ import annotations
import asyncio
from collections import deque
from time import monotonic
from core.http.errors import ApiProblemError
from core.logging import get_logger
from services.base.redis import get_or_init_redis_client
_BUCKETS: dict[str, deque[float]] = {}
_LAST_SEEN: dict[str, float] = {}
_LOCK = asyncio.Lock()
_CLEANUP_INTERVAL = 200
_CALL_COUNT = 0
logger = get_logger("v1.auth.rate_limit")
_REDIS_LIMIT_SCRIPT = """
local current = redis.call("INCR", KEYS[1])
if current == 1 then
redis.call("EXPIRE", KEYS[1], ARGV[1])
end
return current
"""
async def enforce_rate_limit(
*,
scope: str,
identifier: str,
limit: int,
window_seconds: int,
) -> None:
key = f"auth:rate_limit:{scope}:{identifier.lower()}"
try:
await _enforce_rate_limit_with_redis(
key=key,
limit=limit,
window_seconds=window_seconds,
)
return
except ApiProblemError:
raise
except Exception as exc: # noqa: BLE001
logger.warning(
"Rate limit fallback to in-memory",
scope=scope,
error_type=type(exc).__name__,
)
await _enforce_rate_limit_in_memory(
key=key,
limit=limit,
window_seconds=window_seconds,
)
async def _enforce_rate_limit_with_redis(
*,
key: str,
limit: int,
window_seconds: int,
) -> None:
client = await get_or_init_redis_client()
current = await client.eval(_REDIS_LIMIT_SCRIPT, 1, key, window_seconds) # type: ignore[await]
if int(current) > limit:
raise ApiProblemError(
status_code=429,
code="AUTH_TOO_MANY_REQUESTS",
detail="Too many requests",
)
async def _enforce_rate_limit_in_memory(
*,
key: str,
limit: int,
window_seconds: int,
) -> None:
global _CALL_COUNT
now = monotonic()
async with _LOCK:
bucket = _BUCKETS.setdefault(key, deque())
_LAST_SEEN[key] = now
cutoff = now - float(window_seconds)
while bucket and bucket[0] <= cutoff:
bucket.popleft()
if len(bucket) >= limit:
raise ApiProblemError(
status_code=429,
code="AUTH_TOO_MANY_REQUESTS",
detail="Too many requests",
)
bucket.append(now)
_CALL_COUNT += 1
if _CALL_COUNT % _CLEANUP_INTERVAL == 0:
_cleanup_stale_buckets(now)
def _cleanup_stale_buckets(now: float) -> None:
stale_keys = [
key
for key, last_seen in _LAST_SEEN.items()
if key not in _BUCKETS or (not _BUCKETS[key] and now - last_seen > 3600)
]
for key in stale_keys:
_BUCKETS.pop(key, None)
_LAST_SEEN.pop(key, None)
def reset_rate_limit_state() -> None:
_BUCKETS.clear()
_LAST_SEEN.clear()
global _CALL_COUNT
_CALL_COUNT = 0
@@ -0,0 +1,239 @@
from __future__ import annotations
from datetime import UTC, datetime, time, timedelta
from typing import Protocol
from uuid import UUID, uuid4
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
from sqlalchemy import select
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.ext.asyncio import AsyncSession
from core.logging import get_logger
from models.automation_jobs import AutomationJob
from schemas.enums import AutomationJobStatus, MemoryType, ScheduleType
from models.profile import Profile
from schemas.domain.automation import AutomationJobConfig, ScheduleConfig
from schemas.domain.memory_content import UserMemoryContent, WorkProfileContent
from schemas.shared.user import parse_profile_settings
from v1.auth.automation_static_config import load_static_automation_job_config
from v1.auth.schemas import RegistrationBootstrapRequest
from v1.memories.repository import SQLAlchemyMemoriesRepository
logger = get_logger("v1.auth.registration_bootstrap")
class RegistrationBootstrapRepository:
def __init__(self, session: AsyncSession) -> None:
self._session = session
self._memories_repository = SQLAlchemyMemoriesRepository(session)
async def get_profile_timezone(self, *, user_id: UUID) -> str:
stmt = select(Profile.settings).where(Profile.id == user_id)
settings = (await self._session.execute(stmt)).scalar_one_or_none()
parsed = parse_profile_settings(
settings if isinstance(settings, dict) else None
)
return parsed.preferences.timezone
async def insert_bootstrap_automation_job_if_absent(
self,
*,
owner_id: UUID,
bootstrap_key: str,
title: str,
config: AutomationJobConfig,
timezone_name: str,
next_run_at: datetime,
) -> bool:
stmt = (
insert(AutomationJob)
.values(
id=uuid4(),
owner_id=owner_id,
bootstrap_key=bootstrap_key,
title=title,
config=config.model_dump(mode="json"),
next_run_at=next_run_at,
timezone=timezone_name,
status=AutomationJobStatus.ACTIVE,
created_by=owner_id,
)
.on_conflict_do_nothing(
index_elements=["owner_id", "bootstrap_key"],
index_where=AutomationJob.deleted_at.is_(None)
& AutomationJob.bootstrap_key.is_not(None),
)
.returning(AutomationJob.id)
)
inserted_id = (await self._session.execute(stmt)).scalar_one_or_none()
await self._session.flush()
return inserted_id is not None
async def upsert_initial_memory(
self,
*,
owner_id: UUID,
memory_type: MemoryType,
content: dict,
) -> bool:
return await self._memories_repository.create_if_absent(
owner_id=owner_id,
memory_type=memory_type,
content=content,
)
class RegistrationBootstrapper(Protocol):
async def ensure_user_automation_jobs(self, *, user_id: str | UUID) -> None: ...
class RegistrationBootstrapRepositoryLike(Protocol):
async def get_profile_timezone(self, *, user_id: UUID) -> str: ...
async def insert_bootstrap_automation_job_if_absent(
self,
*,
owner_id: UUID,
bootstrap_key: str,
title: str,
config: AutomationJobConfig,
timezone_name: str,
next_run_at: datetime,
) -> bool: ...
async def upsert_initial_memory(
self,
*,
owner_id: UUID,
memory_type: MemoryType,
content: dict,
) -> bool: ...
class SessionLike(Protocol):
async def commit(self) -> None: ...
async def rollback(self) -> None: ...
def compute_first_run_at_utc(
*,
now_utc: datetime,
timezone_name: str,
schedule: ScheduleConfig,
) -> datetime:
try:
timezone_obj = ZoneInfo(timezone_name)
except ZoneInfoNotFoundError:
timezone_obj = ZoneInfo("UTC")
local_now = now_utc.astimezone(timezone_obj)
run_clock = time(
hour=schedule.run_at.hour,
minute=schedule.run_at.minute,
tzinfo=timezone_obj,
)
if schedule.type == ScheduleType.DAILY:
candidate_local = datetime.combine(local_now.date(), run_clock)
if candidate_local <= local_now:
candidate_local = candidate_local + timedelta(days=1)
return candidate_local.astimezone(UTC)
weekdays = schedule.weekdays or []
if not weekdays:
raise ValueError("weekly schedule requires weekdays")
normalized_weekdays = sorted(set(weekdays))
for day_offset in range(0, 8):
candidate_day = local_now.date() + timedelta(days=day_offset)
if candidate_day.isoweekday() not in normalized_weekdays:
continue
candidate_local = datetime.combine(candidate_day, run_clock)
if candidate_local > local_now:
return candidate_local.astimezone(UTC)
fallback_day = local_now.date() + timedelta(days=7)
while fallback_day.isoweekday() not in normalized_weekdays:
fallback_day = fallback_day + timedelta(days=1)
fallback_local = datetime.combine(fallback_day, run_clock)
return fallback_local.astimezone(UTC)
class RegistrationAutomationBootstrapService:
def __init__(
self,
*,
repository: RegistrationBootstrapRepositoryLike,
session: SessionLike,
) -> None:
self._repository = repository
self._session = session
async def ensure_user_automation_jobs(self, *, user_id: str | UUID) -> None:
request = RegistrationBootstrapRequest.model_validate({"user_id": user_id})
owner_id = request.user_id
timezone_name = await self._repository.get_profile_timezone(user_id=owner_id)
definitions = [
{
"bootstrap_key": "memory_extraction",
"config_name": "memory_extraction",
"title": "记忆推送",
}
]
try:
inserted_any = False
created_or_updated_memory = False
user_initialized = await self._repository.upsert_initial_memory(
owner_id=owner_id,
memory_type=MemoryType.USER,
content=UserMemoryContent().model_dump(mode="json"),
)
work_initialized = await self._repository.upsert_initial_memory(
owner_id=owner_id,
memory_type=MemoryType.WORK,
content=WorkProfileContent().model_dump(mode="json"),
)
created_or_updated_memory = user_initialized or work_initialized
for definition in definitions:
bootstrap_key = str(definition["bootstrap_key"])
job_config = load_static_automation_job_config(
config_name=str(definition["config_name"])
)
schedule = job_config.schedule
if schedule is None:
raise ValueError(
f"bootstrap job {bootstrap_key} has no schedule configured"
)
next_run_at = compute_first_run_at_utc(
now_utc=datetime.now(UTC),
timezone_name=timezone_name,
schedule=schedule,
)
inserted = (
await self._repository.insert_bootstrap_automation_job_if_absent(
owner_id=owner_id,
bootstrap_key=bootstrap_key,
title=str(definition["title"]),
config=job_config,
timezone_name=timezone_name,
next_run_at=next_run_at,
)
)
inserted_any = inserted_any or inserted
if inserted_any or created_or_updated_memory:
await self._session.commit()
logger.info(
"user automation jobs bootstrapped",
user_id=user_id,
timezone=timezone_name,
memory_initialized=created_or_updated_memory,
)
except Exception:
await self._session.rollback()
raise
+117
View File
@@ -0,0 +1,117 @@
from __future__ import annotations
from fastapi import APIRouter, Depends, Request, Response
from core.config.settings import config
from v1.auth.rate_limit import enforce_rate_limit
from v1.auth.dependencies import get_auth_service
from v1.auth.schemas import (
OtpSendRequest,
PhoneSessionCreateRequest,
SessionDeleteRequest,
SessionRefreshRequest,
SessionResponse,
)
from v1.auth.service import AuthService
router = APIRouter(prefix="/auth", tags=["auth"])
@router.post("/otp/send", status_code=204)
async def send_otp(
payload: OtpSendRequest,
request: Request,
service: AuthService = Depends(get_auth_service),
) -> Response:
client_ip = _client_ip(request)
await enforce_rate_limit(
scope="otp_send_phone",
identifier=payload.phone,
limit=3,
window_seconds=60,
)
await enforce_rate_limit(
scope="otp_send_ip",
identifier=client_ip,
limit=20,
window_seconds=60,
)
await service.send_otp(payload)
return Response(status_code=204)
@router.post("/phone-session", response_model=SessionResponse)
async def create_phone_session(
payload: PhoneSessionCreateRequest,
request: Request,
service: AuthService = Depends(get_auth_service),
) -> SessionResponse:
client_ip = _client_ip(request)
await enforce_rate_limit(
scope="phone_session_phone",
identifier=payload.phone,
limit=6,
window_seconds=300,
)
await enforce_rate_limit(
scope="phone_session_ip",
identifier=client_ip,
limit=20,
window_seconds=300,
)
return await service.create_phone_session(payload)
@router.post("/sessions/refresh", response_model=SessionResponse)
async def refresh_session(
payload: SessionRefreshRequest,
request: Request,
service: AuthService = Depends(get_auth_service),
) -> SessionResponse:
await enforce_rate_limit(
scope="refresh",
identifier=_client_ip(request),
limit=10,
window_seconds=60,
)
return await service.refresh_session(payload)
@router.delete("/sessions", status_code=204)
async def delete_session(
payload: SessionDeleteRequest,
request: Request,
service: AuthService = Depends(get_auth_service),
) -> Response:
await enforce_rate_limit(
scope="logout",
identifier=_client_ip(request),
limit=10,
window_seconds=60,
)
await service.delete_session(payload.refresh_token)
return Response(status_code=204)
def _client_ip(request: Request) -> str:
host = request.client.host if request.client else ""
if not host:
return "unknown"
if _should_trust_proxy_headers(host):
forwarded_for = request.headers.get("x-forwarded-for", "")
if forwarded_for:
first = forwarded_for.split(",")[0].strip()
if first:
return first
real_ip = request.headers.get("x-real-ip", "").strip()
if real_ip:
return real_ip
return host
def _should_trust_proxy_headers(host: str) -> bool:
trusted_proxies = {entry.strip() for entry in config.runtime.trusted_proxy_ips}
return host in trusted_proxies
+66
View File
@@ -0,0 +1,66 @@
from __future__ import annotations
from uuid import UUID
from pydantic import BaseModel, ConfigDict, Field
SUPABASE_PASSWORD_MIN_LENGTH = 6
SUPABASE_PHONE_PATTERN = r"^\+[1-9]\d{7,14}$"
class OtpSendRequest(BaseModel):
model_config = ConfigDict(extra="forbid")
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
class PhoneSessionCreateRequest(BaseModel):
model_config = ConfigDict(extra="forbid")
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
token: str = Field(pattern=r"^\d{6}$")
class SessionRefreshRequest(BaseModel):
refresh_token: str = Field(min_length=1)
class SessionDeleteRequest(BaseModel):
refresh_token: str = Field(min_length=1)
class AuthUser(BaseModel):
id: str
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
class SessionResponse(BaseModel):
access_token: str
refresh_token: str
expires_in: int
token_type: str
user: AuthUser
class UserByPhoneResponse(BaseModel):
id: str
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
created_at: str
phone_confirmed_at: str | None = None
class UserByIdResponse(BaseModel):
id: str
phone: str | None = None
created_at: str
phone_confirmed_at: str | None = None
class OtpSendResponse(BaseModel):
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
class RegistrationBootstrapRequest(BaseModel):
model_config = ConfigDict(extra="forbid")
user_id: UUID
+63
View File
@@ -0,0 +1,63 @@
from __future__ import annotations
from typing import Protocol
from v1.auth.schemas import (
OtpSendRequest,
PhoneSessionCreateRequest,
SessionRefreshRequest,
SessionResponse,
)
class AuthServiceGateway(Protocol):
async def send_otp(self, request: OtpSendRequest) -> None:
raise NotImplementedError
async def create_phone_session(
self, request: PhoneSessionCreateRequest
) -> SessionResponse:
raise NotImplementedError
async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse:
raise NotImplementedError
async def delete_session(self, refresh_token: str | None) -> None:
raise NotImplementedError
class AuthService:
_gateway: AuthServiceGateway
_registration_bootstrapper: RegistrationBootstrapper | None
def __init__(
self,
gateway: AuthServiceGateway,
registration_bootstrapper: "RegistrationBootstrapper | None" = None,
) -> None:
self._gateway = gateway
self._registration_bootstrapper = registration_bootstrapper
async def send_otp(self, request: OtpSendRequest) -> None:
await self._gateway.send_otp(request)
async def create_phone_session(
self, request: PhoneSessionCreateRequest
) -> SessionResponse:
response = await self._gateway.create_phone_session(request)
if self._registration_bootstrapper is not None:
await self._registration_bootstrapper.ensure_user_automation_jobs(
user_id=response.user.id
)
return response
async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse:
return await self._gateway.refresh_session(request)
async def delete_session(self, refresh_token: str | None) -> None:
await self._gateway.delete_session(refresh_token)
class RegistrationBootstrapper(Protocol):
async def ensure_user_automation_jobs(self, *, user_id: str) -> None:
raise NotImplementedError