Files
social-app/backend/tests/unit/v1/agent/test_repository.py
T

177 lines
5.2 KiB
Python
Raw Normal View History

from __future__ import annotations
from datetime import datetime, timezone
from types import SimpleNamespace
from uuid import uuid4
import pytest
from models.agent_chat_message import AgentChatMessageRole
from v1.agent.repository import AgentRepository
class _ExecuteResult:
def __init__(self, value: object) -> None:
self._value = value
def scalar_one_or_none(self) -> object:
return self._value
class _FakeSession:
def __init__(self, session_row: object) -> None:
self.session_row = session_row
self.added: list[object] = []
self.flushed = False
async def execute(self, stmt): # noqa: ANN001
del stmt
return _ExecuteResult(self.session_row)
def add(self, obj: object) -> None:
self.added.append(obj)
async def flush(self) -> None:
self.flushed = True
class _FakeToolResultStorage:
def __init__(self, payload: dict[str, object] | None) -> None:
self._payload = payload
async def read_json(self, *, bucket: str, path: str) -> dict[str, object] | None:
del bucket, path
return self._payload
@pytest.mark.asyncio
async def test_tool_message_hydrates_content_from_object_storage() -> None:
repository = AgentRepository(
session=SimpleNamespace(), # type: ignore[arg-type]
tool_result_storage=_FakeToolResultStorage(
{
"toolName": "front.navigate_to_route",
"result": {"ok": True, "applied": True, "content": "已跳转"},
}
),
)
message = SimpleNamespace(
id=uuid4(),
role=AgentChatMessageRole.TOOL,
created_at=datetime.now(timezone.utc),
content='{"offloaded":true}',
metadata_json={
"tool_call_id": "call-1",
"storage_bucket": "private",
"storage_path": "tool-results/run-1/call-1.json",
},
)
payload = await repository._to_snapshot_message(message) # type: ignore[arg-type]
assert payload["toolCallId"] == "call-1"
assert payload["content"] == "已跳转"
@pytest.mark.asyncio
async def test_tool_message_keeps_inline_content_when_storage_payload_missing() -> None:
repository = AgentRepository(
session=SimpleNamespace(), # type: ignore[arg-type]
tool_result_storage=_FakeToolResultStorage(None),
)
message = SimpleNamespace(
id=uuid4(),
role=AgentChatMessageRole.TOOL,
created_at=datetime.now(timezone.utc),
content="inline-tool-content",
metadata_json={
"tool_call_id": "call-2",
"storage_bucket": "private",
"storage_path": "tool-results/run-1/call-2.json",
},
)
payload = await repository._to_snapshot_message(message) # type: ignore[arg-type]
assert payload["toolCallId"] == "call-2"
assert payload["content"] == "inline-tool-content"
@pytest.mark.asyncio
async def test_user_message_snapshot_includes_renderable_attachments() -> None:
repository = AgentRepository(
session=SimpleNamespace(), # type: ignore[arg-type]
)
message = SimpleNamespace(
id=uuid4(),
role=AgentChatMessageRole.USER,
created_at=datetime.now(timezone.utc),
content="请分析这张图",
metadata_json={
"attachments": [
{
"bucket": "agent-chat-attachments",
"path": "agent-inputs/u1/t1/r1/m1/att-1.png",
"mimeType": "image/png",
}
]
},
)
payload = await repository._to_snapshot_message(message) # type: ignore[arg-type]
assert payload["role"] == "user"
assert payload["content"] == "请分析这张图"
assert payload["attachments"] == [
{
"bucket": "agent-chat-attachments",
"path": "agent-inputs/u1/t1/r1/m1/att-1.png",
"mimeType": "image/png",
}
]
@pytest.mark.asyncio
async def test_persist_user_message_sets_session_title_when_empty() -> None:
session_id = str(uuid4())
session_row = SimpleNamespace(
message_count=0,
title=None,
last_activity_at=datetime.now(timezone.utc),
)
fake_session = _FakeSession(session_row)
repository = AgentRepository(session=fake_session) # type: ignore[arg-type]
await repository.persist_user_message(
session_id=session_id,
run_id="run-1",
content_text=" 请帮我安排明天下午开会 ",
metadata=None,
)
assert session_row.title == "请帮我安排明天下午开会"
assert session_row.message_count == 1
assert fake_session.flushed is True
@pytest.mark.asyncio
async def test_persist_user_message_keeps_existing_session_title() -> None:
session_id = str(uuid4())
session_row = SimpleNamespace(
message_count=1,
title="已有标题",
last_activity_at=datetime.now(timezone.utc),
)
fake_session = _FakeSession(session_row)
repository = AgentRepository(session=fake_session) # type: ignore[arg-type]
await repository.persist_user_message(
session_id=session_id,
run_id="run-2",
content_text="新的消息内容",
metadata=None,
)
assert session_row.title == "已有标题"
assert session_row.message_count == 2