Skip to content

Commit

Permalink
api/render, errs: moved StatusCoder & StackTracer to the render package
Browse files Browse the repository at this point in the history
  • Loading branch information
azazeal committed Mar 29, 2022
1 parent 691fbad commit 3d31449
Show file tree
Hide file tree
Showing 19 changed files with 266 additions and 253 deletions.
34 changes: 15 additions & 19 deletions api/log/log.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,21 @@ import (
"net/http"
"os"

"github.com/smallstep/certificates/errs"
"github.com/pkg/errors"

"github.com/smallstep/certificates/logging"
)

// StackTracedError is the set of errors implementing the StackTrace function.
//
// Errors implementing this interface have their stack traces logged when passed
// to the Error function of this package.
type StackTracedError interface {
error

StackTrace() errors.StackTrace
}

// Error adds to the response writer the given error if it implements
// logging.ResponseLogger. If it does not implement it, then writes the error
// using the log package.
Expand All @@ -30,33 +41,18 @@ func Error(rw http.ResponseWriter, err error) {
return
}

e, ok := err.(errs.StackTracer)
e, ok := err.(StackTracedError)
if !ok {
e, ok = cause(err).(errs.StackTracer)
e, ok = errors.Cause(err).(StackTracedError)
}

if ok {
rl.WithFields(map[string]interface{}{
"stack-trace": fmt.Sprintf("%+v", e),
"stack-trace": fmt.Sprintf("%+v", e.StackTrace()),
})
}
}

func cause(err error) error {
type causer interface {
Cause() error
}

for err != nil {
cause, ok := err.(causer)
if !ok {
break
}
err = cause.Cause()
}
return err
}

// EnabledResponse log the response object if it implements the EnableLogger
// interface.
func EnabledResponse(rw http.ResponseWriter, v interface{}) {
Expand Down
31 changes: 17 additions & 14 deletions api/render/render.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"google.golang.org/protobuf/proto"

"github.com/smallstep/certificates/api/log"
"github.com/smallstep/certificates/errs"
)

// JSON is shorthand for JSONStatus(w, v, http.StatusOK).
Expand Down Expand Up @@ -87,32 +86,36 @@ func Error(w http.ResponseWriter, err error) {
JSONStatus(w, err, statusCodeFromError(err))
}

func statusCodeFromError(err error) (code int) {
code = http.StatusInternalServerError
// StatusCodedError is the set of errors that implement the basic StatusCode
// function.
//
// Errors that implement this interface will use the code reported by StatusCode
// as the HTTP response code when being rendered by this package.
type StatusCodedError interface {
error

// if the error implements err
if sc, ok := err.(errs.StatusCoder); ok {
code = sc.StatusCode()
StatusCode() int
}

return
}
func statusCodeFromError(err error) (code int) {
code = http.StatusInternalServerError

type causer interface {
Cause() error
}

for err != nil {
cause, ok := err.(causer)
if !ok {
if sc, ok := err.(StatusCodedError); ok {
code = sc.StatusCode()

break
}
err = cause.Cause()

if sc, ok := err.(errs.StatusCoder); ok {
code = sc.StatusCode()

cause, ok := err.(causer)
if !ok {
break
}
err = cause.Cause()
}

return
Expand Down
49 changes: 26 additions & 23 deletions authority/authorize_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,26 @@ import (
"crypto/x509/pkix"
"encoding/asn1"
"encoding/base64"
"errors"
"fmt"
"net/http"
"reflect"
"strconv"
"testing"
"time"

"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/errs"
"golang.org/x/crypto/ssh"

"go.step.sm/crypto/jose"
"go.step.sm/crypto/pemutil"
"go.step.sm/crypto/randutil"
"go.step.sm/crypto/x509util"
"golang.org/x/crypto/ssh"

"github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/errs"
)

var testAudiences = provisioner.Audiences{
Expand Down Expand Up @@ -310,8 +313,8 @@ func TestAuthority_authorizeToken(t *testing.T) {
p, err := tc.auth.authorizeToken(context.Background(), tc.token)
if err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
Expand Down Expand Up @@ -396,8 +399,8 @@ func TestAuthority_authorizeRevoke(t *testing.T) {

if err := tc.auth.authorizeRevoke(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
Expand Down Expand Up @@ -481,8 +484,8 @@ func TestAuthority_authorizeSign(t *testing.T) {
got, err := tc.auth.authorizeSign(context.Background(), tc.token)
if err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
Expand Down Expand Up @@ -740,8 +743,8 @@ func TestAuthority_Authorize(t *testing.T) {
if err != nil {
if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) {
assert.Nil(t, got)
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())

Expand Down Expand Up @@ -853,7 +856,7 @@ func TestAuthority_authorizeRenew(t *testing.T) {
err := tc.auth.authorizeRenew(tc.cert)
if err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
Expand Down Expand Up @@ -1001,8 +1004,8 @@ func TestAuthority_authorizeSSHSign(t *testing.T) {
got, err := tc.auth.authorizeSSHSign(context.Background(), tc.token)
if err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
Expand Down Expand Up @@ -1118,8 +1121,8 @@ func TestAuthority_authorizeSSHRenew(t *testing.T) {
got, err := tc.auth.authorizeSSHRenew(context.Background(), tc.token)
if err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
Expand Down Expand Up @@ -1218,8 +1221,8 @@ func TestAuthority_authorizeSSHRevoke(t *testing.T) {

if err := tc.auth.authorizeSSHRevoke(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
Expand Down Expand Up @@ -1311,8 +1314,8 @@ func TestAuthority_authorizeSSHRekey(t *testing.T) {
cert, signOpts, err := tc.auth.authorizeSSHRekey(context.Background(), tc.token)
if err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
Expand Down
17 changes: 9 additions & 8 deletions authority/provisioner/acme_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@ package provisioner
import (
"context"
"crypto/x509"
"errors"
"fmt"
"net/http"
"testing"
"time"

"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/api/render"
)

func TestACME_Getters(t *testing.T) {
Expand Down Expand Up @@ -114,7 +115,7 @@ func TestACME_AuthorizeRenew(t *testing.T) {
NotAfter: now.Add(time.Hour),
},
code: http.StatusUnauthorized,
err: errors.Errorf("renew is disabled for provisioner '%s'", p.GetName()),
err: fmt.Errorf("renew is disabled for provisioner '%s'", p.GetName()),
}
},
"ok": func(t *testing.T) test {
Expand All @@ -133,8 +134,8 @@ func TestACME_AuthorizeRenew(t *testing.T) {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if err := tc.p.AuthorizeRenew(context.Background(), tc.cert); err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code)
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
Expand Down Expand Up @@ -168,8 +169,8 @@ func TestACME_AuthorizeSign(t *testing.T) {
tc := tt(t)
if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
Expand All @@ -192,7 +193,7 @@ func TestACME_AuthorizeSign(t *testing.T) {
assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration())
default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
}
}
}
Expand Down
25 changes: 13 additions & 12 deletions authority/provisioner/aws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"crypto/x509"
"encoding/hex"
"encoding/pem"
"errors"
"fmt"
"net"
"net/http"
Expand All @@ -17,10 +18,10 @@ import (
"testing"
"time"

"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose"

"github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
)

func TestAWS_Getters(t *testing.T) {
Expand Down Expand Up @@ -521,8 +522,8 @@ func TestAWS_authorizeToken(t *testing.T) {
tc := tt(t)
if claims, err := tc.p.authorizeToken(tc.token); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
Expand Down Expand Up @@ -668,8 +669,8 @@ func TestAWS_AuthorizeSign(t *testing.T) {
t.Errorf("AWS.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
return
case err != nil:
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code)
default:
assert.Len(t, tt.wantLen, got)
Expand Down Expand Up @@ -698,7 +699,7 @@ func TestAWS_AuthorizeSign(t *testing.T) {
case dnsNamesValidator:
assert.Equals(t, []string(v), []string{"ip-127-0-0-1.us-west-1.compute.internal"})
default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
}
}
}
Expand Down Expand Up @@ -802,8 +803,8 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) {
return
}
if err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code)
assert.Nil(t, got)
} else if assert.NotNil(t, got) {
Expand Down Expand Up @@ -860,8 +861,8 @@ func TestAWS_AuthorizeRenew(t *testing.T) {
if err := tt.aws.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("AWS.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
} else if err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code)
}
})
Expand Down
Loading

0 comments on commit 3d31449

Please sign in to comment.