Skip to content
Snippets Groups Projects
Commit 70d93028 authored by tqchen's avatar tqchen
Browse files

Keep up with changes of NodeRef

parent 2fc12dcd
No related branches found
No related tags found
No related merge requests found
Subproject commit bf96f8af0dfd1f79d258c7c1506f9ded932b94a9
Subproject commit eb2f7d604a611318fc685172847bcf5ba2fcf835
......@@ -95,13 +95,13 @@ class RDomainNode : public Node {
RDomainNode(Array<Var> index, Domain domain)
: index(index), domain(domain) {
}
const char* type_key() const override {
return "RDomain";
}
void VisitAttrs(AttrVisitor* v) final {
v->Visit("index", &index);
v->Visit("domain", &domain);
}
static constexpr const char* _type_key = "RDomain";
TVM_DECLARE_NODE_TYPE_INFO(RDomainNode);
};
inline const RDomainNode* RDomain::operator->() const {
......
......@@ -6,7 +6,7 @@
#ifndef TVM_IR_MUTATOR_H_
#define TVM_IR_MUTATOR_H_
#include <tvm/ir_node.h>
#include <tvm/ir_functor.h>
#include <unordered_map>
#include "./expr.h"
......@@ -16,7 +16,7 @@ namespace ir {
* \brief a base class for mutator to iterative mutate the IR
*
* This IRMutator is implemented via IRFunctor instead of Visitor Pattern.
* This enables easy extensions of possible new IRNode.
* This enables easy extensions of possible new Node.
* It also makes changing return types easier.
*
* \note If you want to return a different type other than Expr and Stmt,
......@@ -44,9 +44,9 @@ class IRMutator {
/*! \brief destructor */
virtual ~IRMutator() {}
/*! \brief functor type of expr mutation */
using FMutateExpr = IRFunctor<Expr(const IRNodeRef&, const Expr&, IRMutator*)>;
using FMutateExpr = IRFunctor<Expr(const NodeRef&, const Expr&, IRMutator*)>;
/*! \brief functor type of stmt mutation */
using FMutateStmt = IRFunctor<Stmt(const IRNodeRef&, const Stmt&, IRMutator*)>;
using FMutateStmt = IRFunctor<Stmt(const NodeRef&, const Stmt&, IRMutator*)>;
/*! \return internal vtable of expr */
static FMutateExpr& vtable_expr(); // NOLINT(*)
/*! \return internal stmt of expr */
......
......@@ -9,7 +9,7 @@
#ifndef TVM_IR_PASS_H_
#define TVM_IR_PASS_H_
#include <tvm/ir_node.h>
#include <tvm/ir_functor.h>
#include <unordered_map>
#include <vector>
#include "./expr.h"
......
......@@ -15,7 +15,7 @@ namespace ir {
* \brief a base class for visitor to iterative traverse the IR
*
* This IRVisitor is implemented via IRFunctor
* This enables extensions of possible new IRNode.
* This enables extensions of possible new Node.
*
* \sa IRFunctor, PostOrderVisit
*/
......@@ -24,14 +24,14 @@ class IRVisitor {
/*!
* \brief recursively visit an IR node
*/
virtual void Visit(const IRNodeRef& node) {
virtual void Visit(const NodeRef& node) {
static const FVisit& f = vtable();
if (node.defined()) f(node, this);
}
/*! \brief destructor */
virtual ~IRVisitor() {}
/*! \brief functor type of visitor */
using FVisit = IRFunctor<void(const IRNodeRef&, IRVisitor*)>;
using FVisit = IRFunctor<void(const NodeRef&, IRVisitor*)>;
/*! \return internal vtable*/
static FVisit& vtable();
};
......@@ -42,7 +42,7 @@ class IRVisitor {
* \param node The ir to be visited.
* \param fvisit The visitor function to be applied.
*/
void PostOrderVisit(const IRNodeRef& node, std::function<void(const IRNodeRef&)> fvisit);
void PostOrderVisit(const NodeRef& node, std::function<void(const NodeRef&)> fvisit);
} // namespace ir
} // namespace tvm
......
......@@ -23,9 +23,6 @@ class ComputeOpNode : public OperationNode {
/*! \brief constructor */
ComputeOpNode() {}
const char* type_key() const final {
return "ComputeOp";
}
size_t num_outputs() const final {
return 1;
}
......@@ -43,6 +40,9 @@ class ComputeOpNode : public OperationNode {
std::string name,
Array<Var> dim_var,
Expr body);
static constexpr const char* _type_key = "ComputeOp";
TVM_DECLARE_NODE_TYPE_INFO(ComputeOpNode);
};
......
......@@ -62,6 +62,10 @@ class ScheduleNode : public Node {
const char* type_key() const final {
return "Schedule";
}
const uint32_t type_index() const final {
static uint32_t tidx = TypeKey2Index(type_key());
return tidx;
}
void VisitAttrs(AttrVisitor* v) final {
v->Visit("scope", &scope);
v->Visit("op", &op);
......
......@@ -46,14 +46,15 @@ class DimSplitNode : public SplitNode {
Expr factor;
/*! \brief constructor */
DimSplitNode() {}
const char* type_key() const final {
return "DimSplit";
}
void VisitAttrs(AttrVisitor* v) final {
v->Visit("var", &var);
v->Visit("factor", &factor);
}
static Split make(Var var, Expr factor);
static constexpr const char* _type_key = "DimSplit";
TVM_DECLARE_NODE_TYPE_INFO(DimSplitNode);
};
// Implementations of inline functions
......
......@@ -104,9 +104,7 @@ class TensorNode : public FunctionBaseNode {
int value_index{0};
/*! \brief constructor */
TensorNode() {}
const char* type_key() const final {
return "Tensor";
}
void VisitAttrs(AttrVisitor* v) final {
v->Visit("shape", &shape);
v->Visit("name", &name);
......@@ -125,6 +123,9 @@ class TensorNode : public FunctionBaseNode {
Type dtype,
Operation op,
int value_index);
static constexpr const char* _type_key = "Tensor";
TVM_DECLARE_NODE_TYPE_INFO(TensorNode);
};
/*!
......
......@@ -9,5 +9,6 @@
#include "./base.h"
#include "./expr.h"
#include "./tensor.h"
#include "./operation.h"
#endif // TVM_TVM_H_
......@@ -26,9 +26,9 @@ TVM_REGISTER_API(_format_str)
CHECK(args.at(0).type_id == kNodeHandle);
std::ostringstream os;
auto& sptr = args.at(0).sptr;
if (sptr->is_type<TensorNode>()) {
if (dynamic_cast<const TensorNode*>(sptr.get())) {
os << args.at(0).operator Tensor();
} else if (sptr->is_type<RDomainNode>()) {
} else if (dynamic_cast<const RDomainNode*>(sptr.get())) {
os << args.at(0).operator RDomain();
} else if (dynamic_cast<const BaseExprNode*>(sptr.get())) {
os << args.at(0).operator Expr();
......
......@@ -22,7 +22,7 @@ namespace {
using namespace Halide::Internal;
// const expr
inline Expr ReturnSelfExpr(const IRNodeRef&, const Expr& e, IRMutator*) {
inline Expr ReturnSelfExpr(const NodeRef&, const Expr& e, IRMutator*) {
return e;
}
......
......@@ -12,9 +12,9 @@ namespace {
// visitor to implement apply
class IRApplyVisit : public IRVisitor {
public:
explicit IRApplyVisit(std::function<void(const IRNodeRef&)> f) : f_(f) {}
explicit IRApplyVisit(std::function<void(const NodeRef&)> f) : f_(f) {}
void Visit(const IRNodeRef& node) final {
void Visit(const NodeRef& node) final {
if (visited_.count(node.get()) != 0) return;
visited_.insert(node.get());
IRVisitor::Visit(node);
......@@ -22,13 +22,13 @@ class IRApplyVisit : public IRVisitor {
}
private:
std::function<void(const IRNodeRef&)> f_;
std::function<void(const NodeRef&)> f_;
std::unordered_set<const Node*> visited_;
};
} // namespace
void PostOrderVisit(const IRNodeRef& node, std::function<void(const IRNodeRef&)> fvisit) {
void PostOrderVisit(const NodeRef& node, std::function<void(const NodeRef&)> fvisit) {
IRApplyVisit(fvisit).Visit(node);
}
......@@ -42,7 +42,7 @@ namespace {
using namespace Halide::Internal;
void NoOp(const IRNodeRef& n, IRVisitor* v) {
void NoOp(const NodeRef& n, IRVisitor* v) {
}
inline void VisitArray(Array<Expr> arr, IRVisitor* v) {
......
......@@ -5,21 +5,37 @@
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include "./scope.h"
namespace tvm {
namespace ir {
namespace {
Stmt MakeCompute(const ComputeOpNode* op, const Array<Split>& splits) {
Tensor output;
std::vector<Expr> args(op->dim_var.size());
for (size_t i = 0; i < args.size(); ++i) {
args[i] = op->dim_var[i];
/*!
* \brief make nest loops given list of stmt, whose body is not defined.
* \param nest A list of For and LetStmt, whose body is not defined.
* \param body The inner-most body of the loop
*/
Stmt MakeLoop(std::vector<Stmt>&& nest, Stmt body) {
while (!nest.empty()) {
Stmt s = std::move(nest.back()); nest.pop_back();
if (s.as<For>()) {
auto n = std::make_shared<For>(*s.as<For>());
n->body = body;
body = Stmt(n);
} else if (s.as<LetStmt>()) {
auto n = std::make_shared<LetStmt>(*s.as<LetStmt>());
n->body = body;
body = Stmt(n);
} else if (s.as<AttrStmt>()) {
auto n = std::make_shared<AttrStmt>(*s.as<AttrStmt>());
n->body = body;
body = Stmt(n);
} else {
LOG(FATAL) << "not supported nest type";
}
}
Array<Expr> values{op->body};
Stmt stmt = Provide::make(output, values, args);
// add splits from ousside most to outsidemost to innermost
return stmt;
return body;
}
......
/*!
* Copyright (c) 2016 by Contributors
* \file scope.h
* \brief attribute scope data structure,
* defines attributes on current domain
*/
#ifndef TVM_PASS_SCOPE_H_
#define TVM_PASS_SCOPE_H_
#include <tvm/ir.h>
#include <unordered_map>
#include <vector>
#include <string>
namespace tvm {
namespace ir {
/*!
* \brief Attribute scope of Nodes in the IR.
* \tparam ValueType The value of of the scope.
*/
template<typename K, typename V>
class Scope {
public:
/*!
* \brief Push value to scope
* \param key the key to be pushed.
* \param v The value to be pushed.
*/
inline void Push(const K& key, V v) {
data_[key].emplace_back(v);
}
/*!
* \brief Pop value from scope.
* \param key the key to be poped
*/
inline void Pop(const K& key) {
auto& v = data_[key];
CHECK_NE(v.size(), 0);
v.pop_back();
}
/*!
* \brief Get value from the scope
* \param key the key to fetch.
* \return The value to be fetched.
*/
inline V operator[](const K& key) const {
const auto it = data_.find(key);
CHECK(it != data_.end() && it->second.size() != 0)
<< "cannot find value in scope";
return it->second.back();
}
private:
std::unordered_map<K, std::vector<V> > data_;
};
/*! \brief Attribute key for specific attribute */
struct AttrKey {
/*! \brief The node of the attribute */
NodeRef node;
/*! \brief The type key of the attribute. */
std::string type_key;
// overload operator ==
inline bool operator==(const AttrKey& other) const {
return node == other.node && type_key == other.type_key;
}
};
} // namespace ir
} // namespace tvm
namespace std {
template <>
struct hash<::tvm::ir::AttrKey> {
std::size_t operator()(const ::tvm::ir::AttrKey& k) const {
size_t lhs = k.node.hash();
size_t rhs = std::hash<std::string>()(k.type_key);
lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
return lhs;
}
};
} // namespace std
#endif // TVM_PASS_SCOPE_H_
......@@ -17,7 +17,7 @@ namespace {
// global functor to get var definition from
struct FGetVarDef {
using FType = IRFunctor<VarExpr (const IRNodeRef&)>;
using FType = IRFunctor<VarExpr (const NodeRef&)>;
static FType& vtable() { // NOLINT(*)
static FType inst; return inst;
}
......@@ -37,8 +37,8 @@ TVM_STATIC_IR_FUNCTOR(FGetVarDef, vtable)
});
struct FSetVarDef {
using FTypeExpr = IRFunctor<Expr (const IRNodeRef&, VarExpr)>;
using FTypeStmt = IRFunctor<Stmt (const IRNodeRef&, VarExpr)>;
using FTypeExpr = IRFunctor<Expr (const NodeRef&, VarExpr)>;
using FTypeStmt = IRFunctor<Stmt (const NodeRef&, VarExpr)>;
static FTypeExpr& vtable_expr() { // NOLINT(*)
static FTypeExpr inst; return inst;
}
......@@ -69,7 +69,7 @@ class IRVerifySSA : public IRVisitor {
public:
bool is_ssa{true};
void Visit(const IRNodeRef& n) final {
void Visit(const NodeRef& n) final {
if (!is_ssa) return;
static auto& fget_var_def = FGetVarDef::vtable();
if (fget_var_def.can_dispatch(n)) {
......
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/tvm.h>
#include <tvm/ir_node.h>
#include <tvm/ir_functor.h>
TEST(IRF, Basic) {
using namespace Halide::Internal;
......@@ -9,7 +9,7 @@ TEST(IRF, Basic) {
Var x("x");
auto z = x + 1;
IRFunctor<int(const IRNodeRef& n, int b)> f;
IRFunctor<int(const NodeRef& n, int b)> f;
LOG(INFO) << "x";
f.set_dispatch<Variable>([](const Variable* n, int b) {
return b;
......
......@@ -11,7 +11,7 @@ TEST(IRVisitor, CountVar) {
Var x("x"), y;
auto z = x + 1 + y + y;
ir::PostOrderVisit(z, [&n_var](const IRNodeRef& n) {
ir::PostOrderVisit(z, [&n_var](const NodeRef& n) {
if (n.as<Variable>()) ++n_var;
});
CHECK_EQ(n_var, 2);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment