Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for the LLVM SPIR-V back-end. #665

Merged
merged 1 commit into from
Feb 13, 2025
Merged

Conversation

maleadt
Copy link
Member

@maleadt maleadt commented Feb 13, 2025

This adds a temporary backend flag to the SPIRVTarget constructor, automatically set based on whether the Translator or LLVM JLL is available in the current environment. It's a non-breaking change; we'll only have to tag a breaking release when removing support for the translator.

Demo:

julia> main()
target = SPIRVCompilerTarget(; backend = :khronos) = SPIRVCompilerTarget(nothing, String[], true, true, :khronos)
; SPIR-V
; Version: 1.0
; Generator: Khronos LLVM/SPIR-V Translator; 14
; Bound: 9
; Schema: 0
               OpCapability Addresses
               OpCapability Linkage
               OpCapability Kernel
          %1 = OpExtInstImport "OpenCL.std"
               OpMemoryModel Physical64 OpenCL
               OpEntryPoint Kernel %6 "_Z6kernel"
               OpSource OpenCL_C 200000
               OpName %_Z6kernel "_Z6kernel"
               OpName %conversion "conversion"
               OpDecorate %_Z6kernel LinkageAttributes "_Z6kernel" Export
       %void = OpTypeVoid
          %3 = OpTypeFunction %void
  %_Z6kernel = OpFunction %void None %3
 %conversion = OpLabel
               OpReturn
               OpFunctionEnd
          %6 = OpFunction %void None %3
          %7 = OpLabel
          %8 = OpFunctionCall %void %_Z6kernel
               OpReturn
               OpFunctionEnd


julia> main()
target = SPIRVCompilerTarget(; backend = :llvm) = SPIRVCompilerTarget(nothing, String[], true, true, :llvm)
; SPIR-V
; Version: 1.4
; Generator: LLVM LLVM SPIR-V Backend; 19
; Bound: 11
; Schema: 0
               OpCapability Kernel
               OpCapability Addresses
          %1 = OpExtInstImport "OpenCL.std"
               OpMemoryModel Physical64 OpenCL
               OpEntryPoint Kernel %_Z6kernel "_Z6kernel"
               OpExecutionMode %_Z6kernel ContractionOff
               OpSource OpenCL_C 200000
               OpName %_Z6kernel "_Z6kernel"
       %void = OpTypeVoid
          %3 = OpTypeFunction %void
  %_Z6kernel = OpFunction %void None %3
          %5 = OpLabel
               OpReturn
               OpFunctionEnd

Back-ends will need some work to switch to this alternative back-end, e.g., probably reimplement atomics based on LLVM atomics. That said, I think this back-end should have several advantages. The code is more likely to pass validation (and thus be usable with tools like spirv-reduce), and the generated IL doesn't contain a wrapper function hopefully allowing us to compile back to OpenCL C code for JuliaGPU/OpenCL.jl#234

@maleadt maleadt added the spirv Stuff about the SPIR-V back-end. label Feb 13, 2025
Copy link
Contributor

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic master) to apply these changes.

Click here to view the suggested changes.
diff --git a/src/spirv.jl b/src/spirv.jl
index f819fdb..1bc0917 100644
--- a/src/spirv.jl
+++ b/src/spirv.jl
@@ -5,14 +5,20 @@
 # https://github.com/KhronosGroup/SPIRV-LLVM-Translator/blob/master/docs/SPIRVRepresentationInLLVM.rst
 
 const SPIRV_LLVM_Backend_jll =
-    LazyModule("SPIRV_LLVM_Backend_jll",
-               UUID("4376b9bf-cff8-51b6-bb48-39421dff0d0c"))
+    LazyModule(
+    "SPIRV_LLVM_Backend_jll",
+    UUID("4376b9bf-cff8-51b6-bb48-39421dff0d0c")
+)
 const SPIRV_LLVM_Translator_unified_jll =
-    LazyModule("SPIRV_LLVM_Translator_unified_jll",
-               UUID("85f0d8ed-5b39-5caa-b1ae-7472de402361"))
+    LazyModule(
+    "SPIRV_LLVM_Translator_unified_jll",
+    UUID("85f0d8ed-5b39-5caa-b1ae-7472de402361")
+)
 const SPIRV_Tools_jll =
-    LazyModule("SPIRV_Tools_jll",
-               UUID("6ac6d60f-d740-5983-97d7-a4482c0689f4"))
+    LazyModule(
+    "SPIRV_Tools_jll",
+    UUID("6ac6d60f-d740-5983-97d7-a4482c0689f4")
+)
 
 
 ## target
@@ -20,7 +26,7 @@ const SPIRV_Tools_jll =
 export SPIRVCompilerTarget
 
 Base.@kwdef struct SPIRVCompilerTarget <: AbstractCompilerTarget
-    version::Union{Nothing,VersionNumber} = nothing
+    version::Union{Nothing, VersionNumber} = nothing
     extensions::Vector{String} = []
     supports_fp16::Bool = true
     supports_fp64::Bool = true
@@ -33,14 +39,14 @@ end
 
 function llvm_triple(target::SPIRVCompilerTarget)
     if target.backend == :llvm
-        architecture = Int===Int64 ? "spirv64" : "spirv32"  # could also be "spirv" for logical addressing
+        architecture = Int === Int64 ? "spirv64" : "spirv32"  # could also be "spirv" for logical addressing
         subarchitecture = target.version === nothing ? "" : "v$(target.version.major).$(target.version.minor)"
         vendor = "unknown"  # could also be AMD
         os = "unknown"
         environment = "unknown"
         return "$architecture$subarchitecture-$vendor-$os-$environment"
     elseif target.backend == :khronos
-        return Int===Int64 ? "spir64-unknown-unknown" : "spirv-unknown-unknown"
+        return Int === Int64 ? "spir64-unknown-unknown" : "spirv-unknown-unknown"
     end
 end
 
@@ -117,20 +123,20 @@ end
 
     # translate to SPIR-V
     input = tempname(cleanup=false) * ".bc"
-    translated = tempname(cleanup=false) * ".spv"
+    translated = tempname(cleanup = false) * ".spv"
     write(input, mod)
     if job.config.target.backend === :llvm
         cmd = `$(SPIRV_LLVM_Backend_jll.llc()) $input -filetype=obj -o $translated`
 
         if !isempty(job.config.target.extensions)
-            str = join(map(ext->"+$ext", job.config.target.extensions), ",")
+            str = join(map(ext -> "+$ext", job.config.target.extensions), ",")
             cmd = `$(cmd) -spirv-ext=$str`
         end
     elseif job.config.target.backend === :khronos
         cmd = `$(SPIRV_LLVM_Translator_unified_jll.llvm_spirv()) -o $translated $input --spirv-debug-info-version=ocl-100`
 
         if !isempty(job.config.target.extensions)
-            str = join(map(ext->"+$ext", job.config.target.extensions), ",")
+            str = join(map(ext -> "+$ext", job.config.target.extensions), ",")
             cmd = `$(cmd) --spirv-ext=$str`
         end
 
@@ -140,29 +146,35 @@ end
     end
     proc = run(ignorestatus(cmd))
     if !success(proc)
-        error("""Failed to translate LLVM code to SPIR-V.
-                 If you think this is a bug, please file an issue and attach $(input).""")
+        error(
+            """Failed to translate LLVM code to SPIR-V.
+            If you think this is a bug, please file an issue and attach $(input)."""
+        )
     end
 
     # validate
     if job.config.target.validate
-       cmd = `$(SPIRV_Tools_jll.spirv_val()) $translated`
-       proc = run(ignorestatus(cmd))
-       if !success(proc)
-           error("""Failed to validate generated SPIR-V.
-                    If you think this is a bug, please file an issue and attach $(input) and $(translated).""")
-       end
+        cmd = `$(SPIRV_Tools_jll.spirv_val()) $translated`
+        proc = run(ignorestatus(cmd))
+        if !success(proc)
+            error(
+                """Failed to validate generated SPIR-V.
+                If you think this is a bug, please file an issue and attach $(input) and $(translated)."""
+            )
+        end
     end
 
     # optimize
     optimized = tempname(cleanup=false) * ".spv"
     if job.config.target.optimize
         cmd = `$(SPIRV_Tools_jll.spirv_opt()) -O --skip-validation $translated -o $optimized`
-       proc = run(ignorestatus(cmd))
-       if !success(proc)
-           error("""Failed to optimize generated SPIR-V.
-                    If you think this is a bug, please file an issue and attach $(input) and $(translated).""")
-       end
+        proc = run(ignorestatus(cmd))
+        if !success(proc)
+            error(
+                """Failed to optimize generated SPIR-V.
+                If you think this is a bug, please file an issue and attach $(input) and $(translated)."""
+            )
+        end
     else
         cp(translated, optimized)
     end
diff --git a/test/helpers/spirv.jl b/test/helpers/spirv.jl
index 761cd49..0f8e9f3 100644
--- a/test/helpers/spirv.jl
+++ b/test/helpers/spirv.jl
@@ -8,11 +8,14 @@ GPUCompiler.runtime_module(::CompilerJob{<:Any,CompilerParams}) = TestRuntime
 
 function create_job(@nospecialize(func), @nospecialize(types);
                    kernel::Bool=false, always_inline=false,
-                   supports_fp16=true, supports_fp64=true,
-                   backend::Symbol, kwargs...)
+        supports_fp16 = true, supports_fp64 = true,
+        backend::Symbol, kwargs...
+    )
     source = methodinstance(typeof(func), Base.to_tuple_type(types), Base.get_world_counter())
-    target = SPIRVCompilerTarget(; backend, validate=true, optimize=true,
-                                   supports_fp16, supports_fp64)
+    target = SPIRVCompilerTarget(;
+        backend, validate = true, optimize = true,
+        supports_fp16, supports_fp64
+    )
     params = CompilerParams()
     config = CompilerConfig(target, params; kernel, always_inline)
     CompilerJob(source, config), kwargs
diff --git a/test/spirv.jl b/test/spirv.jl
index 2d7fb84..d0d2db8 100644
--- a/test/spirv.jl
+++ b/test/spirv.jl
@@ -6,11 +6,13 @@ for backend in (:khronos, :llvm)
 @testset "calling convention" begin
     kernel() = return
 
-    ir = sprint(io->SPIRV.code_llvm(io, kernel, Tuple{}; backend, dump_module=true))
+                ir = sprint(io -> SPIRV.code_llvm(io, kernel, Tuple{}; backend, dump_module = true))
     @test !occursin("spir_kernel", ir)
 
     ir = sprint(io->SPIRV.code_llvm(io, kernel, Tuple{};
-                                    backend, dump_module=true, kernel=true))
+                        backend, dump_module = true, kernel = true
+                    )
+                )
     @test occursin("spir_kernel", ir)
 end
 
@@ -20,12 +22,16 @@ end
         kernel(x) = return
     end
 
-    ir = sprint(io->SPIRV.code_llvm(io, mod.kernel, Tuple{Tuple{Int}}; backend))
+                ir = sprint(io -> SPIRV.code_llvm(io, mod.kernel, Tuple{Tuple{Int}}; backend))
     @test occursin(r"@\w*kernel\w*\(({ i64 }|\[1 x i64\])\*", ir) ||
           occursin(r"@\w*kernel\w*\(ptr", ir)
 
-    ir = sprint(io->SPIRV.code_llvm(io, mod.kernel, Tuple{Tuple{Int}};
-                                    backend, kernel=true))
+                ir = sprint(
+                    io -> SPIRV.code_llvm(
+                        io, mod.kernel, Tuple{Tuple{Int}};
+                        backend, kernel = true
+                    )
+                )
     @test occursin(r"@\w*kernel\w*\(.*{ ({ i64 }|\[1 x i64\]) }\*.+byval", ir) ||
           occursin(r"@\w*kernel\w*\(ptr byval", ir)
 end
@@ -33,7 +39,7 @@ end
 @testset "byval bug" begin
     # byval added alwaysinline, which could conflict with noinline and fail verification
     @noinline kernel() = return
-    SPIRV.code_llvm(devnull, kernel, Tuple{}; backend, kernel=true)
+                SPIRV.code_llvm(devnull, kernel, Tuple{}; backend, kernel = true)
     @test "We did not crash!" != ""
 end
 end
@@ -47,21 +53,35 @@ end
         end
     end
 
-    ir = sprint(io->SPIRV.code_llvm(io, mod.kernel, Tuple{Ptr{Float16}, Float16};
-                                    backend, validate=true))
+            ir = sprint(
+                io -> SPIRV.code_llvm(
+                    io, mod.kernel, Tuple{Ptr{Float16}, Float16};
+                    backend, validate = true
+                )
+            )
     @test occursin("store half", ir)
 
-    ir = sprint(io->SPIRV.code_llvm(io, mod.kernel, Tuple{Ptr{Float32}, Float32};
-                                    backend, validate=true))
+            ir = sprint(
+                io -> SPIRV.code_llvm(
+                    io, mod.kernel, Tuple{Ptr{Float32}, Float32};
+                    backend, validate = true
+                )
+            )
     @test occursin("store float", ir)
 
-    ir = sprint(io->SPIRV.code_llvm(io, mod.kernel, Tuple{Ptr{Float64}, Float64};
-                                    backend, validate=true))
+            ir = sprint(
+                io -> SPIRV.code_llvm(
+                    io, mod.kernel, Tuple{Ptr{Float64}, Float64};
+                    backend, validate = true
+                )
+            )
     @test occursin("store double", ir)
 
     @test_throws_message(InvalidIRError,
                          SPIRV.code_llvm(devnull, mod.kernel, Tuple{Ptr{Float16}, Float16};
-                                         backend, supports_fp16=false, validate=true)) do msg
+                    backend, supports_fp16 = false, validate = true
+                )
+            ) do msg
         occursin("unsupported use of half value", msg) &&
         occursin("[1] unsafe_store!", msg) &&
         occursin("[2] kernel", msg)
@@ -69,7 +89,9 @@ end
 
     @test_throws_message(InvalidIRError,
                          SPIRV.code_llvm(devnull, mod.kernel, Tuple{Ptr{Float64}, Float64};
-                                         backend, supports_fp64=false, validate=true)) do msg
+                    backend, supports_fp64 = false, validate = true
+                )
+            ) do msg
         occursin("unsupported use of double value", msg) &&
         occursin("[1] unsafe_store!", msg) &&
         occursin("[2] kernel", msg)
@@ -88,7 +110,7 @@ end
         return
     end
 
-    asm = sprint(io->SPIRV.code_native(io, kernel, Tuple{Bool}; backend, kernel=true))
+            asm = sprint(io -> SPIRV.code_native(io, kernel, Tuple{Bool}; backend, kernel = true))
     @test occursin(r"OpFunctionCall %void %(julia|j)_error", asm)
 end
 

@maleadt
Copy link
Member Author

maleadt commented Feb 13, 2025

Nightly failures unrelated.

@maleadt maleadt merged commit 7b208f5 into master Feb 13, 2025
15 of 19 checks passed
@maleadt maleadt deleted the tb/llvm_spirv_backend branch February 13, 2025 07:06
Copy link

codecov bot commented Feb 13, 2025

Codecov Report

Attention: Patch coverage is 75.00000% with 10 lines in your changes missing coverage. Please review.

Project coverage is 71.89%. Comparing base (870fa83) to head (1526029).
Report is 3 commits behind head on master.

Files with missing lines Patch % Lines
src/spirv.jl 76.92% 9 Missing ⚠️
src/utils.jl 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master     #665      +/-   ##
==========================================
+ Coverage   71.86%   71.89%   +0.03%     
==========================================
  Files          24       24              
  Lines        3330     3359      +29     
==========================================
+ Hits         2393     2415      +22     
- Misses        937      944       +7     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
spirv Stuff about the SPIR-V back-end.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant