diff --git a/PySDM/backends/numba.py b/PySDM/backends/numba.py index 5850ca43a..7f6bf594b 100644 --- a/PySDM/backends/numba.py +++ b/PySDM/backends/numba.py @@ -36,9 +36,9 @@ class Numba( # pylint: disable=too-many-ancestors,duplicate-code default_croupier = "local" - def __init__(self, formulae=None, double_precision=True): + def __init__(self, formulae=None, *, double_precision=True): if not double_precision: - raise NotImplementedError() + raise NotImplementedError() # TODO #1144 self.formulae = formulae or Formulae() CollisionsMethods.__init__(self) PairMethods.__init__(self) diff --git a/PySDM/backends/thrust_rtc.py b/PySDM/backends/thrust_rtc.py index 79ce65bd1..ca19583e3 100644 --- a/PySDM/backends/thrust_rtc.py +++ b/PySDM/backends/thrust_rtc.py @@ -45,7 +45,7 @@ class ThrustRTC( # pylint: disable=duplicate-code,too-many-ancestors default_croupier = "global" def __init__( - self, formulae=None, double_precision=False, debug=False, verbose=False + self, formulae=None, *, double_precision=False, debug=False, verbose=False ): self.formulae = formulae or Formulae() diff --git a/tests/unit_tests/backends/test_freezing_methods.py b/tests/unit_tests/backends/test_freezing_methods.py index 2fdcd5a0e..3f8aeffe8 100644 --- a/tests/unit_tests/backends/test_freezing_methods.py +++ b/tests/unit_tests/backends/test_freezing_methods.py @@ -96,12 +96,8 @@ def test_freeze_singular(backend_class): ) @staticmethod - @pytest.mark.parametrize("double_precision", (True, False)) # pylint: disable=too-many-locals - def test_freeze_time_dependent(backend_class, double_precision, plot=False): - if backend_class.__name__ == "Numba" and not double_precision: - pytest.skip() - + def test_freeze_time_dependent(backend_class, plot=False): # Arrange seed = 44 cases = ( @@ -151,12 +147,7 @@ def low(t): key = f"{case['dt']}:{case['N']}" output[key] = {"unfrozen_fraction": [], "dt": case["dt"], "N": case["N"]} - builder = Builder( - n_sd=n_sd, - backend=backend_class( - formulae=formulae, double_precision=double_precision - ), - ) + builder = Builder(n_sd=n_sd, backend=backend_class(formulae=formulae)) env = Box(dt=case["dt"], dv=d_v) builder.set_environment(env) builder.add_dynamic(Freezing(singular=False)) diff --git a/tests/unit_tests/conftest.py b/tests/unit_tests/conftest.py index 3983e2127..c811e14fc 100644 --- a/tests/unit_tests/conftest.py +++ b/tests/unit_tests/conftest.py @@ -4,6 +4,26 @@ from PySDM.backends import CPU, GPU -@pytest.fixture(params=(CPU, GPU)) +class GPUFP32(GPU): # pylint: disable=too-many-ancestors + def __init__(self, formulae=None, **kwargs): + if "double_precision" in kwargs: + if kwargs["double_precision"]: + pytest.skip() + else: + del kwargs["double_precision"] + super().__init__(formulae=formulae, double_precision=False, **kwargs) + + +class GPUFP64(GPU): # pylint: disable=too-many-ancestors + def __init__(self, formulae=None, **kwargs): + if "double_precision" in kwargs: + if kwargs["double_precision"]: + del kwargs["double_precision"] + else: + pytest.skip() + super().__init__(formulae=formulae, double_precision=True, **kwargs) + + +@pytest.fixture(params=(CPU, GPUFP32, GPUFP64)) # TODO #1144 CPU float def backend_class(request): return request.param diff --git a/tests/unit_tests/dummy_particulator.py b/tests/unit_tests/dummy_particulator.py index 53db9258d..da975ff9e 100644 --- a/tests/unit_tests/dummy_particulator.py +++ b/tests/unit_tests/dummy_particulator.py @@ -8,7 +8,7 @@ class DummyParticulator(Builder, Particulator): def __init__(self, backend_class, n_sd=0, formulae=None, grid=None): - backend = backend_class(formulae, double_precision=True) + backend = backend_class(formulae=formulae, double_precision=True) Builder.__init__(self, n_sd, backend) Particulator.__init__(self, n_sd, backend) self.particulator = self diff --git a/tests/unit_tests/dynamics/collisions/test_croupiers.py b/tests/unit_tests/dynamics/collisions/test_croupiers.py index 447323e25..bdc41f2cc 100644 --- a/tests/unit_tests/dynamics/collisions/test_croupiers.py +++ b/tests/unit_tests/dynamics/collisions/test_croupiers.py @@ -11,8 +11,8 @@ @pytest.mark.parametrize("croupier", ["local", "global"]) def test_final_state(croupier, backend_class): - if backend_class is ThrustRTC: - return # TODO #330 + if issubclass(backend_class, ThrustRTC): + pytest.skip() # TODO #330 # Arrange n_part = 100000 diff --git a/tests/unit_tests/dynamics/test_relaxed_velocity.py b/tests/unit_tests/dynamics/test_relaxed_velocity.py index 1bed5beb4..8b369547f 100644 --- a/tests/unit_tests/dynamics/test_relaxed_velocity.py +++ b/tests/unit_tests/dynamics/test_relaxed_velocity.py @@ -47,9 +47,9 @@ def test_small_timescale(default_attributes, constant_timescale, backend_class): When the fall velocity is initialized to 0 and relaxation is very quick, the velocity should quickly approach the terminal velocity """ - builder = Builder( - n_sd=len(default_attributes["multiplicity"]), backend=backend_class() + n_sd=len(default_attributes["multiplicity"]), + backend=backend_class(), ) builder.set_environment(Box(dt=1, dv=1))