Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference on custom images. Add/Update readme files. #14

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,19 +78,26 @@ cd src
./train.sh
```
* We implement our method by PyTorch and conduct experiments on 2 NVIDIA 2080Ti GPUs.
* We adopt pre-trained [ResNet-18](https://download.pytorch.org/models/resnet18-5c106cde.pth) and [Swin-B-224](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth) as backbone networks, which are saved in PRE folder.
* We adopt pre-trained [ResNet-18](https://download.pytorch.org/models/resnet18-5c106cde.pth) and [Swin-B-224](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth) as backbone networks, which are saved in **pre** folder.
* We train our method on 3 settings : DUTS-TR, DUTS-TR+HRSOD and UHRSD_TR+HRSOD_TR.
* After training, the trained models will be saved in MODEL folder.
* After training, the trained models will be saved in **model** folder.

### Test
The trained model can be download here: [Google Drive](https://drive.google.com/drive/folders/1hXwCvrdmvkaRePXWPTw5tjFXmrrzHPtt?usp=sharing)
The trained model can be download here: [Google Drive](https://drive.google.com/drive/folders/1hXwCvrdmvkaRePXWPTw5tjFXmrrzHPtt?usp=sharing)
Rename the downloaded file to *model-31* and save it in **model** folder.

To test on the datasets, change working directory to **src** and run *test.py* as follows:
```
cd src
python test.py
```
* After testing, saliency maps will be saved in RESULT folder

To inference on custom images in a folder, change working directory to **src** and run *test_images.py* as follows:
```
cd src
python test_images.py /path/to/folder
```
* After testing, saliency maps will be saved in **result** folder.



Expand Down
4 changes: 4 additions & 0 deletions model/PGNet_DUT+HR/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Ignore everything in this directory
*
# Except this file
!.gitignore
4 changes: 4 additions & 0 deletions model/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
The PGNet_DUT+HR trained model can be download here: [Google Drive](https://drive.google.com/drive/folders/1hXwCvrdmvkaRePXWPTw5tjFXmrrzHPtt?usp=sharing)
1. Download the trained model file.
2. Move it into the PGNet_DUT+HR folder within this folder.
3. Rename the downloaded file to *model-31*.
5 changes: 5 additions & 0 deletions pre/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Ignore everything in this directory
*
# Except these file
!.gitignore
!.README.md
4 changes: 4 additions & 0 deletions pre/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
We adopt pre-trained [ResNet-18](https://download.pytorch.org/models/resnet18-5c106cde.pth) and [Swin-B-224](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth) as backbone networks.
1. Download both pre-trained models and move them to this folder.
2. Rename downloaded ResNet-18 model file to *resnet18.pth*.
3. Rename downloaded Swin-B-224 model file to *swin224.pth*.
5 changes: 5 additions & 0 deletions result/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Ignore everything in this directory
*
# Except these file
!.gitignore
!README.md
18 changes: 18 additions & 0 deletions src/dataset.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def __init__(self, cfg):
img_name = each.split("/")[-1]
img_name = img_name.split(".")[0]
self.samples.append(img_name)

def __getitem__(self, idx):
name = self.samples[idx]
tig='.jpg'
Expand Down Expand Up @@ -129,6 +130,23 @@ def __len__(self):
return len(self.samples)


class DataImage(Data):
def __init__(self, cfg):
super().__init__(cfg)
self.samples = [os.path.join(self.cfg.datapath, i) for i in os.listdir(self.cfg.datapath)]

def __getitem__(self, idx):
name = self.samples[idx]
image = cv2.imread(name).astype(np.float32)
image = image[:,:,::-1].copy()

mask = image[:,:,0]
shape = mask.shape #
image, mask = self.normalize(image, mask)
image, mask = self.resize(image, mask)
image, mask = self.totensor(image, mask)
return image, mask, shape, name

########################### Testing Script ###########################
if __name__=='__main__':
import matplotlib.pyplot as plt
Expand Down
12 changes: 10 additions & 2 deletions src/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@ def save(self):

if __name__=='__main__':
for path in ['../data/DAVIS-S','../data/UHRSD_TE','../data/HRSOD_TE','../data/DUT-OMRON','../data/HKU-IS','../data/ECSSD','../data/DUTS-TE','../data/PASCAL-S']:
for model in ['model-27','model-28','model-29','model-30','model-31','model-32']:
t = Test(dataset,PGNet, path,'./PGNet_DUT+HR/'+model)
if not os.path.isdir(path):
print(f'Skipping dataset. Directory does not exist: {path}')
continue
for model in ['model-27','model-28','model-29','model-30','model-31','model-32']:
model_path = os.path.join('model', 'PGNet_DUT+HR', model)
if os.path.isfile(model_path):
print(f'Testing model {model_path}')
t = Test(dataset,PGNet, path, model_path)
t.save()
else:
print(f'Skipping {model_path} because file does not exist.')
81 changes: 81 additions & 0 deletions src/test_bigbird.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import os
import sys
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from dataset import Data, Config, DataImage
from PGNet import PGNet


class Test(object):
def __init__(self, network, path, model):
## dataset
self.model = model
self.cfg = Config(datapath=path, snapshot=model, mode='test')
self.data = DataImage(self.cfg)
self.loader = DataLoader(self.data, batch_size=1, shuffle=False, num_workers=2)
## network
self.net = network(self.cfg)
self.net.train(False)
self.net.cuda()

def save(self, path, forwards=1):
print(f'Saving results in {path}')
os.makedirs(path, exist_ok=True)

with torch.no_grad():
for image, mask, shape, name in self.loader:
image = image.cuda().float()
mask = mask.cuda().float()

# Successive iteration of forwards on previous results
for i in range(forwards):
p = self.net(image, shape=None)
# Replicate 1 channel mask into 3 channels
image = image.expand(-1, 3,-1,-1)

# Resize and save
out_resize = F.interpolate(p[0],size=shape, mode='bilinear')
pred = torch.sigmoid(out_resize[0,0])
pred = (pred*255).cpu().numpy()
name = os.path.basename(name[0])
out = os.path.join(path, name)
cv2.imwrite(out, np.round(pred))


if __name__=='__main__':
import argparse
parser = argparse.ArgumentParser(description = 'Saliency detection on cropped BigBird images')
parser.add_argument('--model', help='path to model')
parser.add_argument('--root', help='root folder containing cropped bigbird object instances', required=True)
parser.add_argument('--objects', help='only crop specific objects', nargs='*', default=None)
parser.add_argument('--in-folder', help='name of folder where cropped images are stored', required=True)
parser.add_argument('--out', help='output root path (default=root)', default=None)
parser.add_argument('--out-folder', help='name of folder where output saliency maps are to be stored', required=True)
parser.add_argument('--forwards', help='number of forward passes (default=1)', type=int, default=1)
args = parser.parse_args()

for k,v in args.__dict__.items():
print(f'{k:->20} : {v}')

# Load object names
objects = args.objects
if objects is None:
objects = [i for i in os.listdir(args.root) if os.path.isdir(os.path.join(args.root, i))]
print(f'Inferencing on {len(objects)} objects')

# Get output root
out = args.out
if out is None:
out = args.root

# Iterate over each object
for obj_idx, obj in enumerate(objects, 1):
print(f'\nInferencing on \t[{obj_idx:>3d}/{len(objects)}] : \t{obj}')
obj_img_path = os.path.join(args.root, obj, args.in_folder)
obj_out_path = os.path.join(out, obj, args.out_folder)
t = Test(PGNet, obj_img_path, args.model)
t.save(obj_out_path, args.forwards)
print('-'*70)
53 changes: 53 additions & 0 deletions src/test_images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#!/usr/bin/python3
#coding=utf-8

import os
import sys
sys.path.insert(0, '../')
sys.dont_write_bytecode = True
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from dataset import Data, Config, DataImage
from PGNet import PGNet

class Test(object):
def __init__(self, Network, path, model):
## dataset
self.model = model
self.cfg = Config(datapath=path, snapshot=model, mode='test')
self.data = DataImage(self.cfg)
self.loader = DataLoader(self.data, batch_size=1, shuffle=False, num_workers=2)
## network
self.net = Network(self.cfg)
self.net.train(False)
self.net.cuda()

def save(self):
head = os.path.join('../result', self.model[3:], self.cfg.datapath.split(os.sep)[-1])
if not os.path.exists(head):
os.makedirs(head)
print(f'Saving results at {head}')

with torch.no_grad():
for image, mask, shape, name in self.loader:

image = image.cuda().float()
mask = mask.cuda().float()
p = self.net(image, shape=None)
out_resize = F.interpolate(p[0],size=shape, mode='bilinear')
pred = torch.sigmoid(out_resize[0,0])
pred = (pred*255).cpu().numpy()

name = os.path.basename(name[0])
out = os.path.join(head, name.split('.')[0]+'_mask.png')
print(out)
cv2.imwrite(out, np.round(pred))

if __name__=='__main__':
img_root = sys.argv[1]
for model in ['model-31']:
t = Test(PGNet, img_root, os.path.join('../model', 'PGNet_DUT+HR', model))
t.save()