forked from vdurnov/xview2_1st_place_solution
-
Notifications
You must be signed in to change notification settings - Fork 0
/
create_masks.py
96 lines (76 loc) · 3.16 KB
/
create_masks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import argparse
import pathlib
import numpy as np
import numpy.typing as npt
import cv2
import shapely.geometry
import torch
from tqdm.autonotebook import tqdm
from metadamagenet.configs import DamageType
from metadamagenet.dataset import ImageData, DataTime, discover_directory
class MaskCreator:
damage_type_color = {
DamageType.UN_CLASSIFIED: 1,
DamageType.NO_DAMAGE: 1,
DamageType.MINOR_DAMAGE: 2,
DamageType.MAJOR_DAMAGE: 3,
DamageType.DESTROYED: 4
}
@staticmethod
def mask_for_polygon(poly: shapely.geometry.Polygon, im_size=(1024, 1024)):
"""
creates a binary mask from a polygon
"""
def int_coords(x):
return np.array(x).round().astype(np.int32)
img_mask = np.zeros(im_size, np.uint8)
exteriors = [int_coords(poly.exterior.coords)]
interiors = [int_coords(p.coords) for p in poly.interiors]
cv2.fillPoly(img_mask, exteriors, 1)
cv2.fillPoly(img_mask, interiors, 0)
return img_mask
@classmethod
def create_loc_mask(cls, image_data: ImageData) -> npt.NDArray:
"""
creates localization mask for image data
"""
localization_mask = np.zeros((1024, 1024), dtype='uint8') # a mask-image including polygons
for polygon in image_data.polygons(DataTime.PRE):
_msk = cls.mask_for_polygon(polygon)
localization_mask[_msk > 0] = 1
return localization_mask
@classmethod
def create_cls_mask(cls, image_data: ImageData) -> npt.NDArray:
classification_mask = np.zeros((1024, 1024), dtype='uint8') # a mask-image with damage levels
damage_type: DamageType
for polygon, damage_type in image_data.polygons(DataTime.POST):
_msk = cls.mask_for_polygon(polygon)
classification_mask[_msk > 0] = cls.damage_type_color[damage_type]
return
@classmethod
def save_masks(cls, image_data: ImageData, localization_msk: torch.Tensor,
classification_msk: torch.Tensor) -> None:
cv2.imwrite(str(image_data.mask(DataTime.PRE)),
localization_msk,
[cv2.IMWRITE_PNG_COMPRESSION, 9])
cv2.imwrite(str(image_data.mask(DataTime.POST)),
classification_msk,
[cv2.IMWRITE_PNG_COMPRESSION, 9])
def __init__(self, source: pathlib.Path):
assert source.exists() and source.is_dir(), \
f"source {source.absolute()} does not exist or its not a directory"
self._source: pathlib.Path = source
def run(self):
image_dataset = discover_directory(self._source, check=False)
for image_data in tqdm(image_dataset):
localization_msk: torch.Tensor = self.create_loc_mask(image_data)
classification_msk: torch.Tensor = self.create_cls_mask(image_data)
self.save_masks(image_data, localization_msk, classification_msk)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--source', required=True)
args = parser.parse_args()
mask_creator = MaskCreator(pathlib.Path(args.source))
mask_creator.run()
if __name__ == '__main__':
main()