Skip to content

Commit 65764d8

Browse files
committed
[FLINK-38793][table] Reuse result of functions if ther are used multiple times as input for other functions
1 parent d720618 commit 65764d8

File tree

6 files changed

+254
-6
lines changed

6 files changed

+254
-6
lines changed

flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/FlinkRelUtil.java

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,25 @@
2424
import org.apache.calcite.rel.core.Project;
2525
import org.apache.calcite.rel.type.RelDataType;
2626
import org.apache.calcite.rex.RexBuilder;
27+
import org.apache.calcite.rex.RexCall;
2728
import org.apache.calcite.rex.RexInputRef;
29+
import org.apache.calcite.rex.RexLocalRef;
2830
import org.apache.calcite.rex.RexNode;
2931
import org.apache.calcite.rex.RexProgram;
3032
import org.apache.calcite.rex.RexProgramBuilder;
3133
import org.apache.calcite.rex.RexSlot;
3234
import org.apache.calcite.rex.RexUtil;
3335
import org.apache.calcite.rex.RexVisitorImpl;
36+
import org.apache.calcite.sql.SqlFunction;
37+
import org.apache.calcite.sql.SqlKind;
3438

3539
import java.util.ArrayList;
3640
import java.util.Arrays;
3741
import java.util.HashMap;
42+
import java.util.HashSet;
3843
import java.util.List;
3944
import java.util.Map;
45+
import java.util.Set;
4046
import java.util.stream.Collectors;
4147

4248
/** Utilities for {@link RelNode}. */
@@ -93,6 +99,9 @@ public static boolean isMergeable(Project topProject, Project bottomProject) {
9399
final int[] topInputRefCounter =
94100
initializeArray(topProject.getInput().getRowType().getFieldCount(), 0);
95101

102+
if (functionResultShouldBeReused(topProject, bottomProject)) {
103+
return false;
104+
}
96105
return mergeable(topInputRefCounter, topProject.getProjects(), bottomProject.getProjects());
97106
}
98107

@@ -104,6 +113,10 @@ public static boolean isMergeable(Project topProject, Project bottomProject) {
104113
public static boolean isMergeable(Calc topCalc, Calc bottomCalc) {
105114
final RexProgram topProgram = topCalc.getProgram();
106115
final RexProgram bottomProgram = bottomCalc.getProgram();
116+
if (functionResultShouldBeReused(topProgram, bottomProgram)) {
117+
return false;
118+
}
119+
107120
final int[] topInputRefCounter =
108121
initializeArray(topCalc.getInput().getRowType().getFieldCount(), 0);
109122

@@ -122,6 +135,84 @@ public static boolean isMergeable(Calc topCalc, Calc bottomCalc) {
122135
return mergeable(topInputRefCounter, topInputRefs, bottomProjects);
123136
}
124137

138+
private static boolean functionResultShouldBeReused(Project topProject, Project bottomProject) {
139+
Set<Integer> indexSet = new HashSet<>();
140+
List<RexNode> bottomProjectList = bottomProject.getProjects();
141+
for (int i = 0; i < bottomProjectList.size(); i++) {
142+
RexNode project = bottomProjectList.get(i);
143+
if (project instanceof RexCall
144+
// && SqlKind.FUNCTION.contains(((RexCall) project).op.getKind())
145+
&& ((RexCall) project).op.isDeterministic()) {
146+
indexSet.add(i);
147+
}
148+
}
149+
if (indexSet.isEmpty()) {
150+
return false;
151+
}
152+
153+
Set<RexNode> rexNodes = new HashSet<>();
154+
List<RexNode> topProjectList = topProject.getProjects();
155+
for (RexNode rex : topProjectList) {
156+
if (!(rex instanceof RexCall)) {
157+
continue;
158+
}
159+
RexCall rCall = (RexCall) rex;
160+
if (!(rCall.op instanceof SqlFunction)) {
161+
continue;
162+
}
163+
List<RexNode> operands = rCall.operands;
164+
for (RexNode op : operands) {
165+
if (op instanceof RexSlot) {
166+
if (indexSet.contains(((RexSlot) op).getIndex()) && !rexNodes.add(op)) {
167+
return true;
168+
}
169+
}
170+
}
171+
}
172+
173+
return false;
174+
}
175+
176+
private static boolean functionResultShouldBeReused(
177+
RexProgram topProgram, RexProgram bottomProgram) {
178+
Set<Integer> indexSet = new HashSet<>();
179+
List<RexLocalRef> bottomProjectList = bottomProgram.getProjectList();
180+
for (int i = 0; i < bottomProjectList.size(); i++) {
181+
int index = bottomProjectList.get(i).getIndex();
182+
RexNode rexNode = bottomProgram.getExprList().get(index);
183+
if (rexNode instanceof RexCall
184+
&& SqlKind.FUNCTION.contains(((RexCall) rexNode).op.getKind())
185+
&& ((RexCall) rexNode).op.isDeterministic()) {
186+
indexSet.add(i);
187+
}
188+
}
189+
if (indexSet.isEmpty()) {
190+
return false;
191+
}
192+
193+
Set<RexNode> rexNodes = new HashSet<>();
194+
List<RexNode> topExprList = topProgram.getExprList();
195+
for (RexNode rex : topExprList) {
196+
if (!(rex instanceof RexCall)) {
197+
continue;
198+
}
199+
RexCall rCall = (RexCall) rex;
200+
if (!(rCall.op instanceof SqlFunction)) {
201+
continue;
202+
}
203+
List<RexNode> operands = rCall.operands;
204+
for (RexNode op : operands) {
205+
if (op instanceof RexSlot) {
206+
if (indexSet.contains(((RexSlot) op).getIndex()) && !rexNodes.add(op)) {
207+
return true;
208+
}
209+
}
210+
}
211+
}
212+
213+
return false;
214+
}
215+
125216
/**
126217
* Merges the programs of two {@link Calc} instances and returns a new {@link Calc} instance
127218
* with the merged program.

flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/common/CommonCalc.scala

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ import org.apache.calcite.rel.{RelNode, RelWriter}
2727
import org.apache.calcite.rel.core.Calc
2828
import org.apache.calcite.rel.hint.RelHint
2929
import org.apache.calcite.rel.metadata.RelMetadataQuery
30-
import org.apache.calcite.rex.{RexCall, RexInputRef, RexLiteral, RexProgram}
31-
import org.apache.calcite.sql.SqlExplainLevel
30+
import org.apache.calcite.rex.{RexCall, RexInputRef, RexLiteral, RexLocalRef, RexNode, RexProgram, RexShuttle}
31+
import org.apache.calcite.sql.{SqlExplainLevel, SqlKind}
3232

3333
import java.util.Collections
3434

@@ -49,12 +49,26 @@ abstract class CommonCalc(
4949
// conditions, etc. We only want to account for computations, not for simple projections.
5050
// CASTs in RexProgram are reduced as far as possible by ReduceExpressionsRule
5151
// in normalization stage. So we should ignore CASTs here in optimization stage.
52-
val compCnt = calcProgram.getProjectList.map(calcProgram.expandLocalRef).toList.count {
52+
val map = new java.util.HashMap[RexNode, Integer]()
53+
val shuttle = new FunctionCounter(calcProgram.getExprList, map)
54+
calcProgram.getProjectList.map(rf => rf.accept(shuttle))
55+
val compCnt1 = calcProgram.getProjectList.map(calcProgram.expandLocalRef).toList.count {
5356
case _: RexInputRef => false
5457
case _: RexLiteral => false
5558
case c: RexCall if c.getOperator.getName.equals("CAST") => false
5659
case _ => true
5760
}
61+
val offset = map
62+
.filterKeys {
63+
case _: RexInputRef => false
64+
case _: RexLiteral => false
65+
case c: RexCall if c.getOperator.getName.equals("CAST") => false
66+
case _ => true
67+
}
68+
.values
69+
.foldLeft(0)(_ + _)
70+
71+
val compCnt = Math.max(compCnt1, offset)
5872
val newRowCnt = mq.getRowCount(this)
5973
// TODO use inputRowCnt to compute cpu cost
6074
planner.getCostFactory.makeCost(newRowCnt, newRowCnt * compCnt, 0)
@@ -102,4 +116,21 @@ abstract class CommonCalc(
102116
.mkString(", ")
103117
}
104118

119+
class FunctionCounter(
120+
private val exprs: java.util.List[RexNode],
121+
val map: java.util.Map[RexNode, Integer])
122+
extends RexShuttle {
123+
override def visitLocalRef(localRef: RexLocalRef): RexNode = {
124+
val tree: RexNode = this.exprs.get(localRef.getIndex)
125+
if (
126+
SqlKind.FUNCTION.contains(tree.getKind)
127+
&& tree.isInstanceOf[RexCall]
128+
&& tree.asInstanceOf[RexCall].op.isDeterministic
129+
) {
130+
map.merge(tree, 1, (x, y) => x + y)
131+
}
132+
tree.accept(this)
133+
}
134+
}
135+
105136
}

flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalCalc.scala

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.flink.table.planner.plan.nodes.physical.stream
1919

2020
import org.apache.flink.table.planner.calcite.FlinkTypeFactory
21+
import org.apache.flink.table.planner.functions.sql.BuiltInSqlFunction
2122
import org.apache.flink.table.planner.plan.nodes.exec.{ExecNode, InputProperty}
2223
import org.apache.flink.table.planner.plan.nodes.exec.stream.StreamExecCalc
2324
import org.apache.flink.table.planner.utils.ShortcutUtils.unwrapTableConfig
@@ -26,7 +27,12 @@ import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
2627
import org.apache.calcite.rel.`type`.RelDataType
2728
import org.apache.calcite.rel.RelNode
2829
import org.apache.calcite.rel.core.Calc
29-
import org.apache.calcite.rex.RexProgram
30+
import org.apache.calcite.rex.{RexCall, RexLocalRef, RexNode, RexProgram, RexShuttle}
31+
import org.apache.calcite.sql.{SqlFunction, SqlKind}
32+
33+
import java.util
34+
import java.util.List
35+
import java.util.function.BiFunction
3036

3137
import scala.collection.JavaConversions._
3238

@@ -44,7 +50,12 @@ class StreamPhysicalCalc(
4450
}
4551

4652
override def translateToExecNode(): ExecNode[_] = {
47-
val projection = calcProgram.getProjectList.map(calcProgram.expandLocalRef)
53+
val funtionToCountMap = new util.HashMap[RexNode, Integer]()
54+
val shuttle = new FunctionRefCounter(calcProgram.getExprList, funtionToCountMap)
55+
56+
calcProgram.getProjectList.map(ref => ref.accept(shuttle))
57+
val projection = calcProgram.getProjectList.map(
58+
ref => ref.accept(new ExpansionShuttle(calcProgram.getExprList, funtionToCountMap)))
4859
val condition = if (calcProgram.getCondition != null) {
4960
calcProgram.expandLocalRef(calcProgram.getCondition)
5061
} else {
@@ -59,4 +70,49 @@ class StreamPhysicalCalc(
5970
FlinkTypeFactory.toLogicalRowType(getRowType),
6071
getRelDetailedDescription)
6172
}
73+
74+
private def isDeterministicFunction(rexNode: RexNode): Boolean = {
75+
SqlKind.FUNCTION.contains(rexNode.getKind) && rexNode.isInstanceOf[RexCall] && rexNode
76+
.asInstanceOf[RexCall]
77+
.op
78+
.isInstanceOf[SqlFunction] && rexNode
79+
.asInstanceOf[RexCall]
80+
.op
81+
.asInstanceOf[SqlFunction]
82+
.isDeterministic
83+
}
84+
85+
private class ExpansionShuttle(
86+
private val exprs: util.List[RexNode],
87+
val map: util.Map[RexNode, Integer])
88+
extends RexShuttle {
89+
override def visitLocalRef(localRef: RexLocalRef): RexNode = {
90+
val tree: RexNode = this.exprs.get(localRef.getIndex)
91+
if (
92+
isDeterministicFunction(tree) && map
93+
.get(tree) > 1
94+
) {
95+
for (op <- tree.asInstanceOf[RexCall].operands) {
96+
if (op.isInstanceOf[RexLocalRef]) {
97+
return tree.accept(this)
98+
}
99+
}
100+
return localRef
101+
}
102+
tree.accept(this)
103+
}
104+
}
105+
106+
private class FunctionRefCounter(
107+
private val exprs: util.List[RexNode],
108+
val map: util.Map[RexNode, Integer])
109+
extends RexShuttle {
110+
override def visitLocalRef(localRef: RexLocalRef): RexNode = {
111+
val tree: RexNode = this.exprs.get(localRef.getIndex)
112+
if (isDeterministicFunction(tree)) {
113+
map.merge(tree, 1, (x, y) => x + y)
114+
}
115+
tree.accept(this)
116+
}
117+
}
62118
}

flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/utils/JavaUserDefinedScalarFunctions.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,26 @@ public boolean isDeterministic() {
141141
}
142142
}
143143

144+
/** Deterministic scalar function. */
145+
public static class DeterministicUdf extends ScalarFunction {
146+
public int eval() {
147+
return 0;
148+
}
149+
150+
public int eval(@DataTypeHint("INT") int v) {
151+
return v;
152+
}
153+
154+
public String eval(String v) {
155+
return v;
156+
}
157+
158+
@Override
159+
public boolean isDeterministic() {
160+
return true;
161+
}
162+
}
163+
144164
/** Test for Python Scalar Function. */
145165
public static class PythonScalarFunction extends ScalarFunction implements PythonFunction {
146166
private final String name;

flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/CalcTest.xml

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,44 @@ LogicalProject(EXPR$0=[$0])
678678
<![CDATA[
679679
Calc(select=[a AS EXPR$0])
680680
+- TableSourceScan(table=[[default_catalog, default_database, MyTable]], fields=[a, b, c])
681+
]]>
682+
</Resource>
683+
</TestCase>
684+
<TestCase name="testReusedFunctionInProject">
685+
<Resource name="sql">
686+
<![CDATA[SELECT LTRIM(q), RTRIM(q) FROM (SELECT TRIM(c) as q FROM MyTable) t]]>
687+
</Resource>
688+
<Resource name="ast">
689+
<![CDATA[
690+
LogicalProject(EXPR$0=[LTRIM($0)], EXPR$1=[RTRIM($0)])
691+
+- LogicalProject(q=[TRIM(FLAG(BOTH), _UTF-16LE' ', $2)])
692+
+- LogicalTableScan(table=[[default_catalog, default_database, MyTable]])
693+
]]>
694+
</Resource>
695+
<Resource name="optimized exec plan">
696+
<![CDATA[
697+
Calc(select=[LTRIM(q) AS EXPR$0, RTRIM(q) AS EXPR$1])
698+
+- Calc(select=[TRIM(BOTH, ' ', c) AS q])
699+
+- TableSourceScan(table=[[default_catalog, default_database, MyTable]], fields=[a, b, c])
700+
]]>
701+
</Resource>
702+
</TestCase>
703+
<TestCase name="testReusedFunctionInProjectWithDeterministicUdf">
704+
<Resource name="sql">
705+
<![CDATA[SELECT JSON_VALUE(json_data, '$.id'), JSON_VALUE(json_data, '$.name') FROM (SELECT deterministic_udf(c) as json_data FROM MyTable) t]]>
706+
</Resource>
707+
<Resource name="ast">
708+
<![CDATA[
709+
LogicalProject(EXPR$0=[JSON_VALUE($0, _UTF-16LE'$.id')], EXPR$1=[JSON_VALUE($0, _UTF-16LE'$.name')])
710+
+- LogicalProject(json_data=[deterministic_udf($2)])
711+
+- LogicalTableScan(table=[[default_catalog, default_database, MyTable]])
712+
]]>
713+
</Resource>
714+
<Resource name="optimized exec plan">
715+
<![CDATA[
716+
Calc(select=[JSON_VALUE(json_data, '$.id') AS EXPR$0, JSON_VALUE(json_data, '$.name') AS EXPR$1])
717+
+- Calc(select=[deterministic_udf(c) AS json_data])
718+
+- TableSourceScan(table=[[default_catalog, default_database, MyTable]], fields=[a, b, c])
681719
]]>
682720
</Resource>
683721
</TestCase>

flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/CalcTest.scala

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.flink.table.planner.plan.stream.sql
1919

2020
import org.apache.flink.table.api._
2121
import org.apache.flink.table.planner.plan.utils.MyPojo
22-
import org.apache.flink.table.planner.runtime.utils.JavaUserDefinedScalarFunctions.NonDeterministicUdf
22+
import org.apache.flink.table.planner.runtime.utils.JavaUserDefinedScalarFunctions.{DeterministicUdf, NonDeterministicUdf}
2323
import org.apache.flink.table.planner.runtime.utils.JavaUserDefinedTableFunctions.StringSplit
2424
import org.apache.flink.table.planner.utils.TableTestBase
2525
import org.apache.flink.table.types.AbstractDataType
@@ -36,13 +36,25 @@ class CalcTest extends TableTestBase {
3636
def setup(): Unit = {
3737
util.addTableSource[(Long, Int, String)]("MyTable", 'a, 'b, 'c)
3838
util.addTemporarySystemFunction("random_udf", new NonDeterministicUdf)
39+
util.addTemporarySystemFunction("deterministic_udf", new DeterministicUdf)
3940
}
4041

4142
@Test
4243
def testOnlyProject(): Unit = {
4344
util.verifyExecPlan("SELECT a, c FROM MyTable")
4445
}
4546

47+
@Test
48+
def testReusedFunctionInProject(): Unit = {
49+
util.verifyExecPlan("SELECT LTRIM(q), RTRIM(q) FROM (SELECT TRIM(c) as q FROM MyTable) t")
50+
}
51+
52+
@Test
53+
def testReusedFunctionInProjectWithDeterministicUdf(): Unit = {
54+
util.verifyExecPlan(
55+
"SELECT JSON_VALUE(json_data, '$.id'), JSON_VALUE(json_data, '$.name') FROM (SELECT deterministic_udf(c) as json_data FROM MyTable) t")
56+
}
57+
4658
@Test
4759
def testProjectWithNaming(): Unit = {
4860
util.verifyExecPlan("SELECT `1-_./Ü`, b, c FROM (SELECT a as `1-_./Ü`, b, c FROM MyTable)")

0 commit comments

Comments
 (0)