-
Notifications
You must be signed in to change notification settings - Fork 12.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SelectionDAG] Add PARTIAL_REDUCE_U/SMLA ISD Nodes #125207
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-backend-aarch64 @llvm/pr-subscribers-llvm-selectiondag Author: James Chesterman (JamesChesterman) ChangesAdd signed and unsigned PARTIAL_REDUCE_MLA ISD nodes. Add command line argument (new-partial-reduce-lowering) that indicates whether the intrinsic experimental_vector_partial_ reduce_add will be transformed into the new ISD node. Lowering with the new ISD nodes will, for now, always be done as an expand. Patch is 55.61 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/125207.diff 10 Files Affected:
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index fd8784a4c10034..3f235ee358e0ed 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -1451,6 +1451,20 @@ enum NodeType {
VECREDUCE_UMAX,
VECREDUCE_UMIN,
+ // PARTIAL_REDUCE_*MLA (Accumulator, Input1, Input2)
+ // Partial reduction nodes. Input1 and Input2 are multiplied together before
+ // being reduced, by addition to the number of elements that Accumulator's
+ // type has.
+ // Input1 and Input2 must be the same type. Accumulator and the output must be
+ // the same type.
+ // The number of elements in Input1 and Input2 must be a positive integer
+ // multiple of the number of elements in the Accumulator / output type.
+ // All operands, as well as the output, must have the same element type.
+ // Operands: Accumulator, Input1, Input2
+ // Outputs: Output
+ PARTIAL_REDUCE_SMLA,
+ PARTIAL_REDUCE_UMLA,
+
// The `llvm.experimental.stackmap` intrinsic.
// Operands: input chain, glue, <id>, <numShadowBytes>, [live0[, live1...]]
// Outputs: output chain, glue
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 461c0c1ead16d2..0fc6f6ccf85bd9 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1607,6 +1607,13 @@ class SelectionDAG {
/// the target's desired shift amount type.
SDValue getShiftAmountOperand(EVT LHSTy, SDValue Op);
+ // Expands PARTIAL_REDUCE_S/UMLA nodes
+ // \p Acc Accumulator for where the result is stored for the partial reduction
+ // operation.
+ // \p Input1 First input for the partial reduction operation
+ // \p Input2 Second input for the partial reduction operation
+ SDValue expandPartialReduceMLA(SDNode *N);
+
/// Create the DAG equivalent of vector_partial_reduce where Op1 and Op2 are
/// its operands and ReducedTY is the intrinsic's return type.
SDValue getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 625052be657ca0..3a9518ea569ebc 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -159,6 +159,11 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
Res = PromoteIntRes_VECTOR_FIND_LAST_ACTIVE(N);
break;
+ case ISD::PARTIAL_REDUCE_UMLA:
+ case ISD::PARTIAL_REDUCE_SMLA:
+ Res = PromoteIntRes_PARTIAL_REDUCE_MLA(N);
+ break;
+
case ISD::SIGN_EXTEND:
case ISD::VP_SIGN_EXTEND:
case ISD::ZERO_EXTEND:
@@ -2076,6 +2081,10 @@ bool DAGTypeLegalizer::PromoteIntegerOperand(SDNode *N, unsigned OpNo) {
case ISD::VECTOR_FIND_LAST_ACTIVE:
Res = PromoteIntOp_VECTOR_FIND_LAST_ACTIVE(N, OpNo);
break;
+ case ISD::PARTIAL_REDUCE_UMLA:
+ case ISD::PARTIAL_REDUCE_SMLA:
+ Res = PromoteIntOp_PARTIAL_REDUCE_MLA(N);
+ break;
}
// If the result is null, the sub-method took care of registering results etc.
@@ -2824,6 +2833,12 @@ SDValue DAGTypeLegalizer::PromoteIntOp_VECTOR_FIND_LAST_ACTIVE(SDNode *N,
return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0);
}
+SDValue DAGTypeLegalizer::PromoteIntOp_PARTIAL_REDUCE_MLA(SDNode *N) {
+ SDValue Res = DAG.expandPartialReduceMLA(N);
+ ReplaceValueWith(SDValue(N, 0), Res);
+ return SDValue();
+}
+
//===----------------------------------------------------------------------===//
// Integer Result Expansion
//===----------------------------------------------------------------------===//
@@ -6139,6 +6154,12 @@ SDValue DAGTypeLegalizer::PromoteIntRes_VECTOR_FIND_LAST_ACTIVE(SDNode *N) {
return DAG.getNode(ISD::VECTOR_FIND_LAST_ACTIVE, SDLoc(N), NVT, N->ops());
}
+SDValue DAGTypeLegalizer::PromoteIntRes_PARTIAL_REDUCE_MLA(SDNode *N) {
+ SDValue Res = DAG.expandPartialReduceMLA(N);
+ ReplaceValueWith(SDValue(N, 0), Res);
+ return SDValue();
+}
+
SDValue DAGTypeLegalizer::PromoteIntRes_INSERT_VECTOR_ELT(SDNode *N) {
EVT OutVT = N->getValueType(0);
EVT NOutVT = TLI.getTypeToTransformTo(*DAG.getContext(), OutVT);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index f13f70e66cfaa6..cb9c1b239c0fa9 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -379,6 +379,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue PromoteIntRes_IS_FPCLASS(SDNode *N);
SDValue PromoteIntRes_PATCHPOINT(SDNode *N);
SDValue PromoteIntRes_VECTOR_FIND_LAST_ACTIVE(SDNode *N);
+ SDValue PromoteIntRes_PARTIAL_REDUCE_MLA(SDNode *N);
// Integer Operand Promotion.
bool PromoteIntegerOperand(SDNode *N, unsigned OpNo);
@@ -430,6 +431,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue PromoteIntOp_VP_SPLICE(SDNode *N, unsigned OpNo);
SDValue PromoteIntOp_VECTOR_HISTOGRAM(SDNode *N, unsigned OpNo);
SDValue PromoteIntOp_VECTOR_FIND_LAST_ACTIVE(SDNode *N, unsigned OpNo);
+ SDValue PromoteIntOp_PARTIAL_REDUCE_MLA(SDNode *N);
void SExtOrZExtPromotedOperands(SDValue &LHS, SDValue &RHS);
void PromoteSetCCOperands(SDValue &LHS,SDValue &RHS, ISD::CondCode Code);
@@ -968,6 +970,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
void SplitVecRes_VAARG(SDNode *N, SDValue &Lo, SDValue &Hi);
void SplitVecRes_FP_TO_XINT_SAT(SDNode *N, SDValue &Lo, SDValue &Hi);
void SplitVecRes_VP_REVERSE(SDNode *N, SDValue &Lo, SDValue &Hi);
+ void SplitVecRes_PARTIAL_REDUCE_MLA(SDNode *N);
// Vector Operand Splitting: <128 x ty> -> 2 x <64 x ty>.
bool SplitVectorOperand(SDNode *N, unsigned OpNo);
@@ -999,6 +1002,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue SplitVecOp_FP_TO_XINT_SAT(SDNode *N);
SDValue SplitVecOp_VP_CttzElements(SDNode *N);
SDValue SplitVecOp_VECTOR_HISTOGRAM(SDNode *N);
+ SDValue SplitVecOp_PARTIAL_REDUCE_MLA(SDNode *N);
//===--------------------------------------------------------------------===//
// Vector Widening Support: LegalizeVectorTypes.cpp
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index 1000235ab4061f..b01470028981e7 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -1373,6 +1373,9 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
case ISD::EXPERIMENTAL_VP_REVERSE:
SplitVecRes_VP_REVERSE(N, Lo, Hi);
break;
+ case ISD::PARTIAL_REDUCE_UMLA:
+ case ISD::PARTIAL_REDUCE_SMLA:
+ SplitVecRes_PARTIAL_REDUCE_MLA(N);
}
// If Lo/Hi is null, the sub-method took care of registering results etc.
@@ -3182,6 +3185,11 @@ void DAGTypeLegalizer::SplitVecRes_VP_REVERSE(SDNode *N, SDValue &Lo,
std::tie(Lo, Hi) = DAG.SplitVector(Load, DL);
}
+void DAGTypeLegalizer::SplitVecRes_PARTIAL_REDUCE_MLA(SDNode *N) {
+ SDValue Res = DAG.expandPartialReduceMLA(N);
+ ReplaceValueWith(SDValue(N, 0), Res);
+}
+
void DAGTypeLegalizer::SplitVecRes_VECTOR_DEINTERLEAVE(SDNode *N) {
SDValue Op0Lo, Op0Hi, Op1Lo, Op1Hi;
@@ -3381,6 +3389,9 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
Res = SplitVecOp_VECTOR_HISTOGRAM(N);
break;
+ case ISD::PARTIAL_REDUCE_UMLA:
+ case ISD::PARTIAL_REDUCE_SMLA:
+ Res = SplitVecOp_PARTIAL_REDUCE_MLA(N);
}
// If the result is null, the sub-method took care of registering results etc.
@@ -4435,6 +4446,12 @@ SDValue DAGTypeLegalizer::SplitVecOp_VECTOR_HISTOGRAM(SDNode *N) {
MMO, IndexType);
}
+SDValue DAGTypeLegalizer::SplitVecOp_PARTIAL_REDUCE_MLA(SDNode *N) {
+ SDValue Res = DAG.expandPartialReduceMLA(N);
+ ReplaceValueWith(SDValue(N, 0), Res);
+ return SDValue();
+}
+
//===----------------------------------------------------------------------===//
// Result Vector Widening
//===----------------------------------------------------------------------===//
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index b416c0efbbc4fc..7240e4e00dfa07 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -2473,6 +2473,23 @@ SDValue SelectionDAG::getShiftAmountOperand(EVT LHSTy, SDValue Op) {
return getZExtOrTrunc(Op, SDLoc(Op), ShTy);
}
+SDValue SelectionDAG::expandPartialReduceMLA(SDNode *N) {
+ SDLoc DL(N);
+ SDValue Acc = N->getOperand(0);
+ SDValue Input1 = N->getOperand(1);
+ SDValue Input2 = N->getOperand(2);
+
+ EVT FullTy = Input1.getValueType();
+
+ SDValue Input = Input1;
+ APInt ConstantOne;
+ if (!ISD::isConstantSplatVector(Input2.getNode(), ConstantOne) ||
+ !ConstantOne.isOne())
+ Input = getNode(ISD::MUL, DL, FullTy, Input1, Input2);
+
+ return getPartialReduceAdd(DL, Acc.getValueType(), Acc, Input);
+}
+
SDValue SelectionDAG::getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
SDValue Op2) {
EVT FullTy = Op2.getValueType();
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 428e7a316d247b..144439f136ff16 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -135,6 +135,10 @@ static cl::opt<unsigned> SwitchPeelThreshold(
"switch statement. A value greater than 100 will void this "
"optimization"));
+static cl::opt<bool> NewPartialReduceLowering(
+ "new-partial-reduce-lowering", cl::init(false), cl::ReallyHidden,
+ cl::desc("Use the new method of lowering partial reductions."));
+
// Limit the width of DAG chains. This is important in general to prevent
// DAG-based analysis from blowing up. For example, alias analysis and
// load clustering may not complete in reasonable time. It is difficult to
@@ -8118,6 +8122,29 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
return;
}
case Intrinsic::experimental_vector_partial_reduce_add: {
+ if (NewPartialReduceLowering) {
+ SDValue Acc = getValue(I.getOperand(0));
+ EVT AccVT = Acc.getValueType();
+ SDValue Input = getValue(I.getOperand(1));
+ EVT InputVT = Input.getValueType();
+
+ assert(AccVT.getVectorElementType() == InputVT.getVectorElementType() &&
+ "Expected operands to have the same vector element type!");
+ assert(
+ InputVT.getVectorElementCount().getKnownMinValue() %
+ AccVT.getVectorElementCount().getKnownMinValue() ==
+ 0 &&
+ "Expected the element count of the Input operand to be a positive "
+ "integer multiple of the element count of the Accumulator operand!");
+
+ // ISD::PARTIAL_REDUCE_UMLA is chosen arbitrarily and would function the
+ // same if ISD::PARTIAL_REDUCE_SMLA was chosen instead. It should be
+ // changed to its correct signedness when combining or expanding,
+ // according to extends being performed on Input.
+ setValue(&I, DAG.getNode(ISD::PARTIAL_REDUCE_UMLA, sdl, AccVT, Acc, Input,
+ DAG.getConstant(1, sdl, InputVT)));
+ return;
+ }
if (!TLI.shouldExpandPartialReductionIntrinsic(cast<IntrinsicInst>(&I))) {
visitTargetIntrinsic(I, Intrinsic);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index f63c8dd3df1c83..a387c10679261b 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -570,6 +570,11 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
case ISD::VECTOR_FIND_LAST_ACTIVE:
return "find_last_active";
+ case ISD::PARTIAL_REDUCE_UMLA:
+ return "partial_reduce_umla";
+ case ISD::PARTIAL_REDUCE_SMLA:
+ return "partial_reduce_smla";
+
// Vector Predication
#define BEGIN_REGISTER_VP_SDNODE(SDID, LEGALARG, NAME, ...) \
case ISD::SDID: \
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index 66f83c658ff4f2..16c0001dbdb838 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -1,12 +1,41 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc -mtriple=aarch64 -mattr=+sve2,+i8mm %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-I8MM
; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-NOI8MM
+; RUN: llc -mtriple=aarch64 -mattr=+sve2,+i8mm -new-partial-reduce-lowering %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-NEWLOWERING
define <vscale x 4 x i32> @udot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
-; CHECK-LABEL: udot:
-; CHECK: // %bb.0: // %entry
-; CHECK-NEXT: udot z0.s, z1.b, z2.b
-; CHECK-NEXT: ret
+; CHECK-I8MM-LABEL: udot:
+; CHECK-I8MM: // %bb.0: // %entry
+; CHECK-I8MM-NEXT: udot z0.s, z1.b, z2.b
+; CHECK-I8MM-NEXT: ret
+;
+; CHECK-NOI8MM-LABEL: udot:
+; CHECK-NOI8MM: // %bb.0: // %entry
+; CHECK-NOI8MM-NEXT: udot z0.s, z1.b, z2.b
+; CHECK-NOI8MM-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: udot:
+; CHECK-NEWLOWERING: // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT: uunpklo z3.h, z1.b
+; CHECK-NEWLOWERING-NEXT: uunpklo z4.h, z2.b
+; CHECK-NEWLOWERING-NEXT: uunpkhi z1.h, z1.b
+; CHECK-NEWLOWERING-NEXT: uunpkhi z2.h, z2.b
+; CHECK-NEWLOWERING-NEXT: ptrue p0.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z5.s, z3.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z3.s, z3.h
+; CHECK-NEWLOWERING-NEXT: uunpklo z6.s, z4.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z4.s, z4.h
+; CHECK-NEWLOWERING-NEXT: uunpklo z7.s, z1.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT: uunpklo z24.s, z2.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z5.s, z6.s
+; CHECK-NEWLOWERING-NEXT: mul z3.s, z3.s, z4.s
+; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEWLOWERING-NEXT: movprfx z1, z3
+; CHECK-NEWLOWERING-NEXT: mla z1.s, p0/m, z7.s, z24.s
+; CHECK-NEWLOWERING-NEXT: add z0.s, z1.s, z0.s
+; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
%b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i32>
@@ -16,10 +45,38 @@ entry:
}
define <vscale x 2 x i64> @udot_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
-; CHECK-LABEL: udot_wide:
-; CHECK: // %bb.0: // %entry
-; CHECK-NEXT: udot z0.d, z1.h, z2.h
-; CHECK-NEXT: ret
+; CHECK-I8MM-LABEL: udot_wide:
+; CHECK-I8MM: // %bb.0: // %entry
+; CHECK-I8MM-NEXT: udot z0.d, z1.h, z2.h
+; CHECK-I8MM-NEXT: ret
+;
+; CHECK-NOI8MM-LABEL: udot_wide:
+; CHECK-NOI8MM: // %bb.0: // %entry
+; CHECK-NOI8MM-NEXT: udot z0.d, z1.h, z2.h
+; CHECK-NOI8MM-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: udot_wide:
+; CHECK-NEWLOWERING: // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT: uunpklo z3.s, z1.h
+; CHECK-NEWLOWERING-NEXT: uunpklo z4.s, z2.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT: ptrue p0.d
+; CHECK-NEWLOWERING-NEXT: uunpklo z5.d, z3.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z3.d, z3.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z6.d, z4.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z4.d, z4.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z7.d, z1.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z24.d, z2.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z5.d, z6.d
+; CHECK-NEWLOWERING-NEXT: mul z3.d, z3.d, z4.d
+; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEWLOWERING-NEXT: movprfx z1, z3
+; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z7.d, z24.d
+; CHECK-NEWLOWERING-NEXT: add z0.d, z1.d, z0.d
+; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i64>
%b.wide = zext <vscale x 8 x i16> %b to <vscale x 8 x i64>
@@ -29,10 +86,38 @@ entry:
}
define <vscale x 4 x i32> @sdot(<vscale x 4 x i32> %accc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
-; CHECK-LABEL: sdot:
-; CHECK: // %bb.0: // %entry
-; CHECK-NEXT: sdot z0.s, z1.b, z2.b
-; CHECK-NEXT: ret
+; CHECK-I8MM-LABEL: sdot:
+; CHECK-I8MM: // %bb.0: // %entry
+; CHECK-I8MM-NEXT: sdot z0.s, z1.b, z2.b
+; CHECK-I8MM-NEXT: ret
+;
+; CHECK-NOI8MM-LABEL: sdot:
+; CHECK-NOI8MM: // %bb.0: // %entry
+; CHECK-NOI8MM-NEXT: sdot z0.s, z1.b, z2.b
+; CHECK-NOI8MM-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: sdot:
+; CHECK-NEWLOWERING: // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT: sunpklo z3.h, z1.b
+; CHECK-NEWLOWERING-NEXT: sunpklo z4.h, z2.b
+; CHECK-NEWLOWERING-NEXT: sunpkhi z1.h, z1.b
+; CHECK-NEWLOWERING-NEXT: sunpkhi z2.h, z2.b
+; CHECK-NEWLOWERING-NEXT: ptrue p0.s
+; CHECK-NEWLOWERING-NEXT: sunpklo z5.s, z3.h
+; CHECK-NEWLOWERING-NEXT: sunpkhi z3.s, z3.h
+; CHECK-NEWLOWERING-NEXT: sunpklo z6.s, z4.h
+; CHECK-NEWLOWERING-NEXT: sunpkhi z4.s, z4.h
+; CHECK-NEWLOWERING-NEXT: sunpklo z7.s, z1.h
+; CHECK-NEWLOWERING-NEXT: sunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT: sunpklo z24.s, z2.h
+; CHECK-NEWLOWERING-NEXT: sunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z5.s, z6.s
+; CHECK-NEWLOWERING-NEXT: mul z3.s, z3.s, z4.s
+; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEWLOWERING-NEXT: movprfx z1, z3
+; CHECK-NEWLOWERING-NEXT: mla z1.s, p0/m, z7.s, z24.s
+; CHECK-NEWLOWERING-NEXT: add z0.s, z1.s, z0.s
+; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
%b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i32>
@@ -42,10 +127,38 @@ entry:
}
define <vscale x 2 x i64> @sdot_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
-; CHECK-LABEL: sdot_wide:
-; CHECK: // %bb.0: // %entry
-; CHECK-NEXT: sdot z0.d, z1.h, z2.h
-; CHECK-NEXT: ret
+; CHECK-I8MM-LABEL: sdot_wide:
+; CHECK-I8MM: // %bb.0: // %entry
+; CHECK-I8MM-NEXT: sdot z0.d, z1.h, z2.h
+; CHECK-I8MM-NEXT: ret
+;
+; CHECK-NOI8MM-LABEL: sdot_wide:
+; CHECK-NOI8MM: // %bb.0: // %entry
+; CHECK-NOI8MM-NEXT: sdot z0.d, z1.h, z2.h
+; CHECK-NOI8MM-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: sdot_wide:
+; CHECK-NEWLOWERING: // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT: sunpklo z3.s, z1.h
+; CHECK-NEWLOWERING-NEXT: sunpklo z4.s, z2.h
+; CHECK-NEWLOWERING-NEXT: sunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT: sunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT: ptrue p0.d
+; CHECK-NEWLOWERING-NEXT: sunpklo z5.d, z3.s
+; CHECK-NEWLOWERING-NEXT: sunpkhi z3.d, z3.s
+; CHECK-NEWLOWERING-NEXT: sunpklo z6.d, z4.s
+; CHECK-NEWLOWERING-NEXT: sunpkhi z4.d, z4.s
+; CHECK-NEWLOWERING-NEXT: sunpklo z7.d, z1.s
+; CHECK-NEWLOWERING-NEXT: sunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-NEXT: sunpklo z24.d, z2.s
+; CHECK-NEWLOWERING-NEXT: sunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z5.d, z6.d
+; CHECK-NEWLOWERING-NEXT: mul z3.d, z3.d, z4.d
+; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEWLOWERING-NEXT: movprfx z1, z3
+; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z7.d, z24.d
+; CHECK-NEWLOWERING-NEXT: add z0.d, z1.d, z0.d
+; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i64>
%b.wide = sext <vscale x 8 x i16> %b to <vscale x 8 x i64>
@@ -82,6 +195,29 @@ define <vscale x 4 x i32> @usdot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a,
; CHECK-NOI8MM-NEXT: mla z1.s, p0/m, z7.s, z24.s
; CHECK-NOI8MM-NEXT: add z0.s, z1.s, z0.s
; CHECK-NOI8MM-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: usdot:
+; CHECK-NEWLOWERING: // %bb.0:...
[truncated]
|
Add signed and unsigned PARTIAL_REDUCE_MLA ISD nodes. Add command line argument (new-partial-reduce-lowering) that indicates whether the intrinsic experimental_vector_partial_ reduce_add will be transformed into the new ISD node. Lowering with the new ISD nodes will, for now, always be done as an expand.
Move the getPartialReduceAdd function around. Make the new codepath work for fixed length NEON vectors too.
a4fb454
to
4cd617a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, this is looking pretty good! Just a few more minor comments before I'm happy to accept it.
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Make comment describing the node more broad. Promote both inputs at the same time. Move assert statements to the getNode() function. Make the command line option target-specific.
Improve assert statements. Remove unnecessary variable creation. Change how operation actions are set for the nodes.
Fix wrong signedness in promotion function. Make expand code depend on node signedness not on operand extend.
Was seeing if the operand opcodes were PARTIAL_REDUCE_MLA nodes. Now looking at their parent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few final recommendations but otherwise this looks good to me. I'll admit to not be convinced about the implementation of SplitVecRes_PARTIAL_REDUCE_MLA
but figure that'll be short lived and so if you're sure it works then that's good enough for me.
unsigned ExtOpc = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA | ||
? ISD::SIGN_EXTEND | ||
: ISD::ZERO_EXTEND; | ||
EVT MulLHSVT = MulLHS.getValueType(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think you need MulLHSVT
because it is the same as FullTy
? That said, FullTy
is a terrible name, perhaps MulVT
?
assert(MulLHSVT == MulRHS.getValueType() && | ||
"The second and third operands of a PARTIAL_REDUCE_MLA node must have " | ||
"the same value type!"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should assume the DAG is well formed and thus this assert is unnecessary. This is the reason behind adding the asserts to getNode()
because it is much easier to catch failures during construction and means there's no need to pollute the DAG combines with such validation code.
EVT ExtVT = MulLHSVT.changeVectorElementType( | ||
Acc.getValueType().getVectorElementType()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think you need ExtVT
because it is the same as NewVT
? Although perhaps rename that to NewMulVT
or ExtMulVT
?
SDValue Res = TLI.expandPartialReduceMLA(N, DAG); | ||
ReplaceValueWith(SDValue(N, 0), Res); | ||
return SDValue(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does just return TLI.expandPartialReduceMLA(N, DAG)
work here?
// The partial reduction nodes sign or zero extend Input1 and Input2 to the | ||
// element type of Accumulator before multiplying their results. | ||
// This result is concatenated to the Accumulator, and this is then reduced, | ||
// using addition, to the result type. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This specification is missing a really important detail. Specifically, which elements are reduced via addition? There's at least two reasonable choices here: adjacent, and subvector addition (i.e. every nth input). Which is it? Looking at the LangRef the associated intrinsic, this detail is missing there as well. The specification needs updating.
Looking at the implementation it is subvector-wise. This needs explicitly stated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The intrinsic and matching ISD node are deliberate in not defining the order in which the elements are reduced because their output is only expected to feed into themselves or an equivalent vector.reduce operation and thus the ordering does not matter. Essentially, add reductions can now be represented in LLVM as a two phase operation (one in-loop and the other out-of-loop) in the most relaxed way possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then you need to say that clearly in the specification text. At the moment, it looks like merely an unstated assumption, not a specifically reserved behavior. The bit about the result only being valid for use by further reductions (i.e. non-order preserving operations) is really important to clearly say.
Seriously, I read your specification and didn't not understand this. I doubt I'm the only one.
@@ -3182,6 +3186,11 @@ void DAGTypeLegalizer::SplitVecRes_VP_REVERSE(SDNode *N, SDValue &Lo, | |||
std::tie(Lo, Hi) = DAG.SplitVector(Load, DL); | |||
} | |||
|
|||
void DAGTypeLegalizer::SplitVecRes_PARTIAL_REDUCE_MLA(SDNode *N) { | |||
SDValue Res = TLI.expandPartialReduceMLA(N, DAG); | |||
ReplaceValueWith(SDValue(N, 0), Res); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on the precedent of SplitVecRes_VECTOR_SPLICE - which is the only example of expanding during SplitVecRes I found - I think you need to manually split the result of the expand into Lo and Hi here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh nice, thanks for finding this.
Add signed and unsigned PARTIAL_REDUCE_MLA ISD nodes. Add command line argument (new-partial-reduce-lowering) that indicates whether the intrinsic experimental_vector_partial_ reduce_add will be transformed into the new ISD node. Lowering with the new ISD nodes will, for now, always be done as an expand.