From 06dbabbf16e7901b97a4a1e049a3cc28e8531c0a Mon Sep 17 00:00:00 2001 From: Samrat Thapa Date: Sat, 7 Dec 2024 16:15:07 +0900 Subject: [PATCH] added tests --- src/rocker/extensions.py | 2 +- test/test_extension.py | 70 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 71 insertions(+), 1 deletion(-) diff --git a/src/rocker/extensions.py b/src/rocker/extensions.py index c6de230..ff8dddf 100644 --- a/src/rocker/extensions.py +++ b/src/rocker/extensions.py @@ -484,7 +484,7 @@ def get_docker_args(self, cliargs): args = '' shm_size = cliargs.get('shm_size', None) if shm_size: - args += f' --shm-size={shm_size} ' + args += f' --shm-size {shm_size} ' return args @staticmethod diff --git a/test/test_extension.py b/test/test_extension.py index d264c35..26558ab 100644 --- a/test/test_extension.py +++ b/test/test_extension.py @@ -617,3 +617,73 @@ def test_group_add_extension(self): args = p.get_docker_args(mock_cliargs) self.assertIn('--group-add sudo', args) self.assertIn('--group-add docker', args) + +class ShmSizeExtensionTest(unittest.TestCase): + + def setUp(self): + # Work around interference between empy Interpreter + # stdout proxy and test runner. empy installs a proxy on stdout + # to be able to capture the information. + # And the test runner creates a new stdout object for each test. + # This breaks empy as it assumes that the proxy has persistent + # between instances of the Interpreter class + # empy will error with the exception + # "em.Error: interpreter stdout proxy lost" + em.Interpreter._wasProxyInstalled = False + + @pytest.mark.docker + def test_shm_size_extension(self): + plugins = list_plugins() + shm_size_plugin = plugins['shm_size'] + self.assertEqual(shm_size_plugin.get_name(), 'shm_size') + + p = shm_size_plugin() + self.assertTrue(plugin_load_parser_correctly(shm_size_plugin)) + + mock_cliargs = {} + self.assertEqual(p.get_snippet(mock_cliargs), '') + self.assertEqual(p.get_preamble(mock_cliargs), '') + args = p.get_docker_args(mock_cliargs) + self.assertNotIn('--shm-size', args) + + mock_cliargs = {'shm_size': '12g'} + args = p.get_docker_args(mock_cliargs) + self.assertIn('--shm-size 12g', args) + +class GpusExtensionTest(unittest.TestCase): + + def setUp(self): + # Work around interference between empy Interpreter + # stdout proxy and test runner. empy installs a proxy on stdout + # to be able to capture the information. + # And the test runner creates a new stdout object for each test. + # This breaks empy as it assumes that the proxy has persistent + # between instances of the Interpreter class + # empy will error with the exception + # "em.Error: interpreter stdout proxy lost" + em.Interpreter._wasProxyInstalled = False + + @pytest.mark.docker + def test_gpus_extension(self): + plugins = list_plugins() + gpus_plugin = plugins['gpus'] + self.assertEqual(gpus_plugin.get_name(), 'gpus') + + p = gpus_plugin() + self.assertTrue(plugin_load_parser_correctly(gpus_plugin)) + + # Test when no GPUs are specified + mock_cliargs = {} + self.assertEqual(p.get_snippet(mock_cliargs), '') + self.assertEqual(p.get_preamble(mock_cliargs), '') + args = p.get_docker_args(mock_cliargs) + self.assertNotIn('--gpus', args) + + # Test when GPUs are specified + mock_cliargs = {'gpus': 'all'} + args = p.get_docker_args(mock_cliargs) + self.assertIn('--gpus all', args) + + mock_cliargs = {'gpus': '0,1'} + args = p.get_docker_args(mock_cliargs) + self.assertIn('--gpus 0,1', args)