Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable the test scripts to use different face detectors. #4

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ python face_warping_test.py -v 0 -p 1 -r -k

Command-line arguments:
```
-v VIDEO: Index of the webcam to use (start from 0) or
-v VIDEO: Index of the webcam to use (default=0) or
path of the input video file
-x WIDTH: Width of the warped frames (default=256)
-y HEIGHT: Height of the warped frames (default=256)
Expand All @@ -47,8 +47,14 @@ Command-line arguments:
-s: To use square-shaped detection box
-n: To use nearest-neighbour interpolation during restoration
-k: Keep aspect ratio in tanh-polar or tanh-circular warping
-d: Device to be used by PyTorch (default=cuda:0)
-d: Device to be used by the warping functions (default=cuda:0)
-b: Enable benchmark mode for CUDNN
-dt: Confidence threshold for face detection (default=0.8)
-dm: Face detection method, can be either RatinaFace (default)
or S3FD
-dw: Weights to be loaded for face detection, can be either
resnet50 or mobilenet0.25 when using RetinaFace
-dd: Device to be used for face detection (default=cuda:0)
```

There is also a script to specifically test the transform from ROI-tanh-polar space to the Cartesian ROI-tanh space (or in the reverse direction).
Expand All @@ -59,7 +65,7 @@ python tanh_polar_to_cartesian_test.py -v 0 -r -k

Command-line arguments:
```
-v VIDEO: Index of the webcam to use (start from 0) or
-v VIDEO: Index of the webcam to use (default=0) or
path of the input video file
-x WIDTH: Width of the warped frames (default=256)
-y HEIGHT: Height of the warped frames (default=256)
Expand All @@ -70,6 +76,12 @@ Command-line arguments:
-s: To use square-shaped detection box
-k: Keep aspect ratio in tanh-polar or tanh-circular warping
-i: To perform computation in the reverse direction
-d: Device to be used by PyTorch (default=cuda:0)
-d: Device to be used by the warping functions (default=cuda:0)
-b: Enable benchmark mode for CUDNN
-dt: Confidence threshold for face detection (default=0.8)
-dm: Face detection method, can be either RatinaFace (default)
or S3FD
-dw: Weights to be loaded for face detection, can be either
resnet50 or mobilenet0.25 when using RetinaFace
-dd: Device to be used for face detection (default=cuda:0)
```
35 changes: 28 additions & 7 deletions face_warping_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
from typing import Tuple, Optional
from argparse import ArgumentParser
from ibug.face_detection import RetinaFacePredictor
from ibug.face_detection import RetinaFacePredictor, S3FDPredictor

from ibug.roi_tanh_warping import *
from ibug.roi_tanh_warping import reference_impl as ref
Expand Down Expand Up @@ -102,7 +102,7 @@ def test_reference_impl(frame: np.ndarray, face_box: np.ndarray, target_width: i

def main() -> None:
parser = ArgumentParser()
parser.add_argument('--video', '-v', help='Video source')
parser.add_argument('--video', '-v', help='Video source (default=0)', default=0)
parser.add_argument('--width', '-x', help='Width of the warped image (default=256)', type=int, default=256)
parser.add_argument('--height', '-y', help='Height of the warped image (default=256)', type=int, default=256)
parser.add_argument('--polar', '-p', help='Use polar coordinates', type=int, default=0)
Expand All @@ -117,17 +117,38 @@ def main() -> None:
action='store_true', default=False)
parser.add_argument('--keep-aspect-ratio', '-k', help='Keep aspect ratio in tanh-polar or tanh-circular warping',
action='store_true', default=False)
parser.add_argument('--device', '-d', help='Device to be used by PyTorch (default=cuda:0)', default='cuda:0')
parser.add_argument('--device', '-d', default='cuda:0',
help='Device to be used by the warping functions (default=cuda:0)')
parser.add_argument('--benchmark', '-b', help='Enable benchmark mode for CUDNN',
action='store_true', default=False)
parser.add_argument('--detection-threshold', '-dt', type=float, default=0.8,
help='Confidence threshold for face detection (default=0.8)')
parser.add_argument('--detection-method', '-dm', default='retinaface',
help='Face detection method, can be either RatinaFace or S3FD (default=RatinaFace)')
parser.add_argument('--detection-weights', '-dw', default=None,
help='Weights to be loaded for face detection, ' +
'can be either resnet50 or mobilenet0.25 when using RetinaFace')
parser.add_argument('--detection-device', '-dd', default='cuda:0',
help='Device to be used for face detection (default=cuda:0)')
args = parser.parse_args()

# Make the models run a bit faster
torch.backends.cudnn.benchmark = args.benchmark

# Create object detector
detector = RetinaFacePredictor(device=args.device, model=RetinaFacePredictor.get_model('mobilenet0.25'))
print('RetinaFace detector created using mobilenet0.25 backbone.')
# Create the face detector
args.detection_method = args.detection_method.lower()
if args.detection_method == 'retinaface':
face_detector = RetinaFacePredictor(threshold=args.detection_threshold, device=args.detection_device,
model=(RetinaFacePredictor.get_model(args.detection_weights)
if args.detection_weights else None))
print('Face detector created using RetinaFace.')
elif args.detection_method == 's3fd':
face_detector = S3FDPredictor(threshold=args.detection_threshold, device=args.detection_device,
model=(S3FDPredictor.get_model(args.detection_weights)
if args.detection_weights else None))
print('Face detector created using S3FD.')
else:
raise ValueError('detector-method must be set to either RetinaFace or S3FD')

# Open webcam
if os.path.exists(args.video):
Expand All @@ -148,7 +169,7 @@ def main() -> None:
break
else:
# Face detection
face_boxes = detector(frame, rgb=False)
face_boxes = face_detector(frame, rgb=False)
if len(face_boxes) > 0:
biggest_face_idx = int(np.argmax([(bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
for bbox in face_boxes]))
Expand Down
35 changes: 28 additions & 7 deletions tanh_polar_to_cartesian_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
from typing import Tuple, Optional
from argparse import ArgumentParser
from ibug.face_detection import RetinaFacePredictor
from ibug.face_detection import RetinaFacePredictor, S3FDPredictor

from ibug.roi_tanh_warping import *
from ibug.roi_tanh_warping import reference_impl as ref
Expand Down Expand Up @@ -96,7 +96,7 @@ def test_pytorch_impl(device: str, frame: np.ndarray, face_box: np.ndarray, targ

def main() -> None:
parser = ArgumentParser()
parser.add_argument('--video', '-v', help='Video source')
parser.add_argument('--video', '-v', help='Video source (default=0)', default=0)
parser.add_argument('--width', '-x', help='Width of the warped image (default=256)', type=int, default=256)
parser.add_argument('--height', '-y', help='Height of the warped image (default=256)', type=int, default=256)
parser.add_argument('--offset', '-o', help='Angular offset, only used when polar>0', type=float, default=0.0)
Expand All @@ -112,17 +112,38 @@ def main() -> None:
action='store_true', default=False)
parser.add_argument('--reverse', '-i', help='Perform computation in the reverse direction',
action='store_true', default=False)
parser.add_argument('--device', '-d', help='Device to be used (default=cuda:0)', default='cuda:0')
parser.add_argument('--device', '-d', default='cuda:0',
help='Device to be used by the warping functions (default=cuda:0)')
parser.add_argument('--benchmark', '-b', help='Enable benchmark mode for CUDNN',
action='store_true', default=False)
parser.add_argument('--detection-threshold', '-dt', type=float, default=0.8,
help='Confidence threshold for face detection (default=0.8)')
parser.add_argument('--detection-method', '-dm', default='retinaface',
help='Face detection method, can be either RatinaFace or S3FD (default=RatinaFace)')
parser.add_argument('--detection-weights', '-dw', default=None,
help='Weights to be loaded for face detection, ' +
'can be either resnet50 or mobilenet0.25 when using RetinaFace')
parser.add_argument('--detection-device', '-dd', default='cuda:0',
help='Device to be used for face detection (default=cuda:0)')
args = parser.parse_args()

# Make the models run a bit faster
torch.backends.cudnn.benchmark = args.benchmark

# Create face detector
detector = RetinaFacePredictor(device=args.device, model=RetinaFacePredictor.get_model('mobilenet0.25'))
print('RetinaFace detector created using mobilenet0.25 backbone.')
# Create the face detector
args.detection_method = args.detection_method.lower()
if args.detection_method == 'retinaface':
face_detector = RetinaFacePredictor(threshold=args.detection_threshold, device=args.detection_device,
model=(RetinaFacePredictor.get_model(args.detection_weights)
if args.detection_weights else None))
print('Face detector created using RetinaFace.')
elif args.detection_method == 's3fd':
face_detector = S3FDPredictor(threshold=args.detection_threshold, device=args.detection_device,
model=(S3FDPredictor.get_model(args.detection_weights)
if args.detection_weights else None))
print('Face detector created using S3FD.')
else:
raise ValueError('detector-method must be set to either RetinaFace or S3FD')

# Open webcam
if os.path.exists(args.video):
Expand All @@ -143,7 +164,7 @@ def main() -> None:
break
else:
# Face detection
face_boxes = detector(frame, rgb=False)
face_boxes = face_detector(frame, rgb=False)
if len(face_boxes) > 0:
biggest_face_idx = int(np.argmax([(bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
for bbox in face_boxes]))
Expand Down