Skip to content

Commit b46df09

Browse files
vtjnashtkf
andauthored
Fetch thread-local information (ptls) through the current task (JuliaLang#40715)
Enables task-thread migration! Co-authored-by: Takafumi Arakaki <[email protected]>
1 parent b632765 commit b46df09

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

73 files changed

+1387
-1157
lines changed

base/gcutils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ end
6565
6666
Immediately run finalizers registered for object `x`.
6767
"""
68-
finalize(@nospecialize(o)) = ccall(:jl_finalize_th, Cvoid, (Ptr{Cvoid}, Any,),
69-
Core.getptls(), o)
68+
finalize(@nospecialize(o)) = ccall(:jl_finalize_th, Cvoid, (Any, Any,),
69+
current_task(), o)
7070

7171
"""
7272
Base.GC

base/task.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -619,19 +619,22 @@ function enq_work(t::Task)
619619
# 1. The Task's stack is currently being used by the scheduler for a certain thread.
620620
# 2. There is only 1 thread.
621621
# 3. The multiq is full (can be fixed by making it growable).
622-
if t.sticky || tid != 0 || Threads.nthreads() == 1
622+
if t.sticky || Threads.nthreads() == 1
623623
if tid == 0
624624
tid = Threads.threadid()
625625
ccall(:jl_set_task_tid, Cvoid, (Any, Cint), t, tid-1)
626626
end
627627
push!(Workqueues[tid], t)
628628
else
629-
tid = 0
630629
if ccall(:jl_enqueue_task, Cint, (Any,), t) != 0
631630
# if multiq is full, give to a random thread (TODO fix)
632-
tid = mod(time_ns() % Int, Threads.nthreads()) + 1
633-
ccall(:jl_set_task_tid, Cvoid, (Any, Cint), t, tid-1)
631+
if tid == 0
632+
tid = mod(time_ns() % Int, Threads.nthreads()) + 1
633+
ccall(:jl_set_task_tid, Cvoid, (Any, Cint), t, tid-1)
634+
end
634635
push!(Workqueues[tid], t)
636+
else
637+
tid = 0
635638
end
636639
end
637640
ccall(:jl_wakeup_thread, Cvoid, (Int16,), (tid - 1) % Int16)

cli/Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ include $(JULIAHOME)/Make.inc
66
include $(JULIAHOME)/deps/llvm-ver.make
77

88

9-
HEADERS := $(addprefix $(SRCDIR)/,jl_exports.h loader.h) $(addprefix $(JULIAHOME)/src/,support/platform.h support/dirpath.h jl_exported_data.inc jl_exported_funcs.inc)
9+
HEADERS := $(addprefix $(SRCDIR)/,jl_exports.h loader.h) $(addprefix $(JULIAHOME)/src/,julia_fasttls.h support/platform.h support/dirpath.h jl_exported_data.inc jl_exported_funcs.inc)
1010

1111
LOADER_CFLAGS = $(JCFLAGS) -I$(BUILDROOT)/src -I$(JULIAHOME)/src -I$(JULIAHOME)/src/support -I$(build_includedir) -ffreestanding
1212
LOADER_LDFLAGS = $(JLDFLAGS) -ffreestanding -L$(build_shlibdir) -L$(build_libdir)
@@ -116,7 +116,7 @@ endif
116116
$(build_shlibdir)/libjulia-debug.$(JL_MAJOR_MINOR_SHLIB_EXT): $(LIB_DOBJS) $(SRCDIR)/list_strip_symbols.h | $(build_shlibdir) $(build_libdir)
117117
@$(call PRINT_LINK, $(CC) $(call IMPLIB_FLAGS,$@.tmp) $(LOADER_CFLAGS) -DLIBRARY_EXPORTS -shared $(DEBUGFLAGS) $(LIB_DOBJS) -o $@ \
118118
$(JLIBLDFLAGS) $(LOADER_LDFLAGS) $(RPATH_LIB) $(call SONAME_FLAGS,libjulia-debug.$(JL_MAJOR_SHLIB_EXT)))
119-
@$(INSTALL_NAME_CMD)libjulia-debug.$(SHLIB_EXT) $@.tmp
119+
@$(INSTALL_NAME_CMD)libjulia-debug.$(SHLIB_EXT) $@
120120
ifeq ($(OS), WINNT)
121121
@$(call PRINT_ANALYZE, $(OBJCOPY) $(build_libdir)/$(notdir $@).tmp.a $(STRIP_EXPORTED_FUNCS) $(build_libdir)/$(notdir $@).a && rm $(build_libdir)/$(notdir $@).tmp.a)
122122
endif

cli/loader.h

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
/* Bring in definitions for `_OS_X_`, `PATH_MAX` and `PATHSEPSTRING`, `jl_ptls_t`, etc... */
44
#include "../src/support/platform.h"
55
#include "../src/support/dirpath.h"
6+
#include "../src/julia_fasttls.h"
67

78
#ifdef _OS_WINDOWS_
89
/* We need to reimplement a bunch of standard library stuff on windows,
@@ -43,15 +44,6 @@
4344
#include <dlfcn.h>
4445
#endif
4546

46-
// Borrow definitions from `julia.h`
47-
#if defined(__GNUC__)
48-
# define JL_CONST_FUNC __attribute__((const))
49-
#elif defined(_COMPILER_MICROSOFT_)
50-
# define JL_CONST_FUNC __declspec(noalias)
51-
#else
52-
# define JL_CONST_FUNC
53-
#endif
54-
5547
// Borrow definition from `support/dtypes.h`
5648
#ifdef _OS_WINDOWS_
5749
# ifdef LIBRARY_EXPORTS
@@ -68,12 +60,6 @@
6860
# endif
6961
#define JL_HIDDEN __attribute__ ((visibility("hidden")))
7062
#endif
71-
#ifdef JL_DEBUG_BUILD
72-
#define JL_NAKED __attribute__ ((naked,no_stack_protector))
73-
#else
74-
#define JL_NAKED __attribute__ ((naked))
75-
#endif
76-
7763
/*
7864
* DEP_LIBS is our list of dependent libraries that must be loaded before `libjulia`.
7965
* Note that order matters, as each entry will be opened in-order. We define here a

cli/loader_exe.c

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,14 @@ extern "C" {
1111
#include "loader_win_utils.c"
1212
#endif
1313

14-
/* Define ptls getter, as this cannot be defined within a shared library. */
15-
#if !defined(_OS_WINDOWS_) && !defined(_OS_DARWIN_)
16-
JL_DLLEXPORT JL_CONST_FUNC void * jl_get_ptls_states_static(void)
17-
{
18-
/* Because we can't #include <julia.h> in this file, we define a TLS state object with
19-
* hopefully enough room; at last check, the `jl_tls_states_t` struct was <16KB. */
20-
static __attribute__((tls_model("local-exec"))) __thread char tls_states[32768];
21-
return &tls_states;
22-
}
23-
#endif
14+
JULIA_DEFINE_FAST_TLS
2415

2516
#ifdef _OS_WINDOWS_
2617
int mainCRTStartup(void)
2718
{
2819
int argc;
2920
LPWSTR * wargv = CommandLineToArgv(GetCommandLine(), &argc);
30-
char ** argv = (char **)malloc(sizeof(char *)*(argc+ 1));
21+
char ** argv = (char **)malloc(sizeof(char*) * (argc + 1));
3122
setup_stdio();
3223
#else
3324
int main(int argc, char * argv[])
@@ -36,7 +27,7 @@ int main(int argc, char * argv[])
3627

3728
// Convert Windows wchar_t values to UTF8
3829
#ifdef _OS_WINDOWS_
39-
for (int i=0; i<argc; i++) {
30+
for (int i = 0; i < argc; i++) {
4031
size_t max_arg_len = 4*wcslen(wargv[i]);
4132
argv[i] = (char *)malloc(max_arg_len);
4233
if (!wchar_to_utf8(wargv[i], argv[i], max_arg_len)) {

cli/loader_lib.c

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -182,17 +182,18 @@ JL_DLLEXPORT int jl_load_repl(int argc, char * argv[]) {
182182
}
183183
// Next, if we're on Linux/FreeBSD, set up fast TLS.
184184
#if !defined(_OS_WINDOWS_) && !defined(_OS_DARWIN_)
185-
void (*jl_set_ptls_states_getter)(void *) = lookup_symbol(libjulia_internal, "jl_set_ptls_states_getter");
186-
if (jl_set_ptls_states_getter == NULL) {
187-
jl_loader_print_stderr("ERROR: Cannot find jl_set_ptls_states_getter() function within libjulia-internal!\n");
185+
void (*jl_pgcstack_setkey)(void*, void*(*)(void)) = lookup_symbol(libjulia_internal, "jl_pgcstack_setkey");
186+
if (jl_pgcstack_setkey == NULL) {
187+
jl_loader_print_stderr("ERROR: Cannot find jl_pgcstack_setkey() function within libjulia-internal!\n");
188188
exit(1);
189189
}
190-
void * (*fptr)(void) = lookup_symbol(RTLD_DEFAULT, "jl_get_ptls_states_static");
191-
if (fptr == NULL) {
192-
jl_loader_print_stderr("ERROR: Cannot find jl_get_ptls_states_static(), must define this symbol within calling executable!\n");
190+
void *fptr = lookup_symbol(RTLD_DEFAULT, "jl_get_pgcstack_static");
191+
void *(*key)(void) = lookup_symbol(RTLD_DEFAULT, "jl_pgcstack_addr_static");
192+
if (fptr == NULL || key == NULL) {
193+
jl_loader_print_stderr("ERROR: Cannot find jl_get_pgcstack_static(), must define this symbol within calling executable!\n");
193194
exit(1);
194195
}
195-
jl_set_ptls_states_getter((void *)fptr);
196+
jl_pgcstack_setkey(fptr, key);
196197
#endif
197198

198199
// Load the repl entrypoint symbol and jump into it!

src/Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ SRCS += $(RUNTIME_SRCS)
8787

8888
# headers are used for dependency tracking, while public headers will be part of the dist
8989
UV_HEADERS :=
90-
HEADERS := $(BUILDDIR)/julia_version.h $(wildcard $(SRCDIR)/support/*.h) $(addprefix $(SRCDIR)/,julia.h julia_assert.h julia_threads.h tls.h locks.h atomics.h julia_internal.h options.h timing.h)
91-
PUBLIC_HEADERS := $(BUILDDIR)/julia_version.h $(wildcard $(SRCDIR)/support/*.h) $(addprefix $(SRCDIR)/,julia.h julia_assert.h julia_threads.h tls.h locks.h atomics.h julia_gcext.h)
90+
HEADERS := $(BUILDDIR)/julia_version.h $(wildcard $(SRCDIR)/support/*.h) $(addprefix $(SRCDIR)/,julia.h julia_assert.h julia_threads.h julia_fasttls.h locks.h atomics.h julia_internal.h options.h timing.h)
91+
PUBLIC_HEADERS := $(BUILDDIR)/julia_version.h $(wildcard $(SRCDIR)/support/*.h) $(addprefix $(SRCDIR)/,julia.h julia_assert.h julia_threads.h julia_fasttls.h locks.h atomics.h julia_gcext.h)
9292
ifeq ($(USE_SYSTEM_LIBUV),0)
9393
UV_HEADERS += uv.h
9494
UV_HEADERS += uv/*.h

src/array.c

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ size_t jl_arr_xtralloc_limit = 0;
7777
static jl_array_t *_new_array_(jl_value_t *atype, uint32_t ndims, size_t *dims,
7878
int8_t isunboxed, int8_t hasptr, int8_t isunion, int8_t zeroinit, int elsz)
7979
{
80-
jl_ptls_t ptls = jl_get_ptls_states();
80+
jl_task_t *ct = jl_current_task;
8181
size_t i, tot, nel=1;
8282
void *data;
8383
jl_array_t *a;
@@ -119,7 +119,7 @@ static jl_array_t *_new_array_(jl_value_t *atype, uint32_t ndims, size_t *dims,
119119
size_t doffs = tsz;
120120
tsz += tot;
121121
tsz = JL_ARRAY_ALIGN(tsz, JL_SMALL_BYTE_ALIGNMENT); // align whole object
122-
a = (jl_array_t*)jl_gc_alloc(ptls, tsz, atype);
122+
a = (jl_array_t*)jl_gc_alloc(ct->ptls, tsz, atype);
123123
// No allocation or safepoint allowed after this
124124
a->flags.how = 0;
125125
data = (char*)a + doffs;
@@ -129,10 +129,10 @@ static jl_array_t *_new_array_(jl_value_t *atype, uint32_t ndims, size_t *dims,
129129
data = jl_gc_managed_malloc(tot);
130130
// Allocate the Array **after** allocating the data
131131
// to make sure the array is still young
132-
a = (jl_array_t*)jl_gc_alloc(ptls, tsz, atype);
132+
a = (jl_array_t*)jl_gc_alloc(ct->ptls, tsz, atype);
133133
// No allocation or safepoint allowed after this
134134
a->flags.how = 2;
135-
jl_gc_track_malloced_array(ptls, a);
135+
jl_gc_track_malloced_array(ct->ptls, a);
136136
}
137137
a->flags.pooled = tsz <= GC_MAX_SZCLASS;
138138

@@ -213,7 +213,7 @@ static inline int is_ntuple_long(jl_value_t *v)
213213
JL_DLLEXPORT jl_array_t *jl_reshape_array(jl_value_t *atype, jl_array_t *data,
214214
jl_value_t *_dims)
215215
{
216-
jl_ptls_t ptls = jl_get_ptls_states();
216+
jl_task_t *ct = jl_current_task;
217217
jl_array_t *a;
218218
size_t ndims = jl_nfields(_dims);
219219
assert(is_ntuple_long(_dims));
@@ -222,7 +222,7 @@ JL_DLLEXPORT jl_array_t *jl_reshape_array(jl_value_t *atype, jl_array_t *data,
222222

223223
int ndimwords = jl_array_ndimwords(ndims);
224224
int tsz = JL_ARRAY_ALIGN(sizeof(jl_array_t) + ndimwords * sizeof(size_t) + sizeof(void*), JL_SMALL_BYTE_ALIGNMENT);
225-
a = (jl_array_t*)jl_gc_alloc(ptls, tsz, atype);
225+
a = (jl_array_t*)jl_gc_alloc(ct->ptls, tsz, atype);
226226
// No allocation or safepoint allowed after this
227227
a->flags.pooled = tsz <= GC_MAX_SZCLASS;
228228
a->flags.ndims = ndims;
@@ -298,12 +298,12 @@ JL_DLLEXPORT jl_array_t *jl_reshape_array(jl_value_t *atype, jl_array_t *data,
298298

299299
JL_DLLEXPORT jl_array_t *jl_string_to_array(jl_value_t *str)
300300
{
301-
jl_ptls_t ptls = jl_get_ptls_states();
301+
jl_task_t *ct = jl_current_task;
302302
jl_array_t *a;
303303

304304
int ndimwords = jl_array_ndimwords(1);
305305
int tsz = JL_ARRAY_ALIGN(sizeof(jl_array_t) + ndimwords*sizeof(size_t) + sizeof(void*), JL_SMALL_BYTE_ALIGNMENT);
306-
a = (jl_array_t*)jl_gc_alloc(ptls, tsz, jl_array_uint8_type);
306+
a = (jl_array_t*)jl_gc_alloc(ct->ptls, tsz, jl_array_uint8_type);
307307
a->flags.pooled = tsz <= GC_MAX_SZCLASS;
308308
a->flags.ndims = 1;
309309
a->offset = 0;
@@ -327,7 +327,7 @@ JL_DLLEXPORT jl_array_t *jl_string_to_array(jl_value_t *str)
327327
JL_DLLEXPORT jl_array_t *jl_ptr_to_array_1d(jl_value_t *atype, void *data,
328328
size_t nel, int own_buffer)
329329
{
330-
jl_ptls_t ptls = jl_get_ptls_states();
330+
jl_task_t *ct = jl_current_task;
331331
jl_array_t *a;
332332
jl_value_t *eltype = jl_tparam0(atype);
333333

@@ -350,7 +350,7 @@ JL_DLLEXPORT jl_array_t *jl_ptr_to_array_1d(jl_value_t *atype, void *data,
350350

351351
int ndimwords = jl_array_ndimwords(1);
352352
int tsz = JL_ARRAY_ALIGN(sizeof(jl_array_t) + ndimwords*sizeof(size_t), JL_CACHE_BYTE_ALIGNMENT);
353-
a = (jl_array_t*)jl_gc_alloc(ptls, tsz, atype);
353+
a = (jl_array_t*)jl_gc_alloc(ct->ptls, tsz, atype);
354354
// No allocation or safepoint allowed after this
355355
a->flags.pooled = tsz <= GC_MAX_SZCLASS;
356356
a->data = data;
@@ -365,7 +365,7 @@ JL_DLLEXPORT jl_array_t *jl_ptr_to_array_1d(jl_value_t *atype, void *data,
365365
a->flags.isaligned = 0; // TODO: allow passing memalign'd buffers
366366
if (own_buffer) {
367367
a->flags.how = 2;
368-
jl_gc_track_malloced_array(ptls, a);
368+
jl_gc_track_malloced_array(ct->ptls, a);
369369
jl_gc_count_allocd(nel*elsz + (elsz == 1 ? 1 : 0));
370370
}
371371
else {
@@ -381,7 +381,7 @@ JL_DLLEXPORT jl_array_t *jl_ptr_to_array_1d(jl_value_t *atype, void *data,
381381
JL_DLLEXPORT jl_array_t *jl_ptr_to_array(jl_value_t *atype, void *data,
382382
jl_value_t *_dims, int own_buffer)
383383
{
384-
jl_ptls_t ptls = jl_get_ptls_states();
384+
jl_task_t *ct = jl_current_task;
385385
size_t nel = 1;
386386
jl_array_t *a;
387387
size_t ndims = jl_nfields(_dims);
@@ -417,7 +417,7 @@ JL_DLLEXPORT jl_array_t *jl_ptr_to_array(jl_value_t *atype, void *data,
417417

418418
int ndimwords = jl_array_ndimwords(ndims);
419419
int tsz = JL_ARRAY_ALIGN(sizeof(jl_array_t) + ndimwords*sizeof(size_t), JL_CACHE_BYTE_ALIGNMENT);
420-
a = (jl_array_t*)jl_gc_alloc(ptls, tsz, atype);
420+
a = (jl_array_t*)jl_gc_alloc(ct->ptls, tsz, atype);
421421
// No allocation or safepoint allowed after this
422422
a->flags.pooled = tsz <= GC_MAX_SZCLASS;
423423
a->data = data;
@@ -433,7 +433,7 @@ JL_DLLEXPORT jl_array_t *jl_ptr_to_array(jl_value_t *atype, void *data,
433433
a->flags.isaligned = 0;
434434
if (own_buffer) {
435435
a->flags.how = 2;
436-
jl_gc_track_malloced_array(ptls, a);
436+
jl_gc_track_malloced_array(ct->ptls, a);
437437
jl_gc_count_allocd(nel*elsz + (elsz == 1 ? 1 : 0));
438438
}
439439
else {
@@ -519,7 +519,8 @@ JL_DLLEXPORT jl_value_t *jl_pchar_to_string(const char *str, size_t len)
519519
jl_throw(jl_memory_exception);
520520
if (len == 0)
521521
return jl_an_empty_string;
522-
jl_value_t *s = jl_gc_alloc_(jl_get_ptls_states(), sz, jl_string_type); // force inlining
522+
jl_task_t *ct = jl_current_task;
523+
jl_value_t *s = jl_gc_alloc_(ct->ptls, sz, jl_string_type); // force inlining
523524
*(size_t*)s = len;
524525
memcpy((char*)s + sizeof(size_t), str, len);
525526
((char*)s + sizeof(size_t))[len] = 0;
@@ -533,7 +534,8 @@ JL_DLLEXPORT jl_value_t *jl_alloc_string(size_t len)
533534
jl_throw(jl_memory_exception);
534535
if (len == 0)
535536
return jl_an_empty_string;
536-
jl_value_t *s = jl_gc_alloc_(jl_get_ptls_states(), sz, jl_string_type); // force inlining
537+
jl_task_t *ct = jl_current_task;
538+
jl_value_t *s = jl_gc_alloc_(ct->ptls, sz, jl_string_type); // force inlining
537539
*(size_t*)s = len;
538540
((char*)s + sizeof(size_t))[len] = 0;
539541
return s;
@@ -672,7 +674,7 @@ JL_DLLEXPORT void jl_arrayunset(jl_array_t *a, size_t i)
672674
// the **beginning** of the new buffer.
673675
static int NOINLINE array_resize_buffer(jl_array_t *a, size_t newlen)
674676
{
675-
jl_ptls_t ptls = jl_get_ptls_states();
677+
jl_task_t *ct = jl_current_task;
676678
assert(!a->flags.isshared || a->flags.how == 3);
677679
size_t elsz = a->elsize;
678680
size_t nbytes = newlen * elsz;
@@ -714,12 +716,12 @@ static int NOINLINE array_resize_buffer(jl_array_t *a, size_t newlen)
714716
newbuf = 1;
715717
if (nbytes >= MALLOC_THRESH) {
716718
a->data = jl_gc_managed_malloc(nbytes);
717-
jl_gc_track_malloced_array(ptls, a);
719+
jl_gc_track_malloced_array(ct->ptls, a);
718720
a->flags.how = 2;
719721
a->flags.isaligned = 1;
720722
}
721723
else {
722-
a->data = jl_gc_alloc_buf(ptls, nbytes);
724+
a->data = jl_gc_alloc_buf(ct->ptls, nbytes);
723725
a->flags.how = 1;
724726
jl_gc_wb_buf(a, a->data, nbytes);
725727
}
@@ -1008,8 +1010,9 @@ STATIC_INLINE void jl_array_shrink(jl_array_t *a, size_t dec)
10081010
typetagdata = (char*)malloc_s(a->nrows);
10091011
memcpy(typetagdata, jl_array_typetagdata(a), a->nrows);
10101012
}
1013+
jl_task_t *ct = jl_current_task;
10111014
char *originaldata = (char*) a->data - a->offset * a->elsize;
1012-
char *newdata = (char*)jl_gc_alloc_buf(jl_get_ptls_states(), newbytes);
1015+
char *newdata = (char*)jl_gc_alloc_buf(ct->ptls, newbytes);
10131016
jl_gc_wb_buf(a, newdata, newbytes);
10141017
a->maxsize -= dec;
10151018
if (isbitsunion) {

0 commit comments

Comments
 (0)