diff --git a/ca/identity/client.go b/ca/identity/client.go index 4377638f..6f862115 100644 --- a/ca/identity/client.go +++ b/ca/identity/client.go @@ -10,6 +10,7 @@ import ( "net/url" "github.com/pkg/errors" + "go.step.sm/cli-utils/step" ) // Client wraps http.Client with a transport using the step root and identity. @@ -27,21 +28,22 @@ func (c *Client) ResolveReference(ref *url.URL) *url.URL { // $STEPPATH/config/defaults.json and the identity defined in // $STEPPATH/config/identity.json func LoadClient() (*Client, error) { - b, err := ioutil.ReadFile(DefaultsFile) + defaultsFile := step.DefaultsFile() + b, err := ioutil.ReadFile(defaultsFile) if err != nil { - return nil, errors.Wrapf(err, "error reading %s", DefaultsFile) + return nil, errors.Wrapf(err, "error reading %s", defaultsFile) } var defaults defaultsConfig if err := json.Unmarshal(b, &defaults); err != nil { - return nil, errors.Wrapf(err, "error unmarshaling %s", DefaultsFile) + return nil, errors.Wrapf(err, "error unmarshaling %s", defaultsFile) } if err := defaults.Validate(); err != nil { - return nil, errors.Wrapf(err, "error validating %s", DefaultsFile) + return nil, errors.Wrapf(err, "error validating %s", defaultsFile) } caURL, err := url.Parse(defaults.CaURL) if err != nil { - return nil, errors.Wrapf(err, "error validating %s", DefaultsFile) + return nil, errors.Wrapf(err, "error validating %s", defaultsFile) } if caURL.Scheme == "" { caURL.Scheme = "https" @@ -52,7 +54,7 @@ func LoadClient() (*Client, error) { return nil, err } if err := identity.Validate(); err != nil { - return nil, errors.Wrapf(err, "error validating %s", IdentityFile) + return nil, errors.Wrapf(err, "error validating %s", step.IdentityFile()) } if kind := identity.Kind(); kind != MutualTLS { return nil, errors.Errorf("unsupported identity %s: only mTLS is currently supported", kind) diff --git a/ca/identity/identity.go b/ca/identity/identity.go index 7d80ef70..e8760c50 100644 --- a/ca/identity/identity.go +++ b/ca/identity/identity.go @@ -39,12 +39,6 @@ const TunnelTLS Type = "tTLS" // DefaultLeeway is the duration for matching not before claims. const DefaultLeeway = 1 * time.Minute -// IdentityFile contains the location of the identity file. -var IdentityFile = filepath.Join(step.ProfilePath(), "config", "identity.json") - -// DefaultsFile contains the location of the defaults file. -var DefaultsFile = filepath.Join(step.ProfilePath(), "config", "defaults.json") - // Identity represents the identity file that can be used to authenticate with // the CA. type Identity struct { @@ -74,23 +68,25 @@ func LoadIdentity(filename string) (*Identity, error) { // LoadDefaultIdentity loads the default identity. func LoadDefaultIdentity() (*Identity, error) { - return LoadIdentity(IdentityFile) + return LoadIdentity(step.IdentityFile()) } -// configDir and identityDir are used in WriteDefaultIdentity for testing -// purposes. -var ( - configDir = filepath.Join(step.ProfilePath(), "config") - identityDir = filepath.Join(step.ProfilePath(), "identity") -) +func profileConfigDir() string { + return filepath.Join(step.ProfilePath(), "config") +} + +func profileIdentityDir() string { + return filepath.Join(step.ProfilePath(), "identity") +} // WriteDefaultIdentity writes the given certificates and key and the // identity.json pointing to the new files. func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) error { - if err := os.MkdirAll(configDir, 0700); err != nil { + if err := os.MkdirAll(profileConfigDir(), 0700); err != nil { return errors.Wrap(err, "error creating config directory") } + identityDir := profileIdentityDir() if err := os.MkdirAll(identityDir, 0700); err != nil { return errors.Wrap(err, "error creating identity directory") } @@ -127,7 +123,7 @@ func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) er }); err != nil { return errors.Wrap(err, "error writing identity json") } - if err := ioutil.WriteFile(IdentityFile, buf.Bytes(), 0600); err != nil { + if err := ioutil.WriteFile(step.IdentityFile(), buf.Bytes(), 0600); err != nil { return errors.Wrap(err, "error writing identity certificate") } @@ -136,7 +132,7 @@ func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) er // WriteIdentityCertificate writes the identity certificate to disk. func WriteIdentityCertificate(certChain []api.Certificate) error { - filename := filepath.Join(identityDir, "identity.crt") + filename := filepath.Join(profileIdentityDir(), "identity.crt") return writeCertificate(filename, certChain) } @@ -319,7 +315,7 @@ func (i *Identity) Renew(client Renewer) error { return errors.Wrap(err, "error encoding identity certificate") } } - certFilename := filepath.Join(identityDir, "identity.crt") + certFilename := filepath.Join(profileIdentityDir(), "identity.crt") if err := ioutil.WriteFile(certFilename, buf.Bytes(), 0600); err != nil { return errors.Wrap(err, "error writing identity certificate") }