Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MultiNodeBatchNormalization #106

Merged
merged 20 commits into from
Aug 24, 2017
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add explanation of tests
  • Loading branch information
iwiwi committed Aug 16, 2017
commit 71b25cb319720cde11aaafb933e5a8e1e237381c
27 changes: 27 additions & 0 deletions tests/links_tests/test_batch_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,33 @@ def test_version_check(self):
3, self.communicator)

def test_multi_node_bn(self):
"""Tests correctness of MultiNodeBatchNormalization.

This test verifies MultiNodeBatchNormalization by comparing
the following four configurations.
(1) Single worker, normal BatchNormalization
(2) Multiple workers, normal BatchNormalization
(3) Single worker, MultiNodeBatchNormalization
(4) Multiple workers, MultiNodeBatchNormalization

Single worker: only using the result of worker 0, which uses the whole
batch.
Multiple workers: Each worker uses the 1/n part of the whole batch,
where n is the number of nodes, and gradient is aggregated.

This test conducts the forward and backward computation once for the
deterministic model parameters and an input batch, and checks the
gradients of parameters.

The purpose of MultiNodeBatchNormalization is to make the results of
(4) to be exactly same as (1). Therefore, the essential part is to
check that the results of (1) and (4) are the same. The results of (3)
should also be also same as them. In contrast, the results of (2) is
not necessarily always same as them, and we can expect that it is
almost always different. Therefore, we also check that the results of
(2) is different from them, to see that this test working correctly.
"""

if chainer.__version__.startswith('1.'):
raise nose.plugins.skip.SkipTest()

Expand Down