From 10d9da486518a0128f17156caf023c5339b96087 Mon Sep 17 00:00:00 2001 From: Salem Derisavi <33945117+derisavi-huawei@users.noreply.github.com> Date: Thu, 30 Nov 2017 18:44:08 -0500 Subject: [PATCH] Consider variable range information during simplification of tensorize expressions (#674) --- src/arithmetic/canonical.cc | 2 ++ src/op/tensorize.cc | 22 +++++++++++++++------- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/src/arithmetic/canonical.cc b/src/arithmetic/canonical.cc index e7f9da1b4..24369db02 100644 --- a/src/arithmetic/canonical.cc +++ b/src/arithmetic/canonical.cc @@ -5,6 +5,7 @@ */ #include <tvm/ir_mutator.h> #include <tvm/arithmetic.h> +#include <tvm/ir_pass.h> #include "./canonical.h" #include "./compute_expr.h" #include "arithmetic/Simplify.h" @@ -612,6 +613,7 @@ void Canonical::SetRange(Var v, Range r, int level) { } // namespace arith namespace ir { + Stmt CanonicalSimplify(Stmt stmt, Map<Var, Range> vrange) { return arith::Canonical(vrange).Simplify(stmt); } diff --git a/src/op/tensorize.cc b/src/op/tensorize.cc index 9715fcbab..243b7931d 100644 --- a/src/op/tensorize.cc +++ b/src/op/tensorize.cc @@ -187,7 +187,8 @@ class TensorIntrinMatcher final : public IRMutator { const Stage& stage, const std::unordered_map<IterVar, Range>& out_dom, const std::unordered_map<Tensor, Array<Range> >& in_region, - const TensorIntrin& intrin) { + const TensorIntrin& intrin, + Map<Var, Range>* compute_intrin_iter_space) { CHECK(self == stage->op.get()); // input remap. Array<Tensor> inputs = self->InputTensors(); @@ -232,6 +233,7 @@ class TensorIntrinMatcher final : public IRMutator { Range r = out_dom.at(iv); var_remap_[iv->var.get()] = target_iv->var + r->min; axis_remap_[iv] = target_iv; + compute_intrin_iter_space->Set(target_iv->var, target_iv->dom); } // Remap reduction axis CHECK_GE(self->reduce_axis.size(), intrin_compute->reduce_axis.size()) @@ -251,6 +253,7 @@ class TensorIntrinMatcher final : public IRMutator { Range r = out_dom.at(iv); var_remap_[iv->var.get()] = target_iv->var + r->min; axis_remap_[iv] = target_iv; + compute_intrin_iter_space->Set(target_iv->var, target_iv->dom); } } @@ -275,9 +278,10 @@ Array<Expr> MatchTensorizeBody( const Stage& stage, const std::unordered_map<IterVar, Range>& out_dom, const std::unordered_map<Tensor, Array<Range> >& in_region, - const TensorIntrin& intrin) { + const TensorIntrin& intrin, + Map<Var, Range>* compute_intrin_iter_space) { TensorIntrinMatcher matcher; - matcher.Init(self, stage, out_dom, in_region, intrin); + matcher.Init(self, stage, out_dom, in_region, intrin, compute_intrin_iter_space); Array<Expr> ret; for (Expr expr : self->body) { ret.push_back(matcher.Mutate(expr)); @@ -291,14 +295,16 @@ void VerifyTensorizeBody( const std::unordered_map<IterVar, Range>& out_dom, const std::unordered_map<Tensor, Array<Range> >& in_region, const TensorIntrin& intrin) { - Array<Expr> body = MatchTensorizeBody(self, stage, out_dom, in_region, intrin); + Map<Var, Range> compute_intrin_iter_space; + Array<Expr> body = MatchTensorizeBody(self, stage, out_dom, in_region, intrin, + &compute_intrin_iter_space); const ComputeOpNode* intrin_compute = intrin->op.as<ComputeOpNode>(); CHECK(intrin_compute) << "Only support compute intrinsic for now"; CHECK_EQ(body.size(), intrin_compute->body.size()) << "Tensorize failed: body size mismatch"; for (size_t i = 0; i < body.size(); ++i) { - Expr lhs = CanonicalSimplify(body[i]); - Expr rhs = CanonicalSimplify(intrin_compute->body[i]); + Expr lhs = CanonicalSimplify(body[i], compute_intrin_iter_space); + Expr rhs = CanonicalSimplify(intrin_compute->body[i], compute_intrin_iter_space); if (lhs.type() != rhs.type()) { LOG(FATAL) << "Failed to match the data type with TensorIntrin " @@ -459,11 +465,13 @@ TVM_REGISTER_API("test.op.MatchTensorizeBody") Map<IterVar, Range> out_dom = args[1]; Map<Tensor, Array<Range> > in_region = args[2]; TensorIntrin intrin = args[3]; + Map<Var, Range> vrange; CHECK(stage->op.as<ComputeOpNode>()); *ret = MatchTensorizeBody(stage->op.as<ComputeOpNode>(), stage, as_unordered_map(out_dom), as_unordered_map(in_region), - intrin); + intrin, + &vrange); }); } // namespace tvm -- GitLab