Skip to content

This repository aims to implement a mushroom type classifier using PyTorch, utilizing various models to enhance performance. Additionally, the project includes an analysis of the model's performance using Gradient-Class Activation Map (Grad-CAM) visualization.

Notifications You must be signed in to change notification settings

TmohamedashrafT/Mushroom-Classification-with-Grad-cam

Repository files navigation

Mushroom-Classification-with-Grad-CAM

This repository aims to implement a mushroom type classifier using PyTorch, utilizing various models to enhance performance. Additionally, the project includes an analysis of the model's performance using Gradient-Class Activation Map (Grad-CAM) visualization.

Dataset

The dataset was obtained from Kaggle, specifically from the "LOVE OF A LIFETIME" collection. It consists of nine classes of mushrooms, which were downloaded from Kaggle and then split into train (65%), validation (20%), and test (15%) sets. The split was done equally among the classes.
Download and split.ipynb

image

Data Preprocessing

  • Resize to 299
  • Normalize channals as pytorch recommendation
    • mean = [0.485, 0.456, 0.406]
    • std = [0.229, 0.224, 0.225]

Data augmentation

  • RandomHorizontalFlip
  • RandomVerticalFlip
  • RandomRotation with maximum degrees 15

The data was loaded and augmented Dataset_Generator.py

Models

models.py

List of models:

  • ResNet50
  • Convnext

Each model is followed by:

  • Linear (out_features, out_features //2)
  • BatchNorm(out_features // 2)
  • Relu()
  • Dropout(0.3)
  • Linear (out_features // 2, out_features //4)
  • BatchNorm(out_features // 4)
  • Relu()
  • Dropout(0.2)
  • Linear (out_features // 4, number of classes)

Gradient-Class Activation Map (Grad-CAM)

Grad-CAM Overview

Grad-CAM is a visualization technique that allows us to understand what the network focuses on when making decisions based on an image. It combines the concepts of a saliency map and a class activation map. Grad-CAM works by computing the gradients of the output of the network to determine which parts of the image contribute the most to the network assigning the highest probability to a specific class.

By utilizing Grad-CAM, we can generate informative heatmaps that highlight the regions in the input image that are most influential in the network's decision-making process. These heatmaps help us interpret and analyze the model's behavior by visualizing the areas that the network pays the most attention to when classifying mushroom types.
image

For more details

The Grad-CAM was implemented in Grad_cam.py

Grad_cam_utils.py This file contains functions to generate heatmaps using Grad-CAM and plot them.

Metrics

  • Accuracy
  • Recall
  • Precision
  • F1-score

The Metrics was implemented in Metrics.py

Utils Function

  • show_batch : show random images from each class
  • show_aug_batch : show 9 random images after transformations
  • plot_results : plot the results of the same metric for both the training and validation datasets

Training settings

  • Learning rate : 1e-3 with cosine annealing scheduler
  • Optimizer : Adam
  • Epochs : 100
  • Loss : Cross entropy
  • Freeze the weights of the backbone

Results

After examining the ResNet notebook, it appears that the model is unable to effectively handle this particular dataset.

Data Loss Accuracy
Train 0.207 93.3%
Val 0.724 77.5%
Test 0.656 77.8%

Upon analyzing the dataset, it becomes apparent that mushrooms of the same class exhibit diverse shapes and colors. This variation poses a challenging task for humans to accurately classify the different types of mushrooms.

To overcome this complexity, a larger and more powerful model will be utilized. By employing a larger model, we aim to capture a broader range of features and patterns present in the mushroom images. This increased capacity will enhance the model's ability to differentiate between various types of mushrooms

image image

According to the results, the ConvNetX model exhibits higher performance compared to the ResNet model, despite a lower number of training samples and some confusion within the dataset. The ConvNetX model demonstrates good performance even in challenging conditions

Data Loss Accuracy
Train 0.16 99.8%
Val 0.419 91.4%
Test 0.403 99.2%

The model was trained for 80 epochs using normal cross-entropy loss. Following that, an additional 20 epochs were trained using class-weighted cross-entropy loss and label smoothing with a factor of 0.1.

Graphs

Loss plot in the first 80 epochs

image

Loss plot in the last 20 epochs

image

Accuracy plot

image

results of test set

Report

image

confusion matrix

image

For results of train and val set go to model_analysis.ipynb

Examples of heatmaps from the test set

In the heatmaps generated by the network, we can observe instances where the network correctly classifies the mushroom type, and the heatmap aligns with the expected regions of importance. However, there are also cases where the network predicts the correct class, but the heatmap may appear misleading or less aligned with the expected regions.

This discrepancy between the heatmap and the expected regions can occur due to various factors. One possible reason is that the network might be relying on features or patterns that are not visually apparent or easily interpretable to humans. Neural networks are capable of learning complex representations and can identify distinguishing characteristics that may not be immediately obvious to us.

image

The above image presents an example of a confusion heatmap. In this case, the heatmap indicates that the highest importance region for classification is the fallen tree leaves, which are separate from the mushrooms. It is possible that a majority of images in this particular class contain these leaves, leading the model to associate this class with the presence of leaves.

The following are examples of logical heatmaps that align with the expected regions of importance for the corresponding mushroom

image image image

About

This repository aims to implement a mushroom type classifier using PyTorch, utilizing various models to enhance performance. Additionally, the project includes an analysis of the model's performance using Gradient-Class Activation Map (Grad-CAM) visualization.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages