Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Apple's MPS backend #123

Closed
wants to merge 9 commits into from
22 changes: 15 additions & 7 deletions notebooks/automatic_mask_generator_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,26 @@
"metadata": {},
"outputs": [],
"source": [
"device = \"cuda\" # uncomment this line for CUDA environment\n",
"# device = \"mps\" # uncomment this line for MPS environment\n",
"\n",
"import os\n",
"if device==\"mps\":\n",
" os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] ='1'\n",
"\n",
"import numpy as np\n",
"import torch\n",
"import matplotlib.pyplot as plt\n",
"from PIL import Image\n",
"\n",
"# use bfloat16 for the entire notebook\n",
"torch.autocast(device_type=\"cuda\", dtype=torch.bfloat16).__enter__()\n",
"if device==\"cuda\":\n",
" # use bfloat16 for the entire notebook\n",
" torch.autocast(device_type=\"cuda\", dtype=torch.bfloat16).__enter__()\n",
"\n",
"if torch.cuda.get_device_properties(0).major >= 8:\n",
" # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)\n",
" torch.backends.cuda.matmul.allow_tf32 = True\n",
" torch.backends.cudnn.allow_tf32 = True"
" if torch.cuda.get_device_properties(0).major >= 8:\n",
" # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)\n",
" torch.backends.cuda.matmul.allow_tf32 = True\n",
" torch.backends.cudnn.allow_tf32 = True"
]
},
{
Expand Down Expand Up @@ -176,7 +184,7 @@
"sam2_checkpoint = \"../checkpoints/sam2_hiera_large.pt\"\n",
"model_cfg = \"sam2_hiera_l.yaml\"\n",
"\n",
"sam2 = build_sam2(model_cfg, sam2_checkpoint, device ='cuda', apply_postprocessing=False)\n",
"sam2 = build_sam2(model_cfg, sam2_checkpoint, device = device, apply_postprocessing=False)\n",
"\n",
"mask_generator = SAM2AutomaticMaskGenerator(sam2)"
]
Expand Down
22 changes: 15 additions & 7 deletions notebooks/image_predictor_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,13 @@
"metadata": {},
"outputs": [],
"source": [
"device = \"cuda\" # uncomment this line for CUDA environment\n",
"# device = \"mps\" # uncomment this line for MPS environment\n",
"\n",
"import os\n",
"if device==\"mps\":\n",
" os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] ='1'\n",
"\n",
"import torch\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
Expand All @@ -82,13 +89,14 @@
"metadata": {},
"outputs": [],
"source": [
"# use bfloat16 for the entire notebook\n",
"torch.autocast(device_type=\"cuda\", dtype=torch.bfloat16).__enter__()\n",
"if device==\"cuda\":\n",
" # use bfloat16 for the entire notebook\n",
" torch.autocast(device_type=\"cuda\", dtype=torch.bfloat16).__enter__()\n",
"\n",
"if torch.cuda.get_device_properties(0).major >= 8:\n",
" # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)\n",
" torch.backends.cuda.matmul.allow_tf32 = True\n",
" torch.backends.cudnn.allow_tf32 = True"
" if torch.cuda.get_device_properties(0).major >= 8:\n",
" # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)\n",
" torch.backends.cuda.matmul.allow_tf32 = True\n",
" torch.backends.cudnn.allow_tf32 = True"
]
},
{
Expand Down Expand Up @@ -211,7 +219,7 @@
"sam2_checkpoint = \"../checkpoints/sam2_hiera_large.pt\"\n",
"model_cfg = \"sam2_hiera_l.yaml\"\n",
"\n",
"sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=\"cuda\")\n",
"sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)\n",
"\n",
"predictor = SAM2ImagePredictor(sam2_model)"
]
Expand Down
20 changes: 13 additions & 7 deletions notebooks/video_predictor_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,12 @@
"metadata": {},
"outputs": [],
"source": [
"device = \"cuda\" # uncomment this line for CUDA environment\n",
"# device = \"mps\" # uncomment this line for MPS environment\n",
"\n",
"import os\n",
"if device==\"mps\":\n",
" os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] ='1'\n",
"import torch\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
Expand All @@ -55,13 +60,14 @@
"metadata": {},
"outputs": [],
"source": [
"# use bfloat16 for the entire notebook\n",
"torch.autocast(device_type=\"cuda\", dtype=torch.bfloat16).__enter__()\n",
"if device==\"cuda\":\n",
" # use bfloat16 for the entire notebook\n",
" torch.autocast(device_type=\"cuda\", dtype=torch.bfloat16).__enter__()\n",
"\n",
"if torch.cuda.get_device_properties(0).major >= 8:\n",
" # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)\n",
" torch.backends.cuda.matmul.allow_tf32 = True\n",
" torch.backends.cudnn.allow_tf32 = True"
" if torch.cuda.get_device_properties(0).major >= 8:\n",
" # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)\n",
" torch.backends.cuda.matmul.allow_tf32 = True\n",
" torch.backends.cudnn.allow_tf32 = True"
]
},
{
Expand All @@ -84,7 +90,7 @@
"sam2_checkpoint = \"../checkpoints/sam2_hiera_large.pt\"\n",
"model_cfg = \"sam2_hiera_l.yaml\"\n",
"\n",
"predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)"
"predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device=device)"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion sam2/automatic_mask_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def _process_batch(
orig_h, orig_w = orig_size

# Run model on this batch
points = torch.as_tensor(points, device=self.predictor.device)
points = torch.as_tensor(points.astype('float32'), device=self.predictor.device)
in_points = self.predictor._transforms.transform_coords(
points, normalize=normalize, orig_hw=im_size
)
Expand Down
3 changes: 2 additions & 1 deletion sam2/modeling/position_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def apply_rotary_enc(
# repeat freqs along seq_len dim to match k seq_len
if repeat_freqs_k:
r = xk_.shape[-2] // xq_.shape[-2]
freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
freqs_cis = freqs_cis.unsqueeze(2).expand(*([-1]*(freqs_cis.ndim-2)), r, *([-1]*(freqs_cis.ndim-2))).reshape(*freqs_cis.shape[:2], r*freqs_cis.shape[2], freqs_cis.shape[3])

xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
4 changes: 2 additions & 2 deletions sam2/modeling/sam2_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,10 +567,10 @@ def _prepare_memory_conditioned_features(
continue # skip padding frames
# "maskmem_features" might have been offloaded to CPU in demo use cases,
# so we load it back to GPU (it's a no-op if it's already on GPU).
feats = prev["maskmem_features"].cuda(non_blocking=True)
feats = prev["maskmem_features"].to(self.device, non_blocking=True)
to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
# Spatial positional encoding (it might have been offloaded to CPU in eval)
maskmem_enc = prev["maskmem_pos_enc"][-1].cuda()
maskmem_enc = prev["maskmem_pos_enc"][-1].to(self.device)
maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
# Temporal positional encoding
maskmem_enc = (
Expand Down
9 changes: 5 additions & 4 deletions sam2/sam2_video_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def init_state(
image_size=self.image_size,
offload_video_to_cpu=offload_video_to_cpu,
async_loading_frames=async_loading_frames,
device=self.device,
)
inference_state = {}
inference_state["images"] = images
Expand All @@ -64,11 +65,11 @@ def init_state(
# the original video height and width, used for resizing final output scores
inference_state["video_height"] = video_height
inference_state["video_width"] = video_width
inference_state["device"] = torch.device("cuda")
inference_state["device"] = torch.device(self.device)
if offload_state_to_cpu:
inference_state["storage_device"] = torch.device("cpu")
else:
inference_state["storage_device"] = torch.device("cuda")
inference_state["storage_device"] = torch.device(self.device)
# inputs on each frame
inference_state["point_inputs_per_obj"] = {}
inference_state["mask_inputs_per_obj"] = {}
Expand Down Expand Up @@ -215,7 +216,7 @@ def add_new_points(
prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)

if prev_out is not None and prev_out["pred_masks"] is not None:
prev_sam_mask_logits = prev_out["pred_masks"].cuda(non_blocking=True)
prev_sam_mask_logits = prev_out["pred_masks"].to(self.device, non_blocking=True)
# Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
current_out, _ = self._run_single_frame_inference(
Expand Down Expand Up @@ -734,7 +735,7 @@ def _get_image_feature(self, inference_state, frame_idx, batch_size):
)
if backbone_out is None:
# Cache miss -- we will run inference on a single image
image = inference_state["images"][frame_idx].cuda().float().unsqueeze(0)
image = inference_state["images"][frame_idx].to(self.device).float().unsqueeze(0)
backbone_out = self.forward_image(image)
# Cache the most recent frame's feature (for repeated interactions with
# a frame; we can use an LRU cache for more frames in the future).
Expand Down
37 changes: 32 additions & 5 deletions sam2/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,35 @@ def get_connected_components(mask):
- counts: A tensor of shape (N, 1, H, W) containing the area of the connected
components for foreground pixels and 0 for background pixels.
"""
from sam2 import _C

return _C.get_connected_componnets(mask.to(torch.uint8).contiguous())
if torch.cuda.is_available():
from sam2 import _C

return _C.get_connected_componnets(mask.to(torch.uint8).contiguous())


# if cuda is not available use scipy to get connected components
from scipy.ndimage import label as scipy_label

labels = torch.zeros_like(mask, dtype=torch.int32)
counts = torch.zeros_like(mask, dtype=torch.int32)

mask_np = mask.cpu().numpy()
for i in range(mask.shape[0]):
mask_i = mask_np[i, 0]
labeled_array, num_features = scipy_label(mask_i, structure=np.ones((3, 3)))
labels_np = np.zeros_like(labeled_array)
counts_np = np.zeros_like(labeled_array)

for feature in range(1, num_features + 1):
labels_np[labeled_array == feature] = feature
counts_np[labeled_array == feature] = (labeled_array == feature).sum()

labels[i, 0] = torch.tensor(labels_np, dtype=torch.int32)
counts[i, 0] = torch.tensor(counts_np, dtype=torch.int32)
labels = labels.to(mask.device)
counts = counts.to(mask.device)
return labels, counts


def mask_to_box(masks: torch.Tensor):
Expand Down Expand Up @@ -164,6 +190,7 @@ def load_video_frames(
video_path,
image_size,
offload_video_to_cpu,
device,
img_mean=(0.485, 0.456, 0.406),
img_std=(0.229, 0.224, 0.225),
async_loading_frames=False,
Expand Down Expand Up @@ -204,9 +231,9 @@ def load_video_frames(
for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")):
images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size)
if not offload_video_to_cpu:
images = images.cuda()
img_mean = img_mean.cuda()
img_std = img_std.cuda()
images = images.to(device)
img_mean = img_mean.to(device)
img_std = img_std.to(device)
# normalize by mean and std
images -= img_mean
images /= img_std
Expand Down
5 changes: 4 additions & 1 deletion setup.py
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may want to consider using torch.backends.cuda.is_available() directly here instead of checking for nvcc.
Docs: https://pytorch.org/docs/stable/backends.html

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the advice! I have updated setup.py to use torch.cuda.is_available() instead of checking nvcc.

Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

from setuptools import find_packages, setup
import torch
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

# Package metadata
Expand All @@ -29,15 +30,17 @@
"hydra-core>=1.3.2",
"iopath>=0.1.10",
"pillow>=9.4.0",
"scipy>=1.14.0",
]

EXTRA_PACKAGES = {
"demo": ["matplotlib>=3.9.1", "jupyter>=1.0.0", "opencv-python>=4.7.0"],
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When trying to work off this branch, I ran into an issue with matplotlib>=3.9.1. Doing matplotlib>=3.9.0 helped.

Separately, I had to comment out ext_modules in the main setup() function.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran into the same problem with matplotlib>=3.9.1 and >=3.9.0 fixed it. Thanks!
Didn’t have any issues with ext_modules though.

"dev": ["black==24.2.0", "usort==1.0.2", "ufmt==2.0.0b2"],
}


def get_extensions():
if torch.cuda.is_available() is False:
return []
srcs = ["sam2/csrc/connected_components.cu"]
compile_args = {
"cxx": [],
Expand Down