From 9aefb76c9e87bc3483acc1835bb9bb2eec67a4de Mon Sep 17 00:00:00 2001 From: qzl Date: Tue, 3 Mar 2026 16:25:43 +0800 Subject: [PATCH] fix(agent): address CRITICAL/HIGH security and validation issues - Fix SSE JSON injection: use json.dumps for safe serialization - Add tool validation to dispatcher with allowlist - Add field validation to tool_registry with proper error handling - Add run_id consistency check (409 on mismatch) - Add RunAgentInput constraints: min_length, extra=forbid - Fix crewai_flow: use Field(default_factory), prefix unused params --- backend/src/v1/agent/crewai_flow.py | 10 +++---- backend/src/v1/agent/router.py | 7 ++++- backend/src/v1/agent/schemas.py | 20 +++++++------- backend/src/v1/agent/service.py | 27 ++++++++++++------- backend/src/v1/agent/tool_dispatcher.py | 18 +++++++++++++ backend/src/v1/agent/tool_registry.py | 12 +++++++-- .../integration/v1/agent/test_chat_routes.py | 2 +- 7 files changed, 68 insertions(+), 28 deletions(-) diff --git a/backend/src/v1/agent/crewai_flow.py b/backend/src/v1/agent/crewai_flow.py index fefffb6..271ec7e 100644 --- a/backend/src/v1/agent/crewai_flow.py +++ b/backend/src/v1/agent/crewai_flow.py @@ -2,14 +2,14 @@ from __future__ import annotations from typing import Any -from pydantic import BaseModel +from pydantic import BaseModel, Field class FlowState(BaseModel): - stage_trace: list[str] = [] + stage_trace: list[str] = Field(default_factory=list) current_stage: str | None = None message: str = "" - context: dict[str, Any] = {} + context: dict[str, Any] = Field(default_factory=dict) class AgentFlow: @@ -27,12 +27,12 @@ class AgentFlow: self.state.stage_trace.append("intent") return {"stage": "intent", "result": "intent recognized"} - async def task_execution(self, prev_result: dict[str, Any]) -> dict[str, Any]: + async def task_execution(self, _prev_result: dict[str, Any]) -> dict[str, Any]: self.state.current_stage = "execution" self.state.stage_trace.append("execution") return {"stage": "execution", "result": "task executed"} - async def result_reporting(self, prev_result: dict[str, Any]) -> dict[str, Any]: + async def result_reporting(self, _prev_result: dict[str, Any]) -> dict[str, Any]: self.state.current_stage = "reporting" self.state.stage_trace.append("reporting") return {"stage": "reporting", "result": "reported"} diff --git a/backend/src/v1/agent/router.py b/backend/src/v1/agent/router.py index 16391ba..667a35b 100644 --- a/backend/src/v1/agent/router.py +++ b/backend/src/v1/agent/router.py @@ -2,7 +2,7 @@ from __future__ import annotations from typing import Annotated -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, HTTPException from fastapi.responses import StreamingResponse from v1.agent.dependencies import get_agent_service @@ -29,6 +29,11 @@ async def resume_run( input_data: RunAgentInput, service: Annotated[AgentChatService, Depends(get_agent_service)], ) -> StreamingResponse: + if input_data.runId != run_id: + raise HTTPException( + status_code=409, + detail=f"run_id mismatch: path={run_id}, body={input_data.runId}", + ) return StreamingResponse( service.stream_resume(run_id, input_data), media_type="text/event-stream", diff --git a/backend/src/v1/agent/schemas.py b/backend/src/v1/agent/schemas.py index 31bc573..59a5a8b 100644 --- a/backend/src/v1/agent/schemas.py +++ b/backend/src/v1/agent/schemas.py @@ -3,18 +3,20 @@ from __future__ import annotations from typing import Any from uuid import UUID -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field class RunAgentInput(BaseModel): - threadId: str - runId: str - parentRunId: str | None = None - state: dict[str, Any] - messages: list[dict[str, Any]] - tools: list[dict[str, Any]] - context: list[dict[str, Any]] - forwardedProps: dict[str, Any] + model_config = ConfigDict(extra="forbid") + + threadId: str = Field(min_length=1, max_length=255) + runId: str = Field(min_length=1, max_length=255) + parentRunId: str | None = Field(default=None, max_length=255) + state: dict[str, Any] = Field(default_factory=dict) + messages: list[dict[str, Any]] = Field(default_factory=list) + tools: list[dict[str, Any]] = Field(default_factory=list) + context: list[dict[str, Any]] = Field(default_factory=list) + forwardedProps: dict[str, Any] = Field(default_factory=dict) resume: dict[str, Any] | None = None diff --git a/backend/src/v1/agent/service.py b/backend/src/v1/agent/service.py index a583ba2..893056c 100644 --- a/backend/src/v1/agent/service.py +++ b/backend/src/v1/agent/service.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json from collections.abc import AsyncGenerator from datetime import datetime, timezone from decimal import Decimal @@ -379,17 +380,23 @@ class AgentChatService(BaseService): return ResumeDecisionResult(applied=True) async def stream_run(self, input_data: RunAgentInput) -> AsyncGenerator[str, None]: - yield 'data: {"type": "RUN_STARTED", "runId": "' + input_data.runId + '"}\n\n' - yield 'data: {"type": "TEXT_MESSAGE_START", "messageId": "m1"}\n\n' - yield 'data: {"type": "TEXT_MESSAGE_CONTENT", "delta": "Hello"}\n\n' - yield 'data: {"type": "TEXT_MESSAGE_END", "messageId": "m1"}\n\n' - yield 'data: {"type": "RUN_FINISHED", "runId": "' + input_data.runId + '"}\n\n' + if "\n" in input_data.runId or "\r" in input_data.runId: + raise ValueError("runId must not contain newlines") + + yield f"data: {json.dumps({'type': 'RUN_STARTED', 'runId': input_data.runId})}\n\n" + yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_START', 'messageId': 'm1'})}\n\n" + yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_CONTENT', 'delta': 'Hello'})}\n\n" + yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_END', 'messageId': 'm1'})}\n\n" + yield f"data: {json.dumps({'type': 'RUN_FINISHED', 'runId': input_data.runId})}\n\n" async def stream_resume( self, run_id: str, input_data: RunAgentInput ) -> AsyncGenerator[str, None]: - yield 'data: {"type": "RUN_STARTED", "runId": "' + run_id + '"}\n\n' - yield 'data: {"type": "TEXT_MESSAGE_START", "messageId": "m2"}\n\n' - yield 'data: {"type": "TEXT_MESSAGE_CONTENT", "delta": "Resumed"}\n\n' - yield 'data: {"type": "TEXT_MESSAGE_END", "messageId": "m2"}\n\n' - yield 'data: {"type": "RUN_FINISHED", "runId": "' + run_id + '"}\n\n' + if "\n" in run_id or "\r" in run_id: + raise ValueError("runId must not contain newlines") + + yield f"data: {json.dumps({'type': 'RUN_STARTED', 'runId': run_id})}\n\n" + yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_START', 'messageId': 'm2'})}\n\n" + yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_CONTENT', 'delta': 'Resumed'})}\n\n" + yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_END', 'messageId': 'm2'})}\n\n" + yield f"data: {json.dumps({'type': 'RUN_FINISHED', 'runId': run_id})}\n\n" diff --git a/backend/src/v1/agent/tool_dispatcher.py b/backend/src/v1/agent/tool_dispatcher.py index 264f14f..efe1909 100644 --- a/backend/src/v1/agent/tool_dispatcher.py +++ b/backend/src/v1/agent/tool_dispatcher.py @@ -4,6 +4,19 @@ from typing import Any from pydantic import BaseModel +from v1.agent.tool_registry import validate_tool_spec + + +ALLOWED_BACKEND_TOOLS = frozenset( + { + "srv.search_docs", + "srv.get_user_info", + "srv.send_message", + "srv.transfer_funds", + "srv.delete_file", + } +) + class InterruptResult(BaseModel): interrupt_type: str @@ -27,6 +40,8 @@ class ToolDispatcher: def dispatch_tool_call( tool: dict[str, Any], ) -> InterruptResult | BackendExecutionResult: + validate_tool_spec(tool) + name = tool["name"] target = tool["execution_target"] args = tool.get("args", {}) @@ -39,6 +54,9 @@ def dispatch_tool_call( ) if target == "backend": + if name not in ALLOWED_BACKEND_TOOLS: + raise ValueError(f"Backend tool '{name}' not in allowlist") + requires_approval = tool.get("requires_approval", False) if requires_approval: return InterruptResult( diff --git a/backend/src/v1/agent/tool_registry.py b/backend/src/v1/agent/tool_registry.py index 0f63641..5c5458b 100644 --- a/backend/src/v1/agent/tool_registry.py +++ b/backend/src/v1/agent/tool_registry.py @@ -4,8 +4,16 @@ from typing import Any def validate_tool_spec(spec: dict[str, Any]) -> None: - name = spec["name"] - target = spec["execution_target"] + try: + name = spec["name"] + target = spec["execution_target"] + except KeyError as e: + raise ValueError(f"Missing required field: {e.args[0]}") from e + + if not name or not isinstance(name, str): + raise ValueError("Tool name must be a non-empty string") + if not target or not isinstance(target, str): + raise ValueError("execution_target must be a non-empty string") if not (name.startswith("ui.") or name.startswith("srv.")): raise ValueError("Tool name must be in ui.* or srv.* namespace") diff --git a/backend/tests/integration/v1/agent/test_chat_routes.py b/backend/tests/integration/v1/agent/test_chat_routes.py index 4ba9832..b81ab75 100644 --- a/backend/tests/integration/v1/agent/test_chat_routes.py +++ b/backend/tests/integration/v1/agent/test_chat_routes.py @@ -62,7 +62,7 @@ class TestChatRoutes: def test_resume_route_streams_sse_events(self, client: TestClient): payload = { "threadId": "t1", - "runId": "r2", + "runId": "r1", "state": {}, "messages": [], "tools": [],