refactor: 简化 AgentScope 运行时模块与 prompt 系统
This commit is contained in:
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from datetime import date, datetime, time, timedelta, timezone
|
||||
from typing import Protocol
|
||||
from uuid import UUID
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import select
|
||||
@@ -10,7 +10,10 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
|
||||
from models.agent_chat_session import AgentChatSession
|
||||
from schemas.messages.chat_message import AgentChatMessage as AgentChatMessageSchema
|
||||
from schemas.messages.chat_message import (
|
||||
AgentChatMessage as AgentChatMessageSchema,
|
||||
AgentChatMessageMetadata,
|
||||
)
|
||||
|
||||
|
||||
class ToolResultPayloadStorage(Protocol):
|
||||
@@ -88,10 +91,11 @@ class AgentRepository:
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
run_id: str,
|
||||
content_text: str,
|
||||
metadata: dict[str, object] | None,
|
||||
content: str,
|
||||
metadata: AgentChatMessageMetadata | None,
|
||||
) -> None:
|
||||
from models.agent_chat_message import AgentChatMessage as OrmAgentChatMessage
|
||||
|
||||
try:
|
||||
session_uuid = UUID(session_id)
|
||||
except ValueError as exc:
|
||||
@@ -108,17 +112,17 @@ class AgentRepository:
|
||||
|
||||
next_seq = int(session_row.message_count or 0) + 1
|
||||
if not _has_title(session_row.title):
|
||||
session_title = _derive_session_title(content_text)
|
||||
session_title = _derive_session_title(content)
|
||||
if session_title is not None:
|
||||
session_row.title = session_title
|
||||
payload_metadata = dict(metadata or {})
|
||||
payload_metadata["run_id"] = run_id
|
||||
message = AgentChatMessage(
|
||||
|
||||
message = OrmAgentChatMessage(
|
||||
id=uuid4(),
|
||||
session_id=session_uuid,
|
||||
seq=next_seq,
|
||||
role=AgentChatMessageRole.USER,
|
||||
content=content_text,
|
||||
metadata_json=payload_metadata,
|
||||
content=content,
|
||||
metadata_json=metadata.model_dump(by_alias=True) if metadata else None,
|
||||
)
|
||||
self._session.add(message)
|
||||
session_row.message_count = next_seq
|
||||
|
||||
@@ -11,6 +11,10 @@ from typing import Annotated, Union
|
||||
|
||||
from ag_ui.core import RunAgentInput
|
||||
from core.agentscope.events import to_sse_event
|
||||
from core.agentscope.schemas.agui_input import (
|
||||
parse_run_input,
|
||||
validate_run_request_messages_contract,
|
||||
)
|
||||
from core.auth.models import CurrentUser
|
||||
from core.logging import get_logger
|
||||
from fastapi import (
|
||||
@@ -26,11 +30,6 @@ from fastapi import (
|
||||
status,
|
||||
)
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from core.agentscope.schemas.agui_input import (
|
||||
extract_latest_tool_result,
|
||||
parse_run_input,
|
||||
validate_run_request_messages_contract,
|
||||
)
|
||||
from services.base.redis import get_or_init_redis_client
|
||||
from v1.agent.dependencies import get_agent_service
|
||||
from v1.agent.schemas import (
|
||||
@@ -129,8 +128,7 @@ async def enqueue_run(
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
) -> TaskAcceptedResponse:
|
||||
try:
|
||||
normalized = parse_run_input(request.model_dump(mode="json", by_alias=True))
|
||||
validate_run_request_messages_contract(normalized)
|
||||
validate_run_request_messages_contract(request)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
||||
allowed = await _allow_run_request(user_id=str(current_user.id))
|
||||
@@ -149,40 +147,6 @@ async def enqueue_run(
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/runs/{thread_id}/resume",
|
||||
response_model=TaskAcceptedResponse,
|
||||
status_code=status.HTTP_202_ACCEPTED,
|
||||
)
|
||||
async def enqueue_resume(
|
||||
thread_id: str,
|
||||
request: RunAgentInput,
|
||||
service: Annotated[AgentService, Depends(get_agent_service)],
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
) -> TaskAcceptedResponse:
|
||||
if request.thread_id != thread_id:
|
||||
raise HTTPException(status_code=422, detail="thread_id path/body mismatch")
|
||||
try:
|
||||
normalized = parse_run_input(request.model_dump(mode="json", by_alias=True))
|
||||
extract_latest_tool_result(normalized)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
||||
allowed = await _allow_run_request(user_id=str(current_user.id))
|
||||
if not allowed:
|
||||
raise HTTPException(status_code=429, detail="Too many run requests")
|
||||
task = await service.enqueue_resume(
|
||||
thread_id=thread_id,
|
||||
run_input=request,
|
||||
current_user=current_user,
|
||||
)
|
||||
return TaskAcceptedResponse(
|
||||
taskId=task.task_id,
|
||||
threadId=task.thread_id,
|
||||
runId=task.run_id,
|
||||
created=task.created,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/runs/{thread_id}/events")
|
||||
async def stream_events(
|
||||
request: Request,
|
||||
|
||||
@@ -17,6 +17,10 @@ from core.auth.models import CurrentUser
|
||||
from core.agentscope.schemas.agui_input import extract_latest_user_payload
|
||||
from core.config.settings import config
|
||||
from core.logging import get_logger
|
||||
from schemas.messages.chat_message import (
|
||||
AgentChatMessageMetadata,
|
||||
UserMessageAttachments,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
_ALLOWED_ATTACHMENT_MIME_TYPES = {"image/png", "image/jpeg", "image/webp"}
|
||||
@@ -53,9 +57,8 @@ class AgentRepositoryLike(Protocol):
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
run_id: str,
|
||||
content_text: str,
|
||||
metadata: dict[str, object] | None,
|
||||
content: str,
|
||||
metadata: AgentChatMessageMetadata | None,
|
||||
) -> None: ...
|
||||
|
||||
|
||||
@@ -157,8 +160,7 @@ class AgentService:
|
||||
)
|
||||
await self._repository.persist_user_message(
|
||||
session_id=thread_id,
|
||||
run_id=run_id,
|
||||
content_text=user_message_text,
|
||||
content=user_message_text,
|
||||
metadata=user_message_metadata,
|
||||
)
|
||||
await self._repository.commit()
|
||||
@@ -167,7 +169,12 @@ class AgentService:
|
||||
command={
|
||||
"command": "run",
|
||||
"owner_id": str(current_user.id),
|
||||
"run_input": run_input.model_dump(mode="json", by_alias=True),
|
||||
"run_input": {
|
||||
"messages": [
|
||||
msg.model_dump(mode="json", exclude_none=True)
|
||||
for msg in run_input.messages
|
||||
],
|
||||
},
|
||||
},
|
||||
dedup_key=None,
|
||||
)
|
||||
@@ -178,14 +185,41 @@ class AgentService:
|
||||
created=created,
|
||||
)
|
||||
|
||||
async def load_agent_input_messages(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
) -> dict[str, object] | None:
|
||||
"""Load recent messages for runtime agent input.
|
||||
|
||||
Returns messages from today and yesterday (if exists).
|
||||
"""
|
||||
today = await self._repository.get_history_day(
|
||||
session_id=thread_id,
|
||||
before=None,
|
||||
)
|
||||
if not today:
|
||||
return None
|
||||
|
||||
yesterday = await self._repository.get_history_day(
|
||||
session_id=thread_id,
|
||||
before=today.get("day"), # type: ignore
|
||||
)
|
||||
|
||||
messages: list[dict[str, object]] = []
|
||||
if yesterday and yesterday.get("messages"):
|
||||
messages.extend(yesterday["messages"]) # type: ignore
|
||||
if today.get("messages"):
|
||||
messages.extend(today["messages"]) # type: ignore
|
||||
|
||||
return {"messages": messages}
|
||||
|
||||
async def _prepare_user_message(
|
||||
self,
|
||||
*,
|
||||
run_input: RunAgentInput,
|
||||
current_user: CurrentUser,
|
||||
) -> tuple[str, dict[str, object] | None]:
|
||||
from schemas.messages.chat_message import UserMessageAttachments
|
||||
|
||||
) -> tuple[str, AgentChatMessageMetadata | None]:
|
||||
text, content_blocks = extract_latest_user_payload(run_input)
|
||||
|
||||
user_attachments: UserMessageAttachments | None = None
|
||||
@@ -227,11 +261,12 @@ class AgentService:
|
||||
logger.warning("Failed to parse signed URL", url=url, error=str(exc))
|
||||
raise HTTPException(status_code=422, detail="Invalid signed image url")
|
||||
|
||||
metadata: dict[str, object] | None = None
|
||||
metadata: AgentChatMessageMetadata | None = None
|
||||
if user_attachments is not None:
|
||||
metadata = {
|
||||
"user_message_attachments": user_attachments.model_dump(by_alias=True),
|
||||
}
|
||||
metadata = AgentChatMessageMetadata(
|
||||
run_id=run_input.run_id,
|
||||
user_message_attachments=user_attachments,
|
||||
)
|
||||
|
||||
return text, metadata
|
||||
|
||||
@@ -361,33 +396,6 @@ class AgentService:
|
||||
"url": signed_url,
|
||||
}
|
||||
|
||||
async def enqueue_resume(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
run_input: RunAgentInput,
|
||||
current_user: CurrentUser,
|
||||
) -> TaskAccepted:
|
||||
owner = await self._repository.get_session_owner(session_id=thread_id)
|
||||
ensure_session_owner(owner_id=owner, current_user=current_user)
|
||||
|
||||
dedup_key = f"resume:{thread_id}:{run_input.run_id}"
|
||||
task_id = await self._queue.enqueue(
|
||||
command={
|
||||
"command": "resume",
|
||||
"owner_id": str(current_user.id),
|
||||
"run_input": run_input.model_dump(mode="json", by_alias=True),
|
||||
},
|
||||
dedup_key=dedup_key,
|
||||
)
|
||||
|
||||
return TaskAccepted(
|
||||
task_id=task_id,
|
||||
thread_id=thread_id,
|
||||
run_id=run_input.run_id,
|
||||
created=False,
|
||||
)
|
||||
|
||||
async def stream_events(
|
||||
self,
|
||||
*,
|
||||
|
||||
@@ -61,7 +61,7 @@ async def _enforce_rate_limit_with_redis(
|
||||
window_seconds: int,
|
||||
) -> None:
|
||||
client = await get_or_init_redis_client()
|
||||
current = await client.eval(_REDIS_LIMIT_SCRIPT, 1, key, window_seconds)
|
||||
current = await client.eval(_REDIS_LIMIT_SCRIPT, 1, key, window_seconds) # type: ignore[await]
|
||||
if int(current) > limit:
|
||||
raise HTTPException(status_code=429, detail="Too many requests")
|
||||
|
||||
|
||||
@@ -81,7 +81,7 @@ class SQLAlchemyFriendshipRepository(BaseRepository[Friendship]):
|
||||
super().__init__(session, Friendship)
|
||||
|
||||
async def create_request(
|
||||
self, initiator_id: UUID, recipient_id: UUID, message: str | None = None
|
||||
self, initiator_id: UUID, recipient_id: UUID, content: str | None = None
|
||||
) -> tuple[Friendship, InboxMessage]:
|
||||
try:
|
||||
user_low_id = min(initiator_id, recipient_id)
|
||||
@@ -100,7 +100,7 @@ class SQLAlchemyFriendshipRepository(BaseRepository[Friendship]):
|
||||
self._session.add(friendship)
|
||||
await self._session.flush()
|
||||
|
||||
inbox_content = FriendshipContent(type="request", message=message)
|
||||
inbox_content = FriendshipContent(type="request", message=content)
|
||||
inbox = InboxMessage(
|
||||
recipient_id=recipient_id,
|
||||
sender_id=initiator_id,
|
||||
@@ -126,7 +126,7 @@ class SQLAlchemyFriendshipRepository(BaseRepository[Friendship]):
|
||||
self,
|
||||
friendship: Friendship,
|
||||
initiator_id: UUID,
|
||||
message: str | None = None,
|
||||
content: str | None = None,
|
||||
) -> tuple[Friendship, InboxMessage]:
|
||||
try:
|
||||
now = datetime.now(timezone.utc)
|
||||
@@ -135,7 +135,7 @@ class SQLAlchemyFriendshipRepository(BaseRepository[Friendship]):
|
||||
friendship.initiator_id = initiator_id
|
||||
friendship.updated_by = initiator_id
|
||||
|
||||
inbox_content = FriendshipContent(type="request", message=message)
|
||||
inbox_content = FriendshipContent(type="request", message=content)
|
||||
inbox = InboxMessage(
|
||||
recipient_id=(
|
||||
friendship.user_low_id
|
||||
|
||||
@@ -18,7 +18,7 @@ class InboxMessageResponse(BaseModel):
|
||||
message_type: InboxMessageType
|
||||
schedule_item_id: UUID | None = None
|
||||
friendship_id: UUID | None = None
|
||||
content: str | None = None
|
||||
content: dict | None = None
|
||||
is_read: bool = False
|
||||
status: InboxMessageStatus = InboxMessageStatus.PENDING
|
||||
created_at: datetime
|
||||
|
||||
@@ -7,33 +7,33 @@ from fastapi import APIRouter, Depends
|
||||
|
||||
from schemas.user.context import UserContext
|
||||
from v1.users.dependencies import get_user_service
|
||||
from v1.users.schemas import UserResponse, UserSearchRequest, UserUpdateRequest
|
||||
from v1.users.schemas import UserSearchRequest, UserUpdateRequest
|
||||
from v1.users.service import UserService
|
||||
|
||||
|
||||
router = APIRouter(prefix="/users", tags=["users"])
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
@router.get("/me", response_model=UserContext)
|
||||
async def get_me(
|
||||
service: Annotated[UserService, Depends(get_user_service)],
|
||||
) -> UserResponse:
|
||||
) -> UserContext:
|
||||
return await service.get_me()
|
||||
|
||||
|
||||
@router.patch("/me", response_model=UserResponse)
|
||||
@router.patch("/me", response_model=UserContext)
|
||||
async def update_me(
|
||||
payload: UserUpdateRequest,
|
||||
service: Annotated[UserService, Depends(get_user_service)],
|
||||
) -> UserResponse:
|
||||
) -> UserContext:
|
||||
return await service.update_me(payload)
|
||||
|
||||
|
||||
@router.post("/search", response_model=list[UserResponse])
|
||||
@router.post("/search", response_model=list[UserContext])
|
||||
async def search_users(
|
||||
payload: UserSearchRequest,
|
||||
service: Annotated[UserService, Depends(get_user_service)],
|
||||
) -> list[UserResponse]:
|
||||
) -> list[UserContext]:
|
||||
return await service.search_users(payload)
|
||||
|
||||
|
||||
|
||||
@@ -11,14 +11,6 @@ from pydantic import (
|
||||
model_validator,
|
||||
)
|
||||
|
||||
from schemas.user.context import UserContext
|
||||
|
||||
|
||||
class UserResponse(UserContext):
|
||||
"""当前用户,含 email,无 settings"""
|
||||
|
||||
settings: None = Field(default=None, exclude=True) # type: ignore[assignment]
|
||||
|
||||
|
||||
class UserSearchRequest(BaseModel):
|
||||
query: str = Field(min_length=1, max_length=100)
|
||||
|
||||
@@ -13,8 +13,9 @@ from core.agentscope.persistence.user_context_cache import (
|
||||
)
|
||||
from core.db.base_service import BaseService
|
||||
from core.logging import get_logger
|
||||
from schemas.user.context import UserContext, parse_profile_settings
|
||||
from v1.users.repository import UserRepository
|
||||
from v1.users.schemas import UserResponse, UserSearchRequest, UserUpdateRequest
|
||||
from v1.users.schemas import UserSearchRequest, UserUpdateRequest
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
@@ -82,7 +83,7 @@ class UserService(BaseService):
|
||||
user_context_cache or create_user_context_cache(),
|
||||
)
|
||||
|
||||
async def get_me(self) -> UserResponse:
|
||||
async def get_me(self) -> UserContext:
|
||||
user_id = self.require_user_id()
|
||||
try:
|
||||
user = await self._repository.get_by_user_id(user_id)
|
||||
@@ -92,12 +93,13 @@ class UserService(BaseService):
|
||||
if user is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
email = self._current_user.email if self._current_user else None
|
||||
return UserResponse(
|
||||
return UserContext(
|
||||
id=str(user.id),
|
||||
username=user.username,
|
||||
email=email,
|
||||
avatar_url=user.avatar_url,
|
||||
bio=user.bio,
|
||||
settings=parse_profile_settings(user.settings),
|
||||
)
|
||||
|
||||
async def get_user_by_id(self, user_id: UUID) -> "UserContext":
|
||||
@@ -116,7 +118,7 @@ class UserService(BaseService):
|
||||
avatar_url=profile.avatar_url,
|
||||
)
|
||||
|
||||
async def update_me(self, update: UserUpdateRequest) -> UserResponse:
|
||||
async def update_me(self, update: UserUpdateRequest) -> UserContext:
|
||||
user_id = self.require_user_id()
|
||||
update_data: dict[str, str | None] = {
|
||||
key: value
|
||||
@@ -151,15 +153,16 @@ class UserService(BaseService):
|
||||
)
|
||||
|
||||
email = self._current_user.email if self._current_user else None
|
||||
return UserResponse(
|
||||
return UserContext(
|
||||
id=str(user.id),
|
||||
username=user.username,
|
||||
email=email,
|
||||
avatar_url=user.avatar_url,
|
||||
bio=user.bio,
|
||||
settings=parse_profile_settings(user.settings),
|
||||
)
|
||||
|
||||
async def get_by_username(self, username: str) -> UserResponse:
|
||||
async def get_by_username(self, username: str) -> UserContext:
|
||||
try:
|
||||
user = await self._repository.get_by_username(username)
|
||||
except SQLAlchemyError:
|
||||
@@ -167,14 +170,15 @@ class UserService(BaseService):
|
||||
|
||||
if user is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
return UserResponse(
|
||||
return UserContext(
|
||||
id=str(user.id),
|
||||
username=user.username,
|
||||
avatar_url=user.avatar_url,
|
||||
bio=user.bio,
|
||||
settings=parse_profile_settings(user.settings),
|
||||
)
|
||||
|
||||
async def search_users(self, request: UserSearchRequest) -> list[UserResponse]:
|
||||
async def search_users(self, request: UserSearchRequest) -> list[UserContext]:
|
||||
query = request.query.strip()
|
||||
|
||||
if _EMAIL_PATTERN.match(query):
|
||||
@@ -182,7 +186,7 @@ class UserService(BaseService):
|
||||
|
||||
return await self._search_by_username(query)
|
||||
|
||||
async def _search_by_email(self, email: str) -> list[UserResponse]:
|
||||
async def _search_by_email(self, email: str) -> list[UserContext]:
|
||||
if self._auth_gateway is None:
|
||||
raise HTTPException(status_code=503, detail="Auth lookup unavailable")
|
||||
|
||||
@@ -199,26 +203,28 @@ class UserService(BaseService):
|
||||
return []
|
||||
|
||||
return [
|
||||
UserResponse(
|
||||
UserContext(
|
||||
id=str(user.id),
|
||||
username=user.username,
|
||||
avatar_url=user.avatar_url,
|
||||
bio=user.bio,
|
||||
settings=parse_profile_settings(user.settings),
|
||||
)
|
||||
]
|
||||
|
||||
async def _search_by_username(self, query: str) -> list[UserResponse]:
|
||||
async def _search_by_username(self, query: str) -> list[UserContext]:
|
||||
try:
|
||||
users = await self._repository.search_users(query, limit=20)
|
||||
except SQLAlchemyError:
|
||||
raise HTTPException(status_code=503, detail="User store unavailable")
|
||||
|
||||
return [
|
||||
UserResponse(
|
||||
UserContext(
|
||||
id=str(user.id),
|
||||
username=user.username,
|
||||
avatar_url=user.avatar_url,
|
||||
bio=user.bio,
|
||||
settings=parse_profile_settings(user.settings),
|
||||
)
|
||||
for user in users
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user