diff --git a/retry.go b/retry.go index af40b10..f739915 100644 --- a/retry.go +++ b/retry.go @@ -155,8 +155,8 @@ func DoWithData[T any](retryableFunc RetryableFuncWithData[T], opts ...Option) ( lastErr = err - n++ config.onRetry(n, err) + n++ select { case <-config.timer.After(delay(config, n, err)): case <-config.context.Done(): diff --git a/retry_test.go b/retry_test.go index 649bbae..bd5492f 100644 --- a/retry_test.go +++ b/retry_test.go @@ -33,6 +33,7 @@ func TestDoWithDataAllFailed(t *testing.T) { #9: test #10: test` assert.Len(t, err, 10) + fmt.Println(err.Error()) assert.Equal(t, expectedErrorFormat, err.Error(), "retry error format") assert.Equal(t, uint(45), retrySum, "right count of retry") } @@ -88,16 +89,17 @@ func TestRetryIf(t *testing.T) { } func TestRetryIf_ZeroAttempts(t *testing.T) { - var retryCount uint + var retryCount, onRetryCount uint err := Do( func() error { if retryCount >= 2 { return errors.New("special") } else { + retryCount++ return errors.New("test") } }, - OnRetry(func(n uint, err error) { retryCount++ }), + OnRetry(func(n uint, err error) { onRetryCount = n }), RetryIf(func(err error) bool { return err.Error() != "special" }), @@ -107,7 +109,7 @@ func TestRetryIf_ZeroAttempts(t *testing.T) { assert.Error(t, err) assert.Equal(t, "special", err.Error(), "retry error format") - assert.Equal(t, uint(2), retryCount, "right count of retry") + assert.Equal(t, retryCount, onRetryCount+1, "right count of retry") } func TestZeroAttemptsWithError(t *testing.T) {