Skip to content

Commit

Permalink
Use host's root CA set by default and support normal SNI scenarios
Browse files Browse the repository at this point in the history
  • Loading branch information
edigaryev committed Jun 6, 2023
1 parent bcbd468 commit b274d8c
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 78 deletions.
133 changes: 99 additions & 34 deletions internal/command/context/create.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
package context

import (
"context"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"github.com/cirruslabs/orchard/internal/bootstraptoken"
"github.com/cirruslabs/orchard/internal/config"
"github.com/cirruslabs/orchard/internal/netconstants"
"github.com/cirruslabs/orchard/pkg/client"
clientpkg "github.com/cirruslabs/orchard/pkg/client"
"github.com/manifoldco/promptui"
"github.com/pkg/errors"
"github.com/spf13/cobra"
"net/url"
"strconv"
Expand All @@ -24,6 +25,7 @@ var bootstrapTokenRaw string
var serviceAccountName string
var serviceAccountToken string
var force bool
var noPKI bool

func newCreateCommand() *cobra.Command {
command := &cobra.Command{
Expand All @@ -43,6 +45,10 @@ func newCreateCommand() *cobra.Command {
"service account token to use (alternative to --bootstrap-token)")
command.PersistentFlags().BoolVar(&force, "force", false,
"create the context even if a context with the same name already exists")
command.PersistentFlags().BoolVar(&noPKI, "no-pki", false,
"do not use the host's root CA set and instead validate the Controller's presented "+
"certificate using a bootstrap token (or manually via fingerprint, "+
"if no bootstrap token is provided)")

return command
}
Expand All @@ -53,64 +59,119 @@ func runCreate(cmd *cobra.Command, args []string) error {
return err
}

// Establish trust
var trustedControllerCertificate *x509.Certificate
// If the bootstrap token is present, extract
// service account credentials from it
// and remember it for further use
var bootstrapToken *bootstraptoken.BootstrapToken

if bootstrapTokenRaw != "" {
bootstrapToken, err := bootstraptoken.NewFromString(bootstrapTokenRaw)
bootstrapToken, err = bootstraptoken.NewFromString(bootstrapTokenRaw)
if err != nil {
return err
}

serviceAccountName = bootstrapToken.ServiceAccountName()
serviceAccountToken = bootstrapToken.ServiceAccountToken()
trustedControllerCertificate = bootstrapToken.Certificate()
} else {
trustedControllerCertificate, err = probeControllerCertificate(controllerURL)
if err != nil {
return err
if serviceAccountName == "" {
serviceAccountName = bootstrapToken.ServiceAccountName()
}
if serviceAccountToken == "" {
serviceAccountToken = bootstrapToken.ServiceAccountToken()
}
}

client, err := client.New(
client.WithAddress(controllerURL.String()),
client.WithTrustedCertificate(trustedControllerCertificate),
client.WithCredentials(serviceAccountName, serviceAccountToken),
)
trustedCertificate, err := tryToConnectToTheController(cmd.Context(), controllerURL, bootstrapToken)
if err != nil {
return err
}
if err := client.Check(cmd.Context()); err != nil {
return err
}

// Create and save the context
configHandle, err := config.NewHandle()
if err != nil {
return err
}

certificatePEMBytes := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: trustedControllerCertificate.Raw,
})

return configHandle.CreateContext(contextName, config.Context{
newContext := config.Context{
URL: controllerURL.String(),
Certificate: certificatePEMBytes,
ServiceAccountName: serviceAccountName,
ServiceAccountToken: serviceAccountToken,
}, force)
}

if trustedCertificate != nil {
certificatePEMBytes := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: trustedCertificate.Raw,
})

newContext.Certificate = certificatePEMBytes
}

return configHandle.CreateContext(contextName, newContext, force)
}

func tryToConnectToTheController(
ctx context.Context,
controllerURL *url.URL,
bootstrapToken *bootstraptoken.BootstrapToken,
) (*x509.Certificate, error) {
if !noPKI {
if err := tryToConnectWithPKI(ctx, controllerURL); err == nil {
// Connection successful and no certificate retrieval is needed
return nil, nil
} else if errors.Is(err, clientpkg.ErrAPI) {
// Makes no sense to go any further since it's an upper layer (HTTP, not TLS) error
return nil, err
}
}

return tryToConnectWithTrustedCertificate(ctx, controllerURL, bootstrapToken)
}

func tryToConnectWithPKI(ctx context.Context, controllerURL *url.URL) error {
client, err := clientpkg.New(
clientpkg.WithAddress(controllerURL.String()),
clientpkg.WithCredentials(serviceAccountName, serviceAccountToken),
)
if err != nil {
return err
}

return client.Check(ctx)
}

func probeControllerCertificate(controllerURL *url.URL) (*x509.Certificate, error) {
// Do not use PKI
emptyPool := x509.NewCertPool()
func tryToConnectWithTrustedCertificate(
ctx context.Context,
controllerURL *url.URL,
bootstrapToken *bootstraptoken.BootstrapToken,
) (*x509.Certificate, error) {
// Either (1) retrieve a trusted certificate from the bootstrap token
// or (2) retrieve it from the Controller and verify it interactively
var trustedControllerCertificate *x509.Certificate
var err error

//nolint:gosec // since we're not using PKI, InsecureSkipVerify is a must here
if bootstrapToken != nil {
trustedControllerCertificate = bootstrapToken.Certificate()
} else {
if trustedControllerCertificate, err = probeControllerCertificate(ctx, controllerURL); err != nil {
return nil, err
}
}

// Now try again with the trusted certificate
client, err := clientpkg.New(
clientpkg.WithAddress(controllerURL.String()),
clientpkg.WithCredentials(serviceAccountName, serviceAccountToken),
clientpkg.WithTrustedCertificate(trustedControllerCertificate),
)
if err != nil {
return nil, err
}

return trustedControllerCertificate, client.Check(ctx)
}

func probeControllerCertificate(ctx context.Context, controllerURL *url.URL) (*x509.Certificate, error) {
//nolint:gosec // without InsecureSkipVerify our VerifyConnection won't be called
insecureTLSConfig := &tls.Config{
MinVersion: tls.VersionTLS12,
RootCAs: emptyPool,
InsecureSkipVerify: true,
}

Expand Down Expand Up @@ -182,7 +243,11 @@ func probeControllerCertificate(controllerURL *url.URL) (*x509.Certificate, erro
}
}

conn, err := tls.Dial("tcp", controllerURL.Host, insecureTLSConfig)
dialer := tls.Dialer{
Config: insecureTLSConfig,
}

conn, err := dialer.DialContext(ctx, "tcp", controllerURL.Host)
if err != nil {
return nil, err
}
Expand Down
23 changes: 12 additions & 11 deletions internal/config/context.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
package config

import (
"crypto/tls"
"crypto/x509"
"encoding/pem"
"fmt"
"github.com/cirruslabs/orchard/internal/netconstants"
)

type Context struct {
Expand All @@ -14,20 +13,22 @@ type Context struct {
ServiceAccountToken string `yaml:"serviceAccountToken,omitempty"`
}

func (context *Context) TLSConfig() (*tls.Config, error) {
func (context *Context) TrustedCertificate() (*x509.Certificate, error) {
if len(context.Certificate) == 0 {
return nil, nil
}

privatePool := x509.NewCertPool()
block, _ := pem.Decode(context.Certificate)
if block == nil {
return nil, fmt.Errorf("%w: failed to load context's certificate: no PEM data found",
ErrConfigReadFailed)
}

if ok := privatePool.AppendCertsFromPEM(context.Certificate); !ok {
return nil, fmt.Errorf("%w: failed to load context's certificate", ErrConfigReadFailed)
trustedCertificate, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return nil, fmt.Errorf("%w: failed to load context's certificate: %v",
ErrConfigReadFailed, err)
}

return &tls.Config{
MinVersion: tls.VersionTLS12,
ServerName: netconstants.DefaultControllerServerName,
RootCAs: privatePool,
}, nil
return trustedCertificate, nil
}
66 changes: 45 additions & 21 deletions pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"github.com/cirruslabs/orchard/internal/config"
"github.com/cirruslabs/orchard/internal/netconstants"
"github.com/cirruslabs/orchard/rpc"
"golang.org/x/net/websocket"
"google.golang.org/grpc/credentials"
Expand All @@ -21,13 +23,15 @@ import (

var (
ErrFailed = errors.New("API client failed")
ErrAPI = errors.New("API client encountered an API error")
ErrInvalidState = errors.New("invalid state")
)

type Client struct {
address string
insecure bool
tlsConfig *tls.Config
address string
insecure bool
trustedCertificate *x509.Certificate
tlsConfig *tls.Config

httpClient *http.Client
baseURL *url.URL
Expand All @@ -51,28 +55,23 @@ func New(opts ...Option) (*Client, error) {

// Apply defaults
if client.address == "" {
configHandle, err := config.NewHandle()
if err != nil {
return nil, err
}

defaultContext, err := configHandle.DefaultContext()
if err != nil {
if err := client.configureFromDefaultContext(); err != nil {
return nil, err
}
}

client.address = defaultContext.URL
client.serviceAccountName = defaultContext.ServiceAccountName
client.serviceAccountToken = defaultContext.ServiceAccountToken
if client.trustedCertificate != nil {
privatePool := x509.NewCertPool()
privatePool.AddCert(client.trustedCertificate)

tlsConfig, err := defaultContext.TLSConfig()
if err != nil {
return nil, err
client.tlsConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
RootCAs: privatePool,
ServerName: netconstants.DefaultControllerServerName,
}
client.tlsConfig = tlsConfig
}

// Instantiate client
// Instantiate the HTTP client
client.httpClient = &http.Client{
Transport: &http.Transport{
TLSClientConfig: client.tlsConfig,
Expand Down Expand Up @@ -125,6 +124,31 @@ func (client *Client) GPRCMetadata() metadata.MD {
return metadata.New(result)
}

func (client *Client) configureFromDefaultContext() error {
configHandle, err := config.NewHandle()
if err != nil {
return err
}

defaultContext, err := configHandle.DefaultContext()
if err != nil {
return err
}

client.address = defaultContext.URL
client.serviceAccountName = defaultContext.ServiceAccountName
client.serviceAccountToken = defaultContext.ServiceAccountToken

if client.trustedCertificate == nil {
client.trustedCertificate, err = defaultContext.TrustedCertificate()
if err != nil {
return err
}
}

return nil
}

func (client *Client) request(
ctx context.Context,
method string,
Expand Down Expand Up @@ -174,18 +198,18 @@ func (client *Client) request(

if response.StatusCode != http.StatusOK {
return fmt.Errorf("%w to make a request: %d %s%s",
ErrFailed, response.StatusCode, http.StatusText(response.StatusCode),
ErrAPI, response.StatusCode, http.StatusText(response.StatusCode),
detailsFromErrorResponseBody(response.Body))
}

if out != nil {
bodyBytes, err := io.ReadAll(response.Body)
if err != nil {
return fmt.Errorf("%w to read response body: %v", ErrFailed, err)
return fmt.Errorf("%w to read response body: %v", ErrAPI, err)
}

if err := json.Unmarshal(bodyBytes, out); err != nil {
return fmt.Errorf("%w to unmarshal response body: %v", ErrFailed, err)
return fmt.Errorf("%w to unmarshal response body: %v", ErrAPI, err)
}
}

Expand Down
Loading

0 comments on commit b274d8c

Please sign in to comment.