Commit 825566cc by Tianqi Chen Committed by GitHub

[SCHEDULE] tensorize (#223)

parent 28120f55
/*!
* Copyright (c) 2016 by Contributors
* \file arithmetic.h
* \brief Algebra and set operations.
* \brief Algebra and set operations and simplifications.
*/
#ifndef TVM_ARITHMETIC_H_
#define TVM_ARITHMETIC_H_
......
......@@ -179,7 +179,11 @@ enum IterVarType : int {
/*!
* \brief The loop is parallelized.
*/
kParallelized = 7
kParallelized = 7,
/*!
* \brief Marks boundary of tensorization intrinsic.
*/
kTensorized = 8
};
/*!
......@@ -281,10 +285,27 @@ inline const char* IterVarType2String(IterVarType t) {
case kUnrolled: return "Unrolled";
case kVectorized: return "Vectorized";
case kParallelized: return "Parallelized";
case kTensorized: return "Tensorized";
}
return "Unknown";
}
/*
* \brief Template function to convert Map to unordered_map
* Sometimes useful for API gluing when internal uses unordered_map
* \param dmap The container map
* \return The corresponding unordered_map.
* \tparam K the key of the Map.
* \tparam V the value of the Map.
*/
template<typename K, typename V>
inline std::unordered_map<K, V> as_unordered_map(const Map<K, V>& dmap) {
std::unordered_map<K, V> ret;
for (auto kv : dmap) {
ret[kv.first] = kv.second;
}
return ret;
}
} // namespace tvm
namespace std {
......
......@@ -9,8 +9,6 @@
#ifndef TVM_IR_PASS_H_
#define TVM_IR_PASS_H_
#include <ir/IREquality.h>
#include <arithmetic/Simplify.h>
#include <tvm/ir_functor.h>
#include <arithmetic/Simplify.h>
#include <unordered_map>
......@@ -33,6 +31,20 @@ inline Stmt Simplify(Stmt a) {
}
/*!
* \brief Simplify by applying canonical form.
* \param stmt The statement to be canonically simplifed.
* \return Canonicalized statement.
*/
Stmt CanonicalSimplify(Stmt stmt);
/*!
* \brief Simplify by applying canonical form.
* \param expr The statement to be canonically simplifed.
* \return Canonicalized expression.
*/
Expr CanonicalSimplify(Expr expr);
/*!
* \brief Deep compare lhs and rhs
* \param lhs The left operand
* \param rhs The right operand
......@@ -88,13 +100,6 @@ bool ExprUseVar(const Expr& e, const std::unordered_set<const Variable*>& vset);
Stmt ConvertSSA(Stmt stmt);
/*!
* \brief Simplify by applying canonical form.
* \param stmt The statement to be canonically simplifed.
* \return Canonicalized statement.
*/
Stmt CanonicalSimplify(Stmt stmt);
/*!
* \brief Substitute the var specified in key->var to be value.
* \param stmt The source statement to be substituted
* \param value_map The map of new values.
......
......@@ -10,6 +10,7 @@
#include "./base.h"
#include "./expr.h"
#include "./tensor.h"
#include "./tensor_intrin.h"
namespace tvm {
......@@ -160,6 +161,14 @@ class Stage : public NodeRef {
*/
Stage& vectorize(IterVar var); // NOLINT(*)
/*!
* \brief Replace computation of the current stage by tensor intrinsic f.
* \param var The axis marks beginning of tensorization.
* Every operations inside the axis(include axis itself is tensorized).
* \param f The Tensor compute intrinsics.
* \return reference to self.
*/
Stage& tensorize(IterVar var, TensorIntrin f); // NOLINT(*)
/*!
* \brief Unroll iteration.
* \param var The axis to be unrolled.
* \return reference to self.
......@@ -465,12 +474,18 @@ class IterVarAttrNode : public Node {
Array<Tensor> prefetch_data;
/*! \brief The offset used in each prefetch */
Array<Expr> prefetch_offset;
/*!
* \brief Tensor intrinsic used in tensorization,
* when the axis is marked as Tensorized
*/
TensorIntrin tensor_intrin;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("iter_type", &iter_type);
v->Visit("bind_thread", &bind_thread);
v->Visit("prefetch_data", &prefetch_data);
v->Visit("prefetch_offset", &prefetch_offset);
v->Visit("tensor_intrin", &tensor_intrin);
}
static constexpr const char* _type_key = "IterVarAttr";
......
/*!
* Copyright (c) 2017 by Contributors
* \file tensor_intrin.h
* \brief Tensor intrinsic operations.
*/
#ifndef TVM_TENSOR_INTRIN_H_
#define TVM_TENSOR_INTRIN_H_
#include <string>
#include "./tensor.h"
#include "./buffer.h"
namespace tvm {
// Internal node container of tensor intrinsics.
class TensorIntrinNode;
/*! \brief Tensor intrinsic node. */
class TensorIntrin : public NodeRef {
public:
TensorIntrin() {}
explicit TensorIntrin(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const TensorIntrinNode* operator->() const;
/*! \brief specify container node */
using ContainerType = TensorIntrinNode;
};
/*! \brief Node to represent a Tensor intrinsic operator */
class TensorIntrinNode : public Node {
public:
/*! \brief The name of the intrinsic */
std::string name;
/*! \brief The operation this intrinsics is carrying out */
Operation op;
/*! \brief List of inputs of operator, placeholder in postdfs order */
Array<Tensor> inputs;
/*!
* \brief Symbolic buffers of each output/input tensor
* buffers[0:len(inputs)] are buffers of the inputs.
* buffers[len(inputs):] are buffers of each output.
*
* \note When a field in Buffer is Var, it means we can be flexible
* wrt that field and Var can occur in body.
* When it is a constant, it means we can only take data in that shape.
*/
Array<Buffer> buffers;
/*! \brief The normal statement to execute the intrinsic */
Stmt body;
/*!
* \brief Special statement for reduction op, can be None
* reset the value of output buffer to identity value.
*/
Stmt reduce_init;
/*!
* \brief Special statement for reduction op, can be None
* Reduce: do a reduction of current output buffer with the result.
*/
Stmt reduce_update;
/*! \brief constructor */
TensorIntrinNode() {}
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("op", &op);
v->Visit("inputs", &inputs);
v->Visit("buffers", &buffers);
v->Visit("body", &body);
v->Visit("reduce_init", &reduce_init);
v->Visit("reduce_update", &reduce_update);
}
static TensorIntrin make(std::string name,
Operation op,
Array<Tensor> inputs,
Array<Buffer> buffers,
Stmt body,
Stmt reduce_init,
Stmt reduce_update);
static constexpr const char* _type_key = "TensorIntrin";
TVM_DECLARE_NODE_TYPE_INFO(TensorIntrinNode, Node);
};
inline const TensorIntrinNode* TensorIntrin::operator->() const {
return static_cast<const TensorIntrinNode*>(node_.get());
}
} // namespace tvm
#endif // TVM_TENSOR_INTRIN_H_
# pylint: disable=redefined-builtin, wildcard-import
"""TVM: a DSL for tensor kernel compilation"""
"""TVM: Low level DSL/IR stack for tensor computation."""
from __future__ import absolute_import as _abs
from . import tensor
......@@ -23,6 +23,7 @@ from ._ffi.function import Function
from ._ffi.base import TVMError, __version__
from .api import *
from .intrin import *
from .tensor_intrin import decl_tensor_intrin
from .node import register_node
from .ndarray import register_extension
from .schedule import create_schedule
......
......@@ -109,9 +109,10 @@ def get_binds(args, binds=None):
args : list of Buffer or Tensor or Var
The argument lists to the function.
binds : dict, optional
Dictionary that maps the binding of symbolic buffer to Tensor.
By default, a new buffer is created for each tensor in the argument.
binds : dict of :any:`Tensor` to :any:`Buffer`, optional
Dictionary that maps the Tensor to Buffer which specified the data layout
requirement of the function. By default, a new compact buffer is created
for each tensor in the argument.
Returns
-------
......@@ -126,14 +127,16 @@ def get_binds(args, binds=None):
arg_list = []
for x in args:
if isinstance(x, tensor.Tensor):
buf = api.decl_buffer(x.shape,
dtype=x.dtype,
name=x.name,
data_alignment=cfg.data_alignment,
offset_factor=cfg.offset_factor)
assert x not in binds
binds[x] = buf
arg_list.append(buf)
if x not in binds:
buf = api.decl_buffer(x.shape,
dtype=x.dtype,
name=x.name,
data_alignment=cfg.data_alignment,
offset_factor=cfg.offset_factor)
binds[x] = buf
arg_list.append(buf)
else:
arg_list.append(binds[x])
elif isinstance(x, schedule.Buffer):
arg_list.append(x)
elif isinstance(x, expr.Var):
......@@ -161,9 +164,10 @@ def lower(sch,
name : str, optional
The name of result function.
binds : dict, optional
Dictionary that maps the binding of symbolic buffer to Tensor.
By default, a new buffer is created for each tensor in the argument.
binds : dict of :any:`Tensor` to :any:`Buffer`, optional
Dictionary that maps the Tensor to Buffer which specified the data layout
requirement of the function. By default, a new compact buffer is created
for each tensor in the argument.
simple_mode : bool, optional
Whether only output simple and compact statement, this will skip
......
......@@ -57,6 +57,7 @@ class IterVar(NodeBase, _expr.ExprOp):
Unrolled = 5
Vectorized = 6
Parallelized = 7
Tensorized = 8
_tensor.iter_var_cls = IterVar
......@@ -388,6 +389,19 @@ class Stage(NodeBase):
"""
_api_internal._StageVectorize(self, var)
def tensorize(self, var, tensor_intrin):
"""Tensorize the computation enclosed by var with tensor_intrin
Parameters
----------
var : IterVar
The iteration boundary of tensorization.
tensor_intrin : TensorIntrin
The tensor intrinsic used for computation.
"""
_api_internal._StageTensorize(self, var, tensor_intrin)
def unroll(self, var):
"""Unroll the iteration.
......
"""Tensor intrinsics"""
from __future__ import absolute_import as _abs
from . import _api_internal
from . import api as _api
from . import expr as _expr
from . import stmt as _stmt
from . import make as _make
from . import tensor as _tensor
from .build import BuildConfig
from ._ffi.node import NodeBase, register_node
@register_node
class TensorIntrin(NodeBase):
"""Tensor intrinsic functions for certain computation.
See Also
--------
decl_tensor_intrin: Construct a TensorIntrin
"""
pass
def decl_tensor_intrin(op,
fcompute,
name="tensor_intrin",
binds=None):
"""Declare a tensor intrinsic function.
Parameters
----------
op: Operation
The symbolic description of the intrinsic operation
fcompute: lambda function of inputs, outputs-> stmt
Specifies the IR statement to do the computation.
See the following note for function signature of fcompute
.. note::
**Parameters**
- **ins** (list of :any:`Buffer`) - Placeholder for each inputs
- **outs** (list of :any:`Buffer`) - Placeholder for each outputs
**Returns**
- **stmt** (:any:`Stmt`, or tuple of three stmts)
- If a single stmt is returned, it represents the body
- If tuple of three stmts are returned they corresponds to body,
reduce_init, reduce_update
name: str, optional
The name of the intrinsic.
binds: dict of :any:`Tensor` to :any:`Buffer`, optional
Dictionary that maps the Tensor to Buffer which specified the data layout
requirement of the function. By default, a new compact buffer is created
for each tensor in the argument.
Returns
-------
intrin: TensorIntrin
A TensorIntrin that can be used in tensorize schedule.
"""
if not isinstance(op, _tensor.Operation):
raise TypeError("expect Operation")
inputs = op.input_tensors
binds = binds if binds else {}
tensors = [x for x in inputs]
for i in range(op.num_outputs):
tensors.append(op.output(i))
binds_list = []
for t in inputs:
if not isinstance(t.op, _tensor.PlaceholderOp):
raise ValueError("Donot yet support composition op")
cfg = BuildConfig.current
for t in tensors:
buf = (binds[t] if t in binds else
_api.decl_buffer(t.shape, t.dtype, t.op.name,
data_alignment=cfg.data_alignment,
offset_factor=cfg.offset_factor))
binds_list.append(buf)
body = fcompute(binds_list[:len(inputs)], binds_list[len(inputs):])
if isinstance(body, (_expr.Expr, _stmt.Stmt)):
body = [body]
body = [_make.Evaluate(x) if isinstance(x, _expr.Expr) else x for x in body]
if len(body) < 3:
body += [None] * (3 - len(body))
return _api_internal._TensorIntrin(
name, op, inputs, binds_list, *body)
......@@ -164,6 +164,17 @@ TVM_REGISTER_API("_Tensor")
args[3]);
});
TVM_REGISTER_API("_TensorIntrin")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TensorIntrinNode::make(args[0],
args[1],
args[2],
args[3],
args[4],
args[5],
args[6]);
});
TVM_REGISTER_API("_TensorEqual")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Tensor() == args[1].operator Tensor();
......@@ -335,6 +346,12 @@ TVM_REGISTER_API("_StageVectorize")
.vectorize(args[1]);
});
TVM_REGISTER_API("_StageTensorize")
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage()
.tensorize(args[1], args[2]);
});
TVM_REGISTER_API("_StageParallel")
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage()
......
......@@ -21,6 +21,15 @@ TVM_REGISTER_API("ir_pass.Simplify")
}
});
TVM_REGISTER_API("ir_pass.CanonicalSimplify")
.set_body([](TVMArgs args, TVMRetValue *ret) {
if (args[0].IsNodeType<Stmt>()) {
*ret = CanonicalSimplify(args[0].operator Stmt());
} else {
*ret = CanonicalSimplify(args[0].operator Expr());
}
});
TVM_REGISTER_API("ir_pass.Equal")
.set_body([](TVMArgs args, TVMRetValue *ret) {
if (args[0].IsNodeType<Stmt>()) {
......@@ -76,7 +85,6 @@ TVM_REGISTER_API("ir_pass.PostOrderVisit")
REGISTER_PASS1(ConvertSSA);
REGISTER_PASS1(VerifySSA);
REGISTER_PASS1(CanonicalSimplify);
REGISTER_PASS4(Inline);
REGISTER_PASS2(StorageFlatten);
REGISTER_PASS1(VectorizeLoop);
......
......@@ -531,5 +531,8 @@ Stmt CanonicalSimplify(Stmt stmt) {
return arith::Canonical().Simplify(stmt);
}
Expr CanonicalSimplify(Expr expr) {
return arith::Canonical().Simplify(expr);
}
} // namespace ir
} // namespace tvm
......@@ -42,7 +42,6 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
<< ", identity_element=" << op->identity_element
<< ")";
});
} // namespace Internal
} // namespace Halide
......@@ -141,5 +140,6 @@ TVM_REGISTER_NODE_TYPE(Realize);
TVM_REGISTER_NODE_TYPE(Block);
TVM_REGISTER_NODE_TYPE(IfThenElse);
TVM_REGISTER_NODE_TYPE(Evaluate);
} // namespace ir
} // namespace tvm
......@@ -4,6 +4,7 @@
*/
#include <tvm/tensor.h>
#include <tvm/operation.h>
#include <tvm/tensor_intrin.h>
#include <ir/IR.h>
#include <memory>
......@@ -54,4 +55,28 @@ Tensor Operation::output(size_t i) const {
return Tensor(node);
}
TensorIntrin TensorIntrinNode::make(std::string name,
Operation op,
Array<Tensor> inputs,
Array<Buffer> buffers,
Stmt body,
Stmt reduce_init,
Stmt reduce_update) {
auto n = std::make_shared<TensorIntrinNode>();
n->name = std::move(name);
n->op = std::move(op);
n->inputs = std::move(inputs);
n->buffers = std::move(buffers);
n->body = std::move(body);
n->reduce_init = std::move(reduce_init);
n->reduce_update = std::move(reduce_update);
return TensorIntrin(n);
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TensorIntrinNode>([](const TensorIntrinNode *n, IRPrinter *p) {
p->stream << "TensorIntrin(name=" << n->name << ", " << n << ")";
});
TVM_REGISTER_NODE_TYPE(TensorIntrinNode);
} // namespace tvm
......@@ -297,14 +297,62 @@ Stmt MakeComputeStmt(const ComputeOpNode* self,
}
}
enum class ComputeType {
kNormal,
kCrossThreadReduction,
kTensorize
};
ComputeType DetectComputeType(const ComputeOpNode* self,
const Stage& stage) {
// Verify correctness of leaf nest.
int normal_red = 0, thread_red = 0, tensorize = 0;
for (IterVar iv : stage->leaf_iter_vars) {
IterVarAttr attr;
auto it = stage->iter_var_attrs.find(iv);
if (it != stage->iter_var_attrs.end()) {
attr = (*it).second;
}
if (attr.defined() && attr->iter_type == kTensorized) {
++tensorize;
}
if (iv->iter_type == kCommReduce) {
if (attr.defined() && attr->bind_thread.defined()) {
++thread_red;
} else {
++normal_red;
}
} else {
CHECK_EQ(thread_red, 0)
<< "Cross thread reduce cannot swap with normal data axis";
}
}
if (tensorize != 0) {
CHECK(thread_red == 0)
<< "Cannot mix cross thread reduction with Tensorize";
return ComputeType::kTensorize;
}
CHECK(normal_red == 0 || thread_red == 0)
<< "Cannot mix normal reduction with thread reduce";
if (thread_red != 0) {
return ComputeType::kCrossThreadReduction;
} else {
return ComputeType::kNormal;
}
}
// implement the provide utility.
Stmt ComputeOpNode::BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) const {
CHECK_EQ(stage->op.operator->(), this);
if (IsCrossThreadReduction(this, stage)) {
ComputeType ctype = DetectComputeType(this, stage);
if (ctype == ComputeType::kCrossThreadReduction) {
// specially handle cross thread reduction.
return MakeCrossThreadReduction(this, stage, dom_map);
} else if (ctype == ComputeType::kTensorize) {
return MakeTensorize(this, stage, dom_map);
} else {
return MakeComputeStmt(this, stage, dom_map);
}
......
......@@ -46,13 +46,6 @@ struct ComputeLoopNest {
};
/*!
* \brief Whether compute op is a cross thread reduction structure.
* \param self The pointer to ComputeOpNode
* \param stage the schedule stage.
*/
bool IsCrossThreadReduction(const ComputeOpNode* self,
const Stage& stage);
/*!
* \brief Build body of compute for cross thread reduction pattern.
* \param self The pointer to ComputeOpNode
* \param stage The schedule stage.
......@@ -63,6 +56,17 @@ Stmt MakeCrossThreadReduction(
const ComputeOpNode* self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map);
/*!
* \brief Build body of compute for tensorization.
* \param self The pointer to ComputeOpNode
* \param stage The schedule stage.
* \param dom_map The domain map.
* \return The created statement.
*/
Stmt MakeTensorize(const ComputeOpNode* self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map);
} // namespace tvm
#endif // TVM_OP_COMPUTE_OP_H_
......@@ -10,29 +10,6 @@
namespace tvm {
using namespace ir;
bool IsCrossThreadReduction(const ComputeOpNode* self,
const Stage& stage) {
// Verify correctness of leaf nest.
int normal_red = 0, thread_red = 0;
for (IterVar iv : stage->leaf_iter_vars) {
if (iv->iter_type == kCommReduce) {
auto it = stage->iter_var_attrs.find(iv);
if (it != stage->iter_var_attrs.end() &&
(*it).second->bind_thread.defined()) {
++thread_red;
} else {
++normal_red;
}
} else {
CHECK_EQ(thread_red, 0)
<< "Cross thread reduce cannot swap with normal data axis";
}
}
CHECK(normal_red == 0 || thread_red == 0)
<< "Cannot mix normal reduction with thread reduce";
return thread_red != 0;
}
Stmt MakeCrossThreadReduction(
const ComputeOpNode* self,
const Stage& stage,
......
......@@ -65,6 +65,7 @@ MakeLoopNest(const Stage& stage,
case kVectorized: for_type = ForType::Vectorized; break;
case kParallelized: for_type = ForType::Parallel; break;
case kDataPar: break;
case kTensorized: break;
default: LOG(FATAL) << "Unknown iter type"
<< it_attr->iter_type
<< " in the iter_var_attrs";
......
......@@ -14,8 +14,8 @@ namespace runtime {
DSLAPI* FindDSLAPI() {
auto* f = Registry::Get("dsl_api.singleton");
if (f == nullptr) {
throw dmlc::Error("TVM runtime only environment, "\
"DSL API is not available");
throw dmlc::Error("TVM runtime only environment,"\
" DSL API is not available");
}
void* ptr = (*f)();
return static_cast<DSLAPI*>(ptr);
......
......@@ -292,7 +292,8 @@ Stage& Stage::tile(IterVar x_parent, IterVar y_parent,
return *this;
}
inline void SetAttrIterType(StageNode* self, IterVar var, IterVarType iter_type) {
template<typename FUpdate>
inline void UpdateIterVarAttr(StageNode* self, IterVar var, FUpdate fupdate) {
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
FindLeafVar(all_vars, leaf_vars, var);
......@@ -303,15 +304,29 @@ inline void SetAttrIterType(StageNode* self, IterVar var, IterVarType iter_type)
} else {
n = std::make_shared<IterVarAttrNode>();
}
n->iter_type = iter_type;
fupdate(n.get());
self->iter_var_attrs.Set(var, IterVarAttr(n));
}
inline void SetAttrIterType(StageNode* self, IterVar var, IterVarType iter_type) {
UpdateIterVarAttr(self, var, [iter_type](IterVarAttrNode* n) {
n->iter_type = iter_type;
});
}
Stage& Stage::vectorize(IterVar var) { // NOLINT(*)
SetAttrIterType(operator->(), var, kVectorized);
return *this;
}
Stage& Stage::tensorize(IterVar var, TensorIntrin f) { // NOLINT(*)
UpdateIterVarAttr(operator->(), var, [f](IterVarAttrNode* n) {
n->iter_type = kTensorized;
n->tensor_intrin = f;
});
return *this;
}
Stage& Stage::unroll(IterVar var) { // NOLINT(*)
SetAttrIterType(operator->(), var, kUnrolled);
return *this;
......
......@@ -296,10 +296,7 @@ class SchedulePostProc : public IRMutator {
Stmt ScheduleOps(
Schedule sch, Map<IterVar, Range> dom_map_) {
Stmt body = Stmt();
std::unordered_map<IterVar, Range> dom_map;
for (auto kv : dom_map_) {
dom_map[kv.first] = kv.second;
}
std::unordered_map<IterVar, Range> dom_map = as_unordered_map(dom_map_);
// scan init and scan updates
std::unordered_map<Operation, Operation> scan_init;
for (Stage s : sch->stages) {
......
......@@ -23,7 +23,7 @@ def test_llvm_add_pipeline():
name='A')
binds = {A : Ab}
# BUILD and invoke the kernel.
f = tvm.build(s, [Ab, B, C], "llvm", binds=binds)
f = tvm.build(s, [A, B, C], "llvm", binds=binds)
ctx = tvm.cpu(0)
# launch the kernel.
n = nn
......
......@@ -115,8 +115,33 @@ def test_rfactor():
assert(BF.op.body[0].axis[1].var == ko.var)
assert(s[B].op.body[0].axis[0].dom.extent.value == 4)
def test_tensor_intrin():
n = 16
x = tvm.placeholder((n,), name='x')
y = tvm.placeholder((n,), name='y')
z = tvm.compute(x.shape, lambda i: x[i] + y[i], name='z')
def intrin_func(ins, outs):
assert(isinstance(ins[0], tvm.schedule.Buffer))
assert(ins[0].shape[0].value == n)
return tvm.call_packed("vadd", ins[0].data, outs[0].data, ins[0].shape[0])
intrin = tvm.decl_tensor_intrin(z.op, intrin_func)
assert intrin.op == z.op
assert intrin.reduce_init is None
assert tuple(intrin.inputs) == tuple(z.op.input_tensors)
assert(intrin.buffers[0].shape[0].value == n)
m = 32
x = tvm.placeholder((m,), name='x')
y = tvm.placeholder((m,), name='y')
z = tvm.compute(x.shape, lambda i: x[i] + y[i], name='z')
s = tvm.create_schedule(z.op)
xo, xi = s[z].split(z.op.axis[0], factor=n)
s[z].tensorize(xi, intrin)
assert(s[z].iter_var_attrs[xi].tensor_intrin == intrin)
assert(s[z].iter_var_attrs[xi].iter_type == tvm.schedule.IterVar.Tensorized)
if __name__ == "__main__":
test_tensor_intrin()
test_rfactor()
test_schedule_create()
test_reorder()
......
import tvm
def intrin_vadd(n):
x = tvm.placeholder((n,), name='vx')
y = tvm.placeholder((n,), name='vy')
z = tvm.compute(x.shape, lambda i: x[i] + y[i], name='z')
def intrin_func(ins, outs):
xx, yy = ins
zz = outs[0]
return tvm.call_packed("vadd", xx, yy, zz)
with tvm.build_config(offset_factor=16):
return tvm.decl_tensor_intrin(z.op, intrin_func)
def intrin_gemv(m, n):
w = tvm.placeholder((m, n), name='w')
x = tvm.placeholder((n,), name='x')
k = tvm.reduce_axis((0, n), name='k')
z = tvm.compute((m,), lambda i:
tvm.sum(w[i, k] * x[k], axis=k), name='z')
Wb = tvm.decl_buffer(w.shape, w.dtype, name="W",
strides=[tvm.var('ldw'), 1])
def intrin_func(ins, outs):
ww, xx = ins
zz = outs[0]
body = tvm.call_packed(
"gemv", ww.data, xx.data, zz.data, n, ww.strides[0])
reset = tvm.call_packed(
"fill_zero", outs[0].data, n)
update = tvm.call_packed(
"gemv_add", ww.data, xx.data, zz.data, n, ww.strides[0])
return body, reset, update
with tvm.build_config(data_alignment=16):
return tvm.decl_tensor_intrin(z.op, intrin_func,
binds={w: Wb})
def test_tensorize_vadd():
m = 128
x = tvm.placeholder((m,), name='x')
y = tvm.placeholder((m,), name='y')
z = tvm.compute(x.shape, lambda i: x[i] + y[i], name='z')
def check(factor):
s = tvm.create_schedule(z.op)
xo, xi = s[z].split(z.op.axis[0], factor=factor)
vadd = intrin_vadd(factor)
s[z].tensorize(xi, vadd)
s = s.normalize()
dom_map = tvm.schedule.InferBound(s)
finfer = tvm.get_global_func("test.op.InferTensorizeRegion")
out_dom, in_dom = finfer(s[z], dom_map)
assert tvm.ir_pass.Equal(out_dom[z.op.axis[0]].extent, factor)
assert tvm.ir_pass.Equal(out_dom[z.op.axis[0]].min, xo * factor)
assert tvm.ir_pass.Equal(in_dom.items()[0][1][0].extent, factor)
fmatch = tvm.get_global_func("test.op.MatchTensorizeBody")
body = fmatch(s[z], out_dom, in_dom, vadd)
assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(body[0]),
tvm.ir_pass.CanonicalSimplify(vadd.op.body[0]))
stmt = tvm.schedule.ScheduleOps(s, dom_map)
tvm.lower(s, [x, y, z])
check(16)
def test_tensorize_matmul():
n = 1024
m = n
l = n
A = tvm.placeholder((n, l), name='A')
B = tvm.placeholder((m, l), name='B')
k = tvm.reduce_axis((0, l), name='k')
C = tvm.compute((n, m), lambda i, j:
tvm.sum(B[j, k] * A[i, k], axis=k), name='C')
def check(factor):
s = tvm.create_schedule(C.op)
x, y = C.op.axis
yo, yi = s[C].split(y, factor=factor)
gemv = intrin_gemv(factor, l)
s[C].tensorize(yi, gemv)
s = s.normalize()
dom_map = tvm.schedule.InferBound(s)
finfer = tvm.get_global_func("test.op.InferTensorizeRegion")
out_dom, in_dom = finfer(s[C], dom_map)
assert tvm.ir_pass.Equal(out_dom[x].extent, 1)
assert tvm.ir_pass.Equal(out_dom[y].extent, factor)
assert tvm.ir_pass.Equal(out_dom[y].min, yo * factor)
fmatch = tvm.get_global_func("test.op.MatchTensorizeBody")
body = fmatch(s[C], out_dom, in_dom, gemv)
assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(body[0]),
tvm.ir_pass.CanonicalSimplify(gemv.op.body[0]))
stmt = tvm.schedule.ScheduleOps(s, dom_map)
tvm.lower(s, [A, B, C])
def check_rfactor(factor, rfactor):
s = tvm.create_schedule(C.op)
x, y = C.op.axis
rk = C.op.reduce_axis[0]
yo, yi = s[C].split(y, factor=factor)
ro, ri = s[C].split(rk, factor=rfactor)
s[C].reorder(yo, ro, yi, ri)
gemv = intrin_gemv(factor, rfactor)
s[C].tensorize(yi, gemv)
s = s.normalize()
dom_map = tvm.schedule.InferBound(s)
finfer = tvm.get_global_func("test.op.InferTensorizeRegion")
out_dom, in_dom = finfer(s[C], dom_map)
assert tvm.ir_pass.Equal(out_dom[x].extent, 1)
assert tvm.ir_pass.Equal(out_dom[y].extent, factor)
assert tvm.ir_pass.Equal(out_dom[y].min, yo * factor)
fmatch = tvm.get_global_func("test.op.MatchTensorizeBody")
body = fmatch(s[C], out_dom, in_dom, gemv)
assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(body[0]),
tvm.ir_pass.CanonicalSimplify(gemv.op.body[0]))
stmt = tvm.schedule.ScheduleOps(s, dom_map)
tvm.lower(s, [A, B, C])
check(16)
check_rfactor(16, 16)
if __name__ == "__main__":
test_tensorize_vadd()
test_tensorize_matmul()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment