From 050bc91bbbb301283f116faa65b49dbc61e569e6 Mon Sep 17 00:00:00 2001
From: Yizhi Liu <javelinjs@gmail.com>
Date: Sun, 21 May 2017 11:37:55 +0800
Subject: [PATCH] add tvm.select (#148)

---
 python/tvm/_ffi/node_generic.py |  2 ++
 python/tvm/api.py               | 17 +++++++++++++++++
 2 files changed, 19 insertions(+)

diff --git a/python/tvm/_ffi/node_generic.py b/python/tvm/_ffi/node_generic.py
index 457a05007..7561097bf 100644
--- a/python/tvm/_ffi/node_generic.py
+++ b/python/tvm/_ffi/node_generic.py
@@ -34,6 +34,8 @@ def convert_to_node(value):
     """
     if isinstance(value, _CLASS_NODE_BASE):
         return value
+    elif isinstance(value, bool):
+        return const(value, 'uint1x1')
     elif isinstance(value, Number):
         return const(value)
     elif isinstance(value, string_types):
diff --git a/python/tvm/api.py b/python/tvm/api.py
index b305d0214..2ef18d210 100644
--- a/python/tvm/api.py
+++ b/python/tvm/api.py
@@ -463,6 +463,23 @@ def reduce_axis(dom, name="rv"):
     """
     return _IterVar(dom, name, 2)
 
+def select(cond, t, f):
+    """Construct a select branch
+    Parameters
+    ----------
+    cond : Expr
+        The condition
+    t : Expr
+        The result expression if cond is true.
+    f : Expr
+        The result expression if cond is false.
+
+    Returns
+    -------
+    node : Node
+        The tvm.expr.Select node
+    """
+    return _make.Select(convert(cond), convert(t), convert(f))
 
 def comm_reducer(fcombine, fidentity, name="reduce"):
     """Create a commutative reducer for reduction.
-- 
GitLab