Commit 4578048c by Tianqi Chen Committed by GitHub

[PASS] IRTransform to enable IR pass proptype in python (#401)

parent 8ef26606
......@@ -102,6 +102,25 @@ class IRMutator {
virtual Expr Mutate_(const Shuffle* op, const Expr& e);
};
/*!
* \brief recursively visit the ir in post DFS order node, and transform it
*
* \param node The ir to be transformed.
* \param preorder The function called in before recursive mutation
* If preorder returns None, then the transform will proceed to recursive call.
* If preorder returns a not None Stmt/Expr, the transformer will simply return it and
* won't do further recursion.
* \param postorder The function called after recursive mutation.
* The recursive mutation result is passed to postorder for further mutation.
* \param only_enable List of StringImm.
* If it is empty, all IRNode will call preorder/postorder
* If it is not empty, preorder/postorder will only be called
* when the IRNode's type key is in the list.
*/
Stmt IRTransform(const Stmt& node,
const runtime::PackedFunc& preorder,
const runtime::PackedFunc& postorder,
const Array<Expr>& only_enable = {});
} // namespace ir
} // namespace tvm
#endif // TVM_IR_MUTATOR_H_
......@@ -7,6 +7,7 @@
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <tvm/api_registry.h>
namespace tvm {
......@@ -88,6 +89,7 @@ REGISTER_PASS1(VerifySSA);
REGISTER_PASS1(RewriteUnsafeSelect);
REGISTER_PASS4(Inline);
REGISTER_PASS3(StorageFlatten);
REGISTER_PASS4(IRTransform);
REGISTER_PASS1(VectorizeLoop);
REGISTER_PASS4(UnrollLoop);
REGISTER_PASS2(ThreadSync);
......
......@@ -110,7 +110,7 @@ class DSLAPIImpl : public DSLAPI {
*out_index = static_cast<int>(Node::TypeKey2Index(type_key));
}
void NodeGetTypeIndex(NodeHandle handle,
int* out_index) const final {
int* out_index) const final {
*out_index = static_cast<int>(
(*static_cast<TVMAPINode*>(handle))->type_index());
}
......
......@@ -11,7 +11,6 @@
#include "./ir_util.h"
#include "./storage_access.h"
namespace tvm {
namespace ir {
......
......@@ -4,11 +4,65 @@
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/packed_func_ext.h>
#include "./ir_util.h"
namespace tvm {
namespace ir {
class IRTransformer final : public IRMutator {
public:
IRTransformer(const runtime::PackedFunc& f_preorder,
const runtime::PackedFunc& f_postorder,
const std::unordered_set<uint32_t>& only_enable)
: f_preorder_(f_preorder),
f_postorder_(f_postorder),
only_enable_(only_enable) {
}
Stmt Mutate(Stmt stmt) final {
return MutateInternal<Stmt>(stmt);
}
Expr Mutate(Expr expr) final {
return MutateInternal<Expr>(expr);
}
private:
template<typename T>
T MutateInternal(T node) {
if (only_enable_.size() &&
!only_enable_.count(node->type_index())) {
return IRMutator::Mutate(node);
}
if (f_preorder_ != nullptr) {
T pre = f_preorder_(node);
if (pre.defined()) return pre;
}
node = IRMutator::Mutate(node);
if (f_postorder_ != nullptr) {
T post = f_postorder_(node);
if (post.defined()) return post;
}
return node;
}
// The functions
const runtime::PackedFunc& f_preorder_;
const runtime::PackedFunc& f_postorder_;
// type indices enabled.
const std::unordered_set<uint32_t>& only_enable_;
};
Stmt IRTransform(const Stmt& ir_node,
const runtime::PackedFunc& f_preorder,
const runtime::PackedFunc& f_postorder,
const Array<Expr>& only_enable) {
std::unordered_set<uint32_t> only_type_index;
for (Expr s : only_enable) {
only_type_index.insert(Node::TypeKey2Index(s.as<StringImm>()->value.c_str()));
}
return IRTransformer(f_preorder, f_postorder, only_type_index)
.Mutate(ir_node);
}
IRMutator::FMutateExpr& IRMutator::vtable_expr() { // NOLINT(*)
static FMutateExpr inst; return inst;
}
......
......@@ -14,5 +14,11 @@ def test_basic():
assert m[1].value == 5
assert tvm.ir_pass.Simplify(m[0] - (b * 6 + 7 + 1)).value == 0
m = tvm.arith.DetectLinearEquation(a * b + 7, a)
assert m[1] == b
m = tvm.arith.DetectLinearEquation(b * 7, a)
assert m[1].value == 0
if __name__ == "__main__":
test_basic()
import tvm
def test_ir_transform():
ib = tvm.ir_builder.create()
n = tvm.var("n")
with ib.for_range(0, n, name="i") as i:
with ib.for_range(0, 10, name="j") as j:
x = tvm.call_extern("int32", "TestA", i * 3 + j * 1)
ib.emit(tvm.call_extern("int32", "TestB", x))
ib.emit(tvm.call_extern("int32", "TestC", x))
body = ib.get()
def preorder(op):
if op.name == "TestC":
return tvm.const(0, "int32")
return None
def postorder(op):
assert isinstance(op, tvm.expr.Call)
if op.name == "TestA":
return tvm.call_extern("int32", "TestB", op.args[0] + 1)
return op
body = tvm.ir_pass.IRTransform(body, preorder, postorder, ["Call"])
stmt_list = tvm.make.stmt_list(body.body.body)
assert stmt_list[0].value.args[0].name == "TestB"
assert stmt_list[1].value.value == 0
if __name__ == "__main__":
test_ir_transform()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment