Commit b90620ea by ziheng Committed by Tianqi Chen

[LANG] Generalize compute to tensor region (#1476)

parent 3d62cf7c
Subproject commit 4f0564ec769477c66d480dd966088f172050c874 Subproject commit 946a54012d0c390675ab5b46cd990838d4183d6f
...@@ -108,6 +108,8 @@ class Range : public HalideIR::IR::Range { ...@@ -108,6 +108,8 @@ class Range : public HalideIR::IR::Range {
TVM_DLL static Range make_by_min_extent(Expr min, Expr extent); TVM_DLL static Range make_by_min_extent(Expr min, Expr extent);
}; };
using Region = Array<Range>;
/*! /*!
* \brief Type of iteration variable. * \brief Type of iteration variable.
* Each IterVar have a specific type. * Each IterVar have a specific type.
......
...@@ -49,7 +49,7 @@ class OperationNode : public FunctionBaseNode { ...@@ -49,7 +49,7 @@ class OperationNode : public FunctionBaseNode {
} }
/*! /*!
* \return The list of iteration variable at root * \return The list of iteration variable at root
* \note root_iter_vars dedides the shape of the outputs. * \note root_iter_vars decides the shape of the outputs.
*/ */
virtual Array<IterVar> root_iter_vars() const = 0; virtual Array<IterVar> root_iter_vars() const = 0;
/*! /*!
...@@ -240,6 +240,74 @@ class TVM_DLL ComputeOpNode : public OperationNode { ...@@ -240,6 +240,74 @@ class TVM_DLL ComputeOpNode : public OperationNode {
}; };
/*! /*!
* \brief A TenorCompute op that compute a tensor with an tensor intrinsic.
*/
class TensorComputeOpNode : public OperationNode {
public:
/*! \brief IterVar on each axis */
Array<IterVar> axis;
/*! \brief IterVar on each reduction axis, if the intrin will use the reduce axis */
Array<IterVar> reduce_axis;
/*! \brief number of axes that can be scheduled */
int schedulable_ndim;
/*! \brief TensorIntrin used to compute */
TensorIntrin intrin;
/*! \brief input tensors of intrin */
Array<Tensor> inputs;
/*! \brief region of input tensors */
Array<Region> input_regions;
/*! \brief constructor */
TensorComputeOpNode() {}
// override functions
int num_outputs() const final;
Array<IterVar> root_iter_vars() const final;
Type output_dtype(size_t i) const final;
Array<Expr> output_shape(size_t i) const final;
Array<Tensor> InputTensors() const final;
Operation ReplaceInputs(
const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const final;
void PropBoundToInputs(
const Operation& self,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound(
const Operation& self,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(
const Stage& stage,
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final;
Stmt BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const final;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("tag", &tag);
v->Visit("axis", &axis);
v->Visit("reduce_axis", &reduce_axis);
v->Visit("schedulable_ndim", &schedulable_ndim);
v->Visit("intrin", &intrin);
v->Visit("inputs", &inputs);
v->Visit("input_regions", &input_regions);
}
static Operation make(std::string name,
std::string tag,
Array<IterVar> axis,
Array<IterVar> reduce_axis,
int schedulable_ndim,
TensorIntrin intrin,
Array<Tensor> tensors,
Array<Region> regions);
static constexpr const char* _type_key = "TensorComputeOp";
TVM_DECLARE_NODE_TYPE_INFO(TensorComputeOpNode, OperationNode);
};
/*!
* \brief Symbolic scan. * \brief Symbolic scan.
*/ */
class ScanOpNode : public OperationNode { class ScanOpNode : public OperationNode {
...@@ -326,7 +394,7 @@ class ExternOpNode : public OperationNode { ...@@ -326,7 +394,7 @@ class ExternOpNode : public OperationNode {
public: public:
/*! \brief The input tensors */ /*! \brief The input tensors */
Array<Tensor> inputs; Array<Tensor> inputs;
/*! \brief Symbolic placeholder representationinputs */ /*! \brief Symbolic placeholder representation of inputs */
Array<Buffer> input_placeholders; Array<Buffer> input_placeholders;
/*! \brief Symbolic placeholder representation of outputs */ /*! \brief Symbolic placeholder representation of outputs */
Array<Buffer> output_placeholders; Array<Buffer> output_placeholders;
......
...@@ -89,5 +89,58 @@ class TensorIntrinNode : public Node { ...@@ -89,5 +89,58 @@ class TensorIntrinNode : public Node {
inline const TensorIntrinNode* TensorIntrin::operator->() const { inline const TensorIntrinNode* TensorIntrin::operator->() const {
return static_cast<const TensorIntrinNode*>(node_.get()); return static_cast<const TensorIntrinNode*>(node_.get());
} }
// Internal node container of tensor intrinsic calling.
class TensorIntrinCallNode;
/*! \brief Tensor intrinsic calling node. */
class TensorIntrinCall : public NodeRef {
public:
TensorIntrinCall() {}
explicit TensorIntrinCall(NodePtr<Node> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const TensorIntrinCallNode* operator->() const;
/*! \brief specify container node */
using ContainerType = TensorIntrinCallNode;
};
class TensorIntrinCallNode : public Node {
public:
/*! \brief the tensor intrinsic */
TensorIntrin intrin;
/*! \brief input tensors of the intrinsic */
Array<Tensor> tensors;
/*! \brief regions of input tensors */
Array<Region> regions;
/*!
* \brief IterVar on each reduction axis, if the
* intrin will use the reduce axis
*/
Array<IterVar> reduce_axis;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("intrin", &intrin);
v->Visit("tensors", &tensors);
v->Visit("regions", &regions);
v->Visit("reduce_axis", &reduce_axis);
}
static TensorIntrinCall make(TensorIntrin intrin,
Array<Tensor> tensors,
Array<Region> regions,
Array<IterVar> reduce_axis);
static constexpr const char* _type_key = "TensorIntrinCall";
TVM_DECLARE_NODE_TYPE_INFO(TensorIntrinCallNode, Node);
};
inline const TensorIntrinCallNode* TensorIntrinCall::operator->() const {
return static_cast<const TensorIntrinCallNode*>(node_.get());
}
} // namespace tvm } // namespace tvm
#endif // TVM_TENSOR_INTRIN_H_ #endif // TVM_TENSOR_INTRIN_H_
...@@ -243,24 +243,43 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None): ...@@ -243,24 +243,43 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None):
raise ValueError("nested tag is not allowed for now") raise ValueError("nested tag is not allowed for now")
tag = _tag.TagScope.get_current().tag tag = _tag.TagScope.get_current().tag
shape = (shape,) if isinstance(shape, _expr.Expr) else shape shape = (shape,) if isinstance(shape, _expr.Expr) else shape
# for python3
shape = tuple([int(s) if isinstance(s, float) else s for s in shape])
ndim = len(shape) ndim = len(shape)
code = fcompute.__code__ code = fcompute.__code__
if fcompute.__code__.co_argcount == 0: out_ndim = ndim
if code.co_argcount == 0:
arg_names = ["i%d" % i for i in range(ndim)] arg_names = ["i%d" % i for i in range(ndim)]
else: else:
arg_names = code.co_varnames[:code.co_argcount] arg_names = code.co_varnames[:code.co_argcount]
out_ndim = code.co_argcount
if ndim != len(arg_names): if out_ndim != len(arg_names):
raise ValueError("fcompute do not match dimension, ndim=%d" % ndim) raise ValueError("fcompute do not match dimension, ndim=%d" % ndim)
dim_var = [_IterVar((0, s), x, 0) for x, s in zip(arg_names, shape)] dim_var = [_IterVar((0, s), x, 0) for x, s in zip(arg_names, shape[:out_ndim])]
body = fcompute(*[v.var for v in dim_var]) body = fcompute(*[v.var for v in dim_var])
if isinstance(body, _tensor.TensorIntrinCall):
for i, s in enumerate(shape[out_ndim:]):
var_name = "ax" + str(i)
dim_var.append(_IterVar((0, s), var_name, 4))
op_node = _api_internal._TensorComputeOp(name,
tag,
dim_var,
body.reduce_axis,
out_ndim,
body.intrin,
body.tensors,
body.regions)
else:
if not isinstance(body, (list, tuple)): if not isinstance(body, (list, tuple)):
body = [body] body = [body]
body = convert(body) body = convert(body)
op_node = _api_internal._ComputeOp( op_node = _api_internal._ComputeOp(
name, tag, attrs, dim_var, body) name, tag, attrs, dim_var, body)
num = op_node.num_outputs num = op_node.num_outputs
outputs = tuple(op_node.output(i) for i in range(num)) outputs = tuple(op_node.output(i) for i in range(num))
return outputs[0] if num == 1 else outputs return outputs[0] if num == 1 else outputs
...@@ -529,14 +548,14 @@ def decl_buffer(shape, ...@@ -529,14 +548,14 @@ def decl_buffer(shape,
dtype = float32 if dtype is None else dtype dtype = float32 if dtype is None else dtype
strides = () if strides is None else strides strides = () if strides is None else strides
if offset_factor != 0 and elem_offset is None: if offset_factor != 0 and elem_offset is None:
elem_offset = var('%s_elem_offset' % name, shape[0].dtype) shape_dtype = shape[0].dtype if hasattr(shape[0], "dtype") else "int32"
elem_offset = var('%s_elem_offset' % name, shape_dtype)
if data is None: if data is None:
data = var(name, "handle") data = var(name, "handle")
return _api_internal._Buffer( return _api_internal._Buffer(
data, dtype, shape, strides, elem_offset, name, scope, data, dtype, shape, strides, elem_offset, name, scope,
data_alignment, offset_factor) data_alignment, offset_factor)
def _IterVar(dom, name, iter_type, thread_tag=''): def _IterVar(dom, name, iter_type, thread_tag=''):
"""Internal function to create IterVar """Internal function to create IterVar
......
...@@ -30,6 +30,11 @@ class TensorSlice(NodeGeneric, _expr.ExprOp): ...@@ -30,6 +30,11 @@ class TensorSlice(NodeGeneric, _expr.ExprOp):
"""Data content of the tensor.""" """Data content of the tensor."""
return self.tensor.dtype return self.tensor.dtype
@register_node
class TensorIntrinCall(NodeBase):
"""Intermediate structure for calling a tensor intrinsic."""
pass
itervar_cls = None itervar_cls = None
...@@ -106,6 +111,7 @@ class Tensor(NodeBase, _expr.ExprOp): ...@@ -106,6 +111,7 @@ class Tensor(NodeBase, _expr.ExprOp):
return "%s.v%d" % (op.name, self.value_index) return "%s.v%d" % (op.name, self.value_index)
class Operation(NodeBase): class Operation(NodeBase):
"""Represent an operation that generate a tensor""" """Represent an operation that generate a tensor"""
...@@ -156,6 +162,12 @@ class ComputeOp(Operation): ...@@ -156,6 +162,12 @@ class ComputeOp(Operation):
@register_node @register_node
class TensorComputeOp(Operation):
"""Tensor operation."""
pass
@register_node
class ScanOp(Operation): class ScanOp(Operation):
"""Scan operation.""" """Scan operation."""
@property @property
......
...@@ -6,9 +6,25 @@ from . import expr as _expr ...@@ -6,9 +6,25 @@ from . import expr as _expr
from . import stmt as _stmt from . import stmt as _stmt
from . import make as _make from . import make as _make
from . import tensor as _tensor from . import tensor as _tensor
from . import schedule as _schedule
from .build_module import current_build_config from .build_module import current_build_config
from ._ffi.node import NodeBase, register_node from ._ffi.node import NodeBase, register_node
def _get_region(tslice):
region = []
for idx in tslice.indices:
if isinstance(idx, slice):
assert idx.step is None
region.append(_api.Range(idx.start, idx.stop))
else:
if isinstance(idx, _schedule.IterVar):
begin = idx.var
else:
begin = idx
region.append(_make.range_by_min_extent(begin, 1))
return region
@register_node @register_node
class TensorIntrin(NodeBase): class TensorIntrin(NodeBase):
"""Tensor intrinsic functions for certain computation. """Tensor intrinsic functions for certain computation.
...@@ -17,8 +33,16 @@ class TensorIntrin(NodeBase): ...@@ -17,8 +33,16 @@ class TensorIntrin(NodeBase):
-------- --------
decl_tensor_intrin: Construct a TensorIntrin decl_tensor_intrin: Construct a TensorIntrin
""" """
pass def __call__(self, *args, **kwargs):
tensors = [x.tensor for x in args]
regions = [_get_region(x) for x in args]
reduce_axis = []
if "reduce_axis" in kwargs:
reduce_axis = kwargs["reduce_axis"]
if not isinstance(reduce_axis, (list, tuple)):
reduce_axis = [reduce_axis]
reduce_axis = _api.convert(reduce_axis)
return _api_internal._TensorIntrinCall(self, tensors, regions, reduce_axis)
def decl_tensor_intrin(op, def decl_tensor_intrin(op,
fcompute, fcompute,
......
...@@ -239,6 +239,14 @@ TVM_REGISTER_API("_TensorIntrin") ...@@ -239,6 +239,14 @@ TVM_REGISTER_API("_TensorIntrin")
args[6]); args[6]);
}); });
TVM_REGISTER_API("_TensorIntrinCall")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TensorIntrinCallNode::make(args[0],
args[1],
args[2],
args[3]);
});
TVM_REGISTER_API("_TensorEqual") TVM_REGISTER_API("_TensorEqual")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Tensor() == args[1].operator Tensor(); *ret = args[0].operator Tensor() == args[1].operator Tensor();
...@@ -278,6 +286,18 @@ TVM_REGISTER_API("_ScanOp") ...@@ -278,6 +286,18 @@ TVM_REGISTER_API("_ScanOp")
args[7]); args[7]);
}); });
TVM_REGISTER_API("_TensorComputeOp")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TensorComputeOpNode::make(args[0],
args[1],
args[2],
args[3],
args[4],
args[5],
args[6],
args[7]);
});
TVM_REGISTER_API("_ExternOp") TVM_REGISTER_API("_ExternOp")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ExternOpNode::make(args[0], *ret = ExternOpNode::make(args[0],
......
...@@ -10,6 +10,8 @@ ...@@ -10,6 +10,8 @@
namespace tvm { namespace tvm {
// Tensor
Expr Tensor::operator()(Array<Var> indices) const { Expr Tensor::operator()(Array<Var> indices) const {
Array<Expr> arr(indices.begin(), indices.end()); Array<Expr> arr(indices.begin(), indices.end());
return operator()(arr); return operator()(arr);
...@@ -26,6 +28,15 @@ Expr Tensor::operator()(Array<Expr> indices) const { ...@@ -26,6 +28,15 @@ Expr Tensor::operator()(Array<Expr> indices) const {
return n; return n;
} }
Tensor Operation::output(size_t i) const {
auto node = make_node<TensorNode>();
node->op = *this;
node->value_index = i;
node->dtype = (*this)->output_dtype(i);
node->shape = (*this)->output_shape(i);
return Tensor(node);
}
Tensor TensorNode::make(Array<Expr> shape, Tensor TensorNode::make(Array<Expr> shape,
Type dtype, Type dtype,
Operation op, Operation op,
...@@ -46,14 +57,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -46,14 +57,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
TVM_REGISTER_NODE_TYPE(TensorNode); TVM_REGISTER_NODE_TYPE(TensorNode);
Tensor Operation::output(size_t i) const {
auto node = make_node<TensorNode>(); // TensorIntrin
node->op = *this;
node->value_index = i;
node->dtype = (*this)->output_dtype(i);
node->shape = (*this)->output_shape(i);
return Tensor(node);
}
TensorIntrin TensorIntrinNode::make(std::string name, TensorIntrin TensorIntrinNode::make(std::string name,
Operation op, Operation op,
...@@ -79,4 +84,27 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -79,4 +84,27 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
}); });
TVM_REGISTER_NODE_TYPE(TensorIntrinNode); TVM_REGISTER_NODE_TYPE(TensorIntrinNode);
// TensorIntrinCall
TensorIntrinCall TensorIntrinCallNode::make(TensorIntrin intrin,
Array<Tensor> tensors,
Array<Region> regions,
Array<IterVar> reduce_axis) {
auto n = make_node<TensorIntrinCallNode>();
n->intrin = std::move(intrin);
n->tensors = std::move(tensors);
n->regions = std::move(regions);
n->reduce_axis = std::move(reduce_axis);
return TensorIntrinCall(n);
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TensorIntrinCallNode>([](const TensorIntrinCallNode *n, IRPrinter *p) {
p->stream << "TensorIntrinCall(intrin=" << n->intrin << ", " << n << ")";
});
TVM_REGISTER_NODE_TYPE(TensorIntrinCallNode);
} // namespace tvm } // namespace tvm
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include "compute_op.h" #include "compute_op.h"
#include "op_util.h" #include "op_util.h"
#include "../schedule/message_passing.h" #include "../schedule/message_passing.h"
#include "../arithmetic/compute_expr.h"
namespace tvm { namespace tvm {
...@@ -545,4 +546,38 @@ static void VerifyComputeOp(const ComputeOpNode* op) { ...@@ -545,4 +546,38 @@ static void VerifyComputeOp(const ComputeOpNode* op) {
v.Run(); v.Run();
} }
Stmt TransformUpdate(const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
const ComputeLoopNest& n,
Stmt body,
Stmt update) {
Array<Expr> conds;
std::unordered_set<const Variable*> banned;
for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
IterVar iv = stage->leaf_iter_vars[i];
auto iit = stage->iter_var_attrs.find(iv);
if (iit != stage->iter_var_attrs.end()) {
const IterVarAttr& attr = (*iit).second;
if (attr->iter_type == kTensorized) {
break;
}
}
if (iv->iter_type == kCommReduce) {
auto vit = dom_map.find(iv);
CHECK(vit != dom_map.end());
const Range& vrange = vit->second;
conds.push_back(likely(iv->var > vrange->min));
banned.insert(iv->var.get());
}
}
for (const Expr& pred : n.main_predicates) {
if (ir::ExprUseVar(pred, banned)) {
LOG(FATAL) << "Tensorize update transform failed, the condition "
<< pred << " has a conflict with the reset condition";
}
}
return IfThenElse::make(arith::ComputeReduce<ir::Or>(conds, const_true(1)),
update, body);
}
} // namespace tvm } // namespace tvm
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
namespace tvm { namespace tvm {
// loop nest structure for general compute // loop nest structure for general compute
// This the the loop nest structured used in compute. // This the loop nest structured used in compute.
// Does not include the loop body. // Does not include the loop body.
struct ComputeLoopNest { struct ComputeLoopNest {
// The common number of loops between init and main // The common number of loops between init and main
...@@ -73,6 +73,21 @@ Stmt MakeTensorize(const ComputeOpNode* self, ...@@ -73,6 +73,21 @@ Stmt MakeTensorize(const ComputeOpNode* self,
const Stage& stage, const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map, const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop); bool debug_keep_trivial_loop);
/*!
* \brief Transform the update part when there is no init func in tensorizing
* \param stage The stage for tensorizing.
* \param dom_map The range of each iter var.
* \param n The loop nest structured used in compute.
* \param body The body func in tensorize intrin
* \param update The update func in tensorize intrin
* \return Transformed result.
*/
Stmt TransformUpdate(const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
const ComputeLoopNest& n,
Stmt body,
Stmt update);
} // namespace tvm } // namespace tvm
#endif // TVM_OP_COMPUTE_OP_H_ #endif // TVM_OP_COMPUTE_OP_H_
/*!
* Copyright (c) 2017 by Contributors
* \brief Tensor Compute Op.
* \file tensor_compute_op.cc
*/
#include <tvm/operation.h>
#include <tvm/arithmetic.h>
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_pass.h>
#include <unordered_set>
#include "./op_util.h"
#include "./compute_op.h"
#include "../arithmetic/compute_expr.h"
namespace tvm {
using namespace ir;
// TensorComputeOpNode
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TensorComputeOpNode>([](const TensorComputeOpNode *op,
IRPrinter *p) {
p->stream << "tensor_compute_op(" << op->name << ", " << op << ")";
});
TVM_REGISTER_NODE_TYPE(TensorComputeOpNode);
int TensorComputeOpNode::num_outputs() const {
return static_cast<int>(this->intrin->buffers.size() - this->inputs.size());
}
Array<IterVar> TensorComputeOpNode::root_iter_vars() const {
Array<IterVar> ret = axis;
for (IterVar iv : reduce_axis) {
ret.push_back(iv);
}
return ret;
}
Type TensorComputeOpNode::output_dtype(size_t i) const {
return this->intrin->buffers[this->inputs.size() + i]->dtype;
}
Array<Expr> TensorComputeOpNode::output_shape(size_t i) const {
Array<Expr> shape;
for (const auto& ivar : this->axis) {
shape.push_back(ivar->dom->extent);
}
return shape;
}
Operation TensorComputeOpNode::make(std::string name,
std::string tag,
Array<IterVar> axis,
Array<IterVar> reduce_axis,
int schedulable_ndim,
TensorIntrin intrin,
Array<Tensor> tensors,
Array<Region> regions) {
auto n = make_node<TensorComputeOpNode>();
n->name = std::move(name);
n->tag = std::move(tag);
n->axis = std::move(axis);
n->reduce_axis = std::move(reduce_axis);
n->schedulable_ndim = std::move(schedulable_ndim);
n->intrin = std::move(intrin);
n->inputs = std::move(tensors);
n->input_regions = std::move(regions);
return Operation(n);
}
Array<Tensor> TensorComputeOpNode::InputTensors() const {
return inputs;
}
Operation TensorComputeOpNode::ReplaceInputs(
const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const {
CHECK_EQ(self.operator->(), this);
auto n = make_node<TensorComputeOpNode>(*this);
auto intrin = make_node<TensorIntrinNode>(*(this->intrin.operator->()));
intrin->body = op::ReplaceTensor(this->intrin->body, rmap);
if (intrin->reduce_init.defined()) {
intrin->reduce_init = op::ReplaceTensor(this->intrin->reduce_init, rmap);
}
if (intrin->reduce_update.defined()) {
intrin->reduce_update = op::ReplaceTensor(this->intrin->reduce_update, rmap);
}
for (size_t i = 0; i < n->inputs.size(); ++i) {
Tensor t = n->inputs[i];
if (rmap.count(t)) {
n->inputs.Set(i, rmap.at(t));
}
}
if (intrin->body.same_as(n->intrin->body) &&
intrin->reduce_init.same_as(n->intrin->reduce_init) &&
intrin->reduce_update.same_as(n->intrin->reduce_update) &&
inputs.same_as(n->inputs)) {
return self;
} else {
n->intrin = TensorIntrin(intrin);
return Operation(n);
}
}
void TensorComputeOpNode::PropBoundToInputs(
const Operation& self,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
for (size_t i = 0; i < this->inputs.size(); ++i) {
Tensor t = this->inputs[i];
Region region = input_regions[i];
auto it = out_dom_map->find(t);
if (it == out_dom_map->end()) continue;
TensorDom& dom = it->second;
for (size_t j = 0; j < t.ndim(); ++j) {
dom.data[j].emplace_back(EvalSet(region[j], dom_map));
}
}
}
void TensorComputeOpNode::GatherBound(
const Operation& self,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const {
const TensorDom& tdom = tensor_dom.at(self.output(0));
for (size_t i = 0; i < this->axis.size(); ++i) {
Range r = arith::Union(tdom.data.at(i)).cover_range(this->axis[i]->dom);
CHECK(!out_dom_map->count(this->axis[i]));
(*out_dom_map)[this->axis[i]] = r;
}
for (size_t i = 0; i < this->reduce_axis.size(); ++i) {
CHECK(!out_dom_map->count(this->reduce_axis[i]));
(*out_dom_map)[this->reduce_axis[i]] = this->reduce_axis[i]->dom;
}
}
Stmt TensorComputeOpNode::BuildRealize(
const Stage& stage,
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const {
CHECK_EQ(stage->op.get(), this);
HalideIR::Internal::Region bounds;
for (IterVar iv : this->axis) {
bounds.push_back(realize_map.at(iv));
}
Stmt realize = body;
for (int i = this->num_outputs(); i > 0; --i) {
Tensor t = stage->op.output(i-1);
realize = ir::Realize::make(t->op, t->value_index,
t->dtype, bounds, const_true(), realize);
// alignment requirement, only useful for compute
for (int i = 0; i < schedulable_ndim; ++i) {
auto it = stage->iter_var_attrs.find(this->axis[i]);
if (it != stage->iter_var_attrs.end()) {
IterVarAttr attr = (*it).second;
if (attr->dim_align_factor != 0) {
Array<Expr> tuple = {static_cast<int>(i),
attr->dim_align_factor,
attr->dim_align_offset};
realize = ir::AttrStmt::make(
t, ir::attr::buffer_dim_align,
Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic),
realize);
}
}
}
}
return realize;
}
ComputeLoopNest MakeLoopNest(
const TensorComputeOpNode* self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) {
CHECK_EQ(stage->op.operator->(), self);
ComputeLoopNest ret;
// make main loop nest
ret.main_nest = op::MakeLoopNest(
stage, dom_map, 0, false, std::unordered_set<IterVar>(), &ret.main_vmap,
debug_keep_trivial_loop);
ret.main_predicates = schedule::MakeBoundCheck(
stage, dom_map, ret.main_vmap, false,
std::unordered_set<IterVar>());
for (auto& e : ret.main_predicates) {
e = likely(e);
}
if (stage->store_predicate.defined()) {
ret.main_predicates.push_back(stage->store_predicate);
}
if (self->reduce_axis.size() != 0) {
// try to find the location to insert the initialization.
// Fuse the initialization and provide loop when possible.
std::unordered_map<IterVar, int> update_state;
for (IterVar iv : self->reduce_axis) {
update_state[iv] = 2;
}
for (int i = 0; i < self->schedulable_ndim; ++i) {
update_state[self->axis[i]] = 1;
}
// find which iter var is related to reduction and which is related to axis.
schedule::PassDownBitMaskOr(stage, &update_state);
auto leaf_iter_vars = stage->leaf_iter_vars;
// first first loop that is related to reduction.
size_t begin_loop = leaf_iter_vars.size();
for (size_t i = 0; i < leaf_iter_vars.size(); ++i) {
auto iv = leaf_iter_vars[i];
int flag = update_state.at(iv);
if ((flag & 2) != 0) {
begin_loop = i; break;
}
ret.init_vmap[iv] = ret.main_vmap.at(iv);
}
ret.num_common_loop = begin_loop;
// skip loops that does not relates to axis.
std::unordered_set<IterVar> skip_iter;
for (auto kv : update_state) {
int flag = kv.second;
if ((flag & 1) == 0) skip_iter.insert(kv.first);
}
ret.init_nest = op::MakeLoopNest(
stage, dom_map, begin_loop, true,
skip_iter, &(ret.init_vmap), debug_keep_trivial_loop);
ret.init_predicates = schedule::MakeBoundCheck(
stage, dom_map, ret.init_vmap, true, skip_iter);
for (auto& e : ret.init_predicates) {
e = likely(e);
}
} else {
CHECK_EQ(ret.main_nest.size(), stage->leaf_iter_vars.size() + 1);
ret.num_common_loop = stage->leaf_iter_vars.size();
}
// copy elison here.
return ret;
}
Stmt TensorComputeOpNode::BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const {
CHECK_EQ(stage->op.operator->(), this);
// Start bind data.
Stmt nop = Evaluate::make(0);
std::vector<Stmt> input_bind_nest, output_bind_nest;
Array<Tensor> inputs = this->InputTensors();
// input binding
size_t num_inputs = inputs.size();
for (size_t i = 0; i < num_inputs; ++i) {
Tensor tensor = inputs[i];
Region region = this->input_regions[i];
Buffer buffer = this->intrin->buffers[i];
Array<NodeRef> bind_spec{buffer, tensor};
Array<Expr> tuple;
for (size_t i = 0; i < region.size(); ++i) {
tuple.push_back(region[i]->min);
tuple.push_back(region[i]->extent);
}
input_bind_nest.emplace_back(AttrStmt::make(
bind_spec, ir::attr::buffer_bind_scope,
Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop));
}
// output binding
for (int i = 0; i < this->num_outputs(); ++i) {
Tensor tensor = stage->op.output(i);
Buffer buffer = this->intrin->buffers[num_inputs + i];
Array<NodeRef> bind_spec{buffer, tensor};
Array<Expr> tuple;
for (size_t i = 0; i < this->axis.size(); ++i) {
auto ivar = this->axis[i];
if (i < static_cast<size_t>(this->schedulable_ndim)) {
tuple.push_back(ivar->var);
tuple.push_back(1);
} else {
Range dom = ivar->dom;
tuple.push_back(dom->min);
tuple.push_back(dom->extent);
}
}
output_bind_nest.emplace_back(AttrStmt::make(
bind_spec, ir::attr::buffer_bind_scope,
Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop));
}
// Check variable remap
std::unordered_map<const Variable*, Expr> vmap;
ir::ArgBinder binder(&vmap);
size_t tloc = stage->leaf_iter_vars.size();
ComputeLoopNest n = MakeLoopNest(this, stage, dom_map, debug_keep_trivial_loop);
if (this->reduce_axis.size() == 0) {
std::vector<std::vector<Stmt> > nest(
n.main_nest.begin(), n.main_nest.begin() + tloc + 1);
nest.emplace_back(op::MakeIfNest(n.main_predicates));
CHECK_EQ(n.init_predicates.size(), 0U);
CHECK(this->intrin->body.defined())
<< "Normal store op for intrin " << this << " is not defined";
Stmt body = MergeNest(output_bind_nest, this->intrin->body);
body = MergeNest(input_bind_nest, body);
body = ir::Substitute(body, vmap);
body = MergeNest(binder.asserts(), body);
body = op::Substitute(body, n.main_vmap);
Stmt ret = MergeNest(nest, body);
return ret;
} else {
// Need to split reduction
CHECK(this->intrin->reduce_update.defined())
<< "Reduction update op is not defined";
// Need init and update steps
CHECK_NE(this->reduce_axis.size(), 0U);
std::vector<std::vector<Stmt> > common(
n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1);
std::vector<std::vector<Stmt> > update_nest(
n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.begin() + tloc + 1);
update_nest.emplace_back(op::MakeIfNest(n.main_predicates));
if (this->intrin->reduce_init.defined()) {
// init nest
std::vector<std::vector<Stmt> > init_nest(
n.init_nest.begin(), n.init_nest.begin() + tloc + 1);
init_nest.emplace_back(op::MakeIfNest(n.init_predicates));
Stmt init = MergeNest(output_bind_nest, this->intrin->reduce_init);
init = op::Substitute(init, n.init_vmap);
init = MergeNest(init_nest, init);
// The update
Stmt update = MergeNest(output_bind_nest, this->intrin->reduce_update);
update = MergeNest(input_bind_nest, update);
update = ir::Substitute(update, vmap);
update = MergeNest(binder.asserts(), update);
update = op::Substitute(update, n.main_vmap);
update = MergeNest(update_nest, update);
return MergeNest(common, Block::make(init, update));
} else {
// When init op is not available, use body op for reset in the first iter.
CHECK(this->intrin->body.defined())
<< "Normal body op is not defined";
Stmt update = TransformUpdate(stage, dom_map, n,
this->intrin->body,
this->intrin->reduce_update);
update = MergeNest(output_bind_nest, update);
update = MergeNest(input_bind_nest, update);
update = ir::Substitute(update, vmap);
update = MergeNest(binder.asserts(), update);
update = op::Substitute(update, n.main_vmap);
update = MergeNest(update_nest, update);
return MergeNest(common, update);
}
}
}
} // namespace tvm
...@@ -10,7 +10,6 @@ ...@@ -10,7 +10,6 @@
#include "op_util.h" #include "op_util.h"
#include "compute_op.h" #include "compute_op.h"
#include "../schedule/message_passing.h" #include "../schedule/message_passing.h"
#include "../arithmetic/compute_expr.h"
namespace tvm { namespace tvm {
...@@ -323,50 +322,6 @@ void VerifyTensorizeBody( ...@@ -323,50 +322,6 @@ void VerifyTensorizeBody(
} }
} }
/*!
* \brief Transform the update part when there is no init func in tensorizing
* \param stage The stage for tensorizing.
* \param dom_map The range of each iter var.
* \param n The loop nest structured used in compute.
* \param body The body func in tensorize intrin
* \param update The update func in tensorize intrin
* \return Transformed result.
*/
Stmt TransformUpdate(const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
const ComputeLoopNest& n,
Stmt body,
Stmt update) {
Array<Expr> conds;
std::unordered_set<const Variable*> banned;
for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
IterVar iv = stage->leaf_iter_vars[i];
auto iit = stage->iter_var_attrs.find(iv);
if (iit != stage->iter_var_attrs.end()) {
const IterVarAttr& attr = (*iit).second;
if (attr->iter_type == kTensorized) {
break;
}
}
if (iv->iter_type == kCommReduce) {
auto vit = dom_map.find(iv);
CHECK(vit != dom_map.end());
const Range& vrange = vit->second;
conds.push_back(likely(iv->var > vrange->min));
banned.insert(iv->var.get());
}
}
for (const Expr& pred : n.main_predicates) {
if (ir::ExprUseVar(pred, banned)) {
LOG(FATAL) << "Tensorize update transform failed, the condition "
<< pred << " has a conflict with the reset condition";
}
}
return IfThenElse::make(arith::ComputeReduce<ir::Or>(conds, const_true(1)),
update, body);
}
Stmt MakeTensorize(const ComputeOpNode* self, Stmt MakeTensorize(const ComputeOpNode* self,
const Stage& stage, const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map, const std::unordered_map<IterVar, Range>& dom_map,
......
...@@ -91,7 +91,9 @@ void ArgBinder::BindBuffer(const Buffer& arg, ...@@ -91,7 +91,9 @@ void ArgBinder::BindBuffer(const Buffer& arg,
// bind pointer and offset. // bind pointer and offset.
if (is_zero(arg->elem_offset)) { if (is_zero(arg->elem_offset)) {
CHECK(is_zero(value->elem_offset)) CHECK(is_zero(value->elem_offset))
<< "Trying to bind a Buffer with offset into one without offset"; << "Trying to bind a Buffer with offset into one without offset "
<< " required elem_offset=" << arg->elem_offset
<< ", provided elem_offset=" << value->elem_offset;
} }
this->Bind(arg->data, value->data, arg_name + ".data"); this->Bind(arg->data, value->data, arg_name + ".data");
......
...@@ -135,29 +135,29 @@ Tensor Schedule::cache_read(const Tensor& tensor, ...@@ -135,29 +135,29 @@ Tensor Schedule::cache_read(const Tensor& tensor,
return cache; return cache;
} }
// Cache write and relayout the data according to loop pattern template<typename OpType>
Array<Tensor> CacheWriteWithReLayout(Schedule sch, void PrepareAxisMapping(Stage orig_stage,
const Array<Tensor>& tensor_array, OpType* op,
const std::string& scope) { std::unordered_set<IterVar>* p_red_axis,
size_t tensor_size = tensor_array.size(); Array<IterVar>* p_new_axis,
sch->InvalidateCache(); std::unordered_map<IterVar, Range>* p_dom_map,
Tensor tensor = tensor_array[0]; std::unordered_map<const Variable*, Expr>* p_vsub,
Stage orig_stage = sch[tensor->op]; std::unordered_map<const Variable*, Expr>* p_vsub2newvar,
const ComputeOpNode* compute = orig_stage->op.as<ComputeOpNode>(); std::vector<Expr>* p_predicates) {
std::unordered_set<IterVar> red_axis; auto& red_axis = *p_red_axis;
for (IterVar iv : compute->reduce_axis) { auto& new_axis = *p_new_axis;
auto& dom_map = *p_dom_map;
auto& vsub = *p_vsub;
auto& vsub2newvar = *p_vsub2newvar;
auto& predicates = *p_predicates;
for (IterVar iv : op->reduce_axis) {
red_axis.insert(iv); red_axis.insert(iv);
} }
std::unordered_map<IterVar, Range> dom_map; for (IterVar iv : op->axis) {
Array<IterVar> new_axis;
for (IterVar iv : compute->axis) {
dom_map[iv] = iv->dom; dom_map[iv] = iv->dom;
} }
schedule::PassDownDomain(orig_stage, &dom_map, true); schedule::PassDownDomain(orig_stage, &dom_map, true);
std::unordered_map<const Variable*, Expr> vsub;
std::unordered_map<const Variable*, Expr> vsub2newvar;
std::vector<Expr> predicates;
{ {
// The source->cache // The source->cache
std::unordered_map<IterVar, Expr> value_map; std::unordered_map<IterVar, Expr> value_map;
...@@ -178,17 +178,85 @@ Array<Tensor> CacheWriteWithReLayout(Schedule sch, ...@@ -178,17 +178,85 @@ Array<Tensor> CacheWriteWithReLayout(Schedule sch,
} }
// skip reduction iteration. // skip reduction iteration.
std::unordered_set<IterVar> skip_bound_check; std::unordered_set<IterVar> skip_bound_check;
for (IterVar iv : compute->reduce_axis) { for (IterVar iv : op->reduce_axis) {
skip_bound_check.insert(iv); skip_bound_check.insert(iv);
} }
schedule::PassUpIndex(orig_stage, dom_map, &value_map, true); schedule::PassUpIndex(orig_stage, dom_map, &value_map, true);
predicates = schedule::MakeBoundCheck( predicates = schedule::MakeBoundCheck(
orig_stage, dom_map, value_map, true, skip_bound_check); orig_stage, dom_map, value_map, true, skip_bound_check);
// The root axis // The root axis
for (IterVar iv : compute->axis) { for (IterVar iv : op->axis) {
if (value_map.count(iv)) {
vsub[iv->var.get()] = value_map.at(iv); vsub[iv->var.get()] = value_map.at(iv);
} // to handle tensor axis
}
}
}
Array<Tensor> ReplaceOriginalOp(Schedule sch,
Stage orig_stage,
const std::string& scope,
Operation cache_op,
Operation orig_new_op,
size_t tensor_size) {
Array<Tensor> cache_tensor_list;
for (size_t i = 0; i < tensor_size; i++) {
Tensor cache_tensor = cache_op.output(i);
cache_tensor_list.push_back(cache_tensor);
} }
// The replace of the dataflow
std::unordered_map<Tensor, Tensor> vmap;
std::unordered_map<Tensor, Tensor> rvmap;
vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
rvmap[orig_new_op.output(0)] = orig_stage->op.output(0);
for (size_t i = 0; i < tensor_size; i++) {
vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
rvmap[orig_new_op.output(0)] = orig_stage->op.output(0);
}
ReplaceDataFlow(sch->stages, &vmap, &rvmap);
// mutate orig stage
orig_stage->op = orig_new_op;
orig_stage->all_iter_vars = orig_stage->op->root_iter_vars();
orig_stage->leaf_iter_vars = orig_stage->all_iter_vars;
orig_stage->relations = Array<IterVarRelation>();
// create schedule for new cached stage.
ArrayNode* stages = sch->stages.CopyOnWrite();
size_t pos = FindNodeRef(stages, orig_stage);
Stage cache_stage = Stage(cache_op);
cache_stage.set_scope(scope);
CHECK_LT(pos, stages->data.size());
stages->data.insert(stages->data.begin() + pos,
cache_stage.node_);
sch->stage_map.Set(cache_op, cache_stage);
// Update group
cache_stage->group = orig_stage->group;
if (cache_stage->group.defined()) {
++cache_stage->group->num_child_stages;
} }
return cache_tensor_list;
}
// Cache write and relayout the data according to loop pattern
Array<Tensor> CacheWriteWithReLayout(Schedule sch,
const Array<Tensor>& tensor_array,
const std::string& scope) {
size_t tensor_size = tensor_array.size();
sch->InvalidateCache();
Tensor tensor = tensor_array[0];
Stage orig_stage = sch[tensor->op];
const ComputeOpNode* compute = orig_stage->op.as<ComputeOpNode>();
std::unordered_set<IterVar> red_axis;
Array<IterVar> new_axis;
std::unordered_map<IterVar, Range> dom_map;
std::unordered_map<const Variable*, Expr> vsub;
std::unordered_map<const Variable*, Expr> vsub2newvar;
std::vector<Expr> predicates;
PrepareAxisMapping(orig_stage, compute,
&red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, &predicates);
Expr body; Expr body;
Array<Expr> body_list; Array<Expr> body_list;
...@@ -198,7 +266,7 @@ Array<Tensor> CacheWriteWithReLayout(Schedule sch, ...@@ -198,7 +266,7 @@ Array<Tensor> CacheWriteWithReLayout(Schedule sch,
body = InjectPredicate(predicates, body); body = InjectPredicate(predicates, body);
body = VarReplacer(vsub2newvar).Mutate(body); body = VarReplacer(vsub2newvar).Mutate(body);
// Reduce nodes in ONE computeOp must be the same except value_index // Reduce nodes in ONE computeOp must be the same except value_index
// This is right only if the oringinal body ensures Reduce nodes are the same // This is right only if the original body ensures Reduce nodes are the same
if (body->is_type<ir::Reduce>()) { if (body->is_type<ir::Reduce>()) {
const ir::Reduce* reduce_body = body.as<ir::Reduce>(); const ir::Reduce* reduce_body = body.as<ir::Reduce>();
if (first_reduce != nullptr) { if (first_reduce != nullptr) {
...@@ -234,48 +302,107 @@ Array<Tensor> CacheWriteWithReLayout(Schedule sch, ...@@ -234,48 +302,107 @@ Array<Tensor> CacheWriteWithReLayout(Schedule sch,
Operation cache_op = ComputeOpNode::make( Operation cache_op = ComputeOpNode::make(
compute->name + "." + scope, compute->tag, compute->attrs, compute->name + "." + scope, compute->tag, compute->attrs,
new_axis, body_list); new_axis, body_list);
Array<Tensor> cache_tensor_list;
Array<Expr> cache_expr_list; Array<Expr> cache_expr_list;
for (size_t i = 0; i < tensor_size; i++) { for (size_t i = 0; i < tensor_size; i++) {
Tensor cache_tensor = cache_op.output(i); Tensor cache_tensor = cache_op.output(i);
cache_tensor_list.push_back(cache_tensor);
cache_expr_list.push_back(cache_tensor(args)); cache_expr_list.push_back(cache_tensor(args));
} }
Operation orig_new_op = ComputeOpNode::make( Operation orig_new_op = ComputeOpNode::make(
compute->name, compute->tag, compute->attrs, compute->name, compute->tag, compute->attrs,
compute->axis, cache_expr_list); compute->axis, cache_expr_list);
// The replace of the dataflow return ReplaceOriginalOp(sch, orig_stage, scope,
std::unordered_map<Tensor, Tensor> vmap; cache_op, orig_new_op, tensor_size);
std::unordered_map<Tensor, Tensor> rvmap; }
vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
rvmap[orig_new_op.output(0)] = orig_stage->op.output(0);
for (size_t i = 0; i < tensor_size; i++) { // for tensor compute op
vmap[orig_stage->op.output(0)] = orig_new_op.output(0); Array<Tensor> CacheWriteWithReLayoutTensor(Schedule sch,
rvmap[orig_new_op.output(0)] = orig_stage->op.output(0); const Array<Tensor>& tensor_array,
const std::string& scope) {
size_t tensor_size = tensor_array.size();
sch->InvalidateCache();
Tensor tensor = tensor_array[0];
Stage orig_stage = sch[tensor->op];
const TensorComputeOpNode* tensor_op = orig_stage->op.as<TensorComputeOpNode>();
CHECK_EQ(tensor_op->num_outputs(), 1)
<< "cache write only support single output tensor_compute_op";
std::unordered_set<IterVar> red_axis;
Array<IterVar> new_axis;
std::unordered_map<IterVar, Range> dom_map;
std::unordered_map<const Variable*, Expr> vsub;
std::unordered_map<const Variable*, Expr> vsub2newvar;
std::vector<Expr> predicates;
PrepareAxisMapping(orig_stage, tensor_op,
&red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, &predicates);
for (int i = tensor_op->schedulable_ndim; i < static_cast<int>(tensor_op->axis.size()); ++i) {
IterVar iv = tensor_op->axis[i];
IterVar new_iv = IterVarNode::make(
iv->dom, iv->var.copy_with_suffix(".c"), iv->iter_type);
new_axis.push_back(new_iv);
} }
ReplaceDataFlow(sch->stages, &vmap, &rvmap); Array<Region> new_regions;
// mutate orig stage for (Region old_region : tensor_op->input_regions) {
orig_stage->op = orig_new_op; Region region;
orig_stage->all_iter_vars = orig_stage->op->root_iter_vars(); for (Range r : old_region) {
orig_stage->leaf_iter_vars = orig_stage->all_iter_vars; Expr min = VarReplacer(vsub2newvar).Mutate(r->min);
orig_stage->relations = Array<IterVarRelation>(); Expr extent = VarReplacer(vsub2newvar).Mutate(r->extent);
// create schedule for new cached stage. region.push_back(Range::make_by_min_extent(min, extent));
ArrayNode* stages = sch->stages.CopyOnWrite();
size_t pos = FindNodeRef(stages, orig_stage);
Stage cache_stage = Stage(cache_op);
cache_stage.set_scope(scope);
CHECK_LT(pos, stages->data.size());
stages->data.insert(stages->data.begin() + pos,
cache_stage.node_);
sch->stage_map.Set(cache_op, cache_stage);
// Update group
cache_stage->group = orig_stage->group;
if (cache_stage->group.defined()) {
++cache_stage->group->num_child_stages;
} }
return cache_tensor_list; new_regions.push_back(region);
}
Operation cache_op = TensorComputeOpNode::make(
tensor_op->name + "." + scope, tensor_op->tag, new_axis,
tensor_op->reduce_axis, tensor_op->schedulable_ndim,
tensor_op->intrin, tensor_op->inputs, new_regions);
// axis will be used in generating compute op
Array<IterVar> compute_axis = tensor_op->axis;
for (size_t i = tensor_op->schedulable_ndim; i < tensor_op->axis.size(); ++i) {
IterVar iv = tensor_op->axis[i];
IterVar aiv = IterVarNode::make(iv->dom, iv->var, kDataPar);
compute_axis.Set(i, aiv);
}
// The reader args
Array<Expr> args;
{
// cache->compute
std::unordered_map<IterVar, Expr> value_map;
for (IterVar iv : compute_axis) {
value_map[iv] = iv->var;
}
schedule::PassDownIndex(orig_stage, dom_map, &value_map, true);
for (IterVar iv : orig_stage->leaf_iter_vars) {
if (red_axis.count(iv)) continue;
args.push_back(value_map.at(iv));
}
// tensorized region axis
for (size_t i = tensor_op->schedulable_ndim; i < tensor_op->axis.size(); ++i) {
IterVar iv = compute_axis[i];
args.push_back(value_map.at(iv));
}
}
Array<Expr> cache_expr_list;
for (size_t i = 0; i < tensor_size; i++) {
Tensor cache_tensor = cache_op.output(i);
cache_expr_list.push_back(cache_tensor(args));
}
Operation orig_new_op = ComputeOpNode::make(
tensor_op->name, tensor_op->tag, {},
compute_axis, cache_expr_list);
return ReplaceOriginalOp(sch, orig_stage, scope,
cache_op, orig_new_op, tensor_size);
} }
Array<Tensor> Schedule::cache_write(const Array<Tensor>& tensor_array, Array<Tensor> Schedule::cache_write(const Array<Tensor>& tensor_array,
const std::string& scope) { const std::string& scope) {
(*this)->InvalidateCache(); (*this)->InvalidateCache();
...@@ -291,23 +418,26 @@ Array<Tensor> Schedule::cache_write(const Array<Tensor>& tensor_array, ...@@ -291,23 +418,26 @@ Array<Tensor> Schedule::cache_write(const Array<Tensor>& tensor_array,
CHECK(orig_stage.same_as(tmp_stage)) CHECK(orig_stage.same_as(tmp_stage))
<< "Input tensor list must be generated by ONE computeOp"; << "Input tensor list must be generated by ONE computeOp";
} }
return CacheWriteWithReLayout(*this, tensor_array, scope); return CacheWriteWithReLayout(*this, tensor_array, scope);
} }
Tensor Schedule::cache_write(const Tensor& tensor, Tensor Schedule::cache_write(const Tensor& tensor,
const std::string& scope) { const std::string& scope) {
// support original compute and tensor compute both
(*this)->InvalidateCache(); (*this)->InvalidateCache();
Stage orig_stage = operator[](tensor->op); const char* type_key = tensor->op->type_key();
const ComputeOpNode* compute = tensor->op.as<ComputeOpNode>(); if (!strcmp(type_key, "ComputeOp")) {
CHECK(compute)
<< "cache write only take ComputeOp as writers";
CHECK_EQ(compute->num_outputs(), 1)
<< "cache write only support single output ComputeOp";
return (CacheWriteWithReLayout(*this, {tensor}, scope))[0]; return (CacheWriteWithReLayout(*this, {tensor}, scope))[0];
} else if (!strcmp(type_key, "TensorComputeOp")) {
return (CacheWriteWithReLayoutTensor(*this, {tensor}, scope))[0];
} else {
LOG(FATAL) << "cache write only take ComputeOp or TensorComputeOp as writers";
return Tensor();
}
} }
void RebaseNonZeroMinLoop(const Schedule& sch) { void RebaseNonZeroMinLoop(const Schedule& sch) {
std::unordered_map<IterVar, IterVar> rebase_map; std::unordered_map<IterVar, IterVar> rebase_map;
for (Stage s : sch->stages) { for (Stage s : sch->stages) {
......
...@@ -85,6 +85,78 @@ def test_tensor_reduce(): ...@@ -85,6 +85,78 @@ def test_tensor_reduce():
assert(isinstance(C_loaded, tvm.tensor.Tensor)) assert(isinstance(C_loaded, tvm.tensor.Tensor))
assert(str(C_loaded) == str(C)) assert(str(C_loaded) == str(C))
def test_tensor_compute1():
m = 1024
factor = 16
dtype = 'float32'
def intrin_vadd(n):
x = tvm.placeholder((n,))
y = tvm.placeholder((n,))
z = tvm.compute(x.shape, lambda i: x[i] + y[i])
def intrin_func(ins, outs):
ib = tvm.ir_builder.create()
ib.emit(tvm.call_extern(outs[0].dtype, 'vadd', ins[0].access_ptr("r"), ins[1].access_ptr('r'), outs[0].access_ptr('wr')))
return ib.get()
with tvm.build_config(offset_factor=n):
return tvm.decl_tensor_intrin(z.op, intrin_func)
vadd = intrin_vadd(factor)
A = tvm.placeholder((m//factor, factor), name="A", dtype=dtype)
B = tvm.placeholder((m//factor, factor), name="B", dtype=dtype)
C = tvm.compute((m//factor, factor),
lambda i: vadd(A[i, 0:factor], B[i, 0:factor]))
s = tvm.create_schedule(C.op)
stmt = tvm.lower(s, [A, B, C], simple_mode=True)
assert isinstance(stmt.body.body, tvm.stmt.Evaluate)
def test_tensor_compute2():
M = 2048
N = 1024
L = 1024
factor = 16
factor1 = 32
factor2 = 32
dtype = 'float32'
def intrin_gemm(m, n, l):
k = tvm.reduce_axis((0, l))
x = tvm.placeholder((m, l))
y = tvm.placeholder((n, l))
# in theory, no relation
z = tvm.compute((m, n), lambda i, j: tvm.sum(x[i][k] * y[j][k], axis=k))
def intrin_func(ins, outs):
x_ptr = ins[0].access_ptr("r")
y_ptr = ins[1].access_ptr("r")
z_ptr = outs[0].access_ptr("w")
body = tvm.call_packed(
"gemv", x_ptr, y_ptr, z_ptr, m, n, l)
reset = tvm.call_packed(
"fill_zero", z_ptr, m, n)
update = tvm.call_packed(
"gemv_add", x_ptr, y_ptr, z_ptr, m, n, l)
return body, reset, update
with tvm.build_config(offset_factor=n):
return tvm.decl_tensor_intrin(z.op, intrin_func)
vgemm = intrin_gemm(factor1, factor2, factor)
A = tvm.placeholder((M//factor1, L//factor, factor1, factor), name="A", dtype=dtype)
B = tvm.placeholder((N//factor2, L//factor, factor2, factor), name="B", dtype=dtype)
k = tvm.reduce_axis((0, L//factor), name='k')
C = tvm.compute((M//factor1, N//factor2, factor1, factor2),
lambda i, j: vgemm(A[i, k, 0:factor1, 0:factor], B[j, k, 0:factor2, 0:factor], reduce_axis=k))
s = tvm.create_schedule(C.op)
stmt = tvm.lower(s, [A, B, C], simple_mode=True)
assert isinstance(stmt.body.body.body.first, tvm.stmt.Evaluate)
assert isinstance(stmt.body.body.body.rest.body, tvm.stmt.Evaluate)
def test_tensor_scan(): def test_tensor_scan():
m = tvm.var("m") m = tvm.var("m")
...@@ -221,6 +293,8 @@ if __name__ == "__main__": ...@@ -221,6 +293,8 @@ if __name__ == "__main__":
test_conv1d() test_conv1d()
test_tensor_slice() test_tensor_slice()
test_tensor() test_tensor()
test_tensor_compute1()
test_tensor_compute2()
test_tensor_reduce() test_tensor_reduce()
test_tensor_scan() test_tensor_scan()
test_scan_multi_out() test_scan_multi_out()
......
...@@ -276,6 +276,133 @@ def test_schedule_bound_condition(): ...@@ -276,6 +276,133 @@ def test_schedule_bound_condition():
stmt = tvm.ir_pass.Simplify(stmt) stmt = tvm.ir_pass.Simplify(stmt)
assert (isinstance(stmt.body.body.first.body.body.then_case, tvm.stmt.IfThenElse)) assert (isinstance(stmt.body.body.first.body.body.then_case, tvm.stmt.IfThenElse))
def intrin_gemv(m, n):
w = tvm.placeholder((m, n), name='w')
x = tvm.placeholder((n,), name='x')
k = tvm.reduce_axis((0, n), name='k')
z = tvm.compute((m,), lambda i:
tvm.sum(w[i, k] * x[k], axis=k), name='z')
Wb = tvm.decl_buffer(w.shape, w.dtype,
name="W",
offset_factor=16,
strides=[tvm.var('ldw'), 1])
def intrin_func(ins, outs):
ww, xx = ins
zz = outs[0]
ww_ptr = ww.access_ptr("r")
xx_ptr = xx.access_ptr("r")
zz_ptr = zz.access_ptr("w")
body = tvm.call_packed(
"gemm", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0])
reset = tvm.call_packed(
"fill_zero", zz_ptr, n)
update = tvm.call_packed(
"gemv_add", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0])
return body, reset, update
with tvm.build_config(data_alignment=16,
offset_factor=16):
return tvm.decl_tensor_intrin(z.op, intrin_func,
binds={w: Wb})
def test_schedule_tensor_compute1():
# basic: split, reorder, tile
M, N, L = 2048, 1024, 512
factor, rfactor = 16, 16
A = tvm.placeholder((N//factor, L//rfactor, factor, rfactor), name='A')
B = tvm.placeholder((M, L//rfactor, rfactor), name='B')
k = tvm.reduce_axis((0, L//rfactor), name='k')
gemv = intrin_gemv(factor, rfactor)
C = tvm.compute((N, M//factor, factor),
lambda i, j: gemv(A[i, k, 0:factor, 0:factor], B[j, k, 0:rfactor], reduce_axis=k),
name='C')
s = tvm.create_schedule(C.op)
ai, aj, ax = s[C].op.axis
aio, aii = s[C].split(ai, 16)
s[C].reorder(aio, aj, aii)
aioo, ajo, aioi, aji = s[C].tile(aio, aj, 16, 4)
s = s.normalize()
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
def intrin_vadd(n, cache_read=False, cache_write=False):
scope_ubuf = 'local'
dtype = 'float32'
x = tvm.placeholder((n,), dtype=dtype, name='vx')
y = tvm.placeholder((n,), dtype=dtype, name='vy')
z = tvm.compute(x.shape, lambda i: x[i] + y[i], name='z')
s = tvm.create_schedule(z.op)
def create_buffer(t):
return tvm.decl_buffer(t.shape, t.dtype,
name='W'+t.name,
scope=scope_ubuf,
offset_factor=16)
binds = {}
if cache_read:
binds[x] = create_buffer(x)
binds[y] = create_buffer(y)
if cache_write:
binds[z] = create_buffer(z)
def intrin_func(ins, outs):
ib = tvm.ir_builder.create()
ib.emit(tvm.call_extern(outs[0].dtype, 'vadd', ins[0].access_ptr("r"), ins[1].access_ptr('r'), outs[0].access_ptr('wr')))
return ib.get()
with tvm.build_config(offset_factor=16):
return tvm.decl_tensor_intrin(z.op, intrin_func, binds=binds)
def test_schedule_tensor_compute2():
# cache_read, cache_write
M = 1024
factor = 16
dtype = 'float32'
scope_ubuf = 'local'
A = tvm.placeholder((M//factor, factor), name="A", dtype=dtype)
B = tvm.placeholder((M//factor, factor), name="B", dtype=dtype)
vadd = intrin_vadd(factor, True, True)
C = tvm.compute((M//factor, factor),
lambda i: vadd(A[i, 0:factor], B[i, 0:factor]), name='C')
s = tvm.create_schedule(C.op)
AL = s.cache_read(A, scope_ubuf, C)
BL = s.cache_read(B, scope_ubuf, C)
CL = s.cache_write(C, scope_ubuf)
s = s.normalize()
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_schedule_tensor_compute3():
# compute_at
M = 1024
factor = 16
dtype = 'float32'
A = tvm.placeholder((M//factor, factor), name="A", dtype=dtype)
B = tvm.placeholder((M//factor, factor), name="B", dtype=dtype)
Bi = tvm.compute((M//factor, factor), lambda i, j: B[i, j] + 5, name="Bi")
vadd = intrin_vadd(factor)
C = tvm.compute((M//factor, factor),
lambda i: vadd(A[i, 0:factor], Bi[i, 0:factor]), name='C')
s = tvm.create_schedule(C.op)
s[Bi].compute_at(s[C], C.op.axis[0])
s = s.normalize()
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
if __name__ == "__main__": if __name__ == "__main__":
test_schedule_middle_cache() test_schedule_middle_cache()
test_inline_multi_reduce() test_inline_multi_reduce()
...@@ -294,3 +421,6 @@ if __name__ == "__main__": ...@@ -294,3 +421,6 @@ if __name__ == "__main__":
test_schedule2() test_schedule2()
test_schedule_cache() test_schedule_cache()
test_schedule_bound_condition() test_schedule_bound_condition()
test_schedule_tensor_compute1()
test_schedule_tensor_compute2()
test_schedule_tensor_compute3()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment