fix(agent): address CRITICAL/HIGH security and validation issues
- Fix SSE JSON injection: use json.dumps for safe serialization - Add tool validation to dispatcher with allowlist - Add field validation to tool_registry with proper error handling - Add run_id consistency check (409 on mismatch) - Add RunAgentInput constraints: min_length, extra=forbid - Fix crewai_flow: use Field(default_factory), prefix unused params
This commit is contained in:
@@ -2,14 +2,14 @@ from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class FlowState(BaseModel):
|
||||
stage_trace: list[str] = []
|
||||
stage_trace: list[str] = Field(default_factory=list)
|
||||
current_stage: str | None = None
|
||||
message: str = ""
|
||||
context: dict[str, Any] = {}
|
||||
context: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class AgentFlow:
|
||||
@@ -27,12 +27,12 @@ class AgentFlow:
|
||||
self.state.stage_trace.append("intent")
|
||||
return {"stage": "intent", "result": "intent recognized"}
|
||||
|
||||
async def task_execution(self, prev_result: dict[str, Any]) -> dict[str, Any]:
|
||||
async def task_execution(self, _prev_result: dict[str, Any]) -> dict[str, Any]:
|
||||
self.state.current_stage = "execution"
|
||||
self.state.stage_trace.append("execution")
|
||||
return {"stage": "execution", "result": "task executed"}
|
||||
|
||||
async def result_reporting(self, prev_result: dict[str, Any]) -> dict[str, Any]:
|
||||
async def result_reporting(self, _prev_result: dict[str, Any]) -> dict[str, Any]:
|
||||
self.state.current_stage = "reporting"
|
||||
self.state.stage_trace.append("reporting")
|
||||
return {"stage": "reporting", "result": "reported"}
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from v1.agent.dependencies import get_agent_service
|
||||
@@ -29,6 +29,11 @@ async def resume_run(
|
||||
input_data: RunAgentInput,
|
||||
service: Annotated[AgentChatService, Depends(get_agent_service)],
|
||||
) -> StreamingResponse:
|
||||
if input_data.runId != run_id:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"run_id mismatch: path={run_id}, body={input_data.runId}",
|
||||
)
|
||||
return StreamingResponse(
|
||||
service.stream_resume(run_id, input_data),
|
||||
media_type="text/event-stream",
|
||||
|
||||
@@ -3,18 +3,20 @@ from __future__ import annotations
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class RunAgentInput(BaseModel):
|
||||
threadId: str
|
||||
runId: str
|
||||
parentRunId: str | None = None
|
||||
state: dict[str, Any]
|
||||
messages: list[dict[str, Any]]
|
||||
tools: list[dict[str, Any]]
|
||||
context: list[dict[str, Any]]
|
||||
forwardedProps: dict[str, Any]
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
threadId: str = Field(min_length=1, max_length=255)
|
||||
runId: str = Field(min_length=1, max_length=255)
|
||||
parentRunId: str | None = Field(default=None, max_length=255)
|
||||
state: dict[str, Any] = Field(default_factory=dict)
|
||||
messages: list[dict[str, Any]] = Field(default_factory=list)
|
||||
tools: list[dict[str, Any]] = Field(default_factory=list)
|
||||
context: list[dict[str, Any]] = Field(default_factory=list)
|
||||
forwardedProps: dict[str, Any] = Field(default_factory=dict)
|
||||
resume: dict[str, Any] | None = None
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime, timezone
|
||||
from decimal import Decimal
|
||||
@@ -379,17 +380,23 @@ class AgentChatService(BaseService):
|
||||
return ResumeDecisionResult(applied=True)
|
||||
|
||||
async def stream_run(self, input_data: RunAgentInput) -> AsyncGenerator[str, None]:
|
||||
yield 'data: {"type": "RUN_STARTED", "runId": "' + input_data.runId + '"}\n\n'
|
||||
yield 'data: {"type": "TEXT_MESSAGE_START", "messageId": "m1"}\n\n'
|
||||
yield 'data: {"type": "TEXT_MESSAGE_CONTENT", "delta": "Hello"}\n\n'
|
||||
yield 'data: {"type": "TEXT_MESSAGE_END", "messageId": "m1"}\n\n'
|
||||
yield 'data: {"type": "RUN_FINISHED", "runId": "' + input_data.runId + '"}\n\n'
|
||||
if "\n" in input_data.runId or "\r" in input_data.runId:
|
||||
raise ValueError("runId must not contain newlines")
|
||||
|
||||
yield f"data: {json.dumps({'type': 'RUN_STARTED', 'runId': input_data.runId})}\n\n"
|
||||
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_START', 'messageId': 'm1'})}\n\n"
|
||||
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_CONTENT', 'delta': 'Hello'})}\n\n"
|
||||
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_END', 'messageId': 'm1'})}\n\n"
|
||||
yield f"data: {json.dumps({'type': 'RUN_FINISHED', 'runId': input_data.runId})}\n\n"
|
||||
|
||||
async def stream_resume(
|
||||
self, run_id: str, input_data: RunAgentInput
|
||||
) -> AsyncGenerator[str, None]:
|
||||
yield 'data: {"type": "RUN_STARTED", "runId": "' + run_id + '"}\n\n'
|
||||
yield 'data: {"type": "TEXT_MESSAGE_START", "messageId": "m2"}\n\n'
|
||||
yield 'data: {"type": "TEXT_MESSAGE_CONTENT", "delta": "Resumed"}\n\n'
|
||||
yield 'data: {"type": "TEXT_MESSAGE_END", "messageId": "m2"}\n\n'
|
||||
yield 'data: {"type": "RUN_FINISHED", "runId": "' + run_id + '"}\n\n'
|
||||
if "\n" in run_id or "\r" in run_id:
|
||||
raise ValueError("runId must not contain newlines")
|
||||
|
||||
yield f"data: {json.dumps({'type': 'RUN_STARTED', 'runId': run_id})}\n\n"
|
||||
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_START', 'messageId': 'm2'})}\n\n"
|
||||
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_CONTENT', 'delta': 'Resumed'})}\n\n"
|
||||
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_END', 'messageId': 'm2'})}\n\n"
|
||||
yield f"data: {json.dumps({'type': 'RUN_FINISHED', 'runId': run_id})}\n\n"
|
||||
|
||||
@@ -4,6 +4,19 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from v1.agent.tool_registry import validate_tool_spec
|
||||
|
||||
|
||||
ALLOWED_BACKEND_TOOLS = frozenset(
|
||||
{
|
||||
"srv.search_docs",
|
||||
"srv.get_user_info",
|
||||
"srv.send_message",
|
||||
"srv.transfer_funds",
|
||||
"srv.delete_file",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class InterruptResult(BaseModel):
|
||||
interrupt_type: str
|
||||
@@ -27,6 +40,8 @@ class ToolDispatcher:
|
||||
def dispatch_tool_call(
|
||||
tool: dict[str, Any],
|
||||
) -> InterruptResult | BackendExecutionResult:
|
||||
validate_tool_spec(tool)
|
||||
|
||||
name = tool["name"]
|
||||
target = tool["execution_target"]
|
||||
args = tool.get("args", {})
|
||||
@@ -39,6 +54,9 @@ def dispatch_tool_call(
|
||||
)
|
||||
|
||||
if target == "backend":
|
||||
if name not in ALLOWED_BACKEND_TOOLS:
|
||||
raise ValueError(f"Backend tool '{name}' not in allowlist")
|
||||
|
||||
requires_approval = tool.get("requires_approval", False)
|
||||
if requires_approval:
|
||||
return InterruptResult(
|
||||
|
||||
@@ -4,8 +4,16 @@ from typing import Any
|
||||
|
||||
|
||||
def validate_tool_spec(spec: dict[str, Any]) -> None:
|
||||
name = spec["name"]
|
||||
target = spec["execution_target"]
|
||||
try:
|
||||
name = spec["name"]
|
||||
target = spec["execution_target"]
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Missing required field: {e.args[0]}") from e
|
||||
|
||||
if not name or not isinstance(name, str):
|
||||
raise ValueError("Tool name must be a non-empty string")
|
||||
if not target or not isinstance(target, str):
|
||||
raise ValueError("execution_target must be a non-empty string")
|
||||
|
||||
if not (name.startswith("ui.") or name.startswith("srv.")):
|
||||
raise ValueError("Tool name must be in ui.* or srv.* namespace")
|
||||
|
||||
@@ -62,7 +62,7 @@ class TestChatRoutes:
|
||||
def test_resume_route_streams_sse_events(self, client: TestClient):
|
||||
payload = {
|
||||
"threadId": "t1",
|
||||
"runId": "r2",
|
||||
"runId": "r1",
|
||||
"state": {},
|
||||
"messages": [],
|
||||
"tools": [],
|
||||
|
||||
Reference in New Issue
Block a user