diff --git a/go.mod b/go.mod index ceefe21be..39c6277c5 100644 --- a/go.mod +++ b/go.mod @@ -25,6 +25,7 @@ require ( github.com/dop251/scsu v0.0.0-20200422003335-8fadfb689669 github.com/dropbox/dropbox-sdk-go-unofficial v1.0.1-0.20210114204226-41fdcdae8a53 github.com/gabriel-vasile/mimetype v1.2.0 + github.com/go-chi/chi/v5 v5.0.2 github.com/go-ole/go-ole v1.2.5 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/google/go-querystring v1.1.0 // indirect diff --git a/go.sum b/go.sum index e9dc313fc..84bdc7e9a 100644 --- a/go.sum +++ b/go.sum @@ -207,6 +207,8 @@ github.com/gabriel-vasile/mimetype v1.2.0/go.mod h1:6CDPel/o/3/s4+bp6kIbsWATq8pm github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/glycerine/go-unsnap-stream v0.0.0-20180323001048-9f0cb55181dd/go.mod h1:/20jfyN9Y5QPEAprSgKAUr+glWDY39ZiUEAYOEv5dsE= github.com/glycerine/goconvey v0.0.0-20180728074245-46e3a41ad493/go.mod h1:Ogl1Tioa0aV7gstGFO7KhffUsb9M4ydbEbbxpcEDc24= +github.com/go-chi/chi/v5 v5.0.2 h1:4xKeALZdMEsuI5s05PU2Bm89Uc5iM04qFubUCl5LfAQ= +github.com/go-chi/chi/v5 v5.0.2/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= diff --git a/lib/http/http.go b/lib/http/http.go new file mode 100644 index 000000000..25e5ff82d --- /dev/null +++ b/lib/http/http.go @@ -0,0 +1,382 @@ +// Package http provides a registration interface for http services +package http + +import ( + "context" + "crypto/tls" + "crypto/x509" + "fmt" + "io/ioutil" + "log" + "net" + "net/http" + "strings" + "sync" + "time" + + "github.com/go-chi/chi/v5" + "github.com/pkg/errors" + "github.com/rclone/rclone/fs/config/flags" + "github.com/spf13/pflag" +) + +// Help contains text describing the http server to add to the command +// help. +var Help = ` +### Server options + +Use --addr to specify which IP address and port the server should +listen on, eg --addr 1.2.3.4:8000 or --addr :8080 to listen to all +IPs. By default it only listens on localhost. You can use port +:0 to let the OS choose an available port. + +If you set --addr to listen on a public or LAN accessible IP address +then using Authentication is advised - see the next section for info. + +--server-read-timeout and --server-write-timeout can be used to +control the timeouts on the server. Note that this is the total time +for a transfer. + +--max-header-bytes controls the maximum number of bytes the server will +accept in the HTTP header. + +--baseurl controls the URL prefix that rclone serves from. By default +rclone will serve from the root. If you used --baseurl "/rclone" then +rclone would serve from a URL starting with "/rclone/". This is +useful if you wish to proxy rclone serve. Rclone automatically +inserts leading and trailing "/" on --baseurl, so --baseurl "rclone", +--baseurl "/rclone" and --baseurl "/rclone/" are all treated +identically. + +#### SSL/TLS + +By default this will serve over http. If you want you can serve over +https. You will need to supply the --cert and --key flags. If you +wish to do client side certificate validation then you will need to +supply --client-ca also. + +--cert should be a either a PEM encoded certificate or a concatenation +of that with the CA certificate. --key should be the PEM encoded +private key and --client-ca should be the PEM encoded client +certificate authority certificate. +` + +// Middleware function signature required by chi.Router.Use() +type Middleware func(http.Handler) http.Handler + +// Options contains options for the http Server +type Options struct { + ListenAddr string // Port to listen on + BaseURL string // prefix to strip from URLs + ServerReadTimeout time.Duration // Timeout for server reading data + ServerWriteTimeout time.Duration // Timeout for server writing data + MaxHeaderBytes int // Maximum size of request header + SslCert string // SSL PEM key (concatenation of certificate and CA certificate) + SslKey string // SSL PEM Private key + ClientCA string // Client certificate authority to verify clients with +} + +// DefaultOpt is the default values used for Options +var DefaultOpt = Options{ + ListenAddr: "127.0.0.1:8080", + ServerReadTimeout: 1 * time.Hour, + ServerWriteTimeout: 1 * time.Hour, + MaxHeaderBytes: 4096, +} + +// Server interface of http server +type Server interface { + Router() chi.Router + Route(pattern string, fn func(r chi.Router)) chi.Router + Mount(pattern string, h http.Handler) + Shutdown() error +} + +type server struct { + addrs []net.Addr + tlsAddrs []net.Addr + listeners []net.Listener + tlsListeners []net.Listener + httpServer *http.Server + baseRouter chi.Router + closing *sync.WaitGroup + useSSL bool +} + +var ( + defaultServer *server + defaultServerOptions = DefaultOpt + defaultServerMutex sync.Mutex +) + +func useSSL(opt Options) bool { + return opt.SslKey != "" +} + +// NewServer instantiates a new http server using provided listeners and options +// This function is provided if the default http server does not meet a services requirements and should not generally be used +// A http server can listen using multiple listeners. For example, a listener for port 80, and a listener for port 443. +// tlsListeners are ignored if opt.SslKey is not provided +func NewServer(listeners, tlsListeners []net.Listener, opt Options) (Server, error) { + // Validate input + if len(listeners) == 0 && len(tlsListeners) == 0 { + return nil, errors.New("Can't create server without listeners") + } + + // Prepare TLS config + var tlsConfig *tls.Config = nil + + useSSL := useSSL(opt) + if (opt.SslCert != "") != useSSL { + err := errors.New("Need both -cert and -key to use SSL") + log.Fatalf(err.Error()) + return nil, err + } + + if useSSL { + tlsConfig = &tls.Config{ + MinVersion: tls.VersionTLS10, // disable SSL v3.0 and earlier + } + } else if len(listeners) == 0 && len(tlsListeners) != 0 { + return nil, errors.New("No SslKey or non-tlsListeners") + } + + if opt.ClientCA != "" { + if !useSSL { + err := errors.New("Can't use --client-ca without --cert and --key") + log.Fatalf(err.Error()) + return nil, err + } + certpool := x509.NewCertPool() + pem, err := ioutil.ReadFile(opt.ClientCA) + if err != nil { + log.Fatalf("Failed to read client certificate authority: %v", err) + return nil, err + } + if !certpool.AppendCertsFromPEM(pem) { + err := errors.New("Can't parse client certificate authority") + log.Fatalf(err.Error()) + return nil, err + } + tlsConfig.ClientCAs = certpool + tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert + } + + // Ignore passing "/" for BaseURL + opt.BaseURL = strings.Trim(opt.BaseURL, "/") + if opt.BaseURL != "" { + opt.BaseURL = "/" + opt.BaseURL + } + + // Build base router + var router chi.Router = chi.NewRouter() + router.MethodNotAllowed(func(w http.ResponseWriter, _ *http.Request) { + http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) + }) + + handler := router.(http.Handler) + if opt.BaseURL != "" { + handler = http.StripPrefix(opt.BaseURL, handler) + } + + // Serve on listeners + httpServer := &http.Server{ + Handler: handler, + ReadTimeout: opt.ServerReadTimeout, + WriteTimeout: opt.ServerWriteTimeout, + MaxHeaderBytes: opt.MaxHeaderBytes, + ReadHeaderTimeout: 10 * time.Second, // time to send the headers + IdleTimeout: 60 * time.Second, // time to keep idle connections open + TLSConfig: tlsConfig, + } + + addrs, tlsAddrs := make([]net.Addr, len(listeners)), make([]net.Addr, len(tlsListeners)) + + wg := &sync.WaitGroup{} + + for i, l := range listeners { + addrs[i] = l.Addr() + } + + if useSSL { + for i, l := range tlsListeners { + tlsAddrs[i] = l.Addr() + } + } + + return &server{addrs, tlsAddrs, listeners, tlsListeners, httpServer, router, wg, useSSL}, nil +} + +func (s *server) Serve() { + serve := func(l net.Listener) { + defer s.closing.Done() + if err := s.httpServer.Serve(l); err != http.ErrServerClosed && err != nil { + log.Fatalf(err.Error()) + } + } + + s.closing.Add(len(s.listeners)) + for _, l := range s.listeners { + go serve(l) + } + + if s.useSSL { + s.closing.Add(len(s.tlsListeners)) + for _, l := range s.tlsListeners { + go serve(l) + } + } +} + +// Router returns the server base router +func (s *server) Router() chi.Router { + return s.baseRouter +} + +// Route mounts a sub-Router along a `pattern` string. +func (s *server) Route(pattern string, fn func(r chi.Router)) chi.Router { + return s.baseRouter.Route(pattern, fn) +} + +// Mount attaches another http.Handler along ./pattern/* +func (s *server) Mount(pattern string, h http.Handler) { + s.baseRouter.Mount(pattern, h) +} + +// Shutdown gracefully shuts down the server +func (s *server) Shutdown() error { + if err := s.httpServer.Shutdown(context.Background()); err != nil { + return err + } + s.closing.Wait() + return nil +} + +//---- Default HTTP server convenience functions ---- + +// Router returns the server base router +func Router() (chi.Router, error) { + if err := start(); err != nil { + return nil, err + } + return defaultServer.baseRouter, nil +} + +// Route mounts a sub-Router along a `pattern` string. +func Route(pattern string, fn func(r chi.Router)) (chi.Router, error) { + if err := start(); err != nil { + return nil, err + } + return defaultServer.Route(pattern, fn), nil +} + +// Mount attaches another http.Handler along ./pattern/* +func Mount(pattern string, h http.Handler) error { + if err := start(); err != nil { + return err + } + defaultServer.Mount(pattern, h) + return nil +} + +// Restart or start the default http server using the default options and no handlers +func Restart() error { + if e := Shutdown(); e != nil { + return e + } + + return start() +} + +// Start the default server +func start() error { + defaultServerMutex.Lock() + defer defaultServerMutex.Unlock() + + if defaultServer != nil { + // Server already started, do nothing + return nil + } + + var err error + var l net.Listener + l, err = net.Listen("tcp", defaultServerOptions.ListenAddr) + if err != nil { + return err + } + + var s Server + if useSSL(defaultServerOptions) { + s, err = NewServer([]net.Listener{}, []net.Listener{l}, defaultServerOptions) + } else { + s, err = NewServer([]net.Listener{l}, []net.Listener{}, defaultServerOptions) + } + if err != nil { + return err + } + defaultServer = s.(*server) + defaultServer.Serve() + return nil +} + +// Shutdown gracefully shuts down the default http server +func Shutdown() error { + defaultServerMutex.Lock() + defer defaultServerMutex.Unlock() + if defaultServer != nil { + s := defaultServer + defaultServer = nil + return s.Shutdown() + } + return nil +} + +// GetOptions thread safe getter for the default server options +func GetOptions() Options { + defaultServerMutex.Lock() + defer defaultServerMutex.Unlock() + return defaultServerOptions +} + +// SetOptions thread safe setter for the default server options +func SetOptions(opt Options) { + defaultServerMutex.Lock() + defer defaultServerMutex.Unlock() + defaultServerOptions = opt +} + +//---- Utility functions ---- + +// URL of default http server +func URL() string { + if defaultServer == nil { + panic("Server not running") + } + for _, a := range defaultServer.addrs { + return fmt.Sprintf("http://%s%s/", a.String(), defaultServerOptions.BaseURL) + } + for _, a := range defaultServer.tlsAddrs { + return fmt.Sprintf("https://%s%s/", a.String(), defaultServerOptions.BaseURL) + } + panic("Server is running with no listener") +} + +//---- Command line flags ---- + +// AddFlagsPrefix adds flags for the httplib +func AddFlagsPrefix(flagSet *pflag.FlagSet, prefix string, Opt *Options) { + flags.StringVarP(flagSet, &Opt.ListenAddr, prefix+"addr", "", Opt.ListenAddr, "IPaddress:Port or :Port to bind server to.") + flags.DurationVarP(flagSet, &Opt.ServerReadTimeout, prefix+"server-read-timeout", "", Opt.ServerReadTimeout, "Timeout for server reading data") + flags.DurationVarP(flagSet, &Opt.ServerWriteTimeout, prefix+"server-write-timeout", "", Opt.ServerWriteTimeout, "Timeout for server writing data") + flags.IntVarP(flagSet, &Opt.MaxHeaderBytes, prefix+"max-header-bytes", "", Opt.MaxHeaderBytes, "Maximum size of request header") + flags.StringVarP(flagSet, &Opt.SslCert, prefix+"cert", "", Opt.SslCert, "SSL PEM key (concatenation of certificate and CA certificate)") + flags.StringVarP(flagSet, &Opt.SslKey, prefix+"key", "", Opt.SslKey, "SSL PEM Private key") + flags.StringVarP(flagSet, &Opt.ClientCA, prefix+"client-ca", "", Opt.ClientCA, "Client certificate authority to verify clients with") + flags.StringVarP(flagSet, &Opt.BaseURL, prefix+"baseurl", "", Opt.BaseURL, "Prefix for URLs - leave blank for root.") + +} + +// AddFlags adds flags for the httplib +func AddFlags(flagSet *pflag.FlagSet) { + AddFlagsPrefix(flagSet, "", &defaultServerOptions) +} diff --git a/lib/http/http_test.go b/lib/http/http_test.go new file mode 100644 index 000000000..a499d5658 --- /dev/null +++ b/lib/http/http_test.go @@ -0,0 +1,438 @@ +package http + +import ( + "net" + "net/http" + "reflect" + "testing" + + "golang.org/x/net/nettest" + + "github.com/go-chi/chi/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGetOptions(t *testing.T) { + tests := []struct { + name string + want Options + }{ + {name: "basic", want: defaultServerOptions}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := GetOptions(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("GetOptions() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestMount(t *testing.T) { + type args struct { + pattern string + h http.Handler + } + tests := []struct { + name string + args args + wantErr bool + }{ + {name: "basic", args: args{ + pattern: "/", + h: http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {}), + }, wantErr: false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.wantErr { + require.Error(t, Mount(tt.args.pattern, tt.args.h)) + } else { + require.NoError(t, Mount(tt.args.pattern, tt.args.h)) + } + assert.NotNil(t, defaultServer) + assert.True(t, defaultServer.baseRouter.Match(chi.NewRouteContext(), "GET", tt.args.pattern), "Failed to match route after registering") + }) + if err := Shutdown(); err != nil { + t.Fatal(err) + } + } +} + +func TestNewServer(t *testing.T) { + type args struct { + listeners []net.Listener + tlsListeners []net.Listener + opt Options + } + listener, err := nettest.NewLocalListener("tcp") + if err != nil { + t.Fatal(err) + } + tests := []struct { + name string + args args + wantErr bool + }{ + {name: "default http", args: args{ + listeners: []net.Listener{listener}, + tlsListeners: []net.Listener{}, + opt: defaultServerOptions, + }, wantErr: false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NewServer(tt.args.listeners, tt.args.tlsListeners, tt.args.opt) + if (err != nil) != tt.wantErr { + t.Errorf("NewServer() error = %v, wantErr %v", err, tt.wantErr) + return + } + s, ok := got.(*server) + require.True(t, ok, "NewServer returned unexpected type") + if len(tt.args.listeners) > 0 { + assert.Equal(t, listener.Addr(), s.addrs[0]) + } else { + assert.Empty(t, s.addrs) + } + if len(tt.args.tlsListeners) > 0 { + assert.Equal(t, listener.Addr(), s.tlsAddrs[0]) + } else { + assert.Empty(t, s.tlsAddrs) + } + if tt.args.opt.BaseURL != "" { + assert.NotSame(t, s.baseRouter, s.httpServer.Handler, "should have wrapped baseRouter") + } else { + assert.Same(t, s.baseRouter, s.httpServer.Handler, "should be baseRouter") + } + if useSSL(tt.args.opt) { + assert.NotNil(t, s.httpServer.TLSConfig, "missing SSL config") + } else { + assert.Nil(t, s.httpServer.TLSConfig, "unexpectedly has SSL config") + } + }) + } +} + +func TestRestart(t *testing.T) { + tests := []struct { + name string + started bool + wantErr bool + }{ + {name: "started", started: true, wantErr: false}, + {name: "stopped", started: false, wantErr: false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.started { + require.NoError(t, Restart()) // Call it twice basically + } else { + require.NoError(t, Shutdown()) + } + current := defaultServer + if err := Restart(); (err != nil) != tt.wantErr { + t.Errorf("Restart() error = %v, wantErr %v", err, tt.wantErr) + } + assert.NotNil(t, defaultServer, "failed to start default server") + assert.NotSame(t, current, defaultServer, "same server instance as before restart") + }) + } +} + +func TestRoute(t *testing.T) { + type args struct { + pattern string + fn func(r chi.Router) + } + tests := []struct { + name string + args args + test func(t *testing.T, r chi.Router) + }{ + { + name: "basic", + args: args{ + pattern: "/basic", + fn: func(r chi.Router) {}, + }, + test: func(t *testing.T, r chi.Router) { + require.Len(t, r.Routes(), 1) + assert.Equal(t, r.Routes()[0].Pattern, "/basic/*") + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.NoError(t, Restart()) + _, err := Route(tt.args.pattern, tt.args.fn) + require.NoError(t, err) + tt.test(t, defaultServer.baseRouter) + }) + + if err := Shutdown(); err != nil { + t.Fatal(err) + } + } +} + +func TestSetOptions(t *testing.T) { + type args struct { + opt Options + } + tests := []struct { + name string + args args + }{ + { + name: "basic", + args: args{opt: Options{ + ListenAddr: "127.0.0.1:9999", + BaseURL: "/basic", + ServerReadTimeout: 1, + ServerWriteTimeout: 1, + MaxHeaderBytes: 1, + }}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + SetOptions(tt.args.opt) + require.Equal(t, tt.args.opt, defaultServerOptions) + require.NoError(t, Restart()) + if useSSL(tt.args.opt) { + assert.Equal(t, tt.args.opt.ListenAddr, defaultServer.tlsAddrs[0].String()) + } else { + assert.Equal(t, tt.args.opt.ListenAddr, defaultServer.addrs[0].String()) + } + assert.Equal(t, tt.args.opt.ServerReadTimeout, defaultServer.httpServer.ReadTimeout) + assert.Equal(t, tt.args.opt.ServerWriteTimeout, defaultServer.httpServer.WriteTimeout) + assert.Equal(t, tt.args.opt.MaxHeaderBytes, defaultServer.httpServer.MaxHeaderBytes) + if tt.args.opt.BaseURL != "" && tt.args.opt.BaseURL != "/" { + assert.NotSame(t, defaultServer.httpServer.Handler, defaultServer.baseRouter, "BaseURL ignored") + } + }) + SetOptions(DefaultOpt) + } +} + +func TestShutdown(t *testing.T) { + tests := []struct { + name string + started bool + wantErr bool + }{ + {name: "started", started: true, wantErr: false}, + {name: "stopped", started: false, wantErr: false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.started { + require.NoError(t, Restart()) + } else { + require.NoError(t, Shutdown()) // Call it twice basically + } + if err := Shutdown(); (err != nil) != tt.wantErr { + t.Errorf("Shutdown() error = %v, wantErr %v", err, tt.wantErr) + } + assert.Nil(t, defaultServer, "default server not deleted") + }) + } +} + +func TestURL(t *testing.T) { + tests := []struct { + name string + want string + }{ + {name: "basic", want: "http://127.0.0.1:8080/"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.NoError(t, Restart()) + if got := URL(); got != tt.want { + t.Errorf("URL() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_server_Mount(t *testing.T) { + type args struct { + pattern string + h http.Handler + } + tests := []struct { + name string + args args + opt Options + }{ + {name: "basic", args: args{ + pattern: "/", + h: http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {}), + }, opt: defaultServerOptions}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + listener, err := nettest.NewLocalListener("tcp") + require.NoError(t, err) + s, err2 := NewServer([]net.Listener{listener}, []net.Listener{}, tt.opt) + require.NoError(t, err2) + s.Mount(tt.args.pattern, tt.args.h) + srv, ok := s.(*server) + require.True(t, ok) + assert.NotNil(t, srv) + assert.True(t, srv.baseRouter.Match(chi.NewRouteContext(), "GET", tt.args.pattern), "Failed to Match() route after registering") + }) + } +} + +func Test_server_Route(t *testing.T) { + type args struct { + pattern string + fn func(r chi.Router) + } + tests := []struct { + name string + args args + opt Options + test func(t *testing.T, r chi.Router) + }{ + { + name: "basic", + args: args{ + pattern: "/basic", + fn: func(r chi.Router) { + + }, + }, + test: func(t *testing.T, r chi.Router) { + require.Len(t, r.Routes(), 1) + assert.Equal(t, r.Routes()[0].Pattern, "/basic/*") + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + listener, err := nettest.NewLocalListener("tcp") + require.NoError(t, err) + s, err2 := NewServer([]net.Listener{listener}, []net.Listener{}, tt.opt) + require.NoError(t, err2) + s.Route(tt.args.pattern, tt.args.fn) + srv, ok := s.(*server) + require.True(t, ok) + assert.NotNil(t, srv) + tt.test(t, srv.baseRouter) + }) + } +} + +func Test_server_Shutdown(t *testing.T) { + tests := []struct { + name string + opt Options + wantErr bool + }{ + { + name: "basic", + opt: defaultServerOptions, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + listener, err := nettest.NewLocalListener("tcp") + require.NoError(t, err) + s, err2 := NewServer([]net.Listener{listener}, []net.Listener{}, tt.opt) + require.NoError(t, err2) + srv, ok := s.(*server) + require.True(t, ok) + if err := s.Shutdown(); (err != nil) != tt.wantErr { + t.Errorf("Shutdown() error = %v, wantErr %v", err, tt.wantErr) + } + assert.EqualError(t, srv.httpServer.Serve(listener), http.ErrServerClosed.Error()) + }) + } +} + +func Test_start(t *testing.T) { + tests := []struct { + name string + opt Options + wantErr bool + }{ + { + name: "basic", + opt: defaultServerOptions, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + SetOptions(tt.opt) + if err := start(); (err != nil) != tt.wantErr { + t.Errorf("start() error = %v, wantErr %v", err, tt.wantErr) + return + } + s := defaultServer + if useSSL(tt.opt) { + assert.Equal(t, tt.opt.ListenAddr, s.tlsAddrs[0].String()) + } else { + assert.Equal(t, tt.opt.ListenAddr, s.addrs[0].String()) + } + /* accessing s.httpServer.* can't be done synchronously and is a race condition + assert.Equal(t, tt.opt.ServerReadTimeout, defaultServer.httpServer.ReadTimeout) + assert.Equal(t, tt.opt.ServerWriteTimeout, defaultServer.httpServer.WriteTimeout) + assert.Equal(t, tt.opt.MaxHeaderBytes, defaultServer.httpServer.MaxHeaderBytes) + if tt.opt.BaseURL != "" && tt.opt.BaseURL != "/" { + assert.NotSame(t, s.baseRouter, s.httpServer.Handler, "should have wrapped baseRouter") + } else { + assert.Same(t, s.baseRouter, s.httpServer.Handler, "should be baseRouter") + } + if useSSL(tt.opt) { + require.NotNil(t, s.httpServer.TLSConfig, "missing SSL config") + assert.NotEmpty(t, s.httpServer.TLSConfig.Certificates, "missing SSL config") + } else if s.httpServer.TLSConfig != nil { + assert.Empty(t, s.httpServer.TLSConfig.Certificates, "unexpectedly has SSL config") + } + */ + }) + } +} + +func Test_useSSL(t *testing.T) { + type args struct { + opt Options + } + tests := []struct { + name string + args args + want bool + }{ + { + name: "basic", + args: args{opt: Options{ + SslCert: "", + SslKey: "", + ClientCA: "", + }}, + want: false, + }, + { + name: "basic", + args: args{opt: Options{ + SslCert: "", + SslKey: "test", + ClientCA: "", + }}, + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := useSSL(tt.args.opt); got != tt.want { + t.Errorf("useSSL() = %v, want %v", got, tt.want) + } + }) + } +}