From 7cc92ace2c37e4dbfc3a2015d695468481a7ba57 Mon Sep 17 00:00:00 2001 From: ziheng <ziheng@apache.org> Date: Thu, 22 Jun 2017 12:47:41 -0700 Subject: [PATCH] [LANG] Expose tvm.cast (#195) * [LANG] Expose tvm.cast * Update * Add unittest --- python/tvm/api.py | 17 +++++++++++++++++ python/tvm/expr.py | 14 ++++++++++++++ src/api/api_ir.cc | 1 + tests/python/unittest/test_lang_tensor.py | 2 ++ 4 files changed, 34 insertions(+) diff --git a/python/tvm/api.py b/python/tvm/api.py index a48741208..dea0252c6 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -467,6 +467,23 @@ def reduce_axis(dom, name="rv"): """ return _IterVar(dom, name, 2) +def cast(dtype, expr): + """Cast an expression to other type + Parameters + ---------- + dtype : str, optional + The type of new expression + expr : Expr + The expression + + Returns + ------- + expr : Expr + Expression with new type + """ + return _make.Cast(dtype, expr) + + def select(cond, t, f): """Construct a select branch Parameters diff --git a/python/tvm/expr.py b/python/tvm/expr.py index 49aeaebe4..42227bdea 100644 --- a/python/tvm/expr.py +++ b/python/tvm/expr.py @@ -89,6 +89,20 @@ class ExprOp(object): """ return _make.EQ(self, other) + def astype(self, dtype): + """Cast the expression to other type + Parameters + ---------- + dtype : str, optional + The type of new expression + + Returns + ------- + expr : Expr + Expression with new type + """ + return _make.Cast(dtype, self) + class Expr(NodeBase, ExprOp): """Base class of all tvm Expressions""" diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc index f66652b99..700418c4b 100644 --- a/src/api/api_ir.cc +++ b/src/api/api_ir.cc @@ -139,6 +139,7 @@ REGISTER_MAKE_BINARY_OP(Or); REGISTER_MAKE1(Not); REGISTER_MAKE3(Select); REGISTER_MAKE3(Ramp); +REGISTER_MAKE2(Cast); REGISTER_MAKE2(Broadcast); REGISTER_MAKE3(Let); REGISTER_MAKE3(LetStmt); diff --git a/tests/python/unittest/test_lang_tensor.py b/tests/python/unittest/test_lang_tensor.py index 73d889f1b..371435351 100644 --- a/tests/python/unittest/test_lang_tensor.py +++ b/tests/python/unittest/test_lang_tensor.py @@ -16,6 +16,8 @@ def test_tensor(): assert(T.op.output(0).__hash__() == T.__hash__()) d = {T.op.output(0) : 1} assert(d[T] == 1) + assert(tvm.cast('float16', T[0][0][0]).dtype == 'float16') + assert(T[0][0][0].astype('float16').dtype == 'float16') def test_conv1d(): -- GitLab