-
Notifications
You must be signed in to change notification settings - Fork 6
/
test_img
executable file
·70 lines (59 loc) · 2.63 KB
/
test_img
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#!/usr/bin/env python
from __future__ import absolute_import
import cv2
import numpy as np
import torch
from argparse import ArgumentParser
from matplotlib import pyplot as plt
import datasets
from traversability_estimation.utils import convert_label, convert_color
import yaml
import os
def main():
parser = ArgumentParser()
# parser.add_argument('--dataset', type=str, default='TraversabilityImagesFiftyone')
parser.add_argument('--dataset', type=str, default='Rellis3DImages')
parser.add_argument('--img_size', nargs='+', default=(192, 320))
parser.add_argument('--device', type=str, default='cpu')
args = parser.parse_args()
print(args)
Dataset = eval('datasets.%s' % args.dataset)
ds = Dataset(crop_size=args.img_size, split='test')
# Initialize model with the best available weights
model_name = 'fcn_resnet50_lr_1e-05_bs_2_epoch_44_TraversabilityImages_iou_0.86.pth'
# model_name = 'fcn_resnet50_lr_1e-05_bs_1_epoch_3_TraversabilityImages_iou_0.71.pth'
model_path = os.path.join('../../config/weights/image/', model_name)
model = torch.load(model_path, map_location=args.device).eval()
pkg_path = os.path.realpath(os.path.join(os.path.dirname(__file__), '../../'))
label_config = os.path.join(pkg_path, "config/rellis.yaml")
data_cfg = yaml.safe_load(open(label_config, 'r'))
for i in range(5):
# Apply inference preprocessing transforms
img, gt_mask = ds[i]
img_vis = np.uint8(255 * (img * ds.std + ds.mean))
if ds.split == 'test':
img = img.transpose((2, 0, 1)) # (H x W x C) -> (C x H x W)
batch = torch.from_numpy(img).unsqueeze(0).to(args.device)
# Use the model and visualize the prediction
with torch.no_grad():
pred = model(batch)['out']
pred = torch.softmax(pred, dim=1)
pred = pred.squeeze(0).cpu().numpy()
mask = np.argmax(pred, axis=0)
gt_mask = np.argmax(gt_mask, axis=0)
# mask = convert_label(mask, inverse=True)
size = (args.img_size[1], args.img_size[0])
mask = cv2.resize(mask.astype('float32'), size, interpolation=cv2.INTER_LINEAR).astype('int8')
# result = convert_color(mask, data_cfg['color_map'])
result = convert_color(mask, {0: [0, 0, 0], 1: [0, 255, 0], 2: [255, 0, 0]})
gt_result = convert_color(gt_mask, {0: [0, 0, 0], 1: [0, 255, 0], 2: [255, 0, 0]})
plt.figure(figsize=(20, 10))
plt.subplot(1, 3, 1)
plt.imshow(img_vis)
plt.subplot(1, 3, 2)
plt.imshow(result)
plt.subplot(1, 3, 3)
plt.imshow(gt_result)
plt.show()
if __name__ == '__main__':
main()