diff --git a/HalideIR b/HalideIR index f519848d972c67971b4cbf8c34070d5a5e3ede0d..cf6090aeaeb782d1daff54b0ca5c2c281d7008db 160000 --- a/HalideIR +++ b/HalideIR @@ -1 +1 @@ -Subproject commit f519848d972c67971b4cbf8c34070d5a5e3ede0d +Subproject commit cf6090aeaeb782d1daff54b0ca5c2c281d7008db diff --git a/include/tvm/api_registry.h b/include/tvm/api_registry.h index d6e9910ab1ee004401fc4212f63df5528f28796b..1532872397c3ceb7608612c45fe3a145879620cb 100644 --- a/include/tvm/api_registry.h +++ b/include/tvm/api_registry.h @@ -57,7 +57,7 @@ class EnvFuncNode : public Node { class EnvFunc : public NodeRef { public: EnvFunc() {} - explicit EnvFunc(std::shared_ptr<Node> n) : NodeRef(n) {} + explicit EnvFunc(NodePtr<Node> n) : NodeRef(n) {} /*! \return The internal global function pointer */ const EnvFuncNode* operator->() const { return static_cast<EnvFuncNode*>(node_.get()); @@ -105,7 +105,7 @@ class TypedEnvFunc<R(Args...)> : public NodeRef { /*! \brief short hand for this function type */ using TSelf = TypedEnvFunc<R(Args...)>; TypedEnvFunc() {} - explicit TypedEnvFunc(std::shared_ptr<Node> n) : NodeRef(n) {} + explicit TypedEnvFunc(NodePtr<Node> n) : NodeRef(n) {} /*! * \brief Assign global function to a TypedEnvFunc * \param other Another global function. diff --git a/include/tvm/arithmetic.h b/include/tvm/arithmetic.h index 54875bbbf474e02c4e999d98d32487cfd5676e38..fe0405264c516aeaf258142d6344849d07e9729b 100644 --- a/include/tvm/arithmetic.h +++ b/include/tvm/arithmetic.h @@ -38,7 +38,7 @@ class IntSet : public NodeRef { /*! \brief constructor */ IntSet() {} // constructor from not container. - explicit IntSet(std::shared_ptr<Node> n) : NodeRef(n) {} + explicit IntSet(NodePtr<Node> n) : NodeRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container diff --git a/include/tvm/attrs.h b/include/tvm/attrs.h index 7cd77a92d0dd7e9fd4f254429b6bc913acf343da..7071dad072146f6b0904b5ee73dff5e9c6bb25a3 100644 --- a/include/tvm/attrs.h +++ b/include/tvm/attrs.h @@ -136,7 +136,7 @@ class Attrs : public NodeRef { // normal constructor Attrs() {} // construct from shared ptr. - explicit Attrs(std::shared_ptr<Node> n) : NodeRef(n) {} + explicit Attrs(NodePtr<Node> n) : NodeRef(n) {} /*! \return The attribute node */ const BaseAttrsNode* operator->() const { @@ -442,7 +442,7 @@ class AttrDocEntry { public: using TSelf = AttrDocEntry; - explicit AttrDocEntry(std::shared_ptr<AttrFieldInfoNode> info) + explicit AttrDocEntry(NodePtr<AttrFieldInfoNode> info) : info_(info) { } TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { @@ -466,15 +466,15 @@ class AttrDocEntry { } private: - std::shared_ptr<AttrFieldInfoNode> info_; + NodePtr<AttrFieldInfoNode> info_; }; class AttrDocVisitor { public: template<typename T> AttrDocEntry operator()(const char* key, T* v) { - std::shared_ptr<AttrFieldInfoNode> info - = std::make_shared<AttrFieldInfoNode>(); + NodePtr<AttrFieldInfoNode> info + = make_node<AttrFieldInfoNode>(); info->name = key; info->type_info = TypeName<T>::value; fields_.push_back(AttrFieldInfo(info)); diff --git a/include/tvm/base.h b/include/tvm/base.h index c2d796b6002c00fce829caa0310babf8335201da..7104688aa169750f07b5e5c2cd5365a515f3c86c 100644 --- a/include/tvm/base.h +++ b/include/tvm/base.h @@ -8,7 +8,7 @@ #include <dmlc/logging.h> #include <dmlc/registry.h> -#include <tvm/node.h> +#include <tvm/node/node.h> #include <string> #include <memory> #include <functional> @@ -25,7 +25,7 @@ using ::tvm::AttrVisitor; class TypeName : public ::tvm::NodeRef { \ public: \ TypeName() {} \ - explicit TypeName(std::shared_ptr<::tvm::Node> n) : NodeRef(n) {} \ + explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : NodeRef(n) {} \ const NodeName* operator->() const { \ return static_cast<const NodeName*>(node_.get()); \ } \ @@ -48,7 +48,7 @@ std::string SaveJSON(const NodeRef& node); * * \return The shared_ptr of the Node. */ -std::shared_ptr<Node> LoadJSON_(std::string json_str); +NodePtr<Node> LoadJSON_(std::string json_str); /*! * \brief Load the node from json string. @@ -85,7 +85,7 @@ struct NodeFactoryReg { * If this is not empty then FGlobalKey * \return The created function. */ - using FCreate = std::function<std::shared_ptr<Node>(const std::string& global_key)>; + using FCreate = std::function<NodePtr<Node>(const std::string& global_key)>; /*! * \brief Global key function, only needed by global objects. * \param node The node pointer. @@ -123,7 +123,7 @@ struct NodeFactoryReg { #define TVM_REGISTER_NODE_TYPE(TypeName) \ static DMLC_ATTRIBUTE_UNUSED ::tvm::NodeFactoryReg & __make_Node ## _ ## TypeName ## __ = \ ::tvm::NodeFactoryReg::Registry()->__REGISTER__(TypeName::_type_key) \ - .set_creator([](const std::string&) { return std::make_shared<TypeName>(); }) + .set_creator([](const std::string&) { return ::tvm::make_node<TypeName>(); }) #define TVM_STRINGIZE_DETAIL(x) #x diff --git a/include/tvm/buffer.h b/include/tvm/buffer.h index 0f591299718e3482c7ea7f52c52509db40cde59f..5901a27fe1cee8ca4263357efc0dc930bda833f8 100644 --- a/include/tvm/buffer.h +++ b/include/tvm/buffer.h @@ -6,11 +6,11 @@ #ifndef TVM_BUFFER_H_ #define TVM_BUFFER_H_ -#include <tvm/container.h> #include <string> #include "base.h" #include "expr.h" +#include "node/container.h" namespace tvm { @@ -31,7 +31,7 @@ enum class AccessMask : int { class Buffer : public NodeRef { public: Buffer() {} - explicit Buffer(std::shared_ptr<Node> n) : NodeRef(n) {} + explicit Buffer(NodePtr<Node> n) : NodeRef(n) {} /*! * \brief Return a new buffer that is equivalent with current one * but always add stride field. diff --git a/include/tvm/build_module.h b/include/tvm/build_module.h index 5dc8320414104e0c60b5033dd95aef2a304f862d..7aafad4216e159a62a29227964eed23c98c92381 100644 --- a/include/tvm/build_module.h +++ b/include/tvm/build_module.h @@ -69,7 +69,7 @@ class TargetNode : public Node { class Target : public NodeRef { public: Target() {} - explicit Target(std::shared_ptr<Node> n) : NodeRef(n) {} + explicit Target(NodePtr<Node> n) : NodeRef(n) {} /*! * \brief Create a Target given a string @@ -241,7 +241,7 @@ class BuildConfigNode : public Node { class BuildConfig : public ::tvm::NodeRef { public: BuildConfig() {} - explicit BuildConfig(std::shared_ptr<::tvm::Node> n) : NodeRef(n) {} + explicit BuildConfig(NodePtr<::tvm::Node> n) : NodeRef(n) {} const BuildConfigNode* operator->() const { return static_cast<const BuildConfigNode*>(node_.get()); @@ -335,7 +335,7 @@ class GenericFuncNode; class GenericFunc : public NodeRef { public: GenericFunc() {} - explicit GenericFunc(std::shared_ptr<Node> n) : NodeRef(n) {} + explicit GenericFunc(NodePtr<Node> n) : NodeRef(n) {} /*! * \brief Set the default function implementaiton. diff --git a/include/tvm/channel.h b/include/tvm/channel.h index 28d9b5f7ce4abb2de9c4b0c742dfba9d3e710e98..051b57a194c46c98e123de6e813fd73e795024dd 100644 --- a/include/tvm/channel.h +++ b/include/tvm/channel.h @@ -17,7 +17,7 @@ class Channel : public NodeRef { public: /*! \brief default constructor */ Channel() {} - explicit Channel(std::shared_ptr<Node> n) : NodeRef(n) {} + explicit Channel(NodePtr<Node> n) : NodeRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container diff --git a/include/tvm/expr.h b/include/tvm/expr.h index fb2233dacb69e37cea8f53833468211bccba39d3..a199d656caf87899d75666e5861114a0bfbff45c 100644 --- a/include/tvm/expr.h +++ b/include/tvm/expr.h @@ -76,7 +76,7 @@ class Var : public HalideIR::VarExpr { public: EXPORT explicit Var(const std::string& name_hint = "v", Type t = Int(32)) : VarExpr(name_hint, t) {} - explicit Var(std::shared_ptr<Node> n) : VarExpr(n) {} + explicit Var(NodePtr<Node> n) : VarExpr(n) {} explicit Var(VarExpr v) : VarExpr(v) {} /*! * \brief Make a new copy of var with same type, append suffix @@ -107,7 +107,7 @@ class Range : public HalideIR::IR::Range { public: /*! \brief constructor */ Range() {} - explicit Range(std::shared_ptr<Node> n) : HalideIR::IR::Range(n) {} + explicit Range(NodePtr<Node> n) : HalideIR::IR::Range(n) {} /*! * \brief constructor by begin and end * \param begin The begin of the range. @@ -197,7 +197,7 @@ class IterVar : public NodeRef { // construct a new iter var without a domain IterVar() {} // construct from shared ptr. - explicit IterVar(std::shared_ptr<Node> n) : NodeRef(n) {} + explicit IterVar(NodePtr<Node> n) : NodeRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container diff --git a/include/tvm/ir.h b/include/tvm/ir.h index f73533439dbaa5c9ddcfefdcb13e04d2939efc61..b75d75c18182cf3dea9a779ad5709f3c42e8fe63 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -28,7 +28,7 @@ struct CommReducerNode; struct CommReducer : public NodeRef { CommReducer() {} - explicit CommReducer(std::shared_ptr<Node> n) : NodeRef(n) {} + explicit CommReducer(NodePtr<Node> n) : NodeRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container diff --git a/include/tvm/ir_functor_ext.h b/include/tvm/ir_functor_ext.h index a9845fdfc898c970eeb29d65118520f1afce61b7..85d2de75dd99db1aa45689d95704c0b60aa10996 100644 --- a/include/tvm/ir_functor_ext.h +++ b/include/tvm/ir_functor_ext.h @@ -6,7 +6,7 @@ #ifndef TVM_IR_FUNCTOR_EXT_H_ #define TVM_IR_FUNCTOR_EXT_H_ -#include <tvm/ir_functor.h> +#include "node/ir_functor.h" #include "ir.h" namespace tvm { diff --git a/include/tvm/ir_mutator.h b/include/tvm/ir_mutator.h index 35c82e9f16c1c54bca422cfd47c1f470a48c8771..6b391caf4b5fd3a03e23c771e7c0c9489424e82a 100644 --- a/include/tvm/ir_mutator.h +++ b/include/tvm/ir_mutator.h @@ -6,10 +6,10 @@ #ifndef TVM_IR_MUTATOR_H_ #define TVM_IR_MUTATOR_H_ -#include <tvm/ir_functor.h> #include <unordered_map> #include "expr.h" #include "ir.h" +#include "node/ir_functor.h" namespace tvm { namespace ir { diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index cf20dfa1e9f373c13bbf15ffa2db261898ec3126..ab42cfc9625fc87b4076573df64369eed8851990 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -9,7 +9,6 @@ #ifndef TVM_IR_PASS_H_ #define TVM_IR_PASS_H_ -#include <tvm/ir_functor.h> #include <arithmetic/Simplify.h> #include <unordered_map> #include <vector> diff --git a/include/tvm/ir_visitor.h b/include/tvm/ir_visitor.h index 4b2887b288859588e8522ff24695b90985a848e1..265ec0e56efbde10e360ba156bb7512436580652 100644 --- a/include/tvm/ir_visitor.h +++ b/include/tvm/ir_visitor.h @@ -6,8 +6,8 @@ #ifndef TVM_IR_VISITOR_H_ #define TVM_IR_VISITOR_H_ -#include <tvm/ir_functor.h> #include "ir.h" +#include "node/ir_functor.h" namespace tvm { namespace ir { diff --git a/include/tvm/lowered_func.h b/include/tvm/lowered_func.h index acb9813339f89da7f32b340a645a14e9b5cea944..8bd2b1ba84cf421219a79e3e8af34fc51cca0444 100644 --- a/include/tvm/lowered_func.h +++ b/include/tvm/lowered_func.h @@ -7,13 +7,13 @@ #ifndef TVM_LOWERED_FUNC_H_ #define TVM_LOWERED_FUNC_H_ -#include <tvm/container.h> #include <ir/FunctionBase.h> #include <string> #include "base.h" #include "expr.h" #include "tensor.h" +#include "node/container.h" namespace tvm { @@ -27,7 +27,7 @@ class LoweredFuncNode; class LoweredFunc : public FunctionRef { public: LoweredFunc() {} - explicit LoweredFunc(std::shared_ptr<Node> n) : FunctionRef(n) {} + explicit LoweredFunc(NodePtr<Node> n) : FunctionRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container diff --git a/include/tvm/node/container.h b/include/tvm/node/container.h new file mode 100644 index 0000000000000000000000000000000000000000..43adae27671c5756180b28c3ae733a249e7ad32b --- /dev/null +++ b/include/tvm/node/container.h @@ -0,0 +1,586 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/node/container.h + * \brief Array/Map container in the DSL graph. + */ +#ifndef TVM_NODE_CONTAINER_H_ +#define TVM_NODE_CONTAINER_H_ + +#include <type_traits> +#include <vector> +#include <initializer_list> +#include <unordered_map> +#include <utility> +#include <string> +#include "node.h" +#include "memory.h" + +namespace tvm { + +/*! \brief array node content in array */ +class ArrayNode : public Node { + public: + /*! \brief the data content */ + std::vector<NodePtr<Node> > data; + + void VisitAttrs(AttrVisitor* visitor) final { + // Visitor to array have no effect. + } + + static constexpr const char* _type_key = "Array"; + TVM_DECLARE_NODE_TYPE_INFO(ArrayNode, Node); +}; + +/*! \brief map node content */ +class MapNode : public Node { + public: + void VisitAttrs(AttrVisitor* visitor) final { + // Visitor to map have no effect. + } + // hash function + struct Hash { + size_t operator()(const NodePtr<Node>& n) const { + return std::hash<Node*>()(n.get()); + } + }; + // comparator + struct Equal { + bool operator()( + const NodePtr<Node>& a, + const NodePtr<Node>& b) const { + return a.get() == b.get(); + } + }; + + /*! \brief The corresponding conatiner type */ + using ContainerType = std::unordered_map< + NodePtr<Node>, + NodePtr<Node>, + Hash, Equal>; + + /*! \brief the data content */ + ContainerType data; + + static constexpr const char* _type_key = "Map"; + TVM_DECLARE_NODE_TYPE_INFO(MapNode, Node); +}; + + +/*! \brief specialized map node with string as key */ +class StrMapNode : public Node { + public: + void VisitAttrs(AttrVisitor* visitor) final { + // Visitor to map have no effect. + } + /*! \brief The corresponding conatiner type */ + using ContainerType = std::unordered_map< + std::string, + NodePtr<Node> >; + + /*! \brief the data content */ + ContainerType data; + + static constexpr const char* _type_key = "StrMap"; + TVM_DECLARE_NODE_TYPE_INFO(StrMapNode, Node); +}; + +/*! + * \brief iterator adapter that adapts TIter to return another type. + * \tparam Converter a struct that contains converting function + * \tparam TIter the content iterator type. + */ +template<typename Converter, + typename TIter> +class IterAdapter { + public: + explicit IterAdapter(TIter iter) : iter_(iter) {} + inline IterAdapter& operator++() { // NOLINT(*) + ++iter_; + return *this; + } + inline IterAdapter& operator++(int) { // NOLINT(*) + ++iter_; + return *this; + } + inline IterAdapter operator+(int offset) const { // NOLINT(*) + return IterAdapter(iter_ + offset); + } + inline bool operator==(IterAdapter other) const { + return iter_ == other.iter_; + } + inline bool operator!=(IterAdapter other) const { + return !(*this == other); + } + inline const typename Converter::ResultType operator*() const { + return Converter::convert(*iter_); + } + + private: + TIter iter_; +}; + +/*! + * \brief Array container of NodeRef in DSL graph. + * Array implements copy on write semantics, which means array is mutable + * but copy will happen when array is referenced in more than two places. + * + * operator[] only provide const acces, use Set to mutate the content. + * \tparam T The content NodeRef type. + */ +template<typename T, + typename = typename std::enable_if<std::is_base_of<NodeRef, T>::value>::type > +class Array : public NodeRef { + public: + /*! + * \brief default constructor + */ + Array() { + node_ = make_node<ArrayNode>(); + } + /*! + * \brief move constructor + * \param other source + */ + Array(Array<T> && other) { // NOLINT(*) + node_ = std::move(other.node_); + } + /*! + * \brief copy constructor + * \param other source + */ + Array(const Array<T> &other) { // NOLINT(*) + node_ = other.node_; + } + /*! + * \brief constructor from pointer + * \param n the container pointer + */ + explicit Array(NodePtr<Node> n) : NodeRef(n) {} + /*! + * \brief constructor from iterator + * \param begin begin of iterator + * \param end end of iterator + * \tparam IterType The type of iterator + */ + template<typename IterType> + Array(IterType begin, IterType end) { + assign(begin, end); + } + /*! + * \brief constructor from initializer list + * \param init The initalizer list + */ + Array(std::initializer_list<T> init) { // NOLINT(*) + assign(init.begin(), init.end()); + } + /*! + * \brief constructor from vector + * \param init The vector + */ + Array(const std::vector<T>& init) { // NOLINT(*) + assign(init.begin(), init.end()); + } + /*! + * \brief move assign operator + * \param other The source of assignment + * \return reference to self. + */ + Array<T>& operator=(Array<T> && other) { + node_ = std::move(other.node_); + return *this; + } + /*! + * \brief copy assign operator + * \param other The source of assignment + * \return reference to self. + */ + Array<T>& operator=(const Array<T> & other) { + node_ = other.node_; + return *this; + } + /*! + * \brief reset the array to content from iterator. + * \param begin begin of iterator + * \param end end of iterator + * \tparam IterType The type of iterator + */ + template<typename IterType> + void assign(IterType begin, IterType end) { + auto n = make_node<ArrayNode>(); + for (IterType it = begin; it != end; ++it) { + n->data.push_back((*it).node_); + } + node_ = std::move(n); + } + /*! + * \brief Read i-th element from array. + * \param i The index + * \return the i-th element. + */ + inline const T operator[](size_t i) const { + return T(static_cast<const ArrayNode*>(node_.get())->data[i]); + } + /*! \return The size of the array */ + inline size_t size() const { + if (node_.get() == nullptr) return 0; + return static_cast<const ArrayNode*>(node_.get())->data.size(); + } + /*! + * \brief copy on write semantics + * Do nothing if current handle is the unique copy of the array. + * Otherwise make a new copy of the array to ensure the current handle + * hold a unique copy. + * + * \return Handle to the internal node container(which ganrantees to be unique) + */ + inline ArrayNode* CopyOnWrite() { + if (node_.get() == nullptr || !node_.unique()) { + NodePtr<ArrayNode> n = make_node<ArrayNode>(); + n->data = static_cast<ArrayNode*>(node_.get())->data; + NodePtr<Node>(std::move(n)).swap(node_); + } + return static_cast<ArrayNode*>(node_.get()); + } + /*! + * \brief push a new item to the back of the list + * \param item The item to be pushed. + */ + inline void push_back(const T& item) { + ArrayNode* n = this->CopyOnWrite(); + n->data.push_back(item.node_); + } + /*! + * \brief set i-th element of the array. + * \param i The index + * \param value The value to be setted. + */ + inline void Set(size_t i, const T& value) { + ArrayNode* n = this->CopyOnWrite(); + n->data[i] = value.node_; + } + /*! \return whether array is empty */ + inline bool empty() const { + return size() == 0; + } + /*! \brief specify container node */ + using ContainerType = ArrayNode; + + struct Ptr2NodeRef { + using ResultType = T; + static inline T convert(const NodePtr<Node>& n) { + return T(n); + } + }; + using iterator = IterAdapter<Ptr2NodeRef, + std::vector<NodePtr<Node> >::const_iterator>; + + using reverse_iterator = IterAdapter< + Ptr2NodeRef, + std::vector<NodePtr<Node> >::const_reverse_iterator>; + + /*! \return begin iterator */ + inline iterator begin() const { + return iterator(static_cast<const ArrayNode*>(node_.get())->data.begin()); + } + /*! \return end iterator */ + inline iterator end() const { + return iterator(static_cast<const ArrayNode*>(node_.get())->data.end()); + } + /*! \return rbegin iterator */ + inline reverse_iterator rbegin() const { + return reverse_iterator(static_cast<const ArrayNode*>(node_.get())->data.rbegin()); + } + /*! \return rend iterator */ + inline reverse_iterator rend() const { + return reverse_iterator(static_cast<const ArrayNode*>(node_.get())->data.rend()); + } +}; + +/*! + * \brief Map container of NodeRef->NodeRef in DSL graph. + * Map implements copy on write semantics, which means map is mutable + * but copy will happen when array is referenced in more than two places. + * + * operator[] only provide const acces, use Set to mutate the content. + * \tparam K The key NodeRef type. + * \tparam V The value NodeRef type. + */ +template<typename K, + typename V, + typename = typename std::enable_if< + std::is_base_of<NodeRef, K>::value || + std::is_base_of<std::string, K>::value >::type, + typename = typename std::enable_if<std::is_base_of<NodeRef, V>::value>::type> +class Map : public NodeRef { + public: + /*! + * \brief default constructor + */ + Map() { + node_ = make_node<MapNode>(); + } + /*! + * \brief move constructor + * \param other source + */ + Map(Map<K, V> && other) { // NOLINT(*) + node_ = std::move(other.node_); + } + /*! + * \brief copy constructor + * \param other source + */ + Map(const Map<K, V> &other) { // NOLINT(*) + node_ = other.node_; + } + /*! + * \brief constructor from pointer + * \param n the container pointer + */ + explicit Map(NodePtr<Node> n) : NodeRef(n) {} + /*! + * \brief constructor from iterator + * \param begin begin of iterator + * \param end end of iterator + * \tparam IterType The type of iterator + */ + template<typename IterType> + Map(IterType begin, IterType end) { + assign(begin, end); + } + /*! + * \brief constructor from initializer list + * \param init The initalizer list + */ + Map(std::initializer_list<std::pair<K, V> > init) { // NOLINT(*) + assign(init.begin(), init.end()); + } + /*! + * \brief constructor from vector + * \param init The vector + */ + template<typename Hash, typename Equal> + Map(const std::unordered_map<K, V, Hash, Equal>& init) { // NOLINT(*) + assign(init.begin(), init.end()); + } + /*! + * \brief move assign operator + * \param other The source of assignment + * \return reference to self. + */ + Map<K, V>& operator=(Map<K, V> && other) { + node_ = std::move(other.node_); + return *this; + } + /*! + * \brief copy assign operator + * \param other The source of assignment + * \return reference to self. + */ + Map<K, V>& operator=(const Map<K, V> & other) { + node_ = other.node_; + return *this; + } + /*! + * \brief reset the array to content from iterator. + * \param begin begin of iterator + * \param end end of iterator + * \tparam IterType The type of iterator + */ + template<typename IterType> + void assign(IterType begin, IterType end) { + NodePtr<MapNode> n = make_node<MapNode>(); + for (IterType i = begin; i != end; ++i) { + n->data.emplace(std::make_pair(i->first.node_, + i->second.node_)); + } + node_ = std::move(n); + } + /*! + * \brief Read element from map. + * \param key The key + * \return the corresonding element. + */ + inline const V operator[](const K& key) const { + return V(static_cast<const MapNode*>(node_.get())->data.at(key.node_)); + } + /*! + * \brief Read element from map. + * \param key The key + * \return the corresonding element. + */ + inline const V at(const K& key) const { + return V(static_cast<const MapNode*>(node_.get())->data.at(key.node_)); + } + /*! \return The size of the array */ + inline size_t size() const { + if (node_.get() == nullptr) return 0; + return static_cast<const MapNode*>(node_.get())->data.size(); + } + /*! \return The size of the array */ + inline size_t count(const K& key) const { + if (node_.get() == nullptr) return 0; + return static_cast<const MapNode*>(node_.get())->data.count(key.node_); + } + /*! + * \brief copy on write semantics + * Do nothing if current handle is the unique copy of the array. + * Otherwise make a new copy of the array to ensure the current handle + * hold a unique copy. + * + * \return Handle to the internal node container(which ganrantees to be unique) + */ + inline MapNode* CopyOnWrite() { + if (node_.get() == nullptr || !node_.unique()) { + NodePtr<MapNode> n = make_node<MapNode>(); + n->data = static_cast<const MapNode*>(node_.get())->data; + NodePtr<Node>(std::move(n)).swap(node_); + } + return static_cast<MapNode*>(node_.get()); + } + /*! + * \brief set the Map. + * \param key The index key. + * \param value The value to be setted. + */ + inline void Set(const K& key, const V& value) { + MapNode* n = this->CopyOnWrite(); + n->data[key.node_] = value.node_; + } + + /*! \return whether array is empty */ + inline bool empty() const { + return size() == 0; + } + /*! \brief specify container node */ + using ContainerType = MapNode; + + struct Ptr2NodeRef { + using ResultType = std::pair<K, V>; + static inline ResultType convert(const std::pair< + NodePtr<Node>, + NodePtr<Node> >& n) { + return std::make_pair(K(n.first), V(n.second)); + } + }; + + using iterator = IterAdapter< + Ptr2NodeRef, MapNode::ContainerType::const_iterator>; + + /*! \return begin iterator */ + inline iterator begin() const { + return iterator(static_cast<const MapNode*>(node_.get())->data.begin()); + } + /*! \return end iterator */ + inline iterator end() const { + return iterator(static_cast<const MapNode*>(node_.get())->data.end()); + } + /*! \return begin iterator */ + inline iterator find(const K& key) const { + return iterator(static_cast<const MapNode*>(node_.get())->data.find(key.node_)); + } +}; + +// specialize of string map +template<typename V, typename T1, typename T2> +class Map<std::string, V, T1, T2> : public NodeRef { + public: + // for code reuse + Map() { + node_ = make_node<StrMapNode>(); + } + Map(Map<std::string, V> && other) { // NOLINT(*) + node_ = std::move(other.node_); + } + Map(const Map<std::string, V> &other) { // NOLINT(*) + node_ = other.node_; + } + explicit Map(NodePtr<Node> n) : NodeRef(n) {} + template<typename IterType> + Map(IterType begin, IterType end) { + assign(begin, end); + } + Map(std::initializer_list<std::pair<std::string, V> > init) { // NOLINT(*) + assign(init.begin(), init.end()); + } + + template<typename Hash, typename Equal> + Map(const std::unordered_map<std::string, V, Hash, Equal>& init) { // NOLINT(*) + assign(init.begin(), init.end()); + } + Map<std::string, V>& operator=(Map<std::string, V> && other) { + node_ = std::move(other.node_); + return *this; + } + Map<std::string, V>& operator=(const Map<std::string, V> & other) { + node_ = other.node_; + return *this; + } + template<typename IterType> + void assign(IterType begin, IterType end) { + auto n = make_node<StrMapNode>(); + for (IterType i = begin; i != end; ++i) { + n->data.emplace(std::make_pair(i->first, + i->second.node_)); + } + node_ = std::move(n); + } + inline const V operator[](const std::string& key) const { + return V(static_cast<const StrMapNode*>(node_.get())->data.at(key)); + } + inline const V at(const std::string& key) const { + return V(static_cast<const StrMapNode*>(node_.get())->data.at(key)); + } + inline size_t size() const { + if (node_.get() == nullptr) return 0; + return static_cast<const StrMapNode*>(node_.get())->data.size(); + } + inline size_t count(const std::string& key) const { + if (node_.get() == nullptr) return 0; + return static_cast<const StrMapNode*>(node_.get())->data.count(key); + } + inline StrMapNode* CopyOnWrite() { + if (node_.get() == nullptr || !node_.unique()) { + NodePtr<StrMapNode> n = make_node<StrMapNode>(); + n->data = static_cast<const StrMapNode*>(node_.get())->data; + NodePtr<Node>(std::move(n)).swap(node_); + } + return static_cast<StrMapNode*>(node_.get()); + } + inline void Set(const std::string& key, const V& value) { + StrMapNode* n = this->CopyOnWrite(); + n->data[key] = value.node_; + } + inline bool empty() const { + return size() == 0; + } + using ContainerType = StrMapNode; + + struct Ptr2NodeRef { + using ResultType = std::pair<std::string, V>; + static inline ResultType convert(const std::pair< + std::string, + NodePtr<Node> >& n) { + return std::make_pair(n.first, V(n.second)); + } + }; + + using iterator = IterAdapter< + Ptr2NodeRef, StrMapNode::ContainerType::const_iterator>; + + /*! \return begin iterator */ + inline iterator begin() const { + return iterator(static_cast<const StrMapNode*>(node_.get())->data.begin()); + } + /*! \return end iterator */ + inline iterator end() const { + return iterator(static_cast<const StrMapNode*>(node_.get())->data.end()); + } + /*! \return begin iterator */ + inline iterator find(const std::string& key) const { + return iterator(static_cast<const StrMapNode*>(node_.get())->data.find(key)); + } +}; + +} // namespace tvm +#endif // TVM_NODE_CONTAINER_H_ diff --git a/include/tvm/node/ir_functor.h b/include/tvm/node/ir_functor.h new file mode 100644 index 0000000000000000000000000000000000000000..293bec75bbf517758e1a4c6953a19700fccee2ce --- /dev/null +++ b/include/tvm/node/ir_functor.h @@ -0,0 +1,254 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/node/ir_functor.h + * \brief Defines the IRFunctor data structures. + */ +#ifndef TVM_NODE_IR_FUNCTOR_H_ +#define TVM_NODE_IR_FUNCTOR_H_ + +#include <dmlc/logging.h> +#include <string> +#include <vector> +#include <type_traits> +#include <functional> +#include "node.h" +#include "../runtime/registry.h" + +namespace tvm { +/*! + * \brief A dynamical dispatched functor on NodeRef in the first argument. + * + * \code + * IRFunctor<std::string (const NodeRef& n, std::string prefix)> tostr; + * tostr.set_dispatch<Add>([](const Add* op, std::string prefix) { + * return prefix + "Add"; + * }); + * tostr.set_dispatch<IntImm>([](const IntImm* op) { + * return prefix + "IntImm" + * }); + * + * Expr x = make_const(1); + * Expr y = x + x; + * // dispatch to IntImm, outputs "MyIntImm" + * LOG(INFO) << tostr(x, "My"); + * // dispatch to IntImm, outputs "MyAdd" + * LOG(INFO) << tostr(y, "My"); + * \endcode + * + * \tparam FType function signiture + * This type if only defined for FType with function signiture + */ +template<typename FType> +class IRFunctor; + +template<typename R, typename ...Args> +class IRFunctor<R(const NodeRef& n, Args...)> { + private: + using Function = std::function<R (const NodeRef&n, Args...)>; + using TSelf = IRFunctor<R (const NodeRef& n, Args...)>; + /*! \brief internal function table */ + std::vector<Function> func_; + + public: + /*! \brief the result type of this functor */ + using result_type = R; + /*! + * \brief Whether the functor can dispatch the corresponding Node + * \param n The node to be dispatched + * \return Whether dispatching function is registered for n's type. + */ + inline bool can_dispatch(const NodeRef& n) const { + uint32_t type_index = n.type_index(); + return type_index < func_.size() && func_[type_index] != nullptr; + } + /*! + * \brief invoke the functor , dispatch on type of n + * \param n The Node argument + * \param args The additional arguments + * \return The result. + */ + inline R operator()(const NodeRef& n, Args... args) const { + uint32_t type_index = n.type_index(); + CHECK(type_index < func_.size() && + func_[type_index] != nullptr) + << "IRFunctor calls un-registered function on type " + << Node::TypeIndex2Key(type_index); + return func_[type_index](n, std::forward<Args>(args)...); + } + /*! + * \brief set the dispacher for type TNode + * \param f The function to be set. + * \tparam TNode the type of Node to be dispatched. + * \return reference to self. + */ + template<typename TNode> + inline TSelf& set_dispatch(Function f) { // NOLINT(*) + uint32_t tindex = Node::TypeKey2Index(TNode::_type_key); + if (func_.size() <= tindex) { + func_.resize(tindex + 1, nullptr); + } + CHECK(func_[tindex] == nullptr) + << "Dispatch for " << Node::TypeIndex2Key(tindex) + << " is already set"; + func_[tindex] = f; + return *this; + } + /*! + * \brief set the dispacher for type TNode + * This allows f to used detailed const Node pointer to replace NodeRef + * + * \param f The function to be set. + * \tparam TNode the type of Node to be dispatched. + * \return reference to self. + */ + template<typename TNode> + inline TSelf& set_dispatch(std::function<R(const TNode* n, Args...)> f) { // NOLINT(*) + Function fun = [f](const NodeRef& n, Args... args) { + return f(static_cast<const TNode*>(n.node_.get()), + std::forward<Args>(args)...); + }; + return this->set_dispatch<TNode>(fun); + } + /*! + * \brief unset the dispacher for type TNode + * + * \tparam TNode the type of Node to be dispatched. + * \return reference to self. + */ + template<typename TNode> + inline TSelf& clear_dispatch() { // NOLINT(*) + uint32_t tindex = Node::TypeKey2Index(TNode::_type_key); + CHECK_LT(tindex, func_.size()) << "clear_dispatch: index out of range"; + func_[tindex] = nullptr; + return *this; + } +}; + +#define TVM_REGISTER_VAR_DEF(ClsName) \ + static TVM_ATTRIBUTE_UNUSED auto & __make_functor ## _ ## ClsName + +/*! + * \brief Useful macro to set IRFunctor dispatch in a global static field. + * + * \code + * // Use IRFunctor to implement IRPrinter similar to Visitor Pattern. + * // vtable allows easy patch in of new Node types, without changing + * // interface of IRPrinter. + * + * class IRPrinter { + * public: + * std::ostream& stream; + * // the dispatch function. + * void print(Expr e) { + * const static FType& f = *vtable(); + * f(e, this); + * } + * + * using FType = IRFunctor<void (const NodeRef&, IRPrinter *)>; + * // function to return global function table + * static FType& vtable(); + * }; + * + * // in cpp/cc file + * IRPrinter::FType& IRPrinter::vtable() { // NOLINT(*0 + * static FType inst; return inst; + * } + * + * TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) + * .set_dispatch<Add>([](const Add* n, IRPrinter* p) { + * p->print(n->a); + * p->stream << '+' + * p->print(n->b); + * }); + * + * + * \endcode + * + * \param ClsName The name of the class + * \param FField The static function that returns a singleton of IRFunctor. + */ +#define TVM_STATIC_IR_FUNCTOR(ClsName, FField) \ + TVM_STR_CONCAT(TVM_REGISTER_VAR_DEF(ClsName), __COUNTER__) = \ + ClsName::FField() + + /*! + * \brief A container for a list of callbacks. All callbacks are invoked when + * the object is destructed. + */ +class IRFunctorCleanList { + public: + ~IRFunctorCleanList() { + for (auto &f : clean_items) { + f(); + } + } + + void append(std::function<void()> func) { + clean_items.push_back(func); + } + + private: + std::vector< std::function<void()> > clean_items; +}; + +/*! +* \brief A wrapper around IRFunctor that will record calls to set_dispatch +* and make a corresponding call to clear_dispatch when the last copy of +* the IRFunctorStaticRegistry is destructed. When assigned to a static variable, +* this can be used by NNVM and other libraries to unregister callbacks when +* the library is unloaded. This prevents crashes when the underlying IRFunctor +* is destructed as it will no longer contain std::function instances allocated +* by a library that has been unloaded. +*/ +template<typename FType> +class IRFunctorStaticRegistry; + +template<typename R, typename ...Args> +class IRFunctorStaticRegistry<R(const NodeRef& n, Args...)> { + private: + IRFunctor<R(const NodeRef& n, Args...)> *irf_; + std::shared_ptr<IRFunctorCleanList> free_list; + + using TSelf = IRFunctorStaticRegistry<R(const NodeRef& n, Args...)>; + + public: + IRFunctorStaticRegistry(IRFunctor<R(const NodeRef& n, Args...)> *irf) { + irf_ = irf; + free_list = std::make_shared<IRFunctorCleanList>(); + } + + template<typename TNode> + inline TSelf& set_dispatch(std::function<R(const TNode* n, Args...)> f) { // NOLINT(*) + irf_->template set_dispatch<TNode>(f); + auto irf_copy = irf_; + free_list.get()->append([irf_copy] { + irf_copy->template clear_dispatch<TNode>(); + }); + return *this; + } +}; + +/*! +* \brief Helper function for constructing an IRFunctorStaticRegistry. This allows +* the compiler to deduce the template types. +*/ +template<typename R, typename ...Args> +IRFunctorStaticRegistry<R(const NodeRef& n, Args...)> MakeIRFunctorStaticRegistry( + IRFunctor<R(const NodeRef& n, Args...)> *irf) { + return IRFunctorStaticRegistry<R(const NodeRef& n, Args...)>(irf); +} + +#define TVM_AUTO_REGISTER_VAR_DEF(ClsName) \ + static TVM_ATTRIBUTE_UNUSED auto __make_functor ## _ ## ClsName + +/*! +* \brief Macro to set IRFunctor dispatch in a global static field using an IRFunctorStaticRegistry. +* Usage is exactly the same as TVM_STATIC_IR_FUNCTOR. Libraries should use this instead of +* TVM_STATIC_IR_FUNCTOR. +*/ +#define TVM_STATIC_IR_FUNCTOR_REGISTER(ClsName, FField) \ + TVM_STR_CONCAT(TVM_AUTO_REGISTER_VAR_DEF(ClsName), __COUNTER__) = \ + MakeIRFunctorStaticRegistry(&ClsName::FField()) + +} // namespace tvm +#endif // TVM_NODE_IR_FUNCTOR_H_ diff --git a/include/tvm/node/memory.h b/include/tvm/node/memory.h new file mode 100644 index 0000000000000000000000000000000000000000..c0f791eb597b4e1bdf95c4b3b77a7ad2b20fe741 --- /dev/null +++ b/include/tvm/node/memory.h @@ -0,0 +1,59 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/node/memory.h + * \brief Node memory management. + */ +#ifndef TVM_NODE_MEMORY_H_ +#define TVM_NODE_MEMORY_H_ + +#include "node.h" + +namespace tvm { +/*! + * \brief Allocate a node object. + * \param args arguments to the constructor. + * \tparam T the node type. + * \return The NodePtr to the allocated object. + */ +template<typename T, typename... Args> +inline NodePtr<T> make_node(Args&&... args); + +// Detail implementations after this +// +// The current design allows swapping the +// allocator pattern when necessary. +// +// Possible future allocator optimizations: +// - Arena allocator that gives ownership of memory to arena (deleter_= nullptr) +// - Thread-local object pools: one pool per size and alignment requirement. +// - Can specialize by type of object to give the specific allocator to each object. +// +template<typename T> +class SimpleNodeAllocator { + public: + template<typename... Args> + static T* New(Args&&... args) { + return new T(std::forward<Args>(args)...); + } + static NodeBase::FDeleter Deleter() { + return Deleter_; + } + + private: + static void Deleter_(NodeBase* ptr) { + delete static_cast<T*>(ptr); + } +}; + +template<typename T, typename... Args> +inline NodePtr<T> make_node(Args&&... args) { + using Allocator = SimpleNodeAllocator<T>; + static_assert(std::is_base_of<NodeBase, T>::value, + "make_node can only be used to create NodeBase"); + T* node = Allocator::New(std::forward<Args>(args)...); + node->deleter_ = Allocator::Deleter(); + return NodePtr<T>(node); +} + +} // namespace tvm +#endif // TVM_NODE_MEMORY_H_ diff --git a/include/tvm/node/node.h b/include/tvm/node/node.h new file mode 100644 index 0000000000000000000000000000000000000000..d726b1dab66061c1156418ca4c42089fc6002609 --- /dev/null +++ b/include/tvm/node/node.h @@ -0,0 +1,295 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/node/node.h + * \brief Node system data structure. + */ +#ifndef TVM_NODE_NODE_H_ +#define TVM_NODE_NODE_H_ + +#include <string> +#include <vector> +#include <type_traits> +#include "base/Type.h" +#include "../runtime/node_base.h" +#include "../runtime/c_runtime_api.h" + +namespace tvm { +using HalideIR::Type; +// forward declaration +class Node; +class NodeRef; + +namespace runtime { +// forward declaration +class NDArray; +} // namespace runtime + +/*! + * \brief Visitor class to each node content. + * The content is going to be called for each field. + */ +class TVM_DLL AttrVisitor { + public: +//! \cond Doxygen_Suppress + virtual void Visit(const char* key, double* value) = 0; + virtual void Visit(const char* key, int64_t* value) = 0; + virtual void Visit(const char* key, uint64_t* value) = 0; + virtual void Visit(const char* key, int* value) = 0; + virtual void Visit(const char* key, bool* value) = 0; + virtual void Visit(const char* key, std::string* value) = 0; + virtual void Visit(const char* key, void** value) = 0; + virtual void Visit(const char* key, Type* value) = 0; + virtual void Visit(const char* key, NodeRef* value) = 0; + virtual void Visit(const char* key, runtime::NDArray* value) = 0; + template<typename ENum, + typename = typename std::enable_if<std::is_enum<ENum>::value>::type> + void Visit(const char* key, ENum* ptr) { + static_assert(std::is_same<int, typename std::underlying_type<ENum>::type>::value, + "declare enum to be enum int to use visitor"); + this->Visit(key, reinterpret_cast<int*>(ptr)); + } +//! \endcond +}; + +/*! + * \brief base class of node container in DSL AST. + * All object's internal is stored as std::shared_ptr<Node> + */ +class TVM_DLL Node : public NodeBase { + public: + /*! \brief virtual destructor */ + virtual ~Node() {} + /*! \return The unique type key of the node */ + virtual const char* type_key() const = 0; + /*! + * \brief Apply visitor to each field of the Node + * Visitor could mutate the content of the node. + * override if Node contains attribute fields. + * \param visitor The visitor + */ + virtual void VisitAttrs(AttrVisitor* visitor) {} + /*! \return the type index of the node */ + virtual const uint32_t type_index() const = 0; + /*! + * \brief Whether this node derives from node with type_index=tid. + * Implemented by TVM_DECLARE_NODE_TYPE_INFO + * + * \param tid The type index. + * \return the check result. + */ + virtual const bool _DerivedFrom(uint32_t tid) const; + /*! + * \brief get a runtime unique type index given a type key + * \param type_key Type key of a type. + * \return the corresponding type index. + */ + static uint32_t TypeKey2Index(const char* type_key); + /*! + * \brief get type key from type index. + * \param index The type index + * \return the corresponding type key. + */ + static const char* TypeIndex2Key(uint32_t index); + /*! + * \return whether the type is derived from + */ + template<typename T> + inline bool derived_from() const; + /*! + * \return whether the node is of type T + * \tparam The type to be checked. + */ + template<typename T> + inline bool is_type() const; + /*! + * \brief Get a NodeRef that holds reference to this Node. + * \return the NodeRef + */ + inline NodeRef GetNodeRef() const; + // node ref can see this + friend class NodeRef; + static constexpr const char* _type_key = "Node"; +}; + +/*! \brief Base class of all node reference object */ +class NodeRef { + public: + /*! \brief type indicate the container type */ + using ContainerType = Node; + /*! + * \brief Comparator + * \param other Another node ref. + * \return the compare result. + */ + inline bool operator==(const NodeRef& other) const; + /*! + * \brief Comparator + * \param other Another node ref. + * \return the compare result. + */ + inline bool same_as(const NodeRef& other) const; + /*! + * \brief Comparator + * \param other Another node ref. + * \return the compare result. + */ + inline bool operator<(const NodeRef& other) const; + /*! + * \brief Comparator + * \param other Another node ref. + * \return the compare result. + */ + inline bool operator!=(const NodeRef& other) const; + /*! \return the hash function for NodeRef */ + inline size_t hash() const; + /*! \return whether the expression is null */ + inline bool defined() const; + /*! \return the internal type index of IRNode */ + inline uint32_t type_index() const; + /*! \return the internal node pointer */ + inline const Node* get() const; + /*! \return the internal node pointer */ + inline const Node* operator->() const; + /*! + * \brief Downcast this ir node to its actual type (e.g. Add, or + * Select). This returns nullptr if the node is not of the requested + * type. Example usage: + * + * if (const Add *add = node->as<Add>()) { + * // This is an add node + * } + * \tparam T the target type, must be subtype of IRNode + */ + template<typename T> + inline const T *as() const; + /*! + * \brief A more powerful version of as that also works with + * intermediate base types. + * \tparam T the target type, must be subtype of IRNode + */ + template<typename T> + inline const T *as_derived() const; + /*! \brief default constructor */ + NodeRef() = default; + explicit NodeRef(NodePtr<Node> node) : node_(node) {} + /*! \brief the internal node object, do not touch */ + NodePtr<Node> node_; +}; + +/*! + * \brief helper macro to declare type information in a base node. + */ +#define TVM_DECLARE_BASE_NODE_INFO(TypeName, Parent) \ + const bool _DerivedFrom(uint32_t tid) const override { \ + static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \ + if (tidx == tid) return true; \ + return Parent::_DerivedFrom(tid); \ + } + +/*! + * \brief helper macro to declare type information in a terminal node + */ +#define TVM_DECLARE_NODE_TYPE_INFO(TypeName, Parent) \ + const char* type_key() const final { \ + return TypeName::_type_key; \ + } \ + const uint32_t type_index() const final { \ + static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \ + return tidx; \ + } \ + const bool _DerivedFrom(uint32_t tid) const final { \ + static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \ + if (tidx == tid) return true; \ + return Parent::_DerivedFrom(tid); \ + } + +// implementations of inline functions after this +template<typename T> +inline bool Node::is_type() const { + // use static field so query only happens once. + static uint32_t type_id = Node::TypeKey2Index(T::_type_key); + return type_id == this->type_index(); +} + +template<typename T> +inline bool Node::derived_from() const { + // use static field so query only happens once. + static uint32_t type_id = Node::TypeKey2Index(T::_type_key); + return this->_DerivedFrom(type_id); +} + +inline NodeRef Node::GetNodeRef() const { + return NodeRef(NodePtr<Node>(const_cast<Node*>(this))); +} + +inline const Node* NodeRef::get() const { + return node_.get(); +} + +inline const Node* NodeRef::operator->() const { + return node_.get(); +} + +inline bool NodeRef::defined() const { + return node_.get() != nullptr; +} + +inline bool NodeRef::operator==(const NodeRef& other) const { + return node_.get() == other.node_.get(); +} + +inline bool NodeRef::same_as(const NodeRef& other) const { + return node_.get() == other.node_.get(); +} + +inline bool NodeRef::operator<(const NodeRef& other) const { + return node_.get() < other.node_.get(); +} + +inline bool NodeRef::operator!=(const NodeRef& other) const { + return node_.get() != other.node_.get(); +} + +inline size_t NodeRef::hash() const { + return std::hash<Node*>()(node_.get()); +} + +inline uint32_t NodeRef::type_index() const { + CHECK(node_.get() != nullptr) + << "null type"; + return get()->type_index(); +} + +template<typename T> +inline const T* NodeRef::as() const { + const Node* ptr = static_cast<const Node*>(get()); + if (ptr && ptr->is_type<T>()) { + return static_cast<const T*>(ptr); + } + return nullptr; +} + +template<typename T> +inline const T* NodeRef::as_derived() const { + const Node* ptr = static_cast<const Node*>(get()); + if (ptr && (ptr->is_type<T>() || ptr->derived_from<T>())) { + return static_cast<const T*>(ptr); + } + return nullptr; +} + +/*! \brief The hash function for nodes */ +struct NodeHash { + size_t operator()(const NodeRef& a) const { + return a.hash(); + } +}; + +/*! \brief The equal comparator for nodes */ +struct NodeEqual { + bool operator()(const NodeRef& a, const NodeRef& b) const { + return a.get() == b.get(); + } +}; +} // namespace tvm +#endif // TVM_NODE_NODE_H_ diff --git a/include/tvm/packed_func_ext.h b/include/tvm/packed_func_ext.h index 78351e094e695dfa93bf25909f3bfd090f36c781..8528eeaa5fa3c22857f913c7e233e5a9db44a096 100644 --- a/include/tvm/packed_func_ext.h +++ b/include/tvm/packed_func_ext.h @@ -116,7 +116,7 @@ inline TNodeRef TVMArgValue::AsNodeRef() const { "Conversion only works for NodeRef"); if (type_code_ == kNull) return TNodeRef(); TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle); - std::shared_ptr<Node>& sptr = *ptr<std::shared_ptr<Node> >(); + NodePtr<Node>& sptr = *ptr<NodePtr<Node> >(); CHECK(NodeTypeChecker<TNodeRef>::Check(sptr.get())) << "Expected type " << NodeTypeName<TNodeRef>() << " but get " << sptr->type_key(); @@ -132,7 +132,7 @@ inline TVMArgValue::operator HalideIR::Expr() const { return Expr(static_cast<float>(value_.v_float64)); } TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle); - std::shared_ptr<Node>& sptr = *ptr<std::shared_ptr<Node> >(); + NodePtr<Node>& sptr = *ptr<NodePtr<Node> >(); if (sptr->is_type<IterVarNode>()) { return IterVar(sptr)->var; } @@ -145,27 +145,27 @@ inline TVMArgValue::operator HalideIR::Expr() const { return Expr(sptr); } -inline std::shared_ptr<Node>& TVMArgValue::node_sptr() { +inline NodePtr<Node>& TVMArgValue::node_sptr() { TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle); - return *ptr<std::shared_ptr<Node> >(); + return *ptr<NodePtr<Node> >(); } template<typename TNodeRef, typename> inline bool TVMArgValue::IsNodeType() const { TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle); - std::shared_ptr<Node>& sptr = - *ptr<std::shared_ptr<Node> >(); + NodePtr<Node>& sptr = + *ptr<NodePtr<Node> >(); return NodeTypeChecker<TNodeRef>::Check(sptr.get()); } // extensions for TVMRetValue inline TVMRetValue& TVMRetValue::operator=( - const std::shared_ptr<Node>& other) { + const NodePtr<Node>& other) { if (other.get() == nullptr) { SwitchToPOD(kNull); } else { - SwitchToClass<std::shared_ptr<Node> >(kNodeHandle, other); + SwitchToClass<NodePtr<Node> >(kNodeHandle, other); } return *this; } @@ -174,7 +174,7 @@ inline TVMRetValue& TVMRetValue::operator=(const NodeRef& other) { if (!other.defined()) { SwitchToPOD(kNull); } else { - SwitchToClass<std::shared_ptr<Node> >(kNodeHandle, other.node_); + SwitchToClass<NodePtr<Node> >(kNodeHandle, other.node_); } return *this; } @@ -186,7 +186,7 @@ inline TNodeRef TVMRetValue::AsNodeRef() const { "Conversion only works for NodeRef"); if (type_code_ == kNull) return TNodeRef(); TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle); - std::shared_ptr<Node>& sptr = *ptr<std::shared_ptr<Node> >(); + NodePtr<Node>& sptr = *ptr<NodePtr<Node> >(); CHECK(NodeTypeChecker<TNodeRef>::Check(sptr.get())) << "Expected type " << NodeTypeName<TNodeRef>() << " but get " << sptr->type_key(); @@ -195,7 +195,7 @@ inline TNodeRef TVMRetValue::AsNodeRef() const { inline void TVMArgsSetter::operator()(size_t i, const NodeRef& other) const { // NOLINT(*) if (other.defined()) { - values_[i].v_handle = const_cast<std::shared_ptr<Node>*>(&(other.node_)); + values_[i].v_handle = const_cast<NodePtr<Node>*>(&(other.node_)); type_codes_[i] = kNodeHandle; } else { type_codes_[i] = kNull; diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h index 7c66d2c2de43df5ae2c934179c6ed79ca6d7b07b..ecf45353af67bc4d07bd031b49555370c4828719 100644 --- a/include/tvm/relay/base.h +++ b/include/tvm/relay/base.h @@ -8,7 +8,7 @@ #include <tvm/api_registry.h> #include <tvm/ir.h> -#include <tvm/node.h> +#include <tvm/node/node.h> #include <string> #include <vector> @@ -55,16 +55,16 @@ using NodeEqual = ::tvm::NodeEqual; * \param NodeName The internal container name. * \param NodeRefBase The base type. */ -#define RELAY_DEFINE_NODE_REF(TypeName, NodeName, NodeRefBase) \ - class TypeName : public NodeRefBase { \ - public: \ - TypeName() {} \ - explicit TypeName(std::shared_ptr<::tvm::Node> n) : NodeRefBase(n) {} \ - const NodeName* operator->() const { \ - return static_cast<const NodeName*>(node_.get()); \ - } \ - operator bool() { return this->defined(); } \ - using ContainerType = NodeName; \ +#define RELAY_DEFINE_NODE_REF(TypeName, NodeName, NodeRefBase) \ + class TypeName : public NodeRefBase { \ + public: \ + TypeName() {} \ + explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : NodeRefBase(n) {} \ + const NodeName* operator->() const { \ + return static_cast<const NodeName*>(node_.get()); \ + } \ + operator bool() { return this->defined(); } \ + using ContainerType = NodeName; \ }; /*! @@ -82,8 +82,6 @@ class SourceNameNode : public Node { // override attr visitor void VisitAttrs(AttrVisitor* v) final { v->Visit("name", &name); } - TVM_DLL static SourceName make(std::string name); - static constexpr const char* _type_key = "relay.SourceName"; TVM_DECLARE_NODE_TYPE_INFO(SourceNameNode, Node); }; @@ -98,7 +96,7 @@ class SourceName : public NodeRef { SourceName() {} /*! \brief constructor from node pointer */ - explicit SourceName(std::shared_ptr<Node> n) : NodeRef(n) {} + explicit SourceName(NodePtr<Node> n) : NodeRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -109,9 +107,9 @@ class SourceName : public NodeRef { * \brief Get an SourceName for a given operator name. * Will raise an error if the source name has not been registered. * \param name Name of the operator. - * \return Reference to a SourceName valid throughout program lifetime. + * \return SourceName valid throughout program lifetime. */ - TVM_DLL static const SourceName& Get(const std::string& name); + TVM_DLL static SourceName Get(const std::string& name); /*! \brief specify container node */ using ContainerType = SourceNameNode; @@ -176,7 +174,7 @@ template <typename RefType, typename NodeType> RefType GetRef(const NodeType* ptr) { static_assert(std::is_same<typename RefType::ContainerType, NodeType>::value, "Can only cast to the ref of same container type"); - return RefType(const_cast<NodeType*>(ptr)->shared_from_this()); + return RefType(std::move(ptr->GetNodeRef().node_)); } // TODO(@tqchen, @jroesch): can we move these semantics to HalideIR diff --git a/include/tvm/relay/environment.h b/include/tvm/relay/environment.h index 7e07dc01eab487450a67dab81b0f199e4d471d7f..46cedf12b816d0dfb66947d0e6316b3431e6b189 100644 --- a/include/tvm/relay/environment.h +++ b/include/tvm/relay/environment.h @@ -98,15 +98,15 @@ class EnvironmentNode : public RelayNode { TVM_DECLARE_NODE_TYPE_INFO(EnvironmentNode, Node); private: - /*! \brief A map from string names to global variables that - * ensures global uniqueness. + /*! \brief A map from string names to global variables that + * ensures global uniqueness. */ tvm::Map<std::string, GlobalVar> global_map_; }; struct Environment : public NodeRef { Environment() {} - explicit Environment(std::shared_ptr<tvm::Node> p) : NodeRef(p) {} + explicit Environment(NodePtr<tvm::Node> p) : NodeRef(p) {} inline EnvironmentNode* operator->() const { return static_cast<EnvironmentNode*>(node_.get()); diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index 8ad0537ad68bdcfe598abde1d2874a469d2bc4dd..27bb464b98a3d2524d9fa6b8953429da313a8177 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -7,7 +7,7 @@ #ifndef TVM_RELAY_EXPR_FUNCTOR_H_ #define TVM_RELAY_EXPR_FUNCTOR_H_ -#include <tvm/ir_functor.h> +#include <tvm/node/ir_functor.h> #include <string> #include "./expr.h" #include "./op.h" @@ -19,7 +19,7 @@ namespace relay { * \brief A dynamical functor that dispatches on in the first Expr argument. * You can use this as a more powerful Visitor, since it allows you to * define function signatures of Visit Function. - * + * * \sa tvm/ir_functor.h * * \tparam FType function signiture @@ -30,7 +30,7 @@ template <typename FType> class ExprFunctor; // functions to be overriden. -#define EXPR_FUNCTOR_DEFAULT \ +#define EXPR_FUNCTOR_DEFAULT \ { return VisitExprDefault_(op, std::forward<Args>(args)...); } #define RELAY_EXPR_FUNCTOR_DISPATCH(OP) \ @@ -152,12 +152,12 @@ class ExprMutator Expr VisitExpr_(const CallNode* call_node, const Expr& e) override; Expr VisitExpr_(const LetNode* op, const Expr& e) override; Expr VisitExpr_(const IfNode* op, const Expr& e) override; - /*! \brief Used to visit the types inside of expressions. - * + /*! \brief Used to visit the types inside of expressions. + * * Can be overloaded to transform the types in arbitrary * ways, one way would be to define a sub-class of type * visitor for types which transform them appropriately. - */ + */ virtual Type VisitType(const Type& t); private: diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index 49661fec57311015e9c208ee633266d939d5e474..9f4e7be08a8c7c2d6eeb93e2fe478dcc195326e0 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -90,7 +90,7 @@ class Op : public relay::Expr { /*! \brief default constructor */ Op() {} /*! \brief constructor from node pointer */ - explicit Op(std::shared_ptr<Node> n) : Expr(n) {} + explicit Op(NodePtr<Node> n) : Expr(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -149,9 +149,9 @@ class OpRegistry { const std::string& description); /*! * \brief Attach the type function corresponding to the return type. - * \param rel_name The type relation name to register. + * \param rel_name The type relation name to register. * \param type_rel_func The backing relation function which can solve an arbitrary - * relation on variables. + * relation on variables. * \return reference to self. */ inline OpRegistry& add_type_rel( @@ -338,7 +338,7 @@ inline OpRegistry& OpRegistry::describe( inline OpRegistry& OpRegistry::add_argument(const std::string& name, const std::string& type, const std::string& description) { - std::shared_ptr<AttrFieldInfoNode> n = std::make_shared<AttrFieldInfoNode>(); + auto n = make_node<AttrFieldInfoNode>(); n->name = name; n->type_info = type; n->description = description; diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index 44030ad8d97f09058bf0f827857ea6fa769a61a4..f972eb85b0413b8071abd710c09c8e99a2c80667 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -8,7 +8,7 @@ #include <tvm/api_registry.h> #include <tvm/ir.h> -#include <tvm/node.h> +#include <tvm/node/node.h> #include <string> #include "./base.h" @@ -37,7 +37,7 @@ class TypeNode : public RelayNode { class Type : public NodeRef { public: Type() {} - explicit Type(std::shared_ptr<tvm::Node> p) : NodeRef(p) {} + explicit Type(NodePtr<tvm::Node> p) : NodeRef(p) {} using ContainerType = TypeNode; }; diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h index a3359289e2618635c60f47f1dd8da203faac8890..313e0a5c3da862a7c01d6e9fc3e821b760a0ac47 100644 --- a/include/tvm/runtime/ndarray.h +++ b/include/tvm/runtime/ndarray.h @@ -263,12 +263,16 @@ struct NDArray::Container { // the usages of functions are documented in place. inline NDArray::NDArray(Container* data) : data_(data) { - data_->IncRef(); + if (data != nullptr) { + data_->IncRef(); + } } inline NDArray::NDArray(const NDArray& other) : data_(other.data_) { - data_->IncRef(); + if (data_ != nullptr) { + data_->IncRef(); + } } inline void NDArray::reset() { diff --git a/include/tvm/runtime/node_base.h b/include/tvm/runtime/node_base.h new file mode 100644 index 0000000000000000000000000000000000000000..bc62ac460cffdbe9d451a43290986adb234597af --- /dev/null +++ b/include/tvm/runtime/node_base.h @@ -0,0 +1,241 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/runtime/node_base.h + * \brief Base data structure for Node. + * + * \note Node is not a runtime feature. + * This file only exposes the signature of NodePtr for PackedFunc. + */ +#ifndef TVM_RUNTIME_NODE_BASE_H_ +#define TVM_RUNTIME_NODE_BASE_H_ + +#include <utility> +#include <atomic> + +namespace tvm { + +// forward declarations +template<typename T> +class NodePtr; +class Node; +class NodeRef; + +/*! + * \brief Base class of Node for runtime destructor purposes. + * + * Node is a reference counted object which is used to construct AST. + * Each node is backed by a custom deleter, which deletes the object. + * Do not call create raw Node pointer, always use tvm::make_node. + * + * \note In most cases, please inheritate tvm::Node. + * \sa Node, NodePtr, make_node + */ +class NodeBase { + public: + /*! + * \brief type of NodeBase deleter + * \param self pointer to the NodeBase. + */ + typedef void (*FDeleter)(NodeBase* self); + + protected: + // default constructor and copy constructor + NodeBase() {} + // override the copy and assign constructors to do nothing. + // This is to make sure only contents, but not deleter and ref_counter + // are copied when a child class copies itself. + NodeBase(const NodeBase& other) { // NOLINT(*) + } + NodeBase(NodeBase&& other) { // NOLINT(*) + } + NodeBase& operator=(const NodeBase& other) { //NOLINT(*) + return *this; + } + NodeBase& operator=(NodeBase&& other) { //NOLINT(*) + return *this; + } + + private: + /*! \brief Internal reference counter */ + std::atomic<int> ref_counter_{0}; + /*! + * \brief deleter of this object to enable customized allocation. + * If the deleter is nullptr, no deletion will be performed. + * The creator of the Node must always set the deleter field properly. + */ + FDeleter deleter_ = nullptr; + // reference counting functions + void IncRef() { + ref_counter_.fetch_add(1, std::memory_order_relaxed); + } + void DecRef() { + if (ref_counter_.fetch_sub(1, std::memory_order_release) == 1) { + std::atomic_thread_fence(std::memory_order_acquire); + if (this->deleter_ != nullptr) { + (*this->deleter_)(this); + } + } + } + int use_count() const { + return ref_counter_.load(std::memory_order_relaxed); + } + // friend declaration + template<typename> + friend class NodePtr; + template<typename Y, typename... Args> + friend NodePtr<Y> make_node(Args&&...); +}; + +/*! + * \brief Smart pointer for Node containers, + * must be subclass of NodeBase + * \tparam T the content data type. + */ +template<typename T> +class NodePtr { + public: + /*! \brief default constructor */ + NodePtr() {} + /*! \brief default constructor */ + NodePtr(std::nullptr_t) {} // NOLINT(*) + /*! + * \brief copy constructor + * \param other The value to be moved + */ + NodePtr(const NodePtr<T>& other) // NOLINT(*) + : NodePtr(other.data_) { + } + /*! + * \brief copy constructor + * \param other The value to be moved + */ + template<typename Y> + NodePtr(const NodePtr<Y>& other) // NOLINT(*) + : NodePtr(other.data_) { + static_assert(std::is_base_of<T, Y>::value, + "can only assign of child class NodePtr to parent"); + } + /*! + * \brief move constructor + * \param other The value to be moved + */ + NodePtr(NodePtr<T>&& other) // NOLINT(*) + : data_(other.data_) { + other.data_ = nullptr; + } + /*! + * \brief move constructor + * \param other The value to be moved + */ + template<typename Y> + NodePtr(NodePtr<Y>&& other) // NOLINT(*) + : data_(other.data_) { + static_assert(std::is_base_of<T, Y>::value, + "can only assign of child class NodePtr to parent"); + other.data_ = nullptr; + } + /*! \brief destructor */ + ~NodePtr() { + this->reset(); + } + /*! + * \brief Swap this array with another NDArray + * \param other The other NDArray + */ + void swap(NodePtr<T>& other) { // NOLINT(*) + std::swap(data_, other.data_); + } + /*! + * \return Get the content of the pointer + */ + T* get() const { + return static_cast<T*>(data_); + } + /*! + * \return The pointer + */ + T* operator->() const { + return get(); + } + /*! + * \return The reference + */ + T& operator*() const { // NOLINT(*) + return *get(); + } + /*! + * \brief copy assignmemt + * \param other The value to be assigned. + * \return reference to self. + */ + NodePtr<T>& operator=(const NodePtr<T>& other) { // NOLINT(*) + // takes in plane operator to enable copy elison. + // copy-and-swap idiom + NodePtr(other).swap(*this); // NOLINT(*) + return *this; + } + /*! + * \brief move assignmemt + * \param other The value to be assigned. + * \return reference to self. + */ + NodePtr<T>& operator=(NodePtr<T>&& other) { // NOLINT(*) + // copy-and-swap idiom + NodePtr(std::move(other)).swap(*this); // NOLINT(*) + return *this; + } + /*! \brief reset the content of ptr to be nullptr */ + void reset() { + if (data_ != nullptr) { + data_->DecRef(); + data_ = nullptr; + } + } + /*! \return The use count of the ptr, for debug purposes */ + int use_count() const { + return data_ != nullptr ? data_->use_count() : 0; + } + /*! \return whether the reference is unique */ + bool unique() const { + return data_ != nullptr && data_->use_count() == 1; + } + /*! \return Whether two NodePtr do not equals each other */ + bool operator==(const NodePtr<T>& other) const { + return data_ == other.data_; + } + /*! \return Whether two NodePtr equals each other */ + bool operator!=(const NodePtr<T>& other) const { + return data_ != other.data_; + } + /*! \return Whether the pointer is nullptr */ + bool operator==(std::nullptr_t null) const { + return data_ == nullptr; + } + /*! \return Whether the pointer is not nullptr */ + bool operator!=(std::nullptr_t null) const { + return data_ != nullptr; + } + + private: + /*! \brief internal pointer field */ + NodeBase* data_{nullptr}; + /*! + * \brief constructor from NodeBase + * \param data The node base pointer + */ + explicit NodePtr(NodeBase* data) + : data_(data) { + if (data != nullptr) { + data_->IncRef(); + } + } + // friend declaration + friend class Node; + template<typename> + friend class NodePtr; + template<typename Y, typename... Args> + friend NodePtr<Y> make_node(Args&&...); +}; +} // namespace tvm + +#endif // TVM_RUNTIME_NODE_BASE_H_ diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index d1206a8a34f4b082507493fe0280bf44fff210e7..401b0bbb97ed483a858bced0dd4478e5d6d343b5 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -17,6 +17,7 @@ #include "c_runtime_api.h" #include "module.h" #include "ndarray.h" +#include "node_base.h" namespace HalideIR { // Forward declare type for extensions @@ -31,12 +32,6 @@ struct Expr; #endif namespace tvm { -// Forward declare NodeRef and Node for extensions. -// This header works fine without depend on NodeRef -// as long as it is not used. -class Node; -class NodeRef; - namespace runtime { // forward declarations class TVMArgs; @@ -549,7 +544,7 @@ class TVMArgValue : public TVMPODValue_ { inline operator HalideIR::Type() const; inline operator HalideIR::Expr() const; // get internal node ptr, if it is node - inline std::shared_ptr<Node>& node_sptr(); + inline NodePtr<Node>& node_sptr(); }; /*! @@ -745,7 +740,7 @@ class TVMRetValue : public TVMPODValue_ { template<typename TNodeRef> inline TNodeRef AsNodeRef() const; inline TVMRetValue& operator=(const NodeRef& other); - inline TVMRetValue& operator=(const std::shared_ptr<Node>& other); + inline TVMRetValue& operator=(const NodePtr<Node>& other); // type related inline operator HalideIR::Type() const; inline TVMRetValue& operator=(const HalideIR::Type& other); @@ -775,8 +770,8 @@ class TVMRetValue : public TVMPODValue_ { break; } case kNodeHandle: { - SwitchToClass<std::shared_ptr<Node> >( - kNodeHandle, *other.template ptr<std::shared_ptr<Node> >()); + SwitchToClass<NodePtr<Node> >( + kNodeHandle, *other.template ptr<NodePtr<Node> >()); break; } default: { @@ -821,7 +816,7 @@ class TVMRetValue : public TVMPODValue_ { case kStr: delete ptr<std::string>(); break; case kFuncHandle: delete ptr<PackedFunc>(); break; case kModuleHandle: delete ptr<Module>(); break; - case kNodeHandle: delete ptr<std::shared_ptr<Node> >(); break; + case kNodeHandle: delete ptr<NodePtr<Node> >(); break; case kNDArrayContainer: { static_cast<NDArray::Container*>(value_.v_handle)->DecRef(); break; diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h index b72eb7105faaf2093027af578e0b227070e4953e..af72f315329141aef2a6eab252428a52a1906f2f 100644 --- a/include/tvm/schedule.h +++ b/include/tvm/schedule.h @@ -36,7 +36,7 @@ enum AttachType : int { class Stage : public NodeRef { public: Stage() {} - explicit Stage(std::shared_ptr<Node> n) : NodeRef(n) {} + explicit Stage(NodePtr<Node> n) : NodeRef(n) {} /*! * \brief create a new schedule for op. * \param op The operator in the schedule @@ -260,7 +260,7 @@ class Stage : public NodeRef { class Schedule : public NodeRef { public: Schedule() {} - explicit Schedule(std::shared_ptr<Node> n) : NodeRef(n) {} + explicit Schedule(NodePtr<Node> n) : NodeRef(n) {} /*! * \brief Get a copy of current schedule. * \return The copied schedule. @@ -383,7 +383,7 @@ class Schedule : public NodeRef { class IterVarRelation : public NodeRef { public: IterVarRelation() {} - explicit IterVarRelation(std::shared_ptr<Node> n) : NodeRef(n) {} + explicit IterVarRelation(NodePtr<Node> n) : NodeRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -397,7 +397,7 @@ class IterVarRelation : public NodeRef { class IterVarAttr : public NodeRef { public: IterVarAttr() {} - explicit IterVarAttr(std::shared_ptr<Node> n) : NodeRef(n) {} + explicit IterVarAttr(NodePtr<Node> n) : NodeRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container diff --git a/include/tvm/tensor.h b/include/tvm/tensor.h index ddccfce2fefbf952136668910d26ca2b2cd6d148..48d959301e638ac84ced38744146b6cb3c6adc08 100644 --- a/include/tvm/tensor.h +++ b/include/tvm/tensor.h @@ -6,7 +6,6 @@ #ifndef TVM_TENSOR_H_ #define TVM_TENSOR_H_ -#include <tvm/container.h> #include <ir/FunctionBase.h> #include <string> #include <vector> @@ -15,6 +14,7 @@ #include "base.h" #include "expr.h" #include "arithmetic.h" +#include "node/container.h" namespace tvm { @@ -33,7 +33,7 @@ class Tensor : public NodeRef { public: /*! \brief default constructor, used internally */ Tensor() {} - explicit Tensor(std::shared_ptr<Node> n) : NodeRef(n) {} + explicit Tensor(NodePtr<Node> n) : NodeRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -118,7 +118,7 @@ class Operation : public FunctionRef { public: /*! \brief default constructor */ Operation() {} - explicit Operation(std::shared_ptr<Node> n) : FunctionRef(n) {} + explicit Operation(NodePtr<Node> n) : FunctionRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container diff --git a/include/tvm/tensor_intrin.h b/include/tvm/tensor_intrin.h index fa8c895ccb08cbfb689bc184c8da57d50e39f948..944498d1e61525d70651d341ae5f778a48572a07 100644 --- a/include/tvm/tensor_intrin.h +++ b/include/tvm/tensor_intrin.h @@ -19,7 +19,7 @@ class TensorIntrinNode; class TensorIntrin : public NodeRef { public: TensorIntrin() {} - explicit TensorIntrin(std::shared_ptr<Node> n) : NodeRef(n) {} + explicit TensorIntrin(NodePtr<Node> n) : NodeRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container diff --git a/nnvm/src/compiler/compile_engine.cc b/nnvm/src/compiler/compile_engine.cc index b9b27621840c5ae4434bb5915e7669c25d1562fe..6df70b53ccae4f0bc4a5f8fd658fc866e30a2ef1 100644 --- a/nnvm/src/compiler/compile_engine.cc +++ b/nnvm/src/compiler/compile_engine.cc @@ -94,7 +94,7 @@ class CompileEngine { return it->second->graph_func; } GraphFunc f = DoLower(key->graph, key->inputs, key->target, master_idx); - std::shared_ptr<GraphCacheEntryNode> n = std::make_shared<GraphCacheEntryNode>(); + auto n = tvm::make_node<GraphCacheEntryNode>(); n->graph_func = f; n->use_count = 1; n->master_idx = master_idx; @@ -107,8 +107,7 @@ class CompileEngine { Array<NodeRef> items; for (auto& kv : cache_) { items.push_back(kv.first); - std::shared_ptr<GraphCacheEntryNode> n = - std::make_shared<GraphCacheEntryNode>(*(kv.second.operator->())); + auto n = tvm::make_node<GraphCacheEntryNode>(*(kv.second.operator->())); items.push_back(GraphCacheEntry(n)); } return items; @@ -126,7 +125,7 @@ class CompileEngine { // Set the given function on given graph key. void Set(const GraphKey& key, GraphFunc func) { std::lock_guard<std::mutex> lock(mutex_); - std::shared_ptr<GraphCacheEntryNode> n = std::make_shared<GraphCacheEntryNode>(); + auto n = tvm::make_node<GraphCacheEntryNode>(); n->graph_func = func; n->use_count = 1; cache_[key] = GraphCacheEntry(n); @@ -265,7 +264,7 @@ class CompileEngine { graph, inputs, target, master_idx, &readable_name, &outputs); - std::shared_ptr<GraphFuncNode> gf = std::make_shared<GraphFuncNode>(); + auto gf = tvm::make_node<GraphFuncNode>(); gf->target = target; gf->func_name = GetUniqeName(readable_name); gf->inputs = inputs; diff --git a/nnvm/src/compiler/compile_engine.h b/nnvm/src/compiler/compile_engine.h index 7696b3b5f4eb1421a2a6a538bdd2c2a4d21bcdf0..23e5e1d1a49c859db57c434a401a81799ee3c0c3 100644 --- a/nnvm/src/compiler/compile_engine.h +++ b/nnvm/src/compiler/compile_engine.h @@ -71,7 +71,7 @@ struct GraphCacheEntryNode : public tvm::Node { class GraphCacheEntry : public ::tvm::NodeRef { public: GraphCacheEntry() {} - explicit GraphCacheEntry(std::shared_ptr<::tvm::Node> n) : NodeRef(n) {} + explicit GraphCacheEntry(::tvm::NodePtr<::tvm::Node> n) : NodeRef(n) {} GraphCacheEntryNode* operator->() { return static_cast<GraphCacheEntryNode*>(node_.get()); } diff --git a/nnvm/src/compiler/graph_hash.cc b/nnvm/src/compiler/graph_hash.cc index ca68727ea0674a304c8656d927271609a625974d..f14a60e80d8c77261df57c7911d7d4eb06e2a1b8 100644 --- a/nnvm/src/compiler/graph_hash.cc +++ b/nnvm/src/compiler/graph_hash.cc @@ -74,8 +74,7 @@ bool GraphKeyEqual::Equal(const GraphKey& a, GraphKey GraphKeyNode::make(Graph graph, tvm::Array<Tensor> inputs, std::string target) { - std::shared_ptr<GraphKeyNode> n - = std::make_shared<GraphKeyNode>(); + auto n = tvm::make_node<GraphKeyNode>(); n->graph = std::move(graph); n->inputs = inputs; n->target = std::move(target); diff --git a/nnvm/src/compiler/graph_runtime.cc b/nnvm/src/compiler/graph_runtime.cc index c680e82dd93624e6b9753cd92f393fff46c1c1a5..e4865df3f9f0d673da55df70d5d0bf123b44798c 100644 --- a/nnvm/src/compiler/graph_runtime.cc +++ b/nnvm/src/compiler/graph_runtime.cc @@ -91,8 +91,7 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._load_param_dict") for (size_t i = 0; i < size; ++i) { tvm::runtime::NDArray temp; temp.Load(strm); - std::shared_ptr<NDArrayWrapperNode> n - = std::make_shared<NDArrayWrapperNode>(); + auto n = tvm::make_node<NDArrayWrapperNode>(); n->name = std::move(names[i]); n->array = temp; ret.push_back(NDArrayWrapper(n)); diff --git a/nnvm/src/compiler/graph_runtime.h b/nnvm/src/compiler/graph_runtime.h index 272e2be7f251e637a12f198ce0913320a71f545b..e5ba3681d2bfb8d0ee6e83216c9cf0b00aea2ab1 100644 --- a/nnvm/src/compiler/graph_runtime.h +++ b/nnvm/src/compiler/graph_runtime.h @@ -9,6 +9,7 @@ #include <nnvm/graph.h> #include <tvm/base.h> #include <tvm/expr.h> +#include <tvm/node/memory.h> #include <tvm/packed_func_ext.h> #include <tvm/runtime/ndarray.h> #include <vector> diff --git a/nnvm/src/compiler/packed_func_ext.cc b/nnvm/src/compiler/packed_func_ext.cc index 64846fc8e2472ab49f72c60051515e39572ea622..1a19feabfe8a324f1b04c8e8f91e9829d50bab8f 100644 --- a/nnvm/src/compiler/packed_func_ext.cc +++ b/nnvm/src/compiler/packed_func_ext.cc @@ -96,7 +96,7 @@ TVM_REGISTER_GLOBAL("nnvm._register_compute") const Array<Tensor>& out_info) -> Array<Tensor> { TVMRetValue ret = (*f)(GetAttrDict(attrs), inputs, out_info); - if ((*ret.ptr<std::shared_ptr<tvm::Node> >())->derived_from<tvm::TensorNode>()) { + if ((*ret.ptr<::tvm::NodePtr<tvm::Node> >())->derived_from<tvm::TensorNode>()) { return {ret.operator Tensor()}; } else { return ret; diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index 8c55684ed851af3b863f188d1992a17a59a4b946..8ca49f19baecc8c7de067ec29a41b255f87151ca 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -45,11 +45,11 @@ TVM_REGISTER_API("_str") TVM_REGISTER_API("_Array") .set_body([](TVMArgs args, TVMRetValue* ret) { - std::vector<std::shared_ptr<Node> > data; + std::vector<NodePtr<Node> > data; for (int i = 0; i < args.size(); ++i) { data.push_back(args[i].node_sptr()); } - auto node = std::make_shared<ArrayNode>(); + auto node = make_node<ArrayNode>(); node->data = std::move(data); *ret = node; }); @@ -87,7 +87,7 @@ TVM_REGISTER_API("_Map") data.emplace(std::make_pair(args[i].operator std::string(), args[i + 1].node_sptr())); } - auto node = std::make_shared<StrMapNode>(); + auto node = make_node<StrMapNode>(); node->data = std::move(data); *ret = node; } else { @@ -101,7 +101,7 @@ TVM_REGISTER_API("_Map") data.emplace(std::make_pair(args[i].node_sptr(), args[i + 1].node_sptr())); } - auto node = std::make_shared<MapNode>(); + auto node = make_node<MapNode>(); node->data = std::move(data); *ret = node; } @@ -163,7 +163,7 @@ TVM_REGISTER_API("_MapItems") auto& sptr = args[0].node_sptr(); if (sptr->is_type<MapNode>()) { auto* n = static_cast<const MapNode*>(sptr.get()); - auto rkvs = std::make_shared<ArrayNode>(); + auto rkvs = make_node<ArrayNode>(); for (const auto& kv : n->data) { rkvs->data.push_back(kv.first); rkvs->data.push_back(kv.second); @@ -171,7 +171,7 @@ TVM_REGISTER_API("_MapItems") *ret = rkvs; } else { auto* n = static_cast<const StrMapNode*>(sptr.get()); - auto rkvs = std::make_shared<ArrayNode>(); + auto rkvs = make_node<ArrayNode>(); for (const auto& kv : n->data) { rkvs->data.push_back(ir::StringImm::make(kv.first).node_); rkvs->data.push_back(kv.second); diff --git a/src/api/dsl_api.cc b/src/api/dsl_api.cc index 9157e62fda8a3d12e13a59231d3d2750095046ba..1c2c294a5f3065d51fc150e6569c7708fa92a0fa 100644 --- a/src/api/dsl_api.cc +++ b/src/api/dsl_api.cc @@ -28,7 +28,7 @@ struct TVMAPIThreadLocalEntry { /*! \brief Thread local store that can be used to hold return values. */ typedef dmlc::ThreadLocalStore<TVMAPIThreadLocalEntry> TVMAPIThreadLocalStore; -using TVMAPINode = std::shared_ptr<Node>; +using TVMAPINode = NodePtr<Node>; struct APIAttrGetter : public AttrVisitor { std::string skey; diff --git a/src/arithmetic/canonical.cc b/src/arithmetic/canonical.cc index 99f9f0c073c3f25df415af1fd1d4b4f38450a19a..0fa7b846cf7e8acf84825277a7be641a273231c5 100644 --- a/src/arithmetic/canonical.cc +++ b/src/arithmetic/canonical.cc @@ -48,7 +48,7 @@ struct ComExprEntry { }; // canonical expression for communicative expression. -struct ComExprNode { +struct ComExprNode : public NodeBase { // base constant value. int64_t base{0}; // The values to be sumed. @@ -60,7 +60,7 @@ struct ComExpr { public: // constructor ComExpr() {} - explicit ComExpr(std::shared_ptr<ComExprNode> ptr) : ptr_(ptr) {} + explicit ComExpr(NodePtr<ComExprNode> ptr) : ptr_(ptr) {} // get member ComExprNode* operator->() const { return ptr_.get(); @@ -106,7 +106,7 @@ struct ComExpr { } private: - std::shared_ptr<ComExprNode> ptr_; + NodePtr<ComExprNode> ptr_; }; // binary comparison op. @@ -173,7 +173,7 @@ class Canonical::Internal : public IRMutator { if (sum.defined()) return sum; const int64_t *v1 = as_const_int(value); const uint64_t *v2 = as_const_uint(value); - std::shared_ptr<ComExprNode> n = std::make_shared<ComExprNode>(); + auto n = make_node<ComExprNode>(); if (v1) { n->base = *v1; } else if (v2) { @@ -471,8 +471,8 @@ class Canonical::Internal : public IRMutator { Type type = coeff.type(); int64_t value = GetConstIntValue(coeff); if (value < 0) return {}; - std::shared_ptr<ComExprNode> xnode = std::make_shared<ComExprNode>(); - std::shared_ptr<ComExprNode> ynode = std::make_shared<ComExprNode>(); + auto xnode = make_node<ComExprNode>(); + auto ynode = make_node<ComExprNode>(); if (a->base % value == 0) { xnode->base = a->base; } else { @@ -507,7 +507,7 @@ class Canonical::Internal : public IRMutator { std::vector<ComExpr> pair = TryLinearEquation(a, v); if (pair.size() == 0) { int64_t value = GetConstIntValue(v); - std::shared_ptr<ComExprNode> n = std::make_shared<ComExprNode>(); + auto n = make_node<ComExprNode>(); n->base = a->base % value; for (auto e : a->elem) { if (e.scale % value == 0) continue; @@ -554,8 +554,7 @@ class Canonical::Internal : public IRMutator { if (value == 0) { return make_zero(v.type()); } - std::shared_ptr<ComExprNode> vsum = - std::make_shared<ComExprNode>(*a.operator->()); + auto vsum = make_node<ComExprNode>(*a.operator->()); vsum->base *= value; for (auto& e : vsum->elem) { e.scale *= value; @@ -576,7 +575,7 @@ class Canonical::Internal : public IRMutator { ComExpr SumAdd_(const ComExpr& suma, const ComExpr& sumb, int bscale) { - std::shared_ptr<ComExprNode> n = std::make_shared<ComExprNode>(); + auto n = make_node<ComExprNode>(); n->base = suma->base + sumb->base * bscale; // merge of suma and sumb; size_t i = 0, j = 0; diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index c1b68fddd0e994e18d207dd910490a77f568a214..78c592471a1a677d9b82ba6402fc6fee562b9189 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -329,7 +329,7 @@ inline IntSet AsStrideSet(IntSet a) { if (a.as<StrideSet>()) return a; const IntervalSet* s = a.as<IntervalSet>(); CHECK(s->i.is_bounded()); - std::shared_ptr<StrideSet> n = std::make_shared<StrideSet>(); + NodePtr<StrideSet> n = make_node<StrideSet>(); n->base = s->i; return IntSet(n); } @@ -348,7 +348,7 @@ inline IntSet CombineSets<Add>(IntSet a, IntSet b) { b = AsStrideSet(b); const StrideSet* a_stride = a.as<StrideSet>(); const StrideSet* b_stride = b.as<StrideSet>(); - auto n = std::make_shared<StrideSet>(*a_stride); + auto n = make_node<StrideSet>(*a_stride); for (size_t i = 0; i < b_stride->extents.size(); ++i) { n->extents.push_back(b_stride->extents[i]); n->strides.push_back(b_stride->strides[i]); diff --git a/src/arithmetic/int_set_internal.h b/src/arithmetic/int_set_internal.h index 9284e6e016e03b6b5e8c51d71a627fd7295ac8ce..e28fe2a9d9584913a76f5538b4d49f2e77b17115 100644 --- a/src/arithmetic/int_set_internal.h +++ b/src/arithmetic/int_set_internal.h @@ -21,14 +21,14 @@ struct IntervalSet : public IntSetNode { Interval i; static IntSet make(Interval i) { - std::shared_ptr<IntervalSet> n = - std::make_shared<IntervalSet>(); + NodePtr<IntervalSet> n = + make_node<IntervalSet>(); n->i = i; return IntSet(n); } static IntSet make(Expr min, Expr max) { - std::shared_ptr<IntervalSet> n = - std::make_shared<IntervalSet>(); + NodePtr<IntervalSet> n = + make_node<IntervalSet>(); n->i.min = min; n->i.max = max; return IntSet(n); diff --git a/src/arithmetic/modular.cc b/src/arithmetic/modular.cc index 1c03d0f97485287456532645943350be96a21efb..d79300eb7782709f153a26728c84930615f73e85 100644 --- a/src/arithmetic/modular.cc +++ b/src/arithmetic/modular.cc @@ -159,7 +159,7 @@ IntSet EvalModular(const Expr& e, CHECK(m) << "Need to pass ModularSet for Modular Analysis"; mmap[kv.first.get()] = m->e; } - std::shared_ptr<ModularSet> n = std::make_shared<ModularSet>(); + NodePtr<ModularSet> n = make_node<ModularSet>(); n->e = ModularEvaluator(mmap)(e); return IntSet(n); } diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc index f35b09d1dfe621a5686fc8a8f8d967a5426fd759..5c0a5e07cd2a71d97ded0408d13c321a6846aed1 100644 --- a/src/codegen/build_module.cc +++ b/src/codegen/build_module.cc @@ -32,7 +32,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) */ Target CreateTarget(const std::string& target_name, const std::vector<std::string>& options) { - auto target = Target(std::make_shared<TargetNode>()); + auto target = Target(make_node<TargetNode>()); auto t = static_cast<TargetNode*>(target.node_.get()); t->target_name = target_name; @@ -475,7 +475,7 @@ runtime::Module build(const Array<LoweredFunc>& funcs, } BuildConfig build_config() { - return BuildConfig(std::make_shared<BuildConfigNode>()); + return BuildConfig(make_node<BuildConfigNode>()); } /*! \brief Entry to hold the BuildConfig context stack. */ @@ -533,7 +533,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); struct GenericFunc::Manager { - std::unordered_map<std::string, std::shared_ptr<Node> > fmap; + std::unordered_map<std::string, NodePtr<Node> > fmap; // mutex std::mutex mutex; @@ -551,7 +551,7 @@ GenericFunc GenericFunc::Get(const std::string& name) { std::lock_guard<std::mutex>(m->mutex); auto it = m->fmap.find(name); if (it == m->fmap.end()) { - auto f = std::make_shared<GenericFuncNode>(); + auto f = make_node<GenericFuncNode>(); f->name_ = name; m->fmap[name] = f; return GenericFunc(f); @@ -669,7 +669,7 @@ TVM_REGISTER_API("_BuildConfigGetAddLowerPassInfo") TVM_REGISTER_API("_GenericFuncCreate") .set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = GenericFunc(std::make_shared<GenericFuncNode>()); + *ret = GenericFunc(make_node<GenericFuncNode>()); }); TVM_REGISTER_API("_GenericFuncGetGlobal") diff --git a/src/codegen/verilog/verilog_ir.cc b/src/codegen/verilog/verilog_ir.cc index b7576c83dfa8cd39ef5cf61af08c28186d63f3e1..dea8ebaebb8de1e56882fee21291624819c72d94 100644 --- a/src/codegen/verilog/verilog_ir.cc +++ b/src/codegen/verilog/verilog_ir.cc @@ -17,14 +17,14 @@ using namespace ir; ControlSignal ControlSignalNode::make( ControlSignalType type, int advance_size) { - auto n = std::make_shared<ControlSignalNode>(); + auto n = make_node<ControlSignalNode>(); n->ctrl_type = type; n->advance_size = advance_size; return ControlSignal(n); } StageInput StageInputNode::make(Var var, StageInputType input_type) { - std::shared_ptr<StageInputNode> n = std::make_shared<StageInputNode>(); + NodePtr<StageInputNode> n = make_node<StageInputNode>(); n->var = var; n->input_type = input_type; return StageInput(n); @@ -81,7 +81,7 @@ class PipelineExtractor: public IRVisitor { arg_handle_[arg.get()] = arg; } } - pipeline_ = std::make_shared<PipelineNode>(); + pipeline_ = make_node<PipelineNode>(); this->Visit(f->body); // setup channels for (const auto &kv : cmap_) { @@ -113,7 +113,7 @@ class PipelineExtractor: public IRVisitor { if (cb.node != nullptr) { CHECK(cb.node->channel.same_as(ch)); } else { - cb.node = std::make_shared<ChannelBlockNode>(); + cb.node = make_node<ChannelBlockNode>(); cb.node->channel = ch; } if (op->attr_key == attr::channel_read_scope) { @@ -167,8 +167,8 @@ class PipelineExtractor: public IRVisitor { // The replace logic StageInputReplacer repl(var_info_); // Setup the compute block. - std::shared_ptr<ComputeBlockNode> compute = - std::make_shared<ComputeBlockNode>(); + NodePtr<ComputeBlockNode> compute = + make_node<ComputeBlockNode>(); compute->loop = Array<Stmt>(loop_); // setup the advance triggers for (const auto& e : trigger_) { @@ -180,8 +180,8 @@ class PipelineExtractor: public IRVisitor { } else { ch = Channel(attr->node.node_); } - std::shared_ptr<SignalTriggerNode> trigger - = std::make_shared<SignalTriggerNode>(); + NodePtr<SignalTriggerNode> trigger + = make_node<SignalTriggerNode>(); trigger->channel_var = ch->handle_var; // predicate for the trigger Expr predicate = const_true(); @@ -249,7 +249,7 @@ class PipelineExtractor: public IRVisitor { CHECK(!cmap_.count(var)) << "Multiple access to the same handle"; ChannelEntry& cb = cmap_[var]; - cb.node = std::make_shared<ChannelBlockNode>(); + cb.node = make_node<ChannelBlockNode>(); cb.node->channel = ChannelNode::make(arg_handle_.at(var), dtype); return cb.node->channel; } @@ -257,7 +257,7 @@ class PipelineExtractor: public IRVisitor { private: // The channel information. struct ChannelEntry { - std::shared_ptr<ChannelBlockNode> node; + NodePtr<ChannelBlockNode> node; int read_ref_count{0}; int write_ref_count{0}; }; @@ -276,7 +276,7 @@ class PipelineExtractor: public IRVisitor { // The argument handle map std::unordered_map<const Variable*, Var> arg_handle_; // The result block. - std::shared_ptr<PipelineNode> pipeline_; + NodePtr<PipelineNode> pipeline_; }; Pipeline MakePipeline(LoweredFunc f) { diff --git a/src/codegen/verilog/vpi_session.cc b/src/codegen/verilog/vpi_session.cc index ac2861e8f74f29ea78f7461908b5b30897ca6903..36c08cac3f84871b96893e3a7c62bd2a5f9ae257 100644 --- a/src/codegen/verilog/vpi_session.cc +++ b/src/codegen/verilog/vpi_session.cc @@ -50,7 +50,7 @@ inline VPIHandleNode* VPIHandle::get() const { VPIHandle VPIHandleCreate( const std::shared_ptr<VPISessionEntry>& sess, VPIRawHandle handle) { - std::shared_ptr<VPIHandleNode> n = std::make_shared<VPIHandleNode>(); + auto n = make_node<VPIHandleNode>(); n->sess = sess; n->handle = handle; return VPIHandle(n); @@ -102,7 +102,7 @@ int VPIGetIntProp(VPIHandleNode* h, int code) { } VPISession VPISession::make(int h_pipe_read, int h_pipe_write) { - std::shared_ptr<VPISessionNode> n = std::make_shared<VPISessionNode>(); + auto n = make_node<VPISessionNode>(); n->sess = std::make_shared<VPISessionEntry>(h_pipe_read, h_pipe_write); n->sess->in_control = true; VPISession sess(n); diff --git a/src/codegen/verilog/vpi_session.h b/src/codegen/verilog/vpi_session.h index 88a7f2f1906ec638d284a6cbcb56e4b4057f9831..9fab0f1739950fd093a4dc0a52d2724894f35eb5 100644 --- a/src/codegen/verilog/vpi_session.h +++ b/src/codegen/verilog/vpi_session.h @@ -27,7 +27,7 @@ using runtime::PackedFunc; class VPISession : public NodeRef { public: VPISession() {} - explicit VPISession(std::shared_ptr<Node> n) : NodeRef(n) {} + explicit VPISession(NodePtr<Node> n) : NodeRef(n) {} /*! * \brief Get handle by name. * \param name The name of the handle. @@ -63,7 +63,7 @@ class VPISession : public NodeRef { class VPIHandle : public NodeRef { public: VPIHandle() {} - explicit VPIHandle(std::shared_ptr<Node> n) : NodeRef(n) {} + explicit VPIHandle(NodePtr<Node> n) : NodeRef(n) {} /*! * \brief Get handle by name. * \param name The name of the handle. diff --git a/src/lang/api_registry.cc b/src/lang/api_registry.cc index 466ee1d3dd68d1d5ebcbc42c0ea324fedd382b3f..c9f84092f5dabd6f87188bab03043af3a6a2ceba 100644 --- a/src/lang/api_registry.cc +++ b/src/lang/api_registry.cc @@ -11,10 +11,10 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->stream << "EnvFunc(" << op->name << ")"; }); -std::shared_ptr<EnvFuncNode> CreateEnvNode(const std::string& name) { +NodePtr<EnvFuncNode> CreateEnvNode(const std::string& name) { auto* f = runtime::Registry::Get(name); CHECK(f != nullptr) << "Cannot find global function \'" << name << '\''; - std::shared_ptr<EnvFuncNode> n = std::make_shared<EnvFuncNode>(); + NodePtr<EnvFuncNode> n = make_node<EnvFuncNode>(); n->func = *f; n->name = name; return n; diff --git a/src/lang/attrs.cc b/src/lang/attrs.cc index 360c5b2e983342b4262eb36141638ad371249cda..12ebbff4be742ea2a5513ee48acb4c35d272723f 100644 --- a/src/lang/attrs.cc +++ b/src/lang/attrs.cc @@ -30,7 +30,7 @@ Array<AttrFieldInfo> DictAttrsNode::ListFieldInfo() const { } Attrs DictAttrsNode::make(Map<std::string, NodeRef> dict) { - std::shared_ptr<DictAttrsNode> n = std::make_shared<DictAttrsNode>(); + NodePtr<DictAttrsNode> n = make_node<DictAttrsNode>(); n->dict = std::move(dict); return Attrs(n); } diff --git a/src/lang/buffer.cc b/src/lang/buffer.cc index 3f23c2d480bffcb40009764f8d57e7778f62acb8..cb3194f8eb1d8a922fb51c1076b93428a76a7826 100644 --- a/src/lang/buffer.cc +++ b/src/lang/buffer.cc @@ -289,7 +289,7 @@ Buffer Buffer::MakeStrideView() const { if ((*this)->strides.size() != 0) return *this; if ((*this)->shape.size() == 0) return *this; std::vector<Expr> temp; - auto n = std::make_shared<BufferNode>(*operator->()); + auto n = make_node<BufferNode>(*operator->()); Expr acc = make_const(n->DefaultIndexType(), 1); for (size_t i = n->shape.size(); i != 0 ; --i) { temp.push_back(acc); @@ -373,7 +373,7 @@ Buffer BufferNode::make(Var data, std::string scope, int data_alignment, int offset_factor) { - auto n = std::make_shared<BufferNode>(); + auto n = make_node<BufferNode>(); n->data = std::move(data); n->dtype = dtype; n->shape = std::move(shape); diff --git a/src/lang/channel.cc b/src/lang/channel.cc index dd850becf95698873694ed3b71bb2b1ebd309719..dcc44a0d061179ccbc08a6dfb0db43cd0a75feec 100644 --- a/src/lang/channel.cc +++ b/src/lang/channel.cc @@ -7,7 +7,7 @@ namespace tvm { Channel ChannelNode::make(Var handle_var, Type dtype) { - auto n = std::make_shared<ChannelNode>(); + auto n = make_node<ChannelNode>(); n->handle_var = handle_var; n->dtype = dtype; return Channel(n); diff --git a/src/lang/expr.cc b/src/lang/expr.cc index c2dab10c26d56686910cf4da145e679104113a07..062ea9217e63a0c6e0fc855c18429256a2b9add1 100644 --- a/src/lang/expr.cc +++ b/src/lang/expr.cc @@ -13,18 +13,18 @@ namespace tvm { using HalideIR::IR::RangeNode; Range::Range(Expr begin, Expr end) - : Range(std::make_shared<RangeNode>( + : Range(make_node<RangeNode>( begin, is_zero(begin) ? end : (end - begin))) { } Range Range::make_by_min_extent(Expr min, Expr extent) { - return Range(std::make_shared<HalideIR::IR::RangeNode>(min, extent)); + return Range(make_node<HalideIR::IR::RangeNode>(min, extent)); } IterVar IterVarNode::make(Range dom, Var var, IterVarType t, std::string thread_tag) { - std::shared_ptr<IterVarNode> n = std::make_shared<IterVarNode>(); + NodePtr<IterVarNode> n = make_node<IterVarNode>(); n->dom = dom; n->var = var; n->iter_type = t; diff --git a/src/lang/ir.cc b/src/lang/ir.cc index 1e0a6e5065f435a9800ebcedf033886c8ca36d0c..875258540584cb65849006db077f89ca5930529d 100644 --- a/src/lang/ir.cc +++ b/src/lang/ir.cc @@ -52,7 +52,7 @@ CommReducer CommReducerNode::make(Array<Var> lhs, Array<Var> rhs, Array<Expr> result, Array<Expr> identity_element) { - auto node = std::make_shared<CommReducerNode>(); + auto node = make_node<CommReducerNode>(); node->lhs = lhs; node->rhs = rhs; node->result = result; @@ -83,7 +83,7 @@ Expr Reduce::make(CommReducer combiner, Array<Expr> source, if (!condition.defined()) { condition = const_true(); } - auto n = std::make_shared<Reduce>(); + auto n = make_node<Reduce>(); CHECK(source.defined()); for (size_t i = 0; i < axis.size(); ++i) { CHECK(axis[i].defined()); diff --git a/src/lang/node.cc b/src/lang/node.cc new file mode 100644 index 0000000000000000000000000000000000000000..f7043eaf7b2afdf08bd1a120143216cc2b2c070b --- /dev/null +++ b/src/lang/node.cc @@ -0,0 +1,58 @@ +/*! + * Copyright (c) 2018 by Contributors + * Implementation of IR Node API + * \file node.cc + */ +#include <tvm/node/node.h> +#include <memory> +#include <atomic> +#include <mutex> +#include <unordered_map> + +namespace tvm { + +namespace { +// single manager of operator information. +struct TypeManager { + // mutex to avoid registration from multiple threads. + // recursive is needed for trigger(which calls UpdateAttrMap) + std::mutex mutex; + std::atomic<uint32_t> type_counter{0}; + std::unordered_map<std::string, uint32_t> key2index; + std::vector<std::string> index2key; + // get singleton of the + static TypeManager* Global() { + static TypeManager inst; + return &inst; + } +}; +} // namespace + +const bool Node::_DerivedFrom(uint32_t tid) const { + static uint32_t tindex = TypeKey2Index(Node::_type_key); + return tid == tindex; +} + +// this is slow, usually caller always hold the result in a static variable. +uint32_t Node::TypeKey2Index(const char* key) { + TypeManager *t = TypeManager::Global(); + std::lock_guard<std::mutex>(t->mutex); + std::string skey = key; + auto it = t->key2index.find(skey); + if (it != t->key2index.end()) { + return it->second; + } + uint32_t tid = ++(t->type_counter); + t->key2index[skey] = tid; + t->index2key.push_back(skey); + return tid; +} + +const char* Node::TypeIndex2Key(uint32_t index) { + TypeManager *t = TypeManager::Global(); + std::lock_guard<std::mutex>(t->mutex); + internal_assert(index != 0); + return t->index2key.at(index - 1).c_str(); +} + +} // namespace tvm diff --git a/src/lang/reflection.cc b/src/lang/reflection.cc index a33594107a69dee22bbb32564117c63109cced8d..497ec24f4129f9af5dc92eb074b86204ef91b0c7 100644 --- a/src/lang/reflection.cc +++ b/src/lang/reflection.cc @@ -6,7 +6,7 @@ #include <tvm/base.h> #include <tvm/expr.h> #include <tvm/attrs.h> -#include <tvm/container.h> +#include <tvm/node/container.h> #include <tvm/packed_func_ext.h> #include <tvm/runtime/ndarray.h> #include <dmlc/json.h> @@ -248,7 +248,7 @@ class JSONAttrGetter : public AttrVisitor { class JSONAttrSetter : public AttrVisitor { public: - const std::vector<std::shared_ptr<Node> >* node_list_; + const std::vector<NodePtr<Node> >* node_list_; const std::vector<runtime::NDArray>* tensor_list_; JSONNode* node_; @@ -401,13 +401,13 @@ std::string SaveJSON(const NodeRef& n) { return os.str(); } -std::shared_ptr<Node> LoadJSON_(std::string json_str) { +NodePtr<Node> LoadJSON_(std::string json_str) { std::istringstream is(json_str); dmlc::JSONReader reader(&is); JSONGraph jgraph; // load in json graph. jgraph.Load(&reader); - std::vector<std::shared_ptr<Node> > nodes; + std::vector<NodePtr<Node> > nodes; std::vector<runtime::NDArray> tensors; // load in tensors for (const std::string& blob : jgraph.b64ndarrays) { @@ -427,7 +427,7 @@ std::shared_ptr<Node> LoadJSON_(std::string json_str) { << "Node type \'" << jnode.type_key << "\' is not registered in TVM"; nodes.emplace_back(f->fcreator(jnode.global_key)); } else { - nodes.emplace_back(std::shared_ptr<Node>()); + nodes.emplace_back(NodePtr<Node>()); } } CHECK_EQ(nodes.size(), jgraph.nodes.size()); @@ -526,7 +526,7 @@ void MakeNode(const TVMArgs& args, TVMRetValue* rv) { TVMArgs kwargs(args.values + 1, args.type_codes + 1, args.size() - 1); CHECK(f->fglobal_key == nullptr) << "Cannot make node type \'" << type_key << "\' with global_key."; - std::shared_ptr<Node> n = f->fcreator(empty_str); + NodePtr<Node> n = f->fcreator(empty_str); if (n->derived_from<BaseAttrsNode>()) { static_cast<BaseAttrsNode*>(n.get())->InitByPackedArgs(kwargs); } else { diff --git a/src/lang/tensor.cc b/src/lang/tensor.cc index 5db4f45e799f9cbba6fd30d585fa419f8bff1e2d..4f9c3e9d1782d066ae24d705fca0537e0c56109b 100644 --- a/src/lang/tensor.cc +++ b/src/lang/tensor.cc @@ -30,7 +30,7 @@ Tensor TensorNode::make(Array<Expr> shape, Type dtype, Operation op, int value_index) { - auto n = std::make_shared<TensorNode>(); + auto n = make_node<TensorNode>(); n->shape = std::move(shape); n->dtype = dtype; n->op = op; @@ -47,7 +47,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_REGISTER_NODE_TYPE(TensorNode); Tensor Operation::output(size_t i) const { - auto node = std::make_shared<TensorNode>(); + auto node = make_node<TensorNode>(); node->op = *this; node->value_index = i; node->dtype = (*this)->output_dtype(i); @@ -62,7 +62,7 @@ TensorIntrin TensorIntrinNode::make(std::string name, Stmt body, Stmt reduce_init, Stmt reduce_update) { - auto n = std::make_shared<TensorIntrinNode>(); + auto n = make_node<TensorIntrinNode>(); n->name = std::move(name); n->op = std::move(op); n->inputs = std::move(inputs); diff --git a/src/op/compute_op.cc b/src/op/compute_op.cc index 267a25ff372b8e24b4127167f22bdd278891d02c..6100c957e4739059cc2aeb126036205f9bf09d38 100644 --- a/src/op/compute_op.cc +++ b/src/op/compute_op.cc @@ -69,7 +69,7 @@ Tensor compute(Array<Expr> shape, std::string name, std::string tag, Map<std::string, NodeRef> attrs) { - auto op_node = std::make_shared<ComputeOpNode>(); + auto op_node = make_node<ComputeOpNode>(); // compute dimension. size_t ndim = shape.size(); std::vector<IterVar> axis; @@ -91,7 +91,7 @@ Array<Tensor> compute(Array<Expr> shape, std::string name, std::string tag, Map<std::string, NodeRef> attrs) { - auto op_node = std::make_shared<ComputeOpNode>(); + auto op_node = make_node<ComputeOpNode>(); // compute dimension. size_t ndim = shape.size(); std::vector<IterVar> axis; @@ -117,7 +117,7 @@ Operation ComputeOpNode::make(std::string name, Map<std::string, NodeRef> attrs, Array<IterVar> axis, Array<Expr> body) { - auto n = std::make_shared<ComputeOpNode>(); + auto n = make_node<ComputeOpNode>(); n->name = std::move(name); n->tag = std::move(tag); n->attrs = std::move(attrs); @@ -163,7 +163,7 @@ Operation ComputeOpNode::ReplaceInputs( if (!new_reduce.same_as(this->body[0])) { const ir::Reduce* r = new_reduce.as<ir::Reduce>(); for (size_t k = 0; k < this->body.size(); ++k) { - std::shared_ptr<ir::Reduce> n = std::make_shared<ir::Reduce>(*r); + auto n = make_node<ir::Reduce>(*r); n->value_index = static_cast<int>(k); n->type = r->source[k].type(); arr.push_back(Expr(n)); diff --git a/src/op/extern_op.cc b/src/op/extern_op.cc index 86c1d5e74527dc4b4b81ec55e888c581434e84f6..952e52a852bdc77b282ee66ce5ca248508ab2a7e 100644 --- a/src/op/extern_op.cc +++ b/src/op/extern_op.cc @@ -43,7 +43,7 @@ Operation ExternOpNode::make(std::string name, Array<Buffer> input_placeholders, Array<Buffer> output_placeholders, Stmt body) { - auto n = std::make_shared<ExternOpNode>(); + auto n = make_node<ExternOpNode>(); n->name = std::move(name); n->tag = std::move(tag); n->attrs = std::move(attrs); @@ -68,7 +68,7 @@ Operation ExternOpNode::ReplaceInputs( const Operation& self, const std::unordered_map<Tensor, Tensor>& rmap) const { CHECK_EQ(self.operator->(), this); - auto n = std::make_shared<ExternOpNode>(*this); + auto n = make_node<ExternOpNode>(*this); n->body = op::ReplaceTensor(this->body, rmap); for (size_t i = 0; i < n->inputs.size(); ++i) { Tensor t = n->inputs[i]; diff --git a/src/op/placeholder_op.cc b/src/op/placeholder_op.cc index a2cd0eb2d81fe91ba5778eeda27adcf99e6d4e77..fcd5993dafa52a107b0a783d95825af594b683c7 100644 --- a/src/op/placeholder_op.cc +++ b/src/op/placeholder_op.cc @@ -36,7 +36,7 @@ Array<Expr> PlaceholderOpNode::output_shape(size_t i) const { Operation PlaceholderOpNode::make(std::string name, Array<Expr> shape, Type dtype) { - auto n = std::make_shared<PlaceholderOpNode>(); + auto n = make_node<PlaceholderOpNode>(); n->name = name; n->shape = shape; n->dtype = dtype; diff --git a/src/op/scan_op.cc b/src/op/scan_op.cc index d03601709ab4ebca2c2509b4678aa988f3c6163f..60369aaabb33f1b5b7b0a22203efe1a7d6deefb3 100644 --- a/src/op/scan_op.cc +++ b/src/op/scan_op.cc @@ -51,7 +51,7 @@ Operation ScanOpNode::make(std::string name, Array<Tensor> update, Array<Tensor> state_placeholder, Array<Tensor> inputs) { - auto n = std::make_shared<ScanOpNode>(); + auto n = make_node<ScanOpNode>(); CHECK_EQ(init.size(), update.size()); CHECK_EQ(init.size(), state_placeholder.size()); @@ -135,7 +135,7 @@ Operation ScanOpNode::ReplaceInputs( const Operation& self, const std::unordered_map<Tensor, Tensor>& rmap) const { CHECK_EQ(self.operator->(), this); - std::shared_ptr<ScanOpNode> n = std::make_shared<ScanOpNode>(*this); + auto n = make_node<ScanOpNode>(*this); for (size_t i = 0; i < n->init.size(); ++i) { if (rmap.count(n->init[i])) { n->init.Set(i, rmap.at(n->init[i])); diff --git a/src/pass/combine_context_call.cc b/src/pass/combine_context_call.cc index dff91e6690f2e9eb3c0b020de1388fe42989c346..d60256bcfcf02654c1783a63cd5c4cad78309fcd 100644 --- a/src/pass/combine_context_call.cc +++ b/src/pass/combine_context_call.cc @@ -90,7 +90,7 @@ class ContextCallCombiner final : public IRMutator { }; LoweredFunc CombineContextCall(LoweredFunc f) { - auto n = std::make_shared<LoweredFuncNode>(*f.operator->()); + auto n = make_node<LoweredFuncNode>(*f.operator->()); n->body = ContextCallCombiner().Combine(n->body); return LoweredFunc(n); } diff --git a/src/pass/ir_util.cc b/src/pass/ir_util.cc index d06839beca33e4229e066be4369f044fffa774da..89426f982ba889c7090cad535b952b3b59312f28 100644 --- a/src/pass/ir_util.cc +++ b/src/pass/ir_util.cc @@ -13,38 +13,38 @@ Stmt MergeNest(const std::vector<Stmt>& nest, Stmt body) { for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) { Stmt s = *ri; if (s.as<For>()) { - auto n = std::make_shared<For>(*s.as<For>()); + auto n = make_node<For>(*s.as<For>()); CHECK(is_no_op(n->body)); n->body = body; body = Stmt(n); } else if (s.as<LetStmt>()) { - auto n = std::make_shared<LetStmt>(*s.as<LetStmt>()); + auto n = make_node<LetStmt>(*s.as<LetStmt>()); CHECK(is_no_op(n->body)); n->body = body; body = Stmt(n); } else if (s.as<AttrStmt>()) { - auto n = std::make_shared<AttrStmt>(*s.as<AttrStmt>()); + auto n = make_node<AttrStmt>(*s.as<AttrStmt>()); CHECK(is_no_op(n->body)); n->body = body; body = Stmt(n); } else if (s.as<IfThenElse>()) { - auto n = std::make_shared<IfThenElse>(*s.as<IfThenElse>()); + auto n = make_node<IfThenElse>(*s.as<IfThenElse>()); CHECK(is_no_op(n->then_case)); CHECK(!n->else_case.defined()); n->then_case = body; body = Stmt(n); } else if (s.as<Block>()) { - auto n = std::make_shared<Block>(*s.as<Block>()); + auto n = make_node<Block>(*s.as<Block>()); CHECK(is_no_op(n->rest)); n->rest = body; body = Stmt(n); } else if (s.as<AssertStmt>()) { - auto n = std::make_shared<AssertStmt>(*s.as<AssertStmt>()); + auto n = make_node<AssertStmt>(*s.as<AssertStmt>()); CHECK(is_no_op(n->body)); n->body = body; body = Stmt(n); } else if (s.as<Allocate>()) { - auto n = std::make_shared<Allocate>(*s.as<Allocate>()); + auto n = make_node<Allocate>(*s.as<Allocate>()); CHECK(is_no_op(n->body)); n->body = body; body = Stmt(n); diff --git a/src/pass/lower_intrin.cc b/src/pass/lower_intrin.cc index b38051326d1d1dff27b4d8057a954a091bc21fc7..1a9caf4b591ec2dff7368b3d06204cf0dabc4f54 100644 --- a/src/pass/lower_intrin.cc +++ b/src/pass/lower_intrin.cc @@ -104,7 +104,7 @@ class IntrinInjecter : public IRMutator { LoweredFunc LowerIntrin(LoweredFunc f, const std::string& target) { - auto n = std::make_shared<LoweredFuncNode>(*f.operator->()); + auto n = make_node<LoweredFuncNode>(*f.operator->()); n->body = IntrinInjecter(target).Mutate(n->body); return LoweredFunc(n); } diff --git a/src/pass/lower_thread_allreduce.cc b/src/pass/lower_thread_allreduce.cc index 4d7f086d05346ec5ab8f34337db82cde749e76f6..2f700ed9112d44d658bbd00e621cdb0c63205d74 100644 --- a/src/pass/lower_thread_allreduce.cc +++ b/src/pass/lower_thread_allreduce.cc @@ -317,7 +317,7 @@ class ThreadAllreduceBuilder final : public IRMutator { LoweredFunc LowerThreadAllreduce(LoweredFunc f, int warp_size) { CHECK_NE(f->func_type, kHostFunc); - auto n = std::make_shared<LoweredFuncNode>(*f.operator->()); + auto n = make_node<LoweredFuncNode>(*f.operator->()); n->body = ThreadAllreduceBuilder(warp_size).Mutate(n->body); return LoweredFunc(n); } diff --git a/src/pass/lower_tvm_builtin.cc b/src/pass/lower_tvm_builtin.cc index 46686a65803aa08d151fe336290c957e4b7f1de1..cf3d9f7eeeb17307384a35d44c3370abade2ed4e 100644 --- a/src/pass/lower_tvm_builtin.cc +++ b/src/pass/lower_tvm_builtin.cc @@ -288,7 +288,7 @@ class BuiltinLower : public IRMutator { }; LoweredFunc LowerTVMBuiltin(LoweredFunc f) { - auto n = std::make_shared<LoweredFuncNode>(*f.operator->()); + auto n = make_node<LoweredFuncNode>(*f.operator->()); n->body = BuiltinLower().Build(n->body); return LoweredFunc(n); } diff --git a/src/pass/lower_warp_memory.cc b/src/pass/lower_warp_memory.cc index 85ae365f2a82f94f8b946973a58bbf7609841c92..01ab2b51752e12decec415a766bd64600362a15e 100644 --- a/src/pass/lower_warp_memory.cc +++ b/src/pass/lower_warp_memory.cc @@ -93,7 +93,7 @@ class WarpStoreCoeffFinder : private IRVisitor { arith::DetectLinearEquation(index, {warp_index_}); CHECK_EQ(m.size(), 2U) << "LowerWarpMemory failed due to store index=" << index; - int coeff; + int coeff = 0; Expr mcoeff = ir::Simplify(m[0]); CHECK(arith::GetConstInt(mcoeff, &coeff) && coeff > 0) @@ -317,7 +317,7 @@ class WarpMemoryRewriter : private IRMutator { LoweredFunc LowerWarpMemory(LoweredFunc f, int warp_size) { CHECK_EQ(f->func_type, kDeviceFunc); - auto n = std::make_shared<LoweredFuncNode>(*f.operator->()); + auto n = make_node<LoweredFuncNode>(*f.operator->()); n->body = WarpMemoryRewriter(warp_size).Rewrite(n->body); return LoweredFunc(n); } diff --git a/src/pass/make_api.cc b/src/pass/make_api.cc index 8113c58f3f789005622d2c3c832bdc1fab266344..41f92ad240850d024e726145ff240f0350ad5cdc 100644 --- a/src/pass/make_api.cc +++ b/src/pass/make_api.cc @@ -132,7 +132,7 @@ LoweredFunc MakeAPI(Stmt body, } } - std::shared_ptr<LoweredFuncNode> n = std::make_shared<LoweredFuncNode>(); + NodePtr<LoweredFuncNode> n = make_node<LoweredFuncNode>(); n->name = name; n->args = args; n->handle_data_type = binder.def_handle_dtype(); @@ -197,7 +197,7 @@ class DeviceTypeBinder: public IRMutator { LoweredFunc BindDeviceType(LoweredFunc f, int device_type) { - auto n = std::make_shared<LoweredFuncNode>(*f.operator->()); + auto n = make_node<LoweredFuncNode>(*f.operator->()); n->body = DeviceTypeBinder(device_type).Mutate(n->body); return LoweredFunc(n); } diff --git a/src/pass/remap_thread_axis.cc b/src/pass/remap_thread_axis.cc index 94e4819a1d71b7c768a0f964ce89fdf89d7a4a1f..08a62b25e2c443f4fe07248acfd459c2c89d6a61 100644 --- a/src/pass/remap_thread_axis.cc +++ b/src/pass/remap_thread_axis.cc @@ -67,7 +67,7 @@ RemapThreadAxis(LoweredFunc f, Map<Expr, IterVar> thread_map) { } CHECK_EQ(f->func_type, kDeviceFunc); - auto n = std::make_shared<LoweredFuncNode>(*f.operator->()); + auto n = make_node<LoweredFuncNode>(*f.operator->()); // replace the thread axis for (size_t i = 0; i < n->thread_axis.size(); ++i) { auto it = tmap.find(n->thread_axis[i]->thread_tag); diff --git a/src/pass/split_host_device.cc b/src/pass/split_host_device.cc index c7b20e137638e0454d3f445fa7e6c3a10e68a375..112c2c173df12a42725486a56a046261be3fe261 100644 --- a/src/pass/split_host_device.cc +++ b/src/pass/split_host_device.cc @@ -165,8 +165,8 @@ class HostDeviceSplitter : public IRMutator { handle_data_type_[kv.first.get()] = kv.second; } name_ = f->name; - std::shared_ptr<LoweredFuncNode> n = - std::make_shared<LoweredFuncNode>(*f.operator->()); + NodePtr<LoweredFuncNode> n = + make_node<LoweredFuncNode>(*f.operator->()); n->body = this->Mutate(f->body); n->func_type = kHostFunc; Array<LoweredFunc> ret{LoweredFunc(n)}; @@ -180,7 +180,7 @@ class HostDeviceSplitter : public IRMutator { Stmt SplitDeviceFunc(Stmt body) { std::ostringstream os; os << name_ << "_kernel" << device_funcs_.size(); - std::shared_ptr<LoweredFuncNode> n = std::make_shared<LoweredFuncNode>(); + NodePtr<LoweredFuncNode> n = make_node<LoweredFuncNode>(); // isolate the device function. IRUseDefAnalysis m; m.visit_thread_extent_ = false; diff --git a/src/pass/storage_rewrite.cc b/src/pass/storage_rewrite.cc index 58b62f291d397bc65abaa7412cb0dfd0c9e49f4d..2bab21d857376fefc080662a69bae5b74c942529 100644 --- a/src/pass/storage_rewrite.cc +++ b/src/pass/storage_rewrite.cc @@ -950,8 +950,7 @@ class VectorAllocRewriter : public IRMutator { LoweredFunc PointerValueTypeRewrite(LoweredFunc f) { - std::shared_ptr<LoweredFuncNode> n = - std::make_shared<LoweredFuncNode>(*f.operator->()); + auto n = make_node<LoweredFuncNode>(*f.operator->()); VectorAllocRewriter rewriter; n->body = rewriter.Mutate(n->body); for (Var arg : f->args) { diff --git a/src/pass/storage_sync.cc b/src/pass/storage_sync.cc index 43f3b94d114f0515dcf60ac20c0a26ce62914973..6f7fc886fd8cbfb1a42e2ab4429743d397f417a1 100644 --- a/src/pass/storage_sync.cc +++ b/src/pass/storage_sync.cc @@ -329,7 +329,7 @@ Stmt ThreadSync(Stmt stmt, std::string storage_scope) { LoweredFunc ThreadSync(LoweredFunc f, std::string storage_scope) { CHECK_NE(f->func_type, kHostFunc); - auto n = std::make_shared<LoweredFuncNode>(*f.operator->()); + auto n = make_node<LoweredFuncNode>(*f.operator->()); n->body = ThreadSync(f->body, storage_scope); return LoweredFunc(n); } diff --git a/src/relay/ir/base.cc b/src/relay/ir/base.cc index 7e7fb71f6d6c1fa374b5c15419c2af7a52edf08b..97ac9e52a4c2d2624b1868d7f76cd23be06a612d 100644 --- a/src/relay/ir/base.cc +++ b/src/relay/ir/base.cc @@ -12,50 +12,39 @@ namespace relay { using tvm::IRPrinter; using namespace tvm::runtime; -SourceName SourceNameNode::make(std::string name) { - std::shared_ptr<SourceNameNode> n = std::make_shared<SourceNameNode>(); - n->name = std::move(name); - return SourceName(n); -} - -std::shared_ptr<SourceNameNode> CreateSourceName(const std::string& name) { - SourceName sn = SourceName::Get(name); - CHECK(!sn.defined()) << "Cannot find source name \'" << name << '\''; - std::shared_ptr<Node> node = sn.node_; - return std::dynamic_pointer_cast<SourceNameNode>(node); -} - -const SourceName& SourceName::Get(const std::string& name) { - static std::unordered_map<std::string, SourceName> source_map; +NodePtr<SourceNameNode> GetSourceNameNode(const std::string& name) { + // always return pointer as the reference can change as map re-allocate. + // or use another level of indirection by creating a unique_ptr + static std::unordered_map<std::string, NodePtr<SourceNameNode> > source_map; auto sn = source_map.find(name); if (sn == source_map.end()) { - auto source_name = SourceNameNode::make(name); - source_map.insert({name, source_name}); - return source_map.at(name); + NodePtr<SourceNameNode> n = make_node<SourceNameNode>(); + n->name = std::move(name); + source_map[name] = n; + return n; } else { return sn->second; } } -TVM_REGISTER_API("relay._make.SourceName") - .set_body([](tvm::TVMArgs args, tvm::TVMRetValue *ret) { - *ret = SourceNameNode::make(args[0]); - }); +SourceName SourceName::Get(const std::string& name) { + return SourceName(GetSourceNameNode(name)); +} TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) - .set_dispatch<SourceNameNode>([](const SourceNameNode *node, tvm::IRPrinter *p) { - p->stream << "SourceNameNode(" << node->name << ", " << node << ")"; - }); +.set_dispatch<SourceNameNode>([](const SourceNameNode *node, tvm::IRPrinter *p) { + p->stream << "SourceName(" << node->name << ", " << node << ")"; + }); TVM_REGISTER_NODE_TYPE(SourceNameNode) -.set_creator(CreateSourceName) +.set_creator(GetSourceNameNode) .set_global_key([](const Node* n) { return static_cast<const SourceNameNode*>(n)->name; }); Span SpanNode::make(SourceName source, int lineno, int col_offset) { - std::shared_ptr<SpanNode> n = std::make_shared<SpanNode>(); + auto n = make_node<SpanNode>(); n->source = std::move(source); n->lineno = lineno; n->col_offset = col_offset; diff --git a/src/relay/ir/environment.cc b/src/relay/ir/environment.cc index 47c9789ab5ae4081a1e6d9deab3e81a5bb99bd1a..16b0314507cfda2436931a999cb1c70ea321b56c 100644 --- a/src/relay/ir/environment.cc +++ b/src/relay/ir/environment.cc @@ -15,7 +15,7 @@ using tvm::IRPrinter; using namespace runtime; Environment EnvironmentNode::make(tvm::Map<GlobalVar, Function> global_funcs) { - std::shared_ptr<EnvironmentNode> n = std::make_shared<EnvironmentNode>(); + auto n = make_node<EnvironmentNode>(); n->functions = std::move(global_funcs); return Environment(n); } @@ -31,20 +31,22 @@ GlobalVar EnvironmentNode::GetGlobalVar(const std::string &str) { } } -/*! \brief Add a new item to the global environment +/*! + * \brief Add a new item to the global environment * \note if the update flag is not set adding a duplicate * definition will trigger an exception, otherwise we will * update the definition if and only if it is type compatible. */ -void EnvironmentNode::Add(const GlobalVar &var, const Function &func, +void EnvironmentNode::Add(const GlobalVar &var, + const Function &func, bool update) { // Type check the item before we add it to the environment. - auto env = GetRef<Environment>(this); + auto env = relay::GetRef<Environment>(this); Expr checked_expr = InferType(env, var, func); if (const FunctionNode *func_node = checked_expr.as<FunctionNode>()) { - auto checked_func = GetRef<Function>(func_node); + auto checked_func = relay::GetRef<Function>(func_node); auto type = checked_func->checked_type(); CHECK(IsFullyResolved(type)); @@ -100,46 +102,46 @@ void EnvironmentNode::Merge(const Environment &env) { } TVM_REGISTER_API("relay._make.Environment") - .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = EnvironmentNode::make(args[0]); - }); +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = EnvironmentNode::make(args[0]); + }); TVM_REGISTER_API("relay._env.Environment_Add") - .set_body([](TVMArgs args, TVMRetValue *ret) { - Environment env = args[0]; - env->Add(args[1], args[2], false); - }); +.set_body([](TVMArgs args, TVMRetValue *ret) { + Environment env = args[0]; + env->Add(args[1], args[2], false); + }); TVM_REGISTER_API("relay._env.Environment_GetGlobalVar") - .set_body([](TVMArgs args, TVMRetValue *ret) { - Environment env = args[0]; - *ret = env->GetGlobalVar(args[1]); - }); +.set_body([](TVMArgs args, TVMRetValue *ret) { + Environment env = args[0]; + *ret = env->GetGlobalVar(args[1]); + }); TVM_REGISTER_API("relay._env.Environment_Lookup") - .set_body([](TVMArgs args, TVMRetValue *ret) { - Environment env = args[0]; - GlobalVar var = args[1]; - *ret = env->Lookup(var); - }); +.set_body([](TVMArgs args, TVMRetValue *ret) { + Environment env = args[0]; + GlobalVar var = args[1]; + *ret = env->Lookup(var); + }); TVM_REGISTER_API("relay._env.Environment_Lookup_str") - .set_body([](TVMArgs args, TVMRetValue *ret) { - Environment env = args[0]; - std::string var_name = args[1]; - auto var = env->GetGlobalVar(var_name); - *ret = env->Lookup(var); - }); +.set_body([](TVMArgs args, TVMRetValue *ret) { + Environment env = args[0]; + std::string var_name = args[1]; + auto var = env->GetGlobalVar(var_name); + *ret = env->Lookup(var); + }); TVM_REGISTER_API("relay._env.Environment_Merge") - .set_body([](TVMArgs args, TVMRetValue *ret) { - Environment env = args[0]; - env->Merge(args[1]); - }); +.set_body([](TVMArgs args, TVMRetValue *ret) { + Environment env = args[0]; + env->Merge(args[1]); + }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) - .set_dispatch<EnvironmentNode>([](const EnvironmentNode *node, - tvm::IRPrinter *p) { +.set_dispatch<EnvironmentNode>( + [](const EnvironmentNode *node, tvm::IRPrinter *p) { p->stream << "EnvironmentNode( " << node->functions << ")"; }); diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index f4363f5312c4df465c90fb922f3cd726d4b5b642..241ccc0b85c38173e7f3af4118a756f1934dcf8c 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -3,7 +3,6 @@ * \file src/tvm/ir/expr.cc * \brief The expression AST nodes of Relay. */ -#include <tvm/ir_functor.h> #include <tvm/relay/expr.h> namespace tvm { @@ -13,21 +12,20 @@ using tvm::IRPrinter; using namespace tvm::runtime; Constant ConstantNode::make(runtime::NDArray data) { - std::shared_ptr<ConstantNode> n = std::make_shared<ConstantNode>(); + NodePtr<ConstantNode> n = make_node<ConstantNode>(); n->data = std::move(data); return Constant(n); } TVM_REGISTER_API("relay._make.Constant") - .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = ConstantNode::make(args[0]); - }); +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = ConstantNode::make(args[0]); + }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) - .set_dispatch<ConstantNode>([](const ConstantNode *node, - tvm::IRPrinter *p) { - p->stream << "ConstantNode(TODO)"; - }); +.set_dispatch<ConstantNode>([](const ConstantNode *node, tvm::IRPrinter *p) { + p->stream << "Constant(TODO)"; + }); TensorType ConstantNode::tensor_type() const { auto dtype = TVMType2Type(data->dtype); @@ -41,57 +39,55 @@ TensorType ConstantNode::tensor_type() const { } Tuple TupleNode::make(tvm::Array<relay::Expr> fields) { - std::shared_ptr<TupleNode> n = std::make_shared<TupleNode>(); + NodePtr<TupleNode> n = make_node<TupleNode>(); n->fields = std::move(fields); return Tuple(n); } TVM_REGISTER_API("relay._make.Tuple") - .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = TupleNode::make(args[0]); - }); +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = TupleNode::make(args[0]); + }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) - .set_dispatch<TupleNode>([](const TupleNode *node, tvm::IRPrinter *p) { - p->stream << "TupleNode(" << node->fields << ")"; - }); +.set_dispatch<TupleNode>([](const TupleNode *node, tvm::IRPrinter *p) { + p->stream << "Tuple(" << node->fields << ")"; + }); Var VarNode::make(std::string name_hint) { - std::shared_ptr<VarNode> n = std::make_shared<VarNode>(); + NodePtr<VarNode> n = make_node<VarNode>(); n->name_hint = std::move(name_hint); return Var(n); } TVM_REGISTER_API("relay._make.Var") - .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = VarNode::make(args[0]); - }); +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = VarNode::make(args[0]); + }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) - .set_dispatch<VarNode>([](const VarNode *node, - tvm::IRPrinter *p) { - p->stream << "VarNode(" << node->name_hint << ")"; - }); +.set_dispatch<VarNode>([](const VarNode *node, tvm::IRPrinter *p) { + p->stream << "Var(" << node->name_hint << ")"; + }); GlobalVar GlobalVarNode::make(std::string name_hint) { - std::shared_ptr<GlobalVarNode> n = std::make_shared<GlobalVarNode>(); + NodePtr<GlobalVarNode> n = make_node<GlobalVarNode>(); n->name_hint = std::move(name_hint); return GlobalVar(n); } TVM_REGISTER_API("relay._make.GlobalVar") - .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = GlobalVarNode::make(args[0]); - }); +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = GlobalVarNode::make(args[0]); + }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) - .set_dispatch<GlobalVarNode>([](const GlobalVarNode *node, - tvm::IRPrinter *p) { - p->stream << "GlobalVarNode(" << node->name_hint << ")"; - }); +.set_dispatch<GlobalVarNode>([](const GlobalVarNode *node, tvm::IRPrinter *p) { + p->stream << "GlobalVar(" << node->name_hint << ")"; + }); Param ParamNode::make(Var var, Type type) { - std::shared_ptr<ParamNode> n = std::make_shared<ParamNode>(); + NodePtr<ParamNode> n = make_node<ParamNode>(); n->var = std::move(var); n->type = std::move(type); return Param(n); @@ -104,12 +100,12 @@ TVM_REGISTER_API("relay._make.Param") TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch<ParamNode>([](const ParamNode *node, tvm::IRPrinter *p) { - p->stream << "ParamNode(" << node->var << ", " << node->type << ")"; + p->stream << "Param(" << node->var << ", " << node->type << ")"; }); Function FunctionNode::make(tvm::Array<Param> params, Type ret_type, Expr body, tvm::Array<TypeParam> type_params) { - std::shared_ptr<FunctionNode> n = std::make_shared<FunctionNode>(); + NodePtr<FunctionNode> n = make_node<FunctionNode>(); n->params = std::move(params); n->ret_type = std::move(ret_type); n->body = std::move(body); @@ -140,7 +136,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) Call CallNode::make(Expr op, Array<Expr> args, Attrs attrs, Array<Type> type_args) { - std::shared_ptr<CallNode> n = std::make_shared<CallNode>(); + NodePtr<CallNode> n = make_node<CallNode>(); n->op = std::move(op); n->args = std::move(args); n->attrs = std::move(attrs); @@ -160,7 +156,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) }); Let LetNode::make(Var var, Expr value, Expr body, Type value_type) { - std::shared_ptr<LetNode> n = std::make_shared<LetNode>(); + NodePtr<LetNode> n = make_node<LetNode>(); n->var = std::move(var); n->value = std::move(value); n->body = std::move(body); @@ -180,7 +176,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) }); If IfNode::make(Expr cond, Expr true_branch, Expr false_branch) { - std::shared_ptr<IfNode> n = std::make_shared<IfNode>(); + NodePtr<IfNode> n = make_node<IfNode>(); n->cond = std::move(cond); n->true_branch = std::move(true_branch); n->false_branch = std::move(false_branch); diff --git a/src/relay/ir/op.cc b/src/relay/ir/op.cc index d1a9dd072d313af890dcf7cad5d1bfd8048f322f..4826aed54ba574aeec20f2582d6a4c8bf1d13ae8 100644 --- a/src/relay/ir/op.cc +++ b/src/relay/ir/op.cc @@ -51,7 +51,7 @@ const Op& Op::Get(const std::string& name) { OpRegistry::OpRegistry() { OpManager* mgr = OpManager::Global(); - std::shared_ptr<OpNode> n = std::make_shared<OpNode>(); + NodePtr<OpNode> n = make_node<OpNode>(); n->index_ = mgr->op_counter++; op_ = Op(n); } @@ -90,14 +90,14 @@ void OpRegistry::UpdateAttr(const std::string& key, TVMRetValue value, // Frontend APIs TVM_REGISTER_API("relay.op._ListOpNames") - .set_body_typed<Array<tvm::Expr>()>([]() { - Array<tvm::Expr> ret; - for (const std::string& name : - dmlc::Registry<OpRegistry>::ListAllNames()) { - ret.push_back(tvm::Expr(name)); - } - return ret; - }); +.set_body_typed<Array<tvm::Expr>()>([]() { + Array<tvm::Expr> ret; + for (const std::string& name : + dmlc::Registry<OpRegistry>::ListAllNames()) { + ret.push_back(tvm::Expr(name)); + } + return ret; + }); TVM_REGISTER_API("relay.op._GetOp").set_body_typed<Op(std::string)>(Op::Get); @@ -138,11 +138,10 @@ TVM_REGISTER_API("relay.op._Register") } }); -std::shared_ptr<OpNode> CreateOp(const std::string& name) { +NodePtr<Node> CreateOp(const std::string& name) { auto op = Op::Get(name); CHECK(!op.defined()) << "Cannot find op \'" << name << '\''; - std::shared_ptr<Node> node = op.node_; - return std::dynamic_pointer_cast<OpNode>(node); + return op.node_; } TVM_REGISTER_NODE_TYPE(OpNode) diff --git a/src/relay/ir/type.cc b/src/relay/ir/type.cc index c13fea26dacd2767622f9fa6cb84426d4eb0df44..fce01390fa948f48679c37f1b9c492ce1b79dbca 100644 --- a/src/relay/ir/type.cc +++ b/src/relay/ir/type.cc @@ -3,7 +3,6 @@ * \file src/tvm/ir/type.cc * \brief The type system AST nodes of Relay. */ -#include <tvm/ir_functor.h> #include <tvm/relay/type.h> namespace tvm { @@ -13,7 +12,7 @@ using tvm::IRPrinter; using namespace tvm::runtime; TensorType TensorTypeNode::make(Array<ShapeExpr> shape, DataType dtype) { - std::shared_ptr<TensorTypeNode> n = std::make_shared<TensorTypeNode>(); + NodePtr<TensorTypeNode> n = make_node<TensorTypeNode>(); n->shape = std::move(shape); n->dtype = std::move(dtype); return TensorType(n); @@ -36,7 +35,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) }); TypeParam TypeParamNode::make(std::string name, TypeParamNode::Kind kind) { - std::shared_ptr<TypeParamNode> n = std::make_shared<TypeParamNode>(); + NodePtr<TypeParamNode> n = make_node<TypeParamNode>(); n->var = tvm::Var(name); n->kind = std::move(kind); return TypeParam(n); @@ -59,7 +58,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) FuncType FuncTypeNode::make(tvm::Array<Type> arg_types, Type ret_type, tvm::Array<TypeParam> type_params, tvm::Array<TypeConstraint> type_constraints) { - std::shared_ptr<FuncTypeNode> n = std::make_shared<FuncTypeNode>(); + NodePtr<FuncTypeNode> n = make_node<FuncTypeNode>(); n->arg_types = std::move(arg_types); n->ret_type = std::move(ret_type); n->type_params = std::move(type_params); @@ -81,7 +80,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) }); TypeRelation TypeRelationNode::make(std::string name, TypeRelationFn func, Array<Type> args) { - std::shared_ptr<TypeRelationNode> n = std::make_shared<TypeRelationNode>(); + NodePtr<TypeRelationNode> n = make_node<TypeRelationNode>(); n->name = std::move(name); n->func_ = std::move(func); n->args = std::move(args); @@ -101,7 +100,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) }); TupleType TupleTypeNode::make(Array<Type> fields) { - std::shared_ptr<TupleTypeNode> n = std::make_shared<TupleTypeNode>(); + NodePtr<TupleTypeNode> n = make_node<TupleTypeNode>(); n->fields = std::move(fields); return TupleType(n); } diff --git a/src/relay/pass/kind_check.cc b/src/relay/pass/kind_check.cc index 522eb93483fb817f6b53da5d6a4e7f880f6a9d25..91d2d582211084541e1cf8538f3b7360c96c9ef6 100644 --- a/src/relay/pass/kind_check.cc +++ b/src/relay/pass/kind_check.cc @@ -10,10 +10,9 @@ * * For example tensors are not allowed to contain functions in Relay. * - * We check this by ensuring the `dtype` field of a Tensor always + * We check this by ensuring the `dtype` field of a Tensor always * contains a data type such as `int`, `float`, `uint`. */ -#include <tvm/ir_functor.h> #include <tvm/relay/pass.h> #include "./type_visitor.h" diff --git a/src/relay/pass/type_functor.h b/src/relay/pass/type_functor.h index 339552108af4864a94333fbd4cd841d3c7d5e65d..cccde62625ea2de7b24163eabd48889b800aaff8 100644 --- a/src/relay/pass/type_functor.h +++ b/src/relay/pass/type_functor.h @@ -6,7 +6,7 @@ #ifndef TVM_RELAY_PASS_TYPE_FUNCTOR_H_ #define TVM_RELAY_PASS_TYPE_FUNCTOR_H_ -#include <tvm/ir_functor.h> +#include <tvm/node/ir_functor.h> #include <tvm/relay/expr.h> #include "./incomplete_type.h" diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index f4f6d82eb5e1e042ec20a36b45bfaeb18c797b20..deed982acbc61f52b0d948e9b9309cd96c1e0eb0 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -137,7 +137,7 @@ class TypeInferencer : private ExprFunctor<CheckedExpr(const Expr&)> { void Solve(TypeRelationData& ty_rel); /*! \brief Attempt to solve all pending relations. - * + * * If the solver */ SolverResult Solve(std::vector<TypeRelationData>& rels); @@ -607,8 +607,7 @@ TVM_REGISTER_API("relay._ir_pass._get_checked_type") /* Incomplete Type */ IncompleteType IncompleteTypeNode::make(TypeParamNode::Kind kind) { - std::shared_ptr<IncompleteTypeNode> n = - std::make_shared<IncompleteTypeNode>(); + auto n = make_node<IncompleteTypeNode>(); n->kind = std::move(kind); return IncompleteType(n); } diff --git a/src/relay/pass/unifier.cc b/src/relay/pass/unifier.cc index b0ed71d1791139df6ea1d7529d413e3236addbf6..67cc58ffc0a32feb8efe46000e8a75e96a7a2bbf 100644 --- a/src/relay/pass/unifier.cc +++ b/src/relay/pass/unifier.cc @@ -21,7 +21,7 @@ using tvm::IRPrinter; using namespace tvm::runtime; UnionFind UnionFindNode::make(tvm::Map<IncompleteType, Type> uf_map) { - std::shared_ptr<UnionFindNode> n = std::make_shared<UnionFindNode>(); + auto n = make_node<UnionFindNode>(); n->uf_map = uf_map; return UnionFind(n); } @@ -130,7 +130,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) }); TypeUnifier TypeUnifierNode::make(UnionFind union_find) { - std::shared_ptr<TypeUnifierNode> n = std::make_shared<TypeUnifierNode>(); + auto n = make_node<TypeUnifierNode>(); n->union_find = union_find; return TypeUnifier(n); } diff --git a/src/relay/pass/unifier.h b/src/relay/pass/unifier.h index 4e939cc26bcae16306726fc0f996bc0bc5d1c1fe..feda644cdd1da56038c2743b59c60dd1f41fec76 100644 --- a/src/relay/pass/unifier.h +++ b/src/relay/pass/unifier.h @@ -67,7 +67,7 @@ class UnionFindNode : public Node { class UnionFind : public NodeRef { public: UnionFind() {} - explicit UnionFind(std::shared_ptr<tvm::Node> p) : NodeRef(p) {} + explicit UnionFind(NodePtr<tvm::Node> p) : NodeRef(p) {} // The union find structure is mutable so we do not use the standard macros // and expose the pointer via `->`. @@ -126,7 +126,7 @@ class TypeUnifierNode : public Node, class TypeUnifier : public NodeRef { public: TypeUnifier() {} - explicit TypeUnifier(std::shared_ptr<tvm::Node> p) : NodeRef(p) {} + explicit TypeUnifier(NodePtr<tvm::Node> p) : NodeRef(p) {} // no const so that unifier can be mutable as a member of typechecker inline TypeUnifierNode* operator->() const { diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index fa26aea51a2bc150bfc8c569a89d57c84c861b51..8591c77bd7ccca71826394a010f72617c0946528 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -46,7 +46,7 @@ Expr InjectPredicate(const Array<Expr>& predicates, if (predicates.size() == 0) return body; const Reduce* reduce = body.as<Reduce>(); if (reduce) { - std::shared_ptr<Reduce> n = std::make_shared<Reduce>(*reduce); + auto n = make_node<Reduce>(*reduce); n->condition = n->condition && arith::ComputeReduce<ir::And>(predicates, Expr()); return Expr(n); } @@ -400,7 +400,7 @@ void InjectInline(ScheduleNode* sch) { CHECK_EQ(new_body[j].size(), r->source.size()); CHECK(r != nullptr); for (size_t k = 0; k < new_body[j].size(); ++k) { - std::shared_ptr<ir::Reduce> n = std::make_shared<ir::Reduce>(*r); + auto n = make_node<ir::Reduce>(*r); n->value_index = static_cast<int>(k); n->type = r->source[k].type(); new_body[j].Set(k, Expr(n)); @@ -520,11 +520,11 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor, const int factor_axis_pos = \ factor_axis >= 0 ? factor_axis : static_cast<int>(compute_op->axis.size() + 1) + factor_axis; CHECK_LE(factor_axis_pos, compute_op->axis.size()); - auto n = std::make_shared<ComputeOpNode>(); + auto n = make_node<ComputeOpNode>(); n->name = compute_op->name + ".rf"; { // axis relacement. - auto iv_node = std::make_shared<IterVarNode>(); + auto iv_node = make_node<IterVarNode>(); iv_node->dom = dom_map.at(axis); CHECK(is_zero(iv_node->dom->min)) << "Can only factor reduction domain starting from 0"; @@ -565,7 +565,7 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor, for (IterVar iv : reduce_stage->leaf_iter_vars) { if (touch_map.count(iv) && !iv.same_as(axis)) { CHECK_EQ(iv->iter_type, kCommReduce); - auto ncpy = std::make_shared<IterVarNode>(*iv.operator->()); + auto ncpy = make_node<IterVarNode>(*iv.operator->()); ncpy->dom = dom_map.at(iv); n->reduce_axis.push_back(IterVar(ncpy)); } diff --git a/src/schedule/schedule_lang.cc b/src/schedule/schedule_lang.cc index 1490c85ff7865c43f7676118b8da4d8435096f32..d503e978887e9991debf7a61f376ce79a8ebee9a 100644 --- a/src/schedule/schedule_lang.cc +++ b/src/schedule/schedule_lang.cc @@ -70,7 +70,7 @@ void Split(StageNode* self, } // namespace Stage::Stage(Operation op) { - auto n = std::make_shared<StageNode>(); + auto n = make_node<StageNode>(); n->op = op; n->origin_op = op; n->all_iter_vars = op->root_iter_vars(); @@ -164,16 +164,16 @@ Stage& Stage::bind(IterVar ivar, IterVar thread_ivar) { // NOLINT(*) FindLeafVar(all_vars, leaf_vars, ivar); auto it = self->iter_var_attrs.find(ivar); - std::shared_ptr<IterVarAttrNode> n; + NodePtr<IterVarAttrNode> n; if (it != self->iter_var_attrs.end()) { - n = std::make_shared<IterVarAttrNode>(*(*it).second.operator->()); + n = make_node<IterVarAttrNode>(*(*it).second.operator->()); if (n->bind_thread.defined() && !n->bind_thread.same_as(thread_ivar)) { LOG(WARNING) << "Axis " << ivar << " is already bind to another thread " << n->bind_thread; } } else { - n = std::make_shared<IterVarAttrNode>(); + n = make_node<IterVarAttrNode>(); } n->bind_thread = thread_ivar; self->iter_var_attrs.Set(ivar, IterVarAttr(n)); @@ -188,7 +188,7 @@ Stage& Stage::env_threads(Array<IterVar> threads) { << "Already set env_threads"; ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); - std::vector<std::shared_ptr<Node> > temp; + std::vector<NodePtr<Node> > temp; for (IterVar iv : threads) { temp.push_back(iv.node_); } @@ -303,7 +303,7 @@ Stage& Stage::reorder(const Array<IterVar>& order) { // NOLINT(*) for (size_t i = 0; i < order.size(); ++i) { pos.push_back(FindLeafVar(all_vars, leaf_vars, order[i])); } - std::vector<std::shared_ptr<Node> > temp; + std::vector<NodePtr<Node> > temp; for (size_t i = 0; i < pos.size(); ++i) { temp.emplace_back(leaf_vars->data[pos[i]]); } @@ -335,11 +335,11 @@ inline void UpdateIterVarAttr(StageNode* self, FindLeafVar(all_vars, leaf_vars, var); } auto it = self->iter_var_attrs.find(var); - std::shared_ptr<IterVarAttrNode> n; + NodePtr<IterVarAttrNode> n; if (it != self->iter_var_attrs.end()) { - n = std::make_shared<IterVarAttrNode>(*(*it).second.operator->()); + n = make_node<IterVarAttrNode>(*(*it).second.operator->()); } else { - n = std::make_shared<IterVarAttrNode>(); + n = make_node<IterVarAttrNode>(); } fupdate(n.get()); self->iter_var_attrs.Set(var, IterVarAttr(n)); @@ -397,11 +397,11 @@ Stage& Stage::prefetch(const Tensor &tensor, IterVar var, Expr offset) { ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); FindLeafVar(all_vars, leaf_vars, var); auto it = self->iter_var_attrs.find(var); - std::shared_ptr<IterVarAttrNode> n; + NodePtr<IterVarAttrNode> n; if (it != self->iter_var_attrs.end()) { - n = std::make_shared<IterVarAttrNode>(*(*it).second.operator->()); + n = make_node<IterVarAttrNode>(*(*it).second.operator->()); } else { - n = std::make_shared<IterVarAttrNode>(); + n = make_node<IterVarAttrNode>(); } n->prefetch_data.push_back(tensor); n->prefetch_offset.push_back(offset); @@ -468,8 +468,8 @@ Stage& Stage::opengl() { } Stage CopyStage(const Stage& s) { - std::shared_ptr<StageNode> n = - std::make_shared<StageNode>(*s.operator->()); + NodePtr<StageNode> n = + make_node<StageNode>(*s.operator->()); return Stage(n); } @@ -477,7 +477,7 @@ Schedule Schedule::copy() const { // map of stages. const ScheduleNode* self = operator->(); std::unordered_map<Stage, Stage, NodeHash, NodeEqual> smap; - std::shared_ptr<ScheduleNode> n = std::make_shared<ScheduleNode>(); + NodePtr<ScheduleNode> n = make_node<ScheduleNode>(); n->outputs = self->outputs; // Copy the stages. for (Stage s : self->stages) { @@ -599,7 +599,7 @@ Stage Schedule::create_group(const Array<Tensor>& outputs, } } // Create the new group stage. - Stage gstage(std::make_shared<StageNode>()); + Stage gstage(make_node<StageNode>()); gstage->group = parent_group; if (parent_group.defined()) { ++parent_group->num_child_stages; @@ -687,7 +687,7 @@ void ScheduleNode::InitCache() { } Schedule ScheduleNode::make(Array<Operation> ops) { - auto n = std::make_shared<ScheduleNode>(); + auto n = make_node<ScheduleNode>(); Schedule sch(n); n->outputs = ops; auto g = schedule::CreateReadGraph(n->outputs); @@ -731,7 +731,7 @@ IterVarRelation SplitNode::make(IterVar parent, IterVar inner, Expr factor, Expr nparts) { - auto n = std::make_shared<SplitNode>(); + auto n = make_node<SplitNode>(); n->parent = parent; n->outer = outer; n->inner = inner; @@ -742,7 +742,7 @@ IterVarRelation SplitNode::make(IterVar parent, IterVarRelation FuseNode::make( IterVar outer, IterVar inner, IterVar fused) { - auto n = std::make_shared<FuseNode>(); + auto n = make_node<FuseNode>(); n->outer = outer; n->inner = inner; n->fused = fused; @@ -750,14 +750,14 @@ IterVarRelation FuseNode::make( } IterVarRelation RebaseNode::make(IterVar parent, IterVar rebased) { - auto n = std::make_shared<RebaseNode>(); + auto n = make_node<RebaseNode>(); n->parent = parent; n->rebased = rebased; return IterVarRelation(n); } IterVarRelation SingletonNode::make(IterVar iter) { - auto n = std::make_shared<SingletonNode>(); + auto n = make_node<SingletonNode>(); n->iter = iter; return IterVarRelation(n); } diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc index f87924d846198fada676708decd662cd035841e5..db140f240344f46baa20d8d11fb191d889b1cf90 100644 --- a/tests/cpp/ir_functor_test.cc +++ b/tests/cpp/ir_functor_test.cc @@ -1,7 +1,7 @@ #include <dmlc/logging.h> #include <gtest/gtest.h> #include <tvm/tvm.h> -#include <tvm/ir_functor.h> +#include <tvm/node/ir_functor.h> #include <tvm/ir_functor_ext.h> TEST(IRF, Basic) { diff --git a/tests/scripts/task_python_integration.sh b/tests/scripts/task_python_integration.sh index 7dcd5c921905cf28e39722bb6fd9d13fde05bbec..818376717176beb2b90e640219443e6df874fbad 100755 --- a/tests/scripts/task_python_integration.sh +++ b/tests/scripts/task_python_integration.sh @@ -10,6 +10,7 @@ make cython3 || exit -1 # Test extern package package cd apps/extension +rm -rf lib make || exit -1 cd ../.. python -m nose -v apps/extension/tests || exit -1