-
Notifications
You must be signed in to change notification settings - Fork 6
/
test_depth
executable file
·102 lines (82 loc) · 4.19 KB
/
test_depth
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#!/usr/bin/env python
import numpy as np
import torch
from argparse import ArgumentParser
import datasets
import os
from traversability_estimation.utils import visualize_imgs, visualize_cloud
import open3d as o3d
def main():
parser = ArgumentParser()
# parser.add_argument('--dataset', type=str, default='Rellis3DClouds')
# parser.add_argument('--dataset', type=str, default='TraversabilityClouds')
parser.add_argument('--dataset', type=str, default='FlexibilityClouds')
# parser.add_argument('--dataset', type=str, default='SemanticUSL')
parser.add_argument('--weights', type=str, default='deeplabv3_resnet101_lr_0.0001_bs_6_epoch_80_FlexibilityClouds_depth_64x1024_labels_flexibility_iou_0.790.pth')
parser.add_argument('--device', type=str, default='cpu')
args = parser.parse_args()
print(args)
pkg_path = os.path.realpath(os.path.join(os.path.dirname(__file__), '../../'))
# Initialize model with the best available weights
model_name = args.weights
assert args.dataset in model_name
model = torch.load(os.path.join(pkg_path, 'config/weights/depth_cloud', model_name), map_location=args.device)
# model = torch.load(model_name, map_location=args.device)
model.eval()
data_fields = [f[1:-1] for f in ['_x_', '_y_', '_z_', '_intensity_', '_depth_'] if f in model_name]
print('Model takes as input: %s' % ','.join(data_fields))
if 'traversability' in model_name.lower():
output = 'traversability'
ignore_label = 255
elif 'flexibility' in model_name.lower():
output = 'flexibility'
ignore_label = 255
else:
output = None
ignore_label = 0
Dataset = eval('datasets.%s' % args.dataset)
ds = Dataset(split='test', fields=data_fields,
output=output,
lidar_H_step=2, lidar_W_step=1)
for _ in range(5):
# Apply inference preprocessing transforms
inpt, label = ds[np.random.choice(range(len(ds)))]
depth_img = inpt[0]
power = 16
depth_img_vis = np.copy(depth_img).squeeze() # depth
depth_img_vis[depth_img_vis > 0] = depth_img_vis[depth_img_vis > 0] ** (1 / power)
depth_img_vis[depth_img_vis > 0] = (depth_img_vis[depth_img_vis > 0] - depth_img_vis[depth_img_vis > 0].min()) / \
(depth_img_vis[depth_img_vis > 0].max() - depth_img_vis[
depth_img_vis > 0].min())
# Use the model and visualize the prediction
batch = torch.from_numpy(inpt).unsqueeze(0).to(args.device)
with torch.no_grad():
pred = model(batch)['out']
pred = torch.softmax(pred.squeeze(0), dim=0).cpu().numpy()
pred = np.argmax(pred, axis=0)
pred_ign = pred.copy()
pred_ign[label == ignore_label] = ignore_label
# label_flex = pred == 1
# depth_img_with_flex_points = (0.3 * depth_img_vis + 0.7 * label_flex).astype("float")
color_pred = ds.label_to_color(pred)
color_pred_ign = ds.label_to_color(pred_ign)
color_gt = ds.label_to_color(label)
visualize_imgs(layout='columns',
depth_img=depth_img_vis,
# depth_img_with_flex_points=depth_img_with_flex_points,
prediction=color_pred,
prediction_without_background=color_pred_ign,
ground_truth=color_gt,
)
# visualize_cloud(xyz=ds.scan.proj_xyz.reshape((-1, 3)), color=color_pred.reshape((-1, 3)))
# visualize_cloud(xyz=ds.scan.proj_xyz.reshape((-1, 3)), color=color_gt.reshape((-1, 3)))
pcd = o3d.geometry.PointCloud()
xyz = ds.scan.proj_xyz[::ds.lidar_H_step]
pcd.points = o3d.utility.Vector3dVector(xyz.reshape((-1, 3)))
pcd.colors = o3d.utility.Vector3dVector(color_pred.reshape((-1, 3)) / color_pred.max())
pcd_gt = o3d.geometry.PointCloud()
pcd_gt.points = o3d.utility.Vector3dVector(xyz.reshape((-1, 3)) + np.asarray([50, 0, 0]))
pcd_gt.colors = o3d.utility.Vector3dVector(color_gt.reshape((-1, 3)) / color_gt.max())
o3d.visualization.draw_geometries([pcd, pcd_gt])
if __name__ == '__main__':
main()