From b13ba41053de68a6a1189df82215c676824c159e Mon Sep 17 00:00:00 2001 From: Marina Biryukova Date: Thu, 3 Aug 2023 15:46:59 +0300 Subject: [PATCH] [#2] Add NNS context Signed-off-by: Marina Biryukova --- acme/api/account_test.go | 6 +++--- acme/api/handler.go | 2 +- acme/api/handler_test.go | 4 ++-- acme/challenge.go | 8 ++++++-- acme/common.go | 4 +++- acme/nns.go | 18 ++++++++++++++++++ authority/config/config.go | 1 + ca/bootstrap_test.go | 2 +- ca/ca.go | 15 ++++++++++++--- ca/tls_test.go | 2 +- 10 files changed, 48 insertions(+), 14 deletions(-) diff --git a/acme/api/account_test.go b/acme/api/account_test.go index c4cfaa02..e36c4adc 100644 --- a/acme/api/account_test.go +++ b/acme/api/account_test.go @@ -362,7 +362,7 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil) + ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil, "") req := httptest.NewRequest("GET", u, nil) req = req.WithContext(ctx) w := httptest.NewRecorder() @@ -801,7 +801,7 @@ func TestHandler_NewAccount(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil) + ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil, "") req := httptest.NewRequest("GET", "/foo/bar", nil) req = req.WithContext(ctx) w := httptest.NewRecorder() @@ -1004,7 +1004,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil) + ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil, "") req := httptest.NewRequest("GET", "/foo/bar", nil) req = req.WithContext(ctx) w := httptest.NewRecorder() diff --git a/acme/api/handler.go b/acme/api/handler.go index 16713cf7..42a47d1a 100644 --- a/acme/api/handler.go +++ b/acme/api/handler.go @@ -95,7 +95,7 @@ func (h *handler) Route(r api.Router) { if ca, ok := h.opts.CA.(*authority.Authority); ok && ca != nil { ctx = authority.NewContext(ctx, ca) } - ctx = acme.NewContext(ctx, h.opts.DB, client, linker, h.opts.PrerequisitesChecker) + ctx = acme.NewContext(ctx, h.opts.DB, client, linker, h.opts.PrerequisitesChecker, "") next(w, r.WithContext(ctx)) } }) diff --git a/acme/api/handler_test.go b/acme/api/handler_test.go index 7ef7cd68..0e3c3467 100644 --- a/acme/api/handler_test.go +++ b/acme/api/handler_test.go @@ -346,7 +346,7 @@ func TestHandler_GetAuthorization(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil) + ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil, "") req := httptest.NewRequest("GET", "/foo/bar", nil) req = req.WithContext(ctx) w := httptest.NewRecorder() @@ -746,7 +746,7 @@ func TestHandler_GetChallenge(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil) + ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil, "") req := httptest.NewRequest("GET", u, nil) req = req.WithContext(ctx) w := httptest.NewRecorder() diff --git a/acme/challenge.go b/acme/challenge.go index cc96bda7..fca3f1e1 100644 --- a/acme/challenge.go +++ b/acme/challenge.go @@ -507,9 +507,13 @@ func deviceAttest01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose func nns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error { domain := strings.TrimPrefix(ch.Value, "*.") + nnsCtx, ok := GetNNSContext(ctx) + if !ok { + return errors.New("error retrieving NNS context") + } + nns := NNS{} - // TODO: retrieve NNS server URL from config - err := nns.Dial("http://localhost:30333") + err := nns.Dial(nnsCtx.nnsServer) if err != nil { return err } diff --git a/acme/common.go b/acme/common.go index 7d58305f..e5e0075f 100644 --- a/acme/common.go +++ b/acme/common.go @@ -29,10 +29,12 @@ type CertificateAuthority interface { } // NewContext adds the given acme components to the context. -func NewContext(ctx context.Context, db DB, client Client, linker Linker, fn PrerequisitesChecker) context.Context { +func NewContext(ctx context.Context, db DB, client Client, linker Linker, fn PrerequisitesChecker, + nnsServer string) context.Context { ctx = NewDatabaseContext(ctx, db) ctx = NewClientContext(ctx, client) ctx = NewLinkerContext(ctx, linker) + ctx = NewNNSContext(ctx, nnsServer) // Prerequisite checker is optional. if fn != nil { ctx = NewPrerequisitesCheckerContext(ctx, fn) diff --git a/acme/nns.go b/acme/nns.go index c6f928ef..0358f6cc 100644 --- a/acme/nns.go +++ b/acme/nns.go @@ -35,6 +35,24 @@ type NNS struct { client multiSchemeClient } +// NNSContext is used to store info about NNS server. +type NNSContext struct { + nnsServer string +} + +type nnsKey struct{} + +// NewNNSContext adds new NNSContext with given params to the context. +func NewNNSContext(ctx context.Context, nnsServer string) context.Context { + return context.WithValue(ctx, nnsKey{}, NNSContext{nnsServer: nnsServer}) +} + +// GetNNSContext returns NNSContext from the given context. +func GetNNSContext(ctx context.Context) (NNSContext, bool) { + c, ok := ctx.Value(nnsKey{}).(NNSContext) + return c, ok +} + // Dial connects to the address of the NNS server. // If URL address scheme is 'ws' or 'wss', then WebSocket protocol is used, otherwise HTTP. func (n *NNS) Dial(address string) error { diff --git a/authority/config/config.go b/authority/config/config.go index ae284fb9..3e38e7e1 100644 --- a/authority/config/config.go +++ b/authority/config/config.go @@ -79,6 +79,7 @@ type Config struct { CommonName string `json:"commonName,omitempty"` CRL *CRLConfig `json:"crl,omitempty"` SkipValidation bool `json:"-"` + NNSServer string `json:"nnsServer,omitempty"` // Keeps record of the filename the Config is read from loadedFromFilepath string diff --git a/ca/bootstrap_test.go b/ca/bootstrap_test.go index 62c422d4..cf52c123 100644 --- a/ca/bootstrap_test.go +++ b/ca/bootstrap_test.go @@ -54,7 +54,7 @@ func startCABootstrapServer() *httptest.Server { if err != nil { panic(err) } - baseContext := buildContext(ca.auth, nil, nil, nil) + baseContext := buildContext(ca.auth, nil, nil, nil, "") srv.Config.Handler = ca.srv.Handler srv.Config.BaseContext = func(net.Listener) context.Context { return baseContext diff --git a/ca/ca.go b/ca/ca.go index b16b0c34..60cb923d 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -309,8 +309,16 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) { insecureHandler = logger.Middleware(insecureHandler) } + var nnsServer string + if len(ca.opts.nnsServer) > 0 { + nnsServer = ca.opts.nnsServer + } else if len(ca.config.NNSServer) > 0 { + nnsServer = ca.config.NNSServer + } else { + return nil, errors.New("error configuring ACME NNS context: no URL of the NNS server provided") + } // Create context with all the necessary values. - baseContext := buildContext(auth, scepAuthority, acmeDB, acmeLinker) + baseContext := buildContext(auth, scepAuthority, acmeDB, acmeLinker, nnsServer) ca.srv = server.New(cfg.Address, handler, tlsConfig) ca.srv.BaseContext = func(net.Listener) context.Context { @@ -351,7 +359,8 @@ func (ca *CA) shouldServeInsecureServer() bool { } // buildContext builds the server base context. -func buildContext(a *authority.Authority, scepAuthority *scep.Authority, acmeDB acme.DB, acmeLinker acme.Linker) context.Context { +func buildContext(a *authority.Authority, scepAuthority *scep.Authority, acmeDB acme.DB, acmeLinker acme.Linker, + nnsServer string) context.Context { ctx := authority.NewContext(context.Background(), a) if authDB := a.GetDatabase(); authDB != nil { ctx = db.NewContext(ctx, authDB) @@ -363,7 +372,7 @@ func buildContext(a *authority.Authority, scepAuthority *scep.Authority, acmeDB ctx = scep.NewContext(ctx, scepAuthority) } if acmeDB != nil { - ctx = acme.NewContext(ctx, acmeDB, acme.NewClient(), acmeLinker, nil) + ctx = acme.NewContext(ctx, acmeDB, acme.NewClient(), acmeLinker, nil, nnsServer) } return ctx } diff --git a/ca/tls_test.go b/ca/tls_test.go index 24b8ef01..8a171846 100644 --- a/ca/tls_test.go +++ b/ca/tls_test.go @@ -79,7 +79,7 @@ func startCATestServer() *httptest.Server { } // Use a httptest.Server instead srv := startTestServer(ca.srv.TLSConfig, ca.srv.Handler) - baseContext := buildContext(ca.auth, nil, nil, nil) + baseContext := buildContext(ca.auth, nil, nil, nil, "") srv.Config.BaseContext = func(net.Listener) context.Context { return baseContext }