Skip to content
Snippets Groups Projects
Commit 9f8fcfc9 authored by Yizhi Liu's avatar Yizhi Liu Committed by Tianqi Chen
Browse files

General Layout Support (#447)

parent fc7e9cd2
No related branches found
No related tags found
No related merge requests found
Showing
with 857 additions and 257 deletions
/*!
* Copyright (c) 2017 by Contributors
* \file contrib_op_param.h
* \brief Additional parameters for compiler optimized operators.
*/
#ifndef NNVM_COMPILER_CONTRIB_OP_PARAM_H_
#define NNVM_COMPILER_CONTRIB_OP_PARAM_H_
#include <dmlc/parameter.h>
#include <string>
namespace nnvm {
namespace compiler {
/*! \brief Parameters of layout transform operator */
struct LayoutTransformParam : public dmlc::Parameter<LayoutTransformParam> {
std::string src_layout;
std::string dst_layout;
DMLC_DECLARE_PARAMETER(LayoutTransformParam) {
DMLC_DECLARE_FIELD(src_layout);
DMLC_DECLARE_FIELD(dst_layout);
}
};
} // namespace compiler
} // namespace nnvm
#endif // NNVM_COMPILER_CONTRIB_OP_PARAM_H_
......@@ -16,6 +16,7 @@
#include <nnvm/graph.h>
#include <vector>
#include <string>
#include "packed_func_ext.h"
namespace nnvm {
namespace compiler {
......@@ -73,19 +74,17 @@ using FTVMSchedule = std::function<
const Array<Tensor>& outs,
const std::string& target)>;
/*! \brief Layout Information about an entry */
using TLayoutInfo = std::string;
/*!
* \brief The producer consumer function of node layout
* \param attrs The attribute of the node.
* \param ilayouts The input layouts that the node request.
* \param olayouts The output layouts that the node produce.
* \return bool The success flag.
* \brief Modify the op node to alter its input layout.
* it is invoked in AlterOpLayout pass.
* \param attrs The attribute of the original node.
* \param inputs The input symbols of the original node.
* \param tinfos The inferred shape and dtype of the inputs.
*/
using FTVMLayoutRequest = std::function<bool (const NodeAttrs& attrs,
std::vector<TLayoutInfo> *ilayouts,
std::vector<TLayoutInfo> *olayouts)>;
using FTVMAlterOpLayout = std::function<
Symbol(const NodeAttrs& attrs,
const Symbol& inputs,
const Array<Tensor>& tinfos)>;
/*!
* \brief Transform from normal operator to vectorized operator
......
......@@ -11,6 +11,7 @@
#include <nnvm/graph.h>
#include <nnvm/symbolic.h>
#include <string>
#include <vector>
#include <unordered_map>
namespace nnvm {
......@@ -52,6 +53,7 @@ template<>
struct extension_class_info<nnvm::compiler::AttrDict> {
static const int code = 18;
};
} // namespace runtime
} // namespace tvm
#endif // NNVM_COMPILER_PACKED_FUNC_EXT_H_
......@@ -9,6 +9,7 @@
#include <vector>
#include <string>
#include "./tuple.h"
#include "./layout.h"
namespace nnvm {
......@@ -46,7 +47,7 @@ using ShapeVector = std::vector<TShape>;
* \code
* Graph g = ApplyPass(src_graph, "InferType");
* const DTypeVector& types = g.GetAttr<DTypeVector>("dtype");
* // get shape by entry id
* // get type by entry id
* int entry_type = dtypes[g.indexed_graph().entry_id(my_entry)];
* \endcode
*
......@@ -54,6 +55,21 @@ using ShapeVector = std::vector<TShape>;
*/
using DTypeVector = std::vector<int>;
/*!
* \brief The result holder of layout of each NodeEntry in the graph.
* \note Stored under graph.attrs["layout"], provided by Pass "InferType"
*
* \code
* Graph g = ApplyPass(src_graph, "LayoutTransform");
* const LayoutVector& layouts = g.GetAttr<LayoutVector>("layout");
* // get layout by entry id
* int entry_layout = layouts[g.indexed_graph().entry_id(my_entry)];
* \endcode
*
* \sa FInferLayout
*/
using LayoutVector = std::vector<Layout>;
/*!
* \brief The result holder of device of each operator in the graph.
* \note Stored under graph.attrs["device"], provided by Pass "PlaceDevice"
......
/*!
* Copyright (c) 2018 by Contributors
* \file layout.h
* \brief Layout expression.
* The layout is composed of upper cases, lower cases and numbers,
* where upper case indicates a (super-)dimension and
* the corresponding lower case with factor size indicates the split (sub-)dimension.
* For example, NCHW16c can describe a 5-D tensor of
* [batch_size, channel, height, width, channel_block].
* Here sub-dimension channel_block=16 is the split of super-dimension C (channel).
*/
#ifndef NNVM_LAYOUT_H_
#define NNVM_LAYOUT_H_
#include <dmlc/parameter.h>
#include <string>
#include <sstream>
#include <vector>
#include <utility>
#include <algorithm>
namespace nnvm {
class Layout {
public:
using LayoutDim = char;
/*! \brief default constructor */
Layout() : name_("__undef__") {} // NOLINT(*)
/*!
* \brief construct from a string.
* \param layout input in layout convention:
* upper case indicates a dimension and
* the corresponding lower case with factor size
* indicates the split dimension.
* return undefined layout if "__undef__" is passed.
*/
inline Layout(const std::string& layout) { // NOLINT(*)
parse(layout);
}
/*!
* \brief copy constructor from another layout
* \param s the source layout
*/
inline Layout(const Layout& s) { // NOLINT(*)
this->parse(s.name_);
}
/*!
* \brief move constructor from Layout
* \param src the source layout
*/
inline Layout(Layout&& src) { // NOLINT(*)
this->swap(src);
}
/*!
* \brief assignment from another layout.
* \param src source layout
* \return reference of self
*/
inline Layout& operator=(const Layout& src) {
this->parse(src.name_);
return *this;
}
/*!
* \brief assignment from rvalue of another layout.
* \param src source layout
* \return reference of self
*/
inline Layout& operator=(Layout&& src) {
Layout(std::move(src)).swap(*this); // NOLINT(*)
return *this;
}
/*!
* \brief assignment from string.
* \param src source layout
* \return reference of self
*/
inline Layout& operator=(const std::string& src) {
this->parse(src);
return *this;
}
/*!
* \return whether two layout equals
* \param s the layout to compare against
*/
inline bool operator==(const Layout& s) const {
return name_ == s.name_;
}
/*!
* \return whether two layout not equal
* \param s the layout to compare against
*/
inline bool operator!=(const Layout& s) const {
return !(*this == s);
}
/*!
* \brief Append the current layout by another.
* @param other the layout to be appended
* @return a new layout
*/
inline Layout operator+(const Layout& other) const {
if (!this->defined() && !other.defined()) {
return Layout::Undef();
} else if (!this->defined()) {
return other;
} else if (!other.defined()) {
return *this;
}
return Layout(this->name_ + other.name_);
}
/*!
* \brief Check whether a given dimension is a super-dimension.
* \param dim input dimension
* \return Whether a given dimension is a super-dimension.
*/
static inline bool is_superdim(LayoutDim dim) {
return dim >= 'A' && dim <= 'Z';
}
/*!
* \brief Check whether a given dimension is a sub-dimension.
* \param dim input dimension
* \return Whether a given dimension is a sub-dimension.
*/
static inline bool is_subdim(LayoutDim dim) {
return dim >= 'a' && dim <= 'z';
}
/*!
* \brief Convert a given dimension to super-dimension.
* \param dim input dimension
* \return The converted description.
*/
static inline LayoutDim to_superdim(LayoutDim dim) {
if (is_subdim(dim)) {
return dim - 'a' + 'A';
}
return dim;
}
/*!
* \brief Convert a given dimension to sub-dimension.
* \param dim input dimension
* \return The converted description.
*/
static inline LayoutDim to_subdim(LayoutDim dim) {
if (is_superdim(dim)) {
return dim - 'A' + 'a';
}
return dim;
}
/*!
* \brief Return an undefined layout.
* \return a (global) undefined layout.
*/
static inline const Layout& Undef() {
static Layout undef;
return undef;
}
/*!
* \brief Swap current object with other
* \param other another object to be swapped.
*/
inline void swap(Layout& other) { // NOLINT(*)
std::swap(name_, other.name_);
std::swap(superdim_pos_, other.superdim_pos_);
std::swap(subdim_pos_, other.subdim_pos_);
std::swap(subdim_size_, other.subdim_size_);
std::swap(layout_simplified_, other.layout_simplified_);
}
/*!
* \brief Two layouts are convertible only if
* they have same set of super-dimensions.
* e.g., NCHW, NCHW16c, NHWC are convertible between each other,
* but NCHW, CHW, OIHW are not.
* \param dst the target layout
* \return Whether can be converted to dst layout.
*/
inline bool convertible(const Layout &dst) const {
if (!this->defined() || !dst.defined()) return false;
for (size_t i = 0; i < kUniqueDim; ++i) {
if ((superdim_pos_[i] >= 0 && dst.superdim_pos_[i] < 0) ||
(superdim_pos_[i] < 0 && dst.superdim_pos_[i] >= 0)) {
return false;
}
}
return true;
}
/*!
* \brief Returns a sublayout which is the portion of the object
* that starts at dimension \p pos and spans \p len dimensions
* (or until the end of the layout, whichever comes first).
* \param pos The start position.
* \param len The length of the sub-layout.
* \return A newly constructed Layout object.
*/
inline Layout sublayout(size_t pos, size_t len) const {
if (pos > ndim()) return Layout::Undef();
if (pos + len > ndim()) len = ndim() - pos;
if (len == 0) return Layout::Undef();
std::ostringstream new_layout;
for (size_t i = pos; i < pos + len; ++i) {
if (is_subdim(layout_simplified_[i])) {
auto block_size = this->subsizeof(layout_simplified_[i]);
CHECK_GT(block_size, 0);
new_layout << block_size;
}
new_layout << layout_simplified_[i];
}
return Layout(new_layout.str());
}
/*! \return A newly constructed reversed Layout object. */
inline Layout reverse() const {
if (!this->defined()) return Layout::Undef();
std::ostringstream new_layout;
for (int64_t i = this->ndim() - 1; i >= 0; --i) {
if (is_subdim(layout_simplified_[i])) {
auto block_size = this->subsizeof(layout_simplified_[i]);
CHECK_GT(block_size, 0);
new_layout << block_size;
}
new_layout << layout_simplified_[i];
}
return Layout(new_layout.str());
}
/*!
* \brief Split \p dim by \p size and put the sub-dimension to position \p target_pos.
* \param dim The source dimension to be split. It must be a super-dimension.
* \param target_pos The target position of the newly split sub-dimension.
* \param size size of the sub-dimension.
* \return A newly constructed Layout object.
*/
inline Layout split(LayoutDim dim, size_t target_pos, uint32_t size) const {
CHECK(target_pos <= this->ndim()) << "Invalid split position "
<< target_pos << " for layout " << name_;
CHECK(is_superdim(dim)) << "Cannot split a sub-dimension " << dim;
CHECK(this->contains(dim)) << "Axis " << dim << " does not exist in " << name_;
CHECK(!this->contains(to_subdim(dim))) << "Dimension " << dim
<< " has already been split in "
<< name_;
CHECK(size > 0) << "Invalid split size " << size;
std::ostringstream new_layout;
for (size_t i = 0; i <= this->ndim(); ++i) {
if (i == target_pos) {
new_layout << size << Layout::to_subdim(dim);
}
if (i == this->ndim()) break;
new_layout << this->at(i);
}
Layout x(new_layout.str());
return x;
}
using iterator = std::vector<LayoutDim>::const_iterator;
using reverse_iterator = std::vector<LayoutDim>::const_reverse_iterator;
/*! \return begin iterator */
inline iterator begin() const {
return layout_simplified_.begin();
}
/*! \return end iterator */
inline iterator end() const {
return layout_simplified_.end();
}
/*! \return rbegin iterator */
inline reverse_iterator rbegin() const {
return layout_simplified_.rbegin();
}
/*! \return rend iterator */
inline reverse_iterator rend() const {
return layout_simplified_.rend();
}
/*! \return number of dimensions */
inline size_t ndim() const {
return layout_simplified_.size();
}
/*!
* \brief The description of the \p i-th dimension.
* If it is a sub-dimension, the size will be returned as well,
* e.g., 16c. Otherwise a single character is returned, e.g., C.
* \param i The position
* \return the description of the dimension.
*/
inline std::string at(size_t i) const {
CHECK_LT(i, this->ndim()) << "position " << i
<< " exceeds ndim=" << this->ndim();
std::ostringstream repr;
if (is_subdim(layout_simplified_[i])) {
auto factor = subsizeof(layout_simplified_[i]);
CHECK_LT(factor, 0);
repr << factor;
}
repr << layout_simplified_[i];
return repr.str();
}
/*!
* \brief return the index of the input dimension.
* If it is not found in the layout or the layout is undefined,
* return -1.
* \param dim the input dimension.
* \return the index or -1 if not found.
*/
inline int32_t indexof(LayoutDim dim) const {
if (!this->defined()) return -1;
else if (is_superdim(dim)) return superdim_pos_[dim - 'A'];
else if (is_subdim(dim)) return subdim_pos_[dim - 'a'];
return -1;
}
/*!
* \param dim the input super-dimension or sub-dimension.
* \return the size of the sub-dimension of \p dim (if \p dim is a super-dimension),
* or the size of \p dim itself (if \p dim is a sub-dimension).
* Return -1 if \p dim is not in the layout or the layout is undefined.
*/
inline int64_t subsizeof(LayoutDim dim) const {
CHECK(is_superdim(dim) || is_subdim(dim)) << "Invalid dim " << dim;
if (!this->defined() || !this->contains(to_subdim(dim))) {
return -1;
}
int idx = to_subdim(dim) - 'a';
return subdim_size_[idx];
}
/*!
* \brief Whether the layout contains a dimension.
* \param dim dimension to be checked.
* \return Whether the layout contains the dimension.
*/
inline bool contains(LayoutDim dim) const {
if (is_superdim(dim)) {
return superdim_pos_[dim-'A'] >= 0;
} else if (is_subdim(dim)) {
return subdim_pos_[dim-'a'] >= 0;
}
return false;
}
inline const LayoutDim operator[](size_t i) const {
return layout_simplified_[i];
}
/*! \return whether the layout is defined */
inline bool defined() const {
return name_ != "__undef__";
}
/*! \return the string description of the layout */
inline const std::string& name() const {
return name_;
}
/*!
* \brief Write layout in JSON format.
* \param writer JSONWriter
*/
inline void Save(dmlc::JSONWriter* writer) const {
writer->Write(name_);
}
/*!
* \brief Load layout from JSON.
* \param reader JSONReader
*/
inline void Load(dmlc::JSONReader* reader) {
std::string tmp;
reader->Read(&tmp);
this->parse(tmp);
}
/*!
* \brief allow output string of layout to ostream
* \param os the output stream
* \param l the layout
* \return the ostream
*/
friend std::ostream& operator<<(std::ostream& os, const Layout& l) {
os << l.name_;
return os;
}
private:
static const uint32_t kUniqueDim = 26;
std::string name_;
int32_t superdim_pos_[kUniqueDim];
int32_t subdim_pos_[kUniqueDim];
int64_t subdim_size_[kUniqueDim];
std::vector<LayoutDim> layout_simplified_;
void parse(const std::string& layout) {
name_ = layout;
std::fill_n(superdim_pos_, kUniqueDim, -1);
std::fill_n(subdim_pos_, kUniqueDim, -1);
std::fill_n(subdim_size_, kUniqueDim, -1);
layout_simplified_.clear();
if (layout == "__undef__") return;
int32_t factor = 0;
uint32_t curr = 0;
for (size_t i = 0; i < layout.size(); ++i) {
const LayoutDim c = layout.at(i);
if (is_superdim(c)) {
int pos = c - 'A';
CHECK_EQ(factor, 0) << "Invalid layout " << layout
<< ": invalid factor size " << factor
<< " before dimension " << c;
CHECK_EQ(superdim_pos_[pos], -1) << "Invalid layout " << layout
<< ": duplicate dimension " << c;
superdim_pos_[pos] = curr++;
layout_simplified_.push_back(c);
} else if (is_subdim(c)) {
int pos = c - 'a';
CHECK_GT(factor, 0) << "Invalid layout " << layout << ": invalid factor size "
<< factor << " for dimension " << c;
CHECK_EQ(subdim_pos_[pos], -1) << "Invalid layout " << layout
<< ": duplicate dimension " << c;
CHECK_EQ(subdim_size_[pos], -1) << "Invalid layout " << layout
<< ": duplicate dimension " << c;
subdim_pos_[pos] = curr++;
subdim_size_[pos] = factor;
layout_simplified_.push_back(c);
factor = 0;
} else if (c >= '0' && c <= '9') {
CHECK(factor >= 0) << "Invalid layout " << layout << ": _ is adjacent to a number.";
factor = factor * 10 + c - '0';
} else {
LOG(FATAL) << "Invalid layout " << layout;
}
}
CHECK(!layout_simplified_.empty()) << "Invalid layout " << layout;
for (LayoutDim dim : layout_simplified_) {
CHECK(is_superdim(dim) || superdim_pos_[dim-'a'] >= 0)
<< "Invalid layout " << layout << ": missing axis "
<< static_cast<char>(dim - 'a' + 'A');
}
}
};
} // namespace nnvm
#endif // NNVM_LAYOUT_H_
......@@ -13,6 +13,7 @@
#include "./base.h"
#include "./node.h"
#include "./tuple.h"
#include "./layout.h"
namespace nnvm {
......@@ -176,6 +177,31 @@ using FSetInputVarAttrOnCompose = std::function<void(
NodePtr var,
const int index)>;
/*!
* \brief Inference function of node layout. See \p Layout for layout convention
* \param attrs The attribute of the node.
* \param ilayouts Given the input layouts produced by ancestor nodes,
* it should be filled by layouts that the node requests.
* If the requested layout is different from what ancestor produces,
* a __layout_transform__ operator will be inserted automatically.
* \param last_ilayouts The input layouts requested by the node
* at the last infer pass (if any).
* This can be useful when an operator wants to keep
* the input layout the same as the original one.
* For example, after the pass of AlterOpLayout,
* transpose(input, axis=[1, 2, 3, 0]) may receive an input of NCHW16c layout,
* with which it cannot calculate with axis=[1, 2, 3, 0].
* Last input layouts allow it to know what the layout it originally inferred,
* i.e., the layout in the imported model.
* \param olayouts Inferred output layouts.
* \return success flag.
*/
using FInferLayout = std::function<bool(
const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts)>;
} // namespace nnvm
#endif // NNVM_OP_ATTR_TYPES_H_
......@@ -9,23 +9,12 @@
#include <dmlc/base.h>
#include <dmlc/parameter.h>
#include <nnvm/tuple.h>
#include <nnvm/layout.h>
#include <string>
namespace nnvm {
namespace top {
// Layout flag in spatial conv and pooling.
enum LayoutFlag {
kNCHW,
kNHWC,
kCHWN,
kNCW,
kNWC,
kCWN,
kNCDHW,
kNDHWC,
kCDHWN
};
struct DenseParam : public dmlc::Parameter<DenseParam> {
int units;
bool use_bias;
......@@ -130,7 +119,9 @@ struct Conv2DParam : public dmlc::Parameter<Conv2DParam> {
TShape padding;
TShape dilation;
int groups;
int layout;
std::string layout;
std::string kernel_layout;
std::string out_layout;
bool use_bias;
DMLC_DECLARE_PARAMETER(Conv2DParam) {
......@@ -152,14 +143,19 @@ struct Conv2DParam : public dmlc::Parameter<Conv2DParam> {
"At groups=2, the operation becomes equivalent to having two convolution"
"layers side by side, each seeing half the input channels, and producing"
"half the output channels, and both subsequently concatenated.");
DMLC_DECLARE_FIELD(layout)
.add_enum("NCHW", kNCHW)
.add_enum("NHWC", kNHWC)
.set_default(kNCHW)
.describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc."
DMLC_DECLARE_FIELD(layout).set_default("NCHW")
.describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Convolution is applied on the 'H' and"
"'W' dimensions.");
DMLC_DECLARE_FIELD(out_layout).set_default("__undef__")
.describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Default to be same as input layout.");
DMLC_DECLARE_FIELD(kernel_layout).set_default("OIHW")
.describe("Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc."
"'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
"dimensions respectively.");
DMLC_DECLARE_FIELD(use_bias).set_default(true)
.describe("Whether the layer uses a bias vector.");
}
......@@ -178,7 +174,8 @@ struct Conv2DTransposeParam : public dmlc::Parameter<Conv2DTransposeParam> {
TShape output_padding;
TShape dilation;
int groups;
int layout;
std::string layout;
std::string kernel_layout;
bool use_bias;
DMLC_DECLARE_PARAMETER(Conv2DTransposeParam) {
......@@ -202,14 +199,15 @@ struct Conv2DTransposeParam : public dmlc::Parameter<Conv2DTransposeParam> {
"At groups=2, the operation becomes equivalent to having two convolution"
"layers side by side, each seeing half the input channels, and producing"
"half the output channels, and both subsequently concatenated.");
DMLC_DECLARE_FIELD(layout)
.add_enum("NCHW", kNCHW)
.add_enum("NHWC", kNHWC)
.set_default(kNCHW)
.describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc."
DMLC_DECLARE_FIELD(layout).set_default("NCHW")
.describe("Dimension ordering of data. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Convolution is applied on the 'H' and"
"'W' dimensions.");
DMLC_DECLARE_FIELD(kernel_layout).set_default("OIHW")
.describe("Dimension ordering of data and weight. Can be 'OIHW', 'OIHW16o16i', etc."
"'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
"dimensions respectively.");
DMLC_DECLARE_FIELD(use_bias).set_default(true)
.describe("Whether the layer uses a bias vector.");
}
......@@ -224,7 +222,7 @@ struct Pool2DParam : public dmlc::Parameter<Pool2DParam> {
TShape pool_size;
TShape strides;
TShape padding;
int layout;
std::string layout;
bool ceil_mode;
DMLC_DECLARE_PARAMETER(Pool2DParam) {
......@@ -235,10 +233,7 @@ struct Pool2DParam : public dmlc::Parameter<Pool2DParam> {
DMLC_DECLARE_FIELD(padding).set_default(TShape({0, 0}))
.describe("If padding is non-zero, then the input is implicitly zero-padded"
"on both sides for padding number of points");
DMLC_DECLARE_FIELD(layout)
.add_enum("NCHW", kNCHW)
.add_enum("NHWC", kNHWC)
.set_default(kNCHW)
DMLC_DECLARE_FIELD(layout).set_default("NCHW")
.describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Convolution is applied on the 'H' and"
......@@ -250,13 +245,10 @@ struct Pool2DParam : public dmlc::Parameter<Pool2DParam> {
struct GlobalPool2DParam : public dmlc::Parameter<GlobalPool2DParam> {
int layout;
std::string layout;
DMLC_DECLARE_PARAMETER(GlobalPool2DParam) {
DMLC_DECLARE_FIELD(layout)
.add_enum("NCHW", kNCHW)
.add_enum("NHWC", kNHWC)
.set_default(kNCHW)
DMLC_DECLARE_FIELD(layout).set_default("NCHW")
.describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Convolution is applied on the 'H' and"
......@@ -266,15 +258,13 @@ struct GlobalPool2DParam : public dmlc::Parameter<GlobalPool2DParam> {
struct UpSamplingParam : public dmlc::Parameter<UpSamplingParam> {
int scale;
int layout;
std::string layout;
DMLC_DECLARE_PARAMETER(UpSamplingParam) {
DMLC_DECLARE_FIELD(scale)
.describe("upsampling scaling factor");
DMLC_DECLARE_FIELD(layout)
.add_enum("NCHW", kNCHW)
.add_enum("NHWC", kNHWC)
.set_default(kNCHW)
.set_default("NCHW")
.describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Convolution is applied on the 'H' and"
......@@ -282,6 +272,18 @@ struct UpSamplingParam : public dmlc::Parameter<UpSamplingParam> {
}
};
struct LayoutTransformParam : public dmlc::Parameter<LayoutTransformParam> {
std::string src_layout;
std::string dst_layout;
DMLC_DECLARE_PARAMETER(LayoutTransformParam) {
DMLC_DECLARE_FIELD(src_layout).set_default("__undef__")
.describe("Dimension ordering of data");
DMLC_DECLARE_FIELD(dst_layout).set_default("__undef__")
.describe("Dimension ordering of data.");
}
};
} // namespace top
} // namespace nnvm
......
......@@ -211,12 +211,15 @@ def _init_symbol_module(symbol_class, root_namespace):
op_names.append(py_str(plist[i]))
module_obj = sys.modules["%s.symbol" % root_namespace]
module_obj_contrib = sys.modules["%s.contrib" % root_namespace]
module_internal = sys.modules["%s._symbol_internal" % root_namespace]
for name in op_names:
hdl = OpHandle()
check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl)))
function = _make_atomic_symbol_function(hdl, name)
if function.__name__.startswith('_'):
if function.__name__.startswith('_contrib_'):
setattr(module_obj_contrib, function.__name__.split('_contrib_')[1], function)
elif function.__name__.startswith('_'):
setattr(module_internal, function.__name__, function)
setattr(module_obj, function.__name__, function)
else:
......
......@@ -15,7 +15,8 @@ OPT_PASS_LEVEL = {
"SimplifyInference": 0,
"PrecomputePrune": 2,
"OpFusion": 1,
"FoldScaleAxis": 3
"FoldScaleAxis": 3,
"AlterOpLayout": 3,
}
# List of optimization pass and level when switch on
......@@ -139,7 +140,7 @@ def _update_shape_dtype(shape, dtype, params):
return shape, dtype
def optimize(graph, shape, dtype="float32"):
def optimize(graph, shape, dtype="float32", layout=None):
"""Perform target and parameter invariant graph optimization.
This is an advanced function that usually do not need to be called.
......@@ -157,6 +158,18 @@ def optimize(graph, shape, dtype="float32"):
"""
# pylint: disable=unused-argument
cfg = BuildConfig.current
if cfg.pass_enabled("AlterOpLayout"):
layout = layout if layout else {}
graph = graph_attr.set_layout_inputs(graph, layout)
graph = graph.apply(["CorrectLayout"])
graph = graph_attr.set_shape_inputs(graph, shape)
graph = graph_attr.set_dtype_inputs(graph, dtype)
graph = graph.apply(["InferShape", "InferType", "AlterOpLayout"])
graph = graph_attr.set_layout_inputs(graph, layout)
graph = graph.apply(["CorrectLayout"])
if cfg.pass_enabled("SimplifyInference"):
graph = graph_attr.set_shape_inputs(graph, shape)
graph = graph.apply(["InferShape", "SimplifyInference"])
......@@ -167,7 +180,8 @@ def optimize(graph, shape, dtype="float32"):
return graph
def build(graph, target=None, shape=None, dtype="float32", params=None, target_host=None):
def build(graph, target=None, shape=None, dtype="float32",
params=None, target_host=None, layout=None):
"""Build graph into runtime library.
The build function will optimize the graph and do the compilation.
......@@ -204,8 +218,8 @@ def build(graph, target=None, shape=None, dtype="float32", params=None, target_h
By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used.
initialize : bool, optional
Whether to initialize variables in global dict _all_var_init.
layout : dict of str to str or str optional
The input layout
Returns
-------
......@@ -230,6 +244,15 @@ def build(graph, target=None, shape=None, dtype="float32", params=None, target_h
cfg = BuildConfig.current
graph = graph if isinstance(graph, _graph.Graph) else _graph.create(graph)
shape, dtype = _update_shape_dtype(shape, dtype, params)
# correct layout if necessary
layout = layout if layout else {}
graph = graph_attr.set_layout_inputs(graph, layout)
graph = graph.apply("CorrectLayout")
index = graph.index
layouts = graph.json_attr("layout")
layout = {x : layouts[index.entry_id(x)] for x in index.input_names}
# Initial pass do shape type inference
ishape, _ = graph_util.infer_shape(graph, **shape)
shape.update(zip(graph.index.input_names, ishape))
......@@ -241,13 +264,14 @@ def build(graph, target=None, shape=None, dtype="float32", params=None, target_h
if _all_var_init:
init_var = initialize_variables(shape, dtype)
# Apply optimization
graph = optimize(graph, shape, dtype)
graph = optimize(graph, shape, dtype, layout)
# Precompute prune
if params and cfg.pass_enabled("PrecomputePrune"):
graph, params = precompute_prune(graph, params)
shape, dtype = _update_shape_dtype(shape, dtype, params)
# Operator Fusion and generation
graph = graph_attr.set_shape_inputs(graph, shape)
graph = graph.apply("InferShape")
graph = graph_attr.set_dtype_inputs(graph, dtype)
graph._set_json_attr("target", str(target), "str")
if target_host is not None:
......
......@@ -96,11 +96,22 @@ def set_layout_inputs(g, layout):
Returns
-------
g : Graph
The updated graph with updated dtype.
The updated graph with updated layout.
"""
list_shape = [
layout.get(name, "default") for name in g.index.input_names]
g._set_json_attr("layout_inputs", list_shape, 'list_str')
if isinstance(layout, dict):
list_layout = [
layout.get(name, "__undef__") for name in g.index.input_names]
elif isinstance(layout, str):
list_layout = ["__undef__"] * len(g.index.input_names)
list_layout[0] = layout
else:
raise ValueError("Input layout must be str or dict")
last_inferred_layouts = g.json_attr("layout")
if last_inferred_layouts:
input_layout = [last_inferred_layouts[g.index.entry_id(x)] for x in g.index.input_names]
for i, layout_stored in enumerate(input_layout):
list_layout[i] = list_layout[i] if list_layout[i] != '__undef__' else layout_stored
g._set_json_attr("layout_inputs", list_layout, 'list_layout')
return g
_move_out_module = tvm.get_global_func("nnvm.graph._move_module")
......
"""Module space to register contrib functions. Leave empty"""
......@@ -86,6 +86,10 @@ def _conv2d(inputs, attrs):
layout = attrs.get('layout', 'NCHW')
if layout not in ['NCHW', 'NHWC']:
_raise_not_supported('layout: ' + layout, 'conv2d')
if 'kernel_layout' in attrs:
kernel_layout = attrs['kernel_layout']
else:
kernel_layout = 'HWIO' if layout == 'NHWC' else 'OIHW'
op_name, new_attrs = 'conv2d', {}
new_attrs['channels'] = _required_attr(attrs, 'num_filter')
new_attrs['kernel_size'] = kernel
......@@ -94,6 +98,7 @@ def _conv2d(inputs, attrs):
new_attrs['dilation'] = attrs.get('dilate', (1, 1))
new_attrs['groups'] = attrs.get('num_group', 1)
new_attrs['layout'] = layout
new_attrs['kernel_layout'] = kernel_layout
new_attrs['use_bias'] = attrs.get('no_bias', 'False').strip() == 'False'
return _get_nnvm_op(op_name)(*inputs, **new_attrs)
......@@ -106,6 +111,10 @@ def _conv2d_transpose(inputs, attrs):
layout = attrs.get('layout', 'NCHW')
if layout not in ['NCHW', 'NHWC']:
_raise_not_supported('layout: ' + layout, 'conv2d_transpose')
if 'kernel_layout' in attrs:
kernel_layout = attrs['kernel_layout']
else:
kernel_layout = 'HWIO' if layout == 'NHWC' else 'OIHW'
op_name, new_attrs = 'conv2d_transpose', {}
new_attrs['channels'] = _required_attr(attrs, 'num_filter')
new_attrs['kernel_size'] = kernel
......@@ -115,6 +124,7 @@ def _conv2d_transpose(inputs, attrs):
new_attrs['dilation'] = attrs.get('dilate', (1, 1))
new_attrs['groups'] = attrs.get('num_group', 1)
new_attrs['layout'] = layout
new_attrs['kernel_layout'] = kernel_layout
new_attrs['use_bias'] = not _parse_bool_str(attrs, 'no_bias')
return _get_nnvm_op(op_name)(*inputs, **new_attrs)
......@@ -237,7 +247,7 @@ _convert_map = {
'min_axis' : _rename('min'),
'reshape' : _reshape,
'sum_axis' : _rename('sum'),
'UpSampling' : _upsampling
'UpSampling' : _upsampling,
}
def _convert_symbol(op_name, inputs, attrs,
......
......@@ -16,6 +16,7 @@ from . import _base
from ._base import _LIB, check_call as _check_call, _FFI_MODE, _all_var_init
from .attribute import AttrScope
from . import _symbol_internal as _internal
from . import contrib
# Use different verison of SymbolBase
# When possible, use cython to speedup part of computation.
......
......@@ -5,7 +5,7 @@ from __future__ import absolute_import
import tvm
import topi
from topi.util import get_const_int
from .tensor import _fschedule_broadcast
from .tensor import _fschedule_broadcast, _fschedule_injective
from . import registry as reg
from .registry import OpPattern
......@@ -32,6 +32,11 @@ reg.register_schedule("pad", _fschedule_broadcast)
reg.register_pattern("pad", OpPattern.INJECTIVE)
# layout transform
reg.register_schedule("__layout_transform__", _fschedule_injective)
reg.register_pattern("__layout_transform__", OpPattern.INJECTIVE)
@reg.register_schedule("softmax")
def schedule_softmax(_, outs, target):
"""Schedule definition of softmax"""
......@@ -108,6 +113,42 @@ def schedule_conv2d(attrs, outs, target):
reg.register_pattern("conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
# convolution NCHWc
@reg.register_compute("_contrib_conv2d_NCHWc")
def compute_contrib_conv2d_NCHWc(attrs, inputs, _):
"""Compute definition of conv2d NCHWc"""
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
dilation = attrs.get_int_tuple("dilation")
kh, kw = attrs.get_int_tuple('kernel_size')
groups = attrs.get_int("groups")
channels = attrs.get_int("channels")
assert dilation == (1, 1), "not support dilate now"
if groups == 1:
out = topi.nn.conv2d_NCHWc(inputs[0], inputs[1], channels, (kh, kw), strides, padding)
else:
raise ValueError("not support arbitrary group number > 1 for now")
if attrs.get_bool("use_bias"):
bias = inputs[2]
bias = topi.expand_dims(bias, axis=1, num_newaxis=2)
out = topi.broadcast_add(out, bias)
return out
@reg.register_schedule("_contrib_conv2d_NCHWc")
def schedule_contrib_conv2d_NCHWc(attrs, outs, target):
"""Schedule definition of conv2d NCHWc"""
groups = attrs.get_int("groups")
kh, kw = attrs.get_int_tuple('kernel_size')
oc = attrs.get_int("channels")
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
with tvm.target.create(target):
if groups == 1:
return topi.generic.schedule_conv2d_NCHWc(oc, (kh, kw), strides, padding, outs)
else:
raise ValueError("not support group number > 1 for now")
reg.register_pattern("_contrib_conv2d_NCHWc", OpPattern.OUT_ELEMWISE_FUSABLE)
# conv2d_transpose
@reg.register_compute("conv2d_transpose")
......
......@@ -25,6 +25,7 @@ class OpPattern(object):
_register_compute = tvm.get_global_func("nnvm._register_compute")
_register_schedule = tvm.get_global_func("nnvm._register_schedule")
_register_pattern = tvm.get_global_func("nnvm._register_pattern")
_register_alter_op_layout = tvm.get_global_func("nnvm.compiler._register_alter_op_layout")
def register_compute(op_name, f=None, level=10):
"""Register compute function for operator
......@@ -93,3 +94,29 @@ def register_pattern(op_name, pattern, level=10):
The priority level
"""
_register_pattern(op_name, pattern, level)
def register_alter_op_layout(op_name, f=None, level=10):
"""Register alter layout function for operator
Parameters
----------
op_name : str
The name of operator
f : function
The schedule function
level : int
The priority level
Returns
-------
fregister : function
Register function if f is not specified.
"""
def register(myf):
"""internal register function"""
_register_alter_op_layout(op_name, myf, level)
return myf
return register(f) if f else register
......@@ -294,7 +294,7 @@ int NNSymbolGetNumOutputs(SymbolHandle symbol,
nn_uint *output_count) {
Symbol *s = static_cast<Symbol*>(symbol);
API_BEGIN();
*output_count = static_cast<nn_uint>(s->outputs.size());
*output_count = static_cast<nn_uint>(s->outputs.size());
API_END();
}
......
/*!
* Copyright (c) 2018 by Contributors
* \file alter_op_layout.cc
* \brief Alter the operator layouts. Keep inferred layouts (if any) from previous stages.
* e.g., convolution may calculates faster with NCHW16c layout.
*/
#include <nnvm/pass.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/layout.h>
#include <nnvm/compiler/op_attr_types.h>
#include <nnvm/pass_functions.h>
#include <tvm/tvm.h>
#include <algorithm>
#include <functional>
#include "./compile_engine.h"
#include "./graph_transform.h"
namespace nnvm {
namespace compiler {
namespace {
tvm::Array<tvm::Tensor> GetTensorInfo(const IndexedGraph& idx_graph,
const uint32_t nid,
const ShapeVector& shape_vec,
const DTypeVector& dtype_vec) {
tvm::Array<tvm::Tensor> vec;
for (uint32_t i = 0; i < idx_graph[nid].source->num_outputs(); ++i) {
tvm::Array<tvm::Expr> shape;
for (int64_t x : shape_vec[idx_graph.entry_id(nid, i)]) {
CHECK_LE(x, static_cast<int64_t>(std::numeric_limits<int>::max()));
shape.push_back(tvm::make_const(tvm::Int(32), x));
}
vec.push_back(tvm::placeholder(
shape, GetTVMType(dtype_vec[idx_graph.entry_id(nid, i)])));
}
return vec;
}
Graph AlterOpLayout(const Graph& src) {
static auto& falter_op_layout =
Op::GetAttr<nnvm::compiler::FTVMAlterOpLayout >("FTVMAlterOpLayout");
const ShapeVector& shape_vec = src.GetAttr<ShapeVector>("shape");
const DTypeVector& dtype_vec = src.GetAttr<DTypeVector>("dtype");
const IndexedGraph& idx_graph = src.indexed_graph();
std::vector<std::vector<Layout> > in_layouts_of_node(idx_graph.num_nodes());
std::vector<std::vector<Layout> > out_layouts_of_node(idx_graph.num_nodes());
std::unordered_map<const Node*, uint32_t> new_nodes;
if (src.HasAttr("layout")) {
// record layouts so that LayoutTransform pass can fix layouts correctly,
// e.g., conv2d can be replaced by some contrib implement
// whose layout is different from the original one
// (which was imported from a model file).
const auto& layouts = src.GetAttr<std::vector<Layout> >("layout");
for (uint32_t nid = 0; nid < idx_graph.num_nodes(); ++nid) {
const auto &inode = idx_graph[nid];
if (falter_op_layout.count(inode.source->op())) {
// do not record input layouts of nodes that will be replaced.
continue;
}
std::vector<Layout> in_layout;
for (const auto& e : inode.inputs) {
in_layout.emplace_back(layouts[idx_graph.entry_id(e)]);
}
in_layouts_of_node[nid] = in_layout;
std::vector<Layout> out_layout;
for (uint i = 0; i < inode.source->num_outputs(); ++i) {
out_layout.emplace_back(layouts[idx_graph.entry_id(nid, i)]);
}
out_layouts_of_node[nid] = out_layout;
}
}
auto transform = [&](uint32_t nid,
const NodePtr& n,
std::vector<NodeEntry>* ret) {
nnvm::compiler::FTVMAlterOpLayout fn_alter_op_layout =
falter_op_layout.get(n->op(), nullptr);
if (fn_alter_op_layout == nullptr) {
new_nodes[n.get()] = nid;
return false;
}
// construct parameters for registered function
std::vector<Symbol> op_inputs;
tvm::Array<tvm::Tensor> tensor_infos;
CHECK_EQ(n->num_inputs(), idx_graph[nid].inputs.size());
for (uint32_t i = 0; i < n->num_inputs(); ++i) {
const nnvm::NodeEntry& input = n->inputs[i];
// input operator
Symbol op_input;
op_input.outputs.push_back(input);
op_inputs.push_back(op_input);
// input tinfo, extract from the original graph
// because it was where infer_shape & infer_type applied.
tvm::Array<tvm::Tensor> op_output_tinfos =
GetTensorInfo(idx_graph, idx_graph[nid].inputs[i].node_id,
shape_vec, dtype_vec);
tensor_infos.push_back(op_output_tinfos[input.index]);
}
// callback registered function to get a new operator.
auto op = fn_alter_op_layout(n->attrs, Symbol::CreateGroup(op_inputs), tensor_infos);
*ret = op.outputs;
return true;
};
Graph ret = nnvm::compiler::GraphTransform(src, transform);
if (src.HasAttr("layout")) {
// restore the layouts to return graph
const auto& ret_idx = ret.indexed_graph();
std::vector<Layout> ret_layouts(ret_idx.num_node_entries(), Layout::Undef());
for (uint32_t nid = 0; nid < ret_idx.num_nodes(); ++nid) {
const auto& inode = ret_idx[nid];
if (new_nodes.count(inode.source)) {
const std::vector<Layout>& in_layouts =
in_layouts_of_node[new_nodes[inode.source]];
for (const auto& e : inode.inputs) {
ret_layouts[ret_idx.entry_id(e)] = in_layouts[e.index];
}
const std::vector<Layout>& out_layouts =
out_layouts_of_node[new_nodes[inode.source]];
for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) {
ret_layouts[ret_idx.entry_id(nid, i)] = out_layouts[i];
}
}
}
// cannot call indexed_graph() before return the origin Graph,
// thus create a new one.
nnvm::Graph new_ret;
new_ret.outputs = ret.outputs;
new_ret.attrs["layout"] = std::make_shared<any>(std::move(ret_layouts));
return new_ret;
}
return ret;
}
// register pass
NNVM_REGISTER_PASS(AlterOpLayout)
.set_body(AlterOpLayout)
.set_change_graph(true);
} // namespace
} // namespace compiler
} // namespace nnvm
......@@ -362,7 +362,7 @@ bool Pool2DBackward(
std::vector<FoldChainInfo>* in_axis) {
using top::Pool2DParam;
const Pool2DParam& param = nnvm::get<Pool2DParam>(attrs.parsed);
if (out_info.axis == 1 && param.layout == top::kNCHW) {
if (out_info.axis == 1 && param.layout == "NCHW") {
(*in_axis)[0] = out_info;
}
return false;
......@@ -376,7 +376,7 @@ bool Pool2DForward(
FoldChainInfo* out_info) {
using top::Pool2DParam;
const Pool2DParam& param = nnvm::get<Pool2DParam>(attrs.parsed);
if ((*in_info)[0].axis == 1 && param.layout == top::kNCHW) {
if ((*in_info)[0].axis == 1 && param.layout == "NCHW") {
*out_info = (*in_info)[0];
}
return false;
......@@ -467,7 +467,7 @@ bool Conv2DScaleAxisBackward(
const Conv2DParam& param = nnvm::get<Conv2DParam>(attrs.parsed);
if (out_info.kind != kPending) return false;
// only optimize for nchw for now
if (param.layout == top::kNCHW && out_info.axis == 1) {
if (param.layout == "NCHW" && out_info.axis == 1) {
(*in_axis)[1].kind = kMulConsumer;
(*in_axis)[1].axis = 0;
(*in_axis)[1].source = out_info.source;
......@@ -492,7 +492,7 @@ bool Conv2DScaleAxisForward(
const Conv2DParam& param = nnvm::get<Conv2DParam>(attrs.parsed);
if ((*in_info)[0].kind != kPending) return false;
// only optimize for nchw for now
if (param.layout == top::kNCHW && (*in_info)[0].axis == 1) {
if (param.layout == "NCHW" && (*in_info)[0].axis == 1) {
(*in_info)[1].kind = kMulConsumer;
(*in_info)[1].axis = 1;
(*in_info)[1].source = (*in_info)[0].source;
......
/*!
* Copyright (c) 2017 by Contributors
* \file layout_transform.cc
* \brief Transforms layout.
*/
#include <nnvm/graph.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/graph_attr_types.h>
#include <nnvm/pass.h>
#include <nnvm/compiler/op_attr_types.h>
#include <nnvm/compiler/contrib_op_param.h>
namespace nnvm {
namespace compiler {
const TLayoutInfo& GetDefaultLayout() {
static TLayoutInfo default_layout = "default";
return default_layout;
}
nnvm::NodePtr CreateLayoutTransformNode(const std::string& src,
const std::string& dst) {
static const nnvm::Op* trans_op = nnvm::Op::Get("layout_transform");
static int count = 0;
nnvm::NodePtr n = nnvm::Node::Create();
n->attrs.op = trans_op;
n->attrs.name = src + "_to_" + dst + std::to_string(count++);
n->attrs.dict["src_layout"] = src;
n->attrs.dict["dst_layout"] = dst;
n->op()->attr_parser(&(n->attrs));
return n;
}
/*!
* \brief A simple layout transform pass that will
* insert layout transform nodes automatically.
*/
nnvm::Graph LayoutTransform(nnvm::Graph src) {
static auto& op_layout_request =
nnvm::Op::GetAttr<FTVMLayoutRequest>("FTVMLayoutRequest");
static auto& op_vecop =
nnvm::Op::GetAttr<FTVMVectorizedOp>("FTVMVectorizedOp");
static auto& op_pattern = nnvm::Op::GetAttr<TOpPattern>("TOpPattern");
const ShapeVector& shape_vec = src.GetAttr<ShapeVector>("shape");
const std::vector<TLayoutInfo>& input_layouts =
src.GetAttr<std::vector<TLayoutInfo> >("layout_inputs");
const IndexedGraph& idx = src.indexed_graph();
std::vector<TLayoutInfo> produce_vec(idx.num_node_entries(), GetDefaultLayout());
std::vector<nnvm::NodePtr> mirror_vec(idx.num_nodes(), nullptr);
// use op pattern to decide whether an op is map
auto is_map_op = [&](size_t nid) {
TOpPattern pt = op_pattern.get(idx[nid].source->op(), kOpaque);
bool is_map = (pt <= kBroadcast);
if (pt == kBroadcast) {
for (const auto& e : idx[nid].inputs) {
if (shape_vec[idx.entry_id(nid, 0)] != shape_vec[idx.entry_id(e)]) {
is_map = false;
break;
}
}
}
return is_map;
};
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid];
nnvm::NodePtr new_node = nnvm::Node::Create();
*new_node = *(inode.source);
if (new_node->is_variable()) {
auto input_iter = std::find(
idx.input_nodes().cbegin(), idx.input_nodes().cend(), nid);
CHECK(input_iter != idx.input_nodes().cend());
size_t input_id = std::distance(idx.input_nodes().cbegin(), input_iter);
produce_vec[idx.entry_id(nid, 0)] = input_layouts[input_id];
mirror_vec[nid] = new_node;
continue;
}
if (op_vecop.count(inode.source->op())) {
new_node = op_vecop[inode.source->op()](inode.source);
new_node->inputs.resize(new_node->num_inputs());
}
// set up output and input layouts
std::vector<TLayoutInfo> request_ilayouts(new_node->num_inputs(), GetDefaultLayout());
if (op_layout_request.count(new_node->op())) {
std::vector<TLayoutInfo> produce_olayouts(new_node->num_outputs(), GetDefaultLayout());
CHECK(op_layout_request[new_node->op()](
new_node->attrs, &request_ilayouts, &produce_olayouts))
<< "Layout request fail";
CHECK_EQ(request_ilayouts.size(), new_node->num_inputs());
CHECK_EQ(produce_olayouts.size(), new_node->num_outputs());
for (size_t i = 0; i < new_node->num_outputs(); ++i) {
produce_vec[idx.entry_id(nid, i)] = produce_olayouts[i];
}
}
bool map_layout = is_map_op(nid);
if (map_layout) {
const TLayoutInfo& layout = produce_vec[idx.entry_id(inode.inputs[0])];
for (const auto& e : inode.inputs) {
if (produce_vec[idx.entry_id(e)] != layout) {
map_layout = false;
break;
}
}
if (map_layout) {
for (size_t i = 0; i < inode.source->num_outputs(); ++i) {
produce_vec[idx.entry_id(nid, i)] = layout;
}
}
}
for (size_t i = 0; i < inode.inputs.size(); ++i) {
const auto& e = inode.inputs[i];
const nnvm::NodePtr& in = mirror_vec[e.node_id];
new_node->inputs[i] =
nnvm::NodeEntry{in, e.index, e.version};
TLayoutInfo produce = produce_vec[idx.entry_id(e)];
TLayoutInfo request = request_ilayouts[i];
if (!map_layout && (produce != request)) {
nnvm::NodePtr tnode = CreateLayoutTransformNode(produce, request);
tnode->attrs.name =
idx[e.node_id].source->attrs.name + "_" + request;
tnode->inputs.emplace_back(new_node->inputs[i]);
new_node->inputs[i] = nnvm::NodeEntry{tnode, 0, 0};
}
}
mirror_vec[nid] = new_node;
}
std::vector<nnvm::NodeEntry> outputs;
for (const auto& e : idx.outputs()) {
TLayoutInfo produce = produce_vec[idx.entry_id(e)];
if (produce != GetDefaultLayout()) {
nnvm::NodePtr tnode = CreateLayoutTransformNode(produce, GetDefaultLayout());
tnode->attrs.name =
idx[e.node_id].source->attrs.name + "_default";
tnode->inputs.emplace_back(
nnvm::NodeEntry{mirror_vec[e.node_id], e.index, e.version});
outputs.emplace_back(nnvm::NodeEntry{tnode, 0, 0});
} else {
outputs.emplace_back(
nnvm::NodeEntry{mirror_vec[e.node_id], e.index, e.version});
}
}
nnvm::Graph ret;
ret.outputs = std::move(outputs);
return ret;
}
} // namespace compiler
} // namespace nnvm
......@@ -8,6 +8,7 @@
#include <nnvm/op.h>
#include <nnvm/compiler/packed_func_ext.h>
#include <nnvm/compiler/op_attr_types.h>
#include <tvm/runtime/c_runtime_api.h>
#include "./node_attr.h"
#include "compile_engine.h"
......@@ -62,6 +63,23 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._dict_keys")
*rv = keys;
});
TVM_REGISTER_GLOBAL("nnvm.compiler._register_alter_op_layout")
.set_body([](TVMArgs args, TVMRetValue *rv) {
// Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown
PackedFunc* f = new PackedFunc(args[1].operator PackedFunc());
Op& op = ::dmlc::Registry<nnvm::Op>::Get()->__REGISTER_OR_GET__(args[0]);
auto fpack = [f](const NodeAttrs& attrs,
const Symbol& inputs,
const Array<Tensor>& tinfos) {
TVMRetValue ret = (*f)(GetAttrDict(attrs), inputs, tinfos);
CHECK_EQ(ret.type_code(), tvm::runtime::extension_class_info<Symbol>::code)
<< " expected " << "Symbol (code = " << tvm::runtime::extension_class_info<Symbol>::code
<< ") but get code = " << ret.type_code();
return *(static_cast<Symbol*>(ret.value().v_handle));
};
op.set_attr<FTVMAlterOpLayout>("FTVMAlterOpLayout", fpack, args[2]);
});
// custom version of TVM compute
TVM_REGISTER_GLOBAL("nnvm._register_compute")
.set_body([](TVMArgs args, TVMRetValue *rv) {
......@@ -84,7 +102,7 @@ TVM_REGISTER_GLOBAL("nnvm._register_compute")
TVM_REGISTER_GLOBAL("nnvm._register_schedule")
.set_body([](TVMArgs args, TVMRetValue *rv) {
// Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown
// Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown
PackedFunc* f = new PackedFunc(args[1].operator PackedFunc());
Op& op = ::dmlc::Registry<nnvm::Op>::Get()->__REGISTER_OR_GET__(args[0]);
auto fschedule = [f](const NodeAttrs& attrs,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment