Skip to content

Commit

Permalink
Merge pull request #65 from levelfour/base-communicator
Browse files Browse the repository at this point in the history
Add a base class for ChainerMN communicators
  • Loading branch information
iwiwi authored May 22, 2017
2 parents 3569569 + 447632c commit 05e920b
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 26 deletions.
39 changes: 23 additions & 16 deletions chainermn/communicators/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,38 @@
from chainermn import nccl


class NodeAwareCommunicatorBase(object):
class CommunicatorBase(object):

def __init__(self, mpi_comm):
self.mpi_comm = mpi_comm

@property
def rank(self):
return self.mpi_comm.rank

@property
def size(self):
return self.mpi_comm.size

def broadcast_data(self, model):
raise NotImplementedError()

def allreduce_grad(self, model):
raise NotImplementedError()


class NodeAwareCommunicatorBase(CommunicatorBase):

def __init__(self, mpi_comm, use_nccl):
super(NodeAwareCommunicatorBase, self).__init__(mpi_comm)

if use_nccl and not nccl._available:
raise RuntimeError(
'NCCL is not available. '
'Please confirm that NCCL can be found by dynamic linkers, '
'and ChainerMN is installed without --no-nccl flag.'
)

self.mpi_comm = mpi_comm
self.use_nccl = use_nccl

self._init_ranks()
Expand All @@ -23,14 +44,6 @@ def __init__(self, mpi_comm, use_nccl):
if self.use_nccl:
self.intra_nccl_comm = None

@property
def rank(self):
return self.mpi_comm.rank

@property
def size(self):
return self.mpi_comm.size

def _init_ranks(self):
my_ranks = _communication_utility.init_ranks(self.mpi_comm)
assert my_ranks[0] == self.mpi_comm.rank
Expand All @@ -52,9 +65,3 @@ def _init_comms(self):
self.inter_mpi_comm = comms[1]
if self.use_nccl:
self.intra_nccl_comm = comms[2]

def broadcast_data(self, model):
raise NotImplementedError()

def allreduce_grad(self, model):
raise NotImplementedError()
13 changes: 3 additions & 10 deletions chainermn/communicators/naive_communicator.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,14 @@
import mpi4py.MPI

from chainermn.communicators import _base
from chainermn.communicators import _communication_utility
from chainermn.communicators import _memory_utility


class NaiveCommunicator(object):
class NaiveCommunicator(_base.CommunicatorBase):

def __init__(self, mpi_comm):
self.mpi_comm = mpi_comm

@property
def rank(self):
return self.mpi_comm.rank

@property
def size(self):
return self.mpi_comm.size
super(NaiveCommunicator, self).__init__(mpi_comm)

def broadcast_data(self, model):
_communication_utility.broadcast_naive(self.mpi_comm, model)
Expand Down

0 comments on commit 05e920b

Please sign in to comment.