Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MultiNodeBatchNormalization #106

Merged
merged 20 commits into from
Aug 24, 2017
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix PEP8
  • Loading branch information
iwiwi committed Aug 10, 2017
commit 12d84756ffa96fc7b9523e4423424541ddf3c673
3 changes: 2 additions & 1 deletion chainermn/functions/batch_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def _xhat(x, mean, std, expander):
class MultiNodeBatchNormalizationFunction(function.Function):

def __init__(self, comm, eps=2e-5, mean=None, var=None, decay=0.9):
chainer.utils.experimental('chainermn.functions.MultiNodeBatchNormalizationFunction')
chainer.utils.experimental(
'chainermn.functions.MultiNodeBatchNormalizationFunction')

self.comm = comm
self.running_mean = mean
Expand Down
5 changes: 3 additions & 2 deletions chainermn/links/batch_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ class MultiNodeBatchNormalization(link.Link):
def __init__(self, size, comm, decay=0.9, eps=2e-5, dtype=numpy.float32,
use_gamma=True, use_beta=True,
initial_gamma=None, initial_beta=None):
chainer.utils.experimental('chainermn.links.MultiNodeBatchNormalization')
chainer.utils.experimental(
'chainermn.links.MultiNodeBatchNormalization')

super(MultiNodeBatchNormalization, self).__init__()
self.comm = comm
Expand Down Expand Up @@ -86,4 +87,4 @@ def start_finetuning(self):
starting the fine-tuning mode again.

"""
self.N = 0
self.N = 0
47 changes: 24 additions & 23 deletions tests/links_tests/test_batch_normalization.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,12 @@
import unittest

import mpi4py.MPI
import numpy as np

from chainermn.communicators.naive_communicator import NaiveCommunicator

import copy
import nose.plugins.skip
import unittest

import chainer
import chainer.testing
import chainer.testing.attr
import mpi4py.MPI
import numpy
import unittest

import chainermn
from chainermn.communicators.naive_communicator import NaiveCommunicator
import chainermn.links


Expand Down Expand Up @@ -64,17 +56,24 @@ def test_multi_node_bn(self):
local_batchsize = 10
global_batchsize = 10 * comm.size
ndim = 3
np.random.seed(71)
x = np.random.random((global_batchsize, ndim)).astype(np.float32)
y = np.random.randint(0, 1, size=global_batchsize, dtype=np.int32)
x_local = comm.mpi_comm.scatter(x.reshape(comm.size, local_batchsize, ndim))
y_local = comm.mpi_comm.scatter(y.reshape(comm.size, local_batchsize))
print(x.shape, y.shape, x_local.shape, y_local.shape)

m1 = chainer.links.Classifier(ModelNormalBN()) # Single Normal
m2 = chainer.links.Classifier(ModelNormalBN()) # Distributed Normal
m3 = chainer.links.Classifier(ModelDistributedBN(comm)) # Distributed BN
m4 = chainer.links.Classifier(ModelDistributedBN(comm)) # Sequential Normal
numpy.random.seed(71)
x = numpy.random.random(
(global_batchsize, ndim)).astype(numpy.float32)
y = numpy.random.randint(
0, 1, size=global_batchsize, dtype=numpy.int32)
x_local = comm.mpi_comm.scatter(
x.reshape(comm.size, local_batchsize, ndim))
y_local = comm.mpi_comm.scatter(
y.reshape(comm.size, local_batchsize))

cls = chainer.links.Classifier
m1 = cls(ModelNormalBN()) # Single worker
m2 = cls(ModelNormalBN()) # Multi worker, Ghost BN
m3 = cls(ModelDistributedBN(comm)) # Single worker, MNBN
m4 = cls(ModelDistributedBN(comm)) # Multi worker, MNBN
# NOTE: m1, m3 and m4 should behave in the same way.
# m2 may be different.

m2.copyparams(m1)
m3.copyparams(m1)
m4.copyparams(m1)
Expand Down Expand Up @@ -108,6 +107,8 @@ def test_multi_node_bn(self):
assert(p3[0] == name)
assert(p4[0] == name)

# TODO: check p1[1].grad != p2[1].grad (to confirm that this test is valid)
chainer.testing.assert_allclose(p1[1].grad, p3[1].grad)
chainer.testing.assert_allclose(p1[1].grad, p4[1].grad)

# TODO: check p1[1].grad != p2[1].grad
# (to confirm that this test is valid)