Skip to content
Snippets Groups Projects
Commit 162ed02c authored by tqchen's avatar tqchen
Browse files

Add new functor

parent 0153649e
No related branches found
No related tags found
No related merge requests found
Subproject commit bd94f8c8e41b46ae7ca69a3405aac7463a4e23d5
Subproject commit 7906ae1edea96e416e338ea21b8bc248d1d6411c
/*!
* Copyright (c) 2016 by Contributors
* \file ir_node.h
* \file ir.h
* \brief Additional high level nodes in the IR
*/
#ifndef TVM_IR_NODE_H_
#define TVM_IR_NODE_H_
#ifndef TVM_IR_H_
#define TVM_IR_H_
#include <ir/Expr.h>
#include <ir/IR.h>
......
......@@ -4,7 +4,7 @@
*/
#include <tvm/base.h>
#include <tvm/expr.h>
#include <tvm/ir_node.h>
#include <tvm/ir.h>
#include <ir/IR.h>
#include <ir/IRPrinter.h>
#include <memory>
......
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/tvm.h>
#include <tvm/ir_node.h>
TEST(IRF, Basic) {
using namespace Halide::Internal;
using namespace tvm;
Var x("x");
auto z = x + 1;
IRFunctor<int(const IRNodeRef& n, int b)> f;
LOG(INFO) << "x";
f.set_dispatch<Variable>([](const Variable* n, int b) {
return b;
});
f.set_dispatch<Add>([](const Add* n, int b) {
return b + 2;
});
CHECK_EQ(f(x, 2), 2);
CHECK_EQ(f(z, 2), 4);
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment