Unverified Commit d9dc65f8 by Tianqi Chen Committed by GitHub

[BUILD] Simplify after bind device type (#2670)

parent 97be70a0
......@@ -176,13 +176,13 @@ class DeviceTypeBinder: public IRMutator {
explicit DeviceTypeBinder(int device_type)
: device_type_(device_type) {}
Stmt Mutate_(const AttrStmt* op, const Stmt &s) final {
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
if (op->attr_key == attr::device_context_type) {
if (const Variable* var = op->value.as<Variable>()) {
std::unordered_map<const Variable*, Expr> dmap;
var_ = var;
Expr value = make_const(op->value.type(), device_type_);
dmap[var] = value;
Stmt body = Substitute(s, dmap);
Stmt body = IRMutator::Mutate_(op, s);
var_ = nullptr;
std::ostringstream os;
os << "device_type need to be " << device_type_;
return AssertStmt::make(op->value == value, os.str(), body);
......@@ -191,7 +191,40 @@ class DeviceTypeBinder: public IRMutator {
return IRMutator::Mutate_(op, s);
}
Stmt Mutate_(const IfThenElse* op, const Stmt& s) final {
// eager simplify if guard.
Stmt res = IRMutator::Mutate_(op, s);
op = res.as<IfThenElse>();
if (is_zero(op->condition)) {
if (op->else_case.defined()) return op->else_case;
return Evaluate::make(0);
}
if (is_one(op->condition)) {
return op->then_case;
}
return res;
}
Expr Mutate_(const NE* op, const Expr& e) final {
// eager check NE for device check
Expr res = IRMutator::Mutate_(op, e);
op = res.as<NE>();
if (ir::Equal(op->a, op->b)) {
return make_const(op->type, false);
}
return res;
}
Expr Mutate_(const Variable* op, const Expr& e) final {
if (op == var_) {
return make_const(op->type, device_type_);
} else {
return e;
}
}
public:
const Variable* var_{nullptr};
int device_type_;
};
......
......@@ -11,10 +11,7 @@ def test_add():
s = tvm.create_schedule(C.op)
def check_c():
f1 = tvm.lower(s, [A, B, C], name="fadd")
fsplits = [x for x in tvm.ir_pass.SplitHostDevice(f1)]
fsplits[0] = tvm.ir_pass.LowerTVMBuiltin(fsplits[0])
mhost = tvm.codegen.build_module(fsplits[0], "c")
mhost = tvm.build(s, [A, B, C], "c", name="fadd")
temp = util.tempdir()
path_dso = temp.relpath("temp.so")
mhost.export_library(path_dso)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment