133 lines
3.6 KiB
Python
133 lines
3.6 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
from dataclasses import dataclass
|
||
|
|
from typing import Protocol
|
||
|
|
|
||
|
|
from fastapi import HTTPException
|
||
|
|
|
||
|
|
from core.auth.models import CurrentUser
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass(frozen=True)
|
||
|
|
class TaskAccepted:
|
||
|
|
task_id: str
|
||
|
|
session_id: str
|
||
|
|
created: bool
|
||
|
|
|
||
|
|
|
||
|
|
class AgentRepositoryLike(Protocol):
|
||
|
|
async def get_session_owner(self, *, session_id: str) -> str: ...
|
||
|
|
|
||
|
|
async def create_session_for_user(self, *, user_id: str) -> str: ...
|
||
|
|
|
||
|
|
async def commit(self) -> None: ...
|
||
|
|
|
||
|
|
async def rollback(self) -> None: ...
|
||
|
|
|
||
|
|
|
||
|
|
class QueueClientLike(Protocol):
|
||
|
|
async def enqueue(
|
||
|
|
self, *, command: dict[str, object], dedup_key: str | None
|
||
|
|
) -> str: ...
|
||
|
|
|
||
|
|
|
||
|
|
class EventStreamLike(Protocol):
|
||
|
|
async def read(
|
||
|
|
self,
|
||
|
|
*,
|
||
|
|
session_id: str,
|
||
|
|
last_event_id: str | None,
|
||
|
|
) -> list[dict[str, object]]: ...
|
||
|
|
|
||
|
|
|
||
|
|
def ensure_session_owner(*, owner_id: str, current_user: CurrentUser) -> None:
|
||
|
|
if owner_id != str(current_user.id):
|
||
|
|
raise HTTPException(status_code=403, detail="Forbidden")
|
||
|
|
|
||
|
|
|
||
|
|
class AgentService:
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
*,
|
||
|
|
repository: AgentRepositoryLike,
|
||
|
|
queue: QueueClientLike,
|
||
|
|
stream: EventStreamLike,
|
||
|
|
) -> None:
|
||
|
|
self._repository = repository
|
||
|
|
self._queue = queue
|
||
|
|
self._stream = stream
|
||
|
|
|
||
|
|
async def enqueue_run(
|
||
|
|
self,
|
||
|
|
*,
|
||
|
|
session_id: str | None,
|
||
|
|
prompt: str,
|
||
|
|
current_user: CurrentUser,
|
||
|
|
) -> TaskAccepted:
|
||
|
|
created = False
|
||
|
|
target_session_id = session_id
|
||
|
|
if target_session_id is None:
|
||
|
|
target_session_id = await self._repository.create_session_for_user(
|
||
|
|
user_id=str(current_user.id)
|
||
|
|
)
|
||
|
|
created = True
|
||
|
|
else:
|
||
|
|
owner = await self._repository.get_session_owner(
|
||
|
|
session_id=target_session_id
|
||
|
|
)
|
||
|
|
ensure_session_owner(owner_id=owner, current_user=current_user)
|
||
|
|
|
||
|
|
if created:
|
||
|
|
await self._repository.commit()
|
||
|
|
|
||
|
|
try:
|
||
|
|
task_id = await self._queue.enqueue(
|
||
|
|
command={
|
||
|
|
"command": "run",
|
||
|
|
"session_id": target_session_id,
|
||
|
|
"user_input": prompt,
|
||
|
|
},
|
||
|
|
dedup_key=None,
|
||
|
|
)
|
||
|
|
except Exception: # noqa: BLE001
|
||
|
|
raise
|
||
|
|
return TaskAccepted(
|
||
|
|
task_id=task_id, session_id=target_session_id, created=created
|
||
|
|
)
|
||
|
|
|
||
|
|
async def enqueue_resume(
|
||
|
|
self,
|
||
|
|
*,
|
||
|
|
session_id: str,
|
||
|
|
tool_call_id: str,
|
||
|
|
current_user: CurrentUser,
|
||
|
|
) -> TaskAccepted:
|
||
|
|
owner = await self._repository.get_session_owner(session_id=session_id)
|
||
|
|
ensure_session_owner(owner_id=owner, current_user=current_user)
|
||
|
|
|
||
|
|
dedup_key = f"resume:{session_id}:{tool_call_id}"
|
||
|
|
task_id = await self._queue.enqueue(
|
||
|
|
command={
|
||
|
|
"command": "resume",
|
||
|
|
"session_id": session_id,
|
||
|
|
"tool_call_id": tool_call_id,
|
||
|
|
},
|
||
|
|
dedup_key=dedup_key,
|
||
|
|
)
|
||
|
|
|
||
|
|
return TaskAccepted(task_id=task_id, session_id=session_id, created=False)
|
||
|
|
|
||
|
|
async def stream_events(
|
||
|
|
self,
|
||
|
|
*,
|
||
|
|
session_id: str,
|
||
|
|
last_event_id: str | None,
|
||
|
|
current_user: CurrentUser,
|
||
|
|
) -> list[dict[str, object]]:
|
||
|
|
owner = await self._repository.get_session_owner(session_id=session_id)
|
||
|
|
ensure_session_owner(owner_id=owner, current_user=current_user)
|
||
|
|
return await self._stream.read(
|
||
|
|
session_id=session_id,
|
||
|
|
last_event_id=last_event_id,
|
||
|
|
)
|