Skip to content

Commit

Permalink
Make the validation condition for random distributions lenient (#550)
Browse files Browse the repository at this point in the history
* Make the validation condition for random distributions lenient

* Fix typo

* Catch too small standard variations against theoretical values as well

* Replace unnecessary NumPy calls with Python primitives

* Tighten the tolerance
  • Loading branch information
magnatelee authored Aug 20, 2022
1 parent 4b9b857 commit 0c95242
Showing 1 changed file with 15 additions and 9 deletions.
24 changes: 15 additions & 9 deletions tests/integration/utils/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,9 @@ def zipf(self, alpha, shape, dtype):
return num.random.zipf(alpha, shape, dtype)


def assert_distribution(a, theo_mean, theo_stdev, tolerance=1e-2):
def assert_distribution(
a, theo_mean, theo_stdev, mean_tol=1e-2, stdev_tol=1.0
):
if True:
aa = np.array(a)
average = np.mean(aa)
Expand All @@ -154,12 +156,16 @@ def assert_distribution(a, theo_mean, theo_stdev, tolerance=1e-2):
num.mean((a - average) ** 2)
) # num.std(a) -> does not work
print(
f"average = {average} - expected {theo_mean}"
+ f", stdev = {stdev} - expected {theo_stdev}\n"
)
assert np.abs(theo_mean - average) < tolerance * np.max(
(1.0, np.abs(theo_mean))
)
assert np.abs(theo_stdev - stdev) < tolerance * np.max(
(1.0, np.abs(theo_stdev))
f"average = {average} - theoretical {theo_mean}"
+ f", stdev = {stdev} - theoretical {theo_stdev}\n"
)
assert abs(theo_mean - average) < mean_tol * max(1.0, abs(theo_mean))
# the theoretical standard deviation can't be 0
assert theo_stdev != 0
# TODO: this check is not a good proxy to validating that the samples
# respect the assumed random distribution unless we draw
# extremely many samples. until we find a better validation
# method, we make the check lenient to avoid random
# failures in the CI. (we still need the check to catch
# the cases that are obviously wrong.)
assert abs(theo_stdev - stdev) / min(theo_stdev, stdev) <= stdev_tol

0 comments on commit 0c95242

Please sign in to comment.