Skip to content

Commit 608d8dd

Browse files
vanitabhagwatvbhagwatEXPEbdodla
authored
feat: Add centralized rate limiter (#305)
* add centralized rate limiter * fixed linting errors * fixed formatting * fixed formatting * fix comments * fix formatting * Added comments * Address review comments * updated based on the review comments * removed accidentally added file * removed accidentally added file * fixed formating * fix type error * fix formatting * Updated the logs * Update sdk/python/feast/rate_limiter.py Co-authored-by: Bhargav Dodla <[email protected]> * Address review comments * fix linitng * Improve error logging in integration test utils * Clean up blank lines in go_integration_test_utils.go Removed unnecessary blank lines in the integration test utility functions. * Restore copyright notice and reformat imports * Restore copyright notice and imports in cli.py * Restore copyright notice and imports in cli.py * Restore copyright notice and imports in cli.py * Add back __init__.py file from master Co-Authored-By: Claude Opus 4.6 <[email protected]> * Added some integration tests * No op commit * No op commit --------- Co-authored-by: vbhagwat <[email protected]> Co-authored-by: Bhargav Dodla <[email protected]>
1 parent f2cd866 commit 608d8dd

6 files changed

Lines changed: 1037 additions & 43 deletions

File tree

sdk/python/feast/infra/online_stores/cassandra_online_store/cassandra_online_store.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
from feast.protos.feast.core.SortedFeatureView_pb2 import SortOrder
4949
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
5050
from feast.protos.feast.types.Value_pb2 import Value as ValueProto
51-
from feast.rate_limiter import SlidingWindowRateLimiter
5251
from feast.repo_config import FeastConfigBaseModel
5352
from feast.sorted_feature_view import SortedFeatureView
5453
from feast.types import (
@@ -418,9 +417,7 @@ def on_failure(exc, concurrent_queue):
418417
ttl_feature_view = table.ttl or timedelta(seconds=0)
419418
ttl_online_store_config = online_store_config.key_ttl_seconds or 0
420419
write_concurrency = online_store_config.write_concurrency
421-
write_rate_limit = online_store_config.write_rate_limit
422420
concurrent_queue: Queue = Queue(maxsize=write_concurrency)
423-
rate_limiter = SlidingWindowRateLimiter(write_rate_limit, 1)
424421
feast_array_types = [
425422
"bytes_list_val",
426423
"string_list_val",
@@ -540,7 +537,6 @@ def on_failure(exc, concurrent_queue):
540537
and 0 < online_store_config.write_batch_size <= batch_count
541538
):
542539
CassandraOnlineStore._apply_batch(
543-
rate_limiter,
544540
batch,
545541
progress,
546542
session,
@@ -553,7 +549,6 @@ def on_failure(exc, concurrent_queue):
553549

554550
if batch_count > 0:
555551
CassandraOnlineStore._apply_batch(
556-
rate_limiter,
557552
batch,
558553
progress,
559554
session,
@@ -592,7 +587,6 @@ def on_failure(exc, concurrent_queue):
592587
and 0 < online_store_config.write_batch_size <= batch_count
593588
):
594589
CassandraOnlineStore._apply_batch(
595-
rate_limiter,
596590
batch,
597591
progress,
598592
session,
@@ -605,7 +599,6 @@ def on_failure(exc, concurrent_queue):
605599

606600
if batch_count > 0:
607601
CassandraOnlineStore._apply_batch(
608-
rate_limiter,
609602
batch,
610603
progress,
611604
session,
@@ -952,9 +945,11 @@ def _build_sorted_table_cql(
952945
"""
953946
sort_key_names = [sk.name for sk in table.sort_keys]
954947
feature_columns = ", ".join(
955-
f"{feature.name} {self._get_cql_type(feature.dtype)}"
956-
if feature.name in sort_key_names
957-
else f"{feature.name} BLOB"
948+
(
949+
f"{feature.name} {self._get_cql_type(feature.dtype)}"
950+
if feature.name in sort_key_names
951+
else f"{feature.name} BLOB"
952+
)
958953
for feature in table.features
959954
)
960955

@@ -1023,19 +1018,13 @@ def _get_cql_statement(
10231018

10241019
@staticmethod
10251020
def _apply_batch(
1026-
rate_limiter: SlidingWindowRateLimiter,
10271021
batch: BatchStatement,
10281022
progress: Optional[Callable[[int], Any]],
10291023
session: Session,
10301024
concurrent_queue: Queue,
10311025
on_success,
10321026
on_failure,
10331027
):
1034-
# Wait until the rate limiter allows
1035-
if not rate_limiter.acquire():
1036-
while not rate_limiter.acquire():
1037-
time.sleep(0.001)
1038-
10391028
future = session.execute_async(batch)
10401029
concurrent_queue.put(future)
10411030
future.add_callbacks(

sdk/python/feast/infra/passthrough_provider.py

Lines changed: 98 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
4747
from feast.protos.feast.types.Value_pb2 import RepeatedValue
4848
from feast.protos.feast.types.Value_pb2 import Value as ValueProto
49+
from feast.rate_limiter import TokenBucketRateLimiter # provider-level write limiter
4950
from feast.repo_config import BATCH_ENGINE_CLASS_FOR_TYPE, RepoConfig
5051
from feast.saved_dataset import SavedDataset
5152
from feast.sorted_feature_view import SortedFeatureView
@@ -62,15 +63,14 @@
6263

6364

6465
class PassthroughProvider(Provider):
65-
"""
66-
The passthrough provider delegates all operations to the underlying online and offline stores.
67-
"""
66+
"""The passthrough provider delegates all operations to the underlying online and offline stores."""
6867

6968
def __init__(self, config: RepoConfig):
7069
self.repo_config = config
7170
self._offline_store = None
7271
self._online_store = None
7372
self._batch_engine: Optional[ComputeEngine] = None
73+
self._write_token_limiters: Dict[str, TokenBucketRateLimiter] = {}
7474

7575
@property
7676
def online_store(self):
@@ -199,8 +199,66 @@ def online_write_batch(
199199
],
200200
progress: Optional[Callable[[int], Any]],
201201
) -> None:
202-
if self.online_store:
203-
self.online_store.online_write_batch(config, table, data, progress)
202+
"""
203+
Write data to the online store in rate-limited batches.
204+
Uses TokenBucketRateLimiter to throttle writes.
205+
"""
206+
207+
# Resolve configured rate limit
208+
rate_limit = self._resolve_write_rate_limit(config, table)
209+
fv_name = getattr(table, "name", "global") if table is not None else "global"
210+
limiter_key = f"{config.project}:{fv_name}"
211+
212+
# If rate limit is 0 or unset, bypass limiter
213+
if rate_limit <= 0:
214+
if self.online_store:
215+
self.online_store.online_write_batch(config, table, data, progress)
216+
return
217+
218+
# Create or reuse per-feature-view limiter
219+
# Calculate percent_usage based on available CPU cores
220+
# More processes = lower percent_usage to reduce token contention
221+
num_spark_driver_cores = int(os.environ.get("SPARK_DRIVER_CORES", 1))
222+
223+
if num_spark_driver_cores > 2:
224+
num_processes = num_spark_driver_cores - 1
225+
# Decrease percent_usage as processes increase to allow fair sharing
226+
# 2 processes -> 0.50, 4 processes -> 0.40, 8 processes -> 0.30
227+
percent_usage = max(0.6 / (num_processes / 2), 0.25)
228+
else:
229+
# Single process - can use more tokens per batch
230+
percent_usage = 0.9
231+
232+
interval = 1.0 # seconds
233+
234+
limiter = self._write_token_limiters.get(limiter_key)
235+
if limiter is None or limiter.rate != rate_limit:
236+
limiter = TokenBucketRateLimiter(
237+
rate=rate_limit, interval=interval, percent_usage=percent_usage
238+
)
239+
self._write_token_limiters[limiter_key] = limiter
240+
logger.info(
241+
f"[Limiter] Initialized rate limiter for {limiter_key} at {rate_limit} writes/sec"
242+
)
243+
244+
# Process data in dynamically sized batches based on token availability
245+
total_records = len(data)
246+
index = 0
247+
248+
while index < total_records:
249+
available = limiter.get_available_tokens()
250+
# Ensure we always make progress (at least 1 record)
251+
batch_size = min(max(available, 1), total_records - index)
252+
253+
batch = data[index : index + batch_size]
254+
limiter.wait_for_tokens(len(batch)) # blocks until tokens available
255+
256+
if self.online_store:
257+
self.online_store.online_write_batch(config, table, batch, progress)
258+
259+
index += batch_size
260+
if progress:
261+
progress(batch_size)
204262

205263
async def online_write_batch_async(
206264
self,
@@ -216,6 +274,41 @@ async def online_write_batch_async(
216274
config, table, data, progress
217275
)
218276

277+
def _resolve_write_rate_limit(
278+
self,
279+
config: RepoConfig,
280+
table: Union[FeatureView, BaseFeatureView, OnDemandFeatureView],
281+
) -> int:
282+
"""Resolve write_rate_limit using precedence:
283+
1. feature view tag 'write_rate_limit'
284+
2. config.online_store.write_rate_limit
285+
3. fallback 0
286+
"""
287+
# 1) Feature view tag override
288+
if table is not None and hasattr(table, "tags") and table.tags:
289+
tag_val = table.tags.get("write_rate_limit")
290+
if tag_val is not None:
291+
try:
292+
return int(tag_val)
293+
except Exception:
294+
logger.warning(
295+
"Invalid write_rate_limit on feature view %s: %s; falling back",
296+
getattr(table, "name", "<unknown>"),
297+
tag_val,
298+
)
299+
300+
# 2) Project / online store level config
301+
try:
302+
if config.online_store and hasattr(config.online_store, "write_rate_limit"):
303+
return int(getattr(config.online_store, "write_rate_limit") or 0)
304+
except Exception:
305+
logger.warning(
306+
"Invalid write_rate_limit on online_store config; falling back to 0"
307+
)
308+
309+
# 3) Fallback to 0 (no rate limit)
310+
return 0
311+
219312
def offline_write_batch(
220313
self,
221314
config: RepoConfig,
@@ -407,7 +500,6 @@ def ingest_df(
407500

408501
# Input table is split into smaller chunks and processed in parallel
409502
chunks = self.split_table(num_processes, table)
410-
411503
chunks_to_parallelize = [
412504
(chunk, feature_view, join_keys) for chunk in chunks
413505
]
@@ -465,7 +557,6 @@ def ingest_df_to_offline_store(self, feature_view: FeatureView, table: pa.Table)
465557
table = _run_pyarrow_field_mapping(
466558
table, feature_view.batch_source.field_mapping
467559
)
468-
469560
self.offline_write_batch(self.repo_config, feature_view, table, None)
470561

471562
def materialize_single_feature_view(
@@ -544,7 +635,6 @@ def get_historical_features(
544635
full_feature_names=full_feature_names,
545636
**kwargs,
546637
)
547-
548638
return job
549639

550640
def retrieve_saved_dataset(
@@ -554,10 +644,8 @@ def retrieve_saved_dataset(
554644
ref.replace(":", "__") if dataset.full_feature_names else ref.split(":")[1]
555645
for ref in dataset.features
556646
]
557-
558647
# ToDo: replace hardcoded value
559648
event_ts_column = "event_timestamp"
560-
561649
return self.offline_store.pull_all_from_table_or_query(
562650
config=config,
563651
data_source=dataset.storage.to_data_source(),
@@ -578,7 +666,6 @@ def write_feature_service_logs(
578666
assert feature_service.logging_config is not None, (
579667
"Logging should be configured for the feature service before calling this function"
580668
)
581-
582669
self.offline_store.write_logged_features(
583670
config=config,
584671
data=logs,
@@ -598,7 +685,6 @@ def retrieve_feature_service_logs(
598685
assert feature_service.logging_config is not None, (
599686
"Logging should be configured for the feature service before calling this function"
600687
)
601-
602688
logging_source = FeatureServiceLoggingSource(feature_service, config.project)
603689
schema = logging_source.get_schema(registry)
604690
logging_config = feature_service.logging_config

sdk/python/feast/rate_limiter.py

Lines changed: 64 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,70 @@
1+
import math
2+
import threading
13
import time
4+
from typing import Optional
25

36

4-
class SlidingWindowRateLimiter:
5-
def __init__(self, max_calls, period):
6-
self.max_calls = max_calls
7-
self.period = period
8-
self.timestamps = [0] * max_calls
9-
self.index = 0
7+
class TokenBucketRateLimiter:
8+
def __init__(self, rate: float, interval: float = 1.0, percent_usage: float = 0.6):
9+
"""
10+
Args:
11+
rate: Maximum tokens added per interval (writes per interval)
12+
interval: Refill interval in seconds
13+
percent_usage: Fraction of available tokens allowed for writing
14+
"""
15+
self.rate = float(rate)
16+
self.interval = float(interval)
17+
self.max_tokens = float(rate)
18+
self.tokens = float(rate)
19+
self.last_refill = time.monotonic()
20+
self.lock = threading.Lock()
21+
self.cond = threading.Condition(self.lock)
22+
self.percent_usage = float(percent_usage)
1023

11-
def acquire(self):
12-
if self.max_calls == 0:
13-
return True
14-
now = time.time()
15-
window_start = now - self.period
24+
def _refill(self):
25+
"""Refill tokens based on elapsed time."""
26+
now = time.monotonic()
27+
elapsed = now - self.last_refill
28+
if elapsed <= 0:
29+
return
30+
31+
added = (self.rate * elapsed) / self.interval
32+
if added > 0:
33+
self.tokens = min(self.max_tokens, self.tokens + added)
34+
self.last_refill = now
35+
36+
def get_available_tokens(self) -> int:
37+
"""
38+
Return the current number of tokens available for use,
39+
considering percent_usage.
40+
"""
41+
with self.lock:
42+
self._refill()
43+
return math.floor(self.tokens * self.percent_usage)
1644

17-
if self.timestamps[self.index] <= window_start:
18-
self.timestamps[self.index] = now
19-
self.index = (self.index + 1) % self.max_calls
45+
def wait_for_tokens(self, num: int, timeout: Optional[float] = None) -> bool:
46+
"""
47+
Block until `num` tokens are available, then consume them.
48+
"""
49+
if num <= 0:
2050
return True
21-
return False
51+
52+
end_time = None if timeout is None else (time.monotonic() + timeout)
53+
with self.cond:
54+
while True:
55+
self._refill()
56+
available = self.tokens * self.percent_usage
57+
if available >= num:
58+
# Consume atomically
59+
self.tokens -= num
60+
self.cond.notify_all()
61+
return True
62+
63+
if end_time is not None:
64+
remaining = end_time - time.monotonic()
65+
if remaining <= 0:
66+
return False
67+
wait_time = min(0.05, remaining)
68+
else:
69+
wait_time = 0.05
70+
self.cond.wait(wait_time)

0 commit comments

Comments
 (0)