-
Tianqi Chen authoredTianqi Chen authored
ir_deep_compare.cc 15.13 KiB
/*!
* Copyright (c) 2017 by Contributors
* \file ir_deep_compare.cc
*/
#include <tvm/ir_pass.h>
#include <tvm/ir_functor_ext.h>
namespace tvm {
namespace ir {
using ExprComparator = ExprFunctor<void(const Expr& n, const Expr &other)>;
using StmtComparator = StmtFunctor<void(const Stmt& n, const Stmt &other)>;
#define DEFINE_BIOP_EXPR_CMP_(OP) \
void VisitExpr_(const OP* op, const Expr& other) final { \
const OP* rhs = other.as<OP>(); \
if (CompareExpr(op->a, rhs->a) != 0) return; \
if (CompareExpr(op->b, rhs->b) != 0) return; \
}
// Deep comparison to check if two IR graph are equivalent
class IRDeepCompare :
public ExprComparator, public StmtComparator {
public:
// Equality comparison
bool Equal(const Stmt& lhs, const Stmt& rhs) {
tie_def_ = true;
VisitStmt(lhs, rhs);
return order_ == 0;
}
bool Equal(const Expr& lhs, const Expr& rhs) {
tie_def_ = true;
VisitExpr(lhs, rhs);
return order_ == 0;
}
int Compare(const Expr& lhs, const Expr& rhs) {
tie_def_ = false;
VisitExpr(lhs, rhs);
return order_;
}
void VisitExpr(const Expr& n, const Expr& other) override {
if (order_ != 0) return;
if (n.same_as(other)) return;
if (CompareValue(n->type_index(), other->type_index()) != 0) return;
if (CompareType(n.type(), other.type()) != 0) return;
ExprComparator::VisitExpr(n, other);
}
void VisitStmt(const Stmt& n, const Stmt& other) override {
if (order_ != 0) return;
if (n.same_as(other)) return;
if (CompareValue(n->type_index(), other->type_index()) != 0) return;
StmtComparator::VisitStmt(n, other);
}
// Stmt
void VisitStmt_(const LetStmt* op, const Stmt& other) final {
const LetStmt* rhs = other.as<LetStmt>();
if (CompareExpr(op->value, rhs->value) != 0) return;
if (tie_def_) {
vmap_[op->var.get()] = rhs->var.get();
} else {
if (CompareExpr(op->var, rhs->var) != 0) return;
}
if (CompareStmt(op->body, rhs->body) != 0) return;
}
void VisitStmt_(const AttrStmt* op, const Stmt& other) final {
const AttrStmt* rhs = other.as<AttrStmt>();
if (CompareString(op->attr_key, rhs->attr_key) != 0) return;
if (CompareNodeRef(op->node, rhs->node) != 0) return;
if (CompareExpr(op->value, rhs->value) != 0) return;
if (CompareStmt(op->body, rhs->body) != 0) return;
}
void VisitStmt_(const IfThenElse* op, const Stmt& other) final {
const IfThenElse* rhs = other.as<IfThenElse>();
if (CompareExpr(op->condition, rhs->condition) != 0) return;
if (CompareStmt(op->then_case, rhs->then_case) != 0) return;
if (CompareStmt(op->else_case, rhs->else_case) != 0) return;
}
void VisitStmt_(const For* op, const Stmt& other) final {
const For* rhs = other.as<For>();
if (CompareExpr(op->min, rhs->min) != 0) return;
if (CompareExpr(op->extent, rhs->extent) != 0) return;
if (tie_def_) {
vmap_[op->loop_var.get()] = rhs->loop_var.get();
} else {
if (CompareExpr(op->loop_var, rhs->loop_var) != 0) return;
}
if (CompareStmt(op->body, rhs->body) != 0) return;
}
void VisitStmt_(const Allocate* op, const Stmt& other) final {
const Allocate* rhs = other.as<Allocate>();
if (tie_def_) {
vmap_[op->buffer_var.get()] = rhs->buffer_var.get();
} else {
if (CompareExpr(op->buffer_var, rhs->buffer_var) != 0) return;
}
if (CompareType(op->type, rhs->type) != 0) return;
if (CompareArray(op->extents, rhs->extents) != 0) return;
if (CompareExpr(op->condition, rhs->condition) != 0) return;
if (CompareStmt(op->body, rhs->body) != 0) return;
if (CompareExpr(op->new_expr, rhs->new_expr) != 0) return;
if (CompareString(op->free_function, rhs->free_function) != 0) return;
}
void VisitStmt_(const Store* op, const Stmt& other) final {
const Store* rhs = other.as<Store>();
if (CompareExpr(op->buffer_var, rhs->buffer_var) != 0) return;
if (CompareExpr(op->value, rhs->value) != 0) return;
if (CompareExpr(op->index, rhs->index) != 0) return;
if (CompareExpr(op->predicate, rhs->predicate) != 0) return;
}
void VisitStmt_(const Free* op, const Stmt& other) final {
const Free* rhs = other.as<Free>();
if (CompareExpr(op->buffer_var, rhs->buffer_var) != 0) return;
}
void VisitStmt_(const AssertStmt* op, const Stmt& other) final {
const AssertStmt* rhs = other.as<AssertStmt>();
if (CompareExpr(op->condition, rhs->condition) != 0) return;
if (CompareExpr(op->message, rhs->message) != 0) return;
if (CompareStmt(op->body, rhs->body) != 0) return;
}
void VisitStmt_(const ProducerConsumer* op, const Stmt& other) final {
const ProducerConsumer* rhs = other.as<ProducerConsumer>();
if (CompareNodeRef(op->func, rhs->func) != 0) return;
if (CompareValue(op->is_producer, rhs->is_producer) != 0) return;
if (CompareStmt(op->body, rhs->body) != 0) return;
}
void VisitStmt_(const Provide* op, const Stmt& other) final {
const Provide* rhs = other.as<Provide>();
if (CompareNodeRef(op->func, rhs->func) != 0) return;
if (CompareValue(op->value_index, rhs->value_index) != 0) return;
if (CompareExpr(op->value, rhs->value) != 0) return;
if (CompareArray(op->args, rhs->args) != 0) return;
}
void VisitStmt_(const Realize* op, const Stmt& other) final {
const Realize* rhs = other.as<Realize>();
if (CompareNodeRef(op->func, rhs->func) != 0) return;
if (CompareValue(op->value_index, rhs->value_index) != 0) return;
if (CompareType(op->type, rhs->type) != 0) return;
if (CompareRegion(op->bounds, rhs->bounds) != 0) return;
if (CompareStmt(op->body, rhs->body) != 0) return;
}
void VisitStmt_(const Prefetch* op, const Stmt& other) final {
const Prefetch* rhs = other.as<Prefetch>();
if (CompareNodeRef(op->func, rhs->func) != 0) return;
if (CompareValue(op->value_index, rhs->value_index) != 0) return;
if (CompareType(op->type, rhs->type) != 0) return;
if (CompareRegion(op->bounds, rhs->bounds) != 0) return;
}
void VisitStmt_(const Block* op, const Stmt& other) final {
const Block* rhs = other.as<Block>();
if (CompareStmt(op->first, rhs->first) != 0) return;
if (CompareStmt(op->rest, rhs->rest) != 0) return;
}
void VisitStmt_(const Evaluate* op, const Stmt& other) final {
const Evaluate* rhs = other.as<Evaluate>();
CompareExpr(op->value, rhs->value);
}
// Exprs
void VisitExpr_(const Variable* op, const Expr& other) final {
const Variable* rhs = other.as<Variable>();
auto it = vmap_.find(op);
if (it != vmap_.end()) op = it->second;
if (op < rhs) {
order_ = -1;
} else if (op > rhs) {
order_ = +1;
}
}
void VisitExpr_(const Load* op, const Expr& other) final {
const Load* rhs = other.as<Load>();
if (CompareExpr(op->buffer_var, rhs->buffer_var) != 0) return;
if (CompareExpr(op->index, rhs->index) != 0) return;
if (CompareExpr(op->predicate, rhs->predicate) != 0) return;
}
void VisitExpr_(const Let* op, const Expr& other) final {
const Let* rhs = other.as<Let>();
if (tie_def_) {
vmap_[op->var.get()] = rhs->var.get();
} else {
if (CompareExpr(op->var, rhs->var) != 0) return;
}
if (CompareExpr(op->value, rhs->value) != 0) return;
if (CompareExpr(op->body, rhs->body) != 0) return;
}
void VisitExpr_(const Call* op, const Expr& other) final {
const Call* rhs = other.as<Call>();
if (CompareString(op->name, rhs->name)) return;
if (CompareArray(op->args, rhs->args)) return;
if (CompareValue(op->call_type, rhs->call_type) != 0) return;
if (CompareNodeRef(op->func, rhs->func) != 0) return;
if (CompareValue(op->value_index, rhs->value_index) != 0) return;
}
void VisitExpr_(const Reduce *op, const Expr& other) final {
const Reduce* rhs = other.as<Reduce>();
if (CompareCommReducer(op->combiner, rhs->combiner) != 0) return;
if (CompareValue(op->axis.size(), rhs->axis.size()) != 0) return;
if (CompareValue(op->value_index, rhs->value_index) != 0) return;
for (size_t i = 0; i < op->axis.size(); ++i) {
if (CompareExpr(op->axis[i]->dom->min, rhs->axis[i]->dom->min) != 0) return;
if (CompareExpr(op->axis[i]->dom->extent, rhs->axis[i]->dom->extent) != 0) return;
if (tie_def_) {
vmap_[op->axis[i]->var.get()] = rhs->axis[i]->var.get();
} else {
if (CompareExpr(op->axis[i]->var, rhs->axis[i]->var) != 0) return;
}
}
if (CompareExpr(op->condition, rhs->condition) != 0) return;
if (CompareArray(op->source, rhs->source) != 0) return;
}
void VisitExpr_(const IntImm *op, const Expr& other) final {
CompareValue(op->value, other.as<IntImm>()->value);
}
void VisitExpr_(const UIntImm *op, const Expr& other) final {
CompareValue(op->value, other.as<UIntImm>()->value);
}
void VisitExpr_(const FloatImm *op, const Expr& other) final {
CompareValue(op->value, other.as<FloatImm>()->value);
}
void VisitExpr_(const StringImm *op, const Expr& other) final {
CompareString(op->value, other.as<StringImm>()->value);
}
void VisitExpr_(const Cast *op, const Expr& other) final {
CompareExpr(op->value, other.as<Cast>()->value);
}
void VisitExpr_(const Not *op, const Expr& other) final {
CompareExpr(op->a, other.as<Not>()->a);
}
void VisitExpr_(const Select *op, const Expr& other) final {
const Select* rhs = other.as<Select>();
if (CompareExpr(op->condition, rhs->condition) != 0) return;
if (CompareExpr(op->true_value, rhs->true_value) != 0) return;
if (CompareExpr(op->false_value, rhs->false_value) != 0) return;
}
void VisitExpr_(const Ramp *op, const Expr& other) final {
const Ramp* rhs = other.as<Ramp>();
if (CompareExpr(op->base, rhs->base) != 0) return;
if (CompareExpr(op->stride, rhs->stride) != 0) return;
if (CompareValue(op->lanes, rhs->lanes) != 0) return;
}
void VisitExpr_(const Broadcast *op, const Expr& other) final {
const Broadcast* rhs = other.as<Broadcast>();
if (CompareExpr(op->value, rhs->value) != 0) return;
if (CompareValue(op->lanes, rhs->lanes) != 0) return;
}
void VisitExpr_(const Shuffle *op, const Expr& other) final {
const Shuffle* rhs = other.as<Shuffle>();
if (CompareArray(op->vectors, rhs->vectors) != 0) return;
if (CompareArray(op->indices, rhs->indices) != 0) return;
}
DEFINE_BIOP_EXPR_CMP_(Add)
DEFINE_BIOP_EXPR_CMP_(Sub)
DEFINE_BIOP_EXPR_CMP_(Mul)
DEFINE_BIOP_EXPR_CMP_(Div)
DEFINE_BIOP_EXPR_CMP_(Mod)
DEFINE_BIOP_EXPR_CMP_(Min)
DEFINE_BIOP_EXPR_CMP_(Max)
DEFINE_BIOP_EXPR_CMP_(EQ)
DEFINE_BIOP_EXPR_CMP_(NE)
DEFINE_BIOP_EXPR_CMP_(LT)
DEFINE_BIOP_EXPR_CMP_(LE)
DEFINE_BIOP_EXPR_CMP_(GT)
DEFINE_BIOP_EXPR_CMP_(GE)
DEFINE_BIOP_EXPR_CMP_(And)
DEFINE_BIOP_EXPR_CMP_(Or)
private:
int CompareExpr(const Expr& lhs, const Expr& rhs) {
if (order_ != 0) return order_;
if (!lhs.defined() && rhs.defined()) {
order_ = -1; return order_;
}
if (!rhs.defined() && lhs.defined()) {
order_ = +1; return order_;
}
VisitExpr(lhs, rhs);
return order_;
}
int CompareStmt(const Stmt& lhs, const Stmt& rhs) {
if (order_ != 0) return order_;
if (!lhs.defined() && rhs.defined()) {
order_ = -1; return order_;
}
if (!rhs.defined() && lhs.defined()) {
order_ = +1; return order_;
}
VisitStmt(lhs, rhs);
return order_;
}
int CompareArray(const Array<Expr>& lhs, const Array<Expr>& rhs) {
if (order_ != 0) return order_;
if (CompareValue(lhs.size(), rhs.size()) != 0) return order_;
for (size_t i = 0; i < lhs.size(); ++i) {
if (CompareExpr(lhs[i], rhs[i]) != 0) return order_;
}
return order_;
}
int CompareRegion(const HalideIR::Internal::Region& lhs,
const HalideIR::Internal::Region& rhs) {
if (order_ != 0) return order_;
if (CompareValue(lhs.size(), rhs.size()) != 0) return order_;
for (size_t i = 0; i < lhs.size(); ++i) {
if (CompareExpr(lhs[i]->min, rhs[i]->min) != 0) return order_;
if (CompareExpr(lhs[i]->extent, rhs[i]->extent) != 0) return order_;
}
return order_;
}
int CompareNodeRef(const NodeRef& lhs, const NodeRef& rhs) {
if (order_ != 0) return order_;
if (lhs.get() < rhs.get()) {
order_ = -1; return order_;
}
if (lhs.get() > rhs.get()) {
order_ = +1; return order_;
}
return order_;
}
int CompareType(const Type& lhs, const Type& rhs) {
if (order_ != 0) return order_;
if (lhs == rhs) return order_;
if (CompareValue(lhs.code(), rhs.code()) != 0) return order_;
if (CompareValue(lhs.bits(), rhs.bits()) != 0) return order_;
if (CompareValue(lhs.lanes(), rhs.lanes()) != 0) return order_;
return order_;
}
int CompareString(const std::string& lhs, const std::string& rhs) {
if (order_ != 0) return order_;
order_ = lhs.compare(rhs);
return order_;
}
template<typename T>
int CompareValue(const T& lhs, const T& rhs) {
if (order_ != 0) return order_;
if (lhs < rhs) {
order_ = -1; return order_;
} else if (lhs > rhs) {
order_ = +1; return order_;
}
return order_;
}
int CompareCommReducer(const CommReducer& lhs, const CommReducer& rhs) {
if (order_ != 0) return order_;
if (lhs == rhs) return order_;
if (CompareValue(lhs->lhs.size(), rhs->lhs.size()) != 0) return order_;
if (CompareValue(lhs->rhs.size(), rhs->rhs.size()) != 0) return order_;
IRDeepCompare cmp;
if (tie_def_) {
for (size_t i = 0; i < lhs->lhs.size(); ++i) {
cmp.vmap_[lhs->lhs[i].get()] = rhs->lhs[i].get();
}
for (size_t i = 0; i < lhs->rhs.size(); ++i) {
cmp.vmap_[lhs->rhs[i].get()] = rhs->rhs[i].get();
}
} else {
for (size_t i = 0; i < lhs->lhs.size(); ++i) {
if (CompareExpr(lhs->lhs[i], rhs->lhs[i]) != 0) return order_;
}
for (size_t i = 0; i < lhs->lhs.size(); ++i) {
if (CompareExpr(lhs->rhs[i], rhs->rhs[i]) != 0) return order_;
}
}
order_ = cmp.CompareArray(lhs->result, rhs->result);
return order_;
}
// The order flag, smaller, -1, bigger: +1, equal: 0
int order_{0};
// Whether tie intermediate definitions.
// This allows use to tie definitions of two variables together.
// This enables us to assert equal between (let x in x + 1), (let y in y + 1)
// However, the comparison is no longer in total order.
// Only equality/non-equality information is valid.
bool tie_def_{false};
// varaible remap if any
std::unordered_map<const Variable*, const Variable*> vmap_;
};
bool Equal(const Stmt& lhs, const Stmt& rhs) {
return IRDeepCompare().Equal(lhs, rhs);
}
bool Equal(const Expr& lhs, const Expr& rhs) {
// quick pass for constant expressions.
if (const int64_t *a = as_const_int(lhs)) {
if (const int64_t *b = as_const_int(rhs)) {
return a[0] == b[0];
}
}
if (!lhs.defined()) {
if (rhs.defined()) return false;
if (!rhs.defined()) return true;
} else {
if (!rhs.defined()) return false;
}
// deep comparison.
return IRDeepCompare().Equal(lhs, rhs);
}
int Compare(const Expr& lhs, const Expr& rhs) {
return IRDeepCompare().Compare(lhs, rhs);
}
} // namespace ir
} // namespace tvm