550 lines
20 KiB
Python
550 lines
20 KiB
Python
from __future__ import annotations
|
|
|
|
from datetime import date, datetime, timezone
|
|
import hashlib
|
|
|
|
from urllib.parse import urlparse
|
|
|
|
from ag_ui.core import RunAgentInput
|
|
from fastapi import HTTPException
|
|
from sqlalchemy.exc import IntegrityError
|
|
|
|
from core.auth.models import CurrentUser
|
|
from core.agentscope.caches.context_messages_cache import (
|
|
create_context_messages_cache,
|
|
)
|
|
from core.agentscope.schemas.agui_input import extract_latest_user_payload
|
|
from core.config.settings import config
|
|
from core.logging import get_logger
|
|
from schemas.agent.forwarded_props import (
|
|
parse_forwarded_props_runtime_mode,
|
|
RuntimeMode,
|
|
)
|
|
from schemas.agent.visibility import SystemVisibilityBit, bit_mask
|
|
from schemas.domain.automation import RuntimeConfig
|
|
from schemas.domain.chat_message import (
|
|
AgentChatMessageMetadata,
|
|
UserMessageAttachment,
|
|
extract_user_message_attachments,
|
|
)
|
|
from v1.agent.schemas import (
|
|
AgentRepositoryLike,
|
|
AttachmentStorageLike,
|
|
CancelRequested,
|
|
EventStreamLike,
|
|
HistorySnapshotResponse,
|
|
QueueClientLike,
|
|
TaskAccepted,
|
|
)
|
|
from v1.agent.utils import (
|
|
MAX_ATTACHMENT_BYTES,
|
|
MAX_ATTACHMENTS_PER_MESSAGE,
|
|
is_safe_attachment_path,
|
|
mime_to_suffix,
|
|
)
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
def ensure_session_owner(*, owner_id: str, current_user: CurrentUser) -> None:
|
|
if owner_id != str(current_user.id):
|
|
raise HTTPException(status_code=403, detail="Forbidden")
|
|
|
|
|
|
class AgentService:
|
|
_repository: AgentRepositoryLike
|
|
_queue: QueueClientLike
|
|
_stream: EventStreamLike
|
|
_attachment_storage: AttachmentStorageLike | None
|
|
|
|
_SIGNED_URL_EXPIRES_IN_SECONDS = 3600
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
repository: AgentRepositoryLike,
|
|
queue: QueueClientLike,
|
|
stream: EventStreamLike,
|
|
attachment_storage: AttachmentStorageLike | None = None,
|
|
) -> None:
|
|
self._repository = repository
|
|
self._queue = queue
|
|
self._stream = stream
|
|
self._attachment_storage = attachment_storage
|
|
|
|
async def enqueue_run(
|
|
self,
|
|
*,
|
|
run_input: RunAgentInput,
|
|
current_user: CurrentUser,
|
|
runtime_config: RuntimeConfig | None = None,
|
|
) -> TaskAccepted:
|
|
created = False
|
|
thread_id = run_input.thread_id
|
|
run_id = run_input.run_id
|
|
forwarded_props = getattr(run_input, "forwarded_props", None)
|
|
try:
|
|
runtime_mode = parse_forwarded_props_runtime_mode(forwarded_props)
|
|
except ValueError as exc:
|
|
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
|
|
|
if runtime_config is None:
|
|
from v1.agent.system_agents_config import (
|
|
build_runtime_config_from_system_agents,
|
|
)
|
|
|
|
runtime_config = build_runtime_config_from_system_agents()
|
|
|
|
try:
|
|
owner = await self._repository.get_session_owner(session_id=thread_id)
|
|
except HTTPException as exc:
|
|
if exc.status_code != 404:
|
|
raise
|
|
try:
|
|
await self._repository.create_session_for_user(
|
|
user_id=str(current_user.id),
|
|
session_id=thread_id,
|
|
)
|
|
await self._repository.commit()
|
|
created = True
|
|
except IntegrityError:
|
|
await self._repository.rollback()
|
|
owner = await self._repository.get_session_owner(session_id=thread_id)
|
|
ensure_session_owner(owner_id=owner, current_user=current_user)
|
|
else:
|
|
ensure_session_owner(owner_id=owner, current_user=current_user)
|
|
|
|
user_message_text, user_message_metadata = await self._prepare_user_message(
|
|
run_input=run_input,
|
|
current_user=current_user,
|
|
)
|
|
visibility_mask = await self._resolve_user_message_visibility_mask(
|
|
runtime_mode=runtime_mode
|
|
)
|
|
await self._repository.persist_user_message(
|
|
session_id=thread_id,
|
|
content=user_message_text,
|
|
metadata=user_message_metadata,
|
|
visibility_mask=visibility_mask,
|
|
)
|
|
await self._repository.commit()
|
|
await self._append_context_cache_user_message(
|
|
thread_id=thread_id,
|
|
runtime_mode=runtime_mode,
|
|
visibility_mask=visibility_mask,
|
|
content=user_message_text,
|
|
metadata=user_message_metadata,
|
|
)
|
|
|
|
queue = "automation" if runtime_mode == RuntimeMode.AUTOMATION else "agent"
|
|
task_id = await self._queue.enqueue(
|
|
command={
|
|
"command": "run",
|
|
"owner_id": str(current_user.id),
|
|
"run_input": run_input.model_dump(
|
|
mode="json", by_alias=True, exclude_none=True
|
|
),
|
|
"runtime_config": runtime_config.model_dump(
|
|
mode="json", by_alias=True, exclude_none=True
|
|
),
|
|
"queue": queue,
|
|
},
|
|
dedup_key=None,
|
|
)
|
|
return TaskAccepted(
|
|
task_id=task_id,
|
|
thread_id=thread_id,
|
|
run_id=run_id,
|
|
created=created,
|
|
)
|
|
|
|
async def cancel_run(
|
|
self,
|
|
*,
|
|
thread_id: str,
|
|
run_id: str,
|
|
current_user: CurrentUser,
|
|
) -> CancelRequested:
|
|
owner = await self._repository.get_session_owner(session_id=thread_id)
|
|
ensure_session_owner(owner_id=owner, current_user=current_user)
|
|
await self._queue.request_cancel(
|
|
thread_id=thread_id,
|
|
run_id=run_id,
|
|
requested_by=str(current_user.id),
|
|
)
|
|
return CancelRequested(
|
|
thread_id=thread_id,
|
|
run_id=run_id,
|
|
accepted=True,
|
|
)
|
|
|
|
async def _append_context_cache_user_message(
|
|
self,
|
|
*,
|
|
thread_id: str,
|
|
runtime_mode: RuntimeMode,
|
|
visibility_mask: int,
|
|
content: str,
|
|
metadata: AgentChatMessageMetadata | None,
|
|
) -> None:
|
|
metadata_payload = (
|
|
metadata.model_dump(mode="json", exclude_none=True)
|
|
if isinstance(metadata, AgentChatMessageMetadata)
|
|
else None
|
|
)
|
|
message_payload: dict[str, object] = {
|
|
"role": "user",
|
|
"content": content,
|
|
"timestamp": datetime.now(timezone.utc).isoformat(timespec="seconds"),
|
|
}
|
|
if isinstance(metadata_payload, dict):
|
|
message_payload["metadata"] = metadata_payload
|
|
|
|
try:
|
|
context_cache = create_context_messages_cache()
|
|
await context_cache.append_message(
|
|
thread_id=thread_id,
|
|
runtime_mode=runtime_mode.value,
|
|
visibility_mask=visibility_mask,
|
|
message=message_payload,
|
|
)
|
|
except Exception as exc:
|
|
logger.warning(
|
|
"Failed to append user message to context cache",
|
|
thread_id=thread_id,
|
|
runtime_mode=runtime_mode.value,
|
|
error=str(exc),
|
|
)
|
|
|
|
async def _resolve_user_message_visibility_mask(
|
|
self, *, runtime_mode: RuntimeMode
|
|
) -> int:
|
|
if runtime_mode == RuntimeMode.CHAT:
|
|
return bit_mask(bit=int(SystemVisibilityBit.UI_HISTORY)) | bit_mask(
|
|
bit=int(SystemVisibilityBit.CONTEXT_ASSEMBLY)
|
|
)
|
|
return 0
|
|
|
|
async def _prepare_user_message(
|
|
self,
|
|
*,
|
|
run_input: RunAgentInput,
|
|
current_user: CurrentUser,
|
|
) -> tuple[str, AgentChatMessageMetadata | None]:
|
|
text, content_blocks = extract_latest_user_payload(run_input)
|
|
|
|
user_attachments: list[UserMessageAttachment] = []
|
|
for block in content_blocks:
|
|
if not isinstance(block, dict):
|
|
continue
|
|
block_type = block.get("type")
|
|
if block_type != "binary":
|
|
continue
|
|
|
|
url = block.get("url")
|
|
mime_type = block.get("mimeType")
|
|
if not isinstance(url, str) or not url:
|
|
continue
|
|
if not isinstance(mime_type, str):
|
|
mime_type = "application/octet-stream"
|
|
|
|
if self._attachment_storage is None:
|
|
raise HTTPException(
|
|
status_code=503,
|
|
detail="Attachment storage unavailable",
|
|
)
|
|
|
|
try:
|
|
bucket, path = self._validate_binary_signed_url(
|
|
url=url,
|
|
thread_id=run_input.thread_id,
|
|
current_user=current_user,
|
|
)
|
|
user_attachments.append(
|
|
UserMessageAttachment(
|
|
bucket=bucket,
|
|
path=path,
|
|
mime_type=mime_type,
|
|
)
|
|
)
|
|
if len(user_attachments) > MAX_ATTACHMENTS_PER_MESSAGE:
|
|
raise HTTPException(status_code=422, detail="Too many attachments")
|
|
except HTTPException:
|
|
raise
|
|
except Exception as exc: # noqa: BLE001
|
|
logger.warning("Failed to parse signed URL", url=url, error=str(exc))
|
|
raise HTTPException(status_code=422, detail="Invalid signed image url")
|
|
|
|
metadata: AgentChatMessageMetadata | None = None
|
|
if user_attachments:
|
|
metadata = AgentChatMessageMetadata(
|
|
run_id=run_input.run_id,
|
|
user_message_attachments=user_attachments,
|
|
)
|
|
|
|
return text, metadata
|
|
|
|
async def upload_attachment(
|
|
self,
|
|
*,
|
|
thread_id: str,
|
|
filename: str | None,
|
|
content_type: str | None,
|
|
payload: bytes,
|
|
current_user: CurrentUser,
|
|
) -> dict[str, str]:
|
|
try:
|
|
owner = await self._repository.get_session_owner(session_id=thread_id)
|
|
except HTTPException as exc:
|
|
if exc.status_code != 404:
|
|
raise
|
|
try:
|
|
await self._repository.create_session_for_user(
|
|
user_id=str(current_user.id),
|
|
session_id=thread_id,
|
|
)
|
|
await self._repository.commit()
|
|
except IntegrityError:
|
|
await self._repository.rollback()
|
|
owner = await self._repository.get_session_owner(session_id=thread_id)
|
|
ensure_session_owner(owner_id=owner, current_user=current_user)
|
|
else:
|
|
ensure_session_owner(owner_id=owner, current_user=current_user)
|
|
if self._attachment_storage is None:
|
|
raise HTTPException(
|
|
status_code=503, detail="Attachment storage unavailable"
|
|
)
|
|
|
|
if not isinstance(content_type, str):
|
|
raise HTTPException(status_code=422, detail="Unsupported attachment type")
|
|
mime_type = content_type.lower()
|
|
if mime_type not in {"image/png", "image/jpeg", "image/webp"}:
|
|
raise HTTPException(status_code=422, detail="Unsupported attachment type")
|
|
if not payload:
|
|
raise HTTPException(status_code=422, detail="Empty attachment")
|
|
if len(payload) > MAX_ATTACHMENT_BYTES:
|
|
raise HTTPException(status_code=413, detail="Attachment too large")
|
|
|
|
suffix = mime_to_suffix(mime_type)
|
|
checksum = hashlib.sha1(payload).hexdigest()[:16]
|
|
filename_seed = filename if isinstance(filename, str) and filename else "upload"
|
|
filename_hash = hashlib.sha1(filename_seed.encode("utf-8")).hexdigest()[:8]
|
|
path = (
|
|
f"agent-inputs/{current_user.id}/{thread_id}/uploads/"
|
|
f"{filename_hash}-{checksum}.{suffix}"
|
|
)
|
|
bucket_name = config.storage.bucket
|
|
try:
|
|
stored_path = await self._attachment_storage.upload_bytes(
|
|
bucket=bucket_name,
|
|
path=path,
|
|
content=payload,
|
|
content_type=mime_type,
|
|
)
|
|
signed_url = await self._attachment_storage.create_signed_url(
|
|
bucket=bucket_name,
|
|
path=stored_path,
|
|
expires_in_seconds=self._SIGNED_URL_EXPIRES_IN_SECONDS,
|
|
)
|
|
except Exception: # noqa: BLE001
|
|
logger.exception(
|
|
"Attachment upload failed",
|
|
extra={
|
|
"bucket": bucket_name,
|
|
"path": path,
|
|
"mime_type": mime_type,
|
|
"thread_id": thread_id,
|
|
},
|
|
)
|
|
raise HTTPException(status_code=502, detail="Failed to upload attachment")
|
|
|
|
return {
|
|
"bucket": bucket_name,
|
|
"path": stored_path,
|
|
"mimeType": mime_type,
|
|
"url": signed_url,
|
|
}
|
|
|
|
async def create_attachment_signed_url(
|
|
self,
|
|
*,
|
|
bucket: str,
|
|
path: str,
|
|
current_user: CurrentUser,
|
|
) -> dict[str, str]:
|
|
if self._attachment_storage is None:
|
|
raise HTTPException(
|
|
status_code=503, detail="Attachment storage unavailable"
|
|
)
|
|
normalized_bucket = bucket.strip()
|
|
if normalized_bucket != config.storage.bucket:
|
|
raise HTTPException(status_code=422, detail="Invalid attachment bucket")
|
|
|
|
normalized_path = path.strip()
|
|
expected_prefix = f"agent-inputs/{current_user.id}/"
|
|
if not is_safe_attachment_path(
|
|
normalized_path, expected_prefix=expected_prefix
|
|
):
|
|
raise HTTPException(status_code=422, detail="Invalid attachment path scope")
|
|
|
|
try:
|
|
signed_url = await self._attachment_storage.create_signed_url(
|
|
bucket=normalized_bucket,
|
|
path=normalized_path,
|
|
expires_in_seconds=self._SIGNED_URL_EXPIRES_IN_SECONDS,
|
|
)
|
|
except Exception: # noqa: BLE001
|
|
logger.exception(
|
|
"Attachment signed URL generation failed",
|
|
extra={
|
|
"bucket": normalized_bucket,
|
|
"path": normalized_path,
|
|
"user_id": str(current_user.id),
|
|
},
|
|
)
|
|
raise HTTPException(status_code=502, detail="Failed to generate signed URL")
|
|
|
|
return {
|
|
"bucket": normalized_bucket,
|
|
"path": normalized_path,
|
|
"url": signed_url,
|
|
}
|
|
|
|
async def stream_events(
|
|
self,
|
|
*,
|
|
thread_id: str,
|
|
last_event_id: str | None,
|
|
current_user: CurrentUser,
|
|
) -> list[dict[str, object]]:
|
|
owner = await self._repository.get_session_owner(session_id=thread_id)
|
|
ensure_session_owner(owner_id=owner, current_user=current_user)
|
|
return await self._stream.read(
|
|
session_id=thread_id,
|
|
last_event_id=last_event_id,
|
|
)
|
|
|
|
async def get_history_snapshot(
|
|
self,
|
|
*,
|
|
thread_id: str,
|
|
before: date | None,
|
|
current_user: CurrentUser,
|
|
) -> HistorySnapshotResponse:
|
|
from schemas.domain.chat_message import AgentChatMessage
|
|
from v1.agent.utils import convert_message_to_history
|
|
from v1.agent.schemas import HistoryMessage
|
|
|
|
owner = await self._repository.get_session_owner(session_id=thread_id)
|
|
ensure_session_owner(owner_id=owner, current_user=current_user)
|
|
day_payload = await self._repository.get_history_day(
|
|
session_id=thread_id,
|
|
before=before,
|
|
visibility_mask=bit_mask(bit=int(SystemVisibilityBit.UI_HISTORY)),
|
|
)
|
|
|
|
messages: list[HistoryMessage] = []
|
|
if day_payload:
|
|
raw_messages_obj = day_payload.get("messages")
|
|
raw_messages = (
|
|
raw_messages_obj if isinstance(raw_messages_obj, list) else []
|
|
)
|
|
for msg_dict in raw_messages:
|
|
msg = AgentChatMessage.model_validate(msg_dict)
|
|
if msg.role == "tool":
|
|
continue
|
|
|
|
signed_urls: dict[str, str] = {}
|
|
attachments = extract_user_message_attachments(msg.metadata)
|
|
if self._attachment_storage and attachments:
|
|
expected_prefix = (
|
|
f"agent-inputs/{current_user.id}/{thread_id}/uploads/"
|
|
)
|
|
for attachment in attachments:
|
|
if not is_safe_attachment_path(
|
|
attachment.path,
|
|
expected_prefix=expected_prefix,
|
|
):
|
|
continue
|
|
signed_url = await self._attachment_storage.create_signed_url(
|
|
bucket=attachment.bucket,
|
|
path=attachment.path,
|
|
expires_in_seconds=self._SIGNED_URL_EXPIRES_IN_SECONDS,
|
|
)
|
|
key = f"{attachment.bucket}/{attachment.path}"
|
|
signed_urls[key] = signed_url
|
|
|
|
def _get_signed_url(payload: dict[str, str]) -> str:
|
|
key = f"{payload['bucket']}/{payload['path']}"
|
|
return signed_urls[key]
|
|
|
|
converted = convert_message_to_history(msg, _get_signed_url)
|
|
messages.append(HistoryMessage.model_validate(converted))
|
|
|
|
return HistorySnapshotResponse(
|
|
scope="history_day",
|
|
threadId=thread_id,
|
|
day=str(day_payload.get("day"))
|
|
if day_payload and day_payload.get("day")
|
|
else None,
|
|
hasMore=bool(day_payload.get("hasMore")) if day_payload else False,
|
|
messages=messages,
|
|
)
|
|
|
|
async def get_user_history_snapshot(
|
|
self,
|
|
*,
|
|
current_user: CurrentUser,
|
|
thread_id: str | None,
|
|
before: date | None,
|
|
) -> HistorySnapshotResponse:
|
|
target_thread_id = thread_id
|
|
if target_thread_id is None:
|
|
target_thread_id = await self._repository.get_latest_session_id_for_user(
|
|
user_id=str(current_user.id)
|
|
)
|
|
if target_thread_id is None:
|
|
return HistorySnapshotResponse(
|
|
scope="history_day",
|
|
threadId=None,
|
|
day=None,
|
|
hasMore=False,
|
|
messages=[],
|
|
)
|
|
return await self.get_history_snapshot(
|
|
thread_id=target_thread_id,
|
|
before=before,
|
|
current_user=current_user,
|
|
)
|
|
|
|
def _validate_binary_signed_url(
|
|
self,
|
|
*,
|
|
url: str,
|
|
thread_id: str,
|
|
current_user: CurrentUser,
|
|
) -> tuple[str, str]:
|
|
if self._attachment_storage is None:
|
|
raise HTTPException(
|
|
status_code=503, detail="Attachment storage unavailable"
|
|
)
|
|
parsed = urlparse(url)
|
|
expected_host = urlparse(config.supabase.url).netloc
|
|
if parsed.netloc != expected_host:
|
|
raise HTTPException(status_code=422, detail="INVALID_BINARY_URL_HOST")
|
|
|
|
try:
|
|
bucket, path = self._attachment_storage.parse_signed_url(url)
|
|
except Exception as exc: # noqa: BLE001
|
|
raise HTTPException(
|
|
status_code=422, detail="Invalid signed image url"
|
|
) from exc
|
|
|
|
if bucket != config.storage.bucket:
|
|
raise HTTPException(status_code=422, detail="INVALID_BINARY_URL_BUCKET")
|
|
|
|
expected_prefix = f"agent-inputs/{current_user.id}/{thread_id}/uploads/"
|
|
if not is_safe_attachment_path(path, expected_prefix=expected_prefix):
|
|
raise HTTPException(status_code=422, detail="INVALID_BINARY_URL_PATH_SCOPE")
|
|
return bucket, path
|