Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a base class for ChainerMN communicators #65

Merged
merged 1 commit into from
May 22, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 23 additions & 16 deletions chainermn/communicators/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,38 @@
from chainermn import nccl


class NodeAwareCommunicatorBase(object):
class CommunicatorBase(object):

def __init__(self, mpi_comm):
self.mpi_comm = mpi_comm

@property
def rank(self):
return self.mpi_comm.rank

@property
def size(self):
return self.mpi_comm.size

def broadcast_data(self, model):
raise NotImplementedError()

def allreduce_grad(self, model):
raise NotImplementedError()


class NodeAwareCommunicatorBase(CommunicatorBase):

def __init__(self, mpi_comm, use_nccl):
super(NodeAwareCommunicatorBase, self).__init__(mpi_comm)

if use_nccl and not nccl._available:
raise RuntimeError(
'NCCL is not available. '
'Please confirm that NCCL can be found by dynamic linkers, '
'and ChainerMN is installed without --no-nccl flag.'
)

self.mpi_comm = mpi_comm
self.use_nccl = use_nccl

self._init_ranks()
Expand All @@ -23,14 +44,6 @@ def __init__(self, mpi_comm, use_nccl):
if self.use_nccl:
self.intra_nccl_comm = None

@property
def rank(self):
return self.mpi_comm.rank

@property
def size(self):
return self.mpi_comm.size

def _init_ranks(self):
my_ranks = _communication_utility.init_ranks(self.mpi_comm)
assert my_ranks[0] == self.mpi_comm.rank
Expand All @@ -52,9 +65,3 @@ def _init_comms(self):
self.inter_mpi_comm = comms[1]
if self.use_nccl:
self.intra_nccl_comm = comms[2]

def broadcast_data(self, model):
raise NotImplementedError()

def allreduce_grad(self, model):
raise NotImplementedError()
13 changes: 3 additions & 10 deletions chainermn/communicators/naive_communicator.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,14 @@
import mpi4py.MPI

from chainermn.communicators import _base
from chainermn.communicators import _communication_utility
from chainermn.communicators import _memory_utility


class NaiveCommunicator(object):
class NaiveCommunicator(_base.CommunicatorBase):

def __init__(self, mpi_comm):
self.mpi_comm = mpi_comm

@property
def rank(self):
return self.mpi_comm.rank

@property
def size(self):
return self.mpi_comm.size
super(NaiveCommunicator, self).__init__(mpi_comm)

def broadcast_data(self, model):
_communication_utility.broadcast_naive(self.mpi_comm, model)
Expand Down