From 050bc91bbbb301283f116faa65b49dbc61e569e6 Mon Sep 17 00:00:00 2001 From: Yizhi Liu <javelinjs@gmail.com> Date: Sun, 21 May 2017 11:37:55 +0800 Subject: [PATCH] add tvm.select (#148) --- python/tvm/_ffi/node_generic.py | 2 ++ python/tvm/api.py | 17 +++++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/python/tvm/_ffi/node_generic.py b/python/tvm/_ffi/node_generic.py index 457a05007..7561097bf 100644 --- a/python/tvm/_ffi/node_generic.py +++ b/python/tvm/_ffi/node_generic.py @@ -34,6 +34,8 @@ def convert_to_node(value): """ if isinstance(value, _CLASS_NODE_BASE): return value + elif isinstance(value, bool): + return const(value, 'uint1x1') elif isinstance(value, Number): return const(value) elif isinstance(value, string_types): diff --git a/python/tvm/api.py b/python/tvm/api.py index b305d0214..2ef18d210 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -463,6 +463,23 @@ def reduce_axis(dom, name="rv"): """ return _IterVar(dom, name, 2) +def select(cond, t, f): + """Construct a select branch + Parameters + ---------- + cond : Expr + The condition + t : Expr + The result expression if cond is true. + f : Expr + The result expression if cond is false. + + Returns + ------- + node : Node + The tvm.expr.Select node + """ + return _make.Select(convert(cond), convert(t), convert(f)) def comm_reducer(fcombine, fidentity, name="reduce"): """Create a commutative reducer for reduction. -- GitLab