f0af44d840
- Update agent router/service/repository with new endpoints - Update auth routes with phone-based authentication - Update users service with new phone lookup - Update schedule_items with new schemas - Update message schemas with visibility support - Update settings with new automation scheduler config - Update CLI with new commands - Update tests to match new API contracts
181 lines
5.3 KiB
Python
181 lines
5.3 KiB
Python
from __future__ import annotations
|
|
|
|
from datetime import datetime, timezone
|
|
from types import SimpleNamespace
|
|
from uuid import uuid4
|
|
|
|
import pytest
|
|
|
|
from models.agent_chat_message import AgentChatMessageRole
|
|
from sqlalchemy import select
|
|
from models.agent_chat_message import AgentChatMessage
|
|
from v1.agent.repository import AgentRepository
|
|
|
|
|
|
class _ExecuteResult:
|
|
def __init__(self, value: object) -> None:
|
|
self._value = value
|
|
|
|
def scalar_one_or_none(self) -> object:
|
|
return self._value
|
|
|
|
|
|
class _FakeSession:
|
|
def __init__(self, session_row: object) -> None:
|
|
self.session_row = session_row
|
|
self.added: list[object] = []
|
|
self.flushed = False
|
|
|
|
async def execute(self, stmt): # noqa: ANN001
|
|
del stmt
|
|
return _ExecuteResult(self.session_row)
|
|
|
|
def add(self, obj: object) -> None:
|
|
self.added.append(obj)
|
|
|
|
async def flush(self) -> None:
|
|
self.flushed = True
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_snapshot_message_returns_raw_db_columns() -> None:
|
|
repository = AgentRepository(session=SimpleNamespace()) # type: ignore[arg-type]
|
|
now = datetime.now(timezone.utc)
|
|
message = SimpleNamespace(
|
|
id=uuid4(),
|
|
session_id=uuid4(),
|
|
seq=7,
|
|
role=AgentChatMessageRole.TOOL,
|
|
content='{"offloaded":true}',
|
|
model_code=None,
|
|
tool_name=None,
|
|
input_tokens=0,
|
|
output_tokens=0,
|
|
cost=0,
|
|
latency_ms=None,
|
|
metadata_json={"tool_call_id": "call-1"},
|
|
created_at=now,
|
|
)
|
|
|
|
payload = await repository._to_snapshot_message(message) # type: ignore[arg-type]
|
|
|
|
assert payload["seq"] == 7
|
|
assert payload["role"] == "tool"
|
|
assert payload["content"] == '{"offloaded":true}'
|
|
assert payload["metadata"] == {"tool_call_id": "call-1"}
|
|
assert "timestamp" in payload
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_persist_user_message_sets_session_title_when_empty() -> None:
|
|
session_id = str(uuid4())
|
|
session_row = SimpleNamespace(
|
|
message_count=0,
|
|
title=None,
|
|
last_activity_at=datetime.now(timezone.utc),
|
|
)
|
|
fake_session = _FakeSession(session_row)
|
|
repository = AgentRepository(session=fake_session) # type: ignore[arg-type]
|
|
|
|
await repository.persist_user_message(
|
|
session_id=session_id,
|
|
content=" 请帮我安排明天下午开会 ",
|
|
metadata=None,
|
|
visibility_mask=1,
|
|
)
|
|
|
|
assert session_row.title == "请帮我安排明天下午开会"
|
|
assert session_row.message_count == 1
|
|
assert fake_session.flushed is True
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_persist_user_message_keeps_existing_session_title() -> None:
|
|
session_id = str(uuid4())
|
|
session_row = SimpleNamespace(
|
|
message_count=1,
|
|
title="已有标题",
|
|
last_activity_at=datetime.now(timezone.utc),
|
|
)
|
|
fake_session = _FakeSession(session_row)
|
|
repository = AgentRepository(session=fake_session) # type: ignore[arg-type]
|
|
|
|
await repository.persist_user_message(
|
|
session_id=session_id,
|
|
content="新的消息内容",
|
|
metadata=None,
|
|
visibility_mask=1,
|
|
)
|
|
|
|
assert session_row.title == "已有标题"
|
|
assert session_row.message_count == 2
|
|
|
|
|
|
class _ScalarRows:
|
|
def __init__(self, rows: list[object]) -> None:
|
|
self._rows = rows
|
|
|
|
def all(self) -> list[object]:
|
|
return self._rows
|
|
|
|
|
|
class _ExecuteRowsResult:
|
|
def __init__(self, rows: list[object]) -> None:
|
|
self._rows = rows
|
|
|
|
def scalars(self) -> _ScalarRows:
|
|
return _ScalarRows(self._rows)
|
|
|
|
|
|
class _FakeHistorySession:
|
|
def __init__(self) -> None:
|
|
self._execute_count = 0
|
|
|
|
async def execute(self, stmt): # noqa: ANN001
|
|
del stmt
|
|
self._execute_count += 1
|
|
if self._execute_count == 1:
|
|
return _ExecuteResult(datetime(2026, 3, 16, 11, 0, tzinfo=timezone.utc))
|
|
if self._execute_count == 2:
|
|
message = SimpleNamespace(
|
|
id=uuid4(),
|
|
seq=1,
|
|
role=AgentChatMessageRole.USER,
|
|
content="hello",
|
|
model_code=None,
|
|
tool_name=None,
|
|
input_tokens=0,
|
|
output_tokens=0,
|
|
cost=0,
|
|
latency_ms=None,
|
|
metadata_json=None,
|
|
created_at=datetime(2026, 3, 16, 11, 0, tzinfo=timezone.utc),
|
|
)
|
|
return _ExecuteRowsResult([message])
|
|
return _ExecuteResult(uuid4())
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_history_day_uses_target_day_queries_only() -> None:
|
|
session = _FakeHistorySession()
|
|
repository = AgentRepository(session=session) # type: ignore[arg-type]
|
|
|
|
payload = await repository.get_history_day(session_id=str(uuid4()), before=None)
|
|
|
|
assert payload is not None
|
|
assert payload["day"] == "2026-03-16"
|
|
assert payload["hasMore"] is True
|
|
messages = payload["messages"]
|
|
assert isinstance(messages, list)
|
|
assert len(messages) == 1
|
|
|
|
|
|
def test_apply_visibility_filter_adds_bitwise_expression() -> None:
|
|
repository = AgentRepository(session=SimpleNamespace()) # type: ignore[arg-type]
|
|
stmt = select(AgentChatMessage)
|
|
|
|
filtered = repository._apply_visibility_filter(stmt=stmt, visibility_mask=1)
|
|
|
|
assert "visibility_mask" in str(filtered)
|
|
assert "&" in str(filtered)
|