Skip to content

Commit

Permalink
Integrate STC method
Browse files Browse the repository at this point in the history
  • Loading branch information
xs233 committed Nov 3, 2022
1 parent b318a16 commit 040693a
Show file tree
Hide file tree
Showing 13 changed files with 351 additions and 66 deletions.
42 changes: 20 additions & 22 deletions examples/homo/imagenet/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@
import torchvision.transforms as transforms
from opacus import PrivacyEngine
from opacus.utils import module_modification
from torch.optim.lr_scheduler import StepLR
from torch.optim.lr_scheduler import StepLR, LambdaLR

from iflearner.communication.homo import homo_pb2
from iflearner.business.homo import pytorch_trainer, train_client
from iflearner.business.homo.argument import parser

sys.path.append("../../../")

Expand All @@ -40,7 +44,6 @@
)

# parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
from iflearner.business.homo.argument import parser

parser.add_argument(
"data",
Expand All @@ -54,7 +57,8 @@
metavar="ARCH",
default="resnet18",
choices=model_names,
help="model architecture: " + " | ".join(model_names) + " (default: resnet18)",
help="model architecture: " +
" | ".join(model_names) + " (default: resnet18)",
)
parser.add_argument(
"-j",
Expand Down Expand Up @@ -92,7 +96,8 @@
help="initial learning rate",
dest="lr",
)
parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
parser.add_argument("--momentum", default=0.9, type=float,
metavar="M", help="momentum")
parser.add_argument(
"--wd",
"--weight-decay",
Expand Down Expand Up @@ -158,16 +163,14 @@
"multi node data parallel training",
)

parser.add_argument("--scaffold", help="Enable scaffold (1 | 0).", type=int, default=0)
parser.add_argument(
"--scaffold", help="Enable scaffold (1 | 0).", type=int, default=0)
parser.add_argument(
"--dp", help="Enable differential privacy (1 | 0).", type=int, default=0
)

best_acc1 = 0

from iflearner.business.homo import pytorch_trainer, train_client
from iflearner.communication.homo import homo_pb2


class ImageNet(pytorch_trainer.PyTorchTrainer):
def __init__(
Expand Down Expand Up @@ -261,17 +264,6 @@ def evaluate(self, epoch):

return {"top1": float(top1), "top5": float(top5)}

def get(self, param_type=""):
parameters = dict()
for name, p in self._model.named_parameters():
if p.requires_grad:
parameters[name] = (
self._old_weights[name].cpu().detach().numpy()
- self._new_weights[name].cpu().detach().numpy()
)

return parameters


def main():
args = parser.parse_args()
Expand Down Expand Up @@ -307,7 +299,8 @@ def main():
args.world_size = ngpus_per_node * args.world_size
# Use torch.multiprocessing.spawn to launch distributed processes: the
# main_worker process function
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
mp.spawn(main_worker, nprocs=ngpus_per_node,
args=(ngpus_per_node, args))
else:
# Simply call main_worker function
main_worker(args.gpu, ngpus_per_node, args)
Expand Down Expand Up @@ -341,6 +334,8 @@ def main_worker(gpu, ngpus_per_node, args):
print("=> creating model '{}'".format(args.arch))
model = models.__dict__[args.arch]()

print(model)

if not torch.cuda.is_available():
print("using CPU, this will be slow")
elif args.distributed:
Expand All @@ -354,7 +349,8 @@ def main_worker(gpu, ngpus_per_node, args):
# DistributedDataParallel, we need to divide the batch size
# ourselves based on the total number of GPUs of the current node.
args.batch_size = int(args.batch_size / ngpus_per_node)
args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
args.workers = int(
(args.workers + ngpus_per_node - 1) / ngpus_per_node)
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[args.gpu]
)
Expand Down Expand Up @@ -391,6 +387,7 @@ def main_worker(gpu, ngpus_per_node, args):

"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
# scheduler = LambdaLR(optimizer, lambda epoch: 1.0)

# optionally resume from a checkpoint
if args.resume:
Expand Down Expand Up @@ -440,7 +437,8 @@ def main_worker(gpu, ngpus_per_node, args):
)

if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset)
else:
train_sampler = None

Expand Down
10 changes: 8 additions & 2 deletions iflearner/business/homo/aggregate_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
fednova_server,
fedopt_server,
qfedavg_server,
stc_server,
)
from iflearner.business.homo.strategy.strategy_server import StrategyServer
from iflearner.communication.base import base_server
Expand Down Expand Up @@ -59,6 +60,10 @@ def __init__(
num_clients, epochs, True, **strategy_params
)

elif strategy == message_type.STRATEGY_STC:
self._strategy_server = stc_server.STCServer(
num_clients, epochs)

elif strategy == message_type.STRATEGY_FEDOPT:
if strategy_params.get("opt") is None:
raise Exception("expect 'opt' when you use fedopt sever")
Expand All @@ -75,7 +80,8 @@ def __init__(
opt=opt,
) # type: ignore
logger.info(
" ".join([f"{k}:{v}" for k, v in strategy_params.items()])
" ".join(
[f"{k}:{v}" for k, v in strategy_params.items()])
)

elif strategy == message_type.STRATEGY_qFEDAVG:
Expand Down Expand Up @@ -127,7 +133,7 @@ def main():
)
parser.add_argument(
"--strategy",
help="the aggregation starategy (FedAvg | Scaffold | FedOpt | qFedAvg | FedNova)",
help="the aggregation starategy (FedAvg | Scaffold | STC | FedOpt | qFedAvg | FedNova)",
default="FedAvg",
type=str,
)
Expand Down
147 changes: 147 additions & 0 deletions iflearner/business/homo/strategy/stc_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright 2022 iFLYTEK. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import sys
import pickle
import numpy as np
from loguru import logger
from typing import Dict, Any, List
from iflearner.communication.homo import homo_pb2
from iflearner.business.homo.strategy.strategy_client import StrategyClient


class STCClient(StrategyClient):
"""Implement the STC method base on the paper (https://ieeexplore.ieee.org/document/8889996).
"""

def __init__(self) -> None:
super().__init__()

self._top_fraction: float = 0.1
self._weights: Dict = None
self._residuals: Dict = None
self._enable_residuals: bool = True
logger.info(
f"STC client, top fraction: {self._top_fraction}, enable residuals: {self._enable_residuals}")

def generate_upload_param(self, epoch: int, data: Dict[Any, Any], metrics: Dict[str, float] = None) -> homo_pb2.UploadParam:
if self._residuals is None:
self._residuals = {}
self._weights = {}
for k, v in data.items():
self._residuals[k] = np.zeros(v.size)
self._weights[k] = np.zeros(v.size)

compressed_data = {}
for k, v in data.items():
compressed_data[k] = homo_pb2.Parameter(shape=v.shape)

ravel_v = v.ravel()

weight_difference = ravel_v - self._weights[k]
if self._enable_residuals:
self._residuals[k] += weight_difference

self._weights[k] = ravel_v

out = None
mean = 0.0
if self._enable_residuals:
out, mean = self._compression(self._residuals[k])
self._residuals[k] -= out
else:
out, mean = self._compression(weight_difference)

sparse_data = self._encode_sparse_array(out, mean)
compressed_data[k].custom_values = sparse_data

return homo_pb2.UploadParam(epoch=epoch, parameters=compressed_data, metrics=metrics)

def aggregate_result(self) -> homo_pb2.AggregateResult:
for k, v in self._aggregate_result_np.items():
self._weights[k] = v.flatten()

return self._aggregate_result_np

def _compression(self, T: np.array) -> np.array:
"""Compress a array.
Args:
T (np.array): The array that needs to be compressed.
Returns:
np.array: The compressed array.
"""
T_abs = np.absolute(T)
n_top = int(np.ceil(T_abs.size * self._top_fraction))
topk = T_abs[np.argpartition(T_abs, -n_top)[-n_top:]]
mean = np.mean(topk)
min_topk = topk.min()
out_ = np.where(T >= min_topk, mean, 0.0)
out = np.where(T <= -min_topk, -mean, out_)
return out, mean

def _encode_sparse_array(self, arr: np.array, mean_value: np.float64) -> bytes:
"""Encode a sparse array to bytes.
Args:
arr (np.array): A sparse array.
Returns:
bytes: The data that dumped by pickle.
"""
logger.info(
f"Encode a sparse array, size: {arr.size}")

positive_horizontal_coordinates = []
positive_vertical_coordinates = []
negative_horizontal_coordinates = []
negative_vertical_coordinates = []
horizontal_coordinate_type = np.uint8
vertical_coordinate_type = np.uint8

uint8_len = np.iinfo(np.uint8).max + 1
uint16_len = np.iinfo(np.uint16).max + 1
if arr.size > uint8_len * uint16_len:
horizontal_coordinate_type = np.uint16
if arr.size > uint8_len * uint8_len:
vertical_coordinate_type = np.uint16

horizontal_index = 0
vertical_index = 0
for item in arr:
if horizontal_index > np.iinfo(np.uint8).max:
horizontal_index = 0
vertical_index += 1
if item > 0:
positive_horizontal_coordinates.append(
horizontal_coordinate_type(horizontal_index))
positive_vertical_coordinates.append(
vertical_coordinate_type(vertical_index))
elif item < 0:
negative_horizontal_coordinates.append(
horizontal_coordinate_type(horizontal_index))
negative_vertical_coordinates.append(
vertical_coordinate_type(vertical_index))

horizontal_index += 1

my_np_tuple = (mean_value, np.array(positive_horizontal_coordinates),
np.array(positive_vertical_coordinates), np.array(negative_horizontal_coordinates), np.array(negative_vertical_coordinates))
data = pickle.dumps(my_np_tuple)
logger.info(
f"After encoding, positive coordinates num: {len(positive_horizontal_coordinates)}, negative coordinates num: {len(negative_horizontal_coordinates)}, the size is {len(data)}")

return data
Loading

0 comments on commit 040693a

Please sign in to comment.