diff --git a/nnvm/docs/top.rst b/nnvm/docs/top.rst index c34ae5b54b07cef30ec4f85dfa381b1d532e03c2..11d0f6093f83411f7da4dc3273163e5b3fcdb909 100644 --- a/nnvm/docs/top.rst +++ b/nnvm/docs/top.rst @@ -28,7 +28,6 @@ This level enables fully connected multi-layer perceptron. :nosignatures: nnvm.symbol.dense - nnvm.symbol.matmul nnvm.symbol.relu nnvm.symbol.tanh nnvm.symbol.sigmoid @@ -40,12 +39,6 @@ This level enables fully connected multi-layer perceptron. nnvm.symbol.elemwise_mul nnvm.symbol.elemwise_div nnvm.symbol.elemwise_sum - nnvm.symbol.full - nnvm.symbol.full_like - nnvm.symbol.ones - nnvm.symbol.ones_like - nnvm.symbol.zeros - nnvm.symbol.zeros_like nnvm.symbol.flatten nnvm.symbol.concatenate nnvm.symbol.expand_dims @@ -57,7 +50,6 @@ This level enables fully connected multi-layer perceptron. nnvm.symbol.log_softmax nnvm.symbol.pad nnvm.symbol.block_grad - nnvm.symbol.indicator **Level 2: Convolutions** @@ -81,8 +73,6 @@ This level enables typical convnet models. :nosignatures: nnvm.symbol.reshape - nnvm.symbol.reshape_like - nnvm.symbol.expand_like nnvm.symbol.copy nnvm.symbol.negative nnvm.symbol.leaky_relu @@ -109,11 +99,21 @@ This level enables typical convnet models. nnvm.symbol.broadcast_sub nnvm.symbol.broadcast_mul nnvm.symbol.broadcast_div + nnvm.symbol.clip + nnvm.symbol.greater + nnvm.symbol.less + nnvm.symbol.expand_like + nnvm.symbol.reshape_like + nnvm.symbol.full + nnvm.symbol.full_like + nnvm.symbol.ones + nnvm.symbol.ones_like + nnvm.symbol.zeros + nnvm.symbol.zeros_like Detailed Definitions -------------------- .. autofunction:: nnvm.symbol.dense -.. autofunction:: nnvm.symbol.matmul .. autofunction:: nnvm.symbol.relu .. autofunction:: nnvm.symbol.tanh .. autofunction:: nnvm.symbol.sigmoid @@ -125,12 +125,6 @@ Detailed Definitions .. autofunction:: nnvm.symbol.elemwise_mul .. autofunction:: nnvm.symbol.elemwise_div .. autofunction:: nnvm.symbol.elemwise_sum -.. autofunction:: nnvm.symbol.full -.. autofunction:: nnvm.symbol.full_like -.. autofunction:: nnvm.symbol.ones -.. autofunction:: nnvm.symbol.ones_like -.. autofunction:: nnvm.symbol.zeros -.. autofunction:: nnvm.symbol.zeros_like .. autofunction:: nnvm.symbol.flatten .. autofunction:: nnvm.symbol.concatenate .. autofunction:: nnvm.symbol.expand_dims @@ -142,7 +136,6 @@ Detailed Definitions .. autofunction:: nnvm.symbol.log_softmax .. autofunction:: nnvm.symbol.pad .. autofunction:: nnvm.symbol.block_grad -.. autofunction:: nnvm.symbol.indicator .. autofunction:: nnvm.symbol.conv2d .. autofunction:: nnvm.symbol.conv2d_transpose @@ -152,8 +145,6 @@ Detailed Definitions .. autofunction:: nnvm.symbol.global_avg_pool2d .. autofunction:: nnvm.symbol.reshape -.. autofunction:: nnvm.symbol.reshape_like -.. autofunction:: nnvm.symbol.expand_like .. autofunction:: nnvm.symbol.copy .. autofunction:: nnvm.symbol.negative .. autofunction:: nnvm.symbol.leaky_relu @@ -175,3 +166,14 @@ Detailed Definitions .. autofunction:: nnvm.symbol.broadcast_sub .. autofunction:: nnvm.symbol.broadcast_mul .. autofunction:: nnvm.symbol.broadcast_div +.. autofunction:: nnvm.symbol.clip +.. autofunction:: nnvm.symbol.greater +.. autofunction:: nnvm.symbol.less +.. autofunction:: nnvm.symbol.expand_like +.. autofunction:: nnvm.symbol.reshape_like +.. autofunction:: nnvm.symbol.full +.. autofunction:: nnvm.symbol.full_like +.. autofunction:: nnvm.symbol.ones +.. autofunction:: nnvm.symbol.ones_like +.. autofunction:: nnvm.symbol.zeros +.. autofunction:: nnvm.symbol.zeros_like diff --git a/nnvm/include/nnvm/top/tensor.h b/nnvm/include/nnvm/top/tensor.h index 7f50e49b996cad63706792af236ae7dcca783595..00bad8245713c44367ddf05bcf4da26df92a45d3 100644 --- a/nnvm/include/nnvm/top/tensor.h +++ b/nnvm/include/nnvm/top/tensor.h @@ -241,6 +241,16 @@ struct MatMulParam : public dmlc::Parameter<MatMulParam> { } }; +struct ClipParam : public dmlc::Parameter<ClipParam> { + double a_min, a_max; + DMLC_DECLARE_PARAMETER(ClipParam) { + DMLC_DECLARE_FIELD(a_min) + .describe("Minimum value such that value smaller then this will be clipped."); + DMLC_DECLARE_FIELD(a_max) + .describe("Maximum value such that value larger then this will be clipped."); + } +}; + } // namespace top } // namespace nnvm diff --git a/nnvm/python/nnvm/_base.py b/nnvm/python/nnvm/_base.py index d01bdcaeb5b1a05277b2ae67bc75b30a87514d75..29390a2201bf7b711afe4ea419bfef1f89f93eaf 100644 --- a/nnvm/python/nnvm/_base.py +++ b/nnvm/python/nnvm/_base.py @@ -54,6 +54,9 @@ OpHandle = ctypes.c_void_p SymbolHandle = ctypes.c_void_p GraphHandle = ctypes.c_void_p +# Global dict of str to symbol to initialize variables +_all_var_init = {} + #---------------------------- # helper function definition #---------------------------- diff --git a/nnvm/python/nnvm/compiler/build_module.py b/nnvm/python/nnvm/compiler/build_module.py index bce142d21b630706be24cc97913e79c1157549ae..d97a8784dcc275129c4249a4c83dd1183c8be6da 100644 --- a/nnvm/python/nnvm/compiler/build_module.py +++ b/nnvm/python/nnvm/compiler/build_module.py @@ -4,9 +4,12 @@ from __future__ import absolute_import as _abs import logging import tvm + from tvm.contrib import graph_runtime from . import graph_attr, graph_util from .. import graph as _graph +from .. import symbol as sym +from .._base import _all_var_init OPT_PASS_LEVEL = { "SimplifyInference": 0, @@ -201,6 +204,9 @@ def build(graph, target=None, shape=None, dtype="float32", params=None, target_h By default, llvm is used if it is enabled, otherwise a stackvm intepreter is used. + initialize : bool, optional + Whether to initialize variables in global dict _all_var_init. + Returns ------- graph : Graph @@ -230,6 +236,10 @@ def build(graph, target=None, shape=None, dtype="float32", params=None, target_h if not isinstance(dtype, str): idtype, _ = graph_util.infer_dtype(graph, **dtype) dtype.update(zip(graph.index.input_names, idtype)) + # Initialize all variables specified in _all_var_init + init_var = {} + if _all_var_init: + init_var = initialize_variables(shape, dtype) # Apply optimization graph = optimize(graph, shape, dtype) # Precompute prune @@ -250,6 +260,11 @@ def build(graph, target=None, shape=None, dtype="float32", params=None, target_h with target: graph = graph.apply("GraphFusePartition").apply("GraphFuseCompile") libmod = graph_attr._move_out_module(graph, "module") + # Write variable initial values into params + if init_var: + if params is None: + params = {} + params.update(init_var) return graph, libmod, params @@ -329,3 +344,45 @@ def precompute_prune(graph, params): with tvm.build_config(auto_unroll_max_step=0): out_arrs = _run_graph(pre_graph, params) return graph, dict(zip(out_names, out_arrs)) + + +def initialize_variables(ishape, idtype): + """ Initialize variables stored in _all_var_init dictionary. + + Parameters + ---------- + ishape : dict of str to tuple of int + The input shape to the graph + + idtype : str or dict of str to str + The input types to the graph + + Returns + ------- + init_var : dict of str to tvm.ndarray + """ + symbol_init_dict = {} + const_init_dict = {} + init_var = {} + for key, value in _all_var_init.items(): + if isinstance(value, sym.Symbol): + symbol_init_dict[key] = value + else: + const_init_dict[key] = tvm.nd.array(value) + # Make sure variables are initialized only once. + _all_var_init.clear() + if symbol_init_dict: + # Create dummy params to run initialization graph + params = {} + for name, shape in ishape.items(): + dtype = idtype if isinstance(idtype, str) else idtype[name] + params[name] = tvm.nd.empty(shape, dtype, ctx=tvm.cpu()) + init_group_sym = sym.Group(symbol_init_dict.values()) + graph = _graph.create(init_group_sym) + with tvm.build_config(auto_unroll_max_step=0): + init_values = _run_graph(graph, params) + init_var.update(dict(zip(symbol_init_dict.keys(), init_values))) + init_var.update(const_init_dict) + for name, data in init_var.items(): + ishape[name] = data.shape + return init_var diff --git a/nnvm/python/nnvm/compiler/lr_scheduler.py b/nnvm/python/nnvm/compiler/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..791925e74960f58469b4123e94a72729f3fb8a98 --- /dev/null +++ b/nnvm/python/nnvm/compiler/lr_scheduler.py @@ -0,0 +1,58 @@ +# pylint: disable=too-few-public-methods, no-member +"""API for scheduling learning rate.""" +from .. import symbol as sym + +class LRScheduler(object): + """Base class of a learning rate scheduler. + + A scheduler returns a new learning rate based on the number of updates that have + been performed. + + Parameters + ---------- + base_lr : float, optional + The initial learning rate. + """ + def __init__(self, base_lr=0.01, name='LRScheduler'): + self.name = name + self.base_lr = base_lr + + def __call__(self, num_update): + """Return a new learning rate based on number of updates. + + Parameters + ---------- + num_update: nnvm Symbol + the number of updates applied to weight. + """ + raise NotImplementedError("__call__ method must be overridden.") + +class FactorScheduler(LRScheduler): + """Reduce the learning rate by a factor for every *n* steps. + + It returns a new learning rate by:: + + base_lr * pow(factor, num_update/step) + + Parameters + ---------- + step : int + Changes the learning rate for every n updates. + factor : float, optional + The factor to change the learning rate. + stop_factor_lr : float, optional + Stop updating the learning rate if it is less than this value. + """ + def __init__(self, step, factor=1, stop_factor_lr=1e-8, name='FactorScheduler', **kwargs): + super(FactorScheduler, self).__init__(name=name, **kwargs) + if step < 1: + raise ValueError("Schedule step must be greater or equal than 1 round") + if factor > 1.0: + raise ValueError("Factor must be no more than 1 to make lr reduce") + self.step = step + self.factor = factor + self.stop_factor_lr = stop_factor_lr + + def __call__(self, num_update): + updated_lr = self.base_lr * self.factor ** (num_update / self.step) + return sym.clip(updated_lr, a_min=self.stop_factor_lr, a_max=self.base_lr) diff --git a/nnvm/python/nnvm/compiler/optimizer.py b/nnvm/python/nnvm/compiler/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..bcf7498528af203a89d56825e9441b9acd0b217e --- /dev/null +++ b/nnvm/python/nnvm/compiler/optimizer.py @@ -0,0 +1,131 @@ +# pylint: disable=invalid-name, no-member, too-few-public-methods, too-many-arguments, too-many-locals, protected-access +"""Optimizer API""" +from . import graph_util +from .. import symbol as sym + +class Optimizer(object): + """Base class inherited by all optimizers. + + Parameters + ---------- + learning_rate : float, optional + The initial learning rate. + + lr_scheduler : LRScheduler, optional + The learning rate scheduler. + + rescale_grad : float, optional + Multiply the gradient with `rescale_grad` before updating. Often + choose to be ``1.0/batch_size``. + + clip_gradient : float, optional + Clip the gradient by projecting onto the box ``[-clip_gradient, clip_gradient]``. + + wd : float, optional + The weight decay (or L2 regularization) coefficient. Modifies objective + by adding a penalty for having large weights. + + name : string, optional + The name of optimizer. + """ + def __init__(self, learning_rate=0.01, lr_scheduler=None, + rescale_grad=1, clip_gradient=None, wd=0, name="Optimizer"): + self.name = name + self.lr = learning_rate + self.lr_scheduler = lr_scheduler + self.rescale_grad = rescale_grad + self.clip_gradient = clip_gradient + self.wd = wd + init_update_t = sym.Variable(name+'_t', init=sym.zeros(shape=(1,), dtype="int32")) + self.update_t = sym._assign(init_update_t, init_update_t + 1) + + def minimize(self, obj, var=None): + """Minimize given obj symbol respect to var. If var is not set, all input + variables of obj will be used. + + Parameters + ---------- + obj : nnvm Symbol or list of nnvm Symbols + Symbols to be minimized. + var : nnvm Symbol or list of nnvm Symbols, optional + Symbols the gradient respect to. + + Returns + ------- + group_sym : nnvm Symbol + Group symbol represents update symbols. + """ + raise NotImplementedError() + + def _get_lr(self): + """Gets the learning rate with learning rate scheduler. + + Returns + ------- + lr : float + Learning rate. + """ + if self.lr_scheduler is not None: + lr = self.lr_scheduler(self.update_t) + else: + lr = self.lr + return lr + + +class SGD(Optimizer): + """The SGD optimizer + """ + def __init__(self, name='SGD', **kwargs): + super(SGD, self).__init__(name=name, **kwargs) + + def minimize(self, obj, var=None): + variables = var or obj.list_input_variables() + if not isinstance(variables, list): + variables = [variables] + grads = graph_util.gradients(obj, variables) + updates = [] + lr_t = self._get_lr() + for v, g in zip(variables, grads): + g = self.rescale_grad * g + if self.clip_gradient is not None: + g = sym.clip(g, a_min=-1 * self.clip_gradient, a_max=self.clip_gradient) + updates.append(sym._assign(v, v - lr_t * (g + self.wd * v))) + return sym.Group(updates) + + +class Adam(Optimizer): + """The Adam optimizer. + + This class implements the optimizer described in *Adam: A Method for + Stochastic Optimization*, available at http://arxiv.org/abs/1412.6980. + """ + def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, + epsilon=1e-8, name='Adam', **kwargs): + super(Adam, self).__init__(learning_rate=learning_rate, name=name, **kwargs) + self.beta1 = beta1 + self.beta2 = beta2 + self.epsilon = epsilon + self.m = [] + self.v = [] + + def minimize(self, obj, var=None): + variables = var or obj.list_input_variables() + if not isinstance(variables, list): + variables = [variables] + grads = graph_util.gradients(obj, variables) + updates = [] + for i, v in enumerate(variables): + self.m.append(sym.Variable(self.name + '_m' + str(i), init=sym.zeros_like(v))) + self.v.append(sym.Variable(self.name + '_v' + str(i), init=sym.zeros_like(v))) + rate = sym.sqrt(1 - self.beta2 ** self.update_t) / (1 - self.beta1 ** self.update_t) + lr_t = self._get_lr() * rate + for variable, g, m, v in zip(variables, grads, self.m, self.v): + g = self.rescale_grad * g + if self.clip_gradient is not None: + g = sym.clip(g, a_min=-1 * self.clip_gradient, a_max=self.clip_gradient) + update_m = sym._assign(m, self.beta1 * m + (1 - self.beta1) * g) + update_v = sym._assign(v, self.beta2 * v + (1 - self.beta2) * g * g) + update_var = sym._assign(variable, variable - lr_t * (update_m / (sym.sqrt(update_v) \ + + self.epsilon) + self.wd * variable)) + updates.append(update_var) + return sym.Group(updates) diff --git a/nnvm/python/nnvm/symbol.py b/nnvm/python/nnvm/symbol.py index 0ac51f23fe9428a8ac6531bc18cc8e9fe60e3373..8b390e2cb72f6a59b7a4efb8e49f9ae6fcf1b3b9 100644 --- a/nnvm/python/nnvm/symbol.py +++ b/nnvm/python/nnvm/symbol.py @@ -1,4 +1,4 @@ -# pylint: disable=invalid-name, unused-import +# pylint: disable=invalid-name, unused-import, protected-access """Symbolic graph construction API. This namespace contains most of the registered operators. @@ -8,10 +8,12 @@ from __future__ import absolute_import as _abs import sys as _sys import os as _os import ctypes as _ctypes - from numbers import Number as _Number + +import numpy as np + from . import _base -from ._base import _LIB, check_call as _check_call, _FFI_MODE +from ._base import _LIB, check_call as _check_call, _FFI_MODE, _all_var_init from .attribute import AttrScope from . import _symbol_internal as _internal @@ -309,13 +311,19 @@ class Symbol(SymbolBase): self.handle, deps.handle)) -def Variable(name, **kwargs): +def Variable(name, init=None, **kwargs): """Create a symbolic variable with specified name. Parameters ---------- name : str Name of the variable. + init : Symbol or numpy.ndarray + Symbol or numpy ndarray of initial value for the variable. + Note that for symbolic initialization value, it must be able + to be defined through InferShape, such as sym.zeros_like(v), + in which v is an input or parameter. Otherwise, pass a numpy + ndarray instead. kwargs : dict of string -> string Additional attributes to set on the variable. @@ -333,6 +341,11 @@ def Variable(name, **kwargs): attr = AttrScope.current.get(kwargs) if attr: ret._set_attr(**attr) + if init is not None: + if not isinstance(init, (Symbol, np.ndarray)): + raise TypeError('Expect a Symbol or numpy ndarray' + 'for variable `init`') + _all_var_init[name] = init return ret diff --git a/nnvm/python/nnvm/top/attr_dict.py b/nnvm/python/nnvm/top/attr_dict.py index 453c2971f73f33cebe55944cf8a67cd07e58fc21..a913a92552b28f2dc737eb16c60a9b725ad573af 100644 --- a/nnvm/python/nnvm/top/attr_dict.py +++ b/nnvm/python/nnvm/top/attr_dict.py @@ -123,6 +123,21 @@ class AttrDict(object): else: raise ValueError("Wrong bool format for key %s" % key) + def get_string(self, key): + """Get string from attr dict + + Parameters + ---------- + key : str + The attr key + + Returns + ------- + value : str + The result value + """ + return self[key] + def __repr__(self): return str({k : self[k] for k in self.keys()}) diff --git a/nnvm/python/nnvm/top/tensor.py b/nnvm/python/nnvm/top/tensor.py index 4f9d3ba9f36ffba717aabe795e5b192ff6a9a815..1e8688f9f2e69afb396f23db2d5e84e88f48dab8 100644 --- a/nnvm/python/nnvm/top/tensor.py +++ b/nnvm/python/nnvm/top/tensor.py @@ -143,3 +143,95 @@ reg.register_schedule("broadcast_div", _fschedule_broadcast) # broadcast_to reg.register_pattern("broadcast_to", OpPattern.BROADCAST) reg.register_schedule("broadcast_to", _fschedule_broadcast) + +# clip +reg.register_pattern("clip", OpPattern.ELEMWISE) +reg.register_schedule("clip", _fschedule_elemwise) + +# elemwise sum +@reg.register_compute("elemwise_sum") +def compute_elemwise_sum(attrs, inputs, _): + """Compute definition of elemwise sum""" + num_args = attrs.get_int("num_args") + assert num_args == len(inputs), "Number of tensors does not match num_args." + return topi.tensor.elemwise_sum(inputs, num_args) +reg.register_pattern("elemwise_sum", OpPattern.ELEMWISE) +reg.register_schedule("elemwise_sum", _fschedule_elemwise) + +# full +@reg.register_compute("full") +def compute_full(attrs, inputs, _): + """Compute definition of full""" + shape = attrs.get_int_tuple("shape") + dtype = attrs.get_string("dtype") + fill_value = attrs.get_float("fill_value") + return topi.tensor.full(shape, dtype, fill_value) +reg.register_pattern("full", OpPattern.OUT_ELEMWISE_FUSABLE) +reg.register_schedule("full", _fschedule_elemwise) + +# full_like +@reg.register_compute("full_like") +def compute_full_like(attrs, inputs, _): + """Compute definition of full_like""" + fill_value = attrs.get_float("fill_value") + return topi.tensor.full_like(inputs[0], fill_value) +reg.register_pattern("full_like", OpPattern.ELEMWISE) +reg.register_schedule("full_like", _fschedule_elemwise) + +# zeros +@reg.register_compute("zeros") +def compute_zeros(attrs, inputs, _): + """Compute definition of zeros""" + shape = attrs.get_int_tuple("shape") + dtype = attrs.get_string("dtype") + return topi.tensor.full(shape, dtype, 0) +reg.register_pattern("zeros", OpPattern.OUT_ELEMWISE_FUSABLE) +reg.register_schedule("zeros", _fschedule_elemwise) + +# zeros_like +@reg.register_compute("zeros_like") +def compute_zeros_like(_, inputs, out_info): + """Compute definition of zeros_like""" + return topi.tensor.full_like(inputs[0], 0) +reg.register_pattern("zeros_like", OpPattern.ELEMWISE) +reg.register_schedule("zeros_like", _fschedule_elemwise) + +# ones +@reg.register_compute("ones") +def compute_ones(attrs, inputs, _): + """Compute definition of ones""" + shape = attrs.get_int_tuple("shape") + dtype = attrs.get_string("dtype") + #tvm.tensor.Tensor() + return topi.tensor.full(shape, dtype, 1) +reg.register_pattern("ones", OpPattern.OUT_ELEMWISE_FUSABLE) +reg.register_schedule("ones", _fschedule_elemwise) + +# ones_like +@reg.register_compute("ones_like") +def compute_ones_like(_, inputs, out_info): + """Compute definition of ones_like""" + return topi.tensor.full_like(inputs[0], 1) +reg.register_pattern("ones_like", OpPattern.ELEMWISE) +reg.register_schedule("ones_like", _fschedule_elemwise) + +# greater +@reg.register_compute("greater") +def compute_greater(_, inputs, out_info): + """Compute definition of greater""" + return topi.tensor.greater(inputs[0], inputs[1], 'float32') +reg.register_pattern("greater", OpPattern.ELEMWISE) +reg.register_schedule("greater", _fschedule_elemwise) + +# less +@reg.register_compute("less") +def compute_less(_, inputs, out_info): + """Compute definition of less""" + return topi.tensor.less(inputs[0], inputs[1], 'float32') +reg.register_pattern("less", OpPattern.ELEMWISE) +reg.register_schedule("less", _fschedule_elemwise) + +# block_grad +reg.register_compute("block_grad", _compute_unary(topi.identity)) +reg.register_pattern("block_grad", OpPattern.ELEMWISE) +reg.register_schedule("block_grad", _fschedule_elemwise) diff --git a/nnvm/python/nnvm/top/transform.py b/nnvm/python/nnvm/top/transform.py index ec1596e9599bb99256d913aa167c6d05793e305c..c3ceb68682ee5a4f7dd38357663ef6c28e2379a3 100644 --- a/nnvm/python/nnvm/top/transform.py +++ b/nnvm/python/nnvm/top/transform.py @@ -2,6 +2,7 @@ """Tensor transformation ops""" from __future__ import absolute_import +import topi from .tensor import _fschedule_broadcast, _fschedule_injective from . import registry as reg from .registry import OpPattern @@ -10,6 +11,32 @@ from .registry import OpPattern reg.register_pattern("expand_dims", OpPattern.BROADCAST) reg.register_schedule("expand_dims", _fschedule_broadcast) +# expand_like +@reg.register_compute("expand_like") +def compute_expand_like(attrs, inputs, _): + """Compute definition of expand_like""" + exclude = attrs.get_bool("exclude") + axis = attrs.get_int_tuple("axis") + if exclude: + exclude_axis = (axis,) if isinstance(axis, int) else axis + axis = [] + for item in range(len(inputs[1].shape)): + if item not in exclude_axis: + axis.append(item) + axis = tuple(axis) + + return topi.transform.expand_like(inputs[0], inputs[1], axis) +reg.register_pattern("expand_like", OpPattern.BROADCAST) +reg.register_schedule("expand_like", _fschedule_broadcast) + +# reshape_like +@reg.register_compute("reshape_like") +def compute_reshape_like(attrs, inputs, out_info): + """Compute definition of reshape_like""" + return topi.reshape(inputs[0], inputs[1].shape) +reg.register_pattern("reshape_like", OpPattern.INJECTIVE) +reg.register_schedule("reshape_like", _fschedule_injective) + # transpose reg.register_pattern("transpose", OpPattern.INJECTIVE) reg.register_schedule("transpose", _fschedule_injective) diff --git a/nnvm/src/top/nn/nn.cc b/nnvm/src/top/nn/nn.cc index 79b4c8587a645da0ce287ba6d033fa17771c7f68..5ac4b7662ff5a155a11cf825d2c5e92359ce1251 100644 --- a/nnvm/src/top/nn/nn.cc +++ b/nnvm/src/top/nn/nn.cc @@ -130,15 +130,14 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(relu) "FGradient", [](const NodePtr& n, const std::vector<NodeEntry>& ograds) { // y = relu(x) - // grad = indicator(x > 0) - NodeEntry zero = MakeNode("zeros_like", n->attrs.name + "_grad_zero", + // grad = indicator(x > 0) * ograd + NodeEntry sub0 = MakeNode("zeros_like", n->attrs.name + "_sub0", {n->inputs[0]}); + NodeEntry sub1 = MakeNode("greater", n->attrs.name + "_sub1", + {n->inputs[0], sub0}, {{"exclude", "true"}}); return std::vector<NodeEntry>{ - MakeNode("elemwise_mul", n->attrs.name + "_grad", { - ograds[0], - MakeNode("greater", n->attrs.name + "_grad_mask", - {n->inputs[0], zero}, {{"exclude", "true"}}) - }) + MakeNode("elemwise_mul", n->attrs.name + "_grad", + {ograds[0], sub1}) }; }) .set_support_level(1); @@ -358,23 +357,21 @@ NNVM_REGISTER_OP(log_softmax) // grad_x = sum(grad_x, keepdim, axis) // grad_x = neg grad_x // grad_x = grad_x + ones_like(grad_x) - // grad_x = expand_dims(grad_x, axis) const SoftmaxParam& param = nnvm::get<SoftmaxParam>(n->attrs.parsed); NodeEntry output = NodeEntry{n, 0, 0}; NodeEntry sub0 = MakeNode("elemwise_mul", n->attrs.name + "_grad_sub0", {ograds[0], output}); NodeEntry sub1 = MakeNode("sum", n->attrs.name + "_grad_sub1", {sub0}, {{"axis", std::to_string(param.axis)}, {"keepdims", "true"}}); - NodeEntry sub2 = MakeNode("negative", n->attrs.name + "_grad_sub2", {sub1}); - NodeEntry sub3 = MakeNode("ones_like", n->attrs.name + "_grad_sub3", {sub2}); - NodeEntry sub4 = MakeNode("elemwise_add", n->attrs.name + "_grad_sub4", {sub2, sub3}); + NodeEntry sub2 = MakeNode("full_like", n->attrs.name + "_grad_sub2", {n->inputs[0]}, + {{"fill_value", "-1"}}); + NodeEntry sub3 = MakeNode("broadcast_mul", n->attrs.name + "_grad_sub3", {sub1, sub2}); return std::vector<NodeEntry> { - MakeNode("expand_like", n->attrs.name + "_grad", {sub4, output}, - {{"axis", std::to_string(param.axis)}}) + MakeNode("elemwise_add", n->attrs.name + "_grad", {sub3, ograds[0]}) }; }) .set_support_level(1); -// leaky_rlu +// leaky_relu DMLC_REGISTER_PARAMETER(LeakyReLUParam); NNVM_REGISTER_OP(leaky_relu) @@ -407,14 +404,15 @@ NNVM_REGISTER_OP(leaky_relu) NodeEntry zero = MakeNode("zeros_like", n->attrs.name + "_grad_zero", {n->inputs[0]}); NodeEntry sub0 = MakeNode("greater", n->attrs.name + "_pos_grad", - {n->inputs[0], zero}, {{"exclude", "true"}}); + {n->inputs[0], zero}); NodeEntry sub1 = MakeNode("less", n->attrs.name + "_neg_grad", - {n->inputs[0], zero}, {{"exclude", "true"}}); + {n->inputs[0], zero}); NodeEntry sub2 = MakeNode("__mul_scalar__", n->attrs.name + "_neg_mul_2", {sub1}, {{"scalar", std::to_string(param.alpha)}}); + NodeEntry sub3 = MakeNode("elemwise_add", n->attrs.name + "_sub3", {sub0, sub2}); return std::vector<NodeEntry>{ - MakeNode("elemwise_add", n->attrs.name + "_add_grad", {sub0, sub2}) + MakeNode("elemwise_mul", n->attrs.name + "_grad", {ograds[0], sub3}) }; }) .set_support_level(1); diff --git a/nnvm/src/top/tensor/elemwise.cc b/nnvm/src/top/tensor/elemwise.cc index d23d9c4f0a3f7ea9d701c26ba2f41147a22b00da..87fbf5823d9b2f19df805cbc0a4ff4f330953a77 100644 --- a/nnvm/src/top/tensor/elemwise.cc +++ b/nnvm/src/top/tensor/elemwise.cc @@ -190,7 +190,10 @@ NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_add) // y = n0 + n1 // grad_0 = grad_y // grad_1 = grad_y - return std::vector<NodeEntry>{ograds[0], ograds[0]}; + return std::vector<NodeEntry>{ MakeNode("copy", n->attrs.name + "_grad_0", + {ograds[0]}), + MakeNode("copy", n->attrs.name + "_grad_0", + {ograds[0]}) }; }); NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_sub) @@ -311,7 +314,8 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(copy) const std::vector<NodeEntry>& ograds){ // y = copy(n0) // grad_0 = grad_y - return std::vector<NodeEntry>{ograds[0]}; + return std::vector<NodeEntry>{ MakeNode("copy", n->attrs.name + "_grad_0", + {ograds[0]}) }; }); DMLC_REGISTER_PARAMETER(InitOpParam); @@ -329,7 +333,7 @@ NNVM_REGISTER_INIT_OP(full) .add_arguments(InitOpWithScalarParam::__FIELDS__()) .set_attr<FInferShape>("FInferShape", ZeroShape<InitOpWithScalarParam>) .set_attr<FInferType>("FInferType", ZeroType<InitOpWithScalarParam>) -.set_support_level(1); +.set_support_level(4); NNVM_REGISTER_INIT_OP(zeros) .describe(R"code(Fill target with zeros @@ -341,7 +345,7 @@ NNVM_REGISTER_INIT_OP(zeros) .add_arguments(InitOpParam::__FIELDS__()) .set_attr<FInferShape>("FInferShape", ZeroShape<InitOpParam>) .set_attr<FInferType>("FInferType", ZeroType<InitOpParam>) -.set_support_level(1); +.set_support_level(4); NNVM_REGISTER_INIT_OP(ones) .describe(R"code(Fill target with ones @@ -353,7 +357,7 @@ NNVM_REGISTER_INIT_OP(ones) .add_arguments(InitOpParam::__FIELDS__()) .set_attr<FInferShape>("FInferShape", ZeroShape<InitOpParam>) .set_attr<FInferType>("FInferType", ZeroType<InitOpParam>) -.set_support_level(1); +.set_support_level(4); // full_like NNVM_REGISTER_INIT_LIKE_OP(full_like) @@ -364,21 +368,21 @@ as the input array .add_arguments(FillValueParam::__FIELDS__()) .set_attr_parser(ParamParser<FillValueParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<FillValueParam>) -.set_support_level(1); +.set_support_level(4); NNVM_REGISTER_INIT_LIKE_OP(zeros_like) .describe(R"code(Return an array of zeros with the same shape and type as the input array. )code") -.set_support_level(1); +.set_support_level(4); NNVM_REGISTER_INIT_LIKE_OP(ones_like) .describe(R"code(Return an array of ones with the same shape and type as the input array. )code") -.set_support_level(1); +.set_support_level(4); // unary scalar op DMLC_REGISTER_PARAMETER(ScalarParam); @@ -415,7 +419,8 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__add_scalar__) .set_attr<FGradient>( "FGradient", [](const NodePtr& n, const std::vector<NodeEntry>& ograds){ - return std::vector<NodeEntry>{ograds[0]}; + return std::vector<NodeEntry>{ MakeNode("copy", n->attrs.name + "_grad_0", + {ograds[0]}) }; }); NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__sub_scalar__) @@ -601,10 +606,11 @@ NNVM_REGISTER_ELEMWISE_REDUCE_OP(elemwise_sum) CHECK_EQ(ograds.size(), 1); std::vector<NodeEntry> ret; for (size_t i = 0; i < n->inputs.size(); i++) { - ret.push_back(ograds[0]); + ret.push_back(MakeNode("copy", n->attrs.name + "_grad_0", {ograds[0]})); } return ret; - }); + }) +.set_support_level(4); NNVM_REGISTER_ELEMWISE_UNARY_OP(block_grad) .describe(R"code(Blocks gradient computation for input. @@ -614,7 +620,8 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(block_grad) "FInplaceIdentity", [](const NodeAttrs& attrs){ return std::vector<bool>{true}; }) -.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes); +.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes) +.set_support_level(4); DMLC_REGISTER_PARAMETER(IndicatorParam); @@ -628,7 +635,7 @@ with 1.0 if (left > right), otherwise 0.0 element-wise. .add_argument("rhs", "Tensor", "Second input") .set_num_inputs(2) .set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<2, 1>) -.set_support_level(1); +.set_support_level(4); NNVM_REGISTER_INDICATOR_OP(less) @@ -640,7 +647,7 @@ with 1.0 if (left < right), otherwise 0.0 element-wise. .add_argument("rhs", "Tensor", "Second input") .set_num_inputs(2) .set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<2, 1>) -.set_support_level(1); +.set_support_level(4); NNVM_REGISTER_INDICATOR_OP(_max_mask) .describe(R"code(Function that returns a mask tensor @@ -668,5 +675,73 @@ with 1.0 if the value is minimum over given axes, otherwise 0.0 element-wise. .set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>) .set_support_level(1); + +DMLC_REGISTER_PARAMETER(ClipParam); + +NNVM_REGISTER_OP(clip) +.describe(R"doc(Clips (limits) the values in an array. +Given an interval, values outside the interval are clipped to the interval edges. +Clipping ``x`` between `a_min` and `a_x` would be:: + clip(x, a_min, a_max) = max(min(x, a_max), a_min)) +Example:: + x = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + clip(x,1,8) = [ 1., 1., 2., 3., 4., 5., 6., 7., 8., 8.] +)doc" NNVM_ADD_FILELINE) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser(ParamParser<ClipParam>) +.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ClipParam>) +.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>) +.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>) +.set_attr<FTVMCompute>( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array<Tensor>& inputs, + const Array<Tensor>& out_info) { + const ClipParam params = get<ClipParam>(attrs.parsed); + return Array<Tensor>{ + topi::clip(inputs[0], tvm::make_const(tvm::Float(32), params.a_min), + tvm::make_const(tvm::Float(32), params.a_max)) }; + }) +.add_argument("data", "NDArray-or-Symbol", "Input array.") +.add_arguments(ClipParam::__FIELDS__()) +.set_attr<nnvm::FGradient>( + "FGradient", [](const NodePtr& n, + const std::vector<NodeEntry>& ograds){ + // y = clip(x, a_min, a_max) + // min_mask = greater_equal(x, a_min*ones_like(x)) + // => ones_like(x) - less(x, a_min) + // max_mask = less_equal(x, a_max*ones_like(x)) + // => ones_like(x) - greater(x, a_max) + // grad_x = min_mask * max_mask * grad_y + CHECK_EQ(ograds.size(), 1); + + NodeEntry sub0 = MakeNode("ones_like", n->attrs.name + "_grad_sub_0", + {n->inputs[0]}); + // min_mask + NodeEntry sub1 = MakeNode("__mul_scalar__", n->attrs.name + "_grad_sub_1", + {sub0}, {{"scalar", n->attrs.dict["a_min"]}}); + NodeEntry sub2 = MakeNode("less", n->attrs.name + "_grad_sub_2", + {n->inputs[0], sub1}); + NodeEntry sub3 = MakeNode("elemwise_sub", n->attrs.name + "_grad_sub_3", + {sub0, sub2}); + + // max_mask + NodeEntry sub4 = MakeNode("__mul_scalar__", n->attrs.name + "_grad_sub_4", + {sub0}, {{"scalar", n->attrs.dict["a_max"]}}); + NodeEntry sub5 = MakeNode("greater", n->attrs.name + "_grad_sub_5", + {n->inputs[0], sub4}); + NodeEntry sub6 = MakeNode("elemwise_sub", n->attrs.name + "_grad_sub_6", + {sub0, sub5}); + + // min_mask * max_mask + NodeEntry sub7 = MakeNode("elemwise_mul", n->attrs.name + "_grad_sub_7", + {sub3, sub6}); + return std::vector<NodeEntry>{ + MakeNode("elemwise_mul", n->attrs.name + "_grad", + {sub7, ograds[0]}) + }; + }) +.set_support_level(4); + } // namespace top } // namespace nnvm diff --git a/nnvm/src/top/tensor/reduce.cc b/nnvm/src/top/tensor/reduce.cc index 8eac2449b27187eae64e39f15e8984b2eb4cb804..84a7dd0f0e12f6de7245cdf0824e65714e46ea97 100644 --- a/nnvm/src/top/tensor/reduce.cc +++ b/nnvm/src/top/tensor/reduce.cc @@ -137,7 +137,20 @@ Example:: const Array<Tensor>& inputs, const Array<Tensor>& out_info) { const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed); - auto axis = ShapeToArray(param.axis); + Array<Expr> axis; + if (param.exclude) { + std::set<dim_t> exclude_axis; + for (dim_t i = 0; i < param.axis.ndim(); ++i) { + exclude_axis.insert(param.axis[i]); + } + for (dim_t i = 0; i < inputs[0].ndim(); ++i) { + if (exclude_axis.count(i) == 0) { + axis.push_back(make_const(Int(32), i)); + } + } + } else { + axis = ShapeToArray(param.axis); + } return Array<Tensor>{ topi::sum(inputs[0], axis, param.keepdims) }; }) @@ -150,7 +163,6 @@ Example:: MakeNode("expand_like", n->attrs.name + "_grad", {ograds[0], n->inputs[0]}, {{"axis", axis.str()}, - {"keepdims", std::to_string(param.keepdims)}, {"exclude", std::to_string(param.exclude)}}) }; }); diff --git a/nnvm/src/top/tensor/state_op.cc b/nnvm/src/top/tensor/state_op.cc index f275cf309511b79c8a74b4ed1eadad001136697b..ebce07696fe4420de91c6bfe59b057621706d833 100644 --- a/nnvm/src/top/tensor/state_op.cc +++ b/nnvm/src/top/tensor/state_op.cc @@ -48,6 +48,15 @@ This is an experimental operator. .set_attr<FInplaceOption>( "FInplaceOption", [](const NodeAttrs& attrs) { return std::vector<std::pair<int, int> >{{1, 0}}; +}) +.set_attr<FGradient>( + "FGradient", [](const NodePtr& n, + const std::vector<NodeEntry>& ograds){ + return std::vector<NodeEntry>{ + MakeNode("zeros_like", n->attrs.name + "_zero_grad", + {n->inputs[0]}), + ograds[0] + }; }); } // namespace top diff --git a/nnvm/src/top/tensor/transform.cc b/nnvm/src/top/tensor/transform.cc index dcdedbb867f90cfe9dd7a918093870f63d9455fa..2457747341c933f7c7037c04b158748a44ec2589 100644 --- a/nnvm/src/top/tensor/transform.cc +++ b/nnvm/src/top/tensor/transform.cc @@ -229,29 +229,24 @@ will return a new array with shape ``(2,5,3,4)``. NNVM_REGISTER_OP(expand_like) .describe(R"code(Expand an input array with the shape of second array. - This operation can always be composed of unsqueezing and expanding dims. - Examples:: input = [ 12. 19. 27.] input.shape = (3,) - new_shape_array = [[[1,2],[2,3],[1,3]], [[1,4],[4,3],[5,2]], [[7,1],[7,2],[7,3]]] new_shape_array.shape = (3, 3, 2) - expand_like(input, [1,2], new_shape_array) = [[[12,12],[12,12],[12,12]], [[19,19],[19,19],[19,19]], [[27,27],[27,27],[27,27]]] - )code" NNVM_ADD_FILELINE) .add_argument("input", "Tensor", "Source input") .add_argument("shape_like", "Tensor", "Input with new shape") -.add_arguments(ReduceParam::__FIELDS__()) -.set_attr_parser(ParamParser<ReduceParam>) -.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ReduceParam>) +.add_arguments(IndicatorParam::__FIELDS__()) +.set_attr_parser(ParamParser<IndicatorParam>) +.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<IndicatorParam>) .set_attr<nnvm::FInferShape>("FInferShape", AssignOutputAttr<TShape, 1, 0>) .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>) .set_num_inputs(2) @@ -259,7 +254,7 @@ Examples:: .set_attr<FGradient>( "FGradient", [](const NodePtr& n, const std::vector<NodeEntry>& ograds) { - const ReduceParam& param = nnvm::get<ReduceParam>(n->attrs.parsed); + const IndicatorParam& param = nnvm::get<IndicatorParam>(n->attrs.parsed); std::ostringstream axis; axis << param.axis; @@ -267,11 +262,11 @@ Examples:: MakeNode("sum", n->attrs.name + "_grad", {ograds[0]}, {{"axis", axis.str()}, - {"keepdims", std::to_string(param.keepdims)}, - {"exclude", std::to_string(param.exclude)}}) + {"exclude", std::to_string(param.exclude)}}), + MakeNode("zeros_like", n->attrs.name + "_zero_grad", {n->inputs[1]}) }; -}) -.set_support_level(1); + }) + .set_support_level(4); // split DMLC_REGISTER_PARAMETER(SplitParam); @@ -564,13 +559,10 @@ The significance of each is explained below: NNVM_REGISTER_OP(reshape_like) .describe(R"code(Reshapes the input array by the size of another array. - For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation reshapes the input array into an output array with the same shape as the second input array. - .. note:: Sizes for both array should be compatible. - )code" NNVM_ADD_FILELINE) .add_argument("data", "Tensor", "Input data.") .add_argument("shape_like", "Tensor", "Input data.") @@ -589,10 +581,12 @@ the input array into an output array with the same shape as the second input arr .set_attr<FGradient>( "FGradient", [](const NodePtr& n, const std::vector<NodeEntry>& ograds) { - return MakeGradNode("reshape_like", n, - {ograds[0], n->inputs[0]}); + return std::vector<NodeEntry>{ + MakeNode("reshape_like", n->attrs.name + "_grad", {ograds[0], n->inputs[0]}), + MakeNode("zeros_like", n->attrs.name + "_zero_grad", { n->inputs[1]}) + }; }) -.set_support_level(3); +.set_support_level(4); // squeeze DMLC_REGISTER_PARAMETER(SqueezeParam); @@ -680,7 +674,8 @@ Examples:: "FGradient", [](const NodePtr& n, const std::vector<NodeEntry>& ograds) { return std::vector<NodeEntry>{ - MakeNode("reshape_like", n->attrs.name + "_grad", {n->inputs[0]}) + MakeNode("reshape_like", n->attrs.name + "_grad", + {ograds[0], n->inputs[0]}) }; }) .set_support_level(1); diff --git a/nnvm/tests/python/compiler/test_optimizer.py b/nnvm/tests/python/compiler/test_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..fd620271d8615cd1d11ed92e713a91502f10d70a --- /dev/null +++ b/nnvm/tests/python/compiler/test_optimizer.py @@ -0,0 +1,118 @@ +import numpy as np +import tvm +import nnvm +import nnvm.compiler.optimizer as optimizer +import nnvm.compiler.lr_scheduler as lr_scheduler + +from nnvm.testing.config import ctx_list +from tvm.contrib import graph_runtime + + +def helper(symbol, inputs, params, update_func, run_times, target, ctx, dtype="float32"): + ishapes = {} + np_inputs = {} + params_dict = {} + for (name, shape, s) in inputs: + ishapes.update({name: shape}) + np_inputs.update({name: np.random.uniform(size=shape).astype(dtype)}) + for (name, shape, s) in params: + np_inputs.update({name: np.random.uniform(size=shape).astype(dtype)}) + params_dict.update({name: np_inputs[name]}) + + graph, lib, rt_params = nnvm.compiler.build(symbol, target, shape=ishapes) + m = graph_runtime.create(graph, lib, ctx) + m.set_input(**np_inputs) + m.set_input(**rt_params) + for _ in range(run_times): + m.run() + y_np = update_func(**np_inputs) + out = m.get_output(0, tvm.nd.empty(y_np.shape, dtype)) + np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5) + + +def test_sgd(): + for target, ctx in ctx_list(): + data = nnvm.sym.Variable("data") + weight = nnvm.sym.Variable("weight") + out = nnvm.sym.elemwise_mul(data, weight ** 2) + + dshape = (1, 2, 3) + wshape = dshape + + base_lr = 0.1 + lr_factor = 0.5 + rescale_grad = 0.2 + wd = 0.1 + clip_gradient = 0.25 + + scheduler = lr_scheduler.FactorScheduler(base_lr=base_lr, step=1, factor=lr_factor) + opt = optimizer.SGD(learning_rate=base_lr, lr_scheduler=scheduler, + rescale_grad=rescale_grad, clip_gradient=clip_gradient, + wd=wd) + opt_sym = opt.minimize(out, var=weight) + + inputs = [("data", dshape, data)] + params = [("weight", wshape, weight)] + + def update_func(data, weight): + gradient_0 = data * 2 * weight * rescale_grad + gradient_0 = np.clip(gradient_0, -clip_gradient, clip_gradient) + weight_0 = weight - base_lr * lr_factor * (gradient_0 + wd * weight) + gradient_1 = data * 2 * weight_0 * rescale_grad + gradient_1 = np.clip(gradient_1, -clip_gradient, clip_gradient) + weight_1 = weight_0 - base_lr * (lr_factor ** 2) * (gradient_1 + wd * weight_0) + return weight_1 + + helper(opt_sym, inputs, params, update_func, 2, target, ctx) + + + +def test_adam(): + for target, ctx in ctx_list(): + data = nnvm.sym.Variable("data") + weight = nnvm.sym.Variable("weight") + out = nnvm.sym.elemwise_mul(data, weight ** 2) + + dshape = (1, 2, 3) + wshape = dshape + + base_lr = 0.1 + beta1 = 0.9 + beta2 = 0.999 + epsilon = 1e-8 + lr_factor = 0.5 + rescale_grad = 0.2 + wd = 0.1 + clip_gradient = 0.25 + + scheduler = lr_scheduler.FactorScheduler(base_lr=base_lr, step=1, factor=lr_factor) + opt = optimizer.Adam(learning_rate=base_lr, beta1=beta1, beta2=beta2, epsilon=epsilon, + lr_scheduler=scheduler, rescale_grad=rescale_grad, + clip_gradient=clip_gradient, wd=wd) + opt_sym = opt.minimize(out, var=weight) + + inputs = [("data", dshape, data)] + params = [("weight", wshape, weight)] + + def update_func(data, weight): + rate_0 = np.sqrt(1 - beta2) / (1 - beta1) + lr_0 = base_lr * lr_factor * rate_0 + gradient_0 = data * 2 * weight * rescale_grad + gradient_0 = np.clip(gradient_0, -clip_gradient, clip_gradient) + m_0 = (1 - beta1) * gradient_0 + v_0 = (1 - beta2) * (gradient_0 ** 2) + weight_0 = weight - lr_0 * (m_0 / (np.sqrt(v_0) + epsilon) + wd * weight) + rate_1 = np.sqrt(1 - beta2 ** 2) / (1 - beta1 ** 2) + lr_1 = base_lr * (lr_factor ** 2) * rate_1 + gradient_1 = data * 2 * weight_0 * rescale_grad + gradient_1 = np.clip(gradient_1, -clip_gradient, clip_gradient) + m_1 = beta1 * m_0 + (1 - beta1) * gradient_1 + v_1 = beta2 * v_0 + (1 - beta2) * (gradient_1 ** 2) + weight_1 = weight_0 - lr_1 * (m_1 / (np.sqrt(v_1) + epsilon) + wd * weight_0) + return weight_1 + + helper(opt_sym, inputs, params, update_func, 2, target, ctx) + +if __name__ == "__main__": + test_sgd() + test_adam() diff --git a/nnvm/tests/python/compiler/test_top_level1.py b/nnvm/tests/python/compiler/test_top_level1.py index abe30d1ec8d026c4e84891caf65ac322453f57c1..480aa271af8f0f317593a1dc5c836038f66f23b9 100644 --- a/nnvm/tests/python/compiler/test_top_level1.py +++ b/nnvm/tests/python/compiler/test_top_level1.py @@ -8,15 +8,14 @@ from nnvm.testing.config import ctx_list def helper(symbol, inputs, dtype, - np_forward, np_backward=None): + np_forward, np_backward=None, need_input=True, need_head_grads=True): ishapes = {} input_syms = [] np_inputs = {} - for (k, v) in inputs.items(): - ishapes.update({k: v[0]}) - np_inputs.update({k: np.random.uniform(size=v[0]).astype(dtype)}) - if len(v) > 1: - input_syms.append(v[1]) + for (name, shape, s) in inputs: + ishapes.update({name: shape}) + np_inputs.update({name: np.random.uniform(size=shape).astype(dtype)}) + input_syms.append(s) for target, ctx in ctx_list(): graph, lib, _ = nnvm.compiler.build(symbol, target, ishapes) @@ -25,23 +24,26 @@ def helper(symbol, inputs, dtype, y_np = np_forward(**np_inputs) out = m.get_output(0, tvm.nd.empty(y_np.shape, dtype)) np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5) - # backward if np_backward: graph._set_symbol_list_attr("grad_ys", symbol) - for x in input_syms: - graph._set_symbol_list_attr("grad_xs", x) - graph._set_symbol_list_attr("grad_ys_out_grad", sym.Variable("head_grads")) + graph._set_symbol_list_attr("grad_xs", input_syms) + graph._set_symbol_list_attr("grad_ys_out_grad", sym.Variable("head_grads", shape=y_np.shape)) graph = graph.apply("Gradient") ishapes.update({"head_grads": y_np.shape}) graph, lib, _ = nnvm.compiler.build(graph, target, ishapes) m = graph_runtime.create(graph, lib, ctx) head_grads = np.random.uniform(size=y_np.shape).astype(dtype) - y_np = head_grads * np_backward(**np_inputs) - m.run(head_grads=head_grads, **np_inputs) - out = m.get_output(0, tvm.nd.empty(y_np.shape, dtype)) - - np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5) + y_np = np_backward(head_grads=head_grads, **np_inputs) + b_inputs = {} + if need_input: + b_inputs.update(np_inputs) + if need_head_grads: + b_inputs.update({"head_grads":head_grads}) + m.run(**b_inputs) + for i in range(len(y_np)): + out = m.get_output(i, tvm.nd.empty(y_np[i].shape, dtype)) + np.testing.assert_allclose(out.asnumpy(), y_np[i], atol=1e-5, rtol=1e-5) def test_relu(): @@ -52,10 +54,15 @@ def test_relu(): x = (x < 0) * x * 0.3 + (x > 0) * x - 0.2 return (x > 0) * x + def backward(head_grads, x): + sub = (x < 0) * x * 0.3 + (x > 0) * x - 0.2 + return [(sub > 0).astype("float") * \ + ((x > 0).astype("float") + 0.3 * (x < 0).astype("float")) * head_grads] + dtype = "float32" dshape = (1, 3, 32, 32) - inputs = {'x': (dshape, x)} - helper(y, inputs, dtype, forward) + inputs = [('x', dshape, x)] + helper(y, inputs, dtype, forward, backward) def test_sym_scalar_pow(): @@ -66,12 +73,12 @@ def test_sym_scalar_pow(): def forward(x): return x**scalar - def backward(x): - return scalar * x**(scalar - 1) + def backward(head_grads, x): + return [scalar * x**(scalar - 1) * head_grads] dtype = "float32" dshape = (1, 3, 32, 32) - inputs = {'x': (dshape, x)} + inputs = [('x', dshape, x)] helper(y, inputs, dtype, forward, backward) @@ -83,12 +90,12 @@ def test_scalar_sym_pow(): def forward(x): return scalar**x - def backward(x): - return np.log(scalar) * scalar**x + def backward(head_grads, x): + return [np.log(scalar) * scalar**x * head_grads] dtype = "float32" dshape = (1, 3, 32, 32) - inputs = {'x': (dshape, x)} + inputs = [('x', dshape, x)] helper(y, inputs, dtype, forward, backward) @@ -99,12 +106,12 @@ def test_exp(): def forward(x): return np.exp(x) - def backward(x): - return np.exp(x) + def backward(head_grads, x): + return [np.exp(x) * head_grads] dtype = "float32" dshape = (1, 3, 32, 32) - inputs = {'x': (dshape, x)} + inputs = [('x', dshape, x)] helper(y, inputs, dtype, forward, backward) @@ -115,12 +122,12 @@ def test_log(): def forward(x): return np.log(x) - def backward(x): - return 1. / x + def backward(head_grads, x): + return [1. / x * head_grads] dtype = "float32" dshape = (1, 3, 32, 32) - inputs = {'x': (dshape, x)} + inputs = [('x', dshape, x)] helper(y, inputs, dtype, forward, backward) @@ -131,13 +138,13 @@ def test_tanh(): def forward(x): return np.sinh(x) / np.cosh(x) - def backward(x): + def backward(head_grads, x): y_np = forward(x) - return (1 - y_np**2) + return [(1 - y_np**2) * head_grads] dtype = "float32" dshape = (1, 3, 32, 32) - inputs = {'x': (dshape, x)} + inputs = [('x', dshape, x)] helper(y, inputs, dtype, forward, backward) @@ -148,13 +155,13 @@ def test_sigmoid(): def forward(x): return 1.0 / (1.0 + np.exp(-x)) - def backward(x): + def backward(head_grads, x): y_np = forward(x) - return y_np *(1 - y_np) + return [y_np *(1 - y_np) * head_grads] dtype = "float32" dshape = (1, 3, 32, 32) - inputs = {'x': (dshape, x)} + inputs = [('x', dshape, x)] helper(y, inputs, dtype, forward, backward) @@ -165,10 +172,15 @@ def test_softmax(): def forward(x): return topi.testing.softmax_python(x) + def backward(head_grads, x): + y = topi.testing.softmax_python(x) + grad = y * (head_grads - np.sum(y * head_grads, axis=1, keepdims=True)) + return [grad] + dtype = "float32" dshape = (10, 1000) - inputs = {'x': (dshape, x)} - helper(y, inputs, dtype, forward) + inputs = [('x', dshape, x)] + helper(y, inputs, dtype, forward), backward def test_log_softmax(): @@ -178,26 +190,32 @@ def test_log_softmax(): def forward(x): return topi.testing.log_softmax_python(x) + def backward(head_grads, x): + y = topi.testing.log_softmax_python(x) + grad = head_grads - np.sum(y * head_grads, axis=1, keepdims=True) + return [grad] + dtype = "float32" dshape = (10, 1000) - inputs = {'x': (dshape, x)} - helper(y, inputs, dtype, forward) + inputs = [('x', dshape, x)] + helper(y, inputs, dtype, forward, backward) def test_dense(): - x = sym.Variable("x") - y = sym.dense(x, units=3, name="dense") + x = sym.Variable("x", shape=(10, 100)) + w = sym.Variable("dense_weight", shape=(3, 100)) + b = sym.Variable("dense_bias", shape=(3,)) + y = sym.dense(x, w, b, use_bias=True, units=3, name="dense") y = sym.flatten(y) def forward(x, dense_weight, dense_bias): return np.dot(x, dense_weight.T) + dense_bias - dtype = "float32" - inputs = { - 'x': ((10, 100), x), - 'dense_weight': ((3, 100),), - 'dense_bias': ((3,),) - } + inputs = [ + ('x', (10, 100), x), + ('dense_weight', (3, 100), w), + ('dense_bias', (3,), b) + ] helper(y, inputs, dtype, forward) @@ -215,13 +233,13 @@ def test_batchnorm(): return (x - moving_mean) / np.sqrt(moving_var + eps) * gamma + beta dtype = "float32" - inputs = { - 'x': ((10, 20), x), - 'gamma': ((20,),), - 'beta': ((20,),), - 'moving_mean': ((20,),), - 'moving_var': ((20,),) - } + inputs = [ + ('x', (10, 20), x), + ('gamma', (20,), gamma), + ('beta', (20,), beta), + ('moving_mean', (20,), moving_var), + ('moving_var', (20,), moving_mean) + ] helper(y, inputs, dtype, forward) @@ -283,9 +301,12 @@ def verify_squeeze(dshape, axis): def forward(x): return np.squeeze(x, axis=axis) + 1 + def backward(head_grads, x): + return [np.reshape(head_grads, x.shape)] + dtype = "float32" - inputs = {'x': (dshape, x)} - helper(y, inputs, dtype, forward) + inputs = [('x', dshape, x)] + helper(y, inputs, dtype, forward, backward) def test_squeeze(): @@ -304,7 +325,7 @@ def test_pad(): mode='constant', constant_values=1.) dtype = "float32" - inputs = {'x': ((1, 3, 28, 28), x)} + inputs = [('x', (1, 3, 28, 28), x)] helper(y, inputs, dtype, forward) diff --git a/nnvm/tests/python/compiler/test_top_level4.py b/nnvm/tests/python/compiler/test_top_level4.py index ad09d73ec28dadeba199eed8c087ea32b9d35572..c6e8620fc972ddac6d9aa19410de290bcbf81284 100644 --- a/nnvm/tests/python/compiler/test_top_level4.py +++ b/nnvm/tests/python/compiler/test_top_level4.py @@ -6,6 +6,46 @@ import nnvm.symbol as sym import nnvm.compiler from nnvm.testing.config import ctx_list + +def helper(symbol, inputs, dtype, + np_forward, np_backward=None, need_input=True, need_head_grads=True): + ishapes = {} + input_syms = [] + np_inputs = {} + for (name, shape, s) in inputs: + ishapes.update({name: shape}) + np_inputs.update({name: np.random.uniform(size=shape).astype(dtype)}) + input_syms.append(s) + + for target, ctx in ctx_list(): + graph, lib, _ = nnvm.compiler.build(symbol, target, ishapes) + m = graph_runtime.create(graph, lib, ctx) + m.run(**np_inputs) + y_np = np_forward(**np_inputs) + out = m.get_output(0, tvm.nd.empty(y_np.shape, dtype)) + np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5) + # backward + if np_backward: + graph._set_symbol_list_attr("grad_ys", symbol) + graph._set_symbol_list_attr("grad_xs", input_syms) + graph._set_symbol_list_attr("grad_ys_out_grad", sym.Variable("head_grads", shape=y_np.shape)) + graph = graph.apply("Gradient") + ishapes.update({"head_grads": y_np.shape}) + graph, lib, _ = nnvm.compiler.build(graph, target, ishapes) + m = graph_runtime.create(graph, lib, ctx) + head_grads = np.random.uniform(size=y_np.shape).astype(dtype) + y_np = np_backward(head_grads=head_grads, **np_inputs) + b_inputs = {} + if need_input: + b_inputs.update(np_inputs) + if need_head_grads: + b_inputs.update({"head_grads":head_grads}) + m.run(**b_inputs) + for i in range(len(y_np)): + out = m.get_output(i, tvm.nd.empty(y_np[i].shape, dtype)) + np.testing.assert_allclose(out.asnumpy(), y_np[i], atol=1e-5, rtol=1e-5) + + def verify_transpose(dshape, axes): x = sym.Variable("x") if axes: @@ -66,13 +106,245 @@ def verify_reshape(dshape, oshape): out = m.get_output(0, tvm.nd.empty(out_np.shape)) np.testing.assert_allclose(out.asnumpy(), out_np, atol=1e-5, rtol=1e-5) + def test_reshape(): verify_reshape((2, 3, 4), (-1, 2, 1)) verify_reshape((2, 3, 4), (8, 3)) verify_reshape((4, 7), (2, 7, 2)) + +def test_clip(): + x = sym.Variable("x") + a_min=0.2 + a_max=0.75 + y = sym.clip(x, a_min=a_min, a_max=a_max) + + def forward(x): + return np.clip(x, a_min=a_min, a_max=a_max) + + def backward(head_grads, x): + mask1 = np.greater_equal(x, a_min).astype("float") + mask2 = np.less_equal(x, a_max).astype("float") + return [head_grads * mask1 * mask2] + + + dtype = "float32" + inputs = [('x', (3, 4, 5), x)] + helper(y, inputs, dtype, forward, backward) + + +def test_greater(): + l = sym.Variable("l") + r = sym.Variable("r") + y = sym.greater(l, r) + + def forward(l, r): + return np.greater(l, r).astype("float32") + + def backward(head_grads, l, r): + return [np.zeros_like(l)] + + + dtype = "float32" + inputs = [('l', (3, 4, 5), l), + ('r', (3, 4, 5), r)] + helper(y, inputs, dtype, forward, backward, need_head_grads=False) + + +def test_less(): + l = sym.Variable("l") + r = sym.Variable("r") + y = sym.less(l, r) + + def forward(l, r): + return np.less(l, r).astype("float32") + + def backward(head_grads, l, r): + return [np.zeros_like(l)] + + + dtype = "float32" + inputs = [('l', (3, 4, 5), l), + ('r', (3, 4, 5), r)] + helper(y, inputs, dtype, forward, backward, need_head_grads=False) + + +def test_reshape_like(): + x = sym.Variable("x") + y = sym.Variable("y") + z = sym.reshape_like(x, y) + + def forward(x, y): + return np.reshape(x, y.shape) + + def backward(head_grads, x, y): + return [np.reshape(head_grads, x.shape), + np.zeros_like(y)] + + + dtype = "float32" + inputs = [('x', (3, 4, 5), x), + ('y', (5, 4, 3), y)] + helper(z, inputs, dtype, forward, backward) + + +def verify_expand_like(in_shape, out_shape, axis, exclude): + x = sym.Variable("x") + y = sym.Variable("y") + z = sym.expand_like(x, y, axis=axis, exclude=exclude) + + def forward(x, y): + odim = len(out_shape) + real_axis = [i if i >= 0 else i + odim for i in axis] + real_axis = sorted(real_axis) + if exclude: + real_axis = list(set(range(odim)) - set(real_axis)) + for i in real_axis: + x = np.expand_dims(x, i).astype(x.dtype) + for i in real_axis: + x = np.concatenate([x]*out_shape[i], axis=i).astype(x.dtype) + + return x + + def backward(head_grads, x, y): + odim = len(out_shape) + real_axis = [i if i >= 0 else i + odim for i in axis] + real_axis = sorted(real_axis) + if exclude: + real_axis = list(set(range(odim)) - set(real_axis)) + return [np.sum(head_grads, axis=tuple(real_axis)), + np.zeros_like(y)] + + + dtype = "float32" + inputs = [('x', in_shape, x), + ('y', out_shape, y)] + helper(z, inputs, dtype, forward, backward, need_input=False) + + +def test_expand_like(): + verify_expand_like((3,), (3, 2), [1], False) + verify_expand_like((2,), (2, 3), [1], False) + verify_expand_like((3, 4), (3, 5, 4), [1], False) + verify_expand_like((5, 7), (5, 6, 7, 8), [0, 2], True) + + +def verify_elemwise_sum(num_args): + s = [sym.Variable("input" + str(i)) for i in range(num_args)] + y = sym.elemwise_sum(*s, num_args=num_args) + + def forward(**inputs): + return np.sum(np.array(list(inputs.values())), axis=0) + + def backward(head_grads, **inputs): + return [head_grads] * num_args + + dtype = "float32" + inputs = [("input" + str(i), (3, 4, 5), s[i]) + for i in range(num_args)] + helper(y, inputs, dtype, forward, backward, need_input=False) + + +def test_elemwise_sum(): + verify_elemwise_sum(1) + verify_elemwise_sum(5) + verify_elemwise_sum(7) + + +def test_block_grad(): + x = sym.Variable("x") + y = sym.block_grad(x) + + def forward(x): + return x + + def backward(head_grads, x): + return [np.zeros_like(head_grads)] + + + dtype = "float32" + inputs = [('x', (3, 4, 5), x)] + helper(y, inputs, dtype, forward, backward, need_head_grads=False) + + +def test_full(): + shape = (3, 4, 5) + value = 7 + dtype = "float32" + for target, ctx in ctx_list(): + data = sym.Variable("data", dtype=dtype) + # full_like + s = sym.full_like(data=data, fill_value=value, name="s") + graph, lib, _ = nnvm.compiler.build(s, target, {"data": shape}) + m = graph_runtime.create(graph, lib, ctx) + m.run(data=np.random.uniform(size=shape).astype(dtype)) + out = m.get_output(0, tvm.nd.empty(shape, dtype=dtype)) + np.testing.assert_allclose( + out.asnumpy(), + np.full(shape, fill_value=value, dtype=dtype), + atol=1e-5, rtol=1e-5) + # ones_like + s = sym.ones_like(data=data, fill_value=value, name="s") + graph, lib, _ = nnvm.compiler.build(s, target, {"data": shape}) + m = graph_runtime.create(graph, lib, ctx) + m.run(data=np.random.uniform(size=shape).astype(dtype)) + out = m.get_output(0, tvm.nd.empty(shape, dtype=dtype)) + np.testing.assert_allclose( + out.asnumpy(), + np.full(shape, fill_value=1, dtype=dtype), + atol=1e-5, rtol=1e-5) + # zeros_like + s = sym.zeros_like(data=data, fill_value=value, name="s") + graph, lib, _ = nnvm.compiler.build(s, target, {"data": shape}) + m = graph_runtime.create(graph, lib, ctx) + m.run(data=np.random.uniform(size=shape).astype(dtype)) + out = m.get_output(0, tvm.nd.empty(shape, dtype=dtype)) + np.testing.assert_allclose( + out.asnumpy(), + np.full(shape, fill_value=0, dtype=dtype), + atol=1e-5, rtol=1e-5) + # full + s = sym.full(shape=shape, dtype=dtype, fill_value=value, name="s") + graph, lib, _ = nnvm.compiler.build(s, target) + m = graph_runtime.create(graph, lib, ctx) + m.run() + out = m.get_output(0, tvm.nd.empty(shape, dtype=dtype)) + np.testing.assert_allclose( + out.asnumpy(), + np.full(shape, fill_value=value, dtype=dtype), + atol=1e-5, rtol=1e-5) + # ones + s = sym.ones(shape=shape, dtype=dtype, name="s") + graph, lib, _ = nnvm.compiler.build(s, target) + m = graph_runtime.create(graph, lib, ctx) + m.run() + out = m.get_output(0, tvm.nd.empty(shape, dtype=dtype)) + np.testing.assert_allclose( + out.asnumpy(), + np.full(shape, fill_value=1, dtype=dtype), + atol=1e-5, rtol=1e-5) + # zeros + s = sym.zeros(shape=shape, dtype=dtype, name="s") + graph, lib, _ = nnvm.compiler.build(s, target) + m = graph_runtime.create(graph, lib, ctx) + m.run() + out = m.get_output(0, tvm.nd.empty(shape, dtype=dtype)) + np.testing.assert_allclose( + out.asnumpy(), + np.full(shape, fill_value=0, dtype=dtype), + atol=1e-5, rtol=1e-5) + + if __name__ == "__main__": test_reshape() test_reduce() test_tranpose() + test_clip() + test_greater() + test_less() + test_reshape_like() + test_expand_like() + test_elemwise_sum() + test_block_grad() + test_full() print(nnvm.compiler.engine.dump())