refactor: 移除 crewai agent 架构相关代码并更新 LLM 配置
This commit is contained in:
@@ -1,80 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from models.agent_chat_session import (
|
||||
AgentChatSession,
|
||||
AgentChatSessionStatus,
|
||||
SessionType,
|
||||
)
|
||||
from v1.agent.service import AgentChatService
|
||||
from v1.agent.tool_registry import validate_tool_spec
|
||||
|
||||
|
||||
class TestAgentSecurityRules:
|
||||
def test_tool_name_must_be_allowlisted(self):
|
||||
validate_tool_spec({"name": "ui.navigate_to", "execution_target": "frontend"})
|
||||
validate_tool_spec({"name": "srv.search_docs", "execution_target": "backend"})
|
||||
|
||||
def test_tool_name_rejected_if_not_in_namespace(self):
|
||||
try:
|
||||
validate_tool_spec(
|
||||
{"name": "malicious.tool", "execution_target": "frontend"}
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
raise AssertionError("Should have raised ValueError for unknown namespace")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_frontend_result_fails_when_interrupt_mismatch(self):
|
||||
session = AgentChatSession(
|
||||
id=uuid4(),
|
||||
user_id=UUID("00000000-0000-0000-0000-000000000001"),
|
||||
session_type=SessionType.CHAT,
|
||||
status=AgentChatSessionStatus.RUNNING,
|
||||
)
|
||||
|
||||
class FakeAsyncSession:
|
||||
def __init__(self, session_obj: AgentChatSession) -> None:
|
||||
self._session_obj = session_obj
|
||||
|
||||
async def execute(self, stmt: object):
|
||||
class _Result:
|
||||
def __init__(self, session_obj: AgentChatSession | None) -> None:
|
||||
self._session_obj = session_obj
|
||||
|
||||
def scalar_one_or_none(self) -> AgentChatSession | None:
|
||||
return self._session_obj
|
||||
|
||||
return _Result(self._session_obj)
|
||||
|
||||
async def scalar(self, stmt: object) -> AgentChatSession | None:
|
||||
return self._session_obj
|
||||
|
||||
service = AgentChatService(
|
||||
session=FakeAsyncSession(session), # type: ignore[arg-type]
|
||||
current_user=CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001")),
|
||||
)
|
||||
|
||||
await service.set_pending_tool_call(
|
||||
session_id=session.id,
|
||||
interrupt_id="int-1",
|
||||
tool_name="srv.transfer_funds",
|
||||
tool_args={"to": "u2", "amount": 100},
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(minutes=5),
|
||||
thread_id="t1",
|
||||
run_id="r1",
|
||||
)
|
||||
|
||||
result = await service.apply_resume_decision(
|
||||
session_id=session.id,
|
||||
interrupt_id="int-other",
|
||||
decision={"decision": "approved"},
|
||||
)
|
||||
|
||||
assert result.applied is False
|
||||
@@ -1,25 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from v1.agent.crewai_flow import AgentFlow
|
||||
|
||||
|
||||
class TestCrewAIFlow:
|
||||
@pytest.mark.asyncio
|
||||
async def test_flow_stages_run_in_order(self):
|
||||
flow = AgentFlow()
|
||||
await flow.run()
|
||||
assert flow.state.stage_trace == ["intent", "execution", "reporting"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flow_state_initialized(self):
|
||||
flow = AgentFlow()
|
||||
assert flow.state.stage_trace == []
|
||||
assert flow.state.current_stage is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flow_updates_current_stage(self):
|
||||
flow = AgentFlow()
|
||||
await flow.run()
|
||||
assert flow.state.current_stage == "reporting"
|
||||
@@ -1,187 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from models.agent_chat_session import (
|
||||
AgentChatSession,
|
||||
AgentChatSessionStatus,
|
||||
SessionType,
|
||||
)
|
||||
from v1.agent.service import AgentChatService
|
||||
|
||||
|
||||
class FakeAsyncSession:
|
||||
def __init__(self) -> None:
|
||||
self.added: list[object] = []
|
||||
self._sessions: dict[UUID, AgentChatSession] = {}
|
||||
self.last_fetch_with_lock = False
|
||||
|
||||
def add(self, obj: object) -> None:
|
||||
self.added.append(obj)
|
||||
if isinstance(obj, AgentChatSession):
|
||||
self._sessions[obj.id] = obj
|
||||
|
||||
async def flush(self) -> None:
|
||||
return None
|
||||
|
||||
async def commit(self) -> None:
|
||||
pass
|
||||
|
||||
async def rollback(self) -> None:
|
||||
pass
|
||||
|
||||
async def refresh(self, obj: object) -> None:
|
||||
pass
|
||||
|
||||
async def execute(self, stmt: object):
|
||||
self.last_fetch_with_lock = "FOR UPDATE" in str(stmt)
|
||||
|
||||
class _Result:
|
||||
def __init__(self, session_obj: AgentChatSession | None) -> None:
|
||||
self._session_obj = session_obj
|
||||
|
||||
def scalar_one_or_none(self) -> AgentChatSession | None:
|
||||
return self._session_obj
|
||||
|
||||
return _Result(next(iter(self._sessions.values()), None))
|
||||
|
||||
async def scalar(self, stmt: object) -> AgentChatSession | None:
|
||||
for session in self._sessions.values():
|
||||
return session
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_db() -> FakeAsyncSession:
|
||||
return FakeAsyncSession()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session(fake_db: FakeAsyncSession) -> AgentChatSession:
|
||||
sess = AgentChatSession(
|
||||
id=uuid4(),
|
||||
user_id=uuid4(),
|
||||
session_type=SessionType.CHAT,
|
||||
status=AgentChatSessionStatus.RUNNING,
|
||||
)
|
||||
fake_db.add(sess)
|
||||
return sess
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service(fake_db: FakeAsyncSession) -> AgentChatService:
|
||||
return AgentChatService(fake_db, current_user=None) # type: ignore[arg-type]
|
||||
|
||||
|
||||
class TestResumeIdempotency:
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_is_idempotent(
|
||||
self,
|
||||
service: AgentChatService,
|
||||
session: AgentChatSession,
|
||||
fake_db: FakeAsyncSession,
|
||||
):
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
await service.set_pending_tool_call(
|
||||
session_id=session.id,
|
||||
interrupt_id="int-1",
|
||||
tool_name="srv.transfer_funds",
|
||||
tool_args={"to": "u2", "amount": 100},
|
||||
expires_at=expires_at,
|
||||
thread_id="t1",
|
||||
run_id="r1",
|
||||
)
|
||||
|
||||
first = await service.apply_resume_decision(
|
||||
session_id=session.id,
|
||||
interrupt_id="int-1",
|
||||
decision={"decision": "approved"},
|
||||
)
|
||||
second = await service.apply_resume_decision(
|
||||
session_id=session.id,
|
||||
interrupt_id="int-1",
|
||||
decision={"decision": "approved"},
|
||||
)
|
||||
|
||||
assert first.applied is True
|
||||
assert second.applied is False
|
||||
assert fake_db.last_fetch_with_lock is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_updates_status_to_approved(
|
||||
self, service: AgentChatService, session: AgentChatSession
|
||||
):
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
await service.set_pending_tool_call(
|
||||
session_id=session.id,
|
||||
interrupt_id="int-2",
|
||||
tool_name="srv.delete_file",
|
||||
tool_args={"file_id": "f1"},
|
||||
expires_at=expires_at,
|
||||
thread_id="t1",
|
||||
run_id="r1",
|
||||
)
|
||||
|
||||
result = await service.apply_resume_decision(
|
||||
session_id=session.id,
|
||||
interrupt_id="int-2",
|
||||
decision={"decision": "approved"},
|
||||
)
|
||||
|
||||
assert result.applied is True
|
||||
snapshot = await service.get_state_snapshot(session.id)
|
||||
assert snapshot["pending_tool_call"]["status"] == "APPROVED_EXECUTING"
|
||||
assert snapshot["pending_tool_call"]["decision"] == {"decision": "approved"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_updates_status_to_rejected(
|
||||
self, service: AgentChatService, session: AgentChatSession
|
||||
):
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
await service.set_pending_tool_call(
|
||||
session_id=session.id,
|
||||
interrupt_id="int-3",
|
||||
tool_name="srv.transfer_funds",
|
||||
tool_args={"to": "u2", "amount": 100},
|
||||
expires_at=expires_at,
|
||||
thread_id="t1",
|
||||
run_id="r1",
|
||||
)
|
||||
|
||||
result = await service.apply_resume_decision(
|
||||
session_id=session.id,
|
||||
interrupt_id="int-3",
|
||||
decision={"decision": "rejected"},
|
||||
)
|
||||
|
||||
assert result.applied is True
|
||||
snapshot = await service.get_state_snapshot(session.id)
|
||||
assert snapshot["pending_tool_call"]["status"] == "REJECTED"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_expired_pending_marks_expired_and_not_applied(
|
||||
self, service: AgentChatService, session: AgentChatSession
|
||||
):
|
||||
expires_at = datetime.now(timezone.utc) - timedelta(seconds=1)
|
||||
await service.set_pending_tool_call(
|
||||
session_id=session.id,
|
||||
interrupt_id="int-expired",
|
||||
tool_name="srv.transfer_funds",
|
||||
tool_args={"to": "u2", "amount": 100},
|
||||
expires_at=expires_at,
|
||||
thread_id="t1",
|
||||
run_id="r1",
|
||||
)
|
||||
|
||||
result = await service.apply_resume_decision(
|
||||
session_id=session.id,
|
||||
interrupt_id="int-expired",
|
||||
decision={"decision": "approved"},
|
||||
)
|
||||
|
||||
assert result.applied is False
|
||||
snapshot = await service.get_state_snapshot(session.id)
|
||||
assert snapshot["pending_tool_call"]["status"] == "EXPIRED"
|
||||
@@ -1,127 +0,0 @@
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from v1.agent.schemas import AgentSessionSnapshot, RunAgentInput
|
||||
|
||||
|
||||
class TestRunAgentInput:
|
||||
def test_requires_full_fields(self):
|
||||
payload = {
|
||||
"threadId": "t1",
|
||||
"runId": "r1",
|
||||
"state": {},
|
||||
"messages": [],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
}
|
||||
model = RunAgentInput.model_validate(payload)
|
||||
assert model.threadId == "t1"
|
||||
assert model.runId == "r1"
|
||||
assert model.parentRunId is None
|
||||
assert model.state == {}
|
||||
assert model.messages == []
|
||||
assert model.tools == []
|
||||
assert model.context == []
|
||||
assert model.forwardedProps == {}
|
||||
assert model.resume is None
|
||||
|
||||
def test_resume_optional(self):
|
||||
payload = {
|
||||
"threadId": "t1",
|
||||
"runId": "r2",
|
||||
"state": {},
|
||||
"messages": [],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
"resume": {"interruptId": "int-1", "payload": {"decision": "approved"}},
|
||||
}
|
||||
model = RunAgentInput.model_validate(payload)
|
||||
assert model.resume is not None
|
||||
assert model.resume["interruptId"] == "int-1"
|
||||
assert model.resume["payload"]["decision"] == "approved"
|
||||
|
||||
def test_parent_run_id_optional(self):
|
||||
payload = {
|
||||
"threadId": "t1",
|
||||
"runId": "r3",
|
||||
"parentRunId": "p1",
|
||||
"state": {"key": "value"},
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
"tools": [{"name": "ui.navigate_to"}],
|
||||
"context": [{"type": "user", "id": "u1"}],
|
||||
"forwardedProps": {"theme": "dark"},
|
||||
}
|
||||
model = RunAgentInput.model_validate(payload)
|
||||
assert model.parentRunId == "p1"
|
||||
assert model.state == {"key": "value"}
|
||||
assert len(model.messages) == 1
|
||||
assert model.messages[0]["role"] == "user"
|
||||
|
||||
|
||||
class TestAgentSessionSnapshot:
|
||||
def test_state_snapshot_v2_model_accepts_valid_payload(self):
|
||||
payload = {
|
||||
"version": 2,
|
||||
"pending_tool_call": {
|
||||
"interrupt_id": "int-1",
|
||||
"tool_name": "srv.transfer_funds",
|
||||
"tool_args": {"to": "u2", "amount": 100},
|
||||
"status": "PENDING_APPROVAL",
|
||||
"expires_at": "2026-03-03T12:00:00Z",
|
||||
"decision": None,
|
||||
"result": None,
|
||||
"updated_at": "2026-03-03T11:59:00Z",
|
||||
},
|
||||
"run_context": {"thread_id": "t1", "run_id": "r1"},
|
||||
}
|
||||
|
||||
model = AgentSessionSnapshot.model_validate(payload)
|
||||
|
||||
assert model.version == 2
|
||||
assert model.pending_tool_call is not None
|
||||
assert model.pending_tool_call.interrupt_id == "int-1"
|
||||
assert model.pending_tool_call.updated_at == datetime(
|
||||
2026, 3, 3, 11, 59, tzinfo=timezone.utc
|
||||
)
|
||||
|
||||
def test_state_snapshot_v2_rejects_wrong_version(self):
|
||||
payload = {
|
||||
"version": 1,
|
||||
"pending_tool_call": None,
|
||||
"run_context": {"thread_id": "t1", "run_id": "r1"},
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
AgentSessionSnapshot.model_validate(payload)
|
||||
|
||||
def test_state_snapshot_v2_requires_pending_tool_call_key(self):
|
||||
payload = {
|
||||
"version": 2,
|
||||
"run_context": {"thread_id": "t1", "run_id": "r1"},
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
AgentSessionSnapshot.model_validate(payload)
|
||||
|
||||
def test_state_snapshot_v2_rejects_extra_fields(self):
|
||||
payload = {
|
||||
"version": 2,
|
||||
"pending_tool_call": {
|
||||
"interrupt_id": "int-1",
|
||||
"tool_name": "srv.transfer_funds",
|
||||
"tool_args": {"to": "u2", "amount": 100},
|
||||
"status": "PENDING_APPROVAL",
|
||||
"expires_at": "2026-03-03T12:00:00Z",
|
||||
"decision": None,
|
||||
"result": None,
|
||||
"updated_at": "2026-03-03T11:59:00Z",
|
||||
"unexpected": True,
|
||||
},
|
||||
"run_context": {"thread_id": "t1", "run_id": "r1", "foo": "bar"},
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
AgentSessionSnapshot.model_validate(payload)
|
||||
@@ -1,168 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from models.agent_chat_session import (
|
||||
AgentChatSession,
|
||||
AgentChatSessionStatus,
|
||||
SessionType,
|
||||
)
|
||||
from v1.agent.service import AgentChatService
|
||||
|
||||
|
||||
class FakeAsyncSession:
|
||||
def __init__(self) -> None:
|
||||
self.added: list[object] = []
|
||||
self._sessions: dict[UUID, AgentChatSession] = {}
|
||||
|
||||
def add(self, obj: object) -> None:
|
||||
self.added.append(obj)
|
||||
if isinstance(obj, AgentChatSession):
|
||||
self._sessions[obj.id] = obj
|
||||
|
||||
async def flush(self) -> None:
|
||||
return None
|
||||
|
||||
async def commit(self) -> None:
|
||||
pass
|
||||
|
||||
async def rollback(self) -> None:
|
||||
pass
|
||||
|
||||
async def refresh(self, obj: object) -> None:
|
||||
pass
|
||||
|
||||
async def execute(self, stmt: object):
|
||||
class _Result:
|
||||
def __init__(self, session_obj: AgentChatSession | None) -> None:
|
||||
self._session_obj = session_obj
|
||||
|
||||
def scalar_one_or_none(self) -> AgentChatSession | None:
|
||||
return self._session_obj
|
||||
|
||||
return _Result(next(iter(self._sessions.values()), None))
|
||||
|
||||
async def scalar(self, stmt: object) -> AgentChatSession | None:
|
||||
for session in self._sessions.values():
|
||||
return session
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_db() -> FakeAsyncSession:
|
||||
return FakeAsyncSession()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session(fake_db: FakeAsyncSession) -> AgentChatSession:
|
||||
sess = AgentChatSession(
|
||||
id=uuid4(),
|
||||
user_id=uuid4(),
|
||||
session_type=SessionType.CHAT,
|
||||
status=AgentChatSessionStatus.RUNNING,
|
||||
)
|
||||
fake_db.add(sess)
|
||||
return sess
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service(fake_db: FakeAsyncSession) -> AgentChatService:
|
||||
return AgentChatService(fake_db, current_user=None) # type: ignore[arg-type]
|
||||
|
||||
|
||||
class TestPendingToolCall:
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_pending_tool_call_to_state_snapshot(
|
||||
self, service: AgentChatService, session: AgentChatSession
|
||||
):
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
await service.set_pending_tool_call(
|
||||
session_id=session.id,
|
||||
interrupt_id="int-1",
|
||||
tool_name="srv.transfer_funds",
|
||||
tool_args={"to": "u2", "amount": 100},
|
||||
expires_at=expires_at,
|
||||
thread_id="t1",
|
||||
run_id="r1",
|
||||
)
|
||||
snapshot = await service.get_state_snapshot(session.id)
|
||||
assert snapshot is not None
|
||||
assert snapshot["version"] == 2
|
||||
assert snapshot["run_context"]["thread_id"] == "t1"
|
||||
assert snapshot["run_context"]["run_id"] == "r1"
|
||||
assert snapshot["pending_tool_call"]["status"] == "PENDING_APPROVAL"
|
||||
assert snapshot["pending_tool_call"]["interrupt_id"] == "int-1"
|
||||
assert snapshot["pending_tool_call"]["tool_name"] == "srv.transfer_funds"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_state_snapshot_returns_none_when_empty(
|
||||
self, service: AgentChatService, session: AgentChatSession
|
||||
):
|
||||
snapshot = await service.get_state_snapshot(session.id)
|
||||
assert snapshot is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_pending_tool_call_status(
|
||||
self, service: AgentChatService, session: AgentChatSession
|
||||
):
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
await service.set_pending_tool_call(
|
||||
session_id=session.id,
|
||||
interrupt_id="int-2",
|
||||
tool_name="srv.delete_file",
|
||||
tool_args={"file_id": "f1"},
|
||||
expires_at=expires_at,
|
||||
thread_id="t1",
|
||||
run_id="r1",
|
||||
)
|
||||
|
||||
await service.update_pending_tool_call_status(
|
||||
session_id=session.id,
|
||||
interrupt_id="int-2",
|
||||
status="APPROVED_EXECUTING",
|
||||
)
|
||||
|
||||
snapshot = await service.get_state_snapshot(session.id)
|
||||
assert snapshot["pending_tool_call"]["status"] == "APPROVED_EXECUTING"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_legacy_snapshot_is_rejected(
|
||||
self, service: AgentChatService, session: AgentChatSession
|
||||
):
|
||||
session.state_snapshot = {"pending_tool_call": {"status": "PENDING_APPROVAL"}}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await service.apply_resume_decision(
|
||||
session_id=session.id,
|
||||
interrupt_id="int-legacy",
|
||||
decision={"decision": "approved"},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_snapshot_rejects_naive_datetime(
|
||||
self, service: AgentChatService, session: AgentChatSession
|
||||
):
|
||||
session.state_snapshot = {
|
||||
"version": 2,
|
||||
"pending_tool_call": {
|
||||
"interrupt_id": "int-naive",
|
||||
"tool_name": "srv.transfer_funds",
|
||||
"tool_args": {"to": "u2", "amount": 100},
|
||||
"status": "PENDING_APPROVAL",
|
||||
"expires_at": "2026-03-03T12:00:00",
|
||||
"decision": None,
|
||||
"result": None,
|
||||
"updated_at": "2026-03-03T11:59:00",
|
||||
},
|
||||
"run_context": {"thread_id": "t1", "run_id": "r1"},
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await service.apply_resume_decision(
|
||||
session_id=session.id,
|
||||
interrupt_id="int-naive",
|
||||
decision={"decision": "approved"},
|
||||
)
|
||||
@@ -1,126 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from models.agent_chat_session import (
|
||||
AgentChatSession,
|
||||
AgentChatSessionStatus,
|
||||
SessionType,
|
||||
)
|
||||
from v1.agent.schemas import RunAgentInput
|
||||
from v1.agent.service import AgentChatService
|
||||
|
||||
|
||||
class FakeAsyncSession:
|
||||
def __init__(self, sessions: list[AgentChatSession]) -> None:
|
||||
self._sessions = {session.id: session for session in sessions}
|
||||
self.commit_called = False
|
||||
|
||||
async def execute(self, stmt: object):
|
||||
class _Result:
|
||||
def __init__(self, session_obj: AgentChatSession | None) -> None:
|
||||
self._session_obj = session_obj
|
||||
|
||||
def scalar_one_or_none(self) -> AgentChatSession | None:
|
||||
return self._session_obj
|
||||
|
||||
for session in self._sessions.values():
|
||||
return _Result(session)
|
||||
return _Result(None)
|
||||
|
||||
async def scalar(self, stmt: object) -> AgentChatSession | None:
|
||||
for session in self._sessions.values():
|
||||
return session
|
||||
return None
|
||||
|
||||
async def commit(self) -> None:
|
||||
self.commit_called = True
|
||||
|
||||
|
||||
def _build_input(run_id: str) -> RunAgentInput:
|
||||
return RunAgentInput.model_validate(
|
||||
{
|
||||
"threadId": "t1",
|
||||
"runId": run_id,
|
||||
"state": {},
|
||||
"messages": [],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
"resume": {"interruptId": "int-1", "payload": {"decision": "approved"}},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_resume_rejects_non_owner_session() -> None:
|
||||
session = AgentChatSession(
|
||||
id=uuid4(),
|
||||
user_id=uuid4(),
|
||||
session_type=SessionType.CHAT,
|
||||
status=AgentChatSessionStatus.RUNNING,
|
||||
state_snapshot={
|
||||
"version": 2,
|
||||
"pending_tool_call": {
|
||||
"interrupt_id": "int-1",
|
||||
"tool_name": "srv.transfer_funds",
|
||||
"tool_args": {"to": "u2", "amount": 100},
|
||||
"status": "PENDING_APPROVAL",
|
||||
"expires_at": datetime.now(timezone.utc).isoformat(),
|
||||
"decision": None,
|
||||
"result": None,
|
||||
"updated_at": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
"run_context": {"thread_id": "t1", "run_id": str(uuid4())},
|
||||
},
|
||||
)
|
||||
service = AgentChatService(
|
||||
session=FakeAsyncSession([session]), # type: ignore[arg-type]
|
||||
current_user=CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001")),
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.prepare_resume(str(session.id), _build_input(str(session.id)))
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_resume_commits_expired_state_before_410() -> None:
|
||||
owner_id = UUID("00000000-0000-0000-0000-000000000001")
|
||||
session = AgentChatSession(
|
||||
id=uuid4(),
|
||||
user_id=owner_id,
|
||||
session_type=SessionType.CHAT,
|
||||
status=AgentChatSessionStatus.RUNNING,
|
||||
state_snapshot={
|
||||
"version": 2,
|
||||
"pending_tool_call": {
|
||||
"interrupt_id": "int-1",
|
||||
"tool_name": "srv.transfer_funds",
|
||||
"tool_args": {"to": "u2", "amount": 100},
|
||||
"status": "PENDING_APPROVAL",
|
||||
"expires_at": "2000-01-01T00:00:00+00:00",
|
||||
"decision": None,
|
||||
"result": None,
|
||||
"updated_at": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
"run_context": {"thread_id": "t1", "run_id": str(uuid4())},
|
||||
},
|
||||
)
|
||||
fake_db = FakeAsyncSession([session])
|
||||
service = AgentChatService(
|
||||
session=fake_db, # type: ignore[arg-type]
|
||||
current_user=CurrentUser(id=owner_id),
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.prepare_resume(str(session.id), _build_input(str(session.id)))
|
||||
|
||||
assert exc_info.value.status_code == 410
|
||||
assert fake_db.commit_called is True
|
||||
@@ -1,65 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from v1.agent.tool_dispatcher import (
|
||||
BackendExecutionResult,
|
||||
InterruptResult,
|
||||
ToolDispatcher,
|
||||
dispatch_tool_call,
|
||||
)
|
||||
|
||||
|
||||
class TestToolDispatcher:
|
||||
def test_frontend_tool_returns_interrupt(self):
|
||||
tool = {
|
||||
"name": "ui.navigate_to",
|
||||
"execution_target": "frontend",
|
||||
"args": {"path": "/home"},
|
||||
}
|
||||
result = dispatch_tool_call(tool)
|
||||
assert isinstance(result, InterruptResult)
|
||||
assert result.interrupt_type == "tool_execution"
|
||||
assert result.tool_name == "ui.navigate_to"
|
||||
|
||||
def test_backend_tool_executes_directly(self):
|
||||
tool = {
|
||||
"name": "srv.get_user_info",
|
||||
"execution_target": "backend",
|
||||
"args": {"user_id": "u1"},
|
||||
"requires_approval": False,
|
||||
}
|
||||
result = dispatch_tool_call(tool)
|
||||
assert isinstance(result, BackendExecutionResult)
|
||||
assert result.tool_name == "srv.get_user_info"
|
||||
|
||||
def test_backend_tool_with_approval_returns_interrupt(self):
|
||||
tool = {
|
||||
"name": "srv.transfer_funds",
|
||||
"execution_target": "backend",
|
||||
"args": {"to": "u2", "amount": 100},
|
||||
"requires_approval": True,
|
||||
}
|
||||
result = dispatch_tool_call(tool)
|
||||
assert isinstance(result, InterruptResult)
|
||||
assert result.interrupt_type == "approval_required"
|
||||
assert result.tool_name == "srv.transfer_funds"
|
||||
|
||||
def test_dispatcher_class_can_dispatch(self):
|
||||
dispatcher = ToolDispatcher()
|
||||
tool = {
|
||||
"name": "ui.navigate_to",
|
||||
"execution_target": "frontend",
|
||||
"args": {"message": "Hello"},
|
||||
}
|
||||
result = dispatcher.dispatch(tool)
|
||||
assert isinstance(result, InterruptResult)
|
||||
|
||||
def test_unknown_frontend_tool_is_rejected(self):
|
||||
tool = {
|
||||
"name": "ui.unknown_action",
|
||||
"execution_target": "frontend",
|
||||
"args": {},
|
||||
}
|
||||
with pytest.raises(ValueError, match="not in allowlist"):
|
||||
dispatch_tool_call(tool)
|
||||
@@ -1,27 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from v1.agent.tool_registry import validate_tool_spec
|
||||
|
||||
|
||||
class TestValidateToolSpec:
|
||||
def test_ui_namespace_must_be_frontend(self):
|
||||
with pytest.raises(ValueError, match="ui.* must use frontend target"):
|
||||
validate_tool_spec(
|
||||
{"name": "ui.navigate_to", "execution_target": "backend"}
|
||||
)
|
||||
|
||||
def test_srv_namespace_must_be_backend(self):
|
||||
with pytest.raises(ValueError, match="srv.* must use backend target"):
|
||||
validate_tool_spec(
|
||||
{"name": "srv.search_docs", "execution_target": "frontend"}
|
||||
)
|
||||
|
||||
def test_ui_namespace_with_frontend_is_valid(self):
|
||||
validate_tool_spec({"name": "ui.navigate_to", "execution_target": "frontend"})
|
||||
|
||||
def test_srv_namespace_with_backend_is_valid(self):
|
||||
validate_tool_spec({"name": "srv.search_docs", "execution_target": "backend"})
|
||||
|
||||
def test_other_namespace_is_rejected(self):
|
||||
with pytest.raises(ValueError, match="must be in ui.* or srv.* namespace"):
|
||||
validate_tool_spec({"name": "other.tool", "execution_target": "frontend"})
|
||||
Reference in New Issue
Block a user