Skip to content

Commit

Permalink
add split dataset (#2)
Browse files Browse the repository at this point in the history
* modify sampler

* add split dataset  example

* add:split-dataset
  • Loading branch information
HoyTiger authored Sep 13, 2022
1 parent ba0271e commit 3cdc852
Show file tree
Hide file tree
Showing 8 changed files with 216 additions and 239 deletions.
4 changes: 4 additions & 0 deletions iflearner/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from cifar import *
from mnist import *

SUPPORT_DATASET = ['MNIST', 'FashionMNIST', 'KMNIST', 'EMNIST', 'CIFAR10', 'CIFAR100']
1 change: 0 additions & 1 deletion iflearner/datasets/cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ class CIFAR10(FLDateset):
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""

base_folder = "cifar-10-batches-py"
Expand Down
9 changes: 6 additions & 3 deletions iflearner/datasets/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def _check_exists(self) -> bool:
) and os.path.exists(os.path.join(self.processed_folder, self.test_file))

def download(self) -> None:
"""Download the MNIST data if it doesn't exist in processed_folder already."""
"""Download the MNIST data if it doesn't exist in processed_folder
already."""

if self._check_exists():
return
Expand Down Expand Up @@ -150,7 +151,8 @@ class KMNIST(MNIST):


class EMNIST(MNIST):
"""`EMNIST <https://www.westernsydney.edu.au/bens/home/reproducible_research/emnist>`_ Dataset.
"""`EMNIST <https://www.westernsydney.edu.au/bens/home/reproducible_researc
h/emnist>`_ Dataset.
Args:
root (string): Root directory of dataset where ``EMNIST/processed/training.pt``
Expand Down Expand Up @@ -220,7 +222,8 @@ def _test_file(split) -> str:
return "test_{}.pt".format(split)

def download(self) -> None:
"""Download the EMNIST data if it doesn't exist in processed_folder already."""
"""Download the EMNIST data if it doesn't exist in processed_folder
already."""
import shutil

if self._check_exists():
Expand Down
40 changes: 40 additions & 0 deletions iflearner/datasets/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
## How to use split_dataset

Run ``python iflearner/datasets/split_dataset.py -h`` to see the help of the command.

```
optional arguments:
-h, --help show this help message and exit
--config CONFIG The path of config file
--iid IID if to split dataset as iid
--noniid NONIID the kind of noniid type
--alpha ALPHA the parameter to control the noniid degree, only using when iid is false
--save_path SAVE_PATH
the path of saving the splitted dataset
--dataset DATASET the name of dataset
--data_path DATA_PATH
the whole data in .npy format, only using when dataset is None
--label_path LABEL_PATH
the whole label in .npy format, only using when dataset is None
--save_test_set_path SAVE_TEST_SET_PATH
if save the test dataset in
--clients CLIENTS [CLIENTS ...]
clients names, eg. client1 client2
```

If the `config` argument is not None, all other argument must be set in the config file with ``yaml`` formart. It is like:

```yaml
iid: False
noniid: "noniid"
alpha: 1
save_path: "data_train"
dataset: "MNIST"
data_path: "/data/hanyuhu/iflearber-github/data/0-train_data.npy"
label_path: "/data/hanyuhu/iflearber-github/data/0-train_label.npy"
save_test_set_path: "data_test"
clients:
- client1
- client2
```
38 changes: 28 additions & 10 deletions iflearner/datasets/sampler.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from collections import defaultdict
from ctypes import Union

import numpy as np

from iflearner.datasets.utils import partition_class_samples_with_dirichlet_distribution


class Sampler:
def __init__(self, label: list, clients: list, method="iid", **kwargs):
def __init__(self, label, clients: list, method="iid", **kwargs):
self.targets = np.array(label, dtype="int64")
self.clients = clients

if method == "iid":
self.client_index = self.iid()
elif method == "noniid":
Expand All @@ -21,17 +21,20 @@ def __init__(self, label: list, clients: list, method="iid", **kwargs):
)
else:
self.client_index = self.dirichlet_distribution_non_iid(1)
@property
def get_client_index(self):
return self.client_index

def iid(self):
clients_index = defaultdict(set)
length = len(self.targets)
clients_num = len(self.clients)
all_idxs = np.arange(length)
for i in range(clients_num):
clients_index[self.clients[i]] = set(
for i in self.clients:
clients_index[i] = set(
np.random.choice(all_idxs, length // clients_num, replace=False)
)
all_idxs = list(set(all_idxs) - clients_index[self.clients[i]])
all_idxs = list(set(all_idxs) - clients_index[i])

return clients_index

Expand All @@ -54,13 +57,11 @@ def dirichlet_distribution_non_iid(self, alpha):
clients_index[self.clients[i]] = set(idx_batch[i])
return clients_index

def noniid(self, num_shards):
def noniid(self):
clients_num = len(self.clients)
num_shards, num_imgs = num_shards, len(self.targets) // (num_shards)
if num_imgs < 1:
raise Exception('too many shards, shard number cannot be more than length of targets')
num_shards, num_imgs = clients_num * 15, len(self.targets) // (clients_num * 15)
idx_shard = [i for i in range(num_shards)]
clients_index = {i: np.array([], dtype="int64") for i in range(clients_num)}
clients_index = {i: np.array([], dtype="int64") for i in self.clients}
idxs = np.arange(len(self.targets))
labels = np.array(self.targets)

Expand All @@ -85,3 +86,20 @@ def noniid(self, num_shards):
pass
return clients_index


if __name__ == "__main__":
import pandas as pd
from mnist import MNIST

data = MNIST("./data", True)
s = Sampler(data.train_labels, ["1", "2", "3"], "dirichlet")
d = {}
for name, indexes in s.client_index.items():
# print(type(s.targets[indexes]))
indexes = list(indexes)
d[name] = pd.Series(s.targets[indexes])
print(s.targets[indexes])

df = pd.DataFrame(d)
for col in df.columns:
print(df[col].value_counts())
94 changes: 94 additions & 0 deletions iflearner/datasets/split_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import os.path

import numpy as np

from iflearner.datasets.sampler import Sampler
from iflearner.datasets.fl_dataset import FLDateset
from iflearner.datasets.utils import read_yaml
import argparse
from iflearner.datasets import *

SUPPORT_DATASET = ['MNIST', 'FashionMNIST',
'KMNIST', 'EMNIST', 'CIFAR10', 'CIFAR100']


def read_yaml_file(file_path):
with open(file_path, mode='r', encoding='utf-8') as fd:
data = yaml.load(fd, Loader=yaml.FullLoader)
return data


def main(args):
if args.config is not None:
args = read_yaml_file(args.config)
args = argparse.Namespace(**args)
print(args)

if args.dataset:
if args.dataset not in SUPPORT_DATASET:
raise Exception(f'only support {", ".join(SUPPORT_DATASET)}')

dataset = eval(args.dataset)('./data', True)
data = np.array(dataset.train_data)
label = np.array(dataset.train_labels)
if args.save_test_set_path:
if not os.path.exists(args.save_test_set_path):
os.makedirs(args.save_test_set_path)
test_data = np.array(dataset.test_data)
np.save(os.path.join(args.save_test_set_path, 'test_data'), test_data)
test_label = np.array(dataset.test_labels)
np.save(os.path.join(args.save_test_set_path,
'test_label'), test_label)

else:
if args.data_path is None or args.label_path is None:
raise Exception(f'A dataset must be set!')

data = np.load(args.data_path)
label = np.load(args.label_path)
assert len(data) == len(label)
clients = args.clients
if not clients:
raise Exception
if args.iid:
sampler = Sampler(label, clients, 'iid')
else:
sampler = Sampler(label, clients, args.noniid)

client_index = sampler.get_client_index
for name, indexes in client_index.items():
print(name)
indexes = list(indexes)
target = label[indexes]
X = data[indexes]
if not os.path.exists(args.save_path):
os.makedirs(args.save_path)
np.save(os.path.join(args.save_path, f'{name}-train_data'), X)
np.save(os.path.join(args.save_path, f'{name}-train_label'), target)


if __name__ == '__main__':
parse = argparse.ArgumentParser()
parse.add_argument('--config', type=str, help='The path of config file')
parse.add_argument('--iid', type=bool, default=False,
help='if to split dataset as iid', )
parse.add_argument('--noniid', type=str, default='noniid',
help='the kind of noniid type', )
parse.add_argument('--alpha', type=float, default=1,
help='the parameter to control the noniid degree, only using when iid is false')
parse.add_argument('--save_path', type=str, default='./data',
help='the path of saving the splitted dataset')
parse.add_argument('--dataset', type=str, default='MNIST',
help='the name of dataset')
parse.add_argument('--data_path', type=str,
help='the whole data in .npy format')
parse.add_argument('--label_path', type=str,
help='the whole label in .npy format')
parse.add_argument('--save_test_set_path', type=bool,
default=False, help='if save the test dataset in')
parse.add_argument('--clients', type=str, nargs='+',
default=['client2', 'client1'], help='clients names, eg. client1 client2')

args = parse.parse_args()
# print(type(args))
main(args)
Loading

0 comments on commit 3cdc852

Please sign in to comment.