Skip to content

Commit dcb5812

Browse files
committed
[FLINK-38793][table] Reuse result of functions if ther are used multiple times as input for other functions
1 parent f79e63e commit dcb5812

File tree

5 files changed

+177
-3
lines changed

5 files changed

+177
-3
lines changed

flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/FlinkRelUtil.java

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,24 @@
2424
import org.apache.calcite.rel.core.Project;
2525
import org.apache.calcite.rel.type.RelDataType;
2626
import org.apache.calcite.rex.RexBuilder;
27+
import org.apache.calcite.rex.RexCall;
2728
import org.apache.calcite.rex.RexInputRef;
2829
import org.apache.calcite.rex.RexNode;
2930
import org.apache.calcite.rex.RexProgram;
3031
import org.apache.calcite.rex.RexProgramBuilder;
3132
import org.apache.calcite.rex.RexSlot;
3233
import org.apache.calcite.rex.RexUtil;
3334
import org.apache.calcite.rex.RexVisitorImpl;
35+
import org.apache.calcite.sql.SqlFunction;
36+
import org.apache.calcite.sql.SqlKind;
3437

3538
import java.util.ArrayList;
3639
import java.util.Arrays;
3740
import java.util.HashMap;
41+
import java.util.HashSet;
3842
import java.util.List;
3943
import java.util.Map;
44+
import java.util.Set;
4045
import java.util.stream.Collectors;
4146

4247
/** Utilities for {@link RelNode}. */
@@ -93,6 +98,9 @@ public static boolean isMergeable(Project topProject, Project bottomProject) {
9398
final int[] topInputRefCounter =
9499
initializeArray(topProject.getInput().getRowType().getFieldCount(), 0);
95100

101+
if (functionResultShouldBeReused(topProject.getProjects(), bottomProject.getProjects())) {
102+
return false;
103+
}
96104
return mergeable(topInputRefCounter, topProject.getProjects(), bottomProject.getProjects());
97105
}
98106

@@ -104,6 +112,11 @@ public static boolean isMergeable(Project topProject, Project bottomProject) {
104112
public static boolean isMergeable(Calc topCalc, Calc bottomCalc) {
105113
final RexProgram topProgram = topCalc.getProgram();
106114
final RexProgram bottomProgram = bottomCalc.getProgram();
115+
if (functionResultShouldBeReused(
116+
topProgram.getProjectList(), bottomProgram.getProjectList())) {
117+
return false;
118+
}
119+
107120
final int[] topInputRefCounter =
108121
initializeArray(topCalc.getInput().getRowType().getFieldCount(), 0);
109122

@@ -122,6 +135,43 @@ public static boolean isMergeable(Calc topCalc, Calc bottomCalc) {
122135
return mergeable(topInputRefCounter, topInputRefs, bottomProjects);
123136
}
124137

138+
private static boolean functionResultShouldBeReused(
139+
List<? extends RexNode> topProjectList, List<? extends RexNode> bottomProjectList) {
140+
Set<Integer> indexSet = new HashSet<>();
141+
for (int i = 0; i < bottomProjectList.size(); i++) {
142+
RexNode project = bottomProjectList.get(i);
143+
if (project instanceof RexCall
144+
&& SqlKind.FUNCTION.contains(((RexCall) project).op.getKind())
145+
&& ((RexCall) project).op.isDeterministic()) {
146+
indexSet.add(i);
147+
}
148+
}
149+
if (indexSet.isEmpty()) {
150+
return false;
151+
}
152+
153+
Set<RexNode> rexNodes = new HashSet<>();
154+
for (RexNode rex : topProjectList) {
155+
if (!(rex instanceof RexCall)) {
156+
continue;
157+
}
158+
RexCall rCall = (RexCall) rex;
159+
if (!(rCall.op instanceof SqlFunction)) {
160+
continue;
161+
}
162+
List<RexNode> operands = rCall.operands;
163+
for (RexNode op : operands) {
164+
if (op instanceof RexSlot) {
165+
if (indexSet.contains(((RexSlot) op).getIndex()) && !rexNodes.add(op)) {
166+
return true;
167+
}
168+
}
169+
}
170+
}
171+
172+
return false;
173+
}
174+
125175
/**
126176
* Merges the programs of two {@link Calc} instances and returns a new {@link Calc} instance
127177
* with the merged program.

flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalCalc.scala

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.flink.table.planner.plan.nodes.physical.stream
1919

2020
import org.apache.flink.table.planner.calcite.FlinkTypeFactory
21+
import org.apache.flink.table.planner.functions.sql.BuiltInSqlFunction
2122
import org.apache.flink.table.planner.plan.nodes.exec.{ExecNode, InputProperty}
2223
import org.apache.flink.table.planner.plan.nodes.exec.stream.StreamExecCalc
2324
import org.apache.flink.table.planner.utils.ShortcutUtils.unwrapTableConfig
@@ -26,7 +27,12 @@ import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
2627
import org.apache.calcite.rel.`type`.RelDataType
2728
import org.apache.calcite.rel.RelNode
2829
import org.apache.calcite.rel.core.Calc
29-
import org.apache.calcite.rex.RexProgram
30+
import org.apache.calcite.rex.{RexCall, RexLocalRef, RexNode, RexProgram, RexShuttle}
31+
import org.apache.calcite.sql.{SqlFunction, SqlKind}
32+
33+
import java.util
34+
import java.util.List
35+
import java.util.function.BiFunction
3036

3137
import scala.collection.JavaConversions._
3238

@@ -44,7 +50,12 @@ class StreamPhysicalCalc(
4450
}
4551

4652
override def translateToExecNode(): ExecNode[_] = {
47-
val projection = calcProgram.getProjectList.map(calcProgram.expandLocalRef)
53+
val funtionToCountMap = new util.HashMap[RexNode, Integer]()
54+
val shuttle = new FunctionRefCounter(calcProgram.getExprList, funtionToCountMap)
55+
56+
calcProgram.getProjectList.map(ref => ref.accept(shuttle))
57+
val projection = calcProgram.getProjectList.map(
58+
ref => ref.accept(new ExpansionShuttle(calcProgram.getExprList, funtionToCountMap)))
4859
val condition = if (calcProgram.getCondition != null) {
4960
calcProgram.expandLocalRef(calcProgram.getCondition)
5061
} else {
@@ -59,4 +70,49 @@ class StreamPhysicalCalc(
5970
FlinkTypeFactory.toLogicalRowType(getRowType),
6071
getRelDetailedDescription)
6172
}
73+
74+
private def isDeterministicFunction(rexNode: RexNode): Boolean = {
75+
SqlKind.FUNCTION.contains(rexNode.getKind) && rexNode.isInstanceOf[RexCall] && rexNode
76+
.asInstanceOf[RexCall]
77+
.op
78+
.isInstanceOf[SqlFunction] && rexNode
79+
.asInstanceOf[RexCall]
80+
.op
81+
.asInstanceOf[SqlFunction]
82+
.isDeterministic
83+
}
84+
85+
private class ExpansionShuttle(
86+
private val exprs: util.List[RexNode],
87+
val map: util.Map[RexNode, Integer])
88+
extends RexShuttle {
89+
override def visitLocalRef(localRef: RexLocalRef): RexNode = {
90+
val tree: RexNode = this.exprs.get(localRef.getIndex)
91+
if (
92+
isDeterministicFunction(tree) && map
93+
.get(tree) > 1
94+
) {
95+
for (op <- tree.asInstanceOf[RexCall].operands) {
96+
if (op.isInstanceOf[RexLocalRef]) {
97+
return tree.accept(this)
98+
}
99+
}
100+
return localRef
101+
}
102+
tree.accept(this)
103+
}
104+
}
105+
106+
private class FunctionRefCounter(
107+
private val exprs: util.List[RexNode],
108+
val map: util.Map[RexNode, Integer])
109+
extends RexShuttle {
110+
override def visitLocalRef(localRef: RexLocalRef): RexNode = {
111+
val tree: RexNode = this.exprs.get(localRef.getIndex)
112+
if (isDeterministicFunction(tree)) {
113+
map.merge(tree, 1, (x, y) => x + y)
114+
}
115+
tree.accept(this)
116+
}
117+
}
62118
}

flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/utils/JavaUserDefinedScalarFunctions.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,26 @@ public boolean isDeterministic() {
141141
}
142142
}
143143

144+
/** Deterministic scalar function. */
145+
public static class DeterministicUdf extends ScalarFunction {
146+
public int eval() {
147+
return 0;
148+
}
149+
150+
public int eval(@DataTypeHint("INT") int v) {
151+
return v;
152+
}
153+
154+
public String eval(String v) {
155+
return v;
156+
}
157+
158+
@Override
159+
public boolean isDeterministic() {
160+
return true;
161+
}
162+
}
163+
144164
/** Test for Python Scalar Function. */
145165
public static class PythonScalarFunction extends ScalarFunction implements PythonFunction {
146166
private final String name;

flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/CalcTest.xml

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,42 @@ LogicalProject(EXPR$0=[$0])
678678
<![CDATA[
679679
Calc(select=[a AS EXPR$0])
680680
+- TableSourceScan(table=[[default_catalog, default_database, MyTable]], fields=[a, b, c])
681+
]]>
682+
</Resource>
683+
</TestCase>
684+
<TestCase name="testReusedFunctionInProject">
685+
<Resource name="sql">
686+
<![CDATA[SELECT LTRIM(q), RTRIM(q) FROM (SELECT TRIM(c) as q FROM MyTable) t]]>
687+
</Resource>
688+
<Resource name="ast">
689+
<![CDATA[
690+
LogicalProject(EXPR$0=[LTRIM($0)], EXPR$1=[RTRIM($0)])
691+
+- LogicalProject(q=[TRIM(FLAG(BOTH), _UTF-16LE' ', $2)])
692+
+- LogicalTableScan(table=[[default_catalog, default_database, MyTable]])
693+
]]>
694+
</Resource>
695+
<Resource name="optimized exec plan">
696+
<![CDATA[
697+
Calc(select=[LTRIM(TRIM(BOTH, ' ', c)) AS EXPR$0, RTRIM(TRIM(BOTH, ' ', c)) AS EXPR$1])
698+
+- TableSourceScan(table=[[default_catalog, default_database, MyTable]], fields=[a, b, c])
699+
]]>
700+
</Resource>
701+
</TestCase>
702+
<TestCase name="testReusedFunctionInProjectWithDeterministicUdf">
703+
<Resource name="sql">
704+
<![CDATA[SELECT JSON_VALUE(json_data, '$.id'), JSON_VALUE(json_data, '$.name') FROM (SELECT deterministic_udf(c) as json_data FROM MyTable) t]]>
705+
</Resource>
706+
<Resource name="ast">
707+
<![CDATA[
708+
LogicalProject(EXPR$0=[JSON_VALUE($0, _UTF-16LE'$.id')], EXPR$1=[JSON_VALUE($0, _UTF-16LE'$.name')])
709+
+- LogicalProject(json_data=[deterministic_udf($2)])
710+
+- LogicalTableScan(table=[[default_catalog, default_database, MyTable]])
711+
]]>
712+
</Resource>
713+
<Resource name="optimized exec plan">
714+
<![CDATA[
715+
Calc(select=[JSON_VALUE(deterministic_udf(c), '$.id') AS EXPR$0, JSON_VALUE(deterministic_udf(c), '$.name') AS EXPR$1])
716+
+- TableSourceScan(table=[[default_catalog, default_database, MyTable]], fields=[a, b, c])
681717
]]>
682718
</Resource>
683719
</TestCase>

flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/CalcTest.scala

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.flink.table.planner.plan.stream.sql
1919

2020
import org.apache.flink.table.api._
2121
import org.apache.flink.table.planner.plan.utils.MyPojo
22-
import org.apache.flink.table.planner.runtime.utils.JavaUserDefinedScalarFunctions.NonDeterministicUdf
22+
import org.apache.flink.table.planner.runtime.utils.JavaUserDefinedScalarFunctions.{DeterministicUdf, NonDeterministicUdf}
2323
import org.apache.flink.table.planner.runtime.utils.JavaUserDefinedTableFunctions.StringSplit
2424
import org.apache.flink.table.planner.utils.TableTestBase
2525
import org.apache.flink.table.types.AbstractDataType
@@ -36,13 +36,25 @@ class CalcTest extends TableTestBase {
3636
def setup(): Unit = {
3737
util.addTableSource[(Long, Int, String)]("MyTable", 'a, 'b, 'c)
3838
util.addTemporarySystemFunction("random_udf", new NonDeterministicUdf)
39+
util.addTemporarySystemFunction("deterministic_udf", new DeterministicUdf)
3940
}
4041

4142
@Test
4243
def testOnlyProject(): Unit = {
4344
util.verifyExecPlan("SELECT a, c FROM MyTable")
4445
}
4546

47+
@Test
48+
def testReusedFunctionInProject(): Unit = {
49+
util.verifyExecPlan("SELECT LTRIM(q), RTRIM(q) FROM (SELECT TRIM(c) as q FROM MyTable) t")
50+
}
51+
52+
@Test
53+
def testReusedFunctionInProjectWithDeterministicUdf(): Unit = {
54+
util.verifyExecPlan(
55+
"SELECT JSON_VALUE(json_data, '$.id'), JSON_VALUE(json_data, '$.name') FROM (SELECT deterministic_udf(c) as json_data FROM MyTable) t")
56+
}
57+
4658
@Test
4759
def testProjectWithNaming(): Unit = {
4860
util.verifyExecPlan("SELECT `1-_./Ü`, b, c FROM (SELECT a as `1-_./Ü`, b, c FROM MyTable)")

0 commit comments

Comments
 (0)