diff --git a/contrib/token-server/main.go b/contrib/token-server/main.go index 6a4c1778b..1cebc0e16 100644 --- a/contrib/token-server/main.go +++ b/contrib/token-server/main.go @@ -183,6 +183,18 @@ func filterAccessList(ctx context.Context, scope string, requestedAccessList []a return grantedAccessList } +type acctSubject struct{} + +func (acctSubject) String() string { return "acctSubject" } + +type requestedAccess struct{} + +func (requestedAccess) String() string { return "requestedAccess" } + +type grantedAccess struct{} + +func (grantedAccess) String() string { return "grantedAccess" } + // getToken handles authenticating the request and authorizing access to the // requested scopes. func (ts *tokenServer) getToken(ctx context.Context, w http.ResponseWriter, r *http.Request) { @@ -225,17 +237,17 @@ func (ts *tokenServer) getToken(ctx context.Context, w http.ResponseWriter, r *h username := context.GetStringValue(ctx, "auth.user.name") - ctx = context.WithValue(ctx, "acctSubject", username) - ctx = context.WithLogger(ctx, context.GetLogger(ctx, "acctSubject")) + ctx = context.WithValue(ctx, acctSubject{}, username) + ctx = context.WithLogger(ctx, context.GetLogger(ctx, acctSubject{})) context.GetLogger(ctx).Info("authenticated client") - ctx = context.WithValue(ctx, "requestedAccess", requestedAccessList) - ctx = context.WithLogger(ctx, context.GetLogger(ctx, "requestedAccess")) + ctx = context.WithValue(ctx, requestedAccess{}, requestedAccessList) + ctx = context.WithLogger(ctx, context.GetLogger(ctx, requestedAccess{})) grantedAccessList := filterAccessList(ctx, username, requestedAccessList) - ctx = context.WithValue(ctx, "grantedAccess", grantedAccessList) - ctx = context.WithLogger(ctx, context.GetLogger(ctx, "grantedAccess")) + ctx = context.WithValue(ctx, grantedAccess{}, grantedAccessList) + ctx = context.WithLogger(ctx, context.GetLogger(ctx, grantedAccess{})) token, err := ts.issuer.CreateJWT(username, service, grantedAccessList) if err != nil { @@ -347,17 +359,17 @@ func (ts *tokenServer) postToken(ctx context.Context, w http.ResponseWriter, r * return } - ctx = context.WithValue(ctx, "acctSubject", subject) - ctx = context.WithLogger(ctx, context.GetLogger(ctx, "acctSubject")) + ctx = context.WithValue(ctx, acctSubject{}, subject) + ctx = context.WithLogger(ctx, context.GetLogger(ctx, acctSubject{})) context.GetLogger(ctx).Info("authenticated client") - ctx = context.WithValue(ctx, "requestedAccess", requestedAccessList) - ctx = context.WithLogger(ctx, context.GetLogger(ctx, "requestedAccess")) + ctx = context.WithValue(ctx, requestedAccess{}, requestedAccessList) + ctx = context.WithLogger(ctx, context.GetLogger(ctx, requestedAccess{})) grantedAccessList := filterAccessList(ctx, subject, requestedAccessList) - ctx = context.WithValue(ctx, "grantedAccess", grantedAccessList) - ctx = context.WithLogger(ctx, context.GetLogger(ctx, "grantedAccess")) + ctx = context.WithValue(ctx, grantedAccess{}, grantedAccessList) + ctx = context.WithLogger(ctx, context.GetLogger(ctx, grantedAccess{})) token, err := ts.issuer.CreateJWT(subject, service, grantedAccessList) if err != nil { diff --git a/registry/auth/silly/access_test.go b/registry/auth/silly/access_test.go index a7c14cb9d..0a5103e6c 100644 --- a/registry/auth/silly/access_test.go +++ b/registry/auth/silly/access_test.go @@ -16,7 +16,7 @@ func TestSillyAccessController(t *testing.T) { } server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx := context.WithValue(nil, "http.request", r) + ctx := context.WithRequest(context.Background(), r) authCtx, err := ac.Authorized(ctx) if err != nil { switch err := err.(type) { diff --git a/registry/auth/token/token_test.go b/registry/auth/token/token_test.go index 27206f9b4..cbfe2a6b4 100644 --- a/registry/auth/token/token_test.go +++ b/registry/auth/token/token_test.go @@ -354,7 +354,7 @@ func TestAccessController(t *testing.T) { Action: "baz", } - ctx := context.WithValue(nil, "http.request", req) + ctx := context.WithRequest(context.Background(), req) authCtx, err := accessController.Authorized(ctx, testAccess) challenge, ok := err.(auth.Challenge) if !ok { diff --git a/registry/handlers/app.go b/registry/handlers/app.go index 4df15ae6e..fde2a4acc 100644 --- a/registry/handlers/app.go +++ b/registry/handlers/app.go @@ -461,6 +461,8 @@ func (app *App) configureEvents(configuration *configuration.Configuration) { } } +type redisStartAtKey struct{} + func (app *App) configureRedis(configuration *configuration.Configuration) { if configuration.Redis.Addr == "" { ctxu.GetLogger(app).Infof("redis not configured") @@ -470,11 +472,11 @@ func (app *App) configureRedis(configuration *configuration.Configuration) { pool := &redis.Pool{ Dial: func() (redis.Conn, error) { // TODO(stevvooe): Yet another use case for contextual timing. - ctx := context.WithValue(app, "redis.connect.startedat", time.Now()) + ctx := context.WithValue(app, redisStartAtKey{}, time.Now()) done := func(err error) { logger := ctxu.GetLoggerWithField(ctx, "redis.connect.duration", - ctxu.Since(ctx, "redis.connect.startedat")) + ctxu.Since(ctx, redisStartAtKey{})) if err != nil { logger.Errorf("redis: error connecting: %v", err) } else { @@ -707,6 +709,18 @@ func (app *App) dispatcher(dispatch dispatchFunc) http.Handler { }) } +type errCodeKey struct{} + +func (errCodeKey) String() string { return "err.code" } + +type errMessageKey struct{} + +func (errMessageKey) String() string { return "err.message" } + +type errDetailKey struct{} + +func (errDetailKey) String() string { return "err.detail" } + func (app *App) logError(context context.Context, errors errcode.Errors) { for _, e1 := range errors { var c ctxu.Context @@ -714,23 +728,23 @@ func (app *App) logError(context context.Context, errors errcode.Errors) { switch e1.(type) { case errcode.Error: e, _ := e1.(errcode.Error) - c = ctxu.WithValue(context, "err.code", e.Code) - c = ctxu.WithValue(c, "err.message", e.Code.Message()) - c = ctxu.WithValue(c, "err.detail", e.Detail) + c = ctxu.WithValue(context, errCodeKey{}, e.Code) + c = ctxu.WithValue(c, errMessageKey{}, e.Code.Message()) + c = ctxu.WithValue(c, errDetailKey{}, e.Detail) case errcode.ErrorCode: e, _ := e1.(errcode.ErrorCode) - c = ctxu.WithValue(context, "err.code", e) - c = ctxu.WithValue(c, "err.message", e.Message()) + c = ctxu.WithValue(context, errCodeKey{}, e) + c = ctxu.WithValue(c, errMessageKey{}, e.Message()) default: // just normal go 'error' - c = ctxu.WithValue(context, "err.code", errcode.ErrorCodeUnknown) - c = ctxu.WithValue(c, "err.message", e1.Error()) + c = ctxu.WithValue(context, errCodeKey{}, errcode.ErrorCodeUnknown) + c = ctxu.WithValue(c, errMessageKey{}, e1.Error()) } c = ctxu.WithLogger(c, ctxu.GetLogger(c, - "err.code", - "err.message", - "err.detail")) + errCodeKey{}, + errMessageKey{}, + errDetailKey{})) ctxu.GetResponseLogger(c).Errorf("response completed with error") } } diff --git a/registry/storage/driver/s3-aws/s3.go b/registry/storage/driver/s3-aws/s3.go index e808f7606..b3c9d37c3 100644 --- a/registry/storage/driver/s3-aws/s3.go +++ b/registry/storage/driver/s3-aws/s3.go @@ -76,8 +76,8 @@ const noStorageClass = "NONE" // validRegions maps known s3 region identifiers to region descriptors var validRegions = map[string]struct{}{} -// validObjectAcls contains known s3 object Acls -var validObjectAcls = map[string]struct{}{} +// validObjectACLs contains known s3 object Acls +var validObjectACLs = map[string]struct{}{} //DriverParameters A struct that encapsulates all of the driver parameters after all values have been set type DriverParameters struct { @@ -97,7 +97,7 @@ type DriverParameters struct { RootDirectory string StorageClass string UserAgent string - ObjectAcl string + ObjectACL string } func init() { @@ -118,7 +118,7 @@ func init() { validRegions[region] = struct{}{} } - for _, objectAcl := range []string{ + for _, objectACL := range []string{ s3.ObjectCannedACLPrivate, s3.ObjectCannedACLPublicRead, s3.ObjectCannedACLPublicReadWrite, @@ -127,7 +127,7 @@ func init() { s3.ObjectCannedACLBucketOwnerRead, s3.ObjectCannedACLBucketOwnerFullControl, } { - validObjectAcls[objectAcl] = struct{}{} + validObjectACLs[objectACL] = struct{}{} } // Register this as the default s3 driver in addition to s3aws @@ -153,7 +153,7 @@ type driver struct { MultipartCopyThresholdSize int64 RootDirectory string StorageClass string - ObjectAcl string + ObjectACL string } type baseEmbed struct { @@ -313,18 +313,18 @@ func FromParameters(parameters map[string]interface{}) (*Driver, error) { userAgent = "" } - objectAcl := s3.ObjectCannedACLPrivate - objectAclParam := parameters["objectacl"] - if objectAclParam != nil { - objectAclString, ok := objectAclParam.(string) + objectACL := s3.ObjectCannedACLPrivate + objectACLParam := parameters["objectacl"] + if objectACLParam != nil { + objectACLString, ok := objectACLParam.(string) if !ok { - return nil, fmt.Errorf("Invalid value for objectacl parameter: %v", objectAclParam) + return nil, fmt.Errorf("Invalid value for objectacl parameter: %v", objectACLParam) } - if _, ok = validObjectAcls[objectAclString]; !ok { - return nil, fmt.Errorf("Invalid value for objectacl parameter: %v", objectAclParam) + if _, ok = validObjectACLs[objectACLString]; !ok { + return nil, fmt.Errorf("Invalid value for objectacl parameter: %v", objectACLParam) } - objectAcl = objectAclString + objectACL = objectACLString } params := DriverParameters{ @@ -344,7 +344,7 @@ func FromParameters(parameters map[string]interface{}) (*Driver, error) { fmt.Sprint(rootDirectory), storageClass, fmt.Sprint(userAgent), - objectAcl, + objectACL, } return New(params) @@ -459,7 +459,7 @@ func New(params DriverParameters) (*Driver, error) { MultipartCopyThresholdSize: params.MultipartCopyThresholdSize, RootDirectory: params.RootDirectory, StorageClass: params.StorageClass, - ObjectAcl: params.ObjectAcl, + ObjectACL: params.ObjectACL, } return &Driver{ @@ -912,7 +912,7 @@ func (d *driver) getContentType() *string { } func (d *driver) getACL() *string { - return aws.String(d.ObjectAcl) + return aws.String(d.ObjectACL) } func (d *driver) getStorageClass() *string { diff --git a/registry/storage/driver/s3-aws/s3_test.go b/registry/storage/driver/s3-aws/s3_test.go index 16c579cbb..eb7ee5195 100644 --- a/registry/storage/driver/s3-aws/s3_test.go +++ b/registry/storage/driver/s3-aws/s3_test.go @@ -33,7 +33,7 @@ func init() { secure := os.Getenv("S3_SECURE") v4Auth := os.Getenv("S3_V4_AUTH") region := os.Getenv("AWS_REGION") - objectAcl := os.Getenv("S3_OBJECT_ACL") + objectACL := os.Getenv("S3_OBJECT_ACL") root, err := ioutil.TempDir("", "driver-") regionEndpoint := os.Getenv("REGION_ENDPOINT") if err != nil { @@ -83,7 +83,7 @@ func init() { rootDirectory, storageClass, driverName + "-test", - objectAcl, + objectACL, } return New(parameters)