diff --git a/src/enzyme_ad/jax/Passes/AffineToStableHLORaising.cpp b/src/enzyme_ad/jax/Passes/AffineToStableHLORaising.cpp index 0a99cd267..5afb3b5ee 100644 --- a/src/enzyme_ad/jax/Passes/AffineToStableHLORaising.cpp +++ b/src/enzyme_ad/jax/Passes/AffineToStableHLORaising.cpp @@ -10,6 +10,7 @@ // //===----------------------------------------------------------------------===// +#include "src/enzyme_ad/jax/Passes/AffineUtils.h" #include "src/enzyme_ad/jax/Passes/Passes.h" #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h" @@ -26,6 +27,13 @@ #include "src/enzyme_ad/jax/Dialect/Ops.h" #include "stablehlo/dialect/StablehloOps.h" +#include +#include +#include +#include +#include +#include +#include namespace mlir { namespace enzyme { @@ -505,6 +513,44 @@ expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, llvm_unreachable("unreachable"); } +/// scope is an operation _in_ the scope we are interested in +bool isSafeToSpeculativelyExecuteAtScope(Operation *scope, Operation *op) { + if (mlir::isPure(op)) + return true; + + MemRefType ty = nullptr; + if (auto read = dyn_cast(op)) + ty = read.getMemRefType(); + if (!ty) + return false; + + IslAnalysis ia; + + isl_set *array = ia.getMemrefShape(ty); + if (!array) + return false; + + isl_map *accessMap = ia.getAccessMap(op); + if (!accessMap) { + isl_set_free(array); + return false; + } + + isl_set *domain = ia.getDomain(scope); + if (!domain) { + isl_set_free(array); + isl_map_free(accessMap); + return false; + } + isl_set *accessed = isl_set_apply(domain, accessMap); + isl_bool inBounds = isl_set_is_subset(accessed, array); + isl_set_free(array); + isl_set_free(accessed); + if (inBounds == isl_bool_error) + return false; + return inBounds; +} + static LogicalResult tryRaisingOpToStableHLO(Operation *op, IRMapping &mapping, OpBuilder &builder, llvm::DenseMap &maps) { @@ -796,9 +842,12 @@ tryRaisingOpToStableHLO(Operation *op, IRMapping &mapping, OpBuilder &builder, if (auto ifOp = dyn_cast(op)) { if (!ifOp.hasElse() || ifOp->getNumResults() == 0 || llvm::any_of(*ifOp.getThenBlock(), - [](Operation &op) { return !mlir::isPure(&op); }) || - llvm::any_of(*ifOp.getElseBlock(), - [](Operation &op) { return !mlir::isPure(&op); })) { + [ifOp](Operation &op) { + return !isSafeToSpeculativelyExecuteAtScope(ifOp, &op); + }) || + llvm::any_of(*ifOp.getElseBlock(), [ifOp](Operation &op) { + return !isSafeToSpeculativelyExecuteAtScope(ifOp, &op); + })) { LLVM_DEBUG(llvm::dbgs() << "cannot raise if yet (non-pure or yielded values): " << *op << "\n"); diff --git a/src/enzyme_ad/jax/Passes/AffineUtils.h b/src/enzyme_ad/jax/Passes/AffineUtils.h new file mode 100644 index 000000000..ce252242e --- /dev/null +++ b/src/enzyme_ad/jax/Passes/AffineUtils.h @@ -0,0 +1,45 @@ +#ifndef ENZYME_JAX_PASSES_AFFINEUTILS_H_ +#define ENZYME_JAX_PASSES_AFFINEUTILS_H_ + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/IR/AffineValueMap.h" +#include "llvm/ADT/SmallVector.h" + +#include +#include + +namespace mlir { + +mlir::affine::AffineValueMap getAVM(mlir::Operation *op); + +class IslAnalysis { +public: + std::optional> + getAffExprs(mlir::Operation *op, mlir::affine::AffineValueMap avm); + + std::optional> getAffExprs(mlir::Operation *op); + + isl_map *getAccessMap(mlir::Operation *op); + + isl_set *getDomain(Operation *op); + + isl_set *getMemrefShape(MemRefType ty); + + ~IslAnalysis(); + IslAnalysis(); + +private: + isl_ctx *ctx; + DenseMap vToIdMap; +}; + +template class IslScopeFree { +public: + T obj; + IslScopeFree(T obj) : obj(obj) {} + ~IslScopeFree() { isl_set_free(obj); } +}; + +} // namespace mlir + +#endif // ENZYME_JAX_PASSES_AFFINEUTILS_H_ diff --git a/src/enzyme_ad/jax/Passes/CanonicalizeFor.cpp b/src/enzyme_ad/jax/Passes/CanonicalizeFor.cpp index 0ca7207e4..8ed86b5f4 100644 --- a/src/enzyme_ad/jax/Passes/CanonicalizeFor.cpp +++ b/src/enzyme_ad/jax/Passes/CanonicalizeFor.cpp @@ -1122,7 +1122,7 @@ struct MoveDoWhileToFor : public OpRewritePattern { // Check to see if doBlock just has yield op Block &doBlock = whileOp.getAfter().front(); if (!isa(doBlock.front())) - return failure(); + return rewriter.notifyMatchFailure(whileOp, "non empty then block"); // Before block analysis Block &beforeBlock = whileOp.getBefore().front(); @@ -1142,10 +1142,10 @@ struct MoveDoWhileToFor : public OpRewritePattern { upperBound = cmpOp.getLhs(); compareValue = cmpOp.getRhs(); } else - return failure(); + return rewriter.notifyMatchFailure(whileOp, "cmp against non constant"); } else { // Currently only supporting arith.cmpIOp - return failure(); + return rewriter.notifyMatchFailure(whileOp, "cmp not arith.cmpIOp"); } // Get condition op args and find IV index @@ -1160,7 +1160,7 @@ struct MoveDoWhileToFor : public OpRewritePattern { IVIndex++; } if (!indexFound) - return failure(); + return rewriter.notifyMatchFailure(whileOp, "Did not find index"); // Extract IV and lowerBound based on IVIndex Value IV = beforeBlock.getArgument(IVIndex); diff --git a/src/enzyme_ad/jax/Passes/SimplifyAffineExprs.cpp b/src/enzyme_ad/jax/Passes/SimplifyAffineExprs.cpp index d8584d026..cbff52b4a 100644 --- a/src/enzyme_ad/jax/Passes/SimplifyAffineExprs.cpp +++ b/src/enzyme_ad/jax/Passes/SimplifyAffineExprs.cpp @@ -1,4 +1,5 @@ +#include "AffineUtils.h" #include "Passes.h" #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h" #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" @@ -11,6 +12,7 @@ #include "src/enzyme_ad/jax/Dialect/Ops.h" #include +#include #include #include #include @@ -18,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -591,6 +594,154 @@ struct IslToAffineExprConverter { } }; +namespace mlir { +AffineValueMap getAVM(Operation *op) { + if (auto cop = dyn_cast(op)) + return AffineValueMap(cop.getMap(), cop.getMapOperands(), {}); + else if (auto cop = dyn_cast(op)) + return AffineValueMap(cop.getMap(), cop.getMapOperands(), {}); + else if (auto cop = dyn_cast(op)) + return AffineValueMap(cop.getMap(), cop.getMapOperands(), {}); + else if (auto cop = dyn_cast(op)) + return AffineValueMap(cop.getMap(), cop.getMapOperands(), {}); + llvm_unreachable("Called with non affine op"); +} +} // namespace mlir + +isl_set *IslAnalysis::getMemrefShape(MemRefType ty) { + // TODO we can support params in some cases + if (!ty.hasStaticShape()) + return nullptr; + isl_space *space = isl_space_set_alloc(ctx, 0, ty.getRank()); + isl_multi_aff *ma = + isl_multi_aff_identity_on_domain_space(isl_space_copy(space)); + isl_set *set = isl_set_universe(isl_space_copy(space)); + for (unsigned i = 0; i < ty.getRank(); i++) { + isl_aff *dim = isl_multi_aff_get_at(ma, i); + isl_aff *lb = isl_aff_val_on_domain_space(isl_space_copy(space), + isl_val_int_from_si(ctx, 0)); + isl_aff *ub = isl_aff_val_on_domain_space( + isl_space_copy(space), isl_val_int_from_si(ctx, ty.getDimSize(i))); + + set = isl_set_intersect(set, isl_aff_ge_set(isl_aff_copy(dim), lb)); + set = isl_set_intersect(set, isl_aff_lt_set(isl_aff_copy(dim), ub)); + isl_aff_free(dim); + } + isl_space_free(space); + isl_multi_aff_free(ma); + + return set; +} + +isl_map *IslAnalysis::getAccessMap(mlir::Operation *op) { + auto exprs = getAffExprs(op); + if (!exprs) + return nullptr; + isl_aff_list *list = isl_aff_list_alloc(ctx, exprs->size()); + isl_space *domain = isl_space_domain(isl_aff_get_space((*exprs)[0])); + isl_space *range = isl_space_set_alloc(ctx, 0, exprs->size()); + isl_space *space = isl_space_map_from_domain_and_range(domain, range); + for (auto aff : *exprs) { + assert(isl_space_dim(isl_aff_get_space(aff), isl_dim_param) == 0 && + "only no-parameter aff supported currently"); + list = isl_aff_list_add(list, aff); + } + isl_multi_aff *maff = isl_multi_aff_from_aff_list(space, list); + return isl_map_from_multi_aff(maff); +} + +std::optional> +IslAnalysis::getAffExprs(Operation *op, AffineValueMap avm) { + LLVM_DEBUG(llvm::dbgs() << "Got domain\n"); + auto [domain, cst] = ::getDomain(ctx, op); + LLVM_DEBUG(isl_set_dump(domain)); + LLVM_DEBUG(cst.dump()); + AffineMap map = avm.getAffineMap(); + + LLVM_DEBUG(llvm::dbgs() << "Mapping dims:\n"); + PosMapTy dimPosMap; + PosMapTy dimPosMapReverse; + for (unsigned i = 0; i < cst.getNumDimVars(); i++) { + Value cstVal = cst.getValue(i); + LLVM_DEBUG(llvm::dbgs() << "cstVal " << cstVal << "\n"); + for (unsigned origDim = 0; origDim < map.getNumDims(); origDim++) { + Value dim = avm.getOperand(origDim); + LLVM_DEBUG(llvm::dbgs() << "dim " << dim << "\n"); + if (cstVal == dim) { + LLVM_DEBUG(llvm::dbgs() << origDim << " <--> " << i << "\n"); + dimPosMap[origDim] = i; + dimPosMapReverse[i] = origDim; + break; + } + } + } + + if (avm.getNumSymbols() != 0 || cst.getNumSymbolVars() != 0) { + // TODO While the fact that all dims from the map _must_ appear in the cst, + // this is not the case for symbols. We do not handle that case correctly + // currently, thus we abort early. + domain = isl_set_free(domain); + return std::nullopt; + } + + LLVM_DEBUG(llvm::dbgs() << "Mapping syms:\n"); + PosMapTy symPosMap; + PosMapTy symPosMapReverse; + for (unsigned i = 0; i < cst.getNumSymbolVars(); i++) { + for (unsigned origSym = 0; origSym < map.getNumSymbols(); origSym++) { + Value dim = avm.getOperand(origSym + map.getNumDims()); + if (cst.getValue(i + cst.getNumDimVars()) == dim) { + LLVM_DEBUG(llvm::dbgs() << origSym << " <--> " << i << "\n"); + symPosMap[origSym] = i; + symPosMapReverse[i] = origSym; + break; + } + } + } + + isl_space *space = + isl_space_set_alloc(ctx, cst.getNumSymbolVars(), cst.getNumDimVars()); + for (unsigned i = 0; i < cst.getNumDimVars(); i++) { + isl_id *id = isl_id_alloc(ctx, "dim", (void *)(size_t)(i + 1)); + space = isl_space_set_dim_id(space, isl_dim_set, i, id); + } + unsigned symOffset = cst.getNumDimVars(); + for (unsigned i = 0; i < cst.getNumSymbolVars(); i++) { + isl_id *id = isl_id_alloc(ctx, "sym", (void *)(size_t)(symOffset + i + 1)); + space = isl_space_set_dim_id(space, isl_dim_set, i, id); + } + + isl_local_space *ls = isl_local_space_from_space(isl_space_copy(space)); + space = isl_space_free(space); + AffineExprToIslAffConverter m2i{dimPosMap, symPosMap, ls, ctx}; + SmallVector affVec; + for (unsigned i = 0; i < map.getNumResults(); i++) { + AffineExpr mlirExpr = map.getResult(i); + LLVM_DEBUG(llvm::dbgs() << "Handling AffineExpr\n" << mlirExpr << "\n"); + LLVM_DEBUG(llvm::dbgs() << "Got aff\n"); + isl_aff *aff = m2i.getIslAff(mlirExpr); + affVec.push_back(aff); + } + ls = isl_local_space_free(ls); + domain = isl_set_free(domain); + + return affVec; +} + +isl_set *IslAnalysis::getDomain(Operation *op) { + auto [domain, cst] = ::getDomain(ctx, op); + + return domain; +} + +std::optional> IslAnalysis::getAffExprs(Operation *op) { + return getAffExprs(op, getAVM(op)); +} + +IslAnalysis::IslAnalysis() { ctx = isl_ctx_alloc(); } + +IslAnalysis::~IslAnalysis() { isl_ctx_free(ctx); } + template void handleAffineOp(isl_ctx *ctx, T load) { LLVM_DEBUG(llvm::dbgs() << "Got domain\n"); auto [domain, cst] = getDomain(ctx, load); diff --git a/test/lit_tests/raising/affine_to_stablehlo11.mlir b/test/lit_tests/raising/affine_to_stablehlo11.mlir index a5828ebfe..3f661dc0f 100644 --- a/test/lit_tests/raising/affine_to_stablehlo11.mlir +++ b/test/lit_tests/raising/affine_to_stablehlo11.mlir @@ -1,6 +1,7 @@ // RUN: enzymexlamlir-opt %s --pass-pipeline="builtin.module(raise-affine-to-stablehlo,enzyme-hlo-opt{max_constant_expansion=0})" | FileCheck %s #set = affine_set<(d0) : (-d0 + 89 >= 0)> +#set2 = affine_set<(d0) : (-d0 + 70 >= 0)> module { func.func private @call__Z31gpu__fill_south_and_north_halo(%arg0: memref<194xf64, 1>) { diff --git a/test/lit_tests/raising/affine_to_stablehlo12.mlir b/test/lit_tests/raising/affine_to_stablehlo12.mlir new file mode 100644 index 000000000..23b24045f --- /dev/null +++ b/test/lit_tests/raising/affine_to_stablehlo12.mlir @@ -0,0 +1,95 @@ +// RUN: enzymexlamlir-opt %s --split-input-file --pass-pipeline="builtin.module(raise-affine-to-stablehlo,enzyme-hlo-opt{max_constant_expansion=0})" | FileCheck %s + +#set = affine_set<(d0) : (-d0 + 89 >= 0)> +#set2 = affine_set<(d0) : (-d0 + 70 >= 0)> + func.func private @if_with_load(%m1: memref<194xf64, 1>, %m2: memref<194xf64, 1>, %m3: memref<194xf64, 1>) { + affine.parallel (%arg1) = (0) to (191) { + affine.if #set2(%arg1) { + %ld = affine.load %m1[%arg1 + 2] : memref<194xf64, 1> + affine.store %ld, %m3[%arg1] : memref<194xf64, 1> + } else { + %ld = affine.load %m2[%arg1 + 3] : memref<194xf64, 1> + affine.store %ld, %m3[%arg1] : memref<194xf64, 1> + } + } + return + } + // ----- +#set = affine_set<(d0) : (-d0 + 89 >= 0)> +#set2 = affine_set<(d0) : (-d0 + 70 >= 0)> + func.func private @if_yield_with_load(%m1: memref<194xf64, 1>, %m2: memref<194xf64, 1>, %m3: memref<194xf64, 1>) { + affine.parallel (%arg1) = (0) to (191) { + %1 = affine.if #set2(%arg1) -> f64 { + %ld = affine.load %m1[%arg1 + 2] : memref<194xf64, 1> + affine.yield %ld : f64 + } else { + %ld = affine.load %m2[%arg1 + 3] : memref<194xf64, 1> + affine.yield %ld : f64 + } + affine.store %1, %m3[%arg1] : memref<194xf64, 1> + } + return + } + // ----- +#set = affine_set<(d0) : (-d0 + 89 >= 0)> +#set2 = affine_set<(d0) : (-d0 + 70 >= 0)> + func.func private @with_load_out_of_bounds(%m1: memref<194xf64, 1>, %m2: memref<194xf64, 1>, %m3: memref<194xf64, 1>) { + affine.parallel (%arg1) = (0) to (193) { + %1 = affine.if #set2(%arg1) -> f64 { + %ld = affine.load %m1[%arg1 + 2] : memref<194xf64, 1> + affine.yield %ld : f64 + } else { + %ld = affine.load %m2[%arg1 + 3] : memref<194xf64, 1> + affine.yield %ld : f64 + } + affine.store %1, %m3[%arg1] : memref<194xf64, 1> + } + return + } + // ----- +#set = affine_set<(d0) : (-d0 + 89 >= 0)> +#set2 = affine_set<(d0) : (-d0 + 70 >= 0)> + func.func private @if_with_multidimload(%m1: memref<100x194xf64, 1>, %m2: memref<100x194xf64, 1>, %m3: memref<100x194xf64, 1>) { + affine.parallel (%a1, %arg1) = (2, 0) to (100, 192) { + %1 = affine.if #set2(%arg1) -> f64 { + %ld = affine.load %m1[%a1 - 2, %arg1 + 2] : memref<100x194xf64, 1> + affine.yield %ld : f64 + } else { + %ld = affine.load %m2[%a1 - 2, %arg1 + 2] : memref<100x194xf64, 1> + affine.yield %ld : f64 + } + affine.store %1, %m3[%a1, %arg1] : memref<100x194xf64, 1> + } + return + } + // ----- +#set = affine_set<(d0) : (-d0 + 89 >= 0)> +#set2 = affine_set<(d0) : (-d0 + 70 >= 0)> + func.func private @if_with_multidimload_out_of_bounds(%m1: memref<100x194xf64, 1>, %m2: memref<100x194xf64, 1>, %m3: memref<100x194xf64, 1>) { + affine.parallel (%a1, %arg1) = (1, 0) to (100, 192) { + %1 = affine.if #set2(%arg1) -> f64 { + %ld = affine.load %m1[%a1 - 2, %arg1 + 2] : memref<100x194xf64, 1> + affine.yield %ld : f64 + } else { + %ld = affine.load %m2[%a1 - 2, %arg1 + 2] : memref<100x194xf64, 1> + affine.yield %ld : f64 + } + affine.store %1, %m3[%a1, %arg1] : memref<100x194xf64, 1> + } + return + } + +// CHECK-LABEL: func.func private @if_with_load( +// CHECK: affine + +// CHECK-LABEL: func.func private @if_yield_with_load_raised( +// CHECK: stablehlo + +// CHECK-LABEL: func.func private @with_load_out_of_bounds( +// CHECK: affine + +// CHECK-LABEL: func.func private @if_with_multidimload_raised( +// CHECK: stablehlo + +// CHECK-LABEL: func.func private @if_with_multidimload_out_of_bounds +// CHECK: affine