Yet Another Siamese Neural Network Example Using PyTorch
A Siamese neural network uses a specialized architecture to compute the dissimilarity between two data items (almost always images). A Siamese network can be used for something called one-shot learning. A Siamese network uses a special kind of loss function called contrastive loss (although there are alternatives).
I reviewed the handful of existing examples of Siamese networks that I found on the Internet. I discovered that a.) Siamese networks seem simple but are in fact very subtle, b.) Siamese networks are much more difficult to implement than I expected, c.) there is a lot of misleading, or just plain incorrect, information about Siamese networks on the Internet.
I spent several hours and put together a demo. My demo uses a 1,000-item subset of the MNIST image dataset. The result is a model that accepts two MNIST images and emits two values. The two output values are then compared to compute a dissimilarity value. In other words, dissimilarity is not computed directly by the Siamese network. This is one of at least a dozen tricky ideas related to Siamese networks.
Note: The two output values are representations of the two input images. It’s possible to extend the Siamese network design presented in this blog post by adding a Linear layer that condenses the two output vectors (using sigmoid activation) to a single output value between 0 and 1 where the output is a measure of similarity (not dissimilarity). So a final output value less than 0.5 means same images, an output value greater than 0.5 means different images. The details are extremely tricky.
In the image below, the computed dissimilarity between two different ‘1’ digits is 0.000592. The computed dissimilarity between a ‘1’ and a ‘6’ is 0.00532 — a larger dissimilarity.
By itself, computing dissimilarity isn’t very useful. But dissimilarity can be used for one-shot learning — and that’s another topic for another blog post.
The demo program implements a special Dataset that serves up pairs of MNIST images along with a flag that indicates if the two images have the same class, such as two ‘3’ digits, or if the two images are different classes, such as a ‘1’ and a ‘6’. This was a surprisingly difficult sub-problem.
The Siamese network is a variation of a convolutional neural network — also a very difficult topic. The two inputs are two images. Dealing with the shapes is tricky. The two outputs are vectors of size 5 where the size 5 is a hyperparameter. These outputs are indirect measures of dissimilarity.
Normal neural network training is accomplished using a loss function that compares a computed output, such as a class label, to an expected output from the training data. A Siamese network uses a special contrastive loss function that can deal with pairs of inputs that are flagged as being the same class, or flagged as being different classes. This is another non-trivial topic.
Because Siamese networks have many different components, there are many, many different possible designs. My demo is just one approach.
The moral of all this is that Siamese neural networks aren’t magic rocket science, but if you want to understand them, you need to be prepared to put in a lot of time and effort.
Siamese cats in Disney animation. Left: Si and Am, two troublemakers from “Lady and the Tramp” (1955). Center: The piano playing Shun Gon from “The Aristocats” (1970). Right: The Siamese Twin Gang from “Chip ‘n Dale: Rescue Rangers” TV series (1988-1990).
Demo code.
# mnist_siamese.py # PyTorch 1.10.0-CPU Anaconda3-2020.02 Python 3.7.6 # Windows 10/11 import numpy as np import matplotlib.pyplot as plt import torch as T device = T.device('cpu') # ----------------------------------------------------------- class Siamese_Dataset(T.utils.data.Dataset): # 784 tab-delim pixel values (0-255) then label (0-9) def __init__(self, src_file): self.rnd = np.random.RandomState(0) all_xy = np.loadtxt(src_file, usecols=range(785), delimiter="\t", comments="#", dtype=np.float32) tmp_x = all_xy[:, 0:784] # all rows, cols [0,783] self.rnd.shuffle(tmp_x) # in-place tmp_x /= 255.0 # normalize tmp_x = tmp_x.reshape(-1, 1, 28, 28) # bs, chnls, 28x28 tmp_y = all_xy[:, 784] # 1-D required self.x_data = \ T.tensor(tmp_x, dtype=T.float32).to(device) self.y_data = \ T.tensor(tmp_y, dtype=T.int64).to(device) self.n = len(self.x_data) def __len__(self): return self.n def __getitem__(self, idx1): flag = self.rnd.randint(0,1) # 0 = same class, or 1 y = self.y_data[idx1] idx2 = self.rnd.randint(0,self.n-1) # a bit tricky if flag == 0: # get two images with same label while self.y_data[idx2] != y: idx2 += 1 if idx2 == self.n: idx2 = 0 elif idx1 % 2 != 0: # get images different labels while self.y_data[idx2] == y: idx2 += 1 if idx2 == self.n: idx2 = 0 pixels1 = self.x_data[idx1] label1 = self.y_data[idx1] pixels2 = self.x_data[idx2] label2 = self.y_data[idx2] flag = T.tensor(flag, dtype=T.float32).to(device) return (pixels1, label1, pixels2, label2, flag) # ----------------------------------------------------------- def display_mult_images(images, titles, rows, cols): # images is a list of 28x28 arrays for i in range(len(images)): images[i] = images[i].reshape(28,28) # if necessary figure, ax = plt.subplots(rows,cols) # array of axes for idx, img in enumerate(images): ax.ravel()[idx].imshow(img, cmap=plt.get_cmap('gray_r')) ax.ravel()[idx].set_title(titles[idx]) # ax.ravel()[idx].set_axis_off() plt.tight_layout() plt.show() # ----------------------------------------------------------- class SiameseNet(T.nn.Module): def __init__(self): super(SiameseNet, self).__init__() # pre 3.3 syntax self.conv1 = T.nn.Conv2d(1, 32, 5) # chnl-in, out, krnl self.conv2 = T.nn.Conv2d(32, 64, 5) self.fc1 = T.nn.Linear(1024, 512) # [64*4*4, x] self.fc2 = T.nn.Linear(512, 256) self.fc3 = T.nn.Linear(256, 2) # n values self.pool1 = T.nn.MaxPool2d(2, stride=2) # kernel, stride self.pool2 = T.nn.MaxPool2d(2, stride=2) self.drop1 = T.nn.Dropout(0.25) self.drop2 = T.nn.Dropout(0.50) def feed(self, x): # convolution phase # x is [bs, 1, 28, 28] z = T.relu(self.conv1(x)) # Size([bs, 32, 24, 24]) z = self.pool1(z) # Size([bs, 32, 12, 12]) z = self.drop1(z) z = T.relu(self.conv2(z)) # Size([bs, 64, 8, 8]) z = self.pool2(z) # Size([bs, 64, 4, 4]) # neural network phase z = z.reshape(-1, 1024) # Size([bs, 1024]) z = T.relu(self.fc1(z)) # Size([bs, 512]) z = self.drop2(z) z = T.relu(self.fc2(z)) # Size([bs, 256]) z = self.fc3(z) # Size([bs, n]) return z def forward(self, x1, x2): oupt1 = self.feed(x1) oupt2 = self.feed(x2) return oupt1, oupt2 # ----------------------------------------------------------- class ContrastiveLoss(T.nn.Module): def __init__(self, m=2.0): super(ContrastiveLoss, self).__init__() # pre 3.3 syntax self.m = m # margin or radius def forward(self, y1, y2, flag): # flag = 0 means y1 and y2 are supposed to be same # flag = 1 means y1 and y2 are supposed to be different euc_dist = T.nn.functional.pairwise_distance(y1, y2) loss = T.mean((1-flag) * T.pow(euc_dist, 2) + (flag) * T.pow(T.clamp(self.m - euc_dist, min=0.0), 2)) return loss # ----------------------------------------------------------- def siamese_dissim(siamese_model, image1, image2): # images are shape [1, chnls, 28, 28] # assumes model is in eval() mode image1 = image1.reshape(1,1,28,28) # if necessary image2 = image2.reshape(1,1,28,28) with T.no_grad(): oupt1, oupt2 = siamese_model(image1, image2) dissim = T.nn.functional.pairwise_distance(oupt1, oupt2) return np.round(dissim.item(), 6) # ----------------------------------------------------------- def main(): # 0. setup print("\nBegin MNIST Siamese network demo ") np.random.seed(1) T.manual_seed(1) # 1. create Dataset # print("\nLoading 1000-item train Dataset from text file ") train_file = ".\\Data\\mnist_train_1000.txt" train_ds = Siamese_Dataset(train_file) bat_size = 10 train_ldr = T.utils.data.DataLoader(train_ds, batch_size=bat_size, shuffle=True) # 2. create network print("\nCreating Siamese network (2 conv, 3 linear) ") net = SiameseNet().to(device) # 3. train model max_epochs = 40 ep_log_interval = 4 lrn_rate = 0.005 loss_func = ContrastiveLoss() optimizer = T.optim.SGD(net.parameters(), lr=lrn_rate) print("\nbat_size = %3d " % bat_size) print("loss = " + "ContrastiveLoss()" ) print("optimizer = SGD") print("max_epochs = %3d " % max_epochs) print("lrn_rate = %0.3f " % lrn_rate) print("\nStarting training") net.train() # set mode for epoch in range(0, max_epochs): ep_loss = 0 # for one full epoch for (batch_idx, batch) in enumerate(train_ldr): X1, y1, X2, y2, flag = batch oupt1, oupt2 = net(X1, X2) optimizer.zero_grad() # reset gradients loss_val = loss_func(oupt1, oupt2, flag) ep_loss += loss_val.item() # accumulate loss loss_val.backward() # compute grads optimizer.step() # update weights if epoch % ep_log_interval == 0: print("epoch = %4d | loss = %10.4f" % (epoch, ep_loss)) print("Done ") # 4. TODO: save trained model # ----------------------------------------------------------- # 5. use model print("\nUsing trained Siaamese model ") pixels1 = train_ds.x_data[0] # a '1' pixels2 = train_ds.x_data[3] # a different '1' pixels3 = train_ds.x_data[4] # a '6' net.eval() dissim_12 = siamese_dissim(net, pixels1, pixels2) dissim_13 = siamese_dissim(net, pixels1, pixels3) display_mult_images([pixels1, pixels2], ["", "dissim = " + str(dissim_12)], 1, 2) # 1 row, 2 cols display_mult_images([pixels1, pixels3], ["", "dissim = " + str(dissim_13)], 1, 2) print("\nEnd MNIST Siamese demo ") if __name__ == "__main__": main()
Share this:
Like this:
Like
Loading…