diff --git a/lib/errors/errors.go b/lib/errors/errors.go new file mode 100644 index 000000000..a8a098972 --- /dev/null +++ b/lib/errors/errors.go @@ -0,0 +1,76 @@ +package errors + +import ( + "errors" + "fmt" + "reflect" +) + +// New returns an error that formats as the given text. +func New(text string) error { + return errors.New(text) +} + +// Errorf formats according to a format specifier and returns the string +// as a value that satisfies error. +func Errorf(format string, a ...interface{}) error { + return fmt.Errorf(format, a...) +} + +// WalkFunc is the signature of the Walk callback function. The function gets the +// current error in the chain and should return true if the chain processing +// should be aborted. +type WalkFunc func(error) bool + +// Walk invokes the given function for each error in the chain. If the +// provided functions returns true or no further cause can be found, the process +// is stopped and no further calls will be made. +// +// The next error in the chain is determined by the following rules: +// - If the current error has a `Cause() error` method (github.com/pkg/errors), +// the return value of this method is used. +// - If the current error has a `Unwrap() error` method (golang.org/x/xerrors), +// the return value of this method is used. +// - Common errors in the Go runtime that contain an Err field will use this value. +func Walk(err error, f WalkFunc) { + for prev := err; err != nil; prev = err { + if f(err) { + return + } + + switch e := err.(type) { + case causer: + err = e.Cause() + case wrapper: + err = e.Unwrap() + default: + // Unpack any struct or *struct with a field of name Err which satisfies + // the error interface. This includes *url.Error, *net.OpError, + // *os.SyscallError and many others in the stdlib. + errType := reflect.TypeOf(err) + errValue := reflect.ValueOf(err) + if errValue.IsValid() && errType.Kind() == reflect.Ptr { + errType = errType.Elem() + errValue = errValue.Elem() + } + if errValue.IsValid() && errType.Kind() == reflect.Struct { + if errField := errValue.FieldByName("Err"); errField.IsValid() { + errFieldValue := errField.Interface() + if newErr, ok := errFieldValue.(error); ok { + err = newErr + } + } + } + } + if err == prev { + break + } + } +} + +type causer interface { + Cause() error +} +type wrapper interface { + Unwrap() error +} diff --git a/lib/errors/errors_test.go b/lib/errors/errors_test.go new file mode 100644 index 000000000..1a7c3f5d5 --- /dev/null +++ b/lib/errors/errors_test.go @@ -0,0 +1,90 @@ +package errors_test + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/ncw/rclone/lib/errors" +) + +func TestWalk(t *testing.T) { + origin := errors.New("origin") + + for _, test := range []struct { + err error + calls int + last error + }{ + {causerError{nil}, 1, causerError{nil}}, + {wrapperError{nil}, 1, wrapperError{nil}}, + {reflectError{nil}, 1, reflectError{nil}}, + {causerError{origin}, 2, origin}, + {wrapperError{origin}, 2, origin}, + {reflectError{origin}, 2, origin}, + {causerError{reflectError{origin}}, 3, origin}, + {wrapperError{causerError{origin}}, 3, origin}, + {reflectError{wrapperError{origin}}, 3, origin}, + {causerError{reflectError{causerError{origin}}}, 4, origin}, + {wrapperError{causerError{wrapperError{origin}}}, 4, origin}, + {reflectError{wrapperError{reflectError{origin}}}, 4, origin}, + + {stopError{nil}, 1, stopError{nil}}, + {stopError{causerError{nil}}, 1, stopError{causerError{nil}}}, + {stopError{wrapperError{nil}}, 1, stopError{wrapperError{nil}}}, + {stopError{reflectError{nil}}, 1, stopError{reflectError{nil}}}, + {causerError{stopError{origin}}, 2, stopError{origin}}, + {wrapperError{stopError{origin}}, 2, stopError{origin}}, + {reflectError{stopError{origin}}, 2, stopError{origin}}, + {causerError{reflectError{stopError{nil}}}, 3, stopError{nil}}, + {wrapperError{causerError{stopError{nil}}}, 3, stopError{nil}}, + {reflectError{wrapperError{stopError{nil}}}, 3, stopError{nil}}, + } { + var last error + calls := 0 + errors.Walk(test.err, func(err error) bool { + calls++ + last = err + _, stop := err.(stopError) + return stop + }) + assert.Equal(t, test.calls, calls) + assert.Equal(t, test.last, last) + } +} + +type causerError struct { + err error +} +type wrapperError struct { + err error +} +type reflectError struct { + Err error +} +type stopError struct { + err error +} + +func (e causerError) Error() string { + return fmt.Sprintf("causerError(%s)", e.err) +} +func (e causerError) Cause() error { + return e.err +} +func (e wrapperError) Unwrap() error { + return e.err +} +func (e wrapperError) Error() string { + return fmt.Sprintf("wrapperError(%s)", e.err) +} +func (e reflectError) Error() string { + return fmt.Sprintf("reflectError(%s)", e.Err) +} +func (e stopError) Error() string { + return fmt.Sprintf("stopError(%s)", e.err) +} +func (e stopError) Cause() error { + return e.err +}