refactor: 重构 schemas 结构,统一枚举定义
This commit is contained in:
@@ -6,7 +6,7 @@ from uuid import uuid4
|
||||
import pytest
|
||||
|
||||
from core.agentscope.persistence.user_context_cache import UserContextCache
|
||||
from schemas.user.context import (
|
||||
from schemas.shared.user import (
|
||||
UserContext,
|
||||
parse_profile_settings,
|
||||
)
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agentscope.runtime.registry_builder import build_consumer_registry
|
||||
|
||||
|
||||
def test_build_consumer_registry_from_system_agent_configs() -> None:
|
||||
registry = build_consumer_registry(
|
||||
system_agent_configs={
|
||||
"router": {"config": {"visibility_consumer_bit": 16}},
|
||||
"worker": {"config": {"visibility_consumer_bit": 17}},
|
||||
"memory": {"config": {"visibility_consumer_bit": 18}},
|
||||
}
|
||||
)
|
||||
|
||||
assert registry.resolve_agent_bit(agent_type="router") == 16
|
||||
assert registry.resolve_agent_bit(agent_type="worker") == 17
|
||||
|
||||
|
||||
def test_build_consumer_registry_rejects_duplicate_bit() -> None:
|
||||
with pytest.raises(ValueError, match="duplicate visibility bit"):
|
||||
build_consumer_registry(
|
||||
system_agent_configs={
|
||||
"router": {"config": {"visibility_consumer_bit": 16}},
|
||||
"worker": {"config": {"visibility_consumer_bit": 16}},
|
||||
}
|
||||
)
|
||||
@@ -6,8 +6,8 @@ import pytest
|
||||
from ag_ui.core import RunAgentInput
|
||||
|
||||
from core.agentscope.runtime.orchestrator import AgentScopeRuntimeOrchestrator
|
||||
from schemas.automation import MessageContextConfig, RuntimeConfig
|
||||
from schemas.user import UserContext, parse_profile_settings
|
||||
from schemas.domain.automation import MessageContextConfig, RuntimeConfig
|
||||
from schemas.shared.user import UserContext, parse_profile_settings
|
||||
|
||||
|
||||
class _FakePipeline:
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agentscope.runtime.pipeline_registry import build_default_pipeline_spec
|
||||
|
||||
|
||||
def test_build_default_pipeline_spec_worker_has_two_stages() -> None:
|
||||
spec = build_default_pipeline_spec(mode="worker")
|
||||
|
||||
assert spec.mode == "worker"
|
||||
assert [item.stage_name for item in spec.stages] == ["router", "worker"]
|
||||
|
||||
|
||||
def test_build_default_pipeline_spec_memory_has_single_stage() -> None:
|
||||
spec = build_default_pipeline_spec(mode="memory")
|
||||
|
||||
assert spec.mode == "memory"
|
||||
assert [item.stage_name for item in spec.stages] == ["memory"]
|
||||
|
||||
|
||||
def test_build_default_pipeline_spec_rejects_unknown_mode() -> None:
|
||||
with pytest.raises(ValueError, match="unsupported pipeline mode"):
|
||||
build_default_pipeline_spec(mode="planner")
|
||||
@@ -18,8 +18,8 @@ from schemas.agent.runtime_models import (
|
||||
WorkerAgentOutputLite,
|
||||
)
|
||||
from schemas.agent.system_agent import AgentType
|
||||
from schemas.automation import MessageContextConfig, RuntimeConfig
|
||||
from schemas.user import UserContext, parse_profile_settings
|
||||
from schemas.domain.automation import MessageContextConfig, RuntimeConfig
|
||||
from schemas.shared.user import UserContext, parse_profile_settings
|
||||
|
||||
|
||||
def _run_input() -> RunAgentInput:
|
||||
|
||||
@@ -7,8 +7,8 @@ import pytest
|
||||
|
||||
import core.agentscope.runtime.tasks as tasks_module
|
||||
from schemas.agent import ToolStatus
|
||||
from schemas.automation import ContextWindowMode, MessageContextConfig
|
||||
from schemas.user import UserContext, parse_profile_settings
|
||||
from schemas.domain.automation import ContextWindowMode, MessageContextConfig
|
||||
from schemas.shared.user import UserContext, parse_profile_settings
|
||||
|
||||
|
||||
def _run_input_payload() -> dict[str, Any]:
|
||||
|
||||
@@ -1,250 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agentscope.schemas.agui_input import (
|
||||
MAX_MESSAGES,
|
||||
MAX_RUN_ID_LENGTH,
|
||||
MAX_RUN_INPUT_BYTES,
|
||||
MAX_TEXT_CHARS,
|
||||
parse_run_input,
|
||||
validate_run_request_messages_contract,
|
||||
)
|
||||
|
||||
|
||||
def _base_payload() -> dict[str, object]:
|
||||
return {
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-1",
|
||||
"state": {},
|
||||
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {"agent_type": "worker"},
|
||||
}
|
||||
|
||||
|
||||
def test_parse_run_input_rejects_invalid_uuid() -> None:
|
||||
payload = _base_payload()
|
||||
payload["threadId"] = "bad-uuid"
|
||||
|
||||
with pytest.raises(ValueError, match="threadId must be a valid UUID"):
|
||||
parse_run_input(payload)
|
||||
|
||||
|
||||
def test_parse_run_input_rejects_message_count_over_limit() -> None:
|
||||
payload = _base_payload()
|
||||
payload["messages"] = [
|
||||
{"id": f"u{i}", "role": "user", "content": "x"} for i in range(MAX_MESSAGES + 1)
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError, match="RunAgentInput.messages exceeds limit"):
|
||||
parse_run_input(payload)
|
||||
|
||||
|
||||
def test_parse_run_input_rejects_user_text_over_limit() -> None:
|
||||
payload = _base_payload()
|
||||
payload["messages"] = [
|
||||
{"id": "u1", "role": "user", "content": "x" * (MAX_TEXT_CHARS + 1)}
|
||||
]
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="RunAgentInput user message text exceeds limit"
|
||||
):
|
||||
parse_run_input(payload)
|
||||
|
||||
|
||||
def test_parse_run_input_rejects_payload_over_limit() -> None:
|
||||
payload = _base_payload()
|
||||
payload["forwardedProps"] = {"blob": "x" * MAX_RUN_INPUT_BYTES}
|
||||
|
||||
with pytest.raises(ValueError, match="RunAgentInput payload exceeds size limit"):
|
||||
parse_run_input(payload)
|
||||
|
||||
|
||||
def test_parse_run_input_rejects_run_id_over_limit() -> None:
|
||||
payload = _base_payload()
|
||||
payload["runId"] = "r" * (MAX_RUN_ID_LENGTH + 1)
|
||||
|
||||
with pytest.raises(ValueError, match="runId exceeds length limit"):
|
||||
parse_run_input(payload)
|
||||
|
||||
|
||||
def test_validate_run_request_messages_contract_requires_single_user_message() -> None:
|
||||
payload = _base_payload()
|
||||
payload["messages"] = [
|
||||
{"id": "u1", "role": "user", "content": "hello"},
|
||||
{"id": "u2", "role": "user", "content": "again"},
|
||||
]
|
||||
run_input = parse_run_input(payload)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="RunAgentInput.messages must contain exactly one user message",
|
||||
):
|
||||
validate_run_request_messages_contract(run_input)
|
||||
|
||||
|
||||
def test_validate_run_request_messages_contract_accepts_binary_url_blocks() -> None:
|
||||
payload = _base_payload()
|
||||
payload["messages"] = [
|
||||
{
|
||||
"id": "u1",
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "请分析"},
|
||||
{
|
||||
"type": "binary",
|
||||
"mimeType": "image/png",
|
||||
"url": "https://signed.example/a.png",
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
run_input = parse_run_input(payload)
|
||||
|
||||
validate_run_request_messages_contract(run_input)
|
||||
|
||||
|
||||
def test_validate_run_request_messages_contract_rejects_binary_data_block() -> None:
|
||||
payload = _base_payload()
|
||||
payload["messages"] = [
|
||||
{
|
||||
"id": "u1",
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "请分析"},
|
||||
{
|
||||
"type": "binary",
|
||||
"mimeType": "image/png",
|
||||
"data": "aGVsbG8=",
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
run_input = parse_run_input(payload)
|
||||
|
||||
with pytest.raises(ValueError, match="binary content requires url"):
|
||||
validate_run_request_messages_contract(run_input)
|
||||
|
||||
|
||||
def test_parse_run_input_accepts_snake_case_aliases() -> None:
|
||||
payload = {
|
||||
"thread_id": "00000000-0000-0000-0000-000000000001",
|
||||
"run_id": "run-1",
|
||||
"state": {},
|
||||
"messages": [
|
||||
{
|
||||
"id": "u1",
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "hello"},
|
||||
{
|
||||
"type": "binary",
|
||||
"mime_type": "image/png",
|
||||
"url": "https://signed.example/a.png",
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwarded_props": {"agent_type": "worker"},
|
||||
}
|
||||
|
||||
run_input = parse_run_input(payload)
|
||||
|
||||
assert run_input.thread_id == "00000000-0000-0000-0000-000000000001"
|
||||
assert run_input.run_id == "run-1"
|
||||
validate_run_request_messages_contract(run_input)
|
||||
|
||||
|
||||
def test_parse_run_input_accepts_client_time_forwarded_props() -> None:
|
||||
payload = _base_payload()
|
||||
payload["forwardedProps"] = {
|
||||
"agent_type": "worker",
|
||||
"client_time": {
|
||||
"device_timezone": "America/Los_Angeles",
|
||||
"client_now_iso": "2026-03-16T09:12:33-07:00",
|
||||
"client_epoch_ms": 1773658353000,
|
||||
},
|
||||
}
|
||||
|
||||
run_input = parse_run_input(payload)
|
||||
|
||||
assert run_input.forwarded_props is not None
|
||||
|
||||
|
||||
def test_parse_run_input_rejects_invalid_client_time_timezone() -> None:
|
||||
payload = _base_payload()
|
||||
payload["forwardedProps"] = {
|
||||
"agent_type": "worker",
|
||||
"client_time": {
|
||||
"device_timezone": "Mars/OlympusMons",
|
||||
"client_now_iso": "2026-03-16T09:12:33-07:00",
|
||||
"client_epoch_ms": 1773658353000,
|
||||
},
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="invalid RunAgentInput.forwardedProps"):
|
||||
parse_run_input(payload)
|
||||
|
||||
|
||||
def test_parse_run_input_rejects_invalid_client_time_now_iso() -> None:
|
||||
payload = _base_payload()
|
||||
payload["forwardedProps"] = {
|
||||
"agent_type": "worker",
|
||||
"client_time": {
|
||||
"device_timezone": "America/Los_Angeles",
|
||||
"client_now_iso": "2026-03-16 09:12:33",
|
||||
"client_epoch_ms": 1773658353000,
|
||||
},
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="invalid RunAgentInput.forwardedProps"):
|
||||
parse_run_input(payload)
|
||||
|
||||
|
||||
def test_parse_run_input_rejects_invalid_client_time_epoch_type() -> None:
|
||||
payload = _base_payload()
|
||||
payload["forwardedProps"] = {
|
||||
"agent_type": "worker",
|
||||
"client_time": {
|
||||
"device_timezone": "America/Los_Angeles",
|
||||
"client_now_iso": "2026-03-16T09:12:33-07:00",
|
||||
"client_epoch_ms": "1773658353000",
|
||||
},
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="invalid RunAgentInput.forwardedProps"):
|
||||
parse_run_input(payload)
|
||||
|
||||
|
||||
def test_parse_run_input_rejects_unknown_forwarded_props_key() -> None:
|
||||
payload = _base_payload()
|
||||
payload["forwardedProps"] = {
|
||||
"agent_type": "worker",
|
||||
"client_time": {
|
||||
"device_timezone": "America/Los_Angeles",
|
||||
"client_now_iso": "2026-03-16T09:12:33-07:00",
|
||||
"client_epoch_ms": 1773658353000,
|
||||
},
|
||||
"unexpected": {"foo": "bar"},
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="invalid RunAgentInput.forwardedProps"):
|
||||
parse_run_input(payload)
|
||||
|
||||
|
||||
def test_parse_run_input_rejects_missing_forwarded_props_agent_type() -> None:
|
||||
payload = _base_payload()
|
||||
payload["forwardedProps"] = {
|
||||
"client_time": {
|
||||
"device_timezone": "America/Los_Angeles",
|
||||
"client_now_iso": "2026-03-16T09:12:33-07:00",
|
||||
"client_epoch_ms": 1773658353000,
|
||||
}
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="invalid RunAgentInput.forwardedProps"):
|
||||
parse_run_input(payload)
|
||||
@@ -40,6 +40,7 @@ class _FakeService:
|
||||
start_at=datetime(2026, 3, 17, 9, 0, tzinfo=timezone.utc),
|
||||
end_at=datetime(2026, 3, 17, 9, 30, tzinfo=timezone.utc),
|
||||
timezone="Asia/Shanghai",
|
||||
status="active",
|
||||
metadata=SimpleNamespace(
|
||||
location=None, color="#4F46E5", reminder_minutes=15
|
||||
),
|
||||
@@ -247,7 +248,7 @@ async def test_calendar_read_returns_structured_result_with_ids(
|
||||
assert "total=1" in payload["result"]
|
||||
assert "timezone=Asia/Shanghai" in payload["result"]
|
||||
assert "description=今天下午五点的会议" in payload["result"]
|
||||
assert "status=" in payload["result"]
|
||||
assert "status=active" in payload["result"]
|
||||
assert fake_service.created_id in payload["result"]
|
||||
assert fake_service.list_calls == [{"page": 1, "page_size": 20, "query": "会议"}]
|
||||
|
||||
|
||||
@@ -1,132 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agentscope.tools.tool_config import ToolApprovalConfig, ToolConfig, ToolGroup
|
||||
from core.agentscope.tools.tool_middleware import create_approval_middleware
|
||||
|
||||
|
||||
async def _next_handler(**kwargs: Any) -> AsyncGenerator[dict[str, object], None]:
|
||||
async def _generator() -> AsyncGenerator[dict[str, object], None]:
|
||||
yield {"ok": True, "tool_call": kwargs.get("tool_call")}
|
||||
|
||||
return _generator()
|
||||
|
||||
|
||||
def _extract_error_payload(chunk: object) -> dict[str, Any]:
|
||||
content = getattr(chunk, "content", None)
|
||||
if not isinstance(content, list) or not content:
|
||||
return {}
|
||||
first_block = content[0]
|
||||
text = getattr(first_block, "text", None)
|
||||
if not isinstance(text, str) and isinstance(first_block, dict):
|
||||
raw_text = first_block.get("text")
|
||||
text = raw_text if isinstance(raw_text, str) else None
|
||||
if not isinstance(text, str):
|
||||
return {}
|
||||
return json.loads(text)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hitl_middleware_default_write_does_not_require_approval() -> None:
|
||||
middleware = create_approval_middleware(
|
||||
config_by_name={
|
||||
"calendar_write": ToolConfig(
|
||||
name="calendar_write",
|
||||
group=ToolGroup.EXECUTE,
|
||||
approval=ToolApprovalConfig(required=False),
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
responses = []
|
||||
async for chunk in middleware(
|
||||
{"tool_call": {"name": "calendar.write", "input": {"operation": "create"}}},
|
||||
_next_handler,
|
||||
):
|
||||
responses.append(chunk)
|
||||
|
||||
assert responses[0]["ok"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hitl_middleware_pending_when_tool_requires_approval() -> None:
|
||||
middleware = create_approval_middleware(
|
||||
config_by_name={
|
||||
"calendar_write": ToolConfig(
|
||||
name="calendar_write",
|
||||
group=ToolGroup.EXECUTE,
|
||||
approval=ToolApprovalConfig(required=True),
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
responses = []
|
||||
async for chunk in middleware(
|
||||
{"tool_call": {"name": "calendar_write", "input": {"operation": "create"}}},
|
||||
_next_handler,
|
||||
):
|
||||
responses.append(chunk)
|
||||
|
||||
payload = _extract_error_payload(responses[0])
|
||||
assert payload["error"]["code"] == "TOOL_PENDING_APPROVAL"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hitl_middleware_passes_when_write_approved() -> None:
|
||||
middleware = create_approval_middleware(
|
||||
config_by_name={
|
||||
"calendar_write": ToolConfig(
|
||||
name="calendar_write",
|
||||
group=ToolGroup.EXECUTE,
|
||||
approval=ToolApprovalConfig(required=True),
|
||||
)
|
||||
},
|
||||
approval_resolver=lambda _name, _args, _config: "approved",
|
||||
)
|
||||
|
||||
responses = []
|
||||
async for chunk in middleware(
|
||||
{
|
||||
"tool_call": {
|
||||
"name": "calendar.write",
|
||||
"input": {
|
||||
"operation": "create",
|
||||
"_hitl": {"approval": "required"},
|
||||
},
|
||||
}
|
||||
},
|
||||
_next_handler,
|
||||
):
|
||||
responses.append(chunk)
|
||||
|
||||
assert responses[0]["ok"] is True
|
||||
sanitized_input = responses[0]["tool_call"]["input"]
|
||||
assert "_hitl" not in sanitized_input
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hitl_middleware_rejected_short_circuits() -> None:
|
||||
middleware = create_approval_middleware(
|
||||
config_by_name={
|
||||
"calendar_write": ToolConfig(
|
||||
name="calendar_write",
|
||||
group=ToolGroup.EXECUTE,
|
||||
approval=ToolApprovalConfig(required=True),
|
||||
)
|
||||
},
|
||||
approval_resolver=lambda _name, _args, _config: "rejected",
|
||||
)
|
||||
|
||||
responses = []
|
||||
async for chunk in middleware(
|
||||
{"tool_call": {"name": "calendar_write", "input": {"operation": "create"}}},
|
||||
_next_handler,
|
||||
):
|
||||
responses.append(chunk)
|
||||
|
||||
payload = _extract_error_payload(responses[0])
|
||||
assert payload["error"]["code"] == "TOOL_REJECTED"
|
||||
@@ -9,7 +9,7 @@ from agentscope.tool import ToolResponse
|
||||
|
||||
from core.agentscope.tools.custom import memory as memory_module
|
||||
from models.memories import MemoryType
|
||||
from schemas.memories.memory_content import UserMemoryContent
|
||||
from schemas.domain.memory_content import UserMemoryContent
|
||||
|
||||
|
||||
def _decode_tool_response(response: ToolResponse) -> dict[str, object]:
|
||||
|
||||
@@ -9,7 +9,7 @@ from core.agentscope.prompts.system_prompt import (
|
||||
)
|
||||
from schemas.agent.forwarded_props import ClientTimeContext
|
||||
from schemas.agent.system_agent import AgentType
|
||||
from schemas.user.context import UserContext, parse_profile_settings
|
||||
from schemas.shared.user import UserContext, parse_profile_settings
|
||||
|
||||
|
||||
def _build_user_context(*, timezone_name: str = "Asia/Shanghai") -> UserContext:
|
||||
@@ -159,7 +159,7 @@ def test_build_system_prompt_keeps_sections_focused_without_language_duplication
|
||||
|
||||
|
||||
def test_build_system_prompt_includes_user_memory_section_for_router() -> None:
|
||||
from schemas.memories.memory_content import UserMemoryContent
|
||||
from schemas.domain.memory_content import UserMemoryContent
|
||||
|
||||
user_memory = UserMemoryContent()
|
||||
|
||||
@@ -175,7 +175,7 @@ def test_build_system_prompt_includes_user_memory_section_for_router() -> None:
|
||||
|
||||
|
||||
def test_build_system_prompt_includes_work_memory_section_for_worker() -> None:
|
||||
from schemas.memories.memory_content import WorkProfileContent
|
||||
from schemas.domain.memory_content import WorkProfileContent
|
||||
|
||||
work_memory = WorkProfileContent()
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from uuid import UUID, uuid4
|
||||
import pytest
|
||||
|
||||
from models.automation_jobs import AutomationJob as OrmAutomationJob, ScheduleType
|
||||
from schemas.automation import (
|
||||
from schemas.domain.automation import (
|
||||
RuntimeConfig,
|
||||
ScheduleConfig,
|
||||
ScheduleRunAt,
|
||||
|
||||
@@ -1,87 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Column, String, Table
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from core.db.base import Base, SoftDeleteMixin
|
||||
from core.db.base_repository import BaseRepository
|
||||
|
||||
|
||||
class Widget(SoftDeleteMixin, Base):
|
||||
__tablename__ = "widgets"
|
||||
|
||||
id: Mapped[UUID] = mapped_column(primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_engine():
|
||||
auth_users = Table(
|
||||
"users",
|
||||
Base.metadata,
|
||||
Column("id", String, primary_key=True),
|
||||
schema="auth",
|
||||
extend_existing=True,
|
||||
)
|
||||
engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
|
||||
async with engine.begin() as conn:
|
||||
await conn.exec_driver_sql("ATTACH DATABASE ':memory:' AS auth")
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield engine
|
||||
Base.metadata.remove(auth_users)
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session(db_engine):
|
||||
async_session = async_sessionmaker(
|
||||
bind=db_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
async with async_session() as session:
|
||||
yield session
|
||||
await session.rollback()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_id_filters_soft_deleted(db_session: AsyncSession) -> None:
|
||||
repository = BaseRepository(db_session, Widget)
|
||||
widget_id = uuid4()
|
||||
|
||||
widget = Widget(id=widget_id, name="widget")
|
||||
db_session.add(widget)
|
||||
await db_session.commit()
|
||||
|
||||
found = await repository.get_by_id(widget_id)
|
||||
assert found is not None
|
||||
|
||||
deleted = await repository.soft_delete_by_id(widget_id)
|
||||
assert deleted is not None
|
||||
assert deleted.deleted_at is not None
|
||||
|
||||
missing = await repository.get_by_id(widget_id)
|
||||
assert missing is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_soft_delete_sets_timestamp(db_session: AsyncSession) -> None:
|
||||
repository = BaseRepository(db_session, Widget)
|
||||
widget_id = uuid4()
|
||||
|
||||
widget = Widget(id=widget_id, name="widget")
|
||||
db_session.add(widget)
|
||||
await db_session.commit()
|
||||
|
||||
deleted = await repository.soft_delete_by_id(widget_id)
|
||||
assert deleted is not None
|
||||
assert isinstance(deleted.deleted_at, datetime)
|
||||
deleted_at = deleted.deleted_at
|
||||
if deleted_at.tzinfo is None:
|
||||
deleted_at = deleted_at.replace(tzinfo=timezone.utc)
|
||||
assert deleted_at <= datetime.now(timezone.utc)
|
||||
@@ -1,134 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Column, String, Table, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from core.db.base import Base
|
||||
from models.profile import Profile
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_engine():
|
||||
"""Create in-memory SQLite engine for testing."""
|
||||
users_table = Table(
|
||||
"users",
|
||||
Base.metadata,
|
||||
Column("id", String, primary_key=True),
|
||||
schema="auth",
|
||||
extend_existing=True,
|
||||
)
|
||||
engine = create_async_engine(
|
||||
"sqlite+aiosqlite:///:memory:",
|
||||
echo=False,
|
||||
)
|
||||
async with engine.begin() as conn:
|
||||
await conn.exec_driver_sql("ATTACH DATABASE ':memory:' AS auth")
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield engine
|
||||
Base.metadata.remove(users_table)
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session(db_engine):
|
||||
"""Create a database session for testing."""
|
||||
async_session = async_sessionmaker(
|
||||
bind=db_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
async with async_session() as session:
|
||||
yield session
|
||||
await session.rollback()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_profile_model_create(db_session: AsyncSession) -> None:
|
||||
"""Test creating a Profile model."""
|
||||
profile_id = uuid4()
|
||||
profile = Profile(
|
||||
id=profile_id,
|
||||
username="testuser",
|
||||
)
|
||||
db_session.add(profile)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(profile)
|
||||
|
||||
assert profile.id == profile_id
|
||||
assert profile.username == "testuser"
|
||||
assert profile.created_at is not None
|
||||
assert profile.updated_at is not None
|
||||
assert profile.deleted_at is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_profile_model_get_by_id(db_session: AsyncSession) -> None:
|
||||
"""Test retrieving a Profile by ID."""
|
||||
profile_id = uuid4()
|
||||
profile = Profile(
|
||||
id=profile_id,
|
||||
username="testuser",
|
||||
)
|
||||
db_session.add(profile)
|
||||
await db_session.commit()
|
||||
|
||||
result = await db_session.get(Profile, profile_id)
|
||||
assert result is not None
|
||||
assert result.username == "testuser"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_profile_model_get_by_username(db_session: AsyncSession) -> None:
|
||||
"""Test retrieving a Profile by username."""
|
||||
profile = Profile(
|
||||
id=uuid4(),
|
||||
username="testuser",
|
||||
)
|
||||
db_session.add(profile)
|
||||
await db_session.commit()
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Profile).where(Profile.username == "testuser")
|
||||
)
|
||||
found = result.scalar_one()
|
||||
assert found is not None
|
||||
assert found.username == "testuser"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_profile_model_update(db_session: AsyncSession) -> None:
|
||||
"""Test updating a Profile."""
|
||||
profile = Profile(
|
||||
id=uuid4(),
|
||||
username="testuser",
|
||||
bio="Old bio",
|
||||
)
|
||||
db_session.add(profile)
|
||||
await db_session.commit()
|
||||
|
||||
profile.bio = "New bio"
|
||||
await db_session.commit()
|
||||
await db_session.refresh(profile)
|
||||
|
||||
assert profile.bio == "New bio"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_profile_model_allows_duplicate_usernames(
|
||||
db_session: AsyncSession,
|
||||
) -> None:
|
||||
first = Profile(id=uuid4(), username="same_name")
|
||||
second = Profile(id=uuid4(), username="same_name")
|
||||
|
||||
db_session.add(first)
|
||||
db_session.add(second)
|
||||
await db_session.commit()
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Profile).where(Profile.username == "same_name")
|
||||
)
|
||||
found = result.scalars().all()
|
||||
assert len(found) == 2
|
||||
@@ -1,32 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
ROOT_DIR = Path(__file__).resolve().parents[4]
|
||||
APP_SCRIPT = ROOT_DIR / "infra" / "scripts" / "app.sh"
|
||||
|
||||
|
||||
def test_worker_commands_use_taskiq() -> None:
|
||||
content = APP_SCRIPT.read_text(encoding="utf-8")
|
||||
removed_runner = "uv run celery"
|
||||
|
||||
assert "uv run taskiq worker" in content
|
||||
assert "core.taskiq.app:critical_broker" in content
|
||||
assert "core.taskiq.app:default_broker" in content
|
||||
assert "core.taskiq.app:bulk_broker" in content
|
||||
assert 'pgrep -f "uv run taskiq worker core.taskiq.app:"' in content
|
||||
assert 'kill_pids_gracefully "taskiq workers"' in content
|
||||
assert "gunicorn" not in content
|
||||
assert removed_runner not in content
|
||||
|
||||
|
||||
def test_web_command_uses_uvicorn_only() -> None:
|
||||
content = APP_SCRIPT.read_text(encoding="utf-8")
|
||||
|
||||
assert "uv run uvicorn app:app" in content
|
||||
assert 'WEB_PORT="${SOCIAL_WEB__PORT:-5775}"' in content
|
||||
assert "SOCIAL_WEB__WORKERS" in content
|
||||
assert 'UVICORN_LOG_LEVEL="${UVICORN_LOG_LEVEL,,}"' in content
|
||||
assert "SOCIAL_WEB__GUNICORN__" not in content
|
||||
assert "uv run gunicorn" not in content
|
||||
@@ -1,36 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agentscope.schemas.consumer_registry import (
|
||||
AgentConsumerBinding,
|
||||
ConsumerRegistry,
|
||||
)
|
||||
from core.agentscope.schemas.pipeline_spec import ExecutorKind, PipelineSpec, StageSpec
|
||||
from schemas.agent.system_agent import AgentType
|
||||
|
||||
|
||||
def test_consumer_registry_rejects_duplicate_bits() -> None:
|
||||
with pytest.raises(ValueError, match="duplicate visibility bit"):
|
||||
ConsumerRegistry(
|
||||
bindings=[
|
||||
AgentConsumerBinding(agent_type="router", bit=16),
|
||||
AgentConsumerBinding(agent_type="worker", bit=16),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_pipeline_spec_requires_non_empty_stages() -> None:
|
||||
with pytest.raises(ValueError, match="at least 1 item"):
|
||||
PipelineSpec(mode="worker", stages=[])
|
||||
|
||||
|
||||
def test_stage_spec_normalizes_stage_name() -> None:
|
||||
spec = StageSpec(
|
||||
stage_name=" Worker ",
|
||||
agent_type=AgentType.WORKER,
|
||||
executor_kind=ExecutorKind.REACT,
|
||||
)
|
||||
|
||||
assert spec.stage_name == "worker"
|
||||
assert spec.agent_type == AgentType.WORKER
|
||||
@@ -2,23 +2,25 @@ from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from schemas.automation.config import AutomationJobConfig, default_memory_job_config
|
||||
from schemas.domain.automation import AutomationJobConfig
|
||||
from v1.auth.automation_static_config import load_static_automation_job_config
|
||||
|
||||
|
||||
def test_default_memory_job_config_has_expected_defaults() -> None:
|
||||
config = default_memory_job_config()
|
||||
def test_memory_extraction_static_config_has_expected_defaults() -> None:
|
||||
config = load_static_automation_job_config(config_name="memory_extraction")
|
||||
|
||||
assert config.agent_type.value == "memory"
|
||||
assert config.model_code == "qwen3.5-flash"
|
||||
assert "memory.write" in (config.enabled_tools or [])
|
||||
assert "memory.forget" in (config.enabled_tools or [])
|
||||
assert config.context is not None
|
||||
assert config.context.source.value == "latest_chat"
|
||||
assert config.schedule is not None
|
||||
assert config.schedule.type.value == "daily"
|
||||
|
||||
|
||||
def test_automation_job_config_rejects_non_flash_model() -> None:
|
||||
with pytest.raises(ValueError, match="model_code must be qwen3.5-flash"):
|
||||
def test_automation_job_config_rejects_missing_weekdays_for_weekly() -> None:
|
||||
with pytest.raises(ValueError, match="weekdays is required"):
|
||||
AutomationJobConfig.model_validate(
|
||||
{
|
||||
"agent_type": "memory",
|
||||
"model_code": "qwen-plus",
|
||||
"enabled_tools": ["calendar.read"],
|
||||
"input_template": "x",
|
||||
"context": {
|
||||
@@ -26,5 +28,9 @@ def test_automation_job_config_rejects_non_flash_model() -> None:
|
||||
"window_mode": "day",
|
||||
"window_count": 2,
|
||||
},
|
||||
"schedule": {
|
||||
"type": "weekly",
|
||||
"run_at": {"hour": 9, "minute": 0},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
from datetime import UTC, datetime
|
||||
from uuid import uuid4
|
||||
|
||||
from schemas.messages.chat_message import AgentChatMessage
|
||||
from schemas.domain.chat_message import AgentChatMessage
|
||||
|
||||
|
||||
def test_agent_chat_message_schema_matches_messages_columns() -> None:
|
||||
|
||||
@@ -1,283 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from v1.agent.dependencies import RedisEventStream, TaskiqQueueClient
|
||||
|
||||
|
||||
class _FakeRedis:
|
||||
def __init__(self) -> None:
|
||||
self.store: dict[str, str] = {}
|
||||
self.delete_calls: list[str] = []
|
||||
|
||||
async def set(
|
||||
self,
|
||||
key: str,
|
||||
value: str,
|
||||
*,
|
||||
nx: bool = False,
|
||||
ex: int | None = None,
|
||||
) -> bool:
|
||||
del ex
|
||||
if nx and key in self.store:
|
||||
return False
|
||||
self.store[key] = value
|
||||
return True
|
||||
|
||||
async def get(self, key: str) -> str | None:
|
||||
return self.store.get(key)
|
||||
|
||||
async def delete(self, key: str) -> int:
|
||||
self.delete_calls.append(key)
|
||||
existed = 1 if key in self.store else 0
|
||||
self.store.pop(key, None)
|
||||
return existed
|
||||
|
||||
|
||||
class _FakeAsyncResult:
|
||||
def __init__(self, task_id: str) -> None:
|
||||
self.task_id = task_id
|
||||
|
||||
|
||||
class _FakeRedisStreamClient:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_returns_task_id(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
from v1.agent import dependencies as deps
|
||||
|
||||
fake_redis = _FakeRedis()
|
||||
resolved_client = {"value": False}
|
||||
|
||||
async def _fake_kiq(payload: dict[str, object]) -> _FakeAsyncResult:
|
||||
assert payload["command"] == "run"
|
||||
return _FakeAsyncResult("task-123")
|
||||
|
||||
async def _fake_get_or_init_client() -> _FakeRedis:
|
||||
resolved_client["value"] = True
|
||||
return fake_redis
|
||||
|
||||
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
|
||||
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_kiq)
|
||||
|
||||
client = TaskiqQueueClient()
|
||||
task_id = await client.enqueue(command={"command": "run"}, dedup_key=None)
|
||||
|
||||
assert resolved_client["value"] is True
|
||||
assert task_id == "task-123"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_resume_dedup_returns_existing_task_id(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
from v1.agent import dependencies as deps
|
||||
|
||||
fake_redis = _FakeRedis()
|
||||
resolved_client = {"value": False}
|
||||
|
||||
async def _fake_kiq(payload: dict[str, object]) -> _FakeAsyncResult:
|
||||
del payload
|
||||
return _FakeAsyncResult("new-task-id")
|
||||
|
||||
async def _fake_get_or_init_client() -> _FakeRedis:
|
||||
resolved_client["value"] = True
|
||||
return fake_redis
|
||||
|
||||
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
|
||||
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_kiq)
|
||||
|
||||
dedup_key = "resume:session-1:call-1"
|
||||
fake_redis.store[f"agent:dedup:{dedup_key}"] = "existing-task-id"
|
||||
|
||||
client = TaskiqQueueClient()
|
||||
task_id = await client.enqueue(
|
||||
command={
|
||||
"command": "resume",
|
||||
"session_id": "session-1",
|
||||
"tool_call_id": "call-1",
|
||||
},
|
||||
dedup_key=dedup_key,
|
||||
)
|
||||
|
||||
assert resolved_client["value"] is True
|
||||
assert task_id == "existing-task-id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_inflight_dedup_waits_and_reuses_existing_task_id(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
from v1.agent import dependencies as deps
|
||||
|
||||
fake_redis = _FakeRedis()
|
||||
dedup_key = "resume:session-1:call-1"
|
||||
redis_key = f"agent:dedup:{dedup_key}"
|
||||
fake_redis.store[redis_key] = deps.DEDUP_INFLIGHT_MARKER
|
||||
attempts = {"count": 0}
|
||||
|
||||
async def _fake_get_or_init_client() -> _FakeRedis:
|
||||
return fake_redis
|
||||
|
||||
async def _fake_get(key: str) -> str | None:
|
||||
attempts["count"] += 1
|
||||
if attempts["count"] > 1:
|
||||
fake_redis.store[key] = "existing-task-id"
|
||||
return fake_redis.store.get(key)
|
||||
|
||||
async def _fake_sleep(_: float) -> None:
|
||||
return None
|
||||
|
||||
async def _fake_kiq(payload: dict[str, object]) -> _FakeAsyncResult:
|
||||
del payload
|
||||
raise AssertionError("should not enqueue when dedup task id appears")
|
||||
|
||||
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
|
||||
monkeypatch.setattr(fake_redis, "get", _fake_get)
|
||||
monkeypatch.setattr(deps.asyncio, "sleep", _fake_sleep)
|
||||
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_kiq)
|
||||
|
||||
client = TaskiqQueueClient()
|
||||
task_id = await client.enqueue(
|
||||
command={
|
||||
"command": "resume",
|
||||
"session_id": "session-1",
|
||||
"tool_call_id": "call-1",
|
||||
},
|
||||
dedup_key=dedup_key,
|
||||
)
|
||||
|
||||
assert task_id == "existing-task-id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_failure_cleans_dedup_lock(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
from v1.agent import dependencies as deps
|
||||
|
||||
fake_redis = _FakeRedis()
|
||||
dedup_key = "resume:session-1:call-1"
|
||||
redis_key = f"agent:dedup:{dedup_key}"
|
||||
|
||||
async def _fake_get_or_init_client() -> _FakeRedis:
|
||||
return fake_redis
|
||||
|
||||
async def _fake_kiq(payload: dict[str, object]) -> _FakeAsyncResult:
|
||||
del payload
|
||||
raise RuntimeError("enqueue failed")
|
||||
|
||||
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
|
||||
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_kiq)
|
||||
|
||||
client = TaskiqQueueClient()
|
||||
with pytest.raises(RuntimeError, match="enqueue failed"):
|
||||
await client.enqueue(
|
||||
command={
|
||||
"command": "resume",
|
||||
"session_id": "session-1",
|
||||
"tool_call_id": "call-1",
|
||||
},
|
||||
dedup_key=dedup_key,
|
||||
)
|
||||
|
||||
assert redis_key in fake_redis.delete_calls
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_uses_critical_queue_when_requested(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
from v1.agent import dependencies as deps
|
||||
|
||||
fake_redis = _FakeRedis()
|
||||
|
||||
async def _fake_get_or_init_client() -> _FakeRedis:
|
||||
return fake_redis
|
||||
|
||||
async def _fake_default_kiq(_: dict[str, object]) -> _FakeAsyncResult:
|
||||
raise AssertionError("default queue should not be selected")
|
||||
|
||||
async def _fake_critical_kiq(_: dict[str, object]) -> _FakeAsyncResult:
|
||||
return _FakeAsyncResult("critical-task-id")
|
||||
|
||||
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
|
||||
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_default_kiq)
|
||||
monkeypatch.setattr(deps.run_command_task_critical, "kiq", _fake_critical_kiq)
|
||||
|
||||
client = TaskiqQueueClient()
|
||||
task_id = await client.enqueue(
|
||||
command={"command": "run", "queue": "critical"},
|
||||
dedup_key=None,
|
||||
)
|
||||
|
||||
assert task_id == "critical-task-id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_uses_bulk_queue_when_requested(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
from v1.agent import dependencies as deps
|
||||
|
||||
fake_redis = _FakeRedis()
|
||||
|
||||
async def _fake_get_or_init_client() -> _FakeRedis:
|
||||
return fake_redis
|
||||
|
||||
async def _fake_default_kiq(_: dict[str, object]) -> _FakeAsyncResult:
|
||||
raise AssertionError("default queue should not be selected")
|
||||
|
||||
async def _fake_bulk_kiq(_: dict[str, object]) -> _FakeAsyncResult:
|
||||
return _FakeAsyncResult("bulk-task-id")
|
||||
|
||||
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
|
||||
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_default_kiq)
|
||||
monkeypatch.setattr(deps.run_command_task_bulk, "kiq", _fake_bulk_kiq)
|
||||
|
||||
client = TaskiqQueueClient()
|
||||
task_id = await client.enqueue(
|
||||
command={"command": "run", "queue": "bulk"},
|
||||
dedup_key=None,
|
||||
)
|
||||
|
||||
assert task_id == "bulk-task-id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_stream_caps_block_ms_below_socket_timeout(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
from v1.agent import dependencies as deps
|
||||
|
||||
async def _fake_get_or_init_client() -> _FakeRedisStreamClient:
|
||||
return _FakeRedisStreamClient()
|
||||
|
||||
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
|
||||
monkeypatch.setattr(deps.config.agent_runtime, "redis_stream_block_ms", 5000)
|
||||
monkeypatch.setattr(deps.config.redis, "socket_timeout", 1.0)
|
||||
|
||||
stream = RedisEventStream()
|
||||
bus = await stream._get_bus()
|
||||
|
||||
assert bus._block_ms == 900
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_stream_uses_configured_block_ms_when_safe(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
from v1.agent import dependencies as deps
|
||||
|
||||
async def _fake_get_or_init_client() -> _FakeRedisStreamClient:
|
||||
return _FakeRedisStreamClient()
|
||||
|
||||
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
|
||||
monkeypatch.setattr(deps.config.agent_runtime, "redis_stream_block_ms", 200)
|
||||
monkeypatch.setattr(deps.config.redis, "socket_timeout", 2.0)
|
||||
|
||||
stream = RedisEventStream()
|
||||
bus = await stream._get_bus()
|
||||
|
||||
assert bus._block_ms == 200
|
||||
@@ -12,7 +12,7 @@ import pytest
|
||||
import v1.agent.service as agent_service_module
|
||||
from core.auth.models import CurrentUser
|
||||
from core.config.settings import config
|
||||
from schemas.messages.chat_message import AgentChatMessageMetadata
|
||||
from schemas.domain.chat_message import AgentChatMessageMetadata
|
||||
from v1.agent.service import AgentService
|
||||
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from models.automation_jobs import ScheduleType
|
||||
from v1.auth.registration_bootstrap import (
|
||||
compute_first_run_at_utc,
|
||||
)
|
||||
from schemas.automation import ScheduleConfig, ScheduleRunAt
|
||||
from schemas.domain.automation import ScheduleConfig, ScheduleRunAt
|
||||
|
||||
|
||||
def test_compute_first_run_at_utc_from_asia_shanghai() -> None:
|
||||
|
||||
@@ -4,7 +4,7 @@ from uuid import uuid4
|
||||
import pytest
|
||||
|
||||
from models.automation_jobs import AutomationJobStatus, ScheduleType
|
||||
from schemas.automation import (
|
||||
from schemas.domain.automation import (
|
||||
AgentTool,
|
||||
AutomationJobConfig,
|
||||
ContextSource,
|
||||
|
||||
@@ -5,7 +5,8 @@ from uuid import uuid4
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from schemas.automation import AgentTool, AutomationJobConfig
|
||||
from core.agentscope.tools.tool_config import AgentTool
|
||||
from schemas.domain.automation import AutomationJobConfig
|
||||
from v1.automation_jobs.schemas import (
|
||||
AutomationJobCreateRequest,
|
||||
AutomationJobResponse,
|
||||
|
||||
@@ -17,7 +17,7 @@ from v1.automation_jobs.schemas import (
|
||||
AutomationJobCreateRequest,
|
||||
AutomationJobUpdateRequest,
|
||||
)
|
||||
from schemas.automation import (
|
||||
from schemas.domain.automation import (
|
||||
AgentTool,
|
||||
AutomationJobConfig,
|
||||
ContextSource,
|
||||
|
||||
@@ -5,7 +5,7 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
from schemas.user.context import UserContext
|
||||
from schemas.shared.user import UserContext
|
||||
|
||||
from v1.friendships.schemas import (
|
||||
FriendRequestCreate,
|
||||
|
||||
@@ -16,7 +16,7 @@ def test_inbox_message_response_schema() -> None:
|
||||
sender_id=uuid4(),
|
||||
message_type=InboxMessageType.CALENDAR,
|
||||
schedule_item_id=uuid4(),
|
||||
content="Join my calendar",
|
||||
content={"type": "invite", "permission": 1, "action": "pending"},
|
||||
is_read=False,
|
||||
status=InboxMessageStatus.PENDING,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
|
||||
@@ -22,7 +22,7 @@ def _build_message(
|
||||
status: InboxMessageModelStatus = InboxMessageModelStatus.PENDING,
|
||||
message_type: InboxMessageModelType = InboxMessageModelType.CALENDAR,
|
||||
schedule_item_id: UUID | None = None,
|
||||
content: str = '{"permission": 7}',
|
||||
content: dict[str, object] = {"permission": 7},
|
||||
) -> InboxMessage:
|
||||
message = MagicMock(spec=InboxMessage)
|
||||
message.id = message_id
|
||||
|
||||
@@ -500,3 +500,43 @@ async def test_list_by_date_range_rolls_back_when_query_fails_after_archive(
|
||||
assert exc_info.value.status_code == 503
|
||||
mock_session.rollback.assert_awaited_once()
|
||||
mock_session.commit.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_id_maps_legacy_completed_to_archived(
|
||||
mock_session: AsyncMock,
|
||||
mock_inbox_repository: MagicMock,
|
||||
) -> None:
|
||||
user_id = UUID("00000000-0000-0000-0000-000000000001")
|
||||
item = _create_mock_schedule_item()
|
||||
setattr(item, "status", "completed")
|
||||
service = ScheduleItemService(
|
||||
repository=FakeRepo(item),
|
||||
session=mock_session,
|
||||
current_user=CurrentUser(id=user_id),
|
||||
inbox_repository=mock_inbox_repository,
|
||||
)
|
||||
|
||||
result = await service.get_by_id(item.id)
|
||||
|
||||
assert result.status == ScheduleItemStatus.ARCHIVED
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_id_maps_legacy_canceled_to_archived(
|
||||
mock_session: AsyncMock,
|
||||
mock_inbox_repository: MagicMock,
|
||||
) -> None:
|
||||
user_id = UUID("00000000-0000-0000-0000-000000000001")
|
||||
item = _create_mock_schedule_item()
|
||||
setattr(item, "status", "canceled")
|
||||
service = ScheduleItemService(
|
||||
repository=FakeRepo(item),
|
||||
session=mock_session,
|
||||
current_user=CurrentUser(id=user_id),
|
||||
inbox_repository=mock_inbox_repository,
|
||||
)
|
||||
|
||||
result = await service.get_by_id(item.id)
|
||||
|
||||
assert result.status == ScheduleItemStatus.ARCHIVED
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from uuid import UUID, uuid4
|
||||
@@ -50,6 +49,11 @@ class FakeInboxRepo:
|
||||
return self._inbox
|
||||
return None
|
||||
|
||||
async def get_calendar_invite(
|
||||
self, schedule_item_id: UUID, recipient_id: UUID
|
||||
) -> InboxMessage | None:
|
||||
return await self.get_pending_calendar_invite(schedule_item_id, recipient_id)
|
||||
|
||||
async def create(self, data: dict) -> InboxMessage:
|
||||
return MagicMock()
|
||||
|
||||
@@ -80,6 +84,9 @@ def mock_session() -> AsyncMock:
|
||||
@pytest.fixture
|
||||
def mock_repo() -> MagicMock:
|
||||
repo = MagicMock()
|
||||
repo.get_subscription = AsyncMock(return_value=None)
|
||||
repo.update_subscription_status = AsyncMock(return_value=None)
|
||||
repo.archive_expired_subscribed_items = AsyncMock(return_value=0)
|
||||
repo.create_subscription = AsyncMock(return_value=MagicMock())
|
||||
return repo
|
||||
|
||||
@@ -196,6 +203,9 @@ async def test_list_by_date_range_with_subscriptions(
|
||||
|
||||
mock_repo.list_by_date_range = AsyncMock(return_value=[owned_item])
|
||||
mock_repo.get_user_subscriptions = AsyncMock(return_value=[subscription])
|
||||
mock_repo.list_subscribed_items_by_date_range = AsyncMock(
|
||||
return_value=[(subscribed_item, subscription)]
|
||||
)
|
||||
mock_repo.get_by_id = AsyncMock(return_value=subscribed_item)
|
||||
|
||||
service = ScheduleItemService(
|
||||
@@ -214,7 +224,6 @@ async def test_list_by_date_range_with_subscriptions(
|
||||
|
||||
result = await service.list_by_date_range(request)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0].is_owner is True
|
||||
assert result[1].is_owner is False
|
||||
assert result[1].permission == 1
|
||||
assert len(result) == 1
|
||||
assert result[0].is_owner is False
|
||||
assert result[0].permission == 1
|
||||
|
||||
Reference in New Issue
Block a user