Refactor SplitShellStrings
This commit is contained in:
parent
345b6c4694
commit
34f27edc03
3 changed files with 26 additions and 33 deletions
|
@ -179,7 +179,12 @@ func (r *SFTP) IsNotExist(err error) bool {
|
||||||
|
|
||||||
func buildSSHCommand(cfg Config) (cmd string, args []string, err error) {
|
func buildSSHCommand(cfg Config) (cmd string, args []string, err error) {
|
||||||
if cfg.Command != "" {
|
if cfg.Command != "" {
|
||||||
return backend.SplitShellArgs(cfg.Command)
|
args, err := backend.SplitShellStrings(cfg.Command)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return args[0], args[1:], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd = "ssh"
|
cmd = "ssh"
|
||||||
|
|
|
@ -41,8 +41,8 @@ func (s *shellSplitter) isSplitChar(c rune) bool {
|
||||||
return c == '\\' || unicode.IsSpace(c)
|
return c == '\\' || unicode.IsSpace(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SplitShellArgs returns the list of arguments from a shell command string.
|
// SplitShellStrings returns the list of shell strings from a shell command string.
|
||||||
func SplitShellArgs(data string) (cmd string, args []string, err error) {
|
func SplitShellStrings(data string) (strs []string, err error) {
|
||||||
s := &shellSplitter{}
|
s := &shellSplitter{}
|
||||||
|
|
||||||
// derived from strings.SplitFunc
|
// derived from strings.SplitFunc
|
||||||
|
@ -50,7 +50,7 @@ func SplitShellArgs(data string) (cmd string, args []string, err error) {
|
||||||
for i, rune := range data {
|
for i, rune := range data {
|
||||||
if s.isSplitChar(rune) {
|
if s.isSplitChar(rune) {
|
||||||
if fieldStart >= 0 {
|
if fieldStart >= 0 {
|
||||||
args = append(args, data[fieldStart:i])
|
strs = append(strs, data[fieldStart:i])
|
||||||
fieldStart = -1
|
fieldStart = -1
|
||||||
}
|
}
|
||||||
} else if fieldStart == -1 {
|
} else if fieldStart == -1 {
|
||||||
|
@ -58,21 +58,19 @@ func SplitShellArgs(data string) (cmd string, args []string, err error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if fieldStart >= 0 { // Last field might end at EOF.
|
if fieldStart >= 0 { // Last field might end at EOF.
|
||||||
args = append(args, data[fieldStart:])
|
strs = append(strs, data[fieldStart:])
|
||||||
}
|
}
|
||||||
|
|
||||||
switch s.quote {
|
switch s.quote {
|
||||||
case '\'':
|
case '\'':
|
||||||
return "", nil, errors.New("single-quoted string not terminated")
|
return nil, errors.New("single-quoted string not terminated")
|
||||||
case '"':
|
case '"':
|
||||||
return "", nil, errors.New("double-quoted string not terminated")
|
return nil, errors.New("double-quoted string not terminated")
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(args) == 0 {
|
if len(strs) == 0 {
|
||||||
return "", nil, errors.New("command string is empty")
|
return nil, errors.New("command string is empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd, args = args[0], args[1:]
|
return strs, nil
|
||||||
|
|
||||||
return cmd, args, nil
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,59 +8,53 @@ import (
|
||||||
func TestShellSplitter(t *testing.T) {
|
func TestShellSplitter(t *testing.T) {
|
||||||
var tests = []struct {
|
var tests = []struct {
|
||||||
data string
|
data string
|
||||||
cmd string
|
|
||||||
args []string
|
args []string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
`foo`,
|
`foo`,
|
||||||
"foo", []string{},
|
[]string{"foo"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
`'foo'`,
|
`'foo'`,
|
||||||
"foo", []string{},
|
[]string{"foo"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
`foo bar baz`,
|
`foo bar baz`,
|
||||||
"foo", []string{"bar", "baz"},
|
[]string{"foo", "bar", "baz"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
`foo 'bar' baz`,
|
`foo 'bar' baz`,
|
||||||
"foo", []string{"bar", "baz"},
|
[]string{"foo", "bar", "baz"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
`'bar box' baz`,
|
`'bar box' baz`,
|
||||||
"bar box", []string{"baz"},
|
[]string{"bar box", "baz"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
`"bar 'box'" baz`,
|
`"bar 'box'" baz`,
|
||||||
"bar 'box'", []string{"baz"},
|
[]string{"bar 'box'", "baz"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
`'bar "box"' baz`,
|
`'bar "box"' baz`,
|
||||||
`bar "box"`, []string{"baz"},
|
[]string{`bar "box"`, "baz"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
`\"bar box baz`,
|
`\"bar box baz`,
|
||||||
`"bar`, []string{"box", "baz"},
|
[]string{`"bar`, "box", "baz"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
`"bar/foo/x" "box baz"`,
|
`"bar/foo/x" "box baz"`,
|
||||||
"bar/foo/x", []string{"box baz"},
|
[]string{"bar/foo/x", "box baz"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
t.Run("", func(t *testing.T) {
|
t.Run("", func(t *testing.T) {
|
||||||
cmd, args, err := SplitShellArgs(test.data)
|
args, err := SplitShellStrings(test.data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if cmd != test.cmd {
|
|
||||||
t.Fatalf("wrong cmd returned, want:\n %#v\ngot:\n %#v",
|
|
||||||
test.cmd, cmd)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !reflect.DeepEqual(args, test.args) {
|
if !reflect.DeepEqual(args, test.args) {
|
||||||
t.Fatalf("wrong args returned, want:\n %#v\ngot:\n %#v",
|
t.Fatalf("wrong args returned, want:\n %#v\ngot:\n %#v",
|
||||||
test.args, args)
|
test.args, args)
|
||||||
|
@ -94,7 +88,7 @@ func TestShellSplitterInvalid(t *testing.T) {
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
t.Run("", func(t *testing.T) {
|
t.Run("", func(t *testing.T) {
|
||||||
cmd, args, err := SplitShellArgs(test.data)
|
args, err := SplitShellStrings(test.data)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatalf("expected error not found: %v", test.err)
|
t.Fatalf("expected error not found: %v", test.err)
|
||||||
}
|
}
|
||||||
|
@ -103,10 +97,6 @@ func TestShellSplitterInvalid(t *testing.T) {
|
||||||
t.Fatalf("expected error not found, want:\n %q\ngot:\n %q", test.err, err.Error())
|
t.Fatalf("expected error not found, want:\n %q\ngot:\n %q", test.err, err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
if cmd != "" {
|
|
||||||
t.Fatalf("splitter returned cmd from invalid data: %v", cmd)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(args) > 0 {
|
if len(args) > 0 {
|
||||||
t.Fatalf("splitter returned fields from invalid data: %v", args)
|
t.Fatalf("splitter returned fields from invalid data: %v", args)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue