Skip to content

Commit

Permalink
Make error handling of relic errors thread-safe (Chia-Network#46)
Browse files Browse the repository at this point in the history
* Support initialization callback for thread-local contexts in relic

* Make the library thread safe

1. Compile relic with MULTI=PTHREAD to enable thread-local contexts
2. Use core_set_thread_initializer to register an initializer for the thread-local
   context. Whenever core_get() is called now, the context will be initialized
   automatically for the current thread (if not already initialized).
3. Remove BLS::Clean. Was never called and is also not useful anymore due
   each thread having its own context now.
4. Remove AssertInitialized and all calls to it. This is not needed anymore
   as initialization happens automatically now.
5. Add simple test case to assert that each thread has it's own context
  • Loading branch information
codablock authored and mariano54 committed Nov 2, 2018
1 parent bb30628 commit c001d9d
Show file tree
Hide file tree
Showing 13 changed files with 79 additions and 83 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ set(CHECK "off" CACHE STRING "")
set(VERBS "off" CACHE STRING "")
set(ALLOC "AUTO" CACHE STRING "")
set(SHLIB "OFF" CACHE STRING "")
set(MULTI "PTHREAD" CACHE STRING "")

set(FP_PRIME 381 CACHE INTEGER "")

Expand Down
11 changes: 11 additions & 0 deletions contrib/relic/include/relic_core.h
Original file line number Diff line number Diff line change
Expand Up @@ -448,4 +448,15 @@ ctx_t *core_get(void);
*/
void core_set(ctx_t *ctx);

#if MULTI != RELIC_NONE
/**
* Set an initializer function which is called when the context
* is uninitialized. This function is called for every thread.
*
* @param[in] init function to call when the current context is not initialized
* @param[in] init_ptr a pointer which is passed to the initialized
*/
void core_set_thread_initializer(void(*init)(void *init_ptr), void* init_ptr);
#endif

#endif /* !RELIC_CORE_H */
21 changes: 21 additions & 0 deletions contrib/relic/src/relic_core.c
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,14 @@ thread ctx_t first_ctx;
*/
thread ctx_t *core_ctx = NULL;

#if MULTI != RELIC_NONE
/*
* Initializer function to call for every thread's context
*/
void (*core_thread_initializer)(void* init_ptr) = NULL;
void* core_init_ptr = NULL;
#endif

#if MULTI == OPENMP
#pragma omp threadprivate(first_ctx, core_ctx)
#endif
Expand Down Expand Up @@ -164,9 +172,22 @@ int core_clean(void) {
}

ctx_t *core_get(void) {
#if MULTI != RELIC_NONE
if (core_ctx == NULL && core_thread_initializer != NULL) {
core_thread_initializer(core_init_ptr);
}
#endif

return core_ctx;
}

void core_set(ctx_t *ctx) {
core_ctx = ctx;
}

#if MULTI != RELIC_NONE
void core_set_thread_initializer(void(*init)(void *init_ptr), void* init_ptr) {
core_thread_initializer = init;
core_init_ptr = init_ptr;
}
#endif
5 changes: 1 addition & 4 deletions python-bindings/pythonbindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,10 +261,7 @@ PYBIND11_MODULE(blspy, m) {
py::class_<BLS>(m, "BLS")
.def_property_readonly_static("MESSAGE_HASH_LEN", [](py::object self) {
return BLS::MESSAGE_HASH_LEN;
})
.def("init", &BLS::Init)
.def("assert_initialized", &BLS::AssertInitialized)
.def("clean", &BLS::Clean);
});

py::class_<Util>(m, "Util")
.def("hash256", [](const py::bytes &message) {
Expand Down
37 changes: 15 additions & 22 deletions src/bls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,47 +28,40 @@ bool BLSInitResult = BLS::Init();
Util::SecureAllocCallback Util::secureAllocCallback;
Util::SecureFreeCallback Util::secureFreeCallback;

bool BLS::Init() {
if (ALLOC != AUTO) {
std::cout << "Must have ALLOC == AUTO";
return false;
}
static void relic_core_initializer(void* ptr) {
core_init();
if (err_get_code() != STS_OK) {
std::cout << "core_init() failed";
return false;
// this will most likely crash the application...but there isn't much we can do
throw std::string("core_init() failed");
}

const int r = ep_param_set_any_pairf();
if (r != STS_OK) {
std::cout << "ep_param_set_any_pairf() failed";
return false;
// this will most likely crash the application...but there isn't much we can do
throw std::string("ep_param_set_any_pairf() failed");
}
}

bool BLS::Init() {
if (ALLOC != AUTO) {
std::cout << "Must have ALLOC == AUTO";
throw std::string("Must have ALLOC == AUTO");
}
#if BLSALLOC_SODIUM
if (sodium_init() < 0) {
std::cout << "libsodium init failed";
return false;
throw std::string("libsodium init failed");
}
SetSecureAllocator(libsodium::sodium_malloc, libsodium::sodium_free);
#else
SetSecureAllocator(malloc, free);
#endif
return true;
}

void BLS::AssertInitialized() {
if (!core_get()) {
throw std::string("Library not initialized properly. Call BLS::Init()");
}
#if BLSALLOC_SODIUM
if (sodium_init() < 0) {
throw std::string("Libsodium initialization failed.");
}
#endif
}
core_set_thread_initializer(relic_core_initializer, nullptr);

void BLS::Clean() {
core_clean();
return true;
}

void BLS::SetSecureAllocator(Util::SecureAllocCallback allocCb, Util::SecureFreeCallback freeCb) {
Expand Down
6 changes: 1 addition & 5 deletions src/bls.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,8 @@ class BLS {
static const char GROUP_ORDER[];
static const size_t MESSAGE_HASH_LEN = 32;

// Initializes the BLS library manually
// Initializes the BLS library (called automatically)
static bool Init();
// Asserts the BLS library is initialized
static void AssertInitialized();
// Cleans the BLS library
static void Clean();

static void SetSecureAllocator(Util::SecureAllocCallback allocCb, Util::SecureFreeCallback freeCb);

Expand Down
2 changes: 0 additions & 2 deletions src/chaincode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,13 @@
namespace bls {

ChainCode ChainCode::FromBytes(const uint8_t* bytes) {
BLS::AssertInitialized();
ChainCode c = ChainCode();
bn_new(c.chainCode);
bn_read_bin(c.chainCode, bytes, ChainCode::CHAIN_CODE_SIZE);
return c;
}

ChainCode::ChainCode(const ChainCode &cc) {
BLS::AssertInitialized();
uint8_t bytes[ChainCode::CHAIN_CODE_SIZE];
cc.Serialize(bytes);
bn_new(chainCode);
Expand Down
8 changes: 0 additions & 8 deletions src/extendedprivatekey.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ namespace bls {

ExtendedPrivateKey ExtendedPrivateKey::FromSeed(const uint8_t* seed,
size_t seedLen) {
BLS::AssertInitialized();

// "BLS HD seed" in ascii
const uint8_t prefix[] = {66, 76, 83, 32, 72, 68, 32, 115, 101, 101, 100};

Expand Down Expand Up @@ -63,7 +61,6 @@ ExtendedPrivateKey ExtendedPrivateKey::FromSeed(const uint8_t* seed,
}

ExtendedPrivateKey ExtendedPrivateKey::FromBytes(const uint8_t* serialized) {
BLS::AssertInitialized();
uint32_t version = Util::FourBytesToInt(serialized);
uint32_t depth = serialized[4];
uint32_t parentFingerprint = Util::FourBytesToInt(serialized + 5);
Expand All @@ -78,7 +75,6 @@ ExtendedPrivateKey ExtendedPrivateKey::FromBytes(const uint8_t* serialized) {
}

ExtendedPrivateKey ExtendedPrivateKey::PrivateChild(uint32_t i) const {
BLS::AssertInitialized();
if (depth >= 255) {
throw std::string("Cannot go further than 255 levels");
}
Expand Down Expand Up @@ -157,7 +153,6 @@ PrivateKey ExtendedPrivateKey::GetPrivateKey() const {
}

PublicKey ExtendedPrivateKey::GetPublicKey() const {
BLS::AssertInitialized();
return sk.GetPublicKey();
}

Expand All @@ -166,7 +161,6 @@ ChainCode ExtendedPrivateKey::GetChainCode() const {
}

ExtendedPublicKey ExtendedPrivateKey::GetExtendedPublicKey() const {
BLS::AssertInitialized();
uint8_t buffer[ExtendedPublicKey::EXTENDED_PUBLIC_KEY_SIZE];
Util::IntToFourBytes(buffer, version);
buffer[4] = depth;
Expand All @@ -181,7 +175,6 @@ ExtendedPublicKey ExtendedPrivateKey::GetExtendedPublicKey() const {

// Comparator implementation.
bool operator==(ExtendedPrivateKey const &a, ExtendedPrivateKey const &b) {
BLS::AssertInitialized();
return (a.GetPrivateKey() == b.GetPrivateKey() &&
a.GetChainCode() == b.GetChainCode());
}
Expand All @@ -191,7 +184,6 @@ bool operator!=(ExtendedPrivateKey const&a, ExtendedPrivateKey const&b) {
}

void ExtendedPrivateKey::Serialize(uint8_t *buffer) const {
BLS::AssertInitialized();
Util::IntToFourBytes(buffer, version);
buffer[4] = depth;
Util::IntToFourBytes(buffer + 5, parentFingerprint);
Expand Down
5 changes: 0 additions & 5 deletions src/extendedpublickey.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ namespace bls {

ExtendedPublicKey ExtendedPublicKey::FromBytes(
const uint8_t* serialized) {
BLS::AssertInitialized();
uint32_t version = Util::FourBytesToInt(serialized);
uint32_t depth = serialized[4];
uint32_t parentFingerprint = Util::FourBytesToInt(serialized + 5);
Expand All @@ -36,7 +35,6 @@ ExtendedPublicKey ExtendedPublicKey::FromBytes(
}

ExtendedPublicKey ExtendedPublicKey::PublicChild(uint32_t i) const {
BLS::AssertInitialized();
// Hardened children have i >= 2^31. Non-hardened have i < 2^31
uint32_t cmp = (1 << 31);
if (i >= cmp) {
Expand Down Expand Up @@ -109,7 +107,6 @@ PublicKey ExtendedPublicKey::GetPublicKey() const {

// Comparator implementation.
bool operator==(ExtendedPublicKey const &a, ExtendedPublicKey const &b) {
BLS::AssertInitialized();
return (a.GetPublicKey() == b.GetPublicKey() &&
a.GetChainCode() == b.GetChainCode());
}
Expand All @@ -119,12 +116,10 @@ bool operator!=(ExtendedPublicKey const&a, ExtendedPublicKey const&b) {
}

std::ostream &operator<<(std::ostream &os, ExtendedPublicKey const &a) {
BLS::AssertInitialized();
return os << a.GetPublicKey() << a.GetChainCode();
}

void ExtendedPublicKey::Serialize(uint8_t *buffer) const {
BLS::AssertInitialized();
Util::IntToFourBytes(buffer, version);
buffer[4] = depth;
Util::IntToFourBytes(buffer + 5, parentFingerprint);
Expand Down
17 changes: 0 additions & 17 deletions src/privatekey.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
#include "privatekey.hpp"
namespace bls {
PrivateKey PrivateKey::FromSeed(const uint8_t* seed, size_t seedLen) {
BLS::AssertInitialized();

// "BLS private key seed" in ascii
const uint8_t hmacKey[] = {66, 76, 83, 32, 112, 114, 105, 118, 97, 116, 101,
32, 107, 101, 121, 32, 115, 101, 101, 100};
Expand Down Expand Up @@ -54,7 +52,6 @@ PrivateKey PrivateKey::FromSeed(const uint8_t* seed, size_t seedLen) {

// Construct a private key from a bytearray.
PrivateKey PrivateKey::FromBytes(const uint8_t* bytes, bool modOrder) {
BLS::AssertInitialized();
PrivateKey k;
k.AllocateKeyData();
bn_read_bin(*k.keydata, bytes, PrivateKey::PRIVATE_KEY_SIZE);
Expand All @@ -73,23 +70,19 @@ PrivateKey PrivateKey::FromBytes(const uint8_t* bytes, bool modOrder) {

// Construct a private key from another private key.
PrivateKey::PrivateKey(const PrivateKey &privateKey) {
BLS::AssertInitialized();
AllocateKeyData();
bn_copy(*keydata, *privateKey.keydata);
}

PrivateKey::PrivateKey(PrivateKey&& k) {
BLS::AssertInitialized();
std::swap(keydata, k.keydata);
}

PrivateKey::~PrivateKey() {
BLS::AssertInitialized();
Util::SecFree(keydata);
}

PublicKey PrivateKey::GetPublicKey() const {
BLS::AssertInitialized();
g1_t *q = Util::SecAlloc<g1_t>(1);
g1_mul_gen(*q, *keydata);

Expand Down Expand Up @@ -181,25 +174,21 @@ PrivateKey PrivateKey::Mul(const bn_t n) const {
}

bool operator==(const PrivateKey& a, const PrivateKey& b) {
BLS::AssertInitialized();
return bn_cmp(*a.keydata, *b.keydata) == CMP_EQ;
}

bool operator!=(const PrivateKey& a, const PrivateKey& b) {
BLS::AssertInitialized();
return !(a == b);
}

PrivateKey& PrivateKey::operator=(const PrivateKey &rhs) {
BLS::AssertInitialized();
Util::SecFree(keydata);
AllocateKeyData();
bn_copy(*keydata, *rhs.keydata);
return *this;
}

void PrivateKey::Serialize(uint8_t* buffer) const {
BLS::AssertInitialized();
bn_write_bin(buffer, PrivateKey::PRIVATE_KEY_SIZE, *keydata);
}

Expand All @@ -210,14 +199,12 @@ std::vector<uint8_t> PrivateKey::Serialize() const {
}

InsecureSignature PrivateKey::SignInsecure(const uint8_t *msg, size_t len) const {
BLS::AssertInitialized();
uint8_t messageHash[BLS::MESSAGE_HASH_LEN];
Util::Hash256(messageHash, msg, len);
return SignInsecurePrehashed(messageHash);
}

InsecureSignature PrivateKey::SignInsecurePrehashed(const uint8_t *messageHash) const {
BLS::AssertInitialized();
g2_t sig, point;

g2_map(point, messageHash, BLS::MESSAGE_HASH_LEN, 0);
Expand All @@ -227,15 +214,12 @@ InsecureSignature PrivateKey::SignInsecurePrehashed(const uint8_t *messageHash)
}

Signature PrivateKey::Sign(const uint8_t *msg, size_t len) const {
BLS::AssertInitialized();
uint8_t messageHash[BLS::MESSAGE_HASH_LEN];
Util::Hash256(messageHash, msg, len);
return SignPrehashed(messageHash);
}

Signature PrivateKey::SignPrehashed(const uint8_t *messageHash) const {
BLS::AssertInitialized();

InsecureSignature insecureSig = SignInsecurePrehashed(messageHash);
Signature ret = Signature::FromInsecureSig(insecureSig);

Expand All @@ -246,7 +230,6 @@ Signature PrivateKey::SignPrehashed(const uint8_t *messageHash) const {
}

void PrivateKey::AllocateKeyData() {
BLS::AssertInitialized();
keydata = Util::SecAlloc<bn_t>(1);
bn_new(*keydata); // Freed in destructor
bn_zero(*keydata);
Expand Down
Loading

0 comments on commit c001d9d

Please sign in to comment.