package sqlancer; import java.io.IOException; import java.sql.SQLException; import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Set; import java.util.function.UnaryOperator; import java.util.stream.Collectors; import sqlancer.common.query.ExpectedErrors; import sqlancer.common.query.SQLQueryAdapter; import sqlancer.common.query.SQLancerResultSet; public final class ComparatorHelper { private ComparatorHelper() { } public static boolean isEqualDouble(String first, String second) { try { double val = Double.parseDouble(first); double secVal = Double.parseDouble(second); return equals(val, secVal); } catch (Exception e) { return false; } } static boolean equals(double a, double b) { if (a == b) { return true; } // If the difference is less than epsilon, treat as equal. return Math.abs(a - b) < 0.001 * Math.max(Math.abs(a), Math.abs(b)) + 0.001; } public static List getResultSetFirstColumnAsString(String queryString, ExpectedErrors errors, SQLGlobalState, ?> state) throws SQLException { if (state.getOptions().logEachSelect()) { // TODO: refactor me state.getLogger().writeCurrent(queryString); try { state.getLogger().getCurrentFileWriter().flush(); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } } boolean canonicalizeString = state.getOptions().canonicalizeSqlString(); SQLQueryAdapter q = new SQLQueryAdapter(queryString, errors, true, canonicalizeString); List resultSet = new ArrayList<>(); SQLancerResultSet result = null; try { result = q.executeAndGet(state); if (result == null) { throw new IgnoreMeException(); } while (result.next()) { String resultTemp = result.getString(1); if (resultTemp != null) { resultTemp = resultTemp.replaceAll("[\\.]0+$", ""); // Remove the trailing zeros as many DBMS treat // it as non-bugs } resultSet.add(resultTemp); } } catch (Exception e) { if (e instanceof IgnoreMeException) { throw e; } if (e.getMessage() == null) { throw new AssertionError(queryString, e); } if (errors.errorIsExpected(e.getMessage())) { throw new IgnoreMeException(); } throw new AssertionError(queryString, e); } finally { if (result != null && !result.isClosed()) { result.close(); } } return resultSet; } public static void assumeResultSetsAreEqual(List resultSet, List secondResultSet, String originalQueryString, List combinedString, SQLGlobalState, ?> state) { if (resultSet.size() != secondResultSet.size()) { String queryFormatString = "-- %s;" + System.lineSeparator() + "-- cardinality: %d" + System.lineSeparator(); String firstQueryString = String.format(queryFormatString, originalQueryString, resultSet.size()); String combinedQueryString = String.join(";", combinedString); String secondQueryString = String.format(queryFormatString, combinedQueryString, secondResultSet.size()); state.getState().getLocalState() .log(String.format("%s" + System.lineSeparator() + "%s", firstQueryString, secondQueryString)); String assertionMessage = String.format( "The size of the result sets mismatch (%d and %d)!" + System.lineSeparator() + "First query: \"%s\", whose cardinality is: %d" + System.lineSeparator() + "Second query:\"%s\", whose cardinality is: %d", resultSet.size(), secondResultSet.size(), originalQueryString, resultSet.size(), combinedQueryString, secondResultSet.size()); throw new AssertionError(assertionMessage); } Set firstHashSet = new HashSet<>(resultSet); Set secondHashSet = new HashSet<>(secondResultSet); boolean validateResultSizeOnly = state.getOptions().validateResultSizeOnly(); if (!validateResultSizeOnly && !firstHashSet.equals(secondHashSet)) { Set firstResultSetMisses = new HashSet<>(firstHashSet); firstResultSetMisses.removeAll(secondHashSet); Set secondResultSetMisses = new HashSet<>(secondHashSet); secondResultSetMisses.removeAll(firstHashSet); String queryFormatString = "-- Query: \"%s\"; It misses: \"%s\""; String firstQueryString = String.format(queryFormatString, originalQueryString, firstResultSetMisses); String secondQueryString = String.format(queryFormatString, String.join(";", combinedString), secondResultSetMisses); // update the SELECT queries to be logged at the bottom of the error log file state.getState().getLocalState() .log(String.format("%s" + System.lineSeparator() + "%s", firstQueryString, secondQueryString)); String assertionMessage = String.format("The content of the result sets mismatch!" + System.lineSeparator() + "First query : \"%s\"" + System.lineSeparator() + "Second query: \"%s\"", originalQueryString, secondQueryString); throw new AssertionError(assertionMessage); } } public static void assumeResultSetsAreEqual(List resultSet, List secondResultSet, String originalQueryString, List combinedString, SQLGlobalState, ?> state, UnaryOperator canonicalizationRule) { // Overloaded version of assumeResultSetsAreEqual that takes a canonicalization function which is applied to // both result sets before their comparison. List canonicalizedResultSet = resultSet.stream().map(canonicalizationRule).collect(Collectors.toList()); List canonicalizedSecondResultSet = secondResultSet.stream().map(canonicalizationRule) .collect(Collectors.toList()); assumeResultSetsAreEqual(canonicalizedResultSet, canonicalizedSecondResultSet, originalQueryString, combinedString, state); } public static List getCombinedResultSet(String firstQueryString, String secondQueryString, String thirdQueryString, List combinedString, boolean asUnion, SQLGlobalState, ?> state, ExpectedErrors errors) throws SQLException { List secondResultSet; if (asUnion) { String unionString = firstQueryString + " UNION ALL " + secondQueryString + " UNION ALL " + thirdQueryString; combinedString.add(unionString); secondResultSet = getResultSetFirstColumnAsString(unionString, errors, state); } else { secondResultSet = new ArrayList<>(); secondResultSet.addAll(getResultSetFirstColumnAsString(firstQueryString, errors, state)); secondResultSet.addAll(getResultSetFirstColumnAsString(secondQueryString, errors, state)); secondResultSet.addAll(getResultSetFirstColumnAsString(thirdQueryString, errors, state)); combinedString.add(firstQueryString); combinedString.add(secondQueryString); combinedString.add(thirdQueryString); } return secondResultSet; } public static List getCombinedResultSetNoDuplicates(String firstQueryString, String secondQueryString, String thirdQueryString, List combinedString, boolean asUnion, SQLGlobalState, ?> state, ExpectedErrors errors) throws SQLException { String unionString; if (asUnion) { unionString = firstQueryString + " UNION " + secondQueryString + " UNION " + thirdQueryString; } else { unionString = "SELECT DISTINCT * FROM (" + firstQueryString + " UNION ALL " + secondQueryString + " UNION ALL " + thirdQueryString + ")"; } List secondResultSet; combinedString.add(unionString); secondResultSet = getResultSetFirstColumnAsString(unionString, errors, state); return secondResultSet; } public static String canonicalizeResultValue(String value) { if (value == null) { return value; } switch (value) { case "-0.0": return "0.0"; case "-0": return "0"; default: } return value; } }