From f397fead73bae20ca9abd2fc69df5723e065acde Mon Sep 17 00:00:00 2001
From: Wuwei Lin <vincentl13x@gmail.com>
Date: Thu, 24 Jan 2019 12:50:21 +0800
Subject: [PATCH] [RELAY] Fix ops in packed layout (#2472)

* [RELAY] Fix ops in packed layout

* Fix style
---
 src/relay/op/nn/pooling.cc                      | 6 +++++-
 src/relay/op/tensor/transform.cc                | 3 ++-
 tests/python/relay/test_pass_alter_op_layout.py | 4 ++++
 3 files changed, 11 insertions(+), 2 deletions(-)

diff --git a/src/relay/op/nn/pooling.cc b/src/relay/op/nn/pooling.cc
index 6cf37668c..8fd33e1f3 100644
--- a/src/relay/op/nn/pooling.cc
+++ b/src/relay/op/nn/pooling.cc
@@ -83,7 +83,11 @@ bool Pool2DRel(const Array<Type>& types,
     return false;
   }
 
-  std::vector<IndexExpr> oshape({dshape[0], dshape[1], dshape[2], dshape[3]});
+  std::vector<IndexExpr> oshape;
+  for (const auto& e : dshape) {
+    oshape.push_back(e);
+  }
+
   if (param->ceil_mode) {
     oshape[hidx] = ((dshape[hidx] + pad_h - param->pool_size[0] +
                     param->strides[0] - 1) / param->strides[0]) + 1;
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index 704324533..6d583bfd6 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -76,7 +76,8 @@ RELAY_REGISTER_OP("cast")
 .set_support_level(3)
 .add_type_rel("Cast", CastRel)
 .set_attr<FTVMCompute>("FTVMCompute", CastCompute)
-.set_attr<TOpPattern>("TOpPattern", kElemWise);
+.set_attr<TOpPattern>("TOpPattern", kElemWise)
+.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);
 
 // relay.expand_dims
 TVM_REGISTER_NODE_TYPE(ExpandDimsAttrs);
diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py
index 48ab2ba27..975973d2b 100644
--- a/tests/python/relay/test_pass_alter_op_layout.py
+++ b/tests/python/relay/test_pass_alter_op_layout.py
@@ -82,6 +82,8 @@ def test_alter_layout():
         # a useless tuple, which will be eliminated
         y = relay.Tuple([y])[0]
         y = relay.nn.relu(y)
+        y = relay.nn.max_pool2d(y, pool_size=(2, 2))
+        y = relay.cast(y, 'int32')
         y = relay.nn.batch_flatten(y)
         y = relay.Function(free_vars(y), y)
         return y
@@ -112,6 +114,8 @@ def test_alter_layout():
         y = relay.add(y, b)
 
         y = relay.nn.relu(y)
+        y = relay.nn.max_pool2d(y, pool_size=(2, 2), layout="NCHW16c")
+        y = relay.cast(y, 'int32')
         y = relay.layout_transform(y, "NCHW16c", "NCHW")
         y = relay.nn.batch_flatten(y)
         y = relay.Function(free_vars(y), y)
-- 
GitLab