Use methods in the step package

* rather than variables set at execution time, which may not match the
actual current context
This commit is contained in:
max furman 2021-10-20 12:41:24 -07:00
parent ed4b56732e
commit e5951fd84c
2 changed files with 21 additions and 23 deletions

View file

@ -10,6 +10,7 @@ import (
"net/url" "net/url"
"github.com/pkg/errors" "github.com/pkg/errors"
"go.step.sm/cli-utils/step"
) )
// Client wraps http.Client with a transport using the step root and identity. // Client wraps http.Client with a transport using the step root and identity.
@ -27,21 +28,22 @@ func (c *Client) ResolveReference(ref *url.URL) *url.URL {
// $STEPPATH/config/defaults.json and the identity defined in // $STEPPATH/config/defaults.json and the identity defined in
// $STEPPATH/config/identity.json // $STEPPATH/config/identity.json
func LoadClient() (*Client, error) { func LoadClient() (*Client, error) {
b, err := ioutil.ReadFile(DefaultsFile) defaultsFile := step.DefaultsFile()
b, err := ioutil.ReadFile(defaultsFile)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "error reading %s", DefaultsFile) return nil, errors.Wrapf(err, "error reading %s", defaultsFile)
} }
var defaults defaultsConfig var defaults defaultsConfig
if err := json.Unmarshal(b, &defaults); err != nil { if err := json.Unmarshal(b, &defaults); err != nil {
return nil, errors.Wrapf(err, "error unmarshaling %s", DefaultsFile) return nil, errors.Wrapf(err, "error unmarshaling %s", defaultsFile)
} }
if err := defaults.Validate(); err != nil { if err := defaults.Validate(); err != nil {
return nil, errors.Wrapf(err, "error validating %s", DefaultsFile) return nil, errors.Wrapf(err, "error validating %s", defaultsFile)
} }
caURL, err := url.Parse(defaults.CaURL) caURL, err := url.Parse(defaults.CaURL)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "error validating %s", DefaultsFile) return nil, errors.Wrapf(err, "error validating %s", defaultsFile)
} }
if caURL.Scheme == "" { if caURL.Scheme == "" {
caURL.Scheme = "https" caURL.Scheme = "https"
@ -52,7 +54,7 @@ func LoadClient() (*Client, error) {
return nil, err return nil, err
} }
if err := identity.Validate(); err != nil { if err := identity.Validate(); err != nil {
return nil, errors.Wrapf(err, "error validating %s", IdentityFile) return nil, errors.Wrapf(err, "error validating %s", step.IdentityFile())
} }
if kind := identity.Kind(); kind != MutualTLS { if kind := identity.Kind(); kind != MutualTLS {
return nil, errors.Errorf("unsupported identity %s: only mTLS is currently supported", kind) return nil, errors.Errorf("unsupported identity %s: only mTLS is currently supported", kind)

View file

@ -39,12 +39,6 @@ const TunnelTLS Type = "tTLS"
// DefaultLeeway is the duration for matching not before claims. // DefaultLeeway is the duration for matching not before claims.
const DefaultLeeway = 1 * time.Minute const DefaultLeeway = 1 * time.Minute
// IdentityFile contains the location of the identity file.
var IdentityFile = filepath.Join(step.ProfilePath(), "config", "identity.json")
// DefaultsFile contains the location of the defaults file.
var DefaultsFile = filepath.Join(step.ProfilePath(), "config", "defaults.json")
// Identity represents the identity file that can be used to authenticate with // Identity represents the identity file that can be used to authenticate with
// the CA. // the CA.
type Identity struct { type Identity struct {
@ -74,23 +68,25 @@ func LoadIdentity(filename string) (*Identity, error) {
// LoadDefaultIdentity loads the default identity. // LoadDefaultIdentity loads the default identity.
func LoadDefaultIdentity() (*Identity, error) { func LoadDefaultIdentity() (*Identity, error) {
return LoadIdentity(IdentityFile) return LoadIdentity(step.IdentityFile())
} }
// configDir and identityDir are used in WriteDefaultIdentity for testing func profileConfigDir() string {
// purposes. return filepath.Join(step.ProfilePath(), "config")
var ( }
configDir = filepath.Join(step.ProfilePath(), "config")
identityDir = filepath.Join(step.ProfilePath(), "identity") func profileIdentityDir() string {
) return filepath.Join(step.ProfilePath(), "identity")
}
// WriteDefaultIdentity writes the given certificates and key and the // WriteDefaultIdentity writes the given certificates and key and the
// identity.json pointing to the new files. // identity.json pointing to the new files.
func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) error { func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) error {
if err := os.MkdirAll(configDir, 0700); err != nil { if err := os.MkdirAll(profileConfigDir(), 0700); err != nil {
return errors.Wrap(err, "error creating config directory") return errors.Wrap(err, "error creating config directory")
} }
identityDir := profileIdentityDir()
if err := os.MkdirAll(identityDir, 0700); err != nil { if err := os.MkdirAll(identityDir, 0700); err != nil {
return errors.Wrap(err, "error creating identity directory") return errors.Wrap(err, "error creating identity directory")
} }
@ -127,7 +123,7 @@ func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) er
}); err != nil { }); err != nil {
return errors.Wrap(err, "error writing identity json") return errors.Wrap(err, "error writing identity json")
} }
if err := ioutil.WriteFile(IdentityFile, buf.Bytes(), 0600); err != nil { if err := ioutil.WriteFile(step.IdentityFile(), buf.Bytes(), 0600); err != nil {
return errors.Wrap(err, "error writing identity certificate") return errors.Wrap(err, "error writing identity certificate")
} }
@ -136,7 +132,7 @@ func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) er
// WriteIdentityCertificate writes the identity certificate to disk. // WriteIdentityCertificate writes the identity certificate to disk.
func WriteIdentityCertificate(certChain []api.Certificate) error { func WriteIdentityCertificate(certChain []api.Certificate) error {
filename := filepath.Join(identityDir, "identity.crt") filename := filepath.Join(profileIdentityDir(), "identity.crt")
return writeCertificate(filename, certChain) return writeCertificate(filename, certChain)
} }
@ -319,7 +315,7 @@ func (i *Identity) Renew(client Renewer) error {
return errors.Wrap(err, "error encoding identity certificate") return errors.Wrap(err, "error encoding identity certificate")
} }
} }
certFilename := filepath.Join(identityDir, "identity.crt") certFilename := filepath.Join(profileIdentityDir(), "identity.crt")
if err := ioutil.WriteFile(certFilename, buf.Bytes(), 0600); err != nil { if err := ioutil.WriteFile(certFilename, buf.Bytes(), 0600); err != nil {
return errors.Wrap(err, "error writing identity certificate") return errors.Wrap(err, "error writing identity certificate")
} }