Skip to content

Commit

Permalink
tmp save
Browse files Browse the repository at this point in the history
  • Loading branch information
NKcqx committed Jul 10, 2023
1 parent a046d27 commit aa397a8
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 26 deletions.
30 changes: 19 additions & 11 deletions fed/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def init(
tls_config: Dict = None,
logging_level: str = 'info',
enable_waiting_for_other_parties_ready: bool = False,
cross_silo_comm_config: Optional[CrossSiloCommConfig] = None,
global_cross_silo_comm_config: Optional[CrossSiloCommConfig] = None,
dest_party_comm_config: Optional[Dict[CrossSiloCommConfig]] = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -108,8 +109,16 @@ def init(
`warning`, `error`, `critical`, not case sensititive.
enable_waiting_for_other_parties_ready: ping other parties until they
are all ready if True.
cross_silo_comm_config: Cross-silo communication related config, supported
configs can refer to CrossSiloCommConfig in config.py
global_cross_silo_comm_config: Global cross-silo communication related
config that are applied to all connections. Supported configs
can refer to CrossSiloCommConfig in config.py.
dest_party_comm_config: Communication config for the destination party
specifed by the key. E.g.
.. code:: python
{
'alice': alice_CrossSiloCommConfig,
'bob': bob_CrossSiloCommConfig
}
Examples:
>>> import fed
Expand All @@ -135,7 +144,7 @@ def init(
'cert' in tls_config and 'key' in tls_config
), 'Cert or key are not in tls_config.'

cross_silo_comm_config = cross_silo_comm_config or CrossSiloCommConfig()
global_cross_silo_comm_config = global_cross_silo_comm_config or CrossSiloCommConfig()
# A Ray private accessing, should be replaced in public API.
compatible_utils._init_internal_kv()

Expand All @@ -147,7 +156,7 @@ def init(

job_config = {
constants.KEY_OF_CROSS_SILO_COMM_CONFIG:
cross_silo_comm_config,
global_cross_silo_comm_config,
}
compatible_utils.kv.put(constants.KEY_OF_CLUSTER_CONFIG,
cloudpickle.dumps(cluster_config))
Expand All @@ -164,25 +173,24 @@ def init(
)

logger.info(f'Started rayfed with {cluster_config}')
set_exit_on_failure_sending(cross_silo_comm_config.exit_on_sending_failure)
set_exit_on_failure_sending(global_cross_silo_comm_config.exit_on_sending_failure)
# Start recv proxy
start_recv_proxy(
cluster=cluster,
party=party,
logging_level=logging_level,
tls_config=tls_config,
retry_policy=cross_silo_comm_config.grpc_retry_policy,
actor_config=cross_silo_comm_config
proxy_cls=None,
proxy_config=global_cross_silo_comm_config
)

start_send_proxy(
cluster=cluster,
party=party,
logging_level=logging_level,
tls_config=tls_config,
retry_policy=cross_silo_comm_config.grpc_retry_policy,
max_retries=cross_silo_comm_config.proxier_fo_max_retries,
actor_config=cross_silo_comm_config
proxy_cls=None,
proxy_config=global_cross_silo_comm_config # retry_policy=cross_silo_comm_config.grpc_retry_policy,
)

if enable_waiting_for_other_parties_ready:
Expand Down
33 changes: 22 additions & 11 deletions fed/barriers.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def __init__(
tls_config: Dict = None,
logging_level: str = None,
retry_policy: Dict = None,
proxy_cls = None
):
setup_logger(
logging_level=logging_level,
Expand Down Expand Up @@ -221,6 +222,7 @@ async def send(
upstream_seq_id,
downstream_seq_id,
):
# proxy_cls.send()

self._stats["send_op_count"] += 1
assert (
Expand Down Expand Up @@ -291,7 +293,9 @@ def __init__(
party: str,
logging_level: str,
tls_config=None,
retry_policy: Dict = None,
proxy_cls = None,
# retry_policy: Dict = None,

):
setup_logger(
logging_level=logging_level,
Expand All @@ -317,6 +321,8 @@ def __init__(
self._lock = threading.Lock()

async def run_grpc_server(self):
# proxy_cls.run_grpc_server()

try:
port = self._listen_addr[self._listen_addr.index(':') + 1 :]
await _run_grpc_server(
Expand All @@ -340,6 +346,10 @@ async def is_ready(self):
return self._server_ready_future.result()

async def get_data(self, src_aprty, upstream_seq_id, curr_seq_id):
# subscriber

# proxy_cls.get_data() # get from broker channel

self._stats["receive_op_count"] += 1
data_log_msg = f"data for {curr_seq_id} from {upstream_seq_id} of {src_aprty}"
logger.debug(f"Getting {data_log_msg}")
Expand Down Expand Up @@ -380,7 +390,7 @@ def start_recv_proxy(
logging_level: str,
tls_config=None,
retry_policy=None,
actor_config: Optional[fed_config.CrossSiloCommConfig] = None
proxy_config: Optional[fed_config.CrossSiloCommConfig] = None
):

# Create RecevrProxyActor
Expand All @@ -392,8 +402,8 @@ def start_recv_proxy(
listen_addr = party_addr['address']

actor_options = copy.deepcopy(_DEFAULT_RECV_PROXY_OPTIONS)
if actor_config is not None and actor_config.recv_resource_label is not None:
actor_options.update({"resources": actor_config.recv_resource_label})
if proxy_config is not None and proxy_config.recv_resource_label is not None:
actor_options.update({"resources": proxy_config.recv_resource_label})

logger.debug(f"Starting RecvProxyActor with options: {actor_options}")

Expand Down Expand Up @@ -425,20 +435,20 @@ def start_send_proxy(
logging_level: str,
tls_config: Dict = None,
retry_policy=None,
max_retries=None,
actor_config: Optional[fed_config.CrossSiloCommConfig] = None
proxy_cls=None,
proxy_config: Optional[fed_config.CrossSiloCommConfig] = None
):
# Create SendProxyActor
global _SEND_PROXY_ACTOR

actor_options = copy.deepcopy(_DEFAULT_SEND_PROXY_OPTIONS)
if max_retries is not None:
if proxy_config and proxy_config.proxier_fo_max_retries:
actor_options.update({
"max_task_retries": max_retries,
"max_task_retries": proxy_config.proxier_fo_max_retries,
"max_restarts": 1,
})
if actor_config is not None and actor_config.send_resource_label is not None:
actor_options.update({"resources": actor_config.send_resource_label})
if proxy_config and proxy_config.send_resource_label:
actor_options.update({"resources": proxy_config.send_resource_label})

logger.debug(f"Starting SendProxyActor with options: {actor_options}")
_SEND_PROXY_ACTOR = SendProxyActor.options(
Expand All @@ -449,7 +459,8 @@ def start_send_proxy(
party=party,
tls_config=tls_config,
logging_level=logging_level,
retry_policy=retry_policy,
# retry_policy=retry_policy,
# starter=server_starter
)
timeout = get_job_config().cross_silo_comm_config.timeout_in_seconds
assert ray.get(_SEND_PROXY_ACTOR.is_ready.remote(), timeout=timeout)
Expand Down
47 changes: 43 additions & 4 deletions fed/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,12 @@ class CrossSiloCommConfig:
If None, the default value of 500 MB is specified.
timeout_in_seconds: The timeout in seconds of a cross-silo RPC call.
It's 60 by default.
http_header: The HTTP header, e.g. metadata in grpc, sent with the RPC request. This won't override
basic tcp headers, such as `user-agent`, but concat them together.
http_header: The HTTP header, e.g. metadata in grpc, sent with the RPC request.
This won't override basic tcp headers, such as `user-agent`, but concat
them together.
"""
def __init__(
self,
grpc_retry_policy: Dict = None,
proxier_fo_max_retries: int = None,
timeout_in_seconds: int = 60,
messages_max_size_in_bytes: int = None,
Expand All @@ -134,7 +134,6 @@ def __init__(
send_resource_label: Optional[Dict[str, str]] = None,
recv_resource_label: Optional[Dict[str, str]] = None,
http_header: Optional[Dict[str, str]] = None) -> None:
self.grpc_retry_policy = grpc_retry_policy
self.proxier_fo_max_retries = proxier_fo_max_retries
self.timeout_in_seconds = timeout_in_seconds
self.messages_max_size_in_bytes = messages_max_size_in_bytes
Expand All @@ -151,3 +150,43 @@ def __json__(self):
def from_json(cls, json_str):
data = json.loads(json_str)
return cls(**data)


class CrossSiloGRPCConfig(CrossSiloCommConfig):
"""A class to store parameters used for GRPC communication
Attributes:
grpc_retry_policy:
grpc_channel_options: A list of tuples to store GRPC channel options,
e.g. [
('grpc.enable_retries', 1),
('grpc.max_send_message_length', 50 * 1024 * 1024)
]
"""
def __init__(self,
grpc_channel_options,
grpc_retry_policy,
*args,
**kwargs):
super().__init__(*args, **kwargs)
self.grpc_retry_policy = grpc_retry_policy
self.grpc_channel_options = grpc_channel_options


class CrossSiloBRPCConfig(CrossSiloCommConfig):
"""A class to store parameters used for GRPC communication
Attributes:
grpc_retry_policy:
grpc_channel_options: A list of tuples to store GRPC channel options,
e.g. [
('grpc.enable_retries', 1),
('grpc.max_send_message_length', 50 * 1024 * 1024)
]
"""
def __init__(self,
brpc_options,
*args,
**kwargs):
super().__init__(*args, **kwargs)
self.brpc_options = brpc_options

0 comments on commit aa397a8

Please sign in to comment.