Skip to content
Snippets Groups Projects
Commit 9595a9c1 authored by tqchen's avatar tqchen
Browse files

Expose array to python

parent de2be97e
No related branches found
No related tags found
No related merge requests found
......@@ -4,3 +4,4 @@ from __future__ import absolute_import as _abs
from .function import *
from ._ctypes._api import register_node
from . import expr
from . import domain
from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node
from . import _function_internal
@register_node("RangeNode")
class Range(NodeBase):
pass
@register_node("ArrayNode")
class Array(NodeBase):
def __getitem__(self, i):
return _function_internal._ArrayGetItem(self, i)
def __len__(self):
return _function_internal._ArraySize(self)
from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node
from .function import binary_op
......@@ -40,6 +41,26 @@ class Expr(NodeBase):
class Var(Expr):
pass
@register_node("IntNode")
class IntExpr(Expr):
pass
@register_node("FloatNode")
class FloatExpr(Expr):
pass
@register_node("UnaryOpNode")
class UnaryOpExpr(Expr):
pass
@register_node("BinaryOpNode")
class BinaryOpExpr(Expr):
pass
@register_node("ReduceNode")
class ReduceExpr(Expr):
pass
@register_node("TensorReadNode")
class TensorReadExpr(Expr):
pass
......@@ -24,6 +24,9 @@ def _symbol(value):
"""Convert a value to expression."""
if isinstance(value, _Number):
return constant(value)
elif isinstance(value, list):
value = [_symbol(x) for x in value]
return _function_internal._Array(*value)
else:
return value
......
......@@ -61,6 +61,41 @@ TVM_REGISTER_API(Range)
.add_argument("begin", "Expr", "beginning of the range.")
.add_argument("end", "Expr", "end of the range");
TVM_REGISTER_API(_Array)
.set_body([](const ArgStack& args, RetValue *ret) {
std::vector<std::shared_ptr<Node> > data;
for (size_t i = 0; i < args.size(); ++i) {
CHECK(args.at(i).type_id == kNodeHandle);
data.push_back(args.at(i).sptr);
}
auto node = std::make_shared<ArrayNode>();
node->data = std::move(data);
ret->type_id = kNodeHandle;
ret->sptr = node;
});
TVM_REGISTER_API(_ArrayGetItem)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
int64_t i = args.at(1);
auto& sptr = args.at(0).sptr;
CHECK(sptr->is_type<ArrayNode>());
auto* n = static_cast<const ArrayNode*>(sptr.get());
CHECK_LT(static_cast<size_t>(i), n->data.size())
<< "out of bound of array";
ret->sptr = n->data[i];
ret->type_id = kNodeHandle;
});
TVM_REGISTER_API(_ArraySize)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
auto& sptr = args.at(0).sptr;
CHECK(sptr->is_type<ArrayNode>());
*ret = static_cast<int64_t>(
static_cast<const ArrayNode*>(sptr.get())->data.size());
});
TVM_REGISTER_API(_TensorInput)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Tensor(
......
......@@ -57,7 +57,7 @@ struct APIVariantValue {
return *this;
}
template<typename T,
typename = typename std::enable_if<std::is_base_of<NodeRef, T>::value>::type >
typename = typename std::enable_if<std::is_base_of<NodeRef, T>::value>::type>
inline operator T() const {
if (type_id == kNull) return T();
CHECK_EQ(type_id, kNodeHandle);
......
......@@ -9,5 +9,15 @@ def test_basic():
assert c.dtype == tvm.int32
assert tvm.format_str(c) == '(%s + %s)' % (a.name, b.name)
def test_array():
a = tvm.Var('a')
x = tvm.function._symbol([1,2,a])
print type(x)
print len(x)
print x[4]
if __name__ == "__main__":
test_basic()
test_array()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment