Files
social-app/backend/tests/unit/v1/agent/test_stream_resume_security.py
T

127 lines
4.1 KiB
Python

from __future__ import annotations
from datetime import datetime, timezone
from uuid import UUID, uuid4
import pytest
from fastapi import HTTPException
from core.auth.models import CurrentUser
from models.agent_chat_session import (
AgentChatSession,
AgentChatSessionStatus,
SessionType,
)
from v1.agent.schemas import RunAgentInput
from v1.agent.service import AgentChatService
class FakeAsyncSession:
def __init__(self, sessions: list[AgentChatSession]) -> None:
self._sessions = {session.id: session for session in sessions}
self.commit_called = False
async def execute(self, stmt: object):
class _Result:
def __init__(self, session_obj: AgentChatSession | None) -> None:
self._session_obj = session_obj
def scalar_one_or_none(self) -> AgentChatSession | None:
return self._session_obj
for session in self._sessions.values():
return _Result(session)
return _Result(None)
async def scalar(self, stmt: object) -> AgentChatSession | None:
for session in self._sessions.values():
return session
return None
async def commit(self) -> None:
self.commit_called = True
def _build_input(run_id: str) -> RunAgentInput:
return RunAgentInput.model_validate(
{
"threadId": "t1",
"runId": run_id,
"state": {},
"messages": [],
"tools": [],
"context": [],
"forwardedProps": {},
"resume": {"interruptId": "int-1", "payload": {"decision": "approved"}},
}
)
@pytest.mark.asyncio
async def test_stream_resume_rejects_non_owner_session() -> None:
session = AgentChatSession(
id=uuid4(),
user_id=uuid4(),
session_type=SessionType.CHAT,
status=AgentChatSessionStatus.RUNNING,
state_snapshot={
"version": 2,
"pending_tool_call": {
"interrupt_id": "int-1",
"tool_name": "srv.transfer_funds",
"tool_args": {"to": "u2", "amount": 100},
"status": "PENDING_APPROVAL",
"expires_at": datetime.now(timezone.utc).isoformat(),
"decision": None,
"result": None,
"updated_at": datetime.now(timezone.utc).isoformat(),
},
"run_context": {"thread_id": "t1", "run_id": str(uuid4())},
},
)
service = AgentChatService(
session=FakeAsyncSession([session]), # type: ignore[arg-type]
current_user=CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001")),
)
with pytest.raises(HTTPException) as exc_info:
await service.prepare_resume(str(session.id), _build_input(str(session.id)))
assert exc_info.value.status_code == 404
@pytest.mark.asyncio
async def test_prepare_resume_commits_expired_state_before_410() -> None:
owner_id = UUID("00000000-0000-0000-0000-000000000001")
session = AgentChatSession(
id=uuid4(),
user_id=owner_id,
session_type=SessionType.CHAT,
status=AgentChatSessionStatus.RUNNING,
state_snapshot={
"version": 2,
"pending_tool_call": {
"interrupt_id": "int-1",
"tool_name": "srv.transfer_funds",
"tool_args": {"to": "u2", "amount": 100},
"status": "PENDING_APPROVAL",
"expires_at": "2000-01-01T00:00:00+00:00",
"decision": None,
"result": None,
"updated_at": datetime.now(timezone.utc).isoformat(),
},
"run_context": {"thread_id": "t1", "run_id": str(uuid4())},
},
)
fake_db = FakeAsyncSession([session])
service = AgentChatService(
session=fake_db, # type: ignore[arg-type]
current_user=CurrentUser(id=owner_id),
)
with pytest.raises(HTTPException) as exc_info:
await service.prepare_resume(str(session.id), _build_input(str(session.id)))
assert exc_info.value.status_code == 410
assert fake_db.commit_called is True