diff --git a/src/relay/pass/simplify_inference.cc b/src/relay/pass/simplify_inference.cc index 785b486ddc06f171055cb6858f8c6d21d8e01e59..6acf4e65b1acbd2160b9689b3ff6f1f36220aa42 100644 --- a/src/relay/pass/simplify_inference.cc +++ b/src/relay/pass/simplify_inference.cc @@ -15,7 +15,8 @@ Expr BatchNormToInferUnpack(const Attrs attrs, Expr gamma, Expr beta, Expr moving_mean, - Expr moving_var) { + Expr moving_var, + Type tdata) { const auto param = attrs.as<BatchNormAttrs>(); Expr epsilon = MakeConstantScalar(Float(32), static_cast<float>(param->epsilon)); Expr var_add_eps = Add(moving_var, epsilon); @@ -32,9 +33,11 @@ Expr BatchNormToInferUnpack(const Attrs attrs, } int axis = param->axis; - const auto* tdata = data->type_as<TensorTypeNode>(); - scale = ExpandBiasToMatchAxis(scale, tdata->shape.size(), {axis}); - shift = ExpandBiasToMatchAxis(shift, tdata->shape.size(), {axis}); + auto ttype = tdata.as<TensorTypeNode>(); + CHECK(ttype); + auto ndim = ttype->shape.size(); + scale = ExpandBiasToMatchAxis(scale, ndim, {axis}); + shift = ExpandBiasToMatchAxis(shift, ndim, {axis}); Expr out = Multiply(data, scale); out = Add(out, shift); @@ -54,14 +57,26 @@ class InferenceSimplifier : public ExprMutator { } if (const auto* call = new_n->tuple.as<CallNode>()) { if (call->op.same_as(batch_norm)) { - return BatchNormToInferUnpack(call->attrs, - call->args[0], call->args[1], call->args[2], call->args[3], call->args[4]); + return BatchNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2], + call->args[3], call->args[4], ty_map_.at(call->args[0])); } else if (call->op.same_as(dropout)) { return call->args[0]; } } return new_e; } + + Expr VisitExpr_(const CallNode* n) { + static const Op& batch_norm = Op::Get("nn.batch_norm"); + auto new_n = ExprMutator::VisitExpr_(n); + if (n->op.same_as(batch_norm)) { + ty_map_[new_n.as<CallNode>()->args[0]] = n->args[0]->checked_type(); + } + return new_n; + } + + private: + std::unordered_map<Expr, Type, NodeHash, NodeEqual> ty_map_; }; Expr SimplifyInference(const Expr& e) { diff --git a/tests/python/relay/test_pass_simplify_inference.py b/tests/python/relay/test_pass_simplify_inference.py index 9830b83dc6e59752338054f83c418a1309f33436..7585a88063abab8639e720347649c3d31362853c 100644 --- a/tests/python/relay/test_pass_simplify_inference.py +++ b/tests/python/relay/test_pass_simplify_inference.py @@ -30,12 +30,12 @@ def test_simplify_batchnorm(): y1, _, _ = rly.nn.batch_norm(y1 + rly.const(1, 'float32'), gamma, beta, moving_mean, moving_var, epsilon=eps, axis=axis) y1 = rly.nn.dropout(y1) - y1 = rly.ir_pass.infer_type(y1) - y1 = simplify_inference(y1) - y2 = simple_bn(y2 + rly.const(1, 'float32'), gamma, beta, moving_mean, moving_var, epsilon=eps, axis=axis, shape=ttype1.shape) + y1 = rly.ir_pass.infer_type(y1) + y1 = simplify_inference(y1) + assert rly.ir_pass.graph_equal(y1, y2) check(2, 1, 1)