Skip to content
Snippets Groups Projects
graph_pass.cc 17.80 KiB
/*!
 *  Copyright (c) 2017 by Contributors
 * \file Additional optimization pass of NNVM.
 */
#include <dmlc/json.h>
#include <nnvm/graph.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/graph_attr_types.h>
#include <nnvm/tuple.h>
#include <nnvm/pass.h>
#include <tvm/operation.h>
#include <tvm/lowered_func.h>
#include "./op_attr_types.h"

namespace tvm {
namespace contrib {

using nnvm::any;
using nnvm::IndexedGraph;

// The single fuse rule.
enum class FuseRule {
  kUknown,
  kFuseToParent,
  kRealize
};

DLDataType GetDLType(int type_flag) {
  if (type_flag == 0) return Type2TVMType(Float(32));
  LOG(FATAL) << "unknown type_flag=" << type_flag;
  return Type2TVMType(Float(32));
}

// Partition the graph into segments
// Each segment will be compiled into one operator.
// Need also mark the property of the segment.
nnvm::Graph GraphPartition(nnvm::Graph g) {
  // setup ref counter
  const IndexedGraph& idx = g.indexed_graph();
  // Get attributes from the graph
  const ShapeVector& shape_vec = g.GetAttr<ShapeVector>("shape");
  const DTypeVector& dtype_vec = g.GetAttr<DTypeVector>("dtype");
  // Transform to dltype
  // In future, directly fo type inference in dltype.
  DLTypeVector dltype_vec = DLTypeVector(dtype_vec.size());
  for (size_t i = 0; i < dtype_vec.size(); ++i) {
    dltype_vec[i] = GetDLType(dtype_vec[i]);
  }

  // Reference counter of each op node
  // For now, always store result when an op is referred more than once.
  std::vector<uint32_t> ref_count(idx.num_nodes(), 0);
  for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
    const auto& inode = idx[nid];
    if (inode.source->is_variable()) continue;
    for (const auto& e : inode.inputs) {
      ++ref_count[e.node_id];
    }
  }
  // Pattern fo the subgraph
  std::vector<TOpPattern> pattern_vec(idx.num_nodes(),  kExtern);
  // Whether node can be fused to parent.
  std::vector<FuseRule> fuse_vec(idx.num_nodes(), FuseRule::kUknown);
  // Operator pattern
  static auto& op_pattern = nnvm::Op::GetAttr<TOpPattern>("TOpPattern");

  for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
    const auto& inode = idx[nid];
    if (inode.source->is_variable()) {
      fuse_vec[nid] = FuseRule::kRealize; continue;