//===- llvm/unittests/Transforms/Vectorize/VPlanUncountableExitTest.cpp ---===// // // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "../lib/Transforms/Vectorize/VPRecipeBuilder.h" #include "../lib/Transforms/Vectorize/VPlan.h" #include "../lib/Transforms/Vectorize/VPlanPatternMatch.h" #include "../lib/Transforms/Vectorize/VPlanUtils.h" #include "VPlanTestBase.h" #include "llvm/ADT/SmallVector.h" #include "gtest/gtest.h" namespace llvm { namespace { class VPUncountableExitTest : public VPlanTestIRBase {}; using namespace VPlanPatternMatch; static void combineExitConditions(VPlan &Plan) { struct EarlyExitInfo { VPBasicBlock *EarlyExitingVPBB; VPIRBasicBlock *EarlyExitVPBB; VPValue *CondToExit; }; auto *MiddleVPBB = cast( Plan.getScalarHeader()->getSinglePredecessor()->getPredecessors()[0]); auto *LatchVPBB = cast(MiddleVPBB->getSinglePredecessor()); // Find the single early exit: a non-middle predecessor of an exit block. VPBasicBlock *EarlyExitingVPBB = nullptr; VPIRBasicBlock *EarlyExitVPBB = nullptr; for (VPIRBasicBlock *ExitBlock : Plan.getExitBlocks()) { for (VPBlockBase *Pred : ExitBlock->getPredecessors()) { if (Pred != MiddleVPBB) { EarlyExitingVPBB = cast(Pred); EarlyExitVPBB = ExitBlock; } } } assert(EarlyExitingVPBB && "must have an early exit"); // Wrap the early exit condition in a MaskedCond. VPValue *Cond; [[maybe_unused]] bool Matched = match(EarlyExitingVPBB->getTerminator(), m_BranchOnCond(m_VPValue(Cond))); assert(Matched && "Terminator must be BranchOnCond"); VPBuilder EarlyExitBuilder(EarlyExitingVPBB->getTerminator()); if (EarlyExitingVPBB->getSuccessors()[0] != EarlyExitVPBB) Cond = EarlyExitBuilder.createNot(Cond); auto *MaskedCond = EarlyExitBuilder.createNaryOp(VPInstruction::MaskedCond, {Cond}); // Combine the early exit with the latch exit on the latch terminator. VPBuilder Builder(LatchVPBB->getTerminator()); auto *IsAnyExitTaken = Builder.createNaryOp(VPInstruction::AnyOf, {MaskedCond}); auto *LatchBranch = cast(LatchVPBB->getTerminator()); auto *IsLatchExitTaken = Builder.createICmp( CmpInst::ICMP_EQ, LatchBranch->getOperand(0), LatchBranch->getOperand(1)); LatchBranch->eraseFromParent(); Builder.setInsertPoint(LatchVPBB); Builder.createNaryOp(VPInstruction::BranchOnCond, {Builder.createOr(IsAnyExitTaken, IsLatchExitTaken)}); // Disconnect the early exit edge. EarlyExitingVPBB->getTerminator()->eraseFromParent(); VPBlockUtils::disconnectBlocks(EarlyExitingVPBB, EarlyExitVPBB); } TEST_F(VPUncountableExitTest, FindUncountableExitRecipes) { const char *ModuleString = "target datalayout = " "\"e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-" "f64:64:64-v64:64:64-v128:128:128-a0:0:64-s0:64:64-f80:128:128-n8:16:" "32:64-S128\"\n" "define void @f(ptr dereferenceable(40) align 2 %array, " "ptr dereferenceable(40) align 2 %pred) {\n" "entry:\n" " br label %for.body\n" "for.body:\n" " %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.inc ]\n" " %st.addr = getelementptr inbounds i16, ptr %array, i64 %iv\n" " %data = load i16, ptr %st.addr, align 2\n" " %inc = add nsw i16 %data, 1\n" " store i16 %inc, ptr %st.addr, align 2\n" " %uncountable.addr = getelementptr inbounds nuw i16, ptr %pred, i64 " "%iv\n" " %uncountable.val = load i16, ptr %uncountable.addr, align 2\n" " %uncountable.cond = icmp sgt i16 %uncountable.val, 500\n" " br i1 %uncountable.cond, label %exit, label %for.inc\n" "for.inc:\n" " %iv.next = add nuw nsw i64 %iv, 1\n" " %countable.cond = icmp eq i64 %iv.next, 20\n" " br i1 %countable.cond, label %exit, label %for.body\n" "exit:\n" " ret void\n" "}\n"; Module &M = parseModule(ModuleString); Function *F = M.getFunction("f"); BasicBlock *LoopHeader = F->getEntryBlock().getSingleSuccessor(); VPlanPtr Plan = buildVPlan0(LoopHeader); combineExitConditions(*Plan); SmallVector Recipes; SmallVector GEPs; auto *MiddleVPBB = cast( Plan->getScalarHeader()->getSinglePredecessor()->getPredecessors()[0]); auto *LatchVPBB = cast(MiddleVPBB->getSinglePredecessor()); std::optional UncountableCondition = vputils::getRecipesForUncountableExit(Recipes, GEPs, LatchVPBB); ASSERT_TRUE(UncountableCondition.has_value()); ASSERT_EQ(GEPs.size(), 1ull); ASSERT_EQ(Recipes.size(), 4ull); } TEST_F(VPUncountableExitTest, NoUncountableExit) { const char *ModuleString = "define void @f(ptr %array, ptr %pred) {\n" "entry:\n" " br label %for.body\n" "for.body:\n" " %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ]\n" " %st.addr = getelementptr inbounds i16, ptr %array, i64 %iv\n" " %data = load i16, ptr %st.addr, align 2\n" " %inc = add nsw i16 %data, 1\n" " store i16 %inc, ptr %st.addr, align 2\n" " %iv.next = add nuw nsw i64 %iv, 1\n" " %countable.cond = icmp eq i64 %iv.next, 20\n" " br i1 %countable.cond, label %exit, label %for.body\n" "exit:\n" " ret void\n" "}\n"; Module &M = parseModule(ModuleString); Function *F = M.getFunction("f"); BasicBlock *LoopHeader = F->getEntryBlock().getSingleSuccessor(); auto Plan = buildVPlan0(LoopHeader); SmallVector Recipes; SmallVector GEPs; auto *MiddleVPBB = cast( Plan->getScalarHeader()->getSinglePredecessor()->getPredecessors()[0]); auto *LatchVPBB = cast(MiddleVPBB->getSinglePredecessor()); std::optional UncountableCondition = vputils::getRecipesForUncountableExit(Recipes, GEPs, LatchVPBB); ASSERT_FALSE(UncountableCondition.has_value()); ASSERT_EQ(GEPs.size(), 0ull); ASSERT_EQ(Recipes.size(), 0ull); } } // namespace } // namespace llvm