diff --git a/ca/provisioner.go b/ca/provisioner.go index bc1acb94..3f86c068 100644 --- a/ca/provisioner.go +++ b/ca/provisioner.go @@ -22,6 +22,7 @@ type Provisioner struct { name string kid string audience string + sshAudience string fingerprint string jwk *jose.JSONWebKey tokenLifetime time.Duration @@ -60,6 +61,7 @@ func NewProvisioner(name, kid, caURL string, password []byte, opts ...ClientOpti name: name, kid: jwk.KeyID, audience: client.endpoint.ResolveReference(&url.URL{Path: "/1.0/sign"}).String(), + sshAudience: client.endpoint.ResolveReference(&url.URL{Path: "/1.0/ssh/sign"}).String(), fingerprint: fp, jwk: jwk, tokenLifetime: tokenLifetime, @@ -116,6 +118,39 @@ func (p *Provisioner) Token(subject string, sans ...string) (string, error) { return tok.SignedString(p.jwk.Algorithm, p.jwk.Key) } +func (p *Provisioner) SSHToken(certType, keyID string, principals []string) (string, error) { + jwtID, err := randutil.Hex(64) + if err != nil { + return "", err + } + + notBefore := time.Now() + notAfter := notBefore.Add(tokenLifetime) + tokOptions := []token.Options{ + token.WithJWTID(jwtID), + token.WithKid(p.kid), + token.WithIssuer(p.name), + token.WithAudience(p.sshAudience), + token.WithValidity(notBefore, notAfter), + token.WithSSH(provisioner.SSHOptions{ + CertType: certType, + Principals: principals, + KeyID: keyID, + }), + } + + if p.fingerprint != "" { + tokOptions = append(tokOptions, token.WithSHA(p.fingerprint)) + } + + tok, err := provision.New(keyID, tokOptions...) + if err != nil { + return "", err + } + + return tok.SignedString(p.jwk.Algorithm, p.jwk.Key) +} + func decryptProvisionerJWK(encryptedKey string, password []byte) (*jose.JSONWebKey, error) { enc, err := jose.ParseEncrypted(encryptedKey) if err != nil { diff --git a/ca/provisioner_test.go b/ca/provisioner_test.go index 40015df7..fcfaeb10 100644 --- a/ca/provisioner_test.go +++ b/ca/provisioner_test.go @@ -198,3 +198,105 @@ func TestProvisioner_Token(t *testing.T) { }) } } + +func TestProvisioner_SSHToken(t *testing.T) { + p := getTestProvisioner(t, "https://127.0.0.1:9000") + sha := "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7" + + type fields struct { + name string + kid string + fingerprint string + jwk *jose.JSONWebKey + tokenLifetime time.Duration + } + type args struct { + certType string + keyID string + principals []string + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + {"ok", fields{p.name, p.kid, sha, p.jwk, p.tokenLifetime}, args{"user", "foo@smallstep.com", []string{"foo"}}, false}, + {"ok host", fields{p.name, p.kid, sha, p.jwk, p.tokenLifetime}, args{"host", "foo.smallstep.com", []string{"foo.smallstep.com"}}, false}, + {"ok multiple principals", fields{p.name, p.kid, sha, p.jwk, p.tokenLifetime}, args{"user", "foo@smallstep.com", []string{"foo", "bar"}}, false}, + {"fail-no-subject", fields{p.name, p.kid, sha, p.jwk, p.tokenLifetime}, args{"user", "", []string{"foo"}}, true}, + {"fail-no-key", fields{p.name, p.kid, sha, &jose.JSONWebKey{}, p.tokenLifetime}, args{"user", "foo@smallstep.com", []string{"foo"}}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &Provisioner{ + name: tt.fields.name, + kid: tt.fields.kid, + audience: "https://127.0.0.1:9000/1.0/sign", + sshAudience: "https://127.0.0.1:9000/1.0/ssh/sign", + fingerprint: tt.fields.fingerprint, + jwk: tt.fields.jwk, + tokenLifetime: tt.fields.tokenLifetime, + } + got, err := p.SSHToken(tt.args.certType, tt.args.keyID, tt.args.principals) + if (err != nil) != tt.wantErr { + t.Errorf("Provisioner.SSHToken() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if tt.wantErr == false { + jwt, err := jose.ParseSigned(got) + if err != nil { + t.Error(err) + return + } + var claims jose.Claims + if err := jwt.Claims(tt.fields.jwk.Public(), &claims); err != nil { + t.Error(err) + return + } + if err := claims.ValidateWithLeeway(jose.Expected{ + Audience: []string{"https://127.0.0.1:9000/1.0/ssh/sign"}, + Issuer: tt.fields.name, + Subject: tt.args.keyID, + Time: time.Now().UTC(), + }, time.Minute); err != nil { + t.Error(err) + return + } + lifetime := claims.Expiry.Time().Sub(claims.NotBefore.Time()) + if lifetime != tt.fields.tokenLifetime { + t.Errorf("Claims token life time = %s, want %s", lifetime, tt.fields.tokenLifetime) + } + allClaims := make(map[string]interface{}) + if err := jwt.Claims(tt.fields.jwk.Public(), &allClaims); err != nil { + t.Error(err) + return + } + if v, ok := allClaims["sha"].(string); !ok || v != sha { + t.Errorf("Claim sha = %s, want %s", v, sha) + } + + principals := make([]interface{}, len(tt.args.principals)) + for i, p := range tt.args.principals { + principals[i] = p + } + want := map[string]interface{}{ + "ssh": map[string]interface{}{ + "certType": tt.args.certType, + "keyID": tt.args.keyID, + "principals": principals, + "validAfter": "", + "validBefore": "", + }, + } + if !reflect.DeepEqual(allClaims["step"], want) { + t.Errorf("Claim step = %s, want %s", allClaims["step"], want) + } + if v, ok := allClaims["jti"].(string); !ok || v == "" { + t.Errorf("Claim jti = %s, want not blank", v) + } + } + }) + } +}