Metric Learning Using Siamese and Triplet Convolutional Neural Networks
Mục Lục
Metric Learning Using Siamese and Triplet Convolutional Neural Networks
Understanding the importance of concept, methods, and online training data mining in metric learning
Photo by Markos Mant on Unsplash
Imagine you have a database containing face images of 1000 people, in which only a few images represent the same person. Now you want to build a face recognition system based on this dataset. How would you do that?
To build a classification model? No! Because each person has only a few representing face images which are far from enough for classification training.
Actually, in deep learning and computer vision, such a task is already well studied in the past 15 years, which is called metric learning.
Metric learning, as the name implies, is a technology to map images to a metric space, in which images of the same class get close while images of different classes get far away from each other. For example, in the face recognition task discussed above, the model can learn to cluster the face images of the same person, while distinguishing the clusters of different people. During inference, a test face image is mapped to the learned metric space and its nearest cluster in the space just represents the predicted person.
Difference with Image Classification
As discussed above, the image classification model cannot be applied in face recognition task because only a few images of the same person (class) exist which is far from enough for classification tasks. Furthermore, if a new person is added to the database, the classification model should be re-learned from scratch, which is too expensive in time and GPU resources. However, these issues are not for metric learning at all.
One-shot (Few-shot) Learning through Metric Learning
One-shot learning is a technique to learn the category of objects from one or only a few samples of each category. Therefore, metric learning is a good implementation of one-shot learning, as the number of samples of each category remains only a few.
Basic Concept
Since metric learning is to learn a latent space where samples of different categories can be distinguished, the representations of samples inside the latent space is very important. As the techniques to be shown below, the representations would be 1D latent vectors, and Euclidean distance is used to calculate the distance between two latent vectors in the latent space.
So how do we get the latent vector out of a training sample image? Convolutional neural network (CNN) is the answer. However, unlike other tasks like classification or semantic segmentation in which one input sample is enough to get one estimation output, a CNN model for metric learning needs at least two input samples to get an estimation output. That is because the model needs to know whether the two input samples belong to the same category or not, so that the model can learn to cluster the two latent vectors or keep them away at a pre-defined distance. In this spirit, two networks are commonly used for metric learning: siamese network and triplet network.
Siamese Network and Contrastive Loss
Fig. 1: Siamese network [Hadsell et al.]
Siamese network, as the name implies, needs a pair of input sample images to get a pair of latent vectors. As shown in Fig. 1, two sample images Xa and Xb are fed into base networks one after the other to get latent vectors G(Xa) and G(Xb). Here, only one base network exists and its weights are shared for the two input sample images. Then in the latent embedding space, the distance D between the two latent vectors are calculated. Finally, the calculated distance D is substituted into the loss function (Fig. 2) and the base network is tuned via back propagation for better latent vector embedding.
Fig. 2: Contrastive loss [Hadsell et al.]
In order to calculate the loss according to the labels of whether the two input samples belong to the same category or not, contrastive loss is proposed by Hadsell et al. 15 years ago. The loss function is shown in Fig. 2, in which Y is the binary label that Y=0 means the two input samples belong to the same category while Y=1 otherwise. If you look at it closer, when Y=0, only the first item exists and Dw is minimized; when Y=1, only the second item exists and Dw is maximized to m, which is a user-defined hyper-parameter. Conceptually, when the two input samples belong to the same category, they are clustered as near as possible; otherwise, they are separated to a distance m. If the distance between two latent vectors from different categories is already bigger than m, the loss becomes zero and nothing would be learned at all.
A pytorch implementation of contrastive loss is as follows:
Triplet Network and Triplet Loss
Fig. 3: Triplet network [Schroff et al.]
Triplet network is an improvement of siamese network. As the name implies, three input sample images are needed, which are called anchor sample, positive sample and negative sample. Firstly, an anchor sample is picked, then a positive sample is picked from the same category as the anchor sample and a negative sample is picked from a different category with the anchor sample. Triplet network is superb to siamese network in that it can learn both positive and negative distances simultaneously and the number of combinations of training data improves to fight overfitting.
Fig. 4: Triplet loss [Schroff et al.]
Triplet loss is used to calculate the loss of estimation results of the three input samples. In concept, as shown in Fig. 4, the triplet network learns to decrease the distance between the anchor and positive, while increase the distance between anchor and negative (Fig. 4 left), so that the difference of the two distances would reach to alpha, which is a user-defined hyper-parameter (Fig. 4 right). In the loss function of Fig. 4, the first item is the distance between the anchor and positive and the second item is the distance between the anchor and negative. The value of the first item is learned to be smaller while the second item to be bigger. If their subtraction is smaller than minus alpha, the loss would become zero and the network parameters would not be updated at all.
In some social implementation applications of triplet loss, the producing of label data is challenging. Actually the training triplet has two parts: anchor- positive pair and anchor-negative pair. Similarly, the triplet loss also has two parts: the loss contributed by anchor-positive pair and the loss contributed by anchor-negative pair. Moreover, the anchor-negative pair training data is easy to get while anchor-positive pair is difficult. Therefore, the total number of anchor-negative pairs is far more than anchor-positive pairs, which would lead to overfitting on anchor-positive pairs. A solution of this is to control the ratio between the number two kinds of pair data, such as 1:3 etc. The thought is similar to anomaly detection tasks in which the number of anomaly label data is too small.
A pytorch implementation of triplet loss is as follows:
Online Triplet Mining
Triplet network is tricky to be trained quickly and effectively. A main reason is the randomness in choosing training triplets online. As discussed above, the loss becomes zero if the input triplet is already predictable, in which the difference of the two distances already reaches alpha and the loss vanishes. Therefore, if many of the training triplets are already predictable, the network parameters would not be updated at all and too much time and computing resources are wasted. To deal with this problem, people select hard samples which are not predictable yet to train. For example, a large batch-size is used and inside the batch, all the distances between anchor and positive/negative samples are calculated. Then training triplets are selected according to the calculated distances [Schroff et al.].
References
One-shot learning, Wikipedia
Dimensionality Reduction by Learning an Invariant Mapping, Hadsell et al., CVPR 2006
FaceNet: A Unified Embedding for Face Recognition and Clustering, Schroff et al., CVPR 2015
Join Medium with my referral link – Shuchen Du
Read every story from Shuchen Du (and thousands of other writers on Medium). Your membership fee directly supports…
dushuchen.medium.com