Commit fbb472b8 by libing4752 Committed by Tianqi Chen

enhance pragma to support single point copy (#863)

* modified schedule_dataflow_rewrite.cc to fix losing tensor problem

* modified schedule_dataflow_rewrite.cc for lint scan

* modified schedule_dataflow_rewrite.cc for lint scan

* using tensor's value_index to index output of stage op

* repare address offset for different kinds of dtype

* bc

* aaa

* aaaaa

* repare address for different dtypes

* remove nonsense files

* add whitespace of line 581

* use base alloc elem_type

* enhance the testcast of basic buffer is 64bits,32bits,16bits,8bits

* use extends[0]->type() as dtype of offset

* clear program writes

* enhance inject_copy_intin to support of pragma stmt with no loops

* fix cpplint errors

* fix cpplint error of !

* enhance detectLinearEquation to support with no loop vars

* fix cpplint errors
parent 0ca53640
...@@ -123,25 +123,28 @@ class LinearEqDetector ...@@ -123,25 +123,28 @@ class LinearEqDetector
}; };
Array<Expr> DetectLinearEquation(const Expr& e, const Array<Var>& vars) { Array<Expr> DetectLinearEquation(const Expr& e, const Array<Var>& vars) {
CHECK_GE(vars.size(), 1U);
Expr base = e; Expr base = e;
Array<Expr> coeff; Array<Expr> coeff;
for (Var v : vars) { if (0 == vars.size()) {
LinearEqEntry ret; coeff.push_back(make_const(Int(32), 1));
if (!LinearEqDetector(v).Detect(base, &ret)) { } else {
return Array<Expr>(); for (Var v : vars) {
LinearEqEntry ret;
if (!LinearEqDetector(v).Detect(base, &ret)) {
return Array<Expr>();
}
coeff.push_back(ret.coeff);
base = std::move(ret.base);
} }
coeff.push_back(ret.coeff);
base = std::move(ret.base);
}
std::unordered_set<const Variable*> vset; std::unordered_set<const Variable*> vset;
for (size_t i = vars.size(); i != 1; --i) { for (size_t i = vars.size(); i != 1; --i) {
vset.insert(vars[i - 1].get()); vset.insert(vars[i - 1].get());
// The previous coeff contains the variable // The previous coeff contains the variable
if (ExprUseVar(coeff[i - 2], vset)) { if (ExprUseVar(coeff[i - 2], vset)) {
return Array<Expr>(); return Array<Expr>();
}
} }
} }
coeff.push_back(base); coeff.push_back(base);
......
...@@ -40,6 +40,7 @@ class CopyIntrinInjector : public IRMutator { ...@@ -40,6 +40,7 @@ class CopyIntrinInjector : public IRMutator {
private: private:
bool MatchCopyPattern(Stmt stmt, Stmt *out) { bool MatchCopyPattern(Stmt stmt, Stmt *out) {
Stmt body = stmt; Stmt body = stmt;
bool is_single_point_copy = false;
// strip the loops // strip the loops
std::vector<const For*> loops; std::vector<const For*> loops;
...@@ -53,7 +54,10 @@ class CopyIntrinInjector : public IRMutator { ...@@ -53,7 +54,10 @@ class CopyIntrinInjector : public IRMutator {
const Select* select = store->value.as<Select>(); const Select* select = store->value.as<Select>();
const Cast* cast = store->value.as<Cast>(); const Cast* cast = store->value.as<Cast>();
const Load* load = store->value.as<Load>(); const Load* load = store->value.as<Load>();
if (0 == loops.size()) {
is_single_point_copy = true;
CHECK(select == nullptr);
}
// for now only support true condition matching // for now only support true condition matching
if (select != nullptr) { if (select != nullptr) {
load = select->true_value.as<Load>(); load = select->true_value.as<Load>();
...@@ -74,13 +78,19 @@ class CopyIntrinInjector : public IRMutator { ...@@ -74,13 +78,19 @@ class CopyIntrinInjector : public IRMutator {
arith::DetectLinearEquation(load->index, loop_vars); arith::DetectLinearEquation(load->index, loop_vars);
if (load_strides.size() == 0 || store_strides.size() == 0) return false; if (load_strides.size() == 0 || store_strides.size() == 0) return false;
Array<Expr> dst_shape; Array<Expr> dst_shape;
for (const For* op : loops) { auto loop_var_size = loop_vars.size();
dst_shape.push_back(op->extent); if (is_single_point_copy) {
loop_var_size = 1;
dst_shape.push_back(make_const(Int(32), 1));
} else {
for (const For* op : loops) {
dst_shape.push_back(op->extent);
}
} }
Array<Expr> src_shape = dst_shape; Array<Expr> src_shape = dst_shape;
Array<Expr> pad_before, pad_after; Array<Expr> pad_before, pad_after;
Expr pad_value; Expr pad_value;
Expr src_elem_offset = load_strides[loop_vars.size()]; Expr src_elem_offset = load_strides[loop_var_size];
if (select != nullptr) { if (select != nullptr) {
Array<Expr> clip_bound = Array<Expr> clip_bound =
arith::DetectClipBound(select->condition, loop_vars); arith::DetectClipBound(select->condition, loop_vars);
...@@ -114,15 +124,15 @@ class CopyIntrinInjector : public IRMutator { ...@@ -114,15 +124,15 @@ class CopyIntrinInjector : public IRMutator {
src_elem_offset = Simplify(src_elem_offset); src_elem_offset = Simplify(src_elem_offset);
} }
CHECK_EQ(load_strides.size(), store_strides.size()); CHECK_EQ(load_strides.size(), store_strides.size());
CHECK_EQ(load_strides.size(), loop_vars.size() + 1); CHECK_EQ(load_strides.size(), loop_var_size + 1);
Array<Expr> src_strides(load_strides.begin(), load_strides.begin() + loop_vars.size()); Array<Expr> src_strides(load_strides.begin(), load_strides.begin() + loop_var_size);
Array<Expr> dst_strides(store_strides.begin(), store_strides.begin() + loop_vars.size()); Array<Expr> dst_strides(store_strides.begin(), store_strides.begin() + loop_var_size);
Buffer dst = BufferNode::make( Buffer dst = BufferNode::make(
Var(store->buffer_var.node_), Var(store->buffer_var.node_),
store->value.type(), store->value.type(),
dst_shape, dst_shape,
dst_strides, dst_strides,
store_strides[loop_vars.size()], store_strides[loop_var_size],
store->buffer_var->name_hint, store->buffer_var->name_hint,
GetStorageScope(store->buffer_var.get()), GetStorageScope(store->buffer_var.get()),
0, 0); 0, 0);
......
...@@ -44,6 +44,25 @@ def test_copy_pad(): ...@@ -44,6 +44,25 @@ def test_copy_pad():
return tvm.make.Evaluate(0) return tvm.make.Evaluate(0)
stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb) stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)
def test_single_point_test():
A = tvm.placeholder((1,), name='A')
B = tvm.compute((1,), lambda i:
A[i], name='B')
s = tvm.create_schedule(B.op)
s[B].pragma(B.op.axis[0], "memcpy")
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.decl_buffer(B.shape, B.dtype, name='B')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
def cb(src, dst, pad_before, pad_after, pad_value):
assert tvm.ir_pass.Simplify(src.elem_offset).value == 0
assert tvm.ir_pass.Simplify(dst.elem_offset).value == 0
assert tvm.ir_pass.Simplify(src.strides[0]).value == 1
assert tvm.ir_pass.Simplify(dst.strides[0]).value == 1
return tvm.make.Evaluate(0)
stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)
def assert_expr_equal(a, b): def assert_expr_equal(a, b):
assert tvm.ir_pass.Simplify(a - b).value == 0 assert tvm.ir_pass.Simplify(a - b).value == 0
...@@ -80,3 +99,4 @@ if __name__ == "__main__": ...@@ -80,3 +99,4 @@ if __name__ == "__main__":
test_copy2d() test_copy2d()
test_copy_pad() test_copy_pad()
test_copy_pad_split() test_copy_pad_split()
test_single_point_test()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment