[MLIR][NVVM] SpecialRegister&PureSpecialRegister takes result type (#195030)
Use concrete `I32` (default) and `I64` (clock64, globaltimer) instead of generic `LLVM_Type` for special-register op results. The dialect verifier now rejects mismatches up-front, and the Python op-binding generator emits the inferred-result form, so callers can write `nvvm.ThreadIdXOp()` with no arguments. Strict tightening: no valid existing IR is rejected.
This commit is contained in:
@@ -309,12 +309,15 @@ class NVVM_SingleResultIntrinsicOp<string mnemonic, list<Trait> traits = [], str
|
||||
class NVVM_PureSpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
|
||||
NVVM_IntrOp<mnemonic, !listconcat(traits, [Pure]), 1> {
|
||||
let arguments = (ins);
|
||||
let results = (outs I32:$res);
|
||||
let assemblyFormat = "attr-dict `:` type($res)";
|
||||
}
|
||||
|
||||
class NVVM_SpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
|
||||
class NVVM_SpecialRegisterOp<string mnemonic, Type resultType = I32,
|
||||
list<Trait> traits = []> :
|
||||
NVVM_IntrOp<mnemonic, traits, 1> {
|
||||
let arguments = (ins);
|
||||
let results = (outs resultType:$res);
|
||||
let assemblyFormat = "attr-dict `:` type($res)";
|
||||
}
|
||||
|
||||
@@ -421,8 +424,8 @@ def NVVM_AggrSmemSize : NVVM_PureSpecialRegisterOp<"read.ptx.sreg.aggr.smem.s
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Clock registers
|
||||
def NVVM_ClockOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.clock">;
|
||||
def NVVM_Clock64Op : NVVM_SpecialRegisterOp<"read.ptx.sreg.clock64">;
|
||||
def NVVM_GlobalTimerOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.globaltimer">;
|
||||
def NVVM_Clock64Op : NVVM_SpecialRegisterOp<"read.ptx.sreg.clock64", I64>;
|
||||
def NVVM_GlobalTimerOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.globaltimer", I64>;
|
||||
def NVVM_GlobalTimerLoOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.globaltimer.lo">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -2115,3 +2115,19 @@ module attributes { dlti.dl_spec = #dlti.dl_spec<
|
||||
%0 = llvm.ptrtoaddr %arg0 : !llvm.ptr to i64
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @nvvm_read_sreg_tid_x_wrong_type() {
|
||||
// expected-error@+1 {{'nvvm.read.ptx.sreg.tid.x' op result #0 must be 32-bit signless integer, but got 'i64'}}
|
||||
%0 = nvvm.read.ptx.sreg.tid.x : i64
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @nvvm_read_sreg_clock64_wrong_type() {
|
||||
// expected-error@+1 {{'nvvm.read.ptx.sreg.clock64' op result #0 must be 64-bit signless integer, but got 'i32'}}
|
||||
%0 = nvvm.read.ptx.sreg.clock64 : i32
|
||||
return
|
||||
}
|
||||
|
||||
@@ -377,3 +377,16 @@ def test_reductions():
|
||||
# CHECK: %[[REDUX_35:.*]] = nvvm.redux.sync fmax %[[ARG2]], %[[ARG1]] : f32 -> f32
|
||||
# CHECK: return
|
||||
# CHECK: }
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testSpecialRegisterInferredResults
|
||||
@constructAndPrintInModule
|
||||
def testSpecialRegisterInferredResults():
|
||||
# CHECK: %{{.*}} = nvvm.read.ptx.sreg.tid.x : i32
|
||||
nvvm.ThreadIdXOp()
|
||||
# CHECK: %{{.*}} = nvvm.read.ptx.sreg.clock : i32
|
||||
nvvm.ClockOp()
|
||||
# CHECK: %{{.*}} = nvvm.read.ptx.sreg.clock64 : i64
|
||||
nvvm.Clock64Op()
|
||||
# CHECK: %{{.*}} = nvvm.read.ptx.sreg.globaltimer : i64
|
||||
nvvm.GlobalTimerOp()
|
||||
|
||||
Reference in New Issue
Block a user