Skip to content
Snippets Groups Projects
Commit 1a18f08e authored by tqchen's avatar tqchen
Browse files

Fold RTensor into tensor

parent dcddd208
No related branches found
No related tags found
No related merge requests found
......@@ -25,6 +25,7 @@ class Tensor(object):
self.name = name if name else "TensorObj"
self.inputs = None
self.rdom = None
def __call__(self, *indices):
if len(indices) != self.ndim:
......@@ -49,7 +50,7 @@ class Tensor(object):
self.inputs = set(inputs)
return self.inputs
def infer_input_domains(self, out_domain, inputs):
def infer_input_domains(self, out_domain, inputs, red_domain=None):
"""Infer the input domains of each domain in given inputs list.
Parameters
......@@ -57,6 +58,12 @@ class Tensor(object):
out_domain : list of Range
Domain of each dimension.
red_domain : list of Range
Domain of reduction variables, if this tensor
this can only be specified if
self.expr finishes with an ReduceExpr, and we can schedule
over the last reduction that creates this tensor.
Returns
-------
in_domains: dict Tensor->Domain
......@@ -66,6 +73,17 @@ class Tensor(object):
index_domains = {
self.dim_index[i] : out_domain[i] for i in range(len(out_domain))
}
begin_expr = self.expr
if red_domain:
if not isinstance(self.expr, _expr.ReduceExpr):
raise ValueError("red_domain must work with tensor that stores a reduction")
rdom = self.expr.rdom
begin_expr = self.expr.src
assert len(red_domain) == len(rdom.index)
for i in range(len(red_domain)):
index_domains[rdom.index[i]] = red_domain[i]
iset = {}
for t in inputs:
assert t in self.input_tensors()
......@@ -79,7 +97,7 @@ class Tensor(object):
elif isinstance(e, _expr.TensorReadExpr):
if e.tensor in iset:
iset[e.tensor].append(e)
_expr_util.visit(self.expr, prepare)
_expr_util.visit(begin_expr, prepare)
result = {}
for k, v in iset.items():
dm = [None] * len(v[0].indices)
......@@ -89,3 +107,13 @@ class Tensor(object):
dm[i], _dom.infer_range(idx, index_domains, allow_unbind_var=False))
result[k] = dm
return result
@property
def is_rtensor(self):
"""Whether this tensor is a result of reduction.
Returns
-------
is_rtensor : Whether the tensor is RTensor
"""
return self.expr and isinstance(self.expr, _expr.ReduceExpr)
......@@ -11,14 +11,16 @@ def test_range_infer():
def test_tensor_dom_infer():
A = tvm.Tensor(2, name='A')
B = tvm.Tensor(2, name='B')
T = tvm.Tensor(3, lambda i, j, k: A(i, k) * B(j, k),
shape=(A.shape[0], B.shape[0], A.shape[1]))
rd = tvm.RDom(tvm.Range(A.shape[1]))
C = tvm.Tensor(2, lambda i, j: tvm.reduce_sum(T(i, j, rd.index[0]), rdom=rd),
T = tvm.Tensor(2, lambda i, j:
tvm.reduce_sum(A(i, rd.index[0]) * B(j, rd.index[0]), rdom=rd),
shape=(A.shape[0], B.shape[0]))
C = tvm.Tensor(2, lambda i, j: T(i,j),
shape=(A.shape[0], B.shape[0]))
cdom = [tvm.Range(0, 10), tvm.Range(1, 11)]
tdom = C.infer_input_domains(cdom, inputs=[T])[T]
assert T.is_rtensor
assert str(tdom[0]) == "(0, 10)"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment