312 lines
10 KiB
Python
312 lines
10 KiB
Python
from __future__ import annotations
|
|
|
|
import base64
|
|
import json
|
|
from typing import Any, cast
|
|
from uuid import UUID
|
|
|
|
from agentscope.message import Msg
|
|
from core.agentscope.events import (
|
|
AgentScopeAgUiCodec,
|
|
AgentScopeEventPipeline,
|
|
RedisStreamBus,
|
|
SqlAlchemyEventStore,
|
|
)
|
|
from core.agentscope.runtime.context_service import AgentContextService
|
|
from core.agentscope.runtime.orchestrator import AgentScopeRuntimeOrchestrator
|
|
from core.agentscope.runtime.pipeline_registry import build_default_pipeline_spec
|
|
from core.agentscope.schemas.agui_input import parse_run_input
|
|
from core.auth.models import CurrentUser
|
|
from core.config.settings import config
|
|
from core.db.session import AsyncSessionLocal
|
|
from core.logging import get_logger
|
|
from core.taskiq.app import bulk_broker, critical_broker, default_broker
|
|
from models.automation_jobs import AutomationJob
|
|
from schemas.agent.visibility import SystemVisibilityBit, bit_mask
|
|
from schemas.automation.config import AutomationJobConfig
|
|
from schemas.messages.chat_message import (
|
|
AgentChatMessageMetadata,
|
|
extract_user_message_attachments,
|
|
)
|
|
from schemas.agent.forwarded_props import parse_forwarded_props_agent_type
|
|
from schemas.user import UserContext
|
|
from services.base.redis import get_or_init_redis_client
|
|
from services.base.supabase import supabase_service
|
|
from v1.agent.repository import AgentRepository
|
|
from v1.users.dependencies import get_user_service
|
|
from sqlalchemy import select
|
|
|
|
logger = get_logger("core.agentscope.runtime.tasks")
|
|
_MAX_CONTEXT_ATTACHMENTS = 3
|
|
|
|
|
|
def _serialize_tool_agent_output(
|
|
*,
|
|
metadata: AgentChatMessageMetadata | dict[str, object] | None,
|
|
) -> str | None:
|
|
if metadata is None:
|
|
return None
|
|
|
|
try:
|
|
resolved_metadata = (
|
|
metadata
|
|
if isinstance(metadata, AgentChatMessageMetadata)
|
|
else AgentChatMessageMetadata.model_validate(metadata)
|
|
)
|
|
except Exception:
|
|
return None
|
|
|
|
tool_agent_output = resolved_metadata.tool_agent_output
|
|
if tool_agent_output is None:
|
|
return None
|
|
|
|
return json.dumps(
|
|
tool_agent_output.model_dump(mode="json", exclude_none=True),
|
|
ensure_ascii=True,
|
|
separators=(",", ":"),
|
|
)
|
|
|
|
|
|
def _load_runtime() -> type[Any]:
|
|
return AgentScopeRuntimeOrchestrator
|
|
|
|
|
|
async def _build_user_context(
|
|
*,
|
|
owner_id: UUID,
|
|
session: Any,
|
|
) -> UserContext:
|
|
current_user = CurrentUser(id=owner_id)
|
|
user_service = get_user_service(session=session, user=current_user)
|
|
return await user_service.get_me()
|
|
|
|
|
|
async def _build_recent_context_messages(
|
|
*,
|
|
session: Any,
|
|
thread_id: str,
|
|
context_mode: str,
|
|
memory_job_config: AutomationJobConfig | None = None,
|
|
) -> list[Msg]:
|
|
context_service = AgentContextService(repository=AgentRepository(session))
|
|
if memory_job_config is not None:
|
|
visibility_mask = bit_mask(bit=int(SystemVisibilityBit.UI_HISTORY))
|
|
if memory_job_config.context.window_mode.value == "day":
|
|
result = await context_service.load_by_day_window(
|
|
thread_id=thread_id,
|
|
day_count=memory_job_config.context.window_count,
|
|
visibility_mask=visibility_mask,
|
|
)
|
|
else:
|
|
result = await context_service.load_by_user_message_window(
|
|
thread_id=thread_id,
|
|
user_message_limit=memory_job_config.context.window_count,
|
|
visibility_mask=visibility_mask,
|
|
)
|
|
else:
|
|
result = await context_service.load_context_messages(
|
|
thread_id=thread_id,
|
|
system_agent_mode=context_mode,
|
|
)
|
|
if not result:
|
|
return []
|
|
|
|
raw_messages: list[dict[str, object]] = result.get("messages") or []
|
|
if not raw_messages:
|
|
return []
|
|
|
|
converted: list[Msg] = []
|
|
|
|
for msg in raw_messages:
|
|
role_raw = msg.get("role")
|
|
role = role_raw if isinstance(role_raw, str) else "user"
|
|
content_raw = msg.get("content", "")
|
|
content: str = content_raw if isinstance(content_raw, str) else ""
|
|
metadata_raw = msg.get("metadata")
|
|
metadata: AgentChatMessageMetadata | dict[str, object] | None
|
|
if isinstance(metadata_raw, AgentChatMessageMetadata):
|
|
metadata = metadata_raw
|
|
elif isinstance(metadata_raw, dict):
|
|
metadata = metadata_raw
|
|
else:
|
|
metadata = None
|
|
|
|
if role == "user" and metadata:
|
|
image_blocks: list[dict[str, Any]] = []
|
|
attachments = extract_user_message_attachments(metadata)[
|
|
:_MAX_CONTEXT_ATTACHMENTS
|
|
]
|
|
for attachment in attachments:
|
|
try:
|
|
image_bytes = await supabase_service.download_bytes(
|
|
bucket=attachment.bucket,
|
|
path=attachment.path,
|
|
)
|
|
except Exception:
|
|
continue
|
|
b64_data = base64.b64encode(image_bytes).decode("utf-8")
|
|
image_blocks.append(
|
|
{
|
|
"type": "image",
|
|
"source": {
|
|
"type": "base64",
|
|
"media_type": attachment.mime_type or "image/png",
|
|
"data": b64_data,
|
|
},
|
|
}
|
|
)
|
|
|
|
if image_blocks:
|
|
multimodal_content: list[dict[str, Any]] = []
|
|
if isinstance(content, str) and content:
|
|
multimodal_content.append({"type": "text", "text": content})
|
|
multimodal_content.extend(image_blocks)
|
|
converted.append(
|
|
Msg(
|
|
name="user",
|
|
role="user",
|
|
content=cast(Any, multimodal_content),
|
|
)
|
|
)
|
|
continue
|
|
|
|
if role == "tool":
|
|
role = "assistant"
|
|
tool_content = _serialize_tool_agent_output(metadata=metadata)
|
|
if not tool_content:
|
|
continue
|
|
content = tool_content
|
|
|
|
converted.append(
|
|
Msg(
|
|
name=role or "user",
|
|
role=role if role in ("user", "assistant", "system") else "user",
|
|
content=content,
|
|
)
|
|
)
|
|
|
|
return converted
|
|
|
|
|
|
async def _load_memory_job_config(
|
|
*,
|
|
session: Any,
|
|
owner_id: UUID,
|
|
automation_job_id: str,
|
|
) -> AutomationJobConfig:
|
|
try:
|
|
job_uuid = UUID(automation_job_id)
|
|
except ValueError as exc:
|
|
raise ValueError("automation_job_id is invalid") from exc
|
|
|
|
stmt = (
|
|
select(AutomationJob)
|
|
.where(AutomationJob.id == job_uuid)
|
|
.where(AutomationJob.owner_id == owner_id)
|
|
.where(AutomationJob.deleted_at.is_(None))
|
|
)
|
|
row = (await session.execute(stmt)).scalar_one_or_none()
|
|
if row is None:
|
|
raise ValueError("automation job not found")
|
|
return AutomationJobConfig.model_validate(row.config or {})
|
|
|
|
|
|
async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
|
command_type = str(command.get("command", "run")).strip().lower()
|
|
raw_owner_id = command.get("owner_id")
|
|
run_input_raw = command.get("run_input")
|
|
|
|
if not isinstance(raw_owner_id, str) or not raw_owner_id.strip():
|
|
raise ValueError("owner_id is required")
|
|
if run_input_raw is None:
|
|
raise ValueError("run_input is required")
|
|
|
|
run_input = parse_run_input(run_input_raw)
|
|
system_agent_mode = parse_forwarded_props_agent_type(
|
|
getattr(run_input, "forwarded_props", None)
|
|
)
|
|
raw_automation_job_id = command.get("automation_job_id")
|
|
if system_agent_mode == "memory" and (
|
|
not isinstance(raw_automation_job_id, str) or not raw_automation_job_id
|
|
):
|
|
raise ValueError("automation_job_id is required for memory mode")
|
|
pipeline_spec = build_default_pipeline_spec(mode=system_agent_mode)
|
|
thread_id = run_input.thread_id
|
|
run_id = run_input.run_id
|
|
owner_id = UUID(raw_owner_id)
|
|
|
|
if command_type != "run":
|
|
raise ValueError("invalid command type")
|
|
|
|
orchestrator = _load_runtime()
|
|
|
|
async with AsyncSessionLocal() as session:
|
|
user_context = await _build_user_context(owner_id=owner_id, session=session)
|
|
memory_job_config: AutomationJobConfig | None = None
|
|
if system_agent_mode == "memory":
|
|
assert isinstance(raw_automation_job_id, str)
|
|
memory_job_config = await _load_memory_job_config(
|
|
session=session,
|
|
owner_id=owner_id,
|
|
automation_job_id=raw_automation_job_id,
|
|
)
|
|
|
|
redis_client = await get_or_init_redis_client()
|
|
bus = RedisStreamBus(
|
|
client=redis_client,
|
|
stream_prefix=config.agent_runtime.redis_stream_prefix,
|
|
read_count=config.agent_runtime.redis_stream_read_count,
|
|
block_ms=config.agent_runtime.redis_stream_block_ms,
|
|
)
|
|
pipeline = AgentScopeEventPipeline(
|
|
codec=AgentScopeAgUiCodec(),
|
|
store=SqlAlchemyEventStore(
|
|
session_factory=AsyncSessionLocal,
|
|
),
|
|
bus=bus,
|
|
)
|
|
runtime = orchestrator(
|
|
pipeline=pipeline,
|
|
)
|
|
|
|
context_messages = await _build_recent_context_messages(
|
|
session=session,
|
|
thread_id=thread_id,
|
|
context_mode=pipeline_spec.stages[0].context_policy.consumer_agent_type,
|
|
memory_job_config=memory_job_config,
|
|
)
|
|
|
|
await runtime.run(
|
|
run_input=run_input,
|
|
context_messages=context_messages,
|
|
user_context=user_context,
|
|
system_agent_mode=system_agent_mode,
|
|
memory_job_config=memory_job_config,
|
|
)
|
|
logger.info(
|
|
"agentscope runtime task completed",
|
|
command_type=command_type,
|
|
thread_id=thread_id,
|
|
run_id=run_id,
|
|
)
|
|
return {
|
|
"thread_id": thread_id,
|
|
"run_id": run_id,
|
|
"status": "completed",
|
|
}
|
|
|
|
|
|
@default_broker.task(task_name="tasks.agentscope.run_command")
|
|
async def run_command_task(command: dict[str, Any]) -> dict[str, object]:
|
|
return await run_agentscope_task(command)
|
|
|
|
|
|
@critical_broker.task(task_name="tasks.agentscope.run_command.critical")
|
|
async def run_command_task_critical(command: dict[str, Any]) -> dict[str, object]:
|
|
return await run_agentscope_task(command)
|
|
|
|
|
|
@bulk_broker.task(task_name="tasks.agentscope.run_command.bulk")
|
|
async def run_command_task_bulk(command: dict[str, Any]) -> dict[str, object]:
|
|
return await run_agentscope_task(command)
|