The official implementation of our paper "WATT: Weight Average Test-Time Adaptation of CLIP".
-
We introduce a novel Test-Time Adaptation method for CLIP that, for the first time, leverages weight averaging across various text templates at test-time.
-
Our WATT method represents a pioneering advancement within the TTA paradigm, achieving exceptional performance with the ability to improve using only a single image at test time, a capability not present in previous approaches.
-
We rigorously evaluate our WATT methodology through comprehensive evaluations across different datasets characterized by diverse types and degrees of domain shifts, encompassing a total of 155 evaluation scenarios. Our experiments demonstrate the robustness and efficacy of WATT compared to existing adaptation methods.
git clone https://github.com/Mehrdad-Noori/WATT.git
Create an environment and Install the requirements The environment.yaml
file can be used to install the required dependencies:
cd WATT
conda env create -f environment.yml
Supported datasets:
To download all datasets, simply run the following command:
python download_datasets.py --data_dir ./data/
There are different bash files in ./bash
directory which are prepared to reproduce the results of the paper for different datasets and domains.
As an example, here is how to run WATT-S
method for different corruprions of CIFAR100-C
dataset using the eight text templates templates.yaml
used in the paper (the corresponding bash file to reprpduce the results in the paper is located at ./bash/watt_s/CIFAR/cifar100c.sh
).
# dataset configuration
DATASET=cifar100
DATA_DIR=/path/to/data/CIFAR100/
# adaptation parameters
BATCH_SIZE=128
LR=1e-3
BACKBONE=ViT-B/32
# method to use for adaptation
METHOD=WATT
# WATT method configurations (please see the paper)
WATT_TYPE=sequential
WATT_L=2
WATT_M=5
# path to text templates that should be used during adaptation
WATT_TEMPLATE_DIR=./templates.yaml
# List of all corruption types for adaptation
ALL_CORRUPTIONS="gaussian_noise shot_noise"
# Execute the adaptation process with specified parameters
python main.py --data_dir $DATA_DIR --dataset $DATASET --adapt --method $METHOD --save_dir ./save --backbone $BACKBONE --batch-size $BATCH_SIZE --lr $LR --watt_type $WATT_TYPE --watt_l $WATT_L --watt_m $WATT_M --watt_temps $WATT_TEMPLATE_DIR --corruptions_list $ALL_CORRUPTIONS
Here are the results of our prposed WATT-P
and WATT-S
methods on different corruprions of CIFAR100-C dataset. For a more detailed analysis and a complete table of the results, please refer to our paper.
Corruption | CLIP | TENT | TPT | CLIPArTT | WATT-P | WATT-S |
---|---|---|---|---|---|---|
Gaussian Noise | 14.80 | 14.38 | 14.03 | 25.32 | 31.28 | 32.07 |
Shot noise | 16.03 | 17.34 | 15.25 | 27.90 | 33.44 | 34.36 |
Impulse Noise | 13.85 | 10.03 | 13.01 | 25.62 | 29.40 | 30.33 |
Defocus blur | 36.74 | 49.05 | 37.60 | 49.88 | 52.32 | 52.99 |
Glass blur | 14.19 | 3.71 | 16.41 | 27.89 | 31.20 | 32.15 |
Motion blur | 36.14 | 46.62 | 37.52 | 47.93 | 49.72 | 50.53 |
Zoom blur | 40.24 | 51.84 | 42.99 | 52.70 | 54.72 | 55.30 |
Snow | 38.95 | 46.71 | 42.35 | 49.72 | 51.79 | 52.77 |
Frost | 40.56 | 44.90 | 43.31 | 49.63 | 53.04 | 53.79 |
Fog | 38.00 | 47.31 | 38.81 | 48.77 | 50.78 | 51.49 |
Brightness | 48.18 | 60.58 | 50.23 | 61.27 | 62.65 | 63.57 |
Contrast | 29.53 | 45.90 | 28.09 | 48.55 | 51.34 | 52.76 |
Elastic transform | 26.33 | 33.09 | 28.12 | 37.45 | 39.97 | 40.90 |
Pixelate | 21.98 | 26.47 | 20.43 | 33.88 | 39.59 | 40.97 |
JPEG compression | 25.91 | 29.89 | 28.82 | 36.07 | 38.99 | 39.59 |
Mean | 29.43 | 35.19 | 30.46 | 41.51 | 44.68 | 45.57 |
To add a new adaptation method to the framework, follow these steps:
Create a new file for your method in the adapt
directory. For example, adapt/new_method.py
. Your method class should have the following four mandatory methods:
class NewMethod:
def __init__(self, model, lr, param1='default1', param2=10, device='cpu'):
"""
Initializes the NewMethod module.
"""
def reset(self):
pass
def adapt(self, inputs, classes, templates):
pass
def evaluate(self, inputs, classes, templates):
pass
from .new_method import NewMethod # Add this import
def get_method(args, device):
elif args.method == 'new_method': # Add this elif clause
print(f"Selected method: NewMethod with parameters: model={args.backbone}, lr={args.lr}, param1={args.param1}, param2={args.param2}, device={device}")
return NewMethod(args.backbone, args.lr, param1=args.param1, param2=args.param2, device=device)
Update add_method_specific_args
in main.py:
def add_method_specific_args(parser, method):
elif method == 'new_method': # Add this elif clause
parser.add_argument('--param1', type=str, default='default1', help='Description for param1')
parser.add_argument('--param2', type=int, default=10, help='Description for param2')
return parser
This source code is released under the MIT license, which can be found here.
This project incorporates components from the following repositories. We extend our gratitude to the authors for open-sourcing their work: