Skip to content

Latest commit

 

History

History
79 lines (62 loc) · 2.89 KB

README.md

File metadata and controls

79 lines (62 loc) · 2.89 KB

Soft Attention Image Captioning

Tensorflow implementation of Show, Attend and Tell presented in ICML'15.

Huge re-factor from last update, compatible with tensorflow >= r1.0

Prerequisites

Data

Preparation

  1. Clone this repo, create data/ and log/ folders:

    git clone https://github.com/markdtw/soft-attention-image-captioning.git
    cd soft-attention-image-captioning
    mkdir data
    mkdir log
  2. Download and extract pre-trained Inception V4 and VGG 19 from tf.slim for feature extraction.
    Save the ckpt files in cnns/ as inception_v4_imagenet.ckpt and vgg_19_imagenet.ckpt.

  3. We need the following files in our data/ folder:

    • coco_raw.json
    • coco_processed.json
    • coco_dictionary.pkl
    • coco_final.json
    • train2014_vgg(inception).npy and val2014_vgg(inception).npy

    These files can be generated through utils.py, please refer to it before executing.

  4. If you are not able to extract the features yourself, here is the features download link:

    • It may take a long time to download.

Train

Train from scratch with default settings:

python main.py --train

Train from a pre-trained model from epoch X:

python main.py --train --model_path=log/model.ckpt-X

Check out tunable arguments:

python main.py

Generate a caption

Using default(latest) model:

python main.py --generate --img_path=/path/to/image.jpg

Using model from epoch X:

python main.py --generate --img_path=/path/to/image.jpg --model_path=log/model.ckpt-X

Others

  • Features extracted are around 16 + 8 GB. Make sure you have enough CPU memory when loading the data.
  • GPU memory usage for batch_size 128 is around 8GB.
  • Utilize tf.while_loop for rnn implementation, tf.slim for feature extraction from their github page.
  • GRU cell is implemented, use it by setting --use_gru=True when training.
  • Features can be extracted through inceptionV4, if so, model.ctx_dim in model.py needs to be set to (64, 1536). (other modifications are needed)
  • Issues are welcome!

Resources