Skip to content

Commit

Permalink
Merge pull request #128 from NielsRogge/add_hf
Browse files Browse the repository at this point in the history
Integrate with Hugging Face
  • Loading branch information
haithamkhedr authored Aug 7, 2024
2 parents 511199d + 9b58611 commit 6ba4c65
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 0 deletions.
36 changes: 36 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,42 @@ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):

Please refer to the examples in [video_predictor_example.ipynb](./notebooks/video_predictor_example.ipynb) for details on how to add prompts, make refinements, and track multiple objects in videos.

## Load from 🤗 Hugging Face

Alternatively, models can also be loaded from [Hugging Face](https://huggingface.co/models?search=facebook/sam2) (requires `pip install huggingface_hub`).

For image prediction:

```python
import torch
from sam2.sam2_image_predictor import SAM2ImagePredictor

predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")

with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
predictor.set_image(<your_image>)
masks, _, _ = predictor.predict(<input_prompts>)
```

For video prediction:

```python
import torch
from sam2.sam2_video_predictor import SAM2VideoPredictor

predictor = SAM2VideoPredictor.from_pretrained("facebook/sam2-hiera-large")

with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
state = predictor.init_state(<your_video>)

# add new prompts and instantly get the output on the same frame
frame_idx, object_ids, masks = predictor.add_new_points(state, <your_prompts>):

# propagate the prompts to get masklets throughout the video
for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
...
```

## Model Description

| **Model** | **Size (M)** | **Speed (FPS)** | **SA-V test (J&F)** | **MOSE val (J&F)** | **LVOS v2 (J&F)** |
Expand Down
38 changes: 38 additions & 0 deletions sam2/build_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,44 @@ def build_sam2_video_predictor(
return model


def build_sam2_hf(model_id, **kwargs):

from huggingface_hub import hf_hub_download

model_id_to_filenames = {
"facebook/sam2-hiera-tiny": ("sam2_hiera_t.yaml", "sam2_hiera_tiny.pt"),
"facebook/sam2-hiera-small": ("sam2_hiera_s.yaml", "sam2_hiera_small.pt"),
"facebook/sam2-hiera-base-plus": (
"sam2_hiera_b+.yaml",
"sam2_hiera_base_plus.pt",
),
"facebook/sam2-hiera-large": ("sam2_hiera_l.yaml", "sam2_hiera_large.pt"),
}
config_name, checkpoint_name = model_id_to_filenames[model_id]
ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name)
return build_sam2(config_file=config_name, ckpt_path=ckpt_path, **kwargs)


def build_sam2_video_predictor_hf(model_id, **kwargs):

from huggingface_hub import hf_hub_download

model_id_to_filenames = {
"facebook/sam2-hiera-tiny": ("sam2_hiera_t.yaml", "sam2_hiera_tiny.pt"),
"facebook/sam2-hiera-small": ("sam2_hiera_s.yaml", "sam2_hiera_small.pt"),
"facebook/sam2-hiera-base-plus": (
"sam2_hiera_b+.yaml",
"sam2_hiera_base_plus.pt",
),
"facebook/sam2-hiera-large": ("sam2_hiera_l.yaml", "sam2_hiera_large.pt"),
}
config_name, checkpoint_name = model_id_to_filenames[model_id]
ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name)
return build_sam2_video_predictor(
config_file=config_name, ckpt_path=ckpt_path, **kwargs
)


def _load_checkpoint(model, ckpt_path):
if ckpt_path is not None:
sd = torch.load(ckpt_path, map_location="cpu")["model"]
Expand Down
17 changes: 17 additions & 0 deletions sam2/sam2_image_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,23 @@ def __init__(
(64, 64),
]

@classmethod
def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2ImagePredictor":
"""
Load a pretrained model from the Hugging Face hub.
Arguments:
model_id (str): The Hugging Face repository ID.
**kwargs: Additional arguments to pass to the model constructor.
Returns:
(SAM2ImagePredictor): The loaded model.
"""
from sam2.build_sam import build_sam2_hf

sam_model = build_sam2_hf(model_id, **kwargs)
return cls(sam_model)

@torch.no_grad()
def set_image(
self,
Expand Down
17 changes: 17 additions & 0 deletions sam2/sam2_video_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,23 @@ def init_state(
self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
return inference_state

@classmethod
def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2VideoPredictor":
"""
Load a pretrained model from the Hugging Face hub.
Arguments:
model_id (str): The Hugging Face repository ID.
**kwargs: Additional arguments to pass to the model constructor.
Returns:
(SAM2VideoPredictor): The loaded model.
"""
from sam2.build_sam import build_sam2_video_predictor_hf

sam_model = build_sam2_video_predictor_hf(model_id, **kwargs)
return cls(sam_model)

def _obj_id_to_idx(self, inference_state, obj_id):
"""Map client-side object id to model-side object index."""
obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None)
Expand Down

0 comments on commit 6ba4c65

Please sign in to comment.