refactor: 重构聊天模块支持 SSE 断线重连及用户上下文隔离

This commit is contained in:
zl-q
2026-03-30 09:06:10 +08:00
parent 1aac62f39e
commit 4285b4ec80
28 changed files with 1624 additions and 658 deletions
@@ -70,7 +70,6 @@ def _router_rules(llm_config: SystemAgentLLMConfig | None) -> list[str]:
"- Return key_entities and constraints that are execution-relevant; low confidence -> omit rather than guess.",
"- Set execution_mode by complexity: onestep / tool_assisted / multistep.",
"- Set result_typing.primary to the most suitable response shape; use clarification_request only when required info is missing.",
"- Set ui.ui_mode and ui.ui_decision_reason based on whether structured UI improves actionability.",
f"- task_typing.primary must use one TaskType enum: {_enum_values(TaskType)}.",
f"- task_typing.secondary max 3 enums: {_enum_values(TaskType)}.",
f"- result_typing.primary must use one ResultType enum: {_enum_values(ResultType)}.",
@@ -279,7 +279,7 @@ class AgentScopeRunner:
runtime_mode: RuntimeMode,
work_memory: WorkProfileContent | None,
) -> WorkerAgentOutputLite:
worker_output_model = resolve_worker_output_model(router_output.ui.ui_mode)
worker_output_model = resolve_worker_output_model(router_output.execution_mode)
await self._emit_step_event(
pipeline=pipeline,
run_input=run_input,
@@ -51,7 +51,10 @@ def _status_badge_needed(intent: UiHintIntent, status: UiHintStatus) -> bool:
def _status_label(status: str) -> str:
return status.upper()
normalized = status.strip().lower()
if not normalized:
return "ui.status.info"
return f"ui.status.{normalized}"
# ============================================================
-4
View File
@@ -14,13 +14,11 @@ from schemas.agent.runtime_models import (
ResultTyping,
ResultType,
RouterAgentOutput,
RouterUiDecision,
RunStatus,
TaskType,
TaskTyping,
ToolAgentOutput,
ToolStatus,
UiMode,
WorkerAgentOutputLite,
WorkerAgentOutputRich,
resolve_worker_output_model,
@@ -47,7 +45,6 @@ __all__ = [
"ClientTimeContext",
"ResultType",
"RouterAgentOutput",
"RouterUiDecision",
"RunStatus",
"RuntimeMode",
"TaskType",
@@ -56,7 +53,6 @@ __all__ = [
"SystemVisibilityBit",
"ToolAgentOutput",
"ToolStatus",
"UiMode",
"UiHintAction",
"UiHintIntent",
"UiHintSection",
+6 -17
View File
@@ -62,11 +62,6 @@ class ExecutionMode(str, Enum):
MULTISTEP = "multistep"
class UiMode(str, Enum):
NONE = "none"
RICH = "rich"
class RunStatus(str, Enum):
SUCCESS = "success"
PARTIAL_SUCCESS = "partial_success"
@@ -114,13 +109,6 @@ class NormalizedTaskInput(BaseModel):
context_summary: str = Field(default="", max_length=2000)
class RouterUiDecision(BaseModel):
model_config = ConfigDict(extra="forbid")
ui_mode: UiMode
ui_decision_reason: str
class RouterAgentOutput(BaseModel):
model_config = ConfigDict(extra="forbid")
@@ -130,7 +118,6 @@ class RouterAgentOutput(BaseModel):
task_typing: TaskTyping
execution_mode: ExecutionMode
result_typing: ResultTyping
ui: RouterUiDecision
class ErrorInfo(BaseModel):
@@ -175,7 +162,9 @@ class AgentOutput(WorkerAgentOutputRich):
WorkerAgentOutput = WorkerAgentOutputLite | WorkerAgentOutputRich
def resolve_worker_output_model(ui_mode: UiMode) -> type[WorkerAgentOutputLite]:
if ui_mode == UiMode.RICH:
return WorkerAgentOutputRich
return WorkerAgentOutputLite
def resolve_worker_output_model(
execution_mode: ExecutionMode,
) -> type[WorkerAgentOutputLite]:
if execution_mode == ExecutionMode.ONESTEP:
return WorkerAgentOutputLite
return WorkerAgentOutputRich
+2 -1
View File
@@ -595,11 +595,12 @@ def build_status_panel(
secondary_button: UiButtonNode | None = None,
node_id: str | None = None,
) -> UiStackNode:
status_label = f"ui.status.{status.value}"
children: list[UiNode] = [
build_stack(
[
build_text(title, role=TextRole.TITLE),
build_badge(label=status.value.upper(), status=status),
build_badge(label=status_label, status=status),
],
direction=LayoutDirection.HORIZONTAL,
gap=8,
+18
View File
@@ -48,6 +48,7 @@ from v1.users.dependencies import get_current_user
router = APIRouter(prefix="/agent", tags=["agent"])
logger = get_logger("v1.agent.router")
_LAST_EVENT_ID_RE = re.compile(r"^\d+-\d+$")
_RUN_ID_RE = re.compile(r"^[A-Za-z0-9_-]{1,128}$")
_MAX_SSE_CONNECTIONS_PER_USER = 3
_SSE_SLOT_TTL_SECONDS = 15 * 60
_TERMINAL_RUN_EVENT_TYPES = {"RUN_FINISHED", "RUN_ERROR"}
@@ -120,6 +121,11 @@ def _is_terminal_run_event(event: dict[str, object]) -> bool:
)
def _is_target_run_event(event: dict[str, object], *, target_run_id: str) -> bool:
run_id = event.get("runId")
return isinstance(run_id, str) and run_id == target_run_id
@router.post(
"/runs", response_model=TaskAcceptedResponse, status_code=status.HTTP_202_ACCEPTED
)
@@ -188,9 +194,19 @@ async def stream_events(
thread_id: str,
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
run_id: str | None = Query(default=None, alias="runId"),
last_event_id: str | None = Header(default=None, alias="Last-Event-ID"),
idle_limit: int = Query(default=300, ge=1, le=3600),
) -> StreamingResponse:
if run_id is None or _RUN_ID_RE.fullmatch(run_id) is None:
raise ApiProblemError(
status_code=422,
detail=problem_payload(
code="AGENT_INVALID_RUN_ID",
detail="Invalid runId",
),
)
if last_event_id is not None and (
len(last_event_id) > 32 or _LAST_EVENT_ID_RE.fullmatch(last_event_id) is None
):
@@ -255,6 +271,8 @@ async def stream_events(
if not row_id or not isinstance(event, dict):
continue
cursor = row_id
if not _is_target_run_event(event, target_run_id=run_id):
continue
yield to_sse_event(row_id, event)
if _is_terminal_run_event(event):
terminal_event_reached = True