Siamese Networks: Algorithm, Applications And PyTorch Implementation

Siamese Networks: Algorithm, Applications And PyTorch Implementation

Since siamese networks are getting increasingly popular in Deep Learning research and applications, I decided to dedicate a blog post to this extremely powerful technique. I will explain what siamese networks are and conclude with a simple example of a siamese CNN network in PyTorch.

What Are Siamese Networks?

Siamese networks (Bromley, Jane, et al. “Signature verification using a” siamese” time delay neural network.” Advances in neural information processing systems. 1994.) are neural networks containing two or more identical subnetwork components. A siamese network may look like this:

Example for a siamese network (source: Rao et al.)

It is important that not only the architecture of the subnetworks is identical, but the weights have to be shared among them as well for the network to be called “siamese”. The main idea behind siamese networks is that they can learn useful data descriptors that can be further used to compare between the inputs of the respective subnetworks. Hereby, inputs can be anything from numerical data (in this case the subnetworks are usually formed by fully-connected layers), image data (with CNNs as subnetworks) or even sequential data such as sentences or time signals (with RNNs as subnetworks).

Trending AI Articles:

Usually, siamese networks perform binary classification at the output, classifying if the inputs are of the same class or not. Hereby, different loss functions may be used during training. One of the most popular loss functions is the binary cross-entropy loss. This loss can be calculated as

, where L is the loss function, y the class label (0 or 1) and p is the prediction. In order to train the network to distinguish between similar and dissimilar objects, we may feed it one positive and one negative example at a time and add up the losses:

Another possibility is to use the triplet loss (Schroff, Florian, Dmitry Kalenichenko, and James Philbin. “Facenet: A unified embedding for face recognition and clustering.” Proceedings of the IEEE conference on computer vision and pattern recognition. 2015.):

Hereby, d is a distance function (e.g. the L2 loss), a is a sample of the dataset, p is a random positive sample and n is a negative sample. m is an arbitrary margin and is used to further the separation between the positive and negative scores.

Applications Of Siamese Networks

Siamese networks have wide-ranging applications. Here are a few of them:

Example: Classifying MNIST Images Using A Siamese Network In PyTorch

Having explained the fundamentals of siamese networks, we will now build a network in PyTorch to classify if a pair of MNIST images is of the same number or not. We will use the binary cross entropy loss as our training loss function and we will evaluate the network on a testing dataset using the accuracy measure. Below is the entire code for this post:

As you can see, most of the code consists of building an appropriate Dataset class that provides us with random image samples. For the purpose of training the network it is crucial that we obtain a balanced dataset with as many positive as negative sampes. Therefore, on each iteration, we provide both at the same time. The code for the dataset is quite long but ultimately simple: for each number (class) 0–9, we have to provide a positive pair (another image of the same number) and a negative pair (image of a random different number).

The network itself, defined in the Net class, is a siamese convolutional neural network consisting of 2 identical subnetworks, each containing 3 convolutional layers with kernel sizes of 7, 5 and 5 and a pooling layer in-between. After passing through the convolutional layers, we let the network build a 1-dimensional descriptor of each input by flattening the features and passing them through a linear layer with 512 output features. Note that the layers in the two subnetworks share the same weights. This allows the network to learn meaningful descriptors for each input and makes the output symmetrical (the ordering of the input should be irrelevant to our goal).

The crucial step of the whole procedure is the next one: we calculate the squared distance of the feature vectors. In principle, to train the network, we could use the triplet loss with the outputs of this squared differences. However, I obtained better results (faster convergence) using binary cross entropy loss. Therefore, we attach one more linear layer with 2 output features (equal number, different number) to the network to obtain the logits.

There are three main relevant functions in the code: the train function, the test function and the predict function.

In the train function, we feed the network a positive and a negative sample (two pairs of images). We calculate the losses for each of these and add them up (with the positive sample having a target of 1 and the negative sample having a target of 0).

The test function serves to measure the accuracy of the network on the test dataset. We perform the test after each training epoch to observe the training progress and to prevent overfitting.

The predict function, given a pair of MNIST images, simply predicts if they are of the same class or not. You can use predict after training is finished by setting the global variable do_learn to False.

Using the implementation above, I was able to achieve 96% accuracy on the test MNIST dataset.