Use methods in the step package

* rather than variables set at execution time, which may not match the
actual current context
This commit is contained in:
max furman 2021-10-20 12:41:24 -07:00
parent ed4b56732e
commit e5951fd84c
2 changed files with 21 additions and 23 deletions

View file

@ -10,6 +10,7 @@ import (
"net/url"
"github.com/pkg/errors"
"go.step.sm/cli-utils/step"
)
// Client wraps http.Client with a transport using the step root and identity.
@ -27,21 +28,22 @@ func (c *Client) ResolveReference(ref *url.URL) *url.URL {
// $STEPPATH/config/defaults.json and the identity defined in
// $STEPPATH/config/identity.json
func LoadClient() (*Client, error) {
b, err := ioutil.ReadFile(DefaultsFile)
defaultsFile := step.DefaultsFile()
b, err := ioutil.ReadFile(defaultsFile)
if err != nil {
return nil, errors.Wrapf(err, "error reading %s", DefaultsFile)
return nil, errors.Wrapf(err, "error reading %s", defaultsFile)
}
var defaults defaultsConfig
if err := json.Unmarshal(b, &defaults); err != nil {
return nil, errors.Wrapf(err, "error unmarshaling %s", DefaultsFile)
return nil, errors.Wrapf(err, "error unmarshaling %s", defaultsFile)
}
if err := defaults.Validate(); err != nil {
return nil, errors.Wrapf(err, "error validating %s", DefaultsFile)
return nil, errors.Wrapf(err, "error validating %s", defaultsFile)
}
caURL, err := url.Parse(defaults.CaURL)
if err != nil {
return nil, errors.Wrapf(err, "error validating %s", DefaultsFile)
return nil, errors.Wrapf(err, "error validating %s", defaultsFile)
}
if caURL.Scheme == "" {
caURL.Scheme = "https"
@ -52,7 +54,7 @@ func LoadClient() (*Client, error) {
return nil, err
}
if err := identity.Validate(); err != nil {
return nil, errors.Wrapf(err, "error validating %s", IdentityFile)
return nil, errors.Wrapf(err, "error validating %s", step.IdentityFile())
}
if kind := identity.Kind(); kind != MutualTLS {
return nil, errors.Errorf("unsupported identity %s: only mTLS is currently supported", kind)

View file

@ -39,12 +39,6 @@ const TunnelTLS Type = "tTLS"
// DefaultLeeway is the duration for matching not before claims.
const DefaultLeeway = 1 * time.Minute
// IdentityFile contains the location of the identity file.
var IdentityFile = filepath.Join(step.ProfilePath(), "config", "identity.json")
// DefaultsFile contains the location of the defaults file.
var DefaultsFile = filepath.Join(step.ProfilePath(), "config", "defaults.json")
// Identity represents the identity file that can be used to authenticate with
// the CA.
type Identity struct {
@ -74,23 +68,25 @@ func LoadIdentity(filename string) (*Identity, error) {
// LoadDefaultIdentity loads the default identity.
func LoadDefaultIdentity() (*Identity, error) {
return LoadIdentity(IdentityFile)
return LoadIdentity(step.IdentityFile())
}
// configDir and identityDir are used in WriteDefaultIdentity for testing
// purposes.
var (
configDir = filepath.Join(step.ProfilePath(), "config")
identityDir = filepath.Join(step.ProfilePath(), "identity")
)
func profileConfigDir() string {
return filepath.Join(step.ProfilePath(), "config")
}
func profileIdentityDir() string {
return filepath.Join(step.ProfilePath(), "identity")
}
// WriteDefaultIdentity writes the given certificates and key and the
// identity.json pointing to the new files.
func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) error {
if err := os.MkdirAll(configDir, 0700); err != nil {
if err := os.MkdirAll(profileConfigDir(), 0700); err != nil {
return errors.Wrap(err, "error creating config directory")
}
identityDir := profileIdentityDir()
if err := os.MkdirAll(identityDir, 0700); err != nil {
return errors.Wrap(err, "error creating identity directory")
}
@ -127,7 +123,7 @@ func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) er
}); err != nil {
return errors.Wrap(err, "error writing identity json")
}
if err := ioutil.WriteFile(IdentityFile, buf.Bytes(), 0600); err != nil {
if err := ioutil.WriteFile(step.IdentityFile(), buf.Bytes(), 0600); err != nil {
return errors.Wrap(err, "error writing identity certificate")
}
@ -136,7 +132,7 @@ func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) er
// WriteIdentityCertificate writes the identity certificate to disk.
func WriteIdentityCertificate(certChain []api.Certificate) error {
filename := filepath.Join(identityDir, "identity.crt")
filename := filepath.Join(profileIdentityDir(), "identity.crt")
return writeCertificate(filename, certChain)
}
@ -319,7 +315,7 @@ func (i *Identity) Renew(client Renewer) error {
return errors.Wrap(err, "error encoding identity certificate")
}
}
certFilename := filepath.Join(identityDir, "identity.crt")
certFilename := filepath.Join(profileIdentityDir(), "identity.crt")
if err := ioutil.WriteFile(certFilename, buf.Bytes(), 0600); err != nil {
return errors.Wrap(err, "error writing identity certificate")
}