Skip to content

Commit e5a215b

Browse files
committed
Reassociate: add global reassociation algorithm
This PR pulls the upstream change, Reassociate: add global reassociation algorithm (llvm/llvm-project@b8a330c), into DXC with miminal changes. For the code below: foo = (a * b) * c bar = (a * d) * c As the upstream change states, it can identify the a*c is a common factor and redundant. This is part 1 of the fix for #6593.
1 parent 14c4407 commit e5a215b

File tree

4 files changed

+623
-2
lines changed

4 files changed

+623
-2
lines changed

lib/Transforms/Scalar/Reassociate.cpp

+122-2
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@
2020
//
2121
//===----------------------------------------------------------------------===//
2222

23-
#include "llvm/Transforms/Scalar.h"
2423
#include "llvm/ADT/DenseMap.h"
2524
#include "llvm/ADT/PostOrderIterator.h"
2625
#include "llvm/ADT/STLExtras.h"
2726
#include "llvm/ADT/SetVector.h"
27+
#include "llvm/ADT/SmallSet.h"
2828
#include "llvm/ADT/Statistic.h"
2929
#include "llvm/IR/CFG.h"
3030
#include "llvm/IR/Constants.h"
@@ -37,6 +37,7 @@
3737
#include "llvm/Pass.h"
3838
#include "llvm/Support/Debug.h"
3939
#include "llvm/Support/raw_ostream.h"
40+
#include "llvm/Transforms/Scalar.h"
4041
#include "llvm/Transforms/Utils/Local.h"
4142
#include <algorithm>
4243
using namespace llvm;
@@ -161,6 +162,13 @@ namespace {
161162
DenseMap<BasicBlock*, unsigned> RankMap;
162163
DenseMap<AssertingVH<Value>, unsigned> ValueRankMap;
163164
SetVector<AssertingVH<Instruction> > RedoInsts;
165+
166+
// Arbitrary, but prevents quadratic behavior.
167+
static const unsigned GlobalReassociateLimit = 10;
168+
static const unsigned NumBinaryOps =
169+
Instruction::BinaryOpsEnd - Instruction::BinaryOpsBegin;
170+
DenseMap<std::pair<Value *, Value *>, unsigned> PairMap[NumBinaryOps];
171+
164172
bool MadeChange;
165173
public:
166174
static char ID; // Pass identification, replacement for typeid
@@ -196,6 +204,7 @@ namespace {
196204
void EraseInst(Instruction *I);
197205
void OptimizeInst(Instruction *I);
198206
Instruction *canonicalizeNegConstExpr(Instruction *I);
207+
void BuildPairMap(ReversePostOrderTraversal<Function *> &RPOT);
199208
};
200209
}
201210

@@ -2234,18 +2243,127 @@ void Reassociate::ReassociateExpression(BinaryOperator *I) {
22342243
return;
22352244
}
22362245

2246+
if (Ops.size() > 2 && Ops.size() <= GlobalReassociateLimit) {
2247+
// Find the pair with the highest count in the pairmap and move it to the
2248+
// back of the list so that it can later be CSE'd.
2249+
// example:
2250+
// a*b*c*d*e
2251+
// if c*e is the most "popular" pair, we can express this as
2252+
// (((c*e)*d)*b)*a
2253+
unsigned Max = 1;
2254+
unsigned BestRank = 0;
2255+
std::pair<unsigned, unsigned> BestPair;
2256+
unsigned Idx = I->getOpcode() - Instruction::BinaryOpsBegin;
2257+
for (unsigned i = 0; i < Ops.size() - 1; ++i)
2258+
for (unsigned j = i + 1; j < Ops.size(); ++j) {
2259+
unsigned Score = 0;
2260+
Value *Op0 = Ops[i].Op;
2261+
Value *Op1 = Ops[j].Op;
2262+
if (std::less<Value *>()(Op1, Op0))
2263+
std::swap(Op0, Op1);
2264+
auto it = PairMap[Idx].find({Op0, Op1});
2265+
if (it != PairMap[Idx].end())
2266+
Score += it->second;
2267+
2268+
unsigned MaxRank = std::max(Ops[i].Rank, Ops[j].Rank);
2269+
if (Score > Max || (Score == Max && MaxRank < BestRank)) {
2270+
BestPair = {i, j};
2271+
Max = Score;
2272+
BestRank = MaxRank;
2273+
}
2274+
}
2275+
if (Max > 1) {
2276+
auto Op0 = Ops[BestPair.first];
2277+
auto Op1 = Ops[BestPair.second];
2278+
Ops.erase(&Ops[BestPair.second]);
2279+
Ops.erase(&Ops[BestPair.first]);
2280+
Ops.push_back(Op0);
2281+
Ops.push_back(Op1);
2282+
}
2283+
}
22372284
// Now that we ordered and optimized the expressions, splat them back into
22382285
// the expression tree, removing any unneeded nodes.
22392286
RewriteExprTree(I, Ops);
22402287
}
22412288

2289+
void Reassociate::BuildPairMap(ReversePostOrderTraversal<Function *> &RPOT) {
2290+
// Make a "pairmap" of how often each operand pair occurs.
2291+
for (BasicBlock *BI : RPOT) {
2292+
for (Instruction &I : *BI) {
2293+
if (!I.isAssociative())
2294+
continue;
2295+
2296+
// Ignore nodes that aren't at the root of trees.
2297+
if (I.hasOneUse() && I.user_back()->getOpcode() == I.getOpcode())
2298+
continue;
2299+
2300+
// Collect all operands in a single reassociable expression.
2301+
// Since Reassociate has already been run once, we can assume things
2302+
// are already canonical according to Reassociation's regime.
2303+
SmallVector<Value *, 8> Worklist = {I.getOperand(0), I.getOperand(1)};
2304+
SmallVector<Value *, 8> Ops;
2305+
while (!Worklist.empty() && Ops.size() <= GlobalReassociateLimit) {
2306+
Value *Op = Worklist.pop_back_val();
2307+
Instruction *OpI = dyn_cast<Instruction>(Op);
2308+
if (!OpI || OpI->getOpcode() != I.getOpcode() || !OpI->hasOneUse()) {
2309+
Ops.push_back(Op);
2310+
continue;
2311+
}
2312+
// Be paranoid about self-referencing expressions in unreachable code.
2313+
if (OpI->getOperand(0) != OpI)
2314+
Worklist.push_back(OpI->getOperand(0));
2315+
if (OpI->getOperand(1) != OpI)
2316+
Worklist.push_back(OpI->getOperand(1));
2317+
}
2318+
// Skip extremely long expressions.
2319+
if (Ops.size() > GlobalReassociateLimit)
2320+
continue;
2321+
2322+
// Add all pairwise combinations of operands to the pair map.
2323+
unsigned BinaryIdx = I.getOpcode() - Instruction::BinaryOpsBegin;
2324+
SmallSet<std::pair<Value *, Value *>, 32> Visited;
2325+
for (unsigned i = 0; i < Ops.size() - 1; ++i) {
2326+
for (unsigned j = i + 1; j < Ops.size(); ++j) {
2327+
// Canonicalize operand orderings.
2328+
Value *Op0 = Ops[i];
2329+
Value *Op1 = Ops[j];
2330+
if (std::less<Value *>()(Op1, Op0))
2331+
std::swap(Op0, Op1);
2332+
if (!Visited.insert({Op0, Op1}).second)
2333+
continue;
2334+
auto res = PairMap[BinaryIdx].insert({{Op0, Op1}, 1});
2335+
if (!res.second)
2336+
++res.first->second;
2337+
}
2338+
}
2339+
}
2340+
}
2341+
}
2342+
22422343
bool Reassociate::runOnFunction(Function &F) {
22432344
if (skipOptnoneFunction(F))
22442345
return false;
22452346

22462347
// Calculate the rank map for F
22472348
BuildRankMap(F);
22482349

2350+
// Build the pair map before running reassociate.
2351+
// Technically this would be more accurate if we did it after one round
2352+
// of reassociation, but in practice it doesn't seem to help much on
2353+
// real-world code, so don't waste the compile time running reassociate
2354+
// twice.
2355+
// If a user wants, they could expicitly run reassociate twice in their
2356+
// pass pipeline for further potential gains.
2357+
// It might also be possible to update the pair map during runtime, but the
2358+
// overhead of that may be large if there's many reassociable chains.
2359+
// TODO: RPOT
2360+
// Get the functions basic blocks in Reverse Post Order. This order is used by
2361+
// BuildRankMap to pre calculate ranks correctly. It also excludes dead basic
2362+
// blocks (it has been seen that the analysis in this pass could hang when
2363+
// analysing dead basic blocks).
2364+
ReversePostOrderTraversal<Function *> RPOT(&F);
2365+
BuildPairMap(RPOT);
2366+
22492367
MadeChange = false;
22502368
for (Function::iterator BI = F.begin(), BE = F.end(); BI != BE; ++BI) {
22512369
// Optimize every instruction in the basic block.
@@ -2268,9 +2386,11 @@ bool Reassociate::runOnFunction(Function &F) {
22682386
}
22692387
}
22702388

2271-
// We are done with the rank map.
2389+
// We are done with the rank map and pair map.
22722390
RankMap.clear();
22732391
ValueRankMap.clear();
2392+
for (auto &Entry : PairMap)
2393+
Entry.clear();
22742394

22752395
return MadeChange;
22762396
}

0 commit comments

Comments
 (0)