Skip to content

Commit 3a3078d

Browse files
authored
Merge pull request #976 from malwaregarry/norec-datafusion
[DataFusion] Use common NoREC oracle
2 parents 8ddcb3c + 97095e8 commit 3a3078d

8 files changed

Lines changed: 150 additions & 91 deletions

File tree

src/sqlancer/datafusion/DataFusionErrors.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
import static sqlancer.datafusion.DataFusionUtil.dfAssert;
44

5+
import java.util.ArrayList;
6+
import java.util.List;
7+
58
import sqlancer.common.query.ExpectedErrors;
69

710
public final class DataFusionErrors {
@@ -17,7 +20,8 @@ private DataFusionErrors() {
1720
* Note now it's implemented this way for simplicity This way might cause false negative, because Q1 and Q2 should
1821
* both succeed or both fail TODO(datafusion): ensure both succeed or both fail
1922
*/
20-
public static void registerExpectedExecutionErrors(ExpectedErrors errors) {
23+
public static List<String> getExpectedExecutionErrors() {
24+
ArrayList<String> errors = new ArrayList<>();
2125
/*
2226
* Expected
2327
*/
@@ -40,5 +44,11 @@ public static void registerExpectedExecutionErrors(ExpectedErrors errors) {
4044
errors.add("Physical plan does not support logical expression AggregateFunction"); // False positive: when aggr
4145
// is generated in where
4246
// clause
47+
48+
return errors;
49+
}
50+
51+
public static void registerExpectedExecutionErrors(ExpectedErrors errors) {
52+
errors.addAll(getExpectedExecutionErrors());
4353
}
4454
}
Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
package sqlancer.datafusion.ast;
22

3-
public interface DataFusionExpression {
3+
import sqlancer.common.ast.newast.Expression;
4+
import sqlancer.datafusion.DataFusionSchema.DataFusionColumn;
5+
6+
public interface DataFusionExpression extends Expression<DataFusionColumn> {
47

58
}

src/sqlancer/datafusion/ast/DataFusionJoin.java

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
import java.util.ArrayList;
44
import java.util.List;
5-
import java.util.stream.Collectors;
65

76
import sqlancer.Randomly;
7+
import sqlancer.common.ast.newast.Join;
88
import sqlancer.datafusion.DataFusionProvider.DataFusionGlobalState;
99
import sqlancer.datafusion.DataFusionSchema;
1010
import sqlancer.datafusion.DataFusionSchema.DataFusionColumn;
@@ -14,12 +14,13 @@
1414
/*
1515
NOT IMPLEMENTED YET
1616
*/
17-
public class DataFusionJoin implements DataFusionExpression {
17+
public class DataFusionJoin
18+
implements DataFusionExpression, Join<DataFusionExpression, DataFusionTable, DataFusionColumn> {
1819

1920
private final DataFusionTableReference leftTable;
2021
private final DataFusionTableReference rightTable;
2122
private final JoinType joinType;
22-
private final DataFusionExpression onCondition;
23+
private DataFusionExpression onCondition;
2324

2425
public DataFusionJoin(DataFusionTableReference leftTable, DataFusionTableReference rightTable, JoinType joinType,
2526
DataFusionExpression whereCondition) {
@@ -29,11 +30,10 @@ public DataFusionJoin(DataFusionTableReference leftTable, DataFusionTableReferen
2930
this.onCondition = whereCondition;
3031
}
3132

32-
public static List<DataFusionExpression> getJoins(List<DataFusionTable> tables, DataFusionGlobalState globalState) {
33+
public static List<DataFusionJoin> getJoins(List<DataFusionTableReference> tableList,
34+
DataFusionGlobalState globalState) {
3335
// [t1_join_t2, t1_join_t3, ...]
34-
List<DataFusionTableReference> tableList = tables.stream().map(t -> new DataFusionTableReference(t))
35-
.collect(Collectors.toList());
36-
List<DataFusionExpression> joinExpressions = new ArrayList<>();
36+
List<DataFusionJoin> joinExpressions = new ArrayList<>();
3737
while (tableList.size() >= 2 && Randomly.getBooleanWithRatherLowProbability()) {
3838
DataFusionTableReference leftTable = tableList.remove(0);
3939
DataFusionTableReference rightTable = tableList.remove(0);
@@ -84,4 +84,8 @@ public static JoinType getRandom() {
8484
}
8585
}
8686

87+
@Override
88+
public void setOnClause(DataFusionExpression onClause) {
89+
onCondition = onClause;
90+
}
8791
}

src/sqlancer/datafusion/ast/DataFusionSelect.java

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,16 @@
66

77
import sqlancer.Randomly;
88
import sqlancer.common.ast.SelectBase;
9+
import sqlancer.common.ast.newast.Select;
910
import sqlancer.datafusion.DataFusionProvider.DataFusionGlobalState;
1011
import sqlancer.datafusion.DataFusionSchema;
12+
import sqlancer.datafusion.DataFusionSchema.DataFusionColumn;
1113
import sqlancer.datafusion.DataFusionSchema.DataFusionTable;
14+
import sqlancer.datafusion.DataFusionToStringVisitor;
1215
import sqlancer.datafusion.gen.DataFusionExpressionGenerator;
1316

14-
public class DataFusionSelect extends SelectBase<DataFusionExpression> implements DataFusionExpression {
17+
public class DataFusionSelect extends SelectBase<DataFusionExpression> implements DataFusionExpression,
18+
Select<DataFusionJoin, DataFusionExpression, DataFusionTable, DataFusionColumn> {
1519
public Optional<String> fetchColumnsString = Optional.empty(); // When available, override `fetchColumns` in base
1620
// class's `Node` representation (for display)
1721
public DataFusionExpressionGenerator exprGen;
@@ -58,4 +62,21 @@ public static DataFusionSelect getRandomSelect(DataFusionGlobalState state) {
5862
public void setFetchColumnsString(String selectExpr) {
5963
this.fetchColumnsString = Optional.of(selectExpr);
6064
}
65+
66+
@Override
67+
public void setJoinClauses(List<DataFusionJoin> joinStatements) {
68+
List<DataFusionExpression> expressions = joinStatements.stream().map(e -> (DataFusionExpression) e)
69+
.collect(Collectors.toList());
70+
setJoinList(expressions);
71+
}
72+
73+
@Override
74+
public List<DataFusionJoin> getJoinClauses() {
75+
return getJoinList().stream().map(e -> (DataFusionJoin) e).collect(Collectors.toList());
76+
}
77+
78+
@Override
79+
public String asString() {
80+
return DataFusionToStringVisitor.asString(this);
81+
}
6182
}

src/sqlancer/datafusion/gen/DataFusionExpressionGenerator.java

Lines changed: 76 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,31 @@
1313
import sqlancer.Randomly;
1414
import sqlancer.common.ast.BinaryOperatorNode.Operator;
1515
import sqlancer.common.ast.newast.NewUnaryPostfixOperatorNode;
16+
import sqlancer.common.gen.NoRECGenerator;
1617
import sqlancer.common.gen.TypedExpressionGenerator;
18+
import sqlancer.common.schema.AbstractTables;
1719
import sqlancer.datafusion.DataFusionProvider.DataFusionGlobalState;
1820
import sqlancer.datafusion.DataFusionSchema.DataFusionColumn;
1921
import sqlancer.datafusion.DataFusionSchema.DataFusionDataType;
22+
import sqlancer.datafusion.DataFusionSchema.DataFusionTable;
23+
import sqlancer.datafusion.DataFusionToStringVisitor;
2024
import sqlancer.datafusion.ast.DataFusionBinaryOperation;
2125
import sqlancer.datafusion.ast.DataFusionColumnReference;
2226
import sqlancer.datafusion.ast.DataFusionExpression;
2327
import sqlancer.datafusion.ast.DataFusionFunction;
28+
import sqlancer.datafusion.ast.DataFusionJoin;
29+
import sqlancer.datafusion.ast.DataFusionSelect;
30+
import sqlancer.datafusion.ast.DataFusionTableReference;
2431
import sqlancer.datafusion.ast.DataFusionUnaryPostfixOperation;
2532
import sqlancer.datafusion.ast.DataFusionUnaryPrefixOperation;
2633
import sqlancer.datafusion.gen.DataFusionBaseExpr.ArgumentType;
2734
import sqlancer.datafusion.gen.DataFusionBaseExpr.DataFusionBaseExprType;
2835

2936
public final class DataFusionExpressionGenerator
30-
extends TypedExpressionGenerator<DataFusionExpression, DataFusionColumn, DataFusionDataType> {
37+
extends TypedExpressionGenerator<DataFusionExpression, DataFusionColumn, DataFusionDataType> implements
38+
NoRECGenerator<DataFusionSelect, DataFusionJoin, DataFusionExpression, DataFusionTable, DataFusionColumn> {
3139

40+
private List<DataFusionTable> tables;
3241
private final DataFusionGlobalState globalState;
3342

3443
public DataFusionExpressionGenerator(DataFusionGlobalState globalState) {
@@ -100,7 +109,8 @@ protected DataFusionExpression generateExpression(DataFusionDataType type, int d
100109
case BINARY:
101110
dfAssert(randomExpr.argTypes.size() == 2 && randomExpr.nArgs == 2,
102111
"Binrary expression should only have 2 argument" + randomExpr.argTypes);
103-
List<DataFusionDataType> argTypeList = new ArrayList<>(); // types of current expression's input arguments
112+
List<DataFusionDataType> argTypeList = new ArrayList<>(); // types of current expression's input
113+
// arguments
104114
for (ArgumentType argumentType : randomExpr.argTypes) {
105115
if (argumentType instanceof ArgumentType.Fixed) {
106116
ArgumentType.Fixed possibleArgTypes = (ArgumentType.Fixed) randomExpr.argTypes.get(0);
@@ -134,7 +144,8 @@ protected DataFusionExpression generateExpression(DataFusionDataType type, int d
134144
public DataFusionExpression generateFunctionExpression(DataFusionDataType type, int depth,
135145
DataFusionBaseExpr exprType) {
136146
if (exprType.isVariadic || Randomly.getBooleanWithSmallProbability()) {
137-
// TODO(datafusion) maybe add possible types. e.g. some function have signature variadic(INT/DOUBLE), then
147+
// TODO(datafusion) maybe add possible types. e.g. some function have signature
148+
// variadic(INT/DOUBLE), then
138149
// only randomly pick from INT and DOUBLE
139150
int nArgs = Randomly.smallNumber(); // 0, 2, 4, ... smaller one is more likely
140151
return new DataFusionFunction<DataFusionBaseExpr>(generateExpressions(nArgs), exprType);
@@ -222,4 +233,66 @@ public String getTextRepresentation() {
222233

223234
}
224235

236+
@Override
237+
public NoRECGenerator<DataFusionSelect, DataFusionJoin, DataFusionExpression, DataFusionTable, DataFusionColumn> setTablesAndColumns(
238+
AbstractTables<DataFusionTable, DataFusionColumn> tables) {
239+
List<DataFusionTable> randomTables = Randomly.nonEmptySubset(tables.getTables());
240+
int maxSize = Randomly.fromOptions(1, 2, 3, 4);
241+
if (randomTables.size() > maxSize) {
242+
randomTables = randomTables.subList(0, maxSize);
243+
}
244+
this.columns = DataFusionTable.getAllColumns(randomTables);
245+
this.tables = randomTables;
246+
247+
return this;
248+
}
249+
250+
@Override
251+
public DataFusionExpression generateBooleanExpression() {
252+
return generateExpression(DataFusionDataType.BOOLEAN);
253+
}
254+
255+
@Override
256+
public DataFusionSelect generateSelect() {
257+
return new DataFusionSelect();
258+
}
259+
260+
@Override
261+
public List<DataFusionJoin> getRandomJoinClauses() {
262+
List<DataFusionTableReference> tableList = tables.stream().map(t -> new DataFusionTableReference(t))
263+
.collect(Collectors.toList());
264+
List<DataFusionJoin> joins = DataFusionJoin.getJoins(tableList, globalState);
265+
tables = tableList.stream().map(t -> t.getTable()).collect(Collectors.toList());
266+
return joins;
267+
}
268+
269+
@Override
270+
public List<DataFusionExpression> getTableRefs() {
271+
return tables.stream().map(t -> new DataFusionTableReference(t)).collect(Collectors.toList());
272+
}
273+
274+
@Override
275+
public String generateOptimizedQueryString(DataFusionSelect select, DataFusionExpression whereCondition,
276+
boolean shouldUseAggregate) {
277+
if (shouldUseAggregate) {
278+
select.setFetchColumnsString("COUNT(*)");
279+
} else {
280+
List<DataFusionExpression> allColumns = columns.stream().map((c) -> new DataFusionColumnReference(c))
281+
.collect(Collectors.toList());
282+
select.setFetchColumns(allColumns);
283+
}
284+
select.setWhereClause(whereCondition);
285+
286+
return select.asString();
287+
}
288+
289+
@Override
290+
public String generateUnoptimizedQueryString(DataFusionSelect select, DataFusionExpression whereCondition) {
291+
String fetchColumn = String.format("COUNT(CASE WHEN %S THEN 1 ELSE NULL END)",
292+
DataFusionToStringVisitor.asString(whereCondition));
293+
select.setFetchColumnsString(fetchColumn);
294+
select.setWhereClause(null);
295+
296+
return select.asString();
297+
}
225298
}
Lines changed: 24 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,95 +1,44 @@
11
package sqlancer.datafusion.test;
22

3-
import static sqlancer.datafusion.DataFusionUtil.DataFusionLogger.DataFusionLogType.ERROR;
4-
import static sqlancer.datafusion.ast.DataFusionSelect.getRandomSelect;
5-
63
import java.sql.SQLException;
7-
import java.util.List;
84

9-
import sqlancer.ComparatorHelper;
10-
import sqlancer.common.oracle.NoRECBase;
5+
import sqlancer.Reproducer;
6+
import sqlancer.common.oracle.NoRECOracle;
117
import sqlancer.common.oracle.TestOracle;
8+
import sqlancer.common.query.ExpectedErrors;
129
import sqlancer.datafusion.DataFusionErrors;
1310
import sqlancer.datafusion.DataFusionProvider.DataFusionGlobalState;
14-
import sqlancer.datafusion.DataFusionToStringVisitor;
15-
import sqlancer.datafusion.DataFusionUtil;
11+
import sqlancer.datafusion.DataFusionSchema;
12+
import sqlancer.datafusion.DataFusionSchema.DataFusionColumn;
13+
import sqlancer.datafusion.DataFusionSchema.DataFusionTable;
14+
import sqlancer.datafusion.ast.DataFusionExpression;
15+
import sqlancer.datafusion.ast.DataFusionJoin;
1616
import sqlancer.datafusion.ast.DataFusionSelect;
17+
import sqlancer.datafusion.gen.DataFusionExpressionGenerator;
1718

18-
public class DataFusionNoRECOracle extends NoRECBase<DataFusionGlobalState>
19-
implements TestOracle<DataFusionGlobalState> {
19+
public class DataFusionNoRECOracle implements TestOracle<DataFusionGlobalState> {
2020

21-
private final DataFusionGlobalState state;
21+
NoRECOracle<DataFusionSelect, DataFusionJoin, DataFusionExpression, DataFusionSchema, DataFusionTable, DataFusionColumn, DataFusionGlobalState> oracle;
2222

2323
public DataFusionNoRECOracle(DataFusionGlobalState globalState) {
24-
super(globalState);
25-
this.state = globalState;
26-
DataFusionErrors.registerExpectedExecutionErrors(errors);
24+
DataFusionExpressionGenerator gen = new DataFusionExpressionGenerator(globalState);
25+
ExpectedErrors errors = ExpectedErrors.newErrors().with(DataFusionErrors.getExpectedExecutionErrors())
26+
.with("canceling statement due to statement timeout").build();
27+
this.oracle = new NoRECOracle<>(globalState, gen, errors);
2728
}
2829

29-
/*
30-
* Non-Optimizing Reference Engine Construction q1: SELECT [expr1] FROM [expr2] WHERE [expr3] q2: SELECT [expr3]
31-
* FROM [expr2]
32-
*
33-
* Oracle Check: q1's result size equals to `true` count in q2's result set
34-
*/
3530
@Override
3631
public void check() throws SQLException {
37-
/*
38-
* Setup Q1 and Q2
39-
*/
40-
// generate a random:
41-
// SELECT [expr1] FROM [expr2] WHERE [expr3]
42-
DataFusionSelect randomSelect = getRandomSelect(state);
43-
// Q1: SELECT count(*) FROM [expr2] WHERE [expr3]
44-
DataFusionSelect q1 = new DataFusionSelect();
45-
q1.setFetchColumnsString("COUNT(*)");
46-
q1.setFromList(randomSelect.getFromList());
47-
q1.setWhereClause(randomSelect.getWhereClause());
48-
// Q2: SELECT count(case when [expr3] then 1 else null end) FROM [expr2]
49-
DataFusionSelect q2 = new DataFusionSelect();
50-
String selectExpr = String.format("COUNT(CASE WHEN %S THEN 1 ELSE NULL END)",
51-
DataFusionToStringVisitor.asString(randomSelect.getWhereClause()));
52-
q2.setFetchColumnsString(selectExpr);
53-
q2.setFromList(randomSelect.getFromList());
54-
q2.setWhereClause(null);
55-
56-
/*
57-
* Execute Q1 and Q2
58-
*/
59-
String q1String = DataFusionToStringVisitor.asString(q1);
60-
String q2String = DataFusionToStringVisitor.asString(q2);
61-
List<String> q1ResultSet = null;
62-
List<String> q2ResultSet = null;
63-
try {
64-
q1ResultSet = ComparatorHelper.getResultSetFirstColumnAsString(q1String, errors, state);
65-
q2ResultSet = ComparatorHelper.getResultSetFirstColumnAsString(q2String, errors, state);
66-
} catch (AssertionError e) {
67-
// Append detailed error message
68-
String replay = DataFusionUtil.getReplay(state.getDatabaseName());
69-
String newMessage = e.getMessage() + "\n" + e.getCause() + "\n" + replay + "\n";
70-
state.dfLogger.appendToLog(ERROR, newMessage);
71-
72-
throw new AssertionError(newMessage);
73-
}
74-
75-
/*
76-
* NoREC check
77-
*/
78-
int count1 = q1ResultSet != null ? Integer.parseInt(q1ResultSet.get(0)) : -1;
79-
int count2 = q2ResultSet != null ? Integer.parseInt(q2ResultSet.get(0)) : -1;
80-
if (count1 != count2) {
81-
StringBuilder errorMessage = new StringBuilder().append("NoREC oracle violated:\n")
82-
.append(" Q1(result size ").append(count1).append("):").append(q1String).append(";\n")
83-
.append(" Q2(result size ").append(count2).append("):").append(q2String).append(";\n")
84-
.append("=======================================\n").append("Reproducer: \n");
85-
86-
String replay = DataFusionUtil.getReplay(state.getDatabaseName());
32+
oracle.check();
33+
}
8734

88-
String errorLog = errorMessage.toString() + replay + "\n";
89-
String indentedErrorLog = errorLog.replaceAll("(?m)^", " ");
90-
state.dfLogger.appendToLog(ERROR, errorLog);
35+
@Override
36+
public Reproducer<DataFusionGlobalState> getLastReproducer() {
37+
return oracle.getLastReproducer();
38+
}
9139

92-
throw new AssertionError("\n\n" + indentedErrorLog);
93-
}
40+
@Override
41+
public String getLastQueryString() {
42+
return oracle.getLastQueryString();
9443
}
9544
}

test/sqlancer/dbms/TestConfig.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ public class TestConfig {
66

77
public static final String COCKROACHDB_ENV = "COCKROACHDB_AVAILABLE";
88
public static final String DATABEND_ENV = "DATABEND_AVAILABLE";
9+
public static final String DATAFUSION_ENV = "DATAFUSION_AVAILABLE";
910
public static final String DORIS_ENV = "DORIS_AVAILABLE";
1011
public static final String MARIADB_ENV = "MARIADB_AVAILABLE";
1112
public static final String POSTGRES_ENV = "POSTGRES_AVAILABLE";

test/sqlancer/dbms/TestDataFusion.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@
1010
public class TestDataFusion {
1111
@Test
1212
public void testDataFusion() {
13-
String datafusionAvailable = System.getenv("DATAFUSION_AVAILABLE");
14-
boolean datafusionIsAvailable = datafusionAvailable != null && datafusionAvailable.equalsIgnoreCase("true");
15-
assumeTrue(datafusionIsAvailable);
13+
assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.DATAFUSION_ENV));
1614

1715
assertEquals(0, Main.executeMain("--random-seed", "0", "--num-threads", "1", // TODO(datafusion) update when
1816
// multithread is supported

0 commit comments

Comments
 (0)