Skip to content

Commit

Permalink
Fix ModelDebugger.step to use only intermediate_outputs
Browse files Browse the repository at this point in the history
`record_intermediate_output` only expects intermediate outputs, not model outputs which could be of type `PIL.Image`. Also added corresponding test on `_get_activation_calibration_stats` which is currently the only user entrypoint to `ModelDebugger.step`.
  • Loading branch information
Zerui18 committed Dec 23, 2024
1 parent bdc9810 commit 71b7246
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -325,5 +325,5 @@ def step(
key: value for key, value in outputs.items() if key not in model_output_names
}

for output_name, output_value in outputs.items():
for output_name, output_value in intermediate_outputs.items():
self.record_intermediate_output(output_value, output_name, activation_stats_dict)
21 changes: 21 additions & 0 deletions coremltools/test/optimize/coreml/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3830,3 +3830,24 @@ def test_get_activation_calibration_stats_concat_surrounding_ops(self):
# Since mlmodel has a concat with 2 inputs and 1 output, we should see at least 3 rmin/rmax pairs are identical in activation_stats.
# If we dedup rmin/rmax pairs with identical values, the length of unique values should at least reduced by 2 compared with original one.
assert len(activation_stats) - len(activation_stats_unique) >= 2

def test_get_activation_calibration_stats_excludes_model_outputs(self):
"""
The activation calibration stats shouldn't include the model's final outputs.
"""
# Prepare sample data
sample_data = []
for _ in range(3):
input_data = np.random.rand(5, 10, 4, 4)
sample_data.append({"data": input_data})

# Loading a floating point mlmodel
mlmodel = self._get_test_mlmodel_conv_relu()

activation_stats = _get_activation_calibration_stats(mlmodel, sample_data)

model_spec = mlmodel.get_spec()
output_count = len(mlmodel.get_spec().description.output)
for i in range(0, output_count):
output_name = model_spec.description.output[i].name
assert output_name not in activation_stats

0 comments on commit 71b7246

Please sign in to comment.