-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for Apple's MPS backend #123
Changes from 7 commits
db20a5b
be42ab0
d9fa0c8
ba4aca1
466e786
be98ff3
7c899fc
c44b2c6
2b073b0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ | |
|
||
from setuptools import find_packages, setup | ||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension | ||
import subprocess | ||
|
||
# Package metadata | ||
NAME = "SAM 2" | ||
|
@@ -29,15 +30,26 @@ | |
"hydra-core>=1.3.2", | ||
"iopath>=0.1.10", | ||
"pillow>=9.4.0", | ||
"scipy>=1.14.0", | ||
] | ||
|
||
EXTRA_PACKAGES = { | ||
"demo": ["matplotlib>=3.9.1", "jupyter>=1.0.0", "opencv-python>=4.7.0"], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When trying to work off this branch, I ran into an issue with Separately, I had to comment out There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I ran into the same problem with |
||
"dev": ["black==24.2.0", "usort==1.0.2", "ufmt==2.0.0b2"], | ||
} | ||
|
||
def find_cuda(): | ||
try: | ||
subprocess.check_output(["nvcc", "--version"]) | ||
return True | ||
except subprocess.CalledProcessError: | ||
return False | ||
except FileNotFoundError: | ||
return False | ||
|
||
def get_extensions(): | ||
if not find_cuda(): | ||
return [] | ||
srcs = ["sam2/csrc/connected_components.cu"] | ||
compile_args = { | ||
"cxx": [], | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You may want to consider using torch.backends.cuda.is_available() directly here instead of checking for nvcc.
Docs: https://pytorch.org/docs/stable/backends.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the advice! I have updated setup.py to use torch.cuda.is_available() instead of checking nvcc.