Files
social-app/backend/src/v1/agent/service.py
T

213 lines
6.3 KiB
Python
Raw Normal View History

from __future__ import annotations
from dataclasses import dataclass
from datetime import date
from typing import Protocol
from ag_ui.core import StateSnapshotEvent
from ag_ui.core import RunAgentInput
from fastapi import HTTPException
from sqlalchemy.exc import IntegrityError
from core.auth.models import CurrentUser
@dataclass(frozen=True)
class TaskAccepted:
task_id: str
thread_id: str
run_id: str
created: bool
class AgentRepositoryLike(Protocol):
async def get_session_owner(self, *, session_id: str) -> str: ...
async def create_session_for_user(
self, *, user_id: str, session_id: str | None = None
) -> str: ...
async def commit(self) -> None: ...
async def rollback(self) -> None: ...
async def get_history_day(
self, *, session_id: str, before: date | None
) -> dict[str, object] | None: ...
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None: ...
class QueueClientLike(Protocol):
async def enqueue(
self, *, command: dict[str, object], dedup_key: str | None
) -> str: ...
class EventStreamLike(Protocol):
async def read(
self,
*,
session_id: str,
last_event_id: str | None,
) -> list[dict[str, object]]: ...
def ensure_session_owner(*, owner_id: str, current_user: CurrentUser) -> None:
if owner_id != str(current_user.id):
raise HTTPException(status_code=403, detail="Forbidden")
class AgentService:
def __init__(
self,
*,
repository: AgentRepositoryLike,
queue: QueueClientLike,
stream: EventStreamLike,
) -> None:
self._repository = repository
self._queue = queue
self._stream = stream
async def enqueue_run(
self,
*,
run_input: RunAgentInput,
current_user: CurrentUser,
) -> TaskAccepted:
created = False
thread_id = run_input.thread_id
run_id = run_input.run_id
try:
owner = await self._repository.get_session_owner(session_id=thread_id)
except HTTPException as exc:
if exc.status_code != 404:
raise
try:
await self._repository.create_session_for_user(
user_id=str(current_user.id),
session_id=thread_id,
)
await self._repository.commit()
created = True
except IntegrityError:
await self._repository.rollback()
owner = await self._repository.get_session_owner(session_id=thread_id)
ensure_session_owner(owner_id=owner, current_user=current_user)
else:
ensure_session_owner(owner_id=owner, current_user=current_user)
try:
task_id = await self._queue.enqueue(
command={
"command": "run",
"run_input": run_input.model_dump(mode="json", by_alias=True),
},
dedup_key=None,
)
except Exception: # noqa: BLE001
raise
return TaskAccepted(
task_id=task_id,
thread_id=thread_id,
run_id=run_id,
created=created,
)
async def enqueue_resume(
self,
*,
thread_id: str,
run_input: RunAgentInput,
current_user: CurrentUser,
) -> TaskAccepted:
owner = await self._repository.get_session_owner(session_id=thread_id)
ensure_session_owner(owner_id=owner, current_user=current_user)
dedup_key = f"resume:{thread_id}:{run_input.run_id}"
task_id = await self._queue.enqueue(
command={
"command": "resume",
"run_input": run_input.model_dump(mode="json", by_alias=True),
},
dedup_key=dedup_key,
)
return TaskAccepted(
task_id=task_id,
thread_id=thread_id,
run_id=run_input.run_id,
created=False,
)
async def stream_events(
self,
*,
thread_id: str,
last_event_id: str | None,
current_user: CurrentUser,
) -> list[dict[str, object]]:
owner = await self._repository.get_session_owner(session_id=thread_id)
ensure_session_owner(owner_id=owner, current_user=current_user)
return await self._stream.read(
session_id=thread_id,
last_event_id=last_event_id,
)
async def get_history_snapshot(
self,
*,
thread_id: str,
before: date | None,
current_user: CurrentUser,
) -> dict[str, object]:
owner = await self._repository.get_session_owner(session_id=thread_id)
ensure_session_owner(owner_id=owner, current_user=current_user)
day_payload = await self._repository.get_history_day(
session_id=thread_id,
before=before,
)
snapshot = {
"scope": "history_day",
"threadId": thread_id,
"day": day_payload["day"] if day_payload else None,
"hasMore": day_payload["hasMore"] if day_payload else False,
"messages": day_payload["messages"] if day_payload else [],
}
event = StateSnapshotEvent(snapshot=snapshot).model_dump(
mode="json",
by_alias=True,
exclude_none=True,
)
event["threadId"] = thread_id
return event
async def get_user_history_snapshot(
self,
*,
current_user: CurrentUser,
thread_id: str | None,
before: date | None,
) -> dict[str, object]:
target_thread_id = thread_id
if target_thread_id is None:
target_thread_id = await self._repository.get_latest_session_id_for_user(
user_id=str(current_user.id)
)
if target_thread_id is None:
return StateSnapshotEvent(
snapshot={
"scope": "history_day",
"threadId": None,
"day": None,
"hasMore": False,
"messages": [],
}
).model_dump(mode="json", by_alias=True, exclude_none=True)
return await self.get_history_snapshot(
thread_id=target_thread_id,
before=before,
current_user=current_user,
)