package sqlancer; import java.io.File; import java.io.FileWriter; import java.io.IOException; import java.io.Writer; import java.nio.file.Files; import java.text.DateFormat; import java.text.SimpleDateFormat; import java.util.ArrayList; import java.util.Date; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; import java.util.function.BiFunction; import java.util.stream.Collectors; import com.beust.jcommander.JCommander; import com.beust.jcommander.JCommander.Builder; import sqlancer.FoundBugException.Reproducer; import sqlancer.arangodb.ArangoDBProvider; import sqlancer.citus.CitusProvider; import sqlancer.clickhouse.ClickHouseProvider; import sqlancer.cockroachdb.CockroachDBProvider; import sqlancer.common.log.Loggable; import sqlancer.common.query.Query; import sqlancer.common.query.SQLQueryAdapter; import sqlancer.common.query.SQLancerResultSet; import sqlancer.cosmos.CosmosProvider; import sqlancer.duckdb.DuckDBProvider; import sqlancer.h2.H2Provider; import sqlancer.mariadb.MariaDBProvider; import sqlancer.mongodb.MongoDBProvider; import sqlancer.mysql.MySQLProvider; import sqlancer.postgres.PostgresProvider; import sqlancer.sqlite3.SQLite3GlobalState; import sqlancer.sqlite3.SQLite3Provider; import sqlancer.tidb.TiDBProvider; public final class Main { public static final File LOG_DIRECTORY = new File("logs"); public static volatile AtomicLong nrQueries = new AtomicLong(); public static volatile AtomicLong nrDatabases = new AtomicLong(); public static volatile AtomicLong nrSuccessfulActions = new AtomicLong(); public static volatile AtomicLong nrUnsuccessfulActions = new AtomicLong(); static int threadsShutdown; static boolean progressMonitorStarted; static { System.setProperty(org.slf4j.impl.SimpleLogger.DEFAULT_LOG_LEVEL_KEY, "ERROR"); if (!LOG_DIRECTORY.exists()) { LOG_DIRECTORY.mkdir(); } } private Main() { } public static final class StateLogger { private final File loggerFile; private final File reducedFile; private File curFile; private FileWriter logFileWriter; public FileWriter currentFileWriter; private static final List INITIALIZED_PROVIDER_NAMES = new ArrayList<>(); private final boolean logEachSelect; private final DatabaseProvider, ?, ?> databaseProvider; private static final class AlsoWriteToConsoleFileWriter extends FileWriter { AlsoWriteToConsoleFileWriter(File file) throws IOException { super(file); } @Override public Writer append(CharSequence arg0) throws IOException { System.err.println(arg0); return super.append(arg0); } @Override public void write(String str) throws IOException { System.err.println(str); super.write(str); } } public StateLogger(String databaseName, DatabaseProvider, ?, ?> provider, MainOptions options) { File dir = new File(LOG_DIRECTORY, provider.getDBMSName()); if (dir.exists() && !dir.isDirectory()) { throw new AssertionError(dir); } ensureExistsAndIsEmpty(dir, provider); loggerFile = new File(dir, databaseName + ".log"); reducedFile = new File(dir, databaseName + "-reduced.log"); logEachSelect = options.logEachSelect(); if (logEachSelect) { curFile = new File(dir, databaseName + "-cur.log"); } this.databaseProvider = provider; } private void ensureExistsAndIsEmpty(File dir, DatabaseProvider, ?, ?> provider) { if (INITIALIZED_PROVIDER_NAMES.contains(provider.getDBMSName())) { return; } synchronized (INITIALIZED_PROVIDER_NAMES) { if (!dir.exists()) { try { Files.createDirectories(dir.toPath()); } catch (IOException e) { throw new AssertionError(e); } } File[] listFiles = dir.listFiles(); assert listFiles != null : "directory was just created, so it should exist"; for (File file : listFiles) { if (!file.isDirectory()) { file.delete(); } } INITIALIZED_PROVIDER_NAMES.add(provider.getDBMSName()); } } private FileWriter getLogFileWriter() { if (logFileWriter == null) { try { logFileWriter = new AlsoWriteToConsoleFileWriter(loggerFile); } catch (IOException e) { throw new AssertionError(e); } } return logFileWriter; } private FileWriter getReducedWriter() { // if (logFileWriter == null) { try { logFileWriter = new FileWriter(reducedFile); } catch (IOException e) { throw new AssertionError(e); } // } return logFileWriter; } public FileWriter getCurrentFileWriter() { if (!logEachSelect) { throw new UnsupportedOperationException(); } if (currentFileWriter == null) { try { currentFileWriter = new FileWriter(curFile, false); } catch (IOException e) { throw new AssertionError(e); } } return currentFileWriter; } public void writeCurrent(StateToReproduce state) { if (!logEachSelect) { throw new UnsupportedOperationException(); } printState(getCurrentFileWriter(), state); try { currentFileWriter.flush(); } catch (IOException e) { e.printStackTrace(); } } public void writeCurrent(String input) { write(databaseProvider.getLoggableFactory().createLoggable(input)); } public void writeCurrentNoLineBreak(String input) { write(databaseProvider.getLoggableFactory().createLoggableWithNoLinebreak(input)); } private void write(Loggable loggable) { if (!logEachSelect) { throw new UnsupportedOperationException(); } try { getCurrentFileWriter().write(loggable.getLogString()); currentFileWriter.flush(); } catch (IOException e) { throw new AssertionError(); } } public void logReduced(StateToReproduce state) { FileWriter logFileWriter = getReducedWriter(); printState(logFileWriter, state); try { logFileWriter.close(); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } } public void logException(Throwable reduce, StateToReproduce state) { Loggable stackTrace = getStackTrace(reduce); FileWriter logFileWriter2 = getLogFileWriter(); try { logFileWriter2.write(stackTrace.getLogString()); printState(logFileWriter2, state); } catch (IOException e) { throw new AssertionError(e); } finally { try { logFileWriter2.flush(); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } } } private Loggable getStackTrace(Throwable e1) { return databaseProvider.getLoggableFactory().convertStacktraceToLoggable(e1); } private void printState(FileWriter writer, StateToReproduce state) { StringBuilder sb = new StringBuilder(); sb.append(databaseProvider.getLoggableFactory() .getInfo(state.getDatabaseName(), state.getDatabaseVersion(), state.getSeedValue()).getLogString()); for (Query> s : state.getStatements()) { sb.append(s.getLogString()); sb.append('\n'); } try { writer.write(sb.toString()); } catch (IOException e) { throw new AssertionError(e); } } } public static class QueryManager { private final GlobalState, ?, C> globalState; QueryManager(GlobalState, ?, C> globalState) { this.globalState = globalState; } public boolean execute(Query q, String... fills) throws Exception { globalState.getState().logStatement(q); boolean success; success = q.execute(globalState, fills); Main.nrSuccessfulActions.addAndGet(1); return success; } public SQLancerResultSet executeAndGet(Query q, String... fills) throws Exception { globalState.getState().logStatement(q); SQLancerResultSet result; result = q.executeAndGet(globalState, fills); Main.nrSuccessfulActions.addAndGet(1); return result; } public void incrementSelectQueryCount() { Main.nrQueries.addAndGet(1); } public void incrementCreateDatabase() { Main.nrDatabases.addAndGet(1); } } public static void main(String[] args) { System.exit(executeMain(args)); } public static class DBMSExecutor, O extends DBMSSpecificOptions>, C extends SQLancerDBConnection> { private final DatabaseProvider provider; private final MainOptions options; private final O command; private final String databaseName; private StateLogger logger; private StateToReproduce stateToRepro; private final Randomly r; public DBMSExecutor(DatabaseProvider provider, MainOptions options, O dbmsSpecificOptions, String databaseName, Randomly r) { this.provider = provider; this.options = options; this.databaseName = databaseName; this.command = dbmsSpecificOptions; this.r = r; } private G createGlobalState() { try { return provider.getGlobalStateClass().getDeclaredConstructor().newInstance(); } catch (Exception e) { throw new AssertionError(e); } } public O getCommand() { return command; } public void testConnection() throws Exception { G state = getInitializedGlobalState(options.getRandomSeed()); try (SQLancerDBConnection con = provider.createDatabase(state)) { return; } } boolean observedChange; public void run() throws Exception { G state = createGlobalState(); stateToRepro = provider.getStateToReproduce(databaseName); stateToRepro.seedValue = r.getSeed(); state.setState(stateToRepro); logger = new StateLogger(databaseName, provider, options); state.setRandomly(r); state.setDatabaseName(databaseName); state.setMainOptions(options); state.setDmbsSpecificOptions(command); try (C con = provider.createDatabase(state)) { QueryManager manager = new QueryManager<>(state); try { stateToRepro.databaseVersion = con.getDatabaseVersion(); } catch (Exception e) { // ignore } state.setConnection(con); state.setStateLogger(logger); state.setManager(manager); if (options.logEachSelect()) { logger.writeCurrent(state.getState()); } try { provider.generateAndTestDatabase(state); } catch (FoundBugException e) { try { logger.getCurrentFileWriter().close(); logger.currentFileWriter = null; } catch (IOException e2) { throw new AssertionError(e2); } Reproducer reproducer = e.getReproducer(); G newGlobalState = createGlobalState(); QueryManager newManager = new QueryManager<>(newGlobalState); newGlobalState.setDatabaseName(databaseName); newGlobalState.setMainOptions(options); newGlobalState.setDmbsSpecificOptions(command); newGlobalState.setStateLogger(new StateLogger(databaseName, provider, options)); newGlobalState.setManager(newManager); newGlobalState.setState(stateToRepro); List> knownToReproduceBugStatements = new ArrayList>(); for (Query> stat : state.getState().getStatements()) { knownToReproduceBugStatements.add((Query) stat); } // iterate until fixpoint do { observedChange = false; knownToReproduceBugStatements = tryReduction(state, reproducer, newGlobalState, knownToReproduceBugStatements, (candidateStatements, i) -> { candidateStatements.remove((int) i); return true; }); } while (observedChange); for (String s : new String[] { "OR IGNORE", "OR ABORT", "OR ROLLBACK", "OR FAIL", "TEMP", "TEMPORARY", "UNIQUE", "NOT NULL", "COLLATE BINARY", "COLLATE NOCASE", "COLLATE RTRIM", "INT", "REAL", "TEXT", "IF NOT EXISTS", "UNINDEXED" }) { knownToReproduceBugStatements = tryReplaceToken(state, reproducer, newGlobalState, knownToReproduceBugStatements, " " + s, ""); } throw e; } } } private List> tryReplaceToken(G state, Reproducer reproducer, G newGlobalState, List> knownToReproduceBugStatements, String target, String replaceBy) throws Exception { do { observedChange = false; knownToReproduceBugStatements = tryReduction(state, reproducer, newGlobalState, knownToReproduceBugStatements, (candidateStatements, i) -> { Query statement = candidateStatements.get(i); if (statement.getQueryString().contains(target)) { candidateStatements.set(i, (Query) new SQLQueryAdapter( statement.getQueryString().replace(target, replaceBy), true)); return true; } return false; } ); } while (observedChange); return knownToReproduceBugStatements; } private List> tryReduction(G state, Reproducer reproducer, G newGlobalState, List> knownToReproduceBugStatements, BiFunction>, Integer, Boolean> reductionOperation) throws Exception { for (int i = 0; i < knownToReproduceBugStatements.size(); i++) { try (C con2 = provider.createDatabase(newGlobalState)) { newGlobalState.setConnection(con2); List> candidateStatements = new ArrayList<>(knownToReproduceBugStatements); if (!reductionOperation.apply(candidateStatements, i)) { continue; } newGlobalState.getState().setStatements(candidateStatements.stream().collect(Collectors.toList())); for (Query s : candidateStatements) { try { s.execute(newGlobalState); } catch (Throwable ignoredException) { // ignore } } try { if (reproducer.bugStillTriggers((SQLite3GlobalState) newGlobalState)) { observedChange = true; knownToReproduceBugStatements = candidateStatements; reproducer.outputHook((SQLite3GlobalState) newGlobalState); state.getLogger().logReduced(newGlobalState.getState()); } } catch (Throwable ignoredException) { } } } return knownToReproduceBugStatements; } private G getInitializedGlobalState(long seed) { G state = createGlobalState(); stateToRepro = provider.getStateToReproduce(databaseName); stateToRepro.seedValue = seed; state.setState(stateToRepro); logger = new StateLogger(databaseName, provider, options); Randomly r = new Randomly(seed); state.setRandomly(r); state.setDatabaseName(databaseName); state.setMainOptions(options); state.setDmbsSpecificOptions(command); return state; } public StateLogger getLogger() { return logger; } public StateToReproduce getStateToReproduce() { return stateToRepro; } } public static class DBMSExecutorFactory, O extends DBMSSpecificOptions>, C extends SQLancerDBConnection> { private final DatabaseProvider provider; private final MainOptions options; private final O command; public DBMSExecutorFactory(DatabaseProvider provider, MainOptions options) { this.provider = provider; this.options = options; this.command = createCommand(); } private O createCommand() { try { return provider.getOptionClass().getDeclaredConstructor().newInstance(); } catch (Exception e) { throw new AssertionError(e); } } public O getCommand() { return command; } @SuppressWarnings("unchecked") public DBMSExecutor getDBMSExecutor(String databaseName, Randomly r) { try { return new DBMSExecutor(provider.getClass().getDeclaredConstructor().newInstance(), options, command, databaseName, r); } catch (Exception e) { throw new AssertionError(e); } } public DatabaseProvider getProvider() { return provider; } } public static int executeMain(String... args) throws AssertionError { List> providers = getDBMSProviders(); Map> nameToProvider = new HashMap<>(); MainOptions options = new MainOptions(); Builder commandBuilder = JCommander.newBuilder().addObject(options); for (DatabaseProvider, ?, ?> provider : providers) { String name = provider.getDBMSName(); if (!name.toLowerCase().equals(name)) { throw new AssertionError(name + " should be in lowercase!"); } DBMSExecutorFactory, ?, ?> executorFactory = new DBMSExecutorFactory<>(provider, options); commandBuilder = commandBuilder.addCommand(name, executorFactory.getCommand()); nameToProvider.put(name, executorFactory); } JCommander jc = commandBuilder.programName("SQLancer").build(); jc.parse(args); if (jc.getParsedCommand() == null || options.isHelp()) { jc.usage(); return options.getErrorExitCode(); } Randomly.initialize(options); if (options.printProgressInformation()) { startProgressMonitor(); if (options.printProgressSummary()) { Runtime.getRuntime().addShutdownHook(new Thread(new Runnable() { @Override public void run() { System.out.println("Overall execution statistics"); System.out.println("============================"); System.out.println(formatInteger(nrQueries.get()) + " queries"); System.out.println(formatInteger(nrDatabases.get()) + " databases"); System.out.println( formatInteger(nrSuccessfulActions.get()) + " successfully-executed statements"); System.out.println( formatInteger(nrUnsuccessfulActions.get()) + " unsuccessfuly-executed statements"); } private String formatInteger(long intValue) { if (intValue > 1000) { return String.format("%,9dk", intValue / 1000); } else { return String.format("%,10d", intValue); } } })); } } ExecutorService execService = Executors.newFixedThreadPool(options.getNumberConcurrentThreads()); DBMSExecutorFactory, ?, ?> executorFactory = nameToProvider.get(jc.getParsedCommand()); if (options.performConnectionTest()) { try { executorFactory.getDBMSExecutor(options.getDatabasePrefix() + "connectiontest", new Randomly()) .testConnection(); } catch (Exception e) { System.err.println( "SQLancer failed creating a test database, indicating that SQLancer might have failed connecting to the DBMS. In order to change the username, password, host and port, you can use the --username, --password, --host and --port options.\n\n"); e.printStackTrace(); return options.getErrorExitCode(); } } for (int i = 0; i < options.getTotalNumberTries(); i++) { final String databaseName = options.getDatabasePrefix() + i; final long seed; if (options.getRandomSeed() == -1) { seed = System.currentTimeMillis() + i; } else { seed = options.getRandomSeed() + i; } execService.execute(new Runnable() { @Override public void run() { Thread.currentThread().setName(databaseName); runThread(databaseName); } private void runThread(final String databaseName) { Randomly r = new Randomly(seed); try { if (options.getMaxGeneratedDatabases() == -1) { // run without a limit boolean continueRunning = true; while (continueRunning) { continueRunning = run(options, execService, executorFactory, r, databaseName); } } else { for (int i = 0; i < options.getMaxGeneratedDatabases(); i++) { boolean continueRunning = run(options, execService, executorFactory, r, databaseName); if (!continueRunning) { break; } } } } finally { threadsShutdown++; if (threadsShutdown == options.getTotalNumberTries()) { execService.shutdown(); } } } private boolean run(MainOptions options, ExecutorService execService, DBMSExecutorFactory, ?, ?> executorFactory, Randomly r, final String databaseName) { DBMSExecutor, ?, ?> executor = executorFactory.getDBMSExecutor(databaseName, r); try { executor.run(); return true; } catch (IgnoreMeException e) { return true; } catch (Throwable reduce) { reduce.printStackTrace(); executor.getStateToReproduce().exception = reduce.getMessage(); executor.getLogger().logFileWriter = null; executor.getLogger().logException(reduce, executor.getStateToReproduce()); return false; } finally { try { if (options.logEachSelect()) { if (executor.getLogger().currentFileWriter != null) { executor.getLogger().currentFileWriter.close(); } executor.getLogger().currentFileWriter = null; } } catch (IOException e) { e.printStackTrace(); } } } }); } try { if (options.getTimeoutSeconds() == -1) { execService.awaitTermination(Long.MAX_VALUE, TimeUnit.DAYS); } else { execService.awaitTermination(options.getTimeoutSeconds(), TimeUnit.SECONDS); } } catch (InterruptedException e) { e.printStackTrace(); } return threadsShutdown == 0 ? 0 : options.getErrorExitCode(); } static List> getDBMSProviders() { List> providers = new ArrayList<>(); providers.add(new SQLite3Provider()); providers.add(new CockroachDBProvider()); providers.add(new MySQLProvider()); providers.add(new MariaDBProvider()); providers.add(new TiDBProvider()); providers.add(new PostgresProvider()); providers.add(new CitusProvider()); providers.add(new ClickHouseProvider()); providers.add(new DuckDBProvider()); providers.add(new H2Provider()); providers.add(new MongoDBProvider()); providers.add(new CosmosProvider()); providers.add(new ArangoDBProvider()); return providers; } private static synchronized void startProgressMonitor() { if (progressMonitorStarted) { /* * it might be already started if, for example, the main method is called multiple times in a test (see * https://github.com/sqlancer/sqlancer/issues/90). */ return; } else { progressMonitorStarted = true; } final ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(1); scheduler.scheduleAtFixedRate(new Runnable() { private long timeMillis = System.currentTimeMillis(); private long lastNrQueries; private long lastNrDbs; { timeMillis = System.currentTimeMillis(); } @Override public void run() { long elapsedTimeMillis = System.currentTimeMillis() - timeMillis; long currentNrQueries = nrQueries.get(); long nrCurrentQueries = currentNrQueries - lastNrQueries; double throughput = nrCurrentQueries / (elapsedTimeMillis / 1000d); long currentNrDbs = nrDatabases.get(); long nrCurrentDbs = currentNrDbs - lastNrDbs; double throughputDbs = nrCurrentDbs / (elapsedTimeMillis / 1000d); long successfulStatementsRatio = (long) (100.0 * nrSuccessfulActions.get() / (nrSuccessfulActions.get() + nrUnsuccessfulActions.get())); DateFormat dateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss"); Date date = new Date(); System.out.println(String.format( "[%s] Executed %d queries (%d queries/s; %.2f/s dbs, successful statements: %2d%%). Threads shut down: %d.", dateFormat.format(date), currentNrQueries, (int) throughput, throughputDbs, successfulStatementsRatio, threadsShutdown)); timeMillis = System.currentTimeMillis(); lastNrQueries = currentNrQueries; lastNrDbs = currentNrDbs; } }, 5, 5, TimeUnit.SECONDS); } }