Image similarity estimation using a Siamese Network with a triplet loss
class
SiameseModel
(
Model
):
"""The Siamese Network model with a custom training and testing loops.
Computes the triplet loss using the three embeddings produced by the
Siamese Network.
The triplet loss is defined as:
L(A, P, N) = max(‖f(A) - f(P)‖² - ‖f(A) - f(N)‖² + margin, 0)
"""
def
__init__
(
self
,
siamese_network
,
margin
=
0.5
):
super
()
.
__init__
()
self
.
siamese_network
=
siamese_network
self
.
margin
=
margin
self
.
loss_tracker
=
metrics
.
Mean
(
name
=
"loss"
)
def
call
(
self
,
inputs
):
return
self
.
siamese_network
(
inputs
)
def
train_step
(
self
,
data
):
# GradientTape is a context manager that records every operation that
# you do inside. We are using it here to compute the loss so we can get
# the gradients and apply them using the optimizer specified in
# `compile()`.
with
tf
.
GradientTape
()
as
tape
:
loss
=
self
.
_compute_loss
(
data
)
# Storing the gradients of the loss function with respect to the
# weights/parameters.
gradients
=
tape
.
gradient
(
loss
,
self
.
siamese_network
.
trainable_weights
)
# Applying the gradients on the model using the specified optimizer
self
.
optimizer
.
apply_gradients
(
zip
(
gradients
,
self
.
siamese_network
.
trainable_weights
)
)
# Let's update and return the training loss metric.
self
.
loss_tracker
.
update_state
(
loss
)
return
{
"loss"
:
self
.
loss_tracker
.
result
()}
def
test_step
(
self
,
data
):
loss
=
self
.
_compute_loss
(
data
)
# Let's update and return the loss metric.
self
.
loss_tracker
.
update_state
(
loss
)
return
{
"loss"
:
self
.
loss_tracker
.
result
()}
def
_compute_loss
(
self
,
data
):
# The output of the network is a tuple containing the distances
# between the anchor and the positive example, and the anchor and
# the negative example.
ap_distance
,
an_distance
=
self
.
siamese_network
(
data
)
# Computing the Triplet Loss by subtracting both distances and
# making sure we don't get a negative value.
loss
=
ap_distance
-
an_distance
loss
=
tf
.
maximum
(
loss
+
self
.
margin
,
0.0
)
return
loss
@property
def
metrics
(
self
):
# We need to list our metrics here so the `reset_states()` can be
# called automatically.
return
[
self
.
loss_tracker
]