From 91f41750c99fa65254974e1f0f1bf9443a475a5b Mon Sep 17 00:00:00 2001 From: Ronghang Hu Date: Sat, 10 Aug 2024 03:15:37 +0000 Subject: [PATCH] move "PYTORCH_ENABLE_MPS_FALLBACK" to the beginning before importing torch --- notebooks/automatic_mask_generator_example.ipynb | 7 +++---- notebooks/image_predictor_example.ipynb | 7 +++---- notebooks/video_predictor_example.ipynb | 7 +++---- 3 files changed, 9 insertions(+), 12 deletions(-) diff --git a/notebooks/automatic_mask_generator_example.ipynb b/notebooks/automatic_mask_generator_example.ipynb index a18f56cb..2ea68a7b 100644 --- a/notebooks/automatic_mask_generator_example.ipynb +++ b/notebooks/automatic_mask_generator_example.ipynb @@ -107,6 +107,8 @@ "outputs": [], "source": [ "import os\n", + "# if using Apple MPS, fall back to CPU for unsupported ops\n", + "os.environ[\"PYTORCH_ENABLE_MPS_FALLBACK\"] = \"1\"\n", "import numpy as np\n", "import torch\n", "import matplotlib.pyplot as plt\n", @@ -143,10 +145,7 @@ " # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)\n", " if torch.cuda.get_device_properties(0).major >= 8:\n", " torch.backends.cuda.matmul.allow_tf32 = True\n", - " torch.backends.cudnn.allow_tf32 = True\n", - "elif device.type == \"mps\":\n", - " # fall back to CPU for unsupported ops\n", - " os.environ[\"PYTORCH_ENABLE_MPS_FALLBACK\"] = \"1\"" + " torch.backends.cudnn.allow_tf32 = True" ] }, { diff --git a/notebooks/image_predictor_example.ipynb b/notebooks/image_predictor_example.ipynb index e2245182..0f27b54e 100644 --- a/notebooks/image_predictor_example.ipynb +++ b/notebooks/image_predictor_example.ipynb @@ -115,6 +115,8 @@ "outputs": [], "source": [ "import os\n", + "# if using Apple MPS, fall back to CPU for unsupported ops\n", + "os.environ[\"PYTORCH_ENABLE_MPS_FALLBACK\"] = \"1\"\n", "import numpy as np\n", "import torch\n", "import matplotlib.pyplot as plt\n", @@ -151,10 +153,7 @@ " # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)\n", " if torch.cuda.get_device_properties(0).major >= 8:\n", " torch.backends.cuda.matmul.allow_tf32 = True\n", - " torch.backends.cudnn.allow_tf32 = True\n", - "elif device.type == \"mps\":\n", - " # fall back to CPU for unsupported ops\n", - " os.environ[\"PYTORCH_ENABLE_MPS_FALLBACK\"] = \"1\"" + " torch.backends.cudnn.allow_tf32 = True" ] }, { diff --git a/notebooks/video_predictor_example.ipynb b/notebooks/video_predictor_example.ipynb index 23c4ad92..534e18d1 100644 --- a/notebooks/video_predictor_example.ipynb +++ b/notebooks/video_predictor_example.ipynb @@ -111,6 +111,8 @@ "outputs": [], "source": [ "import os\n", + "# if using Apple MPS, fall back to CPU for unsupported ops\n", + "os.environ[\"PYTORCH_ENABLE_MPS_FALLBACK\"] = \"1\"\n", "import numpy as np\n", "import torch\n", "import matplotlib.pyplot as plt\n", @@ -147,10 +149,7 @@ " # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)\n", " if torch.cuda.get_device_properties(0).major >= 8:\n", " torch.backends.cuda.matmul.allow_tf32 = True\n", - " torch.backends.cudnn.allow_tf32 = True\n", - "elif device.type == \"mps\":\n", - " # fall back to CPU for unsupported ops\n", - " os.environ[\"PYTORCH_ENABLE_MPS_FALLBACK\"] = \"1\"" + " torch.backends.cudnn.allow_tf32 = True" ] }, {