Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the destruction of interruptible token registry #1229

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
863ead0
Wrap the registry in a shared pointer and access it using weak pointers
achirkin Feb 2, 2023
456e8e2
Merge remote-tracking branch 'rapidsai/branch-23.02' into fix-interru…
achirkin Feb 3, 2023
ddd5a61
Revert #1224
achirkin Feb 3, 2023
5d15f66
Only use the mutex if the registry still exists
achirkin Feb 3, 2023
da2fe6a
Merge branch 'branch-23.02' into fix-interruptible-destruction
cjnolet Feb 3, 2023
d7cabcb
Merge branch 'branch-23.04' into fix-interruptible-destruction
cjnolet Feb 3, 2023
8778b30
Merge branch 'branch-23.04' into fix-interruptible-destruction
cjnolet Feb 4, 2023
eb92501
Merge branch 'branch-23.04' into fix-interruptible-destruction
cjnolet Feb 7, 2023
61b66a2
Merge remote-tracking branch 'rapidsai/branch-23.04' into fix-interru…
achirkin Feb 8, 2023
1a416fa
Put both the map and the mutex into one shared_ptr and make sure to o…
achirkin Feb 8, 2023
00cccb0
Merge branch 'branch-23.04' into fix-interruptible-destruction
cjnolet Feb 8, 2023
8583e4f
Fix compile time explosion for minkowski distance (#1254)
ahendriksen Feb 9, 2023
4977c30
Merge remote-tracking branch 'rapidsai/branch-23.04' into fix-interru…
achirkin Feb 9, 2023
6230b28
Merge remote-tracking branch 'rapidsai/branch-23.04' into fix-interru…
achirkin Feb 13, 2023
221cc54
Merge remote-tracking branch 'rapidsai/branch-23.04' into fix-interru…
achirkin Feb 13, 2023
4873727
Refactor the token deleter from the lambda to a custom type
achirkin Feb 13, 2023
add6a43
Merge remote-tracking branch 'rapidsai/branch-23.04' into fix-interru…
achirkin Feb 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 51 additions & 20 deletions cpp/include/raft/core/interruptible.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -179,9 +179,44 @@ class interruptible {

private:
/** Global registry of thread-local cancellation stores. */
static inline std::unordered_map<std::thread::id, std::weak_ptr<interruptible>> registry_;
/** Protect the access to the registry. */
static inline std::mutex mutex_;
using registry_t =
std::tuple<std::mutex, std::unordered_map<std::thread::id, std::weak_ptr<interruptible>>>;

/**
* The registry "garbage collector": a custom deleter for the interruptible tokens that removes
* the token from the registry, if the registry still exists.
*/
struct registry_gc_t {
std::weak_ptr<registry_t> weak_registry;
std::thread::id thread_id;

inline void operator()(interruptible* thread_store) const noexcept
{
// the deleter kicks in at thread/program exit; in some cases, the registry_ (static variable)
// may have been destructed by this point of time.
// Hence, we use a weak pointer to check if the registry still exists.
auto registry = weak_registry.lock();
if (registry) {
std::lock_guard<std::mutex> guard_erase(std::get<0>(*registry));
auto& map = std::get<1>(*registry);
auto found = map.find(thread_id);
if (found != map.end()) {
auto stored = found->second.lock();
// thread_store is not moveable, thus retains its original location.
// Not equal pointers below imply the new store has been already placed
// in the registry by the same std::thread::id
if (!stored || stored.get() == thread_store) { map.erase(found); }
}
}
delete thread_store;
}
};

/**
* The registry itself is stored in the static memory, in a shared pointer.
* This is to safely access it from the destructors of the thread-local tokens.
*/
static inline std::shared_ptr<registry_t> registry_{new registry_t{}};

/**
* Create a new interruptible token or get an existing from the global registry_.
Expand All @@ -201,26 +236,22 @@ class interruptible {
template <bool Claim>
static auto get_token_impl(std::thread::id thread_id) -> std::shared_ptr<interruptible>
{
std::lock_guard<std::mutex> guard_get(mutex_);
// the following constructs an empty shared_ptr if the key does not exist.
auto& weak_store = registry_[thread_id];
// Make a local copy of the shared pointer to make sure the registry is not destroyed,
// if, for any reason, this function is called at program exit.
std::shared_ptr<registry_t> shared_registry = registry_;
// If the registry is not available, create a lone token that cannot be accessed from
// the outside of the thread.
if (!shared_registry) { return std::shared_ptr<interruptible>{new interruptible()}; }
// Otherwise, proceed with the normal logic
std::lock_guard<std::mutex> guard_get(std::get<0>(*shared_registry));
// the following two lines construct an empty shared_ptr if the key does not exist.
auto& weak_store = std::get<1>(*shared_registry)[thread_id];
auto thread_store = weak_store.lock();
if (!thread_store || (Claim && thread_store->claimed_)) {
// Create a new thread_store in two cases:
// 1. It does not exist in the map yet
// 2. The previous store in the map has not yet been deleted
thread_store.reset(new interruptible(), [thread_id](auto ts) {
std::lock_guard<std::mutex> guard_erase(mutex_);
auto found = registry_.find(thread_id);
if (found != registry_.end()) {
auto stored = found->second.lock();
// thread_store is not moveable, thus retains its original location.
// Not equal pointers below imply the new store has been already placed
// in the registry_ by the same std::thread::id
if (!stored || stored.get() == ts) { registry_.erase(found); }
}
delete ts;
});
thread_store.reset(new interruptible(), registry_gc_t{shared_registry, thread_id});
std::weak_ptr<interruptible>(thread_store).swap(weak_store);
}
// The thread_store is "claimed" by the thread
Expand Down Expand Up @@ -268,4 +299,4 @@ class interruptible {

} // namespace raft

#endif
#endif
7 changes: 2 additions & 5 deletions cpp/include/raft/core/resource/cuda_stream.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,7 @@ inline void set_cuda_stream(resources const& res, rmm::cuda_stream_view stream_v
*/
inline void sync_stream(const resources& res, rmm::cuda_stream_view stream)
{
// TODO: Fix interruptible segfault:
// https://github.com/rapidsai/raft/issues/1225
// interruptible::synchronize(stream);
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
interruptible::synchronize(stream);
}

/**
Expand All @@ -106,4 +103,4 @@ inline void sync_stream(const resources& res) { sync_stream(res, get_cuda_stream
* @}
*/

} // namespace raft::resource
} // namespace raft::resource