-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for the LLVM SPIR-V back-end. #665
Conversation
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/src/spirv.jl b/src/spirv.jl
index f819fdb..1bc0917 100644
--- a/src/spirv.jl
+++ b/src/spirv.jl
@@ -5,14 +5,20 @@
# https://github.com/KhronosGroup/SPIRV-LLVM-Translator/blob/master/docs/SPIRVRepresentationInLLVM.rst
const SPIRV_LLVM_Backend_jll =
- LazyModule("SPIRV_LLVM_Backend_jll",
- UUID("4376b9bf-cff8-51b6-bb48-39421dff0d0c"))
+ LazyModule(
+ "SPIRV_LLVM_Backend_jll",
+ UUID("4376b9bf-cff8-51b6-bb48-39421dff0d0c")
+)
const SPIRV_LLVM_Translator_unified_jll =
- LazyModule("SPIRV_LLVM_Translator_unified_jll",
- UUID("85f0d8ed-5b39-5caa-b1ae-7472de402361"))
+ LazyModule(
+ "SPIRV_LLVM_Translator_unified_jll",
+ UUID("85f0d8ed-5b39-5caa-b1ae-7472de402361")
+)
const SPIRV_Tools_jll =
- LazyModule("SPIRV_Tools_jll",
- UUID("6ac6d60f-d740-5983-97d7-a4482c0689f4"))
+ LazyModule(
+ "SPIRV_Tools_jll",
+ UUID("6ac6d60f-d740-5983-97d7-a4482c0689f4")
+)
## target
@@ -20,7 +26,7 @@ const SPIRV_Tools_jll =
export SPIRVCompilerTarget
Base.@kwdef struct SPIRVCompilerTarget <: AbstractCompilerTarget
- version::Union{Nothing,VersionNumber} = nothing
+ version::Union{Nothing, VersionNumber} = nothing
extensions::Vector{String} = []
supports_fp16::Bool = true
supports_fp64::Bool = true
@@ -33,14 +39,14 @@ end
function llvm_triple(target::SPIRVCompilerTarget)
if target.backend == :llvm
- architecture = Int===Int64 ? "spirv64" : "spirv32" # could also be "spirv" for logical addressing
+ architecture = Int === Int64 ? "spirv64" : "spirv32" # could also be "spirv" for logical addressing
subarchitecture = target.version === nothing ? "" : "v$(target.version.major).$(target.version.minor)"
vendor = "unknown" # could also be AMD
os = "unknown"
environment = "unknown"
return "$architecture$subarchitecture-$vendor-$os-$environment"
elseif target.backend == :khronos
- return Int===Int64 ? "spir64-unknown-unknown" : "spirv-unknown-unknown"
+ return Int === Int64 ? "spir64-unknown-unknown" : "spirv-unknown-unknown"
end
end
@@ -117,20 +123,20 @@ end
# translate to SPIR-V
input = tempname(cleanup=false) * ".bc"
- translated = tempname(cleanup=false) * ".spv"
+ translated = tempname(cleanup = false) * ".spv"
write(input, mod)
if job.config.target.backend === :llvm
cmd = `$(SPIRV_LLVM_Backend_jll.llc()) $input -filetype=obj -o $translated`
if !isempty(job.config.target.extensions)
- str = join(map(ext->"+$ext", job.config.target.extensions), ",")
+ str = join(map(ext -> "+$ext", job.config.target.extensions), ",")
cmd = `$(cmd) -spirv-ext=$str`
end
elseif job.config.target.backend === :khronos
cmd = `$(SPIRV_LLVM_Translator_unified_jll.llvm_spirv()) -o $translated $input --spirv-debug-info-version=ocl-100`
if !isempty(job.config.target.extensions)
- str = join(map(ext->"+$ext", job.config.target.extensions), ",")
+ str = join(map(ext -> "+$ext", job.config.target.extensions), ",")
cmd = `$(cmd) --spirv-ext=$str`
end
@@ -140,29 +146,35 @@ end
end
proc = run(ignorestatus(cmd))
if !success(proc)
- error("""Failed to translate LLVM code to SPIR-V.
- If you think this is a bug, please file an issue and attach $(input).""")
+ error(
+ """Failed to translate LLVM code to SPIR-V.
+ If you think this is a bug, please file an issue and attach $(input)."""
+ )
end
# validate
if job.config.target.validate
- cmd = `$(SPIRV_Tools_jll.spirv_val()) $translated`
- proc = run(ignorestatus(cmd))
- if !success(proc)
- error("""Failed to validate generated SPIR-V.
- If you think this is a bug, please file an issue and attach $(input) and $(translated).""")
- end
+ cmd = `$(SPIRV_Tools_jll.spirv_val()) $translated`
+ proc = run(ignorestatus(cmd))
+ if !success(proc)
+ error(
+ """Failed to validate generated SPIR-V.
+ If you think this is a bug, please file an issue and attach $(input) and $(translated)."""
+ )
+ end
end
# optimize
optimized = tempname(cleanup=false) * ".spv"
if job.config.target.optimize
cmd = `$(SPIRV_Tools_jll.spirv_opt()) -O --skip-validation $translated -o $optimized`
- proc = run(ignorestatus(cmd))
- if !success(proc)
- error("""Failed to optimize generated SPIR-V.
- If you think this is a bug, please file an issue and attach $(input) and $(translated).""")
- end
+ proc = run(ignorestatus(cmd))
+ if !success(proc)
+ error(
+ """Failed to optimize generated SPIR-V.
+ If you think this is a bug, please file an issue and attach $(input) and $(translated)."""
+ )
+ end
else
cp(translated, optimized)
end
diff --git a/test/helpers/spirv.jl b/test/helpers/spirv.jl
index 761cd49..0f8e9f3 100644
--- a/test/helpers/spirv.jl
+++ b/test/helpers/spirv.jl
@@ -8,11 +8,14 @@ GPUCompiler.runtime_module(::CompilerJob{<:Any,CompilerParams}) = TestRuntime
function create_job(@nospecialize(func), @nospecialize(types);
kernel::Bool=false, always_inline=false,
- supports_fp16=true, supports_fp64=true,
- backend::Symbol, kwargs...)
+ supports_fp16 = true, supports_fp64 = true,
+ backend::Symbol, kwargs...
+ )
source = methodinstance(typeof(func), Base.to_tuple_type(types), Base.get_world_counter())
- target = SPIRVCompilerTarget(; backend, validate=true, optimize=true,
- supports_fp16, supports_fp64)
+ target = SPIRVCompilerTarget(;
+ backend, validate = true, optimize = true,
+ supports_fp16, supports_fp64
+ )
params = CompilerParams()
config = CompilerConfig(target, params; kernel, always_inline)
CompilerJob(source, config), kwargs
diff --git a/test/spirv.jl b/test/spirv.jl
index 2d7fb84..d0d2db8 100644
--- a/test/spirv.jl
+++ b/test/spirv.jl
@@ -6,11 +6,13 @@ for backend in (:khronos, :llvm)
@testset "calling convention" begin
kernel() = return
- ir = sprint(io->SPIRV.code_llvm(io, kernel, Tuple{}; backend, dump_module=true))
+ ir = sprint(io -> SPIRV.code_llvm(io, kernel, Tuple{}; backend, dump_module = true))
@test !occursin("spir_kernel", ir)
ir = sprint(io->SPIRV.code_llvm(io, kernel, Tuple{};
- backend, dump_module=true, kernel=true))
+ backend, dump_module = true, kernel = true
+ )
+ )
@test occursin("spir_kernel", ir)
end
@@ -20,12 +22,16 @@ end
kernel(x) = return
end
- ir = sprint(io->SPIRV.code_llvm(io, mod.kernel, Tuple{Tuple{Int}}; backend))
+ ir = sprint(io -> SPIRV.code_llvm(io, mod.kernel, Tuple{Tuple{Int}}; backend))
@test occursin(r"@\w*kernel\w*\(({ i64 }|\[1 x i64\])\*", ir) ||
occursin(r"@\w*kernel\w*\(ptr", ir)
- ir = sprint(io->SPIRV.code_llvm(io, mod.kernel, Tuple{Tuple{Int}};
- backend, kernel=true))
+ ir = sprint(
+ io -> SPIRV.code_llvm(
+ io, mod.kernel, Tuple{Tuple{Int}};
+ backend, kernel = true
+ )
+ )
@test occursin(r"@\w*kernel\w*\(.*{ ({ i64 }|\[1 x i64\]) }\*.+byval", ir) ||
occursin(r"@\w*kernel\w*\(ptr byval", ir)
end
@@ -33,7 +39,7 @@ end
@testset "byval bug" begin
# byval added alwaysinline, which could conflict with noinline and fail verification
@noinline kernel() = return
- SPIRV.code_llvm(devnull, kernel, Tuple{}; backend, kernel=true)
+ SPIRV.code_llvm(devnull, kernel, Tuple{}; backend, kernel = true)
@test "We did not crash!" != ""
end
end
@@ -47,21 +53,35 @@ end
end
end
- ir = sprint(io->SPIRV.code_llvm(io, mod.kernel, Tuple{Ptr{Float16}, Float16};
- backend, validate=true))
+ ir = sprint(
+ io -> SPIRV.code_llvm(
+ io, mod.kernel, Tuple{Ptr{Float16}, Float16};
+ backend, validate = true
+ )
+ )
@test occursin("store half", ir)
- ir = sprint(io->SPIRV.code_llvm(io, mod.kernel, Tuple{Ptr{Float32}, Float32};
- backend, validate=true))
+ ir = sprint(
+ io -> SPIRV.code_llvm(
+ io, mod.kernel, Tuple{Ptr{Float32}, Float32};
+ backend, validate = true
+ )
+ )
@test occursin("store float", ir)
- ir = sprint(io->SPIRV.code_llvm(io, mod.kernel, Tuple{Ptr{Float64}, Float64};
- backend, validate=true))
+ ir = sprint(
+ io -> SPIRV.code_llvm(
+ io, mod.kernel, Tuple{Ptr{Float64}, Float64};
+ backend, validate = true
+ )
+ )
@test occursin("store double", ir)
@test_throws_message(InvalidIRError,
SPIRV.code_llvm(devnull, mod.kernel, Tuple{Ptr{Float16}, Float16};
- backend, supports_fp16=false, validate=true)) do msg
+ backend, supports_fp16 = false, validate = true
+ )
+ ) do msg
occursin("unsupported use of half value", msg) &&
occursin("[1] unsafe_store!", msg) &&
occursin("[2] kernel", msg)
@@ -69,7 +89,9 @@ end
@test_throws_message(InvalidIRError,
SPIRV.code_llvm(devnull, mod.kernel, Tuple{Ptr{Float64}, Float64};
- backend, supports_fp64=false, validate=true)) do msg
+ backend, supports_fp64 = false, validate = true
+ )
+ ) do msg
occursin("unsupported use of double value", msg) &&
occursin("[1] unsafe_store!", msg) &&
occursin("[2] kernel", msg)
@@ -88,7 +110,7 @@ end
return
end
- asm = sprint(io->SPIRV.code_native(io, kernel, Tuple{Bool}; backend, kernel=true))
+ asm = sprint(io -> SPIRV.code_native(io, kernel, Tuple{Bool}; backend, kernel = true))
@test occursin(r"OpFunctionCall %void %(julia|j)_error", asm)
end
|
Nightly failures unrelated. |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #665 +/- ##
==========================================
+ Coverage 71.86% 71.89% +0.03%
==========================================
Files 24 24
Lines 3330 3359 +29
==========================================
+ Hits 2393 2415 +22
- Misses 937 944 +7 ☔ View full report in Codecov by Sentry. |
This adds a temporary
backend
flag to the SPIRVTarget constructor, automatically set based on whether the Translator or LLVM JLL is available in the current environment. It's a non-breaking change; we'll only have to tag a breaking release when removing support for the translator.Demo:
Back-ends will need some work to switch to this alternative back-end, e.g., probably reimplement atomics based on LLVM atomics. That said, I think this back-end should have several advantages. The code is more likely to pass validation (and thus be usable with tools like
spirv-reduce
), and the generated IL doesn't contain a wrapper function hopefully allowing us to compile back to OpenCL C code for JuliaGPU/OpenCL.jl#234