diff --git a/backend/alembic/versions/20260227_0006_invite_codes_and_profile_referral.py b/backend/alembic/versions/20260227_0006_invite_codes_and_profile_referral.py new file mode 100644 index 0000000..5867909 --- /dev/null +++ b/backend/alembic/versions/20260227_0006_invite_codes_and_profile_referral.py @@ -0,0 +1,197 @@ +"""invite_codes table and profile referral fields + +Revision ID: 202602270006 +Revises: 202602260005 +Create Date: 2026-02-27 10:00:00 +""" + +from typing import Sequence, Union + +from alembic import op + +revision: str = "202602270006" +down_revision: Union[str, Sequence[str], None] = "202602260005" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.execute( + """ + CREATE TABLE invite_codes ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + code VARCHAR(8) NOT NULL UNIQUE, + owner_id UUID REFERENCES profiles(id) ON DELETE SET NULL, + status VARCHAR(20) NOT NULL DEFAULT 'active' CHECK (status IN ('active', 'disabled', 'expired')), + used_count INTEGER NOT NULL DEFAULT 0 CHECK (used_count >= 0), + max_uses INTEGER CHECK (max_uses IS NULL OR max_uses >= 1), + expires_at TIMESTAMPTZ NULL, + reward_config JSONB NOT NULL DEFAULT '{}', + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + ) + """ + ) + op.execute("CREATE INDEX ix_invite_codes_owner_id ON invite_codes(owner_id)") + op.execute( + "CREATE INDEX ix_invite_codes_code ON invite_codes(code) WHERE status = 'active'" + ) + + op.execute("ALTER TABLE invite_codes ENABLE ROW LEVEL SECURITY") + op.execute("DROP POLICY IF EXISTS invite_codes_all_denied ON invite_codes") + op.execute( + "CREATE POLICY invite_codes_all_denied ON invite_codes FOR ALL USING (false)" + ) + + op.execute( + """ + ALTER TABLE profiles ADD COLUMN referred_by UUID REFERENCES profiles(id) ON DELETE SET NULL + """ + ) + op.execute("CREATE INDEX ix_profiles_referred_by ON profiles(referred_by)") + + op.execute( + """ + CREATE OR REPLACE FUNCTION public.generate_invite_code() + RETURNS TEXT + LANGUAGE plpgsql + SECURITY DEFINER + SET search_path = public + AS $$ + DECLARE + chars TEXT := 'ABCDEFGHJKMNPQRSTUVWXYZ23456789'; + result TEXT := ''; + i INT; + BEGIN + FOR i IN 1..8 LOOP + result := result || substr(chars, floor(random() * length(chars) + 1)::int, 1); + END LOOP; + RETURN result; + END; + $$; + """ + ) + + op.execute( + """ + CREATE OR REPLACE FUNCTION public.create_profile_for_new_user() + RETURNS trigger + LANGUAGE plpgsql + SECURITY DEFINER + SET search_path = public + AS $$ + DECLARE + invite_code_value TEXT; + referrer_id UUID; + new_code TEXT; + attempts INT := 0; + BEGIN + INSERT INTO public.profiles (id, username, avatar_url, bio, settings, referred_by, created_at, updated_at) + VALUES ( + NEW.id, + COALESCE( + NEW.raw_user_meta_data ->> 'username', + split_part(NEW.email, '@', 1), + 'user_' || substring(NEW.id::text, 1, 8) + ), + NULL, + NULL, + '{}'::jsonb, + NULL, + now(), + now() + ) + ON CONFLICT (id) DO NOTHING; + + LOOP + BEGIN + new_code := public.generate_invite_code(); + INSERT INTO public.invite_codes (code, owner_id, status, used_count, max_uses, expires_at, reward_config) + VALUES ( + new_code, + NEW.id, + 'active', + 0, + NULL, + NULL, + '{}'::jsonb + ); + EXIT; + EXCEPTION WHEN unique_violation THEN + attempts := attempts + 1; + IF attempts >= 100 THEN + RAISE EXCEPTION 'Failed to generate unique invite code after 100 attempts'; + END IF; + END; + END LOOP; + + invite_code_value := NEW.raw_user_meta_data ->> 'invite_code'; + IF invite_code_value IS NOT NULL AND length(invite_code_value) = 8 THEN + invite_code_value := upper(invite_code_value); + IF invite_code_value ~ '^[ABCDEFGHJKMNPQRSTUVWXYZ23456789]{8}$' THEN + UPDATE public.invite_codes + SET used_count = used_count + 1 + WHERE code = invite_code_value + AND status = 'active' + AND (max_uses IS NULL OR used_count < max_uses) + AND (expires_at IS NULL OR expires_at > NOW()) + RETURNING owner_id INTO referrer_id; + + IF referrer_id IS NOT NULL THEN + UPDATE public.profiles + SET referred_by = referrer_id + WHERE id = NEW.id; + END IF; + END IF; + END IF; + + RETURN NEW; + END; + $$; + """ + ) + + +def downgrade() -> None: + op.execute("DROP FUNCTION IF EXISTS public.create_profile_for_new_user()") + op.execute( + """ + CREATE OR REPLACE FUNCTION public.create_profile_for_new_user() + RETURNS trigger + LANGUAGE plpgsql + SECURITY DEFINER + SET search_path = public + AS $$ + BEGIN + INSERT INTO public.profiles (id, username, avatar_url, bio, settings, created_at, updated_at) + VALUES ( + NEW.id, + COALESCE( + NEW.raw_user_meta_data ->> 'username', + split_part(NEW.email, '@', 1), + 'user_' || substring(NEW.id::text, 1, 8) + ), + NULL, + NULL, + '{}'::jsonb, + now(), + now() + ) + ON CONFLICT (id) DO NOTHING; + + RETURN NEW; + END; + $$; + """ + ) + + op.execute("DROP FUNCTION IF EXISTS public.generate_invite_code()") + + op.execute("DROP INDEX IF EXISTS ix_profiles_referred_by") + op.execute("ALTER TABLE profiles DROP COLUMN IF EXISTS referred_by") + + op.execute("DROP POLICY IF EXISTS invite_codes_all_denied ON invite_codes") + op.execute("ALTER TABLE invite_codes DISABLE ROW LEVEL SECURITY") + op.execute("DROP INDEX IF EXISTS ix_invite_codes_code") + op.execute("DROP INDEX IF EXISTS ix_invite_codes_owner_id") + op.execute("DROP TABLE IF EXISTS invite_codes") diff --git a/backend/src/models/__init__.py b/backend/src/models/__init__.py index 6ac51d9..e6f6bbc 100644 --- a/backend/src/models/__init__.py +++ b/backend/src/models/__init__.py @@ -6,6 +6,7 @@ from models.automation_jobs import AutomationJob from models.group_members import GroupMember from models.groups import Group from models.inbox_messages import InboxMessage +from models.invite_code import InviteCode, InviteCodeStatus from models.llm import Llm from models.llm_factory import LlmFactory from models.memories import Memory @@ -23,6 +24,8 @@ __all__ = [ "GroupMember", "Group", "InboxMessage", + "InviteCode", + "InviteCodeStatus", "Llm", "LlmFactory", "Memory", diff --git a/backend/src/models/invite_code.py b/backend/src/models/invite_code.py new file mode 100644 index 0000000..5c7774b --- /dev/null +++ b/backend/src/models/invite_code.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +import uuid +from datetime import datetime +from enum import Enum + +from sqlalchemy import CheckConstraint, DateTime, ForeignKey, Integer, String +from sqlalchemy.dialects.postgresql import JSONB, UUID +from sqlalchemy.orm import Mapped, mapped_column + +from core.db.base import Base, TimestampMixin + + +class InviteCodeStatus(str, Enum): + ACTIVE = "active" + DISABLED = "disabled" + EXPIRED = "expired" + + +class InviteCode(TimestampMixin, Base): + """Invite code model. + + Tracks invite codes generated by users for referral system. + """ + + __tablename__: str = "invite_codes" + __table_args__ = ( + CheckConstraint( + "status IN ('active', 'disabled', 'expired')", + name="invite_codes_status_check", + ), + CheckConstraint("used_count >= 0", name="invite_codes_used_count_check"), + CheckConstraint( + "max_uses IS NULL OR max_uses >= 1", + name="invite_codes_max_uses_check", + ), + {"extend_existing": True}, + ) + + id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid4, + ) + code: Mapped[str] = mapped_column( + String(8), + nullable=False, + unique=True, + index=True, + ) + owner_id: Mapped[uuid.UUID | None] = mapped_column( + UUID(as_uuid=True), + ForeignKey("profiles.id", ondelete="SET NULL"), + nullable=True, + index=True, + ) + status: Mapped[str] = mapped_column( + String(20), + nullable=False, + default=InviteCodeStatus.ACTIVE.value, + ) + used_count: Mapped[int] = mapped_column( + Integer, + nullable=False, + default=0, + ) + max_uses: Mapped[int | None] = mapped_column( + Integer, + nullable=True, + ) + expires_at: Mapped[datetime | None] = mapped_column( + DateTime(timezone=True), + nullable=True, + ) + reward_config: Mapped[dict] = mapped_column( + JSONB, + nullable=False, + server_default="{}", + ) diff --git a/backend/src/models/profile.py b/backend/src/models/profile.py index 917a6af..60e8c5c 100644 --- a/backend/src/models/profile.py +++ b/backend/src/models/profile.py @@ -2,7 +2,7 @@ from __future__ import annotations import uuid -from sqlalchemy import String, Text +from sqlalchemy import ForeignKey, String, Text from sqlalchemy.dialects.postgresql import JSONB, UUID from sqlalchemy.orm import Mapped, mapped_column @@ -42,3 +42,9 @@ class Profile(TimestampMixin, SoftDeleteMixin, Base): nullable=False, server_default="{}", ) + referred_by: Mapped[uuid.UUID | None] = mapped_column( + UUID(as_uuid=True), + ForeignKey("profiles.id", ondelete="SET NULL"), + nullable=True, + index=True, + ) diff --git a/backend/src/v1/auth/gateway.py b/backend/src/v1/auth/gateway.py index 5a647fd..efac7ba 100644 --- a/backend/src/v1/auth/gateway.py +++ b/backend/src/v1/auth/gateway.py @@ -39,10 +39,13 @@ class SupabaseAuthGateway(AuthServiceGateway): async def create_verification( self, request: VerificationCreateRequest ) -> VerificationCreateResponse: + metadata: dict[str, Any] = {"username": request.username} + if request.invite_code: + metadata["invite_code"] = request.invite_code payload: dict[str, Any] = { "email": request.email, "password": request.password, - "data": {"username": request.username}, + "data": metadata, } if request.redirect_to: payload["options"] = {"email_redirect_to": request.redirect_to} diff --git a/backend/src/v1/auth/schemas.py b/backend/src/v1/auth/schemas.py index 4821f8b..d11b7aa 100644 --- a/backend/src/v1/auth/schemas.py +++ b/backend/src/v1/auth/schemas.py @@ -1,13 +1,21 @@ from __future__ import annotations -from pydantic import BaseModel, EmailStr, Field +from pydantic import BaseModel, ConfigDict, EmailStr, Field class VerificationCreateRequest(BaseModel): + model_config = ConfigDict(extra="forbid") + username: str = Field(min_length=3, max_length=30) email: EmailStr password: str = Field(min_length=6) redirect_to: str | None = None + invite_code: str | None = Field( + default=None, + min_length=8, + max_length=8, + pattern=r"^[ABCDEFGHJKMNPQRSTUVWXYZ23456789]{8}$", + ) class VerificationResendRequest(BaseModel): @@ -65,7 +73,3 @@ class PasswordResetConfirmRequest(BaseModel): email: EmailStr token: str = Field(pattern=r"^\d{6}$") new_password: str = Field(min_length=6) - - -class PasswordResetResponse(BaseModel): - message: str = "Password reset email sent"