Commit 820a8597 by Tianqi Chen Committed by GitHub

[LANG] Introduce Scan, Bugfix Canonical (#43)

parent f8f02829
...@@ -49,12 +49,27 @@ struct Reduce : public ExprNode<Reduce> { ...@@ -49,12 +49,27 @@ struct Reduce : public ExprNode<Reduce> {
static constexpr const char* Min = "Min"; static constexpr const char* Min = "Min";
}; };
/*! \brief namespace of possible attribute sin AttrStmt.type_key */
namespace attr {
/*! /*!
* \brief Mark scope of iteration variable, used by Schedule. * \brief Auxiliary data structure used in IR Pass to indicate a tensor.
*/ */
constexpr const char* scope = "scope"; struct TensorKey {
FunctionRef f;
int value_index;
inline bool operator==(const TensorKey& other) const {
return f == other.f && value_index == other.value_index;
}
inline std::string GetName() const {
if (f->num_outputs() == 1) return f->func_name();
std::ostringstream os;
os << f->func_name() << ".v" << value_index;
return os.str();
}
};
/*! \brief namespace of possible attribute sin AttrStmt.type_key */
namespace attr {
// The above attr does not pass to ir stage.
/*! /*!
* \brief Mark launching extent of thread, used by device API. * \brief Mark launching extent of thread, used by device API.
*/ */
...@@ -189,4 +204,16 @@ using Halide::Internal::Evaluate; ...@@ -189,4 +204,16 @@ using Halide::Internal::Evaluate;
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
namespace std {
template <>
struct hash<::tvm::ir::TensorKey> {
std::size_t operator()(const ::tvm::ir::TensorKey& k) const {
size_t lhs = k.f.hash();
size_t rhs = static_cast<size_t>(k.value_index);
lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
return lhs;
}
};
} // namespace std
#endif // TVM_IR_H_ #endif // TVM_IR_H_
...@@ -77,6 +77,55 @@ class ComputeOpNode : public OperationNode { ...@@ -77,6 +77,55 @@ class ComputeOpNode : public OperationNode {
TVM_DECLARE_NODE_TYPE_INFO(ComputeOpNode); TVM_DECLARE_NODE_TYPE_INFO(ComputeOpNode);
}; };
/*!
* \brief Symbolic scan.
*/
class ScanOpNode : public OperationNode {
public:
/*! \brief IterVar to scan over */
IterVar scan_axis;
/*! \brief the initialization tensors */
Array<Tensor> init;
/*! \brief the update function represented by tensor */
Array<Tensor> update;
/*! \brief The placeholder to refer as states in update. */
Array<Tensor> state_placeholder;
/*!
* \brief Spatial axis to indicate spatial dimension of each output.
* They corresponds to flattened spatial axis of the outputs.
*
* [output[0].axis[1], output[0].axis[2]... output[k].axis[j]...]
* These are auxiliary data structure for storing result of bound inference.
* They do not corresponds to splittable iterations, thus the name comes
* with underscore.
*/
Array<IterVar> spatial_axis_;
/*! \brief constructor */
ScanOpNode() {}
// override behavior.
int num_outputs() const final;
Array<IterVar> root_iter_vars() const final;
Type output_dtype(size_t i) const final;
Array<Expr> output_shape(size_t i) const final;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("scan_axis", &scan_axis);
v->Visit("init", &init);
v->Visit("update", &update);
v->Visit("state_placeholder", &state_placeholder);
v->Visit("spatial_axis_", &spatial_axis_);
}
static Operation make(std::string name,
IterVar axis,
Array<Tensor> init,
Array<Tensor> update,
Array<Tensor> state_placeholder);
static constexpr const char* _type_key = "ScanOp";
TVM_DECLARE_NODE_TYPE_INFO(ScanOpNode);
};
/*! \brief The compute function to specify the input source of a Tensor */ /*! \brief The compute function to specify the input source of a Tensor */
using FCompute = std::function<Expr (const Array<Var>& i)>; using FCompute = std::function<Expr (const Array<Var>& i)>;
...@@ -100,6 +149,21 @@ Tensor Placeholder(Array<Expr> shape, ...@@ -100,6 +149,21 @@ Tensor Placeholder(Array<Expr> shape,
*/ */
Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor"); Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor");
/*!
* \brief Construct new tensors by scan over scan_axis.
*
* \param scan_axis The iteration representing the scan.
* \param init The intialize tensor of first K steps.
* \param update The update tensor indicated the updated result after each timestamp.
* \param state_placeholder The placeholder for the states.
* \param name The optional name of the tensor.
*/
Array<Tensor> Scan(IterVar scan_axis,
Array<Tensor> init,
Array<Tensor> update,
Array<Tensor> state_placeholder,
std::string name = "scan");
// same as compute, specialized for different fcompute function // same as compute, specialized for different fcompute function
inline Tensor Compute(Array<Expr> shape, inline Tensor Compute(Array<Expr> shape,
std::function<Expr(Var)> f, std::function<Expr(Var)> f,
......
...@@ -14,6 +14,7 @@ from ._ctypes._function import convert_to_tvm_func as _convert_tvm_func ...@@ -14,6 +14,7 @@ from ._ctypes._function import convert_to_tvm_func as _convert_tvm_func
from . import _api_internal from . import _api_internal
from . import make as _make from . import make as _make
from . import expr as _expr from . import expr as _expr
from . import tensor as _tensor
from . import collections as _collections from . import collections as _collections
int32 = "int32" int32 = "int32"
...@@ -111,7 +112,6 @@ def compute(shape, fcompute, name="compute"): ...@@ -111,7 +112,6 @@ def compute(shape, fcompute, name="compute"):
shape: Tuple of Expr shape: Tuple of Expr
The shape of the tensor The shape of the tensor
fcompute: lambda function of *indices-> value fcompute: lambda function of *indices-> value
Specifies the input source expression Specifies the input source expression
...@@ -137,8 +137,57 @@ def compute(shape, fcompute, name="compute"): ...@@ -137,8 +137,57 @@ def compute(shape, fcompute, name="compute"):
body = convert(body) body = convert(body)
op_node = _api_internal._ComputeOp( op_node = _api_internal._ComputeOp(
name, dim_var, body) name, dim_var, body)
return _api_internal._Tensor( return op_node.output(0)
shape, body.dtype, op_node, 0)
def scan(axis, init, update, state_placeholder, name="scan"):
"""Construct new tensors by scanning over axis.
Parameters
----------
axis: IterVar
The scanning axis.
init: Tensor or list of Tensor
The initial condition of first init.shape[0] timestamps
update: Tensor or list of Tensor
The update rule of the scan given by symbolic tensor.
state_placeholder: Tensor or list of Tensor
The placeholder variables used by update.
name: str, optional
The name hint of the tensor
Returns
-------
tensor: tensor.Tensor
The created tensor
Example
-------
# The following code is equivalent to numpy.cumsum
m = tvm.Var("m")
n = tvm.Var("n")
t = tvm.IterVar((1, m), name="t")
X = tvm.placeholder((m, n), name="X")
s_state = tvm.placeholder((m, n))
s_init = tvm.compute((1, n), lambda _, i: X[0, i])
s_update = tvm.compute((n,), lambda i: s_state[t-1, i] + X[t, i])
res = tvm.scan(t, s_init, s_update, s_state)
"""
if isinstance(init, _tensor.Tensor):
init = [init]
if isinstance(update, _tensor.Tensor):
update = [update]
if isinstance(state_placeholder, _tensor.Tensor):
state_placeholder = [state_placeholder]
if len(init) != len(update) or len(init) != len(state_placeholder):
raise ValueError("init, update, state_placeholder must have same length")
op = _api_internal._ScanOp(name, axis, init, update, state_placeholder)
res = [op.output(i) for i in range(len(update))]
return (res[0] if len(res) == 1 else res)
def Buffer(shape, dtype=None, def Buffer(shape, dtype=None,
......
...@@ -75,11 +75,16 @@ class Operation(NodeBase): ...@@ -75,11 +75,16 @@ class Operation(NodeBase):
return _api_internal._OpGetOutput(self, index) return _api_internal._OpGetOutput(self, index)
@register_node @register_node
class PlaceholderOp(Operation):
"""Placeholder operation."""
pass
@register_node
class ComputeOp(Operation): class ComputeOp(Operation):
"""Compute operation.""" """Compute operation."""
pass pass
@register_node @register_node
class PlaceholderOp(Operation): class ScanOp(Operation):
"""Placeholder operation.""" """Scan operation."""
pass pass
...@@ -173,6 +173,15 @@ TVM_REGISTER_API(_ComputeOp) ...@@ -173,6 +173,15 @@ TVM_REGISTER_API(_ComputeOp)
args[2]); args[2]);
}); });
TVM_REGISTER_API(_ScanOp)
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ScanOpNode::make(args[0],
args[1],
args[2],
args[3],
args[4]);
});
TVM_REGISTER_API(_OpGetOutput) TVM_REGISTER_API(_OpGetOutput)
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Operation().output( *ret = args[0].operator Operation().output(
......
...@@ -365,7 +365,7 @@ class Canonical::Internal : public IRMutator { ...@@ -365,7 +365,7 @@ class Canonical::Internal : public IRMutator {
const ComExpr& sumb, const ComExpr& sumb,
int bscale) { int bscale) {
std::shared_ptr<ComExprNode> n = std::make_shared<ComExprNode>(); std::shared_ptr<ComExprNode> n = std::make_shared<ComExprNode>();
n->base = suma->base + sumb->base; n->base = suma->base + sumb->base * bscale;
// merge of suma and sumb; // merge of suma and sumb;
size_t i = 0, j = 0; size_t i = 0, j = 0;
while (i < suma->elem.size() && j < sumb->elem.size()) { while (i < suma->elem.size() && j < sumb->elem.size()) {
...@@ -417,7 +417,7 @@ class Canonical::Internal : public IRMutator { ...@@ -417,7 +417,7 @@ class Canonical::Internal : public IRMutator {
// convert sum to expr // convert sum to expr
Expr Sum2Expr(const ComExpr& com, Type t) { Expr Sum2Expr(const ComExpr& com, Type t) {
Expr vsum; Expr vsum;
if (com->base != 0) { if (com->base > 0) {
vsum = make_const(t, com->base); vsum = make_const(t, com->base);
} }
for (const ComExprEntry& e : com->elem) { for (const ComExprEntry& e : com->elem) {
...@@ -433,6 +433,13 @@ class Canonical::Internal : public IRMutator { ...@@ -433,6 +433,13 @@ class Canonical::Internal : public IRMutator {
} }
} }
} }
if (com->base < 0) {
if (vsum.defined()) {
vsum = Sub::make(vsum, make_const(t, -com->base));
} else {
vsum = make_const(t, com->base);
}
}
for (const ComExprEntry& e : com->elem) { for (const ComExprEntry& e : com->elem) {
if (e.scale < 0) { if (e.scale < 0) {
Expr v = e.value; Expr v = e.value;
......
...@@ -168,7 +168,7 @@ MakeNVRTC(Array<LoweredFunc> funcs) { ...@@ -168,7 +168,7 @@ MakeNVRTC(Array<LoweredFunc> funcs) {
const auto& f = PackedFunc::GetGlobal("tvm_callback_cuda_postproc"); const auto& f = PackedFunc::GetGlobal("tvm_callback_cuda_postproc");
code = f(code).operator std::string(); code = f(code).operator std::string();
} }
LOG(INFO) << code;
std::string ptx; std::string ptx;
if (PackedFunc::GlobalExist("tvm_callback_cuda_compile")) { if (PackedFunc::GlobalExist("tvm_callback_cuda_compile")) {
const auto& f = PackedFunc::GetGlobal("tvm_callback_cuda_compile"); const auto& f = PackedFunc::GetGlobal("tvm_callback_cuda_compile");
......
...@@ -42,7 +42,7 @@ class CodeGenCUDA : public CodeGenC { ...@@ -42,7 +42,7 @@ class CodeGenCUDA : public CodeGenC {
private: private:
// magic number to add pragma unroll to it. // magic number to add pragma unroll to it.
// used to generate code that is compact but still unrolls. // used to generate code that is compact but still unrolls.
int max_auto_unroll_{8}; int max_auto_unroll_{1025};
}; };
} // namespace codegen } // namespace codegen
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <tvm/operation.h> #include <tvm/operation.h>
#include <tvm/tensor.h> #include <tvm/tensor.h>
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <memory> #include <memory>
namespace tvm { namespace tvm {
...@@ -120,4 +121,90 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -120,4 +121,90 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
TVM_REGISTER_NODE_TYPE(ComputeOpNode); TVM_REGISTER_NODE_TYPE(ComputeOpNode);
// Scan
inline bool prove_equal(Expr lhs, Expr rhs) {
return is_zero(ir::Simplify(lhs - rhs));
}
int ScanOpNode::num_outputs() const {
return update.size();
}
Array<IterVar> ScanOpNode::root_iter_vars() const {
return Array<IterVar>{scan_axis};
}
Type ScanOpNode::output_dtype(size_t i) const {
return update[i]->dtype;
}
Array<Expr> ScanOpNode::output_shape(size_t i) const {
CHECK_LT(i, state_placeholder.size());
return state_placeholder[i]->shape;
}
Operation ScanOpNode::make(std::string name,
IterVar axis,
Array<Tensor> init,
Array<Tensor> update,
Array<Tensor> state_placeholder) {
auto n = std::make_shared<ScanOpNode>();
CHECK_EQ(init.size(), update.size());
CHECK_EQ(init.size(), state_placeholder.size());
for (size_t i = 0; i < init.size(); ++i) {
CHECK_EQ(init[i]->dtype, state_placeholder[i]->dtype);
CHECK_EQ(init[i]->dtype, update[i]->dtype);
CHECK(can_prove(init[i]->shape[0] == axis->dom->min))
<< "init.shape[0] need to match scan_axis.dom.min";
CHECK(prove_equal(
state_placeholder[i]->shape[0], axis->dom->min + axis->dom->extent))
<< "shate_placeholder.shape[0] need to match"
<< " scan_axis.dom.min + scan_axis.dom.extent";
CHECK_EQ(state_placeholder[i].ndim(), init[i].ndim())
<< "The dimension of init need to match state_placeholder";
CHECK_EQ(update[i].ndim() + 1, state_placeholder[i].ndim())
<< "The update.ndim need to be state_placeholder.ndim - 1";
for (size_t k = 0; k < update[i].ndim(); ++k) {
CHECK(prove_equal(
update[i]->shape[k], state_placeholder[i]->shape[k + 1]));
// setup spatial axis
std::ostringstream spatial_name;
spatial_name << name << ".out" << i << ".i" << k + 1;
n->spatial_axis_.push_back(
IterVar(Range::make_with_min_extent(0, update[i]->shape[k]),
spatial_name.str()));
}
for (size_t k = 1; k < init[i].ndim(); ++k) {
CHECK(prove_equal(
init[i]->shape[k], state_placeholder[i]->shape[k]));
}
}
n->name = name;
n->scan_axis = axis;
n->init = init;
n->update = update;
n->state_placeholder = state_placeholder;
return Operation(n);
}
Array<Tensor> Scan(IterVar scan_axis,
Array<Tensor> init,
Array<Tensor> update,
Array<Tensor> state_placeholder,
std::string name) {
Operation op = ScanOpNode::make(
name, scan_axis, init, update, state_placeholder);
Array<Tensor> res;
for (int i = 0; i < op->num_outputs(); ++i) {
res.push_back(op.output(i));
}
return res;
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ScanOpNode>([](const ScanOpNode *op, IRPrinter *p) {
p->stream << "scan(" << op->name << ", " << op << ")";
});
} // namespace tvm } // namespace tvm
...@@ -191,9 +191,6 @@ class VTInjector : public IRMutator { ...@@ -191,9 +191,6 @@ class VTInjector : public IRMutator {
} }
// Attribute // Attribute
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
if (op->type_key == attr::scope) {
return Mutate(op->body);
} else {
Expr value = Mutate(op->value); Expr value = Mutate(op->value);
if (visit_touched_var_) { if (visit_touched_var_) {
return InjectVTLoop(s, true); return InjectVTLoop(s, true);
...@@ -207,7 +204,6 @@ class VTInjector : public IRMutator { ...@@ -207,7 +204,6 @@ class VTInjector : public IRMutator {
} }
} }
} }
}
// LetStmt // LetStmt
Stmt Mutate_(const LetStmt* op, const Stmt& s) final { Stmt Mutate_(const LetStmt* op, const Stmt& s) final {
Expr value = this->Mutate(op->value); Expr value = this->Mutate(op->value);
......
...@@ -11,40 +11,6 @@ ...@@ -11,40 +11,6 @@
namespace tvm { namespace tvm {
namespace ir { namespace ir {
// key of function buffer
struct TensorKey {
FunctionRef f;
int value_index;
inline bool operator==(const TensorKey& other) const {
return f == other.f && value_index == other.value_index;
}
inline std::string GetName() const {
if (f->num_outputs() == 1) return f->func_name();
std::ostringstream os;
os << f->func_name() << ".v" << value_index;
return os.str();
}
};
} // namespace ir
} // namespace tvm
namespace std {
template <>
struct hash<::tvm::ir::TensorKey> {
std::size_t operator()(const ::tvm::ir::TensorKey& k) const {
size_t lhs = k.f.hash();
size_t rhs = static_cast<size_t>(k.value_index);
lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
return lhs;
}
};
} // namespace std
namespace tvm {
namespace ir {
using Halide::Internal::Region; using Halide::Internal::Region;
using runtime::StorageScope; using runtime::StorageScope;
using runtime::ThreadScope; using runtime::ThreadScope;
......
...@@ -23,6 +23,10 @@ inline Expr DivCeil(Expr a, Expr b) { ...@@ -23,6 +23,10 @@ inline Expr DivCeil(Expr a, Expr b) {
return ir::Simplify((a + b - 1) / b); return ir::Simplify((a + b - 1) / b);
} }
inline bool prove_equal(Expr lhs, Expr rhs) {
return is_zero(ir::Simplify(lhs - rhs));
}
// Downward message passing algorithm on stage schedule s, // Downward message passing algorithm on stage schedule s,
// pass the range state down from the root to the leaves // pass the range state down from the root to the leaves
// after this pass, every IterVar in the stage hyper graph will have a range(domain) // after this pass, every IterVar in the stage hyper graph will have a range(domain)
...@@ -41,9 +45,18 @@ void PassDown(const Stage& s, ...@@ -41,9 +45,18 @@ void PassDown(const Stage& s,
if (r->outer->dom.defined()) { if (r->outer->dom.defined()) {
state[r->outer] = r->outer->dom; state[r->outer] = r->outer->dom;
} else { } else {
CHECK(!state.count(r->outer)); if (!state.count(r->outer)) {
state[r->outer] = Range::make_with_min_extent( state[r->outer] = Range::make_with_min_extent(
0, DivCeil(range_parent->extent, r->factor)); 0, DivCeil(range_parent->extent, r->factor));
} else {
Expr outer_ext = DivCeil(range_parent->extent, r->factor);
Range outer_rng = state.at(r->outer);
bool match = is_zero(outer_rng->min);
if (!prove_equal(outer_ext, outer_rng->extent)) match = false;
CHECK(match)
<< "IterVar is used in two places as outer scope,"
<< " cannot prove their extents are the same";
}
} }
} else { } else {
CHECK(r->outer->dom.defined()); CHECK(r->outer->dom.defined());
...@@ -181,6 +194,21 @@ void PassUp(const Stage& s, ...@@ -181,6 +194,21 @@ void PassUp(const Stage& s,
} }
} }
// All the itervars that are needed to output bound of op.
// For most op, it is root_iter_vars
// For Scan, it also contains the additional spatial axis.
Array<IterVar> OutputRelatedIterVars(const Operation& op) {
if (op.as<ScanOpNode>()) {
const ScanOpNode* scan = op.as<ScanOpNode>();
Array<IterVar> ret{scan->scan_axis};
for (IterVar iv : scan->spatial_axis_) {
ret.push_back(iv);
}
return ret;
} else {
return op->root_iter_vars();
}
}
/*! \brief temporary data structure to store Tensor domain */ /*! \brief temporary data structure to store Tensor domain */
struct TensorDom { struct TensorDom {
...@@ -214,6 +242,34 @@ void BoundProp(const Operation& op, ...@@ -214,6 +242,34 @@ void BoundProp(const Operation& op,
} }
}; };
ir::PostOrderVisit(op.as<ComputeOpNode>()->body, fvisit); ir::PostOrderVisit(op.as<ComputeOpNode>()->body, fvisit);
} else if (op.as<ScanOpNode>()) {
const ScanOpNode* scan = op.as<ScanOpNode>();
size_t sp_idx = 0;
for (size_t i = 0; i < scan->init.size(); ++i) {
TensorDom* init_dom = nullptr;
TensorDom* update_dom = nullptr;
if (out->count(scan->init[i])) {
init_dom = &out->at(scan->init[i]);
}
if (out->count(scan->update[i])) {
update_dom = &out->at(scan->update[i]);
}
// first dimension, always needed.
if (init_dom) {
init_dom->data[0].push_back(IntSet::range(
Range::make_with_min_extent(0, scan->init[i]->shape[0])));
}
// The update dimensions
for (size_t k = 0; k < scan->update[i]->shape.size(); ++k, ++sp_idx) {
IterVar sp_ax = scan->spatial_axis_[sp_idx];
if (init_dom) {
init_dom->data[k + 1].push_back(dom_map.at(sp_ax->var.get()));
}
if (update_dom) {
update_dom->data[k].push_back(dom_map.at(sp_ax->var.get()));
}
}
}
} else if (op.as<PlaceholderOpNode>()) { } else if (op.as<PlaceholderOpNode>()) {
// do nothing // do nothing
} else { } else {
...@@ -221,14 +277,49 @@ void BoundProp(const Operation& op, ...@@ -221,14 +277,49 @@ void BoundProp(const Operation& op,
} }
} }
void InferOpBound(const Operation& op, // Given the bound of output of op
// Pass the bound to the related axis in op.
void GatherOpBound(const ScanOpNode* scan,
const Operation& op,
const std::unordered_map<Tensor, TensorDom>& tmap,
std::unordered_map<IterVar, Range>* rmap) {
CHECK(!rmap->count(scan->scan_axis));
std::vector<Tensor> output(op->num_outputs());
for (size_t i = 0; i < output.size(); ++i) {
output[i] = op.output(i);
}
// Update for time axis.
std::vector<IntSet> time_dom;
for (size_t i = 0; i < output.size(); ++i) {
const TensorDom& d = tmap.at(output[i]);
time_dom.insert(time_dom.end(), d.data[0].begin(), d.data[0].end());
}
LOG(INFO) << time_dom.size();
CHECK(!rmap->count(scan->scan_axis));
Range sdom = scan->scan_axis->dom;
Range r = arith::Union(time_dom).cover_range(sdom);
(*rmap)[scan->scan_axis] = Range::make_with_min_extent(
sdom->min, ir::Simplify(r->extent + r->min - sdom->min));
// Update for spatial axis.
size_t sp_idx = 0;
for (size_t i = 0; i < output.size(); ++i) {
for (size_t k = 0; k < scan->update[i]->shape.size(); ++k, ++sp_idx) {
IterVar sp_ax = scan->spatial_axis_[sp_idx];
CHECK(!rmap->count(sp_ax));
// In default, we always need all spatial axis
// Unless that axis only refers back to itself as a fixed point.
// TODO(tqchen): Add fix point detection.
(*rmap)[sp_ax] = sp_ax->dom;
}
}
}
void GatherOpBound(const Operation& op,
const std::unordered_map<Tensor, TensorDom>& tmap, const std::unordered_map<Tensor, TensorDom>& tmap,
std::unordered_map<IterVar, Range>* rmap) { std::unordered_map<IterVar, Range>* rmap) {
if (op.as<ComputeOpNode>()) { if (op.as<ComputeOpNode>()) {
auto root_iter_vars = op->root_iter_vars();
const ComputeOpNode* compute = op.as<ComputeOpNode>(); const ComputeOpNode* compute = op.as<ComputeOpNode>();
const TensorDom& tdom = tmap.at(op.output(0)); const TensorDom& tdom = tmap.at(op.output(0));
for (size_t i = 0; i < compute->axis.size(); ++i) { for (size_t i = 0; i < compute->axis.size(); ++i) {
Range r = arith::Union(tdom.data[i]).cover_range(compute->axis[i]->dom); Range r = arith::Union(tdom.data[i]).cover_range(compute->axis[i]->dom);
CHECK(!rmap->count(compute->axis[i])); CHECK(!rmap->count(compute->axis[i]));
...@@ -238,6 +329,8 @@ void InferOpBound(const Operation& op, ...@@ -238,6 +329,8 @@ void InferOpBound(const Operation& op,
CHECK(!rmap->count(compute->reduce_axis[i])); CHECK(!rmap->count(compute->reduce_axis[i]));
(*rmap)[compute->reduce_axis[i]] = compute->reduce_axis[i]->dom; (*rmap)[compute->reduce_axis[i]] = compute->reduce_axis[i]->dom;
} }
} else if (op.as<ScanOpNode>()) {
GatherOpBound(op.as<ScanOpNode>(), op, tmap, rmap);
} else if (op.as<PlaceholderOpNode>()) { } else if (op.as<PlaceholderOpNode>()) {
// dp nothing // dp nothing
} else { } else {
...@@ -269,8 +362,7 @@ void InferRootBound(const Stage& stage, ...@@ -269,8 +362,7 @@ void InferRootBound(const Stage& stage,
std::unordered_map<IterVar, Range>* rmap) { std::unordered_map<IterVar, Range>* rmap) {
if (stage->attach_type == kInline) return; if (stage->attach_type == kInline) return;
if (stage->attach_type == kRoot || stage->attach_type == kNone) { if (stage->attach_type == kRoot || stage->attach_type == kNone) {
auto root_iter_vars = stage->op->root_iter_vars(); for (auto iv : OutputRelatedIterVars(stage->op)) {
for (auto iv : root_iter_vars) {
CHECK(iv->dom.defined()); CHECK(iv->dom.defined());
CHECK(!rmap->count(iv)); CHECK(!rmap->count(iv));
(*rmap)[iv] = iv->dom; (*rmap)[iv] = iv->dom;
...@@ -338,8 +430,13 @@ void InferRootBound(const Stage& stage, ...@@ -338,8 +430,13 @@ void InferRootBound(const Stage& stage,
PassUp(parent, *rmap, &up_state); PassUp(parent, *rmap, &up_state);
std::unordered_map<const Variable*, IntSet> dom_map; std::unordered_map<const Variable*, IntSet> dom_map;
for (auto iv : parent->op->root_iter_vars()) { for (auto iv : OutputRelatedIterVars(parent->op)) {
Range r = up_state.at(iv).cover_range(iv->dom); Range r;
if (up_state.count(iv)) {
r = up_state.at(iv).cover_range(iv->dom);
} else {
r = iv->dom;
}
if (relax_set.size() != 0) { if (relax_set.size() != 0) {
dom_map[iv->var.get()] = EvalSet(r, relax_set); dom_map[iv->var.get()] = EvalSet(r, relax_set);
} else { } else {
...@@ -379,13 +476,13 @@ void InferRootBound(const Stage& stage, ...@@ -379,13 +476,13 @@ void InferRootBound(const Stage& stage,
CHECK(found) CHECK(found)
<< "Invalid Schedule, cannot find the producer " << stage->op << "Invalid Schedule, cannot find the producer " << stage->op
<< " along the loop nest specified by compute_at of consumer " << op; << " along the loop nest specified by compute_at of consumer " << op;
for (auto iv : op->root_iter_vars()) { for (auto iv : OutputRelatedIterVars(op)) {
Range r = rmap->at(iv); Range r = rmap->at(iv);
dom_map[iv->var.get()] = EvalSet(r, relax_set); dom_map[iv->var.get()] = EvalSet(r, relax_set);
} }
BoundProp(op, dom_map, &tmap); BoundProp(op, dom_map, &tmap);
} }
InferOpBound(stage->op, tmap, rmap); GatherOpBound(stage->op, tmap, rmap);
} }
FeedGraph CreateFeedGraph(const Schedule& sch) { FeedGraph CreateFeedGraph(const Schedule& sch) {
......
...@@ -33,20 +33,28 @@ ReadGraph CreateReadGraph(const Array<Operation>& roots) { ...@@ -33,20 +33,28 @@ ReadGraph CreateReadGraph(const Array<Operation>& roots) {
if (call != nullptr && call->func.defined()) { if (call != nullptr && call->func.defined()) {
Operation call_op(call->func.node_); Operation call_op(call->func.node_);
deps.push_back(call_op.output(call->value_index)); deps.push_back(call_op.output(call->value_index));
if (call_op.defined() && visited.count(call_op.get()) == 0) {
visited.insert(call_op.get());
stack.push_back(call_op);
}
} }
}; };
ir::PostOrderVisit(op.as<ComputeOpNode>()->body, fvisit); ir::PostOrderVisit(op.as<ComputeOpNode>()->body, fvisit);
rmap.Set(op, deps); } else if (op.as<ScanOpNode>()) {
const ScanOpNode* scan = op.as<ScanOpNode>();
for (Tensor t : scan->init) {
deps.push_back(t);
}
for (Tensor t : scan->update) {
deps.push_back(t);
}
} else if (op.as<PlaceholderOpNode>()) { } else if (op.as<PlaceholderOpNode>()) {
// empty set of deps
rmap.Set(op, deps);
} else { } else {
LOG(FATAL) << "unknown Operation" << op->type_key(); LOG(FATAL) << "unknown Operation" << op->type_key();
} }
rmap.Set(op, deps);
for (Tensor t : deps) {
if (t->op.defined() && visited.count(t->op.get()) == 0) {
visited.insert(t->op.get());
stack.push_back(t->op);
}
}
} }
return rmap; return rmap;
} }
......
...@@ -146,6 +146,8 @@ Stage& Stage::fuse(IterVar inner, IterVar outer, IterVar* p_target) { // NOLINT ...@@ -146,6 +146,8 @@ Stage& Stage::fuse(IterVar inner, IterVar outer, IterVar* p_target) { // NOLINT
Stage& Stage::reorder(const Array<IterVar>& order) { // NOLINT(*) Stage& Stage::reorder(const Array<IterVar>& order) { // NOLINT(*)
StageNode* self = operator->(); StageNode* self = operator->();
CHECK(!self->op.as<ScanOpNode>())
<< "Cannot reorder axis of scan";
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
std::vector<size_t> pos; std::vector<size_t> pos;
......
...@@ -7,7 +7,9 @@ ...@@ -7,7 +7,9 @@
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h> #include <tvm/ir_visitor.h>
#include <tvm/schedule_pass.h> #include <tvm/schedule_pass.h>
#include <utility>
#include <unordered_map>
#include <unordered_set>
#include "../pass/ir_util.h" #include "../pass/ir_util.h"
#include "../arithmetic/compute_expr.h" #include "../arithmetic/compute_expr.h"
#include "./graph.h" #include "./graph.h"
...@@ -18,6 +20,12 @@ namespace schedule { ...@@ -18,6 +20,12 @@ namespace schedule {
using namespace arith; using namespace arith;
using namespace ir; using namespace ir;
// Two private scope marks
namespace attr {
constexpr const char* loop_scope = "loop_scope";
constexpr const char* scan_scope = "scan_scope";
} // namespace attr
/*! /*!
* \brief message passing to find if IterVar is related to reduction. * \brief message passing to find if IterVar is related to reduction.
* \param s The stage to be used. * \param s The stage to be used.
...@@ -168,7 +176,6 @@ MakeLoopNest(const Stage& sch, ...@@ -168,7 +176,6 @@ MakeLoopNest(const Stage& sch,
value_map[iv] = iv->var; value_map[iv] = iv->var;
continue; continue;
} }
Range dom = dom_map.at(iv); Range dom = dom_map.at(iv);
// initialize the offset and loop_level // initialize the offset and loop_level
Var var = iv->var; Var var = iv->var;
...@@ -223,7 +230,7 @@ MakeLoopNest(const Stage& sch, ...@@ -223,7 +230,7 @@ MakeLoopNest(const Stage& sch,
if (!reduce_init_loop) { if (!reduce_init_loop) {
// annotate the extent of the IterVar // annotate the extent of the IterVar
nest[i + 1].emplace_back( nest[i + 1].emplace_back(
AttrStmt::make(iv, ir::attr::scope, iv->var, no_op)); AttrStmt::make(iv, attr::loop_scope, iv->var, no_op));
} }
} }
// message passing to get offset of root iter vars. // message passing to get offset of root iter vars.
...@@ -307,8 +314,8 @@ Stmt MakeLoop(const Stage& s, ...@@ -307,8 +314,8 @@ Stmt MakeLoop(const Stage& s,
init = Substitute(init, init_value_map); init = Substitute(init, init_value_map);
init = MergeNest(init_nest, init); init = MergeNest(init_nest, init);
// common nest // common nest
std::vector<std::vector<Stmt> > common(nest.begin(), nest.begin() + begin_loop); std::vector<std::vector<Stmt> > common(nest.begin(), nest.begin() + begin_loop + 1);
std::vector<std::vector<Stmt> > reduce(nest.begin() + begin_loop, nest.end()); std::vector<std::vector<Stmt> > reduce(nest.begin() + begin_loop + 1, nest.end());
provide = MergeNest(reduce, provide); provide = MergeNest(reduce, provide);
return MergeNest( return MergeNest(
common, Block::make(init, provide)); common, Block::make(init, provide));
...@@ -340,6 +347,29 @@ Stmt MakeRealize(const ComputeOpNode* op, ...@@ -340,6 +347,29 @@ Stmt MakeRealize(const ComputeOpNode* op,
bounds, make_const(Bool(1), true), body); bounds, make_const(Bool(1), true), body);
} }
Stmt MakeRealize(const ScanOpNode* op,
const Map<IterVar, Range>& dom_map,
const std::vector<Tensor>& tensors,
Stmt body) {
Range sdom = dom_map.at(op->scan_axis);
Range tdom = Range::make_with_min_extent(
0, ir::Simplify(sdom->extent + sdom->min));
size_t sp_idx = 0;
for (size_t i = 0; i < tensors.size(); ++i) {
const Tensor& t = tensors[i];
CHECK_EQ(static_cast<size_t>(t->value_index), i);
Halide::Internal::Region bounds;
bounds.push_back(tdom);
for (size_t k = 0; k < op->update[i]->shape.size(); ++k, ++sp_idx) {
IterVar sp_ax = op->spatial_axis_[sp_idx];
bounds.push_back(dom_map.at(sp_ax));
}
body = Realize::make(t->op, t->value_index, t->dtype,
bounds, make_const(Bool(1), true), body);
}
return body;
}
void MakeReduction(const ComputeOpNode* op, void MakeReduction(const ComputeOpNode* op,
const std::vector<Tensor>& tensors, const std::vector<Tensor>& tensors,
...@@ -382,12 +412,18 @@ Stmt MakePipeline(const Stage& s, ...@@ -382,12 +412,18 @@ Stmt MakePipeline(const Stage& s,
Stmt init, provide; Stmt init, provide;
const ComputeOpNode* compute = s->op.as<ComputeOpNode>(); const ComputeOpNode* compute = s->op.as<ComputeOpNode>();
const ScanOpNode* scan = s->op.as<ScanOpNode>();
if (compute) { if (compute) {
if (compute->reduce_axis.size() == 0) { if (compute->reduce_axis.size() == 0) {
provide = MakeProvide(compute, tensors); provide = MakeProvide(compute, tensors);
} else { } else {
MakeReduction(compute, tensors, &init, &provide); MakeReduction(compute, tensors, &init, &provide);
} }
} else if (scan) {
// Provide is done by the sub operations.
provide = AttrStmt::make(
s->op, attr::scan_scope, scan->scan_axis->var,
Evaluate::make(0));
} else { } else {
LOG(FATAL) << "not supported op " << s->op->type_key(); LOG(FATAL) << "not supported op " << s->op->type_key();
} }
...@@ -396,7 +432,12 @@ Stmt MakePipeline(const Stage& s, ...@@ -396,7 +432,12 @@ Stmt MakePipeline(const Stage& s,
producer = ProducerConsumer::make(s->op, true, producer); producer = ProducerConsumer::make(s->op, true, producer);
Stmt pipeline = producer; Stmt pipeline = producer;
if (consumer.defined()) { // check if consumer is nop.
bool is_no_op{false};
const Evaluate* ev = consumer.as<Evaluate>();
if (ev && ev->value.as<IntImm>()) is_no_op = true;
if (consumer.defined() && !is_no_op) {
consumer = ProducerConsumer::make(s->op, false, consumer); consumer = ProducerConsumer::make(s->op, false, consumer);
pipeline = Block::make(producer, consumer); pipeline = Block::make(producer, consumer);
} }
...@@ -404,47 +445,103 @@ Stmt MakePipeline(const Stage& s, ...@@ -404,47 +445,103 @@ Stmt MakePipeline(const Stage& s,
if (s->op.as<ComputeOpNode>()) { if (s->op.as<ComputeOpNode>()) {
pipeline = MakeRealize(s->op.as<ComputeOpNode>(), pipeline = MakeRealize(s->op.as<ComputeOpNode>(),
dom_map, tensors, pipeline); dom_map, tensors, pipeline);
} else if (s->op.as<ScanOpNode>()) {
pipeline = MakeRealize(s->op.as<ScanOpNode>(),
dom_map, tensors, pipeline);
} else { } else {
LOG(FATAL) << "not supported op"; LOG(FATAL) << "not supported op";
return Stmt();
} }
// use attribute to mark scope of the operation. // use attribute to mark scope of the operation.
pipeline = AttrStmt::make( pipeline = AttrStmt::make(
s->op, "realize_scope", s->op, ir::attr::realize_scope,
StringImm::make(s->scope), StringImm::make(s->scope),
pipeline); pipeline);
return pipeline; return pipeline;
} }
// inject the operator's realization on the stmt. // inject the operator's realization on the stmt.
class InjectRealize : public IRMutator { class InjectAttach : public IRMutator {
public: public:
InjectRealize(Stage schedule, Map<IterVar, Range> dom_map) InjectAttach(const Stage& stage,
: schedule(schedule), dom_map(dom_map) {} const Map<IterVar, Range>& dom_map)
: stage_(stage), dom_map_(dom_map) {}
Stmt Mutate(Stmt stmt) final { Stmt Mutate(Stmt stmt) final {
CHECK(stmt.defined()); CHECK(stmt.defined());
stmt = IRMutator::Mutate(stmt); stmt = IRMutator::Mutate(stmt);
const AttrStmt* op = stmt.as<AttrStmt>(); const AttrStmt* op = stmt.as<AttrStmt>();
if (op != nullptr && if (op != nullptr &&
op->type_key == "scope") { op->type_key == attr::loop_scope) {
if (op->node == schedule->attach_ivar) { if (op->node == stage_->attach_ivar) {
CHECK(!found_attach); CHECK(!found_attach);
found_attach = true; found_attach = true;
stmt = AttrStmt::make( stmt = AttrStmt::make(
op->node, op->type_key, op->value, op->node, op->type_key, op->value,
MakePipeline(schedule, dom_map, MakePipeline(stage_, dom_map_, op->body));
IRMutator::Mutate(op->body)));
} }
} }
return stmt; return stmt;
} }
// whether attach point is found
bool found_attach{false};
private:
// the operations to be carried // the operations to be carried
Stage schedule; const Stage& stage_;
// domain map // domain map
Map<IterVar, Range> dom_map; const Map<IterVar, Range>& dom_map_;
};
// inject the operator's realization on the stmt.
class InjectScanStep : public IRMutator {
public:
InjectScanStep(const Stage& stage,
const Operation& scan_op,
const Map<IterVar, Range>& dom_map,
bool is_init)
: stage_(stage), scan_op_(scan_op),
dom_map_(dom_map), is_init_(is_init) {}
Stmt Mutate(Stmt stmt) final {
CHECK(stmt.defined());
stmt = IRMutator::Mutate(stmt);
if (is_init_) {
const ProducerConsumer* op = stmt.as<ProducerConsumer>();
if (op != nullptr &&
op->is_producer &&
op->func.same_as(scan_op_)) {
stmt = ProducerConsumer::make(
op->func, true,
MakePipeline(stage_, dom_map_, op->body));
found_attach = true;
}
} else {
// update
const AttrStmt* op = stmt.as<AttrStmt>();
if (op != nullptr &&
op->type_key == attr::scan_scope) {
if (op->node.same_as(scan_op_)) {
found_attach = true;
stmt = AttrStmt::make(
op->node, op->type_key, op->value,
MakePipeline(stage_, dom_map_, op->body));
}
}
}
return stmt;
}
// whether attach point is found // whether attach point is found
bool found_attach{false}; bool found_attach{false};
private:
// the operations to be carried
const Stage& stage_;
const Operation& scan_op_;
// domain map
const Map<IterVar, Range>& dom_map_;
// whether it is init.
bool is_init_;
}; };
Stmt InjectInline(const Operation op, Stmt body) { Stmt InjectInline(const Operation op, Stmt body) {
...@@ -459,27 +556,180 @@ Stmt InjectInline(const Operation op, Stmt body) { ...@@ -459,27 +556,180 @@ Stmt InjectInline(const Operation op, Stmt body) {
return Inline(body, op, args, compute->body); return Inline(body, op, args, compute->body);
} }
// Postprocessing of schedule op
// Replace the init and update's expression by scan's buffer.
class SchedulePostProc : public IRMutator {
public:
Stmt Mutate_(const ProducerConsumer* op, const Stmt& s) final {
if (to_remove_.count(op->func.get())) {
return this->Mutate(op->body);
} else {
return IRMutator::Mutate_(op, s);
}
}
Stmt Mutate_(const LetStmt* op, const Stmt& s) final {
if (!HasSideEffect(op->value)) {
var_value_[op->var.get()] = Mutate(op->value);
return this->Mutate(op->body);
} else {
return IRMutator::Mutate_(op, s);
}
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
if (op->type_key == attr::loop_scope) {
return this->Mutate(op->body);
} else if (op->type_key == attr::scan_scope) {
const ScanOpNode* scan = op->node.as<ScanOpNode>();
CHECK(scan);
var_value_[scan->scan_axis->var.get()] = op->value;
return this->Mutate(op->body);
} else if (op->type_key == ir::attr::realize_scope) {
if (to_remove_.count(op->node.get())) {
return this->Mutate(op->body);
}
}
return IRMutator::Mutate_(op, s);
}
Stmt Mutate_(const Realize* op, const Stmt& s) final {
TensorKey key{op->func, op->value_index};
if (replace_.count(key)) {
return this->Mutate(op->body);
} else {
return IRMutator::Mutate_(op, s);
}
}
Stmt Mutate_(const Provide* op, const Stmt& s) final {
TensorKey key{op->func, op->value_index};
auto it = replace_.find(key);
if (it != replace_.end()) {
const Tensor& dst = it->second.first;
Stmt ret = Provide::make(
dst->op, dst->value_index, op->value,
RewriteArgs(it->second.second, op->args));
return IRMutator::Mutate_(ret.as<Provide>(), ret);
} else {
return IRMutator::Mutate_(op, s);
}
}
Expr Mutate_(const Call* op, const Expr& e) final {
if (op != nullptr && op->call_type == Call::Halide) {
TensorKey key{op->func, op->value_index};
auto it = replace_.find(key);
if (it != replace_.end()) {
const Tensor& dst = it->second.first;
Expr ret = Call::make(
op->type, dst->op->name,
RewriteArgs(it->second.second, op->args),
op->call_type, dst->op, dst->value_index);
return IRMutator::Mutate_(ret.as<Call>(), ret);
}
}
return IRMutator::Mutate_(op, e);
}
Expr Mutate_(const Variable* op, const Expr& e) final {
auto it = var_value_.find(op);
if (it != var_value_.end()) {
return it->second;
} else {
return e;
}
}
void Init(const Schedule& sch) {
for (Stage s : sch->stages) {
const ScanOpNode* scan = s->op.as<ScanOpNode>();
if (!scan) continue;
for (size_t i = 0; i < scan->update.size(); ++i) {
Tensor t = s->op.output(i);
AddReplace(scan->init[i], t, Expr());
AddReplace(scan->update[i], t, scan->scan_axis->var);
AddReplace(scan->state_placeholder[i], t, Expr());
}
}
}
private:
void AddReplace(Tensor src, Tensor dst, Expr head_idx) {
replace_[TensorKey{src->op, src->value_index}]
= std::make_pair(dst, head_idx);
to_remove_.insert(src->op.get());
}
Array<Expr> RewriteArgs(Expr head, Array<Expr> args) {
if (!head.defined()) return args;
Array<Expr> new_args{head};
for (Expr e : args) {
new_args.push_back(e);
}
return new_args;
}
// The scan value
std::unordered_map<const Variable*, Expr> var_value_;
// buffer replacement
std::unordered_map<TensorKey, std::pair<Tensor, Expr> > replace_;
// replaced functions
std::unordered_set<const Node*> to_remove_;
};
Stmt ScheduleOps( Stmt ScheduleOps(
Schedule sch, Map<IterVar, Range> dom_map) { Schedule sch, Map<IterVar, Range> dom_map) {
Stmt body = Stmt(); Stmt body = Stmt();
// scan init and scan updates
std::unordered_map<Operation, std::pair<Operation, bool> > scan_attach;
for (Stage s : sch->stages) {
const ScanOpNode* scan = s->op.as<ScanOpNode>();
if (!scan) continue;
for (Tensor t : scan->init) {
if (scan_attach.count(t->op)) {
CHECK(scan_attach.at(t->op).first.same_as(s->op))
<< "Scan init tensor can only belong to one scan";
} else {
scan_attach[t->op] = std::make_pair(s->op, true);
}
}
for (Tensor t : scan->update) {
if (scan_attach.count(t->op)) {
CHECK(scan_attach.at(t->op).first.same_as(s->op))
<< "Scan update tensor can only belong to one scan";
} else {
scan_attach[t->op] = std::make_pair(s->op, false);
}
}
}
// reverse the post DFS order. // reverse the post DFS order.
for (size_t i = sch->stages.size(); i != 0; --i) { for (size_t i = sch->stages.size(); i != 0; --i) {
Stage s = sch->stages[i - 1]; Stage s = sch->stages[i - 1];
// no need to specify place holder op. // no need to specify place holder op.
if (s->op.as<PlaceholderOpNode>()) continue; if (s->op.as<PlaceholderOpNode>()) continue;
if (s->attach_type == kInline) { if (scan_attach.count(s->op)) {
CHECK(s->attach_type == kNone || s->attach_type == kInline)
<< "Cannot specify compute_at for scan's init/update";
CHECK(body.defined());
const auto& p = scan_attach.at(s->op);
InjectScanStep mu(s, p.first, dom_map, p.second);
body = mu.Mutate(body);
CHECK(mu.found_attach)
<< "did not find attachment point for scan.init/update";
} else if (s->attach_type == kInline) {
body = InjectInline(s->op, body); body = InjectInline(s->op, body);
} else if (s->attach_type == kRoot || s-> attach_type == kNone) { } else if (s->attach_type == kRoot || s-> attach_type == kNone) {
body = MakePipeline(s, dom_map, body); body = MakePipeline(s, dom_map, body);
} else if (s->attach_type == kScope) { } else if (s->attach_type == kScope) {
CHECK(body.defined()); CHECK(body.defined());
InjectRealize mutator(s, dom_map); InjectAttach mutator(s, dom_map);
body = mutator.Mutate(body); body = mutator.Mutate(body);
CHECK(mutator.found_attach) CHECK(mutator.found_attach)
<< "did not find attachment point"; << "did not find attachment point";
} }
} }
return body; SchedulePostProc post_proc;
post_proc.Init(sch);
return post_proc.Mutate(body);
} }
} // namespace schedule } // namespace schedule
......
import tvm
import numpy as np
def test_scan():
m = tvm.Var("m")
n = tvm.Var("n")
t = tvm.IterVar((1, m), name="t")
X = tvm.placeholder((m, n), name="X")
s_state = tvm.placeholder((m, n))
s_init = tvm.compute((1, n), lambda _, i: X[0, i])
s_update = tvm.compute((n,), lambda i: s_state[t-1, i] + X[t, i])
res = tvm.scan(t, s_init, s_update, s_state)
# schedule
s = tvm.Schedule(res.op)
num_thread = 256
block_x = tvm.IterVar(thread_tag="blockIdx.x")
thread_x = tvm.IterVar((0, num_thread), thread_tag="threadIdx.x")
_, x = s[s_init].split(s_init.op.axis[1], factor=num_thread, outer=block_x)
_, x = s[s_init].split(x, outer=thread_x)
_, x = s[s_update].split(s_update.op.axis[0], factor=num_thread, outer=block_x)
_, x = s[s_update].split(x, outer=thread_x)
# one line to build the function.
def check_device(target):
codes = []
fscan = tvm.build(s, [X, res],
target, record_codes=codes,
name="myscan")
if target == "cuda":
ctx = tvm.gpu(0)
else:
ctx = tvm.cl(0)
if not ctx.enabled:
return
for c in codes[1:]:
print(c)
# launch the kernel.
n = 1024
m = 10
a_np = np.random.uniform(size=(m, n)).astype(res.dtype)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros((m, n), dtype=res.dtype), ctx)
fscan(a, b)
np.testing.assert_allclose(
b.asnumpy(), np.cumsum(a_np, axis=0))
tvm.init_opencl()
check_device("cuda")
if __name__ == "__main__":
test_scan()
...@@ -34,6 +34,20 @@ def test_tensor_reduce(): ...@@ -34,6 +34,20 @@ def test_tensor_reduce():
assert(str(C_loaded) == str(C)) assert(str(C_loaded) == str(C))
def test_tensor_scan():
m = tvm.Var("m")
n = tvm.Var("n")
t = tvm.IterVar((1, m), "t")
x = tvm.placeholder((m, n))
s = tvm.placeholder((m, n))
res = tvm.scan(t,
tvm.compute((1, n), lambda _, i: x[0, i]),
tvm.compute((n,), lambda i: s[t-1, i] + x[t, i]),
s)
assert tuple(res.shape) == (m, n)
if __name__ == "__main__": if __name__ == "__main__":
test_tensor() test_tensor()
test_tensor_reduce() test_tensor_reduce()
test_tensor_scan()
...@@ -18,9 +18,15 @@ def test_simplify(): ...@@ -18,9 +18,15 @@ def test_simplify():
tvm.make.Load(dtype, Ab.data, i + 4) + 1, tvm.make.Load(dtype, Ab.data, i + 4) + 1,
(j + 1) * 4 - 4 * j + i), (j + 1) * 4 - 4 * j + i),
None))) None)))
print(stmt)
stmt = tvm.ir_pass.CanonicalSimplify(stmt) stmt = tvm.ir_pass.CanonicalSimplify(stmt)
print(stmt)
def test_basic():
m = tvm.Var('m')
ret = tvm.ir_pass.CanonicalSimplify(tvm.make.Evaluate(m-1))
assert str(ret.value) == "(m - 1)"
if __name__ == "__main__": if __name__ == "__main__":
test_basic()
test_simplify() test_simplify()
...@@ -6,13 +6,11 @@ def test_schedule0(): ...@@ -6,13 +6,11 @@ def test_schedule0():
l = tvm.Var('l') l = tvm.Var('l')
A = tvm.placeholder((m, l), name='A') A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1') A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
s = tvm.Schedule(A1.op) s = tvm.Schedule(A1.op)
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map) assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
print(stmt)
def test_schedule1(): def test_schedule1():
m = tvm.Var('m') m = tvm.Var('m')
...@@ -25,7 +23,7 @@ def test_schedule1(): ...@@ -25,7 +23,7 @@ def test_schedule1():
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map) assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
print(stmt)
def test_schedule2(): def test_schedule2():
m = tvm.Var('m') m = tvm.Var('m')
...@@ -40,8 +38,28 @@ def test_schedule2(): ...@@ -40,8 +38,28 @@ def test_schedule2():
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map) assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_schedule_scan():
m = tvm.Var("m")
n = tvm.Var("n")
l = tvm.Var("l")
t = tvm.IterVar((1, m), name="t")
x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
s_state = tvm.placeholder((m, n))
s_init = tvm.compute((1, n), lambda _, i: x[0, i])
s_update = tvm.compute((n,), lambda i: s_state[t-1, i] + x[t, i])
res = tvm.scan(t, s_init, s_update, s_state)
assert tuple(res.shape) == (m, n)
s = tvm.Schedule(res.op)
s.normalize()
bounds = tvm.schedule.InferBound(s)
assert(bounds[res.op.scan_axis].min.value == 1)
stmt = tvm.schedule.ScheduleOps(s, bounds)
print(stmt) print(stmt)
def test_auto_inline(): def test_auto_inline():
m = tvm.Var('m') m = tvm.Var('m')
n = tvm.Var('n') n = tvm.Var('n')
...@@ -55,10 +73,10 @@ def test_auto_inline(): ...@@ -55,10 +73,10 @@ def test_auto_inline():
tvm.schedule.AutoInlineElemWise(s) tvm.schedule.AutoInlineElemWise(s)
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
print(stmt)
if __name__ == "__main__": if __name__ == "__main__":
test_schedule_scan()
test_schedule0() test_schedule0()
test_schedule1() test_schedule1()
test_schedule2() test_schedule2()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment