diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index 07f5fa5f8bc8172496dceb9fcaed7fffeb3a19b2..79274a228e8b78908c3baf294fcfe20cec5bb258 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -2,3 +2,5 @@ from __future__ import absolute_import as _abs from .op import * +from .expr import Var, const +from .expr_util import * diff --git a/python/tvm/expr.py b/python/tvm/expr.py index b6bac8808bb710edca671c431583262b25b4ac24..bfdb07dc19b64db7bbf1c02e6b17374dd30b7032 100644 --- a/python/tvm/expr.py +++ b/python/tvm/expr.py @@ -2,13 +2,16 @@ from __future__ import absolute_import as _abs from numbers import Number as _Number from . import op as _op - +from . import var_name as _name class Expr(object): - """Base class of expression.""" + """Base class of expression. + + Expression object should be in general immutable. + """ def children(self): - """All expr must define this. + """get children of this expression. Returns ------- @@ -60,6 +63,21 @@ def _symbol(value): raise TypeError("type %s not supported" % str(type(other))) +class Var(Expr): + """Variable, is a symbolic placeholder. + + Each variable is uniquely identified by its address + Note that name alone is not able to uniquely identify the var. + + Parameters + ---------- + name : str + optional name to the var. + """ + def __init__(self, name=None): + self.name = name if name else _name.NameManager.current.get(name) + + class ConstExpr(Expr): """Constant expression.""" def __init__(self, value): @@ -77,7 +95,6 @@ class BinaryOpExpr(Expr): def children(self): return (self.lhs, self.rhs) - _op.binary_op_cls = BinaryOpExpr class UnaryOpExpr(Expr): @@ -88,3 +105,8 @@ class UnaryOpExpr(Expr): def children(self): return (self.src) + + +def const(value): + """Return a constant value""" + return ConstExpr(value) diff --git a/python/tvm/expr_util.py b/python/tvm/expr_util.py new file mode 100644 index 0000000000000000000000000000000000000000..b79d1d0998580fcae6c8935aa211b3f24814031f --- /dev/null +++ b/python/tvm/expr_util.py @@ -0,0 +1,101 @@ +"""Utilities to manipulate expression""" +from __future__ import absolute_import as _abs +from . import expr as _expr + +def expr_with_new_children(e, children): + """Returns same expr as e but with new children + + A shallow copy of e will happen if children differs from current children + + Parameters + ---------- + e : Expr + The input expression + + children : list of Expr + The new children + + Returns + ------- + new_e : Expr + Expression with the new children + """ + if children: + if isinstance(e, _expr.BinaryOpExpr): + return (e if children[0] == e.lhs and children[1] == e.rhs + else _expr.BinaryOpExpr(e.op, children[0], children[1])) + elif isinstance(e, _expr.UnaryOpExpr): + return e if children[0] == e.src else _expr.UnaryOpExpr(e.op, children[0]) + else: + raise TypeError("donnot know how to handle Expr %s" % type(e)) + else: + return e + + +def transform(e, f): + """Apply f recursively to e and collect the resulr + + Parameters + ---------- + e : Expr + The input expression. + + f : function with signiture (e, ret_children) + ret_children is the result of transform from children + + Returns + ------- + result : return value of f + The final result of transformation. + """ + return f(e , [transform(c, f) for c in e.children()]) + + +def format_str(expr): + """change expression to string. + + Parameters + ---------- + expr : Expr + Input expression + + Returns + ------- + s : str + The string representation of expr + """ + def make_str(e, result_children): + if isinstance(e, _expr.BinaryOpExpr): + return e.op.format_str(result_children[0], result_children[1]) + elif isinstance(e, _expr.UnaryOpExpr): + return e.op.format_str(result_children[0]) + elif isinstance(e, _expr.ConstExpr): + return str(e.value) + elif isinstance(e, _expr.Var): + return e.name + else: + raise TypeError("Do not know how to handle type " + str(type(e))) + return transform(expr, make_str) + + +def bind(expr, update_dict): + """Replace the variable in e by specification from kwarg + + Parameters + ---------- + expr : Expr + Input expression + + update_dict : dict of Var->Expr + The variables to be replaced. + + Examples + -------- + eout = bind(e, update_dict={v1: (x+1)} ) + """ + def replace(e, result_children): + if isinstance(e, _expr.Var) and e in update_dict: + return update_dict[e] + else: + return expr_with_new_children(e, result_children) + return transform(expr, replace) diff --git a/python/tvm/op.py b/python/tvm/op.py index 238552d6e53273c9bd3188837ca8eb4476cd3907..464c5a10bb265eb543beb1674299653d18c52e85 100644 --- a/python/tvm/op.py +++ b/python/tvm/op.py @@ -8,16 +8,20 @@ class BinaryOp(object): return _binary_op_cls(self, lhs, rhs) class AddOp(BinaryOp): - pass + def format_str(self, lhs, rhs): + return '(%s + %s)' % (lhs, rhs) class SubOp(BinaryOp): - pass + def format_str(self, lhs, rhs): + return '(%s - %s)' % (lhs, rhs) class MulOp(BinaryOp): - pass + def format_str(self, lhs, rhs): + return '(%s * %s)' % (lhs, rhs) class DivOp(BinaryOp): - pass + def format_str(self, lhs, rhs): + return '(%s / %s)' % (lhs, rhs) add = AddOp() diff --git a/python/tvm/var.py b/python/tvm/var.py deleted file mode 100644 index 1988e4306be5d031e324d5636f741ca817321c5f..0000000000000000000000000000000000000000 --- a/python/tvm/var.py +++ /dev/null @@ -1,19 +0,0 @@ -from __future__ import absolute_import as _abs -from .expr import Expr - -class Var(Expr): - """Variables""" - def __init__(self, name, expr=None): - self.name = name - self.expr = expr - - def assign(self, expr): - self.expr = expr - - def children(self): - if self.expr is None: - return () - return self.expr.children() - - def same_as(self, other): - return (self.name == other.name) diff --git a/python/tvm/var_name.py b/python/tvm/var_name.py new file mode 100644 index 0000000000000000000000000000000000000000..57044e83b25d07b205f1407d575a97187857ac21 --- /dev/null +++ b/python/tvm/var_name.py @@ -0,0 +1,51 @@ +"""Name manager to make sure name is unique.""" +from __future__ import absolute_import as _abs + +class NameManager(object): + """NameManager to do automatic naming. + + User can also inherit this object to change naming behavior. + """ + current = None + + def __init__(self): + self._counter = {} + self._old_manager = None + + def get(self, hint): + """Get the canonical name for a symbol. + + This is default implementation. + When user specified a name, + the user specified name will be used. + + When user did not, we will automatically generate a + name based on hint string. + + Parameters + ---------- + hint : str + A hint string, which can be used to generate name. + + Returns + ------- + full_name : str + A canonical name for the user. + """ + if hint not in self._counter: + self._counter[hint] = 0 + name = '%s%d' % (hint, self._counter[hint]) + self._counter[hint] += 1 + return name + + def __enter__(self): + self._old_manager = NameManager.current + NameManager.current = self + return self + + def __exit__(self, ptype, value, trace): + assert self._old_manager + NameManager.current = self._old_manager + +# initialize the default name manager +NameManager.current = NameManager() diff --git a/tests/python/test_basic.py b/tests/python/test_basic.py index fe40810e926c3d46cc8ada92c014855a497925ce..c76520ea1bc41e7db54123c27d2f026954c7f9d8 100644 --- a/tests/python/test_basic.py +++ b/tests/python/test_basic.py @@ -1,9 +1,20 @@ import tvm from tvm import expr -def test_const(): - x = expr.ConstExpr(1) - x + 1 - print x +def test_bind(): + x = tvm.Var('x') + y = x + 1 + z = tvm.bind(y, {x: tvm.const(10) + 9}) + assert tvm.format_str(z) == '((10 + 9) + 1)' -test_const() + +def test_basic(): + a= tvm.Var('a') + b = tvm.Var('b') + c = a + b + assert tvm.format_str(c) == '(%s + %s)' % (a.name, b.name) + + +if __name__ == "__main__": + test_basic() + test_bind()