Refactor Basic Authentication package
This change refactors the basic authentication implementation to better follow Go coding standards. Many types are no longer exported. The parser is now a separate function from the authentication code. The standard functions (*http.Request).BasicAuth/SetBasicAuth are now used where appropriate. Signed-off-by: Stephen J Day <stephen.day@docker.com>
This commit is contained in:
parent
3504445680
commit
427c457801
3 changed files with 142 additions and 134 deletions
|
@ -15,23 +15,20 @@ import (
|
||||||
"golang.org/x/net/context"
|
"golang.org/x/net/context"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrInvalidCredential is returned when the auth token does not authenticate correctly.
|
||||||
|
ErrInvalidCredential = errors.New("invalid authorization credential")
|
||||||
|
|
||||||
|
// ErrAuthenticationFailure returned when authentication failure to be presented to agent.
|
||||||
|
ErrAuthenticationFailure = errors.New("authentication failured")
|
||||||
|
)
|
||||||
|
|
||||||
type accessController struct {
|
type accessController struct {
|
||||||
realm string
|
realm string
|
||||||
htpasswd *htpasswd
|
htpasswd *htpasswd
|
||||||
}
|
}
|
||||||
|
|
||||||
type challenge struct {
|
|
||||||
realm string
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ auth.AccessController = &accessController{}
|
var _ auth.AccessController = &accessController{}
|
||||||
var (
|
|
||||||
// ErrPasswordRequired Returned when no auth token is given.
|
|
||||||
ErrPasswordRequired = errors.New("authorization credential required")
|
|
||||||
// ErrInvalidCredential is returned when the auth token does not authenticate correctly.
|
|
||||||
ErrInvalidCredential = errors.New("invalid authorization credential")
|
|
||||||
)
|
|
||||||
|
|
||||||
func newAccessController(options map[string]interface{}) (auth.AccessController, error) {
|
func newAccessController(options map[string]interface{}) (auth.AccessController, error) {
|
||||||
realm, present := options["realm"]
|
realm, present := options["realm"]
|
||||||
|
@ -53,28 +50,29 @@ func (ac *accessController) Authorized(ctx context.Context, accessRecords ...aut
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
authHeader := req.Header.Get("Authorization")
|
username, password, ok := req.BasicAuth()
|
||||||
if authHeader == "" {
|
|
||||||
challenge := challenge{
|
|
||||||
realm: ac.realm,
|
|
||||||
}
|
|
||||||
return nil, &challenge
|
|
||||||
}
|
|
||||||
|
|
||||||
user, pass, ok := req.BasicAuth()
|
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("Invalid Authorization header")
|
return nil, &challenge{
|
||||||
}
|
|
||||||
|
|
||||||
if res, _ := ac.htpasswd.AuthenticateUser(user, pass); !res {
|
|
||||||
challenge := challenge{
|
|
||||||
realm: ac.realm,
|
realm: ac.realm,
|
||||||
|
err: ErrInvalidCredential,
|
||||||
}
|
}
|
||||||
challenge.err = ErrInvalidCredential
|
|
||||||
return nil, &challenge
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return auth.WithUser(ctx, auth.UserInfo{Name: user}), nil
|
if err := ac.htpasswd.authenticateUser(ctx, username, password); err != nil {
|
||||||
|
ctxu.GetLogger(ctx).Errorf("error authenticating user %q: %v", username, err)
|
||||||
|
return nil, &challenge{
|
||||||
|
realm: ac.realm,
|
||||||
|
err: ErrAuthenticationFailure,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return auth.WithUser(ctx, auth.UserInfo{Name: username}), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// challenge implements the auth.Challenge interface.
|
||||||
|
type challenge struct {
|
||||||
|
realm string
|
||||||
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ch *challenge) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (ch *challenge) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
|
@ -1,14 +1,13 @@
|
||||||
package basic
|
package basic
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/base64"
|
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/docker/distribution/context"
|
||||||
"github.com/docker/distribution/registry/auth"
|
"github.com/docker/distribution/registry/auth"
|
||||||
"golang.org/x/net/context"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestBasicAccessController(t *testing.T) {
|
func TestBasicAccessController(t *testing.T) {
|
||||||
|
@ -33,6 +32,7 @@ func TestBasicAccessController(t *testing.T) {
|
||||||
"realm": testRealm,
|
"realm": testRealm,
|
||||||
"path": tempFile.Name(),
|
"path": tempFile.Name(),
|
||||||
}
|
}
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
accessController, err := newAccessController(options)
|
accessController, err := newAccessController(options)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -44,7 +44,7 @@ func TestBasicAccessController(t *testing.T) {
|
||||||
var userNumber = 0
|
var userNumber = 0
|
||||||
|
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := context.WithValue(nil, "http.request", r)
|
ctx := context.WithRequest(ctx, r)
|
||||||
authCtx, err := accessController.Authorized(ctx)
|
authCtx, err := accessController.Authorized(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
switch err := err.(type) {
|
switch err := err.(type) {
|
||||||
|
@ -87,13 +87,14 @@ func TestBasicAccessController(t *testing.T) {
|
||||||
|
|
||||||
for i := 0; i < len(testUsers); i++ {
|
for i := 0; i < len(testUsers); i++ {
|
||||||
userNumber = i
|
userNumber = i
|
||||||
req, _ = http.NewRequest("GET", server.URL, nil)
|
req, err := http.NewRequest("GET", server.URL, nil)
|
||||||
sekrit := testUsers[i] + ":" + testPasswords[i]
|
if err != nil {
|
||||||
credential := "Basic " + base64.StdEncoding.EncodeToString([]byte(sekrit))
|
t.Fatalf("error allocating new request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.SetBasicAuth(testUsers[i], testPasswords[i])
|
||||||
|
|
||||||
req.Header.Set("Authorization", credential)
|
|
||||||
resp, err = client.Do(req)
|
resp, err = client.Do(req)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error during GET: %v", err)
|
t.Fatalf("unexpected error during GET: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -101,7 +102,7 @@ func TestBasicAccessController(t *testing.T) {
|
||||||
|
|
||||||
// Request should be authorized
|
// Request should be authorized
|
||||||
if resp.StatusCode != http.StatusNoContent {
|
if resp.StatusCode != http.StatusNoContent {
|
||||||
t.Fatalf("unexpected non-success response status: %v != %v for %s %s %s", resp.StatusCode, http.StatusNoContent, testUsers[i], testPasswords[i], credential)
|
t.Fatalf("unexpected non-success response status: %v != %v for %s %s", resp.StatusCode, http.StatusNoContent, testUsers[i], testPasswords[i])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,54 +1,66 @@
|
||||||
package basic
|
package basic
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"crypto/sha1"
|
"crypto/sha1"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/csv"
|
"io"
|
||||||
"errors"
|
|
||||||
"os"
|
"os"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/docker/distribution/context"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ErrAuthenticationFailure A generic error message for authentication failure to be presented to agent.
|
// htpasswd holds a path to a system .htpasswd file and the machinery to parse it.
|
||||||
var ErrAuthenticationFailure = errors.New("Bad username or password")
|
|
||||||
|
|
||||||
// htpasswd Holds a path to a system .htpasswd file and the machinery to parse it.
|
|
||||||
type htpasswd struct {
|
type htpasswd struct {
|
||||||
path string
|
path string
|
||||||
reader *csv.Reader
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthType Represents a particular hash function used in the htpasswd file.
|
// authType represents a particular hash function used in the htpasswd file.
|
||||||
type AuthType int
|
type authType int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// PlainText Plain-text password storage (htpasswd -p)
|
authTypePlainText authType = iota // Plain-text password storage (htpasswd -p)
|
||||||
PlainText AuthType = iota
|
authTypeSHA1 // sha hashed password storage (htpasswd -s)
|
||||||
// SHA1 sha hashed password storage (htpasswd -s)
|
authTypeApacheMD5 // apr iterated md5 hashing (htpasswd -m)
|
||||||
SHA1
|
authTypeBCrypt // BCrypt adapative password hashing (htpasswd -B)
|
||||||
// ApacheMD5 apr iterated md5 hashing (htpasswd -m)
|
authTypeCrypt // System crypt() hashes. (htpasswd -d)
|
||||||
ApacheMD5
|
|
||||||
// BCrypt BCrypt adapative password hashing (htpasswd -B)
|
|
||||||
BCrypt
|
|
||||||
// Crypt System crypt() hashes. (htpasswd -d)
|
|
||||||
Crypt
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var bcryptPrefixRegexp = regexp.MustCompile(`^\$2[ab]?y\$`)
|
||||||
|
|
||||||
|
// detectAuthCredentialType inspects the credential and resolves the encryption scheme.
|
||||||
|
func detectAuthCredentialType(cred string) authType {
|
||||||
|
if strings.HasPrefix(cred, "{SHA}") {
|
||||||
|
return authTypeSHA1
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(cred, "$apr1$") {
|
||||||
|
return authTypeApacheMD5
|
||||||
|
}
|
||||||
|
if bcryptPrefixRegexp.MatchString(cred) {
|
||||||
|
return authTypeBCrypt
|
||||||
|
}
|
||||||
|
// There's just not a great way to distinguish between these next two...
|
||||||
|
if len(cred) == 13 {
|
||||||
|
return authTypeCrypt
|
||||||
|
}
|
||||||
|
return authTypePlainText
|
||||||
|
}
|
||||||
|
|
||||||
// String Returns a text representation of the AuthType
|
// String Returns a text representation of the AuthType
|
||||||
func (at AuthType) String() string {
|
func (at authType) String() string {
|
||||||
switch at {
|
switch at {
|
||||||
case PlainText:
|
case authTypePlainText:
|
||||||
return "plaintext"
|
return "plaintext"
|
||||||
case SHA1:
|
case authTypeSHA1:
|
||||||
return "sha1"
|
return "sha1"
|
||||||
case ApacheMD5:
|
case authTypeApacheMD5:
|
||||||
return "md5"
|
return "md5"
|
||||||
case BCrypt:
|
case authTypeBCrypt:
|
||||||
return "bcrypt"
|
return "bcrypt"
|
||||||
case Crypt:
|
case authTypeCrypt:
|
||||||
return "system crypt"
|
return "system crypt"
|
||||||
}
|
}
|
||||||
return "unknown"
|
return "unknown"
|
||||||
|
@ -59,83 +71,80 @@ func newHTPasswd(htpath string) *htpasswd {
|
||||||
return &htpasswd{path: htpath}
|
return &htpasswd{path: htpath}
|
||||||
}
|
}
|
||||||
|
|
||||||
var bcryptPrefixRegexp = regexp.MustCompile(`^\$2[ab]?y\$`)
|
// AuthenticateUser checks a given user:password credential against the
|
||||||
|
// receiving HTPasswd's file. If the check passes, nil is returned. Note that
|
||||||
// GetAuthCredentialType Inspect an htpasswd file credential and guess the encryption algorithm used.
|
// this parses the htpasswd file on each request so ensure that updates are
|
||||||
func GetAuthCredentialType(cred string) AuthType {
|
// available.
|
||||||
if strings.HasPrefix(cred, "{SHA}") {
|
func (htpasswd *htpasswd) authenticateUser(ctx context.Context, username string, password string) error {
|
||||||
return SHA1
|
|
||||||
}
|
|
||||||
if strings.HasPrefix(cred, "$apr1$") {
|
|
||||||
return ApacheMD5
|
|
||||||
}
|
|
||||||
if bcryptPrefixRegexp.MatchString(cred) {
|
|
||||||
return BCrypt
|
|
||||||
}
|
|
||||||
// There's just not a great way to distinguish between these next two...
|
|
||||||
if len(cred) == 13 {
|
|
||||||
return Crypt
|
|
||||||
}
|
|
||||||
return PlainText
|
|
||||||
}
|
|
||||||
|
|
||||||
// AuthenticateUser Check a given user:password credential against the receiving HTPasswd's file.
|
|
||||||
func (htpasswd *htpasswd) AuthenticateUser(user string, pwd string) (bool, error) {
|
|
||||||
|
|
||||||
// Open the file.
|
// Open the file.
|
||||||
in, err := os.Open(htpasswd.path)
|
in, err := os.Open(htpasswd.path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return err
|
||||||
|
}
|
||||||
|
defer in.Close()
|
||||||
|
|
||||||
|
for _, entry := range parseHTPasswd(ctx, in) {
|
||||||
|
if entry.username != username {
|
||||||
|
continue // wrong entry
|
||||||
|
}
|
||||||
|
|
||||||
|
switch t := detectAuthCredentialType(entry.password); t {
|
||||||
|
case authTypeSHA1:
|
||||||
|
sha := sha1.New()
|
||||||
|
sha.Write([]byte(password))
|
||||||
|
hash := base64.StdEncoding.EncodeToString(sha.Sum(nil))
|
||||||
|
|
||||||
|
if entry.password[5:] != hash {
|
||||||
|
return ErrAuthenticationFailure
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
case authTypeBCrypt:
|
||||||
|
err := bcrypt.CompareHashAndPassword([]byte(entry.password), []byte(password))
|
||||||
|
if err != nil {
|
||||||
|
return ErrAuthenticationFailure
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
case authTypePlainText:
|
||||||
|
if password != entry.password {
|
||||||
|
return ErrAuthenticationFailure
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
context.GetLogger(ctx).Errorf("unsupported basic authentication type: %v", t)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse the contents of the standard .htpasswd until we hit the end or find a match.
|
return ErrAuthenticationFailure
|
||||||
reader := csv.NewReader(in)
|
}
|
||||||
reader.Comma = ':'
|
|
||||||
reader.Comment = '#'
|
// htpasswdEntry represents a line in an htpasswd file.
|
||||||
reader.TrimLeadingSpace = true
|
type htpasswdEntry struct {
|
||||||
for entry, readerr := reader.Read(); entry != nil || readerr != nil; entry, readerr = reader.Read() {
|
username string // username, plain text
|
||||||
if readerr != nil {
|
password string // stores hashed passwd
|
||||||
return false, readerr
|
}
|
||||||
}
|
|
||||||
if len(entry) == 0 {
|
// parseHTPasswd parses the contents of htpasswd. Bad entries are skipped and
|
||||||
|
// logged, so this may return empty. This will read all the entries in the
|
||||||
|
// file, whether or not they are needed.
|
||||||
|
func parseHTPasswd(ctx context.Context, rd io.Reader) []htpasswdEntry {
|
||||||
|
entries := []htpasswdEntry{}
|
||||||
|
scanner := bufio.NewScanner(rd)
|
||||||
|
for scanner.Scan() {
|
||||||
|
t := strings.TrimSpace(scanner.Text())
|
||||||
|
i := strings.Index(t, ":")
|
||||||
|
if i < 0 || i >= len(t) {
|
||||||
|
context.GetLogger(ctx).Errorf("bad entry in htpasswd: %q", t)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if entry[0] == user {
|
|
||||||
credential := entry[1]
|
entries = append(entries, htpasswdEntry{
|
||||||
credType := GetAuthCredentialType(credential)
|
username: t[:i],
|
||||||
switch credType {
|
password: t[i+1:],
|
||||||
case SHA1:
|
})
|
||||||
{
|
|
||||||
sha := sha1.New()
|
|
||||||
sha.Write([]byte(pwd))
|
|
||||||
hash := base64.StdEncoding.EncodeToString(sha.Sum(nil))
|
|
||||||
return entry[1][5:] == hash, nil
|
|
||||||
}
|
|
||||||
case ApacheMD5:
|
|
||||||
{
|
|
||||||
return false, errors.New(ApacheMD5.String() + " htpasswd hash function not yet supported")
|
|
||||||
}
|
|
||||||
case BCrypt:
|
|
||||||
{
|
|
||||||
err := bcrypt.CompareHashAndPassword([]byte(credential), []byte(pwd))
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
case Crypt:
|
|
||||||
{
|
|
||||||
return false, errors.New(Crypt.String() + " htpasswd hash function not yet supported")
|
|
||||||
}
|
|
||||||
case PlainText:
|
|
||||||
{
|
|
||||||
if pwd == credential {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
return false, ErrAuthenticationFailure
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return false, ErrAuthenticationFailure
|
|
||||||
|
return entries
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue