Commit abccd9cd by Tianqi Chen Committed by GitHub

[TOPI] Improve dilate (#330)

parent 9ac46bea
......@@ -2,6 +2,7 @@
"""Dilation operators"""
from __future__ import absolute_import as _abs
import tvm
from .. import util
@tvm.tag_scope(tag="dilation")
......@@ -29,15 +30,21 @@ def dilate(Input, strides):
output_size += (tvm.ir_pass.Simplify((Input.shape[i]-1)*strides[i]+1),)
def _dilate(data, *indices):
not_zero = (indices[0]%strides[0]).equal(0)
index_tuple = ()
not_zero = []
index_tuple = []
for i in range(n):
index_tuple += (indices[i]/strides[i],)
not_zero = tvm.all(not_zero, (indices[i]%strides[i]).equal(0))
return tvm.select(not_zero, data[index_tuple], tvm.const(0.0, data.dtype))
if not util.equal_const_int(strides[i], 1):
index_tuple.append(indices[i]/strides[i])
not_zero.append((indices[i] % strides[i]).equal(0))
else:
index_tuple.append(indices[i])
if not_zero:
not_zero = tvm.all(*not_zero)
return tvm.select(not_zero, data(*index_tuple), tvm.const(0.0, data.dtype))
return data(*index_tuple)
Output = tvm.compute(
(output_size),
output_size,
lambda *indices: _dilate(Input, *indices),
name='DilatedInput')
......
......@@ -7,21 +7,43 @@ def get_const_int(expr):
Parameters
----------
expr :
expr : tvm.Expr
The input expression.
Returns
-------
out_tuple : tuple of int
out_value : int
The output.
"""
if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)):
expr = tvm.ir_pass.Simplfy(expr)
expr = tvm.ir_pass.Simplify(expr)
if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)):
raise ValueError("Expect value to be constant int")
return expr.value
def equal_const_int(expr, value):
"""Returns if expr equals value.
Parameters
----------
expr : tvm.Expr
The input expression.
Returns
-------
equal : bool
Whether they equals.
"""
if isinstance(expr, int):
return expr == value
if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)):
expr = tvm.ir_pass.Simplify(expr)
if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)):
return False
return expr.value == value
def get_const_tuple(in_tuple):
"""Verifies input tuple is IntImm, returns tuple of int.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment