From eaa9da0be33bf7c4d8cd7c1131a5f38ac3af06aa Mon Sep 17 00:00:00 2001 From: Derek McGowan Date: Mon, 25 Jan 2016 15:42:05 -0800 Subject: [PATCH] Add simple implementation of token server Token server implementation currently functional with existing docker 1.9.x release and latest distribution release. Signed-off-by: Derek McGowan (github: dmcgowan) --- contrib/token-server/main.go | 202 ++++++++++++++++++++++++++++++++++ contrib/token-server/token.go | 168 ++++++++++++++++++++++++++++ 2 files changed, 370 insertions(+) create mode 100644 contrib/token-server/main.go create mode 100644 contrib/token-server/token.go diff --git a/contrib/token-server/main.go b/contrib/token-server/main.go new file mode 100644 index 000000000..303ed9ed0 --- /dev/null +++ b/contrib/token-server/main.go @@ -0,0 +1,202 @@ +package main + +import ( + "encoding/json" + "flag" + "net/http" + "strings" + + "github.com/Sirupsen/logrus" + "github.com/docker/distribution/context" + "github.com/docker/distribution/registry/api/errcode" + "github.com/docker/distribution/registry/auth" + _ "github.com/docker/distribution/registry/auth/htpasswd" + "github.com/docker/libtrust" + "github.com/gorilla/mux" +) + +func main() { + var ( + issuer = &TokenIssuer{} + pkFile string + addr string + debug bool + err error + + passwdFile string + realm string + + cert string + certKey string + ) + + flag.StringVar(&issuer.Issuer, "issuer", "distribution-token-server", "Issuer string for token") + flag.StringVar(&pkFile, "key", "", "Private key file") + flag.StringVar(&addr, "addr", "localhost:8080", "Address to listen on") + flag.BoolVar(&debug, "debug", false, "Debug mode") + + flag.StringVar(&passwdFile, "passwd", ".htpasswd", "Passwd file") + flag.StringVar(&realm, "realm", "", "Authentication realm") + + flag.StringVar(&cert, "tlscert", "", "Certificate file for TLS") + flag.StringVar(&certKey, "tlskey", "", "Certificate key for TLS") + + flag.Parse() + + if debug { + logrus.SetLevel(logrus.DebugLevel) + } + + if pkFile == "" { + issuer.SigningKey, err = libtrust.GenerateECP256PrivateKey() + if err != nil { + logrus.Fatalf("Error generating private key: %v", err) + } + logrus.Debugf("Using newly generated key with id %s", issuer.SigningKey.KeyID()) + } else { + issuer.SigningKey, err = libtrust.LoadKeyFile(pkFile) + if err != nil { + logrus.Fatalf("Error loading key file %s: %v", pkFile, err) + } + logrus.Debugf("Loaded private key with id %s", issuer.SigningKey.KeyID()) + } + + if realm == "" { + logrus.Fatalf("Must provide realm") + } + + ac, err := auth.GetAccessController("htpasswd", map[string]interface{}{ + "realm": realm, + "path": passwdFile, + }) + if err != nil { + logrus.Fatalf("Error initializing access controller: %v", err) + } + + ctx := context.Background() + + ts := &tokenServer{ + issuer: issuer, + accessController: ac, + } + + router := mux.NewRouter() + router.Path("/token/").Methods("GET").Handler(handlerWithContext(ctx, ts.getToken)) + + if cert == "" { + err = http.ListenAndServe(addr, router) + } else if certKey == "" { + logrus.Fatalf("Must provide certficate and key") + } else { + err = http.ListenAndServeTLS(addr, cert, certKey, router) + } + + if err != nil { + logrus.Infof("Error serving: %v", err) + } + +} + +// handlerWithContext wraps the given context-aware handler by setting up the +// request context from a base context. +func handlerWithContext(ctx context.Context, handler func(context.Context, http.ResponseWriter, *http.Request)) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := context.WithRequest(ctx, r) + logger := context.GetRequestLogger(ctx) + ctx = context.WithLogger(ctx, logger) + + handler(ctx, w, r) + }) +} + +func handleError(ctx context.Context, err error, w http.ResponseWriter) { + ctx, w = context.WithResponseWriter(ctx, w) + + if serveErr := errcode.ServeJSON(w, err); serveErr != nil { + context.GetResponseLogger(ctx).Errorf("error sending error response: %v", serveErr) + return + } + + context.GetResponseLogger(ctx).Info("application error") +} + +type tokenServer struct { + issuer *TokenIssuer + accessController auth.AccessController +} + +// getToken handles authenticating the request and authorizing access to the +// requested scopes. +func (ts *tokenServer) getToken(ctx context.Context, w http.ResponseWriter, r *http.Request) { + context.GetLogger(ctx).Info("getToken") + + params := r.URL.Query() + service := params.Get("service") + scopeSpecifiers := params["scope"] + + requestedAccessList := ResolveScopeSpecifiers(scopeSpecifiers) + + authorizedCtx, err := ts.accessController.Authorized(ctx, requestedAccessList...) + if err != nil { + challenge, ok := err.(auth.Challenge) + if !ok { + handleError(ctx, err, w) + return + } + + // Get response context. + ctx, w = context.WithResponseWriter(ctx, w) + + challenge.SetHeaders(w) + handleError(ctx, errcode.ErrorCodeUnauthorized.WithDetail(challenge.Error()), w) + + context.GetResponseLogger(ctx).Info("authentication challenged") + + return + } + ctx = authorizedCtx + + // TODO(dmcgowan): handle case where this could panic? + username := ctx.Value("auth.user.name").(string) + + ctx = context.WithValue(ctx, "acctSubject", username) + ctx = context.WithLogger(ctx, context.GetLogger(ctx, "acctSubject")) + + context.GetLogger(ctx).Info("authenticated client") + + ctx = context.WithValue(ctx, "requestedAccess", requestedAccessList) + ctx = context.WithLogger(ctx, context.GetLogger(ctx, "requestedAccess")) + + scopePrefix := username + "/" + grantedAccessList := make([]auth.Access, 0, len(requestedAccessList)) + for _, access := range requestedAccessList { + if access.Type != "repository" { + context.GetLogger(ctx).Debugf("Skipping unsupported resource type: %s", access.Type) + continue + } + if !strings.HasPrefix(access.Name, scopePrefix) { + context.GetLogger(ctx).Debugf("Resource scope not allowed: %s", access.Name) + continue + } + grantedAccessList = append(grantedAccessList, access) + } + + ctx = context.WithValue(ctx, "grantedAccess", grantedAccessList) + ctx = context.WithLogger(ctx, context.GetLogger(ctx, "grantedAccess")) + + token, err := ts.issuer.CreateJWT(username, service, grantedAccessList) + if err != nil { + handleError(ctx, err, w) + return + } + + context.GetLogger(ctx).Info("authorized client") + + // Get response context. + ctx, w = context.WithResponseWriter(ctx, w) + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]string{"token": token}) + + context.GetResponseLogger(ctx).Info("getToken complete") +} diff --git a/contrib/token-server/token.go b/contrib/token-server/token.go new file mode 100644 index 000000000..917d6ee3e --- /dev/null +++ b/contrib/token-server/token.go @@ -0,0 +1,168 @@ +package main + +import ( + "crypto" + "crypto/rand" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "strings" + "time" + + "github.com/docker/distribution/registry/auth" + "github.com/docker/distribution/registry/auth/token" + "github.com/docker/libtrust" +) + +// ResolveScopeSpecifiers converts a list of scope specifiers from a token +// request's `scope` query parameters into a list of standard access objects. +func ResolveScopeSpecifiers(scopeSpecs []string) []auth.Access { + requestedAccessSet := make(map[auth.Access]struct{}, 2*len(scopeSpecs)) + + for _, scopeSpecifier := range scopeSpecs { + // There should be 3 parts, separated by a `:` character. + parts := strings.SplitN(scopeSpecifier, ":", 3) + + if len(parts) != 3 { + // Ignore malformed scope specifiers. + continue + } + + resourceType, resourceName, actions := parts[0], parts[1], parts[2] + + // Actions should be a comma-separated list of actions. + for _, action := range strings.Split(actions, ",") { + requestedAccess := auth.Access{ + Resource: auth.Resource{ + Type: resourceType, + Name: resourceName, + }, + Action: action, + } + + // Add this access to the requested access set. + requestedAccessSet[requestedAccess] = struct{}{} + } + } + + requestedAccessList := make([]auth.Access, 0, len(requestedAccessSet)) + for requestedAccess := range requestedAccessSet { + requestedAccessList = append(requestedAccessList, requestedAccess) + } + + return requestedAccessList +} + +// TokenIssuer represents an issuer capable of generating JWT tokens +type TokenIssuer struct { + Issuer string + SigningKey libtrust.PrivateKey + Expiration time.Duration +} + +// CreateJWT creates and signs a JSON Web Token for the given account and +// audience with the granted access. +func (issuer *TokenIssuer) CreateJWT(subject string, audience string, grantedAccessList []auth.Access) (string, error) { + // Make a set of access entries to put in the token's claimset. + resourceActionSets := make(map[auth.Resource]map[string]struct{}, len(grantedAccessList)) + for _, access := range grantedAccessList { + actionSet, exists := resourceActionSets[access.Resource] + if !exists { + actionSet = map[string]struct{}{} + resourceActionSets[access.Resource] = actionSet + } + actionSet[access.Action] = struct{}{} + } + + accessEntries := make([]token.ResourceActions, 0, len(resourceActionSets)) + for resource, actionSet := range resourceActionSets { + actions := make([]string, 0, len(actionSet)) + for action := range actionSet { + actions = append(actions, action) + } + + accessEntries = append(accessEntries, token.ResourceActions{ + Type: resource.Type, + Name: resource.Name, + Actions: actions, + }) + } + + randomBytes := make([]byte, 15) + _, err := io.ReadFull(rand.Reader, randomBytes) + if err != nil { + return "", err + } + randomID := base64.URLEncoding.EncodeToString(randomBytes) + + now := time.Now() + + signingHash := crypto.SHA256 + var alg string + switch issuer.SigningKey.KeyType() { + case "RSA": + alg = "RS256" + case "EC": + alg = "ES256" + default: + panic(fmt.Errorf("unsupported signing key type %q", issuer.SigningKey.KeyType())) + } + + joseHeader := map[string]interface{}{ + "typ": "JWT", + "alg": alg, + } + + if x5c := issuer.SigningKey.GetExtendedField("x5c"); x5c != nil { + joseHeader["x5c"] = x5c + } else { + joseHeader["jwk"] = issuer.SigningKey.PublicKey() + } + + exp := issuer.Expiration + if exp == 0 { + exp = 5 * time.Minute + } + + claimSet := map[string]interface{}{ + "iss": issuer.Issuer, + "sub": subject, + "aud": audience, + "exp": now.Add(exp).Unix(), + "nbf": now.Unix(), + "iat": now.Unix(), + "jti": randomID, + + "access": accessEntries, + } + + var ( + joseHeaderBytes []byte + claimSetBytes []byte + ) + + if joseHeaderBytes, err = json.Marshal(joseHeader); err != nil { + return "", fmt.Errorf("unable to encode jose header: %s", err) + } + if claimSetBytes, err = json.Marshal(claimSet); err != nil { + return "", fmt.Errorf("unable to encode claim set: %s", err) + } + + encodedJoseHeader := joseBase64Encode(joseHeaderBytes) + encodedClaimSet := joseBase64Encode(claimSetBytes) + encodingToSign := fmt.Sprintf("%s.%s", encodedJoseHeader, encodedClaimSet) + + var signatureBytes []byte + if signatureBytes, _, err = issuer.SigningKey.Sign(strings.NewReader(encodingToSign), signingHash); err != nil { + return "", fmt.Errorf("unable to sign jwt payload: %s", err) + } + + signature := joseBase64Encode(signatureBytes) + + return fmt.Sprintf("%s.%s", encodingToSign, signature), nil +} + +func joseBase64Encode(data []byte) string { + return strings.TrimRight(base64.URLEncoding.EncodeToString(data), "=") +}