Commit 825566cc by Tianqi Chen Committed by GitHub

[SCHEDULE] tensorize (#223)

parent 28120f55
/*!
* Copyright (c) 2016 by Contributors
* \file arithmetic.h
* \brief Algebra and set operations.
* \brief Algebra and set operations and simplifications.
*/
#ifndef TVM_ARITHMETIC_H_
#define TVM_ARITHMETIC_H_
......
......@@ -179,7 +179,11 @@ enum IterVarType : int {
/*!
* \brief The loop is parallelized.
*/
kParallelized = 7
kParallelized = 7,
/*!
* \brief Marks boundary of tensorization intrinsic.
*/
kTensorized = 8
};
/*!
......@@ -281,10 +285,27 @@ inline const char* IterVarType2String(IterVarType t) {
case kUnrolled: return "Unrolled";
case kVectorized: return "Vectorized";
case kParallelized: return "Parallelized";
case kTensorized: return "Tensorized";
}
return "Unknown";
}
/*
* \brief Template function to convert Map to unordered_map
* Sometimes useful for API gluing when internal uses unordered_map
* \param dmap The container map
* \return The corresponding unordered_map.
* \tparam K the key of the Map.
* \tparam V the value of the Map.
*/
template<typename K, typename V>
inline std::unordered_map<K, V> as_unordered_map(const Map<K, V>& dmap) {
std::unordered_map<K, V> ret;
for (auto kv : dmap) {
ret[kv.first] = kv.second;
}
return ret;
}
} // namespace tvm
namespace std {
......
......@@ -9,8 +9,6 @@
#ifndef TVM_IR_PASS_H_
#define TVM_IR_PASS_H_
#include <ir/IREquality.h>
#include <arithmetic/Simplify.h>
#include <tvm/ir_functor.h>
#include <arithmetic/Simplify.h>
#include <unordered_map>
......@@ -33,6 +31,20 @@ inline Stmt Simplify(Stmt a) {
}
/*!
* \brief Simplify by applying canonical form.
* \param stmt The statement to be canonically simplifed.
* \return Canonicalized statement.
*/
Stmt CanonicalSimplify(Stmt stmt);
/*!
* \brief Simplify by applying canonical form.
* \param expr The statement to be canonically simplifed.
* \return Canonicalized expression.
*/
Expr CanonicalSimplify(Expr expr);
/*!
* \brief Deep compare lhs and rhs
* \param lhs The left operand
* \param rhs The right operand
......@@ -88,13 +100,6 @@ bool ExprUseVar(const Expr& e, const std::unordered_set<const Variable*>& vset);
Stmt ConvertSSA(Stmt stmt);
/*!
* \brief Simplify by applying canonical form.
* \param stmt The statement to be canonically simplifed.
* \return Canonicalized statement.
*/
Stmt CanonicalSimplify(Stmt stmt);
/*!
* \brief Substitute the var specified in key->var to be value.
* \param stmt The source statement to be substituted
* \param value_map The map of new values.
......
......@@ -10,6 +10,7 @@
#include "./base.h"
#include "./expr.h"
#include "./tensor.h"
#include "./tensor_intrin.h"
namespace tvm {
......@@ -160,6 +161,14 @@ class Stage : public NodeRef {
*/
Stage& vectorize(IterVar var); // NOLINT(*)
/*!
* \brief Replace computation of the current stage by tensor intrinsic f.
* \param var The axis marks beginning of tensorization.
* Every operations inside the axis(include axis itself is tensorized).
* \param f The Tensor compute intrinsics.
* \return reference to self.
*/
Stage& tensorize(IterVar var, TensorIntrin f); // NOLINT(*)
/*!
* \brief Unroll iteration.
* \param var The axis to be unrolled.
* \return reference to self.
......@@ -465,12 +474,18 @@ class IterVarAttrNode : public Node {
Array<Tensor> prefetch_data;
/*! \brief The offset used in each prefetch */
Array<Expr> prefetch_offset;
/*!
* \brief Tensor intrinsic used in tensorization,
* when the axis is marked as Tensorized
*/
TensorIntrin tensor_intrin;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("iter_type", &iter_type);
v->Visit("bind_thread", &bind_thread);
v->Visit("prefetch_data", &prefetch_data);
v->Visit("prefetch_offset", &prefetch_offset);
v->Visit("tensor_intrin", &tensor_intrin);
}
static constexpr const char* _type_key = "IterVarAttr";
......
/*!
* Copyright (c) 2017 by Contributors
* \file tensor_intrin.h
* \brief Tensor intrinsic operations.
*/
#ifndef TVM_TENSOR_INTRIN_H_
#define TVM_TENSOR_INTRIN_H_
#include <string>
#include "./tensor.h"
#include "./buffer.h"
namespace tvm {
// Internal node container of tensor intrinsics.
class TensorIntrinNode;
/*! \brief Tensor intrinsic node. */
class TensorIntrin : public NodeRef {
public:
TensorIntrin() {}
explicit TensorIntrin(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const TensorIntrinNode* operator->() const;
/*! \brief specify container node */
using ContainerType = TensorIntrinNode;
};
/*! \brief Node to represent a Tensor intrinsic operator */
class TensorIntrinNode : public Node {
public:
/*! \brief The name of the intrinsic */
std::string name;
/*! \brief The operation this intrinsics is carrying out */
Operation op;
/*! \brief List of inputs of operator, placeholder in postdfs order */
Array<Tensor> inputs;
/*!
* \brief Symbolic buffers of each output/input tensor
* buffers[0:len(inputs)] are buffers of the inputs.
* buffers[len(inputs):] are buffers of each output.
*
* \note When a field in Buffer is Var, it means we can be flexible
* wrt that field and Var can occur in body.
* When it is a constant, it means we can only take data in that shape.
*/
Array<Buffer> buffers;
/*! \brief The normal statement to execute the intrinsic */
Stmt body;
/*!
* \brief Special statement for reduction op, can be None
* reset the value of output buffer to identity value.
*/
Stmt reduce_init;
/*!
* \brief Special statement for reduction op, can be None
* Reduce: do a reduction of current output buffer with the result.
*/
Stmt reduce_update;
/*! \brief constructor */
TensorIntrinNode() {}
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("op", &op);
v->Visit("inputs", &inputs);
v->Visit("buffers", &buffers);
v->Visit("body", &body);
v->Visit("reduce_init", &reduce_init);
v->Visit("reduce_update", &reduce_update);
}
static TensorIntrin make(std::string name,
Operation op,
Array<Tensor> inputs,
Array<Buffer> buffers,
Stmt body,
Stmt reduce_init,
Stmt reduce_update);
static constexpr const char* _type_key = "TensorIntrin";
TVM_DECLARE_NODE_TYPE_INFO(TensorIntrinNode, Node);
};
inline const TensorIntrinNode* TensorIntrin::operator->() const {
return static_cast<const TensorIntrinNode*>(node_.get());
}
} // namespace tvm
#endif // TVM_TENSOR_INTRIN_H_
# pylint: disable=redefined-builtin, wildcard-import
"""TVM: a DSL for tensor kernel compilation"""
"""TVM: Low level DSL/IR stack for tensor computation."""
from __future__ import absolute_import as _abs
from . import tensor
......@@ -23,6 +23,7 @@ from ._ffi.function import Function
from ._ffi.base import TVMError, __version__
from .api import *
from .intrin import *
from .tensor_intrin import decl_tensor_intrin
from .node import register_node
from .ndarray import register_extension
from .schedule import create_schedule
......
......@@ -109,9 +109,10 @@ def get_binds(args, binds=None):
args : list of Buffer or Tensor or Var
The argument lists to the function.
binds : dict, optional
Dictionary that maps the binding of symbolic buffer to Tensor.
By default, a new buffer is created for each tensor in the argument.
binds : dict of :any:`Tensor` to :any:`Buffer`, optional
Dictionary that maps the Tensor to Buffer which specified the data layout
requirement of the function. By default, a new compact buffer is created
for each tensor in the argument.
Returns
-------
......@@ -126,14 +127,16 @@ def get_binds(args, binds=None):
arg_list = []
for x in args:
if isinstance(x, tensor.Tensor):
buf = api.decl_buffer(x.shape,
dtype=x.dtype,
name=x.name,
data_alignment=cfg.data_alignment,
offset_factor=cfg.offset_factor)
assert x not in binds
binds[x] = buf
arg_list.append(buf)
if x not in binds:
buf = api.decl_buffer(x.shape,
dtype=x.dtype,
name=x.name,
data_alignment=cfg.data_alignment,
offset_factor=cfg.offset_factor)
binds[x] = buf
arg_list.append(buf)
else:
arg_list.append(binds[x])
elif isinstance(x, schedule.Buffer):
arg_list.append(x)
elif isinstance(x, expr.Var):
......@@ -161,9 +164,10 @@ def lower(sch,
name : str, optional
The name of result function.
binds : dict, optional
Dictionary that maps the binding of symbolic buffer to Tensor.
By default, a new buffer is created for each tensor in the argument.
binds : dict of :any:`Tensor` to :any:`Buffer`, optional
Dictionary that maps the Tensor to Buffer which specified the data layout
requirement of the function. By default, a new compact buffer is created
for each tensor in the argument.
simple_mode : bool, optional
Whether only output simple and compact statement, this will skip
......
......@@ -57,6 +57,7 @@ class IterVar(NodeBase, _expr.ExprOp):
Unrolled = 5
Vectorized = 6
Parallelized = 7
Tensorized = 8
_tensor.iter_var_cls = IterVar
......@@ -388,6 +389,19 @@ class Stage(NodeBase):
"""
_api_internal._StageVectorize(self, var)
def tensorize(self, var, tensor_intrin):
"""Tensorize the computation enclosed by var with tensor_intrin
Parameters
----------
var : IterVar
The iteration boundary of tensorization.
tensor_intrin : TensorIntrin
The tensor intrinsic used for computation.
"""
_api_internal._StageTensorize(self, var, tensor_intrin)
def unroll(self, var):
"""Unroll the iteration.
......
"""Tensor intrinsics"""
from __future__ import absolute_import as _abs
from . import _api_internal
from . import api as _api
from . import expr as _expr
from . import stmt as _stmt
from . import make as _make
from . import tensor as _tensor
from .build import BuildConfig
from ._ffi.node import NodeBase, register_node
@register_node
class TensorIntrin(NodeBase):
"""Tensor intrinsic functions for certain computation.
See Also
--------
decl_tensor_intrin: Construct a TensorIntrin
"""
pass
def decl_tensor_intrin(op,
fcompute,
name="tensor_intrin",
binds=None):
"""Declare a tensor intrinsic function.
Parameters
----------
op: Operation
The symbolic description of the intrinsic operation
fcompute: lambda function of inputs, outputs-> stmt
Specifies the IR statement to do the computation.
See the following note for function signature of fcompute
.. note::
**Parameters**
- **ins** (list of :any:`Buffer`) - Placeholder for each inputs
- **outs** (list of :any:`Buffer`) - Placeholder for each outputs
**Returns**
- **stmt** (:any:`Stmt`, or tuple of three stmts)
- If a single stmt is returned, it represents the body
- If tuple of three stmts are returned they corresponds to body,
reduce_init, reduce_update
name: str, optional
The name of the intrinsic.
binds: dict of :any:`Tensor` to :any:`Buffer`, optional
Dictionary that maps the Tensor to Buffer which specified the data layout
requirement of the function. By default, a new compact buffer is created
for each tensor in the argument.
Returns
-------
intrin: TensorIntrin
A TensorIntrin that can be used in tensorize schedule.
"""
if not isinstance(op, _tensor.Operation):
raise TypeError("expect Operation")
inputs = op.input_tensors
binds = binds if binds else {}
tensors = [x for x in inputs]
for i in range(op.num_outputs):
tensors.append(op.output(i))
binds_list = []
for t in inputs:
if not isinstance(t.op, _tensor.PlaceholderOp):
raise ValueError("Donot yet support composition op")
cfg = BuildConfig.current
for t in tensors:
buf = (binds[t] if t in binds else
_api.decl_buffer(t.shape, t.dtype, t.op.name,
data_alignment=cfg.data_alignment,
offset_factor=cfg.offset_factor))
binds_list.append(buf)
body = fcompute(binds_list[:len(inputs)], binds_list[len(inputs):])
if isinstance(body, (_expr.Expr, _stmt.Stmt)):
body = [body]
body = [_make.Evaluate(x) if isinstance(x, _expr.Expr) else x for x in body]
if len(body) < 3:
body += [None] * (3 - len(body))
return _api_internal._TensorIntrin(
name, op, inputs, binds_list, *body)
......@@ -164,6 +164,17 @@ TVM_REGISTER_API("_Tensor")
args[3]);
});
TVM_REGISTER_API("_TensorIntrin")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TensorIntrinNode::make(args[0],
args[1],
args[2],
args[3],
args[4],
args[5],
args[6]);
});
TVM_REGISTER_API("_TensorEqual")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Tensor() == args[1].operator Tensor();
......@@ -335,6 +346,12 @@ TVM_REGISTER_API("_StageVectorize")
.vectorize(args[1]);
});
TVM_REGISTER_API("_StageTensorize")
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage()
.tensorize(args[1], args[2]);
});
TVM_REGISTER_API("_StageParallel")
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage()
......
......@@ -21,6 +21,15 @@ TVM_REGISTER_API("ir_pass.Simplify")
}
});
TVM_REGISTER_API("ir_pass.CanonicalSimplify")
.set_body([](TVMArgs args, TVMRetValue *ret) {
if (args[0].IsNodeType<Stmt>()) {
*ret = CanonicalSimplify(args[0].operator Stmt());
} else {
*ret = CanonicalSimplify(args[0].operator Expr());
}
});
TVM_REGISTER_API("ir_pass.Equal")
.set_body([](TVMArgs args, TVMRetValue *ret) {
if (args[0].IsNodeType<Stmt>()) {
......@@ -76,7 +85,6 @@ TVM_REGISTER_API("ir_pass.PostOrderVisit")
REGISTER_PASS1(ConvertSSA);
REGISTER_PASS1(VerifySSA);
REGISTER_PASS1(CanonicalSimplify);
REGISTER_PASS4(Inline);
REGISTER_PASS2(StorageFlatten);
REGISTER_PASS1(VectorizeLoop);
......
......@@ -531,5 +531,8 @@ Stmt CanonicalSimplify(Stmt stmt) {
return arith::Canonical().Simplify(stmt);
}
Expr CanonicalSimplify(Expr expr) {
return arith::Canonical().Simplify(expr);
}
} // namespace ir
} // namespace tvm
......@@ -42,7 +42,6 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
<< ", identity_element=" << op->identity_element
<< ")";
});
} // namespace Internal
} // namespace Halide
......@@ -141,5 +140,6 @@ TVM_REGISTER_NODE_TYPE(Realize);
TVM_REGISTER_NODE_TYPE(Block);
TVM_REGISTER_NODE_TYPE(IfThenElse);
TVM_REGISTER_NODE_TYPE(Evaluate);
} // namespace ir
} // namespace tvm
......@@ -4,6 +4,7 @@
*/
#include <tvm/tensor.h>
#include <tvm/operation.h>
#include <tvm/tensor_intrin.h>
#include <ir/IR.h>
#include <memory>
......@@ -54,4 +55,28 @@ Tensor Operation::output(size_t i) const {
return Tensor(node);
}
TensorIntrin TensorIntrinNode::make(std::string name,
Operation op,
Array<Tensor> inputs,
Array<Buffer> buffers,
Stmt body,
Stmt reduce_init,
Stmt reduce_update) {
auto n = std::make_shared<TensorIntrinNode>();
n->name = std::move(name);
n->op = std::move(op);
n->inputs = std::move(inputs);
n->buffers = std::move(buffers);
n->body = std::move(body);
n->reduce_init = std::move(reduce_init);
n->reduce_update = std::move(reduce_update);
return TensorIntrin(n);
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TensorIntrinNode>([](const TensorIntrinNode *n, IRPrinter *p) {
p->stream << "TensorIntrin(name=" << n->name << ", " << n << ")";
});
TVM_REGISTER_NODE_TYPE(TensorIntrinNode);
} // namespace tvm
......@@ -297,14 +297,62 @@ Stmt MakeComputeStmt(const ComputeOpNode* self,
}
}
enum class ComputeType {
kNormal,
kCrossThreadReduction,
kTensorize
};
ComputeType DetectComputeType(const ComputeOpNode* self,
const Stage& stage) {
// Verify correctness of leaf nest.
int normal_red = 0, thread_red = 0, tensorize = 0;
for (IterVar iv : stage->leaf_iter_vars) {
IterVarAttr attr;
auto it = stage->iter_var_attrs.find(iv);
if (it != stage->iter_var_attrs.end()) {
attr = (*it).second;
}
if (attr.defined() && attr->iter_type == kTensorized) {
++tensorize;
}
if (iv->iter_type == kCommReduce) {
if (attr.defined() && attr->bind_thread.defined()) {
++thread_red;
} else {
++normal_red;
}
} else {
CHECK_EQ(thread_red, 0)
<< "Cross thread reduce cannot swap with normal data axis";
}
}
if (tensorize != 0) {
CHECK(thread_red == 0)
<< "Cannot mix cross thread reduction with Tensorize";
return ComputeType::kTensorize;
}
CHECK(normal_red == 0 || thread_red == 0)
<< "Cannot mix normal reduction with thread reduce";
if (thread_red != 0) {
return ComputeType::kCrossThreadReduction;
} else {
return ComputeType::kNormal;
}
}
// implement the provide utility.
Stmt ComputeOpNode::BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) const {
CHECK_EQ(stage->op.operator->(), this);
if (IsCrossThreadReduction(this, stage)) {
ComputeType ctype = DetectComputeType(this, stage);
if (ctype == ComputeType::kCrossThreadReduction) {
// specially handle cross thread reduction.
return MakeCrossThreadReduction(this, stage, dom_map);
} else if (ctype == ComputeType::kTensorize) {
return MakeTensorize(this, stage, dom_map);
} else {
return MakeComputeStmt(this, stage, dom_map);
}
......
......@@ -46,13 +46,6 @@ struct ComputeLoopNest {
};
/*!
* \brief Whether compute op is a cross thread reduction structure.
* \param self The pointer to ComputeOpNode
* \param stage the schedule stage.
*/
bool IsCrossThreadReduction(const ComputeOpNode* self,
const Stage& stage);
/*!
* \brief Build body of compute for cross thread reduction pattern.
* \param self The pointer to ComputeOpNode
* \param stage The schedule stage.
......@@ -63,6 +56,17 @@ Stmt MakeCrossThreadReduction(
const ComputeOpNode* self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map);
/*!
* \brief Build body of compute for tensorization.
* \param self The pointer to ComputeOpNode
* \param stage The schedule stage.
* \param dom_map The domain map.
* \return The created statement.
*/
Stmt MakeTensorize(const ComputeOpNode* self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map);
} // namespace tvm
#endif // TVM_OP_COMPUTE_OP_H_
......@@ -10,29 +10,6 @@
namespace tvm {
using namespace ir;
bool IsCrossThreadReduction(const ComputeOpNode* self,
const Stage& stage) {
// Verify correctness of leaf nest.
int normal_red = 0, thread_red = 0;
for (IterVar iv : stage->leaf_iter_vars) {
if (iv->iter_type == kCommReduce) {
auto it = stage->iter_var_attrs.find(iv);
if (it != stage->iter_var_attrs.end() &&
(*it).second->bind_thread.defined()) {
++thread_red;
} else {
++normal_red;
}
} else {
CHECK_EQ(thread_red, 0)
<< "Cross thread reduce cannot swap with normal data axis";
}
}
CHECK(normal_red == 0 || thread_red == 0)
<< "Cannot mix normal reduction with thread reduce";
return thread_red != 0;
}
Stmt MakeCrossThreadReduction(
const ComputeOpNode* self,
const Stage& stage,
......
......@@ -65,6 +65,7 @@ MakeLoopNest(const Stage& stage,
case kVectorized: for_type = ForType::Vectorized; break;
case kParallelized: for_type = ForType::Parallel; break;
case kDataPar: break;
case kTensorized: break;
default: LOG(FATAL) << "Unknown iter type"
<< it_attr->iter_type
<< " in the iter_var_attrs";
......
/*!
* Copyright (c) 2017 by Contributors
* \brief Logics related to tensorize, used by ComputeOpNode.
* \file tensorize.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <tvm/api_registry.h>
#include "./op_util.h"
#include "./compute_op.h"
#include "../schedule/message_passing.h"
namespace tvm {
using namespace ir;
using namespace op;
// Detect the region of input and output to be tensrized.
// out_dom: the domain of root iter vars in output op
// in_region: region of each input tensor.
// return The location of the tensorized scope start.
size_t InferTensorizeRegion(
const ComputeOpNode* self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
std::unordered_map<IterVar, Range>* out_dom,
std::unordered_map<Tensor, Array<Range> >* in_region) {
// Get the bound of the tensorized scope.
bool found_point = false;
size_t loc_scope = 0;
std::unordered_map<IterVar, IntSet> up_state;
// Loop over the leafs
for (size_t i = stage->leaf_iter_vars.size(); i != 0; --i) {
IterVar iv = stage->leaf_iter_vars[i - 1];
CHECK(iv->iter_type == kDataPar ||
iv->iter_type == kCommReduce);
auto vit = dom_map.find(iv);
CHECK(vit != dom_map.end());
const Range& vrange = vit->second;
if (is_one(vrange->extent)) {
up_state[iv] = IntSet::single_point(vrange->min);
} else if (found_point) {
CHECK(is_zero(vrange->min));
up_state[iv] = IntSet::single_point(iv->var);
} else {
up_state[iv] = IntSet::range(vrange);
}
auto iit = stage->iter_var_attrs.find(iv);
if (iit != stage->iter_var_attrs.end()) {
const IterVarAttr& attr = (*iit).second;
if (!found_point) {
CHECK(!attr->bind_thread.defined())
<< "Donot allow thread in tensorize scope";
}
if (attr->iter_type == kTensorized) {
CHECK(!found_point) << "Donot allow two tensorized point";
found_point = true;
loc_scope = i - 1;
}
}
}
CHECK(found_point);
// Get domain of the tensorized scope.
schedule::PassUpDomain(stage, dom_map, &up_state);
// Get domains if inputs
std::unordered_map<Tensor, TensorDom> in_dom;
std::unordered_map<const Variable*, IntSet> temp_dmap;
Array<Tensor> inputs = self->InputTensors();
for (Tensor t : inputs) {
in_dom.emplace(t, TensorDom(t.ndim()));
}
for (IterVar iv : self->root_iter_vars()) {
IntSet iset = up_state.at(iv);
(*out_dom)[iv] = iset.cover_range(dom_map.at(iv));
temp_dmap[iv->var.get()] = iset;
}
// Input domains
self->PropBoundToInputs(stage->op, temp_dmap, &in_dom);
Range none;
for (const auto& kv : in_dom) {
Array<Range> vec;
const Tensor& t = kv.first;
for (int i = 0; i < t.ndim(); ++i) {
Range r = arith::Union(kv.second.data.at(i)).cover_range(none);
CHECK(r.defined()) << "cannot deduce region of tensorized scope for input " << t;
vec.push_back(std::move(r));
}
(*in_region)[t] = std::move(vec);
}
return loc_scope;
}
void VerifyTensorizeLoopNest(const ComputeOpNode* self,
const Stage& stage,
const ComputeLoopNest& n,
size_t tloc) {
// Veirfication step.
std::unordered_set<const Variable*> banned;
CHECK_EQ(n.main_nest.size(), stage->leaf_iter_vars.size() + 1);
CHECK(n.init_nest.size() == stage->leaf_iter_vars.size() + 1 ||
n.init_nest.size() == 0);
auto f_push_banned = [&banned](const Stmt& s) {
if (const For* op = s.as<For>()) {
banned.insert(op->loop_var.get());
} else if (const AttrStmt* op = s.as<AttrStmt>()) {
if (const IterVarNode* iv = op->node.as<IterVarNode>()) {
banned.insert(iv->var.get());
}
} else if (const LetStmt* op = s.as<LetStmt>()) {
banned.insert(op->var.get());
}
};
for (size_t i = tloc; i < stage->leaf_iter_vars.size(); ++i) {
for (const Stmt& s : n.main_nest[i + 1]) {
f_push_banned(s);
}
if (n.init_nest.size() != 0) {
for (const Stmt& s : n.init_nest[i + 1]) {
f_push_banned(s);
}
}
}
for (const Expr& pred : n.main_predicates) {
if (ir::ExprUseVar(pred, banned)) {
LOG(FATAL) << "Tensorize failed, split condition "
<< pred << " relies on var defined inside tensorize scope";
}
}
for (const Expr& pred : n.init_predicates) {
if (ir::ExprUseVar(pred, banned)) {
LOG(FATAL) << "Tensorize failed, split condition "
<< pred << " relies on var defined inside tensorize scope";
}
}
}
// Remap the tensor placeholder, index and inline things.
class TensorIntrinMatcher final : public IRMutator {
public:
Expr Mutate_(const Call* op, const Expr& e) final {
Expr expr = IRMutator::Mutate_(op, e);
op = expr.as<Call>();
if (op->call_type == Call::Halide) {
Tensor t = Operation(op->func.node_).output(op->value_index);
auto it = in_remap_.find(t);
if (it != in_remap_.end()) {
const InputEntry& e = it->second;
CHECK_EQ(op->args.size(), e.region.size());
Array<Expr> args;
for (size_t i = e.start; i < e.region.size(); ++i) {
args.push_back(op->args[i] - e.region[i]->min);
}
return Call::make(
op->type, e.tensor->op->name, args,
op->call_type, e.tensor->op, e.tensor->value_index);
}
}
return expr;
}
Expr Mutate_(const Variable* op, const Expr& e) final {
auto it = var_remap_.find(op);
if (it != var_remap_.end()) {
return it->second;
} else {
return e;
}
}
Expr Mutate_(const Reduce* op, const Expr& e) final {
Expr expr = IRMutator::Mutate_(op, e);
op = expr.as<Reduce>();
Array<IterVar> axis = op->axis;
for (size_t i = 0; i < op->axis.size(); ++i) {
auto it = axis_remap_.find(op->axis[i]);
if (it != axis_remap_.end()) {
axis.Set(i, it->second);
}
}
if (!axis.same_as(op->axis)) {
return Reduce::make(
op->combiner, op->source, axis, op->condition, op->value_index);
} else {
return e;
}
}
void Init(const ComputeOpNode* self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& out_dom,
const std::unordered_map<Tensor, Array<Range> >& in_region,
const TensorIntrin& intrin) {
// input remap.
Array<Tensor> inputs = self->InputTensors();
CHECK_EQ(inputs.size(), intrin->inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
InputEntry e;
e.tensor = intrin->inputs[i];
e.region = Array<Range>(in_region.at(inputs[i]));
CHECK_GE(e.region.size(), e.tensor.ndim());
// Enable fuzzy matching, to match [1, n, m] to [n, m]
e.start = e.region.size() - e.tensor.ndim();
for (size_t i = 0; i < e.start; ++i) {
CHECK(is_one(e.region[i]->extent))
<< "Tensorize: Input dimension mismatch with tensor intrin "
<< " expected shape=" << e.tensor->shape
<< ", given region=" << e.region;
}
in_remap_[inputs[i]] = e;
}
// output remap
const ComputeOpNode* intrin_compute = intrin->op.as<ComputeOpNode>();
CHECK(intrin_compute) << "Only support compute intrinsic for now";
CHECK_GE(self->axis.size(), intrin_compute->axis.size())
<< "Tensorize: Output mismatch with tensor intrin ";
// Enable fuzzy matching, to match [1, n, m] to [n, m]
size_t axis_start = self->axis.size() - intrin_compute->axis.size();
for (size_t i = 0; i < axis_start; ++i) {
Range r = out_dom.at(self->axis[i]);
CHECK(is_one(r->extent))
<< "Tensorize: Output mismatch with tensor intrin "
<< " intrin-dim=" << intrin_compute->axis.size()
<< ", tensorize-dim=" << self->axis.size();
}
// Assume we tensorize at regin axis i [min, min + extent)
// The corresponding intrinsic axis is j [0, extent)
// Remap index i to j + min
for (size_t i = axis_start; i < self->axis.size(); ++i) {
IterVar iv = self->axis[i];
IterVar target_iv = intrin_compute->axis[i - axis_start];
Range r = out_dom.at(iv);
var_remap_[iv->var.get()] = target_iv->var + r->min;
axis_remap_[iv] = target_iv;
}
// Remap reduction axis
CHECK_GE(self->reduce_axis.size(), intrin_compute->reduce_axis.size())
<< "Tensorize: Reduction dimension mismatch with tensor intrin";
axis_start = self->reduce_axis.size() - intrin_compute->reduce_axis.size();
for (size_t i = 0; i < axis_start; ++i) {
Range r = out_dom.at(self->reduce_axis[i]);
CHECK(is_one(r->extent))
<< "Tensorize: Reduction mismatch with tensor intrin "
<< " intrin-dim=" << intrin_compute->reduce_axis.size()
<< ", tensorize-dim=" << self->reduce_axis.size();
}
for (size_t i = axis_start; i < self->reduce_axis.size(); ++i) {
IterVar iv = self->reduce_axis[i];
IterVar target_iv = intrin_compute->reduce_axis[i - axis_start];
Range r = out_dom.at(iv);
var_remap_[iv->var.get()] = target_iv->var + r->min;
axis_remap_[iv] = target_iv;
}
}
private:
// Input entry
struct InputEntry {
Tensor tensor;
size_t start;
Array<Range> region;
};
// input data remap
std::unordered_map<Tensor, InputEntry> in_remap_;
// variable remap.
std::unordered_map<const Variable*, Expr> var_remap_;
// IterVar remap.
std::unordered_map<IterVar, IterVar> axis_remap_;
};
// Try to match tensor dataflow of the stage with the intrinsic
Array<Expr> MatchTensorizeBody(
const ComputeOpNode* self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& out_dom,
const std::unordered_map<Tensor, Array<Range> >& in_region,
const TensorIntrin& intrin) {
TensorIntrinMatcher matcher;
matcher.Init(self, stage, out_dom, in_region, intrin);
Array<Expr> ret;
for (Expr expr : self->body) {
ret.push_back(matcher.Mutate(expr));
}
return ret;
}
void VerifyTensorizeBody(
const ComputeOpNode* self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& out_dom,
const std::unordered_map<Tensor, Array<Range> >& in_region,
const TensorIntrin& intrin) {
Array<Expr> body = MatchTensorizeBody(self, stage, out_dom, in_region, intrin);
const ComputeOpNode* intrin_compute = intrin->op.as<ComputeOpNode>();
CHECK(intrin_compute) << "Only support compute intrinsic for now";
CHECK_EQ(body.size(), intrin_compute->body.size())
<< "Tensorize failed: body size mismatch";
for (size_t i = 0; i < body.size(); ++i) {
Expr lhs = CanonicalSimplify(body[i]);
Expr rhs = CanonicalSimplify(intrin_compute->body[i]);
CHECK(Equal(lhs, rhs))
<< "Failed to match the compute with TensorIntrin declaration "
<< " provided:" << lhs
<< ", intrin:" << rhs;
}
}
Stmt MakeTensorize(const ComputeOpNode* self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) {
std::unordered_map<IterVar, Range> out_dom;
std::unordered_map<Tensor, Array<Range> > in_region;
size_t tloc = InferTensorizeRegion(self, stage, dom_map, &out_dom, &in_region);
TensorIntrin intrin = stage->iter_var_attrs.at(
stage->leaf_iter_vars[tloc])->tensor_intrin;
CHECK(intrin.defined());
ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map);
VerifyTensorizeLoopNest(self, stage, n, tloc);
VerifyTensorizeBody(self, stage, out_dom, in_region, intrin);
// Start bind data.
Stmt nop = Evaluate::make(0);
std::vector<Stmt> bind_nest;
Array<Tensor> inputs = self->InputTensors();
CHECK_EQ(inputs.size(), intrin->inputs.size())
<< "Tensorize failed: input size mismatch ";
// input binding
for (int i = 0; i < intrin->inputs.size(); ++i) {
Tensor tensor = inputs[i];
Buffer buffer = intrin->buffers[i];
Array<NodeRef> bind_spec{buffer, tensor};
auto it = in_region.find(tensor);
CHECK(it != in_region.end());
const Array<Range>& region = it->second;
Array<Expr> tuple;
for (const Range r : region) {
tuple.push_back(r->min);
tuple.push_back(r->extent);
}
bind_nest.emplace_back(AttrStmt::make(
bind_spec, ir::attr::buffer_bind_scope,
Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop));
}
// output binding
const ComputeOpNode* intrin_compute = intrin->op.as<ComputeOpNode>();
CHECK(intrin_compute) << "Only support compute intrinsic for now";
CHECK_EQ(intrin->inputs.size() + intrin_compute->body.size(), intrin->buffers.size());
CHECK_EQ(intrin_compute->body.size(), self->body.size());
Array<Expr> tuple;
for (IterVar iv : self->axis) {
auto it = out_dom.find(iv);
CHECK(it != out_dom.end());
tuple.push_back(it->second->min);
tuple.push_back(it->second->extent);
}
for (size_t i = intrin->inputs.size(); i < intrin->buffers.size(); ++i) {
Tensor tensor = stage->op.output(i - intrin->inputs.size());
Buffer buffer = intrin->buffers[i];
Array<NodeRef> bind_spec{buffer, tensor};
bind_nest.emplace_back(AttrStmt::make(
bind_spec, ir::attr::buffer_bind_scope,
Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop));
}
// Check variable remap
std::unordered_map<const Variable*, Expr> vmap;
ir::ArgBinder binder(&vmap);
CHECK_GE(self->reduce_axis.size(), intrin_compute->reduce_axis.size())
<< "Tensorization fail: reduction axis size do not match";
size_t start = self->reduce_axis.size() - intrin_compute->reduce_axis.size();
for (size_t i = 0; i < start; ++i) {
IterVar iv = self->reduce_axis[i];
auto it = out_dom.find(iv);
CHECK(it != out_dom.end());
CHECK(is_one(it->second->extent))
<< "Tensorization fail: reduction axis size do not match";
}
for (size_t i = start; i < self->reduce_axis.size(); ++i) {
IterVar iv = self->reduce_axis[i];
IterVar target = intrin_compute->reduce_axis[i - start];
auto it = out_dom.find(iv);
CHECK(it != out_dom.end());
binder.Bind(target->dom->min, make_const(iv->dom->min.type(), 0),
"tensir_intrin.reduction.min");
binder.Bind(target->dom->extent, it->second->extent,
"tensir_intrin.reduction.extent");
}
if (tloc <= n.num_common_loop) {
// Do no need to split reduction
std::vector<std::vector<Stmt> > nest(
n.main_nest.begin(), n.main_nest.begin() + tloc + 1);
nest.emplace_back(op::MakeIfNest(n.main_predicates));
CHECK_EQ(n.init_predicates.size(), 0U);
CHECK(intrin->body.defined())
<< "Normal store op for intrin " << intrin << " is not defined";
Stmt body = ir::MergeNest(bind_nest, intrin->body);
body = Substitute(body, vmap);
body = ir::MergeNest(binder.asserts(), body);
body = Substitute(body, n.main_vmap);
return ir::MergeNest(nest, body);
} else {
// Need to split reduction
CHECK(intrin->reduce_init.defined())
<< "Reduction init op for intrin " << intrin << " is not defined";
CHECK(intrin->reduce_update.defined())
<< "Reduction update op for intrin " << intrin << " is not defined";
// Need init and update steps
CHECK_NE(self->reduce_axis.size(), 0U);
std::vector<std::vector<Stmt> > common(
n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1);
// init nest
std::vector<std::vector<Stmt> > init_nest(
n.init_nest.begin(), n.init_nest.begin() + tloc + 1);
init_nest.emplace_back(op::MakeIfNest(n.init_predicates));
Stmt init = MergeNest(bind_nest, intrin->reduce_init);
init = Substitute(init, n.init_vmap);
init = MergeNest(init_nest, init);
// The update
std::vector<std::vector<Stmt> > update_nest(
n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.begin() + tloc + 1);
update_nest.emplace_back(op::MakeIfNest(n.main_predicates));
Stmt update = MergeNest(bind_nest, intrin->reduce_update);
update = Substitute(update, vmap);
update = MergeNest(binder.asserts(), update);
update = Substitute(update, n.main_vmap);
update = MergeNest(update_nest, update);
return MergeNest(common, Block::make(init, update));
}
}
// Register functions for unittests
TVM_REGISTER_API("test.op.InferTensorizeRegion")
.set_body([](TVMArgs args, TVMRetValue* ret) {
Stage stage = args[0];
Map<IterVar, Range> dmap = args[1];
std::unordered_map<IterVar, Range> out_dom;
std::unordered_map<Tensor, Array<Range> > in_region;
CHECK(stage->op.as<ComputeOpNode>());
InferTensorizeRegion(stage->op.as<ComputeOpNode>(),
stage,
as_unordered_map(dmap),
&out_dom, &in_region);
*ret = Array<NodeRef>{Map<IterVar, Range>(out_dom),
Map<Tensor, Array<Range> >(in_region)};
});
TVM_REGISTER_API("test.op.MatchTensorizeBody")
.set_body([](TVMArgs args, TVMRetValue* ret) {
Stage stage = args[0];
Map<IterVar, Range> out_dom = args[1];
Map<Tensor, Array<Range> > in_region = args[2];
TensorIntrin intrin = args[3];
CHECK(stage->op.as<ComputeOpNode>());
*ret = MatchTensorizeBody(stage->op.as<ComputeOpNode>(),
stage,
as_unordered_map(out_dom),
as_unordered_map(in_region),
intrin);
});
} // namespace tvm
......@@ -14,8 +14,8 @@ namespace runtime {
DSLAPI* FindDSLAPI() {
auto* f = Registry::Get("dsl_api.singleton");
if (f == nullptr) {
throw dmlc::Error("TVM runtime only environment, "\
"DSL API is not available");
throw dmlc::Error("TVM runtime only environment,"\
" DSL API is not available");
}
void* ptr = (*f)();
return static_cast<DSLAPI*>(ptr);
......
......@@ -292,7 +292,8 @@ Stage& Stage::tile(IterVar x_parent, IterVar y_parent,
return *this;
}
inline void SetAttrIterType(StageNode* self, IterVar var, IterVarType iter_type) {
template<typename FUpdate>
inline void UpdateIterVarAttr(StageNode* self, IterVar var, FUpdate fupdate) {
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
FindLeafVar(all_vars, leaf_vars, var);
......@@ -303,15 +304,29 @@ inline void SetAttrIterType(StageNode* self, IterVar var, IterVarType iter_type)
} else {
n = std::make_shared<IterVarAttrNode>();
}
n->iter_type = iter_type;
fupdate(n.get());
self->iter_var_attrs.Set(var, IterVarAttr(n));
}
inline void SetAttrIterType(StageNode* self, IterVar var, IterVarType iter_type) {
UpdateIterVarAttr(self, var, [iter_type](IterVarAttrNode* n) {
n->iter_type = iter_type;
});
}
Stage& Stage::vectorize(IterVar var) { // NOLINT(*)
SetAttrIterType(operator->(), var, kVectorized);
return *this;
}
Stage& Stage::tensorize(IterVar var, TensorIntrin f) { // NOLINT(*)
UpdateIterVarAttr(operator->(), var, [f](IterVarAttrNode* n) {
n->iter_type = kTensorized;
n->tensor_intrin = f;
});
return *this;
}
Stage& Stage::unroll(IterVar var) { // NOLINT(*)
SetAttrIterType(operator->(), var, kUnrolled);
return *this;
......
......@@ -296,10 +296,7 @@ class SchedulePostProc : public IRMutator {
Stmt ScheduleOps(
Schedule sch, Map<IterVar, Range> dom_map_) {
Stmt body = Stmt();
std::unordered_map<IterVar, Range> dom_map;
for (auto kv : dom_map_) {
dom_map[kv.first] = kv.second;
}
std::unordered_map<IterVar, Range> dom_map = as_unordered_map(dom_map_);
// scan init and scan updates
std::unordered_map<Operation, Operation> scan_init;
for (Stage s : sch->stages) {
......
......@@ -23,7 +23,7 @@ def test_llvm_add_pipeline():
name='A')
binds = {A : Ab}
# BUILD and invoke the kernel.
f = tvm.build(s, [Ab, B, C], "llvm", binds=binds)
f = tvm.build(s, [A, B, C], "llvm", binds=binds)
ctx = tvm.cpu(0)
# launch the kernel.
n = nn
......
......@@ -115,8 +115,33 @@ def test_rfactor():
assert(BF.op.body[0].axis[1].var == ko.var)
assert(s[B].op.body[0].axis[0].dom.extent.value == 4)
def test_tensor_intrin():
n = 16
x = tvm.placeholder((n,), name='x')
y = tvm.placeholder((n,), name='y')
z = tvm.compute(x.shape, lambda i: x[i] + y[i], name='z')
def intrin_func(ins, outs):
assert(isinstance(ins[0], tvm.schedule.Buffer))
assert(ins[0].shape[0].value == n)
return tvm.call_packed("vadd", ins[0].data, outs[0].data, ins[0].shape[0])
intrin = tvm.decl_tensor_intrin(z.op, intrin_func)
assert intrin.op == z.op
assert intrin.reduce_init is None
assert tuple(intrin.inputs) == tuple(z.op.input_tensors)
assert(intrin.buffers[0].shape[0].value == n)
m = 32
x = tvm.placeholder((m,), name='x')
y = tvm.placeholder((m,), name='y')
z = tvm.compute(x.shape, lambda i: x[i] + y[i], name='z')
s = tvm.create_schedule(z.op)
xo, xi = s[z].split(z.op.axis[0], factor=n)
s[z].tensorize(xi, intrin)
assert(s[z].iter_var_attrs[xi].tensor_intrin == intrin)
assert(s[z].iter_var_attrs[xi].iter_type == tvm.schedule.IterVar.Tensorized)
if __name__ == "__main__":
test_tensor_intrin()
test_rfactor()
test_schedule_create()
test_reorder()
......
import tvm
def intrin_vadd(n):
x = tvm.placeholder((n,), name='vx')
y = tvm.placeholder((n,), name='vy')
z = tvm.compute(x.shape, lambda i: x[i] + y[i], name='z')
def intrin_func(ins, outs):
xx, yy = ins
zz = outs[0]
return tvm.call_packed("vadd", xx, yy, zz)
with tvm.build_config(offset_factor=16):
return tvm.decl_tensor_intrin(z.op, intrin_func)
def intrin_gemv(m, n):
w = tvm.placeholder((m, n), name='w')
x = tvm.placeholder((n,), name='x')
k = tvm.reduce_axis((0, n), name='k')
z = tvm.compute((m,), lambda i:
tvm.sum(w[i, k] * x[k], axis=k), name='z')
Wb = tvm.decl_buffer(w.shape, w.dtype, name="W",
strides=[tvm.var('ldw'), 1])
def intrin_func(ins, outs):
ww, xx = ins
zz = outs[0]
body = tvm.call_packed(
"gemv", ww.data, xx.data, zz.data, n, ww.strides[0])
reset = tvm.call_packed(
"fill_zero", outs[0].data, n)
update = tvm.call_packed(
"gemv_add", ww.data, xx.data, zz.data, n, ww.strides[0])
return body, reset, update
with tvm.build_config(data_alignment=16):
return tvm.decl_tensor_intrin(z.op, intrin_func,
binds={w: Wb})
def test_tensorize_vadd():
m = 128
x = tvm.placeholder((m,), name='x')
y = tvm.placeholder((m,), name='y')
z = tvm.compute(x.shape, lambda i: x[i] + y[i], name='z')
def check(factor):
s = tvm.create_schedule(z.op)
xo, xi = s[z].split(z.op.axis[0], factor=factor)
vadd = intrin_vadd(factor)
s[z].tensorize(xi, vadd)
s = s.normalize()
dom_map = tvm.schedule.InferBound(s)
finfer = tvm.get_global_func("test.op.InferTensorizeRegion")
out_dom, in_dom = finfer(s[z], dom_map)
assert tvm.ir_pass.Equal(out_dom[z.op.axis[0]].extent, factor)
assert tvm.ir_pass.Equal(out_dom[z.op.axis[0]].min, xo * factor)
assert tvm.ir_pass.Equal(in_dom.items()[0][1][0].extent, factor)
fmatch = tvm.get_global_func("test.op.MatchTensorizeBody")
body = fmatch(s[z], out_dom, in_dom, vadd)
assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(body[0]),
tvm.ir_pass.CanonicalSimplify(vadd.op.body[0]))
stmt = tvm.schedule.ScheduleOps(s, dom_map)
tvm.lower(s, [x, y, z])
check(16)
def test_tensorize_matmul():
n = 1024
m = n
l = n
A = tvm.placeholder((n, l), name='A')
B = tvm.placeholder((m, l), name='B')
k = tvm.reduce_axis((0, l), name='k')
C = tvm.compute((n, m), lambda i, j:
tvm.sum(B[j, k] * A[i, k], axis=k), name='C')
def check(factor):
s = tvm.create_schedule(C.op)
x, y = C.op.axis
yo, yi = s[C].split(y, factor=factor)
gemv = intrin_gemv(factor, l)
s[C].tensorize(yi, gemv)
s = s.normalize()
dom_map = tvm.schedule.InferBound(s)
finfer = tvm.get_global_func("test.op.InferTensorizeRegion")
out_dom, in_dom = finfer(s[C], dom_map)
assert tvm.ir_pass.Equal(out_dom[x].extent, 1)
assert tvm.ir_pass.Equal(out_dom[y].extent, factor)
assert tvm.ir_pass.Equal(out_dom[y].min, yo * factor)
fmatch = tvm.get_global_func("test.op.MatchTensorizeBody")
body = fmatch(s[C], out_dom, in_dom, gemv)
assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(body[0]),
tvm.ir_pass.CanonicalSimplify(gemv.op.body[0]))
stmt = tvm.schedule.ScheduleOps(s, dom_map)
tvm.lower(s, [A, B, C])
def check_rfactor(factor, rfactor):
s = tvm.create_schedule(C.op)
x, y = C.op.axis
rk = C.op.reduce_axis[0]
yo, yi = s[C].split(y, factor=factor)
ro, ri = s[C].split(rk, factor=rfactor)
s[C].reorder(yo, ro, yi, ri)
gemv = intrin_gemv(factor, rfactor)
s[C].tensorize(yi, gemv)
s = s.normalize()
dom_map = tvm.schedule.InferBound(s)
finfer = tvm.get_global_func("test.op.InferTensorizeRegion")
out_dom, in_dom = finfer(s[C], dom_map)
assert tvm.ir_pass.Equal(out_dom[x].extent, 1)
assert tvm.ir_pass.Equal(out_dom[y].extent, factor)
assert tvm.ir_pass.Equal(out_dom[y].min, yo * factor)
fmatch = tvm.get_global_func("test.op.MatchTensorizeBody")
body = fmatch(s[C], out_dom, in_dom, gemv)
assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(body[0]),
tvm.ir_pass.CanonicalSimplify(gemv.op.body[0]))
stmt = tvm.schedule.ScheduleOps(s, dom_map)
tvm.lower(s, [A, B, C])
check(16)
check_rfactor(16, 16)
if __name__ == "__main__":
test_tensorize_vadd()
test_tensorize_matmul()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment