Commit 820a8597 by Tianqi Chen Committed by GitHub

[LANG] Introduce Scan, Bugfix Canonical (#43)

parent f8f02829
......@@ -49,12 +49,27 @@ struct Reduce : public ExprNode<Reduce> {
static constexpr const char* Min = "Min";
};
/*! \brief namespace of possible attribute sin AttrStmt.type_key */
namespace attr {
/*!
* \brief Mark scope of iteration variable, used by Schedule.
* \brief Auxiliary data structure used in IR Pass to indicate a tensor.
*/
constexpr const char* scope = "scope";
struct TensorKey {
FunctionRef f;
int value_index;
inline bool operator==(const TensorKey& other) const {
return f == other.f && value_index == other.value_index;
}
inline std::string GetName() const {
if (f->num_outputs() == 1) return f->func_name();
std::ostringstream os;
os << f->func_name() << ".v" << value_index;
return os.str();
}
};
/*! \brief namespace of possible attribute sin AttrStmt.type_key */
namespace attr {
// The above attr does not pass to ir stage.
/*!
* \brief Mark launching extent of thread, used by device API.
*/
......@@ -189,4 +204,16 @@ using Halide::Internal::Evaluate;
} // namespace ir
} // namespace tvm
namespace std {
template <>
struct hash<::tvm::ir::TensorKey> {
std::size_t operator()(const ::tvm::ir::TensorKey& k) const {
size_t lhs = k.f.hash();
size_t rhs = static_cast<size_t>(k.value_index);
lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
return lhs;
}
};
} // namespace std
#endif // TVM_IR_H_
......@@ -77,6 +77,55 @@ class ComputeOpNode : public OperationNode {
TVM_DECLARE_NODE_TYPE_INFO(ComputeOpNode);
};
/*!
* \brief Symbolic scan.
*/
class ScanOpNode : public OperationNode {
public:
/*! \brief IterVar to scan over */
IterVar scan_axis;
/*! \brief the initialization tensors */
Array<Tensor> init;
/*! \brief the update function represented by tensor */
Array<Tensor> update;
/*! \brief The placeholder to refer as states in update. */
Array<Tensor> state_placeholder;
/*!
* \brief Spatial axis to indicate spatial dimension of each output.
* They corresponds to flattened spatial axis of the outputs.
*
* [output[0].axis[1], output[0].axis[2]... output[k].axis[j]...]
* These are auxiliary data structure for storing result of bound inference.
* They do not corresponds to splittable iterations, thus the name comes
* with underscore.
*/
Array<IterVar> spatial_axis_;
/*! \brief constructor */
ScanOpNode() {}
// override behavior.
int num_outputs() const final;
Array<IterVar> root_iter_vars() const final;
Type output_dtype(size_t i) const final;
Array<Expr> output_shape(size_t i) const final;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("scan_axis", &scan_axis);
v->Visit("init", &init);
v->Visit("update", &update);
v->Visit("state_placeholder", &state_placeholder);
v->Visit("spatial_axis_", &spatial_axis_);
}
static Operation make(std::string name,
IterVar axis,
Array<Tensor> init,
Array<Tensor> update,
Array<Tensor> state_placeholder);
static constexpr const char* _type_key = "ScanOp";
TVM_DECLARE_NODE_TYPE_INFO(ScanOpNode);
};
/*! \brief The compute function to specify the input source of a Tensor */
using FCompute = std::function<Expr (const Array<Var>& i)>;
......@@ -100,6 +149,21 @@ Tensor Placeholder(Array<Expr> shape,
*/
Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor");
/*!
* \brief Construct new tensors by scan over scan_axis.
*
* \param scan_axis The iteration representing the scan.
* \param init The intialize tensor of first K steps.
* \param update The update tensor indicated the updated result after each timestamp.
* \param state_placeholder The placeholder for the states.
* \param name The optional name of the tensor.
*/
Array<Tensor> Scan(IterVar scan_axis,
Array<Tensor> init,
Array<Tensor> update,
Array<Tensor> state_placeholder,
std::string name = "scan");
// same as compute, specialized for different fcompute function
inline Tensor Compute(Array<Expr> shape,
std::function<Expr(Var)> f,
......
......@@ -14,6 +14,7 @@ from ._ctypes._function import convert_to_tvm_func as _convert_tvm_func
from . import _api_internal
from . import make as _make
from . import expr as _expr
from . import tensor as _tensor
from . import collections as _collections
int32 = "int32"
......@@ -111,7 +112,6 @@ def compute(shape, fcompute, name="compute"):
shape: Tuple of Expr
The shape of the tensor
fcompute: lambda function of *indices-> value
Specifies the input source expression
......@@ -137,8 +137,57 @@ def compute(shape, fcompute, name="compute"):
body = convert(body)
op_node = _api_internal._ComputeOp(
name, dim_var, body)
return _api_internal._Tensor(
shape, body.dtype, op_node, 0)
return op_node.output(0)
def scan(axis, init, update, state_placeholder, name="scan"):
"""Construct new tensors by scanning over axis.
Parameters
----------
axis: IterVar
The scanning axis.
init: Tensor or list of Tensor
The initial condition of first init.shape[0] timestamps
update: Tensor or list of Tensor
The update rule of the scan given by symbolic tensor.
state_placeholder: Tensor or list of Tensor
The placeholder variables used by update.
name: str, optional
The name hint of the tensor
Returns
-------
tensor: tensor.Tensor
The created tensor
Example
-------
# The following code is equivalent to numpy.cumsum
m = tvm.Var("m")
n = tvm.Var("n")
t = tvm.IterVar((1, m), name="t")
X = tvm.placeholder((m, n), name="X")
s_state = tvm.placeholder((m, n))
s_init = tvm.compute((1, n), lambda _, i: X[0, i])
s_update = tvm.compute((n,), lambda i: s_state[t-1, i] + X[t, i])
res = tvm.scan(t, s_init, s_update, s_state)
"""
if isinstance(init, _tensor.Tensor):
init = [init]
if isinstance(update, _tensor.Tensor):
update = [update]
if isinstance(state_placeholder, _tensor.Tensor):
state_placeholder = [state_placeholder]
if len(init) != len(update) or len(init) != len(state_placeholder):
raise ValueError("init, update, state_placeholder must have same length")
op = _api_internal._ScanOp(name, axis, init, update, state_placeholder)
res = [op.output(i) for i in range(len(update))]
return (res[0] if len(res) == 1 else res)
def Buffer(shape, dtype=None,
......
......@@ -75,11 +75,16 @@ class Operation(NodeBase):
return _api_internal._OpGetOutput(self, index)
@register_node
class PlaceholderOp(Operation):
"""Placeholder operation."""
pass
@register_node
class ComputeOp(Operation):
"""Compute operation."""
pass
@register_node
class PlaceholderOp(Operation):
"""Placeholder operation."""
class ScanOp(Operation):
"""Scan operation."""
pass
......@@ -173,6 +173,15 @@ TVM_REGISTER_API(_ComputeOp)
args[2]);
});
TVM_REGISTER_API(_ScanOp)
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ScanOpNode::make(args[0],
args[1],
args[2],
args[3],
args[4]);
});
TVM_REGISTER_API(_OpGetOutput)
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Operation().output(
......
......@@ -365,7 +365,7 @@ class Canonical::Internal : public IRMutator {
const ComExpr& sumb,
int bscale) {
std::shared_ptr<ComExprNode> n = std::make_shared<ComExprNode>();
n->base = suma->base + sumb->base;
n->base = suma->base + sumb->base * bscale;
// merge of suma and sumb;
size_t i = 0, j = 0;
while (i < suma->elem.size() && j < sumb->elem.size()) {
......@@ -417,7 +417,7 @@ class Canonical::Internal : public IRMutator {
// convert sum to expr
Expr Sum2Expr(const ComExpr& com, Type t) {
Expr vsum;
if (com->base != 0) {
if (com->base > 0) {
vsum = make_const(t, com->base);
}
for (const ComExprEntry& e : com->elem) {
......@@ -433,6 +433,13 @@ class Canonical::Internal : public IRMutator {
}
}
}
if (com->base < 0) {
if (vsum.defined()) {
vsum = Sub::make(vsum, make_const(t, -com->base));
} else {
vsum = make_const(t, com->base);
}
}
for (const ComExprEntry& e : com->elem) {
if (e.scale < 0) {
Expr v = e.value;
......
......@@ -168,7 +168,7 @@ MakeNVRTC(Array<LoweredFunc> funcs) {
const auto& f = PackedFunc::GetGlobal("tvm_callback_cuda_postproc");
code = f(code).operator std::string();
}
LOG(INFO) << code;
std::string ptx;
if (PackedFunc::GlobalExist("tvm_callback_cuda_compile")) {
const auto& f = PackedFunc::GetGlobal("tvm_callback_cuda_compile");
......
......@@ -42,7 +42,7 @@ class CodeGenCUDA : public CodeGenC {
private:
// magic number to add pragma unroll to it.
// used to generate code that is compact but still unrolls.
int max_auto_unroll_{8};
int max_auto_unroll_{1025};
};
} // namespace codegen
......
......@@ -5,6 +5,7 @@
#include <tvm/operation.h>
#include <tvm/tensor.h>
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <memory>
namespace tvm {
......@@ -120,4 +121,90 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
TVM_REGISTER_NODE_TYPE(ComputeOpNode);
// Scan
inline bool prove_equal(Expr lhs, Expr rhs) {
return is_zero(ir::Simplify(lhs - rhs));
}
int ScanOpNode::num_outputs() const {
return update.size();
}
Array<IterVar> ScanOpNode::root_iter_vars() const {
return Array<IterVar>{scan_axis};
}
Type ScanOpNode::output_dtype(size_t i) const {
return update[i]->dtype;
}
Array<Expr> ScanOpNode::output_shape(size_t i) const {
CHECK_LT(i, state_placeholder.size());
return state_placeholder[i]->shape;
}
Operation ScanOpNode::make(std::string name,
IterVar axis,
Array<Tensor> init,
Array<Tensor> update,
Array<Tensor> state_placeholder) {
auto n = std::make_shared<ScanOpNode>();
CHECK_EQ(init.size(), update.size());
CHECK_EQ(init.size(), state_placeholder.size());
for (size_t i = 0; i < init.size(); ++i) {
CHECK_EQ(init[i]->dtype, state_placeholder[i]->dtype);
CHECK_EQ(init[i]->dtype, update[i]->dtype);
CHECK(can_prove(init[i]->shape[0] == axis->dom->min))
<< "init.shape[0] need to match scan_axis.dom.min";
CHECK(prove_equal(
state_placeholder[i]->shape[0], axis->dom->min + axis->dom->extent))
<< "shate_placeholder.shape[0] need to match"
<< " scan_axis.dom.min + scan_axis.dom.extent";
CHECK_EQ(state_placeholder[i].ndim(), init[i].ndim())
<< "The dimension of init need to match state_placeholder";
CHECK_EQ(update[i].ndim() + 1, state_placeholder[i].ndim())
<< "The update.ndim need to be state_placeholder.ndim - 1";
for (size_t k = 0; k < update[i].ndim(); ++k) {
CHECK(prove_equal(
update[i]->shape[k], state_placeholder[i]->shape[k + 1]));
// setup spatial axis
std::ostringstream spatial_name;
spatial_name << name << ".out" << i << ".i" << k + 1;
n->spatial_axis_.push_back(
IterVar(Range::make_with_min_extent(0, update[i]->shape[k]),
spatial_name.str()));
}
for (size_t k = 1; k < init[i].ndim(); ++k) {
CHECK(prove_equal(
init[i]->shape[k], state_placeholder[i]->shape[k]));
}
}
n->name = name;
n->scan_axis = axis;
n->init = init;
n->update = update;
n->state_placeholder = state_placeholder;
return Operation(n);
}
Array<Tensor> Scan(IterVar scan_axis,
Array<Tensor> init,
Array<Tensor> update,
Array<Tensor> state_placeholder,
std::string name) {
Operation op = ScanOpNode::make(
name, scan_axis, init, update, state_placeholder);
Array<Tensor> res;
for (int i = 0; i < op->num_outputs(); ++i) {
res.push_back(op.output(i));
}
return res;
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ScanOpNode>([](const ScanOpNode *op, IRPrinter *p) {
p->stream << "scan(" << op->name << ", " << op << ")";
});
} // namespace tvm
......@@ -191,20 +191,16 @@ class VTInjector : public IRMutator {
}
// Attribute
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
if (op->type_key == attr::scope) {
return Mutate(op->body);
Expr value = Mutate(op->value);
if (visit_touched_var_) {
return InjectVTLoop(s, true);
} else {
Expr value = Mutate(op->value);
if (visit_touched_var_) {
return InjectVTLoop(s, true);
Stmt body = Mutate(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return s;
} else {
Stmt body = Mutate(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return s;
} else {
return AttrStmt::make(op->node, op->type_key, value, body);
}
return AttrStmt::make(op->node, op->type_key, value, body);
}
}
}
......
......@@ -11,40 +11,6 @@
namespace tvm {
namespace ir {
// key of function buffer
struct TensorKey {
FunctionRef f;
int value_index;
inline bool operator==(const TensorKey& other) const {
return f == other.f && value_index == other.value_index;
}
inline std::string GetName() const {
if (f->num_outputs() == 1) return f->func_name();
std::ostringstream os;
os << f->func_name() << ".v" << value_index;
return os.str();
}
};
} // namespace ir
} // namespace tvm
namespace std {
template <>
struct hash<::tvm::ir::TensorKey> {
std::size_t operator()(const ::tvm::ir::TensorKey& k) const {
size_t lhs = k.f.hash();
size_t rhs = static_cast<size_t>(k.value_index);
lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
return lhs;
}
};
} // namespace std
namespace tvm {
namespace ir {
using Halide::Internal::Region;
using runtime::StorageScope;
using runtime::ThreadScope;
......
......@@ -23,6 +23,10 @@ inline Expr DivCeil(Expr a, Expr b) {
return ir::Simplify((a + b - 1) / b);
}
inline bool prove_equal(Expr lhs, Expr rhs) {
return is_zero(ir::Simplify(lhs - rhs));
}
// Downward message passing algorithm on stage schedule s,
// pass the range state down from the root to the leaves
// after this pass, every IterVar in the stage hyper graph will have a range(domain)
......@@ -41,9 +45,18 @@ void PassDown(const Stage& s,
if (r->outer->dom.defined()) {
state[r->outer] = r->outer->dom;
} else {
CHECK(!state.count(r->outer));
state[r->outer] = Range::make_with_min_extent(
0, DivCeil(range_parent->extent, r->factor));
if (!state.count(r->outer)) {
state[r->outer] = Range::make_with_min_extent(
0, DivCeil(range_parent->extent, r->factor));
} else {
Expr outer_ext = DivCeil(range_parent->extent, r->factor);
Range outer_rng = state.at(r->outer);
bool match = is_zero(outer_rng->min);
if (!prove_equal(outer_ext, outer_rng->extent)) match = false;
CHECK(match)
<< "IterVar is used in two places as outer scope,"
<< " cannot prove their extents are the same";
}
}
} else {
CHECK(r->outer->dom.defined());
......@@ -181,6 +194,21 @@ void PassUp(const Stage& s,
}
}
// All the itervars that are needed to output bound of op.
// For most op, it is root_iter_vars
// For Scan, it also contains the additional spatial axis.
Array<IterVar> OutputRelatedIterVars(const Operation& op) {
if (op.as<ScanOpNode>()) {
const ScanOpNode* scan = op.as<ScanOpNode>();
Array<IterVar> ret{scan->scan_axis};
for (IterVar iv : scan->spatial_axis_) {
ret.push_back(iv);
}
return ret;
} else {
return op->root_iter_vars();
}
}
/*! \brief temporary data structure to store Tensor domain */
struct TensorDom {
......@@ -214,6 +242,34 @@ void BoundProp(const Operation& op,
}
};
ir::PostOrderVisit(op.as<ComputeOpNode>()->body, fvisit);
} else if (op.as<ScanOpNode>()) {
const ScanOpNode* scan = op.as<ScanOpNode>();
size_t sp_idx = 0;
for (size_t i = 0; i < scan->init.size(); ++i) {
TensorDom* init_dom = nullptr;
TensorDom* update_dom = nullptr;
if (out->count(scan->init[i])) {
init_dom = &out->at(scan->init[i]);
}
if (out->count(scan->update[i])) {
update_dom = &out->at(scan->update[i]);
}
// first dimension, always needed.
if (init_dom) {
init_dom->data[0].push_back(IntSet::range(
Range::make_with_min_extent(0, scan->init[i]->shape[0])));
}
// The update dimensions
for (size_t k = 0; k < scan->update[i]->shape.size(); ++k, ++sp_idx) {
IterVar sp_ax = scan->spatial_axis_[sp_idx];
if (init_dom) {
init_dom->data[k + 1].push_back(dom_map.at(sp_ax->var.get()));
}
if (update_dom) {
update_dom->data[k].push_back(dom_map.at(sp_ax->var.get()));
}
}
}
} else if (op.as<PlaceholderOpNode>()) {
// do nothing
} else {
......@@ -221,14 +277,49 @@ void BoundProp(const Operation& op,
}
}
void InferOpBound(const Operation& op,
const std::unordered_map<Tensor, TensorDom>& tmap,
std::unordered_map<IterVar, Range>* rmap) {
// Given the bound of output of op
// Pass the bound to the related axis in op.
void GatherOpBound(const ScanOpNode* scan,
const Operation& op,
const std::unordered_map<Tensor, TensorDom>& tmap,
std::unordered_map<IterVar, Range>* rmap) {
CHECK(!rmap->count(scan->scan_axis));
std::vector<Tensor> output(op->num_outputs());
for (size_t i = 0; i < output.size(); ++i) {
output[i] = op.output(i);
}
// Update for time axis.
std::vector<IntSet> time_dom;
for (size_t i = 0; i < output.size(); ++i) {
const TensorDom& d = tmap.at(output[i]);
time_dom.insert(time_dom.end(), d.data[0].begin(), d.data[0].end());
}
LOG(INFO) << time_dom.size();
CHECK(!rmap->count(scan->scan_axis));
Range sdom = scan->scan_axis->dom;
Range r = arith::Union(time_dom).cover_range(sdom);
(*rmap)[scan->scan_axis] = Range::make_with_min_extent(
sdom->min, ir::Simplify(r->extent + r->min - sdom->min));
// Update for spatial axis.
size_t sp_idx = 0;
for (size_t i = 0; i < output.size(); ++i) {
for (size_t k = 0; k < scan->update[i]->shape.size(); ++k, ++sp_idx) {
IterVar sp_ax = scan->spatial_axis_[sp_idx];
CHECK(!rmap->count(sp_ax));
// In default, we always need all spatial axis
// Unless that axis only refers back to itself as a fixed point.
// TODO(tqchen): Add fix point detection.
(*rmap)[sp_ax] = sp_ax->dom;
}
}
}
void GatherOpBound(const Operation& op,
const std::unordered_map<Tensor, TensorDom>& tmap,
std::unordered_map<IterVar, Range>* rmap) {
if (op.as<ComputeOpNode>()) {
auto root_iter_vars = op->root_iter_vars();
const ComputeOpNode* compute = op.as<ComputeOpNode>();
const TensorDom& tdom = tmap.at(op.output(0));
for (size_t i = 0; i < compute->axis.size(); ++i) {
Range r = arith::Union(tdom.data[i]).cover_range(compute->axis[i]->dom);
CHECK(!rmap->count(compute->axis[i]));
......@@ -238,6 +329,8 @@ void InferOpBound(const Operation& op,
CHECK(!rmap->count(compute->reduce_axis[i]));
(*rmap)[compute->reduce_axis[i]] = compute->reduce_axis[i]->dom;
}
} else if (op.as<ScanOpNode>()) {
GatherOpBound(op.as<ScanOpNode>(), op, tmap, rmap);
} else if (op.as<PlaceholderOpNode>()) {
// dp nothing
} else {
......@@ -269,8 +362,7 @@ void InferRootBound(const Stage& stage,
std::unordered_map<IterVar, Range>* rmap) {
if (stage->attach_type == kInline) return;
if (stage->attach_type == kRoot || stage->attach_type == kNone) {
auto root_iter_vars = stage->op->root_iter_vars();
for (auto iv : root_iter_vars) {
for (auto iv : OutputRelatedIterVars(stage->op)) {
CHECK(iv->dom.defined());
CHECK(!rmap->count(iv));
(*rmap)[iv] = iv->dom;
......@@ -338,8 +430,13 @@ void InferRootBound(const Stage& stage,
PassUp(parent, *rmap, &up_state);
std::unordered_map<const Variable*, IntSet> dom_map;
for (auto iv : parent->op->root_iter_vars()) {
Range r = up_state.at(iv).cover_range(iv->dom);
for (auto iv : OutputRelatedIterVars(parent->op)) {
Range r;
if (up_state.count(iv)) {
r = up_state.at(iv).cover_range(iv->dom);
} else {
r = iv->dom;
}
if (relax_set.size() != 0) {
dom_map[iv->var.get()] = EvalSet(r, relax_set);
} else {
......@@ -379,13 +476,13 @@ void InferRootBound(const Stage& stage,
CHECK(found)
<< "Invalid Schedule, cannot find the producer " << stage->op
<< " along the loop nest specified by compute_at of consumer " << op;
for (auto iv : op->root_iter_vars()) {
for (auto iv : OutputRelatedIterVars(op)) {
Range r = rmap->at(iv);
dom_map[iv->var.get()] = EvalSet(r, relax_set);
}
BoundProp(op, dom_map, &tmap);
}
InferOpBound(stage->op, tmap, rmap);
GatherOpBound(stage->op, tmap, rmap);
}
FeedGraph CreateFeedGraph(const Schedule& sch) {
......
......@@ -33,20 +33,28 @@ ReadGraph CreateReadGraph(const Array<Operation>& roots) {
if (call != nullptr && call->func.defined()) {
Operation call_op(call->func.node_);
deps.push_back(call_op.output(call->value_index));
if (call_op.defined() && visited.count(call_op.get()) == 0) {
visited.insert(call_op.get());
stack.push_back(call_op);
}
}
};
ir::PostOrderVisit(op.as<ComputeOpNode>()->body, fvisit);
rmap.Set(op, deps);
} else if (op.as<ScanOpNode>()) {
const ScanOpNode* scan = op.as<ScanOpNode>();
for (Tensor t : scan->init) {
deps.push_back(t);
}
for (Tensor t : scan->update) {
deps.push_back(t);
}
} else if (op.as<PlaceholderOpNode>()) {
// empty set of deps
rmap.Set(op, deps);
} else {
LOG(FATAL) << "unknown Operation" << op->type_key();
}
rmap.Set(op, deps);
for (Tensor t : deps) {
if (t->op.defined() && visited.count(t->op.get()) == 0) {
visited.insert(t->op.get());
stack.push_back(t->op);
}
}
}
return rmap;
}
......
......@@ -146,6 +146,8 @@ Stage& Stage::fuse(IterVar inner, IterVar outer, IterVar* p_target) { // NOLINT
Stage& Stage::reorder(const Array<IterVar>& order) { // NOLINT(*)
StageNode* self = operator->();
CHECK(!self->op.as<ScanOpNode>())
<< "Cannot reorder axis of scan";
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
std::vector<size_t> pos;
......
......@@ -7,7 +7,9 @@
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/schedule_pass.h>
#include <utility>
#include <unordered_map>
#include <unordered_set>
#include "../pass/ir_util.h"
#include "../arithmetic/compute_expr.h"
#include "./graph.h"
......@@ -18,6 +20,12 @@ namespace schedule {
using namespace arith;
using namespace ir;
// Two private scope marks
namespace attr {
constexpr const char* loop_scope = "loop_scope";
constexpr const char* scan_scope = "scan_scope";
} // namespace attr
/*!
* \brief message passing to find if IterVar is related to reduction.
* \param s The stage to be used.
......@@ -168,7 +176,6 @@ MakeLoopNest(const Stage& sch,
value_map[iv] = iv->var;
continue;
}
Range dom = dom_map.at(iv);
// initialize the offset and loop_level
Var var = iv->var;
......@@ -223,7 +230,7 @@ MakeLoopNest(const Stage& sch,
if (!reduce_init_loop) {
// annotate the extent of the IterVar
nest[i + 1].emplace_back(
AttrStmt::make(iv, ir::attr::scope, iv->var, no_op));
AttrStmt::make(iv, attr::loop_scope, iv->var, no_op));
}
}
// message passing to get offset of root iter vars.
......@@ -307,8 +314,8 @@ Stmt MakeLoop(const Stage& s,
init = Substitute(init, init_value_map);
init = MergeNest(init_nest, init);
// common nest
std::vector<std::vector<Stmt> > common(nest.begin(), nest.begin() + begin_loop);
std::vector<std::vector<Stmt> > reduce(nest.begin() + begin_loop, nest.end());
std::vector<std::vector<Stmt> > common(nest.begin(), nest.begin() + begin_loop + 1);
std::vector<std::vector<Stmt> > reduce(nest.begin() + begin_loop + 1, nest.end());
provide = MergeNest(reduce, provide);
return MergeNest(
common, Block::make(init, provide));
......@@ -340,6 +347,29 @@ Stmt MakeRealize(const ComputeOpNode* op,
bounds, make_const(Bool(1), true), body);
}
Stmt MakeRealize(const ScanOpNode* op,
const Map<IterVar, Range>& dom_map,
const std::vector<Tensor>& tensors,
Stmt body) {
Range sdom = dom_map.at(op->scan_axis);
Range tdom = Range::make_with_min_extent(
0, ir::Simplify(sdom->extent + sdom->min));
size_t sp_idx = 0;
for (size_t i = 0; i < tensors.size(); ++i) {
const Tensor& t = tensors[i];
CHECK_EQ(static_cast<size_t>(t->value_index), i);
Halide::Internal::Region bounds;
bounds.push_back(tdom);
for (size_t k = 0; k < op->update[i]->shape.size(); ++k, ++sp_idx) {
IterVar sp_ax = op->spatial_axis_[sp_idx];
bounds.push_back(dom_map.at(sp_ax));
}
body = Realize::make(t->op, t->value_index, t->dtype,
bounds, make_const(Bool(1), true), body);
}
return body;
}
void MakeReduction(const ComputeOpNode* op,
const std::vector<Tensor>& tensors,
......@@ -382,12 +412,18 @@ Stmt MakePipeline(const Stage& s,
Stmt init, provide;
const ComputeOpNode* compute = s->op.as<ComputeOpNode>();
const ScanOpNode* scan = s->op.as<ScanOpNode>();
if (compute) {
if (compute->reduce_axis.size() == 0) {
provide = MakeProvide(compute, tensors);
} else {
MakeReduction(compute, tensors, &init, &provide);
}
} else if (scan) {
// Provide is done by the sub operations.
provide = AttrStmt::make(
s->op, attr::scan_scope, scan->scan_axis->var,
Evaluate::make(0));
} else {
LOG(FATAL) << "not supported op " << s->op->type_key();
}
......@@ -396,7 +432,12 @@ Stmt MakePipeline(const Stage& s,
producer = ProducerConsumer::make(s->op, true, producer);
Stmt pipeline = producer;
if (consumer.defined()) {
// check if consumer is nop.
bool is_no_op{false};
const Evaluate* ev = consumer.as<Evaluate>();
if (ev && ev->value.as<IntImm>()) is_no_op = true;
if (consumer.defined() && !is_no_op) {
consumer = ProducerConsumer::make(s->op, false, consumer);
pipeline = Block::make(producer, consumer);
}
......@@ -404,47 +445,103 @@ Stmt MakePipeline(const Stage& s,
if (s->op.as<ComputeOpNode>()) {
pipeline = MakeRealize(s->op.as<ComputeOpNode>(),
dom_map, tensors, pipeline);
} else if (s->op.as<ScanOpNode>()) {
pipeline = MakeRealize(s->op.as<ScanOpNode>(),
dom_map, tensors, pipeline);
} else {
LOG(FATAL) << "not supported op";
return Stmt();
}
// use attribute to mark scope of the operation.
pipeline = AttrStmt::make(
s->op, "realize_scope",
s->op, ir::attr::realize_scope,
StringImm::make(s->scope),
pipeline);
return pipeline;
}
// inject the operator's realization on the stmt.
class InjectRealize : public IRMutator {
class InjectAttach : public IRMutator {
public:
InjectRealize(Stage schedule, Map<IterVar, Range> dom_map)
: schedule(schedule), dom_map(dom_map) {}
InjectAttach(const Stage& stage,
const Map<IterVar, Range>& dom_map)
: stage_(stage), dom_map_(dom_map) {}
Stmt Mutate(Stmt stmt) final {
CHECK(stmt.defined());
stmt = IRMutator::Mutate(stmt);
const AttrStmt* op = stmt.as<AttrStmt>();
if (op != nullptr &&
op->type_key == "scope") {
if (op->node == schedule->attach_ivar) {
op->type_key == attr::loop_scope) {
if (op->node == stage_->attach_ivar) {
CHECK(!found_attach);
found_attach = true;
stmt = AttrStmt::make(
op->node, op->type_key, op->value,
MakePipeline(schedule, dom_map,
IRMutator::Mutate(op->body)));
MakePipeline(stage_, dom_map_, op->body));
}
}
return stmt;
}
// whether attach point is found
bool found_attach{false};
private:
// the operations to be carried
Stage schedule;
const Stage& stage_;
// domain map
Map<IterVar, Range> dom_map;
const Map<IterVar, Range>& dom_map_;
};
// inject the operator's realization on the stmt.
class InjectScanStep : public IRMutator {
public:
InjectScanStep(const Stage& stage,
const Operation& scan_op,
const Map<IterVar, Range>& dom_map,
bool is_init)
: stage_(stage), scan_op_(scan_op),
dom_map_(dom_map), is_init_(is_init) {}
Stmt Mutate(Stmt stmt) final {
CHECK(stmt.defined());
stmt = IRMutator::Mutate(stmt);
if (is_init_) {
const ProducerConsumer* op = stmt.as<ProducerConsumer>();
if (op != nullptr &&
op->is_producer &&
op->func.same_as(scan_op_)) {
stmt = ProducerConsumer::make(
op->func, true,
MakePipeline(stage_, dom_map_, op->body));
found_attach = true;
}
} else {
// update
const AttrStmt* op = stmt.as<AttrStmt>();
if (op != nullptr &&
op->type_key == attr::scan_scope) {
if (op->node.same_as(scan_op_)) {
found_attach = true;
stmt = AttrStmt::make(
op->node, op->type_key, op->value,
MakePipeline(stage_, dom_map_, op->body));
}
}
}
return stmt;
}
// whether attach point is found
bool found_attach{false};
private:
// the operations to be carried
const Stage& stage_;
const Operation& scan_op_;
// domain map
const Map<IterVar, Range>& dom_map_;
// whether it is init.
bool is_init_;
};
Stmt InjectInline(const Operation op, Stmt body) {
......@@ -459,27 +556,180 @@ Stmt InjectInline(const Operation op, Stmt body) {
return Inline(body, op, args, compute->body);
}
// Postprocessing of schedule op
// Replace the init and update's expression by scan's buffer.
class SchedulePostProc : public IRMutator {
public:
Stmt Mutate_(const ProducerConsumer* op, const Stmt& s) final {
if (to_remove_.count(op->func.get())) {
return this->Mutate(op->body);
} else {
return IRMutator::Mutate_(op, s);
}
}
Stmt Mutate_(const LetStmt* op, const Stmt& s) final {
if (!HasSideEffect(op->value)) {
var_value_[op->var.get()] = Mutate(op->value);
return this->Mutate(op->body);
} else {
return IRMutator::Mutate_(op, s);
}
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
if (op->type_key == attr::loop_scope) {
return this->Mutate(op->body);
} else if (op->type_key == attr::scan_scope) {
const ScanOpNode* scan = op->node.as<ScanOpNode>();
CHECK(scan);
var_value_[scan->scan_axis->var.get()] = op->value;
return this->Mutate(op->body);
} else if (op->type_key == ir::attr::realize_scope) {
if (to_remove_.count(op->node.get())) {
return this->Mutate(op->body);
}
}
return IRMutator::Mutate_(op, s);
}
Stmt Mutate_(const Realize* op, const Stmt& s) final {
TensorKey key{op->func, op->value_index};
if (replace_.count(key)) {
return this->Mutate(op->body);
} else {
return IRMutator::Mutate_(op, s);
}
}
Stmt Mutate_(const Provide* op, const Stmt& s) final {
TensorKey key{op->func, op->value_index};
auto it = replace_.find(key);
if (it != replace_.end()) {
const Tensor& dst = it->second.first;
Stmt ret = Provide::make(
dst->op, dst->value_index, op->value,
RewriteArgs(it->second.second, op->args));
return IRMutator::Mutate_(ret.as<Provide>(), ret);
} else {
return IRMutator::Mutate_(op, s);
}
}
Expr Mutate_(const Call* op, const Expr& e) final {
if (op != nullptr && op->call_type == Call::Halide) {
TensorKey key{op->func, op->value_index};
auto it = replace_.find(key);
if (it != replace_.end()) {
const Tensor& dst = it->second.first;
Expr ret = Call::make(
op->type, dst->op->name,
RewriteArgs(it->second.second, op->args),
op->call_type, dst->op, dst->value_index);
return IRMutator::Mutate_(ret.as<Call>(), ret);
}
}
return IRMutator::Mutate_(op, e);
}
Expr Mutate_(const Variable* op, const Expr& e) final {
auto it = var_value_.find(op);
if (it != var_value_.end()) {
return it->second;
} else {
return e;
}
}
void Init(const Schedule& sch) {
for (Stage s : sch->stages) {
const ScanOpNode* scan = s->op.as<ScanOpNode>();
if (!scan) continue;
for (size_t i = 0; i < scan->update.size(); ++i) {
Tensor t = s->op.output(i);
AddReplace(scan->init[i], t, Expr());
AddReplace(scan->update[i], t, scan->scan_axis->var);
AddReplace(scan->state_placeholder[i], t, Expr());
}
}
}
private:
void AddReplace(Tensor src, Tensor dst, Expr head_idx) {
replace_[TensorKey{src->op, src->value_index}]
= std::make_pair(dst, head_idx);
to_remove_.insert(src->op.get());
}
Array<Expr> RewriteArgs(Expr head, Array<Expr> args) {
if (!head.defined()) return args;
Array<Expr> new_args{head};
for (Expr e : args) {
new_args.push_back(e);
}
return new_args;
}
// The scan value
std::unordered_map<const Variable*, Expr> var_value_;
// buffer replacement
std::unordered_map<TensorKey, std::pair<Tensor, Expr> > replace_;
// replaced functions
std::unordered_set<const Node*> to_remove_;
};
Stmt ScheduleOps(
Schedule sch, Map<IterVar, Range> dom_map) {
Stmt body = Stmt();
// scan init and scan updates
std::unordered_map<Operation, std::pair<Operation, bool> > scan_attach;
for (Stage s : sch->stages) {
const ScanOpNode* scan = s->op.as<ScanOpNode>();
if (!scan) continue;
for (Tensor t : scan->init) {
if (scan_attach.count(t->op)) {
CHECK(scan_attach.at(t->op).first.same_as(s->op))
<< "Scan init tensor can only belong to one scan";
} else {
scan_attach[t->op] = std::make_pair(s->op, true);
}
}
for (Tensor t : scan->update) {
if (scan_attach.count(t->op)) {
CHECK(scan_attach.at(t->op).first.same_as(s->op))
<< "Scan update tensor can only belong to one scan";
} else {
scan_attach[t->op] = std::make_pair(s->op, false);
}
}
}
// reverse the post DFS order.
for (size_t i = sch->stages.size(); i != 0; --i) {
Stage s = sch->stages[i - 1];
// no need to specify place holder op.
if (s->op.as<PlaceholderOpNode>()) continue;
if (s->attach_type == kInline) {
if (scan_attach.count(s->op)) {
CHECK(s->attach_type == kNone || s->attach_type == kInline)
<< "Cannot specify compute_at for scan's init/update";
CHECK(body.defined());
const auto& p = scan_attach.at(s->op);
InjectScanStep mu(s, p.first, dom_map, p.second);
body = mu.Mutate(body);
CHECK(mu.found_attach)
<< "did not find attachment point for scan.init/update";
} else if (s->attach_type == kInline) {
body = InjectInline(s->op, body);
} else if (s->attach_type == kRoot || s-> attach_type == kNone) {
body = MakePipeline(s, dom_map, body);
} else if (s->attach_type == kScope) {
CHECK(body.defined());
InjectRealize mutator(s, dom_map);
InjectAttach mutator(s, dom_map);
body = mutator.Mutate(body);
CHECK(mutator.found_attach)
<< "did not find attachment point";
}
}
return body;
SchedulePostProc post_proc;
post_proc.Init(sch);
return post_proc.Mutate(body);
}
} // namespace schedule
......
import tvm
import numpy as np
def test_scan():
m = tvm.Var("m")
n = tvm.Var("n")
t = tvm.IterVar((1, m), name="t")
X = tvm.placeholder((m, n), name="X")
s_state = tvm.placeholder((m, n))
s_init = tvm.compute((1, n), lambda _, i: X[0, i])
s_update = tvm.compute((n,), lambda i: s_state[t-1, i] + X[t, i])
res = tvm.scan(t, s_init, s_update, s_state)
# schedule
s = tvm.Schedule(res.op)
num_thread = 256
block_x = tvm.IterVar(thread_tag="blockIdx.x")
thread_x = tvm.IterVar((0, num_thread), thread_tag="threadIdx.x")
_, x = s[s_init].split(s_init.op.axis[1], factor=num_thread, outer=block_x)
_, x = s[s_init].split(x, outer=thread_x)
_, x = s[s_update].split(s_update.op.axis[0], factor=num_thread, outer=block_x)
_, x = s[s_update].split(x, outer=thread_x)
# one line to build the function.
def check_device(target):
codes = []
fscan = tvm.build(s, [X, res],
target, record_codes=codes,
name="myscan")
if target == "cuda":
ctx = tvm.gpu(0)
else:
ctx = tvm.cl(0)
if not ctx.enabled:
return
for c in codes[1:]:
print(c)
# launch the kernel.
n = 1024
m = 10
a_np = np.random.uniform(size=(m, n)).astype(res.dtype)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros((m, n), dtype=res.dtype), ctx)
fscan(a, b)
np.testing.assert_allclose(
b.asnumpy(), np.cumsum(a_np, axis=0))
tvm.init_opencl()
check_device("cuda")
if __name__ == "__main__":
test_scan()
......@@ -34,6 +34,20 @@ def test_tensor_reduce():
assert(str(C_loaded) == str(C))
def test_tensor_scan():
m = tvm.Var("m")
n = tvm.Var("n")
t = tvm.IterVar((1, m), "t")
x = tvm.placeholder((m, n))
s = tvm.placeholder((m, n))
res = tvm.scan(t,
tvm.compute((1, n), lambda _, i: x[0, i]),
tvm.compute((n,), lambda i: s[t-1, i] + x[t, i]),
s)
assert tuple(res.shape) == (m, n)
if __name__ == "__main__":
test_tensor()
test_tensor_reduce()
test_tensor_scan()
......@@ -18,9 +18,15 @@ def test_simplify():
tvm.make.Load(dtype, Ab.data, i + 4) + 1,
(j + 1) * 4 - 4 * j + i),
None)))
print(stmt)
stmt = tvm.ir_pass.CanonicalSimplify(stmt)
print(stmt)
def test_basic():
m = tvm.Var('m')
ret = tvm.ir_pass.CanonicalSimplify(tvm.make.Evaluate(m-1))
assert str(ret.value) == "(m - 1)"
if __name__ == "__main__":
test_basic()
test_simplify()
......@@ -6,13 +6,11 @@ def test_schedule0():
l = tvm.Var('l')
A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
s = tvm.Schedule(A1.op)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds)
print(stmt)
def test_schedule1():
m = tvm.Var('m')
......@@ -25,7 +23,7 @@ def test_schedule1():
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds)
print(stmt)
def test_schedule2():
m = tvm.Var('m')
......@@ -40,25 +38,45 @@ def test_schedule2():
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_schedule_scan():
m = tvm.Var("m")
n = tvm.Var("n")
l = tvm.Var("l")
t = tvm.IterVar((1, m), name="t")
x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
s_state = tvm.placeholder((m, n))
s_init = tvm.compute((1, n), lambda _, i: x[0, i])
s_update = tvm.compute((n,), lambda i: s_state[t-1, i] + x[t, i])
res = tvm.scan(t, s_init, s_update, s_state)
assert tuple(res.shape) == (m, n)
s = tvm.Schedule(res.op)
s.normalize()
bounds = tvm.schedule.InferBound(s)
assert(bounds[res.op.scan_axis].min.value == 1)
stmt = tvm.schedule.ScheduleOps(s, bounds)
print(stmt)
def test_auto_inline():
m = tvm.Var('m')
n = tvm.Var('n')
A = tvm.placeholder((m, n), name='A')
B = tvm.placeholder((m, n), name='B')
C = tvm.placeholder((m, n), name='C')
T1 = tvm.compute((m, n), lambda i, j: A(i, j) * B(i, j), name='T1')
T2 = tvm.compute((m, n), lambda i, j: T1(i, j) + C(i, j), name='T2')
m = tvm.Var('m')
n = tvm.Var('n')
A = tvm.placeholder((m, n), name='A')
B = tvm.placeholder((m, n), name='B')
C = tvm.placeholder((m, n), name='C')
T1 = tvm.compute((m, n), lambda i, j: A(i, j) * B(i, j), name='T1')
T2 = tvm.compute((m, n), lambda i, j: T1(i, j) + C(i, j), name='T2')
s = tvm.Schedule(T2.op)
tvm.schedule.AutoInlineElemWise(s)
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
print(stmt)
s = tvm.Schedule(T2.op)
tvm.schedule.AutoInlineElemWise(s)
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
if __name__ == "__main__":
test_schedule_scan()
test_schedule0()
test_schedule1()
test_schedule2()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment