Commit fbb472b8 by libing4752 Committed by Tianqi Chen

enhance pragma to support single point copy (#863)

* modified schedule_dataflow_rewrite.cc to fix losing tensor problem

* modified schedule_dataflow_rewrite.cc for lint scan

* modified schedule_dataflow_rewrite.cc for lint scan

* using tensor's value_index to index output of stage op

* repare address offset for different kinds of dtype

* bc

* aaa

* aaaaa

* repare address for different dtypes

* remove nonsense files

* add whitespace of line 581

* use base alloc elem_type

* enhance the testcast of basic buffer is 64bits,32bits,16bits,8bits

* use extends[0]->type() as dtype of offset

* clear program writes

* enhance inject_copy_intin to support of pragma stmt with no loops

* fix cpplint errors

* fix cpplint error of !

* enhance detectLinearEquation to support with no loop vars

* fix cpplint errors
parent 0ca53640
......@@ -123,10 +123,12 @@ class LinearEqDetector
};
Array<Expr> DetectLinearEquation(const Expr& e, const Array<Var>& vars) {
CHECK_GE(vars.size(), 1U);
Expr base = e;
Array<Expr> coeff;
if (0 == vars.size()) {
coeff.push_back(make_const(Int(32), 1));
} else {
for (Var v : vars) {
LinearEqEntry ret;
if (!LinearEqDetector(v).Detect(base, &ret)) {
......@@ -144,6 +146,7 @@ Array<Expr> DetectLinearEquation(const Expr& e, const Array<Var>& vars) {
return Array<Expr>();
}
}
}
coeff.push_back(base);
return coeff;
}
......
......@@ -40,6 +40,7 @@ class CopyIntrinInjector : public IRMutator {
private:
bool MatchCopyPattern(Stmt stmt, Stmt *out) {
Stmt body = stmt;
bool is_single_point_copy = false;
// strip the loops
std::vector<const For*> loops;
......@@ -53,7 +54,10 @@ class CopyIntrinInjector : public IRMutator {
const Select* select = store->value.as<Select>();
const Cast* cast = store->value.as<Cast>();
const Load* load = store->value.as<Load>();
if (0 == loops.size()) {
is_single_point_copy = true;
CHECK(select == nullptr);
}
// for now only support true condition matching
if (select != nullptr) {
load = select->true_value.as<Load>();
......@@ -74,13 +78,19 @@ class CopyIntrinInjector : public IRMutator {
arith::DetectLinearEquation(load->index, loop_vars);
if (load_strides.size() == 0 || store_strides.size() == 0) return false;
Array<Expr> dst_shape;
auto loop_var_size = loop_vars.size();
if (is_single_point_copy) {
loop_var_size = 1;
dst_shape.push_back(make_const(Int(32), 1));
} else {
for (const For* op : loops) {
dst_shape.push_back(op->extent);
}
}
Array<Expr> src_shape = dst_shape;
Array<Expr> pad_before, pad_after;
Expr pad_value;
Expr src_elem_offset = load_strides[loop_vars.size()];
Expr src_elem_offset = load_strides[loop_var_size];
if (select != nullptr) {
Array<Expr> clip_bound =
arith::DetectClipBound(select->condition, loop_vars);
......@@ -114,15 +124,15 @@ class CopyIntrinInjector : public IRMutator {
src_elem_offset = Simplify(src_elem_offset);
}
CHECK_EQ(load_strides.size(), store_strides.size());
CHECK_EQ(load_strides.size(), loop_vars.size() + 1);
Array<Expr> src_strides(load_strides.begin(), load_strides.begin() + loop_vars.size());
Array<Expr> dst_strides(store_strides.begin(), store_strides.begin() + loop_vars.size());
CHECK_EQ(load_strides.size(), loop_var_size + 1);
Array<Expr> src_strides(load_strides.begin(), load_strides.begin() + loop_var_size);
Array<Expr> dst_strides(store_strides.begin(), store_strides.begin() + loop_var_size);
Buffer dst = BufferNode::make(
Var(store->buffer_var.node_),
store->value.type(),
dst_shape,
dst_strides,
store_strides[loop_vars.size()],
store_strides[loop_var_size],
store->buffer_var->name_hint,
GetStorageScope(store->buffer_var.get()),
0, 0);
......
......@@ -44,6 +44,25 @@ def test_copy_pad():
return tvm.make.Evaluate(0)
stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)
def test_single_point_test():
A = tvm.placeholder((1,), name='A')
B = tvm.compute((1,), lambda i:
A[i], name='B')
s = tvm.create_schedule(B.op)
s[B].pragma(B.op.axis[0], "memcpy")
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.decl_buffer(B.shape, B.dtype, name='B')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
def cb(src, dst, pad_before, pad_after, pad_value):
assert tvm.ir_pass.Simplify(src.elem_offset).value == 0
assert tvm.ir_pass.Simplify(dst.elem_offset).value == 0
assert tvm.ir_pass.Simplify(src.strides[0]).value == 1
assert tvm.ir_pass.Simplify(dst.strides[0]).value == 1
return tvm.make.Evaluate(0)
stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)
def assert_expr_equal(a, b):
assert tvm.ir_pass.Simplify(a - b).value == 0
......@@ -80,3 +99,4 @@ if __name__ == "__main__":
test_copy2d()
test_copy_pad()
test_copy_pad_split()
test_single_point_test()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment