Skip to content
Snippets Groups Projects
Commit 77345051 authored by tqchen's avatar tqchen
Browse files

Move Var back to Expr, add format str test

parent f5b8196d
No related branches found
No related tags found
No related merge requests found
......@@ -2,3 +2,5 @@
from __future__ import absolute_import as _abs
from .op import *
from .expr import Var, const
from .expr_util import *
......@@ -2,13 +2,16 @@
from __future__ import absolute_import as _abs
from numbers import Number as _Number
from . import op as _op
from . import var_name as _name
class Expr(object):
"""Base class of expression."""
"""Base class of expression.
Expression object should be in general immutable.
"""
def children(self):
"""All expr must define this.
"""get children of this expression.
Returns
-------
......@@ -60,6 +63,21 @@ def _symbol(value):
raise TypeError("type %s not supported" % str(type(other)))
class Var(Expr):
"""Variable, is a symbolic placeholder.
Each variable is uniquely identified by its address
Note that name alone is not able to uniquely identify the var.
Parameters
----------
name : str
optional name to the var.
"""
def __init__(self, name=None):
self.name = name if name else _name.NameManager.current.get(name)
class ConstExpr(Expr):
"""Constant expression."""
def __init__(self, value):
......@@ -77,7 +95,6 @@ class BinaryOpExpr(Expr):
def children(self):
return (self.lhs, self.rhs)
_op.binary_op_cls = BinaryOpExpr
class UnaryOpExpr(Expr):
......@@ -88,3 +105,8 @@ class UnaryOpExpr(Expr):
def children(self):
return (self.src)
def const(value):
"""Return a constant value"""
return ConstExpr(value)
"""Utilities to manipulate expression"""
from __future__ import absolute_import as _abs
from . import expr as _expr
def expr_with_new_children(e, children):
"""Returns same expr as e but with new children
A shallow copy of e will happen if children differs from current children
Parameters
----------
e : Expr
The input expression
children : list of Expr
The new children
Returns
-------
new_e : Expr
Expression with the new children
"""
if children:
if isinstance(e, _expr.BinaryOpExpr):
return (e if children[0] == e.lhs and children[1] == e.rhs
else _expr.BinaryOpExpr(e.op, children[0], children[1]))
elif isinstance(e, _expr.UnaryOpExpr):
return e if children[0] == e.src else _expr.UnaryOpExpr(e.op, children[0])
else:
raise TypeError("donnot know how to handle Expr %s" % type(e))
else:
return e
def transform(e, f):
"""Apply f recursively to e and collect the resulr
Parameters
----------
e : Expr
The input expression.
f : function with signiture (e, ret_children)
ret_children is the result of transform from children
Returns
-------
result : return value of f
The final result of transformation.
"""
return f(e , [transform(c, f) for c in e.children()])
def format_str(expr):
"""change expression to string.
Parameters
----------
expr : Expr
Input expression
Returns
-------
s : str
The string representation of expr
"""
def make_str(e, result_children):
if isinstance(e, _expr.BinaryOpExpr):
return e.op.format_str(result_children[0], result_children[1])
elif isinstance(e, _expr.UnaryOpExpr):
return e.op.format_str(result_children[0])
elif isinstance(e, _expr.ConstExpr):
return str(e.value)
elif isinstance(e, _expr.Var):
return e.name
else:
raise TypeError("Do not know how to handle type " + str(type(e)))
return transform(expr, make_str)
def bind(expr, update_dict):
"""Replace the variable in e by specification from kwarg
Parameters
----------
expr : Expr
Input expression
update_dict : dict of Var->Expr
The variables to be replaced.
Examples
--------
eout = bind(e, update_dict={v1: (x+1)} )
"""
def replace(e, result_children):
if isinstance(e, _expr.Var) and e in update_dict:
return update_dict[e]
else:
return expr_with_new_children(e, result_children)
return transform(expr, replace)
......@@ -8,16 +8,20 @@ class BinaryOp(object):
return _binary_op_cls(self, lhs, rhs)
class AddOp(BinaryOp):
pass
def format_str(self, lhs, rhs):
return '(%s + %s)' % (lhs, rhs)
class SubOp(BinaryOp):
pass
def format_str(self, lhs, rhs):
return '(%s - %s)' % (lhs, rhs)
class MulOp(BinaryOp):
pass
def format_str(self, lhs, rhs):
return '(%s * %s)' % (lhs, rhs)
class DivOp(BinaryOp):
pass
def format_str(self, lhs, rhs):
return '(%s / %s)' % (lhs, rhs)
add = AddOp()
......
from __future__ import absolute_import as _abs
from .expr import Expr
class Var(Expr):
"""Variables"""
def __init__(self, name, expr=None):
self.name = name
self.expr = expr
def assign(self, expr):
self.expr = expr
def children(self):
if self.expr is None:
return ()
return self.expr.children()
def same_as(self, other):
return (self.name == other.name)
"""Name manager to make sure name is unique."""
from __future__ import absolute_import as _abs
class NameManager(object):
"""NameManager to do automatic naming.
User can also inherit this object to change naming behavior.
"""
current = None
def __init__(self):
self._counter = {}
self._old_manager = None
def get(self, hint):
"""Get the canonical name for a symbol.
This is default implementation.
When user specified a name,
the user specified name will be used.
When user did not, we will automatically generate a
name based on hint string.
Parameters
----------
hint : str
A hint string, which can be used to generate name.
Returns
-------
full_name : str
A canonical name for the user.
"""
if hint not in self._counter:
self._counter[hint] = 0
name = '%s%d' % (hint, self._counter[hint])
self._counter[hint] += 1
return name
def __enter__(self):
self._old_manager = NameManager.current
NameManager.current = self
return self
def __exit__(self, ptype, value, trace):
assert self._old_manager
NameManager.current = self._old_manager
# initialize the default name manager
NameManager.current = NameManager()
import tvm
from tvm import expr
def test_const():
x = expr.ConstExpr(1)
x + 1
print x
def test_bind():
x = tvm.Var('x')
y = x + 1
z = tvm.bind(y, {x: tvm.const(10) + 9})
assert tvm.format_str(z) == '((10 + 9) + 1)'
test_const()
def test_basic():
a= tvm.Var('a')
b = tvm.Var('b')
c = a + b
assert tvm.format_str(c) == '(%s + %s)' % (a.name, b.name)
if __name__ == "__main__":
test_basic()
test_bind()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment