|
5 | 5 | import java.util.List; |
6 | 6 |
|
7 | 7 | import sqlancer.Randomly; |
8 | | -import sqlancer.SQLConnection; |
9 | | -import sqlancer.StateToReproduce; |
| 8 | +import sqlancer.common.gen.NoRECGenerator; |
| 9 | +import sqlancer.common.schema.AbstractTables; |
10 | 10 | import sqlancer.mariadb.MariaDBProvider; |
11 | 11 | import sqlancer.mariadb.MariaDBSchema.MariaDBColumn; |
12 | 12 | import sqlancer.mariadb.MariaDBSchema.MariaDBDataType; |
| 13 | +import sqlancer.mariadb.MariaDBSchema.MariaDBTable; |
| 14 | +import sqlancer.mariadb.ast.MariaDBAggregate; |
| 15 | +import sqlancer.mariadb.ast.MariaDBAggregate.MariaDBAggregateFunction; |
13 | 16 | import sqlancer.mariadb.ast.MariaDBBinaryOperator; |
14 | 17 | import sqlancer.mariadb.ast.MariaDBBinaryOperator.MariaDBBinaryComparisonOperator; |
15 | 18 | import sqlancer.mariadb.ast.MariaDBColumnName; |
|
18 | 21 | import sqlancer.mariadb.ast.MariaDBFunction; |
19 | 22 | import sqlancer.mariadb.ast.MariaDBFunctionName; |
20 | 23 | import sqlancer.mariadb.ast.MariaDBInOperation; |
| 24 | +import sqlancer.mariadb.ast.MariaDBJoin; |
21 | 25 | import sqlancer.mariadb.ast.MariaDBPostfixUnaryOperation; |
22 | 26 | import sqlancer.mariadb.ast.MariaDBPostfixUnaryOperation.MariaDBPostfixUnaryOperator; |
| 27 | +import sqlancer.mariadb.ast.MariaDBSelectStatement; |
| 28 | +import sqlancer.mariadb.ast.MariaDBSelectStatement.MariaDBSelectType; |
| 29 | +import sqlancer.mariadb.ast.MariaDBTableReference; |
| 30 | +import sqlancer.mariadb.ast.MariaDBText; |
23 | 31 | import sqlancer.mariadb.ast.MariaDBUnaryPrefixOperation; |
24 | 32 | import sqlancer.mariadb.ast.MariaDBUnaryPrefixOperation.MariaDBUnaryPrefixOperator; |
25 | 33 |
|
26 | | -public class MariaDBExpressionGenerator { |
| 34 | +public class MariaDBExpressionGenerator |
| 35 | + implements NoRECGenerator<MariaDBSelectStatement, MariaDBJoin, MariaDBExpression, MariaDBTable, MariaDBColumn> { |
27 | 36 |
|
28 | 37 | private final Randomly r; |
| 38 | + private List<MariaDBTable> targetTables = new ArrayList<>(); |
29 | 39 | private List<MariaDBColumn> columns = new ArrayList<>(); |
30 | 40 |
|
31 | 41 | public MariaDBExpressionGenerator(Randomly r) { |
@@ -66,14 +76,6 @@ public MariaDBExpressionGenerator setColumns(List<MariaDBColumn> columns) { |
66 | 76 | return this; |
67 | 77 | } |
68 | 78 |
|
69 | | - public MariaDBExpressionGenerator setCon(SQLConnection con) { |
70 | | - return this; |
71 | | - } |
72 | | - |
73 | | - public MariaDBExpressionGenerator setState(StateToReproduce state) { |
74 | | - return this; |
75 | | - } |
76 | | - |
77 | 79 | private enum ExpressionType { |
78 | 80 | LITERAL, COLUMN, BINARY_COMPARISON, UNARY_POSTFIX_OPERATOR, UNARY_PREFIX_OPERATOR, FUNCTION, IN |
79 | 81 | } |
@@ -146,4 +148,64 @@ public MariaDBExpression getRandomExpression() { |
146 | 148 | return getRandomExpression(0); |
147 | 149 | } |
148 | 150 |
|
| 151 | + @Override |
| 152 | + public MariaDBExpressionGenerator setTablesAndColumns(AbstractTables<MariaDBTable, MariaDBColumn> targetTables) { |
| 153 | + this.targetTables = targetTables.getTables(); |
| 154 | + this.columns = targetTables.getColumns(); |
| 155 | + return this; |
| 156 | + } |
| 157 | + |
| 158 | + @Override |
| 159 | + public List<MariaDBExpression> getTableRefs() { |
| 160 | + List<MariaDBExpression> tableRefs = new ArrayList<>(); |
| 161 | + for (MariaDBTable t : targetTables) { |
| 162 | + MariaDBTableReference tableRef = new MariaDBTableReference(t); |
| 163 | + tableRefs.add(tableRef); |
| 164 | + } |
| 165 | + return tableRefs; |
| 166 | + } |
| 167 | + |
| 168 | + @Override |
| 169 | + public MariaDBExpression generateBooleanExpression() { |
| 170 | + return getRandomExpression(); |
| 171 | + } |
| 172 | + |
| 173 | + @Override |
| 174 | + public MariaDBSelectStatement generateSelect() { |
| 175 | + return new MariaDBSelectStatement(); |
| 176 | + } |
| 177 | + |
| 178 | + @Override |
| 179 | + public List<MariaDBJoin> getRandomJoinClauses() { |
| 180 | + return MariaDBJoin.getRandomJoinClauses(targetTables, r); |
| 181 | + } |
| 182 | + |
| 183 | + @Override |
| 184 | + public String generateOptimizedQueryString(MariaDBSelectStatement select, MariaDBExpression whereCondition, |
| 185 | + boolean shouldUseAggregate) { |
| 186 | + if (shouldUseAggregate) { |
| 187 | + MariaDBAggregate aggr = new MariaDBAggregate( |
| 188 | + new MariaDBColumnName(new MariaDBColumn("*", MariaDBDataType.INT, false, 0)), |
| 189 | + MariaDBAggregateFunction.COUNT); |
| 190 | + select.setFetchColumns(Arrays.asList(aggr)); |
| 191 | + } else { |
| 192 | + MariaDBColumnName aggr = new MariaDBColumnName(MariaDBColumn.createDummy("*")); |
| 193 | + select.setFetchColumns(Arrays.asList(aggr)); |
| 194 | + } |
| 195 | + |
| 196 | + select.setWhereClause(whereCondition); |
| 197 | + select.setSelectType(MariaDBSelectType.ALL); |
| 198 | + return select.asString(); |
| 199 | + } |
| 200 | + |
| 201 | + @Override |
| 202 | + public String generateUnoptimizedQueryString(MariaDBSelectStatement select, MariaDBExpression whereCondition) { |
| 203 | + MariaDBPostfixUnaryOperation isTrue = new MariaDBPostfixUnaryOperation(MariaDBPostfixUnaryOperator.IS_TRUE, |
| 204 | + whereCondition); |
| 205 | + MariaDBText asText = new MariaDBText(isTrue, " as count", false); |
| 206 | + select.setFetchColumns(Arrays.asList(asText)); |
| 207 | + select.setSelectType(MariaDBSelectType.ALL); |
| 208 | + |
| 209 | + return "SELECT SUM(count) FROM (" + select.asString() + ") as asdf"; |
| 210 | + } |
149 | 211 | } |
0 commit comments