[#2] Add NNS context

Signed-off-by: Marina Biryukova <m.biryukova@yadro.com>
This commit is contained in:
Marina Biryukova 2023-08-03 15:46:59 +03:00
parent f4dfed3bf3
commit b13ba41053
10 changed files with 48 additions and 14 deletions

View file

@ -362,7 +362,7 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil) ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil, "")
req := httptest.NewRequest("GET", u, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -801,7 +801,7 @@ func TestHandler_NewAccount(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil) ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil, "")
req := httptest.NewRequest("GET", "/foo/bar", nil) req := httptest.NewRequest("GET", "/foo/bar", nil)
req = req.WithContext(ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -1004,7 +1004,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil) ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil, "")
req := httptest.NewRequest("GET", "/foo/bar", nil) req := httptest.NewRequest("GET", "/foo/bar", nil)
req = req.WithContext(ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()

View file

@ -95,7 +95,7 @@ func (h *handler) Route(r api.Router) {
if ca, ok := h.opts.CA.(*authority.Authority); ok && ca != nil { if ca, ok := h.opts.CA.(*authority.Authority); ok && ca != nil {
ctx = authority.NewContext(ctx, ca) ctx = authority.NewContext(ctx, ca)
} }
ctx = acme.NewContext(ctx, h.opts.DB, client, linker, h.opts.PrerequisitesChecker) ctx = acme.NewContext(ctx, h.opts.DB, client, linker, h.opts.PrerequisitesChecker, "")
next(w, r.WithContext(ctx)) next(w, r.WithContext(ctx))
} }
}) })

View file

@ -346,7 +346,7 @@ func TestHandler_GetAuthorization(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil) ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil, "")
req := httptest.NewRequest("GET", "/foo/bar", nil) req := httptest.NewRequest("GET", "/foo/bar", nil)
req = req.WithContext(ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -746,7 +746,7 @@ func TestHandler_GetChallenge(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil) ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil, "")
req := httptest.NewRequest("GET", u, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()

View file

@ -507,9 +507,13 @@ func deviceAttest01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose
func nns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error { func nns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error {
domain := strings.TrimPrefix(ch.Value, "*.") domain := strings.TrimPrefix(ch.Value, "*.")
nnsCtx, ok := GetNNSContext(ctx)
if !ok {
return errors.New("error retrieving NNS context")
}
nns := NNS{} nns := NNS{}
// TODO: retrieve NNS server URL from config err := nns.Dial(nnsCtx.nnsServer)
err := nns.Dial("http://localhost:30333")
if err != nil { if err != nil {
return err return err
} }

View file

@ -29,10 +29,12 @@ type CertificateAuthority interface {
} }
// NewContext adds the given acme components to the context. // NewContext adds the given acme components to the context.
func NewContext(ctx context.Context, db DB, client Client, linker Linker, fn PrerequisitesChecker) context.Context { func NewContext(ctx context.Context, db DB, client Client, linker Linker, fn PrerequisitesChecker,
nnsServer string) context.Context {
ctx = NewDatabaseContext(ctx, db) ctx = NewDatabaseContext(ctx, db)
ctx = NewClientContext(ctx, client) ctx = NewClientContext(ctx, client)
ctx = NewLinkerContext(ctx, linker) ctx = NewLinkerContext(ctx, linker)
ctx = NewNNSContext(ctx, nnsServer)
// Prerequisite checker is optional. // Prerequisite checker is optional.
if fn != nil { if fn != nil {
ctx = NewPrerequisitesCheckerContext(ctx, fn) ctx = NewPrerequisitesCheckerContext(ctx, fn)

View file

@ -35,6 +35,24 @@ type NNS struct {
client multiSchemeClient client multiSchemeClient
} }
// NNSContext is used to store info about NNS server.
type NNSContext struct {
nnsServer string
}
type nnsKey struct{}
// NewNNSContext adds new NNSContext with given params to the context.
func NewNNSContext(ctx context.Context, nnsServer string) context.Context {
return context.WithValue(ctx, nnsKey{}, NNSContext{nnsServer: nnsServer})
}
// GetNNSContext returns NNSContext from the given context.
func GetNNSContext(ctx context.Context) (NNSContext, bool) {
c, ok := ctx.Value(nnsKey{}).(NNSContext)
return c, ok
}
// Dial connects to the address of the NNS server. // Dial connects to the address of the NNS server.
// If URL address scheme is 'ws' or 'wss', then WebSocket protocol is used, otherwise HTTP. // If URL address scheme is 'ws' or 'wss', then WebSocket protocol is used, otherwise HTTP.
func (n *NNS) Dial(address string) error { func (n *NNS) Dial(address string) error {

View file

@ -79,6 +79,7 @@ type Config struct {
CommonName string `json:"commonName,omitempty"` CommonName string `json:"commonName,omitempty"`
CRL *CRLConfig `json:"crl,omitempty"` CRL *CRLConfig `json:"crl,omitempty"`
SkipValidation bool `json:"-"` SkipValidation bool `json:"-"`
NNSServer string `json:"nnsServer,omitempty"`
// Keeps record of the filename the Config is read from // Keeps record of the filename the Config is read from
loadedFromFilepath string loadedFromFilepath string

View file

@ -54,7 +54,7 @@ func startCABootstrapServer() *httptest.Server {
if err != nil { if err != nil {
panic(err) panic(err)
} }
baseContext := buildContext(ca.auth, nil, nil, nil) baseContext := buildContext(ca.auth, nil, nil, nil, "")
srv.Config.Handler = ca.srv.Handler srv.Config.Handler = ca.srv.Handler
srv.Config.BaseContext = func(net.Listener) context.Context { srv.Config.BaseContext = func(net.Listener) context.Context {
return baseContext return baseContext

View file

@ -309,8 +309,16 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) {
insecureHandler = logger.Middleware(insecureHandler) insecureHandler = logger.Middleware(insecureHandler)
} }
var nnsServer string
if len(ca.opts.nnsServer) > 0 {
nnsServer = ca.opts.nnsServer
} else if len(ca.config.NNSServer) > 0 {
nnsServer = ca.config.NNSServer
} else {
return nil, errors.New("error configuring ACME NNS context: no URL of the NNS server provided")
}
// Create context with all the necessary values. // Create context with all the necessary values.
baseContext := buildContext(auth, scepAuthority, acmeDB, acmeLinker) baseContext := buildContext(auth, scepAuthority, acmeDB, acmeLinker, nnsServer)
ca.srv = server.New(cfg.Address, handler, tlsConfig) ca.srv = server.New(cfg.Address, handler, tlsConfig)
ca.srv.BaseContext = func(net.Listener) context.Context { ca.srv.BaseContext = func(net.Listener) context.Context {
@ -351,7 +359,8 @@ func (ca *CA) shouldServeInsecureServer() bool {
} }
// buildContext builds the server base context. // buildContext builds the server base context.
func buildContext(a *authority.Authority, scepAuthority *scep.Authority, acmeDB acme.DB, acmeLinker acme.Linker) context.Context { func buildContext(a *authority.Authority, scepAuthority *scep.Authority, acmeDB acme.DB, acmeLinker acme.Linker,
nnsServer string) context.Context {
ctx := authority.NewContext(context.Background(), a) ctx := authority.NewContext(context.Background(), a)
if authDB := a.GetDatabase(); authDB != nil { if authDB := a.GetDatabase(); authDB != nil {
ctx = db.NewContext(ctx, authDB) ctx = db.NewContext(ctx, authDB)
@ -363,7 +372,7 @@ func buildContext(a *authority.Authority, scepAuthority *scep.Authority, acmeDB
ctx = scep.NewContext(ctx, scepAuthority) ctx = scep.NewContext(ctx, scepAuthority)
} }
if acmeDB != nil { if acmeDB != nil {
ctx = acme.NewContext(ctx, acmeDB, acme.NewClient(), acmeLinker, nil) ctx = acme.NewContext(ctx, acmeDB, acme.NewClient(), acmeLinker, nil, nnsServer)
} }
return ctx return ctx
} }

View file

@ -79,7 +79,7 @@ func startCATestServer() *httptest.Server {
} }
// Use a httptest.Server instead // Use a httptest.Server instead
srv := startTestServer(ca.srv.TLSConfig, ca.srv.Handler) srv := startTestServer(ca.srv.TLSConfig, ca.srv.Handler)
baseContext := buildContext(ca.auth, nil, nil, nil) baseContext := buildContext(ca.auth, nil, nil, nil, "")
srv.Config.BaseContext = func(net.Listener) context.Context { srv.Config.BaseContext = func(net.Listener) context.Context {
return baseContext return baseContext
} }