Simplifying the Image Classification Workflow with Lightning & Comet ML

A guide to performing end-to-end computer vision projects with PyTorch-Lightning, Comet ML and Gradio

Tirendaz AI
Heartbeat

--

Image by Freepik

Computer vision is the buzzword at the moment. However, building successful projects in this field has many challenges. This is because these projects require a lot of knowledge of math, computer power, and time. Fortunately, many modern frameworks and libraries have emerged to address these challenges.

Today, I’ll walk you through how to implement an end-to-end image classification project with Lightning, Comet ML, and Gradio libraries. First, we’ll build a deep-learning model with Lightning. After that, we’ll track hyperparameters, monitor metrics and save the model with Comet ML. Lastly, we’ll create a cancer detection app with Gradio. After finishing our project, this app will look like this:

Cancer Detection App (Video by Author)

Here are the topics we’ll cover in this blog:

  • What are PyTorch-Lightning & Comet ML?
  • Preparing data with Lightning Data Module
  • Building the model with Lightning Module
  • Monitoring metrics with Comet ML
  • Creating a cancer detection app with Gradio

Let’s first take a bird’s-eye view of what Lightning and Comet ML are.

PyTorch-Lightning

As you know, PyTorch is a popular framework for building deep learning models. However, when you grow the model complexity, the modeling steps are likely to get messy. This is where Lightning comes in.

PyTorch vs Lightning

Lightning is a high-level and flexible deep-learning framework that makes it easier to manage your PyTorch code. Using Lightning, you can automate your training tasks, such as model building, data loading, model checkpointing, and logging. To get started with this framework, Lightning in 15 minutes is a great tutorial.

Comet ML

At the end of the day, our goal is to obtain the best model. To do this, we strive to find the optimal combination of hyperparameters. This is where Comet ML comes into play.

ML Workflow with Comet ML (Image by Author)

Comet ML is an MLOps platform that helps you track your hyperparameters, monitor your model metrics, and optimize your model. To use the Comet ML platform, you can create a free account by visiting Comet. If you are using Comet ML for the first time, you might check out the quick-start guide.

Let’s go ahead and put what we’ve talked about into action with a project.

Image Classification for Cancer Detection

As we all know, cancer is a complex and common disease that affects millions of people worldwide. Luckily, recent techniques in computer vision have made it possible to detect cancer from images. In this section, we’re going to cover how to do this.

The dataset we’re going to use is the histopathologic cancer detection dataset.

The histopathologic cancer detection dataset

This dataset consists of 327,680 color images obtained by extracting them from histopathological scans of lymph node sections. If you’re not familiar with the field, don’t worry. Our goal is to build a model that can identify the presence of carcinogenic elements in an image.

Please keep in mind that you can find the notebook we’re going to use in this blog here. Let’s start by examining the image labels.

# Reading the labels
cancer_labels = pd.read_csv("train_labels.csv")

# Visualizing the labels
labels_count = cancer_labels.label.value_counts()
plt.pie(labels_count, labels=['No Cancer', 'Cancer'], startangle=180, autopct='%1.1f')
plt.figure(figsize=(16,16))
plt.show()
Cancer vs non-cancer labels

As you can see, the non-cancer labels are a bit more. Let’s move on to taking a look at a few images in the dataset.

# Visualizing the images
base_dir = 'histopathologic-cancer-detection/'
fig = plt.figure(figsize=(25, 6))
train_imgs = os.listdir(base_dir+"train")
for idx, img in enumerate(np.random.choice(train_imgs, 20)):
ax = fig.add_subplot(2, 20//2, idx+1, xticks=[], yticks=[])
im = Image.open(base_dir+"train/" + img)
plt.imshow(im)
lab = cancer_labels.loc[cancer_labels['id'] == img.split('.')[0], 'label'].values[0]
ax.set_title('Label: %s'%lab)
Some images in the dataset

As you can see, each image is annotated with a binary label showing the presence of metastatic tissue. Let’s move on to the data preprocessing step.

Data Preprocessing

As we mentioned, the dataset consists of 327,680 color images, but we’re going to use 10,000 randomly selected images from this dataset. To do this, let’s choose 8,000 indexes for the training set and 2,000 indexes for the test set first.

# Selecting randomly the indexes from the dataset
np.random.seed(0)
train_imgs_orig = os.listdir(f"{base_dir}/train")
selected_image_list = []
for img in np.random.choice(train_imgs_orig, 10000):
selected_image_list.append(img)

# Creating index variables for the training and test sets
np.random.shuffle(selected_image_list)
cancer_train_idx = selected_image_list[:8000]
cancer_test_idx = selected_image_list[8000:10000]

Awesome, we created the index variables to select images. Now, let’s save the images to the train and test directories using these indexes.

# Saving the training images 
os.mkdir('cancer_train_dataset/')
for fname in cancer_train_idx:
src = os.path.join(f'{base_dir}/train', fname)
dst = os.path.join('cancer_train_dataset', fname)
shutil.copyfile(src, dst)

# Saving the test images
os.mkdir('cancer_test_dataset/')
for fname in cancer_test_idx:
src = os.path.join(f'{base_dir}/train', fname)
dst = os.path.join('cancer_test_dataset/', fname)
shutil.copyfile(src, dst)

Now let’s create a Pandas dataframe containing the labels of the training and test images.

# Creating a dataframe with labels
selected_image_labels = pd.DataFrame()
id_list = []
label_list = []

for img in selected_image_list:
label_tuple = cancer_labels.loc[cancer_labels['id'] == img.split('.')[0]]
id_list.append(label_tuple['id'].values[0])
label_list.append(label_tuple['label'].values[0])

selected_image_labels['id'] = id_list
selected_image_labels['label'] = label_list

# Creating a variable in dictionary format to use data loading
img_class_dict = {k:v for k, v in zip(selected_image_labels.id, selected_image_labels.label)}

Nice, we performed the data preprocessing steps. We’re ready to load the dataset with Lightning.

How does the team at Uber manage to keep their data organized and their team united? Comet’s experiment tracking. Learn more from Uber’s Olcay Cirit.

Loading the Dataset

It’s difficult to manage the data that will be used to train the model. It includes all the necessary steps for data processing, downloads, and transforms, as well as training, validation, testing, and prediction data loaders.

In this section, we’re going to see how to perform these steps with the Lightning DataModule, which helps you easily build dataset-agnostic models.

Note that we cannot utilize this module directly. To load the data, we first need to create our custom class like this:

# Creating the custom class to load the data
class LoadCancerDataset(Dataset):
def __init__(self, datafolder, transform, labels_dict={}):
self.datafolder = datafolder
self.image_files_list = [s for s in os.listdir(datafolder)]
self.transform = transform
self.labels_dict = labels_dict
self.labels = [labels_dict[i.split('.')[0]] for i in self.image_files_list]

def __len__(self):
return len(self.image_files_list)
def __getitem__(self, idx):
img_name = os.path.join(self.datafolder, self.image_files_list[idx])
image = Image.open(img_name)
image = self.transform(image)
img_name_short = self.image_files_list[idx].split('.')[0]
label = self.labels_dict[img_name_short]
return image, label

This class helps us read all the images in the folder. Let’s go ahead and create a DataModule class for performing data processing steps.

class CancerDataModule(pl.LightningDataModule):
def __init__(self, batch_size, num_workers, data_dir):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
self.num_workers = num_workers

# Preparing the dataset
def prepare_data(self):
"""
The prepare_data method was intentionally left empty as we have the dataset in our directory.
"""
pass

# Setuping the dataset
def setup(self, stage=None):

# Assigning train and val datasets to use in dataloaders
if stage == "fit" or stage is None:
train_set_full = LoadCancerDataset(
datafolder=f'{self.data_dir}/cancer_train_dataset',
transform=T.Compose([
T.Resize(224),
T.RandomHorizontalFlip(),
T.RandomVerticalFlip(),
T.ToTensor()
]),
labels_dict=img_class_dict
)
train_set_size = int(len(train_set_full) * 0.9)
valid_set_size = len(train_set_full) - train_set_size
self.train_ds, self.val_ds = random_split(train_set_full, [train_set_size, valid_set_size])

# Assigning test dataset to use in dataloader
if stage == "test" or stage is None:
self.test_ds = LoadCancerDataset(
datafolder=f'{self.data_dir}/cancer_test_dataset',
transform=T.Compose([
T.Resize(224),
T.ToTensor()]),
labels_dict=img_class_dict
)
# Creating the dataloaders
def train_dataloader(self):
return DataLoader(self.train_ds,batch_size=self.batch_size,
num_workers=self.num_workers,shuffle=True)

def val_dataloader(self):
return DataLoader(self.val_ds,batch_size=self.batch_size,
num_workers=self.num_workers,shuffle=False)

def test_dataloader(self):
return DataLoader( self.test_ds,batch_size=self.batch_size,
num_workers=self.num_workers,shuffle=False)

Note that if you use pure PyTorch to code these steps, you need to write more code. It is a piece of cake to carry out these steps with the Lightning DataModule.

Nice, we’ve created a class to manage the data. Let’s move on to building the model.

Building the Model with Lightning

So far, we’ve prepared the data that we’ll use to train the model. In this section, we’re going to create our model architecture for image classification. To do this, we’re going to leverage the transfer learning technique.

Transfer learning is a technique that allows us to use the pre-trained model. What we need to do is adapt the pre-trained model for our task. The pre-trained model we’re going to utilize is the ResNet-50 architecture. This architecture is often used for image classification. It is composed of 50 layers, containing convolutional layers, pooling layers, and fully connected layers.

Now, we’re going to leverage with LightningModule to build our ResNet model. This module helps us organize our PyTorch code and easily manage the training, validation, and testing steps.

class CancerImageClassifier(pl.LightningModule):

def __init__(self, learning_rate = 0.001, num_classes = 2):
super().__init__()
self.learning_rate = learning_rate
self.loss_fn = nn.CrossEntropyLoss()
self.num_classes = num_classes
# Defining metrics
self.accuracy = Accuracy(task="binary", num_classes=num_classes)
self.f1_score = F1Score(task="binary", num_classes=num_classes)
self.history = {'train_loss':[],'train_acc':[],'val_loss':[],'val_acc' : []}

# Defining the model architecture
self.pretrain_model = resnet50(weights=ResNet50_Weights.DEFAULT)
self.pretrain_model.eval()
for param in self.pretrain_model.parameters():
param.requires_grad = False

self.pretrain_model.fc = nn.Sequential(
nn.Linear(self.pretrain_model.fc.in_features, 1024),
nn.ReLU(),
nn.Dropout(),
nn.Linear(1024,self.num_classes)
)
# To run data through the model
def forward(self, input):
output=self.pretrain_model(input)
return output

def training_step(self, batch, batch_idx):
outputs, targets, loss, preds = self._common_step(batch, batch_idx)
train_accuracy = self.accuracy(preds, targets)
self.history['train_loss'].append(loss.item())
self.history['train_acc'].append(train_accuracy.item())
self.log_dict(
{"train_loss": loss,"train_acc": train_accuracy,},
on_step=False, on_epoch=True, prog_bar=True)
return {"loss":loss, 'train_acc': train_accuracy}

def validation_step(self, batch, batch_idx):
outputs, targets, loss, preds = self._common_step(batch, batch_idx)
val_accuracy = self.accuracy(preds, targets)
self.history['val_loss'].append(loss.item())
self.history['val_acc'].append(val_accuracy.item())
self.log_dict(
{"val_loss": loss,"val_acc": val_accuracy},
on_step=False, on_epoch=True, prog_bar=True,
)
return {"loss":loss, 'val_acc': val_accuracy}

def test_step(self, batch, batch_idx):
outputs, targets, loss, preds = self._common_step(batch, batch_idx)
test_accuracy = self.accuracy(preds, targets)
f1_score = self.f1_score(preds, targets)
self.log_dict(
{"test_loss": loss,
"test_acc": test_accuracy,
"test_f1_score": f1_score},
on_step=False, on_epoch=True, prog_bar=True,
)
return {"test_loss":loss,
"test_accuracy":test_accuracy,
"test_f1_score": f1_score}

def _common_step(self, batch, batch_idx):
inputs, targets = batch
outputs = self.forward(inputs)
loss = self.loss_fn(outputs, targets)
preds = torch.argmax(outputs, dim=1)
return outputs, targets, loss, preds

def configure_optimizers(self):
params = self.parameters()
optimizer = optim.Adam(params=params, lr = self.learning_rate)
return optimizer

As you can see, the code used to build the model is more understandable and easier to monitor than pure PyTorch codes.

One of my favorite features of Lightning is that it’s easy to monitor metrics. Here, we’ve utilized the self.log_dict method to log the metrics in the train and test steps.

Ok, we’ve created our model architecture. Let’s go ahead and take a look at how to track our hyperparameters, metrics, and model.

Tracking with Comet Logger

I have great news for you. Lightning supports the Comet ML platform. This means you can track model hyperparameters, monitor model metrics and register your model in your Comet ML account.

Lightning has a CometLogger class that can be utilized to seamlessly log metrics, hyperparameters, model weights and more. Just instantiate a CometLogger object and pass it to Lightning’s Trainer. Let’s do this.

# Creating an experiment with your API key
comet_logger = CometLogger(
api_key= "your_api_key",
workspace="your_workspace",
project_name="your_project_name"
)

Nice, we now have an object to track hyperparameters and metrics. What we need to do is pass this object to the Trainer as shown below. Now, we’re ready to train the model.

Training the Model

After organizing our PyTorch code into a LightningModule, the Trainer performs everything else automatically. Note that the Trainer object offers much more than “training”. It helps us set options such as choosing the hardware (CPU/GPU/CPU), training and test epochs, comet logging, 16-bit training support, and so on.

Before creating a trainer object, let’s instantiate a DataModule object, as follows:

my_dataloader = CancerDataModule(
batch_size=128,
num_workers=2,
data_dir="./")

After that, let’s create a model object from the CancerImageClassifier class.

my_model = CancerImageClassifier(
num_classes=2,
learning_rate=0.001)

Next, let’s instantiate a Trainer object and then call the fit method to train the model.

my_trainer = pl.Trainer(
logger=comet_logger,
accelerator="auto",
devices="auto",
max_epochs=15)

my_trainer.fit(my_model, my_dataloader)

When we run these codes, the training process starts. You can monitor metrics in your Comet ML dashboard, as follows:

Comet ML Dashboard (Video by Author)

So far, we’ve trained our model and examined its metrics such as accuracy, loss, and F1 score in the Comet ML dashboard. Let’s move on to model evaluation.

Model Evaluation

We now have a good model for image classification, but how is the performance of our model on the test set? To find out this, we can leverage the Trainer object by calling the test method.

# Model evaluation
my_dataloader.setup()
my_trainer.test(model=my_model, dataloaders=my_dataloader.test_dataloader())
Test Metrics

As you can see, the metrics on the test set are close to the metrics on the train set. Lastly, let’s save our model in Comet ML and finish our experiment, as follows:

# Saving Model in Comet-ML
from comet_ml.integration.pytorch import log_model
log_model(comet_logger.experiment, my_model, model_name="my_pl_model")

# Ending our experiment
comet_logger.experiment.end()

Building an App with Gradio

Gradio is a Python library that lets you build demo applications. After training a model, you can leverage this library to create a straightforward web interface that anyone can review. First, we’re going to define a function to predict the label of an image.

# Defining a function to predict the label of an image
def predict(inp):
image_transform = transforms.Compose([ transforms.Resize(size=(224,224)), transforms.ToTensor()])
labels = ['normal', 'cancer']
inp = image_transform(inp).unsqueeze(dim=0)
with torch.no_grad():
prediction = torch.nn.functional.softmax(model(inp))
confidences = {labels[i]: float(prediction.squeeze()[i]) for i in range(len(labels))}
return confidences

Next, let’s create a Gradio interface using the following codes.

# Creating a Gradio interface
gr.Interface(fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=2),
title=title,
description=description,
article=article,
examples=['image-1.jpg', 'image-2.jpg']).launch()

And that’s it! Our app is now ready and looks like this:

Cancer Detection App (Image by Author)

Using this app, you can predict whether a metastatic tissue image is cancerous.

Note that, you can find the files I used to build this app here.

Wrap-Up

When it comes to implementing complex deep learning models, you encounter many challenges. Fortunately, you can overcome these challenges with modern libraries like Lightning and Comet ML.

In this blog, we’ve learned how to carry out an end-to-end deep learning project with modern libraries. Lightning enables us to easily organize our PyTorch codes for model building. The CometLogger class in Lightning helps us monitor our model metrics. Gradio allows us to create a cancer detection app.

That’s it. Thanks for reading. Let’s connect YouTube | Twitter | LinkedIn.

Resources

Editor’s Note: Heartbeat is a contributor-driven online publication and community dedicated to providing premier educational resources for data science, machine learning, and deep learning practitioners. We’re committed to supporting and inspiring developers and engineers from all walks of life.

Editorially independent, Heartbeat is sponsored and published by Comet, an MLOps platform that enables data scientists & ML teams to track, compare, explain, & optimize their experiments. We pay our contributors, and we don’t sell ads.

If you’d like to contribute, head on over to our call for contributors. You can also sign up to receive our weekly newsletter (Deep Learning Weekly), check out the Comet blog, join us on Slack, and follow Comet on Twitter and LinkedIn for resources, events, and much more that will help you build better ML models, faster.

--

--