Skip to content

Conversation

@anthonyduong9
Copy link
Collaborator

@anthonyduong9 anthonyduong9 commented May 15, 2025

Description

Replaces assert and following calls to torch.allclose(), with calls to torch.testing.assert_close(), and assert not and following calls to torch.allclose(), with with pytest.raise(AssertionError) and calls to torch.testing.assert_close(). We want the output when what's being asserted by torch.testing.assert_close() isn't true, e.g.

E       AssertionError: Tensor-likes are not close!
E
E       Mismatched elements: 193048 / 193048 (100.0%)
E       Greatest absolute difference: 21.887332916259766 at index (0, 1, 1593) (up to 1e-05 allowed)
E       Greatest relative difference: 371709.78125 at index (0, 3, 19725) (up to 1.3e-06 allowed)

and PyTorch recommends using torch.testing.assert_close().

Fixes #479

Type of change

Please delete options that are not relevant.

  • Bug fix (non-breaking change which fixes an issue)

Checklist:

  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility

You have tested formatting, typing and tests

  • I have run make check-ci to check format and linting. (you can run make format to format code if needed.)

Performance Check.

If you have implemented a training change, please indicate precisely how performance changes with respect to the following metrics:

  • L0
  • CE Loss
  • MSE Loss
  • Feature Dashboard Interpretability

Please links to wandb dashboards with a control and test group.

@anthonyduong9 anthonyduong9 marked this pull request as ready for review May 15, 2025 14:05
@anthonyduong9 anthonyduong9 requested a review from chanind May 15, 2025 14:05
logits_with_saes = model(prompt)
assert not torch.allclose(original_logits, logits_with_saes)
with pytest.raises(AssertionError):
torch.testing.assert_close(original_logits, logits_with_saes)
Copy link
Collaborator

@chanind chanind May 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can make a helper in tests.helpers that does this? Seems annoying to need to remember to add the with pytest.raises(AssertionError) everywhere. Something like assert_not_close()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea! I've extracted all the with pytest.raises(AssertionError): blocks to assert_not_close() (which calls assert_close()).

new_params[k],
atol=1e-5,
rtol=1e-5,
msg=lambda msg: f"Parameter {k} differs after {fold_fn}\n\n{msg}",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can make our own assert_close wrapper in tests.helpers that by default sets 1e-5 for both these params? Seems annoying need to to write atol=1e-5, rtol=1e-5 everywhere

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea! I've replaced calls to torch.testing.assert_close() with calls to assert_close(), which has the same default values for atol and rtol as torch.allclose() (to keep our tests the same as before), and abstracts away the lambda msg code.

Copy link
Collaborator

@chanind chanind left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor nits about adding a helper function, but looks great otherwise

Base automatically changed from refactor-arch-configs to alpha May 22, 2025 15:30
@anthonyduong9 anthonyduong9 force-pushed the replace-allclose-with-testing-assert_close branch from bbb3fcc to 986fb63 Compare May 24, 2025 09:04
@codecov
Copy link

codecov bot commented May 24, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 85.64%. Comparing base (07196e1) to head (d53898e).
Report is 1 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #480   +/-   ##
=======================================
  Coverage   85.64%   85.64%           
=======================================
  Files          28       28           
  Lines        3567     3567           
  Branches      443      443           
=======================================
  Hits         3055     3055           
  Misses        331      331           
  Partials      181      181           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Base automatically changed from alpha to main July 14, 2025 22:25
@anthonyduong9 anthonyduong9 force-pushed the replace-allclose-with-testing-assert_close branch from 986fb63 to d53898e Compare July 16, 2025 20:12
@anthonyduong9 anthonyduong9 merged commit f39763f into main Jul 16, 2025
5 checks passed
@anthonyduong9 anthonyduong9 deleted the replace-allclose-with-testing-assert_close branch July 16, 2025 20:20
v-realm pushed a commit to realmlabs-ai/SAELens that referenced this pull request Nov 15, 2025
* replaces torch.allclose() with torch.testing.assert_close()

* removes type: ignore comment

* makes changes to augment generated messages

* extracts repeated code to helpers

* removes unnecessary args

* updates after rebase
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Proposal] Replace torch.allclose() with torch.testing.assert_close()

3 participants