From b3e98ff7881b175ae0ef5be7c75c9fcc811e8dfe Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Fri, 6 Dec 2024 10:55:58 +0800 Subject: [PATCH] Fix the issue when using `Model.compile` multiple times. --- keras/src/trainers/trainer.py | 16 --------------- keras/src/trainers/trainer_test.py | 32 ++++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 16 deletions(-) diff --git a/keras/src/trainers/trainer.py b/keras/src/trainers/trainer.py index ff103b535c3..27bcfea9381 100644 --- a/keras/src/trainers/trainer.py +++ b/keras/src/trainers/trainer.py @@ -140,7 +140,6 @@ def compile( wrapped in a `LossScaleOptimizer`, which will dynamically scale the loss to prevent underflow. """ - self._clear_previous_trainer_metrics() optimizer = optimizers.get(optimizer) self.optimizer = optimizer if ( @@ -287,21 +286,6 @@ def _get_own_metrics(self): metrics.extend(self._metrics) return metrics - def _clear_previous_trainer_metrics(self): - for layer in self._flatten_layers(include_self=False): - if not isinstance(layer, Trainer): - continue - # A sublayer might be a Trainer. In that case, we need to clear - # the Trainer-related metrics, as they are not usable when a - # new Trainer is instantiated. - for m in self._get_own_metrics(): - layer._tracker.untrack(m) - layer._loss_tracker = None - layer._compile_metrics = None - if layer._compile_loss is not None: - layer._compile_loss._metrics.clear() - layer._metrics.clear() - def compute_loss( self, x=None, diff --git a/keras/src/trainers/trainer_test.py b/keras/src/trainers/trainer_test.py index 434fe47c969..4e25f4d3437 100644 --- a/keras/src/trainers/trainer_test.py +++ b/keras/src/trainers/trainer_test.py @@ -407,6 +407,38 @@ def test_nested_trainer_metrics_without_compile(self): self.assertEqual(new_model.metrics[0], new_model._loss_tracker) self.assertEqual(new_model.metrics[1], new_model._compile_metrics) + def test_multiple_compiles(self): + # https://github.com/keras-team/keras/issues/20474 + model1 = ExampleModel(units=3) + model2 = ExampleModel(units=3) + model1.compile( + optimizer=optimizers.SGD(), + loss=losses.MeanSquaredError(), + metrics=[metrics.MeanSquaredError()], + ) + + # Combine these 2 models into `combined`. + inputs = keras.Input(shape=(4,)) + x = model1(inputs) + outputs = model2(x) + combined = models.Model(inputs, outputs) + combined.compile( + optimizer=optimizers.SGD(), + loss=losses.MeanSquaredError(), + metrics=[metrics.MeanSquaredError()], + ) + + self.assertLen(model1.metrics, 2) + self.assertIsNotNone(model1._loss_tracker) + self.assertEqual(model1.metrics[0], model1._loss_tracker) + self.assertEqual(model1.metrics[1], model1._compile_metrics) + + # `combined.metrics` will not include `model1.metrics`. + self.assertLen(combined.metrics, 2) + self.assertIsNotNone(combined._loss_tracker) + self.assertEqual(combined.metrics[0], combined._loss_tracker) + self.assertEqual(combined.metrics[1], combined._compile_metrics) + @pytest.mark.skipif( backend.backend() != "torch", reason="torch backend runs in eager mode for jit_compile='auto'",