Refactor SplitShellStrings

This commit is contained in:
Alexander Neumann 2018-03-13 20:50:37 +01:00
parent 345b6c4694
commit 34f27edc03
3 changed files with 26 additions and 33 deletions

View file

@ -179,7 +179,12 @@ func (r *SFTP) IsNotExist(err error) bool {
func buildSSHCommand(cfg Config) (cmd string, args []string, err error) { func buildSSHCommand(cfg Config) (cmd string, args []string, err error) {
if cfg.Command != "" { if cfg.Command != "" {
return backend.SplitShellArgs(cfg.Command) args, err := backend.SplitShellStrings(cfg.Command)
if err != nil {
return "", nil, err
}
return args[0], args[1:], nil
} }
cmd = "ssh" cmd = "ssh"

View file

@ -41,8 +41,8 @@ func (s *shellSplitter) isSplitChar(c rune) bool {
return c == '\\' || unicode.IsSpace(c) return c == '\\' || unicode.IsSpace(c)
} }
// SplitShellArgs returns the list of arguments from a shell command string. // SplitShellStrings returns the list of shell strings from a shell command string.
func SplitShellArgs(data string) (cmd string, args []string, err error) { func SplitShellStrings(data string) (strs []string, err error) {
s := &shellSplitter{} s := &shellSplitter{}
// derived from strings.SplitFunc // derived from strings.SplitFunc
@ -50,7 +50,7 @@ func SplitShellArgs(data string) (cmd string, args []string, err error) {
for i, rune := range data { for i, rune := range data {
if s.isSplitChar(rune) { if s.isSplitChar(rune) {
if fieldStart >= 0 { if fieldStart >= 0 {
args = append(args, data[fieldStart:i]) strs = append(strs, data[fieldStart:i])
fieldStart = -1 fieldStart = -1
} }
} else if fieldStart == -1 { } else if fieldStart == -1 {
@ -58,21 +58,19 @@ func SplitShellArgs(data string) (cmd string, args []string, err error) {
} }
} }
if fieldStart >= 0 { // Last field might end at EOF. if fieldStart >= 0 { // Last field might end at EOF.
args = append(args, data[fieldStart:]) strs = append(strs, data[fieldStart:])
} }
switch s.quote { switch s.quote {
case '\'': case '\'':
return "", nil, errors.New("single-quoted string not terminated") return nil, errors.New("single-quoted string not terminated")
case '"': case '"':
return "", nil, errors.New("double-quoted string not terminated") return nil, errors.New("double-quoted string not terminated")
} }
if len(args) == 0 { if len(strs) == 0 {
return "", nil, errors.New("command string is empty") return nil, errors.New("command string is empty")
} }
cmd, args = args[0], args[1:] return strs, nil
return cmd, args, nil
} }

View file

@ -8,59 +8,53 @@ import (
func TestShellSplitter(t *testing.T) { func TestShellSplitter(t *testing.T) {
var tests = []struct { var tests = []struct {
data string data string
cmd string
args []string args []string
}{ }{
{ {
`foo`, `foo`,
"foo", []string{}, []string{"foo"},
}, },
{ {
`'foo'`, `'foo'`,
"foo", []string{}, []string{"foo"},
}, },
{ {
`foo bar baz`, `foo bar baz`,
"foo", []string{"bar", "baz"}, []string{"foo", "bar", "baz"},
}, },
{ {
`foo 'bar' baz`, `foo 'bar' baz`,
"foo", []string{"bar", "baz"}, []string{"foo", "bar", "baz"},
}, },
{ {
`'bar box' baz`, `'bar box' baz`,
"bar box", []string{"baz"}, []string{"bar box", "baz"},
}, },
{ {
`"bar 'box'" baz`, `"bar 'box'" baz`,
"bar 'box'", []string{"baz"}, []string{"bar 'box'", "baz"},
}, },
{ {
`'bar "box"' baz`, `'bar "box"' baz`,
`bar "box"`, []string{"baz"}, []string{`bar "box"`, "baz"},
}, },
{ {
`\"bar box baz`, `\"bar box baz`,
`"bar`, []string{"box", "baz"}, []string{`"bar`, "box", "baz"},
}, },
{ {
`"bar/foo/x" "box baz"`, `"bar/foo/x" "box baz"`,
"bar/foo/x", []string{"box baz"}, []string{"bar/foo/x", "box baz"},
}, },
} }
for _, test := range tests { for _, test := range tests {
t.Run("", func(t *testing.T) { t.Run("", func(t *testing.T) {
cmd, args, err := SplitShellArgs(test.data) args, err := SplitShellStrings(test.data)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if cmd != test.cmd {
t.Fatalf("wrong cmd returned, want:\n %#v\ngot:\n %#v",
test.cmd, cmd)
}
if !reflect.DeepEqual(args, test.args) { if !reflect.DeepEqual(args, test.args) {
t.Fatalf("wrong args returned, want:\n %#v\ngot:\n %#v", t.Fatalf("wrong args returned, want:\n %#v\ngot:\n %#v",
test.args, args) test.args, args)
@ -94,7 +88,7 @@ func TestShellSplitterInvalid(t *testing.T) {
for _, test := range tests { for _, test := range tests {
t.Run("", func(t *testing.T) { t.Run("", func(t *testing.T) {
cmd, args, err := SplitShellArgs(test.data) args, err := SplitShellStrings(test.data)
if err == nil { if err == nil {
t.Fatalf("expected error not found: %v", test.err) t.Fatalf("expected error not found: %v", test.err)
} }
@ -103,10 +97,6 @@ func TestShellSplitterInvalid(t *testing.T) {
t.Fatalf("expected error not found, want:\n %q\ngot:\n %q", test.err, err.Error()) t.Fatalf("expected error not found, want:\n %q\ngot:\n %q", test.err, err.Error())
} }
if cmd != "" {
t.Fatalf("splitter returned cmd from invalid data: %v", cmd)
}
if len(args) > 0 { if len(args) > 0 {
t.Fatalf("splitter returned fields from invalid data: %v", args) t.Fatalf("splitter returned fields from invalid data: %v", args)
} }