diff --git a/cmd/spire-agent/cli/api/api_posix_test.go b/cmd/spire-agent/cli/api/api_posix_test.go new file mode 100644 index 0000000000..f9616f27ed --- /dev/null +++ b/cmd/spire-agent/cli/api/api_posix_test.go @@ -0,0 +1,45 @@ +//go:build !windows +// +build !windows + +package api + +const ( + fetchJWTUsage = `Usage of fetch jwt: + -audience value + comma separated list of audience values + -format value + deprecated; use -output + -output value + Desired output format (pretty, json); default: pretty. + -socketPath string + Path to the SPIRE Agent API Unix domain socket (default "/tmp/spire-agent/public/api.sock") + -spiffeID string + SPIFFE ID subject (optional) + -timeout value + Time to wait for a response (default 5s) +` + fetchX509Usage = `Usage of fetch x509: + -output value + Desired output format (pretty, json); default: pretty. + -silent + Suppress stdout + -socketPath string + Path to the SPIRE Agent API Unix domain socket (default "/tmp/spire-agent/public/api.sock") + -timeout value + Time to wait for a response (default 5s) + -write string + Write SVID data to the specified path (optional; only available for pretty output format) +` + validateJWTUsage = `Usage of validate jwt: + -audience string + expected audience value + -output value + Desired output format (pretty, json); default: pretty. + -socketPath string + Path to the SPIRE Agent API Unix domain socket (default "/tmp/spire-agent/public/api.sock") + -svid string + JWT SVID + -timeout value + Time to wait for a response (default 5s) +` +) diff --git a/cmd/spire-agent/cli/api/api_test.go b/cmd/spire-agent/cli/api/api_test.go new file mode 100644 index 0000000000..a030fc8332 --- /dev/null +++ b/cmd/spire-agent/cli/api/api_test.go @@ -0,0 +1,559 @@ +package api + +import ( + "bytes" + "crypto" + "crypto/x509" + "encoding/base64" + "encoding/pem" + "errors" + "fmt" + "os" + "path/filepath" + "testing" + + "github.com/mitchellh/cli" + "github.com/spiffe/go-spiffe/v2/proto/spiffe/workload" + "github.com/spiffe/go-spiffe/v2/spiffeid" + "github.com/spiffe/spire/cmd/spire-server/cli/common" + commoncli "github.com/spiffe/spire/pkg/common/cli" + "github.com/spiffe/spire/pkg/common/x509util" + "github.com/spiffe/spire/test/fakes/fakeworkloadapi" + "github.com/spiffe/spire/test/spiretest" + "github.com/spiffe/spire/test/testca" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/structpb" +) + +var availableFormats = []string{"pretty", "json"} + +func TestFetchJWTCommandHelp(t *testing.T) { + test := setupTest(t, newFetchJWTCommandWithEnv) + test.cmd.Help() + require.Equal(t, fetchJWTUsage, test.stderr.String()) +} + +func TestFetchJWTCommandSynopsis(t *testing.T) { + test := setupTest(t, newFetchJWTCommandWithEnv) + require.Equal(t, "Fetches a JWT SVID from the Workload API", test.cmd.Synopsis()) +} + +func TestFetchJWTCommand(t *testing.T) { + td := spiffeid.RequireTrustDomainFromString("example.org") + ca := testca.New(t, td) + encodedSvid1 := ca.CreateJWTSVID(spiffeid.RequireFromString("spiffe://domain1.test"), []string{"foo"}).Marshal() + encodedSvid2 := ca.CreateJWTSVID(spiffeid.RequireFromString("spiffe://domain2.test"), []string{"foo"}).Marshal() + bundleJWKSBytes, err := ca.JWTBundle().Marshal() + require.NoError(t, err) + + tests := []struct { + name string + args []string + fakeRequests []*fakeworkloadapi.FakeRequest + expectedStderr string + expectedStdoutPretty []string + expectedStdoutJSON string + }{ + { + name: "success fetching jwt with bundles", + args: []string{"-audience", "foo", "-spiffeID", "spiffe://domain1.test"}, + fakeRequests: []*fakeworkloadapi.FakeRequest{ + { + Req: &workload.JWTBundlesRequest{}, + Resp: &workload.JWTBundlesResponse{ + Bundles: map[string][]byte{ + "spiffe://domain1.test": bundleJWKSBytes, + "spiffe://domain2.test": bundleJWKSBytes, + }, + }, + }, + { + Req: &workload.JWTSVIDRequest{ + Audience: []string{"foo"}, + SpiffeId: "spiffe://domain1.test", + }, + Resp: &workload.JWTSVIDResponse{ + Svids: []*workload.JWTSVID{ + { + SpiffeId: "spiffe://domain1.test", + Svid: encodedSvid1, + }, + { + SpiffeId: "spiffe://domain2.test", + Svid: encodedSvid2, + }, + }, + }, + }, + }, + expectedStdoutPretty: []string{ + fmt.Sprintf("token(spiffe://domain1.test):\n\t%s", encodedSvid1), + fmt.Sprintf("token(spiffe://domain2.test):\n\t%s", encodedSvid2), + fmt.Sprintf("bundle(spiffe://domain1.test):\n\t%s", bundleJWKSBytes), + fmt.Sprintf("bundle(spiffe://domain2.test):\n\t%s", bundleJWKSBytes), + }, + expectedStdoutJSON: fmt.Sprintf(`[ + { + "svids": [ + { + "spiffe_id": "spiffe://domain1.test", + "svid": "%s" + }, + { + "spiffe_id": "spiffe://domain2.test", + "svid": "%s" + } + ] + }, + { + "bundles": { + "spiffe://domain1.test": "%s", + "spiffe://domain2.test": "%s" + } + } +]`, encodedSvid1, encodedSvid2, base64.StdEncoding.EncodeToString(bundleJWKSBytes), base64.StdEncoding.EncodeToString(bundleJWKSBytes)), + }, + { + name: "fail with error fetching bundles", + args: []string{"-audience", "foo", "-spiffeID", "spiffe://domain1.test"}, + fakeRequests: []*fakeworkloadapi.FakeRequest{ + { + Req: &workload.JWTBundlesRequest{}, + Resp: &workload.JWTBundlesResponse{}, + Err: errors.New("error fetching bundles"), + }, + }, + expectedStderr: "rpc error: code = Unknown desc = error fetching bundles\n", + }, + { + name: "fail with error fetching svid", + args: []string{"-audience", "foo", "-spiffeID", "spiffe://domain1.test"}, + fakeRequests: []*fakeworkloadapi.FakeRequest{ + { + Req: &workload.JWTBundlesRequest{}, + Resp: &workload.JWTBundlesResponse{ + Bundles: map[string][]byte{ + "spiffe://domain1.test": bundleJWKSBytes, + }, + }, + }, + { + Req: &workload.JWTSVIDRequest{ + Audience: []string{"foo"}, + SpiffeId: "spiffe://domain1.test", + }, + Resp: &workload.JWTSVIDResponse{}, + Err: errors.New("error fetching svid"), + }, + }, + expectedStderr: "rpc error: code = Unknown desc = error fetching svid\n", + }, + { + name: "fail when audience is not provided", + expectedStderr: "audience must be specified\n", + }, + } + + for _, tt := range tests { + for _, format := range availableFormats { + t.Run(fmt.Sprintf("%s using %s format", tt.name, format), func(t *testing.T) { + test := setupTest(t, newFetchJWTCommandWithEnv, tt.fakeRequests...) + args := tt.args + args = append(args, "-output", format) + + rc := test.cmd.Run(test.args(args...)) + + if tt.expectedStderr != "" { + assert.Equal(t, 1, rc) + assert.Equal(t, tt.expectedStderr, test.stderr.String()) + return + } + + assertOutputBasedOnFormat(t, format, test.stdout.String(), tt.expectedStdoutJSON, tt.expectedStdoutPretty...) + assert.Empty(t, test.stderr.String()) + assert.Equal(t, 0, rc) + }) + } + } +} + +func TestFetchX509CommandHelp(t *testing.T) { + test := setupTest(t, newFetchX509Command) + test.cmd.Help() + require.Equal(t, fetchX509Usage, test.stderr.String()) +} + +func TestFetchX509CommandSynopsis(t *testing.T) { + test := setupTest(t, newFetchX509Command) + require.Equal(t, "Fetches X509 SVIDs from the Workload API", test.cmd.Synopsis()) +} + +func TestFetchX509Command(t *testing.T) { + testDir := t.TempDir() + td := spiffeid.RequireTrustDomainFromString("example.org") + ca := testca.New(t, td) + svid := ca.CreateX509SVID(spiffeid.RequireFromString("spiffe://example.org/foo")) + + tests := []struct { + name string + args []string + fakeRequests []*fakeworkloadapi.FakeRequest + expectedStderr string + expectedStdoutPretty string + expectedStdoutJSON string + expectedFileResult bool + }{ + { + name: "success fetching x509 svid", + fakeRequests: []*fakeworkloadapi.FakeRequest{ + { + Req: &workload.X509SVIDRequest{}, + Resp: &workload.X509SVIDResponse{ + Svids: []*workload.X509SVID{ + { + SpiffeId: svid.ID.String(), + X509Svid: x509util.DERFromCertificates(svid.Certificates), + X509SvidKey: pkcs8FromSigner(t, svid.PrivateKey), + Bundle: x509util.DERFromCertificates(ca.Bundle().X509Authorities()), + }, + }, + Crl: [][]byte{}, + FederatedBundles: map[string][]byte{}, + }, + }, + }, + expectedStdoutPretty: fmt.Sprintf(` +SPIFFE ID: spiffe://example.org/foo +SVID Valid After: %v +SVID Valid Until: %v +CA #1 Valid After: %v +CA #1 Valid Until: %v +`, + svid.Certificates[0].NotBefore, + svid.Certificates[0].NotAfter, + ca.Bundle().X509Authorities()[0].NotBefore, + ca.Bundle().X509Authorities()[0].NotAfter, + ), + expectedStdoutJSON: fmt.Sprintf(`{ + "crl": [], + "federated_bundles": {}, + "svids": [ + { + "bundle": "%s", + "spiffe_id": "spiffe://example.org/foo", + "x509_svid": "%s", + "x509_svid_key": "%s" + } + ] +}`, + base64.StdEncoding.EncodeToString(x509util.DERFromCertificates(ca.Bundle().X509Authorities())), + base64.StdEncoding.EncodeToString(x509util.DERFromCertificates(svid.Certificates)), + base64.StdEncoding.EncodeToString(pkcs8FromSigner(t, svid.PrivateKey)), + ), + }, + { + name: "success fetching x509 and writing to file", + args: []string{"-write", testDir}, + fakeRequests: []*fakeworkloadapi.FakeRequest{ + { + Req: &workload.X509SVIDRequest{}, + Resp: &workload.X509SVIDResponse{ + Svids: []*workload.X509SVID{ + { + SpiffeId: svid.ID.String(), + X509Svid: x509util.DERFromCertificates(svid.Certificates), + X509SvidKey: pkcs8FromSigner(t, svid.PrivateKey), + Bundle: x509util.DERFromCertificates(ca.Bundle().X509Authorities()), + }, + }, + Crl: [][]byte{}, + FederatedBundles: map[string][]byte{}, + }, + }, + }, + expectedStdoutPretty: fmt.Sprintf(` +SPIFFE ID: spiffe://example.org/foo +SVID Valid After: %v +SVID Valid Until: %v +CA #1 Valid After: %v +CA #1 Valid Until: %v + +Writing SVID #0 to file %s +Writing key #0 to file %s +Writing bundle #0 to file %s +`, + svid.Certificates[0].NotBefore, + svid.Certificates[0].NotAfter, + ca.Bundle().X509Authorities()[0].NotBefore, + ca.Bundle().X509Authorities()[0].NotAfter, + fmt.Sprintf("%s/svid.0.pem.", testDir), + fmt.Sprintf("%s/svid.0.key.", testDir), + fmt.Sprintf("%s/bundle.0.pem.", testDir), + ), + expectedStdoutJSON: fmt.Sprintf(`{ + "crl": [], + "federated_bundles": {}, + "svids": [ + { + "bundle": "%s", + "spiffe_id": "spiffe://example.org/foo", + "x509_svid": "%s", + "x509_svid_key": "%s" + } + ] +}`, + base64.StdEncoding.EncodeToString(x509util.DERFromCertificates(ca.Bundle().X509Authorities())), + base64.StdEncoding.EncodeToString(x509util.DERFromCertificates(svid.Certificates)), + base64.StdEncoding.EncodeToString(pkcs8FromSigner(t, svid.PrivateKey)), + ), + expectedFileResult: true, + }, + { + name: "fails fetching svid", + fakeRequests: []*fakeworkloadapi.FakeRequest{ + { + Req: &workload.X509SVIDRequest{}, + Resp: &workload.X509SVIDResponse{}, + Err: errors.New("error fetching svid"), + }, + }, + expectedStderr: "rpc error: code = Unknown desc = error fetching svid\n", + }, + } + for _, tt := range tests { + for _, format := range availableFormats { + t.Run(fmt.Sprintf("%s using %s format", tt.name, format), func(t *testing.T) { + test := setupTest(t, newFetchX509Command, tt.fakeRequests...) + args := tt.args + args = append(args, "-output", format) + + rc := test.cmd.Run(test.args(args...)) + + if tt.expectedStderr != "" { + assert.Equal(t, 1, rc) + assert.Equal(t, tt.expectedStderr, test.stderr.String()) + return + } + + assertOutputBasedOnFormat(t, format, test.stdout.String(), tt.expectedStdoutJSON, tt.expectedStdoutPretty) + assert.Empty(t, test.stderr.String()) + assert.Equal(t, 0, rc) + + if tt.expectedFileResult && format == "pretty" { + content, err := os.ReadFile(filepath.Join(testDir, "svid.0.pem")) + assert.NoError(t, err) + assert.Equal(t, pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: svid.Certificates[0].Raw, + }), content) + + content, err = os.ReadFile(filepath.Join(testDir, "svid.0.key")) + assert.NoError(t, err) + assert.Equal(t, string(pem.EncodeToMemory(&pem.Block{ + Type: "PRIVATE KEY", + Bytes: pkcs8FromSigner(t, svid.PrivateKey), + })), string(content)) + + content, err = os.ReadFile(filepath.Join(testDir, "bundle.0.pem")) + assert.NoError(t, err) + assert.Equal(t, pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: ca.Bundle().X509Authorities()[0].Raw, + }), content) + } + }) + } + } +} + +func TestValidateJWTCommandHelp(t *testing.T) { + test := setupTest(t, newValidateJWTCommand) + test.cmd.Help() + require.Equal(t, validateJWTUsage, test.stderr.String()) +} + +func TestValidateJWTCommandSynopsis(t *testing.T) { + test := setupTest(t, newValidateJWTCommand) + require.Equal(t, "Validates a JWT SVID", test.cmd.Synopsis()) +} + +func TestValidateJWTCommand(t *testing.T) { + td := spiffeid.RequireTrustDomainFromString("example.org") + ca := testca.New(t, td) + encodedSvid := ca.CreateJWTSVID(spiffeid.RequireFromString("spiffe://domain1.test"), []string{"foo"}).Marshal() + + tests := []struct { + name string + args []string + fakeRequests []*fakeworkloadapi.FakeRequest + expectedStderr string + expectedStdoutPretty string + expectedStdoutJSON string + }{ + { + name: "valid svid", + args: []string{"-audience", "foo", "-svid", encodedSvid}, + fakeRequests: []*fakeworkloadapi.FakeRequest{ + { + Req: &workload.ValidateJWTSVIDRequest{ + Audience: "foo", + Svid: encodedSvid, + }, + Resp: &workload.ValidateJWTSVIDResponse{ + SpiffeId: "spiffe://example.org/foo", + Claims: &structpb.Struct{ + Fields: map[string]*structpb.Value{ + "aud": { + Kind: &structpb.Value_ListValue{ListValue: &structpb.ListValue{ + Values: []*structpb.Value{{Kind: &structpb.Value_StringValue{StringValue: "foo"}}}, + }, + }, + }, + }, + }, + }, + }, + }, + expectedStdoutPretty: `SVID is valid. +SPIFFE ID : spiffe://example.org/foo +Claims : {"aud":["foo"]}`, + expectedStdoutJSON: `{ + "claims": { + "aud": [ + "foo" + ] + }, + "spiffe_id": "spiffe://example.org/foo" +}`, + }, + { + name: "invalid svid", + args: []string{"-audience", "invalid", "-svid", "invalid"}, + fakeRequests: []*fakeworkloadapi.FakeRequest{ + { + Req: &workload.ValidateJWTSVIDRequest{ + Audience: "foo", + Svid: encodedSvid, + }, + Resp: &workload.ValidateJWTSVIDResponse{}, + Err: status.Error(codes.InvalidArgument, "invalid svid"), + }, + }, + expectedStderr: "SVID is not valid: invalid svid\n", + }, + { + name: "fail when audience is not provided", + expectedStderr: "audience must be specified\n", + }, + { + name: "fail when svid is not provided", + args: []string{"-audience", "foo"}, + expectedStderr: "svid must be specified\n", + }, + } + for _, tt := range tests { + for _, format := range availableFormats { + t.Run(fmt.Sprintf("%s using %s format", tt.name, format), func(t *testing.T) { + test := setupTest(t, newValidateJWTCommand, tt.fakeRequests...) + args := tt.args + args = append(args, "-output", format) + + rc := test.cmd.Run(test.args(args...)) + + if tt.expectedStderr != "" { + assert.Equal(t, 1, rc) + assert.Equal(t, tt.expectedStderr, test.stderr.String()) + return + } + + assertOutputBasedOnFormat(t, format, test.stdout.String(), tt.expectedStdoutJSON, tt.expectedStdoutPretty) + assert.Empty(t, test.stderr.String()) + assert.Equal(t, 0, rc) + }) + } + } +} + +func setupTest(t *testing.T, newCmd func(env *commoncli.Env, clientMaker workloadClientMaker) cli.Command, requests ...*fakeworkloadapi.FakeRequest) *apiTest { + workloadAPIServer := fakeworkloadapi.New(t, requests...) + + addr := spiretest.StartGRPCServer(t, func(s *grpc.Server) { + workload.RegisterSpiffeWorkloadAPIServer(s, workloadAPIServer) + }) + + stdin := new(bytes.Buffer) + stdout := new(bytes.Buffer) + stderr := new(bytes.Buffer) + + cmd := newCmd(&commoncli.Env{ + Stdin: stdin, + Stdout: stdout, + Stderr: stderr, + }, newWorkloadClient) + + test := &apiTest{ + addr: common.GetAddr(addr), + stdin: stdin, + stdout: stdout, + stderr: stderr, + workloadAPI: workloadAPIServer, + cmd: cmd, + } + + t.Cleanup(func() { + test.afterTest(t) + }) + + return test +} + +type apiTest struct { + stdin *bytes.Buffer + stdout *bytes.Buffer + stderr *bytes.Buffer + + addr string + workloadAPI *fakeworkloadapi.WorkloadAPI + + cmd cli.Command +} + +func (s *apiTest) afterTest(t *testing.T) { + t.Logf("TEST:%s", t.Name()) + t.Logf("STDOUT:\n%s", s.stdout.String()) + t.Logf("STDIN:\n%s", s.stdin.String()) + t.Logf("STDERR:\n%s", s.stderr.String()) +} + +func (s *apiTest) args(extra ...string) []string { + return append([]string{common.AddrArg, s.addr}, extra...) +} + +func assertOutputBasedOnFormat(t *testing.T, format, stdoutString, expectedStdoutJSON string, expectedStdoutPretty ...string) { + switch format { + case "pretty": + if len(expectedStdoutPretty) > 0 { + for _, expected := range expectedStdoutPretty { + require.Contains(t, stdoutString, expected) + } + } else { + require.Empty(t, stdoutString) + } + case "json": + if expectedStdoutJSON != "" { + require.JSONEq(t, expectedStdoutJSON, stdoutString) + } else { + require.Empty(t, stdoutString) + } + } +} + +func pkcs8FromSigner(t *testing.T, key crypto.Signer) []byte { + keyBytes, err := x509.MarshalPKCS8PrivateKey(key) + require.NoError(t, err) + return keyBytes +} diff --git a/cmd/spire-agent/cli/api/api_windows_test.go b/cmd/spire-agent/cli/api/api_windows_test.go new file mode 100644 index 0000000000..cc1dd6a95c --- /dev/null +++ b/cmd/spire-agent/cli/api/api_windows_test.go @@ -0,0 +1,45 @@ +//go:build windows +// +build windows + +package api + +const ( + fetchJWTUsage = `Usage of fetch jwt: + -audience value + comma separated list of audience values + -format value + deprecated; use -output + -namedPipeName string + Pipe name of the SPIRE Agent API named pipe (default "\\spire-agent\\public\\api") + -output value + Desired output format (pretty, json); default: pretty. + -spiffeID string + SPIFFE ID subject (optional) + -timeout value + Time to wait for a response (default 5s) +` + fetchX509Usage = `Usage of fetch x509: + -namedPipeName string + Pipe name of the SPIRE Agent API named pipe (default "\\spire-agent\\public\\api") + -output value + Desired output format (pretty, json); default: pretty. + -silent + Suppress stdout + -timeout value + Time to wait for a response (default 5s) + -write string + Write SVID data to the specified path (optional; only available for pretty output format) +` + validateJWTUsage = `Usage of validate jwt: + -audience string + expected audience value + -namedPipeName string + Pipe name of the SPIRE Agent API named pipe (default "\\spire-agent\\public\\api") + -output value + Desired output format (pretty, json); default: pretty. + -svid string + JWT SVID + -timeout value + Time to wait for a response (default 5s) +` +) diff --git a/cmd/spire-agent/cli/api/fetch_jwt.go b/cmd/spire-agent/cli/api/fetch_jwt.go index 7e7654f9c6..f37b377c37 100644 --- a/cmd/spire-agent/cli/api/fetch_jwt.go +++ b/cmd/spire-agent/cli/api/fetch_jwt.go @@ -13,17 +13,18 @@ import ( ) func NewFetchJWTCommand() cli.Command { - return newFetchJWTCommand(commoncli.DefaultEnv, newWorkloadClient) + return newFetchJWTCommandWithEnv(commoncli.DefaultEnv, newWorkloadClient) } -func newFetchJWTCommand(env *commoncli.Env, clientMaker workloadClientMaker) cli.Command { - return adaptCommand(env, clientMaker, new(fetchJWTCommand)) +func newFetchJWTCommandWithEnv(env *commoncli.Env, clientMaker workloadClientMaker) cli.Command { + return adaptCommand(env, clientMaker, &fetchJWTCommand{env: env}) } type fetchJWTCommand struct { audience commoncli.CommaStringsFlag spiffeID string printer cliprinter.Printer + env *commoncli.Env } func (c *fetchJWTCommand) name() string { @@ -54,7 +55,7 @@ func (c *fetchJWTCommand) run(ctx context.Context, env *commoncli.Env, client *w func (c *fetchJWTCommand) appendFlags(fs *flag.FlagSet) { fs.Var(&c.audience, "audience", "comma separated list of audience values") fs.StringVar(&c.spiffeID, "spiffeID", "", "SPIFFE ID subject (optional)") - outputValue := cliprinter.AppendFlagWithCustomPretty(&c.printer, fs, nil, printPrettyResult) + outputValue := cliprinter.AppendFlagWithCustomPretty(&c.printer, fs, c.env, printPrettyResult) fs.Var(outputValue, "format", "deprecated; use -output") } @@ -77,27 +78,25 @@ func (c *fetchJWTCommand) fetchJWTBundles(ctx context.Context, client *workloadC return stream.Recv() } -func printPrettyResult(_ *commoncli.Env, results ...interface{}) error { - errMsg := "internal error: cli printer; please report this bug" - +func printPrettyResult(env *commoncli.Env, results ...interface{}) error { svidResp, ok := results[0].(*workload.JWTSVIDResponse) if !ok { - fmt.Println(errMsg) - return errors.New(errMsg) + env.Println(cliprinter.ErrInternalCustomPrettyFunc.Error()) + return cliprinter.ErrInternalCustomPrettyFunc } bundlesResp, ok := results[1].(*workload.JWTBundlesResponse) if !ok { - fmt.Println(errMsg) - return errors.New(errMsg) + env.Println(cliprinter.ErrInternalCustomPrettyFunc.Error()) + return cliprinter.ErrInternalCustomPrettyFunc } for _, svid := range svidResp.Svids { - fmt.Printf("token(%s):\n\t%s\n", svid.SpiffeId, svid.Svid) + env.Printf("token(%s):\n\t%s\n", svid.SpiffeId, svid.Svid) } for trustDomainID, jwksJSON := range bundlesResp.Bundles { - fmt.Printf("bundle(%s):\n\t%s\n", trustDomainID, string(jwksJSON)) + env.Printf("bundle(%s):\n\t%s\n", trustDomainID, string(jwksJSON)) } return nil diff --git a/cmd/spire-agent/cli/api/fetch_x509.go b/cmd/spire-agent/cli/api/fetch_x509.go index c1188a6d67..16d6d25f60 100644 --- a/cmd/spire-agent/cli/api/fetch_x509.go +++ b/cmd/spire-agent/cli/api/fetch_x509.go @@ -16,21 +16,25 @@ import ( "github.com/spiffe/go-spiffe/v2/proto/spiffe/workload" "github.com/spiffe/go-spiffe/v2/spiffeid" "github.com/spiffe/go-spiffe/v2/svid/x509svid" - common_cli "github.com/spiffe/spire/pkg/common/cli" + commoncli "github.com/spiffe/spire/pkg/common/cli" + "github.com/spiffe/spire/pkg/common/cliprinter" "github.com/spiffe/spire/pkg/common/diskutil" ) func NewFetchX509Command() cli.Command { - return newFetchX509Command(common_cli.DefaultEnv, newWorkloadClient) + return newFetchX509Command(commoncli.DefaultEnv, newWorkloadClient) } -func newFetchX509Command(env *common_cli.Env, clientMaker workloadClientMaker) cli.Command { - return adaptCommand(env, clientMaker, new(fetchX509Command)) +func newFetchX509Command(env *commoncli.Env, clientMaker workloadClientMaker) cli.Command { + return adaptCommand(env, clientMaker, &fetchX509Command{env: env}) } type fetchX509Command struct { silent bool writePath string + env *commoncli.Env + printer cliprinter.Printer + respTime time.Duration } func (*fetchX509Command) name() string { @@ -41,35 +45,21 @@ func (*fetchX509Command) synopsis() string { return "Fetches X509 SVIDs from the Workload API" } -func (c *fetchX509Command) run(ctx context.Context, env *common_cli.Env, client *workloadClient) error { +func (c *fetchX509Command) run(ctx context.Context, env *commoncli.Env, client *workloadClient) error { start := time.Now() resp, err := c.fetchX509SVID(ctx, client) - respTime := time.Since(start) + c.respTime = time.Since(start) if err != nil { return err } - svids, err := parseAndValidateX509SVIDResponse(resp) - if err != nil { - return err - } - - if !c.silent { - printX509SVIDResponse(svids, respTime) - } - - if c.writePath != "" { - if err := c.writeResponse(svids); err != nil { - return err - } - } - - return nil + return c.printer.PrintProto(resp) } func (c *fetchX509Command) appendFlags(fs *flag.FlagSet) { fs.BoolVar(&c.silent, "silent", false, "Suppress stdout") - fs.StringVar(&c.writePath, "write", "", "Write SVID data to the specified path (optional)") + fs.StringVar(&c.writePath, "write", "", "Write SVID data to the specified path (optional; only available for pretty output format)") + cliprinter.AppendFlagWithCustomPretty(&c.printer, fs, c.env, c.prettyPrintFetchX509) } func (c *fetchX509Command) fetchX509SVID(ctx context.Context, client *workloadClient) (*workload.X509SVIDResponse, error) { @@ -90,19 +80,19 @@ func (c *fetchX509Command) writeResponse(svids []*X509SVID) error { keyPath := path.Join(c.writePath, fmt.Sprintf("svid.%v.key", i)) bundlePath := path.Join(c.writePath, fmt.Sprintf("bundle.%v.pem", i)) - fmt.Printf("Writing SVID #%d to file %s.\n", i, svidPath) + c.env.Printf("Writing SVID #%d to file %s.\n", i, svidPath) err := c.writeCerts(svidPath, svid.Certificates) if err != nil { return err } - fmt.Printf("Writing key #%d to file %s.\n", i, keyPath) + c.env.Printf("Writing key #%d to file %s.\n", i, keyPath) err = c.writeKey(keyPath, svid.PrivateKey) if err != nil { return err } - fmt.Printf("Writing bundle #%d to file %s.\n", i, bundlePath) + c.env.Printf("Writing bundle #%d to file %s.\n", i, bundlePath) err = c.writeCerts(bundlePath, svid.Bundle) if err != nil { return err @@ -116,7 +106,7 @@ func (c *fetchX509Command) writeResponse(svids []*X509SVID) error { for j, trustDomain := range federatedDomains { bundlePath := path.Join(c.writePath, fmt.Sprintf("federated_bundle.%d.%d.pem", i, j)) - fmt.Printf("Writing federated bundle #%d for trust domain %s to file %s.\n", j, trustDomain, bundlePath) + c.env.Printf("Writing federated bundle #%d for trust domain %s to file %s.\n", j, trustDomain, bundlePath) err = c.writeCerts(bundlePath, svid.FederatedBundles[trustDomain]) if err != nil { return err @@ -161,6 +151,30 @@ func (c *fetchX509Command) writeFile(filename string, data []byte) error { return diskutil.WritePubliclyReadableFile(filename, data) } +func (c *fetchX509Command) prettyPrintFetchX509(env *commoncli.Env, results ...interface{}) error { + resp, ok := results[0].(*workload.X509SVIDResponse) + if !ok { + return cliprinter.ErrInternalCustomPrettyFunc + } + + svids, err := parseAndValidateX509SVIDResponse(resp) + if err != nil { + return err + } + + if !c.silent { + printX509SVIDResponse(env, svids, c.respTime) + } + + if c.writePath != "" { + if err := c.writeResponse(svids); err != nil { + return err + } + } + + return nil +} + type X509SVID struct { SPIFFEID string Certificates []*x509.Certificate @@ -201,7 +215,7 @@ func parseX509SVIDResponse(resp *workload.X509SVIDResponse) ([]*X509SVID, error) for i, respSVID := range resp.Svids { svid, err := parseX509SVID(respSVID, federatedBundles) if err != nil { - return nil, fmt.Errorf("failed to parse svid entry %d for spiffe id %q: %w", i, svid.SPIFFEID, err) + return nil, fmt.Errorf("failed to parse svid entry %d for spiffe id %q: %w", i, respSVID.SpiffeId, err) } svids = append(svids, svid) } diff --git a/cmd/spire-agent/cli/api/printer.go b/cmd/spire-agent/cli/api/printer.go index 68656febbc..2ef61b77ce 100644 --- a/cmd/spire-agent/cli/api/printer.go +++ b/cmd/spire-agent/cli/api/printer.go @@ -4,50 +4,52 @@ import ( "crypto/x509" "fmt" "time" + + commoncli "github.com/spiffe/spire/pkg/common/cli" ) -func printX509SVIDResponse(svids []*X509SVID, respTime time.Duration) { +func printX509SVIDResponse(env *commoncli.Env, svids []*X509SVID, respTime time.Duration) { lenMsg := fmt.Sprintf("Received %d svid", len(svids)) if len(svids) != 1 { lenMsg += "s" } lenMsg += fmt.Sprintf(" after %s", respTime) - fmt.Println(lenMsg) + env.Println(lenMsg) for _, svid := range svids { - fmt.Println() - printX509SVID(svid) + env.Println() + printX509SVID(env, svid) for trustDomain, bundle := range svid.FederatedBundles { - printX509FederatedBundle(trustDomain, bundle) + printX509FederatedBundle(env, trustDomain, bundle) } } - fmt.Println() + env.Println() } -func printX509SVID(svid *X509SVID) { +func printX509SVID(env *commoncli.Env, svid *X509SVID) { // Print SPIFFE ID first so if we run into a problem, we // get to know which record it was - fmt.Printf("SPIFFE ID:\t\t%s\n", svid.SPIFFEID) + env.Printf("SPIFFE ID:\t\t%s\n", svid.SPIFFEID) - fmt.Printf("SVID Valid After:\t%v\n", svid.Certificates[0].NotBefore) - fmt.Printf("SVID Valid Until:\t%v\n", svid.Certificates[0].NotAfter) + env.Printf("SVID Valid After:\t%v\n", svid.Certificates[0].NotBefore) + env.Printf("SVID Valid Until:\t%v\n", svid.Certificates[0].NotAfter) for i, intermediate := range svid.Certificates[1:] { num := i + 1 - fmt.Printf("Intermediate #%v Valid After:\t%v\n", num, intermediate.NotBefore) - fmt.Printf("Intermediate #%v Valid Until:\t%v\n", num, intermediate.NotAfter) + env.Printf("Intermediate #%v Valid After:\t%v\n", num, intermediate.NotBefore) + env.Printf("Intermediate #%v Valid Until:\t%v\n", num, intermediate.NotAfter) } for i, ca := range svid.Bundle { num := i + 1 - fmt.Printf("CA #%v Valid After:\t%v\n", num, ca.NotBefore) - fmt.Printf("CA #%v Valid Until:\t%v\n", num, ca.NotAfter) + env.Printf("CA #%v Valid After:\t%v\n", num, ca.NotBefore) + env.Printf("CA #%v Valid Until:\t%v\n", num, ca.NotAfter) } } -func printX509FederatedBundle(trustDomain string, bundle []*x509.Certificate) { +func printX509FederatedBundle(env *commoncli.Env, trustDomain string, bundle []*x509.Certificate) { for i, ca := range bundle { num := i + 1 - fmt.Printf("[%s] CA #%v Valid After:\t%v\n", trustDomain, num, ca.NotBefore) - fmt.Printf("[%s] CA #%v Valid Until:\t%v\n", trustDomain, num, ca.NotAfter) + env.Printf("[%s] CA #%v Valid After:\t%v\n", trustDomain, num, ca.NotBefore) + env.Printf("[%s] CA #%v Valid Until:\t%v\n", trustDomain, num, ca.NotAfter) } } diff --git a/cmd/spire-agent/cli/api/validate_jwt.go b/cmd/spire-agent/cli/api/validate_jwt.go index 72e2b784ee..aa6400ee33 100644 --- a/cmd/spire-agent/cli/api/validate_jwt.go +++ b/cmd/spire-agent/cli/api/validate_jwt.go @@ -8,23 +8,26 @@ import ( "github.com/mitchellh/cli" "github.com/spiffe/go-spiffe/v2/proto/spiffe/workload" - common_cli "github.com/spiffe/spire/pkg/common/cli" + commoncli "github.com/spiffe/spire/pkg/common/cli" + "github.com/spiffe/spire/pkg/common/cliprinter" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/encoding/protojson" ) func NewValidateJWTCommand() cli.Command { - return newValidateJWTCommand(common_cli.DefaultEnv, newWorkloadClient) + return newValidateJWTCommand(commoncli.DefaultEnv, newWorkloadClient) } -func newValidateJWTCommand(env *common_cli.Env, clientMaker workloadClientMaker) cli.Command { - return adaptCommand(env, clientMaker, new(validateJWTCommand)) +func newValidateJWTCommand(env *commoncli.Env, clientMaker workloadClientMaker) cli.Command { + return adaptCommand(env, clientMaker, &validateJWTCommand{env: env}) } type validateJWTCommand struct { audience string svid string + env *commoncli.Env + printer cliprinter.Printer } func (*validateJWTCommand) name() string { @@ -38,9 +41,10 @@ func (*validateJWTCommand) synopsis() string { func (c *validateJWTCommand) appendFlags(fs *flag.FlagSet) { fs.StringVar(&c.audience, "audience", "", "expected audience value") fs.StringVar(&c.svid, "svid", "", "JWT SVID") + cliprinter.AppendFlagWithCustomPretty(&c.printer, fs, c.env, prettyPrintValidate) } -func (c *validateJWTCommand) run(ctx context.Context, env *common_cli.Env, client *workloadClient) error { +func (c *validateJWTCommand) run(ctx context.Context, env *commoncli.Env, client *workloadClient) error { if c.audience == "" { return errors.New("audience must be specified") } @@ -53,17 +57,7 @@ func (c *validateJWTCommand) run(ctx context.Context, env *common_cli.Env, clien return err } - if err := env.Println("SVID is valid."); err != nil { - return err - } - if err := env.Println("SPIFFE ID :", resp.SpiffeId); err != nil { - return err - } - claims, err := protojson.Marshal(resp.Claims) - if err != nil { - return fmt.Errorf("unable to unmarshal claims: %w", err) - } - return env.Println("Claims :", string(claims)) + return c.printer.PrintProto(resp) } func (c *validateJWTCommand) validateJWTSVID(ctx context.Context, client *workloadClient) (*workload.ValidateJWTSVIDResponse, error) { @@ -81,3 +75,21 @@ func (c *validateJWTCommand) validateJWTSVID(ctx context.Context, client *worklo } return resp, nil } + +func prettyPrintValidate(env *commoncli.Env, results ...interface{}) error { + resp, ok := results[0].(*workload.ValidateJWTSVIDResponse) + if !ok { + return cliprinter.ErrInternalCustomPrettyFunc + } + if err := env.Println("SVID is valid."); err != nil { + return err + } + if err := env.Println("SPIFFE ID :", resp.SpiffeId); err != nil { + return err + } + claims, err := protojson.Marshal(resp.Claims) + if err != nil { + return fmt.Errorf("unable to unmarshal claims: %w", err) + } + return env.Println("Claims :", string(claims)) +} diff --git a/cmd/spire-agent/cli/api/watch.go b/cmd/spire-agent/cli/api/watch.go index f08babc7b9..06d2993226 100644 --- a/cmd/spire-agent/cli/api/watch.go +++ b/cmd/spire-agent/cli/api/watch.go @@ -11,6 +11,7 @@ import ( "github.com/spiffe/go-spiffe/v2/workloadapi" "github.com/spiffe/spire/cmd/spire-agent/cli/common" + commoncli "github.com/spiffe/spire/pkg/common/cli" "github.com/spiffe/spire/pkg/common/util" ) @@ -98,7 +99,7 @@ func (w *watcher) OnX509ContextUpdate(x509Context *workloadapi.X509Context) { FederatedBundles: federatedBundles, }) } - printX509SVIDResponse(svids, time.Since(w.updateTime)) + printX509SVIDResponse(commoncli.DefaultEnv, svids, time.Since(w.updateTime)) w.updateTime = time.Now() } diff --git a/test/fakes/fakeworkloadapi/workloadapi.go b/test/fakes/fakeworkloadapi/workloadapi.go index 14ab17ec4f..a016e54743 100644 --- a/test/fakes/fakeworkloadapi/workloadapi.go +++ b/test/fakes/fakeworkloadapi/workloadapi.go @@ -3,68 +3,56 @@ package fakeworkloadapi import ( "context" "errors" + "fmt" "net" - "sync" "testing" "github.com/spiffe/go-spiffe/v2/proto/spiffe/workload" "github.com/spiffe/spire/test/spiretest" "github.com/stretchr/testify/require" "google.golang.org/grpc/metadata" + "google.golang.org/protobuf/proto" ) -type Result interface { - result() -} - -type fetchX509SVIDResult func(workload.SpiffeWorkloadAPI_FetchX509SVIDServer) (bool, error) - -func (fn fetchX509SVIDResult) do(stream workload.SpiffeWorkloadAPI_FetchX509SVIDServer) (done bool, err error) { - return fn(stream) -} - -func (fn fetchX509SVIDResult) result() {} - -func FetchX509SVIDErrorOnce(err error) Result { - return fetchX509SVIDResult(func(workload.SpiffeWorkloadAPI_FetchX509SVIDServer) (bool, error) { - return true, err - }) -} - -func FetchX509SVIDErrorAlways(err error) Result { - return fetchX509SVIDResult(func(workload.SpiffeWorkloadAPI_FetchX509SVIDServer) (bool, error) { - return false, err - }) -} - -func FetchX509SVIDResponses(responses ...*workload.X509SVIDResponse) Result { - return fetchX509SVIDResult(func(stream workload.SpiffeWorkloadAPI_FetchX509SVIDServer) (bool, error) { - for _, response := range responses { - if err := stream.Send(response); err != nil { - return true, err - } - } - return true, nil - }) +type FakeRequest struct { + Req proto.Message + Resp proto.Message + Err error } type WorkloadAPI struct { workload.UnimplementedSpiffeWorkloadAPIServer addr net.Addr + t *testing.T + + ExpFetchJWTSVIDReq *workload.JWTSVIDRequest + ExpFetchJWTBundlesReq *workload.JWTBundlesRequest - mu sync.Mutex - fetchX509SVIDResults []fetchX509SVIDResult + fetchX509SVIDRequest FakeRequest + fetchJWTSVIDRequest FakeRequest + fetchJWTBundlesRequest FakeRequest + validateJWTRequest FakeRequest } -func New(t *testing.T, results ...Result) *WorkloadAPI { +func New(t *testing.T, responses ...*FakeRequest) *WorkloadAPI { w := new(WorkloadAPI) + w.t = t - for _, result := range results { - switch result := result.(type) { - case fetchX509SVIDResult: - w.fetchX509SVIDResults = append(w.fetchX509SVIDResults, result) + for _, response := range responses { + if response == nil { + continue + } + switch response.Resp.(type) { + case *workload.X509SVIDResponse: + w.fetchX509SVIDRequest = *response + case *workload.JWTSVIDResponse: + w.fetchJWTSVIDRequest = *response + case *workload.JWTBundlesResponse: + w.fetchJWTBundlesRequest = *response + case *workload.ValidateJWTSVIDResponse: + w.validateJWTRequest = *response default: - require.FailNow(t, "unexpected result type %T", result) + require.FailNow(t, "unexpected result type %T", response.Resp) } } @@ -82,57 +70,82 @@ func (w *WorkloadAPI) FetchX509SVID(req *workload.X509SVIDRequest, stream worklo return err } - // service all of the results - for { - result := w.nextFetchX509SVIDResult() - if result == nil { - break - } + if w.fetchX509SVIDRequest.Err != nil { + return w.fetchX509SVIDRequest.Err + } - done, err := result.do(stream) - if done { - w.advanceFetchX509SVIDResult() - } + if request, ok := w.fetchX509SVIDRequest.Req.(*workload.X509SVIDRequest); ok { + spiretest.AssertProtoEqual(w.t, request, req) + } else { + require.FailNow(w.t, fmt.Sprintf("unexpected message type %T", w.fetchX509SVIDRequest.Req)) + } - if err != nil { - return err - } + if response, ok := w.fetchX509SVIDRequest.Resp.(*workload.X509SVIDResponse); ok { + _ = stream.Send(response) + <-stream.Context().Done() + } else { + require.FailNow(w.t, fmt.Sprintf("unexpected message type %T", w.fetchX509SVIDRequest.Resp)) } - // wait for the context to be canceled - <-stream.Context().Done() return nil } -func (w *WorkloadAPI) FetchJWTSVID(context.Context, *workload.JWTSVIDRequest) (*workload.JWTSVIDResponse, error) { - return nil, errors.New("unimplemented") -} +func (w *WorkloadAPI) FetchJWTSVID(ctx context.Context, req *workload.JWTSVIDRequest) (*workload.JWTSVIDResponse, error) { + if w.fetchJWTSVIDRequest.Err != nil { + return nil, w.fetchJWTSVIDRequest.Err + } + if request, ok := w.fetchJWTSVIDRequest.Req.(*workload.JWTSVIDRequest); ok { + spiretest.AssertProtoEqual(w.t, request, req) + } else { + require.FailNow(w.t, fmt.Sprintf("unexpected message type %T", w.fetchJWTSVIDRequest.Req)) + } -func (w *WorkloadAPI) FetchJWTBundles(*workload.JWTBundlesRequest, workload.SpiffeWorkloadAPI_FetchJWTBundlesServer) error { - return errors.New("unimplemented") + if response, ok := w.fetchJWTSVIDRequest.Resp.(*workload.JWTSVIDResponse); ok { + return response, nil + } + require.FailNow(w.t, fmt.Sprintf("unexpected message type %T", w.fetchJWTSVIDRequest.Resp)) + return nil, nil } -func (w *WorkloadAPI) ValidateJWTSVID(context.Context, *workload.ValidateJWTSVIDRequest) (*workload.ValidateJWTSVIDResponse, error) { - return nil, errors.New("unimplemented") -} +func (w *WorkloadAPI) FetchJWTBundles(req *workload.JWTBundlesRequest, stream workload.SpiffeWorkloadAPI_FetchJWTBundlesServer) error { + if err := checkSecurityHeader(stream.Context()); err != nil { + return err + } + + if w.fetchJWTBundlesRequest.Err != nil { + return w.fetchJWTBundlesRequest.Err + } -func (w *WorkloadAPI) nextFetchX509SVIDResult() fetchX509SVIDResult { - w.mu.Lock() - defer w.mu.Unlock() + if request, ok := w.fetchJWTBundlesRequest.Req.(*workload.JWTBundlesRequest); ok { + spiretest.AssertProtoEqual(w.t, request, req) + } else { + require.FailNow(w.t, fmt.Sprintf("unexpected message type %T", w.fetchJWTBundlesRequest.Req)) + } - if len(w.fetchX509SVIDResults) == 0 { - return nil + if response, ok := w.fetchJWTBundlesRequest.Resp.(*workload.JWTBundlesResponse); ok { + _ = stream.Send(response) + <-stream.Context().Done() + } else { + require.FailNow(w.t, fmt.Sprintf("unexpected message type %T", w.fetchJWTBundlesRequest.Resp)) } - return w.fetchX509SVIDResults[0] + return nil } -func (w *WorkloadAPI) advanceFetchX509SVIDResult() { - w.mu.Lock() - defer w.mu.Unlock() +func (w *WorkloadAPI) ValidateJWTSVID(ctx context.Context, req *workload.ValidateJWTSVIDRequest) (*workload.ValidateJWTSVIDResponse, error) { + if w.validateJWTRequest.Err != nil { + return nil, w.validateJWTRequest.Err + } + if request, ok := w.validateJWTRequest.Req.(*workload.ValidateJWTSVIDRequest); ok { + spiretest.AssertProtoEqual(w.t, request, req) + } else { + require.FailNow(w.t, fmt.Sprintf("unexpected message type %T", w.validateJWTRequest.Req)) + } - if len(w.fetchX509SVIDResults) > 0 { - w.fetchX509SVIDResults = w.fetchX509SVIDResults[1:] + if response, ok := w.validateJWTRequest.Resp.(*workload.ValidateJWTSVIDResponse); ok { + return response, nil } + require.FailNow(w.t, fmt.Sprintf("unexpected message type %T", w.validateJWTRequest.Resp)) + return nil, nil } func checkSecurityHeader(ctx context.Context) error {