diff --git a/.gitignore b/.gitignore index 50b9875ec..121d46aa5 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ build/* _C.* outputs/* checkpoints/*.pt +demo/backend/checkpoints/*.pt diff --git a/INSTALL.md b/INSTALL.md index 7f32564f7..9480ba1bb 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -2,7 +2,7 @@ ### Requirements -- Linux with Python ≥ 3.10, PyTorch ≥ 2.3.1 and [torchvision](https://github.com/pytorch/vision/) that matches the PyTorch installation. Install them together at https://pytorch.org to ensure this. +- Linux with Python ≥ 3.10, PyTorch ≥ 2.5.1 and [torchvision](https://github.com/pytorch/vision/) that matches the PyTorch installation. Install them together at https://pytorch.org to ensure this. * Note older versions of Python or PyTorch may also work. However, the versions above are strongly recommended to provide all features such as `torch.compile`. - [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) that match the CUDA version for your PyTorch installation. This should typically be CUDA 12.1 if you follow the default installation command. - If you are installing on Windows, it's strongly recommended to use [Windows Subsystem for Linux (WSL)](https://learn.microsoft.com/en-us/windows/wsl/install) with Ubuntu. @@ -121,9 +121,9 @@ I got `undefined symbol: _ZN3c1015SmallVectorBaseIjE8grow_podEPKvmm` (or similar This usually happens because you have multiple versions of dependencies (PyTorch or CUDA) in your environment. During installation, the SAM 2 library is compiled against one version library while at run time it links against another version. This might be due to that you have different versions of PyTorch or CUDA installed separately via `pip` or `conda`. You may delete one of the duplicates to only keep a single PyTorch and CUDA version. -In particular, if you have a lower PyTorch version than 2.3.1, it's recommended to upgrade to PyTorch 2.3.1 or higher first. Otherwise, the installation script will try to upgrade to the latest PyTorch using `pip`, which could sometimes lead to duplicated PyTorch installation if you have previously installed another PyTorch version using `conda`. +In particular, if you have a lower PyTorch version than 2.5.1, it's recommended to upgrade to PyTorch 2.5.1 or higher first. Otherwise, the installation script will try to upgrade to the latest PyTorch using `pip`, which could sometimes lead to duplicated PyTorch installation if you have previously installed another PyTorch version using `conda`. -We have been building SAM 2 against PyTorch 2.3.1 internally. However, a few user comments (e.g. https://github.com/facebookresearch/sam2/issues/22, https://github.com/facebookresearch/sam2/issues/14) suggested that downgrading to PyTorch 2.1.0 might resolve this problem. In case the error persists, you may try changing the restriction from `torch>=2.3.1` to `torch>=2.1.0` in both [`pyproject.toml`](pyproject.toml) and [`setup.py`](setup.py) to allow PyTorch 2.1.0. +We have been building SAM 2 against PyTorch 2.5.1 internally. However, a few user comments (e.g. https://github.com/facebookresearch/sam2/issues/22, https://github.com/facebookresearch/sam2/issues/14) suggested that downgrading to PyTorch 2.1.0 might resolve this problem. In case the error persists, you may try changing the restriction from `torch>=2.5.1` to `torch==2.1.0` in both [`pyproject.toml`](pyproject.toml) and [`setup.py`](setup.py) to allow PyTorch 2.1.0.
diff --git a/README.md b/README.md index 65654f5a0..85a7eb958 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,12 @@ ## Latest updates +**12/11/2024 -- full model compilation for a major VOS speedup and a new `SAM2VideoPredictor` to better handle multi-object tracking** + +- We now support `torch.compile` of the entire SAM 2 model on videos, which can be turned on by setting `vos_optimized=True` in `build_sam2_video_predictor`, leading to a major speedup for VOS inference. +- We update the implementation of `SAM2VideoPredictor` to support independent per-object inference, allowing us to relax the assumption of prompting for multi-object tracking and adding new objects after tracking starts. +- See [`RELEASE_NOTES.md`](RELEASE_NOTES.md) for full details. + **09/30/2024 -- SAM 2.1 Developer Suite (new checkpoints, training code, web demo) is released** - A new suite of improved model checkpoints (denoted as **SAM 2.1**) are released. See [Model Description](#model-description) for details. @@ -23,7 +29,7 @@ ## Installation -SAM 2 needs to be installed first before use. The code requires `python>=3.10`, as well as `torch>=2.3.1` and `torchvision>=0.18.1`. Please follow the instructions [here](https://pytorch.org/get-started/locally/) to install both PyTorch and TorchVision dependencies. You can install SAM 2 on a GPU machine using: +SAM 2 needs to be installed first before use. The code requires `python>=3.10`, as well as `torch>=2.5.1` and `torchvision>=0.20.1`. Please follow the instructions [here](https://pytorch.org/get-started/locally/) to install both PyTorch and TorchVision dependencies. You can install SAM 2 on a GPU machine using: ```bash git clone https://github.com/facebookresearch/sam2.git && cd sam2 @@ -39,7 +45,7 @@ pip install -e ".[notebooks]" ``` Note: -1. It's recommended to create a new Python environment via [Anaconda](https://www.anaconda.com/) for this installation and install PyTorch 2.3.1 (or higher) via `pip` following https://pytorch.org/. If you have a PyTorch version lower than 2.3.1 in your current environment, the installation command above will try to upgrade it to the latest PyTorch version using `pip`. +1. It's recommended to create a new Python environment via [Anaconda](https://www.anaconda.com/) for this installation and install PyTorch 2.5.1 (or higher) via `pip` following https://pytorch.org/. If you have a PyTorch version lower than 2.5.1 in your current environment, the installation command above will try to upgrade it to the latest PyTorch version using `pip`. 2. The step above requires compiling a custom CUDA kernel with the `nvcc` compiler. If it isn't already available on your machine, please install the [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) with a version that matches your PyTorch CUDA version. 3. If you see a message like `Failed to build the SAM 2 CUDA extension` during installation, you can ignore it and still use SAM 2 (some post-processing functionality may be limited, but it doesn't affect the results in most cases). @@ -158,10 +164,10 @@ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): The table below shows the improved SAM 2.1 checkpoints released on September 29, 2024. | **Model** | **Size (M)** | **Speed (FPS)** | **SA-V test (J&F)** | **MOSE val (J&F)** | **LVOS v2 (J&F)** | | :------------------: | :----------: | :--------------------: | :-----------------: | :----------------: | :---------------: | -| sam2.1_hiera_tiny
([config](sam2/configs/sam2.1/sam2.1_hiera_t.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt)) | 38.9 | 47.2 | 76.5 | 71.8 | 77.3 | -| sam2.1_hiera_small
([config](sam2/configs/sam2.1/sam2.1_hiera_s.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt)) | 46 | 43.3 (53.0 compiled\*) | 76.6 | 73.5 | 78.3 | -| sam2.1_hiera_base_plus
([config](sam2/configs/sam2.1/sam2.1_hiera_b+.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt)) | 80.8 | 34.8 (43.8 compiled\*) | 78.2 | 73.7 | 78.2 | -| sam2.1_hiera_large
([config](sam2/configs/sam2.1/sam2.1_hiera_l.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt)) | 224.4 | 24.2 (30.2 compiled\*) | 79.5 | 74.6 | 80.6 | +| sam2.1_hiera_tiny
([config](sam2/configs/sam2.1/sam2.1_hiera_t.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt)) | 38.9 | 91.2 | 76.5 | 71.8 | 77.3 | +| sam2.1_hiera_small
([config](sam2/configs/sam2.1/sam2.1_hiera_s.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt)) | 46 | 84.8 | 76.6 | 73.5 | 78.3 | +| sam2.1_hiera_base_plus
([config](sam2/configs/sam2.1/sam2.1_hiera_b+.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt)) | 80.8 | 64.1 | 78.2 | 73.7 | 78.2 | +| sam2.1_hiera_large
([config](sam2/configs/sam2.1/sam2.1_hiera_l.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt)) | 224.4 | 39.5 | 79.5 | 74.6 | 80.6 | ### SAM 2 checkpoints @@ -169,13 +175,12 @@ The previous SAM 2 checkpoints released on July 29, 2024 can be found as follows | **Model** | **Size (M)** | **Speed (FPS)** | **SA-V test (J&F)** | **MOSE val (J&F)** | **LVOS v2 (J&F)** | | :------------------: | :----------: | :--------------------: | :-----------------: | :----------------: | :---------------: | -| sam2_hiera_tiny
([config](sam2/configs/sam2/sam2_hiera_t.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt)) | 38.9 | 47.2 | 75.0 | 70.9 | 75.3 | -| sam2_hiera_small
([config](sam2/configs/sam2/sam2_hiera_s.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt)) | 46 | 43.3 (53.0 compiled\*) | 74.9 | 71.5 | 76.4 | -| sam2_hiera_base_plus
([config](sam2/configs/sam2/sam2_hiera_b+.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt)) | 80.8 | 34.8 (43.8 compiled\*) | 74.7 | 72.8 | 75.8 | -| sam2_hiera_large
([config](sam2/configs/sam2/sam2_hiera_l.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt)) | 224.4 | 24.2 (30.2 compiled\*) | 76.0 | 74.6 | 79.8 | - -\* Compile the model by setting `compile_image_encoder: True` in the config. +| sam2_hiera_tiny
([config](sam2/configs/sam2/sam2_hiera_t.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt)) | 38.9 | 91.5 | 75.0 | 70.9 | 75.3 | +| sam2_hiera_small
([config](sam2/configs/sam2/sam2_hiera_s.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt)) | 46 | 85.6 | 74.9 | 71.5 | 76.4 | +| sam2_hiera_base_plus
([config](sam2/configs/sam2/sam2_hiera_b+.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt)) | 80.8 | 64.8 | 74.7 | 72.8 | 75.8 | +| sam2_hiera_large
([config](sam2/configs/sam2/sam2_hiera_l.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt)) | 224.4 | 39.7 | 76.0 | 74.6 | 79.8 | +Speed measured on an A100 with `torch 2.5.1, cuda 12.4`. See `benchmark.py` for an example on benchmarking (compiling all the model components). Compiling only the image encoder can be more flexible and also provide (a smaller) speed-up (set `compile_image_encoder: True` in the config). ## Segment Anything Video Dataset See [sav_dataset/README.md](sav_dataset/README.md) for details. diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md new file mode 100644 index 000000000..ee65ae7f4 --- /dev/null +++ b/RELEASE_NOTES.md @@ -0,0 +1,27 @@ +## SAM 2 release notes + +### 12/11/2024 -- full model compilation for a major VOS speedup and a new `SAM2VideoPredictor` to better handle multi-object tracking + +- We now support `torch.compile` of the entire SAM 2 model on videos, which can be turned on by setting `vos_optimized=True` in `build_sam2_video_predictor` (it uses the new `SAM2VideoPredictorVOS` predictor class in `sam2/sam2_video_predictor.py`). + * Compared to the previous setting (which only compiles the image encoder backbone), the new full model compilation gives a major speedup in inference FPS. + * In the VOS prediction script `tools/vos_inference.py`, you can specify this option in `tools/vos_inference.py` via the `--use_vos_optimized_video_predictor` flag. + * Note that turning on this flag might introduce a small variance in the predictions due to numerical differences caused by `torch.compile` of the full model. + * **PyTorch 2.5.1 is the minimum version for full support of this feature**. (Earlier PyTorch versions might run into compilation errors in some cases.) Therefore, we have updated the minimum PyTorch version to 2.5.1 accordingly in the installation scripts. +- We also update the implementation of the `SAM2VideoPredictor` class for the SAM 2 video prediction in `sam2/sam2_video_predictor.py`, which allows for independent per-object inference. Specifically, in the new `SAM2VideoPredictor`: + * Now **we handle the inference of each object independently** (as if we are opening a separate session for each object) while sharing their backbone features. + * This change allows us to relax the assumption of prompting for multi-object tracking. Previously (due to the batching behavior in inference), if a video frame receives clicks for only a subset of objects, the rest of the (non-prompted) objects are assumed to be non-existent in this frame (i.e., in such frames, the user is telling SAM 2 that the rest of the objects don't appear). Now, if a frame receives clicks for only a subset of objects, we do not make any assumptions about the remaining (non-prompted) objects (i.e., now each object is handled independently and is not affected by how other objects are prompted). As a result, **we allow adding new objects after tracking starts** after this change (which was previously a restriction on usage). + * We believe that the new version is a more natural inference behavior and therefore switched to it as the default behavior. The previous implementation of `SAM2VideoPredictor` is backed up to in `sam2/sam2_video_predictor_legacy.py`. All the VOS inference results using `tools/vos_inference.py` should remain the same after this change to the `SAM2VideoPredictor` class. + +### 09/30/2024 -- SAM 2.1 Developer Suite (new checkpoints, training code, web demo) is released + +- A new suite of improved model checkpoints (denoted as **SAM 2.1**) are released. See [Model Description](#model-description) for details. + * To use the new SAM 2.1 checkpoints, you need the latest model code from this repo. If you have installed an earlier version of this repo, please first uninstall the previous version via `pip uninstall SAM-2`, pull the latest code from this repo (with `git pull`), and then reinstall the repo following [Installation](#installation) below. +- The training (and fine-tuning) code has been released. See [`training/README.md`](training/README.md) on how to get started. +- The frontend + backend code for the SAM 2 web demo has been released. See [`demo/README.md`](demo/README.md) for details. + +### 07/29/2024 -- SAM 2 is released + +- We release Segment Anything Model 2 (SAM 2), a foundation model towards solving promptable visual segmentation in images and videos. + * SAM 2 code: https://github.com/facebookresearch/sam2 + * SAM 2 demo: https://sam2.metademolab.com/ + * SAM 2 paper: https://arxiv.org/abs/2408.00714 diff --git a/backend.Dockerfile b/backend.Dockerfile index adec61d56..54a32967b 100644 --- a/backend.Dockerfile +++ b/backend.Dockerfile @@ -1,4 +1,4 @@ -ARG BASE_IMAGE=pytorch/pytorch:2.3.1-cuda12.1-cudnn8-runtime +ARG BASE_IMAGE=pytorch/pytorch:2.5.1-cuda12.1-cudnn9-runtime ARG MODEL_SIZE=base_plus FROM ${BASE_IMAGE} diff --git a/demo/README.md b/demo/README.md index 2abe2aa0d..2f80be7a5 100644 --- a/demo/README.md +++ b/demo/README.md @@ -105,7 +105,7 @@ cd demo/backend/server/ ```bash PYTORCH_ENABLE_MPS_FALLBACK=1 \ APP_ROOT="$(pwd)/../../../" \ -APP_URL=http://localhost:7263 \ +API_URL=http://localhost:7263 \ MODEL_SIZE=base_plus \ DATA_PATH="$(pwd)/../../data" \ DEFAULT_VIDEO_PATH=gallery/05_default_juggle.mp4 \ diff --git a/pyproject.toml b/pyproject.toml index f7e865232..f84317dbb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [build-system] requires = [ "setuptools>=61.0", - "torch>=2.3.1", + "torch>=2.5.1", ] build-backend = "setuptools.build_meta" diff --git a/sam2/benchmark.py b/sam2/benchmark.py new file mode 100644 index 000000000..6519534c8 --- /dev/null +++ b/sam2/benchmark.py @@ -0,0 +1,92 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +import time + +import numpy as np +import torch +from tqdm import tqdm + +from sam2.build_sam import build_sam2_video_predictor + +# Only cuda supported +assert torch.cuda.is_available() +device = torch.device("cuda") + +torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() +if torch.cuda.get_device_properties(0).major >= 8: + # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + +# Config and checkpoint +sam2_checkpoint = "checkpoints/sam2.1_hiera_base_plus.pt" +model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml" + +# Build video predictor with vos_optimized=True setting +predictor = build_sam2_video_predictor( + model_cfg, sam2_checkpoint, device=device, vos_optimized=True +) + + +# Initialize with video +video_dir = "notebooks/videos/bedroom" +# scan all the JPEG frame names in this directory +frame_names = [ + p + for p in os.listdir(video_dir) + if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] +] +frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) +inference_state = predictor.init_state(video_path=video_dir) + + +# Number of runs, warmup etc +warm_up, runs = 5, 25 +verbose = True +num_frames = len(frame_names) +total, count = 0, 0 +torch.cuda.empty_cache() + +# We will select an object with a click. +# See video_predictor_example.ipynb for more detailed explanation +ann_frame_idx, ann_obj_id = 0, 1 +# Add a positive click at (x, y) = (210, 350) +# For labels, `1` means positive click +points = np.array([[210, 350]], dtype=np.float32) +labels = np.array([1], np.int32) + +_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box( + inference_state=inference_state, + frame_idx=ann_frame_idx, + obj_id=ann_obj_id, + points=points, + labels=labels, +) + +# Warmup and then average FPS over several runs +with torch.autocast("cuda", torch.bfloat16): + with torch.inference_mode(): + for i in tqdm(range(runs), disable=not verbose, desc="Benchmarking"): + start = time.time() + # Start tracking + for ( + out_frame_idx, + out_obj_ids, + out_mask_logits, + ) in predictor.propagate_in_video(inference_state): + pass + + end = time.time() + total += end - start + count += 1 + if i == warm_up - 1: + print("Warmup FPS: ", count * num_frames / total) + total = 0 + count = 0 + +print("FPS: ", count * num_frames / total) diff --git a/sam2/build_sam.py b/sam2/build_sam.py index 7cfc45139..3a3bef1e5 100644 --- a/sam2/build_sam.py +++ b/sam2/build_sam.py @@ -104,11 +104,18 @@ def build_sam2_video_predictor( mode="eval", hydra_overrides_extra=[], apply_postprocessing=True, + vos_optimized=False, **kwargs, ): hydra_overrides = [ "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor", ] + if vos_optimized: + hydra_overrides = [ + "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictorVOS", + "++model.compile_image_encoder=True", # Let sam2_base handle this + ] + if apply_postprocessing: hydra_overrides_extra = hydra_overrides_extra.copy() hydra_overrides_extra += [ diff --git a/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml b/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml index cbee3cf9b..d7172f9b0 100644 --- a/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +++ b/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml @@ -36,7 +36,7 @@ model: self_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 - feat_sizes: [32, 32] + feat_sizes: [64, 64] embedding_dim: 256 num_heads: 1 downsample_rate: 1 @@ -47,7 +47,7 @@ model: cross_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 - feat_sizes: [32, 32] + feat_sizes: [64, 64] rope_k_repeat: True embedding_dim: 256 num_heads: 1 diff --git a/sam2/configs/sam2.1/sam2.1_hiera_l.yaml b/sam2/configs/sam2.1/sam2.1_hiera_l.yaml index 33c9097f3..23073ea7a 100644 --- a/sam2/configs/sam2.1/sam2.1_hiera_l.yaml +++ b/sam2/configs/sam2.1/sam2.1_hiera_l.yaml @@ -40,7 +40,7 @@ model: self_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 - feat_sizes: [32, 32] + feat_sizes: [64, 64] embedding_dim: 256 num_heads: 1 downsample_rate: 1 @@ -51,7 +51,7 @@ model: cross_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 - feat_sizes: [32, 32] + feat_sizes: [64, 64] rope_k_repeat: True embedding_dim: 256 num_heads: 1 diff --git a/sam2/configs/sam2.1/sam2.1_hiera_s.yaml b/sam2/configs/sam2.1/sam2.1_hiera_s.yaml index 8e803dfea..fd8d40465 100644 --- a/sam2/configs/sam2.1/sam2.1_hiera_s.yaml +++ b/sam2/configs/sam2.1/sam2.1_hiera_s.yaml @@ -39,7 +39,7 @@ model: self_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 - feat_sizes: [32, 32] + feat_sizes: [64, 64] embedding_dim: 256 num_heads: 1 downsample_rate: 1 @@ -50,7 +50,7 @@ model: cross_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 - feat_sizes: [32, 32] + feat_sizes: [64, 64] rope_k_repeat: True embedding_dim: 256 num_heads: 1 diff --git a/sam2/configs/sam2.1/sam2.1_hiera_t.yaml b/sam2/configs/sam2.1/sam2.1_hiera_t.yaml index 983c2ea03..e762aec93 100644 --- a/sam2/configs/sam2.1/sam2.1_hiera_t.yaml +++ b/sam2/configs/sam2.1/sam2.1_hiera_t.yaml @@ -39,7 +39,7 @@ model: self_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 - feat_sizes: [32, 32] + feat_sizes: [64, 64] embedding_dim: 256 num_heads: 1 downsample_rate: 1 @@ -50,7 +50,7 @@ model: cross_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 - feat_sizes: [32, 32] + feat_sizes: [64, 64] rope_k_repeat: True embedding_dim: 256 num_heads: 1 diff --git a/sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml b/sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml index 204679146..9b6faa79f 100644 --- a/sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +++ b/sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml @@ -97,7 +97,7 @@ trainer: self_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 - feat_sizes: [32, 32] + feat_sizes: [64, 64] embedding_dim: 256 num_heads: 1 downsample_rate: 1 @@ -108,7 +108,7 @@ trainer: cross_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 - feat_sizes: [32, 32] + feat_sizes: [64, 64] rope_k_repeat: True embedding_dim: 256 num_heads: 1 diff --git a/sam2/configs/sam2/sam2_hiera_b+.yaml b/sam2/configs/sam2/sam2_hiera_b+.yaml index 58f3eb815..0f435af02 100644 --- a/sam2/configs/sam2/sam2_hiera_b+.yaml +++ b/sam2/configs/sam2/sam2_hiera_b+.yaml @@ -36,7 +36,7 @@ model: self_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 - feat_sizes: [32, 32] + feat_sizes: [64, 64] embedding_dim: 256 num_heads: 1 downsample_rate: 1 @@ -47,7 +47,7 @@ model: cross_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 - feat_sizes: [32, 32] + feat_sizes: [64, 64] rope_k_repeat: True embedding_dim: 256 num_heads: 1 diff --git a/sam2/configs/sam2/sam2_hiera_l.yaml b/sam2/configs/sam2/sam2_hiera_l.yaml index 918667f50..1092802b1 100644 --- a/sam2/configs/sam2/sam2_hiera_l.yaml +++ b/sam2/configs/sam2/sam2_hiera_l.yaml @@ -40,7 +40,7 @@ model: self_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 - feat_sizes: [32, 32] + feat_sizes: [64, 64] embedding_dim: 256 num_heads: 1 downsample_rate: 1 @@ -51,7 +51,7 @@ model: cross_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 - feat_sizes: [32, 32] + feat_sizes: [64, 64] rope_k_repeat: True embedding_dim: 256 num_heads: 1 diff --git a/sam2/configs/sam2/sam2_hiera_s.yaml b/sam2/configs/sam2/sam2_hiera_s.yaml index 26e5d4d39..174e414f1 100644 --- a/sam2/configs/sam2/sam2_hiera_s.yaml +++ b/sam2/configs/sam2/sam2_hiera_s.yaml @@ -39,7 +39,7 @@ model: self_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 - feat_sizes: [32, 32] + feat_sizes: [64, 64] embedding_dim: 256 num_heads: 1 downsample_rate: 1 @@ -50,7 +50,7 @@ model: cross_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 - feat_sizes: [32, 32] + feat_sizes: [64, 64] rope_k_repeat: True embedding_dim: 256 num_heads: 1 diff --git a/sam2/configs/sam2/sam2_hiera_t.yaml b/sam2/configs/sam2/sam2_hiera_t.yaml index a62c903aa..121447aab 100644 --- a/sam2/configs/sam2/sam2_hiera_t.yaml +++ b/sam2/configs/sam2/sam2_hiera_t.yaml @@ -39,7 +39,7 @@ model: self_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 - feat_sizes: [32, 32] + feat_sizes: [64, 64] embedding_dim: 256 num_heads: 1 downsample_rate: 1 @@ -50,7 +50,7 @@ model: cross_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 - feat_sizes: [32, 32] + feat_sizes: [64, 64] rope_k_repeat: True embedding_dim: 256 num_heads: 1 diff --git a/sam2/modeling/backbones/utils.py b/sam2/modeling/backbones/utils.py index 32d55c754..930b1b762 100644 --- a/sam2/modeling/backbones/utils.py +++ b/sam2/modeling/backbones/utils.py @@ -32,9 +32,7 @@ def window_partition(x, window_size): Hp, Wp = H + pad_h, W + pad_w x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) - windows = ( - x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) - ) + windows = x.permute(0, 1, 3, 2, 4, 5).reshape(-1, window_size, window_size, C) return windows, (Hp, Wp) @@ -52,13 +50,13 @@ def window_unpartition(windows, window_size, pad_hw, hw): Hp, Wp = pad_hw H, W = hw B = windows.shape[0] // (Hp * Wp // window_size // window_size) - x = windows.view( + x = windows.reshape( B, Hp // window_size, Wp // window_size, window_size, window_size, -1 ) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, Hp, Wp, -1) if Hp > H or Wp > W: - x = x[:, :H, :W, :].contiguous() + x = x[:, :H, :W, :] return x diff --git a/sam2/modeling/position_encoding.py b/sam2/modeling/position_encoding.py index 52ac22674..2241d4cf1 100644 --- a/sam2/modeling/position_encoding.py +++ b/sam2/modeling/position_encoding.py @@ -25,6 +25,11 @@ def __init__( temperature: int = 10000, normalize: bool = True, scale: Optional[float] = None, + # Following settings only relevant + # for warmping up cache for compilation + warmup_cache: bool = True, + image_size: int = 1024, + strides: Tuple[int] = (4, 8, 16, 32), ): super().__init__() assert num_pos_feats % 2 == 0, "Expecting even model width" @@ -38,6 +43,12 @@ def __init__( self.scale = scale self.cache = {} + if warmup_cache and torch.cuda.is_available(): + # Warmup cache for cuda, to help with compilation + device = torch.device("cuda") + for stride in strides: + cache_key = (image_size // stride, image_size // stride) + self._pe(1, device, *cache_key) def _encode_xy(self, x, y): # The positions are expected to be normalized @@ -76,19 +87,20 @@ def encode_points(self, x, y, labels): return pos @torch.no_grad() - def forward(self, x: torch.Tensor): - cache_key = (x.shape[-2], x.shape[-1]) + def _pe(self, B, device, *cache_key): + H, W = cache_key if cache_key in self.cache: - return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1) + return self.cache[cache_key].to(device)[None].repeat(B, 1, 1, 1) + y_embed = ( - torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device) + torch.arange(1, H + 1, dtype=torch.float32, device=device) .view(1, -1, 1) - .repeat(x.shape[0], 1, x.shape[-1]) + .repeat(B, 1, W) ) x_embed = ( - torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device) + torch.arange(1, W + 1, dtype=torch.float32, device=device) .view(1, 1, -1) - .repeat(x.shape[0], x.shape[-2], 1) + .repeat(B, H, 1) ) if self.normalize: @@ -96,7 +108,7 @@ def forward(self, x: torch.Tensor): y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale - dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=device) dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) pos_x = x_embed[:, :, :, None] / dim_t @@ -111,6 +123,12 @@ def forward(self, x: torch.Tensor): self.cache[cache_key] = pos[0] return pos + @torch.no_grad() + def forward(self, x: torch.Tensor): + B = x.shape[0] + cache_key = (x.shape[-2], x.shape[-1]) + return self._pe(B, x.device, *cache_key) + class PositionEmbeddingRandom(nn.Module): """ diff --git a/sam2/modeling/sam/prompt_encoder.py b/sam2/modeling/sam/prompt_encoder.py index 6b3bbb95b..c57876264 100644 --- a/sam2/modeling/sam/prompt_encoder.py +++ b/sam2/modeling/sam/prompt_encoder.py @@ -92,12 +92,32 @@ def _embed_points( point_embedding = self.pe_layer.forward_with_coords( points, self.input_image_size ) - point_embedding[labels == -1] = 0.0 - point_embedding[labels == -1] += self.not_a_point_embed.weight - point_embedding[labels == 0] += self.point_embeddings[0].weight - point_embedding[labels == 1] += self.point_embeddings[1].weight - point_embedding[labels == 2] += self.point_embeddings[2].weight - point_embedding[labels == 3] += self.point_embeddings[3].weight + + point_embedding = torch.where( + (labels == -1).unsqueeze(-1), + torch.zeros_like(point_embedding) + self.not_a_point_embed.weight, + point_embedding, + ) + point_embedding = torch.where( + (labels == 0).unsqueeze(-1), + point_embedding + self.point_embeddings[0].weight, + point_embedding, + ) + point_embedding = torch.where( + (labels == 1).unsqueeze(-1), + point_embedding + self.point_embeddings[1].weight, + point_embedding, + ) + point_embedding = torch.where( + (labels == 2).unsqueeze(-1), + point_embedding + self.point_embeddings[2].weight, + point_embedding, + ) + point_embedding = torch.where( + (labels == 3).unsqueeze(-1), + point_embedding + self.point_embeddings[3].weight, + point_embedding, + ) return point_embedding def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: diff --git a/sam2/modeling/sam/transformer.py b/sam2/modeling/sam/transformer.py index b5b6fa2f8..f9fe9a3fb 100644 --- a/sam2/modeling/sam/transformer.py +++ b/sam2/modeling/sam/transformer.py @@ -4,9 +4,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -import contextlib import math -import warnings from functools import partial from typing import Tuple, Type @@ -16,29 +14,6 @@ from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis from sam2.modeling.sam2_utils import MLP -from sam2.utils.misc import get_sdpa_settings - -warnings.simplefilter(action="ignore", category=FutureWarning) -# Check whether Flash Attention is available (and use it by default) -OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings() -# A fallback setting to allow all available kernels if Flash Attention fails -ALLOW_ALL_KERNELS = False - - -def sdp_kernel_context(dropout_p): - """ - Get the context for the attention scaled dot-product kernel. We use Flash Attention - by default, but fall back to all available kernels if Flash Attention fails. - """ - if ALLOW_ALL_KERNELS: - return contextlib.nullcontext() - - return torch.backends.cuda.sdp_kernel( - enable_flash=USE_FLASH_ATTN, - # if Flash attention kernel is off, then math kernel needs to be enabled - enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON, - enable_mem_efficient=OLD_GPU, - ) class TwoWayTransformer(nn.Module): @@ -265,20 +240,7 @@ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: dropout_p = self.dropout_p if self.training else 0.0 # Attention - try: - with sdp_kernel_context(dropout_p): - out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) - except Exception as e: - # Fall back to all kernels if the Flash attention kernel fails - warnings.warn( - f"Flash Attention kernel failed due to: {e}\nFalling back to all available " - f"kernels for scaled_dot_product_attention (which may have a slower speed).", - category=UserWarning, - stacklevel=2, - ) - global ALLOW_ALL_KERNELS - ALLOW_ALL_KERNELS = True - out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) out = self._recombine_heads(out) out = self.out_proj(out) @@ -296,7 +258,7 @@ def __init__( # whether to repeat q rope to match k length # this is needed for cross-attention to memories rope_k_repeat=False, - feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution + feat_sizes=(64, 64), # [w, h] for stride 16 feats at 1024 resolution **kwargs, ): super().__init__(*args, **kwargs) @@ -305,7 +267,9 @@ def __init__( compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta ) freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1]) - self.freqs_cis = freqs_cis + self.freqs_cis = ( + freqs_cis.to("cuda") if torch.cuda.is_available() else freqs_cis + ) self.rope_k_repeat = rope_k_repeat def forward( @@ -339,20 +303,7 @@ def forward( dropout_p = self.dropout_p if self.training else 0.0 # Attention - try: - with sdp_kernel_context(dropout_p): - out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) - except Exception as e: - # Fall back to all kernels if the Flash attention kernel fails - warnings.warn( - f"Flash Attention kernel failed due to: {e}\nFalling back to all available " - f"kernels for scaled_dot_product_attention (which may have a slower speed).", - category=UserWarning, - stacklevel=2, - ) - global ALLOW_ALL_KERNELS - ALLOW_ALL_KERNELS = True - out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) out = self._recombine_heads(out) out = self.out_proj(out) diff --git a/sam2/modeling/sam2_base.py b/sam2/modeling/sam2_base.py index a5d243adc..8aa1a0b11 100644 --- a/sam2/modeling/sam2_base.py +++ b/sam2/modeling/sam2_base.py @@ -628,7 +628,11 @@ def _prepare_memory_conditioned_features( if self.add_tpos_enc_to_obj_ptrs: t_diff_max = max_obj_ptrs_in_encoder - 1 tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim - obj_pos = torch.tensor(pos_list, device=device) + obj_pos = ( + torch.tensor(pos_list) + .pin_memory() + .to(device=device, non_blocking=True) + ) obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim) obj_pos = self.obj_ptr_tpos_proj(obj_pos) obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim) diff --git a/sam2/sam2_video_predictor.py b/sam2/sam2_video_predictor.py index c7e01ccf9..463770667 100644 --- a/sam2/sam2_video_predictor.py +++ b/sam2/sam2_video_predictor.py @@ -8,6 +8,7 @@ from collections import OrderedDict import torch +import torch.nn.functional as F from tqdm import tqdm @@ -26,8 +27,6 @@ def __init__( # whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks; # note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True) clear_non_cond_mem_around_input=False, - # whether to also clear non-conditioning memory of the surrounding frames (only effective when `clear_non_cond_mem_around_input` is True). - clear_non_cond_mem_for_multi_obj=False, # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames add_all_frames_to_correct_as_cond=False, @@ -37,7 +36,6 @@ def __init__( self.fill_hole_area = fill_hole_area self.non_overlap_masks = non_overlap_masks self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input - self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond @torch.inference_mode() @@ -87,11 +85,6 @@ def init_state( inference_state["obj_id_to_idx"] = OrderedDict() inference_state["obj_idx_to_id"] = OrderedDict() inference_state["obj_ids"] = [] - # A storage to hold the model's tracking results and states on each frame - inference_state["output_dict"] = { - "cond_frame_outputs": {}, # dict containing {frame_idx: } - "non_cond_frame_outputs": {}, # dict containing {frame_idx: } - } # Slice (view) of each object tracking results, sharing the same memory with "output_dict" inference_state["output_dict_per_obj"] = {} # A temporary storage to hold new outputs when user interact with a frame @@ -99,13 +92,8 @@ def init_state( inference_state["temp_output_dict_per_obj"] = {} # Frames that already holds consolidated outputs from click or mask inputs # (we directly use their consolidated outputs during tracking) - inference_state["consolidated_frame_inds"] = { - "cond_frame_outputs": set(), # set containing frame indices - "non_cond_frame_outputs": set(), # set containing frame indices - } # metadata for each tracking frame (e.g. which direction it's tracked) - inference_state["tracking_has_started"] = False - inference_state["frames_already_tracked"] = {} + inference_state["frames_tracked_per_obj"] = {} # Warm up the visual backbone and cache the image feature on frame 0 self._get_image_feature(inference_state, frame_idx=0, batch_size=1) return inference_state @@ -133,9 +121,8 @@ def _obj_id_to_idx(self, inference_state, obj_id): if obj_idx is not None: return obj_idx - # This is a new object id not sent to the server before. We only allow adding - # new objects *before* the tracking starts. - allow_new_object = not inference_state["tracking_has_started"] + # We always allow adding new objects (including after tracking starts). + allow_new_object = True if allow_new_object: # get the next object slot obj_idx = len(inference_state["obj_id_to_idx"]) @@ -153,6 +140,7 @@ def _obj_id_to_idx(self, inference_state, obj_id): "cond_frame_outputs": {}, # dict containing {frame_idx: } "non_cond_frame_outputs": {}, # dict containing {frame_idx: } } + inference_state["frames_tracked_per_obj"][obj_idx] = {} return obj_idx else: raise RuntimeError( @@ -213,15 +201,6 @@ def add_new_points_or_box( "box prompt must be provided before any point prompt " "(please use clear_old_points=True instead)" ) - if inference_state["tracking_has_started"]: - warnings.warn( - "You are adding a box after tracking starts. SAM 2 may not always be " - "able to incorporate a box prompt for *refinement*. If you intend to " - "use box prompt as an *initial* input before tracking, please call " - "'reset_state' on the inference state to restart from scratch.", - category=UserWarning, - stacklevel=2, - ) if not isinstance(box, torch.Tensor): box = torch.tensor(box, dtype=torch.float32, device=points.device) box_coords = box.reshape(1, 2, 2) @@ -251,12 +230,13 @@ def add_new_points_or_box( # frame, meaning that the inputs points are to generate segments on this frame without # using any memory from other frames, like in SAM. Otherwise (if it has been tracked), # the input points will be used to correct the already tracked masks. - is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"] + obj_frames_tracked = inference_state["frames_tracked_per_obj"][obj_idx] + is_init_cond_frame = frame_idx not in obj_frames_tracked # whether to track in reverse time order if is_init_cond_frame: reverse = False else: - reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"] + reverse = obj_frames_tracked[frame_idx]["reverse"] obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] # Add a frame to conditioning output if it's an initial conditioning frame or @@ -305,7 +285,6 @@ def add_new_points_or_box( inference_state, frame_idx, is_cond=is_cond, - run_mem_encoder=False, consolidate_at_video_res=True, ) _, video_res_masks = self._get_orig_video_res_output( @@ -356,12 +335,13 @@ def add_new_mask( # frame, meaning that the inputs points are to generate segments on this frame without # using any memory from other frames, like in SAM. Otherwise (if it has been tracked), # the input points will be used to correct the already tracked masks. - is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"] + obj_frames_tracked = inference_state["frames_tracked_per_obj"][obj_idx] + is_init_cond_frame = frame_idx not in obj_frames_tracked # whether to track in reverse time order if is_init_cond_frame: reverse = False else: - reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"] + reverse = obj_frames_tracked[frame_idx]["reverse"] obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] # Add a frame to conditioning output if it's an initial conditioning frame or @@ -393,7 +373,6 @@ def add_new_mask( inference_state, frame_idx, is_cond=is_cond, - run_mem_encoder=False, consolidate_at_video_res=True, ) _, video_res_masks = self._get_orig_video_res_output( @@ -428,7 +407,6 @@ def _consolidate_temp_output_across_obj( inference_state, frame_idx, is_cond, - run_mem_encoder, consolidate_at_video_res=False, ): """ @@ -445,7 +423,6 @@ def _consolidate_temp_output_across_obj( # Optionally, we allow consolidating the temporary outputs at the original # video resolution (to provide a better editing experience for mask prompts). if consolidate_at_video_res: - assert not run_mem_encoder, "memory encoder cannot run at video resolution" consolidated_H = inference_state["video_height"] consolidated_W = inference_state["video_width"] consolidated_mask_key = "pred_masks_video_res" @@ -458,30 +435,13 @@ def _consolidate_temp_output_across_obj( # constraints to object scores. Its "pred_masks" are prefilled with a large # negative value (NO_OBJ_SCORE) to represent missing objects. consolidated_out = { - "maskmem_features": None, - "maskmem_pos_enc": None, consolidated_mask_key: torch.full( size=(batch_size, 1, consolidated_H, consolidated_W), fill_value=NO_OBJ_SCORE, dtype=torch.float32, device=inference_state["storage_device"], ), - "obj_ptr": torch.full( - size=(batch_size, self.hidden_dim), - fill_value=NO_OBJ_SCORE, - dtype=torch.float32, - device=inference_state["device"], - ), - "object_score_logits": torch.full( - size=(batch_size, 1), - # default to 10.0 for object_score_logits, i.e. assuming the object is - # present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder` - fill_value=10.0, - dtype=torch.float32, - device=inference_state["device"], - ), } - empty_mask_ptr = None for obj_idx in range(batch_size): obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] @@ -498,16 +458,6 @@ def _consolidate_temp_output_across_obj( # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE # placeholder above) and set its object pointer to be a dummy pointer. if out is None: - # Fill in dummy object pointers for those objects without any inputs or - # tracking outcomes on this frame (only do it under `run_mem_encoder=True`, - # i.e. when we need to build the memory for tracking). - if run_mem_encoder: - if empty_mask_ptr is None: - empty_mask_ptr = self._get_empty_mask_ptr( - inference_state, frame_idx - ) - # fill object pointer with a dummy pointer (based on an empty mask) - consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = empty_mask_ptr continue # Add the temporary object output mask to consolidated output mask obj_mask = out["pred_masks"] @@ -523,141 +473,74 @@ def _consolidate_temp_output_across_obj( align_corners=False, ) consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask - consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"] - consolidated_out["object_score_logits"][obj_idx : obj_idx + 1] = out[ - "object_score_logits" - ] - - # Optionally, apply non-overlapping constraints on the consolidated scores - # and rerun the memory encoder - if run_mem_encoder: - device = inference_state["device"] - high_res_masks = torch.nn.functional.interpolate( - consolidated_out["pred_masks"].to(device, non_blocking=True), - size=(self.image_size, self.image_size), - mode="bilinear", - align_corners=False, - ) - if self.non_overlap_masks_for_mem_enc: - high_res_masks = self._apply_non_overlapping_constraints(high_res_masks) - maskmem_features, maskmem_pos_enc = self._run_memory_encoder( - inference_state=inference_state, - frame_idx=frame_idx, - batch_size=batch_size, - high_res_masks=high_res_masks, - object_score_logits=consolidated_out["object_score_logits"], - is_mask_from_pts=True, # these frames are what the user interacted with - ) - consolidated_out["maskmem_features"] = maskmem_features - consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc return consolidated_out - def _get_empty_mask_ptr(self, inference_state, frame_idx): - """Get a dummy object pointer based on an empty mask on the current frame.""" - # A dummy (empty) mask with a single object - batch_size = 1 - mask_inputs = torch.zeros( - (batch_size, 1, self.image_size, self.image_size), - dtype=torch.float32, - device=inference_state["device"], - ) - - # Retrieve correct image features - ( - _, - _, - current_vision_feats, - current_vision_pos_embeds, - feat_sizes, - ) = self._get_image_feature(inference_state, frame_idx, batch_size) - - # Feed the empty mask and image feature above to get a dummy object pointer - current_out = self.track_step( - frame_idx=frame_idx, - is_init_cond_frame=True, - current_vision_feats=current_vision_feats, - current_vision_pos_embeds=current_vision_pos_embeds, - feat_sizes=feat_sizes, - point_inputs=None, - mask_inputs=mask_inputs, - output_dict={}, - num_frames=inference_state["num_frames"], - track_in_reverse=False, - run_mem_encoder=False, - prev_sam_mask_logits=None, - ) - return current_out["obj_ptr"] - @torch.inference_mode() def propagate_in_video_preflight(self, inference_state): """Prepare inference_state and consolidate temporary outputs before tracking.""" - # Tracking has started and we don't allow adding new objects until session is reset. - inference_state["tracking_has_started"] = True + # Check and make sure that every object has received input points or masks. batch_size = self._get_obj_num(inference_state) + if batch_size == 0: + raise RuntimeError( + "No input points or masks are provided for any object; please add inputs first." + ) # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and # add them into "output_dict". - temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] - output_dict = inference_state["output_dict"] - # "consolidated_frame_inds" contains indices of those frames where consolidated - # temporary outputs have been added (either in this call or any previous calls - # to `propagate_in_video_preflight`). - consolidated_frame_inds = inference_state["consolidated_frame_inds"] - for is_cond in [False, True]: - # Separately consolidate conditioning and non-conditioning temp outputs - storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" - # Find all the frames that contain temporary outputs for any objects - # (these should be the frames that have just received clicks for mask inputs - # via `add_new_points_or_box` or `add_new_mask`) - temp_frame_inds = set() - for obj_temp_output_dict in temp_output_dict_per_obj.values(): - temp_frame_inds.update(obj_temp_output_dict[storage_key].keys()) - consolidated_frame_inds[storage_key].update(temp_frame_inds) - # consolidate the temporary output across all objects on this frame - for frame_idx in temp_frame_inds: - consolidated_out = self._consolidate_temp_output_across_obj( - inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True - ) - # merge them into "output_dict" and also create per-object slices - output_dict[storage_key][frame_idx] = consolidated_out - self._add_output_per_object( - inference_state, frame_idx, consolidated_out, storage_key - ) - clear_non_cond_mem = self.clear_non_cond_mem_around_input and ( - self.clear_non_cond_mem_for_multi_obj or batch_size <= 1 + for obj_idx in range(batch_size): + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] + for is_cond in [False, True]: + # Separately consolidate conditioning and non-conditioning temp outputs + storage_key = ( + "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" ) - if clear_non_cond_mem: - # clear non-conditioning memory of the surrounding frames - self._clear_non_cond_mem_around_input(inference_state, frame_idx) + # Find all the frames that contain temporary outputs for any objects + # (these should be the frames that have just received clicks for mask inputs + # via `add_new_points_or_box` or `add_new_mask`) + for frame_idx, out in obj_temp_output_dict[storage_key].items(): + # Run memory encoder on the temporary outputs (if the memory feature is missing) + if out["maskmem_features"] is None: + high_res_masks = torch.nn.functional.interpolate( + out["pred_masks"].to(inference_state["device"]), + size=(self.image_size, self.image_size), + mode="bilinear", + align_corners=False, + ) + maskmem_features, maskmem_pos_enc = self._run_memory_encoder( + inference_state=inference_state, + frame_idx=frame_idx, + batch_size=1, # run on the slice of a single object + high_res_masks=high_res_masks, + object_score_logits=out["object_score_logits"], + # these frames are what the user interacted with + is_mask_from_pts=True, + ) + out["maskmem_features"] = maskmem_features + out["maskmem_pos_enc"] = maskmem_pos_enc + + obj_output_dict[storage_key][frame_idx] = out + if self.clear_non_cond_mem_around_input: + # clear non-conditioning memory of the surrounding frames + self._clear_obj_non_cond_mem_around_input( + inference_state, frame_idx, obj_idx + ) - # clear temporary outputs in `temp_output_dict_per_obj` - for obj_temp_output_dict in temp_output_dict_per_obj.values(): + # clear temporary outputs in `temp_output_dict_per_obj` obj_temp_output_dict[storage_key].clear() - # edge case: if an output is added to "cond_frame_outputs", we remove any prior - # output on the same frame in "non_cond_frame_outputs" - for frame_idx in output_dict["cond_frame_outputs"]: - output_dict["non_cond_frame_outputs"].pop(frame_idx, None) - for obj_output_dict in inference_state["output_dict_per_obj"].values(): + # check and make sure that every object has received input points or masks + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + if len(obj_output_dict["cond_frame_outputs"]) == 0: + obj_id = self._obj_idx_to_id(inference_state, obj_idx) + raise RuntimeError( + f"No input points or masks are provided for object id {obj_id}; please add inputs first." + ) + # edge case: if an output is added to "cond_frame_outputs", we remove any prior + # output on the same frame in "non_cond_frame_outputs" for frame_idx in obj_output_dict["cond_frame_outputs"]: obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None) - for frame_idx in consolidated_frame_inds["cond_frame_outputs"]: - assert frame_idx in output_dict["cond_frame_outputs"] - consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx) - - # Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames - # with either points or mask inputs (which should be true under a correct workflow). - all_consolidated_frame_inds = ( - consolidated_frame_inds["cond_frame_outputs"] - | consolidated_frame_inds["non_cond_frame_outputs"] - ) - input_frames_inds = set() - for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values(): - input_frames_inds.update(point_inputs_per_frame.keys()) - for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values(): - input_frames_inds.update(mask_inputs_per_frame.keys()) - assert all_consolidated_frame_inds == input_frames_inds @torch.inference_mode() def propagate_in_video( @@ -670,21 +553,18 @@ def propagate_in_video( """Propagate the input points across frames to track in the entire video.""" self.propagate_in_video_preflight(inference_state) - output_dict = inference_state["output_dict"] - consolidated_frame_inds = inference_state["consolidated_frame_inds"] obj_ids = inference_state["obj_ids"] num_frames = inference_state["num_frames"] batch_size = self._get_obj_num(inference_state) - if len(output_dict["cond_frame_outputs"]) == 0: - raise RuntimeError("No points are provided; please add points first") - clear_non_cond_mem = self.clear_non_cond_mem_around_input and ( - self.clear_non_cond_mem_for_multi_obj or batch_size <= 1 - ) # set start index, end index, and processing order if start_frame_idx is None: # default: start from the earliest frame with input points - start_frame_idx = min(output_dict["cond_frame_outputs"]) + start_frame_idx = min( + t + for obj_output_dict in inference_state["output_dict_per_obj"].values() + for t in obj_output_dict["cond_frame_outputs"] + ) if max_frame_num_to_track is None: # default: track all the frames in the video max_frame_num_to_track = num_frames @@ -701,78 +581,53 @@ def propagate_in_video( processing_order = range(start_frame_idx, end_frame_idx + 1) for frame_idx in tqdm(processing_order, desc="propagate in video"): - # We skip those frames already in consolidated outputs (these are frames - # that received input clicks or mask). Note that we cannot directly run - # batched forward on them via `_run_single_frame_inference` because the - # number of clicks on each object might be different. - if frame_idx in consolidated_frame_inds["cond_frame_outputs"]: - storage_key = "cond_frame_outputs" - current_out = output_dict[storage_key][frame_idx] - pred_masks = current_out["pred_masks"] - if clear_non_cond_mem: - # clear non-conditioning memory of the surrounding frames - self._clear_non_cond_mem_around_input(inference_state, frame_idx) - elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]: - storage_key = "non_cond_frame_outputs" - current_out = output_dict[storage_key][frame_idx] - pred_masks = current_out["pred_masks"] - else: - storage_key = "non_cond_frame_outputs" - current_out, pred_masks = self._run_single_frame_inference( - inference_state=inference_state, - output_dict=output_dict, - frame_idx=frame_idx, - batch_size=batch_size, - is_init_cond_frame=False, - point_inputs=None, - mask_inputs=None, - reverse=reverse, - run_mem_encoder=True, - ) - output_dict[storage_key][frame_idx] = current_out - # Create slices of per-object outputs for subsequent interaction with each - # individual object after tracking. - self._add_output_per_object( - inference_state, frame_idx, current_out, storage_key - ) - inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse} + pred_masks_per_obj = [None] * batch_size + for obj_idx in range(batch_size): + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + # We skip those frames already in consolidated outputs (these are frames + # that received input clicks or mask). Note that we cannot directly run + # batched forward on them via `_run_single_frame_inference` because the + # number of clicks on each object might be different. + if frame_idx in obj_output_dict["cond_frame_outputs"]: + storage_key = "cond_frame_outputs" + current_out = obj_output_dict[storage_key][frame_idx] + pred_masks = current_out["pred_masks"] + if self.clear_non_cond_mem_around_input: + # clear non-conditioning memory of the surrounding frames + self._clear_obj_non_cond_mem_around_input( + inference_state, frame_idx, obj_idx + ) + else: + storage_key = "non_cond_frame_outputs" + current_out, pred_masks = self._run_single_frame_inference( + inference_state=inference_state, + output_dict=obj_output_dict, + frame_idx=frame_idx, + batch_size=1, # run on the slice of a single object + is_init_cond_frame=False, + point_inputs=None, + mask_inputs=None, + reverse=reverse, + run_mem_encoder=True, + ) + obj_output_dict[storage_key][frame_idx] = current_out + + inference_state["frames_tracked_per_obj"][obj_idx][frame_idx] = { + "reverse": reverse + } + pred_masks_per_obj[obj_idx] = pred_masks # Resize the output mask to the original video resolution (we directly use # the mask scores on GPU for output to avoid any CPU conversion in between) + if len(pred_masks_per_obj) > 1: + all_pred_masks = torch.cat(pred_masks_per_obj, dim=0) + else: + all_pred_masks = pred_masks_per_obj[0] _, video_res_masks = self._get_orig_video_res_output( - inference_state, pred_masks + inference_state, all_pred_masks ) yield frame_idx, obj_ids, video_res_masks - def _add_output_per_object( - self, inference_state, frame_idx, current_out, storage_key - ): - """ - Split a multi-object output into per-object output slices and add them into - `output_dict_per_obj`. The resulting slices share the same tensor storage. - """ - maskmem_features = current_out["maskmem_features"] - assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor) - - maskmem_pos_enc = current_out["maskmem_pos_enc"] - assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list) - - output_dict_per_obj = inference_state["output_dict_per_obj"] - for obj_idx, obj_output_dict in output_dict_per_obj.items(): - obj_slice = slice(obj_idx, obj_idx + 1) - obj_out = { - "maskmem_features": None, - "maskmem_pos_enc": None, - "pred_masks": current_out["pred_masks"][obj_slice], - "obj_ptr": current_out["obj_ptr"][obj_slice], - "object_score_logits": current_out["object_score_logits"][obj_slice], - } - if maskmem_features is not None: - obj_out["maskmem_features"] = maskmem_features[obj_slice] - if maskmem_pos_enc is not None: - obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc] - obj_output_dict[storage_key][frame_idx] = obj_out - @torch.inference_mode() def clear_all_prompts_in_frame( self, inference_state, frame_idx, obj_id, need_output=True @@ -788,41 +643,14 @@ def clear_all_prompts_in_frame( temp_output_dict_per_obj[obj_idx]["cond_frame_outputs"].pop(frame_idx, None) temp_output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].pop(frame_idx, None) - # Check and see if there are still any inputs left on this frame - batch_size = self._get_obj_num(inference_state) - frame_has_input = False - for obj_idx2 in range(batch_size): - if frame_idx in inference_state["point_inputs_per_obj"][obj_idx2]: - frame_has_input = True - break - if frame_idx in inference_state["mask_inputs_per_obj"][obj_idx2]: - frame_has_input = True - break - - # If this frame has no remaining inputs for any objects, we further clear its - # conditioning frame status - if not frame_has_input: - output_dict = inference_state["output_dict"] - consolidated_frame_inds = inference_state["consolidated_frame_inds"] - consolidated_frame_inds["cond_frame_outputs"].discard(frame_idx) - consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx) - # Remove the frame's conditioning output (possibly downgrading it to non-conditioning) - out = output_dict["cond_frame_outputs"].pop(frame_idx, None) - if out is not None: - # The frame is not a conditioning frame anymore since it's not receiving inputs, - # so we "downgrade" its output (if exists) to a non-conditioning frame output. - output_dict["non_cond_frame_outputs"][frame_idx] = out - inference_state["frames_already_tracked"].pop(frame_idx, None) - # Similarly, do it for the sliced output on each object. - for obj_idx2 in range(batch_size): - obj_output_dict = inference_state["output_dict_per_obj"][obj_idx2] - obj_out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None) - if obj_out is not None: - obj_output_dict["non_cond_frame_outputs"][frame_idx] = obj_out - - # If all the conditioning frames have been removed, we also clear the tracking outputs - if len(output_dict["cond_frame_outputs"]) == 0: - self._reset_tracking_results(inference_state) + # Remove the frame's conditioning output (possibly downgrading it to non-conditioning) + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None) + if out is not None: + # The frame is not a conditioning frame anymore since it's not receiving inputs, + # so we "downgrade" its output (if exists) to a non-conditioning frame output. + obj_output_dict["non_cond_frame_outputs"][frame_idx] = out + inference_state["frames_tracked_per_obj"][obj_idx].pop(frame_idx, None) if not need_output: return @@ -836,7 +664,6 @@ def clear_all_prompts_in_frame( inference_state, frame_idx, is_cond=is_cond, - run_mem_encoder=False, consolidate_at_video_res=True, ) _, video_res_masks = self._get_orig_video_res_output( @@ -856,6 +683,7 @@ def reset_state(self, inference_state): inference_state["mask_inputs_per_obj"].clear() inference_state["output_dict_per_obj"].clear() inference_state["temp_output_dict_per_obj"].clear() + inference_state["frames_tracked_per_obj"].clear() def _reset_tracking_results(self, inference_state): """Reset all tracking inputs and results across the videos.""" @@ -869,12 +697,8 @@ def _reset_tracking_results(self, inference_state): for v in inference_state["temp_output_dict_per_obj"].values(): v["cond_frame_outputs"].clear() v["non_cond_frame_outputs"].clear() - inference_state["output_dict"]["cond_frame_outputs"].clear() - inference_state["output_dict"]["non_cond_frame_outputs"].clear() - inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear() - inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear() - inference_state["tracking_has_started"] = False - inference_state["frames_already_tracked"].clear() + for v in inference_state["frames_tracked_per_obj"].values(): + v.clear() def _get_image_feature(self, inference_state, frame_idx, batch_size): """Compute the image features on a given frame.""" @@ -1092,8 +916,6 @@ def remove_object(self, inference_state, obj_id, strict=False, need_output=True) inference_state["obj_ids"] = new_obj_ids # Step 2: For per-object tensor storage, we shift their obj_idx in the dict keys. - # (note that "consolidated_frame_inds" doesn't need to be updated in this step as - # it's already handled in Step 0) def _map_keys(container): new_kvs = [] for k in old_obj_inds: @@ -1106,30 +928,9 @@ def _map_keys(container): _map_keys(inference_state["mask_inputs_per_obj"]) _map_keys(inference_state["output_dict_per_obj"]) _map_keys(inference_state["temp_output_dict_per_obj"]) + _map_keys(inference_state["frames_tracked_per_obj"]) - # Step 3: For packed tensor storage, we index the remaining ids and rebuild the per-object slices. - def _slice_state(output_dict, storage_key): - for frame_idx, out in output_dict[storage_key].items(): - out["maskmem_features"] = out["maskmem_features"][remain_old_obj_inds] - out["maskmem_pos_enc"] = [ - x[remain_old_obj_inds] for x in out["maskmem_pos_enc"] - ] - # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it - out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(inference_state, out) - out["pred_masks"] = out["pred_masks"][remain_old_obj_inds] - out["obj_ptr"] = out["obj_ptr"][remain_old_obj_inds] - out["object_score_logits"] = out["object_score_logits"][ - remain_old_obj_inds - ] - # also update the per-object slices - self._add_output_per_object( - inference_state, frame_idx, out, storage_key - ) - - _slice_state(inference_state["output_dict"], "cond_frame_outputs") - _slice_state(inference_state["output_dict"], "non_cond_frame_outputs") - - # Step 4: Further collect the outputs on those frames in `obj_input_frames_inds`, which + # Step 3: Further collect the outputs on those frames in `obj_input_frames_inds`, which # could show an updated mask for objects previously occluded by the object being removed if need_output: temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] @@ -1142,7 +943,6 @@ def _slice_state(output_dict, storage_key): inference_state, frame_idx, is_cond=is_cond, - run_mem_encoder=False, consolidate_at_video_res=True, ) _, video_res_masks = self._get_orig_video_res_output( @@ -1164,9 +964,259 @@ def _clear_non_cond_mem_around_input(self, inference_state, frame_idx): r = self.memory_temporal_stride_for_eval frame_idx_begin = frame_idx - r * self.num_maskmem frame_idx_end = frame_idx + r * self.num_maskmem - output_dict = inference_state["output_dict"] - non_cond_frame_outputs = output_dict["non_cond_frame_outputs"] - for t in range(frame_idx_begin, frame_idx_end + 1): - non_cond_frame_outputs.pop(t, None) - for obj_output_dict in inference_state["output_dict_per_obj"].values(): - obj_output_dict["non_cond_frame_outputs"].pop(t, None) + batch_size = self._get_obj_num(inference_state) + for obj_idx in range(batch_size): + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + non_cond_frame_outputs = obj_output_dict["non_cond_frame_outputs"] + for t in range(frame_idx_begin, frame_idx_end + 1): + non_cond_frame_outputs.pop(t, None) + + +class SAM2VideoPredictorVOS(SAM2VideoPredictor): + """Optimized for the VOS setting""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._compile_all_components() + + def _compile_all_components(self): + print("Compiling all components for VOS setting. First time may be very slow.") + self.memory_encoder.forward = torch.compile( + self.memory_encoder.forward, + mode="max-autotune", + fullgraph=True, + dynamic=False, + ) + + self.memory_attention.forward = torch.compile( + self.memory_attention.forward, + mode="max-autotune", + fullgraph=True, + dynamic=True, # Num. of memories varies + ) + + self.sam_prompt_encoder.forward = torch.compile( + self.sam_prompt_encoder.forward, + mode="max-autotune", + fullgraph=True, + dynamic=False, # Accuracy regression on True + ) + + self.sam_mask_decoder.forward = torch.compile( + self.sam_mask_decoder.forward, + mode="max-autotune", + fullgraph=True, + dynamic=False, # Accuracy regression on True + ) + + def forward_image(self, img_batch: torch.Tensor): + """ + Identical to the corresponding method in the parent (SAM2VideoPredictor), but + cloning the backbone features and pos encoding to enable compilation. + """ + backbone_out = self.image_encoder(img_batch) + if self.use_high_res_features_in_sam: + # precompute projected level 0 and level 1 features in SAM decoder + # to avoid running it again on every SAM click + backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0( + backbone_out["backbone_fpn"][0] + ) + backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1( + backbone_out["backbone_fpn"][1] + ) + # Clone to help torch.compile + for i in range(len(backbone_out["backbone_fpn"])): + backbone_out["backbone_fpn"][i] = backbone_out["backbone_fpn"][i].clone() + backbone_out["vision_pos_enc"][i] = backbone_out["vision_pos_enc"][ + i + ].clone() + return backbone_out + + def _forward_sam_heads( + self, + backbone_features, + point_inputs=None, + mask_inputs=None, + high_res_features=None, + multimask_output=False, + ): + """ + Identical to the corresponding method in the parent (SAM2VideoPredictor), but + cloning the outputs of prompt_encoder and mask_decoder to enable compilation. + """ + B = backbone_features.size(0) + device = backbone_features.device + assert backbone_features.size(1) == self.sam_prompt_embed_dim + assert backbone_features.size(2) == self.sam_image_embedding_size + assert backbone_features.size(3) == self.sam_image_embedding_size + + # a) Handle point prompts + if point_inputs is not None: + sam_point_coords = point_inputs["point_coords"] + sam_point_labels = point_inputs["point_labels"] + assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B + else: + # If no points are provide, pad with an empty point (with label -1) + sam_point_coords = torch.zeros(B, 1, 2, device=device) + sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device) + + # b) Handle mask prompts + if mask_inputs is not None: + # If mask_inputs is provided, downsize it into low-res mask input if needed + # and feed it as a dense mask prompt into the SAM mask encoder + assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1) + if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size: + sam_mask_prompt = F.interpolate( + mask_inputs.float(), + size=self.sam_prompt_encoder.mask_input_size, + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + else: + sam_mask_prompt = mask_inputs + else: + # Otherwise, simply feed None (and SAM's prompt encoder will add + # a learned `no_mask_embed` to indicate no mask input in this case). + sam_mask_prompt = None + + sparse_embeddings, dense_embeddings = self.sam_prompt_encoder( + points=(sam_point_coords, sam_point_labels), + boxes=None, + masks=sam_mask_prompt, + ) + # Clone image_pe and the outputs of sam_prompt_encoder + # to enable compilation + sparse_embeddings = sparse_embeddings.clone() + dense_embeddings = dense_embeddings.clone() + image_pe = self.sam_prompt_encoder.get_dense_pe().clone() + ( + low_res_multimasks, + ious, + sam_output_tokens, + object_score_logits, + ) = self.sam_mask_decoder( + image_embeddings=backbone_features, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + repeat_image=False, # the image is already batched + high_res_features=high_res_features, + ) + # Clone the output of sam_mask_decoder + # to enable compilation + low_res_multimasks = low_res_multimasks.clone() + ious = ious.clone() + sam_output_tokens = sam_output_tokens.clone() + object_score_logits = object_score_logits.clone() + + if self.pred_obj_scores: + is_obj_appearing = object_score_logits > 0 + + # Mask used for spatial memories is always a *hard* choice between obj and no obj, + # consistent with the actual mask prediction + low_res_multimasks = torch.where( + is_obj_appearing[:, None, None], + low_res_multimasks, + NO_OBJ_SCORE, + ) + + # convert masks from possibly bfloat16 (or float16) to float32 + # (older PyTorch versions before 2.1 don't support `interpolate` on bf16) + low_res_multimasks = low_res_multimasks.float() + high_res_multimasks = F.interpolate( + low_res_multimasks, + size=(self.image_size, self.image_size), + mode="bilinear", + align_corners=False, + ) + + sam_output_token = sam_output_tokens[:, 0] + if multimask_output: + # take the best mask prediction (with the highest IoU estimation) + best_iou_inds = torch.argmax(ious, dim=-1) + batch_inds = torch.arange(B, device=device) + low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + if sam_output_tokens.size(1) > 1: + sam_output_token = sam_output_tokens[batch_inds, best_iou_inds] + else: + low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks + + # Extract object pointer from the SAM output token (with occlusion handling) + obj_ptr = self.obj_ptr_proj(sam_output_token) + if self.pred_obj_scores: + # Allow *soft* no obj ptr, unlike for masks + if self.soft_no_obj_ptr: + lambda_is_obj_appearing = object_score_logits.sigmoid() + else: + lambda_is_obj_appearing = is_obj_appearing.float() + + if self.fixed_no_obj_ptr: + obj_ptr = lambda_is_obj_appearing * obj_ptr + obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr + + return ( + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) + + def _encode_new_memory( + self, + current_vision_feats, + feat_sizes, + pred_masks_high_res, + object_score_logits, + is_mask_from_pts, + ): + """ + Identical to the corresponding method in the parent (SAM2VideoPredictor), but + cloning the memories and their pos enc to enable compilation. + """ + B = current_vision_feats[-1].size(1) # batch size on this frame + C = self.hidden_dim + H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size + # top-level feature, (HW)BC => BCHW + pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) + if self.non_overlap_masks_for_mem_enc and not self.training: + # optionally, apply non-overlapping constraints to the masks (it's applied + # in the batch dimension and should only be used during eval, where all + # the objects come from the same video under batch size 1). + pred_masks_high_res = self._apply_non_overlapping_constraints( + pred_masks_high_res + ) + # scale the raw mask logits with a temperature before applying sigmoid + binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts + if binarize and not self.training: + mask_for_mem = (pred_masks_high_res > 0).float() + else: + # apply sigmoid on the raw mask logits to turn them into range (0, 1) + mask_for_mem = torch.sigmoid(pred_masks_high_res) + # apply scale and bias terms to the sigmoid probabilities + if self.sigmoid_scale_for_mem_enc != 1.0: + mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc + if self.sigmoid_bias_for_mem_enc != 0.0: + mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc + maskmem_out = self.memory_encoder( + pix_feat, mask_for_mem, skip_mask_sigmoid=True # sigmoid already applied + ) + # Clone the feats and pos_enc to enable compilation + maskmem_features = maskmem_out["vision_features"].clone() + maskmem_pos_enc = [m.clone() for m in maskmem_out["vision_pos_enc"]] + # add a no-object embedding to the spatial memory to indicate that the frame + # is predicted to be occluded (i.e. no object is appearing in the frame) + if self.no_obj_embed_spatial is not None: + is_obj_appearing = (object_score_logits > 0).float() + maskmem_features += ( + 1 - is_obj_appearing[..., None, None] + ) * self.no_obj_embed_spatial[..., None, None].expand( + *maskmem_features.shape + ) + + return maskmem_features, maskmem_pos_enc diff --git a/sam2/sam2_video_predictor_legacy.py b/sam2/sam2_video_predictor_legacy.py new file mode 100644 index 000000000..c7e01ccf9 --- /dev/null +++ b/sam2/sam2_video_predictor_legacy.py @@ -0,0 +1,1172 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import warnings +from collections import OrderedDict + +import torch + +from tqdm import tqdm + +from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base +from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames + + +class SAM2VideoPredictor(SAM2Base): + """The predictor class to handle user interactions and manage inference states.""" + + def __init__( + self, + fill_hole_area=0, + # whether to apply non-overlapping constraints on the output object masks + non_overlap_masks=False, + # whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks; + # note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True) + clear_non_cond_mem_around_input=False, + # whether to also clear non-conditioning memory of the surrounding frames (only effective when `clear_non_cond_mem_around_input` is True). + clear_non_cond_mem_for_multi_obj=False, + # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click + # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames + add_all_frames_to_correct_as_cond=False, + **kwargs, + ): + super().__init__(**kwargs) + self.fill_hole_area = fill_hole_area + self.non_overlap_masks = non_overlap_masks + self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input + self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj + self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond + + @torch.inference_mode() + def init_state( + self, + video_path, + offload_video_to_cpu=False, + offload_state_to_cpu=False, + async_loading_frames=False, + ): + """Initialize an inference state.""" + compute_device = self.device # device of the model + images, video_height, video_width = load_video_frames( + video_path=video_path, + image_size=self.image_size, + offload_video_to_cpu=offload_video_to_cpu, + async_loading_frames=async_loading_frames, + compute_device=compute_device, + ) + inference_state = {} + inference_state["images"] = images + inference_state["num_frames"] = len(images) + # whether to offload the video frames to CPU memory + # turning on this option saves the GPU memory with only a very small overhead + inference_state["offload_video_to_cpu"] = offload_video_to_cpu + # whether to offload the inference state to CPU memory + # turning on this option saves the GPU memory at the cost of a lower tracking fps + # (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object + # and from 24 to 21 when tracking two objects) + inference_state["offload_state_to_cpu"] = offload_state_to_cpu + # the original video height and width, used for resizing final output scores + inference_state["video_height"] = video_height + inference_state["video_width"] = video_width + inference_state["device"] = compute_device + if offload_state_to_cpu: + inference_state["storage_device"] = torch.device("cpu") + else: + inference_state["storage_device"] = compute_device + # inputs on each frame + inference_state["point_inputs_per_obj"] = {} + inference_state["mask_inputs_per_obj"] = {} + # visual features on a small number of recently visited frames for quick interactions + inference_state["cached_features"] = {} + # values that don't change across frames (so we only need to hold one copy of them) + inference_state["constants"] = {} + # mapping between client-side object id and model-side object index + inference_state["obj_id_to_idx"] = OrderedDict() + inference_state["obj_idx_to_id"] = OrderedDict() + inference_state["obj_ids"] = [] + # A storage to hold the model's tracking results and states on each frame + inference_state["output_dict"] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + # Slice (view) of each object tracking results, sharing the same memory with "output_dict" + inference_state["output_dict_per_obj"] = {} + # A temporary storage to hold new outputs when user interact with a frame + # to add clicks or mask (it's merged into "output_dict" before propagation starts) + inference_state["temp_output_dict_per_obj"] = {} + # Frames that already holds consolidated outputs from click or mask inputs + # (we directly use their consolidated outputs during tracking) + inference_state["consolidated_frame_inds"] = { + "cond_frame_outputs": set(), # set containing frame indices + "non_cond_frame_outputs": set(), # set containing frame indices + } + # metadata for each tracking frame (e.g. which direction it's tracked) + inference_state["tracking_has_started"] = False + inference_state["frames_already_tracked"] = {} + # Warm up the visual backbone and cache the image feature on frame 0 + self._get_image_feature(inference_state, frame_idx=0, batch_size=1) + return inference_state + + @classmethod + def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2VideoPredictor": + """ + Load a pretrained model from the Hugging Face hub. + + Arguments: + model_id (str): The Hugging Face repository ID. + **kwargs: Additional arguments to pass to the model constructor. + + Returns: + (SAM2VideoPredictor): The loaded model. + """ + from sam2.build_sam import build_sam2_video_predictor_hf + + sam_model = build_sam2_video_predictor_hf(model_id, **kwargs) + return sam_model + + def _obj_id_to_idx(self, inference_state, obj_id): + """Map client-side object id to model-side object index.""" + obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None) + if obj_idx is not None: + return obj_idx + + # This is a new object id not sent to the server before. We only allow adding + # new objects *before* the tracking starts. + allow_new_object = not inference_state["tracking_has_started"] + if allow_new_object: + # get the next object slot + obj_idx = len(inference_state["obj_id_to_idx"]) + inference_state["obj_id_to_idx"][obj_id] = obj_idx + inference_state["obj_idx_to_id"][obj_idx] = obj_id + inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"]) + # set up input and output structures for this object + inference_state["point_inputs_per_obj"][obj_idx] = {} + inference_state["mask_inputs_per_obj"][obj_idx] = {} + inference_state["output_dict_per_obj"][obj_idx] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + inference_state["temp_output_dict_per_obj"][obj_idx] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + return obj_idx + else: + raise RuntimeError( + f"Cannot add new object id {obj_id} after tracking starts. " + f"All existing object ids: {inference_state['obj_ids']}. " + f"Please call 'reset_state' to restart from scratch." + ) + + def _obj_idx_to_id(self, inference_state, obj_idx): + """Map model-side object index to client-side object id.""" + return inference_state["obj_idx_to_id"][obj_idx] + + def _get_obj_num(self, inference_state): + """Get the total number of unique object ids received so far in this session.""" + return len(inference_state["obj_idx_to_id"]) + + @torch.inference_mode() + def add_new_points_or_box( + self, + inference_state, + frame_idx, + obj_id, + points=None, + labels=None, + clear_old_points=True, + normalize_coords=True, + box=None, + ): + """Add new points to a frame.""" + obj_idx = self._obj_id_to_idx(inference_state, obj_id) + point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx] + mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx] + + if (points is not None) != (labels is not None): + raise ValueError("points and labels must be provided together") + if points is None and box is None: + raise ValueError("at least one of points or box must be provided as input") + + if points is None: + points = torch.zeros(0, 2, dtype=torch.float32) + elif not isinstance(points, torch.Tensor): + points = torch.tensor(points, dtype=torch.float32) + if labels is None: + labels = torch.zeros(0, dtype=torch.int32) + elif not isinstance(labels, torch.Tensor): + labels = torch.tensor(labels, dtype=torch.int32) + if points.dim() == 2: + points = points.unsqueeze(0) # add batch dimension + if labels.dim() == 1: + labels = labels.unsqueeze(0) # add batch dimension + + # If `box` is provided, we add it as the first two points with labels 2 and 3 + # along with the user-provided points (consistent with how SAM 2 is trained). + if box is not None: + if not clear_old_points: + raise ValueError( + "cannot add box without clearing old points, since " + "box prompt must be provided before any point prompt " + "(please use clear_old_points=True instead)" + ) + if inference_state["tracking_has_started"]: + warnings.warn( + "You are adding a box after tracking starts. SAM 2 may not always be " + "able to incorporate a box prompt for *refinement*. If you intend to " + "use box prompt as an *initial* input before tracking, please call " + "'reset_state' on the inference state to restart from scratch.", + category=UserWarning, + stacklevel=2, + ) + if not isinstance(box, torch.Tensor): + box = torch.tensor(box, dtype=torch.float32, device=points.device) + box_coords = box.reshape(1, 2, 2) + box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device) + box_labels = box_labels.reshape(1, 2) + points = torch.cat([box_coords, points], dim=1) + labels = torch.cat([box_labels, labels], dim=1) + + if normalize_coords: + video_H = inference_state["video_height"] + video_W = inference_state["video_width"] + points = points / torch.tensor([video_W, video_H]).to(points.device) + # scale the (normalized) coordinates by the model's internal image size + points = points * self.image_size + points = points.to(inference_state["device"]) + labels = labels.to(inference_state["device"]) + + if not clear_old_points: + point_inputs = point_inputs_per_frame.get(frame_idx, None) + else: + point_inputs = None + point_inputs = concat_points(point_inputs, points, labels) + + point_inputs_per_frame[frame_idx] = point_inputs + mask_inputs_per_frame.pop(frame_idx, None) + # If this frame hasn't been tracked before, we treat it as an initial conditioning + # frame, meaning that the inputs points are to generate segments on this frame without + # using any memory from other frames, like in SAM. Otherwise (if it has been tracked), + # the input points will be used to correct the already tracked masks. + is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"] + # whether to track in reverse time order + if is_init_cond_frame: + reverse = False + else: + reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"] + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] + # Add a frame to conditioning output if it's an initial conditioning frame or + # if the model sees all frames receiving clicks/mask as conditioning frames. + is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + + # Get any previously predicted mask logits on this object and feed it along with + # the new clicks into the SAM mask decoder. + prev_sam_mask_logits = None + # lookup temporary output dict first, which contains the most recent output + # (if not found, then lookup conditioning and non-conditioning frame output) + prev_out = obj_temp_output_dict[storage_key].get(frame_idx) + if prev_out is None: + prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx) + if prev_out is None: + prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx) + + if prev_out is not None and prev_out["pred_masks"] is not None: + device = inference_state["device"] + prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True) + # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues. + prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0) + current_out, _ = self._run_single_frame_inference( + inference_state=inference_state, + output_dict=obj_output_dict, # run on the slice of a single object + frame_idx=frame_idx, + batch_size=1, # run on the slice of a single object + is_init_cond_frame=is_init_cond_frame, + point_inputs=point_inputs, + mask_inputs=None, + reverse=reverse, + # Skip the memory encoder when adding clicks or mask. We execute the memory encoder + # at the beginning of `propagate_in_video` (after user finalize their clicks). This + # allows us to enforce non-overlapping constraints on all objects before encoding + # them into memory. + run_mem_encoder=False, + prev_sam_mask_logits=prev_sam_mask_logits, + ) + # Add the output to the output dict (to be used as future memory) + obj_temp_output_dict[storage_key][frame_idx] = current_out + + # Resize the output mask to the original video resolution + obj_ids = inference_state["obj_ids"] + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, + frame_idx, + is_cond=is_cond, + run_mem_encoder=False, + consolidate_at_video_res=True, + ) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, consolidated_out["pred_masks_video_res"] + ) + return frame_idx, obj_ids, video_res_masks + + def add_new_points(self, *args, **kwargs): + """Deprecated method. Please use `add_new_points_or_box` instead.""" + return self.add_new_points_or_box(*args, **kwargs) + + @torch.inference_mode() + def add_new_mask( + self, + inference_state, + frame_idx, + obj_id, + mask, + ): + """Add new mask to a frame.""" + obj_idx = self._obj_id_to_idx(inference_state, obj_id) + point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx] + mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx] + + if not isinstance(mask, torch.Tensor): + mask = torch.tensor(mask, dtype=torch.bool) + assert mask.dim() == 2 + mask_H, mask_W = mask.shape + mask_inputs_orig = mask[None, None] # add batch and channel dimension + mask_inputs_orig = mask_inputs_orig.float().to(inference_state["device"]) + + # resize the mask if it doesn't match the model's image size + if mask_H != self.image_size or mask_W != self.image_size: + mask_inputs = torch.nn.functional.interpolate( + mask_inputs_orig, + size=(self.image_size, self.image_size), + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + mask_inputs = (mask_inputs >= 0.5).float() + else: + mask_inputs = mask_inputs_orig + + mask_inputs_per_frame[frame_idx] = mask_inputs + point_inputs_per_frame.pop(frame_idx, None) + # If this frame hasn't been tracked before, we treat it as an initial conditioning + # frame, meaning that the inputs points are to generate segments on this frame without + # using any memory from other frames, like in SAM. Otherwise (if it has been tracked), + # the input points will be used to correct the already tracked masks. + is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"] + # whether to track in reverse time order + if is_init_cond_frame: + reverse = False + else: + reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"] + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] + # Add a frame to conditioning output if it's an initial conditioning frame or + # if the model sees all frames receiving clicks/mask as conditioning frames. + is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + + current_out, _ = self._run_single_frame_inference( + inference_state=inference_state, + output_dict=obj_output_dict, # run on the slice of a single object + frame_idx=frame_idx, + batch_size=1, # run on the slice of a single object + is_init_cond_frame=is_init_cond_frame, + point_inputs=None, + mask_inputs=mask_inputs, + reverse=reverse, + # Skip the memory encoder when adding clicks or mask. We execute the memory encoder + # at the beginning of `propagate_in_video` (after user finalize their clicks). This + # allows us to enforce non-overlapping constraints on all objects before encoding + # them into memory. + run_mem_encoder=False, + ) + # Add the output to the output dict (to be used as future memory) + obj_temp_output_dict[storage_key][frame_idx] = current_out + + # Resize the output mask to the original video resolution + obj_ids = inference_state["obj_ids"] + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, + frame_idx, + is_cond=is_cond, + run_mem_encoder=False, + consolidate_at_video_res=True, + ) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, consolidated_out["pred_masks_video_res"] + ) + return frame_idx, obj_ids, video_res_masks + + def _get_orig_video_res_output(self, inference_state, any_res_masks): + """ + Resize the object scores to the original video resolution (video_res_masks) + and apply non-overlapping constraints for final output. + """ + device = inference_state["device"] + video_H = inference_state["video_height"] + video_W = inference_state["video_width"] + any_res_masks = any_res_masks.to(device, non_blocking=True) + if any_res_masks.shape[-2:] == (video_H, video_W): + video_res_masks = any_res_masks + else: + video_res_masks = torch.nn.functional.interpolate( + any_res_masks, + size=(video_H, video_W), + mode="bilinear", + align_corners=False, + ) + if self.non_overlap_masks: + video_res_masks = self._apply_non_overlapping_constraints(video_res_masks) + return any_res_masks, video_res_masks + + def _consolidate_temp_output_across_obj( + self, + inference_state, + frame_idx, + is_cond, + run_mem_encoder, + consolidate_at_video_res=False, + ): + """ + Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on + a frame into a single output for all objects, including + 1) fill any missing objects either from `output_dict_per_obj` (if they exist in + `output_dict_per_obj` for this frame) or leave them as placeholder values + (if they don't exist in `output_dict_per_obj` for this frame); + 2) if specified, rerun memory encoder after apply non-overlapping constraints + on the object scores. + """ + batch_size = self._get_obj_num(inference_state) + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + # Optionally, we allow consolidating the temporary outputs at the original + # video resolution (to provide a better editing experience for mask prompts). + if consolidate_at_video_res: + assert not run_mem_encoder, "memory encoder cannot run at video resolution" + consolidated_H = inference_state["video_height"] + consolidated_W = inference_state["video_width"] + consolidated_mask_key = "pred_masks_video_res" + else: + consolidated_H = consolidated_W = self.image_size // 4 + consolidated_mask_key = "pred_masks" + + # Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc" + # will be added when rerunning the memory encoder after applying non-overlapping + # constraints to object scores. Its "pred_masks" are prefilled with a large + # negative value (NO_OBJ_SCORE) to represent missing objects. + consolidated_out = { + "maskmem_features": None, + "maskmem_pos_enc": None, + consolidated_mask_key: torch.full( + size=(batch_size, 1, consolidated_H, consolidated_W), + fill_value=NO_OBJ_SCORE, + dtype=torch.float32, + device=inference_state["storage_device"], + ), + "obj_ptr": torch.full( + size=(batch_size, self.hidden_dim), + fill_value=NO_OBJ_SCORE, + dtype=torch.float32, + device=inference_state["device"], + ), + "object_score_logits": torch.full( + size=(batch_size, 1), + # default to 10.0 for object_score_logits, i.e. assuming the object is + # present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder` + fill_value=10.0, + dtype=torch.float32, + device=inference_state["device"], + ), + } + empty_mask_ptr = None + for obj_idx in range(batch_size): + obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + out = obj_temp_output_dict[storage_key].get(frame_idx, None) + # If the object doesn't appear in "temp_output_dict_per_obj" on this frame, + # we fall back and look up its previous output in "output_dict_per_obj". + # We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in + # "output_dict_per_obj" to find a previous output for this object. + if out is None: + out = obj_output_dict["cond_frame_outputs"].get(frame_idx, None) + if out is None: + out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx, None) + # If the object doesn't appear in "output_dict_per_obj" either, we skip it + # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE + # placeholder above) and set its object pointer to be a dummy pointer. + if out is None: + # Fill in dummy object pointers for those objects without any inputs or + # tracking outcomes on this frame (only do it under `run_mem_encoder=True`, + # i.e. when we need to build the memory for tracking). + if run_mem_encoder: + if empty_mask_ptr is None: + empty_mask_ptr = self._get_empty_mask_ptr( + inference_state, frame_idx + ) + # fill object pointer with a dummy pointer (based on an empty mask) + consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = empty_mask_ptr + continue + # Add the temporary object output mask to consolidated output mask + obj_mask = out["pred_masks"] + consolidated_pred_masks = consolidated_out[consolidated_mask_key] + if obj_mask.shape[-2:] == consolidated_pred_masks.shape[-2:]: + consolidated_pred_masks[obj_idx : obj_idx + 1] = obj_mask + else: + # Resize first if temporary object mask has a different resolution + resized_obj_mask = torch.nn.functional.interpolate( + obj_mask, + size=consolidated_pred_masks.shape[-2:], + mode="bilinear", + align_corners=False, + ) + consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask + consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"] + consolidated_out["object_score_logits"][obj_idx : obj_idx + 1] = out[ + "object_score_logits" + ] + + # Optionally, apply non-overlapping constraints on the consolidated scores + # and rerun the memory encoder + if run_mem_encoder: + device = inference_state["device"] + high_res_masks = torch.nn.functional.interpolate( + consolidated_out["pred_masks"].to(device, non_blocking=True), + size=(self.image_size, self.image_size), + mode="bilinear", + align_corners=False, + ) + if self.non_overlap_masks_for_mem_enc: + high_res_masks = self._apply_non_overlapping_constraints(high_res_masks) + maskmem_features, maskmem_pos_enc = self._run_memory_encoder( + inference_state=inference_state, + frame_idx=frame_idx, + batch_size=batch_size, + high_res_masks=high_res_masks, + object_score_logits=consolidated_out["object_score_logits"], + is_mask_from_pts=True, # these frames are what the user interacted with + ) + consolidated_out["maskmem_features"] = maskmem_features + consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc + + return consolidated_out + + def _get_empty_mask_ptr(self, inference_state, frame_idx): + """Get a dummy object pointer based on an empty mask on the current frame.""" + # A dummy (empty) mask with a single object + batch_size = 1 + mask_inputs = torch.zeros( + (batch_size, 1, self.image_size, self.image_size), + dtype=torch.float32, + device=inference_state["device"], + ) + + # Retrieve correct image features + ( + _, + _, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + ) = self._get_image_feature(inference_state, frame_idx, batch_size) + + # Feed the empty mask and image feature above to get a dummy object pointer + current_out = self.track_step( + frame_idx=frame_idx, + is_init_cond_frame=True, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + point_inputs=None, + mask_inputs=mask_inputs, + output_dict={}, + num_frames=inference_state["num_frames"], + track_in_reverse=False, + run_mem_encoder=False, + prev_sam_mask_logits=None, + ) + return current_out["obj_ptr"] + + @torch.inference_mode() + def propagate_in_video_preflight(self, inference_state): + """Prepare inference_state and consolidate temporary outputs before tracking.""" + # Tracking has started and we don't allow adding new objects until session is reset. + inference_state["tracking_has_started"] = True + batch_size = self._get_obj_num(inference_state) + + # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and + # add them into "output_dict". + temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] + output_dict = inference_state["output_dict"] + # "consolidated_frame_inds" contains indices of those frames where consolidated + # temporary outputs have been added (either in this call or any previous calls + # to `propagate_in_video_preflight`). + consolidated_frame_inds = inference_state["consolidated_frame_inds"] + for is_cond in [False, True]: + # Separately consolidate conditioning and non-conditioning temp outputs + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + # Find all the frames that contain temporary outputs for any objects + # (these should be the frames that have just received clicks for mask inputs + # via `add_new_points_or_box` or `add_new_mask`) + temp_frame_inds = set() + for obj_temp_output_dict in temp_output_dict_per_obj.values(): + temp_frame_inds.update(obj_temp_output_dict[storage_key].keys()) + consolidated_frame_inds[storage_key].update(temp_frame_inds) + # consolidate the temporary output across all objects on this frame + for frame_idx in temp_frame_inds: + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True + ) + # merge them into "output_dict" and also create per-object slices + output_dict[storage_key][frame_idx] = consolidated_out + self._add_output_per_object( + inference_state, frame_idx, consolidated_out, storage_key + ) + clear_non_cond_mem = self.clear_non_cond_mem_around_input and ( + self.clear_non_cond_mem_for_multi_obj or batch_size <= 1 + ) + if clear_non_cond_mem: + # clear non-conditioning memory of the surrounding frames + self._clear_non_cond_mem_around_input(inference_state, frame_idx) + + # clear temporary outputs in `temp_output_dict_per_obj` + for obj_temp_output_dict in temp_output_dict_per_obj.values(): + obj_temp_output_dict[storage_key].clear() + + # edge case: if an output is added to "cond_frame_outputs", we remove any prior + # output on the same frame in "non_cond_frame_outputs" + for frame_idx in output_dict["cond_frame_outputs"]: + output_dict["non_cond_frame_outputs"].pop(frame_idx, None) + for obj_output_dict in inference_state["output_dict_per_obj"].values(): + for frame_idx in obj_output_dict["cond_frame_outputs"]: + obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None) + for frame_idx in consolidated_frame_inds["cond_frame_outputs"]: + assert frame_idx in output_dict["cond_frame_outputs"] + consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx) + + # Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames + # with either points or mask inputs (which should be true under a correct workflow). + all_consolidated_frame_inds = ( + consolidated_frame_inds["cond_frame_outputs"] + | consolidated_frame_inds["non_cond_frame_outputs"] + ) + input_frames_inds = set() + for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values(): + input_frames_inds.update(point_inputs_per_frame.keys()) + for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values(): + input_frames_inds.update(mask_inputs_per_frame.keys()) + assert all_consolidated_frame_inds == input_frames_inds + + @torch.inference_mode() + def propagate_in_video( + self, + inference_state, + start_frame_idx=None, + max_frame_num_to_track=None, + reverse=False, + ): + """Propagate the input points across frames to track in the entire video.""" + self.propagate_in_video_preflight(inference_state) + + output_dict = inference_state["output_dict"] + consolidated_frame_inds = inference_state["consolidated_frame_inds"] + obj_ids = inference_state["obj_ids"] + num_frames = inference_state["num_frames"] + batch_size = self._get_obj_num(inference_state) + if len(output_dict["cond_frame_outputs"]) == 0: + raise RuntimeError("No points are provided; please add points first") + clear_non_cond_mem = self.clear_non_cond_mem_around_input and ( + self.clear_non_cond_mem_for_multi_obj or batch_size <= 1 + ) + + # set start index, end index, and processing order + if start_frame_idx is None: + # default: start from the earliest frame with input points + start_frame_idx = min(output_dict["cond_frame_outputs"]) + if max_frame_num_to_track is None: + # default: track all the frames in the video + max_frame_num_to_track = num_frames + if reverse: + end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0) + if start_frame_idx > 0: + processing_order = range(start_frame_idx, end_frame_idx - 1, -1) + else: + processing_order = [] # skip reverse tracking if starting from frame 0 + else: + end_frame_idx = min( + start_frame_idx + max_frame_num_to_track, num_frames - 1 + ) + processing_order = range(start_frame_idx, end_frame_idx + 1) + + for frame_idx in tqdm(processing_order, desc="propagate in video"): + # We skip those frames already in consolidated outputs (these are frames + # that received input clicks or mask). Note that we cannot directly run + # batched forward on them via `_run_single_frame_inference` because the + # number of clicks on each object might be different. + if frame_idx in consolidated_frame_inds["cond_frame_outputs"]: + storage_key = "cond_frame_outputs" + current_out = output_dict[storage_key][frame_idx] + pred_masks = current_out["pred_masks"] + if clear_non_cond_mem: + # clear non-conditioning memory of the surrounding frames + self._clear_non_cond_mem_around_input(inference_state, frame_idx) + elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]: + storage_key = "non_cond_frame_outputs" + current_out = output_dict[storage_key][frame_idx] + pred_masks = current_out["pred_masks"] + else: + storage_key = "non_cond_frame_outputs" + current_out, pred_masks = self._run_single_frame_inference( + inference_state=inference_state, + output_dict=output_dict, + frame_idx=frame_idx, + batch_size=batch_size, + is_init_cond_frame=False, + point_inputs=None, + mask_inputs=None, + reverse=reverse, + run_mem_encoder=True, + ) + output_dict[storage_key][frame_idx] = current_out + # Create slices of per-object outputs for subsequent interaction with each + # individual object after tracking. + self._add_output_per_object( + inference_state, frame_idx, current_out, storage_key + ) + inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse} + + # Resize the output mask to the original video resolution (we directly use + # the mask scores on GPU for output to avoid any CPU conversion in between) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, pred_masks + ) + yield frame_idx, obj_ids, video_res_masks + + def _add_output_per_object( + self, inference_state, frame_idx, current_out, storage_key + ): + """ + Split a multi-object output into per-object output slices and add them into + `output_dict_per_obj`. The resulting slices share the same tensor storage. + """ + maskmem_features = current_out["maskmem_features"] + assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor) + + maskmem_pos_enc = current_out["maskmem_pos_enc"] + assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list) + + output_dict_per_obj = inference_state["output_dict_per_obj"] + for obj_idx, obj_output_dict in output_dict_per_obj.items(): + obj_slice = slice(obj_idx, obj_idx + 1) + obj_out = { + "maskmem_features": None, + "maskmem_pos_enc": None, + "pred_masks": current_out["pred_masks"][obj_slice], + "obj_ptr": current_out["obj_ptr"][obj_slice], + "object_score_logits": current_out["object_score_logits"][obj_slice], + } + if maskmem_features is not None: + obj_out["maskmem_features"] = maskmem_features[obj_slice] + if maskmem_pos_enc is not None: + obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc] + obj_output_dict[storage_key][frame_idx] = obj_out + + @torch.inference_mode() + def clear_all_prompts_in_frame( + self, inference_state, frame_idx, obj_id, need_output=True + ): + """Remove all input points or mask in a specific frame for a given object.""" + obj_idx = self._obj_id_to_idx(inference_state, obj_id) + + # Clear the conditioning information on the given frame + inference_state["point_inputs_per_obj"][obj_idx].pop(frame_idx, None) + inference_state["mask_inputs_per_obj"][obj_idx].pop(frame_idx, None) + + temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] + temp_output_dict_per_obj[obj_idx]["cond_frame_outputs"].pop(frame_idx, None) + temp_output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].pop(frame_idx, None) + + # Check and see if there are still any inputs left on this frame + batch_size = self._get_obj_num(inference_state) + frame_has_input = False + for obj_idx2 in range(batch_size): + if frame_idx in inference_state["point_inputs_per_obj"][obj_idx2]: + frame_has_input = True + break + if frame_idx in inference_state["mask_inputs_per_obj"][obj_idx2]: + frame_has_input = True + break + + # If this frame has no remaining inputs for any objects, we further clear its + # conditioning frame status + if not frame_has_input: + output_dict = inference_state["output_dict"] + consolidated_frame_inds = inference_state["consolidated_frame_inds"] + consolidated_frame_inds["cond_frame_outputs"].discard(frame_idx) + consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx) + # Remove the frame's conditioning output (possibly downgrading it to non-conditioning) + out = output_dict["cond_frame_outputs"].pop(frame_idx, None) + if out is not None: + # The frame is not a conditioning frame anymore since it's not receiving inputs, + # so we "downgrade" its output (if exists) to a non-conditioning frame output. + output_dict["non_cond_frame_outputs"][frame_idx] = out + inference_state["frames_already_tracked"].pop(frame_idx, None) + # Similarly, do it for the sliced output on each object. + for obj_idx2 in range(batch_size): + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx2] + obj_out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None) + if obj_out is not None: + obj_output_dict["non_cond_frame_outputs"][frame_idx] = obj_out + + # If all the conditioning frames have been removed, we also clear the tracking outputs + if len(output_dict["cond_frame_outputs"]) == 0: + self._reset_tracking_results(inference_state) + + if not need_output: + return + # Finally, output updated masks per object (after removing the inputs above) + obj_ids = inference_state["obj_ids"] + is_cond = any( + frame_idx in obj_temp_output_dict["cond_frame_outputs"] + for obj_temp_output_dict in temp_output_dict_per_obj.values() + ) + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, + frame_idx, + is_cond=is_cond, + run_mem_encoder=False, + consolidate_at_video_res=True, + ) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, consolidated_out["pred_masks_video_res"] + ) + return frame_idx, obj_ids, video_res_masks + + @torch.inference_mode() + def reset_state(self, inference_state): + """Remove all input points or mask in all frames throughout the video.""" + self._reset_tracking_results(inference_state) + # Remove all object ids + inference_state["obj_id_to_idx"].clear() + inference_state["obj_idx_to_id"].clear() + inference_state["obj_ids"].clear() + inference_state["point_inputs_per_obj"].clear() + inference_state["mask_inputs_per_obj"].clear() + inference_state["output_dict_per_obj"].clear() + inference_state["temp_output_dict_per_obj"].clear() + + def _reset_tracking_results(self, inference_state): + """Reset all tracking inputs and results across the videos.""" + for v in inference_state["point_inputs_per_obj"].values(): + v.clear() + for v in inference_state["mask_inputs_per_obj"].values(): + v.clear() + for v in inference_state["output_dict_per_obj"].values(): + v["cond_frame_outputs"].clear() + v["non_cond_frame_outputs"].clear() + for v in inference_state["temp_output_dict_per_obj"].values(): + v["cond_frame_outputs"].clear() + v["non_cond_frame_outputs"].clear() + inference_state["output_dict"]["cond_frame_outputs"].clear() + inference_state["output_dict"]["non_cond_frame_outputs"].clear() + inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear() + inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear() + inference_state["tracking_has_started"] = False + inference_state["frames_already_tracked"].clear() + + def _get_image_feature(self, inference_state, frame_idx, batch_size): + """Compute the image features on a given frame.""" + # Look up in the cache first + image, backbone_out = inference_state["cached_features"].get( + frame_idx, (None, None) + ) + if backbone_out is None: + # Cache miss -- we will run inference on a single image + device = inference_state["device"] + image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0) + backbone_out = self.forward_image(image) + # Cache the most recent frame's feature (for repeated interactions with + # a frame; we can use an LRU cache for more frames in the future). + inference_state["cached_features"] = {frame_idx: (image, backbone_out)} + + # expand the features to have the same dimension as the number of objects + expanded_image = image.expand(batch_size, -1, -1, -1) + expanded_backbone_out = { + "backbone_fpn": backbone_out["backbone_fpn"].copy(), + "vision_pos_enc": backbone_out["vision_pos_enc"].copy(), + } + for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]): + expanded_backbone_out["backbone_fpn"][i] = feat.expand( + batch_size, -1, -1, -1 + ) + for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]): + pos = pos.expand(batch_size, -1, -1, -1) + expanded_backbone_out["vision_pos_enc"][i] = pos + + features = self._prepare_backbone_features(expanded_backbone_out) + features = (expanded_image,) + features + return features + + def _run_single_frame_inference( + self, + inference_state, + output_dict, + frame_idx, + batch_size, + is_init_cond_frame, + point_inputs, + mask_inputs, + reverse, + run_mem_encoder, + prev_sam_mask_logits=None, + ): + """Run tracking on a single frame based on current inputs and previous memory.""" + # Retrieve correct image features + ( + _, + _, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + ) = self._get_image_feature(inference_state, frame_idx, batch_size) + + # point and mask should not appear as input simultaneously on the same frame + assert point_inputs is None or mask_inputs is None + current_out = self.track_step( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + output_dict=output_dict, + num_frames=inference_state["num_frames"], + track_in_reverse=reverse, + run_mem_encoder=run_mem_encoder, + prev_sam_mask_logits=prev_sam_mask_logits, + ) + + # optionally offload the output to CPU memory to save GPU space + storage_device = inference_state["storage_device"] + maskmem_features = current_out["maskmem_features"] + if maskmem_features is not None: + maskmem_features = maskmem_features.to(torch.bfloat16) + maskmem_features = maskmem_features.to(storage_device, non_blocking=True) + pred_masks_gpu = current_out["pred_masks"] + # potentially fill holes in the predicted masks + if self.fill_hole_area > 0: + pred_masks_gpu = fill_holes_in_mask_scores( + pred_masks_gpu, self.fill_hole_area + ) + pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True) + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it + maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out) + # object pointer is a small tensor, so we always keep it on GPU memory for fast access + obj_ptr = current_out["obj_ptr"] + object_score_logits = current_out["object_score_logits"] + # make a compact version of this frame's output to reduce the state size + compact_current_out = { + "maskmem_features": maskmem_features, + "maskmem_pos_enc": maskmem_pos_enc, + "pred_masks": pred_masks, + "obj_ptr": obj_ptr, + "object_score_logits": object_score_logits, + } + return compact_current_out, pred_masks_gpu + + def _run_memory_encoder( + self, + inference_state, + frame_idx, + batch_size, + high_res_masks, + object_score_logits, + is_mask_from_pts, + ): + """ + Run the memory encoder on `high_res_masks`. This is usually after applying + non-overlapping constraints to object scores. Since their scores changed, their + memory also need to be computed again with the memory encoder. + """ + # Retrieve correct image features + _, _, current_vision_feats, _, feat_sizes = self._get_image_feature( + inference_state, frame_idx, batch_size + ) + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + current_vision_feats=current_vision_feats, + feat_sizes=feat_sizes, + pred_masks_high_res=high_res_masks, + object_score_logits=object_score_logits, + is_mask_from_pts=is_mask_from_pts, + ) + + # optionally offload the output to CPU memory to save GPU space + storage_device = inference_state["storage_device"] + maskmem_features = maskmem_features.to(torch.bfloat16) + maskmem_features = maskmem_features.to(storage_device, non_blocking=True) + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it + maskmem_pos_enc = self._get_maskmem_pos_enc( + inference_state, {"maskmem_pos_enc": maskmem_pos_enc} + ) + return maskmem_features, maskmem_pos_enc + + def _get_maskmem_pos_enc(self, inference_state, current_out): + """ + `maskmem_pos_enc` is the same across frames and objects, so we cache it as + a constant in the inference session to reduce session storage size. + """ + model_constants = inference_state["constants"] + # "out_maskmem_pos_enc" should be either a list of tensors or None + out_maskmem_pos_enc = current_out["maskmem_pos_enc"] + if out_maskmem_pos_enc is not None: + if "maskmem_pos_enc" not in model_constants: + assert isinstance(out_maskmem_pos_enc, list) + # only take the slice for one object, since it's same across objects + maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc] + model_constants["maskmem_pos_enc"] = maskmem_pos_enc + else: + maskmem_pos_enc = model_constants["maskmem_pos_enc"] + # expand the cached maskmem_pos_enc to the actual batch size + batch_size = out_maskmem_pos_enc[0].size(0) + expanded_maskmem_pos_enc = [ + x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc + ] + else: + expanded_maskmem_pos_enc = None + return expanded_maskmem_pos_enc + + @torch.inference_mode() + def remove_object(self, inference_state, obj_id, strict=False, need_output=True): + """ + Remove an object id from the tracking state. If strict is True, we check whether + the object id actually exists and raise an error if it doesn't exist. + """ + old_obj_idx_to_rm = inference_state["obj_id_to_idx"].get(obj_id, None) + updated_frames = [] + # Check whether this object_id to remove actually exists and possibly raise an error. + if old_obj_idx_to_rm is None: + if not strict: + return inference_state["obj_ids"], updated_frames + raise RuntimeError( + f"Cannot remove object id {obj_id} as it doesn't exist. " + f"All existing object ids: {inference_state['obj_ids']}." + ) + + # If this is the only remaining object id, we simply reset the state. + if len(inference_state["obj_id_to_idx"]) == 1: + self.reset_state(inference_state) + return inference_state["obj_ids"], updated_frames + + # There are still remaining objects after removing this object id. In this case, + # we need to delete the object storage from inference state tensors. + # Step 0: clear the input on those frames where this object id has point or mask input + # (note that this step is required as it might downgrade conditioning frames to + # non-conditioning ones) + obj_input_frames_inds = set() + obj_input_frames_inds.update( + inference_state["point_inputs_per_obj"][old_obj_idx_to_rm] + ) + obj_input_frames_inds.update( + inference_state["mask_inputs_per_obj"][old_obj_idx_to_rm] + ) + for frame_idx in obj_input_frames_inds: + self.clear_all_prompts_in_frame( + inference_state, frame_idx, obj_id, need_output=False + ) + + # Step 1: Update the object id mapping (note that it must be done after Step 0, + # since Step 0 still requires the old object id mappings in inference_state) + old_obj_ids = inference_state["obj_ids"] + old_obj_inds = list(range(len(old_obj_ids))) + remain_old_obj_inds = old_obj_inds.copy() + remain_old_obj_inds.remove(old_obj_idx_to_rm) + new_obj_ids = [old_obj_ids[old_idx] for old_idx in remain_old_obj_inds] + new_obj_inds = list(range(len(new_obj_ids))) + # build new mappings + old_idx_to_new_idx = dict(zip(remain_old_obj_inds, new_obj_inds)) + inference_state["obj_id_to_idx"] = dict(zip(new_obj_ids, new_obj_inds)) + inference_state["obj_idx_to_id"] = dict(zip(new_obj_inds, new_obj_ids)) + inference_state["obj_ids"] = new_obj_ids + + # Step 2: For per-object tensor storage, we shift their obj_idx in the dict keys. + # (note that "consolidated_frame_inds" doesn't need to be updated in this step as + # it's already handled in Step 0) + def _map_keys(container): + new_kvs = [] + for k in old_obj_inds: + v = container.pop(k) + if k in old_idx_to_new_idx: + new_kvs.append((old_idx_to_new_idx[k], v)) + container.update(new_kvs) + + _map_keys(inference_state["point_inputs_per_obj"]) + _map_keys(inference_state["mask_inputs_per_obj"]) + _map_keys(inference_state["output_dict_per_obj"]) + _map_keys(inference_state["temp_output_dict_per_obj"]) + + # Step 3: For packed tensor storage, we index the remaining ids and rebuild the per-object slices. + def _slice_state(output_dict, storage_key): + for frame_idx, out in output_dict[storage_key].items(): + out["maskmem_features"] = out["maskmem_features"][remain_old_obj_inds] + out["maskmem_pos_enc"] = [ + x[remain_old_obj_inds] for x in out["maskmem_pos_enc"] + ] + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it + out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(inference_state, out) + out["pred_masks"] = out["pred_masks"][remain_old_obj_inds] + out["obj_ptr"] = out["obj_ptr"][remain_old_obj_inds] + out["object_score_logits"] = out["object_score_logits"][ + remain_old_obj_inds + ] + # also update the per-object slices + self._add_output_per_object( + inference_state, frame_idx, out, storage_key + ) + + _slice_state(inference_state["output_dict"], "cond_frame_outputs") + _slice_state(inference_state["output_dict"], "non_cond_frame_outputs") + + # Step 4: Further collect the outputs on those frames in `obj_input_frames_inds`, which + # could show an updated mask for objects previously occluded by the object being removed + if need_output: + temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] + for frame_idx in obj_input_frames_inds: + is_cond = any( + frame_idx in obj_temp_output_dict["cond_frame_outputs"] + for obj_temp_output_dict in temp_output_dict_per_obj.values() + ) + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, + frame_idx, + is_cond=is_cond, + run_mem_encoder=False, + consolidate_at_video_res=True, + ) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, consolidated_out["pred_masks_video_res"] + ) + updated_frames.append((frame_idx, video_res_masks)) + + return inference_state["obj_ids"], updated_frames + + def _clear_non_cond_mem_around_input(self, inference_state, frame_idx): + """ + Remove the non-conditioning memory around the input frame. When users provide + correction clicks, the surrounding frames' non-conditioning memories can still + contain outdated object appearance information and could confuse the model. + + This method clears those non-conditioning memories surrounding the interacted + frame to avoid giving the model both old and new information about the object. + """ + r = self.memory_temporal_stride_for_eval + frame_idx_begin = frame_idx - r * self.num_maskmem + frame_idx_end = frame_idx + r * self.num_maskmem + output_dict = inference_state["output_dict"] + non_cond_frame_outputs = output_dict["non_cond_frame_outputs"] + for t in range(frame_idx_begin, frame_idx_end + 1): + non_cond_frame_outputs.pop(t, None) + for obj_output_dict in inference_state["output_dict_per_obj"].values(): + obj_output_dict["non_cond_frame_outputs"].pop(t, None) diff --git a/setup.py b/setup.py index c67a949fc..78a634cdd 100644 --- a/setup.py +++ b/setup.py @@ -22,8 +22,8 @@ # Required dependencies REQUIRED_PACKAGES = [ - "torch>=2.3.1", - "torchvision>=0.18.1", + "torch>=2.5.1", + "torchvision>=0.20.1", "numpy>=1.24.4", "tqdm>=4.66.1", "hydra-core>=1.3.2", @@ -58,7 +58,7 @@ "scikit-image>=0.24.0", "tensorboard>=2.17.0", "pycocotools>=2.0.8", - "tensordict>=0.5.0", + "tensordict>=0.6.0", "opencv-python>=4.7.0", "submitit>=1.5.1", ], diff --git a/tools/vos_inference.py b/tools/vos_inference.py index 5c40cda9e..ef3e8c674 100644 --- a/tools/vos_inference.py +++ b/tools/vos_inference.py @@ -375,7 +375,7 @@ def main(): parser.add_argument( "--sam2_checkpoint", type=str, - default="./checkpoints/sam2.1_hiera_b+.pt", + default="./checkpoints/sam2.1_hiera_base_plus.pt", help="path to the SAM 2 model checkpoint", ) parser.add_argument( @@ -434,6 +434,11 @@ def main(): help="whether to track objects that appear later in the video (i.e. not on the first frame; " "some VOS datasets like LVOS or YouTube-VOS don't have all objects appearing in the first frame)", ) + parser.add_argument( + "--use_vos_optimized_video_predictor", + action="store_true", + help="whether to use vos optimized video predictor with all modules compiled", + ) args = parser.parse_args() # if we use per-object PNG files, they could possibly overlap in inputs and outputs @@ -445,6 +450,7 @@ def main(): ckpt_path=args.sam2_checkpoint, apply_postprocessing=args.apply_postprocessing, hydra_overrides_extra=hydra_overrides_extra, + vos_optimized=args.use_vos_optimized_video_predictor, ) if args.use_all_masks: