Skip to content

Commit

Permalink
move "PYTORCH_ENABLE_MPS_FALLBACK" to the beginning before importing …
Browse files Browse the repository at this point in the history
…torch
  • Loading branch information
ronghanghu committed Aug 10, 2024
1 parent 5ee4cdb commit 91f4175
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 12 deletions.
7 changes: 3 additions & 4 deletions notebooks/automatic_mask_generator_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@
"outputs": [],
"source": [
"import os\n",
"# if using Apple MPS, fall back to CPU for unsupported ops\n",
"os.environ[\"PYTORCH_ENABLE_MPS_FALLBACK\"] = \"1\"\n",
"import numpy as np\n",
"import torch\n",
"import matplotlib.pyplot as plt\n",
Expand Down Expand Up @@ -143,10 +145,7 @@
" # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)\n",
" if torch.cuda.get_device_properties(0).major >= 8:\n",
" torch.backends.cuda.matmul.allow_tf32 = True\n",
" torch.backends.cudnn.allow_tf32 = True\n",
"elif device.type == \"mps\":\n",
" # fall back to CPU for unsupported ops\n",
" os.environ[\"PYTORCH_ENABLE_MPS_FALLBACK\"] = \"1\""
" torch.backends.cudnn.allow_tf32 = True"
]
},
{
Expand Down
7 changes: 3 additions & 4 deletions notebooks/image_predictor_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@
"outputs": [],
"source": [
"import os\n",
"# if using Apple MPS, fall back to CPU for unsupported ops\n",
"os.environ[\"PYTORCH_ENABLE_MPS_FALLBACK\"] = \"1\"\n",
"import numpy as np\n",
"import torch\n",
"import matplotlib.pyplot as plt\n",
Expand Down Expand Up @@ -151,10 +153,7 @@
" # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)\n",
" if torch.cuda.get_device_properties(0).major >= 8:\n",
" torch.backends.cuda.matmul.allow_tf32 = True\n",
" torch.backends.cudnn.allow_tf32 = True\n",
"elif device.type == \"mps\":\n",
" # fall back to CPU for unsupported ops\n",
" os.environ[\"PYTORCH_ENABLE_MPS_FALLBACK\"] = \"1\""
" torch.backends.cudnn.allow_tf32 = True"
]
},
{
Expand Down
7 changes: 3 additions & 4 deletions notebooks/video_predictor_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@
"outputs": [],
"source": [
"import os\n",
"# if using Apple MPS, fall back to CPU for unsupported ops\n",
"os.environ[\"PYTORCH_ENABLE_MPS_FALLBACK\"] = \"1\"\n",
"import numpy as np\n",
"import torch\n",
"import matplotlib.pyplot as plt\n",
Expand Down Expand Up @@ -147,10 +149,7 @@
" # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)\n",
" if torch.cuda.get_device_properties(0).major >= 8:\n",
" torch.backends.cuda.matmul.allow_tf32 = True\n",
" torch.backends.cudnn.allow_tf32 = True\n",
"elif device.type == \"mps\":\n",
" # fall back to CPU for unsupported ops\n",
" os.environ[\"PYTORCH_ENABLE_MPS_FALLBACK\"] = \"1\""
" torch.backends.cudnn.allow_tf32 = True"
]
},
{
Expand Down

0 comments on commit 91f4175

Please sign in to comment.