from __future__ import annotations from datetime import date from types import SimpleNamespace from uuid import UUID from ag_ui.core import RunAgentInput from fastapi import HTTPException from sqlalchemy.exc import IntegrityError from core.auth.models import CurrentUser import v1.agent.service as agent_service_module from v1.agent.service import AgentService, AsrService class _FakeRepository: def __init__(self) -> None: self.committed = False self.rolled_back = False self.deleted_session_id: str | None = None self.created_with_session_id: str | None = None self.persisted_user_messages: list[dict[str, object]] = [] async def get_session_owner(self, *, session_id: str) -> str: if session_id == "00000000-0000-0000-0000-000000000001": return "00000000-0000-0000-0000-000000000001" raise HTTPException(status_code=404, detail="Session not found") async def create_session_for_user( self, *, user_id: str, session_id: str | None = None ) -> str: del user_id self.created_with_session_id = session_id return session_id or "00000000-0000-0000-0000-000000000999" async def commit(self) -> None: self.committed = True async def rollback(self) -> None: self.rolled_back = True async def delete_session(self, *, session_id: str) -> None: self.deleted_session_id = session_id async def get_history_day( self, *, session_id: str, before: date | None ) -> dict[str, object] | None: del session_id if before is not None and before <= date(2026, 3, 6): return None return { "day": "2026-03-06", "hasMore": False, "messages": [{"id": "m1", "role": "assistant", "content": "hello"}], } async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None: del user_id return "00000000-0000-0000-0000-000000000001" async def persist_user_message( self, *, session_id: str, run_id: str, content_text: str, metadata: dict[str, object] | None, ) -> None: self.persisted_user_messages.append( { "session_id": session_id, "run_id": run_id, "content_text": content_text, "metadata": metadata, } ) class _FakeQueue: async def enqueue( self, *, command: dict[str, object], dedup_key: str | None ) -> str: del command, dedup_key return "task-1" class _FailingQueue: async def enqueue( self, *, command: dict[str, object], dedup_key: str | None ) -> str: del command, dedup_key raise RuntimeError("enqueue failed") class _FakeStream: async def read( self, *, session_id: str, last_event_id: str | None ) -> list[dict[str, object]]: del session_id return [ {"id": "2-0", "event": {"type": "RUN_STARTED"}, "cursor": last_event_id} ] class _FakeAttachmentStorage: def __init__(self) -> None: self.calls: list[dict[str, object]] = [] async def upload_bytes( self, *, bucket: str, path: str, content: bytes, content_type: str, ) -> str: self.calls.append( { "bucket": bucket, "path": path, "content": content, "content_type": content_type, } ) return path def _user() -> CurrentUser: return CurrentUser( id=UUID("00000000-0000-0000-0000-000000000001"), email="user@example.com", ) def _build_run_input(*, thread_id: str, run_id: str) -> RunAgentInput: return RunAgentInput.model_validate( { "threadId": thread_id, "runId": run_id, "state": {}, "messages": [{"id": "u1", "role": "user", "content": "hello"}], "tools": [], "context": [], "forwardedProps": {}, } ) async def test_resume_idempotency_uses_redis_lock_and_task_key() -> None: service = AgentService( repository=_FakeRepository(), queue=_FakeQueue(), stream=_FakeStream(), ) user = _user() run_input = _build_run_input( thread_id="00000000-0000-0000-0000-000000000001", run_id="run-1", ) first = await service.enqueue_resume( thread_id="00000000-0000-0000-0000-000000000001", run_input=run_input, current_user=user, ) second = await service.enqueue_resume( thread_id="00000000-0000-0000-0000-000000000001", run_input=run_input, current_user=user, ) assert first.task_id == second.task_id async def test_enqueue_run_creates_missing_thread_session() -> None: repository = _FakeRepository() service = AgentService( repository=repository, queue=_FakeQueue(), stream=_FakeStream(), ) run_input = _build_run_input( thread_id="00000000-0000-0000-0000-000000000999", run_id="run-1", ) accepted = await service.enqueue_run( run_input=run_input, current_user=_user(), ) assert accepted.thread_id == "00000000-0000-0000-0000-000000000999" assert accepted.run_id == "run-1" assert accepted.created is True assert repository.created_with_session_id == "00000000-0000-0000-0000-000000000999" assert repository.committed is True async def test_enqueue_run_keeps_created_session_when_enqueue_fails() -> None: repository = _FakeRepository() service = AgentService( repository=repository, queue=_FailingQueue(), stream=_FakeStream(), ) run_input = _build_run_input( thread_id="00000000-0000-0000-0000-000000000999", run_id="run-1", ) try: await service.enqueue_run( run_input=run_input, current_user=_user(), ) raise AssertionError("expected RuntimeError") except RuntimeError as exc: assert str(exc) == "enqueue failed" assert repository.deleted_session_id is None async def test_enqueue_run_handles_session_create_race() -> None: class _RaceRepository(_FakeRepository): def __init__(self) -> None: super().__init__() self.create_calls = 0 async def get_session_owner(self, *, session_id: str) -> str: if self.create_calls == 0: raise HTTPException(status_code=404, detail="Session not found") return "00000000-0000-0000-0000-000000000001" async def create_session_for_user( self, *, user_id: str, session_id: str | None = None ) -> str: del user_id, session_id self.create_calls += 1 raise IntegrityError("insert", {}, Exception("duplicate key")) repository = _RaceRepository() service = AgentService( repository=repository, queue=_FakeQueue(), stream=_FakeStream(), ) run_input = _build_run_input( thread_id="00000000-0000-0000-0000-000000000999", run_id="run-1", ) accepted = await service.enqueue_run( run_input=run_input, current_user=_user(), ) assert accepted.created is False assert repository.rolled_back is True async def test_enqueue_run_uploads_user_image_to_supabase_and_injects_metadata( monkeypatch, ) -> None: monkeypatch.setattr( agent_service_module.config.storage, "bucket", "agent-test-bucket" ) repository = _FakeRepository() attachment_storage = _FakeAttachmentStorage() service = AgentService( repository=repository, queue=_FakeQueue(), stream=_FakeStream(), attachment_storage=attachment_storage, ) run_input = RunAgentInput.model_validate( { "threadId": "00000000-0000-0000-0000-000000000001", "runId": "run-with-image", "state": {}, "messages": [ { "id": "u1", "role": "user", "content": [ {"type": "text", "text": "帮我看下这张图"}, { "type": "binary", "data": "aGVsbG8=", "mimeType": "image/png", }, ], } ], "tools": [], "context": [], "forwardedProps": {}, } ) accepted = await service.enqueue_run(run_input=run_input, current_user=_user()) assert accepted.task_id == "task-1" assert len(attachment_storage.calls) == 1 upload = attachment_storage.calls[0] assert upload["bucket"] == "agent-test-bucket" assert upload["content"] == b"hello" assert upload["content_type"] == "image/png" assert repository.persisted_user_messages persisted = repository.persisted_user_messages[0] assert persisted["session_id"] == "00000000-0000-0000-0000-000000000001" assert persisted["run_id"] == "run-with-image" metadata = persisted["metadata"] assert isinstance(metadata, dict) attachments = metadata.get("attachments") assert isinstance(attachments, list) assert attachments and isinstance(attachments[0], dict) assert attachments[0]["bucket"] == "agent-test-bucket" assert isinstance(attachments[0]["path"], str) async def test_get_history_snapshot_wraps_history_day_as_state_snapshot_event() -> None: service = AgentService( repository=_FakeRepository(), queue=_FakeQueue(), stream=_FakeStream(), ) event = await service.get_history_snapshot( thread_id="00000000-0000-0000-0000-000000000001", before=date(2026, 3, 7), current_user=_user(), ) assert event["type"] == "STATE_SNAPSHOT" assert event["threadId"] == "00000000-0000-0000-0000-000000000001" snapshot = event["snapshot"] assert isinstance(snapshot, dict) assert snapshot["scope"] == "history_day" assert snapshot["day"] == "2026-03-06" assert snapshot["messages"][0]["id"] == "m1" async def test_get_user_history_snapshot_uses_latest_thread_when_absent() -> None: service = AgentService( repository=_FakeRepository(), queue=_FakeQueue(), stream=_FakeStream(), ) event = await service.get_user_history_snapshot( current_user=_user(), thread_id=None, before=None, ) assert event["type"] == "STATE_SNAPSHOT" assert event["threadId"] == "00000000-0000-0000-0000-000000000001" async def test_asr_service_parses_dict_output_sentence(monkeypatch) -> None: result = SimpleNamespace( status_code=200, message="ok", output={"sentence": {"text": "你好,世界"}}, request_id="req-test", ) class _FakeRecognition: def __init__(self, **kwargs) -> None: del kwargs def call(self, *, file: str): del file return result monkeypatch.setattr(agent_service_module, "Recognition", _FakeRecognition) monkeypatch.setattr(AsrService, "_get_api_key", lambda self: "test-key") service = AsrService() transcript = await service.transcribe_file("/tmp/test.wav", "test.wav") assert transcript == "你好,世界" async def test_asr_service_parses_sentence_when_result_is_dict(monkeypatch) -> None: result = { "status_code": 200, "message": "ok", "output": {"sentence": {"text": "字典结果"}}, "request_id": "req-dict", } class _FakeRecognition: def __init__(self, **kwargs) -> None: del kwargs def call(self, *, file: str): del file return result monkeypatch.setattr(agent_service_module, "Recognition", _FakeRecognition) monkeypatch.setattr(AsrService, "_get_api_key", lambda self: "test-key") service = AsrService() transcript = await service.transcribe_file("/tmp/test.wav", "test.wav") assert transcript == "字典结果" async def test_asr_service_returns_empty_when_sentence_missing(monkeypatch) -> None: result = { "status_code": 200, "message": "ok", "output": {}, } class _FakeRecognition: def __init__(self, **kwargs) -> None: del kwargs def call(self, *, file: str): del file return result monkeypatch.setattr(agent_service_module, "Recognition", _FakeRecognition) monkeypatch.setattr(AsrService, "_get_api_key", lambda self: "test-key") service = AsrService() transcript = await service.transcribe_file("/tmp/test.wav", "test.wav") assert transcript == ""