refactor: 简化 AgentScope 运行时模块与事件处理

- 移除冗余的 user_token 参数传递
- 重构 tool.result 事件使用 ToolAgentOutput 模型
- 重构 text.end 事件使用 WorkerAgentOutput 模型
- 简化 store 模块的 tool result 处理逻辑
- 更新 router/service 适配新事件结构
- 清理废弃的测试文件与设计文档
- 新增 AgentRuns 多模态存储设计文档
This commit is contained in:
qzl
2026-03-13 17:27:18 +08:00
parent 3273d63b23
commit 1c02503d1d
29 changed files with 1259 additions and 2725 deletions
@@ -38,16 +38,13 @@ def to_agui_wire_event(event: dict[str, Any]) -> dict[str, Any]:
data = event.get("data")
if isinstance(data, dict):
if event_type == "tool.result":
for key in (
"messageId",
"toolCallId",
"callId",
"toolName",
"stage",
"taskId",
"ui",
"content",
):
for key in ("messageId", "toolCallId", "toolAgentOutput"):
value = data.get(key)
if value is not None:
payload[key] = value
return payload
if event_type == "text.end":
for key in ("messageId", "workerAgentOutput"):
value = data.get(key)
if value is not None:
payload[key] = value
+30 -93
View File
@@ -1,16 +1,19 @@
from __future__ import annotations
import json
import re
from decimal import Decimal, InvalidOperation
from typing import Any, Callable, Protocol
from uuid import UUID, uuid4
from core.agentscope.events.tool_result_summary import build_tool_content_summary
from core.agentscope.events.persistence import MessageRepository, SessionRepository
from core.logging import get_logger
from models.agent_chat_message import AgentChatMessageRole
from models.agent_chat_session import AgentChatSessionStatus
from schemas.agent.runtime_models import (
ToolAgentOutput,
WorkerAgentOutputLite,
WorkerAgentOutputRich,
)
class EventStore(Protocol):
@@ -193,6 +196,19 @@ class SqlAlchemyEventStore:
if isinstance(stage, str) and stage:
metadata["stage"] = stage
worker_payload = event.get("workerAgentOutput")
if isinstance(worker_payload, dict):
try:
if "ui_hints" in worker_payload:
worker_output = WorkerAgentOutputRich.model_validate(worker_payload)
else:
worker_output = WorkerAgentOutputLite.model_validate(worker_payload)
except Exception:
worker_output = None
else:
content = worker_output.answer
metadata["worker_agent_output"] = worker_output.model_dump(mode="json")
role_value = context.get("role")
if not isinstance(role_value, str):
role_value = "assistant"
@@ -252,6 +268,14 @@ class SqlAlchemyEventStore:
if not isinstance(tool_name, str) or not tool_name:
return
raw_output = event.get("toolAgentOutput")
if not isinstance(raw_output, dict):
return
try:
tool_output = ToolAgentOutput.model_validate(raw_output)
except Exception:
return
run_id = event.get("runId")
run_id_value = run_id if isinstance(run_id, str) and run_id else ""
task_id = event.get("taskId")
@@ -264,43 +288,18 @@ class SqlAlchemyEventStore:
else f"{task_id_value}-{uuid4().hex[:8]}"
)
summary = build_tool_content_summary(
tool_name=tool_name,
args=event.get("args") if isinstance(event.get("args"), dict) else None,
result=event.get("result"),
error=event.get("error"),
)
raw_result_value = event.get("result")
raw_result: dict[str, object] = (
raw_result_value if isinstance(raw_result_value, dict) else {}
)
ui_candidate = raw_result.get("ui")
ui_schema = ui_candidate if isinstance(ui_candidate, dict) else None
result_type = raw_result.get("type")
result_data = raw_result.get("data")
if (
ui_schema is None
and isinstance(result_type, str)
and isinstance(result_data, dict)
):
ui_schema = raw_result
payload: dict[str, object] = {
"toolName": tool_name,
"ui_schema": ui_schema,
"result": _sanitize_result(raw_result),
"error": _sanitize_error(event.get("error")),
"toolAgentOutput": tool_output.model_dump(mode="json"),
"callId": call_id_value,
"runId": run_id_value,
"taskId": task_id_value,
"content": summary,
"content": tool_output.result_summary,
}
metadata: dict[str, object] = {
"tool_name": tool_name,
"tool_call_id": call_id_value,
"summary_version": "v1",
"tool_agent_output": tool_output.model_dump(mode="json"),
}
if run_id_value:
metadata["run_id"] = run_id_value
@@ -332,9 +331,7 @@ class SqlAlchemyEventStore:
storage_path=storage_path,
)
content = summary or json.dumps(
payload, ensure_ascii=False, separators=(",", ":")
)
content = tool_output.result_summary
locked_session = await session_repo.lock_session_for_update(
session_id=session_id
@@ -429,63 +426,3 @@ def _sanitize_path_component(value: str) -> str:
compact = re.sub(r"[^A-Za-z0-9._-]", "-", value.strip())
compact = compact.strip(".-")
return compact or "id"
def _sanitize_error(value: object) -> str | None:
if isinstance(value, str) and value.strip():
return " ".join(value.split())[:300]
if isinstance(value, dict):
for key in ("message", "error", "detail"):
item = value.get(key)
if isinstance(item, str) and item.strip():
return " ".join(item.split())[:300]
return None
def _sanitize_result(value: object) -> dict[str, object]:
if not isinstance(value, dict):
return {}
def _is_sensitive_key(key: str) -> bool:
normalized = key.strip().lower().replace("-", "_")
if not normalized:
return False
exact = {
"password",
"token",
"secret",
"api_key",
"apikey",
"credential",
"authorization",
"auth",
}
if normalized in exact:
return True
patterns = (
"password",
"token",
"secret",
"auth",
"credential",
"api_key",
"apikey",
"authorization",
)
return any(pattern in normalized for pattern in patterns)
def _sanitize_value(item: object) -> object:
if isinstance(item, dict):
return _sanitize_result(item)
if isinstance(item, list):
return [_sanitize_value(entry) for entry in item]
return item
sanitized: dict[str, object] = {}
for key, item in value.items():
key_text = str(key)
if _is_sensitive_key(key_text):
sanitized[str(key)] = "[REDACTED]"
continue
sanitized[str(key)] = _sanitize_value(item)
return sanitized
@@ -78,21 +78,19 @@ def build_intent_user_prompt(
*, user_input: str | list[dict[str, Any]]
) -> str | list[dict[str, Any]]:
if isinstance(user_input, list):
instruction_block = {
"type": "text",
"text": "\n\n".join(
[
ROUTER_STAGE_INSTRUCTION,
"[Output Schema]",
_schema_json(RouterAgentOutput),
]
),
}
return [
{
"type": "text",
"text": "\n\n".join(
[
ROUTER_STAGE_INSTRUCTION,
"[Output Schema]",
_schema_json(RouterAgentOutput),
"[User Input]",
json.dumps(
user_input, ensure_ascii=True, separators=(",", ":")
),
]
),
}
instruction_block,
*user_input,
]
return "\n\n".join(
[
@@ -50,11 +50,9 @@ class AgentScopeRuntimeOrchestrator:
*,
command: RunAgentInput,
owner_id: UUID,
user_token: str,
user_context: UserContext,
session: AsyncSession,
) -> dict[str, Any]:
del user_token
return await self._execute(
command=command,
owner_id=owner_id,
@@ -68,11 +66,9 @@ class AgentScopeRuntimeOrchestrator:
*,
command: RunAgentInput,
owner_id: UUID,
user_token: str,
user_context: UserContext,
session: AsyncSession,
) -> dict[str, Any]:
del user_token
return await self._execute(
command=command,
owner_id=owner_id,
@@ -116,7 +112,7 @@ class AgentScopeRuntimeOrchestrator:
user_input = _to_resume_user_input(command)
else:
_, content_blocks = extract_latest_user_payload(command)
user_input = _to_user_input_payload(content_blocks)
user_input = _to_model_user_input(content_blocks)
router_toolkit = build_stage_toolkit(
stage="intent",
session=session,
@@ -159,16 +155,38 @@ class AgentScopeRuntimeOrchestrator:
worker_payload = result.get("worker") if isinstance(result, dict) else None
worker = worker_payload if isinstance(worker_payload, dict) else {}
response_metadata = worker.get("response_metadata")
metadata = response_metadata if isinstance(response_metadata, dict) else {}
assistant_text = _resolve_worker_answer(worker)
tool_outputs_raw = worker.get("tool_outputs")
if isinstance(tool_outputs_raw, list):
for idx, item in enumerate(tool_outputs_raw, start=1):
if not isinstance(item, dict):
continue
tool_name = item.get("tool_name")
tool_call_id = item.get("tool_call_id")
if not isinstance(tool_name, str) or not tool_name:
continue
if not isinstance(tool_call_id, str) or not tool_call_id:
tool_call_id = f"{run_id}-tool-{idx}"
await self._pipeline.emit(
session_id=thread_id,
event={
"type": "tool.result",
"threadId": thread_id,
"runId": run_id,
"data": {
"messageId": f"tool-{tool_call_id}",
"toolCallId": tool_call_id,
"toolAgentOutput": item,
},
},
)
await self._emit_stage_text(
thread_id=thread_id,
run_id=run_id,
stage_name="worker",
message_id=f"assistant-{run_id}",
text=assistant_text,
response_metadata=metadata,
worker_agent_output=worker,
)
await self._pipeline.emit(
@@ -215,7 +233,7 @@ class AgentScopeRuntimeOrchestrator:
stage_name: str,
message_id: str,
text: str,
response_metadata: dict[str, Any],
worker_agent_output: dict[str, Any],
) -> None:
await self._pipeline.emit(
session_id=thread_id,
@@ -250,8 +268,7 @@ class AgentScopeRuntimeOrchestrator:
"runId": run_id,
"data": {
"messageId": message_id,
"stage": stage_name,
**_text_end_telemetry_payload(response_metadata),
"workerAgentOutput": worker_agent_output,
},
},
)
@@ -271,6 +288,28 @@ def _to_user_input_payload(
return content_blocks
def _to_model_user_input(
content_blocks: list[dict[str, Any]],
) -> str | list[dict[str, Any]]:
normalized: list[dict[str, Any]] = []
for block in content_blocks:
if not isinstance(block, dict):
continue
block_type = block.get("type")
if block_type == "text":
text = block.get("text")
if isinstance(text, str) and text.strip():
normalized.append({"type": "text", "text": text})
continue
if block_type != "binary":
continue
url = block.get("url")
if isinstance(url, str) and url:
normalized.append({"type": "image_url", "image_url": {"url": url}})
return _to_user_input_payload(normalized)
def _to_resume_user_input(command: RunAgentInput) -> list[dict[str, Any]]:
normalized: list[dict[str, Any]] = []
for message in command.messages:
@@ -296,66 +335,3 @@ def _resolve_worker_answer(worker: dict[str, Any]) -> str:
return message
return "抱歉,这次没有产出可用结果,请重试。"
def _text_end_telemetry_payload(metadata: dict[str, Any]) -> dict[str, Any]:
payload: dict[str, Any] = {}
model = _first_non_empty_str(metadata, keys=("model", "model_code"))
if model is not None:
payload["model"] = model
input_tokens = _first_number(metadata, keys=("inputTokens", "input_tokens"))
if input_tokens is not None:
payload["inputTokens"] = input_tokens
output_tokens = _first_number(metadata, keys=("outputTokens", "output_tokens"))
if output_tokens is not None:
payload["outputTokens"] = output_tokens
latency_ms = _first_number(metadata, keys=("latencyMs", "latency_ms"))
if latency_ms is not None:
payload["latencyMs"] = latency_ms
cost = _first_number(metadata, keys=("cost", "total_cost"), allow_float=True)
if cost is not None:
payload["cost"] = cost
return payload
def _first_non_empty_str(
metadata: dict[str, Any], *, keys: tuple[str, ...]
) -> str | None:
for key in keys:
value = metadata.get(key)
if isinstance(value, str) and value.strip():
return value.strip()
return None
def _first_number(
metadata: dict[str, Any],
*,
keys: tuple[str, ...],
allow_float: bool = False,
) -> int | float | None:
for key in keys:
value = metadata.get(key)
if isinstance(value, bool):
continue
if isinstance(value, int):
if value < 0:
continue
return value
if isinstance(value, float):
if value < 0:
continue
return value if allow_float else int(value)
if isinstance(value, str):
try:
parsed = float(value) if allow_float else int(value)
except ValueError:
continue
if parsed >= 0:
return parsed
return None
@@ -68,16 +68,6 @@ def _build_user_context(*, owner_id: UUID, run_input: RunAgentInput) -> UserCont
)
def _extract_user_token(
*, command: dict[str, Any], run_input: RunAgentInput
) -> str | None:
del run_input
raw_token = command.get("user_token")
if isinstance(raw_token, str) and raw_token.strip():
return raw_token.strip()
return None
async def _build_recent_context_messages(
*,
session: Any,
@@ -147,7 +137,6 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
if command_type == "resume":
extract_latest_tool_result(parsed_run_input)
user_context = _build_user_context(owner_id=owner_id, run_input=parsed_run_input)
user_token = _extract_user_token(command=command, run_input=parsed_run_input) or ""
redis_client = await get_or_init_redis_client()
bus = RedisStreamBus(
@@ -189,7 +178,6 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
await runtime.resume(
command=parsed_run_input,
owner_id=owner_id,
user_token=user_token,
user_context=user_context,
session=session,
)
@@ -197,7 +185,6 @@ async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
await runtime.run(
command=parsed_run_input,
owner_id=owner_id,
user_token=user_token,
user_context=user_context,
session=session,
)
@@ -167,16 +167,18 @@ class UiCompiler:
timestamp: str | None = None,
meta: dict[str, Any] | None = None,
) -> dict[str, Any]:
hints = output.ui_hints or self._build_default_worker_hints(output)
if output.error is not None and not self._contains_error_block(hints.blocks):
output_ui_hints = getattr(output, "ui_hints", None)
hints = output_ui_hints or self._build_default_worker_hints(output)
output_error = getattr(output, "error", None)
if output_error is not None and not self._contains_error_block(hints.blocks):
hints = self._append_error_block(
hints,
UiHintErrorBlock(
kind="error",
errorCode=output.error.code,
message=output.error.message,
retryable=output.error.retryable,
details=self._stringify_details(output.error.details),
errorCode=output_error.code,
message=output_error.message,
retryable=output_error.retryable,
details=self._stringify_details(output_error.details),
),
)
@@ -0,0 +1,57 @@
from __future__ import annotations
import json
from typing import Protocol
from services.base.supabase import supabase_service
class ToolResultStorage(Protocol):
async def upload_json(
self,
*,
bucket: str,
path: str,
payload: dict[str, object],
) -> str: ...
async def read_json(
self,
*,
bucket: str,
path: str,
) -> dict[str, object] | None: ...
class SupabaseToolResultStorage:
async def upload_json(
self,
*,
bucket: str,
path: str,
payload: dict[str, object],
) -> str:
serialized = json.dumps(payload, ensure_ascii=True, separators=(",", ":"))
await supabase_service.upload_bytes(
bucket=bucket,
path=path,
content=serialized.encode("utf-8"),
content_type="application/json",
)
return path
async def read_json(
self,
*,
bucket: str,
path: str,
) -> dict[str, object] | None:
raw = await supabase_service.download_bytes(bucket=bucket, path=path)
decoded = json.loads(raw.decode("utf-8"))
if isinstance(decoded, dict):
return decoded
return None
def create_tool_result_storage() -> ToolResultStorage:
return SupabaseToolResultStorage()
+2 -16
View File
@@ -1,11 +1,3 @@
from core.agentscope.schemas.agui_input import (
extract_latest_tool_result,
extract_latest_user_content,
extract_latest_user_payload,
extract_latest_user_text,
parse_run_input,
validate_run_request_messages_contract,
)
from schemas.agent.runtime_models import (
ResultType,
RouterAgentOutput,
@@ -14,17 +6,17 @@ from schemas.agent.runtime_models import (
ToolAgentOutput,
ToolStatus,
UiMode,
WorkerAgentOutput,
WorkerAgentOutputLite,
WorkerAgentOutputRich,
WorkerAgentOutput,
resolve_worker_output_model,
)
from schemas.agent.system_agent import AgentType, SystemAgentLLMConfig
from schemas.agent.ui_hints import (
UiHintAction,
UiHintBlock,
UiHintsPayload,
)
from schemas.agent.system_agent import AgentType, SystemAgentLLMConfig
__all__ = [
"AgentType",
@@ -43,10 +35,4 @@ __all__ = [
"WorkerAgentOutputRich",
"WorkerAgentOutput",
"resolve_worker_output_model",
"extract_latest_tool_result",
"extract_latest_user_content",
"extract_latest_user_payload",
"extract_latest_user_text",
"parse_run_input",
"validate_run_request_messages_contract",
]
-1
View File
@@ -1 +0,0 @@
from core.agentscope.schemas.agui_input import * # noqa: F403
+1 -1
View File
@@ -374,7 +374,7 @@ class WorkerAgentOutputRich(WorkerAgentOutputLite):
)
WorkerAgentOutput = WorkerAgentOutputRich
WorkerAgentOutput = WorkerAgentOutputLite | WorkerAgentOutputRich
def resolve_worker_output_model(ui_mode: UiMode) -> type[WorkerAgentOutputLite]:
+2 -2
View File
@@ -1,3 +1,3 @@
from schemas.messages.chat_message import AgentChatMessageMetadata
from schemas.messages.chat_message import AgentChatMessage, AgentChatMessageMetadata
__all__ = ["AgentChatMessageMetadata"]
__all__ = ["AgentChatMessage", "AgentChatMessageMetadata"]
+17 -1
View File
@@ -1,8 +1,11 @@
from __future__ import annotations
from datetime import datetime
from decimal import Decimal
from typing import ClassVar
from uuid import UUID
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, Field
from ..agent import AgentType, ToolAgentOutput, WorkerAgentOutput
@@ -22,3 +25,16 @@ class AgentChatMessageMetadata(BaseModel):
user_message_attachments: UserMessageAttachments | None = None
tool_agent_output: ToolAgentOutput | None = None
worker_agent_output: WorkerAgentOutput | None = None
class AgentChatMessage(BaseModel):
"""Canonical schema aligned with `messages` table columns."""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
id: UUID
seq: int
role: str
content: str
metadata: AgentChatMessageMetadata | dict[str, object] | None = None
timestamp: datetime
+12 -145
View File
@@ -1,7 +1,6 @@
from __future__ import annotations
from datetime import date, datetime, time, timedelta, timezone
import json
from typing import Protocol
from uuid import UUID
@@ -9,10 +8,9 @@ from fastapi import HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from core.config.settings import config
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
from models.agent_chat_session import AgentChatSession
from services.base.supabase import supabase_service
from schemas.messages.chat_message import AgentChatMessage as AgentChatMessageSchema
class ToolResultPayloadStorage(Protocol):
@@ -210,132 +208,17 @@ class AgentRepository:
if isinstance(message.role, AgentChatMessageRole)
else str(message.role)
)
payload: dict[str, object] = {
"id": str(message.id),
"role": role,
"timestamp": message.created_at.astimezone(timezone.utc).isoformat(),
}
if role == AgentChatMessageRole.TOOL.value:
metadata = message.metadata_json or {}
tool_call_id = metadata.get("tool_call_id")
if isinstance(tool_call_id, str) and tool_call_id:
payload["toolCallId"] = tool_call_id
parsed_content: dict[str, object] | None = None
try:
decoded = json.loads(message.content)
if isinstance(decoded, dict):
parsed_content = decoded
except (TypeError, ValueError):
parsed_content = None
hydrated_content: dict[str, object] | None = None
if self._tool_result_storage is not None:
storage_bucket = metadata.get("storage_bucket")
storage_path = metadata.get("storage_path")
if isinstance(storage_bucket, str) and isinstance(storage_path, str):
expected_bucket = config.storage.bucket
message_session_id = getattr(message, "session_id", None)
expected_prefix = (
f"tool-results/{message_session_id}/"
if message_session_id is not None
else None
)
tool_call_id = metadata.get("tool_call_id")
is_legacy_path = isinstance(
tool_call_id, str
) and storage_path.endswith(f"/{tool_call_id}.json")
if (
storage_bucket == expected_bucket
and _is_safe_storage_path(storage_path)
and (
(
expected_prefix is not None
and storage_path.startswith(expected_prefix)
)
or (
storage_path.startswith("tool-results/")
and is_legacy_path
)
)
):
try:
hydrated_content = (
await self._tool_result_storage.read_json(
bucket=storage_bucket,
path=storage_path,
)
)
except Exception:
hydrated_content = None
resolved_content = hydrated_content or parsed_content
payload["content"] = message.content
if resolved_content is not None:
ui = resolved_content.get("ui")
if not isinstance(ui, dict):
ui = resolved_content.get("ui_schema")
if isinstance(ui, dict):
payload["ui"] = ui
display_content = resolved_content.get("content")
if not isinstance(display_content, str):
nested_result = resolved_content.get("result")
if isinstance(nested_result, dict):
nested_content = nested_result.get("content")
if isinstance(nested_content, str):
display_content = nested_content
if (
isinstance(display_content, str)
and display_content.strip()
and (
not payload["content"]
or _looks_like_offloaded_placeholder(str(payload["content"]))
)
):
payload["content"] = display_content
else:
payload["content"] = message.content
if role == AgentChatMessageRole.USER.value:
metadata = message.metadata_json or {}
user_attachments = metadata.get("user_message_attachments")
if isinstance(user_attachments, dict):
bucket = user_attachments.get("bucket")
path = user_attachments.get("path")
mime_type = user_attachments.get("mime_type")
if (
isinstance(bucket, str)
and isinstance(path, str)
and isinstance(mime_type, str)
):
try:
signed_url = await supabase_service.create_signed_url(
bucket=bucket,
path=path,
expires_in_seconds=3600,
)
attachment_block = {
"type": "binary",
"mimeType": mime_type,
"url": signed_url,
}
existing_content = message.content
if (
isinstance(existing_content, str)
and existing_content.strip()
):
content_blocks = [
{"type": "text", "text": existing_content}
]
content_blocks.append(attachment_block)
payload["content"] = content_blocks
else:
payload["content"] = [attachment_block]
except Exception: # noqa: BLE001
pass
return payload
payload_model = AgentChatMessageSchema.model_validate(
{
"id": str(message.id),
"seq": int(message.seq),
"role": role,
"content": message.content,
"metadata": message.metadata_json,
"timestamp": message.created_at.astimezone(timezone.utc).isoformat(),
}
)
return payload_model.model_dump(mode="json", exclude_none=True)
def _has_title(title: object) -> bool:
@@ -347,19 +230,3 @@ def _derive_session_title(content_text: str) -> str | None:
if not normalized:
return None
return normalized[:80]
def _is_safe_storage_path(path: str) -> bool:
normalized = path.strip()
if not normalized:
return False
if normalized.startswith("/"):
return False
if ".." in normalized:
return False
return True
def _looks_like_offloaded_placeholder(content: str) -> bool:
normalized = content.strip().lower()
return normalized in {'{"offloaded":true}', '{"offloaded": true}'}
+20 -64
View File
@@ -11,9 +11,7 @@ from typing import Annotated, Union
from ag_ui.core import RunAgentInput
from core.agentscope.events import to_sse_event
from core.auth.jwt_verifier import JwtVerifier, TokenValidationError
from core.auth.models import CurrentUser
from core.config.settings import config
from core.logging import get_logger
from fastapi import (
APIRouter,
@@ -38,6 +36,7 @@ from v1.agent.dependencies import get_agent_service
from v1.agent.schemas import (
AsrTranscribeResponse,
AttachmentReference,
AttachmentSignedUrlResponse,
AttachmentUploadResponse,
TaskAcceptedResponse,
)
@@ -63,42 +62,6 @@ _ALLOWED_AUDIO_CONTENT_TYPES = {
}
def _verified_access_token_for_user(
*,
authorization: str | None,
current_user: CurrentUser,
) -> str:
if not isinstance(authorization, str):
raise HTTPException(status_code=401, detail="Unauthorized")
normalized = authorization.strip()
if not normalized:
raise HTTPException(status_code=401, detail="Unauthorized")
if not normalized.lower().startswith("bearer "):
raise HTTPException(status_code=401, detail="Unauthorized")
token = normalized[7:].strip()
if not token:
raise HTTPException(status_code=401, detail="Unauthorized")
jwt_secret = config.supabase.jwt_secret
if jwt_secret is None:
raise HTTPException(status_code=503, detail="Auth verifier unavailable")
verifier = JwtVerifier(
issuer=str(config.supabase.jwt_issuer),
jwt_secret=jwt_secret.get_secret_value(),
jwt_algorithm=config.supabase.jwt_algorithm,
)
try:
payload = verifier.verify(token)
except TokenValidationError as exc:
raise HTTPException(status_code=401, detail="Unauthorized") from exc
subject = payload.get("sub")
if not isinstance(subject, str) or subject != str(current_user.id):
raise HTTPException(status_code=403, detail="Forbidden")
return token
def _looks_like_wav_header(header: bytes) -> bool:
if len(header) < _WAV_HEADER_MIN_BYTES:
return False
@@ -164,7 +127,6 @@ async def enqueue_run(
request: RunAgentInput,
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
authorization: str | None = Header(default=None, alias="Authorization"),
) -> TaskAcceptedResponse:
try:
normalized = parse_run_input(request.model_dump(mode="json", by_alias=True))
@@ -174,15 +136,10 @@ async def enqueue_run(
allowed = await _allow_run_request(user_id=str(current_user.id))
if not allowed:
raise HTTPException(status_code=429, detail="Too many run requests")
user_token = _verified_access_token_for_user(
authorization=authorization,
current_user=current_user,
)
task = await service.enqueue_run(
run_input=request,
current_user=current_user,
user_token=user_token,
)
return TaskAcceptedResponse(
taskId=task.task_id,
@@ -202,7 +159,6 @@ async def enqueue_resume(
request: RunAgentInput,
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
authorization: str | None = Header(default=None, alias="Authorization"),
) -> TaskAcceptedResponse:
if request.thread_id != thread_id:
raise HTTPException(status_code=422, detail="thread_id path/body mismatch")
@@ -214,15 +170,10 @@ async def enqueue_resume(
allowed = await _allow_run_request(user_id=str(current_user.id))
if not allowed:
raise HTTPException(status_code=429, detail="Too many run requests")
user_token = _verified_access_token_for_user(
authorization=authorization,
current_user=current_user,
)
task = await service.enqueue_resume(
thread_id=thread_id,
run_input=request,
current_user=current_user,
user_token=user_token,
)
return TaskAcceptedResponse(
taskId=task.task_id,
@@ -304,20 +255,6 @@ async def stream_events(
)
@router.get("/runs/{thread_id}/history")
async def get_history_snapshot(
thread_id: str,
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
before: date | None = Query(default=None),
) -> dict[str, object]:
return await service.get_history_snapshot(
thread_id=thread_id,
before=before,
current_user=current_user,
)
@router.get("/history")
async def get_user_history_snapshot(
service: Annotated[AgentService, Depends(get_agent_service)],
@@ -360,6 +297,25 @@ async def upload_attachment(
)
@router.get(
"/attachments/signed-url",
response_model=AttachmentSignedUrlResponse,
status_code=status.HTTP_200_OK,
)
async def create_attachment_signed_url(
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
bucket: str = Query(min_length=1, max_length=100),
path: str = Query(min_length=1, max_length=500),
) -> AttachmentSignedUrlResponse:
signed = await service.create_attachment_signed_url(
bucket=bucket,
path=path,
current_user=current_user,
)
return AttachmentSignedUrlResponse(**signed)
@router.post(
"/transcribe",
response_model=AsrTranscribeResponse,
+6
View File
@@ -27,3 +27,9 @@ class AttachmentReference(BaseModel):
class AttachmentUploadResponse(BaseModel):
attachment: AttachmentReference
class AttachmentSignedUrlResponse(BaseModel):
bucket: str
path: str
url: str
+91 -30
View File
@@ -5,6 +5,7 @@ from dataclasses import dataclass
from datetime import date
import hashlib
from typing import Any, Protocol
from urllib.parse import urlparse
import dashscope
from ag_ui.core import RunAgentInput, StateSnapshotEvent
@@ -23,19 +24,6 @@ _MAX_ATTACHMENT_BYTES = 5 * 1024 * 1024
_MAX_TOTAL_ATTACHMENT_BYTES = 12 * 1024 * 1024
def _normalize_bearer_token(value: str | None) -> str | None:
if not isinstance(value, str):
return None
normalized = value.strip()
if not normalized:
return None
lower = normalized.lower()
if lower.startswith("bearer "):
token = normalized[7:].strip()
return token or None
return normalized
@dataclass(frozen=True)
class TaskAccepted:
task_id: str
@@ -70,14 +58,6 @@ class AgentRepositoryLike(Protocol):
metadata: dict[str, object] | None,
) -> None: ...
async def get_message_attachment_reference(
self,
*,
session_id: str,
message_id: str,
attachment_index: int,
) -> dict[str, str] | None: ...
class QueueClientLike(Protocol):
async def enqueue(
@@ -148,7 +128,6 @@ class AgentService:
*,
run_input: RunAgentInput,
current_user: CurrentUser,
user_token: str | None = None,
) -> TaskAccepted:
created = False
thread_id = run_input.thread_id
@@ -188,7 +167,6 @@ class AgentService:
command={
"command": "run",
"owner_id": str(current_user.id),
"user_token": _normalize_bearer_token(user_token),
"run_input": run_input.model_dump(mode="json", by_alias=True),
},
dedup_key=None,
@@ -226,19 +204,28 @@ class AgentService:
mime_type = "application/octet-stream"
if self._attachment_storage is None:
continue
raise HTTPException(
status_code=503,
detail="Attachment storage unavailable",
)
try:
bucket, path = self._attachment_storage.parse_signed_url(url)
bucket, path = self._validate_binary_signed_url(
url=url,
thread_id=run_input.thread_id,
current_user=current_user,
)
user_attachments = UserMessageAttachments(
bucket=bucket,
path=path,
mime_type=mime_type,
)
break
except Exception: # noqa: BLE001
logger.warning("Failed to parse signed URL", url=url)
continue
except HTTPException:
raise
except Exception as exc: # noqa: BLE001
logger.warning("Failed to parse signed URL", url=url, error=str(exc))
raise HTTPException(status_code=422, detail="Invalid signed image url")
metadata: dict[str, object] | None = None
if user_attachments is not None:
@@ -329,13 +316,57 @@ class AgentService:
"url": signed_url,
}
async def create_attachment_signed_url(
self,
*,
bucket: str,
path: str,
current_user: CurrentUser,
) -> dict[str, str]:
if self._attachment_storage is None:
raise HTTPException(
status_code=503, detail="Attachment storage unavailable"
)
normalized_bucket = bucket.strip()
if normalized_bucket != config.storage.bucket:
raise HTTPException(status_code=422, detail="Invalid attachment bucket")
normalized_path = path.strip()
expected_prefix = f"agent-inputs/{current_user.id}/"
if not _is_safe_attachment_path(
normalized_path, expected_prefix=expected_prefix
):
raise HTTPException(status_code=422, detail="Invalid attachment path scope")
try:
signed_url = await self._attachment_storage.create_signed_url(
bucket=normalized_bucket,
path=normalized_path,
expires_in_seconds=self._SIGNED_URL_EXPIRES_IN_SECONDS,
)
except Exception: # noqa: BLE001
logger.exception(
"Attachment signed URL generation failed",
extra={
"bucket": normalized_bucket,
"path": normalized_path,
"user_id": str(current_user.id),
},
)
raise HTTPException(status_code=502, detail="Failed to generate signed URL")
return {
"bucket": normalized_bucket,
"path": normalized_path,
"url": signed_url,
}
async def enqueue_resume(
self,
*,
thread_id: str,
run_input: RunAgentInput,
current_user: CurrentUser,
user_token: str | None = None,
) -> TaskAccepted:
owner = await self._repository.get_session_owner(session_id=thread_id)
ensure_session_owner(owner_id=owner, current_user=current_user)
@@ -345,7 +376,6 @@ class AgentService:
command={
"command": "resume",
"owner_id": str(current_user.id),
"user_token": _normalize_bearer_token(user_token),
"run_input": run_input.model_dump(mode="json", by_alias=True),
},
dedup_key=dedup_key,
@@ -428,6 +458,37 @@ class AgentService:
current_user=current_user,
)
def _validate_binary_signed_url(
self,
*,
url: str,
thread_id: str,
current_user: CurrentUser,
) -> tuple[str, str]:
if self._attachment_storage is None:
raise HTTPException(
status_code=503, detail="Attachment storage unavailable"
)
parsed = urlparse(url)
expected_host = urlparse(config.supabase.url).netloc
if parsed.netloc != expected_host:
raise HTTPException(status_code=422, detail="INVALID_BINARY_URL_HOST")
try:
bucket, path = self._attachment_storage.parse_signed_url(url)
except Exception as exc: # noqa: BLE001
raise HTTPException(
status_code=422, detail="Invalid signed image url"
) from exc
if bucket != config.storage.bucket:
raise HTTPException(status_code=422, detail="INVALID_BINARY_URL_BUCKET")
expected_prefix = f"agent-inputs/{current_user.id}/{thread_id}/uploads/"
if not _is_safe_attachment_path(path, expected_prefix=expected_prefix):
raise HTTPException(status_code=422, detail="INVALID_BINARY_URL_PATH_SCOPE")
return bucket, path
class AsrService:
def __init__(self) -> None: