Understanding Siamese Network with example and codes
Mục Lục
Understanding Siamese Network with example and codes
One-Shot Learning with Siamese Network trained using Contrastive loss
In my previous, I discussed the basics of One-Shot learning alongside a brief on different types of One-Shot Learning for parametric machine learning models. As promised, it’s time we discuss the different models that can be used for One-Shot learning starting off with Siamese Network.
Siamese networks are no classification models but more comparator models i.e. checks whether two samples (images) passed are similar or not. Though, we can mold this for tasks like binary classification as well.
So let’s first discuss the architecture and loss functions we can use and later on will train our Siamese network for a dummy dataset as well.
Siamese Network architecture
- It is a combination of 2 shallow(few hidden layers), identical CNNs. The structure can be anything you wish to have.
- The parameters between these CNNs are shared i.e. same weights and biases being used for both the CNNs. Only one set of weights is trained and used for both CNNs.
- It uses triplet or contrastive loss functions. Never heard of them? will discuss soon
- The final expected output is a binary (0 or 1) where 1= similar images else 0.
Why is parameter sharing done?
One major reason is to reduce the number of parameters so as to converge faster (isn’t it the whole idea of One-Shot Learning?). Shared params would mean lesser parameters to estimate, and quicker convergence with fewer data. Also, if the parameters remain the same for the networks, we know two similar images (fed each to one network separately) would generate a similar output as the information learned is the same !! Hence, helping Siamese achieve its goal.
Coming to the loss functions, we have two possibilities
- Contrastive loss:
(1-Y) x 0.5 x X² + Y x 0.5 x (max(0,m-X))²
Where
Y= Prediction
X = (Square_Root((Siamese1(X1)-Siamese2(X2)))²
m = Margin, assume it is a bias added so as to avoid getting Weights=0
X1= Input_image1
X2= Input_image2
So, X is nothing but the square root of the difference between the output of two siamese networks for the two images respectively.
If you look closely, the model will try to make X=0 so as to minimize loss. Also, don’t confuse Siamese1 & Siamese2, it’s the same network getting used to getting to the output of the two images.
- Triplet loss:
At times, for complex objects, Contrastive loss might not be able to work properly. Hence, you will observe Triplet loss being used with Siamese networks nowadays.
max(0,d(A,P) + d(A,B) + alpha)
Where
A= actual Image
P= sample similar to A
B= sample very different from A
d() = distance function
alpha = constant
The basic idea behind triplet loss is the distance between A & P(similar image) should be very less compared to A & B (very different image).
Triplet loss is assumed to perform better compared to Contrastive loss function as it uses a negative & positive example together to train for each sample while the Contrastive loss function is doing the same thing but one at a time. Hence Triplet learns the different boundaries better & quicker!
As we do with the architecture and loss functions, it’s time we train a Siamese Network over some dummy dataset with a handful of samples (50 samples). I will be using the Contrastive loss function for now but the code can be updated for triplet loss as well. The code flow will be something like this
Train a shallow CNN over MNIST dataset for classification with categorical_crossentropy as loss funciton.
Pair up actual dataset samples we wish to train the siamese network with
a) similar images, samples from the dataset itself 2) random images to generate training dataset samples. Hence, our training set will have 2 images as input and 0/1 as target where 1 means similar images else different images.
Use the shallow CNN we trained over MNIST as base for Siamese network by dropping the last dense layer and changing loss function to Contrastive loss function and train on above paired images dataset prepared.
All this is digestable but why are we training the shallow CNN with MNIST for classification?
Transfer Learning is the answer
So what we would be doing is for making the model understand the basic features of images in general, we will be training the CNN on any image dataset (big enough) and later on, using our scanty dataset, will make the model specific to our problem statement. Hence, Transfer learning can be taken as a 2 step process
Train model on some general dataset so that the weights in the model are stabilized/initialized with decent estimates. You can also take any pre-trained model on datasets like ImageNet
Train the same model on the dataset for our specific problem by adding/removing a few layers from the above trained model. In this way, we don’t need much data as weights are already in a decent shape (and not completely randomized) and convergence also takes less time.
Enough theoretical, time for some action
- Import libraries
%matplotlib inline
from randimage import get_random_image, show_array
import random
import pandas as pd
from sklearn.datasets import load_digits
import cv2
import numpy as np
from PIL import Image
from keras.utils.np_utils import to_categorical
from keras import backend as K
from keras.layers import Input, Lambda, Dense, Dropout, Convolution2D, MaxPooling2D, Flatten,Activation
from keras.models import Sequential, Model
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras import optimizers
import matplotlib.pyplot as plt
from keras import callbacks
from keras.callbacks import ModelCheckpoint, LearningRateScheduler, EarlyStopping, ReduceLROnPlateau, TensorBoard
from keras.models import Model,load_model, model_from_json
import tensorflow as tf
from keras.layers import Dense, Dropout, Flatten, Conv2D, MaxPool2D, BatchNormalization
#enable eager execution in tensorflow
tf.config.run_functions_eagerly(True)
2.Loading MNIST toy dataset from sklearn toy datasets
digits_ = load_digits()
target_ = digits_['target']
target_ = train_labels = to_categorical(target_, num_classes=10)
digits = digits_['data'].reshape(-1,8,8)
digits_resize = np.zeros((len(digits),32,32))
for x,y in enumerate(digits):
digits_resize[x] = (cv2.resize(y, dsize=(32,32), interpolation=cv2.INTER_CUBIC)+1)/2
The loop, in the end, resizes the 8×8 MNIST images to 32×32 images using OpenCV. This is because our actual dataset is of the shape 32×32.
3. Build the base CNN
def build_base_network(input_shape=(32,32,1)):
model = Sequential()
model.add(Convolution2D(16, (8,8), strides=(1,1),activation="relu", input_shape=input_shape))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Convolution2D(32, (4,4),strides=(1,1), activation="relu"))
model.add(Flatten())
model.add(Dense(512, activation="relu"))
model.add(Dense(256, activation="relu"))
model.add(Dense(10,activation='softmax'))
return model
model = build_base_network()
rms = optimizers.Adam(lr=0.0001, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)#RMSprop()
rms = RMSprop()
earlyStopping = EarlyStopping(monitor='val_loss',
min_delta=0,
patience=3,
verbose=1,
restore_best_weights=True)
callback_early_stop_reduceLROnPlateau=[earlyStopping]
model.compile(loss='categorical_crossentropy', optimizer=rms,metrics="accuracy")
model.summary()
batch_size=32
model.fit(digits_resize.reshape(-1,32,32,1), target_, validation_split=.20,batch_size= batch_size, verbose=1, epochs=10, callbacks=callback_early_stop_reduceLROnPlateau)
This section involves training a basic CNN hence skipping it’s explaination. This network is 1st trained with MNIST and then with the actual dataset. The below results show the final accuracy on validation set for the CNN.
CNN Architectureresults for MNIST Classification
Nearly 92%, not bad !!
4. Generating dummy dataset for training Siamese Network. The class we wish the network to distinguish from random images has a colored rectangle present at a random place in the image.
one = []
zero = []
img_size = (32,32)
for x in range(200):
img = get_random_image(img_size)
#picking random a,b,c,d coordinates for plotting rectangle
a,b = random.randrange(0,img_size[0]/4),random.randrange(0,img_size[0]/4)
c,d = random.randrange(img_size[0]/2,img_size[0]),random.randrange(img_size[0]/2,img_size[0])
value = random.sample([True,False],1)[0]
if value==False:
#plotting rectangle
img[a:c,b:d,0] = 25
img[a:c,b:d,1] = 25
img[a:c,b:d,2] = 25
#convert RGB image to black & white
img = np.asarray(Image.fromarray((img*255).astype(np.uint8)).convert('L'))/255
one.append(img)
else:
img = np.asarray(Image.fromarray((img*255).astype(np.uint8)).convert('L'))/255
zero.append(img)
A few points to note
- List ‘one’ has our actual dataset (images with rectangles) while ‘zero’ has random images. We will use these to create image pairs later.
- We are converting RGB images to single channel images using convert(‘L’)
- a,b,c,d can be taken as the coordinates for the rectangle in positive class images
Let’s see a few samples
Actual dataset with rectangles at random placesRandom images
5. Defining Contrastive loss function alongside other utilities
def euclidean_distance(vects):
#euclidean distance, output for Siamese network
x, y = vects
return K.sqrt(K.sum(K.square(x - y), axis=1, keepdims=True))
def eucl_dist_output_shape(shapes):
shape1, shape2 = shapes
return (shape1[0], 1)
def compute_accuracy(predictions, labels):
return labels[predictions.ravel() < 0.5].mean()
def accuracy(y_true, y_pred):
return K.mean(K.equal(y_true, K.cast(y_pred < 0.5, y_true.dtype)))
def contrastive_loss(y_true, y_pred):
margin = 1
return K.mean(y_true * K.square(y_pred) + (1 - y_true) * K.square(K.maximum(margin - y_pred, 0)))
6. Generating training and validation dataset
total_sample_size = 50
test_sample_size = 200
dim1,dim2 = 32,32
count = 0
x_pair = np.zeros([total_sample_size, 2, 1, dim1, dim2])
y = np.zeros([total_sample_size,1])
x_pair_test = np.zeros([test_sample_size, 2, 1, dim1, dim2])
y_test = np.zeros([test_sample_size,1])
for x in range(total_sample_size):
value = random.sample([True,False],1)[0]
if value:
pair = random.choices(one, k=2)
x_pair[x,0,0,:,:] = pair[0]
x_pair[x,1,0,:,:] = pair[1]
#setting label=1 for similar images
y[x] = 1
else:
x_pair[x,0,0,:,:] = random.choices(one, k=1)[0]
x_pair[x,1,0,:,:] = random.choices(zero, k=1)[0]
#setting label=0 for dissimilar images
y[x] = 0
for x in range(test_sample_size):
value = random.sample([True,False],1)[0]
if value:
pair = random.choices(one, k=2)
x_pair_test[x,0,0,:,:] = pair[0]
x_pair_test[x,1,0,:,:] = pair[1]
y_test[x] = 1
else:
x_pair_test[x,0,0,:,:] = random.choices(one, k=1)[0]
x_pair_test[x,1,0,:,:] = random.choices(zero, k=1)[0]
y_test[x] = 0
What we are doing is
- Taking training sample size = 50 (as we need to demonstrate One/Few shot learning), randomly pairing images with rectangles with either similar images from the dataset or random images.
- Taking validation_set=200, repeat the above steps. We have taken a comparatively bigger validation so as to be sure of the results we get.
7. Modifying the CNN
#remove the final output layer from the CNN we added for MNIST Classification
model2= Model(inputs=model.input, outputs=model.layers[-2].output)
input_dim = x_pair.shape[3:]+ tuple([1])
img_a = Input(shape=input_dim)
img_b = Input(shape=input_dim)
feat_vecs_a = model2(img_a)
feat_vecs_b = model2(img_b)
#Siamese output using utility functions declared above
distance = Lambda(euclidean_distance, output_shape=eucl_dist_output_shape)([feat_vecs_a, feat_vecs_b])
So what we did just now is
- Removed the last dense layer from CNN we trained for MNIST classification. hence, it outputs an embedding of length 50 for each image fed now.
- Using the feature embedding for two images (the pair we will feed), calculate the euclidean distance. This will be our output!
Hence,
Input: Pair of images,
Output: Euclidean distance between feature embeddings of the images
9. Compile the Siamese Network
rms = optimizers.Adam(lr=0.0001, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)#RMSprop()
earlyStopping = EarlyStopping(monitor='val_loss',
min_delta=0,
patience=3,
verbose=1,
restore_best_weights=True)
callback_early_stop_reduceLROnPlateau=[earlyStopping]
model = Model(inputs=[img_a, img_b], outputs=distance)
model.compile(loss=contrastive_loss, optimizer=rms,metrics=[accuracy])
model.summary()
The architecture now looks like this
Modified CNN for Transfer Learning
10. Train the model and save the weights
img1 = x_pair[:, 0].reshape(-1,32,32,1)
img2 = x_pair[:, 1].reshape(-1,32,32,1)
img3 = x_pair_test[:, 0].reshape(-1,32,32,1)
img4 = x_pair_test[:, 1].reshape(-1,32,32,1)
batch_size = 8
history = model.fit([img1, img2], y, validation_data=([img3,img4],y_test),
batch_size= batch_size, verbose=1, epochs=10, callbacks=callback_early_stop_reduceLROnPlateau)
model.save_weights('model_weights.h5')
with open('model_architecture.json', 'w') as f:
f.write(model.to_json())
print('saved')
The results look good
We have got an accuracy of 0.99 on the validation set !! that too with just 50 labeled samples. I could have gone for even fewer samples but let’s keep it to 50 for now.
So, One-Shot Learning actually works and is not just on research papers. We will be wrapping this up now. Let’s catch up with some other One-Shot Algorithm in my next !!