diff --git a/api/api_test.go b/api/api_test.go index ec567a48..1c02e6da 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -18,6 +18,7 @@ import ( "github.com/go-chi/chi" "github.com/smallstep/cli/crypto/tlsutil" + "github.com/smallstep/cli/jose" ) const ( @@ -95,6 +96,18 @@ Q7vMNPBWrJWu+A++vHY61WGET+h4lY3GFr2I8OE4IiHPQi1D7Y0+fwOmStwuRPM4 58jHzJwr1K7cx0lpWfGTtc5bseCGtTKmDBXTziw04yl8eE1+ZFOganixGwCtl4Tt DCbKzWTW8lqVdp9Kyf7XEhhc2R8C5w== -----END CERTIFICATE REQUEST-----` + + pubKey = `{ + "use": "sig", + "kty": "EC", + "kid": "oV1p0MJeGQ7qBlK6B-oyfVdBRjh_e7VSK_YSEEqgW00", + "crv": "P-256", + "alg": "ES256", + "x": "p9QX4tzjxUrB0fgqRWLKUuPolDtBW681f2Qyh-uVNhk", + "y": "CNSEloc4oLDFTX0Vywj0WiqOlh516sFQwCj6WtM8LT8" +}` + + privKey = "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJjdHkiOiJqd2sranNvbiIsImVuYyI6IkEyNTZHQ00iLCJwMmMiOjEwMDAwMCwicDJzIjoiNEhBYjE0WDQ5OFM4LWxSb29JTnpqZyJ9.RbkJXGzI3kOsaP20KmZs0ELFLgpRddAE49AJHlEblw-uH_gg6SV3QA.M3MArEpHgI171lhm.gBlFySpzK9F7riBJbtLSNkb4nAw_gWokqs1jS-ZK1qxuqTK-9mtX5yILjRnftx9P9uFp5xt7rvv4Mgom1Ed4V9WtIyfNP_Cz3Pme1Eanp5nY68WCe_yG6iSB1RJdMDBUb2qBDZiBdhJim1DRXsOfgedOrNi7GGbppMlD77DEpId118owR5izA-c6Q_hg08hIE3tnMAnebDNQoF9jfEY99_AReVRH8G4hgwZEPCfXMTb3J-lowKGG4vXIbK5knFLh47SgOqG4M2M51SMS-XJ7oBz1Vjoamc90QIqKV51rvZ5m0N_sPFtxzcfV4E9yYH3XVd4O-CG4ydVKfKVyMtQ.mcKFZqBHp_n7Ytj2jz9rvw" ) func parseCertificate(data string) *x509.Certificate { @@ -377,13 +390,15 @@ func TestSignRequest_Validate(t *testing.T) { } type mockAuthority struct { - ret1, ret2 interface{} - err error - authorize func(ott string) ([]Claim, error) - getTLSOptions func() *tlsutil.TLSOptions - root func(shasum string) (*x509.Certificate, error) - sign func(cr *x509.CertificateRequest, opts SignOptions, claims ...Claim) (*x509.Certificate, *x509.Certificate, error) - renew func(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error) + ret1, ret2 interface{} + err error + authorize func(ott string) ([]Claim, error) + getTLSOptions func() *tlsutil.TLSOptions + root func(shasum string) (*x509.Certificate, error) + sign func(cr *x509.CertificateRequest, opts SignOptions, claims ...Claim) (*x509.Certificate, *x509.Certificate, error) + renew func(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error) + getProvisioners func() (map[string]*jose.JSONWebKeySet, error) + getEncryptedKey func(kid string) (string, error) } func (m *mockAuthority) Authorize(ott string) ([]Claim, error) { @@ -429,6 +444,44 @@ func (m *mockAuthority) Renew(cert *x509.Certificate) (*x509.Certificate, *x509. return m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate), m.err } +func (m *mockAuthority) GetProvisioners() (map[string]*jose.JSONWebKeySet, error) { + if m.getProvisioners != nil { + return m.getProvisioners() + } + return m.ret1.(map[string]*jose.JSONWebKeySet), m.err +} + +func (m *mockAuthority) GetEncryptedKey(kid string) (string, error) { + if m.getEncryptedKey != nil { + return m.getEncryptedKey(kid) + } + return m.ret1.(string), m.err +} + +func Test_caHandler_Route(t *testing.T) { + type fields struct { + Authority Authority + } + type args struct { + r Router + } + tests := []struct { + name string + fields fields + args args + }{ + {"ok", fields{&mockAuthority{}}, args{chi.NewRouter()}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := &caHandler{ + Authority: tt.fields.Authority, + } + h.Route(tt.args.r) + }) + } +} + func Test_caHandler_Health(t *testing.T) { req := httptest.NewRequest("GET", "http://example.com/health", nil) w := httptest.NewRecorder() @@ -616,3 +669,135 @@ func Test_caHandler_Renew(t *testing.T) { }) } } + +func Test_caHandler_Provisioners(t *testing.T) { + type fields struct { + Authority Authority + } + type args struct { + w http.ResponseWriter + r *http.Request + } + + req, err := http.NewRequest("GET", "http://example.com/provisioners", nil) + if err != nil { + t.Fatal(err) + } + + var key jose.JSONWebKey + if err := json.Unmarshal([]byte(pubKey), &key); err != nil { + t.Fatal(err) + } + + p := map[string]*jose.JSONWebKeySet{ + "p1": &jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{key}, + }, + "p2": &jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{key}, + }, + } + + tests := []struct { + name string + fields fields + args args + statusCode int + }{ + {"ok", fields{&mockAuthority{ret1: p}}, args{httptest.NewRecorder(), req}, 200}, + {"fail", fields{&mockAuthority{ret1: p, err: fmt.Errorf("the error")}}, args{httptest.NewRecorder(), req}, 500}, + } + + expectedKey, err := json.Marshal(key) + if err != nil { + t.Fatal(err) + } + expected := []byte(`{"provisioners":{"p1":{"keys":[` + string(expectedKey) + `]},"p2":{"keys":[` + string(expectedKey) + `]}}}`) + expectedError := []byte(`{"status":500,"message":"Internal Server Error"}`) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := &caHandler{ + Authority: tt.fields.Authority, + } + h.Provisioners(tt.args.w, tt.args.r) + + rec := tt.args.w.(*httptest.ResponseRecorder) + res := rec.Result() + if res.StatusCode != tt.statusCode { + t.Errorf("caHandler.Provisioners StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) + } + body, err := ioutil.ReadAll(res.Body) + res.Body.Close() + if err != nil { + t.Errorf("caHandler.Provisioners unexpected error = %v", err) + } + if tt.statusCode < http.StatusBadRequest { + if !bytes.Equal(bytes.TrimSpace(body), expected) { + t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expected) + } + } else { + if !bytes.Equal(bytes.TrimSpace(body), expectedError) { + t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expectedError) + } + } + }) + } +} + +func Test_caHandler_ProvisionerKey(t *testing.T) { + type fields struct { + Authority Authority + } + type args struct { + w http.ResponseWriter + r *http.Request + } + + // Request with chi context + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("kid", "oV1p0MJeGQ7qBlK6B-oyfVdBRjh_e7VSK_YSEEqgW00") + req := httptest.NewRequest("GET", "http://example.com/provisioners/oV1p0MJeGQ7qBlK6B-oyfVdBRjh_e7VSK_YSEEqgW00/encrypted-key", nil) + req = req.WithContext(context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)) + + tests := []struct { + name string + fields fields + args args + statusCode int + }{ + {"ok", fields{&mockAuthority{ret1: privKey}}, args{httptest.NewRecorder(), req}, 200}, + {"fail", fields{&mockAuthority{ret1: "", err: fmt.Errorf("not found")}}, args{httptest.NewRecorder(), req}, 404}, + } + + expected := []byte(`{"key":"` + privKey + `"}`) + expectedError := []byte(`{"status":404,"message":"Not Found"}`) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := &caHandler{ + Authority: tt.fields.Authority, + } + h.ProvisionerKey(tt.args.w, tt.args.r) + + rec := tt.args.w.(*httptest.ResponseRecorder) + res := rec.Result() + if res.StatusCode != tt.statusCode { + t.Errorf("caHandler.Provisioners StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) + } + body, err := ioutil.ReadAll(res.Body) + res.Body.Close() + if err != nil { + t.Errorf("caHandler.Provisioners unexpected error = %v", err) + } + if tt.statusCode < http.StatusBadRequest { + if !bytes.Equal(bytes.TrimSpace(body), expected) { + t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expected) + } + } else { + if !bytes.Equal(bytes.TrimSpace(body), expectedError) { + t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expectedError) + } + } + }) + } +}