Building a One-shot Learning Network with PyTorch

Building a One-shot Learning Network with PyTorch

Deep learning has been quite popular for image recognition and classification tasks in recent years due to its high performances. However, traditional deep learning approaches usually require a large dataset for the model to be trained on to distinguish very few different classes, which is drastically different from how humans are able to learn from even very few examples.

Few-shot or one-shot learning is a categorization problem that aims to classify objects given only a limited amount of samples, with the ultimate goal of creating a more human-like learning algorithm. In this article, we will dive into the deep learning approaches to solving the one-shot learning problem by using a special network structure: Siamese Network. We will build the network using PyTorch and test it on the Omniglot handwritten character dataset and perform several experiments to compare the results of different network structures and hyperparameters, using a one-shot learning evaluation metric.

Omniglot Dataset

The Omniglot handwritten character dataset is a dataset for one-shot learning, proposed by Lake et al. It contains 1623 different handwritten characters from 50 different series of alphabets, where each character was handwritten by 20 different people. Each image is 105×105 pixels large. The 50 alphabets are divided into a 30:20 ratio for training and testing, which means that the test set is on a completely new set of characters that are unseen before.

Computing Environment

The training and experiment was done solely through Google Colab, with a range of GPUs including Tesla K80 and P100. We used libraries including Numpy, Matplotlib, and PyTorch.

Method

Traditional deep networks usually don’t work well with one shot or few shot learning, since very few samples per class is very likely to cause overfitting. To prevent the overfitting problem and extend it to unseen characters, we proposed to use the Siamese Network.

Figure 1. Convolutional Siamese Network Architecture

Figure 1. is the backbone architecture of the Convolutional Siamese Network. Unlike traditional CNNs that take an input of 1 image to generate a one-hot vector suggesting the category the image belongs to, the Siamese network takes in 2 images and feeds them into 2 CNNs with the same structure. The output would be merged together, in this case through their absolute differences, and feed into fully connected layers to output one number representing the similarity of the two images. A larger number implies that the two images are more similar.

Instead of learning which image belongs to which class, the Siamese network learns how to determine the “similarity” between two images. After training, given a completely new image, the network could then compare the image to an image from each of the categories and determine which category is the most similar to the given image.

Dataset Preprocessing and Generation

Train and Validation Data Loader

To train the Siamese Network, we have to first generate the proper input (in pairs) and define the ground truth label for the model.

We first define two images that are from the same character in the same alphabet to have a similarity of 1, and 0 otherwise as shown in Figure 3. Afterwards, we randomly select a pair of images to input into the network based on parity of the index on the dataloader iteration. In other words, if the current iteration is an odd number, we retrieve a pair of images from the same character, and vice versa. This ensures that our training dataset is balanced for both types of outputs. Both images go through the same image transformation, since the goal is to determine the similarity of the two images, so feeding them into different image transformations wouldn’t make sense.

The following is the code for generating the training set:

We created 10000 pairs of these data as our training set, which is then separated further into training and validation with an 80:20 ratio randomly.

Test Loader

The evaluation of a network on its performance in one-shot learning can be done via an n-way one shot learning evaluation metrics, where we find n images representing n categories and one main image that belongs to one of the n categories. For our Siamese Network, we computed the similarity of the main images against all n images, and the pair with the highest similarity means the main image belongs to the class.

The test loader was structured in the way to support the above evaluation, where a random main image is taken and n images representing n categories were retrieved as well, one of which is from the same category of the main image.

The following is the code for generating the test set:

For our final testing, we extended our network to 4-way one shot learning with a test set size of 1000, and a 20-way with a size of 200.

Experiment

Experiment 1. Traditional Siamese Network for one-shot learning

Figure 2. Siamese Network Architecture by Koch et al.

The major part of the Siamese Network is the double convolutional architecture that was shown previously. The first convolutional architecture we will try to build was from Koch et al. in his paper “Siamese Neural Networks for One-shot Image Recognition”, as portrayed in Figure 2. One thing to note is that after flattening, the absolute differences between the two convolutional branches are fed into the fully-connected layer instead of just one image’s input.

The network in PyTorch is built as the following:

and we can perform training with the following function:

Hyperparameters Setting

Batch Size: Since we are learning how similar are two images, the batch size needs to be pretty big in order for the model to be generalisable especially for a dataset like this with many different categories. Therefore we used a batch size of 128.

Learning Rate: We tested with several learning rates from 0.001 to 0.0005, and selected a 0.0006 which provided the best loss decreasing rate.

Optimizer and Loss: We adopted the traditional Adam optimizer for this network with the binary cross entropy (BCE) loss with logits.

Results

Figure 3. Original Network’s Training and Validation Loss

The network is trained for 30 epochs. Figure 3. is the plot of the training and validation loss after every epoch, which, as we can see, shows a dramatic decrease and convergence towards the end. The validation loss decreases generally along with the training loss, indicating that no overfitting has occurred throughout the training. During the training, the model with the lowest validation loss will be saved. We used the validation loss instead of training loss as it is an indicator that the model is not just performing well only on the training set, which is likely to be a case of overfitting.

Experiment 2. Adding Batch Normalisation

Figure 4. Model Architecture with BatchNorm

To further improve the network, we can add batch normalisation, which supposedly is going to make the converging process faster and more stable. Figure 4 is the updated architecture with a BatchNorm2d after every convolutional layer.

Results

Figure 4. Training results after 10 epochs

As expected, the loss decreased a lot faster for both training loss and validation loss, in comparison to the original network. With a better result, we decided to also train the model for more epochs to see whether it would perform better than experiment 1.

Figure 5. Training results after 50 epochs

As shown in the loss graph, the results were slightly better than the original result from Experiment 1. Since the loss is slowly converging between epochs 40 and 50, we stopped training at the 50th epoch. This is currently the best result we have achieved.

Experiment 3. Swapping the ConvNet with a lightweight VGG16

After getting the original network to work pretty well, we can also test out different well-established CNNs for our Siamese Network, and see if we can achieve better results. With the small image size of 105×105, we wanted to use a network that is comparatively smaller with not that many layers, yet still produced decent results, and hence we borrowed the network architecture of VGG16.

The original VGG16 was still a bit too big for our size, where the final 5 convolutional layers are just dealing with single pixels, and so we eliminated them, ending up with the network as the following:

Results

Figure 6. VGG16 Siamese Network Result

As shown in the loss graph, the training loss is decreasing significantly slower than the prior experiments. This could be due to the fact that the kernel size of the convolutional layers is fairly small (3×3), which gives a small receptive field. For a problem of computing similarity between two images, it may perhaps be beneficial to look at a “bigger picture” of the two images instead of focusing on small details, and hence a larger receptive field proposed in the original network worked better.

Evaluation on the Model

The code for evaluating a network is implemented as the following:

4-way one shot learning

We first tested a 4-way one shot learning using a completely new set of images for evaluation, where all the testing images were not used during training, and no characters were known to the model either. The results showed an approximately 90% accuracy, which suggests that the model generalized pretty well to unseen datasets and categories, achieving our goal of one-shot learning on the Omniglot dataset.

20-way one shot learning

Afterwards, we performed a 20-way one shot learning evaluation for 200 sets. Where the result returned to still be around 86%. We compared the results with the baselines provided by the Lake et al.:

Although we did not outperform or replicate the proposed accuracy of the paper which is 92% (possibly due to details such as varying layer learning rate), we were actually pretty close to it.

In addition, our model actually performed a lot better than a lot of other models, including a normal Siamese network and Nearest Neighbor.

Conclusion

So there you have it! This is how to build a Convolutional Siamese Network for One-shot Learning the Omniglot Dataset. The full code is also posted on Github in the following directory:

Thank you for making it this far 🙏! I will be posting more on different areas of computer vision/deep learning, make sure to check out my other article on 3D reconstruction too!