1cfed18aa7
This populates the authenticated user from the client certificate common name. Also added tests for the existing client certificate functionality.
393 lines
9.6 KiB
Go
393 lines
9.6 KiB
Go
package http
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestMiddlewareAuth(t *testing.T) {
|
|
servers := []struct {
|
|
name string
|
|
http Config
|
|
auth AuthConfig
|
|
user string
|
|
pass string
|
|
}{
|
|
{
|
|
name: "Basic",
|
|
http: Config{
|
|
ListenAddr: []string{"127.0.0.1:0"},
|
|
},
|
|
auth: AuthConfig{
|
|
Realm: "test",
|
|
BasicUser: "test",
|
|
BasicPass: "test",
|
|
},
|
|
user: "test",
|
|
pass: "test",
|
|
},
|
|
{
|
|
name: "Htpasswd/MD5",
|
|
http: Config{
|
|
ListenAddr: []string{"127.0.0.1:0"},
|
|
},
|
|
auth: AuthConfig{
|
|
Realm: "test",
|
|
HtPasswd: "./testdata/.htpasswd",
|
|
},
|
|
user: "md5",
|
|
pass: "md5",
|
|
},
|
|
{
|
|
name: "Htpasswd/SHA",
|
|
http: Config{
|
|
ListenAddr: []string{"127.0.0.1:0"},
|
|
},
|
|
auth: AuthConfig{
|
|
Realm: "test",
|
|
HtPasswd: "./testdata/.htpasswd",
|
|
},
|
|
user: "sha",
|
|
pass: "sha",
|
|
},
|
|
{
|
|
name: "Htpasswd/Bcrypt",
|
|
http: Config{
|
|
ListenAddr: []string{"127.0.0.1:0"},
|
|
},
|
|
auth: AuthConfig{
|
|
Realm: "test",
|
|
HtPasswd: "./testdata/.htpasswd",
|
|
},
|
|
user: "bcrypt",
|
|
pass: "bcrypt",
|
|
},
|
|
{
|
|
name: "Custom",
|
|
http: Config{
|
|
ListenAddr: []string{"127.0.0.1:0"},
|
|
},
|
|
auth: AuthConfig{
|
|
Realm: "test",
|
|
CustomAuthFn: func(user, pass string) (value interface{}, err error) {
|
|
if user == "custom" && pass == "custom" {
|
|
return true, nil
|
|
}
|
|
return nil, errors.New("invalid credentials")
|
|
},
|
|
},
|
|
user: "custom",
|
|
pass: "custom",
|
|
},
|
|
}
|
|
|
|
for _, ss := range servers {
|
|
t.Run(ss.name, func(t *testing.T) {
|
|
s, err := NewServer(context.Background(), WithConfig(ss.http), WithAuth(ss.auth))
|
|
require.NoError(t, err)
|
|
defer func() {
|
|
require.NoError(t, s.Shutdown())
|
|
}()
|
|
|
|
expected := []byte("secret-page")
|
|
s.Router().Mount("/", testEchoHandler(expected))
|
|
s.Serve()
|
|
|
|
url := testGetServerURL(t, s)
|
|
|
|
t.Run("NoCreds", func(t *testing.T) {
|
|
client := &http.Client{}
|
|
req, err := http.NewRequest("GET", url, nil)
|
|
require.NoError(t, err)
|
|
|
|
resp, err := client.Do(req)
|
|
require.NoError(t, err)
|
|
defer func() {
|
|
_ = resp.Body.Close()
|
|
}()
|
|
|
|
require.Equal(t, http.StatusUnauthorized, resp.StatusCode, "using no creds should return unauthorized")
|
|
|
|
wwwAuthHeader := resp.Header.Get("WWW-Authenticate")
|
|
require.NotEmpty(t, wwwAuthHeader, "resp should contain WWW-Authtentication header")
|
|
require.Contains(t, wwwAuthHeader, fmt.Sprintf("realm=%q", ss.auth.Realm), "WWW-Authtentication header should contain relam")
|
|
})
|
|
|
|
t.Run("BadCreds", func(t *testing.T) {
|
|
client := &http.Client{}
|
|
req, err := http.NewRequest("GET", url, nil)
|
|
require.NoError(t, err)
|
|
|
|
req.SetBasicAuth(ss.user+"BAD", ss.pass+"BAD")
|
|
|
|
resp, err := client.Do(req)
|
|
require.NoError(t, err)
|
|
defer func() {
|
|
_ = resp.Body.Close()
|
|
}()
|
|
|
|
require.Equal(t, http.StatusUnauthorized, resp.StatusCode, "using bad creds should return unauthorized")
|
|
|
|
wwwAuthHeader := resp.Header.Get("WWW-Authenticate")
|
|
require.NotEmpty(t, wwwAuthHeader, "resp should contain WWW-Authtentication header")
|
|
require.Contains(t, wwwAuthHeader, fmt.Sprintf("realm=%q", ss.auth.Realm), "WWW-Authtentication header should contain relam")
|
|
})
|
|
|
|
t.Run("GoodCreds", func(t *testing.T) {
|
|
client := &http.Client{}
|
|
req, err := http.NewRequest("GET", url, nil)
|
|
require.NoError(t, err)
|
|
|
|
req.SetBasicAuth(ss.user, ss.pass)
|
|
|
|
resp, err := client.Do(req)
|
|
require.NoError(t, err)
|
|
defer func() {
|
|
_ = resp.Body.Close()
|
|
}()
|
|
|
|
require.Equal(t, http.StatusOK, resp.StatusCode, "using good creds should return ok")
|
|
|
|
testExpectRespBody(t, resp, expected)
|
|
})
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestMiddlewareAuthCertificateUser(t *testing.T) {
|
|
serverCertBytes := testReadTestdataFile(t, "local.crt")
|
|
serverKeyBytes := testReadTestdataFile(t, "local.key")
|
|
clientCertBytes := testReadTestdataFile(t, "client.crt")
|
|
clientKeyBytes := testReadTestdataFile(t, "client.key")
|
|
clientCert, err := tls.X509KeyPair(clientCertBytes, clientKeyBytes)
|
|
require.NoError(t, err)
|
|
emptyCertBytes := testReadTestdataFile(t, "emptyclient.crt")
|
|
emptyKeyBytes := testReadTestdataFile(t, "emptyclient.key")
|
|
emptyCert, err := tls.X509KeyPair(emptyCertBytes, emptyKeyBytes)
|
|
require.NoError(t, err)
|
|
invalidCert, err := tls.X509KeyPair(serverCertBytes, serverKeyBytes)
|
|
require.NoError(t, err)
|
|
|
|
servers := []struct {
|
|
name string
|
|
wantErr bool
|
|
status int
|
|
result string
|
|
http Config
|
|
auth AuthConfig
|
|
clientCerts []tls.Certificate
|
|
}{
|
|
{
|
|
name: "Missing",
|
|
wantErr: true,
|
|
http: Config{
|
|
ListenAddr: []string{"127.0.0.1:0"},
|
|
TLSCertBody: serverCertBytes,
|
|
TLSKeyBody: serverKeyBytes,
|
|
MinTLSVersion: "tls1.0",
|
|
ClientCA: "./testdata/client-ca.crt",
|
|
},
|
|
},
|
|
{
|
|
name: "Invalid",
|
|
wantErr: true,
|
|
clientCerts: []tls.Certificate{invalidCert},
|
|
http: Config{
|
|
ListenAddr: []string{"127.0.0.1:0"},
|
|
TLSCertBody: serverCertBytes,
|
|
TLSKeyBody: serverKeyBytes,
|
|
MinTLSVersion: "tls1.0",
|
|
ClientCA: "./testdata/client-ca.crt",
|
|
},
|
|
},
|
|
{
|
|
name: "EmptyCommonName",
|
|
status: http.StatusUnauthorized,
|
|
result: fmt.Sprintf("%s\n", http.StatusText(http.StatusUnauthorized)),
|
|
clientCerts: []tls.Certificate{emptyCert},
|
|
http: Config{
|
|
ListenAddr: []string{"127.0.0.1:0"},
|
|
TLSCertBody: serverCertBytes,
|
|
TLSKeyBody: serverKeyBytes,
|
|
MinTLSVersion: "tls1.0",
|
|
ClientCA: "./testdata/client-ca.crt",
|
|
},
|
|
},
|
|
{
|
|
name: "Valid",
|
|
status: http.StatusOK,
|
|
result: "rclone-dev-client",
|
|
clientCerts: []tls.Certificate{clientCert},
|
|
http: Config{
|
|
ListenAddr: []string{"127.0.0.1:0"},
|
|
TLSCertBody: serverCertBytes,
|
|
TLSKeyBody: serverKeyBytes,
|
|
MinTLSVersion: "tls1.0",
|
|
ClientCA: "./testdata/client-ca.crt",
|
|
},
|
|
},
|
|
{
|
|
name: "CustomAuth/Invalid",
|
|
status: http.StatusUnauthorized,
|
|
result: fmt.Sprintf("%d %s\n", http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)),
|
|
clientCerts: []tls.Certificate{clientCert},
|
|
http: Config{
|
|
ListenAddr: []string{"127.0.0.1:0"},
|
|
TLSCertBody: serverCertBytes,
|
|
TLSKeyBody: serverKeyBytes,
|
|
MinTLSVersion: "tls1.0",
|
|
ClientCA: "./testdata/client-ca.crt",
|
|
},
|
|
auth: AuthConfig{
|
|
Realm: "test",
|
|
CustomAuthFn: func(user, pass string) (value interface{}, err error) {
|
|
if user == "custom" && pass == "custom" {
|
|
return true, nil
|
|
}
|
|
return nil, errors.New("invalid credentials")
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "CustomAuth/Valid",
|
|
status: http.StatusOK,
|
|
result: "rclone-dev-client",
|
|
clientCerts: []tls.Certificate{clientCert},
|
|
http: Config{
|
|
ListenAddr: []string{"127.0.0.1:0"},
|
|
TLSCertBody: serverCertBytes,
|
|
TLSKeyBody: serverKeyBytes,
|
|
MinTLSVersion: "tls1.0",
|
|
ClientCA: "./testdata/client-ca.crt",
|
|
},
|
|
auth: AuthConfig{
|
|
Realm: "test",
|
|
CustomAuthFn: func(user, pass string) (value interface{}, err error) {
|
|
fmt.Println("CUSTOMAUTH", user, pass)
|
|
if user == "rclone-dev-client" && pass == "" {
|
|
return true, nil
|
|
}
|
|
return nil, errors.New("invalid credentials")
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, ss := range servers {
|
|
t.Run(ss.name, func(t *testing.T) {
|
|
s, err := NewServer(context.Background(), WithConfig(ss.http), WithAuth(ss.auth))
|
|
require.NoError(t, err)
|
|
defer func() {
|
|
require.NoError(t, s.Shutdown())
|
|
}()
|
|
|
|
s.Router().Mount("/", testAuthUserHandler())
|
|
s.Serve()
|
|
|
|
url := testGetServerURL(t, s)
|
|
client := &http.Client{
|
|
Transport: &http.Transport{
|
|
TLSClientConfig: &tls.Config{
|
|
Certificates: ss.clientCerts,
|
|
InsecureSkipVerify: true,
|
|
},
|
|
},
|
|
}
|
|
req, err := http.NewRequest("GET", url, nil)
|
|
require.NoError(t, err)
|
|
|
|
resp, err := client.Do(req)
|
|
if ss.wantErr {
|
|
require.Error(t, err)
|
|
return
|
|
}
|
|
require.NoError(t, err)
|
|
|
|
defer func() {
|
|
_ = resp.Body.Close()
|
|
}()
|
|
|
|
require.Equal(t, ss.status, resp.StatusCode, fmt.Sprintf("should return status %d", ss.status))
|
|
|
|
testExpectRespBody(t, resp, []byte(ss.result))
|
|
})
|
|
}
|
|
|
|
}
|
|
|
|
var _testCORSHeaderKeys = []string{
|
|
"Access-Control-Allow-Origin",
|
|
"Access-Control-Request-Method",
|
|
"Access-Control-Allow-Headers",
|
|
}
|
|
|
|
func TestMiddlewareCORS(t *testing.T) {
|
|
servers := []struct {
|
|
name string
|
|
http Config
|
|
origin string
|
|
}{
|
|
{
|
|
name: "EmptyOrigin",
|
|
http: Config{
|
|
ListenAddr: []string{"127.0.0.1:0"},
|
|
},
|
|
origin: "",
|
|
},
|
|
{
|
|
name: "CustomOrigin",
|
|
http: Config{
|
|
ListenAddr: []string{"127.0.0.1:0"},
|
|
},
|
|
origin: "http://test.rclone.org",
|
|
},
|
|
}
|
|
|
|
for _, ss := range servers {
|
|
t.Run(ss.name, func(t *testing.T) {
|
|
s, err := NewServer(context.Background(), WithConfig(ss.http))
|
|
require.NoError(t, err)
|
|
defer func() {
|
|
require.NoError(t, s.Shutdown())
|
|
}()
|
|
|
|
s.Router().Use(MiddlewareCORS(ss.origin))
|
|
|
|
expected := []byte("data")
|
|
s.Router().Mount("/", testEchoHandler(expected))
|
|
s.Serve()
|
|
|
|
url := testGetServerURL(t, s)
|
|
|
|
client := &http.Client{}
|
|
req, err := http.NewRequest("GET", url, nil)
|
|
require.NoError(t, err)
|
|
|
|
resp, err := client.Do(req)
|
|
require.NoError(t, err)
|
|
defer func() {
|
|
_ = resp.Body.Close()
|
|
}()
|
|
|
|
require.Equal(t, http.StatusOK, resp.StatusCode, "should return ok")
|
|
|
|
testExpectRespBody(t, resp, expected)
|
|
|
|
for _, key := range _testCORSHeaderKeys {
|
|
require.Contains(t, resp.Header, key, "CORS headers should be sent")
|
|
}
|
|
|
|
expectedOrigin := url
|
|
if ss.origin != "" {
|
|
expectedOrigin = ss.origin
|
|
}
|
|
require.Equal(t, expectedOrigin, resp.Header.Get("Access-Control-Allow-Origin"), "allow origin should match")
|
|
})
|
|
}
|
|
}
|