diff --git a/lib/Dialect/RTG/Transforms/CMakeLists.txt b/lib/Dialect/RTG/Transforms/CMakeLists.txt index 50f47352aefa..1111b8aa84c6 100644 --- a/lib/Dialect/RTG/Transforms/CMakeLists.txt +++ b/lib/Dialect/RTG/Transforms/CMakeLists.txt @@ -10,6 +10,7 @@ add_circt_dialect_library(CIRCTRTGTransforms LINK_LIBS PRIVATE CIRCTRTGDialect MLIRIndexDialect + MLIRSCFDialect MLIRIR MLIRPass ) diff --git a/lib/Dialect/RTG/Transforms/ElaborationPass.cpp b/lib/Dialect/RTG/Transforms/ElaborationPass.cpp index 9fc9374ab807..86c79fe7c703 100644 --- a/lib/Dialect/RTG/Transforms/ElaborationPass.cpp +++ b/lib/Dialect/RTG/Transforms/ElaborationPass.cpp @@ -19,6 +19,7 @@ #include "circt/Support/Namespace.h" #include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/PatternMatch.h" #include "llvm/Support/Debug.h" @@ -422,39 +423,105 @@ namespace { /// Construct an SSA value from a given elaborated value. class Materializer { public: + Materializer(OpBuilder builder) : builder(builder) {} + + /// Materialize IR representing the provided `ElaboratorValue` and return the + /// `Value` or a null value on failure. Value materialize(ElaboratorValue *val, Location loc, std::queue &elabRequests, function_ref emitError) { - assert(block && "must call reset before calling this function"); - auto iter = materializedValues.find(val); if (iter != materializedValues.end()) return iter->second; LLVM_DEBUG(llvm::dbgs() << "Materializing " << *val << "\n\n"); - OpBuilder builder(block, insertionPoint); return TypeSwitch(val) .Case([&](auto val) { - return visit(val, builder, loc, elabRequests, emitError); - }) + SequenceValue>( + [&](auto val) { return visit(val, loc, elabRequests, emitError); }) .Default([](auto val) { assert(false && "all cases must be covered above"); return Value(); }); } - Materializer &reset(Block *block) { - materializedValues.clear(); - integerValues.clear(); - this->block = block; - insertionPoint = block->begin(); - return *this; + /// If `op` is not in the same region as the materializer insertion point, a + /// clone is created at the materializer's insertion point by also + /// materializing the `ElaboratorValue`s for each operand just before it. + /// Otherwise, all operations after the materializer's insertion point are + /// deleted until `op` is reached. An error is returned if the operation is + /// before the insertion point. + LogicalResult materialize(Operation *op, + DenseMap &state, + std::queue &elabRequests) { + if (op->getNumRegions() > 0) + return op->emitOpError("ops with nested regions must be elaborated away"); + + // We don't support opaque values. If there is an SSA value that has a + // use-site it needs an equivalent ElaborationValue representation. + // NOTE: We could support cases where there is initially a use-site but that + // op is guaranteed to be deleted during elaboration. Or the use-sites are + // replaced with freshly materialized values from the ElaborationValue. But + // then, why can't we delete the value defining op? + for (auto res : op->getResults()) + if (!res.use_empty()) + return op->emitOpError( + "ops with results that have uses are not supported"); + + if (op->getParentRegion() == builder.getBlock()->getParent()) { + // We are doing in-place materialization, so mark all ops deleted until we + // reach the one to be materialized and modify it in-place. + auto ip = builder.getInsertionPoint(); + while (ip != builder.getBlock()->end() && &*ip != op) { + LLVM_DEBUG(llvm::dbgs() << "Marking to be deleted: " << *ip << "\n\n"); + toDelete.push_back(&*ip); + + builder.setInsertionPointAfter(&*ip); + ip = builder.getInsertionPoint(); + } + + if (ip == builder.getBlock()->end()) + return op->emitError("operation did not occur after the current " + "materializer insertion point"); + + LLVM_DEBUG(llvm::dbgs() << "Modifying in-place: " << *op << "\n\n"); + } else { + LLVM_DEBUG(llvm::dbgs() << "Materializing a clone of " << *op << "\n\n"); + op = builder.clone(*op); + builder.setInsertionPoint(op); + } + + for (auto &operand : op->getOpOperands()) { + auto emitError = [&]() { + auto diag = op->emitError(); + diag.attachNote(op->getLoc()) + << "while materializing value for operand#" + << operand.getOperandNumber(); + return diag; + }; + + Value val = materialize(state.at(operand.get()), op->getLoc(), + elabRequests, emitError); + if (!val) + return failure(); + + operand.set(val); + } + + builder.setInsertionPointAfter(op); + return success(); + } + + /// Should be called once the `Region` is successfully materialized. No calls + /// to `materialize` should happen after this anymore. + void finalize() { + for (auto *op : llvm::reverse(toDelete)) + op->erase(); } private: - Value visit(AttributeValue *val, OpBuilder &builder, Location loc, + Value visit(AttributeValue *val, Location loc, std::queue &elabRequests, function_ref emitError) { auto attr = val->getAttr(); @@ -485,7 +552,7 @@ class Materializer { return res; } - Value visit(IndexValue *val, OpBuilder &builder, Location loc, + Value visit(IndexValue *val, Location loc, std::queue &elabRequests, function_ref emitError) { Value res = builder.create(loc, val->getIndex()); @@ -493,7 +560,7 @@ class Materializer { return res; } - Value visit(BoolValue *val, OpBuilder &builder, Location loc, + Value visit(BoolValue *val, Location loc, std::queue &elabRequests, function_ref emitError) { Value res = builder.create(loc, val->getBool()); @@ -501,7 +568,7 @@ class Materializer { return res; } - Value visit(SetValue *val, OpBuilder &builder, Location loc, + Value visit(SetValue *val, Location loc, std::queue &elabRequests, function_ref emitError) { SmallVector elements; @@ -519,7 +586,7 @@ class Materializer { return res; } - Value visit(BagValue *val, OpBuilder &builder, Location loc, + Value visit(BagValue *val, Location loc, std::queue &elabRequests, function_ref emitError) { SmallVector values, weights; @@ -550,7 +617,7 @@ class Materializer { return res; } - Value visit(SequenceValue *val, OpBuilder &builder, Location loc, + Value visit(SequenceValue *val, Location loc, std::queue &elabRequests, function_ref emitError) { elabRequests.push(val); @@ -566,16 +633,42 @@ class Materializer { DenseMap materializedValues; DenseMap integerValues; - /// Cache the builders to continue insertions at their current insertion point + /// Cache the builder to continue insertions at their current insertion point /// for the reason stated above. - Block *block; - Block::iterator insertionPoint; + OpBuilder builder; + + SmallVector toDelete; }; /// Used to signal to the elaboration driver whether the operation should be /// removed. enum class DeletionKind { Keep, Delete }; +/// Elaborator state that should be shared by all elaborator instances. +struct ElaboratorSharedState { + ElaboratorSharedState(SymbolTable &table, unsigned seed) + : table(table), rng(seed) {} + + SymbolTable &table; + std::mt19937 rng; + Namespace names; + + // A map used to intern elaborator values. We do this such that we can + // compare pointers when, e.g., computing set differences, uniquing the + // elements in a set, etc. Otherwise, we'd need to do a deep value comparison + // in those situations. + // Use a pointer as the key with custom MapInfo because of object slicing when + // inserting an object of a derived class of ElaboratorValue. + // The custom MapInfo makes sure that we do a value comparison instead of + // comparing the pointers. + DenseMap, InternMapInfo> + interned; + + /// The worklist used to keep track of the test and sequence operations to + /// make sure they are processed top-down (BFS traversal). + std::queue worklist; +}; + /// Interprets the IR to perform and lower the represented randomizations. class Elaborator : public RTGOpVisitor> { public: @@ -583,7 +676,8 @@ class Elaborator : public RTGOpVisitor> { using RTGBase::visitOp; using RTGBase::visitRegisterOp; - Elaborator(SymbolTable &table, std::mt19937 &rng) : rng(rng), table(table) {} + Elaborator(ElaboratorSharedState &sharedState, Materializer &materializer) + : sharedState(sharedState), materializer(materializer) {} /// Helper to perform internalization and keep track of interpreted value for /// the given SSA value. @@ -592,7 +686,7 @@ class Elaborator : public RTGOpVisitor> { // TODO: this isn't the most efficient way to internalize auto ptr = std::make_unique(std::forward(args)...); auto *e = ptr.get(); - auto [iter, _] = interned.insert({e, std::move(ptr)}); + auto [iter, _] = sharedState.interned.insert({e, std::move(ptr)}); state[val] = iter->second.get(); } @@ -617,7 +711,7 @@ class Elaborator : public RTGOpVisitor> { args.push_back(state.at(arg)); auto familyName = op.getSequenceAttr(); - auto name = names.newName(familyName.getValue()); + auto name = sharedState.names.newName(familyName.getValue()); internalizeResult(op.getResult(), name, familyName, std::move(args)); return DeletionKind::Delete; @@ -646,7 +740,8 @@ class Elaborator : public RTGOpVisitor> { std::mt19937 customRng(intAttr.getInt()); selected = getUniformlyInRange(customRng, 0, set->getSet().size() - 1); } else { - selected = getUniformlyInRange(rng, 0, set->getSet().size() - 1); + selected = + getUniformlyInRange(sharedState.rng, 0, set->getSet().size() - 1); } state[op.getResult()] = set->getSet()[selected]; @@ -709,7 +804,7 @@ class Elaborator : public RTGOpVisitor> { prefixSum.push_back({val, accumulator}); } - auto customRng = rng; + auto customRng = sharedState.rng; if (auto intAttr = op->getAttrOfType("rtg.elaboration_custom_seed")) { customRng = std::mt19937(intAttr.getInt()); @@ -771,6 +866,72 @@ class Elaborator : public RTGOpVisitor> { return DeletionKind::Delete; } + FailureOr visitOp(scf::IfOp op) { + bool cond = cast(state.at(op.getCondition()))->getBool(); + auto &toElaborate = cond ? op.getThenRegion() : op.getElseRegion(); + if (toElaborate.empty()) + return DeletionKind::Delete; + + // Just reuse this elaborator for the nested region because we need access + // to the elaborated values outside the nested region (since it is not + // isolated from above) and we want to materialize the region inline, thus + // don't need a new materializer instance. + if (failed(elaborate(toElaborate))) + return failure(); + + // Map the results of the 'scf.if' to the yielded values. + for (auto [res, out] : + llvm::zip(op.getResults(), + toElaborate.front().getTerminator()->getOperands())) + state[res] = state.at(out); + + return DeletionKind::Delete; + } + + FailureOr visitOp(scf::ForOp op) { + auto *lowerBound = dyn_cast(state.at(op.getLowerBound())); + auto *step = dyn_cast(state.at(op.getStep())); + auto *upperBound = dyn_cast(state.at(op.getUpperBound())); + + if (!lowerBound || !step || !upperBound) + return op->emitOpError("can only elaborate index type iterator"); + + // Prepare for first iteration by assigning the nested regions block + // arguments. We can just reuse this elaborator because we need access to + // values elaborated in the parent region anyway and materialize everything + // inline (i.e., don't need a new materializer). + state[op.getInductionVar()] = lowerBound; + for (auto [iterArg, initArg] : + llvm::zip(op.getRegionIterArgs(), op.getInitArgs())) + state[iterArg] = state.at(initArg); + + // This loop performs the actual 'scf.for' loop iterations. + for (size_t i = lowerBound->getIndex(); i < upperBound->getIndex(); + i += step->getIndex()) { + if (failed(elaborate(op.getBodyRegion()))) + return failure(); + + // Prepare for the next iteration by updating the mapping of the nested + // regions block arguments + internalizeResult(op.getInductionVar(), i + step->getIndex()); + for (auto [iterArg, prevIterArg] : + llvm::zip(op.getRegionIterArgs(), + op.getBody()->getTerminator()->getOperands())) + state[iterArg] = state.at(prevIterArg); + } + + // Transfer the previously yielded values to the for loop result values. + for (auto [res, iterArg] : + llvm::zip(op->getResults(), op.getRegionIterArgs())) + state[res] = state.at(iterArg); + + return DeletionKind::Delete; + } + + FailureOr visitOp(scf::YieldOp op) { + return DeletionKind::Delete; + } + FailureOr visitOp(index::AddOp op) { size_t lhs = cast(state.at(op.getLhs()))->getIndex(); size_t rhs = cast(state.at(op.getRhs()))->getIndex(); @@ -832,120 +993,38 @@ class Elaborator : public RTGOpVisitor> { } return TypeSwitch>(op) - .Case([&](auto op) { return visitOp(op); }) + .Case< + // Index ops + index::AddOp, index::CmpOp, + // SCF ops + scf::IfOp, scf::ForOp, scf::YieldOp>( + [&](auto op) { return visitOp(op); }) .Default([&](Operation *op) { return RTGBase::dispatchOpVisitor(op); }); } - LogicalResult elaborate(SequenceOp family, SequenceOp dest, - ArrayRef args) { - LLVM_DEBUG(llvm::dbgs() << "\n=== Elaborating " << family.getOperationName() - << " @" << family.getSymName() << " into @" - << dest.getSymName() << "\n\n"); - - // Reduce max memory consumption and make sure the values cannot be accessed - // anymore because we deleted the ops above. Clearing should lead to better - // performance than having them as a local here and pass via function - // argument. - state.clear(); - materializer.reset(dest.getBody()); - IRMapping mapping; + // NOLINTNEXTLINE(misc-no-recursion) + LogicalResult elaborate(Region ®ion, + ArrayRef regionArguments = {}) { + if (region.getBlocks().size() > 1) + return region.getParentOp()->emitOpError( + "regions with more than one block are not supported"); for (auto [arg, elabArg] : - llvm::zip(family.getBody()->getArguments(), args)) + llvm::zip(region.getArguments(), regionArguments)) state[arg] = elabArg; - for (auto &op : *family.getBody()) { - if (op.getNumRegions() != 0) - return op.emitOpError("nested regions not supported"); - - auto result = dispatchOpVisitor(&op); - if (failed(result)) - return failure(); - - if (*result == DeletionKind::Keep) { - for (auto &operand : op.getOpOperands()) { - if (mapping.contains(operand.get())) - continue; - - auto emitError = [&]() { - auto diag = op.emitError(); - diag.attachNote(op.getLoc()) - << "while materializing value for operand#" - << operand.getOperandNumber(); - return diag; - }; - Value val = materializer.materialize( - state.at(operand.get()), op.getLoc(), worklist, emitError); - if (!val) - return failure(); - - mapping.map(operand.get(), val); - } - - OpBuilder builder = OpBuilder::atBlockEnd(dest.getBody()); - builder.clone(op, mapping); - } - - LLVM_DEBUG({ - llvm::dbgs() << "Elaborating " << op << " to\n["; - - llvm::interleaveComma(op.getResults(), llvm::dbgs(), [&](auto res) { - if (state.contains(res)) - llvm::dbgs() << *state.at(res); - else - llvm::dbgs() << "unknown"; - }); - - llvm::dbgs() << "]\n\n"; - }); - } - - return success(); - } - - template - LogicalResult elaborateInPlace(OpTy op) { - LLVM_DEBUG(llvm::dbgs() - << "\n=== Elaborating (in place) " << op.getOperationName() - << " @" << op.getSymName() << "\n\n"); - - // Reduce max memory consumption and make sure the values cannot be accessed - // anymore because we deleted the ops above. Clearing should lead to better - // performance than having them as a local here and pass via function - // argument. - state.clear(); - materializer.reset(op.getBody()); - - SmallVector toDelete; - for (auto &op : *op.getBody()) { - if (op.getNumRegions() != 0) - return op.emitOpError("nested regions not supported"); - + Block *block = ®ion.front(); + for (auto &op : *block) { auto result = dispatchOpVisitor(&op); if (failed(result)) return failure(); - if (*result == DeletionKind::Keep) { - for (auto &operand : op.getOpOperands()) { - auto emitError = [&]() { - auto diag = op.emitError(); - diag.attachNote(op.getLoc()) - << "while materializing value for operand#" - << operand.getOperandNumber(); - return diag; - }; - Value val = materializer.materialize( - state.at(operand.get()), op.getLoc(), worklist, emitError); - if (!val) - return failure(); - operand.set(val); - } - } else { // DeletionKind::Delete - toDelete.push_back(&op); - } + if (*result == DeletionKind::Keep) + if (failed(materializer.materialize(&op, state, sharedState.worklist))) + return failure(); LLVM_DEBUG({ - llvm::dbgs() << "Elaborating " << op << " to\n["; + llvm::dbgs() << "Elaborated " << op << " to\n["; llvm::interleaveComma(op.getResults(), llvm::dbgs(), [&](auto res) { if (state.contains(res)) @@ -958,118 +1037,19 @@ class Elaborator : public RTGOpVisitor> { }); } - for (auto *op : llvm::reverse(toDelete)) - op->erase(); - - return success(); - } - - LogicalResult inlineSequences(TestOp testOp) { - OpBuilder builder(testOp); - for (auto iter = testOp.getBody()->begin(); - iter != testOp.getBody()->end();) { - auto invokeOp = dyn_cast(&*iter); - if (!invokeOp) { - ++iter; - continue; - } - - auto seqClosureOp = - invokeOp.getSequence().getDefiningOp(); - if (!seqClosureOp) - return invokeOp->emitError( - "sequence operand not directly defined by sequence_closure op"); - - auto seqOp = table.lookup(seqClosureOp.getSequenceAttr()); - - builder.setInsertionPointAfter(invokeOp); - IRMapping mapping; - for (auto &op : *seqOp.getBody()) - builder.clone(op, mapping); - - (iter++)->erase(); - - if (seqClosureOp->use_empty()) - seqClosureOp->erase(); - } - - return success(); - } - - LogicalResult elaborateModule(ModuleOp moduleOp) { - // Update the name cache - names.clear(); - names.add(moduleOp); - - // Initialize the worklist with the test ops since they cannot be placed by - // other ops. - for (auto testOp : moduleOp.getOps()) - if (failed(elaborateInPlace(testOp))) - return failure(); - - // Do top-down BFS traversal such that elaborating a sequence further down - // does not fix the outcome for multiple placements. - while (!worklist.empty()) { - auto *curr = worklist.front(); - worklist.pop(); - - if (table.lookup(curr->getName())) - continue; - - auto familyOp = table.lookup(curr->getFamilyName()); - // TODO: use 'elaborateInPlace' and don't clone if this is the only - // remaining reference to this sequence - OpBuilder builder(familyOp); - auto seqOp = builder.cloneWithoutRegions(familyOp); - seqOp.getBodyRegion().emplaceBlock(); - seqOp.setSymName(curr->getName()); - table.insert(seqOp); - assert(seqOp.getSymName() == curr->getName() && - "should not have been renamed"); - - if (failed(elaborate(familyOp, seqOp, curr->getArgs()))) - return failure(); - } - - // Inline all sequences and remove the operations that place the sequences. - for (auto testOp : moduleOp.getOps()) - if (failed(inlineSequences(testOp))) - return failure(); - - // Remove all sequences since they are not accessible from the outside and - // are not needed anymore since we fully inlined them. - for (auto seqOp : llvm::make_early_inc_range(moduleOp.getOps())) - seqOp->erase(); - return success(); } private: - std::mt19937 rng; - SymbolTable &table; - Namespace names; + // State to be shared between all elaborator instances. + ElaboratorSharedState &sharedState; - /// The worklist used to keep track of the test and sequence operations to - /// make sure they are processed top-down (BFS traversal). - std::queue worklist; - - // A map used to intern elaborator values. We do this such that we can - // compare pointers when, e.g., computing set differences, uniquing the - // elements in a set, etc. Otherwise, we'd need to do a deep value comparison - // in those situations. - // Use a pointer as the key with custom MapInfo because of object slicing when - // inserting an object of a derived class of ElaboratorValue. - // The custom MapInfo makes sure that we do a value comparison instead of - // comparing the pointers. - DenseMap, InternMapInfo> - interned; + // Allows us to materialize ElaboratorValues to the IR operations necessary to + // obtain an SSA value representing that elaborated value. + Materializer &materializer; // A map from SSA values to a pointer of an interned elaborator value. DenseMap state; - - // Allows us to materialize ElaboratorValues to the IR operations necessary to - // obtain an SSA value representing that elaborated value. - Materializer materializer; }; } // namespace @@ -1084,6 +1064,8 @@ struct ElaborationPass void runOnOperation() override; void cloneTargetsIntoTests(SymbolTable &table); + LogicalResult elaborateModule(ModuleOp moduleOp, SymbolTable &table); + LogicalResult inlineSequences(TestOp testOp, SymbolTable &table); }; } // namespace @@ -1093,9 +1075,7 @@ void ElaborationPass::runOnOperation() { cloneTargetsIntoTests(table); - std::mt19937 rng(seed); - Elaborator elaborator(table, rng); - if (failed(elaborator.elaborateModule(moduleOp))) + if (failed(elaborateModule(moduleOp, table))) return signalPassFailure(); } @@ -1143,3 +1123,101 @@ void ElaborationPass::cloneTargetsIntoTests(SymbolTable &table) { if (!test.getTarget().getEntries().empty()) test->erase(); } + +LogicalResult ElaborationPass::elaborateModule(ModuleOp moduleOp, + SymbolTable &table) { + ElaboratorSharedState state(table, seed); + + // Update the name cache + state.names.add(moduleOp); + + // Initialize the worklist with the test ops since they cannot be placed by + // other ops. + for (auto testOp : moduleOp.getOps()) { + LLVM_DEBUG(llvm::dbgs() + << "\n=== Elaborating test @" << testOp.getSymName() << "\n\n"); + Materializer materializer(OpBuilder::atBlockBegin(testOp.getBody())); + Elaborator elaborator(state, materializer); + if (failed(elaborator.elaborate(testOp.getBodyRegion()))) + return failure(); + + materializer.finalize(); + } + + // Do top-down BFS traversal such that elaborating a sequence further down + // does not fix the outcome for multiple placements. + while (!state.worklist.empty()) { + auto *curr = state.worklist.front(); + state.worklist.pop(); + + if (table.lookup(curr->getName())) + continue; + + auto familyOp = table.lookup(curr->getFamilyName()); + // TODO: don't clone if this is the only remaining reference to this + // sequence + OpBuilder builder(familyOp); + auto seqOp = builder.cloneWithoutRegions(familyOp); + seqOp.getBodyRegion().emplaceBlock(); + seqOp.setSymName(curr->getName()); + table.insert(seqOp); + assert(seqOp.getSymName() == curr->getName() && + "should not have been renamed"); + + LLVM_DEBUG(llvm::dbgs() + << "\n=== Elaborating sequence family @" << familyOp.getSymName() + << " into @" << seqOp.getSymName() << "\n\n"); + + Materializer materializer(OpBuilder::atBlockBegin(seqOp.getBody())); + Elaborator elaborator(state, materializer); + if (failed(elaborator.elaborate(familyOp.getBodyRegion(), curr->getArgs()))) + return failure(); + + materializer.finalize(); + } + + // Inline all sequences and remove the operations that place the sequences. + for (auto testOp : moduleOp.getOps()) + if (failed(inlineSequences(testOp, table))) + return failure(); + + // Remove all sequences since they are not accessible from the outside and + // are not needed anymore since we fully inlined them. + for (auto seqOp : llvm::make_early_inc_range(moduleOp.getOps())) + seqOp->erase(); + + return success(); +} + +LogicalResult ElaborationPass::inlineSequences(TestOp testOp, + SymbolTable &table) { + OpBuilder builder(testOp); + for (auto iter = testOp.getBody()->begin(); + iter != testOp.getBody()->end();) { + auto invokeOp = dyn_cast(&*iter); + if (!invokeOp) { + ++iter; + continue; + } + + auto seqClosureOp = + invokeOp.getSequence().getDefiningOp(); + if (!seqClosureOp) + return invokeOp->emitError( + "sequence operand not directly defined by sequence_closure op"); + + auto seqOp = table.lookup(seqClosureOp.getSequenceAttr()); + + builder.setInsertionPointAfter(invokeOp); + IRMapping mapping; + for (auto &op : *seqOp.getBody()) + builder.clone(op, mapping); + + (iter++)->erase(); + + if (seqClosureOp->use_empty()) + seqClosureOp->erase(); + } + + return success(); +} diff --git a/test/Dialect/RTG/Transform/elaboration.mlir b/test/Dialect/RTG/Transform/elaboration.mlir index d67218d482cf..0be78b2ac6dc 100644 --- a/test/Dialect/RTG/Transform/elaboration.mlir +++ b/test/Dialect/RTG/Transform/elaboration.mlir @@ -179,18 +179,18 @@ rtg.test @sequenceClosureFixesRandomization : !rtg.dict<> { // CHECK-LABLE: @indexOps rtg.test @indexOps : !rtg.dict<> { // CHECK: [[C:%.+]] = index.constant 2 - // CHECK: [[T:%.+]] = index.bool.constant true - // CHECK: [[F:%.+]] = index.bool.constant false %0 = index.constant 1 // CHECK: func.call @dummy2([[C]]) %1 = index.add %0, %0 func.call @dummy2(%1) : (index) -> () + // CHECK: [[T:%.+]] = index.bool.constant true // CHECK: func.call @dummy5([[T]]) %2 = index.cmp eq(%0, %0) func.call @dummy5(%2) : (i1) -> () + // CHECK: [[F:%.+]] = index.bool.constant false // CHECK: func.call @dummy5([[F]]) %3 = index.cmp ne(%0, %0) func.call @dummy5(%3) : (i1) -> () @@ -216,10 +216,94 @@ rtg.test @indexOps : !rtg.dict<> { func.call @dummy5(%8) : (i1) -> () } +// CHECK-LABEL: @scfIf +rtg.test @scfIf : !rtg.dict<> { + %0 = index.bool.constant true + %1 = index.bool.constant false + + // Don't elaborate body + scf.if %1 { + func.call @dummy5(%0) : (i1) -> () + scf.yield + } + + // Test nested ifs + // CHECK-NEXT: [[T:%.+]] = index.bool.constant true + // CHECK-NEXT: func.call @dummy5([[T]]) + // CHECK-NEXT: [[F:%.+]] = index.bool.constant false + // CHECK-NEXT: func.call @dummy5([[F]]) + scf.if %0 { + scf.if %0 { + scf.if %0 { + func.call @dummy5(%0) : (i1) -> () + scf.yield + } + scf.yield + } + scf.if %0 { + func.call @dummy5(%1) : (i1) -> () + scf.yield + } + scf.yield + } + + // Return values + // CHECK-NEXT: [[C1:%.+]] = index.constant 1 + // CHECK-NEXT: func.call @dummy2([[C1]]) + %2 = scf.if %0 -> index { + %3 = index.constant 1 + scf.yield %3 : index + } else { + %3 = index.constant 2 + scf.yield %3 : index + } + func.call @dummy2(%2) : (index) -> () +} + +// CHECK-LABEL: @scfFor +rtg.test @scfFor : !rtg.dict<> { + // CHECK-NEXT: [[C0:%.+]] = index.constant 0 + // CHECK-NEXT: func.call @dummy2([[C0]]) + // CHECK-NEXT: [[C1:%.+]] = index.constant 1 + // CHECK-NEXT: func.call @dummy2([[C1]]) + // CHECK-NEXT: [[C2:%.+]] = index.constant 2 + // CHECK-NEXT: func.call @dummy2([[C2]]) + // CHECK-NEXT: func.call @dummy2([[C2]]) + // CHECK-NEXT: [[C4:%.+]] = index.constant 4 + // CHECK-NEXT: func.call @dummy2([[C4]]) + // CHECK-NEXT: [[C3:%.+]] = index.constant 3 + // CHECK-NEXT: func.call @dummy2([[C3]]) + // CHECK-NEXT: func.call @dummy2([[C3]]) + // CHECK-NEXT: func.call @dummy2([[C0]]) + // CHECK-NEXT: } + + %0 = index.constant 0 + %1 = index.constant 2 + %2 = index.constant 5 + %3 = index.constant 1 + // Three iterations + %4 = scf.for %i = %0 to %2 step %1 iter_args(%a = %0) -> (index) { + %5 = index.add %a, %3 + func.call @dummy2(%i) : (index) -> () + func.call @dummy2(%5) : (index) -> () + scf.yield %5 : index + } + func.call @dummy2(%4) : (index) -> () + + // Zero iterations + %5 = scf.for %i = %0 to %0 step %1 iter_args(%a = %0) -> (index) { + %6 = index.add %a, %3 + func.call @dummy2(%a) : (index) -> () + func.call @dummy2(%6) : (index) -> () + scf.yield %6 : index + } + func.call @dummy2(%5) : (index) -> () +} + // ----- rtg.test @nestedRegionsNotSupported : !rtg.dict<> { - // expected-error @below {{nested regions not supported}} + // expected-error @below {{ops with nested regions must be elaborated away}} scf.execute_region { scf.yield } }