02. Neural Network Classification with TensorFlow – Zero to Mastery TensorFlow for Deep Learning

# Note: The following confusion matrix code is a remix of Scikit-Learn's

# plot_confusion_matrix function - https://scikit-learn.org/stable/modules/generated/sklearn.metrics.plot_confusion_matrix.html

# and Made with ML's introductory notebook - https://github.com/GokuMohandas/MadeWithML/blob/main/notebooks/08_Neural_Networks.ipynb

import

itertools

from

sklearn.metrics

import

confusion_matrix

# Our function needs a different name to sklearn's plot_confusion_matrix

def

make_confusion_matrix

(

y_true

,

y_pred

,

classes

=

None

,

figsize

=

(

10

,

10

),

text_size

=

15

):

"""Makes a labelled confusion matrix comparing predictions and ground truth labels.

If classes is passed, confusion matrix will be labelled, if not, integer class values

will be used.

Args:

y_true: Array of truth labels (must be same shape as y_pred).

y_pred: Array of predicted labels (must be same shape as y_true).

classes: Array of class labels (e.g. string form). If `None`, integer labels are used.

figsize: Size of output figure (default=(10, 10)).

text_size: Size of output figure text (default=15).

Returns:

A labelled confusion matrix plot comparing y_true and y_pred.

Example usage:

make_confusion_matrix(y_true=test_labels, # ground truth test labels

y_pred=y_preds, # predicted labels

classes=class_names, # array of class label names

figsize=(15, 15),

text_size=10)

"""

# Create the confustion matrix

cm

=

confusion_matrix

(

y_true

,

y_pred

)

cm_norm

=

cm

.

astype

(

"float"

)

/

cm

.

sum

(

axis

=

1

)[:,

np

.

newaxis

]

# normalize it

n_classes

=

cm

.

shape

[

0

]

# find the number of classes we're dealing with

# Plot the figure and make it pretty

fig

,

ax

=

plt

.

subplots

(

figsize

=

figsize

)

cax

=

ax

.

matshow

(

cm

,

cmap

=

plt

.

cm

.

Blues

)

# colors will represent how 'correct' a class is, darker == better

fig

.

colorbar

(

cax

)

# Are there a list of classes?

if

classes

:

labels

=

classes

else

:

labels

=

np

.

arange

(

cm

.

shape

[

0

])

# Label the axes

ax

.

set

(

title

=

"Confusion Matrix"

,

xlabel

=

"Predicted label"

,

ylabel

=

"True label"

,

xticks

=

np

.

arange

(

n_classes

),

# create enough axis slots for each class

yticks

=

np

.

arange

(

n_classes

),

xticklabels

=

labels

,

# axes will labeled with class names (if they exist) or ints

yticklabels

=

labels

)

# Make x-axis labels appear on bottom

ax

.

xaxis

.

set_label_position

(

"bottom"

)

ax

.

xaxis

.

tick_bottom

()

# Set the threshold for different colors

threshold

=

(

cm

.

max

()

+

cm

.

min

())

/

2.

# Plot the text on each cell

for

i

,

j

in

itertools

.

product

(

range

(

cm

.

shape

[

0

]),

range

(

cm

.

shape

[

1

])):

plt

.

text

(

j

,

i

,

f

"

{

cm

[

i

,

j

]

}

(

{

cm_norm

[

i

,

j

]

*

100

:

.1f

}

%)"

,

horizontalalignment

=

"center"

,

color

=

"white"

if

cm

[

i

,

j

]

>

threshold

else

"black"

,

size

=

text_size

)