Commit 581509ab by Tianqi Chen Committed by GitHub

[PASS] copy intrin (#536)

* [PASS] copy intrin

* update comment thanks to derisavi
parent 33a80e46
...@@ -159,8 +159,8 @@ struct IntSetNode : public Node { ...@@ -159,8 +159,8 @@ struct IntSetNode : public Node {
}; };
/*! /*!
* \brief Detect if e can be rewritten as e = sum_{i=0}^n var[i] * coeff[i] + coeff[n] * \brief Detect if e can be rewritten as e = sum_{i=0}^{n-1} var[i] * coeff[i] + coeff[n]
* Where coeff and base are invariant of var. * Where coeff[i] and base are invariant of var[j] for all i and j.
* *
* \param e The expression to be detected. * \param e The expression to be detected.
* \param vars List of variables to be used in detection. * \param vars List of variables to be used in detection.
......
...@@ -128,7 +128,7 @@ class BufferNode : public Node { ...@@ -128,7 +128,7 @@ class BufferNode : public Node {
Type dtype, Type dtype,
Array<Expr> shape, Array<Expr> shape,
Array<Expr> strides, Array<Expr> strides,
Expr byte_offset, Expr elem_offset,
std::string name, std::string name,
std::string scope, std::string scope,
int data_alignment, int data_alignment,
......
...@@ -240,6 +240,24 @@ Stmt InjectPrefetch(Stmt stmt); ...@@ -240,6 +240,24 @@ Stmt InjectPrefetch(Stmt stmt);
Stmt InjectDoubleBuffer(Stmt stmt, int split_loop); Stmt InjectDoubleBuffer(Stmt stmt, int split_loop);
/*! /*!
* \brief Inject copy intrinsics with optional pad.
*
* \param stmt The statment to be transformed.
* \param pragma_key The pragma key for hint of copy.
* \param fintrin The function with signature
*
* Stmt fintrin(Buffer src,
* Buffer dst,
* Array<Expr> pad_before,
* Array<Expr> pad_after,
* Expr pad_value)
* \return Transformed stmt.
*/
Stmt InjectCopyIntrin(Stmt stmt,
const std::string& pragma_key,
const runtime::PackedFunc& fintrin);
/*!
* \brief Rewrite storage allocation pattern. * \brief Rewrite storage allocation pattern.
* Moves the allocation to outer most possible scope. * Moves the allocation to outer most possible scope.
* Trying to share space between allocations to make * Trying to share space between allocations to make
......
...@@ -171,8 +171,12 @@ inline TNodeRef TVMRetValue::AsNodeRef() const { ...@@ -171,8 +171,12 @@ inline TNodeRef TVMRetValue::AsNodeRef() const {
} }
inline void TVMArgsSetter::operator()(size_t i, const NodeRef& other) const { // NOLINT(*) inline void TVMArgsSetter::operator()(size_t i, const NodeRef& other) const { // NOLINT(*)
if (other.defined()) {
values_[i].v_handle = const_cast<std::shared_ptr<Node>*>(&(other.node_)); values_[i].v_handle = const_cast<std::shared_ptr<Node>*>(&(other.node_));
type_codes_[i] = kNodeHandle; type_codes_[i] = kNodeHandle;
} else {
type_codes_[i] = kNull;
}
} }
// type related stuffs // type related stuffs
......
...@@ -92,6 +92,7 @@ REGISTER_PASS3(StorageFlatten); ...@@ -92,6 +92,7 @@ REGISTER_PASS3(StorageFlatten);
REGISTER_PASS4(IRTransform); REGISTER_PASS4(IRTransform);
REGISTER_PASS1(VectorizeLoop); REGISTER_PASS1(VectorizeLoop);
REGISTER_PASS4(UnrollLoop); REGISTER_PASS4(UnrollLoop);
REGISTER_PASS3(InjectCopyIntrin);
REGISTER_PASS2(ThreadSync); REGISTER_PASS2(ThreadSync);
REGISTER_PASS5(MakeAPI); REGISTER_PASS5(MakeAPI);
REGISTER_PASS2(BindDeviceType); REGISTER_PASS2(BindDeviceType);
......
...@@ -307,7 +307,13 @@ class Canonical::Internal : public IRMutator { ...@@ -307,7 +307,13 @@ class Canonical::Internal : public IRMutator {
if (!op->is_pure()) { if (!op->is_pure()) {
stack_.back().has_side_effect = true; stack_.back().has_side_effect = true;
} }
return IRMutator::Mutate_(op, e); Expr expr = IRMutator::Mutate_(op, e);
op = expr.as<Call>();
if (op->is_intrinsic(Call::likely) && is_const(op->args[0])) {
return op->args[0];
} else {
return expr;
}
} }
// For // For
Stmt Mutate_(const For* op, const Stmt& s) { Stmt Mutate_(const For* op, const Stmt& s) {
...@@ -320,6 +326,13 @@ class Canonical::Internal : public IRMutator { ...@@ -320,6 +326,13 @@ class Canonical::Internal : public IRMutator {
--level_counter_; --level_counter_;
return stmt; return stmt;
} }
// IfThenElse
Stmt Mutate_(const IfThenElse* op, const Stmt& s) {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<IfThenElse>();
if (is_one(op->condition)) return op->then_case;
return stmt;
}
// AttrStmt // AttrStmt
Stmt Mutate_(const AttrStmt* op, const Stmt& s) { Stmt Mutate_(const AttrStmt* op, const Stmt& s) {
if (op->attr_key == attr::thread_extent || if (op->attr_key == attr::thread_extent ||
......
/*!
* Copyright (c) 2017 by Contributors
* \brief Replace certain copy with copy intrinsics.
* \file copy_intrin_rewrite.cc
*/
#include <tvm/ir.h>
#include <tvm/packed_func_ext.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
namespace tvm {
namespace ir {
using runtime::PackedFunc;
class CopyIntrinInjector : public IRMutator {
public:
CopyIntrinInjector(const std::string& pragma_key,
const PackedFunc& flower_copy_fromto)
: pragma_key_(pragma_key),
flower_copy_fromto_(flower_copy_fromto) {
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
if (op->attr_key == attr::storage_scope) {
const Variable* buf = op->node.as<Variable>();
storage_scope_[buf] = op->value.as<StringImm>()->value;
} else if (op->attr_key == ir::attr::pragma_scope) {
const std::string& pname = op->value.as<StringImm>()->value;
if (pname == pragma_key_) {
Stmt ret;
CHECK(MatchCopyPattern(op->body, &ret))
<< "Cannot match copy pattern of " << op->body;
return ret;
}
}
return IRMutator::Mutate_(op, s);
}
private:
bool MatchCopyPattern(Stmt stmt, Stmt *out) {
Stmt body = stmt;
// strip the loops
std::vector<const For*> loops;
while (const For* op = body.as<For>()) {
if (!is_zero(op->min)) return false;
loops.push_back(op);
body = op->body;
}
const Store* store = body.as<Store>();
if (store == nullptr) return false;
const Select* select = store->value.as<Select>();
const Load* load = store->value.as<Load>();
// for now only support true condition matching
if (select != nullptr) {
load = select->true_value.as<Load>();
}
if (load == nullptr) return false;
if (load->type.lanes() != 1) return false;
Array<Var> loop_vars;
for (const For* op : loops) {
loop_vars.push_back(Var(op->loop_var.node_));
}
Array<Expr> store_strides =
arith::DetectLinearEquation(store->index, loop_vars);
Array<Expr> load_strides =
arith::DetectLinearEquation(load->index, loop_vars);
if (load_strides.size() == 0 || store_strides.size() == 0) return false;
Array<Expr> dst_shape;
for (const For* op : loops) {
dst_shape.push_back(op->extent);
}
Array<Expr> src_shape = dst_shape;
Array<Expr> pad_before, pad_after;
Expr pad_value;
Expr src_elem_offset = load_strides[loop_vars.size()];
if (select != nullptr) {
Array<Expr> clip_bound =
arith::DetectClipBound(select->condition, loop_vars);
pad_value = select->false_value;
if (clip_bound.size() == 0) return false;
CHECK_EQ(src_shape.size(), loop_vars.size());
CHECK_EQ(clip_bound.size(), loop_vars.size() * 2);
for (size_t i = 0; i < src_shape.size(); ++i) {
Expr min_value = clip_bound[2 * i];
Expr max_value = clip_bound[2 * i + 1];
Type t = loop_vars[i].type();
Expr svalue = src_shape[i];
if (min_value.defined()) {
Expr pbefore = Simplify(Max::make(min_value, make_zero(t)));
src_elem_offset = src_elem_offset + pbefore * load_strides[i];
svalue = svalue - pbefore;
pad_before.push_back(pbefore);
} else {
pad_before.push_back(make_zero(t));
}
if (max_value.defined()) {
Expr pafter = Simplify(Max::make(loops[i]->extent - max_value - make_const(t, 1),
make_zero(t)));
svalue = svalue - pafter;
pad_after.push_back(pafter);
} else {
pad_after.push_back(make_zero(t));
}
src_shape.Set(i, Simplify(svalue));
}
src_elem_offset = Simplify(src_elem_offset);
}
CHECK_EQ(load_strides.size(), store_strides.size());
CHECK_EQ(load_strides.size(), loop_vars.size() + 1);
Array<Expr> src_strides(load_strides.begin(), load_strides.begin() + loop_vars.size());
Array<Expr> dst_strides(store_strides.begin(), store_strides.begin() + loop_vars.size());
Buffer dst = BufferNode::make(
Var(store->buffer_var.node_),
load->type,
dst_shape,
dst_strides,
store_strides[loop_vars.size()],
store->buffer_var->name_hint,
GetStorageScope(store->buffer_var.get()),
0, 0);
Buffer src = BufferNode::make(
Var(load->buffer_var.node_),
load->type,
src_shape,
src_strides,
src_elem_offset,
load->buffer_var->name_hint,
GetStorageScope(load->buffer_var.get()),
0, 0);
*out = flower_copy_fromto_(src, dst, pad_before, pad_after, pad_value);
CHECK(out->defined()) << "flower function did not return correct stmt";
return true;
}
// Get storage scope
std::string GetStorageScope(const Variable* var) const {
auto it = storage_scope_.find(var);
if (it != storage_scope_.end()) {
return it->second;
} else {
return "";
}
}
// pragma key
const std::string& pragma_key_;
// function to lower copy intrinsics.
const PackedFunc& flower_copy_fromto_;
// Storage scope
std::unordered_map<const Variable*, std::string> storage_scope_;
};
Stmt InjectCopyIntrin(Stmt stmt,
const std::string& pragma_key,
const PackedFunc& flower_copy_fromto) {
return CopyIntrinInjector(pragma_key, flower_copy_fromto)
.Mutate(stmt);
}
} // namespace ir
} // namespace tvm
import tvm
def test_copy2d():
m = tvm.var('m')
l = tvm.var('l')
A = tvm.placeholder((m, l), name='A')
B = tvm.compute((m, l), lambda i, j: A[i, j], name='B')
s = tvm.create_schedule(B.op)
s[B].pragma(B.op.axis[0], "memcpy")
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.decl_buffer(B.shape, B.dtype, name='B')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
def cb(src, dst, pad_before, pad_after, pad_value):
assert dst.strides[0] == l
assert dst.strides[1].value == 1
assert src.strides[0] == l
assert tuple(src.shape) == (m, l)
return tvm.make.Evaluate(0)
stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)
def test_copy_pad():
m = tvm.var('m')
l = tvm.var('l')
A = tvm.placeholder((m, l), name='A')
B = tvm.compute((m + 2, l), lambda i, j:
tvm.select(tvm.all(i >= 1, i < m + 1),
A[i - 1, j], 1.0), name='B')
s = tvm.create_schedule(B.op)
s[B].pragma(B.op.axis[0], "memcpy")
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.decl_buffer(B.shape, B.dtype, name='B')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
def cb(src, dst, pad_before, pad_after, pad_value):
assert tvm.ir_pass.Simplify(src.elem_offset).value == 0
assert pad_before[0].value == 1
assert pad_before[1].value == 0
assert pad_after[0].value == 1
assert pad_after[1].value == 0
assert pad_value.value == 1.0
return tvm.make.Evaluate(0)
stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)
def assert_expr_equal(a, b):
assert tvm.ir_pass.Simplify(a - b).value == 0
def test_copy_pad_split():
m = 4 * 3
A = tvm.placeholder((m, ), name="A")
Apad = tvm.compute((m + 2,), lambda i:
tvm.select(tvm.all(i >= 1, i <= m),
A[i - 1], 0.0), "Apad")
B = tvm.compute((m,), lambda i: Apad[i] + Apad[i + 1] + Apad[i + 2])
s = tvm.create_schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], factor=4)
s[Apad].compute_at(s[B], xo)
s[Apad].pragma(s[Apad].op.axis[0], "memcpy")
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.decl_buffer(B.shape, B.dtype, name='B')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
stmt = tvm.ir_pass.CanonicalSimplify(stmt)
def cb(src, dst, pad_before, pad_after, pad_value):
assert(dst.elem_offset.value == 0)
assert_expr_equal(src.elem_offset, tvm.max(xo * 4, 1) - 1)
rpad_before = tvm.max(1 - xo * 4, 0)
rpad_after = tvm.max(xo * 4 - 7, 0)
assert_expr_equal(pad_before[0], rpad_before)
assert_expr_equal(pad_after[0], rpad_after)
assert_expr_equal(src.shape[0], 6 - rpad_before - rpad_after)
return tvm.make.Evaluate(0)
stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)
if __name__ == "__main__":
test_copy2d()
test_copy_pad()
test_copy_pad_split()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment