Skip to content

Commit

Permalink
per party config & timeout ms & listen IPv6
Browse files Browse the repository at this point in the history
Signed-off-by: paer <[email protected]>
  • Loading branch information
paer committed Jul 13, 2023
1 parent 2f9ce57 commit dec6153
Show file tree
Hide file tree
Showing 8 changed files with 164 additions and 91 deletions.
6 changes: 0 additions & 6 deletions fed/_private/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,6 @@

KEY_OF_CROSS_SILO_COMM_CONFIG = "CROSS_SILO_COMM_CONFIG"

KEY_OF_CROSS_SILO_SERIALIZING_ALLOWED_LIST = "CROSS_SILO_SERIALIZING_ALLOWED_LIST" # noqa

KEY_OF_CROSS_SILO_MESSAGES_MAX_SIZE_IN_BYTES = "CROSS_SILO_MESSAGES_MAX_SIZE_IN_BYTES" # noqa

KEY_OF_CROSS_SILO_TIMEOUT_IN_SECONDS = "CROSS_SILO_TIMEOUT_IN_SECONDS"

RAYFED_LOG_FMT = "%(asctime)s %(levelname)s %(filename)s:%(lineno)s [%(party)s] -- %(message)s" # noqa

RAYFED_DATE_FMT = "%Y-%m-%d %H:%M:%S"
Expand Down
13 changes: 0 additions & 13 deletions fed/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,6 @@ def init(
# used if not provided.
'listen_addr': '0.0.0.0:10001',
'cross_silo_comm_config': CrossSiloCommConfig
# (Optional) The party specific metadata sent with grpc requests
'grpc_metadata': (('token', 'alice-token'),),
'grpc_options': [
('grpc.default_authority', 'alice'),
('grpc.max_send_message_length', 50 * 1024 * 1024)
]
},
'bob': {
# The address for other parties.
Expand Down Expand Up @@ -122,13 +116,6 @@ def init(
global_cross_silo_comm_config: Global cross-silo communication related
config that are applied to all connections. Supported configs
can refer to CrossSiloCommConfig in config.py.
dest_party_comm_config: Communication config for the destination party
specifed by the key. E.g.
.. code:: python
{
'alice': alice_CrossSiloCommConfig,
'bob': bob_CrossSiloCommConfig
}
Examples:
>>> import fed
Expand Down
167 changes: 126 additions & 41 deletions fed/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
import fed._private.compatible_utils as compatible_utils
import fed._private.constants as fed_constants
import cloudpickle
from typing import Dict, List, Optional
import json

from typing import Dict, List, Optional
from dataclasses import dataclass


class ClusterConfig:
"""A local cache of cluster configuration items."""
Expand All @@ -28,18 +30,6 @@ def current_party(self):
def tls_config(self):
return self._data[fed_constants.KEY_OF_TLS_CONFIG]

@property
def serializing_allowed_list(self):
return self._data[fed_constants.KEY_OF_CROSS_SILO_SERIALIZING_ALLOWED_LIST]

@property
def cross_silo_timeout(self):
return self._data[fed_constants.KEY_OF_CROSS_SILO_TIMEOUT_IN_SECONDS]

@property
def cross_silo_messages_max_size(self):
return self._data[fed_constants.KEY_OF_CROSS_SILO_MESSAGES_MAX_SIZE_IN_BYTES]


class JobConfig:
def __init__(self, raw_bytes: bytes) -> None:
Expand Down Expand Up @@ -81,6 +71,99 @@ def get_job_config():
return _job_config


# class CrossSiloCommConfig:
# """A class to store parameters used for Proxy Actor

# Attributes:
# proxy_max_restarts: The max restart times for the send proxy.
# serializing_allowed_list: The package or class list allowed for
# serializing(deserializating) cross silos. It's used for avoiding pickle
# deserializing execution attack when crossing solis.
# send_resource_label: Customized resource label, the SendProxyActor
# will be scheduled based on the declared resource label. For example,
# when setting to `{"my_label": 1}`, then the SendProxyActor will be started
# only on Nodes with `{"resource": {"my_label": $NUM}}` where $NUM >= 1.
# recv_resource_label: Customized resource label, the RecverProxyActor
# will be scheduled based on the declared resource label. For example,
# when setting to `{"my_label": 1}`, then the RecverProxyActor will be started
# only on Nodes with `{"resource": {"my_label": $NUM}}` where $NUM >= 1.
# exit_on_sending_failure: whether exit when failure on
# cross-silo sending. If True, a SIGTERM will be signaled to self
# if failed to sending cross-silo data.
# messages_max_size_in_bytes: The maximum length in bytes of
# cross-silo messages.
# If None, the default value of 500 MB is specified.
# timeout_in_seconds: The timeout in seconds of a cross-silo RPC call.
# It's 60 by default.
# http_header: The HTTP header, e.g. metadata in grpc, sent with the RPC request.
# This won't override basic tcp headers, such as `user-agent`, but concat
# them together.
# """
# def __init__(
# self,
# proxy_max_restarts: int = None,
# timeout_in_seconds: int = 60,
# messages_max_size_in_bytes: int = None,
# exit_on_sending_failure: Optional[bool] = False,
# serializing_allowed_list: Optional[Dict[str, str]] = None,
# send_resource_label: Optional[Dict[str, str]] = None,
# recv_resource_label: Optional[Dict[str, str]] = None,
# http_header: Optional[Dict[str, str]] = None) -> None:
# self.proxy_max_restarts = proxy_max_restarts
# self.timeout_in_seconds = timeout_in_seconds
# self.messages_max_size_in_bytes = messages_max_size_in_bytes
# self.exit_on_sending_failure = exit_on_sending_failure
# self.serializing_allowed_list = serializing_allowed_list
# self.send_resource_label = send_resource_label
# self.recv_resource_label = recv_resource_label
# self.http_header = http_header

# def __json__(self):
# return json.dumps(self.__dict__)

# @classmethod
# def from_json(cls, json_str):
# data = json.loads(json_str)
# return cls(**data)


# class CrossSiloGrpcCommConfig(CrossSiloCommConfig):
# """A class to store parameters used for GRPC communication

# Attributes:
# grpc_retry_policy: a dict descibes the retry policy for
# cross silo rpc call. If None, the following default retry policy
# will be used. More details please refer to
# `retry-policy <https://github.com/grpc/proposal/blob/master/A6-client-retries.md#retry-policy>`_. # noqa

# .. code:: python
# {
# "maxAttempts": 4,
# "initialBackoff": "0.1s",
# "maxBackoff": "1s",
# "backoffMultiplier": 2,
# "retryableStatusCodes": [
# "UNAVAILABLE"
# ]
# }
# grpc_channel_options: A list of tuples to store GRPC channel options,
# e.g. [
# ('grpc.enable_retries', 1),
# ('grpc.max_send_message_length', 50 * 1024 * 1024)
# ]
# """
# def __init__(self,
# grpc_channel_options: List = None,
# grpc_retry_policy: Dict[str, str] = None,
# *args,
# **kwargs):
# super().__init__(*args, **kwargs)
# self.grpc_retry_policy = grpc_retry_policy
# self.grpc_channel_options = grpc_channel_options



@dataclass
class CrossSiloCommConfig:
"""A class to store parameters used for Proxy Actor
Expand All @@ -103,30 +186,20 @@ class CrossSiloCommConfig:
messages_max_size_in_bytes: The maximum length in bytes of
cross-silo messages.
If None, the default value of 500 MB is specified.
timeout_in_seconds: The timeout in seconds of a cross-silo RPC call.
It's 60 by default.
timeout_in_ms: The timeout in mili-seconds of a cross-silo RPC call.
It's 60000 by default.
http_header: The HTTP header, e.g. metadata in grpc, sent with the RPC request.
This won't override basic tcp headers, such as `user-agent`, but concat
them together.
"""
def __init__(
self,
proxy_max_restarts: int = None,
timeout_in_seconds: int = 60,
messages_max_size_in_bytes: int = None,
exit_on_sending_failure: Optional[bool] = False,
serializing_allowed_list: Optional[Dict[str, str]] = None,
send_resource_label: Optional[Dict[str, str]] = None,
recv_resource_label: Optional[Dict[str, str]] = None,
http_header: Optional[Dict[str, str]] = None) -> None:
self.proxy_max_restarts = proxy_max_restarts
self.timeout_in_seconds = timeout_in_seconds
self.messages_max_size_in_bytes = messages_max_size_in_bytes
self.exit_on_sending_failure = exit_on_sending_failure
self.serializing_allowed_list = serializing_allowed_list
self.send_resource_label = send_resource_label
self.recv_resource_label = recv_resource_label
self.http_header = http_header
proxy_max_restarts: int = None
timeout_in_ms: int = 60000
messages_max_size_in_bytes: int = None
exit_on_sending_failure: Optional[bool] = False
serializing_allowed_list: Optional[Dict[str, str]] = None
send_resource_label: Optional[Dict[str, str]] = None
recv_resource_label: Optional[Dict[str, str]] = None
http_header: Optional[Dict[str, str]] = None

def __json__(self):
return json.dumps(self.__dict__)
Expand All @@ -136,7 +209,25 @@ def from_json(cls, json_str):
data = json.loads(json_str)
return cls(**data)

@classmethod
def from_dict(cls, data: Dict):
"""Initialize CrossSiloCommConfig from a dictionary.
Args:
data (Dict): Dictionary with keys as member variable names.
Returns:
CrossSiloCommConfig: An instance of CrossSiloCommConfig.
"""
# Get the attributes of the class
attrs = {attr for attr, _ in cls.__annotations__.items()}
# Filter the dictionary to only include keys that are attributes of the class
filtered_data = {key: value for key, value in data.items() if key in attrs}
return cls(**filtered_data)



@dataclass
class CrossSiloGrpcCommConfig(CrossSiloCommConfig):
"""A class to store parameters used for GRPC communication
Expand All @@ -162,11 +253,5 @@ class CrossSiloGrpcCommConfig(CrossSiloCommConfig):
('grpc.max_send_message_length', 50 * 1024 * 1024)
]
"""
def __init__(self,
grpc_channel_options: List = None,
grpc_retry_policy: Dict[str, str] = None,
*args,
**kwargs):
super().__init__(*args, **kwargs)
self.grpc_retry_policy = grpc_retry_policy
self.grpc_channel_options = grpc_channel_options
grpc_channel_options: List = None
grpc_retry_policy: Dict[str, str] = None
4 changes: 2 additions & 2 deletions fed/proxy/barriers.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def start_recv_proxy(
proxy_cls=proxy_cls
)
recver_proxy_actor.start.remote()
timeout = proxy_config.timeout_in_seconds if proxy_config is not None else 60
timeout = proxy_config.timeout_in_ms / 1000 if proxy_config is not None else 60
server_state = ray.get(recver_proxy_actor.is_ready.remote(), timeout=timeout)
assert server_state[0], server_state[1]
logger.info("RecverProxy has successfully created.")
Expand Down Expand Up @@ -308,7 +308,7 @@ def start_send_proxy(
logging_level=logging_level,
proxy_cls=proxy_cls
)
timeout = get_job_config().cross_silo_comm_config.timeout_in_seconds
timeout = get_job_config().cross_silo_comm_config.timeout_in_ms / 1000
assert ray.get(_SEND_PROXY_ACTOR.is_ready.remote(), timeout=timeout)
logger.info("SendProxyActor has successfully created.")

Expand Down
4 changes: 3 additions & 1 deletion fed/proxy/grpc_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ async def send(
stub = fed_pb2_grpc.GrpcServiceStub(channel)
self._stubs[dest_party] = stub

timeout = self._proxy_config.timeout_in_seconds
timeout = self._proxy_config.timeout_in_ms / 1000
response = await send_data_grpc(
data=data,
stub=self._stubs[dest_party],
Expand Down Expand Up @@ -274,6 +274,7 @@ async def _run_grpc_server(
port, event, all_data, party, lock,
server_ready_future, tls_config=None, grpc_options=None
):
print(f"ReceiveProxy binding port {port}, options: {grpc_options}...")
server = grpc.aio.server(options=grpc_options)
fed_pb2_grpc.add_GrpcServiceServicer_to_server(
SendDataService(event, all_data, party, lock), server
Expand All @@ -290,6 +291,7 @@ async def _run_grpc_server(
server.add_secure_port(f'[::]:{port}', server_credentials)
else:
server.add_insecure_port(f'[::]:{port}')
# server.add_insecure_port(f'[::]:{port}')

msg = f"Succeeded to add port {port}."
await server.start()
Expand Down
56 changes: 32 additions & 24 deletions tests/test_listen_addr.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,29 +72,36 @@ def run(party):

compatible_utils.init_ray(address='local')
occupied_port = 11020
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
# NOTE(NKcqx): Firstly try to bind IPv6 because the grpc server will do so.
# Otherwise this UT will false because socket bind $occupied_port
# on IPv4 address while grpc server listendn Ipv6 address.
try:
s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
# Pre-occuping the port
s.bind(("localhost", occupied_port))

cluster = {
'alice': {
'address': '127.0.0.1:11012',
'listen_addr': f'0.0.0.0:{occupied_port}'},
'bob': {
'address': '127.0.0.1:11011',
'listen_addr': '0.0.0.0:11011'},
}

# Starting grpc server on an used port will cause AssertionError
with pytest.raises(AssertionError):
fed.init(cluster=cluster, party=party)

import time

time.sleep(5)
s.close()
fed.shutdown()
ray.shutdown()
s.bind(("::", occupied_port))
except OSError:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("127.0.0.1", occupied_port))

cluster = {
'alice': {
'address': '127.0.0.1:11012',
'listen_addr': f'0.0.0.0:{occupied_port}'},
'bob': {
'address': '127.0.0.1:11011',
'listen_addr': '0.0.0.0:11011'},
}

# Starting grpc server on an used port will cause AssertionError
with pytest.raises(AssertionError):
fed.init(cluster=cluster, party=party)

import time

time.sleep(5)
s.close()
fed.shutdown()
ray.shutdown()

p_alice = multiprocessing.Process(target=run, args=('alice',))
p_alice.start()
Expand All @@ -103,6 +110,7 @@ def run(party):


if __name__ == "__main__":
import sys
# import sys

sys.exit(pytest.main(["-sv", __file__]))
# sys.exit(pytest.main(["-sv", __file__]))
test_listen_used_addr()
2 changes: 1 addition & 1 deletion tests/test_setup_proxy_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def run_failure(party):
global_cross_silo_comm_config=CrossSiloCommConfig(
send_resource_label=send_proxy_resources,
recv_resource_label=recv_proxy_resources,
timeout_in_seconds=10,
timeout_in_ms=10*1000,
)
)

Expand Down
3 changes: 0 additions & 3 deletions tests/test_transport_proxy_tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,6 @@ def test_n_to_1_transport():
constants.KEY_OF_CLUSTER_ADDRESSES: "",
constants.KEY_OF_CURRENT_PARTY_NAME: "",
constants.KEY_OF_TLS_CONFIG: tls_config,
constants.KEY_OF_CROSS_SILO_MESSAGES_MAX_SIZE_IN_BYTES: None,
constants.KEY_OF_CROSS_SILO_SERIALIZING_ALLOWED_LIST: {},
constants.KEY_OF_CROSS_SILO_TIMEOUT_IN_SECONDS: 60,
}

global_context.get_global_context().get_cleanup_manager().start()
Expand Down

0 comments on commit dec6153

Please sign in to comment.