Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WebAssembly] Implement the wide-arithmetic proposal #111598

Merged
merged 4 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions clang/include/clang/Driver/Options.td
Original file line number Diff line number Diff line change
Expand Up @@ -5098,6 +5098,8 @@ def msimd128 : Flag<["-"], "msimd128">, Group<m_wasm_Features_Group>;
def mno_simd128 : Flag<["-"], "mno-simd128">, Group<m_wasm_Features_Group>;
def mtail_call : Flag<["-"], "mtail-call">, Group<m_wasm_Features_Group>;
def mno_tail_call : Flag<["-"], "mno-tail-call">, Group<m_wasm_Features_Group>;
def mwide_arithmetic : Flag<["-"], "mwide-arithmetic">, Group<m_wasm_Features_Group>;
def mno_wide_arithmetic : Flag<["-"], "mno-wide-arithmetic">, Group<m_wasm_Features_Group>;
def mexec_model_EQ : Joined<["-"], "mexec-model=">, Group<m_wasm_Features_Driver_Group>,
Values<"command,reactor">,
HelpText<"Execution model (WebAssembly only)">,
Expand Down
12 changes: 12 additions & 0 deletions clang/lib/Basic/Targets/WebAssembly.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ bool WebAssemblyTargetInfo::hasFeature(StringRef Feature) const {
.Case("sign-ext", HasSignExt)
.Case("simd128", SIMDLevel >= SIMD128)
.Case("tail-call", HasTailCall)
.Case("wide-arithmetic", HasWideArithmetic)
.Default(false);
}

Expand Down Expand Up @@ -102,6 +103,8 @@ void WebAssemblyTargetInfo::getTargetDefines(const LangOptions &Opts,
Builder.defineMacro("__wasm_simd128__");
if (HasTailCall)
Builder.defineMacro("__wasm_tail_call__");
if (HasWideArithmetic)
Builder.defineMacro("__wasm_wide_arithmetic__");

Builder.defineMacro("__GCC_HAVE_SYNC_COMPARE_AND_SWAP_1");
Builder.defineMacro("__GCC_HAVE_SYNC_COMPARE_AND_SWAP_2");
Expand Down Expand Up @@ -166,6 +169,7 @@ bool WebAssemblyTargetInfo::initFeatureMap(
Features["multimemory"] = true;
Features["nontrapping-fptoint"] = true;
Features["tail-call"] = true;
Features["wide-arithmetic"] = true;
setSIMDLevel(Features, RelaxedSIMD, true);
};
if (CPU == "generic") {
Expand Down Expand Up @@ -293,6 +297,14 @@ bool WebAssemblyTargetInfo::handleTargetFeatures(
HasTailCall = false;
continue;
}
if (Feature == "+wide-arithmetic") {
HasWideArithmetic = true;
continue;
}
if (Feature == "-wide-arithmetic") {
HasWideArithmetic = false;
continue;
}

Diags.Report(diag::err_opt_not_valid_with_opt)
<< Feature << "-target-feature";
Expand Down
1 change: 1 addition & 0 deletions clang/lib/Basic/Targets/WebAssembly.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class LLVM_LIBRARY_VISIBILITY WebAssemblyTargetInfo : public TargetInfo {
bool HasReferenceTypes = false;
bool HasSignExt = false;
bool HasTailCall = false;
bool HasWideArithmetic = false;

std::string ABI;

Expand Down
6 changes: 6 additions & 0 deletions clang/test/Driver/wasm-features.c
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,9 @@

// TAIL-CALL: "-target-feature" "+tail-call"
// NO-TAIL-CALL: "-target-feature" "-tail-call"

// RUN: %clang --target=wasm32-unknown-unknown -### %s -mwide-arithmetic 2>&1 | FileCheck %s -check-prefix=WIDE-ARITH
// RUN: %clang --target=wasm32-unknown-unknown -### %s -mno-wide-arithmetic 2>&1 | FileCheck %s -check-prefix=NO-WIDE-ARITH

// WIDE-ARITH: "-target-feature" "+wide-arithmetic"
// NO-WIDE-ARITH: "-target-feature" "-wide-arithmetic"
12 changes: 12 additions & 0 deletions clang/test/Preprocessor/wasm-target-features.c
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@
// MVP-NOT: #define __wasm_sign_ext__ 1{{$}}
// MVP-NOT: #define __wasm_simd128__ 1{{$}}
// MVP-NOT: #define __wasm_tail_call__ 1{{$}}
// MVP-NOT: #define __wasm_wide_arithmetic__ 1{{$}}

// RUN: %clang -E -dM %s -o - 2>&1 \
// RUN: -target wasm32-unknown-unknown -mcpu=generic \
Expand Down Expand Up @@ -184,6 +185,7 @@
// GENERIC-NOT: #define __wasm_relaxed_simd__ 1{{$}}
// GENERIC-NOT: #define __wasm_simd128__ 1{{$}}
// GENERIC-NOT: #define __wasm_tail_call__ 1{{$}}
// GENERIC-NOT: #define __wasm_wide_arithmetic__ 1{{$}}

// RUN: %clang -E -dM %s -o - 2>&1 \
// RUN: -target wasm32-unknown-unknown -mcpu=bleeding-edge \
Expand All @@ -206,6 +208,7 @@
// BLEEDING-EDGE-INCLUDE-DAG: #define __wasm_sign_ext__ 1{{$}}
// BLEEDING-EDGE-INCLUDE-DAG: #define __wasm_simd128__ 1{{$}}
// BLEEDING-EDGE-INCLUDE-DAG: #define __wasm_tail_call__ 1{{$}}
// BLEEDING-EDGE-INCLUDE-DAG: #define __wasm_wide_arithmetic__ 1{{$}}

// RUN: %clang -E -dM %s -o - 2>&1 \
// RUN: -target wasm32-unknown-unknown -mcpu=bleeding-edge -mno-simd128 \
Expand All @@ -215,3 +218,12 @@
// RUN: | FileCheck %s -check-prefix=BLEEDING-EDGE-NO-SIMD128
//
// BLEEDING-EDGE-NO-SIMD128-NOT: #define __wasm_simd128__ 1{{$}}

// RUN: %clang -E -dM %s -o - 2>&1 \
// RUN: -target wasm32-unknown-unknown -mwide-arithmetic \
// RUN: | FileCheck %s -check-prefix=WIDE-ARITHMETIC
// RUN: %clang -E -dM %s -o - 2>&1 \
// RUN: -target wasm64-unknown-unknown -mwide-arithmetic \
// RUN: | FileCheck %s -check-prefix=WIDE-ARITHMETIC
//
// WIDE-ARITHMETIC: #define __wasm_wide_arithmetic__ 1{{$}}
4 changes: 4 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssembly.td
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ def FeatureTailCall :
SubtargetFeature<"tail-call", "HasTailCall", "true",
"Enable tail call instructions">;

def FeatureWideArithmetic :
SubtargetFeature<"wide-arithmetic", "HasWideArithmetic", "true",
"Enable wide-arithmetic instructions">;

//===----------------------------------------------------------------------===//
// Architectures.
//===----------------------------------------------------------------------===//
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyISD.def
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ HANDLE_NODETYPE(TRUNC_SAT_ZERO_U)
HANDLE_NODETYPE(DEMOTE_ZERO)
HANDLE_NODETYPE(MEMORY_COPY)
HANDLE_NODETYPE(MEMORY_FILL)
HANDLE_NODETYPE(I64_ADD128)
HANDLE_NODETYPE(I64_SUB128)
HANDLE_NODETYPE(I64_MUL_WIDE_S)
HANDLE_NODETYPE(I64_MUL_WIDE_U)

// Memory intrinsics
HANDLE_MEM_NODETYPE(GLOBAL_GET)
Expand Down
71 changes: 71 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,13 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
setOperationAction(Op, T, Expand);
}

if (Subtarget->hasWideArithmetic()) {
setOperationAction(ISD::ADD, MVT::i128, Custom);
setOperationAction(ISD::SUB, MVT::i128, Custom);
setOperationAction(ISD::SMUL_LOHI, MVT::i64, Custom);
setOperationAction(ISD::UMUL_LOHI, MVT::i64, Custom);
}

if (Subtarget->hasNontrappingFPToInt())
for (auto Op : {ISD::FP_TO_SINT_SAT, ISD::FP_TO_UINT_SAT})
for (auto T : {MVT::i32, MVT::i64})
Expand Down Expand Up @@ -1443,6 +1450,10 @@ void WebAssemblyTargetLowering::ReplaceNodeResults(
// Do not add any results, signifying that N should not be custom lowered.
// EXTEND_VECTOR_INREG is implemented for some vectors, but not all.
break;
case ISD::ADD:
case ISD::SUB:
Results.push_back(Replace128Op(N, DAG));
break;
default:
llvm_unreachable(
"ReplaceNodeResults not implemented for this op for WebAssembly!");
Expand Down Expand Up @@ -1519,6 +1530,9 @@ SDValue WebAssemblyTargetLowering::LowerOperation(SDValue Op,
return DAG.UnrollVectorOp(Op.getNode());
case ISD::CLEAR_CACHE:
report_fatal_error("llvm.clear_cache is not supported on wasm");
case ISD::SMUL_LOHI:
case ISD::UMUL_LOHI:
return LowerMUL_LOHI(Op, DAG);
}
}

Expand Down Expand Up @@ -1617,6 +1631,63 @@ SDValue WebAssemblyTargetLowering::LowerLoad(SDValue Op,
return Op;
}

SDValue WebAssemblyTargetLowering::LowerMUL_LOHI(SDValue Op,
SelectionDAG &DAG) const {
assert(Subtarget->hasWideArithmetic());
assert(Op.getValueType() == MVT::i64);
SDLoc DL(Op);
unsigned Opcode;
switch (Op.getOpcode()) {
case ISD::UMUL_LOHI:
Opcode = WebAssemblyISD::I64_MUL_WIDE_U;
break;
case ISD::SMUL_LOHI:
Opcode = WebAssemblyISD::I64_MUL_WIDE_S;
break;
default:
llvm_unreachable("unexpected opcode");
}
SDValue LHS = Op.getOperand(0);
SDValue RHS = Op.getOperand(1);
SDValue Hi =
DAG.getNode(Opcode, DL, DAG.getVTList(MVT::i64, MVT::i64), LHS, RHS);
SDValue Lo(Hi.getNode(), 1);
SDValue Ops[] = {Hi, Lo};
return DAG.getMergeValues(Ops, DL);
}

SDValue WebAssemblyTargetLowering::Replace128Op(SDNode *N,
SelectionDAG &DAG) const {
assert(Subtarget->hasWideArithmetic());
auto ValTy = N->getValueType(0);
assert(ValTy == MVT::i128);
SDLoc DL(N);
unsigned Opcode;
switch (N->getOpcode()) {
case ISD::ADD:
Opcode = WebAssemblyISD::I64_ADD128;
break;
case ISD::SUB:
Opcode = WebAssemblyISD::I64_SUB128;
break;
default:
llvm_unreachable("unexpected opcode");
}
SDValue LHS = N->getOperand(0);
SDValue RHS = N->getOperand(1);

SDValue C0 = DAG.getConstant(0, DL, MVT::i64);
SDValue C1 = DAG.getConstant(1, DL, MVT::i64);
SDValue LHS_0 = DAG.getNode(ISD::EXTRACT_ELEMENT, DL, MVT::i64, LHS, C0);
SDValue LHS_1 = DAG.getNode(ISD::EXTRACT_ELEMENT, DL, MVT::i64, LHS, C1);
SDValue RHS_0 = DAG.getNode(ISD::EXTRACT_ELEMENT, DL, MVT::i64, RHS, C0);
SDValue RHS_1 = DAG.getNode(ISD::EXTRACT_ELEMENT, DL, MVT::i64, RHS, C1);
SDValue Result_LO = DAG.getNode(Opcode, DL, DAG.getVTList(MVT::i64, MVT::i64),
LHS_0, LHS_1, RHS_0, RHS_1);
SDValue Result_HI(Result_LO.getNode(), 1);
return DAG.getNode(ISD::BUILD_PAIR, DL, N->getVTList(), Result_LO, Result_HI);
}

SDValue WebAssemblyTargetLowering::LowerCopyToReg(SDValue Op,
SelectionDAG &DAG) const {
SDValue Src = Op.getOperand(2);
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ class WebAssemblyTargetLowering final : public TargetLowering {
SDValue LowerFP_TO_INT_SAT(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerLoad(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerStore(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerMUL_LOHI(SDValue Op, SelectionDAG &DAG) const;
SDValue Replace128Op(SDNode *N, SelectionDAG &DAG) const;

// Custom DAG combine hooks
SDValue
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ def HasTailCall :
Predicate<"Subtarget->hasTailCall()">,
AssemblerPredicate<(all_of FeatureTailCall), "tail-call">;

def HasWideArithmetic :
Predicate<"Subtarget->hasWideArithmetic()">,
AssemblerPredicate<(all_of FeatureWideArithmetic), "wide-arithmetic">;

//===----------------------------------------------------------------------===//
// WebAssembly-specific DAG Node Types.
//===----------------------------------------------------------------------===//
Expand Down
45 changes: 45 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyInstrInteger.td
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,48 @@ def : Pat<(select (i32 (seteq I32:$cond, 0)), I32:$lhs, I32:$rhs),
(SELECT_I32 I32:$rhs, I32:$lhs, I32:$cond)>;
def : Pat<(select (i32 (seteq I32:$cond, 0)), I64:$lhs, I64:$rhs),
(SELECT_I64 I64:$rhs, I64:$lhs, I32:$cond)>;

let Predicates = [HasWideArithmetic] in {
defm I64_ADD128 : I<(outs I64:$lo, I64:$hi), (ins I64:$lhs_lo, I64:$lhs_hi, I64:$rhs_lo, I64:$rhs_hi),
(outs), (ins),
[],
"i64.add128\t$lo, $hi, $lhs_lo, $lhs_hi, $rhs_lo, $rhs_hi",
"i64.add128",
0xfc13>;
defm I64_SUB128 : I<(outs I64:$lo, I64:$hi), (ins I64:$lhs_lo, I64:$lhs_hi, I64:$rhs_lo, I64:$rhs_hi),
(outs), (ins),
[],
"i64.sub128\t$lo, $hi, $lhs_lo, $lhs_hi, $rhs_lo, $rhs_hi",
"i64.sub128",
0xfc14>;
defm I64_MUL_WIDE_S : I<(outs I64:$lo, I64:$hi), (ins I64:$lhs, I64:$rhs),
(outs), (ins),
[],
"i64.mul_wide_s\t$lo, $hi, $lhs, $rhs",
"i64.mul_wide_s",
0xfc15>;
defm I64_MUL_WIDE_U : I<(outs I64:$lo, I64:$hi), (ins I64:$lhs, I64:$rhs),
(outs), (ins),
[],
"i64.mul_wide_u\t$lo, $hi, $lhs, $rhs",
"i64.mul_wide_u",
0xfc16>;
} // Predicates = [HasWideArithmetic]

def wasm_binop128_t : SDTypeProfile<2, 4, []>;
def wasm_add128 : SDNode<"WebAssemblyISD::I64_ADD128", wasm_binop128_t>;
def wasm_sub128 : SDNode<"WebAssemblyISD::I64_SUB128", wasm_binop128_t>;

def : Pat<(wasm_add128 I64:$a, I64:$b, I64:$c, I64:$d),
(I64_ADD128 $a, $b, $c, $d)>;
def : Pat<(wasm_sub128 I64:$a, I64:$b, I64:$c, I64:$d),
(I64_SUB128 $a, $b, $c, $d)>;

def wasm_mul_wide_t : SDTypeProfile<2, 2, []>;
def wasm_mul_wide_s : SDNode<"WebAssemblyISD::I64_MUL_WIDE_S", wasm_mul_wide_t>;
def wasm_mul_wide_u : SDNode<"WebAssemblyISD::I64_MUL_WIDE_U", wasm_mul_wide_t>;

def : Pat<(wasm_mul_wide_s I64:$x, I64:$y),
(I64_MUL_WIDE_S $x, $y)>;
def : Pat<(wasm_mul_wide_u I64:$x, I64:$y),
(I64_MUL_WIDE_U $x, $y)>;
2 changes: 2 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblySubtarget.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class WebAssemblySubtarget final : public WebAssemblyGenSubtargetInfo {
bool HasReferenceTypes = false;
bool HasSignExt = false;
bool HasTailCall = false;
bool HasWideArithmetic = false;

/// What processor and OS we're targeting.
Triple TargetTriple;
Expand Down Expand Up @@ -106,6 +107,7 @@ class WebAssemblySubtarget final : public WebAssemblyGenSubtargetInfo {
bool hasSignExt() const { return HasSignExt; }
bool hasSIMD128() const { return SIMDLevel >= SIMD128; }
bool hasTailCall() const { return HasTailCall; }
bool hasWideArithmetic() const { return HasWideArithmetic; }

/// Parses features string setting specified subtarget options. Definition of
/// function is auto generated by tblgen.
Expand Down
Loading
Loading