Skip to content

Commit

Permalink
Merge pull request #1 from gremlinflat/master
Browse files Browse the repository at this point in the history
xinntao#584 - custom backend device (auto, cuda, m1, cpu)
  • Loading branch information
guimondmm authored Jan 20, 2024
2 parents 5ca1078 + 6b5815c commit 5249f94
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 1 deletion.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,7 @@ dmypy.json

# Pyre type checker
.pyre/

# VSCode
.vscode/
.vscode/*
21 changes: 21 additions & 0 deletions inference_realesrgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact

from torch.cuda import is_available as cudaIsAvailable
from torch.backends.mps import is_available as mpsIsAvailable

def main():
"""Inference demo for Real-ESRGAN.
Expand Down Expand Up @@ -52,6 +54,8 @@ def main():
parser.add_argument(
'-g', '--gpu-id', type=int, default=None, help='gpu device to use (default=None) can be 0,1,2 for multi-gpu')

parser.add_argument('--backend_type', type=str, default='auto', choices=['auto', 'cuda', 'cpu', 'mps'], help='backend type. Options: auto(cuda-cpu) | cuda | cpu | mps')

args = parser.parse_args()

# determine models according to model names
Expand Down Expand Up @@ -103,6 +107,21 @@ def main():
model_path = [model_path, wdn_model_path]
dni_weight = [args.denoise_strength, 1 - args.denoise_strength]

# deternime backend type (cpu, cuda, mps)
if args.backend_type == 'auto':
if cudaIsAvailable():
backend_type = 'cuda'
elif mpsIsAvailable():
backend_type = 'mps'
else:
backend_type = 'cpu'
elif args.backend_type == 'cuda' and cudaIsAvailable():
backend_type = 'cuda'
elif args.backend_type == 'mps' and mpsIsAvailable():
backend_type = 'mps'
else:
backend_type = 'cpu'

# restorer
upsampler = RealESRGANer(
scale=netscale,
Expand All @@ -113,6 +132,7 @@ def main():
tile_pad=args.tile_pad,
pre_pad=args.pre_pad,
half=not args.fp32,
device=backend_type,
gpu_id=args.gpu_id)

if args.face_enhance: # Use GFPGAN for face enhancement
Expand All @@ -122,6 +142,7 @@ def main():
upscale=args.outscale,
arch='clean',
channel_multiplier=2,
device='cpu', # <--- MPS is not supported yet, crash pas runtime. TODO: FIX THIS
bg_upsampler=upsampler)
os.makedirs(args.output, exist_ok=True)

Expand Down
2 changes: 1 addition & 1 deletion realesrgan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def pre_process(self, img):
"""Pre-process, such as pre-pad and mod pad, so that the images can be divisible
"""
img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
self.img = img.unsqueeze(0).to(self.device)
self.img = img.unsqueeze(0).contiguous().to(self.device)
if self.half:
self.img = self.img.half()

Expand Down

0 comments on commit 5249f94

Please sign in to comment.