Skip to content

Commit

Permalink
feat: add ProcessState & deprecate GetState (#67)
Browse files Browse the repository at this point in the history
* feat: add ProcessState & deprecate GetState

* feat: cannot set max steps in dag
  • Loading branch information
meguminnnnnnnnn authored Feb 14, 2025
1 parent 7b64b27 commit 648ec7c
Show file tree
Hide file tree
Showing 8 changed files with 111 additions and 59 deletions.
4 changes: 2 additions & 2 deletions compose/dag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ func TestDAG(t *testing.T) {
t.Fatal(err)
}

r, err := g.compile(context.Background(), &graphCompileOptions{nodeTriggerMode: AllPredecessor, maxRunSteps: 10})
r, err := g.compile(context.Background(), &graphCompileOptions{nodeTriggerMode: AllPredecessor})
if err != nil {
t.Fatal(err)
}
Expand All @@ -146,7 +146,7 @@ func TestDAG(t *testing.T) {
}

// test Compile[I,O]
runner, err := g.Compile(context.Background(), WithMaxRunSteps(100), WithNodeTriggerMode(AllPredecessor))
runner, err := g.Compile(context.Background(), WithNodeTriggerMode(AllPredecessor))
if err != nil {
t.Fatal(err)
}
Expand Down
5 changes: 4 additions & 1 deletion compose/graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -847,14 +847,17 @@ func (g *graph) compile(ctx context.Context, opt *graphCompileOptions) (*composa
if err != nil {
return nil, err
}
r.dag = true
}

if opt != nil {
r.options = *opt
}

// default options
if r.options.maxRunSteps == 0 {
if r.dag && r.options.maxRunSteps > 0 {
return nil, fmt.Errorf("cannot set max run steps in dag mode")
} else if !r.dag && r.options.maxRunSteps == 0 {
r.options.maxRunSteps = len(r.chanSubscribeTo) + 10
}

Expand Down
5 changes: 1 addition & 4 deletions compose/graph_compile_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,7 @@ func WithGraphCompileCallbacks(cbs ...GraphCompileCallback) GraphCompileOption {
}

// WithGetStateEnable enables/disables GetState in Workflow nodes.
// note: Only use this in Workflow
// Since WorkflowNodes execute concurrently without ordering guarantees relative to other nodes' state handlers.
// GetState is disabled in WorkflowNodes to prevent race conditions by default.
// When enabled, users must handle concurrent state access safety (e.g. using locks) themselves.
// Deprecated: use ProcessState instead of GetState, which is always concurrency-safe.
func WithGetStateEnable(enabled bool) GraphCompileOption {
return func(o *graphCompileOptions) {
o.getStateEnabled = enabled
Expand Down
25 changes: 17 additions & 8 deletions compose/graph_run.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ type runner struct {

chanBuilder chanBuilder // could be nil
eager bool
dag bool

runCtx func(ctx context.Context) context.Context

Expand Down Expand Up @@ -130,14 +131,22 @@ func (r *runner) run(ctx context.Context, isStream bool, input any, opts ...Opti
ctx = r.runCtx(ctx)
}

// Update maxSteps if provided in options.
for i := range opts {
if opts[i].maxRunSteps > 0 {
maxSteps = opts[i].maxRunSteps
if r.dag {
for i := range opts {
if opts[i].maxRunSteps > 0 {
return nil, fmt.Errorf("cannot set max run steps in dag")
}
}
} else {
// Update maxSteps if provided in options.
for i := range opts {
if opts[i].maxRunSteps > 0 {
maxSteps = opts[i].maxRunSteps
}
}
if maxSteps < 1 {
return nil, errors.New("max run steps limit must be at least 1")
}
}
if maxSteps < 1 {
return nil, errors.New("recursion limit must be at least 1")
}

// Extract and validate options for each node.
Expand All @@ -162,7 +171,7 @@ func (r *runner) run(ctx context.Context, isStream bool, input any, opts ...Opti
return nil, fmt.Errorf("context has been canceled: %w", ctx.Err())
default:
}
if step == maxSteps {
if !r.dag && step >= maxSteps {
return nil, ErrExceedMaxSteps
}

Expand Down
58 changes: 42 additions & 16 deletions compose/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"context"
"fmt"
"reflect"
"sync"

"github.com/cloudwego/eino/schema"
"github.com/cloudwego/eino/utils/generic"
Expand All @@ -33,6 +34,7 @@ type stateKey struct{}
type internalState struct {
state any
forbidden bool
mu sync.Mutex
}

// StatePreHandler is a function that is called before the node is executed.
Expand All @@ -51,10 +53,12 @@ type StreamStatePostHandler[O, S any] func(ctx context.Context, out *schema.Stre

func convertPreHandler[I, S any](handler StatePreHandler[I, S]) *composableRunnable {
rf := func(ctx context.Context, in I, opts ...any) (I, error) {
cState, err := getState[S](ctx)
cState, pMu, err := getState[S](ctx)
if err != nil {
return in, err
}
pMu.Lock()
defer pMu.Unlock()

return handler(ctx, in, cState)
}
Expand All @@ -64,10 +68,12 @@ func convertPreHandler[I, S any](handler StatePreHandler[I, S]) *composableRunna

func convertPostHandler[O, S any](handler StatePostHandler[O, S]) *composableRunnable {
rf := func(ctx context.Context, out O, opts ...any) (O, error) {
cState, err := getState[S](ctx)
cState, pMu, err := getState[S](ctx)
if err != nil {
return out, err
}
pMu.Lock()
defer pMu.Unlock()

return handler(ctx, out, cState)
}
Expand All @@ -77,10 +83,12 @@ func convertPostHandler[O, S any](handler StatePostHandler[O, S]) *composableRun

func streamConvertPreHandler[I, S any](handler StreamStatePreHandler[I, S]) *composableRunnable {
rf := func(ctx context.Context, in *schema.StreamReader[I], opts ...any) (*schema.StreamReader[I], error) {
cState, err := getState[S](ctx)
cState, pMu, err := getState[S](ctx)
if err != nil {
return in, err
}
pMu.Lock()
defer pMu.Unlock()

return handler(ctx, in, cState)
}
Expand All @@ -90,35 +98,51 @@ func streamConvertPreHandler[I, S any](handler StreamStatePreHandler[I, S]) *com

func streamConvertPostHandler[O, S any](handler StreamStatePostHandler[O, S]) *composableRunnable {
rf := func(ctx context.Context, out *schema.StreamReader[O], opts ...any) (*schema.StreamReader[O], error) {
cState, err := getState[S](ctx)
cState, pMu, err := getState[S](ctx)
if err != nil {
return out, err
}
pMu.Lock()
defer pMu.Unlock()

return handler(ctx, out, cState)
}

return runnableLambda[O, O](nil, nil, nil, rf, false)
}

// GetState gets the state from the context.
// When using this method to read or write state in custom nodes, it may lead to data race because other nodes may concurrently access the state.
// You need to be aware of and resolve this situation, typically by adding a mutex.
// It's recommended to only READ the returned state. If you want to WRITE to state, consider using StatePreHandler / StatePostHandler because they are concurrency safe out of the box.
// note: this method will report error
// ProcessState processes the state from the context in a concurrency-safe way.
// This is the recommended way to access and modify state in custom nodes.
// The provided function handler will be executed with exclusive access to the state (protected by mutex).
// note: this method will report error if state type doesn't match or state is not found in context
// e.g.
//
// lambdaFunc := func(ctx context.Context, in string, opts ...any) (string, error) {
// state, err := compose.GetState[*testState](ctx)
// err := compose.ProcessState[*testState](ctx, func(state *testState) error {
// // do something with state in a concurrency-safe way
// state.Count++
// return nil
// })
// if err != nil {
// return "", err
// }
// // do something with state
// return in, nil
// }
//
// stateGraph := compose.NewStateGraph[string, string, testState](genStateFunc)
// stateGraph.AddNode("node1", lambdaFunc)
func ProcessState[S any](ctx context.Context, handler func(context.Context, S) error) error {
s, pMu, err := getState[S](ctx)
if err != nil {
return fmt.Errorf("get state from context fail: %w", err)
}
pMu.Lock()
defer pMu.Unlock()
return handler(ctx, s)
}

// GetState gets the state from the context.
// Deprecated: use ProcessState instead.
func GetState[S any](ctx context.Context) (S, error) {
state := ctx.Value(stateKey{})

Expand All @@ -137,15 +161,17 @@ func GetState[S any](ctx context.Context) (S, error) {
return cState, nil
}

func getState[S any](ctx context.Context) (S, error) {
func getState[S any](ctx context.Context) (S, *sync.Mutex, error) {
state := ctx.Value(stateKey{})

cState, ok := state.(*internalState).state.(S)
interState := state.(*internalState)

cState, ok := interState.state.(S)
if !ok {
var s S
return s, fmt.Errorf("unexpected state type. expected: %v, got: %v",
generic.TypeOf[S](), reflect.TypeOf(state))
return s, nil, fmt.Errorf("unexpected state type. expected: %v, got: %v",
generic.TypeOf[S](), reflect.TypeOf(interState.state))
}

return cState, nil
return cState, &interState.mu, nil
}
22 changes: 16 additions & 6 deletions compose/state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,13 @@ func TestStateGraphUtils(t *testing.T) {
state: &testStruct{UserID: 10},
})

ts, err := GetState[*testStruct](ctx)
var userID int64
err := ProcessState[*testStruct](ctx, func(_ context.Context, state *testStruct) error {
userID = state.UserID
return nil
})
assert.NoError(t, err)
assert.Equal(t, int64(10), ts.UserID)
assert.Equal(t, int64(10), userID)
})

t.Run("getState_nil", func(t *testing.T) {
Expand All @@ -193,7 +197,9 @@ func TestStateGraphUtils(t *testing.T) {
ctx := context.Background()
ctx = context.WithValue(ctx, stateKey{}, &internalState{})

_, err := GetState[*testStruct](ctx)
err := ProcessState[*testStruct](ctx, func(_ context.Context, state *testStruct) error {
return nil
})
assert.ErrorContains(t, err, "unexpected state type. expected: *compose.testStruct, got: <nil>")
})

Expand All @@ -207,7 +213,9 @@ func TestStateGraphUtils(t *testing.T) {
state: &testStruct{UserID: 10},
})

_, err := GetState[string](ctx)
err := ProcessState[string](ctx, func(_ context.Context, state string) error {
return nil
})
assert.ErrorContains(t, err, "unexpected state type. expected: string, got: *compose.testStruct")

})
Expand All @@ -224,11 +232,13 @@ func TestStateChain(t *testing.T) {
}))

r, err := sc.AppendLambda(InvokableLambda(func(ctx context.Context, input string) (output string, err error) {
s, err := GetState[*testState](ctx)
err = ProcessState[*testState](ctx, func(_ context.Context, state *testState) error {
state.Field1 = "node1"
return nil
})
if err != nil {
return "", err
}
s.Field1 = "node1"
return input, nil
}), WithStatePostHandler(func(ctx context.Context, out string, state *testState) (string, error) {
state.Field2 = "node2"
Expand Down
41 changes: 24 additions & 17 deletions flow/agent/react/react.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package react

import (
"context"
"fmt"

"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/compose"
Expand Down Expand Up @@ -222,17 +221,23 @@ func NewAgent(ctx context.Context, config *AgentConfig) (_ *Agent, err error) {
func buildReturnDirectly(graph *compose.Graph[[]*schema.Message, *schema.Message]) (err error) {
directReturn := func(ctx context.Context, msgs *schema.StreamReader[[]*schema.Message]) (*schema.StreamReader[*schema.Message], error) {
return schema.StreamReaderWithConvert(msgs, func(msgs []*schema.Message) (*schema.Message, error) {
state, err := compose.GetState[*state](ctx)
var msg *schema.Message
err = compose.ProcessState[*state](ctx, func(_ context.Context, state *state) error {
for i := range msgs {
if msgs[i] != nil && msgs[i].ToolCallID == state.ReturnDirectlyToolCallID {
msg = msgs[i]
return nil
}
}
return nil
})
if err != nil {
return nil, fmt.Errorf("get state failed: %w", err)
return nil, err
}
for i := range msgs {
msg := msgs[i]
if msg != nil && msg.ToolCallID == state.ReturnDirectlyToolCallID {
return msg, nil
}
if msg == nil {
return nil, schema.ErrNoValue
}
return nil, schema.ErrNoValue
return msg, nil
}), nil
}

Expand All @@ -245,16 +250,18 @@ func buildReturnDirectly(graph *compose.Graph[[]*schema.Message, *schema.Message
err = graph.AddBranch(nodeKeyTools, compose.NewStreamGraphBranch(func(ctx context.Context, msgsStream *schema.StreamReader[[]*schema.Message]) (endNode string, err error) {
msgsStream.Close()

s, err := compose.GetState[*state](ctx) // last msg stored in state should contain the tool call information
err = compose.ProcessState[*state](ctx, func(_ context.Context, state *state) error {
if len(state.ReturnDirectlyToolCallID) > 0 {
endNode = nodeKeyDirectReturn
} else {
endNode = nodeKeyModel
}
return nil
})
if err != nil {
return "", fmt.Errorf("get state in branch failed: %w", err)
}

if len(s.ReturnDirectlyToolCallID) > 0 {
return nodeKeyDirectReturn, nil
return "", err
}

return nodeKeyModel, nil
return endNode, nil
}, map[string]bool{nodeKeyModel: true, nodeKeyDirectReturn: true}))
if err != nil {
return err
Expand Down
Loading

0 comments on commit 648ec7c

Please sign in to comment.