diff --git a/dslinter/checkers/forward_pytorch.py b/dslinter/checkers/forward_pytorch.py index c397a69..5165d4f 100644 --- a/dslinter/checkers/forward_pytorch.py +++ b/dslinter/checkers/forward_pytorch.py @@ -4,6 +4,7 @@ from pylint.interfaces import IAstroidChecker from dslinter.utils.exception_handler import ExceptionHandler +from dslinter.utils.randomness_control_helper import has_import class ForwardPytorchChecker(BaseChecker): @@ -22,6 +23,12 @@ class ForwardPytorchChecker(BaseChecker): } options = () + _import_torch = False + + def visit_import(self, import_node: astroid.Import): + if self._import_torch is False: + self._import_torch = has_import(import_node, "torch") + def visit_call(self, call_node: astroid.Call): """ When a Call node is visited, check whether it violated the rule in this checker. @@ -46,7 +53,17 @@ def visit_call(self, call_node: astroid.Call): and call_node.func.expr.func.name == "super" ): _call_from_super = True - if _has_forward is True and (_call_from_self is False and _call_from_super is False): + if( + self._import_torch is True + and _has_forward is True + and ( + _call_from_self is False + and _call_from_super is False + ) + ): self.add_message("forward-pytorch", node=call_node) except: # pylint: disable = bare-except ExceptionHandler.handle(self, call_node) + + def leave_module(self, module: astroid.Module): + self._import_torch = False diff --git a/dslinter/tests/checkers/test_forward_pytorch.py b/dslinter/tests/checkers/test_forward_pytorch.py index 2d0d045..1a21375 100644 --- a/dslinter/tests/checkers/test_forward_pytorch.py +++ b/dslinter/tests/checkers/test_forward_pytorch.py @@ -13,7 +13,7 @@ class TestForwardPytorchChecker(pylint.testutils.CheckerTestCase): def test_use_forward(self): """Message will be added if the self.net.forward() is used in the code rather than self.net().""" script = """ - import torch.nn as nn + import torch.nn as nn #@ class Net(nn.Module): def __init__(self): super().__init__() @@ -33,14 +33,16 @@ def forward(self, x): x = self.fc3(x) return x """ - call_node = astroid.extract_node(script).value + import_node, assign_node = astroid.extract_node(script) + call_node = assign_node.value with self.assertAddsMessages(pylint.testutils.MessageTest(msg_id="forward-pytorch", node=call_node)): + self.checker.visit_import(import_node) self.checker.visit_call(call_node) def test_not_use_forward(self): """No message will be added if self.net() is used in the code.""" script = """ - import torch.nn as nn + import torch.nn as nn #@ class Net(nn.Module): def __init__(self): super().__init__() @@ -60,25 +62,31 @@ def forward(self, x): x = self.fc3(x) return x """ - call_node = astroid.extract_node(script).value + import_node, assign_node = astroid.extract_node(script) + call_node = assign_node.value with self.assertNoMessages(): + self.checker.visit_import(import_node) self.checker.visit_call(call_node) def test_use_self_forward(self): """No Message will be added if the self.forward() is used in the code.""" script = """ + import torch #@ def training_step(self, batch, batch_nb): idx = batch['idx'] loss = self.forward(batch)[0] #@ return {'loss': loss, 'idx': idx} """ - call_node = astroid.extract_node(script).value.value + import_node, assign_node = astroid.extract_node(script) + call_node = assign_node.value.value with self.assertNoMessages(): + self.checker.visit_import(import_node) self.checker.visit_call(call_node) def test_use_super_forward(self): """No Message will be added if the super().forward() is used in the code.""" script = """ + import torch #@ class SpatialDropout(nn.Dropout2d): def forward(self, x): x = x.unsqueeze(2) # (N, T, 1, K) @@ -88,6 +96,8 @@ def forward(self, x): x = x.squeeze(2) # (N, T, K) return x """ - call_node = astroid.extract_node(script).value + import_node, assign_node = astroid.extract_node(script) + call_node = assign_node.value with self.assertNoMessages(): + self.checker.visit_import(import_node) self.checker.visit_call(call_node) diff --git a/dslinter/utils/randomness_control_helper.py b/dslinter/utils/randomness_control_helper.py index 5bfafc5..b035d61 100644 --- a/dslinter/utils/randomness_control_helper.py +++ b/dslinter/utils/randomness_control_helper.py @@ -20,7 +20,7 @@ def check_main_module(module: astroid.Module) -> bool: def has_import(node: astroid.Import, library_name: str): for name, _ in node.names: - if name == library_name: + if name == library_name or library_name in name: return True return False diff --git a/pyproject.toml b/pyproject.toml index 453703f..85d0c28 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ skip = 'scripts' [tool.poetry] name = "dslinter" -version = "2.0.8" +version = "2.0.9" description = "`dslinter` is a pylint plugin for linting data science and machine learning code. We plan to support the following Python libraries: TensorFlow, PyTorch, Scikit-Learn, Pandas, NumPy and SciPy." license = "GPL-3.0 License"