1- from typing import Dict , List , Optional , Union
1+ from typing import Any , Dict , List , Optional , Union
22
33import numpy as np
44import pandas as pd
5+ import pyarrow as pa
6+ import ray
57from ray .data import Dataset
68
79
10+ class RemoteDatasetProxy :
11+ """Proxy class that executes Ray Data operations remotely on cluster workers."""
12+
13+ def __init__ (self , dataset_ref : Any ):
14+ """Initialize with a reference to the remote dataset."""
15+ self ._dataset_ref = dataset_ref
16+
17+ def map_batches (self , func , ** kwargs ) -> "RemoteDatasetProxy" :
18+ """Execute map_batches remotely on cluster workers."""
19+
20+ @ray .remote
21+ def _remote_map_batches (dataset , function , batch_kwargs ):
22+ return dataset .map_batches (function , ** batch_kwargs )
23+
24+ new_ref = _remote_map_batches .remote (self ._dataset_ref , func , kwargs )
25+ return RemoteDatasetProxy (new_ref )
26+
27+ def filter (self , fn ) -> "RemoteDatasetProxy" :
28+ """Execute filter remotely on cluster workers."""
29+
30+ @ray .remote
31+ def _remote_filter (dataset , filter_fn ):
32+ return dataset .filter (filter_fn )
33+
34+ new_ref = _remote_filter .remote (self ._dataset_ref , fn )
35+ return RemoteDatasetProxy (new_ref )
36+
37+ def to_pandas (self ) -> pd .DataFrame :
38+ """Execute to_pandas remotely and transfer result to client."""
39+
40+ @ray .remote
41+ def _remote_to_pandas (dataset ):
42+ return dataset .to_pandas ()
43+
44+ result_ref = _remote_to_pandas .remote (self ._dataset_ref )
45+ return ray .get (result_ref )
46+
47+ def to_arrow (self ) -> pa .Table :
48+ """Execute to_arrow remotely and transfer result to client."""
49+
50+ @ray .remote
51+ def _remote_to_arrow (dataset ):
52+ return dataset .to_arrow ()
53+
54+ result_ref = _remote_to_arrow .remote (self ._dataset_ref )
55+ return ray .get (result_ref )
56+
57+ def schema (self ) -> Any :
58+ """Get dataset schema."""
59+
60+ @ray .remote
61+ def _remote_schema (dataset ):
62+ return dataset .schema ()
63+
64+ schema_ref = _remote_schema .remote (self ._dataset_ref )
65+ return ray .get (schema_ref )
66+
67+ def sort (self , key , descending = False ) -> "RemoteDatasetProxy" :
68+ """Execute sort remotely on cluster workers."""
69+
70+ @ray .remote
71+ def _remote_sort (dataset , sort_key , desc ):
72+ return dataset .sort (sort_key , descending = desc )
73+
74+ new_ref = _remote_sort .remote (self ._dataset_ref , key , descending )
75+ return RemoteDatasetProxy (new_ref )
76+
77+ def limit (self , count ) -> "RemoteDatasetProxy" :
78+ """Execute limit remotely on cluster workers."""
79+
80+ @ray .remote
81+ def _remote_limit (dataset , limit_count ):
82+ return dataset .limit (limit_count )
83+
84+ new_ref = _remote_limit .remote (self ._dataset_ref , count )
85+ return RemoteDatasetProxy (new_ref )
86+
87+ def union (self , other ) -> "RemoteDatasetProxy" :
88+ """Execute union remotely on cluster workers."""
89+
90+ @ray .remote
91+ def _remote_union (dataset1 , dataset2 ):
92+ return dataset1 .union (dataset2 )
93+
94+ new_ref = _remote_union .remote (self ._dataset_ref , other ._dataset_ref )
95+ return RemoteDatasetProxy (new_ref )
96+
97+ def materialize (self ) -> "RemoteDatasetProxy" :
98+ """Execute materialize remotely on cluster workers."""
99+
100+ @ray .remote
101+ def _remote_materialize (dataset ):
102+ return dataset .materialize ()
103+
104+ new_ref = _remote_materialize .remote (self ._dataset_ref )
105+ return RemoteDatasetProxy (new_ref )
106+
107+ def count (self ) -> int :
108+ """Execute count remotely and return result."""
109+
110+ @ray .remote
111+ def _remote_count (dataset ):
112+ return dataset .count ()
113+
114+ result_ref = _remote_count .remote (self ._dataset_ref )
115+ return ray .get (result_ref )
116+
117+ def take (self , n = 20 ) -> list :
118+ """Execute take remotely and return result."""
119+
120+ @ray .remote
121+ def _remote_take (dataset , num ):
122+ return dataset .take (num )
123+
124+ result_ref = _remote_take .remote (self ._dataset_ref , n )
125+ return ray .get (result_ref )
126+
127+ def __getattr__ (self , name ):
128+ """Catch any method calls that we haven't explicitly implemented."""
129+ raise AttributeError (f"RemoteDatasetProxy has no attribute '{ name } '" )
130+
131+
132+ def is_ray_data (data : Any ) -> bool :
133+ """Check if data is a Ray Dataset or RemoteDatasetProxy."""
134+ return isinstance (data , (Dataset , RemoteDatasetProxy ))
135+
136+
8137def normalize_timestamp_columns (
9- data : Union [pd .DataFrame , Dataset ],
138+ data : Union [pd .DataFrame , Dataset , Any ],
10139 columns : Union [str , List [str ]],
11140 inplace : bool = False ,
12141 exclude_columns : Optional [List [str ]] = None ,
13- ) -> Union [pd .DataFrame , Dataset ]:
142+ ) -> Union [pd .DataFrame , Dataset , Any ]:
14143 column_list = [columns ] if isinstance (columns , str ) else columns
15144 exclude_columns = exclude_columns or []
16145
@@ -21,7 +150,7 @@ def apply_normalization(series: pd.Series) -> pd.Series:
21150 .astype ("datetime64[ns, UTC]" )
22151 )
23152
24- if isinstance (data , Dataset ):
153+ if is_ray_data (data ):
25154
26155 def normalize_batch (batch : pd .DataFrame ) -> pd .DataFrame :
27156 for column in column_list :
@@ -35,6 +164,7 @@ def normalize_batch(batch: pd.DataFrame) -> pd.DataFrame:
35164
36165 return data .map_batches (normalize_batch , batch_format = "pandas" )
37166 else :
167+ assert isinstance (data , pd .DataFrame )
38168 if not inplace :
39169 data = data .copy ()
40170 for column in column_list :
@@ -44,13 +174,13 @@ def normalize_batch(batch: pd.DataFrame) -> pd.DataFrame:
44174
45175
46176def ensure_timestamp_compatibility (
47- data : Union [pd .DataFrame , Dataset ],
177+ data : Union [pd .DataFrame , Dataset , Any ],
48178 timestamp_fields : List [str ],
49179 inplace : bool = False ,
50- ) -> Union [pd .DataFrame , Dataset ]:
180+ ) -> Union [pd .DataFrame , Dataset , Any ]:
51181 from feast .utils import make_df_tzaware
52182
53- if isinstance (data , Dataset ):
183+ if is_ray_data (data ):
54184
55185 def ensure_compatibility (batch : pd .DataFrame ) -> pd .DataFrame :
56186 batch = make_df_tzaware (batch )
@@ -65,6 +195,7 @@ def ensure_compatibility(batch: pd.DataFrame) -> pd.DataFrame:
65195
66196 return data .map_batches (ensure_compatibility , batch_format = "pandas" )
67197 else :
198+ assert isinstance (data , pd .DataFrame )
68199 if not inplace :
69200 data = data .copy ()
70201 from feast .utils import make_df_tzaware
@@ -77,22 +208,24 @@ def ensure_compatibility(batch: pd.DataFrame) -> pd.DataFrame:
77208
78209
79210def apply_field_mapping (
80- data : Union [pd .DataFrame , Dataset ], field_mapping : Dict [str , str ]
81- ) -> Union [pd .DataFrame , Dataset ]:
211+ data : Union [pd .DataFrame , Dataset , Any ],
212+ field_mapping : Dict [str , str ],
213+ ) -> Union [pd .DataFrame , Dataset , Any ]:
82214 def rename_columns (df : pd .DataFrame ) -> pd .DataFrame :
83215 return df .rename (columns = field_mapping )
84216
85- if isinstance (data , Dataset ):
217+ if is_ray_data (data ):
86218 return data .map_batches (rename_columns , batch_format = "pandas" )
87219 else :
220+ assert isinstance (data , pd .DataFrame )
88221 return data .rename (columns = field_mapping )
89222
90223
91224def deduplicate_by_keys_and_timestamp (
92- data : Union [pd .DataFrame , Dataset ],
225+ data : Union [pd .DataFrame , Dataset , Any ],
93226 join_keys : List [str ],
94227 timestamp_columns : List [str ],
95- ) -> Union [pd .DataFrame , Dataset ]:
228+ ) -> Union [pd .DataFrame , Dataset , Any ]:
96229 def deduplicate_batch (batch : pd .DataFrame ) -> pd .DataFrame :
97230 if batch .empty :
98231 return batch
@@ -110,9 +243,10 @@ def deduplicate_batch(batch: pd.DataFrame) -> pd.DataFrame:
110243 return deduped_batch
111244 return batch
112245
113- if isinstance (data , Dataset ):
246+ if is_ray_data (data ):
114247 return data .map_batches (deduplicate_batch , batch_format = "pandas" )
115248 else :
249+ assert isinstance (data , pd .DataFrame )
116250 return deduplicate_batch (data )
117251
118252
0 commit comments