Commit 6ab05082 by Siva Committed by Tianqi Chen

[RELAY] Filter PlaceholderOp from schedule. (#2412)

parent b692289e
......@@ -83,11 +83,20 @@ class ScheduleGetter :
cache_node->func_name = readable_name_stream_.str();
CachedFunc cfunc(cache_node);
CHECK(master_op_.defined());
// Fusion over tupled results may leave identity relationships
// between inputs and outputs, and those should not be scheduled.
// Hence schedule only non PlaceholderOp outputs.
tvm::Array<Tensor> tensor_outs;
for (const auto& tensor : cache_node->outputs) {
if (!tensor->op.as<PlaceholderOpNode>()) {
tensor_outs.push_back(tensor);
}
}
Schedule schedule;
// No need to register schedule for device copy op.
if (master_attrs_.as<DeviceCopyAttrs>() == nullptr) {
schedule =
fschedule[master_op_](master_attrs_, cache_node->outputs, target_);
fschedule[master_op_](master_attrs_, tensor_outs, target_);
for (const auto& scalar : scalars_) {
schedule[scalar].compute_inline();
}
......
......@@ -33,6 +33,16 @@ def test_compile_engine():
y.asnumpy(), x.asnumpy() * 3)
engine.dump()
def test_compile_placeholder_bypass():
engine = relay.backend.compile_engine.get()
x = relay.var("x", shape=(2, 3))
y = relay.var("y", shape=(2, 3))
z = relay.var("z", shape=(2, 3))
result = relay.Tuple([x, relay.op.concatenate([y, z], axis=0)])
func = relay.Function(relay.ir_pass.free_vars(result), result)
with relay.build_config(opt_level=0):
graph, lib, params = relay.build(func, 'llvm')
if __name__ == "__main__":
test_compile_engine()
test_compile_placeholder_bypass()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment