-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
15 changed files
with
2,039 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,238 @@ | ||
#!/usr/bin/env python | ||
# coding=utf-8 | ||
import torch | ||
from collections import OrderedDict | ||
import numpy as np | ||
from sklearn.metrics import precision_recall_fscore_support | ||
|
||
from tqdm import tqdm | ||
|
||
CONDITIONS = [ | ||
"Enlarged Cardiomediastinum", | ||
"Cardiomegaly", | ||
"Lung Opacity", | ||
"Lung Lesion", | ||
"Edema", | ||
"Consolidation", | ||
"Pneumonia", | ||
"Atelectasis", | ||
"Pneumothorax", | ||
"Pleural Effusion", | ||
"Pleural Other", | ||
"Fracture", | ||
"Support Devices", | ||
"No Finding", | ||
] | ||
|
||
|
||
def load_chexbert(checkpoint_path): | ||
import sys | ||
|
||
sys.path.append("../CheXbert/src/") | ||
from models.bert_labeler import bert_labeler | ||
|
||
chexbert = bert_labeler() | ||
checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu")) | ||
new_state_dict = OrderedDict() | ||
for k, v in checkpoint["model_state_dict"].items(): | ||
name = k[7:] # remove `module.` | ||
new_state_dict[name] = v | ||
chexbert.load_state_dict(new_state_dict, strict=False) | ||
print("Loaded reward model from {}".format(checkpoint_path)) | ||
chexbert.eval() | ||
return chexbert.cuda() | ||
|
||
|
||
def compute_ce_metric(references, hypotheses, is_temporals, chexbert, bert_tokenizer): | ||
def pad_strings(strs): | ||
max_len = max([len(s) for s in strs]) | ||
return [s + " " * (max_len - len(s)) for s in strs] | ||
|
||
chexbert.eval() | ||
# CLASS_MAPPING = {0: "Blank", 1: "Positive", 2: "Negative", 3: "Uncertain"} | ||
CLASS_MAPPING = {0: "Blank", 1: "Positive", 2: "Negative", 3: "Positive"} | ||
NO_FINDING_CLASS_MAPPING = {0: "Negative", 1: "Positive"} | ||
LABEL_MAPPING = {0: 0, 1: 1, 2: 2, 3: 1} | ||
TEM_keywords = { | ||
"bigger", | ||
"change", | ||
"cleared", | ||
"constant", | ||
"decrease", | ||
"decreased", | ||
"decreasing", | ||
"elevated", | ||
"elevation", | ||
"enlarged", | ||
"enlargement", | ||
"enlarging", | ||
"expanded", | ||
"greater", | ||
"growing", | ||
"improved", | ||
"improvement", | ||
"improving", | ||
"increase", | ||
"increased", | ||
"increasing", | ||
"larger", | ||
"new", | ||
"persistence", | ||
"persistent", | ||
"persisting", | ||
"progression", | ||
"progressive", | ||
"reduced", | ||
"removal", | ||
"resolution", | ||
"resolved", | ||
"resolving", | ||
"smaller", | ||
"stability", | ||
"stable", | ||
"stably", | ||
"unchanged", | ||
"unfolded", | ||
"worse", | ||
"worsen", | ||
"worsened", | ||
"worsening", | ||
"unaltered", | ||
} | ||
batch_size = 128 | ||
ref_observations = [] | ||
hyp_observations = [] | ||
y_preds = [] | ||
y_trues = [] | ||
macro_y_preds = [] | ||
macro_y_trues = [] | ||
for i in tqdm(range(0, len(references), batch_size), desc="Calculating CE Scores"): | ||
ref = references[i : i + batch_size] | ||
hyp = hypotheses[i : i + batch_size] | ||
ref_input = bert_tokenizer.batch_encode_plus( | ||
ref, return_tensors="pt", padding=True, truncation=True, max_length=512 | ||
) | ||
hyp_input = bert_tokenizer.batch_encode_plus( | ||
hyp, return_tensors="pt", padding=True, truncation=True, max_length=512 | ||
) | ||
ref_input = {k: v.cuda() for k, v in ref_input.items()} | ||
hyp_input = {k: v.cuda() for k, v in hyp_input.items()} | ||
ref_logits = chexbert( | ||
source_padded=ref_input["input_ids"], | ||
attention_mask=ref_input["attention_mask"], | ||
) | ||
hyp_logits = chexbert( | ||
source_padded=hyp_input["input_ids"], | ||
attention_mask=hyp_input["attention_mask"], | ||
) | ||
ref_status = [l.argmax(dim=1).tolist() for l in ref_logits] | ||
hyp_status = [l.argmax(dim=1).tolist() for l in hyp_logits] | ||
y_pred = np.zeros((len(ref_status[0]), len(CONDITIONS))) | ||
y_true = np.zeros((len(hyp_status[0]), len(CONDITIONS))) | ||
macro_y_pred = np.zeros((len(ref_status[0]), len(CONDITIONS))) | ||
macro_y_true = np.zeros((len(hyp_status[0]), len(CONDITIONS))) | ||
ref_obs = [[] for _ in range(len(ref_status[0]))] | ||
hyp_obs = [[] for _ in range(len(hyp_status[0]))] | ||
for i, c in enumerate(CONDITIONS): | ||
i_ref_status = ref_status[i] | ||
i_hyp_status = hyp_status[i] | ||
if c == "No Finding": | ||
class_mapping = NO_FINDING_CLASS_MAPPING | ||
else: | ||
class_mapping = CLASS_MAPPING | ||
for j in range(len(i_hyp_status)): # batch_size | ||
macro_y_pred[j][i] = i_hyp_status[j] | ||
macro_y_true[j][i] = i_ref_status[j] | ||
if LABEL_MAPPING[i_hyp_status[j]] == 1: | ||
y_pred[j][i] = 1 | ||
if LABEL_MAPPING[i_ref_status[j]] == 1: | ||
y_true[j][i] = 1 | ||
if i_hyp_status[j] != 0 or c == "No Finding": | ||
hyp_obs[j].append(":".join((c, class_mapping[i_hyp_status[j]]))) | ||
if i_ref_status[j] != 0 or c == "No Finding": | ||
ref_obs[j].append(":".join((c, class_mapping[i_ref_status[j]]))) | ||
|
||
y_preds.append(y_pred) | ||
y_trues.append(y_true) | ||
macro_y_preds.append(macro_y_pred) | ||
macro_y_trues.append(macro_y_true) | ||
ref_observations.extend(ref_obs) | ||
hyp_observations.extend(hyp_obs) | ||
y_preds = np.concatenate(y_preds, axis=0) | ||
y_trues = np.concatenate(y_trues, axis=0) | ||
macro_y_preds = np.concatenate(macro_y_preds, axis=0) | ||
macro_y_trues = np.concatenate(macro_y_trues, axis=0) | ||
ce_prf = [0, 0, 0] | ||
macro_ce_prf = [0, 0, 0] | ||
temporal_ce_prf = [0, 0, 0] | ||
macro_temporal_ce_prf = [0, 0, 0] | ||
|
||
print("--------------------------------------------------------------") | ||
pad_conditions = pad_strings(CONDITIONS) | ||
for i, c in enumerate(CONDITIONS): | ||
# for all reports | ||
y_true = y_trues[:, i] | ||
y_pred = y_preds[:, i] | ||
i_prf = precision_recall_fscore_support( | ||
y_true=y_true, y_pred=y_pred, average="binary", pos_label=1 | ||
) | ||
ce_prf = [ce_prf[j] + i_prf[j] for j in range(3)] | ||
|
||
print( | ||
"%s\tPrec. %0.4f\tRec. %0.4f\tF1 %0.4f" | ||
% (pad_conditions[i], i_prf[0], i_prf[1], i_prf[2]) | ||
) | ||
|
||
y_true = macro_y_trues[:, i] | ||
y_pred = macro_y_preds[:, i] | ||
i_prf = precision_recall_fscore_support( | ||
y_true=y_true, y_pred=y_pred, average="macro" | ||
) | ||
macro_ce_prf = [macro_ce_prf[j] + i_prf[j] for j in range(3)] | ||
|
||
# for reports with temporal information | ||
y_true = [z for z, k in zip(y_trues[:, i], is_temporals) if k] | ||
y_pred = [z for z, k in zip(y_preds[:, i], is_temporals) if k] | ||
i_prf = precision_recall_fscore_support( | ||
y_true=y_true, y_pred=y_pred, average="binary", pos_label=1 | ||
) | ||
temporal_ce_prf = [temporal_ce_prf[j] + i_prf[j] for j in range(3)] | ||
|
||
y_true = [z for z, k in zip(macro_y_trues[:, i], is_temporals) if k] | ||
y_pred = [z for z, k in zip(macro_y_preds[:, i], is_temporals) if k] | ||
i_prf = precision_recall_fscore_support( | ||
y_true=y_true, y_pred=y_pred, average="macro" | ||
) | ||
macro_temporal_ce_prf = [macro_temporal_ce_prf[j] + i_prf[j] for j in range(3)] | ||
print("--------------------------------------------------------------") | ||
ce_prf = [ce_prf[j] / len(CONDITIONS) for j in range(3)] | ||
macro_ce_prf = [macro_ce_prf[j] / len(CONDITIONS) for j in range(3)] | ||
temporal_ce_prf = [temporal_ce_prf[j] / len(CONDITIONS) for j in range(3)] | ||
macro_temporal_ce_prf = [ | ||
macro_temporal_ce_prf[j] / len(CONDITIONS) for j in range(3) | ||
] | ||
|
||
tp = 0 | ||
count_gen = 0 | ||
count_ref = 0 | ||
for ref, hyp, is_temporal in zip(references, hypotheses, is_temporals): | ||
if not is_temporal: | ||
continue | ||
ref_tem = set([z for z in ref.split() if z in TEM_keywords]) | ||
hyp_tem = set([z for z in hyp.split() if z in TEM_keywords]) | ||
tp += len(ref_tem & hyp_tem) | ||
count_gen += len(hyp_tem) | ||
count_ref += len(ref_tem) | ||
tem_prec = tp / max(count_gen, 1) | ||
tem_rec = tp / max(count_ref, 1) | ||
tem_f1 = 2 * tem_prec * tem_rec / max((tem_prec + tem_rec), 0.1) | ||
tem_score = [tem_prec, tem_rec, tem_f1] | ||
return ( | ||
ref_observations, | ||
hyp_observations, | ||
ce_prf, | ||
temporal_ce_prf, | ||
macro_ce_prf, | ||
macro_temporal_ce_prf, | ||
tem_score, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
from dataclasses import dataclass, field | ||
from typing import Optional | ||
|
||
|
||
@dataclass | ||
class DataTrainingArguments: | ||
""" | ||
Arguments pertaining to what data we are going to input our model for training and eval. | ||
""" | ||
|
||
dataset_name: Optional[str] = field( | ||
default=None, | ||
metadata={"help": "The name of the dataset to use (via the datasets library)."}, | ||
) | ||
dataset_config_name: Optional[str] = field( | ||
default=None, | ||
metadata={ | ||
"help": "The configuration name of the dataset to use (via the datasets library)." | ||
}, | ||
) | ||
image_path: Optional[str] = field( | ||
default=None, | ||
metadata={ | ||
"help": "The text model checkpoint for weights initialization." | ||
"Don't set if you want to train a model from scratch." | ||
}, | ||
) | ||
annotation_file: Optional[str] = field( | ||
default=None, | ||
metadata={ | ||
"help": "The text model checkpoint for weights initialization." | ||
"Don't set if you want to train a model from scratch." | ||
}, | ||
) | ||
miss_annotation_file: Optional[str] = field( | ||
default=None, | ||
metadata={ | ||
"help": "The text model checkpoint for weights initialization." | ||
"Don't set if you want to train a model from scratch." | ||
}, | ||
) | ||
history: Optional[str] = field( | ||
default=None, | ||
metadata={ | ||
"help": "The text model checkpoint for weights initialization." | ||
"Don't set if you want to train a model from scratch." | ||
}, | ||
) | ||
graph_version: Optional[str] = field( | ||
default=None, | ||
) | ||
progression_graph: Optional[str] = field( | ||
default=None, | ||
) | ||
chexbert_label: Optional[str] = field(default=None) | ||
debug_model: Optional[bool] = field(default=False) | ||
max_tgt_length: Optional[int] = field( | ||
default=64, | ||
) | ||
is_stage1_pretrained: int = field(default=1) | ||
is_temporal: int = field(default=1) | ||
eval_on_gen: Optional[bool] = field(default=False) | ||
max_train_samples: Optional[int] = field( | ||
default=None, | ||
metadata={ | ||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this " | ||
"value if set." | ||
}, | ||
) | ||
max_eval_samples: Optional[int] = field( | ||
default=None, | ||
metadata={ | ||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " | ||
"value if set." | ||
}, | ||
) | ||
|
||
block_size: Optional[int] = field( | ||
default=None, | ||
metadata={ | ||
"help": "Optional input sequence length after tokenization. " | ||
"The training dataset will be truncated in block of this size for training. " | ||
"Default to the model max input length for single sentence inputs (take into account special tokens)." | ||
}, | ||
) | ||
overwrite_cache: bool = field( | ||
default=False, | ||
metadata={"help": "Overwrite the cached training and evaluation sets"}, | ||
) | ||
validation_split_percentage: Optional[int] = field( | ||
default=5, | ||
metadata={ | ||
"help": "The percentage of the train set used as validation set in case there's no validation split" | ||
}, | ||
) | ||
preprocessing_num_workers: Optional[int] = field( | ||
default=None, | ||
metadata={"help": "The number of processes to use for the preprocessing."}, | ||
) | ||
keep_linebreaks: bool = field( | ||
default=True, | ||
metadata={"help": "Whether to keep line breaks when using TXT files or not."}, | ||
) | ||
alpha: Optional[float] = field(default=3) | ||
beta: Optional[float] = field(default=3) | ||
wo_op: Optional[int] = field(default=1) | ||
wo_obs: Optional[int] = field(default=1) | ||
wo_pro: Optional[int] = field(default=1) | ||
wo_prr: Optional[int] = field(default=1) | ||
topk: Optional[int] = field(default=10) | ||
lambda_: Optional[float] = field(default=0.5) |
Oops, something went wrong.