Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port some intrinsic lowering code from LLVM/C++. #452

Merged
merged 1 commit into from
May 19, 2023

Conversation

maleadt
Copy link
Member

@maleadt maleadt commented May 19, 2023

I'm hoping to make the LLVM back-end more generic, focussing on only the IR downgrade part (which could be used by other back-ends too, e.g., to target NVVM).

This PR also includes a fix, now properly attaching AIR intrinsic metadata where we didn't before. As a result, it'll break some broken tests in Metal.jl, which I'll fix after this is merged.

@maleadt maleadt added the metal Stuff about the Apple metal back-end. label May 19, 2023
@maleadt maleadt force-pushed the tb/metal_lower_intrinsics branch from 6f686c0 to ebcb8c0 Compare May 19, 2023 12:48
Comment on lines +879 to +988
Float64
else
error("Unsupported maximum/minimum type: $typ")
end

# create a function that performs the IEEE-compliant operation.
# normally we'd do this inline, but LLVM.jl doesn't have BB split functionality.
new_intr_fn = "air.minimum.f$(8*sizeof(jltyp))"
if haskey(functions(mod), new_intr_fn)
new_intr = functions(mod)[new_intr_fn]
else
new_intr = LLVM.Function(mod, new_intr_fn, LLVM.FunctionType(typ, parameters(call_ft)))
push!(function_attributes(new_intr), EnumAttribute("alwaysinline"; ctx))

arg0, arg1 = parameters(new_intr)
@assert value_type(arg0) == value_type(arg1)

bb_check_arg0 = BasicBlock(new_intr, "check_arg0"; ctx)
bb_nan_arg0 = BasicBlock(new_intr, "nan_arg0"; ctx)
bb_check_arg1 = BasicBlock(new_intr, "check_arg1"; ctx)
bb_nan_arg1 = BasicBlock(new_intr, "nan_arg1"; ctx)
bb_check_zero = BasicBlock(new_intr, "check_zero"; ctx)
bb_compare_zero = BasicBlock(new_intr, "compare_zero"; ctx)
bb_fallback = BasicBlock(new_intr, "fallback"; ctx)

@dispose builder=IRBuilder(ctx) begin
# first, check if either argument is NaN, and return it if so

position!(builder, bb_check_arg0)
arg0_nan = fcmp!(builder, LLVM.API.LLVMRealUNO, arg0, arg0)
br!(builder, arg0_nan, bb_nan_arg0, bb_check_arg1)

position!(builder, bb_nan_arg0)
ret!(builder, arg0)

position!(builder, bb_check_arg1)
arg1_nan = fcmp!(builder, LLVM.API.LLVMRealUNO, arg1, arg1)
br!(builder, arg1_nan, bb_nan_arg1, bb_check_zero)

position!(builder, bb_nan_arg1)
ret!(builder, arg1)

# then, check if both arguments are zero and have a mismatching sign.
# if so, return in accordance to the intrinsic (minimum or maximum)

position!(builder, bb_check_zero)

typ′ = LLVM.IntType(8*sizeof(jltyp); ctx)
arg0′ = bitcast!(builder, arg0, typ′)
arg1′ = bitcast!(builder, arg1, typ′)

arg0_zero = fcmp!(builder, LLVM.API.LLVMRealUEQ, arg0,
LLVM.ConstantFP(typ, zero(jltyp)))
arg1_zero = fcmp!(builder, LLVM.API.LLVMRealUEQ, arg1,
LLVM.ConstantFP(typ, zero(jltyp)))
args_zero = and!(builder, arg0_zero, arg1_zero)
arg0_sign = and!(builder, arg0′, LLVM.ConstantInt(typ′, Base.sign_mask(jltyp)))
arg1_sign = and!(builder, arg1′, LLVM.ConstantInt(typ′, Base.sign_mask(jltyp)))
sign_mismatch = icmp!(builder, LLVM.API.LLVMIntNE, arg0_sign, arg1_sign)
relevant_zero = and!(builder, args_zero, sign_mismatch)
br!(builder, relevant_zero, bb_compare_zero, bb_fallback)

position!(builder, bb_compare_zero)
arg0_negative = icmp!(builder, LLVM.API.LLVMIntNE, arg0_sign,
LLVM.ConstantInt(typ′, 0))
val = if intr == LLVM.Intrinsic("llvm.minimum")
select!(builder, arg0_negative, arg0, arg1)
else
select!(builder, arg0_negative, arg1, arg0)
end
ret!(builder, val)

# finally, it's safe to use the existing minnum/maxnum intrinsics

position!(builder, bb_fallback)
fallback_intr_fn = if intr == LLVM.Intrinsic("llvm.minimum")
"air.fmin.f$(8*sizeof(jltyp))"
else
"air.fmax.f$(8*sizeof(jltyp))"
end
fallback_intr = if haskey(functions(mod), fallback_intr_fn)
functions(mod)[fallback_intr_fn]
else
LLVM.Function(mod, fallback_intr_fn, LLVM.FunctionType(typ, parameters(call_ft)))
end
val = call!(builder, fallback_intr, collect(parameters(new_intr)))
ret!(builder, val)
end
end

@dispose builder=IRBuilder(ctx) begin
position!(builder, call)
debuglocation!(builder, call)

new_value = call!(builder, new_intr, arguments(call))
replace_uses!(call, new_value)
unsafe_delete!(bb, call)
changed = true
end
end
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gbaraldi Look what you made me do.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry :(. It's very annoying that backends don't know how to lower this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problem 🙂 But yeah it is annoying that you can't reasonably expect fairly basic things to be available across all back-ends; it should be easy enough for LLVM to provide a fallback lowering...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

X86 not having a lowering is telling though.

@maleadt maleadt merged commit 9351e5f into master May 19, 2023
@maleadt maleadt deleted the tb/metal_lower_intrinsics branch May 19, 2023 13:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
metal Stuff about the Apple metal back-end.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants