[mlir][acc] Capture explicit serial semantics for compute regions (#195158)
This PR improves robustness in capturing when user's intent is to treat OpenACC region as sequential. It does so in the following ways: - Ensure that `seq` acc.par_width is explicitly used when region is serial. Previously it was not assigning any acc.par_width which causes ambiguities because that way it is indistinguishable whether a region is explicitly serial vs whether the region needs implicitly assigned parallelism. - Treas `acc parallel` and `acc kernels` with `num_gangs(1)` `num_workers(1)` `vector_length(1)` exactly the same as `acc serial`. This is because these are all parallelism dimensions expressible with OpenACC clauses and being all set to 1 makes the semantics consistent with those defined for `acc serial`.
This commit is contained in:
@@ -24,15 +24,17 @@
|
|||||||
// ----------------
|
// ----------------
|
||||||
// 1. Compute constructs: acc.parallel, acc.serial, and acc.kernels are
|
// 1. Compute constructs: acc.parallel, acc.serial, and acc.kernels are
|
||||||
// replaced by acc.kernel_environment containing a single acc.compute_region.
|
// replaced by acc.kernel_environment containing a single acc.compute_region.
|
||||||
// Launch arguments (num_gangs, num_workers, vector_length) become
|
// For acc.parallel / acc.kernels, launch arguments (num_gangs, num_workers,
|
||||||
// acc.par_width ops (each result is `index`) and are passed as
|
// vector_length) become acc.par_width ops (each result is `index`) and are
|
||||||
// compute_region launch operands (still required to be acc.par_width
|
// passed as compute_region launch operands. Compute regions with
|
||||||
// results by the compute_region verifier).
|
// num_gangs(1), num_workers(1), and vector_length(1) and acc serial use a
|
||||||
|
// single sequential acc.par_width launch operand.
|
||||||
//
|
//
|
||||||
// 2. acc.loop: Converted according to context and attributes:
|
// 2. acc.loop: Converted according to context and attributes:
|
||||||
// - Unstructured: body wrapped in scf.execute_region.
|
// - Unstructured: body wrapped in scf.execute_region.
|
||||||
// - Sequential (serial region or seq clause): scf.parallel with
|
// - Sequential (serial region, seq clause, or compute region with
|
||||||
// par_dims = sequential.
|
// num_gangs(1), num_workers(1), and vector_length(1)):
|
||||||
|
// scf.parallel with par_dims = sequential.
|
||||||
// - Auto (in parallel/kernels): scf.for with collapse when
|
// - Auto (in parallel/kernels): scf.for with collapse when
|
||||||
// multi-dimensional.
|
// multi-dimensional.
|
||||||
// - Orphan (not inside a compute construct): scf.for, no collapse.
|
// - Orphan (not inside a compute construct): scf.for, no collapse.
|
||||||
@@ -56,6 +58,7 @@
|
|||||||
#include "mlir/Interfaces/FunctionInterfaces.h"
|
#include "mlir/Interfaces/FunctionInterfaces.h"
|
||||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
#include "mlir/Transforms/RegionUtils.h"
|
#include "mlir/Transforms/RegionUtils.h"
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace acc {
|
namespace acc {
|
||||||
@@ -82,24 +85,34 @@ static Value stripIndexCasts(Value val) {
|
|||||||
return val;
|
return val;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A parallel construct is "effectively serial" when it specifies
|
template <typename ComputeOpT>
|
||||||
/// num_gangs(1), num_workers(1), and vector_length(1). This matches
|
static bool isGangWorkerVectorAllOne(ComputeOpT op) {
|
||||||
/// the semantics of acc.serial but expressed through acc.parallel.
|
|
||||||
static bool isEffectivelySerial(ParallelOp op) {
|
|
||||||
auto numGangs = op.getNumGangsValues();
|
auto numGangs = op.getNumGangsValues();
|
||||||
if (numGangs.size() != 1)
|
if (numGangs.empty())
|
||||||
return false;
|
return false;
|
||||||
|
for (Value gangSize : numGangs) {
|
||||||
|
if (!isConstantIntValue(stripIndexCasts(gangSize), 1))
|
||||||
|
return false;
|
||||||
|
}
|
||||||
Value numWorkers = op.getNumWorkersValue();
|
Value numWorkers = op.getNumWorkersValue();
|
||||||
if (!numWorkers)
|
if (!numWorkers)
|
||||||
return false;
|
return false;
|
||||||
Value vectorLength = op.getVectorLengthValue();
|
Value vectorLength = op.getVectorLengthValue();
|
||||||
if (!vectorLength)
|
if (!vectorLength)
|
||||||
return false;
|
return false;
|
||||||
return isConstantIntValue(stripIndexCasts(numGangs.front()), 1) &&
|
return isConstantIntValue(stripIndexCasts(numWorkers), 1) &&
|
||||||
isConstantIntValue(stripIndexCasts(numWorkers), 1) &&
|
|
||||||
isConstantIntValue(stripIndexCasts(vectorLength), 1);
|
isConstantIntValue(stripIndexCasts(vectorLength), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// A compute construct is "effectively serial" when it specifies
|
||||||
|
/// num_gangs(1), num_workers(1), and vector_length(1). This is because
|
||||||
|
/// these are the only parallelism dimensions expressible from OpenACC spec
|
||||||
|
/// point-of-view and is consistent with how `serial` semantics are defined.
|
||||||
|
template <typename ComputeOpT>
|
||||||
|
static bool isEffectivelySerial(ComputeOpT op) {
|
||||||
|
return isGangWorkerVectorAllOne(op);
|
||||||
|
}
|
||||||
|
|
||||||
static bool isOpInComputeRegion(Operation *op) {
|
static bool isOpInComputeRegion(Operation *op) {
|
||||||
Region *region = op->getBlock()->getParent();
|
Region *region = op->getBlock()->getParent();
|
||||||
return getEnclosingComputeOp(*region) != nullptr;
|
return getEnclosingComputeOp(*region) != nullptr;
|
||||||
@@ -108,10 +121,12 @@ static bool isOpInComputeRegion(Operation *op) {
|
|||||||
static bool isOpInSerialRegion(Operation *op) {
|
static bool isOpInSerialRegion(Operation *op) {
|
||||||
if (auto parallelOp = op->getParentOfType<ParallelOp>())
|
if (auto parallelOp = op->getParentOfType<ParallelOp>())
|
||||||
return isEffectivelySerial(parallelOp);
|
return isEffectivelySerial(parallelOp);
|
||||||
if (auto computeRegion = op->getParentOfType<ComputeRegionOp>())
|
if (auto kernelsOp = op->getParentOfType<KernelsOp>())
|
||||||
return computeRegion.isEffectivelySerial();
|
return isEffectivelySerial(kernelsOp);
|
||||||
if (op->getParentOfType<SerialOp>())
|
if (op->getParentOfType<SerialOp>())
|
||||||
return true;
|
return true;
|
||||||
|
if (auto computeRegion = op->getParentOfType<ComputeRegionOp>())
|
||||||
|
return computeRegion.isEffectivelySerial();
|
||||||
if (auto funcOp = op->getParentOfType<FunctionOpInterface>()) {
|
if (auto funcOp = op->getParentOfType<FunctionOpInterface>()) {
|
||||||
if (isSpecializedAccRoutine(funcOp)) {
|
if (isSpecializedAccRoutine(funcOp)) {
|
||||||
auto attr = funcOp->getAttrOfType<SpecializedRoutineAttr>(
|
auto attr = funcOp->getAttrOfType<SpecializedRoutineAttr>(
|
||||||
@@ -194,61 +209,67 @@ getParallelDimensions(LoopOp loopOp, const ACCToGPUMappingPolicy &policy,
|
|||||||
return parDims;
|
return parDims;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create acc.par_width operations from gang/worker/vector values of a
|
/// Build `acc.compute_region` launch operands: one sequential `acc.par_width`
|
||||||
/// compute construct. Queries the device-type-specific values first, falling
|
/// for `acc.serial`, for `acc.parallel` / `acc.kernels` when every num_gangs
|
||||||
/// back to the default (DeviceType::None) values.
|
/// operand and num_workers / vector_length are the constant 1, and otherwise
|
||||||
|
/// `acc.par_width` from gang/worker/vector (device-type operands first, then
|
||||||
|
/// default DeviceType::None).
|
||||||
template <typename ComputeConstructT>
|
template <typename ComputeConstructT>
|
||||||
static SmallVector<Value>
|
static SmallVector<Value>
|
||||||
assignKnownLaunchArgs(ComputeConstructT computeOp, DeviceType deviceType,
|
assignKnownLaunchArgs(ComputeConstructT computeOp, DeviceType deviceType,
|
||||||
RewriterBase &rewriter,
|
RewriterBase &rewriter,
|
||||||
const ACCToGPUMappingPolicy &policy) {
|
const ACCToGPUMappingPolicy &policy) {
|
||||||
SmallVector<Value> values;
|
|
||||||
auto *ctx = rewriter.getContext();
|
auto *ctx = rewriter.getContext();
|
||||||
auto indexTy = rewriter.getIndexType();
|
|
||||||
auto loc = computeOp->getLoc();
|
auto loc = computeOp->getLoc();
|
||||||
|
|
||||||
auto numGangs = computeOp.getNumGangsValues(deviceType);
|
if constexpr (std::is_same_v<ComputeConstructT, SerialOp>) {
|
||||||
if (numGangs.empty())
|
return {ParWidthOp::create(rewriter, loc, Value(), policy.seqDim(ctx))};
|
||||||
numGangs = computeOp.getNumGangsValues();
|
} else if constexpr (llvm::is_one_of<ComputeConstructT, ParallelOp,
|
||||||
for (auto [gangDimIdx, gangSize] : llvm::enumerate(numGangs)) {
|
KernelsOp>::value) {
|
||||||
auto gangLevel = getGangParLevel(gangDimIdx + 1);
|
if (isEffectivelySerial(computeOp))
|
||||||
values.push_back(
|
return {ParWidthOp::create(rewriter, loc, Value(), policy.seqDim(ctx))};
|
||||||
ParWidthOp::create(rewriter, loc,
|
|
||||||
getValueOrCreateCastToIndexLike(
|
|
||||||
rewriter, gangSize.getLoc(), indexTy, gangSize),
|
|
||||||
policy.gangDim(ctx, gangLevel)));
|
|
||||||
}
|
|
||||||
|
|
||||||
Value numWorkers = computeOp.getNumWorkersValue(deviceType);
|
SmallVector<Value> values;
|
||||||
if (!numWorkers)
|
auto indexTy = rewriter.getIndexType();
|
||||||
numWorkers = computeOp.getNumWorkersValue();
|
|
||||||
if (numWorkers) {
|
|
||||||
values.push_back(ParWidthOp::create(
|
|
||||||
rewriter, loc,
|
|
||||||
getValueOrCreateCastToIndexLike(rewriter, numWorkers.getLoc(), indexTy,
|
|
||||||
numWorkers),
|
|
||||||
policy.workerDim(ctx)));
|
|
||||||
}
|
|
||||||
|
|
||||||
Value vectorLength = computeOp.getVectorLengthValue(deviceType);
|
auto numGangs = computeOp.getNumGangsValues(deviceType);
|
||||||
if (!vectorLength)
|
if (numGangs.empty())
|
||||||
vectorLength = computeOp.getVectorLengthValue();
|
numGangs = computeOp.getNumGangsValues();
|
||||||
if (vectorLength) {
|
for (auto [gangDimIdx, gangSize] : llvm::enumerate(numGangs)) {
|
||||||
values.push_back(ParWidthOp::create(
|
auto gangLevel = getGangParLevel(gangDimIdx + 1);
|
||||||
rewriter, loc,
|
values.push_back(ParWidthOp::create(
|
||||||
getValueOrCreateCastToIndexLike(rewriter, vectorLength.getLoc(),
|
rewriter, loc,
|
||||||
indexTy, vectorLength),
|
getValueOrCreateCastToIndexLike(rewriter, gangSize.getLoc(), indexTy,
|
||||||
policy.vectorDim(ctx)));
|
gangSize),
|
||||||
}
|
policy.gangDim(ctx, gangLevel)));
|
||||||
return values;
|
}
|
||||||
}
|
|
||||||
|
|
||||||
/// SerialOp has no gang/worker/vector clauses.
|
Value numWorkers = computeOp.getNumWorkersValue(deviceType);
|
||||||
template <>
|
if (!numWorkers)
|
||||||
SmallVector<Value>
|
numWorkers = computeOp.getNumWorkersValue();
|
||||||
assignKnownLaunchArgs<SerialOp>(SerialOp, DeviceType, RewriterBase &,
|
if (numWorkers) {
|
||||||
const ACCToGPUMappingPolicy &) {
|
values.push_back(ParWidthOp::create(
|
||||||
return {};
|
rewriter, loc,
|
||||||
|
getValueOrCreateCastToIndexLike(rewriter, numWorkers.getLoc(),
|
||||||
|
indexTy, numWorkers),
|
||||||
|
policy.workerDim(ctx)));
|
||||||
|
}
|
||||||
|
|
||||||
|
Value vectorLength = computeOp.getVectorLengthValue(deviceType);
|
||||||
|
if (!vectorLength)
|
||||||
|
vectorLength = computeOp.getVectorLengthValue();
|
||||||
|
if (vectorLength) {
|
||||||
|
values.push_back(ParWidthOp::create(
|
||||||
|
rewriter, loc,
|
||||||
|
getValueOrCreateCastToIndexLike(rewriter, vectorLength.getLoc(),
|
||||||
|
indexTy, vectorLength),
|
||||||
|
policy.vectorDim(ctx)));
|
||||||
|
}
|
||||||
|
return values;
|
||||||
|
} else {
|
||||||
|
llvm_unreachable("assignKnownLaunchArgs: expected parallel, kernels, or "
|
||||||
|
"serial");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|||||||
@@ -64,8 +64,8 @@ func.func @serial_loop(%buf: memref<4xi32>) {
|
|||||||
%dev = acc.copyin varPtr(%buf : memref<4xi32>) -> memref<4xi32>
|
%dev = acc.copyin varPtr(%buf : memref<4xi32>) -> memref<4xi32>
|
||||||
// CHECK-NOT: acc.serial
|
// CHECK-NOT: acc.serial
|
||||||
// CHECK: acc.kernel_environment
|
// CHECK: acc.kernel_environment
|
||||||
// CHECK-NOT: acc.par_width
|
// CHECK: acc.par_width {par_dim = #acc.par_dim<sequential>}
|
||||||
// CHECK: acc.compute_region
|
// CHECK: acc.compute_region launch(
|
||||||
// CHECK: scf.parallel
|
// CHECK: scf.parallel
|
||||||
// CHECK: acc.par_dims = #acc<par_dims[sequential]>
|
// CHECK: acc.par_dims = #acc<par_dims[sequential]>
|
||||||
acc.serial dataOperands(%dev : memref<4xi32>) {
|
acc.serial dataOperands(%dev : memref<4xi32>) {
|
||||||
@@ -117,7 +117,9 @@ func.func @constant_livein_materialized_into_compute_region(%buf: memref<1xi32>)
|
|||||||
%c42 = arith.constant 42 : i32
|
%c42 = arith.constant 42 : i32
|
||||||
%dev = acc.copyin varPtr(%buf : memref<1xi32>) -> memref<1xi32>
|
%dev = acc.copyin varPtr(%buf : memref<1xi32>) -> memref<1xi32>
|
||||||
// CHECK: acc.kernel_environment
|
// CHECK: acc.kernel_environment
|
||||||
// CHECK: acc.compute_region ins({{.*}}) : (memref<1xi32>) {
|
// CHECK: acc.par_width {par_dim = #acc.par_dim<sequential>}
|
||||||
|
// CHECK: acc.compute_region launch(
|
||||||
|
// CHECK-SAME: ins({{.*}}) : (memref<1xi32>) {
|
||||||
// CHECK-DAG: arith.constant 42 : i32
|
// CHECK-DAG: arith.constant 42 : i32
|
||||||
// CHECK-DAG: arith.constant 0 : index
|
// CHECK-DAG: arith.constant 0 : index
|
||||||
// CHECK: memref.store
|
// CHECK: memref.store
|
||||||
@@ -129,3 +131,142 @@ func.func @constant_livein_materialized_into_compute_region(%buf: memref<1xi32>)
|
|||||||
acc.copyout accPtr(%dev : memref<1xi32>) to varPtr(%buf : memref<1xi32>)
|
acc.copyout accPtr(%dev : memref<1xi32>) to varPtr(%buf : memref<1xi32>)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// acc.parallel with num_gangs(1), num_workers(1), and vector_length(1) is
|
||||||
|
// treated like acc.serial: sequential acc.par_width launch args and sequential
|
||||||
|
// par_dims on lowered loops.
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @parallel_unit_launch_serial_loops
|
||||||
|
func.func @parallel_unit_launch_serial_loops(%buf: memref<4xi32>) {
|
||||||
|
%c0 = arith.constant 0 : index
|
||||||
|
%c1 = arith.constant 1 : index
|
||||||
|
%c4 = arith.constant 4 : index
|
||||||
|
%c1_i32 = arith.constant 1 : i32
|
||||||
|
|
||||||
|
%dev = acc.copyin varPtr(%buf : memref<4xi32>) -> memref<4xi32>
|
||||||
|
// CHECK-NOT: acc.parallel
|
||||||
|
// CHECK: acc.kernel_environment
|
||||||
|
// CHECK: acc.par_width {par_dim = #acc.par_dim<sequential>}
|
||||||
|
// CHECK: acc.compute_region launch(
|
||||||
|
// CHECK: scf.parallel
|
||||||
|
// CHECK: acc.par_dims = #acc<par_dims[sequential]>
|
||||||
|
acc.parallel num_gangs({%c1_i32 : i32}) num_workers(%c1_i32 : i32) vector_length(%c1_i32 : i32) dataOperands(%dev : memref<4xi32>) {
|
||||||
|
acc.loop control(%i : index) = (%c0 : index) to (%c4 : index) step (%c1 : index) {
|
||||||
|
%vi = arith.index_cast %i : index to i32
|
||||||
|
memref.store %vi, %dev[%i] : memref<4xi32>
|
||||||
|
acc.yield
|
||||||
|
} attributes {independent = [#acc.device_type<none>]}
|
||||||
|
acc.yield
|
||||||
|
}
|
||||||
|
acc.copyout accPtr(%dev : memref<4xi32>) to varPtr(%buf : memref<4xi32>)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// acc.kernels with num_gangs(1), num_workers(1), and vector_length(1) is
|
||||||
|
// treated like acc.serial: sequential acc.par_width launch args and sequential
|
||||||
|
// par_dims on lowered loops.
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @kernels_unit_launch_serial_loops
|
||||||
|
func.func @kernels_unit_launch_serial_loops(%buf: memref<4xi32>) {
|
||||||
|
%c0 = arith.constant 0 : index
|
||||||
|
%c1 = arith.constant 1 : index
|
||||||
|
%c4 = arith.constant 4 : index
|
||||||
|
%c1_i32 = arith.constant 1 : i32
|
||||||
|
|
||||||
|
%dev = acc.copyin varPtr(%buf : memref<4xi32>) -> memref<4xi32>
|
||||||
|
// CHECK-NOT: acc.kernels
|
||||||
|
// CHECK: acc.kernel_environment
|
||||||
|
// CHECK: acc.par_width {par_dim = #acc.par_dim<sequential>}
|
||||||
|
// CHECK: acc.compute_region launch(
|
||||||
|
// CHECK: scf.parallel
|
||||||
|
// CHECK: acc.par_dims = #acc<par_dims[sequential]>
|
||||||
|
acc.kernels num_gangs({%c1_i32 : i32}) num_workers(%c1_i32 : i32) vector_length(%c1_i32 : i32) dataOperands(%dev : memref<4xi32>) {
|
||||||
|
acc.loop control(%i : index) = (%c0 : index) to (%c4 : index) step (%c1 : index) {
|
||||||
|
%vi = arith.index_cast %i : index to i32
|
||||||
|
memref.store %vi, %dev[%i] : memref<4xi32>
|
||||||
|
acc.yield
|
||||||
|
} attributes {independent = [#acc.device_type<none>]}
|
||||||
|
acc.terminator
|
||||||
|
}
|
||||||
|
acc.copyout accPtr(%dev : memref<4xi32>) to varPtr(%buf : memref<4xi32>)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @parallel_vector_length32_independent
|
||||||
|
func.func @parallel_vector_length32_independent(%buf: memref<4xi32>) {
|
||||||
|
%c0 = arith.constant 0 : index
|
||||||
|
%c1 = arith.constant 1 : index
|
||||||
|
%c4 = arith.constant 4 : index
|
||||||
|
%c1_i32 = arith.constant 1 : i32
|
||||||
|
%c32_i32 = arith.constant 32 : i32
|
||||||
|
|
||||||
|
%dev = acc.copyin varPtr(%buf : memref<4xi32>) -> memref<4xi32>
|
||||||
|
// CHECK-NOT: acc.par_dims = #acc<par_dims[sequential]>
|
||||||
|
// CHECK: acc.par_dims = #acc<par_dims[thread_x]>
|
||||||
|
acc.parallel num_gangs({%c1_i32 : i32}) num_workers(%c1_i32 : i32) vector_length(%c32_i32 : i32) dataOperands(%dev : memref<4xi32>) {
|
||||||
|
acc.loop control(%i : index) = (%c0 : index) to (%c4 : index) step (%c1 : index) {
|
||||||
|
%vi = arith.index_cast %i : index to i32
|
||||||
|
memref.store %vi, %dev[%i] : memref<4xi32>
|
||||||
|
acc.yield
|
||||||
|
} attributes {independent = [#acc.device_type<none>], vector = [#acc.device_type<none>]}
|
||||||
|
acc.yield
|
||||||
|
}
|
||||||
|
acc.copyout accPtr(%dev : memref<4xi32>) to varPtr(%buf : memref<4xi32>)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @kernels_num_gangs4_independent
|
||||||
|
func.func @kernels_num_gangs4_independent(%buf: memref<4xi32>) {
|
||||||
|
%c0 = arith.constant 0 : index
|
||||||
|
%c1 = arith.constant 1 : index
|
||||||
|
%c4 = arith.constant 4 : index
|
||||||
|
%c1_i32 = arith.constant 1 : i32
|
||||||
|
%c4_i32 = arith.constant 4 : i32
|
||||||
|
|
||||||
|
%dev = acc.copyin varPtr(%buf : memref<4xi32>) -> memref<4xi32>
|
||||||
|
// CHECK-NOT: acc.par_dims = #acc<par_dims[sequential]>
|
||||||
|
// CHECK: acc.par_dims = #acc<par_dims[thread_x]>
|
||||||
|
acc.kernels num_gangs({%c4_i32 : i32}) num_workers(%c1_i32 : i32) vector_length(%c1_i32 : i32) dataOperands(%dev : memref<4xi32>) {
|
||||||
|
acc.loop control(%i : index) = (%c0 : index) to (%c4 : index) step (%c1 : index) {
|
||||||
|
%vi = arith.index_cast %i : index to i32
|
||||||
|
memref.store %vi, %dev[%i] : memref<4xi32>
|
||||||
|
acc.yield
|
||||||
|
} attributes {independent = [#acc.device_type<none>], vector = [#acc.device_type<none>]}
|
||||||
|
acc.terminator
|
||||||
|
}
|
||||||
|
acc.copyout accPtr(%dev : memref<4xi32>) to varPtr(%buf : memref<4xi32>)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @parallel_num_gangs_1_2_independent
|
||||||
|
func.func @parallel_num_gangs_1_2_independent(%buf: memref<4xi32>) {
|
||||||
|
%c0 = arith.constant 0 : index
|
||||||
|
%c1 = arith.constant 1 : index
|
||||||
|
%c4 = arith.constant 4 : index
|
||||||
|
%c1_i32 = arith.constant 1 : i32
|
||||||
|
%c2_i32 = arith.constant 2 : i32
|
||||||
|
|
||||||
|
%dev = acc.copyin varPtr(%buf : memref<4xi32>) -> memref<4xi32>
|
||||||
|
// CHECK-NOT: acc.par_dims = #acc<par_dims[sequential]>
|
||||||
|
// CHECK: acc.par_dims = #acc<par_dims[thread_x]>
|
||||||
|
acc.parallel num_gangs({%c1_i32 : i32, %c2_i32 : i32}) num_workers(%c1_i32 : i32) vector_length(%c1_i32 : i32) dataOperands(%dev : memref<4xi32>) {
|
||||||
|
acc.loop control(%i : index) = (%c0 : index) to (%c4 : index) step (%c1 : index) {
|
||||||
|
%vi = arith.index_cast %i : index to i32
|
||||||
|
memref.store %vi, %dev[%i] : memref<4xi32>
|
||||||
|
acc.yield
|
||||||
|
} attributes {independent = [#acc.device_type<none>], vector = [#acc.device_type<none>]}
|
||||||
|
acc.yield
|
||||||
|
}
|
||||||
|
acc.copyout accPtr(%dev : memref<4xi32>) to varPtr(%buf : memref<4xi32>)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|||||||
@@ -92,8 +92,8 @@ func.func @serial_loop_normalized(%buf: memref<1xi32>) {
|
|||||||
%dev = acc.copyin varPtr(%buf : memref<1xi32>) -> memref<1xi32>
|
%dev = acc.copyin varPtr(%buf : memref<1xi32>) -> memref<1xi32>
|
||||||
// CHECK-NOT: acc.serial
|
// CHECK-NOT: acc.serial
|
||||||
// CHECK: acc.kernel_environment
|
// CHECK: acc.kernel_environment
|
||||||
// CHECK-NOT: acc.par_width
|
// CHECK: acc.par_width {par_dim = #acc.par_dim<sequential>}
|
||||||
// CHECK: acc.compute_region
|
// CHECK: acc.compute_region launch(
|
||||||
// CHECK: scf.parallel
|
// CHECK: scf.parallel
|
||||||
// CHECK-DAG: arith.muli
|
// CHECK-DAG: arith.muli
|
||||||
// CHECK-DAG: arith.addi
|
// CHECK-DAG: arith.addi
|
||||||
|
|||||||
Reference in New Issue
Block a user