#include "include/py_query_result_converter.h" #include "cached_import/py_cached_import.h" #include "common/types/value/value.h" #include "include/py_query_result.h" using namespace lbug::common; using namespace lbug; NPArrayWrapper::NPArrayWrapper(const LogicalType& type, uint64_t numFlatTuple) : type{type.copy()}, numElements{0} { data = py::array(convertToArrayType(type), numFlatTuple); dataBuffer = (uint8_t*)data.mutable_data(); mask = py::array(py::dtype("bool"), numFlatTuple); } void NPArrayWrapper::appendElement(Value* value) { ((uint8_t*)mask.mutable_data())[numElements] = value->isNull(); if (!value->isNull()) { switch (type.getLogicalTypeID()) { case LogicalTypeID::BOOL: { ((uint8_t*)dataBuffer)[numElements] = value->getValue(); } break; case LogicalTypeID::INT128: { Int128_t::tryCast(value->getValue(), ((double*)dataBuffer)[numElements]); } break; case LogicalTypeID::SERIAL: case LogicalTypeID::INT64: { ((int64_t*)dataBuffer)[numElements] = value->getValue(); } break; case LogicalTypeID::INT32: { ((int32_t*)dataBuffer)[numElements] = value->getValue(); } break; case LogicalTypeID::INT16: { ((int16_t*)dataBuffer)[numElements] = value->getValue(); } break; case LogicalTypeID::INT8: { ((int8_t*)dataBuffer)[numElements] = value->getValue(); } break; case LogicalTypeID::UINT64: { ((uint64_t*)dataBuffer)[numElements] = value->getValue(); } break; case LogicalTypeID::UINT32: { ((uint32_t*)dataBuffer)[numElements] = value->getValue(); } break; case LogicalTypeID::UINT16: { ((uint16_t*)dataBuffer)[numElements] = value->getValue(); } break; case LogicalTypeID::UINT8: { ((uint8_t*)dataBuffer)[numElements] = value->getValue(); } break; case LogicalTypeID::DOUBLE: { ((double*)dataBuffer)[numElements] = value->getValue(); } break; case LogicalTypeID::FLOAT: { ((float*)dataBuffer)[numElements] = value->getValue(); } break; case LogicalTypeID::DATE: { ((int64_t*)dataBuffer)[numElements] = Date::getEpochNanoSeconds(value->getValue()) / Interval::NANOS_PER_MICRO; } break; case LogicalTypeID::TIMESTAMP: { ((int64_t*)dataBuffer)[numElements] = value->getValue().value; } break; case LogicalTypeID::TIMESTAMP_TZ: { ((int64_t*)dataBuffer)[numElements] = value->getValue().value; } break; case LogicalTypeID::TIMESTAMP_NS: { ((int64_t*)dataBuffer)[numElements] = value->getValue().value; } break; case LogicalTypeID::TIMESTAMP_MS: { ((int64_t*)dataBuffer)[numElements] = value->getValue().value; } break; case LogicalTypeID::TIMESTAMP_SEC: { ((int64_t*)dataBuffer)[numElements] = value->getValue().value; } break; case LogicalTypeID::INTERVAL: { ((int64_t*)dataBuffer)[numElements] = Interval::getNanoseconds(value->getValue()); } break; case LogicalTypeID::STRING: { auto val = value->getValue<:string>(); py::str result(val); ((py::str*)dataBuffer)[numElements] = result; } break; case LogicalTypeID::BLOB: { ((py::bytes*)dataBuffer)[numElements] = PyQueryResult::convertValueToPyObject(*value); } break; case LogicalTypeID::DECIMAL: case LogicalTypeID::UUID: case LogicalTypeID::UNION: case LogicalTypeID::MAP: case LogicalTypeID::STRUCT: case LogicalTypeID::NODE: case LogicalTypeID::REL: { ((py::object*)dataBuffer)[numElements] = PyQueryResult::convertValueToPyObject(*value); } break; case LogicalTypeID::ARRAY: case LogicalTypeID::LIST: { ((py::list*)dataBuffer)[numElements] = PyQueryResult::convertValueToPyObject(*value); } break; case LogicalTypeID::RECURSIVE_REL: { ((py::dict*)dataBuffer)[numElements] = PyQueryResult::convertValueToPyObject(*value); } break; default: { UNREACHABLE_CODE; } } } numElements++; } py::dtype NPArrayWrapper::convertToArrayType(const LogicalType& type) { std::string dtype; switch (type.getLogicalTypeID()) { case LogicalTypeID::INT128: { dtype = "float64"; } break; case LogicalTypeID::SERIAL: case LogicalTypeID::INT64: { dtype = "int64"; } break; case LogicalTypeID::INT32: { dtype = "int32"; } break; case LogicalTypeID::INT16: { dtype = "int16"; } break; case LogicalTypeID::INT8: { dtype = "int8"; } break; case LogicalTypeID::UINT64: { dtype = "uint64"; } break; case LogicalTypeID::UINT32: { dtype = "uint32"; } break; case LogicalTypeID::UINT16: { dtype = "uint16"; } break; case LogicalTypeID::UINT8: { dtype = "uint8"; } break; case LogicalTypeID::DOUBLE: { dtype = "float64"; } break; case LogicalTypeID::FLOAT: { dtype = "float32"; } break; case LogicalTypeID::BOOL: { dtype = "bool"; } break; case LogicalTypeID::DATE: case LogicalTypeID::TIMESTAMP_TZ: case LogicalTypeID::TIMESTAMP: { dtype = "datetime64[us]"; } break; case LogicalTypeID::TIMESTAMP_NS: { dtype = "datetime64[ns]"; } break; case LogicalTypeID::TIMESTAMP_MS: { dtype = "datetime64[ms]"; } break; case LogicalTypeID::TIMESTAMP_SEC: { dtype = "datetime64[s]"; } break; case LogicalTypeID::INTERVAL: { dtype = "timedelta64[ns]"; } break; case LogicalTypeID::DECIMAL: case LogicalTypeID::UNION: case LogicalTypeID::BLOB: case LogicalTypeID::UUID: case LogicalTypeID::STRUCT: case LogicalTypeID::NODE: case LogicalTypeID::REL: case LogicalTypeID::LIST: case LogicalTypeID::ARRAY: case LogicalTypeID::STRING: case LogicalTypeID::MAP: case LogicalTypeID::RECURSIVE_REL: { dtype = "object"; } break; default: { UNREACHABLE_CODE; } } return py::dtype(dtype); } QueryResultConverter::QueryResultConverter(QueryResult* queryResult) : queryResult{queryResult} { for (auto& type : queryResult->getColumnDataTypes()) { columns.emplace_back(std::make_unique(type, queryResult->getNumTuples())); } } py::object QueryResultConverter::toDF() { queryResult->resetIterator(); while (queryResult->hasNext()) { auto flatTuple = queryResult->getNext(); for (auto i = 0u; i < columns.size(); i++) { columns[i]->appendElement(flatTuple->getValue(i)); } } py::dict result; auto colNames = queryResult->getColumnNames(); auto maskedArray = importCache->numpyma.masked_array(); auto fromDict = importCache->pandas.DataFrame.from_dict(); for (auto i = 0u; i < colNames.size(); i++) { result[colNames[i].c_str()] = maskedArray(columns[i]->data, columns[i]->mask); } return fromDict(result); }