fix(agent): serialize crewai flow stages and remove nested asyncio.run
This commit is contained in:
@@ -0,0 +1,38 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class FlowState(BaseModel):
|
||||
stage_trace: list[str] = []
|
||||
current_stage: str | None = None
|
||||
message: str = ""
|
||||
context: dict[str, Any] = {}
|
||||
|
||||
|
||||
class AgentFlow:
|
||||
def __init__(self) -> None:
|
||||
self.state = FlowState()
|
||||
|
||||
async def run(self) -> dict[str, Any]:
|
||||
result = await self.intent_recognition()
|
||||
result = await self.task_execution(result)
|
||||
result = await self.result_reporting(result)
|
||||
return result
|
||||
|
||||
async def intent_recognition(self) -> dict[str, Any]:
|
||||
self.state.current_stage = "intent"
|
||||
self.state.stage_trace.append("intent")
|
||||
return {"stage": "intent", "result": "intent recognized"}
|
||||
|
||||
async def task_execution(self, prev_result: dict[str, Any]) -> dict[str, Any]:
|
||||
self.state.current_stage = "execution"
|
||||
self.state.stage_trace.append("execution")
|
||||
return {"stage": "execution", "result": "task executed"}
|
||||
|
||||
async def result_reporting(self, prev_result: dict[str, Any]) -> dict[str, Any]:
|
||||
self.state.current_stage = "reporting"
|
||||
self.state.stage_trace.append("reporting")
|
||||
return {"stage": "reporting", "result": "reported"}
|
||||
@@ -0,0 +1,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from v1.agent.crewai_flow import AgentFlow
|
||||
|
||||
|
||||
class TestCrewAIFlow:
|
||||
@pytest.mark.asyncio
|
||||
async def test_flow_stages_run_in_order(self):
|
||||
flow = AgentFlow()
|
||||
await flow.run()
|
||||
assert flow.state.stage_trace == ["intent", "execution", "reporting"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flow_state_initialized(self):
|
||||
flow = AgentFlow()
|
||||
assert flow.state.stage_trace == []
|
||||
assert flow.state.current_stage is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flow_updates_current_stage(self):
|
||||
flow = AgentFlow()
|
||||
await flow.run()
|
||||
assert flow.state.current_stage == "reporting"
|
||||
Reference in New Issue
Block a user