chore: checkpoint current backend/runtime changes

This commit is contained in:
qzl
2026-03-06 17:28:17 +08:00
parent 2c59fe5ee2
commit b6087fd195
32 changed files with 1641 additions and 469 deletions
@@ -0,0 +1,227 @@
from __future__ import annotations
import pytest
from v1.agent.dependencies import TaskiqQueueClient
class _FakeRedis:
def __init__(self) -> None:
self.store: dict[str, str] = {}
self.delete_calls: list[str] = []
async def set(
self,
key: str,
value: str,
*,
nx: bool = False,
ex: int | None = None,
) -> bool:
del ex
if nx and key in self.store:
return False
self.store[key] = value
return True
async def get(self, key: str) -> str | None:
return self.store.get(key)
async def delete(self, key: str) -> int:
self.delete_calls.append(key)
existed = 1 if key in self.store else 0
self.store.pop(key, None)
return existed
class _FakeAsyncResult:
def __init__(self, task_id: str) -> None:
self.task_id = task_id
@pytest.mark.asyncio
async def test_enqueue_returns_task_id(monkeypatch: pytest.MonkeyPatch) -> None:
from v1.agent import dependencies as deps
fake_redis = _FakeRedis()
resolved_client = {"value": False}
async def _fake_kiq(payload: dict[str, object]) -> _FakeAsyncResult:
assert payload["command"] == "run"
return _FakeAsyncResult("task-123")
async def _fake_get_or_init_client() -> _FakeRedis:
resolved_client["value"] = True
return fake_redis
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_kiq)
client = TaskiqQueueClient()
task_id = await client.enqueue(command={"command": "run"}, dedup_key=None)
assert resolved_client["value"] is True
assert task_id == "task-123"
@pytest.mark.asyncio
async def test_enqueue_resume_dedup_returns_existing_task_id(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from v1.agent import dependencies as deps
fake_redis = _FakeRedis()
resolved_client = {"value": False}
async def _fake_kiq(payload: dict[str, object]) -> _FakeAsyncResult:
del payload
return _FakeAsyncResult("new-task-id")
async def _fake_get_or_init_client() -> _FakeRedis:
resolved_client["value"] = True
return fake_redis
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_kiq)
dedup_key = "resume:session-1:call-1"
fake_redis.store[f"agent:dedup:{dedup_key}"] = "existing-task-id"
client = TaskiqQueueClient()
task_id = await client.enqueue(
command={"command": "resume", "session_id": "session-1", "tool_call_id": "call-1"},
dedup_key=dedup_key,
)
assert resolved_client["value"] is True
assert task_id == "existing-task-id"
@pytest.mark.asyncio
async def test_enqueue_inflight_dedup_waits_and_reuses_existing_task_id(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from v1.agent import dependencies as deps
fake_redis = _FakeRedis()
dedup_key = "resume:session-1:call-1"
redis_key = f"agent:dedup:{dedup_key}"
fake_redis.store[redis_key] = deps.DEDUP_INFLIGHT_MARKER
attempts = {"count": 0}
async def _fake_get_or_init_client() -> _FakeRedis:
return fake_redis
async def _fake_get(key: str) -> str | None:
attempts["count"] += 1
if attempts["count"] > 1:
fake_redis.store[key] = "existing-task-id"
return fake_redis.store.get(key)
async def _fake_sleep(_: float) -> None:
return None
async def _fake_kiq(payload: dict[str, object]) -> _FakeAsyncResult:
del payload
raise AssertionError("should not enqueue when dedup task id appears")
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
monkeypatch.setattr(fake_redis, "get", _fake_get)
monkeypatch.setattr(deps.asyncio, "sleep", _fake_sleep)
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_kiq)
client = TaskiqQueueClient()
task_id = await client.enqueue(
command={"command": "resume", "session_id": "session-1", "tool_call_id": "call-1"},
dedup_key=dedup_key,
)
assert task_id == "existing-task-id"
@pytest.mark.asyncio
async def test_enqueue_failure_cleans_dedup_lock(monkeypatch: pytest.MonkeyPatch) -> None:
from v1.agent import dependencies as deps
fake_redis = _FakeRedis()
dedup_key = "resume:session-1:call-1"
redis_key = f"agent:dedup:{dedup_key}"
async def _fake_get_or_init_client() -> _FakeRedis:
return fake_redis
async def _fake_kiq(payload: dict[str, object]) -> _FakeAsyncResult:
del payload
raise RuntimeError("enqueue failed")
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_kiq)
client = TaskiqQueueClient()
with pytest.raises(RuntimeError, match="enqueue failed"):
await client.enqueue(
command={"command": "resume", "session_id": "session-1", "tool_call_id": "call-1"},
dedup_key=dedup_key,
)
assert redis_key in fake_redis.delete_calls
@pytest.mark.asyncio
async def test_enqueue_uses_critical_queue_when_requested(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from v1.agent import dependencies as deps
fake_redis = _FakeRedis()
async def _fake_get_or_init_client() -> _FakeRedis:
return fake_redis
async def _fake_default_kiq(_: dict[str, object]) -> _FakeAsyncResult:
raise AssertionError("default queue should not be selected")
async def _fake_critical_kiq(_: dict[str, object]) -> _FakeAsyncResult:
return _FakeAsyncResult("critical-task-id")
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_default_kiq)
monkeypatch.setattr(deps.run_command_task_critical, "kiq", _fake_critical_kiq)
client = TaskiqQueueClient()
task_id = await client.enqueue(
command={"command": "run", "queue": "critical"},
dedup_key=None,
)
assert task_id == "critical-task-id"
@pytest.mark.asyncio
async def test_enqueue_uses_bulk_queue_when_requested(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from v1.agent import dependencies as deps
fake_redis = _FakeRedis()
async def _fake_get_or_init_client() -> _FakeRedis:
return fake_redis
async def _fake_default_kiq(_: dict[str, object]) -> _FakeAsyncResult:
raise AssertionError("default queue should not be selected")
async def _fake_bulk_kiq(_: dict[str, object]) -> _FakeAsyncResult:
return _FakeAsyncResult("bulk-task-id")
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_default_kiq)
monkeypatch.setattr(deps.run_command_task_bulk, "kiq", _fake_bulk_kiq)
client = TaskiqQueueClient()
task_id = await client.enqueue(
command={"command": "run", "queue": "bulk"},
dedup_key=None,
)
assert task_id == "bulk-task-id"
+40 -28
View File
@@ -1,7 +1,7 @@
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock
import pytest
from fastapi import HTTPException
@@ -12,37 +12,44 @@ from v1.auth.schemas import PasswordResetConfirmRequest, PasswordResetRequest
class TestSupabaseAuthGateway:
@pytest.fixture
def gateway(self) -> SupabaseAuthGateway:
with patch("v1.auth.gateway.create_client") as mock_create:
mock_client = MagicMock()
mock_admin_client = MagicMock()
mock_create.side_effect = [mock_client, mock_admin_client]
return SupabaseAuthGateway()
def gateway(
self, monkeypatch: pytest.MonkeyPatch
) -> tuple[SupabaseAuthGateway, MagicMock, MagicMock]:
mock_client = MagicMock()
mock_admin_client = MagicMock()
monkeypatch.setattr("v1.auth.gateway.supabase_service.get_client", lambda: mock_client)
monkeypatch.setattr(
"v1.auth.gateway.supabase_service.get_admin_client",
lambda: mock_admin_client,
)
return SupabaseAuthGateway(), mock_client, mock_admin_client
@pytest.mark.asyncio
async def test_request_password_reset_calls_email_with_string(
self, gateway: SupabaseAuthGateway
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
) -> None:
sut, mock_client, _ = gateway
mock_reset_email = MagicMock()
gateway._client.auth.reset_password_email = mock_reset_email
mock_client.auth.reset_password_email = mock_reset_email
request = PasswordResetRequest(email="test@example.com")
await gateway.request_password_reset(request)
await sut.request_password_reset(request)
mock_reset_email.assert_called_once_with("test@example.com")
@pytest.mark.asyncio
async def test_request_password_reset_with_redirect(
self, gateway: SupabaseAuthGateway
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
) -> None:
sut, mock_client, _ = gateway
mock_reset_email = MagicMock()
gateway._client.auth.reset_password_email = mock_reset_email
mock_client.auth.reset_password_email = mock_reset_email
request = PasswordResetRequest(
email="test@example.com",
redirect_to="http://localhost:3000/reset-password",
)
await gateway.request_password_reset(request)
await sut.request_password_reset(request)
mock_reset_email.assert_called_once_with(
"test@example.com",
@@ -51,64 +58,68 @@ class TestSupabaseAuthGateway:
@pytest.mark.asyncio
async def test_request_password_reset_swallows_auth_error(
self, gateway: SupabaseAuthGateway
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
) -> None:
sut, mock_client, _ = gateway
from supabase import AuthError
mock_reset_email = MagicMock(side_effect=AuthError("rate limit exceeded", None))
gateway._client.auth.reset_password_email = mock_reset_email
mock_client.auth.reset_password_email = mock_reset_email
request = PasswordResetRequest(email="test@example.com")
result = await gateway.request_password_reset(request)
result = await sut.request_password_reset(request)
mock_reset_email.assert_called_once()
assert result is None
@pytest.mark.asyncio
async def test_request_password_reset_extracts_email_from_mapping(
self, gateway: SupabaseAuthGateway
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
) -> None:
sut, mock_client, _ = gateway
mock_reset_email = MagicMock()
gateway._client.auth.reset_password_email = mock_reset_email
mock_client.auth.reset_password_email = mock_reset_email
request = PasswordResetRequest.model_construct(
email={"email": "test@example.com"},
redirect_to=None,
)
await gateway.request_password_reset(request)
await sut.request_password_reset(request)
mock_reset_email.assert_called_once_with("test@example.com")
@pytest.mark.asyncio
async def test_request_password_reset_rejects_invalid_email_shape(
self, gateway: SupabaseAuthGateway
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
) -> None:
sut, _, _ = gateway
request = PasswordResetRequest.model_construct(
email={"unexpected": "value"},
redirect_to=None,
)
with pytest.raises(HTTPException) as exc_info:
await gateway.request_password_reset(request)
await sut.request_password_reset(request)
assert exc_info.value.status_code == 422
assert exc_info.value.detail == "Invalid email"
@pytest.mark.asyncio
async def test_confirm_password_reset_updates_password_by_user_id(
self, gateway: SupabaseAuthGateway
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
) -> None:
sut, mock_client, mock_admin_client = gateway
verify_response = SimpleNamespace(
session=SimpleNamespace(access_token="access"),
user=SimpleNamespace(id="user-1"),
)
mock_verify_otp = MagicMock(return_value=verify_response)
gateway._client.auth.verify_otp = mock_verify_otp
mock_client.auth.verify_otp = mock_verify_otp
mock_update_user_by_id = MagicMock()
gateway._admin_client.auth.admin = SimpleNamespace(
mock_admin_client.auth.admin = SimpleNamespace(
update_user_by_id=mock_update_user_by_id
)
@@ -118,7 +129,7 @@ class TestSupabaseAuthGateway:
new_password="newpassword123",
)
await gateway.confirm_password_reset(request)
await sut.confirm_password_reset(request)
mock_verify_otp.assert_called_once_with(
{
@@ -134,13 +145,14 @@ class TestSupabaseAuthGateway:
@pytest.mark.asyncio
async def test_confirm_password_reset_raises_when_user_id_missing(
self, gateway: SupabaseAuthGateway
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
) -> None:
sut, mock_client, _ = gateway
verify_response = SimpleNamespace(
session=SimpleNamespace(access_token="access"),
user=SimpleNamespace(id=""),
)
gateway._client.auth.verify_otp = MagicMock(return_value=verify_response)
mock_client.auth.verify_otp = MagicMock(return_value=verify_response)
request = PasswordResetConfirmRequest(
email="test@example.com",
@@ -149,7 +161,7 @@ class TestSupabaseAuthGateway:
)
with pytest.raises(HTTPException) as exc_info:
await gateway.confirm_password_reset(request)
await sut.confirm_password_reset(request)
assert exc_info.value.status_code == 401
assert exc_info.value.detail == "Invalid or expired verification code"