PyTorch for Deep Learning

PyTorch is a library in Python which provides tools to build deep learning models. What python does for programming PyTorch does for deep learning. Python is a very flexible language for programming and just like python, the PyTorch library provides flexible tools for deep learning. If you are learning deep learning or looking to start with it, then the knowledge of PyTorch will help you a lot in creating your deep learning models.

It has proven to be one of the most flexible and stronger tools to work with real-world problems by providing a higher level of performance. PyTorch provides a deep data structure known as a tensor, which is a multidimensional array that facilitates many similarities with the NumPy arrays.

Why PyTorch for Deep Learning?

As we know deep learning allows us to work with a very wide range of complicated tasks, like machine translations, playing strategy games, objects detection, and many more. With PyTorch, you can perform these complex tasks in very flexible ways. Now let’s understand PyTorch more by working on a real-world example.

Training a Classifier with PyTorch

I will do the following steps in order to work on the Image Classification with PyTorch:

  1. Load and normalizing the CIFAR10 training and test datasets using torchvision
  2. Define a Convolutional Neural Network
  3. Define a loss function
  4. Train the network on the training data
  5. Test the network on the test data
  6. Loading and normalizing CIFAR10.

Using torchvision, it’s very easy to load CIFAR10:

import torch
import torchvision
import torchvision.transforms as transformsCode language: Python (python)

The output of torchvision datasets are PILImage images of range [0, 1]. We transform them to Tensors of normalized range [-1, 1].

transform = transforms.Compose(
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader =, batch_size=4,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader =, batch_size=4,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')Code language: Python (python)

Now let’s have a look at some of our training images:

import matplotlib.pyplot as plt
import numpy as np

# functions to show an image

def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))

# get some random training images
dataiter = iter(trainloader)
images, labels =

# show images
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))Code language: Python (python)
training set
horse plane plane cat

Now, let’s define a Convolutional Neural Network using PyTorch:

import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()Code language: Python (python)

Now I will define a loss function using a Classification cross-Entropy loss and SGD with momentum:

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)Code language: Python (python)

Now, lets train the Neural Network. Here I will simply loop over our data iterator, and feed the inputs to the neural network to optimize classification:

for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')Code language: Python (python)
[1, 2000] loss: 2.137 [1, 4000] loss: 1.810 [1, 6000] loss: 1.657 [1, 8000] loss: 1.569 [1, 10000] loss: 1.493 [1, 12000] loss: 1.451 [2, 2000] loss: 1.372 [2, 4000] loss: 1.354 [2, 6000] loss: 1.339 [2, 8000] loss: 1.334 [2, 10000] loss: 1.289 [2, 12000] loss: 1.280 Finished Training

Now, before moving forward let’s quickly save our model:

PATH = './cifar_net.pth', PATH)Code language: Python (python)

Now, lets test our trained Neural Network on the test data. But we need to check if the network has learnt anything at all. We will check this by predicting the class label that the neural network outputs, and checking it against the ground-truth. If the prediction is correct, we add the sample to the list of correct predictions.

dataiter = iter(testloader)
images, labels =

# print images
print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))Code language: Python (python)
GroundTruth: cat ship ship plane

Now let’s load our saved model:

net = Net()
net.load_state_dict(torch.load(PATH))Code language: Python (python)

Okay, now let us see what our trained neural network thinks these examples above are:

outputs = net(images)
_, predicted = torch.max(outputs, 1)

print('Predicted: ', ' '.join('%5s' % classes[predicted[j]]
                              for j in range(4)))Code language: Python (python)
Predicted: cat ship ship ship

Now, lets have a look at the accuracy of our trained neural network:

correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))Code language: Python (python)

Accuracy of the network on the 10000 test images: 54 %

Now, lets look deeply on this accuracy rate, I want to see here what classes performed well and what not. I mean to say let’s have a look at the classes which contributed the most and least on this accuracy rate:

Also, read – PDF Processing with Python.

class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()
        for i in range(4):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1

for i in range(10):
    print('Accuracy of %5s : %2d %%' % (
        classes[i], 100 * class_correct[i] / class_total[i]))Code language: Python (python)
Accuracy of plane : 51 % 
Accuracy of car : 58 %
Accuracy of bird : 44 %
Accuracy of cat : 36 %
Accuracy of deer : 45 %
Accuracy of dog : 47 %
Accuracy of frog : 58 %
Accuracy of horse : 55 %
Accuracy of ship : 88 %
Accuracy of truck : 57 %

Also, read – 10 Machine Learning Projects to Boost your Portfolio.

I hope you liked this article on PyTorch for deep learning, feel free to ask your valuable questions in the comments section. Don’t forget to subscribe for my daily newsletters below to get email notification if you like my work.

Aman Kharwal
Aman Kharwal

I'm a writer and data scientist on a mission to educate others about the incredible power of data📈.

Articles: 1498

Leave a Reply