Files
social-app/backend/src/v1/agent/dependencies.py
T

127 lines
4.2 KiB
Python
Raw Normal View History

from __future__ import annotations
import asyncio
from typing import Any
from fastapi import Depends
from redis.asyncio import Redis
from sqlalchemy.ext.asyncio import AsyncSession
from core.agentscope.events import RedisStreamBus
from core.agentscope.runtime.tasks import (
run_command_task,
run_command_task_bulk,
run_command_task_critical,
)
from core.agentscope.tools.tool_result_storage import (
create_tool_result_storage,
)
from core.config.settings import config
from core.db import get_db
from services.base.redis import get_or_init_redis_client
from services.base.supabase import supabase_service
from v1.agent.repository import AgentRepository
from v1.agent.service import AgentService
DEDUP_WAIT_RETRIES = 20
DEDUP_WAIT_SECONDS = 0.05
DEDUP_LOCK_SECONDS = 300
DEDUP_INFLIGHT_MARKER = "__inflight__"
def _event_stream_block_ms() -> int:
configured = int(config.agent_runtime.redis_stream_block_ms)
socket_timeout = float(config.redis.socket_timeout)
socket_timeout_ms = max(int(socket_timeout * 1000), 1)
safe_max = max(socket_timeout_ms - 100, 1)
return max(1, min(configured, safe_max))
class TaskiqQueueClient:
def __init__(self) -> None:
self._redis: Redis | None = None
async def _get_redis(self) -> Redis:
if self._redis is None:
self._redis = await get_or_init_redis_client()
return self._redis
@staticmethod
def _select_queue_task(command: dict[str, object]) -> Any:
queue = str(command.get("queue", "default")).strip().lower()
if queue == "critical":
return run_command_task_critical
if queue == "bulk":
return run_command_task_bulk
return run_command_task
async def enqueue(
self, *, command: dict[str, object], dedup_key: str | None
) -> str:
redis_client = await self._get_redis()
redis_key = None
if dedup_key:
redis_key = f"agent:dedup:{dedup_key}"
locked = await redis_client.set(
redis_key,
DEDUP_INFLIGHT_MARKER,
nx=True,
ex=DEDUP_LOCK_SECONDS,
)
if not locked:
for _ in range(DEDUP_WAIT_RETRIES):
existing = await redis_client.get(redis_key)
if existing and existing != DEDUP_INFLIGHT_MARKER:
return existing
await asyncio.sleep(DEDUP_WAIT_SECONDS)
raise RuntimeError("duplicate request is still in progress")
payload = dict(command)
queue_task = self._select_queue_task(payload)
try:
result = await queue_task.kiq(payload)
task_id = str(result.task_id)
if redis_key is not None:
await redis_client.set(redis_key, task_id, ex=DEDUP_LOCK_SECONDS)
return task_id
except Exception:
if redis_key is not None:
await redis_client.delete(redis_key)
raise
class RedisEventStream:
def __init__(self) -> None:
self._bus: RedisStreamBus | None = None
async def _get_bus(self) -> RedisStreamBus:
if self._bus is None:
client = await get_or_init_redis_client()
self._bus = RedisStreamBus(
client=client,
stream_prefix=config.agent_runtime.redis_stream_prefix,
read_count=config.agent_runtime.redis_stream_read_count,
block_ms=_event_stream_block_ms(),
)
return self._bus
async def read(
self,
*,
session_id: str,
last_event_id: str | None,
) -> list[dict[str, Any]]:
bus = await self._get_bus()
rows = await bus.read(session_id=session_id, last_event_id=last_event_id)
return [{**row, "cursor": row.get("id")} for row in rows]
def get_agent_service(session: AsyncSession = Depends(get_db)) -> AgentService:
tool_result_storage = create_tool_result_storage()
return AgentService(
repository=AgentRepository(session, tool_result_storage=tool_result_storage),
queue=TaskiqQueueClient(),
stream=RedisEventStream(),
attachment_storage=supabase_service,
)