-
Notifications
You must be signed in to change notification settings - Fork 416
/
gen_label_sthv2.py
executable file
·50 lines (45 loc) · 2.07 KB
/
gen_label_sthv2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# Code for "TSM: Temporal Shift Module for Efficient Video Understanding"
# arXiv:1811.08383
# Ji Lin*, Chuang Gan, Song Han
# {jilin, songhan}@mit.edu, [email protected]
# ------------------------------------------------------
# Code adapted from https://github.com/metalbubble/TRN-pytorch/blob/master/process_dataset.py
# processing the raw data of the video Something-Something-V2
import os
import json
if __name__ == '__main__':
dataset_name = 'something-something-v2' # 'jester-v1'
with open('%s-labels.json' % dataset_name) as f:
data = json.load(f)
categories = []
for i, (cat, idx) in enumerate(data.items()):
assert i == int(idx) # make sure the rank is right
categories.append(cat)
with open('category.txt', 'w') as f:
f.write('\n'.join(categories))
dict_categories = {}
for i, category in enumerate(categories):
dict_categories[category] = i
files_input = ['%s-validation.json' % dataset_name, '%s-train.json' % dataset_name, '%s-test.json' % dataset_name]
files_output = ['val_videofolder.txt', 'train_videofolder.txt', 'test_videofolder.txt']
for (filename_input, filename_output) in zip(files_input, files_output):
with open(filename_input) as f:
data = json.load(f)
folders = []
idx_categories = []
for item in data:
folders.append(item['id'])
if 'test' not in filename_input:
idx_categories.append(dict_categories[item['template'].replace('[', '').replace(']', '')])
else:
idx_categories.append(0)
output = []
for i in range(len(folders)):
curFolder = folders[i]
curIDX = idx_categories[i]
# counting the number of frames in each video folders
dir_files = os.listdir(os.path.join('20bn-something-something-v2-frames', curFolder))
output.append('%s %d %d' % (curFolder, len(dir_files), curIDX))
print('%d/%d' % (i, len(folders)))
with open(filename_output, 'w') as f:
f.write('\n'.join(output))