Commit 108e9f3f by alex-weaver Committed by Tianqi Chen

Update HalideIR submodule to include TVM_STATIC_IR_FUNCTOR_REGISTER (#857)

* Update HalideIR commit to include TVM_STATIC_IR_FUNCTOR_REGISTER

* Fix HalideIR to point to the right commit

* Add missing using to C++ TOPI nn.h

* Update HalideIR to include compiler error fix

* Fixed error where broadcast_to fails if shape is tuple of IntImm

* Change get_const_int to support int as input
parent 05952984
Subproject commit aadbf02d6bd7a545edbf6652494a7b07a97a06c1
Subproject commit 87b089a0ba20f2e8257038ee9211d6816088ce95
......@@ -15,6 +15,7 @@
#include "tvm/tvm.h"
namespace topi {
using namespace tvm;
namespace detail {
template <typename T>
......
......@@ -3,7 +3,7 @@
from __future__ import absolute_import as _abs
import tvm
from .import tag
from .util import get_const_tuple, equal_const_int
from .util import get_const_tuple, equal_const_int, get_const_int
def _get_bcast_info(original_shape, target_shape):
"""Get the broadcasting info.
......@@ -106,6 +106,7 @@ def broadcast_to(data, shape):
indices_tuple.append(0)
return data[tuple(indices_tuple)]
original_shape = data.shape
shape = [get_const_int(i) for i in shape]
bcast_info = _get_bcast_info(original_shape=original_shape, target_shape=shape)
ret = tvm.compute(shape,
lambda *indices: _bcast_to_arg_eval(data,
......
......@@ -7,7 +7,7 @@ def get_const_int(expr):
Parameters
----------
expr : tvm.Expr
expr : tvm.Expr or int
The input expression.
Returns
......@@ -15,6 +15,8 @@ def get_const_int(expr):
out_value : int
The output.
"""
if isinstance(expr, int):
return expr
if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)):
expr = tvm.ir_pass.Simplify(expr)
if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment