See What Your Model Sees with Model Explainability
As new models continue to light the ML world on fire, more and more focus is going into just how these models work. Model Explainability is a fantastic field to follow right now as new methods roll out to give more insight into what our models are seeing, helping us determine hidden biases and understand misclassifications.
As model sophistication increases, so will the need for model explainability. I will cover how you can leverage CAM (Class Activation Mapping) methods with FiftyOne to unlock a whole new understanding of your models. Let’s dive into what CAM is and learn more about model explainability.
What Exactly is CAM?
Without diving too deep into the specifics, CAM is an algorithm that produces a heatmap on an image of where the model is focusing the most for a specific class. The Original CAM method relies on global average pooling layers for producing the heatmap and could not be applied to all models. However, further research and methods have been produced that allow for more general use cases.
GradCam is one of the most popular methods in which the gradients are used to produce the heatmap. GradCam performs a weighted combination of feature map activations based on the gradient values to then produce a heatmap. This allows for any differentiable layer to be chosen as the CAM layer, not just global average pooling.
GradCam Heatmap for Brown Bear Class
While the methods may change, the main idea of CAM methods is to find the spatial data that is preserved and expressed feature maps and depict them in a heatmap on the original image in a per class fashion. The heatmaps help us understand what features are being learned in our model and how the model comes to the conclusion it provides.
Let’s look at an example of unwanted bias in our dataset found with GradCam. We can understand where our model was looking to help come to its decision based on the heatmap. In the image below, we can see the sky light up as positive values, meaning that instead of learning on the features of the plane, it is making its prediction mostly based on the large amount of sky in the picture and not so much plane.
Objects should be predicted based on their features, not their environments. Spotting biases like this one helps prevent trouble in the future. By detecting the sky to make plane predictions, we open ourselves up to misclassifications on rainy days or potential ways bad actors could bypass classifications by abusing the bias.
Thankfully, through model evaluation through FiftyOne and CAM methods, we can spot these biases and understand our models better. Let’s take a look at two different examples.
Follow along in the notebook example for the full code!
Exploring GradCam and FiftyOne
For the two examples we will be looking at, we will be using pytorch_grad_cam, an incredible open source package that makes working with GradCam very easy. There are excellent other tutorials to check out on the repo as well.
To start, we will be doing instance segmentation with GradCam to get a better picture into how our model is predicting. The model we will be using is deeplabv3-resnet50-coco-torch. We can get started with our model and the quickstart example dataset easily with the following:
import fiftyone as fo import fiftyone.zoo as foz dataset = foz.load_zoo_dataset("quickstart") fo_model = foz.load_zoo_model("deeplabv3-resnet50-coco-torch")
We need to first collect predictions from our model so that we can then understand them later. We can inference on every sample in our FiftyOne dataset with:
dataset.apply_model(fo_model, label_field="resnet50_seg")
The function apply_model
will add a new label to each sample with the predicted segmentation mask, such as the one below:
Now, in order to prep the model for GradCam, we need to make sure the GradCam function can identify clearly four things:
- The input tensor of our model, preprocessed and prepped for inference
- The output of our model as an array or tensor
- The target class expressed as the target output of the model, needed for backpropagation
- The target layer of the model that we would like to see the class activations for
Let’s tackle these one by one!
Preprocessing the Input Tensor
Luckily for us, it is very easy for us to grab the preprocessing of our model with FiftyOne.
transforms = fo_model.transforms
If a Model Zoo model contains preprocessing, it will always be stored under the transforms
attribute. We can preprocess easily then with:
image_path = sample.filepath rgb_img = Image.open(image_path) input_tensor = transforms(rgb_img).unsqueeze(0)
Getting Only the Model Output
Our model in this case outputs a dictionary, but we only want the mask output. We can create a thin wrapper to go around our model that makes the forward method only return model(x)[“out”]
.
class ModelWrapper(torch.nn.Module): def __init__(self, model): super(ModelWrapper, self).__init__() self.model = model def forward(self, x): return self.model(x)["out"] model = ModelWrapper(fo_model._model)
Wherever backpropagation is performed from is the value you will want to grab. This can be tricky in some object detection cases like YOLO, but just requires wrapping the model to grab the appropriate value as well.
Defining the Target Class
We need to decide what target we will be chasing in our GradCam. We are faced with either looking at one pixel for every class or all pixels for one class. I chose the latter for the example. We will gather the class and the mask to sum across all the pixels belonging to the class.
class SemanticSegmentationTarget: def __init__(self, category, mask): self.category = category self.mask = torch.from_numpy(mask) if torch.cuda.is_available(): self.mask = self.mask.cuda() def __call__(self, model_output): return (model_output[self.category, :, : ] * self.mask).sum()
In our example as well, we will specifically look at the “person” class when generating our heatmaps.
Choosing the Model Layer
Choosing the correct model layer can be a tricky task at times. The most important part for GradCam is that you are targeting a differentiable layer that will produce a gradient, and that you are targeting a layer that is expressing the spatial information of the model. ResNet50 is an easy architecture to work with due to its block structure.
We can grab the outputs of stage 2-5 and perform GradCam on each of the layers. This will allow us to understand how the model works through different feature sets until arriving at its prediction. We target them with:
target_layers = [[model.model.backbone.layer1], [model.model.backbone.layer2], [model.model.backbone.layer3], [model.model.backbone.layer4]]
Performing GradCam
To get started we first generate the cam
function that we will be using. This requires our model and the target layer.
from pytorch_grad_cam import GradCAM cam = GradCAM(model=model, target_layers=target_layers[index],)
Afterwards we can put it all together and create a loop to iterate through all of our samples to perform GradCam!
for sample in dataset: #Load the Image and Preprocess image_path = sample.filepath rgb_img = Image.open(image_path) input_tensor = transforms(rgb_img).unsqueeze(0).cuda() #Generate mask output = model(input_tensor) #Create the target normalized_masks = torch.nn.functional.softmax(output, dim=1).cpu() person_mask = normalized_masks[0, :, :, :].argmax(axis=0).detach().cpu().numpy() person_category = classes_dict["person"] person_mask_float = np.float32(person_mask == person_category) targets = [SemanticSegmentationTarget(person_category, person_mask_float)] #Perform GradCam grayscale_cam = cam(input_tensor=input_tensor, targets=targets,) # Here grayscale_cam has only one image in the batch grayscale_cam = grayscale_cam[0, :] #Save to sample sample[f"person_grad_layer_{index}"] = fo.Heatmap(map=grayscale_cam) sample.save()
For full reference, don’t forget to refer to the notebook!
We are able to get some stunning results from our GradCam runs. We can look at first how the model works its way through edge detection and low level features and then hones its way in on the person in the image.
In an even more stunning example, watch how the model below precisely finds the edges for the train then drops it later on when it only focuses on the person. Remember that the heatmaps are only for the person class so it is incredible to see the model “ignore” the train towards the end after previously “looking” at it.
GradCam with Vision Transformers
GradCam was originally designed with CNNs in mind. Since, transformers have exploded as a staple in many different types of model architectures. Vision Transformers, or ViTs for short, are the main computer vision architecture being used for tasks across the domain. The main difference is we no longer have these stacks of convolutional layers to perform GradCam anymore, so what do we do?
The answer lies with finding the spatial data in the architecture and how we can transform it (pun intended) to a way GradCam can work with it. To begin, let us go over quickly what powers transformer models. Vision Transformers work by leveraging multi-head self attention. However, instead of performing attention on language, we are going to create tokens based on slices of our image. These slices tend to be square frames of our image, in which the model can then learn both global and local dependencies within images to make accurate predictions.
Here is a helpful visual of how the model is slicing up the image. The model should learn that slices in and around the bear’s face are important for the classification, while the background squares are not.
All this builds up to help us understand where the spatial data lies and how we can extract it. We just need to grab from the activations the data that lies within each square and how much it activated for the class we are interested in. In order to do this, we apply a simple transform, after the ln_norm layer in our model, to bring the data into 2D:
def reshape_transform(tensor, height=14, width=14): result = tensor[:, 1:, :].reshape(tensor.size(0), height, width, tensor.size(2)) # Bring the channels to the first dimension, # like in CNNs. result = result.transpose(2, 3).transpose(1, 2) return result
We can now apply the transform and try ourselves with a vision transformer. We will use DeiT from the FiftyOne model zoo. It can be loaded with a pre and postprocessor like the following:
import os import eta import fiftyone.utils.torch as fout import torchvision transforms = [fout.ToPILImage(),torchvision.transforms.Resize((224,224)), torchvision.transforms.ToTensor(),torchvision.transforms.Normalize( [0.5, 0.5, 0.5], [0.5, 0.5, 0.5])] labels_path = os.path.join( eta.constants.RESOURCES_DIR, "imagenet-labels-no-background.txt" ) transforms = torchvision.transforms.Compose(transforms) fo_model = fout.load_torch_hub_image_model( "facebookresearch/deit:main", 'deit_tiny_patch16_224', hub_kwargs=dict(pretrained=True), transforms=transforms, output_processor_cls=fout.ClassifierOutputProcessor, labels_path=labels_path, )
We will also need to grab the core Torch Model for the CAM method later, we do so with:
model = fo_model._model model.eval() model = model.cuda() #if using gpu
As mentioned above, we need to decide which layer to target for our CAM methods. We will be using the last normalization layer, as it usually performs the best for model explanability and ViTs. In DeiT that is:
target_layers = [model.blocks[-1].norm1]
Afterwards, we can load our model, target layers, and transform into pytorch-grad_cam
and generate a new CAM for us. We will be passing None
for targets as it will automatically target the highest predicted class.
cam = methods[method](model=model, target_layers=target_layers, reshape_transform=reshape_transform)
We have many methods to choose from in the library, so we will be comparing all of them across our images. While they all operate in similar fashion, they use different methods to determine the class activations, so different features may be highlighted in others, leading to potential uncovered biases in one method and not another. The eight different methods we will be testing are:
methods = \ {"gradcam": GradCAM, "scorecam": ScoreCAM, "gradcam++": GradCAMPlusPlus, "ablationcam": AblationCAM, "xgradcam": XGradCAM, "eigencam": EigenCAM, "eigengradcam": EigenGradCAM, "layercam": LayerCAM,}
You can learn more about each one and the papers they are derived from on the library’s README.
To find the full loop to produce each heatmap, refer to the notebook . The main portion is generating a heatmap using the cam function and the input_tensor of an image:
grayscale_cam = cam(input_tensor=input_tensor, targets=targets, eigen_smooth=True, aug_smooth=True)
Here is just a portion of some of the interesting results found!
Learning The Wrong Features
Sometimes our models make decisions not because of the features of the object we hoped to learn, but because of the features of the environment around it. This can happen often in cases where objects are always at the center of your image, so the model can learn to just detect the center and look around it, or when the background helps give away the class immensely, such as the background being all sky or all water. Here we can find two cases of the latter rather quickly.
Above we can see the surfer hitting the wave, the class activation maps tell us that the model is heavily focused on the splashing water, and not the person. This indicates that the model associates splashing water heavily with snorkel instead of the actual features of snorkel leading to misclassifications like herel. While there are many paths you could take to subvert this hindrance in the next training run, some form of data curation or augmentation should be done to hopefully improve the models performance and help detect the right features.
Again, we can see that the model is making an accurate prediction above with a warplane. However, the model according to the heatmaps is barely looking at the planes! How could this be possible? The model is most likely looking at the objects in the center and realizing the rest of the image is just sky. The sky then heavily impacts its decision towards making the classification of warplanes. While this is still generating correct predictions, this line of thinking for the model could spell trouble down the line.
Understanding Misclassifications
Observing our class activation maps can help us understand why the model came to the wrong conclusion. In the image below, we can see that there is an incorrect classification of “german shepherd“ for the cat. We can see though that the model did pick up all the correct edges of the cat and focused correctly, german shepherd just had a higher signal. We could compare in future with other cat activation maps to see how they performed.
Focusing On What Matters
The last example to look at is how well the model focuses on the subject of the picture. In some cases, it’s obvious through the class activations that the model was able to find the subject with high precision. In the example below, the model clearly activates for the zebras and pays no attention to the background.
In other cases, despite being correct or nearly correct in the case of our cat image below, the model activates highly for shapes or other edges in our image.
The circular piece of glass the cat is looking through seemingly distracts the model, potentially swaying its judgment at predictions. Such edges or shapes should certainly not be considered a feature of the cat, and could negatively impact the model.
Conclusion
GradCam should be a tool in every computer vision engineer’s toolkit. It provides valuable insights into how your model predicts, and, more importantly, can help understand biases in your datasets or predictions. If you are interested in learning more about GradCam, I recommend checking out the pytorch-gradcam
documentation. There are some really incredible tutorials that made putting together this example possible!