diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index 6563903a9cbf5e89e87ebae3d2a7e7bf6a3a7bb8..8566404561b2277747ec9a237b82e60d228bd818 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -61,6 +61,7 @@ This level enables typical convnet models. tvm.relay.less tvm.relay.less_equal tvm.relay.maximum + tvm.relay.minimum **Level 5: Vision/Image Operators** @@ -89,4 +90,5 @@ Level 4 Definitions .. autofunction:: tvm.relay.greater_equal .. autofunction:: tvm.relay.less .. autofunction:: tvm.relay.less_equal -.. autofunction:: tvm.relay.maximum \ No newline at end of file +.. autofunction:: tvm.relay.maximum +.. autofunction:: tvm.relay.minimum diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py index d6e4f32ae553c47f6bf19d5d4c624a129c7c6a4f..859bfdc267999921427c9b49a0b1689a799dd14f 100644 --- a/python/tvm/relay/op/tensor.py +++ b/python/tvm/relay/op/tensor.py @@ -246,6 +246,24 @@ def maximum(lhs, rhs): return _make.maximum(lhs, rhs) +def minimum(lhs, rhs): + """Minimum with numpy-style broadcasting. + + Parameters + ---------- + lhs : relay.Expr + The left hand side input data + rhs : relay.Expr + The right hand side input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.minimum(lhs, rhs) + + def right_shift(lhs, rhs): """Right shift with numpy-style broadcasting. diff --git a/src/relay/op/tensor/binary.cc b/src/relay/op/tensor/binary.cc index 7f8f6884c597b70e9c769ca3d01d248680cb5432..11175f21573d9bacfd5bb1c1baefc35e02778649 100644 --- a/src/relay/op/tensor/binary.cc +++ b/src/relay/op/tensor/binary.cc @@ -45,6 +45,10 @@ RELAY_REGISTER_BINARY_OP("maximum") .describe("Elementwise maximum of two tensors with broadcasting") .set_support_level(4); +RELAY_REGISTER_BINARY_OP("minimum") +.describe("Elementwise minimum of two tensors with broadcasting") +.set_support_level(4); + // Comparisons #define RELAY_REGISTER_CMP_OP(OpName) \ TVM_REGISTER_API("relay.op._make." OpName) \ diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index a4b8cebd297d6227d2fc4dca9c093476e7f68574..72876780f944ca5e8cf12a89742bacf8bc736f3b 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -23,7 +23,8 @@ def test_cmp_type(): def test_binary_broadcast(): for op in [relay.right_shift, relay.left_shift, - relay.maximum]: + relay.maximum, + relay.minimum]: ib = relay.ir_builder.IRBuilder() x = ib.param("x", relay.TensorType((10, 4), "int32")) y = ib.param("y", relay.TensorType((5, 10, 1), "int32"))