cat: extract parameter validation and add a test
This commit is contained in:
parent
705556f134
commit
fe54912a46
2 changed files with 48 additions and 12 deletions
|
@ -13,8 +13,6 @@ import (
|
|||
"github.com/restic/restic/internal/restic"
|
||||
)
|
||||
|
||||
var allowedCmds = []string{"config", "index", "snapshot", "key", "masterkey", "lock", "pack", "blob", "tree"}
|
||||
|
||||
var cmdCat = &cobra.Command{
|
||||
Use: "cat [flags] [masterkey|config|pack ID|blob ID|snapshot ID|index ID|key ID|lock ID|tree snapshot:subfolder]",
|
||||
Short: "Print internal objects to stdout",
|
||||
|
@ -36,21 +34,21 @@ func init() {
|
|||
cmdRoot.AddCommand(cmdCat)
|
||||
}
|
||||
|
||||
func validateParam(param string) bool {
|
||||
for _, v := range allowedCmds {
|
||||
if v == param {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
func validateCatArgs(args []string) error {
|
||||
var allowedCmds = []string{"config", "index", "snapshot", "key", "masterkey", "lock", "pack", "blob", "tree"}
|
||||
|
||||
func runCat(ctx context.Context, gopts GlobalOptions, args []string) error {
|
||||
if len(args) < 1 {
|
||||
return errors.Fatal("type not specified")
|
||||
}
|
||||
|
||||
if ok := validateParam(args[0]); !ok {
|
||||
validType := false
|
||||
for _, v := range allowedCmds {
|
||||
if v == args[0] {
|
||||
validType = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !validType {
|
||||
return errors.Fatalf("invalid type %q, must be one of [%s]", args[0], strings.Join(allowedCmds, "|"))
|
||||
}
|
||||
|
||||
|
@ -58,6 +56,14 @@ func runCat(ctx context.Context, gopts GlobalOptions, args []string) error {
|
|||
return errors.Fatal("ID not specified")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func runCat(ctx context.Context, gopts GlobalOptions, args []string) error {
|
||||
if err := validateCatArgs(args); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
repo, err := OpenRepository(ctx, gopts)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
30
cmd/restic/cmd_cat_test.go
Normal file
30
cmd/restic/cmd_cat_test.go
Normal file
|
@ -0,0 +1,30 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
rtest "github.com/restic/restic/internal/test"
|
||||
)
|
||||
|
||||
func TestCatArgsValidation(t *testing.T) {
|
||||
for _, test := range []struct {
|
||||
args []string
|
||||
err string
|
||||
}{
|
||||
{[]string{}, "Fatal: type not specified"},
|
||||
{[]string{"masterkey"}, ""},
|
||||
{[]string{"invalid"}, `Fatal: invalid type "invalid"`},
|
||||
{[]string{"snapshot"}, "Fatal: ID not specified"},
|
||||
{[]string{"snapshot", "12345678"}, ""},
|
||||
} {
|
||||
t.Run("", func(t *testing.T) {
|
||||
err := validateCatArgs(test.args)
|
||||
if test.err == "" {
|
||||
rtest.Assert(t, err == nil, "unexpected error %q", err)
|
||||
} else {
|
||||
rtest.Assert(t, strings.Contains(err.Error(), test.err), "unexpected error expected %q to contain %q", err, test.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue