Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SelectionDAG] Add PARTIAL_REDUCE_U/SMLA ISD Nodes #125207

Open
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

JamesChesterman
Copy link
Contributor

Add signed and unsigned PARTIAL_REDUCE_MLA ISD nodes. Add command line argument (new-partial-reduce-lowering) that indicates whether the intrinsic experimental_vector_partial_ reduce_add will be transformed into the new ISD node. Lowering with the new ISD nodes will, for now, always be done as an expand.

@llvmbot
Copy link
Member

llvmbot commented Jan 31, 2025

@llvm/pr-subscribers-backend-aarch64

@llvm/pr-subscribers-llvm-selectiondag

Author: James Chesterman (JamesChesterman)

Changes

Add signed and unsigned PARTIAL_REDUCE_MLA ISD nodes. Add command line argument (new-partial-reduce-lowering) that indicates whether the intrinsic experimental_vector_partial_ reduce_add will be transformed into the new ISD node. Lowering with the new ISD nodes will, for now, always be done as an expand.


Patch is 55.61 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/125207.diff

10 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/ISDOpcodes.h (+14)
  • (modified) llvm/include/llvm/CodeGen/SelectionDAG.h (+7)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp (+21)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h (+4)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp (+17)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp (+17)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (+27)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp (+5)
  • (modified) llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll (+635-74)
  • (modified) llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll (+49)
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index fd8784a4c10034..3f235ee358e0ed 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -1451,6 +1451,20 @@ enum NodeType {
   VECREDUCE_UMAX,
   VECREDUCE_UMIN,
 
+  // PARTIAL_REDUCE_*MLA (Accumulator, Input1, Input2)
+  // Partial reduction nodes. Input1 and Input2 are multiplied together before
+  // being reduced, by addition to the number of elements that Accumulator's
+  // type has.
+  // Input1 and Input2 must be the same type. Accumulator and the output must be
+  // the same type.
+  // The number of elements in Input1 and Input2 must be a positive integer
+  // multiple of the number of elements in the Accumulator / output type.
+  // All operands, as well as the output, must have the same element type.
+  // Operands: Accumulator, Input1, Input2
+  // Outputs: Output
+  PARTIAL_REDUCE_SMLA,
+  PARTIAL_REDUCE_UMLA,
+
   // The `llvm.experimental.stackmap` intrinsic.
   // Operands: input chain, glue, <id>, <numShadowBytes>, [live0[, live1...]]
   // Outputs: output chain, glue
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 461c0c1ead16d2..0fc6f6ccf85bd9 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1607,6 +1607,13 @@ class SelectionDAG {
   /// the target's desired shift amount type.
   SDValue getShiftAmountOperand(EVT LHSTy, SDValue Op);
 
+  // Expands PARTIAL_REDUCE_S/UMLA nodes
+  // \p Acc Accumulator for where the result is stored for the partial reduction
+  // operation.
+  // \p Input1 First input for the partial reduction operation
+  // \p Input2 Second input for the partial reduction operation
+  SDValue expandPartialReduceMLA(SDNode *N);
+
   /// Create the DAG equivalent of vector_partial_reduce where Op1 and Op2 are
   /// its operands and ReducedTY is the intrinsic's return type.
   SDValue getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 625052be657ca0..3a9518ea569ebc 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -159,6 +159,11 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
     Res = PromoteIntRes_VECTOR_FIND_LAST_ACTIVE(N);
     break;
 
+  case ISD::PARTIAL_REDUCE_UMLA:
+  case ISD::PARTIAL_REDUCE_SMLA:
+    Res = PromoteIntRes_PARTIAL_REDUCE_MLA(N);
+    break;
+
   case ISD::SIGN_EXTEND:
   case ISD::VP_SIGN_EXTEND:
   case ISD::ZERO_EXTEND:
@@ -2076,6 +2081,10 @@ bool DAGTypeLegalizer::PromoteIntegerOperand(SDNode *N, unsigned OpNo) {
   case ISD::VECTOR_FIND_LAST_ACTIVE:
     Res = PromoteIntOp_VECTOR_FIND_LAST_ACTIVE(N, OpNo);
     break;
+  case ISD::PARTIAL_REDUCE_UMLA:
+  case ISD::PARTIAL_REDUCE_SMLA:
+    Res = PromoteIntOp_PARTIAL_REDUCE_MLA(N);
+    break;
   }
 
   // If the result is null, the sub-method took care of registering results etc.
@@ -2824,6 +2833,12 @@ SDValue DAGTypeLegalizer::PromoteIntOp_VECTOR_FIND_LAST_ACTIVE(SDNode *N,
   return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0);
 }
 
+SDValue DAGTypeLegalizer::PromoteIntOp_PARTIAL_REDUCE_MLA(SDNode *N) {
+  SDValue Res = DAG.expandPartialReduceMLA(N);
+  ReplaceValueWith(SDValue(N, 0), Res);
+  return SDValue();
+}
+
 //===----------------------------------------------------------------------===//
 //  Integer Result Expansion
 //===----------------------------------------------------------------------===//
@@ -6139,6 +6154,12 @@ SDValue DAGTypeLegalizer::PromoteIntRes_VECTOR_FIND_LAST_ACTIVE(SDNode *N) {
   return DAG.getNode(ISD::VECTOR_FIND_LAST_ACTIVE, SDLoc(N), NVT, N->ops());
 }
 
+SDValue DAGTypeLegalizer::PromoteIntRes_PARTIAL_REDUCE_MLA(SDNode *N) {
+  SDValue Res = DAG.expandPartialReduceMLA(N);
+  ReplaceValueWith(SDValue(N, 0), Res);
+  return SDValue();
+}
+
 SDValue DAGTypeLegalizer::PromoteIntRes_INSERT_VECTOR_ELT(SDNode *N) {
   EVT OutVT = N->getValueType(0);
   EVT NOutVT = TLI.getTypeToTransformTo(*DAG.getContext(), OutVT);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index f13f70e66cfaa6..cb9c1b239c0fa9 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -379,6 +379,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue PromoteIntRes_IS_FPCLASS(SDNode *N);
   SDValue PromoteIntRes_PATCHPOINT(SDNode *N);
   SDValue PromoteIntRes_VECTOR_FIND_LAST_ACTIVE(SDNode *N);
+  SDValue PromoteIntRes_PARTIAL_REDUCE_MLA(SDNode *N);
 
   // Integer Operand Promotion.
   bool PromoteIntegerOperand(SDNode *N, unsigned OpNo);
@@ -430,6 +431,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue PromoteIntOp_VP_SPLICE(SDNode *N, unsigned OpNo);
   SDValue PromoteIntOp_VECTOR_HISTOGRAM(SDNode *N, unsigned OpNo);
   SDValue PromoteIntOp_VECTOR_FIND_LAST_ACTIVE(SDNode *N, unsigned OpNo);
+  SDValue PromoteIntOp_PARTIAL_REDUCE_MLA(SDNode *N);
 
   void SExtOrZExtPromotedOperands(SDValue &LHS, SDValue &RHS);
   void PromoteSetCCOperands(SDValue &LHS,SDValue &RHS, ISD::CondCode Code);
@@ -968,6 +970,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   void SplitVecRes_VAARG(SDNode *N, SDValue &Lo, SDValue &Hi);
   void SplitVecRes_FP_TO_XINT_SAT(SDNode *N, SDValue &Lo, SDValue &Hi);
   void SplitVecRes_VP_REVERSE(SDNode *N, SDValue &Lo, SDValue &Hi);
+  void SplitVecRes_PARTIAL_REDUCE_MLA(SDNode *N);
 
   // Vector Operand Splitting: <128 x ty> -> 2 x <64 x ty>.
   bool SplitVectorOperand(SDNode *N, unsigned OpNo);
@@ -999,6 +1002,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue SplitVecOp_FP_TO_XINT_SAT(SDNode *N);
   SDValue SplitVecOp_VP_CttzElements(SDNode *N);
   SDValue SplitVecOp_VECTOR_HISTOGRAM(SDNode *N);
+  SDValue SplitVecOp_PARTIAL_REDUCE_MLA(SDNode *N);
 
   //===--------------------------------------------------------------------===//
   // Vector Widening Support: LegalizeVectorTypes.cpp
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index 1000235ab4061f..b01470028981e7 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -1373,6 +1373,9 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
   case ISD::EXPERIMENTAL_VP_REVERSE:
     SplitVecRes_VP_REVERSE(N, Lo, Hi);
     break;
+  case ISD::PARTIAL_REDUCE_UMLA:
+  case ISD::PARTIAL_REDUCE_SMLA:
+    SplitVecRes_PARTIAL_REDUCE_MLA(N);
   }
 
   // If Lo/Hi is null, the sub-method took care of registering results etc.
@@ -3182,6 +3185,11 @@ void DAGTypeLegalizer::SplitVecRes_VP_REVERSE(SDNode *N, SDValue &Lo,
   std::tie(Lo, Hi) = DAG.SplitVector(Load, DL);
 }
 
+void DAGTypeLegalizer::SplitVecRes_PARTIAL_REDUCE_MLA(SDNode *N) {
+  SDValue Res = DAG.expandPartialReduceMLA(N);
+  ReplaceValueWith(SDValue(N, 0), Res);
+}
+
 void DAGTypeLegalizer::SplitVecRes_VECTOR_DEINTERLEAVE(SDNode *N) {
 
   SDValue Op0Lo, Op0Hi, Op1Lo, Op1Hi;
@@ -3381,6 +3389,9 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
   case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
     Res = SplitVecOp_VECTOR_HISTOGRAM(N);
     break;
+  case ISD::PARTIAL_REDUCE_UMLA:
+  case ISD::PARTIAL_REDUCE_SMLA:
+    Res = SplitVecOp_PARTIAL_REDUCE_MLA(N);
   }
 
   // If the result is null, the sub-method took care of registering results etc.
@@ -4435,6 +4446,12 @@ SDValue DAGTypeLegalizer::SplitVecOp_VECTOR_HISTOGRAM(SDNode *N) {
                                 MMO, IndexType);
 }
 
+SDValue DAGTypeLegalizer::SplitVecOp_PARTIAL_REDUCE_MLA(SDNode *N) {
+  SDValue Res = DAG.expandPartialReduceMLA(N);
+  ReplaceValueWith(SDValue(N, 0), Res);
+  return SDValue();
+}
+
 //===----------------------------------------------------------------------===//
 //  Result Vector Widening
 //===----------------------------------------------------------------------===//
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index b416c0efbbc4fc..7240e4e00dfa07 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -2473,6 +2473,23 @@ SDValue SelectionDAG::getShiftAmountOperand(EVT LHSTy, SDValue Op) {
   return getZExtOrTrunc(Op, SDLoc(Op), ShTy);
 }
 
+SDValue SelectionDAG::expandPartialReduceMLA(SDNode *N) {
+  SDLoc DL(N);
+  SDValue Acc = N->getOperand(0);
+  SDValue Input1 = N->getOperand(1);
+  SDValue Input2 = N->getOperand(2);
+
+  EVT FullTy = Input1.getValueType();
+
+  SDValue Input = Input1;
+  APInt ConstantOne;
+  if (!ISD::isConstantSplatVector(Input2.getNode(), ConstantOne) ||
+      !ConstantOne.isOne())
+    Input = getNode(ISD::MUL, DL, FullTy, Input1, Input2);
+
+  return getPartialReduceAdd(DL, Acc.getValueType(), Acc, Input);
+}
+
 SDValue SelectionDAG::getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
                                           SDValue Op2) {
   EVT FullTy = Op2.getValueType();
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 428e7a316d247b..144439f136ff16 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -135,6 +135,10 @@ static cl::opt<unsigned> SwitchPeelThreshold(
              "switch statement. A value greater than 100 will void this "
              "optimization"));
 
+static cl::opt<bool> NewPartialReduceLowering(
+    "new-partial-reduce-lowering", cl::init(false), cl::ReallyHidden,
+    cl::desc("Use the new method of lowering partial reductions."));
+
 // Limit the width of DAG chains. This is important in general to prevent
 // DAG-based analysis from blowing up. For example, alias analysis and
 // load clustering may not complete in reasonable time. It is difficult to
@@ -8118,6 +8122,29 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
     return;
   }
   case Intrinsic::experimental_vector_partial_reduce_add: {
+    if (NewPartialReduceLowering) {
+      SDValue Acc = getValue(I.getOperand(0));
+      EVT AccVT = Acc.getValueType();
+      SDValue Input = getValue(I.getOperand(1));
+      EVT InputVT = Input.getValueType();
+
+      assert(AccVT.getVectorElementType() == InputVT.getVectorElementType() &&
+             "Expected operands to have the same vector element type!");
+      assert(
+          InputVT.getVectorElementCount().getKnownMinValue() %
+                  AccVT.getVectorElementCount().getKnownMinValue() ==
+              0 &&
+          "Expected the element count of the Input operand to be a positive "
+          "integer multiple of the element count of the Accumulator operand!");
+
+      // ISD::PARTIAL_REDUCE_UMLA is chosen arbitrarily and would function the
+      // same if ISD::PARTIAL_REDUCE_SMLA was chosen instead. It should be
+      // changed to its correct signedness when combining or expanding,
+      // according to extends being performed on Input.
+      setValue(&I, DAG.getNode(ISD::PARTIAL_REDUCE_UMLA, sdl, AccVT, Acc, Input,
+                               DAG.getConstant(1, sdl, InputVT)));
+      return;
+    }
 
     if (!TLI.shouldExpandPartialReductionIntrinsic(cast<IntrinsicInst>(&I))) {
       visitTargetIntrinsic(I, Intrinsic);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index f63c8dd3df1c83..a387c10679261b 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -570,6 +570,11 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
   case ISD::VECTOR_FIND_LAST_ACTIVE:
     return "find_last_active";
 
+  case ISD::PARTIAL_REDUCE_UMLA:
+    return "partial_reduce_umla";
+  case ISD::PARTIAL_REDUCE_SMLA:
+    return "partial_reduce_smla";
+
     // Vector Predication
 #define BEGIN_REGISTER_VP_SDNODE(SDID, LEGALARG, NAME, ...)                    \
   case ISD::SDID:                                                              \
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index 66f83c658ff4f2..16c0001dbdb838 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -1,12 +1,41 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
 ; RUN: llc -mtriple=aarch64 -mattr=+sve2,+i8mm %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-I8MM
 ; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-NOI8MM
+; RUN: llc -mtriple=aarch64 -mattr=+sve2,+i8mm -new-partial-reduce-lowering %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-NEWLOWERING
 
 define <vscale x 4 x i32> @udot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
-; CHECK-LABEL: udot:
-; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    udot z0.s, z1.b, z2.b
-; CHECK-NEXT:    ret
+; CHECK-I8MM-LABEL: udot:
+; CHECK-I8MM:       // %bb.0: // %entry
+; CHECK-I8MM-NEXT:    udot z0.s, z1.b, z2.b
+; CHECK-I8MM-NEXT:    ret
+;
+; CHECK-NOI8MM-LABEL: udot:
+; CHECK-NOI8MM:       // %bb.0: // %entry
+; CHECK-NOI8MM-NEXT:    udot z0.s, z1.b, z2.b
+; CHECK-NOI8MM-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: udot:
+; CHECK-NEWLOWERING:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT:    uunpklo z3.h, z1.b
+; CHECK-NEWLOWERING-NEXT:    uunpklo z4.h, z2.b
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.h, z1.b
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.h, z2.b
+; CHECK-NEWLOWERING-NEXT:    ptrue p0.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z5.s, z3.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.s, z3.h
+; CHECK-NEWLOWERING-NEXT:    uunpklo z6.s, z4.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.s, z4.h
+; CHECK-NEWLOWERING-NEXT:    uunpklo z7.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    uunpklo z24.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z5.s, z6.s
+; CHECK-NEWLOWERING-NEXT:    mul z3.s, z3.s, z4.s
+; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEWLOWERING-NEXT:    movprfx z1, z3
+; CHECK-NEWLOWERING-NEXT:    mla z1.s, p0/m, z7.s, z24.s
+; CHECK-NEWLOWERING-NEXT:    add z0.s, z1.s, z0.s
+; CHECK-NEWLOWERING-NEXT:    ret
 entry:
   %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
   %b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i32>
@@ -16,10 +45,38 @@ entry:
 }
 
 define <vscale x 2 x i64> @udot_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
-; CHECK-LABEL: udot_wide:
-; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    udot z0.d, z1.h, z2.h
-; CHECK-NEXT:    ret
+; CHECK-I8MM-LABEL: udot_wide:
+; CHECK-I8MM:       // %bb.0: // %entry
+; CHECK-I8MM-NEXT:    udot z0.d, z1.h, z2.h
+; CHECK-I8MM-NEXT:    ret
+;
+; CHECK-NOI8MM-LABEL: udot_wide:
+; CHECK-NOI8MM:       // %bb.0: // %entry
+; CHECK-NOI8MM-NEXT:    udot z0.d, z1.h, z2.h
+; CHECK-NOI8MM-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: udot_wide:
+; CHECK-NEWLOWERING:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT:    uunpklo z3.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    uunpklo z4.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    ptrue p0.d
+; CHECK-NEWLOWERING-NEXT:    uunpklo z5.d, z3.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.d, z3.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z6.d, z4.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.d, z4.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z7.d, z1.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z24.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z5.d, z6.d
+; CHECK-NEWLOWERING-NEXT:    mul z3.d, z3.d, z4.d
+; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEWLOWERING-NEXT:    movprfx z1, z3
+; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z7.d, z24.d
+; CHECK-NEWLOWERING-NEXT:    add z0.d, z1.d, z0.d
+; CHECK-NEWLOWERING-NEXT:    ret
 entry:
   %a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i64>
   %b.wide = zext <vscale x 8 x i16> %b to <vscale x 8 x i64>
@@ -29,10 +86,38 @@ entry:
 }
 
 define <vscale x 4 x i32> @sdot(<vscale x 4 x i32> %accc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
-; CHECK-LABEL: sdot:
-; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    sdot z0.s, z1.b, z2.b
-; CHECK-NEXT:    ret
+; CHECK-I8MM-LABEL: sdot:
+; CHECK-I8MM:       // %bb.0: // %entry
+; CHECK-I8MM-NEXT:    sdot z0.s, z1.b, z2.b
+; CHECK-I8MM-NEXT:    ret
+;
+; CHECK-NOI8MM-LABEL: sdot:
+; CHECK-NOI8MM:       // %bb.0: // %entry
+; CHECK-NOI8MM-NEXT:    sdot z0.s, z1.b, z2.b
+; CHECK-NOI8MM-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: sdot:
+; CHECK-NEWLOWERING:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT:    sunpklo z3.h, z1.b
+; CHECK-NEWLOWERING-NEXT:    sunpklo z4.h, z2.b
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.h, z1.b
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.h, z2.b
+; CHECK-NEWLOWERING-NEXT:    ptrue p0.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z5.s, z3.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.s, z3.h
+; CHECK-NEWLOWERING-NEXT:    sunpklo z6.s, z4.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.s, z4.h
+; CHECK-NEWLOWERING-NEXT:    sunpklo z7.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    sunpklo z24.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z5.s, z6.s
+; CHECK-NEWLOWERING-NEXT:    mul z3.s, z3.s, z4.s
+; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEWLOWERING-NEXT:    movprfx z1, z3
+; CHECK-NEWLOWERING-NEXT:    mla z1.s, p0/m, z7.s, z24.s
+; CHECK-NEWLOWERING-NEXT:    add z0.s, z1.s, z0.s
+; CHECK-NEWLOWERING-NEXT:    ret
 entry:
   %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
   %b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i32>
@@ -42,10 +127,38 @@ entry:
 }
 
 define <vscale x 2 x i64> @sdot_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
-; CHECK-LABEL: sdot_wide:
-; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    sdot z0.d, z1.h, z2.h
-; CHECK-NEXT:    ret
+; CHECK-I8MM-LABEL: sdot_wide:
+; CHECK-I8MM:       // %bb.0: // %entry
+; CHECK-I8MM-NEXT:    sdot z0.d, z1.h, z2.h
+; CHECK-I8MM-NEXT:    ret
+;
+; CHECK-NOI8MM-LABEL: sdot_wide:
+; CHECK-NOI8MM:       // %bb.0: // %entry
+; CHECK-NOI8MM-NEXT:    sdot z0.d, z1.h, z2.h
+; CHECK-NOI8MM-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: sdot_wide:
+; CHECK-NEWLOWERING:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT:    sunpklo z3.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    sunpklo z4.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    ptrue p0.d
+; CHECK-NEWLOWERING-NEXT:    sunpklo z5.d, z3.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.d, z3.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z6.d, z4.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.d, z4.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z7.d, z1.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z24.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z5.d, z6.d
+; CHECK-NEWLOWERING-NEXT:    mul z3.d, z3.d, z4.d
+; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEWLOWERING-NEXT:    movprfx z1, z3
+; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z7.d, z24.d
+; CHECK-NEWLOWERING-NEXT:    add z0.d, z1.d, z0.d
+; CHECK-NEWLOWERING-NEXT:    ret
 entry:
   %a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i64>
   %b.wide = sext <vscale x 8 x i16> %b to <vscale x 8 x i64>
@@ -82,6 +195,29 @@ define <vscale x 4 x i32> @usdot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a,
 ; CHECK-NOI8MM-NEXT:    mla z1.s, p0/m, z7.s, z24.s
 ; CHECK-NOI8MM-NEXT:    add z0.s, z1.s, z0.s
 ; CHECK-NOI8MM-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: usdot:
+; CHECK-NEWLOWERING:       // %bb.0:...
[truncated]

llvm/include/llvm/CodeGen/ISDOpcodes.h Outdated Show resolved Hide resolved
llvm/include/llvm/CodeGen/ISDOpcodes.h Outdated Show resolved Hide resolved
llvm/include/llvm/CodeGen/SelectionDAG.h Outdated Show resolved Hide resolved
llvm/include/llvm/CodeGen/SelectionDAG.h Outdated Show resolved Hide resolved
llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp Outdated Show resolved Hide resolved
llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp Outdated Show resolved Hide resolved
llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp Outdated Show resolved Hide resolved
llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp Outdated Show resolved Hide resolved
llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp Outdated Show resolved Hide resolved
Add signed and unsigned PARTIAL_REDUCE_MLA ISD nodes.
Add command line argument (new-partial-reduce-lowering) that
indicates whether the intrinsic experimental_vector_partial_
reduce_add will be transformed into the new ISD node.
Lowering with the new ISD nodes will, for now, always be done as
an expand.
Move the getPartialReduceAdd function around.
Make the new codepath work for fixed length NEON vectors too.
llvm/include/llvm/CodeGen/TargetLowering.h Outdated Show resolved Hide resolved
llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp Outdated Show resolved Hide resolved
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp Outdated Show resolved Hide resolved
llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp Outdated Show resolved Hide resolved
Copy link
Collaborator

@sdesmalen-arm sdesmalen-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this is looking pretty good! Just a few more minor comments before I'm happy to accept it.

llvm/include/llvm/CodeGen/ISDOpcodes.h Outdated Show resolved Hide resolved
llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp Outdated Show resolved Hide resolved
llvm/include/llvm/CodeGen/TargetLowering.h Outdated Show resolved Hide resolved
llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp Outdated Show resolved Hide resolved
Copy link

github-actions bot commented Feb 5, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Collaborator

@sdesmalen-arm sdesmalen-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp Outdated Show resolved Hide resolved
llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp Outdated Show resolved Hide resolved
llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp Outdated Show resolved Hide resolved
llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp Outdated Show resolved Hide resolved
llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp Outdated Show resolved Hide resolved
llvm/include/llvm/CodeGen/ISDOpcodes.h Outdated Show resolved Hide resolved
llvm/include/llvm/CodeGen/ISDOpcodes.h Outdated Show resolved Hide resolved
llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp Outdated Show resolved Hide resolved
llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp Outdated Show resolved Hide resolved
Make comment describing the node more broad.
Promote both inputs at the same time.
Move assert statements to the getNode() function.
Make the command line option target-specific.
llvm/include/llvm/CodeGen/ISDOpcodes.h Outdated Show resolved Hide resolved
llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp Outdated Show resolved Hide resolved
llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp Outdated Show resolved Hide resolved
llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp Outdated Show resolved Hide resolved
llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp Outdated Show resolved Hide resolved
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp Outdated Show resolved Hide resolved
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp Outdated Show resolved Hide resolved
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp Outdated Show resolved Hide resolved
Improve assert statements.
Remove unnecessary variable creation.
Change how operation actions are set for the nodes.
Fix wrong signedness in promotion function.
Make expand code depend on node signedness not on operand extend.
Was seeing if the operand opcodes were PARTIAL_REDUCE_MLA nodes.
Now looking at their parent.
Copy link
Collaborator

@paulwalker-arm paulwalker-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few final recommendations but otherwise this looks good to me. I'll admit to not be convinced about the implementation of SplitVecRes_PARTIAL_REDUCE_MLA but figure that'll be short lived and so if you're sure it works then that's good enough for me.

unsigned ExtOpc = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA
? ISD::SIGN_EXTEND
: ISD::ZERO_EXTEND;
EVT MulLHSVT = MulLHS.getValueType();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need MulLHSVT because it is the same as FullTy? That said, FullTy is a terrible name, perhaps MulVT?

Comment on lines +11910 to +11912
assert(MulLHSVT == MulRHS.getValueType() &&
"The second and third operands of a PARTIAL_REDUCE_MLA node must have "
"the same value type!");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should assume the DAG is well formed and thus this assert is unnecessary. This is the reason behind adding the asserts to getNode() because it is much easier to catch failures during construction and means there's no need to pollute the DAG combines with such validation code.

Comment on lines +11913 to +11914
EVT ExtVT = MulLHSVT.changeVectorElementType(
Acc.getValueType().getVectorElementType());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need ExtVT because it is the same as NewVT? Although perhaps rename that to NewMulVT or ExtMulVT?

Comment on lines +4452 to +4454
SDValue Res = TLI.expandPartialReduceMLA(N, DAG);
ReplaceValueWith(SDValue(N, 0), Res);
return SDValue();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does just return TLI.expandPartialReduceMLA(N, DAG) work here?

// The partial reduction nodes sign or zero extend Input1 and Input2 to the
// element type of Accumulator before multiplying their results.
// This result is concatenated to the Accumulator, and this is then reduced,
// using addition, to the result type.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This specification is missing a really important detail. Specifically, which elements are reduced via addition? There's at least two reasonable choices here: adjacent, and subvector addition (i.e. every nth input). Which is it? Looking at the LangRef the associated intrinsic, this detail is missing there as well. The specification needs updating.

Looking at the implementation it is subvector-wise. This needs explicitly stated.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The intrinsic and matching ISD node are deliberate in not defining the order in which the elements are reduced because their output is only expected to feed into themselves or an equivalent vector.reduce operation and thus the ordering does not matter. Essentially, add reductions can now be represented in LLVM as a two phase operation (one in-loop and the other out-of-loop) in the most relaxed way possible.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then you need to say that clearly in the specification text. At the moment, it looks like merely an unstated assumption, not a specifically reserved behavior. The bit about the result only being valid for use by further reductions (i.e. non-order preserving operations) is really important to clearly say.

Seriously, I read your specification and didn't not understand this. I doubt I'm the only one.

@@ -3182,6 +3186,11 @@ void DAGTypeLegalizer::SplitVecRes_VP_REVERSE(SDNode *N, SDValue &Lo,
std::tie(Lo, Hi) = DAG.SplitVector(Load, DL);
}

void DAGTypeLegalizer::SplitVecRes_PARTIAL_REDUCE_MLA(SDNode *N) {
SDValue Res = TLI.expandPartialReduceMLA(N, DAG);
ReplaceValueWith(SDValue(N, 0), Res);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the precedent of SplitVecRes_VECTOR_SPLICE - which is the only example of expanding during SplitVecRes I found - I think you need to manually split the result of the expand into Lo and Hi here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh nice, thanks for finding this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:AArch64 llvm:SelectionDAG SelectionDAGISel as well
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants