Keras — Transfer Learning with Cats vs Dogs

7 min readSep 11, 2022


In this article, we learn how to setup data generators to load our own dataset and train a classifier using Keras.

  1. Understand trainable layers of a Neural Network
  2. Setting up our data
  3. Building our Model for Transfer Learning
  4. Perform Fine Tuning

# Import of libraries

import numpy as np

import tensorflow as tf

from tensorflow import keras

Trainable Layers

Layers & models have three weight attributes:

  • weights is the list of all weights variables of the layer.
  • trainable_weights is the list of those that are meant to be updated (via gradient descent) to minimize the loss during training.
  • non_trainable_weights is the list of those that aren't meant to be trained. Typically they are updated by the model during the forward pass.

Example: the Dense layer has 2 trainable weights (kernel & bias)

layer = keras.layers.Dense(4)

# Create the weights using, 2))

print(f’Number of weights: {len(layer.weights)}’)

print(f’Number of trainable_weights: {len(layer.trainable_weights)}’)

print(f’Number of non_trainable_weights: {len(layer.non_trainable_weights)}’)

All layers are trainable with the exception of BatchNormalization. It uses non-trainable weights to keep track of the mean and variance of its inputs during training.

Layers & models also feature a boolean attribute trainable.

Its value can be changed by setting layer.trainable to False moves all the layer's weights from trainable to non-trainable.

This is called “freezing” the layer: the state of a frozen layer won’t be updated during training (either when training with fit() or when training with any custom loop that relies on trainable_weights to apply gradient updates).

Example: setting trainable to False

# Make a model with 2 layers

layer1 = keras.layers.Dense(3, activation=”relu”)

layer2 = keras.layers.Dense(3, activation=”sigmoid”)

model = keras.Sequential([keras.Input(shape=(3,)), layer1, layer2])

# Freeze the first layer

layer1.trainable = False

# Keep a copy of the weights of layer1 for later reference

initial_layer1_weights_values = layer1.get_weights()

# Train the model

model.compile(optimizer=”adam”, loss=”mse”), 3)), np.random.random((2, 3)))

will get:

1/1 [==============================] - 1s 521ms/step - loss: 0.0446<keras.callbacks.History at 0x7f270aea7c10>

# Check that the weights of layer1 have not changed during training

final_layer1_weights_values = layer1.get_weights()

if initial_layer1_weights_values[0].all() == final_layer1_weights_values[0].all():

print(‘Weights unchanged’)

if initial_layer1_weights_values[1].all() == final_layer1_weights_values[1].all():

print(‘Weights unchanged’)

will get:

Weights unchanged

Weights unchanged

Note: .trianable is Recursive, meaning that on a model or on any layer that has sublayers, all children layers become non-trainable as well.

Implementing Transfer Learning

Transfer-learning workflow

  1. We instantiate a base model and load pre-trained weighs into it.
  2. Freeze all layers in the base model by setting trainable = False.
  3. Create a new model on top of the output of one (or several) layers from the base model.
  4. Train your new model on your new dataset.

Step 1. Load a base model with pre-trained weights (ImageNet)

import tensorflow_datasets as tfds


train_ds, validation_ds, test_ds = tfds.load(


# Reserve 10% for validation and 10% for test

split=[“train[:40%]”, “train[40%:50%]”, “train[50%:60%]”],

as_supervised=True, # Include labels


print(f’Number of training samples: {}’)

print(f’Number of validation samples: {}’)

print(f’Number of test samples: {}’)

will get below…:

Downloading and preparing dataset cats_vs_dogs/4.0.0 (download: 786.68 MiB, generated: Unknown size, total: 786.68 MiB) to /root/tensorflow_datasets/cats_vs_dogs/4.0.0...WARNING:absl:1738 images were corrupted and were skippedShuffling and writing examples to /root/tensorflow_datasets/cats_vs_dogs/4.0.0.incomplete1DQMSY/cats_vs_dogs-train.tfrecord
Dataset cats_vs_dogs downloaded and prepared to /root/tensorflow_datasets/cats_vs_dogs/4.0.0. Subsequent calls will reuse this data.
Number of training samples: 9305
Number of validation samples: 2326
Number of test samples: 2326

These are the first 9 images in the training dataset — as you can see, they’re all different sizes.

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 10))

for i, (image, label) in enumerate(train_ds.take(9)):

ax = plt.subplot(3, 3, i + 1)


plt.title(‘Cat’ if int(label) == 0 else ‘Dog’)


Standardize Our Data

  • Standardize to a fixed image size. We pick 150x150.
  • Normalize pixel values between -1 and 1. We’ll do this using a Normalization layer as part of the model itself.

size = (150, 150)

train_ds = x, y: (tf.image.resize(x, size), y))

validation_ds = x, y: (tf.image.resize(x, size), y))

test_ds = x, y: (tf.image.resize(x, size), y))

We’ll batch the data and use caching & prefetching to optimize loading speed.

batch_size = 32

train_ds = train_ds.cache().batch(batch_size).prefetch(buffer_size=10)

validation_ds = validation_ds.cache().batch(batch_size).prefetch(buffer_size=10)

test_ds = test_ds.cache().batch(batch_size).prefetch(buffer_size=10)

Introduce some random data augmentation

from tensorflow import keras

from tensorflow.keras import layers

data_augmentation = keras.Sequential(



Visualize our Data Augmentations

import numpy as np

for images, labels in train_ds.take(1):

plt.figure(figsize=(10, 10))

first_image = images[0]

for i in range(9):

ax = plt.subplot(3, 3, i + 1)

augmented_image = data_augmentation(

tf.expand_dims(first_image, 0), training=True





3. Building our model

Now let’s built a model that follows the blueprint we’ve explained earlier.

Note that:

  • We add a Normalization layer to scale input values (initially in the [0, 255] range) to the [-1, 1] range.
  • We add a Dropout layer before the classification layer, for regularization.
  • We make sure to pass training=False when calling the base model, so that it runs in inference mode, so that batchnorm statistics don't get updated even after we unfreeze the base model for fine-tuning.
  • We’ll be using the Xception Model as our base.

base_model = keras.applications.Xception(

weights=”imagenet”, # Load weights pre-trained on ImageNet.

input_shape=(150, 150, 3),


) # Do not include the ImageNet classifier at the top.

# Freeze the base_model

base_model.trainable = False

# Create new model on top

inputs = keras.Input(shape=(150, 150, 3))

x = data_augmentation(inputs) # Apply random data augmentation

# Pre-trained Xception weights requires that input be scaled

# from (0, 255) to a range of (-1., +1.), the rescaling layer

# outputs: `(inputs * scale) + offset`

scale_layer = keras.layers.Rescaling(scale=1 / 127.5, offset=-1)

x = scale_layer(x)

# The base model contains batchnorm layers. We want to keep them in inference mode

# when we unfreeze the base model for fine-tuning, so we make sure that the

# base_model is running in inference mode here.

x = base_model(x, training=False)

x = keras.layers.GlobalAveragePooling2D()(x)

x = keras.layers.Dropout(0.2)(x) # Regularize with dropout

outputs = keras.layers.Dense(1)(x)

model = keras.Model(inputs, outputs)


will get below output:

Downloading data from 83689472/83683744 [==============================] — 3s 0us/step 83697664/83683744 [==============================] — 3s 0us/step Model: “model” _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_3 (InputLayer) [(None, 150, 150, 3)] 0 sequential_1 (Sequential) (None, 150, 150, 3) 0 rescaling (Rescaling) (None, 150, 150, 3) 0 xception (Functional) (None, 5, 5, 2048) 20861480 global_average_pooling2d (G (None, 2048) 0 lobalAveragePooling2D) dropout (Dropout) (None, 2048) 0 dense_3 (Dense) (None, 1) 2049 ================================================================= Total params: 20,863,529 Trainable params: 2,049 Non-trainable params: 20,861,480 _________________________________________________________________

Now let’s Train our Top Layer

Note from the above summary that we only have 2,049 trainable paramaters.






epochs = 20, epochs=epochs, validation_data=validation_ds)

will get…:

Epoch 1/20 291/291 [==============================] — 689s 2s/step — loss: 0.1793 — binary_accuracy: 0.9192 — val_loss: 0.0832 — val_binary_accuracy: 0.9678 Epoch 2/20 291/291 [==============================] — 660s 2s/step — loss: 0.1163 — binary_accuracy: 0.9515 — val_loss: 0.0743 — val_binary_accuracy: 0.9716 Epoch 3/20 291/291 [==============================] — 694s 2s/step — loss: 0.1073 — binary_accuracy: 0.9557 — val_loss: 0.0725 — val_binary_accuracy: 0.9721 Epoch 4/20 291/291 [==============================] — 703s 2s/step — loss: 0.1085 — binary_accuracy: 0.9556 — val_loss: 0.0731 — val_binary_accuracy: 0.9721 Epoch 5/20 291/291 [==============================] — 697s 2s/step — loss: 0.1018 — binary_accuracy: 0.9589 — val_loss: 0.0737 — val_binary_accuracy: 0.9703 Epoch 6/20 26/291 [=>……………………….] — ETA: 8:32 — loss: 0.0861 — binary_accuracy: 0.9627

4. Fine Tuning

We unfreeze the base model and train the entire model end-to-end with a low learning rate.

Notes although the base model becomes trainable, it is still running in inference mode since we passed training=False when calling it when we built the model.

This means that the batch normalization layers inside won’t update their batch statistics. If they did, they would wreck havoc on the representations learned by the model so far.

# Unfreeze the base_model. Note that it keeps running in inference mode

# since we passed `training=False` when calling it. This means that

# the batchnorm layers will not update their batch statistics.

# This prevents the batchnorm layers from undoing all the training

# we’ve done so far.

base_model.trainable = True



optimizer=keras.optimizers.Adam(1e-5), # Low learning rate




epochs = 10, epochs=epochs, validation_data=validation_ds)

will get…:

Model: “model_1” _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_11 (InputLayer) [(None, 150, 150, 3)] 0 sequential_1 (Sequential) (None, 150, 150, 3) 0 rescaling_1 (Rescaling) (None, 150, 150, 3) 0 xception (Functional) (None, 5, 5, 2048) 20861480 global_average_pooling2d_1 (None, 2048) 0 (GlobalAveragePooling2D) dropout_1 (Dropout) (None, 2048) 0 dense_4 (Dense) (None, 1) 2049 ================================================================= Total params: 20,863,529 Trainable params: 20,809,001 Non-trainable params: 54,528 _________________________________________________________________ Epoch 1/10 291/291 [==============================] — 2885s 10s/step — loss: 0.1086 — binary_accuracy: 0.9552 — val_loss: 0.0574 — val_binary_accuracy: 0.9755 Epoch 2/10 291/291 [==============================] — 2847s 10s/step — loss: 0.0644 — binary_accuracy: 0.9747 — val_loss: 0.0499 — val_binary_accuracy: 0.9798 Epoch 3/10 291/291 [==============================] — 2849s 10s/step — loss: 0.0525 — binary_accuracy: 0.9799 — val_loss: 0.0440 — val_binary_accuracy: 0.9815 Epoch 4/10 227/291 [======================>…….] — ETA: 9:51 — loss: 0.0396 — binary_accuracy: 0.9851

OK, finally we get the binary_accuracy is 0.9851

See ya, this is the learning process note!

You guys can follow me and step by step with me!!

Thank you for your watching!



