Commit eef35a57 by Haichen Shen Committed by Jared Roesch

[Relay][Any] Add shape func for dynamic shape (#3606)

* init shape func in interpreter and vm compiler

* Update interpreter

* fix

* lint

* lint

* fix

* remove hack

* update

* fix

* fix

* update

* address comments & update for shape_of

* fix lint

* update

* fix hybrid

* lint

* fix bug & add take shape func

* lint

* lint

* update

* fix flaky test

* add todo
parent c1c7b9b1
...@@ -704,6 +704,10 @@ class Reduce : public ExprNode { ...@@ -704,6 +704,10 @@ class Reduce : public ExprNode {
class Any : public ExprNode { class Any : public ExprNode {
public: public:
void VisitAttrs(AttrVisitor* v) final {} void VisitAttrs(AttrVisitor* v) final {}
/*! \brief Convert to var. */
Var ToVar() const {
return Variable::make(Int(32), "any_dim");
}
TVM_DLL static Expr make(); TVM_DLL static Expr make();
......
...@@ -75,6 +75,11 @@ using TOpIsStateful = bool; ...@@ -75,6 +75,11 @@ using TOpIsStateful = bool;
using TNonComputational = bool; using TNonComputational = bool;
/*! /*!
* \brief Mark the operator whether output shape is data dependant.
*/
using TShapeDataDependant = bool;
/*!
* \brief Computation description interface. * \brief Computation description interface.
* *
* \note This function have a special convention * \note This function have a special convention
...@@ -186,7 +191,7 @@ using Shape = Array<IndexExpr>; ...@@ -186,7 +191,7 @@ using Shape = Array<IndexExpr>;
using FShapeFunc = runtime::TypedPackedFunc< using FShapeFunc = runtime::TypedPackedFunc<
Array<Tensor>(const Attrs& attrs, Array<Tensor>(const Attrs& attrs,
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
const Array<Shape>& out_shapes)>; const Array<IndexExpr>& out_ndims)>;
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
......
...@@ -25,7 +25,7 @@ import numbers ...@@ -25,7 +25,7 @@ import numbers
from enum import Enum from enum import Enum
from .util import _internal_assert, _apply_indices from .util import _internal_assert
from . import calls from . import calls
from . import util from . import util
from .preprocessor import determine_variable_usage from .preprocessor import determine_variable_usage
...@@ -35,7 +35,6 @@ from ..container import Array ...@@ -35,7 +35,6 @@ from ..container import Array
from ..tensor import Tensor, Operation from ..tensor import Tensor, Operation
from .. import _api_internal as _tvm_internal from .. import _api_internal as _tvm_internal
from .. import expr as _expr from .. import expr as _expr
from .. import stmt as _stmt
from .. import make as _make from .. import make as _make
from .. import api as _api from .. import api as _api
from .. import ir_pass as _ir_pass from .. import ir_pass as _ir_pass
...@@ -43,15 +42,14 @@ from .. import ir_pass as _ir_pass ...@@ -43,15 +42,14 @@ from .. import ir_pass as _ir_pass
def concat_list_to_block(lst): def concat_list_to_block(lst):
"""Concatenate a list of Python IR nodes to HalideIR Block""" """Concatenate a list of Python IR nodes to HalideIR Block"""
if not lst:
return util.make_nop()
n = len(lst) n = len(lst)
if n == 1: if n == 1:
return lst[0] return lst[0]
body = lst[n - 1] body = lst[n - 1]
for i in range(1, n): for i in range(1, n):
stmt = lst[n - 1 - i] stmt = lst[n - 1 - i]
if isinstance(stmt, _stmt.AssertStmt):
body = _make.AssertStmt(stmt.condition, stmt.message, body)
else:
body = _make.Block(stmt, body) body = _make.Block(stmt, body)
return body return body
...@@ -179,6 +177,9 @@ class HybridParser(ast.NodeVisitor): ...@@ -179,6 +177,9 @@ class HybridParser(ast.NodeVisitor):
to_pop = [] to_pop = []
for key, val in self.usage.items(): for key, val in self.usage.items():
_, level, _ = val _, level, _ = val
if key not in self.symbols:
# don't realize the symbols that are never visited
continue
if level != node: if level != node:
continue continue
_internal_assert(key in self.symbols.keys(), "Unknown symbol %s!" % key) _internal_assert(key in self.symbols.keys(), "Unknown symbol %s!" % key)
...@@ -363,44 +364,25 @@ class HybridParser(ast.NodeVisitor): ...@@ -363,44 +364,25 @@ class HybridParser(ast.NodeVisitor):
def visit_Attribute(self, node): def visit_Attribute(self, node):
_internal_assert(isinstance(node.value, ast.Name), \
"For atrribute access, only both names are supported so far!")
buf = self.visit(node.value) buf = self.visit(node.value)
return getattr(buf, node.attr) return getattr(buf, node.attr)
def visit_Subscript(self, node): def visit_Subscript(self, node):
args = self.visit(node.slice) args = self.visit(node.slice)
if isinstance(node.value, ast.Name): arr = self.visit(node.value)
if node.value.id in self.closure_vars: if isinstance(arr, Array):
args = ast.literal_eval(str(args))
return _api.convert(_apply_indices(self.closure_vars[node.value.id], args))
buf = self.visit(node.value)
if isinstance(buf, Array):
for i in args: for i in args:
if isinstance(i, numbers.Integral): if isinstance(i, numbers.Integral):
buf = buf[i] arr = arr[i]
else: else:
_internal_assert(isinstance(i, (_expr.IntImm, _expr.UIntImm)), \ _internal_assert(isinstance(i, (_expr.IntImm, _expr.UIntImm)), \
"All indices are supposed to be constants") "All indices are supposed to be constants")
buf = buf[i.value] arr = arr[i.value]
return arr
return buf
if isinstance(node.ctx, ast.Load): if isinstance(node.ctx, ast.Load):
return _make.Call(buf.dtype, buf.name, args, \ return _make.Call(arr.dtype, arr.name, args,
_expr.Call.Halide, buf.op, buf.value_index) _expr.Call.Halide, arr.op, arr.value_index)
return arr, args
return buf, args
shape = self.visit(node.value)
_internal_assert(len(args) == 1, "For 'shape' access the argument should be only one!")
args = args[0]
#TODO: maybe support non-constant value later?
_internal_assert(isinstance(args, (_expr.IntImm, _expr.UIntImm)), \
"So far only constant shape access supported!")
return shape[args.value]
def visit_With(self, node): def visit_With(self, node):
if sys.version_info[0] < 3: if sys.version_info[0] < 3:
...@@ -417,7 +399,7 @@ class HybridParser(ast.NodeVisitor): ...@@ -417,7 +399,7 @@ class HybridParser(ast.NodeVisitor):
def visit_If(self, node): def visit_If(self, node):
cond = self.visit(node.test) cond = _ir_pass.CanonicalSimplify(self.visit(node.test))
# Return no IfThenElse if proven # Return no IfThenElse if proven
if isinstance(cond, _expr.UIntImm): if isinstance(cond, _expr.UIntImm):
...@@ -508,11 +490,11 @@ class HybridParser(ast.NodeVisitor): ...@@ -508,11 +490,11 @@ class HybridParser(ast.NodeVisitor):
_name = node.target.id _name = node.target.id
if isinstance(for_type, tuple): if isinstance(for_type, tuple):
low = _ir_pass.Simplify(low) low = _ir_pass.CanonicalSimplify(low)
ext = _ir_pass.Simplify(ext) ext = _ir_pass.CanonicalSimplify(ext)
_internal_assert(isinstance(low, _expr.ConstExpr) and _internal_assert(isinstance(low, _expr.ConstExpr) and
isinstance(ext, _expr.ConstExpr), \ isinstance(ext, _expr.ConstExpr), \
"Const range should start from a const" + \ "Const range should start from a const " + \
"and iterate const times") "and iterate const times")
low, ext = low.value, ext.value low, ext = low.value, ext.value
......
...@@ -101,9 +101,3 @@ def _is_tvm_arg_types(args): ...@@ -101,9 +101,3 @@ def _is_tvm_arg_types(args):
_internal_assert(isinstance(elem, np_arg_types), \ _internal_assert(isinstance(elem, np_arg_types), \
"Expect a numpy type but %s get!" % str(type(elem))) "Expect a numpy type but %s get!" % str(type(elem)))
return False return False
def _apply_indices(value, indices):
"""Apply multidimensional index"""
if indices:
return _apply_indices(value[indices[0]], indices[1:])
return value
...@@ -177,6 +177,10 @@ class VMCompiler(object): ...@@ -177,6 +177,10 @@ class VMCompiler(object):
The VM runtime. The VM runtime.
""" """
target = _update_target(target) target = _update_target(target)
target_host = None if target_host == "" else target_host
if not target_host:
target_host = "llvm" if tvm.module.enabled("llvm") else "stackvm"
target_host = tvm.target.create(target_host)
self._compile(mod, target, target_host) self._compile(mod, target, target_host)
return VirtualMachine(self._get_vm()) return VirtualMachine(self._get_vm())
......
...@@ -14,12 +14,13 @@ ...@@ -14,12 +14,13 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
#pylint: disable=invalid-name, unused-argument #pylint: disable=invalid-name, unused-argument, len-as-condition
"""Backend compiler related feature registration""" """Backend compiler related feature registration"""
from __future__ import absolute_import from __future__ import absolute_import
import topi import topi
from .op import register_compute, register_schedule, register_pattern from .op import register_compute, register_schedule, register_pattern, register_shape_func
from .op import schedule_injective, OpPattern from .op import schedule_injective, OpPattern
from ...hybrid import script
schedule_broadcast = schedule_injective schedule_broadcast = schedule_injective
schedule_elemwise = schedule_injective schedule_elemwise = schedule_injective
...@@ -104,3 +105,49 @@ def clip_compute(attrs, inputs, output_type, target): ...@@ -104,3 +105,49 @@ def clip_compute(attrs, inputs, output_type, target):
return [topi.clip(inputs[0], attrs.a_min, attrs.a_max)] return [topi.clip(inputs[0], attrs.a_min, attrs.a_max)]
register_schedule("clip", schedule_elemwise) register_schedule("clip", schedule_elemwise)
# shape func
@script
def _broadcast_shape_func(x, y, ndim):
out = output_tensor((ndim,), "int64")
if len(x.shape) == 0:
for i in const_range(ndim):
out[i] = y[i]
elif len(y.shape) == 0:
for i in const_range(ndim):
out[i] = x[i]
else:
ndim1 = x.shape[0]
ndim2 = y.shape[0]
for i in const_range(1, min(ndim1, ndim2)+1):
if x[ndim1-i] == y[ndim2-i]:
out[ndim-i] = x[ndim1-i]
elif x[ndim1-i] == 1:
out[ndim-i] = y[ndim2-i]
else:
assert y[ndim2 - i] == 1, "Incompatible broadcast type %s and %s" % (
x[ndim1-i], y[ndim2-i])
out[ndim-i] = x[ndim1-i]
for i in const_range(min(ndim1, ndim2)+1, ndim+1):
if ndim1 >= ndim2:
out[ndim-i] = x[ndim1-i]
else:
out[ndim-i] = y[ndim2-i]
return out
def broadcast_shape_func(attrs, inputs, out_ndims):
return [_broadcast_shape_func(*inputs, out_ndims[0])]
register_shape_func("add", False, broadcast_shape_func)
register_shape_func("subtract", False, broadcast_shape_func)
register_shape_func("multiply", False, broadcast_shape_func)
register_shape_func("divide", False, broadcast_shape_func)
register_shape_func("mod", False, broadcast_shape_func)
register_shape_func("logical_and", False, broadcast_shape_func)
register_shape_func("logical_or", False, broadcast_shape_func)
register_shape_func("equal", False, broadcast_shape_func)
register_shape_func("not_equal", False, broadcast_shape_func)
register_shape_func("less", False, broadcast_shape_func)
register_shape_func("less_equal", False, broadcast_shape_func)
register_shape_func("greater", False, broadcast_shape_func)
register_shape_func("greater_equal", False, broadcast_shape_func)
...@@ -15,11 +15,14 @@ ...@@ -15,11 +15,14 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Backend compiler related feature registration""" """Backend compiler related feature registration"""
# pylint: disable=invalid-name,unused-argument # pylint: disable=invalid-name,unused-argument, len-as-condition
from __future__ import absolute_import from __future__ import absolute_import
from topi.util import get_const_int, get_const_tuple
from . import op as _reg from . import op as _reg
from ._reduce import _schedule_reduce from ._reduce import _schedule_reduce
from .op import OpPattern from .op import OpPattern
from ...hybrid import script
from ...api import convert
schedule_injective = _reg.schedule_injective schedule_injective = _reg.schedule_injective
schedule_broadcast = _reg.schedule_injective schedule_broadcast = _reg.schedule_injective
...@@ -58,3 +61,145 @@ _reg.register_schedule("one_hot", schedule_injective) ...@@ -58,3 +61,145 @@ _reg.register_schedule("one_hot", schedule_injective)
# layout_transform # layout_transform
_reg.register_schedule("layout_transform", schedule_injective) _reg.register_schedule("layout_transform", schedule_injective)
_reg.register_pattern("layout_transform", OpPattern.INJECTIVE) _reg.register_pattern("layout_transform", OpPattern.INJECTIVE)
# shape func
@script
def _arange_shape_func(start, stop, step):
out = output_tensor((1,), "int64")
out[0] = int64(ceil_div((float32(stop[0]) - float32(start[0])), float32(step[0])))
return out
@_reg.register_shape_func("arange", True)
def arange_shape_func(attrs, inputs, _):
return [_arange_shape_func(*inputs)]
@script
def _concatenate_shape_func(inputs, axis):
ndim = inputs[0].shape[0]
out = output_tensor((ndim,), "int64")
for i in const_range(ndim):
if i != axis:
out[i] = inputs[0][i]
for j in const_range(1, len(inputs)):
assert out[i] == inputs[j][i], \
"Dims mismatch in the inputs of concatenate."
else:
out[i] = int64(0)
for j in const_range(len(inputs)):
out[i] += inputs[j][i]
return out
@_reg.register_shape_func("concatenate", False)
def concatenate_shape_func(attrs, inputs, _):
axis = get_const_int(attrs.axis)
return [_concatenate_shape_func(inputs, convert(axis))]
@script
def _reshape_shape_func(data_shape, newshape, ndim):
out = output_tensor((ndim,), "int64")
src_idx = 0
dst_idx = 0
infer_idx = -1
copy = False
skip = 0
for i in const_range(len(newshape)):
if skip > 0:
skip -= 1
elif newshape[i] > 0:
out[dst_idx] = int64(newshape[i])
src_idx += 1
dst_idx += 1
elif newshape[i] == 0:
out[dst_idx] = data_shape[src_idx]
src_idx += 1
dst_idx += 1
elif newshape[i] == -1:
assert infer_idx < 0, "One and only one dim can be inferred"
out[dst_idx] = int64(1)
infer_idx = i
dst_idx += 1
elif newshape[i] == -2:
copy = True
elif newshape[i] == -3:
assert data_shape.shape[0] - src_idx > 1, \
"Not enough dims in input shape for -3"
out[dst_idx] = data_shape[src_idx] * data_shape[src_idx+1]
src_idx += 2
dst_idx += 1
elif newshape[i] == -4:
assert len(newshape) - i > 2, "Not enough dims in new shape for -4"
if newshape[i+1] == -1:
assert newshape[i+2] != -1, "Split dims cannot both be -1."
out[dst_idx] = data_shape[src_idx] / int64(newshape[i+2])
out[dst_idx+1] = int64(newshape[i+2])
else:
out[dst_idx] = int64(newshape[i+1])
if newshape[i+2] == -1:
out[dst_idx+1] = data_shape[src_idx] / int64(newshape[i+1])
else:
out[dst_idx+1] = int64(newshape[i+2])
assert data_shape[src_idx] == out[dst_idx] * out[dst_idx+1],\
"Product of split dims doesn't match to input dim"
src_idx += 1
dst_idx += 2
skip = 2
else:
assert False, "Invalid special values in new shape"
if len(data_shape.shape) > 0:
# if data is not constant, we can then handle -1 and -2
if copy:
for i in range(src_idx, data_shape.shape[0]):
out[dst_idx] = data_shape[i]
dst_idx += 1
if infer_idx >= 0:
old_size = int64(1)
for i in const_range(data_shape.shape[0]):
old_size *= data_shape[i]
new_size = int64(1)
for i in const_range(out.shape[0]):
new_size *= out[i]
out[infer_idx] = old_size / new_size
return out
@_reg.register_shape_func("reshape", False)
def reshape_shape_func(attrs, inputs, out_ndims):
newshape = get_const_tuple(attrs.newshape)
return [_reshape_shape_func(inputs[0], convert(newshape), out_ndims[0])]
@script
def _take_no_axis_shape_func(indices_shape, out_ndim):
out = output_tensor((out_ndim,), "int64")
for i in const_range(out_ndim):
out[i] = indices_shape[i]
return out
@script
def _take_with_axis_shape_func(data_shape, indices_shape, axis, out_ndim):
out = output_tensor((out_ndim,), "int64")
for i in const_range(axis):
out[i] = data_shape[i]
if len(indices_shape.shape) == 0:
# indices is constant
for i in const_range(axis+1, len(data_shape)):
out[i-1] = data_shape[i]
else:
for i in const_range(len(indices_shape)):
out[axis+i] = indices_shape[i]
for i in const_range(axis+1, len(data_shape)):
out[len(indices_shape)+i-1] = data_shape[i]
return out
@_reg.register_shape_func("take", False)
def take_shape_func(attrs, inputs, out_ndims):
"""
Shape function for take op.
"""
if attrs.axis is None:
return [_take_no_axis_shape_func(inputs[1], out_ndims[0])]
else:
axis = get_const_int(attrs.axis)
data_ndim = int(inputs[0].shape[0])
if axis < 0:
axis += data_ndim
assert 0 <= axis < data_ndim
return [_take_with_axis_shape_func(*inputs, convert(axis), out_ndims[0])]
...@@ -48,6 +48,22 @@ class Op(Expr): ...@@ -48,6 +48,22 @@ class Op(Expr):
""" """
return _OpGetAttr(self, attr_name) return _OpGetAttr(self, attr_name)
def set_attr(self, attr_name, value, plevel=10):
"""Set attribute about the operator.
Parameters
----------
attr_name : str
The attribute name
value : object
The attribute value
plevel : int
The priority level
"""
_OpSetAttr(self, attr_name, value, plevel)
def get(op_name): def get(op_name):
"""Get the Op for a given name """Get the Op for a given name
...@@ -219,6 +235,26 @@ def register_gradient(op_name, fgradient=None, level=10): ...@@ -219,6 +235,26 @@ def register_gradient(op_name, fgradient=None, level=10):
""" """
return register(op_name, "FPrimalGradient", fgradient, level) return register(op_name, "FPrimalGradient", fgradient, level)
def register_shape_func(op_name, data_dependant, shape_func=None, level=10):
"""Register operator shape function for an op.
Parameters
----------
op_name : str
The name of the op.
data_dependant : bool
Whether the shape function depends on input data.
shape_func : function (attrs: Attrs, inputs: List[Tensor], out_ndims: List[IndexExpr])
-> shape_tensors: List<Tensor>
The function for computing the dynamic output shapes
level : int
The priority level
"""
get(op_name).set_attr("TShapeDataDependant", data_dependant, level)
return register(op_name, "FShapeFunc", shape_func, level)
_init_api("relay.op", __name__) _init_api("relay.op", __name__)
......
...@@ -82,7 +82,25 @@ Operation HybridOpNode::make(std::string name, ...@@ -82,7 +82,25 @@ Operation HybridOpNode::make(std::string name,
} }
Array<Tensor> HybridOpNode::InputTensors() const { Array<Tensor> HybridOpNode::InputTensors() const {
return inputs; // Because input tensors could be potentially inlined into hybrid scripts,
// we need to check if all input tensors are used in the body.
std::unordered_set<Tensor> orig_inputs;
for (auto t : inputs) {
orig_inputs.insert(t);
}
std::unordered_set<Tensor> visited;
Array<Tensor> curr_inputs;
ir::PostOrderVisit(body, [&curr_inputs, &orig_inputs, &visited](const NodeRef& n) {
const ir::Call *call = n.as<ir::Call>();
if (call != nullptr && call->func.defined()) {
Tensor t = Operation(call->func.node_).output(call->value_index);
if (orig_inputs.count(t) && !visited.count(t)) {
curr_inputs.push_back(t);
visited.insert(t);
}
}
});
return curr_inputs;
} }
Operation HybridOpNode::ReplaceInputs( Operation HybridOpNode::ReplaceInputs(
...@@ -111,7 +129,8 @@ void HybridOpNode::PropBoundToInputs( ...@@ -111,7 +129,8 @@ void HybridOpNode::PropBoundToInputs(
arith::Analyzer* analyzer, arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet> &dom_map, const std::unordered_map<const Variable*, IntSet> &dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const { std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
for (Tensor t : this->inputs) { auto curr_inputs = InputTensors();
for (Tensor t : curr_inputs) {
auto it = out_dom_map->find(t); auto it = out_dom_map->find(t);
if (it == out_dom_map->end()) continue; if (it == out_dom_map->end()) continue;
TensorDom &dom = it->second; TensorDom &dom = it->second;
...@@ -180,11 +199,10 @@ Stmt HybridOpNode::BuildProvide( ...@@ -180,11 +199,10 @@ Stmt HybridOpNode::BuildProvide(
outputs[i]->dtype); outputs[i]->dtype);
f_push_bind(buffer, stage->op.output(i)); f_push_bind(buffer, stage->op.output(i));
} }
for (int i = static_cast<int>(inputs.size()) - 1; i >= 0; --i) { auto curr_inputs = InputTensors();
Buffer buffer = decl_buffer( for (int i = static_cast<int>(curr_inputs.size()) - 1; i >= 0; --i) {
inputs[i]->shape, Buffer buffer = decl_buffer(curr_inputs[i]->shape, curr_inputs[i]->dtype);
inputs[i]->dtype); f_push_bind(buffer, curr_inputs[i]);
f_push_bind(buffer, inputs[i]);
} }
std::unordered_map<Tensor, Tensor> rmap; std::unordered_map<Tensor, Tensor> rmap;
...@@ -203,7 +221,7 @@ Stmt HybridOpNode::BuildProvide( ...@@ -203,7 +221,7 @@ Stmt HybridOpNode::BuildProvide(
* tensors have the same names as the operation produces them. * tensors have the same names as the operation produces them.
* 2. Once OpNode is wrapped up by an Operation node, it is finalized. * 2. Once OpNode is wrapped up by an Operation node, it is finalized.
* Later access will be from a const OpNode*. * Later access will be from a const OpNode*.
* This is a chiken-egg paradox. It is impossible to put the output * This is a chicken-egg paradox. It is impossible to put the output
* tensors into the function body without forming the op node. The * tensors into the function body without forming the op node. The
* function body is immutable after the node is formed. * function body is immutable after the node is formed.
* *
......
...@@ -41,7 +41,7 @@ namespace ir { ...@@ -41,7 +41,7 @@ namespace ir {
using runtime::StorageRank; using runtime::StorageRank;
using runtime::StorageScope; using runtime::StorageScope;
// Find a linear pattern of storage acess // Find a linear pattern of storage access
// Used for liveness analysis. // Used for liveness analysis.
// Composite scopes(loop/thread_launch/IfThen) is represented by two points: // Composite scopes(loop/thread_launch/IfThen) is represented by two points:
// before_scope -> scope_body -> after_scope // before_scope -> scope_body -> after_scope
...@@ -193,6 +193,10 @@ class LinearAccessPatternFinder final : public IRVisitor { ...@@ -193,6 +193,10 @@ class LinearAccessPatternFinder final : public IRVisitor {
VisitNewScope(op); VisitNewScope(op);
} }
void Visit_(const AssertStmt* op) final {
VisitNewScope(op);
}
// linearized access sequence. // linearized access sequence.
std::vector<StmtEntry> linear_seq_; std::vector<StmtEntry> linear_seq_;
// The storage scope of each buffer // The storage scope of each buffer
......
...@@ -36,6 +36,14 @@ ...@@ -36,6 +36,14 @@
namespace tvm { namespace tvm {
namespace relay { namespace relay {
/*! \brief Indicate whether the data or shape or both of a parameter is used in the shape func. */
enum ShapeFuncParamState {
kNoNeed = 0,
kNeedInputData = 1,
kNeedInputShape = 2,
kNeedBoth = 3,
};
/*! \brief Node container to represent a cached function. */ /*! \brief Node container to represent a cached function. */
struct CachedFuncNode : public Node { struct CachedFuncNode : public Node {
/* \brief compiled target */ /* \brief compiled target */
...@@ -48,6 +56,8 @@ struct CachedFuncNode : public Node { ...@@ -48,6 +56,8 @@ struct CachedFuncNode : public Node {
tvm::Array<Tensor> outputs; tvm::Array<Tensor> outputs;
/*! \brief The lowered functions to support the function. */ /*! \brief The lowered functions to support the function. */
tvm::Array<tvm::LoweredFunc> funcs; tvm::Array<tvm::LoweredFunc> funcs;
/*! \brief Parameter usage states in the shape function. */
tvm::Array<Integer> shape_func_param_states;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("target", &target); v->Visit("target", &target);
...@@ -55,6 +65,7 @@ struct CachedFuncNode : public Node { ...@@ -55,6 +65,7 @@ struct CachedFuncNode : public Node {
v->Visit("inputs", &inputs); v->Visit("inputs", &inputs);
v->Visit("outputs", &outputs); v->Visit("outputs", &outputs);
v->Visit("funcs", &funcs); v->Visit("funcs", &funcs);
v->Visit("shape_func_param_states", &shape_func_param_states);
} }
static constexpr const char* _type_key = "relay.CachedFunc"; static constexpr const char* _type_key = "relay.CachedFunc";
...@@ -170,6 +181,12 @@ class CompileEngineNode : public Node { ...@@ -170,6 +181,12 @@ class CompileEngineNode : public Node {
* \return The result. * \return The result.
*/ */
virtual PackedFunc JIT(const CCacheKey& key) = 0; virtual PackedFunc JIT(const CCacheKey& key) = 0;
/*!
* \brief Lower the shape function.
* \param key The key to the cached function.
* \return The result.
*/
virtual CachedFunc LowerShapeFunc(const CCacheKey& key) = 0;
/*! \brief clear the cache. */ /*! \brief clear the cache. */
virtual void Clear() = 0; virtual void Clear() = 0;
...@@ -180,7 +197,7 @@ class CompileEngineNode : public Node { ...@@ -180,7 +197,7 @@ class CompileEngineNode : public Node {
TVM_DECLARE_NODE_TYPE_INFO(CompileEngineNode, Node); TVM_DECLARE_NODE_TYPE_INFO(CompileEngineNode, Node);
}; };
/*! \brier cache entry used in compile engine */ /*! \brief cache entry used in compile engine */
class CompileEngine : public NodeRef { class CompileEngine : public NodeRef {
public: public:
CompileEngine() {} CompileEngine() {}
...@@ -193,6 +210,13 @@ class CompileEngine : public NodeRef { ...@@ -193,6 +210,13 @@ class CompileEngine : public NodeRef {
TVM_DLL static const CompileEngine& Global(); TVM_DLL static const CompileEngine& Global();
}; };
/*!
* \brief Check if the type is dynamic.
* \param ty The type to be checked.
* \return The result.
*/
bool IsDynamic(const Type& ty);
// implementations // implementations
inline size_t CCacheKeyNode::Hash() const { inline size_t CCacheKeyNode::Hash() const {
if (hash_ != 0) return hash_; if (hash_ != 0) return hash_;
......
...@@ -280,7 +280,7 @@ class Interpreter : ...@@ -280,7 +280,7 @@ class Interpreter :
return TupleValueNode::make(values); return TupleValueNode::make(values);
} }
// TODO(@jroesch): this doesn't support mututal letrec // TODO(@jroesch): this doesn't support mutual letrec
inline Value MakeClosure(const Function& func, Var letrec_name = Var()) { inline Value MakeClosure(const Function& func, Var letrec_name = Var()) {
tvm::Map<Var, Value> captured_mod; tvm::Map<Var, Value> captured_mod;
Array<Var> free_vars = FreeVars(func); Array<Var> free_vars = FreeVars(func);
...@@ -310,7 +310,125 @@ class Interpreter : ...@@ -310,7 +310,125 @@ class Interpreter :
return MakeClosure(func); return MakeClosure(func);
} }
Value InvokePrimitiveOp(Function func, Array<Shape> ComputeDynamicShape(const Function& func,
const Array<Value>& args) {
auto key = CCacheKeyNode::make(func, Target::Create("llvm"));
auto cfunc = engine_->LowerShapeFunc(key);
size_t arity = cfunc->inputs.size() + cfunc->outputs.size();
std::vector<TVMValue> values(arity);
std::vector<int> codes(arity);
TVMArgsSetter setter(values.data(), codes.data());
std::vector<NDArray> inputs(cfunc->inputs.size());
std::vector<NDArray> outputs(cfunc->outputs.size());
DLContext cpu_ctx;
cpu_ctx.device_type = kDLCPU;
cpu_ctx.device_id = 0;
auto fset_input = [&](size_t i, Value val, bool need_shape) {
const TensorValueNode* tv = val.as<TensorValueNode>();
CHECK(tv != nullptr) << "expect Tensor argument";
if (need_shape) {
int64_t ndim = tv->data.Shape().size();
NDArray shape_arr;
if (ndim == 0) {
shape_arr = NDArray::Empty({}, Type2TVMType(Int(64)), cpu_ctx);
} else {
shape_arr = NDArray::Empty({ndim}, Type2TVMType(Int(64)), cpu_ctx);
int64_t* data = reinterpret_cast<int64_t*>(shape_arr->data);
for (auto j = 0; j < ndim; ++j) {
data[j] = tv->data.Shape()[j];
}
}
inputs[i] = shape_arr;
setter(i, shape_arr);
} else {
auto arr = tv->data.CopyTo(cpu_ctx);
inputs[i] = arr;
setter(i, arr);
}
};
size_t arg_counter = 0;
for (size_t i = 0; i < args.size(); ++i) {
auto arg = args[i];
auto param = func->params[i];
int state = cfunc->shape_func_param_states[i]->value;
if (arg.as<TensorValueNode>()) {
if (state & kNeedInputData) {
fset_input(arg_counter++, arg, false);
}
if (state & kNeedInputShape) {
fset_input(arg_counter++, arg, true);
}
} else {
const TupleValueNode* tuple = arg.as<TupleValueNode>();
CHECK(tuple != nullptr);
if (state & kNeedInputData) {
for (size_t i = 0; i < tuple->fields.size(); ++i) {
fset_input(arg_counter++, tuple->fields[i], false);
}
}
if (state & kNeedInputShape) {
for (size_t i = 0; i < tuple->fields.size(); ++i) {
fset_input(arg_counter++, tuple->fields[i], true);
}
}
}
}
CHECK_EQ(arg_counter, cfunc->inputs.size())
<< "Shape function input sizes mismatch";
auto fset_shape_output = [&](size_t i, Type val_type) {
// TODO(@icemelon): allow recursive tuple
const TensorTypeNode* rtype = val_type.as<TensorTypeNode>();
CHECK(rtype != nullptr);
int64_t ndim = rtype->shape.size();
auto arr = NDArray::Empty({ndim}, Type2TVMType(Int(64)), cpu_ctx);
outputs[i] = arr;
setter(arg_counter + i, arr);
};
auto ret_type = func->body->checked_type();
size_t out_cnt = 0;
if (auto rtype = ret_type.as<TupleTypeNode>()) {
out_cnt = rtype->fields.size();
for (size_t i = 0; i < out_cnt; ++i) {
fset_shape_output(i, rtype->fields[i]);
}
} else {
out_cnt = 1;
auto tt = Downcast<TensorType>(ret_type);
fset_shape_output(0, tt);
}
CHECK_EQ(cfunc->outputs.size(), out_cnt)
<< "Shape function output sizes mismatch";
PackedFunc shape_func;
TVMRetValue rv;
if (const auto* f = runtime::Registry::Get("relay.backend.build")) {
tvm::runtime::Module m = (*f)(cfunc->funcs, cfunc->target);
shape_func = m.GetFunction(cfunc->func_name);
} else {
LOG(FATAL) << "relay.backend.build is not registered";
}
shape_func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv);
// Get output shapes
Array<Shape> out_shapes;
for (auto out_tensor : outputs) {
int64_t* shape_data = reinterpret_cast<int64_t*>(out_tensor->data);
Shape out_shape;
for (int i = 0; i < out_tensor->shape[0]; ++i) {
out_shape.push_back(tvm::Integer(shape_data[i]));
}
out_shapes.push_back(out_shape);
}
return out_shapes;
}
Value InvokePrimitiveOp(const Function& func,
const Array<Value>& args) { const Array<Value>& args) {
auto call_node = func->body.as<CallNode>(); auto call_node = func->body.as<CallNode>();
...@@ -394,17 +512,46 @@ class Interpreter : ...@@ -394,17 +512,46 @@ class Interpreter :
return out_tensor; return out_tensor;
}; };
Array<Shape> out_shapes;
auto ret_type = func->body->checked_type();
bool is_dyn = IsDynamic(func->checked_type());
if (call_node->op == Op::Get("shape_of")) {
// The output shape of shape_of must be static since Relay doesn't support
// dynamic rank tensors.
is_dyn = false;
}
if (is_dyn) {
CHECK(func->IsPrimitive());
out_shapes = ComputeDynamicShape(func, args);
}
PackedFunc packed_func = engine_->JIT(CCacheKeyNode::make(func, target_)); PackedFunc packed_func = engine_->JIT(CCacheKeyNode::make(func, target_));
TVMRetValue rv; TVMRetValue rv;
if (const TupleTypeNode* rtype = func->body->checked_type().as<TupleTypeNode>()) { if (const TupleTypeNode* rtype = func->body->checked_type().as<TupleTypeNode>()) {
CHECK(!is_dyn || out_shapes.size() == rtype->fields.size());
Array<Value> fields; Array<Value> fields;
for (size_t i = 0; i < rtype->fields.size(); ++i) { for (size_t i = 0; i < rtype->fields.size(); ++i) {
if (is_dyn) {
auto sh = out_shapes[i];
auto tt = Downcast<TensorType>(rtype->fields[i]);
fields.push_back(fset_output(i, TensorTypeNode::make(sh, tt->dtype)));
} else {
fields.push_back(fset_output(i, rtype->fields[i])); fields.push_back(fset_output(i, rtype->fields[i]));
} }
}
packed_func.CallPacked(TVMArgs(values.data(), codes.data(), arg_len), &rv); packed_func.CallPacked(TVMArgs(values.data(), codes.data(), arg_len), &rv);
return TupleValueNode::make(fields); return TupleValueNode::make(fields);
} else { } else {
Value out_tensor = fset_output(0, func->body->checked_type()); Value out_tensor;
if (is_dyn) {
CHECK_EQ(out_shapes.size(), 1);
auto sh = out_shapes[0];
auto tt = Downcast<TensorType>(ret_type);
out_tensor = fset_output(0, TensorTypeNode::make(sh, tt->dtype));
} else {
out_tensor = fset_output(0, ret_type);
}
packed_func.CallPacked(TVMArgs(values.data(), codes.data(), arg_len), &rv); packed_func.CallPacked(TVMArgs(values.data(), codes.data(), arg_len), &rv);
return out_tensor; return out_tensor;
} }
......
...@@ -72,12 +72,10 @@ struct VMCompilerContext { ...@@ -72,12 +72,10 @@ struct VMCompilerContext {
TagMap tag_map; TagMap tag_map;
// Map from global var to a unique integer // Map from global var to a unique integer
GlobalMap global_map; GlobalMap global_map;
// Map from Const object to its index in const pool // List of constants
ConstMap const_map; std::vector<NDArray> constants;
// Map from Const tensor shape to its index in const pool // List of cached functions
ConstTensorShapeMap const_tensor_shape_map; std::vector<CachedFunc> cached_funcs;
// List of lowered functions
std::vector<LoweredFunc> lowered_funcs;
// The functions that have been lowered. // The functions that have been lowered.
std::unordered_map<LoweredFunc, size_t, NodeHash, NodeEqual> seen_funcs; std::unordered_map<LoweredFunc, size_t, NodeHash, NodeEqual> seen_funcs;
}; };
......
...@@ -121,7 +121,7 @@ TVM_REGISTER_API("relay.op._ListOpNames") ...@@ -121,7 +121,7 @@ TVM_REGISTER_API("relay.op._ListOpNames")
TVM_REGISTER_API("relay.op._GetOp").set_body_typed<Op(std::string)>(Op::Get); TVM_REGISTER_API("relay.op._GetOp").set_body_typed<Op(std::string)>(Op::Get);
TVM_REGISTER_API("relay.op._OpGetAttr") TVM_REGISTER_API("relay.op._OpGetAttr")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
Op op = args[0]; Op op = args[0];
std::string attr_name = args[1]; std::string attr_name = args[1];
auto op_map = Op::GetAttr<TVMRetValue>(attr_name); auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
...@@ -130,6 +130,17 @@ TVM_REGISTER_API("relay.op._OpGetAttr") ...@@ -130,6 +130,17 @@ TVM_REGISTER_API("relay.op._OpGetAttr")
} }
}); });
TVM_REGISTER_API("relay.op._OpSetAttr")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Op op = args[0];
std::string attr_name = args[1];
runtime::TVMArgValue value = args[2];
int plevel = args[3];
auto& reg =
OpRegistry::Registry()->__REGISTER_OR_GET__(op->name).set_name();
reg.set_attr(attr_name, value, plevel);
});
TVM_REGISTER_API("relay.op._Register") TVM_REGISTER_API("relay.op._Register")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
std::string op_name = args[0]; std::string op_name = args[0];
......
...@@ -528,7 +528,11 @@ bool ReshapeRel(const Array<Type>& types, ...@@ -528,7 +528,11 @@ bool ReshapeRel(const Array<Type>& types,
used_input_dims.insert(src_idx); used_input_dims.insert(src_idx);
IndexExpr d2 = data_shape[src_idx++]; IndexExpr d2 = data_shape[src_idx++];
used_output_dims.insert(oshape.size()); used_output_dims.insert(oshape.size());
if (d1.as<Any>() || d2.as<Any>()) {
oshape.push_back(Any::make());
} else {
oshape.push_back(d1 * d2); oshape.push_back(d1 * d2);
}
} else if (svalue == -4) { } else if (svalue == -4) {
// split the source dim s into two dims // split the source dim s into two dims
// read the left dim and then the right dim (either can be -1) // read the left dim and then the right dim (either can be -1)
...@@ -563,6 +567,8 @@ bool ReshapeRel(const Array<Type>& types, ...@@ -563,6 +567,8 @@ bool ReshapeRel(const Array<Type>& types,
oshape.push_back(d2); oshape.push_back(d2);
} }
} }
} else {
CHECK(false) << "Unsupported special value: " << svalue;
} }
} }
...@@ -608,7 +614,15 @@ Array<Tensor> ReshapeCompute(const Attrs& attrs, ...@@ -608,7 +614,15 @@ Array<Tensor> ReshapeCompute(const Attrs& attrs,
const Target& target) { const Target& target) {
const auto* out_ttype = out_type.as<TensorTypeNode>(); const auto* out_ttype = out_type.as<TensorTypeNode>();
CHECK(out_ttype != nullptr); CHECK(out_ttype != nullptr);
return { topi::reshape(inputs[0], out_ttype->shape) }; Array<IndexExpr> newshape;
for (auto val : out_ttype->shape) {
if (val->is_type<ir::Any>()) {
newshape.push_back(val.as<ir::Any>()->ToVar());
} else {
newshape.push_back(val);
}
}
return { topi::reshape(inputs[0], newshape) };
} }
Expr MakeReshape(Expr data, Expr MakeReshape(Expr data,
...@@ -1108,7 +1122,8 @@ RELAY_REGISTER_OP("arange") ...@@ -1108,7 +1122,8 @@ RELAY_REGISTER_OP("arange")
.set_support_level(3) .set_support_level(3)
.add_type_rel("Arange", ArangeRel) .add_type_rel("Arange", ArangeRel)
.set_attr<FTVMCompute>("FTVMCompute", ArangeCompute) .set_attr<FTVMCompute>("FTVMCompute", ArangeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective) // TODO(@icemelon): Change arange to kOpaque because FuseOps doesn't consider dynamic shape
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<AnyCodegenStrategy>("AnyCodegenStrategy", kVariableDimensions); .set_attr<AnyCodegenStrategy>("AnyCodegenStrategy", kVariableDimensions);
// repeat operator // repeat operator
......
...@@ -295,7 +295,9 @@ RELAY_REGISTER_OP("shape_of") ...@@ -295,7 +295,9 @@ RELAY_REGISTER_OP("shape_of")
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.add_type_rel("ShapeOf", ShapeOfRel) .add_type_rel("ShapeOf", ShapeOfRel)
.set_attr<TOpIsStateful>("TOpIsStateful", false) .set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<TOpPattern>("TOpPattern", kInjective) // Use kOpaque for shape_of op for now since it won't be performance critic,
// and it makes things easier for dynamic shape func
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", .set_attr<FInferCorrectLayout>("FInferCorrectLayout",
ElemwiseArbitraryLayout) ElemwiseArbitraryLayout)
.set_support_level(10) .set_support_level(10)
......
...@@ -81,16 +81,17 @@ Type ConcreteBroadcast(const TensorType& t1, ...@@ -81,16 +81,17 @@ Type ConcreteBroadcast(const TensorType& t1,
for (; i <= std::min(ndim1, ndim2); ++i) { for (; i <= std::min(ndim1, ndim2); ++i) {
IndexExpr s1 = t1->shape[ndim1 - i]; IndexExpr s1 = t1->shape[ndim1 - i];
IndexExpr s2 = t2->shape[ndim2 - i]; IndexExpr s2 = t2->shape[ndim2 - i];
if (EqualCheck(s1, s2)) { if (EqualConstInt(s1, 1)) {
oshape.push_back(s1);
} else if (EqualConstInt(s1, 1)) {
oshape.push_back(s2); oshape.push_back(s2);
} else if (EqualConstInt(s2, 1)) { } else if (EqualConstInt(s2, 1)) {
oshape.push_back(s1); oshape.push_back(s1);
} else if (s1.as<Any>() && EqualConstInt(s2, 1)) { } else if (s1.as<Any>()) {
// TODO(@jroesch): we need to come back to this // s1 == 1 || s1 == s2
oshape.push_back(s2); oshape.push_back(s2);
} else if (s2.as<Any>() && EqualConstInt(s1, 1)) { } else if (s2.as<Any>()) {
// s2 == 1 || s2 == s1
oshape.push_back(s1);
} else if (EqualCheck(s1, s2)) {
oshape.push_back(s1); oshape.push_back(s1);
} else { } else {
RELAY_ERROR( RELAY_ERROR(
......
...@@ -915,7 +915,7 @@ class FuseMutator : private ExprMutator { ...@@ -915,7 +915,7 @@ class FuseMutator : private ExprMutator {
if (it == gmap_.end()) return ""; if (it == gmap_.end()) return "";
std::ostringstream os; std::ostringstream os;
auto *group = it->second->FindRoot(); auto *group = it->second->FindRoot();
os << "group=" << group; os << " /* group=" << group << " */";
return os.str(); return os.str();
}); });
LOG(INFO) << "Dump of group info:\n" << text; LOG(INFO) << "Dump of group info:\n" << text;
......
...@@ -120,7 +120,7 @@ class ModulePassNode : public PassNode { ...@@ -120,7 +120,7 @@ class ModulePassNode : public PassNode {
/*! /*!
* \brief Get the pass information/meta data. * \brief Get the pass information/meta data.
*/ */
PassInfo Info() const { return pass_info; } PassInfo Info() const override { return pass_info; }
TVM_DLL static ModulePass make( TVM_DLL static ModulePass make(
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func, runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func,
...@@ -174,7 +174,7 @@ class FunctionPassNode : public PassNode { ...@@ -174,7 +174,7 @@ class FunctionPassNode : public PassNode {
/*! /*!
* \brief Get the pass information/meta data. * \brief Get the pass information/meta data.
*/ */
PassInfo Info() const { return pass_info; } PassInfo Info() const override { return pass_info; }
TVM_DLL static FunctionPass make( TVM_DLL static FunctionPass make(
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func, runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func,
...@@ -220,7 +220,7 @@ class SequentialNode : public PassNode { ...@@ -220,7 +220,7 @@ class SequentialNode : public PassNode {
/*! /*!
* \brief Get the pass information/meta data. * \brief Get the pass information/meta data.
*/ */
PassInfo Info() const { return pass_info; } PassInfo Info() const override { return pass_info; }
/*! /*!
* \brief Check if a pass is enabled. * \brief Check if a pass is enabled.
......
...@@ -451,11 +451,11 @@ std::string StrJoin(T* items, int offset, int cnt, std::string delim = ", ") { ...@@ -451,11 +451,11 @@ std::string StrJoin(T* items, int offset, int cnt, std::string delim = ", ") {
void InstructionPrint(std::ostream& os, const Instruction& instr) { void InstructionPrint(std::ostream& os, const Instruction& instr) {
switch (instr.op) { switch (instr.op) {
case Opcode::Move: { case Opcode::Move: {
os << "move $" << instr.dst << " $" << instr.from << std::endl; os << "move $" << instr.dst << " $" << instr.from;
break; break;
} }
case Opcode::Ret: { case Opcode::Ret: {
os << "ret $" << instr.result << std::endl; os << "ret $" << instr.result;
break; break;
} }
case Opcode::Fatal: { case Opcode::Fatal: {
...@@ -469,7 +469,7 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { ...@@ -469,7 +469,7 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) {
<< ", out: $" << ", out: $"
<< StrJoin<RegName>(instr.packed_args, instr.arity - instr.output_size, << StrJoin<RegName>(instr.packed_args, instr.arity - instr.output_size,
instr.output_size, ", $") instr.output_size, ", $")
<< ")" << std::endl; << ")";
break; break;
} }
case Opcode::AllocTensor: { case Opcode::AllocTensor: {
...@@ -478,71 +478,61 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { ...@@ -478,71 +478,61 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) {
instr.alloc_tensor.ndim) instr.alloc_tensor.ndim)
<< "] "; << "] ";
DLDatatypePrint(os, instr.alloc_tensor.dtype); DLDatatypePrint(os, instr.alloc_tensor.dtype);
os << std::endl;
break; break;
} }
case Opcode::AllocTensorReg: { case Opcode::AllocTensorReg: {
os << "alloc_tensor_reg $" << instr.dst << " $" os << "alloc_tensor_reg $" << instr.dst << " $"
<< instr.alloc_tensor_reg.shape_register << " "; << instr.alloc_tensor_reg.shape_register << " ";
DLDatatypePrint(os, instr.alloc_tensor_reg.dtype); DLDatatypePrint(os, instr.alloc_tensor_reg.dtype);
os << std::endl;
break; break;
} }
case Opcode::AllocDatatype: { case Opcode::AllocDatatype: {
os << "alloc_data $" << instr.dst << " tag(" << instr.constructor_tag << ") [$" os << "alloc_data $" << instr.dst << " tag(" << instr.constructor_tag << ") [$"
<< StrJoin<RegName>(instr.datatype_fields, 0, instr.num_fields, ",$") << "]" << StrJoin<RegName>(instr.datatype_fields, 0, instr.num_fields, ",$") << "]";
<< std::endl;
break; break;
} }
case Opcode::AllocClosure: { case Opcode::AllocClosure: {
os << "alloc_closure $" << instr.dst << " VMFunc[" << instr.clo_index os << "alloc_closure $" << instr.dst << " VMFunc[" << instr.clo_index
<< "]($" << StrJoin<RegName>(instr.free_vars, 0, instr.num_freevar, ",$") << "]($" << StrJoin<RegName>(instr.free_vars, 0, instr.num_freevar, ",$")
<< ")" << ")";
<< std::endl;
break; break;
} }
case Opcode::If: { case Opcode::If: {
os << "if " << "$" << instr.if_op.test << " " << instr.if_op.target << " " os << "if " << "$" << instr.if_op.test << " " << instr.if_op.target << " "
<< instr.if_op.true_offset << " " << instr.if_op.false_offset << instr.if_op.true_offset << " " << instr.if_op.false_offset;
<< std::endl;
break; break;
} }
case Opcode::Invoke: { case Opcode::Invoke: {
os << "invoke $" << instr.dst << " VMFunc[" << instr.func_index << "]($" os << "invoke $" << instr.dst << " VMFunc[" << instr.func_index << "]($"
<< StrJoin<RegName>(instr.invoke_args_registers, 0, instr.num_args, ",$") << StrJoin<RegName>(instr.invoke_args_registers, 0, instr.num_args, ",$")
<< ")" << ")";
<< std::endl;
break; break;
} }
case Opcode::InvokeClosure: { case Opcode::InvokeClosure: {
os << "invoke_closure $" << instr.dst << " $" << instr.closure << "($" os << "invoke_closure $" << instr.dst << " $" << instr.closure << "($"
<< StrJoin<RegName>(instr.closure_args, 0, instr.num_closure_args, ",$") << StrJoin<RegName>(instr.closure_args, 0, instr.num_closure_args, ",$")
<< ")" << ")";
<< std::endl;
break; break;
} }
case Opcode::LoadConst: { case Opcode::LoadConst: {
os << "load_const $" << instr.dst << " Const[" << instr.const_index << "]" os << "load_const $" << instr.dst << " Const[" << instr.const_index << "]";
<< std::endl;
break; break;
} }
case Opcode::LoadConsti: { case Opcode::LoadConsti: {
os << "load_consti $" << instr.dst << " Const[" << instr.load_consti.val << "]" os << "load_consti $" << instr.dst << " Const[" << instr.load_consti.val << "]";
<< std::endl;
break; break;
} }
case Opcode::GetField: { case Opcode::GetField: {
os << "get_field $" << instr.dst << " $" << instr.object << "[" os << "get_field $" << instr.dst << " $" << instr.object << "["
<< instr.field_index << "]" << instr.field_index << "]";
<< std::endl;
break; break;
} }
case Opcode::GetTag: { case Opcode::GetTag: {
os << "get_tag $" << instr.dst << " $" << instr.get_tag.object << std::endl; os << "get_tag $" << instr.dst << " $" << instr.get_tag.object;
break; break;
} }
case Opcode::Goto: { case Opcode::Goto: {
os << "goto " << instr.pc_offset << std::endl; os << "goto " << instr.pc_offset;
break; break;
} }
default: default:
...@@ -559,9 +549,7 @@ std::ostream& operator<<(std::ostream& os, const Instruction& instr) { ...@@ -559,9 +549,7 @@ std::ostream& operator<<(std::ostream& os, const Instruction& instr) {
void VMFunctionPrint(std::ostream& os, const VMFunction& vm_func) { void VMFunctionPrint(std::ostream& os, const VMFunction& vm_func) {
os << vm_func.name << ": " << std::endl; os << vm_func.name << ": " << std::endl;
for (size_t i = 0; i < vm_func.instructions.size(); ++i) { for (size_t i = 0; i < vm_func.instructions.size(); ++i) {
os << i << ": "; os << i << ": " << vm_func.instructions[i] << ";" << std::endl;
InstructionPrint(os, vm_func.instructions[i]);
os << ";" << std::endl;
} }
} }
...@@ -801,7 +789,7 @@ void VirtualMachine::RunLoop() { ...@@ -801,7 +789,7 @@ void VirtualMachine::RunLoop() {
while (true) { while (true) {
main_loop: main_loop:
auto const& instr = this->code[this->pc]; auto const& instr = this->code[this->pc];
DLOG(INFO) << "Executing(" << pc << "): "; DLOG(INFO) << "Executing(" << pc << "): " << instr;
#if USE_RELAY_DEBUG #if USE_RELAY_DEBUG
InstructionPrint(std::cout, instr); InstructionPrint(std::cout, instr);
#endif // USE_RELAY_DEBUG #endif // USE_RELAY_DEBUG
......
...@@ -546,6 +546,8 @@ void InjectInline(ScheduleNode* sch) { ...@@ -546,6 +546,8 @@ void InjectInline(ScheduleNode* sch) {
std::vector<Array<Expr> > new_body(sch->stages.size()); std::vector<Array<Expr> > new_body(sch->stages.size());
std::vector<bool> changed(sch->stages.size(), false); std::vector<bool> changed(sch->stages.size(), false);
std::vector<Stmt> new_hybrid_body(sch->stages.size());
std::vector<bool> hybrid_changed(sch->stages.size(), false);
// inline all the ops // inline all the ops
for (size_t i = sch->stages.size(); i != 0; --i) { for (size_t i = sch->stages.size(); i != 0; --i) {
Stage stage = sch->stages[i - 1]; Stage stage = sch->stages[i - 1];
...@@ -568,6 +570,7 @@ void InjectInline(ScheduleNode* sch) { ...@@ -568,6 +570,7 @@ void InjectInline(ScheduleNode* sch) {
for (size_t j = i; j < sch->stages.size(); ++j) { for (size_t j = i; j < sch->stages.size(); ++j) {
Stage s = sch->stages[j]; Stage s = sch->stages[j];
const ComputeOpNode* compute = s->op.as<ComputeOpNode>(); const ComputeOpNode* compute = s->op.as<ComputeOpNode>();
const HybridOpNode* hybrid = s->op.as<HybridOpNode>();
if (compute) { if (compute) {
if (!new_body[j].size()) { if (!new_body[j].size()) {
new_body[j] = compute->body; new_body[j] = compute->body;
...@@ -606,6 +609,15 @@ void InjectInline(ScheduleNode* sch) { ...@@ -606,6 +609,15 @@ void InjectInline(ScheduleNode* sch) {
} }
} }
} }
} else if (hybrid) {
if (!new_hybrid_body[j].defined()) {
new_hybrid_body[j] = hybrid->body;
}
Stmt new_stmt = ir::Inline(new_hybrid_body[j], stage->op, args, body);
if (!new_stmt.same_as(new_hybrid_body[j])) {
new_hybrid_body[j] = new_stmt;
hybrid_changed[j] = true;
}
} }
} }
} }
...@@ -632,6 +644,17 @@ void InjectInline(ScheduleNode* sch) { ...@@ -632,6 +644,17 @@ void InjectInline(ScheduleNode* sch) {
} }
s->op = op; s->op = op;
} }
} else if (hybrid_changed[i]) {
const HybridOpNode* hybrid = sch->stages[i]->op.as<HybridOpNode>();
CHECK(hybrid);
Operation op = HybridOpNode::make(
hybrid->name, hybrid->tag, hybrid->attrs, hybrid->inputs,
hybrid->outputs, new_hybrid_body[i]);
op = op->ReplaceInputs(op, repl);
for (int idx = 0; idx < s->op->num_outputs(); ++idx) {
repl[s->op.output(idx)] = op.output(idx);
}
s->op = op;
} else { } else {
Operation op = s->op->ReplaceInputs(s->op, repl); Operation op = s->op->ReplaceInputs(s->op, repl);
if (!op.same_as(s->op)) { if (!op.same_as(s->op)) {
......
...@@ -18,27 +18,156 @@ import numpy as np ...@@ -18,27 +18,156 @@ import numpy as np
import tvm import tvm
from tvm import relay from tvm import relay
from tvm.relay import Kind, transform
from tvm.relay.loops import while_loop from tvm.relay.loops import while_loop
from tvm.relay.testing import run_infer_type as infer_type from tvm.relay.testing import run_infer_type as infer_type
def int32(val): def int32(val):
return relay.const(val, 'int32') return relay.const(val, 'int32')
def any_dims(ndim):
shape = []
for _ in range(ndim):
shape.append(relay.Any())
return tuple(shape)
# TODO(@wweic): because vm doesn't support heterogeneous exec, we can only test
# shape function on CPU.
def verify_any_broadcast(x_shape, y_shape, x_np_shape, y_np_shape, op, np_op):
dtype = 'float32'
x = relay.var('x', shape=x_shape, dtype=dtype)
y = relay.var('y', shape=y_shape, dtype=dtype)
mod = relay.module.Module()
mod["main"] = relay.Function([x, y], op(x, y))
x_np = np.random.uniform(size=x_np_shape).astype(dtype)
y_np = np.random.uniform(size=y_np_shape).astype(dtype)
for kind in ["debug", "vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(x_np, y_np)
tvm.testing.assert_allclose(result.asnumpy(), np_op(x_np, y_np))
def test_any_broadcast():
verify_any_broadcast((relay.Any(),), (3, 2), (1,), (3, 2), relay.add, np.add)
verify_any_broadcast((relay.Any(), 2), (1, 2), (1, 2), (1, 2), relay.add, np.add)
verify_any_broadcast((relay.Any(), 2), (1, 2), (3, 2), (1, 2), relay.add, np.add)
verify_any_broadcast((relay.Any(), 2), (3, 2), (1, 2), (3, 2), relay.add, np.add)
verify_any_broadcast((relay.Any(), 2), (3, relay.Any()), (1, 2), (3, 1), relay.add, np.add)
# The following currently fail because topi compute treats Any as 1
# will requires auto_broadcast buffer to solve the problem
# TODO(@zhiics): Fix this
# verify_any_broadcast((relay.Any(),), (3, 2), (2,), (3, 2), relay.add, np.add)
# verify_any_broadcast((relay.Any(), 2), (3, 2), (3, 2), (3, 2), relay.add, np.add)
def test_any_concat():
x = relay.var('x', shape=(relay.Any(), 2), dtype="float32")
y = relay.var('y', shape=(1, 2), dtype="float32")
z = relay.op.concatenate([x, y], axis=0)
mod = relay.module.Module()
mod["main"] = relay.Function([x, y], z)
x_np = np.random.uniform(size=(3, 2)).astype('float32')
y_np = np.random.uniform(size=(1, 2)).astype('float32')
for kind in ["debug", "vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(x_np, y_np)
ref = np.concatenate([x_np, y_np], axis=0)
tvm.testing.assert_allclose(result.asnumpy(), ref)
def verify_any_reshape(x_shape, newshape, x_np_shape, out_shape):
x = relay.var('x', shape=x_shape, dtype="float32")
y = relay.reshape(x, newshape=newshape)
mod = relay.module.Module()
mod["main"] = relay.Function([x], y)
data = np.random.uniform(size=x_np_shape).astype('float32')
for kind in ["debug", "vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(data).asnumpy()
assert result.shape == out_shape
tvm.testing.assert_allclose(result.flatten(), data.flatten())
def test_any_reshape():
verify_any_reshape(any_dims(3), (1, -1), (2, 3, 4), (1, 24))
verify_any_reshape(any_dims(3), (0, -1), (2, 3, 4), (2, 12))
verify_any_reshape(any_dims(3), (0, -2), (2, 3, 4), (2, 3, 4))
verify_any_reshape(any_dims(3), (-4, 2, -1, -2), (6, 3, 4), (2, 3, 3, 4))
verify_any_reshape(any_dims(3), (-4, -1, 2, -3), (6, 3, 4), (3, 2, 12))
def verify_any_take(data_shape, indices_shape, axis, data_np_shape, indices_np_shape):
mod = relay.Module()
data = relay.var('data', shape=data_shape, dtype='float32')
indices = relay.var('indices', shape=indices_shape, dtype='int32')
y = relay.take(data, indices, axis=axis)
mod["main"] = relay.Function([data, indices], y)
data_np = np.random.uniform(size=data_np_shape).astype('float32')
if axis is None:
max_index = data_np.size
else:
max_index = data_np.shape[axis]
indices_np = np.random.randint(max_index, size=indices_np_shape).astype('int32')
ref = np.take(data_np, indices_np, axis=axis)
for kind in ["debug", "vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(data_np, indices_np)
tvm.testing.assert_allclose(result.asnumpy(), ref)
def test_any_take():
verify_any_take(any_dims(2), (1,), 0, (4, 5), (1,))
verify_any_take(any_dims(2), (), 0, (4, 5), ())
verify_any_take(any_dims(2), (), None, (4, 5), ())
verify_any_take(any_dims(3), any_dims(2), 1, (3, 4, 5), (2, 3))
verify_any_take(any_dims(2), any_dims(3), None, (4, 5), (2, 3, 4))
verify_any_take(any_dims(2), any_dims(4), -1, (4, 5), (2, 3, 4, 5))
def test_any_shape_of():
x = relay.var('x', shape=any_dims(2), dtype='float32')
y = relay.shape_of(x)
mod = relay.module.Module()
mod["main"] = relay.Function([x], y)
data = np.random.uniform(size=(3, 4)).astype('float32')
for kind in ["debug", "vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(data)
tvm.testing.assert_allclose(result.asnumpy(), np.array([3,4]).astype("int64"))
x = relay.var('x', shape=any_dims(3), dtype='float32')
y0 = relay.shape_of(x)
y1 = relay.take(y0, relay.const(1, 'int32'))
mod = relay.module.Module()
mod["main"] = relay.Function([x], y1)
data = np.random.uniform(size=(2, 3, 4)).astype('float32')
for kind in ["debug", "vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(data)
tvm.testing.assert_allclose(result.asnumpy(), np.array(3).astype("int64"))
def test_fused_ops():
x = relay.var('x', shape=(relay.Any(), relay.Any()), dtype='float32')
y0 = x + relay.const(1.0, 'float32')
y1 = y0 * relay.const(2.0, 'float32')
mod = relay.module.Module()
mod["main"] = relay.Function([x], y1)
data = np.random.uniform(size=(5, 4)).astype('float32')
for kind in ["vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(data)
tvm.testing.assert_allclose(result.asnumpy(), (data + 1) * 2)
def test_arange_with_dynamic_shape(): def test_arange_with_dynamic_shape():
m, n, k = relay.ShapeVar('m'), relay.ShapeVar('n'), relay.ShapeVar('k') m, n, k = relay.ShapeVar('m'), relay.ShapeVar('n'), relay.ShapeVar('k')
x = relay.var('x', shape=(m.var, n.var, k.var), dtype='float32') x = relay.var('x', shape=(m.var, n.var, k.var), dtype='float32')
y0 = relay.shape_of(x) y0 = relay.shape_of(x)
y1 = relay.take(y0, relay.const(0, 'int32')) y1 = relay.take(y0, relay.const(0, 'int32'))
y2 = relay.op.arange(y1) y2 = relay.op.arange(y1, dtype="int32")
ex = relay.create_executor() y3 = y2 + relay.const(1, dtype="int32")
f = relay.Function([x], y2, type_params=[m, n, k]) data = np.random.rand(10, 5, 3).astype('float32')
# TODO(@jroesch): Restore after code generation. mod = relay.module.Module()
# data = np.random.rand(10, 5, 3).astype('float32') mod["main"] = relay.Function([x], y3, type_params=[m, n, k])
# result = ex.evaluate(f)(data) for kind in ["debug", "vm"]:
# np.testing.assert_allclose(result.asnumpy(), np.array(range(10))) ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(data)
def test_dynamic_concat(): tvm.testing.assert_allclose(result.asnumpy(), np.array(range(10)).astype("int32")+1)
def test_recursive_concat():
""" """
fn @concat_loop(%i: int32, %st: (any, 1)) -> (any, 1) { fn @concat_loop(%i: int32, %st: (any, 1)) -> (any, 1) {
if (%i < 10) { if (%i < 10) {
...@@ -66,26 +195,18 @@ def test_dynamic_concat(): ...@@ -66,26 +195,18 @@ def test_dynamic_concat():
start = relay.var('start', shape=(), dtype='int32') start = relay.var('start', shape=(), dtype='int32')
body = loop(start, relay.op.reshape(relay.const(0), newshape=(1, 1))) body = loop(start, relay.op.reshape(relay.const(0), newshape=(1, 1)))
func = relay.Function([start], relay.TupleGetItem(body, 1)) func = relay.Function([start], relay.TupleGetItem(body, 1))
func = infer_type(func) mod = relay.module.Module()
# TODO(@jroesch, @haichen): We should restore this code when codegeneration mod["main"] = func
# is merged data = np.array(0.0, dtype='int32')
# ret_shape = func.checked_type.ret_type.shape # TODO(@jroesch): After LambdaLift pass, TypeInfer pass will fail
# assert len(ret_shape) == 2, "expected 2-dim output" # so currently we cannot run this test case on VM
# assert relay.ir_pass.alpha_eq(ret_shape[0], relay.Any()) for kind in ["debug"]:
# import pdb; pdb.set_trace() ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
# mod = relay.module.Module() result = ex.evaluate()(data)
# print(relay.ir_pass.infer_type(func, mod=mod)) ref = np.array([0] + list(range(10))).reshape((11, 1)).astype("int32")
# ret = relay.Call(loop, [relay.const(0, 'int32'), init]) np.testing.assert_allclose(result.asnumpy(), ref)
# mod[mod.entry_func] = relay.Function([], ret)
# print(relay.ir_pass.infer_type(mod[mod.entry_func], mod=mod)) def test_recursive_concat_with_wrong_annotation():
# initial = np.array(0.0, dtype='float32').reshape((1,))
# iter_stop = np.array(10, dtype='int32')
# ex = relay.create_executor("debug", mod=mod, ctx=tvm.cpu(), target="llvm")
# result = ex.evaluate(mod.entry_func)()
# np.testing.assert_allclose(result.asnumpy(), np.array(range(10)))
def test_dynamic_concat_with_wrong_annotation():
""" """
v0.0.1 v0.0.1
fn (%start: int32) { fn (%start: int32) {
...@@ -133,6 +254,12 @@ def test_dynamic_concat_with_wrong_annotation(): ...@@ -133,6 +254,12 @@ def test_dynamic_concat_with_wrong_annotation():
assert "in particular dimension 0 conflicts 2 does not match 1" in str(e) assert "in particular dimension 0 conflicts 2 does not match 1" in str(e)
if __name__ == "__main__": if __name__ == "__main__":
test_any_broadcast()
test_any_concat()
test_any_reshape()
test_any_take()
test_any_shape_of()
test_fused_ops()
test_arange_with_dynamic_shape() test_arange_with_dynamic_shape()
test_dynamic_concat() test_recursive_concat()
test_dynamic_concat_with_wrong_annotation() test_recursive_concat_with_wrong_annotation()
...@@ -104,9 +104,6 @@ def test_serializer(): ...@@ -104,9 +104,6 @@ def test_serializer():
vm = create_vm(mod) vm = create_vm(mod)
ser = serializer.Serializer(vm) ser = serializer.Serializer(vm)
stats = ser.stats
assert "scalar" in stats
glbs = ser.globals glbs = ser.globals
assert len(glbs) == 3 assert len(glbs) == 3
assert "f1" in glbs assert "f1" in glbs
...@@ -120,8 +117,8 @@ def test_serializer(): ...@@ -120,8 +117,8 @@ def test_serializer():
code = ser.bytecode code = ser.bytecode
assert "main 5 2 5" in code assert "main 5 2 5" in code
assert "f1 3 1 4" in code assert "f1 2 1 3" in code
assert "f2 3 1 4" in code assert "f2 2 1 3" in code
code, lib = ser.serialize() code, lib = ser.serialize()
assert isinstance(code, bytearray) assert isinstance(code, bytearray)
......
...@@ -122,11 +122,13 @@ def test_outer_product(): ...@@ -122,11 +122,13 @@ def test_outer_product():
assert ibody.min.value == 0 assert ibody.min.value == 0
assert ibody.extent.name == 'm' assert ibody.extent.name == 'm'
#Check loop body #Check loop body
jbody = ibody.body jblock = ibody.body
assert isinstance(jblock, tvm.stmt.Block)
jbody = jblock.first
assert isinstance(jbody, tvm.stmt.AssertStmt) assert isinstance(jbody, tvm.stmt.AssertStmt)
assert isinstance(jbody.message, tvm.expr.StringImm) assert isinstance(jbody.message, tvm.expr.StringImm)
assert jbody.message.value == "index out of range!" assert jbody.message.value == "index out of range!"
jbody = jbody.body jbody = jblock.rest
assert isinstance(jbody, tvm.stmt.Provide) assert isinstance(jbody, tvm.stmt.Provide)
assert jbody.func.name == 'c' assert jbody.func.name == 'c'
assert len(jbody.args) == 2 assert len(jbody.args) == 2
......
...@@ -52,6 +52,9 @@ inline BroadcastHelper BroadcastShape(const tvm::Array<tvm::Expr>& shape1, ...@@ -52,6 +52,9 @@ inline BroadcastHelper BroadcastShape(const tvm::Array<tvm::Expr>& shape1,
tvm::Expr one(1); tvm::Expr one(1);
int i; int i;
for (i = 1; i <= std::min(s1_size, s2_size); ++i) { for (i = 1; i <= std::min(s1_size, s2_size); ++i) {
// TODO(@icemelon9): Need to revisit this part
const Variable* var1 = shape1[s1_size - i].as<Variable>();
const Variable* var2 = shape2[s2_size - i].as<Variable>();
bh.all_vars.push_front(tvm::Var()); bh.all_vars.push_front(tvm::Var());
if (topi::detail::EqualCheck(shape1[s1_size - i], shape2[s2_size - i])) { if (topi::detail::EqualCheck(shape1[s1_size - i], shape2[s2_size - i])) {
bh.common_shape.push_front(shape1[s1_size - i]); bh.common_shape.push_front(shape1[s1_size - i]);
...@@ -64,6 +67,16 @@ inline BroadcastHelper BroadcastShape(const tvm::Array<tvm::Expr>& shape1, ...@@ -64,6 +67,16 @@ inline BroadcastHelper BroadcastShape(const tvm::Array<tvm::Expr>& shape1,
} else if (topi::detail::EqualCheck(one, shape2[s2_size - i])) { } else if (topi::detail::EqualCheck(one, shape2[s2_size - i])) {
bh.common_shape.push_front(shape1[s1_size - i]); bh.common_shape.push_front(shape1[s1_size - i]);
bh.vars1.push_front(bh.all_vars[0]); bh.vars1.push_front(bh.all_vars[0]);
} else if (var1 && var2) {
bh.common_shape.push_front(max(shape1[s1_size - i], shape2[s2_size - i]));
bh.vars1.push_front(bh.all_vars[0]);
bh.vars2.push_front(bh.all_vars[0]);
} else if (var1) {
bh.common_shape.push_front(shape2[s2_size - i]);
bh.vars2.push_front(bh.all_vars[0]);
} else if (var2) {
bh.common_shape.push_front(shape1[s1_size - i]);
bh.vars1.push_front(bh.all_vars[0]);
} else { } else {
CHECK(false) << "Incompatible broadcast dims: " << shape1[s1_size - i] CHECK(false) << "Incompatible broadcast dims: " << shape1[s1_size - i]
<< " and " << shape2[s2_size - i] << " in: " << " and " << shape2[s2_size - i] << " in: "
......
...@@ -1148,9 +1148,9 @@ inline Tensor tensordot(const Tensor& A, ...@@ -1148,9 +1148,9 @@ inline Tensor tensordot(const Tensor& A,
return compute(output_shape, func, name, tag); return compute(output_shape, func, name, tag);
} }
inline Tensor arange(const Expr start, inline Tensor arange(const Expr& start,
const Expr stop, const Expr& stop,
const Expr step, const Expr& step,
Type dtype, Type dtype,
std::string name = "T_arange", std::string name = "T_arange",
std::string tag = kInjective) { std::string tag = kInjective) {
......
...@@ -82,9 +82,7 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p ...@@ -82,9 +82,7 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
func(a, w, c) func(a, w, c)
rtol = 1e-5 rtol = 1e-3
if (kernel > 3):
rtol = 2e-5
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=rtol) tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=rtol)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment