diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCComputeLowering.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCComputeLowering.cpp index 9cc36312d361..9868584a4b69 100644 --- a/mlir/lib/Dialect/OpenACC/Transforms/ACCComputeLowering.cpp +++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCComputeLowering.cpp @@ -24,15 +24,17 @@ // ---------------- // 1. Compute constructs: acc.parallel, acc.serial, and acc.kernels are // replaced by acc.kernel_environment containing a single acc.compute_region. -// Launch arguments (num_gangs, num_workers, vector_length) become -// acc.par_width ops (each result is `index`) and are passed as -// compute_region launch operands (still required to be acc.par_width -// results by the compute_region verifier). +// For acc.parallel / acc.kernels, launch arguments (num_gangs, num_workers, +// vector_length) become acc.par_width ops (each result is `index`) and are +// passed as compute_region launch operands. Compute regions with +// num_gangs(1), num_workers(1), and vector_length(1) and acc serial use a +// single sequential acc.par_width launch operand. // // 2. acc.loop: Converted according to context and attributes: // - Unstructured: body wrapped in scf.execute_region. -// - Sequential (serial region or seq clause): scf.parallel with -// par_dims = sequential. +// - Sequential (serial region, seq clause, or compute region with +// num_gangs(1), num_workers(1), and vector_length(1)): +// scf.parallel with par_dims = sequential. // - Auto (in parallel/kernels): scf.for with collapse when // multi-dimensional. // - Orphan (not inside a compute construct): scf.for, no collapse. @@ -56,6 +58,7 @@ #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/STLExtras.h" namespace mlir { namespace acc { @@ -82,24 +85,34 @@ static Value stripIndexCasts(Value val) { return val; } -/// A parallel construct is "effectively serial" when it specifies -/// num_gangs(1), num_workers(1), and vector_length(1). This matches -/// the semantics of acc.serial but expressed through acc.parallel. -static bool isEffectivelySerial(ParallelOp op) { +template +static bool isGangWorkerVectorAllOne(ComputeOpT op) { auto numGangs = op.getNumGangsValues(); - if (numGangs.size() != 1) + if (numGangs.empty()) return false; + for (Value gangSize : numGangs) { + if (!isConstantIntValue(stripIndexCasts(gangSize), 1)) + return false; + } Value numWorkers = op.getNumWorkersValue(); if (!numWorkers) return false; Value vectorLength = op.getVectorLengthValue(); if (!vectorLength) return false; - return isConstantIntValue(stripIndexCasts(numGangs.front()), 1) && - isConstantIntValue(stripIndexCasts(numWorkers), 1) && + return isConstantIntValue(stripIndexCasts(numWorkers), 1) && isConstantIntValue(stripIndexCasts(vectorLength), 1); } +/// A compute construct is "effectively serial" when it specifies +/// num_gangs(1), num_workers(1), and vector_length(1). This is because +/// these are the only parallelism dimensions expressible from OpenACC spec +/// point-of-view and is consistent with how `serial` semantics are defined. +template +static bool isEffectivelySerial(ComputeOpT op) { + return isGangWorkerVectorAllOne(op); +} + static bool isOpInComputeRegion(Operation *op) { Region *region = op->getBlock()->getParent(); return getEnclosingComputeOp(*region) != nullptr; @@ -108,10 +121,12 @@ static bool isOpInComputeRegion(Operation *op) { static bool isOpInSerialRegion(Operation *op) { if (auto parallelOp = op->getParentOfType()) return isEffectivelySerial(parallelOp); - if (auto computeRegion = op->getParentOfType()) - return computeRegion.isEffectivelySerial(); + if (auto kernelsOp = op->getParentOfType()) + return isEffectivelySerial(kernelsOp); if (op->getParentOfType()) return true; + if (auto computeRegion = op->getParentOfType()) + return computeRegion.isEffectivelySerial(); if (auto funcOp = op->getParentOfType()) { if (isSpecializedAccRoutine(funcOp)) { auto attr = funcOp->getAttrOfType( @@ -194,61 +209,67 @@ getParallelDimensions(LoopOp loopOp, const ACCToGPUMappingPolicy &policy, return parDims; } -/// Create acc.par_width operations from gang/worker/vector values of a -/// compute construct. Queries the device-type-specific values first, falling -/// back to the default (DeviceType::None) values. +/// Build `acc.compute_region` launch operands: one sequential `acc.par_width` +/// for `acc.serial`, for `acc.parallel` / `acc.kernels` when every num_gangs +/// operand and num_workers / vector_length are the constant 1, and otherwise +/// `acc.par_width` from gang/worker/vector (device-type operands first, then +/// default DeviceType::None). template static SmallVector assignKnownLaunchArgs(ComputeConstructT computeOp, DeviceType deviceType, RewriterBase &rewriter, const ACCToGPUMappingPolicy &policy) { - SmallVector values; auto *ctx = rewriter.getContext(); - auto indexTy = rewriter.getIndexType(); auto loc = computeOp->getLoc(); - auto numGangs = computeOp.getNumGangsValues(deviceType); - if (numGangs.empty()) - numGangs = computeOp.getNumGangsValues(); - for (auto [gangDimIdx, gangSize] : llvm::enumerate(numGangs)) { - auto gangLevel = getGangParLevel(gangDimIdx + 1); - values.push_back( - ParWidthOp::create(rewriter, loc, - getValueOrCreateCastToIndexLike( - rewriter, gangSize.getLoc(), indexTy, gangSize), - policy.gangDim(ctx, gangLevel))); - } + if constexpr (std::is_same_v) { + return {ParWidthOp::create(rewriter, loc, Value(), policy.seqDim(ctx))}; + } else if constexpr (llvm::is_one_of::value) { + if (isEffectivelySerial(computeOp)) + return {ParWidthOp::create(rewriter, loc, Value(), policy.seqDim(ctx))}; - Value numWorkers = computeOp.getNumWorkersValue(deviceType); - if (!numWorkers) - numWorkers = computeOp.getNumWorkersValue(); - if (numWorkers) { - values.push_back(ParWidthOp::create( - rewriter, loc, - getValueOrCreateCastToIndexLike(rewriter, numWorkers.getLoc(), indexTy, - numWorkers), - policy.workerDim(ctx))); - } + SmallVector values; + auto indexTy = rewriter.getIndexType(); - Value vectorLength = computeOp.getVectorLengthValue(deviceType); - if (!vectorLength) - vectorLength = computeOp.getVectorLengthValue(); - if (vectorLength) { - values.push_back(ParWidthOp::create( - rewriter, loc, - getValueOrCreateCastToIndexLike(rewriter, vectorLength.getLoc(), - indexTy, vectorLength), - policy.vectorDim(ctx))); - } - return values; -} + auto numGangs = computeOp.getNumGangsValues(deviceType); + if (numGangs.empty()) + numGangs = computeOp.getNumGangsValues(); + for (auto [gangDimIdx, gangSize] : llvm::enumerate(numGangs)) { + auto gangLevel = getGangParLevel(gangDimIdx + 1); + values.push_back(ParWidthOp::create( + rewriter, loc, + getValueOrCreateCastToIndexLike(rewriter, gangSize.getLoc(), indexTy, + gangSize), + policy.gangDim(ctx, gangLevel))); + } -/// SerialOp has no gang/worker/vector clauses. -template <> -SmallVector -assignKnownLaunchArgs(SerialOp, DeviceType, RewriterBase &, - const ACCToGPUMappingPolicy &) { - return {}; + Value numWorkers = computeOp.getNumWorkersValue(deviceType); + if (!numWorkers) + numWorkers = computeOp.getNumWorkersValue(); + if (numWorkers) { + values.push_back(ParWidthOp::create( + rewriter, loc, + getValueOrCreateCastToIndexLike(rewriter, numWorkers.getLoc(), + indexTy, numWorkers), + policy.workerDim(ctx))); + } + + Value vectorLength = computeOp.getVectorLengthValue(deviceType); + if (!vectorLength) + vectorLength = computeOp.getVectorLengthValue(); + if (vectorLength) { + values.push_back(ParWidthOp::create( + rewriter, loc, + getValueOrCreateCastToIndexLike(rewriter, vectorLength.getLoc(), + indexTy, vectorLength), + policy.vectorDim(ctx))); + } + return values; + } else { + llvm_unreachable("assignKnownLaunchArgs: expected parallel, kernels, or " + "serial"); + } } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/OpenACC/acc-compute-lowering-compute.mlir b/mlir/test/Dialect/OpenACC/acc-compute-lowering-compute.mlir index ee177aaf6e7a..c2049dab676e 100644 --- a/mlir/test/Dialect/OpenACC/acc-compute-lowering-compute.mlir +++ b/mlir/test/Dialect/OpenACC/acc-compute-lowering-compute.mlir @@ -64,8 +64,8 @@ func.func @serial_loop(%buf: memref<4xi32>) { %dev = acc.copyin varPtr(%buf : memref<4xi32>) -> memref<4xi32> // CHECK-NOT: acc.serial // CHECK: acc.kernel_environment - // CHECK-NOT: acc.par_width - // CHECK: acc.compute_region + // CHECK: acc.par_width {par_dim = #acc.par_dim} + // CHECK: acc.compute_region launch( // CHECK: scf.parallel // CHECK: acc.par_dims = #acc acc.serial dataOperands(%dev : memref<4xi32>) { @@ -117,7 +117,9 @@ func.func @constant_livein_materialized_into_compute_region(%buf: memref<1xi32>) %c42 = arith.constant 42 : i32 %dev = acc.copyin varPtr(%buf : memref<1xi32>) -> memref<1xi32> // CHECK: acc.kernel_environment - // CHECK: acc.compute_region ins({{.*}}) : (memref<1xi32>) { + // CHECK: acc.par_width {par_dim = #acc.par_dim} + // CHECK: acc.compute_region launch( + // CHECK-SAME: ins({{.*}}) : (memref<1xi32>) { // CHECK-DAG: arith.constant 42 : i32 // CHECK-DAG: arith.constant 0 : index // CHECK: memref.store @@ -129,3 +131,142 @@ func.func @constant_livein_materialized_into_compute_region(%buf: memref<1xi32>) acc.copyout accPtr(%dev : memref<1xi32>) to varPtr(%buf : memref<1xi32>) return } + +// ----- + +// acc.parallel with num_gangs(1), num_workers(1), and vector_length(1) is +// treated like acc.serial: sequential acc.par_width launch args and sequential +// par_dims on lowered loops. + +// CHECK-LABEL: func.func @parallel_unit_launch_serial_loops +func.func @parallel_unit_launch_serial_loops(%buf: memref<4xi32>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c1_i32 = arith.constant 1 : i32 + + %dev = acc.copyin varPtr(%buf : memref<4xi32>) -> memref<4xi32> + // CHECK-NOT: acc.parallel + // CHECK: acc.kernel_environment + // CHECK: acc.par_width {par_dim = #acc.par_dim} + // CHECK: acc.compute_region launch( + // CHECK: scf.parallel + // CHECK: acc.par_dims = #acc + acc.parallel num_gangs({%c1_i32 : i32}) num_workers(%c1_i32 : i32) vector_length(%c1_i32 : i32) dataOperands(%dev : memref<4xi32>) { + acc.loop control(%i : index) = (%c0 : index) to (%c4 : index) step (%c1 : index) { + %vi = arith.index_cast %i : index to i32 + memref.store %vi, %dev[%i] : memref<4xi32> + acc.yield + } attributes {independent = [#acc.device_type]} + acc.yield + } + acc.copyout accPtr(%dev : memref<4xi32>) to varPtr(%buf : memref<4xi32>) + return +} + +// ----- + +// acc.kernels with num_gangs(1), num_workers(1), and vector_length(1) is +// treated like acc.serial: sequential acc.par_width launch args and sequential +// par_dims on lowered loops. + +// CHECK-LABEL: func.func @kernels_unit_launch_serial_loops +func.func @kernels_unit_launch_serial_loops(%buf: memref<4xi32>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c1_i32 = arith.constant 1 : i32 + + %dev = acc.copyin varPtr(%buf : memref<4xi32>) -> memref<4xi32> + // CHECK-NOT: acc.kernels + // CHECK: acc.kernel_environment + // CHECK: acc.par_width {par_dim = #acc.par_dim} + // CHECK: acc.compute_region launch( + // CHECK: scf.parallel + // CHECK: acc.par_dims = #acc + acc.kernels num_gangs({%c1_i32 : i32}) num_workers(%c1_i32 : i32) vector_length(%c1_i32 : i32) dataOperands(%dev : memref<4xi32>) { + acc.loop control(%i : index) = (%c0 : index) to (%c4 : index) step (%c1 : index) { + %vi = arith.index_cast %i : index to i32 + memref.store %vi, %dev[%i] : memref<4xi32> + acc.yield + } attributes {independent = [#acc.device_type]} + acc.terminator + } + acc.copyout accPtr(%dev : memref<4xi32>) to varPtr(%buf : memref<4xi32>) + return +} + +// ----- + +// CHECK-LABEL: func.func @parallel_vector_length32_independent +func.func @parallel_vector_length32_independent(%buf: memref<4xi32>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c1_i32 = arith.constant 1 : i32 + %c32_i32 = arith.constant 32 : i32 + + %dev = acc.copyin varPtr(%buf : memref<4xi32>) -> memref<4xi32> + // CHECK-NOT: acc.par_dims = #acc + // CHECK: acc.par_dims = #acc + acc.parallel num_gangs({%c1_i32 : i32}) num_workers(%c1_i32 : i32) vector_length(%c32_i32 : i32) dataOperands(%dev : memref<4xi32>) { + acc.loop control(%i : index) = (%c0 : index) to (%c4 : index) step (%c1 : index) { + %vi = arith.index_cast %i : index to i32 + memref.store %vi, %dev[%i] : memref<4xi32> + acc.yield + } attributes {independent = [#acc.device_type], vector = [#acc.device_type]} + acc.yield + } + acc.copyout accPtr(%dev : memref<4xi32>) to varPtr(%buf : memref<4xi32>) + return +} + +// ----- + +// CHECK-LABEL: func.func @kernels_num_gangs4_independent +func.func @kernels_num_gangs4_independent(%buf: memref<4xi32>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c1_i32 = arith.constant 1 : i32 + %c4_i32 = arith.constant 4 : i32 + + %dev = acc.copyin varPtr(%buf : memref<4xi32>) -> memref<4xi32> + // CHECK-NOT: acc.par_dims = #acc + // CHECK: acc.par_dims = #acc + acc.kernels num_gangs({%c4_i32 : i32}) num_workers(%c1_i32 : i32) vector_length(%c1_i32 : i32) dataOperands(%dev : memref<4xi32>) { + acc.loop control(%i : index) = (%c0 : index) to (%c4 : index) step (%c1 : index) { + %vi = arith.index_cast %i : index to i32 + memref.store %vi, %dev[%i] : memref<4xi32> + acc.yield + } attributes {independent = [#acc.device_type], vector = [#acc.device_type]} + acc.terminator + } + acc.copyout accPtr(%dev : memref<4xi32>) to varPtr(%buf : memref<4xi32>) + return +} + +// ----- + +// CHECK-LABEL: func.func @parallel_num_gangs_1_2_independent +func.func @parallel_num_gangs_1_2_independent(%buf: memref<4xi32>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c1_i32 = arith.constant 1 : i32 + %c2_i32 = arith.constant 2 : i32 + + %dev = acc.copyin varPtr(%buf : memref<4xi32>) -> memref<4xi32> + // CHECK-NOT: acc.par_dims = #acc + // CHECK: acc.par_dims = #acc + acc.parallel num_gangs({%c1_i32 : i32, %c2_i32 : i32}) num_workers(%c1_i32 : i32) vector_length(%c1_i32 : i32) dataOperands(%dev : memref<4xi32>) { + acc.loop control(%i : index) = (%c0 : index) to (%c4 : index) step (%c1 : index) { + %vi = arith.index_cast %i : index to i32 + memref.store %vi, %dev[%i] : memref<4xi32> + acc.yield + } attributes {independent = [#acc.device_type], vector = [#acc.device_type]} + acc.yield + } + acc.copyout accPtr(%dev : memref<4xi32>) to varPtr(%buf : memref<4xi32>) + return +} diff --git a/mlir/test/Dialect/OpenACC/acc-compute-lowering-loop.mlir b/mlir/test/Dialect/OpenACC/acc-compute-lowering-loop.mlir index bd2f006396c6..4032fed217b6 100644 --- a/mlir/test/Dialect/OpenACC/acc-compute-lowering-loop.mlir +++ b/mlir/test/Dialect/OpenACC/acc-compute-lowering-loop.mlir @@ -92,8 +92,8 @@ func.func @serial_loop_normalized(%buf: memref<1xi32>) { %dev = acc.copyin varPtr(%buf : memref<1xi32>) -> memref<1xi32> // CHECK-NOT: acc.serial // CHECK: acc.kernel_environment - // CHECK-NOT: acc.par_width - // CHECK: acc.compute_region + // CHECK: acc.par_width {par_dim = #acc.par_dim} + // CHECK: acc.compute_region launch( // CHECK: scf.parallel // CHECK-DAG: arith.muli // CHECK-DAG: arith.addi