Skip to content

Commit

Permalink
lint codes
Browse files Browse the repository at this point in the history
  • Loading branch information
paer committed Jul 10, 2023
1 parent 8ff5658 commit 582f523
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 61 deletions.
7 changes: 4 additions & 3 deletions fed/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ def init(
'cert' in tls_config and 'key' in tls_config
), 'Cert or key are not in tls_config.'

global_cross_silo_comm_config = global_cross_silo_comm_config or CrossSiloCommConfig()
global_cross_silo_comm_config = \
global_cross_silo_comm_config or CrossSiloCommConfig()
# A Ray private accessing, should be replaced in public API.
compatible_utils._init_internal_kv()

Expand Down Expand Up @@ -184,7 +185,7 @@ def init(
logger.info(f'Started rayfed with {cluster_config}')
get_global_context().get_cleanup_manager().start(
exit_when_failure_sending=global_cross_silo_comm_config.exit_on_sending_failure)

if recv_proxy_cls is None:
from fed.proxy.grpc_proxy import GrpcRecvProxy
recv_proxy_cls = GrpcRecvProxy
Expand All @@ -207,7 +208,7 @@ def init(
logging_level=logging_level,
tls_config=tls_config,
proxy_cls=send_proxy_cls,
proxy_config=global_cross_silo_comm_config # retry_policy=cross_silo_comm_config.grpc_retry_policy,
proxy_config=global_cross_silo_comm_config
)

if enable_waiting_for_other_parties_ready:
Expand Down
6 changes: 0 additions & 6 deletions fed/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,12 +176,6 @@ class CrossSiloBRPCConfig(CrossSiloCommConfig):
"""A class to store parameters used for GRPC communication
Attributes:
grpc_retry_policy:
grpc_channel_options: A list of tuples to store GRPC channel options,
e.g. [
('grpc.enable_retries', 1),
('grpc.max_send_message_length', 50 * 1024 * 1024)
]
"""
def __init__(self,
brpc_options,
Expand Down
33 changes: 16 additions & 17 deletions fed/proxy/barriers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,11 @@
# limitations under the License.

import abc
import asyncio
import logging
import threading
import time
import copy
from typing import Dict, Optional

import cloudpickle
import ray

import fed.config as fed_config
Expand Down Expand Up @@ -64,13 +61,13 @@ def __init__(
cluster: Dict,
party: str,
tls_config: Dict,
proxy_config = None) -> None:
proxy_config=None
) -> None:
self._cluster = cluster
self._party = party
self._tls_config = tls_config
self._proxy_config = proxy_config


@abc.abstractmethod
async def send(
self,
Expand All @@ -84,19 +81,20 @@ async def send(
async def is_ready(self):
return True


class RecvProxy(abc.ABC):
def __init__(
self,
listen_addr: str,
party: str,
tls_config: Dict,
proxy_config: CrossSiloCommConfig) -> None:
proxy_config: CrossSiloCommConfig
) -> None:
self._listen_addr = listen_addr
self._party = party
self._tls_config = tls_config
self._proxy_config = proxy_config


@abc.abstractmethod
def start(self):
pass
Expand All @@ -121,7 +119,7 @@ def __init__(
party: str,
tls_config: Dict = None,
logging_level: str = None,
proxy_cls = None
proxy_cls=None
):
setup_logger(
logging_level=logging_level,
Expand All @@ -135,7 +133,8 @@ def __init__(
self._party = party
self._tls_config = tls_config
cross_silo_comm_config = fed_config.get_job_config().cross_silo_comm_config
self.proxy_instance: SendProxy = proxy_cls(cluster, party, tls_config, cross_silo_comm_config)
self.proxy_instance: SendProxy = proxy_cls(
cluster, party, tls_config, cross_silo_comm_config)

async def is_ready(self):
res = await self.proxy_instance.is_ready()
Expand All @@ -161,18 +160,17 @@ async def send(
' credentials.'
)
try:
response = await self.proxy_instance.send(dest_party, data, upstream_seq_id, downstream_seq_id)
response = await self.proxy_instance.send(
dest_party, data, upstream_seq_id, downstream_seq_id)
except Exception as e:
logger.error(f'Failed to {send_log_msg}, error: {e}')
return False
logger.debug(f"Succeeded to send {send_log_msg}. Response is {response}")
return True # True indicates it's sent successfully.


async def _get_stats(self):
return self._stats


async def _get_cluster_info(self):
return self._cluster

Expand All @@ -185,7 +183,7 @@ def __init__(
party: str,
logging_level: str,
tls_config=None,
proxy_cls = None,
proxy_cls=None,
):
setup_logger(
logging_level=logging_level,
Expand All @@ -198,7 +196,8 @@ def __init__(
self._party = party
self._tls_config = tls_config
cross_silo_comm_config = fed_config.get_job_config().cross_silo_comm_config
self._proxy_instance: RecvProxy = proxy_cls(listen_addr, party, tls_config, cross_silo_comm_config)
self._proxy_instance: RecvProxy = proxy_cls(
listen_addr, party, tls_config, cross_silo_comm_config)

async def start(self):
await self._proxy_instance.start()
Expand All @@ -208,15 +207,15 @@ async def is_ready(self):
return res

async def get_data(self, src_party, upstream_seq_id, curr_seq_id):
self._stats["receive_op_count"] += 1
data = await self._proxy_instance.get_data(src_party, upstream_seq_id, curr_seq_id)
self._stats["receive_op_count"] += 1
data = await self._proxy_instance.get_data(
src_party, upstream_seq_id, curr_seq_id)
return data

async def _get_stats(self):
return self._stats



_DEFAULT_RECV_PROXY_OPTIONS = {
"max_concurrency": 1000,
}
Expand Down
54 changes: 29 additions & 25 deletions fed/proxy/grpc_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import Dict


import fed.config as fed_config
import fed.utils as fed_utils

from fed.config import CrossSiloCommConfig, CrossSiloGrpcCommConfig
Expand All @@ -25,9 +24,14 @@
logger = logging.getLogger(__name__)



class GrpcSendProxy(SendProxy):
def __init__(self, cluster: Dict, party: str, tls_config: Dict, proxy_config=None) -> None:
def __init__(
self,
cluster: Dict,
party: str,
tls_config: Dict,
proxy_config=None
) -> None:
super().__init__(cluster, party, tls_config, proxy_config)
self._grpc_metadata = proxy_config.http_header
set_max_message_length(proxy_config.messages_max_size_in_bytes)
Expand All @@ -50,20 +54,20 @@ async def send(
grpc_options = get_grpc_options(retry_policy=self._retry_policy) if \
grpc_options is None else fed_utils.dict2tuple(grpc_options)
if dest_party not in self._stubs:
if tls_enabled:
ca_cert, private_key, cert_chain = fed_utils.load_cert_config(
self._tls_config)
credentials = grpc.ssl_channel_credentials(
certificate_chain=cert_chain,
private_key=private_key,
root_certificates=ca_cert,
)
channel = grpc.aio.secure_channel(
dest_addr, credentials, options=grpc_options)
else:
channel = grpc.aio.insecure_channel(dest_addr, options=grpc_options)
stub = fed_pb2_grpc.GrpcServiceStub(channel)
self._stubs[dest_party] = stub
if tls_enabled:
ca_cert, private_key, cert_chain = fed_utils.load_cert_config(
self._tls_config)
credentials = grpc.ssl_channel_credentials(
certificate_chain=cert_chain,
private_key=private_key,
root_certificates=ca_cert,
)
channel = grpc.aio.secure_channel(
dest_addr, credentials, options=grpc_options)
else:
channel = grpc.aio.insecure_channel(dest_addr, options=grpc_options)
stub = fed_pb2_grpc.GrpcServiceStub(channel)
self._stubs[dest_party] = stub

timeout = self._proxy_config.timeout_in_seconds
response = await send_data_grpc(
Expand All @@ -75,7 +79,7 @@ async def send(
metadata=dest_party_grpc_config['grpc_metadata'],
)
return response

def setup_grpc_config(self, dest_party):
dest_party_grpc_config = {}
global_grpc_metadata = (
Expand All @@ -100,7 +104,6 @@ async def _get_grpc_options(self):
return get_grpc_options()



async def send_data_grpc(
data,
stub,
Expand All @@ -109,7 +112,6 @@ async def send_data_grpc(
timeout,
metadata=None,
):
job_config = fed_config.get_job_config()
data = cloudpickle.dumps(data)
request = fed_pb2.SendDataRequest(
data=data,
Expand All @@ -130,7 +132,13 @@ async def send_data_grpc(


class GrpcRecvProxy(RecvProxy):
def __init__(self, listen_addr: str, party: str, tls_config: Dict, proxy_config: CrossSiloCommConfig) -> None:
def __init__(
self,
listen_addr: str,
party: str,
tls_config: Dict,
proxy_config: CrossSiloCommConfig
) -> None:
super().__init__(listen_addr, party, tls_config, proxy_config)
set_max_message_length(proxy_config.messages_max_size_in_bytes)
# Flag to see whether grpc server starts
Expand Down Expand Up @@ -163,7 +171,6 @@ async def start(self):
f' when calling `fed.init`. Grpc error msg: {err}'
self._server_ready_future.set_result((False, msg))


async def is_ready(self):
await self._server_ready_future
res = self._server_ready_future.result()
Expand Down Expand Up @@ -195,8 +202,6 @@ async def _get_grpc_options(self):
return get_grpc_options()




class SendDataService(fed_pb2_grpc.GrpcServiceServicer):
def __init__(self, all_events, all_data, party, lock):
self._events = all_events
Expand Down Expand Up @@ -258,4 +263,3 @@ async def _run_grpc_server(
)
server_ready_future.set_result((True, msg))
await server.wait_for_termination()

10 changes: 0 additions & 10 deletions fed/tmp.py

This file was deleted.

0 comments on commit 582f523

Please sign in to comment.