Chest X-Ray based Multi-Class Disease Classification using DenseNet121 — Transfer Learning Approach

Naveenraj Kamalakannan
5 min readJul 22, 2021

Hey, What up, guys!? Today, we’re gonna do some fascinating fireworks with Neural Networks! Okay, lemme not keep you waiting, let's jump straight into the World of Machine Learning!

Good Day, People! :) I’m Naveenraj Kamalakannan from India, and I’ll walk you through my project of multi-class disease classification. I’ve used Transfer Learning with DenseNet121. It sounds so creepy, right? Chill, I got you! Let’s jump straight into the project.

Importing Required Packages

Note: I faced runtime errors while training the network, so I’ve disabled the eager execution

The Chest X-Ray dataset used for this implementation is from the National Institute of Health which comprises 112,120 X-ray images with disease labels from 30,805 unique patients. Clinical diagnosis of a chest X-ray can be challenging and sometimes more difficult than diagnosis via chest CT imaging. Click here to access the dataset,

P.S.: I didn’t use the complete dataset because it really takes a lot of time to run on my laptop. Using these 12,500 images, the code has been running for hours. Hey, Wait! Why don’t you make some donations, so that I can buy a GPU and run the codes faster! 😉

Check for Patient Overlap

This step has its own significance, for it prevents the model’s performance from being overly optimistic. While training, there is a chance that the model can learn patient-specific features, which improves the accuracy while testing it with the same patient. Therefore, we split the dataset such that we don’t have common patients between the respective splits.

Analyzing a Sample Image

This is necessarily not a part of the main code but can be used to analyze and differentiate the variation of pixel intensity, maximum and minimum pixel intensity of the images. We will plot the pixel intensity variation over a histplot for both original and standardized images. To standardize the images, we use the ImageDataGenerator class for all the datasets viz: train, validation, and test.

Sample Original Image
Sample Image - Pixel Intensity Variation

Standardizing Images using ImageDataGenerator Class

First things first, we have to normalize the images before we proceed to the Model building phase.

Standardizing the pixel intensity values

In the real world, we won’t be processing the X-Ray images as batches rather just one-by-one. So, considering these facts, if you carefully examine the code, we can find that the training dataset is samplewise centered and normalized while the testing dataset is featurewise centered and normalized.

Sample Standardized Image
Sample Standardized Image — Pixel Intensity Variation

Removing the Class Imbalance Problem using Customized Loss Function

In this walkthrough, we will be implementing a Cross-entropy loss function instead of the common MSE function as there is a huge variation of positive and negative frequency of each class.

One way to overcome this is to downsample/upsample all the positive and negative samples to a number ’n’ so that they contribute equally to the Loss Function.

Another way is to weigh the positive and negative frequency such that the total positive contribution and the total negative contribution equalizes.

Before weighing the classes

We can clearly see that the frequency of data that doesn’t have a specific class of disease is way too higher. This seems to be an advantage to our model. If the model predicts ‘0’ — negative for all the values, the loss function would return negligible loss and the model would claim high accuracy.

Okay, seems to be a problem, :( Let’s work on it!

After weighing the loss

Now that our class imbalance problem is addressed, our model can’t predict by the majority. Cool! Let’s define our loss function.

Building the Model

Woah! We crossed 75% of the ocean…just a few more to go. Cheer Up! In this project, we have used DenseNet121 to train our model. But why the PRE-TRAINED model?

A pre-trained model has been previously trained on a dataset and contains the weights and biases that represent the features of whichever dataset it was trained on.

You can access the weights, pre-trained files from my GitHub repository. (Links attached below)

A dense block with 5 layers and a growth rate of 4. Source: https://github.com/liuzhuang13/DenseNet
A deep DenseNet with three dense blocks. Source: https://github.com/liuzhuang13/DenseNet
Training Loss Curve

Analyzing the Performance

Let’s plot a ROC curve to get to know how well the model has performed.

Computed using pretrained_model weights from Coursera*

Note: The reason behind this average AUC score is the lack of a dedicated GPU to train the samples. I took a good amount of training dataset with batch_size = 1, WIDTH = 512 and HEIGHT = 512. Upon training the samples with appropriate batch statistics, a good pre-training model, and with enough samples, we will be able to achieve higher AUC values.

Chest X-Ray GradCam Computation

Okay, Great! We did it. Get a hands-on experience. I have uploaded my code and all the necessary files. Refer to the underlying links. Thanks, Mate!

Clap for me and make a donation, so that I can work harder and make more interesting and hands-on pieces of stuff.

Check out my GitHub Repository for the code and pre-trained models.
Link: github.com/therealnaveenkamal/Chest_XRay_Disease_Classification

References:

  1. Before I end this up, I would like to recommend all my readers to take a look at this AI for Medical Diagnosis by Pranav Rajpurkar. I had a great learning experience with him :D*
  2. Standford ML Group — CheXNeXt: https://stanfordmlgroup.github.io/projects/chexnext/
  3. Paper Link: https://journals.plos.org/plosmedicine/article?id=10.1371/journal.pmed.1002686

--

--

Naveenraj Kamalakannan

A resolute programmer of Python and Java. Worked in Android Apps and ML model deployment. More strong in Data Analytics and Bioinformatics. I love to Code :)