Skip to content

Commit

Permalink
feat: support join error match for eris.Is and eris.As (#27)
Browse files Browse the repository at this point in the history
  • Loading branch information
Gogomoe authored Jun 14, 2024
1 parent c2b8c9a commit 78185bb
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 2 deletions.
20 changes: 19 additions & 1 deletion eris.go
Original file line number Diff line number Diff line change
Expand Up @@ -433,8 +433,18 @@ func (e *rootError) Format(s fmt.State, verb rune) {
printError(e, s, verb)
}

// Is returns true if both errors have the same message and code. Ignores additional KV pairs.
// Is returns true if both errors have the same message and code.
// In case of a joined error, returns true if at least one of the joined errors is equal to target.
// Ignores additional KV pairs.
func (e *rootError) Is(target error) bool {
if joinErr, ok := e.ext.(joinError); ok {
for _, err := range joinErr.Unwrap() {
if Is(err, target) {
return true
}
}
return false
}
if err, ok := target.(*rootError); ok {
return e.msg == err.msg && e.code == err.code && reflect.DeepEqual(e.kvs, err.kvs)
}
Expand All @@ -446,6 +456,14 @@ func (e *rootError) Is(target error) bool {

// As returns true if the error message in the target error is equivalent to the error message in the root error.
func (e *rootError) As(target any) bool {
if joinErr, ok := e.ext.(joinError); ok {
for _, err := range joinErr.Unwrap() {
if As(err, target) {
return true
}
}
return false
}
t := reflect.Indirect(reflect.ValueOf(target)).Interface()
if err, ok := t.(*rootError); ok {
if e.msg == err.msg {
Expand Down
73 changes: 72 additions & 1 deletion eris_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@ func TestErrorUnwrap(t *testing.T) {
}

func TestErrorIs(t *testing.T) {
rootErr := eris.New("root error")
externalErr := errors.New("external error")
customErr := withLayer{
msg: "additional context",
Expand Down Expand Up @@ -518,6 +519,53 @@ func TestErrorIs(t *testing.T) {
compare: nil,
output: true,
},
"join error (external)": {
cause: eris.Join(externalErr, rootErr),
compare: externalErr,
output: true,
},
"join error (root)": {
cause: eris.Join(externalErr, rootErr),
compare: rootErr,
output: true,
},
"join error (nil)": {
cause: eris.Join(nil, nil),
compare: nil,
output: true,
},
"join error (wrap)": {
cause: eris.Join(externalErr, eris.Wrap(rootErr, "eris wrap error")),
compare: eris.New("eris wrap error").WithCode(eris.CodeInternal),
output: true,
},
"join error not found (code don't match)": {
cause: eris.Join(externalErr, eris.Wrap(rootErr, "eris wrap error")),
compare: eris.New("eris wrap error"),
output: false,
},
"join error not found (message don't match)": {
cause: eris.Join(externalErr, eris.Wrap(rootErr, "eris wrap error")),
compare: eris.New("eris root error message wrong"),
output: false,
},
"join error not found (external don't match)": {
cause: eris.Join(externalErr, rootErr),
compare: errors.New("external error not match"),
output: false,
},
"wrapped join error (match join errors)": {
cause: eris.Join(externalErr, rootErr),
input: []string{"additional context"},
compare: rootErr,
output: true,
},
"wrapped join error (match wrap)": {
cause: eris.Join(externalErr, rootErr),
input: []string{"additional context"},
compare: eris.New("additional context").WithCode(eris.CodeUnknown),
output: true,
},
}

for desc, tc := range tests {
Expand All @@ -535,6 +583,7 @@ func TestErrorIs(t *testing.T) {
func TestErrorAs(t *testing.T) {
externalError := errors.New("external error")
rootErr := eris.New("root error").WithCode(eris.CodeUnknown)
anotherRootErr := eris.New("another root error").WithCode(eris.CodeUnknown)
wrappedErr := eris.WithCode(eris.Wrap(rootErr, "additional context"), eris.CodeUnknown)
customErr := withLayer{
msg: "additional context",
Expand All @@ -544,7 +593,6 @@ func TestErrorAs(t *testing.T) {
},
},
}

tests := map[string]struct {
cause error // root error
target any // errors for comparison
Expand Down Expand Up @@ -641,6 +689,29 @@ func TestErrorAs(t *testing.T) {
match: true,
output: customErr,
},
"join error (external)": {
cause: eris.Join(externalError, rootErr),
target: &externalError,
match: true,
output: externalError,
},
"join error (root)": {
cause: eris.Join(externalError, rootErr),
target: &rootErr,
match: true,
output: rootErr,
},
"join error (custom)": {
cause: eris.Join(externalError, withMessage{"test"}),
target: &withMessage{""},
match: true,
output: withMessage{"test"},
},
"join error not found (message don't match)": {
cause: eris.Join(externalError, rootErr),
target: &anotherRootErr,
match: false,
},
}

for desc, tc := range tests {
Expand Down

0 comments on commit 78185bb

Please sign in to comment.