43e502590f
It's possible to configure log fields in the configuration file, and we would like these fields to be included in all logs. Previously these fields were included only in logs produced using the main routine's context, meaning that any logs from a request handler were missing the fields since those use a context based on the HTTP request's context. Add a configurable default logger to the `context` package, and set it when configuring logging at startup time. Signed-off-by: Adam Wolfe Gordon <awg@digitalocean.com>
407 lines
12 KiB
Go
407 lines
12 KiB
Go
package registry
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"crypto"
|
|
"crypto/ecdsa"
|
|
"crypto/elliptic"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"crypto/x509/pkix"
|
|
"encoding/pem"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"math/big"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"path"
|
|
"reflect"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/distribution/distribution/v3/configuration"
|
|
dcontext "github.com/distribution/distribution/v3/context"
|
|
_ "github.com/distribution/distribution/v3/registry/storage/driver/inmemory"
|
|
"github.com/sirupsen/logrus"
|
|
"gopkg.in/yaml.v2"
|
|
)
|
|
|
|
// Tests to ensure nextProtos returns the correct protocols when:
|
|
// * config.HTTP.HTTP2.Disabled is not explicitly set => [h2 http/1.1]
|
|
// * config.HTTP.HTTP2.Disabled is explicitly set to false [h2 http/1.1]
|
|
// * config.HTTP.HTTP2.Disabled is explicitly set to true [http/1.1]
|
|
func TestNextProtos(t *testing.T) {
|
|
config := &configuration.Configuration{}
|
|
protos := nextProtos(config)
|
|
if !reflect.DeepEqual(protos, []string{"h2", "http/1.1"}) {
|
|
t.Fatalf("expected protos to equal [h2 http/1.1], got %s", protos)
|
|
}
|
|
config.HTTP.HTTP2.Disabled = false
|
|
protos = nextProtos(config)
|
|
if !reflect.DeepEqual(protos, []string{"h2", "http/1.1"}) {
|
|
t.Fatalf("expected protos to equal [h2 http/1.1], got %s", protos)
|
|
}
|
|
config.HTTP.HTTP2.Disabled = true
|
|
protos = nextProtos(config)
|
|
if !reflect.DeepEqual(protos, []string{"http/1.1"}) {
|
|
t.Fatalf("expected protos to equal [http/1.1], got %s", protos)
|
|
}
|
|
}
|
|
|
|
type registryTLSConfig struct {
|
|
cipherSuites []string
|
|
certificatePath string
|
|
privateKeyPath string
|
|
certificate *tls.Certificate
|
|
}
|
|
|
|
func setupRegistry(tlsCfg *registryTLSConfig, addr string) (*Registry, error) {
|
|
config := &configuration.Configuration{}
|
|
// TODO: this needs to change to something ephemeral as the test will fail if there is any server
|
|
// already listening on port 5000
|
|
config.HTTP.Addr = addr
|
|
config.HTTP.DrainTimeout = time.Duration(10) * time.Second
|
|
if tlsCfg != nil {
|
|
config.HTTP.TLS.CipherSuites = tlsCfg.cipherSuites
|
|
config.HTTP.TLS.Certificate = tlsCfg.certificatePath
|
|
config.HTTP.TLS.Key = tlsCfg.privateKeyPath
|
|
}
|
|
config.Storage = map[string]configuration.Parameters{"inmemory": map[string]interface{}{}}
|
|
return NewRegistry(context.Background(), config)
|
|
}
|
|
|
|
func TestGracefulShutdown(t *testing.T) {
|
|
registry, err := setupRegistry(nil, ":5000")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// run registry server
|
|
var errchan chan error
|
|
go func() {
|
|
errchan <- registry.ListenAndServe()
|
|
}()
|
|
select {
|
|
case err = <-errchan:
|
|
t.Fatalf("Error listening: %v", err)
|
|
default:
|
|
}
|
|
|
|
// Wait for some unknown random time for server to start listening
|
|
time.Sleep(3 * time.Second)
|
|
|
|
// send incomplete request
|
|
conn, err := net.Dial("tcp", "localhost:5000")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
fmt.Fprintf(conn, "GET /v2/ ")
|
|
|
|
// send stop signal
|
|
quit <- os.Interrupt
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
// try connecting again. it shouldn't
|
|
_, err = net.Dial("tcp", "localhost:5000")
|
|
if err == nil {
|
|
t.Fatal("Managed to connect after stopping.")
|
|
}
|
|
|
|
// make sure earlier request is not disconnected and response can be received
|
|
fmt.Fprintf(conn, "HTTP/1.1\r\nHost: 127.0.0.1\r\n\r\n")
|
|
resp, err := http.ReadResponse(bufio.NewReader(conn), nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if resp.Status != "200 OK" {
|
|
t.Error("response status is not 200 OK: ", resp.Status)
|
|
}
|
|
if body, err := ioutil.ReadAll(resp.Body); err != nil || string(body) != "{}" {
|
|
t.Error("Body is not {}; ", string(body))
|
|
}
|
|
}
|
|
|
|
func TestGetCipherSuite(t *testing.T) {
|
|
resp, err := getCipherSuites([]string{"TLS_RSA_WITH_AES_128_CBC_SHA"})
|
|
if err != nil || len(resp) != 1 || resp[0] != tls.TLS_RSA_WITH_AES_128_CBC_SHA {
|
|
t.Errorf("expected cipher suite %q, got %q",
|
|
"TLS_RSA_WITH_AES_128_CBC_SHA",
|
|
strings.Join(getCipherSuiteNames(resp), ","),
|
|
)
|
|
}
|
|
|
|
resp, err = getCipherSuites([]string{"TLS_RSA_WITH_AES_128_CBC_SHA", "TLS_AES_128_GCM_SHA256"})
|
|
if err != nil || len(resp) != 2 ||
|
|
resp[0] != tls.TLS_RSA_WITH_AES_128_CBC_SHA || resp[1] != tls.TLS_AES_128_GCM_SHA256 {
|
|
t.Errorf("expected cipher suites %q, got %q",
|
|
"TLS_RSA_WITH_AES_128_CBC_SHA,TLS_AES_128_GCM_SHA256",
|
|
strings.Join(getCipherSuiteNames(resp), ","),
|
|
)
|
|
}
|
|
|
|
_, err = getCipherSuites([]string{"TLS_RSA_WITH_AES_128_CBC_SHA", "bad_input"})
|
|
if err == nil {
|
|
t.Error("did not return expected error about unknown cipher suite")
|
|
}
|
|
}
|
|
|
|
func buildRegistryTLSConfig(name, keyType string, cipherSuites []string) (*registryTLSConfig, error) {
|
|
var priv interface{}
|
|
var pub crypto.PublicKey
|
|
var err error
|
|
switch keyType {
|
|
case "rsa":
|
|
priv, err = rsa.GenerateKey(rand.Reader, 2048)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create rsa private key: %v", err)
|
|
}
|
|
rsaKey := priv.(*rsa.PrivateKey)
|
|
pub = rsaKey.Public()
|
|
case "ecdsa":
|
|
priv, err = ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create ecdsa private key: %v", err)
|
|
}
|
|
ecdsaKey := priv.(*ecdsa.PrivateKey)
|
|
pub = ecdsaKey.Public()
|
|
default:
|
|
return nil, fmt.Errorf("unsupported key type: %v", keyType)
|
|
}
|
|
|
|
notBefore := time.Now()
|
|
notAfter := notBefore.Add(time.Minute)
|
|
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
|
|
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create serial number: %v", err)
|
|
}
|
|
cert := x509.Certificate{
|
|
SerialNumber: serialNumber,
|
|
Subject: pkix.Name{
|
|
Organization: []string{"registry_test"},
|
|
},
|
|
NotBefore: notBefore,
|
|
NotAfter: notAfter,
|
|
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
|
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
|
BasicConstraintsValid: true,
|
|
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
|
|
DNSNames: []string{"localhost"},
|
|
IsCA: true,
|
|
}
|
|
derBytes, err := x509.CreateCertificate(rand.Reader, &cert, &cert, pub, priv)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create certificate: %v", err)
|
|
}
|
|
if _, err := os.Stat(os.TempDir()); os.IsNotExist(err) {
|
|
os.Mkdir(os.TempDir(), 1777)
|
|
}
|
|
|
|
certPath := path.Join(os.TempDir(), name+".pem")
|
|
certOut, err := os.Create(certPath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create pem: %v", err)
|
|
}
|
|
if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
|
|
return nil, fmt.Errorf("failed to write data to %s: %v", certPath, err)
|
|
}
|
|
if err := certOut.Close(); err != nil {
|
|
return nil, fmt.Errorf("error closing %s: %v", certPath, err)
|
|
}
|
|
|
|
keyPath := path.Join(os.TempDir(), name+".key")
|
|
keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to open %s for writing: %v", keyPath, err)
|
|
}
|
|
privBytes, err := x509.MarshalPKCS8PrivateKey(priv)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to marshal private key: %v", err)
|
|
}
|
|
if err := pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil {
|
|
return nil, fmt.Errorf("failed to write data to key.pem: %v", err)
|
|
}
|
|
if err := keyOut.Close(); err != nil {
|
|
return nil, fmt.Errorf("error closing %s: %v", keyPath, err)
|
|
}
|
|
|
|
tlsCert := tls.Certificate{
|
|
Certificate: [][]byte{derBytes},
|
|
PrivateKey: priv,
|
|
}
|
|
|
|
tlsTestCfg := registryTLSConfig{
|
|
cipherSuites: cipherSuites,
|
|
certificatePath: certPath,
|
|
privateKeyPath: keyPath,
|
|
certificate: &tlsCert,
|
|
}
|
|
|
|
return &tlsTestCfg, nil
|
|
}
|
|
|
|
func TestRegistrySupportedCipherSuite(t *testing.T) {
|
|
name := "registry_test_server_supported_cipher"
|
|
cipherSuites := []string{"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"}
|
|
serverTLS, err := buildRegistryTLSConfig(name, "rsa", cipherSuites)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
registry, err := setupRegistry(serverTLS, ":5001")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// run registry server
|
|
var errchan chan error
|
|
go func() {
|
|
errchan <- registry.ListenAndServe()
|
|
}()
|
|
select {
|
|
case err = <-errchan:
|
|
t.Fatalf("Error listening: %v", err)
|
|
default:
|
|
}
|
|
|
|
// Wait for some unknown random time for server to start listening
|
|
time.Sleep(3 * time.Second)
|
|
|
|
// send tls request with server supported cipher suite
|
|
clientCipherSuites, err := getCipherSuites(cipherSuites)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
clientTLS := tls.Config{
|
|
InsecureSkipVerify: true,
|
|
CipherSuites: clientCipherSuites,
|
|
}
|
|
dialer := net.Dialer{
|
|
Timeout: time.Second * 5,
|
|
}
|
|
conn, err := tls.DialWithDialer(&dialer, "tcp", "127.0.0.1:5001", &clientTLS)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
fmt.Fprintf(conn, "GET /v2/ HTTP/1.1\r\nHost: 127.0.0.1\r\n\r\n")
|
|
|
|
resp, err := http.ReadResponse(bufio.NewReader(conn), nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if resp.Status != "200 OK" {
|
|
t.Error("response status is not 200 OK: ", resp.Status)
|
|
}
|
|
if body, err := ioutil.ReadAll(resp.Body); err != nil || string(body) != "{}" {
|
|
t.Error("Body is not {}; ", string(body))
|
|
}
|
|
|
|
// send stop signal
|
|
quit <- os.Interrupt
|
|
time.Sleep(100 * time.Millisecond)
|
|
}
|
|
|
|
func TestRegistryUnsupportedCipherSuite(t *testing.T) {
|
|
name := "registry_test_server_unsupported_cipher"
|
|
serverTLS, err := buildRegistryTLSConfig(name, "rsa", []string{"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA358"})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
registry, err := setupRegistry(serverTLS, ":5002")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// run registry server
|
|
var errchan chan error
|
|
go func() {
|
|
errchan <- registry.ListenAndServe()
|
|
}()
|
|
select {
|
|
case err = <-errchan:
|
|
t.Fatalf("Error listening: %v", err)
|
|
default:
|
|
}
|
|
|
|
// Wait for some unknown random time for server to start listening
|
|
time.Sleep(3 * time.Second)
|
|
|
|
// send tls request with server unsupported cipher suite
|
|
clientTLS := tls.Config{
|
|
InsecureSkipVerify: true,
|
|
CipherSuites: []uint16{tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
|
|
}
|
|
dialer := net.Dialer{
|
|
Timeout: time.Second * 5,
|
|
}
|
|
_, err = tls.DialWithDialer(&dialer, "tcp", "127.0.0.1:5002", &clientTLS)
|
|
if err == nil {
|
|
t.Error("expected TLS connection to timeout")
|
|
}
|
|
|
|
// send stop signal
|
|
quit <- os.Interrupt
|
|
time.Sleep(100 * time.Millisecond)
|
|
}
|
|
|
|
func TestConfigureLogging(t *testing.T) {
|
|
yamlConfig := `---
|
|
log:
|
|
level: warn
|
|
fields:
|
|
foo: bar
|
|
baz: xyzzy
|
|
`
|
|
|
|
var config configuration.Configuration
|
|
err := yaml.Unmarshal([]byte(yamlConfig), &config)
|
|
if err != nil {
|
|
t.Fatal("failed to parse config: ", err)
|
|
}
|
|
|
|
ctx, err := configureLogging(context.Background(), &config)
|
|
if err != nil {
|
|
t.Fatal("failed to configure logging: ", err)
|
|
}
|
|
|
|
// Check that the log level was set to Warn.
|
|
if logrus.IsLevelEnabled(logrus.InfoLevel) {
|
|
t.Error("expected Info to be disabled, is enabled")
|
|
}
|
|
|
|
// Check that the returned context's logger includes the right fields.
|
|
logger := dcontext.GetLogger(ctx)
|
|
entry, ok := logger.(*logrus.Entry)
|
|
if !ok {
|
|
t.Fatalf("expected logger to be a *logrus.Entry, is: %T", entry)
|
|
}
|
|
val, ok := entry.Data["foo"].(string)
|
|
if !ok || val != "bar" {
|
|
t.Error("field foo not configured correctly; expected 'bar' got: ", val)
|
|
}
|
|
val, ok = entry.Data["baz"].(string)
|
|
if !ok || val != "xyzzy" {
|
|
t.Error("field baz not configured correctly; expected 'xyzzy' got: ", val)
|
|
}
|
|
|
|
// Get a logger for a new, empty context and make sure it also has the right fields.
|
|
logger = dcontext.GetLogger(context.Background())
|
|
entry, ok = logger.(*logrus.Entry)
|
|
if !ok {
|
|
t.Fatalf("expected logger to be a *logrus.Entry, is: %T", entry)
|
|
}
|
|
val, ok = entry.Data["foo"].(string)
|
|
if !ok || val != "bar" {
|
|
t.Error("field foo not configured correctly; expected 'bar' got: ", val)
|
|
}
|
|
val, ok = entry.Data["baz"].(string)
|
|
if !ok || val != "xyzzy" {
|
|
t.Error("field baz not configured correctly; expected 'xyzzy' got: ", val)
|
|
}
|
|
}
|