@@ -39,14 +39,22 @@ function _nthreads_in_pool(tpid::Int8)
39
39
return Int (unsafe_load (p, tpid + 1 ))
40
40
end
41
41
42
+ function _tpid_to_sym (tpid:: Int8 )
43
+ return tpid == 0 ? :interactive : :default
44
+ end
45
+
46
+ function _sym_to_tpid (tp:: Symbol )
47
+ return tp === :interactive ? Int8 (0 ) : Int8 (1 )
48
+ end
49
+
42
50
"""
43
51
Threads.threadpool(tid = threadid()) -> Symbol
44
52
45
53
Returns the specified thread's threadpool; either `:default` or `:interactive`.
46
54
"""
47
55
function threadpool (tid = threadid ())
48
56
tpid = ccall (:jl_threadpoolid , Int8, (Int16,), tid- 1 )
49
- return tpid == 0 ? :default : :interactive
57
+ return _tpid_to_sym ( tpid)
50
58
end
51
59
52
60
"""
@@ -67,24 +75,34 @@ See also: `BLAS.get_num_threads` and `BLAS.set_num_threads` in the
67
75
[`Distributed`](@ref man-distributed) standard library.
68
76
"""
69
77
function threadpoolsize (pool:: Symbol = :default )
70
- if pool === :default
71
- tpid = Int8 (0 )
72
- elseif pool === :interactive
73
- tpid = Int8 (1 )
78
+ if pool === :default || pool === :interactive
79
+ tpid = _sym_to_tpid (pool)
74
80
else
75
81
error (" invalid threadpool specified" )
76
82
end
77
83
return _nthreads_in_pool (tpid)
78
84
end
79
85
86
+ function threadpoolids (pool:: Symbol )
87
+ ni = _nthreads_in_pool (Int8 (0 ))
88
+ if pool === :interactive
89
+ return collect (1 : ni)
90
+ elseif pool === :default
91
+ return collect (ni+ 1 : ni+ _nthreads_in_pool (Int8 (1 )))
92
+ else
93
+ error (" invalid threadpool specified" )
94
+ end
95
+ end
96
+
80
97
function threading_run (fun, static)
81
98
ccall (:jl_enter_threaded_region , Cvoid, ())
82
99
n = threadpoolsize ()
100
+ tid_offset = threadpoolsize (:interactive )
83
101
tasks = Vector {Task} (undef, n)
84
102
for i = 1 : n
85
103
t = Task (() -> fun (i)) # pass in tid
86
104
t. sticky = static
87
- static && ccall (:jl_set_task_tid , Cint, (Any, Cint), t, i- 1 )
105
+ static && ccall (:jl_set_task_tid , Cint, (Any, Cint), t, tid_offset + i- 1 )
88
106
tasks[i] = t
89
107
schedule (t)
90
108
end
@@ -287,6 +305,15 @@ macro threads(args...)
287
305
return _threadsfor (ex. args[1 ], ex. args[2 ], sched)
288
306
end
289
307
308
+ function _spawn_set_thrpool (t:: Task , tp:: Symbol )
309
+ tpid = _sym_to_tpid (tp)
310
+ if _nthreads_in_pool (tpid) == 0
311
+ tpid = _sym_to_tpid (:default )
312
+ end
313
+ ccall (:jl_set_task_threadpoolid , Cint, (Any, Int8), t, tpid)
314
+ nothing
315
+ end
316
+
290
317
"""
291
318
Threads.@spawn [:default|:interactive] expr
292
319
@@ -315,7 +342,7 @@ the variable's value in the current task.
315
342
A threadpool may be specified as of Julia 1.9.
316
343
"""
317
344
macro spawn (args... )
318
- tpid = Int8 ( 0 )
345
+ tp = :default
319
346
na = length (args)
320
347
if na == 2
321
348
ttype, ex = args
@@ -325,9 +352,9 @@ macro spawn(args...)
325
352
# TODO : allow unquoted symbols
326
353
ttype = nothing
327
354
end
328
- if ttype === :interactive
329
- tpid = Int8 ( 1 )
330
- elseif ttype != = :default
355
+ if ttype === :interactive || ttype === :default
356
+ tp = ttype
357
+ else
331
358
throw (ArgumentError (" unsupported threadpool in @spawn: $ttype " ))
332
359
end
333
360
elseif na == 1
@@ -344,11 +371,7 @@ macro spawn(args...)
344
371
let $ (letargs... )
345
372
local task = Task ($ thunk)
346
373
task. sticky = false
347
- local tpid_actual = $ tpid
348
- if _nthreads_in_pool (tpid_actual) == 0
349
- tpid_actual = Int8 (0 )
350
- end
351
- ccall (:jl_set_task_threadpoolid , Cint, (Any, Int8), task, tpid_actual)
374
+ _spawn_set_thrpool (task, $ (QuoteNode (tp)))
352
375
if $ (Expr (:islocal , var))
353
376
put! ($ var, task)
354
377
end
0 commit comments