[SLP] Support for horizontal min/max reduction.
Summary: SLP vectorizer supports horizontal reductions for Add/FAdd binary operations. Patch adds support for horizontal min/max reductions. Function getReductionCost() is split to getArithmeticReductionCost() for binary operation reductions and getMinMaxReductionCost() for min/max reductions. Patch fixes PR26956. Reviewers: spatel, mkuper, hfinkel, RKSimon Subscribers: llvm-commits Differential Revision: https://reviews.llvm.org/D27846 llvm-svn: 314101
This commit is contained in:
@@ -4647,16 +4647,24 @@ namespace {
|
||||
/// *p =
|
||||
///
|
||||
class HorizontalReduction {
|
||||
SmallVector<Value *, 16> ReductionOps;
|
||||
using ReductionOpsType = SmallVector<Value *, 16>;
|
||||
using ReductionOpsListType = SmallVector<ReductionOpsType, 2>;
|
||||
ReductionOpsListType ReductionOps;
|
||||
SmallVector<Value *, 32> ReducedVals;
|
||||
// Use map vector to make stable output.
|
||||
MapVector<Instruction *, Value *> ExtraArgs;
|
||||
|
||||
/// Kind of the reduction data.
|
||||
enum ReductionKind {
|
||||
RK_None, /// Not a reduction.
|
||||
RK_Arithmetic, /// Binary reduction data.
|
||||
RK_Min, /// Minimum reduction data.
|
||||
RK_UMin, /// Unsigned minimum reduction data.
|
||||
RK_Max, /// Maximum reduction data.
|
||||
RK_UMax, /// Unsigned maximum reduction data.
|
||||
};
|
||||
/// Contains info about operation, like its opcode, left and right operands.
|
||||
struct OperationData {
|
||||
/// true if the operation is a reduced value, false if reduction operation.
|
||||
bool IsReducedValue = false;
|
||||
|
||||
class OperationData {
|
||||
/// Opcode of the instruction.
|
||||
unsigned Opcode = 0;
|
||||
|
||||
@@ -4665,12 +4673,52 @@ class HorizontalReduction {
|
||||
|
||||
/// Right operand of the reduction operation.
|
||||
Value *RHS = nullptr;
|
||||
/// Kind of the reduction operation.
|
||||
ReductionKind Kind = RK_None;
|
||||
/// True if float point min/max reduction has no NaNs.
|
||||
bool NoNaN = false;
|
||||
|
||||
/// Checks if the reduction operation can be vectorized.
|
||||
bool isVectorizable() const {
|
||||
return LHS && RHS &&
|
||||
// We currently only support adds.
|
||||
(Opcode == Instruction::Add || Opcode == Instruction::FAdd);
|
||||
// We currently only support adds && min/max reductions.
|
||||
((Kind == RK_Arithmetic &&
|
||||
(Opcode == Instruction::Add || Opcode == Instruction::FAdd)) ||
|
||||
((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) &&
|
||||
(Kind == RK_Min || Kind == RK_Max)) ||
|
||||
(Opcode == Instruction::ICmp &&
|
||||
(Kind == RK_UMin || Kind == RK_UMax)));
|
||||
}
|
||||
|
||||
/// Creates reduction operation with the current opcode.
|
||||
Value *createOp(IRBuilder<> &Builder, const Twine &Name) const {
|
||||
assert(isVectorizable() &&
|
||||
"Expected add|fadd or min/max reduction operation.");
|
||||
Value *Cmp;
|
||||
switch (Kind) {
|
||||
case RK_Arithmetic:
|
||||
return Builder.CreateBinOp((Instruction::BinaryOps)Opcode, LHS, RHS,
|
||||
Name);
|
||||
case RK_Min:
|
||||
Cmp = Opcode == Instruction::ICmp ? Builder.CreateICmpSLT(LHS, RHS)
|
||||
: Builder.CreateFCmpOLT(LHS, RHS);
|
||||
break;
|
||||
case RK_Max:
|
||||
Cmp = Opcode == Instruction::ICmp ? Builder.CreateICmpSGT(LHS, RHS)
|
||||
: Builder.CreateFCmpOGT(LHS, RHS);
|
||||
break;
|
||||
case RK_UMin:
|
||||
assert(Opcode == Instruction::ICmp && "Expected integer types.");
|
||||
Cmp = Builder.CreateICmpULT(LHS, RHS);
|
||||
break;
|
||||
case RK_UMax:
|
||||
assert(Opcode == Instruction::ICmp && "Expected integer types.");
|
||||
Cmp = Builder.CreateICmpUGT(LHS, RHS);
|
||||
break;
|
||||
case RK_None:
|
||||
llvm_unreachable("Unknown reduction operation.");
|
||||
}
|
||||
return Builder.CreateSelect(Cmp, LHS, RHS, Name);
|
||||
}
|
||||
|
||||
public:
|
||||
@@ -4678,43 +4726,156 @@ class HorizontalReduction {
|
||||
|
||||
/// Construction for reduced values. They are identified by opcode only and
|
||||
/// don't have associated LHS/RHS values.
|
||||
explicit OperationData(Value *V) : IsReducedValue(true) {
|
||||
explicit OperationData(Value *V) : Kind(RK_None) {
|
||||
if (auto *I = dyn_cast<Instruction>(V))
|
||||
Opcode = I->getOpcode();
|
||||
}
|
||||
|
||||
/// Constructor for binary reduction operations with opcode and its left and
|
||||
/// Constructor for reduction operations with opcode and its left and
|
||||
/// right operands.
|
||||
OperationData(unsigned Opcode, Value *LHS, Value *RHS)
|
||||
: Opcode(Opcode), LHS(LHS), RHS(RHS) {}
|
||||
|
||||
OperationData(unsigned Opcode, Value *LHS, Value *RHS, ReductionKind Kind,
|
||||
bool NoNaN = false)
|
||||
: Opcode(Opcode), LHS(LHS), RHS(RHS), Kind(Kind), NoNaN(NoNaN) {
|
||||
assert(Kind != RK_None && "One of the reduction operations is expected.");
|
||||
}
|
||||
explicit operator bool() const { return Opcode; }
|
||||
|
||||
/// Get the index of the first operand.
|
||||
unsigned getFirstOperandIndex() const {
|
||||
assert(!!*this && "The opcode is not set.");
|
||||
switch (Kind) {
|
||||
case RK_Min:
|
||||
case RK_UMin:
|
||||
case RK_Max:
|
||||
case RK_UMax:
|
||||
return 1;
|
||||
case RK_Arithmetic:
|
||||
case RK_None:
|
||||
break;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Total number of operands in the reduction operation.
|
||||
unsigned getNumberOfOperands() const {
|
||||
assert(!IsReducedValue && !!*this && LHS && RHS &&
|
||||
assert(Kind != RK_None && !!*this && LHS && RHS &&
|
||||
"Expected reduction operation.");
|
||||
return 2;
|
||||
switch (Kind) {
|
||||
case RK_Arithmetic:
|
||||
return 2;
|
||||
case RK_Min:
|
||||
case RK_UMin:
|
||||
case RK_Max:
|
||||
case RK_UMax:
|
||||
return 3;
|
||||
case RK_None:
|
||||
break;
|
||||
}
|
||||
llvm_unreachable("Reduction kind is not set");
|
||||
}
|
||||
|
||||
/// Expected number of uses for reduction operations/reduced values.
|
||||
unsigned getRequiredNumberOfUses() const {
|
||||
assert(!IsReducedValue && !!*this && LHS && RHS &&
|
||||
/// Checks if the operation has the same parent as \p P.
|
||||
bool hasSameParent(Instruction *I, Value *P, bool IsRedOp) const {
|
||||
assert(Kind != RK_None && !!*this && LHS && RHS &&
|
||||
"Expected reduction operation.");
|
||||
return 1;
|
||||
if (!IsRedOp)
|
||||
return I->getParent() == P;
|
||||
switch (Kind) {
|
||||
case RK_Arithmetic:
|
||||
// Arithmetic reduction operation must be used once only.
|
||||
return I->getParent() == P;
|
||||
case RK_Min:
|
||||
case RK_UMin:
|
||||
case RK_Max:
|
||||
case RK_UMax: {
|
||||
// SelectInst must be used twice while the condition op must have single
|
||||
// use only.
|
||||
auto *Cmp = cast<Instruction>(cast<SelectInst>(I)->getCondition());
|
||||
return I->getParent() == P && Cmp && Cmp->getParent() == P;
|
||||
}
|
||||
case RK_None:
|
||||
break;
|
||||
}
|
||||
llvm_unreachable("Reduction kind is not set");
|
||||
}
|
||||
/// Expected number of uses for reduction operations/reduced values.
|
||||
bool hasRequiredNumberOfUses(Instruction *I, bool IsReductionOp) const {
|
||||
assert(Kind != RK_None && !!*this && LHS && RHS &&
|
||||
"Expected reduction operation.");
|
||||
switch (Kind) {
|
||||
case RK_Arithmetic:
|
||||
return I->hasOneUse();
|
||||
case RK_Min:
|
||||
case RK_UMin:
|
||||
case RK_Max:
|
||||
case RK_UMax:
|
||||
return I->hasNUses(2) &&
|
||||
(!IsReductionOp ||
|
||||
cast<SelectInst>(I)->getCondition()->hasOneUse());
|
||||
case RK_None:
|
||||
break;
|
||||
}
|
||||
llvm_unreachable("Reduction kind is not set");
|
||||
}
|
||||
|
||||
/// Initializes the list of reduction operations.
|
||||
void initReductionOps(ReductionOpsListType &ReductionOps) {
|
||||
assert(Kind != RK_None && !!*this && LHS && RHS &&
|
||||
"Expected reduction operation.");
|
||||
switch (Kind) {
|
||||
case RK_Arithmetic:
|
||||
ReductionOps.assign(1, ReductionOpsType());
|
||||
break;
|
||||
case RK_Min:
|
||||
case RK_UMin:
|
||||
case RK_Max:
|
||||
case RK_UMax:
|
||||
ReductionOps.assign(2, ReductionOpsType());
|
||||
break;
|
||||
case RK_None:
|
||||
llvm_unreachable("Reduction kind is not set");
|
||||
}
|
||||
}
|
||||
/// Add all reduction operations for the reduction instruction \p I.
|
||||
void addReductionOps(Instruction *I, ReductionOpsListType &ReductionOps) {
|
||||
assert(Kind != RK_None && !!*this && LHS && RHS &&
|
||||
"Expected reduction operation.");
|
||||
switch (Kind) {
|
||||
case RK_Arithmetic:
|
||||
ReductionOps[0].emplace_back(I);
|
||||
break;
|
||||
case RK_Min:
|
||||
case RK_UMin:
|
||||
case RK_Max:
|
||||
case RK_UMax:
|
||||
ReductionOps[0].emplace_back(cast<SelectInst>(I)->getCondition());
|
||||
ReductionOps[1].emplace_back(I);
|
||||
break;
|
||||
case RK_None:
|
||||
llvm_unreachable("Reduction kind is not set");
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks if instruction is associative and can be vectorized.
|
||||
bool isAssociative(Instruction *I) const {
|
||||
assert(!IsReducedValue && *this && LHS && RHS &&
|
||||
assert(Kind != RK_None && *this && LHS && RHS &&
|
||||
"Expected reduction operation.");
|
||||
return I->isAssociative();
|
||||
switch (Kind) {
|
||||
case RK_Arithmetic:
|
||||
return I->isAssociative();
|
||||
case RK_Min:
|
||||
case RK_Max:
|
||||
return Opcode == Instruction::ICmp ||
|
||||
cast<Instruction>(I->getOperand(0))->hasUnsafeAlgebra();
|
||||
case RK_UMin:
|
||||
case RK_UMax:
|
||||
assert(Opcode == Instruction::ICmp &&
|
||||
"Only integer compare operation is expected.");
|
||||
return true;
|
||||
case RK_None:
|
||||
break;
|
||||
}
|
||||
llvm_unreachable("Reduction kind is not set");
|
||||
}
|
||||
|
||||
/// Checks if the reduction operation can be vectorized.
|
||||
@@ -4725,18 +4886,17 @@ class HorizontalReduction {
|
||||
/// Checks if two operation data are both a reduction op or both a reduced
|
||||
/// value.
|
||||
bool operator==(const OperationData &OD) {
|
||||
assert(((IsReducedValue != OD.IsReducedValue) ||
|
||||
((!LHS == !OD.LHS) && (!RHS == !OD.RHS))) &&
|
||||
assert(((Kind != OD.Kind) || ((!LHS == !OD.LHS) && (!RHS == !OD.RHS))) &&
|
||||
"One of the comparing operations is incorrect.");
|
||||
return this == &OD ||
|
||||
(IsReducedValue == OD.IsReducedValue && Opcode == OD.Opcode);
|
||||
return this == &OD || (Kind == OD.Kind && Opcode == OD.Opcode);
|
||||
}
|
||||
bool operator!=(const OperationData &OD) { return !(*this == OD); }
|
||||
void clear() {
|
||||
IsReducedValue = false;
|
||||
Opcode = 0;
|
||||
LHS = nullptr;
|
||||
RHS = nullptr;
|
||||
Kind = RK_None;
|
||||
NoNaN = false;
|
||||
}
|
||||
|
||||
/// Get the opcode of the reduction operation.
|
||||
@@ -4745,16 +4905,99 @@ class HorizontalReduction {
|
||||
return Opcode;
|
||||
}
|
||||
|
||||
/// Get kind of reduction data.
|
||||
ReductionKind getKind() const { return Kind; }
|
||||
Value *getLHS() const { return LHS; }
|
||||
Value *getRHS() const { return RHS; }
|
||||
Type *getConditionType() const {
|
||||
switch (Kind) {
|
||||
case RK_Arithmetic:
|
||||
return nullptr;
|
||||
case RK_Min:
|
||||
case RK_Max:
|
||||
case RK_UMin:
|
||||
case RK_UMax:
|
||||
return CmpInst::makeCmpResultType(LHS->getType());
|
||||
case RK_None:
|
||||
break;
|
||||
}
|
||||
llvm_unreachable("Reduction kind is not set");
|
||||
}
|
||||
|
||||
/// Creates reduction operation with the current opcode.
|
||||
Value *createOp(IRBuilder<> &Builder, const Twine &Name = "") const {
|
||||
assert(!IsReducedValue &&
|
||||
(Opcode == Instruction::FAdd || Opcode == Instruction::Add) &&
|
||||
"Expected add|fadd reduction operation.");
|
||||
return Builder.CreateBinOp((Instruction::BinaryOps)Opcode, LHS, RHS,
|
||||
Name);
|
||||
/// Creates reduction operation with the current opcode with the IR flags
|
||||
/// from \p ReductionOps.
|
||||
Value *createOp(IRBuilder<> &Builder, const Twine &Name,
|
||||
const ReductionOpsListType &ReductionOps) const {
|
||||
assert(isVectorizable() &&
|
||||
"Expected add|fadd or min/max reduction operation.");
|
||||
auto *Op = createOp(Builder, Name);
|
||||
switch (Kind) {
|
||||
case RK_Arithmetic:
|
||||
propagateIRFlags(Op, ReductionOps[0]);
|
||||
return Op;
|
||||
case RK_Min:
|
||||
case RK_Max:
|
||||
case RK_UMin:
|
||||
case RK_UMax:
|
||||
propagateIRFlags(cast<SelectInst>(Op)->getCondition(), ReductionOps[0]);
|
||||
propagateIRFlags(Op, ReductionOps[1]);
|
||||
return Op;
|
||||
case RK_None:
|
||||
break;
|
||||
}
|
||||
llvm_unreachable("Unknown reduction operation.");
|
||||
}
|
||||
/// Creates reduction operation with the current opcode with the IR flags
|
||||
/// from \p I.
|
||||
Value *createOp(IRBuilder<> &Builder, const Twine &Name,
|
||||
Instruction *I) const {
|
||||
assert(isVectorizable() &&
|
||||
"Expected add|fadd or min/max reduction operation.");
|
||||
auto *Op = createOp(Builder, Name);
|
||||
switch (Kind) {
|
||||
case RK_Arithmetic:
|
||||
propagateIRFlags(Op, I);
|
||||
return Op;
|
||||
case RK_Min:
|
||||
case RK_Max:
|
||||
case RK_UMin:
|
||||
case RK_UMax:
|
||||
propagateIRFlags(cast<SelectInst>(Op)->getCondition(),
|
||||
cast<SelectInst>(I)->getCondition());
|
||||
propagateIRFlags(Op, I);
|
||||
return Op;
|
||||
case RK_None:
|
||||
break;
|
||||
}
|
||||
llvm_unreachable("Unknown reduction operation.");
|
||||
}
|
||||
|
||||
TargetTransformInfo::ReductionFlags getFlags() const {
|
||||
TargetTransformInfo::ReductionFlags Flags;
|
||||
Flags.NoNaN = NoNaN;
|
||||
switch (Kind) {
|
||||
case RK_Arithmetic:
|
||||
break;
|
||||
case RK_Min:
|
||||
Flags.IsSigned = Opcode == Instruction::ICmp;
|
||||
Flags.IsMaxOp = false;
|
||||
break;
|
||||
case RK_Max:
|
||||
Flags.IsSigned = Opcode == Instruction::ICmp;
|
||||
Flags.IsMaxOp = true;
|
||||
break;
|
||||
case RK_UMin:
|
||||
Flags.IsSigned = false;
|
||||
Flags.IsMaxOp = false;
|
||||
break;
|
||||
case RK_UMax:
|
||||
Flags.IsSigned = false;
|
||||
Flags.IsMaxOp = true;
|
||||
break;
|
||||
case RK_None:
|
||||
llvm_unreachable("Reduction kind is not set");
|
||||
}
|
||||
return Flags;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -4796,8 +5039,32 @@ class HorizontalReduction {
|
||||
|
||||
Value *LHS;
|
||||
Value *RHS;
|
||||
if (m_BinOp(m_Value(LHS), m_Value(RHS)).match(V))
|
||||
return OperationData(cast<BinaryOperator>(V)->getOpcode(), LHS, RHS);
|
||||
if (m_BinOp(m_Value(LHS), m_Value(RHS)).match(V)) {
|
||||
return OperationData(cast<BinaryOperator>(V)->getOpcode(), LHS, RHS,
|
||||
RK_Arithmetic);
|
||||
}
|
||||
if (auto *Select = dyn_cast<SelectInst>(V)) {
|
||||
// Look for a min/max pattern.
|
||||
if (m_UMin(m_Value(LHS), m_Value(RHS)).match(Select)) {
|
||||
return OperationData(Instruction::ICmp, LHS, RHS, RK_UMin);
|
||||
} else if (m_SMin(m_Value(LHS), m_Value(RHS)).match(Select)) {
|
||||
return OperationData(Instruction::ICmp, LHS, RHS, RK_Min);
|
||||
} else if (m_OrdFMin(m_Value(LHS), m_Value(RHS)).match(Select) ||
|
||||
m_UnordFMin(m_Value(LHS), m_Value(RHS)).match(Select)) {
|
||||
return OperationData(
|
||||
Instruction::FCmp, LHS, RHS, RK_Min,
|
||||
cast<Instruction>(Select->getCondition())->hasNoNaNs());
|
||||
} else if (m_UMax(m_Value(LHS), m_Value(RHS)).match(Select)) {
|
||||
return OperationData(Instruction::ICmp, LHS, RHS, RK_UMax);
|
||||
} else if (m_SMax(m_Value(LHS), m_Value(RHS)).match(Select)) {
|
||||
return OperationData(Instruction::ICmp, LHS, RHS, RK_Max);
|
||||
} else if (m_OrdFMax(m_Value(LHS), m_Value(RHS)).match(Select) ||
|
||||
m_UnordFMax(m_Value(LHS), m_Value(RHS)).match(Select)) {
|
||||
return OperationData(
|
||||
Instruction::FCmp, LHS, RHS, RK_Max,
|
||||
cast<Instruction>(Select->getCondition())->hasNoNaNs());
|
||||
}
|
||||
}
|
||||
return OperationData(V);
|
||||
}
|
||||
|
||||
@@ -4840,7 +5107,7 @@ public:
|
||||
// trees containing only binary operators.
|
||||
SmallVector<std::pair<Instruction *, unsigned>, 32> Stack;
|
||||
Stack.push_back(std::make_pair(B, ReductionData.getFirstOperandIndex()));
|
||||
const unsigned NUses = ReductionData.getRequiredNumberOfUses();
|
||||
ReductionData.initReductionOps(ReductionOps);
|
||||
while (!Stack.empty()) {
|
||||
Instruction *TreeN = Stack.back().first;
|
||||
unsigned EdgeToVist = Stack.back().second++;
|
||||
@@ -4866,7 +5133,7 @@ public:
|
||||
markExtraArg(Stack[Stack.size() - 2], TreeN);
|
||||
ExtraArgs.erase(TreeN);
|
||||
} else
|
||||
ReductionOps.push_back(TreeN);
|
||||
ReductionData.addReductionOps(TreeN, ReductionOps);
|
||||
}
|
||||
// Retract.
|
||||
Stack.pop_back();
|
||||
@@ -4884,8 +5151,10 @@ public:
|
||||
// reduced value class.
|
||||
if (I && (!ReducedValueData || OpData == ReducedValueData ||
|
||||
OpData == ReductionData)) {
|
||||
const bool IsReductionOperation = OpData == ReductionData;
|
||||
// Only handle trees in the current basic block.
|
||||
if (I->getParent() != B->getParent()) {
|
||||
if (!ReductionData.hasSameParent(I, B->getParent(),
|
||||
IsReductionOperation)) {
|
||||
// I is an extra argument for TreeN (its parent operation).
|
||||
markExtraArg(Stack.back(), I);
|
||||
continue;
|
||||
@@ -4893,13 +5162,15 @@ public:
|
||||
|
||||
// Each tree node needs to have minimal number of users except for the
|
||||
// ultimate reduction.
|
||||
if (!I->hasNUses(NUses) && I != B) {
|
||||
if (!ReductionData.hasRequiredNumberOfUses(I,
|
||||
OpData == ReductionData) &&
|
||||
I != B) {
|
||||
// I is an extra argument for TreeN (its parent operation).
|
||||
markExtraArg(Stack.back(), I);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (OpData == ReductionData) {
|
||||
if (IsReductionOperation) {
|
||||
// We need to be able to reassociate the reduction operations.
|
||||
if (!OpData.isAssociative(I)) {
|
||||
// I is an extra argument for TreeN (its parent operation).
|
||||
@@ -4953,12 +5224,15 @@ public:
|
||||
// to use it.
|
||||
for (auto &Pair : ExtraArgs)
|
||||
ExternallyUsedValues[Pair.second].push_back(Pair.first);
|
||||
SmallVector<Value *, 16> IgnoreList;
|
||||
for (auto &V : ReductionOps)
|
||||
IgnoreList.append(V.begin(), V.end());
|
||||
while (i < NumReducedVals - ReduxWidth + 1 && ReduxWidth > 2) {
|
||||
auto VL = makeArrayRef(&ReducedVals[i], ReduxWidth);
|
||||
V.buildTree(VL, ExternallyUsedValues, ReductionOps);
|
||||
V.buildTree(VL, ExternallyUsedValues, IgnoreList);
|
||||
if (V.shouldReorder()) {
|
||||
SmallVector<Value *, 8> Reversed(VL.rbegin(), VL.rend());
|
||||
V.buildTree(Reversed, ExternallyUsedValues, ReductionOps);
|
||||
V.buildTree(Reversed, ExternallyUsedValues, IgnoreList);
|
||||
}
|
||||
if (V.isTreeTinyAndNotFullyVectorizable())
|
||||
break;
|
||||
@@ -4986,13 +5260,14 @@ public:
|
||||
|
||||
// Emit a reduction.
|
||||
Value *ReducedSubTree =
|
||||
emitReduction(VectorizedRoot, Builder, ReduxWidth, ReductionOps, TTI);
|
||||
emitReduction(VectorizedRoot, Builder, ReduxWidth, TTI);
|
||||
if (VectorizedTree) {
|
||||
Builder.SetCurrentDebugLocation(Loc);
|
||||
OperationData VectReductionData(ReductionData.getOpcode(),
|
||||
VectorizedTree, ReducedSubTree);
|
||||
VectorizedTree = VectReductionData.createOp(Builder, "bin.rdx");
|
||||
propagateIRFlags(VectorizedTree, ReductionOps);
|
||||
VectorizedTree, ReducedSubTree,
|
||||
ReductionData.getKind());
|
||||
VectorizedTree =
|
||||
VectReductionData.createOp(Builder, "op.rdx", ReductionOps);
|
||||
} else
|
||||
VectorizedTree = ReducedSubTree;
|
||||
i += ReduxWidth;
|
||||
@@ -5005,9 +5280,9 @@ public:
|
||||
auto *I = cast<Instruction>(ReducedVals[i]);
|
||||
Builder.SetCurrentDebugLocation(I->getDebugLoc());
|
||||
OperationData VectReductionData(ReductionData.getOpcode(),
|
||||
VectorizedTree, I);
|
||||
VectorizedTree = VectReductionData.createOp(Builder);
|
||||
propagateIRFlags(VectorizedTree, ReductionOps);
|
||||
VectorizedTree, I,
|
||||
ReductionData.getKind());
|
||||
VectorizedTree = VectReductionData.createOp(Builder, "", ReductionOps);
|
||||
}
|
||||
for (auto &Pair : ExternallyUsedValues) {
|
||||
assert(!Pair.second.empty() &&
|
||||
@@ -5016,9 +5291,9 @@ public:
|
||||
for (auto *I : Pair.second) {
|
||||
Builder.SetCurrentDebugLocation(I->getDebugLoc());
|
||||
OperationData VectReductionData(ReductionData.getOpcode(),
|
||||
VectorizedTree, Pair.first);
|
||||
VectorizedTree = VectReductionData.createOp(Builder, "bin.extra");
|
||||
propagateIRFlags(VectorizedTree, I);
|
||||
VectorizedTree, Pair.first,
|
||||
ReductionData.getKind());
|
||||
VectorizedTree = VectReductionData.createOp(Builder, "op.extra", I);
|
||||
}
|
||||
}
|
||||
// Update users.
|
||||
@@ -5038,19 +5313,58 @@ private:
|
||||
Type *ScalarTy = FirstReducedVal->getType();
|
||||
Type *VecTy = VectorType::get(ScalarTy, ReduxWidth);
|
||||
|
||||
int PairwiseRdxCost =
|
||||
TTI->getArithmeticReductionCost(ReductionData.getOpcode(), VecTy,
|
||||
/*IsPairwiseForm=*/true);
|
||||
int SplittingRdxCost =
|
||||
TTI->getArithmeticReductionCost(ReductionData.getOpcode(), VecTy,
|
||||
/*IsPairwiseForm=*/false);
|
||||
int PairwiseRdxCost;
|
||||
int SplittingRdxCost;
|
||||
switch (ReductionData.getKind()) {
|
||||
case RK_Arithmetic:
|
||||
PairwiseRdxCost =
|
||||
TTI->getArithmeticReductionCost(ReductionData.getOpcode(), VecTy,
|
||||
/*IsPairwiseForm=*/true);
|
||||
SplittingRdxCost =
|
||||
TTI->getArithmeticReductionCost(ReductionData.getOpcode(), VecTy,
|
||||
/*IsPairwiseForm=*/false);
|
||||
break;
|
||||
case RK_Min:
|
||||
case RK_Max:
|
||||
case RK_UMin:
|
||||
case RK_UMax: {
|
||||
Type *VecCondTy = CmpInst::makeCmpResultType(VecTy);
|
||||
bool IsUnsigned = ReductionData.getKind() == RK_UMin ||
|
||||
ReductionData.getKind() == RK_UMax;
|
||||
PairwiseRdxCost =
|
||||
TTI->getMinMaxReductionCost(VecTy, VecCondTy,
|
||||
/*IsPairwiseForm=*/true, IsUnsigned);
|
||||
SplittingRdxCost =
|
||||
TTI->getMinMaxReductionCost(VecTy, VecCondTy,
|
||||
/*IsPairwiseForm=*/false, IsUnsigned);
|
||||
break;
|
||||
}
|
||||
case RK_None:
|
||||
llvm_unreachable("Expected arithmetic or min/max reduction operation");
|
||||
}
|
||||
|
||||
IsPairwiseReduction = PairwiseRdxCost < SplittingRdxCost;
|
||||
int VecReduxCost = IsPairwiseReduction ? PairwiseRdxCost : SplittingRdxCost;
|
||||
|
||||
int ScalarReduxCost =
|
||||
(ReduxWidth - 1) *
|
||||
TTI->getArithmeticInstrCost(ReductionData.getOpcode(), ScalarTy);
|
||||
int ScalarReduxCost;
|
||||
switch (ReductionData.getKind()) {
|
||||
case RK_Arithmetic:
|
||||
ScalarReduxCost =
|
||||
TTI->getArithmeticInstrCost(ReductionData.getOpcode(), ScalarTy);
|
||||
break;
|
||||
case RK_Min:
|
||||
case RK_Max:
|
||||
case RK_UMin:
|
||||
case RK_UMax:
|
||||
ScalarReduxCost =
|
||||
TTI->getCmpSelInstrCost(ReductionData.getOpcode(), ScalarTy) +
|
||||
TTI->getCmpSelInstrCost(Instruction::Select, ScalarTy,
|
||||
CmpInst::makeCmpResultType(ScalarTy));
|
||||
break;
|
||||
case RK_None:
|
||||
llvm_unreachable("Expected arithmetic or min/max reduction operation");
|
||||
}
|
||||
ScalarReduxCost *= (ReduxWidth - 1);
|
||||
|
||||
DEBUG(dbgs() << "SLP: Adding cost " << VecReduxCost - ScalarReduxCost
|
||||
<< " for reduction that starts with " << *FirstReducedVal
|
||||
@@ -5063,8 +5377,7 @@ private:
|
||||
|
||||
/// \brief Emit a horizontal reduction of the vectorized value.
|
||||
Value *emitReduction(Value *VectorizedValue, IRBuilder<> &Builder,
|
||||
unsigned ReduxWidth, ArrayRef<Value *> RedOps,
|
||||
const TargetTransformInfo *TTI) {
|
||||
unsigned ReduxWidth, const TargetTransformInfo *TTI) {
|
||||
assert(VectorizedValue && "Need to have a vectorized tree node");
|
||||
assert(isPowerOf2_32(ReduxWidth) &&
|
||||
"We only handle power-of-two reductions for now");
|
||||
@@ -5072,7 +5385,7 @@ private:
|
||||
if (!IsPairwiseReduction)
|
||||
return createSimpleTargetReduction(
|
||||
Builder, TTI, ReductionData.getOpcode(), VectorizedValue,
|
||||
TargetTransformInfo::ReductionFlags(), RedOps);
|
||||
ReductionData.getFlags(), ReductionOps.back());
|
||||
|
||||
Value *TmpVec = VectorizedValue;
|
||||
for (unsigned i = ReduxWidth / 2; i != 0; i >>= 1) {
|
||||
@@ -5087,9 +5400,8 @@ private:
|
||||
TmpVec, UndefValue::get(TmpVec->getType()), (RightMask),
|
||||
"rdx.shuf.r");
|
||||
OperationData VectReductionData(ReductionData.getOpcode(), LeftShuf,
|
||||
RightShuf);
|
||||
TmpVec = VectReductionData.createOp(Builder, "bin.rdx");
|
||||
propagateIRFlags(TmpVec, RedOps);
|
||||
RightShuf, ReductionData.getKind());
|
||||
TmpVec = VectReductionData.createOp(Builder, "op.rdx", ReductionOps);
|
||||
}
|
||||
|
||||
// The result is in the first element of the vector.
|
||||
@@ -5249,9 +5561,11 @@ static bool tryToVectorizeHorReductionOrInstOperands(
|
||||
auto *Inst = dyn_cast<Instruction>(V);
|
||||
if (!Inst)
|
||||
continue;
|
||||
if (auto *BI = dyn_cast<BinaryOperator>(Inst)) {
|
||||
auto *BI = dyn_cast<BinaryOperator>(Inst);
|
||||
auto *SI = dyn_cast<SelectInst>(Inst);
|
||||
if (BI || SI) {
|
||||
HorizontalReduction HorRdx;
|
||||
if (HorRdx.matchAssociativeReduction(P, BI)) {
|
||||
if (HorRdx.matchAssociativeReduction(P, Inst)) {
|
||||
if (HorRdx.tryToReduce(R, TTI)) {
|
||||
Res = true;
|
||||
// Set P to nullptr to avoid re-analysis of phi node in
|
||||
@@ -5260,7 +5574,7 @@ static bool tryToVectorizeHorReductionOrInstOperands(
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (P) {
|
||||
if (P && BI) {
|
||||
Inst = dyn_cast<Instruction>(BI->getOperand(0));
|
||||
if (Inst == P)
|
||||
Inst = dyn_cast<Instruction>(BI->getOperand(1));
|
||||
|
||||
Reference in New Issue
Block a user