first commit

This commit is contained in:
max furman 2018-10-05 21:48:36 +00:00
commit c284a2c0ab
65 changed files with 7138 additions and 0 deletions

5
.gitignore vendored Normal file
View file

@ -0,0 +1,5 @@
*.swp
/.GOPATH
/bin
/run
vendor

13
.travis.yml Normal file
View file

@ -0,0 +1,13 @@
language: go
go:
- 1.10.x
before_script:
- make bootstrap
script:
- make
notifications:
email: false
slack:
secure: d0a8OzlFZtvPs1wyRkM4CTN+SdFi+S9+lpv5a3AExoTJYKmbA+HTzIKaOmcVAln8STj7/OzaKAJdCE49ZYpQnmUWAmoUcJdlHEE5cccNPMnpY48VT7G69XPfs4QYHYLjlJxiWBaXMDF+AgNoMe6UMrehB69rqLnhIq/yZJWy5s6XG80LxQjQATv53kLbLJCCyORmfKI5/zyQeMyvOsjLOmE7b1CkSXQWTXU/rVTzSSsTC2uozh8WqBGkVWBR45JGgB8JL2AIJHY7q8IjGVAbMuiig/vHInk3SKSqikrAA5NEewlpAMEAum3Xobtt9vx+Ox0KO4qpNkqCuoAowhnXQwZO8MKrQkE6dKRmi6w2wq64pCqvvOQb1X6hDsreUoz9XsxEt1GLSzelhH9HMD8hb2tMebZw0EeDrSL9E37wPdbAWSYrJwRG+PK0NGWPqh6JMHlDYR5nGUGoJ/q7EBOTH6ZAPoYjAuOjypI43xkMqN/72F0a7FbNXvelVQipiLD+zhZatiJNhlxKPrwE17+X9t/CDHFGeSllJT1YdqLYnAxqRJixP59MTdlx3YzXKW/PBb8dkURMKRw/erY4CzrZ62QiSzG0Ya8x18sCxcjKZtP5ChrhDnnFi7hFIQY/svHVcpO6e9xsesMLYzQJ43MJ+kCkppWS5B7zs0Tu6LXHfvQ=
dd:
secure: MyM0SW+xZvTo2PbFf46yzUGprbaWqcndXwLPWSjSHxoDXsR8+6RufZCFdpKdXPGujE8lmujEcY8bXNveLaiq8TmlkwlCmj7p8MRJcXLBSUgAw8+iCkbu8JnEQ9xN5hw6OSe7dhWGlZxPm+KTij7KqI3j0UqI5TbeloASCsHZEYMPB03Ku5uHslTjeAfv+ryJbceIUJ1s3IaDfGH8jJMees0XQjZoMJN/aZejdyAaXlH6Gqs/dC0v0N8/ZH3vxMnBKU0qY89V025b5JwpvxGEeieOZ+w8lp4j+MZH3GVsd9If/9Gokp0VdApJaYK4laA2EwfXcf6bVF1gFXYw69seS3X1NMmuikvO7xMs2bFBz04ATNGB8lrXvZNk8IISkHcL4jZch5F4P4n2TIoNsHLrrGL5CthHgugNtF1TeA0IDNRhNu9wM/1wHASISPH+977hEBEj0Bp1rQBW+w+ELiZ+qXd7I2NtHCbYnxEgJmS9bApZXblhxDhKiObGajNpRXMYYcOSydK2+lLZVGmsVV91HeHtsjZlO/ltxZthLA6Z8LC2YZ56X8wdI1tA2EQ1DZhTfZNe6Swsx+gNHlCjSkI+d1iel8FW77oKAAyK7kuzGKgqu2xmoiXgnxO0Ufzpu4VCUEIvh6+K33omRcrdJqzyo9LRQRjjU3P0o973oGVuN9Q=

326
Gopkg.lock generated Normal file
View file

@ -0,0 +1,326 @@
# This file is autogenerated, do not edit; changes may be undone by the next 'dep ensure'.
[[projects]]
digest = "1:304cb78c285eaf02ab529ad02a257cad9b4845022915e6c82f87860ac53222d8"
name = "github.com/alecthomas/gometalinter"
packages = ["."]
pruneopts = "UT"
revision = "bae2f1293d092fd8167939d5108d1b025eaef9de"
[[projects]]
branch = "master"
digest = "1:c198fdc381e898e8fb62b8eb62758195091c313ad18e52a3067366e1dda2fb3c"
name = "github.com/alecthomas/units"
packages = ["."]
pruneopts = "UT"
revision = "2efee857e7cfd4f3d0138cc3cbb1b4966962b93a"
[[projects]]
digest = "1:848ef40f818e59905140552cc49ff3dc1a15f955e4b56d1c5c2cc4b54dbadf0c"
name = "github.com/client9/misspell"
packages = [
".",
"cmd/misspell",
]
pruneopts = "UT"
revision = "b90dc15cfd220ecf8bbc9043ecb928cef381f011"
version = "v0.3.4"
[[projects]]
branch = "master"
digest = "1:b872acdc9ad7bc072f54163ed43c44ba00dbf0411301f96db6631266c5935d43"
name = "github.com/go-chi/chi"
packages = ["."]
pruneopts = "UT"
revision = "44932d207a10cd2f26b06095a61843c9807188ea"
[[projects]]
branch = "master"
digest = "1:4ee452f8994700dcab9e816aef1cb9eb2317218734c6ccf5135746e6c19f3dce"
name = "github.com/golang/lint"
packages = ["golint"]
pruneopts = "UT"
revision = "06c8688daad7faa9da5a0c2f163a3d14aac986ca"
[[projects]]
digest = "1:d2754cafcab0d22c13541618a8029a70a8959eb3525ff201fe971637e2274cd0"
name = "github.com/google/go-cmp"
packages = [
"cmp",
"cmp/cmpopts",
"cmp/internal/diff",
"cmp/internal/function",
"cmp/internal/value",
]
pruneopts = "UT"
revision = "3af367b6b30c263d47e8895973edcca9a49cf029"
version = "v0.2.0"
[[projects]]
branch = "master"
digest = "1:750e747d0aad97b79f4a4e00034bae415c2ea793fd9e61438d966ee9c79579bf"
name = "github.com/google/shlex"
packages = ["."]
pruneopts = "UT"
revision = "6f45313302b9c56850fc17f99e40caebce98c716"
[[projects]]
branch = "master"
digest = "1:78010c43f45797f252007611599b5eb3d3752775305f9aa68669c318a54c6230"
name = "github.com/gordonklaus/ineffassign"
packages = ["."]
pruneopts = "UT"
revision = "3fd9b69f2fb179405773f03d33c68a00f3a1ca4a"
[[projects]]
digest = "1:266d082179f3a29a4bdcf1dcc49d4a304f5c7107e65bd22d1fecacf45f1ac348"
name = "github.com/newrelic/go-agent"
packages = [
".",
"internal",
"internal/cat",
"internal/jsonx",
"internal/logger",
"internal/sysinfo",
"internal/utilization",
]
pruneopts = "UT"
revision = "f5bce3387232559bcbe6a5f8227c4bf508dac1ba"
version = "v1.11.0"
[[projects]]
digest = "1:07140002dbf37da92090f731b46fa47be4820b82fe5c14a035203b0e813d0ec2"
name = "github.com/nicksnyder/go-i18n"
packages = [
"i18n",
"i18n/bundle",
"i18n/language",
"i18n/translation",
]
pruneopts = "UT"
revision = "0dc1626d56435e9d605a29875701721c54bc9bbd"
version = "v1.10.0"
[[projects]]
digest = "1:95741de3af260a92cc5c7f3f3061e85273f5a81b5db20d4bd68da74bd521675e"
name = "github.com/pelletier/go-toml"
packages = ["."]
pruneopts = "UT"
revision = "c01d1270ff3e442a8a57cddc1c92dc1138598194"
version = "v1.2.0"
[[projects]]
digest = "1:40e195917a951a8bf867cd05de2a46aaf1806c50cf92eebf4c16f78cd196f747"
name = "github.com/pkg/errors"
packages = ["."]
pruneopts = "UT"
revision = "645ef00459ed84a119197bfb8d8205042c6df63d"
version = "v0.8.0"
[[projects]]
digest = "1:757b110984b77e820e01c60d3ac03a376a0fdb05c990dd9d6bd4f9ba0d606261"
name = "github.com/rs/xid"
packages = ["."]
pruneopts = "UT"
revision = "2c7e97ce663ff82c49656bca3048df0fdd83c5f9"
version = "v1.2.0"
[[projects]]
digest = "1:d867dfa6751c8d7a435821ad3b736310c2ed68945d05b50fb9d23aee0540c8cc"
name = "github.com/sirupsen/logrus"
packages = ["."]
pruneopts = "UT"
revision = "3e01752db0189b9157070a0e1668a620f9a85da2"
version = "v1.0.6"
[[projects]]
branch = "master"
digest = "1:4d1f0640875aefefdb2151f297c144518a71f5729c4b9f9423f09df501f699c5"
name = "github.com/smallstep/assert"
packages = ["."]
pruneopts = "UT"
revision = "de77670473b5492f5d0bce155b5c01534c2d13f7"
[[projects]]
branch = "ca-commands-wip"
digest = "1:723d56910291478edfd50fa2146e52fc6d8f5b5e67ddd6e5b8e89291313256a2"
name = "github.com/smallstep/cli"
packages = [
"crypto/keys",
"crypto/pemutil",
"crypto/randutil",
"crypto/tlsutil",
"crypto/x509util",
"errs",
"jose",
"pkg/x509",
"utils",
]
pruneopts = "UT"
revision = "75ee5a0262bdbb305c75dcb98e7f806540537678"
[[projects]]
branch = "master"
digest = "1:ae5dbd6e0922625debc1d0b3a74a4d97b4f89d2d861e4f0e0886c03b6b28ced7"
name = "github.com/smallstep/go-makefile"
packages = ["."]
pruneopts = "UT"
revision = "c6025f797567554133ce98a3fcc224b3691a9f05"
[[projects]]
branch = "master"
digest = "1:ba52e5a5fb800ce55108b7a5f181bb809aab71c16736051312b0aa969f82ad39"
name = "github.com/tsenart/deadcode"
packages = ["."]
pruneopts = "UT"
revision = "210d2dc333e90c7e3eedf4f2242507a8e83ed4ab"
[[projects]]
branch = "master"
digest = "1:189a0e6e9c657bb662bafc41a796360d11c88eed7614b1b6f003b8fbc8847e5e"
name = "github.com/urfave/cli"
packages = ["."]
pruneopts = "UT"
revision = "8e01ec4cd3e2d84ab2fe90d8210528ffbb06d8ff"
[[projects]]
branch = "master"
digest = "1:82590d674737712213caa196f58716ee00f2711d860451dc8bd36e847015209a"
name = "golang.org/x/crypto"
packages = [
"cryptobyte",
"cryptobyte/asn1",
"ed25519",
"ed25519/internal/edwards25519",
"pbkdf2",
"ssh/terminal",
]
pruneopts = "UT"
revision = "aabede6cba87e37f413b3e60ebfc214f8eeca1b0"
[[projects]]
branch = "master"
digest = "1:9238d4d6fdc7b3859e37764c86d02625b74e0c76cd1faae3677735d5c5129724"
name = "golang.org/x/lint"
packages = ["."]
pruneopts = "UT"
revision = "06c8688daad7faa9da5a0c2f163a3d14aac986ca"
[[projects]]
branch = "master"
digest = "1:88a792a03a354a98ee468d774bc9a882e6d9d666b8f0069deb66d896ba83c163"
name = "golang.org/x/net"
packages = [
"http/httpguts",
"http2",
"http2/hpack",
"idna",
]
pruneopts = "UT"
revision = "4dfa2610cdf3b287375bbba5b8f2a14d3b01d8de"
[[projects]]
branch = "master"
digest = "1:2f71657f09ff05e4567909e9e0de7ad799828c96d402c540b41dc044a6590fb2"
name = "golang.org/x/sys"
packages = [
"unix",
"windows",
]
pruneopts = "UT"
revision = "1c9583448a9c3aa0f9a6a5241bf73c0bd8aafded"
[[projects]]
digest = "1:a2ab62866c75542dd18d2b069fec854577a20211d7c0ea6ae746072a1dccdd18"
name = "golang.org/x/text"
packages = [
"collate",
"collate/build",
"internal/colltab",
"internal/gen",
"internal/tag",
"internal/triegen",
"internal/ucd",
"language",
"secure/bidirule",
"transform",
"unicode/bidi",
"unicode/cldr",
"unicode/norm",
"unicode/rangetable",
]
pruneopts = "UT"
revision = "f21a4dfb5e38f5895301dc265a8def02365cc3d0"
version = "v0.3.0"
[[projects]]
branch = "master"
digest = "1:3d35f43f18787f661e7a4c8d8bcd424e96b046ef63769f02ae23e34aa57ff661"
name = "golang.org/x/tools"
packages = [
"go/ast/astutil",
"go/gcexportdata",
"go/internal/gcimporter",
"go/types/typeutil",
]
pruneopts = "UT"
revision = "7d1dc997617fb662918b6ea95efc19faa87e1cf8"
[[projects]]
digest = "1:39efb07a0d773dc09785b237ada4e10b5f28646eb6505d97bc18f8d2ff439362"
name = "gopkg.in/alecthomas/kingpin.v3-unstable"
packages = ["."]
pruneopts = "UT"
revision = "63abe20a23e29e80bbef8089bd3dee3ac25e5306"
[[projects]]
digest = "1:7fbe10f3790dc4e6296c7c844c5a9b35513e5521c29c47e10ba99cd2956a2719"
name = "gopkg.in/square/go-jose.v2"
packages = [
".",
"cipher",
"json",
"jwt",
]
pruneopts = "UT"
revision = "ef984e69dd356202fd4e4910d4d9c24468bdf0b8"
version = "v2.1.9"
[[projects]]
digest = "1:342378ac4dcb378a5448dd723f0784ae519383532f5e70ade24132c4c8693202"
name = "gopkg.in/yaml.v2"
packages = ["."]
pruneopts = "UT"
revision = "5420a8b6744d3b0345ab293f6fcba19c978f1183"
version = "v2.2.1"
[solve-meta]
analyzer-name = "dep"
analyzer-version = 1
input-imports = [
"github.com/alecthomas/gometalinter",
"github.com/client9/misspell/cmd/misspell",
"github.com/go-chi/chi",
"github.com/golang/lint/golint",
"github.com/gordonklaus/ineffassign",
"github.com/newrelic/go-agent",
"github.com/pkg/errors",
"github.com/rs/xid",
"github.com/sirupsen/logrus",
"github.com/smallstep/assert",
"github.com/smallstep/cli/crypto/keys",
"github.com/smallstep/cli/crypto/pemutil",
"github.com/smallstep/cli/crypto/randutil",
"github.com/smallstep/cli/crypto/tlsutil",
"github.com/smallstep/cli/crypto/x509util",
"github.com/smallstep/cli/jose",
"github.com/smallstep/cli/pkg/x509",
"github.com/smallstep/go-makefile",
"github.com/tsenart/deadcode",
"golang.org/x/net/http2",
"gopkg.in/square/go-jose.v2",
"gopkg.in/square/go-jose.v2/jwt",
]
solver-name = "gps-cdcl"
solver-version = 1

66
Gopkg.toml Normal file
View file

@ -0,0 +1,66 @@
# Gopkg.toml example
#
# Refer to https://golang.github.io/dep/docs/Gopkg.toml.html
# for detailed Gopkg.toml documentation.
#
# required = ["github.com/user/thing/cmd/thing"]
# ignored = ["github.com/user/project/pkgX", "bitbucket.org/user/project/pkgA/pkgY"]
#
# [[constraint]]
# name = "github.com/user/project"
# version = "1.0.0"
#
# [[constraint]]
# name = "github.com/user/project2"
# branch = "dev"
# source = "github.com/myfork/project2"
#
# [[override]]
# name = "github.com/x/y"
# version = "2.4.0"
#
# [prune]
# non-go = false
# go-tests = true
# unused-packages = true
required = [
"github.com/alecthomas/gometalinter",
"github.com/golang/lint/golint",
"github.com/client9/misspell/cmd/misspell",
"github.com/gordonklaus/ineffassign",
"github.com/tsenart/deadcode",
"github.com/smallstep/go-makefile"
]
[[constraint]]
name = "github.com/alecthomas/gometalinter"
revision = "bae2f1293d092fd8167939d5108d1b025eaef9de"
[[override]]
name = "gopkg.in/alecthomas/kingpin.v3-unstable"
revision = "63abe20a23e29e80bbef8089bd3dee3ac25e5306"
[[constraint]]
branch = "master"
name = "github.com/go-chi/chi"
[[constraint]]
branch = "ca-commands-wip"
name = "github.com/smallstep/cli"
[prune]
go-tests = true
unused-packages = true
[[constraint]]
name = "github.com/newrelic/go-agent"
version = "1.11.0"
[[constraint]]
name = "github.com/sirupsen/logrus"
version = "1.0.6"
[[constraint]]
name = "gopkg.in/square/go-jose.v2"
version = "2.1.9"

245
Makefile Normal file
View file

@ -0,0 +1,245 @@
PKG?=github.com/smallstep/ca-component/cmd/step-ca
BINNAME?=step-ca
# Set V to 1 for verbose output from the Makefile
Q=$(if $V,,@)
PREFIX?=
SRC=$(shell find . -type f -name '*.go' -not -path "./vendor/*")
GOOS_OVERRIDE ?=
# Set shell to bash for `echo -e`
SHELL := /bin/bash
all: build lint test
.PHONY: all
#########################################
# Bootstrapping
#########################################
bootstra%:
$Q which dep || go get github.com/golang/dep/cmd/dep
$Q dep ensure
vendor: Gopkg.lock
$Q dep ensure
BOOTSTRAP=\
github.com/golang/lint/golint \
github.com/client9/misspell/cmd/misspell \
github.com/gordonklaus/ineffassign \
github.com/tsenart/deadcode \
github.com/alecthomas/gometalinter
define VENDOR_BIN_TMPL
vendor/bin/$(notdir $(1)): vendor
$Q go build -o $$@ ./vendor/$(1)
VENDOR_BINS += vendor/bin/$(notdir $(1))
endef
$(foreach pkg,$(BOOTSTRAP),$(eval $(call VENDOR_BIN_TMPL,$(pkg))))
.PHONY: bootstra% vendor
#################################################
# Determine the type of `push` and `version`
#################################################
# Version flags to embed in the binaries
VERSION ?= $(shell [ -d .git ] && git describe --tags --always --dirty="-dev")
VERSION := $(shell echo $(VERSION) | sed 's/^v//')
# If TRAVIS_TAG is set then we know this ref has been tagged.
ifdef TRAVIS_TAG
PUSHTYPE=release
else
PUSHTYPE=master
endif
#########################################
# Build
#########################################
DATE := $(shell date -u '+%Y-%m-%d %H:%M UTC')
LDFLAGS := -ldflags='-w -X "main.Version=$(VERSION)" -X "main.BuildTime=$(DATE)"'
GOFLAGS := CGO_ENABLED=0
build: $(PREFIX)bin/$(BINNAME)
@echo "Build Complete!"
$(PREFIX)bin/$(BINNAME): vendor $(call rwildcard,*.go)
$Q mkdir -p $(@D)
$Q $(GOOS_OVERRIDE) $(GOFLAGS) go build -v -o $(PREFIX)bin/$(BINNAME) $(LDFLAGS) $(PKG)
# Target for building without calling dep ensure
simple:
$Q mkdir -p bin/
$Q $(GOOS_OVERRIDE) $(GOFLAGS) go build -v -o bin/$(BINNAME) $(LDFLAGS) $(PKG)
@echo "Build Complete!"
.PHONY: build simple
#########################################
# Go generate
#########################################
generate:
$Q go generate ./...
.PHONY: generate
#########################################
# Test
#########################################
test:
$Q $(GOFLAGS) go test -short -cover ./...
vtest:
$(Q)for d in $$(go list ./... | grep -v vendor); do \
echo -e "TESTS FOR: for \033[0;35m$$d\033[0m"; \
$(GOFLAGS) go test -v -bench=. -run=. -short -coverprofile=profile.coverage.out -covermode=atomic $$d; \
out=$$?; \
if [[ $$out -ne 0 ]]; then ret=$$out; fi;\
rm -f profile.coverage.out; \
done; exit $$ret;
.PHONY: test vtest
integrate: integration
integration: bin/$(BINNAME)
$Q $(GOFLAGS) go test -tags=integration ./integration/...
.PHONY: integrate integration
#########################################
# Linting
#########################################
LINTERS=\
gofmt \
golint \
vet \
misspell \
ineffassign \
deadcode
$(patsubst %,%-bin,$(filter-out gofmt vet,$(LINTERS))): %-bin: vendor/bin/%
gofmt-bin vet-bin:
$(LINTERS): %: vendor/bin/gometalinter %-bin vendor
$Q PATH=`pwd`/vendor/bin:$$PATH gometalinter --tests --disable-all --vendor \
--deadline=5m -s data -s pkg --enable $@ ./...
fmt:
$Q gofmt -l -w $(SRC)
lint: $(LINTERS)
.PHONY: $(LINTERS) lint fmt
#########################################
# Install
#########################################
INSTALL_PREFIX?=/usr/
install: $(PREFIX)bin/$(BINNAME)
$Q install -D $(PREFIX)bin/$(BINNAME) $(DESTDIR)$(INSTALL_PREFIX)bin/$(BINNAME)
uninstall:
$Q rm -f $(DESTDIR)$(INSTALL_PREFIX)/bin/$(BINNAME)
.PHONY: install uninstall
#########################################
# Debian
#########################################
debian:
$Q mkdir -p $(RELEASE); \
OUTPUT=../step-ca_*.deb; \
rm $$OUTPUT; \
dpkg-buildpackage -b -rfakeroot -us -uc && cp $$OUTPUT $(RELEASE)/
distclean: clean
.PHONY: debian distclean
#################################################
# Build statically compiled step binary for various operating systems
#################################################
OUTPUT_ROOT=output/
BINARY_OUTPUT=$(OUTPUT_ROOT)binary/
BUNDLE_MAKE=v=$v GOOS_OVERRIDE='GOOS=$(1) GOARCH=$(2)' PREFIX=$(3) make $(3)bin/step
RELEASE=./.travis-releases
binary-linux:
$(call BUNDLE_MAKE,linux,amd64,$(BINARY_OUTPUT)linux/)
binary-darwin:
$(call BUNDLE_MAKE,darwin,amd64,$(BINARY_OUTPUT)darwin/)
define BUNDLE
$(q)BUNDLE_DIR=$(BINARY_OUTPUT)$(1)/bundle; \
stepName=step_$(2); \
mkdir -p $$BUNDLE_DIR $(RELEASE); \
TMP=$$(mktemp -d $$BUNDLE_DIR/tmp.XXXX); \
trap "rm -rf $$TMP" EXIT INT QUIT TERM; \
newdir=$$TMP/$$stepName; \
mkdir -p $$newdir/bin; \
cp $(BINARY_OUTPUT)$(1)/bin/step $$newdir/bin/; \
cp README.md $$newdir/; \
NEW_BUNDLE=$(RELEASE)/step_$(2)_$(1)_$(3).tar.gz; \
rm -f $$NEW_BUNDLE; \
tar -zcvf $$NEW_BUNDLE -C $$TMP $$stepName;
endef
bundle-linux: binary-linux
$(call BUNDLE,linux,$(VERSION),amd64)
bundle-darwin: binary-darwin
$(call BUNDLE,darwin,$(VERSION),amd64)
.PHONY: binary-linux binary-darwin bundle-linux bundle-darwin
#################################################
# Targets for creating OS specific artifacts
#################################################
artifacts-linux-tag: bundle-linux debian
artifacts-darwin-tag: bundle-darwin
artifacts-tag: artifacts-linux-tag artifacts-darwin-tag
.PHONY: artifacts-linux-tag artifacts-darwin-tag artifacts-tag
#################################################
# Targets for creating step artifacts
#################################################
# For all builds that are not tagged
artifacts-master:
# For all builds with a release tag
artifacts-release: artifacts-tag
# This command is called by travis directly *after* a successful build
artifacts: artifacts-$(PUSHTYPE)
.PHONY: artifacts-master artifacts-release artifacts
#########################################
# Clean
#########################################
clean:
@echo "You will need to run 'make bootstrap' or 'dep ensure' directly to re-download any dependencies."
$Q rm -rf vendor
ifneq ($(BINNAME),"")
$Q rm -f bin/$(BINNAME)
endif
.PHONY: clean

317
api/api.go Normal file
View file

@ -0,0 +1,317 @@
package api
import (
"crypto/tls"
"crypto/x509"
"encoding/json"
"encoding/pem"
"net/http"
"strings"
"time"
"github.com/go-chi/chi"
"github.com/pkg/errors"
"github.com/smallstep/cli/crypto/tlsutil"
"github.com/smallstep/cli/crypto/x509util"
)
// Minimum and maximum validity of an end-entity (not root or intermediate) certificate.
// They will be overwritten with the values configured in the authority
var (
minCertDuration = 5 * time.Minute
maxCertDuration = 24 * time.Hour
)
// Claim interface is implemented by types used to validate specific claims in a
// certificate request.
// TODO(mariano): Rename?
type Claim interface {
Valid(cr *x509.CertificateRequest) error
}
// SignOptions contains the options that can be passed to the Authority.Sign
// method.
type SignOptions struct {
NotAfter time.Time `json:"notAfter"`
NotBefore time.Time `json:"notBefore"`
}
// Authority is the interface implemented by a CA authority.
type Authority interface {
Authorize(ott string) ([]Claim, error)
GetTLSOptions() *tlsutil.TLSOptions
GetMinDuration() time.Duration
GetMaxDuration() time.Duration
Root(shasum string) (*x509.Certificate, error)
Sign(cr *x509.CertificateRequest, opts SignOptions, claims ...Claim) (*x509.Certificate, *x509.Certificate, error)
Renew(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error)
}
// Certificate wraps a *x509.Certificate and adds the json.Marshaler interface.
type Certificate struct {
*x509.Certificate
}
// NewCertificate is a helper method that returns a Certificate from a
// *x509.Certificate.
func NewCertificate(cr *x509.Certificate) Certificate {
return Certificate{
Certificate: cr,
}
}
// MarshalJSON implements the json.Marshaler interface. The certificate is
// quoted string using the PEM encoding.
func (c Certificate) MarshalJSON() ([]byte, error) {
if c.Certificate == nil {
return []byte("null"), nil
}
block := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: c.Raw,
})
return json.Marshal(string(block))
}
// UnmarshalJSON implements the json.Unmarshaler interface. The certificate is
// expected to be a quoted string using the PEM encoding.
func (c *Certificate) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err != nil {
return errors.Wrap(err, "error decoding certificate")
}
block, _ := pem.Decode([]byte(s))
if block == nil {
return errors.New("error decoding certificate")
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return errors.Wrap(err, "error decoding certificate")
}
c.Certificate = cert
return nil
}
// CertificateRequest wraps a *x509.CertificateRequest and adds the
// json.Unmarshaler interface.
type CertificateRequest struct {
*x509.CertificateRequest
}
// NewCertificateRequest is a helper method that returns a CertificateRequest
// from a *x509.CertificateRequest.
func NewCertificateRequest(cr *x509.CertificateRequest) CertificateRequest {
return CertificateRequest{
CertificateRequest: cr,
}
}
// MarshalJSON implements the json.Marshaler interface. The certificate request
// is a quoted string using the PEM encoding.
func (c CertificateRequest) MarshalJSON() ([]byte, error) {
if c.CertificateRequest == nil {
return []byte("null"), nil
}
block := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE REQUEST",
Bytes: c.Raw,
})
return json.Marshal(string(block))
}
// UnmarshalJSON implements the json.Unmarshaler interface. The certificate
// request is expected to be a quoted string using the PEM encoding.
func (c *CertificateRequest) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err != nil {
return errors.Wrap(err, "error decoding csr")
}
block, _ := pem.Decode([]byte(s))
if block == nil {
return errors.New("error decoding csr")
}
cr, err := x509.ParseCertificateRequest(block.Bytes)
if err != nil {
return errors.Wrap(err, "error decoding csr")
}
c.CertificateRequest = cr
return nil
}
// Router defines a common router interface.
type Router interface {
// MethodFunc adds routes for `pattern` that matches
// the `method` HTTP method.
MethodFunc(method, pattern string, h http.HandlerFunc)
}
// RouterHandler is the interface that a HTTP handler that manages multiple
// endpoints will implement.
type RouterHandler interface {
Route(r Router)
}
// HealthResponse is the response object that returns the health of the server.
type HealthResponse struct {
Status string `json:"status"`
}
// RootResponse is the response object that returns the PEM of a root certificate.
type RootResponse struct {
RootPEM Certificate `json:"ca"`
}
// SignRequest is the request body for a certificate signature request.
type SignRequest struct {
CsrPEM CertificateRequest `json:"csr"`
OTT string `json:"ott"`
NotAfter time.Time `json:"notAfter"`
NotBefore time.Time `json:"notBefore"`
}
// Validate checks the fields of the SignRequest and returns nil if they are ok
// or an error if something is wrong.
func (s *SignRequest) Validate() error {
if s.CsrPEM.CertificateRequest == nil {
return BadRequest(errors.New("missing csr"))
}
if err := s.CsrPEM.CertificateRequest.CheckSignature(); err != nil {
return BadRequest(errors.Wrap(err, "invalid csr"))
}
if s.OTT == "" {
return BadRequest(errors.New("missing ott"))
}
now := time.Now()
if s.NotBefore.IsZero() {
s.NotBefore = now
}
if s.NotAfter.IsZero() {
s.NotAfter = now.Add(x509util.DefaultCertValidity)
}
if s.NotAfter.Before(now) {
return BadRequest(errors.New("notAfter < now"))
}
if s.NotAfter.Before(s.NotBefore) {
return BadRequest(errors.New("notAfter < notBefore"))
}
requestedDuration := s.NotAfter.Sub(s.NotBefore)
if requestedDuration < minCertDuration {
return BadRequest(errors.New("requested certificate validity duration is too short"))
}
if requestedDuration > maxCertDuration {
return BadRequest(errors.New("requested certificate validity duration is too long"))
}
return nil
}
// SignResponse is the response object of the certificate signature request.
type SignResponse struct {
ServerPEM Certificate `json:"crt"`
CaPEM Certificate `json:"ca"`
TLSOptions *tlsutil.TLSOptions `json:"tlsOptions,omitempty"`
TLS *tls.ConnectionState `json:"-"`
}
// caHandler is the type used to implement the different CA HTTP endpoints.
type caHandler struct {
Authority Authority
}
// New creates a new RouterHandler with the CA endpoints.
func New(authority Authority) RouterHandler {
minCertDuration = authority.GetMinDuration()
maxCertDuration = authority.GetMaxDuration()
return &caHandler{
Authority: authority,
}
}
func (h *caHandler) Route(r Router) {
r.MethodFunc("GET", "/health", h.Health)
r.MethodFunc("GET", "/root/{sha}", h.Root)
r.MethodFunc("POST", "/sign", h.Sign)
r.MethodFunc("POST", "/renew", h.Renew)
}
// Health is an HTTP handler that returns the status of the server.
func (h *caHandler) Health(w http.ResponseWriter, r *http.Request) {
JSON(w, HealthResponse{Status: "ok"})
}
// Root is an HTTP handler that using the SHA256 from the URL, returns the root
// certificate for the given SHA256.
func (h *caHandler) Root(w http.ResponseWriter, r *http.Request) {
sha := chi.URLParam(r, "sha")
sum := strings.ToLower(strings.Replace(sha, "-", "", -1))
// Load root certificate with the
cert, err := h.Authority.Root(sum)
if err != nil {
WriteError(w, NotFound(errors.Wrapf(err, "%s was not found", r.RequestURI)))
return
}
JSON(w, &RootResponse{RootPEM: Certificate{cert}})
}
// Sign is an HTTP handler that reads a certificate request and an
// one-time-token (ott) from the body and creates a new certificate with the
// information in the certificate request.
func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
var body SignRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, BadRequest(errors.Wrap(err, "error reading request body")))
return
}
if err := body.Validate(); err != nil {
WriteError(w, err)
return
}
claims, err := h.Authority.Authorize(body.OTT)
if err != nil {
WriteError(w, Unauthorized(err))
return
}
opts := SignOptions{
NotBefore: body.NotBefore,
NotAfter: body.NotAfter,
}
cert, root, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, claims...)
if err != nil {
WriteError(w, Forbidden(err))
return
}
w.WriteHeader(http.StatusCreated)
JSON(w, &SignResponse{
ServerPEM: Certificate{cert},
CaPEM: Certificate{root},
TLSOptions: h.Authority.GetTLSOptions(),
})
}
// Renew uses the information of certificate in the TLS connection to create a
// new one.
func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
WriteError(w, BadRequest(errors.New("missing peer certificate")))
return
}
cert, root, err := h.Authority.Renew(r.TLS.PeerCertificates[0])
if err != nil {
WriteError(w, Forbidden(err))
return
}
w.WriteHeader(http.StatusCreated)
JSON(w, &SignResponse{
ServerPEM: Certificate{cert},
CaPEM: Certificate{root},
TLSOptions: h.Authority.GetTLSOptions(),
})
}

618
api/api_test.go Normal file
View file

@ -0,0 +1,618 @@
package api
import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"encoding/pem"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"reflect"
"strings"
"testing"
"time"
"github.com/go-chi/chi"
"github.com/smallstep/cli/crypto/tlsutil"
)
const (
rootPEM = `-----BEGIN CERTIFICATE-----
MIIEBDCCAuygAwIBAgIDAjppMA0GCSqGSIb3DQEBBQUAMEIxCzAJBgNVBAYTAlVT
MRYwFAYDVQQKEw1HZW9UcnVzdCBJbmMuMRswGQYDVQQDExJHZW9UcnVzdCBHbG9i
YWwgQ0EwHhcNMTMwNDA1MTUxNTU1WhcNMTUwNDA0MTUxNTU1WjBJMQswCQYDVQQG
EwJVUzETMBEGA1UEChMKR29vZ2xlIEluYzElMCMGA1UEAxMcR29vZ2xlIEludGVy
bmV0IEF1dGhvcml0eSBHMjCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB
AJwqBHdc2FCROgajguDYUEi8iT/xGXAaiEZ+4I/F8YnOIe5a/mENtzJEiaB0C1NP
VaTOgmKV7utZX8bhBYASxF6UP7xbSDj0U/ck5vuR6RXEz/RTDfRK/J9U3n2+oGtv
h8DQUB8oMANA2ghzUWx//zo8pzcGjr1LEQTrfSTe5vn8MXH7lNVg8y5Kr0LSy+rE
ahqyzFPdFUuLH8gZYR/Nnag+YyuENWllhMgZxUYi+FOVvuOAShDGKuy6lyARxzmZ
EASg8GF6lSWMTlJ14rbtCMoU/M4iarNOz0YDl5cDfsCx3nuvRTPPuj5xt970JSXC
DTWJnZ37DhF5iR43xa+OcmkCAwEAAaOB+zCB+DAfBgNVHSMEGDAWgBTAephojYn7
qwVkDBF9qn1luMrMTjAdBgNVHQ4EFgQUSt0GFhu89mi1dvWBtrtiGrpagS8wEgYD
VR0TAQH/BAgwBgEB/wIBADAOBgNVHQ8BAf8EBAMCAQYwOgYDVR0fBDMwMTAvoC2g
K4YpaHR0cDovL2NybC5nZW90cnVzdC5jb20vY3Jscy9ndGdsb2JhbC5jcmwwPQYI
KwYBBQUHAQEEMTAvMC0GCCsGAQUFBzABhiFodHRwOi8vZ3RnbG9iYWwtb2NzcC5n
ZW90cnVzdC5jb20wFwYDVR0gBBAwDjAMBgorBgEEAdZ5AgUBMA0GCSqGSIb3DQEB
BQUAA4IBAQA21waAESetKhSbOHezI6B1WLuxfoNCunLaHtiONgaX4PCVOzf9G0JY
/iLIa704XtE7JW4S615ndkZAkNoUyHgN7ZVm2o6Gb4ChulYylYbc3GrKBIxbf/a/
zG+FA1jDaFETzf3I93k9mTXwVqO94FntT0QJo544evZG0R0SnU++0ED8Vf4GXjza
HFa9llF7b1cq26KqltyMdMKVvvBulRP/F/A8rLIQjcxz++iPAsbw+zOzlTvjwsto
WHPbqCRiOwY1nQ2pM714A5AuTHhdUDqB1O6gyHA43LL5Z/qHQF1hwFGPa4NrzQU6
yuGnBXj8ytqU0CwIPX4WecigUCAkVDNx
-----END CERTIFICATE-----`
certPEM = `-----BEGIN CERTIFICATE-----
MIIDujCCAqKgAwIBAgIIE31FZVaPXTUwDQYJKoZIhvcNAQEFBQAwSTELMAkGA1UE
BhMCVVMxEzARBgNVBAoTCkdvb2dsZSBJbmMxJTAjBgNVBAMTHEdvb2dsZSBJbnRl
cm5ldCBBdXRob3JpdHkgRzIwHhcNMTQwMTI5MTMyNzQzWhcNMTQwNTI5MDAwMDAw
WjBpMQswCQYDVQQGEwJVUzETMBEGA1UECAwKQ2FsaWZvcm5pYTEWMBQGA1UEBwwN
TW91bnRhaW4gVmlldzETMBEGA1UECgwKR29vZ2xlIEluYzEYMBYGA1UEAwwPbWFp
bC5nb29nbGUuY29tMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEfRrObuSW5T7q
5CnSEqefEmtH4CCv6+5EckuriNr1CjfVvqzwfAhopXkLrq45EQm8vkmf7W96XJhC
7ZM0dYi1/qOCAU8wggFLMB0GA1UdJQQWMBQGCCsGAQUFBwMBBggrBgEFBQcDAjAa
BgNVHREEEzARgg9tYWlsLmdvb2dsZS5jb20wCwYDVR0PBAQDAgeAMGgGCCsGAQUF
BwEBBFwwWjArBggrBgEFBQcwAoYfaHR0cDovL3BraS5nb29nbGUuY29tL0dJQUcy
LmNydDArBggrBgEFBQcwAYYfaHR0cDovL2NsaWVudHMxLmdvb2dsZS5jb20vb2Nz
cDAdBgNVHQ4EFgQUiJxtimAuTfwb+aUtBn5UYKreKvMwDAYDVR0TAQH/BAIwADAf
BgNVHSMEGDAWgBRK3QYWG7z2aLV29YG2u2IaulqBLzAXBgNVHSAEEDAOMAwGCisG
AQQB1nkCBQEwMAYDVR0fBCkwJzAloCOgIYYfaHR0cDovL3BraS5nb29nbGUuY29t
L0dJQUcyLmNybDANBgkqhkiG9w0BAQUFAAOCAQEAH6RYHxHdcGpMpFE3oxDoFnP+
gtuBCHan2yE2GRbJ2Cw8Lw0MmuKqHlf9RSeYfd3BXeKkj1qO6TVKwCh+0HdZk283
TZZyzmEOyclm3UGFYe82P/iDFt+CeQ3NpmBg+GoaVCuWAARJN/KfglbLyyYygcQq
0SgeDh8dRKUiaW3HQSoYvTvdTuqzwK4CXsr3b5/dAOY8uMuG/IAR3FgwTbZ1dtoW
RvOTa8hYiU6A475WuZKyEHcwnGYe57u2I2KbMgcKjPniocj4QzgYsVAVKW3IwaOh
yE+vPxsiUkvQHdO2fojCkY8jg70jxM+gu59tPDNbw3Uh/2Ij310FgTHsnGQMyA==
-----END CERTIFICATE-----`
csrPEM = `-----BEGIN CERTIFICATE REQUEST-----
MIIEYjCCAkoCAQAwHTEbMBkGA1UEAxMSdGVzdC5zbWFsbHN0ZXAuY29tMIICIjAN
BgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEAuCpifZfoZhYNywfpnPa21NezXgtn
wrWBFE6xhVzE7YDSIqtIsj8aR7R8zwEymxfv5j5298LUy/XSmItVH31CsKyfcGqN
QM0PZr9XY3z5V6qchGMqjzt/jqlYMBHujcxIFBfz4HATxSgKyvHqvw14ESsS2huu
7jowx+XTKbFYgKcXrjBkvOej5FXD3ehkg0jDA2UAJNdfKmrc1BBEaaqOtfh7eyU2
HU7+5gxH8C27IiCAmNj719E0B99Nu2MUw6aLFIM4xAcRga33Avevx6UuXZZIEepe
V1sihrkcnDK9Vsxkme5erXzvAoOiRusiC2iIomJHJrdRM5ReEU+N+Tl1Kxq+rk7H
/qAq78wVm07M1/GGi9SUMObZS4WuJpM6whlikIAEbv9iV+CK0sv/Jr/AADdGMmQU
lwk+Q0ZNE8p4ZuWILv/dtLDtDVBpnrrJ9e8duBtB0lGcG8MdaUCQ346EI4T0Sgx0
hJ+wMq8zYYFfPIZEHC8o9p1ywWN9ySpJ8Zj/5ubmx9v2bY67GbuVFEa8iAp+S00x
/Z8nD6/JsoKtexuHyGr3ixWFzlBqXDuugukIDFUOVDCbuGw4Io4/hEMu4Zz0TIFk
Uu/wf2z75Tt8EkosKLu2wieKcY7n7Vhog/0tqexqWlWtJH0tvq4djsGoSvA62WPs
0iXXj+aZIARPNhECAwEAAaAAMA0GCSqGSIb3DQEBCwUAA4ICAQA0vyHIndAkIs/I
Nnz5yZWCokRjokoKv3Aj4VilyjncL+W0UIPULLU/47ZyoHVSUj2t8gknr9xu/Kd+
g/2z0RiF3CIp8IUH49w/HYWaR95glzVNAAzr8qD9UbUqloLVQW3lObSRGtezhdZO
sspw5dC+inhAb1LZhx8PVxB3SAeJ8h11IEBr0s2Hxt9viKKd7YPtIFZkZdOkVx4R
if1DMawj1P6fEomf8z7m+dmbUYTqqosbCbRL01mzEga/kF6JyH/OzpNlcsAiyM8e
BxPWH6TtPqwmyy4y7j1outmM0RnyUw5A0HmIbWh+rHpXiHVsnNqse0XfzmaxM8+z
dxYeDax8aMWZKfvY1Zew+xIxl7DtEy1BpxrZcawumJYt5+LL+bwF/OtL0inQLnw8
zyqydsXNdrpIQJnfmWPld7ThWbQw2FBE70+nFSxHeG2ULnpF3M9xf6ZNAF4gqaNE
Q7vMNPBWrJWu+A++vHY61WGET+h4lY3GFr2I8OE4IiHPQi1D7Y0+fwOmStwuRPM4
2rARcJChNdiYBkkuvs4kixKTTjdXhB8RQtuBSrJ0M1tzq2qMbm7F8G01rOg4KlXU
58jHzJwr1K7cx0lpWfGTtc5bseCGtTKmDBXTziw04yl8eE1+ZFOganixGwCtl4Tt
DCbKzWTW8lqVdp9Kyf7XEhhc2R8C5w==
-----END CERTIFICATE REQUEST-----`
)
func parseCertificate(data string) *x509.Certificate {
block, _ := pem.Decode([]byte(data))
if block == nil {
panic("failed to parse certificate PEM")
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
panic("failed to parse certificate: " + err.Error())
}
return cert
}
func parseCertificateRequest(data string) *x509.CertificateRequest {
block, _ := pem.Decode([]byte(csrPEM))
if block == nil {
panic("failed to parse certificate request PEM")
}
csr, err := x509.ParseCertificateRequest(block.Bytes)
if err != nil {
panic("failed to parse certificate request: " + err.Error())
}
return csr
}
func TestNewCertificate(t *testing.T) {
cert := parseCertificate(rootPEM)
if !reflect.DeepEqual(Certificate{Certificate: cert}, NewCertificate(cert)) {
t.Errorf("NewCertificate failed, got %v, wants %v", NewCertificate(cert), Certificate{Certificate: cert})
}
}
func TestCertificate_MarshalJSON(t *testing.T) {
type fields struct {
Certificate *x509.Certificate
}
tests := []struct {
name string
fields fields
want []byte
wantErr bool
}{
{"nil", fields{Certificate: nil}, []byte("null"), false},
{"empty", fields{Certificate: &x509.Certificate{Raw: nil}}, []byte(`"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE-----\n"`), false},
{"root", fields{Certificate: parseCertificate(rootPEM)}, []byte(`"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"`), false},
{"cert", fields{Certificate: parseCertificate(certPEM)}, []byte(`"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n"`), false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := Certificate{
Certificate: tt.fields.Certificate,
}
got, err := c.MarshalJSON()
if (err != nil) != tt.wantErr {
t.Errorf("Certificate.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Certificate.MarshalJSON() = %s, want %s", got, tt.want)
}
})
}
}
func TestCertificate_UnmarshalJSON(t *testing.T) {
tests := []struct {
name string
data []byte
wantErr bool
}{
{"no data", nil, true},
{"empty string", []byte(`""`), true},
{"incomplete string 1", []byte(`"foobar`), true}, {"incomplete string 2", []byte(`foobar"`), true},
{"invalid string", []byte(`"foobar"`), true},
{"invalid bytes 0", []byte{}, true}, {"invalid bytes 1", []byte{1}, true},
{"empty csr", []byte(`"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE----\n"`), true},
{"invalid type", []byte(`"` + strings.Replace(csrPEM, "\n", `\n`, -1) + `"`), true},
{"valid root", []byte(`"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `"`), false},
{"valid cert", []byte(`"` + strings.Replace(certPEM, "\n", `\n`, -1) + `"`), false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var c Certificate
if err := c.UnmarshalJSON(tt.data); (err != nil) != tt.wantErr {
t.Errorf("Certificate.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
}
if !tt.wantErr && c.Certificate == nil {
t.Error("Certificate.UnmarshalJSON() failed, Certificate is nil")
}
})
}
}
func TestCertificate_UnmarshalJSON_json(t *testing.T) {
tests := []struct {
name string
data string
wantErr bool
}{
{"invalid type (null)", `{"crt":null}`, true},
{"invalid type (bool)", `{"crt":true}`, true},
{"invalid type (number)", `{"crt":123}`, true},
{"invalid type (object)", `{"crt":{}}`, true},
{"empty crt", `{"crt":"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE----\n"}`, true},
{"valid crt", `{"crt":"` + strings.Replace(certPEM, "\n", `\n`, -1) + `"}`, false},
}
type request struct {
Cert Certificate `json:"crt"`
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var body request
if err := json.Unmarshal([]byte(tt.data), &body); (err != nil) != tt.wantErr {
t.Errorf("json.Unmarshal() error = %v, wantErr %v", err, tt.wantErr)
}
switch tt.wantErr {
case false:
if body.Cert.Certificate == nil {
t.Error("json.Unmarshal() failed, Certificate is nil")
}
case true:
if body.Cert.Certificate != nil {
t.Error("json.Unmarshal() failed, Certificate is not nil")
}
}
})
}
}
func TestNewCertificateRequest(t *testing.T) {
csr := parseCertificateRequest(csrPEM)
if !reflect.DeepEqual(CertificateRequest{CertificateRequest: csr}, NewCertificateRequest(csr)) {
t.Errorf("NewCertificateRequest failed, got %v, wants %v", NewCertificateRequest(csr), CertificateRequest{CertificateRequest: csr})
}
}
func TestCertificateRequest_MarshalJSON(t *testing.T) {
type fields struct {
CertificateRequest *x509.CertificateRequest
}
tests := []struct {
name string
fields fields
want []byte
wantErr bool
}{
{"nil", fields{CertificateRequest: nil}, []byte("null"), false},
{"empty", fields{CertificateRequest: &x509.CertificateRequest{}}, []byte(`"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST-----\n"`), false},
{"csr", fields{CertificateRequest: parseCertificateRequest(csrPEM)}, []byte(`"` + strings.Replace(csrPEM, "\n", `\n`, -1) + `\n"`), false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := CertificateRequest{
CertificateRequest: tt.fields.CertificateRequest,
}
got, err := c.MarshalJSON()
if (err != nil) != tt.wantErr {
t.Errorf("CertificateRequest.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("CertificateRequest.MarshalJSON() = %s, want %s", got, tt.want)
}
})
}
}
func TestCertificateRequest_UnmarshalJSON(t *testing.T) {
tests := []struct {
name string
data []byte
wantErr bool
}{
{"no data", nil, true},
{"empty string", []byte(`""`), true},
{"incomplete string 1", []byte(`"foobar`), true}, {"incomplete string 2", []byte(`foobar"`), true},
{"invalid string", []byte(`"foobar"`), true},
{"invalid bytes 0", []byte{}, true}, {"invalid bytes 1", []byte{1}, true},
{"empty csr", []byte(`"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST----\n"`), true},
{"invalid type", []byte(`"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `"`), true},
{"valid csr", []byte(`"` + strings.Replace(csrPEM, "\n", `\n`, -1) + `"`), false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var c CertificateRequest
if err := c.UnmarshalJSON(tt.data); (err != nil) != tt.wantErr {
t.Errorf("CertificateRequest.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
}
if !tt.wantErr && c.CertificateRequest == nil {
t.Error("CertificateRequest.UnmarshalJSON() failed, CertificateRequet is nil")
}
})
}
}
func TestCertificateRequest_UnmarshalJSON_json(t *testing.T) {
tests := []struct {
name string
data string
wantErr bool
}{
{"invalid type (null)", `{"csr":null}`, true},
{"invalid type (bool)", `{"csr":true}`, true},
{"invalid type (number)", `{"csr":123}`, true},
{"invalid type (object)", `{"csr":{}}`, true},
{"empty csr", `{"csr":"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST----\n"}`, true},
{"valid csr", `{"csr":"` + strings.Replace(csrPEM, "\n", `\n`, -1) + `"}`, false},
}
type request struct {
CSR CertificateRequest `json:"csr"`
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var body request
if err := json.Unmarshal([]byte(tt.data), &body); (err != nil) != tt.wantErr {
t.Errorf("json.Unmarshal() error = %v, wantErr %v", err, tt.wantErr)
}
switch tt.wantErr {
case false:
if body.CSR.CertificateRequest == nil {
t.Error("json.Unmarshal() failed, CertificateRequest is nil")
}
case true:
if body.CSR.CertificateRequest != nil {
t.Error("json.Unmarshal() failed, CertificateRequest is not nil")
}
}
})
}
}
func TestSignRequest_Validate(t *testing.T) {
now := time.Now()
csr := parseCertificateRequest(csrPEM)
bad := parseCertificateRequest(csrPEM)
bad.Signature[0]++
type fields struct {
CsrPEM CertificateRequest
OTT string
NotBefore time.Time
NotAfter time.Time
}
tests := []struct {
name string
fields fields
wantErr bool
}{
{"ok", fields{CertificateRequest{csr}, "foobarzar", time.Time{}, time.Time{}}, false},
{"ok 5m", fields{CertificateRequest{csr}, "foobarzar", now, now.Add(5 * time.Minute)}, false},
{"ok 24h", fields{CertificateRequest{csr}, "foobarzar", now, now.Add(24 * time.Hour)}, false},
{"missing csr", fields{CertificateRequest{}, "foobarzar", time.Time{}, time.Time{}}, true},
{"invalid csr", fields{CertificateRequest{bad}, "foobarzar", time.Time{}, time.Time{}}, true},
{"missing ott", fields{CertificateRequest{csr}, "", time.Time{}, time.Time{}}, true},
{"notAfter < now", fields{CertificateRequest{csr}, "foobarzar", now, now.Add(-5 * time.Minute)}, true},
{"notAfter < notBefore", fields{CertificateRequest{csr}, "foobarzar", now.Add(5 * time.Minute), now.Add(4 * time.Minute)}, true},
{"too short", fields{CertificateRequest{csr}, "foobarzar", now, now.Add(4 * time.Minute)}, true},
{"too long", fields{CertificateRequest{csr}, "foobarzar", now, now.Add(24 * time.Hour).Add(1 * time.Minute)}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &SignRequest{
CsrPEM: tt.fields.CsrPEM,
OTT: tt.fields.OTT,
NotAfter: tt.fields.NotAfter,
NotBefore: tt.fields.NotBefore,
}
if err := s.Validate(); (err != nil) != tt.wantErr {
t.Errorf("SignRequest.Validate() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
type mockAuthority struct {
ret1, ret2 interface{}
err error
authorize func(ott string) ([]Claim, error)
getTLSOptions func() *tlsutil.TLSOptions
root func(shasum string) (*x509.Certificate, error)
sign func(cr *x509.CertificateRequest, opts SignOptions, claims ...Claim) (*x509.Certificate, *x509.Certificate, error)
renew func(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error)
}
func (m *mockAuthority) Authorize(ott string) ([]Claim, error) {
if m.authorize != nil {
return m.authorize(ott)
}
return m.ret1.([]Claim), m.err
}
func (m *mockAuthority) GetTLSOptions() *tlsutil.TLSOptions {
if m.getTLSOptions != nil {
return m.getTLSOptions()
}
return m.ret1.(*tlsutil.TLSOptions)
}
func (m *mockAuthority) GetMinDuration() time.Duration {
return minCertDuration
}
func (m *mockAuthority) GetMaxDuration() time.Duration {
return maxCertDuration
}
func (m *mockAuthority) Root(shasum string) (*x509.Certificate, error) {
if m.root != nil {
return m.root(shasum)
}
return m.ret1.(*x509.Certificate), m.err
}
func (m *mockAuthority) Sign(cr *x509.CertificateRequest, opts SignOptions, claims ...Claim) (*x509.Certificate, *x509.Certificate, error) {
if m.sign != nil {
return m.sign(cr, opts, claims...)
}
return m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate), m.err
}
func (m *mockAuthority) Renew(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error) {
if m.renew != nil {
return m.renew(cert)
}
return m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate), m.err
}
func Test_caHandler_Health(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/health", nil)
w := httptest.NewRecorder()
h := New(&mockAuthority{}).(*caHandler)
h.Health(w, req)
res := w.Result()
if res.StatusCode != 200 {
t.Errorf("caHandler.Health StatusCode = %d, wants 200", res.StatusCode)
}
body, err := ioutil.ReadAll(res.Body)
res.Body.Close()
if err != nil {
t.Errorf("caHandler.Health unexpected error = %v", err)
}
expected := []byte("{\"status\":\"ok\"}\n")
if !bytes.Equal(body, expected) {
t.Errorf("caHandler.Health Body = %s, wants %s", body, expected)
}
}
func Test_caHandler_Root(t *testing.T) {
tests := []struct {
name string
root *x509.Certificate
err error
statusCode int
}{
{"ok", parseCertificate(rootPEM), nil, 200},
{"fail", nil, fmt.Errorf("not found"), 404},
}
// Request with chi context
chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("sha", "efc7d6b475a56fe587650bcdb999a4a308f815ba44db4bf0371ea68a786ccd36")
req := httptest.NewRequest("GET", "http://example.com/root/efc7d6b475a56fe587650bcdb999a4a308f815ba44db4bf0371ea68a786ccd36", nil)
req = req.WithContext(context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx))
expected := []byte(`{"ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"}`)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ret1: tt.root, err: tt.err}).(*caHandler)
w := httptest.NewRecorder()
h.Root(w, req)
res := w.Result()
if res.StatusCode != tt.statusCode {
t.Errorf("caHandler.Root StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
}
body, err := ioutil.ReadAll(res.Body)
res.Body.Close()
if err != nil {
t.Errorf("caHandler.Root unexpected error = %v", err)
}
if tt.statusCode == 200 {
if !bytes.Equal(bytes.TrimSpace(body), expected) {
t.Errorf("caHandler.Root Body = %s, wants %s", body, expected)
}
}
})
}
}
func Test_caHandler_Sign(t *testing.T) {
csr := parseCertificateRequest(csrPEM)
valid, err := json.Marshal(SignRequest{
CsrPEM: CertificateRequest{csr},
OTT: "foobarzar",
})
if err != nil {
t.Fatal(err)
}
invalid, err := json.Marshal(SignRequest{
CsrPEM: CertificateRequest{csr},
OTT: "",
})
if err != nil {
t.Fatal(err)
}
tests := []struct {
name string
input string
claims []Claim
autherr error
cert *x509.Certificate
root *x509.Certificate
signErr error
statusCode int
}{
{"ok", string(valid), nil, nil, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated},
{"json read error", "{", nil, nil, nil, nil, nil, http.StatusBadRequest},
{"validate error", string(invalid), nil, nil, nil, nil, nil, http.StatusBadRequest},
{"authorize error", string(valid), nil, fmt.Errorf("an error"), nil, nil, nil, http.StatusUnauthorized},
{"sign error", string(valid), nil, nil, nil, nil, fmt.Errorf("an error"), http.StatusForbidden},
}
expected := []byte(`{"crt":"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"}`)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{
ret1: tt.cert, ret2: tt.root, err: tt.signErr,
authorize: func(ott string) ([]Claim, error) {
return tt.claims, tt.autherr
},
getTLSOptions: func() *tlsutil.TLSOptions {
return nil
},
}).(*caHandler)
req := httptest.NewRequest("POST", "http://example.com/sign", strings.NewReader(tt.input))
w := httptest.NewRecorder()
h.Sign(w, req)
res := w.Result()
if res.StatusCode != tt.statusCode {
t.Errorf("caHandler.Root StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
}
body, err := ioutil.ReadAll(res.Body)
res.Body.Close()
if err != nil {
t.Errorf("caHandler.Root unexpected error = %v", err)
}
if tt.statusCode < http.StatusBadRequest {
if !bytes.Equal(bytes.TrimSpace(body), expected) {
t.Errorf("caHandler.Root Body = %s, wants %s", body, expected)
}
}
})
}
}
func Test_caHandler_Renew(t *testing.T) {
cs := &tls.ConnectionState{
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
}
tests := []struct {
name string
tls *tls.ConnectionState
cert *x509.Certificate
root *x509.Certificate
err error
statusCode int
}{
{"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated},
{"no tls", nil, nil, nil, nil, http.StatusBadRequest},
{"no peer certificates", &tls.ConnectionState{}, nil, nil, nil, http.StatusBadRequest},
{"renew error", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden},
}
expected := []byte(`{"crt":"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"}`)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{
ret1: tt.cert, ret2: tt.root, err: tt.err,
getTLSOptions: func() *tlsutil.TLSOptions {
return nil
},
}).(*caHandler)
req := httptest.NewRequest("POST", "http://example.com/renew", nil)
req.TLS = tt.tls
w := httptest.NewRecorder()
h.Renew(w, req)
res := w.Result()
if res.StatusCode != tt.statusCode {
t.Errorf("caHandler.Root StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
}
body, err := ioutil.ReadAll(res.Body)
res.Body.Close()
if err != nil {
t.Errorf("caHandler.Root unexpected error = %v", err)
}
if tt.statusCode < http.StatusBadRequest {
if !bytes.Equal(bytes.TrimSpace(body), expected) {
t.Errorf("caHandler.Root Body = %s, wants %s", body, expected)
}
}
})
}
}

142
api/errors.go Normal file
View file

@ -0,0 +1,142 @@
package api
import (
"encoding/json"
"fmt"
"net/http"
"os"
"github.com/pkg/errors"
"github.com/smallstep/ca-component/logging"
)
// StatusCoder interface is used by errors that returns the HTTP response code.
type StatusCoder interface {
StatusCode() int
}
// StackTracer must be by those errors that return an stack trace.
type StackTracer interface {
StackTrace() errors.StackTrace
}
// Error represents the CA API errors.
type Error struct {
Status int
Err error
}
// ErrorResponse represents an error in JSON format.
type ErrorResponse struct {
Status int `json:"status"`
Message string `json:"message"`
}
// Cause implements the errors.Causer interface and returns the original error.
func (e *Error) Cause() error {
return e.Err
}
// Error implements the error interface and returns the error string.
func (e *Error) Error() string {
return e.Err.Error()
}
// StatusCode implements the StatusCoder interface and returns the HTTP response
// code.
func (e *Error) StatusCode() int {
return e.Status
}
// MarshalJSON implements json.Marshaller interface for the Error struct.
func (e *Error) MarshalJSON() ([]byte, error) {
return json.Marshal(&ErrorResponse{Status: e.Status, Message: http.StatusText(e.Status)})
}
// UnmarshalJSON implements json.Unmarshaler interface for the Error struct.
func (e *Error) UnmarshalJSON(data []byte) error {
var er ErrorResponse
if err := json.Unmarshal(data, &er); err != nil {
return err
}
e.Status = er.Status
e.Err = fmt.Errorf(er.Message)
return nil
}
// NewError returns a new Error. If the given error implements the StatusCoder
// interface we will ignore the given status.
func NewError(status int, err error) error {
if sc, ok := err.(StatusCoder); ok {
return &Error{Status: sc.StatusCode(), Err: err}
}
cause := errors.Cause(err)
if sc, ok := cause.(StatusCoder); ok {
return &Error{Status: sc.StatusCode(), Err: err}
}
return &Error{Status: status, Err: err}
}
// InternalServerError returns a 500 error with the given error.
func InternalServerError(err error) error {
return NewError(http.StatusInternalServerError, err)
}
// BadRequest returns an 400 error with the given error.
func BadRequest(err error) error {
return NewError(http.StatusBadRequest, err)
}
// Unauthorized returns an 401 error with the given error.
func Unauthorized(err error) error {
return NewError(http.StatusUnauthorized, err)
}
// Forbidden returns an 403 error with the given error.
func Forbidden(err error) error {
return NewError(http.StatusForbidden, err)
}
// NotFound returns an 404 error with the given error.
func NotFound(err error) error {
return NewError(http.StatusNotFound, err)
}
// WriteError writes to w a JSON representation of the given error.
func WriteError(w http.ResponseWriter, err error) {
w.Header().Set("Content-Type", "application/json")
cause := errors.Cause(err)
if sc, ok := err.(StatusCoder); ok {
w.WriteHeader(sc.StatusCode())
} else {
if sc, ok := cause.(StatusCoder); ok {
w.WriteHeader(sc.StatusCode())
} else {
w.WriteHeader(http.StatusInternalServerError)
}
}
// Write errors in the response writer
if rl, ok := w.(logging.ResponseLogger); ok {
rl.WithFields(map[string]interface{}{
"error": err,
})
if os.Getenv("STEPDEBUG") == "1" {
if e, ok := err.(StackTracer); ok {
rl.WithFields(map[string]interface{}{
"stack-trace": fmt.Sprintf("%+v", e),
})
} else {
if e, ok := cause.(StackTracer); ok {
rl.WithFields(map[string]interface{}{
"stack-trace": fmt.Sprintf("%+v", e),
})
}
}
}
}
if err := json.NewEncoder(w).Encode(err); err != nil {
LogError(w, err)
}
}

41
api/utils.go Normal file
View file

@ -0,0 +1,41 @@
package api
import (
"encoding/json"
"io"
"log"
"net/http"
"github.com/pkg/errors"
"github.com/smallstep/ca-component/logging"
)
// LogError adds to the response writer the given error if it implements
// logging.ResponseLogger. If it does not implement it, then writes the error
// using the log package.
func LogError(rw http.ResponseWriter, err error) {
if rl, ok := rw.(logging.ResponseLogger); ok {
rl.WithFields(map[string]interface{}{
"error": err,
})
} else {
log.Println(err)
}
}
// JSON writes the passed value into the http.ResponseWriter.
func JSON(w http.ResponseWriter, v interface{}) {
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(v); err != nil {
LogError(w, err)
}
}
// ReadJSON reads JSON from the request body and stores it in the value
// pointed by v.
func ReadJSON(r io.Reader, v interface{}) error {
if err := json.NewDecoder(r).Decode(v); err != nil {
return BadRequest(errors.Wrap(err, "error decoding json"))
}
return nil
}

124
api/utils_test.go Normal file
View file

@ -0,0 +1,124 @@
package api
import (
"io"
"net/http"
"net/http/httptest"
"reflect"
"strings"
"testing"
"github.com/pkg/errors"
"github.com/smallstep/ca-component/logging"
)
func TestLogError(t *testing.T) {
theError := errors.New("the error")
type args struct {
rw http.ResponseWriter
err error
}
tests := []struct {
name string
args args
withFields bool
}{
{"normalLogger", args{httptest.NewRecorder(), theError}, false},
{"responseLogger", args{logging.NewResponseLogger(httptest.NewRecorder()), theError}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
LogError(tt.args.rw, tt.args.err)
if tt.withFields {
if rl, ok := tt.args.rw.(logging.ResponseLogger); ok {
fields := rl.Fields()
if !reflect.DeepEqual(fields["error"], theError) {
t.Errorf("ResponseLogger[\"error\"] = %s, wants %s", fields["error"], theError)
}
} else {
t.Error("ResponseWriter does not implement logging.ResponseLogger")
}
}
})
}
}
func TestJSON(t *testing.T) {
type args struct {
rw http.ResponseWriter
v interface{}
}
tests := []struct {
name string
args args
ok bool
}{
{"ok", args{httptest.NewRecorder(), map[string]interface{}{"foo": "bar"}}, true},
{"fail", args{httptest.NewRecorder(), make(chan int)}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rw := logging.NewResponseLogger(tt.args.rw)
JSON(rw, tt.args.v)
rr, ok := tt.args.rw.(*httptest.ResponseRecorder)
if !ok {
t.Error("ResponseWriter does not implement *httptest.ResponseRecorder")
return
}
fields := rw.Fields()
if tt.ok {
if body := rr.Body.String(); body != "{\"foo\":\"bar\"}\n" {
t.Errorf(`Unexpected body = %v, want {"foo":"bar"}`, body)
}
if len(fields) != 0 {
t.Errorf("ResponseLogger fields = %v, wants 0 elements", fields)
}
} else {
if body := rr.Body.String(); body != "" {
t.Errorf("Unexpected body = %s, want empty string", body)
}
if len(fields) != 1 {
t.Errorf("ResponseLogger fields = %v, wants 1 element", fields)
}
}
})
}
}
func TestReadJSON(t *testing.T) {
type args struct {
r io.Reader
v interface{}
}
tests := []struct {
name string
args args
wantErr bool
}{
{"ok", args{strings.NewReader(`{"foo":"bar"}`), make(map[string]interface{})}, false},
{"fail", args{strings.NewReader(`{"foo"}`), make(map[string]interface{})}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ReadJSON(tt.args.r, &tt.args.v)
if (err != nil) != tt.wantErr {
t.Errorf("ReadJSON() error = %v, wantErr %v", err, tt.wantErr)
}
if tt.wantErr {
e, ok := err.(*Error)
if ok {
if code := e.StatusCode(); code != 400 {
t.Errorf("error.StatusCode() = %v, wants 400", code)
}
} else {
t.Errorf("error type = %T, wants *Error", err)
}
} else if !reflect.DeepEqual(tt.args.v, map[string]interface{}{"foo": "bar"}) {
t.Errorf("ReadJSON value = %v, wants %v", tt.args.v, map[string]interface{}{"foo": "bar"})
}
})
}
}

99
authority/authority.go Normal file
View file

@ -0,0 +1,99 @@
package authority
import (
"crypto/sha256"
realx509 "crypto/x509"
"encoding/hex"
"sync"
"time"
"github.com/smallstep/cli/crypto/pemutil"
"github.com/smallstep/cli/crypto/x509util"
)
// Authority implements the Certificate Authority internal interface.
type Authority struct {
config *Config
rootX509Crt *realx509.Certificate
intermediateIdentity *x509util.Identity
validateOnce bool
certificates *sync.Map
ottMap *sync.Map
startTime time.Time
provisionerIDIndex *sync.Map
encryptedKeyIndex *sync.Map
provisionerKeySetIndex *sync.Map
// Do not re-initialize
initOnce bool
}
// New creates and initiates a new Authority type.
func New(config *Config) (*Authority, error) {
if err := config.Validate(); err != nil {
return nil, err
}
var a = &Authority{
config: config,
certificates: new(sync.Map),
ottMap: new(sync.Map),
provisionerIDIndex: new(sync.Map),
encryptedKeyIndex: new(sync.Map),
provisionerKeySetIndex: new(sync.Map),
}
if err := a.init(); err != nil {
return nil, err
}
return a, nil
}
// init performs validation and initializes the fields of an Authority struct.
func (a *Authority) init() error {
// Check if handler has already been validated/initialized.
if a.initOnce {
return nil
}
var err error
// First load the root using our modified pem/x509 package.
a.rootX509Crt, err = pemutil.ReadCertificate(a.config.Root)
if err != nil {
return err
}
// Add root certificate to the certificate map
sum := sha256.Sum256(a.rootX509Crt.Raw)
a.certificates.Store(hex.EncodeToString(sum[:]), a.rootX509Crt)
// Decrypt and load intermediate public / private key pair.
if len(a.config.Password) > 0 {
//fmt.Printf("Decrypting intermediate... ")
a.intermediateIdentity, err = x509util.LoadIdentityFromDisk(
a.config.IntermediateCert,
a.config.IntermediateKey,
pemutil.WithPassword([]byte(a.config.Password)),
)
if err != nil {
return err
}
//fmt.Printf("all done.\n")
} else {
a.intermediateIdentity, err = x509util.LoadIdentityFromDisk(a.config.IntermediateCert, a.config.IntermediateKey)
if err != nil {
return err
}
}
for _, p := range a.config.AuthorityConfig.Provisioners {
a.provisionerIDIndex.Store(p.Key.KeyID, p)
if len(p.EncryptedKey) != 0 {
a.encryptedKeyIndex.Store(p.Key.KeyID, p.EncryptedKey)
}
}
a.startTime = time.Now()
// Set flag indicating that initialization has been completed, and should
// not be repeated.
a.initOnce = true
return nil
}

View file

@ -0,0 +1,41 @@
package authority
import (
"testing"
"github.com/smallstep/assert"
stepJOSE "github.com/smallstep/cli/jose"
)
func testAuthority(t *testing.T) *Authority {
maxjwk, err := stepJOSE.ParseKey("testdata/secrets/max_pub.jwk")
assert.FatalError(t, err)
clijwk, err := stepJOSE.ParseKey("testdata/secrets/step_cli_key_pub.jwk")
assert.FatalError(t, err)
p := []*Provisioner{
{
Issuer: "Max",
Type: "JWK",
Key: maxjwk,
},
{
Issuer: "step-cli",
Type: "JWK",
Key: clijwk,
},
}
c := &Config{
Address: "127.0.0.1",
Root: "testdata/secrets/root_ca.crt",
IntermediateCert: "testdata/secrets/intermediate_ca.crt",
IntermediateKey: "testdata/secrets/intermediate_ca_key",
DNSNames: []string{"test.smallstep.com"},
Password: "pass",
AuthorityConfig: &AuthConfig{
Provisioners: p,
},
}
a, err := New(c)
assert.FatalError(t, err)
return a
}

108
authority/authorize.go Normal file
View file

@ -0,0 +1,108 @@
package authority
import (
"net/http"
"time"
"github.com/pkg/errors"
"github.com/smallstep/ca-component/api"
"gopkg.in/square/go-jose.v2/jwt"
)
type idUsed struct {
UsedAt int64 `json:"ua,omitempty"`
Subject string `json:"sub,omitempty"`
}
var (
validTokenAudience = []string{"https://ca/sign", "step-certificate-authority"}
)
func containsAtLeastOneAudience(claim []string, expected []string) bool {
if len(expected) == 0 {
return true
}
if len(claim) == 0 {
return false
}
for _, exp := range expected {
for _, cl := range claim {
if exp == cl {
return true
}
}
}
return false
}
// Authorize authorizes a signature request by validating and authenticating
// a OTT that must be sent w/ the request.
func (a *Authority) Authorize(ott string) ([]api.Claim, error) {
var (
errContext = map[string]interface{}{"ott": ott}
claims = jwt.Claims{}
// Claims to check in the Sign method
downstreamClaims []api.Claim
)
// Validate payload
token, err := jwt.ParseSigned(ott)
if err != nil {
return nil, &apiError{errors.Wrapf(err, "error parsing OTT to JSONWebToken"),
http.StatusUnauthorized, errContext}
}
kid := token.Headers[0].KeyID // JWT will only have 1 header.
if len(kid) == 0 {
return nil, &apiError{errors.New("keyID cannot be empty"),
http.StatusUnauthorized, errContext}
}
val, ok := a.provisionerIDIndex.Load(kid)
if !ok {
return nil, &apiError{errors.Errorf("Provisioner with KeyID %s could not be found", kid),
http.StatusUnauthorized, errContext}
}
p, ok := val.(*Provisioner)
if !ok {
return nil, &apiError{errors.Errorf("stored value is not a *Provisioner"),
http.StatusInternalServerError, context{}}
}
if err = token.Claims(p.Key, &claims); err != nil {
return nil, &apiError{err, http.StatusUnauthorized, errContext}
}
// According to "rfc7519 JSON Web Token" acceptable skew should be no
// more than a few minutes.
if err = claims.ValidateWithLeeway(jwt.Expected{
Issuer: p.Issuer,
}, time.Minute); err != nil {
return nil, &apiError{errors.Wrapf(err, "error validating OTT"),
http.StatusUnauthorized, errContext}
}
if !containsAtLeastOneAudience(claims.Audience, validTokenAudience) {
return nil, &apiError{errors.New("invalid audience"), http.StatusUnauthorized,
errContext}
}
if claims.Subject == "" {
return nil, &apiError{errors.New("OTT sub cannot be empty"),
http.StatusUnauthorized, errContext}
}
downstreamClaims = append(downstreamClaims, &commonNameClaim{claims.Subject})
downstreamClaims = append(downstreamClaims, &dnsNamesClaim{claims.Subject})
downstreamClaims = append(downstreamClaims, &ipAddressesClaim{claims.Subject})
// Store the token to protect against reuse.
if _, ok := a.ottMap.LoadOrStore(claims.ID, &idUsed{
UsedAt: time.Now().Unix(),
Subject: claims.Subject,
}); ok {
return nil, &apiError{errors.Errorf("token already used"), http.StatusUnauthorized,
errContext}
}
return downstreamClaims, nil
}

160
authority/authorize_test.go Normal file
View file

@ -0,0 +1,160 @@
package authority
import (
"net/http"
"testing"
"time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/ca-component/api"
"github.com/smallstep/cli/crypto/keys"
stepJOSE "github.com/smallstep/cli/jose"
jose "gopkg.in/square/go-jose.v2"
"gopkg.in/square/go-jose.v2/jwt"
)
func TestAuthorize(t *testing.T) {
a := testAuthority(t)
jwk, err := stepJOSE.ParseKey("testdata/secrets/step_cli_key_priv.jwk",
stepJOSE.WithPassword([]byte("pass")))
assert.FatalError(t, err)
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
(&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID))
assert.FatalError(t, err)
now := time.Now()
validIssuer := "step-cli"
type authorizeTest struct {
ott string
err *apiError
claims []api.Claim
}
tests := map[string]func(t *testing.T) *authorizeTest{
"invalid-ott": func(t *testing.T) *authorizeTest {
return &authorizeTest{
ott: "foo",
err: &apiError{errors.New("error parsing OTT"),
http.StatusUnauthorized, context{"ott": "foo"}},
claims: nil}
},
"invalid-issuer": func(t *testing.T) *authorizeTest {
cl := jwt.Claims{
Subject: "subject",
Issuer: "invalid-issuer",
NotBefore: jwt.NewNumericDate(now),
Expiry: jwt.NewNumericDate(now.Add(time.Minute)),
Audience: validTokenAudience,
}
raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
assert.FatalError(t, err)
return &authorizeTest{
ott: raw,
err: &apiError{errors.New("error validating OTT"),
http.StatusUnauthorized, context{"ott": raw}},
claims: nil}
},
"empty-subject": func(t *testing.T) *authorizeTest {
cl := jwt.Claims{
Subject: "",
Issuer: validIssuer,
NotBefore: jwt.NewNumericDate(now),
Expiry: jwt.NewNumericDate(now.Add(time.Minute)),
Audience: validTokenAudience,
}
raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
assert.FatalError(t, err)
return &authorizeTest{
ott: raw,
err: &apiError{errors.New("OTT sub cannot be empty"),
http.StatusUnauthorized, context{"ott": raw}},
claims: nil}
},
"verify-sig-failure": func(t *testing.T) *authorizeTest {
_, priv2, err := keys.GenerateDefaultKeyPair()
assert.FatalError(t, err)
invalidKeySig, err := jose.NewSigner(jose.SigningKey{
Algorithm: jose.ES256,
Key: priv2,
}, (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID))
assert.FatalError(t, err)
cl := jwt.Claims{
Subject: "foo",
Issuer: validIssuer,
NotBefore: jwt.NewNumericDate(now),
Expiry: jwt.NewNumericDate(now.Add(time.Minute)),
Audience: validTokenAudience,
}
raw, err := jwt.Signed(invalidKeySig).Claims(cl).CompactSerialize()
assert.FatalError(t, err)
return &authorizeTest{
ott: raw,
err: &apiError{errors.New("square/go-jose: error in cryptographic primitive"),
http.StatusUnauthorized, context{"ott": raw}},
claims: nil}
},
"token-already-used": func(t *testing.T) *authorizeTest {
cl := jwt.Claims{
Subject: "foo",
Issuer: validIssuer,
NotBefore: jwt.NewNumericDate(now),
Expiry: jwt.NewNumericDate(now.Add(time.Minute)),
Audience: validTokenAudience,
ID: "42",
}
raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
assert.FatalError(t, err)
_, err = a.Authorize(raw)
assert.FatalError(t, err)
return &authorizeTest{
ott: raw,
err: &apiError{errors.New("token already used"),
http.StatusUnauthorized, context{"ott": raw}},
claims: nil}
},
"success": func(t *testing.T) *authorizeTest {
cl := jwt.Claims{
Subject: "foo",
Issuer: validIssuer,
NotBefore: jwt.NewNumericDate(now),
Expiry: jwt.NewNumericDate(now.Add(time.Minute)),
Audience: validTokenAudience,
ID: "43",
}
raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
assert.FatalError(t, err)
return &authorizeTest{
ott: raw,
claims: []api.Claim{&commonNameClaim{"foo"}, &dnsNamesClaim{"foo"}, &ipAddressesClaim{"foo"}},
}
},
}
for name, genTestCase := range tests {
t.Run(name, func(t *testing.T) {
tc := genTestCase(t)
assert.FatalError(t, err)
claims, err := a.Authorize(tc.ott)
if err != nil {
if assert.NotNil(t, tc.err) {
switch v := err.(type) {
case *apiError:
assert.HasPrefix(t, v.err.Error(), tc.err.Error())
assert.Equals(t, v.code, tc.err.code)
assert.Equals(t, v.context, tc.err.context)
default:
t.Errorf("unexpected error type: %T", v)
}
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, claims, tc.claims)
}
}
})
}
}

76
authority/claims.go Normal file
View file

@ -0,0 +1,76 @@
package authority
import (
"crypto/x509"
"net"
"github.com/pkg/errors"
"github.com/smallstep/ca-component/api"
)
// ValidateClaims returns nil if all the claims are validated, it will return
// the first error if a claim fails.
func ValidateClaims(cr *x509.CertificateRequest, claims []api.Claim) (err error) {
for _, c := range claims {
if err = c.Valid(cr); err != nil {
return err
}
}
return
}
// commonNameClaim validates the common name of a certificate request.
type commonNameClaim struct {
name string
}
// Valid checks that certificate request common name matches the one configured.
func (c *commonNameClaim) Valid(cr *x509.CertificateRequest) error {
if cr.Subject.CommonName == "" {
return errors.New("common name cannot be empty")
}
if cr.Subject.CommonName != c.name {
return errors.Errorf("common name claim failed - got %s, want %s", cr.Subject.CommonName, c.name)
}
return nil
}
type dnsNamesClaim struct {
name string
}
// Valid checks that certificate request common name matches the one configured.
func (c *dnsNamesClaim) Valid(cr *x509.CertificateRequest) error {
if len(cr.DNSNames) == 0 {
return nil
}
for _, name := range cr.DNSNames {
if name != c.name {
return errors.Errorf("DNS names claim failed - got %s, want %s", name, c.name)
}
}
return nil
}
type ipAddressesClaim struct {
name string
}
// Valid checks that certificate request common name matches the one configured.
func (c *ipAddressesClaim) Valid(cr *x509.CertificateRequest) error {
if len(cr.IPAddresses) == 0 {
return nil
}
// If it's an IP validate that only that ip is in IP addresses
if requestedIP := net.ParseIP(c.name); requestedIP != nil {
for _, ip := range cr.IPAddresses {
if !ip.Equal(requestedIP) {
return errors.Errorf("IP addresses claim failed - got %s, want %s", ip, requestedIP)
}
}
return nil
}
return errors.Errorf("IP addresses claim failed - got %v, want none", cr.IPAddresses)
}

127
authority/claims_test.go Normal file
View file

@ -0,0 +1,127 @@
package authority
import (
"crypto/x509"
"crypto/x509/pkix"
"net"
"testing"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/ca-component/api"
)
func TestCommonNameClaim_Valid(t *testing.T) {
tests := map[string]struct {
cnc api.Claim
crt *x509.CertificateRequest
err error
}{
"empty-common-name": {
cnc: &commonNameClaim{name: "foo"},
crt: &x509.CertificateRequest{},
err: errors.New("common name cannot be empty"),
},
"wrong-common-name": {
cnc: &commonNameClaim{name: "foo"},
crt: &x509.CertificateRequest{Subject: pkix.Name{CommonName: "bar"}},
err: errors.New("common name claim failed - got bar, want foo"),
},
"ok": {
cnc: &commonNameClaim{name: "foo"},
crt: &x509.CertificateRequest{Subject: pkix.Name{CommonName: "foo"}},
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
err := tc.cnc.Valid(tc.crt)
if err != nil {
if assert.NotNil(t, tc.err) {
assert.Equals(t, tc.err.Error(), err.Error())
}
} else {
assert.Nil(t, tc.err)
}
})
}
}
func TestIPAddressesClaim_Valid(t *testing.T) {
tests := map[string]struct {
iac api.Claim
crt *x509.CertificateRequest
err error
}{
"unexpected-ip": {
iac: &ipAddressesClaim{name: "127.0.0.1"},
crt: &x509.CertificateRequest{IPAddresses: []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("1.1.1.1")}},
err: errors.New("IP addresses claim failed - got 1.1.1.1, want 127.0.0.1"),
},
"invalid-matcher-nonempty-ips": {
iac: &ipAddressesClaim{name: "invalid"},
crt: &x509.CertificateRequest{IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}},
err: errors.New("IP addresses claim failed - got [127.0.0.1], want none"),
},
"ok": {
iac: &ipAddressesClaim{name: "127.0.0.1"},
crt: &x509.CertificateRequest{IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}},
},
"ok-empty-ips": {
iac: &ipAddressesClaim{name: "127.0.0.1"},
crt: &x509.CertificateRequest{IPAddresses: []net.IP{}},
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
err := tc.iac.Valid(tc.crt)
if err != nil {
if assert.NotNil(t, tc.err) {
assert.Equals(t, tc.err.Error(), err.Error())
}
} else {
assert.Nil(t, tc.err)
}
})
}
}
func TestDNSNamesClaim_Valid(t *testing.T) {
tests := map[string]struct {
dnc api.Claim
crt *x509.CertificateRequest
err error
}{
"wrong-dns-name": {
dnc: &dnsNamesClaim{name: "foo"},
crt: &x509.CertificateRequest{DNSNames: []string{"foo", "bar"}},
err: errors.New("DNS names claim failed - got bar, want foo"),
},
"ok": {
dnc: &dnsNamesClaim{name: "foo"},
crt: &x509.CertificateRequest{DNSNames: []string{"foo"}},
},
"ok-empty-dnsNames": {
dnc: &dnsNamesClaim{"foo"},
crt: &x509.CertificateRequest{},
},
"ok-multiple-identical-dns-entries": {
dnc: &dnsNamesClaim{name: "foo"},
crt: &x509.CertificateRequest{DNSNames: []string{"foo", "foo", "foo"}},
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
err := tc.dnc.Valid(tc.crt)
if err != nil {
if assert.NotNil(t, tc.err) {
assert.Equals(t, tc.err.Error(), err.Error())
}
} else {
assert.Nil(t, tc.err)
}
})
}
}

171
authority/config.go Normal file
View file

@ -0,0 +1,171 @@
package authority
import (
"encoding/json"
"os"
"time"
"github.com/pkg/errors"
"github.com/smallstep/cli/crypto/tlsutil"
"github.com/smallstep/cli/crypto/x509util"
jose "gopkg.in/square/go-jose.v2"
)
// DefaultTLSOptions represents the default TLS version as well as the cipher
// suites used in the TLS certificates.
var DefaultTLSOptions = tlsutil.TLSOptions{
CipherSuites: x509util.CipherSuites{
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305",
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384",
},
MinVersion: 1.2,
MaxVersion: 1.2,
Renegotiation: false,
}
const (
// minCertDuration is the minimum validity of an end-entity (not root or intermediate) certificate.
minCertDuration = 5 * time.Minute
// maxCertDuration is the maximum validity of an end-entity (not root or intermediate) certificate.
maxCertDuration = 24 * time.Hour
)
type duration struct {
time.Duration
}
// UnmarshalJSON parses a duration string and sets it to the duration.
//
// A duration string is a possibly signed sequence of decimal numbers, each with
// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m".
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
func (d *duration) UnmarshalJSON(data []byte) (err error) {
var s string
if err = json.Unmarshal(data, &s); err != nil {
return errors.Wrapf(err, "error unmarshalling %s", data)
}
if d.Duration, err = time.ParseDuration(s); err != nil {
return errors.Wrapf(err, "error parsing %s as duration", s)
}
return
}
// Provisioner - authorized entity that can sign tokens necessary for signature requests.
type Provisioner struct {
Issuer string `json:"issuer,omitempty"`
Type string `json:"type,omitempty"`
Key *jose.JSONWebKey `json:"key,omitempty"`
EncryptedKey string `json:"encryptedKey,omitempty"`
}
// Config represents the CA configuration and it's mapped to a JSON object.
type Config struct {
Root string `json:"root"`
IntermediateCert string `json:"crt"`
IntermediateKey string `json:"key"`
Address string `json:"address"`
DNSNames []string `json:"dnsNames"`
Logger json.RawMessage `json:"logger,omitempty"`
Monitoring json.RawMessage `json:"monitoring,omitempty"`
AuthorityConfig *AuthConfig `json:"authority,omitempty"`
TLS *tlsutil.TLSOptions `json:"tls,omitempty"`
Password string `json:"password,omitempty"`
}
// AuthConfig represents the configuration options for the authority.
type AuthConfig struct {
Provisioners []*Provisioner `json:"provisioners,omitempty"`
Template *x509util.ASN1DN `json:"template,omitempty"`
MinCertDuration *duration `json:"minCertDuration,omitempty"`
MaxCertDuration *duration `json:"maxCertDuration,omitempty"`
}
// Validate validates the authority configuration.
func (c *AuthConfig) Validate() error {
switch {
case c == nil:
return errors.New("authority cannot be undefined")
case len(c.Provisioners) == 0:
return errors.New("authority.provisioners cannot be empty")
default:
if c.Template == nil {
c.Template = &x509util.ASN1DN{}
}
return nil
}
}
// LoadConfiguration parses the given filename in JSON format and returns the
// configuration struct.
func LoadConfiguration(filename string) (*Config, error) {
f, err := os.Open(filename)
if err != nil {
return nil, errors.Wrapf(err, "error opening %s", filename)
}
defer f.Close()
var c Config
if err := json.NewDecoder(f).Decode(&c); err != nil {
return nil, errors.Wrapf(err, "error parsing %s", filename)
}
return &c, nil
}
// Save saves the configuration to the given filename.
func (c *Config) Save(filename string) error {
f, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return errors.Wrapf(err, "error opening %s", filename)
}
defer f.Close()
enc := json.NewEncoder(f)
enc.SetIndent("", "\t")
return errors.Wrapf(enc.Encode(c), "error writing %s", filename)
}
// Validate validates the configuration.
func (c *Config) Validate() error {
switch {
case c.Address == "":
return errors.New("address cannot be empty")
case c.Root == "":
return errors.New("root cannot be empty")
case c.IntermediateCert == "":
return errors.New("crt cannot be empty")
case c.IntermediateKey == "":
return errors.New("key cannot be empty")
case len(c.DNSNames) == 0:
return errors.New("dnsNames cannot be empty")
}
if c.TLS == nil {
c.TLS = &DefaultTLSOptions
} else {
if len(c.TLS.CipherSuites) == 0 {
c.TLS.CipherSuites = DefaultTLSOptions.CipherSuites
}
if c.TLS.MaxVersion == 0 {
c.TLS.MaxVersion = DefaultTLSOptions.MaxVersion
}
if c.TLS.MinVersion == 0 {
c.TLS.MinVersion = c.TLS.MaxVersion
}
if c.TLS.MinVersion > c.TLS.MaxVersion {
return errors.New("tls minVersion cannot exceed tls maxVersion")
}
c.TLS.Renegotiation = c.TLS.Renegotiation || DefaultTLSOptions.Renegotiation
}
if err := c.AuthorityConfig.Validate(); err != nil {
return err
}
return nil
}

282
authority/config_test.go Normal file
View file

@ -0,0 +1,282 @@
package authority
import (
"testing"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/cli/crypto/tlsutil"
"github.com/smallstep/cli/crypto/x509util"
stepJOSE "github.com/smallstep/cli/jose"
)
func TestConfigValidate(t *testing.T) {
maxjwk, err := stepJOSE.ParseKey("testdata/secrets/max_pub.jwk")
assert.FatalError(t, err)
clijwk, err := stepJOSE.ParseKey("testdata/secrets/step_cli_key_pub.jwk")
assert.FatalError(t, err)
ac := &AuthConfig{
Provisioners: []*Provisioner{
{
Issuer: "Max",
Type: "JWK",
Key: maxjwk,
},
{
Issuer: "step-cli",
Type: "JWK",
Key: clijwk,
},
},
}
type ConfigValidateTest struct {
config *Config
err error
tls tlsutil.TLSOptions
}
tests := map[string]func(*testing.T) ConfigValidateTest{
"empty-address": func(t *testing.T) ConfigValidateTest {
return ConfigValidateTest{
config: &Config{
Root: "testdata/secrets/root_ca.crt",
IntermediateCert: "testdata/secrets/intermediate_ca.crt",
IntermediateKey: "testdata/secrets/intermediate_ca_key",
DNSNames: []string{"test.smallstep.com"},
Password: "pass",
AuthorityConfig: ac,
},
err: errors.New("address cannot be empty"),
}
},
"empty-root": func(t *testing.T) ConfigValidateTest {
return ConfigValidateTest{
config: &Config{
Address: "127.0.0.1",
IntermediateCert: "testdata/secrets/intermediate_ca.crt",
IntermediateKey: "testdata/secrets/intermediate_ca_key",
DNSNames: []string{"test.smallstep.com"},
Password: "pass",
AuthorityConfig: ac,
},
err: errors.New("root cannot be empty"),
}
},
"empty-intermediate-cert": func(t *testing.T) ConfigValidateTest {
return ConfigValidateTest{
config: &Config{
Address: "127.0.0.1",
Root: "testdata/secrets/root_ca.crt",
IntermediateKey: "testdata/secrets/intermediate_ca_key",
DNSNames: []string{"test.smallstep.com"},
Password: "pass",
AuthorityConfig: ac,
},
err: errors.New("crt cannot be empty"),
}
},
"empty-intermediate-key": func(t *testing.T) ConfigValidateTest {
return ConfigValidateTest{
config: &Config{
Address: "127.0.0.1",
Root: "testdata/secrets/root_ca.crt",
IntermediateCert: "testdata/secrets/intermediate_ca.crt",
DNSNames: []string{"test.smallstep.com"},
Password: "pass",
AuthorityConfig: ac,
},
err: errors.New("key cannot be empty"),
}
},
"empty-dnsNames": func(t *testing.T) ConfigValidateTest {
return ConfigValidateTest{
config: &Config{
Address: "127.0.0.1",
Root: "testdata/secrets/root_ca.crt",
IntermediateCert: "testdata/secrets/intermediate_ca.crt",
IntermediateKey: "testdata/secrets/intermediate_ca_key",
Password: "pass",
AuthorityConfig: ac,
},
err: errors.New("dnsNames cannot be empty"),
}
},
"empty-TLS": func(t *testing.T) ConfigValidateTest {
return ConfigValidateTest{
config: &Config{
Address: "127.0.0.1",
Root: "testdata/secrets/root_ca.crt",
IntermediateCert: "testdata/secrets/intermediate_ca.crt",
IntermediateKey: "testdata/secrets/intermediate_ca_key",
DNSNames: []string{"test.smallstep.com"},
Password: "pass",
AuthorityConfig: ac,
},
tls: DefaultTLSOptions,
}
},
"empty-TLS-values": func(t *testing.T) ConfigValidateTest {
return ConfigValidateTest{
config: &Config{
Address: "127.0.0.1",
Root: "testdata/secrets/root_ca.crt",
IntermediateCert: "testdata/secrets/intermediate_ca.crt",
IntermediateKey: "testdata/secrets/intermediate_ca_key",
DNSNames: []string{"test.smallstep.com"},
Password: "pass",
AuthorityConfig: ac,
TLS: &tlsutil.TLSOptions{},
},
tls: DefaultTLSOptions,
}
},
"custom-tls-values": func(t *testing.T) ConfigValidateTest {
return ConfigValidateTest{
config: &Config{
Address: "127.0.0.1",
Root: "testdata/secrets/root_ca.crt",
IntermediateCert: "testdata/secrets/intermediate_ca.crt",
IntermediateKey: "testdata/secrets/intermediate_ca_key",
DNSNames: []string{"test.smallstep.com"},
Password: "pass",
AuthorityConfig: ac,
TLS: &tlsutil.TLSOptions{
CipherSuites: x509util.CipherSuites{
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305",
},
MinVersion: 1.0,
MaxVersion: 1.1,
Renegotiation: true,
},
},
tls: tlsutil.TLSOptions{
CipherSuites: x509util.CipherSuites{
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305",
},
MinVersion: 1.0,
MaxVersion: 1.1,
Renegotiation: true,
},
}
},
"tls-min>max": func(t *testing.T) ConfigValidateTest {
return ConfigValidateTest{
config: &Config{
Address: "127.0.0.1",
Root: "testdata/secrets/root_ca.crt",
IntermediateCert: "testdata/secrets/intermediate_ca.crt",
IntermediateKey: "testdata/secrets/intermediate_ca_key",
DNSNames: []string{"test.smallstep.com"},
Password: "pass",
AuthorityConfig: ac,
TLS: &tlsutil.TLSOptions{
CipherSuites: x509util.CipherSuites{
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305",
},
MinVersion: 1.2,
MaxVersion: 1.1,
Renegotiation: true,
},
},
err: errors.New("tls minVersion cannot exceed tls maxVersion"),
}
},
}
for name, get := range tests {
t.Run(name, func(t *testing.T) {
tc := get(t)
err := tc.config.Validate()
if err != nil {
if assert.NotNil(t, tc.err) {
assert.Equals(t, tc.err.Error(), err.Error())
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, *tc.config.TLS, tc.tls)
}
}
})
}
}
func TestAuthConfigValidate(t *testing.T) {
asn1dn := x509util.ASN1DN{
Country: "Tazmania",
Organization: "Acme Co",
Locality: "Landscapes",
Province: "Sudden Cliffs",
StreetAddress: "TNT",
CommonName: "test",
}
maxjwk, err := stepJOSE.ParseKey("testdata/secrets/max_pub.jwk")
assert.FatalError(t, err)
clijwk, err := stepJOSE.ParseKey("testdata/secrets/step_cli_key_pub.jwk")
assert.FatalError(t, err)
p := []*Provisioner{
{
Issuer: "Max",
Type: "JWK",
Key: maxjwk,
},
{
Issuer: "step-cli",
Type: "JWK",
Key: clijwk,
},
}
type AuthConfigValidateTest struct {
ac *AuthConfig
asn1dn x509util.ASN1DN
err error
}
tests := map[string]func(*testing.T) AuthConfigValidateTest{
"nil": func(t *testing.T) AuthConfigValidateTest {
return AuthConfigValidateTest{
ac: nil,
err: errors.New("authority cannot be undefined"),
}
},
"empty-provisioners": func(t *testing.T) AuthConfigValidateTest {
return AuthConfigValidateTest{
ac: &AuthConfig{},
err: errors.New("authority.provisioners cannot be empty"),
}
},
"empty-asn1dn-template": func(t *testing.T) AuthConfigValidateTest {
return AuthConfigValidateTest{
ac: &AuthConfig{
Provisioners: p,
},
asn1dn: x509util.ASN1DN{},
}
},
"custom-asn1dn": func(t *testing.T) AuthConfigValidateTest {
return AuthConfigValidateTest{
ac: &AuthConfig{
Provisioners: p,
Template: &asn1dn,
},
asn1dn: asn1dn,
}
},
}
for name, get := range tests {
t.Run(name, func(t *testing.T) {
tc := get(t)
err := tc.ac.Validate()
if err != nil {
if assert.NotNil(t, tc.err) {
assert.Equals(t, tc.err.Error(), err.Error())
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, *tc.ac.Template, tc.asn1dn)
}
}
})
}
}

43
authority/error.go Normal file
View file

@ -0,0 +1,43 @@
package authority
import (
"net/http"
)
type context map[string]interface{}
// Error implements the api.Error interface and adds context to error messages.
type apiError struct {
err error
code int
context context
}
// Cause implements the errors.Causer interface and returns the original error.
func (e *apiError) Cause() error {
return e.err
}
// Error returns an error message with additional context.
func (e *apiError) Error() string {
ret := e.err.Error()
/*
if len(e.context) > 0 {
ret += "\n\nContext:"
for k, v := range e.context {
ret += fmt.Sprintf("\n %s: %v", k, v)
}
}
*/
return ret
}
// StatusCode returns an http status code indicating the type and severity of
// the error.
func (e *apiError) StatusCode() int {
if e.code == 0 {
return http.StatusInternalServerError
}
return e.code
}

47
authority/provisioners.go Normal file
View file

@ -0,0 +1,47 @@
package authority
import (
"net/http"
"github.com/pkg/errors"
"github.com/smallstep/cli/jose"
)
// GetEncryptedKey returns the JWE key corresponding to the given kid argument.
func (a *Authority) GetEncryptedKey(kid string) (string, error) {
val, ok := a.encryptedKeyIndex.Load(kid)
if !ok {
return "", &apiError{errors.Errorf("encrypted key with kid %s was not found", kid),
http.StatusNotFound, context{}}
}
key, ok := val.(string)
if !ok {
return "", &apiError{errors.Errorf("stored value is not a string"),
http.StatusInternalServerError, context{}}
}
return key, nil
}
// GetProvisioners returns a map listing each provisioner and the JWK Key Set
// with their public keys.
func (a *Authority) GetProvisioners() (map[string]*jose.JSONWebKeySet, error) {
pks := map[string]*jose.JSONWebKeySet{}
a.provisionerIDIndex.Range(func(key, val interface{}) bool {
p, ok := val.(*Provisioner)
if !ok {
return false
}
ks, found := pks[p.Issuer]
if found {
ks.Keys = append(ks.Keys, *p.Key)
} else {
ks = new(jose.JSONWebKeySet)
ks.Keys = []jose.JSONWebKey{*p.Key}
pks[p.Issuer] = ks
}
return true
})
return pks, nil
}

29
authority/root.go Normal file
View file

@ -0,0 +1,29 @@
package authority
import (
"crypto/x509"
"net/http"
"github.com/pkg/errors"
)
// Root returns the certificate corresponding to the given SHA sum argument.
func (a *Authority) Root(sum string) (*x509.Certificate, error) {
val, ok := a.certificates.Load(sum)
if !ok {
return nil, &apiError{errors.Errorf("certificate with fingerprint %s was not found", sum),
http.StatusNotFound, context{}}
}
crt, ok := val.(*x509.Certificate)
if !ok {
return nil, &apiError{errors.Errorf("stored value is not a *cryto/x509.Certificate"),
http.StatusInternalServerError, context{}}
}
return crt, nil
}
// GetRootCertificate returns the server root certificate.
func (a *Authority) GetRootCertificate() *x509.Certificate {
return a.rootX509Crt
}

45
authority/root_test.go Normal file
View file

@ -0,0 +1,45 @@
package authority
import (
"net/http"
"testing"
"github.com/pkg/errors"
"github.com/smallstep/assert"
)
func TestRoot(t *testing.T) {
a := testAuthority(t)
a.certificates.Store("invaliddata", "a string") // invalid cert for testing
tests := map[string]struct {
sum string
err *apiError
}{
"not-found": {"foo", &apiError{errors.New("certificate with fingerprint foo was not found"), http.StatusNotFound, context{}}},
"invalid-stored-certificate": {"invaliddata", &apiError{errors.New("stored value is not a *cryto/x509.Certificate"), http.StatusInternalServerError, context{}}},
"success": {"189f573cfa159251e445530847ef80b1b62a3a380ee670dcb49e33ed34da0616", nil},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
crt, err := a.Root(tc.sum)
if err != nil {
if assert.NotNil(t, tc.err) {
switch v := err.(type) {
case *apiError:
assert.HasPrefix(t, v.err.Error(), tc.err.Error())
assert.Equals(t, v.code, tc.err.code)
assert.Equals(t, v.context, tc.err.context)
default:
t.Errorf("unexpected error type: %T", v)
}
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, crt, a.rootX509Crt)
}
}
})
}
}

View file

@ -0,0 +1,12 @@
-----BEGIN CERTIFICATE-----
MIIBxTCCAWugAwIBAgIQfkaUVV4yh8gQZa/EsIECpTAKBggqhkjOPQQDAjAcMRow
GAYDVQQDExFzbWFsbHN0ZXAgUm9vdCBDQTAeFw0xODA4MTgxOTAxNDZaFw0yODA4
MTUxOTAxNDZaMCQxIjAgBgNVBAMTGXNtYWxsc3RlcCBJbnRlcm1lZGlhdGUgQ0Ew
WTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAATfuJeqP7FHMaVq1uMU9avTZ9JW+VzL
NS7rJrkhs41j38Oru9UpZWCqXr5uNNioqElRLB6xRfTPd1mCNctQoTUpo4GGMIGD
MA4GA1UdDwEB/wQEAwIBpjAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIw
EgYDVR0TAQH/BAgwBgEB/wIBADAdBgNVHQ4EFgQU1rz/ojOuK6vKFH4Qi8mwpXtv
OzkwHwYDVR0jBBgwFoAUjoa24fWu22FipFrMI2rjBkzVDhEwCgYIKoZIzj0EAwID
SAAwRQIgWDEWlEaleq5ubnm21k4Zc+agdh1pwOQ41uS4GxXEY5ACIQDkY+MvTLLe
uBjherwnoVagcftox+GmRwgFpLJC/gRLzw==
-----END CERTIFICATE-----

View file

@ -0,0 +1,8 @@
-----BEGIN EC PRIVATE KEY-----
Proc-Type: 4,ENCRYPTED
DEK-Info: AES-128-CBC,856c18a6a0d6654d0e3aed6e3211a285
J1j4qQjtBsh6+ETLy/wlG4eSmQSkmxNQkyzt5zkpqFozS8yssAmTdkIFM6JGnQcc
e0jGRXCy+Sx/vYQCY1uKR5FKlVpcT9I02r1nwgNHfd6zVmbQcXuYKvZQjJKLP27p
gqluC9+nPA+NLJM/oP0GjNtQGasCc7oX6jYP4f1XFpw=
-----END EC PRIVATE KEY-----

View file

@ -0,0 +1,7 @@
{
"protected": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6IkpsNkZLWUp4V1UwdGRIbG9UanA1aGcifQ",
"encrypted_key": "Qy0EP6u5-t0ggOweoc3Z1DCzR5BllsQi",
"iv": "KUkviZ_TJKY4c0Mi",
"ciphertext": "h7QZqgh_Fl2MZpmVy4h375yC0DORjB1dQULbNqc6MuUCW2iweWVRysFImUXiXMUKRarJC5adwWy1GhyAqUj6Xj1iOZDGLjYnqMETGWcI0rKDBwcSU7y7Y-2VYBRDSM2b7aWtTBfz3_kvEaw_vc3b5CEPJ86UlZc-jhKFRr_IcGWU-vXX5-bppoH15IPreyzi55YdjCll338lYpDecB_Paym3XBXotyd2iGXXUwoA1npEFwuyRMMEhl9zLp7rVcMW6A_32EzB8cZANEnA0C4FXGHQalY6u_2UeqxcC8_FuXPay6VIYODyRqcABvvkft3nwOcrI0pYDGBdk2w2Euk",
"tag": "kOAFq3Tg6s4vBGS_plMpSw"
}

View file

@ -0,0 +1,9 @@
{
"use": "sig",
"kty": "EC",
"kid": "IMi94WBNI6gP5cNHXlZYNUzvMjGdHyBRmFoo-lCEaqk",
"crv": "P-256",
"alg": "ES256",
"x": "XmaY0c9Cc_kjfn9uhimiDiKnKn00gmFzzsvElg4KxoE",
"y": "ZhYcFQBqtErdC_pA7sOXrO7AboCEPIKP9Ik4CHJqANk"
}

10
authority/testdata/secrets/root_ca.crt vendored Normal file
View file

@ -0,0 +1,10 @@
-----BEGIN CERTIFICATE-----
MIIBezCCASGgAwIBAgIQO4IwgRBrTxUIHlMdV9j5NDAKBggqhkjOPQQDAjAcMRow
GAYDVQQDExFzbWFsbHN0ZXAgUm9vdCBDQTAeFw0xODA4MTgxOTAxNDZaFw0yODA4
MTUxOTAxNDZaMBwxGjAYBgNVBAMTEXNtYWxsc3RlcCBSb290IENBMFkwEwYHKoZI
zj0CAQYIKoZIzj0DAQcDQgAEsA5O9AoNi/LslXQ2LRXrcWsTH3Urlyrw4RNLs4nK
Fep6C/kRk83eD4eGr0Nfh0EYvUc4J6kYIQl62/bD2RjqCqNFMEMwDgYDVR0PAQH/
BAQDAgGmMBIGA1UdEwEB/wQIMAYBAf8CAQEwHQYDVR0OBBYEFI6GtuH1rtthYqRa
zCNq4wZM1Q4RMAoGCCqGSM49BAMCA0gAMEUCIQCiC+3oVXGMmUp1xeQ/vOwRWTat
I96I5ms2tY8LA6z9RQIgdhiWiYwvvgIMlm57sGpol7evVuAibYH6CE3Mqn4jIE4=
-----END CERTIFICATE-----

View file

@ -0,0 +1,8 @@
-----BEGIN EC PRIVATE KEY-----
Proc-Type: 4,ENCRYPTED
DEK-Info: AES-128-CBC,e2c9c7cdad45b5032f1990b929cf83fd
k3Yd307VgDrdllCBGN7PP8dOMQvEAUkq1lYtyxAWa7u/DuxeDP7SYlDB+xEk/UL8
bgoYYCProydEElYFzGg8Z98WYAzbNoP2p6PPPpAhOZsxJjc5OfTHf/OQleR8PjD5
ryN4woGuq7Tiq5xritlyhluPc91ODqMsm4P98X1sPYA=
-----END EC PRIVATE KEY-----

View file

@ -0,0 +1,4 @@
-----BEGIN PUBLIC KEY-----
MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE7ZdAAMZCFU4XwgblI5RfZouBi8lY
mF6DlZusNNnsbm+xCvYl3PAPZ+DKvKYERdazEPEU2OOo3riostJst0tn1g==
-----END PUBLIC KEY-----

View file

@ -0,0 +1,7 @@
{
"protected": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6IlhOdmYxQjgxSUlLMFA2NUkwcmtGTGcifQ",
"encrypted_key": "XaN9zcPQeWt49zchUDm34FECUTHfQTn_",
"iv": "tmNHPQDqR3ebsWfd",
"ciphertext": "9WZr3YVdeOyJh36vvx0VlRtluhvYp4K7jJ1KGDr1qypwZ3ziBVSNbYYQ71du7fTtrnfG1wgGTVR39tWSzBU-zwQ5hdV3rpMAaEbod5zeW6SHd95H3Bvcb43YiiqJFNL5sGZzFb7FqzVmpsZ1efiv6sZaGDHtnCAL6r12UG5EZuqGfM0jGCZitUz2m9TUKXJL5DJ7MOYbFfkCEsUBPDm_TInliSVn2kMJhFa0VOe5wZk5YOuYM3lNYW64HGtbf-llN2Xk-4O9TfeSPizBx9ZqGpeu8pz13efUDT2WL9tWo6-0UE-CrG0bScm8lFTncTkHcu49_a5NaUBkYlBjEiw",
"tag": "thPcx3t1AUcWuEygXIY3Fg"
}

View file

@ -0,0 +1,9 @@
{
"use": "sig",
"kty": "EC",
"kid": "4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc",
"crv": "P-256",
"alg": "ES256",
"x": "7ZdAAMZCFU4XwgblI5RfZouBi8lYmF6DlZusNNnsbm8",
"y": "sQr2JdzwD2fgyrymBEXWsxDxFNjjqN64qLLSbLdLZ9Y"
}

238
authority/tls.go Normal file
View file

@ -0,0 +1,238 @@
package authority
import (
"crypto/tls"
"crypto/x509"
"encoding/pem"
"net/http"
"strings"
"time"
"github.com/pkg/errors"
"github.com/smallstep/ca-component/api"
"github.com/smallstep/cli/crypto/pemutil"
"github.com/smallstep/cli/crypto/tlsutil"
"github.com/smallstep/cli/crypto/x509util"
stepx509 "github.com/smallstep/cli/pkg/x509"
)
// GetMinDuration returns the minimum validity of an end-entity (not root or
// intermediate) certificate.
func (a *Authority) GetMinDuration() time.Duration {
if a.config.AuthorityConfig.MinCertDuration == nil {
return minCertDuration
}
return a.config.AuthorityConfig.MinCertDuration.Duration
}
// GetMaxDuration returns the maximum validity of an end-entity (not root or
// intermediate) certificate.
func (a *Authority) GetMaxDuration() time.Duration {
if a.config.AuthorityConfig.MaxCertDuration == nil {
return maxCertDuration
}
return a.config.AuthorityConfig.MaxCertDuration.Duration
}
// GetTLSOptions returns the tls options configured.
func (a *Authority) GetTLSOptions() *tlsutil.TLSOptions {
return a.config.TLS
}
func withDefaultASN1DN(def *x509util.ASN1DN) x509util.WithOption {
return func(p x509util.Profile) error {
if def == nil {
return errors.New("default ASN1DN template cannot be nil")
}
crt := p.Subject()
if len(crt.Subject.Country) == 0 && def.Country != "" {
crt.Subject.Country = append(crt.Subject.Country, def.Country)
}
if len(crt.Subject.Organization) == 0 && def.Organization != "" {
crt.Subject.Organization = append(crt.Subject.Organization, def.Organization)
}
if len(crt.Subject.OrganizationalUnit) == 0 && def.OrganizationalUnit != "" {
crt.Subject.OrganizationalUnit = append(crt.Subject.OrganizationalUnit, def.OrganizationalUnit)
}
if len(crt.Subject.Locality) == 0 && def.Locality != "" {
crt.Subject.Locality = append(crt.Subject.Locality, def.Locality)
}
if len(crt.Subject.Province) == 0 && def.Province != "" {
crt.Subject.Province = append(crt.Subject.Province, def.Province)
}
if len(crt.Subject.StreetAddress) == 0 && def.StreetAddress != "" {
crt.Subject.StreetAddress = append(crt.Subject.StreetAddress, def.StreetAddress)
}
return nil
}
}
// Sign creates a signed certificate from a certificate signing request.
func (a *Authority) Sign(csr *x509.CertificateRequest, opts api.SignOptions, claims ...api.Claim) (*x509.Certificate, *x509.Certificate, error) {
if err := ValidateClaims(csr, claims); err != nil {
return nil, nil, &apiError{err, http.StatusUnauthorized, context{}}
}
stepCSR, err := stepx509.ParseCertificateRequest(csr.Raw)
if err != nil {
return nil, nil, &apiError{errors.Wrap(err, "error converting x509 csr to stepx509 csr"),
http.StatusInternalServerError, context{}}
}
// DNSNames and IPAddresses are validated but to avoid duplications we will
// clean them as x509util.NewLeafProfileWithCSR will set the right values.
stepCSR.DNSNames = nil
stepCSR.IPAddresses = nil
issIdentity := a.intermediateIdentity
leaf, err := x509util.NewLeafProfileWithCSR(stepCSR, issIdentity.Crt,
issIdentity.Key, x509util.WithHosts(csr.Subject.CommonName),
x509util.WithNotBeforeAfter(opts.NotBefore, opts.NotAfter),
withDefaultASN1DN(a.config.AuthorityConfig.Template))
if err != nil {
return nil, nil, &apiError{err, http.StatusInternalServerError, context{}}
}
crtBytes, err := leaf.CreateCertificate()
if err != nil {
return nil, nil, &apiError{errors.Wrap(err, "error creating new leaf certificate from input csr"),
http.StatusInternalServerError, context{}}
}
serverCert, err := x509.ParseCertificate(crtBytes)
if err != nil {
return nil, nil, &apiError{errors.Wrap(err, "error parsing new server certificate"),
http.StatusInternalServerError, context{}}
}
caCert, err := x509.ParseCertificate(issIdentity.Crt.Raw)
if err != nil {
return nil, nil, &apiError{errors.Wrap(err, "error parsing intermediate certificate"),
http.StatusInternalServerError, context{}}
}
return serverCert, caCert, nil
}
// Renew creates a new Certificate identical to the old certificate, except
// with a validity window that begins 'now'.
func (a *Authority) Renew(ocx *x509.Certificate) (*x509.Certificate, *x509.Certificate, error) {
issIdentity := a.intermediateIdentity
// Convert a realx509.Certificate to the step x509 Certificate.
oldCert, err := stepx509.ParseCertificate(ocx.Raw)
if err != nil {
return nil, nil, &apiError{
errors.Wrap(err, "error converting x509.Certificate to stepx509.Certificate"),
http.StatusInternalServerError, context{},
}
}
now := time.Now().UTC()
duration := oldCert.NotAfter.Sub(oldCert.NotBefore)
newCert := &stepx509.Certificate{
PublicKey: oldCert.PublicKey,
Issuer: issIdentity.Crt.Subject,
Subject: oldCert.Subject,
NotBefore: now,
NotAfter: now.Add(duration),
KeyUsage: oldCert.KeyUsage,
Extensions: oldCert.Extensions,
ExtraExtensions: oldCert.ExtraExtensions,
UnhandledCriticalExtensions: oldCert.UnhandledCriticalExtensions,
ExtKeyUsage: oldCert.ExtKeyUsage,
UnknownExtKeyUsage: oldCert.UnknownExtKeyUsage,
BasicConstraintsValid: oldCert.BasicConstraintsValid,
IsCA: oldCert.IsCA,
MaxPathLen: oldCert.MaxPathLen,
MaxPathLenZero: oldCert.MaxPathLenZero,
OCSPServer: oldCert.OCSPServer,
IssuingCertificateURL: oldCert.IssuingCertificateURL,
DNSNames: oldCert.DNSNames,
EmailAddresses: oldCert.EmailAddresses,
IPAddresses: oldCert.IPAddresses,
URIs: oldCert.URIs,
PermittedDNSDomainsCritical: oldCert.PermittedDNSDomainsCritical,
PermittedDNSDomains: oldCert.PermittedDNSDomains,
ExcludedDNSDomains: oldCert.ExcludedDNSDomains,
PermittedIPRanges: oldCert.PermittedIPRanges,
ExcludedIPRanges: oldCert.ExcludedIPRanges,
PermittedEmailAddresses: oldCert.PermittedEmailAddresses,
ExcludedEmailAddresses: oldCert.ExcludedEmailAddresses,
PermittedURIDomains: oldCert.PermittedURIDomains,
ExcludedURIDomains: oldCert.ExcludedURIDomains,
CRLDistributionPoints: oldCert.CRLDistributionPoints,
PolicyIdentifiers: oldCert.PolicyIdentifiers,
}
leaf, err := x509util.NewLeafProfileWithTemplate(newCert,
issIdentity.Crt, issIdentity.Key)
if err != nil {
return nil, nil, &apiError{err, http.StatusInternalServerError, context{}}
}
crtBytes, err := leaf.CreateCertificate()
if err != nil {
return nil, nil, &apiError{errors.Wrap(err, "error renewing certificate from existing server certificate"),
http.StatusInternalServerError, context{}}
}
serverCert, err := x509.ParseCertificate(crtBytes)
if err != nil {
return nil, nil, &apiError{errors.Wrap(err, "error parsing new server certificate"),
http.StatusInternalServerError, context{}}
}
caCert, err := x509.ParseCertificate(issIdentity.Crt.Raw)
if err != nil {
return nil, nil, &apiError{errors.Wrap(err, "error parsing intermediate certificate"),
http.StatusInternalServerError, context{}}
}
return serverCert, caCert, nil
}
// GetTLSCertificate creates a new leaf certificate to be used by the CA HTTPS server.
func (a *Authority) GetTLSCertificate() (*tls.Certificate, error) {
profile, err := x509util.NewLeafProfile("Step Online CA",
a.intermediateIdentity.Crt, a.intermediateIdentity.Key,
x509util.WithHosts(strings.Join(a.config.DNSNames, ",")))
if err != nil {
return nil, err
}
crtBytes, err := profile.CreateCertificate()
if err != nil {
return nil, err
}
keyPEM, err := pemutil.Serialize(profile.SubjectPrivateKey())
if err != nil {
return nil, err
}
crtPEM := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: crtBytes,
})
// Load the x509 key pair (combining server and intermediate blocks)
// to a tls.Certificate.
intermediatePEM, err := pemutil.Serialize(a.intermediateIdentity.Crt)
if err != nil {
return nil, err
}
tlsCrt, err := tls.X509KeyPair(append(crtPEM,
pem.EncodeToMemory(intermediatePEM)...),
pem.EncodeToMemory(keyPEM))
if err != nil {
return nil, errors.Wrap(err, "error creating tls certificate")
}
// Get the 'leaf' certificate and set the attribute accordingly.
leaf, err := x509.ParseCertificate(tlsCrt.Certificate[0])
if err != nil {
return nil, errors.Wrap(err, "error parsing tls certificate")
}
tlsCrt.Leaf = leaf
return &tlsCrt, nil
}

407
authority/tls_test.go Normal file
View file

@ -0,0 +1,407 @@
package authority
import (
"crypto/rand"
"crypto/sha1"
"crypto/x509"
"crypto/x509/pkix"
"fmt"
"net/http"
"testing"
"time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/ca-component/api"
"github.com/smallstep/cli/crypto/keys"
"github.com/smallstep/cli/crypto/tlsutil"
"github.com/smallstep/cli/crypto/x509util"
)
func getCSR(t *testing.T, priv interface{}) *x509.CertificateRequest {
_csr := &x509.CertificateRequest{
Subject: pkix.Name{CommonName: "test"},
DNSNames: []string{"test.smallstep.com"},
}
csrBytes, err := x509.CreateCertificateRequest(rand.Reader, _csr, priv)
assert.FatalError(t, err)
csr, err := x509.ParseCertificateRequest(csrBytes)
assert.FatalError(t, err)
return csr
}
func TestSign(t *testing.T) {
pub, priv, err := keys.GenerateDefaultKeyPair()
assert.FatalError(t, err)
a := testAuthority(t)
assert.FatalError(t, err)
a.config.AuthorityConfig.Template = &x509util.ASN1DN{
Country: "Tazmania",
Organization: "Acme Co",
Locality: "Landscapes",
Province: "Sudden Cliffs",
StreetAddress: "TNT",
CommonName: "test",
}
now := time.Now()
type signTest struct {
auth *Authority
csr *x509.CertificateRequest
opts api.SignOptions
claims []api.Claim
err *apiError
}
tests := map[string]func(*testing.T) *signTest{
"fail-validate-claims": func(t *testing.T) *signTest {
csr := getCSR(t, priv)
return &signTest{
auth: a,
csr: csr,
opts: api.SignOptions{
NotBefore: now,
NotAfter: now.Add(time.Minute * 5),
},
claims: []api.Claim{&commonNameClaim{"foo"}},
err: &apiError{errors.New("common name claim failed - got test, want foo"),
http.StatusUnauthorized, context{}},
}
},
"fail-convert-stepCSR": func(t *testing.T) *signTest {
csr := getCSR(t, priv)
csr.Raw = []byte("foo")
return &signTest{
auth: a,
csr: csr,
opts: api.SignOptions{
NotBefore: now,
NotAfter: now.Add(time.Minute * 5),
},
claims: []api.Claim{&commonNameClaim{"test"}},
err: &apiError{errors.New("error converting x509 csr to stepx509 csr"),
http.StatusInternalServerError, context{}},
}
},
"fail-merge-default-ASN1DN": func(t *testing.T) *signTest {
_a := testAuthority(t)
_a.config.AuthorityConfig.Template = nil
csr := getCSR(t, priv)
return &signTest{
auth: _a,
csr: csr,
opts: api.SignOptions{
NotBefore: now,
NotAfter: now.Add(time.Minute * 5),
},
claims: []api.Claim{&commonNameClaim{"test"}},
err: &apiError{errors.New("default ASN1DN template cannot be nil"),
http.StatusInternalServerError, context{}},
}
},
"fail-create-cert": func(t *testing.T) *signTest {
_a := testAuthority(t)
_a.intermediateIdentity.Key = nil
csr := getCSR(t, priv)
return &signTest{
auth: _a,
csr: csr,
opts: api.SignOptions{
NotBefore: now,
NotAfter: now.Add(time.Minute * 5),
},
claims: []api.Claim{&commonNameClaim{"test"}},
err: &apiError{errors.New("error creating new leaf certificate from input csr"),
http.StatusInternalServerError, context{}},
}
},
"success": func(t *testing.T) *signTest {
csr := getCSR(t, priv)
return &signTest{
auth: a,
csr: csr,
opts: api.SignOptions{
NotBefore: now,
NotAfter: now.Add(time.Minute * 5),
},
claims: []api.Claim{&commonNameClaim{"test"}},
}
},
}
for name, genTestCase := range tests {
t.Run(name, func(t *testing.T) {
tc := genTestCase(t)
leaf, intermediate, err := tc.auth.Sign(tc.csr, tc.opts, tc.claims...)
if err != nil {
if assert.NotNil(t, tc.err) {
switch v := err.(type) {
case *apiError:
assert.HasPrefix(t, v.err.Error(), tc.err.Error())
assert.Equals(t, v.code, tc.err.code)
assert.Equals(t, v.context, tc.err.context)
default:
t.Errorf("unexpected error type: %T", v)
}
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, leaf.NotBefore, tc.opts.NotBefore.UTC().Truncate(time.Second))
assert.Equals(t, leaf.NotAfter, tc.opts.NotAfter.UTC().Truncate(time.Second))
tmplt := a.config.AuthorityConfig.Template
assert.Equals(t, fmt.Sprintf("%v", leaf.Subject),
fmt.Sprintf("%v", &pkix.Name{
Country: []string{tmplt.Country},
Organization: []string{tmplt.Organization},
Locality: []string{tmplt.Locality},
StreetAddress: []string{tmplt.StreetAddress},
Province: []string{tmplt.Province},
CommonName: tmplt.CommonName,
}))
assert.Equals(t, leaf.Issuer, intermediate.Subject)
assert.Equals(t, leaf.SignatureAlgorithm, x509.ECDSAWithSHA256)
assert.Equals(t, leaf.PublicKeyAlgorithm, x509.ECDSA)
assert.Equals(t, leaf.ExtKeyUsage,
[]x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth})
assert.Equals(t, leaf.DNSNames, []string{"test"})
pubBytes, err := x509.MarshalPKIXPublicKey(pub)
assert.FatalError(t, err)
hash := sha1.Sum(pubBytes)
assert.Equals(t, leaf.SubjectKeyId, hash[:])
assert.Equals(t, leaf.AuthorityKeyId, a.intermediateIdentity.Crt.SubjectKeyId)
realIntermediate, err := x509.ParseCertificate(a.intermediateIdentity.Crt.Raw)
assert.FatalError(t, err)
assert.Equals(t, intermediate, realIntermediate)
}
}
})
}
}
func TestRenew(t *testing.T) {
pub, _, err := keys.GenerateDefaultKeyPair()
assert.FatalError(t, err)
a := testAuthority(t)
a.config.AuthorityConfig.Template = &x509util.ASN1DN{
Country: "Tazmania",
Organization: "Acme Co",
Locality: "Landscapes",
Province: "Sudden Cliffs",
StreetAddress: "TNT",
CommonName: "renew",
}
now := time.Now().UTC()
nb1 := now.Add(-time.Minute * 7)
na1 := now
so := &api.SignOptions{
NotBefore: nb1,
NotAfter: na1,
}
leaf, err := x509util.NewLeafProfile("renew", a.intermediateIdentity.Crt,
a.intermediateIdentity.Key,
x509util.WithNotBeforeAfter(so.NotBefore, so.NotAfter),
withDefaultASN1DN(a.config.AuthorityConfig.Template),
x509util.WithPublicKey(pub), x509util.WithHosts("test.smallstep.com,test"))
assert.FatalError(t, err)
crtBytes, err := leaf.CreateCertificate()
assert.FatalError(t, err)
crt, err := x509.ParseCertificate(crtBytes)
assert.FatalError(t, err)
type renewTest struct {
auth *Authority
crt *x509.Certificate
err *apiError
}
tests := map[string]func() (*renewTest, error){
"fail-conversion-stepx509": func() (*renewTest, error) {
return &renewTest{
crt: &x509.Certificate{Raw: []byte("foo")},
err: &apiError{errors.New("error converting x509.Certificate to stepx509.Certificate"),
http.StatusInternalServerError, context{}},
}, nil
},
"fail-create-cert": func() (*renewTest, error) {
_a := testAuthority(t)
_a.intermediateIdentity.Key = nil
return &renewTest{
auth: _a,
crt: crt,
err: &apiError{errors.New("error renewing certificate from existing server certificate"),
http.StatusInternalServerError, context{}},
}, nil
},
"success": func() (*renewTest, error) {
return &renewTest{
crt: crt,
}, nil
},
}
for name, genTestCase := range tests {
t.Run(name, func(t *testing.T) {
tc, err := genTestCase()
assert.FatalError(t, err)
var leaf, intermediate *x509.Certificate
if tc.auth != nil {
leaf, intermediate, err = tc.auth.Renew(tc.crt)
} else {
leaf, intermediate, err = a.Renew(tc.crt)
}
if err != nil {
if assert.NotNil(t, tc.err) {
switch v := err.(type) {
case *apiError:
assert.HasPrefix(t, v.err.Error(), tc.err.Error())
assert.Equals(t, v.code, tc.err.code)
assert.Equals(t, v.context, tc.err.context)
default:
t.Errorf("unexpected error type: %T", v)
}
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, leaf.NotAfter.Sub(leaf.NotBefore), crt.NotAfter.Sub(crt.NotBefore))
assert.True(t, leaf.NotBefore.After(now.Add(-time.Minute)))
assert.True(t, leaf.NotBefore.Before(now.Add(time.Minute)))
expiry := now.Add(time.Minute * 7)
assert.True(t, leaf.NotAfter.After(expiry.Add(-time.Minute)))
assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Minute)))
tmplt := a.config.AuthorityConfig.Template
assert.Equals(t, fmt.Sprintf("%v", leaf.Subject),
fmt.Sprintf("%v", &pkix.Name{
Country: []string{tmplt.Country},
Organization: []string{tmplt.Organization},
Locality: []string{tmplt.Locality},
StreetAddress: []string{tmplt.StreetAddress},
Province: []string{tmplt.Province},
CommonName: tmplt.CommonName,
}))
assert.Equals(t, leaf.Issuer, intermediate.Subject)
assert.Equals(t, leaf.SignatureAlgorithm, x509.ECDSAWithSHA256)
assert.Equals(t, leaf.PublicKeyAlgorithm, x509.ECDSA)
assert.Equals(t, leaf.ExtKeyUsage,
[]x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth})
assert.Equals(t, leaf.DNSNames, []string{"test.smallstep.com", "test"})
pubBytes, err := x509.MarshalPKIXPublicKey(pub)
assert.FatalError(t, err)
hash := sha1.Sum(pubBytes)
assert.Equals(t, leaf.SubjectKeyId, hash[:])
assert.Equals(t, leaf.AuthorityKeyId, a.intermediateIdentity.Crt.SubjectKeyId)
realIntermediate, err := x509.ParseCertificate(a.intermediateIdentity.Crt.Raw)
assert.FatalError(t, err)
assert.Equals(t, intermediate, realIntermediate)
}
}
})
}
}
func TestGetMinDuration(t *testing.T) {
type renewTest struct {
auth *Authority
d time.Duration
}
tests := map[string]func() (*renewTest, error){
"default": func() (*renewTest, error) {
a := testAuthority(t)
return &renewTest{auth: a, d: time.Minute * 5}, nil
},
"non-default": func() (*renewTest, error) {
a := testAuthority(t)
a.config.AuthorityConfig.MinCertDuration = &duration{time.Minute * 7}
return &renewTest{auth: a, d: time.Minute * 7}, nil
},
}
for name, genTestCase := range tests {
t.Run(name, func(t *testing.T) {
tc, err := genTestCase()
assert.FatalError(t, err)
d := tc.auth.GetMinDuration()
assert.Equals(t, d, tc.d)
})
}
}
func TestGetMaxDuration(t *testing.T) {
type renewTest struct {
auth *Authority
d time.Duration
}
tests := map[string]func() (*renewTest, error){
"default": func() (*renewTest, error) {
a := testAuthority(t)
return &renewTest{auth: a, d: time.Hour * 24}, nil
},
"non-default": func() (*renewTest, error) {
a := testAuthority(t)
a.config.AuthorityConfig.MaxCertDuration = &duration{time.Minute * 7}
return &renewTest{auth: a, d: time.Minute * 7}, nil
},
}
for name, genTestCase := range tests {
t.Run(name, func(t *testing.T) {
tc, err := genTestCase()
assert.FatalError(t, err)
d := tc.auth.GetMaxDuration()
assert.Equals(t, d, tc.d)
})
}
}
func TestGetTLSOptions(t *testing.T) {
type renewTest struct {
auth *Authority
opts *tlsutil.TLSOptions
}
tests := map[string]func() (*renewTest, error){
"default": func() (*renewTest, error) {
a := testAuthority(t)
return &renewTest{auth: a, opts: &DefaultTLSOptions}, nil
},
"non-default": func() (*renewTest, error) {
a := testAuthority(t)
a.config.TLS = &tlsutil.TLSOptions{
CipherSuites: x509util.CipherSuites{
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305",
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384",
},
MinVersion: 1.0,
MaxVersion: 1.1,
Renegotiation: true,
}
return &renewTest{auth: a, opts: a.config.TLS}, nil
},
}
for name, genTestCase := range tests {
t.Run(name, func(t *testing.T) {
tc, err := genTestCase()
assert.FatalError(t, err)
opts := tc.auth.GetTLSOptions()
assert.Equals(t, opts, tc.opts)
})
}
}

198
ca/ca.go Normal file
View file

@ -0,0 +1,198 @@
package ca
import (
"crypto/tls"
"crypto/x509"
"net/http"
"github.com/go-chi/chi"
"github.com/pkg/errors"
"github.com/smallstep/ca-component/api"
"github.com/smallstep/ca-component/authority"
"github.com/smallstep/ca-component/logging"
"github.com/smallstep/ca-component/monitoring"
"github.com/smallstep/ca-component/server"
)
type options struct {
configFile string
password []byte
}
func (o *options) apply(opts []Option) {
for _, fn := range opts {
fn(o)
}
}
// Option is the type of options passed to the CA constructor.
type Option func(o *options)
// WithConfigFile sets the given name as the configuration file name in the CA
// options.
func WithConfigFile(name string) Option {
return func(o *options) {
o.configFile = name
}
}
// WithPassword sets the given password as the configured password in the CA
// options.
func WithPassword(password []byte) Option {
return func(o *options) {
o.password = password
}
}
// CA is the type used to build the complete certificate authority. It builds
// the HTTP server, set ups the middlewares and the HTTP handlers.
type CA struct {
auth *authority.Authority
config *authority.Config
srv *server.Server
opts *options
renewer *TLSRenewer
}
// New creates and initializes the CA with the given configuration and options.
func New(config *authority.Config, opts ...Option) (*CA, error) {
ca := &CA{
config: config,
opts: new(options),
}
ca.opts.apply(opts)
return ca.Init(config)
}
// Init initializes the CA with the given configuration.
func (ca *CA) Init(config *authority.Config) (*CA, error) {
if l := len(ca.opts.password); l > 0 {
ca.config.Password = string(ca.opts.password)
}
auth, err := authority.New(config)
if err != nil {
return nil, err
}
tlsConfig, err := ca.getTLSConfig(auth)
if err != nil {
return nil, err
}
// Using chi as the main router
mux := chi.NewRouter()
handler := http.Handler(mux)
// Add api endpoints in / and /1.0
routerHandler := api.New(auth)
routerHandler.Route(mux)
mux.Route("/1.0", func(r chi.Router) {
routerHandler.Route(r)
})
// Add monitoring if configured
if len(config.Monitoring) > 0 {
m, err := monitoring.New(config.Monitoring)
if err != nil {
return nil, err
}
handler = m.Middleware(handler)
}
// Add logger if configured
if len(config.Logger) > 0 {
logger, err := logging.New("ca", config.Logger)
if err != nil {
return nil, err
}
handler = logger.Middleware(handler)
}
ca.auth = auth
ca.srv = server.New(config.Address, handler, tlsConfig)
return ca, nil
}
// Run starts the CA calling to the server ListenAndServe method.
func (ca *CA) Run() error {
return ca.srv.ListenAndServe()
}
// Stop stops the CA calling to the server Shutdown method.
func (ca *CA) Stop() error {
return ca.srv.Shutdown()
}
// Reload reloads the configuration of the CA and calls to the server Reload
// method.
func (ca *CA) Reload() error {
if ca.opts.configFile == "" {
return errors.New("error reloading ca: configuration file is not set")
}
config, err := authority.LoadConfiguration(ca.opts.configFile)
if err != nil {
return errors.Wrap(err, "error reloading ca")
}
newCA, err := New(config, WithPassword(ca.opts.password), WithConfigFile(ca.opts.configFile))
if err != nil {
return errors.Wrap(err, "error reloading ca")
}
return ca.srv.Reload(newCA.srv)
}
// getTLSConfig returns a TLSConfig for the CA server with a self-renewing
// server certificate.
func (ca *CA) getTLSConfig(auth *authority.Authority) (*tls.Config, error) {
// Create initial TLS certificate
tlsCrt, err := auth.GetTLSCertificate()
if err != nil {
return nil, err
}
// Start tls renewer with the new certificate.
// If a renewer was started, attempt to stop it before.
if ca.renewer != nil {
ca.renewer.Stop()
}
ca.renewer, err = NewTLSRenewer(tlsCrt, auth.GetTLSCertificate)
if err != nil {
return nil, err
}
ca.renewer.Run()
var tlsConfig *tls.Config
if ca.config.TLS != nil {
tlsConfig = ca.config.TLS.TLSConfig()
} else {
tlsConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
}
}
certPool := x509.NewCertPool()
certPool.AddCert(auth.GetRootCertificate())
// GetCertificate will only be called if the client supplies SNI
// information or if tlsConfig.Certificates is empty.
// When client requests are made using an IP address (as opposed to a domain
// name) the server does not receive any SNI and may fallback to using the
// first entry in the Certificates attribute; by setting the attribute to
// empty we are implicitly forcing GetCertificate to be the only mechanism
// by which the server can find it's own leaf Certificate.
tlsConfig.Certificates = []tls.Certificate{}
tlsConfig.GetCertificate = ca.renewer.GetCertificate
// Add support for mutual tls to renew certificates
tlsConfig.ClientAuth = tls.VerifyClientCertIfGiven
tlsConfig.ClientCAs = certPool
// Use server's most preferred ciphersuite
tlsConfig.PreferServerCipherSuites = true
return tlsConfig, nil
}

499
ca/ca_test.go Normal file
View file

@ -0,0 +1,499 @@
package ca
import (
"bytes"
"crypto/rand"
"crypto/sha1"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/json"
"encoding/pem"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/ca-component/api"
"github.com/smallstep/ca-component/authority"
"github.com/smallstep/cli/crypto/keys"
"github.com/smallstep/cli/crypto/pemutil"
"github.com/smallstep/cli/crypto/randutil"
"github.com/smallstep/cli/crypto/x509util"
stepJOSE "github.com/smallstep/cli/jose"
jose "gopkg.in/square/go-jose.v2"
"gopkg.in/square/go-jose.v2/jwt"
)
type ClosingBuffer struct {
*bytes.Buffer
}
func (cb *ClosingBuffer) Close() error {
return nil
}
func getCSR(priv interface{}) (*x509.CertificateRequest, error) {
_csr := &x509.CertificateRequest{
Subject: pkix.Name{CommonName: "test.smallstep.com"},
DNSNames: []string{"test.smallstep.com"},
}
csrBytes, err := x509.CreateCertificateRequest(rand.Reader, _csr, priv)
if err != nil {
return nil, err
}
return x509.ParseCertificateRequest(csrBytes)
}
func TestCASign(t *testing.T) {
pub, priv, err := keys.GenerateDefaultKeyPair()
assert.FatalError(t, err)
asn1dn := &x509util.ASN1DN{
Country: "Tazmania",
Organization: "Acme Co",
Locality: "Landscapes",
Province: "Sudden Cliffs",
StreetAddress: "TNT",
CommonName: "test.smallstep.com",
}
config, err := authority.LoadConfiguration("testdata/ca.json")
assert.FatalError(t, err)
config.AuthorityConfig.Template = asn1dn
ca, err := New(config)
assert.FatalError(t, err)
intermediateIdentity, err := x509util.LoadIdentityFromDisk("testdata/secrets/intermediate_ca.crt",
"testdata/secrets/intermediate_ca_key", pemutil.WithPassword([]byte("password")))
assert.FatalError(t, err)
clijwk, err := stepJOSE.ParseKey("testdata/secrets/step_cli_key_priv.jwk",
stepJOSE.WithPassword([]byte("pass")))
assert.FatalError(t, err)
fmt.Printf("clijwk.KeyID = %+v\n", clijwk.KeyID)
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: clijwk.Key},
(&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", clijwk.KeyID))
assert.FatalError(t, err)
now := time.Now().UTC()
leafExpiry := now.Add(time.Minute * 5)
type signTest struct {
ca *CA
body string
status int
errMsg string
}
tests := map[string]func(t *testing.T) *signTest{
"invalid-json-body": func(t *testing.T) *signTest {
return &signTest{
ca: ca,
body: "invalid json",
status: http.StatusBadRequest,
errMsg: "Bad Request",
}
},
"invalid-csr-sig": func(t *testing.T) *signTest {
der := []byte(`-----BEGIN CERTIFICATE REQUEST-----
MIIDNjCCAh4CAQAwYzELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAkNBMRYwFAYDVQQH
DA1TYW4gRnJhbmNpc2NvMRIwEAYDVQQKDAlzbWFsbHN0ZXAxGzAZBgNVBAMMEnRl
c3Quc21hbGxzdGVwLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB
ANPahliigZ38QpBLmQMS3MVKKZ5gapNjqR7LIEYoYWa4lTFiUnbwg8tSfIFcgLZr
jNIxn7/98+JOJHKgS03NhFJoS5hej0LyypleOGJ0nk2qawYVKnn1ftoKjkfxkfZI
a/5rsDF1jhNBspB/KPHWE0eimKQJbUiVG1zA1sExnXDecF3vJfBj+DPDWngx4yxR
/jYEKjt4tQ6Ei752TbosrCHYeYXzkr6iAwiNz6vT/ewLb6b8JmuN8X6Y1I9ogDGx
hntBJ1jAK8x3IGTjYbkm+mqVuCyhNcHtGfEHcBnUEzLAPrVFn8kGiAnU17FJ0uQ7
1C9CtUzgBRZCxSBm6Qs+Zs8CAwEAAaCBjTCBigYJKoZIhvcNAQkOMX0wezAMBgNV
HRMBAf8EAjAAMB0GA1UdJQQWMBQGCCsGAQUFBwMCBggrBgEFBQcDATAOBgNVHQ8B
Af8EBAMCBaAwHQYDVR0RBBYwFIISdGVzdC5zbWFsbHN0ZXAuY29tMB0GA1UdDgQW
BBQj6N4RTAAjhV3UBYXH72mkdOGpqzANBgkqhkiG9w0BAQsFAAOCAQEAN0/ivCBk
FD53SqtRmqqc7C9saoRNvV+wDi4Sg6YGLFQLjbZPJrqQURWdHtV9O3sb3p8O5erX
9Kgq3C7fqd//0mro4GZ1GTpjsPKIMocZFfH7zEhAZlvQLRKWICjoBaOwxQum2qY/
B3+ltAXb4uqGdbI0jPkkyWGN5CQhK+ZHoYe/zGtTEmHBcPxRtJJkukQQjUgZhjU2
Z7K+w3AjOxj47XLNHHlW83QYUJ2mN+mEZF9DhrZb2ydYOlpy0V2NJwv7QrmnFaDj
R0v3BFLTblIp100li3oV2QaM/yESrgo9XIjEEGzCGz5cNs5ovNadufUZDCJyyT4q
ZEp7knvU2psWRw==
-----END CERTIFICATE REQUEST-----`)
block, _ := pem.Decode(der)
assert.NotNil(t, block)
csr, err := x509.ParseCertificateRequest(block.Bytes)
assert.FatalError(t, err)
body, err := json.Marshal(&api.SignRequest{
CsrPEM: api.CertificateRequest{CertificateRequest: csr},
OTT: "foo",
})
assert.FatalError(t, err)
return &signTest{
ca: ca,
body: string(body),
status: http.StatusBadRequest,
errMsg: "Bad Request",
}
},
"unauthorized-ott": func(t *testing.T) *signTest {
csr, err := getCSR(priv)
assert.FatalError(t, err)
body, err := json.Marshal(&api.SignRequest{
CsrPEM: api.CertificateRequest{CertificateRequest: csr},
OTT: "foo",
})
assert.FatalError(t, err)
return &signTest{
ca: ca,
body: string(body),
status: http.StatusUnauthorized,
errMsg: "Unauthorized",
}
},
"fail-commonname-claim": func(t *testing.T) *signTest {
jti, err := randutil.ASCII(32)
assert.FatalError(t, err)
cl := jwt.Claims{
Subject: "invalid",
Issuer: "step-cli",
NotBefore: jwt.NewNumericDate(now),
Expiry: jwt.NewNumericDate(now.Add(time.Minute)),
Audience: []string{"step-certificate-authority"},
ID: jti,
}
raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
assert.FatalError(t, err)
csr, err := getCSR(priv)
assert.FatalError(t, err)
body, err := json.Marshal(&api.SignRequest{
CsrPEM: api.CertificateRequest{CertificateRequest: csr},
OTT: raw,
})
assert.FatalError(t, err)
return &signTest{
ca: ca,
body: string(body),
status: http.StatusUnauthorized,
errMsg: "Unauthorized",
}
},
"success": func(t *testing.T) *signTest {
jti, err := randutil.ASCII(32)
assert.FatalError(t, err)
cl := jwt.Claims{
Subject: "test.smallstep.com",
Issuer: "step-cli",
NotBefore: jwt.NewNumericDate(now),
Expiry: jwt.NewNumericDate(now.Add(time.Minute)),
Audience: []string{"step-certificate-authority"},
ID: jti,
}
raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
assert.FatalError(t, err)
csr, err := getCSR(priv)
assert.FatalError(t, err)
body, err := json.Marshal(&api.SignRequest{
CsrPEM: api.CertificateRequest{CertificateRequest: csr},
OTT: raw,
NotBefore: now,
NotAfter: leafExpiry,
})
assert.FatalError(t, err)
return &signTest{
ca: ca,
body: string(body),
status: http.StatusCreated,
}
},
}
for name, genTestCase := range tests {
t.Run(name, func(t *testing.T) {
tc := genTestCase(t)
rq, err := http.NewRequest("POST", "/sign", strings.NewReader(tc.body))
assert.FatalError(t, err)
rr := httptest.NewRecorder()
tc.ca.srv.Handler.ServeHTTP(rr, rq)
if assert.Equals(t, rr.Code, tc.status) {
body := &ClosingBuffer{rr.Body}
if rr.Code < http.StatusBadRequest {
var sign api.SignResponse
assert.FatalError(t, readJSON(body, &sign))
leaf := sign.ServerPEM.Certificate
intermediate := sign.CaPEM.Certificate
assert.Equals(t, leaf.NotBefore, now.Truncate(time.Second))
assert.Equals(t, leaf.NotAfter, leafExpiry.Truncate(time.Second))
assert.Equals(t, fmt.Sprintf("%v", leaf.Subject),
fmt.Sprintf("%v", &pkix.Name{
Country: []string{asn1dn.Country},
Organization: []string{asn1dn.Organization},
Locality: []string{asn1dn.Locality},
StreetAddress: []string{asn1dn.StreetAddress},
Province: []string{asn1dn.Province},
CommonName: asn1dn.CommonName,
}))
assert.Equals(t, leaf.Issuer, intermediate.Subject)
assert.Equals(t, leaf.SignatureAlgorithm, x509.ECDSAWithSHA256)
assert.Equals(t, leaf.PublicKeyAlgorithm, x509.ECDSA)
assert.Equals(t, leaf.ExtKeyUsage,
[]x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth})
assert.Equals(t, leaf.DNSNames, []string{"test.smallstep.com"})
pubBytes, err := x509.MarshalPKIXPublicKey(pub)
assert.FatalError(t, err)
hash := sha1.Sum(pubBytes)
assert.Equals(t, leaf.SubjectKeyId, hash[:])
assert.Equals(t, leaf.AuthorityKeyId, intermediateIdentity.Crt.SubjectKeyId)
realIntermediate, err := x509.ParseCertificate(intermediateIdentity.Crt.Raw)
assert.FatalError(t, err)
assert.Equals(t, intermediate, realIntermediate)
} else {
err := readError(body)
if len(tc.errMsg) == 0 {
assert.FatalError(t, errors.New("must validate response error"))
}
assert.HasPrefix(t, err.Error(), tc.errMsg)
}
}
})
}
}
func TestCARoot(t *testing.T) {
config, err := authority.LoadConfiguration("testdata/ca.json")
assert.FatalError(t, err)
ca, err := New(config)
assert.FatalError(t, err)
rootCrt, err := pemutil.ReadCertificate("testdata/secrets/root_ca.crt")
assert.FatalError(t, err)
type rootTest struct {
ca *CA
sha string
status int
errMsg string
}
tests := map[string]func(t *testing.T) *rootTest{
"not-found": func(t *testing.T) *rootTest {
return &rootTest{
ca: ca,
sha: "foo",
status: http.StatusNotFound,
errMsg: "Not Found",
}
},
"success": func(t *testing.T) *rootTest {
return &rootTest{
ca: ca,
sha: "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7",
status: http.StatusOK,
}
},
}
for name, genTestCase := range tests {
t.Run(name, func(t *testing.T) {
tc := genTestCase(t)
rq, err := http.NewRequest("GET", fmt.Sprintf("/root/%s", tc.sha), strings.NewReader(""))
assert.FatalError(t, err)
rr := httptest.NewRecorder()
tc.ca.srv.Handler.ServeHTTP(rr, rq)
if assert.Equals(t, rr.Code, tc.status) {
body := &ClosingBuffer{rr.Body}
if rr.Code < http.StatusBadRequest {
var root api.RootResponse
assert.FatalError(t, readJSON(body, &root))
assert.Equals(t, root.RootPEM.Certificate, rootCrt)
} else {
err := readError(body)
if len(tc.errMsg) == 0 {
assert.FatalError(t, errors.New("must validate response error"))
}
assert.HasPrefix(t, err.Error(), tc.errMsg)
}
}
})
}
}
func TestCAHealth(t *testing.T) {
config, err := authority.LoadConfiguration("testdata/ca.json")
assert.FatalError(t, err)
ca, err := New(config)
assert.FatalError(t, err)
type rootTest struct {
ca *CA
status int
}
tests := map[string]func(t *testing.T) *rootTest{
"success": func(t *testing.T) *rootTest {
return &rootTest{
ca: ca,
status: http.StatusOK,
}
},
}
for name, genTestCase := range tests {
t.Run(name, func(t *testing.T) {
tc := genTestCase(t)
rq, err := http.NewRequest("GET", "/health", strings.NewReader(""))
assert.FatalError(t, err)
rr := httptest.NewRecorder()
tc.ca.srv.Handler.ServeHTTP(rr, rq)
if assert.Equals(t, rr.Code, tc.status) {
body := &ClosingBuffer{rr.Body}
if rr.Code < http.StatusBadRequest {
var health api.HealthResponse
assert.FatalError(t, readJSON(body, &health))
assert.Equals(t, health, api.HealthResponse{Status: "ok"})
}
}
})
}
}
func TestCARenew(t *testing.T) {
pub, _, err := keys.GenerateDefaultKeyPair()
assert.FatalError(t, err)
asn1dn := &x509util.ASN1DN{
Country: "Tazmania",
Organization: "Acme Co",
Locality: "Landscapes",
Province: "Sudden Cliffs",
StreetAddress: "TNT",
CommonName: "test",
}
config, err := authority.LoadConfiguration("testdata/ca.json")
assert.FatalError(t, err)
config.AuthorityConfig.Template = asn1dn
ca, err := New(config)
assert.FatalError(t, err)
assert.FatalError(t, err)
intermediateIdentity, err := x509util.LoadIdentityFromDisk("testdata/secrets/intermediate_ca.crt",
"testdata/secrets/intermediate_ca_key", pemutil.WithPassword([]byte("password")))
assert.FatalError(t, err)
now := time.Now().UTC()
leafExpiry := now.Add(time.Minute * 5)
type renewTest struct {
ca *CA
tlsConnState *tls.ConnectionState
status int
errMsg string
}
tests := map[string]func(t *testing.T) *renewTest{
"request-missing-tls": func(t *testing.T) *renewTest {
return &renewTest{
ca: ca,
tlsConnState: nil,
status: http.StatusBadRequest,
errMsg: "Bad Request",
}
},
"request-missing-peer-certificate": func(t *testing.T) *renewTest {
return &renewTest{
ca: ca,
tlsConnState: &tls.ConnectionState{PeerCertificates: []*x509.Certificate{}},
status: http.StatusBadRequest,
errMsg: "Bad Request",
}
},
"success": func(t *testing.T) *renewTest {
profile, err := x509util.NewLeafProfile("test", intermediateIdentity.Crt,
intermediateIdentity.Key, x509util.WithPublicKey(pub),
x509util.WithNotBeforeAfter(now, leafExpiry), x509util.WithHosts("funk"))
assert.FatalError(t, err)
crtBytes, err := profile.CreateCertificate()
assert.FatalError(t, err)
crt, err := x509.ParseCertificate(crtBytes)
assert.FatalError(t, err)
return &renewTest{
ca: ca,
tlsConnState: &tls.ConnectionState{
PeerCertificates: []*x509.Certificate{crt},
},
status: http.StatusCreated,
}
},
}
for name, genTestCase := range tests {
t.Run(name, func(t *testing.T) {
tc := genTestCase(t)
rq, err := http.NewRequest("POST", "/renew", strings.NewReader(""))
assert.FatalError(t, err)
rq.TLS = tc.tlsConnState
rr := httptest.NewRecorder()
tc.ca.srv.Handler.ServeHTTP(rr, rq)
if assert.Equals(t, rr.Code, tc.status) {
body := &ClosingBuffer{rr.Body}
if rr.Code < http.StatusBadRequest {
var sign api.SignResponse
assert.FatalError(t, readJSON(body, &sign))
leaf := sign.ServerPEM.Certificate
intermediate := sign.CaPEM.Certificate
assert.Equals(t, leaf.NotBefore, now.Truncate(time.Second))
assert.Equals(t, leaf.NotAfter, leafExpiry.Truncate(time.Second))
assert.Equals(t, fmt.Sprintf("%v", leaf.Subject),
fmt.Sprintf("%v", &pkix.Name{
CommonName: asn1dn.CommonName,
}))
assert.Equals(t, leaf.Issuer, intermediate.Subject)
assert.Equals(t, leaf.SignatureAlgorithm, x509.ECDSAWithSHA256)
assert.Equals(t, leaf.PublicKeyAlgorithm, x509.ECDSA)
assert.Equals(t, leaf.ExtKeyUsage,
[]x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth})
assert.Equals(t, leaf.DNSNames, []string{"funk"})
pubBytes, err := x509.MarshalPKIXPublicKey(pub)
assert.FatalError(t, err)
hash := sha1.Sum(pubBytes)
assert.Equals(t, leaf.SubjectKeyId, hash[:])
assert.Equals(t, leaf.AuthorityKeyId, intermediateIdentity.Crt.SubjectKeyId)
realIntermediate, err := x509.ParseCertificate(intermediateIdentity.Crt.Raw)
assert.FatalError(t, err)
assert.Equals(t, intermediate, realIntermediate)
assert.Equals(t, *sign.TLSOptions, authority.DefaultTLSOptions)
} else {
err := readError(body)
if len(tc.errMsg) == 0 {
assert.FatalError(t, errors.New("must validate response error"))
}
assert.HasPrefix(t, err.Error(), tc.errMsg)
}
}
})
}
}

350
ca/client.go Normal file
View file

@ -0,0 +1,350 @@
package ca
import (
"bytes"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/hex"
"encoding/json"
"encoding/pem"
"io"
"io/ioutil"
"net/http"
"net/url"
"strings"
"github.com/pkg/errors"
"github.com/smallstep/ca-component/api"
"golang.org/x/net/http2"
"gopkg.in/square/go-jose.v2/jwt"
)
// ClientOption is the type of options passed to the Client constructor.
type ClientOption func(o *clientOptions) error
type clientOptions struct {
transport http.RoundTripper
rootSHA256 string
rootFilename string
}
func (o *clientOptions) apply(opts []ClientOption) (err error) {
for _, fn := range opts {
if err = fn(o); err != nil {
return
}
}
return
}
// checkTransport checks if other ways to set up a transport have been provided.
// If they have it returns an error.
func (o *clientOptions) checkTransport() error {
if o.transport != nil || o.rootFilename != "" || o.rootSHA256 != "" {
return errors.New("multiple transport methods have been configured")
}
return nil
}
// getTransport returns the transport configured in the clientOptions.
func (o *clientOptions) getTransport(endpoint string) (tr http.RoundTripper, err error) {
if o.transport != nil {
tr = o.transport
}
if o.rootFilename != "" {
if tr, err = getTransportFromFile(o.rootFilename); err != nil {
return nil, err
}
}
if o.rootSHA256 != "" {
if tr, err = getTransportFromSHA256(endpoint, o.rootSHA256); err != nil {
return nil, err
}
}
if tr == nil {
return nil, errors.New("a transport, a root cert, or a root sha256 must be used")
}
return tr, nil
}
// WithTransport adds a custom transport to the Client. If the transport is
// given is given it will have preference over WithRootFile and WithRootSHA256.
func WithTransport(tr http.RoundTripper) ClientOption {
return func(o *clientOptions) error {
if err := o.checkTransport(); err != nil {
return err
}
o.transport = tr
return nil
}
}
// WithRootFile will create the transport using the given root certificate. If
// the root file is given it will have preference over WithRootSHA256, but less
// preference than WithTransport.
func WithRootFile(filename string) ClientOption {
return func(o *clientOptions) error {
if err := o.checkTransport(); err != nil {
return err
}
o.rootFilename = filename
return nil
}
}
// WithRootSHA256 will create the transport using an insecure client to retrieve the
// root certificate. It has less preference than WithTransport and WithRootFile.
func WithRootSHA256(sum string) ClientOption {
return func(o *clientOptions) error {
if err := o.checkTransport(); err != nil {
return err
}
o.rootSHA256 = sum
return nil
}
}
func getTransportFromFile(filename string) (http.RoundTripper, error) {
data, err := ioutil.ReadFile(filename)
if err != nil {
return nil, errors.Wrapf(err, "error reading %s", filename)
}
block, _ := pem.Decode(data)
if block == nil {
return nil, errors.Errorf("error decoding %s", filename)
}
root, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return nil, errors.Wrapf(err, "error parsing %s", filename)
}
pool := x509.NewCertPool()
pool.AddCert(root)
return getDefaultTransport(&tls.Config{
MinVersion: tls.VersionTLS12,
PreferServerCipherSuites: true,
RootCAs: pool,
})
}
func getTransportFromSHA256(endpoint, sum string) (http.RoundTripper, error) {
client, err := NewClient(endpoint)
if err != nil {
return nil, err
}
root, err := client.Root(sum)
if err != nil {
return nil, err
}
pool := x509.NewCertPool()
pool.AddCert(root.RootPEM.Certificate)
return getDefaultTransport(&tls.Config{
MinVersion: tls.VersionTLS12,
PreferServerCipherSuites: true,
RootCAs: pool,
})
}
// Client implements an HTTP client for the CA server.
type Client struct {
client *http.Client
endpoint *url.URL
certPool *x509.CertPool
}
// NewClient creates a new Client with the given endpoint and options.
func NewClient(endpoint string, opts ...ClientOption) (*Client, error) {
// Validate endpoint
u, err := url.Parse(endpoint)
if err != nil {
return nil, errors.Wrap(err, "error parsing endpoint")
}
if u.Scheme == "" || u.Host == "" {
return nil, errors.New("error parsing endpoint: url is not valid")
}
// Retrieve transport from options.
o := new(clientOptions)
if err := o.apply(opts); err != nil {
return nil, err
}
tr, err := o.getTransport(endpoint)
if err != nil {
return nil, err
}
var cp *x509.CertPool
switch tr := tr.(type) {
case *http.Transport:
if tr.TLSClientConfig != nil && tr.TLSClientConfig.RootCAs != nil {
cp = tr.TLSClientConfig.RootCAs
}
case *http2.Transport:
if tr.TLSClientConfig != nil && tr.TLSClientConfig.RootCAs != nil {
cp = tr.TLSClientConfig.RootCAs
}
}
return &Client{
client: &http.Client{
Transport: tr,
},
endpoint: u,
certPool: cp,
}, nil
}
// Health performs the health request to the CA and returns the
// api.HealthResponse struct.
func (c *Client) Health() (*api.HealthResponse, error) {
u := c.endpoint.ResolveReference(&url.URL{Path: "/health"})
resp, err := c.client.Get(u.String())
if err != nil {
return nil, errors.Wrapf(err, "client GET %s failed", u)
}
if resp.StatusCode >= 400 {
return nil, readError(resp.Body)
}
var health api.HealthResponse
if err := readJSON(resp.Body, &health); err != nil {
return nil, errors.Wrapf(err, "error reading %s", u)
}
return &health, nil
}
// Root performs the root request to the CA with the given SHA256 and returns
// the api.RootResponse struct. It uses an insecure client, but it checks the
// resulting root certificate with the given SHA256, returning an error if they
// do not match.
func (c *Client) Root(sha256Sum string) (*api.RootResponse, error) {
sha256Sum = strings.ToLower(strings.Replace(sha256Sum, "-", "", -1))
u := c.endpoint.ResolveReference(&url.URL{Path: "/root/" + sha256Sum})
resp, err := getInsecureClient().Get(u.String())
if err != nil {
return nil, errors.Wrapf(err, "client GET %s failed", u)
}
if resp.StatusCode >= 400 {
return nil, readError(resp.Body)
}
var root api.RootResponse
if err := readJSON(resp.Body, &root); err != nil {
return nil, errors.Wrapf(err, "error reading %s", u)
}
// verify the sha256
sum := sha256.Sum256(root.RootPEM.Raw)
if sha256Sum != strings.ToLower(hex.EncodeToString(sum[:])) {
return nil, errors.New("root certificate SHA256 fingerprint do not match")
}
return &root, nil
}
// Sign performs the sign request to the CA and returns the api.SignResponse
// struct.
func (c *Client) Sign(req *api.SignRequest) (*api.SignResponse, error) {
body, err := json.Marshal(req)
if err != nil {
return nil, errors.Wrap(err, "error marshaling request")
}
u := c.endpoint.ResolveReference(&url.URL{Path: "/sign"})
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
if err != nil {
return nil, errors.Wrapf(err, "client POST %s failed", u)
}
if resp.StatusCode >= 400 {
return nil, readError(resp.Body)
}
var sign api.SignResponse
if err := readJSON(resp.Body, &sign); err != nil {
return nil, errors.Wrapf(err, "error reading %s", u)
}
// Add tls.ConnectionState:
// We'll extract the root certificate from the verified chains
sign.TLS = resp.TLS
return &sign, nil
}
// Renew performs the renew request to the CA and returns the api.SignResponse
// struct.
func (c *Client) Renew(tr http.RoundTripper) (*api.SignResponse, error) {
u := c.endpoint.ResolveReference(&url.URL{Path: "/renew"})
client := &http.Client{Transport: tr}
resp, err := client.Post(u.String(), "application/json", http.NoBody)
if err != nil {
return nil, errors.Wrapf(err, "client POST %s failed", u)
}
if resp.StatusCode >= 400 {
return nil, readError(resp.Body)
}
var sign api.SignResponse
if err := readJSON(resp.Body, &sign); err != nil {
return nil, errors.Wrapf(err, "error reading %s", u)
}
return &sign, nil
}
// CreateSignRequest is a helper function that given an x509 OTT returns a
// simple but secure sign request as well as the private key used.
func CreateSignRequest(ott string) (*api.SignRequest, crypto.PrivateKey, error) {
token, err := jwt.ParseSigned(ott)
if err != nil {
return nil, nil, errors.Wrap(err, "error parsing ott")
}
var claims jwt.Claims
if err := token.UnsafeClaimsWithoutVerification(&claims); err != nil {
return nil, nil, errors.Wrap(err, "error parsing ott")
}
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, nil, errors.Wrap(err, "error generating key")
}
template := &x509.CertificateRequest{
Subject: pkix.Name{
CommonName: claims.Subject,
},
SignatureAlgorithm: x509.ECDSAWithSHA256,
DNSNames: []string{claims.Subject},
}
csr, err := x509.CreateCertificateRequest(rand.Reader, template, pk)
if err != nil {
return nil, nil, errors.Wrap(err, "error creating certificate request")
}
cr, err := x509.ParseCertificateRequest(csr)
if err != nil {
return nil, nil, errors.Wrap(err, "error parsing certificate request")
}
if err := cr.CheckSignature(); err != nil {
return nil, nil, errors.Wrap(err, "error signing certificate request")
}
return &api.SignRequest{
CsrPEM: api.CertificateRequest{CertificateRequest: cr},
OTT: ott,
}, pk, nil
}
func getInsecureClient() *http.Client {
return &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
},
}
}
func readJSON(r io.ReadCloser, v interface{}) error {
defer r.Close()
return json.NewDecoder(r).Decode(v)
}
func readError(r io.ReadCloser) error {
defer r.Close()
apiErr := new(api.Error)
if err := json.NewDecoder(r).Decode(apiErr); err != nil {
return err
}
return apiErr
}

388
ca/client_test.go Normal file
View file

@ -0,0 +1,388 @@
package ca
import (
"bytes"
"crypto/x509"
"encoding/json"
"encoding/pem"
"fmt"
"net/http"
"net/http/httptest"
"reflect"
"testing"
"time"
"github.com/smallstep/ca-component/api"
)
const (
rootPEM = `-----BEGIN CERTIFICATE-----
MIIEBDCCAuygAwIBAgIDAjppMA0GCSqGSIb3DQEBBQUAMEIxCzAJBgNVBAYTAlVT
MRYwFAYDVQQKEw1HZW9UcnVzdCBJbmMuMRswGQYDVQQDExJHZW9UcnVzdCBHbG9i
YWwgQ0EwHhcNMTMwNDA1MTUxNTU1WhcNMTUwNDA0MTUxNTU1WjBJMQswCQYDVQQG
EwJVUzETMBEGA1UEChMKR29vZ2xlIEluYzElMCMGA1UEAxMcR29vZ2xlIEludGVy
bmV0IEF1dGhvcml0eSBHMjCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB
AJwqBHdc2FCROgajguDYUEi8iT/xGXAaiEZ+4I/F8YnOIe5a/mENtzJEiaB0C1NP
VaTOgmKV7utZX8bhBYASxF6UP7xbSDj0U/ck5vuR6RXEz/RTDfRK/J9U3n2+oGtv
h8DQUB8oMANA2ghzUWx//zo8pzcGjr1LEQTrfSTe5vn8MXH7lNVg8y5Kr0LSy+rE
ahqyzFPdFUuLH8gZYR/Nnag+YyuENWllhMgZxUYi+FOVvuOAShDGKuy6lyARxzmZ
EASg8GF6lSWMTlJ14rbtCMoU/M4iarNOz0YDl5cDfsCx3nuvRTPPuj5xt970JSXC
DTWJnZ37DhF5iR43xa+OcmkCAwEAAaOB+zCB+DAfBgNVHSMEGDAWgBTAephojYn7
qwVkDBF9qn1luMrMTjAdBgNVHQ4EFgQUSt0GFhu89mi1dvWBtrtiGrpagS8wEgYD
VR0TAQH/BAgwBgEB/wIBADAOBgNVHQ8BAf8EBAMCAQYwOgYDVR0fBDMwMTAvoC2g
K4YpaHR0cDovL2NybC5nZW90cnVzdC5jb20vY3Jscy9ndGdsb2JhbC5jcmwwPQYI
KwYBBQUHAQEEMTAvMC0GCCsGAQUFBzABhiFodHRwOi8vZ3RnbG9iYWwtb2NzcC5n
ZW90cnVzdC5jb20wFwYDVR0gBBAwDjAMBgorBgEEAdZ5AgUBMA0GCSqGSIb3DQEB
BQUAA4IBAQA21waAESetKhSbOHezI6B1WLuxfoNCunLaHtiONgaX4PCVOzf9G0JY
/iLIa704XtE7JW4S615ndkZAkNoUyHgN7ZVm2o6Gb4ChulYylYbc3GrKBIxbf/a/
zG+FA1jDaFETzf3I93k9mTXwVqO94FntT0QJo544evZG0R0SnU++0ED8Vf4GXjza
HFa9llF7b1cq26KqltyMdMKVvvBulRP/F/A8rLIQjcxz++iPAsbw+zOzlTvjwsto
WHPbqCRiOwY1nQ2pM714A5AuTHhdUDqB1O6gyHA43LL5Z/qHQF1hwFGPa4NrzQU6
yuGnBXj8ytqU0CwIPX4WecigUCAkVDNx
-----END CERTIFICATE-----`
certPEM = `-----BEGIN CERTIFICATE-----
MIIDujCCAqKgAwIBAgIIE31FZVaPXTUwDQYJKoZIhvcNAQEFBQAwSTELMAkGA1UE
BhMCVVMxEzARBgNVBAoTCkdvb2dsZSBJbmMxJTAjBgNVBAMTHEdvb2dsZSBJbnRl
cm5ldCBBdXRob3JpdHkgRzIwHhcNMTQwMTI5MTMyNzQzWhcNMTQwNTI5MDAwMDAw
WjBpMQswCQYDVQQGEwJVUzETMBEGA1UECAwKQ2FsaWZvcm5pYTEWMBQGA1UEBwwN
TW91bnRhaW4gVmlldzETMBEGA1UECgwKR29vZ2xlIEluYzEYMBYGA1UEAwwPbWFp
bC5nb29nbGUuY29tMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEfRrObuSW5T7q
5CnSEqefEmtH4CCv6+5EckuriNr1CjfVvqzwfAhopXkLrq45EQm8vkmf7W96XJhC
7ZM0dYi1/qOCAU8wggFLMB0GA1UdJQQWMBQGCCsGAQUFBwMBBggrBgEFBQcDAjAa
BgNVHREEEzARgg9tYWlsLmdvb2dsZS5jb20wCwYDVR0PBAQDAgeAMGgGCCsGAQUF
BwEBBFwwWjArBggrBgEFBQcwAoYfaHR0cDovL3BraS5nb29nbGUuY29tL0dJQUcy
LmNydDArBggrBgEFBQcwAYYfaHR0cDovL2NsaWVudHMxLmdvb2dsZS5jb20vb2Nz
cDAdBgNVHQ4EFgQUiJxtimAuTfwb+aUtBn5UYKreKvMwDAYDVR0TAQH/BAIwADAf
BgNVHSMEGDAWgBRK3QYWG7z2aLV29YG2u2IaulqBLzAXBgNVHSAEEDAOMAwGCisG
AQQB1nkCBQEwMAYDVR0fBCkwJzAloCOgIYYfaHR0cDovL3BraS5nb29nbGUuY29t
L0dJQUcyLmNybDANBgkqhkiG9w0BAQUFAAOCAQEAH6RYHxHdcGpMpFE3oxDoFnP+
gtuBCHan2yE2GRbJ2Cw8Lw0MmuKqHlf9RSeYfd3BXeKkj1qO6TVKwCh+0HdZk283
TZZyzmEOyclm3UGFYe82P/iDFt+CeQ3NpmBg+GoaVCuWAARJN/KfglbLyyYygcQq
0SgeDh8dRKUiaW3HQSoYvTvdTuqzwK4CXsr3b5/dAOY8uMuG/IAR3FgwTbZ1dtoW
RvOTa8hYiU6A475WuZKyEHcwnGYe57u2I2KbMgcKjPniocj4QzgYsVAVKW3IwaOh
yE+vPxsiUkvQHdO2fojCkY8jg70jxM+gu59tPDNbw3Uh/2Ij310FgTHsnGQMyA==
-----END CERTIFICATE-----`
csrPEM = `-----BEGIN CERTIFICATE REQUEST-----
MIIEYjCCAkoCAQAwHTEbMBkGA1UEAxMSdGVzdC5zbWFsbHN0ZXAuY29tMIICIjAN
BgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEAuCpifZfoZhYNywfpnPa21NezXgtn
wrWBFE6xhVzE7YDSIqtIsj8aR7R8zwEymxfv5j5298LUy/XSmItVH31CsKyfcGqN
QM0PZr9XY3z5V6qchGMqjzt/jqlYMBHujcxIFBfz4HATxSgKyvHqvw14ESsS2huu
7jowx+XTKbFYgKcXrjBkvOej5FXD3ehkg0jDA2UAJNdfKmrc1BBEaaqOtfh7eyU2
HU7+5gxH8C27IiCAmNj719E0B99Nu2MUw6aLFIM4xAcRga33Avevx6UuXZZIEepe
V1sihrkcnDK9Vsxkme5erXzvAoOiRusiC2iIomJHJrdRM5ReEU+N+Tl1Kxq+rk7H
/qAq78wVm07M1/GGi9SUMObZS4WuJpM6whlikIAEbv9iV+CK0sv/Jr/AADdGMmQU
lwk+Q0ZNE8p4ZuWILv/dtLDtDVBpnrrJ9e8duBtB0lGcG8MdaUCQ346EI4T0Sgx0
hJ+wMq8zYYFfPIZEHC8o9p1ywWN9ySpJ8Zj/5ubmx9v2bY67GbuVFEa8iAp+S00x
/Z8nD6/JsoKtexuHyGr3ixWFzlBqXDuugukIDFUOVDCbuGw4Io4/hEMu4Zz0TIFk
Uu/wf2z75Tt8EkosKLu2wieKcY7n7Vhog/0tqexqWlWtJH0tvq4djsGoSvA62WPs
0iXXj+aZIARPNhECAwEAAaAAMA0GCSqGSIb3DQEBCwUAA4ICAQA0vyHIndAkIs/I
Nnz5yZWCokRjokoKv3Aj4VilyjncL+W0UIPULLU/47ZyoHVSUj2t8gknr9xu/Kd+
g/2z0RiF3CIp8IUH49w/HYWaR95glzVNAAzr8qD9UbUqloLVQW3lObSRGtezhdZO
sspw5dC+inhAb1LZhx8PVxB3SAeJ8h11IEBr0s2Hxt9viKKd7YPtIFZkZdOkVx4R
if1DMawj1P6fEomf8z7m+dmbUYTqqosbCbRL01mzEga/kF6JyH/OzpNlcsAiyM8e
BxPWH6TtPqwmyy4y7j1outmM0RnyUw5A0HmIbWh+rHpXiHVsnNqse0XfzmaxM8+z
dxYeDax8aMWZKfvY1Zew+xIxl7DtEy1BpxrZcawumJYt5+LL+bwF/OtL0inQLnw8
zyqydsXNdrpIQJnfmWPld7ThWbQw2FBE70+nFSxHeG2ULnpF3M9xf6ZNAF4gqaNE
Q7vMNPBWrJWu+A++vHY61WGET+h4lY3GFr2I8OE4IiHPQi1D7Y0+fwOmStwuRPM4
2rARcJChNdiYBkkuvs4kixKTTjdXhB8RQtuBSrJ0M1tzq2qMbm7F8G01rOg4KlXU
58jHzJwr1K7cx0lpWfGTtc5bseCGtTKmDBXTziw04yl8eE1+ZFOganixGwCtl4Tt
DCbKzWTW8lqVdp9Kyf7XEhhc2R8C5w==
-----END CERTIFICATE REQUEST-----`
)
func parseCertificate(data string) *x509.Certificate {
block, _ := pem.Decode([]byte(data))
if block == nil {
panic("failed to parse certificate PEM")
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
panic("failed to parse certificate: " + err.Error())
}
return cert
}
func parseCertificateRequest(data string) *x509.CertificateRequest {
block, _ := pem.Decode([]byte(csrPEM))
if block == nil {
panic("failed to parse certificate request PEM")
}
csr, err := x509.ParseCertificateRequest(block.Bytes)
if err != nil {
panic("failed to parse certificate request: " + err.Error())
}
return csr
}
func equalJSON(t *testing.T, a interface{}, b interface{}) bool {
if reflect.DeepEqual(a, b) {
return true
}
ab, err := json.Marshal(a)
if err != nil {
t.Error(err)
return false
}
bb, err := json.Marshal(b)
if err != nil {
t.Error(err)
return false
}
return bytes.Equal(ab, bb)
}
func TestClient_Health(t *testing.T) {
ok := &api.HealthResponse{Status: "ok"}
nok := api.InternalServerError(fmt.Errorf("Internal Server Error"))
tests := []struct {
name string
response interface{}
responseCode int
wantErr bool
}{
{"ok", ok, 200, false},
{"not ok", nok, 500, true},
}
srv := httptest.NewServer(nil)
defer srv.Close()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
if err != nil {
t.Errorf("NewClient() error = %v", err)
return
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.WriteHeader(tt.responseCode)
api.JSON(w, tt.response)
})
got, err := c.Health()
if (err != nil) != tt.wantErr {
fmt.Printf("%+v", err)
t.Errorf("Client.Health() error = %v, wantErr %v", err, tt.wantErr)
return
}
switch {
case err != nil:
if got != nil {
t.Errorf("Client.Health() = %v, want nil", got)
}
if !reflect.DeepEqual(err, tt.response) {
t.Errorf("Client.Health() error = %v, want %v", err, tt.response)
}
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Health() = %v, want %v", got, tt.response)
}
}
})
}
}
func TestClient_Root(t *testing.T) {
ok := &api.RootResponse{
RootPEM: api.Certificate{Certificate: parseCertificate(rootPEM)},
}
notFound := api.NotFound(fmt.Errorf("Not Found"))
tests := []struct {
name string
shasum string
response interface{}
responseCode int
wantErr bool
}{
{"ok", "a047a37fa2d2e118a4f5095fe074d6cfe0e352425a7632bf8659c03919a6c81d", ok, 200, false},
{"not found", "invalid", notFound, 404, true},
}
srv := httptest.NewServer(nil)
defer srv.Close()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
if err != nil {
t.Errorf("NewClient() error = %v", err)
return
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
expected := "/root/" + tt.shasum
if req.RequestURI != expected {
t.Errorf("RequestURI = %s, want %s", req.RequestURI, expected)
}
w.WriteHeader(tt.responseCode)
api.JSON(w, tt.response)
})
got, err := c.Root(tt.shasum)
if (err != nil) != tt.wantErr {
t.Errorf("Client.Root() error = %v, wantErr %v", err, tt.wantErr)
return
}
switch {
case err != nil:
if got != nil {
t.Errorf("Client.Root() = %v, want nil", got)
}
if !reflect.DeepEqual(err, tt.response) {
t.Errorf("Client.Root() error = %v, want %v", err, tt.response)
}
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Root() = %v, want %v", got, tt.response)
}
}
})
}
}
func TestClient_Sign(t *testing.T) {
ok := &api.SignResponse{
ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)},
CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)},
}
request := &api.SignRequest{
CsrPEM: api.CertificateRequest{CertificateRequest: parseCertificateRequest(csrPEM)},
OTT: "the-ott",
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(0, 1, 0),
}
unauthorized := api.Unauthorized(fmt.Errorf("Unauthorized"))
badRequest := api.BadRequest(fmt.Errorf("Bad Request"))
tests := []struct {
name string
request *api.SignRequest
response interface{}
responseCode int
wantErr bool
}{
{"ok", request, ok, 200, false},
{"unauthorized", request, unauthorized, 401, true},
{"empty request", &api.SignRequest{}, badRequest, 403, true},
{"nil request", nil, badRequest, 403, true},
}
srv := httptest.NewServer(nil)
defer srv.Close()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
if err != nil {
t.Errorf("NewClient() error = %v", err)
return
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
body := new(api.SignRequest)
if err := api.ReadJSON(req.Body, body); err != nil {
api.WriteError(w, badRequest)
return
} else if !equalJSON(t, body, tt.request) {
if tt.request == nil {
if !reflect.DeepEqual(body, &api.SignRequest{}) {
t.Errorf("Client.Sign() request = %v, wants %v", body, tt.request)
}
} else {
t.Errorf("Client.Sign() request = %v, wants %v", body, tt.request)
}
}
w.WriteHeader(tt.responseCode)
api.JSON(w, tt.response)
})
got, err := c.Sign(tt.request)
if (err != nil) != tt.wantErr {
fmt.Printf("%+v", err)
t.Errorf("Client.Sign() error = %v, wantErr %v", err, tt.wantErr)
return
}
switch {
case err != nil:
if got != nil {
t.Errorf("Client.Sign() = %v, want nil", got)
}
if !reflect.DeepEqual(err, tt.response) {
t.Errorf("Client.Sign() error = %v, want %v", err, tt.response)
}
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Sign() = %v, want %v", got, tt.response)
}
}
})
}
}
func TestClient_Renew(t *testing.T) {
ok := &api.SignResponse{
ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)},
CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)},
}
unauthorized := api.Unauthorized(fmt.Errorf("Unauthorized"))
badRequest := api.BadRequest(fmt.Errorf("Bad Request"))
tests := []struct {
name string
response interface{}
responseCode int
wantErr bool
}{
{"ok", ok, 200, false},
{"unauthorized", unauthorized, 401, true},
{"empty request", badRequest, 403, true},
{"nil request", badRequest, 403, true},
}
srv := httptest.NewServer(nil)
defer srv.Close()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
if err != nil {
t.Errorf("NewClient() error = %v", err)
return
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.WriteHeader(tt.responseCode)
api.JSON(w, tt.response)
})
got, err := c.Renew(nil)
if (err != nil) != tt.wantErr {
fmt.Printf("%+v", err)
t.Errorf("Client.Renew() error = %v, wantErr %v", err, tt.wantErr)
return
}
switch {
case err != nil:
if got != nil {
t.Errorf("Client.Renew() = %v, want nil", got)
}
if !reflect.DeepEqual(err, tt.response) {
t.Errorf("Client.Renew() error = %v, want %v", err, tt.response)
}
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Renew() = %v, want %v", got, tt.response)
}
}
})
}
}

147
ca/renew.go Normal file
View file

@ -0,0 +1,147 @@
package ca
import (
"context"
"crypto/tls"
"math/rand"
"sync"
"time"
"github.com/pkg/errors"
)
// RenewFunc defines the type of the functions used to get a new tls
// certificate.
type RenewFunc func() (*tls.Certificate, error)
// TLSRenewer renews automatically a tls certificate with a given function.
type TLSRenewer struct {
sync.RWMutex
RenewCertificate RenewFunc
cert *tls.Certificate
timer *time.Timer
renewBefore time.Duration
renewJitter time.Duration
}
type tlsRenewerOptions func(r *TLSRenewer) error
// WithRenewBefore modifies a tlsRenewer by setting the renewBefore attribute.
func WithRenewBefore(b time.Duration) func(r *TLSRenewer) error {
return func(r *TLSRenewer) error {
r.renewBefore = b
return nil
}
}
// WithRenewJitter modifies a tlsRenewer by setting the renewJitter attribute.
func WithRenewJitter(j time.Duration) func(r *TLSRenewer) error {
return func(r *TLSRenewer) error {
r.renewJitter = j
return nil
}
}
// NewTLSRenewer creates a TLSRenewer for the given cert. It will use the given
// function to get a new certificate when required.
func NewTLSRenewer(cert *tls.Certificate, fn RenewFunc, opts ...tlsRenewerOptions) (*TLSRenewer, error) {
r := &TLSRenewer{
RenewCertificate: fn,
cert: cert,
}
for _, f := range opts {
if err := f(r); err != nil {
return nil, errors.Wrap(err, "error applying options")
}
}
period := cert.Leaf.NotAfter.Sub(cert.Leaf.NotBefore)
if period < time.Minute {
return nil, errors.Errorf("period must be greater than or equal to 1 Minute, but got %v.", period)
}
// By default we will try to renew the cert before 2/3 of the validity
// period have expired.
if r.renewBefore == 0 {
r.renewBefore = period / 3
}
// By default we set the jitter to 1/20th of the validity period.
if r.renewJitter == 0 {
r.renewJitter = period / 20
}
return r, nil
}
// Run starts the certificate renewer for the given certificate.
func (r *TLSRenewer) Run() {
cert := r.getCertificate()
next := r.nextRenewDuration(cert.Leaf.NotAfter)
r.timer = time.AfterFunc(next, r.renewCertificate)
}
// RunContext starts the certificate renewer for the given certificate.
func (r *TLSRenewer) RunContext(ctx context.Context) {
r.Run()
go func() {
<-ctx.Done()
r.Stop()
}()
}
// Stop prevents the renew timer from firing.
func (r *TLSRenewer) Stop() bool {
return r.timer.Stop()
}
// GetCertificate returns the current server certificate.
//
// This method is set in the tls.Config GetCertificate property.
func (r *TLSRenewer) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
return r.getCertificate(), nil
}
// GetClientCertificate returns the current client certificate.
//
// This method is set in the tls.Config GetClientCertificate property.
func (r *TLSRenewer) GetClientCertificate(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
return r.getCertificate(), nil
}
// getCertificate returns the certificate using a read-only lock.
func (r *TLSRenewer) getCertificate() *tls.Certificate {
r.RLock()
cert := r.cert
r.RUnlock()
return cert
}
// setCertificate updates the certificate using a read-write lock.
func (r *TLSRenewer) setCertificate(cert *tls.Certificate) {
r.Lock()
r.cert = cert
r.Unlock()
}
func (r *TLSRenewer) renewCertificate() {
var next time.Duration
cert, err := r.RenewCertificate()
if err != nil {
next = r.renewJitter / 2
next += time.Duration(rand.Int63n(int64(next)))
} else {
r.setCertificate(cert)
next = r.nextRenewDuration(cert.Leaf.NotAfter)
}
r.timer = time.AfterFunc(next, r.renewCertificate)
}
func (r *TLSRenewer) nextRenewDuration(notAfter time.Time) time.Duration {
d := notAfter.Sub(time.Now()) - r.renewBefore
n := rand.Int63n(int64(r.renewJitter))
d -= time.Duration(n)
if d < 0 {
d = 0
}
return d
}

49
ca/signal.go Normal file
View file

@ -0,0 +1,49 @@
package ca
import (
"log"
"os"
"os/signal"
"syscall"
)
// StopReloader is the interface that external commands can implement to stop
// the server and reload the configuration while running.
type StopReloader interface {
Stop() error
Reload() error
}
// StopReloaderHandler watches SIGINT, SIGTERM and SIGHUP on a list of servers
// implementing the StopReloader interface, and when one of those signals is
// caught we'll run Stop (SIGINT, SIGTERM) or Reload (SIGHUP) on all servers.
func StopReloaderHandler(servers ...StopReloader) {
signals := make(chan os.Signal, 1)
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)
defer signal.Stop(signals)
for {
select {
case sig := <-signals:
switch sig {
case syscall.SIGHUP:
log.Println("reloading ...")
for _, server := range servers {
err := server.Reload()
if err != nil {
log.Printf("error reloading server: %+v", err)
}
}
case syscall.SIGINT, syscall.SIGTERM:
log.Println("shutting down ...")
for _, server := range servers {
err := server.Stop()
if err != nil {
log.Printf("error stopping server: %s", err.Error())
}
}
return
}
}
}
}

95
ca/testdata/ca.json vendored Normal file
View file

@ -0,0 +1,95 @@
{
"root": "testdata/secrets/root_ca.crt",
"crt": "testdata/secrets/intermediate_ca.crt",
"key": "testdata/secrets/intermediate_ca_key",
"password": "password",
"address": "127.0.0.1:0",
"dnsNames": ["127.0.0.1"],
"logger": {"format": "text"},
"tls": {
"minVersion": 1.2,
"maxVersion": 1.2,
"renegotiation": false,
"cipherSuites": [
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305",
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384"
]
},
"authority": {
"minCertDuration": "1m",
"provisioners": [
{
"issuer": "max",
"type": "jwk",
"encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6IkpsNkZLWUp4V1UwdGRIbG9UanA1aGcifQ.Qy0EP6u5-t0ggOweoc3Z1DCzR5BllsQi.KUkviZ_TJKY4c0Mi.h7QZqgh_Fl2MZpmVy4h375yC0DORjB1dQULbNqc6MuUCW2iweWVRysFImUXiXMUKRarJC5adwWy1GhyAqUj6Xj1iOZDGLjYnqMETGWcI0rKDBwcSU7y7Y-2VYBRDSM2b7aWtTBfz3_kvEaw_vc3b5CEPJ86UlZc-jhKFRr_IcGWU-vXX5-bppoH15IPreyzi55YdjCll338lYpDecB_Paym3XBXotyd2iGXXUwoA1npEFwuyRMMEhl9zLp7rVcMW6A_32EzB8cZANEnA0C4FXGHQalY6u_2UeqxcC8_FuXPay6VIYODyRqcABvvkft3nwOcrI0pYDGBdk2w2Euk.kOAFq3Tg6s4vBGS_plMpSw",
"key": {
"use": "sig",
"kty": "EC",
"kid": "IMi94WBNI6gP5cNHXlZYNUzvMjGdHyBRmFoo-lCEaqk",
"crv": "P-256",
"alg": "ES256",
"x": "XmaY0c9Cc_kjfn9uhimiDiKnKn00gmFzzsvElg4KxoE",
"y": "ZhYcFQBqtErdC_pA7sOXrO7AboCEPIKP9Ik4CHJqANk"
}
}, {
"issuer": "max",
"type": "jwk",
"encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6IlZsWnl0dUxrWTR5enlqZXJybnN0aGcifQ.QP15wQYjZ12BLgl-XTq2Vb12G3OHAfic.X35QqAaXwnlmeCUU._2qIUp0TI8yDI7c2e9upIRdrnmB5OvtLfrYN-Su2NLBpaoYtr9O55Wo0Iryc0W2pYqnVDPvgPPes4P4nQAnzw5WhFYc1Xf1ZEetfdNhwi1x2FNwPbACBAgxm5AW40O5AAlbLcWushYASfeMBZocTGXuSGUzwFqoWD-5EDJ80TWQ7cAj3ttHrJ_3QV9hi4O9KJUCiXngN-Yz2zXrhBL4NOH2fmRbaf5c0rF8xUJIIW-TcyYJeX_Fbx1IzzKKPd9USUwkDhxD4tLa51I345xVqjuwG1PEn6nF8JKqLRVUKEKFin-ShXrfE61KceyAvm4YhWKrbJWIm3bH5Hxaphy4.TexIrIhsRxJStpE3EJ925Q",
"key": {
"use": "sig",
"kty": "EC",
"kid": "DC06fatJ5nALkfEubR3VVgQ2XNy_DXSKZhwGoRO8cWU",
"crv": "P-256",
"alg": "ES256",
"x": "SuaL-GJ3LmgBF43Da9ZCY-BzmvlkMJ61MAZ1UELPpTw",
"y": "wnqZSMuXpmUxORq20t83LyY4BDYmqDGV9P7FGR6mw84"
}
}, {
"issuer": "step-cli",
"type": "jwk",
"encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6IlhOdmYxQjgxSUlLMFA2NUkwcmtGTGcifQ.XaN9zcPQeWt49zchUDm34FECUTHfQTn_.tmNHPQDqR3ebsWfd.9WZr3YVdeOyJh36vvx0VlRtluhvYp4K7jJ1KGDr1qypwZ3ziBVSNbYYQ71du7fTtrnfG1wgGTVR39tWSzBU-zwQ5hdV3rpMAaEbod5zeW6SHd95H3Bvcb43YiiqJFNL5sGZzFb7FqzVmpsZ1efiv6sZaGDHtnCAL6r12UG5EZuqGfM0jGCZitUz2m9TUKXJL5DJ7MOYbFfkCEsUBPDm_TInliSVn2kMJhFa0VOe5wZk5YOuYM3lNYW64HGtbf-llN2Xk-4O9TfeSPizBx9ZqGpeu8pz13efUDT2WL9tWo6-0UE-CrG0bScm8lFTncTkHcu49_a5NaUBkYlBjEiw.thPcx3t1AUcWuEygXIY3Fg",
"key": {
"use": "sig",
"kty": "EC",
"kid": "4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc",
"crv": "P-256",
"alg": "ES256",
"x": "7ZdAAMZCFU4XwgblI5RfZouBi8lYmF6DlZusNNnsbm8",
"y": "sQr2JdzwD2fgyrymBEXWsxDxFNjjqN64qLLSbLdLZ9Y"
}
}, {
"issuer": "mariano",
"type": "jwk",
"encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6IlB1UnJVQ1RZZkR1T2F5MEh2cGl6bncifQ.7a-OP5xWGbFra8m2MN9YuLGt6v4y0wmB.u-54daK2y-0UO9na.3GQy6E52-fOSUu5NJ_sEbxj_T3CTyWb7wOPFv2oI2PBWXp5CLpiWJbCFpF4v2oD9fN5XbxMP14ootbrFjATnoMWfWgyLwG-KOj9BqMGNxhG2v37yC7Wrris6s30nrPa3uyNEYZ12AOQW1K04cU2X0u_qJM3vzMCle548ZFTWs6_d6L8lp3o0F9MEbCmJ4p6CLqQxjxYtn1aD79lM91NbIXpRP3iUFQRly-y_iC2mSkXCdd_cQ6-dqLUchXwWRyVO5nBHb4J87aZ91VApw7ldTLtwRZ2ZGJpqGQGgjTwi4sgjEcMuGg0_83XGk2ubdlKDpmGFedOHS5rYCbxotts.vSYfxsi2UU9LQeySDjAnnQ",
"key": {
"use": "sig",
"kty": "EC",
"kid": "FLIV7q23CXHrg75J2OSbvzwKJJqoxCYixjmsJirneOg",
"crv": "P-256",
"alg": "ES256",
"x": "tTKthEHN7RuybhkaC43J2oLfBG995FNSWbtahLAiK7Y",
"y": "e3wycXwVB366F0wLE5J9gIpq8EIQ4900nHBNpIGebEA"
}
}, {
"issuer": "mariano",
"type": "jwk",
"encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6Ik5SLTk5ZkVMSm1CLW1FZGllUlFFc3cifQ.Fr314BEUGTda4ICJl2uxFdjpEUGGqJEV.gBbu_DZE1ONDu14r.X-7MKMyokZIF1HTCVqqL0tTWgaC1ZGZBLLltd11ZUhQTswo_8kvgiTv3cFShj7ATF0tAY8HStyJmzLO8mKPVOPDXSwjdNsPriZclI6JWGi9iOu8pEiN9pZM6-itxan1JMcDUNg2U-P1BmKppHRbDKsOTivymfRyeUk51dBIlS54p5xNK1HFLc1YtWC1Rc_ngYVqOgqlhIrCHArAEBe3jrfUaH2ym-8fkVdwVqtxmte3XXK9g8FchsygRNnOKtRcr0TyzTUV-7bPi8_t02Zi-EHLFaSawVXWV_Qk1GeLYJR22Rp74beo-b5-lCNVp10btO0xdGySUWmCJ4v4_QZw.c8unwWycwtfdJMM_0b0fuA",
"key": {
"use": "sig",
"kty": "EC",
"kid": "kA5qxq_k8VFc2vzriBUU1FdzHpRfQ5Uq4W3803l1m5U",
"crv": "P-256",
"alg": "ES256",
"x": "qGXXrT1vgRKVpqLoVwdgIut5VjvxrHa_V4xhh2kQvY0",
"y": "8YHQPb031kQ9gMG8ue-YRy0Fm8Gc-v6TnYYLxRGcSjw"
}
}
],
"template": {
"country": "US",
"locality": "San Francisco",
"organization": "Smallstep"
}
}
}

12
ca/testdata/secrets/intermediate_ca.crt vendored Normal file
View file

@ -0,0 +1,12 @@
-----BEGIN CERTIFICATE-----
MIIB0DCCAXWgAwIBAgIQaYEAv6hTHRU+ZEnIJ6VB7zAKBggqhkjOPQQDAjAhMR8w
HQYDVQQDExZTbWFsbHN0ZXAgVGVzdCBSb290IENBMB4XDTE4MDkyNzE4MTgwOVoX
DTI4MDkyNDE4MTgwOVowKTEnMCUGA1UEAxMeU21hbGxzdGVwIFRlc3QgSW50ZXJt
ZWRpYXRlIENBMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEUnFoY688av7AhSsP
vAMXHuA66zdzujzw/Wx0F/ZkWagbo52zskTxElrTt/Qkiotv33EKTUaJ7mSV/ZhW
DaI6TqOBhjCBgzAOBgNVHQ8BAf8EBAMCAaYwHQYDVR0lBBYwFAYIKwYBBQUHAwEG
CCsGAQUFBwMCMBIGA1UdEwEB/wQIMAYBAf8CAQAwHQYDVR0OBBYEFAKELAm5/V3t
40xrDbKcDn5VWYThMB8GA1UdIwQYMBaAFAdgQF1Ej2WxY52Olc2wKVePE596MAoG
CCqGSM49BAMCA0kAMEYCIQCoCUGx0W5wv3iQjlGIhux/zWZiDkyIbGj3ASeUL5v9
QgIhAJ8dVOcqW3oq2TF9hHv8tXjhwmK44krO/FMK4gHljo4i
-----END CERTIFICATE-----

View file

@ -0,0 +1,8 @@
-----BEGIN EC PRIVATE KEY-----
Proc-Type: 4,ENCRYPTED
DEK-Info: AES-256-CBC,62bb1ccb9ed22ed553a479e34a4a0765
6lqTXwNel3jJjj+LdkA1E3Xr7bbeSukQLouFq2cbjh9Zyqb2xuhS2goxWZw0DDmG
rhCCKyiQnR+ImuHAwZnKBouWvp6po8CR4C1STNAX45wPfIhPV3UA49xbiA1sM+AE
QrlwCWVk9x/JhkZURK0T/3TWtdk9llcnhSKfAXnekAA=
-----END EC PRIVATE KEY-----

8
ca/testdata/secrets/ott_key vendored Normal file
View file

@ -0,0 +1,8 @@
-----BEGIN EC PRIVATE KEY-----
Proc-Type: 4,ENCRYPTED
DEK-Info: AES-256-CBC,f6870a50902e9397844faaf37f6196fc
BVotbStC8KUiRyR6azjNu5nM1ER3/DtrdS/DxzDWJdWCPfayvQAU47DwoZdZ8Id2
Cu92bfKB0gQsgckPSfQhMC6sCd9JEiV7NqyLztDLnJJBmhml6fPMhoQaHAZy+qgW
RiVrBaYXR92DTbtzFuYb03nmHeUVCjAT/R8Q21SCAfE=
-----END EC PRIVATE KEY-----

4
ca/testdata/secrets/ott_key.public vendored Normal file
View file

@ -0,0 +1,4 @@
-----BEGIN PUBLIC KEY-----
MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEtTKthEHN7RuybhkaC43J2oLfBG99
5FNSWbtahLAiK7Z7fDJxfBUHfroXTAsTkn2AimrwQhDj3TSccE2kgZ5sQA==
-----END PUBLIC KEY-----

View file

@ -0,0 +1,7 @@
{
"protected": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6IlB1UnJVQ1RZZkR1T2F5MEh2cGl6bncifQ",
"encrypted_key": "7a-OP5xWGbFra8m2MN9YuLGt6v4y0wmB",
"iv": "u-54daK2y-0UO9na",
"ciphertext": "3GQy6E52-fOSUu5NJ_sEbxj_T3CTyWb7wOPFv2oI2PBWXp5CLpiWJbCFpF4v2oD9fN5XbxMP14ootbrFjATnoMWfWgyLwG-KOj9BqMGNxhG2v37yC7Wrris6s30nrPa3uyNEYZ12AOQW1K04cU2X0u_qJM3vzMCle548ZFTWs6_d6L8lp3o0F9MEbCmJ4p6CLqQxjxYtn1aD79lM91NbIXpRP3iUFQRly-y_iC2mSkXCdd_cQ6-dqLUchXwWRyVO5nBHb4J87aZ91VApw7ldTLtwRZ2ZGJpqGQGgjTwi4sgjEcMuGg0_83XGk2ubdlKDpmGFedOHS5rYCbxotts",
"tag": "vSYfxsi2UU9LQeySDjAnnQ"
}

View file

@ -0,0 +1,9 @@
{
"use": "sig",
"kty": "EC",
"kid": "FLIV7q23CXHrg75J2OSbvzwKJJqoxCYixjmsJirneOg",
"crv": "P-256",
"alg": "ES256",
"x": "tTKthEHN7RuybhkaC43J2oLfBG995FNSWbtahLAiK7Y",
"y": "e3wycXwVB366F0wLE5J9gIpq8EIQ4900nHBNpIGebEA"
}

11
ca/testdata/secrets/root_ca.crt vendored Normal file
View file

@ -0,0 +1,11 @@
-----BEGIN CERTIFICATE-----
MIIBhzCCASygAwIBAgIRANJiwPnM38wWznkJGOcIyIYwCgYIKoZIzj0EAwIwITEf
MB0GA1UEAxMWU21hbGxzdGVwIFRlc3QgUm9vdCBDQTAeFw0xODA5MjcxODE4MDla
Fw0yODA5MjQxODE4MDlaMCExHzAdBgNVBAMTFlNtYWxsc3RlcCBUZXN0IFJvb3Qg
Q0EwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAAS15w7dx9zPjCnQ7+RlRkvUXQJN
Fjk5Hg5K9nCoiiNQQhcQMw63/pXQxHNsugiMshcN59XJC8195KJPm25nXN8co0Uw
QzAOBgNVHQ8BAf8EBAMCAaYwEgYDVR0TAQH/BAgwBgEB/wIBATAdBgNVHQ4EFgQU
B2BAXUSPZbFjnY6VzbApV48Tn3owCgYIKoZIzj0EAwIDSQAwRgIhAJRTVmc2xW8c
ESx4oIp2d/OX9KBZzpcNi9fHnnJCS0FXAiEA7OpFb2+b8KBzg1c02x21PS7pHoET
/A8LXNH4M06A7vE=
-----END CERTIFICATE-----

8
ca/testdata/secrets/root_ca_key vendored Normal file
View file

@ -0,0 +1,8 @@
-----BEGIN EC PRIVATE KEY-----
Proc-Type: 4,ENCRYPTED
DEK-Info: AES-256-CBC,3e0252253bf2ca8a21087f2f36c3bb4d
YlSY9zZ7jEMEWqgk3IT3B+WuJrnAMn9OBtMeWMo9FL1eQFLfAJBwKiKdEUYyeAwi
qi4nxx4MvfpkN02B53rmObUmAWQsxOPlMY3/KVkwQ1ovT/+eC/BGieBMvm/1aOYu
7/rnNAvI/3gWrbQ59mW6pr2qjK2eHr08s6S6GUx3C2E=
-----END EC PRIVATE KEY-----

8
ca/testdata/secrets/step_cli_key vendored Normal file
View file

@ -0,0 +1,8 @@
-----BEGIN EC PRIVATE KEY-----
Proc-Type: 4,ENCRYPTED
DEK-Info: AES-128-CBC,e2c9c7cdad45b5032f1990b929cf83fd
k3Yd307VgDrdllCBGN7PP8dOMQvEAUkq1lYtyxAWa7u/DuxeDP7SYlDB+xEk/UL8
bgoYYCProydEElYFzGg8Z98WYAzbNoP2p6PPPpAhOZsxJjc5OfTHf/OQleR8PjD5
ryN4woGuq7Tiq5xritlyhluPc91ODqMsm4P98X1sPYA=
-----END EC PRIVATE KEY-----

View file

@ -0,0 +1,4 @@
-----BEGIN PUBLIC KEY-----
MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE7ZdAAMZCFU4XwgblI5RfZouBi8lY
mF6DlZusNNnsbm+xCvYl3PAPZ+DKvKYERdazEPEU2OOo3riostJst0tn1g==
-----END PUBLIC KEY-----

View file

@ -0,0 +1,7 @@
{
"protected": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6IlhOdmYxQjgxSUlLMFA2NUkwcmtGTGcifQ",
"encrypted_key": "XaN9zcPQeWt49zchUDm34FECUTHfQTn_",
"iv": "tmNHPQDqR3ebsWfd",
"ciphertext": "9WZr3YVdeOyJh36vvx0VlRtluhvYp4K7jJ1KGDr1qypwZ3ziBVSNbYYQ71du7fTtrnfG1wgGTVR39tWSzBU-zwQ5hdV3rpMAaEbod5zeW6SHd95H3Bvcb43YiiqJFNL5sGZzFb7FqzVmpsZ1efiv6sZaGDHtnCAL6r12UG5EZuqGfM0jGCZitUz2m9TUKXJL5DJ7MOYbFfkCEsUBPDm_TInliSVn2kMJhFa0VOe5wZk5YOuYM3lNYW64HGtbf-llN2Xk-4O9TfeSPizBx9ZqGpeu8pz13efUDT2WL9tWo6-0UE-CrG0bScm8lFTncTkHcu49_a5NaUBkYlBjEiw",
"tag": "thPcx3t1AUcWuEygXIY3Fg"
}

View file

@ -0,0 +1,9 @@
{
"use": "sig",
"kty": "EC",
"kid": "4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc",
"crv": "P-256",
"alg": "ES256",
"x": "7ZdAAMZCFU4XwgblI5RfZouBi8lYmF6DlZusNNnsbm8",
"y": "sQr2JdzwD2fgyrymBEXWsxDxFNjjqN64qLLSbLdLZ9Y"
}

244
ca/tls.go Normal file
View file

@ -0,0 +1,244 @@
package ca
import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"net"
"net/http"
"time"
"github.com/pkg/errors"
"github.com/smallstep/ca-component/api"
"golang.org/x/net/http2"
)
// GetClientTLSConfig returns a tls.Config for client use configured with the
// sign certificate, and a new certificate pool with the sign root certificate.
// The certificate will automatically rotate before expiring.
func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Config, error) {
cert, err := TLSCertificate(sign, pk)
if err != nil {
return nil, err
}
renewer, err := NewTLSRenewer(cert, nil)
if err != nil {
return nil, err
}
tlsConfig := getDefaultTLSConfig(sign)
// Note that with GetClientCertificate tlsConfig.Certificates is not used.
tlsConfig.GetClientCertificate = renewer.GetClientCertificate
tlsConfig.PreferServerCipherSuites = true
// Build RootCAs with given root certificate
if pool := c.getCertPool(sign); pool != nil {
tlsConfig.RootCAs = pool
}
// Parse Certificates and build NameToCertificate
tlsConfig.BuildNameToCertificate()
// Update renew function with transport
tr, err := getDefaultTransport(tlsConfig)
if err != nil {
return nil, err
}
renewer.RenewCertificate = getRenewFunc(c, tr, pk)
// Start renewer
renewer.RunContext(ctx)
return tlsConfig, nil
}
// GetServerTLSConfig returns a tls.Config for server use configured with the
// sign certificate, and a new certificate pool with the sign root certificate.
// The certificate will automatically rotate before expiring.
func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Config, error) {
cert, err := TLSCertificate(sign, pk)
if err != nil {
return nil, err
}
renewer, err := NewTLSRenewer(cert, nil)
if err != nil {
return nil, err
}
tlsConfig := getDefaultTLSConfig(sign)
// Note that GetCertificate will only be called if the client supplies SNI
// information or if tlsConfig.Certificates is empty.
tlsConfig.GetCertificate = renewer.GetCertificate
tlsConfig.GetClientCertificate = renewer.GetClientCertificate
tlsConfig.PreferServerCipherSuites = true
// Build RootCAs with given root certificate
if pool := c.getCertPool(sign); pool != nil {
tlsConfig.ClientCAs = pool
tlsConfig.ClientAuth = tls.VerifyClientCertIfGiven
// Add RootCAs for refresh client
tlsConfig.RootCAs = pool
}
// Update renew function with transport
tr, err := getDefaultTransport(tlsConfig)
if err != nil {
return nil, err
}
renewer.RenewCertificate = getRenewFunc(c, tr, pk)
// Start renewer
renewer.RunContext(ctx)
return tlsConfig, nil
}
// Transport returns an http.Transport configured to use the client certificate from the sign response.
func (c *Client) Transport(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey) (*http.Transport, error) {
tlsConfig, err := c.GetClientTLSConfig(ctx, sign, pk)
if err != nil {
return nil, err
}
return getDefaultTransport(tlsConfig)
}
// getCertPool returns the transport x509.CertPool or the one from the sign
// request.
func (c *Client) getCertPool(sign *api.SignResponse) *x509.CertPool {
// Return the transport certPool
if c.certPool != nil {
return c.certPool
}
// Return certificate used in sign request.
if root, err := RootCertificate(sign); err == nil {
pool := x509.NewCertPool()
pool.AddCert(root)
return pool
}
return nil
}
// Certificate returns the server or client certificate from the sign response.
func Certificate(sign *api.SignResponse) (*x509.Certificate, error) {
if sign.ServerPEM.Certificate == nil {
return nil, errors.New("ca: certificate does not exists")
}
return sign.ServerPEM.Certificate, nil
}
// IntermediateCertificate returns the CA intermediate certificate from the sign
// response.
func IntermediateCertificate(sign *api.SignResponse) (*x509.Certificate, error) {
if sign.CaPEM.Certificate == nil {
return nil, errors.New("ca: certificate does not exists")
}
return sign.CaPEM.Certificate, nil
}
// RootCertificate returns the root certificate from the sign response.
func RootCertificate(sign *api.SignResponse) (*x509.Certificate, error) {
if sign.TLS == nil || len(sign.TLS.VerifiedChains) == 0 {
return nil, errors.New("ca: certificate does not exists")
}
lastChain := sign.TLS.VerifiedChains[len(sign.TLS.VerifiedChains)-1]
if len(lastChain) == 0 {
return nil, errors.New("ca: certificate does not exists")
}
return lastChain[len(lastChain)-1], nil
}
// TLSCertificate creates a new TLS certificate from the sign response and the
// private key used.
func TLSCertificate(sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Certificate, error) {
certPEM, err := getPEM(sign.ServerPEM)
if err != nil {
return nil, err
}
caPEM, err := getPEM(sign.CaPEM)
if err != nil {
return nil, err
}
keyPEM, err := getPEM(pk)
if err != nil {
return nil, err
}
chain := append(certPEM, caPEM...)
cert, err := tls.X509KeyPair(chain, keyPEM)
if err != nil {
return nil, errors.Wrap(err, "error creating tls certificate")
}
leaf, err := x509.ParseCertificate(cert.Certificate[0])
if err != nil {
return nil, errors.Wrap(err, "error parsing tls certificate")
}
cert.Leaf = leaf
return &cert, nil
}
func getDefaultTLSConfig(sign *api.SignResponse) *tls.Config {
if sign.TLSOptions != nil {
return sign.TLSOptions.TLSConfig()
}
return &tls.Config{
MinVersion: tls.VersionTLS12,
}
}
// getDefaultTransport returns an http.Transport with the same parameters than
// http.DefaultTransport, but adds the given tls.Config and configures the
// transport for HTTP/2.
func getDefaultTransport(tlsConfig *tls.Config) (*http.Transport, error) {
tr := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
TLSClientConfig: tlsConfig,
}
if err := http2.ConfigureTransport(tr); err != nil {
return nil, errors.Wrap(err, "error configuring transport")
}
return tr, nil
}
func getPEM(i interface{}) ([]byte, error) {
block := new(pem.Block)
switch i := i.(type) {
case api.Certificate:
block.Type = "CERTIFICATE"
block.Bytes = i.Raw
case *x509.Certificate:
block.Type = "CERTIFICATE"
block.Bytes = i.Raw
case *rsa.PrivateKey:
block.Type = "RSA PRIVATE KEY"
block.Bytes = x509.MarshalPKCS1PrivateKey(i)
case *ecdsa.PrivateKey:
var err error
block.Type = "EC PRIVATE KEY"
block.Bytes, err = x509.MarshalECPrivateKey(i)
if err != nil {
return nil, errors.Wrap(err, "error marshaling private key")
}
default:
return nil, errors.Errorf("unsupported key type %T", i)
}
return pem.EncodeToMemory(block), nil
}
func getRenewFunc(client *Client, tr *http.Transport, pk crypto.PrivateKey) RenewFunc {
return func() (*tls.Certificate, error) {
sign, err := client.Renew(tr)
if err != nil {
return nil, err
}
return TLSCertificate(sign, pk)
}
}

397
ca/tls_test.go Normal file
View file

@ -0,0 +1,397 @@
package ca
import (
"bytes"
"context"
"crypto"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"encoding/hex"
"io/ioutil"
"log"
"net/http"
"net/http/httptest"
"reflect"
"testing"
"time"
"github.com/smallstep/ca-component/api"
"github.com/smallstep/ca-component/authority"
"github.com/smallstep/cli/crypto/randutil"
stepJOSE "github.com/smallstep/cli/jose"
"gopkg.in/square/go-jose.v2"
"gopkg.in/square/go-jose.v2/jwt"
)
func generateOTT(subject string) string {
now := time.Now()
jwk, err := stepJOSE.ParseKey("testdata/secrets/ott_mariano_priv.jwk", stepJOSE.WithPassword([]byte("password")))
if err != nil {
panic(err)
}
opts := new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID)
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, opts)
if err != nil {
panic(err)
}
id, err := randutil.ASCII(64)
if err != nil {
panic(err)
}
cl := jwt.Claims{
ID: id,
Subject: subject,
Issuer: "mariano",
NotBefore: jwt.NewNumericDate(now),
Expiry: jwt.NewNumericDate(now.Add(time.Minute)),
Audience: []string{"https://ca/sign"},
}
raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
if err != nil {
panic(err)
}
return raw
}
func startTestServer(tlsConfig *tls.Config, handler http.Handler) *httptest.Server {
srv := httptest.NewUnstartedServer(handler)
srv.TLS = tlsConfig
srv.StartTLS()
// Force the use of GetCertificate on IPs
srv.TLS.Certificates = nil
return srv
}
func startCATestServer() *httptest.Server {
config, err := authority.LoadConfiguration("testdata/ca.json")
if err != nil {
panic(err)
}
ca, err := New(config)
if err != nil {
panic(err)
}
// Use a httptest.Server instead
return startTestServer(ca.srv.TLSConfig, ca.srv.Handler)
}
func sign(domain string) (*Client, *api.SignResponse, crypto.PrivateKey) {
srv := startCATestServer()
defer srv.Close()
return signDuration(srv, domain, 0)
}
func signDuration(srv *httptest.Server, domain string, duration time.Duration) (*Client, *api.SignResponse, crypto.PrivateKey) {
req, pk, err := CreateSignRequest(generateOTT(domain))
if err != nil {
panic(err)
}
if duration > 0 {
req.NotBefore = time.Now()
req.NotAfter = req.NotBefore.Add(duration)
}
client, err := NewClient(srv.URL, WithRootFile("testdata/secrets/root_ca.crt"))
if err != nil {
panic(err)
}
sr, err := client.Sign(req)
if err != nil {
panic(err)
}
return client, sr, pk
}
func TestClient_GetServerTLSConfig_http(t *testing.T) {
client, sr, pk := sign("127.0.0.1")
tlsConfig, err := client.GetServerTLSConfig(context.Background(), sr, pk)
if err != nil {
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
}
clientDomain := "test.domain"
// Create server with given tls.Config
srv := startTestServer(tlsConfig, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.TLS == nil || len(req.TLS.PeerCertificates) == 0 {
w.Write([]byte("fail"))
t.Error("http.Request.TLS does not have peer certificates")
return
}
if req.TLS.PeerCertificates[0].Subject.CommonName != clientDomain {
w.Write([]byte("fail"))
t.Errorf("http.Request.TLS.PeerCertificates[0].Subject.CommonName = %s, wants %s", req.TLS.PeerCertificates[0].Subject.CommonName, clientDomain)
return
}
if !reflect.DeepEqual(req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain}) {
w.Write([]byte("fail"))
t.Errorf("http.Request.TLS.PeerCertificates[0].DNSNames %v, wants %v", req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain})
return
}
w.Write([]byte("ok"))
}))
defer srv.Close()
tests := []struct {
name string
getClient func(*testing.T, *Client, *api.SignResponse, crypto.PrivateKey) *http.Client
}{
{"with transport", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
tr, err := client.Transport(context.Background(), sr, pk)
if err != nil {
t.Errorf("Client.Transport() error = %v", err)
return nil
}
return &http.Client{
Transport: tr,
}
}},
{"with tlsConfig", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
tlsConfig, err := client.GetClientTLSConfig(context.Background(), sr, pk)
if err != nil {
t.Errorf("Client.GetClientTLSConfig() error = %v", err)
return nil
}
tr, err := getDefaultTransport(tlsConfig)
if err != nil {
t.Errorf("getDefaultTransport() error = %v", err)
return nil
}
return &http.Client{
Transport: tr,
}
}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client, sr, pk := sign(clientDomain)
cli := tt.getClient(t, client, sr, pk)
if cli != nil {
resp, err := cli.Get(srv.URL)
if err != nil {
t.Fatalf("http.Client.Get() error = %v", err)
}
defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("ioutil.RealAdd() error = %v", err)
}
if !bytes.Equal(b, []byte("ok")) {
t.Errorf("response body unexpected, got %s, want ok", b)
}
}
})
}
}
func TestClient_GetServerTLSConfig_renew(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode.")
}
// Start CA
ca := startCATestServer()
defer ca.Close()
client, sr, pk := signDuration(ca, "127.0.0.1", 1*time.Minute)
tlsConfig, err := client.GetServerTLSConfig(context.Background(), sr, pk)
if err != nil {
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
}
clientDomain := "test.domain"
fingerprints := make(map[string]struct{})
// Create server with given tls.Config
srv := startTestServer(tlsConfig, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.TLS == nil || len(req.TLS.PeerCertificates) == 0 {
w.Write([]byte("fail"))
t.Error("http.Request.TLS does not have peer certificates")
return
}
if req.TLS.PeerCertificates[0].Subject.CommonName != clientDomain {
w.Write([]byte("fail"))
t.Errorf("http.Request.TLS.PeerCertificates[0].Subject.CommonName = %s, wants %s", req.TLS.PeerCertificates[0].Subject.CommonName, clientDomain)
return
}
if !reflect.DeepEqual(req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain}) {
w.Write([]byte("fail"))
t.Errorf("http.Request.TLS.PeerCertificates[0].DNSNames %v, wants %v", req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain})
return
}
// Add serial number to check rotation
sum := sha256.Sum256(req.TLS.PeerCertificates[0].Raw)
fingerprints[hex.EncodeToString(sum[:])] = struct{}{}
w.Write([]byte("ok"))
}))
defer srv.Close()
// Clients: transport and tlsConfig
client, sr, pk = signDuration(ca, clientDomain, 1*time.Minute)
tr1, err := client.Transport(context.Background(), sr, pk)
if err != nil {
t.Fatalf("Client.Transport() error = %v", err)
}
client, sr, pk = signDuration(ca, clientDomain, 1*time.Minute)
tlsConfig, err = client.GetClientTLSConfig(context.Background(), sr, pk)
if err != nil {
t.Fatalf("Client.GetClientTLSConfig() error = %v", err)
}
tr2, err := getDefaultTransport(tlsConfig)
if err != nil {
t.Fatalf("getDefaultTransport() error = %v", err)
}
// Disable keep alives to force TLS handshake
tr1.DisableKeepAlives = true
tr2.DisableKeepAlives = true
tests := []struct {
name string
client *http.Client
}{
{"with transport", &http.Client{Transport: tr1}},
{"with tlsConfig", &http.Client{Transport: tr2}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp, err := tt.client.Get(srv.URL)
if err != nil {
t.Fatalf("http.Client.Get() error = %v", err)
}
defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("ioutil.RealAdd() error = %v", err)
}
if !bytes.Equal(b, []byte("ok")) {
t.Errorf("response body unexpected, got %s, want ok", b)
}
})
}
if l := len(fingerprints); l != 2 {
t.Errorf("number of fingerprints unexpected, got %d, want 4", l)
}
// Wait for renewal 40s == 1m-1m/3
log.Printf("Sleeping for %s ...\n", 40*time.Second)
time.Sleep(40 * time.Second)
for _, tt := range tests {
t.Run("renewed "+tt.name, func(t *testing.T) {
resp, err := tt.client.Get(srv.URL)
if err != nil {
t.Fatalf("http.Client.Get() error = %v", err)
}
defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("ioutil.RealAdd() error = %v", err)
}
if !bytes.Equal(b, []byte("ok")) {
t.Errorf("response body unexpected, got %s, want ok", b)
}
})
}
if l := len(fingerprints); l != 4 {
t.Errorf("number of fingerprints unexpected, got %d, want 4", l)
}
}
func TestCertificate(t *testing.T) {
cert := parseCertificate(certPEM)
ok := &api.SignResponse{
ServerPEM: api.Certificate{Certificate: cert},
CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)},
}
tests := []struct {
name string
sign *api.SignResponse
want *x509.Certificate
wantErr bool
}{
{"ok", ok, cert, false},
{"fail", &api.SignResponse{}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := Certificate(tt.sign)
if (err != nil) != tt.wantErr {
t.Errorf("Certificate() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Certificate() = %v, want %v", got, tt.want)
}
})
}
}
func TestIntermediateCertificate(t *testing.T) {
intermediate := parseCertificate(rootPEM)
ok := &api.SignResponse{
ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)},
CaPEM: api.Certificate{Certificate: intermediate},
}
tests := []struct {
name string
sign *api.SignResponse
want *x509.Certificate
wantErr bool
}{
{"ok", ok, intermediate, false},
{"fail", &api.SignResponse{}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := IntermediateCertificate(tt.sign)
if (err != nil) != tt.wantErr {
t.Errorf("IntermediateCertificate() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("IntermediateCertificate() = %v, want %v", got, tt.want)
}
})
}
}
func TestRootCertificateCertificate(t *testing.T) {
root := parseCertificate(rootPEM)
ok := &api.SignResponse{
ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)},
CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)},
TLS: &tls.ConnectionState{VerifiedChains: [][]*x509.Certificate{
{root, root},
}},
}
noTLS := &api.SignResponse{
ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)},
CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)},
}
tests := []struct {
name string
sign *api.SignResponse
want *x509.Certificate
wantErr bool
}{
{"ok", ok, root, false},
{"fail", &api.SignResponse{}, nil, true},
{"no tls", noTLS, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := RootCertificate(tt.sign)
if (err != nil) != tt.wantErr {
t.Errorf("RootCertificate() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("RootCertificate() = %v, want %v", got, tt.want)
}
})
}
}

69
cmd/step-ca/main.go Normal file
View file

@ -0,0 +1,69 @@
package main
import (
"bytes"
"flag"
"fmt"
"io/ioutil"
"net/http"
"os"
"path"
"unicode"
"github.com/pkg/errors"
"github.com/smallstep/ca-component/authority"
"github.com/smallstep/ca-component/ca"
)
func usage() {
fmt.Fprintf(os.Stderr, "Usage: %s [options] <config.json>\n\n", path.Base(os.Args[0]))
flag.PrintDefaults()
}
func main() {
var configFile, passFile string
flag.StringVar(&passFile, "password-file", "", "Path to file containing a password")
flag.Usage = usage
flag.Parse()
if flag.NArg() != 1 {
flag.Usage()
os.Exit(1)
}
configFile = flag.Arg(0)
config, err := authority.LoadConfiguration(configFile)
if err != nil {
fatal(err)
}
var password []byte
if passFile != "" {
if password, err = ioutil.ReadFile(passFile); err != nil {
fatal(errors.Wrapf(err, "error reading %s", passFile))
}
password = bytes.TrimRightFunc(password, unicode.IsSpace)
}
srv, err := ca.New(config, ca.WithConfigFile(configFile), ca.WithPassword(password))
if err != nil {
fatal(err)
}
go ca.StopReloaderHandler(srv)
if err = srv.Run(); err != nil && err != http.ErrServerClosed {
fatal(err)
}
}
// fatal writes the passed error on the standard error and exits with the exit
// code 1. If the environment variable STEPDEBUG is set to 1 it shows the
// stack trace of the error.
func fatal(err error) {
if os.Getenv("STEPDEBUG") == "1" {
fmt.Fprintf(os.Stderr, "%+v\n", err)
} else {
fmt.Fprintln(os.Stderr, err)
}
os.Exit(2)
}

16
config.json Normal file
View file

@ -0,0 +1,16 @@
{
"address": "127.0.0.1:9000",
"dnsNames": "ca.smallstep.com",
"root": "/Users/max/src/github.com/smallstep/step/.step/secrets/root_ca.crt",
"crt": "/Users/max/src/github.com/smallstep/step/.step/secrets/intermediate_ca.crt",
"key": "/Users/max/src/github.com/smallstep/step/.step/secrets/intermediate_ca_key",
"password": "pass",
"monitoring": {
"newrelic": {
"key": "57e1214ddccf694de9eef9aefdec538b6425cbbb",
"name": "step-foo"
}
},
"ottPublicKey": "/Users/max/src/github.com/smallstep/step/.step/secrets/ott_key.public"
}

6
examples/config.json Normal file
View file

@ -0,0 +1,6 @@
{
"caPath": "/path/to/intermediate-certificate",
"caPrivateKeyPath": "/path/to/intermediate-private-key",
"caPasscode": "very-secure-passcode",
"listenAddress": "127.0.0.1:9000"
}

6
examples/csr-config.yaml Normal file
View file

@ -0,0 +1,6 @@
country: USA
locality: San Francisco
organization: smallstep
common_name: internal.smallstep.com
key_type: rsa
rsa_bits: 4096

77
logging/clf.go Normal file
View file

@ -0,0 +1,77 @@
package logging
import (
"bytes"
"fmt"
"strconv"
"time"
"github.com/sirupsen/logrus"
)
var clfFields = [...]string{
"request-id", "remote-address", "name", "user-id", "time", "duration", "method", "path", "protocol", "status", "size",
}
// CommonLogFormat implements the logrus.Formatter interface it writes logrus
// entries using a CLF format prepended by the request-id.
type CommonLogFormat struct{}
// Format implements the logrus.Formatter interface. It returns the given
// logrus entry as a CLF line with the following format:
// <request-id> <remote-address> <name> <user-id> <time> <duration> "<method> <path> <protocol>" <status> <size>
// If a field is not known, the hyphen symbol (-) will be used.
func (f *CommonLogFormat) Format(entry *logrus.Entry) ([]byte, error) {
data := make([]string, len(clfFields))
for i, name := range clfFields {
if v, ok := entry.Data[name]; ok {
switch v := v.(type) {
case error:
data[i] = v.Error()
case string:
if v == "" {
data[i] = "-"
} else {
data[i] = v
}
case time.Time:
data[i] = v.Format(time.RFC3339)
case time.Duration:
data[i] = strconv.FormatInt(int64(v/time.Millisecond), 10)
case int:
data[i] = strconv.FormatInt(int64(v), 10)
case int64:
data[i] = strconv.FormatInt(v, 10)
default:
data[i] = fmt.Sprintf("%v", v)
}
} else {
data[i] = "-"
}
}
var buf bytes.Buffer
buf.WriteString(data[0])
buf.WriteByte(' ')
buf.WriteString(data[1])
buf.WriteByte(' ')
buf.WriteString(data[2])
buf.WriteByte(' ')
buf.WriteString(data[3])
buf.WriteByte(' ')
buf.WriteString(data[4])
buf.WriteByte(' ')
buf.WriteString(data[5])
buf.WriteString(" \"")
buf.WriteString(data[6])
buf.WriteByte(' ')
buf.WriteString(data[7])
buf.WriteByte(' ')
buf.WriteString(data[8])
buf.WriteString("\" ")
buf.WriteString(data[9])
buf.WriteByte(' ')
buf.WriteString(data[10])
buf.WriteByte('\n')
return buf.Bytes(), nil
}

66
logging/context.go Normal file
View file

@ -0,0 +1,66 @@
package logging
import (
"context"
"net/http"
"github.com/rs/xid"
)
type key int
const (
// RequestIDKey is the context key that should store the request identifier.
RequestIDKey key = iota
// UserIDKey is the context key that should store the user identifier.
UserIDKey
)
// NewRequestID creates a new request id using github.com/rs/xid.
func NewRequestID() string {
return xid.New().String()
}
// RequestID returns a new middleware that gets the given header and sets it
// in the context so it can be written in the logger. If the header does not
// exists or it's the empty string, it uses github.com/rs/xid to create a new
// one.
func RequestID(headerName string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, req *http.Request) {
requestID := req.Header.Get(headerName)
if requestID == "" {
requestID = NewRequestID()
req.Header.Set(headerName, requestID)
}
ctx := WithRequestID(req.Context(), requestID)
next.ServeHTTP(w, req.WithContext(ctx))
}
return http.HandlerFunc(fn)
}
}
// WithRequestID returns a new context with the given requestID added to the
// context.
func WithRequestID(ctx context.Context, requestID string) context.Context {
return context.WithValue(ctx, RequestIDKey, requestID)
}
// GetRequestID returns the request id from the context if it exists.
func GetRequestID(ctx context.Context) (string, bool) {
v, ok := ctx.Value(RequestIDKey).(string)
return v, ok
}
// WithUserID decodes the token, extracts the user from the payload and stores
// it in the context.
func WithUserID(ctx context.Context, userID string) context.Context {
return context.WithValue(ctx, UserIDKey, userID)
}
// GetUserID returns the request id from the context if it exists.
func GetUserID(ctx context.Context) (string, bool) {
v, ok := ctx.Value(UserIDKey).(string)
return v, ok
}

100
logging/handler.go Normal file
View file

@ -0,0 +1,100 @@
package logging
import (
"net"
"net/http"
"time"
"github.com/sirupsen/logrus"
)
// LoggerHandler creates a logger handler
type LoggerHandler struct {
name string
logger *logrus.Logger
next http.Handler
}
// NewLoggerHandler returns the given http.Handler with the logger integrated.
func NewLoggerHandler(name string, logger *Logger, next http.Handler) http.Handler {
h := RequestID(logger.GetTraceHeader())
return h(&LoggerHandler{
name: name,
logger: logger.GetImpl(),
next: next,
})
}
// ServeHTTP implements the http.Handler and call to the handler to log with a
// custom http.ResponseWriter that records the response code and the number of
// bytes sent.
func (l *LoggerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
t := time.Now()
rw := NewResponseLogger(w)
l.next.ServeHTTP(rw, r)
d := time.Now().Sub(t)
l.writeEntry(rw, r, t, d)
}
// writeEntry writes to the Logger writer the request information in the logger.
func (l *LoggerHandler) writeEntry(w ResponseLogger, r *http.Request, t time.Time, d time.Duration) {
var reqID, user string
ctx := r.Context()
if v, ok := ctx.Value(RequestIDKey).(string); ok && v != "" {
reqID = v
}
if v, ok := ctx.Value(UserIDKey).(string); ok && v != "" {
user = v
}
// Remote hostname
addr, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
addr = r.RemoteAddr
}
// From https://github.com/gorilla/handlers
uri := r.RequestURI
// Requests using the CONNECT method over HTTP/2.0 must use
// the authority field (aka r.Host) to identify the target.
// Refer: https://httpwg.github.io/specs/rfc7540.html#CONNECT
if r.ProtoMajor == 2 && r.Method == "CONNECT" {
uri = r.Host
}
if uri == "" {
uri = r.URL.RequestURI()
}
status := w.StatusCode()
fields := logrus.Fields{
"request-id": reqID,
"remote-address": addr,
"name": l.name,
"user-id": user,
"time": t.Format(time.RFC3339),
"duration-ns": d.Nanoseconds(),
"duration": d.String(),
"method": r.Method,
"path": uri,
"protocol": r.Proto,
"status": status,
"size": w.Size(),
"referer": r.Referer(),
"user-agent": r.UserAgent(),
}
for k, v := range w.Fields() {
fields[k] = v
}
switch {
case status < http.StatusBadRequest:
l.logger.WithFields(fields).Info()
case status < http.StatusInternalServerError:
l.logger.WithFields(fields).Warn()
default:
l.logger.WithFields(fields).Error()
}
}

77
logging/logger.go Normal file
View file

@ -0,0 +1,77 @@
package logging
import (
"encoding/json"
"net/http"
"strings"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)
// defaultTraceIdHeader is the default header used as a trace id.
const defaultTraceIDHeader = "X-Smallstep-Id"
// ErrorKey defines the key used to log errors.
var ErrorKey = logrus.ErrorKey
// Logger is an alias of logrus.Logger.
type Logger struct {
*logrus.Logger
name string
traceHeader string
}
// loggerConfig represents the configuration options for the logger.
type loggerConfig struct {
Format string `json:"format"`
TraceHeader string `json:"traceHeader"`
}
// New initializes the logger with the given options.
func New(name string, raw json.RawMessage) (*Logger, error) {
var config loggerConfig
if err := json.Unmarshal(raw, &config); err != nil {
return nil, errors.Wrap(err, "error unmarshalling logging attribute")
}
var formatter logrus.Formatter
switch strings.ToLower(config.Format) {
case "", "text":
case "json":
formatter = new(logrus.JSONFormatter)
case "common":
formatter = new(CommonLogFormat)
default:
return nil, errors.Errorf("unsupported logger.format '%s'", config.Format)
}
logger := &Logger{
Logger: logrus.New(),
name: name,
traceHeader: config.TraceHeader,
}
if formatter != nil {
logger.Formatter = formatter
}
return logger, nil
}
// GetImpl returns the real implementation of the logger.
func (l *Logger) GetImpl() *logrus.Logger {
return l.Logger
}
// GetTraceHeader returns the trace header configured
func (l *Logger) GetTraceHeader() string {
if l.traceHeader == "" {
return defaultTraceIDHeader
}
return l.traceHeader
}
// Middleware returns the logger middleware that will trace the request of the
// given handler.
func (l *Logger) Middleware(next http.Handler) http.Handler {
return NewLoggerHandler(l.name, l, next)
}

125
logging/responselogger.go Normal file
View file

@ -0,0 +1,125 @@
package logging
import (
"bufio"
"net"
"net/http"
)
// ResponseLogger defines an interface that a responseWrite can implement to
// support the capture of the status code, the number of bytes written and
// extra log entry fields.
type ResponseLogger interface {
http.ResponseWriter
Size() int
StatusCode() int
Fields() map[string]interface{}
WithFields(map[string]interface{})
}
// NewResponseLogger wraps the given response writer with methods to capture
// the status code, the number of bytes written, and methods to add new log
// entries. It won't wrap the response writer if it's already a
// ResponseLogger.
func NewResponseLogger(w http.ResponseWriter) ResponseLogger {
if rw, ok := w.(ResponseLogger); ok {
return rw
}
return wrapLogger(w)
}
func wrapLogger(w http.ResponseWriter) (rw ResponseLogger) {
rw = &rwDefault{w, 200, 0, nil}
if c, ok := w.(http.CloseNotifier); ok {
rw = &rwCloseNotifier{rw, c}
}
if f, ok := w.(http.Flusher); ok {
rw = &rwFlusher{rw, f}
}
if h, ok := w.(http.Hijacker); ok {
rw = &rwHijacker{rw, h}
}
if p, ok := w.(http.Pusher); ok {
rw = &rwPusher{rw, p}
}
return
}
type rwDefault struct {
http.ResponseWriter
code int
size int
fields map[string]interface{}
}
func (r *rwDefault) Header() http.Header {
return r.ResponseWriter.Header()
}
func (r *rwDefault) Write(p []byte) (n int, err error) {
n, err = r.ResponseWriter.Write(p)
r.size += n
return
}
func (r *rwDefault) WriteHeader(code int) {
r.ResponseWriter.WriteHeader(code)
r.code = code
}
func (r *rwDefault) Size() int {
return r.size
}
func (r *rwDefault) StatusCode() int {
return r.code
}
func (r *rwDefault) Fields() map[string]interface{} {
return r.fields
}
func (r *rwDefault) WithFields(fields map[string]interface{}) {
if r.fields == nil {
r.fields = make(map[string]interface{}, len(fields))
}
for k, v := range fields {
r.fields[k] = v
}
}
type rwCloseNotifier struct {
ResponseLogger
c http.CloseNotifier
}
func (r *rwCloseNotifier) CloseNotify() <-chan bool {
return r.CloseNotify()
}
type rwFlusher struct {
ResponseLogger
f http.Flusher
}
func (r *rwFlusher) Flush() {
r.f.Flush()
}
type rwHijacker struct {
ResponseLogger
h http.Hijacker
}
func (r *rwHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return r.h.Hijack()
}
type rwPusher struct {
ResponseLogger
p http.Pusher
}
func (rw *rwPusher) Push(target string, opts *http.PushOptions) error {
return rw.p.Push(target, opts)
}

115
monitoring/monitoring.go Normal file
View file

@ -0,0 +1,115 @@
package monitoring
import (
"encoding/json"
"fmt"
"net/http"
"strconv"
"strings"
newrelic "github.com/newrelic/go-agent"
"github.com/pkg/errors"
"github.com/smallstep/ca-component/logging"
)
// Middleware is a function returns another http.Handler that wraps the given
// handler.
type Middleware func(next http.Handler) http.Handler
// Monitoring is the type holding a middleware that traces the request to an
// application.
type Monitoring struct {
middleware Middleware
}
// monitoring config represents the JSON attributes used for configuration. At
// this moment only fields for NewRelic are supported.
type monitoringConfig struct {
Type string `json:"type,omitempty"`
Name string `json:"name"`
Key string `json:"key"`
}
// New initializes the monitoring with the given configuration.
// Right now it only supports newrelic as the monitoring backend.
func New(raw json.RawMessage) (*Monitoring, error) {
var config monitoringConfig
if err := json.Unmarshal(raw, &config); err != nil {
return nil, errors.Wrap(err, "error unmarshalling monitoring attribute")
}
m := new(Monitoring)
switch strings.ToLower(config.Type) {
case "", "newrelic":
app, err := newrelic.NewApplication(newrelic.NewConfig(config.Name, config.Key))
if err != nil {
return nil, errors.Wrap(err, "error loading New Relic application")
}
m.middleware = newRelicMiddleware(app)
default:
return nil, errors.Errorf("unsupported monitoring.type '%s'", config.Type)
}
return m, nil
}
// Middleware is an HTTP middleware that traces the request with the configured
// monitoring backednd.
func (m *Monitoring) Middleware(next http.Handler) http.Handler {
return m.middleware(next)
}
func newRelicMiddleware(app newrelic.Application) Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Start transaction
txn := app.StartTransaction(transactionName(r), w, r)
defer txn.End()
// Wrap request writer if necessary
rw := logging.NewResponseLogger(w)
// Call next handler
next.ServeHTTP(rw, r)
// Report status (using same key NewRelic uses by default)
status := rw.StatusCode()
txn.AddAttribute("httpResponseCode", strconv.Itoa(status))
// Add custom attributes
if v, ok := logging.GetRequestID(r.Context()); ok {
txn.AddAttribute("request.id", v)
}
// Report errors if necessary
if status >= http.StatusBadRequest {
var errorNoticed bool
if fields := rw.Fields(); fields != nil {
if v, ok := fields["error"]; ok {
if err, ok := v.(error); ok {
txn.NoticeError(err)
errorNoticed = true
}
}
}
if !errorNoticed {
txn.NoticeError(fmt.Errorf("request failed with status code %d", status))
}
}
})
}
}
func transactionName(r *http.Request) string {
// From https://github.com/gorilla/handlers
uri := r.RequestURI
// Requests using the CONNECT method over HTTP/2.0 must use
// the authority field (aka r.Host) to identify the target.
// Refer: https://httpwg.github.io/specs/rfc7540.html#CONNECT
if r.ProtoMajor == 2 && r.Method == "CONNECT" {
uri = r.Host
}
if uri == "" {
uri = r.URL.RequestURI()
}
return uri
}

175
server/server.go Normal file
View file

@ -0,0 +1,175 @@
package server
import (
"context"
"crypto/tls"
"log"
"net"
"net/http"
"os"
"time"
"github.com/pkg/errors"
)
// ServerShutdownTimeout is the default time to wait before closing
// connections on shutdown.
const ServerShutdownTimeout = 60 * time.Second
// Server is a incomplete component that implements a basic HTTP/HTTPS
// server.
type Server struct {
*http.Server
listener *net.TCPListener
reloadCh chan net.Listener
shutdownCh chan struct{}
}
// New creates a new HTTP/HTTPS server configured with the passed
// address, http.Handler and tls.Config.
func New(addr string, handler http.Handler, tlsConfig *tls.Config) *Server {
return &Server{
reloadCh: make(chan net.Listener),
shutdownCh: make(chan struct{}),
Server: newHTTPServer(addr, handler, tlsConfig),
}
}
// newHTTPServer creates a new http.Server with the TCP address, handler and
// tls.Config.
func newHTTPServer(addr string, handler http.Handler, tlsConfig *tls.Config) *http.Server {
return &http.Server{
Addr: addr,
Handler: handler,
TLSConfig: tlsConfig,
WriteTimeout: 15 * time.Second,
ReadTimeout: 15 * time.Second,
IdleTimeout: 15 * time.Second,
ErrorLog: log.New(os.Stderr, "", log.Ldate|log.Ltime|log.Llongfile),
}
}
// ListenAndServe listens on the TCP network address srv.Addr and then calls
// Serve to handle requests on incoming connections.
func (srv *Server) ListenAndServe() error {
ln, err := net.Listen("tcp", srv.Addr)
if err != nil {
return err
}
return srv.Serve(ln)
}
// Serve runs Serve or ServetTLS on the underlaying http.Server and listen to
// channels to reload or shutdown the server.
func (srv *Server) Serve(ln net.Listener) error {
var err error
// Store the current listener.
// In reloads we'll create a copy of the underlying os.File so the close of the server one does not affect the copy.
srv.listener = ln.(*net.TCPListener)
for {
// Start server
if srv.TLSConfig == nil || (len(srv.TLSConfig.Certificates) == 0 && srv.TLSConfig.GetCertificate == nil) {
log.Printf("Serving HTTP on %s ...", srv.Addr)
err = srv.Server.Serve(tcpKeepAliveListener{ln.(*net.TCPListener)})
} else {
log.Printf("Serving HTTPS on %s ...", srv.Addr)
err = srv.Server.ServeTLS(tcpKeepAliveListener{ln.(*net.TCPListener)}, "", "")
}
// log unexpected errors
if err != http.ErrServerClosed {
log.Println(errors.Wrap(err, "unexpected error"))
}
select {
case ln = <-srv.reloadCh:
srv.listener = ln.(*net.TCPListener)
case <-srv.shutdownCh:
return http.ErrServerClosed
}
}
}
// Shutdown gracefully shuts down the server without interrupting any active
// connections.
func (srv *Server) Shutdown() error {
ctx, cancel := context.WithTimeout(context.Background(), ServerShutdownTimeout)
defer cancel() // release resources if Shutdown ends before the timeout
defer close(srv.shutdownCh) // close shutdown channel
return srv.Server.Shutdown(ctx)
}
func (srv *Server) reloadShutdown() error {
ctx, cancel := context.WithTimeout(context.Background(), ServerShutdownTimeout)
defer cancel() // release resources if Shutdown ends before the timeout
return srv.Server.Shutdown(ctx)
}
// Reload reloads the current server with the configuration of the passed
// server.
func (srv *Server) Reload(ns *Server) error {
var err error
var ln net.Listener
if srv.Addr != ns.Addr {
// Open new address
ln, err = net.Listen("tcp", ns.Addr)
if err != nil {
return errors.WithStack(err)
}
} else {
// Get a copy of the underlying os.File
fd, err := srv.listener.File()
if err != nil {
return errors.WithStack(err)
}
// Make sure to close the copy
defer fd.Close()
// Creates a new listener copying fd
ln, err = net.FileListener(fd)
if err != nil {
return errors.WithStack(err)
}
}
// Close old server without sending a signal
if err := srv.reloadShutdown(); err != nil {
return err
}
// Update old server
srv.Server = ns.Server
srv.reloadCh <- ln
return nil
}
// Forbidden writes on the http.ResponseWriter a text/plain forbidden
// response.
func (srv *Server) Forbidden(w http.ResponseWriter) {
header := w.Header()
header.Set("Content-Type", "text/plain; charset=utf-8")
header.Set("Content-Length", "11")
w.WriteHeader(http.StatusForbidden)
w.Write([]byte("Forbidden.\n"))
}
// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
// connections. It's used by ListenAndServe and ListenAndServeTLS so
// dead TCP connections (e.g. closing laptop mid-download) eventually
// go away.
type tcpKeepAliveListener struct {
*net.TCPListener
}
func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
tc, err := ln.AcceptTCP()
if err != nil {
return
}
tc.SetKeepAlive(true)
tc.SetKeepAlivePeriod(3 * time.Minute)
return tc, nil
}