diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index 29ffd079284da4c1b681389bd363f302caf65f8d..143aceeab6e32082986bbf60b6254a9803b81969 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -6,3 +6,4 @@ from .expr import Var, const from .expr_util import * from .tensor import Tensor from .domain import RDom, Range, infer_range +from .split import Split diff --git a/python/tvm/domain.py b/python/tvm/domain.py index 7b615ebb41420aca66abbade3708e4a623f65dc4..cd9815b4daa028a73c0a1ad1e4d7ab137a3a2748 100644 --- a/python/tvm/domain.py +++ b/python/tvm/domain.py @@ -17,7 +17,7 @@ class Range(object): self.extent = _expr_util.simplify(end - begin) def is_value(self): - return isinstance(self.extent, _expr.ConstExpr) and self.extend.value == 1 + return isinstance(self.extent, _expr.ConstExpr) and self.extent.value == 1 def __str__(self): return "(%s, %s)" % ( diff --git a/python/tvm/op.py b/python/tvm/op.py index bc782d229556c932bffb12b0cc35910737f42c74..1ea5d592674698dc3b8ca932a1e278a0b7b5e4ce 100644 --- a/python/tvm/op.py +++ b/python/tvm/op.py @@ -6,7 +6,7 @@ constant_canonical_key = '__constant__' def canonical_to_expr(c): elements = [] for k, v in sorted(c.items()): - if k == constant_canonical_key: + if k == constant_canonical_key and v != 0: elements.append(_expr.const(v)) elif v == 0: continue @@ -87,7 +87,7 @@ class DivOp(BinaryOp): if isinstance(erhs, _expr.ConstExpr): lhs = lhs.copy() for k, v in lhs.items(): - lhs[k] /= erhs.value + lhs[k] /= float(erhs.value) return lhs elhs = canonical_to_expr(lhs) return {elhs / erhs: 1} diff --git a/python/tvm/split.py b/python/tvm/split.py new file mode 100644 index 0000000000000000000000000000000000000000..5bd070f6a56661a72f8790e02079b91035462e8c --- /dev/null +++ b/python/tvm/split.py @@ -0,0 +1,22 @@ +from __future__ import absolute_import as _abs +from . import expr as _expr +from . import domain as _dom +from . import tensor as _tensor + + +class Split(object): + def __init__(self, dim, factor): + self.dim = dim + self.factor = factor + self.loop_index = _expr.Var('loop_index_%d_' % dim) + + def infer_inner_domain(self, domain): + if isinstance(domain, _dom.RDom): + domain = domain.domain + assert self.dim < len(domain) + inner_domain = domain[:] + dim_out_range = domain[self.dim] + dim_inner_begin = dim_out_range.begin + self.loop_index * self.factor + inner_domain[self.dim] = _dom.Range(dim_inner_begin, dim_inner_begin + self.factor) + return inner_domain + diff --git a/python/tvm/tensor.py b/python/tvm/tensor.py index 453dad12bf6f4a7b25369f0d4f837e55bc6dc7ad..67d3a3aa75906d47cf2c3f03847e04d7c21a7831 100644 --- a/python/tvm/tensor.py +++ b/python/tvm/tensor.py @@ -25,7 +25,6 @@ class Tensor(object): self.name = name if name else "TensorObj" self.inputs = None - self.rdom = None def __call__(self, *indices): if len(indices) != self.ndim: diff --git a/tests/python/test_basic.py b/tests/python/test_basic.py index 53ff569eaf33bbc3ed96416a9308c6b32e552537..f674627f982ac40fdbd145411fb8ab4cd0785bed 100644 --- a/tests/python/test_basic.py +++ b/tests/python/test_basic.py @@ -26,6 +26,6 @@ def test_simplify(): assert tvm.format_str(tvm.simplify(e4)) == '0' if __name__ == "__main__": - test_simplify() test_basic() test_bind() + test_simplify() diff --git a/tests/python/test_split.py b/tests/python/test_split.py new file mode 100644 index 0000000000000000000000000000000000000000..652d3974d547dd56cb216f35b5970b1b56c6cc61 --- /dev/null +++ b/tests/python/test_split.py @@ -0,0 +1,24 @@ +import tvm + +def test_split_dom_infer(): + A = tvm.Tensor(2, name='A') + rd = tvm.RDom(tvm.Range(A.shape[1])) + split1 = tvm.Split(0, 64) + split2 = tvm.Split(1, 64) + split3 = tvm.Split(0, 8) + dom = [tvm.Range(A.shape[0]), tvm.Range(A.shape[1])] + dom1 = split1.infer_inner_domain(dom) + dom2 = split2.infer_inner_domain(dom1) + dom3 = split3.infer_inner_domain(dom2) + dom4 = split3.infer_inner_domain(rd) + i1 = split1.loop_index.name + i2 = split2.loop_index.name + i3 = split3.loop_index.name + assert str(dom1) == "[((%s * 64), ((%s * 64) + 64)), (0, A_shape_1_0)]" % (i1, i1) + assert str(dom2) == "[((%s * 64), ((%s * 64) + 64)), ((%s * 64), ((%s * 64) + 64))]" % (i1, i1, i2, i2) + assert str(dom3) == "[(((%s * 64) + (%s * 8)), (((%s * 64) + (%s * 8)) + 8)), ((%s * 64), ((%s * 64) + 64))]" % (i1, i3, i1, i3, i2, i2) + assert str(dom4) == "[((%s * 8), ((%s * 8) + 8))]" % (i3, i3) + + +if __name__ == "__main__": + test_split_dom_infer()