[Numpy] Further simplifying log of Softmax for NLL loss, to use LSE trick for numerical stability #21
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
NLLLoss is directly computed on log of Softmax and with some high learning rates you can get
unstable trainings/inf loss and RuntimeWarning because of div by 0 in nlls = -np.log(probs_targets)
example with the same lr = 0.1 in numpy and pytorch implementations:
To get it more in line with pytorch I further simplify the log of the softmax in order to use the LogSumExp trick.
You had already done
exp_logits = np.exp(logits - logits_max)
, I just take the log of the sum of that.I get the actual probabilities back for the cache by taking the exponential of
log_probs
This doesn't have much impact on this result tbh and could simply be left as an exercise idea.