Commit eef35a57 by Haichen Shen Committed by Jared Roesch

[Relay][Any] Add shape func for dynamic shape (#3606)

* init shape func in interpreter and vm compiler

* Update interpreter

* fix

* lint

* lint

* fix

* remove hack

* update

* fix

* fix

* update

* address comments & update for shape_of

* fix lint

* update

* fix hybrid

* lint

* fix bug & add take shape func

* lint

* lint

* update

* fix flaky test

* add todo
parent c1c7b9b1
......@@ -704,6 +704,10 @@ class Reduce : public ExprNode {
class Any : public ExprNode {
public:
void VisitAttrs(AttrVisitor* v) final {}
/*! \brief Convert to var. */
Var ToVar() const {
return Variable::make(Int(32), "any_dim");
}
TVM_DLL static Expr make();
......
......@@ -75,6 +75,11 @@ using TOpIsStateful = bool;
using TNonComputational = bool;
/*!
* \brief Mark the operator whether output shape is data dependant.
*/
using TShapeDataDependant = bool;
/*!
* \brief Computation description interface.
*
* \note This function have a special convention
......@@ -186,7 +191,7 @@ using Shape = Array<IndexExpr>;
using FShapeFunc = runtime::TypedPackedFunc<
Array<Tensor>(const Attrs& attrs,
const Array<Tensor>& inputs,
const Array<Shape>& out_shapes)>;
const Array<IndexExpr>& out_ndims)>;
} // namespace relay
} // namespace tvm
......
......@@ -25,7 +25,7 @@ import numbers
from enum import Enum
from .util import _internal_assert, _apply_indices
from .util import _internal_assert
from . import calls
from . import util
from .preprocessor import determine_variable_usage
......@@ -35,7 +35,6 @@ from ..container import Array
from ..tensor import Tensor, Operation
from .. import _api_internal as _tvm_internal
from .. import expr as _expr
from .. import stmt as _stmt
from .. import make as _make
from .. import api as _api
from .. import ir_pass as _ir_pass
......@@ -43,16 +42,15 @@ from .. import ir_pass as _ir_pass
def concat_list_to_block(lst):
"""Concatenate a list of Python IR nodes to HalideIR Block"""
if not lst:
return util.make_nop()
n = len(lst)
if n == 1:
return lst[0]
body = lst[n - 1]
for i in range(1, n):
stmt = lst[n - 1 - i]
if isinstance(stmt, _stmt.AssertStmt):
body = _make.AssertStmt(stmt.condition, stmt.message, body)
else:
body = _make.Block(stmt, body)
body = _make.Block(stmt, body)
return body
......@@ -100,8 +98,8 @@ class HybridParser(ast.NodeVisitor):
ast.LtE : operator.le,
ast.Eq : operator.eq,
ast.NotEq : operator.ne,
ast.And : _all,
ast.Or : _any,
ast.And : _all,
ast.Or : _any,
}
......@@ -179,6 +177,9 @@ class HybridParser(ast.NodeVisitor):
to_pop = []
for key, val in self.usage.items():
_, level, _ = val
if key not in self.symbols:
# don't realize the symbols that are never visited
continue
if level != node:
continue
_internal_assert(key in self.symbols.keys(), "Unknown symbol %s!" % key)
......@@ -363,44 +364,25 @@ class HybridParser(ast.NodeVisitor):
def visit_Attribute(self, node):
_internal_assert(isinstance(node.value, ast.Name), \
"For atrribute access, only both names are supported so far!")
buf = self.visit(node.value)
return getattr(buf, node.attr)
def visit_Subscript(self, node):
args = self.visit(node.slice)
if isinstance(node.value, ast.Name):
if node.value.id in self.closure_vars:
args = ast.literal_eval(str(args))
return _api.convert(_apply_indices(self.closure_vars[node.value.id], args))
buf = self.visit(node.value)
if isinstance(buf, Array):
for i in args:
if isinstance(i, numbers.Integral):
buf = buf[i]
else:
_internal_assert(isinstance(i, (_expr.IntImm, _expr.UIntImm)), \
"All indices are supposed to be constants")
buf = buf[i.value]
return buf
if isinstance(node.ctx, ast.Load):
return _make.Call(buf.dtype, buf.name, args, \
_expr.Call.Halide, buf.op, buf.value_index)
return buf, args
shape = self.visit(node.value)
_internal_assert(len(args) == 1, "For 'shape' access the argument should be only one!")
args = args[0]
#TODO: maybe support non-constant value later?
_internal_assert(isinstance(args, (_expr.IntImm, _expr.UIntImm)), \
"So far only constant shape access supported!")
return shape[args.value]
arr = self.visit(node.value)
if isinstance(arr, Array):
for i in args:
if isinstance(i, numbers.Integral):
arr = arr[i]
else:
_internal_assert(isinstance(i, (_expr.IntImm, _expr.UIntImm)), \
"All indices are supposed to be constants")
arr = arr[i.value]
return arr
if isinstance(node.ctx, ast.Load):
return _make.Call(arr.dtype, arr.name, args,
_expr.Call.Halide, arr.op, arr.value_index)
return arr, args
def visit_With(self, node):
if sys.version_info[0] < 3:
......@@ -417,7 +399,7 @@ class HybridParser(ast.NodeVisitor):
def visit_If(self, node):
cond = self.visit(node.test)
cond = _ir_pass.CanonicalSimplify(self.visit(node.test))
# Return no IfThenElse if proven
if isinstance(cond, _expr.UIntImm):
......@@ -508,11 +490,11 @@ class HybridParser(ast.NodeVisitor):
_name = node.target.id
if isinstance(for_type, tuple):
low = _ir_pass.Simplify(low)
ext = _ir_pass.Simplify(ext)
low = _ir_pass.CanonicalSimplify(low)
ext = _ir_pass.CanonicalSimplify(ext)
_internal_assert(isinstance(low, _expr.ConstExpr) and
isinstance(ext, _expr.ConstExpr), \
"Const range should start from a const" + \
"Const range should start from a const " + \
"and iterate const times")
low, ext = low.value, ext.value
......
......@@ -101,9 +101,3 @@ def _is_tvm_arg_types(args):
_internal_assert(isinstance(elem, np_arg_types), \
"Expect a numpy type but %s get!" % str(type(elem)))
return False
def _apply_indices(value, indices):
"""Apply multidimensional index"""
if indices:
return _apply_indices(value[indices[0]], indices[1:])
return value
......@@ -177,6 +177,10 @@ class VMCompiler(object):
The VM runtime.
"""
target = _update_target(target)
target_host = None if target_host == "" else target_host
if not target_host:
target_host = "llvm" if tvm.module.enabled("llvm") else "stackvm"
target_host = tvm.target.create(target_host)
self._compile(mod, target, target_host)
return VirtualMachine(self._get_vm())
......
......@@ -14,12 +14,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#pylint: disable=invalid-name, unused-argument
#pylint: disable=invalid-name, unused-argument, len-as-condition
"""Backend compiler related feature registration"""
from __future__ import absolute_import
import topi
from .op import register_compute, register_schedule, register_pattern
from .op import register_compute, register_schedule, register_pattern, register_shape_func
from .op import schedule_injective, OpPattern
from ...hybrid import script
schedule_broadcast = schedule_injective
schedule_elemwise = schedule_injective
......@@ -104,3 +105,49 @@ def clip_compute(attrs, inputs, output_type, target):
return [topi.clip(inputs[0], attrs.a_min, attrs.a_max)]
register_schedule("clip", schedule_elemwise)
# shape func
@script
def _broadcast_shape_func(x, y, ndim):
out = output_tensor((ndim,), "int64")
if len(x.shape) == 0:
for i in const_range(ndim):
out[i] = y[i]
elif len(y.shape) == 0:
for i in const_range(ndim):
out[i] = x[i]
else:
ndim1 = x.shape[0]
ndim2 = y.shape[0]
for i in const_range(1, min(ndim1, ndim2)+1):
if x[ndim1-i] == y[ndim2-i]:
out[ndim-i] = x[ndim1-i]
elif x[ndim1-i] == 1:
out[ndim-i] = y[ndim2-i]
else:
assert y[ndim2 - i] == 1, "Incompatible broadcast type %s and %s" % (
x[ndim1-i], y[ndim2-i])
out[ndim-i] = x[ndim1-i]
for i in const_range(min(ndim1, ndim2)+1, ndim+1):
if ndim1 >= ndim2:
out[ndim-i] = x[ndim1-i]
else:
out[ndim-i] = y[ndim2-i]
return out
def broadcast_shape_func(attrs, inputs, out_ndims):
return [_broadcast_shape_func(*inputs, out_ndims[0])]
register_shape_func("add", False, broadcast_shape_func)
register_shape_func("subtract", False, broadcast_shape_func)
register_shape_func("multiply", False, broadcast_shape_func)
register_shape_func("divide", False, broadcast_shape_func)
register_shape_func("mod", False, broadcast_shape_func)
register_shape_func("logical_and", False, broadcast_shape_func)
register_shape_func("logical_or", False, broadcast_shape_func)
register_shape_func("equal", False, broadcast_shape_func)
register_shape_func("not_equal", False, broadcast_shape_func)
register_shape_func("less", False, broadcast_shape_func)
register_shape_func("less_equal", False, broadcast_shape_func)
register_shape_func("greater", False, broadcast_shape_func)
register_shape_func("greater_equal", False, broadcast_shape_func)
......@@ -15,11 +15,14 @@
# specific language governing permissions and limitations
# under the License.
"""Backend compiler related feature registration"""
# pylint: disable=invalid-name,unused-argument
# pylint: disable=invalid-name,unused-argument, len-as-condition
from __future__ import absolute_import
from topi.util import get_const_int, get_const_tuple
from . import op as _reg
from ._reduce import _schedule_reduce
from .op import OpPattern
from ...hybrid import script
from ...api import convert
schedule_injective = _reg.schedule_injective
schedule_broadcast = _reg.schedule_injective
......@@ -58,3 +61,145 @@ _reg.register_schedule("one_hot", schedule_injective)
# layout_transform
_reg.register_schedule("layout_transform", schedule_injective)
_reg.register_pattern("layout_transform", OpPattern.INJECTIVE)
# shape func
@script
def _arange_shape_func(start, stop, step):
out = output_tensor((1,), "int64")
out[0] = int64(ceil_div((float32(stop[0]) - float32(start[0])), float32(step[0])))
return out
@_reg.register_shape_func("arange", True)
def arange_shape_func(attrs, inputs, _):
return [_arange_shape_func(*inputs)]
@script
def _concatenate_shape_func(inputs, axis):
ndim = inputs[0].shape[0]
out = output_tensor((ndim,), "int64")
for i in const_range(ndim):
if i != axis:
out[i] = inputs[0][i]
for j in const_range(1, len(inputs)):
assert out[i] == inputs[j][i], \
"Dims mismatch in the inputs of concatenate."
else:
out[i] = int64(0)
for j in const_range(len(inputs)):
out[i] += inputs[j][i]
return out
@_reg.register_shape_func("concatenate", False)
def concatenate_shape_func(attrs, inputs, _):
axis = get_const_int(attrs.axis)
return [_concatenate_shape_func(inputs, convert(axis))]
@script
def _reshape_shape_func(data_shape, newshape, ndim):
out = output_tensor((ndim,), "int64")
src_idx = 0
dst_idx = 0
infer_idx = -1
copy = False
skip = 0
for i in const_range(len(newshape)):
if skip > 0:
skip -= 1
elif newshape[i] > 0:
out[dst_idx] = int64(newshape[i])
src_idx += 1
dst_idx += 1
elif newshape[i] == 0:
out[dst_idx] = data_shape[src_idx]
src_idx += 1
dst_idx += 1
elif newshape[i] == -1:
assert infer_idx < 0, "One and only one dim can be inferred"
out[dst_idx] = int64(1)
infer_idx = i
dst_idx += 1
elif newshape[i] == -2:
copy = True
elif newshape[i] == -3:
assert data_shape.shape[0] - src_idx > 1, \
"Not enough dims in input shape for -3"
out[dst_idx] = data_shape[src_idx] * data_shape[src_idx+1]
src_idx += 2
dst_idx += 1
elif newshape[i] == -4:
assert len(newshape) - i > 2, "Not enough dims in new shape for -4"
if newshape[i+1] == -1:
assert newshape[i+2] != -1, "Split dims cannot both be -1."
out[dst_idx] = data_shape[src_idx] / int64(newshape[i+2])
out[dst_idx+1] = int64(newshape[i+2])
else:
out[dst_idx] = int64(newshape[i+1])
if newshape[i+2] == -1:
out[dst_idx+1] = data_shape[src_idx] / int64(newshape[i+1])
else:
out[dst_idx+1] = int64(newshape[i+2])
assert data_shape[src_idx] == out[dst_idx] * out[dst_idx+1],\
"Product of split dims doesn't match to input dim"
src_idx += 1
dst_idx += 2
skip = 2
else:
assert False, "Invalid special values in new shape"
if len(data_shape.shape) > 0:
# if data is not constant, we can then handle -1 and -2
if copy:
for i in range(src_idx, data_shape.shape[0]):
out[dst_idx] = data_shape[i]
dst_idx += 1
if infer_idx >= 0:
old_size = int64(1)
for i in const_range(data_shape.shape[0]):
old_size *= data_shape[i]
new_size = int64(1)
for i in const_range(out.shape[0]):
new_size *= out[i]
out[infer_idx] = old_size / new_size
return out
@_reg.register_shape_func("reshape", False)
def reshape_shape_func(attrs, inputs, out_ndims):
newshape = get_const_tuple(attrs.newshape)
return [_reshape_shape_func(inputs[0], convert(newshape), out_ndims[0])]
@script
def _take_no_axis_shape_func(indices_shape, out_ndim):
out = output_tensor((out_ndim,), "int64")
for i in const_range(out_ndim):
out[i] = indices_shape[i]
return out
@script
def _take_with_axis_shape_func(data_shape, indices_shape, axis, out_ndim):
out = output_tensor((out_ndim,), "int64")
for i in const_range(axis):
out[i] = data_shape[i]
if len(indices_shape.shape) == 0:
# indices is constant
for i in const_range(axis+1, len(data_shape)):
out[i-1] = data_shape[i]
else:
for i in const_range(len(indices_shape)):
out[axis+i] = indices_shape[i]
for i in const_range(axis+1, len(data_shape)):
out[len(indices_shape)+i-1] = data_shape[i]
return out
@_reg.register_shape_func("take", False)
def take_shape_func(attrs, inputs, out_ndims):
"""
Shape function for take op.
"""
if attrs.axis is None:
return [_take_no_axis_shape_func(inputs[1], out_ndims[0])]
else:
axis = get_const_int(attrs.axis)
data_ndim = int(inputs[0].shape[0])
if axis < 0:
axis += data_ndim
assert 0 <= axis < data_ndim
return [_take_with_axis_shape_func(*inputs, convert(axis), out_ndims[0])]
......@@ -48,6 +48,22 @@ class Op(Expr):
"""
return _OpGetAttr(self, attr_name)
def set_attr(self, attr_name, value, plevel=10):
"""Set attribute about the operator.
Parameters
----------
attr_name : str
The attribute name
value : object
The attribute value
plevel : int
The priority level
"""
_OpSetAttr(self, attr_name, value, plevel)
def get(op_name):
"""Get the Op for a given name
......@@ -219,6 +235,26 @@ def register_gradient(op_name, fgradient=None, level=10):
"""
return register(op_name, "FPrimalGradient", fgradient, level)
def register_shape_func(op_name, data_dependant, shape_func=None, level=10):
"""Register operator shape function for an op.
Parameters
----------
op_name : str
The name of the op.
data_dependant : bool
Whether the shape function depends on input data.
shape_func : function (attrs: Attrs, inputs: List[Tensor], out_ndims: List[IndexExpr])
-> shape_tensors: List<Tensor>
The function for computing the dynamic output shapes
level : int
The priority level
"""
get(op_name).set_attr("TShapeDataDependant", data_dependant, level)
return register(op_name, "FShapeFunc", shape_func, level)
_init_api("relay.op", __name__)
......
......@@ -82,7 +82,25 @@ Operation HybridOpNode::make(std::string name,
}
Array<Tensor> HybridOpNode::InputTensors() const {
return inputs;
// Because input tensors could be potentially inlined into hybrid scripts,
// we need to check if all input tensors are used in the body.
std::unordered_set<Tensor> orig_inputs;
for (auto t : inputs) {
orig_inputs.insert(t);
}
std::unordered_set<Tensor> visited;
Array<Tensor> curr_inputs;
ir::PostOrderVisit(body, [&curr_inputs, &orig_inputs, &visited](const NodeRef& n) {
const ir::Call *call = n.as<ir::Call>();
if (call != nullptr && call->func.defined()) {
Tensor t = Operation(call->func.node_).output(call->value_index);
if (orig_inputs.count(t) && !visited.count(t)) {
curr_inputs.push_back(t);
visited.insert(t);
}
}
});
return curr_inputs;
}
Operation HybridOpNode::ReplaceInputs(
......@@ -111,7 +129,8 @@ void HybridOpNode::PropBoundToInputs(
arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet> &dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
for (Tensor t : this->inputs) {
auto curr_inputs = InputTensors();
for (Tensor t : curr_inputs) {
auto it = out_dom_map->find(t);
if (it == out_dom_map->end()) continue;
TensorDom &dom = it->second;
......@@ -180,11 +199,10 @@ Stmt HybridOpNode::BuildProvide(
outputs[i]->dtype);
f_push_bind(buffer, stage->op.output(i));
}
for (int i = static_cast<int>(inputs.size()) - 1; i >= 0; --i) {
Buffer buffer = decl_buffer(
inputs[i]->shape,
inputs[i]->dtype);
f_push_bind(buffer, inputs[i]);
auto curr_inputs = InputTensors();
for (int i = static_cast<int>(curr_inputs.size()) - 1; i >= 0; --i) {
Buffer buffer = decl_buffer(curr_inputs[i]->shape, curr_inputs[i]->dtype);
f_push_bind(buffer, curr_inputs[i]);
}
std::unordered_map<Tensor, Tensor> rmap;
......@@ -203,7 +221,7 @@ Stmt HybridOpNode::BuildProvide(
* tensors have the same names as the operation produces them.
* 2. Once OpNode is wrapped up by an Operation node, it is finalized.
* Later access will be from a const OpNode*.
* This is a chiken-egg paradox. It is impossible to put the output
* This is a chicken-egg paradox. It is impossible to put the output
* tensors into the function body without forming the op node. The
* function body is immutable after the node is formed.
*
......
......@@ -41,7 +41,7 @@ namespace ir {
using runtime::StorageRank;
using runtime::StorageScope;
// Find a linear pattern of storage acess
// Find a linear pattern of storage access
// Used for liveness analysis.
// Composite scopes(loop/thread_launch/IfThen) is represented by two points:
// before_scope -> scope_body -> after_scope
......@@ -193,6 +193,10 @@ class LinearAccessPatternFinder final : public IRVisitor {
VisitNewScope(op);
}
void Visit_(const AssertStmt* op) final {
VisitNewScope(op);
}
// linearized access sequence.
std::vector<StmtEntry> linear_seq_;
// The storage scope of each buffer
......
......@@ -35,7 +35,9 @@
#include <limits>
#include <mutex>
#include <functional>
#include <vector>
#include <unordered_map>
#include "../ir/type_functor.h"
#include "compile_engine.h"
namespace tvm {
......@@ -48,6 +50,43 @@ CCacheKey CCacheKeyNode::make(Function source_func, Target target) {
return CCacheKey(n);
}
struct IsDynamicVisitor : public TypeVisitor {
bool is_dyn{false};
void VisitType_(const TensorTypeNode* tt) {
for (auto dim : tt->shape) {
if (dim.as<Any>()) {
is_dyn = true;
break;
}
}
}
};
bool IsDynamic(const Type& ty) {
IsDynamicVisitor v;
v.VisitType(ty);
return v.is_dyn;
}
Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
// for now, we always use int32 shape when possible
// even if the result of shape inference becomes int64.
Array<IndexExpr> res;
for (IndexExpr val : shape) {
const int64_t* pval = as_const_int(val);
if (pval != nullptr) {
CHECK_LE(pval[0], std::numeric_limits<int32_t>::max());
CHECK_GE(pval[0], std::numeric_limits<int32_t>::min());
res.push_back(ir::IntImm::make(Int(32), *pval));
} else if (val->is_type<ir::Any>()) {
res.push_back(val.as<ir::Any>()->ToVar());
} else {
res.push_back(val);
}
}
return res;
}
// The getter to get schedule from compile engine.
// Get schedule from functor.
class ScheduleGetter :
......@@ -56,23 +95,6 @@ class ScheduleGetter :
explicit ScheduleGetter(Target target)
: target_(target) {}
Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
// for now, we always use int32 shape when possible
// even if the result of shape inference becomes int64.
Array<IndexExpr> res;
for (IndexExpr val : shape) {
const int64_t* pval = as_const_int(val);
if (pval != nullptr) {
CHECK_LE(pval[0], std::numeric_limits<int32_t>::max());
CHECK_GE(pval[0], std::numeric_limits<int32_t>::min());
res.push_back(ir::IntImm::make(Int(32), *pval));
} else {
res.push_back(val);
}
}
return res;
}
std::pair<Schedule, CachedFunc> Create(const Function& prim_func) {
static auto fschedule =
Op::GetAttr<FTVMSchedule>("FTVMSchedule");
......@@ -90,6 +112,7 @@ class ScheduleGetter :
const auto* tuple_type = param->type_as<TupleTypeNode>();
for (Type field : tuple_type->fields) {
const auto* ttype = field.as<TensorTypeNode>();
// TODO(@icemelon): Allow recursive tuple
CHECK(ttype != nullptr);
tvm::Tensor tensor = tvm::placeholder(
GetShape(ttype->shape), ttype->dtype);
......@@ -283,6 +306,255 @@ class ScheduleGetter :
Array<Operation> scalars_;
};
// Creates shape function from functor.
class MakeShapeFunc : public ExprFunctor<Array<Tensor>(const Expr&)> {
public:
MakeShapeFunc() {}
std::pair<Schedule, CachedFunc> Create(const Function& prim_func) {
for (auto param : prim_func->params) {
param_states_[param] = kNoNeed;
Array<tvm::Tensor> data_inputs;
Array<tvm::Tensor> shape_inputs;
auto add_placeholder = [&data_inputs, &shape_inputs](const TensorTypeNode* ttype) {
// Add data placeholder
Shape shape = GetShape(ttype->shape);
tvm::Tensor data_tensor = tvm::placeholder(shape, ttype->dtype);
data_inputs.push_back(data_tensor);
// Add shape placeholder
int64_t ndim = shape.size();
Shape sshape;
if (ndim > 0) {
sshape.push_back(tvm::Integer(ndim));
}
tvm::Tensor shape_tensor = tvm::placeholder(sshape, Int(64));
shape_inputs.push_back(shape_tensor);
};
if (const auto *ttype = param->checked_type().as<TensorTypeNode>()) {
add_placeholder(ttype);
} else {
// flatten tuple of tensor type.
const auto *tuple_type = param->type_as<TupleTypeNode>();
// TODO(@icemelon): Support recursive tuple
CHECK(tuple_type);
for (Type field : tuple_type->fields) {
const auto *ttype = field.as<TensorTypeNode>();
CHECK(ttype);
add_placeholder(ttype);
}
}
param_data_[param] = data_inputs;
param_shapes_[param] = shape_inputs;
}
readable_name_stream_ << "shape_func";
auto cache_node = make_node<CachedFuncNode>();
cache_node->outputs = VisitExpr(prim_func->body);
auto candidate_name = readable_name_stream_.str();
constexpr static size_t kMaxFuncNameLength = 80;
if (candidate_name.size() > kMaxFuncNameLength) {
std::stringstream truncated_name;
truncated_name << candidate_name.substr(0, kMaxFuncNameLength);
truncated_name << "_" << std::hash<std::string>{}(candidate_name) << "_";
candidate_name = truncated_name.str();
}
cache_node->func_name = candidate_name;
// set inputs
for (auto param : prim_func->params) {
int state = param_states_[param];
cache_node->shape_func_param_states.push_back(IntImm::make(Int(32), state));
if (state & kNeedInputData) {
for (auto t : param_data_[param]) {
cache_node->inputs.push_back(t);
}
}
if (state & kNeedInputShape) {
for (auto t : param_shapes_[param]) {
cache_node->inputs.push_back(t);
}
}
}
CachedFunc cfunc(cache_node);
// generate schedule for shape func
Array<Operation> out_ops;
for (auto t : cache_node->outputs) {
out_ops.push_back(t->op);
}
auto schedule = create_schedule(out_ops);
tvm::schedule::AutoInlineInjective(schedule);
for (const auto& scalar : scalars_) {
auto scalar_op = scalar->op;
if (schedule->Contain(scalar_op)) {
schedule[scalar_op].compute_inline();
}
}
return std::make_pair(schedule, cfunc);
}
Array<Tensor> VisitExpr(const Expr& expr) {
auto it = memo_.find(expr);
if (it != memo_.end()) {
return it->second;
} else {
Array<Tensor> res = ExprFunctor::VisitExpr(expr);
if (expr.as<VarNode>() == nullptr) {
// Do not memoize vars because shape functions could use either the data
// or the shape of a var each time.
memo_[expr] = res;
}
return res;
}
}
Array<Tensor> VisitExpr_(const VarNode* var_node) final {
auto var = GetRef<Var>(var_node);
auto it = param_states_.find(var);
if (it == param_states_.end()) {
LOG(FATAL) << "Free variable " << var->name_hint();
return {};
} else {
CHECK(data_dependants_.size());
bool data_dependant = data_dependants_.back();
if (data_dependant) {
param_states_[var] |= kNeedInputData;
return param_data_[var];
} else {
param_states_[var] |= kNeedInputShape;
return param_shapes_[var];
}
}
}
Array<Tensor> VisitExpr_(const ConstantNode* op) final {
CHECK(data_dependants_.size());
CHECK(op->is_scalar());
bool data_dependant = data_dependants_.back();
if (data_dependant) {
void* data = op->data->data;
DataType dtype = TVMType2Type(op->data->dtype);
Tensor value = tvm::compute({}, [&](const Array<tvm::Var>&) {
if (dtype == Int(32)) {
return make_const(dtype, static_cast<const int32_t*>(data)[0]);
} else if (dtype == Int(64)) {
return make_const(dtype, static_cast<const int64_t*>(data)[0]);
} else if (dtype == Float(32)) {
return make_const(dtype, static_cast<const float*>(data)[0]);
} else if (dtype == Float(64)) {
return make_const(dtype, static_cast<const double*>(data)[0]);
} else if (dtype == Bool()) {
return make_const(dtype, static_cast<const uint8_t*>(data)[0]);
} else {
LOG(FATAL) << "not handled";
return tvm::Expr();
}
}, "data_const", topi::kBroadcast);
scalars_.push_back(value);
return {value};
} else {
Tensor value = tvm::compute({}, [&](const Array<tvm::Var>&) {
return make_const(Int(64), 0);
}, "shape_const", topi::kBroadcast);
scalars_.push_back(value);
return {value};
}
}
Array<Tensor> VisitExpr_(const CallNode* call_node) final {
static auto fshape_func = Op::GetAttr<FShapeFunc>("FShapeFunc");
static auto tshape_data_dependant = Op::GetAttr<TShapeDataDependant>(
"TShapeDataDependant");
CHECK(call_node->op.as<OpNode>())
<< "Primitive function only allows call into primitive ops";
Op op = Downcast<Op>(call_node->op);
CHECK(data_dependants_.empty() || !data_dependants_.back())
<< "Error in op fusion: output of the shape func is fed to a "
<< "data-dependant shape func";
CHECK_GT(fshape_func.count(op), 0)
<< "Internal error, cannot find ShapeFunc for " << op->name;
CHECK_GT(tshape_data_dependant.count(op), 0)
<< "Internal error, cannot find TShapeDataDependant for " << op->name;
data_dependants_.push_back(tshape_data_dependant[op]);
// Visit all inputs
Array<Tensor> inputs;
int count_tuple = 0;
for (Expr arg : call_node->args) {
if (arg->checked_type().as<TupleTypeNode>()) {
++count_tuple;
}
for (Tensor tensor : VisitExpr(arg)) {
inputs.push_back(tensor);
}
}
if (count_tuple) {
CHECK_EQ(call_node->args.size(), 1U)
<< "Only allow function with a single tuple input";
}
// Get output ndims
auto ret_type = call_node->checked_type();
Array<IndexExpr> out_ndims;
if (const auto* ttype = ret_type.as<TensorTypeNode>()) {
out_ndims.push_back(IntImm::make(Int(32), ttype->shape.size()));
} else {
auto rtype = ret_type.as<TupleTypeNode>();
// TODO(@icemelon): Allow recursive tuple
CHECK(rtype);
for (size_t i = 0; i < rtype->fields.size(); ++i) {
auto ttype = rtype->fields[i].as<TensorTypeNode>();
CHECK(ttype);
out_ndims.push_back(IntImm::make(Int(32), ttype->shape.size()));
}
}
// Call shape function
auto outputs = fshape_func[op](call_node->attrs, inputs, out_ndims);
data_dependants_.pop_back();
readable_name_stream_ << "_" << op->name;
return outputs;
}
Array<Tensor> VisitExpr_(const FunctionNode* op) final {
LOG(FATAL) << "Do not support sub function";
return Array<Tensor>();
}
Array<Tensor> VisitExpr_(const LetNode* op) final {
Array<Tensor> val = VisitExpr(op->value);
CHECK(!memo_.count(op->var));
memo_[op->var] = val;
return VisitExpr(op->body);
}
Array<Tensor> VisitExpr_(const TupleNode* op) final {
Array<Tensor> fields;
for (Expr field : op->fields) {
CHECK(field->checked_type().as<TensorTypeNode>())
<< "Only allow Tuple of Tensor";
Array<Tensor> res = VisitExpr(field);
CHECK_EQ(res.size(), 1);
fields.push_back(res[0]);
}
return fields;
}
private:
/*! \brief String stream for function name */
std::ostringstream readable_name_stream_;
/*! \brief Map from parameter to its shape function usage state */
std::unordered_map<Expr, int, NodeHash, NodeEqual> param_states_;
/*! \brief Map from parameter to list of data placeholder */
std::unordered_map<Expr, Array<Tensor>, NodeHash, NodeEqual> param_data_;
/*! \brief Map from parameter to list of shape placeholder */
std::unordered_map<Expr, Array<Tensor>, NodeHash, NodeEqual> param_shapes_;
/*! \brief Memoized visit result */
std::unordered_map<Expr, Array<Tensor>, NodeHash, NodeEqual> memo_;
/*! \brief Stack of data dependencies for shape function */
std::vector<bool> data_dependants_;
/*! \brief Scalars used in the shape function */
Array<Tensor> scalars_;
};
class CompileEngineImpl : public CompileEngineNode {
public:
......@@ -304,6 +576,11 @@ class CompileEngineImpl : public CompileEngineNode {
}
return value->packed_func;
}
CachedFunc LowerShapeFunc(const CCacheKey& key) final {
return LowerShapeFuncInternal(key)->cached_func;
}
void Clear() final {
cache_.clear();
}
......@@ -379,6 +656,40 @@ class CompileEngineImpl : public CompileEngineNode {
value->cached_func = CachedFunc(cache_node);
return value;
}
// implement lowered shape func
CCacheValue LowerShapeFuncInternal(const CCacheKey& key) {
std::lock_guard<std::mutex> lock(mutex_);
CCacheValue value;
auto it = shape_func_cache_.find(key);
if (it != shape_func_cache_.end()) {
it->second->use_count += 1;
if (it->second->cached_func.defined()) return it->second;
value = it->second;
} else {
value = CCacheValue(make_node<CCacheValueNode>());
value->use_count = 0;
shape_func_cache_[key] = value;
}
// Enforce use the target.
With<Target> target_scope(key->target);
CHECK(!value->cached_func.defined());
auto spair = MakeShapeFunc().Create(key->source_func);
auto cache_node = make_node<CachedFuncNode>(
*(spair.second.operator->()));
cache_node->func_name = GetUniqueName(cache_node->func_name);
cache_node->target = key->target;
Array<Tensor> all_args = cache_node->inputs;
for (Tensor arg : cache_node->outputs) {
all_args.push_back(arg);
}
tvm::BuildConfig bcfg = BuildConfig::Create();
std::unordered_map<Tensor, Buffer> binds;
cache_node->funcs = tvm::lower(spair.first, all_args, cache_node->func_name, binds, bcfg);
value->cached_func = CachedFunc(cache_node);
return value;
}
/*!
* \brief Get unique name from name.
* \param name The orginal name.
......@@ -408,6 +719,8 @@ class CompileEngineImpl : public CompileEngineNode {
std::unordered_map<std::string, int> name_map_;
/*! \brief internal compiler cache */
std::unordered_map<CCacheKey, CCacheValue> cache_;
/*! \brief internal compiler cache for shape funcs */
std::unordered_map<CCacheKey, CCacheValue> shape_func_cache_;
};
/*! \brief The global compile engine */
......
......@@ -36,6 +36,14 @@
namespace tvm {
namespace relay {
/*! \brief Indicate whether the data or shape or both of a parameter is used in the shape func. */
enum ShapeFuncParamState {
kNoNeed = 0,
kNeedInputData = 1,
kNeedInputShape = 2,
kNeedBoth = 3,
};
/*! \brief Node container to represent a cached function. */
struct CachedFuncNode : public Node {
/* \brief compiled target */
......@@ -48,6 +56,8 @@ struct CachedFuncNode : public Node {
tvm::Array<Tensor> outputs;
/*! \brief The lowered functions to support the function. */
tvm::Array<tvm::LoweredFunc> funcs;
/*! \brief Parameter usage states in the shape function. */
tvm::Array<Integer> shape_func_param_states;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("target", &target);
......@@ -55,6 +65,7 @@ struct CachedFuncNode : public Node {
v->Visit("inputs", &inputs);
v->Visit("outputs", &outputs);
v->Visit("funcs", &funcs);
v->Visit("shape_func_param_states", &shape_func_param_states);
}
static constexpr const char* _type_key = "relay.CachedFunc";
......@@ -170,6 +181,12 @@ class CompileEngineNode : public Node {
* \return The result.
*/
virtual PackedFunc JIT(const CCacheKey& key) = 0;
/*!
* \brief Lower the shape function.
* \param key The key to the cached function.
* \return The result.
*/
virtual CachedFunc LowerShapeFunc(const CCacheKey& key) = 0;
/*! \brief clear the cache. */
virtual void Clear() = 0;
......@@ -180,7 +197,7 @@ class CompileEngineNode : public Node {
TVM_DECLARE_NODE_TYPE_INFO(CompileEngineNode, Node);
};
/*! \brier cache entry used in compile engine */
/*! \brief cache entry used in compile engine */
class CompileEngine : public NodeRef {
public:
CompileEngine() {}
......@@ -193,6 +210,13 @@ class CompileEngine : public NodeRef {
TVM_DLL static const CompileEngine& Global();
};
/*!
* \brief Check if the type is dynamic.
* \param ty The type to be checked.
* \return The result.
*/
bool IsDynamic(const Type& ty);
// implementations
inline size_t CCacheKeyNode::Hash() const {
if (hash_ != 0) return hash_;
......
......@@ -280,7 +280,7 @@ class Interpreter :
return TupleValueNode::make(values);
}
// TODO(@jroesch): this doesn't support mututal letrec
// TODO(@jroesch): this doesn't support mutual letrec
inline Value MakeClosure(const Function& func, Var letrec_name = Var()) {
tvm::Map<Var, Value> captured_mod;
Array<Var> free_vars = FreeVars(func);
......@@ -310,7 +310,125 @@ class Interpreter :
return MakeClosure(func);
}
Value InvokePrimitiveOp(Function func,
Array<Shape> ComputeDynamicShape(const Function& func,
const Array<Value>& args) {
auto key = CCacheKeyNode::make(func, Target::Create("llvm"));
auto cfunc = engine_->LowerShapeFunc(key);
size_t arity = cfunc->inputs.size() + cfunc->outputs.size();
std::vector<TVMValue> values(arity);
std::vector<int> codes(arity);
TVMArgsSetter setter(values.data(), codes.data());
std::vector<NDArray> inputs(cfunc->inputs.size());
std::vector<NDArray> outputs(cfunc->outputs.size());
DLContext cpu_ctx;
cpu_ctx.device_type = kDLCPU;
cpu_ctx.device_id = 0;
auto fset_input = [&](size_t i, Value val, bool need_shape) {
const TensorValueNode* tv = val.as<TensorValueNode>();
CHECK(tv != nullptr) << "expect Tensor argument";
if (need_shape) {
int64_t ndim = tv->data.Shape().size();
NDArray shape_arr;
if (ndim == 0) {
shape_arr = NDArray::Empty({}, Type2TVMType(Int(64)), cpu_ctx);
} else {
shape_arr = NDArray::Empty({ndim}, Type2TVMType(Int(64)), cpu_ctx);
int64_t* data = reinterpret_cast<int64_t*>(shape_arr->data);
for (auto j = 0; j < ndim; ++j) {
data[j] = tv->data.Shape()[j];
}
}
inputs[i] = shape_arr;
setter(i, shape_arr);
} else {
auto arr = tv->data.CopyTo(cpu_ctx);
inputs[i] = arr;
setter(i, arr);
}
};
size_t arg_counter = 0;
for (size_t i = 0; i < args.size(); ++i) {
auto arg = args[i];
auto param = func->params[i];
int state = cfunc->shape_func_param_states[i]->value;
if (arg.as<TensorValueNode>()) {
if (state & kNeedInputData) {
fset_input(arg_counter++, arg, false);
}
if (state & kNeedInputShape) {
fset_input(arg_counter++, arg, true);
}
} else {
const TupleValueNode* tuple = arg.as<TupleValueNode>();
CHECK(tuple != nullptr);
if (state & kNeedInputData) {
for (size_t i = 0; i < tuple->fields.size(); ++i) {
fset_input(arg_counter++, tuple->fields[i], false);
}
}
if (state & kNeedInputShape) {
for (size_t i = 0; i < tuple->fields.size(); ++i) {
fset_input(arg_counter++, tuple->fields[i], true);
}
}
}
}
CHECK_EQ(arg_counter, cfunc->inputs.size())
<< "Shape function input sizes mismatch";
auto fset_shape_output = [&](size_t i, Type val_type) {
// TODO(@icemelon): allow recursive tuple
const TensorTypeNode* rtype = val_type.as<TensorTypeNode>();
CHECK(rtype != nullptr);
int64_t ndim = rtype->shape.size();
auto arr = NDArray::Empty({ndim}, Type2TVMType(Int(64)), cpu_ctx);
outputs[i] = arr;
setter(arg_counter + i, arr);
};
auto ret_type = func->body->checked_type();
size_t out_cnt = 0;
if (auto rtype = ret_type.as<TupleTypeNode>()) {
out_cnt = rtype->fields.size();
for (size_t i = 0; i < out_cnt; ++i) {
fset_shape_output(i, rtype->fields[i]);
}
} else {
out_cnt = 1;
auto tt = Downcast<TensorType>(ret_type);
fset_shape_output(0, tt);
}
CHECK_EQ(cfunc->outputs.size(), out_cnt)
<< "Shape function output sizes mismatch";
PackedFunc shape_func;
TVMRetValue rv;
if (const auto* f = runtime::Registry::Get("relay.backend.build")) {
tvm::runtime::Module m = (*f)(cfunc->funcs, cfunc->target);
shape_func = m.GetFunction(cfunc->func_name);
} else {
LOG(FATAL) << "relay.backend.build is not registered";
}
shape_func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv);
// Get output shapes
Array<Shape> out_shapes;
for (auto out_tensor : outputs) {
int64_t* shape_data = reinterpret_cast<int64_t*>(out_tensor->data);
Shape out_shape;
for (int i = 0; i < out_tensor->shape[0]; ++i) {
out_shape.push_back(tvm::Integer(shape_data[i]));
}
out_shapes.push_back(out_shape);
}
return out_shapes;
}
Value InvokePrimitiveOp(const Function& func,
const Array<Value>& args) {
auto call_node = func->body.as<CallNode>();
......@@ -394,17 +512,46 @@ class Interpreter :
return out_tensor;
};
Array<Shape> out_shapes;
auto ret_type = func->body->checked_type();
bool is_dyn = IsDynamic(func->checked_type());
if (call_node->op == Op::Get("shape_of")) {
// The output shape of shape_of must be static since Relay doesn't support
// dynamic rank tensors.
is_dyn = false;
}
if (is_dyn) {
CHECK(func->IsPrimitive());
out_shapes = ComputeDynamicShape(func, args);
}
PackedFunc packed_func = engine_->JIT(CCacheKeyNode::make(func, target_));
TVMRetValue rv;
if (const TupleTypeNode* rtype = func->body->checked_type().as<TupleTypeNode>()) {
CHECK(!is_dyn || out_shapes.size() == rtype->fields.size());
Array<Value> fields;
for (size_t i = 0; i < rtype->fields.size(); ++i) {
fields.push_back(fset_output(i, rtype->fields[i]));
if (is_dyn) {
auto sh = out_shapes[i];
auto tt = Downcast<TensorType>(rtype->fields[i]);
fields.push_back(fset_output(i, TensorTypeNode::make(sh, tt->dtype)));
} else {
fields.push_back(fset_output(i, rtype->fields[i]));
}
}
packed_func.CallPacked(TVMArgs(values.data(), codes.data(), arg_len), &rv);
return TupleValueNode::make(fields);
} else {
Value out_tensor = fset_output(0, func->body->checked_type());
Value out_tensor;
if (is_dyn) {
CHECK_EQ(out_shapes.size(), 1);
auto sh = out_shapes[0];
auto tt = Downcast<TensorType>(ret_type);
out_tensor = fset_output(0, TensorTypeNode::make(sh, tt->dtype));
} else {
out_tensor = fset_output(0, ret_type);
}
packed_func.CallPacked(TVMArgs(values.data(), codes.data(), arg_len), &rv);
return out_tensor;
}
......
......@@ -23,12 +23,15 @@
* \brief A compiler from relay::Module to the VM byte code.
*/
#include <tvm/operation.h>
#include <tvm/relay/error.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/interpreter.h>
#include <tvm/logging.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/vm.h>
#include <topi/tags.h>
#include <algorithm>
#include <iostream>
#include <memory>
#include <set>
......@@ -61,42 +64,6 @@ using namespace relay::transform;
// (@jroesch): VM passes, eventually declare as passes.
bool IsClosure(const Function& func);
// Compute the constant pool, i.e a mapping from Constant node to constant index.
struct ConstantPool : ExprVisitor {
std::set<GlobalVar> visited;
Module module;
ConstMap const_map;
ConstTensorShapeMap const_tensor_shape_map;
size_t index;
explicit ConstantPool(const Module& mod) : module(mod), const_map(), index(0) {}
void VisitExpr_(const GlobalVarNode* var_node) {
auto gvar = GetRef<GlobalVar>(var_node);
if (visited.find(gvar) == visited.end()) {
visited.insert(gvar);
this->VisitExpr(this->module->Lookup(gvar));
}
}
void VisitExpr_(const ConstantNode* const_node) {
auto konst = GetRef<Constant>(const_node);
auto it = this->const_map.find(konst);
if (it == this->const_map.end()) {
this->const_map.insert({konst, index++});
}
}
};
std::tuple<ConstMap, ConstTensorShapeMap> LayoutConstantPool(const Module& module) {
auto cp = ConstantPool(module);
for (auto& func : module->functions) {
cp.VisitExpr(func.first);
}
return std::make_tuple(cp.const_map, cp.const_tensor_shape_map);
}
void InstructionPrint(std::ostream& os, const Instruction& instr);
// Represent a runtime object that's going to be matched by pattern match expressions
......@@ -220,12 +187,13 @@ TreeNodePtr BuildDecisionTreeFromClauses(MatchValuePtr data, tvm::Array<Clause>
class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
public:
VMFunctionCompiler(VMCompilerContext* context, TargetsMap targets)
VMFunctionCompiler(VMCompilerContext* context, TargetsMap targets, Target target_host)
: last_register_(0),
registers_num_(0),
engine_(CompileEngine::Global()),
context_(context),
targets_(targets) {}
targets_(targets),
target_host_(target_host) {}
VMFunction Compile(const GlobalVar& var, const Function& func) {
size_t i = 0;
......@@ -288,10 +256,9 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
}
void VisitExpr_(const ConstantNode* const_node) {
auto rconst = GetRef<Constant>(const_node);
auto it = this->context_->const_map.find(rconst);
CHECK(it != this->context_->const_map.end());
Emit(Instruction::LoadConst(it->second, NewRegister()));
size_t konst_idx = context_->constants.size();
context_->constants.push_back(const_node->data);
Emit(Instruction::LoadConst(konst_idx, NewRegister()));
}
void VisitExpr_(const VarNode* var_node) {
......@@ -326,7 +293,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
}
void VisitExpr_(const LetNode* let_node) {
DLOG(INFO) << let_node->value;
DLOG(INFO) << AsText(let_node->value);
this->VisitExpr(let_node->value);
var_register_map_.insert({let_node->var, this->last_register_});
this->VisitExpr(let_node->body);
......@@ -393,29 +360,206 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
this->last_register_ = true_register;
}
Instruction AllocTensorFromType(const TensorTypeNode* ttype) {
TVMType dltype = Type2TVMType(ttype->dtype);
auto tensor_type = GetRef<TensorType>(ttype);
Index EmitGetShape(const TensorTypeNode* ttype, Index reg) {
bool const_shape = true;
std::vector<int64_t> shape;
for (auto dim : tensor_type->shape) {
shape.push_back(Downcast<tvm::Integer>(dim)->value);
for (auto dim : ttype->shape) {
if (auto kdim = dim.as<IntImm>()) {
shape.push_back(kdim->value);
} else {
const_shape = false;
}
}
if (const_shape) {
int64_t ndim = shape.size();
DLContext cpu_ctx;
cpu_ctx.device_type = kDLCPU;
cpu_ctx.device_id = 0;
NDArray shape_tensor;
if (ndim == 0) {
shape_tensor = NDArray::Empty({}, Type2TVMType(Int(64)), cpu_ctx);
} else {
shape_tensor = NDArray::Empty({ndim}, Type2TVMType(Int(64)), cpu_ctx);
int64_t* dims = reinterpret_cast<int64_t*>(shape_tensor->data);
for (size_t i = 0; i < shape.size(); ++i) {
dims[i] = shape[i];
}
}
size_t konst_idx = context_->constants.size();
context_->constants.push_back(shape_tensor);
Emit(Instruction::LoadConst(konst_idx, NewRegister()));
return last_register_;
}
// For dynamic shape, we need insert shape_of op to get its shape at runtime
auto attrs = make_node<ShapeOfAttrs>();
attrs->dtype = Int(64);
static const Op& op = Op::Get("shape_of");
auto input = VarNode::make("input", GetRef<Type>(ttype));
auto expr = CallNode::make(op, {input}, Attrs(attrs), {});
auto func = FunctionNode::make({input}, expr, IncompleteTypeNode::make(Kind::kType), {});
auto mod = ModuleNode::make({}, {});
auto main_gv = GlobalVarNode::make("main");
mod->Add(main_gv, func);
func = mod->Lookup(main_gv);
// shape_of op has to be run on the host target
// TODO(@icemelon9): handle heterogeneous target, such as cuda
auto key = CCacheKeyNode::make(func, target_host_);
auto cfunc = engine_->Lower(key);
auto op_index = -1;
if (context_->seen_funcs.find(cfunc->funcs[0]) == context_->seen_funcs.end()) {
op_index = context_->cached_funcs.size();
context_->cached_funcs.push_back(cfunc);
context_->seen_funcs[cfunc->funcs[0]] = op_index;
} else {
op_index = context_->seen_funcs[cfunc->funcs[0]];
}
std::vector<Index> arg_regs{reg};
int64_t ndim = ttype->shape.size();
if (ndim == 0) {
Emit(Instruction::AllocTensor({}, Int(64), NewRegister()));
} else {
Emit(Instruction::AllocTensor({ndim}, Int(64), NewRegister()));
}
Index shape_reg = last_register_;
arg_regs.push_back(shape_reg);
Emit(Instruction::InvokePacked(op_index, 2, 1, arg_regs));
return shape_reg;
}
std::vector<Index> EmitShapeFunc(const Type& ret_type, const Function& func,
const std::vector<Index>& unpacked_arg_regs) {
// Find the mapping from params to registers
int idx = 0;
std::vector<std::vector<Index>> param_regs;
std::vector<std::vector<const TensorTypeNode*>> param_types;
for (auto param : func->params) {
auto ty = param->checked_type();
std::vector<Index> regs;
std::vector<const TensorTypeNode*> types;
if (auto ttype = ty.as<TensorTypeNode>()) {
regs.push_back(unpacked_arg_regs[idx++]);
types.push_back(ttype);
} else if (const auto tuple_ty = ret_type.as<TupleTypeNode>()) {
for (size_t j = 0; j < tuple_ty->fields.size(); ++j, ++idx) {
regs.push_back(unpacked_arg_regs[idx]);
auto ttype = tuple_ty->fields[j].as<TensorTypeNode>();
CHECK(ttype);
types.push_back(ttype);
}
} else {
LOG(FATAL) << "unsupported parameter type " << ty;
}
param_regs.push_back(regs);
param_types.push_back(types);
}
// Lower shape function
auto key = CCacheKeyNode::make(func, target_host_);
auto cfunc = engine_->LowerShapeFunc(key);
int op_index = -1;
if (context_->seen_funcs.count(cfunc->funcs[0]) == 0) {
op_index = context_->cached_funcs.size();
context_->cached_funcs.push_back(cfunc);
context_->seen_funcs[cfunc->funcs[0]] = op_index;
} else {
op_index = context_->seen_funcs[cfunc->funcs[0]];
}
// Prepare input and output registers
std::vector<Index> shape_func_args;
std::vector<Index> shape_regs;
for (size_t i = 0; i < func->params.size(); ++i) {
int state = cfunc->shape_func_param_states[i]->value;
if (state & kNeedInputData) {
for (auto reg : param_regs[i]) {
// TODO(@icemelon9): Need to copy data here for heterogeneous exec
shape_func_args.push_back(reg);
}
}
if (state & kNeedInputShape) {
for (size_t j = 0; j < param_regs[i].size(); ++j) {
shape_func_args.push_back(EmitGetShape(param_types[i][j], param_regs[i][j]));
}
}
}
for (auto t : cfunc->outputs) {
int64_t ndim = t->shape[0].as<IntImm>()->value;
Emit(Instruction::AllocTensor({ndim}, t->dtype, NewRegister()));
shape_func_args.push_back(last_register_);
shape_regs.push_back(last_register_);
}
int arity = shape_func_args.size();
int ret_count = shape_regs.size();
Emit(Instruction::InvokePacked(op_index, arity, ret_count, shape_func_args));
// Alloc return tensors given the shape regs
std::vector<DataType> ret_dtypes;
if (const auto* tuple_type = ret_type.as<TupleTypeNode>()) {
for (auto field : tuple_type->fields) {
const TensorTypeNode* tty = field.as<TensorTypeNode>();
CHECK(tty);
ret_dtypes.push_back(tty->dtype);
}
} else {
auto tty = ret_type.as<TensorTypeNode>();
CHECK(tty);
ret_dtypes.push_back(tty->dtype);
}
std::vector<Index> ret_regs;
for (size_t i = 0; i < shape_regs.size(); ++i) {
Emit(Instruction::AllocTensorReg(shape_regs[i], ret_dtypes[i], NewRegister()));
ret_regs.push_back(last_register_);
}
return ret_regs;
}
std::vector<Index> AllocReturnType(const Type& ret_type, const Function& func,
const std::vector<Index>& unpacked_arg_regs) {
auto op = func->body.as<CallNode>()->op;
// 1. If either func param types or ret type is dynamic, we need to insert
// shape func to perform type checking at runtime.
// 2. We skip the shape_of function since currently Relay doesn't support
// dynamic rank tensor.
if (op != Op::Get("shape_of") && IsDynamic(func->checked_type())) {
return EmitShapeFunc(ret_type, func, unpacked_arg_regs);
}
return Instruction::AllocTensor(shape, dltype, NewRegister());
std::vector<Index> ret_regs;
auto alloc_tensor = [&](const TensorTypeNode* ttype) {
const TensorType& tensor_type = GetRef<TensorType>(ttype);
std::vector<int64_t> shape;
for (auto dim : tensor_type->shape) {
shape.push_back(Downcast<tvm::Integer>(dim)->value);
}
Emit(Instruction::AllocTensor(shape, Type2TVMType(tensor_type->dtype), NewRegister()));
ret_regs.push_back(last_register_);
};
if (const TensorTypeNode* ttype = ret_type.as<TensorTypeNode>()) {
alloc_tensor(ttype);
} else if (const TupleTypeNode* ttype = ret_type.as<TupleTypeNode>()) {
for (auto field : ttype->fields) {
alloc_tensor(field.as<TensorTypeNode>());
}
} else {
LOG(FATAL) << "Unsupported return value type";
}
return ret_regs;
}
void EmitInvokePrimitive(const Function& func,
const std::vector<Index>& args_registers,
const std::vector<Index>& arg_registers,
const Type& ret_type) {
std::vector<Index> unpacked_arg_regs;
std::vector<Instruction> allocs;
// Arity calculation must flatten tuples.
size_t arity = 0;
CHECK_EQ(func->params.size(), args_registers.size());
CHECK_EQ(func->params.size(), arg_registers.size());
for (size_t i = 0; i < func->params.size(); i++) {
auto ty = func->params[i]->checked_type();
if (ty.as<TensorTypeNode>()) {
unpacked_arg_regs.push_back(args_registers[i]);
unpacked_arg_regs.push_back(arg_registers[i]);
arity += 1;
} else if (auto tuple_ty = ty.as<TupleTypeNode>()) {
for (size_t f = 0; f < tuple_ty->fields.size(); f++) {
......@@ -424,7 +568,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
<< "only supports non-nested tuples currently "
<< "found " << field;
auto dst = NewRegister();
Emit(Instruction::GetField(args_registers[i], f, dst));
Emit(Instruction::GetField(arg_registers[i], f, dst));
unpacked_arg_regs.push_back(dst);
}
arity += tuple_ty->fields.size();
......@@ -433,30 +577,11 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
}
}
size_t return_val_count = 0;
if (const TensorTypeNode* ttype = ret_type.as<TensorTypeNode>()) {
// Allocate space for the return tensor.
auto alloc = AllocTensorFromType(ttype);
allocs.push_back(alloc);
return_val_count = 1;
} else if (const TupleTypeNode* ttype = ret_type.as<TupleTypeNode>()) {
std::vector<Index> fields_registers;
for (size_t i = 0; i < ttype->fields.size(); ++i) {
auto f = ttype->fields[i];
auto f_type = f.as<TensorTypeNode>();
allocs.push_back(AllocTensorFromType(f_type));
fields_registers.push_back(allocs.back().dst);
}
return_val_count = ttype->fields.size();
} else {
LOG(FATAL) << "Unsupported return value type";
}
arity += return_val_count;
for (auto& alloc : allocs) {
Emit(alloc);
unpacked_arg_regs.push_back(alloc.dst);
auto ret_regs = AllocReturnType(ret_type, func, unpacked_arg_regs);
size_t return_count = ret_regs.size();
arity += return_count;
for (auto reg : ret_regs) {
unpacked_arg_regs.push_back(reg);
}
// Next generate the invoke instruction.
......@@ -477,22 +602,22 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
CHECK_EQ(cfunc->funcs.size(), 1);
auto op_index = -1;
if (context_->seen_funcs.find(cfunc->funcs[0]) == context_->seen_funcs.end()) {
op_index = context_->lowered_funcs.size();
context_->lowered_funcs.push_back(cfunc->funcs[0]);
op_index = context_->cached_funcs.size();
context_->cached_funcs.push_back(cfunc);
context_->seen_funcs[cfunc->funcs[0]] = op_index;
} else {
op_index = context_->seen_funcs[cfunc->funcs[0]];
}
Emit(Instruction::InvokePacked(op_index, arity, return_val_count, unpacked_arg_regs));
Emit(Instruction::InvokePacked(op_index, arity, return_count, unpacked_arg_regs));
if (return_val_count > 1) {
if (return_count > 1) {
// return value is a tuple, we need to create a tuple
std::vector<Index> fields_registers;
for (size_t i = arity - return_val_count; i < arity; ++i) {
for (size_t i = arity - return_count; i < arity; ++i) {
fields_registers.push_back(unpacked_arg_regs[i]);
}
Emit(Instruction::AllocDatatype(0, return_val_count, fields_registers, NewRegister()));
Emit(Instruction::AllocDatatype(0, return_count, fields_registers, NewRegister()));
}
}
......@@ -636,6 +761,8 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
VMCompilerContext* context_;
/*! \brief Target devices. */
TargetsMap targets_;
/*! \brief Host target. */
Target target_host_;
};
......@@ -676,28 +803,18 @@ void VMCompiler::Compile(const Module& mod_ref,
// in the VMFunction table.
PopulateGlobalMap();
// Next we populate constant map.
auto constant_analysis_result = LayoutConstantPool(context_.module);
context_.const_map = std::get<0>(constant_analysis_result);
context_.const_tensor_shape_map = std::get<1>(constant_analysis_result);
// Next we get ready by allocating space for
// the global state.
vm_->functions.resize(context_.module->functions.size());
vm_->constants.resize(context_.const_map.size() + context_.const_tensor_shape_map.size());
for (auto pair : context_.const_map) {
vm_->constants[pair.second] = Object::Tensor(pair.first->data);
}
for (auto pair : context_.const_tensor_shape_map) {
vm_->constants[pair.second.first] = Object::Tensor(pair.second.second);
}
// Next we get ready by allocating space for
// the global state.
vm_->functions.resize(context_.module->functions.size());
for (auto named_func : context_.module->functions) {
auto gvar = named_func.first;
auto func = named_func.second;
VMFunctionCompiler func_compiler(&context_, targets_);
VMFunctionCompiler func_compiler(&context_, targets_, target_host_);
auto vm_func = func_compiler.Compile(gvar, func);
size_t func_index = context_.global_map.at(gvar);
......@@ -711,6 +828,11 @@ void VMCompiler::Compile(const Module& mod_ref,
}
#endif // USE_RELAY_DEBUG
// populate constants
for (auto data : context_.constants) {
vm_->constants.push_back(Object::Tensor(data));
}
LibraryCodegen();
for (auto gv : context_.global_map) {
......@@ -721,11 +843,13 @@ void VMCompiler::Compile(const Module& mod_ref,
Module VMCompiler::OptimizeModule(const Module& mod) {
// TODO(@icemelon9): check number of targets and build config, add more optimization pass
transform::Sequential seq({transform::SimplifyInference(),
transform::ToANormalForm(),
transform::InlinePrimitives(),
// TODO(@wweic): FuseOps pass currently don't handle Let
// For now, we put FuseOps before ToANormalForm to enable it
transform::FuseOps(),
transform::ToANormalForm(),
transform::LambdaLift(),
transform::InlinePrimitives(),
transform::FuseOps()});
transform::InlinePrimitives()});
auto pass_ctx = transform::PassContext::Create();
tvm::With<relay::transform::PassContext> ctx(pass_ctx);
return seq(mod);
......@@ -741,27 +865,36 @@ void VMCompiler::PopulateGlobalMap() {
}
void VMCompiler::LibraryCodegen() {
auto const& lowered_funcs = context_.lowered_funcs;
if (lowered_funcs.size() == 0) {
auto const &cached_funcs = context_.cached_funcs;
if (cached_funcs.size() == 0) {
return;
}
// TODO(@icemelon9): support heterogeneous targets
Target target;
for (auto kv : targets_) {
target = kv.second;
std::unordered_map<std::string, Array<LoweredFunc>> tgt_funcs;
for (auto &cfunc : cached_funcs) {
std::string target_str = cfunc->target->str();
if (tgt_funcs.count(target_str) == 0) {
tgt_funcs.emplace(target_str, Array<LoweredFunc>{cfunc->funcs[0]});
} else {
tgt_funcs[target_str].push_back(cfunc->funcs[0]);
}
}
if (const auto* f = runtime::Registry::Get("relay.backend.build")) {
runtime::Module mod =
(*f)(tvm::Array<LoweredFunc>(lowered_funcs.begin(), lowered_funcs.end()), target,
target_host_);
Map<Target, Array<LoweredFunc>> funcs;
for (auto &it : tgt_funcs) {
funcs.Set(Target::Create(it.first), it.second);
}
if (const auto *f = runtime::Registry::Get("relay.backend.build")) {
// The target is just a dummy arg because funcs already contains corresponding target
// therefore target won't be used in the build function
runtime::Module mod = (*f)(funcs, Target(), target_host_);
CHECK(mod.operator->());
vm_->lib = mod;
} else {
LOG(FATAL) << "relay.backend.build is not registered";
}
size_t primitive_index = 0;
for (auto lfunc : lowered_funcs) {
vm_->primitive_map.insert({lfunc->name, primitive_index++});
for (auto cfunc : cached_funcs) {
vm_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++});
}
}
......
......@@ -72,12 +72,10 @@ struct VMCompilerContext {
TagMap tag_map;
// Map from global var to a unique integer
GlobalMap global_map;
// Map from Const object to its index in const pool
ConstMap const_map;
// Map from Const tensor shape to its index in const pool
ConstTensorShapeMap const_tensor_shape_map;
// List of lowered functions
std::vector<LoweredFunc> lowered_funcs;
// List of constants
std::vector<NDArray> constants;
// List of cached functions
std::vector<CachedFunc> cached_funcs;
// The functions that have been lowered.
std::unordered_map<LoweredFunc, size_t, NodeHash, NodeEqual> seen_funcs;
};
......
......@@ -121,14 +121,25 @@ TVM_REGISTER_API("relay.op._ListOpNames")
TVM_REGISTER_API("relay.op._GetOp").set_body_typed<Op(std::string)>(Op::Get);
TVM_REGISTER_API("relay.op._OpGetAttr")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Op op = args[0];
std::string attr_name = args[1];
auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
if (op_map.count(op)) {
*rv = op_map[op];
}
});
.set_body([](TVMArgs args, TVMRetValue* rv) {
Op op = args[0];
std::string attr_name = args[1];
auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
if (op_map.count(op)) {
*rv = op_map[op];
}
});
TVM_REGISTER_API("relay.op._OpSetAttr")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Op op = args[0];
std::string attr_name = args[1];
runtime::TVMArgValue value = args[2];
int plevel = args[3];
auto& reg =
OpRegistry::Registry()->__REGISTER_OR_GET__(op->name).set_name();
reg.set_attr(attr_name, value, plevel);
});
TVM_REGISTER_API("relay.op._Register")
.set_body([](TVMArgs args, TVMRetValue* rv) {
......
......@@ -528,7 +528,11 @@ bool ReshapeRel(const Array<Type>& types,
used_input_dims.insert(src_idx);
IndexExpr d2 = data_shape[src_idx++];
used_output_dims.insert(oshape.size());
oshape.push_back(d1 * d2);
if (d1.as<Any>() || d2.as<Any>()) {
oshape.push_back(Any::make());
} else {
oshape.push_back(d1 * d2);
}
} else if (svalue == -4) {
// split the source dim s into two dims
// read the left dim and then the right dim (either can be -1)
......@@ -563,6 +567,8 @@ bool ReshapeRel(const Array<Type>& types,
oshape.push_back(d2);
}
}
} else {
CHECK(false) << "Unsupported special value: " << svalue;
}
}
......@@ -608,7 +614,15 @@ Array<Tensor> ReshapeCompute(const Attrs& attrs,
const Target& target) {
const auto* out_ttype = out_type.as<TensorTypeNode>();
CHECK(out_ttype != nullptr);
return { topi::reshape(inputs[0], out_ttype->shape) };
Array<IndexExpr> newshape;
for (auto val : out_ttype->shape) {
if (val->is_type<ir::Any>()) {
newshape.push_back(val.as<ir::Any>()->ToVar());
} else {
newshape.push_back(val);
}
}
return { topi::reshape(inputs[0], newshape) };
}
Expr MakeReshape(Expr data,
......@@ -1108,7 +1122,8 @@ RELAY_REGISTER_OP("arange")
.set_support_level(3)
.add_type_rel("Arange", ArangeRel)
.set_attr<FTVMCompute>("FTVMCompute", ArangeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective)
// TODO(@icemelon): Change arange to kOpaque because FuseOps doesn't consider dynamic shape
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<AnyCodegenStrategy>("AnyCodegenStrategy", kVariableDimensions);
// repeat operator
......
......@@ -295,7 +295,9 @@ RELAY_REGISTER_OP("shape_of")
.add_argument("data", "Tensor", "The input tensor.")
.add_type_rel("ShapeOf", ShapeOfRel)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<TOpPattern>("TOpPattern", kInjective)
// Use kOpaque for shape_of op for now since it won't be performance critic,
// and it makes things easier for dynamic shape func
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
ElemwiseArbitraryLayout)
.set_support_level(10)
......
......@@ -81,16 +81,17 @@ Type ConcreteBroadcast(const TensorType& t1,
for (; i <= std::min(ndim1, ndim2); ++i) {
IndexExpr s1 = t1->shape[ndim1 - i];
IndexExpr s2 = t2->shape[ndim2 - i];
if (EqualCheck(s1, s2)) {
oshape.push_back(s1);
} else if (EqualConstInt(s1, 1)) {
if (EqualConstInt(s1, 1)) {
oshape.push_back(s2);
} else if (EqualConstInt(s2, 1)) {
oshape.push_back(s1);
} else if (s1.as<Any>() && EqualConstInt(s2, 1)) {
// TODO(@jroesch): we need to come back to this
} else if (s1.as<Any>()) {
// s1 == 1 || s1 == s2
oshape.push_back(s2);
} else if (s2.as<Any>() && EqualConstInt(s1, 1)) {
} else if (s2.as<Any>()) {
// s2 == 1 || s2 == s1
oshape.push_back(s1);
} else if (EqualCheck(s1, s2)) {
oshape.push_back(s1);
} else {
RELAY_ERROR(
......
......@@ -915,7 +915,7 @@ class FuseMutator : private ExprMutator {
if (it == gmap_.end()) return "";
std::ostringstream os;
auto *group = it->second->FindRoot();
os << "group=" << group;
os << " /* group=" << group << " */";
return os.str();
});
LOG(INFO) << "Dump of group info:\n" << text;
......
......@@ -120,7 +120,7 @@ class ModulePassNode : public PassNode {
/*!
* \brief Get the pass information/meta data.
*/
PassInfo Info() const { return pass_info; }
PassInfo Info() const override { return pass_info; }
TVM_DLL static ModulePass make(
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func,
......@@ -174,7 +174,7 @@ class FunctionPassNode : public PassNode {
/*!
* \brief Get the pass information/meta data.
*/
PassInfo Info() const { return pass_info; }
PassInfo Info() const override { return pass_info; }
TVM_DLL static FunctionPass make(
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func,
......@@ -220,7 +220,7 @@ class SequentialNode : public PassNode {
/*!
* \brief Get the pass information/meta data.
*/
PassInfo Info() const { return pass_info; }
PassInfo Info() const override { return pass_info; }
/*!
* \brief Check if a pass is enabled.
......
......@@ -451,11 +451,11 @@ std::string StrJoin(T* items, int offset, int cnt, std::string delim = ", ") {
void InstructionPrint(std::ostream& os, const Instruction& instr) {
switch (instr.op) {
case Opcode::Move: {
os << "move $" << instr.dst << " $" << instr.from << std::endl;
os << "move $" << instr.dst << " $" << instr.from;
break;
}
case Opcode::Ret: {
os << "ret $" << instr.result << std::endl;
os << "ret $" << instr.result;
break;
}
case Opcode::Fatal: {
......@@ -469,7 +469,7 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) {
<< ", out: $"
<< StrJoin<RegName>(instr.packed_args, instr.arity - instr.output_size,
instr.output_size, ", $")
<< ")" << std::endl;
<< ")";
break;
}
case Opcode::AllocTensor: {
......@@ -478,71 +478,61 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) {
instr.alloc_tensor.ndim)
<< "] ";
DLDatatypePrint(os, instr.alloc_tensor.dtype);
os << std::endl;
break;
}
case Opcode::AllocTensorReg: {
os << "alloc_tensor_reg $" << instr.dst << " $"
<< instr.alloc_tensor_reg.shape_register << " ";
DLDatatypePrint(os, instr.alloc_tensor_reg.dtype);
os << std::endl;
break;
}
case Opcode::AllocDatatype: {
os << "alloc_data $" << instr.dst << " tag(" << instr.constructor_tag << ") [$"
<< StrJoin<RegName>(instr.datatype_fields, 0, instr.num_fields, ",$") << "]"
<< std::endl;
<< StrJoin<RegName>(instr.datatype_fields, 0, instr.num_fields, ",$") << "]";
break;
}
case Opcode::AllocClosure: {
os << "alloc_closure $" << instr.dst << " VMFunc[" << instr.clo_index
<< "]($" << StrJoin<RegName>(instr.free_vars, 0, instr.num_freevar, ",$")
<< ")"
<< std::endl;
<< ")";
break;
}
case Opcode::If: {
os << "if " << "$" << instr.if_op.test << " " << instr.if_op.target << " "
<< instr.if_op.true_offset << " " << instr.if_op.false_offset
<< std::endl;
<< instr.if_op.true_offset << " " << instr.if_op.false_offset;
break;
}
case Opcode::Invoke: {
os << "invoke $" << instr.dst << " VMFunc[" << instr.func_index << "]($"
<< StrJoin<RegName>(instr.invoke_args_registers, 0, instr.num_args, ",$")
<< ")"
<< std::endl;
<< ")";
break;
}
case Opcode::InvokeClosure: {
os << "invoke_closure $" << instr.dst << " $" << instr.closure << "($"
<< StrJoin<RegName>(instr.closure_args, 0, instr.num_closure_args, ",$")
<< ")"
<< std::endl;
<< ")";
break;
}
case Opcode::LoadConst: {
os << "load_const $" << instr.dst << " Const[" << instr.const_index << "]"
<< std::endl;
os << "load_const $" << instr.dst << " Const[" << instr.const_index << "]";
break;
}
case Opcode::LoadConsti: {
os << "load_consti $" << instr.dst << " Const[" << instr.load_consti.val << "]"
<< std::endl;
os << "load_consti $" << instr.dst << " Const[" << instr.load_consti.val << "]";
break;
}
case Opcode::GetField: {
os << "get_field $" << instr.dst << " $" << instr.object << "["
<< instr.field_index << "]"
<< std::endl;
<< instr.field_index << "]";
break;
}
case Opcode::GetTag: {
os << "get_tag $" << instr.dst << " $" << instr.get_tag.object << std::endl;
os << "get_tag $" << instr.dst << " $" << instr.get_tag.object;
break;
}
case Opcode::Goto: {
os << "goto " << instr.pc_offset << std::endl;
os << "goto " << instr.pc_offset;
break;
}
default:
......@@ -559,9 +549,7 @@ std::ostream& operator<<(std::ostream& os, const Instruction& instr) {
void VMFunctionPrint(std::ostream& os, const VMFunction& vm_func) {
os << vm_func.name << ": " << std::endl;
for (size_t i = 0; i < vm_func.instructions.size(); ++i) {
os << i << ": ";
InstructionPrint(os, vm_func.instructions[i]);
os << ";" << std::endl;
os << i << ": " << vm_func.instructions[i] << ";" << std::endl;
}
}
......@@ -801,7 +789,7 @@ void VirtualMachine::RunLoop() {
while (true) {
main_loop:
auto const& instr = this->code[this->pc];
DLOG(INFO) << "Executing(" << pc << "): ";
DLOG(INFO) << "Executing(" << pc << "): " << instr;
#if USE_RELAY_DEBUG
InstructionPrint(std::cout, instr);
#endif // USE_RELAY_DEBUG
......
......@@ -546,6 +546,8 @@ void InjectInline(ScheduleNode* sch) {
std::vector<Array<Expr> > new_body(sch->stages.size());
std::vector<bool> changed(sch->stages.size(), false);
std::vector<Stmt> new_hybrid_body(sch->stages.size());
std::vector<bool> hybrid_changed(sch->stages.size(), false);
// inline all the ops
for (size_t i = sch->stages.size(); i != 0; --i) {
Stage stage = sch->stages[i - 1];
......@@ -568,6 +570,7 @@ void InjectInline(ScheduleNode* sch) {
for (size_t j = i; j < sch->stages.size(); ++j) {
Stage s = sch->stages[j];
const ComputeOpNode* compute = s->op.as<ComputeOpNode>();
const HybridOpNode* hybrid = s->op.as<HybridOpNode>();
if (compute) {
if (!new_body[j].size()) {
new_body[j] = compute->body;
......@@ -606,6 +609,15 @@ void InjectInline(ScheduleNode* sch) {
}
}
}
} else if (hybrid) {
if (!new_hybrid_body[j].defined()) {
new_hybrid_body[j] = hybrid->body;
}
Stmt new_stmt = ir::Inline(new_hybrid_body[j], stage->op, args, body);
if (!new_stmt.same_as(new_hybrid_body[j])) {
new_hybrid_body[j] = new_stmt;
hybrid_changed[j] = true;
}
}
}
}
......@@ -632,6 +644,17 @@ void InjectInline(ScheduleNode* sch) {
}
s->op = op;
}
} else if (hybrid_changed[i]) {
const HybridOpNode* hybrid = sch->stages[i]->op.as<HybridOpNode>();
CHECK(hybrid);
Operation op = HybridOpNode::make(
hybrid->name, hybrid->tag, hybrid->attrs, hybrid->inputs,
hybrid->outputs, new_hybrid_body[i]);
op = op->ReplaceInputs(op, repl);
for (int idx = 0; idx < s->op->num_outputs(); ++idx) {
repl[s->op.output(idx)] = op.output(idx);
}
s->op = op;
} else {
Operation op = s->op->ReplaceInputs(s->op, repl);
if (!op.same_as(s->op)) {
......
......@@ -18,27 +18,156 @@ import numpy as np
import tvm
from tvm import relay
from tvm.relay import Kind, transform
from tvm.relay.loops import while_loop
from tvm.relay.testing import run_infer_type as infer_type
def int32(val):
return relay.const(val, 'int32')
def any_dims(ndim):
shape = []
for _ in range(ndim):
shape.append(relay.Any())
return tuple(shape)
# TODO(@wweic): because vm doesn't support heterogeneous exec, we can only test
# shape function on CPU.
def verify_any_broadcast(x_shape, y_shape, x_np_shape, y_np_shape, op, np_op):
dtype = 'float32'
x = relay.var('x', shape=x_shape, dtype=dtype)
y = relay.var('y', shape=y_shape, dtype=dtype)
mod = relay.module.Module()
mod["main"] = relay.Function([x, y], op(x, y))
x_np = np.random.uniform(size=x_np_shape).astype(dtype)
y_np = np.random.uniform(size=y_np_shape).astype(dtype)
for kind in ["debug", "vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(x_np, y_np)
tvm.testing.assert_allclose(result.asnumpy(), np_op(x_np, y_np))
def test_any_broadcast():
verify_any_broadcast((relay.Any(),), (3, 2), (1,), (3, 2), relay.add, np.add)
verify_any_broadcast((relay.Any(), 2), (1, 2), (1, 2), (1, 2), relay.add, np.add)
verify_any_broadcast((relay.Any(), 2), (1, 2), (3, 2), (1, 2), relay.add, np.add)
verify_any_broadcast((relay.Any(), 2), (3, 2), (1, 2), (3, 2), relay.add, np.add)
verify_any_broadcast((relay.Any(), 2), (3, relay.Any()), (1, 2), (3, 1), relay.add, np.add)
# The following currently fail because topi compute treats Any as 1
# will requires auto_broadcast buffer to solve the problem
# TODO(@zhiics): Fix this
# verify_any_broadcast((relay.Any(),), (3, 2), (2,), (3, 2), relay.add, np.add)
# verify_any_broadcast((relay.Any(), 2), (3, 2), (3, 2), (3, 2), relay.add, np.add)
def test_any_concat():
x = relay.var('x', shape=(relay.Any(), 2), dtype="float32")
y = relay.var('y', shape=(1, 2), dtype="float32")
z = relay.op.concatenate([x, y], axis=0)
mod = relay.module.Module()
mod["main"] = relay.Function([x, y], z)
x_np = np.random.uniform(size=(3, 2)).astype('float32')
y_np = np.random.uniform(size=(1, 2)).astype('float32')
for kind in ["debug", "vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(x_np, y_np)
ref = np.concatenate([x_np, y_np], axis=0)
tvm.testing.assert_allclose(result.asnumpy(), ref)
def verify_any_reshape(x_shape, newshape, x_np_shape, out_shape):
x = relay.var('x', shape=x_shape, dtype="float32")
y = relay.reshape(x, newshape=newshape)
mod = relay.module.Module()
mod["main"] = relay.Function([x], y)
data = np.random.uniform(size=x_np_shape).astype('float32')
for kind in ["debug", "vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(data).asnumpy()
assert result.shape == out_shape
tvm.testing.assert_allclose(result.flatten(), data.flatten())
def test_any_reshape():
verify_any_reshape(any_dims(3), (1, -1), (2, 3, 4), (1, 24))
verify_any_reshape(any_dims(3), (0, -1), (2, 3, 4), (2, 12))
verify_any_reshape(any_dims(3), (0, -2), (2, 3, 4), (2, 3, 4))
verify_any_reshape(any_dims(3), (-4, 2, -1, -2), (6, 3, 4), (2, 3, 3, 4))
verify_any_reshape(any_dims(3), (-4, -1, 2, -3), (6, 3, 4), (3, 2, 12))
def verify_any_take(data_shape, indices_shape, axis, data_np_shape, indices_np_shape):
mod = relay.Module()
data = relay.var('data', shape=data_shape, dtype='float32')
indices = relay.var('indices', shape=indices_shape, dtype='int32')
y = relay.take(data, indices, axis=axis)
mod["main"] = relay.Function([data, indices], y)
data_np = np.random.uniform(size=data_np_shape).astype('float32')
if axis is None:
max_index = data_np.size
else:
max_index = data_np.shape[axis]
indices_np = np.random.randint(max_index, size=indices_np_shape).astype('int32')
ref = np.take(data_np, indices_np, axis=axis)
for kind in ["debug", "vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(data_np, indices_np)
tvm.testing.assert_allclose(result.asnumpy(), ref)
def test_any_take():
verify_any_take(any_dims(2), (1,), 0, (4, 5), (1,))
verify_any_take(any_dims(2), (), 0, (4, 5), ())
verify_any_take(any_dims(2), (), None, (4, 5), ())
verify_any_take(any_dims(3), any_dims(2), 1, (3, 4, 5), (2, 3))
verify_any_take(any_dims(2), any_dims(3), None, (4, 5), (2, 3, 4))
verify_any_take(any_dims(2), any_dims(4), -1, (4, 5), (2, 3, 4, 5))
def test_any_shape_of():
x = relay.var('x', shape=any_dims(2), dtype='float32')
y = relay.shape_of(x)
mod = relay.module.Module()
mod["main"] = relay.Function([x], y)
data = np.random.uniform(size=(3, 4)).astype('float32')
for kind in ["debug", "vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(data)
tvm.testing.assert_allclose(result.asnumpy(), np.array([3,4]).astype("int64"))
x = relay.var('x', shape=any_dims(3), dtype='float32')
y0 = relay.shape_of(x)
y1 = relay.take(y0, relay.const(1, 'int32'))
mod = relay.module.Module()
mod["main"] = relay.Function([x], y1)
data = np.random.uniform(size=(2, 3, 4)).astype('float32')
for kind in ["debug", "vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(data)
tvm.testing.assert_allclose(result.asnumpy(), np.array(3).astype("int64"))
def test_fused_ops():
x = relay.var('x', shape=(relay.Any(), relay.Any()), dtype='float32')
y0 = x + relay.const(1.0, 'float32')
y1 = y0 * relay.const(2.0, 'float32')
mod = relay.module.Module()
mod["main"] = relay.Function([x], y1)
data = np.random.uniform(size=(5, 4)).astype('float32')
for kind in ["vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(data)
tvm.testing.assert_allclose(result.asnumpy(), (data + 1) * 2)
def test_arange_with_dynamic_shape():
m, n, k = relay.ShapeVar('m'), relay.ShapeVar('n'), relay.ShapeVar('k')
x = relay.var('x', shape=(m.var, n.var, k.var), dtype='float32')
y0 = relay.shape_of(x)
y1 = relay.take(y0, relay.const(0, 'int32'))
y2 = relay.op.arange(y1)
ex = relay.create_executor()
f = relay.Function([x], y2, type_params=[m, n, k])
# TODO(@jroesch): Restore after code generation.
# data = np.random.rand(10, 5, 3).astype('float32')
# result = ex.evaluate(f)(data)
# np.testing.assert_allclose(result.asnumpy(), np.array(range(10)))
def test_dynamic_concat():
y2 = relay.op.arange(y1, dtype="int32")
y3 = y2 + relay.const(1, dtype="int32")
data = np.random.rand(10, 5, 3).astype('float32')
mod = relay.module.Module()
mod["main"] = relay.Function([x], y3, type_params=[m, n, k])
for kind in ["debug", "vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(data)
tvm.testing.assert_allclose(result.asnumpy(), np.array(range(10)).astype("int32")+1)
def test_recursive_concat():
"""
fn @concat_loop(%i: int32, %st: (any, 1)) -> (any, 1) {
if (%i < 10) {
......@@ -66,26 +195,18 @@ def test_dynamic_concat():
start = relay.var('start', shape=(), dtype='int32')
body = loop(start, relay.op.reshape(relay.const(0), newshape=(1, 1)))
func = relay.Function([start], relay.TupleGetItem(body, 1))
func = infer_type(func)
# TODO(@jroesch, @haichen): We should restore this code when codegeneration
# is merged
# ret_shape = func.checked_type.ret_type.shape
# assert len(ret_shape) == 2, "expected 2-dim output"
# assert relay.ir_pass.alpha_eq(ret_shape[0], relay.Any())
# import pdb; pdb.set_trace()
# mod = relay.module.Module()
# print(relay.ir_pass.infer_type(func, mod=mod))
# ret = relay.Call(loop, [relay.const(0, 'int32'), init])
# mod[mod.entry_func] = relay.Function([], ret)
# print(relay.ir_pass.infer_type(mod[mod.entry_func], mod=mod))
# initial = np.array(0.0, dtype='float32').reshape((1,))
# iter_stop = np.array(10, dtype='int32')
# ex = relay.create_executor("debug", mod=mod, ctx=tvm.cpu(), target="llvm")
# result = ex.evaluate(mod.entry_func)()
# np.testing.assert_allclose(result.asnumpy(), np.array(range(10)))
def test_dynamic_concat_with_wrong_annotation():
mod = relay.module.Module()
mod["main"] = func
data = np.array(0.0, dtype='int32')
# TODO(@jroesch): After LambdaLift pass, TypeInfer pass will fail
# so currently we cannot run this test case on VM
for kind in ["debug"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(data)
ref = np.array([0] + list(range(10))).reshape((11, 1)).astype("int32")
np.testing.assert_allclose(result.asnumpy(), ref)
def test_recursive_concat_with_wrong_annotation():
"""
v0.0.1
fn (%start: int32) {
......@@ -133,6 +254,12 @@ def test_dynamic_concat_with_wrong_annotation():
assert "in particular dimension 0 conflicts 2 does not match 1" in str(e)
if __name__ == "__main__":
test_any_broadcast()
test_any_concat()
test_any_reshape()
test_any_take()
test_any_shape_of()
test_fused_ops()
test_arange_with_dynamic_shape()
test_dynamic_concat()
test_dynamic_concat_with_wrong_annotation()
test_recursive_concat()
test_recursive_concat_with_wrong_annotation()
......@@ -104,9 +104,6 @@ def test_serializer():
vm = create_vm(mod)
ser = serializer.Serializer(vm)
stats = ser.stats
assert "scalar" in stats
glbs = ser.globals
assert len(glbs) == 3
assert "f1" in glbs
......@@ -120,8 +117,8 @@ def test_serializer():
code = ser.bytecode
assert "main 5 2 5" in code
assert "f1 3 1 4" in code
assert "f2 3 1 4" in code
assert "f1 2 1 3" in code
assert "f2 2 1 3" in code
code, lib = ser.serialize()
assert isinstance(code, bytearray)
......
......@@ -122,11 +122,13 @@ def test_outer_product():
assert ibody.min.value == 0
assert ibody.extent.name == 'm'
#Check loop body
jbody = ibody.body
jblock = ibody.body
assert isinstance(jblock, tvm.stmt.Block)
jbody = jblock.first
assert isinstance(jbody, tvm.stmt.AssertStmt)
assert isinstance(jbody.message, tvm.expr.StringImm)
assert jbody.message.value == "index out of range!"
jbody = jbody.body
jbody = jblock.rest
assert isinstance(jbody, tvm.stmt.Provide)
assert jbody.func.name == 'c'
assert len(jbody.args) == 2
......
......@@ -52,6 +52,9 @@ inline BroadcastHelper BroadcastShape(const tvm::Array<tvm::Expr>& shape1,
tvm::Expr one(1);
int i;
for (i = 1; i <= std::min(s1_size, s2_size); ++i) {
// TODO(@icemelon9): Need to revisit this part
const Variable* var1 = shape1[s1_size - i].as<Variable>();
const Variable* var2 = shape2[s2_size - i].as<Variable>();
bh.all_vars.push_front(tvm::Var());
if (topi::detail::EqualCheck(shape1[s1_size - i], shape2[s2_size - i])) {
bh.common_shape.push_front(shape1[s1_size - i]);
......@@ -64,6 +67,16 @@ inline BroadcastHelper BroadcastShape(const tvm::Array<tvm::Expr>& shape1,
} else if (topi::detail::EqualCheck(one, shape2[s2_size - i])) {
bh.common_shape.push_front(shape1[s1_size - i]);
bh.vars1.push_front(bh.all_vars[0]);
} else if (var1 && var2) {
bh.common_shape.push_front(max(shape1[s1_size - i], shape2[s2_size - i]));
bh.vars1.push_front(bh.all_vars[0]);
bh.vars2.push_front(bh.all_vars[0]);
} else if (var1) {
bh.common_shape.push_front(shape2[s2_size - i]);
bh.vars2.push_front(bh.all_vars[0]);
} else if (var2) {
bh.common_shape.push_front(shape1[s1_size - i]);
bh.vars1.push_front(bh.all_vars[0]);
} else {
CHECK(false) << "Incompatible broadcast dims: " << shape1[s1_size - i]
<< " and " << shape2[s2_size - i] << " in: "
......
......@@ -1148,9 +1148,9 @@ inline Tensor tensordot(const Tensor& A,
return compute(output_shape, func, name, tag);
}
inline Tensor arange(const Expr start,
const Expr stop,
const Expr step,
inline Tensor arange(const Expr& start,
const Expr& stop,
const Expr& step,
Type dtype,
std::string name = "T_arange",
std::string tag = kInjective) {
......
......@@ -82,9 +82,7 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
func(a, w, c)
rtol = 1e-5
if (kernel > 3):
rtol = 2e-5
rtol = 1e-3
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=rtol)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment