api/read: reworked JSON & ProtoJSON to use buffers & added AdminJSON

This commit is contained in:
Panagiotis Siatras 2022-03-22 18:38:03 +02:00
parent 2fd84227f0
commit a715e57d04
No known key found for this signature in database
GPG key ID: 529695F03A572804
2 changed files with 115 additions and 43 deletions

View file

@ -2,30 +2,91 @@
package read
import (
"bytes"
"encoding/json"
"io"
"fmt"
"net/http"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/internal/buffer"
)
// JSON reads JSON from the request body and stores it in the value
// pointed by v.
func JSON(r io.Reader, v interface{}) error {
if err := json.NewDecoder(r).Decode(v); err != nil {
return errs.BadRequestErr(err, "error decoding json")
// JSON unmarshals from the given request's JSON body into v. In case of an
// error a HTTP Bad Request error will be written to w.
func JSON(w http.ResponseWriter, r *http.Request, v interface{}) bool {
b := read(w, r)
if b == nil {
return false
}
return nil
defer buffer.Put(b)
if err := json.NewDecoder(b).Decode(v); err != nil {
err = fmt.Errorf("error decoding json: %w", err)
render.BadRequest(w, err)
return false
}
return true
}
// AdminJSON is obsolete; it's here for backwards compatibility.
//
// Please don't use.
func AdminJSON(w http.ResponseWriter, r *http.Request, v interface{}) bool {
b := read(w, r)
if b == nil {
return false
}
defer buffer.Put(b)
if err := json.NewDecoder(b).Decode(v); err != nil {
e := admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")
admin.WriteError(w, e)
return false
}
return true
}
// ProtoJSON reads JSON from the request body and stores it in the value
// pointed by v.
func ProtoJSON(r io.Reader, m proto.Message) error {
data, err := io.ReadAll(r)
if err != nil {
return errs.BadRequestErr(err, "error reading request body")
func ProtoJSON(w http.ResponseWriter, r *http.Request, m proto.Message) bool {
b := read(w, r)
if b == nil {
return false
}
return protojson.Unmarshal(data, m)
defer buffer.Put(b)
if err := protojson.Unmarshal(b.Bytes(), m); err != nil {
err = fmt.Errorf("error decoding proto json: %w", err)
render.BadRequest(w, err)
return false
}
return true
}
func read(w http.ResponseWriter, r *http.Request) *bytes.Buffer {
b := buffer.Get()
if _, err := b.ReadFrom(r.Body); err != nil {
buffer.Put(b)
err = fmt.Errorf("error reading request body: %w", err)
render.BadRequest(w, err)
return nil
}
return b
}

View file

@ -2,45 +2,56 @@ package read
import (
"io"
"reflect"
"net/http"
"net/http/httptest"
"strconv"
"strings"
"testing"
"testing/iotest"
"github.com/smallstep/certificates/errs"
"github.com/stretchr/testify/assert"
)
func TestJSON(t *testing.T) {
type args struct {
r io.Reader
v interface{}
}
tests := []struct {
name string
args args
wantErr bool
cases := []struct {
src io.Reader
exp interface{}
ok bool
code int
}{
{"ok", args{strings.NewReader(`{"foo":"bar"}`), make(map[string]interface{})}, false},
{"fail", args{strings.NewReader(`{"foo"}`), make(map[string]interface{})}, true},
0: {
src: strings.NewReader(`{"foo":"bar"}`),
exp: map[string]interface{}{"foo": "bar"},
ok: true,
code: http.StatusOK,
},
1: {
src: strings.NewReader(`{"foo"}`),
code: http.StatusBadRequest,
},
2: {
src: io.MultiReader(
strings.NewReader(`{`),
iotest.ErrReader(assert.AnError),
strings.NewReader(`"foo":"bar"}`),
),
code: http.StatusBadRequest,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := JSON(tt.args.r, &tt.args.v)
if (err != nil) != tt.wantErr {
t.Errorf("JSON() error = %v, wantErr %v", err, tt.wantErr)
}
if tt.wantErr {
e, ok := err.(*errs.Error)
if ok {
if code := e.StatusCode(); code != 400 {
t.Errorf("error.StatusCode() = %v, wants 400", code)
}
} else {
t.Errorf("error type = %T, wants *Error", err)
}
} else if !reflect.DeepEqual(tt.args.v, map[string]interface{}{"foo": "bar"}) {
t.Errorf("JSON value = %v, wants %v", tt.args.v, map[string]interface{}{"foo": "bar"})
}
for caseIndex := range cases {
kase := cases[caseIndex]
t.Run(strconv.Itoa(caseIndex), func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", kase.src)
rec := httptest.NewRecorder()
var body interface{}
got := JSON(rec, req, &body)
assert.Equal(t, kase.ok, got)
assert.Equal(t, kase.code, rec.Result().StatusCode)
assert.Equal(t, kase.exp, body)
})
}
}