Skip to content

Commit

Permalink
predictions are made
Browse files Browse the repository at this point in the history
  • Loading branch information
idan-tankel committed Aug 30, 2023
1 parent 546bbc4 commit 1377966
Show file tree
Hide file tree
Showing 3 changed files with 243,995 additions and 3 deletions.
2 changes: 1 addition & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
"--model",
"instruct_blip",
"--anno_path",
"${workspaceFolder}/SEED-Bench/SEED_filtered_debug.json",
"${workspaceFolder}/SEED-Bench/SEED-Bench.json",
"--output-dir",
"results"
],
Expand Down
12 changes: 10 additions & 2 deletions SEED-Bench/eval.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import json
import argparse

import pandas as pd
import torch
from tqdm import tqdm
import numpy as np
Expand Down Expand Up @@ -32,8 +32,10 @@ def build_model(model_name):
if model_name == 'instruct_blip':
from instruct_blip_interface import build

model = build()

model = build()
if model_name == "blip2ForConditionalGeneration":
pass
return model


Expand Down Expand Up @@ -68,7 +70,11 @@ def run_inference(model, qa_anno, output_dir):
data_info['data_path'] = data_path

# losses: loss values of 4 choices, torch tensor, shape=[4]
# prompt the model with the code {question} + {choice}/
# prompt that 4 times
# choose the one that minimizes the loss out of these 4 times
losses = model(data_info)

class_ranks = torch.argsort(losses, dim=-1).cpu()
pred_id = ['A', 'B', 'C', 'D'][class_ranks[0]]
gt = qa_item['answer']
Expand Down Expand Up @@ -123,6 +129,8 @@ def run_inference(model, qa_anno, output_dir):
if 'questions' in qa_anno.keys():
qa_anno = qa_anno['questions']

x = pd.DataFrame(qa_anno)
x = x[x.question_type_id == 1]
if not os.path.exists(args.output_dir):
os.mkdir(args.output_dir)

Expand Down
Loading

0 comments on commit 1377966

Please sign in to comment.