Commit 5293c6bf by 雾雨魔理沙 Committed by Haichen Shen

[Relay] use unordered_map instead of map in ANF (#3024)

parent 8d3b392d
......@@ -34,7 +34,9 @@
namespace tvm {
namespace relay {
Expr ToANormalForm(const Expr& e, const Module& m, std::set<GlobalVar>* gv);
Expr ToANormalForm(const Expr& e,
const Module& m,
std::unordered_set<GlobalVar, NodeHash, NodeEqual>* gv);
struct ScopeNode;
using Scope = std::shared_ptr<ScopeNode>;
......@@ -104,7 +106,7 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
const Module& m,
const DependencyGraph& dg,
std::unordered_map<DependencyGraph::Node*, Scope>* node_scope,
std::set<GlobalVar>* gv) {
std::unordered_set<GlobalVar, NodeHash, NodeEqual>* gv) {
Fill fi(m, dg, node_scope, gv);
return fi.GetScope(e)->ll->Get(fi.VisitExpr(e));
}
......@@ -113,13 +115,13 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
Module mod_;
const DependencyGraph& dg_;
std::unordered_map<DependencyGraph::Node*, Scope>* node_scope_;
std::set<GlobalVar>* visited_;
std::unordered_set<GlobalVar, NodeHash, NodeEqual>* visited_;
std::unordered_map<Expr, Expr, NodeHash, NodeEqual> memo;
Fill(Module mod,
const DependencyGraph& dg,
std::unordered_map<DependencyGraph::Node*, Scope>* node_scope,
std::set<GlobalVar>* visited) :
std::unordered_set<GlobalVar, NodeHash, NodeEqual>* visited) :
mod_(mod),
dg_(dg),
node_scope_(node_scope),
......@@ -273,7 +275,9 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
}
};
Expr ToANormalFormAux(const Expr& e, const Module& m, std::set<GlobalVar>* gv) {
Expr ToANormalFormAux(const Expr& e,
const Module& m,
std::unordered_set<GlobalVar, NodeHash, NodeEqual>* gv) {
/* When you lift a lambda, what is inside is also being lift.
*
* So we must determine the scope of the lambda before determining the scope of it's body.
......@@ -299,12 +303,14 @@ Expr ToANormalFormAux(const Expr& e, const Module& m, std::set<GlobalVar>* gv) {
return Fill::ToANormalForm(e, m, dg, &node_scope, gv);
}
Expr ToANormalForm(const Expr& e, const Module& m, std::set<GlobalVar>* gv) {
Expr ToANormalForm(const Expr& e,
const Module& m,
std::unordered_set<GlobalVar, NodeHash, NodeEqual>* gv) {
return TransformF([&](const Expr& e) { return ToANormalFormAux(e, m, gv); }, e);
}
Expr ToANormalForm(const Expr& e, const Module& m) {
std::set<GlobalVar> gv;
std::unordered_set<GlobalVar, NodeHash, NodeEqual> gv;
return ToANormalForm(e, m, &gv);
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment