Skip to content

Commit

Permalink
allow running SAM 2 on CPU or non-CUDA devices
Browse files Browse the repository at this point in the history
  • Loading branch information
ronghanghu committed Aug 10, 2024
1 parent 778e112 commit 7213380
Show file tree
Hide file tree
Showing 8 changed files with 152 additions and 59 deletions.
72 changes: 51 additions & 21 deletions notebooks/automatic_mask_generator_example.ipynb

Large diffs are not rendered by default.

55 changes: 39 additions & 16 deletions notebooks/image_predictor_example.ipynb

Large diffs are not rendered by default.

35 changes: 27 additions & 8 deletions notebooks/video_predictor_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,10 @@
"outputs": [],
"source": [
"import os\n",
"import torch\n",
"# if using Apple MPS, fall back to CPU for unsupported ops\n",
"os.environ[\"PYTORCH_ENABLE_MPS_FALLBACK\"] = \"1\"\n",
"import numpy as np\n",
"import torch\n",
"import matplotlib.pyplot as plt\n",
"from PIL import Image"
]
Expand All @@ -122,15 +124,32 @@
"execution_count": 5,
"id": "08ba49d8-8c22-4eba-a2ab-46eee839287f",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"using device: cuda\n"
]
}
],
"source": [
"# use bfloat16 for the entire notebook\n",
"torch.autocast(device_type=\"cuda\", dtype=torch.bfloat16).__enter__()\n",
"# select the device for computation\n",
"if torch.cuda.is_available():\n",
" device = torch.device(\"cuda\")\n",
"elif torch.backends.mps.is_available():\n",
" device = torch.device(\"mps\")\n",
"else:\n",
" device = torch.device(\"cpu\")\n",
"print(f\"using device: {device}\")\n",
"\n",
"if torch.cuda.get_device_properties(0).major >= 8:\n",
"if device.type == \"cuda\":\n",
" # use bfloat16 for the entire notebook\n",
" torch.autocast(\"cuda\", dtype=torch.bfloat16).__enter__()\n",
" # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)\n",
" torch.backends.cuda.matmul.allow_tf32 = True\n",
" torch.backends.cudnn.allow_tf32 = True"
" if torch.cuda.get_device_properties(0).major >= 8:\n",
" torch.backends.cuda.matmul.allow_tf32 = True\n",
" torch.backends.cudnn.allow_tf32 = True"
]
},
{
Expand All @@ -153,7 +172,7 @@
"sam2_checkpoint = \"../checkpoints/sam2_hiera_large.pt\"\n",
"model_cfg = \"sam2_hiera_l.yaml\"\n",
"\n",
"predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)"
"predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device=device)"
]
},
{
Expand Down
4 changes: 3 additions & 1 deletion sam2/automatic_mask_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,9 @@ def _process_batch(
orig_h, orig_w = orig_size

# Run model on this batch
points = torch.as_tensor(points, device=self.predictor.device)
points = torch.as_tensor(
points, dtype=torch.float32, device=self.predictor.device
)
in_points = self.predictor._transforms.transform_coords(
points, normalize=normalize, orig_hw=im_size
)
Expand Down
2 changes: 1 addition & 1 deletion sam2/modeling/position_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,6 @@ def apply_rotary_enc(
# repeat freqs along seq_len dim to match k seq_len
if repeat_freqs_k:
r = xk_.shape[-2] // xq_.shape[-2]
freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
4 changes: 2 additions & 2 deletions sam2/modeling/sam2_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,10 +567,10 @@ def _prepare_memory_conditioned_features(
continue # skip padding frames
# "maskmem_features" might have been offloaded to CPU in demo use cases,
# so we load it back to GPU (it's a no-op if it's already on GPU).
feats = prev["maskmem_features"].cuda(non_blocking=True)
feats = prev["maskmem_features"].to(device, non_blocking=True)
to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
# Spatial positional encoding (it might have been offloaded to CPU in eval)
maskmem_enc = prev["maskmem_pos_enc"][-1].cuda()
maskmem_enc = prev["maskmem_pos_enc"][-1].to(device)
maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
# Temporal positional encoding
maskmem_enc = (
Expand Down
12 changes: 8 additions & 4 deletions sam2/sam2_video_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,13 @@ def init_state(
async_loading_frames=False,
):
"""Initialize a inference state."""
compute_device = next(self.parameters()).device # device of the model
images, video_height, video_width = load_video_frames(
video_path=video_path,
image_size=self.image_size,
offload_video_to_cpu=offload_video_to_cpu,
async_loading_frames=async_loading_frames,
compute_device=compute_device,
)
inference_state = {}
inference_state["images"] = images
Expand All @@ -65,11 +67,11 @@ def init_state(
# the original video height and width, used for resizing final output scores
inference_state["video_height"] = video_height
inference_state["video_width"] = video_width
inference_state["device"] = torch.device("cuda")
inference_state["device"] = compute_device
if offload_state_to_cpu:
inference_state["storage_device"] = torch.device("cpu")
else:
inference_state["storage_device"] = torch.device("cuda")
inference_state["storage_device"] = compute_device
# inputs on each frame
inference_state["point_inputs_per_obj"] = {}
inference_state["mask_inputs_per_obj"] = {}
Expand Down Expand Up @@ -270,7 +272,8 @@ def add_new_points_or_box(
prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)

if prev_out is not None and prev_out["pred_masks"] is not None:
prev_sam_mask_logits = prev_out["pred_masks"].cuda(non_blocking=True)
device = inference_state["device"]
prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True)
# Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
current_out, _ = self._run_single_frame_inference(
Expand Down Expand Up @@ -793,7 +796,8 @@ def _get_image_feature(self, inference_state, frame_idx, batch_size):
)
if backbone_out is None:
# Cache miss -- we will run inference on a single image
image = inference_state["images"][frame_idx].cuda().float().unsqueeze(0)
device = inference_state["device"]
image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0)
backbone_out = self.forward_image(image)
# Cache the most recent frame's feature (for repeated interactions with
# a frame; we can use an LRU cache for more frames in the future).
Expand Down
27 changes: 21 additions & 6 deletions sam2/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,15 @@ class AsyncVideoFrameLoader:
A list of video frames to be load asynchronously without blocking session start.
"""

def __init__(self, img_paths, image_size, offload_video_to_cpu, img_mean, img_std):
def __init__(
self,
img_paths,
image_size,
offload_video_to_cpu,
img_mean,
img_std,
compute_device,
):
self.img_paths = img_paths
self.image_size = image_size
self.offload_video_to_cpu = offload_video_to_cpu
Expand All @@ -119,6 +127,7 @@ def __init__(self, img_paths, image_size, offload_video_to_cpu, img_mean, img_st
# video_height and video_width be filled when loading the first image
self.video_height = None
self.video_width = None
self.compute_device = compute_device

# load the first frame to fill video_height and video_width and also
# to cache it (since it's most likely where the user will click)
Expand Down Expand Up @@ -152,7 +161,7 @@ def __getitem__(self, index):
img -= self.img_mean
img /= self.img_std
if not self.offload_video_to_cpu:
img = img.cuda(non_blocking=True)
img = img.to(self.compute_device, non_blocking=True)
self.images[index] = img
return img

Expand All @@ -167,6 +176,7 @@ def load_video_frames(
img_mean=(0.485, 0.456, 0.406),
img_std=(0.229, 0.224, 0.225),
async_loading_frames=False,
compute_device=torch.device("cuda"),
):
"""
Load the video frames from a directory of JPEG files ("<frame_index>.jpg" format).
Expand Down Expand Up @@ -196,17 +206,22 @@ def load_video_frames(

if async_loading_frames:
lazy_images = AsyncVideoFrameLoader(
img_paths, image_size, offload_video_to_cpu, img_mean, img_std
img_paths,
image_size,
offload_video_to_cpu,
img_mean,
img_std,
compute_device,
)
return lazy_images, lazy_images.video_height, lazy_images.video_width

images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32)
for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")):
images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size)
if not offload_video_to_cpu:
images = images.cuda()
img_mean = img_mean.cuda()
img_std = img_std.cuda()
images = images.to(compute_device)
img_mean = img_mean.to(compute_device)
img_std = img_std.to(compute_device)
# normalize by mean and std
images -= img_mean
images /= img_std
Expand Down

0 comments on commit 7213380

Please sign in to comment.