Skip to content
Snippets Groups Projects
Commit adf4bfef authored by tqchen's avatar tqchen
Browse files

Enable attribute key in LetStmt

parent 38f03f1f
No related branches found
No related tags found
No related merge requests found
Subproject commit 24a7c0357a6a8db5db782d320aad7f706ebe8507
Subproject commit 7f1d811972bccc26f651ea2289d88bcadea9fe9f
export CXX=g++
export LDFLAGS = -pthread -lm
export CFLAGS = -std=c++11 -Wall -O2 -Wno-unknown-pragmas -funroll-loops\
-Iinclude -Idmlc-core/include -IHalideIR/src -fPIC
......
from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node
from . import function as _func
from . import make as _make
class Stmt(NodeBase):
def __repr__(self):
return _func.format_str(self)
pass
@register_node
class LetStmt(Stmt):
......
......@@ -56,6 +56,23 @@ TVM_REGISTER_API(_make_Allocate)
args.at(4));
});
TVM_REGISTER_API(_make_LetStmt)
.set_body([](const ArgStack& args, RetValue *ret) {
if (args.size() == 3) {
*ret = LetStmt::make(args.at(0),
args.at(1),
args.at(2));
} else {
CHECK_EQ(args.size(), 5);
*ret = LetStmt::make(args.at(0),
args.at(1),
args.at(2),
args.at(3),
args.at(4));
}
});
// make from two arguments
#define REGISTER_MAKE1(Node) \
TVM_REGISTER_API(_make_## Node) \
......@@ -109,7 +126,6 @@ REGISTER_MAKE3(Select);
REGISTER_MAKE3(Ramp);
REGISTER_MAKE2(Broadcast);
REGISTER_MAKE3(Let);
REGISTER_MAKE3(LetStmt);
REGISTER_MAKE2(AssertStmt);
REGISTER_MAKE3(ProducerConsumer);
REGISTER_MAKE3(Store);
......
......@@ -97,6 +97,9 @@ class APIVariantValue {
inline operator T() const {
if (type_id == kNull) return T();
CHECK_EQ(type_id, kNodeHandle);
// use dynamic RTTI for safety
CHECK(dynamic_cast<typename T::ContainerType*>(sptr.get()))
<< "wrong type specified";
return T(sptr);
}
inline operator Expr() const {
......
......@@ -18,6 +18,15 @@ def test_ir():
stmt = tvm.make.Evaluate(z)
assert isinstance(stmt, tvm.stmt.Evaluate)
def test_let():
x = tvm.Var('x')
y = tvm.Var('y')
stmt = tvm.make.LetStmt(
x, 10, tvm.make.Evaluate(x + 1), y, "stride")
assert stmt.attr_of_node == y
print(stmt)
def test_basic():
a = tvm.Var('a')
b = tvm.Var('b')
......@@ -28,10 +37,10 @@ def test_array():
a = tvm.convert([1,2,3])
def test_stmt():
x = tvm.make.Evaluate(0)
tvm.make.For(tvm.Var('i'), 0, 1,
tvm.stmt.For.Serial, 0,
tvm.make.Evaluate(0))
x)
if __name__ == "__main__":
......@@ -40,3 +49,4 @@ if __name__ == "__main__":
test_ir()
test_basic()
test_stmt()
test_let()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment