#include "include/py_database.h"
#include
#include "common/exception/runtime.h"
#include "extension/extension.h"
#include "include/cached_import/py_cached_import.h"
#include "main/version.h"
#include "pandas/pandas_scan.h"
using namespace lbug::common;
void PyDatabase::initialize(py::handle& m) {
py::class_(m, "Database")
.def(py::init(),
py::arg("database_path"), py::arg("buffer_pool_size") = 0,
py::arg("max_num_threads") = 0, py::arg("compression") = true,
py::arg("read_only") = false, py::arg("max_db_size") = -1u,
py::arg("auto_checkpoint") = true, py::arg("checkpoint_threshold") = -1,
py::arg("throw_on_wal_replay_failure") = true, py::arg("enable_checksums") = true,
py::arg("enable_multi_writes") = false)
.def("scan_node_table_as_int64", &PyDatabase::scanNodeTable<:int64_t>,
py::arg("table_name"), py::arg("prop_name"), py::arg("indices"), py::arg("np_array"),
py::arg("num_threads"))
.def("scan_node_table_as_int32", &PyDatabase::scanNodeTable<:int32_t>,
py::arg("table_name"), py::arg("prop_name"), py::arg("indices"), py::arg("np_array"),
py::arg("num_threads"))
.def("scan_node_table_as_int16", &PyDatabase::scanNodeTable<:int16_t>,
py::arg("table_name"), py::arg("prop_name"), py::arg("indices"), py::arg("np_array"),
py::arg("num_threads"))
.def("scan_node_table_as_double", &PyDatabase::scanNodeTable, py::arg("table_name"),
py::arg("prop_name"), py::arg("indices"), py::arg("np_array"), py::arg("num_threads"))
.def("scan_node_table_as_float", &PyDatabase::scanNodeTable, py::arg("table_name"),
py::arg("prop_name"), py::arg("indices"), py::arg("np_array"), py::arg("num_threads"))
.def("scan_node_table_as_bool", &PyDatabase::scanNodeTable, py::arg("table_name"),
py::arg("prop_name"), py::arg("indices"), py::arg("np_array"), py::arg("num_threads"))
.def("close", &PyDatabase::close)
.def_static("get_version", &PyDatabase::getVersion)
.def_static("get_storage_version", &PyDatabase::getStorageVersion);
}
py::str PyDatabase::getVersion() {
return py::str(Version::getVersion());
}
uint64_t PyDatabase::getStorageVersion() {
return Version::getStorageVersion();
}
PyDatabase::PyDatabase(const std::string& databasePath, uint64_t bufferPoolSize,
uint64_t maxNumThreads, bool compression, bool readOnly, uint64_t maxDBSize,
bool autoCheckpoint, int64_t checkpointThreshold, bool throwOnWalReplayFailure,
bool enableChecksums, bool enableMultiWrites) {
auto systemConfig = SystemConfig(bufferPoolSize, maxNumThreads, compression, readOnly,
maxDBSize, autoCheckpoint);
if (checkpointThreshold >= 0) {
systemConfig.checkpointThreshold = static_cast(checkpointThreshold);
}
systemConfig.throwOnWalReplayFailure = throwOnWalReplayFailure;
systemConfig.enableChecksums = enableChecksums;
systemConfig.enableMultiWrites = enableMultiWrites;
state = std::make_shared();
state->database = std::make_unique(databasePath, systemConfig);
lbug::extension::ExtensionUtils::addTableFunc<:pandasscanfunction>(*state->database);
state->storageDriver = std::make_unique(state->database.get());
py::gil_scoped_acquire acquire;
if (lbug::importCache.get() == nullptr) {
lbug::importCache = std::make_shared<:pythoncachedimport>();
}
}
PyDatabase::~PyDatabase() {
close();
}
void PyDatabase::close() {
state.reset();
}
template
void PyDatabase::scanNodeTable(const std::string& tableName, const std::string& propName,
const py::array_t& indices, py::array_t& result, int numThreads) {
auto indices_buffer_info = indices.request(false);
auto indices_buffer = static_cast(indices_buffer_info.ptr);
auto nodeOffsets = (offset_t*)indices_buffer;
auto result_buffer_info = result.request();
auto result_buffer = (uint8_t*)result_buffer_info.ptr;
auto size = indices.size();
if (state == nullptr) {
throw RuntimeException("Database is closed.");
}
state->storage().scan(tableName, propName, nodeOffsets, size, result_buffer, numThreads);
}