Skip to content

Commit 3d37edc

Browse files
authored
mysql: cert init (#837)
* mysql: cert init * CERT oracle base class * mysql: stream API optimization * cert: abstract mutate method * cert: getRandomExcept
1 parent 1f80de1 commit 3d37edc

6 files changed

Lines changed: 359 additions & 57 deletions

File tree

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
package sqlancer.common.oracle;
2+
3+
import java.util.Arrays;
4+
import java.util.List;
5+
6+
import sqlancer.Randomly;
7+
import sqlancer.SQLGlobalState;
8+
import sqlancer.common.query.ExpectedErrors;
9+
10+
public abstract class CERTOracleBase<S extends SQLGlobalState<?, ?>> implements TestOracle<S> {
11+
12+
protected final S state;
13+
protected final ExpectedErrors errors;
14+
protected List<String> queryPlan1Sequences;
15+
protected List<String> queryPlan2Sequences;
16+
17+
protected enum Mutator {
18+
JOIN, DISTINCT, WHERE, GROUPBY, HAVING, AND, OR, LIMIT;
19+
20+
public static Mutator getRandomExcept(Mutator... exclude) {
21+
Mutator[] values = Arrays.stream(values()).filter(m -> !Arrays.asList(exclude).contains(m))
22+
.toArray(Mutator[]::new);
23+
return Randomly.fromOptions(values);
24+
}
25+
}
26+
27+
protected CERTOracleBase(S state) {
28+
this.state = state;
29+
this.errors = new ExpectedErrors();
30+
}
31+
32+
protected boolean mutate(Mutator... exclude) {
33+
Mutator m = Mutator.getRandomExcept(exclude);
34+
switch (m) {
35+
case JOIN:
36+
return mutateJoin();
37+
case DISTINCT:
38+
return mutateDistinct();
39+
case WHERE:
40+
return mutateWhere();
41+
case GROUPBY:
42+
return mutateGroupBy();
43+
case HAVING:
44+
return mutateHaving();
45+
case AND:
46+
return mutateAnd();
47+
case OR:
48+
return mutateOr();
49+
case LIMIT:
50+
return mutateLimit();
51+
default:
52+
throw new AssertionError(m);
53+
}
54+
}
55+
56+
protected boolean mutateJoin() {
57+
throw new UnsupportedOperationException();
58+
}
59+
60+
protected boolean mutateDistinct() {
61+
throw new UnsupportedOperationException();
62+
}
63+
64+
protected boolean mutateWhere() {
65+
throw new UnsupportedOperationException();
66+
}
67+
68+
protected boolean mutateGroupBy() {
69+
throw new UnsupportedOperationException();
70+
}
71+
72+
protected boolean mutateHaving() {
73+
throw new UnsupportedOperationException();
74+
}
75+
76+
protected boolean mutateAnd() {
77+
throw new UnsupportedOperationException();
78+
}
79+
80+
protected boolean mutateOr() {
81+
throw new UnsupportedOperationException();
82+
}
83+
84+
protected boolean mutateLimit() {
85+
throw new UnsupportedOperationException();
86+
}
87+
88+
}

src/sqlancer/mysql/MySQLOptions.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import sqlancer.OracleFactory;
1212
import sqlancer.common.oracle.TestOracle;
1313
import sqlancer.mysql.MySQLOptions.MySQLOracleFactory;
14+
import sqlancer.mysql.oracle.MySQLCERTOracle;
1415
import sqlancer.mysql.oracle.MySQLPivotedQuerySynthesisOracle;
1516
import sqlancer.mysql.oracle.MySQLTLPWhereOracle;
1617

@@ -45,7 +46,18 @@ public boolean requiresAllTablesToContainRows() {
4546
return true;
4647
}
4748

48-
}
49+
},
50+
CERT {
51+
@Override
52+
public TestOracle<MySQLGlobalState> create(MySQLGlobalState globalState) throws SQLException {
53+
return new MySQLCERTOracle(globalState);
54+
}
55+
56+
@Override
57+
public boolean requiresAllTablesToContainRows() {
58+
return true;
59+
}
60+
};
4961
}
5062

5163
@Override

src/sqlancer/mysql/MySQLProvider.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818
import sqlancer.SQLProviderAdapter;
1919
import sqlancer.StatementExecutor;
2020
import sqlancer.common.DBMSCommon;
21+
import sqlancer.common.query.ExpectedErrors;
2122
import sqlancer.common.query.SQLQueryAdapter;
2223
import sqlancer.common.query.SQLQueryProvider;
24+
import sqlancer.mysql.MySQLSchema.MySQLColumn;
2325
import sqlancer.mysql.MySQLSchema.MySQLTable;
2426
import sqlancer.mysql.gen.MySQLAlterTable;
2527
import sqlancer.mysql.gen.MySQLDeleteGenerator;
@@ -36,6 +38,7 @@
3638
import sqlancer.mysql.gen.tblmaintenance.MySQLChecksum;
3739
import sqlancer.mysql.gen.tblmaintenance.MySQLOptimize;
3840
import sqlancer.mysql.gen.tblmaintenance.MySQLRepair;
41+
import sqlancer.mysql.oracle.MySQLCERTOracle;
3942

4043
@AutoService(DatabaseProvider.class)
4144
public class MySQLProvider extends SQLProviderAdapter<MySQLGlobalState, MySQLOptions> {
@@ -153,6 +156,24 @@ public void generateDatabase(MySQLGlobalState globalState) throws Exception {
153156
}
154157
});
155158
se.executeStatements();
159+
160+
if (globalState.getDbmsSpecificOptions().getTestOracleFactory().size() == 1
161+
&& globalState.getDbmsSpecificOptions().getTestOracleFactory().get(0)
162+
.create(globalState) instanceof MySQLCERTOracle) {
163+
// Enfore statistic collected for all tables
164+
ExpectedErrors errors = new ExpectedErrors();
165+
MySQLErrors.addExpressionErrors(errors);
166+
for (MySQLTable table : globalState.getSchema().getDatabaseTables()) {
167+
StringBuilder sb = new StringBuilder();
168+
sb.append("ANALYZE TABLE ");
169+
sb.append(table.getName());
170+
sb.append(" UPDATE HISTOGRAM ON ");
171+
String columns = table.getColumns().stream().map(MySQLColumn::getName)
172+
.collect(Collectors.joining(", "));
173+
sb.append(columns + ";");
174+
globalState.executeStatement(new SQLQueryAdapter(sb.toString(), errors));
175+
}
176+
}
156177
}
157178

158179
@Override

src/sqlancer/mysql/ast/MySQLSelect.java

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,6 @@ public void setSelectType(SelectType fromOptions) {
1818
this.setFromOptions(fromOptions);
1919
}
2020

21-
public SelectType getSelectType() {
22-
return fromOptions;
23-
}
24-
2521
public SelectType getFromOptions() {
2622
return fromOptions;
2723
}
Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
package sqlancer.mysql.oracle;
2+
3+
import java.io.IOException;
4+
import java.sql.SQLException;
5+
import java.util.ArrayList;
6+
import java.util.List;
7+
import java.util.stream.Collectors;
8+
9+
import sqlancer.IgnoreMeException;
10+
import sqlancer.Randomly;
11+
import sqlancer.SQLGlobalState;
12+
import sqlancer.common.DBMSCommon;
13+
import sqlancer.common.oracle.CERTOracleBase;
14+
import sqlancer.common.oracle.TestOracle;
15+
import sqlancer.common.query.SQLQueryAdapter;
16+
import sqlancer.common.query.SQLancerResultSet;
17+
import sqlancer.mysql.MySQLErrors;
18+
import sqlancer.mysql.MySQLGlobalState;
19+
import sqlancer.mysql.MySQLSchema.MySQLTables;
20+
import sqlancer.mysql.MySQLVisitor;
21+
import sqlancer.mysql.ast.MySQLBinaryLogicalOperation;
22+
import sqlancer.mysql.ast.MySQLBinaryLogicalOperation.MySQLBinaryLogicalOperator;
23+
import sqlancer.mysql.ast.MySQLColumnReference;
24+
import sqlancer.mysql.ast.MySQLExpression;
25+
import sqlancer.mysql.ast.MySQLSelect;
26+
import sqlancer.mysql.ast.MySQLTableReference;
27+
import sqlancer.mysql.gen.MySQLExpressionGenerator;
28+
29+
public class MySQLCERTOracle extends CERTOracleBase<MySQLGlobalState> implements TestOracle<MySQLGlobalState> {
30+
private MySQLExpressionGenerator gen;
31+
private MySQLSelect select;
32+
33+
public MySQLCERTOracle(MySQLGlobalState globalState) {
34+
super(globalState);
35+
MySQLErrors.addExpressionErrors(errors);
36+
}
37+
38+
@Override
39+
public void check() throws SQLException {
40+
queryPlan1Sequences = new ArrayList<>();
41+
queryPlan2Sequences = new ArrayList<>();
42+
43+
// Randomly generate a query
44+
MySQLTables tables = state.getSchema().getRandomTableNonEmptyTables();
45+
gen = new MySQLExpressionGenerator(state).setColumns(tables.getColumns());
46+
List<MySQLExpression> fetchColumns = new ArrayList<>();
47+
fetchColumns.addAll(Randomly.nonEmptySubset(tables.getColumns()).stream()
48+
.map(c -> new MySQLColumnReference(c, null)).collect(Collectors.toList()));
49+
List<MySQLExpression> tableList = tables.getTables().stream().map(t -> new MySQLTableReference(t))
50+
.collect(Collectors.toList());
51+
52+
select = new MySQLSelect();
53+
select.setFetchColumns(fetchColumns);
54+
select.setFromList(tableList);
55+
56+
select.setSelectType(Randomly.fromOptions(MySQLSelect.SelectType.values()));
57+
if (Randomly.getBoolean()) {
58+
select.setWhereClause(gen.generateExpression());
59+
}
60+
if (Randomly.getBoolean()) {
61+
select.setGroupByExpressions(fetchColumns);
62+
if (Randomly.getBoolean()) {
63+
select.setHavingClause(gen.generateExpression());
64+
}
65+
}
66+
67+
// Set the join. Todo: to make it random
68+
// List<MySQLExpression> joinExpressions = getJoins(tableList, state);
69+
// select.setJoinList(joinExpressions);
70+
71+
// Get the result of the first query
72+
String queryString1 = MySQLVisitor.asString(select);
73+
int rowCount1 = getRow(state, queryString1, queryPlan1Sequences);
74+
75+
boolean increase = mutate(Mutator.JOIN, Mutator.LIMIT);
76+
77+
// Get the result of the second query
78+
String queryString2 = MySQLVisitor.asString(select);
79+
int rowCount2 = getRow(state, queryString2, queryPlan2Sequences);
80+
81+
// Check structural equivalence
82+
if (DBMSCommon.editDistance(queryPlan1Sequences, queryPlan2Sequences) > 1) {
83+
return;
84+
}
85+
86+
// Check the results
87+
if (increase && rowCount1 > rowCount2 || !increase && rowCount1 < rowCount2) {
88+
throw new AssertionError("Inconsistent result for query: EXPLAIN " + queryString1 + "; --" + rowCount1
89+
+ "\nEXPLAIN " + queryString2 + "; --" + rowCount2);
90+
}
91+
}
92+
93+
@Override
94+
protected boolean mutateDistinct() {
95+
MySQLSelect.SelectType selectType = select.getFromOptions();
96+
if (selectType != MySQLSelect.SelectType.ALL) {
97+
select.setSelectType(MySQLSelect.SelectType.ALL);
98+
return true;
99+
} else {
100+
select.setSelectType(MySQLSelect.SelectType.DISTINCT);
101+
return false;
102+
}
103+
}
104+
105+
@Override
106+
protected boolean mutateWhere() {
107+
boolean increase = select.getWhereClause() != null;
108+
if (increase) {
109+
select.setWhereClause(null);
110+
} else {
111+
select.setWhereClause(gen.generateExpression());
112+
}
113+
return increase;
114+
}
115+
116+
@Override
117+
protected boolean mutateGroupBy() {
118+
boolean increase = select.getGroupByExpressions().size() > 0;
119+
if (increase) {
120+
select.clearGroupByExpressions();
121+
} else {
122+
select.setGroupByExpressions(select.getFetchColumns());
123+
}
124+
return increase;
125+
}
126+
127+
@Override
128+
protected boolean mutateHaving() {
129+
if (select.getGroupByExpressions().size() == 0) {
130+
select.setGroupByExpressions(select.getFetchColumns());
131+
select.setHavingClause(gen.generateExpression());
132+
return false;
133+
} else {
134+
if (select.getHavingClause() == null) {
135+
select.setHavingClause(gen.generateExpression());
136+
return false;
137+
} else {
138+
select.setHavingClause(null);
139+
return true;
140+
}
141+
}
142+
}
143+
144+
@Override
145+
protected boolean mutateAnd() {
146+
if (select.getWhereClause() == null) {
147+
select.setWhereClause(gen.generateExpression());
148+
} else {
149+
MySQLExpression newWhere = new MySQLBinaryLogicalOperation(select.getWhereClause(),
150+
gen.generateExpression(), MySQLBinaryLogicalOperator.AND);
151+
select.setWhereClause(newWhere);
152+
}
153+
return false;
154+
}
155+
156+
@Override
157+
protected boolean mutateOr() {
158+
if (select.getWhereClause() == null) {
159+
select.setWhereClause(gen.generateExpression());
160+
return false;
161+
} else {
162+
MySQLExpression newWhere = new MySQLBinaryLogicalOperation(select.getWhereClause(),
163+
gen.generateExpression(), MySQLBinaryLogicalOperator.OR);
164+
select.setWhereClause(newWhere);
165+
return true;
166+
}
167+
}
168+
169+
// The limit clause only accpets positive integers, which is not supported yet
170+
// private boolean mutateLimit() {
171+
// boolean increase = select.getLimitClause() != null;
172+
// if (increase) {
173+
// select.setLimitClause(null);
174+
// } else {
175+
// select.setLimitClause(gen.generateConstant());
176+
// }
177+
// return increase;
178+
// }
179+
180+
private int getRow(SQLGlobalState<?, ?> globalState, String selectStr, List<String> queryPlanSequences)
181+
throws AssertionError, SQLException {
182+
int row = -1;
183+
String explainQuery = "EXPLAIN " + selectStr;
184+
185+
// Log the query
186+
if (globalState.getOptions().logEachSelect()) {
187+
globalState.getLogger().writeCurrent(explainQuery);
188+
try {
189+
globalState.getLogger().getCurrentFileWriter().flush();
190+
} catch (IOException e) {
191+
e.printStackTrace();
192+
}
193+
}
194+
195+
// Get the row count
196+
SQLQueryAdapter q = new SQLQueryAdapter(explainQuery, errors);
197+
try (SQLancerResultSet rs = q.executeAndGet(globalState)) {
198+
if (rs != null) {
199+
while (rs.next()) {
200+
int estRows = rs.getInt(10);
201+
if (row == -1) {
202+
row = estRows;
203+
}
204+
String operation = rs.getString(2);
205+
queryPlanSequences.add(operation);
206+
}
207+
}
208+
} catch (Exception e) {
209+
throw new AssertionError(q.getQueryString(), e);
210+
}
211+
if (row == -1) {
212+
throw new IgnoreMeException();
213+
}
214+
return row;
215+
}
216+
217+
}

0 commit comments

Comments
 (0)