Skip to content
Snippets Groups Projects
Commit f03483bf authored by Haichen Shen's avatar Haichen Shen
Browse files

checked split

parent 1a18f08e
No related branches found
No related tags found
No related merge requests found
......@@ -6,3 +6,4 @@ from .expr import Var, const
from .expr_util import *
from .tensor import Tensor
from .domain import RDom, Range, infer_range
from .split import Split
......@@ -17,7 +17,7 @@ class Range(object):
self.extent = _expr_util.simplify(end - begin)
def is_value(self):
return isinstance(self.extent, _expr.ConstExpr) and self.extend.value == 1
return isinstance(self.extent, _expr.ConstExpr) and self.extent.value == 1
def __str__(self):
return "(%s, %s)" % (
......
......@@ -6,7 +6,7 @@ constant_canonical_key = '__constant__'
def canonical_to_expr(c):
elements = []
for k, v in sorted(c.items()):
if k == constant_canonical_key:
if k == constant_canonical_key and v != 0:
elements.append(_expr.const(v))
elif v == 0:
continue
......@@ -87,7 +87,7 @@ class DivOp(BinaryOp):
if isinstance(erhs, _expr.ConstExpr):
lhs = lhs.copy()
for k, v in lhs.items():
lhs[k] /= erhs.value
lhs[k] /= float(erhs.value)
return lhs
elhs = canonical_to_expr(lhs)
return {elhs / erhs: 1}
......
from __future__ import absolute_import as _abs
from . import expr as _expr
from . import domain as _dom
from . import tensor as _tensor
class Split(object):
def __init__(self, dim, factor):
self.dim = dim
self.factor = factor
self.loop_index = _expr.Var('loop_index_%d_' % dim)
def infer_inner_domain(self, domain):
if isinstance(domain, _dom.RDom):
domain = domain.domain
assert self.dim < len(domain)
inner_domain = domain[:]
dim_out_range = domain[self.dim]
dim_inner_begin = dim_out_range.begin + self.loop_index * self.factor
inner_domain[self.dim] = _dom.Range(dim_inner_begin, dim_inner_begin + self.factor)
return inner_domain
......@@ -25,7 +25,6 @@ class Tensor(object):
self.name = name if name else "TensorObj"
self.inputs = None
self.rdom = None
def __call__(self, *indices):
if len(indices) != self.ndim:
......
......@@ -26,6 +26,6 @@ def test_simplify():
assert tvm.format_str(tvm.simplify(e4)) == '0'
if __name__ == "__main__":
test_simplify()
test_basic()
test_bind()
test_simplify()
import tvm
def test_split_dom_infer():
A = tvm.Tensor(2, name='A')
rd = tvm.RDom(tvm.Range(A.shape[1]))
split1 = tvm.Split(0, 64)
split2 = tvm.Split(1, 64)
split3 = tvm.Split(0, 8)
dom = [tvm.Range(A.shape[0]), tvm.Range(A.shape[1])]
dom1 = split1.infer_inner_domain(dom)
dom2 = split2.infer_inner_domain(dom1)
dom3 = split3.infer_inner_domain(dom2)
dom4 = split3.infer_inner_domain(rd)
i1 = split1.loop_index.name
i2 = split2.loop_index.name
i3 = split3.loop_index.name
assert str(dom1) == "[((%s * 64), ((%s * 64) + 64)), (0, A_shape_1_0)]" % (i1, i1)
assert str(dom2) == "[((%s * 64), ((%s * 64) + 64)), ((%s * 64), ((%s * 64) + 64))]" % (i1, i1, i2, i2)
assert str(dom3) == "[(((%s * 64) + (%s * 8)), (((%s * 64) + (%s * 8)) + 8)), ((%s * 64), ((%s * 64) + 64))]" % (i1, i3, i1, i3, i2, i2)
assert str(dom4) == "[((%s * 8), ((%s * 8) + 8))]" % (i3, i3)
if __name__ == "__main__":
test_split_dom_infer()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment