-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
19 changed files
with
1,116 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# Copyright 2022 iFLYTEK. All Rights Reserved. | ||
# # | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# # | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# # | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
from iflearner.business.hetero.model.role import Role | ||
from iflearner.business.hetero.model.base_model import BaseModel | ||
from iflearner.business.hetero.model.logistic_regression import lr_guest, lr_host, lr_arbiter | ||
|
||
from iflearner.business.hetero.builder.model_builder import ModelBuilder | ||
|
||
|
||
class LRBuilder(ModelBuilder): | ||
|
||
def create_role_model_instance(self, role: str) -> BaseModel: | ||
"""Create a model instance base on specific role. | ||
Args: | ||
role (str): The role name. | ||
Returns: | ||
BaseModel: Return the base class. | ||
""" | ||
if role == Role.guest: | ||
return lr_guest.LRGuest() | ||
elif role == Role.host: | ||
return lr_host.LRHost() | ||
elif role == Role.arbiter: | ||
return lr_arbiter.LRArbiter() | ||
|
||
raise Exception(f"{role} is not existed.") | ||
|
||
def get_role_model_flow_file(self, role: str) -> str: | ||
"""Get model flow file by role name. | ||
Args: | ||
role (str): The role name. | ||
Returns: | ||
str: Return the filename. | ||
""" | ||
if role == Role.guest: | ||
return "lr_guest_flow.yaml" | ||
elif role == Role.host: | ||
return "lr_host_flow.yaml" | ||
elif role == Role.arbiter: | ||
return "lr_arbiter_flow.yaml" | ||
|
||
raise Exception(f"{role} is not existed.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# Copyright 2022 iFLYTEK. All Rights Reserved. | ||
# # | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# # | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# # | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
from abc import ABC, abstractmethod | ||
from iflearner.business.hetero.model.base_model import BaseModel | ||
|
||
|
||
class ModelBuilder(ABC): | ||
"""Build a model instance base on the role you specify. | ||
""" | ||
|
||
@abstractmethod | ||
def create_role_model_instance(self, role: str) -> BaseModel: | ||
"""Create a model instance base on specific role. | ||
Args: | ||
role (str): The role name. | ||
Returns: | ||
BaseModel: Return the base class. | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def get_role_model_flow_file(self, role: str) -> str: | ||
"""Get model flow file by role name. | ||
Args: | ||
role (str): The role name. | ||
Returns: | ||
str: Return the filename. | ||
""" | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
# Copyright 2022 iFLYTEK. All Rights Reserved. | ||
# # | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# # | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# # | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
import os | ||
import time | ||
from os.path import join | ||
from loguru import logger | ||
|
||
from iflearner.business.hetero.parser import Parser | ||
from iflearner.business.hetero.builder.lr_builder import LRBuilder | ||
from iflearner.communication.hetero.hetero_network import HeteroNetwork | ||
|
||
parser = Parser() | ||
|
||
class Driver: | ||
"""Drive the entire process according to the flow yaml. | ||
Flow yaml format: | ||
role: string | ||
steps: | ||
- name: string | ||
upstreams: | ||
- role: string | ||
step: string | ||
""" | ||
|
||
def __init__(self) -> None: | ||
"""Init the class. | ||
""" | ||
logger.add(f"log/{parser.model_name}_{parser.role_name}.log", backtrace=True, diagnose=True) | ||
self._model = LRBuilder().create_role_model_instance(parser.role_name) | ||
parser.parse_model_flow_file(join("model", parser.model_name, LRBuilder().get_role_model_flow_file(parser.role_name))) | ||
|
||
logger.info(f"Model flow: {parser.model_flow}") | ||
logger.info(f"Network config: {parser.network_config}") | ||
self._network = HeteroNetwork(*parser.network_config) | ||
|
||
def _exec_flow(self) -> None: | ||
"""Execute flow. | ||
Raise: | ||
Exception(f"The return of handle {step_name} is illegal.") | ||
""" | ||
for step in parser.model_flow["steps"]: | ||
step_name = step["name"] | ||
upstreams = step["upstreams"] | ||
logger.info(f"Step: {step_name}, Upstreams: {upstreams}") | ||
if upstreams is not None: | ||
for upstream in upstreams: | ||
data_list = None | ||
while data_list is None: | ||
data_list = self._network.pull( | ||
upstream["role"], upstream["step"]) | ||
time.sleep(1) | ||
|
||
self._model.handle_upstream( | ||
upstream["role"], upstream["step"], data_list) | ||
|
||
result = self._model.handle_step(step_name) | ||
if result is None: | ||
continue | ||
|
||
if isinstance(result, tuple): | ||
self._network.push(result[0], None, step_name, result[1]) | ||
elif isinstance(result, dict): | ||
for party_name, data in result.items(): | ||
self._network.push(None, party_name, step_name, data) | ||
else: | ||
raise Exception(f"The return of handle {step_name} is illegal.") | ||
|
||
def run(self, epoch: int=1) -> None: | ||
"""Loop execution process. | ||
Args: | ||
epoch (int, optional): The number of epochs we need to run. Defaults to 1. | ||
""" | ||
for i in range(epoch): | ||
logger.info(f"Start epoch {i+1}") | ||
self._exec_flow() | ||
|
||
if __name__ == "__main__": | ||
parser.parse_task_configuration_file() | ||
driver = Driver() | ||
driver.run() | ||
os._exit(0) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
# Copyright 2022 iFLYTEK. All Rights Reserved. | ||
# # | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# # | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# # | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
from typing import Dict, Tuple, List, Union | ||
|
||
|
||
def handle_another_step(data: List[Tuple[str, bytes]]) -> None: | ||
"""Handle a step from another role. | ||
Args: | ||
data (List[Tuple[str, bytes]]): List data for all role members. | ||
""" | ||
pass | ||
|
||
|
||
def handle_own_step() -> Union[Tuple[str, bytes], Dict[str, bytes]]: | ||
"""Handle a own step. | ||
Returns: | ||
Union[Tuple[str, bytes], Dict[str, bytes]]: | ||
Tuple[str, bytes]: If you want to send the same data to all role members, you need to return this (k: role name, v: data). | ||
Dict[str, bytes]: If you want to send unique data to a specific role member, you need to return this (k: party name, v: data). | ||
""" | ||
pass | ||
|
||
|
||
class BaseModel: | ||
'''Define each step of model training and evaluation. | ||
You need to split the whole process into steps, then you need to implement it and register it with self._register_own_step. | ||
In most cases, a step depends on steps completed by other roles. So you need to implement the response of those upstream steps and register it with self._register_another_step. | ||
''' | ||
|
||
def __init__(self) -> None: | ||
self._another_steps: Dict[str, handle_another_step] = {} | ||
self._own_steps: Dict[str, handle_own_step] = {} | ||
|
||
def _register_another_step(self, role: str, step_name: str, func: handle_another_step) -> None: | ||
"""Register a another step handler. | ||
Args: | ||
role (str): The target role name. | ||
step_name (str): Unique name for the step. | ||
func (handle_another_step): The handler you implement. | ||
""" | ||
self._another_steps[f"{role}.{step_name}"] = func | ||
|
||
def _register_own_step(self, step_name: str, func: handle_own_step) -> None: | ||
"""Register a own step handler. | ||
Args: | ||
step_name (str): Unique name for the step. | ||
func (handle_own_step): The handler you implement. | ||
""" | ||
self._own_steps[step_name] = func | ||
|
||
def handle_upstream(self, role: str, step_name: str, data: List[Tuple[str, bytes]]) -> None: | ||
"""Handle specific upstream step from other role. | ||
Args: | ||
role (str): The target role name. | ||
step_name (str): Unique name for the step. | ||
data (List[Tuple[str, bytes]]): List data for all role members. | ||
""" | ||
key = f"{role}.{step_name}" | ||
assert key in self._another_steps, f"{key} is not implemented." | ||
|
||
self._another_steps[key](data) | ||
|
||
def handle_step(self, step_name: str) -> Union[Tuple[str, bytes], Dict[str, bytes]]: | ||
"""Handle own specific step. | ||
Args: | ||
step_name (str): Unique name for the step. | ||
Returns: | ||
Union[Tuple[str, bytes], Dict[str, bytes]]: | ||
Tuple[str, bytes]: If you want to send the same data to all role members, you need to return this (k: role name, v: data). | ||
Dict[str, bytes]: If you want to send unique data to a specific role member, you need to return this (k: party name, v: data). | ||
""" | ||
assert step_name in self._own_steps, f"{step_name} is not implemented." | ||
|
||
return self._own_steps[step_name]() |
Oops, something went wrong.