Improve handling of bad JSON protobuf bodies

This commit is contained in:
Herman Slatman 2022-04-18 23:38:13 +02:00
parent 2ca5c0170f
commit def9438ad6
No known key found for this signature in database
GPG key ID: F4D8A44EA0A75A4F
4 changed files with 90 additions and 77 deletions

View file

@ -10,7 +10,6 @@ import (
"google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
) )
@ -24,62 +23,55 @@ func JSON(r io.Reader, v interface{}) error {
} }
// ProtoJSON reads JSON from the request body and stores it in the value // ProtoJSON reads JSON from the request body and stores it in the value
// pointed by v. // pointed to by v.
func ProtoJSON(r io.Reader, m proto.Message) error { func ProtoJSON(r io.Reader, m proto.Message) error {
data, err := io.ReadAll(r) data, err := io.ReadAll(r)
if err != nil { if err != nil {
return errs.BadRequestErr(err, "error reading request body") return errs.BadRequestErr(err, "error reading request body")
} }
return protojson.Unmarshal(data, m)
}
// ProtoJSONWithCheck reads JSON from the request body and stores it in the value
// pointed to by m. Returns false if an error was written; true if not.
// TODO(hs): refactor this after the API flow changes are in (or before if that works)
func ProtoJSONWithCheck(w http.ResponseWriter, r io.Reader, m proto.Message) bool {
data, err := io.ReadAll(r)
if err != nil {
var wrapper = struct {
Status int `json:"code"`
Message string `json:"message"`
}{
Status: http.StatusBadRequest,
Message: err.Error(),
}
errData, err := json.Marshal(wrapper)
if err != nil {
panic(err)
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
w.Write(errData)
return false
}
if err := protojson.Unmarshal(data, m); err != nil { if err := protojson.Unmarshal(data, m); err != nil {
if errors.Is(err, proto.Error) { if errors.Is(err, proto.Error) {
var wrapper = struct { return newBadProtoJSONError(err)
Type string `json:"type"`
Detail string `json:"detail"`
Message string `json:"message"`
}{
Type: "badRequest",
Detail: "bad request",
Message: err.Error(),
}
errData, err := json.Marshal(wrapper)
if err != nil {
panic(err)
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
w.Write(errData)
return false
} }
}
return err
}
// fallback to the default error writer // BadProtoJSONError is an error type that is used when a proto
render.Error(w, err) // message cannot be unmarshaled. Usually this is caused by an error
return false // in the request body.
type BadProtoJSONError struct {
err error
Type string `json:"type"`
Detail string `json:"detail"`
Message string `json:"message"`
}
// newBadProtoJSONError returns a new instance of BadProtoJSONError
// This error type is always caused by an error in the request body.
func newBadProtoJSONError(err error) *BadProtoJSONError {
return &BadProtoJSONError{
err: err,
Type: "badRequest",
Detail: "bad request",
Message: err.Error(),
}
}
// Error implements the error interface
func (e *BadProtoJSONError) Error() string {
return e.err.Error()
}
// Render implements render.RenderableError for BadProtoError
func (e *BadProtoJSONError) Render(w http.ResponseWriter) {
errData, err := json.Marshal(e)
if err != nil {
panic(err)
} }
return true w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
w.Write(errData)
} }

View file

@ -80,7 +80,8 @@ func (par *PolicyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r
} }
var newPolicy = new(linkedca.Policy) var newPolicy = new(linkedca.Policy)
if !read.ProtoJSONWithCheck(w, r.Body, newPolicy) { if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
render.Error(w, err)
return return
} }
@ -120,7 +121,8 @@ func (par *PolicyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r
} }
var newPolicy = new(linkedca.Policy) var newPolicy = new(linkedca.Policy)
if !read.ProtoJSONWithCheck(w, r.Body, newPolicy) { if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
render.Error(w, err)
return return
} }
@ -195,7 +197,8 @@ func (par *PolicyAdminResponder) CreateProvisionerPolicy(w http.ResponseWriter,
} }
var newPolicy = new(linkedca.Policy) var newPolicy = new(linkedca.Policy)
if !read.ProtoJSONWithCheck(w, r.Body, newPolicy) { if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
render.Error(w, err)
return return
} }
@ -228,7 +231,8 @@ func (par *PolicyAdminResponder) UpdateProvisionerPolicy(w http.ResponseWriter,
} }
var newPolicy = new(linkedca.Policy) var newPolicy = new(linkedca.Policy)
if !read.ProtoJSONWithCheck(w, r.Body, newPolicy) { if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
render.Error(w, err)
return return
} }
@ -297,7 +301,8 @@ func (par *PolicyAdminResponder) CreateACMEAccountPolicy(w http.ResponseWriter,
} }
var newPolicy = new(linkedca.Policy) var newPolicy = new(linkedca.Policy)
if !read.ProtoJSONWithCheck(w, r.Body, newPolicy) { if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
render.Error(w, err)
return return
} }
@ -324,7 +329,8 @@ func (par *PolicyAdminResponder) UpdateACMEAccountPolicy(w http.ResponseWriter,
} }
var newPolicy = new(linkedca.Policy) var newPolicy = new(linkedca.Policy)
if !read.ProtoJSONWithCheck(w, r.Body, newPolicy) { if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
render.Error(w, err)
return return
} }

View file

@ -167,7 +167,7 @@ func TestPolicyAdminResponder_CreateAuthorityPolicy(t *testing.T) {
statusCode: 409, statusCode: 409,
} }
}, },
"fail/read.ProtoJSONWithCheck": func(t *testing.T) test { "fail/read.ProtoJSON": func(t *testing.T) test {
ctx := context.Background() ctx := context.Background()
adminErr := admin.NewError(admin.ErrorBadRequestType, "proto: syntax error (line 1:2): invalid value ?") adminErr := admin.NewError(admin.ErrorBadRequestType, "proto: syntax error (line 1:2): invalid value ?")
adminErr.Message = "proto: syntax error (line 1:2): invalid value ?" adminErr.Message = "proto: syntax error (line 1:2): invalid value ?"
@ -410,7 +410,7 @@ func TestPolicyAdminResponder_UpdateAuthorityPolicy(t *testing.T) {
statusCode: 404, statusCode: 404,
} }
}, },
"fail/read.ProtoJSONWithCheck": func(t *testing.T) test { "fail/read.ProtoJSON": func(t *testing.T) test {
policy := &linkedca.Policy{ policy := &linkedca.Policy{
X509: &linkedca.X509Policy{ X509: &linkedca.X509Policy{
Allow: &linkedca.X509Names{ Allow: &linkedca.X509Names{
@ -871,7 +871,7 @@ func TestPolicyAdminResponder_CreateProvisionerPolicy(t *testing.T) {
statusCode: 409, statusCode: 409,
} }
}, },
"fail/read.ProtoJSONWithCheck": func(t *testing.T) test { "fail/read.ProtoJSON": func(t *testing.T) test {
prov := &linkedca.Provisioner{ prov := &linkedca.Provisioner{
Name: "provName", Name: "provName",
} }
@ -1060,7 +1060,7 @@ func TestPolicyAdminResponder_UpdateProvisionerPolicy(t *testing.T) {
statusCode: 404, statusCode: 404,
} }
}, },
"fail/read.ProtoJSONWithCheck": func(t *testing.T) test { "fail/read.ProtoJSON": func(t *testing.T) test {
policy := &linkedca.Policy{ policy := &linkedca.Policy{
X509: &linkedca.X509Policy{ X509: &linkedca.X509Policy{
Allow: &linkedca.X509Names{ Allow: &linkedca.X509Names{
@ -1472,7 +1472,7 @@ func TestPolicyAdminResponder_CreateACMEAccountPolicy(t *testing.T) {
statusCode: 409, statusCode: 409,
} }
}, },
"fail/read.ProtoJSONWithCheck": func(t *testing.T) test { "fail/read.ProtoJSON": func(t *testing.T) test {
prov := &linkedca.Provisioner{ prov := &linkedca.Provisioner{
Name: "provName", Name: "provName",
} }
@ -1637,7 +1637,7 @@ func TestPolicyAdminResponder_UpdateACMEAccountPolicy(t *testing.T) {
statusCode: 404, statusCode: 404,
} }
}, },
"fail/read.ProtoJSONWithCheck": func(t *testing.T) test { "fail/read.ProtoJSON": func(t *testing.T) test {
policy := &linkedca.Policy{ policy := &linkedca.Policy{
X509: &linkedca.X509Policy{ X509: &linkedca.X509Policy{
Allow: &linkedca.X509Names{ Allow: &linkedca.X509Names{

View file

@ -8,18 +8,21 @@ import (
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings"
"testing" "testing"
"time" "time"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts" "github.com/google/go-cmp/cmp/cmpopts"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/types/known/timestamppb"
"go.step.sm/linkedca"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"go.step.sm/linkedca"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/types/known/timestamppb"
) )
func TestHandler_GetProvisioner(t *testing.T) { func TestHandler_GetProvisioner(t *testing.T) {
@ -335,12 +338,12 @@ func TestHandler_CreateProvisioner(t *testing.T) {
return test{ return test{
ctx: context.Background(), ctx: context.Background(),
body: body, body: body,
statusCode: 500, statusCode: 400,
err: &admin.Error{ // TODO(hs): this probably needs a better error err: &admin.Error{
Type: "", Type: "badRequest",
Status: 500, Status: 400,
Detail: "", Detail: "bad request",
Message: "", Message: "proto: syntax error (line 1:2): invalid value !",
}, },
} }
}, },
@ -423,9 +426,15 @@ func TestHandler_CreateProvisioner(t *testing.T) {
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr)) assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr))
assert.Equals(t, tc.err.Type, adminErr.Type) assert.Equals(t, tc.err.Type, adminErr.Type)
assert.Equals(t, tc.err.Message, adminErr.Message)
assert.Equals(t, tc.err.Detail, adminErr.Detail) assert.Equals(t, tc.err.Detail, adminErr.Detail)
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
if strings.HasPrefix(tc.err.Message, "proto:") {
assert.True(t, strings.Contains(tc.err.Message, "syntax error"))
} else {
assert.Equals(t, tc.err.Message, adminErr.Message)
}
return return
} }
@ -616,12 +625,12 @@ func TestHandler_UpdateProvisioner(t *testing.T) {
return test{ return test{
ctx: context.Background(), ctx: context.Background(),
body: body, body: body,
statusCode: 500, statusCode: 400,
err: &admin.Error{ // TODO(hs): this probably needs a better error err: &admin.Error{
Type: "", Type: "badRequest",
Status: 500, Status: 400,
Detail: "", Detail: "bad request",
Message: "", Message: "proto: syntax error (line 1:2): invalid value !",
}, },
} }
}, },
@ -1074,9 +1083,15 @@ func TestHandler_UpdateProvisioner(t *testing.T) {
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr)) assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr))
assert.Equals(t, tc.err.Type, adminErr.Type) assert.Equals(t, tc.err.Type, adminErr.Type)
assert.Equals(t, tc.err.Message, adminErr.Message)
assert.Equals(t, tc.err.Detail, adminErr.Detail) assert.Equals(t, tc.err.Detail, adminErr.Detail)
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
if strings.HasPrefix(tc.err.Message, "proto:") {
assert.True(t, strings.Contains(tc.err.Message, "syntax error"))
} else {
assert.Equals(t, tc.err.Message, adminErr.Message)
}
return return
} }