Skip to content

Commit

Permalink
Add plotting for test_permutation_global_uniform_distribution
Browse files Browse the repository at this point in the history
Co-authored-by: Sylwester Arabas <[email protected]>
  • Loading branch information
abulenok and slayoo committed May 16, 2023
1 parent 7ab6619 commit 911695c
Showing 1 changed file with 29 additions and 9 deletions.
38 changes: 29 additions & 9 deletions tests/unit_tests/impl/test_particle_attributes.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# pylint: disable=missing-module-docstring,missing-class-docstring,missing-function-docstring
from collections import Counter

import numpy as np
import pytest
from matplotlib import pyplot
from scipy import stats

from PySDM import Formulae
from PySDM.backends import CPU, GPU, ThrustRTC
from PySDM.backends.impl_common.index import make_Index
from PySDM.backends.impl_common.indexed_storage import make_IndexedStorage
Expand Down Expand Up @@ -295,23 +295,24 @@ def test_permutation_local_repeatable(backend_class):
)

@staticmethod
@pytest.mark.parametrize("seed", (1, 2, 3))
# pylint: disable=redefined-outer-name
def test_permutation_global_uniform_distribution(backend_class):
if backend_class is ThrustRTC:
return # TODO #328

def test_permutation_global_uniform_distribution(
seed, backend_class=CPU, plot=False
):
n_sd = 4
possible_permutations_num = np.math.factorial(n_sd)
coverage = 1000

random_numbers = np.linspace(
0.0, 1.0, n_sd * possible_permutations_num * coverage
)
np.random.seed(1)
np.random.seed(seed)
np.random.shuffle(random_numbers)

# Arrange
particulator = DummyParticulator(CPU, n_sd=n_sd)
particulator = DummyParticulator(CPU, n_sd=n_sd, formulae=Formulae(seed=seed))

sut = ParticleAttributesFactory.empty_particles(particulator, n_sd)
idx_length = len(sut._ParticleAttributes__idx)
sut._ParticleAttributes__tmp_idx = make_indexed_storage(
Expand All @@ -332,6 +333,25 @@ def test_permutation_global_uniform_distribution(backend_class):
sut._ParticleAttributes__idx, idx_length
)

_, uniformity = stats.chisquare(list(Counter(permutation_ids).values()))
# Plot
counts, _ = np.histogram(permutation_ids, bins=possible_permutations_num)
_, uniformity = stats.chisquare(counts)

avg = np.mean(counts)
std = np.std(counts)

pyplot.plot(counts, marker=".")
pyplot.xlabel("permutation id")
pyplot.ylabel("occurrence count")
pyplot.xlim(0, possible_permutations_num)
pyplot.axhline(coverage, color="black", label="coverage")
pyplot.axhline(avg, color="green", label="mean +/- std")
for offset in (-std, +std):
pyplot.axhline(avg + offset, color="green", linestyle="--")
pyplot.legend()
if plot:
pyplot.show()

# Assert
assert abs(avg - coverage) / coverage < 1e-6
assert uniformity > 0.9

0 comments on commit 911695c

Please sign in to comment.