api/read: reworked JSON & ProtoJSON to use buffers & added AdminJSON
This commit is contained in:
parent
2fd84227f0
commit
a715e57d04
2 changed files with 115 additions and 43 deletions
|
@ -2,30 +2,91 @@
|
|||
package read
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/certificates/api/render"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
|
||||
"github.com/smallstep/certificates/internal/buffer"
|
||||
)
|
||||
|
||||
// JSON reads JSON from the request body and stores it in the value
|
||||
// pointed by v.
|
||||
func JSON(r io.Reader, v interface{}) error {
|
||||
if err := json.NewDecoder(r).Decode(v); err != nil {
|
||||
return errs.BadRequestErr(err, "error decoding json")
|
||||
// JSON unmarshals from the given request's JSON body into v. In case of an
|
||||
// error a HTTP Bad Request error will be written to w.
|
||||
func JSON(w http.ResponseWriter, r *http.Request, v interface{}) bool {
|
||||
b := read(w, r)
|
||||
if b == nil {
|
||||
return false
|
||||
}
|
||||
return nil
|
||||
defer buffer.Put(b)
|
||||
|
||||
if err := json.NewDecoder(b).Decode(v); err != nil {
|
||||
err = fmt.Errorf("error decoding json: %w", err)
|
||||
|
||||
render.BadRequest(w, err)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// AdminJSON is obsolete; it's here for backwards compatibility.
|
||||
//
|
||||
// Please don't use.
|
||||
func AdminJSON(w http.ResponseWriter, r *http.Request, v interface{}) bool {
|
||||
b := read(w, r)
|
||||
if b == nil {
|
||||
return false
|
||||
}
|
||||
defer buffer.Put(b)
|
||||
|
||||
if err := json.NewDecoder(b).Decode(v); err != nil {
|
||||
e := admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")
|
||||
admin.WriteError(w, e)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// ProtoJSON reads JSON from the request body and stores it in the value
|
||||
// pointed by v.
|
||||
func ProtoJSON(r io.Reader, m proto.Message) error {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return errs.BadRequestErr(err, "error reading request body")
|
||||
func ProtoJSON(w http.ResponseWriter, r *http.Request, m proto.Message) bool {
|
||||
b := read(w, r)
|
||||
if b == nil {
|
||||
return false
|
||||
}
|
||||
return protojson.Unmarshal(data, m)
|
||||
defer buffer.Put(b)
|
||||
|
||||
if err := protojson.Unmarshal(b.Bytes(), m); err != nil {
|
||||
err = fmt.Errorf("error decoding proto json: %w", err)
|
||||
|
||||
render.BadRequest(w, err)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func read(w http.ResponseWriter, r *http.Request) *bytes.Buffer {
|
||||
b := buffer.Get()
|
||||
if _, err := b.ReadFrom(r.Body); err != nil {
|
||||
buffer.Put(b)
|
||||
|
||||
err = fmt.Errorf("error reading request body: %w", err)
|
||||
|
||||
render.BadRequest(w, err)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
|
|
@ -2,45 +2,56 @@ package read
|
|||
|
||||
import (
|
||||
"io"
|
||||
"reflect"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"testing/iotest"
|
||||
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestJSON(t *testing.T) {
|
||||
type args struct {
|
||||
r io.Reader
|
||||
v interface{}
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
cases := []struct {
|
||||
src io.Reader
|
||||
exp interface{}
|
||||
ok bool
|
||||
code int
|
||||
}{
|
||||
{"ok", args{strings.NewReader(`{"foo":"bar"}`), make(map[string]interface{})}, false},
|
||||
{"fail", args{strings.NewReader(`{"foo"}`), make(map[string]interface{})}, true},
|
||||
0: {
|
||||
src: strings.NewReader(`{"foo":"bar"}`),
|
||||
exp: map[string]interface{}{"foo": "bar"},
|
||||
ok: true,
|
||||
code: http.StatusOK,
|
||||
},
|
||||
1: {
|
||||
src: strings.NewReader(`{"foo"}`),
|
||||
code: http.StatusBadRequest,
|
||||
},
|
||||
2: {
|
||||
src: io.MultiReader(
|
||||
strings.NewReader(`{`),
|
||||
iotest.ErrReader(assert.AnError),
|
||||
strings.NewReader(`"foo":"bar"}`),
|
||||
),
|
||||
code: http.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := JSON(tt.args.r, &tt.args.v)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("JSON() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
|
||||
if tt.wantErr {
|
||||
e, ok := err.(*errs.Error)
|
||||
if ok {
|
||||
if code := e.StatusCode(); code != 400 {
|
||||
t.Errorf("error.StatusCode() = %v, wants 400", code)
|
||||
}
|
||||
} else {
|
||||
t.Errorf("error type = %T, wants *Error", err)
|
||||
}
|
||||
} else if !reflect.DeepEqual(tt.args.v, map[string]interface{}{"foo": "bar"}) {
|
||||
t.Errorf("JSON value = %v, wants %v", tt.args.v, map[string]interface{}{"foo": "bar"})
|
||||
}
|
||||
for caseIndex := range cases {
|
||||
kase := cases[caseIndex]
|
||||
|
||||
t.Run(strconv.Itoa(caseIndex), func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", kase.src)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
var body interface{}
|
||||
got := JSON(rec, req, &body)
|
||||
|
||||
assert.Equal(t, kase.ok, got)
|
||||
assert.Equal(t, kase.code, rec.Result().StatusCode)
|
||||
assert.Equal(t, kase.exp, body)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue