[DAG] Add funnel-shift matchers to SDPatternMatch (Fixes #185880) (#186593)

Add new SelectionDAG pattern matchers for funnel shifts:
- m_FShL and m_FShR as ternary wrappers for ISD::FSHL/ISD::FSHR
- m_FShLLike and m_FShRLike to match:
-- direct FSHL/FSHR nodes
-- ROTL/ROTR equivalents (binding both X and Y to the same rotate operand)
-- OR(SHL(X, C), SRL(Y, BW - C)) forms (including commuted OR)

Also add unit tests covering positive and negative cases for:
- direct funnel-shif matching
- rotate equivalence matching
- OR-based funnel-shift-like patterns

Fixes #185880
This commit is contained in:
Vedant Neve
2026-04-14 13:12:06 +05:30
committed by GitHub
parent c61b070ec2
commit 4b4aa3b791
2 changed files with 177 additions and 0 deletions

View File

@@ -17,6 +17,7 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/bit.h"
#include "llvm/CodeGen/SelectionDAG.h"
#include "llvm/CodeGen/SelectionDAGNodes.h"
#include "llvm/CodeGen/TargetLowering.h"
@@ -965,6 +966,75 @@ inline BinaryOpc_match<LHS, RHS> m_Rotr(const LHS &L, const RHS &R) {
return BinaryOpc_match<LHS, RHS>(ISD::ROTR, L, R);
}
template <typename T0_P, typename T1_P, typename T2_P>
inline TernaryOpc_match<T0_P, T1_P, T2_P>
m_FShL(const T0_P &Op0, const T1_P &Op1, const T2_P &Op2) {
return m_TernaryOp(ISD::FSHL, Op0, Op1, Op2);
}
template <typename T0_P, typename T1_P, typename T2_P>
inline TernaryOpc_match<T0_P, T1_P, T2_P>
m_FShR(const T0_P &Op0, const T1_P &Op1, const T2_P &Op2) {
return m_TernaryOp(ISD::FSHR, Op0, Op1, Op2);
}
template <typename T0_P, typename T1_P, typename T2_P, bool Left>
struct FunnelShiftLike_match {
T0_P Op0;
T1_P Op1;
T2_P Op2;
FunnelShiftLike_match(const T0_P &Op0, const T1_P &Op1, const T2_P &Op2)
: Op0(Op0), Op1(Op1), Op2(Op2) {}
static bool hasComplementaryConstantShifts(const APInt &ShlV,
const APInt &SrlV,
unsigned BitWidth) {
unsigned SumWidth = std::max(ShlV.getBitWidth(), SrlV.getBitWidth()) + 1;
unsigned BitWidthBits = llvm::bit_width(BitWidth);
if (BitWidthBits > SumWidth)
return false;
return ShlV.zext(SumWidth) + SrlV.zext(SumWidth) ==
APInt(SumWidth, BitWidth);
}
template <typename MatchContext>
bool matchOperands(const MatchContext &Ctx, SDValue X, SDValue Y, SDValue Z) {
return Op0.match(Ctx, X) && Op1.match(Ctx, Y) && Op2.match(Ctx, Z);
}
template <typename MatchContext>
bool matchShiftOr(const MatchContext &Ctx, SDValue N, unsigned BitWidth);
template <typename MatchContext>
bool match(const MatchContext &Ctx, SDValue N) {
if (sd_context_match(N, Ctx,
Left ? m_FShL(Op0, Op1, Op2) : m_FShR(Op0, Op1, Op2)))
return true;
SDValue X, Z;
if (sd_context_match(N, Ctx,
Left ? m_Rotl(m_Value(X), m_Value(Z))
: m_Rotr(m_Value(X), m_Value(Z))))
return matchOperands(Ctx, X, X, Z);
return matchShiftOr(Ctx, N, N.getValueType().getScalarSizeInBits());
}
};
template <typename T0_P, typename T1_P, typename T2_P>
inline FunnelShiftLike_match<T0_P, T1_P, T2_P, true>
m_FShLLike(const T0_P &Op0, const T1_P &Op1, const T2_P &Op2) {
return FunnelShiftLike_match<T0_P, T1_P, T2_P, true>(Op0, Op1, Op2);
}
template <typename T0_P, typename T1_P, typename T2_P>
inline FunnelShiftLike_match<T0_P, T1_P, T2_P, false>
m_FShRLike(const T0_P &Op0, const T1_P &Op1, const T2_P &Op2) {
return FunnelShiftLike_match<T0_P, T1_P, T2_P, false>(Op0, Op1, Op2);
}
template <typename LHS, typename RHS>
inline BinaryOpc_match<LHS, RHS, true> m_Clmul(const LHS &L, const RHS &R) {
return BinaryOpc_match<LHS, RHS, true>(ISD::CLMUL, L, R);
@@ -1224,6 +1294,22 @@ inline Constant64_match<int64_t> m_ConstInt(int64_t &V) {
return Constant64_match<int64_t>(V);
}
template <typename T0_P, typename T1_P, typename T2_P, bool Left>
template <typename MatchContext>
bool FunnelShiftLike_match<T0_P, T1_P, T2_P, Left>::matchShiftOr(
const MatchContext &Ctx, SDValue N, unsigned BitWidth) {
SDValue X, Y, ShlAmt, SrlAmt;
APInt ShlConst, SrlConst;
if (!sd_context_match(
N, Ctx,
m_Or(m_Shl(m_Value(X), m_Value(ShlAmt, m_ConstInt(ShlConst))),
m_Srl(m_Value(Y), m_Value(SrlAmt, m_ConstInt(SrlConst))))) ||
!hasComplementaryConstantShifts(ShlConst, SrlConst, BitWidth))
return false;
return matchOperands(Ctx, X, Y, Left ? ShlAmt : SrlAmt);
}
struct SpecificInt_match {
APInt IntVal;

View File

@@ -636,6 +636,97 @@ TEST_F(SelectionDAGPatternMatchTest, matchGenericTernaryOp) {
sd_match(FAdd, m_c_TernaryOp(ISD::FMA, m_Value(), m_Value(), m_Value())));
}
TEST_F(SelectionDAGPatternMatchTest, matchFunnelShift) {
SDLoc DL;
auto Int32VT = EVT::getIntegerVT(Context, 32);
SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL,
Register::index2VirtReg(1), Int32VT);
SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL,
Register::index2VirtReg(2), Int32VT);
SDValue Op2 = DAG->getCopyFromReg(DAG->getEntryNode(), DL,
Register::index2VirtReg(3), Int32VT);
SDValue C7 = DAG->getConstant(7, DL, Int32VT);
SDValue C24 = DAG->getConstant(24, DL, Int32VT);
SDValue C25 = DAG->getConstant(25, DL, Int32VT);
SDValue FShL = DAG->getNode(ISD::FSHL, DL, Int32VT, Op0, Op1, Op2);
SDValue FShR = DAG->getNode(ISD::FSHR, DL, Int32VT, Op0, Op1, Op2);
SDValue Rotl = DAG->getNode(ISD::ROTL, DL, Int32VT, Op0, Op2);
SDValue Rotr = DAG->getNode(ISD::ROTR, DL, Int32VT, Op0, Op2);
SDValue Shl7 = DAG->getNode(ISD::SHL, DL, Int32VT, Op0, C7);
SDValue Srl25 = DAG->getNode(ISD::SRL, DL, Int32VT, Op1, C25);
SDValue Srl24 = DAG->getNode(ISD::SRL, DL, Int32VT, Op1, C24);
SDValue OrFSh = DAG->getNode(ISD::OR, DL, Int32VT, Shl7, Srl25);
SDValue OrFShCommuted = DAG->getNode(ISD::OR, DL, Int32VT, Srl25, Shl7);
SDValue BadOrFSh = DAG->getNode(ISD::OR, DL, Int32VT, Shl7, Srl24);
using namespace SDPatternMatch;
EXPECT_TRUE(sd_match(
FShL, m_FShL(m_Specific(Op0), m_Specific(Op1), m_Specific(Op2))));
EXPECT_TRUE(sd_match(
FShR, m_FShR(m_Specific(Op0), m_Specific(Op1), m_Specific(Op2))));
EXPECT_FALSE(sd_match(FShL, m_FShR(m_Value(), m_Value(), m_Value())));
EXPECT_FALSE(sd_match(FShR, m_FShL(m_Value(), m_Value(), m_Value())));
EXPECT_TRUE(sd_match(
FShL, m_FShLLike(m_Specific(Op0), m_Specific(Op1), m_Specific(Op2))));
EXPECT_TRUE(sd_match(
FShR, m_FShRLike(m_Specific(Op0), m_Specific(Op1), m_Specific(Op2))));
EXPECT_FALSE(sd_match(FShL, m_FShRLike(m_Value(), m_Value(), m_Value())));
EXPECT_FALSE(sd_match(FShR, m_FShLLike(m_Value(), m_Value(), m_Value())));
EXPECT_TRUE(sd_match(
Rotl, m_FShLLike(m_Specific(Op0), m_Specific(Op0), m_Specific(Op2))));
EXPECT_TRUE(sd_match(
Rotr, m_FShRLike(m_Specific(Op0), m_Specific(Op0), m_Specific(Op2))));
EXPECT_FALSE(sd_match(
Rotl, m_FShLLike(m_Specific(Op0), m_Specific(Op1), m_Specific(Op2))));
EXPECT_FALSE(sd_match(
Rotr, m_FShRLike(m_Specific(Op0), m_Specific(Op1), m_Specific(Op2))));
EXPECT_FALSE(sd_match(Rotl, m_FShRLike(m_Value(), m_Value(), m_Value())));
EXPECT_FALSE(sd_match(Rotr, m_FShLLike(m_Value(), m_Value(), m_Value())));
SDValue A, B, C;
EXPECT_TRUE(sd_match(Rotl, m_FShLLike(m_Value(A), m_Value(B), m_Value(C))));
EXPECT_EQ(A, Op0);
EXPECT_EQ(B, Op0);
EXPECT_EQ(C, Op2);
A = B = C = SDValue();
EXPECT_TRUE(sd_match(Rotr, m_FShRLike(m_Value(A), m_Value(B), m_Value(C))));
EXPECT_EQ(A, Op0);
EXPECT_EQ(B, Op0);
EXPECT_EQ(C, Op2);
EXPECT_TRUE(sd_match(
OrFSh, m_FShLLike(m_Specific(Op0), m_Specific(Op1), m_SpecificInt(7))));
EXPECT_TRUE(sd_match(
OrFSh, m_FShRLike(m_Specific(Op0), m_Specific(Op1), m_SpecificInt(25))));
EXPECT_TRUE(
sd_match(OrFShCommuted,
m_FShLLike(m_Specific(Op0), m_Specific(Op1), m_SpecificInt(7))));
EXPECT_TRUE(
sd_match(OrFShCommuted, m_FShRLike(m_Specific(Op0), m_Specific(Op1),
m_SpecificInt(25))));
EXPECT_FALSE(sd_match(BadOrFSh, m_FShLLike(m_Value(), m_Value(), m_Value())));
EXPECT_FALSE(sd_match(BadOrFSh, m_FShRLike(m_Value(), m_Value(), m_Value())));
auto Int1024VT = EVT::getIntegerVT(Context, 1024);
auto Int8VT = EVT::getIntegerVT(Context, 8);
SDValue WideOp0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL,
Register::index2VirtReg(4), Int1024VT);
SDValue WideOp1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL,
Register::index2VirtReg(5), Int1024VT);
SDValue C0I8 = DAG->getConstant(0, DL, Int8VT);
SDValue WideShl = DAG->getNode(ISD::SHL, DL, Int1024VT, WideOp0, C0I8);
SDValue WideSrl = DAG->getNode(ISD::SRL, DL, Int1024VT, WideOp1, C0I8);
SDValue WideOr = DAG->getNode(ISD::OR, DL, Int1024VT, WideShl, WideSrl);
EXPECT_FALSE(sd_match(WideOr, m_FShLLike(m_Value(), m_Value(), m_Value())));
EXPECT_FALSE(sd_match(WideOr, m_FShRLike(m_Value(), m_Value(), m_Value())));
}
TEST_F(SelectionDAGPatternMatchTest, matchUnaryOp) {
SDLoc DL;
auto Int32VT = EVT::getIntegerVT(Context, 32);