diff --git a/cmd/serve/docker/docker.go b/cmd/serve/docker/docker.go index a13ceb490..060e0f2b5 100644 --- a/cmd/serve/docker/docker.go +++ b/cmd/serve/docker/docker.go @@ -4,6 +4,7 @@ package docker import ( "context" _ "embed" + "net" "path/filepath" "strings" "syscall" @@ -13,6 +14,7 @@ import ( "github.com/rclone/rclone/cmd" "github.com/rclone/rclone/cmd/mountlib" "github.com/rclone/rclone/fs/config/flags" + "github.com/rclone/rclone/lib/systemd" "github.com/rclone/rclone/vfs" "github.com/rclone/rclone/vfs/vfsflags" ) @@ -70,15 +72,25 @@ var Command = &cobra.Command{ return err } srv := NewServer(drv) + + var listener net.Listener if socketAddr == "" { // Listen on unix socket at /run/docker/plugins/.sock - return srv.ServeUnix(pluginName, socketGid) - } - if filepath.IsAbs(socketAddr) { + listener, err = srv.ListenUnix(pluginName, socketGid) + } else if filepath.IsAbs(socketAddr) { // Listen on unix socket at given path - return srv.ServeUnix(socketAddr, socketGid) + listener, err = srv.ListenUnix(socketAddr, socketGid) + } else { + listener, err = srv.ListenTCP(socketAddr, "", nil, noSpec) } - return srv.ServeTCP(socketAddr, "", nil, noSpec) + if err != nil { + return err + } + + // notify systemd + defer systemd.Notify()() + + return srv.Serve(listener) }) }, } diff --git a/cmd/serve/docker/docker_test.go b/cmd/serve/docker/docker_test.go index a4b1cd320..0253da360 100644 --- a/cmd/serve/docker/docker_test.go +++ b/cmd/serve/docker/docker_test.go @@ -342,13 +342,16 @@ func testMountAPI(t *testing.T, sockAddr string) { srv := docker.NewServer(drv) go func() { - var errServe error + var listener net.Listener + var err error if unixPath != "" { - errServe = srv.ServeUnix(unixPath, os.Getgid()) + listener, err = srv.ListenUnix(unixPath, os.Getgid()) } else { - errServe = srv.ServeTCP(sockAddr, testDir, nil, false) + listener, err = srv.ListenTCP(sockAddr, testDir, nil, false) } - assert.ErrorIs(t, errServe, http.ErrServerClosed) + assert.NoError(t, err) + err = srv.Serve(listener) + assert.ErrorIs(t, err, http.ErrServerClosed) }() defer func() { err := srv.Shutdown(ctx) diff --git a/cmd/serve/docker/driver.go b/cmd/serve/docker/driver.go index 86885d336..593f53130 100644 --- a/cmd/serve/docker/driver.go +++ b/cmd/serve/docker/driver.go @@ -12,7 +12,6 @@ import ( "sync" "time" - "github.com/coreos/go-systemd/v22/daemon" "github.com/rclone/rclone/cmd/mountlib" "github.com/rclone/rclone/fs" "github.com/rclone/rclone/fs/config" @@ -84,11 +83,6 @@ func NewDriver(ctx context.Context, root string, mntOpt *mountlib.Options, vfsOp drv.exitOnce.Do(drv.Exit) }) - // notify systemd - if _, err := daemon.SdNotify(false, daemon.SdNotifyReady); err != nil { - return nil, fmt.Errorf("failed to notify systemd: %w", err) - } - return drv, nil } @@ -98,10 +92,6 @@ func (drv *Driver) Exit() { drv.mu.Lock() defer drv.mu.Unlock() - reportErr(func() error { - _, err := daemon.SdNotify(false, daemon.SdNotifyStopping) - return err - }()) drv.monChan <- true // ask monitor to exit for _, vol := range drv.volumes { reportErr(vol.unmountAll()) diff --git a/cmd/serve/docker/serve.go b/cmd/serve/docker/serve.go index 9fa6e5642..fbd2ec3d1 100644 --- a/cmd/serve/docker/serve.go +++ b/cmd/serve/docker/serve.go @@ -29,40 +29,38 @@ func (s *Server) Shutdown(ctx context.Context) error { return hs.Shutdown(ctx) } -func (s *Server) serve(listener net.Listener, addr, tempFile string) error { - if tempFile != "" { - atexit.Register(func() { - // remove spec file or self-created unix socket - fs.Debugf(nil, "Removing stale file %s", tempFile) - _ = os.Remove(tempFile) - }) - } +// Serve requests using the listener +func (s *Server) Serve(listener net.Listener) error { hs := (*http.Server)(s) return hs.Serve(listener) } -// ServeUnix makes the handler to listen for requests in a unix socket. +// ListenUnix returns a unix socket listener. // It also creates the socket file in the right directory for docker to read. -func (s *Server) ServeUnix(path string, gid int) error { +func (s *Server) ListenUnix(path string, gid int) (net.Listener, error) { listener, socketPath, err := newUnixListener(path, gid) if err != nil { - return err + return nil, err } if socketPath != "" { - path = socketPath - fs.Infof(nil, "Serving unix socket: %s", path) + fs.Infof(nil, "Listening on unix socket: %s", socketPath) + atexit.Register(func() { + // remove self-created unix socket + fs.Debugf(nil, "Removing stale unix socket file %s", socketPath) + _ = os.Remove(socketPath) + }) } else { - fs.Infof(nil, "Serving systemd socket") + fs.Infof(nil, "Listening on systemd socket") } - return s.serve(listener, path, socketPath) + return listener, nil } -// ServeTCP makes the handler listen for request on a given TCP address. +// ListenTCP returns a TCP listener for the given TCP address. // It also writes the spec file in the right directory for docker to read. -func (s *Server) ServeTCP(addr, specDir string, tlsConfig *tls.Config, noSpec bool) error { +func (s *Server) ListenTCP(addr, specDir string, tlsConfig *tls.Config, noSpec bool) (net.Listener, error) { listener, err := net.Listen("tcp", addr) if err != nil { - return err + return nil, err } if tlsConfig != nil { tlsConfig.NextProtos = []string{"http/1.1"} @@ -73,11 +71,16 @@ func (s *Server) ServeTCP(addr, specDir string, tlsConfig *tls.Config, noSpec bo if !noSpec { specFile, err = writeSpecFile(addr, "tcp", specDir) if err != nil { - return err + return nil, err } + atexit.Register(func() { + // remove spec file + fs.Debugf(nil, "Removing stale spec file %s", specFile) + _ = os.Remove(specFile) + }) } - fs.Infof(nil, "Serving TCP socket: %s", addr) - return s.serve(listener, addr, specFile) + fs.Infof(nil, "Listening on TCP socket: %s", addr) + return listener, nil } func writeSpecFile(addr, proto, specDir string) (string, error) {