Skip to content

Commit

Permalink
OpenXLA-specific changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Aliia Khasanova authored and khasanovaa committed Feb 6, 2025
1 parent 99cff45 commit aedc018
Show file tree
Hide file tree
Showing 58 changed files with 3,633 additions and 982 deletions.
891 changes: 891 additions & 0 deletions BUILD

Large diffs are not rendered by default.

13 changes: 7 additions & 6 deletions include/triton/Conversion/MLIRTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,16 @@ inline Type bf16Ty(MLIRContext *ctx) { return BFloat16Type::get(ctx); }

inline bool isFloat(Type type) {
return type.isF32() || type.isF64() || type.isF16() || type.isF128() ||
type.isBF16() || type.isFloat8E4M3B11FNUZ() || type.isFloat8E4M3FN() ||
type.isFloat8E4M3FNUZ() || type.isFloat8E5M2() ||
type.isFloat8E5M2FNUZ();
type.isBF16() ||
llvm::isa<mlir::Float8E4M3B11FNUZType, mlir::Float8E4M3FNType,
mlir::Float8E4M3FNUZType, mlir::Float8E5M2Type,
mlir::Float8E5M2FNUZType>(type);
}

inline bool isFloat8(Type type) {
return type.isFloat8E4M3B11FNUZ() || type.isFloat8E4M3FN() ||
type.isFloat8E4M3FNUZ() || type.isFloat8E5M2() ||
type.isFloat8E5M2FNUZ();
return llvm::isa<mlir::Float8E4M3B11FNUZType, mlir::Float8E4M3FNType,
mlir::Float8E4M3FNUZType, mlir::Float8E5M2Type,
mlir::Float8E5M2FNUZType>(type);
}

inline bool isInt(Type type) { return type.isIntOrFloat() && !isFloat(type); }
Expand Down
7 changes: 6 additions & 1 deletion include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1105,7 +1105,12 @@ def CallOp : TT_Op<"call", [CallOpInterface, /*MemRefsNormalizable, */DeclareOpI
MutableOperandRange getArgOperandsMutable() {
return getOperandsMutable();
}

Attribute removeArgAttrsAttr() { return nullptr; }
Attribute removeResAttrsAttr() { return nullptr; }
ArrayAttr getArgAttrsAttr() { return nullptr; }
ArrayAttr getResAttrsAttr() { return nullptr; }
void setArgAttrsAttr(ArrayAttr) { return; }
void setResAttrsAttr(ArrayAttr) { return; }
}];

let assemblyFormat = [{
Expand Down
6 changes: 6 additions & 0 deletions lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,12 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,

std::tie(scratchConfig.inVec, scratchConfig.outVec) =
getScratchCvtInOutVecLengths(srcTy, dstTy);
// We can't write a longer vector than the shape of shared memory.
// This shape might be smaller than the tensor shape in case we decided to
// do the conversion in multiple iterations.
unsigned contiguousShapeDim = scratchConfig.repShape[scratchConfig.order[0]];
scratchConfig.inVec = std::min(scratchConfig.inVec, contiguousShapeDim);
scratchConfig.outVec = std::min(scratchConfig.outVec, contiguousShapeDim);

// No padding is required if the tensor is 1-D, or if all dimensions except
// the first accessed dimension have a size of 1.
Expand Down
2 changes: 1 addition & 1 deletion lib/Analysis/AxisInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -935,7 +935,7 @@ class ShROpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
// Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n
lhsDivisibility = 1;
}
return std::max<int64_t>(1, lhsDivisibility / (1 << shift));
return std::max<int64_t>(1, lhsDivisibility / (int64_t(1) << shift));
}

int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
Expand Down
9 changes: 5 additions & 4 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -750,14 +750,14 @@ bool supportMMA(triton::DotOp op, int version) {
return false;
if (!(numWarps % 4 == 0 && retShapePerCTA[rank - 2] % 64 == 0 &&
retShapePerCTA[rank - 1] % 8 == 0 &&
(aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN() ||
(llvm::isa<mlir::Float8E5M2Type, mlir::Float8E4M3FNType>(aElemTy) ||
aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() ||
aElemTy.isF32()))) {
return false;
}
// We cannot use MMA_V3 if we need to accumulate in F32 within the MMA op.
if (op.getMaxNumImpreciseAcc() < 32 &&
(aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN()) &&
(llvm::isa<mlir::Float8E5M2Type, mlir::Float8E4M3FNType>(aElemTy)) &&
cast<RankedTensorType>(op.getType()).getElementType().isF32()) {
return false;
}
Expand All @@ -778,8 +778,9 @@ bool supportMMA(Value value, int version) {
cast<triton::gpu::TensorOrMemDesc>(value.getType()).getElementType();
// FP8 is not natively supported on all mma versions but it can always be
// promoted to fp16 therefore we can always support it.
bool isFP8 = elemTy.isFloat8E5M2() || elemTy.isFloat8E4M3FN() ||
elemTy.isFloat8E5M2FNUZ() || elemTy.isFloat8E4M3FNUZ();
bool isFP8 =
llvm::isa<mlir::Float8E5M2Type, mlir::Float8E4M3FNType,
mlir::Float8E5M2FNUZType, mlir::Float8E4M3FNUZType>(elemTy);
return isFP8 || elemTy.isF16() || elemTy.isBF16() ||
(elemTy.isF32() && version >= 2) ||
(elemTy.isInteger(8) && version >= 2);
Expand Down
12 changes: 12 additions & 0 deletions lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
addArgumentMaterialization([&](OpBuilder &builder,
RankedTensorType tensorType, ValueRange inputs,
Location loc) -> Value {
// Allows partial TTIR to TTGIR conversion by materializing a conversion for
// remaining arguments that have been converted to a new type.
// We use this to rewrite triton_xla.sparse_dot in a separate pass after
// 'convert-triton-to-tritongpu'.
return builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType,
inputs);
llvm_unreachable("Argument rematerialization should not happen in Triton "
"-> TritonGPU conversion");
return {};
Expand All @@ -66,6 +72,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
// convert origValue to newValue
addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs, Location loc) -> Value {
// Allows partial TTIR to TTGIR conversion by materializing a conversion for
// remaining uses of values that have been converted to a new type.
// We use this to rewrite triton_xla.sparse_dot in a separate pass after
// 'convert-triton-to-tritongpu'.
return builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType,
inputs);
llvm_unreachable("Source rematerialization should not happen in Triton -> "
"TritonGPU Conversion");
return {};
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -899,7 +899,7 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
if (argAttrs.empty())
return;
assert(type.getNumInputs() == argAttrs.size());
function_interface_impl::addArgAndResultAttrs(
call_interface_impl::addArgAndResultAttrs(
builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
}
Expand Down
13 changes: 9 additions & 4 deletions lib/Dialect/TritonGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,11 @@ struct CanonicalizeConvertFromAlloc
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
if (!convert)
return failure();
// LocalAllocOp lowering doesn't support going from DotOperandEncoding
// to SharedEncoding, so we want to keep this layout conversion.
if (mlir::isa<triton::gpu::DotOperandEncodingAttr>(
convert.getSrc().getType().getEncoding()))
return failure();
rewriter.replaceOpWithNewOp<triton::gpu::LocalAllocOp>(
op, op->getResult(0).getType(), convert.getSrc());
return mlir::success();
Expand Down Expand Up @@ -213,13 +218,13 @@ struct CanonicalizeConvertFromConvert
// heuristic to accommodate fused attention.
auto srcType = op.getSrc().getType();
auto dstType = op.getType();
if (mlir::isa<DotOperandEncodingAttr>(dstType.getEncoding()) &&
mlir::isa<NvidiaMmaEncodingAttr>(srcType.getEncoding()))
if (mlir::isa_and_nonnull<DotOperandEncodingAttr>(dstType.getEncoding()) &&
mlir::isa_and_nonnull<NvidiaMmaEncodingAttr>(srcType.getEncoding()))
return failure();

// for hopper MMAv3
if (mlir::isa<SharedEncodingAttr>(dstType.getEncoding()) &&
mlir::isa<NvidiaMmaEncodingAttr>(srcType.getEncoding()) &&
if (mlir::isa_and_nonnull<SharedEncodingAttr>(dstType.getEncoding()) &&
mlir::isa_and_nonnull<NvidiaMmaEncodingAttr>(srcType.getEncoding()) &&
llvm::any_of(op.getResult().getUsers(), [](Operation *dot) {
return dot->hasTrait<OpTrait::DotLike>();
})) {
Expand Down
50 changes: 41 additions & 9 deletions lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ namespace mlir {
namespace triton {
namespace gpu {

namespace {

// Get the highest version supported for the hardware and the dot.
static int getMMAVersionSafe(int computeCapability, DotOp op) {
// List supported mma version in order of preference.
Expand All @@ -47,8 +45,8 @@ static int getMMAVersionSafe(int computeCapability, DotOp op) {
return 0;
}

SmallVector<unsigned> warpsPerTileV2(DotOp dotOp, const ArrayRef<int64_t> shape,
int numWarps) {
SmallVector<unsigned>
warpsPerTileV2(Operation *dotOp, const ArrayRef<int64_t> shape, int numWarps) {
auto rank = shape.size();
// Early exit for batched matmul
if (rank == 3)
Expand Down Expand Up @@ -112,10 +110,10 @@ SmallVector<unsigned> warpsPerTileV2(DotOp dotOp, const ArrayRef<int64_t> shape,
}

SmallVector<unsigned, 2>
warpsPerTileV3(DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps,
warpsPerTileV3(Operation *dotOp, const ArrayRef<int64_t> shape, int numWarps,
const SmallVector<unsigned, 3> &instrShape) {
SetVector<Operation *> slices;
mlir::getForwardSlice(dotOp.getResult(), &slices);
mlir::getForwardSlice(dotOp->getResult(0), &slices);
// Contains a chained dot. We prefer to assign warps to one axis
// to facilitate use cases like flash attention, allowing reductions within
// the same warp.
Expand Down Expand Up @@ -170,11 +168,26 @@ static Value getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter,
auto newType = MemDescType::get(argType.getShape(), argType.getElementType(),
newLayout, SharedMemorySpace);
rewriter.setInsertionPointAfterValue(arg);

// LocalAllocOp lowering doesn't support going from DotOperandEncoding
// to SharedEncoding.
if (auto dotOpEnc = mlir::dyn_cast<DotOperandEncodingAttr>(
argType.getEncoding())) {
// Create a layout conversion from DotOperandEncoding to BlockedEncoding
// then pass it to the LocalAllocOp.
auto newArgType = RankedTensorType::get(
argType.getShape(), argType.getElementType(), dotOpEnc.getParent());
auto dotOperandToBlockedCvt =
rewriter.create<ConvertLayoutOp>(arg.getLoc(), newArgType, arg);
return rewriter.create<LocalAllocOp>(arg.getLoc(), newType,
dotOperandToBlockedCvt);
}

return rewriter.create<LocalAllocOp>(arg.getLoc(), newType, arg);
}

SmallVector<unsigned, 3>
getWarpsPerTile(DotOp dotOp, const ArrayRef<int64_t> shape, int version,
getWarpsPerTile(Operation* dotOp, const ArrayRef<int64_t> shape, int version,
int numWarps, const SmallVector<unsigned, 3> &instrShape) {
switch (version) {
case 2:
Expand All @@ -188,6 +201,16 @@ getWarpsPerTile(DotOp dotOp, const ArrayRef<int64_t> shape, int version,
}

static bool bwdFilter(Operation *op) {
// Dot operand layout assignment to Predicates are not currently supported
// during lowering from TritonGPU to LLVM in Triton for MMA cases. This
// condition limits visibility of the original bit-width so that predicate
// are not considered, hence, kwidth can never be = 32.
if (isa<arith::UIToFPOp>(op)) {
Type srcType = getElementTypeOrSelf(op->getOperand(0));
if (srcType.isInteger(1))
return false;
}

return op->getNumOperands() == 1 &&
(isa<FpToFpOp, BitcastOp, ConvertLayoutOp>(op) ||
isPureUnaryInlineAsm(op) ||
Expand All @@ -207,7 +230,7 @@ static bool bwdFilter(Operation *op) {
// result, kwidth can be the bitwidth of the lower precision primitive.
// Conversely, in the downcasting scenario, no reordering is performed,
// making it directory use the lower precision primitive.
static int computeOrigBitWidth(Value x) {
int computeOrigBitWidth(Value x) {
int finalBitWidth = getElementTypeOrSelf(x).getIntOrFloatBitWidth();
int origBitWidth = finalBitWidth;
SetVector<Operation *> slice;
Expand All @@ -227,6 +250,9 @@ static int computeOrigBitWidth(Value x) {
}
return origBitWidth;
}
// Move anonymous namespace down, so getWarpsPerTile is visible to the sparsity
// extension.
namespace {

class BlockedToMMA : public mlir::OpRewritePattern<DotOp> {
int computeCapability;
Expand Down Expand Up @@ -632,7 +658,8 @@ static void decomposeMixedModeDotOp(ModuleOp mod, int computeCapability) {
NvidiaMmaEncodingAttr mmaLayout =
dyn_cast<NvidiaMmaEncodingAttr>(D.getType().getEncoding());
if (mmaLayout) {
bool isNativeFP8 = AElType.isFloat8E5M2() || AElType.isFloat8E4M3FN();
bool isNativeFP8 =
llvm::isa<mlir::Float8E5M2Type, mlir::Float8E4M3FNType>(AElType);
// promote operands for sm < 89 since fp8 mma is not natively supported
// promote operands for sm >= 90 when mma is not v3
if (!isNativeFP8 ||
Expand Down Expand Up @@ -1018,6 +1045,11 @@ class TritonGPUAccelerateMatmulPass
}
};

Value getSharedMemMMAOperand(Value v, mlir::PatternRewriter &rewriter,
int opIdx, bool allowTranspose) {
return getSharedMemoryMMAOperand(v, rewriter, opIdx, allowTranspose);
}

} // namespace gpu
} // namespace triton
} // namespace mlir
1 change: 1 addition & 0 deletions lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> {
if (!foundInitializer)
return failure();

rewriter.setInsertionPointAfter(src);
SmallVector<ConvertLayoutOp> newOperands;
for (auto operand : src->getOperands()) {
// We checked earlier that all operands are ranked tensors.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ static int createAsyncCopy(scf::ForOp forOp, tt::LoadOp loadOp, Value alloc,

Value zero = builder.createWithStage<arith::ConstantIntOp>(
forOp.getLoc(), stage, clusterId, 0, 32);

// Replace the load with insert/extract slice.
builder.setInsertionPoint(loadOp);
Location loc = loadOp.getLoc();
Expand Down Expand Up @@ -527,7 +528,8 @@ assignMemoryLayouts(scf::ForOp &forOp,

bool isTMALoad = isa<tt::ExperimentalDescriptorLoadOp,
tt::ExperimentalDescriptorGatherOp>(op);
loadsToPipeline.insert(&op);
// TODO: b/381421713 - Uncomment this once pipelining is fixed.
// loadsToPipeline.insert(&op);
LoadInfo loadInfo;
for (auto use : users) {
if (use->hasTrait<OpTrait::DotLike>()) {
Expand Down Expand Up @@ -566,6 +568,11 @@ assignMemoryLayouts(scf::ForOp &forOp,
getBlockedEncoding(loadOp, axisInfoAnalysis);
}
}

// TODO: b/381421713 - Remove this once pipelining is fixed.
if (!loadInfo.sharedEncoding) continue;
loadsToPipeline.insert(&op);

loadToInfo[&op] = loadInfo;
}
// Make sure all loads in loadsToPipeline are in loadToInfo.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ mlir::triton::maybeGetStageCluster(Operation *op) {
}
std::pair<int, int> mlir::triton::getStageCluster(Operation *op) {
auto res = maybeGetStageCluster(op);
assert(res.has_value() || "Operation is missing stage & cluster attribute");
assert(res.has_value() && "Operation is missing stage & cluster attribute");
return *res;
}

Expand Down
26 changes: 24 additions & 2 deletions lib/Dialect/TritonGPU/Transforms/Prefetch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
// opIdx: 0 => a, 1 => b
auto type = cast<triton::gpu::MemDescType>(v.getType());
SmallVector<int64_t> shape{type.getShape().begin(), type.getShape().end()};
SmallVector<int64_t> offset{0, 0};
SmallVector<int64_t> offset(shape.size(), 0);
Type elementType = type.getElementType();

// k => (prefetchWidth, k - prefetchWidth)
Expand All @@ -141,8 +141,14 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
type.getMutableMemory(), type.getAllocShape()),
v, offsetsVal);

// We need to assign kwidth to zero in the case where the parent layout is
// Blocked, otherwise the verifier emits a failure. The parent layout is
// Blocked only when Tensor Cores are disabled.
int kwidth = dyn_cast<triton::gpu::BlockedEncodingAttr>(dotEncoding)
? 0
: prefetchWidth / 8;
auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get(
builder.getContext(), opIdx, dotEncoding, prefetchWidth / 8);
builder.getContext(), opIdx, dotEncoding, kwidth);
Value prefetchSlice = builder.create<triton::gpu::LocalLoadOp>(
v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc),
newSmem);
Expand Down Expand Up @@ -191,6 +197,22 @@ LogicalResult Prefetcher::initialize() {
break;
if (!op->getResult(0).hasOneUse())
break;
// Similar to issues faced in HoistLayoutConversion pattern in
// OptimizeDotOperands.cpp, we can't propagate through type casts from
// predicates as they aren't supported in Triton when encoded with dot_op
// layout.
if (isa<arith::UIToFPOp>(op)) {
Type srcType = getElementTypeOrSelf(op->getOperand(0));
if (srcType.isInteger(1))
break;
}
// Propagation through ExpandDims is currently not supported. This blindly
// replaces the encoding with dot encoding & but ExpandDims requires a
// SliceEncoding. This could be rewritten to support it somehow, but I
// don't think it's trivial & it's currently crashing.
if (isa<ExpandDimsOp>(op)) {
break;
}
rets.push_back(op->getOperand(0));
if (auto cvt = dyn_cast<triton::gpu::LocalLoadOp>(op)) {
foundConvertFromShared = true;
Expand Down
Loading

0 comments on commit aedc018

Please sign in to comment.