diff --git a/ca/tls_test.go b/ca/tls_test.go index b88e825a..7d2f2d46 100644 --- a/ca/tls_test.go +++ b/ca/tls_test.go @@ -16,6 +16,7 @@ import ( "testing" "time" + "github.com/pkg/errors" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority" "github.com/smallstep/cli/crypto/randutil" @@ -412,6 +413,64 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) { } } +func TestClient_GetCertificateRenewer(t *testing.T) { + reset := setMinCertDuration(1 * time.Second) + defer reset() + + // Start CA + ca := startCATestServer() + defer ca.Close() + + client, sr, pk := signDuration(ca, "127.0.0.1", 5*time.Second) + + badOption := func(ctx *TLSOptionCtx) error { + return errors.New("foo") + } + + type args struct { + sign *api.SignResponse + pk crypto.PrivateKey + options []TLSOption + } + tests := []struct { + name string + args args + wantErr bool + }{ + {"ok", args{sr, pk, nil}, false}, + {"bad-pk", args{sr, []byte("foo"), nil}, true}, + {"bad-option", args{sr, pk, []TLSOption{badOption}}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := client.GetCertificateRenewer(tt.args.sign, tt.args.pk, tt.args.options...) + if (err != nil) != tt.wantErr { + t.Errorf("Client.GetCertificateRenewer() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr == false { + fn := got.RenewCertificate + var i int + got.RenewCertificate = func() (*tls.Certificate, error) { + cert, err := fn() + if err != nil { + t.Errorf("TLSRenewer.RenewCertificate() error = %v", err) + } else { + i++ + } + return cert, err + } + got.Run() + time.Sleep(5 * time.Second) + if i == 0 { + t.Errorf("Client.GetCertificateRenewer() certificate was not renewed") + } + got.Stop() + } + }) + } +} + func TestCertificate(t *testing.T) { cert := parseCertificate(certPEM) ok := &api.SignResponse{