More unit tests for nebula.
This commit is contained in:
parent
99845d38bb
commit
de51c2edfb
1 changed files with 256 additions and 10 deletions
|
@ -7,11 +7,13 @@ import (
|
|||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"net"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"go.step.sm/crypto/jose"
|
||||
"go.step.sm/crypto/randutil"
|
||||
|
@ -26,7 +28,7 @@ func mustNebulaIPNet(t *testing.T, s string) *net.IPNet {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if ip.To4() == nil {
|
||||
if ip = ip.To4(); ip == nil {
|
||||
t.Fatalf("nebula only supports ipv4, have %s", s)
|
||||
}
|
||||
ipNet.IP = ip
|
||||
|
@ -46,7 +48,7 @@ func mustNebulaCA(t *testing.T) (*cert.NebulaCertificate, ed25519.PrivateKey) {
|
|||
Ips: []*net.IPNet{
|
||||
mustNebulaIPNet(t, "10.1.0.0/16"),
|
||||
},
|
||||
Subnets: nil,
|
||||
Subnets: []*net.IPNet{},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(10 * time.Minute),
|
||||
PublicKey: pub,
|
||||
|
@ -72,16 +74,24 @@ func mustNebulaCert(t *testing.T, name string, ipNet *net.IPNet, groups []string
|
|||
t.Fatal(err)
|
||||
}
|
||||
|
||||
invertedGroups := make(map[string]struct{}, len(groups))
|
||||
for _, name := range groups {
|
||||
invertedGroups[name] = struct{}{}
|
||||
}
|
||||
|
||||
t1 := time.Now().Truncate(time.Second)
|
||||
nc := &cert.NebulaCertificate{
|
||||
Details: cert.NebulaCertificateDetails{
|
||||
Name: name,
|
||||
Ips: []*net.IPNet{ipNet},
|
||||
Groups: groups,
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(5 * time.Minute),
|
||||
PublicKey: pub,
|
||||
IsCA: false,
|
||||
Issuer: issuer,
|
||||
Name: name,
|
||||
Ips: []*net.IPNet{ipNet},
|
||||
Subnets: []*net.IPNet{},
|
||||
Groups: groups,
|
||||
NotBefore: t1,
|
||||
NotAfter: t1.Add(5 * time.Minute),
|
||||
PublicKey: pub,
|
||||
IsCA: false,
|
||||
Issuer: issuer,
|
||||
InvertedGroups: invertedGroups,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -244,6 +254,10 @@ func TestNebula_Init(t *testing.T) {
|
|||
{"fail type", fields{"", "Nebulous", ncPem, nil, nil}, args{cfg}, true},
|
||||
{"fail name", fields{"Nebula", "", ncPem, nil, nil}, args{cfg}, true},
|
||||
{"fail root", fields{"Nebula", "Nebulous", nil, nil, nil}, args{cfg}, true},
|
||||
{"fail bad root", fields{"Nebula", "Nebulous", ncPem[:16], nil, nil}, args{cfg}, true},
|
||||
{"fail bad claims", fields{"Nebula", "Nebulous", ncPem, &Claims{
|
||||
MinTLSDur: &Duration{Duration: 0},
|
||||
}, nil}, args{cfg}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -707,3 +721,235 @@ func TestNebula_AuthorizeSSHRekey(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNebula_authorizeToken(t *testing.T) {
|
||||
t1 := now()
|
||||
p, ca, signer := mustNebulaProvisioner(t)
|
||||
crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer)
|
||||
ok := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv)
|
||||
okNoSANs := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], t1, nil, crt, priv)
|
||||
okSSH := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], t1, &SignSSHOptions{
|
||||
CertType: "host",
|
||||
KeyID: "test.lan",
|
||||
Principals: []string{"test.lan"},
|
||||
}, crt, priv)
|
||||
okSSHNoOptions := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], t1, nil, crt, priv)
|
||||
|
||||
// Token with errors
|
||||
failNotBefore := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], t1.Add(1*time.Hour), []string{"10.1.0.1"}, crt, priv)
|
||||
failIssuer := mustNebulaToken(t, "test.lan", "foo", p.audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv)
|
||||
failAudience := mustNebulaToken(t, "test.lan", p.Name, "foo", t1, []string{"10.1.0.1"}, crt, priv)
|
||||
failSubject := mustNebulaToken(t, "", p.Name, p.audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv)
|
||||
|
||||
// Not a nebula token
|
||||
jwk, err := generateJSONWebKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
simpleToken, err := generateSimpleToken("iss", "aud", jwk)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Provisioner with a different CA
|
||||
p2, _, _ := mustNebulaProvisioner(t)
|
||||
|
||||
x509Claims := jose.Claims{
|
||||
ID: "[REPLACEME]",
|
||||
Subject: "test.lan",
|
||||
Issuer: p.Name,
|
||||
IssuedAt: jose.NewNumericDate(t1),
|
||||
NotBefore: jose.NewNumericDate(t1),
|
||||
Expiry: jose.NewNumericDate(t1.Add(5 * time.Minute)),
|
||||
Audience: []string{p.audiences.Sign[0]},
|
||||
}
|
||||
sshClaims := jose.Claims{
|
||||
ID: "[REPLACEME]",
|
||||
Subject: "test.lan",
|
||||
Issuer: p.Name,
|
||||
IssuedAt: jose.NewNumericDate(t1),
|
||||
NotBefore: jose.NewNumericDate(t1),
|
||||
Expiry: jose.NewNumericDate(t1.Add(5 * time.Minute)),
|
||||
Audience: []string{p.audiences.SSHSign[0]},
|
||||
}
|
||||
|
||||
type args struct {
|
||||
token string
|
||||
audiences []string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
p *Nebula
|
||||
args args
|
||||
want *cert.NebulaCertificate
|
||||
want1 *jwtPayload
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok x509", p, args{ok, p.audiences.Sign}, crt, &jwtPayload{
|
||||
Claims: x509Claims,
|
||||
SANs: []string{"10.1.0.1"},
|
||||
}, false},
|
||||
{"ok x509 no sans", p, args{okNoSANs, p.audiences.Sign}, crt, &jwtPayload{
|
||||
Claims: x509Claims,
|
||||
}, false},
|
||||
{"ok ssh", p, args{okSSH, p.audiences.SSHSign}, crt, &jwtPayload{
|
||||
Claims: sshClaims,
|
||||
Step: &stepPayload{
|
||||
SSH: &SignSSHOptions{
|
||||
CertType: "host",
|
||||
KeyID: "test.lan",
|
||||
Principals: []string{"test.lan"},
|
||||
},
|
||||
},
|
||||
}, false},
|
||||
{"ok ssh no principals", p, args{okSSHNoOptions, p.audiences.SSHSign}, crt, &jwtPayload{
|
||||
Claims: sshClaims,
|
||||
}, false},
|
||||
{"fail parse", p, args{"bad.token", p.audiences.Sign}, nil, nil, true},
|
||||
{"fail header", p, args{simpleToken, p.audiences.Sign}, nil, nil, true},
|
||||
{"fail verify", p2, args{ok, p.audiences.Sign}, nil, nil, true},
|
||||
{"fail claims nbf", p, args{failNotBefore, p.audiences.Sign}, nil, nil, true},
|
||||
{"fail claims iss", p, args{failIssuer, p.audiences.Sign}, nil, nil, true},
|
||||
{"fail claims aud", p, args{failAudience, p.audiences.Sign}, nil, nil, true},
|
||||
{"fail claims sub", p, args{failSubject, p.audiences.Sign}, nil, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, got1, err := tt.p.authorizeToken(tt.args.token, tt.args.audiences)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Nebula.authorizeToken() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Nebula.authorizeToken() got = %#v, want %#v", got, tt.want)
|
||||
t.Error(cmp.Equal(got, tt.want))
|
||||
}
|
||||
|
||||
if got1 != nil && tt.want1 != nil {
|
||||
tt.want1.ID = got1.ID
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got1, tt.want1) {
|
||||
t.Errorf("Nebula.authorizeToken() got1 = %v, want %v", got1, tt.want1)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_nebulaSANsValidator_Valid(t *testing.T) {
|
||||
ipNet := mustNebulaIPNet(t, "10.1.2.3/16")
|
||||
type fields struct {
|
||||
Name string
|
||||
IPs []*net.IPNet
|
||||
}
|
||||
type args struct {
|
||||
req *x509.CertificateRequest
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", fields{"dns.name", []*net.IPNet{ipNet}}, args{&x509.CertificateRequest{
|
||||
DNSNames: []string{"dns.name"},
|
||||
IPAddresses: []net.IP{net.IPv4(10, 1, 2, 3)},
|
||||
}}, false},
|
||||
{"ok name only", fields{"dns.name", []*net.IPNet{ipNet}}, args{&x509.CertificateRequest{
|
||||
DNSNames: []string{"dns.name"},
|
||||
}}, false},
|
||||
{"ok ip only", fields{"dns.name", []*net.IPNet{ipNet}}, args{&x509.CertificateRequest{
|
||||
IPAddresses: []net.IP{net.IPv4(10, 1, 2, 3)},
|
||||
}}, false},
|
||||
{"ok email name", fields{"jane@doe.org", []*net.IPNet{ipNet}}, args{&x509.CertificateRequest{
|
||||
EmailAddresses: []string{"jane@doe.org"},
|
||||
IPAddresses: []net.IP{net.IPv4(10, 1, 2, 3)},
|
||||
}}, false},
|
||||
{"ok uri name", fields{"urn:foobar", []*net.IPNet{ipNet}}, args{&x509.CertificateRequest{
|
||||
URIs: []*url.URL{{Scheme: "urn", Opaque: "foobar"}},
|
||||
IPAddresses: []net.IP{net.IPv4(10, 1, 2, 3)},
|
||||
}}, false},
|
||||
{"ok ip name", fields{"127.0.0.1", []*net.IPNet{ipNet}}, args{&x509.CertificateRequest{
|
||||
IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv4(10, 1, 2, 3)},
|
||||
}}, false},
|
||||
{"ok multiple ips", fields{"dns.name", []*net.IPNet{ipNet, mustNebulaIPNet(t, "10.2.2.3/8")}}, args{&x509.CertificateRequest{
|
||||
DNSNames: []string{"dns.name"},
|
||||
IPAddresses: []net.IP{net.IPv4(10, 1, 2, 3), net.IPv4(10, 2, 2, 3)},
|
||||
}}, false},
|
||||
{"fail dns", fields{"fail.name", []*net.IPNet{ipNet}}, args{&x509.CertificateRequest{
|
||||
DNSNames: []string{"dns.name"},
|
||||
IPAddresses: []net.IP{net.IPv4(10, 1, 2, 3)},
|
||||
}}, true},
|
||||
{"fail email", fields{"fail@doe.org", []*net.IPNet{ipNet}}, args{&x509.CertificateRequest{
|
||||
EmailAddresses: []string{"jane@doe.org"},
|
||||
IPAddresses: []net.IP{net.IPv4(10, 1, 2, 3)},
|
||||
}}, true},
|
||||
{"fail uri", fields{"urn:barfoo", []*net.IPNet{ipNet}}, args{&x509.CertificateRequest{
|
||||
URIs: []*url.URL{{Scheme: "urn", Opaque: "foobar"}},
|
||||
IPAddresses: []net.IP{net.IPv4(10, 1, 2, 3)},
|
||||
}}, true},
|
||||
{"fail ip", fields{"127.0.0.1", []*net.IPNet{ipNet}}, args{&x509.CertificateRequest{
|
||||
IPAddresses: []net.IP{net.IPv4(10, 1, 2, 1), net.IPv4(10, 1, 2, 3)},
|
||||
}}, true},
|
||||
{"fail nebula ip", fields{"dns.name", []*net.IPNet{ipNet}}, args{&x509.CertificateRequest{
|
||||
DNSNames: []string{"dns.name"},
|
||||
IPAddresses: []net.IP{net.IPv4(10, 2, 2, 3)},
|
||||
}}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
v := nebulaSANsValidator{
|
||||
Name: tt.fields.Name,
|
||||
IPs: tt.fields.IPs,
|
||||
}
|
||||
if err := v.Valid(tt.args.req); (err != nil) != tt.wantErr {
|
||||
t.Errorf("nebulaSANsValidator.Valid() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_nebulaPrincipalsValidator_Valid(t *testing.T) {
|
||||
ipNet := mustNebulaIPNet(t, "10.1.2.3/16")
|
||||
|
||||
type fields struct {
|
||||
Name string
|
||||
IPs []*net.IPNet
|
||||
}
|
||||
type args struct {
|
||||
got SignSSHOptions
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", fields{"dns.name", []*net.IPNet{ipNet}}, args{SignSSHOptions{
|
||||
Principals: []string{"dns.name", "10.1.2.3"},
|
||||
}}, false},
|
||||
{"ok name", fields{"dns.name", []*net.IPNet{ipNet}}, args{SignSSHOptions{
|
||||
Principals: []string{"dns.name"},
|
||||
}}, false},
|
||||
{"ok ip", fields{"dns.name", []*net.IPNet{ipNet}}, args{SignSSHOptions{
|
||||
Principals: []string{"10.1.2.3"},
|
||||
}}, false},
|
||||
{"fail name", fields{"dns.name", []*net.IPNet{ipNet}}, args{SignSSHOptions{
|
||||
Principals: []string{"foo.name", "10.1.2.3"},
|
||||
}}, true},
|
||||
{"fail ip", fields{"dns.name", []*net.IPNet{ipNet}}, args{SignSSHOptions{
|
||||
Principals: []string{"dns.name", "10.2.2.3"},
|
||||
}}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
v := nebulaPrincipalsValidator{
|
||||
Name: tt.fields.Name,
|
||||
IPs: tt.fields.IPs,
|
||||
}
|
||||
if err := v.Valid(tt.args.got); (err != nil) != tt.wantErr {
|
||||
t.Errorf("nebulaPrincipalsValidator.Valid() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue