Skip to content

Commit c374808

Browse files
committed
make thread 1 interactive when there is an interactive pool, so it can run the event loop
1 parent 5b49c03 commit c374808

File tree

4 files changed

+48
-26
lines changed

4 files changed

+48
-26
lines changed

base/task.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ istaskfailed(t::Task) = (load_state_acquire(t) === task_state_failed)
253253
Threads.threadid(t::Task) = Int(ccall(:jl_get_task_tid, Int16, (Any,), t)+1)
254254
function Threads.threadpool(t::Task)
255255
tpid = ccall(:jl_get_task_threadpoolid, Int8, (Any,), t)
256-
return tpid == 0 ? :default : :interactive
256+
return Threads._tpid_to_sym(tpid)
257257
end
258258

259259
task_result(t::Task) = t.result
@@ -786,7 +786,7 @@ function enq_work(t::Task)
786786
if Threads.threadpoolsize(tp) == 1
787787
# There's only one thread in the task's assigned thread pool;
788788
# use its work queue.
789-
tid = (tp === :default) ? 1 : Threads.threadpoolsize(:default)+1
789+
tid = (tp === :interactive) ? 1 : Threads.threadpoolsize(:interactive)+1
790790
ccall(:jl_set_task_tid, Cint, (Any, Cint), t, tid-1)
791791
push!(workqueue_for(tid), t)
792792
else

base/threadingconstructs.jl

+38-15
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,22 @@ function _nthreads_in_pool(tpid::Int8)
3939
return Int(unsafe_load(p, tpid + 1))
4040
end
4141

42+
function _tpid_to_sym(tpid::Int8)
43+
return tpid == 0 ? :interactive : :default
44+
end
45+
46+
function _sym_to_tpid(tp::Symbol)
47+
return tp === :interactive ? Int8(0) : Int8(1)
48+
end
49+
4250
"""
4351
Threads.threadpool(tid = threadid()) -> Symbol
4452
4553
Returns the specified thread's threadpool; either `:default` or `:interactive`.
4654
"""
4755
function threadpool(tid = threadid())
4856
tpid = ccall(:jl_threadpoolid, Int8, (Int16,), tid-1)
49-
return tpid == 0 ? :default : :interactive
57+
return _tpid_to_sym(tpid)
5058
end
5159

5260
"""
@@ -67,24 +75,34 @@ See also: `BLAS.get_num_threads` and `BLAS.set_num_threads` in the
6775
[`Distributed`](@ref man-distributed) standard library.
6876
"""
6977
function threadpoolsize(pool::Symbol = :default)
70-
if pool === :default
71-
tpid = Int8(0)
72-
elseif pool === :interactive
73-
tpid = Int8(1)
78+
if pool === :default || pool === :interactive
79+
tpid = _sym_to_tpid(pool)
7480
else
7581
error("invalid threadpool specified")
7682
end
7783
return _nthreads_in_pool(tpid)
7884
end
7985

86+
function threadpoolids(pool::Symbol)
87+
ni = _nthreads_in_pool(Int8(0))
88+
if pool === :interactive
89+
return collect(1:ni)
90+
elseif pool === :default
91+
return collect(ni+1:ni+_nthreads_in_pool(Int8(1)))
92+
else
93+
error("invalid threadpool specified")
94+
end
95+
end
96+
8097
function threading_run(fun, static)
8198
ccall(:jl_enter_threaded_region, Cvoid, ())
8299
n = threadpoolsize()
100+
tid_offset = threadpoolsize(:interactive)
83101
tasks = Vector{Task}(undef, n)
84102
for i = 1:n
85103
t = Task(() -> fun(i)) # pass in tid
86104
t.sticky = static
87-
static && ccall(:jl_set_task_tid, Cint, (Any, Cint), t, i-1)
105+
static && ccall(:jl_set_task_tid, Cint, (Any, Cint), t, tid_offset + i-1)
88106
tasks[i] = t
89107
schedule(t)
90108
end
@@ -287,6 +305,15 @@ macro threads(args...)
287305
return _threadsfor(ex.args[1], ex.args[2], sched)
288306
end
289307

308+
function _spawn_set_thrpool(t::Task, tp::Symbol)
309+
tpid = _sym_to_tpid(tp)
310+
if _nthreads_in_pool(tpid) == 0
311+
tpid = _sym_to_tpid(:default)
312+
end
313+
ccall(:jl_set_task_threadpoolid, Cint, (Any, Int8), t, tpid)
314+
nothing
315+
end
316+
290317
"""
291318
Threads.@spawn [:default|:interactive] expr
292319
@@ -315,7 +342,7 @@ the variable's value in the current task.
315342
A threadpool may be specified as of Julia 1.9.
316343
"""
317344
macro spawn(args...)
318-
tpid = Int8(0)
345+
tp = :default
319346
na = length(args)
320347
if na == 2
321348
ttype, ex = args
@@ -325,9 +352,9 @@ macro spawn(args...)
325352
# TODO: allow unquoted symbols
326353
ttype = nothing
327354
end
328-
if ttype === :interactive
329-
tpid = Int8(1)
330-
elseif ttype !== :default
355+
if ttype === :interactive || ttype === :default
356+
tp = ttype
357+
else
331358
throw(ArgumentError("unsupported threadpool in @spawn: $ttype"))
332359
end
333360
elseif na == 1
@@ -344,11 +371,7 @@ macro spawn(args...)
344371
let $(letargs...)
345372
local task = Task($thunk)
346373
task.sticky = false
347-
local tpid_actual = $tpid
348-
if _nthreads_in_pool(tpid_actual) == 0
349-
tpid_actual = Int8(0)
350-
end
351-
ccall(:jl_set_task_threadpoolid, Cint, (Any, Int8), task, tpid_actual)
374+
_spawn_set_thrpool(task, $(QuoteNode(tp)))
352375
if $(Expr(:islocal, var))
353376
put!($var, task)
354377
end

src/threading.c

+4-7
Original file line numberDiff line numberDiff line change
@@ -600,17 +600,16 @@ void jl_init_threading(void)
600600
// specified on the command line (and so are in `jl_options`) or by the
601601
// environment variable. Set the globals `jl_n_threadpools`, `jl_n_threads`
602602
// and `jl_n_threads_per_pool`.
603-
jl_n_threadpools = 1;
603+
jl_n_threadpools = 2;
604604
int16_t nthreads = JULIA_NUM_THREADS;
605605
int16_t nthreadsi = 0;
606606
char *endptr, *endptri;
607607

608608
if (jl_options.nthreads != 0) { // --threads specified
609-
jl_n_threadpools = jl_options.nthreadpools;
610609
nthreads = jl_options.nthreads_per_pool[0];
611610
if (nthreads < 0)
612611
nthreads = jl_effective_threads();
613-
if (jl_n_threadpools == 2)
612+
if (jl_options.nthreadpools == 2)
614613
nthreadsi = jl_options.nthreads_per_pool[1];
615614
}
616615
else if ((cp = getenv(NUM_THREADS_NAME))) { // ENV[NUM_THREADS_NAME] specified
@@ -635,15 +634,13 @@ void jl_init_threading(void)
635634
if (errno != 0 || endptri == cp || nthreadsi < 0)
636635
nthreadsi = 0;
637636
}
638-
if (nthreadsi > 0)
639-
jl_n_threadpools++;
640637
}
641638
}
642639

643640
jl_all_tls_states_size = nthreads + nthreadsi;
644641
jl_n_threads_per_pool = (int*)malloc_s(2 * sizeof(int));
645-
jl_n_threads_per_pool[0] = nthreads;
646-
jl_n_threads_per_pool[1] = nthreadsi;
642+
jl_n_threads_per_pool[0] = nthreadsi;
643+
jl_n_threads_per_pool[1] = nthreads;
647644

648645
jl_atomic_store_release(&jl_all_tls_states, (jl_ptls_t*)calloc(jl_all_tls_states_size, sizeof(jl_ptls_t)));
649646
jl_atomic_store_release(&jl_n_threads, jl_all_tls_states_size);

test/threadpool_use.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@ using Test
44
using Base.Threads
55

66
@test nthreadpools() == 2
7-
@test threadpool() === :default
8-
@test threadpool(2) === :interactive
7+
@test threadpool() === :interactive
8+
@test threadpool(2) === :default
99
@test fetch(Threads.@spawn Threads.threadpool()) === :default
1010
@test fetch(Threads.@spawn :default Threads.threadpool()) === :default
1111
@test fetch(Threads.@spawn :interactive Threads.threadpool()) === :interactive
12+
@test Threads.threadpoolids(:interactive) == [1]
13+
@test Threads.threadpoolids(:default) == [2]

0 commit comments

Comments
 (0)