Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement authorization for Raw InferenceGraphs #499

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions charts/kserve-resources/templates/clusterrole.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,17 @@ rules:
- ""
resources:
- secrets
verbs:
- get
- apiGroups:
- ""
resources:
- serviceaccounts
verbs:
- create
- delete
- get
- patch
- apiGroups:
- admissionregistration.k8s.io
resources:
Expand Down Expand Up @@ -124,6 +132,34 @@ rules:
- patch
- update
- watch
- apiGroups:
- rbac.authorization.k8s.io
resourceNames:
- kserve-inferencegraph-auth-verifiers
resources:
- clusterrolebindings
verbs:
- create
- get
- patch
- update
- apiGroups:
- route.openshift.io
resources:
- routes
verbs:
- create
- get
- list
- patch
- update
- watch
- apiGroups:
- route.openshift.io
resources:
- routes/status
verbs:
- get
- apiGroups:
- serving.knative.dev
resources:
Expand Down
185 changes: 173 additions & 12 deletions cmd/router/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@ package main

import (
"bytes"
"context"
"crypto/rand"
"encoding/json"
goerrors "errors"
"fmt"
"io"
"math/big"
"net/http"
"net/url"
"os"
Expand All @@ -31,18 +34,19 @@ import (
"syscall"
"time"

"github.com/kserve/kserve/pkg/constants"
"github.com/pkg/errors"

flag "github.com/spf13/pflag"
"github.com/tidwall/gjson"
authnv1 "k8s.io/api/authentication/v1"
authzv1 "k8s.io/api/authorization/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/rest"
logf "sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/controller-runtime/pkg/log/zap"

"crypto/rand"
"math/big"

"github.com/kserve/kserve/pkg/apis/serving/v1alpha1"
flag "github.com/spf13/pflag"
"github.com/kserve/kserve/pkg/constants"
)

var log = logf.Log.WithName("InferenceGraphRouter")
Expand Down Expand Up @@ -411,10 +415,158 @@ func compilePatterns(patterns []string) ([]*regexp.Regexp, error) {
}

var (
enableAuthFlag = flag.Bool("enable-auth", false, "protect the inference graph with authorization")
graphName = flag.String("inferencegraph-name", "", "the name of the associated inference graph Kubernetes resource")
jsonGraph = flag.String("graph-json", "", "serialized json graph def")
compiledHeaderPatterns []*regexp.Regexp
)

// findBearerToken parses the standard HTTP Authorization header to find and return
// a Bearer token that a client may have provided in the request. If the token
// is found, it is returned. Else, an empty string is returned and the HTTP response
// is sent to the client with proper status code and the reason for the request being
// rejected.
func findBearerToken(w http.ResponseWriter, r *http.Request) string {
// Find for HTTP Authentication header. Reject request if not available.
authHeader := r.Header.Get("Authorization")
if len(authHeader) == 0 {
w.Header().Set("X-Forbidden-Reason", "No credentials were provided")
w.WriteHeader(http.StatusUnauthorized)
return ""
}

// Parse Auth header
token := strings.TrimPrefix(authHeader, "Bearer ")
if token == authHeader {
w.Header().Set("X-Forbidden-Reason", "Only Bearer tokens are supported")
w.WriteHeader(http.StatusUnauthorized)
return ""
}
return token
}

// validateTokenIsAuthenticated queries the Kubernetes cluster to find if the provided token is
// valid and flagged as authenticated. If the token is usable, the result of the TokenReview
// is returned. Otherwise, the HTTP response is sent rejecting the request and setting
// a meaningful status code along with a reason (if available).
func validateTokenIsAuthenticated(w http.ResponseWriter, token string, clientset *kubernetes.Clientset) *authnv1.TokenReview {
// Check the token is valid
tokenReview := authnv1.TokenReview{}
tokenReview.Spec.Token = token
tokenReviewResult, err := clientset.AuthenticationV1().TokenReviews().Create(context.Background(), &tokenReview, metav1.CreateOptions{})
if err != nil {
log.Error(err, "failed to create TokenReview when verifying credentials")
w.WriteHeader(http.StatusInternalServerError)
return nil
}
if len(tokenReviewResult.Status.Error) != 0 {
w.Header().Set("X-Forbidden-Reason", tokenReviewResult.Status.Error)
w.WriteHeader(http.StatusUnauthorized)
return nil
}
if !tokenReviewResult.Status.Authenticated {
w.Header().Set("X-Forbidden-Reason", "The provided token is unauthenticated")
w.WriteHeader(http.StatusUnauthorized)
return nil
}
return tokenReviewResult
}

// checkRequestIsAuthorized verifies that the user in the provided tokenReviewResult has privileges to query the
// Kubernetes API and get the InferenceGraph resource that belongs to this pod. If so, the request is considered
// as allowed and `true` is returned. Otherwise, the HTTP response is sent rejecting the request and setting
// a meaningful status code along with a reason (if available).

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question:
What happens if the token has get privileges for the IG but not 1 or more of the ISVCs in the IG? Should we be verifying that the token has the correct privileges for the IG + all the ISVCs?

Copy link
Author

@israel-hdez israel-hdez Feb 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The token needs to have privileges for both the IG and the ISVC. The IG is doing only its check.

Later, the token is forwarded to the ISVC, so that it also can do its own check. This, effectively, delegates such check to the ISVC.

This may not be optimal, if the auth-protected ISVC is the last one on the IG and the previous ones are not protected. The request would fail after wasting resources. But I think current implementation is good enough, given we are not sure how users will use InferenceGraph. So, I'd say that we should optimize once we are sure.

func checkRequestIsAuthorized(w http.ResponseWriter, _ *http.Request, tokenReviewResult *authnv1.TokenReview, clientset *kubernetes.Clientset) bool {
// Read pod namespace
const namespaceFile = "/var/run/secrets/kubernetes.io/serviceaccount/namespace"
namespaceBytes, err := os.ReadFile(namespaceFile)
if err != nil {
log.Error(err, "failed to read namespace file while verifying credentials")
w.WriteHeader(http.StatusInternalServerError)
return false
}
namespace := string(namespaceBytes)

// Check the subject is authorized to query the InferenceGraph
if len(*graphName) == 0 {
log.Error(errors.New("no graph name provided"), "the --inferencegraph-name flag wasn't provided")
w.WriteHeader(http.StatusInternalServerError)
return false
}
accessReview := authzv1.SubjectAccessReview{
Spec: authzv1.SubjectAccessReviewSpec{
ResourceAttributes: &authzv1.ResourceAttributes{
Namespace: namespace,
Verb: "get",
Group: "serving.kserve.io",
Resource: "inferencegraphs",
Name: *graphName,
},
User: tokenReviewResult.Status.User.Username,
Groups: nil,
},
}

accessReviewResult, err := clientset.AuthorizationV1().SubjectAccessReviews().Create(context.Background(), &accessReview, metav1.CreateOptions{})
if err != nil {
log.Error(err, "failed to create LocalSubjectAccessReview when verifying credentials")
w.WriteHeader(http.StatusInternalServerError)
return false
}
if accessReviewResult.Status.Allowed {
// Note: This is here so that the request is NOT allowed by default.
return true
}

w.Header().Add("X-Forbidden-Reason", "Access to the InferenceGraph is not allowed")
if len(accessReviewResult.Status.Reason) != 0 {
w.Header().Add("X-Forbidden-Reason", accessReviewResult.Status.Reason)
}
if len(accessReviewResult.Status.EvaluationError) != 0 {
w.Header().Add("X-Forbidden-Reason", accessReviewResult.Status.EvaluationError)
}

w.WriteHeader(http.StatusUnauthorized)
return false
}

// authMiddleware uses the Middleware pattern to protect the InferenceGraph behind authorization.
// It expects that a Bearer token is provided in the request in the standard HTTP Authorization
// header. The token is verified against Kubernetes using the TokenReview and SubjectAccessReview APIs.
// If the token is valid and has enough privileges, the handler provided in the `next` argument is run.
// Otherwise, `next` is not invoked and the reason for the rejection is sent in response headers.
func authMiddleware(next http.Handler) (http.Handler, error) {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
k8sConfig, k8sConfigErr := rest.InClusterConfig()
if k8sConfigErr != nil {
log.Error(k8sConfigErr, "failed to create rest configuration to connect to Kubernetes API")
w.WriteHeader(http.StatusInternalServerError)
return
}

clientset, clientsetErr := kubernetes.NewForConfig(k8sConfig)
if clientsetErr != nil {
log.Error(k8sConfigErr, "failed to create Kubernetes client to connect to API")
return
}

token := findBearerToken(w, r)
if len(token) == 0 {
return
}

tokenReviewResult := validateTokenIsAuthenticated(w, token, clientset)
if tokenReviewResult == nil {
return
}

isAuthorized := checkRequestIsAuthorized(w, r, tokenReviewResult, clientset)
if isAuthorized {
next.ServeHTTP(w, r)
}
}), nil
}

func main() {
flag.Parse()
logf.SetLogger(zap.New())
Expand All @@ -434,14 +586,23 @@ func main() {
os.Exit(1)
}

http.HandleFunc("/", graphHandler)
var entrypointHandler http.Handler
entrypointHandler = http.HandlerFunc(graphHandler)
if *enableAuthFlag {
entrypointHandler, err = authMiddleware(entrypointHandler)
log.Info("This Router has authorization enabled")
if err != nil {
log.Error(err, "failed to create entrypoint handler")
os.Exit(1)
}
}

server := &http.Server{
Addr: ":8080", // specify the address and port
Handler: http.HandlerFunc(graphHandler), // specify your HTTP handler
ReadTimeout: time.Minute, // set the maximum duration for reading the entire request, including the body
WriteTimeout: time.Minute, // set the maximum duration before timing out writes of the response
IdleTimeout: 3 * time.Minute, // set the maximum amount of time to wait for the next request when keep-alives are enabled
Addr: ":8080", // specify the address and port
Handler: entrypointHandler, // specify your HTTP handler
ReadTimeout: time.Minute, // set the maximum duration for reading the entire request, including the body
WriteTimeout: time.Minute, // set the maximum duration before timing out writes of the response
IdleTimeout: 3 * time.Minute, // set the maximum amount of time to wait for the next request when keep-alives are enabled
}
err = server.ListenAndServe()

Expand Down
19 changes: 19 additions & 0 deletions config/rbac/role.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,17 @@ rules:
- ""
resources:
- secrets
verbs:
- get
- apiGroups:
- ""
resources:
- serviceaccounts
verbs:
- create
- delete
- get
- patch
- apiGroups:
- admissionregistration.k8s.io
resources:
Expand Down Expand Up @@ -111,6 +119,17 @@ rules:
- patch
- update
- watch
- apiGroups:
- rbac.authorization.k8s.io
resourceNames:
- kserve-inferencegraph-auth-verifiers
resources:
- clusterrolebindings
verbs:
- create
- get
- patch
- update
- apiGroups:
- route.openshift.io
resources:
Expand Down
2 changes: 2 additions & 0 deletions pkg/constants/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ var (
const (
RouterHeadersPropagateEnvVar = "PROPAGATE_HEADERS"
InferenceGraphLabel = "serving.kserve.io/inferencegraph"
InferenceGraphAuthCRBName = "kserve-inferencegraph-auth-verifiers"
InferenceGraphFinalizerName = "inferencegraph.finalizers"
)

// TrainedModel Constants
Expand Down
Loading
Loading