Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SelectionDAG] Add PARTIAL_REDUCE_U/SMLA ISD Nodes #125207

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
14 changes: 14 additions & 0 deletions llvm/include/llvm/CodeGen/ISDOpcodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -1451,6 +1451,20 @@ enum NodeType {
VECREDUCE_UMAX,
VECREDUCE_UMIN,

// PARTIAL_REDUCE_[U|S]MLA(Accumulator, Input1, Input2)
// The partial reduction nodes sign or zero extend Input1 and Input2 to the
// element type of Accumulator before multiplying their results.
// This result is concatenated to the Accumulator, and this is then reduced,
// using addition, to the result type.
JamesChesterman marked this conversation as resolved.
Show resolved Hide resolved
// Input1 and Input2 must be the same type. Accumulator and the output must be
// the same type.
// The number of elements in Input1 and Input2 must be a positive integer
// multiple of the number of elements in the Accumulator / output type.
// Input1 and Input2 must have an element type which is the same as or smaller
// than the element type of the Accumulator and output.
PARTIAL_REDUCE_SMLA,
PARTIAL_REDUCE_UMLA,

// The `llvm.experimental.stackmap` intrinsic.
// Operands: input chain, glue, <id>, <numShadowBytes>, [live0[, live1...]]
// Outputs: output chain, glue
Expand Down
5 changes: 0 additions & 5 deletions llvm/include/llvm/CodeGen/SelectionDAG.h
Original file line number Diff line number Diff line change
Expand Up @@ -1607,11 +1607,6 @@ class SelectionDAG {
/// the target's desired shift amount type.
SDValue getShiftAmountOperand(EVT LHSTy, SDValue Op);

/// Create the DAG equivalent of vector_partial_reduce where Op1 and Op2 are
/// its operands and ReducedTY is the intrinsic's return type.
SDValue getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
SDValue Op2);

/// Expands a node with multiple results to an FP or vector libcall. The
/// libcall is expected to take all the operands of the \p Node followed by
/// output pointers for each of the results. \p CallRetResNo can be optionally
Expand Down
4 changes: 4 additions & 0 deletions llvm/include/llvm/CodeGen/TargetLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -5539,6 +5539,10 @@ class TargetLowering : public TargetLoweringBase {
/// temporarily, advance store position, before re-loading the final vector.
SDValue expandVECTOR_COMPRESS(SDNode *Node, SelectionDAG &DAG) const;

/// Expands PARTIAL_REDUCE_S/UMLA nodes to a series of simpler operations,
/// consisting of zext/sext, extract_subvector, mul and add operations.
SDValue expandPartialReduceMLA(SDNode *Node, SelectionDAG &DAG) const;

/// Legalize a SETCC or VP_SETCC with given LHS and RHS and condition code CC
/// on the current target. A VP_SETCC will additionally be given a Mask
/// and/or EVL not equal to SDValue().
Expand Down
30 changes: 30 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,11 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
Res = PromoteIntRes_VECTOR_FIND_LAST_ACTIVE(N);
break;

case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
Res = PromoteIntRes_PARTIAL_REDUCE_MLA(N);
break;

case ISD::SIGN_EXTEND:
case ISD::VP_SIGN_EXTEND:
case ISD::ZERO_EXTEND:
Expand Down Expand Up @@ -2099,6 +2104,10 @@ bool DAGTypeLegalizer::PromoteIntegerOperand(SDNode *N, unsigned OpNo) {
case ISD::VECTOR_FIND_LAST_ACTIVE:
Res = PromoteIntOp_VECTOR_FIND_LAST_ACTIVE(N, OpNo);
break;
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
Res = PromoteIntOp_PARTIAL_REDUCE_MLA(N);
break;
}

// If the result is null, the sub-method took care of registering results etc.
Expand Down Expand Up @@ -2881,6 +2890,18 @@ SDValue DAGTypeLegalizer::PromoteIntOp_VECTOR_FIND_LAST_ACTIVE(SDNode *N,
return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0);
}

SDValue DAGTypeLegalizer::PromoteIntOp_PARTIAL_REDUCE_MLA(SDNode *N) {
SmallVector<SDValue, 1> NewOps(N->ops());
if (N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA) {
NewOps[1] = SExtPromotedInteger(N->getOperand(1));
NewOps[2] = SExtPromotedInteger(N->getOperand(2));
} else {
NewOps[1] = ZExtPromotedInteger(N->getOperand(1));
NewOps[2] = ZExtPromotedInteger(N->getOperand(2));
}
return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0);
}

//===----------------------------------------------------------------------===//
// Integer Result Expansion
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -6196,6 +6217,15 @@ SDValue DAGTypeLegalizer::PromoteIntRes_VECTOR_FIND_LAST_ACTIVE(SDNode *N) {
return DAG.getNode(ISD::VECTOR_FIND_LAST_ACTIVE, SDLoc(N), NVT, N->ops());
}

SDValue DAGTypeLegalizer::PromoteIntRes_PARTIAL_REDUCE_MLA(SDNode *N) {
SDLoc DL(N);
EVT VT = N->getValueType(0);
EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
SDValue ExtAcc = GetPromotedInteger(N->getOperand(0));
JamesChesterman marked this conversation as resolved.
Show resolved Hide resolved
return DAG.getNode(N->getOpcode(), DL, NVT, ExtAcc, N->getOperand(1),
N->getOperand(2));
}

SDValue DAGTypeLegalizer::PromoteIntRes_INSERT_VECTOR_ELT(SDNode *N) {
EVT OutVT = N->getValueType(0);
EVT NOutVT = TLI.getTypeToTransformTo(*DAG.getContext(), OutVT);
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue PromoteIntRes_IS_FPCLASS(SDNode *N);
SDValue PromoteIntRes_PATCHPOINT(SDNode *N);
SDValue PromoteIntRes_VECTOR_FIND_LAST_ACTIVE(SDNode *N);
SDValue PromoteIntRes_PARTIAL_REDUCE_MLA(SDNode *N);

// Integer Operand Promotion.
bool PromoteIntegerOperand(SDNode *N, unsigned OpNo);
Expand Down Expand Up @@ -430,6 +431,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue PromoteIntOp_VP_SPLICE(SDNode *N, unsigned OpNo);
SDValue PromoteIntOp_VECTOR_HISTOGRAM(SDNode *N, unsigned OpNo);
SDValue PromoteIntOp_VECTOR_FIND_LAST_ACTIVE(SDNode *N, unsigned OpNo);
SDValue PromoteIntOp_PARTIAL_REDUCE_MLA(SDNode *N);

void SExtOrZExtPromotedOperands(SDValue &LHS, SDValue &RHS);
void PromoteSetCCOperands(SDValue &LHS,SDValue &RHS, ISD::CondCode Code);
Expand Down Expand Up @@ -968,6 +970,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
void SplitVecRes_VAARG(SDNode *N, SDValue &Lo, SDValue &Hi);
void SplitVecRes_FP_TO_XINT_SAT(SDNode *N, SDValue &Lo, SDValue &Hi);
void SplitVecRes_VP_REVERSE(SDNode *N, SDValue &Lo, SDValue &Hi);
void SplitVecRes_PARTIAL_REDUCE_MLA(SDNode *N);

// Vector Operand Splitting: <128 x ty> -> 2 x <64 x ty>.
bool SplitVectorOperand(SDNode *N, unsigned OpNo);
Expand Down Expand Up @@ -999,6 +1002,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue SplitVecOp_FP_TO_XINT_SAT(SDNode *N);
SDValue SplitVecOp_VP_CttzElements(SDNode *N);
SDValue SplitVecOp_VECTOR_HISTOGRAM(SDNode *N);
SDValue SplitVecOp_PARTIAL_REDUCE_MLA(SDNode *N);

//===--------------------------------------------------------------------===//
// Vector Widening Support: LegalizeVectorTypes.cpp
Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,8 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
case ISD::VECTOR_COMPRESS:
case ISD::SCMP:
case ISD::UCMP:
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
Action = TLI.getOperationAction(Node->getOpcode(), Node->getValueType(0));
break;
case ISD::SMULFIX:
Expand Down Expand Up @@ -1195,6 +1197,10 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
case ISD::VECREDUCE_FMINIMUM:
Results.push_back(TLI.expandVecReduce(Node, DAG));
return;
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
Results.push_back(TLI.expandPartialReduceMLA(Node, DAG));
return;
case ISD::VECREDUCE_SEQ_FADD:
case ISD::VECREDUCE_SEQ_FMUL:
Results.push_back(TLI.expandVecReduceSeq(Node, DAG));
Expand Down
19 changes: 19 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1373,6 +1373,10 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
case ISD::EXPERIMENTAL_VP_REVERSE:
SplitVecRes_VP_REVERSE(N, Lo, Hi);
break;
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
SplitVecRes_PARTIAL_REDUCE_MLA(N);
JamesChesterman marked this conversation as resolved.
Show resolved Hide resolved
break;
}

// If Lo/Hi is null, the sub-method took care of registering results etc.
Expand Down Expand Up @@ -3182,6 +3186,11 @@ void DAGTypeLegalizer::SplitVecRes_VP_REVERSE(SDNode *N, SDValue &Lo,
std::tie(Lo, Hi) = DAG.SplitVector(Load, DL);
}

void DAGTypeLegalizer::SplitVecRes_PARTIAL_REDUCE_MLA(SDNode *N) {
SDValue Res = TLI.expandPartialReduceMLA(N, DAG);
ReplaceValueWith(SDValue(N, 0), Res);
JamesChesterman marked this conversation as resolved.
Show resolved Hide resolved
}

void DAGTypeLegalizer::SplitVecRes_VECTOR_DEINTERLEAVE(SDNode *N) {

SDValue Op0Lo, Op0Hi, Op1Lo, Op1Hi;
Expand Down Expand Up @@ -3381,6 +3390,10 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
Res = SplitVecOp_VECTOR_HISTOGRAM(N);
break;
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
Res = SplitVecOp_PARTIAL_REDUCE_MLA(N);
JamesChesterman marked this conversation as resolved.
Show resolved Hide resolved
break;
}

// If the result is null, the sub-method took care of registering results etc.
Expand Down Expand Up @@ -4435,6 +4448,12 @@ SDValue DAGTypeLegalizer::SplitVecOp_VECTOR_HISTOGRAM(SDNode *N) {
MMO, IndexType);
}

SDValue DAGTypeLegalizer::SplitVecOp_PARTIAL_REDUCE_MLA(SDNode *N) {
SDValue Res = TLI.expandPartialReduceMLA(N, DAG);
ReplaceValueWith(SDValue(N, 0), Res);
return SDValue();
JamesChesterman marked this conversation as resolved.
Show resolved Hide resolved
}

//===----------------------------------------------------------------------===//
// Result Vector Widening
//===----------------------------------------------------------------------===//
Expand Down
51 changes: 22 additions & 29 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2474,35 +2474,6 @@ SDValue SelectionDAG::getShiftAmountOperand(EVT LHSTy, SDValue Op) {
return getZExtOrTrunc(Op, SDLoc(Op), ShTy);
}

SDValue SelectionDAG::getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
SDValue Op2) {
EVT FullTy = Op2.getValueType();

unsigned Stride = ReducedTy.getVectorMinNumElements();
unsigned ScaleFactor = FullTy.getVectorMinNumElements() / Stride;

// Collect all of the subvectors
std::deque<SDValue> Subvectors = {Op1};
for (unsigned I = 0; I < ScaleFactor; I++) {
auto SourceIndex = getVectorIdxConstant(I * Stride, DL);
Subvectors.push_back(
getNode(ISD::EXTRACT_SUBVECTOR, DL, ReducedTy, {Op2, SourceIndex}));
}

// Flatten the subvector tree
while (Subvectors.size() > 1) {
Subvectors.push_back(
getNode(ISD::ADD, DL, ReducedTy, {Subvectors[0], Subvectors[1]}));
Subvectors.pop_front();
Subvectors.pop_front();
}

assert(Subvectors.size() == 1 &&
"There should only be one subvector after tree flattening");

return Subvectors[0];
}

/// Given a store node \p StoreNode, return true if it is safe to fold that node
/// into \p FPNode, which expands to a library call with output pointers.
static bool canFoldStoreIntoLibCallOutputPointers(StoreSDNode *StoreNode,
Expand Down Expand Up @@ -7883,6 +7854,28 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,

break;
}
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA: {
[[maybe_unused]] EVT AccVT = N1.getValueType();
[[maybe_unused]] EVT Input1VT = N2.getValueType();
[[maybe_unused]] EVT Input2VT = N3.getValueType();
assert(Input1VT.isVector() && Input1VT == Input2VT &&
"Expected the second and third operands of the PARTIAL_REDUCE_MLA "
"node to have the same type!");
assert(VT.isVector() && VT == AccVT &&
"Expected the first operand of the PARTIAL_REDUCE_MLA node to have "
"the same type as its result!");
assert(Input1VT.getVectorElementCount().hasKnownScalarFactor(
AccVT.getVectorElementCount()) &&
"Expected the element count of the second and third operands of the "
"PARTIAL_REDUCE_MLA node to be a positive integer multiple of the "
"element count of the first operand and the result!");
assert(N2.getScalarValueSizeInBits() <= N1.getScalarValueSizeInBits() &&
"Expected the second and third operands of the PARTIAL_REDUCE_MLA "
"node to have an element type which is the same as or smaller than "
"the element type of the first operand and result!");
break;
}
}

// Memoize node if it doesn't produce a glue result.
Expand Down
10 changes: 5 additions & 5 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8118,15 +8118,15 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
return;
}
case Intrinsic::experimental_vector_partial_reduce_add: {

if (!TLI.shouldExpandPartialReductionIntrinsic(cast<IntrinsicInst>(&I))) {
visitTargetIntrinsic(I, Intrinsic);
return;
}

setValue(&I, DAG.getPartialReduceAdd(sdl, EVT::getEVT(I.getType()),
getValue(I.getOperand(0)),
getValue(I.getOperand(1))));
SDValue Acc = getValue(I.getOperand(0));
SDValue Input = getValue(I.getOperand(1));
setValue(&I,
DAG.getNode(ISD::PARTIAL_REDUCE_UMLA, sdl, Acc.getValueType(), Acc,
Input, DAG.getConstant(1, sdl, Input.getValueType())));
return;
}
case Intrinsic::experimental_cttz_elts: {
Expand Down
5 changes: 5 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,11 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
case ISD::VECTOR_FIND_LAST_ACTIVE:
return "find_last_active";

case ISD::PARTIAL_REDUCE_UMLA:
return "partial_reduce_umla";
case ISD::PARTIAL_REDUCE_SMLA:
return "partial_reduce_smla";

// Vector Predication
#define BEGIN_REGISTER_VP_SDNODE(SDID, LEGALARG, NAME, ...) \
case ISD::SDID: \
Expand Down
57 changes: 57 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "llvm/Support/MathExtras.h"
#include "llvm/Target/TargetMachine.h"
#include <cctype>
#include <deque>
using namespace llvm;

/// NOTE: The TargetMachine owns TLOF.
Expand Down Expand Up @@ -11890,6 +11891,62 @@ SDValue TargetLowering::expandVECTOR_COMPRESS(SDNode *Node,
return DAG.getLoad(VecVT, DL, Chain, StackPtr, PtrInfo);
}

SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
SelectionDAG &DAG) const {
SDLoc DL(N);
SDValue Acc = N->getOperand(0);
SDValue MulLHS = N->getOperand(1);
SDValue MulRHS = N->getOperand(2);
EVT ReducedTy = Acc.getValueType();
EVT FullTy = MulLHS.getValueType();

EVT NewVT =
EVT::getVectorVT(*DAG.getContext(), ReducedTy.getVectorElementType(),
FullTy.getVectorElementCount());
unsigned ExtOpc = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA
? ISD::SIGN_EXTEND
: ISD::ZERO_EXTEND;
EVT MulLHSVT = MulLHS.getValueType();
JamesChesterman marked this conversation as resolved.
Show resolved Hide resolved
assert(MulLHSVT == MulRHS.getValueType() &&
"The second and third operands of a PARTIAL_REDUCE_MLA node must have "
"the same value type!");
JamesChesterman marked this conversation as resolved.
Show resolved Hide resolved
EVT ExtVT = MulLHSVT.changeVectorElementType(
Acc.getValueType().getVectorElementType());
JamesChesterman marked this conversation as resolved.
Show resolved Hide resolved
if (ExtVT != FullTy) {
MulLHS = DAG.getNode(ExtOpc, DL, ExtVT, MulLHS);
MulRHS = DAG.getNode(ExtOpc, DL, ExtVT, MulRHS);
}
SDValue Input = MulLHS;
APInt ConstantOne;
if (!ISD::isConstantSplatVector(MulRHS.getNode(), ConstantOne) ||
!ConstantOne.isOne())
Input = DAG.getNode(ISD::MUL, DL, NewVT, MulLHS, MulRHS);

unsigned Stride = ReducedTy.getVectorMinNumElements();
unsigned ScaleFactor = FullTy.getVectorMinNumElements() / Stride;

// Collect all of the subvectors
std::deque<SDValue> Subvectors = {Acc};
for (unsigned I = 0; I < ScaleFactor; I++) {
auto SourceIndex = DAG.getVectorIdxConstant(I * Stride, DL);
Subvectors.push_back(DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReducedTy,
{Input, SourceIndex}));
}

// Flatten the subvector tree
while (Subvectors.size() > 1) {
Subvectors.push_back(
DAG.getNode(ISD::ADD, DL, ReducedTy, {Subvectors[0], Subvectors[1]}));
Subvectors.pop_front();
Subvectors.pop_front();
}

assert(Subvectors.size() == 1 &&
"There should only be one subvector after tree flattening");

return Subvectors[0];
}

bool TargetLowering::LegalizeSetCCCondCode(SelectionDAG &DAG, EVT VT,
SDValue &LHS, SDValue &RHS,
SDValue &CC, SDValue Mask,
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/CodeGen/TargetLoweringBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,10 @@ void TargetLoweringBase::initActions() {
setOperationAction(ISD::GET_FPENV, VT, Expand);
setOperationAction(ISD::SET_FPENV, VT, Expand);
setOperationAction(ISD::RESET_FPENV, VT, Expand);

// PartialReduceMLA operations default to expand.
setOperationAction({ISD::PARTIAL_REDUCE_UMLA, ISD::PARTIAL_REDUCE_SMLA}, VT,
Expand);
}

// Most targets ignore the @llvm.prefetch intrinsic.
Expand Down
Loading