diff --git a/cmd/manager/main.go b/cmd/manager/main.go index 1ed8530d754..cbc524c7a1b 100644 --- a/cmd/manager/main.go +++ b/cmd/manager/main.go @@ -23,7 +23,6 @@ import ( "sigs.k8s.io/controller-runtime/pkg/webhook/admission" - "github.com/kserve/kserve/pkg/utils" istio_networking "istio.io/api/networking/v1alpha3" istioclientv1beta1 "istio.io/client-go/pkg/apis/networking/v1beta1" v1 "k8s.io/api/core/v1" @@ -41,6 +40,10 @@ import ( metricsserver "sigs.k8s.io/controller-runtime/pkg/metrics/server" "sigs.k8s.io/controller-runtime/pkg/webhook" + "github.com/kserve/kserve/pkg/utils" + + routev1 "github.com/openshift/api/route/v1" + "github.com/kserve/kserve/pkg/apis/serving/v1alpha1" "github.com/kserve/kserve/pkg/apis/serving/v1beta1" "github.com/kserve/kserve/pkg/constants" @@ -50,7 +53,6 @@ import ( v1beta1controller "github.com/kserve/kserve/pkg/controller/v1beta1/inferenceservice" "github.com/kserve/kserve/pkg/webhook/admission/pod" "github.com/kserve/kserve/pkg/webhook/admission/servingruntime" - routev1 "github.com/openshift/api/route/v1" ) var ( @@ -281,7 +283,7 @@ func main() { if err = ctrl.NewWebhookManagedBy(mgr). For(&v1beta1.InferenceService{}). WithDefaulter(&v1beta1.InferenceServiceDefaulter{}). - WithValidator(&v1beta1.InferenceServiceValidator{}). + WithValidator(&v1beta1.InferenceServiceValidator{Client: mgr.GetClient()}). Complete(); err != nil { setupLog.Error(err, "unable to create webhook", "webhook", "v1beta1") os.Exit(1) diff --git a/config/webhook/manifests.yaml b/config/webhook/manifests.yaml index ddac00f3376..85265a00dd9 100644 --- a/config/webhook/manifests.yaml +++ b/config/webhook/manifests.yaml @@ -76,6 +76,7 @@ webhooks: operations: - CREATE - UPDATE + - DELETE resources: - inferenceservices --- diff --git a/pkg/apis/serving/v1beta1/inference_service_validation.go b/pkg/apis/serving/v1beta1/inference_service_validation.go index 2b2c5046d8e..43c28a299f8 100644 --- a/pkg/apis/serving/v1beta1/inference_service_validation.go +++ b/pkg/apis/serving/v1beta1/inference_service_validation.go @@ -21,19 +21,21 @@ import ( "errors" "fmt" "reflect" + "regexp" "strconv" "strings" + "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/webhook/admission" - "regexp" - - "github.com/kserve/kserve/pkg/constants" - "github.com/kserve/kserve/pkg/utils" "k8s.io/apimachinery/pkg/runtime" "knative.dev/serving/pkg/apis/autoscaling" logf "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/webhook" + + "github.com/kserve/kserve/pkg/apis/serving/v1alpha1" + "github.com/kserve/kserve/pkg/constants" + "github.com/kserve/kserve/pkg/utils" ) // regular expressions for validation of isvc name @@ -57,9 +59,11 @@ var ( // // NOTE: The +kubebuilder:object:generate=false and +k8s:deepcopy-gen=false marker prevents controller-gen from generating DeepCopy methods, // as this struct is used only for temporary operations and does not need to be deeply copied. -type InferenceServiceValidator struct{} +type InferenceServiceValidator struct { + Client client.Client +} -// +kubebuilder:webhook:verbs=create;update,path=/validate-inferenceservices,mutating=false,failurePolicy=fail,groups=serving.kserve.io,resources=inferenceservices,versions=v1beta1,name=inferenceservice.kserve-webhook-server.validator +// +kubebuilder:webhook:verbs=create;update;delete,path=/validate-inferenceservices,mutating=false,failurePolicy=fail,groups=serving.kserve.io,resources=inferenceservices,versions=v1beta1,name=inferenceservice.kserve-webhook-server.validator var _ webhook.CustomValidator = &InferenceServiceValidator{} // ValidateCreate implements webhook.Validator so a webhook will be registered for the type @@ -93,7 +97,36 @@ func (v *InferenceServiceValidator) ValidateDelete(ctx context.Context, obj runt return nil, err } validatorLogger.Info("validate delete", "name", isvc.Name) - return nil, nil + return v.validateInferenceServiceReferences(ctx, isvc) +} + +// validateInferenceServiceReferences checks if there are any InferenceGraphs that are referencing the given +// InferenceService in isvc argument, and returns an error if there are references to it. +func (v *InferenceServiceValidator) validateInferenceServiceReferences(ctx context.Context, isvc *InferenceService) (admission.Warnings, error) { + igList := v1alpha1.InferenceGraphList{} + err := v.Client.List(ctx, &igList, client.InNamespace(isvc.GetNamespace())) + if err != nil { + return admission.Warnings{}, fmt.Errorf("failed to fetch list of InferenceGraphs: %w", err) + } + + var isvcReferences []string + for _, ig := range igList.Items { + node_loop: + for _, igNode := range ig.Spec.Nodes { + for _, step := range igNode.Steps { + if step.ServiceName == isvc.GetName() { + isvcReferences = append(isvcReferences, ig.GetName()) + break node_loop + } + } + } + } + + if len(isvcReferences) != 0 { + return admission.Warnings{}, fmt.Errorf("InferenceService [%s] is being used in the following InferenceGraphs: %s", isvc.GetName(), strings.Join(isvcReferences, ", ")) + } + + return admission.Warnings{}, nil } // GetIntReference returns the pointer for the integer input diff --git a/pkg/webhook/validation/inferenceservice/inferenceservice_webhook_test.go b/pkg/webhook/validation/inferenceservice/inferenceservice_webhook_test.go new file mode 100644 index 00000000000..975beef71ef --- /dev/null +++ b/pkg/webhook/validation/inferenceservice/inferenceservice_webhook_test.go @@ -0,0 +1,127 @@ +package inferenceservice + +import ( + "context" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "google.golang.org/protobuf/proto" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" + + "github.com/kserve/kserve/pkg/apis/serving/v1alpha1" + "github.com/kserve/kserve/pkg/apis/serving/v1beta1" + "github.com/kserve/kserve/pkg/constants" +) + +var _ = Describe("InferenceService validator Webhook", func() { + + Context("When deleting InferenceService under Validating Webhook", func() { + var validator v1beta1.InferenceServiceValidator + var servingRuntime *v1alpha1.ServingRuntime + var isvc1, isvc2 *v1beta1.InferenceService + + BeforeEach(func() { + validator = v1beta1.InferenceServiceValidator{Client: k8sClient} + + // Create a serving runtime + servingRuntime = &v1alpha1.ServingRuntime{ + ObjectMeta: metav1.ObjectMeta{ + Name: "tf-serving", + Namespace: "default", + }, + Spec: v1alpha1.ServingRuntimeSpec{ + SupportedModelFormats: []v1alpha1.SupportedModelFormat{ + { + Name: "tensorflow", + Version: proto.String("1"), + AutoSelect: proto.Bool(true), + }, + }, + ServingRuntimePodSpec: v1alpha1.ServingRuntimePodSpec{ + Containers: []v1.Container{ + { + Name: constants.InferenceServiceContainerName, + Image: "tensorflow/serving:1.14.0", + }, + }, + }, + Disabled: proto.Bool(false), + }, + } + Expect(k8sClient.Create(ctx, servingRuntime)).To(Succeed()) + + // Create two inference services to be referenced by an inference graph + isvc1 = &v1beta1.InferenceService{ + ObjectMeta: metav1.ObjectMeta{ + Name: "foo", + Namespace: "default", + }, + Spec: v1beta1.InferenceServiceSpec{ + Predictor: v1beta1.PredictorSpec{ + Tensorflow: &v1beta1.TFServingSpec{ + PredictorExtensionSpec: v1beta1.PredictorExtensionSpec{ + StorageURI: proto.String("s3://test/mnist/export"), + RuntimeVersion: proto.String("1.14.0"), + }, + }, + }, + }, + } + isvc1.DefaultInferenceService(nil, nil, &v1beta1.SecurityConfig{AutoMountServiceAccountToken: false}, nil) + isvc2 = isvc1.DeepCopy() + isvc2.Name = isvc2.Name + "-second" + + Expect(k8sClient.Create(ctx, isvc1)).Should(Succeed()) + Expect(k8sClient.Create(ctx, isvc2)).Should(Succeed()) + }) + + AfterEach(func() { + Expect(k8sClient.Delete(ctx, servingRuntime)).To(WithTransform(client.IgnoreNotFound, Succeed())) + Expect(k8sClient.Delete(ctx, isvc1)).To(WithTransform(client.IgnoreNotFound, Succeed())) + Expect(k8sClient.Delete(ctx, isvc2)).To(WithTransform(client.IgnoreNotFound, Succeed())) + }) + + It("Should allow deleting an InferenceService that is not referenced by an InferenceGraph", func() { + Expect(validator.ValidateDelete(ctx, isvc1)).Error().ToNot(HaveOccurred()) + }) + + It("Should prevent deleting an InferenceService that is referenced by an InferenceGraph", func() { + inferenceGraph := v1alpha1.InferenceGraph{ + ObjectMeta: metav1.ObjectMeta{ + Name: "inferencegraph-one", + Namespace: "default", + }, + Spec: v1alpha1.InferenceGraphSpec{ + Nodes: map[string]v1alpha1.InferenceRouter{ + "root": { + RouterType: v1alpha1.Sequence, + Steps: []v1alpha1.InferenceStep{ + {StepName: "first", InferenceTarget: v1alpha1.InferenceTarget{ + ServiceName: isvc1.GetName(), + }}, + }, + }, + }, + }, + } + Expect(k8sClient.Create(ctx, &inferenceGraph)).To(Succeed()) + defer func(ctx context.Context, inferenceGraph *v1alpha1.InferenceGraph) { + _ = k8sClient.Delete(ctx, inferenceGraph) + }(ctx, &inferenceGraph) + + Eventually(func() error { + var checkIg v1alpha1.InferenceGraph + return k8sClient.Get(ctx, types.NamespacedName{ + Namespace: inferenceGraph.GetNamespace(), + Name: inferenceGraph.GetName(), + }, &checkIg) + }).ShouldNot(HaveOccurred()) + + _, err := validator.ValidateDelete(ctx, isvc1) + Expect(err).To(HaveOccurred()) + }) + }) +}) diff --git a/pkg/webhook/validation/inferenceservice/webhook_suite_test.go b/pkg/webhook/validation/inferenceservice/webhook_suite_test.go new file mode 100644 index 00000000000..af677f4bd63 --- /dev/null +++ b/pkg/webhook/validation/inferenceservice/webhook_suite_test.go @@ -0,0 +1,154 @@ +/* +Copyright 2021 The KServe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package inferenceservice + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "path/filepath" + "testing" + "time" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + v1 "k8s.io/api/core/v1" + netv1 "k8s.io/api/networking/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/kubernetes/scheme" + knservingv1 "knative.dev/serving/pkg/apis/serving/v1" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/envtest" + logf "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/controller-runtime/pkg/log/zap" + metricsserver "sigs.k8s.io/controller-runtime/pkg/metrics/server" + "sigs.k8s.io/controller-runtime/pkg/webhook" + + routev1 "github.com/openshift/api/route/v1" + + v1alpha1 "github.com/kserve/kserve/pkg/apis/serving/v1alpha1" + "github.com/kserve/kserve/pkg/apis/serving/v1beta1" + "github.com/kserve/kserve/pkg/constants" + pkgtest "github.com/kserve/kserve/pkg/testing" +) + +// These tests use Ginkgo (BDD-style Go testing framework). Refer to +// http://onsi.github.io/ginkgo/ to learn more about Ginkgo. + +var ( + k8sClient client.Client + testEnv *envtest.Environment + cancel context.CancelFunc + ctx context.Context + clientset kubernetes.Interface +) + +func TestAPIs(t *testing.T) { + RegisterFailHandler(Fail) + + RunSpecs(t, "InferenceService Webhook Suite") +} + +var _ = BeforeSuite(func() { + logf.SetLogger(zap.New(zap.WriteTo(GinkgoWriter), zap.UseDevMode(true))) + ctx, cancel = context.WithCancel(context.TODO()) + By("bootstrapping test environment") + crdDirectoryPaths := []string{ + filepath.Join("..", "..", "..", "..", "test", "crds"), + } + testEnv = pkgtest.SetupEnvTest(crdDirectoryPaths) + cfg, err := testEnv.Start() + Expect(err).ToNot(HaveOccurred()) + Expect(cfg).ToNot(BeNil()) + + err = v1alpha1.AddToScheme(scheme.Scheme) + Expect(err).NotTo(HaveOccurred()) + err = v1beta1.AddToScheme(scheme.Scheme) + Expect(err).NotTo(HaveOccurred()) + err = knservingv1.AddToScheme(scheme.Scheme) + Expect(err).NotTo(HaveOccurred()) + err = netv1.AddToScheme(scheme.Scheme) + Expect(err).NotTo(HaveOccurred()) + err = routev1.AddToScheme(scheme.Scheme) + Expect(err).NotTo(HaveOccurred()) + + k8sClient, err = client.New(cfg, client.Options{Scheme: scheme.Scheme}) + Expect(err).ToNot(HaveOccurred()) + Expect(k8sClient).ToNot(BeNil()) + + clientset, err = kubernetes.NewForConfig(cfg) + Expect(err).ToNot(HaveOccurred()) + Expect(clientset).ToNot(BeNil()) + + //Create namespace + kfservingNamespaceObj := &v1.Namespace{ + ObjectMeta: metav1.ObjectMeta{ + Name: constants.KServeNamespace, + }, + } + Expect(k8sClient.Create(context.Background(), kfservingNamespaceObj)).Should(Succeed()) + + webhookInstallOptions := &testEnv.WebhookInstallOptions + mgr, err := ctrl.NewManager(cfg, ctrl.Options{ + Scheme: scheme.Scheme, + WebhookServer: webhook.NewServer(webhook.Options{ + Host: webhookInstallOptions.LocalServingHost, + Port: webhookInstallOptions.LocalServingPort, + CertDir: webhookInstallOptions.LocalServingCertDir, + }), + LeaderElection: false, + Metrics: metricsserver.Options{BindAddress: "0"}, + }) + Expect(err).NotTo(HaveOccurred()) + + err = ctrl.NewWebhookManagedBy(mgr). + For(&v1beta1.InferenceService{}). + WithValidator(&v1beta1.InferenceServiceValidator{Client: mgr.GetClient()}). + Complete() + Expect(err).ToNot(HaveOccurred()) + + go func() { + defer GinkgoRecover() + err = mgr.Start(ctx) + Expect(err).ToNot(HaveOccurred()) + }() + + k8sClient = mgr.GetClient() + Expect(k8sClient).ToNot(BeNil()) + + // wait for the webhook server to get ready. + dialer := &net.Dialer{Timeout: time.Second} + addrPort := fmt.Sprintf("%s:%d", webhookInstallOptions.LocalServingHost, webhookInstallOptions.LocalServingPort) + Eventually(func() error { + conn, err := tls.DialWithDialer(dialer, "tcp", addrPort, &tls.Config{InsecureSkipVerify: true}) + if err != nil { + return err + } + + return conn.Close() + }).Should(Succeed()) +}) + +var _ = AfterSuite(func() { + cancel() + By("tearing down the test environment") + err := testEnv.Stop() + Expect(err).ToNot(HaveOccurred()) +})