Yet Another Siamese Neural Network Example Using PyTorch

A Siamese neural network uses a specialized architecture to compute the dissimilarity between two data items (almost always images). A Siamese network can be used for something called one-shot learning. A Siamese network uses a special kind of loss function called contrastive loss (although there are alternatives).

I reviewed the handful of existing examples of Siamese networks that I found on the Internet. I discovered that a.) Siamese networks seem simple but are in fact very subtle, b.) Siamese networks are much more difficult to implement than I expected, c.) there is a lot of misleading, or just plain incorrect, information about Siamese networks on the Internet.

I spent several hours and put together a demo. My demo uses a 1,000-item subset of the MNIST image dataset. The result is a model that accepts two MNIST images and emits two values. The two output values are then compared to compute a dissimilarity value. In other words, dissimilarity is not computed directly by the Siamese network. This is one of at least a dozen tricky ideas related to Siamese networks.

Note: The two output values are representations of the two input images. It’s possible to extend the Siamese network design presented in this blog post by adding a Linear layer that condenses the two output vectors (using sigmoid activation) to a single output value between 0 and 1 where the output is a measure of similarity (not dissimilarity). So a final output value less than 0.5 means same images, an output value greater than 0.5 means different images. The details are extremely tricky.

In the image below, the computed dissimilarity between two different ‘1’ digits is 0.000592. The computed dissimilarity between a ‘1’ and a ‘6’ is 0.00532 — a larger dissimilarity.

By itself, computing dissimilarity isn’t very useful. But dissimilarity can be used for one-shot learning — and that’s another topic for another blog post.

The demo program implements a special Dataset that serves up pairs of MNIST images along with a flag that indicates if the two images have the same class, such as two ‘3’ digits, or if the two images are different classes, such as a ‘1’ and a ‘6’. This was a surprisingly difficult sub-problem.

The Siamese network is a variation of a convolutional neural network — also a very difficult topic. The two inputs are two images. Dealing with the shapes is tricky. The two outputs are vectors of size 5 where the size 5 is a hyperparameter. These outputs are indirect measures of dissimilarity.

Normal neural network training is accomplished using a loss function that compares a computed output, such as a class label, to an expected output from the training data. A Siamese network uses a special contrastive loss function that can deal with pairs of inputs that are flagged as being the same class, or flagged as being different classes. This is another non-trivial topic.

Because Siamese networks have many different components, there are many, many different possible designs. My demo is just one approach.

The moral of all this is that Siamese neural networks aren’t magic rocket science, but if you want to understand them, you need to be prepared to put in a lot of time and effort.


Siamese cats in Disney animation. Left: Si and Am, two troublemakers from “Lady and the Tramp” (1955). Center: The piano playing Shun Gon from “The Aristocats” (1970). Right: The Siamese Twin Gang from “Chip ‘n Dale: Rescue Rangers” TV series (1988-1990).

Demo code.

# mnist_siamese.py
# PyTorch 1.10.0-CPU Anaconda3-2020.02  Python 3.7.6
# Windows 10/11 

import numpy as np
import matplotlib.pyplot as plt
import torch as T

device = T.device('cpu')

# -----------------------------------------------------------

class Siamese_Dataset(T.utils.data.Dataset):
  # 784 tab-delim pixel values (0-255) then label (0-9)
  def __init__(self, src_file):
    self.rnd = np.random.RandomState(0)
    all_xy = np.loadtxt(src_file, usecols=range(785),
      delimiter="\t", comments="#", dtype=np.float32)

    tmp_x = all_xy[:, 0:784]  # all rows, cols [0,783]
    self.rnd.shuffle(tmp_x)   # in-place
    tmp_x /= 255.0            # normalize
    tmp_x = tmp_x.reshape(-1, 1, 28, 28)  # bs, chnls, 28x28
    tmp_y = all_xy[:, 784]    # 1-D required

    self.x_data = \
      T.tensor(tmp_x, dtype=T.float32).to(device)
    self.y_data = \
     T.tensor(tmp_y, dtype=T.int64).to(device) 

    self.n = len(self.x_data)

  def __len__(self):
    return self.n
    
  def __getitem__(self, idx1):
    flag = self.rnd.randint(0,1)  # 0 = same class, or 1
    y = self.y_data[idx1]
    idx2 = self.rnd.randint(0,self.n-1)  # a bit tricky
    
    if flag == 0:  # get two images with same label
      while self.y_data[idx2] != y:
        idx2 += 1
        if idx2 == self.n: idx2 = 0
    elif idx1 % 2 != 0:  # get images different labels
      while self.y_data[idx2] == y:
        idx2 += 1
        if idx2 == self.n: idx2 = 0
 
    pixels1 = self.x_data[idx1]
    label1 = self.y_data[idx1] 
    pixels2 = self.x_data[idx2]
    label2 = self.y_data[idx2] 
    flag = T.tensor(flag, dtype=T.float32).to(device)
    
    return (pixels1, label1, pixels2, label2, flag)

# -----------------------------------------------------------

def display_mult_images(images, titles, rows, cols):
  # images is a list of 28x28 arrays
  for i in range(len(images)):
    images[i] = images[i].reshape(28,28)  # if necessary
  figure, ax = plt.subplots(rows,cols)  # array of axes

  for idx, img in enumerate(images):  
    ax.ravel()[idx].imshow(img,
      cmap=plt.get_cmap('gray_r'))
    ax.ravel()[idx].set_title(titles[idx])
    # ax.ravel()[idx].set_axis_off()
  plt.tight_layout()
  plt.show() 

# -----------------------------------------------------------

class SiameseNet(T.nn.Module):
  def __init__(self):
    super(SiameseNet, self).__init__()  # pre 3.3 syntax

    self.conv1 = T.nn.Conv2d(1, 32, 5)  # chnl-in, out, krnl
    self.conv2 = T.nn.Conv2d(32, 64, 5)
    self.fc1 = T.nn.Linear(1024, 512)   # [64*4*4, x]
    self.fc2 = T.nn.Linear(512, 256)
    self.fc3 = T.nn.Linear(256, 2)     # n values 
    self.pool1 = T.nn.MaxPool2d(2, stride=2) # kernel, stride
    self.pool2 = T.nn.MaxPool2d(2, stride=2)
    self.drop1 = T.nn.Dropout(0.25)
    self.drop2 = T.nn.Dropout(0.50)

  def feed(self, x):
    # convolution phase         # x is [bs, 1, 28, 28]
    z = T.relu(self.conv1(x))   # Size([bs, 32, 24, 24])
    z = self.pool1(z)           # Size([bs, 32, 12, 12])
    z = self.drop1(z)
    z = T.relu(self.conv2(z))   # Size([bs, 64, 8, 8])
    z = self.pool2(z)           # Size([bs, 64, 4, 4])
   
    # neural network phase
    z = z.reshape(-1, 1024)     # Size([bs, 1024])
    z = T.relu(self.fc1(z))     # Size([bs, 512])
    z = self.drop2(z)
    z = T.relu(self.fc2(z))     # Size([bs, 256])
    z = self.fc3(z)             # Size([bs, n])
    return z

  def forward(self, x1, x2):
    oupt1 = self.feed(x1)
    oupt2 = self.feed(x2)
    return oupt1, oupt2

# -----------------------------------------------------------

class ContrastiveLoss(T.nn.Module):
  def __init__(self, m=2.0):
    super(ContrastiveLoss, self).__init__()  # pre 3.3 syntax
    self.m = m  # margin or radius

  def forward(self, y1, y2, flag):
    # flag = 0 means y1 and y2 are supposed to be same
    # flag = 1 means y1 and y2 are supposed to be different
    
    euc_dist = T.nn.functional.pairwise_distance(y1, y2)

    loss = T.mean((1-flag) * T.pow(euc_dist, 2) +
      (flag) * T.pow(T.clamp(self.m - euc_dist, min=0.0), 2))

    return loss
    
# -----------------------------------------------------------

def siamese_dissim(siamese_model, image1, image2):
  # images are shape [1, chnls, 28, 28]
  # assumes model is in eval() mode
  image1 = image1.reshape(1,1,28,28)  # if necessary
  image2 = image2.reshape(1,1,28,28)
  with T.no_grad():
    oupt1, oupt2 = siamese_model(image1, image2)
  dissim = T.nn.functional.pairwise_distance(oupt1, oupt2)
  return np.round(dissim.item(), 6)

# -----------------------------------------------------------

def main():
  # 0. setup
  print("\nBegin MNIST Siamese network demo ")
  np.random.seed(1)
  T.manual_seed(1)

  # 1. create Dataset
  # print("\nLoading 1000-item train Dataset from text file ")
  train_file = ".\\Data\\mnist_train_1000.txt" 
  train_ds = Siamese_Dataset(train_file)

  bat_size = 10
  train_ldr = T.utils.data.DataLoader(train_ds,
    batch_size=bat_size, shuffle=True)

  # 2. create network
  print("\nCreating Siamese network (2 conv, 3 linear) ")
  net = SiameseNet().to(device)

  # 3. train model
  max_epochs = 40
  ep_log_interval = 4
  lrn_rate = 0.005
  loss_func = ContrastiveLoss()
  optimizer = T.optim.SGD(net.parameters(), lr=lrn_rate)
  
  print("\nbat_size = %3d " % bat_size)
  print("loss = " + "ContrastiveLoss()" )
  print("optimizer = SGD")
  print("max_epochs = %3d " % max_epochs)
  print("lrn_rate = %0.3f " % lrn_rate)

  print("\nStarting training")
  net.train()  # set mode
  for epoch in range(0, max_epochs):
    ep_loss = 0  # for one full epoch
    for (batch_idx, batch) in enumerate(train_ldr):
      X1, y1, X2, y2, flag = batch
      oupt1, oupt2 = net(X1, X2)
     
      optimizer.zero_grad()       # reset gradients
      loss_val = loss_func(oupt1, oupt2, flag)

      ep_loss += loss_val.item()  # accumulate loss
      loss_val.backward()         # compute grads
      optimizer.step()            # update weights
    if epoch % ep_log_interval == 0:
      print("epoch = %4d  |  loss = %10.4f" % (epoch, ep_loss))
  print("Done ") 

  # 4. TODO: save trained model

# -----------------------------------------------------------

  # 5. use model
  print("\nUsing trained Siaamese model ")
  pixels1 = train_ds.x_data[0]  # a '1' 
  pixels2 = train_ds.x_data[3]  # a different '1'
  pixels3 = train_ds.x_data[4]  # a '6'

  net.eval()
  dissim_12 = siamese_dissim(net, pixels1, pixels2)
  dissim_13 = siamese_dissim(net, pixels1, pixels3)

  display_mult_images([pixels1, pixels2],
    ["", "dissim = " + str(dissim_12)], 1, 2)  # 1 row, 2 cols

  display_mult_images([pixels1, pixels3],
    ["", "dissim = " + str(dissim_13)], 1, 2)

  print("\nEnd MNIST Siamese demo ")

if __name__ == "__main__":
  main()

Share this:

Like this:

Like

Loading…