117 lines
3.8 KiB
Python
117 lines
3.8 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from typing import Any
|
|
from uuid import UUID
|
|
|
|
from fastapi import Depends
|
|
from redis.asyncio import Redis
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from core.agent.infrastructure.events.redis_stream import RedisStreamEventStore
|
|
from core.agent.infrastructure.queue.tasks import (
|
|
run_command_task,
|
|
run_command_task_bulk,
|
|
run_command_task_critical,
|
|
)
|
|
from core.config.settings import config
|
|
from core.db import get_db
|
|
from services.base.redis import get_or_init_redis_client
|
|
from v1.agent.repository import AgentRepository
|
|
from v1.agent.service import AgentService
|
|
|
|
DEDUP_WAIT_RETRIES = 20
|
|
DEDUP_WAIT_SECONDS = 0.05
|
|
DEDUP_LOCK_SECONDS = 300
|
|
DEDUP_INFLIGHT_MARKER = "__inflight__"
|
|
|
|
|
|
class TaskiqQueueClient:
|
|
def __init__(self) -> None:
|
|
self._redis: Redis | None = None
|
|
|
|
async def _get_redis(self) -> Redis:
|
|
if self._redis is None:
|
|
self._redis = await get_or_init_redis_client()
|
|
return self._redis
|
|
|
|
@staticmethod
|
|
def _select_queue_task(command: dict[str, object]) -> Any:
|
|
queue = str(command.get("queue", "default")).strip().lower()
|
|
if queue == "critical":
|
|
return run_command_task_critical
|
|
if queue == "bulk":
|
|
return run_command_task_bulk
|
|
return run_command_task
|
|
|
|
async def enqueue(
|
|
self, *, command: dict[str, object], dedup_key: str | None
|
|
) -> str:
|
|
redis_client = await self._get_redis()
|
|
redis_key = None
|
|
if dedup_key:
|
|
redis_key = f"agent:dedup:{dedup_key}"
|
|
locked = await redis_client.set(
|
|
redis_key,
|
|
DEDUP_INFLIGHT_MARKER,
|
|
nx=True,
|
|
ex=DEDUP_LOCK_SECONDS,
|
|
)
|
|
if not locked:
|
|
for _ in range(DEDUP_WAIT_RETRIES):
|
|
existing = await redis_client.get(redis_key)
|
|
if existing and existing != DEDUP_INFLIGHT_MARKER:
|
|
return existing
|
|
await asyncio.sleep(DEDUP_WAIT_SECONDS)
|
|
raise RuntimeError("duplicate request is still in progress")
|
|
|
|
payload = dict(command)
|
|
queue_task = self._select_queue_task(payload)
|
|
try:
|
|
result = await queue_task.kiq(payload)
|
|
task_id = str(result.task_id)
|
|
if redis_key is not None:
|
|
await redis_client.set(redis_key, task_id, ex=DEDUP_LOCK_SECONDS)
|
|
return task_id
|
|
except Exception:
|
|
if redis_key is not None:
|
|
await redis_client.delete(redis_key)
|
|
raise
|
|
|
|
|
|
class RedisEventStream:
|
|
def __init__(self) -> None:
|
|
self._store: RedisStreamEventStore | None = None
|
|
|
|
async def _get_store(self) -> RedisStreamEventStore:
|
|
if self._store is None:
|
|
client = await get_or_init_redis_client()
|
|
self._store = RedisStreamEventStore(
|
|
client=client,
|
|
stream_prefix=config.agent_runtime.redis_stream_prefix,
|
|
read_count=config.agent_runtime.redis_stream_read_count,
|
|
block_ms=config.agent_runtime.redis_stream_block_ms,
|
|
)
|
|
return self._store
|
|
|
|
async def read(
|
|
self,
|
|
*,
|
|
session_id: str,
|
|
last_event_id: str | None,
|
|
) -> list[dict[str, Any]]:
|
|
store = await self._get_store()
|
|
rows = await store.read_events(
|
|
session_id=UUID(session_id),
|
|
last_event_id=last_event_id,
|
|
)
|
|
return [{**row, "cursor": last_event_id} for row in rows]
|
|
|
|
|
|
def get_agent_service(session: AsyncSession = Depends(get_db)) -> AgentService:
|
|
return AgentService(
|
|
repository=AgentRepository(session),
|
|
queue=TaskiqQueueClient(),
|
|
stream=RedisEventStream(),
|
|
)
|