From bacbf85aa30718bec0fd8c79e32cc683b11b2b48 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Thu, 17 Jan 2019 14:48:33 -0800 Subject: [PATCH] Add new bootstrap method that creates a listener. --- ca/bootstrap.go | 53 +++++++++++++++++++++++++++++++++ ca/bootstrap_test.go | 70 +++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 122 insertions(+), 1 deletion(-) diff --git a/ca/bootstrap.go b/ca/bootstrap.go index 8989b3c0..6c532d5c 100644 --- a/ca/bootstrap.go +++ b/ca/bootstrap.go @@ -2,6 +2,8 @@ package ca import ( "context" + "crypto/tls" + "net" "net/http" "strings" @@ -145,3 +147,54 @@ func BootstrapClient(ctx context.Context, token string, options ...TLSOption) (* Transport: transport, }, nil } + +// BootstrapListener is a helper function that using the given token returns a +// TLS listener which accepts connections from an inner listener and wraps each +// connection with Server. +// +// Without any extra option the server will be configured for mTLS, it will +// require and verify clients certificates, but options can be used to drop this +// requirement, the most common will be only verify the certs if given with +// ca.VerifyClientCertIfGiven(), or add extra CAs with +// ca.AddClientCA(*x509.Certificate). +// +// Usage: +// inner, err := net.Listen("tcp", ":443") +// if err != nil { +// return nil +// } +// ctx, cancel := context.WithCancel(context.Background()) +// defer cancel() +// lis, err := ca.BootstrapListener(ctx, token, inner) +// if err != nil { +// return err +// } +// srv := grpc.NewServer() +// ... // register services +// srv.Serve(lis) +func BootstrapListener(ctx context.Context, token string, inner net.Listener, options ...TLSOption) (net.Listener, error) { + client, err := Bootstrap(token) + if err != nil { + return nil, err + } + + req, pk, err := CreateSignRequest(token) + if err != nil { + return nil, err + } + + sign, err := client.Sign(req) + if err != nil { + return nil, err + } + + // Make sure the tlsConfig have all supported roots on ClientCAs and RootCAs + options = append(options, AddRootsToCAs()) + + tlsConfig, err := client.GetServerTLSConfig(ctx, sign, pk, options...) + if err != nil { + return nil, err + } + + return tls.NewListener(inner, tlsConfig), nil +} diff --git a/ca/bootstrap_test.go b/ca/bootstrap_test.go index 6a0a7159..560d9c47 100644 --- a/ca/bootstrap_test.go +++ b/ca/bootstrap_test.go @@ -8,13 +8,13 @@ import ( "net/http" "net/http/httptest" "reflect" + "sync" "testing" "time" "github.com/pkg/errors" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority" - "github.com/smallstep/cli/crypto/randutil" stepJOSE "github.com/smallstep/cli/jose" jose "gopkg.in/square/go-jose.v2" @@ -530,3 +530,71 @@ func doReload(ca *CA) error { newCA.srv.Addr = ca.srv.Addr return ca.srv.Reload(newCA.srv) } + +func TestBootstrapListener(t *testing.T) { + srv := startCABootstrapServer() + defer srv.Close() + token := func() string { + return generateBootstrapToken(srv.URL, "127.0.0.1", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") + } + type args struct { + token string + } + tests := []struct { + name string + args args + wantErr bool + }{ + {"ok", args{token()}, false}, + {"fail", args{"bad-token"}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + inner := newLocalListener() + defer inner.Close() + lis, err := BootstrapListener(context.Background(), tt.args.token, inner) + if (err != nil) != tt.wantErr { + t.Errorf("BootstrapListener() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr { + if lis != nil { + t.Errorf("BootstrapListener() = %v, want nil", lis) + } + return + } + wg := new(sync.WaitGroup) + go func() { + wg.Add(1) + http.Serve(lis, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("ok")) + })) + wg.Done() + }() + defer wg.Wait() + defer lis.Close() + + client, err := BootstrapClient(context.Background(), token()) + if err != nil { + t.Errorf("BootstrapClient() error = %v", err) + return + } + println("https://" + lis.Addr().String()) + resp, err := client.Get("https://" + lis.Addr().String()) + if err != nil { + t.Errorf("client.Get() error = %v", err) + return + } + defer resp.Body.Close() + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Errorf("ioutil.ReadAll() error = %v", err) + return + } + if string(b) != "ok" { + t.Errorf("client.Get() = %s, want ok", string(b)) + return + } + }) + } +}