refactor: 重构 AgentScope ReAct Runner 与事件处理
- 重构 runtime/runner.py 实现 ReAct Agent 核心逻辑 - 更新事件编码器与存储机制 - 优化 prompt 系统与 tool 调用 - 调整 agent service 与 repository 配合
This commit is contained in:
@@ -10,11 +10,11 @@ from ag_ui.core import (
|
||||
RunErrorEvent,
|
||||
StepStartedEvent,
|
||||
StepFinishedEvent,
|
||||
TextMessageStartEvent,
|
||||
TextMessageContentEvent,
|
||||
TextMessageEndEvent,
|
||||
ToolCallResultEvent,
|
||||
)
|
||||
from core.agentscope.runtime.ui_compiler import compile as compile_ui_hints
|
||||
from schemas.agent.ui_hints import UiHintsPayload
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
@@ -25,8 +25,6 @@ _INTERNAL_TO_AGUI: dict[str, EventType] = {
|
||||
"run.error": EventType.RUN_ERROR,
|
||||
"step.start": EventType.STEP_STARTED,
|
||||
"step.finish": EventType.STEP_FINISHED,
|
||||
"text.start": EventType.TEXT_MESSAGE_START,
|
||||
"text.delta": EventType.TEXT_MESSAGE_CONTENT,
|
||||
"text.end": EventType.TEXT_MESSAGE_END,
|
||||
"tool.start": EventType.TOOL_CALL_START,
|
||||
"tool.args": EventType.TOOL_CALL_ARGS,
|
||||
@@ -53,6 +51,34 @@ def _is_agui_event(event: dict[str, Any]) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _sanitize_agui_event(event: dict[str, Any]) -> dict[str, Any]:
|
||||
payload = dict(event)
|
||||
event_type = str(payload.get("type", "")).strip().upper()
|
||||
if event_type in {
|
||||
EventType.TEXT_MESSAGE_END.value,
|
||||
EventType.TOOL_CALL_RESULT.value,
|
||||
}:
|
||||
ui_hints = payload.get("ui_hints")
|
||||
if ui_hints is not None:
|
||||
try:
|
||||
ui_hints_payload = UiHintsPayload.model_validate(ui_hints)
|
||||
ui_schema = compile_ui_hints(ui_hints_payload)
|
||||
payload["ui_schema"] = ui_schema
|
||||
except Exception:
|
||||
pass
|
||||
payload.pop("ui_hints", None)
|
||||
if event_type == EventType.TEXT_MESSAGE_END.value:
|
||||
for key in (
|
||||
"inputTokens",
|
||||
"outputTokens",
|
||||
"cost",
|
||||
"latencyMs",
|
||||
"model",
|
||||
):
|
||||
payload.pop(key, None)
|
||||
return payload
|
||||
|
||||
|
||||
def _build_run_started(event: dict[str, Any]) -> RunStartedEvent:
|
||||
return RunStartedEvent(
|
||||
thread_id=event.get("threadId", ""),
|
||||
@@ -77,31 +103,21 @@ def _build_run_error(event: dict[str, Any]) -> RunErrorEvent:
|
||||
|
||||
def _build_step_started(event: dict[str, Any]) -> StepStartedEvent:
|
||||
data = event.get("data", {})
|
||||
step_name = event.get("stepName", "")
|
||||
if (not isinstance(step_name, str) or not step_name) and isinstance(data, dict):
|
||||
step_name = data.get("stepName", "")
|
||||
return StepStartedEvent(
|
||||
step_name=data.get("stepName", ""),
|
||||
step_name=step_name if isinstance(step_name, str) else "",
|
||||
)
|
||||
|
||||
|
||||
def _build_step_finished(event: dict[str, Any]) -> StepFinishedEvent:
|
||||
data = event.get("data", {})
|
||||
step_name = event.get("stepName", "")
|
||||
if (not isinstance(step_name, str) or not step_name) and isinstance(data, dict):
|
||||
step_name = data.get("stepName", "")
|
||||
return StepFinishedEvent(
|
||||
step_name=data.get("stepName", ""),
|
||||
)
|
||||
|
||||
|
||||
def _build_text_start(event: dict[str, Any]) -> TextMessageStartEvent:
|
||||
data = event.get("data", {})
|
||||
return TextMessageStartEvent(
|
||||
message_id=data.get("messageId", ""),
|
||||
role=data.get("role", "assistant"),
|
||||
)
|
||||
|
||||
|
||||
def _build_text_delta(event: dict[str, Any]) -> TextMessageContentEvent:
|
||||
data = event.get("data", {})
|
||||
return TextMessageContentEvent(
|
||||
message_id=data.get("messageId", ""),
|
||||
delta=data.get("delta", ""),
|
||||
step_name=step_name if isinstance(step_name, str) else "",
|
||||
)
|
||||
|
||||
|
||||
@@ -128,8 +144,6 @@ _BUILDER_MAP: dict[str, Any] = {
|
||||
"run.error": _build_run_error,
|
||||
"step.start": _build_step_started,
|
||||
"step.finish": _build_step_finished,
|
||||
"text.start": _build_text_start,
|
||||
"text.delta": _build_text_delta,
|
||||
"text.end": _build_text_end,
|
||||
"tool.result": _build_tool_result,
|
||||
}
|
||||
@@ -140,7 +154,7 @@ def to_agui_wire_event(event: dict[str, Any] | BaseEvent) -> dict[str, Any]:
|
||||
return event.model_dump(by_alias=True, exclude_none=True)
|
||||
|
||||
if _is_agui_event(event):
|
||||
return event
|
||||
return _sanitize_agui_event(event)
|
||||
|
||||
internal_type = str(event.get("type", "")).strip()
|
||||
|
||||
@@ -156,24 +170,29 @@ def to_agui_wire_event(event: dict[str, Any] | BaseEvent) -> dict[str, Any]:
|
||||
text_end_payload["threadId"] = thread_id
|
||||
if isinstance(run_id, str) and run_id:
|
||||
text_end_payload["runId"] = run_id
|
||||
for key in ("messageId", "workerAgentOutput"):
|
||||
value = data.get(key)
|
||||
if value is not None:
|
||||
text_end_payload[key] = value
|
||||
reserved = {
|
||||
"type",
|
||||
"threadId",
|
||||
"runId",
|
||||
"inputTokens",
|
||||
"outputTokens",
|
||||
"cost",
|
||||
"latencyMs",
|
||||
"model",
|
||||
}
|
||||
text_end_payload.update({k: v for k, v in data.items() if k not in reserved})
|
||||
return text_end_payload
|
||||
|
||||
if internal_type == "tool.result" and isinstance(data, dict):
|
||||
tool_result_payload = {
|
||||
tool_result_payload: dict[str, Any] = {
|
||||
"type": _convert_to_agui_type(internal_type).value,
|
||||
}
|
||||
if isinstance(thread_id, str) and thread_id:
|
||||
tool_result_payload["threadId"] = thread_id
|
||||
if isinstance(run_id, str) and run_id:
|
||||
tool_result_payload["runId"] = run_id
|
||||
for key in ("messageId", "toolCallId", "toolAgentOutput"):
|
||||
value = data.get(key)
|
||||
if value is not None:
|
||||
tool_result_payload[key] = value
|
||||
reserved = {"type", "threadId", "runId"}
|
||||
tool_result_payload.update({k: v for k, v in data.items() if k not in reserved})
|
||||
return tool_result_payload
|
||||
|
||||
builder = _BUILDER_MAP.get(internal_type)
|
||||
|
||||
@@ -1,35 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from decimal import Decimal, InvalidOperation
|
||||
from typing import Any, Callable, Protocol
|
||||
from uuid import UUID, uuid4
|
||||
from uuid import UUID
|
||||
|
||||
from core.agentscope.events.persistence import MessageRepository, SessionRepository
|
||||
from core.logging import get_logger
|
||||
from models.agent_chat_message import AgentChatMessageRole
|
||||
from models.agent_chat_session import AgentChatSessionStatus
|
||||
from schemas.agent.runtime_models import (
|
||||
ToolAgentOutput,
|
||||
WorkerAgentOutputLite,
|
||||
WorkerAgentOutputRich,
|
||||
)
|
||||
from schemas.agent.runtime_models import ToolAgentOutput, WorkerAgentOutputRich
|
||||
from schemas.agent.system_agent import AgentType
|
||||
from schemas.messages.chat_message import AgentChatMessageMetadata
|
||||
|
||||
|
||||
class EventStore(Protocol):
|
||||
async def persist(self, event: dict[str, Any]) -> None: ...
|
||||
|
||||
|
||||
class ToolResultStorageLike(Protocol):
|
||||
async def upload_json(
|
||||
self,
|
||||
*,
|
||||
bucket: str,
|
||||
path: str,
|
||||
payload: dict[str, object],
|
||||
) -> str: ...
|
||||
|
||||
|
||||
class NullEventStore:
|
||||
async def persist(self, event: dict[str, Any]) -> None:
|
||||
del event
|
||||
@@ -37,22 +24,14 @@ class NullEventStore:
|
||||
|
||||
class SqlAlchemyEventStore:
|
||||
_session_factory: Callable[[], Any]
|
||||
_tool_result_storage: ToolResultStorageLike | None
|
||||
_tool_result_bucket: str | None
|
||||
_logger = get_logger("core.agentscope.events.store")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
session_factory: Any,
|
||||
tool_result_storage: ToolResultStorageLike | None = None,
|
||||
tool_result_bucket: str | None = None,
|
||||
) -> None:
|
||||
self._session_factory = session_factory
|
||||
self._tool_result_storage = tool_result_storage
|
||||
self._tool_result_bucket = tool_result_bucket
|
||||
self._message_buffers: dict[tuple[str, str], str] = {}
|
||||
self._message_contexts: dict[tuple[str, str], dict[str, object]] = {}
|
||||
|
||||
async def persist(self, event: dict[str, Any]) -> None:
|
||||
event_type = str(event.get("type", "")).strip().upper().replace(".", "_")
|
||||
@@ -63,22 +42,11 @@ class SqlAlchemyEventStore:
|
||||
session_id = UUID(thread_id)
|
||||
except ValueError:
|
||||
return
|
||||
session_key = str(session_id)
|
||||
|
||||
async with self._session_factory() as session:
|
||||
session_repo = SessionRepository(session)
|
||||
message_repo = MessageRepository(session)
|
||||
chat_session = await session_repo.get_session(session_id=session_id)
|
||||
if chat_session is None:
|
||||
self._clear_session_buffers(session_key=session_key)
|
||||
return
|
||||
|
||||
if event_type == "TEXT_MESSAGE_CONTENT":
|
||||
self._buffer_text_delta(session_key=session_key, event=event)
|
||||
return
|
||||
|
||||
if event_type == "TEXT_MESSAGE_START":
|
||||
self._buffer_text_context(session_key=session_key, event=event)
|
||||
return
|
||||
|
||||
if event_type == "RUN_STARTED":
|
||||
@@ -95,7 +63,6 @@ class SqlAlchemyEventStore:
|
||||
status=AgentChatSessionStatus.FAILED,
|
||||
message_delta=0,
|
||||
)
|
||||
self._clear_session_buffers(session_key=session_key)
|
||||
elif event_type == "RUN_FINISHED":
|
||||
await self._update_session_state(
|
||||
session_repo=session_repo,
|
||||
@@ -103,7 +70,6 @@ class SqlAlchemyEventStore:
|
||||
status=AgentChatSessionStatus.COMPLETED,
|
||||
message_delta=0,
|
||||
)
|
||||
self._clear_session_buffers(session_key=session_key)
|
||||
elif event_type == "TEXT_MESSAGE_END":
|
||||
await self._persist_text_message(
|
||||
event=event,
|
||||
@@ -123,42 +89,6 @@ class SqlAlchemyEventStore:
|
||||
|
||||
await session.commit()
|
||||
|
||||
def _buffer_text_delta(self, *, session_key: str, event: dict[str, Any]) -> None:
|
||||
message_id = self._event_value(event, "messageId")
|
||||
delta = self._event_value(event, "delta")
|
||||
if not isinstance(message_id, str) or not message_id:
|
||||
return
|
||||
if not isinstance(delta, str) or not delta:
|
||||
return
|
||||
key = (session_key, message_id)
|
||||
current = self._message_buffers.get(key, "")
|
||||
self._message_buffers[key] = f"{current}{delta}"
|
||||
|
||||
def _clear_session_buffers(self, *, session_key: str) -> None:
|
||||
stale_keys = [k for k in self._message_buffers if k[0] == session_key]
|
||||
for key in stale_keys:
|
||||
self._message_buffers.pop(key, None)
|
||||
stale_context_keys = [k for k in self._message_contexts if k[0] == session_key]
|
||||
for key in stale_context_keys:
|
||||
self._message_contexts.pop(key, None)
|
||||
|
||||
def _buffer_text_context(self, *, session_key: str, event: dict[str, Any]) -> None:
|
||||
message_id = self._event_value(event, "messageId")
|
||||
if not isinstance(message_id, str) or not message_id:
|
||||
return
|
||||
key = (session_key, message_id)
|
||||
role = self._event_value(event, "role")
|
||||
stage = self._event_value(event, "stage")
|
||||
tool_name = self._event_value(event, "toolName")
|
||||
context: dict[str, object] = {}
|
||||
if isinstance(role, str) and role:
|
||||
context["role"] = role
|
||||
if isinstance(stage, str) and stage:
|
||||
context["stage"] = stage
|
||||
if isinstance(tool_name, str) and tool_name:
|
||||
context["tool_name"] = tool_name
|
||||
self._message_contexts[key] = context
|
||||
|
||||
async def _persist_text_message(
|
||||
self,
|
||||
*,
|
||||
@@ -170,13 +100,11 @@ class SqlAlchemyEventStore:
|
||||
) -> None:
|
||||
message_id_raw = self._event_value(event, "messageId")
|
||||
message_id = message_id_raw if isinstance(message_id_raw, str) else ""
|
||||
key = (str(session_id), message_id)
|
||||
content = self._message_buffers.get(key, "")
|
||||
content_value = self._event_value(event, "answer")
|
||||
content = content_value if isinstance(content_value, str) else ""
|
||||
if not content:
|
||||
return
|
||||
|
||||
context = self._message_contexts.get(key, {})
|
||||
|
||||
input_tokens = self._to_int(self._event_value(event, "inputTokens"))
|
||||
output_tokens = self._to_int(self._event_value(event, "outputTokens"))
|
||||
token_delta = input_tokens + output_tokens
|
||||
@@ -185,35 +113,48 @@ class SqlAlchemyEventStore:
|
||||
run_id = self._event_value(event, "runId")
|
||||
model_code = self._event_value(event, "model")
|
||||
|
||||
metadata: dict[str, object] = {"message_id": message_id}
|
||||
if isinstance(run_id, str) and run_id:
|
||||
metadata["run_id"] = run_id
|
||||
if latency_ms is not None:
|
||||
metadata["latency_ms"] = latency_ms
|
||||
stage = self._event_value(event, "stage")
|
||||
if not isinstance(stage, str):
|
||||
stage = context.get("stage")
|
||||
if isinstance(stage, str) and stage:
|
||||
metadata["stage"] = stage
|
||||
run_id_value = run_id if isinstance(run_id, str) and run_id else None
|
||||
if run_id_value is None:
|
||||
return
|
||||
|
||||
worker_payload = self._event_value(event, "workerAgentOutput")
|
||||
if isinstance(worker_payload, dict):
|
||||
try:
|
||||
if "ui_hints" in worker_payload:
|
||||
worker_output = WorkerAgentOutputRich.model_validate(worker_payload)
|
||||
else:
|
||||
worker_output = WorkerAgentOutputLite.model_validate(worker_payload)
|
||||
except Exception:
|
||||
worker_output = None
|
||||
else:
|
||||
content = worker_output.answer
|
||||
metadata["worker_agent_output"] = worker_output.model_dump(mode="json")
|
||||
worker_output_fields = (
|
||||
"status",
|
||||
"answer",
|
||||
"key_points",
|
||||
"result_type",
|
||||
"suggested_actions",
|
||||
"error",
|
||||
"ui_hints",
|
||||
)
|
||||
worker_output_payload: dict[str, object] = {}
|
||||
for field in worker_output_fields:
|
||||
value = self._event_value(event, field)
|
||||
if value is not None:
|
||||
worker_output_payload[field] = value
|
||||
|
||||
role_value = context.get("role")
|
||||
if not worker_output_payload:
|
||||
return
|
||||
|
||||
try:
|
||||
worker_output = WorkerAgentOutputRich.model_validate(worker_output_payload)
|
||||
metadata_model = AgentChatMessageMetadata(
|
||||
run_id=run_id_value,
|
||||
agent_type=AgentType.WORKER,
|
||||
worker_agent_output=worker_output,
|
||||
)
|
||||
except Exception:
|
||||
self._logger.warning(
|
||||
"invalid worker metadata payload",
|
||||
run_id=run_id_value,
|
||||
message_id=message_id,
|
||||
)
|
||||
return
|
||||
|
||||
role_value = self._event_value(event, "role")
|
||||
if not isinstance(role_value, str):
|
||||
role_value = "assistant"
|
||||
role = self._resolve_role(role_value)
|
||||
tool_name = context.get("tool_name")
|
||||
tool_name = self._event_value(event, "tool_name")
|
||||
tool_name_value = (
|
||||
tool_name if isinstance(tool_name, str) and tool_name else None
|
||||
)
|
||||
@@ -231,7 +172,7 @@ class SqlAlchemyEventStore:
|
||||
content=content,
|
||||
model_code=model_code if isinstance(model_code, str) else None,
|
||||
tool_name=tool_name_value,
|
||||
metadata=metadata,
|
||||
metadata=metadata_model.model_dump(mode="json", exclude_none=True),
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cost=cost,
|
||||
@@ -252,8 +193,6 @@ class SqlAlchemyEventStore:
|
||||
token_delta=token_delta,
|
||||
cost_delta=cost,
|
||||
)
|
||||
self._message_buffers.pop(key, None)
|
||||
self._message_contexts.pop(key, None)
|
||||
|
||||
async def _persist_tool_call_result(
|
||||
self,
|
||||
@@ -264,72 +203,33 @@ class SqlAlchemyEventStore:
|
||||
session_repo: SessionRepository,
|
||||
message_repo: MessageRepository,
|
||||
) -> None:
|
||||
tool_name = self._event_value(event, "toolName")
|
||||
if not isinstance(tool_name, str) or not tool_name:
|
||||
run_id = self._event_value(event, "runId")
|
||||
run_id_value = run_id if isinstance(run_id, str) and run_id else None
|
||||
if run_id_value is None:
|
||||
return
|
||||
|
||||
raw_output = self._event_value(event, "toolAgentOutput")
|
||||
if not isinstance(raw_output, dict):
|
||||
return
|
||||
raw_output: dict[str, object] = {
|
||||
"tool_name": self._event_value(event, "tool_name"),
|
||||
"tool_call_id": self._event_value(event, "tool_call_id"),
|
||||
"tool_call_args": self._event_value(event, "tool_call_args"),
|
||||
"status": self._event_value(event, "status"),
|
||||
"result_summary": self._event_value(event, "result_summary"),
|
||||
"error": self._event_value(event, "error"),
|
||||
"ui_hints": self._event_value(event, "ui_hints"),
|
||||
}
|
||||
|
||||
try:
|
||||
tool_output = ToolAgentOutput.model_validate(raw_output)
|
||||
except Exception:
|
||||
return
|
||||
|
||||
run_id = self._event_value(event, "runId")
|
||||
run_id_value = run_id if isinstance(run_id, str) and run_id else ""
|
||||
task_id = self._event_value(event, "taskId")
|
||||
task_id_value = task_id if isinstance(task_id, str) and task_id else "task"
|
||||
call_id_value = self._event_value(event, "callId")
|
||||
if not isinstance(call_id_value, str) or not call_id_value:
|
||||
call_id_value = (
|
||||
f"{run_id_value}-{task_id_value}-{uuid4().hex[:8]}"
|
||||
if run_id_value
|
||||
else f"{task_id_value}-{uuid4().hex[:8]}"
|
||||
metadata_model = AgentChatMessageMetadata(
|
||||
run_id=run_id_value,
|
||||
tool_agent_output=tool_output,
|
||||
)
|
||||
|
||||
payload: dict[str, object] = {
|
||||
"toolAgentOutput": tool_output.model_dump(mode="json"),
|
||||
"callId": call_id_value,
|
||||
"runId": run_id_value,
|
||||
"taskId": task_id_value,
|
||||
"content": tool_output.result_summary,
|
||||
}
|
||||
|
||||
metadata: dict[str, object] = {
|
||||
"tool_name": tool_name,
|
||||
"tool_call_id": call_id_value,
|
||||
"tool_agent_output": tool_output.model_dump(mode="json"),
|
||||
}
|
||||
if run_id_value:
|
||||
metadata["run_id"] = run_id_value
|
||||
stage = self._event_value(event, "stage")
|
||||
if isinstance(stage, str) and stage:
|
||||
metadata["stage"] = stage
|
||||
if task_id_value:
|
||||
metadata["task_id"] = task_id_value
|
||||
|
||||
if self._tool_result_storage is not None and self._tool_result_bucket:
|
||||
safe_run = _sanitize_path_component(run_id_value or "run")
|
||||
safe_call = _sanitize_path_component(call_id_value)
|
||||
storage_path = f"tool-results/{session_id}/{safe_run}/{safe_call}.json"
|
||||
try:
|
||||
await self._tool_result_storage.upload_json(
|
||||
bucket=self._tool_result_bucket,
|
||||
path=storage_path,
|
||||
payload=payload,
|
||||
)
|
||||
metadata["storage_bucket"] = self._tool_result_bucket
|
||||
metadata["storage_path"] = storage_path
|
||||
except Exception: # noqa: BLE001
|
||||
metadata["storage_upload_failed"] = True
|
||||
self._logger.warning(
|
||||
"tool result storage upload failed",
|
||||
session_id=str(session_id),
|
||||
run_id=run_id_value,
|
||||
call_id=call_id_value,
|
||||
storage_path=storage_path,
|
||||
)
|
||||
except Exception:
|
||||
self._logger.warning(
|
||||
"invalid tool metadata payload",
|
||||
run_id=run_id_value,
|
||||
)
|
||||
return
|
||||
|
||||
content = tool_output.result_summary
|
||||
|
||||
@@ -344,8 +244,8 @@ class SqlAlchemyEventStore:
|
||||
seq=seq,
|
||||
role=AgentChatMessageRole.TOOL,
|
||||
content=content,
|
||||
tool_name=tool_name,
|
||||
metadata=metadata,
|
||||
tool_name=tool_output.tool_name,
|
||||
metadata=metadata_model.model_dump(mode="json", exclude_none=True),
|
||||
)
|
||||
|
||||
current_status = getattr(chat_session, "status", AgentChatSessionStatus.RUNNING)
|
||||
@@ -433,9 +333,3 @@ class SqlAlchemyEventStore:
|
||||
if isinstance(data, dict):
|
||||
return data.get(key, default)
|
||||
return default
|
||||
|
||||
|
||||
def _sanitize_path_component(value: str) -> str:
|
||||
compact = re.sub(r"[^A-Za-z0-9._-]", "-", value.strip())
|
||||
compact = compact.strip(".-")
|
||||
return compact or "id"
|
||||
|
||||
Reference in New Issue
Block a user