diff --git a/HalideIR b/HalideIR index bd94f8c8e41b46ae7ca69a3405aac7463a4e23d5..7906ae1edea96e416e338ea21b8bc248d1d6411c 160000 --- a/HalideIR +++ b/HalideIR @@ -1 +1 @@ -Subproject commit bd94f8c8e41b46ae7ca69a3405aac7463a4e23d5 +Subproject commit 7906ae1edea96e416e338ea21b8bc248d1d6411c diff --git a/include/tvm/ir_node.h b/include/tvm/ir.h similarity index 94% rename from include/tvm/ir_node.h rename to include/tvm/ir.h index 10d099a158ac544e27ec4d0260f7846c7ae7287a..6995016f404cf824ae3ab068a253d3d5910bc861 100644 --- a/include/tvm/ir_node.h +++ b/include/tvm/ir.h @@ -1,10 +1,10 @@ /*! * Copyright (c) 2016 by Contributors - * \file ir_node.h + * \file ir.h * \brief Additional high level nodes in the IR */ -#ifndef TVM_IR_NODE_H_ -#define TVM_IR_NODE_H_ +#ifndef TVM_IR_H_ +#define TVM_IR_H_ #include <ir/Expr.h> #include <ir/IR.h> diff --git a/src/lang/ir_node.cc b/src/lang/ir.cc similarity index 98% rename from src/lang/ir_node.cc rename to src/lang/ir.cc index d704b2ee0402133f34d0c6f1ee5fb01fa5c1ef5c..4307387aaa13c25b2bc4e5dcfca190f74bb52d5e 100644 --- a/src/lang/ir_node.cc +++ b/src/lang/ir.cc @@ -4,7 +4,7 @@ */ #include <tvm/base.h> #include <tvm/expr.h> -#include <tvm/ir_node.h> +#include <tvm/ir.h> #include <ir/IR.h> #include <ir/IRPrinter.h> #include <memory> diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..1f07ce5a20297bb5477a38441e2f9bea3fe605cb --- /dev/null +++ b/tests/cpp/ir_functor_test.cc @@ -0,0 +1,28 @@ +#include <dmlc/logging.h> +#include <gtest/gtest.h> +#include <tvm/tvm.h> +#include <tvm/ir_node.h> + +TEST(IRF, Basic) { + using namespace Halide::Internal; + using namespace tvm; + Var x("x"); + auto z = x + 1; + + IRFunctor<int(const IRNodeRef& n, int b)> f; + LOG(INFO) << "x"; + f.set_dispatch<Variable>([](const Variable* n, int b) { + return b; + }); + f.set_dispatch<Add>([](const Add* n, int b) { + return b + 2; + }); + CHECK_EQ(f(x, 2), 2); + CHECK_EQ(f(z, 2), 4); +} + +int main(int argc, char ** argv) { + testing::InitGoogleTest(&argc, argv); + testing::FLAGS_gtest_death_test_style = "threadsafe"; + return RUN_ALL_TESTS(); +}