Skip to content

Commit

Permalink
Complete base design
Browse files Browse the repository at this point in the history
  • Loading branch information
xs233 committed Sep 27, 2022
1 parent 5b04d96 commit 92e8342
Show file tree
Hide file tree
Showing 11 changed files with 478 additions and 51 deletions.
64 changes: 41 additions & 23 deletions iflearner/business/hetero/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pickle
import time
from os.path import join
from typing import Any
from loguru import logger

from iflearner.business.hetero.parser import Parser
Expand Down Expand Up @@ -46,43 +47,60 @@ def __init__(self) -> None:
raise Exception(f"{parser.model_name} is not existed.")

self._model = Builders[parser.model_name].create_role_model_instance(parser.role_name)
parser.parse_model_flow_file(join("model", parser.model_name, Builders[parser.model_name].get_role_model_flow_file(parser.role_name)))

self._model.set_hyper_params(parser.hyper_params)

parser.parse_model_flow_file(join("model", parser.model_name, Builders[parser.model_name].get_role_model_flow_file(parser.role_name)))
logger.info(f"Model flow: {parser.model_flow}")
logger.info(f"Network config: {parser.network_config}")
self._network = HeteroNetwork(*parser.network_config)
self._init_steps = False

def _exec_flow(self) -> None:
"""Execute flow.
def _exec_steps(self, steps: Any) -> None:
"""Execute steps.
Args:
steps (Any): Details of the steps.
Raise:
Exception(f"The return type of {step_name} is illegal.")
Exception(f"The return type of {step.name} is illegal.")
"""
for step in parser.model_flow["steps"]:
step_name = step["name"]
upstreams = step["upstreams"]
logger.info(f"Step: {step_name}, Upstreams: {upstreams}")
if upstreams is not None:
for upstream in upstreams:
data_list = None
while data_list is None:
data_list = self._network.pull(upstream["role"], upstream["step"])
time.sleep(1)
self._model.handle_upstream(upstream["role"], upstream["step"], data_list)
for step in steps:
logger.info(f"{step}")
for upstream in step.upstreams:
data_list = None
while data_list is None:
data_list = self._network.pull(upstream.role, upstream.step)
time.sleep(1)

self._model.handle_upstream(upstream.role, upstream.step, data_list)

if step.virtual is True:
continue

result = self._model.handle_step(step_name)
result = self._model.handle_step(step.name)
if result is None:
continue

for name, data in result.items():
data = pickle.dumps(data)
if isinstance(name, Role):
self._network.push(str(name), None, step_name, data)
self._network.push(str(name), None, step.name, data)
elif isinstance(name, str):
self._network.push(None, name, step_name, data)
self._network.push(None, name, step.name, data)
else:
raise Exception(f"The return type of {step_name} is illegal.")
raise Exception(f"The return type of {step.name} is illegal.")

def _exec_model_flow(self) -> None:
"""Execute model flow.
Raise:
Exception(f"The return type of {step.name} is illegal.")
"""
if not self._init_steps and parser.model_flow.init_steps is not None:
self._exec_steps(parser.model_flow.init_steps)
self._init_steps = True

self._exec_steps(parser.model_flow.steps)

def run(self, epoch: int=1) -> None:
"""Loop execution process.
Expand All @@ -92,11 +110,11 @@ def run(self, epoch: int=1) -> None:
"""
for i in range(epoch):
logger.info(f"Start epoch {i+1}")
self._exec_flow()
self._exec_model_flow()

if __name__ == "__main__":
parser.parse_task_configuration_file()
driver = Driver()
driver.run()
driver.run(parser.epochs)
os._exit(0)

12 changes: 11 additions & 1 deletion iflearner/business/hetero/model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================

from abc import ABC, abstractmethod
from typing import Dict, Tuple, List, Union, Any
from iflearner.business.hetero.model.role import Role

Expand All @@ -37,7 +38,7 @@ def handle_own_step() -> Dict[Union[Role, str], Any]:
pass


class BaseModel:
class BaseModel(ABC):
'''Define each step of model training and evaluation.
You need to split the whole process into steps, then you need to implement it and register it with self._register_own_step.
Expand All @@ -48,6 +49,15 @@ def __init__(self) -> None:
self._another_steps: Dict[str, handle_another_step] = {}
self._own_steps: Dict[str, handle_own_step] = {}

@abstractmethod
def set_hyper_params(self, hyper_params: Any) -> None:
"""Set hyper params.
Args:
hyper_params (Any): Details of the hyper params.
"""
pass

def _register_another_step(self, role: Role, step_name: str, func: handle_another_step) -> None:
"""Register a another step handler.
Expand Down
55 changes: 55 additions & 0 deletions iflearner/business/hetero/model/logistic_regression/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright 2022 iFLYTEK. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split


def load_data():
breast = load_breast_cancer()
x_train, x_test, y_train, y_test = train_test_split(
breast.data, breast.target, random_state=1)
std = StandardScaler()
x_train = std.fit_transform(x_train)
x_test = std.transform(x_test)
return x_train, y_train, x_test, y_test


def vertically_partition_data(X, X_test, A_idx, B_idx):
XA = X[:, A_idx]
XB = X[:, B_idx]
# print(X.shape[0], np.ones(X.shape[0]))
# print(X.shape[1], np.ones(X.shape[1]))
XB = np.c_[np.ones(X.shape[0]), XB]
XA_test = X_test[:, A_idx]
XB_test = X_test[:, B_idx]
XB_test = np.c_[np.ones(XB_test.shape[0]), XB_test]
return XA, XB, XA_test, XB_test


def get_guest_data():
x, y, x_test, y_test = load_data()
XA, XB, XA_test, XB_test = vertically_partition_data(x, x_test, [
10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
return XB, y


def get_host_data():
x, y, x_test, y_test = load_data()
XA, XB, XA_test, XB_test = vertically_partition_data(x, x_test, [
10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
return XA
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def __init__(self, A_d_shape, B_d_shape, config):
super().__init__(config)
self.A_data_shape = A_d_shape
self.B_data_shape = B_d_shape
print(A_d_shape, B_d_shape)
self.public_key = None
self.private_key = None
# 保存训练中的损失值(泰展开近似)
Expand Down
61 changes: 59 additions & 2 deletions iflearner/business/hetero/model/logistic_regression/lr_arbiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.
# ==============================================================================

import math
import numpy as np
from loguru import logger
from phe import paillier
from typing import Any, List, Dict, Union
Expand All @@ -25,12 +27,67 @@ class LRArbiter(BaseModel):
def __init__(self) -> None:
super().__init__()

self._register_own_step("generate_he_keypair",
self.generate_he_keypair)
self._register_another_step(guest, "calc_final_result_with_host", self.received_guest_encrypted_data)
self._register_another_step(host, "calc_final_result_with_guest", self.received_host_encrypted_data)

self._register_own_step("generate_he_keypair", self.generate_he_keypair)
self._register_own_step("decrypt_guest_data", self.decrypt_guest_data)
self._register_own_step("decrypt_host_data", self.decrypt_host_data)

def set_hyper_params(self, hyper_params: Any) -> None:
"""Set hyper params.
Args:
hyper_params (Any): Details of the hyper params.
"""
super().set_hyper_params(hyper_params)

def generate_he_keypair(self) -> Dict[Union[Role, str], Any]:
"""Generate HE public key and private key.
Returns:
Dict[Union[Role, str], Any]: Return the HE public key to the guest and host.
"""
public_key, private_key = paillier.generate_paillier_keypair()
logger.info(f"Public key: {public_key}")
self._private_key = private_key
return {guest: public_key.n, host: public_key.n}

def received_guest_encrypted_data(self, data: Dict[str, Any]) -> None:
"""Save encrypted data from the guest.
Args:
data (Dict[str, Any]): Guest party name and encrypted data.
"""
self._encrypted_masked_dJ_b, encrypted_loss, shape = list(data.values())[0]
loss = self._private_key.decrypt(encrypted_loss) / shape + math.log(2)
logger.info(f"Loss: {loss}")

def received_host_encrypted_data(self, data: Dict[str, Any]) -> None:
"""Save encrypted data from the host.
Args:
data (Dict[str, Any]): Host party name and encrypted data.
"""
self._encrypted_masked_dJ_a = list(data.values())[0]

def decrypt_guest_data(self) -> Dict[Union[Role, str], Any]:
"""Decrypt guest data.
Returns:
Dict[Union[Role, str], Any]: Return guest role name and its decrypted data.
"""
masked_dJ_b = np.asarray([self._private_key.decrypt(x) for x in self._encrypted_masked_dJ_b])
return {guest: masked_dJ_b}

def decrypt_host_data(self) -> Dict[Union[Role, str], Any]:
"""Decrypt host data.
Returns:
Dict[Union[Role, str], Any]: Return host role name and its decrypted data.
"""
masked_dJ_a = np.asarray([self._private_key.decrypt(x) for x in self._encrypted_masked_dJ_a])
return {host: masked_dJ_a}



Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
role: arbiter
steps:
init_steps:
- name: generate_he_keypair
upstreams: null
steps:
- name: decrypt_guest_data
upstreams:
- role: guest
step: calc_final_result_with_host
- name: decrypt_host_data
upstreams:
- role: host
step: calc_final_result_with_guest
Loading

0 comments on commit 92e8342

Please sign in to comment.