Skip to content

Commit 2edf4cb

Browse files
committed
Merge remote-tracking branch 'origin' into adopt-protogetter
Signed-off-by: Eduardo Apolinario <[email protected]>
2 parents d2dcc2e + be67457 commit 2edf4cb

File tree

18 files changed

+508
-29
lines changed

18 files changed

+508
-29
lines changed

flyteadmin/pkg/manager/impl/validation/launch_plan_validator.go

+14
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package validation
33
import (
44
"context"
55

6+
"github.com/robfig/cron/v3"
67
"google.golang.org/grpc/codes"
78

89
"github.com/flyteorg/flyte/flyteadmin/pkg/common"
@@ -91,6 +92,19 @@ func validateSchedule(request *admin.LaunchPlanCreateRequest, expectedInputs *co
9192
"KickoffTimeInputArg must reference a datetime input. [%v] is a [%v]", schedule.GetKickoffTimeInputArg(), param.GetVar().GetType())
9293
}
9394
}
95+
96+
// validate cron expression
97+
var cronExpression string
98+
if schedule.GetCronExpression() != "" {
99+
cronExpression = schedule.GetCronExpression()
100+
} else if schedule.GetCronSchedule() != nil {
101+
cronExpression = schedule.GetCronSchedule().GetSchedule()
102+
}
103+
if cronExpression != "" {
104+
if _, err := cron.ParseStandard(cronExpression); err != nil {
105+
return errors.NewFlyteAdminErrorf(codes.InvalidArgument, "Invalid cron expression: %v", err)
106+
}
107+
}
94108
}
95109
return nil
96110
}

flyteadmin/pkg/manager/impl/validation/launch_plan_validator_test.go

+37-6
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ func TestValidateSchedule_ArgNotFixed(t *testing.T) {
358358
},
359359
}
360360
t.Run("with deprecated cron expression", func(t *testing.T) {
361-
request := testutils.GetLaunchPlanRequestWithDeprecatedCronSchedule("* * * * * *")
361+
request := testutils.GetLaunchPlanRequestWithDeprecatedCronSchedule("* * * * *")
362362

363363
err := validateSchedule(request, inputMap)
364364
assert.NotNil(t, err)
@@ -370,15 +370,15 @@ func TestValidateSchedule_ArgNotFixed(t *testing.T) {
370370
assert.NotNil(t, err)
371371
})
372372
t.Run("with cron schedule", func(t *testing.T) {
373-
request := testutils.GetLaunchPlanRequestWithCronSchedule("* * * * * *")
373+
request := testutils.GetLaunchPlanRequestWithCronSchedule("* * * * *")
374374

375375
err := validateSchedule(request, inputMap)
376376
assert.NotNil(t, err)
377377
})
378378
}
379379

380380
func TestValidateSchedule_KickoffTimeArgDoesNotExist(t *testing.T) {
381-
request := testutils.GetLaunchPlanRequestWithDeprecatedCronSchedule("* * * * * *")
381+
request := testutils.GetLaunchPlanRequestWithDeprecatedCronSchedule("* * * * *")
382382
inputMap := &core.ParameterMap{
383383
Parameters: map[string]*core.Parameter{},
384384
}
@@ -389,7 +389,7 @@ func TestValidateSchedule_KickoffTimeArgDoesNotExist(t *testing.T) {
389389
}
390390

391391
func TestValidateSchedule_KickoffTimeArgPointsAtWrongType(t *testing.T) {
392-
request := testutils.GetLaunchPlanRequestWithDeprecatedCronSchedule("* * * * * *")
392+
request := testutils.GetLaunchPlanRequestWithDeprecatedCronSchedule("* * * * *")
393393
inputMap := &core.ParameterMap{
394394
Parameters: map[string]*core.Parameter{
395395
foo: {
@@ -409,7 +409,7 @@ func TestValidateSchedule_KickoffTimeArgPointsAtWrongType(t *testing.T) {
409409
}
410410

411411
func TestValidateSchedule_NoRequired(t *testing.T) {
412-
request := testutils.GetLaunchPlanRequestWithDeprecatedCronSchedule("* * * * * *")
412+
request := testutils.GetLaunchPlanRequestWithDeprecatedCronSchedule("* * * * *")
413413
inputMap := &core.ParameterMap{
414414
Parameters: map[string]*core.Parameter{
415415
foo: {
@@ -428,7 +428,7 @@ func TestValidateSchedule_NoRequired(t *testing.T) {
428428
}
429429

430430
func TestValidateSchedule_KickoffTimeBound(t *testing.T) {
431-
request := testutils.GetLaunchPlanRequestWithDeprecatedCronSchedule("* * * * * *")
431+
request := testutils.GetLaunchPlanRequestWithDeprecatedCronSchedule("* * * * *")
432432
inputMap := &core.ParameterMap{
433433
Parameters: map[string]*core.Parameter{
434434
foo: {
@@ -446,3 +446,34 @@ func TestValidateSchedule_KickoffTimeBound(t *testing.T) {
446446
err := validateSchedule(request, inputMap)
447447
assert.Nil(t, err)
448448
}
449+
450+
func TestValidateSchedule_InvalidCronExpression(t *testing.T) {
451+
inputMap := &core.ParameterMap{
452+
Parameters: map[string]*core.Parameter{
453+
foo: {
454+
Var: &core.Variable{
455+
Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_DATETIME}},
456+
},
457+
Behavior: &core.Parameter_Required{
458+
Required: true,
459+
},
460+
},
461+
},
462+
}
463+
464+
t.Run("with unsupported cron special characters on deprecated cron schedule: #", func(t *testing.T) {
465+
request := testutils.GetLaunchPlanRequestWithDeprecatedCronSchedule("* * * * MON#1")
466+
request.Spec.EntityMetadata.Schedule.KickoffTimeInputArg = foo
467+
468+
err := validateSchedule(request, inputMap)
469+
assert.NotNil(t, err)
470+
})
471+
472+
t.Run("with unsupported cron special characters: #", func(t *testing.T) {
473+
request := testutils.GetLaunchPlanRequestWithCronSchedule("* * * * MON#1")
474+
request.Spec.EntityMetadata.Schedule.KickoffTimeInputArg = foo
475+
476+
err := validateSchedule(request, inputMap)
477+
assert.NotNil(t, err)
478+
})
479+
}

flytecopilot/go.mod

+1
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ require (
8282
github.com/prometheus/client_model v0.6.1 // indirect
8383
github.com/prometheus/common v0.53.0 // indirect
8484
github.com/prometheus/procfs v0.15.1 // indirect
85+
github.com/shamaton/msgpack/v2 v2.2.2 // indirect
8586
github.com/sirupsen/logrus v1.9.3 // indirect
8687
github.com/spf13/afero v1.8.2 // indirect
8788
github.com/spf13/cast v1.4.1 // indirect

flytecopilot/go.sum

+2
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,8 @@ github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFR
309309
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
310310
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
311311
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
312+
github.com/shamaton/msgpack/v2 v2.2.2 h1:GOIg0c9LV04VwzOOqZSrmsv/JzjNOOMxnS/HvOHGdgs=
313+
github.com/shamaton/msgpack/v2 v2.2.2/go.mod h1:6khjYnkx73f7VQU7wjcFS9DFjs+59naVWJv1TB7qdOI=
312314
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
313315
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
314316
github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72 h1:qLC7fQah7D6K1B0ujays3HV9gkFtllcxhzImRR7ArPQ=

flyteidl/clients/go/coreutils/extract_literal_test.go

+3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
package coreutils
55

66
import (
7+
"os"
78
"testing"
89
"time"
910

@@ -125,6 +126,7 @@ func TestFetchLiteral(t *testing.T) {
125126
})
126127

127128
t.Run("Generic", func(t *testing.T) {
129+
os.Setenv(FlyteUseOldDcFormat, "true")
128130
literalVal := map[string]interface{}{
129131
"x": 1,
130132
"y": "ystringvalue",
@@ -150,6 +152,7 @@ func TestFetchLiteral(t *testing.T) {
150152
for key, val := range expectedStructVal.GetFields() {
151153
assert.Equal(t, val.GetKind(), extractedStructValue.GetFields()[key].GetKind())
152154
}
155+
os.Unsetenv(FlyteUseOldDcFormat)
153156
})
154157

155158
t.Run("Generic Passed As String", func(t *testing.T) {

flyteidl/clients/go/coreutils/literals.go

+30-6
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,24 @@ import (
55
"encoding/json"
66
"fmt"
77
"math"
8+
"os"
89
"reflect"
910
"strconv"
1011
"strings"
1112
"time"
1213

13-
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
14-
"github.com/flyteorg/flyte/flytestdlib/storage"
1514
"github.com/golang/protobuf/jsonpb"
1615
"github.com/golang/protobuf/ptypes"
1716
structpb "github.com/golang/protobuf/ptypes/struct"
1817
"github.com/pkg/errors"
18+
"github.com/shamaton/msgpack/v2"
19+
20+
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
21+
"github.com/flyteorg/flyte/flytestdlib/storage"
1922
)
2023

2124
const MESSAGEPACK = "msgpack"
25+
const FlyteUseOldDcFormat = "FLYTE_USE_OLD_DC_FORMAT"
2226

2327
func MakePrimitive(v interface{}) (*core.Primitive, error) {
2428
switch p := v.(type) {
@@ -561,12 +565,32 @@ func MakeLiteralForType(t *core.LiteralType, v interface{}) (*core.Literal, erro
561565
strValue = fmt.Sprintf("%.0f", math.Trunc(f))
562566
}
563567
if newT.Simple == core.SimpleType_STRUCT {
568+
useOldFormat := strings.ToLower(os.Getenv(FlyteUseOldDcFormat))
564569
if _, isValueStringType := v.(string); !isValueStringType {
565-
byteValue, err := json.Marshal(v)
566-
if err != nil {
567-
return nil, fmt.Errorf("unable to marshal to json string for struct value %v", v)
570+
if useOldFormat == "1" || useOldFormat == "t" || useOldFormat == "true" {
571+
byteValue, err := json.Marshal(v)
572+
if err != nil {
573+
return nil, fmt.Errorf("unable to marshal to json string for struct value %v", v)
574+
}
575+
strValue = string(byteValue)
576+
} else {
577+
byteValue, err := msgpack.Marshal(v)
578+
if err != nil {
579+
return nil, fmt.Errorf("unable to marshal to msgpack bytes for struct value %v", v)
580+
}
581+
return &core.Literal{
582+
Value: &core.Literal_Scalar{
583+
Scalar: &core.Scalar{
584+
Value: &core.Scalar_Binary{
585+
Binary: &core.Binary{
586+
Value: byteValue,
587+
Tag: MESSAGEPACK,
588+
},
589+
},
590+
},
591+
},
592+
}, nil
568593
}
569-
strValue = string(byteValue)
570594
}
571595
}
572596
lv, err := MakeLiteralForSimpleType(newT.Simple, strValue)

flyteidl/clients/go/coreutils/literals_test.go

+66
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ package coreutils
55

66
import (
77
"fmt"
8+
"os"
89
"reflect"
910
"strconv"
1011
"testing"
@@ -14,6 +15,7 @@ import (
1415
"github.com/golang/protobuf/ptypes"
1516
structpb "github.com/golang/protobuf/ptypes/struct"
1617
"github.com/pkg/errors"
18+
"github.com/shamaton/msgpack/v2"
1719
"github.com/stretchr/testify/assert"
1820

1921
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
@@ -455,6 +457,7 @@ func TestMakeLiteralForType(t *testing.T) {
455457
})
456458

457459
t.Run("Generic", func(t *testing.T) {
460+
os.Setenv(FlyteUseOldDcFormat, "true")
458461
literalVal := map[string]interface{}{
459462
"x": 1,
460463
"y": "ystringvalue",
@@ -480,6 +483,69 @@ func TestMakeLiteralForType(t *testing.T) {
480483
for key, val := range expectedStructVal.GetFields() {
481484
assert.Equal(t, val.GetKind(), extractedStructValue.GetFields()[key].GetKind())
482485
}
486+
os.Unsetenv(FlyteUseOldDcFormat)
487+
})
488+
489+
t.Run("SimpleBinary", func(t *testing.T) {
490+
// We compare the deserialized values instead of the raw msgpack bytes because Go does not guarantee the order
491+
// of map keys during serialization. This means that while the serialized bytes may differ, the deserialized
492+
// values should be logically equivalent.
493+
494+
var literalType = &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRUCT}}
495+
v := map[string]interface{}{
496+
"a": int64(1),
497+
"b": 3.14,
498+
"c": "example_string",
499+
"d": map[string]interface{}{
500+
"1": int64(100),
501+
"2": int64(200),
502+
},
503+
"e": map[string]interface{}{
504+
"a": int64(1),
505+
"b": 3.14,
506+
},
507+
"f": []string{"a", "b", "c"},
508+
}
509+
510+
val, err := MakeLiteralForType(literalType, v)
511+
assert.NoError(t, err)
512+
513+
msgpackBytes, err := msgpack.Marshal(v)
514+
assert.NoError(t, err)
515+
516+
literalVal := &core.Literal{
517+
Value: &core.Literal_Scalar{
518+
Scalar: &core.Scalar{
519+
Value: &core.Scalar_Binary{
520+
Binary: &core.Binary{
521+
Value: msgpackBytes,
522+
Tag: MESSAGEPACK,
523+
},
524+
},
525+
},
526+
},
527+
}
528+
529+
expectedLiteralVal, err := ExtractFromLiteral(literalVal)
530+
assert.NoError(t, err)
531+
actualLiteralVal, err := ExtractFromLiteral(val)
532+
assert.NoError(t, err)
533+
534+
// Check if the extracted value is of type *core.Binary (not []byte)
535+
expectedBinary, ok := expectedLiteralVal.(*core.Binary)
536+
assert.True(t, ok, "expectedLiteralVal is not of type *core.Binary")
537+
actualBinary, ok := actualLiteralVal.(*core.Binary)
538+
assert.True(t, ok, "actualLiteralVal is not of type *core.Binary")
539+
540+
// Now check if the Binary values match
541+
var expectedVal, actualVal map[string]interface{}
542+
err = msgpack.Unmarshal(expectedBinary.Value, &expectedVal)
543+
assert.NoError(t, err)
544+
err = msgpack.Unmarshal(actualBinary.Value, &actualVal)
545+
assert.NoError(t, err)
546+
547+
// Finally, assert that the deserialized values are equal
548+
assert.Equal(t, expectedVal, actualVal)
483549
})
484550

485551
t.Run("ArrayStrings", func(t *testing.T) {

flyteidl/go.mod

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ require (
1313
github.com/mitchellh/mapstructure v1.5.0
1414
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c
1515
github.com/pkg/errors v0.9.1
16+
github.com/shamaton/msgpack/v2 v2.2.2
1617
github.com/spf13/pflag v1.0.5
1718
github.com/stretchr/testify v1.9.0
1819
golang.org/x/net v0.27.0

flyteidl/go.sum

+2
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,8 @@ github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoG
215215
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
216216
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
217217
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
218+
github.com/shamaton/msgpack/v2 v2.2.2 h1:GOIg0c9LV04VwzOOqZSrmsv/JzjNOOMxnS/HvOHGdgs=
219+
github.com/shamaton/msgpack/v2 v2.2.2/go.mod h1:6khjYnkx73f7VQU7wjcFS9DFjs+59naVWJv1TB7qdOI=
218220
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
219221
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
220222
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=

flyteplugins/go/tasks/plugins/k8s/ray/config.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ var (
2222
IncludeDashboard: true,
2323
DashboardHost: "0.0.0.0",
2424
EnableUsageStats: false,
25-
ServiceAccount: "default",
25+
ServiceAccount: "",
2626
Defaults: DefaultConfig{
2727
HeadNode: NodeConfig{
2828
StartParameters: map[string]string{

flyteplugins/go/tasks/plugins/k8s/ray/config_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,6 @@ func TestLoadDefaultServiceAccountConfig(t *testing.T) {
3232

3333
t.Run("serviceAccount", func(t *testing.T) {
3434
config := GetConfig()
35-
assert.Equal(t, config.ServiceAccount, "default")
35+
assert.Equal(t, config.ServiceAccount, "")
3636
})
3737
}

flyteplugins/go/tasks/plugins/k8s/ray/ray.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ func constructRayJob(taskCtx pluginsCore.TaskExecutionContext, rayJob plugins.Ra
190190
}
191191

192192
serviceAccountName := flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata())
193-
if len(serviceAccountName) == 0 {
193+
if len(serviceAccountName) == 0 || cfg.ServiceAccount != "" {
194194
serviceAccountName = cfg.ServiceAccount
195195
}
196196

0 commit comments

Comments
 (0)