This patch adds support for `float8_e3m4` and `float8_e4m3` in `np_to_memref.py` by adding the appropriate ctypes structures. Additionally changes minimum numpy version to 2.1.0 and uses a single ml_dtypes version of 0.5.0.
This commit is contained in:
@@ -312,7 +312,7 @@ junitparser==3.2.0 \
|
|||||||
--hash=sha256:b05e89c27e7b74b3c563a078d6e055d95cf397444f8f689b0ca616ebda0b3c65 \
|
--hash=sha256:b05e89c27e7b74b3c563a078d6e055d95cf397444f8f689b0ca616ebda0b3c65 \
|
||||||
--hash=sha256:e14fdc0a999edfc15889b637390e8ef6ca09a49532416d3bd562857d42d4b96d
|
--hash=sha256:e14fdc0a999edfc15889b637390e8ef6ca09a49532416d3bd562857d42d4b96d
|
||||||
# via -r .ci/requirements.txt
|
# via -r .ci/requirements.txt
|
||||||
ml-dtypes==0.5.1 ; python_version < "3.13" \
|
ml-dtypes==0.5.1 \
|
||||||
--hash=sha256:023ce2f502efd4d6c1e0472cc58ce3640d051d40e71e27386bed33901e201327 \
|
--hash=sha256:023ce2f502efd4d6c1e0472cc58ce3640d051d40e71e27386bed33901e201327 \
|
||||||
--hash=sha256:05f23447a1c20ddf4dc7c2c661aa9ed93fcb2658f1017c204d1e758714dc28a8 \
|
--hash=sha256:05f23447a1c20ddf4dc7c2c661aa9ed93fcb2658f1017c204d1e758714dc28a8 \
|
||||||
--hash=sha256:12651420130ee7cc13059fc56dac6ad300c3af3848b802d475148c9defd27c23 \
|
--hash=sha256:12651420130ee7cc13059fc56dac6ad300c3af3848b802d475148c9defd27c23 \
|
||||||
@@ -342,52 +342,60 @@ nanobind==2.9.2 \
|
|||||||
--hash=sha256:c37957ffd5eac7eda349cff3622ecd32e5ee1244ecc912c99b5bc8188bafd16e \
|
--hash=sha256:c37957ffd5eac7eda349cff3622ecd32e5ee1244ecc912c99b5bc8188bafd16e \
|
||||||
--hash=sha256:e7608472de99d375759814cab3e2c94aba3f9ec80e62cfef8ced495ca5c27d6e
|
--hash=sha256:e7608472de99d375759814cab3e2c94aba3f9ec80e62cfef8ced495ca5c27d6e
|
||||||
# via -r mlir/python/requirements.txt
|
# via -r mlir/python/requirements.txt
|
||||||
numpy==2.0.2 \
|
numpy==2.1.2 \
|
||||||
--hash=sha256:0123ffdaa88fa4ab64835dcbde75dcdf89c453c922f18dced6e27c90d1d0ec5a \
|
--hash=sha256:05b2d4e667895cc55e3ff2b56077e4c8a5604361fc21a042845ea3ad67465aa8 \
|
||||||
--hash=sha256:11a76c372d1d37437857280aa142086476136a8c0f373b2e648ab2c8f18fb195 \
|
--hash=sha256:12edb90831ff481f7ef5f6bc6431a9d74dc0e5ff401559a71e5e4611d4f2d466 \
|
||||||
--hash=sha256:13e689d772146140a252c3a28501da66dfecd77490b498b168b501835041f951 \
|
--hash=sha256:13311c2db4c5f7609b462bc0f43d3c465424d25c626d95040f073e30f7570e35 \
|
||||||
--hash=sha256:1e795a8be3ddbac43274f18588329c72939870a16cae810c2b73461c40718ab1 \
|
--hash=sha256:13532a088217fa624c99b843eeb54640de23b3414b14aa66d023805eb731066c \
|
||||||
--hash=sha256:26df23238872200f63518dd2aa984cfca675d82469535dc7162dc2ee52d9dd5c \
|
--hash=sha256:13602b3174432a35b16c4cfb5de9a12d229727c3dd47a6ce35111f2ebdf66ff4 \
|
||||||
--hash=sha256:286cd40ce2b7d652a6f22efdfc6d1edf879440e53e76a75955bc0c826c7e64dc \
|
--hash=sha256:1600068c262af1ca9580a527d43dc9d959b0b1d8e56f8a05d830eea39b7c8af6 \
|
||||||
--hash=sha256:2b2955fa6f11907cf7a70dab0d0755159bca87755e831e47932367fc8f2f2d0b \
|
--hash=sha256:1b8cde4f11f0a975d1fd59373b32e2f5a562ade7cde4f85b7137f3de8fbb29a0 \
|
||||||
--hash=sha256:2da5960c3cf0df7eafefd806d4e612c5e19358de82cb3c343631188991566ccd \
|
--hash=sha256:1c193d0b0238638e6fc5f10f1b074a6993cb13b0b431f64079a509d63d3aa8b7 \
|
||||||
--hash=sha256:312950fdd060354350ed123c0e25a71327d3711584beaef30cdaa93320c392d4 \
|
--hash=sha256:1ebec5fd716c5a5b3d8dfcc439be82a8407b7b24b230d0ad28a81b61c2f4659a \
|
||||||
--hash=sha256:423e89b23490805d2a5a96fe40ec507407b8ee786d66f7328be214f9679df6dd \
|
--hash=sha256:242b39d00e4944431a3cd2db2f5377e15b5785920421993770cddb89992c3f3a \
|
||||||
--hash=sha256:496f71341824ed9f3d2fd36cf3ac57ae2e0165c143b55c3a035ee219413f3318 \
|
--hash=sha256:259ec80d54999cc34cd1eb8ded513cb053c3bf4829152a2e00de2371bd406f5e \
|
||||||
--hash=sha256:49ca4decb342d66018b01932139c0961a8f9ddc7589611158cb3c27cbcf76448 \
|
--hash=sha256:2abbf905a0b568706391ec6fa15161fad0fb5d8b68d73c461b3c1bab6064dd62 \
|
||||||
--hash=sha256:51129a29dbe56f9ca83438b706e2e69a39892b5eda6cedcb6b0c9fdc9b0d3ece \
|
--hash=sha256:2cbba4b30bf31ddbe97f1c7205ef976909a93a66bb1583e983adbd155ba72ac2 \
|
||||||
--hash=sha256:5fec9451a7789926bcf7c2b8d187292c9f93ea30284802a0ab3f5be8ab36865d \
|
--hash=sha256:2ffef621c14ebb0188a8633348504a35c13680d6da93ab5cb86f4e54b7e922b5 \
|
||||||
--hash=sha256:671bec6496f83202ed2d3c8fdc486a8fc86942f2e69ff0e986140339a63bcbe5 \
|
--hash=sha256:30d53720b726ec36a7f88dc873f0eec8447fbc93d93a8f079dfac2629598d6ee \
|
||||||
--hash=sha256:7f0a0c6f12e07fa94133c8a67404322845220c06a9e80e85999afe727f7438b8 \
|
--hash=sha256:32e16a03138cabe0cb28e1007ee82264296ac0983714094380b408097a418cfe \
|
||||||
--hash=sha256:807ec44583fd708a21d4a11d94aedf2f4f3c3719035c76a2bbe1fe8e217bdc57 \
|
--hash=sha256:43cca367bf94a14aca50b89e9bc2061683116cfe864e56740e083392f533ce7a \
|
||||||
--hash=sha256:883c987dee1880e2a864ab0dc9892292582510604156762362d9326444636e78 \
|
--hash=sha256:456e3b11cb79ac9946c822a56346ec80275eaf2950314b249b512896c0d2505e \
|
||||||
--hash=sha256:8c5713284ce4e282544c68d1c3b2c7161d38c256d2eefc93c1d683cf47683e66 \
|
--hash=sha256:4d6ec0d4222e8ffdab1744da2560f07856421b367928026fb540e1945f2eeeaf \
|
||||||
--hash=sha256:8cafab480740e22f8d833acefed5cc87ce276f4ece12fdaa2e8903db2f82897a \
|
--hash=sha256:5006b13a06e0b38d561fab5ccc37581f23c9511879be7693bd33c7cd15ca227c \
|
||||||
--hash=sha256:8df823f570d9adf0978347d1f926b2a867d5608f434a7cff7f7908c6570dcf5e \
|
--hash=sha256:675c741d4739af2dc20cd6c6a5c4b7355c728167845e3c6b0e824e4e5d36a6c3 \
|
||||||
--hash=sha256:9059e10581ce4093f735ed23f3b9d283b9d517ff46009ddd485f1747eb22653c \
|
--hash=sha256:6cdb606a7478f9ad91c6283e238544451e3a95f30fb5467fbf715964341a8a86 \
|
||||||
--hash=sha256:905d16e0c60200656500c95b6b8dca5d109e23cb24abc701d41c02d74c6b3afa \
|
--hash=sha256:6d95f286b8244b3649b477ac066c6906fbb2905f8ac19b170e2175d3d799f4df \
|
||||||
--hash=sha256:9189427407d88ff25ecf8f12469d4d39d35bee1db5d39fc5c168c6f088a6956d \
|
--hash=sha256:76322dcdb16fccf2ac56f99048af32259dcc488d9b7e25b51e5eca5147a3fb98 \
|
||||||
--hash=sha256:96a55f64139912d61de9137f11bf39a55ec8faec288c75a54f93dfd39f7eb40c \
|
--hash=sha256:7c1c60328bd964b53f8b835df69ae8198659e2b9302ff9ebb7de4e5a5994db3d \
|
||||||
--hash=sha256:97032a27bd9d8988b9a97a8c4d2c9f2c15a81f61e2f21404d7e8ef00cb5be729 \
|
--hash=sha256:860ec6e63e2c5c2ee5e9121808145c7bf86c96cca9ad396c0bd3e0f2798ccbe2 \
|
||||||
--hash=sha256:984d96121c9f9616cd33fbd0618b7f08e0cfc9600a7ee1d6fd9b239186d19d97 \
|
--hash=sha256:8e00ea6fc82e8a804433d3e9cedaa1051a1422cb6e443011590c14d2dea59146 \
|
||||||
--hash=sha256:9a92ae5c14811e390f3767053ff54eaee3bf84576d99a2456391401323f4ec2c \
|
--hash=sha256:9c6c754df29ce6a89ed23afb25550d1c2d5fdb9901d9c67a16e0b16eaf7e2550 \
|
||||||
--hash=sha256:9ea91dfb7c3d1c56a0e55657c0afb38cf1eeae4544c208dc465c3c9f3a7c09f9 \
|
--hash=sha256:a26ae94658d3ba3781d5e103ac07a876b3e9b29db53f68ed7df432fd033358a8 \
|
||||||
--hash=sha256:a15f476a45e6e5a3a79d8a14e62161d27ad897381fecfa4a09ed5322f2085669 \
|
--hash=sha256:a65acfdb9c6ebb8368490dbafe83c03c7e277b37e6857f0caeadbbc56e12f4fb \
|
||||||
--hash=sha256:a392a68bd329eafac5817e5aefeb39038c48b671afd242710b451e76090e81f4 \
|
--hash=sha256:a7d80b2e904faa63068ead63107189164ca443b42dd1930299e0d1cb041cec2e \
|
||||||
--hash=sha256:a3f4ab0caa7f053f6797fcd4e1e25caee367db3112ef2b6ef82d749530768c73 \
|
--hash=sha256:a84498e0d0a1174f2b3ed769b67b656aa5460c92c9554039e11f20a05650f00d \
|
||||||
--hash=sha256:a46288ec55ebbd58947d31d72be2c63cbf839f0a63b49cb755022310792a3385 \
|
--hash=sha256:ab4754d432e3ac42d33a269c8567413bdb541689b02d93788af4131018cbf366 \
|
||||||
--hash=sha256:a61ec659f68ae254e4d237816e33171497e978140353c0c2038d46e63282d0c8 \
|
--hash=sha256:ad369ed238b1959dfbade9018a740fb9392c5ac4f9b5173f420bd4f37ba1f7a0 \
|
||||||
--hash=sha256:a842d573724391493a97a62ebbb8e731f8a5dcc5d285dfc99141ca15a3302d0c \
|
--hash=sha256:b1d0fcae4f0949f215d4632be684a539859b295e2d0cb14f78ec231915d644db \
|
||||||
--hash=sha256:becfae3ddd30736fe1889a37f1f580e245ba79a5855bff5f2a29cb3ccc22dd7b \
|
--hash=sha256:b42a1a511c81cc78cbc4539675713bbcf9d9c3913386243ceff0e9429ca892fe \
|
||||||
--hash=sha256:c05e238064fc0610c840d1cf6a13bf63d7e391717d247f1bf0318172e759e692 \
|
--hash=sha256:bd33f82e95ba7ad632bc57837ee99dba3d7e006536200c4e9124089e1bf42426 \
|
||||||
--hash=sha256:c1c9307701fec8f3f7a1e6711f9089c06e6284b3afbbcd259f7791282d660a15 \
|
--hash=sha256:bdd407c40483463898b84490770199d5714dcc9dd9b792f6c6caccc523c00952 \
|
||||||
--hash=sha256:c7b0be4ef08607dd04da4092faee0b86607f111d5ae68036f16cc787e250a131 \
|
--hash=sha256:c6eef7a2dbd0abfb0d9eaf78b73017dbfd0b54051102ff4e6a7b2980d5ac1a03 \
|
||||||
--hash=sha256:cfd41e13fdc257aa5778496b8caa5e856dc4896d4ccf01841daee1d96465467a \
|
--hash=sha256:c82af4b2ddd2ee72d1fc0c6695048d457e00b3582ccde72d8a1c991b808bb20f \
|
||||||
--hash=sha256:d731a1c6116ba289c1e9ee714b08a8ff882944d4ad631fd411106a30f083c326 \
|
--hash=sha256:d666cb72687559689e9906197e3bec7b736764df6a2e58ee265e360663e9baf7 \
|
||||||
--hash=sha256:df55d490dea7934f330006d0f81e8551ba6010a5bf035a249ef61a94f21c500b \
|
--hash=sha256:d7bf0a4f9f15b32b5ba53147369e94296f5fffb783db5aacc1be15b4bf72f43b \
|
||||||
--hash=sha256:ec9852fb39354b5a45a80bdab5ac02dd02b15f44b3804e9f00c556bf24b4bded \
|
--hash=sha256:d82075752f40c0ddf57e6e02673a17f6cb0f8eb3f587f63ca1eaab5594da5b17 \
|
||||||
--hash=sha256:f15975dfec0cf2239224d80e32c3170b1d168335eaedee69da84fbe9f1f9cd04 \
|
--hash=sha256:da65fb46d4cbb75cb417cddf6ba5e7582eb7bb0b47db4b99c9fe5787ce5d91f5 \
|
||||||
--hash=sha256:f26b258c385842546006213344c50655ff1555a9338e2e5e02a0756dc3e803dd
|
--hash=sha256:e2b49c3c0804e8ecb05d59af8386ec2f74877f7ca8fd9c1e00be2672e4d399b1 \
|
||||||
|
--hash=sha256:e585c8ae871fd38ac50598f4763d73ec5497b0de9a0ab4ef5b69f01c6a046142 \
|
||||||
|
--hash=sha256:e8d3ca0a72dd8846eb6f7dfe8f19088060fcb76931ed592d29128e0219652884 \
|
||||||
|
--hash=sha256:ef444c57d664d35cac4e18c298c47d7b504c66b17c2ea91312e979fcfbdfb08a \
|
||||||
|
--hash=sha256:f1eb068ead09f4994dec71c24b2844f1e4e4e013b9629f812f292f04bd1510d9 \
|
||||||
|
--hash=sha256:f2ded8d9b6f68cc26f8425eda5d3877b47343e68ca23d0d0846f4d312ecaa445 \
|
||||||
|
--hash=sha256:f751ed0a2f250541e19dfca9f1eafa31a392c71c832b6bb9e113b10d050cb0f1 \
|
||||||
|
--hash=sha256:faa88bc527d0f097abdc2c663cddf37c05a1c2f113716601555249805cf573f1 \
|
||||||
|
--hash=sha256:fc44e3c68ff00fd991b59092a54350e6e4911152682b4782f68070985aa9e648
|
||||||
# via
|
# via
|
||||||
# -r mlir/python/requirements.txt
|
# -r mlir/python/requirements.txt
|
||||||
# ml-dtypes
|
# ml-dtypes
|
||||||
|
|||||||
@@ -37,12 +37,25 @@ class BF16(ctypes.Structure):
|
|||||||
|
|
||||||
_fields_ = [("bf16", ctypes.c_int16)]
|
_fields_ = [("bf16", ctypes.c_int16)]
|
||||||
|
|
||||||
|
|
||||||
class F8E5M2(ctypes.Structure):
|
class F8E5M2(ctypes.Structure):
|
||||||
"""A ctype representation for MLIR's Float8E5M2."""
|
"""A ctype representation for MLIR's Float8E5M2."""
|
||||||
|
|
||||||
_fields_ = [("f8E5M2", ctypes.c_int8)]
|
_fields_ = [("f8E5M2", ctypes.c_int8)]
|
||||||
|
|
||||||
|
|
||||||
|
class F8E3M4(ctypes.Structure):
|
||||||
|
"""A ctype representation for MLIR's Float8E3M4."""
|
||||||
|
|
||||||
|
_fields_ = [("f8E3M4", ctypes.c_int8)]
|
||||||
|
|
||||||
|
|
||||||
|
class F8E4M3(ctypes.Structure):
|
||||||
|
"""A ctype representation for MLIR's Float8E4M3."""
|
||||||
|
|
||||||
|
_fields_ = [("f8E4M3", ctypes.c_int8)]
|
||||||
|
|
||||||
|
|
||||||
# https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype
|
# https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype
|
||||||
def as_ctype(dtp):
|
def as_ctype(dtp):
|
||||||
"""Converts dtype to ctype."""
|
"""Converts dtype to ctype."""
|
||||||
@@ -56,6 +69,10 @@ def as_ctype(dtp):
|
|||||||
return BF16
|
return BF16
|
||||||
if ml_dtypes is not None and dtp == ml_dtypes.float8_e5m2:
|
if ml_dtypes is not None and dtp == ml_dtypes.float8_e5m2:
|
||||||
return F8E5M2
|
return F8E5M2
|
||||||
|
if ml_dtypes is not None and dtp == ml_dtypes.float8_e3m4:
|
||||||
|
return F8E3M4
|
||||||
|
if ml_dtypes is not None and dtp == ml_dtypes.float8_e4m3:
|
||||||
|
return F8E4M3
|
||||||
return np.ctypeslib.as_ctypes_type(dtp)
|
return np.ctypeslib.as_ctypes_type(dtp)
|
||||||
|
|
||||||
|
|
||||||
@@ -68,15 +85,17 @@ def to_numpy(array):
|
|||||||
if array.dtype == F16:
|
if array.dtype == F16:
|
||||||
return array.view("float16")
|
return array.view("float16")
|
||||||
assert not (
|
assert not (
|
||||||
array.dtype == BF16 and ml_dtypes is None
|
array.dtype in (BF16, F8E5M2, F8E3M4, F8E4M3) and ml_dtypes is None
|
||||||
), f"bfloat16 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
|
), f"{array.dtype=} requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
|
||||||
if array.dtype == BF16:
|
if array.dtype == BF16:
|
||||||
return array.view("bfloat16")
|
return array.view("bfloat16")
|
||||||
assert not (
|
|
||||||
array.dtype == F8E5M2 and ml_dtypes is None
|
|
||||||
), f"float8_e5m2 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
|
|
||||||
if array.dtype == F8E5M2:
|
if array.dtype == F8E5M2:
|
||||||
return array.view("float8_e5m2")
|
return array.view("float8_e5m2")
|
||||||
|
if array.dtype == F8E3M4:
|
||||||
|
return array.view("float8_e3m4")
|
||||||
|
if array.dtype == F8E4M3:
|
||||||
|
return array.view("float8_e4m3")
|
||||||
|
|
||||||
return array
|
return array
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,5 @@ nanobind>=2.9, <3.0
|
|||||||
PyYAML>=5.4.0, <=6.0.1
|
PyYAML>=5.4.0, <=6.0.1
|
||||||
typing_extensions>=4.12.2
|
typing_extensions>=4.12.2
|
||||||
# RUN dependencies
|
# RUN dependencies
|
||||||
numpy>=1.19.5, <=2.1.2
|
numpy>=2.1.0, <=2.1.2
|
||||||
ml_dtypes>=0.1.0, <=0.6.0; python_version<"3.13" # provides several NumPy dtype extensions, including the bf16
|
ml_dtypes>=0.5.0, <=0.6.0
|
||||||
ml_dtypes>=0.5.0, <=0.6.0; python_version>="3.13"
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from mlir.execution_engine import *
|
|||||||
from mlir.runtime import *
|
from mlir.runtime import *
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from ml_dtypes import bfloat16, float8_e5m2
|
from ml_dtypes import bfloat16, float8_e5m2, float8_e3m4, float8_e4m3
|
||||||
|
|
||||||
HAS_ML_DTYPES = True
|
HAS_ML_DTYPES = True
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
@@ -623,6 +623,90 @@ else:
|
|||||||
log("TEST: testF8E5M2Memref")
|
log("TEST: testF8E5M2Memref")
|
||||||
|
|
||||||
|
|
||||||
|
# Test f8E3M4 memrefs
|
||||||
|
# CHECK-LABEL: TEST: testF8E3M4Memref
|
||||||
|
def testF8E3M4Memref():
|
||||||
|
with Context():
|
||||||
|
module = Module.parse(
|
||||||
|
"""
|
||||||
|
module {
|
||||||
|
func.func @main(%arg0: memref<1xf8E3M4>,
|
||||||
|
%arg1: memref<1xf8E3M4>) attributes { llvm.emit_c_interface } {
|
||||||
|
%0 = arith.constant 0 : index
|
||||||
|
%1 = memref.load %arg0[%0] : memref<1xf8E3M4>
|
||||||
|
memref.store %1, %arg1[%0] : memref<1xf8E3M4>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} """
|
||||||
|
)
|
||||||
|
|
||||||
|
arg1 = np.array([0.5]).astype(float8_e3m4)
|
||||||
|
arg2 = np.array([0.0]).astype(float8_e3m4)
|
||||||
|
|
||||||
|
arg1_memref_ptr = ctypes.pointer(
|
||||||
|
ctypes.pointer(get_ranked_memref_descriptor(arg1))
|
||||||
|
)
|
||||||
|
arg2_memref_ptr = ctypes.pointer(
|
||||||
|
ctypes.pointer(get_ranked_memref_descriptor(arg2))
|
||||||
|
)
|
||||||
|
|
||||||
|
execution_engine = ExecutionEngine(lowerToLLVM(module))
|
||||||
|
execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr)
|
||||||
|
|
||||||
|
# test to-numpy utility
|
||||||
|
x = ranked_memref_to_numpy(arg2_memref_ptr[0])
|
||||||
|
assert len(x) == 1
|
||||||
|
assert x[0] == 0.5
|
||||||
|
|
||||||
|
|
||||||
|
if HAS_ML_DTYPES:
|
||||||
|
run(testF8E3M4Memref)
|
||||||
|
else:
|
||||||
|
log("TEST: testF8E3M4Memref")
|
||||||
|
|
||||||
|
|
||||||
|
# Test f8E4M3 memrefs
|
||||||
|
# CHECK-LABEL: TEST: testF8E4M3Memref
|
||||||
|
def testF8E4M3Memref():
|
||||||
|
with Context():
|
||||||
|
module = Module.parse(
|
||||||
|
"""
|
||||||
|
module {
|
||||||
|
func.func @main(%arg0: memref<1xf8E4M3>,
|
||||||
|
%arg1: memref<1xf8E4M3>) attributes { llvm.emit_c_interface } {
|
||||||
|
%0 = arith.constant 0 : index
|
||||||
|
%1 = memref.load %arg0[%0] : memref<1xf8E4M3>
|
||||||
|
memref.store %1, %arg1[%0] : memref<1xf8E4M3>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} """
|
||||||
|
)
|
||||||
|
|
||||||
|
arg1 = np.array([0.5]).astype(float8_e4m3)
|
||||||
|
arg2 = np.array([0.0]).astype(float8_e4m3)
|
||||||
|
|
||||||
|
arg1_memref_ptr = ctypes.pointer(
|
||||||
|
ctypes.pointer(get_ranked_memref_descriptor(arg1))
|
||||||
|
)
|
||||||
|
arg2_memref_ptr = ctypes.pointer(
|
||||||
|
ctypes.pointer(get_ranked_memref_descriptor(arg2))
|
||||||
|
)
|
||||||
|
|
||||||
|
execution_engine = ExecutionEngine(lowerToLLVM(module))
|
||||||
|
execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr)
|
||||||
|
|
||||||
|
# test to-numpy utility
|
||||||
|
x = ranked_memref_to_numpy(arg2_memref_ptr[0])
|
||||||
|
assert len(x) == 1
|
||||||
|
assert x[0] == 0.5
|
||||||
|
|
||||||
|
|
||||||
|
if HAS_ML_DTYPES:
|
||||||
|
run(testF8E4M3Memref)
|
||||||
|
else:
|
||||||
|
log("TEST: testF8E4M3Memref")
|
||||||
|
|
||||||
|
|
||||||
# Test addition of two 2d_memref
|
# Test addition of two 2d_memref
|
||||||
# CHECK-LABEL: TEST: testDynamicMemrefAdd2D
|
# CHECK-LABEL: TEST: testDynamicMemrefAdd2D
|
||||||
def testDynamicMemrefAdd2D():
|
def testDynamicMemrefAdd2D():
|
||||||
|
|||||||
Reference in New Issue
Block a user