forked from TrueCloudLab/certificates
Improve handling of bad JSON protobuf bodies
This commit is contained in:
parent
2ca5c0170f
commit
def9438ad6
4 changed files with 90 additions and 77 deletions
|
@ -10,7 +10,6 @@ import (
|
||||||
"google.golang.org/protobuf/encoding/protojson"
|
"google.golang.org/protobuf/encoding/protojson"
|
||||||
"google.golang.org/protobuf/proto"
|
"google.golang.org/protobuf/proto"
|
||||||
|
|
||||||
"github.com/smallstep/certificates/api/render"
|
|
||||||
"github.com/smallstep/certificates/errs"
|
"github.com/smallstep/certificates/errs"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -24,62 +23,55 @@ func JSON(r io.Reader, v interface{}) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProtoJSON reads JSON from the request body and stores it in the value
|
// ProtoJSON reads JSON from the request body and stores it in the value
|
||||||
// pointed by v.
|
// pointed to by v.
|
||||||
func ProtoJSON(r io.Reader, m proto.Message) error {
|
func ProtoJSON(r io.Reader, m proto.Message) error {
|
||||||
data, err := io.ReadAll(r)
|
data, err := io.ReadAll(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errs.BadRequestErr(err, "error reading request body")
|
return errs.BadRequestErr(err, "error reading request body")
|
||||||
}
|
}
|
||||||
return protojson.Unmarshal(data, m)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ProtoJSONWithCheck reads JSON from the request body and stores it in the value
|
|
||||||
// pointed to by m. Returns false if an error was written; true if not.
|
|
||||||
// TODO(hs): refactor this after the API flow changes are in (or before if that works)
|
|
||||||
func ProtoJSONWithCheck(w http.ResponseWriter, r io.Reader, m proto.Message) bool {
|
|
||||||
data, err := io.ReadAll(r)
|
|
||||||
if err != nil {
|
|
||||||
var wrapper = struct {
|
|
||||||
Status int `json:"code"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
}{
|
|
||||||
Status: http.StatusBadRequest,
|
|
||||||
Message: err.Error(),
|
|
||||||
}
|
|
||||||
errData, err := json.Marshal(wrapper)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.WriteHeader(http.StatusBadRequest)
|
|
||||||
w.Write(errData)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if err := protojson.Unmarshal(data, m); err != nil {
|
if err := protojson.Unmarshal(data, m); err != nil {
|
||||||
if errors.Is(err, proto.Error) {
|
if errors.Is(err, proto.Error) {
|
||||||
var wrapper = struct {
|
return newBadProtoJSONError(err)
|
||||||
Type string `json:"type"`
|
|
||||||
Detail string `json:"detail"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
}{
|
|
||||||
Type: "badRequest",
|
|
||||||
Detail: "bad request",
|
|
||||||
Message: err.Error(),
|
|
||||||
}
|
|
||||||
errData, err := json.Marshal(wrapper)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.WriteHeader(http.StatusBadRequest)
|
|
||||||
w.Write(errData)
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// fallback to the default error writer
|
// BadProtoJSONError is an error type that is used when a proto
|
||||||
render.Error(w, err)
|
// message cannot be unmarshaled. Usually this is caused by an error
|
||||||
return false
|
// in the request body.
|
||||||
|
type BadProtoJSONError struct {
|
||||||
|
err error
|
||||||
|
Type string `json:"type"`
|
||||||
|
Detail string `json:"detail"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// newBadProtoJSONError returns a new instance of BadProtoJSONError
|
||||||
|
// This error type is always caused by an error in the request body.
|
||||||
|
func newBadProtoJSONError(err error) *BadProtoJSONError {
|
||||||
|
return &BadProtoJSONError{
|
||||||
|
err: err,
|
||||||
|
Type: "badRequest",
|
||||||
|
Detail: "bad request",
|
||||||
|
Message: err.Error(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error implements the error interface
|
||||||
|
func (e *BadProtoJSONError) Error() string {
|
||||||
|
return e.err.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Render implements render.RenderableError for BadProtoError
|
||||||
|
func (e *BadProtoJSONError) Render(w http.ResponseWriter) {
|
||||||
|
|
||||||
|
errData, err := json.Marshal(e)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
w.Write(errData)
|
||||||
}
|
}
|
||||||
|
|
|
@ -80,7 +80,8 @@ func (par *PolicyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r
|
||||||
}
|
}
|
||||||
|
|
||||||
var newPolicy = new(linkedca.Policy)
|
var newPolicy = new(linkedca.Policy)
|
||||||
if !read.ProtoJSONWithCheck(w, r.Body, newPolicy) {
|
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
|
||||||
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -120,7 +121,8 @@ func (par *PolicyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r
|
||||||
}
|
}
|
||||||
|
|
||||||
var newPolicy = new(linkedca.Policy)
|
var newPolicy = new(linkedca.Policy)
|
||||||
if !read.ProtoJSONWithCheck(w, r.Body, newPolicy) {
|
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
|
||||||
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -195,7 +197,8 @@ func (par *PolicyAdminResponder) CreateProvisionerPolicy(w http.ResponseWriter,
|
||||||
}
|
}
|
||||||
|
|
||||||
var newPolicy = new(linkedca.Policy)
|
var newPolicy = new(linkedca.Policy)
|
||||||
if !read.ProtoJSONWithCheck(w, r.Body, newPolicy) {
|
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
|
||||||
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -228,7 +231,8 @@ func (par *PolicyAdminResponder) UpdateProvisionerPolicy(w http.ResponseWriter,
|
||||||
}
|
}
|
||||||
|
|
||||||
var newPolicy = new(linkedca.Policy)
|
var newPolicy = new(linkedca.Policy)
|
||||||
if !read.ProtoJSONWithCheck(w, r.Body, newPolicy) {
|
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
|
||||||
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -297,7 +301,8 @@ func (par *PolicyAdminResponder) CreateACMEAccountPolicy(w http.ResponseWriter,
|
||||||
}
|
}
|
||||||
|
|
||||||
var newPolicy = new(linkedca.Policy)
|
var newPolicy = new(linkedca.Policy)
|
||||||
if !read.ProtoJSONWithCheck(w, r.Body, newPolicy) {
|
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
|
||||||
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -324,7 +329,8 @@ func (par *PolicyAdminResponder) UpdateACMEAccountPolicy(w http.ResponseWriter,
|
||||||
}
|
}
|
||||||
|
|
||||||
var newPolicy = new(linkedca.Policy)
|
var newPolicy = new(linkedca.Policy)
|
||||||
if !read.ProtoJSONWithCheck(w, r.Body, newPolicy) {
|
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
|
||||||
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -167,7 +167,7 @@ func TestPolicyAdminResponder_CreateAuthorityPolicy(t *testing.T) {
|
||||||
statusCode: 409,
|
statusCode: 409,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/read.ProtoJSONWithCheck": func(t *testing.T) test {
|
"fail/read.ProtoJSON": func(t *testing.T) test {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
adminErr := admin.NewError(admin.ErrorBadRequestType, "proto: syntax error (line 1:2): invalid value ?")
|
adminErr := admin.NewError(admin.ErrorBadRequestType, "proto: syntax error (line 1:2): invalid value ?")
|
||||||
adminErr.Message = "proto: syntax error (line 1:2): invalid value ?"
|
adminErr.Message = "proto: syntax error (line 1:2): invalid value ?"
|
||||||
|
@ -410,7 +410,7 @@ func TestPolicyAdminResponder_UpdateAuthorityPolicy(t *testing.T) {
|
||||||
statusCode: 404,
|
statusCode: 404,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/read.ProtoJSONWithCheck": func(t *testing.T) test {
|
"fail/read.ProtoJSON": func(t *testing.T) test {
|
||||||
policy := &linkedca.Policy{
|
policy := &linkedca.Policy{
|
||||||
X509: &linkedca.X509Policy{
|
X509: &linkedca.X509Policy{
|
||||||
Allow: &linkedca.X509Names{
|
Allow: &linkedca.X509Names{
|
||||||
|
@ -871,7 +871,7 @@ func TestPolicyAdminResponder_CreateProvisionerPolicy(t *testing.T) {
|
||||||
statusCode: 409,
|
statusCode: 409,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/read.ProtoJSONWithCheck": func(t *testing.T) test {
|
"fail/read.ProtoJSON": func(t *testing.T) test {
|
||||||
prov := &linkedca.Provisioner{
|
prov := &linkedca.Provisioner{
|
||||||
Name: "provName",
|
Name: "provName",
|
||||||
}
|
}
|
||||||
|
@ -1060,7 +1060,7 @@ func TestPolicyAdminResponder_UpdateProvisionerPolicy(t *testing.T) {
|
||||||
statusCode: 404,
|
statusCode: 404,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/read.ProtoJSONWithCheck": func(t *testing.T) test {
|
"fail/read.ProtoJSON": func(t *testing.T) test {
|
||||||
policy := &linkedca.Policy{
|
policy := &linkedca.Policy{
|
||||||
X509: &linkedca.X509Policy{
|
X509: &linkedca.X509Policy{
|
||||||
Allow: &linkedca.X509Names{
|
Allow: &linkedca.X509Names{
|
||||||
|
@ -1472,7 +1472,7 @@ func TestPolicyAdminResponder_CreateACMEAccountPolicy(t *testing.T) {
|
||||||
statusCode: 409,
|
statusCode: 409,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/read.ProtoJSONWithCheck": func(t *testing.T) test {
|
"fail/read.ProtoJSON": func(t *testing.T) test {
|
||||||
prov := &linkedca.Provisioner{
|
prov := &linkedca.Provisioner{
|
||||||
Name: "provName",
|
Name: "provName",
|
||||||
}
|
}
|
||||||
|
@ -1637,7 +1637,7 @@ func TestPolicyAdminResponder_UpdateACMEAccountPolicy(t *testing.T) {
|
||||||
statusCode: 404,
|
statusCode: 404,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/read.ProtoJSONWithCheck": func(t *testing.T) test {
|
"fail/read.ProtoJSON": func(t *testing.T) test {
|
||||||
policy := &linkedca.Policy{
|
policy := &linkedca.Policy{
|
||||||
X509: &linkedca.X509Policy{
|
X509: &linkedca.X509Policy{
|
||||||
Allow: &linkedca.X509Names{
|
Allow: &linkedca.X509Names{
|
||||||
|
|
|
@ -8,18 +8,21 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-chi/chi"
|
"github.com/go-chi/chi"
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/google/go-cmp/cmp/cmpopts"
|
"github.com/google/go-cmp/cmp/cmpopts"
|
||||||
|
"google.golang.org/protobuf/encoding/protojson"
|
||||||
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
|
||||||
|
"go.step.sm/linkedca"
|
||||||
|
|
||||||
"github.com/smallstep/assert"
|
"github.com/smallstep/assert"
|
||||||
"github.com/smallstep/certificates/authority/admin"
|
"github.com/smallstep/certificates/authority/admin"
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
"go.step.sm/linkedca"
|
|
||||||
"google.golang.org/protobuf/encoding/protojson"
|
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestHandler_GetProvisioner(t *testing.T) {
|
func TestHandler_GetProvisioner(t *testing.T) {
|
||||||
|
@ -335,12 +338,12 @@ func TestHandler_CreateProvisioner(t *testing.T) {
|
||||||
return test{
|
return test{
|
||||||
ctx: context.Background(),
|
ctx: context.Background(),
|
||||||
body: body,
|
body: body,
|
||||||
statusCode: 500,
|
statusCode: 400,
|
||||||
err: &admin.Error{ // TODO(hs): this probably needs a better error
|
err: &admin.Error{
|
||||||
Type: "",
|
Type: "badRequest",
|
||||||
Status: 500,
|
Status: 400,
|
||||||
Detail: "",
|
Detail: "bad request",
|
||||||
Message: "",
|
Message: "proto: syntax error (line 1:2): invalid value !",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -423,9 +426,15 @@ func TestHandler_CreateProvisioner(t *testing.T) {
|
||||||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr))
|
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr))
|
||||||
|
|
||||||
assert.Equals(t, tc.err.Type, adminErr.Type)
|
assert.Equals(t, tc.err.Type, adminErr.Type)
|
||||||
assert.Equals(t, tc.err.Message, adminErr.Message)
|
|
||||||
assert.Equals(t, tc.err.Detail, adminErr.Detail)
|
assert.Equals(t, tc.err.Detail, adminErr.Detail)
|
||||||
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
|
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
|
||||||
|
|
||||||
|
if strings.HasPrefix(tc.err.Message, "proto:") {
|
||||||
|
assert.True(t, strings.Contains(tc.err.Message, "syntax error"))
|
||||||
|
} else {
|
||||||
|
assert.Equals(t, tc.err.Message, adminErr.Message)
|
||||||
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -616,12 +625,12 @@ func TestHandler_UpdateProvisioner(t *testing.T) {
|
||||||
return test{
|
return test{
|
||||||
ctx: context.Background(),
|
ctx: context.Background(),
|
||||||
body: body,
|
body: body,
|
||||||
statusCode: 500,
|
statusCode: 400,
|
||||||
err: &admin.Error{ // TODO(hs): this probably needs a better error
|
err: &admin.Error{
|
||||||
Type: "",
|
Type: "badRequest",
|
||||||
Status: 500,
|
Status: 400,
|
||||||
Detail: "",
|
Detail: "bad request",
|
||||||
Message: "",
|
Message: "proto: syntax error (line 1:2): invalid value !",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -1074,9 +1083,15 @@ func TestHandler_UpdateProvisioner(t *testing.T) {
|
||||||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr))
|
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr))
|
||||||
|
|
||||||
assert.Equals(t, tc.err.Type, adminErr.Type)
|
assert.Equals(t, tc.err.Type, adminErr.Type)
|
||||||
assert.Equals(t, tc.err.Message, adminErr.Message)
|
|
||||||
assert.Equals(t, tc.err.Detail, adminErr.Detail)
|
assert.Equals(t, tc.err.Detail, adminErr.Detail)
|
||||||
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
|
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
|
||||||
|
|
||||||
|
if strings.HasPrefix(tc.err.Message, "proto:") {
|
||||||
|
assert.True(t, strings.Contains(tc.err.Message, "syntax error"))
|
||||||
|
} else {
|
||||||
|
assert.Equals(t, tc.err.Message, adminErr.Message)
|
||||||
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue