Skip to content
Snippets Groups Projects
Commit bda95817 authored by tqchen's avatar tqchen
Browse files

checkin tensor

parent fc4ba796
No related branches found
No related tags found
No related merge requests found
......@@ -4,3 +4,4 @@ from __future__ import absolute_import as _abs
from .op import *
from .expr import Var, const
from .expr_util import *
from .tensor import Tensor
......@@ -79,7 +79,7 @@ class Var(Expr):
optional name to the var.
"""
def __init__(self, name=None):
if name is None: name = 'i'
if name is None: name = 'index'
self.name = _name.NameManager.current.get(name)
......@@ -100,7 +100,7 @@ class BinaryOpExpr(Expr):
def children(self):
return (self.lhs, self.rhs)
class UnaryOpExpr(Expr):
"""Unary operator expression."""
def __init__(self, op, src):
......
......@@ -2,6 +2,7 @@
from __future__ import absolute_import as _abs
from . import expr as _expr
from . import op as _op
from . import tensor as _tensor
def expr_with_new_children(e, children):
"""Returns same expr as e but with new children
......@@ -75,6 +76,8 @@ def format_str(expr):
return str(e.value)
elif isinstance(e, _expr.Var):
return e.name
elif isinstance(e, _tensor.TensorReadExpr):
return "%s(%s)" % (e.tensor.name, ','.join(result_children))
else:
raise TypeError("Do not know how to handle type " + str(type(e)))
return transform(expr, make_str)
......
from __future__ import absolute_import as _abs
from . import expr as _expr
class TensorReadExpr(_expr.Expr):
def __init__(self, tensor, indices):
self.tensor = tensor
self.indices = indices
def children(self):
return self.indices
class Tensor(object):
def __init__(self, ndim, fcompute=None, name=None):
self.ndim = ndim
if fcompute:
arg_names = fcompute.func_code.co_varnames
assert(len(arg_names) == ndim)
self.dim_index = [_expr.Var(n) for n in arg_names]
self.expr = fcompute(*self.dim_index)
else:
self.expr = None
self.dim_index = None
shape_name = '_shape'
if name: shape_name = name + shape_name
self.shape = tuple(_expr.Var("%s_%d_" % (shape_name, i)) for i in range(ndim))
self.name = name if name else "TensorObj"
def __call__(self, *indices):
if len(indices) != self.ndim:
raise ValueError("Need to provide %d index in tensor slice" % self.ndim)
return TensorReadExpr(self, indices)
import tvm
from tvm import expr
def test_bind():
x = tvm.Var('x')
......
import tvm
def test_tensor():
A = tvm.Tensor(2, name='A')
B = tvm.Tensor(2, name='B')
T = tvm.Tensor(3, lambda i, j, k: A(i, k) * B(j, k))
print(tvm.format_str(T.expr))
if __name__ == "__main__":
test_tensor()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment