Commit b076cad5 by Wuwei Lin Committed by ziheng

[RELAY] Inline scalar compute (#2335)

parent 848dc39a
...@@ -84,6 +84,9 @@ class ScheduleGetter : ...@@ -84,6 +84,9 @@ class ScheduleGetter :
CHECK(master_op_.defined()); CHECK(master_op_.defined());
Schedule schedule = fschedule[master_op_]( Schedule schedule = fschedule[master_op_](
master_attrs_, cache_node->outputs, target_); master_attrs_, cache_node->outputs, target_);
for (const auto& scalar : scalars_) {
schedule[scalar].compute_inline();
}
return std::make_pair(schedule, cfunc); return std::make_pair(schedule, cfunc);
} }
...@@ -123,6 +126,7 @@ class ScheduleGetter : ...@@ -123,6 +126,7 @@ class ScheduleGetter :
return tvm::Expr(); return tvm::Expr();
} }
}); });
scalars_.push_back(value->op);
return {value}; return {value};
} }
...@@ -216,6 +220,7 @@ class ScheduleGetter : ...@@ -216,6 +220,7 @@ class ScheduleGetter :
int master_op_pattern_{0}; int master_op_pattern_{0};
std::ostringstream readable_name_stream_; std::ostringstream readable_name_stream_;
std::unordered_map<Expr, Array<Tensor>, NodeHash, NodeEqual> memo_; std::unordered_map<Expr, Array<Tensor>, NodeHash, NodeEqual> memo_;
Array<Operation> scalars_;
}; };
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment