diff --git a/python/tvm/api.py b/python/tvm/api.py index a487412087eeb65da0be56ccefba4b005a4b0bda..dea0252c61096de907f8deab669c3d03c69782bb 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -467,6 +467,23 @@ def reduce_axis(dom, name="rv"): """ return _IterVar(dom, name, 2) +def cast(dtype, expr): + """Cast an expression to other type + Parameters + ---------- + dtype : str, optional + The type of new expression + expr : Expr + The expression + + Returns + ------- + expr : Expr + Expression with new type + """ + return _make.Cast(dtype, expr) + + def select(cond, t, f): """Construct a select branch Parameters diff --git a/python/tvm/expr.py b/python/tvm/expr.py index 49aeaebe430375fa63dfcf9af0b9c7d0b4ad59df..42227bdea44769ccc83c0593ebe236d76dadf6e7 100644 --- a/python/tvm/expr.py +++ b/python/tvm/expr.py @@ -89,6 +89,20 @@ class ExprOp(object): """ return _make.EQ(self, other) + def astype(self, dtype): + """Cast the expression to other type + Parameters + ---------- + dtype : str, optional + The type of new expression + + Returns + ------- + expr : Expr + Expression with new type + """ + return _make.Cast(dtype, self) + class Expr(NodeBase, ExprOp): """Base class of all tvm Expressions""" diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc index f66652b99157d5a29eaade5f7e8caef46f576a27..700418c4ba1d99ce5a9006b8778c27db2462bc93 100644 --- a/src/api/api_ir.cc +++ b/src/api/api_ir.cc @@ -139,6 +139,7 @@ REGISTER_MAKE_BINARY_OP(Or); REGISTER_MAKE1(Not); REGISTER_MAKE3(Select); REGISTER_MAKE3(Ramp); +REGISTER_MAKE2(Cast); REGISTER_MAKE2(Broadcast); REGISTER_MAKE3(Let); REGISTER_MAKE3(LetStmt); diff --git a/tests/python/unittest/test_lang_tensor.py b/tests/python/unittest/test_lang_tensor.py index 73d889f1b0a341db822e14320991eae39349f314..371435351799743cbd0d2ebd75abcf3d417dc496 100644 --- a/tests/python/unittest/test_lang_tensor.py +++ b/tests/python/unittest/test_lang_tensor.py @@ -16,6 +16,8 @@ def test_tensor(): assert(T.op.output(0).__hash__() == T.__hash__()) d = {T.op.output(0) : 1} assert(d[T] == 1) + assert(tvm.cast('float16', T[0][0][0]).dtype == 'float16') + assert(T[0][0][0].astype('float16').dtype == 'float16') def test_conv1d():