fs/fserrors: make sure Cause never returns nil

This commit is contained in:
Nick Craig-Wood 2018-07-13 10:31:40 +01:00
parent b1bd17a220
commit a3d9a38f51
2 changed files with 35 additions and 4 deletions

View file

@ -188,6 +188,12 @@ func Cause(cause error) (retriable bool, err error) {
// Unwrap 1 level if possible
err = errors.Cause(err)
if err == nil {
// errors.Cause can return nil which isn't
// desirable so pick the previous error in
// this case.
err = prev
}
if err == prev {
// Unpack any struct or *struct with a field
// of name Err which satisfies the error
@ -196,11 +202,11 @@ func Cause(cause error) (retriable bool, err error) {
// others in the stdlib
errType := reflect.TypeOf(err)
errValue := reflect.ValueOf(err)
if errType.Kind() == reflect.Ptr {
if errValue.IsValid() && errType.Kind() == reflect.Ptr {
errType = errType.Elem()
errValue = errValue.Elem()
}
if errType.Kind() == reflect.Struct {
if errValue.IsValid() && errType.Kind() == reflect.Struct {
if errField := errValue.FieldByName("Err"); errField.IsValid() {
errFieldValue := errField.Interface()
if newErr, ok := errFieldValue.(error); ok {

View file

@ -39,7 +39,15 @@ type myError2 struct {
Err error
}
func (e *myError2) Error() string { return e.Err.Error() }
func (e *myError2) Error() string {
if e == nil {
return "myError2(nil)"
}
if e.Err == nil {
return "myError2{Err: nil}"
}
return e.Err.Error()
}
type myError3 struct {
Err int
@ -53,11 +61,23 @@ type myError4 struct {
func (e *myError4) Error() string { return e.e.Error() }
type errorCause struct {
e error
}
func (e *errorCause) Error() string { return fmt.Sprintf("%#v", e) }
func (e *errorCause) Cause() error { return e.e }
func TestCause(t *testing.T) {
e3 := &myError3{3}
e4 := &myError4{io.EOF}
eNil1 := &myError2{nil}
eNil2 := &myError2{Err: (*myError2)(nil)}
errPotato := errors.New("potato")
nilCause1 := &errorCause{nil}
nilCause2 := &errorCause{(*myError2)(nil)}
for i, test := range []struct {
err error
wantRetriable bool
@ -70,10 +90,15 @@ func TestCause(t *testing.T) {
{errUseOfClosedNetworkConnection, false, errUseOfClosedNetworkConnection},
{makeNetErr(syscall.EAGAIN), true, syscall.EAGAIN},
{makeNetErr(syscall.Errno(123123123)), false, syscall.Errno(123123123)},
{eNil1, false, eNil1},
{eNil2, false, eNil2.Err},
{myError1{io.EOF}, false, io.EOF},
{&myError2{io.EOF}, false, io.EOF},
{e3, false, e3},
{e4, false, e4},
{&errorCause{errPotato}, false, errPotato},
{nilCause1, false, nilCause1},
{nilCause2, false, nilCause2.e},
} {
gotRetriable, gotErr := Cause(test.err)
what := fmt.Sprintf("test #%d: %v", i, test.err)