Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The question of whether attention distillation loss in LwM can produce gradient. #9

Closed
NUAA-XSF opened this issue Aug 27, 2021 · 4 comments

Comments

@NUAA-XSF
Copy link

NUAA-XSF commented Aug 27, 2021

Hello!
Thank you for your nice work.
I have a question:
LwM (learning without Memorizing) paper uses attention distillation loss. In your code (lwm.py):

# in class GradCAM
def __call__(self, input, class_indices=None, return_outputs=False):
        # pass input & backpropagate for selected class
        if input.dim() == 3:
            input = input.view([1] + list(input.size()))
        self.model.eval()
        model_output = self.model(input)
        logits = torch.cat(model_output, dim=1)
        if class_indices is None:
            class_indices = logits.argmax(dim=1)
        score = logits[:, class_indices].squeeze()
        self.model.zero_grad()
        score.mean().backward(retain_graph=self.retain_graph)
        model_output = [o.detach() for o in model_output]

        # create map based on gradients and activations
        with torch.no_grad():
            weights = F.adaptive_avg_pool2d(self.gradients, 1)
            att_map = (weights * self.activations).sum(dim=1, keepdim=True)
            att_map = F.relu(att_map)
            del self.activations
            del self.gradients
            return (att_map, model_output) if return_outputs else att_map

I feel that using such a code does not seem to produce gradients when backpropagating.
Looking forward to your reply.
Thank you.

@NUAA-XSF NUAA-XSF changed the title The question of whether attention differentiation loss in LwM can produce gradient. The question of whether attention distillation loss in LwM can produce gradient. Aug 27, 2021
@mmasana
Copy link
Owner

mmasana commented Aug 27, 2021

Hi @NUAA-XSF,
happy you like our work! The gradients are saved when doing the backward pass (line 229 of lwm.py):

def __enter__(self):
        # register hooks to collect activations and gradients
        def forward_hook(module, input, output):
            self.activations = output.detach()

        def backward_hook(module, grad_input, grad_output):
            self.gradients = grad_output[0].detach()

        # hook to final layer
        self.fhandle = self.model_layer.register_forward_hook(forward_hook)
        self.bhandle = self.model_layer.register_backward_hook(backward_hook)
        return self

and used in the line weights = F.adaptive_avg_pool2d(self.gradients, 1) of the part of the code you posted.
Was this what you were asking? Or did you mean something else?

@NUAA-XSF
Copy link
Author

NUAA-XSF commented Aug 28, 2021

@mmasana
Thank you for your reply .

def backward_hook(module, grad_input, grad_output):
    self.gradients = grad_output[0].detach()

This gradient is used to generate the Attention map .I have understood. What I really want to ask is when using two attention maps from M_t and M_{t-1} to generate attention distillation loss, can gradients be generated when this loss is propagated back. I can't find the relevant code.

# retain_graph = False in your code (line 251 in lwm.py)
# I think retain_graph should be True, because  attention distillation loss
# will be calculated later
score.mean().backward(retain_graph=self.retain_graph)

@NUAA-XSF
Copy link
Author

NUAA-XSF commented Aug 28, 2021

@mmasana
There are two other things that confuse me (line 247 in lwm.py):

if class_indices is None:
    class_indices = logits.argmax(dim=1)
score = logits[:, class_indices].squeeze()
self.model.zero_grad()
score.mean().backward(retain_graph=self.retain_graph)

I think there is a problem with the third line of code, for example

logits = torch.tensor([[1,2,3,4],[8,7,6,5]]).float()
class_indices = logits.argmax(dim=1)  # class_indices:tensor([3, 0])
score = logits[:, class_indices].squeeze() # score:tensor([[4., 1.], [5., 8.]])

The result is confusing because it contains other values that are not the maximum.
Another thing I don’t understand is why use score.mean()

@mmasana
Copy link
Owner

mmasana commented Aug 30, 2021

@NUAA-XSF
thanks for pointing this out. What you mention on the third line of code looks strange indeed, so I will check it out. I have not looked to this code in some time, but I remember we based the implementation on this other repository. This paper was a bit tricky already because of the lack of code and the hyperparameters of the attention-distillation loss (γ in the original paper) not being disclosed.

The result is confusing because it contains other values that are not the maximum.

It seems it returns the values of the maximum for each entry, instead of applying it element-wise.

Another thing I don’t understand is why use score.mean()

This one should be just the averaging of the batch before doing the backward pass. If it was a sum, when the batch-size is smaller (i.e last batch of an epoch) then the backpropagation is done at a smaller scale in comparison to a "full" batch. However, since there seems to be more values than it should in the score tensor, it might be not necessary if that part is modified.

Finally, another user mentioned that in the distillation loss we should use torch.nn.functional.normalize instead of torch.norm. We will update it to be:

def attention_distillation_loss(self, attention_map1, attention_map2):
    attention_map1 = torch.nn.functional.normalize(attention_map1.view(attention_map1.size(0),-1), p=2, dim=1, eps=1e-12, out=None)
    attention_map2 = torch.nn.functional.normalize(attention_map2.view(attention_map1.size(0),-1), p=2, dim=1, eps=1e-12, out=None)
    return torch.norm( attention_map2 - attention_map1, p=1, dim=1).mean()

I thought it would also be good to mention it in case is useful for you. Let me know if you have any other issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants