package sqlancer;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.io.Writer;
import java.nio.file.Files;
import java.text.DateFormat;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
import com.beust.jcommander.JCommander;
import com.beust.jcommander.JCommander.Builder;
import sqlancer.FoundBugException.Reproducer;
import sqlancer.arangodb.ArangoDBProvider;
import sqlancer.citus.CitusProvider;
import sqlancer.clickhouse.ClickHouseProvider;
import sqlancer.cockroachdb.CockroachDBProvider;
import sqlancer.common.log.Loggable;
import sqlancer.common.query.Query;
import sqlancer.common.query.SQLQueryAdapter;
import sqlancer.common.query.SQLancerResultSet;
import sqlancer.cosmos.CosmosProvider;
import sqlancer.duckdb.DuckDBProvider;
import sqlancer.h2.H2Provider;
import sqlancer.mariadb.MariaDBProvider;
import sqlancer.mongodb.MongoDBProvider;
import sqlancer.mysql.MySQLProvider;
import sqlancer.postgres.PostgresProvider;
import sqlancer.sqlite3.SQLite3GlobalState;
import sqlancer.sqlite3.SQLite3Provider;
import sqlancer.tidb.TiDBProvider;
public final class Main {
public static final File LOG_DIRECTORY = new File("logs");
public static volatile AtomicLong nrQueries = new AtomicLong();
public static volatile AtomicLong nrDatabases = new AtomicLong();
public static volatile AtomicLong nrSuccessfulActions = new AtomicLong();
public static volatile AtomicLong nrUnsuccessfulActions = new AtomicLong();
static int threadsShutdown;
static boolean progressMonitorStarted;
static {
System.setProperty(org.slf4j.impl.SimpleLogger.DEFAULT_LOG_LEVEL_KEY, "ERROR");
if (!LOG_DIRECTORY.exists()) {
LOG_DIRECTORY.mkdir();
}
}
private Main() {
}
public static final class StateLogger {
private final File loggerFile;
private final File reducedFile;
private File curFile;
private FileWriter logFileWriter;
public FileWriter currentFileWriter;
private static final List INITIALIZED_PROVIDER_NAMES = new ArrayList<>();
private final boolean logEachSelect;
private final DatabaseProvider, ?, ?> databaseProvider;
private static final class AlsoWriteToConsoleFileWriter extends FileWriter {
AlsoWriteToConsoleFileWriter(File file) throws IOException {
super(file);
}
@Override
public Writer append(CharSequence arg0) throws IOException {
System.err.println(arg0);
return super.append(arg0);
}
@Override
public void write(String str) throws IOException {
System.err.println(str);
super.write(str);
}
}
public StateLogger(String databaseName, DatabaseProvider, ?, ?> provider, MainOptions options) {
File dir = new File(LOG_DIRECTORY, provider.getDBMSName());
if (dir.exists() && !dir.isDirectory()) {
throw new AssertionError(dir);
}
ensureExistsAndIsEmpty(dir, provider);
loggerFile = new File(dir, databaseName + ".log");
reducedFile = new File(dir, databaseName + "-reduced.log");
logEachSelect = options.logEachSelect();
if (logEachSelect) {
curFile = new File(dir, databaseName + "-cur.log");
}
this.databaseProvider = provider;
}
private void ensureExistsAndIsEmpty(File dir, DatabaseProvider, ?, ?> provider) {
if (INITIALIZED_PROVIDER_NAMES.contains(provider.getDBMSName())) {
return;
}
synchronized (INITIALIZED_PROVIDER_NAMES) {
if (!dir.exists()) {
try {
Files.createDirectories(dir.toPath());
} catch (IOException e) {
throw new AssertionError(e);
}
}
File[] listFiles = dir.listFiles();
assert listFiles != null : "directory was just created, so it should exist";
for (File file : listFiles) {
if (!file.isDirectory()) {
file.delete();
}
}
INITIALIZED_PROVIDER_NAMES.add(provider.getDBMSName());
}
}
private FileWriter getLogFileWriter() {
if (logFileWriter == null) {
try {
logFileWriter = new AlsoWriteToConsoleFileWriter(loggerFile);
} catch (IOException e) {
throw new AssertionError(e);
}
}
return logFileWriter;
}
private FileWriter getReducedWriter() {
// if (logFileWriter == null) {
try {
logFileWriter = new FileWriter(reducedFile);
} catch (IOException e) {
throw new AssertionError(e);
}
// }
return logFileWriter;
}
public FileWriter getCurrentFileWriter() {
if (!logEachSelect) {
throw new UnsupportedOperationException();
}
if (currentFileWriter == null) {
try {
currentFileWriter = new FileWriter(curFile, false);
} catch (IOException e) {
throw new AssertionError(e);
}
}
return currentFileWriter;
}
public void writeCurrent(StateToReproduce state) {
if (!logEachSelect) {
throw new UnsupportedOperationException();
}
printState(getCurrentFileWriter(), state);
try {
currentFileWriter.flush();
} catch (IOException e) {
e.printStackTrace();
}
}
public void writeCurrent(String input) {
write(databaseProvider.getLoggableFactory().createLoggable(input));
}
public void writeCurrentNoLineBreak(String input) {
write(databaseProvider.getLoggableFactory().createLoggableWithNoLinebreak(input));
}
private void write(Loggable loggable) {
if (!logEachSelect) {
throw new UnsupportedOperationException();
}
try {
getCurrentFileWriter().write(loggable.getLogString());
currentFileWriter.flush();
} catch (IOException e) {
throw new AssertionError();
}
}
public void logReduced(StateToReproduce state) {
FileWriter logFileWriter = getReducedWriter();
printState(logFileWriter, state);
try {
logFileWriter.close();
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
public void logException(Throwable reduce, StateToReproduce state) {
Loggable stackTrace = getStackTrace(reduce);
FileWriter logFileWriter2 = getLogFileWriter();
try {
logFileWriter2.write(stackTrace.getLogString());
printState(logFileWriter2, state);
} catch (IOException e) {
throw new AssertionError(e);
} finally {
try {
logFileWriter2.flush();
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}
private Loggable getStackTrace(Throwable e1) {
return databaseProvider.getLoggableFactory().convertStacktraceToLoggable(e1);
}
private void printState(FileWriter writer, StateToReproduce state) {
StringBuilder sb = new StringBuilder();
sb.append(databaseProvider.getLoggableFactory()
.getInfo(state.getDatabaseName(), state.getDatabaseVersion(), state.getSeedValue()).getLogString());
for (Query> s : state.getStatements()) {
sb.append(s.getLogString());
sb.append('\n');
}
try {
writer.write(sb.toString());
} catch (IOException e) {
throw new AssertionError(e);
}
}
}
public static class QueryManager {
private final GlobalState, ?, C> globalState;
QueryManager(GlobalState, ?, C> globalState) {
this.globalState = globalState;
}
public boolean execute(Query q, String... fills) throws Exception {
globalState.getState().logStatement(q);
boolean success;
success = q.execute(globalState, fills);
Main.nrSuccessfulActions.addAndGet(1);
return success;
}
public SQLancerResultSet executeAndGet(Query q, String... fills) throws Exception {
globalState.getState().logStatement(q);
SQLancerResultSet result;
result = q.executeAndGet(globalState, fills);
Main.nrSuccessfulActions.addAndGet(1);
return result;
}
public void incrementSelectQueryCount() {
Main.nrQueries.addAndGet(1);
}
public void incrementCreateDatabase() {
Main.nrDatabases.addAndGet(1);
}
}
public static void main(String[] args) {
System.exit(executeMain(args));
}
public static class DBMSExecutor, O extends DBMSSpecificOptions>, C extends SQLancerDBConnection> {
private final DatabaseProvider provider;
private final MainOptions options;
private final O command;
private final String databaseName;
private StateLogger logger;
private StateToReproduce stateToRepro;
private final Randomly r;
public DBMSExecutor(DatabaseProvider provider, MainOptions options, O dbmsSpecificOptions,
String databaseName, Randomly r) {
this.provider = provider;
this.options = options;
this.databaseName = databaseName;
this.command = dbmsSpecificOptions;
this.r = r;
}
private G createGlobalState() {
try {
return provider.getGlobalStateClass().getDeclaredConstructor().newInstance();
} catch (Exception e) {
throw new AssertionError(e);
}
}
public O getCommand() {
return command;
}
public void testConnection() throws Exception {
G state = getInitializedGlobalState(options.getRandomSeed());
try (SQLancerDBConnection con = provider.createDatabase(state)) {
return;
}
}
boolean observedChange;
public void run() throws Exception {
G state = createGlobalState();
stateToRepro = provider.getStateToReproduce(databaseName);
stateToRepro.seedValue = r.getSeed();
state.setState(stateToRepro);
logger = new StateLogger(databaseName, provider, options);
state.setRandomly(r);
state.setDatabaseName(databaseName);
state.setMainOptions(options);
state.setDmbsSpecificOptions(command);
try (C con = provider.createDatabase(state)) {
QueryManager manager = new QueryManager<>(state);
try {
stateToRepro.databaseVersion = con.getDatabaseVersion();
} catch (Exception e) {
// ignore
}
state.setConnection(con);
state.setStateLogger(logger);
state.setManager(manager);
if (options.logEachSelect()) {
logger.writeCurrent(state.getState());
}
try {
provider.generateAndTestDatabase(state);
} catch (FoundBugException e) {
try {
logger.getCurrentFileWriter().close();
logger.currentFileWriter = null;
} catch (IOException e2) {
throw new AssertionError(e2);
}
Reproducer reproducer = e.getReproducer();
G newGlobalState = createGlobalState();
QueryManager newManager = new QueryManager<>(newGlobalState);
newGlobalState.setDatabaseName(databaseName);
newGlobalState.setMainOptions(options);
newGlobalState.setDmbsSpecificOptions(command);
newGlobalState.setStateLogger(new StateLogger(databaseName, provider, options));
newGlobalState.setManager(newManager);
newGlobalState.setState(stateToRepro);
List> knownToReproduceBugStatements = new ArrayList>();
for (Query> stat : state.getState().getStatements()) {
knownToReproduceBugStatements.add((Query) stat);
}
// iterate until fixpoint
do {
observedChange = false;
knownToReproduceBugStatements = tryReduction(state, reproducer, newGlobalState,
knownToReproduceBugStatements, (candidateStatements, i) -> {
candidateStatements.remove((int) i);
return true;
});
} while (observedChange);
for (String s : new String[] { "OR IGNORE", "OR ABORT", "OR ROLLBACK", "OR FAIL", "TEMP",
"TEMPORARY", "UNIQUE", "NOT NULL", "COLLATE BINARY", "COLLATE NOCASE", "COLLATE RTRIM",
"INT", "REAL", "TEXT", "IF NOT EXISTS", "UNINDEXED" }) {
knownToReproduceBugStatements = tryReplaceToken(state, reproducer, newGlobalState,
knownToReproduceBugStatements, " " + s, "");
}
throw e;
}
}
}
private List> tryReplaceToken(G state, Reproducer reproducer, G newGlobalState,
List> knownToReproduceBugStatements, String target, String replaceBy) throws Exception {
do {
observedChange = false;
knownToReproduceBugStatements = tryReduction(state, reproducer, newGlobalState,
knownToReproduceBugStatements, (candidateStatements, i) -> {
Query statement = candidateStatements.get(i);
if (statement.getQueryString().contains(target)) {
candidateStatements.set(i, (Query) new SQLQueryAdapter(
statement.getQueryString().replace(target, replaceBy), true));
return true;
}
return false;
}
);
} while (observedChange);
return knownToReproduceBugStatements;
}
private List> tryReduction(G state, Reproducer reproducer, G newGlobalState,
List> knownToReproduceBugStatements,
BiFunction>, Integer, Boolean> reductionOperation) throws Exception {
for (int i = 0; i < knownToReproduceBugStatements.size(); i++) {
try (C con2 = provider.createDatabase(newGlobalState)) {
newGlobalState.setConnection(con2);
List> candidateStatements = new ArrayList<>(knownToReproduceBugStatements);
if (!reductionOperation.apply(candidateStatements, i)) {
continue;
}
newGlobalState.getState().setStatements(candidateStatements.stream().collect(Collectors.toList()));
for (Query s : candidateStatements) {
try {
s.execute(newGlobalState);
} catch (Throwable ignoredException) {
// ignore
}
}
try {
if (reproducer.bugStillTriggers((SQLite3GlobalState) newGlobalState)) {
observedChange = true;
knownToReproduceBugStatements = candidateStatements;
reproducer.outputHook((SQLite3GlobalState) newGlobalState);
state.getLogger().logReduced(newGlobalState.getState());
}
} catch (Throwable ignoredException) {
}
}
}
return knownToReproduceBugStatements;
}
private G getInitializedGlobalState(long seed) {
G state = createGlobalState();
stateToRepro = provider.getStateToReproduce(databaseName);
stateToRepro.seedValue = seed;
state.setState(stateToRepro);
logger = new StateLogger(databaseName, provider, options);
Randomly r = new Randomly(seed);
state.setRandomly(r);
state.setDatabaseName(databaseName);
state.setMainOptions(options);
state.setDmbsSpecificOptions(command);
return state;
}
public StateLogger getLogger() {
return logger;
}
public StateToReproduce getStateToReproduce() {
return stateToRepro;
}
}
public static class DBMSExecutorFactory, O extends DBMSSpecificOptions>, C extends SQLancerDBConnection> {
private final DatabaseProvider provider;
private final MainOptions options;
private final O command;
public DBMSExecutorFactory(DatabaseProvider provider, MainOptions options) {
this.provider = provider;
this.options = options;
this.command = createCommand();
}
private O createCommand() {
try {
return provider.getOptionClass().getDeclaredConstructor().newInstance();
} catch (Exception e) {
throw new AssertionError(e);
}
}
public O getCommand() {
return command;
}
@SuppressWarnings("unchecked")
public DBMSExecutor getDBMSExecutor(String databaseName, Randomly r) {
try {
return new DBMSExecutor(provider.getClass().getDeclaredConstructor().newInstance(), options,
command, databaseName, r);
} catch (Exception e) {
throw new AssertionError(e);
}
}
public DatabaseProvider getProvider() {
return provider;
}
}
public static int executeMain(String... args) throws AssertionError {
List> providers = getDBMSProviders();
Map> nameToProvider = new HashMap<>();
MainOptions options = new MainOptions();
Builder commandBuilder = JCommander.newBuilder().addObject(options);
for (DatabaseProvider, ?, ?> provider : providers) {
String name = provider.getDBMSName();
if (!name.toLowerCase().equals(name)) {
throw new AssertionError(name + " should be in lowercase!");
}
DBMSExecutorFactory, ?, ?> executorFactory = new DBMSExecutorFactory<>(provider, options);
commandBuilder = commandBuilder.addCommand(name, executorFactory.getCommand());
nameToProvider.put(name, executorFactory);
}
JCommander jc = commandBuilder.programName("SQLancer").build();
jc.parse(args);
if (jc.getParsedCommand() == null || options.isHelp()) {
jc.usage();
return options.getErrorExitCode();
}
Randomly.initialize(options);
if (options.printProgressInformation()) {
startProgressMonitor();
if (options.printProgressSummary()) {
Runtime.getRuntime().addShutdownHook(new Thread(new Runnable() {
@Override
public void run() {
System.out.println("Overall execution statistics");
System.out.println("============================");
System.out.println(formatInteger(nrQueries.get()) + " queries");
System.out.println(formatInteger(nrDatabases.get()) + " databases");
System.out.println(
formatInteger(nrSuccessfulActions.get()) + " successfully-executed statements");
System.out.println(
formatInteger(nrUnsuccessfulActions.get()) + " unsuccessfuly-executed statements");
}
private String formatInteger(long intValue) {
if (intValue > 1000) {
return String.format("%,9dk", intValue / 1000);
} else {
return String.format("%,10d", intValue);
}
}
}));
}
}
ExecutorService execService = Executors.newFixedThreadPool(options.getNumberConcurrentThreads());
DBMSExecutorFactory, ?, ?> executorFactory = nameToProvider.get(jc.getParsedCommand());
if (options.performConnectionTest()) {
try {
executorFactory.getDBMSExecutor(options.getDatabasePrefix() + "connectiontest", new Randomly())
.testConnection();
} catch (Exception e) {
System.err.println(
"SQLancer failed creating a test database, indicating that SQLancer might have failed connecting to the DBMS. In order to change the username, password, host and port, you can use the --username, --password, --host and --port options.\n\n");
e.printStackTrace();
return options.getErrorExitCode();
}
}
for (int i = 0; i < options.getTotalNumberTries(); i++) {
final String databaseName = options.getDatabasePrefix() + i;
final long seed;
if (options.getRandomSeed() == -1) {
seed = System.currentTimeMillis() + i;
} else {
seed = options.getRandomSeed() + i;
}
execService.execute(new Runnable() {
@Override
public void run() {
Thread.currentThread().setName(databaseName);
runThread(databaseName);
}
private void runThread(final String databaseName) {
Randomly r = new Randomly(seed);
try {
if (options.getMaxGeneratedDatabases() == -1) {
// run without a limit
boolean continueRunning = true;
while (continueRunning) {
continueRunning = run(options, execService, executorFactory, r, databaseName);
}
} else {
for (int i = 0; i < options.getMaxGeneratedDatabases(); i++) {
boolean continueRunning = run(options, execService, executorFactory, r, databaseName);
if (!continueRunning) {
break;
}
}
}
} finally {
threadsShutdown++;
if (threadsShutdown == options.getTotalNumberTries()) {
execService.shutdown();
}
}
}
private boolean run(MainOptions options, ExecutorService execService,
DBMSExecutorFactory, ?, ?> executorFactory, Randomly r, final String databaseName) {
DBMSExecutor, ?, ?> executor = executorFactory.getDBMSExecutor(databaseName, r);
try {
executor.run();
return true;
} catch (IgnoreMeException e) {
return true;
} catch (Throwable reduce) {
reduce.printStackTrace();
executor.getStateToReproduce().exception = reduce.getMessage();
executor.getLogger().logFileWriter = null;
executor.getLogger().logException(reduce, executor.getStateToReproduce());
return false;
} finally {
try {
if (options.logEachSelect()) {
if (executor.getLogger().currentFileWriter != null) {
executor.getLogger().currentFileWriter.close();
}
executor.getLogger().currentFileWriter = null;
}
} catch (IOException e) {
e.printStackTrace();
}
}
}
});
}
try {
if (options.getTimeoutSeconds() == -1) {
execService.awaitTermination(Long.MAX_VALUE, TimeUnit.DAYS);
} else {
execService.awaitTermination(options.getTimeoutSeconds(), TimeUnit.SECONDS);
}
} catch (InterruptedException e) {
e.printStackTrace();
}
return threadsShutdown == 0 ? 0 : options.getErrorExitCode();
}
static List> getDBMSProviders() {
List> providers = new ArrayList<>();
providers.add(new SQLite3Provider());
providers.add(new CockroachDBProvider());
providers.add(new MySQLProvider());
providers.add(new MariaDBProvider());
providers.add(new TiDBProvider());
providers.add(new PostgresProvider());
providers.add(new CitusProvider());
providers.add(new ClickHouseProvider());
providers.add(new DuckDBProvider());
providers.add(new H2Provider());
providers.add(new MongoDBProvider());
providers.add(new CosmosProvider());
providers.add(new ArangoDBProvider());
return providers;
}
private static synchronized void startProgressMonitor() {
if (progressMonitorStarted) {
/*
* it might be already started if, for example, the main method is called multiple times in a test (see
* https://github.com/sqlancer/sqlancer/issues/90).
*/
return;
} else {
progressMonitorStarted = true;
}
final ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(1);
scheduler.scheduleAtFixedRate(new Runnable() {
private long timeMillis = System.currentTimeMillis();
private long lastNrQueries;
private long lastNrDbs;
{
timeMillis = System.currentTimeMillis();
}
@Override
public void run() {
long elapsedTimeMillis = System.currentTimeMillis() - timeMillis;
long currentNrQueries = nrQueries.get();
long nrCurrentQueries = currentNrQueries - lastNrQueries;
double throughput = nrCurrentQueries / (elapsedTimeMillis / 1000d);
long currentNrDbs = nrDatabases.get();
long nrCurrentDbs = currentNrDbs - lastNrDbs;
double throughputDbs = nrCurrentDbs / (elapsedTimeMillis / 1000d);
long successfulStatementsRatio = (long) (100.0 * nrSuccessfulActions.get()
/ (nrSuccessfulActions.get() + nrUnsuccessfulActions.get()));
DateFormat dateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss");
Date date = new Date();
System.out.println(String.format(
"[%s] Executed %d queries (%d queries/s; %.2f/s dbs, successful statements: %2d%%). Threads shut down: %d.",
dateFormat.format(date), currentNrQueries, (int) throughput, throughputDbs,
successfulStatementsRatio, threadsShutdown));
timeMillis = System.currentTimeMillis();
lastNrQueries = currentNrQueries;
lastNrDbs = currentNrDbs;
}
}, 5, 5, TimeUnit.SECONDS);
}
}