diff --git a/include/tvm/attrs.h b/include/tvm/attrs.h index 33d84cecec6aeb6f55ae4c0e5d723bb2343a41cf..51d916ca488d41dd6eb156fe3031a383fc87def0 100644 --- a/include/tvm/attrs.h +++ b/include/tvm/attrs.h @@ -35,6 +35,7 @@ #include <string> #include "ir.h" #include "base.h" +#include "expr.h" #include "packed_func_ext.h" namespace tvm { @@ -73,7 +74,6 @@ inline Type NullValue<Type>() { return Type(Type::Handle, 0, 0); } - /*! \brief Error thrown during attribute checking. */ struct AttrError : public dmlc::Error { /*! diff --git a/include/tvm/expr.h b/include/tvm/expr.h index 7fdca7f6af8e05554315e46eba9506a771a2ad7a..37b122ae5b03403422bac013449e5e7ac26ea66a 100644 --- a/include/tvm/expr.h +++ b/include/tvm/expr.h @@ -29,6 +29,7 @@ using HalideIR::VarExpr; using HalideIR::IR::RangeNode; using HalideIR::IR::FunctionRef; using HalideIR::IR::FunctionBaseNode; +using HalideIR::Internal::IntImm; using HalideIR::Internal::Stmt; using HalideIR::Internal::IRPrinter; using HalideIR::Internal::Variable; @@ -83,6 +84,51 @@ class Var : public HalideIR::VarExpr { }; +/*! + * \brief Container of constant ineteger (IntImm). + * + * This is used to store and automate type check + * attributes that must be constant integer. + */ +class Integer : public Expr { + public: + Integer() : Expr() {} + /*! + * \brief constructor from node. + */ + explicit Integer(NodePtr<Node> node) : Expr(node) {} + /*! + * \brief Construct integer from int value. + */ + Integer(int value) : Expr(value) {} // NOLINT(*) + /*! + * \brief Assign an expression to integer. + * \param other another expression. + */ + Integer& operator=(const Integer& other) { + node_ = other.node_; + return *this; + } + /*! + * \brief Get pointer to the internal value. + * \return the content of the integer. + */ + const IntImm* operator->() const { + return static_cast<const IntImm*>(node_.get()); + } + /*! + * \brief convert to int64_t + */ + operator int64_t() const { + CHECK(node_ != nullptr) + << " Trying get reference a null Integer"; + return (*this)->value; + } + /*! \brief type indicate the container type */ + using ContainerType = IntImm; +}; + + /*! \brief container class of iteration variable. */ class IterVarNode; diff --git a/include/tvm/packed_func_ext.h b/include/tvm/packed_func_ext.h index 0491f3057815f774e97979e1d3f15059229fb1d7..c5a83608c61740d1fbadefef97ed8421957a0002 100644 --- a/include/tvm/packed_func_ext.h +++ b/include/tvm/packed_func_ext.h @@ -10,6 +10,7 @@ #include <sstream> #include <string> #include <memory> +#include <limits> #include <type_traits> #include "base.h" @@ -126,6 +127,8 @@ inline TNodeRef TVMArgValue::AsNodeRef() const { inline TVMArgValue::operator HalideIR::Expr() const { if (type_code_ == kNull) return Expr(); if (type_code_ == kDLInt) { + CHECK_LE(value_.v_int64, std::numeric_limits<int>::max()); + CHECK_GE(value_.v_int64, std::numeric_limits<int>::min()); return Expr(static_cast<int>(value_.v_int64)); } if (type_code_ == kDLFloat) { @@ -145,6 +148,20 @@ inline TVMArgValue::operator HalideIR::Expr() const { return Expr(sptr); } +inline TVMArgValue::operator tvm::Integer() const { + if (type_code_ == kNull) return Integer(); + if (type_code_ == kDLInt) { + CHECK_LE(value_.v_int64, std::numeric_limits<int>::max()); + CHECK_GE(value_.v_int64, std::numeric_limits<int>::min()); + return Integer(static_cast<int>(value_.v_int64)); + } + NodePtr<Node>& sptr = *ptr<NodePtr<Node> >(); + CHECK(NodeTypeChecker<Integer>::Check(sptr.get())) + << "Expected type " << NodeTypeName<Expr>() + << " but get " << sptr->type_key(); + return Integer(sptr); +} + inline NodePtr<Node>& TVMArgValue::node_sptr() { TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle); return *ptr<NodePtr<Node> >(); diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index eb044ccb29fd7cbbb452a398c435b8e75108c25f..34bd5eb93312431d719d0861a52ff306d57e98da 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -317,7 +317,7 @@ struct BatchNormAttrs : public tvm::AttrsNode<BatchNormAttrs> { /*! \brief Attributes for LRN operator */ struct LRNAttrs : public tvm::AttrsNode<LRNAttrs> { IndexExpr size; - IndexExpr axis; + int axis; double bias; double alpha; double beta; @@ -340,7 +340,7 @@ struct LRNAttrs : public tvm::AttrsNode<LRNAttrs> { /*! \brief Attributes for L2Normalize operator */ struct L2NormalizeAttrs : public tvm::AttrsNode<L2NormalizeAttrs> { double eps; - Array<IndexExpr> axis; + Array<Integer> axis; TVM_DECLARE_ATTRS(L2NormalizeAttrs, "relay.attrs.L2NormalizeAttrs") { TVM_ATTR_FIELD(eps) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 1941e045ed8d46d43b6c8107411485d03c3c6635..b0150c4ac3d9d824f2699b608c45974841796d0d 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -53,7 +53,7 @@ struct ConcatenateAttrs : public tvm::AttrsNode<ConcatenateAttrs> { /*! \brief Attributes used in transpose operators */ struct TransposeAttrs : public tvm::AttrsNode<TransposeAttrs> { - Array<IndexExpr> axes; + Array<Integer> axes; TVM_DECLARE_ATTRS(TransposeAttrs, "relay.attrs.TransposeAttrs") { TVM_ATTR_FIELD(axes) .describe("The target axes order, reverse order if not specified."); @@ -70,10 +70,10 @@ struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> { }; // struct ReshapeAttrs struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> { - IndexExpr axis; + Integer axis; TVM_DECLARE_ATTRS(TakeAttrs, "relay.attrs.TakeAttrs") { - TVM_ATTR_FIELD(axis).set_default(NullValue<IndexExpr>()) + TVM_ATTR_FIELD(axis).set_default(NullValue<Integer>()) .describe("The axis over which to select values."); } }; diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index a8fa096e51c4df8e3d2ee75bab109c62f5d1d664..c306f8d15160af6f0e5668bd64dca8685fcc17dd 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -32,6 +32,9 @@ struct Expr; #endif namespace tvm { +// forward declarations +class Integer; + namespace runtime { // forward declarations class TVMArgs; @@ -559,6 +562,7 @@ class TVMArgValue : public TVMPODValue_ { inline bool IsNodeType() const; inline operator HalideIR::Type() const; inline operator HalideIR::Expr() const; + inline operator tvm::Integer() const; // get internal node ptr, if it is node inline NodePtr<Node>& node_sptr(); }; diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 8459a99cde23639a94deeb9b62dbb56f1bc5aba8..d38c5a0ebe0d0beeb288983e93e5fdb2ab157677 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -317,7 +317,7 @@ TVM_REGISTER_NODE_TYPE(LRNAttrs); Expr MakeLRN(Expr data, IndexExpr size, - IndexExpr axis, + int axis, double alpha, double beta, double bias) { @@ -337,7 +337,7 @@ TVM_REGISTER_API("relay.op.nn._make.lrn") }); RELAY_REGISTER_OP("nn.lrn") - .describe(R"code(LRN layer. +.describe(R"code(LRN layer. Normalize the input in a local region across or within feature maps. Each input value is divided by (1 + (\alpha/n) \sum_i x_i^2)^\beta, @@ -362,7 +362,7 @@ TVM_REGISTER_NODE_TYPE(L2NormalizeAttrs); Expr MakeL2Normalize(Expr data, double eps, - Array<IndexExpr> axis) { + Array<Integer> axis) { auto attrs = make_node<L2NormalizeAttrs>(); attrs->eps = eps; attrs->axis = std::move(axis); diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index bab875fd190ec6b9d766801a6c8bccd7fc6cc013..29dff1e4ba27251cd3334d65a1970d761bf2bdf4 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -218,24 +218,23 @@ bool TransposeRel(const Array<Type>& types, } const auto* param = attrs.as<TransposeAttrs>(); const int ndim = data->shape.size(); - const Array<IndexExpr>& axes = param->axes; + const Array<Integer>& axes = param->axes; // check dimension match - CHECK(axes.empty() || static_cast<int>(axes.size()) == ndim) + CHECK(!axes.defined() || static_cast<int>(axes.size()) == ndim) << "Dimension mismatch: axes has " << axes.size() << " elements" << ", but data.ndim = " << ndim; // construct int_axes std::vector<int> int_axes; int_axes.reserve(ndim); - if (axes.empty()) { + // used not defined to check if it is None. + if (!axes.defined()) { for (int i = ndim - 1; i >= 0; --i) { int_axes.push_back(i); } } else { std::vector<int> axis_used(ndim, 0); - for (const IndexExpr& e : axes) { - const int64_t *axis_ptr = as_const_int(e); - CHECK(axis_ptr != nullptr); - int axis = *axis_ptr; + for (const Integer& e : axes) { + int64_t axis = e; // sanity check for axis and ndim CHECK(-ndim <= axis && axis < ndim) << "transpose only allows each `axis` in `axes` in range [-data.ndim, data.ndim)" @@ -245,7 +244,7 @@ bool TransposeRel(const Array<Type>& types, // sanity check for duplication CHECK(!axis_used[axis]) << "Duplicate axes in transpose: " << axis; axis_used[axis] = 1; - int_axes.push_back(axis); + int_axes.push_back(static_cast<int>(axis)); } } std::vector<IndexExpr> oshape; @@ -258,7 +257,7 @@ bool TransposeRel(const Array<Type>& types, } Expr MakeTranspose(Expr data, - Array<IndexExpr> axes) { + Array<Integer> axes) { auto attrs = make_node<TransposeAttrs>(); attrs->axes = std::move(axes); static const Op& op = Op::Get("transpose"); @@ -401,7 +400,7 @@ bool TakeRel(const Array<Type>& types, std::vector<IndexExpr> oshape; const auto ndim_data = static_cast<int>(data->shape.size()); const auto ndim_indices = static_cast<int>(indices->shape.size()); - auto axis = (*as_const_int(param->axis)); + int axis = static_cast<int>(param->axis->value); if (axis < 0) axis += ndim_data; CHECK_LE(axis, ndim_data) << "axis should be with in data shape" @@ -424,9 +423,9 @@ bool TakeRel(const Array<Type>& types, Expr MakeTake(Expr data, Expr indices, - IndexExpr axis) { + Integer axis) { auto attrs = make_node<TakeAttrs>(); - attrs->axis = axis; + attrs->axis = std::move(axis); static const Op& op = Op::Get("take"); return CallNode::make(op, {data, indices}, Attrs(attrs), {}); }