From 9f441d817c7e3c5eeeb0078c5c60701a6e7ba33b Mon Sep 17 00:00:00 2001 From: Tianqi Chen <tqchen@users.noreply.github.com> Date: Fri, 9 Nov 2018 14:25:53 -0800 Subject: [PATCH] [RELAY] CompileEngine update, nn conv2d, fix dense, pool. (#2082) --- include/tvm/relay/op_attr_types.h | 3 +- python/tvm/relay/op/_tensor.py | 9 +- python/tvm/relay/op/nn/__init__.py | 1 + python/tvm/relay/op/nn/_nn.py | 176 +++++++++++++++- python/tvm/relay/op/op.py | 9 + src/relay/backend/compile_engine.cc | 7 +- src/relay/op/nn/nn.cc | 5 +- src/relay/op/nn/pooling.cc | 81 +++++++- .../python/relay/test_backend_interpreter.py | 22 -- tests/python/relay/test_op_level1.py | 62 +++++- tests/python/relay/test_op_level2.py | 194 +++++++++++++++--- topi/python/topi/util.py | 9 +- 12 files changed, 489 insertions(+), 89 deletions(-) diff --git a/include/tvm/relay/op_attr_types.h b/include/tvm/relay/op_attr_types.h index 941b32e9d..2c9fa2808 100644 --- a/include/tvm/relay/op_attr_types.h +++ b/include/tvm/relay/op_attr_types.h @@ -72,7 +72,8 @@ using FTVMCompute = runtime::TypedPackedFunc< * \return schedule The computation schedule. */ using FTVMSchedule = runtime::TypedPackedFunc< - Schedule(const Array<Tensor>& outs, + Schedule(const Attrs& attrs, + const Array<Tensor>& outs, const Target& target)>; } // namespace relay } // namespace tvm diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index 28d53ec86..7aef4d437 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -2,13 +2,8 @@ """Backend compiler related feature registration""" from __future__ import absolute_import import topi -import topi.cuda -from .op import register_compute, register_schedule, register_pattern, OpPattern - -def schedule_injective(outputs, target): - """Generic schedule for binary broadcast.""" - with target: - return topi.generic.schedule_injective(outputs) +from .op import register_compute, register_schedule, register_pattern +from .op import schedule_injective, OpPattern schedule_broadcast = schedule_injective schedule_elemwise = schedule_injective diff --git a/python/tvm/relay/op/nn/__init__.py b/python/tvm/relay/op/nn/__init__.py index d1818e718..0c2a0a435 100644 --- a/python/tvm/relay/op/nn/__init__.py +++ b/python/tvm/relay/op/nn/__init__.py @@ -2,3 +2,4 @@ """Neural network related operators.""" from __future__ import absolute_import as _abs from .nn import * +from . import _nn diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 4f5dcd4dd..7bc26cdec 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -1,16 +1,174 @@ #pylint: disable=invalid-name, unused-argument """Backend compiler related feature registration""" -import tvm import topi -from .. import register +from topi.util import get_const_int, get_const_tuple +from .. import op as reg +from ..op import OpPattern, schedule_injective -def dense_compiler(attrs, inputs, output_type): - assert len(inputs) == 2 +# dense +@reg.register_compute("nn.dense") +def compute_dense(attrs, inputs, out_type, target): + """Compute definition of dense""" return [topi.nn.dense(inputs[0], inputs[1])] -def dense_schedule(outputs, target): - assert len(outputs) == 1 - return tvm.create_schedule(outputs[0].op) +@reg.register_schedule("nn.dense") +def schedule_dense(attrs, outputs, target): + """Schedule definition of dense""" + with target: + return topi.generic.schedule_dense(outputs) -register("nn.dense", "FTVMCompute", dense_compiler) -register("nn.dense", "FTVMSchedule", dense_schedule) +reg.register_pattern("nn.dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE) + + +# conv2d +@reg.register_compute("nn.conv2d") +def compute_conv2d(attrs, inputs, out_type, target): + """Compute definition of conv2d""" + padding = get_const_tuple(attrs.padding) + strides = get_const_tuple(attrs.strides) + dilation = get_const_tuple(attrs.dilation) + groups = attrs.groups + layout = attrs.data_layout + weight_layout = attrs.weight_layout + out_dtype = attrs.out_dtype + out_dtype = (inputs[0].dtype if (out_dtype == "same" or out_dtype == "") + else out_dtype) + + assert layout in ["NCHW", "NHWC", "NCHW4c"] + (dilation_h, dilation_w) = dilation + if dilation_h < 1 or dilation_w < 1: + raise ValueError("dilation should be positive value") + + if groups == 1: + out = topi.nn.conv2d( + inputs[0], inputs[1], strides, padding, + dilation, layout, out_dtype=out_dtype) + elif layout == "NCHW" and \ + weight_layout == "OIHW" and \ + get_const_int(inputs[1].shape[0]) == groups and \ + get_const_int(inputs[1].shape[1]) == 1: + out = topi.nn.depthwise_conv2d_nchw( + inputs[0], inputs[1], strides, padding, dilation, out_dtype=out_dtype) + elif layout == "NHWC" and \ + kernel_layout == "HWOI" and\ + get_const_int(inputs[1].shape[2]) == groups and \ + get_const_int(inputs[1].shape[3]) == 1: + out = topi.nn.depthwise_conv2d_nhwc( + inputs[0], inputs[1], strides, padding, dilation, out_dtype=out_dtype) + else: + raise ValueError("not support arbitrary group number for now") + return [out] + + +@reg.register_schedule("nn.conv2d") +def schedule_conv2d(attrs, outs, target): + """Schedule definition of conv2d""" + groups = attrs.groups + layout = attrs.data_layout + kernel_layout = attrs.weight_layout + with target: + if groups == 1 and layout == "NCHW": + return topi.generic.schedule_conv2d_nchw(outs) + elif groups == 1 and layout == "NCHW4c": + return topi.generic.schedule_conv2d_nchw(outs) + elif groups == 1 and layout == "NHWC": + return topi.generic.schedule_conv2d_nhwc(outs) + elif groups != 1: + if layout == "NCHW": + # TODO(leyuan, merrymercy, Huyuwei): fold depthwise topi into conv2d. + return topi.generic.schedule_depthwise_conv2d_nchw(outs) + elif layout == "NHWC" and kernel_layout == "HWOI": + return topi.generic.schedule_depthwise_conv2d_nhwc(outs) + raise ValueError("No compatible schedule") + +reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE) + + +# conv2d_transpose +@reg.register_compute("nn.conv2d_transpose") +def compute_conv2d_transpose(attrs, inputs, out_dtype, target): + """Compute definition of conv2d_transpose""" + padding = get_const_tuple(attrs.padding) + strides = get_const_tuple(attrs.strides) + dilation = get_const_tuple(attrs.dilation) + groups = attrs.groups + layout = attrs.data_layout + out_dtype = attrs.out_dtype + out_dtype = (inputs[0].dtype if (out_dtype == "same" or out_dtype == "") + else out_dtype) + assert layout == "NCHW", "only support nchw for now" + assert dilation == (1, 1), "not support dilate now" + assert groups == 1, "only support groups == 1 for now" + out = topi.nn.conv2d_transpose_nchw(inputs[0], inputs[1], strides, padding, out_dtype) + output_padding = get_const_tuple(attrs.output_padding) + out = topi.nn.pad(out, + [0, 0, 0, 0], [0, 0, output_padding[0], output_padding[1]]) + return [out] + +@reg.register_schedule("nn.conv2d_transpose") +def schedule_conv2d_transpose(attrs, outs, target): + """Schedule definition of conv2d_transpose""" + with target: + return topi.generic.schedule_conv2d_transpose_nchw(outs) + +reg.register_pattern("nn.conv2d_transpose", OpPattern.OUT_ELEMWISE_FUSABLE) + +# bias_add +@reg.register_compute("nn.bias_add") +def compute_bias_add(attrs, inputs, out_dtype, target): + """Compute definition of conv2d_transpose""" + axis = attrs.axis + bias = inputs[1] + data_ndim = len(inputs[0].shape) + if axis < 0: + axis = axis + data_ndim + num_newaxis = data_ndim - axis - 1 + + if num_newaxis: + bias = topi.expand_dims(bias, axis=1, num_newaxis=num_newaxis) + return [topi.add(inputs[0], bias)] + +reg.register_schedule("nn.bias_add", schedule_injective) +reg.register_pattern("nn.bias_add", OpPattern.BROADCAST) + + +# max_pool2d +@reg.register_schedule("nn.max_pool2d") +def schedule_max_pool2d(attrs, outs, target): + """Schedule definition of max_pool2d""" + layout = attrs.layout + with target: + return topi.generic.schedule_pool(outs, layout) + +reg.register_pattern("nn.max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) + + +# avg_pool2d +@reg.register_schedule("nn.avg_pool2d") +def schedule_avg_pool2d(attrs, outs, target): + """Schedule definition of avg_pool2d""" + layout = attrs.layout + with target: + return topi.generic.schedule_pool(outs, layout) + +reg.register_pattern("nn.avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) + + +# global_max_pool2d +@reg.register_schedule("nn.global_max_pool2d") +def schedule_global_max_pool2d(_, outs, target): + """Schedule definition of global_max_pool2d""" + with target: + return topi.generic.schedule_global_pool(outs) + +reg.register_pattern("nn.global_max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) + + +# global_avg_pool2d +@reg.register_schedule("nn.global_avg_pool2d") +def schedule_global_avg_pool2d(_, outs, target): + """Schedule definition of global_avg_pool2d""" + with target: + return topi.generic.schedule_global_pool(outs) + +reg.register_pattern("nn.global_avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index 3bdb5989c..c777a8246 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -1,4 +1,7 @@ +#pylint: disable=unused-argument """The base node types for the Relay language.""" +import topi + from ..._ffi.function import _init_api from ..base import register_relay_node @@ -156,3 +159,9 @@ def _lower(name, schedule, inputs, outputs): @register_func("relay.op.compiler._build") def _build(lowered_funcs): return build(lowered_funcs, target="llvm") + + +def schedule_injective(attrs, outputs, target): + """Generic schedule for binary broadcast.""" + with target: + return topi.generic.schedule_injective(outputs) diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index d9385977d..38e3f6c2a 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -89,7 +89,7 @@ class ScheduleGetter : CachedFunc cfunc(cache_node); CHECK(master_op_.defined()); Schedule schedule = fschedule[master_op_]( - cache_node->outputs, target_); + master_attrs_, cache_node->outputs, target_); return std::make_pair(schedule, cfunc); } @@ -145,6 +145,7 @@ class ScheduleGetter : } if (op_pattern >= master_op_patetrn_) { master_op_ = op; + master_attrs_ = call_node->attrs; master_op_patetrn_ = op_pattern; } if (outputs.size() != 1) { @@ -193,6 +194,7 @@ class ScheduleGetter : private: tvm::Target target_; Op master_op_; + Attrs master_attrs_; int master_op_patetrn_{0}; std::ostringstream readable_name_stream_; std::unordered_map<Expr, Array<Tensor>, NodeHash, NodeEqual> memo_; @@ -285,6 +287,9 @@ class CompileEngineImpl : public CompileEngineNode { * \return Updated name which is unique. */ std::string GetUniqeName(std::string name) { + for (size_t i = 0; i < name.length(); ++i) { + if (name[i] == '.') name[i] = '_'; + } while (true) { auto it = name_map_.find(name); if (it == name_map_.end()) { diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index d141eec3b..fb4c7304a 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -91,16 +91,15 @@ bool DenseRel(const Array<Type>& types, Array<tvm::Expr> oshape = data->shape; if (param->units.defined()) { Array<tvm::Expr> dshape = data->shape; - // validate the weight shape is proper if defined // Assign weight type - Array<IndexExpr> wshape({dshape[dshape.size() - 1], param->units}); + Array<IndexExpr> wshape({param->units, dshape[dshape.size() - 1]}); reporter->Assign(types[1], TensorTypeNode::make(wshape, data->dtype)); oshape.Set((oshape.size() - 1), param->units); } else { if (weight == nullptr) return false; Array<tvm::Expr> wshape = weight->shape; - oshape.Set((oshape.size() - 1), wshape[wshape.size() - 1]); + oshape.Set((oshape.size() - 1), wshape[0]); } // assign output type diff --git a/src/relay/op/nn/pooling.cc b/src/relay/op/nn/pooling.cc index 8c989ac91..0e54564e0 100644 --- a/src/relay/op/nn/pooling.cc +++ b/src/relay/op/nn/pooling.cc @@ -4,7 +4,9 @@ * \brief Pooling operators */ #include <tvm/relay/op.h> +#include <tvm/relay/op_attr_types.h> #include <tvm/relay/attrs/nn.h> +#include <topi/nn/pooling.h> #include <vector> #include "layout.h" @@ -14,7 +16,7 @@ namespace relay { TVM_REGISTER_NODE_TYPE(MaxPool2DAttrs); TVM_REGISTER_NODE_TYPE(AvgPool2DAttrs); -template <typename AttrTtype> +template <typename AttrType> bool Pool2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, @@ -27,7 +29,7 @@ bool Pool2DRel(const Array<Type>& types, CHECK_NE(dshape.size(), 0); CHECK_GE(dshape.size(), 2U) << "Pool2D only support input >= 2-D: input must have height and width"; - const auto param = attrs.as<AttrTtype>(); + const auto param = attrs.as<AttrType>(); CHECK(param != nullptr); Layout layout(param->layout); @@ -88,6 +90,46 @@ Expr MakeMaxPool2D(Expr data, return CallNode::make(op, {data}, Attrs(attrs), {}); } +template<typename AttrType, topi::nn::PoolType mode> +Array<Tensor> Pool2DCompute(const Attrs& attrs, + const Array<Tensor>& inputs, + const Type& out_type, + const Target& target) { + const auto* param = attrs.as<AttrType>(); + CHECK(param != nullptr); + auto pool_size = param->pool_size; + auto strides = param->strides; + auto padding = param->padding; + auto ceil_mode = param->ceil_mode; + Layout layout(param->layout); + CHECK(layout.convertible(Layout("NCHW"))) + << "max_pool2d currently only supports layouts that are convertible from NCHW"; + CHECK_EQ(layout.indexof('h'), -1) << "max_pool2d does not support input split on height"; + CHECK_EQ(layout.indexof('w'), -1) << "max_pool2d does not support input split on width"; + + CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U) + << "Pool2D only support 4-D input (e.g., NCHW)" + << " or 5-D input (last dimension is a split of channel)"; + + if (param->padding.size() == 1) { + padding.push_back(padding[0]); + padding.push_back(padding[0]); + padding.push_back(padding[0]); + } else if (param->padding.size() == 2) { + padding.push_back(padding[0]); + padding.push_back(padding[1]); + } + if (mode == topi::nn::kAvgPool) { + bool count_include_pad = reinterpret_cast<const AvgPool2DAttrs*>(param)->count_include_pad; + return Array<Tensor>{ + topi::nn::pool(inputs[0], pool_size, strides, padding, + mode, ceil_mode, layout.name(), count_include_pad)}; + } else { + return Array<Tensor>{ + topi::nn::pool(inputs[0], pool_size, strides, padding, + mode, ceil_mode, layout.name())}; + } +} TVM_REGISTER_API("relay.op.nn._make.max_pool2d") .set_body([](const TVMArgs& args, TVMRetValue* rv) { @@ -120,7 +162,8 @@ RELAY_REGISTER_OP("nn.max_pool2d") .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_support_level(2) -.add_type_rel("MaxPool2D", Pool2DRel<MaxPool2DAttrs>); +.add_type_rel("MaxPool2D", Pool2DRel<MaxPool2DAttrs>) +.set_attr<FTVMCompute>("FTVMCompute", Pool2DCompute<MaxPool2DAttrs, topi::nn::kMaxPool>); // AvgPool2D @@ -175,7 +218,8 @@ Average pooling operation for one dimensional data. .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_support_level(2) -.add_type_rel("AvgPool2D", Pool2DRel<AvgPool2DAttrs>); +.add_type_rel("AvgPool2D", Pool2DRel<AvgPool2DAttrs>) +.set_attr<FTVMCompute>("FTVMCompute", Pool2DCompute<AvgPool2DAttrs, topi::nn::kAvgPool>); // Global Pool TVM_REGISTER_NODE_TYPE(GlobalPool2DAttrs); @@ -211,6 +255,29 @@ bool GlobalPool2DRel(const Array<Type>& types, return true; } + +template<topi::nn::PoolType mode> +Array<Tensor> GlobalPool2DCompute(const Attrs& attrs, + const Array<Tensor>& inputs, + const Type& out_type, + const Target& target) { + const auto* param = attrs.as<GlobalPool2DAttrs>(); + CHECK(param != nullptr); + Layout layout(param->layout); + CHECK(layout.convertible(Layout("NCHW"))) + << "global_avg_pool2d currently only supports layouts that are convertible from NCHW"; + CHECK_EQ(layout.indexof('h'), -1) + << "global_avg_pool2d does not support input split on height"; + CHECK_EQ(layout.indexof('w'), -1) + << "global_avg_pool2d does not support input split on width"; + + CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U) + << "Pool2D only support 4-D input (e.g., NCHW)" + << " or 5-D input (last dimension is a split of channel)"; + return Array<Tensor>{ + topi::nn::global_pool(inputs[0], mode, layout.name()) }; +} + Expr MakeGlobalAvgPool2D(Expr data, std::string layout) { auto attrs = make_node<GlobalPool2DAttrs>(); @@ -239,7 +306,8 @@ RELAY_REGISTER_OP("nn.global_avg_pool2d") .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_support_level(2) -.add_type_rel("GlobalAvgPool2D", GlobalPool2DRel); +.add_type_rel("GlobalAvgPool2D", GlobalPool2DRel) +.set_attr<FTVMCompute>("FTVMCompute", GlobalPool2DCompute<topi::nn::kAvgPool>); // GlobalMaxPool Expr MakeGlobalMaxPool2D(Expr data, @@ -269,7 +337,8 @@ RELAY_REGISTER_OP("nn.global_max_pool2d") .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_support_level(2) -.add_type_rel("GlobalMaxPool2D", GlobalPool2DRel); +.add_type_rel("GlobalMaxPool2D", GlobalPool2DRel) +.set_attr<FTVMCompute>("FTVMCompute", GlobalPool2DCompute<topi::nn::kMaxPool>); } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_backend_interpreter.py b/tests/python/relay/test_backend_interpreter.py index c9f689f7b..f53f27192 100644 --- a/tests/python/relay/test_backend_interpreter.py +++ b/tests/python/relay/test_backend_interpreter.py @@ -55,28 +55,6 @@ def test_mul_param(): check_eval(func, [x_data, y_data], x_data * y_data) -# failing due to numeric issues - -# def test_dense(): -# x = relay.var('x', shape=(10, 10)) -# w = relay.var('w', shape=(10, 10)) -# y = relay.nn.dense(x, w) -# func = relay.Function([x, w], y) -# x_data = np.random.rand(10, 10).astype('float32') -# w_data = np.random.rand(10, 10).astype('float32') -# check_eval(func, [x_data, w_data], x_data @ w_data, rtol=0.1) - -# def test_linear(): -# x = relay.var('x', shape=(10, 10)) -# w = relay.var('w', shape=(10, 10)) -# b = relay.var('b', shape=(10,)) -# y = relay.add(relay.nn.dense(x, w), b) -# func = relay.Function([x, w, b], y) -# x_data = np.random.rand(10, 10).astype('float32') -# w_data = np.random.rand(10, 10).astype('float32') -# b_data = np.random.rand(10).astype('float32') -# check_eval(func, [x_data, w_data, b_data], x_data @ w_data + b_data) - def test_equal(): i = relay.var('i', shape=[], dtype='int32') j = relay.var('i', shape=[], dtype='int32') diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index 477207dce..88a7aba59 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -74,6 +74,7 @@ def test_binary_op(): y_data = np.random.rand(5, 10, 5).astype(t2.dtype) ref_res = ref(x_data, y_data) func = relay.Function([x, y], z) + for target, ctx in ctx_list(): # use graph by execuor default for testing, as we need # create function explicitly to avoid constant-folding. @@ -89,12 +90,24 @@ def test_binary_op(): def test_bias_add(): - x = relay.var("x", shape=(10, 2, 3, 4)) + xshape=(10, 2, 3, 4) + bshape=(2,) + dtype="float32" + x = relay.var("x", shape=xshape) bias = relay.var("bias") z = relay.nn.bias_add(x, bias) zz = relay.ir_pass.infer_type(z) assert "axis=" not in zz.astext() - assert zz.args[1].checked_type == relay.TensorType((2,)) + assert zz.args[1].checked_type == relay.TensorType(bshape) + + func = relay.Function([x, bias], z) + x_data = np.random.uniform(size=xshape).astype(dtype) + y_data = np.random.uniform(size=bshape).astype(dtype) + ref_res = x_data + y_data.reshape((2, 1, 1)) + for target, ctx in ctx_list(): + intrp = relay.create_executor("graph", ctx=ctx, target=target) + op_res = intrp.evaluate(func)(x_data, y_data) + np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) def test_expand_dims_infer_type(): @@ -217,6 +230,50 @@ def test_batch_norm(): ])) +def test_dense(): + n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") + x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) + w = relay.var("w", relay.TensorType((2, w), "float32")) + y = relay.nn.dense(x, w, units=2) + "units=2" in y.astext() + yy = relay.ir_pass.infer_type(y) + assert yy.checked_type == relay.TensorType((n, c, h, 2), "float32") + + n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), 2 + x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) + wh, ww = tvm.var("wh"), tvm.var("ww") + w = relay.var("w", relay.TensorType((ww, wh), "float32")) + y = relay.nn.dense(x, w) + yy = relay.ir_pass.infer_type(y) + assert yy.checked_type == relay.TensorType((n, c, h, ww), "float32") + + n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), 2 + x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) + w = relay.var("w", relay.IncompleteType()) + y = relay.nn.dense(x, w, units=2) + yy = relay.ir_pass.infer_type(y) + assert yy.checked_type == relay.TensorType((n, c, h, 2), "float32") + + x = relay.var("x", shape=(10, 5)) + w = relay.var("w", shape=(2, 5)) + z = relay.nn.dense(x, w) + + # Check result. + func = relay.Function([x, w], z) + x_data = np.random.rand(10, 5).astype('float32') + w_data = np.random.rand(2, 5).astype('float32') + ref_res = np.dot(x_data, w_data.T) + + for target, ctx in ctx_list(): + intrp1 = relay.create_executor("graph", ctx=ctx, target=target) + intrp2 = relay.create_executor("debug", ctx=ctx, target=target) + op_res1 = intrp1.evaluate(func)(x_data, w_data) + tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5) + op_res2 = intrp2.evaluate(func)(x_data, w_data) + tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5) + + + if __name__ == "__main__": test_bias_add() test_unary_op() @@ -227,3 +284,4 @@ if __name__ == "__main__": test_log_softmax() test_dropout() test_batch_norm() + test_dense() diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 9dd249128..7b3a6d3fe 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -2,7 +2,9 @@ """ import tvm from tvm import relay - +from tvm.relay.testing import ctx_list +import numpy as np +import topi.testing def test_conv2d_infer_type(): # symbolic in batch dimension @@ -62,6 +64,62 @@ def test_conv2d_infer_type(): (n, h, w, 16), "int32") +def test_conv2d_run(): + def run_test_conv2d(dtype, out_dtype, scale, dshape, kshape, + padding=(1, 1), + fref=None, + groups=1, + dilation=(1, 1), + **attrs): + x = relay.var("x", shape=dshape) + w = relay.var("w") + y = relay.nn.conv2d(x, w, + padding=padding, + dilation=dilation, + groups=groups, + **attrs) + func = relay.Function([x, w], y) + data = np.random.uniform(-scale, scale, size=dshape).astype(dtype) + kernel = np.random.uniform(-scale, scale, size=kshape).astype(dtype) + dkernel = topi.testing.dilate_python(kernel, (1, 1) + dilation) + if fref is None: + ref_res = topi.testing.conv2d_nchw_python( + data.astype(out_dtype), dkernel.astype(out_dtype), 1, padding) + else: + ref_res = fref(data.astype(out_dtype), dkernel.astype(out_dtype)) + + for target, ctx in ctx_list(): + intrp1 = relay.create_executor("graph", ctx=ctx, target=target) + op_res1 = intrp1.evaluate(func)(data, kernel) + tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) + + # depthwise conv2d + dshape = (1, 32, 18, 18) + kshape = (32, 1, 3, 3) + run_test_conv2d("float32", "float32", 1, dshape, kshape, + padding=(1, 1), channels=32, groups=32, kernel_size=(3 ,3), + fref=lambda x, w: topi.testing.depthwise_conv2d_python_nchw( + x, w, (1, 1), "SAME")) + + # normal conv2d + dshape = (1, 3, 224, 224) + kshape = (10, 3, 3, 3) + run_test_conv2d("float32", "float32", 1, dshape, kshape, + padding=(1, 1), channels=10, kernel_size=(3 ,3)) + # mixed precision + run_test_conv2d("int8", "int32", 1, dshape, kshape, + padding=(1, 1), channels=10, kernel_size=(3 ,3)) + kshape = (10, 3, 1, 3) + # mixed precision. + run_test_conv2d("int8", "int32", 1, dshape, kshape, + padding=(0, 1), channels=10, kernel_size=(1 ,3)) + # dilated conv2d + dshape = (1, 3, 18, 18) + kshape = (10, 3, 3, 3) + run_test_conv2d("float32", "float32", 1, dshape, kshape, + padding=(1, 1), channels=10, kernel_size=(3 ,3), dilation=(3, 3)) + + def test_conv2d_transpose_infer_type(): # symbolic in batch dimension n, c, h, w = tvm.var("n"), 10, 10, 12 @@ -90,6 +148,33 @@ def test_conv2d_transpose_infer_type(): assert yy.checked_type == relay.TensorType( (n, 15, 15, 11), "float32") + +def test_conv2d_transpose_run(): + dshape = (1, 3, 18, 18) + kshape = (3, 10, 3, 3) + oshape = (1, 10, 37, 37) + x = relay.var("x", shape=dshape) + w = relay.var("w") + y = relay.nn.conv2d_transpose(x, w, + channels=10, kernel_size=(3,3), strides=(2,2), + padding=(1,1), output_padding=(2, 2)) + func = relay.Function([x, w], y) + dtype = "float32" + data = np.random.uniform(size=dshape).astype(dtype) + kernel = np.random.uniform(size=kshape).astype(dtype) + c_np = topi.testing.conv2d_transpose_nchw_python( + data, kernel, 2, 1) + d_np = np.zeros(shape=oshape) + d_np[:,:,0:c_np.shape[2],0:c_np.shape[3]] = c_np + ref_res = d_np + + for target, ctx in ctx_list(): + intrp1 = relay.create_executor("graph", ctx=ctx, target=target) + op_res1 = intrp1.evaluate(func)(data, kernel) + tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) + + + def test_upsampling_infer_type(): n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) @@ -103,15 +188,29 @@ def test_upsampling_infer_type(): yy = relay.ir_pass.infer_type(y) assert yy.checked_type == relay.TensorType((n, c, 200, 400), "float32") -def _test_pool2d_infer_type(opfunc): + +def _test_pool2d(opfunc, reffunc): n, c, h, w = tvm.var("n"), 10, 224, 224 x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) y = opfunc(x, pool_size=(1, 1)) assert "pool_size=" in y.astext() yy = relay.ir_pass.infer_type(y) assert yy.checked_type == relay.TensorType((n, 10, 224, 224), "float32") + # test execution + dtype = "float32" + dshape = (1, 3, 28, 28) + x = relay.var("x", shape=dshape) + y = opfunc(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) + func = relay.Function([x], y) + data = np.random.uniform(size=dshape).astype(dtype) + ref_res = reffunc(data.reshape(1,3,14,2,14,2), axis=(3,5)) + for target, ctx in ctx_list(): + intrp1 = relay.create_executor("graph", ctx=ctx, target=target) + op_res1 = intrp1.evaluate(func)(data) + tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) -def _test_global_pool2d_infer_type(opfunc): + +def _test_global_pool2d(opfunc, reffunc): n, c, h, w = tvm.var("n"), tvm.var("c"), 224, 224 x = relay.var("x", relay.TensorType((n, h, w, c), "float32")) y = opfunc(x, layout="NHWC") @@ -123,12 +222,61 @@ def _test_global_pool2d_infer_type(opfunc): y = opfunc(x) yy = relay.ir_pass.infer_type(y) assert yy.checked_type == relay.TensorType((n, c, 1, 1), "float32") + # test execution + dtype = "float32" + dshape = (1, 1024, 7, 7) + x = relay.var("x", shape=dshape) + y = opfunc(x) + func = relay.Function([x], y) + data = np.random.uniform(size=dshape).astype(dtype) + ref_res = reffunc(data, axis=(2,3), keepdims=True) + for target, ctx in ctx_list(): + intrp1 = relay.create_executor("graph", ctx=ctx, target=target) + op_res1 = intrp1.evaluate(func)(data) + tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) + + +def test_pool2d(): + _test_pool2d(relay.nn.max_pool2d, np.max) + _test_pool2d(relay.nn.avg_pool2d, np.mean) + _test_global_pool2d(relay.nn.global_max_pool2d, np.max) + _test_global_pool2d(relay.nn.global_avg_pool2d, np.mean) + + +def test_avg_pool2d_no_count_pad(): + kh, kw = (4, 4) + sh, sw = (2, 2) + ph, pw = (2, 2) + n = 1 + (ic, ih, iw) = (3, 28, 28) + (oc, oh, ow) = (3, 15, 15) + dshape = (n, ic, ih, iw) + x = relay.var("x", shape=dshape) + y = relay.nn.avg_pool2d(x, + pool_size=(kh, kw), + strides=(sw, sw), + padding=(ph, pw), + count_include_pad=False) + func = relay.Function([x], y) + dtype = "float32" + a_np = np.random.uniform(low=0.001, size=(n, ic, ih, iw)).astype(dtype) + pad_np = np.zeros(shape=(n, ic, ih+2*ph, iw+2*pw)).astype(dtype) + no_zero = (range(n), range(ic), (range(ph, ih+ph)), (range(pw, iw+pw))) + pad_np[np.ix_(*no_zero)] = a_np + b_np = np.zeros(shape=(n, oc, oh, ow)).astype(dtype) + for i in range(oh): + for j in range(ow): + pad_count = np.sum(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw] > 0, axis=(2,3)) + b_np[:,:,i,j] = np.sum(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], + axis=(2,3)) / np.maximum(pad_count, 1) + ref_res = np.maximum(b_np, 0.0) + data = a_np + + for target, ctx in ctx_list(): + intrp1 = relay.create_executor("graph", ctx=ctx, target=target) + op_res1 = intrp1.evaluate(func)(data) + tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) -def test_pool2d_infer_type(): - _test_pool2d_infer_type(relay.nn.max_pool2d) - _test_pool2d_infer_type(relay.nn.avg_pool2d) - _test_global_pool2d_infer_type(relay.nn.global_avg_pool2d) - _test_global_pool2d_infer_type(relay.nn.global_avg_pool2d) def test_flatten_infer_type(): d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4") @@ -163,30 +311,6 @@ def test_pad_infer_type(): yy = relay.ir_pass.infer_type(y) assert yy.checked_type == relay.TensorType((n + 2, 6, 9, w + 8), "float32") -def test_dense_infer_type(): - n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") - x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) - w = relay.var("w", relay.TensorType((w, 2), "float32")) - y = relay.nn.dense(x, w, units=2) - "units=2" in y.astext() - yy = relay.ir_pass.infer_type(y) - assert yy.checked_type == relay.TensorType((n, c, h, 2), "float32") - - n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), 2 - x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) - wh, ww = tvm.var("wh"), tvm.var("ww") - w = relay.var("w", relay.TensorType((wh, ww), "float32")) - y = relay.nn.dense(x, w) - yy = relay.ir_pass.infer_type(y) - assert yy.checked_type == relay.TensorType((n, c, h, ww), "float32") - - n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), 2 - x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) - w = relay.var("w", relay.IncompleteType()) - y = relay.nn.dense(x, w, units=2) - yy = relay.ir_pass.infer_type(y) - assert yy.checked_type == relay.TensorType((n, c, h, 2), "float32") - def test_lrn(): n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") @@ -206,12 +330,14 @@ def test_l2_normalize(): if __name__ == "__main__": + test_pool2d() + test_avg_pool2d_no_count_pad() test_lrn() test_l2_normalize() test_conv2d_infer_type() - test_pool2d_infer_type() test_upsampling_infer_type() test_flatten_infer_type() test_pad_infer_type() test_conv2d_transpose_infer_type() - test_dense_infer_type() + test_conv2d_transpose_run() + test_conv2d_run() diff --git a/topi/python/topi/util.py b/topi/python/topi/util.py index 71e123e83..de9ff90ae 100644 --- a/topi/python/topi/util.py +++ b/topi/python/topi/util.py @@ -1,8 +1,9 @@ # pylint: disable=invalid-name """Common topi utilities""" from __future__ import absolute_import as _abs -import tvm +from numbers import Integral +import tvm from . import tag def traverse_inline(s, final_op, callback): @@ -68,13 +69,13 @@ def get_const_int(expr): out_value : int The output. """ - if isinstance(expr, int): + if isinstance(expr, Integral): return expr if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)): expr = tvm.ir_pass.Simplify(expr) if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)): raise ValueError("Expect value to be constant int") - return expr.value + return int(expr.value) def equal_const_int(expr, value): @@ -90,7 +91,7 @@ def equal_const_int(expr, value): equal : bool Whether they equals. """ - if isinstance(expr, int): + if isinstance(expr, Integral): return expr == value if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)): expr = tvm.ir_pass.Simplify(expr) -- GitLab