http: add client certificate user auth middleware
This populates the authenticated user from the client certificate common name. Also added tests for the existing client certificate functionality.
This commit is contained in:
parent
7751d5a00b
commit
1cfed18aa7
14 changed files with 458 additions and 29 deletions
|
@ -20,6 +20,16 @@ func testEchoHandler(data []byte) http.Handler {
|
|||
})
|
||||
}
|
||||
|
||||
func testAuthUserHandler() http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
userID, ok := CtxGetUser(r.Context())
|
||||
if !ok {
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
}
|
||||
_, _ = w.Write([]byte(userID))
|
||||
})
|
||||
}
|
||||
|
||||
func testExpectRespBody(t *testing.T, resp *http.Response, expected []byte) {
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
|
@ -234,19 +244,22 @@ func TestNewServerBaseURL(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestNewServerTLS(t *testing.T) {
|
||||
certBytes := testReadTestdataFile(t, "local.crt")
|
||||
keyBytes := testReadTestdataFile(t, "local.key")
|
||||
serverCertBytes := testReadTestdataFile(t, "local.crt")
|
||||
serverKeyBytes := testReadTestdataFile(t, "local.key")
|
||||
clientCertBytes := testReadTestdataFile(t, "client.crt")
|
||||
clientKeyBytes := testReadTestdataFile(t, "client.key")
|
||||
clientCert, err := tls.X509KeyPair(clientCertBytes, clientKeyBytes)
|
||||
require.NoError(t, err)
|
||||
|
||||
// TODO: generate a proper cert with SAN
|
||||
// TODO: generate CA, test mTLS
|
||||
// clientCert, err := tls.X509KeyPair(certBytes, keyBytes)
|
||||
// require.NoError(t, err, "should be testing with a valid self signed certificate")
|
||||
|
||||
servers := []struct {
|
||||
name string
|
||||
wantErr bool
|
||||
err error
|
||||
http Config
|
||||
name string
|
||||
clientCerts []tls.Certificate
|
||||
wantErr bool
|
||||
wantClientErr bool
|
||||
err error
|
||||
http Config
|
||||
}{
|
||||
{
|
||||
name: "FromFile/Valid",
|
||||
|
@ -303,8 +316,8 @@ func TestNewServerTLS(t *testing.T) {
|
|||
name: "FromBody/Valid",
|
||||
http: Config{
|
||||
ListenAddr: []string{"127.0.0.1:0"},
|
||||
TLSCertBody: certBytes,
|
||||
TLSKeyBody: keyBytes,
|
||||
TLSCertBody: serverCertBytes,
|
||||
TLSKeyBody: serverKeyBytes,
|
||||
MinTLSVersion: "tls1.0",
|
||||
},
|
||||
},
|
||||
|
@ -315,7 +328,7 @@ func TestNewServerTLS(t *testing.T) {
|
|||
http: Config{
|
||||
ListenAddr: []string{"127.0.0.1:0"},
|
||||
TLSCertBody: nil,
|
||||
TLSKeyBody: keyBytes,
|
||||
TLSKeyBody: serverKeyBytes,
|
||||
MinTLSVersion: "tls1.0",
|
||||
},
|
||||
},
|
||||
|
@ -325,7 +338,7 @@ func TestNewServerTLS(t *testing.T) {
|
|||
http: Config{
|
||||
ListenAddr: []string{"127.0.0.1:0"},
|
||||
TLSCertBody: []byte("JUNK DATA"),
|
||||
TLSKeyBody: keyBytes,
|
||||
TLSKeyBody: serverKeyBytes,
|
||||
MinTLSVersion: "tls1.0",
|
||||
},
|
||||
},
|
||||
|
@ -335,7 +348,7 @@ func TestNewServerTLS(t *testing.T) {
|
|||
err: ErrTLSBodyMismatch,
|
||||
http: Config{
|
||||
ListenAddr: []string{"127.0.0.1:0"},
|
||||
TLSCertBody: certBytes,
|
||||
TLSCertBody: serverCertBytes,
|
||||
TLSKeyBody: nil,
|
||||
MinTLSVersion: "tls1.0",
|
||||
},
|
||||
|
@ -345,7 +358,7 @@ func TestNewServerTLS(t *testing.T) {
|
|||
wantErr: true,
|
||||
http: Config{
|
||||
ListenAddr: []string{"127.0.0.1:0"},
|
||||
TLSCertBody: certBytes,
|
||||
TLSCertBody: serverCertBytes,
|
||||
TLSKeyBody: []byte("JUNK DATA"),
|
||||
MinTLSVersion: "tls1.0",
|
||||
},
|
||||
|
@ -354,8 +367,8 @@ func TestNewServerTLS(t *testing.T) {
|
|||
name: "MinTLSVersion/Valid/1.1",
|
||||
http: Config{
|
||||
ListenAddr: []string{"127.0.0.1:0"},
|
||||
TLSCertBody: certBytes,
|
||||
TLSKeyBody: keyBytes,
|
||||
TLSCertBody: serverCertBytes,
|
||||
TLSKeyBody: serverKeyBytes,
|
||||
MinTLSVersion: "tls1.1",
|
||||
},
|
||||
},
|
||||
|
@ -363,8 +376,8 @@ func TestNewServerTLS(t *testing.T) {
|
|||
name: "MinTLSVersion/Valid/1.2",
|
||||
http: Config{
|
||||
ListenAddr: []string{"127.0.0.1:0"},
|
||||
TLSCertBody: certBytes,
|
||||
TLSKeyBody: keyBytes,
|
||||
TLSCertBody: serverCertBytes,
|
||||
TLSKeyBody: serverKeyBytes,
|
||||
MinTLSVersion: "tls1.2",
|
||||
},
|
||||
},
|
||||
|
@ -372,8 +385,8 @@ func TestNewServerTLS(t *testing.T) {
|
|||
name: "MinTLSVersion/Valid/1.3",
|
||||
http: Config{
|
||||
ListenAddr: []string{"127.0.0.1:0"},
|
||||
TLSCertBody: certBytes,
|
||||
TLSKeyBody: keyBytes,
|
||||
TLSCertBody: serverCertBytes,
|
||||
TLSKeyBody: serverKeyBytes,
|
||||
MinTLSVersion: "tls1.3",
|
||||
},
|
||||
},
|
||||
|
@ -383,11 +396,46 @@ func TestNewServerTLS(t *testing.T) {
|
|||
err: ErrInvalidMinTLSVersion,
|
||||
http: Config{
|
||||
ListenAddr: []string{"127.0.0.1:0"},
|
||||
TLSCertBody: certBytes,
|
||||
TLSKeyBody: keyBytes,
|
||||
TLSCertBody: serverCertBytes,
|
||||
TLSKeyBody: serverKeyBytes,
|
||||
MinTLSVersion: "tls0.9",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "MutualTLS/InvalidCA",
|
||||
clientCerts: []tls.Certificate{clientCert},
|
||||
wantErr: true,
|
||||
http: Config{
|
||||
ListenAddr: []string{"127.0.0.1:0"},
|
||||
TLSCertBody: serverCertBytes,
|
||||
TLSKeyBody: serverKeyBytes,
|
||||
MinTLSVersion: "tls1.0",
|
||||
ClientCA: "./testdata/client-ca.crt.invalid",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "MutualTLS/InvalidClient",
|
||||
clientCerts: []tls.Certificate{},
|
||||
wantClientErr: true,
|
||||
http: Config{
|
||||
ListenAddr: []string{"127.0.0.1:0"},
|
||||
TLSCertBody: serverCertBytes,
|
||||
TLSKeyBody: serverKeyBytes,
|
||||
MinTLSVersion: "tls1.0",
|
||||
ClientCA: "./testdata/client-ca.crt",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "MutualTLS/Valid",
|
||||
clientCerts: []tls.Certificate{clientCert},
|
||||
http: Config{
|
||||
ListenAddr: []string{"127.0.0.1:0"},
|
||||
TLSCertBody: serverCertBytes,
|
||||
TLSKeyBody: serverKeyBytes,
|
||||
MinTLSVersion: "tls1.0",
|
||||
ClientCA: "./testdata/client-ca.crt",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, ss := range servers {
|
||||
|
@ -422,7 +470,7 @@ func TestNewServerTLS(t *testing.T) {
|
|||
return net.Dial("tcp", dest)
|
||||
},
|
||||
TLSClientConfig: &tls.Config{
|
||||
// Certificates: []tls.Certificate{clientCert},
|
||||
Certificates: ss.clientCerts,
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
},
|
||||
|
@ -431,6 +479,12 @@ func TestNewServerTLS(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
|
||||
if ss.wantClientErr {
|
||||
require.Error(t, err, "new server client should return error")
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue