From 13173ec8a2d3fdc4150304aeb60937bb4ddf20df Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Sun, 1 May 2022 22:29:17 +0200 Subject: [PATCH 1/2] Fix SCEP GET requests --- scep/api/api.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scep/api/api.go b/scep/api/api.go index 31f0f10d..fcabfc58 100644 --- a/scep/api/api.go +++ b/scep/api/api.go @@ -86,7 +86,7 @@ func (h *handler) Get(w http.ResponseWriter, r *http.Request) { case opnGetCACaps: res, err = h.GetCACaps(ctx) case opnPKIOperation: - // TODO: implement the GET for PKI operation? Default CACAPS doesn't specify this is in use, though + res, err = h.PKIOperation(ctx, req) default: err = fmt.Errorf("unknown operation: %s", req.Operation) } @@ -151,8 +151,8 @@ func decodeRequest(r *http.Request) (request, error) { if _, ok := query["message"]; ok { message = query.Get("message") } - // TODO: verify this; it seems like it should be StdEncoding instead of URLEncoding - decodedMessage, err := base64.URLEncoding.DecodeString(message) + // TODO: verify this; right type of encoding? Needs additional transformations? + decodedMessage, err := base64.StdEncoding.DecodeString(message) if err != nil { return request{}, err } From 688ae837a4beaa8ec1c147a8c33aa2f4c22ef561 Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Sat, 7 May 2022 00:26:18 +0200 Subject: [PATCH 2/2] Add some tests for SCEP request decoding --- scep/api/api_test.go | 113 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 113 insertions(+) create mode 100644 scep/api/api_test.go diff --git a/scep/api/api_test.go b/scep/api/api_test.go new file mode 100644 index 00000000..bdb51594 --- /dev/null +++ b/scep/api/api_test.go @@ -0,0 +1,113 @@ +// Package api implements a SCEP HTTP server. +package api + +import ( + "bytes" + "errors" + "net/http" + "net/http/httptest" + "reflect" + "testing" + "testing/iotest" +) + +func Test_decodeRequest(t *testing.T) { + type args struct { + r *http.Request + } + tests := []struct { + name string + args args + want request + wantErr bool + }{ + { + name: "fail/unsupported-method", + args: args{ + r: httptest.NewRequest(http.MethodPatch, "http://scep:8080/?operation=AnUnsupportOperation", nil), + }, + want: request{}, + wantErr: true, + }, + { + name: "fail/get-unsupported-operation", + args: args{ + r: httptest.NewRequest(http.MethodGet, "http://scep:8080/?operation=AnUnsupportOperation", nil), + }, + want: request{}, + wantErr: true, + }, + { + name: "fail/get-PKIOperation", + args: args{ + r: httptest.NewRequest(http.MethodGet, "http://scep:8080/?operation=PKIOperation&message='somewronginput'", nil), + }, + want: request{}, + wantErr: true, + }, + { + name: "fail/post-PKIOperation", + args: args{ + r: httptest.NewRequest(http.MethodPost, "http://scep:8080/?operation=PKIOperation", iotest.ErrReader(errors.New("a read error"))), + }, + want: request{}, + wantErr: true, + }, + { + name: "ok/get-GetCACert", + args: args{ + r: httptest.NewRequest(http.MethodGet, "http://scep:8080/?operation=GetCACert", nil), + }, + want: request{ + Operation: "GetCACert", + Message: []byte{}, + }, + wantErr: false, + }, + { + name: "ok/get-GetCACaps", + args: args{ + r: httptest.NewRequest(http.MethodGet, "http://scep:8080/?operation=GetCACaps", nil), + }, + want: request{ + Operation: "GetCACaps", + Message: []byte{}, + }, + wantErr: false, + }, + { + name: "ok/get-PKIOperation", + args: args{ + r: httptest.NewRequest(http.MethodGet, "http://scep:8080/?operation=PKIOperation&message=MTIzNA==", nil), + }, + want: request{ + Operation: "PKIOperation", + Message: []byte("1234"), + }, + wantErr: false, + }, + { + name: "ok/post-PKIOperation", + args: args{ + r: httptest.NewRequest(http.MethodPost, "http://scep:8080/?operation=PKIOperation", bytes.NewBufferString("1234")), + }, + want: request{ + Operation: "PKIOperation", + Message: []byte("1234"), + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := decodeRequest(tt.args.r) + if (err != nil) != tt.wantErr { + t.Errorf("decodeRequest() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("decodeRequest() = %v, want %v", got, tt.want) + } + }) + } +}