Code | HuggingFace Model |
---|---|
MoH-ViT | 🤗 MoH-ViT-B-75, MoH-ViT-B-50, MoH-ViT-S-80, MoH-ViT-S-75 |
MoH-DiT | 😊 MoH-DiT-90 |
MoH-LLaMA3-8B | 😊 MoH-LLaMA3-8B |
We provide an environment.yml
file that can be used to create a Conda environment. If you only want
to run pre-trained models locally on CPU, you can remove the cudatoolkit
and pytorch-cuda
requirements from the file.
conda env create -f environment.yml
conda activate DiT
If you've trained a new MoH-DiT model with train.py
(see below), you can add the --ckpt
argument to use your own checkpoint instead. For example, to sample from the EMA weights of a custom
256x256 MoH-DiT-XL/2-90 model, run:
python sample.py --model MoH-DiT-XL/2-90 --image-size 256 --ckpt /path/to/model.pt
We provide a training script for MoH-DiT in train.py
. This script can be used to train class-conditional
MoH-DiT models, but it can be easily modified to support other types of conditioning. To launch MoH-DiT-XL/2-90 (256x256) training with 8 GPUs on
one node:
torchrun --nnodes=1 \
--nproc_per_node=8 train.py \
--model MoH-DiT-XL/2-90 \
--data-path /path/to/imagenet/train \
--results-dir results/MoH-DiT-XL-2-90
We include a sample_ddp.py
script which samples a large number of images from a MoH-DiT model in parallel. This script
generates a folder of samples as well as a .npz
file which can be directly used with ADM's TensorFlow
evaluation suite to compute FID, Inception Score and
other metrics. For example, to sample 50K images from our pre-trained MoH-DiT-XL/2-90 model over 8 GPUs, run:
torchrun --nnodes=1 --nproc_per_node=8 sample_ddp.py --model MoH-DiT-XL/2-90 --num-fid-samples 50000
There are several additional options; see sample_ddp.py
for details.