diff --git a/CMakeLists.txt b/CMakeLists.txt index 60826a304491b02ac175cec9847a30a12e7624c4..e072ced6fda19af1df638ee6e100e3683c52b4d6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -193,7 +193,7 @@ if(USE_GRAPH_RUNTIME) endif(USE_GRAPH_RUNTIME) if(USE_LLVM) - find_spackage(LLVM CONFIG REQUIRED) + find_package(LLVM CONFIG REQUIRED) include_directories(${LLVM_INCLUDE_DIRS}) add_definitions(${LLVM_DEFINITIONS}) set(TVM_LLVM_VERSION ${LLVM_VERSION_MAJOR}${LLVM_VERSION_MINOR}) diff --git a/topi/include/topi/nn.h b/topi/include/topi/nn.h index f5c6b3d2a05f924fb6d17eed9bc64b787b1bfb74..df4b73bdd6f33326ae7573a27b501be38eddf4d3 100644 --- a/topi/include/topi/nn.h +++ b/topi/include/topi/nn.h @@ -47,7 +47,10 @@ inline tvm::Tensor relu(const tvm::Tensor& t, std::string tag = kElementWise) { return tvm::compute( t->shape, - [&](const tvm::Array<tvm::Var>& i) { return tvm::max(t(i), threshold); }, + [&](const tvm::Array<tvm::Var>& i) { + auto threshold_const = tvm::make_const(t->dtype, threshold); + return tvm::max(t(i), threshold_const); + }, name, tag); } diff --git a/topi/include/topi/nn/flatten.h b/topi/include/topi/nn/flatten.h index 20a9b2cbfe78717095da2ad200c6f4bf92c588bc..d9577be36e5d89118802fea972b82c0565c7a4b4 100644 --- a/topi/include/topi/nn/flatten.h +++ b/topi/include/topi/nn/flatten.h @@ -55,7 +55,7 @@ inline Tensor flatten(const Tensor& x, index.push_back(i); std::reverse(index.begin(), index.end()); return x(index); - }); + }, name, tag); } } // namespace nn diff --git a/topi/python/topi/__init__.py b/topi/python/topi/__init__.py index 1648de33bb98b793a6324c1b742fcc1ebbdd58af..d7875cdbe52147baa7cb2e711d959bd0b52750c7 100644 --- a/topi/python/topi/__init__.py +++ b/topi/python/topi/__init__.py @@ -24,3 +24,4 @@ from . import mali from . import testing from . import util from . import rocm +from . import cpp diff --git a/topi/tests/python_cpp/test_topi_basic.py b/topi/tests/python_cpp/test_topi_basic.py index 1faba588b877448ffc026e1506fb6c2598c41c61..1057f746b004741012d52719ccbc5ab754c852f2 100644 --- a/topi/tests/python_cpp/test_topi_basic.py +++ b/topi/tests/python_cpp/test_topi_basic.py @@ -25,7 +25,12 @@ def test_ewise(): test_apply(topi.cpp.log, "log") test_apply(topi.cpp.sqrt, "sqrt") +def test_flatten_tag(): + A = tvm.placeholder((3, 4), name='A') + B = topi.cpp.nn.flatten(A) + assert B.op.tag == topi.tag.INJECTIVE if __name__ == "__main__": test_util() test_ewise() + test_flatten_tag() diff --git a/topi/tests/python_cpp/test_topi_relu.py b/topi/tests/python_cpp/test_topi_relu.py index 7322f892551736dd5522aea5f5213992179135c1..f214266351210bcd11b71be64cdebdfc98b25ba6 100644 --- a/topi/tests/python_cpp/test_topi_relu.py +++ b/topi/tests/python_cpp/test_topi_relu.py @@ -5,9 +5,10 @@ import tvm import topi from topi.util import get_const_tuple -def verify_relu(m, n): - A = tvm.placeholder((m, n), name='A') +def verify_relu(m, n, dtype): + A = tvm.placeholder((m, n), name='A', dtype=dtype) B = topi.cpp.nn.relu(A) + assert B.dtype == dtype a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype) b_np = a_np * (a_np > 0) @@ -51,7 +52,8 @@ def verify_leaky_relu(m, alpha): def test_relu(): - verify_relu(10, 128) + for dtype in ['float32', 'float64', 'int32', 'int16', 'int8', 'int64']: + verify_relu(10, 128, dtype) def test_leaky_relu(): verify_leaky_relu(100, 0.1)