Skip to content

Commit

Permalink
allow and validate fan-in to same map dest in workflow
Browse files Browse the repository at this point in the history
Change-Id: Ib459e12c3f05498f613d0fb97eefd49d036e68ee
  • Loading branch information
shentongmartin committed Jan 16, 2025
1 parent de3f0ac commit 2e90ccb
Show file tree
Hide file tree
Showing 5 changed files with 217 additions and 38 deletions.
63 changes: 53 additions & 10 deletions compose/field_mapping.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package compose

import (
"errors"
"fmt"
"reflect"

Expand Down Expand Up @@ -69,11 +70,27 @@ func assignOne[T any](dest T, taken any, to string) (T, error) {

}

func convertTo[T any](mappings map[string]any, mustSucceed bool) (T, error) {
func convertTo[T any](mappings map[Mapping]any, mustSucceed bool) (T, error) {
t := generic.NewInstance[T]()

var err error
for fieldName, taken := range mappings {
var (
err error
field2Values = make(map[string][]any)
)

for m, taken := range mappings {
field2Values[m.to] = append(field2Values[m.to], taken)
}

for fieldName, values := range field2Values {
taken := values[0]
if len(values) > 1 {
taken, err = mergeValues(values)
if err != nil {
return t, fmt.Errorf("convertTo %T failed when merge multiple values for field %s, %w", t, fieldName, err)
}
}

t, err = assignOne(t, taken, fieldName)
if err != nil {
if mustSucceed {
Expand All @@ -86,34 +103,34 @@ func convertTo[T any](mappings map[string]any, mustSucceed bool) (T, error) {
return t, nil
}

type fieldMapFn func(any) (map[string]any, error)
type fieldMapFn func(any) (map[Mapping]any, error)
type streamFieldMapFn func(streamReader) streamReader

func mappingAssign[T any](in map[string]any, mustSucceed bool) (any, error) {
func mappingAssign[T any](in map[Mapping]any, mustSucceed bool) (any, error) {
return convertTo[T](in, mustSucceed)
}

func mappingStreamAssign[T any](in streamReader, mustSucceed bool) streamReader {
s, ok := unpackStreamReader[map[string]any](in)
s, ok := unpackStreamReader[map[Mapping]any](in)
if !ok {
panic("mappingStreamAssign incoming streamReader chunk type not map[string]any")
}

return packStreamReader(schema.StreamReaderWithConvert(s, func(v map[string]any) (T, error) {
return packStreamReader(schema.StreamReaderWithConvert(s, func(v map[Mapping]any) (T, error) {
return convertTo[T](v, mustSucceed)
}))
}

func fieldMap(mappings []*Mapping) fieldMapFn {
return func(input any) (map[string]any, error) {
result := make(map[string]any, len(mappings))
return func(input any) (map[Mapping]any, error) {
result := make(map[Mapping]any, len(mappings))
for _, mapping := range mappings {
taken, err := takeOne(input, mapping.from)
if err != nil {
return nil, err
}

result[mapping.to] = taken
result[*mapping] = taken
}

return result, nil
Expand Down Expand Up @@ -194,3 +211,29 @@ func checkAndExtractFieldType(field string, typ reflect.Type) (reflect.Type, err

return f.Type, nil
}

func checkMappingGroup(mappings []*Mapping) error {
if len(mappings) <= 1 {
return nil
}

var toMap = make(map[string]bool, len(mappings))

for _, mapping := range mappings {
if mapping.empty() {
return errors.New("multiple mappings have an empty mapping")
}

if len(mapping.to) == 0 {
return fmt.Errorf("multiple mappings have a mapping to entire output, mapping= %s", mapping)
}

if _, ok := toMap[mapping.to]; ok {
return fmt.Errorf("multiple mappings have the same To = %s, mappings=%v", mapping.to, mappings)
}

toMap[mapping.to] = true
}

return nil
}
67 changes: 67 additions & 0 deletions compose/graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,10 @@ func (g *graph) compile(ctx context.Context, opt *graphCompileOptions) (*composa
if err != nil {
return nil, err
}

if err = g.validateFanIn(invertedEdges, g.node2Mappings); err != nil {
return nil, err
}
}

if opt != nil {
Expand Down Expand Up @@ -1097,3 +1101,66 @@ func (g *graph) checkNodesPreConvertMustSucceed() map[string]bool {

return result
}

func (g *graph) validateFanIn(invertedEdge map[string][]string, node2Mappings map[string][]*Mapping) (err error) {
for toNodeKey, fromNodeKeys := range invertedEdge {
if len(fromNodeKeys) <= 1 { // no fan in
continue
}

toNodeInputType := g.getNodeInputType(toNodeKey)

mappings, ok := node2Mappings[toNodeKey]
if !ok {
if toNodeInputType.Kind() != reflect.Map {
return fmt.Errorf("fan in downstream node[%s] has non-map input type: %v", toNodeKey, toNodeInputType)
}

for _, fromNodeKey := range fromNodeKeys {
fromType := g.getNodeOutputType(fromNodeKey)
if fromType != toNodeInputType {
return fmt.Errorf("fan in upstream node[%s] has different output type: %v, expected: %v", toNodeKey, fromType, toNodeInputType)
}
}

continue
}

toField2Mappings := make(map[string][]*Mapping, len(mappings))
for _, mapping := range mappings {
toField2Mappings[mapping.to] = append(toField2Mappings[mapping.to], mapping)
}

for fieldName, mappingGroup := range toField2Mappings {
if len(mappingGroup) <= 1 {
continue
}

toType := toNodeInputType
if len(fieldName) > 0 {
if toType, err = checkAndExtractFieldType(fieldName, toNodeInputType); err != nil {
return err
}
}

if toType.Kind() != reflect.Map {
return fmt.Errorf("downstream node[%s]'s fan in field[%s] has non-map input type: %v", toNodeKey, fieldName, toType)
}

for _, m := range mappingGroup {
fromType := g.getNodeOutputType(m.fromNodeKey)
if len(m.from) > 0 {
if fromType, err = checkAndExtractFieldType(m.from, fromType); err != nil {
return err
}
}

if fromType != toType {
return fmt.Errorf("upstream node[%s]'s output field[%s] has different input type: %v, expected type: %v", m.fromNodeKey, m.from, fromType, toType)
}
}
}
}

return nil
}
4 changes: 2 additions & 2 deletions compose/runnable.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,10 @@ func (rp *runnablePacker[I, O, TOption]) toComposableRunnable() *composableRunna
}

func buildConverter[I any]() *composableRunnable {
inputType := reflect.TypeOf(map[string]any{})
inputType := reflect.TypeOf(map[Mapping]any{})
outputType := generic.TypeOf[I]()
i := func(ctx context.Context, input any, opts ...any) (output any, err error) {
in, ok := input.(map[string]any)
in, ok := input.(map[Mapping]any)
if !ok {
panic(newUnexpectedInputTypeErr(inputType, reflect.TypeOf(input)))
}
Expand Down
31 changes: 8 additions & 23 deletions compose/workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,26 +332,18 @@ func (wf *Workflow[I, O]) addEdgesWithMapping() (err error) {
return fmt.Errorf("workflow node = %s has no input", toNode)
}

toSet := make(map[string]bool, len(node.inputs))

fromNode2Mappings := make(map[string][]*Mapping, len(node.inputs))
for i := range node.inputs {
input := node.inputs[i]

if len(input.to) == 0 && len(node.inputs) > 1 {
return fmt.Errorf("workflow node = %s has multiple incoming mappings, one of them maps to entire input", toNode)
}

if _, ok := toSet[input.to]; ok {
return fmt.Errorf("workflow node = %s has multiple incoming mappings mapped to same field = %s", toNode, input.to)
}
toSet[input.to] = true

fromNodeKey := input.fromNodeKey
fromNode2Mappings[fromNodeKey] = append(fromNode2Mappings[fromNodeKey], input)
}

for fromNode, mappings := range fromNode2Mappings {
if err = checkMappingGroup(mappings); err != nil {
return err
}

if mappings[0].empty() {
if err = wf.gg.AddEdge(fromNode, toNode); err != nil {
return err
Expand All @@ -366,25 +358,18 @@ func (wf *Workflow[I, O]) addEdgesWithMapping() (err error) {
return errors.New("workflow END has no input mapping")
}

toSet := make(map[string]bool, len(wf.end))
fromNode2EndMappings := make(map[string][]*Mapping, len(wf.end))
for i := range wf.end {
input := wf.end[i]

if len(input.to) == 0 && len(wf.end) > 1 {
return fmt.Errorf("workflow node = %s has multiple incoming mappings, one of them maps to entire input", END)
}

if _, ok := toSet[input.to]; ok {
return fmt.Errorf("workflow node = %s has multiple incoming mappings mapped to same field = %s", END, input.to)
}
toSet[input.to] = true

fromNodeKey := input.fromNodeKey
fromNode2EndMappings[fromNodeKey] = append(fromNode2EndMappings[fromNodeKey], input)
}

for fromNode, mappings := range fromNode2EndMappings {
if err = checkMappingGroup(mappings); err != nil {
return err
}

if mappings[0].empty() {
if err = wf.gg.AddEdge(fromNode, END); err != nil {
return err
Expand Down
90 changes: 87 additions & 3 deletions compose/workflow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"

"github.com/cloudwego/eino/components/prompt"
"github.com/cloudwego/eino/internal/mock/components/embedding"
"github.com/cloudwego/eino/internal/mock/components/indexer"
"github.com/cloudwego/eino/internal/mock/components/model"
Expand Down Expand Up @@ -315,7 +316,7 @@ func TestWorkflowCompile(t *testing.T) {
w.AddToolsNode("1", &ToolsNode{}).AddInput(NewMapping(START), NewMapping(START).From("Content").To("Content"))
w.AddEnd(NewMapping("1"))
_, err := w.Compile(ctx)
assert.ErrorContains(t, err, "one of them maps to entire input")
assert.ErrorContains(t, err, "multiple mappings have an empty mapping")
})

t.Run("multiple mappings have mapping to entire output ", func(t *testing.T) {
Expand All @@ -326,7 +327,7 @@ func TestWorkflowCompile(t *testing.T) {
)
w.AddEnd(NewMapping("1"))
_, err := w.Compile(ctx)
assert.ErrorContains(t, err, " maps to entire input")
assert.ErrorContains(t, err, "multiple mappings have a mapping to entire output")
})

t.Run("multiple mappings have duplicate ToField", func(t *testing.T) {
Expand All @@ -337,6 +338,89 @@ func TestWorkflowCompile(t *testing.T) {
)
w.AddEnd(NewMapping("1"))
_, err := w.Compile(ctx)
assert.ErrorContains(t, err, "mapped to same field")
assert.ErrorContains(t, err, "multiple mappings have the same To")
})
}

func TestFanInToSameDest(t *testing.T) {
t.Run("traditional outputKey fan-in with map[string]any", func(t *testing.T) {
wf := NewWorkflow[string, []*schema.Message]()
wf.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, in string) (output string, err error) {
return in, nil
}), WithOutputKey("q1")).AddInput(NewMapping(START))
wf.AddLambdaNode("2", InvokableLambda(func(ctx context.Context, in string) (output string, err error) {
return in, nil
}), WithOutputKey("q2")).AddInput(NewMapping(START))
wf.AddChatTemplateNode("prompt", prompt.FromMessages(schema.Jinja2, schema.UserMessage("{{q1}}_{{q2}}"))).
AddInput(NewMapping("1"), NewMapping("2"))
wf.AddEnd(NewMapping("prompt"))
c, err := wf.Compile(context.Background())
assert.NoError(t, err)
out, err := c.Invoke(context.Background(), "query")
assert.NoError(t, err)
assert.Equal(t, []*schema.Message{{Role: schema.User, Content: "query_query"}}, out)
})

t.Run("multiple int fan-in to single int", func(t *testing.T) {
wf := NewWorkflow[int, int]()
wf.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, in int) (output int, err error) {
return in, nil
})).AddInput(NewMapping(START))
wf.AddLambdaNode("2", InvokableLambda(func(ctx context.Context, in int) (output int, err error) {
return in, nil
})).AddInput(NewMapping(START))
wf.AddLambdaNode("3", InvokableLambda(func(ctx context.Context, in int) (output int, err error) {
return in, nil
})).AddInput(NewMapping("1"), NewMapping("2"))
wf.AddEnd(NewMapping("3"))
_, err := wf.Compile(context.Background())
assert.ErrorContains(t, err, "has non-map input type: int")
})

t.Run("fan-in to a field of map", func(t *testing.T) {
type dest struct {
F map[string]any
}

type in struct {
A string
B int
}

wf := NewWorkflow[in, dest]()
wf.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, in string) (output string, err error) {
return in, nil
}), WithOutputKey("A")).AddInput(NewMapping(START).From("A"))
wf.AddLambdaNode("2", InvokableLambda(func(ctx context.Context, in int) (output int, err error) {
return in, nil
}), WithOutputKey("B")).AddInput(NewMapping(START).From("B"))
wf.AddEnd(NewMapping("1").To("F"), NewMapping("2").To("F"))
c, err := wf.Compile(context.Background())
assert.NoError(t, err)
out, err := c.Invoke(context.Background(), in{A: "a", B: 1})
assert.NoError(t, err)
assert.Equal(t, dest{F: map[string]any{"A": "a", "B": 1}}, out)
})

t.Run("fan-in to a field of non-map", func(t *testing.T) {
type dest struct {
F string
}

type in struct {
A string
B string
}

wf := NewWorkflow[in, dest]()
wf.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, in string) (output string, err error) {
return in, nil
})).AddInput(NewMapping(START).From("A"))
wf.AddLambdaNode("2", InvokableLambda(func(ctx context.Context, in string) (output string, err error) {
return in, nil
})).AddInput(NewMapping(START).From("B"))
wf.AddEnd(NewMapping("1").To("F"), NewMapping("2").To("F"))
_, err := wf.Compile(context.Background())
assert.ErrorContains(t, err, "has non-map input type: string")
})
}

0 comments on commit 2e90ccb

Please sign in to comment.