Skip to content

Commit

Permalink
parameter passng bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
paer committed Jul 10, 2023
1 parent 51e4e70 commit 8ff5658
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 45 deletions.
11 changes: 8 additions & 3 deletions fed/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
send,
start_recv_proxy,
start_send_proxy,
SendProxy
SendProxy,
RecvProxy
)
from fed.config import CrossSiloCommConfig

Expand All @@ -50,6 +51,7 @@ def init(
logging_level: str = 'info',
enable_waiting_for_other_parties_ready: bool = False,
send_proxy_cls: SendProxy = None,
recv_proxy_cls: RecvProxy = None,
global_cross_silo_comm_config: Optional[CrossSiloCommConfig] = None,
**kwargs,
):
Expand Down Expand Up @@ -182,20 +184,23 @@ def init(
logger.info(f'Started rayfed with {cluster_config}')
get_global_context().get_cleanup_manager().start(
exit_when_failure_sending=global_cross_silo_comm_config.exit_on_sending_failure)

if recv_proxy_cls is None:
from fed.proxy.grpc_proxy import GrpcRecvProxy
recv_proxy_cls = GrpcRecvProxy
# Start recv proxy
start_recv_proxy(
cluster=cluster,
party=party,
logging_level=logging_level,
tls_config=tls_config,
proxy_cls=None,
proxy_cls=recv_proxy_cls,
proxy_config=global_cross_silo_comm_config
)

if send_proxy_cls is None:
from fed.proxy.grpc_proxy import GrpcSendProxy
send_proxy_cls = GrpcSendProxy

start_send_proxy(
cluster=cluster,
party=party,
Expand Down
31 changes: 15 additions & 16 deletions fed/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,21 +85,6 @@ class CrossSiloCommConfig:
"""A class to store parameters used for Proxy Actor
Attributes:
grpc_retry_policy: a dict descibes the retry policy for
cross silo rpc call. If None, the following default retry policy
will be used. More details please refer to
`retry-policy <https://github.com/grpc/proposal/blob/master/A6-client-retries.md#retry-policy>`_. # noqa
.. code:: python
{
"maxAttempts": 4,
"initialBackoff": "0.1s",
"maxBackoff": "1s",
"backoffMultiplier": 2,
"retryableStatusCodes": [
"UNAVAILABLE"
]
}
proxier_fo_max_retries: The max restart times for the send proxy.
serializing_allowed_list: The package or class list allowed for
serializing(deserializating) cross silos. It's used for avoiding pickle
Expand Down Expand Up @@ -156,7 +141,21 @@ class CrossSiloGrpcCommConfig(CrossSiloCommConfig):
"""A class to store parameters used for GRPC communication
Attributes:
grpc_retry_policy:
grpc_retry_policy: a dict descibes the retry policy for
cross silo rpc call. If None, the following default retry policy
will be used. More details please refer to
`retry-policy <https://github.com/grpc/proposal/blob/master/A6-client-retries.md#retry-policy>`_. # noqa
.. code:: python
{
"maxAttempts": 4,
"initialBackoff": "0.1s",
"maxBackoff": "1s",
"backoffMultiplier": 2,
"retryableStatusCodes": [
"UNAVAILABLE"
]
}
grpc_channel_options: A list of tuples to store GRPC channel options,
e.g. [
('grpc.enable_retries', 1),
Expand Down
30 changes: 15 additions & 15 deletions fed/proxy/barriers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,9 @@
from typing import Dict, Optional

import cloudpickle
import grpc
import ray

import fed.config as fed_config
import fed.utils as fed_utils
from fed._private import constants

from fed.config import get_job_config, CrossSiloCommConfig
Expand Down Expand Up @@ -62,12 +60,14 @@ def pop_from_two_dim_dict(the_dict, key_a, key_b):

class SendProxy(abc.ABC):
def __init__(
self,
cluster: Dict,
party: str,
proxy_config = None) -> None:
self,
cluster: Dict,
party: str,
tls_config: Dict,
proxy_config = None) -> None:
self._cluster = cluster
self._party = party
self._tls_config = tls_config
self._proxy_config = proxy_config


Expand All @@ -81,7 +81,7 @@ async def send(
):
pass

async def is_ready():
async def is_ready(self):
return True

class RecvProxy(abc.ABC):
Expand Down Expand Up @@ -135,10 +135,11 @@ def __init__(
self._party = party
self._tls_config = tls_config
cross_silo_comm_config = fed_config.get_job_config().cross_silo_comm_config
self.proxy_instance: SendProxy = proxy_cls(cluster, party, cross_silo_comm_config)
self.proxy_instance: SendProxy = proxy_cls(cluster, party, tls_config, cross_silo_comm_config)

async def is_ready(self):
return self.proxy_instance.is_ready()
res = await self.proxy_instance.is_ready()
return res

async def send(
self,
Expand Down Expand Up @@ -197,20 +198,19 @@ def __init__(
self._party = party
self._tls_config = tls_config
cross_silo_comm_config = fed_config.get_job_config().cross_silo_comm_config
set_max_message_length(cross_silo_comm_config.messages_max_size_in_bytes)
self._proxy_instance: RecvProxy = proxy_cls(listen_addr, party, cross_silo_comm_config)
self._proxy_instance: RecvProxy = proxy_cls(listen_addr, party, tls_config, cross_silo_comm_config)

async def start(self):
await self._proxy_instance.start()

async def is_ready(self):
return self._proxy_instance.is_ready()

async def is_ready(self):
res = await self._proxy_instance.is_ready()
return res

async def get_data(self, src_party, upstream_seq_id, curr_seq_id):
self._stats["receive_op_count"] += 1
data = await self._proxy_instance.get_data(src_party, upstream_seq_id, curr_seq_id)
return cloudpickle.loads(data)
return data

async def _get_stats(self):
return self._stats
Expand Down
26 changes: 17 additions & 9 deletions fed/proxy/grpc_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,13 @@


class GrpcSendProxy(SendProxy):
def __init__(self, cluster: Dict, party: str, proxy_config=None) -> None:
super().__init__(cluster, party, proxy_config)
def __init__(self, cluster: Dict, party: str, tls_config: Dict, proxy_config=None) -> None:
super().__init__(cluster, party, tls_config, proxy_config)
self._grpc_metadata = proxy_config.http_header
set_max_message_length(proxy_config.messages_max_size_in_bytes)
self._retry_policy = None
if isinstance(proxy_config, CrossSiloGrpcCommConfig):
self._retry_policy = proxy_config.grpc_retry_policy
# Mapping the destination party name to the reused client stub.
self._stubs = {}

Expand All @@ -44,7 +47,7 @@ async def send(
dest_party_grpc_config = self.setup_grpc_config(dest_party)
tls_enabled = fed_utils.tls_enabled(self._tls_config)
grpc_options = dest_party_grpc_config['grpc_options']
grpc_options = get_grpc_options(retry_policy=self.retry_policy) if \
grpc_options = get_grpc_options(retry_policy=self._retry_policy) if \
grpc_options is None else fed_utils.dict2tuple(grpc_options)
if dest_party not in self._stubs:
if tls_enabled:
Expand All @@ -62,11 +65,13 @@ async def send(
stub = fed_pb2_grpc.GrpcServiceStub(channel)
self._stubs[dest_party] = stub

timeout = self._proxy_config.timeout_in_seconds
response = await send_data_grpc(
data=data,
stub=self._stubs[dest_party],
upstream_seq_id=upstream_seq_id,
downstream_seq_id=downstream_seq_id,
timeout=timeout,
metadata=dest_party_grpc_config['grpc_metadata'],
)
return response
Expand All @@ -83,7 +88,7 @@ def setup_grpc_config(self, dest_party):
dest_party_grpc_config['grpc_metadata'] = {
**global_grpc_metadata, **dest_party_grpc_metadata}

global_grpc_options = dict(get_grpc_options(self.retry_policy))
global_grpc_options = dict(get_grpc_options(self._retry_policy))
dest_party_grpc_options = dict(
self._cluster[dest_party].get('grpc_options', {})
)
Expand All @@ -101,9 +106,10 @@ async def send_data_grpc(
stub,
upstream_seq_id,
downstream_seq_id,
timeout,
metadata=None,
):
cluster_config = fed_config.get_cluster_config()
job_config = fed_config.get_job_config()
data = cloudpickle.dumps(data)
request = fed_pb2.SendDataRequest(
data=data,
Expand All @@ -114,7 +120,7 @@ async def send_data_grpc(
response = await stub.SendData(
request,
metadata=fed_utils.dict2tuple(metadata),
timeout=cluster_config.cross_silo_timeout,
timeout=timeout,
)
logger.debug(
f'Received data response from seq_id {downstream_seq_id}, '
Expand All @@ -124,8 +130,9 @@ async def send_data_grpc(


class GrpcRecvProxy(RecvProxy):
def __init__(self, listen_addr: str, party: str, proxy_config: CrossSiloCommConfig) -> None:
super().__init__(listen_addr, party, proxy_config)
def __init__(self, listen_addr: str, party: str, tls_config: Dict, proxy_config: CrossSiloCommConfig) -> None:
super().__init__(listen_addr, party, tls_config, proxy_config)
set_max_message_length(proxy_config.messages_max_size_in_bytes)
# Flag to see whether grpc server starts
self._server_ready_future = asyncio.Future()
self._retry_policy = None
Expand Down Expand Up @@ -159,7 +166,8 @@ async def start(self):

async def is_ready(self):
await self._server_ready_future
return self._server_ready_future.result()
res = self._server_ready_future.result()
return res

async def get_data(self, src_party, upstream_seq_id, curr_seq_id):
data_log_msg = f"data for {curr_seq_id} from {upstream_seq_id} of {src_party}"
Expand Down
10 changes: 10 additions & 0 deletions fed/tmp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import ray
import fed
ray.init()

cluster = {
'alice': {'address': '127.0.0.1:11012'},
'bob': {'address': '127.0.0.1:11011'},
}
party = 'alice'
fed.init(cluster, party)
4 changes: 2 additions & 2 deletions tests/test_retry_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import fed._private.compatible_utils as compatible_utils
import ray

from fed.config import CrossSiloCommConfig
from fed.config import CrossSiloGrpcCommConfig


@fed.remote
Expand Down Expand Up @@ -53,7 +53,7 @@ def run(party, is_inner_party):
fed.init(
cluster=cluster,
party=party,
cross_silo_comm_config=CrossSiloCommConfig(
cross_silo_comm_config=CrossSiloGrpcCommConfig(
grpc_retry_policy=retry_policy
)
)
Expand Down

0 comments on commit 8ff5658

Please sign in to comment.