From 7eb8aeb1f199c9b26cffe962a473f8286895bb3e Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Mon, 5 Nov 2018 12:22:10 -0800 Subject: [PATCH] Add tests for bootstrap functions. --- ca/bootstrap.go | 2 +- ca/bootstrap_test.go | 219 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 220 insertions(+), 1 deletion(-) create mode 100644 ca/bootstrap_test.go diff --git a/ca/bootstrap.go b/ca/bootstrap.go index 522fa91b..880d4652 100644 --- a/ca/bootstrap.go +++ b/ca/bootstrap.go @@ -24,7 +24,7 @@ func Bootstrap(token string) (*Client, error) { } var claims tokenClaims if err := tok.UnsafeClaimsWithoutVerification(&claims); err != nil { - return nil, errors.Wrap(err, "error parsing ott") + return nil, errors.Wrap(err, "error parsing token") } // Validate bootstrap token diff --git a/ca/bootstrap_test.go b/ca/bootstrap_test.go new file mode 100644 index 00000000..821b82a0 --- /dev/null +++ b/ca/bootstrap_test.go @@ -0,0 +1,219 @@ +package ca + +import ( + "net/http" + "net/http/httptest" + "reflect" + "testing" + "time" + + "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/authority" + + "github.com/smallstep/cli/crypto/randutil" + stepJOSE "github.com/smallstep/cli/jose" + jose "gopkg.in/square/go-jose.v2" + "gopkg.in/square/go-jose.v2/jwt" +) + +func startCABootstrapServer() *httptest.Server { + config, err := authority.LoadConfiguration("testdata/ca.json") + if err != nil { + panic(err) + } + srv := httptest.NewUnstartedServer(nil) + config.Address = srv.Listener.Addr().String() + ca, err := New(config) + if err != nil { + panic(err) + } + srv.Config.Handler = ca.srv.Handler + srv.TLS = ca.srv.TLSConfig + srv.StartTLS() + // Force the use of GetCertificate on IPs + srv.TLS.Certificates = nil + return srv +} + +func generateBootstrapToken(ca, subject, sha string) string { + now := time.Now() + jwk, err := stepJOSE.ParseKey("testdata/secrets/ott_mariano_priv.jwk", stepJOSE.WithPassword([]byte("password"))) + if err != nil { + panic(err) + } + opts := new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID) + sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, opts) + if err != nil { + panic(err) + } + id, err := randutil.ASCII(64) + if err != nil { + panic(err) + } + cl := struct { + SHA string `json:"sha"` + jwt.Claims + }{ + SHA: sha, + Claims: jwt.Claims{ + ID: id, + Subject: subject, + Issuer: "mariano", + NotBefore: jwt.NewNumericDate(now), + Expiry: jwt.NewNumericDate(now.Add(time.Minute)), + Audience: []string{ca + "/sign"}, + }, + } + raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize() + if err != nil { + panic(err) + } + return raw +} + +func TestBootstrap(t *testing.T) { + srv := startCABootstrapServer() + defer srv.Close() + token := generateBootstrapToken(srv.URL, "subject", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") + client, err := NewClient(srv.URL+"/sign", WithRootFile("testdata/secrets/root_ca.crt")) + if err != nil { + t.Fatal(err) + } + + type args struct { + token string + } + tests := []struct { + name string + args args + want *Client + wantErr bool + }{ + {"ok", args{token}, client, false}, + {"token err", args{"badtoken"}, nil, true}, + {"bad claims", args{"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.foo.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c"}, nil, true}, + {"bad sha", args{generateBootstrapToken(srv.URL, "subject", "")}, nil, true}, + {"bad aud", args{generateBootstrapToken("", "subject", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7")}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := Bootstrap(tt.args.token) + if (err != nil) != tt.wantErr { + t.Errorf("Bootstrap() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr { + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Bootstrap() = %v, want %v", got, tt.want) + } + } else { + if got == nil { + t.Error("Bootstrap() = nil, want not nil") + } else { + if !reflect.DeepEqual(got.endpoint, tt.want.endpoint) { + t.Errorf("Bootstrap() endpoint = %v, want %v", got.endpoint, tt.want.endpoint) + } + if !reflect.DeepEqual(got.certPool, tt.want.certPool) { + t.Errorf("Bootstrap() certPool = %v, want %v", got.certPool, tt.want.certPool) + } + } + } + }) + } +} + +func TestBootstrapServer(t *testing.T) { + srv := startCABootstrapServer() + defer srv.Close() + token := func() string { + return generateBootstrapToken(srv.URL, "subject", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") + } + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("ok")) + }) + type args struct { + addr string + token string + handler http.Handler + } + tests := []struct { + name string + args args + wantErr bool + }{ + {"ok", args{":0", token(), handler}, false}, + {"fail", args{":0", "bad-token", handler}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := BootstrapServer(tt.args.addr, tt.args.token, tt.args.handler) + if (err != nil) != tt.wantErr { + t.Errorf("BootstrapServer() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr { + if got != nil { + t.Errorf("BootstrapServer() = %v, want nil", got) + } + } else { + if !reflect.DeepEqual(got.Addr, tt.args.addr) { + t.Errorf("BootstrapServer() Addr = %v, want %v", got.Addr, tt.args.addr) + } + if got.TLSConfig == nil || got.TLSConfig.ClientCAs == nil || got.TLSConfig.RootCAs == nil || got.TLSConfig.GetCertificate == nil || got.TLSConfig.GetClientCertificate == nil { + t.Errorf("BootstrapServer() invalid TLSConfig = %#v", got.TLSConfig) + } + } + }) + } +} + +func TestBootstrapClient(t *testing.T) { + srv := startCABootstrapServer() + defer srv.Close() + token := func() string { + return generateBootstrapToken(srv.URL, "subject", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") + } + type args struct { + token string + } + tests := []struct { + name string + args args + wantErr bool + }{ + {"ok", args{token()}, false}, + {"fail", args{"bad-token"}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := BootstrapClient(tt.args.token) + if (err != nil) != tt.wantErr { + t.Errorf("BootstrapClient() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr { + if got != nil { + t.Errorf("BootstrapClient() = %v, want nil", got) + } + } else { + tlsConfig := got.Transport.(*http.Transport).TLSClientConfig + if tlsConfig == nil || tlsConfig.ClientCAs != nil || tlsConfig.GetClientCertificate == nil || tlsConfig.RootCAs == nil || tlsConfig.GetCertificate != nil { + t.Errorf("BootstrapClient() invalid Transport = %#v", tlsConfig) + } + resp, err := got.Post(srv.URL+"/renew", "application/json", http.NoBody) + if err != nil { + t.Errorf("BootstrapClient() failed renewing certificate") + return + } + var renewal api.SignResponse + if err := readJSON(resp.Body, &renewal); err != nil { + t.Errorf("BootstrapClient() error reading response: %v", err) + return + } + if renewal.CaPEM.Certificate == nil || renewal.ServerPEM.Certificate == nil { + t.Errorf("BootstrapClient() invalid renewal response: %v", renewal) + } + } + }) + } +}