-
Notifications
You must be signed in to change notification settings - Fork 31.5k
Open
Labels
Description
System Info
v4.46.2
Who can help?
No response
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
Neftune computation is probably wrong with packed training because the scaling factor is alpha/sqrt(lenght*d). The length is the packed length there:
transformers/src/transformers/trainer_utils.py
Lines 126 to 149 in a06a0d1
| def neftune_post_forward_hook(module, input, output): | |
| """ | |
| Implements the NEFTune forward pass for the model using forward hooks. Note this works only for torch.nn.Embedding | |
| layers. This method is slightly adapted from the original source code that can be found here: | |
| https://github.com/neelsjain/NEFTune Simply add it to your model as follows: | |
| ```python | |
| model = ... | |
| model.embed_tokens.neftune_noise_alpha = 0.1 | |
| model.embed_tokens.register_forward_hook(neftune_post_forward_hook) | |
| ``` | |
| Args: | |
| module (`torch.nn.Module`): | |
| The embedding module where the hook is attached. Note that you need to set `module.neftune_noise_alpha` to | |
| the desired noise alpha value. | |
| input (`torch.Tensor`): | |
| The input tensor to the model. | |
| output (`torch.Tensor`): | |
| The output tensor of the model (i.e. the embeddings). | |
| """ | |
| if module.training: | |
| dims = torch.tensor(output.size(1) * output.size(2)) | |
| mag_norm = module.neftune_noise_alpha / torch.sqrt(dims) | |
| output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm) | |
| return output |
Expected behavior
Should take into account the size of each sentence during computation.