-
Notifications
You must be signed in to change notification settings - Fork 100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
The question of whether attention distillation loss in LwM can produce gradient. #9
Comments
Hi @NUAA-XSF,
and used in the line |
@mmasana def backward_hook(module, grad_input, grad_output):
self.gradients = grad_output[0].detach() This gradient is used to generate the Attention map .I have understood. What I really want to ask is when using two attention maps from M_t and M_{t-1} to generate attention distillation loss, can gradients be generated when this loss is propagated back. I can't find the relevant code. # retain_graph = False in your code (line 251 in lwm.py)
# I think retain_graph should be True, because attention distillation loss
# will be calculated later
score.mean().backward(retain_graph=self.retain_graph) |
@mmasana if class_indices is None:
class_indices = logits.argmax(dim=1)
score = logits[:, class_indices].squeeze()
self.model.zero_grad()
score.mean().backward(retain_graph=self.retain_graph) I think there is a problem with the third line of code, for example logits = torch.tensor([[1,2,3,4],[8,7,6,5]]).float()
class_indices = logits.argmax(dim=1) # class_indices:tensor([3, 0])
score = logits[:, class_indices].squeeze() # score:tensor([[4., 1.], [5., 8.]]) The result is confusing because it contains other values that are not the maximum. |
@NUAA-XSF
It seems it returns the values of the maximum for each entry, instead of applying it element-wise.
This one should be just the averaging of the batch before doing the backward pass. If it was a sum, when the batch-size is smaller (i.e last batch of an epoch) then the backpropagation is done at a smaller scale in comparison to a "full" batch. However, since there seems to be more values than it should in the Finally, another user mentioned that in the distillation loss we should use torch.nn.functional.normalize instead of torch.norm. We will update it to be:
I thought it would also be good to mention it in case is useful for you. Let me know if you have any other issue. |
Hello!
Thank you for your nice work.
I have a question:
LwM (learning without Memorizing) paper uses attention distillation loss. In your code (lwm.py):
I feel that using such a code does not seem to produce gradients when backpropagating.
Looking forward to your reply.
Thank you.
The text was updated successfully, but these errors were encountered: