From f2b913925dae089160e5f20eaeaa5c8ef3eff2e8 Mon Sep 17 00:00:00 2001 From: Tianqi Chen <tqchen@users.noreply.github.com> Date: Sun, 3 Dec 2017 22:38:28 -0800 Subject: [PATCH] Support rank-0 tensor (#687) * Support rank-0 tensor * fix lint --- include/tvm/buffer.h | 5 +++ include/tvm/packed_func_ext.h | 4 ++ include/tvm/tensor.h | 2 +- python/tvm/_ffi/ndarray.py | 7 ++-- python/tvm/api.py | 1 - python/tvm/tensor.py | 8 +++- src/arithmetic/compute_expr.h | 12 ++++-- src/lang/buffer.cc | 37 ++++++++++--------- src/pass/arg_binder.cc | 20 +++++----- src/pass/inject_double_buffer.cc | 3 +- src/pass/inject_virtual_thread.cc | 3 +- src/pass/storage_flatten.cc | 15 ++++++-- src/pass/storage_rewrite.cc | 2 +- src/runtime/c_runtime_api.cc | 10 +++-- src/runtime/graph/graph_runtime.cc | 6 ++- src/schedule/schedule_dataflow_rewrite.cc | 6 +-- tests/python/unittest/test_codegen_device.py | 9 +++-- tests/python/unittest/test_codegen_llvm.py | 25 +++++++++++++ tests/python/unittest/test_lang_tensor.py | 13 +++++++ .../unittest/test_runtime_packed_func.py | 8 ++++ topi/python/topi/nn/dense.py | 4 +- 21 files changed, 143 insertions(+), 57 deletions(-) diff --git a/include/tvm/buffer.h b/include/tvm/buffer.h index ad4872b8e..f2790f6df 100644 --- a/include/tvm/buffer.h +++ b/include/tvm/buffer.h @@ -124,6 +124,11 @@ class BufferNode : public Node { v->Visit("offset_factor", &offset_factor); } + /*! \return preferred index type for this buffer node */ + Type DefaultIndexType() const { + return shape.size() != 0 ? shape[0].type() : Int(32); + } + // User can specify data_alignment and offset_factor to be 0 // A default value will be picked. TVM_DLL static Buffer make(Var ptr, diff --git a/include/tvm/packed_func_ext.h b/include/tvm/packed_func_ext.h index 5242a0576..542de6a36 100644 --- a/include/tvm/packed_func_ext.h +++ b/include/tvm/packed_func_ext.h @@ -14,6 +14,7 @@ #include "./base.h" #include "./expr.h" +#include "./tensor.h" #include "./runtime/packed_func.h" namespace tvm { @@ -116,6 +117,9 @@ inline TVMArgValue::operator Halide::Expr() const { if (sptr->is_type<IterVarNode>()) { return IterVar(sptr)->var; } + if (sptr->is_type<TensorNode>()) { + return Tensor(sptr)(); + } CHECK(NodeTypeChecker<Expr>::Check(sptr.get())) << "Expected type " << NodeTypeName<Expr>() << " but get " << sptr->type_key(); diff --git a/include/tvm/tensor.h b/include/tvm/tensor.h index a6613a4dc..4f46d86e9 100644 --- a/include/tvm/tensor.h +++ b/include/tvm/tensor.h @@ -188,7 +188,7 @@ inline bool Tensor::operator==(const Tensor& other) const { #define DEFINE_OVERLOAD_SLICE_UNARY_OP(Op) \ inline Expr operator Op (const Tensor::Slice& a) { \ return Op a.operator Expr() ; \ - } + } \ #define DEFINE_OVERLOAD_SLICE_BINARY_OP(Op) \ template<typename T> \ diff --git a/python/tvm/_ffi/ndarray.py b/python/tvm/_ffi/ndarray.py index b0dfd0f73..135701a80 100644 --- a/python/tvm/_ffi/ndarray.py +++ b/python/tvm/_ffi/ndarray.py @@ -177,13 +177,14 @@ class NDArrayBase(_NDArrayBase): shape = shape + (t.lanes,) t.lanes = 1 dtype = str(t) - source_array = np.ascontiguousarray(source_array, dtype=dtype) + if source_array.shape != shape: raise ValueError("array shape do not match the shape of NDArray {0} vs {1}".format( source_array.shape, shape)) + source_array = np.ascontiguousarray(source_array, dtype=dtype) assert source_array.flags['C_CONTIGUOUS'] data = source_array.ctypes.data_as(ctypes.c_void_p) - nbytes = ctypes.c_size_t(np.prod(source_array.shape) * source_array.dtype.itemsize) + nbytes = ctypes.c_size_t(source_array.size * source_array.dtype.itemsize) check_call(_LIB.TVMArrayCopyFromBytes(self.handle, data, nbytes)) return self @@ -212,7 +213,7 @@ class NDArrayBase(_NDArrayBase): np_arr = np.empty(shape, dtype=dtype) assert np_arr.flags['C_CONTIGUOUS'] data = np_arr.ctypes.data_as(ctypes.c_void_p) - nbytes = ctypes.c_size_t(np.prod(np_arr.shape) * np_arr.dtype.itemsize) + nbytes = ctypes.c_size_t(np_arr.size * np_arr.dtype.itemsize) check_call(_LIB.TVMArrayCopyToBytes(self.handle, data, nbytes)) return np_arr diff --git a/python/tvm/api.py b/python/tvm/api.py index dfe6e4cf7..08b3d95dc 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -462,7 +462,6 @@ def decl_buffer(shape, elem_offset = var('%s_elem_offset' % name, shape[0].dtype) if data is None: data = var(name, "handle") - return _api_internal._Buffer( data, dtype, shape, strides, elem_offset, name, scope, data_alignment, offset_factor) diff --git a/python/tvm/tensor.py b/python/tvm/tensor.py index 98a142e8c..f169ff1b6 100644 --- a/python/tvm/tensor.py +++ b/python/tvm/tensor.py @@ -32,7 +32,7 @@ class TensorSlice(NodeGeneric, _expr.ExprOp): itervar_cls = None @register_node -class Tensor(NodeBase): +class Tensor(NodeBase, _expr.ExprOp): """Tensor object, to construct, see function.Tensor""" def __call__(self, *indices): ndim = self.ndim @@ -60,7 +60,13 @@ class Tensor(NodeBase): def __eq__(self, other): if not isinstance(other, Tensor): + if isinstance(other, _expr.ExprOp): + return _expr.EqualOp(self, other) return False + if self.ndim == 0 and other.ndim == 0: + raise ValueError("Equal == comparison among rank-0 tensor is ambiguous, " + "use Tensor.equal for content expression equvalence, " + "use Tensor.same_as for exact reference comparison") return _api_internal._TensorEqual(self, other) @property diff --git a/src/arithmetic/compute_expr.h b/src/arithmetic/compute_expr.h index 18ae8530f..994bcb13e 100644 --- a/src/arithmetic/compute_expr.h +++ b/src/arithmetic/compute_expr.h @@ -33,11 +33,14 @@ inline Expr ComputeExpr(Expr lhs, Expr rhs) { /*! * \brief Compute an reduction with Op * \param values The input values. + * \param empty_value The value when return if it is empty, can be Expr() + * which will cause an error to be rasied. * \tparam Op The computation operator * \return The result. */ template<typename Op> -inline Expr ComputeReduce(const Array<Expr>& values); +inline Expr ComputeReduce( + const Array<Expr>& values, Expr empty_value); template<typename T> inline bool GetConst(Expr e, T* out); @@ -139,8 +142,11 @@ inline Expr ComputeExpr<ir::Min>(Expr a, Expr b) { } template<typename Op> -inline Expr ComputeReduce(const Array<Expr>& values) { - CHECK_NE(values.size(), 0U); +inline Expr ComputeReduce(const Array<Expr>& values, Expr empty_value) { + if (values.size() == 0U) { + CHECK(empty_value.defined()); + return empty_value; + } Expr res = values[0]; for (size_t i = 1; i < values.size(); ++i) { res = ComputeExpr<Op>(res, values[i]); diff --git a/src/lang/buffer.cc b/src/lang/buffer.cc index 5cf7ddef3..af76dcc94 100644 --- a/src/lang/buffer.cc +++ b/src/lang/buffer.cc @@ -11,15 +11,6 @@ namespace tvm { -Array<Expr> GetStrides(Array<Expr> shape) { - CHECK_NE(shape.size(), 0U); - std::vector<Expr> vec{make_const(shape[0].type(), 1)}; - for (size_t i = shape.size() - 1; i != 0; --i) { - vec.push_back(shape[i - 1] * vec.back()); - } - return Array<Expr>(vec.rbegin(), vec.rend()); -} - Array<Expr> SimplifyArray(Array<Expr> array) { for (size_t i = 0; i < array.size(); ++i) { array.Set(i, ir::Simplify(array[i])); @@ -235,10 +226,12 @@ inline Expr ElemOffset(const BufferNode* n, Array<Expr> index) { Expr base = n->elem_offset; if (n->strides.size() == 0) { CHECK_EQ(n->shape.size(), index.size()); - if (is_zero(base)) { - base = index[0]; - } else { - base = base + index[0]; + if (n->shape.size() != 0) { + if (is_zero(base)) { + base = index[0]; + } else { + base = base + index[0]; + } } base = MergeMulMod(base); for (size_t i = 1; i < index.size(); ++i) { @@ -294,9 +287,10 @@ Stmt Buffer::vstore(Array<Expr> begin, Expr value) const { Buffer Buffer::MakeStrideView() const { if ((*this)->strides.size() != 0) return *this; + if ((*this)->shape.size() == 0) return *this; std::vector<Expr> temp; auto n = std::make_shared<BufferNode>(*operator->()); - Expr acc = make_const(n->shape[0].type(), 1); + Expr acc = make_const(n->DefaultIndexType(), 1); for (size_t i = n->shape.size(); i != 0 ; --i) { temp.push_back(acc); acc = acc * n->shape[i - 1]; @@ -344,9 +338,16 @@ Buffer Buffer::MakeSlice(Array<Expr> begins, Array<Expr> extents) const { Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes) const { const BufferNode* self = operator->(); Expr e_dtype; - Expr extent = (self->strides.size() == self->shape.size() ? - arith::ComputeExpr<ir::Mul>(self->strides[0], self->shape[0]): - arith::ComputeReduce<ir::Mul>(self->shape)); + Expr extent; + if (self->shape.size() == 0) { + extent = make_const(self->DefaultIndexType(), 1); + } else if (self->strides.size() == self->shape.size()) { + int highest_dim = 0; + extent = arith::ComputeExpr<ir::Mul>( + self->strides[highest_dim], self->shape[highest_dim]); + } else { + extent = arith::ComputeReduce<ir::Mul>(self->shape, Expr()); + } Expr elem_offset = self->elem_offset; if (content_lanes > 1) { e_dtype = make_zero(self->dtype.with_lanes(content_lanes)); @@ -383,7 +384,7 @@ Buffer BufferNode::make(Var data, } n->scope = std::move(scope); if (!elem_offset.defined()) { - elem_offset = make_const(n->shape[0].type(), 0); + elem_offset = make_const(n->DefaultIndexType(), 0); } if (data_alignment <= 0) { data_alignment = runtime::kAllocAlignment; diff --git a/src/pass/arg_binder.cc b/src/pass/arg_binder.cc index 20c8593a1..cdd344670 100644 --- a/src/pass/arg_binder.cc +++ b/src/pass/arg_binder.cc @@ -196,7 +196,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, nop)); if (buffer->strides.size() == 0) { // Assert the buffer is compact - Type stype = buffer->shape[0].type(); + Type stype = buffer->DefaultIndexType(); Expr expect_stride = make_const(stype, 1); Array<Expr> conds; for (size_t i = buffer->shape.size(); i != 0; --i) { @@ -211,14 +211,16 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, std::ostringstream stride_err_msg; stride_err_msg << arg_name << ".strides:" << " expected to be compact array"; - Stmt check = - AssertStmt::make(arith::ComputeReduce<ir::And>(conds), - stride_err_msg.str(), Evaluate::make(0)); - Expr is_null = Call::make( - Bool(1), intrinsic::tvm_handle_is_null, - {v_strides}, Call::PureIntrinsic); - check = IfThenElse::make(Not::make(is_null), check, Stmt()); - init_nest_.emplace_back(Block::make(check, Evaluate::make(0))); + if (conds.size() != 0) { + Stmt check = + AssertStmt::make(arith::ComputeReduce<ir::And>(conds, Expr()), + stride_err_msg.str(), Evaluate::make(0)); + Expr is_null = Call::make( + Bool(1), intrinsic::tvm_handle_is_null, + {v_strides}, Call::PureIntrinsic); + check = IfThenElse::make(Not::make(is_null), check, Stmt()); + init_nest_.emplace_back(Block::make(check, Evaluate::make(0))); + } } else { for (size_t k = 0; k < buffer->strides.size(); ++k) { std::ostringstream field_name; diff --git a/src/pass/inject_double_buffer.cc b/src/pass/inject_double_buffer.cc index e9bd8594a..03ffdb01e 100644 --- a/src/pass/inject_double_buffer.cc +++ b/src/pass/inject_double_buffer.cc @@ -81,7 +81,8 @@ class DoubleBufferInjector : public IRMutator { Stmt Mutate_(const Allocate* op, const Stmt& s) final { auto it = dbuffer_info_.find(op->buffer_var.get()); if (it != dbuffer_info_.end()) { - it->second.stride = arith::ComputeReduce<Mul>(op->extents) * op->type.lanes(); + it->second.stride = arith::ComputeReduce<Mul> + (op->extents, Expr()) * op->type.lanes(); Stmt stmt = IRMutator::Mutate_(op, s); op = stmt.as<Allocate>(); Array<Expr> new_extents{make_const(op->extents[0].type(), 2)}; diff --git a/src/pass/inject_virtual_thread.cc b/src/pass/inject_virtual_thread.cc index 28e90ec48..bcf0e3d9f 100644 --- a/src/pass/inject_virtual_thread.cc +++ b/src/pass/inject_virtual_thread.cc @@ -376,7 +376,8 @@ class VTInjector : public IRMutator { // always rewrite if not allow sharing. if (touched_var_.count(op->buffer_var.get()) || !allow_share_) { // place v on highest dimension. - Expr stride = arith::ComputeReduce<Mul>(op->extents) * op->type.lanes(); + Expr stride = arith::ComputeReduce<Mul>( + op->extents, Expr()) * op->type.lanes(); Array<Expr> other; other.push_back(make_const(op->extents[0].type(), num_threads_)); for (Expr e : extents) { diff --git a/src/pass/storage_flatten.cc b/src/pass/storage_flatten.cc index f1aee504f..46bed1fc9 100644 --- a/src/pass/storage_flatten.cc +++ b/src/pass/storage_flatten.cc @@ -147,10 +147,11 @@ class StorageFlattener : public IRMutator { } } Array<Expr> strides; - if (dim_align_.count(key) != 0) { + if (dim_align_.count(key) != 0 && shape.size() != 0) { std::vector<Expr> rstrides; const std::vector<DimAlignInfo>& avec = dim_align_[key]; - Expr stride = make_const(shape[0].type(), 1); + int first_dim = 0; + Expr stride = make_const(shape[first_dim].type(), 1); for (size_t i = shape.size(); i != 0; --i) { size_t dim = i - 1; if (dim < avec.size() && avec[dim].align_factor != 0) { @@ -164,6 +165,7 @@ class StorageFlattener : public IRMutator { } strides = Array<Expr>(rstrides.rbegin(), rstrides.rend()); } + e.buffer = BufferNode::make( Var(key.GetName(), Handle()), op->type, shape, strides, Expr(), @@ -176,13 +178,18 @@ class StorageFlattener : public IRMutator { Stmt ret; if (strides.size() != 0) { + int first_dim = 0; ret = Allocate::make( e.buffer->data, e.buffer->dtype, - {arith::ComputeExpr<Mul>(e.buffer->strides[0], e.buffer->shape[0])}, + {arith::ComputeExpr<Mul>(e.buffer->strides[first_dim], e.buffer->shape[first_dim])}, make_const(Bool(e.buffer->dtype.lanes()), true), body); } else { + shape = e.buffer->shape; + if (shape.size() == 0) { + shape.push_back(make_const(Int(32), 1)); + } ret = Allocate::make( - e.buffer->data, e.buffer->dtype, e.buffer->shape, + e.buffer->data, e.buffer->dtype, shape, make_const(Bool(e.buffer->dtype.lanes()), true), body); } ret = AttrStmt::make( diff --git a/src/pass/storage_rewrite.cc b/src/pass/storage_rewrite.cc index 2f3616017..9d47a64f8 100644 --- a/src/pass/storage_rewrite.cc +++ b/src/pass/storage_rewrite.cc @@ -405,7 +405,7 @@ class StoragePlanRewriter : public IRMutator { // Build a merged allocation Expr combo_size; for (const Allocate* op : e->allocs) { - Expr sz = arith::ComputeReduce<Mul>(op->extents); + Expr sz = arith::ComputeReduce<Mul>(op->extents, make_const(Int(32), 1)); if (alloc_type.lanes() != op->type.lanes()) { sz = (sz * make_const(sz.type(), op->type.lanes()) + make_const(sz.type(), alloc_type.lanes() - 1)) / diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc index f036dccc3..dd8f80bcd 100644 --- a/src/runtime/c_runtime_api.cc +++ b/src/runtime/c_runtime_api.cc @@ -352,9 +352,13 @@ int TVMArrayAlloc(const tvm_index_t* shape, arr->dtype.code = static_cast<uint8_t>(dtype_code); arr->dtype.bits = static_cast<uint8_t>(dtype_bits); arr->dtype.lanes = static_cast<uint16_t>(dtype_lanes); - tvm_index_t* shape_copy = new tvm_index_t[ndim]; - std::copy(shape, shape + ndim, shape_copy); - arr->shape = shape_copy; + if (ndim != 0) { + tvm_index_t* shape_copy = new tvm_index_t[ndim]; + std::copy(shape, shape + ndim, shape_copy); + arr->shape = shape_copy; + } else { + arr->shape = nullptr; + } // ctx arr->ctx.device_type = static_cast<DLDeviceType>(device_type); arr->ctx.device_id = device_id; diff --git a/src/runtime/graph/graph_runtime.cc b/src/runtime/graph/graph_runtime.cc index d3f849d74..ed833d408 100644 --- a/src/runtime/graph/graph_runtime.cc +++ b/src/runtime/graph/graph_runtime.cc @@ -370,8 +370,10 @@ void GraphRuntime::LoadDLTensor(dmlc::Stream* strm, DLTensor* dst) { CHECK(strm->Read(&tensor.dtype, sizeof(tensor.dtype))) << "Invalid DLTensor file format"; std::vector<int64_t> shape(tensor.ndim); - CHECK(strm->Read(&shape[0], sizeof(int64_t) * tensor.ndim)) - << "Invalid DLTensor file format"; + if (tensor.ndim != 0) { + CHECK(strm->Read(&shape[0], sizeof(int64_t) * tensor.ndim)) + << "Invalid DLTensor file format"; + } CHECK_EQ(tensor.ndim, dst->ndim) << "param dimension mismatch"; CHECK(tensor.dtype.bits == dst->dtype.bits && tensor.dtype.code == dst->dtype.code && diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index a8dc4edf5..d1a69ecf0 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -47,10 +47,10 @@ Expr InjectPredicate(const Array<Expr>& predicates, const Reduce* reduce = body.as<Reduce>(); if (reduce) { std::shared_ptr<Reduce> n = std::make_shared<Reduce>(*reduce); - n->condition = n->condition && arith::ComputeReduce<ir::And>(predicates); + n->condition = n->condition && arith::ComputeReduce<ir::And>(predicates, Expr()); return Expr(n); } - return Select::make(arith::ComputeReduce<ir::And>(predicates), + return Select::make(arith::ComputeReduce<ir::And>(predicates, Expr()), body, make_zero(body.type())); } @@ -467,7 +467,7 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor, const Reduce* reduce = compute_op->body[idx].as<Reduce>(); CHECK(reduce) << "Can only rfactor non-inline reductions"; predicates.push_back(reduce->condition); - Expr predicate = arith::ComputeReduce<ir::And>(predicates); + Expr predicate = arith::ComputeReduce<ir::And>(predicates, Expr()); std::unordered_map<const Variable*, Expr> vsub; diff --git a/tests/python/unittest/test_codegen_device.py b/tests/python/unittest/test_codegen_device.py index bbdd65e4b..56e3fc819 100644 --- a/tests/python/unittest/test_codegen_device.py +++ b/tests/python/unittest/test_codegen_device.py @@ -5,8 +5,8 @@ import numpy as np def test_add_pipeline(): n = tvm.var('n') A = tvm.placeholder((n,), name='A') - B = tvm.placeholder((n,), name='B') - C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') + B = tvm.placeholder((), name='B') + C = tvm.compute(A.shape, lambda *i: A(*i) + B(), name='C') D = tvm.compute(A.shape, lambda *i: C(*i) + 1, name='D') s = tvm.create_schedule(D.op) @@ -48,7 +48,7 @@ def test_add_pipeline(): # launch the kernel. n = 1027 a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx) - b = tvm.nd.array(np.random.uniform(size=n).astype(Bb.dtype), ctx) + b = tvm.nd.array(np.random.uniform(size=()).astype(Bb.dtype), ctx) d = tvm.nd.array(np.zeros(n, dtype=Db.dtype), ctx) f(a, b, d) np.testing.assert_allclose( @@ -72,7 +72,7 @@ def test_add_pipeline(): # launch the kernel. n = 1027 a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx) - b = tvm.nd.array(np.random.uniform(size=n).astype(Bb.dtype), ctx) + b = tvm.nd.array(np.random.uniform(size=()).astype(Bb.dtype), ctx) d = tvm.nd.array(np.zeros(n, dtype=Db.dtype), ctx) f(a, b, d) np.testing.assert_allclose( @@ -84,5 +84,6 @@ def test_add_pipeline(): check_target("nvptx", host="llvm") check_target("rocm", host="llvm") + if __name__ == "__main__": test_add_pipeline() diff --git a/tests/python/unittest/test_codegen_llvm.py b/tests/python/unittest/test_codegen_llvm.py index 0db06b934..24996c842 100644 --- a/tests/python/unittest/test_codegen_llvm.py +++ b/tests/python/unittest/test_codegen_llvm.py @@ -273,7 +273,32 @@ def test_llvm_bool(): check_llvm(64) +def test_rank_zero(): + def check_llvm(n): + if not tvm.module.enabled("llvm"): + return + A = tvm.placeholder((n, ), name='A') + scale = tvm.placeholder((), name='scale') + k = tvm.reduce_axis((0, n), name="k") + C = tvm.compute((), lambda : tvm.sum(A[k] * scale, axis=k), name="C") + D = tvm.compute((), lambda : C + 1) + s = tvm.create_schedule(D.op) + # build and invoke the kernel. + f = tvm.build(s, [A, scale, D], "llvm") + ctx = tvm.cpu(0) + # launch the kernel. + a = tvm.nd.array(np.random.randint(0, 2, size=(n,)).astype(A.dtype), ctx) + sc = tvm.nd.array( + np.random.randint(0, 2, size=()).astype(scale.dtype), ctx) + d = tvm.nd.empty((), D.dtype, ctx) + f(a, sc, d) + d_np = np.sum(a.asnumpy()) * sc.asnumpy() + 1 + np.testing.assert_allclose(d.asnumpy(), d_np) + check_llvm(64) + + if __name__ == "__main__": + test_rank_zero() test_llvm_bool() test_llvm_persist_parallel() test_llvm_select() diff --git a/tests/python/unittest/test_lang_tensor.py b/tests/python/unittest/test_lang_tensor.py index 6f151749c..1d8603dfc 100644 --- a/tests/python/unittest/test_lang_tensor.py +++ b/tests/python/unittest/test_lang_tensor.py @@ -19,6 +19,17 @@ def test_tensor(): assert(T[0][0][0].astype('float16').dtype == 'float16') +def test_rank_zero(): + m = tvm.var('m') + A = tvm.placeholder((m,), name='A') + scale = tvm.placeholder((), name='s') + k = tvm.reduce_axis((0, m), name="k") + T = tvm.compute((), lambda : tvm.sum(A[k] * scale(), axis=k)) + print(T) + print(T.op.body) + assert(tuple(T.shape) == ()) + + def test_conv1d(): n = tvm.var('n') A = tvm.placeholder((n+2), name='A') @@ -173,7 +184,9 @@ def test_tensor_inputs(): y = tvm.compute(x.shape, lambda i: x[i] + x[i]) assert tuple(y.op.input_tensors) == (x,) + if __name__ == "__main__": + test_rank_zero() test_tensor_inputs() test_tensor_reduce_multi_axis() test_conv1d() diff --git a/tests/python/unittest/test_runtime_packed_func.py b/tests/python/unittest/test_runtime_packed_func.py index 44b450b23..279172555 100644 --- a/tests/python/unittest/test_runtime_packed_func.py +++ b/tests/python/unittest/test_runtime_packed_func.py @@ -63,7 +63,15 @@ def test_byte_array(): f(a) +def test_empty_array(): + def myfunc(ss): + assert tuple(ss) == () + x = tvm.convert(()) + tvm.convert(myfunc)(x) + + if __name__ == "__main__": + test_empty_array() test_get_global() test_get_callback_with_node() test_convert() diff --git a/topi/python/topi/nn/dense.py b/topi/python/topi/nn/dense.py index caa736a41..333692614 100644 --- a/topi/python/topi/nn/dense.py +++ b/topi/python/topi/nn/dense.py @@ -25,7 +25,7 @@ def dense(data, weight, bias=None): """ assert len(data.shape) == 2 and len(weight.shape) == 2, \ "only support 2-dim dense" - if bias: + if bias is not None: assert len(bias.shape) == 1 batch, in_dim = data.shape out_dim, _ = weight.shape @@ -33,7 +33,7 @@ def dense(data, weight, bias=None): matmul = tvm.compute((batch, out_dim), \ lambda i, j: tvm.sum(data[i, k] * weight[j, k], axis=k), \ tag='dense') - if bias: + if bias is not None: matmul = tvm.compute((batch, out_dim), \ lambda i, j: matmul[i, j] + bias[j], \ tag=tag.BROADCAST) -- GitLab