From 7cc92ace2c37e4dbfc3a2015d695468481a7ba57 Mon Sep 17 00:00:00 2001
From: ziheng <ziheng@apache.org>
Date: Thu, 22 Jun 2017 12:47:41 -0700
Subject: [PATCH] [LANG] Expose tvm.cast (#195)

* [LANG] Expose tvm.cast

* Update

* Add unittest
---
 python/tvm/api.py                         | 17 +++++++++++++++++
 python/tvm/expr.py                        | 14 ++++++++++++++
 src/api/api_ir.cc                         |  1 +
 tests/python/unittest/test_lang_tensor.py |  2 ++
 4 files changed, 34 insertions(+)

diff --git a/python/tvm/api.py b/python/tvm/api.py
index a48741208..dea0252c6 100644
--- a/python/tvm/api.py
+++ b/python/tvm/api.py
@@ -467,6 +467,23 @@ def reduce_axis(dom, name="rv"):
     """
     return _IterVar(dom, name, 2)
 
+def cast(dtype, expr):
+    """Cast an expression to other type
+    Parameters
+    ----------
+    dtype : str, optional
+        The type of new expression
+    expr : Expr
+        The expression
+
+    Returns
+    -------
+    expr : Expr
+        Expression with new type
+    """
+    return _make.Cast(dtype, expr)
+
+
 def select(cond, t, f):
     """Construct a select branch
     Parameters
diff --git a/python/tvm/expr.py b/python/tvm/expr.py
index 49aeaebe4..42227bdea 100644
--- a/python/tvm/expr.py
+++ b/python/tvm/expr.py
@@ -89,6 +89,20 @@ class ExprOp(object):
         """
         return _make.EQ(self, other)
 
+    def astype(self, dtype):
+        """Cast the expression to other type
+        Parameters
+        ----------
+        dtype : str, optional
+            The type of new expression
+
+        Returns
+        -------
+        expr : Expr
+            Expression with new type
+        """
+        return _make.Cast(dtype, self)
+
 
 class Expr(NodeBase, ExprOp):
     """Base class of all tvm Expressions"""
diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc
index f66652b99..700418c4b 100644
--- a/src/api/api_ir.cc
+++ b/src/api/api_ir.cc
@@ -139,6 +139,7 @@ REGISTER_MAKE_BINARY_OP(Or);
 REGISTER_MAKE1(Not);
 REGISTER_MAKE3(Select);
 REGISTER_MAKE3(Ramp);
+REGISTER_MAKE2(Cast);
 REGISTER_MAKE2(Broadcast);
 REGISTER_MAKE3(Let);
 REGISTER_MAKE3(LetStmt);
diff --git a/tests/python/unittest/test_lang_tensor.py b/tests/python/unittest/test_lang_tensor.py
index 73d889f1b..371435351 100644
--- a/tests/python/unittest/test_lang_tensor.py
+++ b/tests/python/unittest/test_lang_tensor.py
@@ -16,6 +16,8 @@ def test_tensor():
     assert(T.op.output(0).__hash__() == T.__hash__())
     d = {T.op.output(0) : 1}
     assert(d[T] == 1)
+    assert(tvm.cast('float16', T[0][0][0]).dtype == 'float16')
+    assert(T[0][0][0].astype('float16').dtype == 'float16')
 
 
 def test_conv1d():
-- 
GitLab