api/read: reworked JSON & ProtoJSON to use buffers & added AdminJSON

This commit is contained in:
Panagiotis Siatras 2022-03-22 18:38:03 +02:00
parent 2fd84227f0
commit a715e57d04
No known key found for this signature in database
GPG key ID: 529695F03A572804
2 changed files with 115 additions and 43 deletions

View file

@ -2,30 +2,91 @@
package read package read
import ( import (
"bytes"
"encoding/json" "encoding/json"
"io" "fmt"
"net/http"
"google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/internal/buffer"
) )
// JSON reads JSON from the request body and stores it in the value // JSON unmarshals from the given request's JSON body into v. In case of an
// pointed by v. // error a HTTP Bad Request error will be written to w.
func JSON(r io.Reader, v interface{}) error { func JSON(w http.ResponseWriter, r *http.Request, v interface{}) bool {
if err := json.NewDecoder(r).Decode(v); err != nil { b := read(w, r)
return errs.BadRequestErr(err, "error decoding json") if b == nil {
return false
} }
return nil defer buffer.Put(b)
if err := json.NewDecoder(b).Decode(v); err != nil {
err = fmt.Errorf("error decoding json: %w", err)
render.BadRequest(w, err)
return false
}
return true
}
// AdminJSON is obsolete; it's here for backwards compatibility.
//
// Please don't use.
func AdminJSON(w http.ResponseWriter, r *http.Request, v interface{}) bool {
b := read(w, r)
if b == nil {
return false
}
defer buffer.Put(b)
if err := json.NewDecoder(b).Decode(v); err != nil {
e := admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")
admin.WriteError(w, e)
return false
}
return true
} }
// ProtoJSON reads JSON from the request body and stores it in the value // ProtoJSON reads JSON from the request body and stores it in the value
// pointed by v. // pointed by v.
func ProtoJSON(r io.Reader, m proto.Message) error { func ProtoJSON(w http.ResponseWriter, r *http.Request, m proto.Message) bool {
data, err := io.ReadAll(r) b := read(w, r)
if err != nil { if b == nil {
return errs.BadRequestErr(err, "error reading request body") return false
} }
return protojson.Unmarshal(data, m) defer buffer.Put(b)
if err := protojson.Unmarshal(b.Bytes(), m); err != nil {
err = fmt.Errorf("error decoding proto json: %w", err)
render.BadRequest(w, err)
return false
}
return true
}
func read(w http.ResponseWriter, r *http.Request) *bytes.Buffer {
b := buffer.Get()
if _, err := b.ReadFrom(r.Body); err != nil {
buffer.Put(b)
err = fmt.Errorf("error reading request body: %w", err)
render.BadRequest(w, err)
return nil
}
return b
} }

View file

@ -2,45 +2,56 @@ package read
import ( import (
"io" "io"
"reflect" "net/http"
"net/http/httptest"
"strconv"
"strings" "strings"
"testing" "testing"
"testing/iotest"
"github.com/smallstep/certificates/errs" "github.com/stretchr/testify/assert"
) )
func TestJSON(t *testing.T) { func TestJSON(t *testing.T) {
type args struct { cases := []struct {
r io.Reader src io.Reader
v interface{} exp interface{}
} ok bool
tests := []struct { code int
name string
args args
wantErr bool
}{ }{
{"ok", args{strings.NewReader(`{"foo":"bar"}`), make(map[string]interface{})}, false}, 0: {
{"fail", args{strings.NewReader(`{"foo"}`), make(map[string]interface{})}, true}, src: strings.NewReader(`{"foo":"bar"}`),
} exp: map[string]interface{}{"foo": "bar"},
for _, tt := range tests { ok: true,
t.Run(tt.name, func(t *testing.T) { code: http.StatusOK,
err := JSON(tt.args.r, &tt.args.v) },
if (err != nil) != tt.wantErr { 1: {
t.Errorf("JSON() error = %v, wantErr %v", err, tt.wantErr) src: strings.NewReader(`{"foo"}`),
code: http.StatusBadRequest,
},
2: {
src: io.MultiReader(
strings.NewReader(`{`),
iotest.ErrReader(assert.AnError),
strings.NewReader(`"foo":"bar"}`),
),
code: http.StatusBadRequest,
},
} }
if tt.wantErr { for caseIndex := range cases {
e, ok := err.(*errs.Error) kase := cases[caseIndex]
if ok {
if code := e.StatusCode(); code != 400 { t.Run(strconv.Itoa(caseIndex), func(t *testing.T) {
t.Errorf("error.StatusCode() = %v, wants 400", code) req := httptest.NewRequest(http.MethodGet, "/", kase.src)
} rec := httptest.NewRecorder()
} else {
t.Errorf("error type = %T, wants *Error", err) var body interface{}
} got := JSON(rec, req, &body)
} else if !reflect.DeepEqual(tt.args.v, map[string]interface{}{"foo": "bar"}) {
t.Errorf("JSON value = %v, wants %v", tt.args.v, map[string]interface{}{"foo": "bar"}) assert.Equal(t, kase.ok, got)
} assert.Equal(t, kase.code, rec.Result().StatusCode)
assert.Equal(t, kase.exp, body)
}) })
} }
} }