feat: 重构 memory 系统,支持 user memory 和 work memory 分离
This commit is contained in:
@@ -12,6 +12,14 @@ from v1.auth.schemas import (
|
||||
from v1.auth.service import AuthService, AuthServiceGateway
|
||||
|
||||
|
||||
class FakeRegistrationBootstrapper:
|
||||
def __init__(self) -> None:
|
||||
self.called_user_ids: list[str] = []
|
||||
|
||||
async def ensure_user_automation_jobs(self, *, user_id: str) -> None:
|
||||
self.called_user_ids.append(user_id)
|
||||
|
||||
|
||||
class FakeGateway(AuthServiceGateway):
|
||||
def __init__(self, response: SessionResponse) -> None:
|
||||
self._response = response
|
||||
@@ -75,6 +83,27 @@ async def test_create_phone_session_forwards_payload() -> None:
|
||||
assert response.user.phone == "+8613812345678"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_phone_session_bootstraps_automation_job() -> None:
|
||||
user = AuthUser(id="b196f8be-c5f4-45d8-8f07-65c0ddf4d3de", phone="+8613812345678")
|
||||
token_response = SessionResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
gateway = FakeGateway(token_response)
|
||||
bootstrapper = FakeRegistrationBootstrapper()
|
||||
service = AuthService(gateway=gateway, registration_bootstrapper=bootstrapper)
|
||||
|
||||
await service.create_phone_session(
|
||||
PhoneSessionCreateRequest(phone="+8613812345678", token="123456")
|
||||
)
|
||||
|
||||
assert bootstrapper.called_user_ids == ["b196f8be-c5f4-45d8-8f07-65c0ddf4d3de"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_session_forwards_payload() -> None:
|
||||
user = AuthUser(id="user-1", phone="+8613812345678")
|
||||
|
||||
@@ -0,0 +1,112 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from v1.auth.registration_bootstrap import (
|
||||
compute_next_local_time_utc,
|
||||
)
|
||||
|
||||
|
||||
def test_compute_next_local_time_utc_from_asia_shanghai() -> None:
|
||||
now_utc = datetime(2026, 3, 23, 0, 30, tzinfo=timezone.utc)
|
||||
|
||||
run_at, next_run_at = compute_next_local_time_utc(
|
||||
now_utc=now_utc,
|
||||
timezone_name="Asia/Shanghai",
|
||||
local_hour=8,
|
||||
local_minute=0,
|
||||
)
|
||||
|
||||
assert run_at == datetime(2026, 3, 24, 0, 0, tzinfo=timezone.utc)
|
||||
assert next_run_at == datetime(2026, 3, 25, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
|
||||
def test_compute_next_local_time_utc_rolls_to_next_day_when_passed() -> None:
|
||||
now_utc = datetime(2026, 3, 23, 2, 30, tzinfo=timezone.utc)
|
||||
|
||||
run_at, next_run_at = compute_next_local_time_utc(
|
||||
now_utc=now_utc,
|
||||
timezone_name="Asia/Shanghai",
|
||||
local_hour=8,
|
||||
local_minute=0,
|
||||
)
|
||||
|
||||
assert run_at == datetime(2026, 3, 24, 0, 0, tzinfo=timezone.utc)
|
||||
assert next_run_at == datetime(2026, 3, 25, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_registration_service_is_idempotent_when_job_exists() -> None:
|
||||
from v1.auth.registration_bootstrap import RegistrationAutomationBootstrapService
|
||||
|
||||
expected_owner_id = uuid4()
|
||||
|
||||
class _Repo:
|
||||
inserted = 0
|
||||
upsert_calls = 0
|
||||
|
||||
async def get_profile_timezone(self, *, user_id):
|
||||
assert user_id == expected_owner_id
|
||||
return "Asia/Shanghai"
|
||||
|
||||
async def insert_bootstrap_automation_job_if_absent(self, **kwargs):
|
||||
assert kwargs["owner_id"] == expected_owner_id
|
||||
assert kwargs["bootstrap_key"] == "memory_extraction"
|
||||
self.inserted += 1
|
||||
return False
|
||||
|
||||
async def upsert_initial_memory(self, **kwargs):
|
||||
self.upsert_calls += 1
|
||||
return False
|
||||
|
||||
class _Session:
|
||||
async def commit(self):
|
||||
raise AssertionError("must not commit when already exists")
|
||||
|
||||
async def rollback(self):
|
||||
raise AssertionError("must not rollback when no error")
|
||||
|
||||
service = RegistrationAutomationBootstrapService(
|
||||
repository=cast(Any, _Repo()), session=cast(Any, _Session())
|
||||
)
|
||||
await service.ensure_user_automation_jobs(user_id=str(expected_owner_id))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_registration_service_creates_initial_memories_when_missing() -> None:
|
||||
from v1.auth.registration_bootstrap import RegistrationAutomationBootstrapService
|
||||
|
||||
expected_owner_id = uuid4()
|
||||
|
||||
class _Repo:
|
||||
async def get_profile_timezone(self, *, user_id):
|
||||
assert user_id == expected_owner_id
|
||||
return "Asia/Shanghai"
|
||||
|
||||
async def upsert_initial_memory(self, **kwargs):
|
||||
return True
|
||||
|
||||
async def insert_bootstrap_automation_job_if_absent(self, **kwargs):
|
||||
_ = kwargs
|
||||
return True
|
||||
|
||||
class _Session:
|
||||
committed = 0
|
||||
|
||||
async def commit(self):
|
||||
self.committed += 1
|
||||
|
||||
async def rollback(self):
|
||||
raise AssertionError("must not rollback when no error")
|
||||
|
||||
session = _Session()
|
||||
service = RegistrationAutomationBootstrapService(
|
||||
repository=cast(Any, _Repo()), session=cast(Any, session)
|
||||
)
|
||||
await service.ensure_user_automation_jobs(user_id=str(expected_owner_id))
|
||||
|
||||
assert session.committed == 1
|
||||
Reference in New Issue
Block a user