fix(agent): address CRITICAL/HIGH security and validation issues

- Fix SSE JSON injection: use json.dumps for safe serialization
- Add tool validation to dispatcher with allowlist
- Add field validation to tool_registry with proper error handling
- Add run_id consistency check (409 on mismatch)
- Add RunAgentInput constraints: min_length, extra=forbid
- Fix crewai_flow: use Field(default_factory), prefix unused params
This commit is contained in:
qzl
2026-03-03 16:25:43 +08:00
parent ff85c1ab08
commit 9aefb76c9e
7 changed files with 68 additions and 28 deletions
+5 -5
View File
@@ -2,14 +2,14 @@ from __future__ import annotations
from typing import Any
from pydantic import BaseModel
from pydantic import BaseModel, Field
class FlowState(BaseModel):
stage_trace: list[str] = []
stage_trace: list[str] = Field(default_factory=list)
current_stage: str | None = None
message: str = ""
context: dict[str, Any] = {}
context: dict[str, Any] = Field(default_factory=dict)
class AgentFlow:
@@ -27,12 +27,12 @@ class AgentFlow:
self.state.stage_trace.append("intent")
return {"stage": "intent", "result": "intent recognized"}
async def task_execution(self, prev_result: dict[str, Any]) -> dict[str, Any]:
async def task_execution(self, _prev_result: dict[str, Any]) -> dict[str, Any]:
self.state.current_stage = "execution"
self.state.stage_trace.append("execution")
return {"stage": "execution", "result": "task executed"}
async def result_reporting(self, prev_result: dict[str, Any]) -> dict[str, Any]:
async def result_reporting(self, _prev_result: dict[str, Any]) -> dict[str, Any]:
self.state.current_stage = "reporting"
self.state.stage_trace.append("reporting")
return {"stage": "reporting", "result": "reported"}
+6 -1
View File
@@ -2,7 +2,7 @@ from __future__ import annotations
from typing import Annotated
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from v1.agent.dependencies import get_agent_service
@@ -29,6 +29,11 @@ async def resume_run(
input_data: RunAgentInput,
service: Annotated[AgentChatService, Depends(get_agent_service)],
) -> StreamingResponse:
if input_data.runId != run_id:
raise HTTPException(
status_code=409,
detail=f"run_id mismatch: path={run_id}, body={input_data.runId}",
)
return StreamingResponse(
service.stream_resume(run_id, input_data),
media_type="text/event-stream",
+11 -9
View File
@@ -3,18 +3,20 @@ from __future__ import annotations
from typing import Any
from uuid import UUID
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
class RunAgentInput(BaseModel):
threadId: str
runId: str
parentRunId: str | None = None
state: dict[str, Any]
messages: list[dict[str, Any]]
tools: list[dict[str, Any]]
context: list[dict[str, Any]]
forwardedProps: dict[str, Any]
model_config = ConfigDict(extra="forbid")
threadId: str = Field(min_length=1, max_length=255)
runId: str = Field(min_length=1, max_length=255)
parentRunId: str | None = Field(default=None, max_length=255)
state: dict[str, Any] = Field(default_factory=dict)
messages: list[dict[str, Any]] = Field(default_factory=list)
tools: list[dict[str, Any]] = Field(default_factory=list)
context: list[dict[str, Any]] = Field(default_factory=list)
forwardedProps: dict[str, Any] = Field(default_factory=dict)
resume: dict[str, Any] | None = None
+17 -10
View File
@@ -1,5 +1,6 @@
from __future__ import annotations
import json
from collections.abc import AsyncGenerator
from datetime import datetime, timezone
from decimal import Decimal
@@ -379,17 +380,23 @@ class AgentChatService(BaseService):
return ResumeDecisionResult(applied=True)
async def stream_run(self, input_data: RunAgentInput) -> AsyncGenerator[str, None]:
yield 'data: {"type": "RUN_STARTED", "runId": "' + input_data.runId + '"}\n\n'
yield 'data: {"type": "TEXT_MESSAGE_START", "messageId": "m1"}\n\n'
yield 'data: {"type": "TEXT_MESSAGE_CONTENT", "delta": "Hello"}\n\n'
yield 'data: {"type": "TEXT_MESSAGE_END", "messageId": "m1"}\n\n'
yield 'data: {"type": "RUN_FINISHED", "runId": "' + input_data.runId + '"}\n\n'
if "\n" in input_data.runId or "\r" in input_data.runId:
raise ValueError("runId must not contain newlines")
yield f"data: {json.dumps({'type': 'RUN_STARTED', 'runId': input_data.runId})}\n\n"
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_START', 'messageId': 'm1'})}\n\n"
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_CONTENT', 'delta': 'Hello'})}\n\n"
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_END', 'messageId': 'm1'})}\n\n"
yield f"data: {json.dumps({'type': 'RUN_FINISHED', 'runId': input_data.runId})}\n\n"
async def stream_resume(
self, run_id: str, input_data: RunAgentInput
) -> AsyncGenerator[str, None]:
yield 'data: {"type": "RUN_STARTED", "runId": "' + run_id + '"}\n\n'
yield 'data: {"type": "TEXT_MESSAGE_START", "messageId": "m2"}\n\n'
yield 'data: {"type": "TEXT_MESSAGE_CONTENT", "delta": "Resumed"}\n\n'
yield 'data: {"type": "TEXT_MESSAGE_END", "messageId": "m2"}\n\n'
yield 'data: {"type": "RUN_FINISHED", "runId": "' + run_id + '"}\n\n'
if "\n" in run_id or "\r" in run_id:
raise ValueError("runId must not contain newlines")
yield f"data: {json.dumps({'type': 'RUN_STARTED', 'runId': run_id})}\n\n"
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_START', 'messageId': 'm2'})}\n\n"
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_CONTENT', 'delta': 'Resumed'})}\n\n"
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_END', 'messageId': 'm2'})}\n\n"
yield f"data: {json.dumps({'type': 'RUN_FINISHED', 'runId': run_id})}\n\n"
+18
View File
@@ -4,6 +4,19 @@ from typing import Any
from pydantic import BaseModel
from v1.agent.tool_registry import validate_tool_spec
ALLOWED_BACKEND_TOOLS = frozenset(
{
"srv.search_docs",
"srv.get_user_info",
"srv.send_message",
"srv.transfer_funds",
"srv.delete_file",
}
)
class InterruptResult(BaseModel):
interrupt_type: str
@@ -27,6 +40,8 @@ class ToolDispatcher:
def dispatch_tool_call(
tool: dict[str, Any],
) -> InterruptResult | BackendExecutionResult:
validate_tool_spec(tool)
name = tool["name"]
target = tool["execution_target"]
args = tool.get("args", {})
@@ -39,6 +54,9 @@ def dispatch_tool_call(
)
if target == "backend":
if name not in ALLOWED_BACKEND_TOOLS:
raise ValueError(f"Backend tool '{name}' not in allowlist")
requires_approval = tool.get("requires_approval", False)
if requires_approval:
return InterruptResult(
+10 -2
View File
@@ -4,8 +4,16 @@ from typing import Any
def validate_tool_spec(spec: dict[str, Any]) -> None:
name = spec["name"]
target = spec["execution_target"]
try:
name = spec["name"]
target = spec["execution_target"]
except KeyError as e:
raise ValueError(f"Missing required field: {e.args[0]}") from e
if not name or not isinstance(name, str):
raise ValueError("Tool name must be a non-empty string")
if not target or not isinstance(target, str):
raise ValueError("execution_target must be a non-empty string")
if not (name.startswith("ui.") or name.startswith("srv.")):
raise ValueError("Tool name must be in ui.* or srv.* namespace")