Semi-supervised Learning method for image classification (1k classes). Structure: CNN with residual connections stacked on top of the encoder module of a convolutional auto-encoder (Pytorch). The encoder network weights are trained using unlabeled images after which they are frozen. A CNN with residual connections is then stacked on top and trained on labeled data for classification.
Run python code or HPC job CAE_train.sbatch to train and evaluate this Semi-supervised learning strategy.
The mean teacher code is originated from Mean Teacher repository. The original one comes with resnet152 and cifar_shakeshake26 model architectures. Here, we added a ResNet18 to the but feel free to use any model that is compatible with your purposes. We also modified the script so that it works with the lastest pytorch version.
Before runnig the code, put the training dataset and validation set under the folder: images/ilsvrc2012 with the folder name 'train' and 'val'. Put your records of labeled data in a separate folder under data-local/labels in a txt file with a desired format.
The fine-tune.sbatch file provides a suggestion of hyperparamters to use, but it is not necessary the most optimized version. The sbatch file can be used to submit jobs directly to an HPC.
In order to see how to do the training, run python --help to get more details.