Commit 079e2307 by masahi Committed by Tianqi Chen

simplify expr in get_const_tuple (#795)

* fix upsampling output shape

* simplify expr in get_const_tuple
parent ebf4e5a3
"""TVM operator upsampling compute."""
from __future__ import absolute_import
import tvm
from .. import util
def upsampling(data, scale):
......@@ -21,8 +22,8 @@ def upsampling(data, scale):
4-D with shape [batch, channel, in_height*scale, in_width*scale]
"""
batch, channel, height, width = data.shape
out_height = height * scale
out_width = width * scale
out_height = util.simplify(height * scale)
out_width = util.simplify(width * scale)
return tvm.compute((batch, channel, out_height, out_width), \
lambda n, c, h, w: data[n, c, h/scale, w/scale])
......@@ -59,9 +59,8 @@ def get_const_tuple(in_tuple):
"""
out_tuple = ()
for elem in in_tuple:
if not isinstance(elem, (tvm.expr.IntImm, tvm.expr.UIntImm)):
raise ValueError("Element of input tuple should be const int")
out_tuple = out_tuple + (elem.value, )
value = get_const_int(elem)
out_tuple = out_tuple + (value, )
return out_tuple
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment