fix(agent): address CRITICAL/HIGH security and validation issues

- Fix SSE JSON injection: use json.dumps for safe serialization
- Add tool validation to dispatcher with allowlist
- Add field validation to tool_registry with proper error handling
- Add run_id consistency check (409 on mismatch)
- Add RunAgentInput constraints: min_length, extra=forbid
- Fix crewai_flow: use Field(default_factory), prefix unused params
This commit is contained in:
qzl
2026-03-03 16:25:43 +08:00
parent ff85c1ab08
commit 9aefb76c9e
7 changed files with 68 additions and 28 deletions
+5 -5
View File
@@ -2,14 +2,14 @@ from __future__ import annotations
from typing import Any from typing import Any
from pydantic import BaseModel from pydantic import BaseModel, Field
class FlowState(BaseModel): class FlowState(BaseModel):
stage_trace: list[str] = [] stage_trace: list[str] = Field(default_factory=list)
current_stage: str | None = None current_stage: str | None = None
message: str = "" message: str = ""
context: dict[str, Any] = {} context: dict[str, Any] = Field(default_factory=dict)
class AgentFlow: class AgentFlow:
@@ -27,12 +27,12 @@ class AgentFlow:
self.state.stage_trace.append("intent") self.state.stage_trace.append("intent")
return {"stage": "intent", "result": "intent recognized"} return {"stage": "intent", "result": "intent recognized"}
async def task_execution(self, prev_result: dict[str, Any]) -> dict[str, Any]: async def task_execution(self, _prev_result: dict[str, Any]) -> dict[str, Any]:
self.state.current_stage = "execution" self.state.current_stage = "execution"
self.state.stage_trace.append("execution") self.state.stage_trace.append("execution")
return {"stage": "execution", "result": "task executed"} return {"stage": "execution", "result": "task executed"}
async def result_reporting(self, prev_result: dict[str, Any]) -> dict[str, Any]: async def result_reporting(self, _prev_result: dict[str, Any]) -> dict[str, Any]:
self.state.current_stage = "reporting" self.state.current_stage = "reporting"
self.state.stage_trace.append("reporting") self.state.stage_trace.append("reporting")
return {"stage": "reporting", "result": "reported"} return {"stage": "reporting", "result": "reported"}
+6 -1
View File
@@ -2,7 +2,7 @@ from __future__ import annotations
from typing import Annotated from typing import Annotated
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from v1.agent.dependencies import get_agent_service from v1.agent.dependencies import get_agent_service
@@ -29,6 +29,11 @@ async def resume_run(
input_data: RunAgentInput, input_data: RunAgentInput,
service: Annotated[AgentChatService, Depends(get_agent_service)], service: Annotated[AgentChatService, Depends(get_agent_service)],
) -> StreamingResponse: ) -> StreamingResponse:
if input_data.runId != run_id:
raise HTTPException(
status_code=409,
detail=f"run_id mismatch: path={run_id}, body={input_data.runId}",
)
return StreamingResponse( return StreamingResponse(
service.stream_resume(run_id, input_data), service.stream_resume(run_id, input_data),
media_type="text/event-stream", media_type="text/event-stream",
+11 -9
View File
@@ -3,18 +3,20 @@ from __future__ import annotations
from typing import Any from typing import Any
from uuid import UUID from uuid import UUID
from pydantic import BaseModel, Field from pydantic import BaseModel, ConfigDict, Field
class RunAgentInput(BaseModel): class RunAgentInput(BaseModel):
threadId: str model_config = ConfigDict(extra="forbid")
runId: str
parentRunId: str | None = None threadId: str = Field(min_length=1, max_length=255)
state: dict[str, Any] runId: str = Field(min_length=1, max_length=255)
messages: list[dict[str, Any]] parentRunId: str | None = Field(default=None, max_length=255)
tools: list[dict[str, Any]] state: dict[str, Any] = Field(default_factory=dict)
context: list[dict[str, Any]] messages: list[dict[str, Any]] = Field(default_factory=list)
forwardedProps: dict[str, Any] tools: list[dict[str, Any]] = Field(default_factory=list)
context: list[dict[str, Any]] = Field(default_factory=list)
forwardedProps: dict[str, Any] = Field(default_factory=dict)
resume: dict[str, Any] | None = None resume: dict[str, Any] | None = None
+17 -10
View File
@@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import json
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from datetime import datetime, timezone from datetime import datetime, timezone
from decimal import Decimal from decimal import Decimal
@@ -379,17 +380,23 @@ class AgentChatService(BaseService):
return ResumeDecisionResult(applied=True) return ResumeDecisionResult(applied=True)
async def stream_run(self, input_data: RunAgentInput) -> AsyncGenerator[str, None]: async def stream_run(self, input_data: RunAgentInput) -> AsyncGenerator[str, None]:
yield 'data: {"type": "RUN_STARTED", "runId": "' + input_data.runId + '"}\n\n' if "\n" in input_data.runId or "\r" in input_data.runId:
yield 'data: {"type": "TEXT_MESSAGE_START", "messageId": "m1"}\n\n' raise ValueError("runId must not contain newlines")
yield 'data: {"type": "TEXT_MESSAGE_CONTENT", "delta": "Hello"}\n\n'
yield 'data: {"type": "TEXT_MESSAGE_END", "messageId": "m1"}\n\n' yield f"data: {json.dumps({'type': 'RUN_STARTED', 'runId': input_data.runId})}\n\n"
yield 'data: {"type": "RUN_FINISHED", "runId": "' + input_data.runId + '"}\n\n' yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_START', 'messageId': 'm1'})}\n\n"
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_CONTENT', 'delta': 'Hello'})}\n\n"
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_END', 'messageId': 'm1'})}\n\n"
yield f"data: {json.dumps({'type': 'RUN_FINISHED', 'runId': input_data.runId})}\n\n"
async def stream_resume( async def stream_resume(
self, run_id: str, input_data: RunAgentInput self, run_id: str, input_data: RunAgentInput
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
yield 'data: {"type": "RUN_STARTED", "runId": "' + run_id + '"}\n\n' if "\n" in run_id or "\r" in run_id:
yield 'data: {"type": "TEXT_MESSAGE_START", "messageId": "m2"}\n\n' raise ValueError("runId must not contain newlines")
yield 'data: {"type": "TEXT_MESSAGE_CONTENT", "delta": "Resumed"}\n\n'
yield 'data: {"type": "TEXT_MESSAGE_END", "messageId": "m2"}\n\n' yield f"data: {json.dumps({'type': 'RUN_STARTED', 'runId': run_id})}\n\n"
yield 'data: {"type": "RUN_FINISHED", "runId": "' + run_id + '"}\n\n' yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_START', 'messageId': 'm2'})}\n\n"
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_CONTENT', 'delta': 'Resumed'})}\n\n"
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_END', 'messageId': 'm2'})}\n\n"
yield f"data: {json.dumps({'type': 'RUN_FINISHED', 'runId': run_id})}\n\n"
+18
View File
@@ -4,6 +4,19 @@ from typing import Any
from pydantic import BaseModel from pydantic import BaseModel
from v1.agent.tool_registry import validate_tool_spec
ALLOWED_BACKEND_TOOLS = frozenset(
{
"srv.search_docs",
"srv.get_user_info",
"srv.send_message",
"srv.transfer_funds",
"srv.delete_file",
}
)
class InterruptResult(BaseModel): class InterruptResult(BaseModel):
interrupt_type: str interrupt_type: str
@@ -27,6 +40,8 @@ class ToolDispatcher:
def dispatch_tool_call( def dispatch_tool_call(
tool: dict[str, Any], tool: dict[str, Any],
) -> InterruptResult | BackendExecutionResult: ) -> InterruptResult | BackendExecutionResult:
validate_tool_spec(tool)
name = tool["name"] name = tool["name"]
target = tool["execution_target"] target = tool["execution_target"]
args = tool.get("args", {}) args = tool.get("args", {})
@@ -39,6 +54,9 @@ def dispatch_tool_call(
) )
if target == "backend": if target == "backend":
if name not in ALLOWED_BACKEND_TOOLS:
raise ValueError(f"Backend tool '{name}' not in allowlist")
requires_approval = tool.get("requires_approval", False) requires_approval = tool.get("requires_approval", False)
if requires_approval: if requires_approval:
return InterruptResult( return InterruptResult(
+10 -2
View File
@@ -4,8 +4,16 @@ from typing import Any
def validate_tool_spec(spec: dict[str, Any]) -> None: def validate_tool_spec(spec: dict[str, Any]) -> None:
name = spec["name"] try:
target = spec["execution_target"] name = spec["name"]
target = spec["execution_target"]
except KeyError as e:
raise ValueError(f"Missing required field: {e.args[0]}") from e
if not name or not isinstance(name, str):
raise ValueError("Tool name must be a non-empty string")
if not target or not isinstance(target, str):
raise ValueError("execution_target must be a non-empty string")
if not (name.startswith("ui.") or name.startswith("srv.")): if not (name.startswith("ui.") or name.startswith("srv.")):
raise ValueError("Tool name must be in ui.* or srv.* namespace") raise ValueError("Tool name must be in ui.* or srv.* namespace")
@@ -62,7 +62,7 @@ class TestChatRoutes:
def test_resume_route_streams_sse_events(self, client: TestClient): def test_resume_route_streams_sse_events(self, client: TestClient):
payload = { payload = {
"threadId": "t1", "threadId": "t1",
"runId": "r2", "runId": "r1",
"state": {}, "state": {},
"messages": [], "messages": [],
"tools": [], "tools": [],