fix: Eliminate RDD usage across SynapseML for Spark 4.0 compatibility#2517
fix: Eliminate RDD usage across SynapseML for Spark 4.0 compatibility#2517BrendanWalsh wants to merge 1 commit intomasterfrom
Conversation
|
Hey @BrendanWalsh 👋! We use semantic commit messages to streamline the release process. Examples of commit messages with semantic prefixes:
To test your commit locally, please follow our guild on building from source. |
Dependency Review✅ No vulnerabilities or license issues or OpenSSF Scorecard issues found.Scanned FilesNone |
There was a problem hiding this comment.
Pull request overview
This PR removes or reduces reliance on SparkContext/RDD APIs across SynapseML to align with SPARK-48909 patterns and improve compatibility with Spark 4.0+ and restricted-RDD environments (e.g., Unity Catalog shared access mode).
Changes:
- Replace SparkContext/RDD-based persistence and small-data parallelization patterns with SparkSession/DataFrame equivalents.
- Refactor LightGBM and VowpalWabbit task/partition management to avoid
.rdd.getNumPartitionsusage. - Rework several stages/utilities (e.g., StratifiedRepartition, Repartition, SyntheticEstimator indexing, ONNX file access) to avoid
.rddconversions where possible.
Reviewed changes
Copilot reviewed 25 out of 25 changed files in this pull request and generated 9 comments.
Show a summary per file
| File | Description |
|---|---|
| vw/src/main/scala/com/microsoft/azure/synapse/ml/vw/VowpalWabbitBaseSpark.scala | Removes SparkContext import usage. |
| vw/src/main/scala/com/microsoft/azure/synapse/ml/vw/VowpalWabbitBaseProgressive.scala | Adjusts to new prepareDataSet return type including task count. |
| vw/src/main/scala/com/microsoft/azure/synapse/ml/vw/VowpalWabbitBaseLearner.scala | Uses (DataFrame, numTasks) returned from prepareDataSet; documents barrier RDD need. |
| vw/src/main/scala/com/microsoft/azure/synapse/ml/vw/VowpalWabbitBase.scala | Refactors dataset preparation to return (DataFrame, Int) and avoid .rdd.getNumPartitions. |
| lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/booster/LightGBMBooster.scala | Replaces sparkContext.parallelize with Seq(...).toDS() for writing model text. |
| lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/LightGBMRanker.scala | Removes dependency on .rdd.getNumPartitions when repartitioning by grouping column. |
| lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/LightGBMBase.scala | Removes .rdd.getNumPartitions caps; simplifies repartition/coalesce logic; documents barrier RDD need. |
| deep-learning/src/main/scala/com/microsoft/azure/synapse/ml/onnx/ONNXModel.scala | Replaces sc.binaryFiles with Hadoop FS read for model payload loading. |
| deep-learning/src/main/scala/com/microsoft/azure/synapse/ml/onnx/ONNXHub.scala | Replaces SparkContext access with SparkSession-based Hadoop configuration access. |
| core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/HTTPSourceV2.scala | Documents remaining SparkContext requirement (BlockManager access). |
| core/src/main/scala/org/apache/spark/sql/execution/streaming/HTTPSource.scala | Documents remaining RDD requirement for internalCreateDataFrame streaming path. |
| core/src/main/scala/org/apache/spark/sql/execution/streaming/DistributedHTTPSource.scala | Replaces small RDD parallelize with spark.range; documents remaining RDD streaming requirement. |
| core/src/main/scala/org/apache/spark/ml/Serializer.scala | Moves HDFS helpers from SparkContext to SparkSession; updates ObjectSerializer accordingly. |
| core/src/main/scala/org/apache/spark/ml/ComplexParamsSerializer.scala | Implements SPARK-48909-style metadata save using SparkSession/DataFrame write. |
| core/src/main/scala/com/microsoft/azure/synapse/ml/train/ComputePerInstanceStatistics.scala | Replaces .rdd.distinct().count() with DataFrame .distinct().count(). |
| core/src/main/scala/com/microsoft/azure/synapse/ml/train/ComputeModelStatistics.scala | Adds documentation for required .rdd metrics conversions. |
| core/src/main/scala/com/microsoft/azure/synapse/ml/stages/StratifiedRepartition.scala | Replaces RDD stratified sampling/partitioning with DataFrame-based sampling + repartitioning. |
| core/src/main/scala/com/microsoft/azure/synapse/ml/stages/Repartition.scala | Replaces RDD repartition/createDataFrame with DataFrame repartition. |
| core/src/main/scala/com/microsoft/azure/synapse/ml/stages/Lambda.scala | Replaces SparkContext emptyRDD with SparkSession empty DataFrame creation. |
| core/src/main/scala/com/microsoft/azure/synapse/ml/recommendation/SARModel.scala | Documents required .rdd conversion for CoordinateMatrix. |
| core/src/main/scala/com/microsoft/azure/synapse/ml/recommendation/RankingEvaluator.scala | Documents required .rdd conversion for RankingMetrics. |
| core/src/main/scala/com/microsoft/azure/synapse/ml/core/utils/ClusterUtil.scala | Replaces mapPartitionsWithIndex with spark_partition_id + aggregation. |
| core/src/main/scala/com/microsoft/azure/synapse/ml/causal/linalg/VectorOps.scala | Replaces SparkContext parallelize with spark.range for distributed vector creation. |
| core/src/main/scala/com/microsoft/azure/synapse/ml/causal/SyntheticEstimator.scala | Replaces rdd.zipWithIndex with a DataFrame-based row id approach. |
| core/src/main/scala/com/microsoft/azure/synapse/ml/automl/TuneHyperparameters.scala | Replaces MLUtils.kFold(df.rdd) with DataFrame-based fold assignment. |
Comments suppressed due to low confidence (3)
core/src/main/scala/com/microsoft/azure/synapse/ml/stages/StratifiedRepartition.scala:103
getEqualLabelCountrepeats themax(spark_partition_id())pattern and will throw on empty inputs for the same reason as intransform(max returns null on empty). It also inherits the same partition undercount risk when some partitions are empty. Please reuse a safe partition-count helper here that handles empty datasets and doesn’t assumemaxPid + 1equals the partition count.
private def getEqualLabelCount(labelToCount: Array[(Int, Long)], dataset: Dataset[_]): Map[Int, Double] = {
val numPartitions = dataset.toDF()
.select(spark_partition_id().as("_pid")).agg(sqlMax("_pid")).head().getInt(0) + 1
val maxLabelCount = Math.max(labelToCount.map { case (label, count) => count }.max, numPartitions)
lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/LightGBMBase.scala:449
determineNumTasksno longer caps the computed worker count by the input dataset’s current number of partitions. Since the default is non-barrier mode (useBarrierExecutionModedefaults to false),prepareDataframewill callcoalesce(numTasks), which cannot increase partitions—sonumTaskscan exceed the actual partition count. The driver-sideNetworkManagerthen waits fornumTasksexecutor connections and can hang/time out when the dataset starts with fewer partitions thannumTasks(e.g., small inputs). Please ensurenumTasksis <= actual partitions in non-barrier mode, or switch torepartition(numTasks)whennumTaskswould increase partitions.
private def determineNumTasks(dataset: Dataset[_], configNumTasks: Int, numTasksPerExecutor: Int) = {
// By default, we try to intelligently calculate the number of executors, but user can override this with numTasks
if (configNumTasks > 0) configNumTasks
else {
ClusterUtil.getNumExecutorTasks(dataset.sparkSession, numTasksPerExecutor, log)
}
deep-learning/src/main/scala/com/microsoft/azure/synapse/ml/onnx/ONNXHub.scala:205
- This method uses
SparkSession.active.sparkContext.hadoopConfigurationwhile downloading. If no active session is set (common when called from library code on the driver thread), this will throw even though a Spark application exists. Use a safer session acquisition strategy (e.g., builder/getOrCreate orgetActiveSessionfallback) so model downloads don’t depend on thread-local SparkSession state.
//noinspection ScalaStyle
private def downloadModel(url: URL, path: Path, fs: FileSystem): Unit = {
FaultToleranceUtils.retryWithTimeout(retryCount, Duration.apply(retryTimeoutInSeconds, "sec")) {
val urlCon = url.openConnection()
urlCon.setConnectTimeout(connectTimeout)
urlCon.setReadTimeout(readTimeout)
using(new BufferedInputStream(urlCon.getInputStream)) { is =>
using(fs.create(path)) { os =>
HUtils.copyBytes(is, os, SparkSession.active.sparkContext.hadoopConfiguration)
}
| import org.apache.spark.sql.functions.{rand, lit} | ||
| val df = dataset.toDF | ||
| val nFolds = getNumFolds | ||
| // DataFrame-based k-fold splitting: assign each row to a fold using hash of random value | ||
| val dfWithFold = df.withColumn("_kfold_rand", rand(getSeed)) | ||
| val splits = (0 until nFolds).map { fold => | ||
| val training = dfWithFold | ||
| .filter((dfWithFold("_kfold_rand") * lit(nFolds)).cast("int") =!= lit(fold)) | ||
| .drop("_kfold_rand") | ||
| val validation = dfWithFold | ||
| .filter((dfWithFold("_kfold_rand") * lit(nFolds)).cast("int") === lit(fold)) | ||
| .drop("_kfold_rand") | ||
| (training, validation) | ||
| }.toArray |
There was a problem hiding this comment.
dfWithFold is not cached, so each fold’s training and validation DataFrames will independently re-evaluate the rand(getSeed) expression and re-scan the input. Persisting dfWithFold once (and unpersisting after building splits) would both reduce work and ensure the fold assignment is identical across all derived DataFrames even if the optimizer rewrites the plan.
| if (getUseBarrierExecutionMode) { | ||
| val numPartitions = df.rdd.getNumPartitions | ||
| if (numPartitions > numTasks) { | ||
| df.repartition(numTasks) | ||
| } else { | ||
| df | ||
| } | ||
| df.repartition(numTasks) | ||
| } else { | ||
| df.coalesce(numTasks) | ||
| } |
There was a problem hiding this comment.
In non-barrier mode this uses df.coalesce(numTasks), but coalesce is a no-op when numTasks is greater than the current partition count. Combined with the updated determineNumTasks, this can leave the DataFrame with fewer partitions than numTasks, while the training logic still assumes numTasks workers. Consider guarding against partition increases here (or using repartition(numTasks) when increasing partitions) so the DataFrame partitioning matches the worker count used for coordination.
This issue also appears on line 444 of the same file.
| // Derive partition count from input dataset without using .rdd.getNumPartitions | ||
| val numPartitions = df.select(spark_partition_id().as("_pid")).agg(sqlMax("_pid")).head().getInt(0) + 1 | ||
| // Round-robin assignment within each label ensures all labels appear on all partitions |
There was a problem hiding this comment.
numPartitions is derived via max(spark_partition_id()), which can return null on an empty DataFrame; head().getInt(0) will then throw. Also, using max(partitionId) + 1 can undercount when trailing partitions are empty (partition IDs are not guaranteed to be contiguous in the data). Please handle the empty-input case safely and avoid assuming maxPid + 1 equals the real partition count.
This issue also appears on line 100 of the same file.
| val windowSpec = Window.partitionBy(col(getLabelCol)).orderBy(rand(getSeed)) | ||
| val withPartition = sampled.withColumn("_rr_idx", row_number().over(windowSpec) % lit(numPartitions)) | ||
| withPartition.repartitionByRange(numPartitions, col("_rr_idx")).drop("_rr_idx") |
There was a problem hiding this comment.
The round-robin index calculation looks off by one: row_number() starts at 1, so row_number() % numPartitions yields values in 1..numPartitions-1 with 0 only every numPartitions rows, which skews distribution. Additionally, repartitionByRange(numPartitions, col("_rr_idx")) does not guarantee that each _rr_idx value maps to a single partition (range boundaries are sampled), so it may not achieve the intended per-partition label spread. Consider using a zero-based modulo (e.g., pmod(row_number()-1, numPartitions)) and repartition(numPartitions, col("_rr_idx")) if you need deterministic partition assignment by _rr_idx.
| logTransform[DataFrame]({ | ||
| if (getDisable) | ||
| dataset.toDF | ||
| else if (getN < dataset.rdd.getNumPartitions) | ||
| dataset.coalesce(getN).toDF() | ||
| else | ||
| dataset.sqlContext.createDataFrame( | ||
| dataset.rdd.repartition(getN).asInstanceOf[RDD[Row]], | ||
| dataset.schema) | ||
| dataset.repartition(getN).toDF() | ||
| }, dataset.columns.length) |
There was a problem hiding this comment.
Switching from conditional coalesce/RDD-based repartitioning to unconditional dataset.repartition(getN) means reducing partitions will now always incur a full shuffle (previously it could use coalesce without shuffle when decreasing partitions). If this transformer is expected to be used for downsizing partitions, consider adding a parameter to choose coalesce vs repartition, or reintroduce a non-shuffle path when getN is less than the current partition count (without relying on RDD APIs).
| df.withColumn(colName, monotonically_increasing_id()) | ||
| .select(col(colName) +: df.columns.map(col): _*) |
There was a problem hiding this comment.
monotonically_increasing_id() does not produce dense, 0-based consecutive indices (it encodes partition ID and can create very large gaps). Downstream MatrixOps.size computes matrix dimensions via max(i)+1, so this change can massively inflate inferred sizes and lead to incorrect behavior or OOMs. Please generate a dense index instead (e.g., row_number() over an explicit Window.orderBy(...) and subtract 1) so indices are contiguous and bounded by the distinct count.
| df.withColumn(colName, monotonically_increasing_id()) | |
| .select(col(colName) +: df.columns.map(col): _*) | |
| val windowSpec = Window.orderBy(df.columns.map(col): _*) | |
| val indexedDf = df.withColumn(colName, row_number().over(windowSpec) - 1) | |
| indexedDf.select(col(colName) +: df.columns.map(col): _*) |
| .getOrElse({ | ||
| val home = new Path("placeholder") | ||
| .getFileSystem(SparkContext.getOrCreate().hadoopConfiguration) | ||
| .getFileSystem(SparkSession.active.sparkContext.hadoopConfiguration) | ||
| .getHomeDirectory | ||
| FileUtilities.join(home, ".cache", "onnx", "hub") |
There was a problem hiding this comment.
SparkSession.active will throw if there is no active Spark session on the current thread. DefaultCacheDir is a lazy val and can be initialized outside of Spark execution (or before SparkSession.setActiveSession), which would make ONNXHub fail unexpectedly. Consider using SparkSession.builder().getOrCreate() or SparkSession.getActiveSession.getOrElse(SparkSession.builder().getOrCreate()) to obtain a session more safely here.
This issue also appears on line 196 of the same file.
See below for a potential fix:
val spark = SparkSession.getActiveSession.getOrElse(SparkSession.builder().getOrCreate())
val home = new Path("placeholder")
.getFileSystem(spark.sparkContext.hadoopConfiguration)
| val fileLength = fs.getFileStatus(hadoopPath).getLen.toInt | ||
| val modelBytes = new Array[Byte](fileLength) |
There was a problem hiding this comment.
Converting getLen (Long) to Int can overflow for model files >2GB, leading to negative/incorrect array sizes and corrupted reads. Consider handling lengths > Int.MaxValue explicitly (e.g., fail with a clear error) or reading the stream without pre-allocating an Int-sized byte array.
| val fileLength = fs.getFileStatus(hadoopPath).getLen.toInt | |
| val modelBytes = new Array[Byte](fileLength) | |
| val fileLength = fs.getFileStatus(hadoopPath).getLen | |
| if (fileLength > Int.MaxValue) { | |
| throw new IllegalArgumentException( | |
| s"ONNX model at $path is too large: $fileLength bytes (max supported: ${Int.MaxValue} bytes)." | |
| ) | |
| } | |
| val fileLengthInt = fileLength.toInt | |
| val modelBytes = new Array[Byte](fileLengthInt) |
| if (getUseBarrierExecutionMode) { | ||
| (dfSubset.repartition(numExecutorTasks), numExecutorTasks) | ||
| } | ||
| else { | ||
| (dfSubset.coalesce(numExecutorTasks), numExecutorTasks) | ||
| } |
There was a problem hiding this comment.
In non-barrier mode, dfSubset.coalesce(numExecutorTasks) will not increase the number of partitions, but this method still returns numExecutorTasks as numTasks. That can make downstream coordination (e.g., augmentVowpalWabbitArguments, --node numbering) think there are more tasks than Spark will actually run, which can lead to hangs/timeouts or incorrect distributed setup. Consider ensuring the returned task count matches the actual partition count (e.g., by using repartition(numExecutorTasks) when increasing partitions, or by deriving the current partition count from the DataFrame plan and capping numExecutorTasks).
| if (getUseBarrierExecutionMode) { | |
| (dfSubset.repartition(numExecutorTasks), numExecutorTasks) | |
| } | |
| else { | |
| (dfSubset.coalesce(numExecutorTasks), numExecutorTasks) | |
| } | |
| val partitionedDf = | |
| if (getUseBarrierExecutionMode) { | |
| dfSubset.repartition(numExecutorTasks) | |
| } else { | |
| dfSubset.coalesce(numExecutorTasks) | |
| } | |
| // Ensure returned task count matches actual partition count to avoid | |
| // overestimating tasks in non-barrier mode where coalesce cannot increase partitions. | |
| val actualNumTasks = partitionedDf.rdd.getNumPartitions | |
| (partitionedDf, actualNumTasks) |
Replace SparkContext/RDD APIs with DataFrame/SparkSession equivalents per SPARK-48909 pattern. This enables SynapseML to work in environments where RDDs are restricted (e.g., Databricks Unity Catalog shared mode) and improves forward-compatibility with Spark 4.0+. Key changes: - ComplexParamsSerializer/Serializer: Replace sc.parallelize().saveAsTextFile() with spark.createDataFrame().write.text() for metadata serialization - LightGBMBooster: Replace sc.parallelize() with Seq().toDS() for model I/O - Lambda: Replace SparkContext.getOrCreate() + sc.emptyRDD with SparkSession - ONNXModel: Replace sc.binaryFiles() with Hadoop FileSystem API - ONNXHub: Replace SparkContext.hadoopConfiguration with SparkSession - StratifiedRepartition: Replace RDD keyBy/sampleByKeyExact/RangePartitioner with DataFrame-based oversampling and round-robin partitioning - Repartition: Replace .rdd.repartition() with DataFrame repartition - ClusterUtil: Replace .rdd.mapPartitionsWithIndex with spark_partition_id() - VectorOps: Replace sparkContext.parallelize with spark.range() - SyntheticEstimator: Replace df.rdd.zipWithIndex with monotonically_increasing_id - TuneHyperparameters: Replace MLUtils.kFold(df.rdd) with DataFrame-based k-fold - VowpalWabbitBase: Refactor prepareDataSet to return (DataFrame, Int) tuple - LightGBMBase/Ranker: Simplify partition management without .rdd.getNumPartitions - DistributedHTTPSource: Replace sparkContext.parallelize with spark.range() Remaining RDD usage (no DataFrame API alternatives): - Barrier execution: VowpalWabbitBaseLearner, LightGBMBase (df.rdd.barrier()) - MLlib evaluators: ComputeModelStatistics, RankingEvaluator (require RDD input) - MLlib linalg: SARModel (CoordinateMatrix requires RDD) - Streaming internals: HTTPSource, DistributedHTTPSource, HTTPSourceV2 Closes #2401 Co-authored-by: Copilot <[email protected]>
1f69881 to
66f0afd
Compare
Summary
Replace SparkContext/RDD APIs with DataFrame/SparkSession equivalents per SPARK-48909 pattern. This enables SynapseML to work in environments where RDDs are restricted (e.g., Databricks Unity Catalog shared access mode) and improves forward-compatibility with Spark 4.0+.
Closes #2401
Changes by category
Critical SPARK-48909 pattern (metadata serialization)
sc.parallelize(Seq(json), 1).saveAsTextFile()withspark.createDataFrame().write.text()writeToHDFS,readFromHDFS,makeQualifiedPathfromSparkContexttoSparkSessionparamssc.parallelize()withSeq().toDS()for model save/dumpSparkContext elimination
SparkContext.getOrCreate()+sc.emptyRDDwithSparkSession+createDataFramesc.binaryFiles()with Hadoop FileSystem APISparkContext.hadoopConfigurationwithSparkSession.active.rdd.mapPartitionsWithIndexwithspark_partition_id()+ groupByDataFrame.rdd elimination
keyBy/sampleByKeyExact/RangePartitionerwith DataFrame-based oversampling and round-robin partitioning.rdd.repartition()with DataFramerepartition()sparkContext.parallelizewithspark.range()df.rdd.zipWithIndexwithmonotonically_increasing_id().rdd.distinct().count()with.distinct().count()MLUtils.kFold(df.rdd)with DataFrame-based k-fold splittingsparkContext.parallelizewithspark.range()getNumPartitions patterns
prepareDataSetto return(DataFrame, Int)tuple, eliminating.rdd.getNumPartitions.rdd.getNumPartitionsRemaining RDD usage (no DataFrame API alternatives)
These are documented with comments in the code:
df.rdd.barrier()in VowpalWabbitBaseLearner and LightGBMBase — no DataFrame equivalent existsinternalCreateDataFramerequires RDDTesting