fix(agent): address CRITICAL/HIGH security and validation issues
- Fix SSE JSON injection: use json.dumps for safe serialization - Add tool validation to dispatcher with allowlist - Add field validation to tool_registry with proper error handling - Add run_id consistency check (409 on mismatch) - Add RunAgentInput constraints: min_length, extra=forbid - Fix crewai_flow: use Field(default_factory), prefix unused params
This commit is contained in:
@@ -2,14 +2,14 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
class FlowState(BaseModel):
|
class FlowState(BaseModel):
|
||||||
stage_trace: list[str] = []
|
stage_trace: list[str] = Field(default_factory=list)
|
||||||
current_stage: str | None = None
|
current_stage: str | None = None
|
||||||
message: str = ""
|
message: str = ""
|
||||||
context: dict[str, Any] = {}
|
context: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class AgentFlow:
|
class AgentFlow:
|
||||||
@@ -27,12 +27,12 @@ class AgentFlow:
|
|||||||
self.state.stage_trace.append("intent")
|
self.state.stage_trace.append("intent")
|
||||||
return {"stage": "intent", "result": "intent recognized"}
|
return {"stage": "intent", "result": "intent recognized"}
|
||||||
|
|
||||||
async def task_execution(self, prev_result: dict[str, Any]) -> dict[str, Any]:
|
async def task_execution(self, _prev_result: dict[str, Any]) -> dict[str, Any]:
|
||||||
self.state.current_stage = "execution"
|
self.state.current_stage = "execution"
|
||||||
self.state.stage_trace.append("execution")
|
self.state.stage_trace.append("execution")
|
||||||
return {"stage": "execution", "result": "task executed"}
|
return {"stage": "execution", "result": "task executed"}
|
||||||
|
|
||||||
async def result_reporting(self, prev_result: dict[str, Any]) -> dict[str, Any]:
|
async def result_reporting(self, _prev_result: dict[str, Any]) -> dict[str, Any]:
|
||||||
self.state.current_stage = "reporting"
|
self.state.current_stage = "reporting"
|
||||||
self.state.stage_trace.append("reporting")
|
self.state.stage_trace.append("reporting")
|
||||||
return {"stage": "reporting", "result": "reported"}
|
return {"stage": "reporting", "result": "reported"}
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
from v1.agent.dependencies import get_agent_service
|
from v1.agent.dependencies import get_agent_service
|
||||||
@@ -29,6 +29,11 @@ async def resume_run(
|
|||||||
input_data: RunAgentInput,
|
input_data: RunAgentInput,
|
||||||
service: Annotated[AgentChatService, Depends(get_agent_service)],
|
service: Annotated[AgentChatService, Depends(get_agent_service)],
|
||||||
) -> StreamingResponse:
|
) -> StreamingResponse:
|
||||||
|
if input_data.runId != run_id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=409,
|
||||||
|
detail=f"run_id mismatch: path={run_id}, body={input_data.runId}",
|
||||||
|
)
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
service.stream_resume(run_id, input_data),
|
service.stream_resume(run_id, input_data),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
|
|||||||
@@ -3,18 +3,20 @@ from __future__ import annotations
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
|
||||||
class RunAgentInput(BaseModel):
|
class RunAgentInput(BaseModel):
|
||||||
threadId: str
|
model_config = ConfigDict(extra="forbid")
|
||||||
runId: str
|
|
||||||
parentRunId: str | None = None
|
threadId: str = Field(min_length=1, max_length=255)
|
||||||
state: dict[str, Any]
|
runId: str = Field(min_length=1, max_length=255)
|
||||||
messages: list[dict[str, Any]]
|
parentRunId: str | None = Field(default=None, max_length=255)
|
||||||
tools: list[dict[str, Any]]
|
state: dict[str, Any] = Field(default_factory=dict)
|
||||||
context: list[dict[str, Any]]
|
messages: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
forwardedProps: dict[str, Any]
|
tools: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
|
context: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
|
forwardedProps: dict[str, Any] = Field(default_factory=dict)
|
||||||
resume: dict[str, Any] | None = None
|
resume: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
@@ -379,17 +380,23 @@ class AgentChatService(BaseService):
|
|||||||
return ResumeDecisionResult(applied=True)
|
return ResumeDecisionResult(applied=True)
|
||||||
|
|
||||||
async def stream_run(self, input_data: RunAgentInput) -> AsyncGenerator[str, None]:
|
async def stream_run(self, input_data: RunAgentInput) -> AsyncGenerator[str, None]:
|
||||||
yield 'data: {"type": "RUN_STARTED", "runId": "' + input_data.runId + '"}\n\n'
|
if "\n" in input_data.runId or "\r" in input_data.runId:
|
||||||
yield 'data: {"type": "TEXT_MESSAGE_START", "messageId": "m1"}\n\n'
|
raise ValueError("runId must not contain newlines")
|
||||||
yield 'data: {"type": "TEXT_MESSAGE_CONTENT", "delta": "Hello"}\n\n'
|
|
||||||
yield 'data: {"type": "TEXT_MESSAGE_END", "messageId": "m1"}\n\n'
|
yield f"data: {json.dumps({'type': 'RUN_STARTED', 'runId': input_data.runId})}\n\n"
|
||||||
yield 'data: {"type": "RUN_FINISHED", "runId": "' + input_data.runId + '"}\n\n'
|
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_START', 'messageId': 'm1'})}\n\n"
|
||||||
|
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_CONTENT', 'delta': 'Hello'})}\n\n"
|
||||||
|
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_END', 'messageId': 'm1'})}\n\n"
|
||||||
|
yield f"data: {json.dumps({'type': 'RUN_FINISHED', 'runId': input_data.runId})}\n\n"
|
||||||
|
|
||||||
async def stream_resume(
|
async def stream_resume(
|
||||||
self, run_id: str, input_data: RunAgentInput
|
self, run_id: str, input_data: RunAgentInput
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
yield 'data: {"type": "RUN_STARTED", "runId": "' + run_id + '"}\n\n'
|
if "\n" in run_id or "\r" in run_id:
|
||||||
yield 'data: {"type": "TEXT_MESSAGE_START", "messageId": "m2"}\n\n'
|
raise ValueError("runId must not contain newlines")
|
||||||
yield 'data: {"type": "TEXT_MESSAGE_CONTENT", "delta": "Resumed"}\n\n'
|
|
||||||
yield 'data: {"type": "TEXT_MESSAGE_END", "messageId": "m2"}\n\n'
|
yield f"data: {json.dumps({'type': 'RUN_STARTED', 'runId': run_id})}\n\n"
|
||||||
yield 'data: {"type": "RUN_FINISHED", "runId": "' + run_id + '"}\n\n'
|
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_START', 'messageId': 'm2'})}\n\n"
|
||||||
|
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_CONTENT', 'delta': 'Resumed'})}\n\n"
|
||||||
|
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_END', 'messageId': 'm2'})}\n\n"
|
||||||
|
yield f"data: {json.dumps({'type': 'RUN_FINISHED', 'runId': run_id})}\n\n"
|
||||||
|
|||||||
@@ -4,6 +4,19 @@ from typing import Any
|
|||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from v1.agent.tool_registry import validate_tool_spec
|
||||||
|
|
||||||
|
|
||||||
|
ALLOWED_BACKEND_TOOLS = frozenset(
|
||||||
|
{
|
||||||
|
"srv.search_docs",
|
||||||
|
"srv.get_user_info",
|
||||||
|
"srv.send_message",
|
||||||
|
"srv.transfer_funds",
|
||||||
|
"srv.delete_file",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class InterruptResult(BaseModel):
|
class InterruptResult(BaseModel):
|
||||||
interrupt_type: str
|
interrupt_type: str
|
||||||
@@ -27,6 +40,8 @@ class ToolDispatcher:
|
|||||||
def dispatch_tool_call(
|
def dispatch_tool_call(
|
||||||
tool: dict[str, Any],
|
tool: dict[str, Any],
|
||||||
) -> InterruptResult | BackendExecutionResult:
|
) -> InterruptResult | BackendExecutionResult:
|
||||||
|
validate_tool_spec(tool)
|
||||||
|
|
||||||
name = tool["name"]
|
name = tool["name"]
|
||||||
target = tool["execution_target"]
|
target = tool["execution_target"]
|
||||||
args = tool.get("args", {})
|
args = tool.get("args", {})
|
||||||
@@ -39,6 +54,9 @@ def dispatch_tool_call(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if target == "backend":
|
if target == "backend":
|
||||||
|
if name not in ALLOWED_BACKEND_TOOLS:
|
||||||
|
raise ValueError(f"Backend tool '{name}' not in allowlist")
|
||||||
|
|
||||||
requires_approval = tool.get("requires_approval", False)
|
requires_approval = tool.get("requires_approval", False)
|
||||||
if requires_approval:
|
if requires_approval:
|
||||||
return InterruptResult(
|
return InterruptResult(
|
||||||
|
|||||||
@@ -4,8 +4,16 @@ from typing import Any
|
|||||||
|
|
||||||
|
|
||||||
def validate_tool_spec(spec: dict[str, Any]) -> None:
|
def validate_tool_spec(spec: dict[str, Any]) -> None:
|
||||||
name = spec["name"]
|
try:
|
||||||
target = spec["execution_target"]
|
name = spec["name"]
|
||||||
|
target = spec["execution_target"]
|
||||||
|
except KeyError as e:
|
||||||
|
raise ValueError(f"Missing required field: {e.args[0]}") from e
|
||||||
|
|
||||||
|
if not name or not isinstance(name, str):
|
||||||
|
raise ValueError("Tool name must be a non-empty string")
|
||||||
|
if not target or not isinstance(target, str):
|
||||||
|
raise ValueError("execution_target must be a non-empty string")
|
||||||
|
|
||||||
if not (name.startswith("ui.") or name.startswith("srv.")):
|
if not (name.startswith("ui.") or name.startswith("srv.")):
|
||||||
raise ValueError("Tool name must be in ui.* or srv.* namespace")
|
raise ValueError("Tool name must be in ui.* or srv.* namespace")
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ class TestChatRoutes:
|
|||||||
def test_resume_route_streams_sse_events(self, client: TestClient):
|
def test_resume_route_streams_sse_events(self, client: TestClient):
|
||||||
payload = {
|
payload = {
|
||||||
"threadId": "t1",
|
"threadId": "t1",
|
||||||
"runId": "r2",
|
"runId": "r1",
|
||||||
"state": {},
|
"state": {},
|
||||||
"messages": [],
|
"messages": [],
|
||||||
"tools": [],
|
"tools": [],
|
||||||
|
|||||||
Reference in New Issue
Block a user