diff --git a/src/pass/lower_intrin.cc b/src/pass/lower_intrin.cc index 84f3cb5bba6ffb7ec277eb8647773df319350e2e..33ac6a94ecf72b81ab6d4044e02a74b4704de067 100644 --- a/src/pass/lower_intrin.cc +++ b/src/pass/lower_intrin.cc @@ -34,22 +34,50 @@ class IntrinInjecter : public IRMutator { } Expr Mutate_(const Add* op, const Expr& e) final { - if (fma_ == nullptr || !op->type.is_float()) { - return IRMutator::Mutate_(op, e); - } if (const Mul* mb = op->b.as<Mul>()) { - Expr r = (*fma_)(Call::make( - op->type, "fma", {mb->a, mb->b, op->a}, Call::PureIntrinsic)); - if (r.defined()) return this->Mutate(r); + return MakeFMA(mb->a, mb->b, op->a, op, e); } else if (const Mul* ma = op->a.as<Mul>()) { + return MakeFMA(ma->a, ma->b, op->b, op, e); + } + return IRMutator::Mutate_(op, e); + } + + private: + Expr SwapBroadcastCast(const Expr& e) { + // Try to change broadcast(cast(x)) to cast(broadcast(x)) + // For some targets, LLVM will generate more efficient FMA + // instruction with the latter. For example, vmla vs. vmlal + // on ARM. + if (const Broadcast* bcast = e.as<Broadcast>()) { + if (const Cast* cast = bcast->value.as<Cast>()) { + if (cast->type.bits() == cast->value.type().bits() * 2) { + Expr new_bcast = Broadcast::make(cast->value, bcast->lanes); + return Cast::make(bcast->type, new_bcast); + } + } + } + return e; + } + + Expr MakeFMA(const Expr& a, const Expr& b, const Expr& c, + const Add* op, const Expr& e) { + // emit fma instruction: a * b + c + Expr lhs = SwapBroadcastCast(a); + Expr rhs = SwapBroadcastCast(b); + + if (fma_ != nullptr && op->type.is_float()) { Expr r = (*fma_)(Call::make( - op->type, "fma", {ma->a, ma->b, op->b}, Call::PureIntrinsic)); + op->type, "fma", {lhs, rhs, c}, Call::PureIntrinsic)); if (r.defined()) return this->Mutate(r); + } else { + if (!lhs.same_as(a) || !rhs.same_as(b)) { + Expr mul = this->Mutate(Mul::make(lhs, rhs)); + return Add::make(mul, this->Mutate(c)); + } } return IRMutator::Mutate_(op, e); } - private: Expr ApplyPattern(const std::string& name, const Expr& e) { for (size_t i = 0; i < patterns_.size(); ++i) { std::string& p = patterns_[i];