You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
HI, first a big thank you for publishing this work.
I am trying to use a trained model and query it with a new probe image.
It seems to me a very imprtant functionality , after all that is what you train the network for, right?
But I couldn't find it anywhere. I tried writing something, but I get poor results.
here is what I came up with:
any insights would be most appreciated.
thanks,
Omer
import os
import cv2
import numpy as np
from model import DCGAN
from utils import get_image, image_save, save_images
import tensorflow as tf
from scipy.misc import imresize
flags = tf.app.flags
flags.DEFINE_integer("batch_size", 64, "The size of batch images [64]")
flags.DEFINE_integer("image_size", 128, "The size of image to use")
flags.DEFINE_string("checkpoint_dir", "/home/omer/work/sub_pixel/models",
"Directory name to read the checkpoints [checkpoint]")
flags.DEFINE_string("test_image_dir", "/home/omer/work/sub_pixel/data/celebA/valid",
"Directory name of the images to evaluate")
flags.DEFINE_string("out_dir", "/home/omer/work/sub_pixel/out", "Directory name of to save results in")
FLAGS = flags.FLAGS
def doresize(x, shape):
x = np.copy((x + 1.) * 127.5).astype("uint8")
y = imresize(x, shape)
return y
def main():
with tf.Session() as sess:
dcgan = DCGAN(sess, image_size=FLAGS.image_size, image_shape=[FLAGS.image_size, FLAGS.image_size, 3],
batch_size=FLAGS.batch_size,
dataset_name='celebA', is_crop=False, checkpoint_dir=FLAGS.checkpoint_dir)
res = dcgan.load(FLAGS.checkpoint_dir)
if not res:
print ("failed loading model from path:" + FLAGS.checkpoint_dir)
return
i = 0
files = []
num_batches = len(os.listdir(FLAGS.test_image_dir)) / FLAGS.batch_size
completed_batches = 0
input_images = np.zeros(shape=(FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, 3))
for f in os.listdir(FLAGS.test_image_dir):
try:
img_path = os.path.join(FLAGS.test_image_dir, f)
if os.path.isdir(img_path):
i += 1
continue
img = get_image(img_path, FLAGS.image_size, False)
files.append(f)
input_images[i] = img
if i == FLAGS.batch_size - 1 or i == len(os.listdir(FLAGS.test_image_dir)) - 1:
batch_ready(dcgan, input_images, sess, files)
i = 0
input_images = np.zeros(shape=(FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, 3))
files = []
completed_batches += 1
print('done batch {0} out of {1}'.format(completed_batches, num_batches))
else:
i += 1
except Exception as e:
print("problem working on:" + f)
print (str(e))
i += 1
def batch_ready(dcgan, input_images, sess, files):
input_resized = [doresize(xx, (32, 32, 3)) for xx in input_images]
sample_input_resized = np.array(input_resized).astype(np.float32)
sample_input_images = np.array(input_images).astype(np.float32)
output_images = sess.run(fetches=[dcgan.G],
feed_dict={dcgan.inputs: sample_input_resized, dcgan.images: sample_input_images})
save_results(output_images, files)
def save_results(output_images, files):
for k in range(0, len(files)):
out_path = os.path.join(FLAGS.out_dir, files[k] + '_.png')
out_img = output_images[0][k]
# out_correct = ((out_img + 1) * 127.5).astype(np.uint8)
# out_correct = cv2.cvtColor(out_correct, cv2.COLOR_RGB2BGR)
# cv2.imshow('image', out_correct)
# cv2.waitKey(0)
image_save(out_img, out_path)
if __name__ == '__main__':
main()
The text was updated successfully, but these errors were encountered:
HI, first a big thank you for publishing this work.
I am trying to use a trained model and query it with a new probe image.
It seems to me a very imprtant functionality , after all that is what you train the network for, right?
But I couldn't find it anywhere. I tried writing something, but I get poor results.
here is what I came up with:
any insights would be most appreciated.
thanks,
Omer
The text was updated successfully, but these errors were encountered: