Fix some provisioner tests

This commit is contained in:
Mariano Cano 2019-11-14 15:26:37 -08:00 committed by max furman
parent 29be322b1c
commit 7db7b1ee4c
9 changed files with 39 additions and 23 deletions

View file

@ -360,7 +360,7 @@ func TestAWS_AuthorizeSign(t *testing.T) {
} }
} }
func TestAWS_AuthorizeSign_SSH(t *testing.T) { func TestAWS_AuthorizeSSHSign(t *testing.T) {
tm, fn := mockNow() tm, fn := mockNow()
defer fn() defer fn()
@ -425,9 +425,9 @@ func TestAWS_AuthorizeSign_SSH(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ctx := NewContextWithMethod(context.Background(), SignSSHMethod) ctx := NewContextWithMethod(context.Background(), SignSSHMethod)
got, err := tt.aws.AuthorizeSign(ctx, tt.args.token) got, err := tt.aws.AuthorizeSSHSign(ctx, tt.args.token)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("AWS.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("AWS.AuthorizeSSHSign() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
if err != nil { if err != nil {

View file

@ -310,7 +310,7 @@ func TestAzure_AuthorizeSign(t *testing.T) {
} }
} }
func TestAzure_AuthorizeSign_SSH(t *testing.T) { func TestAzure_AuthorizeSSHSign(t *testing.T) {
tm, fn := mockNow() tm, fn := mockNow()
defer fn() defer fn()
@ -365,9 +365,9 @@ func TestAzure_AuthorizeSign_SSH(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ctx := NewContextWithMethod(context.Background(), SignSSHMethod) ctx := NewContextWithMethod(context.Background(), SignSSHMethod)
got, err := tt.azure.AuthorizeSign(ctx, tt.args.token) got, err := tt.azure.AuthorizeSSHSign(ctx, tt.args.token)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("Azure.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Azure.AuthorizeSSHSign() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
if err != nil { if err != nil {

View file

@ -345,7 +345,7 @@ func TestGCP_AuthorizeSign(t *testing.T) {
} }
} }
func TestGCP_AuthorizeSign_SSH(t *testing.T) { func TestGCP_AuthorizeSSHSign(t *testing.T) {
tm, fn := mockNow() tm, fn := mockNow()
defer fn() defer fn()
@ -412,9 +412,9 @@ func TestGCP_AuthorizeSign_SSH(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ctx := NewContextWithMethod(context.Background(), SignSSHMethod) ctx := NewContextWithMethod(context.Background(), SignSSHMethod)
got, err := tt.gcp.AuthorizeSign(ctx, tt.args.token) got, err := tt.gcp.AuthorizeSSHSign(ctx, tt.args.token)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("GCP.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("GCP.AuthorizeSSHSign() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
if err != nil { if err != nil {

View file

@ -207,6 +207,10 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
if !opts.ValidBefore.IsZero() { if !opts.ValidBefore.IsZero() {
signOptions = append(signOptions, sshCertificateValidBeforeModifier(opts.ValidBefore.RelativeTime(t).Unix())) signOptions = append(signOptions, sshCertificateValidBeforeModifier(opts.ValidBefore.RelativeTime(t).Unix()))
} }
// Make sure to define the the KeyID
if opts.KeyID == "" {
signOptions = append(signOptions, sshCertificateKeyIDModifier(claims.Subject))
}
// Default to a user certificate with no principals if not set // Default to a user certificate with no principals if not set
signOptions = append(signOptions, sshCertificateDefaultsModifier{CertType: SSHUserCert}) signOptions = append(signOptions, sshCertificateDefaultsModifier{CertType: SSHUserCert})

View file

@ -329,7 +329,7 @@ func TestJWK_AuthorizeRenew(t *testing.T) {
} }
} }
func TestJWK_AuthorizeSign_SSH(t *testing.T) { func TestJWK_AuthorizeSSHSign(t *testing.T) {
tm, fn := mockNow() tm, fn := mockNow()
defer fn() defer fn()
@ -338,7 +338,7 @@ func TestJWK_AuthorizeSign_SSH(t *testing.T) {
jwk, err := decryptJSONWebKey(p1.EncryptedKey) jwk, err := decryptJSONWebKey(p1.EncryptedKey)
assert.FatalError(t, err) assert.FatalError(t, err)
iss, aud := p1.Name, testAudiences.Sign[0] iss, aud := p1.Name, testAudiences.SSHSign[0]
t1, err := generateSimpleSSHUserToken(iss, aud, jwk) t1, err := generateSimpleSSHUserToken(iss, aud, jwk)
assert.FatalError(t, err) assert.FatalError(t, err)
@ -400,9 +400,9 @@ func TestJWK_AuthorizeSign_SSH(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ctx := NewContextWithMethod(context.Background(), SignSSHMethod) ctx := NewContextWithMethod(context.Background(), SignSSHMethod)
got, err := tt.prov.AuthorizeSign(ctx, tt.args.token) got, err := tt.prov.AuthorizeSSHSign(ctx, tt.args.token)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("JWK.AuthorizeSSHSign() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
if err != nil { if err != nil {
@ -432,7 +432,7 @@ func TestJWK_AuthorizeSign_SSHOptions(t *testing.T) {
jwk, err := decryptJSONWebKey(p1.EncryptedKey) jwk, err := decryptJSONWebKey(p1.EncryptedKey)
assert.FatalError(t, err) assert.FatalError(t, err)
sub, iss, aud, iat := "subject@smallstep.com", p1.Name, testAudiences.Sign[0], time.Now() sub, iss, aud, iat := "subject@smallstep.com", p1.Name, testAudiences.SSHSign[0], time.Now()
key, err := generateJSONWebKey() key, err := generateJSONWebKey()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -514,8 +514,8 @@ func TestJWK_AuthorizeSign_SSHOptions(t *testing.T) {
ctx := NewContextWithMethod(context.Background(), SignSSHMethod) ctx := NewContextWithMethod(context.Background(), SignSSHMethod)
token, err := generateSSHToken(tt.args.sub, tt.args.iss, tt.args.aud, tt.args.iat, tt.args.tokSSHOpts, tt.args.jwk) token, err := generateSSHToken(tt.args.sub, tt.args.iss, tt.args.aud, tt.args.iat, tt.args.tokSSHOpts, tt.args.jwk)
assert.FatalError(t, err) assert.FatalError(t, err)
if got, err := tt.prov.AuthorizeSign(ctx, token); (err != nil) != tt.wantErr { if got, err := tt.prov.AuthorizeSSHSign(ctx, token); (err != nil) != tt.wantErr {
t.Errorf("JWK.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("JWK.AuthorizeSSHSign() error = %v, wantErr %v", err, tt.wantErr)
} else if !tt.wantErr && assert.NotNil(t, got) { } else if !tt.wantErr && assert.NotNil(t, got) {
var opts SSHOptions var opts SSHOptions
if tt.args.userSSHOpts != nil { if tt.args.userSSHOpts != nil {

View file

@ -330,7 +330,7 @@ func TestOIDC_AuthorizeSign(t *testing.T) {
} }
} }
func TestOIDC_AuthorizeSign_SSH(t *testing.T) { func TestOIDC_AuthorizeSSHSign(t *testing.T) {
tm, fn := mockNow() tm, fn := mockNow()
defer fn() defer fn()
@ -427,9 +427,9 @@ func TestOIDC_AuthorizeSign_SSH(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ctx := NewContextWithMethod(context.Background(), SignSSHMethod) ctx := NewContextWithMethod(context.Background(), SignSSHMethod)
got, err := tt.prov.AuthorizeSign(ctx, tt.args.token) got, err := tt.prov.AuthorizeSSHSign(ctx, tt.args.token)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("OIDC.AuthorizeSSHSign() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
if err != nil { if err != nil {

View file

@ -74,15 +74,18 @@ func (o SSHOptions) Modify(cert *ssh.Certificate) error {
cert.KeyId = o.KeyID cert.KeyId = o.KeyID
cert.ValidPrincipals = o.Principals cert.ValidPrincipals = o.Principals
t := now()
if !o.ValidAfter.IsZero() { if !o.ValidAfter.IsZero() {
cert.ValidAfter = uint64(o.ValidAfter.Time().Unix()) cert.ValidAfter = uint64(o.ValidAfter.RelativeTime(t).Unix())
} }
if !o.ValidBefore.IsZero() { if !o.ValidBefore.IsZero() {
cert.ValidBefore = uint64(o.ValidBefore.Time().Unix()) cert.ValidBefore = uint64(o.ValidBefore.RelativeTime(t).Unix())
} }
if cert.ValidAfter > 0 && cert.ValidBefore > 0 && cert.ValidAfter > cert.ValidBefore { if cert.ValidAfter > 0 && cert.ValidBefore > 0 && cert.ValidAfter > cert.ValidBefore {
return errors.New("ssh certificate valid after cannot be greater than valid before") return errors.New("ssh certificate valid after cannot be greater than valid before")
} }
return nil return nil
} }

View file

@ -38,8 +38,12 @@ var (
EnableSSHCA: &defaultEnableSSHCA, EnableSSHCA: &defaultEnableSSHCA,
} }
testAudiences = Audiences{ testAudiences = Audiences{
Sign: []string{"https://ca.smallstep.com/sign", "https://ca.smallstep.com/1.0/sign"}, Sign: []string{"https://ca.smallstep.com/1.0/sign", "https://ca.smallstep.com/sign"},
Revoke: []string{"https://ca.smallstep.com/revoke", "https://ca.smallstep.com/1.0/revoke"}, Revoke: []string{"https://ca.smallstep.com/1.0/revoke", "https://ca.smallstep.com/revoke"},
SSHSign: []string{"https://ca.smallstep.com/1.0/ssh/sign"},
SSHRevoke: []string{"https://ca.smallstep.com/1.0/ssh/revoke"},
SSHRenew: []string{"https://ca.smallstep.com/1.0/ssh/renew"},
SSHRekey: []string{"https://ca.smallstep.com/1.0/ssh/rekey"},
} }
) )

View file

@ -235,6 +235,7 @@ func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
} }
// Add modifiers from custom claims // Add modifiers from custom claims
// FIXME: this is also set in the sign method using SSHOptions.Modify.
if opts.CertType != "" { if opts.CertType != "" {
signOptions = append(signOptions, sshCertificateCertTypeModifier(opts.CertType)) signOptions = append(signOptions, sshCertificateCertTypeModifier(opts.CertType))
} }
@ -248,6 +249,10 @@ func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
if !opts.ValidBefore.IsZero() { if !opts.ValidBefore.IsZero() {
signOptions = append(signOptions, sshCertificateValidBeforeModifier(opts.ValidBefore.RelativeTime(t).Unix())) signOptions = append(signOptions, sshCertificateValidBeforeModifier(opts.ValidBefore.RelativeTime(t).Unix()))
} }
// Make sure to define the the KeyID
if opts.KeyID == "" {
signOptions = append(signOptions, sshCertificateKeyIDModifier(claims.Subject))
}
// Default to a user certificate with no principals if not set // Default to a user certificate with no principals if not set
signOptions = append(signOptions, sshCertificateDefaultsModifier{CertType: SSHUserCert}) signOptions = append(signOptions, sshCertificateDefaultsModifier{CertType: SSHUserCert})