refactor: 重构 schemas 结构,统一枚举定义

This commit is contained in:
qzl
2026-03-25 12:36:31 +08:00
parent 389f5248fc
commit d22ded21f8
122 changed files with 774 additions and 1456 deletions
@@ -1,283 +0,0 @@
from __future__ import annotations
import pytest
from v1.agent.dependencies import RedisEventStream, TaskiqQueueClient
class _FakeRedis:
def __init__(self) -> None:
self.store: dict[str, str] = {}
self.delete_calls: list[str] = []
async def set(
self,
key: str,
value: str,
*,
nx: bool = False,
ex: int | None = None,
) -> bool:
del ex
if nx and key in self.store:
return False
self.store[key] = value
return True
async def get(self, key: str) -> str | None:
return self.store.get(key)
async def delete(self, key: str) -> int:
self.delete_calls.append(key)
existed = 1 if key in self.store else 0
self.store.pop(key, None)
return existed
class _FakeAsyncResult:
def __init__(self, task_id: str) -> None:
self.task_id = task_id
class _FakeRedisStreamClient:
pass
@pytest.mark.asyncio
async def test_enqueue_returns_task_id(monkeypatch: pytest.MonkeyPatch) -> None:
from v1.agent import dependencies as deps
fake_redis = _FakeRedis()
resolved_client = {"value": False}
async def _fake_kiq(payload: dict[str, object]) -> _FakeAsyncResult:
assert payload["command"] == "run"
return _FakeAsyncResult("task-123")
async def _fake_get_or_init_client() -> _FakeRedis:
resolved_client["value"] = True
return fake_redis
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_kiq)
client = TaskiqQueueClient()
task_id = await client.enqueue(command={"command": "run"}, dedup_key=None)
assert resolved_client["value"] is True
assert task_id == "task-123"
@pytest.mark.asyncio
async def test_enqueue_resume_dedup_returns_existing_task_id(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from v1.agent import dependencies as deps
fake_redis = _FakeRedis()
resolved_client = {"value": False}
async def _fake_kiq(payload: dict[str, object]) -> _FakeAsyncResult:
del payload
return _FakeAsyncResult("new-task-id")
async def _fake_get_or_init_client() -> _FakeRedis:
resolved_client["value"] = True
return fake_redis
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_kiq)
dedup_key = "resume:session-1:call-1"
fake_redis.store[f"agent:dedup:{dedup_key}"] = "existing-task-id"
client = TaskiqQueueClient()
task_id = await client.enqueue(
command={
"command": "resume",
"session_id": "session-1",
"tool_call_id": "call-1",
},
dedup_key=dedup_key,
)
assert resolved_client["value"] is True
assert task_id == "existing-task-id"
@pytest.mark.asyncio
async def test_enqueue_inflight_dedup_waits_and_reuses_existing_task_id(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from v1.agent import dependencies as deps
fake_redis = _FakeRedis()
dedup_key = "resume:session-1:call-1"
redis_key = f"agent:dedup:{dedup_key}"
fake_redis.store[redis_key] = deps.DEDUP_INFLIGHT_MARKER
attempts = {"count": 0}
async def _fake_get_or_init_client() -> _FakeRedis:
return fake_redis
async def _fake_get(key: str) -> str | None:
attempts["count"] += 1
if attempts["count"] > 1:
fake_redis.store[key] = "existing-task-id"
return fake_redis.store.get(key)
async def _fake_sleep(_: float) -> None:
return None
async def _fake_kiq(payload: dict[str, object]) -> _FakeAsyncResult:
del payload
raise AssertionError("should not enqueue when dedup task id appears")
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
monkeypatch.setattr(fake_redis, "get", _fake_get)
monkeypatch.setattr(deps.asyncio, "sleep", _fake_sleep)
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_kiq)
client = TaskiqQueueClient()
task_id = await client.enqueue(
command={
"command": "resume",
"session_id": "session-1",
"tool_call_id": "call-1",
},
dedup_key=dedup_key,
)
assert task_id == "existing-task-id"
@pytest.mark.asyncio
async def test_enqueue_failure_cleans_dedup_lock(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from v1.agent import dependencies as deps
fake_redis = _FakeRedis()
dedup_key = "resume:session-1:call-1"
redis_key = f"agent:dedup:{dedup_key}"
async def _fake_get_or_init_client() -> _FakeRedis:
return fake_redis
async def _fake_kiq(payload: dict[str, object]) -> _FakeAsyncResult:
del payload
raise RuntimeError("enqueue failed")
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_kiq)
client = TaskiqQueueClient()
with pytest.raises(RuntimeError, match="enqueue failed"):
await client.enqueue(
command={
"command": "resume",
"session_id": "session-1",
"tool_call_id": "call-1",
},
dedup_key=dedup_key,
)
assert redis_key in fake_redis.delete_calls
@pytest.mark.asyncio
async def test_enqueue_uses_critical_queue_when_requested(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from v1.agent import dependencies as deps
fake_redis = _FakeRedis()
async def _fake_get_or_init_client() -> _FakeRedis:
return fake_redis
async def _fake_default_kiq(_: dict[str, object]) -> _FakeAsyncResult:
raise AssertionError("default queue should not be selected")
async def _fake_critical_kiq(_: dict[str, object]) -> _FakeAsyncResult:
return _FakeAsyncResult("critical-task-id")
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_default_kiq)
monkeypatch.setattr(deps.run_command_task_critical, "kiq", _fake_critical_kiq)
client = TaskiqQueueClient()
task_id = await client.enqueue(
command={"command": "run", "queue": "critical"},
dedup_key=None,
)
assert task_id == "critical-task-id"
@pytest.mark.asyncio
async def test_enqueue_uses_bulk_queue_when_requested(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from v1.agent import dependencies as deps
fake_redis = _FakeRedis()
async def _fake_get_or_init_client() -> _FakeRedis:
return fake_redis
async def _fake_default_kiq(_: dict[str, object]) -> _FakeAsyncResult:
raise AssertionError("default queue should not be selected")
async def _fake_bulk_kiq(_: dict[str, object]) -> _FakeAsyncResult:
return _FakeAsyncResult("bulk-task-id")
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_default_kiq)
monkeypatch.setattr(deps.run_command_task_bulk, "kiq", _fake_bulk_kiq)
client = TaskiqQueueClient()
task_id = await client.enqueue(
command={"command": "run", "queue": "bulk"},
dedup_key=None,
)
assert task_id == "bulk-task-id"
@pytest.mark.asyncio
async def test_event_stream_caps_block_ms_below_socket_timeout(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from v1.agent import dependencies as deps
async def _fake_get_or_init_client() -> _FakeRedisStreamClient:
return _FakeRedisStreamClient()
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
monkeypatch.setattr(deps.config.agent_runtime, "redis_stream_block_ms", 5000)
monkeypatch.setattr(deps.config.redis, "socket_timeout", 1.0)
stream = RedisEventStream()
bus = await stream._get_bus()
assert bus._block_ms == 900
@pytest.mark.asyncio
async def test_event_stream_uses_configured_block_ms_when_safe(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from v1.agent import dependencies as deps
async def _fake_get_or_init_client() -> _FakeRedisStreamClient:
return _FakeRedisStreamClient()
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
monkeypatch.setattr(deps.config.agent_runtime, "redis_stream_block_ms", 200)
monkeypatch.setattr(deps.config.redis, "socket_timeout", 2.0)
stream = RedisEventStream()
bus = await stream._get_bus()
assert bus._block_ms == 200
+1 -1
View File
@@ -12,7 +12,7 @@ import pytest
import v1.agent.service as agent_service_module
from core.auth.models import CurrentUser
from core.config.settings import config
from schemas.messages.chat_message import AgentChatMessageMetadata
from schemas.domain.chat_message import AgentChatMessageMetadata
from v1.agent.service import AgentService
@@ -10,7 +10,7 @@ from models.automation_jobs import ScheduleType
from v1.auth.registration_bootstrap import (
compute_first_run_at_utc,
)
from schemas.automation import ScheduleConfig, ScheduleRunAt
from schemas.domain.automation import ScheduleConfig, ScheduleRunAt
def test_compute_first_run_at_utc_from_asia_shanghai() -> None:
@@ -4,7 +4,7 @@ from uuid import uuid4
import pytest
from models.automation_jobs import AutomationJobStatus, ScheduleType
from schemas.automation import (
from schemas.domain.automation import (
AgentTool,
AutomationJobConfig,
ContextSource,
@@ -5,7 +5,8 @@ from uuid import uuid4
import pytest
from pydantic import ValidationError
from schemas.automation import AgentTool, AutomationJobConfig
from core.agentscope.tools.tool_config import AgentTool
from schemas.domain.automation import AutomationJobConfig
from v1.automation_jobs.schemas import (
AutomationJobCreateRequest,
AutomationJobResponse,
@@ -17,7 +17,7 @@ from v1.automation_jobs.schemas import (
AutomationJobCreateRequest,
AutomationJobUpdateRequest,
)
from schemas.automation import (
from schemas.domain.automation import (
AgentTool,
AutomationJobConfig,
ContextSource,
@@ -5,7 +5,7 @@ from uuid import uuid4
import pytest
from pydantic import ValidationError
from schemas.user.context import UserContext
from schemas.shared.user import UserContext
from v1.friendships.schemas import (
FriendRequestCreate,
@@ -16,7 +16,7 @@ def test_inbox_message_response_schema() -> None:
sender_id=uuid4(),
message_type=InboxMessageType.CALENDAR,
schedule_item_id=uuid4(),
content="Join my calendar",
content={"type": "invite", "permission": 1, "action": "pending"},
is_read=False,
status=InboxMessageStatus.PENDING,
created_at=datetime.now(timezone.utc),
@@ -22,7 +22,7 @@ def _build_message(
status: InboxMessageModelStatus = InboxMessageModelStatus.PENDING,
message_type: InboxMessageModelType = InboxMessageModelType.CALENDAR,
schedule_item_id: UUID | None = None,
content: str = '{"permission": 7}',
content: dict[str, object] = {"permission": 7},
) -> InboxMessage:
message = MagicMock(spec=InboxMessage)
message.id = message_id
@@ -500,3 +500,43 @@ async def test_list_by_date_range_rolls_back_when_query_fails_after_archive(
assert exc_info.value.status_code == 503
mock_session.rollback.assert_awaited_once()
mock_session.commit.assert_not_awaited()
@pytest.mark.asyncio
async def test_get_by_id_maps_legacy_completed_to_archived(
mock_session: AsyncMock,
mock_inbox_repository: MagicMock,
) -> None:
user_id = UUID("00000000-0000-0000-0000-000000000001")
item = _create_mock_schedule_item()
setattr(item, "status", "completed")
service = ScheduleItemService(
repository=FakeRepo(item),
session=mock_session,
current_user=CurrentUser(id=user_id),
inbox_repository=mock_inbox_repository,
)
result = await service.get_by_id(item.id)
assert result.status == ScheduleItemStatus.ARCHIVED
@pytest.mark.asyncio
async def test_get_by_id_maps_legacy_canceled_to_archived(
mock_session: AsyncMock,
mock_inbox_repository: MagicMock,
) -> None:
user_id = UUID("00000000-0000-0000-0000-000000000001")
item = _create_mock_schedule_item()
setattr(item, "status", "canceled")
service = ScheduleItemService(
repository=FakeRepo(item),
session=mock_session,
current_user=CurrentUser(id=user_id),
inbox_repository=mock_inbox_repository,
)
result = await service.get_by_id(item.id)
assert result.status == ScheduleItemStatus.ARCHIVED
@@ -1,4 +1,3 @@
import json
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock
from uuid import UUID, uuid4
@@ -50,6 +49,11 @@ class FakeInboxRepo:
return self._inbox
return None
async def get_calendar_invite(
self, schedule_item_id: UUID, recipient_id: UUID
) -> InboxMessage | None:
return await self.get_pending_calendar_invite(schedule_item_id, recipient_id)
async def create(self, data: dict) -> InboxMessage:
return MagicMock()
@@ -80,6 +84,9 @@ def mock_session() -> AsyncMock:
@pytest.fixture
def mock_repo() -> MagicMock:
repo = MagicMock()
repo.get_subscription = AsyncMock(return_value=None)
repo.update_subscription_status = AsyncMock(return_value=None)
repo.archive_expired_subscribed_items = AsyncMock(return_value=0)
repo.create_subscription = AsyncMock(return_value=MagicMock())
return repo
@@ -196,6 +203,9 @@ async def test_list_by_date_range_with_subscriptions(
mock_repo.list_by_date_range = AsyncMock(return_value=[owned_item])
mock_repo.get_user_subscriptions = AsyncMock(return_value=[subscription])
mock_repo.list_subscribed_items_by_date_range = AsyncMock(
return_value=[(subscribed_item, subscription)]
)
mock_repo.get_by_id = AsyncMock(return_value=subscribed_item)
service = ScheduleItemService(
@@ -214,7 +224,6 @@ async def test_list_by_date_range_with_subscriptions(
result = await service.list_by_date_range(request)
assert len(result) == 2
assert result[0].is_owner is True
assert result[1].is_owner is False
assert result[1].permission == 1
assert len(result) == 1
assert result[0].is_owner is False
assert result[0].permission == 1