Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PTX: Lower unreachable control flow to avoid bad CFG reconstruction #467

Merged
merged 11 commits into from
Jun 12, 2023
4 changes: 2 additions & 2 deletions src/precompile.jl
Original file line number Diff line number Diff line change
@@ -52,8 +52,8 @@ end
function _precompile_()
ccall(:jl_generating_output, Cint, ()) == 1 || return nothing
@assert precompile(Tuple{typeof(GPUCompiler.assign_args!),Expr,Vector{Any}})
@assert precompile(Tuple{typeof(GPUCompiler.hide_trap!),LLVM.Module})
@assert precompile(Tuple{typeof(GPUCompiler.hide_unreachable!),LLVM.Function})
@assert precompile(Tuple{typeof(GPUCompiler.lower_trap!),LLVM.Module})
@assert precompile(Tuple{typeof(GPUCompiler.lower_unreachable!),LLVM.Function})
@assert precompile(Tuple{typeof(GPUCompiler.lower_gc_frame!),LLVM.Function})
@assert precompile(Tuple{typeof(GPUCompiler.lower_throw!),LLVM.Module})
#@assert precompile(Tuple{typeof(GPUCompiler.split_kwargs),Tuple{},Vector{Symbol},Vararg{Vector{Symbol}, N} where N})
276 changes: 135 additions & 141 deletions src/ptx.jl
Original file line number Diff line number Diff line change
@@ -11,25 +11,23 @@ Base.@kwdef struct PTXCompilerTarget <: AbstractCompilerTarget
# codegen quirks
## can we emit debug info in the PTX assembly?
debuginfo::Bool = false
## do we permit unrachable statements, which often result in divergent control flow?
unreachable::Bool = false
## can exceptions use `exit` (which doesn't kill the GPU), or should they use `trap`?
exitable::Bool = false

# optional properties
minthreads::Union{Nothing,Int,NTuple{<:Any,Int}} = nothing
maxthreads::Union{Nothing,Int,NTuple{<:Any,Int}} = nothing
blocks_per_sm::Union{Nothing,Int} = nothing
maxregs::Union{Nothing,Int} = nothing

# deprecated; remove with next major version
exitable::Union{Nothing,Bool} = nothing
unreachable::Union{Nothing,Bool} = nothing
end

function Base.hash(target::PTXCompilerTarget, h::UInt)
h = hash(target.cap, h)
h = hash(target.ptx, h)

h = hash(target.debuginfo, h)
h = hash(target.unreachable, h)
h = hash(target.exitable, h)

h = hash(target.minthreads, h)
h = hash(target.maxthreads, h)
@@ -92,8 +90,7 @@ isintrinsic(@nospecialize(job::CompilerJob{PTXCompilerTarget}), fn::String) =
# XXX: the debuginfo part should be handled by GPUCompiler as it applies to all back-ends.
runtime_slug(@nospecialize(job::CompilerJob{PTXCompilerTarget})) =
"ptx-sm_$(job.config.target.cap.major)$(job.config.target.cap.minor)" *
"-debuginfo=$(Int(llvm_debug_info(job)))" *
"-exitable=$(job.config.target.exitable)"
"-debuginfo=$(Int(llvm_debug_info(job)))"

function finish_module!(@nospecialize(job::CompilerJob{PTXCompilerTarget}),
mod::LLVM.Module, entry::LLVM.Function)
@@ -132,14 +129,6 @@ function finish_module!(@nospecialize(job::CompilerJob{PTXCompilerTarget}),
end

@dispose pm=ModulePassManager() begin
# hide `unreachable` from LLVM so that it doesn't introduce divergent control flow
if !job.config.target.unreachable
add!(pm, FunctionPass("HideUnreachable", hide_unreachable!))
end

# even if we support `unreachable`, we still prefer `exit` to `trap`
add!(pm, ModulePass("HideTrap", hide_trap!))

# we emit properties (of the device and ptx isa) as private global constants,
# so run the optimizer so that they are inlined before the rest of the optimizer runs.
global_optimizer!(pm)
@@ -188,6 +177,13 @@ function finish_ir!(@nospecialize(job::CompilerJob{PTXCompilerTarget}),
mod::LLVM.Module, entry::LLVM.Function)
ctx = context(mod)

@dispose pm=ModulePassManager() begin
add!(pm, ModulePass("LowerTrap", lower_trap!))
add!(pm, FunctionPass("LowerUnreachable", lower_unreachable!))

run!(pm, mod)
end

if job.config.kernel
# add metadata annotations for the assembler to the module

@@ -242,111 +238,29 @@ end

## LLVM passes

# HACK: this pass removes `unreachable` information from LLVM
#
# `ptxas` is buggy and cannot deal with thread-divergent control flow in the presence of
# shared memory (see JuliaGPU/CUDAnative.jl#4). avoid that by rewriting control flow to fall
# through any other block. this is semantically invalid, but the code is unreachable anyhow
# (and we expect it to be preceded by eg. a noreturn function, or a trap).
#
# TODO: can LLVM do this with structured CFGs? It seems to have some support, but seemingly
# only to prevent introducing non-structureness during optimization (ie. the front-end
# is still responsible for generating structured control flow).
function hide_unreachable!(fun::LLVM.Function)
# replace calls to `trap` with inline assembly calling `exit`, which isn't fatal
function lower_trap!(mod::LLVM.Module)
job = current_job::CompilerJob
ctx = context(fun)
ctx = context(mod)
changed = false
@timeit_debug to "hide unreachable" begin
@timeit_debug to "lower trap" begin

# remove `noreturn` attributes
#
# when calling a `noreturn` function, LLVM places an `unreachable` after the call.
# this leads to an early `ret` from the function.
attrs = function_attributes(fun)
delete!(attrs, EnumAttribute("noreturn", 0; ctx))
if haskey(functions(mod), "llvm.trap")
trap = functions(mod)["llvm.trap"]

# build a map of basic block predecessors
predecessors = Dict(bb => Set{LLVM.BasicBlock}() for bb in blocks(fun))
@timeit_debug to "predecessors" for bb in blocks(fun)
insts = instructions(bb)
if !isempty(insts)
inst = last(insts)
if isterminator(inst)
for bb′ in successors(inst)
push!(predecessors[bb′], bb)
end
end
end
end
# inline assembly to exit a thread
exit_ft = LLVM.FunctionType(LLVM.VoidType(ctx))
exit = InlineAsm(exit_ft, "exit;", "", true)

# scan for unreachable terminators and alternative successors
worklist = Pair{LLVM.BasicBlock, Union{Nothing,LLVM.BasicBlock}}[]
@timeit_debug to "find" for bb in blocks(fun)
unreachable = terminator(bb)
if isa(unreachable, LLVM.UnreachableInst)
unsafe_delete!(bb, unreachable)
changed = true

try
terminator(bb)
# the basic-block is still terminated properly, nothing to do
# (this can happen with `ret; unreachable`)
# TODO: `unreachable; unreachable`
catch ex
isa(ex, UndefRefError) || rethrow(ex)
for use in uses(trap)
val = user(use)
if isa(val, LLVM.CallInst)
@dispose builder=IRBuilder(ctx) begin
position!(builder, bb)

# find the strict predecessors to this block
preds = collect(predecessors[bb])

# find a fallthrough block: recursively look at predecessors
# and find a successor that branches to any other block
fallthrough = nothing
while !isempty(preds)
# find an alternative successor
for pred in preds, succ in successors(terminator(pred))
if succ != bb
fallthrough = succ
break
end
end
fallthrough === nothing || break

# recurse upwards
old_preds = copy(preds)
empty!(preds)
for pred in old_preds
append!(preds, predecessors[pred])
end
end
push!(worklist, bb => fallthrough)
end
end
end
end

# apply the pending terminator rewrites
@timeit_debug to "replace" if !isempty(worklist)
let builder = IRBuilder(ctx)
for (bb, fallthrough) in worklist
position!(builder, bb)
if fallthrough !== nothing
br!(builder, fallthrough)
else
# couldn't find any other successor. this happens with functions
# that only contain a single block, or when the block is dead.
ft = function_type(fun)
if return_type(ft) == LLVM.VoidType(ctx)
# even though returning can lead to invalid control flow,
# it mostly happens with functions that just throw,
# and leaving the unreachable there would make the optimizer
# place another after the call.
ret!(builder)
else
unreachable!(builder)
end
position!(builder, val)
call!(builder, exit_ft, exit)
end
unsafe_delete!(LLVM.parent(val), val)
changed = true
end
end
end
@@ -355,41 +269,121 @@ function hide_unreachable!(fun::LLVM.Function)
return changed
end

# HACK: this pass removes calls to `trap` and replaces them with inline assembly
# lower `unreachable` to `exit` so that the emitted PTX has correct control flow
#
# if LLVM knows we're trapping, code is marked `unreachable` (see `hide_unreachable!`).
function hide_trap!(mod::LLVM.Module)
job = current_job::CompilerJob
ctx = context(mod)
changed = false
@timeit_debug to "hide trap" begin
# During back-end compilation, `ptxas` inserts instructions to manage the harware's
# reconvergence stack (SSY and SYNC). In order to do so, it needs to identify
# divergent regions:
#
# entry:
# // start of divergent region
# @%p0 bra cont;
# ...
# bra.uni cont;
# cont:
# // end of divergent region
# bar.sync 0;
#
# Meanwhile, LLVM's branch-folder and block-placement MIR passes will try to optimize
# the block layout, e.g., by placing unlikely blocks at the end of the function:
#
# entry:
# // start of divergent region
# @%p0 bra cont;
# @%p1 bra unlikely;
# bra.uni cont;
# cont:
# // end of divergent region
# bar.sync 0;
# unlikely:
# bra.uni cont;
#
# That is not a problem as long as the unlikely block continunes back into the
# divergent region. Crucially, this is not the case with unreachable control flow:
#
# entry:
# // start of divergent region
# @%p0 bra cont;
# @%p1 bra throw;
# bra.uni cont;
# cont:
# bar.sync 0;
# throw:
# call throw_and_trap();
# // unreachable
# exit:
# // end of divergent region
# ret;
#
# Dynamically, this is fine, because the called function does not return.
# However, `ptxas` does not know that and adds a successor edge to the `exit`
# block, widening the divergence range. In this example, that's not allowed, as
# `bar.sync` cannot be executed divergently on Pascal hardware or earlier.
#
# To avoid these fall-through successors that change the control flow,
# we replace `unreachable` instructions with a call to `exit`. This informs
# `ptxas` that the thread exits, and allows it to correctly construct a CFG,
# and consequently correctly determine the divergence regions as intended.
function lower_unreachable!(f::LLVM.Function)
ctx = context(f)

# TODO:
# - if unreachable blocks have been merged, we still may be jumping from different
# divergent regions, potentially causing the same problem as above:
# entry:
# // start of divergent region 1
# @%p0 bra cont1;
# @%p1 bra throw;
# bra.uni cont1;
# cont1:
# // end of divergent region 1
# bar.sync 0; // is this executed divergently?
# // start of divergent region 2
# @%p2 bra cont2;
# @%p3 bra throw;
# bra.uni cont2;
# cont2:
# // end of divergent region 2
# ...
# throw:
# trap;
# br throw;
# if this is a problem, we probably need to clone blocks with multiple
# predecessors so that there's a unique path from each region of
# divergence to every `unreachable` terminator

# remove `noreturn` attributes, to avoid the (minimal) optimization that
# happens during `prepare_execution!` undoing our work here.
# this shouldn't be needed when we upstream the pass.
attrs = function_attributes(f)
delete!(attrs, EnumAttribute("noreturn", 0; ctx))

# inline assembly to exit a thread, hiding control flow from LLVM
exit_ft = LLVM.FunctionType(LLVM.VoidType(ctx))
exit = if job.config.target.exitable
InlineAsm(exit_ft, "exit;", "", true)
else
InlineAsm(exit_ft, "trap;", "", true)
# find unreachable blocks
unreachable_blocks = BasicBlock[]
for block in blocks(f)
if terminator(block) isa LLVM.UnreachableInst
push!(unreachable_blocks, block)
end
end
isempty(unreachable_blocks) && return false

if haskey(functions(mod), "llvm.trap")
trap = functions(mod)["llvm.trap"]
# inline assembly to exit a thread
exit_ft = LLVM.FunctionType(LLVM.VoidType(ctx))
exit = InlineAsm(exit_ft, "exit;", "", true)

for use in uses(trap)
val = user(use)
if isa(val, LLVM.CallInst)
@dispose builder=IRBuilder(ctx) begin
position!(builder, val)
call!(builder, exit_ft, exit)
end
unsafe_delete!(LLVM.parent(val), val)
changed = true
end
# rewrite the unreachable terminators
@dispose builder=IRBuilder(ctx) begin
entry_block = first(blocks(f))
for block in unreachable_blocks
inst = terminator(block)
@assert inst isa LLVM.UnreachableInst

position!(builder, inst)
call!(builder, exit_ft, exit)
end
end

end
return changed
return true
end

# Replace occurrences of __nvvm_reflect("foo") and llvm.nvvm.reflect with an integer.