Image Segmentation

Being a practitioner in Machine Learning, you must have gone through an image classification, where the goal is to assign a label or a class to the input image. Now, suppose you want to get where the object is present inside the image, the shape of the object, or what pixel represents what object. In such a case, you have to play with the segment of the image, from which I mean to say to give a label to each pixel of the image. The goal of Image Segmentation is to train a Neural Network which can return a pixel-wise mask of the image.

In the real world, Image Segmentation helps in many applications in medical science, self-driven cars, imaging of satellites and many more. Image Segmentation works by studying the image at the lowest level.

In this article, I will take you through Image Segmentation with Deep Learning. Now let’s learn about Image Segmentation by digging deeper into it. I will start by merely importing the libraries that we need for Image Segmentation.

import tensorflow as tf from tensorflow_examples.models.pix2pix import pix2pix import tensorflow_datasets as tfds tfds.disable_progress_bar() from IPython.display import clear_output import matplotlib.pyplot as plt
Code language: Python (python)

I will use the Oxford-IIIT Pets dataset, that is already included in Tensorflow:

dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)
Code language: Python (python)

The code below performs a simple image augmentation. Like we prepare the data before doing any machine learning task based on text analysis. Here I am just preparing the images for Image Segmentation:

def normalize(input_image, input_mask): input_image = tf.cast(input_image, tf.float32) / 255.0 input_mask -= 1 return input_image, input_mask @tf.function def load_image_train(datapoint): input_image = tf.image.resize(datapoint['image'], (128, 128)) input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128)) if tf.random.uniform(()) > 0.5: input_image = tf.image.flip_left_right(input_image) input_mask = tf.image.flip_left_right(input_mask) input_image, input_mask = normalize(input_image, input_mask) return input_image, input_mask def load_image_test(datapoint): input_image = tf.image.resize(datapoint['image'], (128, 128)) input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128)) input_image, input_mask = normalize(input_image, input_mask) return input_image, input_mask
Code language: Python (python)

In the dataset, we already have the required number of training and test sets. So I will continue to use that split of training and test sets:

TRAIN_LENGTH = info.splits['train'].num_examples BATCH_SIZE = 64 BUFFER_SIZE = 1000 STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE train = dataset['train'].map(load_image_train, test = dataset['test'].map(load_image_test) train_dataset = train.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat() train_dataset = train_dataset.prefetch( test_dataset = test.batch(BATCH_SIZE)
Code language: Python (python)

Now let’s have a quick look at an image and it’s mask from the data:

def display(display_list): plt.figure(figsize=(15, 15)) title = ['Input Image', 'True Mask', 'Predicted Mask'] for i in range(len(display_list)): plt.subplot(1, len(display_list), i+1) plt.title(title[i]) plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i])) plt.axis('off') for image, mask in train.take(1): sample_image, sample_mask = image, mask display([sample_image, sample_mask])
Code language: Python (python)

Define the Image SegmentationModel

The model that I will use here is a modified U-Net. A U-Net contains an encoder and a decoder. In order to learn the robust features, and reducing all the trainable parameters, a pretrained model can be used efficiently as an encoder.

Code language: Python (python)

As I already mentioned above, our encoder is a pretrained model which is available and ready to use in tf.keras.applications. This encoder contains some specific outputs from the intermediate layers of the model. Please note that the encoder will not be trained during the process of training.

base_model = tf.keras.applications.MobileNetV2(input_shape=[128, 128, 3], include_top=False) # Use the activations of these layers layer_names = [ 'block_1_expand_relu', # 64x64 'block_3_expand_relu', # 32x32 'block_6_expand_relu', # 16x16 'block_13_expand_relu', # 8x8 'block_16_project', # 4x4 ] layers = [base_model.get_layer(name).output for name in layer_names] # Create the feature extraction model down_stack = tf.keras.Model(inputs=base_model.input, outputs=layers) down_stack.trainable = False
Code language: Python (python)

The decoder/upsampler is simply a series of upsample blocks implemented in TensorFlow examples:

up_stack = [ pix2pix.upsample(512, 3), # 4x4 -> 8x8 pix2pix.upsample(256, 3), # 8x8 -> 16x16 pix2pix.upsample(128, 3), # 16x16 -> 32x32 pix2pix.upsample(64, 3), # 32x32 -> 64x64 ] def unet_model(output_channels): inputs = tf.keras.layers.Input(shape=[128, 128, 3]) x = inputs # Downsampling through the model skips = down_stack(x) x = skips[-1] skips = reversed(skips[:-1]) # Upsampling and establishing the skip connections for up, skip in zip(up_stack, skips): x = up(x) concat = tf.keras.layers.Concatenate() x = concat([x, skip]) # This is the last layer of the model last = tf.keras.layers.Conv2DTranspose( output_channels, 3, strides=2, padding='same') #64x64 -> 128x128 x = last(x) return tf.keras.Model(inputs=inputs, outputs=x)
Code language: Python (python)

Training the Image Segmentation Model

model = unet_model(OUTPUT_CHANNELS) model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy'])
Code language: Python (python)

Now before moving forward let’s have a quick look at the resulting output of the trained model:

tf.keras.utils.plot_model(model, show_shapes=True)
Code language: Python (python)

Let’s try out the model to see what it predicts before training:

def create_mask(pred_mask): pred_mask = tf.argmax(pred_mask, axis=-1) pred_mask = pred_mask[..., tf.newaxis] return pred_mask[0] def show_predictions(dataset=None, num=1): if dataset: for image, mask in dataset.take(num): pred_mask = model.predict(image) display([image[0], mask[0], create_mask(pred_mask)]) else: display([sample_image, sample_mask, create_mask(model.predict(sample_image[tf.newaxis, ...]))]) show_predictions()
Code language: Python (python)

Now, Let’s observe how the Image Segmentation model improves while it is training. To accomplish this task, a callback function is defined below:

class DisplayCallback(tf.keras.callbacks.Callback): def on_epoch_end(self, epoch, logs=None): clear_output(wait=True) show_predictions() print ('\nSample Prediction after epoch {}\n'.format(epoch+1)) EPOCHS = 20 VAL_SUBSPLITS = 5 VALIDATION_STEPS = info.splits['test'].num_examples//BATCH_SIZE//VAL_SUBSPLITS model_history =, epochs=EPOCHS, steps_per_epoch=STEPS_PER_EPOCH, validation_steps=VALIDATION_STEPS, validation_data=test_dataset, callbacks=[DisplayCallback()])
Code language: Python (python)

Now, let’s have a quick look on the performance of the model:

loss = model_history.history['loss'] val_loss = model_history.history['val_loss'] epochs = range(EPOCHS) plt.figure() plt.plot(epochs, loss, 'r', label='Training loss') plt.plot(epochs, val_loss, 'bo', label='Validation loss') plt.title('Training and Validation Loss') plt.xlabel('Epoch') plt.ylabel('Loss Value') plt.ylim([0, 1]) plt.legend()
Code language: Python (python)

Make Predictions using Image Segmentation

Let’s make some predictions. In the interest of saving time, the number of epochs was kept small, but you may set this higher to achieve more accurate results:

Also Read: Pipelines in Machine Learning.

show_predictions(test_dataset, 3)
Code language: Python (python)
Image Segmentation
Image Segmentation
Image Segmentation

Also Read: 10 Machine Learning Projects to Boost your Portfolio.

I hope you liked this article on Image Segmentation with Deep Learning. Feel free to ask your valuable questions in the comments section below.

Follow Us:

Default image
Aman Kharwal

I am a programmer from India, and I am here to guide you with Data Science, Machine Learning, Python, and C++ for free. I hope you will learn a lot in your journey towards Coding, Machine Learning and Artificial Intelligence with me.

Leave a Reply

Data Science | Machine Learning | Python | C++ | Coding | Programming | JavaScript