Skip to content

Commit 539d600

Browse files
committed
improve OncePer implementation
Address reviewer feedback, add more fixes and more tests
1 parent 49ead81 commit 539d600

File tree

2 files changed

+80
-40
lines changed

2 files changed

+80
-40
lines changed

base/lock.jl

+30-26
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,10 @@ end
578578
export Event
579579
end
580580

581+
const PerStateInitial = 0x00
582+
const PerStateHasrun = 0x01
583+
const PerStateErrored = 0x02
584+
const PerStateConcurrent = 0x03
581585

582586
"""
583587
PerProcess{T}
@@ -614,38 +618,38 @@ mutable struct PerProcess{T, F}
614618
const lock::ReentrantLock
615619

616620
function PerProcess{T,F}(initializer::F) where {T, F}
617-
once = new{T,F}(nothing, 0x00, true, initializer, ReentrantLock())
621+
once = new{T,F}(nothing, PerStateInitial, true, initializer, ReentrantLock())
618622
ccall(:jl_set_precompile_field_replace, Cvoid, (Any, Any, Any),
619623
once, :x, nothing)
620624
ccall(:jl_set_precompile_field_replace, Cvoid, (Any, Any, Any),
621-
once, :state, 0x00)
625+
once, :state, PerStateInitial)
622626
return once
623627
end
624628
end
625629
PerProcess{T}(initializer::F) where {T, F} = PerProcess{T, F}(initializer)
626630
PerProcess(initializer) = PerProcess{Base.promote_op(initializer), typeof(initializer)}(initializer)
627631
@inline function (once::PerProcess{T})() where T
628632
state = (@atomic :acquire once.state)
629-
if state != 0x01
633+
if state != PerStateHasrun
630634
(@noinline function init_perprocesss(once, state)
631-
state == 0x02 && error("PerProcess initializer failed previously")
635+
state == PerStateErrored && error("PerProcess initializer failed previously")
632636
once.allow_compile_time || __precompile__(false)
633637
lock(once.lock)
634638
try
635639
state = @atomic :monotonic once.state
636-
if state == 0x00
640+
if state == PerStateInitial
637641
once.x = once.initializer()
638-
elseif state == 0x02
642+
elseif state == PerStateErrored
639643
error("PerProcess initializer failed previously")
640-
elseif state != 0x01
644+
elseif state != PerStateHasrun
641645
error("invalid state for PerProcess")
642646
end
643647
catch
644-
state == 0x02 || @atomic :release once.state = 0x02
648+
state == PerStateErrored || @atomic :release once.state = PerStateErrored
645649
unlock(once.lock)
646650
rethrow()
647651
end
648-
state == 0x01 || @atomic :release once.state = 0x01
652+
state == PerStateHasrun || @atomic :release once.state = PerStateHasrun
649653
unlock(once.lock)
650654
nothing
651655
end)(once, state)
@@ -674,7 +678,7 @@ function fill_monotonic!(dest::AtomicMemory, x)
674678
end
675679

676680

677-
# share a lock, since we just need it briefly, so some contention is okay
681+
# share a lock/condition, since we just need it briefly, so some contention is okay
678682
const PerThreadLock = ThreadSynchronizer()
679683
"""
680684
PerThread{T}
@@ -734,22 +738,22 @@ PerThread(initializer) = PerThread{Base.promote_op(initializer), typeof(initiali
734738
ss = @atomic :acquire once.ss
735739
xs = @atomic :monotonic once.xs
736740
# n.b. length(xs) >= length(ss)
737-
if tid > length(ss) || (@atomic :acquire ss[tid]) != 0x01
741+
if tid <= 0 || tid > length(ss) || (@atomic :acquire ss[tid]) != PerStateHasrun
738742
(@noinline function init_perthread(once, tid)
739-
local xs = @atomic :acquire once.xs
740-
local ss = @atomic :monotonic once.ss
743+
local ss = @atomic :acquire once.ss
744+
local xs = @atomic :monotonic once.xs
741745
local len = length(ss)
742746
# slow path to allocate it
743747
nt = Threads.maxthreadid()
744-
0 < tid <= nt || ArgumentError("thread id outside of allocated range")
745-
if tid <= length(ss) && (@atomic :acquire ss[tid]) == 0x02
748+
0 < tid <= nt || throw(ArgumentError("thread id outside of allocated range"))
749+
if tid <= length(ss) && (@atomic :acquire ss[tid]) == PerStateErrored
746750
error("PerThread initializer failed previously")
747751
end
748752
newxs = xs
749753
newss = ss
750754
if tid > len
751755
# attempt to do all allocations outside of PerThreadLock for better scaling
752-
@assert length(xs) == length(ss) "logical constraint violation"
756+
@assert length(xs) >= length(ss) "logical constraint violation"
753757
newxs = typeof(xs)(undef, len + nt)
754758
newss = typeof(ss)(undef, len + nt)
755759
end
@@ -759,30 +763,30 @@ PerThread(initializer) = PerThread{Base.promote_op(initializer), typeof(initiali
759763
ss = @atomic :monotonic once.ss
760764
xs = @atomic :monotonic once.xs
761765
if tid > length(ss)
762-
@assert length(ss) >= len && newxs !== xs && newss != ss "logical constraint violation"
763-
fill_monotonic!(newss, 0x00)
766+
@assert len <= length(ss) <= length(newss) "logical constraint violation"
767+
fill_monotonic!(newss, PerStateInitial)
764768
xs = copyto_monotonic!(newxs, xs)
765769
ss = copyto_monotonic!(newss, ss)
766770
@atomic :release once.xs = xs
767771
@atomic :release once.ss = ss
768772
end
769773
state = @atomic :monotonic ss[tid]
770-
while state == 0x04
774+
while state == PerStateConcurrent
771775
# lost race, wait for notification this is done running elsewhere
772776
wait(PerThreadLock) # wait for initializer to finish without releasing this thread
773777
ss = @atomic :monotonic once.ss
774-
state = @atomic :monotonic ss[tid] == 0x04
778+
state = @atomic :monotonic ss[tid]
775779
end
776-
if state == 0x00
780+
if state == PerStateInitial
777781
# won the race, drop lock in exchange for state, and run user initializer
778-
@atomic :monotonic ss[tid] = 0x04
782+
@atomic :monotonic ss[tid] = PerStateConcurrent
779783
result = try
780784
unlock(PerThreadLock)
781785
once.initializer()
782786
catch
783787
lock(PerThreadLock)
784788
ss = @atomic :monotonic once.ss
785-
@atomic :release ss[tid] = 0x02
789+
@atomic :release ss[tid] = PerStateErrored
786790
notify(PerThreadLock)
787791
rethrow()
788792
end
@@ -791,11 +795,11 @@ PerThread(initializer) = PerThread{Base.promote_op(initializer), typeof(initiali
791795
xs = @atomic :monotonic once.xs
792796
@atomic :release xs[tid] = result
793797
ss = @atomic :monotonic once.ss
794-
@atomic :release ss[tid] = 0x01
798+
@atomic :release ss[tid] = PerStateHasrun
795799
notify(PerThreadLock)
796-
elseif state == 0x02
800+
elseif state == PerStateErrored
797801
error("PerThread initializer failed previously")
798-
elseif state != 0x01
802+
elseif state != PerStateHasrun
799803
error("invalid state for PerThread")
800804
end
801805
finally

test/threads.jl

+50-14
Original file line numberDiff line numberDiff line change
@@ -392,25 +392,45 @@ end
392392

393393
let e = Base.Event(true),
394394
started = Channel{Int16}(Inf),
395+
finish = Channel{Nothing}(Inf),
396+
exiting = Channel{Nothing}(Inf),
397+
starttest2 = Event(),
395398
once = PerThread() do
396399
push!(started, threadid())
397-
wait(e)
400+
take!(finish)
401+
return [nothing]
402+
end
403+
alls = PerThread() do
398404
return [nothing]
399405
end
400406
@test typeof(once) <: PerThread{Vector{Nothing}}
401-
notify(e)
407+
push!(finish, nothing)
408+
@test_throws ArgumentError once[0]
402409
x = once()
403-
@test x === once() === fetch(@async once())
410+
@test_throws ArgumentError once[0]
411+
@test x === once() === fetch(@async once()) === once[threadid()]
404412
@test take!(started) == threadid()
405413
@test isempty(started)
406414
tids = zeros(UInt, 50)
415+
newthreads = zeros(Int16, length(tids))
407416
onces = Vector{Vector{Nothing}}(undef, length(tids))
417+
allonces = Vector{Vector{Vector{Nothing}}}(undef, length(tids))
408418
for i = 1:length(tids)
409419
function cl()
410420
GC.gc(false) # stress test the GC-safepoint mechanics of jl_adopt_thread
411-
local y = once()
412-
onces[i] = y
413-
@test x !== y === once()
421+
try
422+
local y = once()
423+
onces[i] = y
424+
@test x !== y === once() === once[threadid()]
425+
newthreads[i] = threadid()
426+
wait(starttest2)
427+
allonces[i] = Vector{Nothing}[alls[tid] for tid in newthreads]
428+
catch ex
429+
close(started, ErrorException("failed"))
430+
close(finish, ErrorException("failed"))
431+
@lock stderr Base.display_error(current_exceptions())
432+
end
433+
push!(exiting, nothing)
414434
GC.gc(false) # stress test the GC-safepoint mechanics of jl_delete_thread
415435
nothing
416436
end
@@ -431,19 +451,35 @@ let e = Base.Event(true),
431451
err == 0 || Base.uv_error("uv_thread_join", err)
432452
end
433453
end
434-
# let them finish in 5 batches of 10
435-
for i = 1:length(tids) ÷ 10
436-
for i = 1:10
437-
@test take!(started) != threadid()
454+
try
455+
# let them finish in batches of 10
456+
for i = 1:length(tids) ÷ 10
457+
for i = 1:10
458+
newid = take!(started)
459+
@test newid != threadid()
460+
end
461+
for i = 1:10
462+
push!(finish, nothing)
463+
end
438464
end
439-
for i = 1:10
440-
notify(e)
465+
@test isempty(started)
466+
# now run the second part of the test where they all try to access the other threads elements
467+
notify(starttest2)
468+
finally
469+
for _ = 1:length(tids)
470+
# run IO loop until all threads are close to exiting
471+
take!(exiting)
441472
end
473+
waitallthreads(tids)
442474
end
443475
@test isempty(started)
444-
waitallthreads(tids)
445-
@test isempty(started)
476+
@test isempty(finish)
446477
@test length(IdSet{eltype(onces)}(onces)) == length(onces) # make sure every object is unique
478+
allexpected = Vector{Nothing}[alls[tid] for tid in newthreads]
479+
@test length(IdSet{eltype(allexpected)}(allexpected)) == length(allexpected) # make sure every object is unique
480+
@test all(i -> allonces[i] !== allexpected && all(j -> allonces[i][j] === allexpected[j], eachindex(allexpected)), eachindex(allonces)) # make sure every thread saw the same elements
481+
@test_throws ArgumentError once[Threads.maxthreadid() + 1]
482+
@test_throws ArgumentError once[-1]
447483

448484
end
449485
let once = PerThread{Int}(() -> error("expected"))

0 commit comments

Comments
 (0)