Merge pull request #310 from jlhawn/improve_context_pkg

context: improve context package
This commit is contained in:
Stephen Day 2015-04-01 16:39:59 -07:00
commit da9d49d186
9 changed files with 97 additions and 63 deletions

4
.gitignore vendored
View file

@ -31,3 +31,7 @@ bin/*
# Cover profiles # Cover profiles
*.out *.out
# Editor/IDE specific files.
*.sublime-project
*.sublime-workspace

View file

@ -224,7 +224,7 @@ func configureLogging(ctx context.Context, config *configuration.Configuration)
fields = append(fields, k) fields = append(fields, k)
} }
ctx = withMapContext(ctx, config.Log.Fields) ctx = ctxu.WithValues(ctx, config.Log.Fields)
ctx = ctxu.WithLogger(ctx, ctxu.GetLogger(ctx, fields...)) ctx = ctxu.WithLogger(ctx, ctxu.GetLogger(ctx, fields...))
} }
@ -241,36 +241,6 @@ func logLevel(level configuration.Loglevel) log.Level {
return l return l
} }
// stringMapContext is a simple context implementation that checks a map for a
// key, falling back to a parent if not present.
type stringMapContext struct {
context.Context
m map[string]string
}
// withMapContext returns a context that proxies lookups through a map.
func withMapContext(ctx context.Context, m map[string]string) context.Context {
mo := make(map[string]string, len(m)) // make our own copy.
for k, v := range m {
mo[k] = v
}
return stringMapContext{
Context: ctx,
m: mo,
}
}
func (smc stringMapContext) Value(key interface{}) interface{} {
if ks, ok := key.(string); ok {
if v, ok := smc.m[ks]; ok {
return v
}
}
return smc.Context.Value(key)
}
// debugServer starts the debug server with pprof, expvar among other // debugServer starts the debug server with pprof, expvar among other
// endpoints. The addr should not be exposed externally. For most of these to // endpoints. The addr should not be exposed externally. For most of these to
// work, tls cannot be enabled on the endpoint, so it is generally separate. // work, tls cannot be enabled on the endpoint, so it is generally separate.

View file

@ -28,7 +28,7 @@ type Configuration struct {
// Fields allows users to specify static string fields to include in // Fields allows users to specify static string fields to include in
// the logger context. // the logger context.
Fields map[string]string `yaml:"fields"` Fields map[string]interface{} `yaml:"fields"`
} }
// Loglevel is the level at which registry operations are logged. This is // Loglevel is the level at which registry operations are logged. This is

View file

@ -17,11 +17,11 @@ func Test(t *testing.T) { TestingT(t) }
var configStruct = Configuration{ var configStruct = Configuration{
Version: "0.1", Version: "0.1",
Log: struct { Log: struct {
Level Loglevel `yaml:"level"` Level Loglevel `yaml:"level"`
Formatter string `yaml:"formatter"` Formatter string `yaml:"formatter"`
Fields map[string]string `yaml:"fields"` Fields map[string]interface{} `yaml:"fields"`
}{ }{
Fields: map[string]string{"environment": "test"}, Fields: map[string]interface{}{"environment": "test"},
}, },
Loglevel: "info", Loglevel: "info",
Storage: Storage{ Storage: Storage{
@ -340,7 +340,7 @@ func copyConfig(config Configuration) *Configuration {
configCopy.Version = MajorMinorVersion(config.Version.Major(), config.Version.Minor()) configCopy.Version = MajorMinorVersion(config.Version.Major(), config.Version.Minor())
configCopy.Loglevel = config.Loglevel configCopy.Loglevel = config.Loglevel
configCopy.Log = config.Log configCopy.Log = config.Log
configCopy.Log.Fields = make(map[string]string, len(config.Log.Fields)) configCopy.Log.Fields = make(map[string]interface{}, len(config.Log.Fields))
for k, v := range config.Log.Fields { for k, v := range config.Log.Fields {
configCopy.Log.Fields[k] = v configCopy.Log.Fields[k] = v
} }

53
context/context.go Normal file
View file

@ -0,0 +1,53 @@
package context
import (
"golang.org/x/net/context"
)
// Context is a copy of Context from the golang.org/x/net/context package.
type Context interface {
context.Context
}
// Background returns a non-nil, empty Context.
func Background() Context {
return context.Background()
}
// WithValue returns a copy of parent in which the value associated with key is
// val. Use context Values only for request-scoped data that transits processes
// and APIs, not for passing optional parameters to functions.
func WithValue(parent Context, key, val interface{}) Context {
return context.WithValue(parent, key, val)
}
// stringMapContext is a simple context implementation that checks a map for a
// key, falling back to a parent if not present.
type stringMapContext struct {
context.Context
m map[string]interface{}
}
// WithValues returns a context that proxies lookups through a map. Only
// supports string keys.
func WithValues(ctx context.Context, m map[string]interface{}) context.Context {
mo := make(map[string]interface{}, len(m)) // make our own copy.
for k, v := range m {
mo[k] = v
}
return stringMapContext{
Context: ctx,
m: mo,
}
}
func (smc stringMapContext) Value(key interface{}) interface{} {
if ks, ok := key.(string); ok {
if v, ok := smc.m[ks]; ok {
return v
}
}
return smc.Context.Value(key)
}

View file

@ -11,7 +11,6 @@ import (
"code.google.com/p/go-uuid/uuid" "code.google.com/p/go-uuid/uuid"
log "github.com/Sirupsen/logrus" log "github.com/Sirupsen/logrus"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"golang.org/x/net/context"
) )
// Common errors used with this package. // Common errors used with this package.
@ -50,12 +49,25 @@ func RemoteAddr(r *http.Request) string {
return r.RemoteAddr return r.RemoteAddr
} }
// RemoteIP extracts the remote IP of the request, taking into
// account proxy headers.
func RemoteIP(r *http.Request) string {
addr := RemoteAddr(r)
// Try parsing it as "IP:port"
if ip, _, err := net.SplitHostPort(addr); err == nil {
return ip
}
return addr
}
// WithRequest places the request on the context. The context of the request // WithRequest places the request on the context. The context of the request
// is assigned a unique id, available at "http.request.id". The request itself // is assigned a unique id, available at "http.request.id". The request itself
// is available at "http.request". Other common attributes are available under // is available at "http.request". Other common attributes are available under
// the prefix "http.request.". If a request is already present on the context, // the prefix "http.request.". If a request is already present on the context,
// this method will panic. // this method will panic.
func WithRequest(ctx context.Context, r *http.Request) context.Context { func WithRequest(ctx Context, r *http.Request) Context {
if ctx.Value("http.request") != nil { if ctx.Value("http.request") != nil {
// NOTE(stevvooe): This needs to be considered a programming error. It // NOTE(stevvooe): This needs to be considered a programming error. It
// is unlikely that we'd want to have more than one request in // is unlikely that we'd want to have more than one request in
@ -74,7 +86,7 @@ func WithRequest(ctx context.Context, r *http.Request) context.Context {
// GetRequest returns the http request in the given context. Returns // GetRequest returns the http request in the given context. Returns
// ErrNoRequestContext if the context does not have an http request associated // ErrNoRequestContext if the context does not have an http request associated
// with it. // with it.
func GetRequest(ctx context.Context) (*http.Request, error) { func GetRequest(ctx Context) (*http.Request, error) {
if r, ok := ctx.Value("http.request").(*http.Request); r != nil && ok { if r, ok := ctx.Value("http.request").(*http.Request); r != nil && ok {
return r, nil return r, nil
} }
@ -83,13 +95,13 @@ func GetRequest(ctx context.Context) (*http.Request, error) {
// GetRequestID attempts to resolve the current request id, if possible. An // GetRequestID attempts to resolve the current request id, if possible. An
// error is return if it is not available on the context. // error is return if it is not available on the context.
func GetRequestID(ctx context.Context) string { func GetRequestID(ctx Context) string {
return GetStringValue(ctx, "http.request.id") return GetStringValue(ctx, "http.request.id")
} }
// WithResponseWriter returns a new context and response writer that makes // WithResponseWriter returns a new context and response writer that makes
// interesting response statistics available within the context. // interesting response statistics available within the context.
func WithResponseWriter(ctx context.Context, w http.ResponseWriter) (context.Context, http.ResponseWriter) { func WithResponseWriter(ctx Context, w http.ResponseWriter) (Context, http.ResponseWriter) {
irw := &instrumentedResponseWriter{ irw := &instrumentedResponseWriter{
ResponseWriter: w, ResponseWriter: w,
Context: ctx, Context: ctx,
@ -107,7 +119,7 @@ var getVarsFromRequest = mux.Vars
// example, if looking for the variable "name", it can be accessed as // example, if looking for the variable "name", it can be accessed as
// "vars.name". Implementations that are accessing values need not know that // "vars.name". Implementations that are accessing values need not know that
// the underlying context is implemented with gorilla/mux vars. // the underlying context is implemented with gorilla/mux vars.
func WithVars(ctx context.Context, r *http.Request) context.Context { func WithVars(ctx Context, r *http.Request) Context {
return &muxVarsContext{ return &muxVarsContext{
Context: ctx, Context: ctx,
vars: getVarsFromRequest(r), vars: getVarsFromRequest(r),
@ -117,7 +129,7 @@ func WithVars(ctx context.Context, r *http.Request) context.Context {
// GetRequestLogger returns a logger that contains fields from the request in // GetRequestLogger returns a logger that contains fields from the request in
// the current context. If the request is not available in the context, no // the current context. If the request is not available in the context, no
// fields will display. Request loggers can safely be pushed onto the context. // fields will display. Request loggers can safely be pushed onto the context.
func GetRequestLogger(ctx context.Context) Logger { func GetRequestLogger(ctx Context) Logger {
return GetLogger(ctx, return GetLogger(ctx,
"http.request.id", "http.request.id",
"http.request.method", "http.request.method",
@ -133,7 +145,7 @@ func GetRequestLogger(ctx context.Context) Logger {
// Because the values are read at call time, pushing a logger returned from // Because the values are read at call time, pushing a logger returned from
// this function on the context will lead to missing or invalid data. Only // this function on the context will lead to missing or invalid data. Only
// call this at the end of a request, after the response has been written. // call this at the end of a request, after the response has been written.
func GetResponseLogger(ctx context.Context) Logger { func GetResponseLogger(ctx Context) Logger {
l := getLogrusLogger(ctx, l := getLogrusLogger(ctx,
"http.response.written", "http.response.written",
"http.response.status", "http.response.status",
@ -142,7 +154,7 @@ func GetResponseLogger(ctx context.Context) Logger {
duration := Since(ctx, "http.request.startedat") duration := Since(ctx, "http.request.startedat")
if duration > 0 { if duration > 0 {
l = l.WithField("http.response.duration", duration) l = l.WithField("http.response.duration", duration.String())
} }
return l return l
@ -150,7 +162,7 @@ func GetResponseLogger(ctx context.Context) Logger {
// httpRequestContext makes information about a request available to context. // httpRequestContext makes information about a request available to context.
type httpRequestContext struct { type httpRequestContext struct {
context.Context Context
startedAt time.Time startedAt time.Time
id string id string
@ -209,7 +221,7 @@ fallback:
} }
type muxVarsContext struct { type muxVarsContext struct {
context.Context Context
vars map[string]string vars map[string]string
} }
@ -235,7 +247,7 @@ func (ctx *muxVarsContext) Value(key interface{}) interface{} {
// context. // context.
type instrumentedResponseWriter struct { type instrumentedResponseWriter struct {
http.ResponseWriter http.ResponseWriter
context.Context Context
mu sync.Mutex mu sync.Mutex
status int status int

View file

@ -8,8 +8,6 @@ import (
"reflect" "reflect"
"testing" "testing"
"time" "time"
"golang.org/x/net/context"
) )
func TestWithRequest(t *testing.T) { func TestWithRequest(t *testing.T) {
@ -23,7 +21,7 @@ func TestWithRequest(t *testing.T) {
req.Header.Set("Referer", "foo.com/referer") req.Header.Set("Referer", "foo.com/referer")
req.Header.Set("User-Agent", "test/0.1") req.Header.Set("User-Agent", "test/0.1")
ctx := WithRequest(context.Background(), &req) ctx := WithRequest(Background(), &req)
for _, testcase := range []struct { for _, testcase := range []struct {
key string key string
expected interface{} expected interface{}
@ -132,7 +130,7 @@ func (trw *testResponseWriter) Flush() {
func TestWithResponseWriter(t *testing.T) { func TestWithResponseWriter(t *testing.T) {
trw := testResponseWriter{} trw := testResponseWriter{}
ctx, rw := WithResponseWriter(context.Background(), &trw) ctx, rw := WithResponseWriter(Background(), &trw)
if ctx.Value("http.response") != &trw { if ctx.Value("http.response") != &trw {
t.Fatalf("response not available in context: %v != %v", ctx.Value("http.response"), &trw) t.Fatalf("response not available in context: %v != %v", ctx.Value("http.response"), &trw)
@ -183,7 +181,7 @@ func TestWithVars(t *testing.T) {
return vars return vars
} }
ctx := WithVars(context.Background(), &req) ctx := WithVars(Background(), &req)
for _, testcase := range []struct { for _, testcase := range []struct {
key string key string
expected interface{} expected interface{}

View file

@ -4,7 +4,6 @@ import (
"fmt" "fmt"
"github.com/Sirupsen/logrus" "github.com/Sirupsen/logrus"
"golang.org/x/net/context"
) )
// Logger provides a leveled-logging interface. // Logger provides a leveled-logging interface.
@ -41,8 +40,8 @@ type Logger interface {
} }
// WithLogger creates a new context with provided logger. // WithLogger creates a new context with provided logger.
func WithLogger(ctx context.Context, logger Logger) context.Context { func WithLogger(ctx Context, logger Logger) Context {
return context.WithValue(ctx, "logger", logger) return WithValue(ctx, "logger", logger)
} }
// GetLogger returns the logger from the current context, if present. If one // GetLogger returns the logger from the current context, if present. If one
@ -51,7 +50,7 @@ func WithLogger(ctx context.Context, logger Logger) context.Context {
// argument passed to GetLogger will be passed to fmt.Sprint when expanded as // argument passed to GetLogger will be passed to fmt.Sprint when expanded as
// a logging key field. If context keys are integer constants, for example, // a logging key field. If context keys are integer constants, for example,
// its recommended that a String method is implemented. // its recommended that a String method is implemented.
func GetLogger(ctx context.Context, keys ...interface{}) Logger { func GetLogger(ctx Context, keys ...interface{}) Logger {
return getLogrusLogger(ctx, keys...) return getLogrusLogger(ctx, keys...)
} }
@ -59,7 +58,7 @@ func GetLogger(ctx context.Context, keys ...interface{}) Logger {
// are provided, they will be resolved on the context and included in the // are provided, they will be resolved on the context and included in the
// logger. Only use this function if specific logrus functionality is // logger. Only use this function if specific logrus functionality is
// required. // required.
func getLogrusLogger(ctx context.Context, keys ...interface{}) *logrus.Entry { func getLogrusLogger(ctx Context, keys ...interface{}) *logrus.Entry {
var logger *logrus.Entry var logger *logrus.Entry
// Get a logger, if it is present. // Get a logger, if it is present.

View file

@ -2,14 +2,12 @@ package context
import ( import (
"time" "time"
"golang.org/x/net/context"
) )
// Since looks up key, which should be a time.Time, and returns the duration // Since looks up key, which should be a time.Time, and returns the duration
// since that time. If the key is not found, the value returned will be zero. // since that time. If the key is not found, the value returned will be zero.
// This is helpful when inferring metrics related to context execution times. // This is helpful when inferring metrics related to context execution times.
func Since(ctx context.Context, key interface{}) time.Duration { func Since(ctx Context, key interface{}) time.Duration {
startedAtI := ctx.Value(key) startedAtI := ctx.Value(key)
if startedAtI != nil { if startedAtI != nil {
if startedAt, ok := startedAtI.(time.Time); ok { if startedAt, ok := startedAtI.(time.Time); ok {
@ -22,7 +20,7 @@ func Since(ctx context.Context, key interface{}) time.Duration {
// GetStringValue returns a string value from the context. The empty string // GetStringValue returns a string value from the context. The empty string
// will be returned if not found. // will be returned if not found.
func GetStringValue(ctx context.Context, key string) (value string) { func GetStringValue(ctx Context, key string) (value string) {
stringi := ctx.Value(key) stringi := ctx.Value(key)
if stringi != nil { if stringi != nil {
if valuev, ok := stringi.(string); ok { if valuev, ok := stringi.(string); ok {