diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index f4a65023ee53aba0873f60aca6d89297633ce16d..5c3ab8b1ffdabdfa4a4a250bdf6e3c4bd4ef847d 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -124,6 +124,7 @@ This level enables additional math and transform operators. tvm.relay.mean tvm.relay.prod tvm.relay.strided_slice + tvm.relay.broadcast_to **Level 5: Vision/Image Operators** diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index c1e71e9133eaf1636777ec49adefb18b419fb2a7..085a8ceed5d1d26e77e25537fa65e939137d7edf 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -11,6 +11,7 @@ schedule_broadcast = _reg.schedule_injective _reg.register_schedule("collapse_sum_like", _schedule_reduce) +_reg.register_schedule("broadcast_to", schedule_broadcast) _reg.register_schedule("broadcast_to_like", schedule_broadcast) _reg.register_schedule("expand_dims", schedule_broadcast) _reg.register_schedule("squeeze", schedule_injective) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index f536e75fd9b401d901ef3084013ecf9d8d88a3ff..2791eaf7d9dbe9aea4b283d004102024deafcc5e 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -267,6 +267,24 @@ def where(condition, x, y): """ return _make.where(condition, x, y) +def broadcast_to(data, shape): + """Return an scalar value array with the same type, broadcast to + the provided shape. + + Parameters + ---------- + data : relay.Expr + The input tensor. + + shape : shape + Provide the shape to broadcast to. + + Returns + ------- + result : relay.Expr + The resulting tensor. + """ + return _make.broadcast_to(data, shape) def broadcast_to_like(data, broadcast_type): """Return an scalar value array with the same shape and type as the input array. diff --git a/src/relay/op/nn/pooling.cc b/src/relay/op/nn/pooling.cc index 6233e6d51776bd8bf06fd4075fc8ac5647415e8e..6cf37668cab59ed2f3d3f7f399e2e069b1e29e62 100644 --- a/src/relay/op/nn/pooling.cc +++ b/src/relay/op/nn/pooling.cc @@ -258,8 +258,7 @@ bool GlobalPool2DRel(const Array<Type>& types, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as<TensorTypeNode>(); - - CHECK(data != nullptr); + if (data == nullptr) { return false; } const auto dshape = data->shape; CHECK_NE(dshape.size(), 0); CHECK_GE(dshape.size(), 2U) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index fcf7f6fe329911a40c5694f9b7f11b30fec084e8..eb8b4f13fb3fd9f601d2ecd41694d97504688374 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1084,6 +1084,52 @@ RELAY_REGISTER_OP("collapse_sum_like") .set_attr<FTVMCompute>("FTVMCompute", CollapseSumLikeCompute) .set_attr<TOpPattern>("TOpPattern", kCommReduce); +// BroadCastTo: <A, B> -> B where BroadCast(A, B) = B +bool BroadCastToRel(const Array<Type>& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 2); + auto ioattrs = attrs.as<InitOpAttrs>(); + CHECK(ioattrs); + auto intt = types[0].as<TensorTypeNode>(); + if (intt == nullptr) { return false; } + auto type = TensorTypeNode::make(ioattrs->shape, intt->dtype); + reporter->Assign(types[1], type); + return true; +} + +Expr MakeBroadCastTo(Expr data, Array<IndexExpr> shape) { + static const Op& op = Op::Get("broadcast_to"); + auto attrs = make_node<InitOpAttrs>(); + attrs->shape = std::move(shape); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + +Array<Tensor> BroadCastToCompute(const Attrs& attrs, + const Array<Tensor>& inputs, + const Type& out_type, + const Target& target) { + auto ioattrs = attrs.as<InitOpAttrs>(); + CHECK(ioattrs != nullptr); + return { topi::broadcast_to(inputs[0], ioattrs->shape) }; +} + +TVM_REGISTER_API("relay.op._make.broadcast_to") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call<Expr, 2>(MakeBroadCastTo, args, rv); + }); + +RELAY_REGISTER_OP("broadcast_to") +.describe(R"code(Broadcast the first input to match the shape argument. +)code" TVM_ADD_FILELINE) +.set_num_inputs(1) +.add_argument("data", "Tensor", "The input tensor.") +.set_support_level(4) +.add_type_rel("BroadCastTo", BroadCastToRel) +.set_attr<FTVMCompute>("FTVMCompute", BroadCastToCompute) +.set_attr<TOpPattern>("TOpPattern", kBroadcast); + // BroadCastToLike: <A, B> -> B where BroadCast(A, B) = B bool BroadCastToLikeRel(const Array<Type>& types, int num_inputs, diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index 5d65691a2ad59ef5ef11b7decb36b30d7ebd9d4a..2c0ed73a753595fd0bd1b7393d40dc9dc6d6e6fc 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -25,6 +25,24 @@ def test_collapse_sum_like(): op_res = intrp.evaluate(func)(x, y) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) +def test_broadcast_to(): + shape = (4, 1, 6) + shape_like = (3, 4, 5, 6) + dtype = "float32" + x = relay.Var("x", relay.ty.TensorType(shape , dtype)) + z = relay.broadcast_to(x, shape=shape_like) + zz = relay.ir_pass.infer_type(z) + assert zz.checked_type == relay.ty.TensorType(shape_like, dtype) + + func = relay.Function([x], z) + x = np.random.uniform(size=shape).astype(dtype) + ref_res = np.broadcast_to(x, shape_like) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(func)(x) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) + def test_broadcast_to_like(): shape = (4, 1, 6) shape_like = (3, 4, 5, 6)