[HLSL] Allow 1x1 matrices to be splatted like scalars (#188119)
Fixes #186859 by allowing 1x1 matrices to be splatted like the scalar and vec1 cases. Assisted-by: GitHub Copilot (powered by Claude Opus 4.6)
This commit is contained in:
@@ -2939,11 +2939,17 @@ bool CastOperation::CheckHLSLCStyleCast(CheckedConversionKind CCK) {
|
||||
if (Self.HLSL().CanPerformAggregateSplatCast(SrcExpr.get(), DestType)) {
|
||||
SrcExpr = Self.DefaultLvalueConversion(SrcExpr.get());
|
||||
const VectorType *VT = SrcTy->getAs<VectorType>();
|
||||
const ConstantMatrixType *MT = SrcTy->getAs<ConstantMatrixType>();
|
||||
// change splat from vec1 case to splat from scalar
|
||||
if (VT && VT->getNumElements() == 1)
|
||||
SrcExpr = Self.ImpCastExprToType(
|
||||
SrcExpr.get(), VT->getElementType(), CK_HLSLVectorTruncation,
|
||||
SrcExpr.get()->getValueKind(), nullptr, CCK);
|
||||
// change splat from 1x1 matrix case to splat from scalar
|
||||
else if (MT && MT->getNumElementsFlattened() == 1)
|
||||
SrcExpr = Self.ImpCastExprToType(
|
||||
SrcExpr.get(), MT->getElementType(), CK_HLSLMatrixTruncation,
|
||||
SrcExpr.get()->getValueKind(), nullptr, CCK);
|
||||
// Inserting a scalar cast here allows for a simplified codegen in
|
||||
// the case the destTy is a vector
|
||||
if (const VectorType *DVT = DestType->getAs<VectorType>())
|
||||
|
||||
@@ -4644,8 +4644,8 @@ bool SemaHLSL::CanPerformScalarCast(QualType SrcTy, QualType DestTy) {
|
||||
}
|
||||
|
||||
// Can perform an HLSL Aggregate splat cast if the Dest is an aggregate and the
|
||||
// Src is a scalar or a vector of length 1
|
||||
// Or if Dest is a vector and Src is a vector of length 1
|
||||
// Src is a scalar, a vector of length 1, or a 1x1 matrix
|
||||
// Or if Dest is a vector and Src is a vector of length 1 or a 1x1 matrix
|
||||
bool SemaHLSL::CanPerformAggregateSplatCast(Expr *Src, QualType DestTy) {
|
||||
|
||||
QualType SrcTy = Src->getType();
|
||||
@@ -4656,13 +4656,18 @@ bool SemaHLSL::CanPerformAggregateSplatCast(Expr *Src, QualType DestTy) {
|
||||
return false;
|
||||
|
||||
const VectorType *SrcVecTy = SrcTy->getAs<VectorType>();
|
||||
const ConstantMatrixType *SrcMatTy = SrcTy->getAs<ConstantMatrixType>();
|
||||
|
||||
// Src isn't a scalar or a vector of length 1
|
||||
if (!SrcTy->isScalarType() && !(SrcVecTy && SrcVecTy->getNumElements() == 1))
|
||||
// Src isn't a scalar, a vector of length 1, or a 1x1 matrix
|
||||
if (!SrcTy->isScalarType() &&
|
||||
!(SrcVecTy && SrcVecTy->getNumElements() == 1) &&
|
||||
!(SrcMatTy && SrcMatTy->getNumElementsFlattened() == 1))
|
||||
return false;
|
||||
|
||||
if (SrcVecTy)
|
||||
SrcTy = SrcVecTy->getElementType();
|
||||
else if (SrcMatTy)
|
||||
SrcTy = SrcMatTy->getElementType();
|
||||
|
||||
llvm::SmallVector<QualType> DestTypes;
|
||||
BuildFlattenedTypeList(DestTy, DestTypes);
|
||||
|
||||
@@ -84,6 +84,39 @@ export void call5() {
|
||||
S s = (S)A;
|
||||
}
|
||||
|
||||
// vector splat from 1x1 matrix
|
||||
// CHECK-LABEL: define void {{.*}}call9
|
||||
// CHECK: [[M:%.*]] = alloca [1 x <1 x float>], align 4
|
||||
// CHECK-NEXT: [[A:%.*]] = alloca <4 x i32>, align 4
|
||||
// CHECK-NEXT: store <1 x float> {{.*}}, ptr [[M]], align 4
|
||||
// CHECK-NEXT: [[L:%.*]] = load <1 x float>, ptr [[M]], align 4
|
||||
// CHECK-NEXT: [[ML:%.*]] = extractelement <1 x float> [[L]], i32 0
|
||||
// CHECK-NEXT: [[C:%.*]] = fptosi float [[ML]] to i32
|
||||
// CHECK-NEXT: [[SI:%.*]] = insertelement <4 x i32> poison, i32 [[C]], i64 0
|
||||
// CHECK-NEXT: [[S:%.*]] = shufflevector <4 x i32> [[SI]], <4 x i32> poison, <4 x i32> zeroinitializer
|
||||
// CHECK-NEXT: store <4 x i32> [[S]], ptr [[A]], align 4
|
||||
export void call9() {
|
||||
float1x1 M = {1.0};
|
||||
int4 A = (int4)M;
|
||||
}
|
||||
|
||||
// struct splat from 1x1 matrix
|
||||
// CHECK-LABEL: define void {{.*}}call10
|
||||
// CHECK: [[M:%.*]] = alloca [1 x <1 x i32>], align 4
|
||||
// CHECK-NEXT: [[s:%.*]] = alloca %struct.S, align 1
|
||||
// CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[M]], align 4
|
||||
// CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[M]], align 4
|
||||
// CHECK-NEXT: [[ML:%.*]] = extractelement <1 x i32> [[L]], i32 0
|
||||
// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 0
|
||||
// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 1
|
||||
// CHECK-NEXT: store i32 [[ML]], ptr [[G1]], align 4
|
||||
// CHECK-NEXT: [[C:%.*]] = sitofp i32 [[ML]] to float
|
||||
// CHECK-NEXT: store float [[C]], ptr [[G2]], align 4
|
||||
export void call10() {
|
||||
int1x1 M = {1};
|
||||
S s = (S)M;
|
||||
}
|
||||
|
||||
struct BFields {
|
||||
double DF;
|
||||
int E: 15;
|
||||
|
||||
Reference in New Issue
Block a user