Commit 108e9f3f by alex-weaver Committed by Tianqi Chen

Update HalideIR submodule to include TVM_STATIC_IR_FUNCTOR_REGISTER (#857)

* Update HalideIR commit to include TVM_STATIC_IR_FUNCTOR_REGISTER

* Fix HalideIR to point to the right commit

* Add missing using to C++ TOPI nn.h

* Update HalideIR to include compiler error fix

* Fixed error where broadcast_to fails if shape is tuple of IntImm

* Change get_const_int to support int as input
parent 05952984
Subproject commit aadbf02d6bd7a545edbf6652494a7b07a97a06c1 Subproject commit 87b089a0ba20f2e8257038ee9211d6816088ce95
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "tvm/tvm.h" #include "tvm/tvm.h"
namespace topi { namespace topi {
using namespace tvm;
namespace detail { namespace detail {
template <typename T> template <typename T>
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm import tvm
from .import tag from .import tag
from .util import get_const_tuple, equal_const_int from .util import get_const_tuple, equal_const_int, get_const_int
def _get_bcast_info(original_shape, target_shape): def _get_bcast_info(original_shape, target_shape):
"""Get the broadcasting info. """Get the broadcasting info.
...@@ -106,6 +106,7 @@ def broadcast_to(data, shape): ...@@ -106,6 +106,7 @@ def broadcast_to(data, shape):
indices_tuple.append(0) indices_tuple.append(0)
return data[tuple(indices_tuple)] return data[tuple(indices_tuple)]
original_shape = data.shape original_shape = data.shape
shape = [get_const_int(i) for i in shape]
bcast_info = _get_bcast_info(original_shape=original_shape, target_shape=shape) bcast_info = _get_bcast_info(original_shape=original_shape, target_shape=shape)
ret = tvm.compute(shape, ret = tvm.compute(shape,
lambda *indices: _bcast_to_arg_eval(data, lambda *indices: _bcast_to_arg_eval(data,
......
...@@ -7,7 +7,7 @@ def get_const_int(expr): ...@@ -7,7 +7,7 @@ def get_const_int(expr):
Parameters Parameters
---------- ----------
expr : tvm.Expr expr : tvm.Expr or int
The input expression. The input expression.
Returns Returns
...@@ -15,6 +15,8 @@ def get_const_int(expr): ...@@ -15,6 +15,8 @@ def get_const_int(expr):
out_value : int out_value : int
The output. The output.
""" """
if isinstance(expr, int):
return expr
if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)): if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)):
expr = tvm.ir_pass.Simplify(expr) expr = tvm.ir_pass.Simplify(expr)
if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)): if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment