Unverified Commit d9dc65f8 by Tianqi Chen Committed by GitHub

[BUILD] Simplify after bind device type (#2670)

parent 97be70a0
...@@ -176,13 +176,13 @@ class DeviceTypeBinder: public IRMutator { ...@@ -176,13 +176,13 @@ class DeviceTypeBinder: public IRMutator {
explicit DeviceTypeBinder(int device_type) explicit DeviceTypeBinder(int device_type)
: device_type_(device_type) {} : device_type_(device_type) {}
Stmt Mutate_(const AttrStmt* op, const Stmt &s) final { Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
if (op->attr_key == attr::device_context_type) { if (op->attr_key == attr::device_context_type) {
if (const Variable* var = op->value.as<Variable>()) { if (const Variable* var = op->value.as<Variable>()) {
std::unordered_map<const Variable*, Expr> dmap; var_ = var;
Expr value = make_const(op->value.type(), device_type_); Expr value = make_const(op->value.type(), device_type_);
dmap[var] = value; Stmt body = IRMutator::Mutate_(op, s);
Stmt body = Substitute(s, dmap); var_ = nullptr;
std::ostringstream os; std::ostringstream os;
os << "device_type need to be " << device_type_; os << "device_type need to be " << device_type_;
return AssertStmt::make(op->value == value, os.str(), body); return AssertStmt::make(op->value == value, os.str(), body);
...@@ -191,7 +191,40 @@ class DeviceTypeBinder: public IRMutator { ...@@ -191,7 +191,40 @@ class DeviceTypeBinder: public IRMutator {
return IRMutator::Mutate_(op, s); return IRMutator::Mutate_(op, s);
} }
Stmt Mutate_(const IfThenElse* op, const Stmt& s) final {
// eager simplify if guard.
Stmt res = IRMutator::Mutate_(op, s);
op = res.as<IfThenElse>();
if (is_zero(op->condition)) {
if (op->else_case.defined()) return op->else_case;
return Evaluate::make(0);
}
if (is_one(op->condition)) {
return op->then_case;
}
return res;
}
Expr Mutate_(const NE* op, const Expr& e) final {
// eager check NE for device check
Expr res = IRMutator::Mutate_(op, e);
op = res.as<NE>();
if (ir::Equal(op->a, op->b)) {
return make_const(op->type, false);
}
return res;
}
Expr Mutate_(const Variable* op, const Expr& e) final {
if (op == var_) {
return make_const(op->type, device_type_);
} else {
return e;
}
}
public: public:
const Variable* var_{nullptr};
int device_type_; int device_type_;
}; };
......
...@@ -11,10 +11,7 @@ def test_add(): ...@@ -11,10 +11,7 @@ def test_add():
s = tvm.create_schedule(C.op) s = tvm.create_schedule(C.op)
def check_c(): def check_c():
f1 = tvm.lower(s, [A, B, C], name="fadd") mhost = tvm.build(s, [A, B, C], "c", name="fadd")
fsplits = [x for x in tvm.ir_pass.SplitHostDevice(f1)]
fsplits[0] = tvm.ir_pass.LowerTVMBuiltin(fsplits[0])
mhost = tvm.codegen.build_module(fsplits[0], "c")
temp = util.tempdir() temp = util.tempdir()
path_dso = temp.relpath("temp.so") path_dso = temp.relpath("temp.so")
mhost.export_library(path_dso) mhost.export_library(path_dso)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment