82 lines
2.1 KiB
Go
82 lines
2.1 KiB
Go
//go:build go1.18
|
|
// +build go1.18
|
|
|
|
package provisioner
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/x509"
|
|
"os"
|
|
"testing"
|
|
)
|
|
|
|
func TestACME_GetAttestationRoots(t *testing.T) {
|
|
appleCA, err := os.ReadFile("testdata/certs/apple-att-ca.crt")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
yubicoCA, err := os.ReadFile("testdata/certs/yubico-piv-ca.crt")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
pool := x509.NewCertPool()
|
|
pool.AppendCertsFromPEM(appleCA)
|
|
pool.AppendCertsFromPEM(yubicoCA)
|
|
|
|
type fields struct {
|
|
Type string
|
|
Name string
|
|
AttestationRoots []byte
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
fields fields
|
|
want *x509.CertPool
|
|
want1 bool
|
|
}{
|
|
{"ok", fields{"ACME", "acme", bytes.Join([][]byte{appleCA, yubicoCA}, []byte("\n"))}, pool, true},
|
|
{"nil", fields{"ACME", "acme", nil}, nil, false},
|
|
{"empty", fields{"ACME", "acme", []byte{}}, nil, false},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
p := &ACME{
|
|
Type: tt.fields.Type,
|
|
Name: tt.fields.Name,
|
|
AttestationRoots: tt.fields.AttestationRoots,
|
|
}
|
|
if err := p.Init(Config{
|
|
Claims: globalProvisionerClaims,
|
|
Audiences: testAudiences,
|
|
}); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
got, got1 := p.GetAttestationRoots()
|
|
switch {
|
|
case tt.want == nil && got == nil:
|
|
break
|
|
case tt.want == nil && got != nil, tt.want != nil && got == nil:
|
|
t.Errorf("ACME.GetAttestationRoots() got = %v, want %v", got, tt.want)
|
|
default:
|
|
//nolint:staticcheck // this file only runs in go1.18
|
|
gotSubjects := got.Subjects()
|
|
//nolint:staticcheck // this file only runs in go1.18
|
|
wantSubjects := tt.want.Subjects()
|
|
if len(gotSubjects) != len(wantSubjects) {
|
|
t.Errorf("ACME.GetAttestationRoots() got = %v, want %v", got, tt.want)
|
|
} else {
|
|
for i, gotSub := range gotSubjects {
|
|
if !bytes.Equal(gotSub, wantSubjects[i]) {
|
|
t.Errorf("ACME.GetAttestationRoots() got = %v, want %v", got, tt.want)
|
|
break
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if got1 != tt.want1 {
|
|
t.Errorf("ACME.GetAttestationRoots() got1 = %v, want %v", got1, tt.want1)
|
|
}
|
|
})
|
|
}
|
|
}
|