package sqlancer.presto; import java.sql.Connection; import java.sql.DriverManager; import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Statement; import java.util.ArrayList; import java.util.List; import java.util.Objects; import com.google.auto.service.AutoService; import sqlancer.AbstractAction; import sqlancer.DatabaseProvider; import sqlancer.IgnoreMeException; import sqlancer.MainOptions; import sqlancer.Randomly; import sqlancer.SQLConnection; import sqlancer.SQLProviderAdapter; import sqlancer.StatementExecutor; import sqlancer.common.query.SQLQueryAdapter; import sqlancer.common.query.SQLQueryProvider; import sqlancer.presto.gen.PrestoInsertGenerator; import sqlancer.presto.gen.PrestoTableGenerator; @AutoService(DatabaseProvider.class) public class PrestoProvider extends SQLProviderAdapter { public PrestoProvider() { super(PrestoGlobalState.class, PrestoOptions.class); } // TODO : check actions based on connector // returns number of actions private static int mapActions(PrestoGlobalState globalState, Action a) { Randomly r = globalState.getRandomly(); if (Objects.requireNonNull(a) == Action.INSERT) { return r.getInteger(0, globalState.getOptions().getMaxNumberInserts()); // case UPDATE: // return r.getInteger(0, globalState.getDbmsSpecificOptions().maxNumUpdates + 1); // case EXPLAIN: // return r.getInteger(0, 2); // case DELETE: // return r.getInteger(0, globalState.getDbmsSpecificOptions().maxNumDeletes + 1); // case CREATE_VIEW: // return r.getInteger(0, globalState.getDbmsSpecificOptions().maxNumViews + 1); } throw new AssertionError(a); } @Override public void generateDatabase(PrestoGlobalState globalState) throws Exception { for (int i = 0; i < Randomly.fromOptions(1, 2); i++) { boolean success; do { SQLQueryAdapter qt = new PrestoTableGenerator().getQuery(globalState); success = globalState.executeStatement(qt); } while (!success); } if (globalState.getSchema().getDatabaseTables().isEmpty()) { throw new IgnoreMeException(); // TODO } StatementExecutor se = new StatementExecutor<>(globalState, Action.values(), PrestoProvider::mapActions, (q) -> { if (globalState.getSchema().getDatabaseTables().isEmpty()) { throw new IgnoreMeException(); } }); se.executeStatements(); } @Override public SQLConnection createDatabase(PrestoGlobalState globalState) throws SQLException { String username = globalState.getOptions().getUserName(); String password = globalState.getOptions().getPassword(); boolean useSSl = true; if (globalState.getOptions().isDefaultUsername() && globalState.getOptions().isDefaultPassword()) { username = "presto"; password = null; useSSl = false; } String host = globalState.getOptions().getHost(); int port = globalState.getOptions().getPort(); if (host == null) { host = PrestoOptions.DEFAULT_HOST; } if (port == MainOptions.NO_SET_PORT) { port = PrestoOptions.DEFAULT_PORT; } String catalogName = globalState.getDbmsSpecificOptions().catalog; String databaseName = globalState.getDatabaseName(); String url = String.format("jdbc:presto://%s:%d/%s?SSL=%b", host, port, catalogName, useSSl); Connection con = DriverManager.getConnection(url, username, password); List schemaNames = getSchemaNames(con, catalogName, databaseName); dropExistingTables(con, catalogName, databaseName, schemaNames); dropSchema(globalState, con, catalogName, databaseName); createSchema(globalState, con, catalogName, databaseName); useSchema(globalState, con, catalogName, databaseName); return new SQLConnection(con); } private static void useSchema(PrestoGlobalState globalState, Connection con, String catalogName, String databaseName) throws SQLException { globalState.getState().logStatement("USE " + catalogName + "." + databaseName); try (Statement s = con.createStatement()) { s.execute("USE " + catalogName + "." + databaseName); } } private static void createSchema(PrestoGlobalState globalState, Connection con, String catalogName, String databaseName) throws SQLException { globalState.getState().logStatement("CREATE SCHEMA IF NOT EXISTS " + catalogName + "." + databaseName); try (Statement s = con.createStatement()) { s.execute("CREATE SCHEMA IF NOT EXISTS " + catalogName + "." + databaseName); } } private static void dropSchema(PrestoGlobalState globalState, Connection con, String catalogName, String databaseName) throws SQLException { globalState.getState().logStatement("DROP SCHEMA IF EXISTS " + catalogName + "." + databaseName); try (Statement s = con.createStatement()) { s.execute("DROP SCHEMA IF EXISTS " + catalogName + "." + databaseName); } } private static List getSchemaNames(Connection con, String catalogName, String databaseName) throws SQLException { List schemaNames = new ArrayList<>(); final String showSchemasSql = "SHOW SCHEMAS FROM " + catalogName + " LIKE '" + databaseName + "'"; try (Statement s = con.createStatement()) { try (ResultSet rs = s.executeQuery(showSchemasSql)) { while (rs.next()) { schemaNames.add(rs.getString("Schema")); } } } return schemaNames; } private static void dropExistingTables(Connection con, String catalogName, String databaseName, List schemaNames) throws SQLException { if (!schemaNames.isEmpty()) { List tableNames = new ArrayList<>(); try (Statement s = con.createStatement()) { try (ResultSet rs = s.executeQuery("SHOW TABLES FROM " + catalogName + "." + databaseName)) { while (rs.next()) { tableNames.add(rs.getString("Table")); } } } try (Statement s = con.createStatement()) { for (String tableName : tableNames) { s.execute("DROP TABLE IF EXISTS " + catalogName + "." + databaseName + "." + tableName); } } } } @Override public String getDBMSName() { return "presto"; } public enum Action implements AbstractAction { // SHOW_TABLES((g) -> new SQLQueryAdapter("SHOW TABLES", new ExpectedErrors(), false, false)), // INSERT(PrestoInsertGenerator::getQuery); // TODO : check actions based on connector // DELETE(PrestoDeleteGenerator::generate), // // UPDATE(PrestoUpdateGenerator::getQuery), // // CREATE_VIEW(PrestoViewGenerator::generate), // // EXPLAIN((g) -> { // ExpectedErrors errors = new ExpectedErrors(); // PrestoErrors.addExpressionErrors(errors); // PrestoErrors.addGroupByErrors(errors); // return new SQLQueryAdapter( // "EXPLAIN " + PrestoToStringVisitor // .asString(PrestoRandomQuerySynthesizer.generateSelect(g, Randomly.smallNumber() + 1)), // errors); // }); private final SQLQueryProvider sqlQueryProvider; Action(SQLQueryProvider sqlQueryProvider) { this.sqlQueryProvider = sqlQueryProvider; } @Override public SQLQueryAdapter getQuery(PrestoGlobalState state) throws Exception { return sqlQueryProvider.getQuery(state); } } }