refactor: 重构 AgentScope ReAct Runner 与事件处理
- 重构 runtime/runner.py 实现 ReAct Agent 核心逻辑 - 更新事件编码器与存储机制 - 优化 prompt 系统与 tool 调用 - 调整 agent service 与 repository 配合
This commit is contained in:
@@ -30,8 +30,8 @@ class AgentRepository:
|
||||
*,
|
||||
tool_result_storage: ToolResultPayloadStorage | None = None,
|
||||
) -> None:
|
||||
self._session = session
|
||||
self._tool_result_storage = tool_result_storage
|
||||
self._session: AsyncSession = session
|
||||
self._tool_result_storage: ToolResultPayloadStorage | None = tool_result_storage
|
||||
|
||||
async def get_session_owner(self, *, session_id: str) -> str:
|
||||
try:
|
||||
@@ -138,34 +138,31 @@ class AgentRepository:
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail="Invalid session_id") from exc
|
||||
|
||||
timestamp_stmt = (
|
||||
before_start = (
|
||||
datetime.combine(before, time.min, tzinfo=timezone.utc)
|
||||
if before is not None
|
||||
else None
|
||||
)
|
||||
|
||||
target_created_at_stmt = (
|
||||
select(AgentChatMessage.created_at)
|
||||
.where(AgentChatMessage.session_id == session_uuid)
|
||||
.where(AgentChatMessage.deleted_at.is_(None))
|
||||
.order_by(AgentChatMessage.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
rows = (await self._session.execute(timestamp_stmt)).scalars().all()
|
||||
unique_days: list[date] = []
|
||||
for created_at in rows:
|
||||
if created_at is None:
|
||||
continue
|
||||
day = created_at.astimezone(timezone.utc).date()
|
||||
if day not in unique_days:
|
||||
unique_days.append(day)
|
||||
if before_start is not None:
|
||||
target_created_at_stmt = target_created_at_stmt.where(
|
||||
AgentChatMessage.created_at < before_start
|
||||
)
|
||||
target_created_at = (
|
||||
await self._session.execute(target_created_at_stmt)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if not unique_days:
|
||||
if target_created_at is None:
|
||||
return None
|
||||
|
||||
target_day: date | None = None
|
||||
if before is None:
|
||||
target_day = unique_days[0]
|
||||
else:
|
||||
for day in unique_days:
|
||||
if day < before:
|
||||
target_day = day
|
||||
break
|
||||
if target_day is None:
|
||||
return None
|
||||
target_day = target_created_at.astimezone(timezone.utc).date()
|
||||
|
||||
start = datetime.combine(target_day, time.min, tzinfo=timezone.utc)
|
||||
end = start + timedelta(days=1)
|
||||
@@ -178,7 +175,16 @@ class AgentRepository:
|
||||
.order_by(AgentChatMessage.seq.asc())
|
||||
)
|
||||
messages = (await self._session.execute(message_stmt)).scalars().all()
|
||||
has_more = any(day < target_day for day in unique_days)
|
||||
has_more_stmt = (
|
||||
select(AgentChatMessage.id)
|
||||
.where(AgentChatMessage.session_id == session_uuid)
|
||||
.where(AgentChatMessage.deleted_at.is_(None))
|
||||
.where(AgentChatMessage.created_at < start)
|
||||
.limit(1)
|
||||
)
|
||||
has_more = (
|
||||
await self._session.execute(has_more_stmt)
|
||||
).scalar_one_or_none() is not None
|
||||
snapshot_messages: list[dict[str, object]] = []
|
||||
for message in messages:
|
||||
snapshot_messages.append(await self._to_snapshot_message(message))
|
||||
|
||||
@@ -128,6 +128,10 @@ async def enqueue_run(
|
||||
service: Annotated[AgentService, Depends(get_agent_service)],
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
) -> TaskAcceptedResponse:
|
||||
try:
|
||||
request = parse_run_input(request.model_dump(by_alias=True, exclude_none=True))
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
||||
try:
|
||||
validate_run_request_messages_contract(request)
|
||||
except ValueError as exc:
|
||||
|
||||
@@ -170,12 +170,9 @@ class AgentService:
|
||||
command={
|
||||
"command": "run",
|
||||
"owner_id": str(current_user.id),
|
||||
"run_input": {
|
||||
"messages": [
|
||||
msg.model_dump(mode="json", exclude_none=True)
|
||||
for msg in run_input.messages
|
||||
],
|
||||
},
|
||||
"run_input": run_input.model_dump(
|
||||
mode="json", by_alias=True, exclude_none=True
|
||||
),
|
||||
},
|
||||
dedup_key=None,
|
||||
)
|
||||
@@ -204,7 +201,7 @@ class AgentService:
|
||||
|
||||
yesterday = await self._repository.get_history_day(
|
||||
session_id=thread_id,
|
||||
before=today.get("day"), # type: ignore
|
||||
before=self._parse_history_day(today.get("day")),
|
||||
)
|
||||
|
||||
messages: list[dict[str, object]] = []
|
||||
@@ -215,6 +212,16 @@ class AgentService:
|
||||
|
||||
return {"messages": messages}
|
||||
|
||||
def _parse_history_day(self, value: object) -> date | None:
|
||||
if isinstance(value, date):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return date.fromisoformat(value)
|
||||
except ValueError:
|
||||
return None
|
||||
return None
|
||||
|
||||
async def _prepare_user_message(
|
||||
self,
|
||||
*,
|
||||
|
||||
@@ -17,7 +17,7 @@ from schemas.messages.chat_message import (
|
||||
|
||||
def convert_message_to_history(
|
||||
message: AgentChatMessage,
|
||||
get_signed_url_fn: Callable[[str, str], str] | None = None,
|
||||
get_signed_url_fn: Callable[[dict[str, str]], str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
将 AgentChatMessage 转换为 HistoryMessage 格式
|
||||
@@ -55,14 +55,14 @@ def convert_message_to_history(
|
||||
result["url"] = url
|
||||
|
||||
if ui_schema:
|
||||
result["uiSchema"] = ui_schema
|
||||
result["ui_schema"] = ui_schema
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _convert_user_attachments(
|
||||
metadata: AgentChatMessageMetadata | dict[str, Any] | None,
|
||||
get_signed_url_fn: Callable[[str, str], str] | None,
|
||||
get_signed_url_fn: Callable[[dict[str, str]], str] | None,
|
||||
) -> str | None:
|
||||
"""转换用户附件为临时访问 URL"""
|
||||
if not metadata:
|
||||
@@ -100,9 +100,19 @@ def _compile_tool_ui_hints(
|
||||
tool_output_data = metadata.get("tool_agent_output")
|
||||
if not tool_output_data:
|
||||
return None
|
||||
if isinstance(tool_output_data, dict):
|
||||
raw_ui_schema = tool_output_data.get("ui_schema")
|
||||
if isinstance(raw_ui_schema, dict):
|
||||
return raw_ui_schema
|
||||
legacy_ui_schema = tool_output_data.get("uiSchema")
|
||||
if isinstance(legacy_ui_schema, dict):
|
||||
return legacy_ui_schema
|
||||
from schemas.agent.runtime_models import ToolAgentOutput
|
||||
|
||||
tool_output = ToolAgentOutput.model_validate(tool_output_data)
|
||||
try:
|
||||
tool_output = ToolAgentOutput.model_validate(tool_output_data)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
if not tool_output:
|
||||
return None
|
||||
@@ -131,9 +141,19 @@ def _compile_worker_ui_hints(
|
||||
worker_output_data = metadata.get("worker_agent_output")
|
||||
if not worker_output_data:
|
||||
return None
|
||||
if isinstance(worker_output_data, dict):
|
||||
raw_ui_schema = worker_output_data.get("ui_schema")
|
||||
if isinstance(raw_ui_schema, dict):
|
||||
return raw_ui_schema
|
||||
legacy_ui_schema = worker_output_data.get("uiSchema")
|
||||
if isinstance(legacy_ui_schema, dict):
|
||||
return legacy_ui_schema
|
||||
from schemas.agent.runtime_models import WorkerAgentOutputRich
|
||||
|
||||
worker_output = WorkerAgentOutputRich.model_validate(worker_output_data)
|
||||
try:
|
||||
worker_output = WorkerAgentOutputRich.model_validate(worker_output_data)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
if not worker_output:
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user