Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow raising ifs with speculatively executable loads #443

Merged
merged 7 commits into from
Mar 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 52 additions & 3 deletions src/enzyme_ad/jax/Passes/AffineToStableHLORaising.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
//
//===----------------------------------------------------------------------===//

#include "src/enzyme_ad/jax/Passes/AffineUtils.h"
#include "src/enzyme_ad/jax/Passes/Passes.h"

#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
Expand All @@ -26,6 +27,13 @@

#include "src/enzyme_ad/jax/Dialect/Ops.h"
#include "stablehlo/dialect/StablehloOps.h"
#include <isl/ctx.h>
#include <isl/ilp.h>
#include <isl/map.h>
#include <isl/set.h>
#include <isl/space.h>
#include <isl/val.h>
#include <optional>

namespace mlir {
namespace enzyme {
Expand Down Expand Up @@ -505,6 +513,44 @@ expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr,
llvm_unreachable("unreachable");
}

/// scope is an operation _in_ the scope we are interested in
bool isSafeToSpeculativelyExecuteAtScope(Operation *scope, Operation *op) {
if (mlir::isPure(op))
return true;

MemRefType ty = nullptr;
if (auto read = dyn_cast<affine::AffineReadOpInterface>(op))
ty = read.getMemRefType();
if (!ty)
return false;

IslAnalysis ia;

isl_set *array = ia.getMemrefShape(ty);
if (!array)
return false;

isl_map *accessMap = ia.getAccessMap(op);
if (!accessMap) {
isl_set_free(array);
return false;
}

isl_set *domain = ia.getDomain(scope);
if (!domain) {
isl_set_free(array);
isl_map_free(accessMap);
return false;
}
isl_set *accessed = isl_set_apply(domain, accessMap);
isl_bool inBounds = isl_set_is_subset(accessed, array);
isl_set_free(array);
isl_set_free(accessed);
if (inBounds == isl_bool_error)
return false;
return inBounds;
}

static LogicalResult
tryRaisingOpToStableHLO(Operation *op, IRMapping &mapping, OpBuilder &builder,
llvm::DenseMap<Value, affine::AffineValueMap> &maps) {
Expand Down Expand Up @@ -796,9 +842,12 @@ tryRaisingOpToStableHLO(Operation *op, IRMapping &mapping, OpBuilder &builder,
if (auto ifOp = dyn_cast<affine::AffineIfOp>(op)) {
if (!ifOp.hasElse() || ifOp->getNumResults() == 0 ||
llvm::any_of(*ifOp.getThenBlock(),
[](Operation &op) { return !mlir::isPure(&op); }) ||
llvm::any_of(*ifOp.getElseBlock(),
[](Operation &op) { return !mlir::isPure(&op); })) {
[ifOp](Operation &op) {
return !isSafeToSpeculativelyExecuteAtScope(ifOp, &op);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little worried about correctness here.

Specifically if you have an if of an if.

We want to use the scope of the outermost if (assuming that the outer if recursively calls this utility)

}) ||
llvm::any_of(*ifOp.getElseBlock(), [ifOp](Operation &op) {
return !isSafeToSpeculativelyExecuteAtScope(ifOp, &op);
})) {
LLVM_DEBUG(llvm::dbgs()
<< "cannot raise if yet (non-pure or yielded values): " << *op
<< "\n");
Expand Down
45 changes: 45 additions & 0 deletions src/enzyme_ad/jax/Passes/AffineUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#ifndef ENZYME_JAX_PASSES_AFFINEUTILS_H_
#define ENZYME_JAX_PASSES_AFFINEUTILS_H_

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
#include "llvm/ADT/SmallVector.h"

#include <isl/aff.h>
#include <isl/set.h>

namespace mlir {

mlir::affine::AffineValueMap getAVM(mlir::Operation *op);

class IslAnalysis {
public:
std::optional<llvm::SmallVector<isl_aff *>>
getAffExprs(mlir::Operation *op, mlir::affine::AffineValueMap avm);

std::optional<llvm::SmallVector<isl_aff *>> getAffExprs(mlir::Operation *op);

isl_map *getAccessMap(mlir::Operation *op);

isl_set *getDomain(Operation *op);

isl_set *getMemrefShape(MemRefType ty);

~IslAnalysis();
IslAnalysis();

private:
isl_ctx *ctx;
DenseMap<Value, isl_id *> vToIdMap;
};

template <typename T> class IslScopeFree {
public:
T obj;
IslScopeFree(T obj) : obj(obj) {}
~IslScopeFree() { isl_set_free(obj); }
};

} // namespace mlir

#endif // ENZYME_JAX_PASSES_AFFINEUTILS_H_
8 changes: 4 additions & 4 deletions src/enzyme_ad/jax/Passes/CanonicalizeFor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1122,7 +1122,7 @@ struct MoveDoWhileToFor : public OpRewritePattern<WhileOp> {
// Check to see if doBlock just has yield op
Block &doBlock = whileOp.getAfter().front();
if (!isa<scf::YieldOp>(doBlock.front()))
return failure();
return rewriter.notifyMatchFailure(whileOp, "non empty then block");

// Before block analysis
Block &beforeBlock = whileOp.getBefore().front();
Expand All @@ -1142,10 +1142,10 @@ struct MoveDoWhileToFor : public OpRewritePattern<WhileOp> {
upperBound = cmpOp.getLhs();
compareValue = cmpOp.getRhs();
} else
return failure();
return rewriter.notifyMatchFailure(whileOp, "cmp against non constant");
} else {
// Currently only supporting arith.cmpIOp
return failure();
return rewriter.notifyMatchFailure(whileOp, "cmp not arith.cmpIOp");
}

// Get condition op args and find IV index
Expand All @@ -1160,7 +1160,7 @@ struct MoveDoWhileToFor : public OpRewritePattern<WhileOp> {
IVIndex++;
}
if (!indexFound)
return failure();
return rewriter.notifyMatchFailure(whileOp, "Did not find index");

// Extract IV and lowerBound based on IVIndex
Value IV = beforeBlock.getArgument(IVIndex);
Expand Down
151 changes: 151 additions & 0 deletions src/enzyme_ad/jax/Passes/SimplifyAffineExprs.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@

#include "AffineUtils.h"
#include "Passes.h"
#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
Expand All @@ -11,13 +12,15 @@
#include "src/enzyme_ad/jax/Dialect/Ops.h"

#include <isl/aff.h>
#include <isl/aff_type.h>
#include <isl/ast.h>
#include <isl/ast_build.h>
#include <isl/constraint.h>
#include <isl/ctx.h>
#include <isl/id.h>
#include <isl/local_space.h>
#include <isl/map.h>
#include <isl/map_type.h>
#include <isl/mat.h>
#include <isl/set.h>
#include <isl/space.h>
Expand Down Expand Up @@ -591,6 +594,154 @@ struct IslToAffineExprConverter {
}
};

namespace mlir {
AffineValueMap getAVM(Operation *op) {
if (auto cop = dyn_cast<AffineLoadOp>(op))
return AffineValueMap(cop.getMap(), cop.getMapOperands(), {});
else if (auto cop = dyn_cast<AffineStoreOp>(op))
return AffineValueMap(cop.getMap(), cop.getMapOperands(), {});
else if (auto cop = dyn_cast<AffineVectorLoadOp>(op))
return AffineValueMap(cop.getMap(), cop.getMapOperands(), {});
else if (auto cop = dyn_cast<AffineVectorStoreOp>(op))
return AffineValueMap(cop.getMap(), cop.getMapOperands(), {});
llvm_unreachable("Called with non affine op");
}
} // namespace mlir

isl_set *IslAnalysis::getMemrefShape(MemRefType ty) {
// TODO we can support params in some cases
if (!ty.hasStaticShape())
return nullptr;
isl_space *space = isl_space_set_alloc(ctx, 0, ty.getRank());
isl_multi_aff *ma =
isl_multi_aff_identity_on_domain_space(isl_space_copy(space));
isl_set *set = isl_set_universe(isl_space_copy(space));
for (unsigned i = 0; i < ty.getRank(); i++) {
isl_aff *dim = isl_multi_aff_get_at(ma, i);
isl_aff *lb = isl_aff_val_on_domain_space(isl_space_copy(space),
isl_val_int_from_si(ctx, 0));
isl_aff *ub = isl_aff_val_on_domain_space(
isl_space_copy(space), isl_val_int_from_si(ctx, ty.getDimSize(i)));

set = isl_set_intersect(set, isl_aff_ge_set(isl_aff_copy(dim), lb));
set = isl_set_intersect(set, isl_aff_lt_set(isl_aff_copy(dim), ub));
isl_aff_free(dim);
}
isl_space_free(space);
isl_multi_aff_free(ma);

return set;
}

isl_map *IslAnalysis::getAccessMap(mlir::Operation *op) {
auto exprs = getAffExprs(op);
if (!exprs)
return nullptr;
isl_aff_list *list = isl_aff_list_alloc(ctx, exprs->size());
isl_space *domain = isl_space_domain(isl_aff_get_space((*exprs)[0]));
isl_space *range = isl_space_set_alloc(ctx, 0, exprs->size());
isl_space *space = isl_space_map_from_domain_and_range(domain, range);
for (auto aff : *exprs) {
assert(isl_space_dim(isl_aff_get_space(aff), isl_dim_param) == 0 &&
"only no-parameter aff supported currently");
list = isl_aff_list_add(list, aff);
}
isl_multi_aff *maff = isl_multi_aff_from_aff_list(space, list);
return isl_map_from_multi_aff(maff);
}

std::optional<SmallVector<isl_aff *>>
IslAnalysis::getAffExprs(Operation *op, AffineValueMap avm) {
LLVM_DEBUG(llvm::dbgs() << "Got domain\n");
auto [domain, cst] = ::getDomain(ctx, op);
LLVM_DEBUG(isl_set_dump(domain));
LLVM_DEBUG(cst.dump());
AffineMap map = avm.getAffineMap();

LLVM_DEBUG(llvm::dbgs() << "Mapping dims:\n");
PosMapTy dimPosMap;
PosMapTy dimPosMapReverse;
for (unsigned i = 0; i < cst.getNumDimVars(); i++) {
Value cstVal = cst.getValue(i);
LLVM_DEBUG(llvm::dbgs() << "cstVal " << cstVal << "\n");
for (unsigned origDim = 0; origDim < map.getNumDims(); origDim++) {
Value dim = avm.getOperand(origDim);
LLVM_DEBUG(llvm::dbgs() << "dim " << dim << "\n");
if (cstVal == dim) {
LLVM_DEBUG(llvm::dbgs() << origDim << " <--> " << i << "\n");
dimPosMap[origDim] = i;
dimPosMapReverse[i] = origDim;
break;
}
}
}

if (avm.getNumSymbols() != 0 || cst.getNumSymbolVars() != 0) {
// TODO While the fact that all dims from the map _must_ appear in the cst,
// this is not the case for symbols. We do not handle that case correctly
// currently, thus we abort early.
domain = isl_set_free(domain);
return std::nullopt;
}

LLVM_DEBUG(llvm::dbgs() << "Mapping syms:\n");
PosMapTy symPosMap;
PosMapTy symPosMapReverse;
for (unsigned i = 0; i < cst.getNumSymbolVars(); i++) {
for (unsigned origSym = 0; origSym < map.getNumSymbols(); origSym++) {
Value dim = avm.getOperand(origSym + map.getNumDims());
if (cst.getValue(i + cst.getNumDimVars()) == dim) {
LLVM_DEBUG(llvm::dbgs() << origSym << " <--> " << i << "\n");
symPosMap[origSym] = i;
symPosMapReverse[i] = origSym;
break;
}
}
}

isl_space *space =
isl_space_set_alloc(ctx, cst.getNumSymbolVars(), cst.getNumDimVars());
for (unsigned i = 0; i < cst.getNumDimVars(); i++) {
isl_id *id = isl_id_alloc(ctx, "dim", (void *)(size_t)(i + 1));
space = isl_space_set_dim_id(space, isl_dim_set, i, id);
}
unsigned symOffset = cst.getNumDimVars();
for (unsigned i = 0; i < cst.getNumSymbolVars(); i++) {
isl_id *id = isl_id_alloc(ctx, "sym", (void *)(size_t)(symOffset + i + 1));
space = isl_space_set_dim_id(space, isl_dim_set, i, id);
}

isl_local_space *ls = isl_local_space_from_space(isl_space_copy(space));
space = isl_space_free(space);
AffineExprToIslAffConverter m2i{dimPosMap, symPosMap, ls, ctx};
SmallVector<isl_aff *> affVec;
for (unsigned i = 0; i < map.getNumResults(); i++) {
AffineExpr mlirExpr = map.getResult(i);
LLVM_DEBUG(llvm::dbgs() << "Handling AffineExpr\n" << mlirExpr << "\n");
LLVM_DEBUG(llvm::dbgs() << "Got aff\n");
isl_aff *aff = m2i.getIslAff(mlirExpr);
affVec.push_back(aff);
}
ls = isl_local_space_free(ls);
domain = isl_set_free(domain);

return affVec;
}

isl_set *IslAnalysis::getDomain(Operation *op) {
auto [domain, cst] = ::getDomain(ctx, op);

return domain;
}

std::optional<SmallVector<isl_aff *>> IslAnalysis::getAffExprs(Operation *op) {
return getAffExprs(op, getAVM(op));
}

IslAnalysis::IslAnalysis() { ctx = isl_ctx_alloc(); }

IslAnalysis::~IslAnalysis() { isl_ctx_free(ctx); }

template <typename T> void handleAffineOp(isl_ctx *ctx, T load) {
LLVM_DEBUG(llvm::dbgs() << "Got domain\n");
auto [domain, cst] = getDomain(ctx, load);
Expand Down
1 change: 1 addition & 0 deletions test/lit_tests/raising/affine_to_stablehlo11.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// RUN: enzymexlamlir-opt %s --pass-pipeline="builtin.module(raise-affine-to-stablehlo,enzyme-hlo-opt{max_constant_expansion=0})" | FileCheck %s

#set = affine_set<(d0) : (-d0 + 89 >= 0)>
#set2 = affine_set<(d0) : (-d0 + 70 >= 0)>

module {
func.func private @call__Z31gpu__fill_south_and_north_halo(%arg0: memref<194xf64, 1>) {
Expand Down
Loading