Add method to renew the identity.
This commit is contained in:
parent
aa58940582
commit
839fe6b952
2 changed files with 155 additions and 2 deletions
|
@ -8,6 +8,7 @@ import (
|
|||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
@ -191,6 +192,62 @@ func (i *Identity) TLSCertificate() (tls.Certificate, error) {
|
|||
}
|
||||
}
|
||||
|
||||
// Renewer is that interface that a renew client must implement.
|
||||
type Renewer interface {
|
||||
GetRootCAs() *x509.CertPool
|
||||
Renew(tr http.RoundTripper) (*api.SignResponse, error)
|
||||
}
|
||||
|
||||
// Renew renews the current identity certificate using a client with a renew
|
||||
// method.
|
||||
func (i *Identity) Renew(client Renewer) error {
|
||||
switch i.Kind() {
|
||||
case Disabled:
|
||||
return nil
|
||||
case MutualTLS:
|
||||
cert, err := i.TLSCertificate()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tr := http.DefaultTransport.(*http.Transport).Clone()
|
||||
tr.TLSClientConfig = &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
RootCAs: client.GetRootCAs(),
|
||||
PreferServerCipherSuites: true,
|
||||
}
|
||||
|
||||
sign, err := client.Renew(tr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if sign.CertChainPEM == nil || len(sign.CertChainPEM) == 0 {
|
||||
sign.CertChainPEM = []api.Certificate{sign.ServerPEM, sign.CaPEM}
|
||||
}
|
||||
|
||||
// Write certificate
|
||||
buf := new(bytes.Buffer)
|
||||
for _, crt := range sign.CertChainPEM {
|
||||
block := &pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: crt.Raw,
|
||||
}
|
||||
if err := pem.Encode(buf, block); err != nil {
|
||||
return errors.Wrap(err, "error encoding identity certificate")
|
||||
}
|
||||
}
|
||||
certFilename := filepath.Join(identityDir, "identity.crt")
|
||||
if err := ioutil.WriteFile(certFilename, buf.Bytes(), 0600); err != nil {
|
||||
return errors.Wrap(err, "error writing identity certificate")
|
||||
}
|
||||
|
||||
return nil
|
||||
default:
|
||||
return errors.Errorf("unsupported identity type %s", i.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func fileExists(filename string) error {
|
||||
info, err := os.Stat(filename)
|
||||
if err != nil {
|
||||
|
|
|
@ -3,15 +3,17 @@ package identity
|
|||
import (
|
||||
"crypto"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/smallstep/cli/crypto/pemutil"
|
||||
|
||||
"github.com/smallstep/certificates/api"
|
||||
"github.com/smallstep/cli/crypto/pemutil"
|
||||
)
|
||||
|
||||
func TestLoadDefaultIdentity(t *testing.T) {
|
||||
|
@ -252,3 +254,97 @@ func TestWriteDefaultIdentity(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
type renewer struct {
|
||||
pool *x509.CertPool
|
||||
sign *api.SignResponse
|
||||
err error
|
||||
}
|
||||
|
||||
func (r *renewer) GetRootCAs() *x509.CertPool {
|
||||
return r.pool
|
||||
}
|
||||
|
||||
func (r *renewer) Renew(tr http.RoundTripper) (*api.SignResponse, error) {
|
||||
return r.sign, r.err
|
||||
}
|
||||
|
||||
func TestIdentity_Renew(t *testing.T) {
|
||||
tmpDir, err := ioutil.TempDir(os.TempDir(), "go-tests")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
oldIdentityDir := identityDir
|
||||
defer func() {
|
||||
identityDir = oldIdentityDir
|
||||
os.RemoveAll(tmpDir)
|
||||
}()
|
||||
|
||||
certs, err := pemutil.ReadCertificateBundle("testdata/identity/identity.crt")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ok := &renewer{
|
||||
sign: &api.SignResponse{
|
||||
ServerPEM: api.Certificate{Certificate: certs[0]},
|
||||
CaPEM: api.Certificate{Certificate: certs[1]},
|
||||
CertChainPEM: []api.Certificate{
|
||||
{Certificate: certs[0]},
|
||||
{Certificate: certs[1]},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
okOld := &renewer{
|
||||
sign: &api.SignResponse{
|
||||
ServerPEM: api.Certificate{Certificate: certs[0]},
|
||||
CaPEM: api.Certificate{Certificate: certs[1]},
|
||||
},
|
||||
}
|
||||
|
||||
fail := &renewer{
|
||||
err: fmt.Errorf("an error"),
|
||||
}
|
||||
|
||||
type fields struct {
|
||||
Type string
|
||||
Certificate string
|
||||
Key string
|
||||
}
|
||||
type args struct {
|
||||
client Renewer
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
prepare func()
|
||||
fields fields
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", func() {}, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{ok}, false},
|
||||
{"ok old", func() {}, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{okOld}, false},
|
||||
{"ok disabled", func() {}, fields{}, args{nil}, false},
|
||||
{"fail type", func() {}, fields{"foo", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{ok}, true},
|
||||
{"fail renew", func() {}, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{fail}, true},
|
||||
{"fail certificate", func() {}, fields{"mTLS", "testdata/certs/server.crt", "testdata/identity/identity_key"}, args{ok}, true},
|
||||
{"fail write identity", func() {
|
||||
identityDir = filepath.Join(tmpDir, "bad-dir")
|
||||
os.MkdirAll(identityDir, 0600)
|
||||
}, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{ok}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.prepare()
|
||||
i := &Identity{
|
||||
Type: tt.fields.Type,
|
||||
Certificate: tt.fields.Certificate,
|
||||
Key: tt.fields.Key,
|
||||
}
|
||||
if err := i.Renew(tt.args.client); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Identity.Renew() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue