fs/fserrors: make sure Cause never returns nil
This commit is contained in:
parent
b1bd17a220
commit
a3d9a38f51
2 changed files with 35 additions and 4 deletions
|
@ -188,6 +188,12 @@ func Cause(cause error) (retriable bool, err error) {
|
|||
|
||||
// Unwrap 1 level if possible
|
||||
err = errors.Cause(err)
|
||||
if err == nil {
|
||||
// errors.Cause can return nil which isn't
|
||||
// desirable so pick the previous error in
|
||||
// this case.
|
||||
err = prev
|
||||
}
|
||||
if err == prev {
|
||||
// Unpack any struct or *struct with a field
|
||||
// of name Err which satisfies the error
|
||||
|
@ -196,11 +202,11 @@ func Cause(cause error) (retriable bool, err error) {
|
|||
// others in the stdlib
|
||||
errType := reflect.TypeOf(err)
|
||||
errValue := reflect.ValueOf(err)
|
||||
if errType.Kind() == reflect.Ptr {
|
||||
if errValue.IsValid() && errType.Kind() == reflect.Ptr {
|
||||
errType = errType.Elem()
|
||||
errValue = errValue.Elem()
|
||||
}
|
||||
if errType.Kind() == reflect.Struct {
|
||||
if errValue.IsValid() && errType.Kind() == reflect.Struct {
|
||||
if errField := errValue.FieldByName("Err"); errField.IsValid() {
|
||||
errFieldValue := errField.Interface()
|
||||
if newErr, ok := errFieldValue.(error); ok {
|
||||
|
|
|
@ -39,7 +39,15 @@ type myError2 struct {
|
|||
Err error
|
||||
}
|
||||
|
||||
func (e *myError2) Error() string { return e.Err.Error() }
|
||||
func (e *myError2) Error() string {
|
||||
if e == nil {
|
||||
return "myError2(nil)"
|
||||
}
|
||||
if e.Err == nil {
|
||||
return "myError2{Err: nil}"
|
||||
}
|
||||
return e.Err.Error()
|
||||
}
|
||||
|
||||
type myError3 struct {
|
||||
Err int
|
||||
|
@ -53,11 +61,23 @@ type myError4 struct {
|
|||
|
||||
func (e *myError4) Error() string { return e.e.Error() }
|
||||
|
||||
type errorCause struct {
|
||||
e error
|
||||
}
|
||||
|
||||
func (e *errorCause) Error() string { return fmt.Sprintf("%#v", e) }
|
||||
|
||||
func (e *errorCause) Cause() error { return e.e }
|
||||
|
||||
func TestCause(t *testing.T) {
|
||||
e3 := &myError3{3}
|
||||
e4 := &myError4{io.EOF}
|
||||
|
||||
eNil1 := &myError2{nil}
|
||||
eNil2 := &myError2{Err: (*myError2)(nil)}
|
||||
errPotato := errors.New("potato")
|
||||
nilCause1 := &errorCause{nil}
|
||||
nilCause2 := &errorCause{(*myError2)(nil)}
|
||||
|
||||
for i, test := range []struct {
|
||||
err error
|
||||
wantRetriable bool
|
||||
|
@ -70,10 +90,15 @@ func TestCause(t *testing.T) {
|
|||
{errUseOfClosedNetworkConnection, false, errUseOfClosedNetworkConnection},
|
||||
{makeNetErr(syscall.EAGAIN), true, syscall.EAGAIN},
|
||||
{makeNetErr(syscall.Errno(123123123)), false, syscall.Errno(123123123)},
|
||||
{eNil1, false, eNil1},
|
||||
{eNil2, false, eNil2.Err},
|
||||
{myError1{io.EOF}, false, io.EOF},
|
||||
{&myError2{io.EOF}, false, io.EOF},
|
||||
{e3, false, e3},
|
||||
{e4, false, e4},
|
||||
{&errorCause{errPotato}, false, errPotato},
|
||||
{nilCause1, false, nilCause1},
|
||||
{nilCause2, false, nilCause2.e},
|
||||
} {
|
||||
gotRetriable, gotErr := Cause(test.err)
|
||||
what := fmt.Sprintf("test #%d: %v", i, test.err)
|
||||
|
|
Loading…
Reference in a new issue