Skip to content

Commit

Permalink
feat: add dag branch (#75)
Browse files Browse the repository at this point in the history
  • Loading branch information
meguminnnnnnnnn authored Feb 21, 2025
1 parent 5e5dc1f commit 3113915
Show file tree
Hide file tree
Showing 6 changed files with 350 additions and 34 deletions.
89 changes: 70 additions & 19 deletions compose/dag.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,46 +22,47 @@ import (
)

func dagChannelBuilder(dependencies []string) channel {
waitList := make(map[string]bool, len(dependencies))
for _, dep := range dependencies {
waitList[dep] = false
}
return &dagChannel{
values: make(map[string]any),
waitList: dependencies,
waitList: waitList,
}
}

type waitPred struct {
key string
skipped bool
}

type dagChannel struct {
values map[string]any
waitList []string
waitList map[string]bool
value any
skipped bool
}

func (ch *dagChannel) update(ctx context.Context, ins map[string]any) error {
if ch.skipped {
return nil
}

for k, v := range ins {
if _, ok := ch.values[k]; ok {
return fmt.Errorf("dag channel update, calculate node repeatedly: %s", k)
}
ch.values[k] = v
}

for i := range ch.waitList {
if _, ok := ch.values[ch.waitList[i]]; !ok {
return nil
}
}

if len(ch.waitList) == 1 {
ch.value = ch.values[ch.waitList[0]]
return nil
}
v, err := mergeValues(mapToList(ch.values))
if err != nil {
return fmt.Errorf("dag channel merge value fail: %w", err)
}
ch.value = v

return nil
return ch.tryUpdateValue()
}

func (ch *dagChannel) get(ctx context.Context) (any, error) {
if ch.skipped {
return nil, fmt.Errorf("dag channel has been skipped")
}
if ch.value == nil {
return nil, fmt.Errorf("dag channel not ready, value is nil")
}
Expand All @@ -71,5 +72,55 @@ func (ch *dagChannel) get(ctx context.Context) (any, error) {
}

func (ch *dagChannel) ready(ctx context.Context) bool {
if ch.skipped {
return false
}
return ch.value != nil
}

func (ch *dagChannel) reportSkip(keys []string) (bool, error) {
for _, k := range keys {
if _, ok := ch.waitList[k]; ok {
ch.waitList[k] = true
}
}

allSkipped := true
for _, skipped := range ch.waitList {
if !skipped {
allSkipped = false
break
}
}
ch.skipped = allSkipped

var err error
if !allSkipped {
err = ch.tryUpdateValue()
}

return allSkipped, err
}

func (ch *dagChannel) tryUpdateValue() error {
var validList []string
for key, skipped := range ch.waitList {
if _, ok := ch.values[key]; !ok && !skipped {
return nil
} else if !skipped {
validList = append(validList, key)
}
}

if len(validList) == 1 {
ch.value = ch.values[validList[0]]
return nil
}
v, err := mergeValues(mapToList(ch.values))
if err != nil {
return err
}
ch.value = v
return nil

}
43 changes: 37 additions & 6 deletions compose/graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,7 @@ func (g *graph) compile(ctx context.Context, opt *graphCompileOptions) (*composa
if isWorkflow(g.cmp) {
eager = true
}
if !eager && opt != nil && opt.getStateEnabled {
if !isWorkflow(g.cmp) && opt != nil && opt.getStateEnabled {
return nil, fmt.Errorf("shouldn't set WithGetStateEnable outside of the Workflow")
}
forbidGetState := true
Expand All @@ -745,11 +745,6 @@ func (g *graph) compile(ctx context.Context, opt *graphCompileOptions) (*composa
}
}

// dag doesn't support branch
if runType == runTypeDAG && len(g.branches) > 0 {
return nil, fmt.Errorf("dag doesn't support branch for now")
}

for key := range g.fieldMappingRecords {
// not allowed to map multiple fields to the same field
toMap := make(map[string]bool)
Expand Down Expand Up @@ -806,6 +801,17 @@ func (g *graph) compile(ctx context.Context, opt *graphCompileOptions) (*composa

}
}
for start, branches := range g.branches {
for _, branch := range branches {
for end := range branch.endNodes {
if _, ok := invertedEdges[end]; !ok {
invertedEdges[end] = []string{start}
} else {
invertedEdges[end] = append(invertedEdges[end], start)
}
}
}
}

inputChannels := &chanCall{
writeTo: g.edges[START],
Expand Down Expand Up @@ -833,6 +839,12 @@ func (g *graph) compile(ctx context.Context, opt *graphCompileOptions) (*composa
edgeHandlerManager: &edgeHandlerManager{h: g.handlerOnEdges},
}

successors := make(map[string][]string)
for ch := range r.chanSubscribeTo {
successors[ch] = getSuccessors(r.chanSubscribeTo[ch])
}
r.successors = successors

if g.stateGenerator != nil {
r.runCtx = func(ctx context.Context) context.Context {
return context.WithValue(ctx, stateKey{}, &internalState{
Expand Down Expand Up @@ -868,6 +880,17 @@ func (g *graph) compile(ctx context.Context, opt *graphCompileOptions) (*composa
return r.toComposableRunnable(), nil
}

func getSuccessors(c *chanCall) []string {
ret := make([]string, len(c.writeTo))
copy(ret, c.writeTo)
for _, branch := range c.writeToBranches {
for node := range branch.endNodes {
ret = append(ret, node)
}
}
return ret
}

type subGraphCompileCallback struct {
closure func(ctx context.Context, info *GraphInfo)
}
Expand Down Expand Up @@ -1043,6 +1066,14 @@ func validateDAG(chanSubscribeTo map[string]*chanCall, invertedEdges map[string]
}
m[subNode]--
}
for _, subBranch := range chanSubscribeTo[node].writeToBranches {
for subNode := range subBranch.endNodes {
if subNode == END {
continue
}
m[subNode]--
}
}
m[node] = -1
}
}
Expand Down
37 changes: 35 additions & 2 deletions compose/graph_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ type channel interface {
update(context.Context, map[string]any) error
get(context.Context) (any, error)
ready(context.Context) bool
reportSkip([]string) (bool, error)
}

type edgeHandlerManager struct {
Expand Down Expand Up @@ -108,8 +109,9 @@ func (p *preBranchHandlerManager) handle(nodeKey string, idx int, value any, isS
}

type channelManager struct {
isStream bool
channels map[string]channel
isStream bool
successors map[string][]string
channels map[string]channel

edgeHandlerManager *edgeHandlerManager
preNodeHandlerManager *preNodeHandlerManager
Expand Down Expand Up @@ -163,6 +165,37 @@ func (c *channelManager) updateAndGet(ctx context.Context, values map[string]map
return c.getFromReadyChannels(ctx, isStream)
}

func (c *channelManager) reportBranch(from string, skippedNodes []string) error {
var nKeys []string
for _, node := range skippedNodes {
skipped, err := c.channels[node].reportSkip([]string{from})
if err != nil {
return err
}
if skipped {
nKeys = append(nKeys, node)
}
}

for i := 0; i < len(nKeys); i++ {
key := nKeys[i]
if _, ok := c.successors[key]; !ok {
return fmt.Errorf("unknown node: %s", key)
}
for _, successor := range c.successors[key] {
skipped, err := c.channels[successor].reportSkip([]string{key})
if err != nil {
return err
}
if skipped {
nKeys = append(nKeys, successor)
}
// todo: detect if end node has been skipped?
}
}
return nil
}

type task struct {
ctx context.Context
nodeKey string
Expand Down
40 changes: 34 additions & 6 deletions compose/graph_run.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ type chanBuilder func(d []string) channel
type runner struct {
chanSubscribeTo map[string]*chanCall
invertedEdges map[string][]string
successors map[string][]string
inputChannels *chanCall

chanBuilder chanBuilder // could be nil
Expand Down Expand Up @@ -176,7 +177,7 @@ func (r *runner) run(ctx context.Context, isStream bool, input any, opts ...Opti
}

// 1. Calculate active edges and resolve their values.
writeChannelValues, err := r.resolveCompletedTasks(ctx, completedTasks, isStream)
writeChannelValues, err := r.resolveCompletedTasks(ctx, completedTasks, isStream, cm)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -233,13 +234,13 @@ func (r *runner) createTasks(ctx context.Context, nodeMap map[string]any, optMap
return nextTasks, nil
}

func (r *runner) resolveCompletedTasks(ctx context.Context, completedTasks []*task, isStream bool) (map[string]map[string]any, error) {
func (r *runner) resolveCompletedTasks(ctx context.Context, completedTasks []*task, isStream bool, cm *channelManager) (map[string]map[string]any, error) {
writeChannelValues := make(map[string]map[string]any)
for _, t := range completedTasks {
// update channel & new_next_tasks
vs := copyItem(t.output, len(t.call.writeTo)+len(t.call.writeToBranches)*2)
nextNodeKeys, err := r.calculateNext(ctx, t.nodeKey, t.call,
vs[len(t.call.writeTo)+len(t.call.writeToBranches):], isStream)
vs[len(t.call.writeTo)+len(t.call.writeToBranches):], isStream, cm)
if err != nil {
return nil, fmt.Errorf("calculate next step fail, node: %s, error: %w", t.nodeKey, err)
}
Expand All @@ -253,7 +254,7 @@ func (r *runner) resolveCompletedTasks(ctx context.Context, completedTasks []*ta
return writeChannelValues, nil
}

func (r *runner) calculateNext(ctx context.Context, curNodeKey string, startChan *chanCall, input []any, isStream bool) ([]string, error) {
func (r *runner) calculateNext(ctx context.Context, curNodeKey string, startChan *chanCall, input []any, isStream bool, cm *channelManager) ([]string, error) {
if len(input) < len(startChan.writeToBranches) {
// unreachable
return nil, errors.New("calculate next input length is shorter than branches")
Expand All @@ -266,6 +267,7 @@ func (r *runner) calculateNext(ctx context.Context, curNodeKey string, startChan
ret := make([]string, 0, len(startChan.writeTo))
ret = append(ret, startChan.writeTo...)

skippedNodes := make(map[string]struct{})
for i, branch := range startChan.writeToBranches {
// check branch input type if needed
var err error
Expand Down Expand Up @@ -305,8 +307,33 @@ func (r *runner) calculateNext(ctx context.Context, curNodeKey string, startChan
return nil, errors.New("invoke branch result isn't string")
}
}

for node := range branch.endNodes {
if node != w {
skippedNodes[node] = struct{}{}
}
}

ret = append(ret, w)
}

// When a node has multiple branches,
// there may be a situation where a succeeding node is selected by some branches and discarded by the other branches,
// in which case the succeeding node should not be skipped.
var skippedNodeList []string
for _, selected := range ret {
if _, ok := skippedNodes[selected]; ok {
delete(skippedNodes, selected)
}
}
for skipped := range skippedNodes {
skippedNodeList = append(skippedNodeList, skipped)
}

err := cm.reportBranch(curNodeKey, skippedNodeList)
if err != nil {
return nil, err
}
return ret, nil
}

Expand Down Expand Up @@ -337,8 +364,9 @@ func (r *runner) initChannelManager(isStream bool) *channelManager {
chs[END] = builder(r.invertedEdges[END])

return &channelManager{
isStream: isStream,
channels: chs,
isStream: isStream,
channels: chs,
successors: r.successors,

edgeHandlerManager: r.edgeHandlerManager,
preNodeHandlerManager: r.preNodeHandlerManager,
Expand Down
Loading

0 comments on commit 3113915

Please sign in to comment.