feat: 添加 Agent 步骤事件与图片附件功能
- 新增 stepStarted/stepFinished 事件类型支持 - 前端实现图片附件上传和预览功能 - 后端增强工具结果存储和事件处理 - 完善相关单元测试和集成测试
This commit is contained in:
@@ -37,6 +37,21 @@ def to_agui_wire_event(event: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
data = event.get("data")
|
||||
if isinstance(data, dict):
|
||||
if event_type == "tool.result":
|
||||
for key in (
|
||||
"messageId",
|
||||
"toolCallId",
|
||||
"callId",
|
||||
"toolName",
|
||||
"stage",
|
||||
"taskId",
|
||||
"ui",
|
||||
"content",
|
||||
):
|
||||
value = data.get(key)
|
||||
if value is not None:
|
||||
payload[key] = value
|
||||
return payload
|
||||
reserved = {"type", "threadId", "runId"}
|
||||
data_map = cast(dict[str, Any], data)
|
||||
payload.update({k: v for k, v in data_map.items() if k not in reserved})
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from decimal import Decimal, InvalidOperation
|
||||
from typing import Any, Callable, Protocol
|
||||
from uuid import UUID
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from core.agentscope.events.tool_result_summary import build_tool_content_summary
|
||||
from core.agentscope.events.persistence import MessageRepository, SessionRepository
|
||||
from core.logging import get_logger
|
||||
from models.agent_chat_message import AgentChatMessageRole
|
||||
from models.agent_chat_session import AgentChatSessionStatus
|
||||
|
||||
@@ -14,6 +17,16 @@ class EventStore(Protocol):
|
||||
async def persist(self, event: dict[str, Any]) -> None: ...
|
||||
|
||||
|
||||
class ToolResultStorageLike(Protocol):
|
||||
async def upload_json(
|
||||
self,
|
||||
*,
|
||||
bucket: str,
|
||||
path: str,
|
||||
payload: dict[str, object],
|
||||
) -> str: ...
|
||||
|
||||
|
||||
class NullEventStore:
|
||||
async def persist(self, event: dict[str, Any]) -> None:
|
||||
del event
|
||||
@@ -21,9 +34,20 @@ class NullEventStore:
|
||||
|
||||
class SqlAlchemyEventStore:
|
||||
_session_factory: Callable[[], Any]
|
||||
_tool_result_storage: ToolResultStorageLike | None
|
||||
_tool_result_bucket: str | None
|
||||
_logger = get_logger("core.agentscope.events.store")
|
||||
|
||||
def __init__(self, *, session_factory: Any) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
session_factory: Any,
|
||||
tool_result_storage: ToolResultStorageLike | None = None,
|
||||
tool_result_bucket: str | None = None,
|
||||
) -> None:
|
||||
self._session_factory = session_factory
|
||||
self._tool_result_storage = tool_result_storage
|
||||
self._tool_result_bucket = tool_result_bucket
|
||||
self._message_buffers: dict[tuple[str, str], str] = {}
|
||||
self._message_contexts: dict[tuple[str, str], dict[str, object]] = {}
|
||||
|
||||
@@ -228,23 +252,89 @@ class SqlAlchemyEventStore:
|
||||
if not isinstance(tool_name, str) or not tool_name:
|
||||
return
|
||||
|
||||
payload = {
|
||||
"args": event.get("args"),
|
||||
"result": event.get("result"),
|
||||
"error": event.get("error"),
|
||||
"call_id": event.get("callId"),
|
||||
}
|
||||
content = json.dumps(payload, ensure_ascii=False, separators=(",", ":"))
|
||||
metadata: dict[str, object] = {"tool_name": tool_name}
|
||||
run_id = event.get("runId")
|
||||
if isinstance(run_id, str) and run_id:
|
||||
metadata["run_id"] = run_id
|
||||
run_id_value = run_id if isinstance(run_id, str) and run_id else ""
|
||||
task_id = event.get("taskId")
|
||||
task_id_value = task_id if isinstance(task_id, str) and task_id else "task"
|
||||
call_id_value = event.get("callId")
|
||||
if not isinstance(call_id_value, str) or not call_id_value:
|
||||
call_id_value = (
|
||||
f"{run_id_value}-{task_id_value}-{uuid4().hex[:8]}"
|
||||
if run_id_value
|
||||
else f"{task_id_value}-{uuid4().hex[:8]}"
|
||||
)
|
||||
|
||||
summary = build_tool_content_summary(
|
||||
tool_name=tool_name,
|
||||
args=event.get("args") if isinstance(event.get("args"), dict) else None,
|
||||
result=event.get("result"),
|
||||
error=event.get("error"),
|
||||
)
|
||||
|
||||
raw_result_value = event.get("result")
|
||||
raw_result: dict[str, object] = (
|
||||
raw_result_value if isinstance(raw_result_value, dict) else {}
|
||||
)
|
||||
ui_candidate = raw_result.get("ui")
|
||||
ui_schema = ui_candidate if isinstance(ui_candidate, dict) else None
|
||||
result_type = raw_result.get("type")
|
||||
result_data = raw_result.get("data")
|
||||
if (
|
||||
ui_schema is None
|
||||
and isinstance(result_type, str)
|
||||
and isinstance(result_data, dict)
|
||||
):
|
||||
ui_schema = raw_result
|
||||
|
||||
payload: dict[str, object] = {
|
||||
"toolName": tool_name,
|
||||
"ui_schema": ui_schema,
|
||||
"result": _sanitize_result(raw_result),
|
||||
"error": _sanitize_error(event.get("error")),
|
||||
"callId": call_id_value,
|
||||
"runId": run_id_value,
|
||||
"taskId": task_id_value,
|
||||
"content": summary,
|
||||
}
|
||||
|
||||
metadata: dict[str, object] = {
|
||||
"tool_name": tool_name,
|
||||
"tool_call_id": call_id_value,
|
||||
"summary_version": "v1",
|
||||
}
|
||||
if run_id_value:
|
||||
metadata["run_id"] = run_id_value
|
||||
stage = event.get("stage")
|
||||
if isinstance(stage, str) and stage:
|
||||
metadata["stage"] = stage
|
||||
task_id = event.get("taskId")
|
||||
if isinstance(task_id, str) and task_id:
|
||||
metadata["task_id"] = task_id
|
||||
if task_id_value:
|
||||
metadata["task_id"] = task_id_value
|
||||
|
||||
if self._tool_result_storage is not None and self._tool_result_bucket:
|
||||
safe_run = _sanitize_path_component(run_id_value or "run")
|
||||
safe_call = _sanitize_path_component(call_id_value)
|
||||
storage_path = f"tool-results/{session_id}/{safe_run}/{safe_call}.json"
|
||||
try:
|
||||
await self._tool_result_storage.upload_json(
|
||||
bucket=self._tool_result_bucket,
|
||||
path=storage_path,
|
||||
payload=payload,
|
||||
)
|
||||
metadata["storage_bucket"] = self._tool_result_bucket
|
||||
metadata["storage_path"] = storage_path
|
||||
except Exception: # noqa: BLE001
|
||||
metadata["storage_upload_failed"] = True
|
||||
self._logger.warning(
|
||||
"tool result storage upload failed",
|
||||
session_id=str(session_id),
|
||||
run_id=run_id_value,
|
||||
call_id=call_id_value,
|
||||
storage_path=storage_path,
|
||||
)
|
||||
|
||||
content = summary or json.dumps(
|
||||
payload, ensure_ascii=False, separators=(",", ":")
|
||||
)
|
||||
|
||||
locked_session = await session_repo.lock_session_for_update(
|
||||
session_id=session_id
|
||||
@@ -333,3 +423,69 @@ class SqlAlchemyEventStore:
|
||||
except (InvalidOperation, TypeError, ValueError):
|
||||
return Decimal("0")
|
||||
return parsed if parsed >= 0 else Decimal("0")
|
||||
|
||||
|
||||
def _sanitize_path_component(value: str) -> str:
|
||||
compact = re.sub(r"[^A-Za-z0-9._-]", "-", value.strip())
|
||||
compact = compact.strip(".-")
|
||||
return compact or "id"
|
||||
|
||||
|
||||
def _sanitize_error(value: object) -> str | None:
|
||||
if isinstance(value, str) and value.strip():
|
||||
return " ".join(value.split())[:300]
|
||||
if isinstance(value, dict):
|
||||
for key in ("message", "error", "detail"):
|
||||
item = value.get(key)
|
||||
if isinstance(item, str) and item.strip():
|
||||
return " ".join(item.split())[:300]
|
||||
return None
|
||||
|
||||
|
||||
def _sanitize_result(value: object) -> dict[str, object]:
|
||||
if not isinstance(value, dict):
|
||||
return {}
|
||||
|
||||
def _is_sensitive_key(key: str) -> bool:
|
||||
normalized = key.strip().lower().replace("-", "_")
|
||||
if not normalized:
|
||||
return False
|
||||
exact = {
|
||||
"password",
|
||||
"token",
|
||||
"secret",
|
||||
"api_key",
|
||||
"apikey",
|
||||
"credential",
|
||||
"authorization",
|
||||
"auth",
|
||||
}
|
||||
if normalized in exact:
|
||||
return True
|
||||
patterns = (
|
||||
"password",
|
||||
"token",
|
||||
"secret",
|
||||
"auth",
|
||||
"credential",
|
||||
"api_key",
|
||||
"apikey",
|
||||
"authorization",
|
||||
)
|
||||
return any(pattern in normalized for pattern in patterns)
|
||||
|
||||
def _sanitize_value(item: object) -> object:
|
||||
if isinstance(item, dict):
|
||||
return _sanitize_result(item)
|
||||
if isinstance(item, list):
|
||||
return [_sanitize_value(entry) for entry in item]
|
||||
return item
|
||||
|
||||
sanitized: dict[str, object] = {}
|
||||
for key, item in value.items():
|
||||
key_text = str(key)
|
||||
if _is_sensitive_key(key_text):
|
||||
sanitized[str(key)] = "[REDACTED]"
|
||||
continue
|
||||
sanitized[str(key)] = _sanitize_value(item)
|
||||
return sanitized
|
||||
|
||||
@@ -0,0 +1,124 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def build_tool_content_summary(
|
||||
*,
|
||||
tool_name: str,
|
||||
args: dict[str, Any] | None,
|
||||
result: Any,
|
||||
error: Any,
|
||||
) -> str:
|
||||
error_message = _extract_error_message(error)
|
||||
if error_message is not None:
|
||||
return _truncate(f"{tool_name} 执行失败:{error_message}")
|
||||
|
||||
normalized_args = args if isinstance(args, dict) else {}
|
||||
normalized_result = result if isinstance(result, dict) else {}
|
||||
|
||||
business_failure = _extract_business_failure_message(normalized_result)
|
||||
if business_failure is not None:
|
||||
return _truncate(f"{tool_name} 执行失败:{business_failure}")
|
||||
|
||||
if tool_name == "calendar_write":
|
||||
title = _pick_first_str(normalized_result, ("title",)) or _pick_first_str(
|
||||
normalized_args, ("title",)
|
||||
)
|
||||
start_at = _pick_first_str(normalized_result, ("startAt", "start_at"))
|
||||
if title and start_at:
|
||||
return _truncate(f"已创建日程:{title}({start_at})")
|
||||
if title:
|
||||
return _truncate(f"已创建日程:{title}")
|
||||
|
||||
if tool_name == "calendar_read":
|
||||
total = _extract_total(normalized_result)
|
||||
query = _pick_first_str(normalized_args, ("query",)) or "全部"
|
||||
if total is not None:
|
||||
return _truncate(f"查询到 {total} 条日程({query})")
|
||||
|
||||
if tool_name == "calendar_delete":
|
||||
target = _pick_first_str(normalized_result, ("title", "eventId", "event_id"))
|
||||
if target:
|
||||
return _truncate(f"已删除日程:{target}")
|
||||
|
||||
if tool_name == "calendar_share":
|
||||
target = _pick_first_str(normalized_result, ("target", "user", "userName"))
|
||||
if target:
|
||||
return _truncate(f"已分享日程给 {target}")
|
||||
|
||||
if tool_name == "user_resolve":
|
||||
target = _pick_first_str(normalized_result, ("name", "userName", "userId"))
|
||||
if target:
|
||||
return _truncate(f"已匹配用户:{target}")
|
||||
|
||||
result_content = _pick_first_str(normalized_result, ("content", "message"))
|
||||
if result_content:
|
||||
return _truncate(result_content)
|
||||
|
||||
return _truncate(f"{tool_name} 执行完成")
|
||||
|
||||
|
||||
def _extract_error_message(error: Any) -> str | None:
|
||||
if isinstance(error, str) and error.strip():
|
||||
return error.strip()
|
||||
if isinstance(error, dict):
|
||||
for key in ("message", "error", "detail"):
|
||||
value = error.get(key)
|
||||
if isinstance(value, str) and value.strip():
|
||||
return value.strip()
|
||||
return None
|
||||
|
||||
|
||||
def _pick_first_str(payload: dict[str, Any], keys: tuple[str, ...]) -> str | None:
|
||||
for key in keys:
|
||||
value = payload.get(key)
|
||||
if isinstance(value, str):
|
||||
normalized = " ".join(value.split())
|
||||
if normalized:
|
||||
return normalized
|
||||
return None
|
||||
|
||||
|
||||
def _extract_total(result: dict[str, Any]) -> int | None:
|
||||
candidates: list[Any] = [result.get("total")]
|
||||
data = result.get("data")
|
||||
if isinstance(data, dict):
|
||||
candidates.append(data.get("total"))
|
||||
events = data.get("events")
|
||||
if isinstance(events, list):
|
||||
candidates.append(len(events))
|
||||
for value in candidates:
|
||||
if isinstance(value, bool):
|
||||
continue
|
||||
if isinstance(value, int) and value >= 0:
|
||||
return value
|
||||
if isinstance(value, str) and value.isdigit():
|
||||
return int(value)
|
||||
return None
|
||||
|
||||
|
||||
def _extract_business_failure_message(result: dict[str, Any]) -> str | None:
|
||||
top_ok = result.get("ok")
|
||||
if top_ok is False:
|
||||
top_message = _pick_first_str(result, ("message", "error", "detail"))
|
||||
if top_message:
|
||||
return top_message
|
||||
|
||||
data = result.get("data")
|
||||
if isinstance(data, dict) and data.get("ok") is False:
|
||||
data_message = _pick_first_str(data, ("message", "error", "detail"))
|
||||
if data_message:
|
||||
return data_message
|
||||
code = _pick_first_str(data, ("code",))
|
||||
if code:
|
||||
return code
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _truncate(text: str, limit: int = 80) -> str:
|
||||
normalized = " ".join(text.split())
|
||||
if len(normalized) <= limit:
|
||||
return normalized
|
||||
return normalized[: limit - 3] + "..."
|
||||
@@ -42,11 +42,19 @@ def build_intent_user_prompt(
|
||||
*, user_input: str | list[dict[str, Any]]
|
||||
) -> str | list[dict[str, Any]]:
|
||||
if isinstance(user_input, list):
|
||||
context_messages = _conversation_context_messages(user_input)
|
||||
context_hint = (
|
||||
json.dumps(context_messages, ensure_ascii=True, separators=(",", ":"))
|
||||
if context_messages
|
||||
else "[]"
|
||||
)
|
||||
instruction_text = "\n\n".join(
|
||||
[
|
||||
INTENT_TASK_INSTRUCTION,
|
||||
"[Output Schema]",
|
||||
_schema_json(IntentOutput),
|
||||
"[Conversation Context]",
|
||||
context_hint,
|
||||
"[User Input]",
|
||||
"Use the following multimodal blocks as the latest user input.",
|
||||
]
|
||||
@@ -127,6 +135,56 @@ def _latest_user_content_blocks(
|
||||
return []
|
||||
|
||||
|
||||
def _conversation_context_messages(
|
||||
user_input: list[dict[str, Any]],
|
||||
) -> list[dict[str, str]]:
|
||||
latest_user_index = -1
|
||||
for index in range(len(user_input) - 1, -1, -1):
|
||||
item = user_input[index]
|
||||
if isinstance(item, dict) and item.get("role") == "user":
|
||||
latest_user_index = index
|
||||
break
|
||||
|
||||
if latest_user_index <= 0:
|
||||
return []
|
||||
|
||||
context_items: list[dict[str, str]] = []
|
||||
for item in user_input[:latest_user_index]:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
role = item.get("role")
|
||||
if role not in {"user", "assistant"}:
|
||||
continue
|
||||
content = item.get("content")
|
||||
text = _content_to_text(content)
|
||||
if text:
|
||||
context_items.append({"role": str(role), "content": text})
|
||||
|
||||
if len(context_items) <= 12:
|
||||
return context_items
|
||||
return context_items[-12:]
|
||||
|
||||
|
||||
def _content_to_text(content: Any) -> str:
|
||||
if isinstance(content, str):
|
||||
return " ".join(content.split())
|
||||
if not isinstance(content, list):
|
||||
return ""
|
||||
|
||||
parts: list[str] = []
|
||||
for block in content:
|
||||
if not isinstance(block, dict):
|
||||
continue
|
||||
block_type = block.get("type")
|
||||
if block_type == "text":
|
||||
text = block.get("text")
|
||||
if isinstance(text, str) and text.strip():
|
||||
parts.append(" ".join(text.split()))
|
||||
elif block_type in {"binary", "image"}:
|
||||
parts.append("[image]")
|
||||
return " ".join(parts).strip()
|
||||
|
||||
|
||||
def _binary_source_block(item: dict[str, Any]) -> dict[str, Any] | None:
|
||||
mime_type = item.get("mimeType")
|
||||
media_type = mime_type if isinstance(mime_type, str) and mime_type else "image/png"
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Protocol
|
||||
from uuid import UUID
|
||||
|
||||
import structlog
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.logging import get_logger
|
||||
@@ -31,7 +33,9 @@ class PipelineLike(Protocol):
|
||||
class AgentRouteRuntime:
|
||||
_orchestrator: OrchestratorLike
|
||||
_pipeline: PipelineLike
|
||||
_logger = get_logger("core.agentscope.runtime.agent_route_runtime")
|
||||
_logger: structlog.stdlib.BoundLogger = get_logger(
|
||||
"core.agentscope.runtime.agent_route_runtime"
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self, *, orchestrator: OrchestratorLike, pipeline: PipelineLike
|
||||
@@ -144,15 +148,6 @@ class AgentRouteRuntime:
|
||||
"data": {"stepName": "execution"},
|
||||
},
|
||||
)
|
||||
await self._pipeline.emit(
|
||||
session_id=command.thread_id,
|
||||
event={
|
||||
"type": "step.finish",
|
||||
"threadId": command.thread_id,
|
||||
"runId": command.run_id,
|
||||
"data": {"stepName": "execution"},
|
||||
},
|
||||
)
|
||||
|
||||
await self._emit_stage_text(
|
||||
thread_id=command.thread_id,
|
||||
@@ -191,6 +186,15 @@ class AgentRouteRuntime:
|
||||
task_id=task.task_id,
|
||||
tool_calls=_task_tool_calls(task),
|
||||
)
|
||||
await self._pipeline.emit(
|
||||
session_id=command.thread_id,
|
||||
event={
|
||||
"type": "step.finish",
|
||||
"threadId": command.thread_id,
|
||||
"runId": command.run_id,
|
||||
"data": {"stepName": "execution"},
|
||||
},
|
||||
)
|
||||
|
||||
await self._pipeline.emit(
|
||||
session_id=command.thread_id,
|
||||
@@ -294,6 +298,13 @@ class AgentRouteRuntime:
|
||||
tool_name = tool_call.get("tool_name")
|
||||
if not isinstance(tool_name, str) or not tool_name:
|
||||
continue
|
||||
call_id = f"{run_id}-{task_id}-{index}"
|
||||
result_payload = _build_tool_result_event_payload(
|
||||
tool_name=tool_name,
|
||||
call_id=call_id,
|
||||
raw_result=tool_call.get("result"),
|
||||
raw_error=tool_call.get("error"),
|
||||
)
|
||||
await self._pipeline.emit(
|
||||
session_id=thread_id,
|
||||
event={
|
||||
@@ -301,18 +312,175 @@ class AgentRouteRuntime:
|
||||
"threadId": thread_id,
|
||||
"runId": run_id,
|
||||
"data": {
|
||||
"callId": f"{run_id}-{task_id}-{index}",
|
||||
"messageId": result_payload["messageId"],
|
||||
"toolCallId": call_id,
|
||||
"callId": call_id,
|
||||
"stage": "execution",
|
||||
"taskId": task_id,
|
||||
"toolName": tool_name,
|
||||
"args": tool_call.get("args", {}),
|
||||
"result": tool_call.get("result"),
|
||||
"error": tool_call.get("error"),
|
||||
"args": _sanitize_result(tool_call.get("args", {})),
|
||||
"result": result_payload["result"],
|
||||
"error": result_payload["error"],
|
||||
"ui": result_payload["ui"],
|
||||
"content": result_payload["content"],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _build_tool_result_event_payload(
|
||||
*,
|
||||
tool_name: str,
|
||||
call_id: str,
|
||||
raw_result: Any,
|
||||
raw_error: Any,
|
||||
) -> dict[str, Any]:
|
||||
result = _sanitize_result(_normalize_tool_result(raw_result))
|
||||
error = _sanitize_error(raw_error)
|
||||
|
||||
ui: dict[str, Any] | None = None
|
||||
direct_ui = result.get("ui")
|
||||
if isinstance(direct_ui, dict):
|
||||
ui = direct_ui
|
||||
elif isinstance(result.get("type"), str) and isinstance(result.get("data"), dict):
|
||||
ui = result
|
||||
|
||||
text_content = _extract_result_text_content(result)
|
||||
if text_content is None and isinstance(error, str):
|
||||
text_content = error
|
||||
if text_content is None:
|
||||
text_content = f"{tool_name} 执行完成"
|
||||
|
||||
return {
|
||||
"messageId": f"tool-result-{call_id}",
|
||||
"result": result,
|
||||
"error": error,
|
||||
"ui": ui,
|
||||
"content": text_content,
|
||||
}
|
||||
|
||||
|
||||
def _normalize_tool_result(raw_result: Any) -> dict[str, Any]:
|
||||
if isinstance(raw_result, dict):
|
||||
content = raw_result.get("content")
|
||||
if isinstance(content, str):
|
||||
parsed = _try_parse_json_object(content)
|
||||
if parsed is not None:
|
||||
return parsed
|
||||
return raw_result
|
||||
if isinstance(raw_result, str):
|
||||
parsed = _try_parse_json_object(raw_result)
|
||||
if parsed is not None:
|
||||
return parsed
|
||||
text = raw_result.strip()
|
||||
if text:
|
||||
return {"content": text}
|
||||
if raw_result is not None:
|
||||
return {"value": raw_result}
|
||||
return {}
|
||||
|
||||
|
||||
def _try_parse_json_object(value: str) -> dict[str, Any] | None:
|
||||
raw = value.strip()
|
||||
if not raw:
|
||||
return None
|
||||
try:
|
||||
parsed = json.loads(raw)
|
||||
except ValueError:
|
||||
return None
|
||||
if not isinstance(parsed, dict):
|
||||
return None
|
||||
return parsed
|
||||
|
||||
|
||||
def _extract_result_text_content(result: dict[str, Any]) -> str | None:
|
||||
content = result.get("content")
|
||||
if isinstance(content, str) and content.strip():
|
||||
return content
|
||||
data = result.get("data")
|
||||
if isinstance(data, dict):
|
||||
message = data.get("message")
|
||||
if isinstance(message, str) and message.strip():
|
||||
return message
|
||||
return None
|
||||
|
||||
|
||||
def _sanitize_error(value: Any) -> str | None:
|
||||
if isinstance(value, str) and value.strip():
|
||||
text = " ".join(value.split())
|
||||
return _redact_sensitive_text(text)[:300]
|
||||
if isinstance(value, dict):
|
||||
for key in ("message", "error", "detail"):
|
||||
item = value.get(key)
|
||||
if isinstance(item, str) and item.strip():
|
||||
text = " ".join(item.split())
|
||||
return _redact_sensitive_text(text)[:300]
|
||||
return None
|
||||
|
||||
|
||||
def _sanitize_result(value: Any) -> dict[str, Any]:
|
||||
if not isinstance(value, dict):
|
||||
return {}
|
||||
|
||||
def _is_sensitive_key(key: str) -> bool:
|
||||
normalized = key.strip().lower().replace("-", "_")
|
||||
if not normalized:
|
||||
return False
|
||||
exact = {
|
||||
"password",
|
||||
"token",
|
||||
"secret",
|
||||
"api_key",
|
||||
"apikey",
|
||||
"credential",
|
||||
"authorization",
|
||||
"auth",
|
||||
}
|
||||
if normalized in exact:
|
||||
return True
|
||||
patterns = (
|
||||
"password",
|
||||
"token",
|
||||
"secret",
|
||||
"credential",
|
||||
"api_key",
|
||||
"apikey",
|
||||
"authorization",
|
||||
)
|
||||
return any(pattern in normalized for pattern in patterns)
|
||||
|
||||
def _sanitize_value(item: Any) -> Any:
|
||||
if isinstance(item, dict):
|
||||
return _sanitize_result(item)
|
||||
if isinstance(item, list):
|
||||
return [_sanitize_value(entry) for entry in item]
|
||||
if isinstance(item, str):
|
||||
return _redact_sensitive_text(item)
|
||||
return item
|
||||
|
||||
sanitized: dict[str, Any] = {}
|
||||
for key, item in value.items():
|
||||
key_text = str(key)
|
||||
if _is_sensitive_key(key_text):
|
||||
sanitized[key_text] = "[REDACTED]"
|
||||
continue
|
||||
sanitized[key_text] = _sanitize_value(item)
|
||||
return sanitized
|
||||
|
||||
|
||||
def _redact_sensitive_text(value: str) -> str:
|
||||
redacted = value
|
||||
key_value_patterns = (
|
||||
r"(?i)(authorization)\s*[:=]\s*bearer\s+[^\s,;]+",
|
||||
r"(?i)(password|token|secret|api[_-]?key|authorization|credential)\s*[:=]\s*[^\s,;]+",
|
||||
r"(?i)(password|token|secret|api[_-]?key|authorization|credential)\s+[^\s,;]+",
|
||||
)
|
||||
for pattern in key_value_patterns:
|
||||
redacted = re.sub(pattern, r"\1=[REDACTED]", redacted)
|
||||
redacted = re.sub(r"(?i)bearer\s+[^\s,;]+", "Bearer [REDACTED]", redacted)
|
||||
return redacted
|
||||
|
||||
|
||||
def _text_end_telemetry_payload(metadata: dict[str, Any]) -> dict[str, Any]:
|
||||
payload: dict[str, Any] = {}
|
||||
model = _first_non_empty_str(metadata, keys=("model", "model_code"))
|
||||
|
||||
@@ -52,6 +52,24 @@ def _tools_payload_from_schema(
|
||||
return payload
|
||||
|
||||
|
||||
def _merge_tool_schemas(
|
||||
*schema_sets: list[dict[str, object]],
|
||||
) -> list[dict[str, object]]:
|
||||
merged: list[dict[str, object]] = []
|
||||
seen_names: set[str] = set()
|
||||
for schemas in schema_sets:
|
||||
for schema in schemas:
|
||||
function = schema.get("function")
|
||||
if not isinstance(function, dict):
|
||||
continue
|
||||
name = function.get("name")
|
||||
if not isinstance(name, str) or not name or name in seen_names:
|
||||
continue
|
||||
seen_names.add(name)
|
||||
merged.append(schema)
|
||||
return merged
|
||||
|
||||
|
||||
class AgentScopeRuntimeOrchestrator:
|
||||
_runner: Any
|
||||
_config_loader: Callable[[AsyncSession], Awaitable[dict[str, RuntimeStageConfig]]]
|
||||
@@ -96,10 +114,20 @@ class AgentScopeRuntimeOrchestrator:
|
||||
enable_hitl=False,
|
||||
)
|
||||
intent_tools_schema = intent_toolkit.get_json_schemas()
|
||||
execution_toolkit = build_stage_toolkit(
|
||||
stage="execution",
|
||||
session=session,
|
||||
owner_id=owner_id,
|
||||
user_token=user_token,
|
||||
enable_hitl=True,
|
||||
)
|
||||
execution_tools_schema = execution_toolkit.get_json_schemas()
|
||||
intent_prompt = build_system_prompt(
|
||||
stage="intent",
|
||||
user_context=user_context,
|
||||
tools=_tools_payload_from_schema(intent_tools_schema),
|
||||
tools=_tools_payload_from_schema(
|
||||
_merge_tool_schemas(intent_tools_schema, execution_tools_schema)
|
||||
),
|
||||
)
|
||||
intent_payload = await self._runner.run_json_stage(
|
||||
stage_config=stage_config["intent"],
|
||||
@@ -125,14 +153,6 @@ class AgentScopeRuntimeOrchestrator:
|
||||
|
||||
execution_output: ExecutionBatchOutput | None = None
|
||||
if intent_output.route == "TASK_EXECUTION":
|
||||
execution_toolkit = build_stage_toolkit(
|
||||
stage="execution",
|
||||
session=session,
|
||||
owner_id=owner_id,
|
||||
user_token=user_token,
|
||||
enable_hitl=True,
|
||||
)
|
||||
execution_tools_schema = execution_toolkit.get_json_schemas()
|
||||
execution_prompt = build_system_prompt(
|
||||
stage="execution",
|
||||
user_context=user_context,
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import math
|
||||
from time import perf_counter
|
||||
from typing import Any, cast
|
||||
|
||||
@@ -106,6 +107,9 @@ class AgentScopeReActRunner:
|
||||
stage_config=stage_config,
|
||||
response=response,
|
||||
latency_ms=latency_ms,
|
||||
system_prompt=system_prompt,
|
||||
user_prompt=user_prompt,
|
||||
assistant_text=text_content,
|
||||
)
|
||||
except json.JSONDecodeError as exc:
|
||||
logger.exception(
|
||||
@@ -234,6 +238,9 @@ def _merge_stage_response_metadata(
|
||||
stage_config: RuntimeStageConfig,
|
||||
response: Any,
|
||||
latency_ms: int,
|
||||
system_prompt: str,
|
||||
user_prompt: str | list[dict[str, Any]],
|
||||
assistant_text: str,
|
||||
) -> dict[str, Any]:
|
||||
result = dict(payload)
|
||||
existing = result.get("response_metadata")
|
||||
@@ -247,6 +254,15 @@ def _merge_stage_response_metadata(
|
||||
completion_tokens = _to_non_negative_int(
|
||||
_read_value(usage, "completion_tokens") or _read_value(usage, "output_tokens")
|
||||
)
|
||||
if prompt_tokens is None:
|
||||
prompt_tokens = _estimate_token_count(
|
||||
{
|
||||
"system": system_prompt,
|
||||
"user": user_prompt,
|
||||
}
|
||||
)
|
||||
if completion_tokens is None:
|
||||
completion_tokens = _estimate_token_count(assistant_text)
|
||||
cost = _to_non_negative_float(
|
||||
_read_value(usage, "cost")
|
||||
or _read_value(_read_value(usage, "metadata"), "cost")
|
||||
@@ -352,3 +368,16 @@ def _estimate_cost_by_pricing(
|
||||
)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def _estimate_token_count(value: object) -> int:
|
||||
try:
|
||||
serialized = (
|
||||
value if isinstance(value, str) else json.dumps(value, ensure_ascii=False)
|
||||
)
|
||||
except (TypeError, ValueError):
|
||||
serialized = str(value)
|
||||
normalized = serialized.strip()
|
||||
if not normalized:
|
||||
return 0
|
||||
return max(1, math.ceil(len(normalized) / 4))
|
||||
|
||||
@@ -21,6 +21,7 @@ from core.config.settings import config
|
||||
from core.db.session import AsyncSessionLocal
|
||||
from core.logging import get_logger
|
||||
from core.taskiq.app import bulk_broker, critical_broker, default_broker
|
||||
from core.agentscope.tools.tool_result_storage import create_tool_result_storage
|
||||
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
|
||||
from services.base.redis import get_or_init_redis_client
|
||||
|
||||
@@ -67,16 +68,10 @@ def _build_user_context(*, owner_id: UUID, run_input: RunCommand) -> UserAgentCo
|
||||
def _extract_user_token(
|
||||
*, command: dict[str, Any], run_input: RunCommand
|
||||
) -> str | None:
|
||||
del run_input
|
||||
raw_token = command.get("user_token")
|
||||
if isinstance(raw_token, str) and raw_token.strip():
|
||||
return raw_token.strip()
|
||||
forwarded = (
|
||||
run_input.forwarded_props if isinstance(run_input.forwarded_props, dict) else {}
|
||||
)
|
||||
for key in ("accessToken", "userToken", "token"):
|
||||
value = forwarded.get(key)
|
||||
if isinstance(value, str) and value.strip():
|
||||
return value.strip()
|
||||
return None
|
||||
|
||||
|
||||
@@ -162,7 +157,11 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
)
|
||||
pipeline = AgentScopeEventPipeline(
|
||||
codec=AgentScopeAgUiCodec(),
|
||||
store=SqlAlchemyEventStore(session_factory=AsyncSessionLocal),
|
||||
store=SqlAlchemyEventStore(
|
||||
session_factory=AsyncSessionLocal,
|
||||
tool_result_storage=create_tool_result_storage(),
|
||||
tool_result_bucket=config.storage.bucket,
|
||||
),
|
||||
bus=bus,
|
||||
)
|
||||
runtime = route_runtime_type(
|
||||
|
||||
@@ -67,6 +67,7 @@ def validate_run_request_messages_contract(run_input: RunAgentInput) -> None:
|
||||
message = run_input.messages[0]
|
||||
if getattr(message, "role", None) != "user":
|
||||
raise ValueError("RunAgentInput.messages[0].role must be user")
|
||||
_validate_user_content_blocks(getattr(message, "content", None))
|
||||
extract_latest_user_payload(run_input)
|
||||
|
||||
|
||||
@@ -106,84 +107,76 @@ def extract_latest_user_payload(
|
||||
text_parts.append(text)
|
||||
blocks.append({"type": "text", "text": text})
|
||||
continue
|
||||
if item_type not in {"image", "binary"}:
|
||||
if item_type != "binary":
|
||||
continue
|
||||
source_type: str | None = None
|
||||
source_value: str | None = None
|
||||
source_mime: str | None = None
|
||||
if item_type == "binary":
|
||||
source_mime = (
|
||||
item.get("mimeType")
|
||||
if isinstance(item, dict)
|
||||
else getattr(item, "mime_type", None)
|
||||
)
|
||||
source_url = (
|
||||
item.get("url")
|
||||
if isinstance(item, dict)
|
||||
else getattr(item, "url", None)
|
||||
)
|
||||
source_data = (
|
||||
item.get("data")
|
||||
if isinstance(item, dict)
|
||||
else getattr(item, "data", None)
|
||||
)
|
||||
if isinstance(source_url, str) and source_url:
|
||||
source_type = "url"
|
||||
source_value = source_url
|
||||
elif isinstance(source_data, str) and source_data:
|
||||
source_type = "data"
|
||||
source_value = source_data
|
||||
else:
|
||||
source = getattr(item, "source", None)
|
||||
source_type = (
|
||||
source.get("type")
|
||||
if isinstance(source, dict)
|
||||
else getattr(source, "type", None)
|
||||
)
|
||||
source_value = (
|
||||
source.get("value")
|
||||
if isinstance(source, dict)
|
||||
else getattr(source, "value", None)
|
||||
)
|
||||
source_mime = (
|
||||
source.get("mimeType")
|
||||
if isinstance(source, dict)
|
||||
else getattr(source, "mimeType", None)
|
||||
)
|
||||
if (
|
||||
source_type == "url"
|
||||
and isinstance(source_value, str)
|
||||
and source_value
|
||||
):
|
||||
source_url = (
|
||||
item.get("url")
|
||||
if isinstance(item, dict)
|
||||
else getattr(item, "url", None)
|
||||
)
|
||||
if isinstance(source_url, str) and source_url:
|
||||
blocks.append(
|
||||
{"type": "image_url", "image_url": {"url": source_value}}
|
||||
)
|
||||
elif (
|
||||
source_type == "data"
|
||||
and isinstance(source_value, str)
|
||||
and source_value
|
||||
):
|
||||
mime_type = (
|
||||
source_mime
|
||||
if isinstance(source_mime, str) and source_mime
|
||||
else "image/png"
|
||||
)
|
||||
blocks.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:{mime_type};base64,{source_value}"
|
||||
},
|
||||
}
|
||||
{"type": "image_url", "image_url": {"url": source_url}}
|
||||
)
|
||||
combined = "".join(text_parts).strip()
|
||||
if combined:
|
||||
if combined or blocks:
|
||||
return combined, blocks
|
||||
raise ValueError(
|
||||
"RunAgentInput.messages requires at least one non-empty user message"
|
||||
)
|
||||
|
||||
|
||||
def _validate_user_content_blocks(content: Any) -> None:
|
||||
if isinstance(content, str):
|
||||
if content.strip():
|
||||
return
|
||||
raise ValueError(
|
||||
"RunAgentInput.messages requires at least one non-empty user message"
|
||||
)
|
||||
if not isinstance(content, list):
|
||||
raise ValueError("RunAgentInput.messages[0].content must be string or list")
|
||||
|
||||
has_text = False
|
||||
has_binary = False
|
||||
for item in content:
|
||||
item_type = getattr(item, "type", None)
|
||||
if item_type == "text":
|
||||
text = getattr(item, "text", None)
|
||||
if isinstance(text, str) and text.strip():
|
||||
has_text = True
|
||||
continue
|
||||
if item_type == "binary":
|
||||
mime_type = (
|
||||
item.get("mimeType")
|
||||
if isinstance(item, dict)
|
||||
else getattr(item, "mime_type", None)
|
||||
)
|
||||
url = (
|
||||
item.get("url")
|
||||
if isinstance(item, dict)
|
||||
else getattr(item, "url", None)
|
||||
)
|
||||
data = (
|
||||
item.get("data")
|
||||
if isinstance(item, dict)
|
||||
else getattr(item, "data", None)
|
||||
)
|
||||
if not isinstance(mime_type, str) or not mime_type.startswith("image/"):
|
||||
raise ValueError("binary content requires image mimeType")
|
||||
if not isinstance(url, str) or not url:
|
||||
raise ValueError("binary content requires url")
|
||||
if isinstance(data, str) and data:
|
||||
raise ValueError("binary content data is not allowed")
|
||||
has_binary = True
|
||||
continue
|
||||
raise ValueError("unsupported content block type")
|
||||
|
||||
if not has_text and not has_binary:
|
||||
raise ValueError(
|
||||
"RunAgentInput.messages requires at least one non-empty user message"
|
||||
)
|
||||
|
||||
|
||||
def extract_latest_tool_result(
|
||||
run_input: RunAgentInput,
|
||||
) -> tuple[str, dict[str, object]]:
|
||||
|
||||
@@ -3,10 +3,11 @@ from __future__ import annotations
|
||||
import logging
|
||||
from logging.config import dictConfig
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
import structlog
|
||||
|
||||
from core.config.settings import PROJECT_ROOT, RuntimeSettings, Settings
|
||||
from core.config.settings import PROJECT_ROOT, RuntimeSettings, Settings, config
|
||||
from core.logging.formatters import (
|
||||
build_plain_formatter,
|
||||
build_processor_formatter,
|
||||
@@ -77,7 +78,7 @@ def build_logging_config(runtime: RuntimeSettings) -> dict[str, object]:
|
||||
|
||||
|
||||
def configure_logging(settings: Settings | None = None) -> None:
|
||||
active_settings = settings or Settings()
|
||||
active_settings = settings if settings is not None else cast(Settings, config)
|
||||
runtime = active_settings.runtime
|
||||
|
||||
try:
|
||||
|
||||
@@ -19,19 +19,14 @@ class SupabaseService(BaseServiceProvider):
|
||||
|
||||
async def initialize(self, **_: Any) -> bool:
|
||||
try:
|
||||
self._client = create_client(
|
||||
self._settings.url,
|
||||
self._settings.anon_key,
|
||||
)
|
||||
self._admin_client = create_client(
|
||||
self._settings.url,
|
||||
self._settings.service_role_key,
|
||||
)
|
||||
self._init_clients()
|
||||
self._set_initialized(True)
|
||||
self.logger.info("Supabase service initialized")
|
||||
return True
|
||||
except Exception as exc: # noqa: BLE001
|
||||
self.logger.warning("Supabase service initialization failed", error=str(exc))
|
||||
self.logger.warning(
|
||||
"Supabase service initialization failed", error=str(exc)
|
||||
)
|
||||
self._client = None
|
||||
self._admin_client = None
|
||||
self._set_initialized(False)
|
||||
@@ -51,7 +46,9 @@ class SupabaseService(BaseServiceProvider):
|
||||
return {"status": "unhealthy", "details": {"error": "not initialized"}}
|
||||
try:
|
||||
await asyncio.to_thread(client.auth.get_session)
|
||||
await asyncio.to_thread(admin_client.auth.admin.list_users, page=1, per_page=1)
|
||||
await asyncio.to_thread(
|
||||
admin_client.auth.admin.list_users, page=1, per_page=1
|
||||
)
|
||||
return {
|
||||
"status": "healthy",
|
||||
"details": {
|
||||
@@ -70,17 +67,35 @@ class SupabaseService(BaseServiceProvider):
|
||||
return self._require_admin_client()
|
||||
|
||||
def _require_client(self) -> Any:
|
||||
if self._client is None or self._admin_client is None:
|
||||
self._init_clients()
|
||||
self._set_initialized(True)
|
||||
self.logger.info("Supabase service lazily initialized")
|
||||
client = self._client
|
||||
if client is None:
|
||||
raise RuntimeError("Supabase client is not initialized")
|
||||
return client
|
||||
|
||||
def _require_admin_client(self) -> Any:
|
||||
if self._client is None or self._admin_client is None:
|
||||
self._init_clients()
|
||||
self._set_initialized(True)
|
||||
self.logger.info("Supabase service lazily initialized")
|
||||
admin_client = self._admin_client
|
||||
if admin_client is None:
|
||||
raise RuntimeError("Supabase admin client is not initialized")
|
||||
return admin_client
|
||||
|
||||
def _init_clients(self) -> None:
|
||||
self._client = create_client(
|
||||
self._settings.url,
|
||||
self._settings.anon_key,
|
||||
)
|
||||
self._admin_client = create_client(
|
||||
self._settings.url,
|
||||
self._settings.service_role_key,
|
||||
)
|
||||
|
||||
|
||||
supabase_service: SupabaseService = register_service_instance(
|
||||
"supabase", SupabaseService()
|
||||
|
||||
@@ -3,11 +3,20 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from storage3.exceptions import StorageApiError
|
||||
|
||||
from core.config.settings import config
|
||||
from services.base.supabase import supabase_service
|
||||
|
||||
|
||||
class AgentAttachmentStorage:
|
||||
def _validate_bucket(self, *, bucket: str) -> None:
|
||||
expected = config.storage.bucket
|
||||
if bucket != expected:
|
||||
raise RuntimeError("Invalid attachment bucket")
|
||||
|
||||
def _bucket_client(self, *, bucket: str) -> Any:
|
||||
self._validate_bucket(bucket=bucket)
|
||||
client = supabase_service.get_admin_client()
|
||||
storage = getattr(client, "storage", None)
|
||||
if storage is None:
|
||||
@@ -39,9 +48,82 @@ class AgentAttachmentStorage:
|
||||
},
|
||||
)
|
||||
|
||||
await asyncio.to_thread(_upload)
|
||||
try:
|
||||
await asyncio.to_thread(_upload)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
if not _is_bucket_not_found_error(exc):
|
||||
raise
|
||||
await self._ensure_bucket_exists(bucket=bucket)
|
||||
await asyncio.to_thread(_upload)
|
||||
return path
|
||||
|
||||
async def _ensure_bucket_exists(self, *, bucket: str) -> None:
|
||||
def _ensure() -> None:
|
||||
client = supabase_service.get_admin_client()
|
||||
storage = getattr(client, "storage", None)
|
||||
if storage is None:
|
||||
raise RuntimeError("Supabase storage client unavailable")
|
||||
get_bucket = getattr(storage, "get_bucket", None)
|
||||
if callable(get_bucket):
|
||||
try:
|
||||
get_bucket(bucket)
|
||||
return
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
|
||||
create_bucket = getattr(storage, "create_bucket", None)
|
||||
if not callable(create_bucket):
|
||||
raise RuntimeError("Supabase storage create_bucket is unavailable")
|
||||
try:
|
||||
create_bucket(bucket, options={"public": False})
|
||||
except Exception as exc: # noqa: BLE001
|
||||
message = str(exc).lower()
|
||||
if "already exists" in message or "duplicate" in message:
|
||||
return
|
||||
raise
|
||||
|
||||
await asyncio.to_thread(_ensure)
|
||||
|
||||
async def download_bytes(self, *, bucket: str, path: str) -> bytes:
|
||||
def _download() -> object:
|
||||
bucket_client = self._bucket_client(bucket=bucket)
|
||||
download = getattr(bucket_client, "download", None)
|
||||
if not callable(download):
|
||||
raise RuntimeError("Supabase storage download is unavailable")
|
||||
return download(path)
|
||||
|
||||
raw = await asyncio.to_thread(_download)
|
||||
if isinstance(raw, bytes):
|
||||
return raw
|
||||
if isinstance(raw, bytearray):
|
||||
return bytes(raw)
|
||||
if isinstance(raw, memoryview):
|
||||
return raw.tobytes()
|
||||
raise RuntimeError("Invalid attachment payload")
|
||||
|
||||
async def create_signed_url(
|
||||
self,
|
||||
*,
|
||||
bucket: str,
|
||||
path: str,
|
||||
expires_in_seconds: int,
|
||||
) -> str:
|
||||
def _create_signed_url() -> object:
|
||||
bucket_client = self._bucket_client(bucket=bucket)
|
||||
signer = getattr(bucket_client, "create_signed_url", None)
|
||||
if not callable(signer):
|
||||
raise RuntimeError("Supabase storage signed url is unavailable")
|
||||
return signer(path, expires_in_seconds)
|
||||
|
||||
raw = await asyncio.to_thread(_create_signed_url)
|
||||
if isinstance(raw, str):
|
||||
return raw
|
||||
if isinstance(raw, dict):
|
||||
signed_url = raw.get("signedURL") or raw.get("signedUrl") or raw.get("url")
|
||||
if isinstance(signed_url, str) and signed_url:
|
||||
return signed_url
|
||||
raise RuntimeError("Invalid signed url payload")
|
||||
|
||||
|
||||
def create_attachment_storage() -> AgentAttachmentStorage | None:
|
||||
try:
|
||||
@@ -49,3 +131,11 @@ def create_attachment_storage() -> AgentAttachmentStorage | None:
|
||||
except Exception:
|
||||
return None
|
||||
return AgentAttachmentStorage()
|
||||
|
||||
|
||||
def _is_bucket_not_found_error(exc: Exception) -> bool:
|
||||
if isinstance(exc, StorageApiError):
|
||||
message = str(exc).lower()
|
||||
return "bucket" in message and "not found" in message
|
||||
message = str(exc).lower()
|
||||
return "bucket" in message and "not found" in message
|
||||
|
||||
@@ -9,6 +9,7 @@ from fastapi import HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.config.settings import config
|
||||
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
|
||||
from models.agent_chat_session import AgentChatSession
|
||||
|
||||
@@ -200,6 +201,61 @@ class AgentRepository:
|
||||
return None
|
||||
return str(latest_id)
|
||||
|
||||
async def get_message_attachment_reference(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
message_id: str,
|
||||
attachment_index: int,
|
||||
) -> dict[str, str] | None:
|
||||
try:
|
||||
session_uuid = UUID(session_id)
|
||||
message_uuid = UUID(message_id)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(
|
||||
status_code=422, detail="Invalid message/session id"
|
||||
) from exc
|
||||
|
||||
stmt = (
|
||||
select(AgentChatMessage)
|
||||
.where(AgentChatMessage.id == message_uuid)
|
||||
.where(AgentChatMessage.session_id == session_uuid)
|
||||
.where(AgentChatMessage.deleted_at.is_(None))
|
||||
)
|
||||
message = (await self._session.execute(stmt)).scalar_one_or_none()
|
||||
if message is None:
|
||||
return None
|
||||
|
||||
metadata = (
|
||||
message.metadata_json if isinstance(message.metadata_json, dict) else {}
|
||||
)
|
||||
attachments_raw = metadata.get("attachments")
|
||||
if not isinstance(attachments_raw, list):
|
||||
return None
|
||||
if attachment_index < 0 or attachment_index >= len(attachments_raw):
|
||||
return None
|
||||
|
||||
attachment = attachments_raw[attachment_index]
|
||||
if not isinstance(attachment, dict):
|
||||
return None
|
||||
bucket = attachment.get("bucket")
|
||||
path = attachment.get("path")
|
||||
mime_type = attachment.get("mimeType")
|
||||
if (
|
||||
not isinstance(bucket, str)
|
||||
or not bucket
|
||||
or not isinstance(path, str)
|
||||
or not path
|
||||
or not isinstance(mime_type, str)
|
||||
or not mime_type
|
||||
):
|
||||
return None
|
||||
return {
|
||||
"bucket": bucket,
|
||||
"path": path,
|
||||
"mimeType": mime_type,
|
||||
}
|
||||
|
||||
async def _to_snapshot_message(
|
||||
self, message: AgentChatMessage
|
||||
) -> dict[str, object]:
|
||||
@@ -233,30 +289,65 @@ class AgentRepository:
|
||||
storage_bucket = metadata.get("storage_bucket")
|
||||
storage_path = metadata.get("storage_path")
|
||||
if isinstance(storage_bucket, str) and isinstance(storage_path, str):
|
||||
try:
|
||||
hydrated_content = await self._tool_result_storage.read_json(
|
||||
bucket=storage_bucket,
|
||||
path=storage_path,
|
||||
expected_bucket = config.storage.bucket
|
||||
message_session_id = getattr(message, "session_id", None)
|
||||
expected_prefix = (
|
||||
f"tool-results/{message_session_id}/"
|
||||
if message_session_id is not None
|
||||
else None
|
||||
)
|
||||
tool_call_id = metadata.get("tool_call_id")
|
||||
is_legacy_path = isinstance(
|
||||
tool_call_id, str
|
||||
) and storage_path.endswith(f"/{tool_call_id}.json")
|
||||
if (
|
||||
storage_bucket == expected_bucket
|
||||
and _is_safe_storage_path(storage_path)
|
||||
and (
|
||||
(
|
||||
expected_prefix is not None
|
||||
and storage_path.startswith(expected_prefix)
|
||||
)
|
||||
or (
|
||||
storage_path.startswith("tool-results/")
|
||||
and is_legacy_path
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
hydrated_content = None
|
||||
):
|
||||
try:
|
||||
hydrated_content = (
|
||||
await self._tool_result_storage.read_json(
|
||||
bucket=storage_bucket,
|
||||
path=storage_path,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
hydrated_content = None
|
||||
|
||||
resolved_content = hydrated_content or parsed_content
|
||||
payload["content"] = message.content
|
||||
if resolved_content is not None:
|
||||
result = resolved_content.get("result")
|
||||
if isinstance(result, dict):
|
||||
result_content = result.get("content")
|
||||
if isinstance(result_content, str):
|
||||
payload["content"] = result_content
|
||||
ui = resolved_content.get("ui")
|
||||
if not isinstance(ui, dict):
|
||||
ui = resolved_content.get("ui_schema")
|
||||
if isinstance(ui, dict):
|
||||
payload["ui"] = ui
|
||||
display_content = resolved_content.get("content")
|
||||
if isinstance(display_content, str):
|
||||
if not isinstance(display_content, str):
|
||||
nested_result = resolved_content.get("result")
|
||||
if isinstance(nested_result, dict):
|
||||
nested_content = nested_result.get("content")
|
||||
if isinstance(nested_content, str):
|
||||
display_content = nested_content
|
||||
if (
|
||||
isinstance(display_content, str)
|
||||
and display_content.strip()
|
||||
and (
|
||||
not payload["content"]
|
||||
or _looks_like_offloaded_placeholder(str(payload["content"]))
|
||||
)
|
||||
):
|
||||
payload["content"] = display_content
|
||||
|
||||
if "content" not in payload:
|
||||
payload["content"] = message.content
|
||||
else:
|
||||
payload["content"] = message.content
|
||||
metadata = message.metadata_json or {}
|
||||
@@ -264,7 +355,22 @@ class AgentRepository:
|
||||
metadata.get("attachments") if isinstance(metadata, dict) else None
|
||||
)
|
||||
if isinstance(attachments, list):
|
||||
rendered = [item for item in attachments if isinstance(item, dict)]
|
||||
rendered: list[dict[str, object]] = []
|
||||
for index, item in enumerate(attachments):
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
mime_type = item.get("mimeType")
|
||||
if not isinstance(mime_type, str) or not mime_type:
|
||||
continue
|
||||
rendered.append(
|
||||
{
|
||||
"mimeType": mime_type,
|
||||
"previewPath": (
|
||||
f"/api/v1/agent/runs/{message.session_id}/attachments/"
|
||||
f"{message.id}/{index}"
|
||||
),
|
||||
}
|
||||
)
|
||||
if rendered:
|
||||
payload["attachments"] = rendered
|
||||
return payload
|
||||
@@ -279,3 +385,19 @@ def _derive_session_title(content_text: str) -> str | None:
|
||||
if not normalized:
|
||||
return None
|
||||
return normalized[:80]
|
||||
|
||||
|
||||
def _is_safe_storage_path(path: str) -> bool:
|
||||
normalized = path.strip()
|
||||
if not normalized:
|
||||
return False
|
||||
if normalized.startswith("/"):
|
||||
return False
|
||||
if ".." in normalized:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _looks_like_offloaded_placeholder(content: str) -> bool:
|
||||
normalized = content.strip().lower()
|
||||
return normalized in {'{"offloaded":true}', '{"offloaded": true}'}
|
||||
|
||||
@@ -10,7 +10,17 @@ import time
|
||||
from typing import Annotated, Union
|
||||
|
||||
from ag_ui.core import RunAgentInput
|
||||
from fastapi import APIRouter, Depends, Header, Query, Request, status, UploadFile
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
File,
|
||||
Form,
|
||||
Header,
|
||||
Query,
|
||||
Request,
|
||||
status,
|
||||
UploadFile,
|
||||
)
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
@@ -20,11 +30,18 @@ from core.agentscope.schemas.agui_input import (
|
||||
parse_run_input,
|
||||
validate_run_request_messages_contract,
|
||||
)
|
||||
from core.auth.jwt_verifier import JwtVerifier, TokenValidationError
|
||||
from core.auth.models import CurrentUser
|
||||
from core.config.settings import config
|
||||
from core.logging import get_logger
|
||||
from services.base.redis import get_or_init_redis_client
|
||||
from v1.agent.dependencies import get_agent_service
|
||||
from v1.agent.schemas import AsrTranscribeResponse, TaskAcceptedResponse
|
||||
from v1.agent.schemas import (
|
||||
AsrTranscribeResponse,
|
||||
AttachmentReference,
|
||||
AttachmentUploadResponse,
|
||||
TaskAcceptedResponse,
|
||||
)
|
||||
from v1.agent.service import AgentService, asr_service
|
||||
from v1.users.dependencies import get_current_user
|
||||
|
||||
@@ -38,6 +55,7 @@ _SSE_SLOT_TTL_SECONDS = 15 * 60
|
||||
_MAX_TRANSCRIBE_AUDIO_BYTES = 10 * 1024 * 1024
|
||||
_TRANSCRIBE_READ_CHUNK_BYTES = 1024 * 1024
|
||||
_MULTIPART_OVERHEAD_BYTES = 64 * 1024
|
||||
_MAX_ATTACHMENT_UPLOAD_BYTES = 5 * 1024 * 1024
|
||||
_WAV_HEADER_MIN_BYTES = 12
|
||||
_ALLOWED_AUDIO_CONTENT_TYPES = {
|
||||
"audio/wav",
|
||||
@@ -46,6 +64,42 @@ _ALLOWED_AUDIO_CONTENT_TYPES = {
|
||||
}
|
||||
|
||||
|
||||
def _verified_access_token_for_user(
|
||||
*,
|
||||
authorization: str | None,
|
||||
current_user: CurrentUser,
|
||||
) -> str | None:
|
||||
if not isinstance(authorization, str):
|
||||
return None
|
||||
normalized = authorization.strip()
|
||||
if not normalized:
|
||||
return None
|
||||
if not normalized.lower().startswith("bearer "):
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
token = normalized[7:].strip()
|
||||
if not token:
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
jwt_secret = config.supabase.jwt_secret
|
||||
if jwt_secret is None:
|
||||
raise HTTPException(status_code=503, detail="Auth verifier unavailable")
|
||||
|
||||
verifier = JwtVerifier(
|
||||
issuer=str(config.supabase.jwt_issuer),
|
||||
jwt_secret=jwt_secret.get_secret_value(),
|
||||
jwt_algorithm=config.supabase.jwt_algorithm,
|
||||
)
|
||||
try:
|
||||
payload = verifier.verify(token)
|
||||
except TokenValidationError as exc:
|
||||
raise HTTPException(status_code=401, detail="Unauthorized") from exc
|
||||
|
||||
subject = payload.get("sub")
|
||||
if not isinstance(subject, str) or subject != str(current_user.id):
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
return token
|
||||
|
||||
|
||||
def _looks_like_wav_header(header: bytes) -> bool:
|
||||
if len(header) < _WAV_HEADER_MIN_BYTES:
|
||||
return False
|
||||
@@ -111,6 +165,7 @@ async def enqueue_run(
|
||||
request: RunAgentInput,
|
||||
service: Annotated[AgentService, Depends(get_agent_service)],
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
authorization: str | None = Header(default=None, alias="Authorization"),
|
||||
) -> TaskAcceptedResponse:
|
||||
try:
|
||||
normalized = parse_run_input(request.model_dump(mode="json", by_alias=True))
|
||||
@@ -120,10 +175,15 @@ async def enqueue_run(
|
||||
allowed = await _allow_run_request(user_id=str(current_user.id))
|
||||
if not allowed:
|
||||
raise HTTPException(status_code=429, detail="Too many run requests")
|
||||
user_token = _verified_access_token_for_user(
|
||||
authorization=authorization,
|
||||
current_user=current_user,
|
||||
)
|
||||
|
||||
task = await service.enqueue_run(
|
||||
run_input=request,
|
||||
current_user=current_user,
|
||||
user_token=user_token,
|
||||
)
|
||||
return TaskAcceptedResponse(
|
||||
taskId=task.task_id,
|
||||
@@ -143,6 +203,7 @@ async def enqueue_resume(
|
||||
request: RunAgentInput,
|
||||
service: Annotated[AgentService, Depends(get_agent_service)],
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
authorization: str | None = Header(default=None, alias="Authorization"),
|
||||
) -> TaskAcceptedResponse:
|
||||
if request.thread_id != thread_id:
|
||||
raise HTTPException(status_code=422, detail="thread_id path/body mismatch")
|
||||
@@ -154,10 +215,15 @@ async def enqueue_resume(
|
||||
allowed = await _allow_run_request(user_id=str(current_user.id))
|
||||
if not allowed:
|
||||
raise HTTPException(status_code=429, detail="Too many run requests")
|
||||
user_token = _verified_access_token_for_user(
|
||||
authorization=authorization,
|
||||
current_user=current_user,
|
||||
)
|
||||
task = await service.enqueue_resume(
|
||||
thread_id=thread_id,
|
||||
run_input=request,
|
||||
current_user=current_user,
|
||||
user_token=user_token,
|
||||
)
|
||||
return TaskAcceptedResponse(
|
||||
taskId=task.task_id,
|
||||
@@ -253,6 +319,31 @@ async def get_history_snapshot(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/runs/{thread_id}/attachments/{message_id}/{attachment_index}")
|
||||
async def get_attachment_preview(
|
||||
thread_id: str,
|
||||
message_id: str,
|
||||
attachment_index: int,
|
||||
service: Annotated[AgentService, Depends(get_agent_service)],
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
) -> StreamingResponse:
|
||||
if attachment_index < 0:
|
||||
raise HTTPException(status_code=422, detail="Invalid attachment index")
|
||||
payload, mime_type = await service.get_attachment_preview(
|
||||
thread_id=thread_id,
|
||||
message_id=message_id,
|
||||
attachment_index=attachment_index,
|
||||
current_user=current_user,
|
||||
)
|
||||
return StreamingResponse(
|
||||
iter([payload]),
|
||||
media_type=mime_type,
|
||||
headers={
|
||||
"Cache-Control": "private, max-age=300",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/history")
|
||||
async def get_user_history_snapshot(
|
||||
service: Annotated[AgentService, Depends(get_agent_service)],
|
||||
@@ -267,6 +358,34 @@ async def get_user_history_snapshot(
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/attachments",
|
||||
response_model=AttachmentUploadResponse,
|
||||
status_code=status.HTTP_200_OK,
|
||||
)
|
||||
async def upload_attachment(
|
||||
service: Annotated[AgentService, Depends(get_agent_service)],
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
thread_id: str = Form(alias="threadId"),
|
||||
file: UploadFile = File(),
|
||||
) -> AttachmentUploadResponse:
|
||||
payload = await file.read()
|
||||
if not payload:
|
||||
raise HTTPException(status_code=422, detail="Empty attachment")
|
||||
if len(payload) > _MAX_ATTACHMENT_UPLOAD_BYTES:
|
||||
raise HTTPException(status_code=413, detail="Attachment too large")
|
||||
attachment = await service.upload_attachment(
|
||||
thread_id=thread_id,
|
||||
filename=file.filename,
|
||||
content_type=file.content_type,
|
||||
payload=payload,
|
||||
current_user=current_user,
|
||||
)
|
||||
return AttachmentUploadResponse(
|
||||
attachment=AttachmentReference.model_validate(attachment),
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/transcribe",
|
||||
response_model=AsrTranscribeResponse,
|
||||
|
||||
@@ -14,3 +14,16 @@ class TaskAcceptedResponse(BaseModel):
|
||||
|
||||
class AsrTranscribeResponse(BaseModel):
|
||||
transcript: str = Field(description="Transcribed text from audio")
|
||||
|
||||
|
||||
class AttachmentReference(BaseModel):
|
||||
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
|
||||
|
||||
bucket: str
|
||||
path: str
|
||||
mime_type: str = Field(alias="mimeType")
|
||||
url: str
|
||||
|
||||
|
||||
class AttachmentUploadResponse(BaseModel):
|
||||
attachment: AttachmentReference
|
||||
|
||||
+297
-60
@@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
from dataclasses import dataclass
|
||||
from datetime import date
|
||||
import hashlib
|
||||
@@ -19,17 +18,22 @@ from core.config.settings import config
|
||||
from core.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
_ALLOWED_ATTACHMENT_MIME_TYPES = {"image/png", "image/jpeg", "image/webp"}
|
||||
_MAX_ATTACHMENT_BYTES = 5 * 1024 * 1024
|
||||
_MAX_TOTAL_ATTACHMENT_BYTES = 12 * 1024 * 1024
|
||||
|
||||
|
||||
def _extract_user_token_from_run_input(run_input: RunAgentInput) -> str | None:
|
||||
forwarded = run_input.forwarded_props
|
||||
if not isinstance(forwarded, dict):
|
||||
def _normalize_bearer_token(value: str | None) -> str | None:
|
||||
if not isinstance(value, str):
|
||||
return None
|
||||
for key in ("accessToken", "userToken", "token"):
|
||||
value = forwarded.get(key)
|
||||
if isinstance(value, str) and value.strip():
|
||||
return value.strip()
|
||||
return None
|
||||
normalized = value.strip()
|
||||
if not normalized:
|
||||
return None
|
||||
lower = normalized.lower()
|
||||
if lower.startswith("bearer "):
|
||||
token = normalized[7:].strip()
|
||||
return token or None
|
||||
return normalized
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -66,6 +70,14 @@ class AgentRepositoryLike(Protocol):
|
||||
metadata: dict[str, object] | None,
|
||||
) -> None: ...
|
||||
|
||||
async def get_message_attachment_reference(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
message_id: str,
|
||||
attachment_index: int,
|
||||
) -> dict[str, str] | None: ...
|
||||
|
||||
|
||||
class QueueClientLike(Protocol):
|
||||
async def enqueue(
|
||||
@@ -92,6 +104,16 @@ class AttachmentStorageLike(Protocol):
|
||||
content_type: str,
|
||||
) -> str: ...
|
||||
|
||||
async def download_bytes(self, *, bucket: str, path: str) -> bytes: ...
|
||||
|
||||
async def create_signed_url(
|
||||
self,
|
||||
*,
|
||||
bucket: str,
|
||||
path: str,
|
||||
expires_in_seconds: int,
|
||||
) -> str: ...
|
||||
|
||||
|
||||
def ensure_session_owner(*, owner_id: str, current_user: CurrentUser) -> None:
|
||||
if owner_id != str(current_user.id):
|
||||
@@ -104,6 +126,8 @@ class AgentService:
|
||||
_stream: EventStreamLike
|
||||
_attachment_storage: AttachmentStorageLike | None
|
||||
|
||||
_SIGNED_URL_EXPIRES_IN_SECONDS = 3600
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -122,6 +146,7 @@ class AgentService:
|
||||
*,
|
||||
run_input: RunAgentInput,
|
||||
current_user: CurrentUser,
|
||||
user_token: str | None = None,
|
||||
) -> TaskAccepted:
|
||||
created = False
|
||||
thread_id = run_input.thread_id
|
||||
@@ -161,7 +186,7 @@ class AgentService:
|
||||
command={
|
||||
"command": "run",
|
||||
"owner_id": str(current_user.id),
|
||||
"user_token": _extract_user_token_from_run_input(run_input),
|
||||
"user_token": _normalize_bearer_token(user_token),
|
||||
"run_input": run_input.model_dump(mode="json", by_alias=True),
|
||||
},
|
||||
dedup_key=None,
|
||||
@@ -179,57 +204,115 @@ class AgentService:
|
||||
run_input: RunAgentInput,
|
||||
current_user: CurrentUser,
|
||||
) -> tuple[str, dict[str, object] | None]:
|
||||
text, content_blocks = extract_latest_user_payload(run_input)
|
||||
text, _ = extract_latest_user_payload(run_input)
|
||||
content_blocks = _extract_latest_user_content_blocks(run_input)
|
||||
attachments: list[dict[str, object]] = []
|
||||
if self._attachment_storage is not None:
|
||||
for index, block in enumerate(content_blocks):
|
||||
if not isinstance(block, dict):
|
||||
continue
|
||||
if block.get("type") != "image_url":
|
||||
continue
|
||||
image_value = block.get("image_url")
|
||||
if not isinstance(image_value, dict):
|
||||
continue
|
||||
url = image_value.get("url")
|
||||
if not isinstance(url, str) or not url.startswith("data:"):
|
||||
continue
|
||||
decoded = _decode_data_url(url)
|
||||
if decoded is None:
|
||||
continue
|
||||
mime_type, payload = decoded
|
||||
suffix = _mime_to_suffix(mime_type)
|
||||
checksum = hashlib.sha1(payload).hexdigest()[:16]
|
||||
path = (
|
||||
f"agent-inputs/{current_user.id}/{run_input.thread_id}/"
|
||||
f"{run_input.run_id}/attachment-{index}-{checksum}.{suffix}"
|
||||
binary_blocks = [
|
||||
block
|
||||
for block in content_blocks
|
||||
if isinstance(block, dict) and block.get("type") == "binary"
|
||||
]
|
||||
if binary_blocks:
|
||||
if self._attachment_storage is None:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Attachment storage unavailable",
|
||||
)
|
||||
bucket_name = config.storage.bucket
|
||||
forwarded_props = (
|
||||
run_input.forwarded_props
|
||||
if isinstance(run_input.forwarded_props, dict)
|
||||
else {}
|
||||
)
|
||||
raw_attachments = forwarded_props.get("attachments")
|
||||
if not isinstance(raw_attachments, list):
|
||||
raise HTTPException(
|
||||
status_code=422, detail="Invalid attachments payload"
|
||||
)
|
||||
if len(raw_attachments) != len(binary_blocks):
|
||||
raise HTTPException(
|
||||
status_code=422, detail="Invalid attachments payload"
|
||||
)
|
||||
|
||||
total_attachment_bytes = 0
|
||||
expected_prefix = f"agent-inputs/{current_user.id}/{run_input.thread_id}/"
|
||||
for index, raw_attachment in enumerate(raw_attachments):
|
||||
if not isinstance(raw_attachment, dict):
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail="Invalid attachment reference",
|
||||
)
|
||||
bucket = raw_attachment.get("bucket")
|
||||
path = raw_attachment.get("path")
|
||||
mime_type = raw_attachment.get("mimeType")
|
||||
if (
|
||||
not isinstance(bucket, str)
|
||||
or not isinstance(path, str)
|
||||
or not isinstance(mime_type, str)
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail="Invalid attachment reference",
|
||||
)
|
||||
if bucket != config.storage.bucket:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
if not _is_safe_attachment_path(path, expected_prefix=expected_prefix):
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
if mime_type.lower() not in _ALLOWED_ATTACHMENT_MIME_TYPES:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail="Unsupported attachment type",
|
||||
)
|
||||
|
||||
binary_block = binary_blocks[index]
|
||||
binary_mime = binary_block.get("mimeType")
|
||||
binary_url = binary_block.get("url")
|
||||
if (
|
||||
not isinstance(binary_mime, str)
|
||||
or binary_mime != mime_type
|
||||
or not isinstance(binary_url, str)
|
||||
or not binary_url
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail="Invalid attachments payload",
|
||||
)
|
||||
|
||||
try:
|
||||
stored_path = await self._attachment_storage.upload_bytes(
|
||||
bucket=bucket_name,
|
||||
payload = await self._attachment_storage.download_bytes(
|
||||
bucket=bucket,
|
||||
path=path,
|
||||
content=payload,
|
||||
content_type=mime_type,
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
logger.exception(
|
||||
"Attachment upload failed",
|
||||
"Attachment validation download failed",
|
||||
extra={
|
||||
"bucket": bucket_name,
|
||||
"bucket": bucket,
|
||||
"path": path,
|
||||
"mime_type": mime_type,
|
||||
"thread_id": run_input.thread_id,
|
||||
"run_id": run_input.run_id,
|
||||
},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to upload attachment",
|
||||
detail="Failed to fetch attachment",
|
||||
)
|
||||
payload_size = len(payload)
|
||||
if payload_size > _MAX_ATTACHMENT_BYTES:
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail="Attachment too large",
|
||||
)
|
||||
total_attachment_bytes += payload_size
|
||||
if total_attachment_bytes > _MAX_TOTAL_ATTACHMENT_BYTES:
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail="Attachments too large",
|
||||
)
|
||||
|
||||
attachments.append(
|
||||
{
|
||||
"bucket": bucket_name,
|
||||
"path": stored_path,
|
||||
"bucket": bucket,
|
||||
"path": path,
|
||||
"mimeType": mime_type,
|
||||
}
|
||||
)
|
||||
@@ -238,12 +321,94 @@ class AgentService:
|
||||
metadata["attachments"] = attachments
|
||||
return text, metadata or None
|
||||
|
||||
async def upload_attachment(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
filename: str | None,
|
||||
content_type: str | None,
|
||||
payload: bytes,
|
||||
current_user: CurrentUser,
|
||||
) -> dict[str, str]:
|
||||
try:
|
||||
owner = await self._repository.get_session_owner(session_id=thread_id)
|
||||
except HTTPException as exc:
|
||||
if exc.status_code != 404:
|
||||
raise
|
||||
try:
|
||||
await self._repository.create_session_for_user(
|
||||
user_id=str(current_user.id),
|
||||
session_id=thread_id,
|
||||
)
|
||||
await self._repository.commit()
|
||||
except IntegrityError:
|
||||
await self._repository.rollback()
|
||||
owner = await self._repository.get_session_owner(session_id=thread_id)
|
||||
ensure_session_owner(owner_id=owner, current_user=current_user)
|
||||
else:
|
||||
ensure_session_owner(owner_id=owner, current_user=current_user)
|
||||
if self._attachment_storage is None:
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Attachment storage unavailable"
|
||||
)
|
||||
|
||||
if not isinstance(content_type, str):
|
||||
raise HTTPException(status_code=422, detail="Unsupported attachment type")
|
||||
mime_type = content_type.lower()
|
||||
if mime_type not in _ALLOWED_ATTACHMENT_MIME_TYPES:
|
||||
raise HTTPException(status_code=422, detail="Unsupported attachment type")
|
||||
if not payload:
|
||||
raise HTTPException(status_code=422, detail="Empty attachment")
|
||||
if len(payload) > _MAX_ATTACHMENT_BYTES:
|
||||
raise HTTPException(status_code=413, detail="Attachment too large")
|
||||
|
||||
suffix = _mime_to_suffix(mime_type)
|
||||
checksum = hashlib.sha1(payload).hexdigest()[:16]
|
||||
filename_seed = filename if isinstance(filename, str) and filename else "upload"
|
||||
filename_hash = hashlib.sha1(filename_seed.encode("utf-8")).hexdigest()[:8]
|
||||
path = (
|
||||
f"agent-inputs/{current_user.id}/{thread_id}/uploads/"
|
||||
f"{filename_hash}-{checksum}.{suffix}"
|
||||
)
|
||||
bucket_name = config.storage.bucket
|
||||
try:
|
||||
stored_path = await self._attachment_storage.upload_bytes(
|
||||
bucket=bucket_name,
|
||||
path=path,
|
||||
content=payload,
|
||||
content_type=mime_type,
|
||||
)
|
||||
signed_url = await self._attachment_storage.create_signed_url(
|
||||
bucket=bucket_name,
|
||||
path=stored_path,
|
||||
expires_in_seconds=self._SIGNED_URL_EXPIRES_IN_SECONDS,
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
logger.exception(
|
||||
"Attachment upload failed",
|
||||
extra={
|
||||
"bucket": bucket_name,
|
||||
"path": path,
|
||||
"mime_type": mime_type,
|
||||
"thread_id": thread_id,
|
||||
},
|
||||
)
|
||||
raise HTTPException(status_code=502, detail="Failed to upload attachment")
|
||||
|
||||
return {
|
||||
"bucket": bucket_name,
|
||||
"path": stored_path,
|
||||
"mimeType": mime_type,
|
||||
"url": signed_url,
|
||||
}
|
||||
|
||||
async def enqueue_resume(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
run_input: RunAgentInput,
|
||||
current_user: CurrentUser,
|
||||
user_token: str | None = None,
|
||||
) -> TaskAccepted:
|
||||
owner = await self._repository.get_session_owner(session_id=thread_id)
|
||||
ensure_session_owner(owner_id=owner, current_user=current_user)
|
||||
@@ -253,7 +418,7 @@ class AgentService:
|
||||
command={
|
||||
"command": "resume",
|
||||
"owner_id": str(current_user.id),
|
||||
"user_token": _extract_user_token_from_run_input(run_input),
|
||||
"user_token": _normalize_bearer_token(user_token),
|
||||
"run_input": run_input.model_dump(mode="json", by_alias=True),
|
||||
},
|
||||
dedup_key=dedup_key,
|
||||
@@ -336,6 +501,63 @@ class AgentService:
|
||||
current_user=current_user,
|
||||
)
|
||||
|
||||
async def get_attachment_preview(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
message_id: str,
|
||||
attachment_index: int,
|
||||
current_user: CurrentUser,
|
||||
) -> tuple[bytes, str]:
|
||||
owner = await self._repository.get_session_owner(session_id=thread_id)
|
||||
ensure_session_owner(owner_id=owner, current_user=current_user)
|
||||
if self._attachment_storage is None:
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Attachment storage unavailable"
|
||||
)
|
||||
|
||||
ref = await self._repository.get_message_attachment_reference(
|
||||
session_id=thread_id,
|
||||
message_id=message_id,
|
||||
attachment_index=attachment_index,
|
||||
)
|
||||
if ref is None:
|
||||
raise HTTPException(status_code=404, detail="Attachment not found")
|
||||
|
||||
bucket = ref.get("bucket")
|
||||
path = ref.get("path")
|
||||
mime_type = ref.get("mimeType")
|
||||
if (
|
||||
not isinstance(bucket, str)
|
||||
or not isinstance(path, str)
|
||||
or not isinstance(mime_type, str)
|
||||
):
|
||||
raise HTTPException(status_code=404, detail="Attachment not found")
|
||||
if bucket != config.storage.bucket:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
|
||||
expected_prefix = f"agent-inputs/{current_user.id}/{thread_id}/"
|
||||
if not _is_safe_attachment_path(path, expected_prefix=expected_prefix):
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
|
||||
try:
|
||||
payload = await self._attachment_storage.download_bytes(
|
||||
bucket=bucket,
|
||||
path=path,
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
logger.exception(
|
||||
"Attachment download failed",
|
||||
extra={
|
||||
"thread_id": thread_id,
|
||||
"message_id": message_id,
|
||||
"attachment_index": attachment_index,
|
||||
"bucket": bucket,
|
||||
},
|
||||
)
|
||||
raise HTTPException(status_code=502, detail="Failed to fetch attachment")
|
||||
return payload, mime_type
|
||||
|
||||
|
||||
class AsrService:
|
||||
def __init__(self) -> None:
|
||||
@@ -445,22 +667,26 @@ class AsrService:
|
||||
asr_service = AsrService()
|
||||
|
||||
|
||||
def _decode_data_url(data_url: str) -> tuple[str, bytes] | None:
|
||||
if not data_url.startswith("data:"):
|
||||
return None
|
||||
header, sep, payload = data_url.partition(",")
|
||||
if not sep:
|
||||
return None
|
||||
mime_type = "image/png"
|
||||
if ";" in header:
|
||||
maybe_mime = header[5:].split(";", 1)[0].strip()
|
||||
if maybe_mime:
|
||||
mime_type = maybe_mime
|
||||
try:
|
||||
decoded = base64.b64decode(payload, validate=True)
|
||||
except ValueError:
|
||||
return None
|
||||
return mime_type, decoded
|
||||
def _extract_latest_user_content_blocks(
|
||||
run_input: RunAgentInput,
|
||||
) -> list[dict[str, Any]]:
|
||||
if not run_input.messages:
|
||||
return []
|
||||
latest = run_input.messages[-1]
|
||||
content = getattr(latest, "content", None)
|
||||
if not isinstance(content, list):
|
||||
return []
|
||||
blocks: list[dict[str, Any]] = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
blocks.append(item)
|
||||
continue
|
||||
model_dump = getattr(item, "model_dump", None)
|
||||
if callable(model_dump):
|
||||
dumped = model_dump(mode="json", by_alias=True, exclude_none=True)
|
||||
if isinstance(dumped, dict):
|
||||
blocks.append(dumped)
|
||||
return blocks
|
||||
|
||||
|
||||
def _mime_to_suffix(mime_type: str) -> str:
|
||||
@@ -470,3 +696,14 @@ def _mime_to_suffix(mime_type: str) -> str:
|
||||
"image/webp": "webp",
|
||||
}
|
||||
return mapping.get(mime_type.lower(), "bin")
|
||||
|
||||
|
||||
def _is_safe_attachment_path(path: str, *, expected_prefix: str) -> bool:
|
||||
normalized = path.strip()
|
||||
if not normalized:
|
||||
return False
|
||||
if normalized.startswith("/"):
|
||||
return False
|
||||
if ".." in normalized:
|
||||
return False
|
||||
return normalized.startswith(expected_prefix)
|
||||
|
||||
Reference in New Issue
Block a user