Skip to content

Commit

Permalink
Adjust role name
Browse files Browse the repository at this point in the history
  • Loading branch information
xs233 committed Sep 29, 2022
1 parent 92e8342 commit 806522b
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 30 deletions.
11 changes: 1 addition & 10 deletions iflearner/business/hetero/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

from iflearner.business.hetero.parser import Parser
from iflearner.business.hetero.builder.builders import Builders
from iflearner.business.hetero.model.role import Role, role_class
from iflearner.communication.hetero.hetero_network import HeteroNetwork

parser = Parser()
Expand Down Expand Up @@ -60,9 +59,6 @@ def _exec_steps(self, steps: Any) -> None:
Args:
steps (Any): Details of the steps.
Raise:
Exception(f"The return type of {step.name} is illegal.")
"""
for step in steps:
logger.info(f"{step}")
Expand All @@ -83,12 +79,7 @@ def _exec_steps(self, steps: Any) -> None:

for name, data in result.items():
data = pickle.dumps(data)
if isinstance(name, Role):
self._network.push(str(name), None, step.name, data)
elif isinstance(name, str):
self._network.push(None, name, step.name, data)
else:
raise Exception(f"The return type of {step.name} is illegal.")
self._network.push(name, step.name, data)

def _exec_model_flow(self) -> None:
"""Execute model flow.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,6 @@ def received_he_public_key(self, data: Dict[str, Any]) -> None:
logger.info(f"Public key: {public_key}")
self._public_key = public_key

def get_he_public_key(self) -> None:
pass

def calc_guest_partial_result(self) -> Dict[Union[Role, str], Any]:
"""Calculate your own partial results.
Expand Down
33 changes: 16 additions & 17 deletions iflearner/communication/hetero/hetero_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
import time
from threading import Thread
from loguru import logger
from typing import Dict, List, Tuple
from typing import Dict, List, Tuple, Union
from iflearner.business.hetero.model.role import Role
from iflearner.communication.base.base_server import start_server
from iflearner.communication.hetero.hetero_client import HeteroClient
from iflearner.communication.hetero.hetero_server import HeteroServer
Expand Down Expand Up @@ -77,39 +78,37 @@ def pull(self, role: str, step_name: str) -> Dict[str, bytes]:

return self._server.messages.pop(key)

def push(self, role: str, party_name: str, step_name: str, data: bytes) -> None:
def push(self, name: Union[Role, str], step_name: str, data: bytes) -> None:
"""Push a message to a specific destination, which you can specify using a role or party name.
If you use a role name, we will send the data to all role members.
If you use a party name, we will only send the data to the specific target.
Args:
role (str): The role name.
party_name (str): The party name.
name (Union[Role, str]): The role or party name.
step_name (str): Current step name.
data (bytes): The data needs to be sent.
Raises:
Exception(f"{role} is not existed."): Role is not existed.
Exception(f"{party_name} is not existed."): Party is not existed.
Exception(f"Role {name} is not existed."): Role is not existed.
Exception(f"Party {name} is not existed."): Party is not existed.
"""
logger.info(
f"Post message, role: {role}, party: {party_name}, step: {step_name}, data length: {len(data)}")
if role is not None:
if role not in self._parties_index_role_name:
raise Exception(f"{role} is not existed.")
f"Post message, name: {name}, step: {step_name}, data length: {len(data)}")
if isinstance(name, Role):
name = str(name)
if name not in self._parties_index_role_name:
raise Exception(f"Role {name} is not existed.")

for client in self._parties_index_role_name[role]:
for client in self._parties_index_role_name[name]:
while True:
try:
client.post(step_name, data)
break
except Exception as e:
logger.warning(e)
time.sleep(3)
elif party_name is not None:
if party_name not in self._parties_index_party_name:
raise Exception(f"{party_name} is not existed.")

self._parties_index_party_name[party_name].post(step_name, data)
else:
raise Exception("You need to specify one of role and party name.")
if name not in self._parties_index_party_name:
raise Exception(f"Party {name} is not existed.")

self._parties_index_party_name[name].post(step_name, data)

0 comments on commit 806522b

Please sign in to comment.