diff --git a/HalideIR b/HalideIR index 79a09d0fd60ae7fb6917a647832664212f7cc844..2a1001108b9112c4e594c456ffd364b57db10b6b 160000 --- a/HalideIR +++ b/HalideIR @@ -1 +1 @@ -Subproject commit 79a09d0fd60ae7fb6917a647832664212f7cc844 +Subproject commit 2a1001108b9112c4e594c456ffd364b57db10b6b diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index 4165baff532e9498a7c37b91f19aa98ac12a6e3e..ca2333fe103f109dc4fa5d722b1749568170777a 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -1,11 +1,6 @@ -"""Init proptype of the TVM""" +"""C++ backend related python scripts""" from __future__ import absolute_import as _abs -from .op import * -from .expr import Var, const -from .expr_util import * -from .tensor import Tensor -from .domain import Range, RDom, infer_range -from .split import Split -from .buffer import Scope, Buffer -from .schedule import Schedule +from .function import * +from ._ctypes._api import register_node +from . import expr diff --git a/python/tvm/cpp/_base.py b/python/tvm/_base.py similarity index 100% rename from python/tvm/cpp/_base.py rename to python/tvm/_base.py diff --git a/python/tvm/cpp/_ctypes/__init__.py b/python/tvm/_ctypes/__init__.py similarity index 100% rename from python/tvm/cpp/_ctypes/__init__.py rename to python/tvm/_ctypes/__init__.py diff --git a/python/tvm/cpp/_ctypes/_api.py b/python/tvm/_ctypes/_api.py similarity index 100% rename from python/tvm/cpp/_ctypes/_api.py rename to python/tvm/_ctypes/_api.py diff --git a/python/tvm/cpp/_function_internal.py b/python/tvm/_function_internal.py similarity index 100% rename from python/tvm/cpp/_function_internal.py rename to python/tvm/_function_internal.py diff --git a/python/tvm/buffer.py b/python/tvm/buffer.py deleted file mode 100644 index 2921dcb408eefa5cd31e35867683c703db950056..0000000000000000000000000000000000000000 --- a/python/tvm/buffer.py +++ /dev/null @@ -1,76 +0,0 @@ -from __future__ import absolute_import as _abs -from . import expr as _expr -from . import expr_util as _expr_util -from . import var_name as _name - - -def enum(*sequential, **named): - enums = dict(zip(sequential, range(len(sequential))), **named) - return type('Enum', (), enums) - - -"""Scope defines the scope of a buffer - -Types ------ -Thread : thread private buffer (registers) -Shared : shared buffer within a thread block (shared memory) -Global : buffer in the global GPU RAM -""" -Scope = enum('Thread', 'Shared', 'Global') - - -class Buffer(object): - def __init__(self, scope, name=None): - self.scope = scope - buf_name = 'Buffer_' - if name: buf_name += name - self.name = _name.NameManager.current.get(buf_name) - self.shape = [] - self.offset_index = [] - - def reshape(self, domain): - for r in domain: - self.shape.append(r.extent) - self.offset_index.append(r.begin) - - def __call__(self, *global_index): - if len(global_index) != len(self.shape): - raise ValueError("Need to provide %d index in buffer slice" % len(self.shape)) - stride = [1] - for i in reversed(range(1, len(self.shape))): - stride.insert(0, self.shape[i] * stride[0]) - local_index = [] - for i in range(0, len(global_index)): - local_index.append(global_index[i] - self.offset_index[i]) - index = local_index[0] * stride[0] - for i in range(1, len(local_index)): - index = index + local_index[i] * stride[i] - index = _expr_util.simplify(index) - return _expr.TensorRefExpr(self, [index]) - - -class BufferManager(object): - def __init__(self): - self._buffer_map = {} - self._old_manager = None - - def get(self, tensor): - if tensor in self._buffer_map: - return self._buffer_map[tensor] - return None - - def bind(self, tensor, buf): - self._buffer_map[tensor] = buf - - def __enter__(self): - self._old_manager = BufferManager.current - BufferManager.current = self - return self - - def __exit__(self, ptype, value, trace): - assert self._old_manager - BufferManager.current = self._old_manager - -# initialize the default buffer manager -BufferManager.current = BufferManager() diff --git a/python/tvm/codegen.py b/python/tvm/codegen.py deleted file mode 100644 index c5d6d44f7c58825fa2de18f22341095d98cea4f3..0000000000000000000000000000000000000000 --- a/python/tvm/codegen.py +++ /dev/null @@ -1,38 +0,0 @@ -from __future__ import absolute_import as _abs -from . import buffer as _buffer -from . import expr as _expr -from . import expr_util as _expr_util - -def gen_code(expr): - """change expression to string. - - Parameters - ---------- - expr : Expr - Input expression - - Returns - ------- - s : str - The string representation of expr - """ - def make_str(e, result_children): - if isinstance(e, _expr.BinaryOpExpr): - return e.op.format_str(result_children[0], result_children[1]) - elif isinstance(e, _expr.UnaryOpExpr): - return e.op.format_str(result_children[0]) - elif isinstance(e, _expr.ConstExpr): - return str(e.value) - elif isinstance(e, _expr.Var): - return e.name - elif isinstance(e, _expr.TensorRefExpr): - buf = _buffer.BufferManager.current.get(e.tensor) - if buf: - return _expr_util.format_str(buf(*e.indices)) - return _expr_util.format_str(e.tensor(*e.indices, flatten=True)) - elif isinstance(e, _expr.ReduceExpr): - return e.op.format_reduce_stmt_str(result_children[0]) - else: - raise TypeError("Do not know how to handle type " + str(type(e))) - return _expr_util.transform(expr, make_str) - diff --git a/python/tvm/cpp/__init__.py b/python/tvm/cpp/__init__.py deleted file mode 100644 index c77d66fa4236095aa0589645fd8b1f058cf8cdde..0000000000000000000000000000000000000000 --- a/python/tvm/cpp/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -"""C++ backend related python scripts""" -from __future__ import absolute_import as _abs - -from .function import * -from ._ctypes._api import register_node -from . import expr -from . import domain diff --git a/python/tvm/cpp/domain.py b/python/tvm/cpp/domain.py deleted file mode 100644 index 9ebf105e9e5624d0cd6605418d93ba60073fd0fa..0000000000000000000000000000000000000000 --- a/python/tvm/cpp/domain.py +++ /dev/null @@ -1,16 +0,0 @@ -from __future__ import absolute_import as _abs -from ._ctypes._api import NodeBase, register_node -from . import _function_internal - -@register_node("RangeNode") -class Range(NodeBase): - pass - - -@register_node("ArrayNode") -class Array(NodeBase): - def __getitem__(self, i): - return _function_internal._ArrayGetItem(self, i) - - def __len__(self): - return _function_internal._ArraySize(self) diff --git a/python/tvm/cpp/expr.py b/python/tvm/cpp/expr.py deleted file mode 100644 index 0dd19853db9bec9b5c14a1573b056f08eb7d01df..0000000000000000000000000000000000000000 --- a/python/tvm/cpp/expr.py +++ /dev/null @@ -1,66 +0,0 @@ -from __future__ import absolute_import as _abs -from ._ctypes._api import NodeBase, register_node -from .function import binary_op - -class Expr(NodeBase): - def __add__(self, other): - return binary_op('+', self, other) - - def __radd__(self, other): - return self.__add__(other) - - def __sub__(self, other): - return binary_op('-', self, other) - - def __rsub__(self, other): - return binary_op('-', other, self) - - def __mul__(self, other): - return binary_op('*', self, other) - - def __rmul__(self, other): - return binary_op('*', other, self) - - def __div__(self, other): - return binary_op('/', self, other) - - def __rdiv__(self, other): - return binary_op('/', other, self) - - def __truediv__(self, other): - return self.__div__(other) - - def __rtruediv__(self, other): - return self.__rdiv__(other) - - def __neg__(self): - return self.__mul__(-1) - - -@register_node("VarNode") -class Var(Expr): - pass - -@register_node("IntNode") -class IntExpr(Expr): - pass - -@register_node("FloatNode") -class FloatExpr(Expr): - pass - -@register_node("UnaryOpNode") -class UnaryOpExpr(Expr): - pass - -@register_node("BinaryOpNode") -class BinaryOpExpr(Expr): - pass - -@register_node("ReduceNode") -class ReduceExpr(Expr): - pass - -@register_node("TensorReadNode") -class TensorReadExpr(Expr): - pass diff --git a/python/tvm/cpp/function.py b/python/tvm/cpp/function.py deleted file mode 100644 index 723f2d6123d15f4bb0ee649bbdc680bffb3efc63..0000000000000000000000000000000000000000 --- a/python/tvm/cpp/function.py +++ /dev/null @@ -1,62 +0,0 @@ -from __future__ import absolute_import as _abs -from numbers import Number as _Number -from ._ctypes._api import _init_function_module -from .import _function_internal - -int32 = 1 -float32 = 2 - -def Var(name="tindex", dtype=int32): - """Create a new variable with specified name and dtype - - Parameters - ---------- - name : str - The name - - dtype : int - The data type - """ - return _function_internal._Var(name, dtype) - - -def _symbol(value): - """Convert a value to expression.""" - if isinstance(value, _Number): - return constant(value) - elif isinstance(value, list): - value = [_symbol(x) for x in value] - return _function_internal._Array(*value) - else: - return value - - -def max(lhs, rhs): - """Max of two expressions - - Parameters - ---------- - lhs : Expr/number - The left operand - - rhs : Expr/number - The right operand - """ - return binary_op("max", lhs, rhs) - - -def min(lhs, rhs): - """Min of two expressions - - Parameters - ---------- - lhs : Expr/number - The left operand - - rhs : Expr/number - The right operand - """ - return binary_op("max", lhs, rhs) - - -_init_function_module("tvm.cpp") diff --git a/python/tvm/domain.py b/python/tvm/domain.py deleted file mode 100644 index cd9815b4daa028a73c0a1ad1e4d7ab137a3a2748..0000000000000000000000000000000000000000 --- a/python/tvm/domain.py +++ /dev/null @@ -1,107 +0,0 @@ -from __future__ import absolute_import as _abs -from . import expr as _expr -from . import expr_util as _expr_util -from . import op as _op - -class Range(object): - """Represent a range in one dimension. - """ - def __init__(self, begin, end=None): - if end is None: - end = begin - begin = _expr.const(0) - begin = _expr_util.simplify(_expr._symbol(begin)) - end = _expr_util.simplify(_expr._symbol(end)) - self.begin = begin - self.end = end - self.extent = _expr_util.simplify(end - begin) - - def is_value(self): - return isinstance(self.extent, _expr.ConstExpr) and self.extent.value == 1 - - def __str__(self): - return "(%s, %s)" % ( - _expr_util.format_str(self.begin), - _expr_util.format_str(self.end)) - - def __repr__(self): - return self.__str__() - - -class RangeInferError(ValueError): - pass - - -class RDom(object): - """Reduction Domain.""" - def __init__(self, domain): - if isinstance(domain, Range): - domain = [domain] - self.index = [] - self.domain = domain - for i in range(len(domain)): - self.index.append(_expr.Var("rd_index_%d_" % i)) - - -"""Use list of ranges as domain""" -Domain = list - - -def _combine_range_binary_op(op, lhs, rhs): - if op == _op.add: - return Range(lhs.begin + rhs.begin, lhs.end + rhs.end - 1) - elif op == _op.sub: - return Range(lhs.begin - rhs.end + 1, lhs.end - rhs.begin) - elif op == _op.mul: - v = None - if lhs.is_value(): - v = lhs.begin.value - e = rhs - elif rhs.is_value(): - v = rhs.begin.value - e = lhs - if v == -1: - return Range(-e.end, -e.begin) - raise InferRangeError("donot know how to infer range for %s" % type(op)) - - -def infer_range(e, range_dict, allow_unbind_var=True): - """Infer the range of result e given range of variables. - - Parameters - ---------- - expr : Expr - Input expression - - range_dict : dict of Var->Range - The variables to be replaced. - - allow_unbind_var: bool - Whether allow unbinded variables - """ - def combine_range(e, result_children): - if isinstance(e, _expr.ConstExpr): - return Range(e, e + 1) - elif isinstance(e, _expr.BinaryOpExpr): - return _combine_range_binary_op(e.op, result_children[0], result_children[1]) - elif isinstance(e, _expr.Var): - if e in range_dict: - return range_dict[e] - else: - if allow_unbind_var: - return Range(e, e + 1) - else: - raise ValueError("Cannot find var %s in range_dict" % e.name) - else: - raise InferRangeError("cannot infer range for %s" % _expr_util.format_str(e)) - return _expr_util.transform(e, combine_range) - - -def union_range(lhs, rhs): - if lhs is None: - return rhs - if rhs is None: - return lhs - begin = _op.min(lhs.begin, rhs.begin) - end = _op.max(rhs.end, lhs.end) - return Range(begin, end) diff --git a/python/tvm/expr.py b/python/tvm/expr.py index 90ee0f1030ab873d2b91e122843d3347d838f8ed..5835640799f2d4394d0c4a7e1e88e295a856986a 100644 --- a/python/tvm/expr.py +++ b/python/tvm/expr.py @@ -1,51 +1,34 @@ -"""Base class of symbolic expression""" from __future__ import absolute_import as _abs -from numbers import Number as _Number -from . import var_name as _name +from ._ctypes._api import NodeBase, register_node +from . import function as _func -__addop__ = None -__subop__ = None -__mulop__ = None -__divop__ = None - -class Expr(object): - """Base class of expression. - - Expression object should be in general immutable. - """ - - def children(self): - """get children of this expression. - - Returns - ------- - children : generator of children - """ - return () +class Expr(NodeBase): + def __repr__(self): + return _func.format_str(self) def __add__(self, other): - return BinaryOpExpr(__addop__, self, other) + return binary_op('+', self, other) def __radd__(self, other): return self.__add__(other) def __sub__(self, other): - return BinaryOpExpr(__subop__, self, other) + return binary_op('-', self, other) def __rsub__(self, other): - return BinaryOpExpr(__subop__, other, self) + return binary_op('-', other, self) def __mul__(self, other): - return BinaryOpExpr(__mulop__, self, other) + return binary_op('*', self, other) def __rmul__(self, other): - return BinaryOpExpr(__mulop__, other, self) + return binary_op('*', other, self) def __div__(self, other): - return BinaryOpExpr(__divop__, self, other) + return binary_op('/', self, other) def __rdiv__(self, other): - return BinaryOpExpr(__divop__, other, self) + return binary_op('/', other, self) def __truediv__(self, other): return self.__div__(other) @@ -57,80 +40,14 @@ class Expr(object): return self.__mul__(-1) -def _symbol(value): - """Convert a value to expression.""" - if isinstance(value, Expr): - return value - elif isinstance(value, _Number): - return ConstExpr(value) - else: - raise TypeError("type %s not supported" % str(type(other))) - - -class Var(Expr): - """Variable, is a symbolic placeholder. - - Each variable is uniquely identified by its address - Note that name alone is not able to uniquely identify the var. - - Parameters - ---------- - name : str - optional name to the var. - """ - def __init__(self, name=None): - if name is None: name = 'index' - self.name = _name.NameManager.current.get(name) - - -class ConstExpr(Expr): - """Constant expression.""" - def __init__(self, value): - assert isinstance(value, _Number) - self.value = value - - -class BinaryOpExpr(Expr): - """Binary operator expression.""" - def __init__(self, op, lhs, rhs): - self.op = op - self.lhs = _symbol(lhs) - self.rhs = _symbol(rhs) - - def children(self): - return (self.lhs, self.rhs) - - -class UnaryOpExpr(Expr): - """Unary operator expression.""" - def __init__(self, op, src): - self.op = op - self.src = _symbol(src) - - def children(self): - return (self.src,) - - -class ReduceExpr(Expr): - def __init__(self, op, src, rdom): - self.op = op - self.src = src - self.rdom = rdom - - def children(self): - return (self.src,) - - -class TensorRefExpr(Expr): - """Tensor reference expression, tensor[indices]""" - def __init__(self, tensor, indices): - self.tensor = tensor - self.indices = indices - - def children(self): - return self.indices +@register_node("IntImm") +class IntImm(Expr): + pass +@register_node("UIntImm") +class UIntImm(Expr): + pass -def const(value): - """Return a constant value""" - return ConstExpr(value) +@register_node("FloatImm") +class FloatImm(Expr): + pass diff --git a/python/tvm/expr_util.py b/python/tvm/expr_util.py deleted file mode 100644 index dd7b68c4abc49d89ebd9de7d8f2d9b9dd45a54ab..0000000000000000000000000000000000000000 --- a/python/tvm/expr_util.py +++ /dev/null @@ -1,155 +0,0 @@ -"""Utilities to manipulate expression""" -from __future__ import absolute_import as _abs -from . import expr as _expr -from . import op as _op - -def expr_with_new_children(e, children): - """Returns same expr as e but with new children - - A shallow copy of e will happen if children differs from current children - - Parameters - ---------- - e : Expr - The input expression - - children : list of Expr - The new children - - Returns - ------- - new_e : Expr - Expression with the new children - """ - if children: - if isinstance(e, _expr.BinaryOpExpr): - return (e if children[0] == e.lhs and children[1] == e.rhs - else _expr.BinaryOpExpr(e.op, children[0], children[1])) - elif isinstance(e, _expr.UnaryOpExpr): - return e if children[0] == e.src else _expr.UnaryOpExpr(e.op, children[0]) - elif isinstance(e, _expr.TensorRefExpr): - return e if children == e.indices else _expr.TensorRefExpr(e.tensor, children) - elif isinstance(e, _expr.ReduceExpr): - return e if children[0] == e.src else _expr.ReduceExpr(e.op, children[0], e.rdom) - else: - raise TypeError("do not know how to handle Expr %s" % type(e)) - else: - return e - - -def transform(e, f): - """Apply f recursively to e and collect the resulr - - Parameters - ---------- - e : Expr - The input expression. - - f : function with signiture (e, ret_children) - ret_children is the result of transform from children - - Returns - ------- - result : return value of f - The final result of transformation. - """ - if not isinstance(e, _expr.Expr): - raise TypeError("Cannot handle type %s" % type(e)) - return f(e , [transform(c, f) for c in e.children()]) - - -def visit(e, f): - """Apply f to each element of e - - Parameters - ---------- - e : Expr - The input expression. - - f : function with signiture (e) - """ - assert isinstance(e, _expr.Expr) - for c in e.children(): - visit(c, f) - f(e) - - -def format_str(expr): - """change expression to string. - - Parameters - ---------- - expr : Expr - Input expression - - Returns - ------- - s : str - The string representation of expr - """ - def make_str(e, result_children): - if isinstance(e, _expr.BinaryOpExpr): - return e.op.format_str(result_children[0], result_children[1]) - elif isinstance(e, _expr.UnaryOpExpr): - return e.op.format_str(result_children[0]) - elif isinstance(e, _expr.ConstExpr): - return str(e.value) - elif isinstance(e, _expr.Var): - return e.name - elif isinstance(e, _expr.TensorRefExpr): - return "%s[%s]" % (e.tensor.name, ','.join(result_children)) - elif isinstance(e, _expr.ReduceExpr): - return e.op.format_reduce_str(result_children[0], e.rdom.domain) - else: - raise TypeError("Do not know how to handle type " + str(type(e))) - return transform(expr, make_str) - - -def simplify(expr): - """simplify expression - - Parameters - ---------- - expr : Expr - Input expression - - Returns - ------- - e : Expr - Simplified expression - """ - def canonical(e, result_children): - if isinstance(e, _expr.BinaryOpExpr): - return e.op.canonical(result_children[0], result_children[1]) - elif isinstance(e, _expr.UnaryOpExpr): - return e.op.canonical(result_children[0]) - elif isinstance(e, _expr.ConstExpr): - return {_op.const_canonical_key: e.value} - elif isinstance(e, _expr.Var): - return {e: 1} - else: - raise TypeError("Do not know how to handle type " + str(type(e))) - return _op.canonical_to_expr(transform(expr, canonical)) - - -def bind(expr, update_dict): - """Replace the variable in e by specification from kwarg - - Parameters - ---------- - expr : Expr - Input expression - - update_dict : dict of Var->Expr - The variables to be replaced. - - Examples - -------- - eout = bind(e, update_dict={v1: (x+1)} ) - """ - def replace(e, result_children): - if isinstance(e, _expr.Var) and e in update_dict: - return update_dict[e] - else: - return expr_with_new_children(e, result_children) - return transform(expr, replace) diff --git a/python/tvm/function.py b/python/tvm/function.py new file mode 100644 index 0000000000000000000000000000000000000000..c0fec46883bca6e3f133c41ff7be92ad18a02c2d --- /dev/null +++ b/python/tvm/function.py @@ -0,0 +1,28 @@ +from __future__ import absolute_import as _abs +from numbers import Number as _Number, Integral as _Integral +from ._ctypes._api import _init_function_module +from .import _function_internal + +int32 = "int32" +float32 = "float32" + +def const(value, dtype=None): + if dtype is None: + if isinstance(value, _Integral): + dtype = 'int32' + else: + dtype = 'float32' + return _function_internal._const(value, dtype) + + +def _symbol(value): + """Convert a value to expression.""" + if isinstance(value, _Number): + return const(value) + elif isinstance(value, list): + value = [_symbol(x) for x in value] + return _function_internal._Array(*value) + else: + return value + +_init_function_module("tvm") diff --git a/python/tvm/cpp/libinfo.py b/python/tvm/libinfo.py similarity index 96% rename from python/tvm/cpp/libinfo.py rename to python/tvm/libinfo.py index 01a5f8f56f3a66337eb33f072a98a9682e4a4878..43679dd73194e72995808e7f0fedd0b24af3e7ee 100644 --- a/python/tvm/cpp/libinfo.py +++ b/python/tvm/libinfo.py @@ -13,7 +13,7 @@ def find_lib_path(): List of all found path to the libraries """ curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) - api_path = os.path.join(curr_path, '../../../lib/') + api_path = os.path.join(curr_path, '../../lib/') cmake_build_path = os.path.join(curr_path, '../../build/Release/') dll_path = [curr_path, api_path, cmake_build_path] if os.name == 'nt': diff --git a/python/tvm/op.py b/python/tvm/op.py deleted file mode 100644 index fa40c3ca17a09d2ab91308b333d9bdb03a0616fc..0000000000000000000000000000000000000000 --- a/python/tvm/op.py +++ /dev/null @@ -1,154 +0,0 @@ -from __future__ import absolute_import as _abs -from . import expr as _expr - -const_canonical_key = '__constant__' - -def canonical_to_expr(c): - elements = [] - for k, v in sorted(c.items()): - if k == const_canonical_key and v != 0: - elements.append(_expr.const(v)) - elif v == 0: - continue - elif v == 1: - elements.append(k) - else: - elements.append(k * v) - if elements: - expr = elements[0] - for i in range(1, len(elements)): - expr = expr + elements[i] - return expr - else: - return _expr.const(0) - -class BinaryOp(object): - """Base class of binary operator""" - def __call__(self, lhs, rhs): - return _expr.BinaryOpExpr(self, lhs, rhs) - - -class AddOp(BinaryOp): - def format_str(self, lhs, rhs): - return '(%s + %s)' % (lhs, rhs) - - def format_reduce_str(self, src, rd): - return "reduce_sum(%s, rdom=%s)" % (src, str(rd)) - - def format_reduce_stmt_str(self, src): - # a temporary hack for now - return "+ %s" % (src) - - def canonical(self, lhs, rhs): - lhs = lhs.copy() - for k, v in rhs.items(): - if k in lhs: - lhs[k] += v - else: - lhs[k] = v - return lhs - -class SubOp(BinaryOp): - def format_str(self, lhs, rhs): - return '(%s - %s)' % (lhs, rhs) - - def canonical(self, lhs, rhs): - lhs = lhs.copy() - for k, v in rhs.items(): - if k in lhs: - lhs[k] -= v - else: - lhs[k] = -v - return lhs - - -class MulOp(BinaryOp): - def format_str(self, lhs, rhs): - return '(%s * %s)' % (lhs, rhs) - - def canonical(self, lhs, rhs): - elhs = canonical_to_expr(lhs) - erhs = canonical_to_expr(rhs) - if isinstance(erhs, _expr.ConstExpr): - lhs = lhs.copy() - for k, v in lhs.items(): - lhs[k] *= erhs.value - return lhs - if isinstance(elhs, _expr.ConstExpr): - rhs = rhs.copy() - for k, v in rhs.items(): - rhs[k] *= elhs.value - return rhs - return {elhs * erhs: 1} - - -class DivOp(BinaryOp): - def format_str(self, lhs, rhs): - return '(%s / %s)' % (lhs, rhs) - - def canonical(self, lhs, rhs): - erhs = canonical_to_expr(rhs) - if isinstance(erhs, _expr.ConstExpr): - lhs = lhs.copy() - remove = [] - for k, v in lhs.items(): - if k == const_canonical_key: - lhs[k] = v / erhs.value - else: - lhs[k / erhs] = 1 - remove.append(k) - for k in remove: - del lhs[k] - return lhs - elhs = canonical_to_expr(lhs) - return {elhs / erhs: 1} - - -class MaxOp(BinaryOp): - def format_str(self, lhs, rhs): - return 'max(%s, %s)' % (lhs, rhs) - - def canonical(self, lhs, rhs): - diff = SubOp().canonical(lhs, rhs) - ediff = canonical_to_expr(diff) - if isinstance(ediff, _expr.ConstExpr): - return lhs if ediff.value >= 0 else rhs - return {MaxOp()(lhs, rhs): 1} - - -class MinOp(BinaryOp): - def format_str(self, lhs, rhs): - return 'min(%s, %s)' % (lhs, rhs) - - def canonical(self, lhs, rhs): - diff = SubOp().canonical(lhs, rhs) - ediff = canonical_to_expr(diff) - if isinstance(ediff, _expr.ConstExpr): - return rhs if ediff.value >= 0 else lhs - return {MinOp()(lhs, rhs): 1} - - -add = AddOp() -sub = SubOp() -mul = MulOp() -div = DivOp() -max = MaxOp() -min = MinOp() - -_expr.__addop__ = add -_expr.__subop__ = sub -_expr.__mulop__ = mul -_expr.__divop__ = div - - -def reduce_sum(expr, rdom): - return _expr.ReduceExpr(add, expr, rdom) - -def reduce_prod(expr, rdom): - return _expr.ReduceExpr(mul, expr, rdom) - -def reduce_min(expr, rdom): - return _expr.ReduceExpr(min, expr, rdom) - -def reduce_max(expr, rdom): - return _expr.ReduceExpr(max, expr, rdom) diff --git a/python/tvm/schedule.py b/python/tvm/schedule.py deleted file mode 100644 index d7a23c775315a27539ba5fafe141e55c011cae07..0000000000000000000000000000000000000000 --- a/python/tvm/schedule.py +++ /dev/null @@ -1,127 +0,0 @@ -from __future__ import absolute_import as _abs -from . import domain as _dom -from . import expr as _expr -from . import expr_util as _expr_util -from . import split as _split -from . import buffer as _buffer -from . import codegen as _gen - -start_point_key = '__start__' -TAB = ' ' - -class Schedule(object): - """SUnit defines the compute schedule of a tensor - - Parameters - ---------- - tensor: tensor - """ - def __init__(self, tensor, buffer=None): - self.tensor = tensor - self.buffer = buffer - self.parent = None - #self.children = [] - self.splits = [] - self.split_attach = {start_point_key: []} - self.implicit_splits = [_split.Split(i, 1) for i in range(tensor.ndim)] - if isinstance(tensor.expr, _expr.ReduceExpr): - for i in range(len(tensor.expr.rdom.domain)): - self.implicit_splits.append(_split.Split(i, 1, rdom=True)) - - def add_split(self, split): - self.splits.append(split) - self.split_attach[split] = [] - - def set_buffer(self, buf): - self.buffer = buf - - def attach(self, split, other): - other.parent = self - if split is None: - self.split_attach[start_point_key].append(other) - else: - self.split_attach[split].append(other) - - def infer_inner_domain(self, domain): - for split in self.splits: - domain = split.infer_inner_domain(domain) - return domain - - def realize(self, domain=None, indent=''): - - def realize_attach(lst): - attach_tensors = [sch.tensor for sch in lst] - attach_domains = self.tensor.infer_input_domains(domain, attach_tensors, red_domain=red_domain) - for sch in lst: - body.extend(sch.realize(attach_domains[sch.tensor], indent)) - - # init domain and red_domain - if domain is None: - domain = self.tensor.domain - red_domain = self.tensor.expr.rdom.domain if isinstance(self.tensor.expr, _expr.ReduceExpr) else None - - # init buffer shape - if self.buffer: - if self.buffer.scope == _buffer.Scope.Global: - self.buffer.reshape(self.tensor.domain) - else: - # don't handle shared buffer for now - self.buffer.reshape(domain) - _buffer.BufferManager.current.bind(self.tensor, self.buffer) - - body = [] - - if self.split_attach[start_point_key]: - realize_attach(self.split_attach[start_point_key]) - - # add loop conditions for splits - for split in self.splits: - if split.rdom: - red_domain = split.generate_loop_condition(red_domain, body, indent) - else: - domain = split.generate_loop_condition(domain, body, indent) - indent += TAB - if self.split_attach[split]: - realize_attach(self.split_attach[split]) - - # add implicit loop conditions - for split in self.implicit_splits: - if split.rdom: - red_domain = split.generate_loop_condition(red_domain, body, indent) - else: - domain = split.generate_loop_condition(domain, body, indent) - indent += TAB - - # add loop body - expr = self.tensor.expr - global_index = [r.begin for r in domain] - global_rdom_index = [r.begin for r in red_domain] if red_domain else [] - if expr is None: - if self.buffer: - lhs = self.buffer(*global_index) - rhs = self.tensor(*global_index, flatten=True) - body.append('%s%s = %s;' % (indent, _expr_util.format_str(lhs), _expr_util.format_str(rhs))) - else: - if self.buffer: - lhs = self.buffer(*global_index) - else: - lhs = self.tensor(*global_index, flatten=True) - - bind_dict = {} - for i in range(self.tensor.ndim): - bind_dict[self.tensor.dim_index[i]] = global_index[i] - if isinstance(expr, _expr.ReduceExpr): - for i in range(len(expr.rdom.domain)): - bind_dict[expr.rdom.index[i]] = global_rdom_index[i] - rhs = _expr_util.bind(expr, bind_dict) - body.append('%s%s = %s;' % (indent, _expr_util.format_str(lhs), _gen.gen_code(rhs))) - - # add right brackets - for split in self.implicit_splits: - indent = indent[:-len(TAB)] - body.append('%s}' % indent) - for split in self.splits: - indent = indent[:-len(TAB)] - body.append('%s}' % indent) - - return body diff --git a/python/tvm/split.py b/python/tvm/split.py deleted file mode 100644 index 2434724c5a9ca41df857dc3111438d2833977e1a..0000000000000000000000000000000000000000 --- a/python/tvm/split.py +++ /dev/null @@ -1,35 +0,0 @@ -from __future__ import absolute_import as _abs -from . import expr as _expr -from . import expr_util as _expr_util -from . import domain as _dom -from . import tensor as _tensor - - -class Split(object): - def __init__(self, dim, factor, name=None, rdom=False): - self.dim = dim - self.factor = factor - self.rdom = rdom - if name is None: - name = 'loop_index_%d_' % dim - self.loop_index = _expr.Var(name) - - def infer_inner_domain(self, out_domain): - assert self.dim < len(out_domain) - inner_domain = out_domain[:] - dim_out_range = out_domain[self.dim] - dim_inner_begin = dim_out_range.begin + self.loop_index * self.factor - inner_domain[self.dim] = _dom.Range(dim_inner_begin, dim_inner_begin + self.factor) - return inner_domain - - def generate_loop_condition(self, out_domain, body, indent): - assert self.dim < len(out_domain) - loop_range = _dom.Range(out_domain[self.dim].extent / self.factor) - stmt = '%sfor (int %s = 0; %s < %s; %s += 1) {' % ( - indent, - self.loop_index.name, - self.loop_index.name, - _expr_util.format_str(loop_range.end), - self.loop_index.name) - body.append(stmt) - return self.infer_inner_domain(out_domain) diff --git a/python/tvm/tensor.py b/python/tvm/tensor.py deleted file mode 100644 index 14b33c685c5fdae35d96681f71ace2fce4a201a5..0000000000000000000000000000000000000000 --- a/python/tvm/tensor.py +++ /dev/null @@ -1,132 +0,0 @@ -from __future__ import absolute_import as _abs -from . import expr as _expr -from . import expr_util as _expr_util -from . import domain as _dom -from . import var_name as _name - - -class Tensor(object): - def __init__(self, ndim, fcompute=None, name=None, shape=None): - self.ndim = ndim - if fcompute: - arg_names = fcompute.func_code.co_varnames - assert(len(arg_names) == ndim) - self.dim_index = [_expr.Var(n) for n in arg_names] - self.expr = fcompute(*self.dim_index) - if shape is None: - raise ValueError("argument shape need to be given for intermediate tensor") - self.shape = shape - else: - self.expr = None - self.dim_index = None - shape_name = '_shape' - if name: shape_name = name + shape_name - self.shape = shape if shape else tuple( - _expr.Var("%s_%d_" % (shape_name, i)) for i in range(ndim)) - - self.name = name if name else _name.NameManager.current.get("TensorObj") - self.inputs = None - - def __call__(self, *indices, **option): - if len(indices) != self.ndim: - raise ValueError("Need to provide %d index in tensor slice" % self.ndim) - if 'flatten' in option and option['flatten']: - stride = [1] - for i in reversed(range(1, len(indices))): - stride.insert(0, self.shape[i] * stride[0]) - index = indices[0] * stride[0] - for i in range(1, len(indices)): - index = index + indices[i] * stride[i] - index = _expr_util.simplify(index) - return _expr.TensorRefExpr(self, [index]) - return _expr.TensorRefExpr(self, indices) - - @property - def domain(self): - return _dom.Domain([_dom.Range(self.shape[i]) for i in range(self.ndim)]) - - def input_tensors(self): - """List of input tensors to this tensor. - - Returns - ------- - inputs : list of input tensors - """ - if self.inputs is not None: - return self.inputs - inputs = [] - if self.expr: - def collect(e): - if isinstance(e, _expr.TensorRefExpr): - inputs.append(e.tensor) - _expr_util.visit(self.expr, collect) - self.inputs = set(inputs) - return self.inputs - - def infer_input_domains(self, out_domain, inputs, red_domain=None): - """Infer the input domains of each domain in given inputs list. - - Parameters - ---------- - out_domain : list of Range - Domain of each dimension. - - red_domain : list of Range - Domain of reduction variables, if this tensor - this can only be specified if - self.expr finishes with an ReduceExpr, and we can schedule - over the last reduction that creates this tensor. - - Returns - ------- - in_domains: dict Tensor->Domain - """ - assert self.expr - assert len(out_domain) == len(self.dim_index) - index_domains = { - self.dim_index[i] : out_domain[i] for i in range(len(out_domain)) - } - - begin_expr = self.expr - if red_domain: - if not isinstance(self.expr, _expr.ReduceExpr): - raise ValueError("red_domain must work with tensor that stores a reduction") - rdom = self.expr.rdom - begin_expr = self.expr.src - assert len(red_domain) == len(rdom.index) - for i in range(len(red_domain)): - index_domains[rdom.index[i]] = red_domain[i] - - iset = {} - for t in inputs: - assert t in self.input_tensors() - iset[t] = [] - - def prepare(e): - if isinstance(e, _expr.ReduceExpr): - rd = e.rdom - for i in range(len(rd.domain)): - index_domains[rd.index[i]] = rd.domain[i] - elif isinstance(e, _expr.TensorRefExpr): - if e.tensor in iset: - iset[e.tensor].append(e) - _expr_util.visit(begin_expr, prepare) - result = {} - for k, v in iset.items(): - dm = [None] * len(v[0].indices) - for e in v: - for i, idx in enumerate(e.indices): - dm[i] = _dom.union_range( - dm[i], _dom.infer_range(idx, index_domains, allow_unbind_var=False)) - result[k] = dm - return result - - @property - def is_rtensor(self): - """Whether this tensor is a result of reduction. - - Returns - ------- - is_rtensor : Whether the tensor is RTensor - """ - return self.expr and isinstance(self.expr, _expr.ReduceExpr) diff --git a/python/tvm/var_name.py b/python/tvm/var_name.py deleted file mode 100644 index 57044e83b25d07b205f1407d575a97187857ac21..0000000000000000000000000000000000000000 --- a/python/tvm/var_name.py +++ /dev/null @@ -1,51 +0,0 @@ -"""Name manager to make sure name is unique.""" -from __future__ import absolute_import as _abs - -class NameManager(object): - """NameManager to do automatic naming. - - User can also inherit this object to change naming behavior. - """ - current = None - - def __init__(self): - self._counter = {} - self._old_manager = None - - def get(self, hint): - """Get the canonical name for a symbol. - - This is default implementation. - When user specified a name, - the user specified name will be used. - - When user did not, we will automatically generate a - name based on hint string. - - Parameters - ---------- - hint : str - A hint string, which can be used to generate name. - - Returns - ------- - full_name : str - A canonical name for the user. - """ - if hint not in self._counter: - self._counter[hint] = 0 - name = '%s%d' % (hint, self._counter[hint]) - self._counter[hint] += 1 - return name - - def __enter__(self): - self._old_manager = NameManager.current - NameManager.current = self - return self - - def __exit__(self, ptype, value, trace): - assert self._old_manager - NameManager.current = self._old_manager - -# initialize the default name manager -NameManager.current = NameManager() diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index dd7c60da638c3f87e96a168605b5a2f047264aa1..7e9e32b336d0f9bc38ef02609a986a55a062e4e2 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -27,36 +27,6 @@ struct TVMAPIThreadLocalEntry { inline void SetReturn(ArgVariant* ret_val, int* ret_typeid); }; -namespace tvm { -inline std::string Type2String(const Type& t) { - std::ostringstream os; - os << t; - return os.str(); -} - -inline Type String2Type(std::string s) { - std::istringstream is(s); - halide_type_code_t code; - if (s.substr(0, 3) == "int") { - code = Type::Int; s = s.substr(3); - } else if (s.substr(0, 4) == "uint") { - code = Type::UInt; s = s.substr(4); - } else if (s.substr(0, 5) == "float") { - code = Type::Float; s = s.substr(5); - } else if (s.substr(0, 5) == "float") { - code = Type::Float; s = s.substr(5); - } else { - LOG(FATAL) << "unknown type " << s; - } - int bits, lanes = 0; - if (sscanf(s.c_str(), "%dx%d", &bits, &lanes) == 0) { - LOG(FATAL) << "unknown type " << s; - } - return Type(code, bits, lanes); -} - -} - using namespace tvm; /*! \brief Thread local store that can be used to hold return values. */ @@ -86,7 +56,7 @@ struct APIAttrGetter : public AttrVisitor { if (skey == key) *ret = static_cast<int64_t>(value[0]); } void Visit(const char* key, Type* value) final { - if (skey == key) *ret = Type2String(value[0]); + if (skey == key) *ret = value[0]; } void Visit(const char* key, std::string* value) final { if (skey == key) *ret = value[0]; diff --git a/src/c_api/c_api_function.cc b/src/c_api/c_api_function.cc index bff849de15936c4828d3d3a920cf6038387bf6c6..072892c730f386c7e82dd3352570cffbee6952a6 100644 --- a/src/c_api/c_api_function.cc +++ b/src/c_api/c_api_function.cc @@ -4,6 +4,7 @@ * \file c_api_impl.cc */ #include <tvm/expr.h> +#include <ir/IROperator.h> #include "./c_api_registry.h" namespace dmlc { @@ -12,7 +13,36 @@ DMLC_REGISTRY_ENABLE(::tvm::APIFunctionReg); namespace tvm { +using namespace Halide::Internal; + using ArgStack = const std::vector<APIVariantValue>; using RetValue = APIVariantValue; +TVM_REGISTER_API(_const) +.set_body([](const ArgStack& args, RetValue *ret) { + if (args.at(0).type_id == kLong) { + *ret = make_const(args.at(1), args.at(0).operator int64_t()); + } else if (args.at(0).type_id == kDouble) { + *ret = make_const(args.at(1), args.at(0).operator double()); + } else { + LOG(FATAL) << "only accept int or float"; + } + }) +.add_argument("src", "Number", "source number") +.add_argument("dtype", "str", "data type"); + +TVM_REGISTER_API(format_str) +.set_body([](const ArgStack& args, RetValue *ret) { + CHECK(args.at(0).type_id == kNodeHandle); + std::ostringstream os; + auto& sptr = args.at(0).sptr; + if (dynamic_cast<const BaseExprNode*>(sptr.get())) { + os << args.at(0).operator Expr(); + } else if (dynamic_cast<const BaseStmtNode*>(sptr.get())) { + os << args.at(0).operator Stmt(); + } + *ret = os.str(); + }) +.add_argument("expr", "Node", "expression to be printed"); + } // namespace tvm diff --git a/src/c_api/c_api_registry.h b/src/c_api/c_api_registry.h index 01d59b8d9c7851a5ed0c873f96ed8fe6321548af..c215c24732f06abb3847efc881853de20d2234b9 100644 --- a/src/c_api/c_api_registry.h +++ b/src/c_api/c_api_registry.h @@ -16,6 +16,33 @@ namespace tvm { +inline std::string Type2String(const Type& t) { + std::ostringstream os; + os << t; + return os.str(); +} + +inline Type String2Type(std::string s) { + std::istringstream is(s); + halide_type_code_t code; + if (s.substr(0, 3) == "int") { + code = Type::Int; s = s.substr(3); + } else if (s.substr(0, 4) == "uint") { + code = Type::UInt; s = s.substr(4); + } else if (s.substr(0, 5) == "float") { + code = Type::Float; s = s.substr(5); + } else if (s.substr(0, 5) == "float") { + code = Type::Float; s = s.substr(5); + } else { + LOG(FATAL) << "unknown type " << s; + } + int bits, lanes = 1; + if (sscanf(s.c_str(), "%dx%d", &bits, &lanes) == 0) { + LOG(FATAL) << "unknown type " << s; + } + return Type(code, bits, lanes); +} + /*! \brief Variant container for API calls */ struct APIVariantValue { /*! \brief the type id */ @@ -57,6 +84,9 @@ struct APIVariantValue { this->sptr = ref.node_; return *this; } + inline APIVariantValue& operator=(const Type& value) { + return operator=(Type2String(value)); + } template<typename T, typename = typename std::enable_if<std::is_base_of<NodeRef, T>::value>::type> inline operator T() const { @@ -89,6 +119,9 @@ struct APIVariantValue { CHECK_EQ(type_id, kStr); return str; } + inline operator Type() const { + return String2Type(operator std::string()); + } }; // common defintiion of API function. diff --git a/tests/python/test_basic.py b/tests/python/test_basic.py index f674627f982ac40fdbd145411fb8ab4cd0785bed..4b5f7ee112af1e366245d98bf382178ec97c91c0 100644 --- a/tests/python/test_basic.py +++ b/tests/python/test_basic.py @@ -1,31 +1,9 @@ import tvm -def test_bind(): - x = tvm.Var('x') - y = x + 1 - z = tvm.bind(y, {x: tvm.const(10) + 9}) - assert tvm.format_str(z) == '((10 + 9) + 1)' - - -def test_basic(): - a = tvm.Var('a') - b = tvm.Var('b') - c = a + b - assert tvm.format_str(c) == '(%s + %s)' % (a.name, b.name) - -def test_simplify(): - a = tvm.Var('a') - b = tvm.Var('b') - e1 = a * (2 + 1) + b * 1 - e2 = a * (2 + 1) - b * 1 - e3 = tvm.max(a * 3.3 + 5, 3 + 3.3 * a) - e4 = a - a - assert tvm.format_str(tvm.simplify(e1)) == '((%s * 3) + %s)' % (a.name, b.name) - assert tvm.format_str(tvm.simplify(e2)) == '((%s * 3) + (%s * -1))' % (a.name, b.name) - assert tvm.format_str(tvm.simplify(e3)) == '((%s * 3.3) + 5)' % (a.name) - assert tvm.format_str(tvm.simplify(e4)) == '0' +def test_const(): + x = tvm.const(1) + assert x.type == 'int32' + assert isinstance(x, tvm.expr.IntImm) if __name__ == "__main__": - test_basic() - test_bind() - test_simplify() + test_const() diff --git a/tests/python/test_buffer.py b/tests/python/test_buffer.py deleted file mode 100644 index 26dc560500acf7baa7d4bd6b7060c86765b9880c..0000000000000000000000000000000000000000 --- a/tests/python/test_buffer.py +++ /dev/null @@ -1,14 +0,0 @@ -import tvm - -def test_buffer(): - buf = tvm.Buffer(tvm.Scope.Thread) - shape = [32, 16] - domain = [tvm.Range(v) for v in shape] - buf.reshape(domain) - x = tvm.Var('x') - y = tvm.Var('y') - assert tvm.format_str(buf(y, x)) == '%s[(%s + (%s * %s))]' % (buf.name, x.name, y.name, shape[1]) - - -if __name__ == '__main__': - test_buffer() diff --git a/tests/python/test_cpp.py b/tests/python/test_cpp.py deleted file mode 100644 index c43c1bbc16001a020b4789ca5063a9cccd8e3cde..0000000000000000000000000000000000000000 --- a/tests/python/test_cpp.py +++ /dev/null @@ -1,41 +0,0 @@ -from tvm import cpp as tvm - - -def test_basic(): - a = tvm.Var('a') - b = tvm.Var('b') - c = a + b - assert a == c.lhs - assert c.dtype == tvm.int32 - assert tvm.format_str(c) == '(%s + %s)' % (a.name, b.name) - - -def test_array(): - a = tvm.Var('a') - x = tvm.function._symbol([1,2,a]) - - -def assert_equal(x, y): - z = tvm.simplify(x - y) - assert isinstance(z, tvm.expr.IntExpr) - assert z.value == 0 - - -def test_simplify(): - a = tvm.Var('a') - b = tvm.Var('b') - e1 = a * (2 + 1) + b * 1 - e2 = a * (2 + 1) - b * 1 - e3 = tvm.max(a * 3 + 5, 3 + 3 * a) - e4 = a - a - - assert_equal(e1, a * 3 + b) - assert_equal(e2, a * 3 - b) - assert_equal(e3, a * 3 + 5) - assert_equal(e4, 0) - - -if __name__ == "__main__": - test_basic() - test_array() - test_simplify()