[#2] Add NNS context #4
10 changed files with 48 additions and 14 deletions
|
@ -362,7 +362,7 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
|
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil, "")
|
||||||
req := httptest.NewRequest("GET", u, nil)
|
req := httptest.NewRequest("GET", u, nil)
|
||||||
req = req.WithContext(ctx)
|
req = req.WithContext(ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
@ -801,7 +801,7 @@ func TestHandler_NewAccount(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
|
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil, "")
|
||||||
req := httptest.NewRequest("GET", "/foo/bar", nil)
|
req := httptest.NewRequest("GET", "/foo/bar", nil)
|
||||||
req = req.WithContext(ctx)
|
req = req.WithContext(ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
@ -1004,7 +1004,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
|
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil, "")
|
||||||
req := httptest.NewRequest("GET", "/foo/bar", nil)
|
req := httptest.NewRequest("GET", "/foo/bar", nil)
|
||||||
req = req.WithContext(ctx)
|
req = req.WithContext(ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
|
@ -95,7 +95,7 @@ func (h *handler) Route(r api.Router) {
|
||||||
if ca, ok := h.opts.CA.(*authority.Authority); ok && ca != nil {
|
if ca, ok := h.opts.CA.(*authority.Authority); ok && ca != nil {
|
||||||
ctx = authority.NewContext(ctx, ca)
|
ctx = authority.NewContext(ctx, ca)
|
||||||
}
|
}
|
||||||
ctx = acme.NewContext(ctx, h.opts.DB, client, linker, h.opts.PrerequisitesChecker)
|
ctx = acme.NewContext(ctx, h.opts.DB, client, linker, h.opts.PrerequisitesChecker, "")
|
||||||
next(w, r.WithContext(ctx))
|
next(w, r.WithContext(ctx))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
|
@ -346,7 +346,7 @@ func TestHandler_GetAuthorization(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
|
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil, "")
|
||||||
req := httptest.NewRequest("GET", "/foo/bar", nil)
|
req := httptest.NewRequest("GET", "/foo/bar", nil)
|
||||||
req = req.WithContext(ctx)
|
req = req.WithContext(ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
@ -746,7 +746,7 @@ func TestHandler_GetChallenge(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
|
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil, "")
|
||||||
req := httptest.NewRequest("GET", u, nil)
|
req := httptest.NewRequest("GET", u, nil)
|
||||||
req = req.WithContext(ctx)
|
req = req.WithContext(ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
|
@ -507,9 +507,13 @@ func deviceAttest01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose
|
||||||
func nns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error {
|
func nns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error {
|
||||||
domain := strings.TrimPrefix(ch.Value, "*.")
|
domain := strings.TrimPrefix(ch.Value, "*.")
|
||||||
|
|
||||||
|
nnsCtx, ok := GetNNSContext(ctx)
|
||||||
|
if !ok {
|
||||||
|
return errors.New("error retrieving NNS context")
|
||||||
|
}
|
||||||
|
|
||||||
nns := NNS{}
|
nns := NNS{}
|
||||||
// TODO: retrieve NNS server URL from config
|
err := nns.Dial(nnsCtx.nnsServer)
|
||||||
err := nns.Dial("http://localhost:30333")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,10 +29,12 @@ type CertificateAuthority interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewContext adds the given acme components to the context.
|
// NewContext adds the given acme components to the context.
|
||||||
func NewContext(ctx context.Context, db DB, client Client, linker Linker, fn PrerequisitesChecker) context.Context {
|
func NewContext(ctx context.Context, db DB, client Client, linker Linker, fn PrerequisitesChecker,
|
||||||
|
nnsServer string) context.Context {
|
||||||
ctx = NewDatabaseContext(ctx, db)
|
ctx = NewDatabaseContext(ctx, db)
|
||||||
ctx = NewClientContext(ctx, client)
|
ctx = NewClientContext(ctx, client)
|
||||||
ctx = NewLinkerContext(ctx, linker)
|
ctx = NewLinkerContext(ctx, linker)
|
||||||
|
ctx = NewNNSContext(ctx, nnsServer)
|
||||||
// Prerequisite checker is optional.
|
// Prerequisite checker is optional.
|
||||||
if fn != nil {
|
if fn != nil {
|
||||||
ctx = NewPrerequisitesCheckerContext(ctx, fn)
|
ctx = NewPrerequisitesCheckerContext(ctx, fn)
|
||||||
|
|
18
acme/nns.go
18
acme/nns.go
|
@ -35,6 +35,24 @@ type NNS struct {
|
||||||
client multiSchemeClient
|
client multiSchemeClient
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NNSContext is used to store info about NNS server.
|
||||||
|
type NNSContext struct {
|
||||||
|
nnsServer string
|
||||||
|
}
|
||||||
|
|
||||||
|
type nnsKey struct{}
|
||||||
|
|
||||||
|
// NewNNSContext adds new NNSContext with given params to the context.
|
||||||
|
func NewNNSContext(ctx context.Context, nnsServer string) context.Context {
|
||||||
|
return context.WithValue(ctx, nnsKey{}, NNSContext{nnsServer: nnsServer})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNNSContext returns NNSContext from the given context.
|
||||||
|
func GetNNSContext(ctx context.Context) (NNSContext, bool) {
|
||||||
|
c, ok := ctx.Value(nnsKey{}).(NNSContext)
|
||||||
|
return c, ok
|
||||||
|
}
|
||||||
|
|
||||||
// Dial connects to the address of the NNS server.
|
// Dial connects to the address of the NNS server.
|
||||||
// If URL address scheme is 'ws' or 'wss', then WebSocket protocol is used, otherwise HTTP.
|
// If URL address scheme is 'ws' or 'wss', then WebSocket protocol is used, otherwise HTTP.
|
||||||
func (n *NNS) Dial(address string) error {
|
func (n *NNS) Dial(address string) error {
|
||||||
|
|
|
@ -79,6 +79,7 @@ type Config struct {
|
||||||
CommonName string `json:"commonName,omitempty"`
|
CommonName string `json:"commonName,omitempty"`
|
||||||
CRL *CRLConfig `json:"crl,omitempty"`
|
CRL *CRLConfig `json:"crl,omitempty"`
|
||||||
SkipValidation bool `json:"-"`
|
SkipValidation bool `json:"-"`
|
||||||
|
NNSServer string `json:"nnsServer,omitempty"`
|
||||||
|
|
||||||
// Keeps record of the filename the Config is read from
|
// Keeps record of the filename the Config is read from
|
||||||
loadedFromFilepath string
|
loadedFromFilepath string
|
||||||
|
|
|
@ -54,7 +54,7 @@ func startCABootstrapServer() *httptest.Server {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
baseContext := buildContext(ca.auth, nil, nil, nil)
|
baseContext := buildContext(ca.auth, nil, nil, nil, "")
|
||||||
srv.Config.Handler = ca.srv.Handler
|
srv.Config.Handler = ca.srv.Handler
|
||||||
srv.Config.BaseContext = func(net.Listener) context.Context {
|
srv.Config.BaseContext = func(net.Listener) context.Context {
|
||||||
return baseContext
|
return baseContext
|
||||||
|
|
15
ca/ca.go
15
ca/ca.go
|
@ -309,8 +309,16 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) {
|
||||||
insecureHandler = logger.Middleware(insecureHandler)
|
insecureHandler = logger.Middleware(insecureHandler)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var nnsServer string
|
||||||
|
if len(ca.opts.nnsServer) > 0 {
|
||||||
|
nnsServer = ca.opts.nnsServer
|
||||||
|
} else if len(ca.config.NNSServer) > 0 {
|
||||||
|
nnsServer = ca.config.NNSServer
|
||||||
|
} else {
|
||||||
|
return nil, errors.New("error configuring ACME NNS context: no URL of the NNS server provided")
|
||||||
|
}
|
||||||
// Create context with all the necessary values.
|
// Create context with all the necessary values.
|
||||||
baseContext := buildContext(auth, scepAuthority, acmeDB, acmeLinker)
|
baseContext := buildContext(auth, scepAuthority, acmeDB, acmeLinker, nnsServer)
|
||||||
|
|
||||||
ca.srv = server.New(cfg.Address, handler, tlsConfig)
|
ca.srv = server.New(cfg.Address, handler, tlsConfig)
|
||||||
ca.srv.BaseContext = func(net.Listener) context.Context {
|
ca.srv.BaseContext = func(net.Listener) context.Context {
|
||||||
|
@ -351,7 +359,8 @@ func (ca *CA) shouldServeInsecureServer() bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildContext builds the server base context.
|
// buildContext builds the server base context.
|
||||||
func buildContext(a *authority.Authority, scepAuthority *scep.Authority, acmeDB acme.DB, acmeLinker acme.Linker) context.Context {
|
func buildContext(a *authority.Authority, scepAuthority *scep.Authority, acmeDB acme.DB, acmeLinker acme.Linker,
|
||||||
|
nnsServer string) context.Context {
|
||||||
ctx := authority.NewContext(context.Background(), a)
|
ctx := authority.NewContext(context.Background(), a)
|
||||||
if authDB := a.GetDatabase(); authDB != nil {
|
if authDB := a.GetDatabase(); authDB != nil {
|
||||||
ctx = db.NewContext(ctx, authDB)
|
ctx = db.NewContext(ctx, authDB)
|
||||||
|
@ -363,7 +372,7 @@ func buildContext(a *authority.Authority, scepAuthority *scep.Authority, acmeDB
|
||||||
ctx = scep.NewContext(ctx, scepAuthority)
|
ctx = scep.NewContext(ctx, scepAuthority)
|
||||||
}
|
}
|
||||||
if acmeDB != nil {
|
if acmeDB != nil {
|
||||||
ctx = acme.NewContext(ctx, acmeDB, acme.NewClient(), acmeLinker, nil)
|
ctx = acme.NewContext(ctx, acmeDB, acme.NewClient(), acmeLinker, nil, nnsServer)
|
||||||
}
|
}
|
||||||
return ctx
|
return ctx
|
||||||
}
|
}
|
||||||
|
|
|
@ -79,7 +79,7 @@ func startCATestServer() *httptest.Server {
|
||||||
}
|
}
|
||||||
// Use a httptest.Server instead
|
// Use a httptest.Server instead
|
||||||
srv := startTestServer(ca.srv.TLSConfig, ca.srv.Handler)
|
srv := startTestServer(ca.srv.TLSConfig, ca.srv.Handler)
|
||||||
baseContext := buildContext(ca.auth, nil, nil, nil)
|
baseContext := buildContext(ca.auth, nil, nil, nil, "")
|
||||||
srv.Config.BaseContext = func(net.Listener) context.Context {
|
srv.Config.BaseContext = func(net.Listener) context.Context {
|
||||||
return baseContext
|
return baseContext
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue