diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 4147f89d703224ab53808cf52be425459fe59ab5..7f38581a4cdf24684891b27fc2d374212542aac9 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -521,6 +521,7 @@ class TVMRetValue : public TVMPODValue_ { inline const char* TypeCode2Str(int type_code) { switch (type_code) { case kInt: return "int"; + case kUInt: return "uint"; case kFloat: return "float"; case kStr: return "str"; case kBytes: return "bytes"; diff --git a/python/tvm/_ffi/node.py b/python/tvm/_ffi/node.py index 23a74b4a8632257c6dc3b0ee8cd137f849cfb106..97dc12d7ac915da732a0038b4c23a718be61e22e 100644 --- a/python/tvm/_ffi/node.py +++ b/python/tvm/_ffi/node.py @@ -6,7 +6,7 @@ import ctypes import sys from .. import _api_internal from .node_generic import NodeGeneric, convert_to_node, const -from .base import _LIB, check_call, c_str, _FFI_MODE +from .base import _LIB, check_call, c_str, py_str, _FFI_MODE IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError try: diff --git a/tests/python/unittest/test_lang_basic.py b/tests/python/unittest/test_lang_basic.py index 63a109d3c8df171e0bada408ff37081b7047df81..932c3955385f5b26d84f1ae047016e5f93d0bb6d 100644 --- a/tests/python/unittest/test_lang_basic.py +++ b/tests/python/unittest/test_lang_basic.py @@ -71,6 +71,15 @@ def test_stmt(): tvm.stmt.For.Serial, 0, x) +def test_dir(): + x = tvm.var('x') + dir(x) + +def test_dtype(): + x = tvm.var('x') + assert x.dtype == 'int32' + y = tvm.var('y') + assert (x > y).dtype == 'uint1' if __name__ == "__main__": test_attr() @@ -81,3 +90,5 @@ if __name__ == "__main__": test_basic() test_stmt() test_let() + test_dir() + test_dtype()