feat(agent): complete task4-6 tool result persistence flow
This commit is contained in:
@@ -9,6 +9,9 @@ from redis.asyncio import Redis
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.agent.infrastructure.events.redis_stream import RedisStreamEventStore
|
||||
from core.agent.infrastructure.storage.tool_result_storage import (
|
||||
create_tool_result_storage,
|
||||
)
|
||||
from core.agent.infrastructure.queue.tasks import (
|
||||
run_command_task,
|
||||
run_command_task_bulk,
|
||||
@@ -109,8 +112,9 @@ class RedisEventStream:
|
||||
|
||||
|
||||
def get_agent_service(session: AsyncSession = Depends(get_db)) -> AgentService:
|
||||
tool_result_storage = create_tool_result_storage()
|
||||
return AgentService(
|
||||
repository=AgentRepository(session),
|
||||
repository=AgentRepository(session, tool_result_storage=tool_result_storage),
|
||||
queue=TaskiqQueueClient(),
|
||||
stream=RedisEventStream(),
|
||||
)
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from datetime import date, datetime, time, timedelta, timezone
|
||||
import json
|
||||
from typing import Protocol
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
@@ -12,9 +13,21 @@ from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
|
||||
from models.agent_chat_session import AgentChatSession
|
||||
|
||||
|
||||
class ToolResultPayloadStorage(Protocol):
|
||||
async def read_json(
|
||||
self, *, bucket: str, path: str
|
||||
) -> dict[str, object] | None: ...
|
||||
|
||||
|
||||
class AgentRepository:
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
*,
|
||||
tool_result_storage: ToolResultPayloadStorage | None = None,
|
||||
) -> None:
|
||||
self._session = session
|
||||
self._tool_result_storage = tool_result_storage
|
||||
|
||||
async def get_session_owner(self, *, session_id: str) -> str:
|
||||
try:
|
||||
@@ -42,7 +55,9 @@ class AgentRepository:
|
||||
try:
|
||||
session_uuid = UUID(session_id)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail="Invalid session_id") from exc
|
||||
raise HTTPException(
|
||||
status_code=422, detail="Invalid session_id"
|
||||
) from exc
|
||||
|
||||
session = AgentChatSession(
|
||||
id=session_uuid,
|
||||
@@ -118,10 +133,13 @@ class AgentRepository:
|
||||
)
|
||||
messages = (await self._session.execute(message_stmt)).scalars().all()
|
||||
has_more = any(day < target_day for day in unique_days)
|
||||
snapshot_messages: list[dict[str, object]] = []
|
||||
for message in messages:
|
||||
snapshot_messages.append(await self._to_snapshot_message(message))
|
||||
return {
|
||||
"day": target_day.isoformat(),
|
||||
"hasMore": has_more,
|
||||
"messages": [self._to_snapshot_message(msg) for msg in messages],
|
||||
"messages": snapshot_messages,
|
||||
}
|
||||
|
||||
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None:
|
||||
@@ -141,8 +159,9 @@ class AgentRepository:
|
||||
return None
|
||||
return str(latest_id)
|
||||
|
||||
@staticmethod
|
||||
def _to_snapshot_message(message: AgentChatMessage) -> dict[str, object]:
|
||||
async def _to_snapshot_message(
|
||||
self, message: AgentChatMessage
|
||||
) -> dict[str, object]:
|
||||
role = (
|
||||
message.role.value
|
||||
if isinstance(message.role, AgentChatMessageRole)
|
||||
@@ -167,14 +186,35 @@ class AgentRepository:
|
||||
parsed_content = decoded
|
||||
except (TypeError, ValueError):
|
||||
parsed_content = None
|
||||
if parsed_content is not None:
|
||||
ui = parsed_content.get("ui")
|
||||
|
||||
hydrated_content: dict[str, object] | None = None
|
||||
if self._tool_result_storage is not None:
|
||||
storage_bucket = metadata.get("storage_bucket")
|
||||
storage_path = metadata.get("storage_path")
|
||||
if isinstance(storage_bucket, str) and isinstance(storage_path, str):
|
||||
try:
|
||||
hydrated_content = await self._tool_result_storage.read_json(
|
||||
bucket=storage_bucket,
|
||||
path=storage_path,
|
||||
)
|
||||
except Exception:
|
||||
hydrated_content = None
|
||||
|
||||
resolved_content = hydrated_content or parsed_content
|
||||
if resolved_content is not None:
|
||||
result = resolved_content.get("result")
|
||||
if isinstance(result, dict):
|
||||
result_content = result.get("content")
|
||||
if isinstance(result_content, str):
|
||||
payload["content"] = result_content
|
||||
ui = resolved_content.get("ui")
|
||||
if isinstance(ui, dict):
|
||||
payload["ui"] = ui
|
||||
display_content = parsed_content.get("content")
|
||||
display_content = resolved_content.get("content")
|
||||
if isinstance(display_content, str):
|
||||
payload["content"] = display_content
|
||||
else:
|
||||
|
||||
if "content" not in payload:
|
||||
payload["content"] = message.content
|
||||
else:
|
||||
payload["content"] = message.content
|
||||
|
||||
Reference in New Issue
Block a user