Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MultiNodeBatchNormalization #106

Merged
merged 20 commits into from
Aug 24, 2017
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add MultiNodeBatchNormalization to document
  • Loading branch information
iwiwi committed Aug 15, 2017
commit 568c117832947c1de26f0a214623cb092e8f9ce2
25 changes: 25 additions & 0 deletions chainermn/links/batch_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,31 @@

class MultiNodeBatchNormalization(link.Link):

"""Batch normalization layer that can use the whole batch stats.

When using chainer.link.BatchNormalization, batch mean and std are
computed independently for the local batch in each worker. When local
batch size is too small, training is unstable due to unreliable batch
stats.

In contrast, when using this MultiNodeBatchNormalization, workers
communicate to conduct 'correct' batch normalization (e.g., obtaining
mean and std for the whole global batch).

Args:
size (int or tuple of ints): Size (or shape) of channel
dimensions.
comm (ChainerMN communicator): communicator to share
the batch stats.
decay (float): Decay rate of moving average. It is used on training.
eps (float): Epsilon value for numerical stability.
dtype (numpy.dtype): Type to use in computing.
use_gamma (bool): If ``True``, use scaling parameter. Otherwise, use
unit(1) which makes no effect.
use_beta (bool): If ``True``, use shifting parameter. Otherwise, use
unit(0) which makes no effect.
"""

def __init__(self, size, comm, decay=0.9, eps=2e-5, dtype=numpy.float32,
use_gamma=True, use_beta=True,
initial_gamma=None, initial_beta=None):
Expand Down
1 change: 1 addition & 0 deletions docs/source/reference/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Links

.. autoclass:: MultiNodeChainList
:members: add_link
.. autoclass:: chainermn.links.MultiNodeBatchNormalization


Optimizers and Evaluators
Expand Down