Skip to content
Snippets Groups Projects
Commit 6819145a authored by tqchen's avatar tqchen
Browse files

checkin domain

parent bda95817
No related branches found
No related tags found
No related merge requests found
......@@ -5,3 +5,4 @@ from .op import *
from .expr import Var, const
from .expr_util import *
from .tensor import Tensor
from .domain import RDom, Range
from __future__ import absolute_import as _abs
from . import expr as _expr
from . import expr_util as _expr_util
class Range(object):
"""Represent a range in one dimension.
"""
def __init__(self, begin, end=None):
if end is None:
end = begin
begin = _expr.const(0)
self.begin = _expr._symbol(begin)
self.end = _expr._symbol(end)
self.extent = _expr_util.simplify(end - begin)
def __str__(self):
return "(%s, %s)" % (
_expr_util.format_str(self.begin),
_expr_util.format_str(self.end))
def __repr__(self):
return self.__str__()
class RDom(object):
"""reduction Domain
"""
def __init__(self, domain):
if isinstance(domain, Range):
domain = [domain]
self.index = []
self.domain = domain
for i in range(len(domain)):
self.index.append(_expr.Var("rd_index_%d_" % i))
"""Use list of ranges as domain"""
Domain = list
......@@ -108,7 +108,27 @@ class UnaryOpExpr(Expr):
self.src = _symbol(src)
def children(self):
return (self.src)
return (self.src,)
class ReduceExpr(Expr):
def __init__(self, op, src, rdom):
self.op = op
self.src = src
self.rdom = rdom
def children(self):
return (self.src,)
class TensorReadExpr(Expr):
"""Tensor read expression, tensor[indices]"""
def __init__(self, tensor, indices):
self.tensor = tensor
self.indices = indices
def children(self):
return self.indices
def const(value):
......
......@@ -2,7 +2,6 @@
from __future__ import absolute_import as _abs
from . import expr as _expr
from . import op as _op
from . import tensor as _tensor
def expr_with_new_children(e, children):
"""Returns same expr as e but with new children
......@@ -50,10 +49,27 @@ def transform(e, f):
result : return value of f
The final result of transformation.
"""
assert isinstance(e, _expr.Expr)
if not isinstance(e, _expr.Expr):
raise TypeError("Cannot handle type %s" % type(e))
return f(e , [transform(c, f) for c in e.children()])
def visit(e, f):
"""Apply f to each element of e
Parameters
----------
e : Expr
The input expression.
f : function with signiture (e)
"""
assert isinstance(e, _expr.Expr)
for c in e.children():
visit(c, f)
f(e)
def format_str(expr):
"""change expression to string.
......@@ -76,12 +92,15 @@ def format_str(expr):
return str(e.value)
elif isinstance(e, _expr.Var):
return e.name
elif isinstance(e, _tensor.TensorReadExpr):
elif isinstance(e, _expr.TensorReadExpr):
return "%s(%s)" % (e.tensor.name, ','.join(result_children))
elif isinstance(e, _expr.ReduceExpr):
return e.op.format_reduce_str(result_children[0], e.rdom.domain)
else:
raise TypeError("Do not know how to handle type " + str(type(e)))
return transform(expr, make_str)
def simplify(expr):
"""simplify expression
......
......@@ -22,15 +22,20 @@ def canonical_to_expr(c):
else:
return _expr.const(0)
class BinaryOp(object):
"""Base class of binary operator"""
def __call__(self, lhs, rhs):
return _expr.BinaryOpExpr(self, lhs, rhs)
class AddOp(BinaryOp):
def format_str(self, lhs, rhs):
return '(%s + %s)' % (lhs, rhs)
def format_reduce_str(self, src, rd):
return "reduce_sum(%s, rdom=%s)" % (src, str(rd))
def canonical(self, lhs, rhs):
lhs = lhs.copy()
for k, v in rhs.items():
......@@ -40,6 +45,7 @@ class AddOp(BinaryOp):
lhs[k] = v
return lhs
class SubOp(BinaryOp):
def format_str(self, lhs, rhs):
return '(%s - %s)' % (lhs, rhs)
......@@ -53,6 +59,7 @@ class SubOp(BinaryOp):
lhs[k] = -v
return lhs
class MulOp(BinaryOp):
def format_str(self, lhs, rhs):
return '(%s * %s)' % (lhs, rhs)
......@@ -72,6 +79,7 @@ class MulOp(BinaryOp):
return rhs
return {elhs * erhs: 1}
class DivOp(BinaryOp):
def format_str(self, lhs, rhs):
return '(%s / %s)' % (lhs, rhs)
......@@ -86,6 +94,7 @@ class DivOp(BinaryOp):
elhs = canonical_to_expr(lhs)
return {elhs / erhs: 1}
class MaxOp(BinaryOp):
def format_str(self, lhs, rhs):
return 'max(%s, %s)' % (lhs, rhs)
......@@ -97,6 +106,7 @@ class MaxOp(BinaryOp):
return lhs if ediff.value >= 0 else rhs
return {MaxOp()(lhs, rhs): 1}
class MinOp(BinaryOp):
def format_str(self, lhs, rhs):
return 'min(%s, %s)' % (lhs, rhs)
......@@ -120,3 +130,16 @@ _expr.__addop__ = add
_expr.__subop__ = sub
_expr.__mulop__ = mul
_expr.__divop__ = div
def reduce_sum(expr, rdom):
return _expr.ReduceExpr(add, expr, rdom)
def reduce_prod(expr, rdom):
return _expr.ReduceExpr(mul, expr, rdom)
def reduce_min(expr, rdom):
return _expr.ReduceExpr(min, expr, rdom)
def reduce_max(expr, rdom):
return _expr.ReduceExpr(max, expr, rdom)
from __future__ import absolute_import as _abs
from . import expr as _expr
class TensorReadExpr(_expr.Expr):
def __init__(self, tensor, indices):
self.tensor = tensor
self.indices = indices
def children(self):
return self.indices
from . import expr_util as _expr_util
class Tensor(object):
def __init__(self, ndim, fcompute=None, name=None):
def __init__(self, ndim, fcompute=None, name=None, shape=None):
self.ndim = ndim
if fcompute:
arg_names = fcompute.func_code.co_varnames
assert(len(arg_names) == ndim)
self.dim_index = [_expr.Var(n) for n in arg_names]
self.expr = fcompute(*self.dim_index)
if shape is None:
raise ValueError("argument shape need to be given for intermediate tensor")
self.shape = shape
else:
self.expr = None
self.dim_index = None
shape_name = '_shape'
if name: shape_name = name + shape_name
self.shape = tuple(_expr.Var("%s_%d_" % (shape_name, i)) for i in range(ndim))
self.shape = shape if shape else tuple(
_expr.Var("%s_%d_" % (shape_name, i)) for i in range(ndim))
self.name = name if name else "TensorObj"
self.inputs = None
def __call__(self, *indices):
if len(indices) != self.ndim:
raise ValueError("Need to provide %d index in tensor slice" % self.ndim)
return TensorReadExpr(self, indices)
return _expr.TensorReadExpr(self, indices)
def input_tensors(self):
"""List of input tensors to this tensor.
Returns
-------
inputs : list of input tensors
"""
if self.inputs is not None:
return self.inputs
self.inputs = []
if self.expr:
def collect(e):
if isinstance(e, _expr.TensorReadExpr):
self.inputs.append(e.tensor)
_expr_util.visit(self.expr, collect)
return self.inputs
def infer_input_domains(self, out_domain):
"""Infer the input domains of each domain given output domains
Parameters
----------
out_domain : list of Range
Domain of each dimension.
Returns
-------
in_domains: dict Tensor->Domain
"""
assert self.expr
assert len(out_domain) == len(self.dim_index)
index_domains = {
self.dim_index[i] : out_domain[i] for i in range(len(out_domain))
}
def collect(e):
if isinstance(e, _expr.TensorReadExpr):
self.inputs.append(e.tensor)
_expr_util.visit(self.expr, collect)
......@@ -3,8 +3,27 @@ import tvm
def test_tensor():
A = tvm.Tensor(2, name='A')
B = tvm.Tensor(2, name='B')
T = tvm.Tensor(3, lambda i, j, k: A(i, k) * B(j, k))
T = tvm.Tensor(3, lambda i, j, k: A(i, k) * B(j, k),
shape=(A.shape[0], B.shape[0], A.shape[1]))
print(tvm.format_str(T.expr))
def test_tensor_inputs():
A = tvm.Tensor(2, name='A')
B = tvm.Tensor(2, name='B')
T = tvm.Tensor(3, lambda i, j, k: A(i, k) * B(j, k),
shape=(A.shape[0], B.shape[0], A.shape[1]))
assert(T.input_tensors() == [A, B])
def test_tensor_reduce():
A = tvm.Tensor(2, name='A')
B = tvm.Tensor(2, name='B')
T = tvm.Tensor(3, lambda i, j, k: A(i, k) * B(j, k),
shape=(A.shape[0], B.shape[0], A.shape[1]))
rd = tvm.RDom(tvm.Range(A.shape[1]))
C = tvm.Tensor(2, lambda i, j: tvm.reduce_sum(T(i, j, rd.index[0]), rdom=rd),
shape=(A.shape[0], B.shape[0]))
print(tvm.format_str(C.expr))
if __name__ == "__main__":
test_tensor()
test_tensor_inputs()
test_tensor_reduce()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment