diff --git a/src/enzyme_ad/jax/Passes/AffineCFG.cpp b/src/enzyme_ad/jax/Passes/AffineCFG.cpp index 5934c915f..02876c4d2 100644 --- a/src/enzyme_ad/jax/Passes/AffineCFG.cpp +++ b/src/enzyme_ad/jax/Passes/AffineCFG.cpp @@ -1933,14 +1933,1112 @@ struct ParallelOpRaising : public OpRewritePattern { } }; +static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, + Region ®ion, ValueRange blockArgs = {}) { + assert(llvm::hasSingleElement(region) && "expected single-region block"); + Block *block = ®ion.front(); + Operation *terminator = block->getTerminator(); + ValueRange results = terminator->getOperands(); + rewriter.inlineBlockBefore(block, op, blockArgs); + rewriter.replaceOp(op, results); + rewriter.eraseOp(terminator); +} + +struct AffineIfSimplification : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(affine::AffineIfOp op, + PatternRewriter &rewriter) const override { + SmallVector todo; + SmallVector eqFlags; + bool knownFalse = false; + bool removed = false; + for (auto cst : llvm::enumerate(op.getIntegerSet().getConstraints())) { + auto opd = cst.value().dyn_cast(); + if (!opd) { + if (op.getIntegerSet().isEq(cst.index())) { + if (auto bop = cst.value().dyn_cast()) { + if (bop.getKind() == AffineExprKind::Mul && + bop.getRHS().getKind() == AffineExprKind::Constant) { + removed = true; + if (bop.getRHS().cast().getValue() != 0) { + todo.push_back(bop.getLHS()); + eqFlags.push_back(op.getIntegerSet().isEq(cst.index())); + } + continue; + } + if (bop.getKind() == AffineExprKind::Add && + valueCmp(Cmp::GE, bop, op.getIntegerSet().getNumDims(), + op.getOperands(), 0)) { + todo.push_back(bop.getLHS()); + eqFlags.push_back(op.getIntegerSet().isEq(cst.index())); + todo.push_back(bop.getRHS()); + eqFlags.push_back(op.getIntegerSet().isEq(cst.index())); + removed = true; + continue; + } + } + } + + bool canRemove = false; + for (auto paren = op->getParentOfType(); paren; + paren = paren->getParentOfType()) { + for (auto cst2 : paren.getIntegerSet().getConstraints()) { + if (paren.getElseRegion().isAncestor(op->getParentRegion())) + continue; + if (cst2 == cst.value() && + paren.getIntegerSet().getNumDims() == + op.getIntegerSet().getNumDims() && + paren.getIntegerSet().getNumSymbols() == + op.getIntegerSet().getNumSymbols() && + llvm::all_of(llvm::zip(paren.getOperands(), op.getOperands()), + [](std::tuple p) { + return std::get<0>(p) == std::get<1>(p); + })) { + canRemove = true; + break; + } + } + if (canRemove) + break; + } + //// expr -1 >= 0 => expr > 0 + if (!op.getIntegerSet().isEq(cst.index())) { + auto expr = cst.value() + 1; + for (auto paren = op->getParentOfType(); + paren; + paren = paren->getParentOfType()) { + if (canRemove) + break; + for (auto tup : llvm::enumerate(paren.getSteps())) { + bool found = false; + for (auto ub : paren.getUpperBoundMap(tup.index()).getResults()) { + if (auto exprS = expr.dyn_cast()) { + if (auto ubS = ub.dyn_cast()) { + if (op.getOperands()[exprS.getPosition() + + op.getIntegerSet().getNumDims()] == + paren.getUpperBoundsOperands()[ubS.getPosition() + + paren.getUpperBoundsMap() + .getNumDims()]) { + + found = true; + break; + } + } + } + } + if (!found) + continue; + + if (!valueCmp(Cmp::GE, paren.getIVs()[tup.index()], 0)) + continue; + + canRemove = true; + break; + } + } + if (auto bop = cst.value().dyn_cast()) { + if (bop.getKind() == AffineExprKind::Add) { + } + } + } + if (canRemove) { + removed = true; + continue; + } + + todo.push_back(cst.value()); + eqFlags.push_back(op.getIntegerSet().isEq(cst.index())); + continue; + } + removed = true; + + if (op.getIntegerSet().isEq(cst.index())) { + if (opd.getValue() != 0) { + knownFalse = true; + break; + } + } + if (!(opd.getValue() >= 0)) { + knownFalse = true; + break; + } + } + + if (knownFalse) { + todo.clear(); + } + + if (todo.size() == 0) { + + if (!knownFalse) + replaceOpWithRegion(rewriter, op, op.getThenRegion()); + else if (!op.getElseRegion().empty()) + replaceOpWithRegion(rewriter, op, op.getElseRegion()); + else + rewriter.eraseOp(op); + + return success(); + } + + if (!removed) + return failure(); + + auto iset = + IntegerSet::get(op.getIntegerSet().getNumDims(), + op.getIntegerSet().getNumSymbols(), todo, eqFlags); + + auto newIf = rewriter.create( + op.getLoc(), op.getResultTypes(), iset, op.getOperands(), + /*hasElse*/ true); + rewriter.eraseBlock(newIf.getThenBlock()); + rewriter.eraseBlock(newIf.getElseBlock()); + rewriter.inlineRegionBefore(op.getThenRegion(), newIf.getThenRegion(), + newIf.getThenRegion().begin()); + rewriter.inlineRegionBefore(op.getElseRegion(), newIf.getElseRegion(), + newIf.getElseRegion().begin()); + rewriter.replaceOp(op, newIf.getResults()); + return success(); + } +}; + +struct CombineAffineIfs : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(affine::AffineIfOp nextIf, + PatternRewriter &rewriter) const override { + Block *parent = nextIf->getBlock(); + if (nextIf == &parent->front()) + return failure(); + + auto prevIf = dyn_cast(nextIf->getPrevNode()); + if (!prevIf) + return failure(); + + // Determine the logical then/else blocks when prevIf's + // condition is used. Null means the block does not exist + // in that case (e.g. empty else). If neither of these + // are set, the two conditions cannot be compared. + Block *nextThen = nullptr; + Block *nextElse = nullptr; + + if (nextIf.getIntegerSet() == prevIf.getIntegerSet() && + llvm::all_of(llvm::zip(nextIf.getOperands(), prevIf.getOperands()), + [](std::tuple p) { + return std::get<0>(p) == std::get<1>(p); + })) { + nextThen = nextIf.getThenBlock(); + if (!nextIf.getElseRegion().empty()) + nextElse = nextIf.getElseBlock(); + } + + if (!nextThen && !nextElse) + return failure(); + + SmallVector prevElseYielded; + if (!prevIf.getElseRegion().empty()) + prevElseYielded = + cast(prevIf.getElseBlock()->getTerminator()) + .getOperands(); + // Replace all uses of return values of op within nextIf with the + // corresponding yields + for (auto it : llvm::zip( + prevIf.getResults(), + cast(prevIf.getThenBlock()->getTerminator()) + .getOperands(), + prevElseYielded)) + for (OpOperand &use : + llvm::make_early_inc_range(std::get<0>(it).getUses())) { + if (nextThen && nextThen->getParent()->isAncestor( + use.getOwner()->getParentRegion())) { + rewriter.startOpModification(use.getOwner()); + use.set(std::get<1>(it)); + rewriter.finalizeOpModification(use.getOwner()); + } else if (nextElse && nextElse->getParent()->isAncestor( + use.getOwner()->getParentRegion())) { + rewriter.startOpModification(use.getOwner()); + use.set(std::get<2>(it)); + rewriter.finalizeOpModification(use.getOwner()); + } + } + + SmallVector mergedTypes(prevIf.getResultTypes()); + llvm::append_range(mergedTypes, nextIf.getResultTypes()); + + affine::AffineIfOp combinedIf = rewriter.create( + nextIf.getLoc(), mergedTypes, prevIf.getIntegerSet(), + prevIf.getOperands(), /*hasElse=*/true); + rewriter.eraseBlock(&combinedIf.getThenRegion().back()); + rewriter.eraseBlock(&combinedIf.getElseRegion().back()); + + rewriter.inlineRegionBefore(prevIf.getThenRegion(), + combinedIf.getThenRegion(), + combinedIf.getThenRegion().begin()); + + if (nextThen) { + affine::AffineYieldOp thenYield = cast( + combinedIf.getThenBlock()->getTerminator()); + affine::AffineYieldOp thenYield2 = + cast(nextThen->getTerminator()); + rewriter.mergeBlocks(nextThen, combinedIf.getThenBlock()); + rewriter.setInsertionPointToEnd(combinedIf.getThenBlock()); + + SmallVector mergedYields(thenYield.getOperands()); + llvm::append_range(mergedYields, thenYield2.getOperands()); + rewriter.create(thenYield2.getLoc(), mergedYields); + rewriter.eraseOp(thenYield); + rewriter.eraseOp(thenYield2); + } + + rewriter.inlineRegionBefore(prevIf.getElseRegion(), + combinedIf.getElseRegion(), + combinedIf.getElseRegion().begin()); + + if (nextElse) { + if (combinedIf.getElseRegion().empty()) { + rewriter.inlineRegionBefore(*nextElse->getParent(), + combinedIf.getElseRegion(), + combinedIf.getElseRegion().begin()); + } else { + affine::AffineYieldOp elseYield = cast( + combinedIf.getElseBlock()->getTerminator()); + affine::AffineYieldOp elseYield2 = + cast(nextElse->getTerminator()); + rewriter.mergeBlocks(nextElse, combinedIf.getElseBlock()); + + rewriter.setInsertionPointToEnd(combinedIf.getElseBlock()); + + SmallVector mergedElseYields(elseYield.getOperands()); + llvm::append_range(mergedElseYields, elseYield2.getOperands()); + + rewriter.create(elseYield2.getLoc(), + mergedElseYields); + rewriter.eraseOp(elseYield); + rewriter.eraseOp(elseYield2); + } + } + + SmallVector prevValues; + SmallVector nextValues; + for (const auto &pair : llvm::enumerate(combinedIf.getResults())) { + if (pair.index() < prevIf.getNumResults()) + prevValues.push_back(pair.value()); + else + nextValues.push_back(pair.value()); + } + rewriter.replaceOp(prevIf, prevValues); + rewriter.replaceOp(nextIf, nextValues); + return success(); + } +}; + +struct MergeNestedAffineParallelLoops + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(affine::AffineParallelOp op, + PatternRewriter &rewriter) const override { + Block &outerBody = op.getRegion().getBlocks().front(); + if (!llvm::hasSingleElement(outerBody.without_terminator())) + return failure(); + + auto innerOp = dyn_cast(outerBody.front()); + if (!innerOp) + return failure(); + + for (auto val : outerBody.getArguments()) + if (llvm::is_contained(innerOp.getLowerBoundsOperands(), val) || + llvm::is_contained(innerOp.getUpperBoundsOperands(), val)) + return failure(); + + // Reductions are not supported yet. + if (!op.getReductions().empty() || !innerOp.getReductions().empty()) + return failure(); + + SmallVector newTypes(op.getResultTypes()); + for (auto T : innerOp.getResultTypes()) + newTypes.push_back(T); + + ArrayRef reductions; + SmallVector lbounds; + SmallVector ubounds; + SmallVector lboundValues; + SmallVector uboundValues; + + for (size_t i = 0; i < op.getLowerBoundsMap().getNumDims(); i++) + lboundValues.push_back(op.getLowerBoundsOperands()[i]); + + for (size_t i = 0; i < op.getUpperBoundsMap().getNumDims(); i++) + uboundValues.push_back(op.getUpperBoundsOperands()[i]); + + for (size_t i = 0; i < innerOp.getLowerBoundsMap().getNumDims(); i++) + lboundValues.push_back(innerOp.getLowerBoundsOperands()[i]); + + for (size_t i = 0; i < innerOp.getUpperBoundsMap().getNumDims(); i++) + uboundValues.push_back(innerOp.getUpperBoundsOperands()[i]); + + for (size_t i = 0; i < op.getLowerBoundsMap().getNumSymbols(); i++) + lboundValues.push_back( + op.getLowerBoundsOperands()[i + op.getLowerBoundsMap().getNumDims()]); + + for (size_t i = 0; i < op.getUpperBoundsMap().getNumSymbols(); i++) + uboundValues.push_back( + op.getUpperBoundsOperands()[i + op.getUpperBoundsMap().getNumDims()]); + + for (size_t i = 0; i < innerOp.getLowerBoundsMap().getNumSymbols(); i++) + lboundValues.push_back( + innerOp.getLowerBoundsOperands()[i + innerOp.getLowerBoundsMap() + .getNumDims()]); + + for (size_t i = 0; i < innerOp.getUpperBoundsMap().getNumSymbols(); i++) + uboundValues.push_back( + innerOp.getUpperBoundsOperands()[i + innerOp.getUpperBoundsMap() + .getNumDims()]); + + for (auto e : op.getLowerBoundsMap().getResults()) { + lbounds.push_back(e); + } + + for (auto e : op.getUpperBoundsMap().getResults()) { + ubounds.push_back(e); + } + + for (auto e : innerOp.getLowerBoundsMap() + .shiftDims(op.getLowerBoundsMap().getNumDims()) + .shiftSymbols(op.getLowerBoundsMap().getNumSymbols()) + .getResults()) { + lbounds.push_back(e); + } + + for (auto e : innerOp.getUpperBoundsMap() + .shiftDims(op.getUpperBoundsMap().getNumDims()) + .shiftSymbols(op.getUpperBoundsMap().getNumSymbols()) + .getResults()) { + ubounds.push_back(e); + } + + SmallVector operands = lboundValues; + operands.append(uboundValues); + + SmallVector lboundGroup; + SmallVector uboundGroup; + for (auto U : op.getLowerBoundsGroups()) + lboundGroup.push_back(U.getZExtValue()); + for (auto U : innerOp.getLowerBoundsGroups()) + lboundGroup.push_back(U.getZExtValue()); + for (auto U : op.getUpperBoundsGroups()) + uboundGroup.push_back(U.getZExtValue()); + for (auto U : innerOp.getUpperBoundsGroups()) + uboundGroup.push_back(U.getZExtValue()); + + SmallVector steps; + for (auto U : op.getSteps()) + steps.push_back(U); + for (auto U : innerOp.getSteps()) + steps.push_back(U); + + affine::AffineParallelOp affineLoop = + rewriter.create( + op.getLoc(), newTypes, rewriter.getArrayAttr(reductions), + AffineMapAttr::get( + AffineMap::get(op.getLowerBoundsMap().getNumDims() + + innerOp.getLowerBoundsMap().getNumDims(), + op.getLowerBoundsMap().getNumSymbols() + + innerOp.getLowerBoundsMap().getNumSymbols(), + lbounds, op.getContext())), + rewriter.getI32TensorAttr(lboundGroup), + AffineMapAttr::get( + AffineMap::get(op.getUpperBoundsMap().getNumDims() + + innerOp.getUpperBoundsMap().getNumDims(), + op.getUpperBoundsMap().getNumSymbols() + + innerOp.getUpperBoundsMap().getNumSymbols(), + ubounds, op.getContext())), + rewriter.getI32TensorAttr(uboundGroup), + rewriter.getI64ArrayAttr(steps), operands); + + rewriter.inlineRegionBefore(op.getRegion(), affineLoop.getRegion(), + affineLoop.getRegion().begin()); + auto yld = affineLoop.getBody()->getTerminator(); + rewriter.eraseOp(innerOp.getBody()->getTerminator()); + SmallVector post; + for (auto v : innerOp.getIVs()) { + post.push_back( + affineLoop.getBody()->addArgument(v.getType(), v.getLoc())); + } + rewriter.inlineBlockBefore(innerOp.getBody(), yld, post); + return success(); + } +}; + +struct PrepMergeNestedAffineParallelLoops + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(affine::AffineParallelOp oop, + PatternRewriter &rewriter) const override { + Block &outerBody = oop.getRegion().getBlocks().front(); + affine::AffineParallelOp innerOp = nullptr; + SmallVector toMove; + for (auto &op : outerBody) { + if (auto innerOp2 = dyn_cast(&op)) { + if (innerOp) + return failure(); + if (!isa(innerOp2->getNextNode())) { + return failure(); + } + innerOp = innerOp2; + continue; + } + if (isMemoryEffectFree(&op)) { + if (!isa(&op)) + toMove.push_back(&op); + continue; + } + + return failure(); + } + + if (!innerOp || !toMove.size()) { + return failure(); + } + + IRMapping map; + rewriter.setInsertionPointToStart(innerOp.getBody()); + for (auto o : toMove) { + rewriter.replaceOp(o, rewriter.clone(*o)->getResults()); + } + return success(); + } +}; + +struct MergeNestedAffineParallelIf + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(affine::AffineParallelOp op, + PatternRewriter &rewriter) const override { + Block &outerBody = op.getRegion().getBlocks().front(); + + affine::AffineIfOp innerOp = nullptr; + for (auto &op : outerBody) { + if (auto innerOp2 = dyn_cast(&op)) { + if (innerOp) + return failure(); + if (!isa(innerOp2->getNextNode())) { + return failure(); + } + innerOp = innerOp2; + continue; + } + if (!isReadOnly(&op)) + return failure(); + } + + if (!innerOp) + return failure(); + + // Reductions are not supported yet. + if (!op.getReductions().empty()) + return failure(); + + if (innerOp.hasElse()) + return failure(); + + SmallVector lboundGroup; + SmallVector uboundGroup; + for (auto U : op.getLowerBoundsGroups()) + lboundGroup.push_back(U.getZExtValue()); + for (auto U : op.getUpperBoundsGroups()) + uboundGroup.push_back(U.getZExtValue()); + + SmallVector lbounds; + SmallVector ubounds; + + for (auto e : op.getLowerBoundsMap().getResults()) { + lbounds.push_back(e); + } + + for (auto e : op.getUpperBoundsMap().getResults()) { + ubounds.push_back(e); + } + + bool changed = false; + SmallVector remaining; + SmallVector isEq; + for (auto cst : llvm::enumerate(innerOp.getIntegerSet().getConstraints())) { + if (innerOp.getIntegerSet().isEq(cst.index())) { + remaining.push_back(cst.value()); + isEq.push_back(innerOp.getIntegerSet().isEq(cst.index())); + continue; + } + + auto getIndUsage = [&op](AffineExpr cst, ValueRange operands, + std::map &indUsage, + bool &legal, + bool *failure = nullptr) -> AffineExpr { + AffineExpr rhs = getAffineConstantExpr(0, cst.getContext()); + SmallVector todo = {cst}; + legal = true; + while (todo.size()) { + auto cur = todo.back(); + todo.pop_back(); + if (cur.isa() || cur.isa()) { + rhs = rhs + cur; + continue; + } + if (auto dim = cur.dyn_cast()) { + auto ival = dyn_cast(operands[dim.getPosition()]); + if (!ival || ival.getOwner()->getParentOp() != op) { + rhs = rhs + dim; + if (failure) + *failure = true; + continue; + } + if (indUsage.find(ival.getArgNumber()) != indUsage.end()) { + legal = false; + continue; + } + indUsage[ival.getArgNumber()] = + getAffineConstantExpr(1, op.getContext()); + continue; + } + if (auto bop = cur.dyn_cast()) { + if (bop.getKind() == AffineExprKind::Add) { + todo.push_back(bop.getLHS()); + todo.push_back(bop.getRHS()); + continue; + } + if (bop.getKind() == AffineExprKind::Mul) { + if (!(bop.getRHS().isa() || + bop.getRHS().isa())) { + legal = false; + continue; + } + + if (auto dim = bop.getLHS().dyn_cast()) { + auto ival = + dyn_cast(operands[dim.getPosition()]); + if (!ival || ival.getOwner()->getParentOp() != op) { + rhs = rhs + bop; + // While legal, this may run before parallel merging + // and prevent parallel fusion + legal = false; + if (failure) + *failure = true; + continue; + } + if (indUsage.find(ival.getArgNumber()) != indUsage.end()) { + legal = false; + continue; + } + indUsage[ival.getArgNumber()] = bop.getRHS(); + continue; + } + } + } + if (failure) + *failure = true; + legal = false; + break; + } + return rhs; + }; + + bool legal; + std::map indUsage; + bool failureV = false; + AffineExpr rhs = getIndUsage(cst.value(), innerOp.getOperands(), indUsage, + legal, &failureV); + if (failureV) + return failure(); + + if (!legal || indUsage.size() != 1) { + remaining.push_back(cst.value()); + isEq.push_back(innerOp.getIntegerSet().isEq(cst.index())); + continue; + } + auto pair = *indUsage.begin(); + auto affCst = pair.second.dyn_cast(); + if (!affCst) { + remaining.push_back(cst.value()); + isEq.push_back(innerOp.getIntegerSet().isEq(cst.index())); + continue; + } + + // currently aff * idx + stuff >= 0 + // currently aff * idx >= -stuff + // idx >= (-stuff).floorDiv(aff) OR idx <= ... + + if (affCst.getValue() < 0) + rhs = rhs.floorDiv(-affCst.getValue()) + 1; + else { + remaining.push_back(cst.value()); + isEq.push_back(innerOp.getIntegerSet().isEq(cst.index())); + continue; + } + + changed = true; + + size_t off = 0; + for (size_t i = 0; i < pair.first; i++) + off += uboundGroup[i]; + + if (auto newCst = rhs.dyn_cast()) { + bool seen = false; + for (size_t i = 0; i < uboundGroup[pair.first]; i++) { + if (auto oldCst = ubounds[i].dyn_cast()) { + seen = true; + if (newCst.getValue() < oldCst.getValue()) + ubounds[i] = rhs; + } + } + if (seen) + continue; + } + ubounds.insert(ubounds.begin() + off, + rhs.shiftDims(innerOp.getIntegerSet().getNumDims(), + op.getUpperBoundsMap().getNumDims()) + .shiftSymbols(innerOp.getIntegerSet().getNumSymbols(), + op.getUpperBoundsMap().getNumSymbols())); + + uboundGroup[pair.first]++; + } + + if (!changed) + return failure(); + + SmallVector lboundValues; + SmallVector uboundValues; + + for (size_t i = 0; i < op.getLowerBoundsMap().getNumDims(); i++) + lboundValues.push_back(op.getLowerBoundsOperands()[i]); + + for (size_t i = 0; i < op.getUpperBoundsMap().getNumDims(); i++) + uboundValues.push_back(op.getUpperBoundsOperands()[i]); + + for (size_t i = 0; i < innerOp.getIntegerSet().getNumDims(); i++) + uboundValues.push_back(innerOp.getOperands()[i]); + + for (size_t i = 0; i < op.getLowerBoundsMap().getNumSymbols(); i++) + lboundValues.push_back( + op.getLowerBoundsOperands()[i + op.getLowerBoundsMap().getNumDims()]); + + for (size_t i = 0; i < op.getUpperBoundsMap().getNumSymbols(); i++) + uboundValues.push_back( + op.getUpperBoundsOperands()[i + op.getUpperBoundsMap().getNumDims()]); + + for (size_t i = 0; i < innerOp.getIntegerSet().getNumSymbols(); i++) + uboundValues.push_back( + innerOp.getOperands()[i + innerOp.getIntegerSet().getNumDims()]); + + SmallVector operands = lboundValues; + operands.append(uboundValues); + + ArrayRef reductions; + + affine::AffineParallelOp affineLoop = + rewriter.create( + op.getLoc(), op.getResultTypes(), rewriter.getArrayAttr(reductions), + AffineMapAttr::get( + AffineMap::get(op.getLowerBoundsMap().getNumDims(), + op.getLowerBoundsMap().getNumSymbols(), lbounds, + op.getContext())), + rewriter.getI32TensorAttr(lboundGroup), + AffineMapAttr::get( + AffineMap::get(op.getUpperBoundsMap().getNumDims() + + innerOp.getIntegerSet().getNumDims(), + op.getUpperBoundsMap().getNumSymbols() + + innerOp.getIntegerSet().getNumSymbols(), + ubounds, op.getContext())), + rewriter.getI32TensorAttr(uboundGroup), op.getStepsAttr(), + operands); + + rewriter.inlineRegionBefore(op.getRegion(), affineLoop.getRegion(), + affineLoop.getRegion().begin()); + + rewriter.setInsertionPoint(innerOp); + + if (remaining.empty()) { + auto yld = + cast(innerOp.getThenBlock()->getTerminator()); + SmallVector toRet(yld.getOperands()); + rewriter.eraseOp(yld); + rewriter.inlineBlockBefore(innerOp.getThenBlock(), innerOp); + rewriter.replaceOp(innerOp, toRet); + } else { + affine::AffineIfOp newIf = rewriter.create( + innerOp.getLoc(), innerOp.getResultTypes(), + IntegerSet::get(innerOp.getIntegerSet().getNumDims(), + innerOp.getIntegerSet().getNumSymbols(), remaining, + isEq), + innerOp.getOperands(), /*hasElse*/ false); + + rewriter.eraseBlock(newIf.getThenBlock()); + + rewriter.inlineRegionBefore(innerOp.getThenRegion(), + newIf.getThenRegion(), + newIf.getThenRegion().begin()); + rewriter.inlineRegionBefore(innerOp.getElseRegion(), + newIf.getElseRegion(), + newIf.getElseRegion().begin()); + + rewriter.replaceOp(innerOp, newIf->getResults()); + rewriter.replaceOp(op, affineLoop->getResults()); + } + return success(); + } +}; + +struct MergeParallelInductions + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(affine::AffineParallelOp op, + PatternRewriter &rewriter) const override { + // Reductions are not supported yet. + if (!op.getReductions().empty()) + return failure(); + + auto getIndUsage = [&op](AffineExpr cst, ValueRange operands, + std::map &indUsage, + bool &legal) -> AffineExpr { + AffineExpr rhs = getAffineConstantExpr(0, cst.getContext()); + SmallVector todo = {cst}; + legal = true; + while (todo.size()) { + auto cur = todo.back(); + todo.pop_back(); + if (cur.isa() || cur.isa()) { + rhs = rhs + cur; + continue; + } + if (auto dim = cur.dyn_cast()) { + auto ival = dyn_cast(operands[dim.getPosition()]); + if (!ival || ival.getOwner()->getParentOp() != op) { + rhs = rhs + dim; + continue; + } + if (indUsage.find(ival.getArgNumber()) != indUsage.end()) { + legal = false; + continue; + } + indUsage[ival.getArgNumber()] = + getAffineConstantExpr(1, op.getContext()); + continue; + } + if (auto bop = cur.dyn_cast()) { + if (bop.getKind() == AffineExprKind::Add) { + todo.push_back(bop.getLHS()); + todo.push_back(bop.getRHS()); + continue; + } + if (bop.getKind() == AffineExprKind::Mul) { + if (!(bop.getRHS().isa() || + bop.getRHS().isa())) { + legal = false; + continue; + } + + if (auto dim = bop.getLHS().dyn_cast()) { + auto ival = dyn_cast(operands[dim.getPosition()]); + if (!ival || ival.getOwner()->getParentOp() != op) { + rhs = rhs + bop; + continue; + } + if (indUsage.find(ival.getArgNumber()) != indUsage.end()) { + legal = false; + continue; + } + indUsage[ival.getArgNumber()] = bop.getRHS(); + continue; + } + } + } + legal = false; + break; + } + return rhs; + }; + + // TODO check all users are affine sums like this. + std::map idxCasts; + SetVector affineMapUsers; + SmallVector legality; + SmallVector fixedUpperBounds; + for (auto iv : op.getIVs()) { + bool legal = true; + Operation *idxCst = nullptr; + + for (auto lb : op.getLowerBoundMap(iv.getArgNumber()).getResults()) { + if (auto cst = lb.dyn_cast()) { + if (cst.getValue() != 0) { + legal = false; + break; + } + } else { + legal = false; + break; + } + } + bool seenub = false; + for (auto ub : op.getUpperBoundMap(iv.getArgNumber()).getResults()) { + if (seenub) { + legal = false; + break; + } + seenub = true; + if (auto cst = ub.dyn_cast()) { + fixedUpperBounds.push_back(ValueOrInt(cst.getValue())); + } else if (auto dim = ub.dyn_cast()) { + fixedUpperBounds.push_back( + ValueOrInt(op.getUpperBoundsOperands()[dim.getPosition()])); + } else if (auto sym = ub.dyn_cast()) { + fixedUpperBounds.push_back(ValueOrInt( + op.getUpperBoundsOperands()[op.getUpperBoundsMap().getNumDims() + + sym.getPosition()])); + } else { + legal = false; + fixedUpperBounds.push_back(ValueOrInt(0)); + } + } + + SmallVector affineMapUsers_t; + for (auto U : iv.getUsers()) { + SmallVector exprs; + ValueRange operands; + if (auto AL = dyn_cast(U)) { + for (auto E : AL.getAffineMap().getResults()) + exprs.push_back(E); + operands = AL.getMapOperands(); + affineMapUsers_t.push_back(U); + } else if (auto AS = dyn_cast(U)) { + if (AS.getValue() == iv) + legal = false; + for (auto E : AS.getAffineMap().getResults()) + exprs.push_back(E); + operands = AS.getMapOperands(); + affineMapUsers_t.push_back(U); + } else if (auto AI = dyn_cast(U)) { + for (auto E : AI.getIntegerSet().getConstraints()) + exprs.push_back(E); + operands = AI.getOperands(); + affineMapUsers_t.push_back(U); + } else if (auto idx = dyn_cast(U)) { + if (idxCst) { + legal = false; + break; + } else + idxCst = idx; + } else if (auto idx = dyn_cast(U)) { + if (idxCst) { + legal = false; + break; + } else + idxCst = idx; + } else { + legal = false; + break; + } + for (auto expr : exprs) { + bool flegal = true; + std::map indUsage; + getIndUsage(expr, operands, indUsage, flegal); + if (!flegal || indUsage.size() < 2) { + legal = false; + break; + } + } + } + legality.push_back(legal); + if (legal) { + for (auto o : affineMapUsers_t) { + affineMapUsers.insert(o); + } + if (idxCst) + idxCasts[iv.getArgNumber()] = idxCst; + } + } + for (auto tup : llvm::zip(op.getIVs(), legality)) { + if (!std::get<1>(tup)) + for (auto U : std::get<0>(tup).getUsers()) + if (affineMapUsers.count(U)) + affineMapUsers.remove(U); + } + for (auto U : affineMapUsers) { + SmallVector exprs; + ValueRange operands; + size_t numDim; + if (auto AL = dyn_cast(U)) { + for (auto E : AL.getAffineMap().getResults()) + exprs.push_back(E); + operands = AL.getMapOperands(); + numDim = AL.getAffineMap().getNumDims(); + } else if (auto AS = dyn_cast(U)) { + for (auto E : AS.getAffineMap().getResults()) + exprs.push_back(E); + operands = AS.getMapOperands(); + numDim = AS.getAffineMap().getNumDims(); + } else if (auto AI = dyn_cast(U)) { + for (auto E : AI.getIntegerSet().getConstraints()) + exprs.push_back(E); + operands = AI.getOperands(); + numDim = AI.getIntegerSet().getNumDims(); + } else { + llvm_unreachable("Unknown affine use type"); + } + + for (auto expr : exprs) { + bool flegal; + std::map indUsage; + getIndUsage(expr, operands, indUsage, flegal); + + for (auto pair1 : indUsage) { + for (auto pair2 : indUsage) { + if (pair1.first == pair2.first) + continue; + if (auto cst = pair1.second.dyn_cast()) { + if (cst.getValue() == -1) { + pair2.second = -pair2.second; + pair1.second = -pair1.second; + } else if (cst.getValue() != 1) + continue; + } else + continue; + + if (!valueCmp(Cmp::EQ, pair2.second, numDim, operands, + fixedUpperBounds[pair1.first])) + continue; + + if (idxCasts.count(pair1.first) != idxCasts.count(pair2.first)) + continue; + + bool legalPair = true; + for (auto U : affineMapUsers) { + if (!legalPair) + break; + SmallVector exprs; + ValueRange operands; + if (auto AL = dyn_cast(U)) { + for (auto E : AL.getAffineMap().getResults()) + exprs.push_back(E); + operands = AL.getMapOperands(); + } else if (auto AS = dyn_cast(U)) { + for (auto E : AS.getAffineMap().getResults()) + exprs.push_back(E); + operands = AS.getMapOperands(); + } else if (auto AI = dyn_cast(U)) { + for (auto E : AI.getIntegerSet().getConstraints()) + exprs.push_back(E); + operands = AI.getOperands(); + } else { + llvm_unreachable("Unknown affine use type"); + } + + for (auto expr : exprs) { + if (!legalPair) + break; + bool sublegal; + std::map subIndUsage; + getIndUsage(expr, operands, subIndUsage, sublegal); + auto find1 = subIndUsage.find(pair1.first); + auto find2 = subIndUsage.find(pair2.first); + + if (find1 == subIndUsage.end() && find2 == subIndUsage.end()) + continue; + if (find1 == subIndUsage.end() || find2 == subIndUsage.end()) { + legalPair = false; + break; + } + if (find1->second * pair2.second != + find2->second * pair1.second) { + legalPair = false; + break; + } + } + } + + if (idxCasts.count(pair1.first)) { + Value val = idxCasts[pair1.first]->getResult(0); + if (!val.hasOneUse()) + continue; + AddIOp add = dyn_cast(*val.user_begin()); + if (!add) + continue; + Value other = (add.getLhs() == val) ? add.getRhs() : add.getLhs(); + + MulIOp mul = other.getDefiningOp(); + if (mul.getLhs() == idxCasts[pair2.first]->getResult(0)) { + if (!valueCmp(Cmp::EQ, mul.getRhs(), + fixedUpperBounds[pair1.first])) + continue; + } else { + if (mul.getRhs() != idxCasts[pair2.first]->getResult(0)) + continue; + if (!valueCmp(Cmp::EQ, mul.getLhs(), + fixedUpperBounds[pair1.first])) + continue; + } + if (!mul->getResult(0).hasOneUse()) + continue; + if (!idxCasts[pair2.first]->getResult(0).hasOneUse()) + continue; + } + + SmallVector uboundGroup; + for (auto U : op.getUpperBoundsGroups()) + uboundGroup.push_back(U.getZExtValue()); + + SmallVector ubounds; + + for (auto e : op.getUpperBoundsMap().getResults()) { + ubounds.push_back(e); + } + + size_t off1 = 0; + for (size_t i = 0; i < pair1.first; i++) + off1 += uboundGroup[i]; + size_t off2 = 0; + for (size_t i = 0; i < pair2.first; i++) + off2 += uboundGroup[i]; + + ubounds[off1] = ubounds[off1] * ubounds[off2]; + ubounds[off2] = getAffineConstantExpr(1, op.getContext()); + + affine::AffineParallelOp affineLoop = + rewriter.create( + op.getLoc(), op.getResultTypes(), op.getReductionsAttr(), + op.getLowerBoundsMapAttr(), op.getLowerBoundsGroupsAttr(), + AffineMapAttr::get( + AffineMap::get(op.getUpperBoundsMap().getNumDims(), + op.getUpperBoundsMap().getNumSymbols(), + ubounds, op.getContext())), + op.getUpperBoundsGroupsAttr(), op.getStepsAttr(), + op.getOperands()); + + rewriter.inlineRegionBefore(op.getRegion(), affineLoop.getRegion(), + affineLoop.getRegion().begin()); + return success(); + } + } + } + } + return failure(); + } +}; + void AffineCFGPass::runOnOperation() { mlir::RewritePatternSet rpl(getOperation()->getContext()); + mlir::enzyme::addSingleIter(rpl, getOperation()->getContext()); rpl.add, CanonicalizeIndexCast, /* IndexCastMovement,*/ AffineFixup, AffineFixup, CanonicalizIfBounds, MoveStoreToAffine, MoveIfToAffine, MoveLoadToAffine, + AffineIfSimplification, CombineAffineIfs, + MergeNestedAffineParallelLoops, PrepMergeNestedAffineParallelLoops, + MergeNestedAffineParallelIf, MergeParallelInductions, + CanonicalieForBounds>(getOperation()->getContext()); GreedyRewriteConfig config; (void)applyPatternsAndFoldGreedily(getOperation(), std::move(rpl), config); diff --git a/src/enzyme_ad/jax/Passes/CanonicalizeLoops.cpp b/src/enzyme_ad/jax/Passes/CanonicalizeLoops.cpp index ac20b7bb0..9ac119d4d 100644 --- a/src/enzyme_ad/jax/Passes/CanonicalizeLoops.cpp +++ b/src/enzyme_ad/jax/Passes/CanonicalizeLoops.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/IRMapping.h" +#include "mlir/IR/IntegerSet.h" #include "mlir/IR/Matchers.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" @@ -34,6 +35,7 @@ namespace enzyme { } // namespace mlir using namespace mlir; +using namespace mlir::arith; using namespace mlir::affine; using namespace mlir::dataflow; using namespace mlir::enzyme; @@ -962,3 +964,8 @@ struct CanonicalizeLoopsPass } }; } // namespace + +void mlir::enzyme::addSingleIter(RewritePatternSet &patterns, + MLIRContext *ctx) { + patterns.add(ctx); +} diff --git a/src/enzyme_ad/jax/Passes/Passes.h b/src/enzyme_ad/jax/Passes/Passes.h index fa55fd511..26e3891e3 100644 --- a/src/enzyme_ad/jax/Passes/Passes.h +++ b/src/enzyme_ad/jax/Passes/Passes.h @@ -31,6 +31,7 @@ namespace enzyme { void populateLibDeviceFuncsToOpsPatterns(MLIRContext *context, RewritePatternSet &patterns); +void addSingleIter(mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx); } // namespace enzyme namespace cf { diff --git a/test/lit_tests/constgather.mlir b/test/lit_tests/constgather.mlir new file mode 100644 index 000000000..1fe24f75a --- /dev/null +++ b/test/lit_tests/constgather.mlir @@ -0,0 +1,14 @@ +// RUN: enzymexlamlir-opt %s --enzyme-hlo-opt | FileCheck %s + +module { + func.func @main(%arg0: tensor<10xf32>) -> (tensor<10xf32>, tensor<10xf32>) { + %0 = "stablehlo.complex"(%arg0, %arg0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xcomplex> + %1 = "stablehlo.real"(%0) : (tensor<10xcomplex>) -> tensor<10xf32> + %2 = "stablehlo.real"(%arg0) : (tensor<10xf32>) -> tensor<10xf32> + return %1, %2 : tensor<10xf32>, tensor<10xf32> + } +} + +// CHECK: func.func @main(%arg0: tensor<10xf32>) -> (tensor<10xf32>, tensor<10xf32>) { +// CHECK-NEXT: return %arg0, %arg0 : tensor<10xf32>, tensor<10xf32> +// CHECK-NEXT: } diff --git a/test/lit_tests/raising/affinecfg.mlir b/test/lit_tests/raising/affinecfg.mlir index fd336c3f8..6dfa04a66 100644 --- a/test/lit_tests/raising/affinecfg.mlir +++ b/test/lit_tests/raising/affinecfg.mlir @@ -152,13 +152,10 @@ module { } } -// CHECK: #set = affine_set<(d0)[s0] : (-d0 + s0 - 1 >= 0)> // CHECK: func.func @c(%[[arg0:.+]]: memref, %[[arg1]]: i64) { // CHECK-NEXT: %[[V0:.+]] = arith.index_cast %[[arg1]] : i64 to index -// CHECK-NEXT: affine.parallel (%[[arg2:.+]], %[[arg3:.+]]) = (0, 0) to (42, 512) { -// CHECK-NEXT: affine.if #set(%[[arg2]])[%[[V0]]] { +// CHECK-NEXT: affine.parallel (%[[arg2:.+]], %[[arg3:.+]]) = (0, 0) to (min(symbol(%[[V0]]), 42), 512) { // CHECK-NEXT: "test.something"() : () -> () -// CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return // CHECK-NEXT: } diff --git a/test/lit_tests/raising/affparmerge.mlir b/test/lit_tests/raising/affparmerge.mlir new file mode 100644 index 000000000..dca645420 --- /dev/null +++ b/test/lit_tests/raising/affparmerge.mlir @@ -0,0 +1,26 @@ +// RUN: enzymexlamlir-opt --pass-pipeline="builtin.module(affine-cfg)" --split-input-file --allow-unregistered-dialect %s | FileCheck %s + +module { + func.func @f(%636: index, %603: memref) { + %c512_i32 = arith.constant 512 : i32 + affine.parallel (%arg7, %arg8) = (0, 0) to (symbol(%636), 512) { + %706 = arith.index_cast %arg7 : index to i32 + %707 = arith.muli %706, %c512_i32 : i32 + %708 = arith.index_cast %arg8 : index to i32 + %709 = arith.addi %707, %708 : i32 + %712 = arith.sitofp %709 : i32 to f64 + affine.store %712, %603[%arg8 + %arg7 * 512] : memref + } + return + } + +} + +// CHECK: func.func @f(%[[arg0:.+]]: index, %[[arg1:.+]]: memref) { +// CHECK-NEXT: affine.parallel (%[[arg2:.+]]) = (0) to (symbol(%[[arg0]]) * 512) { +// CHECK-NEXT: %[[V0:.+]] = arith.index_cast %[[arg2]] : index to i32 +// CHECK-NEXT: %[[V1:.+]] = arith.sitofp %[[V0]] : i32 to f64 +// CHECK-NEXT: affine.store %[[V1]], %arg1[%[[arg2]]] : memref +// CHECK-NEXT: } +// CHECK-NEXT: return +// CHECK-NEXT: } diff --git a/test/lit_tests/raising/cpu.mlir b/test/lit_tests/raising/cpu.mlir index 0bbba1081..4969cf186 100644 --- a/test/lit_tests/raising/cpu.mlir +++ b/test/lit_tests/raising/cpu.mlir @@ -31,16 +31,12 @@ module { } // CHECK: func.func private @kern$par0(%arg0: !llvm.ptr<1>) { -// CHECK-NEXT: affine.parallel (%arg1, %arg2, %arg3, %arg4, %arg5, %arg6) = (0, 0, 0, 0, 0, 0) to (1, 1, 1, 1, 1, 40) { -// CHECK-NEXT: affine.if #set(%arg4) { -// CHECK-NEXT: llvm.call fastcc @throw_boundserror_2676() : () -> () -// CHECK-NEXT: } else { +// CHECK-NEXT: affine.parallel (%arg1) = (0) to (40) { // CHECK-NEXT: %0 = "enzymexla.pointer2memref"(%arg0) : (!llvm.ptr<1>) -> memref -// CHECK-NEXT: %1 = affine.load %0[%arg4] {alignment = 1 : i64, ordering = 0 : i64} : memref +// CHECK-NEXT: %1 = affine.load %0[0] {alignment = 1 : i64, ordering = 0 : i64} : memref // CHECK-NEXT: %2 = arith.muli %1, %1 : i64 // CHECK-NEXT: %3 = "enzymexla.pointer2memref"(%arg0) : (!llvm.ptr<1>) -> memref -// CHECK-NEXT: affine.store %2, %3[%arg4] {alignment = 1 : i64, ordering = 0 : i64} : memref -// CHECK-NEXT: } +// CHECK-NEXT: affine.store %2, %3[0] {alignment = 1 : i64, ordering = 0 : i64} : memref // CHECK-NEXT: } // CHECK-NEXT: return // CHECK-NEXT: }