Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MultiNodeBatchNormalization #106

Merged
merged 20 commits into from
Aug 24, 2017
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix import error for Chainer v1
  • Loading branch information
iwiwi committed Aug 17, 2017
commit 73efe61903b6151ad6618651c98e86235429deec
9 changes: 4 additions & 5 deletions chainermn/functions/batch_normalization.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import chainer
from chainer import configuration
from chainer import cuda
from chainer import function
import chainer.utils
Expand Down Expand Up @@ -84,7 +83,7 @@ def check_type_forward(self, in_types):
def forward(self, inputs):
xp = cuda.get_array_module(*inputs)
x, gamma, beta = inputs[:3]
if configuration.config.train:
if chainer.configuration.config.train:
if self.running_mean is None:
self.running_mean = xp.zeros_like(gamma)
self.running_var = xp.zeros_like(gamma)
Expand Down Expand Up @@ -113,7 +112,7 @@ def forward(self, inputs):

cudnn_updated_running_stats = False

if configuration.config.train:
if chainer.configuration.config.train:
axis = (0,) + tuple(range(head_ndim, x.ndim))

mpi_comm = self.comm.mpi_comm
Expand Down Expand Up @@ -150,7 +149,7 @@ def forward(self, inputs):
'bn_fwd')(x, mean[expander], self.std[expander], gamma,
beta)

if configuration.config.train and (not cudnn_updated_running_stats):
if chainer.configuration.config.train and (not cudnn_updated_running_stats):
# Note: If in training mode, the cuDNN forward training function
# will do this for us, so
# only run following code if cuDNN was not used.
Expand Down Expand Up @@ -193,7 +192,7 @@ def backward(self, inputs, grad_outputs):
return gx, ggamma, gbeta, gmean, gvar

# Note: If length of inputs is not 5, we must be in train mode.
assert configuration.config.train
assert chainer.configuration.config.train

# It is wrong to multiply m by mpi_comm.size
# (instead of multiplying 1/size to gbeta, ggamma)
Expand Down
3 changes: 1 addition & 2 deletions chainermn/links/batch_normalization.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import chainer
from chainer import configuration
from chainer import cuda
from chainer.functions.normalization import batch_normalization
from chainer import initializers
Expand Down Expand Up @@ -91,7 +90,7 @@ def __call__(self, x, finetune=False):
beta = variable.Variable(self.xp.zeros(
self.avg_mean.shape, dtype=x.dtype))

if configuration.config.train:
if chainer.configuration.config.train:
if finetune:
self.N += 1
decay = 1. - 1. / self.N
Expand Down