package main import ( "context" "crypto/rand" "crypto/rsa" "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/pem" "fmt" "math/big" "net" "net/http" "os" "path" "testing" "time" "github.com/stretchr/testify/require" "golang.org/x/net/http2" ) const ( expHeaderKey = "Foo" expHeaderValue = "Bar" ) func TestHTTP2TLS(t *testing.T) { ctx := context.Background() certPath, keyPath := prepareTestCerts(t) srv := &http.Server{ Handler: http.HandlerFunc(testHandler), } tlsListener, err := newServer(ctx, ServerInfo{ Address: ":0", TLS: ServerTLSInfo{ Enabled: true, CertFile: certPath, KeyFile: keyPath, }, }) require.NoError(t, err) port := tlsListener.Listener().Addr().(*net.TCPAddr).Port addr := fmt.Sprintf("https://localhost:%d", port) go func() { _ = srv.Serve(tlsListener.Listener()) }() // Server is running, now send HTTP/2 request tlsClientConfig := &tls.Config{ InsecureSkipVerify: true, } cliHTTP1 := http.Client{Transport: &http.Transport{TLSClientConfig: tlsClientConfig}} cliHTTP2 := http.Client{Transport: &http2.Transport{TLSClientConfig: tlsClientConfig}} req, err := http.NewRequest("GET", addr, nil) require.NoError(t, err) req.Header[expHeaderKey] = []string{expHeaderValue} resp, err := cliHTTP1.Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) resp, err = cliHTTP2.Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) } func testHandler(resp http.ResponseWriter, req *http.Request) { hdr, ok := req.Header[expHeaderKey] if !ok || len(hdr) != 1 || hdr[0] != expHeaderValue { resp.WriteHeader(http.StatusBadRequest) } else { resp.WriteHeader(http.StatusOK) } } func prepareTestCerts(t *testing.T) (certPath, keyPath string) { privateKey, err := rsa.GenerateKey(rand.Reader, 2048) require.NoError(t, err) template := x509.Certificate{ SerialNumber: big.NewInt(1), Subject: pkix.Name{CommonName: "localhost"}, NotBefore: time.Now(), NotAfter: time.Now().Add(time.Hour * 24 * 365), KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, BasicConstraintsValid: true, } derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) require.NoError(t, err) dir := t.TempDir() certPath = path.Join(dir, "cert.pem") keyPath = path.Join(dir, "key.pem") certFile, err := os.Create(certPath) require.NoError(t, err) defer certFile.Close() keyFile, err := os.Create(keyPath) require.NoError(t, err) defer keyFile.Close() err = pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) require.NoError(t, err) err = pem.Encode(keyFile, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)}) require.NoError(t, err) return certPath, keyPath }