diff --git a/HalideIR b/HalideIR index 24a7c0357a6a8db5db782d320aad7f706ebe8507..7f1d811972bccc26f651ea2289d88bcadea9fe9f 160000 --- a/HalideIR +++ b/HalideIR @@ -1 +1 @@ -Subproject commit 24a7c0357a6a8db5db782d320aad7f706ebe8507 +Subproject commit 7f1d811972bccc26f651ea2289d88bcadea9fe9f diff --git a/Makefile b/Makefile index d2f8bd71b4b783f77137d07488ec8616f7035269..7daddbd955af578fc83fa9a7d488c1edcaa52f1b 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,3 @@ -export CXX=g++ export LDFLAGS = -pthread -lm export CFLAGS = -std=c++11 -Wall -O2 -Wno-unknown-pragmas -funroll-loops\ -Iinclude -Idmlc-core/include -IHalideIR/src -fPIC diff --git a/python/tvm/stmt.py b/python/tvm/stmt.py index b2e92f20b45552a1d972bced0c3f6aeedf8e69cb..cc22e168782a8d04d168fb6aab49ff6350c78f94 100644 --- a/python/tvm/stmt.py +++ b/python/tvm/stmt.py @@ -1,11 +1,9 @@ from __future__ import absolute_import as _abs from ._ctypes._api import NodeBase, register_node -from . import function as _func from . import make as _make class Stmt(NodeBase): - def __repr__(self): - return _func.format_str(self) + pass @register_node class LetStmt(Stmt): diff --git a/src/c_api/c_api_ir.cc b/src/c_api/c_api_ir.cc index 291e600952414082206d7c823d794203b03fc36a..81f20c15db96b685d17d33185c675865b7d639b7 100644 --- a/src/c_api/c_api_ir.cc +++ b/src/c_api/c_api_ir.cc @@ -56,6 +56,23 @@ TVM_REGISTER_API(_make_Allocate) args.at(4)); }); +TVM_REGISTER_API(_make_LetStmt) +.set_body([](const ArgStack& args, RetValue *ret) { + + if (args.size() == 3) { + *ret = LetStmt::make(args.at(0), + args.at(1), + args.at(2)); + } else { + CHECK_EQ(args.size(), 5); + *ret = LetStmt::make(args.at(0), + args.at(1), + args.at(2), + args.at(3), + args.at(4)); + } + }); + // make from two arguments #define REGISTER_MAKE1(Node) \ TVM_REGISTER_API(_make_## Node) \ @@ -109,7 +126,6 @@ REGISTER_MAKE3(Select); REGISTER_MAKE3(Ramp); REGISTER_MAKE2(Broadcast); REGISTER_MAKE3(Let); -REGISTER_MAKE3(LetStmt); REGISTER_MAKE2(AssertStmt); REGISTER_MAKE3(ProducerConsumer); REGISTER_MAKE3(Store); diff --git a/src/c_api/c_api_registry.h b/src/c_api/c_api_registry.h index 6271d773e8e6ac4c3874da42dfd550a885a963ce..97a7e0c14f4feb132ae17eb6369af32a01ad4511 100644 --- a/src/c_api/c_api_registry.h +++ b/src/c_api/c_api_registry.h @@ -97,6 +97,9 @@ class APIVariantValue { inline operator T() const { if (type_id == kNull) return T(); CHECK_EQ(type_id, kNodeHandle); + // use dynamic RTTI for safety + CHECK(dynamic_cast<typename T::ContainerType*>(sptr.get())) + << "wrong type specified"; return T(sptr); } inline operator Expr() const { diff --git a/tests/python/test_basic.py b/tests/python/test_basic.py index e89479c8f141716a5bb2be28632420d85d9a631f..4ae1ad35f60a5c018b066837b446cfdc8e38f3b3 100644 --- a/tests/python/test_basic.py +++ b/tests/python/test_basic.py @@ -18,6 +18,15 @@ def test_ir(): stmt = tvm.make.Evaluate(z) assert isinstance(stmt, tvm.stmt.Evaluate) +def test_let(): + x = tvm.Var('x') + y = tvm.Var('y') + stmt = tvm.make.LetStmt( + x, 10, tvm.make.Evaluate(x + 1), y, "stride") + assert stmt.attr_of_node == y + print(stmt) + + def test_basic(): a = tvm.Var('a') b = tvm.Var('b') @@ -28,10 +37,10 @@ def test_array(): a = tvm.convert([1,2,3]) def test_stmt(): + x = tvm.make.Evaluate(0) tvm.make.For(tvm.Var('i'), 0, 1, tvm.stmt.For.Serial, 0, - tvm.make.Evaluate(0)) - + x) if __name__ == "__main__": @@ -40,3 +49,4 @@ if __name__ == "__main__": test_ir() test_basic() test_stmt() + test_let()