api/read: reworked JSON & ProtoJSON to use buffers & added AdminJSON
This commit is contained in:
parent
2fd84227f0
commit
a715e57d04
2 changed files with 115 additions and 43 deletions
|
@ -2,30 +2,91 @@
|
||||||
package read
|
package read
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"io"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
"google.golang.org/protobuf/encoding/protojson"
|
"google.golang.org/protobuf/encoding/protojson"
|
||||||
"google.golang.org/protobuf/proto"
|
"google.golang.org/protobuf/proto"
|
||||||
|
|
||||||
"github.com/smallstep/certificates/errs"
|
"github.com/smallstep/certificates/api/render"
|
||||||
|
"github.com/smallstep/certificates/authority/admin"
|
||||||
|
|
||||||
|
"github.com/smallstep/certificates/internal/buffer"
|
||||||
)
|
)
|
||||||
|
|
||||||
// JSON reads JSON from the request body and stores it in the value
|
// JSON unmarshals from the given request's JSON body into v. In case of an
|
||||||
// pointed by v.
|
// error a HTTP Bad Request error will be written to w.
|
||||||
func JSON(r io.Reader, v interface{}) error {
|
func JSON(w http.ResponseWriter, r *http.Request, v interface{}) bool {
|
||||||
if err := json.NewDecoder(r).Decode(v); err != nil {
|
b := read(w, r)
|
||||||
return errs.BadRequestErr(err, "error decoding json")
|
if b == nil {
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
return nil
|
defer buffer.Put(b)
|
||||||
|
|
||||||
|
if err := json.NewDecoder(b).Decode(v); err != nil {
|
||||||
|
err = fmt.Errorf("error decoding json: %w", err)
|
||||||
|
|
||||||
|
render.BadRequest(w, err)
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// AdminJSON is obsolete; it's here for backwards compatibility.
|
||||||
|
//
|
||||||
|
// Please don't use.
|
||||||
|
func AdminJSON(w http.ResponseWriter, r *http.Request, v interface{}) bool {
|
||||||
|
b := read(w, r)
|
||||||
|
if b == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
defer buffer.Put(b)
|
||||||
|
|
||||||
|
if err := json.NewDecoder(b).Decode(v); err != nil {
|
||||||
|
e := admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")
|
||||||
|
admin.WriteError(w, e)
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProtoJSON reads JSON from the request body and stores it in the value
|
// ProtoJSON reads JSON from the request body and stores it in the value
|
||||||
// pointed by v.
|
// pointed by v.
|
||||||
func ProtoJSON(r io.Reader, m proto.Message) error {
|
func ProtoJSON(w http.ResponseWriter, r *http.Request, m proto.Message) bool {
|
||||||
data, err := io.ReadAll(r)
|
b := read(w, r)
|
||||||
if err != nil {
|
if b == nil {
|
||||||
return errs.BadRequestErr(err, "error reading request body")
|
return false
|
||||||
}
|
}
|
||||||
return protojson.Unmarshal(data, m)
|
defer buffer.Put(b)
|
||||||
|
|
||||||
|
if err := protojson.Unmarshal(b.Bytes(), m); err != nil {
|
||||||
|
err = fmt.Errorf("error decoding proto json: %w", err)
|
||||||
|
|
||||||
|
render.BadRequest(w, err)
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func read(w http.ResponseWriter, r *http.Request) *bytes.Buffer {
|
||||||
|
b := buffer.Get()
|
||||||
|
if _, err := b.ReadFrom(r.Body); err != nil {
|
||||||
|
buffer.Put(b)
|
||||||
|
|
||||||
|
err = fmt.Errorf("error reading request body: %w", err)
|
||||||
|
|
||||||
|
render.BadRequest(w, err)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return b
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,45 +2,56 @@ package read
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"io"
|
"io"
|
||||||
"reflect"
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"testing/iotest"
|
||||||
|
|
||||||
"github.com/smallstep/certificates/errs"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestJSON(t *testing.T) {
|
func TestJSON(t *testing.T) {
|
||||||
type args struct {
|
cases := []struct {
|
||||||
r io.Reader
|
src io.Reader
|
||||||
v interface{}
|
exp interface{}
|
||||||
}
|
ok bool
|
||||||
tests := []struct {
|
code int
|
||||||
name string
|
|
||||||
args args
|
|
||||||
wantErr bool
|
|
||||||
}{
|
}{
|
||||||
{"ok", args{strings.NewReader(`{"foo":"bar"}`), make(map[string]interface{})}, false},
|
0: {
|
||||||
{"fail", args{strings.NewReader(`{"foo"}`), make(map[string]interface{})}, true},
|
src: strings.NewReader(`{"foo":"bar"}`),
|
||||||
|
exp: map[string]interface{}{"foo": "bar"},
|
||||||
|
ok: true,
|
||||||
|
code: http.StatusOK,
|
||||||
|
},
|
||||||
|
1: {
|
||||||
|
src: strings.NewReader(`{"foo"}`),
|
||||||
|
code: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
2: {
|
||||||
|
src: io.MultiReader(
|
||||||
|
strings.NewReader(`{`),
|
||||||
|
iotest.ErrReader(assert.AnError),
|
||||||
|
strings.NewReader(`"foo":"bar"}`),
|
||||||
|
),
|
||||||
|
code: http.StatusBadRequest,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
err := JSON(tt.args.r, &tt.args.v)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("JSON() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
if tt.wantErr {
|
for caseIndex := range cases {
|
||||||
e, ok := err.(*errs.Error)
|
kase := cases[caseIndex]
|
||||||
if ok {
|
|
||||||
if code := e.StatusCode(); code != 400 {
|
t.Run(strconv.Itoa(caseIndex), func(t *testing.T) {
|
||||||
t.Errorf("error.StatusCode() = %v, wants 400", code)
|
req := httptest.NewRequest(http.MethodGet, "/", kase.src)
|
||||||
}
|
rec := httptest.NewRecorder()
|
||||||
} else {
|
|
||||||
t.Errorf("error type = %T, wants *Error", err)
|
var body interface{}
|
||||||
}
|
got := JSON(rec, req, &body)
|
||||||
} else if !reflect.DeepEqual(tt.args.v, map[string]interface{}{"foo": "bar"}) {
|
|
||||||
t.Errorf("JSON value = %v, wants %v", tt.args.v, map[string]interface{}{"foo": "bar"})
|
assert.Equal(t, kase.ok, got)
|
||||||
}
|
assert.Equal(t, kase.code, rec.Result().StatusCode)
|
||||||
|
assert.Equal(t, kase.exp, body)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue