End-To-End PyTorch Example of Image Classification with Convolutional Neural Networks

Image classification solutions in PyTorch with popular models like ResNet and its variations. End-To-End solution for CIFAR10/100 and ImageNet datasets.
NNNeural Networks215.00
Apr 13, 2021


  The image classification problem remains one of the most widely used and researched problems in the Computer Vision field. For over decades people tackle the problem by proposing many different algorithms and approaches. Convolutional Neural Networks brought a very significant boost into the community implemented in a known model called AlexNet. It was one of the lightweight networks proposed to solve the image classification problem, but in recent years there are very diverse and complex models, which handle the problem better than the old solutions.

Benchmark Datasets and Competitions

  There are a couple of competitions every year, where people try to solve the problem better than others and one of the most known contests is the ImageNet Large Scale Visual Recognition Challenge (ILSVRC), which still remains one of the main benchmarks for the community


  The dataset spans 1000 object classes like airplane, car, animals, indoor objects, etc., and 1,281,167 training images, 50,000 validation images, and 100,000 test images. The dataset is available for downloading from the Kaggle Competition, but to get the whole set, you need to separately request specific access, which has some regulations and use cases.

CIFAR 10/100

  Cifar 10 and Cifar 100 are datasets filtered from the known one called the 80 million tiny images dataset. They are 32x32 resolution images and have 10 and 100 object classes respectively. The CIFAR-10 dataset consists of 60,000 color images in 10 classes, with 6,000 images per class. There are 50,000 training images and 10,000 test images. In contrast, the Cifar 100 contains 600 images for each class and they are grouped into 20 superclasses. Both datasets mainly contain images of animals, fish, flowers, people, trees, and other classes. They also have their Python and Matlab scripts on their website to download, extract and reshape all examples for the network feed.

Residual Networks - ResNets

  Deep Residual Networks like ResNets, proposed in 2015 by the Microsoft Research team, are one of the most used and widely observed models in Image Classification, Semantic and Instance Segmentation, Object Detection, and in many other problems of Computer Vision. Those approaches are based on one of the versions of the ResNet model, which improves the overall results of the solution mainly because of the architecture choice and transfer learning techniques. Residual Networks are very deep networks with shortcut connections, which let them extract and preserve high-level semantic information and avoid gradient vanishing problems. The most know architectures of ResNet are ResNet18, ResNet34, ResNet50, ResNet101, and ResNet152. Here are the structures of those networks.

PyTorch Implementation of ResNets

  Many deep learning frameworks have their official implementations of ResNets in their native code. PyTorch also has the implementation in the Torchvision package. You can use it in the following way:

import torchvision.models as models

# resnet18, resnet34, resnet50, resnet101, resnet152
model = models.resnet50(pretrained=True)

End-To-End Image Classification Example

  First, you need to import all necessary packages for your training and validation processes.

import torch
import argparse
import torch.optim
import torch.nn as nn
import torch.utils.data
import torchvision.models as models
import torch.backends.cudnn as cudnn
import torchvision.datasets as datasets
import torchvision.transforms as transforms

To make a generalized script for different models and different hyperparameters, you need to define arg parser, which lets you pass your parameters through the command line.

parser = argparse.ArgumentParser(description='PyTorch Example')
                    help='path to dataset')
parser.add_argument('--epochs', default=90, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('-b', '--batch-size', default=128, type=int,
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
                    metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--gpu', default=None, type=int,
                    help='GPU id to use.')
parser.add_argument('--num_classes', default=10, type=int,
                    help='Number of classes for the network.')

args = parser.parse_args()

Then it's time to load the chosen network and define losses with optimizer algorithms

model = models.resnet50(num_classes=args.num_classes)

if args.gpu is not None:
    model = model.cuda(args.gpu)

# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda(args.gpu)

optimizer = torch.optim.SGD(model.parameters(), args.lr,

As we mentioned 3 different datasets above, we need to use their specific class implementations for data loader. Each dataset has different structures and Torchvision already implemented them.

if args.dataset == 'cifar10':
  train_dataset = datasets.CIFAR10('path/cifar10', 
               transforms.RandomCrop(32, padding=4),
               transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
elif args.dataset == 'cifar100':
  train_dataset = datasets.CIFAR100('path/cifar100', 
               transforms.RandomCrop(32, padding=4),
               transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
elif args.dataset == 'imagenet':
  train_dataset = datasets.ImageNet('path/imagenet', 
               transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size,
        shuffle=True, num_workers=4)

test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=args.batch_size,
        shuffle=False, num_workers=4)

The same steps you should take for the validation set as well, only you need to pass train=False as the dataset class argument.

*Remember, if you use a dataset like Cifar10 or Cifar100, which have small resolutions, you need to modify the neck of the ResNet module and replace the Convolution layer, because 7x7 convolution with stride=2 is not a good start for small images.

self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
# raplace with
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=(3, 3), padding=(1, 1), bias=False)

To cover the training and testing regimes for each epoch, we need to define 2 different functions for each procedure.

def train(epoch):
    train_loss = 0
    correct = 0
    total = 0
    for i, (images, labels) in enumerate(train_loader):
        images, labels = images.to(args.gpu), labels.to(args.gpu)

        outputs = model(images)

        loss = criterion(outputs, labels)

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        print('Loss: {:.3}, Acc: {:.4}'.format(train_loss, 100.*correct/total))

def test(epoch):
    global best_acc
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for id, (images, labels) in enumerate(test_loader):
            images, labels = images.to(args.gpu), labels.to(args.gpu)
            outputs = model(images)
            loss = criterion(outputs, labels)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(targets).sum().item()

            print('Loss: {:.3} | Acc: {:.4}'.format(test_loss, 100.*correct/total))

The whole script for image classification is almost ready, we just need to combine those functions together and iterate over the number of epochs, passed from the command line.

for epoch in range(args.epochs):

Now it ready to run this script and train on the specific dataset we need. If you use a learning rate of 0.1, batch size of 128, you are going to get the following results

Cifar 10: ResNet50 - 93.62%, Cifar 100: ResNet50 - 61.06%, ImageNet: ResNet50 - 76.9%

4 votes
How helpful was this page?