package main import ( "context" "crypto/rand" "crypto/rsa" "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/pem" "fmt" "math/big" "net" "net/http" "os" "path" "testing" "time" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" ) const ( expHeaderKey = "Foo" expHeaderValue = "Bar" ) func TestHTTP_TLS(t *testing.T) { ctx := context.Background() certPath, keyPath := prepareTestCerts(t) tlsListener, err := newServer(ctx, ServerInfo{ Address: ":0", TLS: ServerTLSInfo{ Enabled: true, CertFile: certPath, KeyFile: keyPath, }, }) require.NoError(t, err) port := tlsListener.Listener().Addr().(*net.TCPAddr).Port addr := fmt.Sprintf("https://localhost:%d", port) go func() { _ = fasthttp.Serve(tlsListener.Listener(), testHandler) }() tlsClientConfig := &tls.Config{ InsecureSkipVerify: true, } cliHTTP := http.Client{Transport: &http.Transport{}} cliHTTPS := http.Client{Transport: &http.Transport{TLSClientConfig: tlsClientConfig}} req, err := http.NewRequest("GET", addr, nil) require.NoError(t, err) req.Header[expHeaderKey] = []string{expHeaderValue} resp, err := cliHTTPS.Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) _, err = cliHTTP.Do(req) require.ErrorContains(t, err, "failed to verify certificate") } func testHandler(ctx *fasthttp.RequestCtx) { hdr := ctx.Request.Header.Peek(expHeaderKey) if len(hdr) == 0 || string(hdr) != expHeaderValue { ctx.Response.SetStatusCode(http.StatusBadRequest) } else { ctx.Response.SetStatusCode(http.StatusOK) } } func prepareTestCerts(t *testing.T) (certPath, keyPath string) { privateKey, err := rsa.GenerateKey(rand.Reader, 2048) require.NoError(t, err) template := x509.Certificate{ SerialNumber: big.NewInt(1), Subject: pkix.Name{CommonName: "localhost"}, NotBefore: time.Now(), NotAfter: time.Now().Add(time.Hour * 24 * 365), KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, BasicConstraintsValid: true, } derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) require.NoError(t, err) dir := t.TempDir() certPath = path.Join(dir, "cert.pem") keyPath = path.Join(dir, "key.pem") certFile, err := os.Create(certPath) require.NoError(t, err) defer certFile.Close() keyFile, err := os.Create(keyPath) require.NoError(t, err) defer keyFile.Close() err = pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) require.NoError(t, err) err = pem.Encode(keyFile, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)}) require.NoError(t, err) return certPath, keyPath }