diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py index 425a072631a65fba4e30b3dc026ff18c295f6eaa..316514801fd6fc969bd678788f73875a4f699698 100644 --- a/python/tvm/relay/op/tensor.py +++ b/python/tvm/relay/op/tensor.py @@ -515,6 +515,35 @@ def ones_like(data): """ return _make.ones_like(data) + +def clip(a, a_min, a_max): + """Clip the elements in `a` between `a_min` and `a_max`. + `a_min` and `a_max` are cast to `a`'s dtype. + + Parameters + ---------- + a : relay.Expr + The input tensor. + a_min : float + The clip minimum. + a_max : float + The clip maximum. + + Returns + ------- + result : relay.Expr + `a` with elements clipped between `a_min` and `a_max`. + + Examples + -------- + .. code:: python + x = relay.Constant(tvm.nd.array([0, 1, 5, 3, 4, 2])) + relay.clip(x, 1., 4.) + # [1, 1, 4, 3, 4, 2] + """ + return _make.clip(a, a_min, a_max) + + def concatenate(data, axis): """Concatenate the input tensors along the given axis. diff --git a/src/relay/op/tensor/unary.cc b/src/relay/op/tensor/unary.cc index 0ebb5f721d341b4c61d3d6065fd31df009cbf2d0..ef051e96453807d7cfdecf193a7ad268f24f922d 100644 --- a/src/relay/op/tensor/unary.cc +++ b/src/relay/op/tensor/unary.cc @@ -87,6 +87,37 @@ RELAY_REGISTER_UNARY_OP("copy") .set_support_level(3) .add_type_rel("Identity", IdentityRel); +// Clip +struct ClipAttrs : public tvm::AttrsNode<ClipAttrs> { + double a_min; + double a_max; + + TVM_DECLARE_ATTRS(ClipAttrs, "relay.attrs.ClipAttrs") { + TVM_ATTR_FIELD(a_min) + .describe("The minimum clip value."); + TVM_ATTR_FIELD(a_max) + .describe("The maximum clip value."); + } +}; + +TVM_REGISTER_API("relay.op._make.clip") + .set_body_typed<Expr(Expr, double, double)>([](Expr a, double a_min, double a_max) { + auto attrs = make_node<ClipAttrs>(); + attrs->a_min = a_min; + attrs->a_max = a_max; + static const Op& op = Op::Get("clip"); + return CallNode::make(op, {a}, Attrs(attrs), {}); + }); + +RELAY_REGISTER_OP("clip") + .describe(R"code(Clip tensor values. + This function takes a tensor, a minimum value `a_min`, and a maximum value `a_max`, and returns a clipped tensor where all values below `a_min` are set to `a_min` and all values above `a_max` are set to `a_max`. `a_min` and `a_max` are cast to the tensor's dtype. + )code" TVM_ADD_FILELINE) + .set_num_inputs(1) + .add_argument("tensor", "Tensor", "The input tensor.") + .set_support_level(3) + .add_type_rel("Clip", IdentityRel); + RELAY_REGISTER_UNARY_OP("floor") .describe(R"code(Returns the floor of input array, computed element-wise. )code" TVM_ADD_FILELINE) @@ -153,6 +184,5 @@ RELAY_REGISTER_UNARY_OP("negative") .set_support_level(3) .add_type_rel("Identity", IdentityRel); - } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 4dfa7b563b82cd1b9d1ce745289372bbb1be3aa5..c6b83b39c27645cdf8d657d7f5f9e9ca17c8116f 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -19,6 +19,18 @@ def test_unary_identity(): ftype = func.checked_type assert ftype.ret_type == relay.TensorType((8, 9, 4), "int32") + +def test_clip_type(): + ib = relay.ir_builder.IRBuilder() + a = ib.param("a", relay.TensorType((10, 4), "float32")) + with ib.function(a) as func: + ib.ret(relay.clip(a.var, 1., 4.)) + ib.ret(func) + func = relay.ir_pass.infer_type(ib.env, func.to_func()) + ftype = func.checked_type + assert ftype.ret_type == relay.TensorType((10, 4), "float32") + + def test_copy_infer_type(): ib = relay.ir_builder.IRBuilder() n, t, d = tvm.var("n"), tvm.var("t"), 100 @@ -57,6 +69,7 @@ def test_reshape_infer_type(): assert ftype.ret_type == relay.ty.TensorType( (n, t, 2000), "float32") + def assert_has_type(expr, typ, env=Environment({})): checked_expr = infer_type(env, expr) checked_type = checked_expr.checked_type @@ -78,9 +91,11 @@ def test_single_op(): tvm.relay.round, tvm.relay.abs, tvm.relay.negative]: check_single_op(opfunc) + if __name__ == "__main__": test_single_op() test_unary_identity() + test_clip_type() test_copy_infer_type() test_transpose_infer_type() test_reshape_infer_type()