diff --git a/README.md b/README.md index 35b92af..ab1e368 100644 --- a/README.md +++ b/README.md @@ -40,15 +40,15 @@ hyperparameters-tensorflow,hyperparameters-pytorch,memory-release-tensorflow,\ deterministic-pytorch,randomness-control-numpy,randomness-control-scikitlearn,\ randomness-control-tensorflow,randomness-control-pytorch,randomness-control-dataloader-pytorch,\ missing-mask-tensorflow,missing-mask-pytorch,tensor-array-tensorflow,\ -forward-pytorch,gradient-clear-pytorch,data-leakage-scikitlearn,\ +forward-pytorch,gradient-clear-pytorch,pipeline-not-used-scikitlearn,\ dependent-threshold-scikitlearn,dependent-threshold-tensorflow,dependent-threshold-pytorch \ ---output-format=json:report.json,text:report.txt,colorized \ +--output-format=text:report.txt,colorized \ --reports=y \ ``` [For Windows Users]: ``` -pylint --load-plugins=dslinter --disable=all --enable=import,unnecessary-iteration-pandas,unnecessary-iteration-tensorflow,nan-numpy,chain-indexing-pandas,datatype-pandas,column-selection-pandas,merge-parameter-pandas,inplace-pandas,dataframe-conversion-pandas,scaler-missing-scikitlearn,hyperparameters-scikitlearn,hyperparameters-tensorflow,hyperparameters-pytorch,memory-release-tensorflow,deterministic-pytorch,randomness-control-numpy,randomness-control-scikitlearn,randomness-control-tensorflow,randomness-control-pytorch,randomness-control-dataloader-pytorch,missing-mask-tensorflow,missing-mask-pytorch,tensor-array-tensorflow,forward-pytorch,gradient-clear-pytorch,data-leakage-scikitlearn,dependent-threshold-scikitlearn,dependent-threshold-tensorflow,dependent-threshold-pytorch --output-format=json:report.json,text:report.txt,colorized --reports=y +pylint --load-plugins=dslinter --disable=all --enable=import,unnecessary-iteration-pandas,unnecessary-iteration-tensorflow,nan-numpy,chain-indexing-pandas,datatype-pandas,column-selection-pandas,merge-parameter-pandas,inplace-pandas,dataframe-conversion-pandas,scaler-missing-scikitlearn,hyperparameters-scikitlearn,hyperparameters-tensorflow,hyperparameters-pytorch,memory-release-tensorflow,deterministic-pytorch,randomness-control-numpy,randomness-control-scikitlearn,randomness-control-tensorflow,randomness-control-pytorch,randomness-control-dataloader-pytorch,missing-mask-tensorflow,missing-mask-pytorch,tensor-array-tensorflow,forward-pytorch,gradient-clear-pytorch,pipeline-not-used-scikitlearn,dependent-threshold-scikitlearn,dependent-threshold-tensorflow,dependent-threshold-pytorch --output-format=text:report.txt,colorized --reports=y ``` Or place a [`.pylintrc` configuration file](https://github.com/Hynn01/dslinter/blob/main/docs/pylint-configuration-examples/pylintrc-with-only-dslinter-settings/.pylintrc) which contains above settings in the folder where you run your command on, and run: ``` @@ -141,7 +141,7 @@ poetry run pytest . - **W5517 | gradient-clear-pytorch | Gradient Clear Checker(PyTorch)**: The loss_fn.backward() and optimizer.step() should be used together with optimizer.zero_grad(). If the `.zero_grad()` is missing in the code, the rule is violated. -- **W5518 | data-leakage-scikitlearn | Data Leakage Checker(ScikitLearn)**: All scikit-learn estimators should be used inside Pipelines, to prevent data leakage between training and test data. +- **W5518 | pipeline-not-used-scikitlearn | Pipeline Checker(ScikitLearn)**: All scikit-learn estimators should be used inside Pipelines, to prevent data leakage between training and test data. - **W5519 | dependent-threshold-scikitlearn | Dependent Threshold Checker(TensorFlow)**: If threshold-dependent evaluation(e.g., f-score) is used in the code, check whether threshold-indenpendent evaluation(e.g., auc) metrics is also used in the code. diff --git a/STEPS_TO_FOLLOW.md b/STEPS_TO_FOLLOW.md index 6f5ee9f..02b5d4f 100644 --- a/STEPS_TO_FOLLOW.md +++ b/STEPS_TO_FOLLOW.md @@ -27,15 +27,15 @@ hyperparameters-tensorflow,hyperparameters-pytorch,memory-release-tensorflow,\ deterministic-pytorch,randomness-control-numpy,randomness-control-scikitlearn,\ randomness-control-tensorflow,randomness-control-pytorch,randomness-control-dataloader-pytorch,\ missing-mask-tensorflow,missing-mask-pytorch,tensor-array-tensorflow,\ -forward-pytorch,gradient-clear-pytorch,data-leakage-scikitlearn,\ +forward-pytorch,gradient-clear-pytorch,pipeline-not-used-scikitlearn,\ dependent-threshold-scikitlearn,dependent-threshold-tensorflow,dependent-threshold-pytorch \ ---output-format=json:report.json,text:report.txt,colorized \ +--output-format=text:report.txt,colorized \ --reports=y \ ``` [For Windows Users]: ``` -pylint --load-plugins=dslinter --disable=all --enable=import,unnecessary-iteration-pandas,unnecessary-iteration-tensorflow,nan-numpy,chain-indexing-pandas,datatype-pandas,column-selection-pandas,merge-parameter-pandas,inplace-pandas,dataframe-conversion-pandas,scaler-missing-scikitlearn,hyperparameters-scikitlearn,hyperparameters-tensorflow,hyperparameters-pytorch,memory-release-tensorflow,deterministic-pytorch,randomness-control-numpy,randomness-control-scikitlearn,randomness-control-tensorflow,randomness-control-pytorch,randomness-control-dataloader-pytorch,missing-mask-tensorflow,missing-mask-pytorch,tensor-array-tensorflow,forward-pytorch,gradient-clear-pytorch,data-leakage-scikitlearn,dependent-threshold-scikitlearn,dependent-threshold-tensorflow,dependent-threshold-pytorch --output-format=json:report.json,text:report.txt,colorized --reports=y +pylint --load-plugins=dslinter --disable=all --enable=import,unnecessary-iteration-pandas,unnecessary-iteration-tensorflow,nan-numpy,chain-indexing-pandas,datatype-pandas,column-selection-pandas,merge-parameter-pandas,inplace-pandas,dataframe-conversion-pandas,scaler-missing-scikitlearn,hyperparameters-scikitlearn,hyperparameters-tensorflow,hyperparameters-pytorch,memory-release-tensorflow,deterministic-pytorch,randomness-control-numpy,randomness-control-scikitlearn,randomness-control-tensorflow,randomness-control-pytorch,randomness-control-dataloader-pytorch,missing-mask-tensorflow,missing-mask-pytorch,tensor-array-tensorflow,forward-pytorch,gradient-clear-pytorch,pipeline-not-used-scikitlearn,dependent-threshold-scikitlearn,dependent-threshold-tensorflow,dependent-threshold-pytorch --output-format=text:report.txt,colorized --reports=y ``` ## For Notebook: @@ -67,13 +67,13 @@ hyperparameters-tensorflow,hyperparameters-pytorch,memory-release-tensorflow,\ deterministic-pytorch,randomness-control-numpy,randomness-control-scikitlearn,\ randomness-control-tensorflow,randomness-control-pytorch,randomness-control-dataloader-pytorch,\ missing-mask-tensorflow,missing-mask-pytorch,tensor-array-tensorflow,\ -forward-pytorch,gradient-clear-pytorch,data-leakage-scikitlearn,\ +forward-pytorch,gradient-clear-pytorch,pipeline-not-used-scikitlearn,\ dependent-threshold-scikitlearn,dependent-threshold-tensorflow,dependent-threshold-pytorch \ ---output-format=json:report.json,text:report.txt,colorized \ +--output-format=text:report.txt,colorized \ --reports=y \ ``` [For Windows Users]: ``` -pylint --load-plugins=dslinter --disable=all --enable=import,unnecessary-iteration-pandas,unnecessary-iteration-tensorflow,nan-numpy,chain-indexing-pandas,datatype-pandas,column-selection-pandas,merge-parameter-pandas,inplace-pandas,dataframe-conversion-pandas,scaler-missing-scikitlearn,hyperparameters-scikitlearn,hyperparameters-tensorflow,hyperparameters-pytorch,memory-release-tensorflow,deterministic-pytorch,randomness-control-numpy,randomness-control-scikitlearn,randomness-control-tensorflow,randomness-control-pytorch,randomness-control-dataloader-pytorch,missing-mask-tensorflow,missing-mask-pytorch,tensor-array-tensorflow,forward-pytorch,gradient-clear-pytorch,data-leakage-scikitlearn,dependent-threshold-scikitlearn,dependent-threshold-tensorflow,dependent-threshold-pytorch --output-format=json:report.json,text:report.txt,colorized --reports=y +pylint --load-plugins=dslinter --disable=all --enable=import,unnecessary-iteration-pandas,unnecessary-iteration-tensorflow,nan-numpy,chain-indexing-pandas,datatype-pandas,column-selection-pandas,merge-parameter-pandas,inplace-pandas,dataframe-conversion-pandas,scaler-missing-scikitlearn,hyperparameters-scikitlearn,hyperparameters-tensorflow,hyperparameters-pytorch,memory-release-tensorflow,deterministic-pytorch,randomness-control-numpy,randomness-control-scikitlearn,randomness-control-tensorflow,randomness-control-pytorch,randomness-control-dataloader-pytorch,missing-mask-tensorflow,missing-mask-pytorch,tensor-array-tensorflow,forward-pytorch,gradient-clear-pytorch,pipeline-not-used-scikitlearn,dependent-threshold-scikitlearn,dependent-threshold-tensorflow,dependent-threshold-pytorch --output-format=text:report.txt,colorized --reports=y ``` diff --git a/dslinter/checkers/deterministic_pytorch.py b/dslinter/checkers/deterministic_pytorch.py index 7949e08..0e2e70f 100644 --- a/dslinter/checkers/deterministic_pytorch.py +++ b/dslinter/checkers/deterministic_pytorch.py @@ -54,10 +54,15 @@ def visit_module(self, module: astroid.Module): if _import_pytorch is False: _import_pytorch = has_import(node, "torch") - if isinstance(node, astroid.nodes.Expr) and hasattr(node, "value"): - call_node = node.value + if isinstance(node, astroid.nodes.Expr): if _has_deterministic_algorithm_option is False: - _has_deterministic_algorithm_option = self._check_deterministic_algorithm_option(call_node) + _has_deterministic_algorithm_option = self._check_deterministic_algorithm_option_in_expr_node(node) + + if isinstance(node, astroid.nodes.FunctionDef): + for nod in node.body: + if isinstance(nod, astroid.nodes.Expr): + if _has_deterministic_algorithm_option is False: + _has_deterministic_algorithm_option = self._check_deterministic_algorithm_option_in_expr_node(nod) # check if the rules are violated if( @@ -70,7 +75,13 @@ def visit_module(self, module: astroid.Module): ExceptionHandler.handle(self, module) @staticmethod - def _check_deterministic_algorithm_option(call_node: astroid.Call): + def _check_deterministic_algorithm_option_in_expr_node(expr_node: astroid.Expr): + if hasattr(expr_node, "value"): + call_node = expr_node.value + return DeterministicAlgorithmChecker._check_deterministic_algorithm_option_in_call_node(call_node) + + @staticmethod + def _check_deterministic_algorithm_option_in_call_node(call_node: astroid.Call): # if torch.use_deterministic_algorithm() is call and the argument is True, # set _has_deterministic_algorithm_option to True if( diff --git a/dslinter/checkers/data_leakage_scikitlearn.py b/dslinter/checkers/pipeline_scikitlearn.py similarity index 92% rename from dslinter/checkers/data_leakage_scikitlearn.py rename to dslinter/checkers/pipeline_scikitlearn.py index 3be2222..f19e4d7 100644 --- a/dslinter/checkers/data_leakage_scikitlearn.py +++ b/dslinter/checkers/pipeline_scikitlearn.py @@ -10,17 +10,17 @@ from dslinter.utils.resources import Resources -class DataLeakageScikitLearnChecker(BaseChecker): +class PipelineScikitLearnChecker(BaseChecker): """Checker which checks rules for preventing data leakage between training and test data.""" __implements__ = IAstroidChecker - name = "data-leakage-scikitlearn" + name = "pipeline-not-used-scikitlearn" priority = -1 msgs = { "W5518": ( "There are both preprocessing and estimation operations in the code, but they are not used in a pipeline.", - "data-leakage-scikitlearn", + "pipeline-not-used-scikitlearn", "Scikit-learn preprocessors and estimators should be used inside pipelines, to prevent data leakage between training and test data.", ), } @@ -84,7 +84,7 @@ def visit_call(self, call_node: astroid.Call): if self._expr_is_preprocessor(value.func.expr): has_preprocessing_function = True if has_learning_function is True and has_preprocessing_function is True: - self.add_message("data-leakage-scikitlearn", node=call_node) + self.add_message("pipeline-not-used-scikitlearn", node=call_node) except: # pylint: disable=bare-except ExceptionHandler.handle(self, call_node) @@ -98,14 +98,14 @@ def _expr_is_estimator(expr: astroid.node_classes.NodeNG) -> bool: :return: True when the expression is an estimator. """ if isinstance(expr, astroid.Call) \ - and DataLeakageScikitLearnChecker._call_initiates_estimator(expr): + and PipelineScikitLearnChecker._call_initiates_estimator(expr): return True # If expr is a Name, check whether that name is assigned to an estimator. if isinstance(expr, astroid.Name): values = AssignUtil.assignment_values(expr) for value in values: - if DataLeakageScikitLearnChecker._expr_is_estimator(value): + if PipelineScikitLearnChecker._expr_is_estimator(value): return True return False @@ -120,7 +120,7 @@ def _call_initiates_estimator(call: astroid.Call) -> bool: return ( call.func is not None and hasattr(call.func, "name") - and call.func.name in DataLeakageScikitLearnChecker._get_estimator_classes() + and call.func.name in PipelineScikitLearnChecker._get_estimator_classes() ) @staticmethod diff --git a/dslinter/checkers/randomness_control_numpy.py b/dslinter/checkers/randomness_control_numpy.py index c3bb11f..a5ad543 100644 --- a/dslinter/checkers/randomness_control_numpy.py +++ b/dslinter/checkers/randomness_control_numpy.py @@ -60,10 +60,15 @@ def visit_module(self, module: astroid.Module): if _import_ml_libraries is False: _import_ml_libraries = has_importfrom_sklearn(node) - if isinstance(node, astroid.nodes.Expr) and hasattr(node, "value"): - call_node = node.value + if isinstance(node, astroid.nodes.Expr): if _has_numpy_manual_seed is False: - _has_numpy_manual_seed = self._check_numpy_manual_seed(call_node) + _has_numpy_manual_seed = self._check_numpy_manual_seed_in_expr_node(node) + + if isinstance(node, astroid.nodes.FunctionDef): + for nod in node.body: + if isinstance(nod, astroid.nodes.Expr): + if _has_numpy_manual_seed is False: + _has_numpy_manual_seed = self._check_numpy_manual_seed_in_expr_node(nod) # check if the rules are violated if( @@ -76,7 +81,13 @@ def visit_module(self, module: astroid.Module): ExceptionHandler.handle(self, module) @staticmethod - def _check_numpy_manual_seed(call_node: astroid.Call): + def _check_numpy_manual_seed_in_expr_node(expr_node: astroid.Expr): + if hasattr(expr_node, "value"): + call_node = expr_node.value + return RandomnessControlNumpyChecker._check_numpy_manual_seed_in_call_node(call_node) + + @staticmethod + def _check_numpy_manual_seed_in_call_node(call_node: astroid.Call): if( hasattr(call_node, "func") and hasattr(call_node.func, "attrname") diff --git a/dslinter/checkers/randomness_control_pytorch.py b/dslinter/checkers/randomness_control_pytorch.py index f84cda4..f98b698 100644 --- a/dslinter/checkers/randomness_control_pytorch.py +++ b/dslinter/checkers/randomness_control_pytorch.py @@ -53,10 +53,15 @@ def visit_module(self, module: astroid.Module): if _import_pytorch is False: _import_pytorch = has_import(node, "torch") - if isinstance(node, astroid.nodes.Expr) and hasattr(node, "value"): - call_node = node.value + if isinstance(node, astroid.nodes.Expr): if _has_pytorch_manual_seed is False: - _has_pytorch_manual_seed = self._check_pytorch_manual_seed(call_node) + _has_pytorch_manual_seed = self._check_pytorch_manual_seed_in_expr_node(node) + + if isinstance(node, astroid.nodes.FunctionDef): + for nod in node.body: + if isinstance(nod, astroid.nodes.Expr): + if _has_pytorch_manual_seed is False: + _has_pytorch_manual_seed = self._check_pytorch_manual_seed_in_expr_node(nod) # check if the rules are violated if( @@ -68,7 +73,13 @@ def visit_module(self, module: astroid.Module): ExceptionHandler.handle(self, module) @staticmethod - def _check_pytorch_manual_seed(call_node: astroid.Call): + def _check_pytorch_manual_seed_in_expr_node(expr_node: astroid.Expr): + if hasattr(expr_node, "value"): + call_node = expr_node.value + return RandomnessControlPytorchChecker._check_pytorch_manual_seed_in_call_node(call_node) + + @staticmethod + def _check_pytorch_manual_seed_in_call_node(call_node: astroid.Call): if( hasattr(call_node, "func") and hasattr(call_node.func, "attrname") diff --git a/dslinter/checkers/randomness_control_tensorflow.py b/dslinter/checkers/randomness_control_tensorflow.py index fdc0bfd..2644d25 100644 --- a/dslinter/checkers/randomness_control_tensorflow.py +++ b/dslinter/checkers/randomness_control_tensorflow.py @@ -53,10 +53,15 @@ def visit_module(self, module: astroid.Module): if _import_tensorflow is False: _import_tensorflow = has_import(node, "tensorflow") - if isinstance(node, astroid.nodes.Expr) and hasattr(node, "value"): - call_node = node.value + if isinstance(node, astroid.nodes.Expr): if _has_tensorflow_manual_seed is False: - _has_tensorflow_manual_seed = self._check_tensorflow_manual_seed(call_node) + _has_tensorflow_manual_seed = self._check_tensorflow_manual_seed_in_expr_node(node) + + if isinstance(node, astroid.nodes.FunctionDef): + for nod in node.body: + if isinstance(nod, astroid.nodes.Expr): + if _has_tensorflow_manual_seed is False: + _has_tensorflow_manual_seed = self._check_tensorflow_manual_seed_in_expr_node(nod) # check if the rules are violated if( @@ -68,7 +73,13 @@ def visit_module(self, module: astroid.Module): ExceptionHandler.handle(self, module) @staticmethod - def _check_tensorflow_manual_seed(call_node: astroid.Call): + def _check_tensorflow_manual_seed_in_expr_node(expr_node: astroid.Expr): + if hasattr(expr_node, "value"): + call_node = expr_node.value + return RandomnessControlTensorflowChecker._check_tensorflow_manual_seed_in_call_node(call_node) + + @staticmethod + def _check_tensorflow_manual_seed_in_call_node(call_node: astroid.Call): if( hasattr(call_node, "func") and hasattr(call_node.func, "attrname") diff --git a/dslinter/plugin.py b/dslinter/plugin.py index ec8ca31..012a5ca 100644 --- a/dslinter/plugin.py +++ b/dslinter/plugin.py @@ -27,7 +27,7 @@ from dslinter.checkers.unnecessary_iteration_pandas import UnnecessaryIterationPandasChecker from dslinter.checkers.unnecessary_iteration_tensorflow import UnnecessaryIterationTensorflowChecker from dslinter.checkers.deterministic_pytorch import DeterministicAlgorithmChecker -from dslinter.checkers.data_leakage_scikitlearn import DataLeakageScikitLearnChecker +from dslinter.checkers.pipeline_scikitlearn import PipelineScikitLearnChecker from dslinter.checkers.hyperparameters_pytorch import HyperparameterPyTorchChecker from dslinter.checkers.hyperparameters_tensorflow import HyperparameterTensorflowChecker # pylint: disable = line-too-long @@ -58,7 +58,7 @@ def register(linter): linter.register_checker(RandomnessControlDataloaderPytorchChecker(linter)) linter.register_checker(RandomnessControlTensorflowChecker(linter)) linter.register_checker(RandomnessControlNumpyChecker(linter)) - linter.register_checker(DataLeakageScikitLearnChecker(linter)) + linter.register_checker(PipelineScikitLearnChecker(linter)) linter.register_checker(DependentThresholdPytorchChecker(linter)) linter.register_checker(DependentThresholdTensorflowChecker(linter)) linter.register_checker(DependentThresholdScikitLearnChecker(linter)) diff --git a/dslinter/tests/checkers/test_deterministic_pytorch.py b/dslinter/tests/checkers/test_deterministic_pytorch.py index 0add4ba..5b24759 100644 --- a/dslinter/tests/checkers/test_deterministic_pytorch.py +++ b/dslinter/tests/checkers/test_deterministic_pytorch.py @@ -22,6 +22,20 @@ def test_with_deterministic_option_set(self): with self.assertNoMessages(): self.checker.visit_module(module) + def test_with_deterministic_option_set2(self): + """Test whether no message is added if the deterministic algorithm option is used.""" + script = """ + import torch #@ + def set_random_seed(): + torch.use_deterministic_algorithms(True) + + if __name__ == '__main__': + pass + """ + module = astroid.parse(script) + with self.assertNoMessages(): + self.checker.visit_module(module) + def test_without_deterministic_option_set(self): """Test whether a message is added if the deterministic algorithm option is not used""" script = """ diff --git a/dslinter/tests/checkers/test_data_leakage_scikitlearn.py b/dslinter/tests/checkers/test_pipeline_scikitlearn.py similarity index 91% rename from dslinter/tests/checkers/test_data_leakage_scikitlearn.py rename to dslinter/tests/checkers/test_pipeline_scikitlearn.py index 399b476..9d1cf13 100644 --- a/dslinter/tests/checkers/test_data_leakage_scikitlearn.py +++ b/dslinter/tests/checkers/test_pipeline_scikitlearn.py @@ -7,7 +7,7 @@ class TestDataLeakageScikitLearnChecker(pylint.testutils.CheckerTestCase): """Class which tests the DataLeakageChecker.""" - CHECKER_CLASS = dslinter.plugin.DataLeakageScikitLearnChecker + CHECKER_CLASS = dslinter.plugin.PipelineScikitLearnChecker def test_pipeline_violation_on_call(self): """Message should be added when learning function is called directly on a learning class.""" @@ -32,7 +32,7 @@ def test_pipeline_violation_on_call(self): SVC().fit(X_train, y_train) #@ """ call_node = astroid.extract_node(script) - with self.assertAddsMessages(pylint.testutils.MessageTest(msg_id="data-leakage-scikitlearn", node=call_node),): + with self.assertAddsMessages(pylint.testutils.MessageTest(msg_id="pipeline-not-used-scikitlearn", node=call_node),): self.checker.visit_call(call_node) def test_learning_function_without_preprocessor(self): @@ -65,7 +65,7 @@ def test_pipeline_violation_outside_block(self): model.fit(X_train) #@ """ call_node = astroid.extract_node(script) - with self.assertAddsMessages(pylint.testutils.MessageTest(msg_id="data-leakage-scikitlearn", node=call_node),): + with self.assertAddsMessages(pylint.testutils.MessageTest(msg_id="pipeline-not-used-scikitlearn", node=call_node),): self.checker.visit_call(call_node) def test_pipeline_violation_on_name(self): @@ -77,7 +77,7 @@ def test_pipeline_violation_on_name(self): kmeans.fit(X_train) #@ """ call_node = astroid.extract_node(script) - with self.assertAddsMessages(pylint.testutils.MessageTest(msg_id="data-leakage-scikitlearn", node=call_node),): + with self.assertAddsMessages(pylint.testutils.MessageTest(msg_id="pipeline-not-used-scikitlearn", node=call_node),): self.checker.visit_call(call_node) def test_pipeline_violation_on_name_twice(self): @@ -90,7 +90,7 @@ def test_pipeline_violation_on_name_twice(self): kmeans2.fit(X_train) #@ """ call_node = astroid.extract_node(script) - with self.assertAddsMessages(pylint.testutils.MessageTest(msg_id="data-leakage-scikitlearn", node=call_node),): + with self.assertAddsMessages(pylint.testutils.MessageTest(msg_id="pipeline-not-used-scikitlearn", node=call_node),): self.checker.visit_call(call_node) def test_pipeline_violation_in_function(self): @@ -103,7 +103,7 @@ def f(model): f(KMeans()) """ call_node = astroid.extract_node(script) - with self.assertAddsMessages(pylint.testutils.MessageTest(msg_id="data-leakage-scikitlearn", node=call_node),): + with self.assertAddsMessages(pylint.testutils.MessageTest(msg_id="pipeline-not-used-scikitlearn", node=call_node),): self.checker.visit_call(call_node) def test_pipeline_violation_in_function_arg_assigned(self): @@ -117,7 +117,7 @@ def f(model): f(kmeans_model) """ call_node = astroid.extract_node(script) - with self.assertAddsMessages(pylint.testutils.MessageTest(msg_id="data-leakage-scikitlearn", node=call_node),): + with self.assertAddsMessages(pylint.testutils.MessageTest(msg_id="pipeline-not-used-scikitlearn", node=call_node),): self.checker.visit_call(call_node) def test_pipeline_violation_in_second_function_argument(self): @@ -130,5 +130,5 @@ def f(x, model): f(0, KMeans()) """ call_node = astroid.extract_node(script) - with self.assertAddsMessages(pylint.testutils.MessageTest(msg_id="data-leakage-scikitlearn", node=call_node),): + with self.assertAddsMessages(pylint.testutils.MessageTest(msg_id="pipeline-not-used-scikitlearn", node=call_node),): self.checker.visit_call(call_node) diff --git a/dslinter/tests/checkers/test_randomness_control_numpy.py b/dslinter/tests/checkers/test_randomness_control_numpy.py index f45960d..63ffe82 100644 --- a/dslinter/tests/checkers/test_randomness_control_numpy.py +++ b/dslinter/tests/checkers/test_randomness_control_numpy.py @@ -24,6 +24,22 @@ def test_ml_code_with_numpy_randomness_control(self): with self.assertNoMessages(): self.checker.visit_module(module) + def test_ml_code_with_numpy_randomness_control2(self): + """Tests whether no message is added if manual seed is set.""" + script = """ + import numpy as np #@ + import torch #@ + def add_seed(): + np.random.seed(0) + np.random.rand(4) + + if __name__ == '__main__': + pass + """ + module = astroid.parse(script) + with self.assertNoMessages(): + self.checker.visit_module(module) + def test_non_ml_code_with_numpy_randomness_control(self): """Tests whether no message is added if manual seed is set.""" script = """ diff --git a/dslinter/tests/checkers/test_randomness_control_pytorch.py b/dslinter/tests/checkers/test_randomness_control_pytorch.py index af2d42a..8348335 100644 --- a/dslinter/tests/checkers/test_randomness_control_pytorch.py +++ b/dslinter/tests/checkers/test_randomness_control_pytorch.py @@ -23,6 +23,21 @@ def test_with_pytorch_randomness_control(self): with self.assertNoMessages(): self.checker.visit_module(module) + def test_with_pytorch_randomness_control2(self): + """Tests whether no message is added if manual seed is set.""" + script = """ + import torch #@ + def add_seed(): + torch.manual_seed(0) + torch.randn(10).index_copy(0, torch.tensor([0]), torch.randn(1)) + + if __name__ == '__main__': + pass + """ + module = astroid.parse(script) + with self.assertNoMessages(): + self.checker.visit_module(module) + def test_without_pytorch_randomness_control(self): """Tests whether a message is added if manual seed is not set""" script = """ diff --git a/dslinter/tests/checkers/test_randomness_control_tensorflow.py b/dslinter/tests/checkers/test_randomness_control_tensorflow.py index 1241d5a..7163b6b 100644 --- a/dslinter/tests/checkers/test_randomness_control_tensorflow.py +++ b/dslinter/tests/checkers/test_randomness_control_tensorflow.py @@ -23,6 +23,21 @@ def test_with_tensorflow_randomness_control(self): with self.assertNoMessages(): self.checker.visit_module(module) + def test_with_tensorflow_randomness_control2(self): + """Tests whether no message is added if manual seed is set.""" + script = """ + import tensorflow as tf #@ + def add_seed(): + tf.random.set_seed(0) + tf.random.uniform([1]) + + if __name__ == '__main__': + pass + """ + module = astroid.parse(script) + with self.assertNoMessages(): + self.checker.visit_module(module) + def test_without_tensorflow_randomness_control(self): """Tests whether a message is added if manual seed is not set""" script = """ diff --git a/dslinter/utils/ast.py b/dslinter/utils/ast.py index a041fe3..6714f2e 100644 --- a/dslinter/utils/ast.py +++ b/dslinter/utils/ast.py @@ -228,4 +228,7 @@ def _get_target_name(target: astroid.node_classes.NodeNG) -> str: return target.name if hasattr(target, "value"): return AssignUtil._get_target_name(target.value) - raise Exception("Target name cannot be retrieved.") + # raise Exception("Target name cannot be retrieved.") + # This is a quick fix + # TODO: make a stable fix + return "" diff --git a/dslinter/utils/type_inference.py b/dslinter/utils/type_inference.py index 04d604b..ace9e53 100644 --- a/dslinter/utils/type_inference.py +++ b/dslinter/utils/type_inference.py @@ -93,7 +93,7 @@ def line_to_add_call(node: astroid.node_classes.NodeNG): :param node: The node where a call is added to in the source code. :return: Line number where the reveal_type() call can be added. """ - if hasattr(node.parent, "blockstart_tolineno") and node not in node.parent.body: + if hasattr(node.parent, "blockstart_tolineno") and node not in node.parent.body and len(node.parent.body) > 0: return TypeInference.line_to_add_call(node.parent.body[0]) return node.tolineno diff --git a/pyproject.toml b/pyproject.toml index 37f2150..453703f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ skip = 'scripts' [tool.poetry] name = "dslinter" -version = "2.0.7" +version = "2.0.8" description = "`dslinter` is a pylint plugin for linting data science and machine learning code. We plan to support the following Python libraries: TensorFlow, PyTorch, Scikit-Learn, Pandas, NumPy and SciPy." license = "GPL-3.0 License" @@ -44,14 +44,6 @@ keywords = ["machine learning", "software engineering"] [tool.poetry.dependencies] python = "^3.7" # Compatible python versions must be declared here toml = "^0.10" -# Dependencies with extras -# requests = { version = "^2.13", extras = [ "security" ] } -# Python specific dependencies with prereleases allowed -# pathlib2 = { version = "^2.2", allow-prereleases = true } -# Git dependencies -# cleo = { git = "https://github.com/sdispater/cleo.git", branch = "master" } -# Optional dependencies (extras) -# pendulum = { version = "^1.4", optional = true } pylint = { version = "~2.12.2" } astroid = { version = "~2.9.3" } mypy = { version = "~0.931" }