Skip to content

We implement NNCLR and a novel clustering-based technique for contrastive learning that we call KMCLR. We show that applying a clustering technique to obtain prototype embeddings and using these prototypes to form positive pairs for contrastive loss can achieve performances on par with NNCLR on CIFAR-100 while storing 0.4% of the number of vectors.

License

Notifications You must be signed in to change notification settings

mwritescode/nnclr-cifar100

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

26 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Python PyTorch Scikit

Beyond Single Instance Positives For Contrastive Visual Representation Learning

GIF


Humans subconsciously compare new sensory inputs with something they have already experienced and this might play an important role in how they can rapidly acquire new concepts. For instance. if we are asked to picture a mangalitsa pig, even without having seen one before, our brain will automatically link it with other similar semantic classes like pig or boar. NNCLR show that enhancing self-supervised techniques with the ability to find similarities between new and previously seen items can improve their representation learning capabilities.

Indeed, other instance discrimination approaches like SimCLR or BYOL focus on learning what makes a specific image different from everything else in the dataset. They typically generate multiple views of each image through data augmentation, consider them as positive samples, and then encourage their representations to be close in the embedding space. These single-instance positives may however not be able to capture all the variance inside a given class and, what's more important, they can't account for similarities between semantically close classes. NNCLR instead uses as a positive sample the nearest neighbor of an image's view in the embedding space. In practice this requires a large support set of previous embeddings that is updated on a per-batch basis; in principle, the bigger the suppost set the better the results. We argue that this brings their solution closer to a brute force approach and strive to find an alternative.

In particular, we propose to generate prototype vectors via online clustering and then use as the positive sample for an image the code that's nearer to its embedding. We speculate that this will be able to effectively capture both intra-class variances and similarities between multiple classes, while limiting the number of vectors we have to store. Due to computational and timing constraints, all our experiments use the CIFAR-100 dataset for pre-training and only run for 500 epochs with 10 epochs of warmup followed by cosine decay. As we show in the figure below, our KMCLR approach with k = 400, achieves performances on par with NNCLR while storing only 0.4% the number of vectors and not increasing the running time.

top-5 accuracy

Note that all our experiments were tracked with Weigths & Biases and can be publicly accessed here.

Requirements

The code has been tested with Python 3.10.6.
In order to install the required python packages it's recommended to (i) create a virtual environment, (ii) install torch and torchvision following the official instructions, and (iii) simply run the below script to install the remaining requirements.

  pip install -U pip
  pip install -r requirements.txt

Usage

In order to pre-train and linearly evaluate a model you can run

 python pretrain.py path/to/your/config_file.yaml # for training
 python linear_eval.py path/to/your/config_file.yaml # for evaluation

where config_file.yaml is a YACS configuration file following those in the config folder. Note that, if during training or evaluation you want to use wandb to log your metrics you also need to run wandb login before the commands above in order to login to your account.

About

We implement NNCLR and a novel clustering-based technique for contrastive learning that we call KMCLR. We show that applying a clustering technique to obtain prototype embeddings and using these prototypes to form positive pairs for contrastive loss can achieve performances on par with NNCLR on CIFAR-100 while storing 0.4% of the number of vectors.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages