Skip to content

Commit

Permalink
fix mock proxy UT
Browse files Browse the repository at this point in the history
  • Loading branch information
paer committed Jul 13, 2023
1 parent 7ffe367 commit 2f9ce57
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 42 deletions.
21 changes: 4 additions & 17 deletions fed/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(self, raw_bytes: bytes) -> None:

@property
def cross_silo_comm_config(self):
return self._data.get(fed_constants.KEY_OF_CROSS_SILO_COMM_CONFIG, {})
return self._data.get(fed_constants.KEY_OF_CROSS_SILO_COMM_CONFIG, CrossSiloCommConfig())


# A module level cache for the cluster configurations.
Expand Down Expand Up @@ -85,7 +85,7 @@ class CrossSiloCommConfig:
"""A class to store parameters used for Proxy Actor
Attributes:
proxier_fo_max_retries: The max restart times for the send proxy.
proxy_max_restarts: The max restart times for the send proxy.
serializing_allowed_list: The package or class list allowed for
serializing(deserializating) cross silos. It's used for avoiding pickle
deserializing execution attack when crossing solis.
Expand All @@ -111,15 +111,15 @@ class CrossSiloCommConfig:
"""
def __init__(
self,
proxier_fo_max_retries: int = None,
proxy_max_restarts: int = None,
timeout_in_seconds: int = 60,
messages_max_size_in_bytes: int = None,
exit_on_sending_failure: Optional[bool] = False,
serializing_allowed_list: Optional[Dict[str, str]] = None,
send_resource_label: Optional[Dict[str, str]] = None,
recv_resource_label: Optional[Dict[str, str]] = None,
http_header: Optional[Dict[str, str]] = None) -> None:
self.proxier_fo_max_retries = proxier_fo_max_retries
self.proxy_max_restarts = proxy_max_restarts
self.timeout_in_seconds = timeout_in_seconds
self.messages_max_size_in_bytes = messages_max_size_in_bytes
self.exit_on_sending_failure = exit_on_sending_failure
Expand Down Expand Up @@ -170,16 +170,3 @@ def __init__(self,
super().__init__(*args, **kwargs)
self.grpc_retry_policy = grpc_retry_policy
self.grpc_channel_options = grpc_channel_options


class CrossSiloBRPCConfig(CrossSiloCommConfig):
"""A class to store parameters used for GRPC communication
Attributes:
"""
def __init__(self,
brpc_options,
*args,
**kwargs):
super().__init__(*args, **kwargs)
self.brpc_options = brpc_options
6 changes: 3 additions & 3 deletions fed/proxy/barriers.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(
listen_addr: str,
party: str,
tls_config: Dict,
proxy_config: CrossSiloCommConfig
proxy_config: CrossSiloCommConfig=None
) -> None:
self._listen_addr = listen_addr
self._party = party
Expand Down Expand Up @@ -289,9 +289,9 @@ def start_send_proxy(
global _SEND_PROXY_ACTOR

actor_options = copy.deepcopy(_DEFAULT_SEND_PROXY_OPTIONS)
if proxy_config and proxy_config.proxier_fo_max_retries:
if proxy_config and proxy_config.proxy_max_restarts:
actor_options.update({
"max_task_retries": proxy_config.proxier_fo_max_retries,
"max_task_retries": proxy_config.proxy_max_restarts,
"max_restarts": 1,
})
if proxy_config and proxy_config.send_resource_label:
Expand Down
66 changes: 44 additions & 22 deletions tests/test_transport_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,17 @@


import fed._private.compatible_utils as compatible_utils
from fed.config import CrossSiloCommConfig, CrossSiloGrpcCommConfig
from fed._private import constants
from fed._private import global_context
from fed.grpc import fed_pb2, fed_pb2_grpc
from fed.proxy.barriers import send, start_recv_proxy, start_send_proxy
from fed.proxy.barriers import (
send,
start_recv_proxy,
start_send_proxy,
RecvProxy
)
from fed.proxy.grpc_proxy import GrpcSendProxy, GrpcRecvProxy


def test_n_to_1_transport():
Expand All @@ -38,9 +45,6 @@ def test_n_to_1_transport():
constants.KEY_OF_CLUSTER_ADDRESSES: "",
constants.KEY_OF_CURRENT_PARTY_NAME: "",
constants.KEY_OF_TLS_CONFIG: "",
constants.KEY_OF_CROSS_SILO_MESSAGES_MAX_SIZE_IN_BYTES: None,
constants.KEY_OF_CROSS_SILO_SERIALIZING_ALLOWED_LIST: {},
constants.KEY_OF_CROSS_SILO_TIMEOUT_IN_SECONDS: 60,
}
compatible_utils._init_internal_kv()
compatible_utils.kv.put(constants.KEY_OF_CLUSTER_CONFIG,
Expand All @@ -50,12 +54,21 @@ def test_n_to_1_transport():
SERVER_ADDRESS = "127.0.0.1:12344"
party = 'test_party'
cluster_config = {'test_party': {'address': SERVER_ADDRESS}}
config = CrossSiloGrpcCommConfig()
start_recv_proxy(
cluster_config,
party,
logging_level='info',
proxy_cls=GrpcRecvProxy,
proxy_config=config
)
start_send_proxy(
cluster_config,
party,
logging_level='info',
proxy_cls=GrpcSendProxy,
proxy_config=config
)
start_send_proxy(cluster_config, party, logging_level='info')

sent_objs = []
get_objs = []
Expand Down Expand Up @@ -147,7 +160,7 @@ def _test_start_recv_proxy(
).remote(
listen_addr=listen_addr,
party=party,
expected_metadata=expected_metadata,
expected_metadata=expected_metadata
)
recver_proxy_actor.run_grpc_server.remote()
assert ray.get(recver_proxy_actor.is_ready.remote())
Expand All @@ -159,14 +172,14 @@ def test_send_grpc_with_meta():
constants.KEY_OF_CLUSTER_ADDRESSES: "",
constants.KEY_OF_CURRENT_PARTY_NAME: "",
constants.KEY_OF_TLS_CONFIG: "",
constants.KEY_OF_CROSS_SILO_MESSAGES_MAX_SIZE_IN_BYTES: None,
constants.KEY_OF_CROSS_SILO_SERIALIZING_ALLOWED_LIST: {},
constants.KEY_OF_CROSS_SILO_TIMEOUT_IN_SECONDS: 60,
}
metadata = {"key": "value"}
send_proxy_config = CrossSiloCommConfig(
http_header=metadata
)
job_config = {
constants.KEY_OF_GRPC_METADATA: {
"key": "value"
}
constants.KEY_OF_CROSS_SILO_COMM_CONFIG:
send_proxy_config,
}
compatible_utils._init_internal_kv()
compatible_utils.kv.put(constants.KEY_OF_CLUSTER_CONFIG,
Expand All @@ -180,9 +193,14 @@ def test_send_grpc_with_meta():
cluster_config = {'test_party': {'address': SERVER_ADDRESS}}
_test_start_recv_proxy(
cluster_config, party, logging_level='info',
expected_metadata={"key": "value"},
expected_metadata=metadata,
)
start_send_proxy(cluster_config, party, logging_level='info')
start_send_proxy(
cluster_config,
party,
logging_level='info',
proxy_cls=GrpcSendProxy,
proxy_config=CrossSiloGrpcCommConfig())
sent_objs = []
sent_obj = send(party, "data", 0, 1)
sent_objs.append(sent_obj)
Expand All @@ -200,14 +218,12 @@ def test_send_grpc_with_party_specific_meta():
constants.KEY_OF_CLUSTER_ADDRESSES: "",
constants.KEY_OF_CURRENT_PARTY_NAME: "",
constants.KEY_OF_TLS_CONFIG: "",
constants.KEY_OF_CROSS_SILO_MESSAGES_MAX_SIZE_IN_BYTES: None,
constants.KEY_OF_CROSS_SILO_SERIALIZING_ALLOWED_LIST: {},
constants.KEY_OF_CROSS_SILO_TIMEOUT_IN_SECONDS: 60,
}
send_proxy_config = CrossSiloCommConfig(
http_header={"key": "value"})
job_config = {
constants.KEY_OF_GRPC_METADATA: {
"key": "value"
}
constants.KEY_OF_CROSS_SILO_COMM_CONFIG:
send_proxy_config,
}
compatible_utils._init_internal_kv()
compatible_utils.kv.put(constants.KEY_OF_CLUSTER_CONFIG,
Expand All @@ -221,14 +237,20 @@ def test_send_grpc_with_party_specific_meta():
cluster_parties_config = {
'test_party': {
'address': SERVER_ADDRESS,
'grpc_metadata': (('token', 'test-party-token'),)
'cross_silo_comm_config': CrossSiloCommConfig(
http_header={"token": "test-party-token"})
}
}
_test_start_recv_proxy(
cluster_parties_config, party, logging_level='info',
expected_metadata={"key": "value", "token": "test-party-token"},
)
start_send_proxy(cluster_parties_config, party, logging_level='info')
start_send_proxy(
cluster_parties_config,
party,
logging_level='info',
proxy_cls=GrpcSendProxy,
proxy_config=send_proxy_config)
sent_objs = []
sent_obj = send(party, "data", 0, 1)
sent_objs.append(sent_obj)
Expand Down
7 changes: 7 additions & 0 deletions tests/test_transport_proxy_tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from fed._private import constants
from fed._private import global_context
from fed.proxy.barriers import send, start_recv_proxy, start_send_proxy
from fed.proxy.grpc_proxy import GrpcSendProxy, GrpcRecvProxy
from fed.config import CrossSiloGrpcCommConfig


def test_n_to_1_transport():
Expand Down Expand Up @@ -58,17 +60,22 @@ def test_n_to_1_transport():
SERVER_ADDRESS = "127.0.0.1:65422"
party = 'test_party'
cluster_config = {'test_party': {'address': SERVER_ADDRESS}}
config = CrossSiloGrpcCommConfig()
start_recv_proxy(
cluster_config,
party,
logging_level='info',
tls_config=tls_config,
proxy_cls=GrpcRecvProxy,
proxy_config=config
)
start_send_proxy(
cluster_config,
party,
logging_level='info',
tls_config=tls_config,
proxy_cls=GrpcSendProxy,
proxy_config=config
)

sent_objs = []
Expand Down

0 comments on commit 2f9ce57

Please sign in to comment.