Refactor Basic Authentication package

This change refactors the basic authentication implementation to better follow
Go coding standards. Many types are no longer exported. The parser is now a
separate function from the authentication code. The standard functions
(*http.Request).BasicAuth/SetBasicAuth are now used where appropriate.

Signed-off-by: Stephen J Day <stephen.day@docker.com>
This commit is contained in:
Stephen J Day 2015-06-08 18:56:48 -07:00
parent abd142855a
commit ffe56ebe41
4 changed files with 142 additions and 138 deletions

View file

@ -26,10 +26,6 @@ storage:
maintenance: maintenance:
uploadpurging: uploadpurging:
enabled: false enabled: false
auth:
basic:
realm: test-realm
path: /tmp/registry-dev/.htpasswd
http: http:
addr: :5000 addr: :5000
secret: asecretforlocaldevelopment secret: asecretforlocaldevelopment

View file

@ -15,23 +15,20 @@ import (
"golang.org/x/net/context" "golang.org/x/net/context"
) )
var (
// ErrInvalidCredential is returned when the auth token does not authenticate correctly.
ErrInvalidCredential = errors.New("invalid authorization credential")
// ErrAuthenticationFailure returned when authentication failure to be presented to agent.
ErrAuthenticationFailure = errors.New("authentication failured")
)
type accessController struct { type accessController struct {
realm string realm string
htpasswd *htpasswd htpasswd *htpasswd
} }
type challenge struct {
realm string
err error
}
var _ auth.AccessController = &accessController{} var _ auth.AccessController = &accessController{}
var (
// ErrPasswordRequired Returned when no auth token is given.
ErrPasswordRequired = errors.New("authorization credential required")
// ErrInvalidCredential is returned when the auth token does not authenticate correctly.
ErrInvalidCredential = errors.New("invalid authorization credential")
)
func newAccessController(options map[string]interface{}) (auth.AccessController, error) { func newAccessController(options map[string]interface{}) (auth.AccessController, error) {
realm, present := options["realm"] realm, present := options["realm"]
@ -53,28 +50,29 @@ func (ac *accessController) Authorized(ctx context.Context, accessRecords ...aut
return nil, err return nil, err
} }
authHeader := req.Header.Get("Authorization") username, password, ok := req.BasicAuth()
if authHeader == "" {
challenge := challenge{
realm: ac.realm,
}
return nil, &challenge
}
user, pass, ok := req.BasicAuth()
if !ok { if !ok {
return nil, errors.New("Invalid Authorization header") return nil, &challenge{
}
if res, _ := ac.htpasswd.AuthenticateUser(user, pass); !res {
challenge := challenge{
realm: ac.realm, realm: ac.realm,
err: ErrInvalidCredential,
} }
challenge.err = ErrInvalidCredential
return nil, &challenge
} }
return auth.WithUser(ctx, auth.UserInfo{Name: user}), nil if err := ac.htpasswd.authenticateUser(ctx, username, password); err != nil {
ctxu.GetLogger(ctx).Errorf("error authenticating user %q: %v", username, err)
return nil, &challenge{
realm: ac.realm,
err: ErrAuthenticationFailure,
}
}
return auth.WithUser(ctx, auth.UserInfo{Name: username}), nil
}
// challenge implements the auth.Challenge interface.
type challenge struct {
realm string
err error
} }
func (ch *challenge) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (ch *challenge) ServeHTTP(w http.ResponseWriter, r *http.Request) {

View file

@ -1,14 +1,13 @@
package basic package basic
import ( import (
"encoding/base64"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"github.com/docker/distribution/context"
"github.com/docker/distribution/registry/auth" "github.com/docker/distribution/registry/auth"
"golang.org/x/net/context"
) )
func TestBasicAccessController(t *testing.T) { func TestBasicAccessController(t *testing.T) {
@ -33,6 +32,7 @@ func TestBasicAccessController(t *testing.T) {
"realm": testRealm, "realm": testRealm,
"path": tempFile.Name(), "path": tempFile.Name(),
} }
ctx := context.Background()
accessController, err := newAccessController(options) accessController, err := newAccessController(options)
if err != nil { if err != nil {
@ -44,7 +44,7 @@ func TestBasicAccessController(t *testing.T) {
var userNumber = 0 var userNumber = 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(nil, "http.request", r) ctx := context.WithRequest(ctx, r)
authCtx, err := accessController.Authorized(ctx) authCtx, err := accessController.Authorized(ctx)
if err != nil { if err != nil {
switch err := err.(type) { switch err := err.(type) {
@ -87,13 +87,14 @@ func TestBasicAccessController(t *testing.T) {
for i := 0; i < len(testUsers); i++ { for i := 0; i < len(testUsers); i++ {
userNumber = i userNumber = i
req, _ = http.NewRequest("GET", server.URL, nil) req, err := http.NewRequest("GET", server.URL, nil)
sekrit := testUsers[i] + ":" + testPasswords[i] if err != nil {
credential := "Basic " + base64.StdEncoding.EncodeToString([]byte(sekrit)) t.Fatalf("error allocating new request: %v", err)
}
req.SetBasicAuth(testUsers[i], testPasswords[i])
req.Header.Set("Authorization", credential)
resp, err = client.Do(req) resp, err = client.Do(req)
if err != nil { if err != nil {
t.Fatalf("unexpected error during GET: %v", err) t.Fatalf("unexpected error during GET: %v", err)
} }
@ -101,7 +102,7 @@ func TestBasicAccessController(t *testing.T) {
// Request should be authorized // Request should be authorized
if resp.StatusCode != http.StatusNoContent { if resp.StatusCode != http.StatusNoContent {
t.Fatalf("unexpected non-success response status: %v != %v for %s %s %s", resp.StatusCode, http.StatusNoContent, testUsers[i], testPasswords[i], credential) t.Fatalf("unexpected non-success response status: %v != %v for %s %s", resp.StatusCode, http.StatusNoContent, testUsers[i], testPasswords[i])
} }
} }

View file

@ -1,54 +1,66 @@
package basic package basic
import ( import (
"bufio"
"crypto/sha1" "crypto/sha1"
"encoding/base64" "encoding/base64"
"encoding/csv" "io"
"errors"
"os" "os"
"regexp" "regexp"
"strings" "strings"
"github.com/docker/distribution/context"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
// ErrAuthenticationFailure A generic error message for authentication failure to be presented to agent. // htpasswd holds a path to a system .htpasswd file and the machinery to parse it.
var ErrAuthenticationFailure = errors.New("Bad username or password")
// htpasswd Holds a path to a system .htpasswd file and the machinery to parse it.
type htpasswd struct { type htpasswd struct {
path string path string
reader *csv.Reader
} }
// AuthType Represents a particular hash function used in the htpasswd file. // authType represents a particular hash function used in the htpasswd file.
type AuthType int type authType int
const ( const (
// PlainText Plain-text password storage (htpasswd -p) authTypePlainText authType = iota // Plain-text password storage (htpasswd -p)
PlainText AuthType = iota authTypeSHA1 // sha hashed password storage (htpasswd -s)
// SHA1 sha hashed password storage (htpasswd -s) authTypeApacheMD5 // apr iterated md5 hashing (htpasswd -m)
SHA1 authTypeBCrypt // BCrypt adapative password hashing (htpasswd -B)
// ApacheMD5 apr iterated md5 hashing (htpasswd -m) authTypeCrypt // System crypt() hashes. (htpasswd -d)
ApacheMD5
// BCrypt BCrypt adapative password hashing (htpasswd -B)
BCrypt
// Crypt System crypt() hashes. (htpasswd -d)
Crypt
) )
var bcryptPrefixRegexp = regexp.MustCompile(`^\$2[ab]?y\$`)
// detectAuthCredentialType inspects the credential and resolves the encryption scheme.
func detectAuthCredentialType(cred string) authType {
if strings.HasPrefix(cred, "{SHA}") {
return authTypeSHA1
}
if strings.HasPrefix(cred, "$apr1$") {
return authTypeApacheMD5
}
if bcryptPrefixRegexp.MatchString(cred) {
return authTypeBCrypt
}
// There's just not a great way to distinguish between these next two...
if len(cred) == 13 {
return authTypeCrypt
}
return authTypePlainText
}
// String Returns a text representation of the AuthType // String Returns a text representation of the AuthType
func (at AuthType) String() string { func (at authType) String() string {
switch at { switch at {
case PlainText: case authTypePlainText:
return "plaintext" return "plaintext"
case SHA1: case authTypeSHA1:
return "sha1" return "sha1"
case ApacheMD5: case authTypeApacheMD5:
return "md5" return "md5"
case BCrypt: case authTypeBCrypt:
return "bcrypt" return "bcrypt"
case Crypt: case authTypeCrypt:
return "system crypt" return "system crypt"
} }
return "unknown" return "unknown"
@ -59,83 +71,80 @@ func newHTPasswd(htpath string) *htpasswd {
return &htpasswd{path: htpath} return &htpasswd{path: htpath}
} }
var bcryptPrefixRegexp = regexp.MustCompile(`^\$2[ab]?y\$`) // AuthenticateUser checks a given user:password credential against the
// receiving HTPasswd's file. If the check passes, nil is returned. Note that
// GetAuthCredentialType Inspect an htpasswd file credential and guess the encryption algorithm used. // this parses the htpasswd file on each request so ensure that updates are
func GetAuthCredentialType(cred string) AuthType { // available.
if strings.HasPrefix(cred, "{SHA}") { func (htpasswd *htpasswd) authenticateUser(ctx context.Context, username string, password string) error {
return SHA1
}
if strings.HasPrefix(cred, "$apr1$") {
return ApacheMD5
}
if bcryptPrefixRegexp.MatchString(cred) {
return BCrypt
}
// There's just not a great way to distinguish between these next two...
if len(cred) == 13 {
return Crypt
}
return PlainText
}
// AuthenticateUser Check a given user:password credential against the receiving HTPasswd's file.
func (htpasswd *htpasswd) AuthenticateUser(user string, pwd string) (bool, error) {
// Open the file. // Open the file.
in, err := os.Open(htpasswd.path) in, err := os.Open(htpasswd.path)
if err != nil { if err != nil {
return false, err return err
}
defer in.Close()
for _, entry := range parseHTPasswd(ctx, in) {
if entry.username != username {
continue // wrong entry
}
switch t := detectAuthCredentialType(entry.password); t {
case authTypeSHA1:
sha := sha1.New()
sha.Write([]byte(password))
hash := base64.StdEncoding.EncodeToString(sha.Sum(nil))
if entry.password[5:] != hash {
return ErrAuthenticationFailure
}
return nil
case authTypeBCrypt:
err := bcrypt.CompareHashAndPassword([]byte(entry.password), []byte(password))
if err != nil {
return ErrAuthenticationFailure
}
return nil
case authTypePlainText:
if password != entry.password {
return ErrAuthenticationFailure
}
return nil
default:
context.GetLogger(ctx).Errorf("unsupported basic authentication type: %v", t)
}
} }
// Parse the contents of the standard .htpasswd until we hit the end or find a match. return ErrAuthenticationFailure
reader := csv.NewReader(in) }
reader.Comma = ':'
reader.Comment = '#' // htpasswdEntry represents a line in an htpasswd file.
reader.TrimLeadingSpace = true type htpasswdEntry struct {
for entry, readerr := reader.Read(); entry != nil || readerr != nil; entry, readerr = reader.Read() { username string // username, plain text
if readerr != nil { password string // stores hashed passwd
return false, readerr }
}
if len(entry) == 0 { // parseHTPasswd parses the contents of htpasswd. Bad entries are skipped and
// logged, so this may return empty. This will read all the entries in the
// file, whether or not they are needed.
func parseHTPasswd(ctx context.Context, rd io.Reader) []htpasswdEntry {
entries := []htpasswdEntry{}
scanner := bufio.NewScanner(rd)
for scanner.Scan() {
t := strings.TrimSpace(scanner.Text())
i := strings.Index(t, ":")
if i < 0 || i >= len(t) {
context.GetLogger(ctx).Errorf("bad entry in htpasswd: %q", t)
continue continue
} }
if entry[0] == user {
credential := entry[1] entries = append(entries, htpasswdEntry{
credType := GetAuthCredentialType(credential) username: t[:i],
switch credType { password: t[i+1:],
case SHA1: })
{
sha := sha1.New()
sha.Write([]byte(pwd))
hash := base64.StdEncoding.EncodeToString(sha.Sum(nil))
return entry[1][5:] == hash, nil
}
case ApacheMD5:
{
return false, errors.New(ApacheMD5.String() + " htpasswd hash function not yet supported")
}
case BCrypt:
{
err := bcrypt.CompareHashAndPassword([]byte(credential), []byte(pwd))
if err != nil {
return false, err
}
return true, nil
}
case Crypt:
{
return false, errors.New(Crypt.String() + " htpasswd hash function not yet supported")
}
case PlainText:
{
if pwd == credential {
return true, nil
}
return false, ErrAuthenticationFailure
}
}
}
} }
return false, ErrAuthenticationFailure
return entries
} }