diff --git a/plugin/metadata/metadata.go b/plugin/metadata/metadata.go index 1e840d3fd..e7560d403 100644 --- a/plugin/metadata/metadata.go +++ b/plugin/metadata/metadata.go @@ -24,15 +24,16 @@ func (m *Metadata) Name() string { return "metadata" } // ServeDNS implements the plugin.Handler interface. func (m *Metadata) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { - md, ctx := newMD(ctx) + ctx = context.WithValue(ctx, metadataKey{}, M{}) + md, _ := FromContext(ctx) state := request.Request{W: w, Req: r} if plugin.Zones(m.Zones).Matches(state.Name()) != "" { - // Go through all Providers and collect metadata + // Go through all Providers and collect metadata. for _, provider := range m.Providers { for _, varName := range provider.MetadataVarNames() { - if val, ok := provider.Metadata(ctx, w, r, varName); ok { - md.setValue(varName, val) + if val, ok := provider.Metadata(ctx, state, varName); ok { + md.SetValue(varName, val) } } } @@ -47,8 +48,8 @@ func (m *Metadata) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Ms func (m *Metadata) MetadataVarNames() []string { return variables.All } // Metadata implements the plugin.Provider interface. -func (m *Metadata) Metadata(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, varName string) (interface{}, bool) { - if val, err := variables.GetValue(varName, w, r); err == nil { +func (m *Metadata) Metadata(ctx context.Context, state request.Request, varName string) (interface{}, bool) { + if val, err := variables.GetValue(state, varName); err == nil { return val, true } return nil, false diff --git a/plugin/metadata/metadata_test.go b/plugin/metadata/metadata_test.go index 413ba874e..8bbff4c34 100644 --- a/plugin/metadata/metadata_test.go +++ b/plugin/metadata/metadata_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/coredns/coredns/plugin/test" + "github.com/coredns/coredns/request" "github.com/miekg/dns" ) @@ -20,12 +21,12 @@ func (m testProvider) MetadataVarNames() []string { return keys } -func (m testProvider) Metadata(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, key string) (val interface{}, ok bool) { +func (m testProvider) Metadata(ctx context.Context, state request.Request, key string) (val interface{}, ok bool) { value, ok := m[key] return value, ok } -// testHandler implements plugin.Handler +// testHandler implements plugin.Handler. type testHandler struct{ ctx context.Context } func (m *testHandler) Name() string { return "testHandler" } @@ -35,7 +36,7 @@ func (m *testHandler) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns return 0, nil } -func TestMetadataServDns(t *testing.T) { +func TestMetadataServeDNS(t *testing.T) { expectedMetadata := []testProvider{ testProvider{"testkey1": "testvalue1"}, testProvider{"testkey2": 2, "testkey3": "testvalue3"}, @@ -45,9 +46,8 @@ func TestMetadataServDns(t *testing.T) { for _, e := range expectedMetadata { providers = append(providers, e) } - // Fake handler which stores the resulting context - next := &testHandler{} + next := &testHandler{} // fake handler which stores the resulting context metadata := Metadata{ Zones: []string{"."}, Providers: providers, diff --git a/plugin/metadata/metadataer.go b/plugin/metadata/provider.go similarity index 57% rename from plugin/metadata/metadataer.go rename to plugin/metadata/provider.go index bff12e92d..e13f9c896 100644 --- a/plugin/metadata/metadataer.go +++ b/plugin/metadata/provider.go @@ -3,7 +3,7 @@ package metadata import ( "context" - "github.com/miekg/dns" + "github.com/coredns/coredns/request" ) // Provider interface needs to be implemented by each plugin willing to provide @@ -16,38 +16,32 @@ type Provider interface { // Metadata is expected to return a value with metadata information by the key // from 4th argument. Value can be later retrieved from context by any other plugin. // If value is not available by some reason returned boolean value should be false. - Metadata(context.Context, dns.ResponseWriter, *dns.Msg, string) (interface{}, bool) + Metadata(ctx context.Context, state request.Request, variable string) (interface{}, bool) } -// MD is metadata information storage -type MD map[string]interface{} +// M is metadata information storage. +type M map[string]interface{} -// metadataKey defines the type of key that is used to save metadata into the context -type metadataKey struct{} - -// newMD initializes MD and attaches it to context -func newMD(ctx context.Context) (MD, context.Context) { - m := MD{} - return m, context.WithValue(ctx, metadataKey{}, m) -} - -// FromContext retrieves MD struct from context. -func FromContext(ctx context.Context) (md MD, ok bool) { +// FromContext retrieves the metadata from the context. +func FromContext(ctx context.Context) (M, bool) { if metadata := ctx.Value(metadataKey{}); metadata != nil { - if md, ok := metadata.(MD); ok { - return md, true + if m, ok := metadata.(M); ok { + return m, true } } - return MD{}, false + return M{}, false } // Value returns metadata value by key. -func (m MD) Value(key string) (value interface{}, ok bool) { +func (m M) Value(key string) (value interface{}, ok bool) { value, ok = m[key] return value, ok } -// setValue adds metadata value. -func (m MD) setValue(key string, val interface{}) { +// SetValue sets the metadata value under key. +func (m M) SetValue(key string, val interface{}) { m[key] = val } + +// metadataKey defines the type of key that is used to save metadata into the context. +type metadataKey struct{} diff --git a/plugin/metadata/metadataer_test.go b/plugin/metadata/provider_test.go similarity index 65% rename from plugin/metadata/metadataer_test.go rename to plugin/metadata/provider_test.go index 53096feb8..1a074aeaa 100644 --- a/plugin/metadata/metadataer_test.go +++ b/plugin/metadata/provider_test.go @@ -25,23 +25,24 @@ func TestMD(t *testing.T) { // Using one same md and ctx for all test cases ctx := context.TODO() - md, ctx := newMD(ctx) + ctx = context.WithValue(ctx, metadataKey{}, M{}) + m, _ := FromContext(ctx) for i, tc := range tests { for k, v := range tc.addValues { - md.setValue(k, v) + m.SetValue(k, v) } - if !reflect.DeepEqual(tc.expectedValues, map[string]interface{}(md)) { - t.Errorf("Test %d: Expected %v but got %v", i, tc.expectedValues, md) + if !reflect.DeepEqual(tc.expectedValues, map[string]interface{}(m)) { + t.Errorf("Test %d: Expected %v but got %v", i, tc.expectedValues, m) } - // Make sure that MD is recieved from context successfullly - mdFromContext, ok := FromContext(ctx) + // Make sure that md is recieved from context successfullly + mFromContext, ok := FromContext(ctx) if !ok { - t.Errorf("Test %d: MD is not recieved from the context", i) + t.Errorf("Test %d: md is not recieved from the context", i) } - if !reflect.DeepEqual(md, mdFromContext) { - t.Errorf("Test %d: MD recieved from context differs from initial. Initial: %v, from context: %v", i, md, mdFromContext) + if !reflect.DeepEqual(m, mFromContext) { + t.Errorf("Test %d: md recieved from context differs from initial. Initial: %v, from context: %v", i, m, mFromContext) } } } diff --git a/plugin/pkg/variables/variables.go b/plugin/pkg/variables/variables.go index da1dccbee..8e1cdbe77 100644 --- a/plugin/pkg/variables/variables.go +++ b/plugin/pkg/variables/variables.go @@ -7,8 +7,6 @@ import ( "strconv" "github.com/coredns/coredns/request" - - "github.com/miekg/dns" ) const ( @@ -26,35 +24,32 @@ var All = []string{queryName, queryType, clientIP, clientPort, protocol, serverI // GetValue calculates and returns the data specified by the variable name. // Supported varNames are listed in allProvidedVars. -func GetValue(varName string, w dns.ResponseWriter, r *dns.Msg) ([]byte, error) { - req := request.Request{W: w, Req: r} +func GetValue(state request.Request, varName string) ([]byte, error) { switch varName { case queryName: - //Query name is written as ascii string - return []byte(req.QName()), nil + return []byte(state.QName()), nil case queryType: - return uint16ToWire(req.QType()), nil + return uint16ToWire(state.QType()), nil case clientIP: - return ipToWire(req.Family(), req.IP()) + return ipToWire(state.Family(), state.IP()) case clientPort: - return portToWire(req.Port()) + return portToWire(state.Port()) case protocol: - // Proto is written as ascii string - return []byte(req.Proto()), nil + return []byte(state.Proto()), nil case serverIP: - ip, _, err := net.SplitHostPort(w.LocalAddr().String()) + ip, _, err := net.SplitHostPort(state.W.LocalAddr().String()) if err != nil { - ip = w.RemoteAddr().String() + ip = state.W.RemoteAddr().String() } - return ipToWire(family(w.RemoteAddr()), ip) + return ipToWire(state.Family(), ip) case serverPort: - _, port, err := net.SplitHostPort(w.LocalAddr().String()) + _, port, err := net.SplitHostPort(state.W.LocalAddr().String()) if err != nil { port = "0" } diff --git a/plugin/pkg/variables/variables_test.go b/plugin/pkg/variables/variables_test.go index 939add323..e0ff64c19 100644 --- a/plugin/pkg/variables/variables_test.go +++ b/plugin/pkg/variables/variables_test.go @@ -5,6 +5,8 @@ import ( "testing" "github.com/coredns/coredns/plugin/test" + "github.com/coredns/coredns/request" + "github.com/miekg/dns" ) @@ -63,8 +65,9 @@ func TestGetValue(t *testing.T) { m := new(dns.Msg) m.SetQuestion("example.com.", dns.TypeA) m.Question[0].Qclass = dns.ClassINET + state := request.Request{W: &test.ResponseWriter{}, Req: m} - value, err := GetValue(tc.varName, &test.ResponseWriter{}, m) + value, err := GetValue(state, tc.varName) if tc.shouldErr && err == nil { t.Errorf("Test %d: Expected error, but didn't recieve", i) diff --git a/plugin/rewrite/edns0.go b/plugin/rewrite/edns0.go index f8b65d468..2391936c7 100644 --- a/plugin/rewrite/edns0.go +++ b/plugin/rewrite/edns0.go @@ -202,7 +202,8 @@ func (rule *edns0VariableRule) ruleData(ctx context.Context, w dns.ResponseWrite } } } else { // No metadata available means metadata plugin is disabled. Try to get the value directly. - return variables.GetValue(rule.variable, w, r) + state := request.Request{W: w, Req: r} // TODO(miek): every rule needs to take a request.Request. + return variables.GetValue(state, rule.variable) } return nil, fmt.Errorf("unable to extract data for variable %s", rule.variable) }