-
Notifications
You must be signed in to change notification settings - Fork 211
Replace allclose with testing assert close #480
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Replace allclose with testing assert close #480
Conversation
| logits_with_saes = model(prompt) | ||
| assert not torch.allclose(original_logits, logits_with_saes) | ||
| with pytest.raises(AssertionError): | ||
| torch.testing.assert_close(original_logits, logits_with_saes) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can make a helper in tests.helpers that does this? Seems annoying to need to remember to add the with pytest.raises(AssertionError) everywhere. Something like assert_not_close()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea! I've extracted all the with pytest.raises(AssertionError): blocks to assert_not_close() (which calls assert_close()).
| new_params[k], | ||
| atol=1e-5, | ||
| rtol=1e-5, | ||
| msg=lambda msg: f"Parameter {k} differs after {fold_fn}\n\n{msg}", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can make our own assert_close wrapper in tests.helpers that by default sets 1e-5 for both these params? Seems annoying need to to write atol=1e-5, rtol=1e-5 everywhere
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea! I've replaced calls to torch.testing.assert_close() with calls to assert_close(), which has the same default values for atol and rtol as torch.allclose() (to keep our tests the same as before), and abstracts away the lambda msg code.
chanind
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor nits about adding a helper function, but looks great otherwise
bbb3fcc to
986fb63
Compare
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #480 +/- ##
=======================================
Coverage 85.64% 85.64%
=======================================
Files 28 28
Lines 3567 3567
Branches 443 443
=======================================
Hits 3055 3055
Misses 331 331
Partials 181 181 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
986fb63 to
d53898e
Compare
* replaces torch.allclose() with torch.testing.assert_close() * removes type: ignore comment * makes changes to augment generated messages * extracts repeated code to helpers * removes unnecessary args * updates after rebase
Description
Replaces
assertand following calls totorch.allclose(), with calls totorch.testing.assert_close(), andassert notand following calls totorch.allclose(), withwith pytest.raise(AssertionError)and calls totorch.testing.assert_close(). We want the output when what's being asserted bytorch.testing.assert_close()isn't true, e.g.and PyTorch recommends using
torch.testing.assert_close().Fixes #479
Type of change
Please delete options that are not relevant.
Checklist:
You have tested formatting, typing and tests
make check-cito check format and linting. (you can runmake formatto format code if needed.)Performance Check.
If you have implemented a training change, please indicate precisely how performance changes with respect to the following metrics:
Please links to wandb dashboards with a control and test group.