Skip to content

Commit 7fa5805

Browse files
committed
Add common CERT oracle
1 parent 91b6f86 commit 7fa5805

2 files changed

Lines changed: 162 additions & 0 deletions

File tree

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
package sqlancer.common.gen;
2+
3+
import java.util.List;
4+
5+
import sqlancer.common.ast.newast.Expression;
6+
import sqlancer.common.ast.newast.Join;
7+
import sqlancer.common.ast.newast.Select;
8+
import sqlancer.common.schema.AbstractTable;
9+
import sqlancer.common.schema.AbstractTableColumn;
10+
import sqlancer.common.schema.AbstractTables;
11+
12+
public interface CERTGenerator<S extends Select<J, E, T, C>, J extends Join<E, T, C>, E extends Expression<C>, T extends AbstractTable<C, ?, ?>, C extends AbstractTableColumn<?, ?>> {
13+
14+
CERTGenerator<S, J, E, T, C> setTablesAndColumns(AbstractTables<T, C> tables);
15+
16+
E generateBooleanExpression();
17+
18+
S generateSelect();
19+
20+
List<J> getRandomJoinClauses();
21+
22+
List<E> getTableRefs();
23+
24+
List<E> generateFetchColumns(boolean shouldCreateDummy);
25+
26+
String generateExplainQuery(S select);
27+
28+
boolean mutate(S select);
29+
}
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
package sqlancer.common.oracle;
2+
3+
import java.io.IOException;
4+
import java.sql.SQLException;
5+
import java.util.ArrayList;
6+
import java.util.List;
7+
import java.util.Optional;
8+
9+
import sqlancer.IgnoreMeException;
10+
import sqlancer.Randomly;
11+
import sqlancer.SQLGlobalState;
12+
import sqlancer.common.DBMSCommon;
13+
import sqlancer.common.ast.newast.Expression;
14+
import sqlancer.common.ast.newast.Join;
15+
import sqlancer.common.ast.newast.Select;
16+
import sqlancer.common.gen.CERTGenerator;
17+
import sqlancer.common.query.ExpectedErrors;
18+
import sqlancer.common.query.SQLQueryAdapter;
19+
import sqlancer.common.query.SQLancerResultSet;
20+
import sqlancer.common.schema.AbstractSchema;
21+
import sqlancer.common.schema.AbstractTable;
22+
import sqlancer.common.schema.AbstractTableColumn;
23+
import sqlancer.common.schema.AbstractTables;
24+
25+
public class CERTOracle<Z extends Select<J, E, T, C>, J extends Join<E, T, C>, E extends Expression<C>, S extends AbstractSchema<?, T>, T extends AbstractTable<C, ?, ?>, C extends AbstractTableColumn<?, ?>, G extends SQLGlobalState<?, S>>
26+
implements TestOracle<G> {
27+
28+
private final G state;
29+
private final CheckedFunction<SQLancerResultSet, Optional<Long>> rowCountParser;
30+
private final CheckedFunction<SQLancerResultSet, Optional<String>> queryPlanParser;
31+
32+
private CERTGenerator<Z, J, E, T, C> gen;
33+
private final ExpectedErrors errors;
34+
35+
public CERTOracle(G state, CERTGenerator<Z, J, E, T, C> gen, ExpectedErrors expectedErrors,
36+
CheckedFunction<SQLancerResultSet, Optional<Long>> rowCountParser,
37+
CheckedFunction<SQLancerResultSet, Optional<String>> queryPlanParser) {
38+
if (state == null || gen == null || expectedErrors == null) {
39+
throw new IllegalArgumentException("Null variables used to initialize test oracle.");
40+
}
41+
this.state = state;
42+
this.gen = gen;
43+
this.errors = expectedErrors;
44+
this.rowCountParser = rowCountParser;
45+
this.queryPlanParser = queryPlanParser;
46+
}
47+
48+
@Override
49+
public void check() throws SQLException {
50+
S schema = state.getSchema();
51+
AbstractTables<T, C> targetTables = TestOracleUtils.getRandomTableNonEmptyTables(schema);
52+
gen = gen.setTablesAndColumns(targetTables);
53+
54+
List<E> fetchColumns = gen.generateFetchColumns(false);
55+
56+
Z select = gen.generateSelect();
57+
select.setFetchColumns(fetchColumns);
58+
select.setJoinClauses(gen.getRandomJoinClauses());
59+
select.setFromList(gen.getTableRefs());
60+
61+
if (Randomly.getBoolean()) {
62+
select.setWhereClause(gen.generateBooleanExpression());
63+
}
64+
if (Randomly.getBoolean()) {
65+
select.setGroupByClause(fetchColumns);
66+
if (Randomly.getBoolean()) {
67+
select.setHavingClause(gen.generateBooleanExpression());
68+
}
69+
}
70+
71+
List<String> queryPlan1Sequences = new ArrayList<>();
72+
List<String> queryPlan2Sequences = new ArrayList<>();
73+
74+
String queryString1 = gen.generateExplainQuery(select);
75+
long rowCount1 = getRow(state, queryString1, queryPlan1Sequences);
76+
77+
boolean increase = gen.mutate(select);
78+
String queryString2 = gen.generateExplainQuery(select);
79+
long rowCount2 = getRow(state, queryString2, queryPlan2Sequences);
80+
81+
if (DBMSCommon.editDistance(queryPlan1Sequences, queryPlan2Sequences) > 1) {
82+
return;
83+
}
84+
85+
// Check the results
86+
if (increase && rowCount1 > rowCount2 || !increase && rowCount1 < rowCount2) {
87+
throw new AssertionError("Inconsistent result for query: " + queryString1 + "; --" + rowCount1 + "\n"
88+
+ queryString2 + "; --" + rowCount2);
89+
}
90+
}
91+
92+
private Long getRow(SQLGlobalState<?, ?> globalState, String explainQuery, List<String> queryPlanSequences)
93+
throws AssertionError, SQLException {
94+
Optional<Long> row = Optional.empty();
95+
96+
// Log the query
97+
if (globalState.getOptions().logEachSelect()) {
98+
globalState.getLogger().writeCurrent(explainQuery);
99+
try {
100+
globalState.getLogger().getCurrentFileWriter().flush();
101+
} catch (IOException e) {
102+
e.printStackTrace();
103+
}
104+
}
105+
106+
// Get the row count
107+
SQLQueryAdapter q = new SQLQueryAdapter(explainQuery, errors);
108+
try (SQLancerResultSet rs = q.executeAndGet(globalState)) {
109+
if (rs != null) {
110+
while (rs.next()) {
111+
Optional<Long> rowCount = rowCountParser.apply(rs);
112+
if (row.isEmpty() && rowCount.isPresent()) {
113+
row = rowCount;
114+
}
115+
116+
Optional<String> queryPlanSequence = queryPlanParser.apply(rs);
117+
queryPlanSequence.ifPresent(qps -> queryPlanSequences.add(qps));
118+
}
119+
}
120+
} catch (IgnoreMeException e) {
121+
throw new IgnoreMeException();
122+
} catch (Exception e) {
123+
throw new AssertionError(q.getQueryString(), e);
124+
}
125+
126+
return row.orElseThrow(IgnoreMeException::new);
127+
}
128+
129+
@FunctionalInterface
130+
public interface CheckedFunction<T, R> {
131+
R apply(T t) throws SQLException;
132+
}
133+
}

0 commit comments

Comments
 (0)