Skip to content

Commit 9504896

Browse files
authored
Merge pull request sqlancer#961 from malwaregarry/norec-mariadb
[MariaDB] Use generic NoREC oracle
2 parents 9c74a95 + f216c41 commit 9504896

10 files changed

Lines changed: 136 additions & 127 deletions

src/sqlancer/mariadb/MariaDBSchema.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ public boolean isPrimaryKey() {
5555
return isPrimaryKey;
5656
}
5757

58+
public static MariaDBColumn createDummy(String name) {
59+
return new MariaDBColumn(name, MariaDBDataType.INT, false, 1);
60+
}
61+
5862
}
5963

6064
public static class MariaDBTables {
Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
package sqlancer.mariadb.ast;
22

3-
public interface MariaDBExpression {
3+
import sqlancer.common.ast.newast.Expression;
4+
import sqlancer.mariadb.MariaDBSchema.MariaDBColumn;
5+
6+
public interface MariaDBExpression extends Expression<MariaDBColumn> {
47

58
}

src/sqlancer/mariadb/ast/MariaDBJoin.java

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
import java.util.List;
66

77
import sqlancer.Randomly;
8-
import sqlancer.mariadb.MariaDBProvider.MariaDBGlobalState;
8+
import sqlancer.common.ast.newast.Join;
99
import sqlancer.mariadb.MariaDBSchema.MariaDBColumn;
1010
import sqlancer.mariadb.MariaDBSchema.MariaDBTable;
1111
import sqlancer.mariadb.gen.MariaDBExpressionGenerator;
1212

13-
public class MariaDBJoin implements MariaDBExpression {
13+
public class MariaDBJoin implements MariaDBExpression, Join<MariaDBExpression, MariaDBTable, MariaDBColumn> {
1414

1515
public enum JoinType {
1616
NATURAL, INNER, STRAIGHT, LEFT, RIGHT, CROSS;
@@ -36,6 +36,7 @@ public MariaDBTable getTable() {
3636
return table;
3737
}
3838

39+
@Override
3940
public MariaDBExpression getOnClause() {
4041
return onClause;
4142
}
@@ -44,6 +45,7 @@ public JoinType getType() {
4445
return type;
4546
}
4647

48+
@Override
4749
public void setOnClause(MariaDBExpression onClause) {
4850
this.onClause = onClause;
4951
}
@@ -52,7 +54,7 @@ public void setType(JoinType type) {
5254
this.type = type;
5355
}
5456

55-
public static List<MariaDBJoin> getRandomJoinClauses(List<MariaDBTable> tables, MariaDBGlobalState globalState) {
57+
public static List<MariaDBJoin> getRandomJoinClauses(List<MariaDBTable> tables, Randomly r) {
5658
List<MariaDBJoin> joinStatements = new ArrayList<>();
5759
List<JoinType> options = new ArrayList<>(Arrays.asList(JoinType.values()));
5860
List<MariaDBColumn> columns = new ArrayList<>();
@@ -68,8 +70,7 @@ public static List<MariaDBJoin> getRandomJoinClauses(List<MariaDBTable> tables,
6870
MariaDBTable table = Randomly.fromList(tables);
6971
tables.remove(table);
7072
columns.addAll(table.getColumns());
71-
MariaDBExpressionGenerator joinGen = new MariaDBExpressionGenerator(globalState.getRandomly())
72-
.setColumns(columns);
73+
MariaDBExpressionGenerator joinGen = new MariaDBExpressionGenerator(r).setColumns(columns);
7374
MariaDBExpression joinClause = joinGen.getRandomExpression();
7475
JoinType selectedOption = Randomly.fromList(options);
7576
if (selectedOption == JoinType.NATURAL) {

src/sqlancer/mariadb/ast/MariaDBSelectStatement.java

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,20 @@
44
import java.util.List;
55

66
import sqlancer.common.ast.SelectBase;
7+
import sqlancer.common.ast.newast.Select;
8+
import sqlancer.mariadb.MariaDBSchema.MariaDBColumn;
9+
import sqlancer.mariadb.MariaDBSchema.MariaDBTable;
710

8-
public class MariaDBSelectStatement extends SelectBase<MariaDBExpression> implements MariaDBExpression {
11+
public class MariaDBSelectStatement extends SelectBase<MariaDBExpression>
12+
implements MariaDBExpression, Select<MariaDBJoin, MariaDBExpression, MariaDBTable, MariaDBColumn> {
913

1014
public enum MariaDBSelectType {
1115
ALL, DISTINCT, DISTINCTROW;
1216
}
1317

1418
private List<MariaDBExpression> groupBys = new ArrayList<>();
1519
private List<MariaDBExpression> columns = new ArrayList<>();
20+
private List<MariaDBJoin> joinClauses = new ArrayList<>();
1621
private MariaDBSelectType selectType = MariaDBSelectType.ALL;
1722
private MariaDBExpression whereCondition;
1823

@@ -52,4 +57,18 @@ public MariaDBExpression getWhereCondition() {
5257
return whereCondition;
5358
}
5459

60+
@Override
61+
public List<MariaDBJoin> getJoinClauses() {
62+
return joinClauses;
63+
}
64+
65+
@Override
66+
public void setJoinClauses(List<MariaDBJoin> joinClauses) {
67+
this.joinClauses = joinClauses;
68+
}
69+
70+
@Override
71+
public String asString() {
72+
return MariaDBVisitor.asString(this);
73+
}
5574
}

src/sqlancer/mariadb/ast/MariaDBStringVisitor.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ public void visit(MariaDBSelectStatement s) {
4747
}
4848
visit(s.getFromList().get(j));
4949
}
50-
for (MariaDBExpression j : s.getJoinList()) {
50+
for (MariaDBExpression j : s.getJoinClauses()) {
5151
visit(j);
5252
}
5353
if (s.getWhereCondition() != null) {

src/sqlancer/mariadb/gen/MariaDBExpressionGenerator.java

Lines changed: 73 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,14 @@
55
import java.util.List;
66

77
import sqlancer.Randomly;
8-
import sqlancer.SQLConnection;
9-
import sqlancer.StateToReproduce;
8+
import sqlancer.common.gen.NoRECGenerator;
9+
import sqlancer.common.schema.AbstractTables;
1010
import sqlancer.mariadb.MariaDBProvider;
1111
import sqlancer.mariadb.MariaDBSchema.MariaDBColumn;
1212
import sqlancer.mariadb.MariaDBSchema.MariaDBDataType;
13+
import sqlancer.mariadb.MariaDBSchema.MariaDBTable;
14+
import sqlancer.mariadb.ast.MariaDBAggregate;
15+
import sqlancer.mariadb.ast.MariaDBAggregate.MariaDBAggregateFunction;
1316
import sqlancer.mariadb.ast.MariaDBBinaryOperator;
1417
import sqlancer.mariadb.ast.MariaDBBinaryOperator.MariaDBBinaryComparisonOperator;
1518
import sqlancer.mariadb.ast.MariaDBColumnName;
@@ -18,14 +21,21 @@
1821
import sqlancer.mariadb.ast.MariaDBFunction;
1922
import sqlancer.mariadb.ast.MariaDBFunctionName;
2023
import sqlancer.mariadb.ast.MariaDBInOperation;
24+
import sqlancer.mariadb.ast.MariaDBJoin;
2125
import sqlancer.mariadb.ast.MariaDBPostfixUnaryOperation;
2226
import sqlancer.mariadb.ast.MariaDBPostfixUnaryOperation.MariaDBPostfixUnaryOperator;
27+
import sqlancer.mariadb.ast.MariaDBSelectStatement;
28+
import sqlancer.mariadb.ast.MariaDBSelectStatement.MariaDBSelectType;
29+
import sqlancer.mariadb.ast.MariaDBTableReference;
30+
import sqlancer.mariadb.ast.MariaDBText;
2331
import sqlancer.mariadb.ast.MariaDBUnaryPrefixOperation;
2432
import sqlancer.mariadb.ast.MariaDBUnaryPrefixOperation.MariaDBUnaryPrefixOperator;
2533

26-
public class MariaDBExpressionGenerator {
34+
public class MariaDBExpressionGenerator
35+
implements NoRECGenerator<MariaDBSelectStatement, MariaDBJoin, MariaDBExpression, MariaDBTable, MariaDBColumn> {
2736

2837
private final Randomly r;
38+
private List<MariaDBTable> targetTables = new ArrayList<>();
2939
private List<MariaDBColumn> columns = new ArrayList<>();
3040

3141
public MariaDBExpressionGenerator(Randomly r) {
@@ -66,14 +76,6 @@ public MariaDBExpressionGenerator setColumns(List<MariaDBColumn> columns) {
6676
return this;
6777
}
6878

69-
public MariaDBExpressionGenerator setCon(SQLConnection con) {
70-
return this;
71-
}
72-
73-
public MariaDBExpressionGenerator setState(StateToReproduce state) {
74-
return this;
75-
}
76-
7779
private enum ExpressionType {
7880
LITERAL, COLUMN, BINARY_COMPARISON, UNARY_POSTFIX_OPERATOR, UNARY_PREFIX_OPERATOR, FUNCTION, IN
7981
}
@@ -146,4 +148,64 @@ public MariaDBExpression getRandomExpression() {
146148
return getRandomExpression(0);
147149
}
148150

151+
@Override
152+
public MariaDBExpressionGenerator setTablesAndColumns(AbstractTables<MariaDBTable, MariaDBColumn> targetTables) {
153+
this.targetTables = targetTables.getTables();
154+
this.columns = targetTables.getColumns();
155+
return this;
156+
}
157+
158+
@Override
159+
public List<MariaDBExpression> getTableRefs() {
160+
List<MariaDBExpression> tableRefs = new ArrayList<>();
161+
for (MariaDBTable t : targetTables) {
162+
MariaDBTableReference tableRef = new MariaDBTableReference(t);
163+
tableRefs.add(tableRef);
164+
}
165+
return tableRefs;
166+
}
167+
168+
@Override
169+
public MariaDBExpression generateBooleanExpression() {
170+
return getRandomExpression();
171+
}
172+
173+
@Override
174+
public MariaDBSelectStatement generateSelect() {
175+
return new MariaDBSelectStatement();
176+
}
177+
178+
@Override
179+
public List<MariaDBJoin> getRandomJoinClauses() {
180+
return MariaDBJoin.getRandomJoinClauses(targetTables, r);
181+
}
182+
183+
@Override
184+
public String generateOptimizedQueryString(MariaDBSelectStatement select, MariaDBExpression whereCondition,
185+
boolean shouldUseAggregate) {
186+
if (shouldUseAggregate) {
187+
MariaDBAggregate aggr = new MariaDBAggregate(
188+
new MariaDBColumnName(new MariaDBColumn("*", MariaDBDataType.INT, false, 0)),
189+
MariaDBAggregateFunction.COUNT);
190+
select.setFetchColumns(Arrays.asList(aggr));
191+
} else {
192+
MariaDBColumnName aggr = new MariaDBColumnName(MariaDBColumn.createDummy("*"));
193+
select.setFetchColumns(Arrays.asList(aggr));
194+
}
195+
196+
select.setWhereClause(whereCondition);
197+
select.setSelectType(MariaDBSelectType.ALL);
198+
return select.asString();
199+
}
200+
201+
@Override
202+
public String generateUnoptimizedQueryString(MariaDBSelectStatement select, MariaDBExpression whereCondition) {
203+
MariaDBPostfixUnaryOperation isTrue = new MariaDBPostfixUnaryOperation(MariaDBPostfixUnaryOperator.IS_TRUE,
204+
whereCondition);
205+
MariaDBText asText = new MariaDBText(isTrue, " as count", false);
206+
select.setFetchColumns(Arrays.asList(asText));
207+
select.setSelectType(MariaDBSelectType.ALL);
208+
209+
return "SELECT SUM(count) FROM (" + select.asString() + ") as asdf";
210+
}
149211
}

src/sqlancer/mariadb/oracle/MariaDBDQPOracle.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@ public MariaDBDQPOracle(MariaDBGlobalState globalState) {
3838
@Override
3939
public void check() throws Exception {
4040
MariaDBTables tables = s.getRandomTableNonEmptyTables();
41-
gen = new MariaDBExpressionGenerator(state.getRandomly()).setColumns(tables.getColumns())
42-
.setCon(state.getConnection()).setState(state.getState());
41+
gen = new MariaDBExpressionGenerator(state.getRandomly()).setColumns(tables.getColumns());
4342

4443
List<MariaDBExpression> fetchColumns = new ArrayList<>();
4544
fetchColumns.addAll(Randomly.nonEmptySubset(tables.getColumns()).stream().map(c -> new MariaDBColumnName(c))
@@ -57,8 +56,8 @@ public void check() throws Exception {
5756
}
5857

5958
// Set the join.
60-
List<MariaDBJoin> joinExpressions = MariaDBJoin.getRandomJoinClauses(tables.getTables(), state);
61-
select.setJoinList(joinExpressions.stream().map(j -> (MariaDBExpression) j).collect(Collectors.toList()));
59+
List<MariaDBJoin> joinExpressions = MariaDBJoin.getRandomJoinClauses(tables.getTables(), state.getRandomly());
60+
select.setJoinClauses(joinExpressions);
6261

6362
// Set the from clause from the tables that are not used in the join.
6463
select.setFromList(

0 commit comments

Comments
 (0)