diff --git a/sam2/sam2_image_predictor.py b/sam2/sam2_image_predictor.py index 41ce53af5..2cd35384d 100644 --- a/sam2/sam2_image_predictor.py +++ b/sam2/sam2_image_predictor.py @@ -456,6 +456,22 @@ def get_image_embedding(self) -> torch.Tensor: def device(self) -> torch.device: return self.model.device + def import_fields(self, fields): + self._features = fields["_features"] + self._orig_hw = fields["_orig_hw"] + self._is_image_set = True + + def export_fields(self): + """ + Exports the specified fields of the TestSamPredictor class. + Returns: + dict: A dictionary containing the values of features, orig_h, orig_w, input_h, and input_w. + """ + return { + "_features": self._features, + "_orig_hw": self._orig_hw + } + def reset_predictor(self) -> None: """ Resets the image embeddings and other state variables.