Skip to content

Commit

Permalink
Support infinite loop
Browse files Browse the repository at this point in the history
  • Loading branch information
xs233 committed Sep 29, 2022
1 parent 806522b commit bed9fbc
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 6 deletions.
6 changes: 6 additions & 0 deletions iflearner/business/hetero/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,12 @@ def run(self, epoch: int=1) -> None:
Args:
epoch (int, optional): The number of epochs we need to run. Defaults to 1.
"""

if epoch == -1:
while True:
logger.info("Infinite loop")
self._exec_model_flow()

for i in range(epoch):
logger.info(f"Start epoch {i+1}")
self._exec_model_flow()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
from tkinter import Y
import numpy as np
from phe import paillier
import pandas as pd
Expand All @@ -8,6 +9,7 @@
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
from sklearn.linear_model import LogisticRegression


class Client:
Expand All @@ -29,9 +31,10 @@ def send_data(self, data, target_client):


class ClientA(Client):
def __init__(self, X, config):
def __init__(self, X, X_test, config):
super().__init__(config)
self.X = X
self.X_test = X_test
self.weights = np.zeros(X.shape[1])

def compute_z_a(self):
Expand Down Expand Up @@ -91,12 +94,20 @@ def task_3(self):
print(f"A weight: {self.weights}")
return

def predict(self, client_B_name):
z_a = self.X.dot(self.weights)
z_a_test = self.X_test.dot(self.weights)
data_to_B = {'z_a': z_a, 'z_a_test': z_a_test}
self.send_data(data_to_B, self.other_client[client_B_name])


class ClientB(Client):
def __init__(self, X, y, config):
def __init__(self, X, y, X_test, y_test, config):
super().__init__(config)
self.X = X
self.y = y
self.X_test = X_test
self.y_test = y_test
self.weights = np.zeros(X.shape[1])
self.data = {}

Expand Down Expand Up @@ -173,6 +184,27 @@ def task_3(self):
print(f"B weight: {self.weights}")
return

def _predict(self, z_a, x, y):
z_b = x.dot(self.weights)
results = 1 / (1 + np.exp(-(z_a + z_b)))
# print(results)
y_pred = []
for result in results:
if result > 0.5:
y_pred.append(1)
else:
y_pred.append(0)
# print(y_pred)
print(sum(1 for x, y in zip(y, y_pred) if x == y) / len(y_pred))

def predict(self):
dt = self.data
assert "z_a" in dt.keys(
), "Error: 'z_a' from A in step predict not successfully received."

self._predict(dt["z_a"], self.X, self.y)
self._predict(dt["z_a_test"], self.X_test, self.y_test)


class ClientC(Client):
"""
Expand Down Expand Up @@ -263,7 +295,10 @@ def vertically_partition_data(X, X_test, A_idx, B_idx):
XB = X[:, B_idx]
# print(X.shape[0], np.ones(X.shape[0]))
# print(X.shape[1], np.ones(X.shape[1]))
print(XB)
print(np.ones(X.shape[0]))
XB = np.c_[np.ones(X.shape[0]), XB]
print(XB)
XA_test = X_test[:, A_idx]
XB_test = X_test[:, B_idx]
XB_test = np.c_[np.ones(XB_test.shape[0]), XB_test]
Expand All @@ -287,9 +322,9 @@ def vertical_logistic_regression(X, y, X_test, y_test, config):
print('XA:', XA.shape, ' XB:', XB.shape)

# 各参与方的初始化
client_A = ClientA(XA, config)
client_A = ClientA(XA, XA_test, config)
print("Client_A successfully initialized.")
client_B = ClientB(XB, y, config)
client_B = ClientB(XB, y, XB_test, y_test, config)
print("Client_B successfully initialized.")
client_C = ClientC(XA.shape, XB.shape, config)
print("Client_C successfully initialized.")
Expand All @@ -312,14 +347,26 @@ def vertical_logistic_regression(X, y, X_test, y_test, config):
client_C.task_2("A", "B")
client_A.task_3()
client_B.task_3()

client_A.predict("B")
client_B.predict()

lr = LogisticRegression(max_iter=config['n_iter'])
lr.fit(X, y)
y_pred = lr.predict(X)
print(sum(1 for a, b in zip(y, y_pred) if a == b) / len(y_pred))

y_pred = lr.predict(X_test)
print(sum(1 for a, b in zip(y_test, y_pred) if a == b) / len(y_pred))

print("All process done.")
return True


config = {
'n_iter': 100,
'n_iter': 10,
'lambda': 10,
'lr': 0.05,
'lr': 0.45,
'A_idx': [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
'B_idx': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
}
Expand Down

0 comments on commit bed9fbc

Please sign in to comment.