diff --git a/cmd/serve/sftp/server.go b/cmd/serve/sftp/server.go index a617cf2ad..278630851 100644 --- a/cmd/serve/sftp/server.go +++ b/cmd/serve/sftp/server.go @@ -78,6 +78,39 @@ func (s *server) getVFS(what string, sshConn *ssh.ServerConn) (VFS *vfs.VFS) { return VFS } +// Accept a single connection - run in a go routine as the ssh +// authentication can block +func (s *server) acceptConnection(nConn net.Conn) { + what := describeConn(nConn) + + // Before use, a handshake must be performed on the incoming net.Conn. + sshConn, chans, reqs, err := ssh.NewServerConn(nConn, s.config) + if err != nil { + fs.Errorf(what, "SSH login failed: %v", err) + return + } + + fs.Infof(what, "SSH login from %s using %s", sshConn.User(), sshConn.ClientVersion()) + + // Discard all global out-of-band Requests + go ssh.DiscardRequests(reqs) + + c := &conn{ + what: what, + vfs: s.getVFS(what, sshConn), + } + if c.vfs == nil { + fs.Infof(what, "Closing unauthenticated connection (couldn't find VFS)") + _ = nConn.Close() + return + } + c.handlers = newVFSHandler(c.vfs) + + // Accept all channels + go c.handleChannels(chans) +} + +// Accept connections and call them in a go routine func (s *server) acceptConnections() { for { nConn, err := s.listener.Accept() @@ -88,33 +121,7 @@ func (s *server) acceptConnections() { fs.Errorf(nil, "Failed to accept incoming connection: %v", err) continue } - what := describeConn(nConn) - - // Before use, a handshake must be performed on the incoming net.Conn. - sshConn, chans, reqs, err := ssh.NewServerConn(nConn, s.config) - if err != nil { - fs.Errorf(what, "SSH login failed: %v", err) - continue - } - - fs.Infof(what, "SSH login from %s using %s", sshConn.User(), sshConn.ClientVersion()) - - // Discard all global out-of-band Requests - go ssh.DiscardRequests(reqs) - - c := &conn{ - what: what, - vfs: s.getVFS(what, sshConn), - } - if c.vfs == nil { - fs.Infof(what, "Closing unauthenticated connection (couldn't find VFS)") - _ = nConn.Close() - continue - } - c.handlers = newVFSHandler(c.vfs) - - // Accept all channels - go c.handleChannels(chans) + go s.acceptConnection(nConn) } }