diff --git a/face_parsing/tester.py b/face_parsing/tester.py index f1ff8e9..33cf31f 100644 --- a/face_parsing/tester.py +++ b/face_parsing/tester.py @@ -104,8 +104,8 @@ def test(self): imgs = torch.stack(imgs) imgs = imgs.cuda() labels_predict = self.G(imgs) - labels_predict_plain = generate_label_plain(labels_predict) - labels_predict_color = generate_label(labels_predict) + labels_predict_plain = generate_label_plain(labels_predict,self.imsize) + labels_predict_color = generate_label(labels_predict,self.imsize) for k in range(self.batch_size): cv2.imwrite(os.path.join(self.test_label_path, str(i * self.batch_size + k) +'.png'), labels_predict_plain[k]) save_image(labels_predict_color[k], os.path.join(self.test_color_label_path, str(i * self.batch_size + k) +'.png'))