[#2] Add NNS context #4
10 changed files with 48 additions and 14 deletions
|
@ -362,7 +362,7 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
|
||||
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil, "")
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
|
@ -801,7 +801,7 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
|
||||
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil, "")
|
||||
req := httptest.NewRequest("GET", "/foo/bar", nil)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
|
@ -1004,7 +1004,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
|
||||
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil, "")
|
||||
req := httptest.NewRequest("GET", "/foo/bar", nil)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
|
|
|
@ -95,7 +95,7 @@ func (h *handler) Route(r api.Router) {
|
|||
if ca, ok := h.opts.CA.(*authority.Authority); ok && ca != nil {
|
||||
ctx = authority.NewContext(ctx, ca)
|
||||
}
|
||||
ctx = acme.NewContext(ctx, h.opts.DB, client, linker, h.opts.PrerequisitesChecker)
|
||||
ctx = acme.NewContext(ctx, h.opts.DB, client, linker, h.opts.PrerequisitesChecker, "")
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
})
|
||||
|
|
|
@ -346,7 +346,7 @@ func TestHandler_GetAuthorization(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
|
||||
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil, "")
|
||||
req := httptest.NewRequest("GET", "/foo/bar", nil)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
|
@ -746,7 +746,7 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
|
||||
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil, "")
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
|
|
|
@ -507,9 +507,13 @@ func deviceAttest01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose
|
|||
func nns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error {
|
||||
domain := strings.TrimPrefix(ch.Value, "*.")
|
||||
|
||||
nnsCtx, ok := GetNNSContext(ctx)
|
||||
if !ok {
|
||||
return errors.New("error retrieving NNS context")
|
||||
}
|
||||
|
||||
nns := NNS{}
|
||||
// TODO: retrieve NNS server URL from config
|
||||
err := nns.Dial("http://localhost:30333")
|
||||
err := nns.Dial(nnsCtx.nnsServer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -29,10 +29,12 @@ type CertificateAuthority interface {
|
|||
}
|
||||
|
||||
// NewContext adds the given acme components to the context.
|
||||
func NewContext(ctx context.Context, db DB, client Client, linker Linker, fn PrerequisitesChecker) context.Context {
|
||||
func NewContext(ctx context.Context, db DB, client Client, linker Linker, fn PrerequisitesChecker,
|
||||
nnsServer string) context.Context {
|
||||
ctx = NewDatabaseContext(ctx, db)
|
||||
ctx = NewClientContext(ctx, client)
|
||||
ctx = NewLinkerContext(ctx, linker)
|
||||
ctx = NewNNSContext(ctx, nnsServer)
|
||||
// Prerequisite checker is optional.
|
||||
if fn != nil {
|
||||
ctx = NewPrerequisitesCheckerContext(ctx, fn)
|
||||
|
|
18
acme/nns.go
18
acme/nns.go
|
@ -35,6 +35,24 @@ type NNS struct {
|
|||
client multiSchemeClient
|
||||
}
|
||||
|
||||
// NNSContext is used to store info about NNS server.
|
||||
type NNSContext struct {
|
||||
nnsServer string
|
||||
}
|
||||
|
||||
type nnsKey struct{}
|
||||
|
||||
// NewNNSContext adds new NNSContext with given params to the context.
|
||||
func NewNNSContext(ctx context.Context, nnsServer string) context.Context {
|
||||
return context.WithValue(ctx, nnsKey{}, NNSContext{nnsServer: nnsServer})
|
||||
}
|
||||
|
||||
// GetNNSContext returns NNSContext from the given context.
|
||||
func GetNNSContext(ctx context.Context) (NNSContext, bool) {
|
||||
c, ok := ctx.Value(nnsKey{}).(NNSContext)
|
||||
return c, ok
|
||||
}
|
||||
|
||||
// Dial connects to the address of the NNS server.
|
||||
// If URL address scheme is 'ws' or 'wss', then WebSocket protocol is used, otherwise HTTP.
|
||||
func (n *NNS) Dial(address string) error {
|
||||
|
|
|
@ -79,6 +79,7 @@ type Config struct {
|
|||
CommonName string `json:"commonName,omitempty"`
|
||||
CRL *CRLConfig `json:"crl,omitempty"`
|
||||
SkipValidation bool `json:"-"`
|
||||
NNSServer string `json:"nnsServer,omitempty"`
|
||||
|
||||
// Keeps record of the filename the Config is read from
|
||||
loadedFromFilepath string
|
||||
|
|
|
@ -54,7 +54,7 @@ func startCABootstrapServer() *httptest.Server {
|
|||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
baseContext := buildContext(ca.auth, nil, nil, nil)
|
||||
baseContext := buildContext(ca.auth, nil, nil, nil, "")
|
||||
srv.Config.Handler = ca.srv.Handler
|
||||
srv.Config.BaseContext = func(net.Listener) context.Context {
|
||||
return baseContext
|
||||
|
|
15
ca/ca.go
15
ca/ca.go
|
@ -309,8 +309,16 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) {
|
|||
insecureHandler = logger.Middleware(insecureHandler)
|
||||
}
|
||||
|
||||
var nnsServer string
|
||||
if len(ca.opts.nnsServer) > 0 {
|
||||
nnsServer = ca.opts.nnsServer
|
||||
} else if len(ca.config.NNSServer) > 0 {
|
||||
nnsServer = ca.config.NNSServer
|
||||
} else {
|
||||
return nil, errors.New("error configuring ACME NNS context: no URL of the NNS server provided")
|
||||
}
|
||||
// Create context with all the necessary values.
|
||||
baseContext := buildContext(auth, scepAuthority, acmeDB, acmeLinker)
|
||||
baseContext := buildContext(auth, scepAuthority, acmeDB, acmeLinker, nnsServer)
|
||||
|
||||
ca.srv = server.New(cfg.Address, handler, tlsConfig)
|
||||
ca.srv.BaseContext = func(net.Listener) context.Context {
|
||||
|
@ -351,7 +359,8 @@ func (ca *CA) shouldServeInsecureServer() bool {
|
|||
}
|
||||
|
||||
// buildContext builds the server base context.
|
||||
func buildContext(a *authority.Authority, scepAuthority *scep.Authority, acmeDB acme.DB, acmeLinker acme.Linker) context.Context {
|
||||
func buildContext(a *authority.Authority, scepAuthority *scep.Authority, acmeDB acme.DB, acmeLinker acme.Linker,
|
||||
nnsServer string) context.Context {
|
||||
ctx := authority.NewContext(context.Background(), a)
|
||||
if authDB := a.GetDatabase(); authDB != nil {
|
||||
ctx = db.NewContext(ctx, authDB)
|
||||
|
@ -363,7 +372,7 @@ func buildContext(a *authority.Authority, scepAuthority *scep.Authority, acmeDB
|
|||
ctx = scep.NewContext(ctx, scepAuthority)
|
||||
}
|
||||
if acmeDB != nil {
|
||||
ctx = acme.NewContext(ctx, acmeDB, acme.NewClient(), acmeLinker, nil)
|
||||
ctx = acme.NewContext(ctx, acmeDB, acme.NewClient(), acmeLinker, nil, nnsServer)
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
|
|
@ -79,7 +79,7 @@ func startCATestServer() *httptest.Server {
|
|||
}
|
||||
// Use a httptest.Server instead
|
||||
srv := startTestServer(ca.srv.TLSConfig, ca.srv.Handler)
|
||||
baseContext := buildContext(ca.auth, nil, nil, nil)
|
||||
baseContext := buildContext(ca.auth, nil, nil, nil, "")
|
||||
srv.Config.BaseContext = func(net.Listener) context.Context {
|
||||
return baseContext
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue