Files
social-app/backend/tests/unit/v1/agent/test_service.py
T

429 lines
13 KiB
Python

from __future__ import annotations
from datetime import date
from types import SimpleNamespace
from uuid import UUID
from ag_ui.core import RunAgentInput
from fastapi import HTTPException
from sqlalchemy.exc import IntegrityError
from core.auth.models import CurrentUser
import v1.agent.service as agent_service_module
from v1.agent.service import AgentService, AsrService
class _FakeRepository:
def __init__(self) -> None:
self.committed = False
self.rolled_back = False
self.deleted_session_id: str | None = None
self.created_with_session_id: str | None = None
self.persisted_user_messages: list[dict[str, object]] = []
async def get_session_owner(self, *, session_id: str) -> str:
if session_id == "00000000-0000-0000-0000-000000000001":
return "00000000-0000-0000-0000-000000000001"
raise HTTPException(status_code=404, detail="Session not found")
async def create_session_for_user(
self, *, user_id: str, session_id: str | None = None
) -> str:
del user_id
self.created_with_session_id = session_id
return session_id or "00000000-0000-0000-0000-000000000999"
async def commit(self) -> None:
self.committed = True
async def rollback(self) -> None:
self.rolled_back = True
async def delete_session(self, *, session_id: str) -> None:
self.deleted_session_id = session_id
async def get_history_day(
self, *, session_id: str, before: date | None
) -> dict[str, object] | None:
del session_id
if before is not None and before <= date(2026, 3, 6):
return None
return {
"day": "2026-03-06",
"hasMore": False,
"messages": [{"id": "m1", "role": "assistant", "content": "hello"}],
}
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None:
del user_id
return "00000000-0000-0000-0000-000000000001"
async def persist_user_message(
self,
*,
session_id: str,
run_id: str,
content_text: str,
metadata: dict[str, object] | None,
) -> None:
self.persisted_user_messages.append(
{
"session_id": session_id,
"run_id": run_id,
"content_text": content_text,
"metadata": metadata,
}
)
class _FakeQueue:
async def enqueue(
self, *, command: dict[str, object], dedup_key: str | None
) -> str:
del command, dedup_key
return "task-1"
class _FailingQueue:
async def enqueue(
self, *, command: dict[str, object], dedup_key: str | None
) -> str:
del command, dedup_key
raise RuntimeError("enqueue failed")
class _FakeStream:
async def read(
self, *, session_id: str, last_event_id: str | None
) -> list[dict[str, object]]:
del session_id
return [
{"id": "2-0", "event": {"type": "RUN_STARTED"}, "cursor": last_event_id}
]
class _FakeAttachmentStorage:
def __init__(self) -> None:
self.calls: list[dict[str, object]] = []
async def upload_bytes(
self,
*,
bucket: str,
path: str,
content: bytes,
content_type: str,
) -> str:
self.calls.append(
{
"bucket": bucket,
"path": path,
"content": content,
"content_type": content_type,
}
)
return path
def _user() -> CurrentUser:
return CurrentUser(
id=UUID("00000000-0000-0000-0000-000000000001"),
email="user@example.com",
)
def _build_run_input(*, thread_id: str, run_id: str) -> RunAgentInput:
return RunAgentInput.model_validate(
{
"threadId": thread_id,
"runId": run_id,
"state": {},
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
"tools": [],
"context": [],
"forwardedProps": {},
}
)
async def test_resume_idempotency_uses_redis_lock_and_task_key() -> None:
service = AgentService(
repository=_FakeRepository(),
queue=_FakeQueue(),
stream=_FakeStream(),
)
user = _user()
run_input = _build_run_input(
thread_id="00000000-0000-0000-0000-000000000001",
run_id="run-1",
)
first = await service.enqueue_resume(
thread_id="00000000-0000-0000-0000-000000000001",
run_input=run_input,
current_user=user,
)
second = await service.enqueue_resume(
thread_id="00000000-0000-0000-0000-000000000001",
run_input=run_input,
current_user=user,
)
assert first.task_id == second.task_id
async def test_enqueue_run_creates_missing_thread_session() -> None:
repository = _FakeRepository()
service = AgentService(
repository=repository,
queue=_FakeQueue(),
stream=_FakeStream(),
)
run_input = _build_run_input(
thread_id="00000000-0000-0000-0000-000000000999",
run_id="run-1",
)
accepted = await service.enqueue_run(
run_input=run_input,
current_user=_user(),
)
assert accepted.thread_id == "00000000-0000-0000-0000-000000000999"
assert accepted.run_id == "run-1"
assert accepted.created is True
assert repository.created_with_session_id == "00000000-0000-0000-0000-000000000999"
assert repository.committed is True
async def test_enqueue_run_keeps_created_session_when_enqueue_fails() -> None:
repository = _FakeRepository()
service = AgentService(
repository=repository,
queue=_FailingQueue(),
stream=_FakeStream(),
)
run_input = _build_run_input(
thread_id="00000000-0000-0000-0000-000000000999",
run_id="run-1",
)
try:
await service.enqueue_run(
run_input=run_input,
current_user=_user(),
)
raise AssertionError("expected RuntimeError")
except RuntimeError as exc:
assert str(exc) == "enqueue failed"
assert repository.deleted_session_id is None
async def test_enqueue_run_handles_session_create_race() -> None:
class _RaceRepository(_FakeRepository):
def __init__(self) -> None:
super().__init__()
self.create_calls = 0
async def get_session_owner(self, *, session_id: str) -> str:
if self.create_calls == 0:
raise HTTPException(status_code=404, detail="Session not found")
return "00000000-0000-0000-0000-000000000001"
async def create_session_for_user(
self, *, user_id: str, session_id: str | None = None
) -> str:
del user_id, session_id
self.create_calls += 1
raise IntegrityError("insert", {}, Exception("duplicate key"))
repository = _RaceRepository()
service = AgentService(
repository=repository,
queue=_FakeQueue(),
stream=_FakeStream(),
)
run_input = _build_run_input(
thread_id="00000000-0000-0000-0000-000000000999",
run_id="run-1",
)
accepted = await service.enqueue_run(
run_input=run_input,
current_user=_user(),
)
assert accepted.created is False
assert repository.rolled_back is True
async def test_enqueue_run_uploads_user_image_to_supabase_and_injects_metadata(
monkeypatch,
) -> None:
monkeypatch.setattr(
agent_service_module.config.storage, "bucket", "agent-test-bucket"
)
repository = _FakeRepository()
attachment_storage = _FakeAttachmentStorage()
service = AgentService(
repository=repository,
queue=_FakeQueue(),
stream=_FakeStream(),
attachment_storage=attachment_storage,
)
run_input = RunAgentInput.model_validate(
{
"threadId": "00000000-0000-0000-0000-000000000001",
"runId": "run-with-image",
"state": {},
"messages": [
{
"id": "u1",
"role": "user",
"content": [
{"type": "text", "text": "帮我看下这张图"},
{
"type": "binary",
"data": "aGVsbG8=",
"mimeType": "image/png",
},
],
}
],
"tools": [],
"context": [],
"forwardedProps": {},
}
)
accepted = await service.enqueue_run(run_input=run_input, current_user=_user())
assert accepted.task_id == "task-1"
assert len(attachment_storage.calls) == 1
upload = attachment_storage.calls[0]
assert upload["bucket"] == "agent-test-bucket"
assert upload["content"] == b"hello"
assert upload["content_type"] == "image/png"
assert repository.persisted_user_messages
persisted = repository.persisted_user_messages[0]
assert persisted["session_id"] == "00000000-0000-0000-0000-000000000001"
assert persisted["run_id"] == "run-with-image"
metadata = persisted["metadata"]
assert isinstance(metadata, dict)
attachments = metadata.get("attachments")
assert isinstance(attachments, list)
assert attachments and isinstance(attachments[0], dict)
assert attachments[0]["bucket"] == "agent-test-bucket"
assert isinstance(attachments[0]["path"], str)
async def test_get_history_snapshot_wraps_history_day_as_state_snapshot_event() -> None:
service = AgentService(
repository=_FakeRepository(),
queue=_FakeQueue(),
stream=_FakeStream(),
)
event = await service.get_history_snapshot(
thread_id="00000000-0000-0000-0000-000000000001",
before=date(2026, 3, 7),
current_user=_user(),
)
assert event["type"] == "STATE_SNAPSHOT"
assert event["threadId"] == "00000000-0000-0000-0000-000000000001"
snapshot = event["snapshot"]
assert isinstance(snapshot, dict)
assert snapshot["scope"] == "history_day"
assert snapshot["day"] == "2026-03-06"
assert snapshot["messages"][0]["id"] == "m1"
async def test_get_user_history_snapshot_uses_latest_thread_when_absent() -> None:
service = AgentService(
repository=_FakeRepository(),
queue=_FakeQueue(),
stream=_FakeStream(),
)
event = await service.get_user_history_snapshot(
current_user=_user(),
thread_id=None,
before=None,
)
assert event["type"] == "STATE_SNAPSHOT"
assert event["threadId"] == "00000000-0000-0000-0000-000000000001"
async def test_asr_service_parses_dict_output_sentence(monkeypatch) -> None:
result = SimpleNamespace(
status_code=200,
message="ok",
output={"sentence": {"text": "你好,世界"}},
request_id="req-test",
)
class _FakeRecognition:
def __init__(self, **kwargs) -> None:
del kwargs
def call(self, *, file: str):
del file
return result
monkeypatch.setattr(agent_service_module, "Recognition", _FakeRecognition)
monkeypatch.setattr(AsrService, "_get_api_key", lambda self: "test-key")
service = AsrService()
transcript = await service.transcribe_file("/tmp/test.wav", "test.wav")
assert transcript == "你好,世界"
async def test_asr_service_parses_sentence_when_result_is_dict(monkeypatch) -> None:
result = {
"status_code": 200,
"message": "ok",
"output": {"sentence": {"text": "字典结果"}},
"request_id": "req-dict",
}
class _FakeRecognition:
def __init__(self, **kwargs) -> None:
del kwargs
def call(self, *, file: str):
del file
return result
monkeypatch.setattr(agent_service_module, "Recognition", _FakeRecognition)
monkeypatch.setattr(AsrService, "_get_api_key", lambda self: "test-key")
service = AsrService()
transcript = await service.transcribe_file("/tmp/test.wav", "test.wav")
assert transcript == "字典结果"
async def test_asr_service_returns_empty_when_sentence_missing(monkeypatch) -> None:
result = {
"status_code": 200,
"message": "ok",
"output": {},
}
class _FakeRecognition:
def __init__(self, **kwargs) -> None:
del kwargs
def call(self, *, file: str):
del file
return result
monkeypatch.setattr(agent_service_module, "Recognition", _FakeRecognition)
monkeypatch.setattr(AsrService, "_get_api_key", lambda self: "test-key")
service = AsrService()
transcript = await service.transcribe_file("/tmp/test.wav", "test.wav")
assert transcript == ""