From 4b4aa3b7911dcf5cda2cd166e39b8d5d97c20183 Mon Sep 17 00:00:00 2001 From: Vedant Neve Date: Tue, 14 Apr 2026 13:12:06 +0530 Subject: [PATCH] [DAG] Add funnel-shift matchers to SDPatternMatch (Fixes #185880) (#186593) Add new SelectionDAG pattern matchers for funnel shifts: - m_FShL and m_FShR as ternary wrappers for ISD::FSHL/ISD::FSHR - m_FShLLike and m_FShRLike to match: -- direct FSHL/FSHR nodes -- ROTL/ROTR equivalents (binding both X and Y to the same rotate operand) -- OR(SHL(X, C), SRL(Y, BW - C)) forms (including commuted OR) Also add unit tests covering positive and negative cases for: - direct funnel-shif matching - rotate equivalence matching - OR-based funnel-shift-like patterns Fixes #185880 --- llvm/include/llvm/CodeGen/SDPatternMatch.h | 86 ++++++++++++++++++ .../CodeGen/SelectionDAGPatternMatchTest.cpp | 91 +++++++++++++++++++ 2 files changed, 177 insertions(+) diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h index 6c356d6ca37a..899243691810 100644 --- a/llvm/include/llvm/CodeGen/SDPatternMatch.h +++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h @@ -17,6 +17,7 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" +#include "llvm/ADT/bit.h" #include "llvm/CodeGen/SelectionDAG.h" #include "llvm/CodeGen/SelectionDAGNodes.h" #include "llvm/CodeGen/TargetLowering.h" @@ -965,6 +966,75 @@ inline BinaryOpc_match m_Rotr(const LHS &L, const RHS &R) { return BinaryOpc_match(ISD::ROTR, L, R); } +template +inline TernaryOpc_match +m_FShL(const T0_P &Op0, const T1_P &Op1, const T2_P &Op2) { + return m_TernaryOp(ISD::FSHL, Op0, Op1, Op2); +} + +template +inline TernaryOpc_match +m_FShR(const T0_P &Op0, const T1_P &Op1, const T2_P &Op2) { + return m_TernaryOp(ISD::FSHR, Op0, Op1, Op2); +} + +template +struct FunnelShiftLike_match { + T0_P Op0; + T1_P Op1; + T2_P Op2; + + FunnelShiftLike_match(const T0_P &Op0, const T1_P &Op1, const T2_P &Op2) + : Op0(Op0), Op1(Op1), Op2(Op2) {} + + static bool hasComplementaryConstantShifts(const APInt &ShlV, + const APInt &SrlV, + unsigned BitWidth) { + unsigned SumWidth = std::max(ShlV.getBitWidth(), SrlV.getBitWidth()) + 1; + unsigned BitWidthBits = llvm::bit_width(BitWidth); + if (BitWidthBits > SumWidth) + return false; + + return ShlV.zext(SumWidth) + SrlV.zext(SumWidth) == + APInt(SumWidth, BitWidth); + } + + template + bool matchOperands(const MatchContext &Ctx, SDValue X, SDValue Y, SDValue Z) { + return Op0.match(Ctx, X) && Op1.match(Ctx, Y) && Op2.match(Ctx, Z); + } + + template + bool matchShiftOr(const MatchContext &Ctx, SDValue N, unsigned BitWidth); + + template + bool match(const MatchContext &Ctx, SDValue N) { + if (sd_context_match(N, Ctx, + Left ? m_FShL(Op0, Op1, Op2) : m_FShR(Op0, Op1, Op2))) + return true; + + SDValue X, Z; + if (sd_context_match(N, Ctx, + Left ? m_Rotl(m_Value(X), m_Value(Z)) + : m_Rotr(m_Value(X), m_Value(Z)))) + return matchOperands(Ctx, X, X, Z); + + return matchShiftOr(Ctx, N, N.getValueType().getScalarSizeInBits()); + } +}; + +template +inline FunnelShiftLike_match +m_FShLLike(const T0_P &Op0, const T1_P &Op1, const T2_P &Op2) { + return FunnelShiftLike_match(Op0, Op1, Op2); +} + +template +inline FunnelShiftLike_match +m_FShRLike(const T0_P &Op0, const T1_P &Op1, const T2_P &Op2) { + return FunnelShiftLike_match(Op0, Op1, Op2); +} + template inline BinaryOpc_match m_Clmul(const LHS &L, const RHS &R) { return BinaryOpc_match(ISD::CLMUL, L, R); @@ -1224,6 +1294,22 @@ inline Constant64_match m_ConstInt(int64_t &V) { return Constant64_match(V); } +template +template +bool FunnelShiftLike_match::matchShiftOr( + const MatchContext &Ctx, SDValue N, unsigned BitWidth) { + SDValue X, Y, ShlAmt, SrlAmt; + APInt ShlConst, SrlConst; + if (!sd_context_match( + N, Ctx, + m_Or(m_Shl(m_Value(X), m_Value(ShlAmt, m_ConstInt(ShlConst))), + m_Srl(m_Value(Y), m_Value(SrlAmt, m_ConstInt(SrlConst))))) || + !hasComplementaryConstantShifts(ShlConst, SrlConst, BitWidth)) + return false; + + return matchOperands(Ctx, X, Y, Left ? ShlAmt : SrlAmt); +} + struct SpecificInt_match { APInt IntVal; diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp index 810363c8f638..1073d67ff68e 100644 --- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp +++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp @@ -636,6 +636,97 @@ TEST_F(SelectionDAGPatternMatchTest, matchGenericTernaryOp) { sd_match(FAdd, m_c_TernaryOp(ISD::FMA, m_Value(), m_Value(), m_Value()))); } +TEST_F(SelectionDAGPatternMatchTest, matchFunnelShift) { + SDLoc DL; + auto Int32VT = EVT::getIntegerVT(Context, 32); + + SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, + Register::index2VirtReg(1), Int32VT); + SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, + Register::index2VirtReg(2), Int32VT); + SDValue Op2 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, + Register::index2VirtReg(3), Int32VT); + SDValue C7 = DAG->getConstant(7, DL, Int32VT); + SDValue C24 = DAG->getConstant(24, DL, Int32VT); + SDValue C25 = DAG->getConstant(25, DL, Int32VT); + + SDValue FShL = DAG->getNode(ISD::FSHL, DL, Int32VT, Op0, Op1, Op2); + SDValue FShR = DAG->getNode(ISD::FSHR, DL, Int32VT, Op0, Op1, Op2); + SDValue Rotl = DAG->getNode(ISD::ROTL, DL, Int32VT, Op0, Op2); + SDValue Rotr = DAG->getNode(ISD::ROTR, DL, Int32VT, Op0, Op2); + + SDValue Shl7 = DAG->getNode(ISD::SHL, DL, Int32VT, Op0, C7); + SDValue Srl25 = DAG->getNode(ISD::SRL, DL, Int32VT, Op1, C25); + SDValue Srl24 = DAG->getNode(ISD::SRL, DL, Int32VT, Op1, C24); + SDValue OrFSh = DAG->getNode(ISD::OR, DL, Int32VT, Shl7, Srl25); + SDValue OrFShCommuted = DAG->getNode(ISD::OR, DL, Int32VT, Srl25, Shl7); + SDValue BadOrFSh = DAG->getNode(ISD::OR, DL, Int32VT, Shl7, Srl24); + + using namespace SDPatternMatch; + EXPECT_TRUE(sd_match( + FShL, m_FShL(m_Specific(Op0), m_Specific(Op1), m_Specific(Op2)))); + EXPECT_TRUE(sd_match( + FShR, m_FShR(m_Specific(Op0), m_Specific(Op1), m_Specific(Op2)))); + EXPECT_FALSE(sd_match(FShL, m_FShR(m_Value(), m_Value(), m_Value()))); + EXPECT_FALSE(sd_match(FShR, m_FShL(m_Value(), m_Value(), m_Value()))); + + EXPECT_TRUE(sd_match( + FShL, m_FShLLike(m_Specific(Op0), m_Specific(Op1), m_Specific(Op2)))); + EXPECT_TRUE(sd_match( + FShR, m_FShRLike(m_Specific(Op0), m_Specific(Op1), m_Specific(Op2)))); + EXPECT_FALSE(sd_match(FShL, m_FShRLike(m_Value(), m_Value(), m_Value()))); + EXPECT_FALSE(sd_match(FShR, m_FShLLike(m_Value(), m_Value(), m_Value()))); + + EXPECT_TRUE(sd_match( + Rotl, m_FShLLike(m_Specific(Op0), m_Specific(Op0), m_Specific(Op2)))); + EXPECT_TRUE(sd_match( + Rotr, m_FShRLike(m_Specific(Op0), m_Specific(Op0), m_Specific(Op2)))); + EXPECT_FALSE(sd_match( + Rotl, m_FShLLike(m_Specific(Op0), m_Specific(Op1), m_Specific(Op2)))); + EXPECT_FALSE(sd_match( + Rotr, m_FShRLike(m_Specific(Op0), m_Specific(Op1), m_Specific(Op2)))); + EXPECT_FALSE(sd_match(Rotl, m_FShRLike(m_Value(), m_Value(), m_Value()))); + EXPECT_FALSE(sd_match(Rotr, m_FShLLike(m_Value(), m_Value(), m_Value()))); + + SDValue A, B, C; + EXPECT_TRUE(sd_match(Rotl, m_FShLLike(m_Value(A), m_Value(B), m_Value(C)))); + EXPECT_EQ(A, Op0); + EXPECT_EQ(B, Op0); + EXPECT_EQ(C, Op2); + + A = B = C = SDValue(); + EXPECT_TRUE(sd_match(Rotr, m_FShRLike(m_Value(A), m_Value(B), m_Value(C)))); + EXPECT_EQ(A, Op0); + EXPECT_EQ(B, Op0); + EXPECT_EQ(C, Op2); + + EXPECT_TRUE(sd_match( + OrFSh, m_FShLLike(m_Specific(Op0), m_Specific(Op1), m_SpecificInt(7)))); + EXPECT_TRUE(sd_match( + OrFSh, m_FShRLike(m_Specific(Op0), m_Specific(Op1), m_SpecificInt(25)))); + EXPECT_TRUE( + sd_match(OrFShCommuted, + m_FShLLike(m_Specific(Op0), m_Specific(Op1), m_SpecificInt(7)))); + EXPECT_TRUE( + sd_match(OrFShCommuted, m_FShRLike(m_Specific(Op0), m_Specific(Op1), + m_SpecificInt(25)))); + EXPECT_FALSE(sd_match(BadOrFSh, m_FShLLike(m_Value(), m_Value(), m_Value()))); + EXPECT_FALSE(sd_match(BadOrFSh, m_FShRLike(m_Value(), m_Value(), m_Value()))); + + auto Int1024VT = EVT::getIntegerVT(Context, 1024); + auto Int8VT = EVT::getIntegerVT(Context, 8); + SDValue WideOp0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, + Register::index2VirtReg(4), Int1024VT); + SDValue WideOp1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, + Register::index2VirtReg(5), Int1024VT); + SDValue C0I8 = DAG->getConstant(0, DL, Int8VT); + SDValue WideShl = DAG->getNode(ISD::SHL, DL, Int1024VT, WideOp0, C0I8); + SDValue WideSrl = DAG->getNode(ISD::SRL, DL, Int1024VT, WideOp1, C0I8); + SDValue WideOr = DAG->getNode(ISD::OR, DL, Int1024VT, WideShl, WideSrl); + EXPECT_FALSE(sd_match(WideOr, m_FShLLike(m_Value(), m_Value(), m_Value()))); + EXPECT_FALSE(sd_match(WideOr, m_FShRLike(m_Value(), m_Value(), m_Value()))); +} + TEST_F(SelectionDAGPatternMatchTest, matchUnaryOp) { SDLoc DL; auto Int32VT = EVT::getIntegerVT(Context, 32);