//===-- FIRToMemRef.cpp - Convert FIR loads and stores to MemRef ---------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This pass lowers FIR dialect memory operations to the MemRef dialect. // In particular it: // // - Rewrites `fir.alloca` to `memref.alloca`. // // - Rewrites `fir.load` / `fir.store` to `memref.load` / `memref.store`. // // - Allows FIR and MemRef to coexist by introducing `fir.convert` at // memory-use sites. Memory operations (`memref.load`, `memref.store`, // `memref.reinterpret_cast`, etc.) see MemRef-typed values, while the // original FIR-typed values remain available for non-memory uses. For // example: // // %fir_ref = ... : !fir.ref> // %memref = fir.convert %fir_ref // : !fir.ref> -> memref<...> // %val = memref.load %memref[...] : memref<...> // fir.call @callee(%fir_ref) : (!fir.ref>) -> () // // Here the MemRef-typed value is used for `memref.load`, while the // original FIR-typed value is preserved for `fir.call`. // // - Computes shapes, strides, and indices as needed for slices and shifts // and emits `memref.reinterpret_cast` when dynamic layout is required // (TODO: use memref.cast instead). // //===----------------------------------------------------------------------===// #include "flang/Optimizer/Builder/CUFCommon.h" #include "flang/Optimizer/Dialect/CUF/Attributes/CUFAttr.h" #include "flang/Optimizer/Dialect/FIROps.h" #include "flang/Optimizer/Dialect/FIROpsSupport.h" #include "flang/Optimizer/Dialect/FIRType.h" #include "flang/Optimizer/Dialect/Support/FIRContext.h" #include "flang/Optimizer/Dialect/Support/KindMapping.h" #include "flang/Optimizer/Transforms/FIRToMemRefTypeConverter.h" #include "flang/Optimizer/Transforms/Passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/OpenACC/OpenACC.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dominance.h" #include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Region.h" #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" #include "mlir/IR/Verifier.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #define DEBUG_TYPE "fir-to-memref" using namespace mlir; namespace fir { #define GEN_PASS_DEF_FIRTOMEMREF #include "flang/Optimizer/Transforms/Passes.h.inc" static bool isMarshalLike(Operation *op) { auto convert = dyn_cast_if_present(op); if (!convert) return false; bool resIsMemRef = isa(convert.getType()); bool argIsMemRef = isa(convert.getValue().getType()); assert(!(resIsMemRef && argIsMemRef) && "unexpected fir.convert memref -> memref in isMarshalLike"); return resIsMemRef || argIsMemRef; } using MemRefInfo = FailureOr>>; static llvm::cl::opt enableFIRConvertOptimizations( "enable-fir-convert-opts", llvm::cl::desc("enable emilinating redundant fir.convert in FIR-to-MemRef"), llvm::cl::init(false), llvm::cl::Hidden); class FIRToMemRef : public fir::impl::FIRToMemRefBase { public: void runOnOperation() override; private: llvm::SmallSetVector eraseOps; DominanceInfo *domInfo = nullptr; void rewriteAlloca(fir::AllocaOp, PatternRewriter &, FIRToMemRefTypeConverter &); void rewriteLoadOp(fir::LoadOp, PatternRewriter &, FIRToMemRefTypeConverter &); void rewriteStoreOp(fir::StoreOp, PatternRewriter &, FIRToMemRefTypeConverter &); MemRefInfo getMemRefInfo(Value, PatternRewriter &, FIRToMemRefTypeConverter &, Operation *); MemRefInfo convertArrayCoorOp(Operation *memOp, fir::ArrayCoorOp, PatternRewriter &, FIRToMemRefTypeConverter &); void replaceFIRMemrefs(Value, Value, PatternRewriter &) const; FailureOr getFIRConvert(Operation *memOp, Operation *memref, PatternRewriter &, FIRToMemRefTypeConverter &); FailureOr> getMemrefIndices(fir::ArrayCoorOp, Operation *, PatternRewriter &, Value, Value) const; bool memrefIsOptional(Operation *) const; Value canonicalizeIndex(Value, PatternRewriter &) const; // Logical section information used by FIRToMemRef. For projected slices, the // descriptor still owns the physical layout, so `sliceVec` intentionally // stays empty while `shapeVec`/`shiftVec` remain available for index math. struct SliceInfo { SmallVector shapeVec; SmallVector shiftVec; SmallVector sliceVec; bool hasProjectedSlice = false; }; template void collectSliceInfoFrom(OpTy op, SliceInfo &info) const; void populateShapeAndShift(SmallVectorImpl &shapeVec, SmallVectorImpl &shiftVec, fir::ShapeShiftOp shift) const; void populateShift(SmallVectorImpl &vec, fir::ShiftOp shift) const; void populateShape(SmallVectorImpl &vec, fir::ShapeOp shape) const; static fir::SliceOp getSliceOp(Value sliceVal) { return sliceVal ? sliceVal.getDefiningOp() : fir::SliceOp{}; } static bool hasProjectedSlice(fir::SliceOp sliceOp) { return sliceOp && !sliceOp.getFields().empty(); } unsigned getRankFromEmbox(fir::EmboxOp embox) const { auto memrefType = embox.getMemref().getType(); Type unwrappedType = fir::unwrapRefType(memrefType); if (auto seqType = dyn_cast(unwrappedType)) return seqType.getDimension(); return 0; } bool isCompilerGeneratedAlloca(Operation *op) const; void copyAttribute(Operation *from, Operation *to, llvm::StringRef name) const; Type getBaseType(Type type, bool complexBaseTypes = false) const; bool memrefIsDeviceData(Operation *memref) const; mlir::Attribute findCudaDataAttr(Value val) const; Value materializeBoxAddressIfNeeded(Value basePtr, PatternRewriter &rewriter, Location loc) const; }; void FIRToMemRef::populateShapeAndShift(SmallVectorImpl &shapeVec, SmallVectorImpl &shiftVec, fir::ShapeShiftOp shift) const { for (mlir::OperandRange::iterator i = shift.getPairs().begin(), endIter = shift.getPairs().end(); i != endIter;) { shiftVec.push_back(*i++); shapeVec.push_back(*i++); } } bool FIRToMemRef::isCompilerGeneratedAlloca(Operation *op) const { if (!isa(op)) llvm_unreachable("expected alloca op"); return !op->getAttr("bindc_name") && !op->getAttr("uniq_name"); } void FIRToMemRef::copyAttribute(Operation *from, Operation *to, llvm::StringRef name) const { if (Attribute value = from->getAttr(name)) to->setAttr(name, value); } Type FIRToMemRef::getBaseType(Type type, bool complexBaseTypes) const { if (fir::isa_fir_type(type)) { type = fir::getFortranElementType(type); } else if (auto memrefTy = dyn_cast(type)) { type = memrefTy.getElementType(); } if (!complexBaseTypes) if (auto complexTy = dyn_cast(type)) type = complexTy.getElementType(); return type; } bool FIRToMemRef::memrefIsDeviceData(Operation *memref) const { if (isa(memref)) return true; return cuf::hasDeviceDataAttr(memref); } mlir::Attribute FIRToMemRef::findCudaDataAttr(Value val) const { Value currentVal = val; llvm::SmallPtrSet visited; while (currentVal) { Operation *defOp = currentVal.getDefiningOp(); if (!defOp || !visited.insert(defOp).second) break; if (cuf::DataAttributeAttr cudaAttr = cuf::getDataAttr(defOp)) return cudaAttr; // TODO: This is a best-effort backward walk; it is easy to miss attributes // as FIR evolves. Long term, it would be preferable if the necessary // information was carried in the type system (or otherwise made available // without relying on a walk-back through defining ops). if (auto reboxOp = dyn_cast(defOp)) { currentVal = reboxOp.getBox(); } else if (auto convertOp = dyn_cast(defOp)) { currentVal = convertOp->getOperand(0); } else if (auto emboxOp = dyn_cast(defOp)) { currentVal = emboxOp.getMemref(); } else if (auto boxAddrOp = dyn_cast(defOp)) { currentVal = boxAddrOp.getVal(); } else if (auto declareOp = dyn_cast(defOp)) { currentVal = declareOp.getMemref(); } else { break; } } return nullptr; } Value FIRToMemRef::materializeBoxAddressIfNeeded(Value basePtr, PatternRewriter &rewriter, Location loc) const { if (!isa(basePtr.getType())) return basePtr; auto boxAddrOp = fir::BoxAddrOp::create(rewriter, loc, basePtr); if (auto cudaAttr = findCudaDataAttr(basePtr)) boxAddrOp->setAttr(cuf::getDataAttrName(), cudaAttr); return boxAddrOp.getResult(); } void FIRToMemRef::populateShift(SmallVectorImpl &vec, fir::ShiftOp shift) const { vec.append(shift.getOrigins().begin(), shift.getOrigins().end()); } void FIRToMemRef::populateShape(SmallVectorImpl &vec, fir::ShapeOp shape) const { vec.append(shape.getExtents().begin(), shape.getExtents().end()); } template void FIRToMemRef::collectSliceInfoFrom(OpTy op, SliceInfo &info) const { if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { Value shapeVal = op.getShape(); if (shapeVal) { Operation *shapeValOp = shapeVal.getDefiningOp(); if (auto shapeOp = dyn_cast(shapeValOp)) { populateShape(info.shapeVec, shapeOp); } else if (auto shapeShiftOp = dyn_cast(shapeValOp)) { populateShapeAndShift(info.shapeVec, info.shiftVec, shapeShiftOp); } else if (auto shiftOp = dyn_cast(shapeValOp)) { populateShift(info.shiftVec, shiftOp); } } if (auto sliceOp = getSliceOp(op.getSlice())) { // A slice path changes the physical projection of the boxed entity (for // example, `complex -> real` for `%re`). Preserve shape/shift for logical // indexing, but do not treat the triplets alone as layout information. if (hasProjectedSlice(sliceOp)) { info.hasProjectedSlice = true; } else { auto triples = sliceOp.getTriples(); info.sliceVec.append(triples.begin(), triples.end()); } } } } void FIRToMemRef::rewriteAlloca(fir::AllocaOp firAlloca, PatternRewriter &rewriter, FIRToMemRefTypeConverter &typeConverter) { if (!typeConverter.convertibleType(firAlloca.getInType())) return; if (typeConverter.isEmptyArray(firAlloca.getType())) return; rewriter.setInsertionPointAfter(firAlloca); Type type = firAlloca.getType(); MemRefType memrefTy = typeConverter.convertMemrefType(type); Location loc = firAlloca.getLoc(); SmallVector sizes = firAlloca.getOperands(); std::reverse(sizes.begin(), sizes.end()); auto alloca = memref::AllocaOp::create(rewriter, loc, memrefTy, sizes); copyAttribute(firAlloca, alloca, firAlloca.getBindcNameAttrName()); copyAttribute(firAlloca, alloca, firAlloca.getUniqNameAttrName()); copyAttribute(firAlloca, alloca, cuf::getDataAttrName()); copyAttribute(firAlloca, alloca, acc::getVarNameAttrName()); auto convert = fir::ConvertOp::create(rewriter, loc, type, alloca); rewriter.replaceOp(firAlloca, convert); if (isCompilerGeneratedAlloca(alloca)) { for (Operation *userOp : convert->getUsers()) { if (auto declareOp = dyn_cast(userOp)) { LLVM_DEBUG(llvm::dbgs() << "FIRToMemRef: removing declare for compiler temp:\n"; declareOp->dump()); declareOp->replaceAllUsesWith(convert); eraseOps.insert(userOp); } } } } bool FIRToMemRef::memrefIsOptional(Operation *op) const { if (auto declare = dyn_cast(op)) { if (fir::FortranVariableOpInterface(declare).isOptional()) return true; Value operand = declare.getMemref(); Operation *operandOp = operand.getDefiningOp(); if (operandOp && isa(operandOp)) return true; } for (mlir::Value result : op->getResults()) for (mlir::Operation *userOp : result.getUsers()) if (isa(userOp)) return true; // TODO: If `op` is not a `fir.declare`, OPTIONAL information may still be // present on a related `fir.declare` reached by tracing the address/box // through common forwarding ops (e.g. `fir.convert`, `fir.rebox`, // `fir.embox`, `fir.box_addr`), then checking `declare.isOptional()`. Add the // search after FIR improves on it. return false; } static Value castTypeToIndexType(Value originalValue, PatternRewriter &rewriter) { if (originalValue.getType().isIndex()) return originalValue; Type indexType = rewriter.getIndexType(); return arith::IndexCastOp::create(rewriter, originalValue.getLoc(), indexType, originalValue); } static bool shouldUseBoundaryBitcast(mlir::Type fromTy, mlir::Type toTy) { auto isBitcastCompatibleScalarType = [](mlir::Type ty) { return mlir::isa( ty) || (mlir::isa(ty) && mlir::cast(ty).getLen() == fir::CharacterType::singleton()); }; auto getKnownScalarBitWidth = [](mlir::Type ty) -> std::optional { if (auto intTy = mlir::dyn_cast(ty)) return intTy.getWidth(); if (auto floatTy = mlir::dyn_cast(ty)) return floatTy.getWidth(); return std::nullopt; }; if (fromTy == toTy) return false; const bool fromStd = fir::isa_std_type(fromTy); const bool toStd = fir::isa_std_type(toTy); if (fromStd == toStd) return false; if (!isBitcastCompatibleScalarType(fromTy) || !isBitcastCompatibleScalarType(toTy)) return false; auto fromBits = getKnownScalarBitWidth(fromTy); auto toBits = getKnownScalarBitWidth(toTy); if (fromBits && toBits && *fromBits != *toBits) return false; return true; } static mlir::Value createTypeConversion(PatternRewriter &rewriter, mlir::Location loc, mlir::Type toTy, mlir::Value value) { if (shouldUseBoundaryBitcast(value.getType(), toTy)) return fir::BitcastOp::create(rewriter, loc, toTy, value); return fir::ConvertOp::create(rewriter, loc, toTy, value); } FailureOr> FIRToMemRef::getMemrefIndices(fir::ArrayCoorOp arrayCoorOp, Operation *memref, PatternRewriter &rewriter, Value converted, Value one) const { IndexType indexTy = rewriter.getIndexType(); SmallVector indices; Location loc = arrayCoorOp->getLoc(); SliceInfo sliceInfo; int rank = arrayCoorOp.getIndices().size(); collectSliceInfoFrom(arrayCoorOp, sliceInfo); // Only collect shape/shift from fir.embox, never from fir.rebox. // // The distinction is a Fortran descriptor invariant: // // | fir.embox | fir.rebox // --------------|------------------------|--------------------------- // base_addr | raw pointer | pre-adjusted by (lb-1)*stride // Index formula | index - lb | index - 1 // Collect shift | yes | no // // fir.embox creates a box from a raw pointer so base_addr is not adjusted; // the shift must be collected so index arithmetic subtracts lb correctly. // fir.rebox re-boxes a live descriptor whose base_addr the runtime has // already adjusted downward by (lb-1)*stride; collecting the shift here // would subtract the lower bound a second time, giving the wrong element. if (auto embox = dyn_cast_or_null(memref)) { collectSliceInfoFrom(embox, sliceInfo); rank = getRankFromEmbox(embox); } // Projected boxed slices leave `sliceVec` empty on purpose: indices are // computed in the logical section coordinate space, while stride/base come // later from the box descriptor. SmallVector &shiftVec = sliceInfo.shiftVec; SmallVector &sliceVec = sliceInfo.sliceVec; SmallVector sliceLbs, sliceStrides; for (size_t i = 0; i < sliceVec.size(); i += 3) { sliceLbs.push_back(castTypeToIndexType(sliceVec[i], rewriter)); sliceStrides.push_back(castTypeToIndexType(sliceVec[i + 2], rewriter)); } const bool isShifted = !shiftVec.empty(); const bool isSliced = !sliceVec.empty(); ValueRange idxs = arrayCoorOp.getIndices(); Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); SmallVector filledPositions(rank, false); for (int i = 0; i < rank; ++i) { Value step = isSliced ? sliceStrides[i] : one; Operation *stepOp = step.getDefiningOp(); if (stepOp && mlir::isa_and_nonnull(stepOp)) { Value shift = isShifted ? shiftVec[i] : one; Value sliceLb = isSliced ? sliceLbs[i] : shift; Value offset = arith::SubIOp::create(rewriter, loc, sliceLb, shift); indices.push_back(offset); filledPositions[i] = true; } else { indices.push_back(zero); } } int arrayCoorIdx = 0; for (int i = 0; i < rank; ++i) { if (filledPositions[i]) continue; assert((unsigned int)arrayCoorIdx < idxs.size() && "empty dimension should be eliminated\n"); Value index = canonicalizeIndex(idxs[arrayCoorIdx], rewriter); Type cTy = index.getType(); if (!llvm::isa(cTy)) { assert(cTy.isSignlessInteger() && "expected signless integer type"); index = arith::IndexCastOp::create(rewriter, loc, indexTy, index); } Value shift = isShifted ? shiftVec[i] : one; Value stride = isSliced ? sliceStrides[i] : one; Value sliceLb = isSliced ? sliceLbs[i] : shift; // When the array_coor has an explicit slice with a shape_shift (i.e. // non-default lower bounds), the indices are in the shape_shift // coordinate space; subtract the lower bound (shift) to get 0-based // memref indices. Otherwise (the slice comes from an embox, or the // shape has no shift), the indices are 1-based section indices; // subtract 1. bool indicesAreFortran = isShifted && arrayCoorOp.getSlice() != nullptr; Value indexAdjustment = (isSliced && !indicesAreFortran) ? arith::ConstantIndexOp::create(rewriter, loc, 1) : shift; Value delta = arith::SubIOp::create(rewriter, loc, index, indexAdjustment); Value scaled = arith::MulIOp::create(rewriter, loc, delta, stride); Value offset = arith::SubIOp::create(rewriter, loc, sliceLb, shift); Value finalIndex = arith::AddIOp::create(rewriter, loc, scaled, offset); indices[i] = finalIndex; arrayCoorIdx++; } std::reverse(indices.begin(), indices.end()); return indices; } MemRefInfo FIRToMemRef::convertArrayCoorOp(Operation *memOp, fir::ArrayCoorOp arrayCoorOp, PatternRewriter &rewriter, FIRToMemRefTypeConverter &typeConverter) { IndexType indexTy = rewriter.getIndexType(); Value firMemref = arrayCoorOp.getMemref(); if (!typeConverter.convertibleMemrefType(firMemref.getType())) return failure(); if (typeConverter.isEmptyArray(firMemref.getType())) return failure(); Location loc = arrayCoorOp->getLoc(); // Prefer lowering the array-coordinates computation to a memref + indices. // This allows erasing fir.array_coor when it is only used by load/store even // if the base address is a block argument (e.g. region arguments). Operation *memref = nullptr; FailureOr converted; if (auto blockArg = dyn_cast(firMemref)) { rewriter.setInsertionPoint(arrayCoorOp); Value basePtr = materializeBoxAddressIfNeeded(blockArg, rewriter, loc); Type memrefTy = typeConverter.convertMemrefType(basePtr.getType()); converted = fir::ConvertOp::create(rewriter, loc, memrefTy, basePtr).getResult(); rewriter.setInsertionPointAfter(arrayCoorOp); } else if ((memref = firMemref.getDefiningOp()) && enableFIRConvertOptimizations && isMarshalLike(memref) && !fir::isa_fir_type(firMemref.getType())) { converted = firMemref; rewriter.setInsertionPoint(arrayCoorOp); } else { Operation *arrayCoorOperation = arrayCoorOp.getOperation(); rewriter.setInsertionPoint(arrayCoorOp); if (memrefIsOptional(memref)) { auto ifOp = arrayCoorOperation->getParentOfType(); if (ifOp) { Operation *condition = ifOp.getCondition().getDefiningOp(); if (condition && isa(condition)) if (condition->getOperand(0) == firMemref) { if (arrayCoorOperation->getParentRegion() == &ifOp.getThenRegion()) rewriter.setInsertionPointToStart( &(ifOp.getThenRegion().front())); else if (arrayCoorOperation->getParentRegion() == &ifOp.getElseRegion()) rewriter.setInsertionPointToStart( &(ifOp.getElseRegion().front())); } } } converted = getFIRConvert(memOp, memref, rewriter, typeConverter); if (failed(converted)) return failure(); rewriter.setInsertionPointAfter(arrayCoorOp); } Value one = arith::ConstantIndexOp::create(rewriter, loc, 1); FailureOr> failureOrIndices = getMemrefIndices(arrayCoorOp, memref, rewriter, *converted, one); if (failed(failureOrIndices)) return failure(); SmallVector indices = *failureOrIndices; if (converted == firMemref) return std::pair{*converted, indices}; Value convertedVal = *converted; MemRefType memRefTy = dyn_cast(convertedVal.getType()); bool isRebox = firMemref.getDefiningOp() != nullptr; bool isDescriptor = mlir::isa(firMemref.getType()) || firMemref.getDefiningOp() != nullptr; // Static shape does not imply contiguous layout for descriptor-backed // entities (e.g. boxed array sections with non-unit stride). Keep the // reinterpret-cast path so descriptor strides are preserved. if (memRefTy.hasStaticShape() && !isRebox && !isDescriptor) return std::pair{*converted, indices}; unsigned rank = arrayCoorOp.getIndices().size(); if (auto embox = firMemref.getDefiningOp()) rank = getRankFromEmbox(embox); SmallVector sizes; sizes.reserve(rank); SmallVector strides; strides.reserve(rank); SliceInfo sliceInfo; collectSliceInfoFrom(arrayCoorOp, sliceInfo); Value box = firMemref; if (!isa(firMemref)) { if (auto embox = firMemref.getDefiningOp()) { collectSliceInfoFrom(embox, sliceInfo); } else if (auto rebox = firMemref.getDefiningOp()) { collectSliceInfoFrom(rebox, sliceInfo); } } SmallVector &shapeVec = sliceInfo.shapeVec; if (sliceInfo.hasProjectedSlice || shapeVec.empty()) { // Projected slices carry their physical layout in the descriptor. Rebuild // the MemRef view from box metadata instead of from slice triplets. auto boxElementSize = fir::BoxEleSizeOp::create(rewriter, loc, indexTy, box); for (unsigned i = 0; i < rank; ++i) { Value dim = arith::ConstantIndexOp::create(rewriter, loc, rank - i - 1); auto boxDims = fir::BoxDimsOp::create(rewriter, loc, indexTy, indexTy, indexTy, box, dim); Value extent = boxDims->getResult(1); sizes.push_back(castTypeToIndexType(extent, rewriter)); Value byteStride = boxDims->getResult(2); Value div = arith::DivSIOp::create(rewriter, loc, byteStride, boxElementSize); strides.push_back(castTypeToIndexType(div, rewriter)); } } else { Value oneIdx = arith::ConstantIndexOp::create(rewriter, arrayCoorOp->getLoc(), 1); for (unsigned i = rank - 1; i > 0; --i) { Value size = shapeVec[i]; sizes.push_back(castTypeToIndexType(size, rewriter)); Value stride = shapeVec[0]; for (unsigned j = 1; j <= i - 1; ++j) stride = arith::MulIOp::create(rewriter, loc, shapeVec[j], stride); strides.push_back(castTypeToIndexType(stride, rewriter)); } sizes.push_back(castTypeToIndexType(shapeVec[0], rewriter)); strides.push_back(oneIdx); } assert(strides.size() == sizes.size() && sizes.size() == rank); int64_t dynamicOffset = ShapedType::kDynamic; SmallVector dynamicStrides(rank, ShapedType::kDynamic); auto stridedLayout = StridedLayoutAttr::get(convertedVal.getContext(), dynamicOffset, dynamicStrides); SmallVector dynamicShape(rank, ShapedType::kDynamic); memRefTy = MemRefType::get(dynamicShape, memRefTy.getElementType(), stridedLayout); Value offset = arith::ConstantIndexOp::create(rewriter, loc, 0); auto reinterpret = memref::ReinterpretCastOp::create( rewriter, loc, memRefTy, *converted, offset, sizes, strides); Value result = reinterpret->getResult(0); return std::pair{result, indices}; } FailureOr FIRToMemRef::getFIRConvert(Operation *memOp, Operation *op, PatternRewriter &rewriter, FIRToMemRefTypeConverter &typeConverter) { if (enableFIRConvertOptimizations && !op->hasOneUse() && !memrefIsOptional(op)) { for (Operation *userOp : op->getUsers()) { if (auto convertOp = dyn_cast(userOp)) { Value converted = convertOp.getResult(); if (!isa(converted.getType())) continue; if (userOp->getParentOp() == memOp->getParentOp() && domInfo->dominates(userOp, memOp)) return converted; } } } assert(op->getNumResults() == 1 && "expecting one result"); Value basePtr = op->getResult(0); MemRefType memrefTy = typeConverter.convertMemrefType(basePtr.getType()); Type baseTy = memrefTy.getElementType(); if (fir::isa_std_type(baseTy) && memrefTy.getRank() == 0) { if (auto convertOp = basePtr.getDefiningOp()) { Value input = convertOp.getOperand(); if (auto alloca = input.getDefiningOp()) { assert(alloca.getType() == memrefTy && "expected same types"); if (isCompilerGeneratedAlloca(alloca)) return alloca.getResult(); } } } const Location loc = op->getLoc(); if (isa(basePtr.getType())) { Operation *baseOp = basePtr.getDefiningOp(); basePtr = materializeBoxAddressIfNeeded(basePtr, rewriter, loc); memrefTy = typeConverter.convertMemrefType(basePtr.getType()); if (baseOp) { auto sameBaseBoxTypes = [&](Type baseType, Type memrefType) -> bool { Type emboxBaseTy = getBaseType(baseType, true); Type emboxMemrefTy = getBaseType(memrefType, true); return emboxBaseTy == emboxMemrefTy; }; if (auto embox = dyn_cast_or_null(baseOp)) { // A projected slice changes the element type of the boxed view. We // can only lower it here when the storage element is complex and // the projection is the real or imaginary part (i.e. %re / %im). For // such cases sizeof(complex) == 2*sizeof(T), so // divsi(byte_stride, elesize) is always an exact integer. // // Derived-type component projections (e.g. a%x, a%y) may produce a // non-integer element-unit stride (e.g. sizeof(T)=24, // sizeof(complex)=16 -> 24/16 = 1 after truncation, which is // wrong). For those, the type-restriction check below fires and we // bail out, leaving the ops for downstream FIR-to-LLVM lowering. auto isComplexComponentProjection = [&](fir::EmboxOp embox) -> bool { if (!hasProjectedSlice(getSliceOp(embox.getSlice()))) return false; Type memTy = fir::unwrapRefType(embox.getMemref().getType()); if (auto seqTy = dyn_cast(memTy)) memTy = seqTy.getEleTy(); return mlir::isa(memTy); }; bool projectedSlice = isComplexComponentProjection(embox); if (!projectedSlice && !sameBaseBoxTypes(embox.getType(), embox.getMemref().getType())) { LLVM_DEBUG(llvm::dbgs() << "FIRToMemRef: embox base type and memref type are not " "the same, bailing out of conversion\n"); return failure(); } // Keep `box_addr` on the projected box so the descriptor remains the // source of truth for projected element type and stride. if (!projectedSlice && embox.getSlice() && embox.getSlice().getDefiningOp()) { Type originalType = embox.getMemref().getType(); basePtr = embox.getMemref(); if (typeConverter.convertibleMemrefType(originalType)) { auto convertedMemrefTy = typeConverter.convertMemrefType(originalType); memrefTy = convertedMemrefTy; } else { return failure(); } } } if (auto rebox = dyn_cast(baseOp)) { if (!sameBaseBoxTypes(rebox.getType(), rebox.getBox().getType())) { LLVM_DEBUG(llvm::dbgs() << "FIRToMemRef: rebox base type and box type are not the " "same, bailing out of conversion\n"); return failure(); } Type originalType = rebox.getBox().getType(); if (auto boxTy = dyn_cast(originalType)) originalType = boxTy.getElementType(); if (!typeConverter.convertibleMemrefType(originalType)) { return failure(); } else { auto convertedMemrefTy = typeConverter.convertMemrefType(originalType); memrefTy = convertedMemrefTy; } } } } auto convert = fir::ConvertOp::create(rewriter, loc, memrefTy, basePtr); return convert->getResult(0); } Value FIRToMemRef::canonicalizeIndex(Value index, PatternRewriter &rewriter) const { if (auto blockArg = dyn_cast(index)) return index; Operation *op = index.getDefiningOp(); if (auto constant = dyn_cast(op)) { if (!constant.getType().isIndex()) { Value v = arith::ConstantIndexOp::create(rewriter, op->getLoc(), constant.value()); return v; } return constant; } if (auto extsi = dyn_cast(op)) { Value operand = extsi.getOperand(); if (auto indexCast = operand.getDefiningOp()) { Value v = indexCast.getOperand(); return v; } return canonicalizeIndex(operand, rewriter); } if (auto add = dyn_cast(op)) { Value lhs = canonicalizeIndex(add.getLhs(), rewriter); Value rhs = canonicalizeIndex(add.getRhs(), rewriter); if (lhs.getType() == rhs.getType()) return arith::AddIOp::create(rewriter, op->getLoc(), lhs, rhs); } return index; } MemRefInfo FIRToMemRef::getMemRefInfo(Value firMemref, PatternRewriter &rewriter, FIRToMemRefTypeConverter &typeConverter, Operation *memOp) { Operation *memrefOp = firMemref.getDefiningOp(); if (!memrefOp) { if (auto blockArg = dyn_cast(firMemref)) { rewriter.setInsertionPoint(memOp); Type memrefTy = typeConverter.convertMemrefType(blockArg.getType()); if (auto mt = dyn_cast(memrefTy)) if (auto inner = llvm::dyn_cast(mt.getElementType())) memrefTy = inner; Value converted = fir::ConvertOp::create(rewriter, blockArg.getLoc(), memrefTy, blockArg); SmallVector indices; return std::pair{converted, indices}; } llvm_unreachable( "FIRToMemRef: expected defining op or block argument for FIR memref"); } if (auto arrayCoorOp = dyn_cast(memrefOp)) { MemRefInfo memrefInfo = convertArrayCoorOp(memOp, arrayCoorOp, rewriter, typeConverter); if (succeeded(memrefInfo)) { for (auto user : memrefOp->getUsers()) { if (!isa(user)) { LLVM_DEBUG( llvm::dbgs() << "FIRToMemRef: array memref used by unsupported op:\n"; firMemref.dump(); user->dump()); return memrefInfo; } } eraseOps.insert(memrefOp); } return memrefInfo; } rewriter.setInsertionPoint(memOp); if (isMarshalLike(memrefOp)) { FailureOr converted = getFIRConvert(memOp, memrefOp, rewriter, typeConverter); if (failed(converted)) { LLVM_DEBUG(llvm::dbgs() << "FIRToMemRef: expected FIR memref in convert, bailing " "out:\n"; firMemref.dump()); return failure(); } SmallVector indices; return std::pair{*converted, indices}; } if (auto declareOp = dyn_cast(memrefOp)) { FailureOr converted = getFIRConvert(memOp, declareOp, rewriter, typeConverter); if (failed(converted)) { LLVM_DEBUG(llvm::dbgs() << "FIRToMemRef: unable to create convert for scalar " "memref:\n"; firMemref.dump()); return failure(); } SmallVector indices; return std::pair{*converted, indices}; } if (auto coordinateOp = dyn_cast(memrefOp)) { FailureOr converted = getFIRConvert(memOp, coordinateOp, rewriter, typeConverter); if (failed(converted)) { LLVM_DEBUG( llvm::dbgs() << "FIRToMemRef: unable to create convert for derived-type " "memref:\n"; firMemref.dump()); return failure(); } SmallVector indices; return std::pair{*converted, indices}; } if (auto convertOp = dyn_cast(memrefOp)) { Type fromTy = convertOp->getOperand(0).getType(); Type toTy = firMemref.getType(); if (isa(fromTy) && isa(toTy)) { FailureOr converted = getFIRConvert(memOp, convertOp, rewriter, typeConverter); if (failed(converted)) { LLVM_DEBUG( llvm::dbgs() << "FIRToMemRef: unable to create convert for conversion " "op:\n"; firMemref.dump()); return failure(); } SmallVector indices; return std::pair{*converted, indices}; } } if (auto boxAddrOp = dyn_cast(memrefOp)) { FailureOr converted = getFIRConvert(memOp, boxAddrOp, rewriter, typeConverter); if (failed(converted)) { LLVM_DEBUG(llvm::dbgs() << "FIRToMemRef: unable to create convert for box_addr " "op:\n"; firMemref.dump()); return failure(); } SmallVector indices; return std::pair{*converted, indices}; } if (memrefIsDeviceData(memrefOp)) { FailureOr converted = getFIRConvert(memOp, memrefOp, rewriter, typeConverter); if (failed(converted)) return failure(); SmallVector indices; return std::pair{*converted, indices}; } LLVM_DEBUG(llvm::dbgs() << "FIRToMemRef: unable to create convert for memref value:\n"; firMemref.dump()); return failure(); } void FIRToMemRef::replaceFIRMemrefs(Value firMemref, Value converted, PatternRewriter &rewriter) const { Operation *op = firMemref.getDefiningOp(); if (op && (isa(op) || isMarshalLike(op))) return; SmallPtrSet worklist; for (auto user : firMemref.getUsers()) { if (isMarshalLike(user) || isa(user)) continue; if (!domInfo->dominates(converted, user)) continue; if (!(isa(user->getParentOp()) || isa(user->getParentOp()))) worklist.insert(user); } Type ty = firMemref.getType(); for (auto op : worklist) { rewriter.setInsertionPoint(op); Location loc = op->getLoc(); Value replaceConvert = fir::ConvertOp::create(rewriter, loc, ty, converted); op->replaceUsesOfWith(firMemref, replaceConvert); } worklist.clear(); for (auto user : firMemref.getUsers()) { if (isMarshalLike(user) || isa(user)) continue; if (isa(user->getParentOp()) || isa(user->getParentOp())) if (domInfo->dominates(converted, user)) worklist.insert(user); } if (worklist.empty()) return; while (!worklist.empty()) { Operation *parentOp = (*worklist.begin())->getParentOp(); Value replaceConvert; SmallVector erase; for (auto op : worklist) { if (op->getParentOp() != parentOp) continue; if (!replaceConvert) { rewriter.setInsertionPoint(parentOp); replaceConvert = fir::ConvertOp::create(rewriter, op->getLoc(), ty, converted); } op->replaceUsesOfWith(firMemref, replaceConvert); erase.push_back(op); } for (auto op : erase) worklist.erase(op); } } void FIRToMemRef::rewriteLoadOp(fir::LoadOp load, PatternRewriter &rewriter, FIRToMemRefTypeConverter &typeConverter) { Value firMemref = load.getMemref(); if (!typeConverter.convertibleType(firMemref.getType())) return; LLVM_DEBUG(llvm::dbgs() << "FIRToMemRef: attempting to convert FIR load:\n"; load.dump(); firMemref.dump()); MemRefInfo memrefInfo = getMemRefInfo(firMemref, rewriter, typeConverter, load.getOperation()); if (failed(memrefInfo)) return; Type originalType = load.getResult().getType(); Value converted = memrefInfo->first; SmallVector indices = memrefInfo->second; LLVM_DEBUG(llvm::dbgs() << "FIRToMemRef: convert for FIR load created successfully:\n"; converted.dump()); rewriter.setInsertionPointAfter(load); Attribute attr = (load.getOperation())->getAttr("tbaa"); memref::LoadOp loadOp = rewriter.replaceOpWithNewOp(load, converted, indices); if (attr) loadOp.getOperation()->setAttr("tbaa", attr); LLVM_DEBUG(llvm::dbgs() << "FIRToMemRef: new memref.load op:\n"; loadOp.dump(); assert(succeeded(verify(loadOp)))); if (loadOp.getType() != originalType) { Value castVal = createTypeConversion(rewriter, loadOp.getLoc(), originalType, loadOp); loadOp.getResult().replaceAllUsesExcept(castVal, castVal.getDefiningOp()); } if (!isa(originalType)) replaceFIRMemrefs(firMemref, converted, rewriter); } void FIRToMemRef::rewriteStoreOp(fir::StoreOp store, PatternRewriter &rewriter, FIRToMemRefTypeConverter &typeConverter) { Value firMemref = store.getMemref(); if (!typeConverter.convertibleType(firMemref.getType())) return; LLVM_DEBUG(llvm::dbgs() << "FIRToMemRef: attempting to convert FIR store:\n"; store.dump(); firMemref.dump()); MemRefInfo memrefInfo = getMemRefInfo(firMemref, rewriter, typeConverter, store.getOperation()); if (failed(memrefInfo)) return; Value converted = memrefInfo->first; SmallVector indices = memrefInfo->second; LLVM_DEBUG( llvm::dbgs() << "FIRToMemRef: convert for FIR store created successfully:\n"; converted.dump()); Value value = store.getValue(); rewriter.setInsertionPointAfter(store); Type convertedType = typeConverter.convertType(value.getType()); if (convertedType != value.getType()) value = createTypeConversion(rewriter, store.getLoc(), convertedType, value); Attribute attr = (store.getOperation())->getAttr("tbaa"); memref::StoreOp storeOp = rewriter.replaceOpWithNewOp( store, value, converted, indices); if (attr) storeOp.getOperation()->setAttr("tbaa", attr); LLVM_DEBUG(llvm::dbgs() << "FIRToMemRef: new memref.store op:\n"; storeOp.dump(); assert(succeeded(verify(storeOp)))); bool isLogicalRef = false; if (fir::ReferenceType refTy = llvm::dyn_cast(firMemref.getType())) isLogicalRef = llvm::isa(refTy.getEleTy()); if (!isLogicalRef) replaceFIRMemrefs(firMemref, converted, rewriter); } // Lower operand and result type of FIR logical operation to get rid // of bitcast after loads from memref and before store to memref // storing arrays of logicals as integers. template static void rewriteLogicalOperation(Op op, PatternRewriter &rewriter, FIRToMemRefTypeConverter &typeConverter) { mlir::Type oldType = op.getResult().getType(); mlir::Type convertedTy = typeConverter.convertType(oldType); if (convertedTy == oldType) return; rewriter.setInsertionPoint(op); mlir::Location loc = op.getLoc(); // Inserted bitcast from/to logical will be folded with the one created // around load/store. mlir::Value lhs = createTypeConversion(rewriter, loc, convertedTy, op.getLhs()); mlir::Value rhs = createTypeConversion(rewriter, loc, convertedTy, op.getRhs()); auto newOp = Op::create(rewriter, loc, convertedTy, lhs, rhs); mlir::Value result = createTypeConversion(rewriter, loc, oldType, newOp); rewriter.replaceOp(op, result); } void FIRToMemRef::runOnOperation() { LLVM_DEBUG(llvm::dbgs() << "Enter FIRToMemRef()\n"); func::FuncOp op = getOperation(); MLIRContext *context = op.getContext(); ModuleOp mod = op->getParentOfType(); FIRToMemRefTypeConverter typeConverter(mod); typeConverter.setConvertComplexTypes(true); PatternRewriter rewriter(context); domInfo = new DominanceInfo(op); op.walk([&](fir::AllocaOp alloca) { rewriteAlloca(alloca, rewriter, typeConverter); }); op.walk([&](Operation *op) { llvm::TypeSwitch(op) .Case([&](auto loadOp) { rewriteLoadOp(loadOp, rewriter, typeConverter); }) .Case([&](auto storeOp) { rewriteStoreOp(storeOp, rewriter, typeConverter); }) .Case( [&](auto logicalOp) { rewriteLogicalOperation(logicalOp, rewriter, typeConverter); }) .Default([](Operation *) {}); }); for (auto eraseOp : eraseOps) rewriter.eraseOp(eraseOp); eraseOps.clear(); if (domInfo) delete domInfo; LLVM_DEBUG(llvm::dbgs() << "After FIRToMemRef()\n"; op.dump(); llvm::dbgs() << "Exit FIRToMemRef()\n";); } } // namespace fir