Zero-shot classification: Winning Kaggle competitions with CLIP has been saved
Zero-shot classification: Winning Kaggle competitions with CLIP
10 min read
Computer Vision (CV) is a field of Artificial Intelligence (AI) aimed at extracting information from visual inputs. CV has a range of applications in industry settings. For instance, to quickly analyse medical scans, images uploaded to a web platform or the image stream in an autonomous vehicle. A common CV task is image classification, in which AI models aim to label an image as a member of a category. An example is identifying whether an image is offensive or not on a social media platform.
This article will discuss two interesting approaches to image classification: training deep learning models and zero-shot classification. Deep learning models can achieve super-human performance in image classification after being trained on a large dataset but often struggle when used to classify images the model has not seen before, or if there is a small set of images to train on. This article will detail an exciting state-of-the-art alternative: using CLIP (Contrastive Language-Image Pretraining)1 for zero-shot classification; that is, image classification with no additional training. CLIP can outperform large deep learning models with zero-shot classification on a difficult Kaggle CV problem: distinguishing game related memes.
CLIP and zero shot classification
State of the art performance in CV and image classification often involves a Convolutional Neural Network2 (CNN) trained on a very large dataset. One example is ResNet3, which was trained to identify 1000 different categories in the 2012 ImageNet4 dataset. The dataset is a collection of 1.28 million images and a key benchmark in CV research. ResNet is trained by encoding an image (in this case into a 512-dimensional feature vector, roughly 512 decimal values) and using that encoding to identify the image’s label. ResNet achieved state of the art performance for the ImageNet data with 76.2 per cent accuracy.
In contrast, CLIP was trained using both images and natural language. Using a set of 15 million images and text descriptions from the internet, CLIP encodes both the text and the image into feature vectors, and during training learns to match images with the corresponding text description, as shown in figure 1:
Figure 1. CLIP’s training architecture
This kind of training leads to some very innovative features in the CLIP model:
- Zero-shot classification
- By learning to match images with text descriptions, CLIP can encode any set of text descriptions and select which is the best match for a particular image. This is the basis of zero-shot classification: CLIP can encode and make decisions about images and text it has not seen before
- ImageNet performance
- Using zero-shot classification, CLIP was tested on the same ImageNet data as ResNet and matched its performance without prior training on the large training dataset
- Like other CNNs, ResNet scores well on the data it trained on but quickly drops performance when using a slightly different set of images. Interesting examples of this are shown in Figure 2. ResNet struggles to identify the same objects from new datasets, while CLIP’s zero-shot classification maintains much better performance
Figure 2. ResNet vs CLIP when generalising to new data6
- Semantic encoding
- As images are matched with natural language, the encodings produced by CLIP capture more than just information about the visual features in the images, and also represent semantic features. Consider the three images in Figure 37:
Figure 3. Example of the semantic nature of CLIP encoding
The images of coffee beans, a cup of coffee and Deloitte’s Chief AI officer Sulabh Soral are visually quite distinct. However, a human can understand there is a connection between the beans and the cup, due to a semantic understanding that transcends the visual features. CLIP also captures this semantic level, as is reflected in the similarity8 between the CLIP encodings for these images: the cup and Sulabh come out as 0.46 similar, while the beans and the cup come out at 0.74. In comparison, ResNet can only capture the visual features in its encodings: the beans and cup come out as only 0.17 similar, with the cup and Sulabh more similar at 0.20
- Natural language processing
- Being able to encode text means CLIP also has natural language processing (NLP) capabilities. These capabilities can be used to further strengthen CLIP’s zero-shot classification, as will be shown in the following use case.
Zero-shot classification in action
To demonstrate the surprisingly good accuracy of CLIP’s zero-shot classification, we will classify the images in a Kaggle dataset, ‘Doom or Animal Crossing?’9. This is a dataset containing images of memes of two different games, either relating to Doom Eternal or Animal Crossing: New Horizons, scraped from subreddits relating to each game. This dataset creates an interesting challenge in relation to size and content The dataset only contains 1597 images, meaning only about 800 images for each class. When removing the usual 10-15 per cent to create a test and validation set, it leaves around 700 in the training images class (much less than the tens of thousands usually required for a good classifier). Content is important as memes are an interesting CV challenge, as they much more complicated than a standard object classification problem. The complexity for memes is how they often refer to a mix pop culture reference which may make images in the same class very visually diverse. An example of this is seen in Figure 4: showing two memes in the Animal Crossing class.
Figure 4. Two diverse images in the ‘Animal Crossing’ class
So how would we go about classifying these images?
Training a CNN and classifier from scratch
One option is to build a classifier from scratch. This involves defining a CNN model architecture and training on a subset of the data (I used 70 per cent for training, with 30 per cent for testing)10. As mentioned, this is a very small dataset to train a neural network on, so random image rotation and flipping was added. This simulates having additional datapoints, as a slight rotated or flipped image appears as completely new input to a CNN model.
Figure 5 shows the code used to set up the CNN, and the result after six epochs of training. The result is 73 per cent accuracy – usually we would like better in a binary classification model.
Best accuracy: 73%
Training a classifier using ResNet
A popular option in modern approaches is to start with a pretrained CNN model, to gain the benefit for the large amounts of training already completed. A new classifier can be trained which uses the encodings of the pretrained model as input, training a new classifier to map that input to the new target labels (whether it is Doom or Animal Crossing image in this case)11.
Figure 6 shows the code used to create this new model and performance. The same augmentation was used (flipping and rotating) with a pretrained ResNet model before then passing to a new classifier. Replacing the new CNN with one already well-trained yields a good performance benefit: up to a more respectable 88 per cent accuracy on the test set after 20 epochs.
Figure 6. Using a pretrained ResNet50 CNN
Best accuracy: 88%
Zero-shot classification with CLIP
Let us now see how CLIP does with some zero-shot classification12. As described above, given an image and a set of sentences, CLIP will predict which sentence is the most likely to describe that image; without any additional training on the input data. In this case, we will use the input labels ‘Doom’ and ‘Animal Crossing’, as shown in Figure 7:
Figure 7. Example zero-shot classification of a ‘Doom’ image
CLIP correctly classifies the image in Figure 7 with a very high confidence. We will get a prediction using this method for every image in dataset, asking CLIP whether the ‘Doom’ label or ‘Animal Crossing’ label is more likely. Our output in this case is held in a Pandas DataFrame, the first five rows shown in Figure 8. The overall accuracy of CLIP’s predictions is also shown, a strong performance of over 93 per cent. Given there was no additional training or fine-tuning for these new images and the difficulty of the task, this is an impressive result.
Figure 8. The zero-shot performance of CLIP
Best accuracy: 93.7%
Boosting performance with natural language processing
As outlined above, an interesting feature of CLIP is its ability to process both images and text. Memes make a very difficult image classification task, as they contain text which may be more informative than the image itself. An example is Figure 9, a datapoint CLIP misclassified:
Figure 9. A meme where text is more informative than the image
It would be difficult for a model or a human to correctly classify the image in Figure 9 without reading the text, when one might guess a ‘Marauder’ and ‘fighting’ are more relevant to ‘Doom’ than ‘Animal Crossing’.
We can use the NLP capabilities of CLIP to process the text and improve classification performance of tricky cases using these steps:
- Step-1: extract the text from the images using an Optical Character Recognition model
- Step-2: encode the extracted text on the images, and encode the target labels with CLIP (‘Doom’ and ‘Animal Crossing’ in this case)
- Step-3: calculate how similar the extracted text for each image is to each of the target labels, for example using cosine similarity. This will create a similarity score between 0 and 1, with higher values indicating greater similarity
- Step-3: Create a predicted label based on which target label (‘Doom’ or ‘Animal Crossing’) the extracted text is most similar to.
We now have two predicted labels (if an image does have text): a zero-shot image classification label, and an NLP based cosine similarity label. There are different strategies you could employ to determine which label to follow where the two disagree. First, choosing a simple maxim where the zero-shot confidence is greater than the NLP similarity; take the zero-shot predicted label. Second, where the NLP similarity is greater than the zero-shot confidence; take the NLP label. Figure 10 shows the applied logic and final accuracy:
Figure 10. Combining zero-shot classification and NLP based classification
Best accuracy: 94.1%
We can see this tweak tipped the accuracy over 94per cent. This is a good classification score and the highest seen on this dataset14, and impressive given this is on a small and unusual set of images with no additional training. Of course, more work could be done to try to improve the CNN models; they would likely benefit from more regularisation and fine-tuning of their hyper-parameters. While that might boost their accuracy toward towards that of CLIP, getting the right balance of those parameters can be tricky and time consuming, whereas CLIP offers a very strong classification score without any upfront investment in additional training.