From f2b913925dae089160e5f20eaeaa5c8ef3eff2e8 Mon Sep 17 00:00:00 2001
From: Tianqi Chen <tqchen@users.noreply.github.com>
Date: Sun, 3 Dec 2017 22:38:28 -0800
Subject: [PATCH] Support rank-0 tensor (#687)

* Support rank-0 tensor

* fix lint
---
 include/tvm/buffer.h                          |  5 +++
 include/tvm/packed_func_ext.h                 |  4 ++
 include/tvm/tensor.h                          |  2 +-
 python/tvm/_ffi/ndarray.py                    |  7 ++--
 python/tvm/api.py                             |  1 -
 python/tvm/tensor.py                          |  8 +++-
 src/arithmetic/compute_expr.h                 | 12 ++++--
 src/lang/buffer.cc                            | 37 ++++++++++---------
 src/pass/arg_binder.cc                        | 20 +++++-----
 src/pass/inject_double_buffer.cc              |  3 +-
 src/pass/inject_virtual_thread.cc             |  3 +-
 src/pass/storage_flatten.cc                   | 15 ++++++--
 src/pass/storage_rewrite.cc                   |  2 +-
 src/runtime/c_runtime_api.cc                  | 10 +++--
 src/runtime/graph/graph_runtime.cc            |  6 ++-
 src/schedule/schedule_dataflow_rewrite.cc     |  6 +--
 tests/python/unittest/test_codegen_device.py  |  9 +++--
 tests/python/unittest/test_codegen_llvm.py    | 25 +++++++++++++
 tests/python/unittest/test_lang_tensor.py     | 13 +++++++
 .../unittest/test_runtime_packed_func.py      |  8 ++++
 topi/python/topi/nn/dense.py                  |  4 +-
 21 files changed, 143 insertions(+), 57 deletions(-)

diff --git a/include/tvm/buffer.h b/include/tvm/buffer.h
index ad4872b8e..f2790f6df 100644
--- a/include/tvm/buffer.h
+++ b/include/tvm/buffer.h
@@ -124,6 +124,11 @@ class BufferNode : public Node {
     v->Visit("offset_factor", &offset_factor);
   }
 
+  /*! \return preferred index type for this buffer node */
+  Type DefaultIndexType() const {
+    return shape.size() != 0 ? shape[0].type() : Int(32);
+  }
+
   // User can specify data_alignment and offset_factor to be 0
   // A default value will be picked.
   TVM_DLL static Buffer make(Var ptr,
diff --git a/include/tvm/packed_func_ext.h b/include/tvm/packed_func_ext.h
index 5242a0576..542de6a36 100644
--- a/include/tvm/packed_func_ext.h
+++ b/include/tvm/packed_func_ext.h
@@ -14,6 +14,7 @@
 
 #include "./base.h"
 #include "./expr.h"
+#include "./tensor.h"
 #include "./runtime/packed_func.h"
 
 namespace tvm {
@@ -116,6 +117,9 @@ inline TVMArgValue::operator Halide::Expr() const {
   if (sptr->is_type<IterVarNode>()) {
     return IterVar(sptr)->var;
   }
+  if (sptr->is_type<TensorNode>()) {
+    return Tensor(sptr)();
+  }
   CHECK(NodeTypeChecker<Expr>::Check(sptr.get()))
       << "Expected type " << NodeTypeName<Expr>()
       << " but get " << sptr->type_key();
diff --git a/include/tvm/tensor.h b/include/tvm/tensor.h
index a6613a4dc..4f46d86e9 100644
--- a/include/tvm/tensor.h
+++ b/include/tvm/tensor.h
@@ -188,7 +188,7 @@ inline bool Tensor::operator==(const Tensor& other) const {
 #define DEFINE_OVERLOAD_SLICE_UNARY_OP(Op)                              \
   inline Expr operator Op (const Tensor::Slice& a) {                    \
     return Op a.operator Expr() ;                                       \
-  }
+  }                                                                     \
 
 #define DEFINE_OVERLOAD_SLICE_BINARY_OP(Op)                             \
   template<typename T>                                                  \
diff --git a/python/tvm/_ffi/ndarray.py b/python/tvm/_ffi/ndarray.py
index b0dfd0f73..135701a80 100644
--- a/python/tvm/_ffi/ndarray.py
+++ b/python/tvm/_ffi/ndarray.py
@@ -177,13 +177,14 @@ class NDArrayBase(_NDArrayBase):
             shape = shape + (t.lanes,)
             t.lanes = 1
             dtype = str(t)
-        source_array = np.ascontiguousarray(source_array, dtype=dtype)
+
         if source_array.shape != shape:
             raise ValueError("array shape do not match the shape of NDArray {0} vs {1}".format(
                 source_array.shape, shape))
+        source_array = np.ascontiguousarray(source_array, dtype=dtype)
         assert source_array.flags['C_CONTIGUOUS']
         data = source_array.ctypes.data_as(ctypes.c_void_p)
-        nbytes = ctypes.c_size_t(np.prod(source_array.shape) * source_array.dtype.itemsize)
+        nbytes = ctypes.c_size_t(source_array.size * source_array.dtype.itemsize)
         check_call(_LIB.TVMArrayCopyFromBytes(self.handle, data, nbytes))
         return self
 
@@ -212,7 +213,7 @@ class NDArrayBase(_NDArrayBase):
         np_arr = np.empty(shape, dtype=dtype)
         assert np_arr.flags['C_CONTIGUOUS']
         data = np_arr.ctypes.data_as(ctypes.c_void_p)
-        nbytes = ctypes.c_size_t(np.prod(np_arr.shape) * np_arr.dtype.itemsize)
+        nbytes = ctypes.c_size_t(np_arr.size * np_arr.dtype.itemsize)
         check_call(_LIB.TVMArrayCopyToBytes(self.handle, data, nbytes))
         return np_arr
 
diff --git a/python/tvm/api.py b/python/tvm/api.py
index dfe6e4cf7..08b3d95dc 100644
--- a/python/tvm/api.py
+++ b/python/tvm/api.py
@@ -462,7 +462,6 @@ def decl_buffer(shape,
         elem_offset = var('%s_elem_offset' % name, shape[0].dtype)
     if data is None:
         data = var(name, "handle")
-
     return _api_internal._Buffer(
         data, dtype, shape, strides, elem_offset, name, scope,
         data_alignment, offset_factor)
diff --git a/python/tvm/tensor.py b/python/tvm/tensor.py
index 98a142e8c..f169ff1b6 100644
--- a/python/tvm/tensor.py
+++ b/python/tvm/tensor.py
@@ -32,7 +32,7 @@ class TensorSlice(NodeGeneric, _expr.ExprOp):
 itervar_cls = None
 
 @register_node
-class Tensor(NodeBase):
+class Tensor(NodeBase, _expr.ExprOp):
     """Tensor object, to construct, see function.Tensor"""
     def __call__(self, *indices):
         ndim = self.ndim
@@ -60,7 +60,13 @@ class Tensor(NodeBase):
 
     def __eq__(self, other):
         if not isinstance(other, Tensor):
+            if isinstance(other, _expr.ExprOp):
+                return _expr.EqualOp(self, other)
             return False
+        if self.ndim == 0 and other.ndim == 0:
+            raise ValueError("Equal == comparison among rank-0 tensor is ambiguous, "
+                             "use Tensor.equal for content expression equvalence, "
+                             "use Tensor.same_as for exact reference comparison")
         return _api_internal._TensorEqual(self, other)
 
     @property
diff --git a/src/arithmetic/compute_expr.h b/src/arithmetic/compute_expr.h
index 18ae8530f..994bcb13e 100644
--- a/src/arithmetic/compute_expr.h
+++ b/src/arithmetic/compute_expr.h
@@ -33,11 +33,14 @@ inline Expr ComputeExpr(Expr lhs, Expr rhs) {
 /*!
  * \brief Compute an reduction with Op
  * \param values The input values.
+ * \param empty_value The value when return if it is empty, can be Expr()
+ *        which will cause an error to be rasied.
  * \tparam Op The computation operator
  * \return The result.
  */
 template<typename Op>
-inline Expr ComputeReduce(const Array<Expr>& values);
+inline Expr ComputeReduce(
+    const Array<Expr>& values, Expr empty_value);
 
 template<typename T>
 inline bool GetConst(Expr e, T* out);
@@ -139,8 +142,11 @@ inline Expr ComputeExpr<ir::Min>(Expr a, Expr b) {
 }
 
 template<typename Op>
-inline Expr ComputeReduce(const Array<Expr>& values) {
-  CHECK_NE(values.size(), 0U);
+inline Expr ComputeReduce(const Array<Expr>& values, Expr empty_value) {
+  if (values.size() == 0U) {
+    CHECK(empty_value.defined());
+    return empty_value;
+  }
   Expr res = values[0];
   for (size_t i = 1; i < values.size(); ++i) {
     res = ComputeExpr<Op>(res, values[i]);
diff --git a/src/lang/buffer.cc b/src/lang/buffer.cc
index 5cf7ddef3..af76dcc94 100644
--- a/src/lang/buffer.cc
+++ b/src/lang/buffer.cc
@@ -11,15 +11,6 @@
 
 namespace tvm {
 
-Array<Expr> GetStrides(Array<Expr> shape) {
-  CHECK_NE(shape.size(), 0U);
-  std::vector<Expr> vec{make_const(shape[0].type(), 1)};
-  for (size_t i = shape.size() - 1; i != 0; --i) {
-    vec.push_back(shape[i - 1] * vec.back());
-  }
-  return Array<Expr>(vec.rbegin(), vec.rend());
-}
-
 Array<Expr> SimplifyArray(Array<Expr> array) {
   for (size_t i = 0; i < array.size(); ++i) {
     array.Set(i, ir::Simplify(array[i]));
@@ -235,10 +226,12 @@ inline Expr ElemOffset(const BufferNode* n, Array<Expr> index) {
   Expr base = n->elem_offset;
   if (n->strides.size() == 0) {
     CHECK_EQ(n->shape.size(), index.size());
-    if (is_zero(base)) {
-      base = index[0];
-    } else {
-      base = base + index[0];
+    if (n->shape.size() != 0) {
+      if (is_zero(base)) {
+        base = index[0];
+      } else {
+        base = base + index[0];
+      }
     }
     base = MergeMulMod(base);
     for (size_t i = 1; i < index.size(); ++i) {
@@ -294,9 +287,10 @@ Stmt Buffer::vstore(Array<Expr> begin, Expr value) const {
 
 Buffer Buffer::MakeStrideView() const {
   if ((*this)->strides.size() != 0) return *this;
+  if ((*this)->shape.size() == 0) return *this;
   std::vector<Expr> temp;
   auto n = std::make_shared<BufferNode>(*operator->());
-  Expr acc = make_const(n->shape[0].type(), 1);
+  Expr acc = make_const(n->DefaultIndexType(), 1);
   for (size_t i = n->shape.size(); i != 0 ; --i) {
     temp.push_back(acc);
     acc = acc * n->shape[i - 1];
@@ -344,9 +338,16 @@ Buffer Buffer::MakeSlice(Array<Expr> begins, Array<Expr> extents) const {
 Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes) const {
   const BufferNode* self = operator->();
   Expr e_dtype;
-  Expr extent = (self->strides.size() == self->shape.size() ?
-                 arith::ComputeExpr<ir::Mul>(self->strides[0], self->shape[0]):
-                 arith::ComputeReduce<ir::Mul>(self->shape));
+  Expr extent;
+  if (self->shape.size() == 0) {
+    extent = make_const(self->DefaultIndexType(), 1);
+  } else if (self->strides.size() == self->shape.size()) {
+    int highest_dim = 0;
+    extent = arith::ComputeExpr<ir::Mul>(
+        self->strides[highest_dim], self->shape[highest_dim]);
+  } else {
+    extent = arith::ComputeReduce<ir::Mul>(self->shape, Expr());
+  }
   Expr elem_offset = self->elem_offset;
   if (content_lanes > 1) {
     e_dtype = make_zero(self->dtype.with_lanes(content_lanes));
@@ -383,7 +384,7 @@ Buffer BufferNode::make(Var data,
   }
   n->scope = std::move(scope);
   if (!elem_offset.defined()) {
-    elem_offset = make_const(n->shape[0].type(), 0);
+    elem_offset = make_const(n->DefaultIndexType(), 0);
   }
   if (data_alignment <= 0) {
     data_alignment = runtime::kAllocAlignment;
diff --git a/src/pass/arg_binder.cc b/src/pass/arg_binder.cc
index 20c8593a1..cdd344670 100644
--- a/src/pass/arg_binder.cc
+++ b/src/pass/arg_binder.cc
@@ -196,7 +196,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
       nop));
   if (buffer->strides.size() == 0) {
     // Assert the buffer is compact
-    Type stype = buffer->shape[0].type();
+    Type stype = buffer->DefaultIndexType();
     Expr expect_stride = make_const(stype, 1);
     Array<Expr> conds;
     for (size_t i = buffer->shape.size(); i != 0; --i) {
@@ -211,14 +211,16 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
     std::ostringstream stride_err_msg;
     stride_err_msg << arg_name << ".strides:"
                    << " expected to be compact array";
-    Stmt check =
-        AssertStmt::make(arith::ComputeReduce<ir::And>(conds),
-                         stride_err_msg.str(), Evaluate::make(0));
-    Expr is_null = Call::make(
-        Bool(1), intrinsic::tvm_handle_is_null,
-        {v_strides}, Call::PureIntrinsic);
-    check = IfThenElse::make(Not::make(is_null), check, Stmt());
-    init_nest_.emplace_back(Block::make(check, Evaluate::make(0)));
+    if (conds.size() != 0) {
+      Stmt check =
+          AssertStmt::make(arith::ComputeReduce<ir::And>(conds, Expr()),
+                           stride_err_msg.str(), Evaluate::make(0));
+      Expr is_null = Call::make(
+          Bool(1), intrinsic::tvm_handle_is_null,
+          {v_strides}, Call::PureIntrinsic);
+      check = IfThenElse::make(Not::make(is_null), check, Stmt());
+      init_nest_.emplace_back(Block::make(check, Evaluate::make(0)));
+    }
   } else {
     for (size_t k = 0; k < buffer->strides.size(); ++k) {
       std::ostringstream field_name;
diff --git a/src/pass/inject_double_buffer.cc b/src/pass/inject_double_buffer.cc
index e9bd8594a..03ffdb01e 100644
--- a/src/pass/inject_double_buffer.cc
+++ b/src/pass/inject_double_buffer.cc
@@ -81,7 +81,8 @@ class DoubleBufferInjector : public IRMutator {
   Stmt Mutate_(const Allocate* op, const Stmt& s) final {
     auto it = dbuffer_info_.find(op->buffer_var.get());
     if (it != dbuffer_info_.end()) {
-      it->second.stride = arith::ComputeReduce<Mul>(op->extents) * op->type.lanes();
+      it->second.stride = arith::ComputeReduce<Mul>
+          (op->extents, Expr()) * op->type.lanes();
       Stmt stmt = IRMutator::Mutate_(op, s);
       op = stmt.as<Allocate>();
       Array<Expr> new_extents{make_const(op->extents[0].type(), 2)};
diff --git a/src/pass/inject_virtual_thread.cc b/src/pass/inject_virtual_thread.cc
index 28e90ec48..bcf0e3d9f 100644
--- a/src/pass/inject_virtual_thread.cc
+++ b/src/pass/inject_virtual_thread.cc
@@ -376,7 +376,8 @@ class VTInjector : public IRMutator {
     // always rewrite if not allow sharing.
     if (touched_var_.count(op->buffer_var.get()) || !allow_share_) {
       // place v on highest dimension.
-      Expr stride = arith::ComputeReduce<Mul>(op->extents) * op->type.lanes();
+      Expr stride = arith::ComputeReduce<Mul>(
+          op->extents, Expr()) * op->type.lanes();
       Array<Expr> other;
       other.push_back(make_const(op->extents[0].type(), num_threads_));
       for (Expr e : extents) {
diff --git a/src/pass/storage_flatten.cc b/src/pass/storage_flatten.cc
index f1aee504f..46bed1fc9 100644
--- a/src/pass/storage_flatten.cc
+++ b/src/pass/storage_flatten.cc
@@ -147,10 +147,11 @@ class StorageFlattener : public IRMutator {
         }
       }
       Array<Expr> strides;
-      if (dim_align_.count(key) != 0) {
+      if (dim_align_.count(key) != 0 && shape.size() != 0) {
         std::vector<Expr> rstrides;
         const std::vector<DimAlignInfo>& avec = dim_align_[key];
-        Expr stride = make_const(shape[0].type(), 1);
+        int first_dim = 0;
+        Expr stride = make_const(shape[first_dim].type(), 1);
         for (size_t i = shape.size(); i != 0; --i) {
           size_t dim = i - 1;
           if (dim < avec.size() && avec[dim].align_factor != 0) {
@@ -164,6 +165,7 @@ class StorageFlattener : public IRMutator {
         }
         strides = Array<Expr>(rstrides.rbegin(), rstrides.rend());
       }
+
       e.buffer = BufferNode::make(
           Var(key.GetName(), Handle()),
           op->type, shape, strides, Expr(),
@@ -176,13 +178,18 @@ class StorageFlattener : public IRMutator {
       Stmt ret;
 
       if (strides.size() != 0) {
+        int first_dim = 0;
         ret = Allocate::make(
             e.buffer->data, e.buffer->dtype,
-            {arith::ComputeExpr<Mul>(e.buffer->strides[0], e.buffer->shape[0])},
+            {arith::ComputeExpr<Mul>(e.buffer->strides[first_dim], e.buffer->shape[first_dim])},
             make_const(Bool(e.buffer->dtype.lanes()), true), body);
       } else {
+        shape = e.buffer->shape;
+        if (shape.size() == 0) {
+          shape.push_back(make_const(Int(32), 1));
+        }
         ret = Allocate::make(
-            e.buffer->data, e.buffer->dtype, e.buffer->shape,
+            e.buffer->data, e.buffer->dtype, shape,
             make_const(Bool(e.buffer->dtype.lanes()), true), body);
       }
       ret = AttrStmt::make(
diff --git a/src/pass/storage_rewrite.cc b/src/pass/storage_rewrite.cc
index 2f3616017..9d47a64f8 100644
--- a/src/pass/storage_rewrite.cc
+++ b/src/pass/storage_rewrite.cc
@@ -405,7 +405,7 @@ class StoragePlanRewriter : public IRMutator {
           // Build a merged allocation
           Expr combo_size;
           for (const Allocate* op : e->allocs) {
-            Expr sz = arith::ComputeReduce<Mul>(op->extents);
+            Expr sz = arith::ComputeReduce<Mul>(op->extents, make_const(Int(32), 1));
             if (alloc_type.lanes() != op->type.lanes()) {
               sz = (sz * make_const(sz.type(), op->type.lanes()) +
                     make_const(sz.type(), alloc_type.lanes() - 1)) /
diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc
index f036dccc3..dd8f80bcd 100644
--- a/src/runtime/c_runtime_api.cc
+++ b/src/runtime/c_runtime_api.cc
@@ -352,9 +352,13 @@ int TVMArrayAlloc(const tvm_index_t* shape,
   arr->dtype.code = static_cast<uint8_t>(dtype_code);
   arr->dtype.bits = static_cast<uint8_t>(dtype_bits);
   arr->dtype.lanes = static_cast<uint16_t>(dtype_lanes);
-  tvm_index_t* shape_copy = new tvm_index_t[ndim];
-  std::copy(shape, shape + ndim, shape_copy);
-  arr->shape = shape_copy;
+  if (ndim != 0) {
+    tvm_index_t* shape_copy = new tvm_index_t[ndim];
+    std::copy(shape, shape + ndim, shape_copy);
+    arr->shape = shape_copy;
+  } else {
+    arr->shape = nullptr;
+  }
   // ctx
   arr->ctx.device_type = static_cast<DLDeviceType>(device_type);
   arr->ctx.device_id = device_id;
diff --git a/src/runtime/graph/graph_runtime.cc b/src/runtime/graph/graph_runtime.cc
index d3f849d74..ed833d408 100644
--- a/src/runtime/graph/graph_runtime.cc
+++ b/src/runtime/graph/graph_runtime.cc
@@ -370,8 +370,10 @@ void GraphRuntime::LoadDLTensor(dmlc::Stream* strm, DLTensor* dst) {
   CHECK(strm->Read(&tensor.dtype, sizeof(tensor.dtype)))
       << "Invalid DLTensor file format";
   std::vector<int64_t> shape(tensor.ndim);
-  CHECK(strm->Read(&shape[0], sizeof(int64_t) * tensor.ndim))
-      << "Invalid DLTensor file format";
+  if (tensor.ndim != 0) {
+    CHECK(strm->Read(&shape[0], sizeof(int64_t) * tensor.ndim))
+        << "Invalid DLTensor file format";
+  }
   CHECK_EQ(tensor.ndim, dst->ndim) << "param dimension mismatch";
   CHECK(tensor.dtype.bits == dst->dtype.bits &&
         tensor.dtype.code == dst->dtype.code &&
diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc
index a8dc4edf5..d1a69ecf0 100644
--- a/src/schedule/schedule_dataflow_rewrite.cc
+++ b/src/schedule/schedule_dataflow_rewrite.cc
@@ -47,10 +47,10 @@ Expr InjectPredicate(const Array<Expr>& predicates,
   const Reduce* reduce = body.as<Reduce>();
   if (reduce) {
     std::shared_ptr<Reduce> n = std::make_shared<Reduce>(*reduce);
-    n->condition = n->condition && arith::ComputeReduce<ir::And>(predicates);
+    n->condition = n->condition && arith::ComputeReduce<ir::And>(predicates, Expr());
     return Expr(n);
   }
-  return Select::make(arith::ComputeReduce<ir::And>(predicates),
+  return Select::make(arith::ComputeReduce<ir::And>(predicates, Expr()),
                       body,
                       make_zero(body.type()));
 }
@@ -467,7 +467,7 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor,
   const Reduce* reduce = compute_op->body[idx].as<Reduce>();
   CHECK(reduce) << "Can only rfactor non-inline reductions";
   predicates.push_back(reduce->condition);
-  Expr predicate = arith::ComputeReduce<ir::And>(predicates);
+  Expr predicate = arith::ComputeReduce<ir::And>(predicates, Expr());
 
   std::unordered_map<const Variable*, Expr> vsub;
 
diff --git a/tests/python/unittest/test_codegen_device.py b/tests/python/unittest/test_codegen_device.py
index bbdd65e4b..56e3fc819 100644
--- a/tests/python/unittest/test_codegen_device.py
+++ b/tests/python/unittest/test_codegen_device.py
@@ -5,8 +5,8 @@ import numpy as np
 def test_add_pipeline():
     n = tvm.var('n')
     A = tvm.placeholder((n,), name='A')
-    B = tvm.placeholder((n,), name='B')
-    C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
+    B = tvm.placeholder((), name='B')
+    C = tvm.compute(A.shape, lambda *i: A(*i) + B(), name='C')
     D = tvm.compute(A.shape, lambda *i: C(*i) + 1, name='D')
     s = tvm.create_schedule(D.op)
 
@@ -48,7 +48,7 @@ def test_add_pipeline():
         # launch the kernel.
         n = 1027
         a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx)
-        b = tvm.nd.array(np.random.uniform(size=n).astype(Bb.dtype), ctx)
+        b = tvm.nd.array(np.random.uniform(size=()).astype(Bb.dtype), ctx)
         d = tvm.nd.array(np.zeros(n, dtype=Db.dtype), ctx)
         f(a, b, d)
         np.testing.assert_allclose(
@@ -72,7 +72,7 @@ def test_add_pipeline():
         # launch the kernel.
         n = 1027
         a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx)
-        b = tvm.nd.array(np.random.uniform(size=n).astype(Bb.dtype), ctx)
+        b = tvm.nd.array(np.random.uniform(size=()).astype(Bb.dtype), ctx)
         d = tvm.nd.array(np.zeros(n, dtype=Db.dtype), ctx)
         f(a, b, d)
         np.testing.assert_allclose(
@@ -84,5 +84,6 @@ def test_add_pipeline():
     check_target("nvptx", host="llvm")
     check_target("rocm", host="llvm")
 
+
 if __name__ == "__main__":
     test_add_pipeline()
diff --git a/tests/python/unittest/test_codegen_llvm.py b/tests/python/unittest/test_codegen_llvm.py
index 0db06b934..24996c842 100644
--- a/tests/python/unittest/test_codegen_llvm.py
+++ b/tests/python/unittest/test_codegen_llvm.py
@@ -273,7 +273,32 @@ def test_llvm_bool():
     check_llvm(64)
 
 
+def test_rank_zero():
+    def check_llvm(n):
+        if not tvm.module.enabled("llvm"):
+            return
+        A = tvm.placeholder((n, ), name='A')
+        scale = tvm.placeholder((), name='scale')
+        k = tvm.reduce_axis((0, n), name="k")
+        C = tvm.compute((), lambda : tvm.sum(A[k] * scale, axis=k), name="C")
+        D = tvm.compute((), lambda : C + 1)
+        s = tvm.create_schedule(D.op)
+        # build and invoke the kernel.
+        f = tvm.build(s, [A, scale, D], "llvm")
+        ctx = tvm.cpu(0)
+        # launch the kernel.
+        a = tvm.nd.array(np.random.randint(0, 2, size=(n,)).astype(A.dtype), ctx)
+        sc = tvm.nd.array(
+            np.random.randint(0, 2, size=()).astype(scale.dtype), ctx)
+        d = tvm.nd.empty((), D.dtype, ctx)
+        f(a, sc, d)
+        d_np = np.sum(a.asnumpy()) * sc.asnumpy() + 1
+        np.testing.assert_allclose(d.asnumpy(), d_np)
+    check_llvm(64)
+
+
 if __name__ == "__main__":
+    test_rank_zero()
     test_llvm_bool()
     test_llvm_persist_parallel()
     test_llvm_select()
diff --git a/tests/python/unittest/test_lang_tensor.py b/tests/python/unittest/test_lang_tensor.py
index 6f151749c..1d8603dfc 100644
--- a/tests/python/unittest/test_lang_tensor.py
+++ b/tests/python/unittest/test_lang_tensor.py
@@ -19,6 +19,17 @@ def test_tensor():
     assert(T[0][0][0].astype('float16').dtype == 'float16')
 
 
+def test_rank_zero():
+    m = tvm.var('m')
+    A = tvm.placeholder((m,), name='A')
+    scale = tvm.placeholder((), name='s')
+    k = tvm.reduce_axis((0, m), name="k")
+    T = tvm.compute((), lambda : tvm.sum(A[k] * scale(), axis=k))
+    print(T)
+    print(T.op.body)
+    assert(tuple(T.shape) == ())
+
+
 def test_conv1d():
     n = tvm.var('n')
     A = tvm.placeholder((n+2), name='A')
@@ -173,7 +184,9 @@ def test_tensor_inputs():
     y = tvm.compute(x.shape, lambda i: x[i] + x[i])
     assert tuple(y.op.input_tensors) == (x,)
 
+
 if __name__ == "__main__":
+    test_rank_zero()
     test_tensor_inputs()
     test_tensor_reduce_multi_axis()
     test_conv1d()
diff --git a/tests/python/unittest/test_runtime_packed_func.py b/tests/python/unittest/test_runtime_packed_func.py
index 44b450b23..279172555 100644
--- a/tests/python/unittest/test_runtime_packed_func.py
+++ b/tests/python/unittest/test_runtime_packed_func.py
@@ -63,7 +63,15 @@ def test_byte_array():
     f(a)
 
 
+def test_empty_array():
+    def myfunc(ss):
+        assert tuple(ss) == ()
+    x = tvm.convert(())
+    tvm.convert(myfunc)(x)
+
+
 if __name__ == "__main__":
+    test_empty_array()
     test_get_global()
     test_get_callback_with_node()
     test_convert()
diff --git a/topi/python/topi/nn/dense.py b/topi/python/topi/nn/dense.py
index caa736a41..333692614 100644
--- a/topi/python/topi/nn/dense.py
+++ b/topi/python/topi/nn/dense.py
@@ -25,7 +25,7 @@ def dense(data, weight, bias=None):
     """
     assert len(data.shape) == 2 and len(weight.shape) == 2, \
         "only support 2-dim dense"
-    if bias:
+    if bias is not None:
         assert len(bias.shape) == 1
     batch, in_dim = data.shape
     out_dim, _ = weight.shape
@@ -33,7 +33,7 @@ def dense(data, weight, bias=None):
     matmul = tvm.compute((batch, out_dim), \
                          lambda i, j: tvm.sum(data[i, k] * weight[j, k], axis=k), \
                          tag='dense')
-    if bias:
+    if bias is not None:
         matmul = tvm.compute((batch, out_dim), \
                              lambda i, j: matmul[i, j] + bias[j], \
                              tag=tag.BROADCAST)
-- 
GitLab