Skip to content

Commit

Permalink
modify sampler (#1)
Browse files Browse the repository at this point in the history
* modify sampler

* add split dataset  example
  • Loading branch information
HoyTiger authored Aug 30, 2022
1 parent 996747d commit 50bf38c
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 30 deletions.
19 changes: 19 additions & 0 deletions examples/spilt_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import pandas as pd

from iflearner.datasets.mnist import MNIST
from iflearner.datasets.sampler import Sampler


clients = ['party1', 'party2', 'party3']
dataset = MNIST('./data', True)
sampler = Sampler(dataset.train_labels, clients, 'dirichlet', alpha=2)
clients_index = sampler.client_index

d = {}
for name, index in clients_index.items():
index = list(index)
p = pd.Series(dataset.train_labels[index].astype('int64'))

print(name+'各个类别:')
print(p.value_counts())
print()
62 changes: 36 additions & 26 deletions iflearner/datasets/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,25 @@

import numpy as np

from iflearner.datasets.utils import partition_class_samples_with_dirichlet_distribution


class Sampler:
def __init__(self, label: list, clients: list, method="iid"):
def __init__(self, label: list, clients: list, method="iid", **kwargs):
self.targets = np.array(label, dtype="int64")
self.clients = clients

if method == "iid":
self.client_index = self.iid()
elif method == "noniid":
self.client_index = self.noniid()
else:
pass
elif method == "dirichlet":
if "alpha" in kwargs:
self.client_index = self.dirichlet_distribution_non_iid(
kwargs.get("alpha")
)
else:
self.client_index = self.dirichlet_distribution_non_iid(1)

def iid(self):
clients_index = defaultdict(set)
Expand All @@ -24,13 +31,34 @@ def iid(self):
clients_index[self.clients[i]] = set(
np.random.choice(all_idxs, length // clients_num, replace=False)
)
all_idxs = list(set(all_idxs) - clients_index[i])
all_idxs = list(set(all_idxs) - clients_index[self.clients[i]])

return clients_index

def dirichlet_distribution_non_iid(self, alpha):
clients_index = defaultdict(set)

N = len(self.targets)
K = len(set(self.targets))
client_num = len(self.clients)
idx_batch = [[] for _ in range(client_num)]
for k in range(K):
# get a list of batch indexes which are belong to label k
idx_k = np.where(self.targets == k)[0]

idx_batch = partition_class_samples_with_dirichlet_distribution(
N, alpha, client_num, idx_batch, idx_k
)
for i in range(client_num):
np.random.shuffle(idx_batch[i])
clients_index[self.clients[i]] = set(idx_batch[i])
return clients_index

def noniid(self):
def noniid(self, num_shards):
clients_num = len(self.clients)
num_shards, num_imgs = clients_num * 15, len(self.targets) // (clients_num * 15)
num_shards, num_imgs = num_shards, len(self.targets) // (num_shards)
if num_imgs < 1:
raise Exception('too many shards, shard number cannot be more than length of targets')
idx_shard = [i for i in range(num_shards)]
clients_index = {i: np.array([], dtype="int64") for i in range(clients_num)}
idxs = np.arange(len(self.targets))
Expand All @@ -46,9 +74,9 @@ def noniid(self):
rand_set = set(np.random.choice(idx_shard, 2, replace=False))
idx_shard = list(set(idx_shard) - rand_set)
for rand in rand_set:
clients_index[i] = np.concatenate(
clients_index[self.clients[i]] = np.concatenate(
(
clients_index[i],
clients_index[self.clients[i]],
idxs[rand * num_imgs : (rand + 1) * num_imgs],
),
axis=0,
Expand All @@ -57,21 +85,3 @@ def noniid(self):
pass
return clients_index


if __name__ == "__main__":
import pandas as pd
from mnist import MNIST

data = MNIST("./data", True)
s = Sampler(data.train_labels, ["1", "2", "3"], "noniid")
index = s.client_index
print(index)
print(s.targets)
for name, indexes in s.client_index.items():
print(s.targets[indexes])

df = pd.DataFrame(
{name: pd.Series(s.targets[index]) for name, index in s.client_index.items()}
)
for col in df.columns:
print(df[col].value_counts())
36 changes: 32 additions & 4 deletions iflearner/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def download_url(


def list_dir(root: str, prefix: bool = False) -> List[str]:
"""List all directories at a given root
"""List all directories at a given root.
Args:
root (str): Path to directory whose folders need to be listed
Expand All @@ -107,7 +107,7 @@ def list_dir(root: str, prefix: bool = False) -> List[str]:


def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]:
"""List all files ending with a suffix at a given root
"""List all files ending with a suffix at a given root.
Args:
root (str): Path to directory whose folders need to be listed
Expand Down Expand Up @@ -287,7 +287,9 @@ def get_int(b: bytes) -> int:

def open_maybe_compressed_file(path: Union[str, IO]) -> Union[IO, gzip.GzipFile]:
"""Return a file object that possibly decompresses 'path' on the fly.
Decompression occurs when argument `path` is a string and ends with '.gz' or '.xz'.
Decompression occurs when argument `path` is a string and ends with
'.gz' or '.xz'.
"""
import torch

Expand All @@ -313,7 +315,9 @@ def open_maybe_compressed_file(path: Union[str, IO]) -> Union[IO, gzip.GzipFile]
def read_sn3_pascalvincent_tensor(
path: Union[str, IO], strict: bool = True
) -> np.ndarray:
"""Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh').
"""Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-
io.lsh').
Argument may be a filename, compressed filename, or file object.
"""
# read
Expand Down Expand Up @@ -378,3 +382,27 @@ def verify_str_arg(
raise ValueError(msg)

return value


def partition_class_samples_with_dirichlet_distribution(
N, alpha, client_num, idx_batch, idx_k
):
np.random.shuffle(idx_k)
# using dirichlet distribution to determine the unbalanced proportion for each client (client_num in total)
# e.g., when client_num = 4, proportions = [0.29543505 0.38414498 0.31998781 0.00043216], sum(proportions) = 1
proportions = np.random.dirichlet(np.repeat(alpha, client_num))

# get the index in idx_k according to the dirichlet distribution
proportions = np.array(
[p * (len(idx_j) < N / client_num) for p, idx_j in zip(proportions, idx_batch)]
)
proportions = proportions / proportions.sum()
proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1]

# generate the batch list for each client
idx_batch = [
idx_j + idx.tolist()
for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions))
]

return idx_batch

0 comments on commit 50bf38c

Please sign in to comment.