Tianqi Chen authoredTianqi Chen authored
graph_pass.cc 17.80 KiB
* Copyright (c) 2017 by Contributors
* \file Additional optimization pass of NNVM.
#include <dmlc/json.h>
#include <nnvm/graph.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/graph_attr_types.h>
#include <nnvm/tuple.h>
#include <nnvm/pass.h>
#include <tvm/operation.h>
#include <tvm/lowered_func.h>
#include "./op_attr_types.h"
namespace tvm {
namespace contrib {
using nnvm::any;
using nnvm::IndexedGraph;
// The single fuse rule.
enum class FuseRule {
DLDataType GetDLType(int type_flag) {
if (type_flag == 0) return Type2TVMType(Float(32));
LOG(FATAL) << "unknown type_flag=" << type_flag;
return Type2TVMType(Float(32));
// Partition the graph into segments
// Each segment will be compiled into one operator.
// Need also mark the property of the segment.
nnvm::Graph GraphPartition(nnvm::Graph g) {
// setup ref counter
const IndexedGraph& idx = g.indexed_graph();
// Get attributes from the graph
const ShapeVector& shape_vec = g.GetAttr<ShapeVector>("shape");
const DTypeVector& dtype_vec = g.GetAttr<DTypeVector>("dtype");
// Transform to dltype
// In future, directly fo type inference in dltype.
DLTypeVector dltype_vec = DLTypeVector(dtype_vec.size());
for (size_t i = 0; i < dtype_vec.size(); ++i) {
dltype_vec[i] = GetDLType(dtype_vec[i]);
// Reference counter of each op node
// For now, always store result when an op is referred more than once.
std::vector<uint32_t> ref_count(idx.num_nodes(), 0);
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid];
if (inode.source->is_variable()) continue;
for (const auto& e : inode.inputs) {
// Pattern fo the subgraph
std::vector<TOpPattern> pattern_vec(idx.num_nodes(), kExtern);
// Whether node can be fused to parent.
std::vector<FuseRule> fuse_vec(idx.num_nodes(), FuseRule::kUknown);
// Operator pattern
static auto& op_pattern = nnvm::Op::GetAttr<TOpPattern>("TOpPattern");
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid];
if (inode.source->is_variable()) {
fuse_vec[nid] = FuseRule::kRealize; continue;