import torch
import torch.nn.functional as F
from tensornet.gradcam.gradcam import GradCAM
[docs]class GradCAMPP(GradCAM):
"""Calculate GradCAM++ salinecy map.
It inherits the `GradCAM` class so
the definition for all the methods is exactly the same as its parent class.
*Note*: The current implemenation supports only ResNet models. The class can
be extended to add support for other models.
"""
def _forward(self, input, class_idx=None, retain_graph=False):
b, c, h, w = input.size()
logit = self.model(input)
if class_idx is None:
score = logit[:, logit.max(1)[-1]].squeeze()
else:
score = logit[:, class_idx].squeeze()
self.model.zero_grad()
score.backward(retain_graph=retain_graph)
gradients = self.gradients['value'] # dS/dA
activations = self.activations['value'] # A
b, k, u, v = gradients.size()
alpha_num = gradients.pow(2)
alpha_denom = alpha_num.mul(2) + activations.mul(gradients.pow(3)).view(b, k, u*v).sum(-1).view(b, k, 1, 1)
alpha_denom = torch.where(alpha_denom != 0.0, alpha_denom, torch.ones_like(alpha_denom))
alpha = alpha_num.div(alpha_denom+1e-7)
positive_gradients = F.relu(score.exp() * gradients) # ReLU(dY/dA) == ReLU(exp(S)*dS/dA))
weights = (alpha * positive_gradients).view(b, k, u*v).sum(-1).view(b, k, 1, 1)
saliency_map = (weights * activations).sum(1, keepdim=True)
saliency_map = F.relu(saliency_map)
saliency_map = F.upsample(saliency_map, size=(h, w), mode='bilinear', align_corners=False)
saliency_map_min, saliency_map_max = saliency_map.min(), saliency_map.max()
saliency_map = (saliency_map-saliency_map_min).div(saliency_map_max-saliency_map_min).data
return saliency_map, logit