diff --git a/internal/backend/sftp/sftp.go b/internal/backend/sftp/sftp.go index 47b6871c1..1f98cf56b 100644 --- a/internal/backend/sftp/sftp.go +++ b/internal/backend/sftp/sftp.go @@ -179,7 +179,12 @@ func (r *SFTP) IsNotExist(err error) bool { func buildSSHCommand(cfg Config) (cmd string, args []string, err error) { if cfg.Command != "" { - return backend.SplitShellArgs(cfg.Command) + args, err := backend.SplitShellStrings(cfg.Command) + if err != nil { + return "", nil, err + } + + return args[0], args[1:], nil } cmd = "ssh" diff --git a/internal/backend/shell_split.go b/internal/backend/shell_split.go index d28ea6034..eff527616 100644 --- a/internal/backend/shell_split.go +++ b/internal/backend/shell_split.go @@ -41,8 +41,8 @@ func (s *shellSplitter) isSplitChar(c rune) bool { return c == '\\' || unicode.IsSpace(c) } -// SplitShellArgs returns the list of arguments from a shell command string. -func SplitShellArgs(data string) (cmd string, args []string, err error) { +// SplitShellStrings returns the list of shell strings from a shell command string. +func SplitShellStrings(data string) (strs []string, err error) { s := &shellSplitter{} // derived from strings.SplitFunc @@ -50,7 +50,7 @@ func SplitShellArgs(data string) (cmd string, args []string, err error) { for i, rune := range data { if s.isSplitChar(rune) { if fieldStart >= 0 { - args = append(args, data[fieldStart:i]) + strs = append(strs, data[fieldStart:i]) fieldStart = -1 } } else if fieldStart == -1 { @@ -58,21 +58,19 @@ func SplitShellArgs(data string) (cmd string, args []string, err error) { } } if fieldStart >= 0 { // Last field might end at EOF. - args = append(args, data[fieldStart:]) + strs = append(strs, data[fieldStart:]) } switch s.quote { case '\'': - return "", nil, errors.New("single-quoted string not terminated") + return nil, errors.New("single-quoted string not terminated") case '"': - return "", nil, errors.New("double-quoted string not terminated") + return nil, errors.New("double-quoted string not terminated") } - if len(args) == 0 { - return "", nil, errors.New("command string is empty") + if len(strs) == 0 { + return nil, errors.New("command string is empty") } - cmd, args = args[0], args[1:] - - return cmd, args, nil + return strs, nil } diff --git a/internal/backend/shell_split_test.go b/internal/backend/shell_split_test.go index bb7963d21..40ae84c63 100644 --- a/internal/backend/shell_split_test.go +++ b/internal/backend/shell_split_test.go @@ -8,59 +8,53 @@ import ( func TestShellSplitter(t *testing.T) { var tests = []struct { data string - cmd string args []string }{ { `foo`, - "foo", []string{}, + []string{"foo"}, }, { `'foo'`, - "foo", []string{}, + []string{"foo"}, }, { `foo bar baz`, - "foo", []string{"bar", "baz"}, + []string{"foo", "bar", "baz"}, }, { `foo 'bar' baz`, - "foo", []string{"bar", "baz"}, + []string{"foo", "bar", "baz"}, }, { `'bar box' baz`, - "bar box", []string{"baz"}, + []string{"bar box", "baz"}, }, { `"bar 'box'" baz`, - "bar 'box'", []string{"baz"}, + []string{"bar 'box'", "baz"}, }, { `'bar "box"' baz`, - `bar "box"`, []string{"baz"}, + []string{`bar "box"`, "baz"}, }, { `\"bar box baz`, - `"bar`, []string{"box", "baz"}, + []string{`"bar`, "box", "baz"}, }, { `"bar/foo/x" "box baz"`, - "bar/foo/x", []string{"box baz"}, + []string{"bar/foo/x", "box baz"}, }, } for _, test := range tests { t.Run("", func(t *testing.T) { - cmd, args, err := SplitShellArgs(test.data) + args, err := SplitShellStrings(test.data) if err != nil { t.Fatal(err) } - if cmd != test.cmd { - t.Fatalf("wrong cmd returned, want:\n %#v\ngot:\n %#v", - test.cmd, cmd) - } - if !reflect.DeepEqual(args, test.args) { t.Fatalf("wrong args returned, want:\n %#v\ngot:\n %#v", test.args, args) @@ -94,7 +88,7 @@ func TestShellSplitterInvalid(t *testing.T) { for _, test := range tests { t.Run("", func(t *testing.T) { - cmd, args, err := SplitShellArgs(test.data) + args, err := SplitShellStrings(test.data) if err == nil { t.Fatalf("expected error not found: %v", test.err) } @@ -103,10 +97,6 @@ func TestShellSplitterInvalid(t *testing.T) { t.Fatalf("expected error not found, want:\n %q\ngot:\n %q", test.err, err.Error()) } - if cmd != "" { - t.Fatalf("splitter returned cmd from invalid data: %v", cmd) - } - if len(args) > 0 { t.Fatalf("splitter returned fields from invalid data: %v", args) }