[MLIR][NVVM] SpecialRegister&PureSpecialRegister takes result type (#195030)

Use concrete `I32` (default) and `I64` (clock64, globaltimer) instead of
generic `LLVM_Type` for special-register op results. The dialect
verifier now rejects mismatches up-front, and the Python op-binding
generator emits the inferred-result form, so callers can write
`nvvm.ThreadIdXOp()` with no arguments. Strict tightening: no valid
existing IR is rejected.
This commit is contained in:
Bastian Hagedorn
2026-04-30 14:22:03 +02:00
committed by GitHub
parent 875d2c9fbc
commit 44753d8646
3 changed files with 35 additions and 3 deletions

View File

@@ -309,12 +309,15 @@ class NVVM_SingleResultIntrinsicOp<string mnemonic, list<Trait> traits = [], str
class NVVM_PureSpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
NVVM_IntrOp<mnemonic, !listconcat(traits, [Pure]), 1> {
let arguments = (ins);
let results = (outs I32:$res);
let assemblyFormat = "attr-dict `:` type($res)";
}
class NVVM_SpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
class NVVM_SpecialRegisterOp<string mnemonic, Type resultType = I32,
list<Trait> traits = []> :
NVVM_IntrOp<mnemonic, traits, 1> {
let arguments = (ins);
let results = (outs resultType:$res);
let assemblyFormat = "attr-dict `:` type($res)";
}
@@ -421,8 +424,8 @@ def NVVM_AggrSmemSize : NVVM_PureSpecialRegisterOp<"read.ptx.sreg.aggr.smem.s
//===----------------------------------------------------------------------===//
// Clock registers
def NVVM_ClockOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.clock">;
def NVVM_Clock64Op : NVVM_SpecialRegisterOp<"read.ptx.sreg.clock64">;
def NVVM_GlobalTimerOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.globaltimer">;
def NVVM_Clock64Op : NVVM_SpecialRegisterOp<"read.ptx.sreg.clock64", I64>;
def NVVM_GlobalTimerOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.globaltimer", I64>;
def NVVM_GlobalTimerLoOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.globaltimer.lo">;
//===----------------------------------------------------------------------===//

View File

@@ -2115,3 +2115,19 @@ module attributes { dlti.dl_spec = #dlti.dl_spec<
%0 = llvm.ptrtoaddr %arg0 : !llvm.ptr to i64
}
}
// -----
func.func @nvvm_read_sreg_tid_x_wrong_type() {
// expected-error@+1 {{'nvvm.read.ptx.sreg.tid.x' op result #0 must be 32-bit signless integer, but got 'i64'}}
%0 = nvvm.read.ptx.sreg.tid.x : i64
return
}
// -----
func.func @nvvm_read_sreg_clock64_wrong_type() {
// expected-error@+1 {{'nvvm.read.ptx.sreg.clock64' op result #0 must be 64-bit signless integer, but got 'i32'}}
%0 = nvvm.read.ptx.sreg.clock64 : i32
return
}

View File

@@ -377,3 +377,16 @@ def test_reductions():
# CHECK: %[[REDUX_35:.*]] = nvvm.redux.sync fmax %[[ARG2]], %[[ARG1]] : f32 -> f32
# CHECK: return
# CHECK: }
# CHECK-LABEL: TEST: testSpecialRegisterInferredResults
@constructAndPrintInModule
def testSpecialRegisterInferredResults():
# CHECK: %{{.*}} = nvvm.read.ptx.sreg.tid.x : i32
nvvm.ThreadIdXOp()
# CHECK: %{{.*}} = nvvm.read.ptx.sreg.clock : i32
nvvm.ClockOp()
# CHECK: %{{.*}} = nvvm.read.ptx.sreg.clock64 : i64
nvvm.Clock64Op()
# CHECK: %{{.*}} = nvvm.read.ptx.sreg.globaltimer : i64
nvvm.GlobalTimerOp()