From b55361b4d5dc5b29a0bfd8d0cfe91de0aa1c789e Mon Sep 17 00:00:00 2001 From: Tianqi Chen <tqchen@users.noreply.github.com> Date: Sat, 25 Nov 2017 10:00:14 -0800 Subject: [PATCH] [PASS] Allow compact checking when strides is available (#669) * [PASS] Allow compact checking when strides is available * remove assert compact --- src/codegen/stack_vm/codegen_stack_vm.cc | 1 + src/pass/arg_binder.cc | 30 +++++++++++++++++------- src/pass/ir_util.cc | 5 ++++ topi/python/topi/nn/conv2d.py | 3 ++- 4 files changed, 30 insertions(+), 9 deletions(-) diff --git a/src/codegen/stack_vm/codegen_stack_vm.cc b/src/codegen/stack_vm/codegen_stack_vm.cc index 97a2388f1..5b01dae71 100644 --- a/src/codegen/stack_vm/codegen_stack_vm.cc +++ b/src/codegen/stack_vm/codegen_stack_vm.cc @@ -362,6 +362,7 @@ void CodeGenStackVM::VisitExpr_(const Or *op) { } void CodeGenStackVM::VisitExpr_(const Not* op) { + this->Push(op->a); this->PushOp(StackVM::NOT); } diff --git a/src/pass/arg_binder.cc b/src/pass/arg_binder.cc index f9969cc5d..20c8593a1 100644 --- a/src/pass/arg_binder.cc +++ b/src/pass/arg_binder.cc @@ -136,12 +136,6 @@ inline Expr TVMArrayGet(Type t, Var arr, intrinsic::TVMStructFieldKind kind) { return TVMStructGet(t, arr, 0, kind); } -inline Stmt AssertNull(Var handle, std::string msg) { - return AssertStmt::make(Call::make( - Bool(1), intrinsic::tvm_handle_is_null, - {handle}, Call::PureIntrinsic), msg, Evaluate::make(0)); -} - void ArgBinder::BindDLTensor(const Buffer& buffer, const Expr& device_type, const Expr& device_id, @@ -201,10 +195,30 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, v_strides, TVMArrayGet(Handle(), handle, intrinsic::kArrStrides), nop)); if (buffer->strides.size() == 0) { + // Assert the buffer is compact + Type stype = buffer->shape[0].type(); + Expr expect_stride = make_const(stype, 1); + Array<Expr> conds; + for (size_t i = buffer->shape.size(); i != 0; --i) { + size_t k = i - 1; + Expr svalue = cast( + stype, + Load::make(tvm_shape_type, v_strides, + IntImm::make(Int(32), k), const_true(1))); + conds.push_back(expect_stride == svalue); + expect_stride = expect_stride * buffer->shape[k]; + } std::ostringstream stride_err_msg; stride_err_msg << arg_name << ".strides:" - << " expected to be nullptr for contiguous array"; - init_nest_.emplace_back(AssertNull(v_strides, stride_err_msg.str())); + << " expected to be compact array"; + Stmt check = + AssertStmt::make(arith::ComputeReduce<ir::And>(conds), + stride_err_msg.str(), Evaluate::make(0)); + Expr is_null = Call::make( + Bool(1), intrinsic::tvm_handle_is_null, + {v_strides}, Call::PureIntrinsic); + check = IfThenElse::make(Not::make(is_null), check, Stmt()); + init_nest_.emplace_back(Block::make(check, Evaluate::make(0))); } else { for (size_t k = 0; k < buffer->strides.size(); ++k) { std::ostringstream field_name; diff --git a/src/pass/ir_util.cc b/src/pass/ir_util.cc index 12551947a..579706ca9 100644 --- a/src/pass/ir_util.cc +++ b/src/pass/ir_util.cc @@ -33,6 +33,11 @@ Stmt MergeNest(const std::vector<Stmt>& nest, Stmt body) { CHECK(!n->else_case.defined()); n->then_case = body; body = Stmt(n); + } else if (s.as<Block>()) { + auto n = std::make_shared<Block>(*s.as<Block>()); + CHECK(is_no_op(n->rest)); + n->rest = body; + body = Stmt(n); } else if (s.as<AssertStmt>()) { auto n = std::make_shared<AssertStmt>(*s.as<AssertStmt>()); CHECK(is_no_op(n->body)); diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py index cc1ee0198..11866aedc 100644 --- a/topi/python/topi/nn/conv2d.py +++ b/topi/python/topi/nn/conv2d.py @@ -225,7 +225,8 @@ def _im2col_pack(data, kernel, stride, padding, out_dtype): wk = tvm.reduce_axis((0, KW), name='wk') conv = tvm.compute(ovshape, lambda n, co, im, vim, vco: \ - tvm.sum(data_vec[n][im][ci][hk][wk][vim] * kernel_vec[co][ci][hk][wk][vco], + tvm.sum(data_vec[n][im][ci][hk][wk][vim].astype(out_dtype) * + kernel_vec[co][ci][hk][wk][vco].astype(out_dtype), axis=[ci, hk, wk]), name='conv') output = tvm.compute(oshape, lambda n, co, h, w: \ -- GitLab