From 4d34ea98a5470d5c27ac33a91bc5b43d785d2af6 Mon Sep 17 00:00:00 2001 From: zheedong Date: Sun, 25 Feb 2024 15:50:02 +0000 Subject: [PATCH] Update Stage 1 Proposal --- .vscode/launch.json | 11 +- coco_dataloader.py | 2 +- configs/data/dci_llava.yaml | 4 + configs/data/laion_capsfusion.yaml | 5 + configs/data/laion_capsfusion_sc.yaml | 4 + configs/data/laion_capsfusion_val.yaml | 4 + configs/eval/seed_FID.yaml | 60 + configs/seed_training_proj_test.yaml | 2 +- configs/seed_unified_test_c2f.yaml | 78 + configs/seed_unified_test_sds.yaml | 38 +- configs/seed_unified_test_sds_debug.yaml | 75 + ...n_test.yaml => long_caption_training.yaml} | 20 +- configs/training/stage1/stage_1_training.yaml | 69 + .../stage1/stage_1_training_long_caption.yaml | 73 + .../stage_1_training_long_caption_debug.yaml | 73 + .../training/stage2/seed_stage2_training.yaml | 1 - data/DCI_LLaVA_wds | 1 + data/laion_capsfusion_wds | 1 + datamodules/c2f_datamodule.py | 363 +++++ datamodules/compression_datamodule.py | 35 +- datamodules/datasets/coco_val.py | 21 + datamodules/seed_llama_datamodule.py | 170 +- models/seed_llama_tokenizer.py | 1 + test.py | 7 +- test_shard.py | 1 + train_sds_wo_codebook.py | 1 + train_v7_FID.py | 247 +++ train_v7_unified.py | 30 +- train_v7_unified_c2f.py | 1386 +++++++++++++++++ train_v7_unified_llm.py | 1198 ++++++++++++++ train_v7_unified_sds.py | 100 +- train_v8_seed_stage1_long_caption.py | 63 +- utils/config.py | 2 +- 33 files changed, 3973 insertions(+), 173 deletions(-) create mode 100755 configs/data/dci_llava.yaml create mode 100755 configs/data/laion_capsfusion.yaml create mode 100644 configs/data/laion_capsfusion_sc.yaml create mode 100755 configs/data/laion_capsfusion_val.yaml create mode 100755 configs/eval/seed_FID.yaml create mode 100644 configs/seed_unified_test_c2f.yaml create mode 100644 configs/seed_unified_test_sds_debug.yaml rename configs/training/stage1/{long_caption_test.yaml => long_caption_training.yaml} (79%) create mode 100755 configs/training/stage1/stage_1_training.yaml create mode 100755 configs/training/stage1/stage_1_training_long_caption.yaml create mode 100755 configs/training/stage1/stage_1_training_long_caption_debug.yaml delete mode 120000 configs/training/stage2/seed_stage2_training.yaml create mode 120000 data/DCI_LLaVA_wds create mode 120000 data/laion_capsfusion_wds create mode 100644 datamodules/c2f_datamodule.py create mode 100644 datamodules/datasets/coco_val.py create mode 120000 test_shard.py create mode 100755 train_v7_FID.py create mode 100644 train_v7_unified_c2f.py create mode 100755 train_v7_unified_llm.py diff --git a/.vscode/launch.json b/.vscode/launch.json index de00bac..05b6cb4 100755 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -9,10 +9,13 @@ "type": "python", "request": "launch", "program": "${file}", - // "env": {"CUDA_VISIBLE_DEVICES":"6, 7, 8, 9"}, - "env": {"CUDA_VISIBLE_DEVICES":"9"}, - "args" : ["cfg_path=configs/training/stage2/seed_stage2_training.yaml"], - // "args" : ["cfg_path=configs/training/stage1/long_caption_test.yaml"], + // "env": {"CUDA_VISIBLE_DEVICES":"8, 9"}, + "env": {"CUDA_VISIBLE_DEVICES":"5"}, + // "env": {"CUDA_VISIBLE_DEVICES":"0,1,2,3"}, + // "args" : ["cfg_path=configs/training/stage2/seed_stage2_training.yaml"], + // "args" : ["cfg_path=configs/eval/seed_FID.yaml"], + // "args" : ["cfg_path=configs/training/stage1/long_caption_training.yaml"], + "args" : ["cfg_path=configs/training/stage1/stage_1_training_long_caption_debug.yaml"], "console": "integratedTerminal", "justMyCode": false, }, diff --git a/coco_dataloader.py b/coco_dataloader.py index 6b5a477..e19a234 100755 --- a/coco_dataloader.py +++ b/coco_dataloader.py @@ -34,7 +34,7 @@ def __init__(self, self.karpathy = json.load(f) self.start_index = start_index - self.end_index = end_index + self.end_index = None if end_index == "None" else end_index def __len__(self): if self.start_index is not None and self.end_index is not None: diff --git a/configs/data/dci_llava.yaml b/configs/data/dci_llava.yaml new file mode 100755 index 0000000..e053a8e --- /dev/null +++ b/configs/data/dci_llava.yaml @@ -0,0 +1,4 @@ +META: + - ["/ssd0/data/DCI_LLaVA_wds/{00000..00001}.tar", 20000] + +CONTAIN_TEXT: True diff --git a/configs/data/laion_capsfusion.yaml b/configs/data/laion_capsfusion.yaml new file mode 100755 index 0000000..da08fd1 --- /dev/null +++ b/configs/data/laion_capsfusion.yaml @@ -0,0 +1,5 @@ +META: + - ["/ssd0/data/laion_capsfusion_wds/{00001..01380}_000000.tar", 11152883] + # - ["/ssd0/data/laion_capsfusion_wds/{00000..00001}_000000.tar", 20000] + +CONTAIN_TEXT: False diff --git a/configs/data/laion_capsfusion_sc.yaml b/configs/data/laion_capsfusion_sc.yaml new file mode 100644 index 0000000..fb6f015 --- /dev/null +++ b/configs/data/laion_capsfusion_sc.yaml @@ -0,0 +1,4 @@ +META: + - ["/ssd0/data/laion_capsfusion_sc_wds/{00001..01380}_000000.tar", 11152883] + +CONTAIN_TEXT: False diff --git a/configs/data/laion_capsfusion_val.yaml b/configs/data/laion_capsfusion_val.yaml new file mode 100755 index 0000000..7423d7e --- /dev/null +++ b/configs/data/laion_capsfusion_val.yaml @@ -0,0 +1,4 @@ +META: + - ["/ssd0/data/laion_capsfusion_wds/{00000..00000}_000000.tar", 10000] + +CONTAIN_TEXT: False diff --git a/configs/eval/seed_FID.yaml b/configs/eval/seed_FID.yaml new file mode 100755 index 0000000..5a2c97e --- /dev/null +++ b/configs/eval/seed_FID.yaml @@ -0,0 +1,60 @@ +cfg_path: ??? +tokenizer_cfg_path: configs/tokenizer/seed_llama_tokenizer_hf.yaml +transform_cfg_path: configs/transform/clip_transform.yaml +model_cfg_path: configs/llm/seed_llama_8b.yaml +result_file_path: ./logs/seed_FID_not_bypass_codebook +checkpoint_path: + model_path: pretrained/seed_tokenizer/seed_quantizer.pt + diffusion_model_path: stabilityai/stable-diffusion-2-1-unclip + +resume: False +load_weight: False +weight_path: None +eval: False + +dist: + n_gpus: 4 + n_nodes: 1 + +dataset: + val_config: + karpathy_file_path: /ssd0/data/coco/annotations/karpathy/dataset_coco_test.json + root_dir: /ssd0/data/coco/images/val2014 + num_workers: 16 + shuffle: True + text_max_length: 128 + +stage1: + init: 'SEED' + +stage2: + bypass_codebook: False + load_diffusion: True + +experiment: + seed: 0 + stage: 2 + local_batch_size: 1024 + val_batch_size: 16 + test_split: train + max_epochs: 1 + deterministic: False + grad_accumulation: 1 + check_val_every_n_epoch: 1 + enable_checkpointing: True + log_every_n_steps: 1 + num_sanity_val_steps: 1 + num_warmup_steps: 200 + grad_clip_val: 0.5 + +optimizer: + vit_precision: 'fp16' + diffusion_precision: 'fp32' + precision: 'bf16' + max_lr: 7e-4 + +hyperparameters: + beta_1: 0.9 + beta_2: 0.999 + weight_decay: 1e-8 + diff --git a/configs/seed_training_proj_test.yaml b/configs/seed_training_proj_test.yaml index 92ea015..2c7a9b9 100755 --- a/configs/seed_training_proj_test.yaml +++ b/configs/seed_training_proj_test.yaml @@ -30,7 +30,7 @@ dataset: type: dalle-vqvae hparams: resolution: 256 - gt_text: Trueq + gt_text: True stage1: ema_update: False diff --git a/configs/seed_unified_test_c2f.yaml b/configs/seed_unified_test_c2f.yaml new file mode 100644 index 0000000..ad71a8a --- /dev/null +++ b/configs/seed_unified_test_c2f.yaml @@ -0,0 +1,78 @@ +cfg_path: ??? +tokenizer_cfg_path: configs/tokenizer/seed_llama_tokenizer_hf.yaml +transform_cfg_path: configs/transform/clip_transform.yaml +model_cfg_path: configs/llm/seed_llama_8b.yaml +result_file_path: ./logs/noexp +checkpoint_path: + model_path: pretrained/seed_tokenizer/seed_quantizer.pt + diffusion_model_path: stabilityai/stable-diffusion-2-1-unclip + +resume: False +load_weight: False +weight_path: None +eval: False + +dist: + n_gpus: 1 + n_nodes: 1 + +dataset: + train_config: + dataset_configs: ['configs/data/laion_capsfusion_sc.yaml'] + weights: [1] + shardshuffle: 100 + resampled: True + world_size: 1 + one_epoch_data_size: 1000000 + val_config: + karpathy_file_path: /ssd0/data/coco/annotations/karpathy/dataset_coco_test.json + root_dir: /ssd0/data/coco/images/val2014 + start_index: 0 + end_index: -1 + num_workers: 1 + shuffle: True + text_max_length: 128 + +stage1: + init: 'SEED' + +stage2: + bypass_codebook: True + load_diffusion: False + train_unet: False + +experiment: + seed: 0 + stage: 1 + local_batch_size: 2 + val_batch_size: 8 + test_split: train + max_epochs: 40 + deterministic: False + grad_accumulation: 8 +# check_val_every_n_epoch: 1 + val_check_interval: 400 + enable_checkpointing: True + log_every_n_steps: 1 + num_sanity_val_steps: 0 + num_warmup_steps: 200 + recon_loss_weight: 1.0 + sds_loss_weight: 0.1 + clip_loss_weight: 1.0 + use_sds_loss_schedule: True + cross_annealing: True + num_positive_samples: 4 + min_pos_weight: 0.3 + +optimizer: + vit_precision: 'fp16' + diffusion_precision: 'fp32' + precision: 'bf16' + max_lr: 7e-4 + grad_clip_val: 0.5 + +hyperparameters: + beta_1: 0.9 + beta_2: 0.999 + weight_decay: 1e-8 + diff --git a/configs/seed_unified_test_sds.yaml b/configs/seed_unified_test_sds.yaml index 45530d6..65a3f3b 100755 --- a/configs/seed_unified_test_sds.yaml +++ b/configs/seed_unified_test_sds.yaml @@ -2,7 +2,7 @@ cfg_path: ??? tokenizer_cfg_path: configs/tokenizer/seed_llama_tokenizer_hf.yaml transform_cfg_path: configs/transform/clip_transform.yaml model_cfg_path: configs/llm/seed_llama_8b.yaml -result_file_path: ./logs/sds_coco +result_file_path: ./logs/sds_coco2 checkpoint_path: model_path: pretrained/seed_tokenizer/seed_quantizer.pt diffusion_model_path: stabilityai/stable-diffusion-2-1-unclip @@ -18,12 +18,12 @@ dist: dataset: train_config: - dataset_configs: ['configs/data/cc15m.yaml'] - weights: [1] + dataset_configs: ['configs/data/cc15m.yaml', 'configs/data/laion-coco.yaml', 'configs/data/mscoco.yaml'] + weights: [1, 8, 1] shardshuffle: 100 resampled: True world_size: 1 - one_epoch_data_size: 3000000 + one_epoch_data_size: 1000000 val_config: karpathy_file_path: /ssd0/data/coco/annotations/karpathy/dataset_coco_test.json root_dir: /ssd0/data/coco/images/val2014 @@ -31,39 +31,43 @@ dataset: end_index: 256 num_workers: 16 shuffle: True + text_max_length: 128 stage1: - init: 'BLIP-2' + init: 'SEED' stage2: bypass_codebook: True load_diffusion: True - train_unet: False - use_clip_loss: True + train_unet: True experiment: seed: 0 stage: 2 - local_batch_size: 32 - val_batch_size: 4 + local_batch_size: 128 + val_batch_size: 8 test_split: train - max_epochs: 5 + max_epochs: 40 deterministic: False grad_accumulation: 8 - check_val_every_n_epoch: 1 +# check_val_every_n_epoch: 1 + val_check_interval: 400 enable_checkpointing: True log_every_n_steps: 1 num_sanity_val_steps: 1 - num_warmup_steps: 50 - grad_clip_val: 1 - val_check_interval: 200 - + num_warmup_steps: 200 + recon_loss_weight: 1.0 + sds_loss_weight: 0.1 + clip_loss_weight: 1.0 + use_sds_loss_schedule: True + cross_annealing: False optimizer: vit_precision: 'fp16' - diffusion_precision: 'fp16' + diffusion_precision: 'fp32' precision: 'bf16' - max_lr: 1e-4 + max_lr: 7e-4 + grad_clip_val: 0.5 hyperparameters: beta_1: 0.9 diff --git a/configs/seed_unified_test_sds_debug.yaml b/configs/seed_unified_test_sds_debug.yaml new file mode 100644 index 0000000..73796da --- /dev/null +++ b/configs/seed_unified_test_sds_debug.yaml @@ -0,0 +1,75 @@ +cfg_path: ??? +tokenizer_cfg_path: configs/tokenizer/seed_llama_tokenizer_hf.yaml +transform_cfg_path: configs/transform/clip_transform.yaml +model_cfg_path: configs/llm/seed_llama_8b.yaml +result_file_path: ./logs/noexp +checkpoint_path: + model_path: pretrained/seed_tokenizer/seed_quantizer.pt + diffusion_model_path: stabilityai/stable-diffusion-2-1-unclip + +resume: False +load_weight: False +weight_path: None +eval: False + +dist: + n_gpus: 1 + n_nodes: 1 + +dataset: + train_config: + dataset_configs: ['configs/data/cc15m.yaml', 'configs/data/laion-coco.yaml', 'configs/data/mscoco.yaml'] + weights: [1, 8, 1] + shardshuffle: 100 + resampled: True + world_size: 1 + one_epoch_data_size: 1000000 + val_config: + karpathy_file_path: /ssd0/data/coco/annotations/karpathy/dataset_coco_test.json + root_dir: /ssd0/data/coco/images/val2014 + start_index: 0 + end_index: 256 + num_workers: 1 + shuffle: True + text_max_length: 128 + +stage1: + init: 'SEED' + +stage2: + bypass_codebook: True + load_diffusion: True + train_unet: True + +experiment: + seed: 0 + stage: 2 + local_batch_size: 1 + val_batch_size: 1 + test_split: train + max_epochs: 40 + deterministic: True + grad_accumulation: 1 +# check_val_every_n_epoch: 1 + val_check_interval: 400 + enable_checkpointing: True + log_every_n_steps: 1 + num_sanity_val_steps: 0 + num_warmup_steps: 0 + recon_loss_weight: 1.0 + sds_loss_weight: 0.1 + clip_loss_weight: 1.0 + use_sds_loss_schedule: False + +optimizer: + vit_precision: 'fp16' + diffusion_precision: 'fp32' + precision: 'bf16' + max_lr: 7e-4 + grad_clip_val: 0.5 + +hyperparameters: + beta_1: 0.9 + beta_2: 0.999 + weight_decay: 1e-8 + diff --git a/configs/training/stage1/long_caption_test.yaml b/configs/training/stage1/long_caption_training.yaml similarity index 79% rename from configs/training/stage1/long_caption_test.yaml rename to configs/training/stage1/long_caption_training.yaml index 06cf7b1..768aafe 100755 --- a/configs/training/stage1/long_caption_test.yaml +++ b/configs/training/stage1/long_caption_training.yaml @@ -14,25 +14,15 @@ weight_path: None eval: False dist: - n_gpus: 4 + n_gpus: 2 n_nodes: 1 dataset: - train_config: - dataset_configs: ['configs/data/mscoco.yaml'] - weights: [10] - shardshuffle: 100 - resampled: True - world_size: 1 - one_epoch_data_size: 1000000 val_config: karpathy_file_path: /ssd0/data/coco/annotations/karpathy/dataset_coco_test.json root_dir: /ssd0/data/coco/images/val2014 - start_index: 0 - end_index: None num_workers: 8 shuffle: True - # Token level length text_max_length: 512 stage1: @@ -46,7 +36,7 @@ experiment: seed: 0 stage: 1 local_batch_size: 128 - val_batch_size: 256 + val_batch_size: 16 test_split: train max_epochs: 40 deterministic: False @@ -54,15 +44,15 @@ experiment: check_val_every_n_epoch: 1 enable_checkpointing: True log_every_n_steps: 1 - num_sanity_val_steps: 0 - num_warmup_steps: 200 - grad_clip_val: 0.5 + num_sanity_val_steps: 1 + num_warmup_steps: 50 optimizer: vit_precision: 'fp16' diffusion_precision: 'fp16' precision: 'bf16' max_lr: 5e-6 + grad_clip_val: 0.5 hyperparameters: beta_1: 0.9 diff --git a/configs/training/stage1/stage_1_training.yaml b/configs/training/stage1/stage_1_training.yaml new file mode 100755 index 0000000..861e59a --- /dev/null +++ b/configs/training/stage1/stage_1_training.yaml @@ -0,0 +1,69 @@ +cfg_path: ??? +tokenizer_cfg_path: configs/tokenizer/seed_llama_tokenizer_hf.yaml +transform_cfg_path: configs/transform/clip_transform.yaml +model_cfg_path: configs/llm/seed_llama_8b.yaml +result_file_path: ./logs/seed_unify_test +checkpoint_path: + model_path: pretrained/seed_tokenizer/seed_quantizer.pt + diffusion_model_path: stabilityai/stable-diffusion-2-1-unclip + +resume: False +load_weight: False +weight_path: None +eval: False + +dist: + n_gpus: 2 + n_nodes: 1 + +dataset: + train_config: + dataset_configs: ['configs/data/cc15m.yaml', 'configs/data/laion-coco.yaml', 'configs/data/mscoco.yaml'] + weights: [1, 8, 1] + shardshuffle: 100 + resampled: True + world_size: 1 + one_epoch_data_size: 1000000 + val_config: + karpathy_file_path: /ssd0/data/coco/annotations/karpathy/dataset_coco_test.json + root_dir: /ssd0/data/coco/images/val2014 + start_index: 0 + end_index: 256 + num_workers: 16 + shuffle: True + text_max_length: 128 + +stage1: + init: 'SEED' + +stage2: + bypass_codebook: False + load_diffusion: False + +experiment: + seed: 0 + stage: 1 + local_batch_size: 256 + val_batch_size: 16 + test_split: train + max_epochs: 40 + deterministic: False + grad_accumulation: 1 + check_val_every_n_epoch: 1 + enable_checkpointing: True + log_every_n_steps: 1 + num_sanity_val_steps: 1 + num_warmup_steps: 200 + grad_clip_val: 0.5 + +optimizer: + vit_precision: 'fp16' + diffusion_precision: 'fp16' + precision: 'bf16' + max_lr: 5e-6 + +hyperparameters: + beta_1: 0.9 + beta_2: 0.999 + weight_decay: 5e-2 + diff --git a/configs/training/stage1/stage_1_training_long_caption.yaml b/configs/training/stage1/stage_1_training_long_caption.yaml new file mode 100755 index 0000000..31adf43 --- /dev/null +++ b/configs/training/stage1/stage_1_training_long_caption.yaml @@ -0,0 +1,73 @@ +cfg_path: ??? +tokenizer_cfg_path: configs/tokenizer/seed_llama_tokenizer_hf.yaml +transform_cfg_path: configs/transform/clip_transform.yaml +model_cfg_path: configs/llm/seed_llama_8b.yaml +result_file_path: ./logs/capsfusion_dci_llava_long_caption_last_token +checkpoint_path: + model_path: pretrained/seed_tokenizer/seed_quantizer.pt + diffusion_model_path: stabilityai/stable-diffusion-2-1-unclip + +resume: False +load_weight: False +weight_path: None +eval: False + +dist: + n_gpus: 4 + n_nodes: 1 + +dataset: + train_config: + dataset_configs: ['configs/data/laion_capsfusion.yaml', 'configs/data/dci_llava.yaml', 'configs/data/coco_karpathy_train.yaml'] + weights: [1000, 1, 1] + shardshuffle: 100 + resampled: True + world_size: 1 + one_epoch_data_size: 1000000 + val_config: + use_coco_val: False + dataset_configs: ['configs/data/laion_capsfusion_val.yaml'] + weights: [1] + shardshuffle: 100 + resampled: True + world_size: 1 + one_epoch_data_size: 1000 + num_workers: 16 + shuffle: True + text_max_length: 512 + +stage1: + init: 'SEED' + +stage2: + bypass_codebook: False + load_diffusion: False + +experiment: + seed: 0 + stage: 1 + local_batch_size: 128 + val_batch_size: 256 + test_split: train + max_epochs: 40 + deterministic: False + grad_accumulation: 1 + val_check_interval: 0.2 + # check_val_every_n_epoch: 1 + enable_checkpointing: True + log_every_n_steps: 1 + num_sanity_val_steps: 1 + num_warmup_steps: 200 + grad_clip_val: 0.5 + +optimizer: + vit_precision: 'fp16' + diffusion_precision: 'fp16' + precision: 'bf16' + max_lr: 7e-7 + +hyperparameters: + beta_1: 0.9 + beta_2: 0.999 + weight_decay: 5e-2 + diff --git a/configs/training/stage1/stage_1_training_long_caption_debug.yaml b/configs/training/stage1/stage_1_training_long_caption_debug.yaml new file mode 100755 index 0000000..48949f4 --- /dev/null +++ b/configs/training/stage1/stage_1_training_long_caption_debug.yaml @@ -0,0 +1,73 @@ +cfg_path: ??? +tokenizer_cfg_path: configs/tokenizer/seed_llama_tokenizer_hf.yaml +transform_cfg_path: configs/transform/clip_transform.yaml +model_cfg_path: configs/llm/seed_llama_8b.yaml +result_file_path: ./logs/capsfusion_dci_llava_long_caption_last_token_debug +checkpoint_path: + model_path: pretrained/seed_tokenizer/seed_quantizer.pt + diffusion_model_path: stabilityai/stable-diffusion-2-1-unclip + +resume: False +load_weight: False +weight_path: None +eval: False + +dist: + n_gpus: 1 + n_nodes: 1 + +dataset: + train_config: + dataset_configs: ['configs/data/laion_capsfusion.yaml', 'configs/data/dci_llava.yaml', 'configs/data/coco_karpathy_train.yaml'] + weights: [1000, 1, 1] + shardshuffle: 100 + resampled: True + world_size: 1 + one_epoch_data_size: 1000000 + val_config: + use_coco_val: False + dataset_configs: ['configs/data/laion_capsfusion_val.yaml'] + weights: [1] + shardshuffle: 100 + resampled: True + world_size: 1 + one_epoch_data_size: 1000 + num_workers: 16 + shuffle: True + text_max_length: 512 + +stage1: + init: 'SEED' + +stage2: + bypass_codebook: False + load_diffusion: False + +experiment: + seed: 0 + stage: 1 + local_batch_size: 4 + val_batch_size: 4 + test_split: train + max_epochs: 40 + deterministic: False + grad_accumulation: 1 + val_check_interval: 0.2 + # check_val_every_n_epoch: 1 + enable_checkpointing: True + log_every_n_steps: 1 + num_sanity_val_steps: 0 + num_warmup_steps: 200 + grad_clip_val: 0.5 + +optimizer: + vit_precision: 'fp16' + diffusion_precision: 'fp16' + precision: 'bf16' + max_lr: 5e-6 + +hyperparameters: + beta_1: 0.9 + beta_2: 0.999 + weight_decay: 5e-2 + diff --git a/configs/training/stage2/seed_stage2_training.yaml b/configs/training/stage2/seed_stage2_training.yaml deleted file mode 120000 index a2158b9..0000000 --- a/configs/training/stage2/seed_stage2_training.yaml +++ /dev/null @@ -1 +0,0 @@ -../seed_unified_test.yaml \ No newline at end of file diff --git a/data/DCI_LLaVA_wds b/data/DCI_LLaVA_wds new file mode 120000 index 0000000..3b97b59 --- /dev/null +++ b/data/DCI_LLaVA_wds @@ -0,0 +1 @@ +/ssd0/data/DCI_LLaVA_wds/ \ No newline at end of file diff --git a/data/laion_capsfusion_wds b/data/laion_capsfusion_wds new file mode 120000 index 0000000..e4ce2a5 --- /dev/null +++ b/data/laion_capsfusion_wds @@ -0,0 +1 @@ +/ssd0/data/laion_capsfusion_wds/ \ No newline at end of file diff --git a/datamodules/c2f_datamodule.py b/datamodules/c2f_datamodule.py new file mode 100644 index 0000000..c53a3e1 --- /dev/null +++ b/datamodules/c2f_datamodule.py @@ -0,0 +1,363 @@ +import torch +import yaml +from torch.utils.data import Dataset, DataLoader, IterableDataset +from PIL import Image +import json +import copy +import torchvision.transforms as transforms +import numpy as np +import os +import sys +sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__)))) +import webdataset as wds +import braceexpand +from collections.abc import Callable +from pycocotools.coco import COCO +from coco_dataloader import CocoDataset +import torch.utils.data as data + +import random +import pytorch_lightning as pl + + + +class C2FFinetuneData(IterableDataset): + def __init__(self, + config_path, + num_positive_samples: int, + transform=None, + tokenizer=None, + shardshuffle=100, + resampled=True, + world_size=1, + rank=0, + ): + print(f"read dataset config from {config_path}") + with open(config_path, 'r') as f: + self.config = yaml.load(f, Loader=yaml.FullLoader) + print("DATASET CONFIG:") + print(self.config) + wds_urls = [] + self.total_num_samples = 0 + for urls, num_samples in self.config['META']: + urls = expand_urls(urls) + wds_urls.extend(urls) + self.total_num_samples += num_samples + + self.transform = transform + self.tokenizer = tokenizer + self.num_positive_samples = num_positive_samples + self.max_num_positive_samples = 32 + + self.target_ids = list(range(self.num_positive_samples)) + + if 'CONTAIN_TEXT' in self.config.keys(): + self.contain_txt = self.config['CONTAIN_TEXT'] + else: + self.contain_txt = True + + if not self.contain_txt: + self.dataset = ( + wds.WebDataset( + urls, + # 싱글 노드 방식: wds.single_node_only (default) + # 멀티 노드 방식: wds.split_by_node + # url level partioning + nodesplitter=wds.split_by_node, + # url level shuffle + shardshuffle=shardshuffle, + # deterministic shuffle (재현성) + detshuffle=False, + # infinite url + resampled=resampled, + handler=wds.ignore_and_continue, + ) + .shuffle( # sample level shuffle + size=(1 if shardshuffle is None else shardshuffle * 10), + initial=(0 if shardshuffle is None else 100), + ) + .decode("pil", handler=wds.ignore_and_continue) + .to_tuple("jpg", "json", handler=wds.ignore_and_continue) + .map_tuple( + transform, + self.identity, #self.tokenize, + handler=wds.ignore_and_continue, + ) + .with_length(int(int(self.total_num_samples) / world_size)) + ) + else: + self.dataset = ( + wds.WebDataset( + urls, + # 싱글 노드 방식: wds.single_node_only (default) + # 멀티 노드 방식: wds.split_by_node + # url level partioning + nodesplitter=wds.split_by_node, + # url level shuffle + shardshuffle=shardshuffle, + # deterministic shuffle (재현성) + detshuffle=False, + # infinite url + resampled=resampled, + handler=wds.ignore_and_continue, + ) + .shuffle( # sample level shuffle + size=(1 if shardshuffle is None else shardshuffle * 10), + initial=(0 if shardshuffle is None else 100), + ) + .decode("pil", handler=wds.ignore_and_continue) + .to_tuple("jpg", "txt", "json") + .map_tuple( + transform, + self.identity, + self.identity, #self.tokenize, + handler=wds.ignore_and_continue, + ) + .with_length(int(int(self.total_num_samples) / world_size)) + ) + + + self.world_size = world_size + self.rank = rank + + self.name = config_path + + def identity(self, x): + return x + + def tokenize(self, x): + return torch.tensor(self.tokenizer.encode(x), dtype=torch.int64)[1:] + + def __len__(self): + return len(self.dataset) + + def process_text(self, text): + text_tokens = self.tokenize(text) + return text_tokens + + def rescale_ids(self, ids, num_samples): + ret = [] + step_size = self.max_num_positive_samples // num_samples + for _id in ids: + st = _id * step_size + ed = min(st + step_size, self.max_num_positive_samples) + interval = list(range(st, ed)) + ret.append(random.choice(interval)) + return ret + + def __iter__(self): + if self.contain_txt: + for i, (img, txt_tokens, meta) in enumerate(self.dataset): + yield img, txt_tokens + else: + for i, (img, meta) in enumerate(self.dataset): + captions = meta['sc'] + ids = list(range(len(captions))) + sampled_ids = sorted(random.sample(ids, self.num_positive_samples)) + _captions = [captions[_id] for _id in sampled_ids][::-1] + pos_ids = 32 - np.array(self.rescale_ids(sampled_ids, len(captions))[::-1]) + + text_tokens = self.tokenizer( + _captions, + padding="max_length", + truncation=True, + max_length=300, + return_tensors="pt", + ) + + yield img, pos_ids, text_tokens.input_ids, text_tokens.attention_mask + + def groups(self): + return list(self.group_indices.values()) + + +def expand_urls(urls): + def decode(urls): + urllist = urls.split("::") + result = [] + for url in urllist: + result.extend(braceexpand.braceexpand(url)) + return result + + if isinstance(urls, str): + return decode(urls) + elif isinstance(urls, tuple): + results = [] + for urls_ in urls: + results += decode(urls_) + return results + else: + return list(urls) + +class ComibinedDatasetIterator: + def __init__(self, datasets, weights): + self.datasets = [iter(dataset) for dataset in datasets] + self.weights = weights + self.randome_generator = random.Random() + + def __next__(self): + (dataset, ) = self.randome_generator.choices(self.datasets, self.weights, k=1) + return next(dataset) + +class CombinedDataset(IterableDataset): + def __init__(self, datasets_configs, datasets, rank=None, world_size=None, weights=None, length=None): + self.datasets = datasets + + weights = weights if weights is not None else [1] * len(datasets) + self.weights = [w/sum(weights) for w in weights] + self.randome_generator = random.Random() + if length is None: + self.length = sum([len(dataset) for dataset in datasets]) + else: + self.length = length + + def __iter__(self): + return ComibinedDatasetIterator(self.datasets, self.weights) + + def __len__(self): + return int(self.length) + +def custom_collate_fn(batch): + images = [_[0] for _ in batch] + images = torch.stack(images, dim=0) + + texts = [_[1] for _ in batch] + return images, texts + + +reverse_transform = transforms.Compose([ + transforms.Normalize(mean=[0, 0, 0], std=[1/0.26862954, 1/0.26130258, 1/0.27577711]), + transforms.Normalize(mean=[-0.48145466, -0.4578275, -0.40821073], std=[1, 1, 1]), + ]) + +max_seq_len = 1024 + +def pack(token_sequences, max_seq_len=128, batch_size=3): + ## sort token_sequences by length + token_sequences = sorted(token_sequences, key=lambda x: len(x), reverse=True) + print(token_sequences) + + ## batch by snaek order + packed_ds = [] + curr_token_ids = [torch.tensor([], dtype=torch.int64) for _ in range(batch_size)] + + for i in range(0, len(token_sequences), batch_size): + for j in range(batch_size): + if i+j >= len(token_sequences): + break + if len(curr_token_ids[j]) + len(token_sequences[i+j]) < max_seq_len: + curr_token_ids[j] = torch.cat([curr_token_ids[j], token_sequences[i+j]], dim=0) + for j in range(batch_size): + packed_ds.append(curr_token_ids[j]) + + return packed_ds + +class SEEDDataModule(pl.LightningDataModule): + def __init__(self, cfg, tokenizer=None, transform=None, use_coco_val=True): + super().__init__() + self.cfg = cfg + self.dataset_configs = cfg.dataset.train_config.dataset_configs + self.shardshuffle = cfg.dataset.train_config.shardshuffle + self.resampled = cfg.dataset.train_config.resampled + self.world_size = cfg.dataset.train_config.world_size + self.weights = cfg.dataset.train_config.weights + self.val_weights = cfg.dataset.val_config.get('weights', None) + + self.local_batch_size = cfg.experiment.local_batch_size + self.val_batch_size = cfg.experiment.val_batch_size + self.num_workers = cfg.dataset.num_workers + self.n_gpus = cfg.dist.n_gpus + + self.tokenizer = tokenizer + self.transform = transform + + + self.one_epoch_data_size = cfg.dataset.train_config.one_epoch_data_size + self.total_training_steps = ((cfg.experiment.max_epochs * self.one_epoch_data_size) / self.n_gpus) / self.local_batch_size + + self.use_coco_val = use_coco_val + + def setup(self, stage=None): + datasets = [] + for config in self.dataset_configs: + datasets.append( + C2FFinetuneData( + config, + self.cfg.experiment.num_positive_samples, + tokenizer=self.tokenizer, + shardshuffle=self.shardshuffle, + resampled=self.resampled, + world_size=self.world_size, + transform=self.transform, + ) + ) + + self.train_dataset = CombinedDataset( + datasets_configs=self.dataset_configs, + datasets=datasets, + rank=0, + world_size=self.world_size, + length=self.one_epoch_data_size/self.n_gpus, + weights=self.weights, + ) + + if self.use_coco_val: + end_index = self.cfg.dataset.val_config.end_index if self.cfg.dataset.val_config.end_index is not None else None + self.validation_dataset = CocoDataset( + root_dir=self.cfg.dataset.val_config.root_dir, + karpathy_file=self.cfg.dataset.val_config.karpathy_file_path, + tokenizer=None, + start_index=self.cfg.dataset.val_config.start_index, + end_index=end_index, + ) + else: + self.val_dataset_configs = self.cfg.dataset.val_config.dataset_configs + val_datasets = [] + for config in self.val_dataset_configs: + val_datasets.append( + C2FFinetuneData( + config, + 1, + tokenizer=self.tokenizer, + shardshuffle=self.shardshuffle, + resampled=self.resampled, + world_size=self.world_size, + transform=self.transform, + ) + ) + self.validation_dataset = CombinedDataset( + datasets_configs=self.val_dataset_configs, + datasets=val_datasets, + rank=0, + world_size=self.world_size, + length=self.one_epoch_data_size/self.n_gpus, + weights=self.val_weights, + ) + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.local_batch_size, + # num_workers=self.num_workers, + ) + + def val_dataloader(self): + if self.use_coco_val: + return DataLoader( + self.validation_dataset, + batch_size=self.val_batch_size, + collate_fn=self.validation_dataset.collate_fn, + # num_workers=self.num_workers, + ) + else: + return DataLoader( + self.validation_dataset, + batch_size=self.val_batch_size, + # num_workers=self.num_workers, + ) + + def test_dataloader(self): + raise(NotImplementedError) + + def predict_dataloader(self): + raise(NotImplementedError) \ No newline at end of file diff --git a/datamodules/compression_datamodule.py b/datamodules/compression_datamodule.py index e0fc3a8..e08585c 100755 --- a/datamodules/compression_datamodule.py +++ b/datamodules/compression_datamodule.py @@ -1,7 +1,7 @@ import pytorch_lightning as pl from pytorch_lightning import LightningDataModule from torch.utils.data import Dataset, DataLoader -from torch.utils.data.dataset import ConcatDataset +from torch.utils.data.dataset import ConcatDataset, random_split from torch.utils.data import Sampler, DistributedSampler from io import BytesIO from PIL import Image @@ -11,7 +11,6 @@ import json import tarfile - class CompressionDataset(Dataset): def __init__(self, compression_level, transform=None): self.root = '/home/zheedong/Projects/SEED/data/cc3m_llava_long_caption' @@ -136,18 +135,38 @@ def __len__(self): return self.num_samples // self.batch_size class CompressionDataModule(LightningDataModule): - def __init__(self, batch_size=32, num_workers=4, transform=None, compression_level=0): + def __init__(self, cfg=None, transform=None, compression_level=0): super().__init__() - self.batch_size = batch_size - self.num_workers = num_workers + self.cfg = cfg + self.local_batch_size = cfg.experiment.local_batch_size + self.val_batch_size = cfg.experiment.val_batch_size + self.num_workers = cfg.dataset.num_workers self.transform = transform self.compression_level = compression_level def setup(self): self.datasets = [] self.datasets = CompressionDataset(compression_level=self.compression_level, transform=self.transform) + self.train_size = int(0.98 * len(self.datasets)) + self.var_size = len(self.datasets) - self.train_size + self.train_dataset, self.val_dataset = random_split(self.datasets, [self.train_size, self.var_size]) def train_dataloader(self): - return DataLoader(self.datasets, - batch_size=self.batch_size, - num_workers=self.num_workers,) \ No newline at end of file + return DataLoader( + self.train_dataset, + batch_size=self.local_batch_size, + num_workers=self.num_workers, + pin_memory=True, + # drop_last=True, + # shuffle=False, + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.val_batch_size, + num_workers=self.num_workers, + pin_memory=True, + # drop_last=True, + # shuffle=False, + ) \ No newline at end of file diff --git a/datamodules/datasets/coco_val.py b/datamodules/datasets/coco_val.py new file mode 100644 index 0000000..1956b96 --- /dev/null +++ b/datamodules/datasets/coco_val.py @@ -0,0 +1,21 @@ +from torch.utils.data import Dataset +from PIL import Image +import json + +class COCOValDataSet(Dataset): + def __init__(self, transform): + self.transform = transform + self.coco_val_data_path = '/home/zheedong/Projects/SEED/coco/annotations/captions_val2014.json' + with open(self.coco_val_data_path, 'r') as f: + self.coco_data = json.load(f) + self.image_root = '/home/zheedong/Projects/SEED/coco/images/val2014' + + def __len__(self): + return 30000 + + def __getitem__(self, idx): + image_path = self.image_root + '/' + self.coco_data['images'][idx]['file_name'] + image = Image.open(image_path).convert('RGB') + if self.transform: + image = self.transform(image) + return image, '', self.coco_data['images'][idx]['file_name'] \ No newline at end of file diff --git a/datamodules/seed_llama_datamodule.py b/datamodules/seed_llama_datamodule.py index d35a101..f40c825 100755 --- a/datamodules/seed_llama_datamodule.py +++ b/datamodules/seed_llama_datamodule.py @@ -15,6 +15,7 @@ from collections.abc import Callable from pycocotools.coco import COCO from coco_dataloader import CocoDataset +import torch.utils.data as data import hydra from omegaconf import OmegaConf @@ -30,7 +31,7 @@ def __init__(self, shardshuffle=100, resampled=True, world_size=1, - rank=0 + rank=0, ): print(f"read dataset config from {config_path}") with open(config_path, 'r') as f: @@ -48,35 +49,71 @@ def __init__(self, self.max_words = max_words self.tokenizer = tokenizer - self.dataset = ( - wds.WebDataset( - urls, - # 싱글 노드 방식: wds.single_node_only (default) - # 멀티 노드 방식: wds.split_by_node - # url level partioning - nodesplitter=wds.split_by_node, - # url level shuffle - shardshuffle=shardshuffle, - # deterministic shuffle (재현성) - detshuffle=False, - # infinite url - resampled=resampled, - handler=wds.ignore_and_continue, - ) - .shuffle( # sample level shuffle - size=(1 if shardshuffle is None else shardshuffle * 10), - initial=(0 if shardshuffle is None else 100), + if 'CONTAIN_TEXT' in self.config.keys(): + self.contain_txt = self.config['CONTAIN_TEXT'] + else: + self.contain_txt = True + + if not self.contain_txt: + self.dataset = ( + wds.WebDataset( + urls, + # 싱글 노드 방식: wds.single_node_only (default) + # 멀티 노드 방식: wds.split_by_node + # url level partioning + nodesplitter=wds.split_by_node, + # url level shuffle + shardshuffle=shardshuffle, + # deterministic shuffle (재현성) + detshuffle=False, + # infinite url + resampled=resampled, + handler=wds.ignore_and_continue, + ) + .shuffle( # sample level shuffle + size=(1 if shardshuffle is None else shardshuffle * 10), + initial=(0 if shardshuffle is None else 100), + ) + .decode("pil", handler=wds.ignore_and_continue) + .to_tuple("jpg", "json", handler=wds.ignore_and_continue) + .map_tuple( + transform, + self.identity, #self.tokenize, + handler=wds.ignore_and_continue, + ) + .with_length(int(int(self.total_num_samples) / world_size)) ) - .decode("pil", handler=wds.ignore_and_continue) - .to_tuple("jpg", "txt", "json", handler=wds.ignore_and_continue) - .map_tuple( - transform, - self.identity, #self.tokenize, - self.identity, - handler=wds.ignore_and_continue, + else: + self.dataset = ( + wds.WebDataset( + urls, + # 싱글 노드 방식: wds.single_node_only (default) + # 멀티 노드 방식: wds.split_by_node + # url level partioning + nodesplitter=wds.split_by_node, + # url level shuffle + shardshuffle=shardshuffle, + # deterministic shuffle (재현성) + detshuffle=False, + # infinite url + resampled=resampled, + handler=wds.ignore_and_continue, + ) + .shuffle( # sample level shuffle + size=(1 if shardshuffle is None else shardshuffle * 10), + initial=(0 if shardshuffle is None else 100), + ) + .decode("pil", handler=wds.ignore_and_continue) + .to_tuple("jpg", "txt", "json") + .map_tuple( + transform, + self.identity, + self.identity, #self.tokenize, + handler=wds.ignore_and_continue, + ) + .with_length(int(int(self.total_num_samples) / world_size)) ) - .with_length(int(int(self.total_num_samples) / world_size)) - ) + self.world_size = world_size self.rank = rank @@ -97,9 +134,14 @@ def process_text(self, text): return text_tokens def __iter__(self): - for i, (img, txt_tokens, meta) in enumerate(self.dataset): - yield img, txt_tokens - + if self.contain_txt: + for i, (img, txt_tokens, meta) in enumerate(self.dataset): + yield img, txt_tokens + else: + for i, (img, meta) in enumerate(self.dataset): + caption = meta['capsfusion'] + yield img, caption + def groups(self): return list(self.group_indices.values()) @@ -148,7 +190,7 @@ def __iter__(self): return ComibinedDatasetIterator(self.datasets, self.weights) def __len__(self): - return self.length + return int(self.length) def custom_collate_fn(batch): images = [_[0] for _ in batch] @@ -186,7 +228,7 @@ def pack(token_sequences, max_seq_len=128, batch_size=3): return packed_ds class SEEDDataModule(pl.LightningDataModule): - def __init__(self, cfg, transform=None): + def __init__(self, cfg, transform=None, use_coco_val=True): super().__init__() self.cfg = cfg self.dataset_configs = cfg.dataset.train_config.dataset_configs @@ -194,6 +236,7 @@ def __init__(self, cfg, transform=None): self.resampled = cfg.dataset.train_config.resampled self.world_size = cfg.dataset.train_config.world_size self.weights = cfg.dataset.train_config.weights + self.val_weights = cfg.dataset.val_config.get('weights', None) self.local_batch_size = cfg.experiment.local_batch_size self.val_batch_size = cfg.experiment.val_batch_size @@ -204,6 +247,8 @@ def __init__(self, cfg, transform=None): self.one_epoch_data_size = cfg.dataset.train_config.one_epoch_data_size self.total_training_steps = ((cfg.experiment.max_epochs * self.one_epoch_data_size) / self.n_gpus) / self.local_batch_size + self.use_coco_val = use_coco_val + def setup(self, stage=None): datasets = [] for config in self.dataset_configs: @@ -225,14 +270,38 @@ def setup(self, stage=None): length=self.one_epoch_data_size/self.n_gpus, weights=self.weights, ) - self.validation_dataset = CocoDataset( - root_dir=self.cfg.dataset.val_config.root_dir, - karpathy_file=self.cfg.dataset.val_config.karpathy_file_path, - tokenizer=None, - start_index=self.cfg.dataset.val_config.start_index if self.cfg.dataset.val_config.start_index != "None" else None, - end_index=self.cfg.dataset.val_config.end_index if self.cfg.dataset.val_config.end_index != "None" else None, - ) - + + if self.use_coco_val: + end_index = self.cfg.dataset.val_config.end_index if self.cfg.dataset.val_config.end_index is not None else None + self.validation_dataset = CocoDataset( + root_dir=self.cfg.dataset.val_config.root_dir, + karpathy_file=self.cfg.dataset.val_config.karpathy_file_path, + tokenizer=None, + start_index=self.cfg.dataset.val_config.start_index, + end_index=end_index, + ) + else: + self.val_one_epoch_data_size = self.cfg.dataset.val_config.one_epoch_data_size + self.val_dataset_configs = self.cfg.dataset.val_config.dataset_configs + val_datasets = [] + for config in self.val_dataset_configs: + val_datasets.append( + FinetuneData( + config, + shardshuffle=self.shardshuffle, + resampled=self.resampled, + world_size=self.world_size, + transform=self.transform, + ) + ) + self.validation_dataset = CombinedDataset( + datasets_configs=self.val_dataset_configs, + datasets=val_datasets, + rank=0, + world_size=self.world_size, + length=self.val_one_epoch_data_size/self.n_gpus, + weights=self.val_weights, + ) def train_dataloader(self): return DataLoader( @@ -242,12 +311,19 @@ def train_dataloader(self): ) def val_dataloader(self): - return DataLoader( - self.validation_dataset, - batch_size=self.val_batch_size, - collate_fn=self.validation_dataset.collate_fn, - num_workers=self.num_workers, - ) + if self.use_coco_val: + return DataLoader( + self.validation_dataset, + batch_size=self.val_batch_size, + collate_fn=self.validation_dataset.collate_fn, + num_workers=self.num_workers, + ) + else: + return DataLoader( + self.validation_dataset, + batch_size=self.val_batch_size, + num_workers=self.num_workers, + ) def test_dataloader(self): raise(NotImplementedError) diff --git a/models/seed_llama_tokenizer.py b/models/seed_llama_tokenizer.py index 4076d9f..3a42f88 100755 --- a/models/seed_llama_tokenizer.py +++ b/models/seed_llama_tokenizer.py @@ -38,6 +38,7 @@ def __init__(self, if not from_pretrained: model = Blip2QformerQuantizer(vit_precision=vit_precision, is_train=True, **kwargs) else: + print(f"Loading model from {model_path}, SEED Weight") model = Blip2QformerQuantizer.from_pretrained(pretrained_model_path=model_path, #vit_precision='fp16' if fp16 else 'fp32', vit_precision=vit_precision, diff --git a/test.py b/test.py index 49d99c3..d82e09a 100755 --- a/test.py +++ b/test.py @@ -2,10 +2,11 @@ from omegaconf import OmegaConf import torch from PIL import Image -from our_tokenizer import SEEDTrainingWrapper +#from our_tokenizer import SEEDTrainingWrapper +from train_v7_unified_llm import SEEDTrainingWrapper from einops import rearrange -cfg_path = './configs/our_seed_tokenizer.yaml' +cfg_path = 'configs/seed_unified_test.yaml' cfg = OmegaConf.load(cfg_path) visual_tokenizer = SEEDTrainingWrapper.load_from_checkpoint('/home/zheedong/Projects/SEED/logs/seed_stage2_proj/lightning_logs/stage2_w_codebook_40epoch/checkpoints/epoch=39-step=7840.ckpt', cfg=cfg, strict=False, map_location="cpu") visual_tokenizer.eval() @@ -45,7 +46,7 @@ print(img_ids) - +import pdb; pdb.set_trace() save_path_new = "images/cat_new.jpg" with torch.no_grad(): diff --git a/test_shard.py b/test_shard.py new file mode 120000 index 0000000..cd57d31 --- /dev/null +++ b/test_shard.py @@ -0,0 +1 @@ +/ssd0/data/laion_capsfusion/test_shard.py \ No newline at end of file diff --git a/train_sds_wo_codebook.py b/train_sds_wo_codebook.py index f66b5d8..f2c8540 100755 --- a/train_sds_wo_codebook.py +++ b/train_sds_wo_codebook.py @@ -480,6 +480,7 @@ def training_step(self, batch, batch_idx: int): with torch.no_grad(): clip_cosine_similarity = F.cosine_similarity(image_embeds, gt_img_clip_embeddings).mean() + # loss_recon = F.mse_loss(image_embeds, gt_img_clip_embeddings) self.log( "train/loss_sds", diff --git a/train_v7_FID.py b/train_v7_FID.py new file mode 100755 index 0000000..12b96cf --- /dev/null +++ b/train_v7_FID.py @@ -0,0 +1,247 @@ +import os +from typing import Any, List +import torch +from torch.cuda.amp import autocast +import torch.nn as nn +from torch.nn import functional as F +import torch.distributed as dist + +import json + +import hydra +import torchvision.transforms as transforms +import pytorch_lightning as pl +from pytorch_lightning import LightningModule, seed_everything +import numpy as np +from omegaconf import OmegaConf +from PIL import Image +from torchvision.transforms.functional import to_pil_image +import pyrootutils + +from torch.distributed.algorithms.ddp_comm_hooks import default_hooks +import torch.nn.functional as F +from pytorch_lightning.strategies import DDPStrategy, DDPFullyShardedStrategy +from einops import rearrange +import transformers + +from pytorch_lightning import loggers as pl_loggers +from functools import partial +from pytorch_lightning.callbacks import ModelSummary +from pytorch_lightning.utilities import grad_norm + +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import matplotlib.pyplot as plt + +from models.seed_qformer.vit import Block +from models.seed_llama_tokenizer import ImageTokenizer + +from coco_dataloader import CocoDataset + +from datamodules.seed_llama_datamodule import SEEDDataModule + +from calculate_clip_score import calculate_clip_s_for_folder +from utils.config import build_config + +from lavis.models import load_model +from lavis.common.dist_utils import is_dist_avail_and_initialized + +pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) + +BOI_TOKEN = "" +EOI_TOKEN = "" +IMG_TOKEN = "" + +IMG_FLAG = "" +NUM_IMG_TOKNES = 32 +NUM_IMG_CODES = 8192 +IMAGE_ID_SHIFT = 32000 + +class SEEDTrainingWrapper(LightningModule): + """Training wrapper for SEED + + Args: + LightningModule (cfg, model): model should be ImageTokenizer + """ + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + # ImageTokenizer model + # Target model to train + self.image_tokenizer = ImageTokenizer( + model_path=cfg.checkpoint_path.model_path, + diffusion_model_path=cfg.checkpoint_path.diffusion_model_path, + load_diffusion=cfg.stage2.load_diffusion, + from_pretrained=True if cfg.stage1.init == "SEED" else False, + vit_precision=cfg.optimizer.vit_precision, + diffusion_precision=cfg.optimizer.diffusion_precision, + ) + + self.B = None + + self.transform_224 = transforms.Resize((224, 224), antialias=True) + + # For diffusion DDP + if self.image_tokenizer.diffusion_model is not None: + self.feature_extractor = self.image_tokenizer.diffusion_model.feature_extractor + self.image_encoder = self.image_tokenizer.diffusion_model.image_encoder + self.image_normalizer = self.image_tokenizer.diffusion_model.image_normalizer + self.image_noising_scheduler = self.image_tokenizer.diffusion_model.image_noising_scheduler + self.tokenizer = self.image_tokenizer.diffusion_model.tokenizer + self.text_encoder = self.image_tokenizer.diffusion_model.text_encoder + self.unet = self.image_tokenizer.diffusion_model.unet + self.scheduler = self.image_tokenizer.diffusion_model.scheduler + self.vae = self.image_tokenizer.diffusion_model.vae + + # For logging + self.pil_to_tensor = transforms.ToTensor() + self.sample_image_ind = 0 + self.logged_original_image = set() + + self.stage = cfg.experiment.stage + self.temp = nn.Parameter(0.07 * torch.ones([])) + + def get_causal_embeddings(self, image): + return self.image_tokenizer.model.get_causal_embeddings(image) + + def forward_stage_2(self, batch, batch_idx: int, bypass_codebook=False): + """_summary_ + Original forward function for stage 2 + Just to see how the forward function works + + Args: + batch (_type_): _description_ + batch_idx (int): _description_ + + Returns: + _type_: _description_ + """ + + # Causal embedding is trained in stage 1. + # [b, 32, 768] + if len(batch) == 3: + image, text, image_id = batch + elif len(batch) == 2: + image, text = batch + + with torch.no_grad(): + causal_embeddings = self.get_causal_embeddings(image) + + # [b, 32, 768] = > [b, 32, 32] + query_output_down = self.image_tokenizer.model.encode_task_layer(causal_embeddings) + + # quant [b, 32, 32], loss_embed [b, 32, 768], embed_ind [b, 32] + if bypass_codebook: + # Bypass codebook + print("Bypass codebook") + quant = query_output_down + loss_embed = None + embed_ind = None + else: + # Quantize + print("Quantize") + quant, loss_embed, embed_ind = self.image_tokenizer.model.quantize(query_output_down) + embed_ind = embed_ind.reshape(quant.shape[0], -1) + + # [b, 32, 32] => [b, 32, 768] + query_output_up = self.image_tokenizer.model.decode_task_layer(quant) + + # [b, 32, 768] => [b, 32, 768] + query_output_up = self.image_tokenizer.model.get_transformer_decoded_embedding(query_output_up) + + # [b, 32, 768] => [b, 32, 32] => [b, 1024] + reverse_output_proj = self.image_tokenizer.model.get_mlp_decoded_embedding(query_output_up) + + return reverse_output_proj + + @torch.no_grad() + def test_step(self, batch, batch_idx: int): + if self.logger is not None and isinstance(self.logger, pl.loggers.TensorBoardLogger): + tb_log_dir = self.logger.log_dir + else: + tb_log_dir = self.cfg.result_file_path # Fallback directory if logger is not set + + _, _, image_name = batch + bypass_codebook = self.cfg.stage2.bypass_codebook + + with torch.no_grad(): + image_embeds = self.forward_stage_2(batch, batch_idx, bypass_codebook) + reconstructed_images = self.image_tokenizer.diffusion_model( + image_embeds=image_embeds, + negative_image_embeds=None, + guidance_scale=10, + noise_level=0, + latents=self.image_tokenizer.latents, + ).images + + save_path = f"{tb_log_dir}/images/version_{self.logger.version}/COCOVal30000" + os.makedirs(save_path, exist_ok=True) + + for img, cur_name in zip(reconstructed_images, image_name): + # save PIL image to save_path + img.save(f"{save_path}/{cur_name}") + + return + + +if __name__ == "__main__": + cfg, cfg_yaml = build_config() + device = "cuda" if torch.cuda.is_available() else "cpu" + + seed_everything(cfg.experiment.seed, workers=True) + + transform_cfg = OmegaConf.load(cfg.transform_cfg_path) + transform = hydra.utils.instantiate(transform_cfg) + + os.makedirs(cfg.result_file_path, exist_ok=True) + + from datamodules.datasets.coco_val import COCOValDataSet + from torch.utils.data import DataLoader + + mode = 5000 # 30000 or 5000 + + print(f"Load {mode} images from COCO Validation.") + if mode == 30000: + test_dataset = COCOValDataSet(transform=transform) + else: + test_dataset = CocoDataset( + root_dir=cfg.dataset.val_config.root_dir, + karpathy_file=cfg.dataset.val_config.karpathy_file_path, + tokenizer=None, + start_index=0, + end_index=5000, + ) + + print(f"Test dataset length: {len(test_dataset)}") + + test_dataloader = DataLoader( + test_dataset, + batch_size=cfg.experiment.val_batch_size, + shuffle=True, + num_workers=cfg.dataset.num_workers, + pin_memory=True, + drop_last=True, + ) + + tb_logger = pl_loggers.TensorBoardLogger(save_dir=cfg.result_file_path) + + trainer = pl.Trainer( + accelerator=device, + num_nodes=cfg.dist.n_nodes, + devices=cfg.dist.n_gpus, + strategy="ddp", + max_epochs=1, + deterministic=True, + precision=str(cfg.optimizer.precision), + callbacks=[ModelSummary(max_depth=3)], + logger=tb_logger, + ) + + wrapper = SEEDTrainingWrapper(cfg).to(device) + # wrapper = SEEDTrainingWrapper.load_from_checkpoint(cfg.weight_path, cfg=cfg, strict=False).to(device) + + trainer.test(wrapper, dataloaders=test_dataloader) + + + diff --git a/train_v7_unified.py b/train_v7_unified.py index 62ea99d..3644ca3 100755 --- a/train_v7_unified.py +++ b/train_v7_unified.py @@ -421,7 +421,7 @@ def get_stage_1_loss_use_last_token(self, batch, batch_idx: int): text, padding="max_length", truncation=True, - max_length=128, + max_length=self.cfg.dataset.text_max_length, return_tensors="pt", ) @@ -532,7 +532,7 @@ def check_image_text_similarity(self, batch, batch_idx: int, save_dir="image_tex text, padding="max_length", truncation=True, - max_length=128, + max_length=self.cfg.dataset.text_max_length, return_tensors="pt", ) @@ -963,6 +963,8 @@ def validation_step(self, batch, batch_idx: int, save_path=None): tb_log_dir = self.cfg.result_file_path # Fallback directory if logger is not set if self.stage == 1: + if batch_idx > 4: + return save_path = f"{tb_log_dir}/histogram" os.makedirs(save_path, exist_ok=True) @@ -1088,7 +1090,7 @@ def configure_optimizers(self): os.makedirs(cfg.result_file_path, exist_ok=True) - datamodule = SEEDDataModule(cfg, transform=transform) + datamodule = SEEDDataModule(cfg, transform=transform, use_coco_val=cfg.dataset.val_config.use_coco_val) datamodule.setup() train_dataloader = datamodule.train_dataloader() val_dataloader = datamodule.val_dataloader() @@ -1097,10 +1099,15 @@ def configure_optimizers(self): tb_logger = pl_loggers.TensorBoardLogger(save_dir=cfg.result_file_path) lr_logger = pl.callbacks.LearningRateMonitor(logging_interval="step") + # checkpoint_callback = pl.callbacks.ModelCheckpoint( + # save_top_k=3, + # monitor="clip_score_coco_karpathy" if cfg.experiment.stage == 2 else "val/loss_itc_mean", + # mode="max" if cfg.experiment.stage == 2 else "min", + # save_last=True, + # ) checkpoint_callback = pl.callbacks.ModelCheckpoint( - save_top_k=3, - monitor="clip_score_coco_karpathy" if cfg.experiment.stage == 2 else "val/loss", - mode="max", + save_last=True, + every_n_train_steps=300, ) trainer = pl.Trainer( @@ -1112,14 +1119,15 @@ def configure_optimizers(self): deterministic=cfg.experiment.deterministic, logger=tb_logger, log_every_n_steps=cfg.experiment.log_every_n_steps, - # val_check_interval=cfg.experiment.val_check_interval, - check_val_every_n_epoch=cfg.experiment.check_val_every_n_epoch, + val_check_interval=cfg.experiment.val_check_interval, + # check_val_every_n_epoch=cfg.experiment.check_val_every_n_epoch, enable_checkpointing=cfg.experiment.enable_checkpointing, num_sanity_val_steps=cfg.experiment.num_sanity_val_steps, precision=str(cfg.optimizer.precision), callbacks=[ModelSummary(max_depth=3), lr_logger] + [checkpoint_callback] if cfg.experiment.enable_checkpointing else [], + # callbacks=[ModelSummary(max_depth=3), lr_logger], accumulate_grad_batches=cfg.experiment.grad_accumulation, - gradient_clip_val=cfg.optimizer.grad_clip_val, + gradient_clip_val=cfg.experiment.grad_clip_val, ) if cfg.load_weight and cfg.resume: @@ -1128,11 +1136,11 @@ def configure_optimizers(self): wrapper = SEEDTrainingWrapper(cfg).to(device) if cfg.load_weight: - wrapper.load_from_checkpoint(cfg.weight_path, cfg=cfg) + wrapper = wrapper.load_from_checkpoint(cfg.weight_path, cfg=cfg) print("Loaded model from checkpoint") else: if cfg.experiment.stage == 1: - print("Stage 1 init from BLIP-2") + print(f"Stage 1 init from {cfg.stage1.init}") elif cfg.experiment.stage == 2: print("Stage 2 init from Scratch") diff --git a/train_v7_unified_c2f.py b/train_v7_unified_c2f.py new file mode 100644 index 0000000..2d68420 --- /dev/null +++ b/train_v7_unified_c2f.py @@ -0,0 +1,1386 @@ +import os +import torch +from torch.cuda.amp import autocast +import torch.nn as nn +from torch.nn import functional as F +import torch.distributed as dist + +import json + +import hydra +import torchvision.transforms as transforms +import pytorch_lightning as pl +from pytorch_lightning import LightningModule, seed_everything +import numpy as np +from omegaconf import OmegaConf +import pyrootutils + +import torch.nn.functional as F +from einops import rearrange +import transformers + +from pytorch_lightning import loggers as pl_loggers +from functools import partial +from pytorch_lightning.callbacks import ModelSummary +from pytorch_lightning.utilities import grad_norm + +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import matplotlib.pyplot as plt + +from models.seed_llama_tokenizer import ImageTokenizer + +from datamodules.c2f_datamodule import SEEDDataModule + +from calculate_clip_score import calculate_clip_s_for_folder +from utils.config import build_config + +from lavis.models import load_model +from lavis.common.dist_utils import is_dist_avail_and_initialized + +pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) + +BOI_TOKEN = "" +EOI_TOKEN = "" +IMG_TOKEN = "" + +IMG_FLAG = "" +NUM_IMG_TOKNES = 32 +NUM_IMG_CODES = 8192 +IMAGE_ID_SHIFT = 32000 + +class GatherLayer(torch.autograd.Function): + """ + Gather tensors from all workers with support for backward propagation: + This implementation does not cut the gradients as torch.distributed.all_gather does. + """ + + @staticmethod + def forward(ctx, x): + output = [ + torch.zeros_like(x) for _ in range(torch.distributed.get_world_size()) + ] + torch.distributed.all_gather(output, x) + return tuple(output) + + @staticmethod + def backward(ctx, *grads): + all_gradients = torch.stack(grads) + torch.distributed.all_reduce(all_gradients) + return all_gradients[torch.distributed.get_rank()] + +class SEEDTrainingWrapper(LightningModule): + """Training wrapper for SEED + + Args: + LightningModule (cfg, model): model should be ImageTokenizer + """ + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + # ImageTokenizer model + # Target model to train + self.image_tokenizer = ImageTokenizer( + model_path=cfg.checkpoint_path.model_path, + diffusion_model_path=cfg.checkpoint_path.diffusion_model_path, + load_diffusion=cfg.stage2.load_diffusion, + from_pretrained=True if cfg.stage1.init == "SEED" else False, + vit_precision=cfg.optimizer.vit_precision, + diffusion_precision=cfg.optimizer.diffusion_precision, + ) + + self.B = None + + self.transform_224 = transforms.Resize((224, 224), antialias=True) + + # For diffusion DDP + if self.image_tokenizer.diffusion_model is not None: + self.feature_extractor = self.image_tokenizer.diffusion_model.feature_extractor + self.image_encoder = self.image_tokenizer.diffusion_model.image_encoder + self.image_normalizer = self.image_tokenizer.diffusion_model.image_normalizer + self.image_noising_scheduler = self.image_tokenizer.diffusion_model.image_noising_scheduler + self.tokenizer = self.image_tokenizer.diffusion_model.tokenizer + self.text_encoder = self.image_tokenizer.diffusion_model.text_encoder + self.unet = self.image_tokenizer.diffusion_model.unet + self.scheduler = self.image_tokenizer.diffusion_model.scheduler + self.vae = self.image_tokenizer.diffusion_model.vae + + # For SDS + t_range = [0.2, 0.6] + # t_range = [0.02, 0.98] + self.num_train_timesteps = 1000 + self.min_step = int(self.num_train_timesteps * t_range[0]) + self.max_step = int(self.num_train_timesteps * t_range[1]) + self.alphas = self.image_noising_scheduler.alphas_cumprod # for convenience + + # For logging + self.pil_to_tensor = transforms.ToTensor() + self.sample_image_ind = 0 + self.logged_original_image = set() + + self.stage = cfg.experiment.stage + self.temp = nn.Parameter(0.07 * torch.ones([])) + + # For C2F + self.setup_C2F() + + + def setup_C2F(self): + c2f_schedule = torch.linspace(1, self.cfg.experiment.min_pos_weight, 32) + bs = self.cfg.experiment.local_batch_size + vbs = self.cfg.experiment.val_batch_size + n_pos = self.cfg.experiment.num_positive_samples + + self.c2f_schedule = c2f_schedule.unsqueeze(0) + self.pos_or_neg = torch.eye(bs).unsqueeze(1).repeat(1, n_pos, 1) + self.pos_or_neg_valid = torch.eye(vbs).unsqueeze(1).repeat(1, n_pos, 1) + + + def setup(self, stage): + # Setup training parameter + self.image_tokenizer.model.train() + for param in self.image_tokenizer.model.parameters(): + param.requires_grad = True + + # Freeze ViT Encoder + for param in self.image_tokenizer.model.visual_encoder.parameters(): + param.requires_grad = False + + # Diffusion frozen + if self.image_tokenizer.diffusion_model is not None: + for param in self.image_tokenizer.diffusion_model.image_encoder.parameters(): + param.requires_grad = False + for param in self.image_tokenizer.diffusion_model.image_normalizer.parameters(): + param.requires_grad = False + for param in self.image_tokenizer.diffusion_model.text_encoder.parameters(): + param.requires_grad = False + # In this case, unet is frozen + for param in self.image_tokenizer.diffusion_model.unet.parameters(): + param.requires_grad = False + for param in self.image_tokenizer.diffusion_model.vae.parameters(): + param.requires_grad = False + + if self.stage == 1: + if self.cfg.stage1.init == "BLIP-2": + print("Load init weights from BLIP-2") + blip_model = load_model("blip2", "pretrain") + # Update the model with the weights + filtered_state_dict = {k: v for k, v in blip_model.state_dict().items() if k in self.image_tokenizer.model.state_dict()} + self.image_tokenizer.model.load_state_dict(filtered_state_dict, strict=False) + elif self.cfg.stage1.init == "SEED": + print("Load init weights from SEED") + + print("Set stage 2 model not trainable") + for param in self.image_tokenizer.model.quantize.parameters(): + param.requires_grad = False + for param in self.image_tokenizer.model.encode_task_layer.parameters(): + param.requires_grad = False + for param in self.image_tokenizer.model.decode_task_layer.parameters(): + param.requires_grad = False + for param in self.image_tokenizer.model.blocks.parameters(): + param.requires_grad = False + for param in self.image_tokenizer.model.blocks_image.parameters(): + param.requires_grad = False + for param in self.image_tokenizer.model.image_down.parameters(): + param.requires_grad = False + for param in self.image_tokenizer.model.distill_image_proj.parameters(): + param.requires_grad = False + + print("Move stage 2 model to cpu") + self.image_tokenizer.model.quantize = self.image_tokenizer.model.quantize.to("cpu") + self.image_tokenizer.model.encode_task_layer = self.image_tokenizer.model.encode_task_layer.to("cpu") + self.image_tokenizer.model.decode_task_layer = self.image_tokenizer.model.decode_task_layer.to("cpu") + self.image_tokenizer.model.blocks = self.image_tokenizer.model.blocks.to("cpu") + self.image_tokenizer.model.blocks_image = self.image_tokenizer.model.blocks_image.to("cpu") + self.image_tokenizer.model.image_down = self.image_tokenizer.model.image_down.to("cpu") + self.image_tokenizer.model.distill_image_proj = self.image_tokenizer.model.distill_image_proj.to("cpu") + elif self.stage == 2: + self.random_initialize_stage2_model_weights() + if self.cfg.stage2.train_unet: + self.make_unet_trainable_for_img_embeds() + + ## make dump folder + os.makedirs(self.cfg.result_file_path, exist_ok=True) + + def make_unet_trainable_for_img_embeds(self): + for p in self.image_tokenizer.diffusion_model.unet.parameters(): + p.requires_grad = False + + for p in self.image_tokenizer.diffusion_model.unet.class_embedding.parameters(): + p.requires_grad = True + + for block in self.image_tokenizer.diffusion_model.unet.down_blocks: + try: + for resnet in block.resnets: + for p in resnet.time_emb_proj.parameters(): + p.requires_grad = True + except Exception as e: + print(e) + continue + + for block in self.image_tokenizer.diffusion_model.unet.up_blocks: + try: + for resnet in block.resnets: + for p in resnet.time_emb_proj.parameters(): + p.requires_grad = True + except Exception as e: + print(e) + continue + + for resnet in self.image_tokenizer.diffusion_model.unet.mid_block.resnets: + for p in resnet.time_emb_proj.parameters(): + p.requires_grad = True + + def random_initialize_stage2_model_weights(self): + """Random initialize stage 2 model weights + """ + # Random initialize stage 2 model weights + for param in self.image_tokenizer.model.parameters(): + param.requires_grad = False + + # unFreeze stage 2 model and initialize with random weights + for param in self.image_tokenizer.model.encode_task_layer.parameters(): + #nn.init.xavier_uniform_(param) + nn.init.normal_(param, mean=0.0, std=0.02) + param.requires_grad = True + for param in self.image_tokenizer.model.quantize.parameters(): + nn.init.normal_(param, mean=0.0, std=0.02) + param.requires_grad = True + for param in self.image_tokenizer.model.decode_task_layer.parameters(): + nn.init.normal_(param, mean=0.0, std=0.02) + param.requires_grad = True + + for param in self.image_tokenizer.model.blocks_image.parameters(): + nn.init.normal_(param, mean=0.0, std=0.02) + param.requires_grad = True + + for param in self.image_tokenizer.model.image_down.parameters(): + nn.init.normal_(param, mean=0.0, std=0.02) + param.requires_grad = True + for param in self.image_tokenizer.model.distill_image_proj.parameters(): + nn.init.normal_(param, mean=0.0, std=0.02) + param.requires_grad = True + + def save_config(self): + config_save_path = os.path.join(self.logger.log_dir, "config.yaml") + with open(config_save_path, "w") as f: + json.dump(self.cfg, f, indent=4) + + def get_clip_text_embedding(self, batch_text): + """CLIP text embedding + + Args: + batch_text (List): List contains text. [b, 32] + + Returns: + float: clip text embedding [b, 1024] + """ + gt_text_clip_embeddings = [] + with torch.no_grad(): + for idx in range(self.B): + gt_text_clip_embeddings.append( + self.tokenizer(batch_text[idx]).squeeze().to(self.device) + ) + gt_text_clip_embeddings = torch.stack(gt_text_clip_embeddings, dim=0) + + # gt_img_clip_embeddings = self.model_clip.encode_image(batch.img.to(self.device)) + gt_text_clip_embeddings = self.image_encoder.encode_text( + gt_text_clip_embeddings.to(self.device) + ) + return gt_text_clip_embeddings + + def get_clip_img_embedding(self, batch_img): + """CLIP image embedding + + Args: + batch_img (torch.Tensor): Image tensor [b, 3, 224, 224] + + Returns: + float: clip image embedding [b, 1024] + """ + return self.image_encoder(batch_img).image_embeds.to(self.device) + + def get_causal_embeddings(self, image): + return self.image_tokenizer.model.get_causal_embeddings(image) + + def forward_stage_2(self, batch, batch_idx: int, bypass_codebook=False): + """_summary_ + Original forward function for stage 2 + Just to see how the forward function works + + Args: + batch (_type_): _description_ + batch_idx (int): _description_ + + Returns: + _type_: _description_ + """ + + # Causal embedding is trained in stage 1. + # [b, 32, 768] + if len(batch) == 3: + image, text, image_id = batch + elif len(batch) == 2: + image, text = batch + + with torch.no_grad(): + causal_embeddings = self.get_causal_embeddings(image) + + # [b, 32, 768] = > [b, 32, 32] + query_output_down = self.image_tokenizer.model.encode_task_layer(causal_embeddings) + + # quant [b, 32, 32], loss_embed [b, 32, 768], embed_ind [b, 32] + if bypass_codebook: + # Bypass codebook + print("Bypass codebook") + quant = query_output_down + loss_embed = None + embed_ind = None + else: + # Quantize + print("Quantize") + quant, loss_embed, embed_ind = self.image_tokenizer.model.quantize(query_output_down) + embed_ind = embed_ind.reshape(quant.shape[0], -1) + + if bypass_codebook: + # # [b, 32, 32] => [b, 32, 768] + query_output_up = self.image_tokenizer.model.decode_task_layer(quant) + else: + quant_embedding = self.image_tokenizer.model.quantize.get_codebook_entry(embed_ind) + # # [b, 32, 32] => [b, 32, 768] + query_output_up = self.image_tokenizer.model.decode_task_layer(quant_embedding) + + # [b, 32, 768] => [b, 32, 768] + query_output_up = self.image_tokenizer.model.get_transformer_decoded_embedding(query_output_up) + + # [b, 32, 768] => [b, 32, 32] => [b, 1024] + reverse_output_proj = self.image_tokenizer.model.get_mlp_decoded_embedding(query_output_up) + + return reverse_output_proj + + def logging_train_stage2(self, clip_cosine_similarity, loss_dict): + self.log( + "train/clip_cosine_similarity", + clip_cosine_similarity, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + ) + + for loss_name, loss_value in loss_dict.items(): + self.log( + f'train/{loss_name}', + loss_value, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + ) + + def all_gather_with_grad(self, tensors): + """ + Performs all_gather operation on the provided tensors. + Graph remains connected for backward grad computation. + """ + # Queue the gathered tensors + world_size = torch.distributed.get_world_size() + # There is no need for reduction in the single-proc case + if world_size == 1: + return tensors + + # tensor_all = GatherLayer.apply(tensors) + tensor_all = GatherLayer.apply(tensors) + + return torch.cat(tensor_all, dim=0) + + @torch.no_grad() + def concat_all_gather(self, tensor): + """ + Performs all_gather operation on the provided tensors. + *** Warning ***: torch.distributed.all_gather has no gradient. + """ + # if use distributed training + if not is_dist_avail_and_initialized(): + return tensor + + tensors_gather = [ + torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) + ] + torch.distributed.all_gather(tensors_gather, tensor, async_op=False) + + output = torch.cat(tensors_gather, dim=0) + return output + + def get_qformer_tokens(self, image): + with torch.no_grad(): + image_embeds = self.image_tokenizer.model.ln_vision( + self.image_tokenizer.model.visual_encoder(image) + ) + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device) + + query_tokens = self.image_tokenizer.model.query_tokens.expand(image_embeds.shape[0], -1, -1) + + # Assume image_embeds.shape[0] is the batch size (b) and you have 32 tokens (n) + b, n, _ = query_tokens.shape + + query_output = self.image_tokenizer.model.Qformer.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + return query_output + + + def get_stage1_loss_c2f(self, batch, batch_idx: int): + image, captions, token_ids, abs_diff = batch + query_output = self.get_qformer_tokens(image) + + text_tokens = self.image_tokenizer.model.tokenizer( + captions, + padding="max_length", + truncation=True, + max_length=300, + return_tensors="pt", + ) + + return None + + + def get_stage_1_loss_use_last_token(self, batch, batch_idx: int): + """ + Contrastive loss using last token of the query_output + Args: + batch (_type_): _description_ + batch_idx (int): _description_ + is_validation (bool, optional): _description_. Defaults to False. + """ + if len(batch) == 3: + image, text, image_id = batch + elif len(batch) == 2: + image, text = batch + + query_output = self.get_qformer_tokens(image) + + # Use last hidden state + # We have 32 tokens, and use last token as image embedding + # [b, 32, 768] + # TODO: Use 'final' causal embedding? Does it mean to use last token embedding? + # Debug + image_feats = rearrange(query_output.last_hidden_state[:, -1, :], "b d -> b 1 d").contiguous() + image_feats = F.normalize(image_feats, dim=-1) + + text_tokens = self.image_tokenizer.model.tokenizer( + text, + padding="max_length", + truncation=True, + max_length=128, + return_tensors="pt", + ) + + text_output = self.image_tokenizer.model.Qformer.bert( + text_tokens.input_ids.to(self.device), + attention_mask=text_tokens.attention_mask.to(self.device), + return_dict=True, + ) + + # CLS token + # [b, 768] + text_feat = F.normalize(text_output.last_hidden_state[:, 0, :], dim=-1) + + ###============== Image-text Contrastive ===================### + # Compute for each query token + # image_feats_all = self.concat_all_gather( + image_feats_all = self.all_gather_with_grad( + image_feats + ) # [batch_size*num_gpu, num_query_tokens, embed_dim] + # text_feat_all = self.concat_all_gather(text_feat) # [batch_size*num_gpu, embed_dim] + text_feat_all = self.all_gather_with_grad(text_feat) # [batch_size*num_gpu, embed_dim] + + # image_feats.unsqueeze(1) : [batch_size, 1, num_query_tokens, embed_dim] + # text_feat_all.unsqueeze(-1) : [batch_size*num_gpu, embed_dim, 1] => broadcast to [batch_size, batch_size*num_gpu, embed_dim, 1] + # Last two dimensions are broadcasted to all other dimensions + # [j, 1, n, m] x [k, m, p] => [j, k, n, p] + # https://pytorch.org/docs/stable/generated/torch.matmul.html + # sim_q2t : [batch_size, batch_size*num_gpu, num_query_tokens] + sim_q2t = torch.matmul( + rearrange(image_feats, "bs n d -> bs 1 n d"), rearrange(text_feat_all, "(bs ngpus) d -> (bs ngpus) d 1", bs=b) + ).squeeze() + # [batch_size, batch_size*num_gpu, num_query_tokens] + + # Always use last token + # sim_i2t = sim_q2t[:, :, -1] + sim_i2t = sim_q2t + # Debug : Test Original BLIP-2 loss + # sim_i2t, _ = sim_q2t.max(-1) + sim_i2t = sim_i2t / self.temp + + # text-query similarity: [batch_size, batch_size*num_gpu, num_query_tokens] + sim_t2q = torch.matmul( + rearrange(text_feat, "bs d -> bs 1 1 d"), rearrange(image_feats_all, "(bs ngpus) n d -> (bs ngpus) d n", bs=b) + ).squeeze() + # [batch_size, batch_size*num_gpu, num_query_tokens] + + # Always use last token + # sim_t2i = sim_t2q[:, :, -1] + sim_t2i = sim_t2q + # Debug : Test Original BLIP-2 loss + # sim_t2i, _ = sim_t2q.max(-1) + sim_t2i = sim_t2i / self.temp # [batch_size, batch_size*num_gpu] + + rank = dist.get_rank() + bs = image.size(0) + targets = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to( + image.device + ) + + loss_itc = ( + F.cross_entropy(sim_i2t, targets, label_smoothing=0.1) + + F.cross_entropy(sim_t2i, targets, label_smoothing=0.1) + ) / 2 + + self.log( + "train/loss_itc", + loss_itc, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + ) + + return loss_itc + + @torch.no_grad() + def check_image_text_similarity(self, batch, batch_idx: int, save_dir="image_text_similarity"): + if len(batch) == 3: + image, text, image_id = batch + elif len(batch) == 2: + image, text = batch + rank = dist.get_rank() + + with torch.no_grad(): + image_embeds = self.image_tokenizer.model.ln_vision( + self.image_tokenizer.model.visual_encoder(image) + ) + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device) + + query_tokens = self.image_tokenizer.model.query_tokens.expand(image_embeds.shape[0], -1, -1) + + # Assume image_embeds.shape[0] is the batch size (b) and you have 32 tokens (n) + b, n, _ = query_tokens.shape + + query_output = self.image_tokenizer.model.Qformer.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + # Use last hidden state + # We have 32 tokens, and use last token as image embedding + # [b, 32, 768] + image_feats = F.normalize(query_output.last_hidden_state, dim=-1) + + text_tokens = self.image_tokenizer.model.tokenizer( + text, + padding="max_length", + truncation=True, + max_length=128, + return_tensors="pt", + ) + + text_output = self.image_tokenizer.model.Qformer.bert( + text_tokens.input_ids.to(self.device), + attention_mask=text_tokens.attention_mask.to(self.device), + return_dict=True, + ) + + # CLS token + # [b, 768] + text_feat = F.normalize(text_output.last_hidden_state[:, 0, :], dim=-1) + + ###============== Image-text Contrastive ===================### + + # Original BLIP-2 loss + # Compute for each query token + image_feats_all = image_feats # [batch_size, num_query_tokens, embed_dim] + text_feat_all = text_feat # [batch_size, embed_dim] + + # image_feats.unsqueeze(1) : [batch_size, 1, num_query_tokens, embed_dim] + # text_feat_all.unsqueeze(-1) : [batch_size*num_gpu, embed_dim, 1] => broadcast to [batch_size, batch_size*num_gpu, embed_dim, 1] + # Last two dimensions are broadcasted to all other dimensions + # [j, 1, n, m] x [k, m, p] => [j, k, n, p] + # https://pytorch.org/docs/stable/generated/torch.matmul.html + # sim_q2t : [batch_size, batch_size*num_gpu, num_query_tokens] + sim_q2t = torch.matmul( + rearrange(image_feats, "bs n d -> bs 1 n d"), rearrange(text_feat_all, "bs_X_ngpus d -> bs_X_ngpus d 1") + # image_feats.unsqueeze(1), text_feat_all.unsqueeze(-1) + ).squeeze() + # [batch_size, batch_size*num_gpu, num_query_tokens] + + ########### 1. Debug: for check the similarity ############ + # Softmax for each row + dump = [] + for token_num in range(32): + dump.append(F.softmax(sim_q2t[:, :, token_num], dim=1)) + dump = torch.stack(dump, dim=2) + positive_token_similarity = torch.diagonal(dump, dim1=0, dim2=1).mean(dim=1) + # Save positive_token_similarity as bar graph + plt.figure(figsize=(18, 6)) + bars = plt.bar(list(range(32)), positive_token_similarity.cpu().numpy(), color='blue') + plt.xlabel('Token Number') + plt.ylabel('Value') + plt.title('Positive Token Similarity') + plt.xticks(list(range(32))) # Ensure all keys are shown in the x-axis + # Add a table of values next to the bars + cell_text = [[f"{val:.4f}"] for val in positive_token_similarity.cpu().numpy()] + plt.table(cellText=cell_text, colLabels=["Value"], loc='right', cellLoc='center') + + # Adjust layout to make room for the table: + plt.subplots_adjust(right=0.5) + plt.savefig(f"{save_dir}/positive_token_similarity_i2t_batch{batch_idx}_rank{rank}.png") + + ############################################################ + # Debug: for check the similarity + count_dict = {} + for token_num in range(32): + count_dict[token_num] = 0 + for row in range(b): + _, ind = sim_q2t[:, :, token_num][row].max(-1) + if row == ind: + print(f"In token {token_num}, in row {row}, max index is {ind}") + count_dict[token_num] += 1 + print(count_dict) + + # Extracting keys and values + keys = list(count_dict.keys()) + values = list(count_dict.values()) + + # Plotting the histogram + plt.figure(figsize=(10, 6)) + bars = plt.bar(keys, values, color='blue') + plt.xlabel('Token Number') + plt.ylabel('Value') + plt.title('Histogram of Token Values') + plt.xticks(keys) # Ensure all keys are shown in the x-axis + + # Adding the text on top of each bar + for bar in bars: + yval = bar.get_height() + plt.text(bar.get_x() + bar.get_width()/2.0, yval, int(yval), va='bottom') + + os.makedirs(f"{save_dir}", exist_ok=True) + plt.savefig(f"{save_dir}/token_histogram_image_text_batch{batch_idx}_rank{rank}.png") + # plt.show() + ############################################################ + + # text-query similarity: [batch_size, batch_size*num_gpu, num_query_tokens] + sim_t2q = torch.matmul( + rearrange(text_feat, "bs d -> bs 1 1 d"), rearrange(image_feats_all, "bs_X_ngpus n d -> bs_X_ngpus d n") + # text_feat.unsqueeze(1).unsqueeze(1), image_feats_all.permute(0, 2, 1) + ).squeeze() + + # Debug: for check the similarity + count_dict = {} + for token_num in range(32): + count_dict[token_num] = 0 + for row in range(b): + _, ind = sim_t2q[:, :, token_num][row].max(-1) + if row == ind: + print(f"In token {token_num}, in row {row}, max index is {ind}") + count_dict[token_num] += 1 + print(count_dict) + + # Softmax for each row + dump = [] + for token_num in range(32): + dump.append(F.softmax(sim_t2q[:, :, token_num], dim=1)) + dump = torch.stack(dump, dim=2) + positive_token_similarity = torch.diagonal(dump, dim1=0, dim2=1).mean(dim=1) + # Save positive_token_similarity as bar graph + plt.figure(figsize=(18, 6)) + bars = plt.bar(list(range(32)), positive_token_similarity.cpu().numpy(), color='blue') + plt.xlabel('Token Number') + plt.ylabel('Value') + plt.title('Positive Token Similarity') + plt.xticks(list(range(32))) # Ensure all keys are shown in the x-axis + # Add a table of values next to the bars + cell_text = [[f"{val:.4f}"] for val in positive_token_similarity.cpu().numpy()] + plt.table(cellText=cell_text, colLabels=["Value"], loc='right', cellLoc='center') + + # Adjust layout to make room for the table: + plt.subplots_adjust(right=0.5) + plt.savefig(f"{save_dir}/positive_token_similarity_t2i_batch{batch_idx}_rank{rank}.png") + + # Extracting keys and values + keys = list(count_dict.keys()) + values = list(count_dict.values()) + + # Plotting the histogram + plt.figure(figsize=(10, 6)) + bars = plt.bar(keys, values, color='blue') + plt.xlabel('Token Number') + plt.ylabel('Value') + plt.title('Histogram of Token Values') + plt.xticks(keys) # Ensure all keys are shown in the x-axis + + # Adding the text on top of each bar + for bar in bars: + yval = bar.get_height() + plt.text(bar.get_x() + bar.get_width()/2.0, yval, int(yval), va='bottom') + + plt.savefig(f"{save_dir}/token_histogram_text_image_batch{batch_idx}_rank{rank}.png") + + loss_mean = 0 + rank = dist.get_rank() + if rank == 0: + for token in range(32): + bs = image.size(0) + targets = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to( + image.device + ) + + sim_i2t = sim_q2t[:, :, token] / self.temp + sim_t2i = sim_t2q[:, :, token] / self.temp + + loss_itc = ( + F.cross_entropy(sim_i2t, targets, label_smoothing=0.1) + + F.cross_entropy(sim_t2i, targets, label_smoothing=0.1) + ) / 2 + print(f"Loss I2T in Token {token}: {loss_itc}") + loss_mean += loss_itc + + self.log( + f"val/loss_itc_{token}", + loss_itc, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) + + loss_mean /= 32 + self.log( + "val/loss_itc_mean", + loss_mean, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) + + return + + def sds_loss( + self, + image_embeds, + clean_image, + guidance_scale=100, + grad_scale=1, + prompt=None, + prompt_embeds=None, + ): + """Score distillation sampling""" + if prompt is None and prompt_embeds is None: + # prompt = len(image) * [""] if isinstance(image, list) else "" + # Changed because we get image_embeds as input + prompt = image_embeds.shape[0] * [""] if isinstance(image_embeds, torch.Tensor) else "" + + # 2. Define call parameters + batch_size = image_embeds.shape[0] + + device = image_embeds.device + + # Convert images to latent space + # latents = self.vae.encode(clean_image).latent_dist.sample() + # latents = latents * self.vae.config.scaling_factor + + # 3. Encode input prompt + + # [b, 77, 1024] + # Now img2img, prompt_embeds is None + prompt_embeds = self.image_tokenizer.diffusion_model._encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=1, + do_classifier_free_guidance=True, + negative_prompt=None, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=None, + lora_scale=None, + ) + + image_embeds = self.image_tokenizer.diffusion_model._encode_image( + image=None, + device=device, + batch_size=batch_size, + num_images_per_prompt=1, + do_classifier_free_guidance=True, + noise_level=0, + generator=None, + image_embeds=image_embeds, + negative_image_embeds=None, + ) + do_classifier_free_guidance = True + + latents = self.vae.encode(clean_image).latent_dist.sample() + latents = latents * self.vae.config.scaling_factor + + # timestep ~ U(0.02, 0.98) to avoid very high/low noise level + t = torch.randint( + self.min_step, self.max_step + 1, [1], dtype=torch.long, device=device + ) + + # predict the noise residual with unet, NO grad! + with torch.no_grad(): + # add noise + noise = torch.randn_like(latents) + latents_noisy = self.scheduler.add_noise(latents, noise, t) + # pred noise + latent_model_input = torch.cat([latents_noisy] * 2) + + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + class_labels=image_embeds, + ).sample + + # perform guidance (high scale from paper!) + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + # w(t), sigma_t^2 + self.alphas = self.alphas.to(device) + w = 1 - self.alphas[t] + grad = grad_scale * w * (noise_pred - noise) + grad = torch.nan_to_num(grad) + + targets = (latents - grad) + # loss = 0.5 * F.mse_loss(latents.float(), targets, reduction='sum') / latents.shape[0] + # Why not mean? + loss = 0.5 * F.mse_loss(latents.float(), targets, reduction='mean') + + return loss + + def clip_loss(self, image_embeds, gt_img_clip_embeddings): + similarity_target = torch.ones(image_embeds.shape[0], device=image_embeds.device) + loss_clip = torch.nn.functional.cosine_embedding_loss(image_embeds, gt_img_clip_embeddings, similarity_target) + return loss_clip + + def get_stage_2_loss_bypass_codebook(self, batch, batch_idx: int): + """_summary_ + + Args: + batch (_type_): _description_ + batch_idx (int): _description_ + """ + #------------------------ + # Stage 2 Training + #------------------------ + img, text = batch + + #------------------------ + # Stage 2 - 1 : Codebook Training + #------------------------ + with torch.no_grad(): + causal_embeddings = self.get_causal_embeddings(img) + gt_img_clip_embeddings = self.get_clip_img_embedding(img) + + # TODO: query_output should be trained to be similar with text embedding + # Image embedding is cross attentioned. + # Notice: query_output_down is match to clip embedding? + # [b, 32, 32] + query_output_down = self.image_tokenizer.model.encode_task_layer(causal_embeddings) + + # quant [b, 32, 32], loss_embed [b, 32, 768], embed_ind [b, 32] + # bypass code book + quant = query_output_down + + #------------------------ + # Stage 2 - 2 : Reconstruction Caual Embedding + #------------------------ + + # quant embedding dimension is [b, 32, 32] + # decoder_task_layer upscale it to [b, 32, 768] + # [b, 32, 32] => [b, 32, 768] + query_output_up = self.image_tokenizer.model.decode_task_layer(quant) + + # Transformer decoder + query_output_up = self.image_tokenizer.model.get_transformer_decoded_embedding(query_output_up) + + # query_output_up_pos_image should be similar to original causal_embeddings + # Maximize cosine similarity between query_output_up_pos_image and causal_embeddings + + #loss_recon = F.cosine_similarity(query_output_up, causal_embeddings).mean() + loss_recon = F.mse_loss(query_output_up, causal_embeddings) + loss_dict = { + "loss_recon": loss_recon, + } + loss_total = self.cfg.experiment.recon_loss_weight * loss_recon + + #------------------------ + # Stage 2 - 3 : Reconstruction Generation Embedding + #------------------------ + + # MLP + # query_output_up = causal_embeddings + image_embeds = self.image_tokenizer.model.get_mlp_decoded_embedding(query_output_up) + gt_img_clip_embeddings.requires_grad = False + + sds_loss_weight = self.cfg.experiment.sds_loss_weight * self.sds_loss_weights[self.global_step] + if self.cfg.experiment.clip_loss_weight > 0: + loss_clip = F.mse_loss(image_embeds, gt_img_clip_embeddings) + # loss_clip = self.clip_loss(image_embeds, gt_img_clip_embeddings) + loss_dict['clip_loss'] = loss_clip + _loss_clip = self.cfg.experiment.clip_loss_weight * loss_clip + if self.cfg.experiment.cross_annealing: + _loss_clip *= (1 - sds_loss_weight) + loss_total += _loss_clip + + if sds_loss_weight > 0: + loss_sds = self.sds_loss( + image_embeds=image_embeds, + clean_image=img, + guidance_scale=10, + grad_scale=1, + ) + + loss_dict['loss_sds'] = loss_sds + loss_dict['sds_weight'] = sds_loss_weight + loss_total += sds_loss_weight * loss_sds + + loss_dict['loss'] = loss_total + #------------------------ + # Logging + #------------------------ + with torch.no_grad(): + clip_cosine_similarity = F.cosine_similarity(image_embeds, gt_img_clip_embeddings).mean() + + self.logging_train_stage2(clip_cosine_similarity, loss_dict) + + return loss_total + + def get_stage_2_loss(self, batch, batch_idx: int): + """_summary_ + + Args: + batch (_type_): _description_ + batch_idx (int): _description_ + """ + #------------------------ + # Stage 2 Training + #------------------------ + device = self.device + if len(batch) == 3: + img, text, image_id = batch + elif len(batch) == 2: + img, text = batch + + #------------------------ + # Stage 2 - 1 : Codebook Training + #------------------------ + with torch.no_grad(): + causal_embeddings = self.get_causal_embeddings(img) + + # TODO: query_output should be trained to be similar with text embedding + # Image embedding is cross attentioned. + # Notice: query_output_down is match to clip embedding? + # [b, 32, 32] + query_output_down = self.image_tokenizer.model.encode_task_layer(causal_embeddings) + + # quant [b, 32, 32], loss_embed [b, 32, 768], embed_ind [b, 32] + quant, loss_embed, embed_ind = self.image_tokenizer.model.quantize(query_output_down) + + #------------------------ + # Stage 2 - 2 : Reconstruction Caual Embedding + #------------------------ + + # quant embedding dimension is [b, 32, 32] + # decoder_task_layer upscale it to [b, 32, 768] + # [b, 32, 32] => [b, 32, 768] + query_output_up = self.image_tokenizer.model.decode_task_layer(quant) + + # Transformer decoder + query_output_up = self.image_tokenizer.model.get_transformer_decoded_embedding(query_output_up) + + # query_output_up_pos_image should be similar to original causal_embeddings + # Maximize cosine similarity between query_output_up_pos_image and causal_embeddings + + #loss_recon = F.cosine_similarity(query_output_up, causal_embeddings).mean() + loss_recon = F.mse_loss(query_output_up, causal_embeddings) + + #------------------------ + # Stage 2 - 3 : Reconstruction Generation Embedding + #------------------------ + + # MLP + # query_output_up = causal_embeddings + reverse_output_proj = self.image_tokenizer.model.get_mlp_decoded_embedding(query_output_up) + + gt_img_clip_embeddings = self.get_clip_img_embedding(img) + + loss_generation_embed = F.mse_loss(reverse_output_proj, gt_img_clip_embeddings) + + loss_total = loss_embed + loss_recon + loss_generation_embed + loss_total = loss_total.mean() + + # loss_dict = {"loss_embed": loss_embed, "loss_recon": loss_recon, + # "loss_generation_embed": loss_generation_embed, + # "loss": loss_total} + + loss_dict = {"loss_generation_embed": loss_generation_embed, + "loss_embed": loss_embed, + "loss_recon": loss_recon, + "loss": loss_total} + + #------------------------ + # Logging + #------------------------ + generation_embedding_cosine_similarity = F.cosine_similarity(reverse_output_proj, gt_img_clip_embeddings).mean() + + self.logging_train_stage2(generation_embedding_cosine_similarity, loss_dict) + + return loss_total + + def on_train_start(self): + print(f"\n====Traing Stage {self.stage}====") + if self.stage == 2 and self.cfg.stage2.bypass_codebook: + print("\n====Bypass codebook====") + + print("Save config") + self.save_config() + + def training_step(self, batch, batch_idx: int): + self.B = batch[0].shape[0] + + if self.stage == 1: + # loss = self.get_stage_1_loss_use_last_token(batch, batch_idx) + # loss = self.get_stage1_loss_c2f(batch, batch_idx) + image, pos_ids, text_tokens, text_attention_masks = batch + query_output = self.get_qformer_tokens(image) + image_feats = torch.stack([query_output.last_hidden_state[i].index_select(0, pos_ids[i]) for i in range(self.B)]) + image_feats = F.normalize(image_feats, dim=-1) + + batch_size, num_positive_samples, token_lens = text_tokens.shape + + text_output = self.image_tokenizer.model.Qformer.bert( + text_tokens.reshape(batch_size * num_positive_samples, -1), + attention_mask=text_attention_masks.reshape(batch_size * num_positive_samples, -1), + return_dict=True, + ) + + # CLS token + # [batch_size * num_positive_samples, embed_dim] + text_feats = F.normalize(text_output.last_hidden_state[:, 0, :], dim=-1) + text_feats = rearrange(text_feats, "bs*n d -> bs n d", bs=self.B) + text_feats = text_feats.reshape(batch_size, num_positive_samples, -1).contiguous() + + ###============== Image-text Contrastive ===================### + # [batch_size*num_gpu, num_positive_samples, embed_dim] + image_feats_all = self.all_gather_with_grad(image_feats) + + # [batch_size*num_gpu, num_positive_samples, embed_dim] + text_feats_all = self.all_gather_with_grad(text_feats) + + mat_image = rearrange(image_feats, "bs n d -> bs 1 1 n d") + mat_text_all = rearrange(text_feats_all, "(bs ngpus) n d -> (bs ngpus) n d 1", bs=self.B) + sim_i2t = torch.matmul(mat_image, mat_text_all).squeeze(-1) + + mat_text = rearrange(text_feats, "bs n d -> bs 1 1 n d") + mat_image_all = rearrange(image_feats_all, "(bs ngpus) n d -> (bs ngpus) n d 1", bs=self.B) + sim_t2i = torch.matmul(mat_text, mat_image_all).squeeze(-1) + + target_pos_idx = torch.randperm(num_positive_samples)[:self.B].unsqueeze(-1) + target_pos_ids = pos_ids.gather(1, target_pos_idx) + abs_diff = torch.abs(pos_ids - target_pos_ids) + + schedule = self.c2f_schedule.repeat(self.B, 1).gather(1, abs_diff) + targets = self.pos_or_neg * schedule.unsqueeze(-1) + + selected_sim_i2t = torch.stack([sim_i2t[i].index_select(1, target_pos_idx[i]) for i in range(self.B)]) + selected_sim_i2t = selected_sim_i2t.squeeze(-1) + selected_sim_t2i = torch.stack([sim_t2i[i].index_select(1, target_pos_idx[i]) for i in range(self.B)]) + selected_sim_t2i = selected_sim_t2i.squeeze(-1) + + loss_c2f_itc = ( + F.cross_entropy(selected_sim_i2t, targets, label_smoothing=0.1) + + F.cross_entropy(selected_sim_i2t, targets, label_smoothing=0.1) + ) + + + + elif self.stage == 2: + if self.cfg.stage2.bypass_codebook: + loss = self.get_stage_2_loss_bypass_codebook(batch, batch_idx) + else: + loss = self.get_stage_2_loss(batch, batch_idx) + + return loss + + def on_before_optimizer_step(self, optimizer, optimizer_idx): + # Compute the 2-norm for each layer + # If using mixed precision, the gradients are already unscaled here + # {'grad_2.0_norm/weight': 0.0003, 'grad_2.0_norm/bias': 0.0, 'grad_2.0_norm_total': 0.0003} + if self.cfg.experiment.stage == 1: + norms_0 = grad_norm(self.image_tokenizer.model.Qformer.bert.encoder.layer[0].attention.self.value, norm_type=2) + for norm in norms_0.keys(): + self.logger.experiment.add_scalar( + f"grad_norm/image_tokenizer/model/Qformer/bert/encoder/layer/0/attention/self/value/{norm}", + norms_0[norm], + global_step=self.global_step, + ) + norms_1 = grad_norm(self.image_tokenizer.model.Qformer.bert.encoder.layer[1].attention.self.value, norm_type=2) + for norm in norms_1.keys(): + self.logger.experiment.add_scalar( + f"grad_norm/image_tokenizer/model/Qformer/bert/encoder/layer/1/attention/self/value/{norm}", + norms_1[norm], + global_step=self.global_step, + ) + norms_7 = grad_norm(self.image_tokenizer.model.Qformer.bert.encoder.layer[7].attention.self.value, norm_type=2) + for norm in norms_7.keys(): + self.logger.experiment.add_scalar( + f"grad_norm/image_tokenizer/model/Qformer/bert/encoder/layer/7/attention/self/value/{norm}", + norms_7[norm], + global_step=self.global_step, + ) + elif self.cfg.experiment.stage == 2: + codebook_norm = grad_norm(self.image_tokenizer.model.quantize.embedding, norm_type=2) + for norm in codebook_norm.keys(): + self.logger.experiment.add_scalar( + f"grad_norm/image_tokenizer/model/quantize/{norm}", + codebook_norm[norm], + global_step=self.global_step, + ) + + transformer_decoder_norm = grad_norm(self.image_tokenizer.model.blocks_image, norm_type=2) + for norm in transformer_decoder_norm.keys(): + self.logger.experiment.add_scalar( + f"grad_norm/image_tokenizer/model/blocks_image/{norm}", + transformer_decoder_norm[norm], + global_step=self.global_step, + ) + + generation_mlp_norm = grad_norm(self.image_tokenizer.model.distill_image_proj, norm_type=2) + for norm in generation_mlp_norm.keys(): + self.logger.experiment.add_scalar( + f"grad_norm/image_tokenizer/model/distill_image_proj/{norm}", + generation_mlp_norm[norm], + global_step=self.global_step, + ) + + def on_validation_epoch_start(self): + os.makedirs(f"{self.cfg.result_file_path}/{self.current_epoch}", exist_ok=True) + + @torch.no_grad() + def validation_step(self, batch, batch_idx: int, save_path=None): + if self.logger is not None and isinstance(self.logger, pl.loggers.TensorBoardLogger): + tb_log_dir = self.logger.log_dir + else: + tb_log_dir = self.cfg.result_file_path # Fallback directory if logger is not set + + if self.stage == 1: + save_path = f"{tb_log_dir}/histogram" + os.makedirs(save_path, exist_ok=True) + + save_path = f"{tb_log_dir}/histogram/epoch_{self.current_epoch}" + os.makedirs(save_path, exist_ok=True) + + self.check_image_text_similarity(batch, batch_idx, save_dir=save_path) + elif self.stage == 2: + image, captions, image_id = batch + bypass_codebook = self.cfg.stage2.bypass_codebook + + with torch.no_grad(): + image_embeds = self.forward_stage_2(batch, batch_idx, bypass_codebook) + reconstructed_images = self.image_tokenizer.diffusion_model( + image_embeds=image_embeds, + negative_image_embeds=None, + guidance_scale=10, + noise_level=0, + latents=self.image_tokenizer.latents, + ).images + + save_path = f"{tb_log_dir}/images/version_{self.logger.version}/epoch_{self.current_epoch}/images" + os.makedirs(save_path, exist_ok=True) + + tensor_images = [] + + for img, cur_id in zip(reconstructed_images, image_id): + # save PIL image to save_path + img.save(f"{save_path}/{cur_id}") + + # For tensorboard logging + tensor_images.append(self.pil_to_tensor(img).unsqueeze(0)) + + tensor_images = torch.cat(tensor_images, dim=0) + + # Check if image is already logged + if batch_idx not in self.logged_original_image: + self.logger.experiment.add_images( + f"original/image_batch_{batch_idx}", + image, + ) + + self.logger.experiment.add_images( + f"original/image_batch_{batch_idx}_seed_reconstructed", + tensor_images, + ) + + # logging original caption + for caption in captions: + self.logger.experiment.add_text( + f"original/gt_text_image_batch_{batch_idx}", + caption, + ) + + self.logged_original_image.add(batch_idx) + else: + self.logger.experiment.add_images( + f"images/image_batch_{batch_idx}", + tensor_images, + global_step=self.sample_image_ind, + ) + self.sample_image_ind += 1 + + def on_validation_epoch_end(self): + if self.logger is not None and isinstance(self.logger, pl_loggers.TensorBoardLogger): + tb_log_dir = self.logger.log_dir + else: + tb_log_dir = self.cfg.result_file_path + + original_image_dir = self.cfg.dataset.val_config.root_dir + generated_image_dir = f"{tb_log_dir}/images/version_{self.logger.version}/epoch_{self.current_epoch}/images" + clip_score = calculate_clip_s_for_folder(original_image_dir, generated_image_dir) + + print(f"clip score: {clip_score}") + self.log_dict({ + 'clip_score': clip_score, + },on_step=False, on_epoch=True, prog_bar=True, logger=True) + + self.log( + "clip_score_coco_karpathy", + clip_score, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + ) + + def configure_optimizers(self): + # optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3, betas=(0.9, 0.999), weight_decay=1e-8) + lr = self.cfg.optimizer.max_lr + betas = (self.cfg.hyperparameters.beta_1, self.cfg.hyperparameters.beta_2) + weight_decay = self.cfg.hyperparameters.weight_decay + optimizer = torch.optim.AdamW(self.parameters(), lr=lr, betas=betas, weight_decay=weight_decay) + + #scheduler = transformers.get_cosine_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=100, num_training_steps=5000) + num_training_steps = self.cfg.experiment.total_training_steps + num_warmup_steps = self.cfg.experiment.num_warmup_steps + scheduler = transformers.get_cosine_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps + ) + + if self.cfg.experiment.sds_loss_weight > 0 and self.cfg.experiment.use_sds_loss_schedule: + _num_training_steps = num_training_steps // 8 + def f(current_step: int): + return 1 - max(0.0, float(_num_training_steps - current_step) / float(_num_training_steps)) + else: + def f(current_step: int): + return 1 + + x = np.arange(0, num_training_steps) + self.sds_loss_weights = np.array(list(map(f, x))) + + lr_scheduler_config = { + "scheduler": scheduler, + "name": "learning_rate", + "interval": "step", + "frequency": 1, + } + + return {"optimizer": optimizer, + "lr_scheduler": lr_scheduler_config,} + +if __name__ == "__main__": + cfg, cfg_yaml = build_config() + device = "cuda" if torch.cuda.is_available() else "cpu" + + seed_everything(cfg.experiment.seed, workers=True) + + transform_cfg = OmegaConf.load(cfg.transform_cfg_path) + transform = hydra.utils.instantiate(transform_cfg) + + os.makedirs(cfg.result_file_path, exist_ok=True) + + tb_logger = pl_loggers.TensorBoardLogger(save_dir=cfg.result_file_path) + lr_logger = pl.callbacks.LearningRateMonitor(logging_interval="step") + + trainer = pl.Trainer( + accelerator=device, + num_nodes=cfg.dist.n_nodes, + devices=cfg.dist.n_gpus, + strategy="ddp", + max_epochs=cfg.experiment.max_epochs, + deterministic=cfg.experiment.deterministic, + logger=tb_logger, + log_every_n_steps=cfg.experiment.log_every_n_steps, + val_check_interval=cfg.experiment.val_check_interval, + # check_val_every_n_epoch=cfg.experiment.check_val_every_n_epoch, + enable_checkpointing=cfg.experiment.enable_checkpointing, + num_sanity_val_steps=cfg.experiment.num_sanity_val_steps, + precision=str(cfg.optimizer.precision), + callbacks=[ModelSummary(max_depth=3), lr_logger], + accumulate_grad_batches=cfg.experiment.grad_accumulation, + gradient_clip_val=cfg.optimizer.grad_clip_val, + ) + + if cfg.load_weight and cfg.resume: + raise ValueError("Only checkpoint or finetune") + + wrapper = SEEDTrainingWrapper(cfg).to(device) + + datamodule = SEEDDataModule(cfg, tokenizer=wrapper.text_tokenizer, transform=transform) + datamodule.setup() + train_dataloader = datamodule.train_dataloader() + val_dataloader = datamodule.val_dataloader() + + cfg.experiment.total_training_steps = datamodule.total_training_steps + + if cfg.load_weight: + wrapper.load_from_checkpoint(cfg.weight_path, cfg=cfg) + print("Loaded model from checkpoint") + else: + if cfg.experiment.stage == 1: + print("Stage 1 init from BLIP-2") + elif cfg.experiment.stage == 2: + print("Stage 2 init from Scratch") + + wrapper.setup("fit") + + if cfg.resume: + # Resume training + if cfg.weight_path is None: + raise ValueError("checkpoint_path is None") + else: + print(f"Resume training from {cfg.weight_path}") + trainer.fit( + wrapper, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, + ckpt_path=cfg.weight_path + ) + else: + print("Start training") + trainer.fit( + wrapper, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader + ) + + trainer.strategy.barrier() diff --git a/train_v7_unified_llm.py b/train_v7_unified_llm.py new file mode 100755 index 0000000..5a6eee4 --- /dev/null +++ b/train_v7_unified_llm.py @@ -0,0 +1,1198 @@ +import os +from typing import Any, List +import torch +from torch.cuda.amp import autocast +import torch.nn as nn +from torch.nn import functional as F +import torch.distributed as dist + +import json + +import hydra +import torchvision.transforms as transforms +import pytorch_lightning as pl +from pytorch_lightning import LightningModule, seed_everything +import numpy as np +from omegaconf import OmegaConf +from PIL import Image +from torchvision.transforms.functional import to_pil_image +import pyrootutils + +from torch.distributed.algorithms.ddp_comm_hooks import default_hooks +import torch.nn.functional as F +from pytorch_lightning.strategies import DDPStrategy, DDPFullyShardedStrategy +from einops import rearrange +import transformers + +from pytorch_lightning import loggers as pl_loggers +from functools import partial +from pytorch_lightning.callbacks import ModelSummary +from pytorch_lightning.utilities import grad_norm + +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import matplotlib.pyplot as plt + +from models.seed_qformer.vit import Block +from models.seed_llama_tokenizer import ImageTokenizer + +from datamodules.seed_llama_datamodule import SEEDDataModule + +from calculate_clip_score import calculate_clip_s_for_folder +from utils.config import build_config + +from lavis.models import load_model +from lavis.common.dist_utils import is_dist_avail_and_initialized + +pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) + +BOI_TOKEN = "" +EOI_TOKEN = "" +IMG_TOKEN = "" + +IMG_FLAG = "" +NUM_IMG_TOKNES = 32 +NUM_IMG_CODES = 8192 +IMAGE_ID_SHIFT = 32000 + +class GatherLayer(torch.autograd.Function): + """ + Gather tensors from all workers with support for backward propagation: + This implementation does not cut the gradients as torch.distributed.all_gather does. + """ + + @staticmethod + def forward(ctx, x): + output = [ + torch.zeros_like(x) for _ in range(torch.distributed.get_world_size()) + ] + torch.distributed.all_gather(output, x) + return tuple(output) + + @staticmethod + def backward(ctx, *grads): + all_gradients = torch.stack(grads) + torch.distributed.all_reduce(all_gradients) + return all_gradients[torch.distributed.get_rank()] + +class SEEDTrainingWrapper(LightningModule): + """Training wrapper for SEED + + Args: + LightningModule (cfg, model): model should be ImageTokenizer + """ + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + # ImageTokenizer model + # Target model to train + self.image_tokenizer = ImageTokenizer( + model_path=cfg.checkpoint_path.model_path, + diffusion_model_path=cfg.checkpoint_path.diffusion_model_path, + load_diffusion=cfg.stage2.load_diffusion, + from_pretrained=True if cfg.stage1.init == "SEED" else False, + vit_precision=cfg.optimizer.vit_precision, + diffusion_precision=cfg.optimizer.diffusion_precision, + ) + + self.B = None + + self.transform_224 = transforms.Resize((224, 224), antialias=True) + + # For diffusion DDP + if self.image_tokenizer.diffusion_model is not None: + self.feature_extractor = self.image_tokenizer.diffusion_model.feature_extractor + self.image_encoder = self.image_tokenizer.diffusion_model.image_encoder + self.image_normalizer = self.image_tokenizer.diffusion_model.image_normalizer + self.image_noising_scheduler = self.image_tokenizer.diffusion_model.image_noising_scheduler + self.tokenizer = self.image_tokenizer.diffusion_model.tokenizer + self.text_encoder = self.image_tokenizer.diffusion_model.text_encoder + self.unet = self.image_tokenizer.diffusion_model.unet + self.scheduler = self.image_tokenizer.diffusion_model.scheduler + self.vae = self.image_tokenizer.diffusion_model.vae + + # For logging + self.pil_to_tensor = transforms.ToTensor() + self.sample_image_ind = 0 + self.logged_original_image = set() + + self.stage = cfg.experiment.stage + self.temp = nn.Parameter(0.07 * torch.ones([])) + + def setup(self, stage): + # Setup training parameter + self.image_tokenizer.model.train() + for param in self.image_tokenizer.model.parameters(): + param.requires_grad = True + + # Freeze ViT Encoder + for param in self.image_tokenizer.model.visual_encoder.parameters(): + param.requires_grad = False + + # Diffusion frozen + if self.image_tokenizer.diffusion_model is not None: + for param in self.image_tokenizer.diffusion_model.image_encoder.parameters(): + param.requires_grad = False + for param in self.image_tokenizer.diffusion_model.image_normalizer.parameters(): + param.requires_grad = False + for param in self.image_tokenizer.diffusion_model.text_encoder.parameters(): + param.requires_grad = False + # In this case, unet is frozen + for param in self.image_tokenizer.diffusion_model.unet.parameters(): + param.requires_grad = False + for param in self.image_tokenizer.diffusion_model.vae.parameters(): + param.requires_grad = False + + if self.stage == 1: + if self.cfg.stage1.init == "BLIP-2": + print("Load init weights from BLIP-2") + blip_model = load_model("blip2", "pretrain") + # Update the model with the weights + filtered_state_dict = {k: v for k, v in blip_model.state_dict().items() if k in self.image_tokenizer.model.state_dict()} + self.image_tokenizer.model.load_state_dict(filtered_state_dict, strict=False) + elif self.cfg.stage1.init == "SEED": + print("Load init weights from SEED") + + print("Set stage 2 model not trainable") + for param in self.image_tokenizer.model.quantize.parameters(): + param.requires_grad = False + for param in self.image_tokenizer.model.encode_task_layer.parameters(): + param.requires_grad = False + for param in self.image_tokenizer.model.decode_task_layer.parameters(): + param.requires_grad = False + for param in self.image_tokenizer.model.blocks.parameters(): + param.requires_grad = False + for param in self.image_tokenizer.model.blocks_image.parameters(): + param.requires_grad = False + for param in self.image_tokenizer.model.image_down.parameters(): + param.requires_grad = False + for param in self.image_tokenizer.model.distill_image_proj.parameters(): + param.requires_grad = False + + print("Move stage 2 model to cpu") + self.image_tokenizer.model.quantize = self.image_tokenizer.model.quantize.to("cpu") + self.image_tokenizer.model.encode_task_layer = self.image_tokenizer.model.encode_task_layer.to("cpu") + self.image_tokenizer.model.decode_task_layer = self.image_tokenizer.model.decode_task_layer.to("cpu") + self.image_tokenizer.model.blocks = self.image_tokenizer.model.blocks.to("cpu") + self.image_tokenizer.model.blocks_image = self.image_tokenizer.model.blocks_image.to("cpu") + self.image_tokenizer.model.image_down = self.image_tokenizer.model.image_down.to("cpu") + self.image_tokenizer.model.distill_image_proj = self.image_tokenizer.model.distill_image_proj.to("cpu") + elif self.stage == 2: + self.random_initialize_stage2_model_weights() + + ## make dump folder + os.makedirs(self.cfg.result_file_path, exist_ok=True) + + def random_initialize_stage2_model_weights(self): + """Random initialize stage 2 model weights + """ + # Random initialize stage 2 model weights + for param in self.image_tokenizer.model.parameters(): + param.requires_grad = False + + # unFreeze stage 2 model and initialize with random weights + for param in self.image_tokenizer.model.encode_task_layer.parameters(): + #nn.init.xavier_uniform_(param) + nn.init.normal_(param, mean=0.0, std=0.02) + param.requires_grad = True + for param in self.image_tokenizer.model.quantize.parameters(): + nn.init.normal_(param, mean=0.0, std=0.02) + param.requires_grad = True + for param in self.image_tokenizer.model.decode_task_layer.parameters(): + nn.init.normal_(param, mean=0.0, std=0.02) + param.requires_grad = True + + for param in self.image_tokenizer.model.blocks_image.parameters(): + nn.init.normal_(param, mean=0.0, std=0.02) + param.requires_grad = True + + for param in self.image_tokenizer.model.image_down.parameters(): + nn.init.normal_(param, mean=0.0, std=0.02) + param.requires_grad = True + for param in self.image_tokenizer.model.distill_image_proj.parameters(): + nn.init.normal_(param, mean=0.0, std=0.02) + param.requires_grad = True + + def save_config(self): + config_save_path = os.path.join(self.logger.log_dir, "config.yaml") + with open(config_save_path, "w") as f: + json.dump(self.cfg, f, indent=4) + + def get_clip_text_embedding(self, batch_text): + """CLIP text embedding + + Args: + batch_text (List): List contains text. [b, 32] + + Returns: + float: clip text embedding [b, 1024] + """ + gt_text_clip_embeddings = [] + with torch.no_grad(): + for idx in range(self.B): + gt_text_clip_embeddings.append( + self.tokenizer(batch_text[idx]).squeeze().to(self.device) + ) + gt_text_clip_embeddings = torch.stack(gt_text_clip_embeddings, dim=0) + + # gt_img_clip_embeddings = self.model_clip.encode_image(batch.img.to(self.device)) + gt_text_clip_embeddings = self.image_encoder.encode_text( + gt_text_clip_embeddings.to(self.device) + ) + return gt_text_clip_embeddings + + def get_clip_img_embedding(self, batch_img): + """CLIP image embedding + + Args: + batch_img (torch.Tensor): Image tensor [b, 3, 224, 224] + + Returns: + float: clip image embedding [b, 1024] + """ + return self.image_encoder(batch_img).image_embeds.to(self.device) + + def get_causal_embeddings(self, image): + return self.image_tokenizer.model.get_causal_embeddings(image) + + def forward_stage_2(self, batch, batch_idx: int, bypass_codebook=False): + """_summary_ + Original forward function for stage 2 + Just to see how the forward function works + + Args: + batch (_type_): _description_ + batch_idx (int): _description_ + + Returns: + _type_: _description_ + """ + + # Causal embedding is trained in stage 1. + # [b, 32, 768] + if len(batch) == 3: + image, text, image_id = batch + elif len(batch) == 2: + image, text = batch + + with torch.no_grad(): + causal_embeddings = self.get_causal_embeddings(image) + + # [b, 32, 768] = > [b, 32, 32] + query_output_down = self.image_tokenizer.model.encode_task_layer(causal_embeddings) + + # quant [b, 32, 32], loss_embed [b, 32, 768], embed_ind [b, 32] + if bypass_codebook: + # Bypass codebook + print("Bypass codebook") + quant = query_output_down + loss_embed = None + embed_ind = None + else: + # Quantize + print("Quantize") + quant, loss_embed, embed_ind = self.image_tokenizer.model.quantize(query_output_down) + embed_ind = embed_ind.reshape(quant.shape[0], -1) + + # [b, 32, 32] => [b, 32, 768] + query_output_up = self.image_tokenizer.model.decode_task_layer(quant) + + # [b, 32, 768] => [b, 32, 768] + query_output_up = self.image_tokenizer.model.get_transformer_decoded_embedding(query_output_up) + + # [b, 32, 768] => [b, 32, 32] => [b, 1024] + reverse_output_proj = self.image_tokenizer.model.get_mlp_decoded_embedding(query_output_up) + + return reverse_output_proj + + def logging_train_stage2(self, generation_embedding_cosine_similarity, loss_dict): + self.log( + "train/generation_embedding_mse_loss", + loss_dict["loss_generation_embed"].mean(), + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + ) + + if "loss_embed" in loss_dict.keys(): + self.log( + "train/codebook_loss_embed", + loss_dict["loss_embed"].mean(), + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + ) + + self.log( + "train/reconstruction_loss", + loss_dict["loss_recon"].mean(), + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + ) + + self.log( + "train/total_loss", + loss_dict["loss"].mean(), + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + ) + + def all_gather_with_grad(self, tensors): + """ + Performs all_gather operation on the provided tensors. + Graph remains connected for backward grad computation. + """ + # Queue the gathered tensors + world_size = torch.distributed.get_world_size() + # There is no need for reduction in the single-proc case + if world_size == 1: + return tensors + + # tensor_all = GatherLayer.apply(tensors) + tensor_all = GatherLayer.apply(tensors) + + return torch.cat(tensor_all, dim=0) + + @torch.no_grad() + def concat_all_gather(self, tensor): + """ + Performs all_gather operation on the provided tensors. + *** Warning ***: torch.distributed.all_gather has no gradient. + """ + # if use distributed training + if not is_dist_avail_and_initialized(): + return tensor + + tensors_gather = [ + torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) + ] + torch.distributed.all_gather(tensors_gather, tensor, async_op=False) + + output = torch.cat(tensors_gather, dim=0) + return output + + def get_stage_1_loss_use_last_token(self, batch, batch_idx: int): + """ + Contrastive loss using last token of the query_output + Args: + batch (_type_): _description_ + batch_idx (int): _description_ + is_validation (bool, optional): _description_. Defaults to False. + """ + if len(batch) == 3: + image, text, image_id = batch + elif len(batch) == 2: + image, text = batch + + with torch.no_grad(): + image_embeds = self.image_tokenizer.model.ln_vision( + self.image_tokenizer.model.visual_encoder(image) + ) + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device) + + query_tokens = self.image_tokenizer.model.query_tokens.expand(image_embeds.shape[0], -1, -1) + + # Assume image_embeds.shape[0] is the batch size (b) and you have 32 tokens (n) + b, n, _ = query_tokens.shape + + query_output = self.image_tokenizer.model.Qformer.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + # Use last hidden state + # We have 32 tokens, and use last token as image embedding + # [b, 32, 768] + # TODO: Use 'final' causal embedding? Does it mean to use last token embedding? + # Debug + image_feats = rearrange(query_output.last_hidden_state[:, -1, :], "b d -> b 1 d").contiguous() + image_feats = F.normalize(image_feats, dim=-1) + + text_tokens = self.image_tokenizer.model.tokenizer( + text, + padding="max_length", + truncation=True, + max_length=128, + return_tensors="pt", + ) + + text_output = self.image_tokenizer.model.Qformer.bert( + text_tokens.input_ids.to(self.device), + attention_mask=text_tokens.attention_mask.to(self.device), + return_dict=True, + ) + + # CLS token + # [b, 768] + text_feat = F.normalize(text_output.last_hidden_state[:, 0, :], dim=-1) + + ###============== Image-text Contrastive ===================### + # Compute for each query token + # image_feats_all = self.concat_all_gather( + image_feats_all = self.all_gather_with_grad( + image_feats + ) # [batch_size*num_gpu, num_query_tokens, embed_dim] + # text_feat_all = self.concat_all_gather(text_feat) # [batch_size*num_gpu, embed_dim] + text_feat_all = self.all_gather_with_grad(text_feat) # [batch_size*num_gpu, embed_dim] + + # image_feats.unsqueeze(1) : [batch_size, 1, num_query_tokens, embed_dim] + # text_feat_all.unsqueeze(-1) : [batch_size*num_gpu, embed_dim, 1] => broadcast to [batch_size, batch_size*num_gpu, embed_dim, 1] + # Last two dimensions are broadcasted to all other dimensions + # [j, 1, n, m] x [k, m, p] => [j, k, n, p] + # https://pytorch.org/docs/stable/generated/torch.matmul.html + # sim_q2t : [batch_size, batch_size*num_gpu, num_query_tokens] + sim_q2t = torch.matmul( + rearrange(image_feats, "bs n d -> bs 1 n d"), rearrange(text_feat_all, "(bs ngpus) d -> (bs ngpus) d 1", bs=b) + ).squeeze() + # [batch_size, batch_size*num_gpu, num_query_tokens] + + # Always use last token + # sim_i2t = sim_q2t[:, :, -1] + sim_i2t = sim_q2t + # Debug : Test Original BLIP-2 loss + # sim_i2t, _ = sim_q2t.max(-1) + sim_i2t = sim_i2t / self.temp + + # text-query similarity: [batch_size, batch_size*num_gpu, num_query_tokens] + sim_t2q = torch.matmul( + rearrange(text_feat, "bs d -> bs 1 1 d"), rearrange(image_feats_all, "(bs ngpus) n d -> (bs ngpus) d n", bs=b) + ).squeeze() + # [batch_size, batch_size*num_gpu, num_query_tokens] + + # Always use last token + # sim_t2i = sim_t2q[:, :, -1] + sim_t2i = sim_t2q + # Debug : Test Original BLIP-2 loss + # sim_t2i, _ = sim_t2q.max(-1) + sim_t2i = sim_t2i / self.temp # [batch_size, batch_size*num_gpu] + + rank = dist.get_rank() + bs = image.size(0) + targets = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to( + image.device + ) + + loss_itc = ( + F.cross_entropy(sim_i2t, targets, label_smoothing=0.1) + + F.cross_entropy(sim_t2i, targets, label_smoothing=0.1) + ) / 2 + + self.log( + "train/loss_itc", + loss_itc, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + ) + + return loss_itc + + @torch.no_grad() + def check_image_text_similarity(self, batch, batch_idx: int, save_dir="image_text_similarity"): + if len(batch) == 3: + image, text, image_id = batch + elif len(batch) == 2: + image, text = batch + rank = dist.get_rank() + + with torch.no_grad(): + image_embeds = self.image_tokenizer.model.ln_vision( + self.image_tokenizer.model.visual_encoder(image) + ) + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device) + + query_tokens = self.image_tokenizer.model.query_tokens.expand(image_embeds.shape[0], -1, -1) + + # Assume image_embeds.shape[0] is the batch size (b) and you have 32 tokens (n) + b, n, _ = query_tokens.shape + + query_output = self.image_tokenizer.model.Qformer.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + # Use last hidden state + # We have 32 tokens, and use last token as image embedding + # [b, 32, 768] + image_feats = F.normalize(query_output.last_hidden_state, dim=-1) + + text_tokens = self.image_tokenizer.model.tokenizer( + text, + padding="max_length", + truncation=True, + max_length=128, + return_tensors="pt", + ) + + text_output = self.image_tokenizer.model.Qformer.bert( + text_tokens.input_ids.to(self.device), + attention_mask=text_tokens.attention_mask.to(self.device), + return_dict=True, + ) + + # CLS token + # [b, 768] + text_feat = F.normalize(text_output.last_hidden_state[:, 0, :], dim=-1) + + ###============== Image-text Contrastive ===================### + + # Original BLIP-2 loss + # Compute for each query token + image_feats_all = image_feats # [batch_size, num_query_tokens, embed_dim] + text_feat_all = text_feat # [batch_size, embed_dim] + + # image_feats.unsqueeze(1) : [batch_size, 1, num_query_tokens, embed_dim] + # text_feat_all.unsqueeze(-1) : [batch_size*num_gpu, embed_dim, 1] => broadcast to [batch_size, batch_size*num_gpu, embed_dim, 1] + # Last two dimensions are broadcasted to all other dimensions + # [j, 1, n, m] x [k, m, p] => [j, k, n, p] + # https://pytorch.org/docs/stable/generated/torch.matmul.html + # sim_q2t : [batch_size, batch_size*num_gpu, num_query_tokens] + sim_q2t = torch.matmul( + rearrange(image_feats, "bs n d -> bs 1 n d"), rearrange(text_feat_all, "bs_X_ngpus d -> bs_X_ngpus d 1") + # image_feats.unsqueeze(1), text_feat_all.unsqueeze(-1) + ).squeeze() + # [batch_size, batch_size*num_gpu, num_query_tokens] + + ########### 1. Debug: for check the similarity ############ + # Softmax for each row + dump = [] + for token_num in range(32): + dump.append(F.softmax(sim_q2t[:, :, token_num], dim=1)) + dump = torch.stack(dump, dim=2) + positive_token_similarity = torch.diagonal(dump, dim1=0, dim2=1).mean(dim=1) + # Save positive_token_similarity as bar graph + plt.figure(figsize=(18, 6)) + bars = plt.bar(list(range(32)), positive_token_similarity.cpu().numpy(), color='blue') + plt.xlabel('Token Number') + plt.ylabel('Value') + plt.title('Positive Token Similarity') + plt.xticks(list(range(32))) # Ensure all keys are shown in the x-axis + # Add a table of values next to the bars + cell_text = [[f"{val:.4f}"] for val in positive_token_similarity.cpu().numpy()] + plt.table(cellText=cell_text, colLabels=["Value"], loc='right', cellLoc='center') + + # Adjust layout to make room for the table: + plt.subplots_adjust(right=0.5) + plt.savefig(f"{save_dir}/positive_token_similarity_i2t_batch{batch_idx}_rank{rank}.png") + + ############################################################ + # Debug: for check the similarity + count_dict = {} + for token_num in range(32): + count_dict[token_num] = 0 + for row in range(b): + _, ind = sim_q2t[:, :, token_num][row].max(-1) + if row == ind: + print(f"In token {token_num}, in row {row}, max index is {ind}") + count_dict[token_num] += 1 + print(count_dict) + + # Extracting keys and values + keys = list(count_dict.keys()) + values = list(count_dict.values()) + + # Plotting the histogram + plt.figure(figsize=(10, 6)) + bars = plt.bar(keys, values, color='blue') + plt.xlabel('Token Number') + plt.ylabel('Value') + plt.title('Histogram of Token Values') + plt.xticks(keys) # Ensure all keys are shown in the x-axis + + # Adding the text on top of each bar + for bar in bars: + yval = bar.get_height() + plt.text(bar.get_x() + bar.get_width()/2.0, yval, int(yval), va='bottom') + + os.makedirs(f"{save_dir}", exist_ok=True) + plt.savefig(f"{save_dir}/token_histogram_image_text_batch{batch_idx}_rank{rank}.png") + # plt.show() + ############################################################ + + # text-query similarity: [batch_size, batch_size*num_gpu, num_query_tokens] + sim_t2q = torch.matmul( + rearrange(text_feat, "bs d -> bs 1 1 d"), rearrange(image_feats_all, "bs_X_ngpus n d -> bs_X_ngpus d n") + # text_feat.unsqueeze(1).unsqueeze(1), image_feats_all.permute(0, 2, 1) + ).squeeze() + + # Debug: for check the similarity + count_dict = {} + for token_num in range(32): + count_dict[token_num] = 0 + for row in range(b): + _, ind = sim_t2q[:, :, token_num][row].max(-1) + if row == ind: + print(f"In token {token_num}, in row {row}, max index is {ind}") + count_dict[token_num] += 1 + print(count_dict) + + # Softmax for each row + dump = [] + for token_num in range(32): + dump.append(F.softmax(sim_t2q[:, :, token_num], dim=1)) + dump = torch.stack(dump, dim=2) + positive_token_similarity = torch.diagonal(dump, dim1=0, dim2=1).mean(dim=1) + # Save positive_token_similarity as bar graph + plt.figure(figsize=(18, 6)) + bars = plt.bar(list(range(32)), positive_token_similarity.cpu().numpy(), color='blue') + plt.xlabel('Token Number') + plt.ylabel('Value') + plt.title('Positive Token Similarity') + plt.xticks(list(range(32))) # Ensure all keys are shown in the x-axis + # Add a table of values next to the bars + cell_text = [[f"{val:.4f}"] for val in positive_token_similarity.cpu().numpy()] + plt.table(cellText=cell_text, colLabels=["Value"], loc='right', cellLoc='center') + + # Adjust layout to make room for the table: + plt.subplots_adjust(right=0.5) + plt.savefig(f"{save_dir}/positive_token_similarity_t2i_batch{batch_idx}_rank{rank}.png") + + # Extracting keys and values + keys = list(count_dict.keys()) + values = list(count_dict.values()) + + # Plotting the histogram + plt.figure(figsize=(10, 6)) + bars = plt.bar(keys, values, color='blue') + plt.xlabel('Token Number') + plt.ylabel('Value') + plt.title('Histogram of Token Values') + plt.xticks(keys) # Ensure all keys are shown in the x-axis + + # Adding the text on top of each bar + for bar in bars: + yval = bar.get_height() + plt.text(bar.get_x() + bar.get_width()/2.0, yval, int(yval), va='bottom') + + plt.savefig(f"{save_dir}/token_histogram_text_image_batch{batch_idx}_rank{rank}.png") + + loss_mean = 0 + rank = dist.get_rank() + if rank == 0: + for token in range(32): + bs = image.size(0) + targets = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to( + image.device + ) + + sim_i2t = sim_q2t[:, :, token] / self.temp + sim_t2i = sim_t2q[:, :, token] / self.temp + + loss_itc = ( + F.cross_entropy(sim_i2t, targets, label_smoothing=0.1) + + F.cross_entropy(sim_t2i, targets, label_smoothing=0.1) + ) / 2 + print(f"Loss I2T in Token {token}: {loss_itc}") + loss_mean += loss_itc + + self.log( + f"val/loss_itc_{token}", + loss_itc, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) + + loss_mean /= 32 + self.log( + "val/loss_itc_mean", + loss_mean, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) + + return + + def get_stage_2_loss_bypass_codebook(self, batch, batch_idx: int): + """_summary_ + + Args: + batch (_type_): _description_ + batch_idx (int): _description_ + """ + #------------------------ + # Stage 2 Training + #------------------------ + img, text = batch + + #------------------------ + # Stage 2 - 1 : Codebook Training + #------------------------ + with torch.no_grad(): + causal_embeddings = self.get_causal_embeddings(img) + + # TODO: query_output should be trained to be similar with text embedding + # Image embedding is cross attentioned. + # Notice: query_output_down is match to clip embedding? + # [b, 32, 32] + query_output_down = self.image_tokenizer.model.encode_task_layer(causal_embeddings) + + # quant [b, 32, 32], loss_embed [b, 32, 768], embed_ind [b, 32] + # bypass code book + quant = query_output_down + + #------------------------ + # Stage 2 - 2 : Reconstruction Caual Embedding + #------------------------ + + # quant embedding dimension is [b, 32, 32] + # decoder_task_layer upscale it to [b, 32, 768] + # [b, 32, 32] => [b, 32, 768] + query_output_up = self.image_tokenizer.model.decode_task_layer(quant) + + # Transformer decoder + query_output_up = self.image_tokenizer.model.get_transformer_decoded_embedding(query_output_up) + + # query_output_up_pos_image should be similar to original causal_embeddings + # Maximize cosine similarity between query_output_up_pos_image and causal_embeddings + + #loss_recon = F.cosine_similarity(query_output_up, causal_embeddings).mean() + loss_recon = F.mse_loss(query_output_up, causal_embeddings) + + #------------------------ + # Stage 2 - 3 : Reconstruction Generation Embedding + #------------------------ + + # MLP + # query_output_up = causal_embeddings + reverse_output_proj = self.image_tokenizer.model.get_mlp_decoded_embedding(query_output_up) + + gt_img_clip_embeddings = self.get_clip_img_embedding(img) + + loss_generation_embed = F.mse_loss(reverse_output_proj, gt_img_clip_embeddings) + + loss_total = loss_recon + loss_generation_embed + loss_total = loss_total.mean() + + loss_dict = {"loss_generation_embed": loss_generation_embed, + # "loss_embed": loss_embed, + "loss_recon": loss_recon, + "loss": loss_total} + + #------------------------ + # Logging + #------------------------ + generation_embedding_cosine_similarity = F.cosine_similarity(reverse_output_proj, gt_img_clip_embeddings).mean() + + self.logging_train_stage2(generation_embedding_cosine_similarity, loss_dict) + + return loss_total + + def get_stage_2_loss(self, batch, batch_idx: int): + """_summary_ + + Args: + batch (_type_): _description_ + batch_idx (int): _description_ + """ + #------------------------ + # Stage 2 Training + #------------------------ + device = self.device + if len(batch) == 3: + img, text, image_id = batch + elif len(batch) == 2: + img, text = batch + + #------------------------ + # Stage 2 - 1 : Codebook Training + #------------------------ + with torch.no_grad(): + causal_embeddings = self.get_causal_embeddings(img) + + # TODO: query_output should be trained to be similar with text embedding + # Image embedding is cross attentioned. + # Notice: query_output_down is match to clip embedding? + # [b, 32, 32] + query_output_down = self.image_tokenizer.model.encode_task_layer(causal_embeddings) + + # quant [b, 32, 32], loss_embed [b, 32, 768], embed_ind [b, 32] + quant, loss_embed, embed_ind = self.image_tokenizer.model.quantize(query_output_down) + + #------------------------ + # Stage 2 - 2 : Reconstruction Caual Embedding + #------------------------ + + # quant embedding dimension is [b, 32, 32] + # decoder_task_layer upscale it to [b, 32, 768] + # [b, 32, 32] => [b, 32, 768] + query_output_up = self.image_tokenizer.model.decode_task_layer(quant) + + # Transformer decoder + query_output_up = self.image_tokenizer.model.get_transformer_decoded_embedding(query_output_up) + + # query_output_up_pos_image should be similar to original causal_embeddings + # Maximize cosine similarity between query_output_up_pos_image and causal_embeddings + + #loss_recon = F.cosine_similarity(query_output_up, causal_embeddings).mean() + loss_recon = F.mse_loss(query_output_up, causal_embeddings) + + #------------------------ + # Stage 2 - 3 : Reconstruction Generation Embedding + #------------------------ + + # MLP + # query_output_up = causal_embeddings + reverse_output_proj = self.image_tokenizer.model.get_mlp_decoded_embedding(query_output_up) + + gt_img_clip_embeddings = self.get_clip_img_embedding(img) + + loss_generation_embed = F.mse_loss(reverse_output_proj, gt_img_clip_embeddings) + + loss_total = loss_embed + loss_recon + loss_generation_embed + loss_total = loss_total.mean() + + # loss_dict = {"loss_embed": loss_embed, "loss_recon": loss_recon, + # "loss_generation_embed": loss_generation_embed, + # "loss": loss_total} + + loss_dict = {"loss_generation_embed": loss_generation_embed, + "loss_embed": loss_embed, + "loss_recon": loss_recon, + "loss": loss_total} + + #------------------------ + # Logging + #------------------------ + generation_embedding_cosine_similarity = F.cosine_similarity(reverse_output_proj, gt_img_clip_embeddings).mean() + + self.logging_train_stage2(generation_embedding_cosine_similarity, loss_dict) + + return loss_total + + def on_train_start(self): + print(f"\n====Traing Stage {self.stage}====") + if self.stage == 2 and self.cfg.stage2.bypass_codebook: + print("\n====Bypass codebook====") + + print("Save config") + self.save_config() + + def training_step(self, batch, batch_idx: int): + if len(batch) == 3: + image, text, image_id = batch + elif len(batch) == 2: + image, text = batch + + self.B = image.shape[0] + + if self.stage == 1: + loss = self.get_stage_1_loss_use_last_token(batch, batch_idx) + elif self.stage == 2: + if self.cfg.stage2.bypass_codebook: + loss = self.get_stage_2_loss_bypass_codebook(batch, batch_idx) + else: + loss = self.get_stage_2_loss(batch, batch_idx) + + return loss + + def on_before_optimizer_step(self, optimizer, optimizer_idx): + # Compute the 2-norm for each layer + # If using mixed precision, the gradients are already unscaled here + # {'grad_2.0_norm/weight': 0.0003, 'grad_2.0_norm/bias': 0.0, 'grad_2.0_norm_total': 0.0003} + if self.cfg.experiment.stage == 1: + norms_0 = grad_norm(self.image_tokenizer.model.Qformer.bert.encoder.layer[0].attention.self.value, norm_type=2) + for norm in norms_0.keys(): + self.logger.experiment.add_scalar( + f"grad_norm/image_tokenizer/model/Qformer/bert/encoder/layer/0/attention/self/value/{norm}", + norms_0[norm], + global_step=self.global_step, + ) + norms_1 = grad_norm(self.image_tokenizer.model.Qformer.bert.encoder.layer[1].attention.self.value, norm_type=2) + for norm in norms_1.keys(): + self.logger.experiment.add_scalar( + f"grad_norm/image_tokenizer/model/Qformer/bert/encoder/layer/1/attention/self/value/{norm}", + norms_1[norm], + global_step=self.global_step, + ) + norms_7 = grad_norm(self.image_tokenizer.model.Qformer.bert.encoder.layer[7].attention.self.value, norm_type=2) + for norm in norms_7.keys(): + self.logger.experiment.add_scalar( + f"grad_norm/image_tokenizer/model/Qformer/bert/encoder/layer/7/attention/self/value/{norm}", + norms_7[norm], + global_step=self.global_step, + ) + elif self.cfg.experiment.stage == 2: + codebook_norm = grad_norm(self.image_tokenizer.model.quantize.embedding, norm_type=2) + for norm in codebook_norm.keys(): + self.logger.experiment.add_scalar( + f"grad_norm/image_tokenizer/model/quantize/{norm}", + codebook_norm[norm], + global_step=self.global_step, + ) + + transformer_decoder_norm = grad_norm(self.image_tokenizer.model.blocks_image, norm_type=2) + for norm in transformer_decoder_norm.keys(): + self.logger.experiment.add_scalar( + f"grad_norm/image_tokenizer/model/blocks_image/{norm}", + transformer_decoder_norm[norm], + global_step=self.global_step, + ) + + generation_mlp_norm = grad_norm(self.image_tokenizer.model.distill_image_proj, norm_type=2) + for norm in generation_mlp_norm.keys(): + self.logger.experiment.add_scalar( + f"grad_norm/image_tokenizer/model/distill_image_proj/{norm}", + generation_mlp_norm[norm], + global_step=self.global_step, + ) + + def on_validation_epoch_start(self): + return + + @torch.no_grad() + def validation_step(self, batch, batch_idx: int, save_path=None): + if self.logger is not None and isinstance(self.logger, pl.loggers.TensorBoardLogger): + tb_log_dir = self.logger.log_dir + else: + tb_log_dir = self.cfg.result_file_path # Fallback directory if logger is not set + + if self.stage == 1: + save_path = f"{tb_log_dir}/histogram" + os.makedirs(save_path, exist_ok=True) + + save_path = f"{tb_log_dir}/histogram/epoch_{self.current_epoch}" + os.makedirs(save_path, exist_ok=True) + + self.check_image_text_similarity(batch, batch_idx, save_dir=save_path) + elif self.stage == 2: + image, captions, image_id = batch + bypass_codebook = self.cfg.stage2.bypass_codebook + + with torch.no_grad(): + image_embeds = self.forward_stage_2(batch, batch_idx, bypass_codebook) + reconstructed_images = self.image_tokenizer.diffusion_model( + image_embeds=image_embeds, + negative_image_embeds=None, + guidance_scale=10, + noise_level=0, + latents=self.image_tokenizer.latents, + ).images + + save_path = f"{tb_log_dir}/images/version_{self.logger.version}/epoch_{self.current_epoch}/images" + os.makedirs(save_path, exist_ok=True) + + tensor_images = [] + + for img, cur_id in zip(reconstructed_images, image_id): + # save PIL image to save_path + img.save(f"{save_path}/{cur_id}") + + # For tensorboard logging + tensor_images.append(self.pil_to_tensor(img).unsqueeze(0)) + + tensor_images = torch.cat(tensor_images, dim=0) + + # Check if image is already logged + if batch_idx not in self.logged_original_image: + self.logger.experiment.add_images( + f"original/image_batch_{batch_idx}", + image, + ) + + self.logger.experiment.add_images( + f"original/image_batch_{batch_idx}_seed_reconstructed", + tensor_images, + ) + + # logging original caption + for caption in captions: + self.logger.experiment.add_text( + f"original/gt_text_image_batch_{batch_idx}", + caption, + ) + + self.logged_original_image.add(batch_idx) + else: + self.logger.experiment.add_images( + f"images/image_batch_{batch_idx}", + tensor_images, + global_step=self.sample_image_ind, + ) + self.sample_image_ind += 1 + + def on_validation_epoch_end(self): + if self.logger is not None and isinstance(self.logger, pl_loggers.TensorBoardLogger): + tb_log_dir = self.logger.log_dir + else: + tb_log_dir = self.cfg.result_file_path + + original_image_dir = self.cfg.dataset.val_config.root_dir + generated_image_dir = f"{tb_log_dir}/images/version_{self.logger.version}/epoch_{self.current_epoch}/images" + clip_score = calculate_clip_s_for_folder(original_image_dir, generated_image_dir) + + print(f"clip score: {clip_score}") + self.log_dict({ + 'clip_score': clip_score, + },on_step=False, on_epoch=True, prog_bar=True, logger=True) + + self.log( + "clip_score_coco_karpathy", + clip_score, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + ) + + def configure_optimizers(self): + # optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3, betas=(0.9, 0.999), weight_decay=1e-8) + lr = self.cfg.optimizer.max_lr + betas = (self.cfg.hyperparameters.beta_1, self.cfg.hyperparameters.beta_2) + weight_decay = self.cfg.hyperparameters.weight_decay + optimizer = torch.optim.AdamW(self.parameters(), lr=lr, betas=betas, weight_decay=weight_decay) + + #scheduler = transformers.get_cosine_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=100, num_training_steps=5000) + num_training_steps = self.cfg.experiment.total_training_steps + num_warmup_steps = self.cfg.experiment.num_warmup_steps + scheduler = transformers.get_cosine_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps + ) + + lr_scheduler_config = { + "scheduler": scheduler, + "name": "learning_rate", + "interval": "step", + "frequency": 1, + } + + return {"optimizer": optimizer, + "lr_scheduler": lr_scheduler_config,} + + @torch.no_grad() + def encode_image(self, image_torch): + with torch.no_grad(): + '''Convert a batch of img to code + Args: + model: The tokenizer model. + img: [b, c, h, w] + ''' + if len(image_torch.shape) == 3: + image_torch = image_torch.unsqueeze(0) + + causal_embeddings = self.get_causal_embeddings(image_torch) + query_output_down = self.image_tokenizer.model.encode_task_layer(causal_embeddings) + quant, loss_embed, embed_ind = self.image_tokenizer.model.quantize(query_output_down) + return embed_ind + + @torch.no_grad() + def decode_image(self, embed_ind): + with torch.no_grad(): + embed_ind = rearrange(embed_ind, 'n -> 1 n') + '''Convert a batch of code to img + Args: + model: The tokenizer model. + code: [b, 32, 32] + ''' + quant = self.image_tokenizer.model.quantize.get_codebook_entry(embed_ind) + query_output_up = self.image_tokenizer.model.decode_task_layer(quant) + query_output_up = self.image_tokenizer.model.get_transformer_decoded_embedding(query_output_up) + reverse_output_proj = self.image_tokenizer.model.get_mlp_decoded_embedding(query_output_up) + # print diffusion_model weight type + reconstructed_images = self.image_tokenizer.diffusion_model( + image_embeds=reverse_output_proj, + negative_image_embeds=None, + guidance_scale=10, + noise_level=0, + # latents=self.image_tokenizer.latents, + ).images + + return reconstructed_images + + +if __name__ == "__main__": + cfg, cfg_yaml = build_config() + device = "cuda" if torch.cuda.is_available() else "cpu" + + seed_everything(cfg.experiment.seed, workers=True) + + transform_cfg = OmegaConf.load(cfg.transform_cfg_path) + transform = hydra.utils.instantiate(transform_cfg) + + os.makedirs(cfg.result_file_path, exist_ok=True) + + datamodule = SEEDDataModule(cfg, transform=transform) + datamodule.setup() + train_dataloader = datamodule.train_dataloader() + val_dataloader = datamodule.val_dataloader() + + cfg.experiment.total_training_steps = datamodule.total_training_steps + + tb_logger = pl_loggers.TensorBoardLogger(save_dir=cfg.result_file_path) + lr_logger = pl.callbacks.LearningRateMonitor(logging_interval="step") + checkpoint_callback = pl.callbacks.ModelCheckpoint( + save_top_k=3, + monitor="clip_score_coco_karpathy" if cfg.experiment.stage == 2 else "val/loss", + mode="max", + ) + + trainer = pl.Trainer( + accelerator=device, + num_nodes=cfg.dist.n_nodes, + devices=cfg.dist.n_gpus, + strategy="ddp", + max_epochs=cfg.experiment.max_epochs, + deterministic=cfg.experiment.deterministic, + logger=tb_logger, + log_every_n_steps=cfg.experiment.log_every_n_steps, + # val_check_interval=cfg.experiment.val_check_interval, + check_val_every_n_epoch=cfg.experiment.check_val_every_n_epoch, + enable_checkpointing=cfg.experiment.enable_checkpointing, + num_sanity_val_steps=cfg.experiment.num_sanity_val_steps, + precision=str(cfg.optimizer.precision), + callbacks=[ModelSummary(max_depth=3), lr_logger] + [checkpoint_callback] if cfg.experiment.enable_checkpointing else [], + accumulate_grad_batches=cfg.experiment.grad_accumulation, + gradient_clip_val=cfg.optimizer.grad_clip_val, + ) + + if cfg.load_weight and cfg.resume: + raise ValueError("Only checkpoint or finetune") + + wrapper = SEEDTrainingWrapper(cfg).to(device) + + if cfg.load_weight: + wrapper.load_from_checkpoint(cfg.weight_path, cfg=cfg) + print("Loaded model from checkpoint") + else: + if cfg.experiment.stage == 1: + print("Stage 1 init from BLIP-2") + elif cfg.experiment.stage == 2: + print("Stage 2 init from Scratch") + + wrapper.setup("fit") + + if cfg.resume: + # Resume training + if cfg.weight_path is None: + raise ValueError("checkpoint_path is None") + else: + print(f"Resume training from {cfg.weight_path}") + trainer.fit( + wrapper, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, + ckpt_path=cfg.weight_path + ) + else: + print("Start training") + trainer.fit( + wrapper, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader + ) + + trainer.strategy.barrier() diff --git a/train_v7_unified_sds.py b/train_v7_unified_sds.py index 23cb7cf..2c0503f 100755 --- a/train_v7_unified_sds.py +++ b/train_v7_unified_sds.py @@ -44,6 +44,7 @@ from lavis.models import load_model from lavis.common.dist_utils import is_dist_avail_and_initialized +from functools import partial pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) @@ -122,7 +123,7 @@ def __init__(self, cfg): # For SDS t_range = [0.2, 0.6] - t_range = [0.02, 0.98] + # t_range = [0.02, 0.98] self.num_train_timesteps = 1000 self.min_step = int(self.num_train_timesteps * t_range[0]) self.max_step = int(self.num_train_timesteps * t_range[1]) @@ -188,12 +189,42 @@ def setup(self, stage): self.image_tokenizer.model.distill_image_proj = self.image_tokenizer.model.distill_image_proj.to("cpu") elif self.stage == 2: self.random_initialize_stage2_model_weights() - if self.image_tokenizer.diffusion_model is not None: - for param in self.image_tokenizer.diffusion_model.unet.parameters(): - param.requires_grad = self.cfg.stage2.train_unet - + if self.cfg.stage2.train_unet: + self.make_unet_trainable_for_img_embeds() + ## make dump folder os.makedirs(self.cfg.result_file_path, exist_ok=True) + + def make_unet_trainable_for_img_embeds(self): + for p in self.image_tokenizer.diffusion_model.unet.parameters(): + p.requires_grad = False + + for p in self.image_tokenizer.diffusion_model.unet.class_embedding.parameters(): + p.requires_grad = True + + for block in self.image_tokenizer.diffusion_model.unet.down_blocks: + try: + for resnet in block.resnets: + for p in resnet.time_emb_proj.parameters(): + p.requires_grad = True + except Exception as e: + print(e) + continue + + for block in self.image_tokenizer.diffusion_model.unet.up_blocks: + try: + for resnet in block.resnets: + for p in resnet.time_emb_proj.parameters(): + p.requires_grad = True + except Exception as e: + print(e) + continue + + for resnet in self.image_tokenizer.diffusion_model.unet.mid_block.resnets: + for p in resnet.time_emb_proj.parameters(): + p.requires_grad = True + + def random_initialize_stage2_model_weights(self): """Random initialize stage 2 model weights @@ -862,7 +893,11 @@ def get_stage_2_loss_bypass_codebook(self, batch, batch_idx: int): #loss_recon = F.cosine_similarity(query_output_up, causal_embeddings).mean() loss_recon = F.mse_loss(query_output_up, causal_embeddings) - + loss_dict = { + "loss_recon": loss_recon, + } + loss_total = self.cfg.experiment.recon_loss_weight * loss_recon + #------------------------ # Stage 2 - 3 : Reconstruction Generation Embedding #------------------------ @@ -872,24 +907,27 @@ def get_stage_2_loss_bypass_codebook(self, batch, batch_idx: int): image_embeds = self.image_tokenizer.model.get_mlp_decoded_embedding(query_output_up) gt_img_clip_embeddings.requires_grad = False - loss_sds = self.sds_loss( - image_embeds=image_embeds, - clean_image=img, - guidance_scale=10, - grad_scale=1, - ) - - loss_total = loss_recon + loss_sds - loss_dict = { - "loss_recon": loss_recon, - "loss_sds": loss_sds, - } - - if self.cfg.stage2.use_clip_loss: - # loss_clip = F.mse_loss(image_embeds, gt_img_clip_embeddings) - loss_clip = self.clip_loss(image_embeds, gt_img_clip_embeddings) - loss_total += loss_clip + sds_loss_weight = self.cfg.experiment.sds_loss_weight * self.sds_loss_weights[self.global_step] + if self.cfg.experiment.clip_loss_weight > 0: + loss_clip = F.mse_loss(image_embeds, gt_img_clip_embeddings) + # loss_clip = self.clip_loss(image_embeds, gt_img_clip_embeddings) loss_dict['clip_loss'] = loss_clip + _loss_clip = self.cfg.experiment.clip_loss_weight * loss_clip + if self.cfg.experiment.cross_annealing: + _loss_clip *= (1 - sds_loss_weight) + loss_total += _loss_clip + + if sds_loss_weight > 0: + loss_sds = self.sds_loss( + image_embeds=image_embeds, + clean_image=img, + guidance_scale=10, + grad_scale=1, + ) + + loss_dict['loss_sds'] = loss_sds + loss_dict['sds_weight'] = sds_loss_weight + loss_total += sds_loss_weight * loss_sds loss_dict['loss'] = loss_total #------------------------ @@ -1083,9 +1121,8 @@ def validation_step(self, batch, batch_idx: int, save_path=None): image, captions, image_id = batch bypass_codebook = self.cfg.stage2.bypass_codebook - image_embeds = self.forward_stage_2(batch, batch_idx, bypass_codebook) - with torch.no_grad(): + image_embeds = self.forward_stage_2(batch, batch_idx, bypass_codebook) reconstructed_images = self.image_tokenizer.diffusion_model( image_embeds=image_embeds, negative_image_embeds=None, @@ -1177,6 +1214,17 @@ def configure_optimizers(self): num_training_steps=num_training_steps ) + if self.cfg.experiment.sds_loss_weight > 0 and self.cfg.experiment.use_sds_loss_schedule: + _num_training_steps = num_training_steps // 8 + def f(current_step: int): + return 1 - max(0.0, float(_num_training_steps - current_step) / float(_num_training_steps)) + else: + def f(current_step: int): + return 1 + + x = np.arange(0, num_training_steps) + self.sds_loss_weights = np.array(list(map(f, x))) + lr_scheduler_config = { "scheduler": scheduler, "name": "learning_rate", @@ -1233,7 +1281,7 @@ def configure_optimizers(self): wrapper = SEEDTrainingWrapper(cfg).to(device) if cfg.load_weight: - wrapper.load_from_checkpoint(cfg.weight_path) + wrapper.load_from_checkpoint(cfg.weight_path, cfg=cfg) print("Loaded model from checkpoint") else: if cfg.experiment.stage == 1: diff --git a/train_v8_seed_stage1_long_caption.py b/train_v8_seed_stage1_long_caption.py index bfc0190..769dcf5 100755 --- a/train_v8_seed_stage1_long_caption.py +++ b/train_v8_seed_stage1_long_caption.py @@ -152,8 +152,6 @@ def setup(self, stage): # Update the model with the weights filtered_state_dict = {k: v for k, v in blip_model.state_dict().items() if k in self.image_tokenizer.model.state_dict()} self.image_tokenizer.model.load_state_dict(filtered_state_dict, strict=False) - elif self.cfg.stage1.init == "SEED": - print("Load init weights from SEED") print("Set stage 2 model not trainable") for param in self.image_tokenizer.model.quantize.parameters(): @@ -296,15 +294,9 @@ def forward_stage_2(self, batch, batch_idx: int, bypass_codebook=False): quant, loss_embed, embed_ind = self.image_tokenizer.model.quantize(query_output_down) embed_ind = embed_ind.reshape(quant.shape[0], -1) - - # # [b, 32, 32] => [b, 32, 768] + # [b, 32, 32] => [b, 32, 768] query_output_up = self.image_tokenizer.model.decode_task_layer(quant) - quant_embedding = self.image_tokenizer.model.quantize.get_codebook_entry(embed_ind) - - # # [b, 32, 32] => [b, 32, 768] - query_output_up = self.image_tokenizer.model.decode_task_layer(quant_embedding) - # [b, 32, 768] => [b, 32, 768] query_output_up = self.image_tokenizer.model.get_transformer_decoded_embedding(query_output_up) @@ -321,7 +313,6 @@ def logging_train_stage2(self, generation_embedding_cosine_similarity, loss_dict on_epoch=True, prog_bar=True, logger=True, - batch_size=self.B, ) if "loss_embed" in loss_dict.keys(): @@ -486,7 +477,7 @@ def get_stage_1_loss_use_last_token(self, batch, batch_idx: int): rank = dist.get_rank() bs = image.size(0) targets = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to( - self.device + image.device ) loss_itc = ( @@ -501,7 +492,6 @@ def get_stage_1_loss_use_last_token(self, batch, batch_idx: int): on_epoch=True, prog_bar=True, logger=True, - batch_size=self.B, ) return loss_itc @@ -713,8 +703,6 @@ def check_image_text_similarity(self, batch, batch_idx: int, save_dir="image_tex on_epoch=True, prog_bar=True, logger=True, - sync_dist=True, - batch_size=image.size(0), ) loss_mean /= 32 @@ -725,8 +713,6 @@ def check_image_text_similarity(self, batch, batch_idx: int, save_dir="image_tex on_epoch=True, prog_bar=True, logger=True, - sync_dist=True, - batch_size=image.size(0), ) return @@ -964,6 +950,9 @@ def on_before_optimizer_step(self, optimizer, optimizer_idx): generation_mlp_norm[norm], global_step=self.global_step, ) + + def on_validation_epoch_start(self): + return @torch.no_grad() def validation_step(self, batch, batch_idx: int, save_path=None): @@ -984,9 +973,8 @@ def validation_step(self, batch, batch_idx: int, save_path=None): image, captions, image_id = batch bypass_codebook = self.cfg.stage2.bypass_codebook - image_embeds = self.forward_stage_2(batch, batch_idx, bypass_codebook) - with torch.no_grad(): + image_embeds = self.forward_stage_2(batch, batch_idx, bypass_codebook) reconstructed_images = self.image_tokenizer.diffusion_model( image_embeds=image_embeds, negative_image_embeds=None, @@ -1036,9 +1024,6 @@ def validation_step(self, batch, batch_idx: int, save_path=None): global_step=self.sample_image_ind, ) self.sample_image_ind += 1 - - def on_validation_epoch_start(self): - return def on_validation_epoch_end(self): if self.logger is not None and isinstance(self.logger, pl_loggers.TensorBoardLogger): @@ -1102,23 +1087,25 @@ def configure_optimizers(self): os.makedirs(cfg.result_file_path, exist_ok=True) - train_datamodule = CompressionDataModule( - batch_size=cfg.experiment.local_batch_size, - num_workers=cfg.dataset.num_workers, + datamodule = CompressionDataModule( + cfg=cfg, transform=transform, - compression_level=31 + compression_level=0 ) - train_datamodule.setup() - train_dataloader = train_datamodule.train_dataloader() - - val_datamodule = SEEDDataModule(cfg, transform=transform) - val_datamodule.setup() - val_dataloader = val_datamodule.val_dataloader() + datamodule.setup() + train_dataloader = datamodule.train_dataloader() + val_dataloader = datamodule.val_dataloader() - cfg.experiment.total_training_steps = (len(train_dataloader) // cfg.dist.n_gpus) * cfg.experiment.max_epochs + cfg.experiment.total_training_steps = len(train_dataloader) * cfg.experiment.max_epochs // cfg.experiment.grad_accumulation + print(f"Training steps in one epoch: {len(train_dataloader)}") + print(f"Total training steps: {cfg.experiment.total_training_steps}") tb_logger = pl_loggers.TensorBoardLogger(save_dir=cfg.result_file_path) lr_logger = pl.callbacks.LearningRateMonitor(logging_interval="step") + checkpoint_callback = pl.callbacks.ModelCheckpoint( + save_top_k=2, + every_n_train_steps=300, + ) trainer = pl.Trainer( accelerator=device, @@ -1129,12 +1116,12 @@ def configure_optimizers(self): deterministic=cfg.experiment.deterministic, logger=tb_logger, log_every_n_steps=cfg.experiment.log_every_n_steps, - # val_check_interval=cfg.experiment.val_check_interval, - check_val_every_n_epoch=cfg.experiment.check_val_every_n_epoch, + val_check_interval=cfg.experiment.val_check_interval, + # check_val_every_n_epoch=cfg.experiment.check_val_every_n_epoch, enable_checkpointing=cfg.experiment.enable_checkpointing, num_sanity_val_steps=cfg.experiment.num_sanity_val_steps, precision=str(cfg.optimizer.precision), - callbacks=[ModelSummary(max_depth=3), lr_logger], + callbacks=[ModelSummary(max_depth=3), lr_logger, checkpoint_callback], accumulate_grad_batches=cfg.experiment.grad_accumulation, gradient_clip_val=cfg.optimizer.grad_clip_val, ) @@ -1145,11 +1132,11 @@ def configure_optimizers(self): wrapper = SEEDTrainingWrapper(cfg).to(device) if cfg.load_weight: - wrapper.load_from_checkpoint(cfg.weight_path, cfg=cfg) - print("Loaded model from checkpoint") + wrapper = wrapper.load_from_checkpoint(cfg.weight_path, cfg=cfg) + print(f"Loaded model from checkpoint {cfg.weight_path}") else: if cfg.experiment.stage == 1: - print("Stage 1 init from BLIP-2") + print(f"Stage 1 init from {cfg.stage1.init}") elif cfg.experiment.stage == 2: print("Stage 2 init from Scratch") diff --git a/utils/config.py b/utils/config.py index 7fadb4b..52f4ed3 100755 --- a/utils/config.py +++ b/utils/config.py @@ -22,7 +22,7 @@ def to_attr_dict(cfg): return cfg def build_config(struct=False): - cfg = OmegaConf.load('configs/overfitting_test.yaml') + cfg = OmegaConf.load('configs/seed_unified_test.yaml') OmegaConf.set_struct(cfg, struct) cfg = override_from_file_name(cfg) cfg = override_from_cli(cfg)