Commit b90620ea by ziheng Committed by Tianqi Chen

[LANG] Generalize compute to tensor region (#1476)

parent 3d62cf7c
Subproject commit 4f0564ec769477c66d480dd966088f172050c874 Subproject commit 946a54012d0c390675ab5b46cd990838d4183d6f
...@@ -108,6 +108,8 @@ class Range : public HalideIR::IR::Range { ...@@ -108,6 +108,8 @@ class Range : public HalideIR::IR::Range {
TVM_DLL static Range make_by_min_extent(Expr min, Expr extent); TVM_DLL static Range make_by_min_extent(Expr min, Expr extent);
}; };
using Region = Array<Range>;
/*! /*!
* \brief Type of iteration variable. * \brief Type of iteration variable.
* Each IterVar have a specific type. * Each IterVar have a specific type.
......
...@@ -49,7 +49,7 @@ class OperationNode : public FunctionBaseNode { ...@@ -49,7 +49,7 @@ class OperationNode : public FunctionBaseNode {
} }
/*! /*!
* \return The list of iteration variable at root * \return The list of iteration variable at root
* \note root_iter_vars dedides the shape of the outputs. * \note root_iter_vars decides the shape of the outputs.
*/ */
virtual Array<IterVar> root_iter_vars() const = 0; virtual Array<IterVar> root_iter_vars() const = 0;
/*! /*!
...@@ -240,6 +240,74 @@ class TVM_DLL ComputeOpNode : public OperationNode { ...@@ -240,6 +240,74 @@ class TVM_DLL ComputeOpNode : public OperationNode {
}; };
/*! /*!
* \brief A TenorCompute op that compute a tensor with an tensor intrinsic.
*/
class TensorComputeOpNode : public OperationNode {
public:
/*! \brief IterVar on each axis */
Array<IterVar> axis;
/*! \brief IterVar on each reduction axis, if the intrin will use the reduce axis */
Array<IterVar> reduce_axis;
/*! \brief number of axes that can be scheduled */
int schedulable_ndim;
/*! \brief TensorIntrin used to compute */
TensorIntrin intrin;
/*! \brief input tensors of intrin */
Array<Tensor> inputs;
/*! \brief region of input tensors */
Array<Region> input_regions;
/*! \brief constructor */
TensorComputeOpNode() {}
// override functions
int num_outputs() const final;
Array<IterVar> root_iter_vars() const final;
Type output_dtype(size_t i) const final;
Array<Expr> output_shape(size_t i) const final;
Array<Tensor> InputTensors() const final;
Operation ReplaceInputs(
const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const final;
void PropBoundToInputs(
const Operation& self,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound(
const Operation& self,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(
const Stage& stage,
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final;
Stmt BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const final;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("tag", &tag);
v->Visit("axis", &axis);
v->Visit("reduce_axis", &reduce_axis);
v->Visit("schedulable_ndim", &schedulable_ndim);
v->Visit("intrin", &intrin);
v->Visit("inputs", &inputs);
v->Visit("input_regions", &input_regions);
}
static Operation make(std::string name,
std::string tag,
Array<IterVar> axis,
Array<IterVar> reduce_axis,
int schedulable_ndim,
TensorIntrin intrin,
Array<Tensor> tensors,
Array<Region> regions);
static constexpr const char* _type_key = "TensorComputeOp";
TVM_DECLARE_NODE_TYPE_INFO(TensorComputeOpNode, OperationNode);
};
/*!
* \brief Symbolic scan. * \brief Symbolic scan.
*/ */
class ScanOpNode : public OperationNode { class ScanOpNode : public OperationNode {
...@@ -326,7 +394,7 @@ class ExternOpNode : public OperationNode { ...@@ -326,7 +394,7 @@ class ExternOpNode : public OperationNode {
public: public:
/*! \brief The input tensors */ /*! \brief The input tensors */
Array<Tensor> inputs; Array<Tensor> inputs;
/*! \brief Symbolic placeholder representationinputs */ /*! \brief Symbolic placeholder representation of inputs */
Array<Buffer> input_placeholders; Array<Buffer> input_placeholders;
/*! \brief Symbolic placeholder representation of outputs */ /*! \brief Symbolic placeholder representation of outputs */
Array<Buffer> output_placeholders; Array<Buffer> output_placeholders;
......
...@@ -89,5 +89,58 @@ class TensorIntrinNode : public Node { ...@@ -89,5 +89,58 @@ class TensorIntrinNode : public Node {
inline const TensorIntrinNode* TensorIntrin::operator->() const { inline const TensorIntrinNode* TensorIntrin::operator->() const {
return static_cast<const TensorIntrinNode*>(node_.get()); return static_cast<const TensorIntrinNode*>(node_.get());
} }
// Internal node container of tensor intrinsic calling.
class TensorIntrinCallNode;
/*! \brief Tensor intrinsic calling node. */
class TensorIntrinCall : public NodeRef {
public:
TensorIntrinCall() {}
explicit TensorIntrinCall(NodePtr<Node> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const TensorIntrinCallNode* operator->() const;
/*! \brief specify container node */
using ContainerType = TensorIntrinCallNode;
};
class TensorIntrinCallNode : public Node {
public:
/*! \brief the tensor intrinsic */
TensorIntrin intrin;
/*! \brief input tensors of the intrinsic */
Array<Tensor> tensors;
/*! \brief regions of input tensors */
Array<Region> regions;
/*!
* \brief IterVar on each reduction axis, if the
* intrin will use the reduce axis
*/
Array<IterVar> reduce_axis;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("intrin", &intrin);
v->Visit("tensors", &tensors);
v->Visit("regions", &regions);
v->Visit("reduce_axis", &reduce_axis);
}
static TensorIntrinCall make(TensorIntrin intrin,
Array<Tensor> tensors,
Array<Region> regions,
Array<IterVar> reduce_axis);
static constexpr const char* _type_key = "TensorIntrinCall";
TVM_DECLARE_NODE_TYPE_INFO(TensorIntrinCallNode, Node);
};
inline const TensorIntrinCallNode* TensorIntrinCall::operator->() const {
return static_cast<const TensorIntrinCallNode*>(node_.get());
}
} // namespace tvm } // namespace tvm
#endif // TVM_TENSOR_INTRIN_H_ #endif // TVM_TENSOR_INTRIN_H_
...@@ -243,24 +243,43 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None): ...@@ -243,24 +243,43 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None):
raise ValueError("nested tag is not allowed for now") raise ValueError("nested tag is not allowed for now")
tag = _tag.TagScope.get_current().tag tag = _tag.TagScope.get_current().tag
shape = (shape,) if isinstance(shape, _expr.Expr) else shape shape = (shape,) if isinstance(shape, _expr.Expr) else shape
# for python3
shape = tuple([int(s) if isinstance(s, float) else s for s in shape])
ndim = len(shape) ndim = len(shape)
code = fcompute.__code__ code = fcompute.__code__
if fcompute.__code__.co_argcount == 0: out_ndim = ndim
if code.co_argcount == 0:
arg_names = ["i%d" % i for i in range(ndim)] arg_names = ["i%d" % i for i in range(ndim)]
else: else:
arg_names = code.co_varnames[:code.co_argcount] arg_names = code.co_varnames[:code.co_argcount]
out_ndim = code.co_argcount
if ndim != len(arg_names): if out_ndim != len(arg_names):
raise ValueError("fcompute do not match dimension, ndim=%d" % ndim) raise ValueError("fcompute do not match dimension, ndim=%d" % ndim)
dim_var = [_IterVar((0, s), x, 0) for x, s in zip(arg_names, shape)] dim_var = [_IterVar((0, s), x, 0) for x, s in zip(arg_names, shape[:out_ndim])]
body = fcompute(*[v.var for v in dim_var]) body = fcompute(*[v.var for v in dim_var])
if isinstance(body, _tensor.TensorIntrinCall):
for i, s in enumerate(shape[out_ndim:]):
var_name = "ax" + str(i)
dim_var.append(_IterVar((0, s), var_name, 4))
op_node = _api_internal._TensorComputeOp(name,
tag,
dim_var,
body.reduce_axis,
out_ndim,
body.intrin,
body.tensors,
body.regions)
else:
if not isinstance(body, (list, tuple)): if not isinstance(body, (list, tuple)):
body = [body] body = [body]
body = convert(body) body = convert(body)
op_node = _api_internal._ComputeOp( op_node = _api_internal._ComputeOp(
name, tag, attrs, dim_var, body) name, tag, attrs, dim_var, body)
num = op_node.num_outputs num = op_node.num_outputs
outputs = tuple(op_node.output(i) for i in range(num)) outputs = tuple(op_node.output(i) for i in range(num))
return outputs[0] if num == 1 else outputs return outputs[0] if num == 1 else outputs
...@@ -529,14 +548,14 @@ def decl_buffer(shape, ...@@ -529,14 +548,14 @@ def decl_buffer(shape,
dtype = float32 if dtype is None else dtype dtype = float32 if dtype is None else dtype
strides = () if strides is None else strides strides = () if strides is None else strides
if offset_factor != 0 and elem_offset is None: if offset_factor != 0 and elem_offset is None:
elem_offset = var('%s_elem_offset' % name, shape[0].dtype) shape_dtype = shape[0].dtype if hasattr(shape[0], "dtype") else "int32"
elem_offset = var('%s_elem_offset' % name, shape_dtype)
if data is None: if data is None:
data = var(name, "handle") data = var(name, "handle")
return _api_internal._Buffer( return _api_internal._Buffer(
data, dtype, shape, strides, elem_offset, name, scope, data, dtype, shape, strides, elem_offset, name, scope,
data_alignment, offset_factor) data_alignment, offset_factor)
def _IterVar(dom, name, iter_type, thread_tag=''): def _IterVar(dom, name, iter_type, thread_tag=''):
"""Internal function to create IterVar """Internal function to create IterVar
......
...@@ -30,6 +30,11 @@ class TensorSlice(NodeGeneric, _expr.ExprOp): ...@@ -30,6 +30,11 @@ class TensorSlice(NodeGeneric, _expr.ExprOp):
"""Data content of the tensor.""" """Data content of the tensor."""
return self.tensor.dtype return self.tensor.dtype
@register_node
class TensorIntrinCall(NodeBase):
"""Intermediate structure for calling a tensor intrinsic."""
pass
itervar_cls = None itervar_cls = None
...@@ -106,6 +111,7 @@ class Tensor(NodeBase, _expr.ExprOp): ...@@ -106,6 +111,7 @@ class Tensor(NodeBase, _expr.ExprOp):
return "%s.v%d" % (op.name, self.value_index) return "%s.v%d" % (op.name, self.value_index)
class Operation(NodeBase): class Operation(NodeBase):
"""Represent an operation that generate a tensor""" """Represent an operation that generate a tensor"""
...@@ -156,6 +162,12 @@ class ComputeOp(Operation): ...@@ -156,6 +162,12 @@ class ComputeOp(Operation):
@register_node @register_node
class TensorComputeOp(Operation):
"""Tensor operation."""
pass
@register_node
class ScanOp(Operation): class ScanOp(Operation):
"""Scan operation.""" """Scan operation."""
@property @property
......
...@@ -6,9 +6,25 @@ from . import expr as _expr ...@@ -6,9 +6,25 @@ from . import expr as _expr
from . import stmt as _stmt from . import stmt as _stmt
from . import make as _make from . import make as _make
from . import tensor as _tensor from . import tensor as _tensor
from . import schedule as _schedule
from .build_module import current_build_config from .build_module import current_build_config
from ._ffi.node import NodeBase, register_node from ._ffi.node import NodeBase, register_node
def _get_region(tslice):
region = []
for idx in tslice.indices:
if isinstance(idx, slice):
assert idx.step is None
region.append(_api.Range(idx.start, idx.stop))
else:
if isinstance(idx, _schedule.IterVar):
begin = idx.var
else:
begin = idx
region.append(_make.range_by_min_extent(begin, 1))
return region
@register_node @register_node
class TensorIntrin(NodeBase): class TensorIntrin(NodeBase):
"""Tensor intrinsic functions for certain computation. """Tensor intrinsic functions for certain computation.
...@@ -17,8 +33,16 @@ class TensorIntrin(NodeBase): ...@@ -17,8 +33,16 @@ class TensorIntrin(NodeBase):
-------- --------
decl_tensor_intrin: Construct a TensorIntrin decl_tensor_intrin: Construct a TensorIntrin
""" """
pass def __call__(self, *args, **kwargs):
tensors = [x.tensor for x in args]
regions = [_get_region(x) for x in args]
reduce_axis = []
if "reduce_axis" in kwargs:
reduce_axis = kwargs["reduce_axis"]
if not isinstance(reduce_axis, (list, tuple)):
reduce_axis = [reduce_axis]
reduce_axis = _api.convert(reduce_axis)
return _api_internal._TensorIntrinCall(self, tensors, regions, reduce_axis)
def decl_tensor_intrin(op, def decl_tensor_intrin(op,
fcompute, fcompute,
......
...@@ -239,6 +239,14 @@ TVM_REGISTER_API("_TensorIntrin") ...@@ -239,6 +239,14 @@ TVM_REGISTER_API("_TensorIntrin")
args[6]); args[6]);
}); });
TVM_REGISTER_API("_TensorIntrinCall")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TensorIntrinCallNode::make(args[0],
args[1],
args[2],
args[3]);
});
TVM_REGISTER_API("_TensorEqual") TVM_REGISTER_API("_TensorEqual")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Tensor() == args[1].operator Tensor(); *ret = args[0].operator Tensor() == args[1].operator Tensor();
...@@ -278,6 +286,18 @@ TVM_REGISTER_API("_ScanOp") ...@@ -278,6 +286,18 @@ TVM_REGISTER_API("_ScanOp")
args[7]); args[7]);
}); });
TVM_REGISTER_API("_TensorComputeOp")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TensorComputeOpNode::make(args[0],
args[1],
args[2],
args[3],
args[4],
args[5],
args[6],
args[7]);
});
TVM_REGISTER_API("_ExternOp") TVM_REGISTER_API("_ExternOp")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ExternOpNode::make(args[0], *ret = ExternOpNode::make(args[0],
......
...@@ -10,6 +10,8 @@ ...@@ -10,6 +10,8 @@
namespace tvm { namespace tvm {
// Tensor
Expr Tensor::operator()(Array<Var> indices) const { Expr Tensor::operator()(Array<Var> indices) const {
Array<Expr> arr(indices.begin(), indices.end()); Array<Expr> arr(indices.begin(), indices.end());
return operator()(arr); return operator()(arr);
...@@ -26,6 +28,15 @@ Expr Tensor::operator()(Array<Expr> indices) const { ...@@ -26,6 +28,15 @@ Expr Tensor::operator()(Array<Expr> indices) const {
return n; return n;
} }
Tensor Operation::output(size_t i) const {
auto node = make_node<TensorNode>();
node->op = *this;
node->value_index = i;
node->dtype = (*this)->output_dtype(i);
node->shape = (*this)->output_shape(i);
return Tensor(node);
}
Tensor TensorNode::make(Array<Expr> shape, Tensor TensorNode::make(Array<Expr> shape,
Type dtype, Type dtype,
Operation op, Operation op,
...@@ -46,14 +57,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -46,14 +57,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
TVM_REGISTER_NODE_TYPE(TensorNode); TVM_REGISTER_NODE_TYPE(TensorNode);
Tensor Operation::output(size_t i) const {
auto node = make_node<TensorNode>(); // TensorIntrin
node->op = *this;
node->value_index = i;
node->dtype = (*this)->output_dtype(i);
node->shape = (*this)->output_shape(i);
return Tensor(node);
}
TensorIntrin TensorIntrinNode::make(std::string name, TensorIntrin TensorIntrinNode::make(std::string name,
Operation op, Operation op,
...@@ -79,4 +84,27 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -79,4 +84,27 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
}); });
TVM_REGISTER_NODE_TYPE(TensorIntrinNode); TVM_REGISTER_NODE_TYPE(TensorIntrinNode);
// TensorIntrinCall
TensorIntrinCall TensorIntrinCallNode::make(TensorIntrin intrin,
Array<Tensor> tensors,
Array<Region> regions,
Array<IterVar> reduce_axis) {
auto n = make_node<TensorIntrinCallNode>();
n->intrin = std::move(intrin);
n->tensors = std::move(tensors);
n->regions = std::move(regions);
n->reduce_axis = std::move(reduce_axis);
return TensorIntrinCall(n);
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TensorIntrinCallNode>([](const TensorIntrinCallNode *n, IRPrinter *p) {
p->stream << "TensorIntrinCall(intrin=" << n->intrin << ", " << n << ")";
});
TVM_REGISTER_NODE_TYPE(TensorIntrinCallNode);
} // namespace tvm } // namespace tvm
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include "compute_op.h" #include "compute_op.h"
#include "op_util.h" #include "op_util.h"
#include "../schedule/message_passing.h" #include "../schedule/message_passing.h"
#include "../arithmetic/compute_expr.h"
namespace tvm { namespace tvm {
...@@ -545,4 +546,38 @@ static void VerifyComputeOp(const ComputeOpNode* op) { ...@@ -545,4 +546,38 @@ static void VerifyComputeOp(const ComputeOpNode* op) {
v.Run(); v.Run();
} }
Stmt TransformUpdate(const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
const ComputeLoopNest& n,
Stmt body,
Stmt update) {
Array<Expr> conds;
std::unordered_set<const Variable*> banned;
for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
IterVar iv = stage->leaf_iter_vars[i];
auto iit = stage->iter_var_attrs.find(iv);
if (iit != stage->iter_var_attrs.end()) {
const IterVarAttr& attr = (*iit).second;
if (attr->iter_type == kTensorized) {
break;
}
}
if (iv->iter_type == kCommReduce) {
auto vit = dom_map.find(iv);
CHECK(vit != dom_map.end());
const Range& vrange = vit->second;
conds.push_back(likely(iv->var > vrange->min));
banned.insert(iv->var.get());
}
}
for (const Expr& pred : n.main_predicates) {
if (ir::ExprUseVar(pred, banned)) {
LOG(FATAL) << "Tensorize update transform failed, the condition "
<< pred << " has a conflict with the reset condition";
}
}
return IfThenElse::make(arith::ComputeReduce<ir::Or>(conds, const_true(1)),
update, body);
}
} // namespace tvm } // namespace tvm
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
namespace tvm { namespace tvm {
// loop nest structure for general compute // loop nest structure for general compute
// This the the loop nest structured used in compute. // This the loop nest structured used in compute.
// Does not include the loop body. // Does not include the loop body.
struct ComputeLoopNest { struct ComputeLoopNest {
// The common number of loops between init and main // The common number of loops between init and main
...@@ -73,6 +73,21 @@ Stmt MakeTensorize(const ComputeOpNode* self, ...@@ -73,6 +73,21 @@ Stmt MakeTensorize(const ComputeOpNode* self,
const Stage& stage, const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map, const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop); bool debug_keep_trivial_loop);
/*!
* \brief Transform the update part when there is no init func in tensorizing
* \param stage The stage for tensorizing.
* \param dom_map The range of each iter var.
* \param n The loop nest structured used in compute.
* \param body The body func in tensorize intrin
* \param update The update func in tensorize intrin
* \return Transformed result.
*/
Stmt TransformUpdate(const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
const ComputeLoopNest& n,
Stmt body,
Stmt update);
} // namespace tvm } // namespace tvm
#endif // TVM_OP_COMPUTE_OP_H_ #endif // TVM_OP_COMPUTE_OP_H_
...@@ -10,7 +10,6 @@ ...@@ -10,7 +10,6 @@
#include "op_util.h" #include "op_util.h"
#include "compute_op.h" #include "compute_op.h"
#include "../schedule/message_passing.h" #include "../schedule/message_passing.h"
#include "../arithmetic/compute_expr.h"
namespace tvm { namespace tvm {
...@@ -323,50 +322,6 @@ void VerifyTensorizeBody( ...@@ -323,50 +322,6 @@ void VerifyTensorizeBody(
} }
} }
/*!
* \brief Transform the update part when there is no init func in tensorizing
* \param stage The stage for tensorizing.
* \param dom_map The range of each iter var.
* \param n The loop nest structured used in compute.
* \param body The body func in tensorize intrin
* \param update The update func in tensorize intrin
* \return Transformed result.
*/
Stmt TransformUpdate(const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
const ComputeLoopNest& n,
Stmt body,
Stmt update) {
Array<Expr> conds;
std::unordered_set<const Variable*> banned;
for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
IterVar iv = stage->leaf_iter_vars[i];
auto iit = stage->iter_var_attrs.find(iv);
if (iit != stage->iter_var_attrs.end()) {
const IterVarAttr& attr = (*iit).second;
if (attr->iter_type == kTensorized) {
break;
}
}
if (iv->iter_type == kCommReduce) {
auto vit = dom_map.find(iv);
CHECK(vit != dom_map.end());
const Range& vrange = vit->second;
conds.push_back(likely(iv->var > vrange->min));
banned.insert(iv->var.get());
}
}
for (const Expr& pred : n.main_predicates) {
if (ir::ExprUseVar(pred, banned)) {
LOG(FATAL) << "Tensorize update transform failed, the condition "
<< pred << " has a conflict with the reset condition";
}
}
return IfThenElse::make(arith::ComputeReduce<ir::Or>(conds, const_true(1)),
update, body);
}
Stmt MakeTensorize(const ComputeOpNode* self, Stmt MakeTensorize(const ComputeOpNode* self,
const Stage& stage, const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map, const std::unordered_map<IterVar, Range>& dom_map,
......
...@@ -91,7 +91,9 @@ void ArgBinder::BindBuffer(const Buffer& arg, ...@@ -91,7 +91,9 @@ void ArgBinder::BindBuffer(const Buffer& arg,
// bind pointer and offset. // bind pointer and offset.
if (is_zero(arg->elem_offset)) { if (is_zero(arg->elem_offset)) {
CHECK(is_zero(value->elem_offset)) CHECK(is_zero(value->elem_offset))
<< "Trying to bind a Buffer with offset into one without offset"; << "Trying to bind a Buffer with offset into one without offset "
<< " required elem_offset=" << arg->elem_offset
<< ", provided elem_offset=" << value->elem_offset;
} }
this->Bind(arg->data, value->data, arg_name + ".data"); this->Bind(arg->data, value->data, arg_name + ".data");
......
...@@ -85,6 +85,78 @@ def test_tensor_reduce(): ...@@ -85,6 +85,78 @@ def test_tensor_reduce():
assert(isinstance(C_loaded, tvm.tensor.Tensor)) assert(isinstance(C_loaded, tvm.tensor.Tensor))
assert(str(C_loaded) == str(C)) assert(str(C_loaded) == str(C))
def test_tensor_compute1():
m = 1024
factor = 16
dtype = 'float32'
def intrin_vadd(n):
x = tvm.placeholder((n,))
y = tvm.placeholder((n,))
z = tvm.compute(x.shape, lambda i: x[i] + y[i])
def intrin_func(ins, outs):
ib = tvm.ir_builder.create()
ib.emit(tvm.call_extern(outs[0].dtype, 'vadd', ins[0].access_ptr("r"), ins[1].access_ptr('r'), outs[0].access_ptr('wr')))
return ib.get()
with tvm.build_config(offset_factor=n):
return tvm.decl_tensor_intrin(z.op, intrin_func)
vadd = intrin_vadd(factor)
A = tvm.placeholder((m//factor, factor), name="A", dtype=dtype)
B = tvm.placeholder((m//factor, factor), name="B", dtype=dtype)
C = tvm.compute((m//factor, factor),
lambda i: vadd(A[i, 0:factor], B[i, 0:factor]))
s = tvm.create_schedule(C.op)
stmt = tvm.lower(s, [A, B, C], simple_mode=True)
assert isinstance(stmt.body.body, tvm.stmt.Evaluate)
def test_tensor_compute2():
M = 2048
N = 1024
L = 1024
factor = 16
factor1 = 32
factor2 = 32
dtype = 'float32'
def intrin_gemm(m, n, l):
k = tvm.reduce_axis((0, l))
x = tvm.placeholder((m, l))
y = tvm.placeholder((n, l))
# in theory, no relation
z = tvm.compute((m, n), lambda i, j: tvm.sum(x[i][k] * y[j][k], axis=k))
def intrin_func(ins, outs):
x_ptr = ins[0].access_ptr("r")
y_ptr = ins[1].access_ptr("r")
z_ptr = outs[0].access_ptr("w")
body = tvm.call_packed(
"gemv", x_ptr, y_ptr, z_ptr, m, n, l)
reset = tvm.call_packed(
"fill_zero", z_ptr, m, n)
update = tvm.call_packed(
"gemv_add", x_ptr, y_ptr, z_ptr, m, n, l)
return body, reset, update
with tvm.build_config(offset_factor=n):
return tvm.decl_tensor_intrin(z.op, intrin_func)
vgemm = intrin_gemm(factor1, factor2, factor)
A = tvm.placeholder((M//factor1, L//factor, factor1, factor), name="A", dtype=dtype)
B = tvm.placeholder((N//factor2, L//factor, factor2, factor), name="B", dtype=dtype)
k = tvm.reduce_axis((0, L//factor), name='k')
C = tvm.compute((M//factor1, N//factor2, factor1, factor2),
lambda i, j: vgemm(A[i, k, 0:factor1, 0:factor], B[j, k, 0:factor2, 0:factor], reduce_axis=k))
s = tvm.create_schedule(C.op)
stmt = tvm.lower(s, [A, B, C], simple_mode=True)
assert isinstance(stmt.body.body.body.first, tvm.stmt.Evaluate)
assert isinstance(stmt.body.body.body.rest.body, tvm.stmt.Evaluate)
def test_tensor_scan(): def test_tensor_scan():
m = tvm.var("m") m = tvm.var("m")
...@@ -221,6 +293,8 @@ if __name__ == "__main__": ...@@ -221,6 +293,8 @@ if __name__ == "__main__":
test_conv1d() test_conv1d()
test_tensor_slice() test_tensor_slice()
test_tensor() test_tensor()
test_tensor_compute1()
test_tensor_compute2()
test_tensor_reduce() test_tensor_reduce()
test_tensor_scan() test_tensor_scan()
test_scan_multi_out() test_scan_multi_out()
......
...@@ -276,6 +276,133 @@ def test_schedule_bound_condition(): ...@@ -276,6 +276,133 @@ def test_schedule_bound_condition():
stmt = tvm.ir_pass.Simplify(stmt) stmt = tvm.ir_pass.Simplify(stmt)
assert (isinstance(stmt.body.body.first.body.body.then_case, tvm.stmt.IfThenElse)) assert (isinstance(stmt.body.body.first.body.body.then_case, tvm.stmt.IfThenElse))
def intrin_gemv(m, n):
w = tvm.placeholder((m, n), name='w')
x = tvm.placeholder((n,), name='x')
k = tvm.reduce_axis((0, n), name='k')
z = tvm.compute((m,), lambda i:
tvm.sum(w[i, k] * x[k], axis=k), name='z')
Wb = tvm.decl_buffer(w.shape, w.dtype,
name="W",
offset_factor=16,
strides=[tvm.var('ldw'), 1])
def intrin_func(ins, outs):
ww, xx = ins
zz = outs[0]
ww_ptr = ww.access_ptr("r")
xx_ptr = xx.access_ptr("r")
zz_ptr = zz.access_ptr("w")
body = tvm.call_packed(
"gemm", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0])
reset = tvm.call_packed(
"fill_zero", zz_ptr, n)
update = tvm.call_packed(
"gemv_add", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0])
return body, reset, update
with tvm.build_config(data_alignment=16,
offset_factor=16):
return tvm.decl_tensor_intrin(z.op, intrin_func,
binds={w: Wb})
def test_schedule_tensor_compute1():
# basic: split, reorder, tile
M, N, L = 2048, 1024, 512
factor, rfactor = 16, 16
A = tvm.placeholder((N//factor, L//rfactor, factor, rfactor), name='A')
B = tvm.placeholder((M, L//rfactor, rfactor), name='B')
k = tvm.reduce_axis((0, L//rfactor), name='k')
gemv = intrin_gemv(factor, rfactor)
C = tvm.compute((N, M//factor, factor),
lambda i, j: gemv(A[i, k, 0:factor, 0:factor], B[j, k, 0:rfactor], reduce_axis=k),
name='C')
s = tvm.create_schedule(C.op)
ai, aj, ax = s[C].op.axis
aio, aii = s[C].split(ai, 16)
s[C].reorder(aio, aj, aii)
aioo, ajo, aioi, aji = s[C].tile(aio, aj, 16, 4)
s = s.normalize()
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
def intrin_vadd(n, cache_read=False, cache_write=False):
scope_ubuf = 'local'
dtype = 'float32'
x = tvm.placeholder((n,), dtype=dtype, name='vx')
y = tvm.placeholder((n,), dtype=dtype, name='vy')
z = tvm.compute(x.shape, lambda i: x[i] + y[i], name='z')
s = tvm.create_schedule(z.op)
def create_buffer(t):
return tvm.decl_buffer(t.shape, t.dtype,
name='W'+t.name,
scope=scope_ubuf,
offset_factor=16)
binds = {}
if cache_read:
binds[x] = create_buffer(x)
binds[y] = create_buffer(y)
if cache_write:
binds[z] = create_buffer(z)
def intrin_func(ins, outs):
ib = tvm.ir_builder.create()
ib.emit(tvm.call_extern(outs[0].dtype, 'vadd', ins[0].access_ptr("r"), ins[1].access_ptr('r'), outs[0].access_ptr('wr')))
return ib.get()
with tvm.build_config(offset_factor=16):
return tvm.decl_tensor_intrin(z.op, intrin_func, binds=binds)
def test_schedule_tensor_compute2():
# cache_read, cache_write
M = 1024
factor = 16
dtype = 'float32'
scope_ubuf = 'local'
A = tvm.placeholder((M//factor, factor), name="A", dtype=dtype)
B = tvm.placeholder((M//factor, factor), name="B", dtype=dtype)
vadd = intrin_vadd(factor, True, True)
C = tvm.compute((M//factor, factor),
lambda i: vadd(A[i, 0:factor], B[i, 0:factor]), name='C')
s = tvm.create_schedule(C.op)
AL = s.cache_read(A, scope_ubuf, C)
BL = s.cache_read(B, scope_ubuf, C)
CL = s.cache_write(C, scope_ubuf)
s = s.normalize()
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_schedule_tensor_compute3():
# compute_at
M = 1024
factor = 16
dtype = 'float32'
A = tvm.placeholder((M//factor, factor), name="A", dtype=dtype)
B = tvm.placeholder((M//factor, factor), name="B", dtype=dtype)
Bi = tvm.compute((M//factor, factor), lambda i, j: B[i, j] + 5, name="Bi")
vadd = intrin_vadd(factor)
C = tvm.compute((M//factor, factor),
lambda i: vadd(A[i, 0:factor], Bi[i, 0:factor]), name='C')
s = tvm.create_schedule(C.op)
s[Bi].compute_at(s[C], C.op.axis[0])
s = s.normalize()
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
if __name__ == "__main__": if __name__ == "__main__":
test_schedule_middle_cache() test_schedule_middle_cache()
test_inline_multi_reduce() test_inline_multi_reduce()
...@@ -294,3 +421,6 @@ if __name__ == "__main__": ...@@ -294,3 +421,6 @@ if __name__ == "__main__":
test_schedule2() test_schedule2()
test_schedule_cache() test_schedule_cache()
test_schedule_bound_condition() test_schedule_bound_condition()
test_schedule_tensor_compute1()
test_schedule_tensor_compute2()
test_schedule_tensor_compute3()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment