feat(agent): 增强多模态链路与工具调用能力

This commit is contained in:
zl-q
2026-03-12 00:18:45 +08:00
parent 18db6c50e7
commit 21ba8e4a44
35 changed files with 2057 additions and 829 deletions
@@ -10,6 +10,31 @@ from models.agent_chat_message import AgentChatMessageRole
from v1.agent.repository import AgentRepository
class _ExecuteResult:
def __init__(self, value: object) -> None:
self._value = value
def scalar_one_or_none(self) -> object:
return self._value
class _FakeSession:
def __init__(self, session_row: object) -> None:
self.session_row = session_row
self.added: list[object] = []
self.flushed = False
async def execute(self, stmt): # noqa: ANN001
del stmt
return _ExecuteResult(self.session_row)
def add(self, obj: object) -> None:
self.added.append(obj)
async def flush(self) -> None:
self.flushed = True
class _FakeToolResultStorage:
def __init__(self, payload: dict[str, object] | None) -> None:
self._payload = payload
@@ -104,3 +129,48 @@ async def test_user_message_snapshot_includes_renderable_attachments() -> None:
"mimeType": "image/png",
}
]
@pytest.mark.asyncio
async def test_persist_user_message_sets_session_title_when_empty() -> None:
session_id = str(uuid4())
session_row = SimpleNamespace(
message_count=0,
title=None,
last_activity_at=datetime.now(timezone.utc),
)
fake_session = _FakeSession(session_row)
repository = AgentRepository(session=fake_session) # type: ignore[arg-type]
await repository.persist_user_message(
session_id=session_id,
run_id="run-1",
content_text=" 请帮我安排明天下午开会 ",
metadata=None,
)
assert session_row.title == "请帮我安排明天下午开会"
assert session_row.message_count == 1
assert fake_session.flushed is True
@pytest.mark.asyncio
async def test_persist_user_message_keeps_existing_session_title() -> None:
session_id = str(uuid4())
session_row = SimpleNamespace(
message_count=1,
title="已有标题",
last_activity_at=datetime.now(timezone.utc),
)
fake_session = _FakeSession(session_row)
repository = AgentRepository(session=fake_session) # type: ignore[arg-type]
await repository.persist_user_message(
session_id=session_id,
run_id="run-2",
content_text="新的消息内容",
metadata=None,
)
assert session_row.title == "已有标题"
assert session_row.message_count == 2
@@ -175,3 +175,53 @@ async def test_enqueue_resume_accepts_valid_tool_contract(
assert result.task_id == "task-resume-1"
assert result.thread_id == "00000000-0000-0000-0000-000000000001"
assert result.run_id == "run-resume-1"
@pytest.mark.asyncio
async def test_stream_events_retries_on_redis_timeout(
monkeypatch: pytest.MonkeyPatch,
) -> None:
async def _acquire(*, user_id: str) -> bool:
del user_id
return True
async def _release(*, user_id: str) -> None:
del user_id
monkeypatch.setattr(agent_router, "_acquire_sse_slot", _acquire)
monkeypatch.setattr(agent_router, "_release_sse_slot", _release)
class _Request:
async def is_disconnected(self) -> bool:
return False
class _Service:
def __init__(self) -> None:
self.calls = 0
async def stream_events(self, **kwargs): # noqa: ANN003
del kwargs
self.calls += 1
if self.calls == 1:
raise RuntimeError("Timeout reading from localhost:6379")
if self.calls == 2:
return [{"id": "1-0", "event": {"type": "RUN_FINISHED"}}]
return []
response = await agent_router.stream_events(
request=cast(Any, _Request()),
thread_id="00000000-0000-0000-0000-000000000001",
service=cast(Any, _Service()),
current_user=CurrentUser(id=uuid4(), email="user@example.com"),
last_event_id=None,
idle_limit=2,
)
chunks: list[str] = []
async for chunk in response.body_iterator:
chunks.append(str(chunk))
if any("RUN_FINISHED" in item for item in chunks):
break
merged = "".join(chunks)
assert "event: RUN_FINISHED" in merged
@@ -124,6 +124,19 @@ class _FakeAttachmentStorage:
return path
class _AlwaysFailAttachmentStorage:
async def upload_bytes(
self,
*,
bucket: str,
path: str,
content: bytes,
content_type: str,
) -> str:
del bucket, path, content, content_type
raise RuntimeError("upload failed")
def _user() -> CurrentUser:
return CurrentUser(
id=UUID("00000000-0000-0000-0000-000000000001"),
@@ -317,6 +330,54 @@ async def test_enqueue_run_uploads_user_image_to_supabase_and_injects_metadata(
assert isinstance(attachments[0]["path"], str)
async def test_enqueue_run_raises_when_attachment_upload_fails_without_fallback(
monkeypatch,
) -> None:
monkeypatch.setattr(
agent_service_module.config.storage, "bucket", "agent-test-bucket"
)
repository = _FakeRepository()
service = AgentService(
repository=repository,
queue=_FakeQueue(),
stream=_FakeStream(),
attachment_storage=_AlwaysFailAttachmentStorage(),
)
run_input = RunAgentInput.model_validate(
{
"threadId": "00000000-0000-0000-0000-000000000001",
"runId": "run-with-image-fail",
"state": {},
"messages": [
{
"id": "u1",
"role": "user",
"content": [
{"type": "text", "text": "帮我看下这张图"},
{
"type": "binary",
"data": "aGVsbG8=",
"mimeType": "image/png",
},
],
}
],
"tools": [],
"context": [],
"forwardedProps": {},
}
)
try:
await service.enqueue_run(run_input=run_input, current_user=_user())
raise AssertionError("expected HTTPException")
except HTTPException as exc:
assert exc.status_code == 502
assert exc.detail == "Failed to upload attachment"
assert repository.persisted_user_messages == []
async def test_get_history_snapshot_wraps_history_day_as_state_snapshot_event() -> None:
service = AgentService(
repository=_FakeRepository(),