diff --git a/include/tvm/ir.h b/include/tvm/ir.h
index 01b1890225316b3fdef95f14763652659e693bbf..934ce0c5ec9f21725a08cb84de8cf9415c669ea4 100644
--- a/include/tvm/ir.h
+++ b/include/tvm/ir.h
@@ -182,6 +182,13 @@ constexpr const char* prefetch_scope = "prefetch_scope";
 constexpr const char* scan_update_scope = "scan_update_scope";
 /*! \brief Mark of scan init scope */
 constexpr const char* scan_init_scope = "scan_init_scope";
+/*!
+ * \brief Mark alignment of buffer dimension
+ *  stmt.node is Tensor
+ *  stmt.value is tvm_tuple(dim, align, offset)
+ *  This gives hint to require stride of dim to be k * align + offset.
+ */
+constexpr const char* buffer_dim_align = "buffer_dim_align";
 /*!
  * \brief Bind the buffer specification to the region of the op
  *  When this scope occurs, the stmt.node is a Array<NodeRef> = [buffer, tensor]
diff --git a/include/tvm/operation.h b/include/tvm/operation.h
index 9d535e336cc13181864134322b3e92cd4d4052e4..f44cfe895e6dd4f67663534d99877439935ca0a2 100644
--- a/include/tvm/operation.h
+++ b/include/tvm/operation.h
@@ -104,13 +104,13 @@ class OperationNode : public FunctionBaseNode {
   /*!
    * \brief Build the Realize statement that realizes
    *   the op's output tensors.
-   * \param self The reference to self.
+   * \param stage the op's stage.
    * \param realize_map The realization domain map of the operators.
    * \param body The body that is going to get
    * \return A realization statement that wraps body.
    */
   virtual Stmt BuildRealize(
-      const Operation& self,
+      const Stage& stage,
       const std::unordered_map<IterVar, Range>& realize_map,
       const Stmt& body) const = 0;
   /*!
@@ -155,7 +155,7 @@ class PlaceholderOpNode : public OperationNode {
       const std::unordered_map<Tensor, TensorDom>& tensor_dom,
       std::unordered_map<IterVar, Range>* out_dom_map) const final;
   Stmt BuildRealize(
-      const Operation& self,
+      const Stage& stage,
       const std::unordered_map<IterVar, Range>& realize_map,
       const Stmt& body) const final;
   Stmt BuildProvide(
@@ -206,7 +206,7 @@ class ComputeOpNode : public OperationNode {
       const std::unordered_map<Tensor, TensorDom>& tensor_dom,
       std::unordered_map<IterVar, Range>* out_dom_map) const final;
   Stmt BuildRealize(
-      const Operation& self,
+      const Stage& stage,
       const std::unordered_map<IterVar, Range>& realize_map,
       const Stmt& body) const final;
   Stmt BuildProvide(
@@ -277,7 +277,7 @@ class ScanOpNode : public OperationNode {
       const std::unordered_map<Tensor, TensorDom>& tensor_dom,
       std::unordered_map<IterVar, Range>* out_dom_map) const final;
   Stmt BuildRealize(
-      const Operation& self,
+      const Stage& stage,
       const std::unordered_map<IterVar, Range>& realize_map,
       const Stmt& body) const final;
   Stmt BuildProvide(
@@ -340,7 +340,7 @@ class ExternOpNode : public OperationNode {
       const std::unordered_map<Tensor, TensorDom>& tensor_dom,
       std::unordered_map<IterVar, Range>* out_dom_map) const final;
   Stmt BuildRealize(
-      const Operation& self,
+      const Stage& stage,
       const std::unordered_map<IterVar, Range>& realize_map,
       const Stmt& body) const final;
   Stmt BuildProvide(
diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h
index 38b939d51aca57bc9a0baead3d4a7e0149c1b96b..0f846381314ffe4e8440d366e21feb434cef3c31 100644
--- a/include/tvm/schedule.h
+++ b/include/tvm/schedule.h
@@ -197,6 +197,17 @@ class Stage : public NodeRef {
    * \return reference to self
    */
   Stage& prefetch(const Tensor &domain, IterVar var, Expr offset); //NOLINT(*)
+  /*!
+   * \brief Set alignment requirement for specific dimension.
+   *
+   *  Such that stride[axis] == k * factor + offset for some k.
+   *
+   * \param axis The dimension to be specified for alignment.
+   * \param factor The factor multiple of alignment
+   * \param offset The required offset factor.
+   * \return reference to self
+   */
+  Stage& storage_align(IterVar axis, int factor, int offset); //NOLINT(*)
   /*!
    * \brief whether the stage has been scheduled.
    * \return whether the stage has been scheduled.
@@ -496,6 +507,10 @@ class IterVarAttrNode : public Node {
    *   when the axis is marked as Tensorized
    */
   TensorIntrin tensor_intrin;
+  /*! \brief Alignment factor of buffer dimension */
+  int dim_align_factor{0};
+  /*! \brief Alignment offset of buffer dimension */
+  int dim_align_offset{0};
   /*!
    * \brief Additional pragmas, array of StringImm
    */
@@ -507,6 +522,8 @@ class IterVarAttrNode : public Node {
     v->Visit("prefetch_data", &prefetch_data);
     v->Visit("prefetch_offset", &prefetch_offset);
     v->Visit("tensor_intrin", &tensor_intrin);
+    v->Visit("dim_align_factor", &dim_align_factor);
+    v->Visit("dim_align_offset", &dim_align_offset);
     v->Visit("pragmas", &pragmas);
   }
 
diff --git a/python/tvm/schedule.py b/python/tvm/schedule.py
index c17622bcb22f0d79ae40f4cdfb383776bd49676a..6dca90a7fddb4b66a7a61acb46b79a438e07ba51 100644
--- a/python/tvm/schedule.py
+++ b/python/tvm/schedule.py
@@ -569,4 +569,24 @@ class Stage(NodeBase):
         """
         _api_internal._StagePrefetch(self, tensor, var, offset)
 
+    def storage_align(self, axis, factor, offset):
+        """Set alignment requirement for specific axis
+
+        This ensures that stride[axis] == k * factor + offset for some k.
+        This is useful to set memory layout to for more friendly memory
+        access pattern. For example, we can set alignment to be
+        factor=2, offset=1 to avoid bank conflict for thread access on
+        higher dimension in GPU shared memory.
+
+        Parameters
+        ----------
+        axis : IterVar
+            The axis dimension to be aligned.
+        factor : int
+            The factor in alignment specification.
+        offset : int
+            The offset in the alignment specification.
+        """
+        _api_internal._StageStorageAlign(self, axis, factor, offset)
+
 _init_api("tvm.schedule")
diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc
index 990ad674aefcb0244311d40cb8b2f4054564bea2..dda1bc5f56c5bdaea567a84b1ad6d940635b5b5e 100644
--- a/src/api/api_lang.cc
+++ b/src/api/api_lang.cc
@@ -388,6 +388,12 @@ TVM_REGISTER_API("_StagePrefetch")
       .prefetch(args[1], args[2], args[3]);
   });
 
+TVM_REGISTER_API("_StageStorageAlign")
+  .set_body([](TVMArgs args, TVMRetValue *ret) {
+    args[0].operator Stage()
+      .storage_align(args[1], args[2], args[3]);
+  });
+
 TVM_REGISTER_API("_ScheduleNormalize")
 .set_body([](TVMArgs args, TVMRetValue* ret) {
     *ret = args[0].operator Schedule()
diff --git a/src/op/compute_op.cc b/src/op/compute_op.cc
index 06f7cf46a63f69962b0c970b32f95686388b5ded..c7e1b54a44522c7c9c0176fdcd8a54fe37fda55d 100644
--- a/src/op/compute_op.cc
+++ b/src/op/compute_op.cc
@@ -198,19 +198,35 @@ void ComputeOpNode::GatherBound(
 }
 
 Stmt ComputeOpNode::BuildRealize(
-    const Operation& self,
+    const Stage& stage,
     const std::unordered_map<IterVar, Range>& realize_map,
     const Stmt& realize_body) const {
-  CHECK_EQ(self.operator->(), this);
+  CHECK_EQ(stage->op.get(), this);
   Halide::Internal::Region bounds;
   for (IterVar iv : this->axis) {
     bounds.push_back(realize_map.at(iv));
   }
   Stmt realize = realize_body;
-  for (int i = self->num_outputs(); i > 0; --i) {
-    Tensor t = self.output(i-1);
+  for (int i = this->num_outputs(); i > 0; --i) {
+    Tensor t = stage->op.output(i-1);
     realize = ir::Realize::make(t->op, t->value_index,
       t->dtype, bounds, const_true(), realize);
+    // alignment requirement, only useful for compute
+    for (size_t i = 0; i < this->axis.size(); ++i) {
+      auto it = stage->iter_var_attrs.find(this->axis[i]);
+      if (it != stage->iter_var_attrs.end()) {
+        IterVarAttr attr = (*it).second;
+        if (attr->dim_align_factor != 0) {
+          Array<Expr> tuple = {static_cast<int>(i),
+                               attr->dim_align_factor,
+                               attr->dim_align_offset};
+          realize = ir::AttrStmt::make(
+              t, ir::attr::buffer_dim_align,
+              Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic),
+              realize);
+        }
+      }
+    }
   }
   return realize;
 }
@@ -304,7 +320,7 @@ enum class ComputeType {
 };
 
 ComputeType DetectComputeType(const ComputeOpNode* self,
-                                const Stage& stage) {
+                              const Stage& stage) {
   // Verify correctness of leaf nest.
   int normal_red = 0, thread_red = 0, tensorize = 0;
 
diff --git a/src/op/extern_op.cc b/src/op/extern_op.cc
index cebc7ae727ae1929b500f6efeda22e5d9acb7a2c..68a51df32616268e116fa65f37796497a18377be 100644
--- a/src/op/extern_op.cc
+++ b/src/op/extern_op.cc
@@ -106,13 +106,13 @@ void ExternOpNode::GatherBound(
 }
 
 Stmt ExternOpNode::BuildRealize(
-    const Operation& self,
+    const Stage& stage,
     const std::unordered_map<IterVar, Range>& realize_map,
     const Stmt& body) const {
-  CHECK_EQ(self.operator->(), this);
+  CHECK_EQ(stage->op.get(), this);
   Stmt realize_body = body;
   for (int k = 0; k < num_outputs(); ++k) {
-    Tensor t = self.output(k);
+    Tensor t = stage->op.output(k);
     Halide::Internal::Region bounds;
     for (size_t i = 0; i < t->shape.size(); ++i) {
       bounds.push_back(
diff --git a/src/op/op_util.cc b/src/op/op_util.cc
index 46cbd9d7e76a5712f01294e8c35d3c17e5f52c3b..fd8bd168db6baaca4f17f196098044e71b1f881b 100644
--- a/src/op/op_util.cc
+++ b/src/op/op_util.cc
@@ -1,5 +1,5 @@
 /*!
- *  Copyright (c) 2017 by Contributors
+ *  Copyright (c) 2017 by5A Contributors
  * \brief Utility to make loop nest.
  * \file op_util.cc
  */
diff --git a/src/op/op_util.h b/src/op/op_util.h
index 419035b672b022dbd3c1e9753deaf5e4d7313d3f..165113863a758dc12285ea3d2178d26493453298 100644
--- a/src/op/op_util.h
+++ b/src/op/op_util.h
@@ -52,6 +52,7 @@ MakeBoundCheck(const Stage& stage,
                bool skip_ivar_domain,
                const std::unordered_set<IterVar>& skip_iter,
                const std::unordered_map<IterVar, Expr>& value_map);
+
 /*!
  * \brief Create a nest of if checking the predicates.
  *
diff --git a/src/op/placeholder_op.cc b/src/op/placeholder_op.cc
index 450d69fc1ddf5237e05eeb63fe61e6ab096e699e..4e9d1d094d7444e5c9c88f574c07f6f50214a337 100644
--- a/src/op/placeholder_op.cc
+++ b/src/op/placeholder_op.cc
@@ -70,7 +70,7 @@ void PlaceholderOpNode::GatherBound(
 }
 
 Stmt PlaceholderOpNode::BuildRealize(
-    const Operation& self,
+    const Stage& stage,
     const std::unordered_map<IterVar, Range>& realize_map,
     const Stmt& body) const {
   return body;
diff --git a/src/op/scan_op.cc b/src/op/scan_op.cc
index 617625c2e690692c4c6440d019d52a47e250b80f..f03eb95f128fa1fb0619a2a84ba249330084a918 100644
--- a/src/op/scan_op.cc
+++ b/src/op/scan_op.cc
@@ -226,17 +226,17 @@ void ScanOpNode::GatherBound(
 }
 
 Stmt ScanOpNode::BuildRealize(
-    const Operation& self,
+    const Stage& stage,
     const std::unordered_map<IterVar, Range>& dom_map,
     const Stmt& body) const {
-  CHECK_EQ(self.operator->(), this);
+  CHECK_EQ(stage->op.get(), this);
   Range sdom = dom_map.at(this->scan_axis);
   Range tdom = Range::make_by_min_extent(
       0, ir::Simplify(sdom->extent + sdom->min));
   Stmt ret = body;
   size_t sp_idx = 0;
   for (size_t i = 0; i < update.size(); ++i) {
-    Tensor t = self.output(i);
+    Tensor t = stage->op.output(i);
     CHECK_EQ(static_cast<size_t>(t->value_index), i);
     Halide::Internal::Region bounds;
     bounds.push_back(tdom);
diff --git a/src/pass/storage_flatten.cc b/src/pass/storage_flatten.cc
index 543286efc1fa251644a29eed50ddc36c2b5d8d42..551442486e9589d5485b4cd61228adfd9aaf1d81 100644
--- a/src/pass/storage_flatten.cc
+++ b/src/pass/storage_flatten.cc
@@ -62,6 +62,19 @@ class StorageFlattener : public IRMutator {
       return stmt;
     } else if (op->attr_key == attr::buffer_bind_scope) {
       return HandleBufferBindScope(op);
+    } else if (op->attr_key == attr::buffer_dim_align) {
+      Tensor tensor(op->node.node_);
+      const Call* tuple = op->value.as<Call>();
+      CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple));
+      TensorKey key{tensor->op, tensor->value_index};
+      auto& vinfo = dim_align_[key];
+      int dim = tuple->args[0].as<IntImm>()->value;
+      if (static_cast<size_t>(dim) >= vinfo.size()) {
+        vinfo.resize(dim + 1);
+      }
+      vinfo[dim].align_factor = tuple->args[1].as<IntImm>()->value;
+      vinfo[dim].align_offset = tuple->args[2].as<IntImm>()->value;
+      return this->Mutate(op->body);
     }
     return IRMutator::Mutate_(op, s);
   }
@@ -116,20 +129,45 @@ class StorageFlattener : public IRMutator {
           align = (info->max_simd_bits + op->type.bits() - 1) / op->type.bits();
         }
       }
-
+      Array<Expr> strides;
+      if (dim_align_.count(key) != 0) {
+        std::vector<Expr> rstrides;
+        const std::vector<DimAlignInfo>& avec = dim_align_[key];
+        Expr stride = make_const(shape[0].type(), 1);
+        for (size_t i = shape.size(); i != 0; --i) {
+          size_t dim = i - 1;
+          if (dim < avec.size() && avec[dim].align_factor != 0) {
+            Expr factor = make_const(stride.type(), avec[dim].align_factor);
+            Expr offset = make_const(stride.type(), avec[dim].align_offset);
+            stride = stride + (factor + offset - stride % factor) % factor;
+            stride = ir::Simplify(stride);
+          }
+          rstrides.push_back(stride);
+          stride = arith::ComputeExpr<Mul>(stride, shape[dim]);
+        }
+        strides = Array<Expr>(rstrides.rbegin(), rstrides.rend());
+      }
       e.buffer = BufferNode::make(
           Var(key.GetName(), Handle()),
-          op->type, shape,
-          Array<Expr>(), Expr(),
+          op->type, shape, strides, Expr(),
           key.GetName(), skey.to_string(),
           align, 0);
+
       buf_map_[key] = e;
       Stmt body = this->Mutate(op->body);
       buf_map_[key].released = true;
+      Stmt ret;
 
-      Stmt ret = Allocate::make(
-          e.buffer->data, e.buffer->dtype, e.buffer->shape,
-          make_const(Bool(e.buffer->dtype.lanes()), true), body);
+      if (strides.size() != 0) {
+        ret = Allocate::make(
+            e.buffer->data, e.buffer->dtype,
+            {arith::ComputeExpr<Mul>(e.buffer->strides[0], e.buffer->shape[0])},
+            make_const(Bool(e.buffer->dtype.lanes()), true), body);
+      } else {
+        ret = Allocate::make(
+            e.buffer->data, e.buffer->dtype, e.buffer->shape,
+            make_const(Bool(e.buffer->dtype.lanes()), true), body);
+      }
       ret = AttrStmt::make(
           e.buffer->data, attr::storage_scope,
           StringImm::make(e.buffer->scope), ret);
@@ -283,7 +321,11 @@ class StorageFlattener : public IRMutator {
     }
     return body;
   }
-
+  // The buffer entry in the flatten map
+  struct DimAlignInfo {
+    int align_factor{0};
+    int align_offset{0};
+  };
   // The buffer entry in the flatten map
   struct BufferEntry {
     // the buffer of storage
@@ -294,7 +336,6 @@ class StorageFlattener : public IRMutator {
     bool external{false};
     // Whether we are out of allocation bounds and buffer get released.
     bool released{false};
-    // TODO(tqchen) allow permutation and inference of index dimension.
     // relative index
     inline Array<Expr> RelIndex(Array<Expr> args) const {
       if (bounds.size() != 0) {
@@ -314,6 +355,9 @@ class StorageFlattener : public IRMutator {
   std::unordered_map<const Variable*, Expr> var_remap_;
   // Buffer map
   std::unordered_map<TensorKey, BufferEntry> buf_map_;
+  // Dimension alignment
+  std::unordered_map<TensorKey, std::vector<DimAlignInfo> > dim_align_;
+  // Storage scope
   std::unordered_map<const Node*, std::string> storage_scope_;
   // The current thread scope.
   std::vector<ThreadScope> curr_thread_scope_;
diff --git a/src/schedule/schedule_lang.cc b/src/schedule/schedule_lang.cc
index 0e78c327e6b1b7359fcf0f418e95b5db564e7898..c07f60dccceb95d46ccf03638784357d0ce71b24 100644
--- a/src/schedule/schedule_lang.cc
+++ b/src/schedule/schedule_lang.cc
@@ -296,10 +296,15 @@ Stage& Stage::tile(IterVar x_parent, IterVar y_parent,
 }
 
 template<typename FUpdate>
-inline void UpdateIterVarAttr(StageNode* self, IterVar var, FUpdate fupdate) {
-  ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
-  ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
-  FindLeafVar(all_vars, leaf_vars, var);
+inline void UpdateIterVarAttr(StageNode* self,
+                              IterVar var,
+                              FUpdate fupdate,
+                              bool need_leaf = true) {
+  if (need_leaf) {
+    ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
+    ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
+    FindLeafVar(all_vars, leaf_vars, var);
+  }
   auto it = self->iter_var_attrs.find(var);
   std::shared_ptr<IterVarAttrNode> n;
   if (it != self->iter_var_attrs.end()) {
@@ -371,6 +376,15 @@ Stage& Stage::prefetch(const Tensor &tensor, IterVar var, Expr offset) {
   return *this;
 }
 
+Stage& Stage::storage_align(IterVar axis, int factor, int offset) {
+  StageNode *self = operator->();
+  UpdateIterVarAttr(self, axis, [factor, offset](IterVarAttrNode* n) {
+      n->dim_align_factor = factor;
+      n->dim_align_offset = offset;
+    }, false);
+  return *this;
+}
+
 Stage CopyStage(const Stage& s) {
   std::shared_ptr<StageNode> n =
       std::make_shared<StageNode>(*s.operator->());
diff --git a/src/schedule/schedule_ops.cc b/src/schedule/schedule_ops.cc
index b49b86bbe074a103a5b216ecd314d2bbea699e42..724672d2a9519114fe940347c2744cc5a42916a3 100644
--- a/src/schedule/schedule_ops.cc
+++ b/src/schedule/schedule_ops.cc
@@ -33,7 +33,7 @@ Stmt MakePipeline(const Stage& s,
     consumer = ProducerConsumer::make(s->op, false, consumer);
     pipeline = Block::make(producer, consumer);
   }
-  pipeline = s->op->BuildRealize(s->op, dom_map, pipeline);
+  pipeline = s->op->BuildRealize(s, dom_map, pipeline);
   // use attribute to mark scope of the operation.
   pipeline = AttrStmt::make(
       s->op, ir::attr::realize_scope,
@@ -194,6 +194,18 @@ class SchedulePostProc : public IRMutator {
           return this->Mutate(op->body);
         }
       }
+    } else if (op->attr_key == ir::attr::buffer_dim_align) {
+      Tensor tensor(op->node.node_);
+      auto it = replace_op_.find(tensor->op.get());
+      if (it != replace_op_.end()) {
+        if (it->second.defined()) {
+          return AttrStmt::make(
+              it->second.output(tensor->value_index),
+              op->attr_key, op->value, Mutate(op->body));
+        } else {
+          return this->Mutate(op->body);
+        }
+      }
     }
     return IRMutator::Mutate_(op, s);
   }
diff --git a/tests/python/unittest/test_pass_storage_flatten.py b/tests/python/unittest/test_pass_storage_flatten.py
index b28a1475929667cc6c88dac1e4702135b5acf946..4e2feed23eff0b5006fb0890636e3b0abd58dced 100644
--- a/tests/python/unittest/test_pass_storage_flatten.py
+++ b/tests/python/unittest/test_pass_storage_flatten.py
@@ -32,6 +32,27 @@ def test_flatten_prefetch():
     assert isinstance(stmt.body, tvm.stmt.For)
     assert stmt.body.extent.value == 2
 
+
+def test_flatten_storage_align():
+    m = 8
+    l = 16
+    A = tvm.placeholder((m, l), name='A')
+    A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
+    A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
+
+    s = tvm.create_schedule(A2.op)
+    s[A1].storage_align(A1.op.axis[0], 2, 1)
+    bounds = tvm.schedule.InferBound(s)
+    assert isinstance(bounds, tvm.container.Map)
+    stmt = tvm.schedule.ScheduleOps(s, bounds)
+    Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
+    A2b = tvm.decl_buffer(A2.shape, A2.dtype, name='A2')
+    stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64)
+    stmt = tvm.ir_pass.Simplify(stmt)
+    assert(stmt.body.extents[0].value == 17 * 8)
+
+
 if __name__ == "__main__":
+    test_flatten_storage_align()
     test_flatten2()
     test_flatten_prefetch()
diff --git a/tutorials/deployment/cross_compilation_and_rpc.py b/tutorials/deployment/cross_compilation_and_rpc.py
index d8c5fceb2441e29af10b50a7da56bd3abc45c818..1763b153ba8374b2e4868d2b9b18d26df39e5876 100644
--- a/tutorials/deployment/cross_compilation_and_rpc.py
+++ b/tutorials/deployment/cross_compilation_and_rpc.py
@@ -28,14 +28,13 @@ from tvm.contrib import rpc, util
 # local machine, we need build runtime on remote device.
 #
 # To get started, clone tvm repo from github. It is important to clone
-# the submodules along, with --recursive option (Assuming you are in 
-# your home directory): 
+# the submodules along, with --recursive option (Assuming you are in
+# your home directory):
 #
 #   .. code-block:: bash
-# 
+#
 #     git clone --recursive https://github.com/dmlc/tvm
-# 
-######################################################################
+#
 # .. note::
 #
 #   Usually device has limited resources and we only need to build
@@ -51,14 +50,13 @@ from tvm.contrib import rpc, util
 #
 #   Also make sure that you have set :code:`USE_RPC=1` in your
 #   :code:`config.mk`. We don't need LLVM when building runtime, so
-#   :code:`LLVM_CONFIG = llvm-config` in :code:`config.mk`is commented
+#   :code:`LLVM_CONFIG = llvm-config` in :code:`config.mk` is commented
 #   out by default. After that, build runtime!
 #
 #   .. code-block:: bash
 #
 #     make runtime
 #
-######################################################################
 # After success of buildind runtime, we need set environment varibles
 # in :code:`~/.bashrc` file of yourself account or :code:`/etc/profile`
 # of system enviroment variables. Assuming your TVM directory is in
@@ -95,7 +93,7 @@ from tvm.contrib import rpc, util
 # successful to start RPC server on your device.
 #
 #    .. code-block:: bash
-# 
+#
 #    Loading runtime library /home/YOURNAME/code/tvm/lib/libtvm_runtime.so... exec only
 #    INFO:root:RPCServer: bind to 0.0.0.0:9090
 #