-
Notifications
You must be signed in to change notification settings - Fork 57
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #106 from chainer/distributed-batch-normalization
Add MultiNodeBatchNormalization
- Loading branch information
Showing
5 changed files
with
524 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,235 @@ | ||
# This file is heavily based on Chainer's batch normalization implementation. | ||
# See: chainer/functions/normalization/batch_normalization.py (dbb650) | ||
|
||
import chainer | ||
from chainer import cuda | ||
from chainer import function | ||
import chainer.utils | ||
from chainer.utils import type_check | ||
import numpy | ||
|
||
|
||
if cuda.cudnn_enabled: | ||
cudnn = cuda.cudnn | ||
libcudnn = cudnn.cudnn | ||
|
||
|
||
def _as4darray(arr): | ||
if arr.ndim == 0: | ||
return arr.reshape(1, 1, 1, 1) | ||
elif arr.ndim == 4: | ||
return arr | ||
else: | ||
return arr.reshape(arr.shape[0], -1, 1, 1) | ||
|
||
|
||
def _xhat(x, mean, std, expander): | ||
x_mu = x - mean[expander] | ||
x_mu /= std[expander] | ||
return x_mu | ||
|
||
|
||
class MultiNodeBatchNormalizationFunction(function.Function): | ||
|
||
def __init__(self, comm, eps=2e-5, mean=None, var=None, decay=0.9): | ||
chainer.utils.experimental( | ||
'chainermn.functions.MultiNodeBatchNormalizationFunction') | ||
|
||
self.comm = comm | ||
self.running_mean = mean | ||
self.running_var = var | ||
|
||
# Note: cuDNN v5 requires that eps be greater than 1e-5. Otherwise, an | ||
# error will occur. | ||
# See CUDNN_BN_MIN_EPSILON value in cudnn.h to verify minimum allowable | ||
# value. | ||
self.eps = eps | ||
if chainer.should_use_cudnn('>=auto'): | ||
if eps < 1e-5: | ||
msg = 'cuDNN does not allow an eps value less than 1e-5.' | ||
raise RuntimeError(msg) | ||
self.mean_cache = None | ||
self.decay = decay | ||
|
||
# We need to delay importing MPI4py (and momdules that import MPI4py) | ||
import chainermn.communicators._memory_utility as memory_utility_module | ||
from mpi4py import MPI as mpi4py_module | ||
self.memory_utility_module = memory_utility_module | ||
self.mpi4py_module = mpi4py_module | ||
|
||
def check_type_forward(self, in_types): | ||
n_in = type_check.eval(in_types.size()) | ||
if n_in != 3 and n_in != 5: | ||
raise type_check.InvalidType( | ||
'%s or %s' % (in_types.size() == 3, in_types.size() == 5), | ||
'%s == %s' % (in_types.size(), n_in)) | ||
x_type, gamma_type, beta_type = in_types[:3] | ||
M = type_check.eval(gamma_type.ndim) | ||
type_check.expect( | ||
x_type.dtype.kind == 'f', | ||
x_type.ndim >= gamma_type.ndim + 1, | ||
x_type.shape[1:1 + M] == gamma_type.shape, | ||
# TODO(beam2d): Check shape | ||
gamma_type.dtype == x_type.dtype, | ||
beta_type.dtype == x_type.dtype, | ||
gamma_type.shape == beta_type.shape, | ||
) | ||
if len(in_types) == 5: | ||
mean_type, var_type = in_types[3:] | ||
type_check.expect( | ||
mean_type.dtype == x_type.dtype, | ||
mean_type.shape == gamma_type.shape, | ||
var_type.dtype == x_type.dtype, | ||
var_type.shape == gamma_type.shape, | ||
) | ||
|
||
def forward(self, inputs): | ||
xp = cuda.get_array_module(*inputs) | ||
x, gamma, beta = inputs[:3] | ||
if chainer.configuration.config.train: | ||
if self.running_mean is None: | ||
self.running_mean = xp.zeros_like(gamma) | ||
self.running_var = xp.zeros_like(gamma) | ||
else: | ||
self.running_mean = xp.array(self.running_mean) | ||
self.running_var = xp.array(self.running_var) | ||
elif len(inputs) == 5: | ||
self.fixed_mean = inputs[3] | ||
self.fixed_var = inputs[4] | ||
|
||
head_ndim = gamma.ndim + 1 | ||
expander = (None, Ellipsis) + (None,) * (x.ndim - head_ndim) | ||
gamma = gamma[expander] | ||
beta = beta[expander] | ||
|
||
# cuDNN only supports these tensor dimensions because they are | ||
# the most commonly used. If there is a need to support other | ||
# dimensions with cuDNN, we could consider reshaping the input | ||
# into a 2-dim array with channels as second dim and m=<product | ||
# of all dimensions except the 2nd dimension> as the first | ||
# dimension. | ||
cudnn_dim_ok = x.ndim == 2 or (x.ndim == 4 and head_ndim == 2) | ||
# TODO(bkvogel): Check for float16 support again in next cuDNN version. | ||
# cuDNN v5 batch normalization does not seem to support float16. | ||
self._can_use_cudnn = cudnn_dim_ok and x[0].dtype != numpy.float16 | ||
|
||
cudnn_updated_running_stats = False | ||
|
||
if chainer.configuration.config.train: | ||
axis = (0,) + tuple(range(head_ndim, x.ndim)) | ||
|
||
# ChainerMN diff (1/2) begins | ||
mpi_comm = self.comm.mpi_comm | ||
tmp = xp.empty(gamma.size * 2, dtype=x.dtype) | ||
x.mean(axis=axis, out=tmp[:gamma.size]) | ||
xp.square(x).mean(axis=axis, out=tmp[gamma.size:]) | ||
if xp is not numpy: | ||
chainer.cuda.Stream.null.synchronize() | ||
mpi_comm.Allreduce( | ||
self.mpi4py_module.IN_PLACE, | ||
self.memory_utility_module.array_to_buffer_object(tmp)) | ||
tmp *= 1.0 / mpi_comm.size | ||
|
||
mean = tmp[:gamma.size] | ||
sqmean = tmp[gamma.size:] | ||
var = sqmean - xp.square(mean) | ||
# ChainerMN diff (1/2) ends | ||
|
||
var += self.eps | ||
else: | ||
mean = self.fixed_mean | ||
var = self.fixed_var + self.eps | ||
self.std = xp.sqrt(var, dtype=var.dtype) | ||
if xp is numpy: | ||
self.x_hat = _xhat(x, mean, self.std, expander) | ||
y = gamma * self.x_hat | ||
y += beta | ||
else: | ||
self.x_hat, y = cuda.elementwise( | ||
'T x, T mean, T std, T gamma, T beta', 'T x_hat, T y', | ||
''' | ||
x_hat = (x - mean) / std; | ||
y = gamma * x_hat + beta; | ||
''', | ||
'bn_fwd')(x, mean[expander], self.std[expander], gamma, | ||
beta) | ||
|
||
if chainer.configuration.config.train and \ | ||
(not cudnn_updated_running_stats): | ||
# Note: If in training mode, the cuDNN forward training function | ||
# will do this for us, so | ||
# only run following code if cuDNN was not used. | ||
# Update running statistics: | ||
m = x.size // gamma.size | ||
adjust = m / max(m - 1., 1.) # unbiased estimation | ||
self.running_mean *= self.decay | ||
temp_ar = xp.array(mean) | ||
temp_ar *= (1 - self.decay) | ||
self.running_mean += temp_ar | ||
del temp_ar | ||
self.running_var *= self.decay | ||
temp_ar = xp.array(var) | ||
temp_ar *= (1 - self.decay) * adjust | ||
self.running_var += temp_ar | ||
del temp_ar | ||
return y, | ||
|
||
def backward(self, inputs, grad_outputs): | ||
x, gamma = inputs[:2] | ||
gy = grad_outputs[0] | ||
head_ndim = gamma.ndim + 1 | ||
expander = (None, Ellipsis) + (None,) * (x.ndim - head_ndim) | ||
m = gamma.dtype.type(x.size // gamma.size) | ||
axis = (0,) + tuple(range(head_ndim, x.ndim)) | ||
xp = cuda.get_array_module(x) | ||
if len(inputs) == 5: | ||
# This case is unlikely to be used in practice and so does not | ||
# need to be optimized for performance. | ||
mean = inputs[3] | ||
var = inputs[4] | ||
std = xp.sqrt(var, dtype=var.dtype) | ||
gs = gamma / std | ||
gbeta = gy.sum(axis=axis) | ||
x_hat = _xhat(x, mean, std, expander) | ||
ggamma = (gy * x_hat).sum(axis=axis) | ||
gmean = -gs * gbeta | ||
gvar = -0.5 * gamma / var * ggamma | ||
gx = gs[expander] * gy | ||
return gx, ggamma, gbeta, gmean, gvar | ||
|
||
# Note: If length of inputs is not 5, we must be in train mode. | ||
assert chainer.configuration.config.train | ||
|
||
# ChainerMN diff (2/2) begins | ||
# Note: It is wrong to multiply m by mpi_comm.size | ||
# (instead of multiplying 1/size to gbeta, ggamma) | ||
mpi_comm = self.comm.mpi_comm | ||
tmp = xp.empty(gamma.size * 2, dtype=x.dtype) | ||
gy.sum(axis=axis, out=tmp[:gamma.size]) | ||
(gy * self.x_hat).sum(axis=axis, out=tmp[gamma.size:]) | ||
if xp is not numpy: | ||
chainer.cuda.Stream.null.synchronize() | ||
mpi_comm.Allreduce( | ||
self.mpi4py_module.IN_PLACE, | ||
self.memory_utility_module.array_to_buffer_object(tmp)) | ||
tmp *= 1.0 / mpi_comm.size | ||
gbeta = tmp[:gamma.size] | ||
ggamma = tmp[gamma.size:] | ||
# ChainerMN diff (2/2) ends | ||
|
||
if xp is numpy: | ||
gx = (gamma / self.std)[expander] * ( | ||
gy - (self.x_hat * ggamma[expander] + gbeta[expander]) / m) | ||
else: | ||
inv_m = numpy.float32(1) / m | ||
gx = cuda.elementwise( | ||
'T gy, T x_hat, T gamma, T std, T ggamma, T gbeta, \ | ||
T inv_m', | ||
'T gx', | ||
'gx = (gamma / std) * (gy - (x_hat * ggamma + gbeta) * \ | ||
inv_m)', | ||
'bn_bwd')(gy, self.x_hat, gamma[expander], | ||
self.std[expander], ggamma[expander], | ||
gbeta[expander], inv_m) | ||
|
||
return gx, ggamma, gbeta |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from chainermn.links.batch_normalization import MultiNodeBatchNormalization # NOQA |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
import chainer | ||
from chainer import cuda | ||
from chainer.functions.normalization import batch_normalization | ||
from chainer import initializers | ||
from chainer import link | ||
import chainer.utils | ||
from chainer import variable | ||
import numpy | ||
|
||
from chainermn.functions.batch_normalization import \ | ||
MultiNodeBatchNormalizationFunction | ||
|
||
|
||
class MultiNodeBatchNormalization(link.Link): | ||
|
||
"""Batch normalization layer that can use the whole batch stats. | ||
When using chainer.link.BatchNormalization, batch mean and std are | ||
computed independently for the local batch in each worker. When local | ||
batch size is too small, training is unstable due to unreliable batch | ||
stats. | ||
In contrast, when using this MultiNodeBatchNormalization, workers | ||
communicate to conduct 'correct' batch normalization (e.g., obtaining | ||
mean and std for the whole global batch). | ||
This link works only with Chainer >= 2.0.0. | ||
Args: | ||
size (int or tuple of ints): Size (or shape) of channel | ||
dimensions. | ||
comm (ChainerMN communicator): communicator to share | ||
the batch stats. | ||
decay (float): Decay rate of moving average. It is used on training. | ||
eps (float): Epsilon value for numerical stability. | ||
dtype (numpy.dtype): Type to use in computing. | ||
use_gamma (bool): If ``True``, use scaling parameter. Otherwise, use | ||
unit(1) which makes no effect. | ||
use_beta (bool): If ``True``, use shifting parameter. Otherwise, use | ||
unit(0) which makes no effect. | ||
""" | ||
|
||
def __init__(self, size, comm, decay=0.9, eps=2e-5, dtype=numpy.float32, | ||
use_gamma=True, use_beta=True, | ||
initial_gamma=None, initial_beta=None): | ||
chainer.utils.experimental( | ||
'chainermn.links.MultiNodeBatchNormalization') | ||
|
||
if chainer.__version__.startswith('1.'): | ||
raise RuntimeError( | ||
'MultiNodeBatchNormalization works only with ' | ||
'chainer >= 2.0.0.') | ||
|
||
super(MultiNodeBatchNormalization, self).__init__() | ||
self.comm = comm | ||
self.avg_mean = numpy.zeros(size, dtype=dtype) | ||
self.register_persistent('avg_mean') | ||
self.avg_var = numpy.zeros(size, dtype=dtype) | ||
self.register_persistent('avg_var') | ||
self.N = 0 | ||
self.register_persistent('N') | ||
self.decay = decay | ||
self.eps = eps | ||
|
||
with self.init_scope(): | ||
if use_gamma: | ||
if initial_gamma is None: | ||
initial_gamma = 1 | ||
initial_gamma = initializers._get_initializer(initial_gamma) | ||
initial_gamma.dtype = dtype | ||
self.gamma = variable.Parameter(initial_gamma, size) | ||
if use_beta: | ||
if initial_beta is None: | ||
initial_beta = 0 | ||
initial_beta = initializers._get_initializer(initial_beta) | ||
initial_beta.dtype = dtype | ||
self.beta = variable.Parameter(initial_beta, size) | ||
|
||
def __call__(self, x, finetune=False): | ||
if hasattr(self, 'gamma'): | ||
gamma = self.gamma | ||
else: | ||
with cuda.get_device_from_id(self._device_id): | ||
gamma = variable.Variable(self.xp.ones( | ||
self.avg_mean.shape, dtype=x.dtype)) | ||
if hasattr(self, 'beta'): | ||
beta = self.beta | ||
else: | ||
with cuda.get_device_from_id(self._device_id): | ||
beta = variable.Variable(self.xp.zeros( | ||
self.avg_mean.shape, dtype=x.dtype)) | ||
|
||
if chainer.configuration.config.train: | ||
if finetune: | ||
self.N += 1 | ||
decay = 1. - 1. / self.N | ||
else: | ||
decay = self.decay | ||
|
||
func = MultiNodeBatchNormalizationFunction( | ||
self.comm, self.eps, self.avg_mean, self.avg_var, decay) | ||
ret = func(x, gamma, beta) | ||
|
||
self.avg_mean[:] = func.running_mean | ||
self.avg_var[:] = func.running_var | ||
else: | ||
# Use running average statistics or fine-tuned statistics. | ||
mean = variable.Variable(self.avg_mean) | ||
var = variable.Variable(self.avg_var) | ||
ret = batch_normalization.fixed_batch_normalization( | ||
x, gamma, beta, mean, var, self.eps) | ||
return ret | ||
|
||
def start_finetuning(self): | ||
"""Resets the population count for collecting population statistics. | ||
This method can be skipped if it is the first time to use the | ||
fine-tuning mode. Otherwise, this method should be called before | ||
starting the fine-tuning mode again. | ||
""" | ||
self.N = 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.