Newer
Older
/*!
* Copyright (c) 2017 by Contributors
*
* \brief Lift specified AttrStmt scope to outer if
* the body contains the same scope.
* \file lift_attr_scope.cc
*/
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
namespace tvm {
namespace ir {
// NOTE: this optimization can only be applied
// to a few specified attr keys
class AttrScopeLifter : public IRMutator {
public:
explicit AttrScopeLifter(std::string attr_key)
: attr_key_(attr_key) {}
Stmt Lift(Stmt stmt) {
stmt = Mutate(stmt);
if (attr_node_.defined()) {
stmt = AttrStmt::make(
attr_node_, attr_key_, attr_value_, stmt);
}
return stmt;
}
// do not go beyond
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Allocate>();
if (attr_node_.defined()) {
Stmt body = AttrStmt::make(
attr_node_, attr_key_, attr_value_, op->body);
// undefine them
attr_node_ = NodeRef();
attr_value_ = Expr();
return Allocate::make(
op->buffer_var, op->type,
op->extents, op->condition, body,
op->new_expr, op->free_function);
} else {
return stmt;
}
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
if (op->attr_key == attr_key_) {
attr_node_ = op->node;
attr_value_ = op->value;
return op->body;
} else {
return IRMutator::Mutate_(op, s);
}
}
Stmt Mutate_(const Block* op, const Stmt& s) final {
std::vector<Stmt> seq;
FlattenSeq(op->first, &seq);
FlattenSeq(op->rest, &seq);
seq = MutateSeq(seq);
if (seq.size() == 2 &&
seq[0].same_as(op->first) &&
seq[1].same_as(op->rest)) {
return s;
}
Stmt Mutate_(const IfThenElse* op, const Stmt& s) final {
if (!op->else_case.defined()) {
return IRMutator::Mutate_(op, s);
}
Stmt then_case = this->Mutate(op->then_case);
NodeRef first_node;
Expr first_value;
std::swap(first_node, attr_node_);
std::swap(first_value, attr_value_);
Stmt else_case = this->Mutate(op->else_case);
if (attr_node_.defined() &&
attr_value_.defined() &&
first_node.defined() &&
first_value.defined() &&
attr_node_.same_as(first_node) &&
ValueSame(attr_value_, first_value)) {
if (then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
return s;
} else {
return IfThenElse::make(op->condition, then_case, else_case);
}
} else {
then_case = AttrStmt::make(
first_node, attr_key_, first_value, then_case);
}
if (attr_node_.defined()) {
else_case = AttrStmt::make(
attr_node_, attr_key_, attr_value_, else_case);
// undefine them
attr_node_ = NodeRef();
attr_value_ = Expr();
}
if (then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
return s;
} else {
return IfThenElse::make(op->condition, then_case, else_case);
}
}
}
private:
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
void FlattenSeq(Stmt s, std::vector<Stmt>* res) {
if (const Block* op = s.as<Block>()) {
FlattenSeq(op->first, res);
FlattenSeq(op->rest, res);
} else if (const ProducerConsumer* op = s.as<ProducerConsumer>()) {
if (!op->is_producer) {
FlattenSeq(op->body, res);
} else {
res->emplace_back(s);
}
} else {
res->emplace_back(s);
}
}
std::vector<Stmt> MutateSeq(const std::vector<Stmt>& seq) {
std::vector<Stmt> res_seq;
NodeRef curr_node;
Expr curr_value;
Stmt curr_stmt;
for (const Stmt & stmt : seq) {
attr_node_ = NodeRef();
attr_value_ = Expr();
Stmt rest = this->Mutate(stmt);
if (attr_node_.defined() &&
attr_value_.defined() &&
curr_node.defined() &&
curr_value.defined() &&
attr_node_.same_as(curr_node) &&
ValueSame(attr_value_, curr_value)) {
curr_stmt = Block::make(curr_stmt, rest);
} else {
if (curr_stmt.defined()) {
if (curr_node.defined()) {
curr_stmt = AttrStmt::make(
curr_node, attr_key_, curr_value, curr_stmt);
}
res_seq.push_back(curr_stmt);
}
curr_stmt = rest;
curr_node = attr_node_;
curr_value = attr_value_;
}
}
if (curr_stmt.defined()) {
// keep attr_node_, attr_node_
if (res_seq.size() == 0) {
return {curr_stmt};
}
if (curr_node.defined()) {
curr_stmt = AttrStmt::make(
curr_node, attr_key_, curr_value, curr_stmt);
}
res_seq.push_back(curr_stmt);
// reset
attr_node_ = NodeRef();
attr_value_ = Expr();
}
return res_seq;
}
// value comparison that also compares content of int constant
static bool ValueSame(const Expr& a, const Expr& b) {
if (a.same_as(b)) return true;
if (a->type_key() != b->type_key()) return false;
if (a.type() != b.type()) return false;
if (const IntImm* op = a.as<IntImm>()) {
return op->value == b.as<IntImm>()->value;
}
if (const UIntImm* op = a.as<UIntImm>()) {
return op->value == b.as<UIntImm>()->value;
}
return false;
}
std::string attr_key_;
NodeRef attr_node_;
Expr attr_value_;
};
Stmt LiftAttrScope(Stmt stmt, std::string attr_key) {
return AttrScopeLifter(attr_key).Lift(stmt);
}
} // namespace ir
} // namespace tvm