package sqlancer;
import java.sql.SQLException;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import sqlancer.StateToReproduce.OracleRunReproductionState;
import sqlancer.common.DBMSCommon;
import sqlancer.common.oracle.CompositeTestOracle;
import sqlancer.common.oracle.TestOracle;
import sqlancer.common.schema.AbstractSchema;
public abstract class ProviderAdapter, C>, O extends DBMSSpecificOptions extends OracleFactory>, C extends SQLancerDBConnection>
implements DatabaseProvider {
private final Class globalClass;
private final Class optionClass;
// Variables for QPG
Map queryPlanPool = new HashMap<>();
static double[] weightedAverageReward; // static variable for sharing across all threads
int currentSelectRewards;
int currentSelectCounts;
int currentMutationOperator = -1;
protected ProviderAdapter(Class globalClass, Class optionClass) {
this.globalClass = globalClass;
this.optionClass = optionClass;
}
@Override
public StateToReproduce getStateToReproduce(String databaseName) {
return new StateToReproduce(databaseName, this);
}
@Override
public Class getGlobalStateClass() {
return globalClass;
}
@Override
public Class getOptionClass() {
return optionClass;
}
@Override
public Reproducer generateAndTestDatabase(G globalState) throws Exception {
try {
generateDatabase(globalState);
checkViewsAreValid(globalState);
globalState.getManager().incrementCreateDatabase();
TestOracle oracle = getTestOracle(globalState);
for (int i = 0; i < globalState.getOptions().getNrQueries(); i++) {
try (OracleRunReproductionState localState = globalState.getState().createLocalState()) {
assert localState != null;
try {
oracle.check();
globalState.getManager().incrementSelectQueryCount();
} catch (IgnoreMeException ignored) {
} catch (AssertionError e) {
Reproducer reproducer = oracle.getLastReproducer();
if (reproducer != null) {
return reproducer;
}
throw e;
}
localState.executedWithoutError();
}
}
} finally {
globalState.getConnection().close();
}
return null;
}
protected abstract void checkViewsAreValid(G globalState) throws SQLException;
protected TestOracle getTestOracle(G globalState) throws Exception {
List extends OracleFactory> testOracleFactory = globalState.getDbmsSpecificOptions()
.getTestOracleFactory();
boolean testOracleRequiresMoreThanZeroRows = testOracleFactory.stream()
.anyMatch(OracleFactory::requiresAllTablesToContainRows);
boolean userRequiresMoreThanZeroRows = globalState.getOptions().testOnlyWithMoreThanZeroRows();
boolean checkZeroRows = testOracleRequiresMoreThanZeroRows || userRequiresMoreThanZeroRows;
if (checkZeroRows && globalState.getSchema().containsTableWithZeroRows(globalState)) {
if (globalState.getOptions().enableQPG()) {
addRowsToAllTables(globalState);
} else {
throw new IgnoreMeException();
}
}
if (testOracleFactory.size() == 1) {
return testOracleFactory.get(0).create(globalState);
} else {
return new CompositeTestOracle<>(testOracleFactory.stream().map(o -> {
try {
return o.create(globalState);
} catch (Exception e1) {
throw new AssertionError(e1);
}
}).collect(Collectors.toList()), globalState);
}
}
public abstract void generateDatabase(G globalState) throws Exception;
// QPG: entry function
@Override
public void generateAndTestDatabaseWithQueryPlanGuidance(G globalState) throws Exception {
if (weightedAverageReward == null) {
weightedAverageReward = initializeWeightedAverageReward(); // Same length as the list of mutators
}
try {
generateDatabase(globalState);
checkViewsAreValid(globalState);
globalState.getManager().incrementCreateDatabase();
Long executedQueryCount = 0L;
while (executedQueryCount < globalState.getOptions().getNrQueries()) {
int numOfNoNewQueryPlans = 0;
TestOracle oracle = getTestOracle(globalState);
while (executedQueryCount < globalState.getOptions().getNrQueries()) {
try (OracleRunReproductionState localState = globalState.getState().createLocalState()) {
assert localState != null;
try {
oracle.check();
String query = oracle.getLastQueryString();
executedQueryCount += 1;
if (addQueryPlan(query, globalState)) {
numOfNoNewQueryPlans = 0;
} else {
numOfNoNewQueryPlans++;
}
globalState.getManager().incrementSelectQueryCount();
} catch (IgnoreMeException e) {
}
localState.executedWithoutError();
}
// exit loop to mutate tables if no new query plans have been found after a while
if (numOfNoNewQueryPlans > globalState.getOptions().getQPGMaxMutationInterval()) {
mutateTables(globalState);
break;
}
}
}
} finally {
globalState.getConnection().close();
}
}
// QPG: mutate tables for a new database state
private synchronized boolean mutateTables(G globalState) throws Exception {
// Update rewards based on a set of newly generated queries in last iteration
if (currentMutationOperator != -1) {
weightedAverageReward[currentMutationOperator] += ((double) currentSelectRewards
/ (double) currentSelectCounts) * globalState.getOptions().getQPGk();
}
currentMutationOperator = -1;
// Choose mutator based on the rewards
int selectedActionIndex = 0;
if (Randomly.getPercentage() < globalState.getOptions().getQPGProbability()) {
selectedActionIndex = globalState.getRandomly().getInteger(0, weightedAverageReward.length);
} else {
selectedActionIndex = DBMSCommon.getMaxIndexInDoubleArray(weightedAverageReward);
}
int reward = 0;
try {
executeMutator(selectedActionIndex, globalState);
checkViewsAreValid(globalState); // Remove the invalid views
reward = checkQueryPlan(globalState);
} catch (IgnoreMeException | AssertionError e) {
} finally {
// Update rewards based on existing queries associated with the query plan pool
updateReward(selectedActionIndex, (double) reward / (double) queryPlanPool.size(), globalState);
currentMutationOperator = selectedActionIndex;
}
// Clear the variables for storing the rewards of the action on a set of newly generated queries
currentSelectRewards = 0;
currentSelectCounts = 0;
return true;
}
// QPG: add a query plan to the query plan pool and return true if the query plan is new
private boolean addQueryPlan(String selectStr, G globalState) throws Exception {
String queryPlan = getQueryPlan(selectStr, globalState);
if (globalState.getOptions().logQueryPlan()) {
globalState.getLogger().writeQueryPlan(queryPlan);
}
currentSelectCounts += 1;
if (queryPlanPool.containsKey(queryPlan)) {
return false;
} else {
queryPlanPool.put(queryPlan, selectStr);
currentSelectRewards += 1;
return true;
}
}
// Obtain the reward of the current action based on the queries associated with the query plan pool
private int checkQueryPlan(G globalState) throws Exception {
int newQueryPlanFound = 0;
HashMap modifiedQueryPlan = new HashMap<>();
for (Iterator> it = queryPlanPool.entrySet().iterator(); it.hasNext();) {
Map.Entry item = it.next();
String queryPlan = item.getKey();
String selectStr = item.getValue();
String newQueryPlan = getQueryPlan(selectStr, globalState);
if (newQueryPlan.isEmpty()) { // Invalid query
it.remove();
} else if (!queryPlan.equals(newQueryPlan)) { // A query plan has been changed
it.remove();
modifiedQueryPlan.put(newQueryPlan, selectStr);
if (!queryPlanPool.containsKey(newQueryPlan)) { // A new query plan is found
newQueryPlanFound++;
}
}
}
queryPlanPool.putAll(modifiedQueryPlan);
return newQueryPlanFound;
}
// QPG: update the reward of current action
private void updateReward(int actionIndex, double reward, G globalState) {
weightedAverageReward[actionIndex] += (reward - weightedAverageReward[actionIndex])
* globalState.getOptions().getQPGk();
}
// QPG: initialize the weighted average reward of all mutation operators (required implementation in specific DBMS)
protected double[] initializeWeightedAverageReward() {
throw new UnsupportedOperationException();
}
// QPG: obtain the query plan of a query (required implementation in specific DBMS)
protected String getQueryPlan(String selectStr, G globalState) throws Exception {
throw new UnsupportedOperationException();
}
// QPG: execute a mutation operator (required implementation in specific DBMS)
protected void executeMutator(int index, G globalState) throws Exception {
throw new UnsupportedOperationException();
}
// QPG: add rows to all tables (required implementation in specific DBMS when enabling PQS oracle for QPG)
protected boolean addRowsToAllTables(G globalState) throws Exception {
throw new UnsupportedOperationException();
}
}