#include "include/py_database.h" #include #include "common/exception/runtime.h" #include "extension/extension.h" #include "include/cached_import/py_cached_import.h" #include "main/version.h" #include "pandas/pandas_scan.h" using namespace lbug::common; void PyDatabase::initialize(py::handle& m) { py::class_(m, "Database") .def(py::init(), py::arg("database_path"), py::arg("buffer_pool_size") = 0, py::arg("max_num_threads") = 0, py::arg("compression") = true, py::arg("read_only") = false, py::arg("max_db_size") = -1u, py::arg("auto_checkpoint") = true, py::arg("checkpoint_threshold") = -1, py::arg("throw_on_wal_replay_failure") = true, py::arg("enable_checksums") = true, py::arg("enable_multi_writes") = false) .def("scan_node_table_as_int64", &PyDatabase::scanNodeTable<:int64_t>, py::arg("table_name"), py::arg("prop_name"), py::arg("indices"), py::arg("np_array"), py::arg("num_threads")) .def("scan_node_table_as_int32", &PyDatabase::scanNodeTable<:int32_t>, py::arg("table_name"), py::arg("prop_name"), py::arg("indices"), py::arg("np_array"), py::arg("num_threads")) .def("scan_node_table_as_int16", &PyDatabase::scanNodeTable<:int16_t>, py::arg("table_name"), py::arg("prop_name"), py::arg("indices"), py::arg("np_array"), py::arg("num_threads")) .def("scan_node_table_as_double", &PyDatabase::scanNodeTable, py::arg("table_name"), py::arg("prop_name"), py::arg("indices"), py::arg("np_array"), py::arg("num_threads")) .def("scan_node_table_as_float", &PyDatabase::scanNodeTable, py::arg("table_name"), py::arg("prop_name"), py::arg("indices"), py::arg("np_array"), py::arg("num_threads")) .def("scan_node_table_as_bool", &PyDatabase::scanNodeTable, py::arg("table_name"), py::arg("prop_name"), py::arg("indices"), py::arg("np_array"), py::arg("num_threads")) .def("close", &PyDatabase::close) .def_static("get_version", &PyDatabase::getVersion) .def_static("get_storage_version", &PyDatabase::getStorageVersion); } py::str PyDatabase::getVersion() { return py::str(Version::getVersion()); } uint64_t PyDatabase::getStorageVersion() { return Version::getStorageVersion(); } PyDatabase::PyDatabase(const std::string& databasePath, uint64_t bufferPoolSize, uint64_t maxNumThreads, bool compression, bool readOnly, uint64_t maxDBSize, bool autoCheckpoint, int64_t checkpointThreshold, bool throwOnWalReplayFailure, bool enableChecksums, bool enableMultiWrites) { auto systemConfig = SystemConfig(bufferPoolSize, maxNumThreads, compression, readOnly, maxDBSize, autoCheckpoint); if (checkpointThreshold >= 0) { systemConfig.checkpointThreshold = static_cast(checkpointThreshold); } systemConfig.throwOnWalReplayFailure = throwOnWalReplayFailure; systemConfig.enableChecksums = enableChecksums; systemConfig.enableMultiWrites = enableMultiWrites; state = std::make_shared(); state->database = std::make_unique(databasePath, systemConfig); lbug::extension::ExtensionUtils::addTableFunc<:pandasscanfunction>(*state->database); state->storageDriver = std::make_unique(state->database.get()); py::gil_scoped_acquire acquire; if (lbug::importCache.get() == nullptr) { lbug::importCache = std::make_shared<:pythoncachedimport>(); } } PyDatabase::~PyDatabase() { close(); } void PyDatabase::close() { state.reset(); } template void PyDatabase::scanNodeTable(const std::string& tableName, const std::string& propName, const py::array_t& indices, py::array_t& result, int numThreads) { auto indices_buffer_info = indices.request(false); auto indices_buffer = static_cast(indices_buffer_info.ptr); auto nodeOffsets = (offset_t*)indices_buffer; auto result_buffer_info = result.request(); auto result_buffer = (uint8_t*)result_buffer_info.ptr; auto size = indices.size(); if (state == nullptr) { throw RuntimeException("Database is closed."); } state->storage().scan(tableName, propName, nodeOffsets, size, result_buffer, numThreads); }