-
Notifications
You must be signed in to change notification settings - Fork 3
/
dataset.py
118 lines (88 loc) · 4.08 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
import json
import yt_dlp
import pandas as pd
from tqdm import tqdm
class Taxonomy():
def __init__(self, taxonomy_path) -> None:
self.taxonomy_path = taxonomy_path
df_taxonomy = pd.read_csv(taxonomy_path)
self.target_id2label = {row["Target Id"]: row["Target Label"] for _, row in df_taxonomy.iterrows()}
self.target_label2id = {j:i for i, j in self.target_id2label.items()}
self.action_id2label = {row["Action Id"]: row["Action Label"] for _, row in df_taxonomy.iterrows()}
class Dataset():
def __init__(self, coin_json_path, taxonomy_path) -> None:
self.coin_json_path = coin_json_path
self.taxonomy_path = taxonomy_path
self.taxonomy = Taxonomy(taxonomy_path=taxonomy_path)
self.raw_data = self.load_coin_json_from_file(coin_json_path)
def load_coin_json_from_file(self, file_path):
f = open(file_path)
data = json.load(f)["database"]
return data
def create_dataset(self, target_label_list):
target_ids = {i: [self.taxonomy.target_label2id[j] for j in target_label_list[i]] for i in range(len(target_label_list))}
target_ids_reverse = {}
for upper, target_id_list in target_ids.items():
for each in target_id_list:
target_ids_reverse[each] = upper
dataset_list = []
for sample_id, sample in self.raw_data.items():
annotations = sample["annotation"]
recipe_id = sample["recipe_type"]
video_url = sample["video_url"]
if recipe_id in target_ids_reverse.keys():
for ann in annotations:
segment = f"{ann['segment'][0]}_{ann['segment'][1]}"
label = ann["label"]
dataset_list.append([target_ids_reverse[recipe_id], recipe_id, video_url, segment, label])
pd_limited_data = pd.DataFrame(dataset_list, columns=["label", "action id", "url", "segment", "action label"])
self.classes = pd_limited_data["label"].unique()
return pd_limited_data
def add_download_local_paths(self, df_dataset: pd.DataFrame, save_folder, drop_none) -> pd.DataFrame:
paths = []
os.makedirs(save_folder, exist_ok=True)
for i in self.classes:
os.makedirs(os.path.join(save_folder, str(i)), exist_ok=True)
for index, sample in tqdm(df_dataset.iterrows(), desc="downloading COIN subset", total=len(df_dataset)):
label = sample.label
url = sample.url
video_url_id = url.split("/")[-1]
URLS = [url]
save_path = os.path.join(save_folder, str(label), video_url_id) + ".mp4"
if os.path.exists(save_path):
paths.append(save_path)
else:
paths.append(None)
df_dataset["paths"] = paths
if drop_none:
df_dataset = df_dataset.dropna()
df_dataset = df_dataset.reset_index(drop=True)
return df_dataset
def download_dataset(self, df_dataset: pd.DataFrame, save_folder, drop_none) -> pd.DataFrame:
paths = []
os.makedirs(save_folder, exist_ok=True)
for i in self.classes:
os.makedirs(os.path.join(save_folder, str(i)), exist_ok=True)
for index, sample in df_dataset.iterrows():
label = sample.label
url = sample.url
video_url_id = url.split("/")[-1]
URLS = [url]
save_path = os.path.join(save_folder, str(label), video_url_id) + ".mp4"
if os.path.exists(save_path): continue
ydl_opts = {
'format': 'mp4',
'outtmpl': save_path
}
try:
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
error_code = ydl.download(URLS)
paths.append(save_path)
except:
paths.append(None)
df_dataset["paths"] = paths
if drop_none:
df_dataset = df_dataset.dropna()
df_dataset = df_dataset.reset_index(drop=True)
return df_dataset