Skip to content
Snippets Groups Projects
Commit dcddd208 authored by tqchen's avatar tqchen
Browse files

finish tensor dom infer

parent 6819145a
No related branches found
No related tags found
No related merge requests found
......@@ -5,4 +5,4 @@ from .op import *
from .expr import Var, const
from .expr_util import *
from .tensor import Tensor
from .domain import RDom, Range
from .domain import RDom, Range, infer_range
from __future__ import absolute_import as _abs
from . import expr as _expr
from . import expr_util as _expr_util
from . import op as _op
class Range(object):
"""Represent a range in one dimension.
......@@ -10,10 +10,15 @@ class Range(object):
if end is None:
end = begin
begin = _expr.const(0)
self.begin = _expr._symbol(begin)
self.end = _expr._symbol(end)
begin = _expr_util.simplify(_expr._symbol(begin))
end = _expr_util.simplify(_expr._symbol(end))
self.begin = begin
self.end = end
self.extent = _expr_util.simplify(end - begin)
def is_value(self):
return isinstance(self.extent, _expr.ConstExpr) and self.extend.value == 1
def __str__(self):
return "(%s, %s)" % (
_expr_util.format_str(self.begin),
......@@ -22,9 +27,13 @@ class Range(object):
def __repr__(self):
return self.__str__()
class RangeInferError(ValueError):
pass
class RDom(object):
"""reduction Domain
"""
"""Reduction Domain."""
def __init__(self, domain):
if isinstance(domain, Range):
domain = [domain]
......@@ -36,3 +45,63 @@ class RDom(object):
"""Use list of ranges as domain"""
Domain = list
def _combine_range_binary_op(op, lhs, rhs):
if op == _op.add:
return Range(lhs.begin + rhs.begin, lhs.end + rhs.end - 1)
elif op == _op.sub:
return Range(lhs.begin - rhs.end + 1, lhs.end - rhs.begin)
elif op == _op.mul:
v = None
if lhs.is_value():
v = lhs.begin.value
e = rhs
elif rhs.is_value():
v = rhs.begin.value
e = lhs
if v == -1:
return Range(-e.end, -e.begin)
raise InferRangeError("donot know how to infer range for %s" % type(op))
def infer_range(e, range_dict, allow_unbind_var=True):
"""Infer the range of result e given range of variables.
Parameters
----------
expr : Expr
Input expression
range_dict : dict of Var->Range
The variables to be replaced.
allow_unbind_var: bool
Whether allow unbinded variables
"""
def combine_range(e, result_children):
if isinstance(e, _expr.ConstExpr):
return Range(e, e + 1)
elif isinstance(e, _expr.BinaryOpExpr):
return _combine_range_binary_op(e.op, result_children[0], result_children[1])
elif isinstance(e, _expr.Var):
if e in range_dict:
return range_dict[e]
else:
if allow_unbind_var:
return Range(e, e + 1)
else:
raise ValueError("Cannot find var %s in range_dict" % e.name)
else:
raise InferRangeError("cannot infer range for %s" % _expr_util.format_str(e))
return _expr_util.transform(e, combine_range)
def union_range(lhs, rhs):
if lhs is None:
return rhs
if rhs is None:
return lhs
begin = _op.min(lhs.begin, rhs.begin)
end = _op.max(rhs.end, lhs.end)
return Range(begin, end)
......@@ -22,7 +22,6 @@ def canonical_to_expr(c):
else:
return _expr.const(0)
class BinaryOp(object):
"""Base class of binary operator"""
def __call__(self, lhs, rhs):
......@@ -45,7 +44,6 @@ class AddOp(BinaryOp):
lhs[k] = v
return lhs
class SubOp(BinaryOp):
def format_str(self, lhs, rhs):
return '(%s - %s)' % (lhs, rhs)
......
from __future__ import absolute_import as _abs
from . import expr as _expr
from . import expr_util as _expr_util
from . import domain as _dom
class Tensor(object):
......@@ -39,16 +40,17 @@ class Tensor(object):
"""
if self.inputs is not None:
return self.inputs
self.inputs = []
inputs = []
if self.expr:
def collect(e):
if isinstance(e, _expr.TensorReadExpr):
self.inputs.append(e.tensor)
inputs.append(e.tensor)
_expr_util.visit(self.expr, collect)
self.inputs = set(inputs)
return self.inputs
def infer_input_domains(self, out_domain):
"""Infer the input domains of each domain given output domains
def infer_input_domains(self, out_domain, inputs):
"""Infer the input domains of each domain in given inputs list.
Parameters
----------
......@@ -64,7 +66,26 @@ class Tensor(object):
index_domains = {
self.dim_index[i] : out_domain[i] for i in range(len(out_domain))
}
def collect(e):
if isinstance(e, _expr.TensorReadExpr):
self.inputs.append(e.tensor)
_expr_util.visit(self.expr, collect)
iset = {}
for t in inputs:
assert t in self.input_tensors()
iset[t] = []
def prepare(e):
if isinstance(e, _expr.ReduceExpr):
rd = e.rdom
for i in range(len(rd.domain)):
index_domains[rd.index[i]] = rd.domain[i]
elif isinstance(e, _expr.TensorReadExpr):
if e.tensor in iset:
iset[e.tensor].append(e)
_expr_util.visit(self.expr, prepare)
result = {}
for k, v in iset.items():
dm = [None] * len(v[0].indices)
for e in v:
for i, idx in enumerate(e.indices):
dm[i] = _dom.union_range(
dm[i], _dom.infer_range(idx, index_domains, allow_unbind_var=False))
result[k] = dm
return result
import tvm
def test_range_infer():
x = tvm.Var('x')
y = tvm.Var('y')
t = tvm.Var('t')
z = x + y + t
zr = tvm.infer_range(z, {x: tvm.Range(10, 20), y : tvm.Range(10, 11)})
assert str(zr) == "((t0 + 20), (t0 + 30))"
def test_tensor_dom_infer():
A = tvm.Tensor(2, name='A')
B = tvm.Tensor(2, name='B')
T = tvm.Tensor(3, lambda i, j, k: A(i, k) * B(j, k),
shape=(A.shape[0], B.shape[0], A.shape[1]))
rd = tvm.RDom(tvm.Range(A.shape[1]))
C = tvm.Tensor(2, lambda i, j: tvm.reduce_sum(T(i, j, rd.index[0]), rdom=rd),
shape=(A.shape[0], B.shape[0]))
cdom = [tvm.Range(0, 10), tvm.Range(1, 11)]
tdom = C.infer_input_domains(cdom, inputs=[T])[T]
assert str(tdom[0]) == "(0, 10)"
if __name__ == "__main__":
test_range_infer()
test_tensor_dom_infer()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment