Commit a3776ba5 by Tianqi Chen Committed by GitHub

[PASS][PRAGMA] Allow pragma debug_skip_region to skip region of computation (#318)

parent 79e482bc
......@@ -207,9 +207,10 @@ def lower(sch,
for f in cfg.add_lower_pass:
stmt = f(stmt)
stmt = ir_pass.Simplify(stmt)
stmt = ir_pass.LowerStorageAccessInfo(stmt)
stmt = ir_pass.RemoveNoOp(stmt)
if simple_mode:
return stmt
stmt = ir_pass.LowerStorageAccessInfo(stmt)
return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func)
......
......@@ -524,6 +524,11 @@ class Stage(NodeBase):
Most pragmas are advanced/experimental features
and may subject to change. List of supported pragmas:
- **debug_skip_region**
Force skip the region marked by the axis and turn it into no-op.
This is useful for debug purposes.
- **parallel_launch_point**
Specify to launch parallel threads outside the
......
......@@ -12,7 +12,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->stream << "mem-info("
<< "unit_bits=" << op->unit_bits << ", "
<< "max_num_bits=" << op->max_num_bits << ", "
<< "max_simd_bits=" << op->max_simd_bits << ")";
<< "max_simd_bits=" << op->max_simd_bits << ", "
<< "head_address=" << op->head_address << ")";
});
TVM_REGISTER_NODE_TYPE(MemoryInfoNode);
......
......@@ -204,10 +204,12 @@ class CoProcBarrierDetector : public StorageAccessVisitor {
void PlanReadBarrier(Stmt stmt) {
read_barrier_ = true;
this->Visit(stmt);
PlanReadBarrier(scope_.back(), nullptr);
}
void PlanWriteBarrier(Stmt stmt) {
read_barrier_ = false;
this->Visit(stmt);
PlanWriteBarrier(scope_.back(), nullptr);
}
std::unordered_map<const Node*, std::vector<Stmt> > barrier_before_;
......@@ -245,7 +247,6 @@ class CoProcBarrierDetector : public StorageAccessVisitor {
write_set.erase(it);
}
};
for (size_t i = 0; i < seq.size(); ++i) {
const StmtEntry& s = seq[i];
for (const AccessEntry& acc : s.access) {
......@@ -291,7 +292,6 @@ class CoProcBarrierDetector : public StorageAccessVisitor {
const StmtEntry& s = seq[i - 1];
for (const AccessEntry& acc : s.access) {
if (acc.threads.size() == 0 && acc.type == kWrite) {
CHECK_NE(i, seq.size());
fupdate(i, acc);
write_seq.push_back(acc);
}
......
......@@ -20,6 +20,12 @@ class NoOpRemover : public IRMutator {
return is_no_op(op->body) ? MakeEvaluate(op->value) : stmt;
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
if (op->attr_key == ir::attr::pragma_scope) {
const std::string& pname = op->value.as<StringImm>()->value;
if (pname == "debug_skip_region") {
return MakeEvaluate(0);
}
}
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<AttrStmt>();
return is_no_op(op->body) ? MakeEvaluate(op->value) : stmt;
......
......@@ -30,8 +30,8 @@ def test_make_sum():
B = tvm.compute((2,), lambda i: tvm.sum(A[i, k], axis=k), name="B")
json_str = tvm.save_json(B)
BB = tvm.load_json(json_str)
assert B.op.body[0].combiner.handle.value != 0
assert BB.op.body[0].combiner.handle.value != 0
assert B.op.body[0].combiner is not None
assert BB.op.body[0].combiner is not None
if __name__ == "__main__":
test_make_node()
......
......@@ -29,10 +29,6 @@ def test_storage_sync():
def test_coproc_sync():
ib = tvm.ir_builder.create()
n = tvm.var("n")
cp = tvm.thread_axis((0, 1), "cop")
@tvm.register_func("tvm.info.mem.global.cache")
def meminfo_cache():
return tvm.make.node(
......@@ -41,6 +37,9 @@ def test_coproc_sync():
max_simd_bits=32,
max_num_bits=128,
head_address=tvm.call_extern("handle", "global_cache"))
ib = tvm.ir_builder.create()
n = tvm.var("n")
cp = tvm.thread_axis((0, 1), "cop")
A = ib.allocate("float32", 128, name="A", scope="global.cache")
with ib.for_range(0, n, name="i") as i:
A[i] = A[i] + 1
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment