|
13 | 13 | import sqlancer.Randomly; |
14 | 14 | import sqlancer.common.ast.BinaryOperatorNode.Operator; |
15 | 15 | import sqlancer.common.ast.newast.NewUnaryPostfixOperatorNode; |
| 16 | +import sqlancer.common.gen.NoRECGenerator; |
16 | 17 | import sqlancer.common.gen.TypedExpressionGenerator; |
| 18 | +import sqlancer.common.schema.AbstractTables; |
17 | 19 | import sqlancer.datafusion.DataFusionProvider.DataFusionGlobalState; |
18 | 20 | import sqlancer.datafusion.DataFusionSchema.DataFusionColumn; |
19 | 21 | import sqlancer.datafusion.DataFusionSchema.DataFusionDataType; |
| 22 | +import sqlancer.datafusion.DataFusionSchema.DataFusionTable; |
| 23 | +import sqlancer.datafusion.DataFusionToStringVisitor; |
20 | 24 | import sqlancer.datafusion.ast.DataFusionBinaryOperation; |
21 | 25 | import sqlancer.datafusion.ast.DataFusionColumnReference; |
22 | 26 | import sqlancer.datafusion.ast.DataFusionExpression; |
23 | 27 | import sqlancer.datafusion.ast.DataFusionFunction; |
| 28 | +import sqlancer.datafusion.ast.DataFusionJoin; |
| 29 | +import sqlancer.datafusion.ast.DataFusionSelect; |
| 30 | +import sqlancer.datafusion.ast.DataFusionTableReference; |
24 | 31 | import sqlancer.datafusion.ast.DataFusionUnaryPostfixOperation; |
25 | 32 | import sqlancer.datafusion.ast.DataFusionUnaryPrefixOperation; |
26 | 33 | import sqlancer.datafusion.gen.DataFusionBaseExpr.ArgumentType; |
27 | 34 | import sqlancer.datafusion.gen.DataFusionBaseExpr.DataFusionBaseExprType; |
28 | 35 |
|
29 | 36 | public final class DataFusionExpressionGenerator |
30 | | - extends TypedExpressionGenerator<DataFusionExpression, DataFusionColumn, DataFusionDataType> { |
| 37 | + extends TypedExpressionGenerator<DataFusionExpression, DataFusionColumn, DataFusionDataType> implements |
| 38 | + NoRECGenerator<DataFusionSelect, DataFusionJoin, DataFusionExpression, DataFusionTable, DataFusionColumn> { |
31 | 39 |
|
| 40 | + private List<DataFusionTable> tables; |
32 | 41 | private final DataFusionGlobalState globalState; |
33 | 42 |
|
34 | 43 | public DataFusionExpressionGenerator(DataFusionGlobalState globalState) { |
@@ -100,7 +109,8 @@ protected DataFusionExpression generateExpression(DataFusionDataType type, int d |
100 | 109 | case BINARY: |
101 | 110 | dfAssert(randomExpr.argTypes.size() == 2 && randomExpr.nArgs == 2, |
102 | 111 | "Binrary expression should only have 2 argument" + randomExpr.argTypes); |
103 | | - List<DataFusionDataType> argTypeList = new ArrayList<>(); // types of current expression's input arguments |
| 112 | + List<DataFusionDataType> argTypeList = new ArrayList<>(); // types of current expression's input |
| 113 | + // arguments |
104 | 114 | for (ArgumentType argumentType : randomExpr.argTypes) { |
105 | 115 | if (argumentType instanceof ArgumentType.Fixed) { |
106 | 116 | ArgumentType.Fixed possibleArgTypes = (ArgumentType.Fixed) randomExpr.argTypes.get(0); |
@@ -134,7 +144,8 @@ protected DataFusionExpression generateExpression(DataFusionDataType type, int d |
134 | 144 | public DataFusionExpression generateFunctionExpression(DataFusionDataType type, int depth, |
135 | 145 | DataFusionBaseExpr exprType) { |
136 | 146 | if (exprType.isVariadic || Randomly.getBooleanWithSmallProbability()) { |
137 | | - // TODO(datafusion) maybe add possible types. e.g. some function have signature variadic(INT/DOUBLE), then |
| 147 | + // TODO(datafusion) maybe add possible types. e.g. some function have signature |
| 148 | + // variadic(INT/DOUBLE), then |
138 | 149 | // only randomly pick from INT and DOUBLE |
139 | 150 | int nArgs = Randomly.smallNumber(); // 0, 2, 4, ... smaller one is more likely |
140 | 151 | return new DataFusionFunction<DataFusionBaseExpr>(generateExpressions(nArgs), exprType); |
@@ -222,4 +233,66 @@ public String getTextRepresentation() { |
222 | 233 |
|
223 | 234 | } |
224 | 235 |
|
| 236 | + @Override |
| 237 | + public NoRECGenerator<DataFusionSelect, DataFusionJoin, DataFusionExpression, DataFusionTable, DataFusionColumn> setTablesAndColumns( |
| 238 | + AbstractTables<DataFusionTable, DataFusionColumn> tables) { |
| 239 | + List<DataFusionTable> randomTables = Randomly.nonEmptySubset(tables.getTables()); |
| 240 | + int maxSize = Randomly.fromOptions(1, 2, 3, 4); |
| 241 | + if (randomTables.size() > maxSize) { |
| 242 | + randomTables = randomTables.subList(0, maxSize); |
| 243 | + } |
| 244 | + this.columns = DataFusionTable.getAllColumns(randomTables); |
| 245 | + this.tables = randomTables; |
| 246 | + |
| 247 | + return this; |
| 248 | + } |
| 249 | + |
| 250 | + @Override |
| 251 | + public DataFusionExpression generateBooleanExpression() { |
| 252 | + return generateExpression(DataFusionDataType.BOOLEAN); |
| 253 | + } |
| 254 | + |
| 255 | + @Override |
| 256 | + public DataFusionSelect generateSelect() { |
| 257 | + return new DataFusionSelect(); |
| 258 | + } |
| 259 | + |
| 260 | + @Override |
| 261 | + public List<DataFusionJoin> getRandomJoinClauses() { |
| 262 | + List<DataFusionTableReference> tableList = tables.stream().map(t -> new DataFusionTableReference(t)) |
| 263 | + .collect(Collectors.toList()); |
| 264 | + List<DataFusionJoin> joins = DataFusionJoin.getJoins(tableList, globalState); |
| 265 | + tables = tableList.stream().map(t -> t.getTable()).collect(Collectors.toList()); |
| 266 | + return joins; |
| 267 | + } |
| 268 | + |
| 269 | + @Override |
| 270 | + public List<DataFusionExpression> getTableRefs() { |
| 271 | + return tables.stream().map(t -> new DataFusionTableReference(t)).collect(Collectors.toList()); |
| 272 | + } |
| 273 | + |
| 274 | + @Override |
| 275 | + public String generateOptimizedQueryString(DataFusionSelect select, DataFusionExpression whereCondition, |
| 276 | + boolean shouldUseAggregate) { |
| 277 | + if (shouldUseAggregate) { |
| 278 | + select.setFetchColumnsString("COUNT(*)"); |
| 279 | + } else { |
| 280 | + List<DataFusionExpression> allColumns = columns.stream().map((c) -> new DataFusionColumnReference(c)) |
| 281 | + .collect(Collectors.toList()); |
| 282 | + select.setFetchColumns(allColumns); |
| 283 | + } |
| 284 | + select.setWhereClause(whereCondition); |
| 285 | + |
| 286 | + return select.asString(); |
| 287 | + } |
| 288 | + |
| 289 | + @Override |
| 290 | + public String generateUnoptimizedQueryString(DataFusionSelect select, DataFusionExpression whereCondition) { |
| 291 | + String fetchColumn = String.format("COUNT(CASE WHEN %S THEN 1 ELSE NULL END)", |
| 292 | + DataFusionToStringVisitor.asString(whereCondition)); |
| 293 | + select.setFetchColumnsString(fetchColumn); |
| 294 | + select.setWhereClause(null); |
| 295 | + |
| 296 | + return select.asString(); |
| 297 | + } |
225 | 298 | } |
0 commit comments