diff --git a/include/tvm/relay/attrs/vision.h b/include/tvm/relay/attrs/vision.h index 5408582c8356d6278a72e1349bf14252ce0c4288..b736bd9c06a0d1879e710929321ee685ec170cdb 100644 --- a/include/tvm/relay/attrs/vision.h +++ b/include/tvm/relay/attrs/vision.h @@ -40,6 +40,24 @@ struct MultiBoxPriorAttrs : public tvm::AttrsNode<MultiBoxPriorAttrs> { } }; +struct MultiBoxTransformLocAttrs + : public tvm::AttrsNode<MultiBoxTransformLocAttrs> { + bool clip; + double threshold; + Array<IndexExpr> variances; + + TVM_DECLARE_ATTRS(MultiBoxTransformLocAttrs, + "relay.attrs.MultiBoxTransformLocAttrs") { + TVM_ATTR_FIELD(clip).set_default(true) + .describe("Clip out-of-boundary boxes."); + TVM_ATTR_FIELD(threshold).set_default(0.01) + .describe("Threshold to be a positive prediction."); + TVM_ATTR_FIELD(variances) + .set_default(Array<IndexExpr>({0.1f, 0.1f , 0.2f, 0.2f})) + .describe("Variances to be decoded from box regression output."); + } +}; + /*! \brief Attributes used in non_maximum_suppression operators */ struct NMSAttrs : public tvm::AttrsNode<NMSAttrs>{ double overlap_threshold; diff --git a/python/tvm/relay/op/vision/multibox.py b/python/tvm/relay/op/vision/multibox.py index 9b7483eec5abdcee0cbabc8fde2717592a01f658..b04610aaa0809ab5d33049241fcdfb0acb99372e 100644 --- a/python/tvm/relay/op/vision/multibox.py +++ b/python/tvm/relay/op/vision/multibox.py @@ -36,3 +36,39 @@ def multibox_prior(data, 3-D tensor with shape [1, h_in * w_in * (num_sizes + num_ratios - 1), 4] """ return _make.multibox_prior(data, sizes, ratios, steps, offsets, clip) + + +def multibox_transform_loc(cls_prob, + loc_pred, + anchor, + clip=True, + threshold=0.01, + variance=(0.1, 0.1, 0.2, 0.2)): + """Location transformation for multibox detection + + Parameters + ---------- + cls_prob : tvm.relay.Expr + Class probabilities. + + loc_pred : tvm.relay.Expr + Location regression predictions. + + anchor : tvm.relay.Expr + Prior anchor boxes. + + clip : boolean, optional + Whether to clip out-of-boundary boxes. + + threshold : double, optional + Threshold to be a positive prediction. + + variance : Tuple of float, optional + Variances to be decoded from box regression output. + + Returns + ------- + ret : tuple of tvm.relay.Expr + """ + return _make.multibox_transform_loc(cls_prob, loc_pred, anchor, clip, + threshold, variance) diff --git a/src/relay/op/vision/multibox_op.cc b/src/relay/op/vision/multibox_op.cc index e347e544e4f9e56f747fd7ddc833fe7e68de2ed6..55db8862e84963bf6d290e1ee2cc4b0aa66e05f2 100644 --- a/src/relay/op/vision/multibox_op.cc +++ b/src/relay/op/vision/multibox_op.cc @@ -1,5 +1,5 @@ /*! - * Copyright (c) 2018 by Contributors + * Copyright (c) 2018 by Contributors * \file multibox_op.cc * \brief Multibox related operators */ @@ -68,5 +68,78 @@ RELAY_REGISTER_OP("vision.multibox_prior") .set_support_level(5) .add_type_rel("MultiBoxPrior", MultiboxPriorRel); +TVM_REGISTER_NODE_TYPE(MultiBoxTransformLocAttrs); + +bool MultiBoxTransformLocRel(const Array<Type>& types, int num_inputs, + const Attrs& attrs, const TypeReporter& reporter) { + CHECK_EQ(types.size(), 4); + + const auto* cls_prob = types[0].as<TensorTypeNode>(); + const auto* loc_pred = types[1].as<TensorTypeNode>(); + const auto* anchor = types[2].as<TensorTypeNode>(); + CHECK(cls_prob != nullptr && loc_pred != nullptr && anchor != nullptr); + + const auto& cls_shape = cls_prob->shape; + const auto& loc_shape = loc_pred->shape; + const auto& anchor_shape = anchor->shape; + + CHECK_EQ(cls_shape.size(), 3U) + << "The dimension of class probability should be 3, but received " + << cls_shape.size(); + CHECK_EQ(loc_shape.size(), 2U) + << "The dimension of location prediction should be 2, but received " + << loc_shape.size(); + CHECK_EQ(anchor_shape.size(), 3U) + << "The dimension of anchor should be 3, but received " + << anchor_shape.size(); + + CHECK(reporter->AssertEQ(cls_shape[2], anchor_shape[1])) + << "Number of anchors mismatch found"; + CHECK(reporter->AssertEQ(cls_shape[2] * 4, loc_shape[1])) + << "# anchors mismatch with # loc."; + CHECK(reporter->Assert(anchor_shape[1] > 0)) << "Number of anchors must > 0."; + CHECK(reporter->AssertEQ(anchor_shape[2], 4)); + + std::vector<IndexExpr> oshape0({cls_shape[0], anchor_shape[1], 6}); + std::vector<IndexExpr> oshape1({cls_shape[0]}); + std::vector<Type> fields; + fields.push_back(TensorTypeNode::make(oshape0, cls_prob->dtype)); + fields.push_back(TensorTypeNode::make(oshape1, Int(32))); + + // assign output type + reporter->Assign(types[3], TupleTypeNode::make(Array<Type>(fields))); + return true; +} + +Expr MakeMultiBoxTransformLoc(Expr cls_prob, + Expr loc_pred, + Expr anchor, + bool clip, + double threshold, + Array<IndexExpr> variances) { + auto attrs = make_node<MultiBoxTransformLocAttrs>(); + attrs->clip = std::move(clip); + attrs->threshold = std::move(threshold); + attrs->variances = std::move(variances); + static const Op& op = Op::Get("vision.multibox_transform_loc"); + return CallNode::make(op, {cls_prob, loc_pred, anchor}, Attrs(attrs), {}); +} + +TVM_REGISTER_API("relay.op.vision._make.multibox_transform_loc") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call<Expr, 6>(MakeMultiBoxTransformLoc, args, rv); +}); + +RELAY_REGISTER_OP("vision.multibox_transform_loc") +.describe(R"doc("Location transformation for multibox detection." +)doc" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.MultiBoxTransformLocAttrs") +.set_num_inputs(3) +.add_argument("cls_prob", "Tensor", "Class probabilities.") +.add_argument("loc_pred", "Tensor", "Location regression predictions.") +.add_argument("anchor", "Tensor", "Multibox prior anchor boxes") +.add_type_rel("MultiBoxTransformLoc", MultiBoxTransformLocRel) +.set_support_level(5); + } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index 77e3f005dade89fc9cc7520aaa280179e7d66f4c..6bd331b9812043144df680f4a9dbe5aa1eca5740 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -102,8 +102,64 @@ def test_nms(): (n, num_anchors, 6), "float32") +def test_multibox_transform_loc(): + def test_default_value(): + num_anchors = 5 + num_classes = 5 + + cls_prob = relay.var( + "cls_prob", + relay.ty.TensorType((1, num_anchors, num_classes), "float32")) + loc_pred = relay.var( + "loc_pred", relay.ty.TensorType((1, num_anchors * 4), "float32")) + anchors = relay.var( + "anchors", relay.ty.TensorType((1, num_anchors, 4), "float32")) + + ret = relay.vision.multibox_transform_loc( + cls_prob=cls_prob, loc_pred=loc_pred, anchor=anchors) + ret = relay.ir_pass.infer_type(ret) + ref_type = relay.ty.TupleType( + tvm.convert([ + relay.ty.TensorType((1, num_anchors, 6), "float32"), + relay.ty.TensorType((1, ), "int") + ])) + assert ret.checked_type == ref_type + + def test_threshold(): + num_anchors = 5 + num_classes = 5 + n = tvm.var("n") + cls_prob = relay.var( + "cls_prob", + relay.ty.TensorType((n, num_anchors, num_classes), "float32")) + loc_pred = relay.var( + "loc_pred", relay.ty.TensorType((n, num_anchors * 4), "float32")) + anchors = relay.var( + "anchors", relay.ty.TensorType((1, num_anchors, 4), "float32")) + threshold = 0.02 + variance = (0.2, 0.2, 0.3, 0.3) + + ret = relay.vision.multibox_transform_loc( + cls_prob=cls_prob, + loc_pred=loc_pred, + anchor=anchors, + threshold=threshold, + variance=variance) + ret = relay.ir_pass.infer_type(ret) + ref_type = relay.ty.TupleType( + tvm.convert([ + relay.ty.TensorType((n, num_anchors, 6), "float32"), + relay.ty.TensorType((n, ), "int") + ])) + assert ret.checked_type == ref_type + + test_default_value() + test_threshold() + + if __name__ == "__main__": test_resize_infer_type() test_resize() test_multibox_prior() + test_multibox_transform_loc() test_nms()