diff --git a/charts/kserve-resources/templates/clusterrole.yaml b/charts/kserve-resources/templates/clusterrole.yaml index 6ba81454c83..d383b3c8b2a 100644 --- a/charts/kserve-resources/templates/clusterrole.yaml +++ b/charts/kserve-resources/templates/clusterrole.yaml @@ -51,9 +51,17 @@ rules: - "" resources: - secrets + verbs: + - get +- apiGroups: + - "" + resources: - serviceaccounts verbs: + - create + - delete - get + - patch - apiGroups: - admissionregistration.k8s.io resources: @@ -124,6 +132,34 @@ rules: - patch - update - watch +- apiGroups: + - rbac.authorization.k8s.io + resourceNames: + - kserve-inferencegraph-auth-verifiers + resources: + - clusterrolebindings + verbs: + - create + - get + - patch + - update +- apiGroups: + - route.openshift.io + resources: + - routes + verbs: + - create + - get + - list + - patch + - update + - watch +- apiGroups: + - route.openshift.io + resources: + - routes/status + verbs: + - get - apiGroups: - serving.knative.dev resources: diff --git a/cmd/router/main.go b/cmd/router/main.go index dc50d061ff3..61951ac34e0 100644 --- a/cmd/router/main.go +++ b/cmd/router/main.go @@ -18,10 +18,13 @@ package main import ( "bytes" + "context" + "crypto/rand" "encoding/json" goerrors "errors" "fmt" "io" + "math/big" "net/http" "net/url" "os" @@ -31,18 +34,19 @@ import ( "syscall" "time" - "github.com/kserve/kserve/pkg/constants" "github.com/pkg/errors" - + flag "github.com/spf13/pflag" "github.com/tidwall/gjson" + authnv1 "k8s.io/api/authentication/v1" + authzv1 "k8s.io/api/authorization/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/rest" logf "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/log/zap" - "crypto/rand" - "math/big" - "github.com/kserve/kserve/pkg/apis/serving/v1alpha1" - flag "github.com/spf13/pflag" + "github.com/kserve/kserve/pkg/constants" ) var log = logf.Log.WithName("InferenceGraphRouter") @@ -411,10 +415,158 @@ func compilePatterns(patterns []string) ([]*regexp.Regexp, error) { } var ( + enableAuthFlag = flag.Bool("enable-auth", false, "protect the inference graph with authorization") + graphName = flag.String("inferencegraph-name", "", "the name of the associated inference graph Kubernetes resource") jsonGraph = flag.String("graph-json", "", "serialized json graph def") compiledHeaderPatterns []*regexp.Regexp ) +// findBearerToken parses the standard HTTP Authorization header to find and return +// a Bearer token that a client may have provided in the request. If the token +// is found, it is returned. Else, an empty string is returned and the HTTP response +// is sent to the client with proper status code and the reason for the request being +// rejected. +func findBearerToken(w http.ResponseWriter, r *http.Request) string { + // Find for HTTP Authentication header. Reject request if not available. + authHeader := r.Header.Get("Authorization") + if len(authHeader) == 0 { + w.Header().Set("X-Forbidden-Reason", "No credentials were provided") + w.WriteHeader(http.StatusUnauthorized) + return "" + } + + // Parse Auth header + token := strings.TrimPrefix(authHeader, "Bearer ") + if token == authHeader { + w.Header().Set("X-Forbidden-Reason", "Only Bearer tokens are supported") + w.WriteHeader(http.StatusUnauthorized) + return "" + } + return token +} + +// validateTokenIsAuthenticated queries the Kubernetes cluster to find if the provided token is +// valid and flagged as authenticated. If the token is usable, the result of the TokenReview +// is returned. Otherwise, the HTTP response is sent rejecting the request and setting +// a meaningful status code along with a reason (if available). +func validateTokenIsAuthenticated(w http.ResponseWriter, token string, clientset *kubernetes.Clientset) *authnv1.TokenReview { + // Check the token is valid + tokenReview := authnv1.TokenReview{} + tokenReview.Spec.Token = token + tokenReviewResult, err := clientset.AuthenticationV1().TokenReviews().Create(context.Background(), &tokenReview, metav1.CreateOptions{}) + if err != nil { + log.Error(err, "failed to create TokenReview when verifying credentials") + w.WriteHeader(http.StatusInternalServerError) + return nil + } + if len(tokenReviewResult.Status.Error) != 0 { + w.Header().Set("X-Forbidden-Reason", tokenReviewResult.Status.Error) + w.WriteHeader(http.StatusUnauthorized) + return nil + } + if !tokenReviewResult.Status.Authenticated { + w.Header().Set("X-Forbidden-Reason", "The provided token is unauthenticated") + w.WriteHeader(http.StatusUnauthorized) + return nil + } + return tokenReviewResult +} + +// checkRequestIsAuthorized verifies that the user in the provided tokenReviewResult has privileges to query the +// Kubernetes API and get the InferenceGraph resource that belongs to this pod. If so, the request is considered +// as allowed and `true` is returned. Otherwise, the HTTP response is sent rejecting the request and setting +// a meaningful status code along with a reason (if available). +func checkRequestIsAuthorized(w http.ResponseWriter, _ *http.Request, tokenReviewResult *authnv1.TokenReview, clientset *kubernetes.Clientset) bool { + // Read pod namespace + const namespaceFile = "/var/run/secrets/kubernetes.io/serviceaccount/namespace" + namespaceBytes, err := os.ReadFile(namespaceFile) + if err != nil { + log.Error(err, "failed to read namespace file while verifying credentials") + w.WriteHeader(http.StatusInternalServerError) + return false + } + namespace := string(namespaceBytes) + + // Check the subject is authorized to query the InferenceGraph + if len(*graphName) == 0 { + log.Error(errors.New("no graph name provided"), "the --inferencegraph-name flag wasn't provided") + w.WriteHeader(http.StatusInternalServerError) + return false + } + accessReview := authzv1.SubjectAccessReview{ + Spec: authzv1.SubjectAccessReviewSpec{ + ResourceAttributes: &authzv1.ResourceAttributes{ + Namespace: namespace, + Verb: "get", + Group: "serving.kserve.io", + Resource: "inferencegraphs", + Name: *graphName, + }, + User: tokenReviewResult.Status.User.Username, + Groups: nil, + }, + } + + accessReviewResult, err := clientset.AuthorizationV1().SubjectAccessReviews().Create(context.Background(), &accessReview, metav1.CreateOptions{}) + if err != nil { + log.Error(err, "failed to create LocalSubjectAccessReview when verifying credentials") + w.WriteHeader(http.StatusInternalServerError) + return false + } + if accessReviewResult.Status.Allowed { + // Note: This is here so that the request is NOT allowed by default. + return true + } + + w.Header().Add("X-Forbidden-Reason", "Access to the InferenceGraph is not allowed") + if len(accessReviewResult.Status.Reason) != 0 { + w.Header().Add("X-Forbidden-Reason", accessReviewResult.Status.Reason) + } + if len(accessReviewResult.Status.EvaluationError) != 0 { + w.Header().Add("X-Forbidden-Reason", accessReviewResult.Status.EvaluationError) + } + + w.WriteHeader(http.StatusUnauthorized) + return false +} + +// authMiddleware uses the Middleware pattern to protect the InferenceGraph behind authorization. +// It expects that a Bearer token is provided in the request in the standard HTTP Authorization +// header. The token is verified against Kubernetes using the TokenReview and SubjectAccessReview APIs. +// If the token is valid and has enough privileges, the handler provided in the `next` argument is run. +// Otherwise, `next` is not invoked and the reason for the rejection is sent in response headers. +func authMiddleware(next http.Handler) (http.Handler, error) { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + k8sConfig, k8sConfigErr := rest.InClusterConfig() + if k8sConfigErr != nil { + log.Error(k8sConfigErr, "failed to create rest configuration to connect to Kubernetes API") + w.WriteHeader(http.StatusInternalServerError) + return + } + + clientset, clientsetErr := kubernetes.NewForConfig(k8sConfig) + if clientsetErr != nil { + log.Error(k8sConfigErr, "failed to create Kubernetes client to connect to API") + return + } + + token := findBearerToken(w, r) + if len(token) == 0 { + return + } + + tokenReviewResult := validateTokenIsAuthenticated(w, token, clientset) + if tokenReviewResult == nil { + return + } + + isAuthorized := checkRequestIsAuthorized(w, r, tokenReviewResult, clientset) + if isAuthorized { + next.ServeHTTP(w, r) + } + }), nil +} + func main() { flag.Parse() logf.SetLogger(zap.New()) @@ -434,14 +586,23 @@ func main() { os.Exit(1) } - http.HandleFunc("/", graphHandler) + var entrypointHandler http.Handler + entrypointHandler = http.HandlerFunc(graphHandler) + if *enableAuthFlag { + entrypointHandler, err = authMiddleware(entrypointHandler) + log.Info("This Router has authorization enabled") + if err != nil { + log.Error(err, "failed to create entrypoint handler") + os.Exit(1) + } + } server := &http.Server{ - Addr: ":8080", // specify the address and port - Handler: http.HandlerFunc(graphHandler), // specify your HTTP handler - ReadTimeout: time.Minute, // set the maximum duration for reading the entire request, including the body - WriteTimeout: time.Minute, // set the maximum duration before timing out writes of the response - IdleTimeout: 3 * time.Minute, // set the maximum amount of time to wait for the next request when keep-alives are enabled + Addr: ":8080", // specify the address and port + Handler: entrypointHandler, // specify your HTTP handler + ReadTimeout: time.Minute, // set the maximum duration for reading the entire request, including the body + WriteTimeout: time.Minute, // set the maximum duration before timing out writes of the response + IdleTimeout: 3 * time.Minute, // set the maximum amount of time to wait for the next request when keep-alives are enabled } err = server.ListenAndServe() diff --git a/config/rbac/role.yaml b/config/rbac/role.yaml index db02f051ebc..7f3a3b848c4 100644 --- a/config/rbac/role.yaml +++ b/config/rbac/role.yaml @@ -38,9 +38,17 @@ rules: - "" resources: - secrets + verbs: + - get +- apiGroups: + - "" + resources: - serviceaccounts verbs: + - create + - delete - get + - patch - apiGroups: - admissionregistration.k8s.io resources: @@ -111,6 +119,17 @@ rules: - patch - update - watch +- apiGroups: + - rbac.authorization.k8s.io + resourceNames: + - kserve-inferencegraph-auth-verifiers + resources: + - clusterrolebindings + verbs: + - create + - get + - patch + - update - apiGroups: - route.openshift.io resources: diff --git a/pkg/constants/constants.go b/pkg/constants/constants.go index 6dbc5f6ef2a..9880d78d4d3 100644 --- a/pkg/constants/constants.go +++ b/pkg/constants/constants.go @@ -53,6 +53,8 @@ var ( const ( RouterHeadersPropagateEnvVar = "PROPAGATE_HEADERS" InferenceGraphLabel = "serving.kserve.io/inferencegraph" + InferenceGraphAuthCRBName = "kserve-inferencegraph-auth-verifiers" + InferenceGraphFinalizerName = "inferencegraph.finalizers" ) // TrainedModel Constants diff --git a/pkg/controller/v1alpha1/inferencegraph/controller.go b/pkg/controller/v1alpha1/inferencegraph/controller.go index 2a3e80ec145..3622197eed5 100644 --- a/pkg/controller/v1alpha1/inferencegraph/controller.go +++ b/pkg/controller/v1alpha1/inferencegraph/controller.go @@ -16,6 +16,8 @@ limitations under the License. // +kubebuilder:rbac:groups=serving.kserve.io,resources=inferencegraphs;inferencegraphs/finalizers,verbs=get;list;watch;create;update;patch;delete // +kubebuilder:rbac:groups=serving.kserve.io,resources=inferencegraphs/status,verbs=get;update;patch +// +kubebuilder:rbac:groups="",resources=serviceaccounts,verbs=create;patch;delete +// +kubebuilder:rbac:groups=rbac.authorization.k8s.io,resources=clusterrolebindings,verbs=create;get;update;patch,resourceNames=kserve-inferencegraph-auth-verifiers // +kubebuilder:rbac:groups=serving.knative.dev,resources=services,verbs=get;list;watch;create;update;patch;delete // +kubebuilder:rbac:groups=serving.knative.dev,resources=services/finalizers,verbs=get;list;watch;create;update;patch;delete // +kubebuilder:rbac:groups=serving.knative.dev,resources=services/status,verbs=get;update;patch @@ -27,6 +29,7 @@ import ( "context" "encoding/json" "fmt" + "strings" "github.com/go-logr/logr" osv1 "github.com/openshift/api/route/v1" @@ -48,6 +51,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" "sigs.k8s.io/controller-runtime/pkg/reconcile" + "sigs.k8s.io/yaml" v1alpha1api "github.com/kserve/kserve/pkg/apis/serving/v1alpha1" "github.com/kserve/kserve/pkg/apis/serving/v1beta1" @@ -71,8 +75,9 @@ type InferenceGraphReconciler struct { type InferenceGraphState string const ( - InferenceGraphNotReadyState InferenceGraphState = "InferenceGraphNotReady" - InferenceGraphReadyState InferenceGraphState = "InferenceGraphReady" + InferenceGraphControllerName string = "inferencegraph-controller" + InferenceGraphNotReadyState InferenceGraphState = "InferenceGraphNotReady" + InferenceGraphReadyState InferenceGraphState = "InferenceGraphReady" ) type RouterConfig struct { @@ -173,9 +178,50 @@ func (r *InferenceGraphReconciler) Reconcile(ctx context.Context, req ctrl.Reque return reconcile.Result{}, errors.Wrapf(err, "fails to create DeployConfig") } + // examine DeletionTimestamp to determine if object is under deletion + if graph.ObjectMeta.DeletionTimestamp.IsZero() { + // The object is not being deleted, so if it does not have our finalizer, + // then lets add the finalizer. + if !utils.Includes(graph.ObjectMeta.Finalizers, constants.InferenceGraphFinalizerName) { + graph.ObjectMeta.Finalizers = append(graph.ObjectMeta.Finalizers, constants.InferenceGraphFinalizerName) + patchYaml := "metadata:\n finalizers: [" + strings.Join(graph.ObjectMeta.Finalizers, ",") + "]" + patchJson, _ := yaml.YAMLToJSON([]byte(patchYaml)) + if err = r.Patch(ctx, graph, client.RawPatch(types.MergePatchType, patchJson)); err != nil { + return reconcile.Result{}, err + } + } + } else { + // The object is being deleted + if utils.Includes(graph.ObjectMeta.Finalizers, constants.InferenceGraphFinalizerName) { + // our finalizer is present, so lets cleanup resources + if err = r.onDeleteCleanup(ctx, graph); err != nil { + // if fail to delete the external dependency here, return with error + // so that it can be retried + return ctrl.Result{}, err + } + + // remove our finalizer from the list and update it. + graph.ObjectMeta.Finalizers = utils.RemoveString(graph.ObjectMeta.Finalizers, constants.InferenceGraphFinalizerName) + patchYaml := "metadata:\n finalizers: [" + strings.Join(graph.ObjectMeta.Finalizers, ",") + "]" + patchJson, _ := yaml.YAMLToJSON([]byte(patchYaml)) + if err = r.Patch(ctx, graph, client.RawPatch(types.MergePatchType, patchJson)); err != nil { + return reconcile.Result{}, err + } + } + + // Stop reconciliation as the item is being deleted + return ctrl.Result{}, nil + } + deploymentMode := isvcutils.GetDeploymentMode(graph.Status.DeploymentMode, graph.ObjectMeta.Annotations, deployConfig) r.Log.Info("Inference graph deployment ", "deployment mode ", deploymentMode) if deploymentMode == constants.RawDeployment { + // If the inference graph has auth enabled, create the supporting resources + err = handleInferenceGraphRawAuthResources(ctx, r.Clientset, r.Scheme, graph) + if err != nil { + return ctrl.Result{}, errors.Wrapf(err, "fails to reconcile resources for auth verification") + } + // Create inference graph resources such as deployment, service, hpa in raw deployment mode deployment, url, err := handleInferenceGraphRawDeployment(r.Client, r.Clientset, r.Scheme, graph, routerConfig) @@ -295,6 +341,16 @@ func inferenceGraphReadiness(status v1alpha1api.InferenceGraphStatus) bool { status.GetCondition(apis.ConditionReady).Status == v1.ConditionTrue } +func (r *InferenceGraphReconciler) onDeleteCleanup(ctx context.Context, graph *v1alpha1api.InferenceGraph) error { + if err := removeAuthPrivilegesFromGraphServiceAccount(ctx, r.Clientset, graph); err != nil { + return err + } + if err := deleteGraphServiceAccount(ctx, r.Clientset, graph); err != nil { + return err + } + return nil +} + func (r *InferenceGraphReconciler) SetupWithManager(mgr ctrl.Manager, deployConfig *v1beta1api.DeployConfig) error { r.ClientConfig = mgr.GetConfig() diff --git a/pkg/controller/v1alpha1/inferencegraph/controller_test.go b/pkg/controller/v1alpha1/inferencegraph/controller_test.go index 7f0cdc3e7ae..18b43d0ed07 100644 --- a/pkg/controller/v1alpha1/inferencegraph/controller_test.go +++ b/pkg/controller/v1alpha1/inferencegraph/controller_test.go @@ -27,13 +27,16 @@ import ( "google.golang.org/protobuf/proto" appsv1 "k8s.io/api/apps/v1" v1 "k8s.io/api/core/v1" + rbacv1 "k8s.io/api/rbac/v1" "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/api/resource" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/kubernetes/scheme" "knative.dev/pkg/kmp" knservingv1 "knative.dev/serving/pkg/apis/serving/v1" "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/apiutil" "sigs.k8s.io/controller-runtime/pkg/reconcile" "github.com/kserve/kserve/pkg/apis/serving/v1alpha1" @@ -65,6 +68,13 @@ var _ = Describe("Inference Graph controller test", func() { ] } }`, + "oauthProxy": `{ + "image": "registry.redhat.io/openshift4/ose-oauth-proxy@sha256:8507daed246d4d367704f7d7193233724acf1072572e1226ca063c066b858ecf", + "memoryRequest": "64Mi", + "memoryLimit": "128Mi", + "cpuRequest": "100m", + "cpuLimit": "200m" + }`, } ) @@ -910,4 +920,130 @@ var _ = Describe("Inference Graph controller test", func() { }, timeout, interval).Should(BeTrue()) }) }) + + Context("When creating an IG in Raw deployment mode with auth", func() { + var configMap *v1.ConfigMap + var inferenceGraph *v1alpha1.InferenceGraph + + BeforeEach(func() { + configMap = &v1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: constants.InferenceServiceConfigMapName, + Namespace: constants.KServeNamespace, + }, + Data: configs, + } + Expect(k8sClient.Create(ctx, configMap)).NotTo(HaveOccurred()) + + graphName := "igrawauth1" + ctx := context.Background() + inferenceGraph = &v1alpha1.InferenceGraph{ + ObjectMeta: metav1.ObjectMeta{ + Name: graphName, + Namespace: "default", + Annotations: map[string]string{ + "serving.kserve.io/deploymentMode": string(constants.RawDeployment), + constants.ODHKserveRawAuth: "true", + }, + }, + Spec: v1alpha1.InferenceGraphSpec{ + Nodes: map[string]v1alpha1.InferenceRouter{ + v1alpha1.GraphRootNodeName: { + RouterType: v1alpha1.Sequence, + Steps: []v1alpha1.InferenceStep{ + { + InferenceTarget: v1alpha1.InferenceTarget{ + ServiceURL: "http://someservice.exmaple.com", + }, + }, + }, + }, + }, + }, + } + Expect(k8sClient.Create(ctx, inferenceGraph)).Should(Succeed()) + }) + AfterEach(func() { + _ = k8sClient.Delete(ctx, inferenceGraph) + igKey := types.NamespacedName{Namespace: inferenceGraph.GetNamespace(), Name: inferenceGraph.GetName()} + Eventually(func() error { return k8sClient.Get(ctx, igKey, inferenceGraph) }, timeout, interval).ShouldNot(Succeed()) + + _ = k8sClient.Delete(ctx, configMap) + cmKey := types.NamespacedName{Namespace: configMap.GetNamespace(), Name: configMap.GetName()} + Eventually(func() error { return k8sClient.Get(ctx, cmKey, configMap) }, timeout, interval).ShouldNot(Succeed()) + }) + + It("Should create or update a ClusterRoleBinding giving privileges to validate auth", func() { + Eventually(func(g Gomega) { + crbKey := types.NamespacedName{Name: constants.InferenceGraphAuthCRBName} + clusterRoleBinding := rbacv1.ClusterRoleBinding{} + g.Expect(k8sClient.Get(ctx, crbKey, &clusterRoleBinding)).To(Succeed()) + + crGVK, err := apiutil.GVKForObject(&rbacv1.ClusterRole{}, scheme.Scheme) + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(clusterRoleBinding.RoleRef).To(Equal(rbacv1.RoleRef{ + APIGroup: crGVK.Group, + Kind: crGVK.Kind, + Name: "system:auth-delegator", + })) + g.Expect(clusterRoleBinding.Subjects).To(ContainElement(rbacv1.Subject{ + Kind: "ServiceAccount", + APIGroup: "", + Name: getServiceAccountNameForGraph(inferenceGraph), + Namespace: inferenceGraph.GetNamespace(), + })) + }, timeout, interval).Should(Succeed()) + }) + + It("Should create a ServiceAccount for querying the Kubernetes API to check tokens", func() { + Eventually(func(g Gomega) { + saKey := types.NamespacedName{Namespace: inferenceGraph.GetNamespace(), Name: getServiceAccountNameForGraph(inferenceGraph)} + serviceAccount := v1.ServiceAccount{} + g.Expect(k8sClient.Get(ctx, saKey, &serviceAccount)).To(Succeed()) + g.Expect(serviceAccount.OwnerReferences).ToNot(BeEmpty()) + }, timeout, interval).Should(Succeed()) + }) + + It("Should configure the InferenceGraph deployment with auth enabled", func() { + Eventually(func(g Gomega) { + igDeployment := appsv1.Deployment{} + g.Expect(k8sClient.Get(ctx, types.NamespacedName{Namespace: inferenceGraph.GetNamespace(), Name: inferenceGraph.GetName()}, &igDeployment)).To(Succeed()) + g.Expect(igDeployment.Spec.Template.Spec.AutomountServiceAccountToken).To(Equal(proto.Bool(true))) + g.Expect(igDeployment.Spec.Template.Spec.ServiceAccountName).To(Equal(getServiceAccountNameForGraph(inferenceGraph))) + // g.Expect(igDeployment.Spec.Template.Spec.Containers).To(HaveLen(1)) // TODO: Restore in RHOAIENG-21300 + g.Expect(igDeployment.Spec.Template.Spec.Containers[0].Args).To(ContainElements("--enable-auth", "--inferencegraph-name", inferenceGraph.GetName())) + }, timeout, interval).Should(Succeed()) + }) + + It("Should delete the ServiceAccount when the InferenceGraph is deleted", func() { + serviceAccount := v1.ServiceAccount{} + saKey := types.NamespacedName{Namespace: inferenceGraph.GetNamespace(), Name: getServiceAccountNameForGraph(inferenceGraph)} + + Eventually(func() error { + return k8sClient.Get(ctx, saKey, &serviceAccount) + }, timeout, interval).Should(Succeed()) + + Expect(k8sClient.Delete(ctx, inferenceGraph)).To(Succeed()) + Eventually(func() error { + return k8sClient.Get(ctx, saKey, &serviceAccount) + }, timeout, interval).Should(WithTransform(errors.IsNotFound, BeTrue())) + }) + + It("Should remove the ServiceAccount as subject of the ClusterRoleBinding when the InferenceGraph is deleted", func() { + crbKey := types.NamespacedName{Name: constants.InferenceGraphAuthCRBName} + + Eventually(func() []rbacv1.Subject { + clusterRoleBinding := rbacv1.ClusterRoleBinding{} + _ = k8sClient.Get(ctx, crbKey, &clusterRoleBinding) + return clusterRoleBinding.Subjects + }, timeout, interval).Should(ContainElement(HaveField("Name", getServiceAccountNameForGraph(inferenceGraph)))) + + Expect(k8sClient.Delete(ctx, inferenceGraph)).To(Succeed()) + Eventually(func() []rbacv1.Subject { + clusterRoleBinding := rbacv1.ClusterRoleBinding{} + _ = k8sClient.Get(ctx, crbKey, &clusterRoleBinding) + return clusterRoleBinding.Subjects + }, timeout, interval).ShouldNot(ContainElement(HaveField("Name", getServiceAccountNameForGraph(inferenceGraph)))) + }) + }) }) diff --git a/pkg/controller/v1alpha1/inferencegraph/raw_ig.go b/pkg/controller/v1alpha1/inferencegraph/raw_ig.go index f29248f76f5..80e0f2f2be5 100644 --- a/pkg/controller/v1alpha1/inferencegraph/raw_ig.go +++ b/pkg/controller/v1alpha1/inferencegraph/raw_ig.go @@ -17,19 +17,27 @@ limitations under the License. package inferencegraph import ( + "context" "encoding/json" + "fmt" "strings" "github.com/pkg/errors" "google.golang.org/protobuf/proto" appsv1 "k8s.io/api/apps/v1" v1 "k8s.io/api/core/v1" + rbacv1 "k8s.io/api/rbac/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" + corev1cfg "k8s.io/client-go/applyconfigurations/core/v1" + metav1cfg "k8s.io/client-go/applyconfigurations/meta/v1" + rbacv1cfg "k8s.io/client-go/applyconfigurations/rbac/v1" "k8s.io/client-go/kubernetes" "knative.dev/pkg/apis" knapis "knative.dev/pkg/apis" "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/apiutil" "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" logf "sigs.k8s.io/controller-runtime/pkg/log" @@ -76,6 +84,7 @@ func createInferenceGraphPodSpec(graph *v1alpha1api.InferenceGraph, config *Rout }, }, Affinity: graph.Spec.Affinity, + ServiceAccountName: "default", AutomountServiceAccountToken: proto.Bool(false), // Inference graph does not need access to api server } @@ -90,6 +99,28 @@ func createInferenceGraphPodSpec(graph *v1alpha1api.InferenceGraph, config *Rout } } + // If auth is enabled for the InferenceGraph: + // * Add --enable-auth argument, to properly secure kserve-router + // * Add the --inferencegraph-name argument, so that the router is aware of its name + // * Enable auto-mount of the ServiceAccount, because it is required for validating tokens + // * Set a non-default ServiceAccount with enough privileges to verify auth + if graph.GetAnnotations()[constants.ODHKserveRawAuth] == "true" { + podSpec.Containers[0].Args = append(podSpec.Containers[0].Args, "--enable-auth") + + podSpec.Containers[0].Args = append(podSpec.Containers[0].Args, "--inferencegraph-name") + podSpec.Containers[0].Args = append(podSpec.Containers[0].Args, graph.GetName()) + + podSpec.AutomountServiceAccountToken = proto.Bool(true) + + // In ODH, when auth is enabled, it is required to have the InferenceGraph running + // with a ServiceAccount that can query the Kubernetes API to validate tokens + // and privileges. + // In KServe v0.14 there is no way for users to set the ServiceAccount for an + // InferenceGraph. In ODH this is used at our advantage to set a non-default SA + // and bind needed privileges for the auth verification. + podSpec.ServiceAccountName = fmt.Sprintf("%s-auth-verifier", graph.GetName()) + } + return podSpec } @@ -180,6 +211,144 @@ func handleInferenceGraphRawDeployment(cl client.Client, clientset kubernetes.In return deployment[0], reconciler.URL, nil } +func handleInferenceGraphRawAuthResources(ctx context.Context, clientset kubernetes.Interface, scheme *runtime.Scheme, graph *v1alpha1api.InferenceGraph) error { + saName := getServiceAccountNameForGraph(graph) + + if graph.GetAnnotations()[constants.ODHKserveRawAuth] == "true" { + graphGVK, err := apiutil.GVKForObject(graph, scheme) + if err != nil { + return errors.Wrapf(err, "fails get GVK for inference graph") + } + ownerReference := metav1cfg.OwnerReference(). + WithKind(graphGVK.Kind). + WithAPIVersion(graphGVK.GroupVersion().String()). + WithName(graph.GetName()). + WithUID(graph.UID). + WithBlockOwnerDeletion(true). + WithController(true) + + // Create a Service Account that can be used to check auth + saAuthVerifier := corev1cfg.ServiceAccount(saName, graph.GetNamespace()). + WithOwnerReferences(ownerReference) + _, err = clientset.CoreV1().ServiceAccounts(graph.GetNamespace()).Apply(ctx, saAuthVerifier, metav1.ApplyOptions{FieldManager: InferenceGraphControllerName}) + if err != nil { + return errors.Wrapf(err, "fails to apply auth-verifier service account for inference graph") + } + + // Bind the required privileges to the Service Account + err = addAuthPrivilegesToGraphServiceAccount(ctx, clientset, graph) + if err != nil { + return err + } + } else { + err := removeAuthPrivilegesFromGraphServiceAccount(ctx, clientset, graph) + if err != nil { + return err + } + + err = deleteGraphServiceAccount(ctx, clientset, graph) + if err != nil { + return err + } + } + + return nil +} + +func addAuthPrivilegesToGraphServiceAccount(ctx context.Context, clientset kubernetes.Interface, graph *v1alpha1api.InferenceGraph) error { + clusterRoleBinding, err := clientset.RbacV1().ClusterRoleBindings().Get(ctx, constants.InferenceGraphAuthCRBName, metav1.GetOptions{}) + if client.IgnoreNotFound(err) != nil { + return errors.Wrapf(err, "fails to get cluster role binding kserve-inferencegraph-auth-verifiers while configuring inference graph auth") + } + + saName := getServiceAccountNameForGraph(graph) + if apierrors.IsNotFound(err) { + clusterRoleAuxiliary := rbacv1.ClusterRole{} + rbRoleRef := rbacv1cfg.RoleRef(). + WithKind("ClusterRole"). + WithName("system:auth-delegator"). + WithAPIGroup(clusterRoleAuxiliary.GroupVersionKind().Group) + rbSubject := rbacv1cfg.Subject(). + WithKind("ServiceAccount"). + WithNamespace(graph.GetNamespace()). + WithName(saName) + crbApply := rbacv1cfg.ClusterRoleBinding(constants.InferenceGraphAuthCRBName). + WithRoleRef(rbRoleRef). + WithSubjects(rbSubject) + + _, err = clientset.RbacV1().ClusterRoleBindings().Apply(ctx, crbApply, metav1.ApplyOptions{FieldManager: InferenceGraphControllerName}) + if err != nil { + return errors.Wrapf(err, "fails to apply kserve-inferencegraph-auth-verifiers ClusterRoleBinding for inference graph") + } + } else { + isPresent := false + for _, subject := range clusterRoleBinding.Subjects { + if subject.Kind == "ServiceAccount" && subject.Name == saName && subject.Namespace == graph.GetNamespace() { + isPresent = true + break + } + } + if !isPresent { + clusterRoleBinding.Subjects = append(clusterRoleBinding.Subjects, rbacv1.Subject{ + Kind: "ServiceAccount", + Name: saName, + Namespace: graph.GetNamespace(), + }) + _, err = clientset.RbacV1().ClusterRoleBindings().Update(ctx, clusterRoleBinding, metav1.UpdateOptions{FieldManager: InferenceGraphControllerName}) + if err != nil { + return errors.Wrapf(err, "fails to bind privileges for auth verification to inference graph") + } + } + } + + return nil +} + +func removeAuthPrivilegesFromGraphServiceAccount(ctx context.Context, clientset kubernetes.Interface, graph *v1alpha1api.InferenceGraph) error { + clusterRole, err := clientset.RbacV1().ClusterRoleBindings().Get(ctx, constants.InferenceGraphAuthCRBName, metav1.GetOptions{}) + if err != nil { + if apierrors.IsNotFound(err) { + return nil + } + return errors.Wrapf(err, "fails to get cluster role binding kserve-inferencegraph-auth-verifiers while deconfiguring inference graph auth") + } + + isPresent := false + saName := getServiceAccountNameForGraph(graph) + for idx, subject := range clusterRole.Subjects { + if subject.Kind == "ServiceAccount" && subject.Name == saName && subject.Namespace == graph.GetNamespace() { + isPresent = true + + // Remove the no longer needed entry + clusterRole.Subjects[idx] = clusterRole.Subjects[len(clusterRole.Subjects)-1] + clusterRole.Subjects = clusterRole.Subjects[:len(clusterRole.Subjects)-1] + break + } + } + + if isPresent { + _, err = clientset.RbacV1().ClusterRoleBindings().Update(ctx, clusterRole, metav1.UpdateOptions{FieldManager: InferenceGraphControllerName}) + if err != nil { + return errors.Wrapf(err, "fails to remove privileges for auth verification from inference graph") + } + } + + return nil +} + +func deleteGraphServiceAccount(ctx context.Context, clientset kubernetes.Interface, graph *v1alpha1api.InferenceGraph) error { + saName := getServiceAccountNameForGraph(graph) + err := clientset.CoreV1().ServiceAccounts(graph.GetNamespace()).Delete(ctx, saName, metav1.DeleteOptions{}) + if client.IgnoreNotFound(err) != nil { + return errors.Wrapf(err, "fails to delete service account for inference graph while deconfiguring auth") + } + return nil +} + +func getServiceAccountNameForGraph(graph *v1alpha1api.InferenceGraph) string { + return fmt.Sprintf("%s-auth-verifier", graph.GetName()) +} + /* PropagateRawStatus Propagates deployment status onto Inference graph status. In raw deployment mode, deployment available denotes the ready status for IG diff --git a/pkg/controller/v1alpha1/inferencegraph/raw_ig_test.go b/pkg/controller/v1alpha1/inferencegraph/raw_ig_test.go index 0a26c740b86..dd5a29df85c 100644 --- a/pkg/controller/v1alpha1/inferencegraph/raw_ig_test.go +++ b/pkg/controller/v1alpha1/inferencegraph/raw_ig_test.go @@ -20,9 +20,6 @@ import ( "testing" "github.com/google/go-cmp/cmp" - . "github.com/kserve/kserve/pkg/apis/serving/v1alpha1" - "github.com/kserve/kserve/pkg/apis/serving/v1beta1" - "github.com/kserve/kserve/pkg/constants" "google.golang.org/protobuf/proto" appsv1 "k8s.io/api/apps/v1" v1 "k8s.io/api/core/v1" @@ -30,6 +27,10 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "knative.dev/pkg/apis" duckv1 "knative.dev/pkg/apis/duck/v1" + + . "github.com/kserve/kserve/pkg/apis/serving/v1alpha1" + "github.com/kserve/kserve/pkg/apis/serving/v1beta1" + "github.com/kserve/kserve/pkg/constants" ) func TestCreateInferenceGraphPodSpec(t *testing.T) { @@ -174,6 +175,7 @@ func TestCreateInferenceGraphPodSpec(t *testing.T) { }, }, AutomountServiceAccountToken: proto.Bool(false), + ServiceAccountName: "default", }, "basicgraphwithheaders": { Containers: []v1.Container{ @@ -212,6 +214,7 @@ func TestCreateInferenceGraphPodSpec(t *testing.T) { }, }, AutomountServiceAccountToken: proto.Bool(false), + ServiceAccountName: "default", }, "withresource": { Containers: []v1.Container{ @@ -244,6 +247,7 @@ func TestCreateInferenceGraphPodSpec(t *testing.T) { }, }, AutomountServiceAccountToken: proto.Bool(false), + ServiceAccountName: "default", }, } diff --git a/pkg/controller/v1beta1/inferenceservice/reconcilers/deployment/deployment_reconciler.go b/pkg/controller/v1beta1/inferenceservice/reconcilers/deployment/deployment_reconciler.go index 7be9bf9ad6b..2835e9a997c 100644 --- a/pkg/controller/v1beta1/inferenceservice/reconcilers/deployment/deployment_reconciler.go +++ b/pkg/controller/v1beta1/inferenceservice/reconcilers/deployment/deployment_reconciler.go @@ -30,10 +30,6 @@ import ( "k8s.io/client-go/kubernetes" "github.com/google/go-cmp/cmp/cmpopts" - "github.com/kserve/kserve/pkg/apis/serving/v1beta1" - "github.com/kserve/kserve/pkg/constants" - v1beta1utils "github.com/kserve/kserve/pkg/controller/v1beta1/inferenceservice/utils" - "github.com/kserve/kserve/pkg/utils" appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" apierr "k8s.io/apimachinery/pkg/api/errors" @@ -44,6 +40,11 @@ import ( "knative.dev/pkg/kmp" kclient "sigs.k8s.io/controller-runtime/pkg/client" logf "sigs.k8s.io/controller-runtime/pkg/log" + + "github.com/kserve/kserve/pkg/apis/serving/v1beta1" + "github.com/kserve/kserve/pkg/constants" + v1beta1utils "github.com/kserve/kserve/pkg/controller/v1beta1/inferenceservice/utils" + "github.com/kserve/kserve/pkg/utils" ) var log = logf.Log.WithName("DeploymentReconciler") @@ -192,7 +193,7 @@ func addOauthContainerToDeployment(clientset kubernetes.Interface, deployment *a var isvcname string var upstreamPort string var sa string - if val, ok := componentMeta.Labels[constants.InferenceServiceLabel]; ok { + if val, ok := componentMeta.Labels[constants.InferenceServicePodLabelKey]; ok { isvcname = val } else { isvcname = componentMeta.Name