diff --git a/src/pass/lower_intrin.cc b/src/pass/lower_intrin.cc
index 84f3cb5bba6ffb7ec277eb8647773df319350e2e..33ac6a94ecf72b81ab6d4044e02a74b4704de067 100644
--- a/src/pass/lower_intrin.cc
+++ b/src/pass/lower_intrin.cc
@@ -34,22 +34,50 @@ class IntrinInjecter : public IRMutator {
   }
 
   Expr Mutate_(const Add* op, const Expr& e) final {
-    if (fma_ == nullptr || !op->type.is_float()) {
-      return IRMutator::Mutate_(op, e);
-    }
     if (const Mul* mb = op->b.as<Mul>()) {
-      Expr r = (*fma_)(Call::make(
-          op->type, "fma", {mb->a, mb->b, op->a}, Call::PureIntrinsic));
-      if (r.defined()) return this->Mutate(r);
+      return MakeFMA(mb->a, mb->b, op->a, op, e);
     } else if (const Mul* ma = op->a.as<Mul>()) {
+      return MakeFMA(ma->a, ma->b, op->b, op, e);
+    }
+    return IRMutator::Mutate_(op, e);
+  }
+
+ private:
+  Expr SwapBroadcastCast(const Expr& e) {
+    // Try to change broadcast(cast(x)) to cast(broadcast(x))
+    // For some targets, LLVM will generate more efficient FMA
+    // instruction with the latter. For example, vmla vs. vmlal
+    // on ARM.
+    if (const Broadcast* bcast = e.as<Broadcast>()) {
+      if (const Cast* cast = bcast->value.as<Cast>()) {
+        if (cast->type.bits() == cast->value.type().bits() * 2) {
+          Expr new_bcast = Broadcast::make(cast->value, bcast->lanes);
+          return Cast::make(bcast->type, new_bcast);
+        }
+      }
+    }
+    return e;
+  }
+
+  Expr MakeFMA(const Expr& a, const Expr& b, const Expr& c,
+               const Add* op, const Expr& e) {
+    // emit fma instruction: a * b + c
+    Expr lhs = SwapBroadcastCast(a);
+    Expr rhs = SwapBroadcastCast(b);
+
+    if (fma_ != nullptr && op->type.is_float()) {
       Expr r = (*fma_)(Call::make(
-          op->type, "fma", {ma->a, ma->b, op->b}, Call::PureIntrinsic));
+          op->type, "fma", {lhs, rhs, c}, Call::PureIntrinsic));
       if (r.defined()) return this->Mutate(r);
+    } else {
+      if (!lhs.same_as(a) || !rhs.same_as(b)) {
+        Expr mul = this->Mutate(Mul::make(lhs, rhs));
+        return Add::make(mul, this->Mutate(c));
+      }
     }
     return IRMutator::Mutate_(op, e);
   }
 
- private:
   Expr ApplyPattern(const std::string& name, const Expr& e) {
     for (size_t i = 0; i < patterns_.size(); ++i) {
       std::string& p = patterns_[i];