Skip to content

Commit f20b8aa

Browse files
[COR-2297/] Fix nested offloaded type validation (flyteorg#552) (flyteorg#5996)
The following workflow works when we are not offloading literals in flytekit ``` import logging from typing import List from flytekit import map_task, task, workflow,LaunchPlan logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger("flytekit") logger.setLevel(logging.DEBUG) @task(cache=True, cache_version="1.1") def my_30mb_task(i: str) -> str: return f"Hello world {i}" * 30 * 100 * 1024 @task(cache=True, cache_version="1.1") def generate_strs(count: int) -> List[str]: return ["a"] * count @workflow def my_30mb_wf(mbs: int) -> List[str]: strs = generate_strs(count=mbs) return map_task(my_30mb_task)(i=strs) @workflow def big_inputs_wf(input: List[str]): noop() @task(cache=True, cache_version="1.1") def noop(): ... big_inputs_wf_lp = LaunchPlan.get_or_create(name="big_inputs_wf_lp", workflow=big_inputs_wf) @workflow def ref_wf(mbs: int): big_inputs_wf_lp(input=my_30mb_wf(mbs)) ``` Without flytekit offloading the return type is OffloadedLiteral{inferredType:{Collection{String}} and when checked against big_inputs_wf launchplan which needs Collection{String} , the LiteralTypeToLiteral returns the inferredType : Collection{String} If we enable offloading in flytekit, the returned data from map task is Collection{OffloadedLiteral<{inferredType:{Collection{String}}} When passing this Input to big_inputs_wf which takes Collection{String} then the type check fails due to LiteralTypeToLiteral returning Collection{OffloadedLiteral{inferredType:{Collection{String}}} as Collection{Collection{String}} Flytekit handles this case by special casing Collection{OffloadedLiteral} and similar special casing is needed in flyte code base Tested this by deploying this PR changes https://dogfood-gcp.cloud-staging.union.ai/console/projects/flytesnacks/domains/development/executions/akxs97cdmkmxhhqp228x/nodes Earlier it would fail like this https://dogfood-gcp.cloud-staging.union.ai/console/projects/flytesnacks/domains/development/executions/ap4thjp5528kjfspcsds/nodes ``` [UserError] failed to launch workflow, caused by: rpc error: code = InvalidArgument desc = invalid input input wrong type. Expected collection_type:{simple:STRING}, but got collection_type:{collection_type:{simple:STRING}} ``` Rollout to canary and then all prod byoc and serverless tenants Should this change be upstreamed to OSS (flyteorg/flyte)? If not, please uncheck this box, which is used for auditing. Note, it is the responsibility of each developer to actually upstream their changes. See [this guide](https://unionai.atlassian.net/wiki/spaces/ENG/pages/447610883/Flyte+-+Union+Cloud+Development+Runbook/#When-are-versions-updated%3F). - [x] To be upstreamed to OSS *TODO: Link Linear issue(s) using [magic words](https://linear.app/docs/github#magic-words). `fixes` will move to merged status, while `ref` will only link the PR.* * [X] Added tests * [ ] Ran a deploy dry run and shared the terraform plan * [ ] Added logging and metrics * [ ] Updated [dashboards](https://unionai.grafana.net/dashboards) and [alerts](https://unionai.grafana.net/alerting/list) * [ ] Updated documentation
1 parent 3c3ae05 commit f20b8aa

File tree

2 files changed

+110
-10
lines changed

2 files changed

+110
-10
lines changed

flytepropeller/pkg/compiler/validators/utils.go

+19-6
Original file line numberDiff line numberDiff line change
@@ -202,13 +202,18 @@ func buildMultipleTypeUnion(innerType []*core.LiteralType) *core.LiteralType {
202202
return unionLiteralType
203203
}
204204

205-
func literalTypeForLiterals(literals []*core.Literal) *core.LiteralType {
205+
func literalTypeForLiterals(literals []*core.Literal) (*core.LiteralType, bool) {
206206
innerType := make([]*core.LiteralType, 0, 1)
207207
innerTypeSet := sets.NewString()
208208
var noneType *core.LiteralType
209+
isOffloadedType := false
209210
for _, x := range literals {
210211
otherType := LiteralTypeForLiteral(x)
211212
otherTypeKey := otherType.String()
213+
if _, ok := x.GetValue().(*core.Literal_OffloadedMetadata); ok {
214+
isOffloadedType = true
215+
return otherType, isOffloadedType
216+
}
212217
if _, ok := x.GetValue().(*core.Literal_Collection); ok {
213218
if x.GetCollection().GetLiterals() == nil {
214219
noneType = otherType
@@ -230,9 +235,9 @@ func literalTypeForLiterals(literals []*core.Literal) *core.LiteralType {
230235
if len(innerType) == 0 {
231236
return &core.LiteralType{
232237
Type: &core.LiteralType_Simple{Simple: core.SimpleType_NONE},
233-
}
238+
}, isOffloadedType
234239
} else if len(innerType) == 1 {
235-
return innerType[0]
240+
return innerType[0], isOffloadedType
236241
}
237242

238243
// sort inner types to ensure consistent union types are generated
@@ -247,7 +252,7 @@ func literalTypeForLiterals(literals []*core.Literal) *core.LiteralType {
247252

248253
return 0
249254
})
250-
return buildMultipleTypeUnion(innerType)
255+
return buildMultipleTypeUnion(innerType), isOffloadedType
251256
}
252257

253258
// ValidateLiteralType check if the literal type is valid, return error if the literal is invalid.
@@ -274,15 +279,23 @@ func LiteralTypeForLiteral(l *core.Literal) *core.LiteralType {
274279
case *core.Literal_Scalar:
275280
return literalTypeForScalar(l.GetScalar())
276281
case *core.Literal_Collection:
282+
collectionType, isOffloaded := literalTypeForLiterals(l.GetCollection().Literals)
283+
if isOffloaded {
284+
return collectionType
285+
}
277286
return &core.LiteralType{
278287
Type: &core.LiteralType_CollectionType{
279-
CollectionType: literalTypeForLiterals(l.GetCollection().Literals),
288+
CollectionType: collectionType,
280289
},
281290
}
282291
case *core.Literal_Map:
292+
mapValueType, isOffloaded := literalTypeForLiterals(maps.Values(l.GetMap().Literals))
293+
if isOffloaded {
294+
return mapValueType
295+
}
283296
return &core.LiteralType{
284297
Type: &core.LiteralType_MapValueType{
285-
MapValueType: literalTypeForLiterals(maps.Values(l.GetMap().Literals)),
298+
MapValueType: mapValueType,
286299
},
287300
}
288301
case *core.Literal_OffloadedMetadata:

flytepropeller/pkg/compiler/validators/utils_test.go

+91-4
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@ import (
1313

1414
func TestLiteralTypeForLiterals(t *testing.T) {
1515
t.Run("empty", func(t *testing.T) {
16-
lt := literalTypeForLiterals(nil)
16+
lt, isOffloaded := literalTypeForLiterals(nil)
1717
assert.Equal(t, core.SimpleType_NONE.String(), lt.GetSimple().String())
18+
assert.False(t, isOffloaded)
1819
})
1920

2021
t.Run("binary idl with raw binary data and no tag", func(t *testing.T) {
@@ -94,17 +95,18 @@ func TestLiteralTypeForLiterals(t *testing.T) {
9495
})
9596

9697
t.Run("homogeneous", func(t *testing.T) {
97-
lt := literalTypeForLiterals([]*core.Literal{
98+
lt, isOffloaded := literalTypeForLiterals([]*core.Literal{
9899
coreutils.MustMakeLiteral(5),
99100
coreutils.MustMakeLiteral(0),
100101
coreutils.MustMakeLiteral(5),
101102
})
102103

103104
assert.Equal(t, core.SimpleType_INTEGER.String(), lt.GetSimple().String())
105+
assert.False(t, isOffloaded)
104106
})
105107

106108
t.Run("non-homogenous", func(t *testing.T) {
107-
lt := literalTypeForLiterals([]*core.Literal{
109+
lt, isOffloaded := literalTypeForLiterals([]*core.Literal{
108110
coreutils.MustMakeLiteral("hello"),
109111
coreutils.MustMakeLiteral(5),
110112
coreutils.MustMakeLiteral("world"),
@@ -115,10 +117,11 @@ func TestLiteralTypeForLiterals(t *testing.T) {
115117
assert.Len(t, lt.GetUnionType().Variants, 2)
116118
assert.Equal(t, core.SimpleType_INTEGER.String(), lt.GetUnionType().Variants[0].GetSimple().String())
117119
assert.Equal(t, core.SimpleType_STRING.String(), lt.GetUnionType().Variants[1].GetSimple().String())
120+
assert.False(t, isOffloaded)
118121
})
119122

120123
t.Run("non-homogenous ensure ordering", func(t *testing.T) {
121-
lt := literalTypeForLiterals([]*core.Literal{
124+
lt, isOffloaded := literalTypeForLiterals([]*core.Literal{
122125
coreutils.MustMakeLiteral(5),
123126
coreutils.MustMakeLiteral("world"),
124127
coreutils.MustMakeLiteral(0),
@@ -128,6 +131,7 @@ func TestLiteralTypeForLiterals(t *testing.T) {
128131
assert.Len(t, lt.GetUnionType().Variants, 2)
129132
assert.Equal(t, core.SimpleType_INTEGER.String(), lt.GetUnionType().Variants[0].GetSimple().String())
130133
assert.Equal(t, core.SimpleType_STRING.String(), lt.GetUnionType().Variants[1].GetSimple().String())
134+
assert.False(t, isOffloaded)
131135
})
132136

133137
t.Run("list with mixed types", func(t *testing.T) {
@@ -454,6 +458,89 @@ func TestLiteralTypeForLiterals(t *testing.T) {
454458
assert.True(t, proto.Equal(expectedLt, lt))
455459
})
456460

461+
t.Run("nested Lists of offloaded List of string types", func(t *testing.T) {
462+
inferredType := &core.LiteralType{
463+
Type: &core.LiteralType_CollectionType{
464+
CollectionType: &core.LiteralType{
465+
Type: &core.LiteralType_Simple{
466+
Simple: core.SimpleType_STRING,
467+
},
468+
},
469+
},
470+
}
471+
literals := &core.Literal{
472+
Value: &core.Literal_Collection{
473+
Collection: &core.LiteralCollection{
474+
Literals: []*core.Literal{
475+
{
476+
Value: &core.Literal_OffloadedMetadata{
477+
OffloadedMetadata: &core.LiteralOffloadedMetadata{
478+
Uri: "dummy/uri-1",
479+
SizeBytes: 1000,
480+
InferredType: inferredType,
481+
},
482+
},
483+
},
484+
{
485+
Value: &core.Literal_OffloadedMetadata{
486+
OffloadedMetadata: &core.LiteralOffloadedMetadata{
487+
Uri: "dummy/uri-2",
488+
SizeBytes: 1000,
489+
InferredType: inferredType,
490+
},
491+
},
492+
},
493+
},
494+
},
495+
},
496+
}
497+
expectedLt := inferredType
498+
lt := LiteralTypeForLiteral(literals)
499+
assert.True(t, proto.Equal(expectedLt, lt))
500+
})
501+
t.Run("nested map of offloaded map of string types", func(t *testing.T) {
502+
inferredType := &core.LiteralType{
503+
Type: &core.LiteralType_MapValueType{
504+
MapValueType: &core.LiteralType{
505+
Type: &core.LiteralType_Simple{
506+
Simple: core.SimpleType_STRING,
507+
},
508+
},
509+
},
510+
}
511+
literals := &core.Literal{
512+
Value: &core.Literal_Map{
513+
Map: &core.LiteralMap{
514+
Literals: map[string]*core.Literal{
515+
516+
"key1": {
517+
Value: &core.Literal_OffloadedMetadata{
518+
OffloadedMetadata: &core.LiteralOffloadedMetadata{
519+
Uri: "dummy/uri-1",
520+
SizeBytes: 1000,
521+
InferredType: inferredType,
522+
},
523+
},
524+
},
525+
"key2": {
526+
Value: &core.Literal_OffloadedMetadata{
527+
OffloadedMetadata: &core.LiteralOffloadedMetadata{
528+
Uri: "dummy/uri-2",
529+
SizeBytes: 1000,
530+
InferredType: inferredType,
531+
},
532+
},
533+
},
534+
},
535+
},
536+
},
537+
}
538+
539+
expectedLt := inferredType
540+
lt := LiteralTypeForLiteral(literals)
541+
assert.True(t, proto.Equal(expectedLt, lt))
542+
})
543+
457544
}
458545

459546
func TestJoinVariableMapsUniqueKeys(t *testing.T) {

0 commit comments

Comments
 (0)