diff --git a/modeling/inference_model.py b/modeling/inference_model.py index 1af06675..441335cc 100644 --- a/modeling/inference_model.py +++ b/modeling/inference_model.py @@ -46,8 +46,8 @@ def __enter__(self): ) if use_core_manipulations.sample: - use_core_manipulations.old_sample = transformers.GenerationMixin.sample - transformers.GenerationMixin.sample = use_core_manipulations.sample + use_core_manipulations.old_sample = transformers.GenerationMixin._sample + transformers.GenerationMixin._sample = use_core_manipulations.sample if use_core_manipulations.get_stopping_criteria: use_core_manipulations.old_get_stopping_criteria = ( @@ -69,7 +69,7 @@ def __exit__(self, exc_type, exc_value, exc_traceback): ), "Patch leak: THE MONKEYS HAVE ESCAPED" if use_core_manipulations.old_sample: - transformers.GenerationMixin.sample = use_core_manipulations.old_sample + transformers.GenerationMixin._sample = use_core_manipulations.old_sample else: assert ( not use_core_manipulations.sample diff --git a/modeling/inference_models/hf_torch.py b/modeling/inference_models/hf_torch.py index fcdd9fb9..37eaf105 100644 --- a/modeling/inference_models/hf_torch.py +++ b/modeling/inference_models/hf_torch.py @@ -266,7 +266,7 @@ def new_sample(self, *args, **kwargs): kwargs.setdefault("pad_token_id", 2) return new_sample.old_sample(self, *args, **kwargs) - new_sample.old_sample = transformers.GenerationMixin.sample + new_sample.old_sample = transformers.GenerationMixin._sample use_core_manipulations.sample = new_sample # PEFT Loading. This MUST be done after all save_pretrained calls are