Skip to content

Latest commit

 

History

History
87 lines (56 loc) · 3.74 KB

File metadata and controls

87 lines (56 loc) · 3.74 KB

Compositionality Generalization

This repository is the official implementation of our paper Compositional Generalization by Learning Analytical Expressions.

If you find our code useful for you, please consider citing our paper

@inproceedings{qian2020compositional,
  title={Compositional Generalization by Learning Analytical Expressions},
  author={Qian Liu and Shengnan An and Jian{-}Guang Lou and Bei Chen and Zeqi Lin and Yan Gao and Bin Zhou and Nanning Zheng and Dongmei Zhang},
  booktitle={Advances in Neural Information Processing Systems 33: Annual Conference on Neural Information Processing Systems 2020, NeurIPS 2020, December 6-12, 2020, virtual},
  year={2020}
}

Content

Requirements

Our code is officially supported by Python 3.7. The main dependencies are pytorch and tensorboardX. You could install all requirements by the following command:

❱❱❱ pip install -r requirements.txt

Training

To train our model on different tasks on SCAN and SCAN-ext datasets, you could use this command:

❱❱❱ python main.py --mode train --checkpoint <model_dir> --task <task_name>

📋 Note that <model_dir> specifies the store folder of model checkpoints, and <task_name> is the task name. Available task names are [simple, addjump, around_right, length, mcd1, mcd2, mcd3, extend].

For example, you could train a model on addjump task by the following command:

❱❱❱ python main.py --mode train --checkpoint addjump_model --task addjump

📋 Since reinforcement learning is known to be hard to train, there is a chance of the code to not converge in the training. You could choose another random seed and try again.

📋 Meanwhile, please note that the model training is sensitive to the value of the hyper-parameter coefficient of the simplicity-based reward (i.e. --simplicity-ratio in args). When it is higher (i.e. 0.5 or 1.0), the model is harder to converge, which indicates that the training accuracy may not arrive at 100%. We're still investigating in the reason behind it. If you cannot obtain good results after trying several random seed, you could try to reproduce other results (not suitable for around_right and mcd3, as stated in the paper) using a 0 simplicity-ratio (default setting now). We will update the code when we find a better training strategy.

Therefore, please use the following command for around_right and mcd3 task:

❱❱❱ python main.py --mode train --checkpoint addjump_model --task around_right --simplicity-ratio 0.5

The corresponding log and model weights will be stored in the path checkpoint/logs/addjump_model.log and checkpoint/models/addjump_model/*.mdl respectively

Evaluation

To evaluate our model on different tasks, run:

❱❱❱ python main.py --mode test --checkpoint <model_weight_file> --task <task_name>

📋 Note that <model_weight_file> specifies a concrete model file with the suffix .mdl, and <task_name> is the task name.

For example, you could evaluate a trained model weight weight.mdl on addjump task by the following command:

❱❱❱ python main.py --mode test --checkpoint weight.mdl --task addjump

Pre-trained Models

You can find pretrained model weights for the above tasks under the pretrained_weights folder.

Results

Our model is excepted to achieve 100% accuracies on all tasks if the training succeeds.