Splitting out some work from #178454; this covers the enums for early exit loop type (none, readonly, readwrite) and the style used (readonly with multiple exit blocks, or masking with the last iteration done in scalar code), along with changing the early exit recipe detection to suit moving the transform for handling early exit readwrite loops earlier in the vplan pipeline.
166 lines
6.2 KiB
C++
166 lines
6.2 KiB
C++
//===- llvm/unittests/Transforms/Vectorize/VPlanUncountableExitTest.cpp ---===//
|
|
//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "../lib/Transforms/Vectorize/VPRecipeBuilder.h"
|
|
#include "../lib/Transforms/Vectorize/VPlan.h"
|
|
#include "../lib/Transforms/Vectorize/VPlanPatternMatch.h"
|
|
#include "../lib/Transforms/Vectorize/VPlanUtils.h"
|
|
#include "VPlanTestBase.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
#include "gtest/gtest.h"
|
|
|
|
namespace llvm {
|
|
|
|
namespace {
|
|
class VPUncountableExitTest : public VPlanTestIRBase {};
|
|
using namespace VPlanPatternMatch;
|
|
|
|
static void combineExitConditions(VPlan &Plan) {
|
|
struct EarlyExitInfo {
|
|
VPBasicBlock *EarlyExitingVPBB;
|
|
VPIRBasicBlock *EarlyExitVPBB;
|
|
VPValue *CondToExit;
|
|
};
|
|
|
|
auto *MiddleVPBB = cast<VPBasicBlock>(
|
|
Plan.getScalarHeader()->getSinglePredecessor()->getPredecessors()[0]);
|
|
auto *LatchVPBB = cast<VPBasicBlock>(MiddleVPBB->getSinglePredecessor());
|
|
|
|
// Find the single early exit: a non-middle predecessor of an exit block.
|
|
VPBasicBlock *EarlyExitingVPBB = nullptr;
|
|
VPIRBasicBlock *EarlyExitVPBB = nullptr;
|
|
for (VPIRBasicBlock *ExitBlock : Plan.getExitBlocks()) {
|
|
for (VPBlockBase *Pred : ExitBlock->getPredecessors()) {
|
|
if (Pred != MiddleVPBB) {
|
|
EarlyExitingVPBB = cast<VPBasicBlock>(Pred);
|
|
EarlyExitVPBB = ExitBlock;
|
|
}
|
|
}
|
|
}
|
|
assert(EarlyExitingVPBB && "must have an early exit");
|
|
|
|
// Wrap the early exit condition in a MaskedCond.
|
|
VPValue *Cond;
|
|
[[maybe_unused]] bool Matched =
|
|
match(EarlyExitingVPBB->getTerminator(), m_BranchOnCond(m_VPValue(Cond)));
|
|
assert(Matched && "Terminator must be BranchOnCond");
|
|
VPBuilder EarlyExitBuilder(EarlyExitingVPBB->getTerminator());
|
|
if (EarlyExitingVPBB->getSuccessors()[0] != EarlyExitVPBB)
|
|
Cond = EarlyExitBuilder.createNot(Cond);
|
|
auto *MaskedCond =
|
|
EarlyExitBuilder.createNaryOp(VPInstruction::MaskedCond, {Cond});
|
|
|
|
// Combine the early exit with the latch exit on the latch terminator.
|
|
VPBuilder Builder(LatchVPBB->getTerminator());
|
|
auto *IsAnyExitTaken =
|
|
Builder.createNaryOp(VPInstruction::AnyOf, {MaskedCond});
|
|
auto *LatchBranch = cast<VPInstruction>(LatchVPBB->getTerminator());
|
|
auto *IsLatchExitTaken = Builder.createICmp(
|
|
CmpInst::ICMP_EQ, LatchBranch->getOperand(0), LatchBranch->getOperand(1));
|
|
LatchBranch->eraseFromParent();
|
|
Builder.setInsertPoint(LatchVPBB);
|
|
Builder.createNaryOp(VPInstruction::BranchOnCond,
|
|
{Builder.createOr(IsAnyExitTaken, IsLatchExitTaken)});
|
|
|
|
// Disconnect the early exit edge.
|
|
EarlyExitingVPBB->getTerminator()->eraseFromParent();
|
|
VPBlockUtils::disconnectBlocks(EarlyExitingVPBB, EarlyExitVPBB);
|
|
}
|
|
|
|
TEST_F(VPUncountableExitTest, FindUncountableExitRecipes) {
|
|
const char *ModuleString =
|
|
"target datalayout = "
|
|
"\"e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-"
|
|
"f64:64:64-v64:64:64-v128:128:128-a0:0:64-s0:64:64-f80:128:128-n8:16:"
|
|
"32:64-S128\"\n"
|
|
"define void @f(ptr dereferenceable(40) align 2 %array, "
|
|
"ptr dereferenceable(40) align 2 %pred) {\n"
|
|
"entry:\n"
|
|
" br label %for.body\n"
|
|
"for.body:\n"
|
|
" %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.inc ]\n"
|
|
" %st.addr = getelementptr inbounds i16, ptr %array, i64 %iv\n"
|
|
" %data = load i16, ptr %st.addr, align 2\n"
|
|
" %inc = add nsw i16 %data, 1\n"
|
|
" store i16 %inc, ptr %st.addr, align 2\n"
|
|
" %uncountable.addr = getelementptr inbounds nuw i16, ptr %pred, i64 "
|
|
"%iv\n"
|
|
" %uncountable.val = load i16, ptr %uncountable.addr, align 2\n"
|
|
" %uncountable.cond = icmp sgt i16 %uncountable.val, 500\n"
|
|
" br i1 %uncountable.cond, label %exit, label %for.inc\n"
|
|
"for.inc:\n"
|
|
" %iv.next = add nuw nsw i64 %iv, 1\n"
|
|
" %countable.cond = icmp eq i64 %iv.next, 20\n"
|
|
" br i1 %countable.cond, label %exit, label %for.body\n"
|
|
"exit:\n"
|
|
" ret void\n"
|
|
"}\n";
|
|
|
|
Module &M = parseModule(ModuleString);
|
|
|
|
Function *F = M.getFunction("f");
|
|
BasicBlock *LoopHeader = F->getEntryBlock().getSingleSuccessor();
|
|
VPlanPtr Plan = buildVPlan0(LoopHeader);
|
|
combineExitConditions(*Plan);
|
|
|
|
SmallVector<VPInstruction *> Recipes;
|
|
SmallVector<VPInstruction *> GEPs;
|
|
|
|
auto *MiddleVPBB = cast<VPBasicBlock>(
|
|
Plan->getScalarHeader()->getSinglePredecessor()->getPredecessors()[0]);
|
|
auto *LatchVPBB = cast<VPBasicBlock>(MiddleVPBB->getSinglePredecessor());
|
|
|
|
std::optional<VPValue *> UncountableCondition =
|
|
vputils::getRecipesForUncountableExit(Recipes, GEPs, LatchVPBB);
|
|
ASSERT_TRUE(UncountableCondition.has_value());
|
|
ASSERT_EQ(GEPs.size(), 1ull);
|
|
ASSERT_EQ(Recipes.size(), 4ull);
|
|
}
|
|
|
|
TEST_F(VPUncountableExitTest, NoUncountableExit) {
|
|
const char *ModuleString =
|
|
"define void @f(ptr %array, ptr %pred) {\n"
|
|
"entry:\n"
|
|
" br label %for.body\n"
|
|
"for.body:\n"
|
|
" %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ]\n"
|
|
" %st.addr = getelementptr inbounds i16, ptr %array, i64 %iv\n"
|
|
" %data = load i16, ptr %st.addr, align 2\n"
|
|
" %inc = add nsw i16 %data, 1\n"
|
|
" store i16 %inc, ptr %st.addr, align 2\n"
|
|
" %iv.next = add nuw nsw i64 %iv, 1\n"
|
|
" %countable.cond = icmp eq i64 %iv.next, 20\n"
|
|
" br i1 %countable.cond, label %exit, label %for.body\n"
|
|
"exit:\n"
|
|
" ret void\n"
|
|
"}\n";
|
|
|
|
Module &M = parseModule(ModuleString);
|
|
|
|
Function *F = M.getFunction("f");
|
|
BasicBlock *LoopHeader = F->getEntryBlock().getSingleSuccessor();
|
|
auto Plan = buildVPlan0(LoopHeader);
|
|
|
|
SmallVector<VPInstruction *> Recipes;
|
|
SmallVector<VPInstruction *> GEPs;
|
|
|
|
auto *MiddleVPBB = cast<VPBasicBlock>(
|
|
Plan->getScalarHeader()->getSinglePredecessor()->getPredecessors()[0]);
|
|
auto *LatchVPBB = cast<VPBasicBlock>(MiddleVPBB->getSinglePredecessor());
|
|
|
|
std::optional<VPValue *> UncountableCondition =
|
|
vputils::getRecipesForUncountableExit(Recipes, GEPs, LatchVPBB);
|
|
ASSERT_FALSE(UncountableCondition.has_value());
|
|
ASSERT_EQ(GEPs.size(), 0ull);
|
|
ASSERT_EQ(Recipes.size(), 0ull);
|
|
}
|
|
|
|
} // namespace
|
|
} // namespace llvm
|