From def9438ad62e2c71975e2ac1a4f96b6f5ec663a5 Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Mon, 18 Apr 2022 23:38:13 +0200 Subject: [PATCH] Improve handling of bad JSON protobuf bodies --- api/read/read.go | 88 +++++++++++-------------- authority/admin/api/policy.go | 18 +++-- authority/admin/api/policy_test.go | 12 ++-- authority/admin/api/provisioner_test.go | 49 +++++++++----- 4 files changed, 90 insertions(+), 77 deletions(-) diff --git a/api/read/read.go b/api/read/read.go index 2f5175d9..7482c272 100644 --- a/api/read/read.go +++ b/api/read/read.go @@ -10,7 +10,6 @@ import ( "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" - "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/errs" ) @@ -24,62 +23,55 @@ func JSON(r io.Reader, v interface{}) error { } // ProtoJSON reads JSON from the request body and stores it in the value -// pointed by v. +// pointed to by v. func ProtoJSON(r io.Reader, m proto.Message) error { data, err := io.ReadAll(r) if err != nil { return errs.BadRequestErr(err, "error reading request body") } - return protojson.Unmarshal(data, m) -} - -// ProtoJSONWithCheck reads JSON from the request body and stores it in the value -// pointed to by m. Returns false if an error was written; true if not. -// TODO(hs): refactor this after the API flow changes are in (or before if that works) -func ProtoJSONWithCheck(w http.ResponseWriter, r io.Reader, m proto.Message) bool { - data, err := io.ReadAll(r) - if err != nil { - var wrapper = struct { - Status int `json:"code"` - Message string `json:"message"` - }{ - Status: http.StatusBadRequest, - Message: err.Error(), - } - errData, err := json.Marshal(wrapper) - if err != nil { - panic(err) - } - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusBadRequest) - w.Write(errData) - return false - } if err := protojson.Unmarshal(data, m); err != nil { if errors.Is(err, proto.Error) { - var wrapper = struct { - Type string `json:"type"` - Detail string `json:"detail"` - Message string `json:"message"` - }{ - Type: "badRequest", - Detail: "bad request", - Message: err.Error(), - } - errData, err := json.Marshal(wrapper) - if err != nil { - panic(err) - } - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusBadRequest) - w.Write(errData) - return false + return newBadProtoJSONError(err) } + } + return err +} - // fallback to the default error writer - render.Error(w, err) - return false +// BadProtoJSONError is an error type that is used when a proto +// message cannot be unmarshaled. Usually this is caused by an error +// in the request body. +type BadProtoJSONError struct { + err error + Type string `json:"type"` + Detail string `json:"detail"` + Message string `json:"message"` +} + +// newBadProtoJSONError returns a new instance of BadProtoJSONError +// This error type is always caused by an error in the request body. +func newBadProtoJSONError(err error) *BadProtoJSONError { + return &BadProtoJSONError{ + err: err, + Type: "badRequest", + Detail: "bad request", + Message: err.Error(), + } +} + +// Error implements the error interface +func (e *BadProtoJSONError) Error() string { + return e.err.Error() +} + +// Render implements render.RenderableError for BadProtoError +func (e *BadProtoJSONError) Render(w http.ResponseWriter) { + + errData, err := json.Marshal(e) + if err != nil { + panic(err) } - return true + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + w.Write(errData) } diff --git a/authority/admin/api/policy.go b/authority/admin/api/policy.go index 17bc454c..b7c7855f 100644 --- a/authority/admin/api/policy.go +++ b/authority/admin/api/policy.go @@ -80,7 +80,8 @@ func (par *PolicyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r } var newPolicy = new(linkedca.Policy) - if !read.ProtoJSONWithCheck(w, r.Body, newPolicy) { + if err := read.ProtoJSON(r.Body, newPolicy); err != nil { + render.Error(w, err) return } @@ -120,7 +121,8 @@ func (par *PolicyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r } var newPolicy = new(linkedca.Policy) - if !read.ProtoJSONWithCheck(w, r.Body, newPolicy) { + if err := read.ProtoJSON(r.Body, newPolicy); err != nil { + render.Error(w, err) return } @@ -195,7 +197,8 @@ func (par *PolicyAdminResponder) CreateProvisionerPolicy(w http.ResponseWriter, } var newPolicy = new(linkedca.Policy) - if !read.ProtoJSONWithCheck(w, r.Body, newPolicy) { + if err := read.ProtoJSON(r.Body, newPolicy); err != nil { + render.Error(w, err) return } @@ -228,7 +231,8 @@ func (par *PolicyAdminResponder) UpdateProvisionerPolicy(w http.ResponseWriter, } var newPolicy = new(linkedca.Policy) - if !read.ProtoJSONWithCheck(w, r.Body, newPolicy) { + if err := read.ProtoJSON(r.Body, newPolicy); err != nil { + render.Error(w, err) return } @@ -297,7 +301,8 @@ func (par *PolicyAdminResponder) CreateACMEAccountPolicy(w http.ResponseWriter, } var newPolicy = new(linkedca.Policy) - if !read.ProtoJSONWithCheck(w, r.Body, newPolicy) { + if err := read.ProtoJSON(r.Body, newPolicy); err != nil { + render.Error(w, err) return } @@ -324,7 +329,8 @@ func (par *PolicyAdminResponder) UpdateACMEAccountPolicy(w http.ResponseWriter, } var newPolicy = new(linkedca.Policy) - if !read.ProtoJSONWithCheck(w, r.Body, newPolicy) { + if err := read.ProtoJSON(r.Body, newPolicy); err != nil { + render.Error(w, err) return } diff --git a/authority/admin/api/policy_test.go b/authority/admin/api/policy_test.go index 5717e73a..cc4f64fb 100644 --- a/authority/admin/api/policy_test.go +++ b/authority/admin/api/policy_test.go @@ -167,7 +167,7 @@ func TestPolicyAdminResponder_CreateAuthorityPolicy(t *testing.T) { statusCode: 409, } }, - "fail/read.ProtoJSONWithCheck": func(t *testing.T) test { + "fail/read.ProtoJSON": func(t *testing.T) test { ctx := context.Background() adminErr := admin.NewError(admin.ErrorBadRequestType, "proto: syntax error (line 1:2): invalid value ?") adminErr.Message = "proto: syntax error (line 1:2): invalid value ?" @@ -410,7 +410,7 @@ func TestPolicyAdminResponder_UpdateAuthorityPolicy(t *testing.T) { statusCode: 404, } }, - "fail/read.ProtoJSONWithCheck": func(t *testing.T) test { + "fail/read.ProtoJSON": func(t *testing.T) test { policy := &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ @@ -871,7 +871,7 @@ func TestPolicyAdminResponder_CreateProvisionerPolicy(t *testing.T) { statusCode: 409, } }, - "fail/read.ProtoJSONWithCheck": func(t *testing.T) test { + "fail/read.ProtoJSON": func(t *testing.T) test { prov := &linkedca.Provisioner{ Name: "provName", } @@ -1060,7 +1060,7 @@ func TestPolicyAdminResponder_UpdateProvisionerPolicy(t *testing.T) { statusCode: 404, } }, - "fail/read.ProtoJSONWithCheck": func(t *testing.T) test { + "fail/read.ProtoJSON": func(t *testing.T) test { policy := &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ @@ -1472,7 +1472,7 @@ func TestPolicyAdminResponder_CreateACMEAccountPolicy(t *testing.T) { statusCode: 409, } }, - "fail/read.ProtoJSONWithCheck": func(t *testing.T) test { + "fail/read.ProtoJSON": func(t *testing.T) test { prov := &linkedca.Provisioner{ Name: "provName", } @@ -1637,7 +1637,7 @@ func TestPolicyAdminResponder_UpdateACMEAccountPolicy(t *testing.T) { statusCode: 404, } }, - "fail/read.ProtoJSONWithCheck": func(t *testing.T) test { + "fail/read.ProtoJSON": func(t *testing.T) test { policy := &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ diff --git a/authority/admin/api/provisioner_test.go b/authority/admin/api/provisioner_test.go index 6d5024f2..de7c3646 100644 --- a/authority/admin/api/provisioner_test.go +++ b/authority/admin/api/provisioner_test.go @@ -8,18 +8,21 @@ import ( "io" "net/http" "net/http/httptest" + "strings" "testing" "time" "github.com/go-chi/chi" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/types/known/timestamppb" + + "go.step.sm/linkedca" + "github.com/smallstep/assert" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/provisioner" - "go.step.sm/linkedca" - "google.golang.org/protobuf/encoding/protojson" - "google.golang.org/protobuf/types/known/timestamppb" ) func TestHandler_GetProvisioner(t *testing.T) { @@ -335,12 +338,12 @@ func TestHandler_CreateProvisioner(t *testing.T) { return test{ ctx: context.Background(), body: body, - statusCode: 500, - err: &admin.Error{ // TODO(hs): this probably needs a better error - Type: "", - Status: 500, - Detail: "", - Message: "", + statusCode: 400, + err: &admin.Error{ + Type: "badRequest", + Status: 400, + Detail: "bad request", + Message: "proto: syntax error (line 1:2): invalid value !", }, } }, @@ -423,9 +426,15 @@ func TestHandler_CreateProvisioner(t *testing.T) { assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr)) assert.Equals(t, tc.err.Type, adminErr.Type) - assert.Equals(t, tc.err.Message, adminErr.Message) assert.Equals(t, tc.err.Detail, adminErr.Detail) assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + + if strings.HasPrefix(tc.err.Message, "proto:") { + assert.True(t, strings.Contains(tc.err.Message, "syntax error")) + } else { + assert.Equals(t, tc.err.Message, adminErr.Message) + } + return } @@ -616,12 +625,12 @@ func TestHandler_UpdateProvisioner(t *testing.T) { return test{ ctx: context.Background(), body: body, - statusCode: 500, - err: &admin.Error{ // TODO(hs): this probably needs a better error - Type: "", - Status: 500, - Detail: "", - Message: "", + statusCode: 400, + err: &admin.Error{ + Type: "badRequest", + Status: 400, + Detail: "bad request", + Message: "proto: syntax error (line 1:2): invalid value !", }, } }, @@ -1074,9 +1083,15 @@ func TestHandler_UpdateProvisioner(t *testing.T) { assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr)) assert.Equals(t, tc.err.Type, adminErr.Type) - assert.Equals(t, tc.err.Message, adminErr.Message) assert.Equals(t, tc.err.Detail, adminErr.Detail) assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + + if strings.HasPrefix(tc.err.Message, "proto:") { + assert.True(t, strings.Contains(tc.err.Message, "syntax error")) + } else { + assert.Equals(t, tc.err.Message, adminErr.Message) + } + return }