diff --git a/paradox/kernel/engine.py b/paradox/kernel/engine.py index f942870..cbb1635 100644 --- a/paradox/kernel/engine.py +++ b/paradox/kernel/engine.py @@ -105,11 +105,14 @@ def __compute_gradient(self, variable: Symbol): if hash(self.__symbol) == hash(variable): self.__gradients[variable] = broadcast(Constant(1), self.shape(self.__symbol)) return + current_operator = None for forward in variable.output: if self.gradient(forward) is not None: + if current_operator != forward.operator: + current_operator = forward.operator + index = -1 gradients = forward.operator.gradient(self, forward, *forward.input) - index = None - for i, _variable in enumerate(forward.input): + for i, _variable in enumerate(forward.input, start=index + 1): if hash(_variable) == hash(variable): index = i break