diff --git a/src/arithmetic/canonical.cc b/src/arithmetic/canonical.cc index e219b5541bdc390cc42143e69aa384455c080f3d..736b8dad78c7cd3862d3b01f8efa3e47635b0be6 100644 --- a/src/arithmetic/canonical.cc +++ b/src/arithmetic/canonical.cc @@ -648,6 +648,24 @@ T Simplify_(T a, Map<Var, Range> vrange) { Expr Simplify(Expr a, Map<Var, Range> vrange) { + // We should not pass an expression having a non-HalideIR op to + // Halide::Internal::simplify. Reduce op is the only such op at this time + // and it only appears as the top op in an expression. So we strip it + // first and send the sub-expressions to the simplifier. + if (const Reduce* r = a.as<Reduce>()) { + Array<Expr> new_source; + for (auto& e : r->source) { + new_source.push_back(Simplify_(e, vrange)); + } + Expr new_condition = Simplify_(r->condition, vrange); + if (r->source.same_as(new_source) && + r->condition.same_as(new_condition)) { + return a; + } else { + return Reduce::make( + r->combiner, new_source, r->axis, new_condition, r->value_index); + } + } return Simplify_(a, vrange); } diff --git a/src/op/tensorize.cc b/src/op/tensorize.cc index 243b7931da67d6e0aea52e625f01f9fd40b541a0..b4527f76e8082f1ea4aa171ee5786f7227360250 100644 --- a/src/op/tensorize.cc +++ b/src/op/tensorize.cc @@ -303,8 +303,10 @@ void VerifyTensorizeBody( CHECK_EQ(body.size(), intrin_compute->body.size()) << "Tensorize failed: body size mismatch"; for (size_t i = 0; i < body.size(); ++i) { - Expr lhs = CanonicalSimplify(body[i], compute_intrin_iter_space); - Expr rhs = CanonicalSimplify(intrin_compute->body[i], compute_intrin_iter_space); + Expr lhs = Simplify(body[i], compute_intrin_iter_space); + lhs = CanonicalSimplify(lhs, compute_intrin_iter_space); + Expr rhs = Simplify(intrin_compute->body[i], compute_intrin_iter_space); + rhs = CanonicalSimplify(rhs, compute_intrin_iter_space); if (lhs.type() != rhs.type()) { LOG(FATAL) << "Failed to match the data type with TensorIntrin "