Skip to content

Commit

Permalink
Request signing beta (#1072)
Browse files Browse the repository at this point in the history
  • Loading branch information
charliecruzan-stripe authored May 11, 2023
1 parent 528d5db commit c502ca2
Show file tree
Hide file tree
Showing 7 changed files with 498 additions and 2 deletions.
95 changes: 94 additions & 1 deletion pkg/config/profile.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package config

import (
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
Expand Down Expand Up @@ -58,14 +60,21 @@ var KeyRing keyring.Keyring

// CreateProfile creates a profile when logging in
func (p *Profile) CreateProfile() error {
writeErr := p.writeProfile(viper.GetViper())
// Remove existing profile first
v := p.deleteProfile(viper.GetViper())

writeErr := p.writeProfile(v)
if writeErr != nil {
return writeErr
}

return nil
}

func (p *Profile) deleteProfile(v *viper.Viper) *viper.Viper {
return p.safeRemove(v, p.ProfileName)
}

// GetColor gets the color setting for the user based on the flag or the
// persisted color stored in the config file
func (p *Profile) GetColor() (string, error) {
Expand Down Expand Up @@ -333,6 +342,9 @@ func (p *Profile) writeProfile(runtimeViper *viper.Viper) error {
runtimeViper = p.safeRemove(runtimeViper, "publishable_key")
}

// Remove experimental fields during login
runtimeViper = p.removeExperimentalFields(runtimeViper)

runtimeViper.SetConfigFile(profilesFile)

// Ensure we preserve the config file type
Expand Down Expand Up @@ -418,3 +430,84 @@ func isRedactedAPIKey(apiKey string) bool {
func getKeyExpiresAt() string {
return time.Now().AddDate(0, 0, KeyValidInDays).UTC().Format(DateStringFormat)
}

// ExperimentalFields are currently only used for request signing
type ExperimentalFields struct {
ContextualName string
PrivateKey string
StripeHeaders string
}

const (
experimentalPrefix = "experimental"
experimentalStripeHeaders = experimentalPrefix + "." + "stripe_headers"
experimentalContextualName = experimentalPrefix + "." + "contextual_name"
experimentalPrivateKey = experimentalPrefix + "." + "private_key"
)

// GetExperimentalFields returns a struct of the profile's experimental fields. These fields are only ever additive in functionality.
func (p *Profile) GetExperimentalFields() ExperimentalFields {
if err := viper.ReadInConfig(); err == nil {
name := viper.GetString(p.GetConfigField(experimentalContextualName))
privKey := viper.GetString(p.GetConfigField(experimentalPrivateKey))
headers := viper.GetString(p.GetConfigField(experimentalStripeHeaders))

return ExperimentalFields{
ContextualName: name,
PrivateKey: privKey,
StripeHeaders: headers,
}
}
return ExperimentalFields{
ContextualName: "",
PrivateKey: "",
StripeHeaders: "",
}
}

func (p *Profile) removeExperimentalFields(v *viper.Viper) *viper.Viper {
v = p.safeRemove(v, experimentalPrefix)
return v
}

// SessionCredentials are the credentials needed for this session
type SessionCredentials struct {
UAT string `json:"uat"`
PrivateKey string `json:"private_key"`
AccountID string `json:"account_id"`
}

// GetSessionCredentials retrieves the session credentials from the keyring
func (p *Profile) GetSessionCredentials() (*SessionCredentials, error) {
key := p.GetConfigField("stripe_cli_session")
ring, err := keyring.Open(keyring.Config{
KeychainTrustApplication: true,
ServiceName: KeyManagementService,
})
if err != nil {
return nil, err
}
keyringItem, err := ring.Get(key)
if err != nil {
if err == keyring.ErrKeyNotFound {
return nil, errors.New("no session")
}
return nil, err
}

creds := SessionCredentials{}
if err := json.Unmarshal(keyringItem.Data, &creds); err != nil {
return nil, err
}

currentAccountID, err := p.GetAccountID()
if err != nil {
return nil, err
}

if creds.AccountID == "" || creds.AccountID != currentAccountID {
return nil, errors.New("found a session, but it doesn't match your current account")
}

return &creds, nil
}
2 changes: 1 addition & 1 deletion pkg/config/profile_livemode_arm64.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func (p *Profile) retrieveLivemodeValue(key string) (string, error) {
}
}

value := strings.TrimSpace(string(out[:]))
value := strings.TrimSpace(string(out))
// if the string has the well-known prefix, assume it's encoded
if strings.HasPrefix(value, encodingPrefix) {
dec, err := hex.DecodeString(value[len(encodingPrefix):])
Expand Down
43 changes: 43 additions & 0 deletions pkg/config/profile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,49 @@ test_mode_key_expires_at = '` + expiresAt + `'
cleanUp(c.ProfilesFile)
}

func TestExperimentalFields(t *testing.T) {
profilesFile := filepath.Join(os.TempDir(), "stripe", "config.toml")
p := Profile{
ProfileName: "tests",
DeviceName: "st-testing",
TestModeAPIKey: "sk_test_123",
DisplayName: "test-account-display-name",
}
c := &Config{
Color: "auto",
LogLevel: "info",
Profile: p,
ProfilesFile: profilesFile,
}
c.InitConfig()

v := viper.New()

v.SetConfigFile(profilesFile)
err := p.writeProfile(v)
require.NoError(t, err)

require.FileExists(t, c.ProfilesFile)

require.NoError(t, err)

experimentalFields := p.GetExperimentalFields()
require.Equal(t, "", experimentalFields.ContextualName)
require.Equal(t, "", experimentalFields.StripeHeaders)
require.Equal(t, "", experimentalFields.PrivateKey)

p.WriteConfigField("experimental.stripe_headers", "test-headers")
p.WriteConfigField("experimental.contextual_name", "test-name")
p.WriteConfigField("experimental.private_key", "test-key")

experimentalFields = p.GetExperimentalFields()
require.Equal(t, "test-name", experimentalFields.ContextualName)
require.Equal(t, "test-headers", experimentalFields.StripeHeaders)
require.Equal(t, "test-key", experimentalFields.PrivateKey)

cleanUp(c.ProfilesFile)
}

func helperLoadBytes(t *testing.T, name string) []byte {
bytes, err := os.ReadFile(name)
if err != nil {
Expand Down
43 changes: 43 additions & 0 deletions pkg/requests/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,17 @@ func (rb *Base) performRequest(ctx context.Context, client stripe.RequestPerform
return err
}
}

if rb.Profile != nil {
experimentalFields := rb.Profile.GetExperimentalFields()
if experimentalFields.StripeHeaders != "" {
err := rb.experimentalRequestSigning(req, experimentalFields)
if err != nil {
return err
}
}
}

return nil
}

Expand Down Expand Up @@ -488,3 +499,35 @@ func normalizePath(path string) string {

return "/v1/" + path
}

func (rb *Base) experimentalRequestSigning(req *http.Request, experimentalFields config.ExperimentalFields) error {
privKey := experimentalFields.PrivateKey

keyToValues := strings.Split(strings.Trim(experimentalFields.StripeHeaders, ";"), ";")
for _, pair := range keyToValues {
header := strings.Split(pair, "=")
if len(header) != 2 {
continue
}
headerName := header[0]
headerValue := header[1]
if headerName == stripeContextHeaderName {
displayMessage := fmt.Sprintf("Operating in %s %s\n", ansi.Bold(experimentalFields.ContextualName), ansi.Color(os.Stdout).Gray(10, "("+headerValue+")..."))
fmt.Print(ansi.Color(os.Stdout).Gray(10, displayMessage))
} else if headerName == authorizationHeaderName && privKey == "" {
creds, err := rb.Profile.GetSessionCredentials()
if err != nil {
return err
}
headerValue += creds.UAT
privKey = creds.PrivateKey
}
req.Header.Set(headerName, headerValue)
}
if len(keyToValues) > 0 {
// Must sign the request AFTER all headers have been set
SignRequest(req, privKey)
}

return nil
}
28 changes: 28 additions & 0 deletions pkg/requests/base_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
"testing"

"github.com/stretchr/testify/require"

"github.com/stripe/stripe-cli/pkg/config"
)

func TestBuildDataForRequest(t *testing.T) {
Expand Down Expand Up @@ -297,3 +299,29 @@ func TestIsAPIKeyExpiredError(t *testing.T) {
require.False(t, IsAPIKeyExpiredError(fmt.Errorf("other")))
})
}

func TestRequestSigning(t *testing.T) {
rb := Base{}
req, _ := http.NewRequest(http.MethodGet, "/test", nil)
err := rb.experimentalRequestSigning(req, config.ExperimentalFields{
StripeHeaders: "Stripe-Context=test-context;Authorization=TEST-PREFIX 123",
ContextualName: "test-name",
PrivateKey: "test-key",
})
require.NoError(t, err)
require.Equal(t, "test-context", req.Header.Get("Stripe-Context"))
require.Equal(t, "TEST-PREFIX 123", req.Header.Get("Authorization"))
}

func TestRequestSigningShouldNotBeCalled(t *testing.T) {
rb := Base{}
req, _ := http.NewRequest(http.MethodGet, "/test", nil)
err := rb.experimentalRequestSigning(req, config.ExperimentalFields{
StripeHeaders: "",
ContextualName: "",
PrivateKey: "",
})
require.NoError(t, err)
require.Equal(t, "", req.Header.Get("Stripe-Context"))
require.Equal(t, "", req.Header.Get("Authorization"))
}
Loading

0 comments on commit c502ca2

Please sign in to comment.