Skip to content
Snippets Groups Projects
Commit 079e2307 authored by masahi's avatar masahi Committed by Tianqi Chen
Browse files

simplify expr in get_const_tuple (#795)

* fix upsampling output shape

* simplify expr in get_const_tuple
parent ebf4e5a3
No related branches found
No related tags found
No related merge requests found
"""TVM operator upsampling compute.""" """TVM operator upsampling compute."""
from __future__ import absolute_import from __future__ import absolute_import
import tvm import tvm
from .. import util
def upsampling(data, scale): def upsampling(data, scale):
...@@ -21,8 +22,8 @@ def upsampling(data, scale): ...@@ -21,8 +22,8 @@ def upsampling(data, scale):
4-D with shape [batch, channel, in_height*scale, in_width*scale] 4-D with shape [batch, channel, in_height*scale, in_width*scale]
""" """
batch, channel, height, width = data.shape batch, channel, height, width = data.shape
out_height = height * scale out_height = util.simplify(height * scale)
out_width = width * scale out_width = util.simplify(width * scale)
return tvm.compute((batch, channel, out_height, out_width), \ return tvm.compute((batch, channel, out_height, out_width), \
lambda n, c, h, w: data[n, c, h/scale, w/scale]) lambda n, c, h, w: data[n, c, h/scale, w/scale])
...@@ -59,9 +59,8 @@ def get_const_tuple(in_tuple): ...@@ -59,9 +59,8 @@ def get_const_tuple(in_tuple):
""" """
out_tuple = () out_tuple = ()
for elem in in_tuple: for elem in in_tuple:
if not isinstance(elem, (tvm.expr.IntImm, tvm.expr.UIntImm)): value = get_const_int(elem)
raise ValueError("Element of input tuple should be const int") out_tuple = out_tuple + (value, )
out_tuple = out_tuple + (elem.value, )
return out_tuple return out_tuple
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment