Skip to content

Commit

Permalink
Use document to register method automaticly.
Browse files Browse the repository at this point in the history
  • Loading branch information
xs233 committed Oct 9, 2022
1 parent fbc2534 commit f7142cf
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 27 deletions.
46 changes: 46 additions & 0 deletions iflearner/business/hetero/model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# ==============================================================================

from abc import ABC, abstractmethod
from inspect import getmembers, ismethod
from typing import Dict, Tuple, List, Union, Any
from iflearner.business.hetero.model.role import Role

Expand Down Expand Up @@ -48,6 +49,7 @@ class BaseModel(ABC):
def __init__(self) -> None:
self._another_steps: Dict[str, handle_another_step] = {}
self._own_steps: Dict[str, handle_own_step] = {}
self._bind_methods()

@abstractmethod
def set_hyper_params(self, hyper_params: Any) -> None:
Expand All @@ -58,6 +60,50 @@ def set_hyper_params(self, hyper_params: Any) -> None:
"""
pass

def _bind_methods(self):
"""Analyze method documents and then register them to specific steps.
Format:
Bind:
step: The step name.
role (optional): The role name.(guest host arbiter)
If role is None, it means the current method is your own step handler.
If role is not None, it means the current method is to handle the step of other role.
"""
bind_tag = "Bind:"
step_tag = "step:"
role_tag = "role:"

functions_list = [o for o in getmembers(self) if ismethod(o[1])]
for func in functions_list:
if func[1].__doc__ is None:
continue

lines = func[1].__doc__.split('\n')
catch = False
step = None
role = None
for line in lines:
line = line.strip()
if catch and len(line) > 0:
if line.startswith(step_tag):
step = line[len(step_tag):].strip()
elif line.startswith(role_tag):
role = line[len(role_tag):].strip()
else:
if step is not None:
if role is not None:
self._register_another_step(
role, step, getattr(self, func[0]))
else:
self._register_own_step(
step, getattr(self, func[0]))
break

if line.lower() == bind_tag.lower():
catch = True

def _register_another_step(self, role: Role, step_name: str, func: handle_another_step) -> None:
"""Register a another step handler.
Expand Down
27 changes: 22 additions & 5 deletions iflearner/business/hetero/model/logistic_regression/lr_arbiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ class LRArbiter(BaseModel):
def __init__(self) -> None:
super().__init__()

self._register_another_step(guest, "calc_final_result_with_host", self.received_guest_encrypted_data)
self._register_another_step(host, "calc_final_result_with_guest", self.received_host_encrypted_data)
# self._register_another_step(guest, "calc_final_result_with_host", self.received_guest_encrypted_data)
# self._register_another_step(host, "calc_final_result_with_guest", self.received_host_encrypted_data)

self._register_own_step("generate_he_keypair", self.generate_he_keypair)
self._register_own_step("decrypt_guest_data", self.decrypt_guest_data)
self._register_own_step("decrypt_host_data", self.decrypt_host_data)
# self._register_own_step("generate_he_keypair", self.generate_he_keypair)
# self._register_own_step("decrypt_guest_data", self.decrypt_guest_data)
# self._register_own_step("decrypt_host_data", self.decrypt_host_data)

def set_hyper_params(self, hyper_params: Any) -> None:
"""Set hyper params.
Expand All @@ -45,6 +45,9 @@ def set_hyper_params(self, hyper_params: Any) -> None:
def generate_he_keypair(self) -> Dict[Union[Role, str], Any]:
"""Generate HE public key and private key.
Bind:
step: generate_he_keypair
Returns:
Dict[Union[Role, str], Any]: Return the HE public key to the guest and host.
"""
Expand All @@ -56,6 +59,10 @@ def generate_he_keypair(self) -> Dict[Union[Role, str], Any]:
def received_guest_encrypted_data(self, data: Dict[str, Any]) -> None:
"""Save encrypted data from the guest.
Bind:
step: calc_final_result_with_host
role: guest
Args:
data (Dict[str, Any]): Guest party name and encrypted data.
"""
Expand All @@ -66,6 +73,10 @@ def received_guest_encrypted_data(self, data: Dict[str, Any]) -> None:
def received_host_encrypted_data(self, data: Dict[str, Any]) -> None:
"""Save encrypted data from the host.
Bind:
step: calc_final_result_with_guest
role: host
Args:
data (Dict[str, Any]): Host party name and encrypted data.
"""
Expand All @@ -74,6 +85,9 @@ def received_host_encrypted_data(self, data: Dict[str, Any]) -> None:
def decrypt_guest_data(self) -> Dict[Union[Role, str], Any]:
"""Decrypt guest data.
Bind:
step: decrypt_guest_data
Returns:
Dict[Union[Role, str], Any]: Return guest role name and its decrypted data.
"""
Expand All @@ -83,6 +97,9 @@ def decrypt_guest_data(self) -> Dict[Union[Role, str], Any]:
def decrypt_host_data(self) -> Dict[Union[Role, str], Any]:
"""Decrypt host data.
Bind:
step: decrypt_host_data
Returns:
Dict[Union[Role, str], Any]: Return host role name and its decrypted data.
"""
Expand Down
40 changes: 29 additions & 11 deletions iflearner/business/hetero/model/logistic_regression/lr_guest.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,17 @@ def __init__(self) -> None:
self._x, self._y = get_guest_data()
self._weights = np.zeros(self._x.shape[1])

self._register_another_step(
arbiter, "generate_he_keypair", self.received_he_public_key)
self._register_another_step(
host, "calc_host_partial_result", self.received_host_partial_result)
self._register_another_step(
arbiter, "decrypt_guest_data", self.received_weights)

self._register_own_step("calc_guest_partial_result",
self.calc_guest_partial_result)
self._register_own_step(
"calc_final_result_with_host", self.calc_final_result_with_host)
# self._register_another_step(
# arbiter, "generate_he_keypair", self.received_he_public_key)
# self._register_another_step(
# host, "calc_host_partial_result", self.received_host_partial_result)
# self._register_another_step(
# arbiter, "decrypt_guest_data", self.received_weights)

# self._register_own_step("calc_guest_partial_result",
# self.calc_guest_partial_result)
# self._register_own_step(
# "calc_final_result_with_host", self.calc_final_result_with_host)

def set_hyper_params(self, hyper_params: Any) -> None:
"""Set hyper params.
Expand All @@ -58,6 +58,10 @@ def set_hyper_params(self, hyper_params: Any) -> None:
def received_he_public_key(self, data: Dict[str, Any]) -> None:
"""Save the HE public key received from the arbiter.
Bind:
step: generate_he_keypair
role: arbiter
Args:
data (Dict[str, Any]): Arbiter party name and public key.
"""
Expand All @@ -68,6 +72,9 @@ def received_he_public_key(self, data: Dict[str, Any]) -> None:
def calc_guest_partial_result(self) -> Dict[Union[Role, str], Any]:
"""Calculate your own partial results.
Bind:
step: calc_guest_partial_result
Returns:
Dict[Union[Role, str], Any]: Return HE-encrypted data to the host.
"""
Expand All @@ -80,6 +87,10 @@ def calc_guest_partial_result(self) -> Dict[Union[Role, str], Any]:
def received_host_partial_result(self, data: Dict[str, Any]) -> None:
"""Save the host partial result.
Bind:
step: calc_host_partial_result
role: host
Args:
data (Dict[str, Any]): Host party name and its data.
"""
Expand All @@ -89,6 +100,9 @@ def received_host_partial_result(self, data: Dict[str, Any]) -> None:
def calc_final_result_with_host(self) -> Dict[Union[Role, str], Any]:
"""Calculate the final result combined with the host.
Bind:
step: calc_final_result_with_host
Returns:
Dict[Union[Role, str], Any]: Return the encrypted result to the arbiter.
"""
Expand All @@ -107,6 +121,10 @@ def calc_final_result_with_host(self) -> Dict[Union[Role, str], Any]:
def received_weights(self, data: Dict[str, Any]) -> None:
"""Received weights from the arbiter.
Bind:
step: decrypt_guest_data
role: arbiter
Args:
data (Dict[str, Any]): The decrypted data from the arbiter.
"""
Expand Down
40 changes: 29 additions & 11 deletions iflearner/business/hetero/model/logistic_regression/lr_host.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,17 @@ def __init__(self) -> None:
self._x = get_host_data()
self._weights = np.zeros(self._x.shape[1])

self._register_another_step(
arbiter, "generate_he_keypair", self.received_he_public_key)
self._register_another_step(
guest, "calc_guest_partial_result", self.received_guest_partial_result)
self._register_another_step(
arbiter, "decrypt_host_data", self.received_weights)

self._register_own_step("calc_host_partial_result",
self.calc_host_partial_result)
self._register_own_step(
"calc_final_result_with_guest", self.calc_final_result_with_guest)
# self._register_another_step(
# arbiter, "generate_he_keypair", self.received_he_public_key)
# self._register_another_step(
# guest, "calc_guest_partial_result", self.received_guest_partial_result)
# self._register_another_step(
# arbiter, "decrypt_host_data", self.received_weights)

# self._register_own_step("calc_host_partial_result",
# self.calc_host_partial_result)
# self._register_own_step(
# "calc_final_result_with_guest", self.calc_final_result_with_guest)

def set_hyper_params(self, hyper_params: Any) -> None:
"""Set hyper params.
Expand All @@ -57,6 +57,10 @@ def set_hyper_params(self, hyper_params: Any) -> None:
def received_he_public_key(self, data: Dict[str, Any]) -> None:
"""Save the HE public key received from the arbiter.
Bind:
step: generate_he_keypair
role: arbiter
Args:
data (Dict[str, Any]): Arbiter party name and public key.
"""
Expand All @@ -67,6 +71,9 @@ def received_he_public_key(self, data: Dict[str, Any]) -> None:
def calc_host_partial_result(self) -> Dict[Union[Role, str], Any]:
"""Calculate your own partial results.
Bind:
step: calc_host_partial_result
Returns:
Dict[Union[Role, str], Any]: Return HE-encrypted data to the guest.
"""
Expand All @@ -82,6 +89,10 @@ def calc_host_partial_result(self) -> Dict[Union[Role, str], Any]:
def received_guest_partial_result(self, data: Dict[str, Any]) -> None:
"""Save the guest partial result.
Bind:
step: calc_guest_partial_result
role: guest
Args:
data (Dict[str, Any]): Guest party name and its data.
"""
Expand All @@ -90,6 +101,9 @@ def received_guest_partial_result(self, data: Dict[str, Any]) -> None:
def calc_final_result_with_guest(self) -> Dict[Union[Role, str], Any]:
"""Calculate the final result combined with the guest.
Bind:
step: calc_final_result_with_guest
Returns:
Dict[Union[Role, str], Any]: Return the encrypted result to the arbiter.
"""
Expand All @@ -103,6 +117,10 @@ def calc_final_result_with_guest(self) -> Dict[Union[Role, str], Any]:
def received_weights(self, data: Dict[str, Any]) -> None:
"""Received weights from the arbiter.
Bind:
step: decrypt_host_data
role: arbiter
Args:
data (Dict[str, Any]): The decrypted data from the arbiter.
"""
Expand Down

0 comments on commit f7142cf

Please sign in to comment.