diff --git a/nnvm/python/nnvm/frontend/mxnet.py b/nnvm/python/nnvm/frontend/mxnet.py index f0217fc1ec85e52c0337d50e40f140f906423d2f..87b169a1cfbc9ad755161a9fb47999d80e5fb5c5 100644 --- a/nnvm/python/nnvm/frontend/mxnet.py +++ b/nnvm/python/nnvm/frontend/mxnet.py @@ -273,6 +273,14 @@ def _lrn(inputs, attrs): new_attrs['size'] = _required_attr(attrs, 'nsize') return _get_nnvm_op(op_name)(*inputs, **new_attrs) +def _ones(_, attrs): + op_name = "ones" + return _get_nnvm_op(op_name)(**attrs) + +def _zeros(_, attrs): + op_name = "zeros" + return _get_nnvm_op(op_name)(**attrs) + _identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__', '__div_symbol__', '__mul_scalar__', '__mul_symbol__', '__pow_scalar__', '__rdiv_scalar__', '__rpow_scalar__', @@ -281,8 +289,8 @@ _identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__', 'broadcast_sub', 'broadcast_to', 'cast', 'elemwise_add', 'elemwise_div', 'elemwise_mul', 'elemwise_sub', 'exp', 'flatten', 'log', 'log_softmax', 'max', 'min', 'negative', - 'relu', 'sigmoid', 'slice_like', 'softmax', 'sum', 'tanh', - 'transpose'] + 'ones_like', 'relu', 'sigmoid', 'slice_like', 'softmax', + 'sum', 'tanh', 'transpose', 'zeros_like'] _convert_map = { '_copy' : _rename('copy'), @@ -294,6 +302,8 @@ _convert_map = { '_rminus_scalar': _rename('__rsub_scalar__'), '_contrib_MultiBoxPrior' : _rename('multibox_prior'), '_contrib_MultiBoxDetection' : _contrib_multibox_detection, + '_ones' : _ones, + '_zeros' : _zeros, 'Activation' : _activations, 'BatchNorm' : _batch_norm, 'BatchNorm_v1' : _batch_norm, @@ -397,13 +407,14 @@ def _from_mxnet_impl(symbol, graph): if node: return node[output_index] attr = symbol.list_attr() - # op_name = symbol.attr('op_name') + op_name = symbol.attr('op_name') childs = symbol.get_children() if childs is not None: - op_name = symbol.attr('op_name') childs = [_from_mxnet_impl(childs[i], graph) for i in range(len(childs.list_outputs()))] childs = [x for y in childs for x in _as_list(y)] # expand group symbol node = _convert_symbol(op_name, childs, attr) + elif op_name != 'null': + node = _convert_symbol(op_name, [], attr) # no input symbol else: op_name = json.loads(symbol.tojson())['nodes'][0]['op'] node = _sym.Variable(name=name, **attr) diff --git a/nnvm/tests/python/frontend/mxnet/test_forward.py b/nnvm/tests/python/frontend/mxnet/test_forward.py index 653af1a631544f7a9bfef54b0330d39c89460019..dbd93e71049158273f1898a115e8f1cee370b472 100644 --- a/nnvm/tests/python/frontend/mxnet/test_forward.py +++ b/nnvm/tests/python/frontend/mxnet/test_forward.py @@ -153,6 +153,28 @@ def test_forward_lrn(): mx_sym = mx.sym.LRN(data, alpha=2, beta=2, knorm=1, nsize=5) verify_mxnet_frontend_impl(mx_sym, (1, 10, 24, 24), (1, 10, 24, 24)) +def test_forward_ones(): + data = mx.sym.var('data') + ones = mx.sym.ones(shape=(2, 3, 4), dtype='float32') + mx_sym = mx.sym.elemwise_add(data, ones) + verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4)) + +def test_forward_zeros(): + data = mx.sym.var('data') + zeros = mx.sym.zeros(shape=(2, 3, 4), dtype='float32') + mx_sym = mx.sym.elemwise_add(data, zeros) + verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4)) + +def test_forward_ones_like(): + data = mx.sym.var('data') + mx_sym = mx.sym.ones_like(data, dtype='float32') + verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4)) + +def test_forward_zeros_like(): + data = mx.sym.var('data') + mx_sym = mx.sym.zeros_like(data, dtype='float32') + verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4)) + if __name__ == '__main__': test_forward_mlp() test_forward_vgg() @@ -168,3 +190,7 @@ if __name__ == '__main__': test_forward_expand_dims() test_forward_pooling() test_forward_lrn() + test_forward_ones() + test_forward_zeros() + test_forward_ones_like() + test_forward_zeros_like()