Source code for tensornet.gradcam.visual

import cv2
import torch
import numpy as np

import matplotlib.pyplot as plt

from tensornet.gradcam.gradcam import GradCAM
from tensornet.gradcam.gradcam_pp import GradCAMPP
from import to_numpy, unnormalize
from typing import Tuple, List, Dict, Union, Optional

[docs]def visualize_cam(mask: torch.Tensor, img: torch.Tensor, alpha: float = 1.0) -> Tuple[torch.Tensor]: """Make heatmap from mask and synthesize GradCAM result image using heatmap and img. Args: mask (torch.tensor): mask shape of (1, 1, H, W) and each element has value in range [0, 1] img (torch.tensor): img shape of (1, 3, H, W) and each pixel value is in range [0, 1] Returns: 2-element tuple containing - (*torch.tensor*): heatmap img shape of (3, H, W) - (*torch.tensor*): synthesized GradCAM result of same shape with heatmap. """ heatmap = (255 * mask.squeeze()).type(torch.uint8).cpu().numpy() heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) heatmap = torch.from_numpy(heatmap).permute(2, 0, 1).float().div(255) b, g, r = heatmap.split(1) heatmap =[r, g, b]) * alpha result = heatmap + img.cpu() result = result.div(result.max()).squeeze() return heatmap, result
[docs]class GradCAMView: """Create GradCAM and GradCAM++. *Note*: The current implemenation of `GradCAM` and `GradCAM++` supports only ResNet models. The class can be extended to add support for other models. Args: model (torch.nn.Module): Trained model. layers (list): List of layers to show GradCAM on. device (:obj:`str` or :obj:`torch.device`): GPU or CPU. mean (:obj:`float` or :obj:`tuple`): Mean of the dataset. std (:obj:`float` or :obj:`tuple`): Standard Deviation of the dataset. """ def __init__( self, model: torch.nn.Module, layers: List[str], device: Union[str, torch.device], mean: Union[float, tuple], std: Union[float, tuple] ): self.model = model self.layers = layers self.device = device self.mean = mean self.std = std self._gradcam() self._gradcam_pp() print('Mode set to GradCAM.') self.grad = self.gradcam.copy() self.views = [] def _gradcam(self): """Initialize GradCAM instance.""" self.gradcam = {} for layer in self.layers: self.gradcam[layer] = GradCAM(self.model, layer) def _gradcam_pp(self): """Initialize GradCAM++ instance.""" self.gradcam_pp = {} for layer in self.layers: self.gradcam_pp[layer] = GradCAMPP(self.model, layer)
[docs] def switch_mode(self): """Switch between GradCAM and GradCAM++.""" if self.grad == self.gradcam: print('Mode switched to GradCAM++.') self.grad = self.gradcam_pp.copy() else: print('Mode switched to GradCAM.') self.grad = self.gradcam.copy()
def _cam_image( self, norm_image: torch.Tensor, class_idx: Optional[int] = None ) -> Dict[str, Union[np.ndarray, Dict[str, np.ndarray]]]: """Get CAM for an image. Args: norm_image (torch.Tensor): Normalized image. class_idx (:obj:`int`, optional): Class index for calculating GradCAM. If not specified, the class index that makes the highest model prediction score will be used. Returns: Dictionary containing unnormalized image, heatmap and CAM result. """ image = unnormalize(norm_image, self.mean, self.std) # Unnormalized image norm_image_cuda = norm_image.clone().unsqueeze_(0).to(self.device) heatmap, result = {}, {} for layer, gc in self.gradcam.items(): mask, _ = gc(norm_image_cuda, class_idx=class_idx) cam_heatmap, cam_result = visualize_cam( mask, image.clone().unsqueeze_(0).to(self.device) ) heatmap[layer], result[layer] = to_numpy(cam_heatmap), to_numpy(cam_result) return { 'image': to_numpy(image), 'heatmap': heatmap, 'result': result }
[docs] def cam(self, norm_img_class_list: List[Union[Dict[str, Union[torch.Tensor, int]], torch.Tensor]]): """Get CAM for a list of images. Args: norm_img_class_list (list): List of dictionaries or list of images. If dict, each dict contains keys 'image' and 'class' having values 'normalized_image' and 'class_idx' respectively. class_idx is optional. If class_idx is not given then the model prediction will be used and the parameter should just be a list of images. Each image should be of type torch.Tensor """ for norm_image_class in norm_img_class_list: class_idx = None norm_image = norm_image_class if type(norm_image_class) == dict: class_idx, norm_image = norm_image_class['class'], norm_image_class['image'] self.views.append(self._cam_image(norm_image, class_idx=class_idx))
[docs] def __call__( self, norm_img_class_list: List[Union[Dict[str, Union[torch.Tensor, int]], torch.Tensor]] ) -> List[Dict[str, Union[np.ndarray, Dict[str, np.ndarray]]]]: """Get GradCAM for a list of images. Args: norm_img_class_list (list): List of dictionaries or list of images. If dict, each dict contains keys 'image' and 'class' having values 'normalized_image' and 'class_idx' respectively. class_idx is optional. If class_idx is not given then the model prediction will be used and the parameter should just be a list of images. Each image should be of type torch.Tensor """ return self.views