diff --git a/python/tvm/relay/backend/graph_runtime_codegen.py b/python/tvm/relay/backend/graph_runtime_codegen.py index 4351fea6b45934e8d51a97c6c21fb5a8586cbdfb..0da9b81269aa5dc69fa8ca453c1a39a33df957c6 100644 --- a/python/tvm/relay/backend/graph_runtime_codegen.py +++ b/python/tvm/relay/backend/graph_runtime_codegen.py @@ -236,18 +236,14 @@ class GraphRuntimeCodegen(ExprFunctor): self.lowered_funcs.add(loweredf) inputs = [] - tuple_arg_count = 0 + # flatten tuple in the call. for arg in call.args: + res = self.visit(arg) if isinstance(arg.checked_type, TupleType): - tuple_arg_count += 1 - inputs.append(self.visit(arg)) - # We need to specially handle tuple inputs and - # tuple output cases. - # Tuple input function(e.g. concat) - if tuple_arg_count: - assert len(call.args) == 1 - assert isinstance(inputs[0], tuple) - inputs = list(inputs[0]) + assert isinstance(res, tuple) + inputs += res + else: + inputs.append(res) inputs = [x.to_json() for x in inputs] op_name = cached_func.func_name diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 9d1bd0deffa9270a77fbb5295981a138b349f5cb..b0b1e700987c2fdae0c2f3c8406187e3a55481ea 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -589,11 +589,11 @@ def from_mxnet(symbol, shape, dtype = _update_shape_dtype(shape, dtype, params) sym = _from_mxnet_impl(symbol, shape, dtype) elif isinstance(symbol, mx.gluon.HybridBlock): - if args_params is not None or aux_params is not None: + if arg_params is not None or aux_params is not None: raise ValueError("arg_params and aux_params ae not used when importing HybridBlock") params = {} for k, v in symbol.collect_params().items(): - params[k] = tvm.nd.array(v.data().asnumpy()) + params[k] = _nd.array(v.data().asnumpy()) data = mx.sym.Variable("data") sym = symbol(data) shape, dtype = _update_shape_dtype(shape, dtype, params) diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index 4832a195f9e8dad3733e5f43e83226e7cd3a34cb..774e091baefce27b8dc935567c3458070670f8ce 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -5,223 +5,37 @@ import topi from .op import register_compute, register_schedule, register_pattern from .op import schedule_injective, OpPattern + schedule_broadcast = schedule_injective schedule_elemwise = schedule_injective -# log -@register_compute("log") -def log_compute(attrs, inputs, output_type, target): - assert len(inputs) == 1 - return [topi.log(inputs[0])] - register_schedule("log", schedule_broadcast) - -# exp -@register_compute("exp") -def exp_compute(attrs, inputs, output_type, target): - assert len(inputs) == 1 - return [topi.exp(inputs[0])] - register_schedule("exp", schedule_broadcast) - -# sqrt -@register_compute("sqrt") -def sqrt_compute(attrs, inputs, output_type, target): - assert len(inputs) == 1 - return [topi.sqrt(inputs[0])] - register_schedule("sqrt", schedule_broadcast) - -# sigmoid -@register_compute("sigmoid") -def sigmoid_compute(attrs, inputs, output_type, target): - assert len(inputs) == 1 - return [topi.sigmoid(inputs[0])] - register_schedule("sigmoid", schedule_broadcast) - -# floor -@register_compute("floor") -def floor_compute(attrs, inputs, output_type, target): - assert len(inputs) == 1 - return [topi.floor(inputs[0])] - register_schedule("floor", schedule_broadcast) - -# ceil -@register_compute("ceil") -def ceil_compute(attrs, inputs, output_type, target): - assert len(inputs) == 1 - return [topi.ceil(inputs[0])] - register_schedule("ceil", schedule_broadcast) - -# trunc -@register_compute("trunc") -def trunc_compute(attrs, inputs, output_type, target): - assert len(inputs) == 1 - return [topi.trunc(inputs[0])] - register_schedule("trunc", schedule_broadcast) - -# round -@register_compute("round") -def round_compute(attrs, inputs, output_type, target): - assert len(inputs) == 1 - return [topi.round(inputs[0])] - register_schedule("round", schedule_broadcast) - -# abs -@register_compute("abs") -def abs_compute(attrs, inputs, output_type, target): - assert len(inputs) == 1 - return [topi.abs(inputs[0])] - register_schedule("abs", schedule_broadcast) - -# tanh -@register_compute("tanh") -def tanh_compute(attrs, inputs, output_type, target): - assert len(inputs) == 1 - return [topi.tanh(inputs[0])] - register_schedule("tanh", schedule_broadcast) - -# negative -@register_compute("negative") -def negative_compute(attrs, inputs, output_type, target): - assert len(inputs) == 1 - return [topi.negative(inputs[0])] - register_schedule("negative", schedule_broadcast) -# add -@register_compute("add") -def add_compute(attrs, inputs, output_type, target): - assert len(inputs) == 2 - return [topi.add(inputs[0], inputs[1])] - -register_schedule("add", schedule_injective) - -# subtract -@register_compute("subtract") -def subtract_compute(attrs, inputs, output_type, target): - assert len(inputs) == 2 - return [topi.subtract(inputs[0], inputs[1])] - +register_schedule("add", schedule_broadcast) register_schedule("subtract", schedule_broadcast) - -# multiply -@register_compute("multiply") -def multiply_compute(attrs, inputs, output_type, target): - assert len(inputs) == 2 - return [topi.multiply(inputs[0], inputs[1])] - register_schedule("multiply", schedule_broadcast) - -# divide -@register_compute("divide") -def divide_compute(attrs, inputs, output_type, target): - assert len(inputs) == 2 - return [topi.divide(inputs[0], inputs[1])] - register_schedule("divide", schedule_broadcast) - -# power -@register_compute("power") -def power_compute(attrs, inputs, output_type, target): - assert len(inputs) == 2 - return [topi.power(inputs[0], inputs[1])] - register_schedule("power", schedule_injective) - -# mod -@register_compute("mod") -def mod_compute(attrs, inputs, output_type, target): - assert len(inputs) == 2 - return [topi.mod(inputs[0], inputs[1])] - register_schedule("mod", schedule_broadcast) - -# equal -@register_compute("equal") -def equal_compute(attrs, inputs, output_type, target): - assert len(inputs) == 2 - return [topi.equal(inputs[0], inputs[1])] - register_schedule("equal", schedule_broadcast) - -# not_equal -@register_compute("not_equal") -def not_equal_compute(attrs, inputs, output_type, target): - assert len(inputs) == 2 - return [topi.not_equal(inputs[0], inputs[1])] - register_schedule("not_equal", schedule_broadcast) - -# less -@register_compute("less") -def less_compute(attrs, inputs, output_type, target): - assert len(inputs) == 2 - return [topi.less(inputs[0], inputs[1])] - register_schedule("less", schedule_broadcast) - -# less equal -@register_compute("less_equal") -def less_equal_compute(attrs, inputs, output_type, target): - assert len(inputs) == 2 - return [topi.less_equal(inputs[0], inputs[1])] - register_schedule("less_equal", schedule_broadcast) - -# greater -@register_compute("greater") -def greater_compute(attrs, inputs, output_type, target): - assert len(inputs) == 2 - return [topi.greater(inputs[0], inputs[1])] - register_schedule("greater", schedule_broadcast) - -# greater equal -@register_compute("greater_equal") -def greater_equal_compute(attrs, inputs, output_type, target): - assert len(inputs) == 2 - return [topi.greater_equal(inputs[0], inputs[1])] - register_schedule("greater_equal", schedule_broadcast) - -# maximum -@register_compute("maximum") -def maximum_compute(attrs, inputs, output_type, target): - assert len(inputs) == 2 - return [topi.maximum(inputs[0], inputs[1])] - register_schedule("maximum_compute", schedule_injective) - -# minimum -@register_compute("minimum") -def minimum_compute(attrs, inputs, output_type, target): - assert len(inputs) == 2 - return [topi.minimum(inputs[0], inputs[1])] - register_schedule("minimum", schedule_injective) - -# right shift -@register_compute("right_shift") -def right_shift_compute(attrs, inputs, output_type, target): - assert len(inputs) == 2 - return [topi.right_shift(inputs[0], inputs[1])] - register_schedule("right_shift", schedule_injective) - -# left shift -@register_compute("left_shift") -def left_shift_compute(attrs, inputs, output_type, target): - assert len(inputs) == 2 - return [topi.left_shift(inputs[0], inputs[1])] - register_schedule("left_shift", schedule_injective) # zeros @@ -273,5 +87,4 @@ def concatenate_compute(attrs, inputs, output_type, target): return [topi.concatenate(inputs, axis=attrs.axis)] register_schedule("concatenate", schedule_injective) -# TODO(tqchen): renable concat as injective -register_pattern("concatenate", OpPattern.OPAQUE) +register_pattern("concatenate", OpPattern.INJECTIVE) diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index b10e9f2e2ea3faee6875cdc184c02e4103281335..8cb1279a1435c398061433bc2c366b8f8460d32a 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -56,30 +56,26 @@ class ScheduleGetter : Op::GetAttr<FTVMSchedule>("FTVMSchedule"); auto cache_node = make_node<CachedFuncNode>(); cache_node->target = target_; - - if (prim_func->params.size() == 1 && - prim_func->params[0]->checked_type().as<TupleTypeNode>()) { - // Handle tuple input type by flattening them. - // This is the current calling convention of tuple input. + for (Var param : prim_func->params) { Array<tvm::Tensor> inputs; - for (Type field : prim_func->params[0]->type_as<TupleTypeNode>()->fields) { - const auto* ttype = field.as<TensorTypeNode>(); - CHECK(ttype != nullptr); + if (const auto* ttype = param->checked_type().as<TensorTypeNode>()) { tvm::Tensor tensor = tvm::placeholder( GetShape(ttype->shape), ttype->dtype); cache_node->inputs.push_back(tensor); inputs.push_back(tensor); + } else { + // flatten tuple of tensor type. + const auto* tuple_type = param->type_as<TupleTypeNode>(); + for (Type field : tuple_type->fields) { + const auto* ttype = field.as<TensorTypeNode>(); + CHECK(ttype != nullptr); + tvm::Tensor tensor = tvm::placeholder( + GetShape(ttype->shape), ttype->dtype); + cache_node->inputs.push_back(tensor); + inputs.push_back(tensor); + } } - memo_[prim_func->params[0]] = inputs; - - } else { - for (Var param : prim_func->params) { - const auto* ttype = param->type_as<TensorTypeNode>(); - tvm::Tensor tensor = tvm::placeholder( - GetShape(ttype->shape), ttype->dtype); - cache_node->inputs.push_back(tensor); - memo_[param] = Array<Tensor>({tensor}); - } + memo_[param] = inputs; } readable_name_stream_ << "fused"; cache_node->outputs = this->VisitExpr(prim_func->body); @@ -161,8 +157,9 @@ class ScheduleGetter : int op_pattern = fpattern[op]; if (op_pattern >= kCommReduce) { - CHECK(!master_op_.defined()) - << "Two complicated op in a primitive function"; + CHECK(!master_op_.defined() || master_op_patetrn_ < kCommReduce) + << "Two complicated op in a primitive function " + << " master=" << master_op_ << " current=" << op; } if (op_pattern >= master_op_patetrn_) { master_op_ = op; diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index db96a3ad4de1638bf9fe69abe40ab27994afc1c8..5bef4a22f371531560705a930cf42f26ccc6b1e5 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -212,7 +212,7 @@ class Interpreter : // Marshal the arguments. // Handle tuple input/output by flattening them. size_t arg_len = 0; - for (size_t i = 0; i < args.size(); i++) { + for (size_t i = 0; i < args.size(); ++i) { if (args[i].as<TensorValueNode>()) { ++arg_len; } else { @@ -242,22 +242,19 @@ class Interpreter : << context_ << ", but get " << arg_ctx; }; - if (func->params.size() == 1 && - func->params[0]->checked_type().as<TupleTypeNode>()) { - // handle tuple input. - const TupleValueNode* tuple = args[0].as<TupleValueNode>(); - CHECK(tuple); - for (size_t i = 0; i < tuple->fields.size(); ++i) { - fset_input(i, tuple->fields[i]); - } - } else { - CHECK_EQ(num_inputs, args.size()); - // Decide the target context. - // Primitive functions always sit in the same context. - for (size_t i = 0; i < args.size(); i++) { - fset_input(i, args[i]); + int arg_counter = 0; + for (Value arg : args) { + if (arg.as<TensorValueNode>()) { + fset_input(arg_counter++, arg); + } else { + const TupleValueNode* tuple = arg.as<TupleValueNode>(); + CHECK(tuple != nullptr); + for (size_t i = 0; i < tuple->fields.size(); ++i) { + fset_input(arg_counter++, tuple->fields[i]); + } } } + // TVM's calling convention is that the final argument is the output // buffer. To preserve the illusion of being a functional language // we need to allocate space for the output buffer based on the diff --git a/src/relay/op/tensor/binary.cc b/src/relay/op/tensor/binary.cc index 171824fcd3ae98399bef5cf3955356a550a0b032..3f28bd52cd4ba97c9d63141a5029208e66d93016 100644 --- a/src/relay/op/tensor/binary.cc +++ b/src/relay/op/tensor/binary.cc @@ -5,54 +5,75 @@ */ #include <tvm/relay/expr.h> #include <tvm/relay/op.h> +#include <topi/broadcast.h> #include "../type_relations.h" #include "../op_common.h" namespace tvm { namespace relay { +#define RELAY_BINARY_COMPUTE(FTOPI) \ + [] (const Attrs& attrs, \ + const Array<Tensor>& inputs, \ + const Type& out_type, \ + const Target& target) -> Array<Tensor> { \ + CHECK_EQ(inputs.size(), 2U); \ + return {FTOPI(inputs[0], inputs[1])}; \ + } \ + + // Addition RELAY_REGISTER_BINARY_OP("relay.op._make.", "add") .describe("Elementwise add with with broadcasting") -.set_support_level(1); +.set_support_level(1) +.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::add)); // Subtraction RELAY_REGISTER_BINARY_OP("relay.op._make.", "subtract") .describe("Elementwise substract with broadcasting") -.set_support_level(1); +.set_support_level(1) +.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::subtract)); // Right shift RELAY_REGISTER_BINARY_OP("relay.op._make.", "right_shift") .describe("Elementwise right shift with broadcasting") -.set_support_level(4); +.set_support_level(4) +.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::right_shift)); RELAY_REGISTER_BINARY_OP("relay.op._make.", "left_shift") .describe("Elementwise left shift with broadcasting") -.set_support_level(4); +.set_support_level(4) +.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::left_shift)); RELAY_REGISTER_BINARY_OP("relay.op._make.", "maximum") .describe("Elementwise maximum of two tensors with broadcasting") -.set_support_level(4); +.set_support_level(4) +.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::maximum)); RELAY_REGISTER_BINARY_OP("relay.op._make.", "minimum") .describe("Elementwise minimum of two tensors with broadcasting") -.set_support_level(4); +.set_support_level(4) +.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::minimum)); RELAY_REGISTER_BINARY_OP("relay.op._make.", "divide") .describe("Elementwise divide with broadcasting") -.set_support_level(1); +.set_support_level(1) +.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::divide)); RELAY_REGISTER_BINARY_OP("relay.op._make.", "multiply") .describe("Elementwise multiply with broadcasting") -.set_support_level(1); +.set_support_level(1) +.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::multiply)); RELAY_REGISTER_BINARY_OP("relay.op._make.", "power") .describe("Elementwise power with broadcasting") -.set_support_level(4); +.set_support_level(4) +.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::power)); RELAY_REGISTER_BINARY_OP("relay.op._make.", "mod") .describe("Elementwise mod with broadcasting") -.set_support_level(1); +.set_support_level(1) +.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::mod)); // Comparisons #define RELAY_REGISTER_CMP_OP(OpName) \ @@ -70,22 +91,38 @@ RELAY_REGISTER_BINARY_OP("relay.op._make.", "mod") RELAY_REGISTER_CMP_OP("equal") .describe("Elementwise equal compare with broadcasting") -.set_support_level(4); +.set_support_level(4) +.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::equal)); + + RELAY_REGISTER_CMP_OP("not_equal") .describe("Elementwise not equal with broadcasting") -.set_support_level(4); +.set_support_level(4) +.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::not_equal)); + + RELAY_REGISTER_CMP_OP("less") .describe("Elementwise less than with broadcasting") -.set_support_level(4); +.set_support_level(4) +.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::less)); + + RELAY_REGISTER_CMP_OP("less_equal") .describe("Elementwise less than or equal compare with broadcasting") -.set_support_level(4); +.set_support_level(4) +.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::less_equal)); + + RELAY_REGISTER_CMP_OP("greater") .describe("Elementwise greater than compare with broadcasting") -.set_support_level(4); +.set_support_level(4) +.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::greater)); + + RELAY_REGISTER_CMP_OP("greater_equal") .describe("Elementwise greater than or equal compare with broadcasting") -.set_support_level(4); +.set_support_level(4) +.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::greater_equal)); } // namespace relay } // namespace tvm diff --git a/src/relay/op/tensor/unary.cc b/src/relay/op/tensor/unary.cc index 22f97e8f0d543e3fae956a263cc04ddf05ab3d71..6c94fe2adcc24c0b6a0d03f0b36fee2908e11d58 100644 --- a/src/relay/op/tensor/unary.cc +++ b/src/relay/op/tensor/unary.cc @@ -5,12 +5,21 @@ */ #include <tvm/relay/expr.h> #include <tvm/relay/op.h> +#include <topi/elemwise.h> #include "../type_relations.h" #include "../op_common.h" namespace tvm { namespace relay { +#define RELAY_UNARY_COMPUTE(FTOPI) \ + [] (const Attrs& attrs, \ + const Array<Tensor>& inputs, \ + const Type& out_type, \ + const Target& target) -> Array<Tensor> { \ + return {FTOPI(inputs[0])}; \ + } \ + RELAY_REGISTER_UNARY_OP("relay.op._make.", "log") .describe(R"code(Returns the log input array, computed element-wise. @@ -20,7 +29,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "log") )code" TVM_ADD_FILELINE) .set_support_level(1) -.add_type_rel("Identity", IdentityRel); +.add_type_rel("Identity", IdentityRel) +.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log)); + RELAY_REGISTER_UNARY_OP("relay.op._make.", "exp") .describe(R"code(Returns the exp input array, computed element-wise. @@ -30,7 +41,8 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "exp") )code" TVM_ADD_FILELINE) .set_support_level(1) -.add_type_rel("Identity", IdentityRel); +.add_type_rel("Identity", IdentityRel) +.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::exp)); RELAY_REGISTER_UNARY_OP("relay.op._make.", "sqrt") @@ -41,7 +53,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "sqrt") )code" TVM_ADD_FILELINE) .set_support_level(1) -.add_type_rel("Identity", IdentityRel); +.add_type_rel("Identity", IdentityRel) +.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sqrt)); + RELAY_REGISTER_UNARY_OP("relay.op._make.", "zeros_like") .describe(R"code(Returns an array of zeros, with same type and shape as the input. @@ -49,6 +63,7 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "zeros_like") .set_support_level(1) .add_type_rel("Identity", IdentityRel); + RELAY_REGISTER_UNARY_OP("relay.op._make.", "ones_like") .describe(R"code(Returns an array of ones, with same type and shape as the input. )code" TVM_ADD_FILELINE) @@ -63,13 +78,17 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "sigmoid") )code" TVM_ADD_FILELINE) .set_support_level(1) -.add_type_rel("Identity", IdentityRel); +.add_type_rel("Identity", IdentityRel) +.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sigmoid)); + RELAY_REGISTER_UNARY_OP("relay.op._make.", "copy") .describe(R"code(Copy a tensor. )code" TVM_ADD_FILELINE) .set_support_level(3) -.add_type_rel("Identity", IdentityRel); +.add_type_rel("Identity", IdentityRel) +.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::identity)); + // Clip struct ClipAttrs : public tvm::AttrsNode<ClipAttrs> { @@ -107,7 +126,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "floor") .describe(R"code(Returns the floor of input array, computed element-wise. )code" TVM_ADD_FILELINE) .set_support_level(3) -.add_type_rel("Identity", IdentityRel); +.add_type_rel("Identity", IdentityRel) +.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::floor)); + RELAY_REGISTER_UNARY_OP("relay.op._make.", "ceil") .describe(R"code(Returns the ceil of input array, computed element-wise. @@ -117,7 +138,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "ceil") )code" TVM_ADD_FILELINE) .set_support_level(3) -.add_type_rel("Identity", IdentityRel); +.add_type_rel("Identity", IdentityRel) +.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::ceil)); + RELAY_REGISTER_UNARY_OP("relay.op._make.", "trunc") .describe(R"code(Returns the trunc of input array, computed element-wise. @@ -127,7 +150,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "trunc") )code" TVM_ADD_FILELINE) .set_support_level(3) -.add_type_rel("Identity", IdentityRel); +.add_type_rel("Identity", IdentityRel) +.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::trunc)); + RELAY_REGISTER_UNARY_OP("relay.op._make.", "round") .describe(R"code(Returns the round of input array, computed element-wise. @@ -137,7 +162,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "round") )code" TVM_ADD_FILELINE) .set_support_level(3) -.add_type_rel("Identity", IdentityRel); +.add_type_rel("Identity", IdentityRel) +.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::round)); + RELAY_REGISTER_UNARY_OP("relay.op._make.", "abs") .describe(R"code(Returns the abs of input array, computed element-wise. @@ -147,7 +174,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "abs") )code" TVM_ADD_FILELINE) .set_support_level(3) -.add_type_rel("Identity", IdentityRel); +.add_type_rel("Identity", IdentityRel) +.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::abs)); + RELAY_REGISTER_UNARY_OP("relay.op._make.", "tanh") .describe(R"code(Returns the tanh of input array, computed element-wise. @@ -157,7 +186,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "tanh") )code" TVM_ADD_FILELINE) .set_support_level(1) -.add_type_rel("Identity", IdentityRel); +.add_type_rel("Identity", IdentityRel) +.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::tanh)); + RELAY_REGISTER_UNARY_OP("relay.op._make.", "negative") .describe(R"code(Returns the numeric negative of input array, computed element-wise. @@ -167,7 +198,8 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "negative") )code" TVM_ADD_FILELINE) .set_support_level(3) -.add_type_rel("Identity", IdentityRel); +.add_type_rel("Identity", IdentityRel) +.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::negative)); } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index d28aa0a56941dd036518f91473b2331e32e03985..6a1662b65170214b303fe30c2d3707ac1f60fbc1 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -188,20 +188,22 @@ def test_concatenate(): x = relay.var("x", shape=(10, 5)) y = relay.var("y", shape=(10, 5)) + t = relay.var("z", shape=()) z = relay.concatenate((x, y), axis=1) - + z = relay.add(z, t) # Check result. - func = relay.Function([x, y], z) + func = relay.Function([x, y, t], z) x_data = np.random.rand(10, 5).astype('float32') y_data = np.random.rand(10, 5).astype('float32') - ref_res = np.concatenate((x_data, y_data), axis=1) + t_data = np.random.uniform(size=()).astype('float32') + ref_res = np.concatenate((x_data, y_data), axis=1) + t_data for target, ctx in ctx_list(): intrp1 = relay.create_executor("graph", ctx=ctx, target=target) intrp2 = relay.create_executor("debug", ctx=ctx, target=target) - op_res1 = intrp1.evaluate(func)(x_data, y_data) + op_res1 = intrp1.evaluate(func)(x_data, y_data, t_data) tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=0.01) - op_res2 = intrp2.evaluate(func)(x_data, y_data) + op_res2 = intrp2.evaluate(func)(x_data, y_data, t_data) tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=0.01) def test_dropout(): @@ -306,11 +308,11 @@ def test_dense(): if __name__ == "__main__": + test_concatenate() test_bias_add() test_unary_op() test_binary_op() test_expand_dims_infer_type() - test_concatenate() test_expand_dims() test_softmax() test_log_softmax()