Commit a3776ba5 by Tianqi Chen Committed by GitHub

[PASS][PRAGMA] Allow pragma debug_skip_region to skip region of computation (#318)

parent 79e482bc
...@@ -207,9 +207,10 @@ def lower(sch, ...@@ -207,9 +207,10 @@ def lower(sch,
for f in cfg.add_lower_pass: for f in cfg.add_lower_pass:
stmt = f(stmt) stmt = f(stmt)
stmt = ir_pass.Simplify(stmt) stmt = ir_pass.Simplify(stmt)
stmt = ir_pass.LowerStorageAccessInfo(stmt)
stmt = ir_pass.RemoveNoOp(stmt)
if simple_mode: if simple_mode:
return stmt return stmt
stmt = ir_pass.LowerStorageAccessInfo(stmt)
return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func) return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func)
......
...@@ -524,6 +524,11 @@ class Stage(NodeBase): ...@@ -524,6 +524,11 @@ class Stage(NodeBase):
Most pragmas are advanced/experimental features Most pragmas are advanced/experimental features
and may subject to change. List of supported pragmas: and may subject to change. List of supported pragmas:
- **debug_skip_region**
Force skip the region marked by the axis and turn it into no-op.
This is useful for debug purposes.
- **parallel_launch_point** - **parallel_launch_point**
Specify to launch parallel threads outside the Specify to launch parallel threads outside the
......
...@@ -12,7 +12,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -12,7 +12,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->stream << "mem-info(" p->stream << "mem-info("
<< "unit_bits=" << op->unit_bits << ", " << "unit_bits=" << op->unit_bits << ", "
<< "max_num_bits=" << op->max_num_bits << ", " << "max_num_bits=" << op->max_num_bits << ", "
<< "max_simd_bits=" << op->max_simd_bits << ")"; << "max_simd_bits=" << op->max_simd_bits << ", "
<< "head_address=" << op->head_address << ")";
}); });
TVM_REGISTER_NODE_TYPE(MemoryInfoNode); TVM_REGISTER_NODE_TYPE(MemoryInfoNode);
......
...@@ -204,10 +204,12 @@ class CoProcBarrierDetector : public StorageAccessVisitor { ...@@ -204,10 +204,12 @@ class CoProcBarrierDetector : public StorageAccessVisitor {
void PlanReadBarrier(Stmt stmt) { void PlanReadBarrier(Stmt stmt) {
read_barrier_ = true; read_barrier_ = true;
this->Visit(stmt); this->Visit(stmt);
PlanReadBarrier(scope_.back(), nullptr);
} }
void PlanWriteBarrier(Stmt stmt) { void PlanWriteBarrier(Stmt stmt) {
read_barrier_ = false; read_barrier_ = false;
this->Visit(stmt); this->Visit(stmt);
PlanWriteBarrier(scope_.back(), nullptr);
} }
std::unordered_map<const Node*, std::vector<Stmt> > barrier_before_; std::unordered_map<const Node*, std::vector<Stmt> > barrier_before_;
...@@ -245,7 +247,6 @@ class CoProcBarrierDetector : public StorageAccessVisitor { ...@@ -245,7 +247,6 @@ class CoProcBarrierDetector : public StorageAccessVisitor {
write_set.erase(it); write_set.erase(it);
} }
}; };
for (size_t i = 0; i < seq.size(); ++i) { for (size_t i = 0; i < seq.size(); ++i) {
const StmtEntry& s = seq[i]; const StmtEntry& s = seq[i];
for (const AccessEntry& acc : s.access) { for (const AccessEntry& acc : s.access) {
...@@ -291,7 +292,6 @@ class CoProcBarrierDetector : public StorageAccessVisitor { ...@@ -291,7 +292,6 @@ class CoProcBarrierDetector : public StorageAccessVisitor {
const StmtEntry& s = seq[i - 1]; const StmtEntry& s = seq[i - 1];
for (const AccessEntry& acc : s.access) { for (const AccessEntry& acc : s.access) {
if (acc.threads.size() == 0 && acc.type == kWrite) { if (acc.threads.size() == 0 && acc.type == kWrite) {
CHECK_NE(i, seq.size());
fupdate(i, acc); fupdate(i, acc);
write_seq.push_back(acc); write_seq.push_back(acc);
} }
......
...@@ -20,6 +20,12 @@ class NoOpRemover : public IRMutator { ...@@ -20,6 +20,12 @@ class NoOpRemover : public IRMutator {
return is_no_op(op->body) ? MakeEvaluate(op->value) : stmt; return is_no_op(op->body) ? MakeEvaluate(op->value) : stmt;
} }
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
if (op->attr_key == ir::attr::pragma_scope) {
const std::string& pname = op->value.as<StringImm>()->value;
if (pname == "debug_skip_region") {
return MakeEvaluate(0);
}
}
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<AttrStmt>(); op = stmt.as<AttrStmt>();
return is_no_op(op->body) ? MakeEvaluate(op->value) : stmt; return is_no_op(op->body) ? MakeEvaluate(op->value) : stmt;
......
...@@ -30,8 +30,8 @@ def test_make_sum(): ...@@ -30,8 +30,8 @@ def test_make_sum():
B = tvm.compute((2,), lambda i: tvm.sum(A[i, k], axis=k), name="B") B = tvm.compute((2,), lambda i: tvm.sum(A[i, k], axis=k), name="B")
json_str = tvm.save_json(B) json_str = tvm.save_json(B)
BB = tvm.load_json(json_str) BB = tvm.load_json(json_str)
assert B.op.body[0].combiner.handle.value != 0 assert B.op.body[0].combiner is not None
assert BB.op.body[0].combiner.handle.value != 0 assert BB.op.body[0].combiner is not None
if __name__ == "__main__": if __name__ == "__main__":
test_make_node() test_make_node()
......
...@@ -29,10 +29,6 @@ def test_storage_sync(): ...@@ -29,10 +29,6 @@ def test_storage_sync():
def test_coproc_sync(): def test_coproc_sync():
ib = tvm.ir_builder.create()
n = tvm.var("n")
cp = tvm.thread_axis((0, 1), "cop")
@tvm.register_func("tvm.info.mem.global.cache") @tvm.register_func("tvm.info.mem.global.cache")
def meminfo_cache(): def meminfo_cache():
return tvm.make.node( return tvm.make.node(
...@@ -41,6 +37,9 @@ def test_coproc_sync(): ...@@ -41,6 +37,9 @@ def test_coproc_sync():
max_simd_bits=32, max_simd_bits=32,
max_num_bits=128, max_num_bits=128,
head_address=tvm.call_extern("handle", "global_cache")) head_address=tvm.call_extern("handle", "global_cache"))
ib = tvm.ir_builder.create()
n = tvm.var("n")
cp = tvm.thread_axis((0, 1), "cop")
A = ib.allocate("float32", 128, name="A", scope="global.cache") A = ib.allocate("float32", 128, name="A", scope="global.cache")
with ib.for_range(0, n, name="i") as i: with ib.for_range(0, n, name="i") as i:
A[i] = A[i] + 1 A[i] = A[i] + 1
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment