Commit d546bb77 by Salem Derisavi Committed by ziheng

Defined a common base class for TensorComputeOp and ComputeOp (#2587)

* Defined a common base class for TensorComputeOp and ComputeOp

* Made changes requested by @ZihengJiang

* added a testcase to assert that `tensorize` does not have any effect on TensorComputeOp ops.
parent 3e5a172d
......@@ -184,22 +184,45 @@ class PlaceholderOpNode : public OperationNode {
/*!
* \brief A Compute op that compute a tensor on certain domain.
* This is the base class for ComputeOp (operating on a scalar at a time) and
* TensorComputeOp (operating on a TensorSlice at a time)
*/
class TVM_DLL ComputeOpNode : public OperationNode {
class TVM_DLL BaseComputeOpNode : public OperationNode {
public:
/*! \brief IterVar on each axis */
Array<IterVar> axis;
/*! \brief IterVar on each reduction axis, if the body is a Reduce */
Array<IterVar> reduce_axis;
// override functions
Array<IterVar> root_iter_vars() const final;
Array<Expr> output_shape(size_t idx) const final;
void GatherBound(
const Operation& self,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(
const Stage& stage,
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final;
virtual size_t num_schedulable_dims() const = 0;
static constexpr const char* _type_key = "BaseComputeOp";
TVM_DECLARE_BASE_NODE_INFO(BaseComputeOpNode, OperationNode);
};
/*!
* \brief A Compute op that compute a tensor on certain domain.
*/
class TVM_DLL ComputeOpNode : public BaseComputeOpNode {
public:
/*! \brief the compute expression */
Array<Expr> body;
/*! \brief constructor */
ComputeOpNode() {}
// override functions
int num_outputs() const final;
Array<IterVar> root_iter_vars() const final;
Type output_dtype(size_t i) const final;
Array<Expr> output_shape(size_t i) const final;
Array<Tensor> InputTensors() const final;
Operation ReplaceInputs(
const Operation& self,
......@@ -208,18 +231,11 @@ class TVM_DLL ComputeOpNode : public OperationNode {
const Operation& self,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound(
const Operation& self,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(
const Stage& stage,
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final;
Stmt BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const final;
size_t num_schedulable_dims() const final;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
......@@ -236,18 +252,14 @@ class TVM_DLL ComputeOpNode : public OperationNode {
Array<Expr> body);
static constexpr const char* _type_key = "ComputeOp";
TVM_DECLARE_NODE_TYPE_INFO(ComputeOpNode, OperationNode);
TVM_DECLARE_NODE_TYPE_INFO(ComputeOpNode, BaseComputeOpNode);
};
/*!
* \brief A TenorCompute op that compute a tensor with an tensor intrinsic.
*/
class TensorComputeOpNode : public OperationNode {
class TensorComputeOpNode : public BaseComputeOpNode {
public:
/*! \brief IterVar on each axis */
Array<IterVar> axis;
/*! \brief IterVar on each reduction axis, if the intrin will use the reduce axis */
Array<IterVar> reduce_axis;
/*! \brief number of axes that can be scheduled */
int schedulable_ndim;
/*! \brief TensorIntrin used to compute */
......@@ -260,9 +272,7 @@ class TensorComputeOpNode : public OperationNode {
TensorComputeOpNode() {}
// override functions
int num_outputs() const final;
Array<IterVar> root_iter_vars() const final;
Type output_dtype(size_t i) const final;
Array<Expr> output_shape(size_t i) const final;
Array<Tensor> InputTensors() const final;
Operation ReplaceInputs(
const Operation& self,
......@@ -271,18 +281,11 @@ class TensorComputeOpNode : public OperationNode {
const Operation& self,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound(
const Operation& self,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(
const Stage& stage,
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final;
Stmt BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const final;
size_t num_schedulable_dims() const final;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
......@@ -304,7 +307,7 @@ class TensorComputeOpNode : public OperationNode {
Array<Region> regions);
static constexpr const char* _type_key = "TensorComputeOp";
TVM_DECLARE_NODE_TYPE_INFO(TensorComputeOpNode, OperationNode);
TVM_DECLARE_NODE_TYPE_INFO(TensorComputeOpNode, BaseComputeOpNode);
};
/*!
......
......@@ -146,7 +146,7 @@ class PlaceholderOp(Operation):
@register_node
class ComputeOp(Operation):
class BaseComputeOp(Operation):
"""Compute operation."""
@property
def axis(self):
......@@ -160,7 +160,13 @@ class ComputeOp(Operation):
@register_node
class TensorComputeOp(Operation):
class ComputeOp(BaseComputeOp):
"""Scalar operation."""
pass
@register_node
class TensorComputeOp(BaseComputeOp):
"""Tensor operation."""
......
......@@ -40,7 +40,7 @@ int ComputeOpNode::num_outputs() const {
return body.size();
}
Array<IterVar> ComputeOpNode::root_iter_vars() const {
Array<IterVar> BaseComputeOpNode::root_iter_vars() const {
if (reduce_axis.size() == 0) return axis;
Array<IterVar> ret = axis;
for (IterVar iv : reduce_axis) {
......@@ -54,15 +54,15 @@ Type ComputeOpNode::output_dtype(size_t idx) const {
return body[idx].type();
}
Array<Expr> ComputeOpNode::output_shape(size_t idx) const {
Array<Expr> BaseComputeOpNode::output_shape(size_t idx) const {
CHECK_LT(idx, num_outputs());
// for now, all outputs of ComputeOp have the same shape
std::vector<Expr> shape;
for (size_t i = 0; i < axis.size(); ++i) {
const Range& r = axis[i]->dom;
// for now, all outputs of a BaseComputeOp have the same shape
Array<Expr> shape;
for (const auto& ivar : this->axis) {
const Range& r = ivar->dom;
shape.push_back(r->extent);
}
return Array<Expr>(shape);
return shape;
}
Tensor compute(Array<Expr> shape,
......@@ -208,7 +208,7 @@ void ComputeOpNode::PropBoundToInputs(
for (auto& e : body) ir::PostOrderVisit(e, fvisit);
}
void ComputeOpNode::GatherBound(
void BaseComputeOpNode::GatherBound(
const Operation& self,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const {
......@@ -225,22 +225,22 @@ void ComputeOpNode::GatherBound(
}
}
Stmt ComputeOpNode::BuildRealize(
Stmt BaseComputeOpNode::BuildRealize(
const Stage& stage,
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& realize_body) const {
const Stmt& body) const {
CHECK_EQ(stage->op.get(), this);
HalideIR::Internal::Region bounds;
for (IterVar iv : this->axis) {
bounds.push_back(realize_map.at(iv));
}
Stmt realize = realize_body;
Stmt realize = body;
for (int i = this->num_outputs(); i > 0; --i) {
Tensor t = stage->op.output(i-1);
realize = ir::Realize::make(t->op, t->value_index,
t->dtype, bounds, const_true(), realize);
// alignment requirement, only useful for compute
for (size_t i = 0; i < this->axis.size(); ++i) {
for (size_t i = 0; i < num_schedulable_dims(); ++i) {
auto it = stage->iter_var_attrs.find(this->axis[i]);
if (it != stage->iter_var_attrs.end()) {
IterVarAttr attr = (*it).second;
......@@ -259,6 +259,10 @@ Stmt ComputeOpNode::BuildRealize(
return realize;
}
size_t ComputeOpNode::num_schedulable_dims() const {
return axis.size();
}
// Build a reduction body.
void MakeReduction(const ComputeOpNode* op,
const Array<Tensor>& tensors,
......@@ -414,7 +418,7 @@ Stmt ComputeOpNode::BuildProvide(
}
ComputeLoopNest ComputeLoopNest::make(
const ComputeOpNode* self,
const BaseComputeOpNode* self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) {
......@@ -440,8 +444,8 @@ ComputeLoopNest ComputeLoopNest::make(
for (IterVar iv : self->reduce_axis) {
update_state[iv] = 2;
}
for (IterVar iv : self->axis) {
update_state[iv] = 1;
for (size_t i = 0; i < self->num_schedulable_dims(); ++i) {
update_state[self->axis[i]] = 1;
}
// find which iter var is related to reduction and which is related to axis.
schedule::PassDownBitMaskOr(stage, &update_state);
......
......@@ -41,7 +41,7 @@ struct ComputeLoopNest {
* \return The constructed loop nest
*/
static ComputeLoopNest make(
const ComputeOpNode* self,
const BaseComputeOpNode* self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop);
......
......@@ -28,27 +28,10 @@ int TensorComputeOpNode::num_outputs() const {
return static_cast<int>(this->intrin->buffers.size() - this->inputs.size());
}
Array<IterVar> TensorComputeOpNode::root_iter_vars() const {
Array<IterVar> ret = axis;
for (IterVar iv : reduce_axis) {
ret.push_back(iv);
}
return ret;
}
Type TensorComputeOpNode::output_dtype(size_t i) const {
return this->intrin->buffers[this->inputs.size() + i]->dtype;
}
Array<Expr> TensorComputeOpNode::output_shape(size_t i) const {
Array<Expr> shape;
for (const auto& ivar : this->axis) {
shape.push_back(ivar->dom->extent);
}
return shape;
}
Operation TensorComputeOpNode::make(std::string name,
std::string tag,
Array<IterVar> axis,
......@@ -121,123 +104,10 @@ void TensorComputeOpNode::PropBoundToInputs(
}
}
void TensorComputeOpNode::GatherBound(
const Operation& self,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const {
const TensorDom& tdom = tensor_dom.at(self.output(0));
for (size_t i = 0; i < this->axis.size(); ++i) {
Range r = arith::Union(tdom.data.at(i)).cover_range(this->axis[i]->dom);
CHECK(!out_dom_map->count(this->axis[i]));
(*out_dom_map)[this->axis[i]] = r;
}
for (size_t i = 0; i < this->reduce_axis.size(); ++i) {
CHECK(!out_dom_map->count(this->reduce_axis[i]));
(*out_dom_map)[this->reduce_axis[i]] = this->reduce_axis[i]->dom;
}
}
Stmt TensorComputeOpNode::BuildRealize(
const Stage& stage,
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const {
CHECK_EQ(stage->op.get(), this);
HalideIR::Internal::Region bounds;
for (IterVar iv : this->axis) {
bounds.push_back(realize_map.at(iv));
}
Stmt realize = body;
for (int i = this->num_outputs(); i > 0; --i) {
Tensor t = stage->op.output(i-1);
realize = ir::Realize::make(t->op, t->value_index,
t->dtype, bounds, const_true(), realize);
// alignment requirement, only useful for compute
for (int i = 0; i < schedulable_ndim; ++i) {
auto it = stage->iter_var_attrs.find(this->axis[i]);
if (it != stage->iter_var_attrs.end()) {
IterVarAttr attr = (*it).second;
if (attr->dim_align_factor != 0) {
Array<Expr> tuple = {static_cast<int>(i),
attr->dim_align_factor,
attr->dim_align_offset};
realize = ir::AttrStmt::make(
t, ir::attr::buffer_dim_align,
Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic),
realize);
}
}
}
}
return realize;
}
ComputeLoopNest MakeLoopNest(
const TensorComputeOpNode* self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) {
CHECK_EQ(stage->op.operator->(), self);
ComputeLoopNest ret;
// make main loop nest
ret.main_nest = op::MakeLoopNest(
stage, dom_map, 0, false, std::unordered_set<IterVar>(), &ret.main_vmap,
debug_keep_trivial_loop);
ret.main_predicates = schedule::MakeBoundCheck(
stage, dom_map, ret.main_vmap, false,
std::unordered_set<IterVar>());
for (auto& e : ret.main_predicates) {
e = likely(e);
}
if (stage->store_predicate.defined()) {
ret.main_predicates.push_back(stage->store_predicate);
}
if (self->reduce_axis.size() != 0) {
// try to find the location to insert the initialization.
// Fuse the initialization and provide loop when possible.
std::unordered_map<IterVar, int> update_state;
for (IterVar iv : self->reduce_axis) {
update_state[iv] = 2;
}
for (int i = 0; i < self->schedulable_ndim; ++i) {
update_state[self->axis[i]] = 1;
}
// find which iter var is related to reduction and which is related to axis.
schedule::PassDownBitMaskOr(stage, &update_state);
auto leaf_iter_vars = stage->leaf_iter_vars;
// first first loop that is related to reduction.
size_t begin_loop = leaf_iter_vars.size();
for (size_t i = 0; i < leaf_iter_vars.size(); ++i) {
auto iv = leaf_iter_vars[i];
int flag = update_state.at(iv);
if ((flag & 2) != 0) {
begin_loop = i; break;
}
ret.init_vmap[iv] = ret.main_vmap.at(iv);
}
ret.num_common_loop = begin_loop;
// skip loops that are related to reduction and are unrelated to axis.
std::unordered_set<IterVar> skip_iter;
for (auto kv : update_state) {
int flag = kv.second;
if (flag == 2) skip_iter.insert(kv.first);
}
ret.init_nest = op::MakeLoopNest(
stage, dom_map, begin_loop, true,
skip_iter, &(ret.init_vmap), debug_keep_trivial_loop);
ret.init_predicates = schedule::MakeBoundCheck(
stage, dom_map, ret.init_vmap, true, skip_iter);
for (auto& e : ret.init_predicates) {
e = likely(e);
}
} else {
CHECK_EQ(ret.main_nest.size(), stage->leaf_iter_vars.size() + 1);
ret.num_common_loop = stage->leaf_iter_vars.size();
}
// copy elison here.
return ret;
size_t TensorComputeOpNode::num_schedulable_dims() const {
return schedulable_ndim;
}
Stmt TensorComputeOpNode::BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
......@@ -296,7 +166,7 @@ Stmt TensorComputeOpNode::BuildProvide(
ir::ArgBinder binder(&vmap);
size_t tloc = stage->leaf_iter_vars.size();
ComputeLoopNest n = MakeLoopNest(this, stage, dom_map, debug_keep_trivial_loop);
ComputeLoopNest n = ComputeLoopNest::make(this, stage, dom_map, debug_keep_trivial_loop);
if (this->reduce_axis.size() == 0) {
std::vector<std::vector<Stmt> > nest(
......
......@@ -229,7 +229,85 @@ def test_tensorize_op():
s = s.normalize()
tvm.lower(s, [A, B])
# This test asserts that tensorize does not have any effect on
# TensorComputeOp operations
def test_tensorize_tensor_compute_op():
# an intrinsic called "multivadd" whose definition (pattern)
# is a loop of another intrinsic called "vadd"
def intrin_multivadd(n):
n_a = tvm.var("n_a")
Ab = tvm.decl_buffer((n, ), tvm.float32, strides=[n_a])
n_b = tvm.var("n_b")
Bb = tvm.decl_buffer((n, ), tvm.float32, strides=[n_b])
n_c = tvm.var("n_c")
Cb = tvm.decl_buffer((n, ), tvm.float32, strides=[n_c])
z = tvm.compute((n,), lambda i: tvm.call_extern("float32", 'vadd',
Ab.access_ptr("w", offset=n_a*i),
Bb.access_ptr("r", offset=n_b*i),
Cb.access_ptr("r", offset=n_c*i)))
# replace the pattern with the multivadd call. I need to figure out
# how to pass it the right parameters.
def intrin_func(ins, outs):
return tvm.call_packed("multivadd")
with tvm.build_config():
return tvm.decl_tensor_intrin(z.op, intrin_func, name="multivadd")
def intrin_vadd(n):
dtype = 'float32'
x = tvm.placeholder((n,), dtype=dtype, name='vx')
y = tvm.placeholder((n,), dtype=dtype, name='vy')
z = tvm.compute(x.shape, lambda i: x[i] + y[i], name='z')
s = tvm.create_schedule(z.op)
def create_buffer(t):
return tvm.decl_buffer(t.shape, t.dtype,
name='W'+t.name,
offset_factor=16)
def intrin_func(ins, outs):
ib = tvm.ir_builder.create()
ib.emit(tvm.call_extern("float32", 'vadd',
ins[0].access_ptr("r"), ins[1].access_ptr('r'),
outs[0].access_ptr('wr')))
return ib.get()
with tvm.build_config(offset_factor=16):
return tvm.decl_tensor_intrin(z.op, intrin_func, binds={x: create_buffer(x),
y: create_buffer(y),
z: create_buffer(z)})
# cache_read, cache_write
M = 1024
factor = 16
dtype = 'float32'
A = tvm.placeholder((M//factor, factor), name="A", dtype=dtype)
B = tvm.placeholder((M//factor, factor), name="B", dtype=dtype)
vadd = intrin_vadd(factor)
C = tvm.compute((M//factor, factor),
lambda i: vadd(A[i, 0:factor], B[i, 0:factor]), name='C')
s = tvm.create_schedule(C.op)
multivadd = intrin_multivadd(64)
s[C].tensorize(C.op.axis[0], multivadd)
s = s.normalize()
dom_map = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, dom_map)
# The loop that we tried to tensorize still exists in the code
# That means tensorize didn't work as expected
assert isinstance(stmt.body.body.body, tvm.stmt.For)
assert stmt.body.body.body.loop_var.name == C.op.axis[0].var.name
if __name__ == "__main__":
test_tensorize_vadd()
test_tensorize_matmul()
test_tensorize_op()
test_tensorize_tensor_compute_op()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment