Skip to content
Snippets Groups Projects
Commit 062bb853 authored by tqchen's avatar tqchen
Browse files

Add in Array, fix most of IR

parent 622cee7a
No related branches found
No related tags found
No related merge requests found
Subproject commit 872099363b9f16a6cd4a4e8e46b9bd8dd1b861e9
Subproject commit 9070ac3697931ef5aeb8c373c23b2e8a2fec4627
......@@ -6,3 +6,4 @@ from ._ctypes._api import register_node
from . import expr
from . import stmt
from . import make
from . import domain
......@@ -5,7 +5,7 @@ from __future__ import absolute_import as _abs
import ctypes
import sys
from numbers import Number as Number
from numbers import Number, Integral
from .._base import _LIB
from .._base import c_str, py_str, string_types
......@@ -93,6 +93,27 @@ class NodeBase(object):
names.append(py_str(plist[i]))
return names
def const(value, dtype=None):
"""construct a constant"""
if dtype is None:
if isinstance(value, Integral):
dtype = 'int32'
else:
dtype = 'float32'
return _function_internal._const(value, dtype)
def convert(value):
"""Convert a value to expression."""
if isinstance(value, Number):
return const(value)
elif isinstance(value, list):
value = [convert(x) for x in value]
return _function_internal._Array(*value)
else:
if not isinstance(value, NodeBase):
raise ValueError("don't know how to handle type %s" % type(value))
def _push_arg(arg):
a = ArgVariant()
......@@ -147,9 +168,16 @@ def _make_function(handle, name):
doc_str = doc_str % (desc, param_str)
arg_names = [py_str(arg_names[i]) for i in range(num_args.value)]
def func(*args, **kwargs):
def func(*args):
"""TVM function"""
for arg in args:
cargs = []
for x in args:
if isinstance(x, list):
cargs.append(convert(x))
else:
cargs.append(x)
for arg in cargs:
_push_arg(arg)
ret_val = ArgVariant()
ret_typeid = ctypes.c_int()
......
from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node
from . import _function_internal
@register_node
class Array(NodeBase):
def __getitem__(self, i):
if i >= len(self):
raise IndexError("array index out ot range")
return _function_internal._ArrayGetItem(self, i)
def __len__(self):
return _function_internal._ArraySize(self)
def __repr__(self):
return '[' + (','.join(str(x) for x in self)) + ']'
......@@ -52,6 +52,10 @@ class CmpExpr(Expr):
class LogicalExpr(Expr):
pass
@register_node("Variable")
class Var(Expr):
pass
@register_node
class FloatImm(ConstExpr):
pass
......
......@@ -8,6 +8,7 @@ int32 = "int32"
float32 = "float32"
def const(value, dtype=None):
"""construct a constant"""
if dtype is None:
if isinstance(value, _Integral):
dtype = 'int32'
......@@ -16,12 +17,26 @@ def const(value, dtype=None):
return _function_internal._const(value, dtype)
def _symbol(value):
def Var(name="tindex", dtype=int32):
"""Create a new variable with specified name and dtype
Parameters
----------
name : str
The name
dtype : int
The data type
"""
return _function_internal._Var(name, dtype)
def convert(value):
"""Convert a value to expression."""
if isinstance(value, _Number):
return const(value)
elif isinstance(value, list):
value = [_symbol(x) for x in value]
value = [convert(x) for x in value]
return _function_internal._Array(*value)
else:
return value
......
......@@ -21,6 +21,10 @@ class ProducerConsumer(Stmt):
@register_node
class For(Stmt):
Serial = 0
Parallel = 1
Vectorized = 2
Unrolled = 3
pass
@register_node
......
......@@ -40,9 +40,46 @@ TVM_REGISTER_API(format_str)
os << args.at(0).operator Expr();
} else if (dynamic_cast<const BaseStmtNode*>(sptr.get())) {
os << args.at(0).operator Stmt();
} else {
LOG(FATAL) << "don't know how to print input NodeBaseType";
}
*ret = os.str();
})
.add_argument("expr", "Node", "expression to be printed");
TVM_REGISTER_API(_Array)
.set_body([](const ArgStack& args, RetValue *ret) {
std::vector<std::shared_ptr<Node> > data;
for (size_t i = 0; i < args.size(); ++i) {
CHECK(args.at(i).type_id == kNodeHandle);
data.push_back(args.at(i).sptr);
}
auto node = std::make_shared<ArrayNode>();
node->data = std::move(data);
ret->type_id = kNodeHandle;
ret->sptr = node;
});
TVM_REGISTER_API(_ArrayGetItem)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
int64_t i = args.at(1);
auto& sptr = args.at(0).sptr;
CHECK(sptr->is_type<ArrayNode>());
auto* n = static_cast<const ArrayNode*>(sptr.get());
CHECK_LT(static_cast<size_t>(i), n->data.size())
<< "out of bound of array";
ret->sptr = n->data[i];
ret->type_id = kNodeHandle;
});
TVM_REGISTER_API(_ArraySize)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
auto& sptr = args.at(0).sptr;
CHECK(sptr->is_type<ArrayNode>());
*ret = static_cast<int64_t>(
static_cast<const ArrayNode*>(sptr.get())->data.size());
});
} // namespace tvm
......@@ -14,6 +14,30 @@ using namespace Halide::Internal;
using ArgStack = const std::vector<APIVariantValue>;
using RetValue = APIVariantValue;
TVM_REGISTER_API(_Var)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Variable::make(args.at(1), args.at(0));
});
TVM_REGISTER_API(_make_For)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = For::make(args.at(0),
args.at(1),
args.at(2),
static_cast<ForType>(args.at(3).operator int()),
static_cast<Halide::DeviceAPI>(args.at(4).operator int()),
args.at(5));
});
TVM_REGISTER_API(_make_Allocate)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Allocate::make(args.at(0),
args.at(1),
args.at(2),
args.at(3),
args.at(4));
});
// make from two arguments
#define REGISTER_MAKE1(Node) \
TVM_REGISTER_API(_make_## Node) \
......@@ -33,7 +57,7 @@ using RetValue = APIVariantValue;
*ret = Node::make(args.at(0), args.at(1), args.at(2)); \
}) \
#define REGISTER_MAKE_BINARY_OP(Node) \
#define REGISTER_MAKE_BINARY_OP(Node) \
TVM_REGISTER_API(_make_## Node) \
.set_body([](const ArgStack& args, RetValue *ret) { \
Expr a = args.at(0), b = args.at(1); \
......@@ -67,13 +91,12 @@ REGISTER_MAKE3(Select);
REGISTER_MAKE3(Ramp);
REGISTER_MAKE2(Broadcast);
REGISTER_MAKE3(Let);
// TODO(tqchen) Call;
REGISTER_MAKE3(LetStmt);
REGISTER_MAKE2(AssertStmt);
REGISTER_MAKE3(ProducerConsumer);
// TODO(tqchen) For;
REGISTER_MAKE3(Store);
// TODO(tqchen) Provide;
// TODO(tqchen) Allocate;
REGISTER_MAKE3(Provide);
REGISTER_MAKE1(Free);
// TODO(tqchen) Realize;
REGISTER_MAKE2(Block);
......
......@@ -96,8 +96,10 @@ struct APIVariantValue {
}
inline operator Expr() const {
if (type_id == kNull) return Expr();
if (type_id == kLong) return Expr(operator int64_t());
if (type_id == kDouble) return Expr(operator double());
if (type_id == kLong) return Expr(operator int());
if (type_id == kDouble) {
return Expr(static_cast<float>(operator double()));
}
CHECK_EQ(type_id, kNodeHandle);
return Expr(sptr);
}
......
......@@ -19,7 +19,26 @@ def test_ir():
assert isinstance(stmt, tvm.stmt.Evaluate)
print tvm.format_str(stmt)
def test_basic():
a = tvm.Var('a')
b = tvm.Var('b')
c = a + b
assert tvm.format_str(c) == '(%s + %s)' % (a.name, b.name)
def test_array():
a = tvm.convert([1,2,3])
def test_stmt():
print tvm.make.Provide('a', [1,2,3], [1,2,3])
print tvm.make.For('a', 0, 1,
tvm.stmt.For.Serial, 0,
tvm.make.Evaluate(0))
if __name__ == "__main__":
test_const()
test_make()
test_ir()
test_basic()
test_stmt()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment