Skip to content

Commit

Permalink
Reshape api
Browse files Browse the repository at this point in the history
  • Loading branch information
xs233 committed Sep 23, 2022
1 parent 39384ac commit 5b04d96
Show file tree
Hide file tree
Showing 14 changed files with 149 additions and 141 deletions.
22 changes: 11 additions & 11 deletions iflearner/business/hetero/builder/lr_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================

from iflearner.business.hetero.model.role import Role
from iflearner.business.hetero.model.role import Role, Guest, Host, Arbiter
from iflearner.business.hetero.model.base_model import BaseModel
from iflearner.business.hetero.model.logistic_regression import lr_guest, lr_host, lr_arbiter

Expand All @@ -22,38 +22,38 @@

class LRBuilder(ModelBuilder):

def create_role_model_instance(self, role: str) -> BaseModel:
def create_role_model_instance(self, role: Role) -> BaseModel:
"""Create a model instance base on specific role.
Args:
role (str): The role name.
role (Role): The role name.
Returns:
BaseModel: Return the base class.
"""
if role == Role.guest:
if isinstance(role, Guest):
return lr_guest.LRGuest()
elif role == Role.host:
elif isinstance(role, Host):
return lr_host.LRHost()
elif role == Role.arbiter:
elif isinstance(role, Arbiter):
return lr_arbiter.LRArbiter()

raise Exception(f"{role} is not existed.")

def get_role_model_flow_file(self, role: str) -> str:
def get_role_model_flow_file(self, role: Role) -> str:
"""Get model flow file by role name.
Args:
role (str): The role name.
role (Role): The role name.
Returns:
str: Return the filename.
"""
if role == Role.guest:
if isinstance(role, Guest):
return "lr_guest_flow.yaml"
elif role == Role.host:
elif isinstance(role, Host):
return "lr_host_flow.yaml"
elif role == Role.arbiter:
elif isinstance(role, Arbiter):
return "lr_arbiter_flow.yaml"

raise Exception(f"{role} is not existed.")
9 changes: 5 additions & 4 deletions iflearner/business/hetero/builder/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# ==============================================================================

from abc import ABC, abstractmethod
from iflearner.business.hetero.model.role import Role
from iflearner.business.hetero.model.base_model import BaseModel


Expand All @@ -22,23 +23,23 @@ class ModelBuilder(ABC):
"""

@abstractmethod
def create_role_model_instance(self, role: str) -> BaseModel:
def create_role_model_instance(self, role: Role) -> BaseModel:
"""Create a model instance base on specific role.
Args:
role (str): The role name.
role (Role): The role name.
Returns:
BaseModel: Return the base class.
"""
pass

@abstractmethod
def get_role_model_flow_file(self, role: str) -> str:
def get_role_model_flow_file(self, role: Role) -> str:
"""Get model flow file by role name.
Args:
role (str): The role name.
role (Role): The role name.
Returns:
str: Return the filename.
Expand Down
25 changes: 13 additions & 12 deletions iflearner/business/hetero/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
# ==============================================================================

import os
import pickle
import time
from os.path import join
from loguru import logger

from iflearner.business.hetero.parser import Parser
from iflearner.business.hetero.builder.builders import Builders
from iflearner.business.hetero.model.role import Role, role_class
from iflearner.communication.hetero.hetero_network import HeteroNetwork

parser = Parser()
Expand Down Expand Up @@ -54,7 +56,7 @@ def _exec_flow(self) -> None:
"""Execute flow.
Raise:
Exception(f"The return of handle {step_name} is illegal.")
Exception(f"The return type of {step_name} is illegal.")
"""
for step in parser.model_flow["steps"]:
step_name = step["name"]
Expand All @@ -64,24 +66,23 @@ def _exec_flow(self) -> None:
for upstream in upstreams:
data_list = None
while data_list is None:
data_list = self._network.pull(
upstream["role"], upstream["step"])
data_list = self._network.pull(upstream["role"], upstream["step"])
time.sleep(1)

self._model.handle_upstream(
upstream["role"], upstream["step"], data_list)
self._model.handle_upstream(upstream["role"], upstream["step"], data_list)

result = self._model.handle_step(step_name)
if result is None:
continue

if isinstance(result, tuple):
self._network.push(result[0], None, step_name, result[1])
elif isinstance(result, dict):
for party_name, data in result.items():
self._network.push(None, party_name, step_name, data)
else:
raise Exception(f"The return of handle {step_name} is illegal.")
for name, data in result.items():
data = pickle.dumps(data)
if isinstance(name, Role):
self._network.push(str(name), None, step_name, data)
elif isinstance(name, str):
self._network.push(None, name, step_name, data)
else:
raise Exception(f"The return type of {step_name} is illegal.")

def run(self, epoch: int=1) -> None:
"""Loop execution process.
Expand Down
33 changes: 17 additions & 16 deletions iflearner/business/hetero/model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,26 @@
# limitations under the License.
# ==============================================================================

from typing import Dict, Tuple, List, Union
from typing import Dict, Tuple, List, Union, Any
from iflearner.business.hetero.model.role import Role


def handle_another_step(data: List[Tuple[str, bytes]]) -> None:
def handle_another_step(data: Dict[str, Any]) -> None:
"""Handle a step from another role.
Args:
data (List[Tuple[str, bytes]]): List data for all role members.
data (Dict[str, Any]): List data for all role members. (str: party name, Any: data)
"""
pass


def handle_own_step() -> Union[Tuple[str, bytes], Dict[str, bytes]]:
def handle_own_step() -> Dict[Union[Role, str], Any]:
"""Handle a own step.
Returns:
Union[Tuple[str, bytes], Dict[str, bytes]]:
Tuple[str, bytes]: If you want to send the same data to all role members, you need to return this (k: role name, v: data).
Dict[str, bytes]: If you want to send unique data to a specific role member, you need to return this (k: party name, v: data).
Dict[Union[Role, str], Any]: Return each target and its data.
Union[Role, str]: The role class or party name.
Any: Return a python object, which we will serialize to bytes using pickle.dumps.
"""
pass

Expand All @@ -47,11 +48,11 @@ def __init__(self) -> None:
self._another_steps: Dict[str, handle_another_step] = {}
self._own_steps: Dict[str, handle_own_step] = {}

def _register_another_step(self, role: str, step_name: str, func: handle_another_step) -> None:
def _register_another_step(self, role: Role, step_name: str, func: handle_another_step) -> None:
"""Register a another step handler.
Args:
role (str): The target role name.
role (Role): The target role name.
step_name (str): Unique name for the step.
func (handle_another_step): The handler you implement.
"""
Expand All @@ -66,29 +67,29 @@ def _register_own_step(self, step_name: str, func: handle_own_step) -> None:
"""
self._own_steps[step_name] = func

def handle_upstream(self, role: str, step_name: str, data: List[Tuple[str, bytes]]) -> None:
def handle_upstream(self, role: Role, step_name: str, data: Dict[str, Any]) -> None:
"""Handle specific upstream step from other role.
Args:
role (str): The target role name.
role (Role): The target role name.
step_name (str): Unique name for the step.
data (List[Tuple[str, bytes]]): List data for all role members.
data (Dict[str, Any]): List data for all role members.
"""
key = f"{role}.{step_name}"
assert key in self._another_steps, f"{key} is not implemented."

self._another_steps[key](data)

def handle_step(self, step_name: str) -> Union[Tuple[str, bytes], Dict[str, bytes]]:
def handle_step(self, step_name: str) -> Dict[Union[Role, str], Any]:
"""Handle own specific step.
Args:
step_name (str): Unique name for the step.
Returns:
Union[Tuple[str, bytes], Dict[str, bytes]]:
Tuple[str, bytes]: If you want to send the same data to all role members, you need to return this (k: role name, v: data).
Dict[str, bytes]: If you want to send unique data to a specific role member, you need to return this (k: party name, v: data).
Dict[Union[Role, str], Any]: Return each target and its data.
Union[Role, str]: The role class or party name.
Any: Return a python object, which we will serialize to bytes using pickle.dumps.
"""
assert step_name in self._own_steps, f"{step_name} is not implemented."

Expand Down
31 changes: 13 additions & 18 deletions iflearner/business/hetero/model/logistic_regression/lr_arbiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,24 @@
# limitations under the License.
# ==============================================================================

from typing import List, Tuple
from iflearner.business.hetero.model.role import Role
from loguru import logger
from phe import paillier
from typing import Any, List, Dict, Union

from iflearner.business.hetero.model.role import Role, guest, host
from iflearner.business.hetero.model.base_model import BaseModel


class LRArbiter(BaseModel):
def __init__(self) -> None:
super().__init__()

self._register_own_step("step1", self.handle_own_step1)
self._register_own_step("step2", self.handle_own_step2)

self._register_another_step(
Role.host, "step2", self.handle_host_step2)

def handle_own_step1(self) -> Tuple[str, bytes]:
print("Arbiter step1")
return Role.guest, "Arbiter step1 completed.".encode("utf-8")

def handle_own_step2(self):
print("Arbiter step2")
return Role.guest, "Arbiter step2 completed.".encode("utf-8")
self._register_own_step("generate_he_keypair",
self.generate_he_keypair)

def handle_host_step2(self, data: List[Tuple[str, bytes]]):
for item in data:
print(item[0], item[1].decode("utf-8"))
def generate_he_keypair(self) -> Dict[Union[Role, str], Any]:
public_key, private_key = paillier.generate_paillier_keypair()
logger.info(f"Public key: {public_key}")
self._private_key = private_key
return {guest: public_key.n, host: public_key.n}

Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
role: arbiter
steps:
- name: step1
- name: generate_he_keypair
upstreams: null
- name: step2
upstreams:
- role: host
step: step2
45 changes: 16 additions & 29 deletions iflearner/business/hetero/model/logistic_regression/lr_guest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,40 +13,27 @@
# limitations under the License.
# ==============================================================================

from typing import List, Tuple
from iflearner.business.hetero.model.role import Role
from loguru import logger
from typing import Any, List, Dict, Union
from phe import paillier
from iflearner.business.hetero.model.role import Role, host, arbiter
from iflearner.business.hetero.model.base_model import BaseModel


class LRGuest(BaseModel):
def __init__(self) -> None:
super().__init__()

self._register_own_step("step1", self.handle_own_step1)
self._register_own_step("step2", self.handle_own_step2)

self._register_another_step(
Role.host, "step1", self.handle_host_step1)
self._register_another_step(
Role.arbiter, "step1", self.handle_arbiter_step1)
self._register_another_step(
Role.arbiter, "step2", self.handle_arbiter_step2)

def handle_own_step1(self) -> Tuple[str, bytes]:
print("Guest step1")
return Role.host, "Guest step1 completed.".encode("utf-8")

def handle_own_step2(self):
print("Guest step2")

def handle_host_step1(self, data: List[Tuple[str, bytes]]):
for item in data:
print(item[0], item[1].decode("utf-8"))

def handle_arbiter_step1(self, data: List[Tuple[str, bytes]]):
for item in data:
print(item[0], item[1].decode("utf-8"))

def handle_arbiter_step2(self, data: List[Tuple[str, bytes]]):
for item in data:
print(item[0], item[1].decode("utf-8"))
arbiter, "generate_he_keypair", self.received_he_public_key)
self._register_own_step("empty", self.empty)

def received_he_public_key(self, data: Dict[str, Any]) -> None:
for value in data.values():
public_key = paillier.PaillierPublicKey(value)
logger.info(f"Public key: {public_key}")
self._public_key = public_key
break

def empty(self) -> Dict[Union[Role, str], Any]:
pass
Original file line number Diff line number Diff line change
@@ -1,13 +1,6 @@
role: guest
steps:
- name: step1
- name: empty
upstreams:
- role: host
step: step1
- role: arbiter
step: step1
- name: step2
upstreams:
- role: arbiter
step: step2

step: generate_he_keypair
Loading

0 comments on commit 5b04d96

Please sign in to comment.