Files
llvm-project/llvm/unittests/Transforms/Vectorize/VPlanUncountableExitTest.cpp
Graham Hunter deaef1c1b7 [LV] Adjust exit recipe detection to run on early vplan (#183318)
Splitting out some work from #178454; this covers the enums for
early exit loop type (none, readonly, readwrite) and the style
used (readonly with multiple exit blocks, or masking with the
last iteration done in scalar code), along with changing the early
exit recipe detection to suit moving the transform for handling
early exit readwrite loops earlier in the vplan pipeline.
2026-04-02 17:25:35 +01:00

166 lines
6.2 KiB
C++

//===- llvm/unittests/Transforms/Vectorize/VPlanUncountableExitTest.cpp ---===//
//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "../lib/Transforms/Vectorize/VPRecipeBuilder.h"
#include "../lib/Transforms/Vectorize/VPlan.h"
#include "../lib/Transforms/Vectorize/VPlanPatternMatch.h"
#include "../lib/Transforms/Vectorize/VPlanUtils.h"
#include "VPlanTestBase.h"
#include "llvm/ADT/SmallVector.h"
#include "gtest/gtest.h"
namespace llvm {
namespace {
class VPUncountableExitTest : public VPlanTestIRBase {};
using namespace VPlanPatternMatch;
static void combineExitConditions(VPlan &Plan) {
struct EarlyExitInfo {
VPBasicBlock *EarlyExitingVPBB;
VPIRBasicBlock *EarlyExitVPBB;
VPValue *CondToExit;
};
auto *MiddleVPBB = cast<VPBasicBlock>(
Plan.getScalarHeader()->getSinglePredecessor()->getPredecessors()[0]);
auto *LatchVPBB = cast<VPBasicBlock>(MiddleVPBB->getSinglePredecessor());
// Find the single early exit: a non-middle predecessor of an exit block.
VPBasicBlock *EarlyExitingVPBB = nullptr;
VPIRBasicBlock *EarlyExitVPBB = nullptr;
for (VPIRBasicBlock *ExitBlock : Plan.getExitBlocks()) {
for (VPBlockBase *Pred : ExitBlock->getPredecessors()) {
if (Pred != MiddleVPBB) {
EarlyExitingVPBB = cast<VPBasicBlock>(Pred);
EarlyExitVPBB = ExitBlock;
}
}
}
assert(EarlyExitingVPBB && "must have an early exit");
// Wrap the early exit condition in a MaskedCond.
VPValue *Cond;
[[maybe_unused]] bool Matched =
match(EarlyExitingVPBB->getTerminator(), m_BranchOnCond(m_VPValue(Cond)));
assert(Matched && "Terminator must be BranchOnCond");
VPBuilder EarlyExitBuilder(EarlyExitingVPBB->getTerminator());
if (EarlyExitingVPBB->getSuccessors()[0] != EarlyExitVPBB)
Cond = EarlyExitBuilder.createNot(Cond);
auto *MaskedCond =
EarlyExitBuilder.createNaryOp(VPInstruction::MaskedCond, {Cond});
// Combine the early exit with the latch exit on the latch terminator.
VPBuilder Builder(LatchVPBB->getTerminator());
auto *IsAnyExitTaken =
Builder.createNaryOp(VPInstruction::AnyOf, {MaskedCond});
auto *LatchBranch = cast<VPInstruction>(LatchVPBB->getTerminator());
auto *IsLatchExitTaken = Builder.createICmp(
CmpInst::ICMP_EQ, LatchBranch->getOperand(0), LatchBranch->getOperand(1));
LatchBranch->eraseFromParent();
Builder.setInsertPoint(LatchVPBB);
Builder.createNaryOp(VPInstruction::BranchOnCond,
{Builder.createOr(IsAnyExitTaken, IsLatchExitTaken)});
// Disconnect the early exit edge.
EarlyExitingVPBB->getTerminator()->eraseFromParent();
VPBlockUtils::disconnectBlocks(EarlyExitingVPBB, EarlyExitVPBB);
}
TEST_F(VPUncountableExitTest, FindUncountableExitRecipes) {
const char *ModuleString =
"target datalayout = "
"\"e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-"
"f64:64:64-v64:64:64-v128:128:128-a0:0:64-s0:64:64-f80:128:128-n8:16:"
"32:64-S128\"\n"
"define void @f(ptr dereferenceable(40) align 2 %array, "
"ptr dereferenceable(40) align 2 %pred) {\n"
"entry:\n"
" br label %for.body\n"
"for.body:\n"
" %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.inc ]\n"
" %st.addr = getelementptr inbounds i16, ptr %array, i64 %iv\n"
" %data = load i16, ptr %st.addr, align 2\n"
" %inc = add nsw i16 %data, 1\n"
" store i16 %inc, ptr %st.addr, align 2\n"
" %uncountable.addr = getelementptr inbounds nuw i16, ptr %pred, i64 "
"%iv\n"
" %uncountable.val = load i16, ptr %uncountable.addr, align 2\n"
" %uncountable.cond = icmp sgt i16 %uncountable.val, 500\n"
" br i1 %uncountable.cond, label %exit, label %for.inc\n"
"for.inc:\n"
" %iv.next = add nuw nsw i64 %iv, 1\n"
" %countable.cond = icmp eq i64 %iv.next, 20\n"
" br i1 %countable.cond, label %exit, label %for.body\n"
"exit:\n"
" ret void\n"
"}\n";
Module &M = parseModule(ModuleString);
Function *F = M.getFunction("f");
BasicBlock *LoopHeader = F->getEntryBlock().getSingleSuccessor();
VPlanPtr Plan = buildVPlan0(LoopHeader);
combineExitConditions(*Plan);
SmallVector<VPInstruction *> Recipes;
SmallVector<VPInstruction *> GEPs;
auto *MiddleVPBB = cast<VPBasicBlock>(
Plan->getScalarHeader()->getSinglePredecessor()->getPredecessors()[0]);
auto *LatchVPBB = cast<VPBasicBlock>(MiddleVPBB->getSinglePredecessor());
std::optional<VPValue *> UncountableCondition =
vputils::getRecipesForUncountableExit(Recipes, GEPs, LatchVPBB);
ASSERT_TRUE(UncountableCondition.has_value());
ASSERT_EQ(GEPs.size(), 1ull);
ASSERT_EQ(Recipes.size(), 4ull);
}
TEST_F(VPUncountableExitTest, NoUncountableExit) {
const char *ModuleString =
"define void @f(ptr %array, ptr %pred) {\n"
"entry:\n"
" br label %for.body\n"
"for.body:\n"
" %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ]\n"
" %st.addr = getelementptr inbounds i16, ptr %array, i64 %iv\n"
" %data = load i16, ptr %st.addr, align 2\n"
" %inc = add nsw i16 %data, 1\n"
" store i16 %inc, ptr %st.addr, align 2\n"
" %iv.next = add nuw nsw i64 %iv, 1\n"
" %countable.cond = icmp eq i64 %iv.next, 20\n"
" br i1 %countable.cond, label %exit, label %for.body\n"
"exit:\n"
" ret void\n"
"}\n";
Module &M = parseModule(ModuleString);
Function *F = M.getFunction("f");
BasicBlock *LoopHeader = F->getEntryBlock().getSingleSuccessor();
auto Plan = buildVPlan0(LoopHeader);
SmallVector<VPInstruction *> Recipes;
SmallVector<VPInstruction *> GEPs;
auto *MiddleVPBB = cast<VPBasicBlock>(
Plan->getScalarHeader()->getSinglePredecessor()->getPredecessors()[0]);
auto *LatchVPBB = cast<VPBasicBlock>(MiddleVPBB->getSinglePredecessor());
std::optional<VPValue *> UncountableCondition =
vputils::getRecipesForUncountableExit(Recipes, GEPs, LatchVPBB);
ASSERT_FALSE(UncountableCondition.has_value());
ASSERT_EQ(GEPs.size(), 0ull);
ASSERT_EQ(Recipes.size(), 0ull);
}
} // namespace
} // namespace llvm