[SLP] fix fast-math-flag propagation on FP reductions
As shown in the test diffs, we could miscompile by propagating flags that did not exist in the original code. The flags required for fmin/fmax reductions will be fixed in a follow-up patch.
This commit is contained in:
@@ -6820,12 +6820,18 @@ public:
|
||||
if (NumReducedVals < 4)
|
||||
return false;
|
||||
|
||||
// FIXME: Fast-math-flags should be set based on the instructions in the
|
||||
// reduction (not all of 'fast' are required).
|
||||
// Intersect the fast-math-flags from all reduction operations.
|
||||
FastMathFlags RdxFMF;
|
||||
RdxFMF.set();
|
||||
for (ReductionOpsType &RdxOp : ReductionOps) {
|
||||
for (Value *RdxVal : RdxOp) {
|
||||
if (auto *FPMO = dyn_cast<FPMathOperator>(RdxVal))
|
||||
RdxFMF &= FPMO->getFastMathFlags();
|
||||
}
|
||||
}
|
||||
|
||||
IRBuilder<> Builder(cast<Instruction>(ReductionRoot));
|
||||
FastMathFlags Unsafe;
|
||||
Unsafe.setFast();
|
||||
Builder.setFastMathFlags(Unsafe);
|
||||
Builder.setFastMathFlags(RdxFMF);
|
||||
|
||||
BoUpSLP::ExtraValueToDebugLocsMap ExternallyUsedValues;
|
||||
// The same extra argument may be used several times, so log each attempt
|
||||
@@ -7071,9 +7077,6 @@ private:
|
||||
assert(isPowerOf2_32(ReduxWidth) &&
|
||||
"We only handle power-of-two reductions for now");
|
||||
|
||||
// FIXME: The builder should use an FMF guard. It should not be hard-coded
|
||||
// to 'fast'.
|
||||
assert(Builder.getFastMathFlags().isFast() && "Expected 'fast' FMF");
|
||||
return createSimpleTargetReduction(Builder, TTI, VectorizedValue, RdxKind,
|
||||
ReductionOps.back());
|
||||
}
|
||||
|
||||
@@ -1766,7 +1766,6 @@ bb.1:
|
||||
ret void
|
||||
}
|
||||
|
||||
; FIXME: This is a miscompile.
|
||||
; The FMF on the reduction should match the incoming insts.
|
||||
|
||||
define float @fadd_v4f32_fmf(float* %p) {
|
||||
@@ -1776,7 +1775,7 @@ define float @fadd_v4f32_fmf(float* %p) {
|
||||
; CHECK-NEXT: [[P3:%.*]] = getelementptr inbounds float, float* [[P]], i64 3
|
||||
; CHECK-NEXT: [[TMP1:%.*]] = bitcast float* [[P]] to <4 x float>*
|
||||
; CHECK-NEXT: [[TMP2:%.*]] = load <4 x float>, <4 x float>* [[TMP1]], align 4
|
||||
; CHECK-NEXT: [[TMP3:%.*]] = call fast float @llvm.vector.reduce.fadd.v4f32(float -0.000000e+00, <4 x float> [[TMP2]])
|
||||
; CHECK-NEXT: [[TMP3:%.*]] = call reassoc nsz float @llvm.vector.reduce.fadd.v4f32(float -0.000000e+00, <4 x float> [[TMP2]])
|
||||
; CHECK-NEXT: ret float [[TMP3]]
|
||||
;
|
||||
; STORE-LABEL: @fadd_v4f32_fmf(
|
||||
@@ -1785,7 +1784,7 @@ define float @fadd_v4f32_fmf(float* %p) {
|
||||
; STORE-NEXT: [[P3:%.*]] = getelementptr inbounds float, float* [[P]], i64 3
|
||||
; STORE-NEXT: [[TMP1:%.*]] = bitcast float* [[P]] to <4 x float>*
|
||||
; STORE-NEXT: [[TMP2:%.*]] = load <4 x float>, <4 x float>* [[TMP1]], align 4
|
||||
; STORE-NEXT: [[TMP3:%.*]] = call fast float @llvm.vector.reduce.fadd.v4f32(float -0.000000e+00, <4 x float> [[TMP2]])
|
||||
; STORE-NEXT: [[TMP3:%.*]] = call reassoc nsz float @llvm.vector.reduce.fadd.v4f32(float -0.000000e+00, <4 x float> [[TMP2]])
|
||||
; STORE-NEXT: ret float [[TMP3]]
|
||||
;
|
||||
%p1 = getelementptr inbounds float, float* %p, i64 1
|
||||
@@ -1801,6 +1800,10 @@ define float @fadd_v4f32_fmf(float* %p) {
|
||||
ret float %add3
|
||||
}
|
||||
|
||||
; The minimal FMF for fadd reduction are "reassoc nsz".
|
||||
; Only the common FMF of all operations in the reduction propagate to the result.
|
||||
; In this example, "contract nnan arcp" are dropped, but "ninf" transfers with the required flags.
|
||||
|
||||
define float @fadd_v4f32_fmf_intersect(float* %p) {
|
||||
; CHECK-LABEL: @fadd_v4f32_fmf_intersect(
|
||||
; CHECK-NEXT: [[P1:%.*]] = getelementptr inbounds float, float* [[P:%.*]], i64 1
|
||||
@@ -1808,7 +1811,7 @@ define float @fadd_v4f32_fmf_intersect(float* %p) {
|
||||
; CHECK-NEXT: [[P3:%.*]] = getelementptr inbounds float, float* [[P]], i64 3
|
||||
; CHECK-NEXT: [[TMP1:%.*]] = bitcast float* [[P]] to <4 x float>*
|
||||
; CHECK-NEXT: [[TMP2:%.*]] = load <4 x float>, <4 x float>* [[TMP1]], align 4
|
||||
; CHECK-NEXT: [[TMP3:%.*]] = call fast float @llvm.vector.reduce.fadd.v4f32(float -0.000000e+00, <4 x float> [[TMP2]])
|
||||
; CHECK-NEXT: [[TMP3:%.*]] = call reassoc ninf nsz float @llvm.vector.reduce.fadd.v4f32(float -0.000000e+00, <4 x float> [[TMP2]])
|
||||
; CHECK-NEXT: ret float [[TMP3]]
|
||||
;
|
||||
; STORE-LABEL: @fadd_v4f32_fmf_intersect(
|
||||
@@ -1817,7 +1820,7 @@ define float @fadd_v4f32_fmf_intersect(float* %p) {
|
||||
; STORE-NEXT: [[P3:%.*]] = getelementptr inbounds float, float* [[P]], i64 3
|
||||
; STORE-NEXT: [[TMP1:%.*]] = bitcast float* [[P]] to <4 x float>*
|
||||
; STORE-NEXT: [[TMP2:%.*]] = load <4 x float>, <4 x float>* [[TMP1]], align 4
|
||||
; STORE-NEXT: [[TMP3:%.*]] = call fast float @llvm.vector.reduce.fadd.v4f32(float -0.000000e+00, <4 x float> [[TMP2]])
|
||||
; STORE-NEXT: [[TMP3:%.*]] = call reassoc ninf nsz float @llvm.vector.reduce.fadd.v4f32(float -0.000000e+00, <4 x float> [[TMP2]])
|
||||
; STORE-NEXT: ret float [[TMP3]]
|
||||
;
|
||||
%p1 = getelementptr inbounds float, float* %p, i64 1
|
||||
|
||||
Reference in New Issue
Block a user