Commit 5293c6bf by 雾雨魔理沙 Committed by Haichen Shen

[Relay] use unordered_map instead of map in ANF (#3024)

parent 8d3b392d
...@@ -34,7 +34,9 @@ ...@@ -34,7 +34,9 @@
namespace tvm { namespace tvm {
namespace relay { namespace relay {
Expr ToANormalForm(const Expr& e, const Module& m, std::set<GlobalVar>* gv); Expr ToANormalForm(const Expr& e,
const Module& m,
std::unordered_set<GlobalVar, NodeHash, NodeEqual>* gv);
struct ScopeNode; struct ScopeNode;
using Scope = std::shared_ptr<ScopeNode>; using Scope = std::shared_ptr<ScopeNode>;
...@@ -104,7 +106,7 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> { ...@@ -104,7 +106,7 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
const Module& m, const Module& m,
const DependencyGraph& dg, const DependencyGraph& dg,
std::unordered_map<DependencyGraph::Node*, Scope>* node_scope, std::unordered_map<DependencyGraph::Node*, Scope>* node_scope,
std::set<GlobalVar>* gv) { std::unordered_set<GlobalVar, NodeHash, NodeEqual>* gv) {
Fill fi(m, dg, node_scope, gv); Fill fi(m, dg, node_scope, gv);
return fi.GetScope(e)->ll->Get(fi.VisitExpr(e)); return fi.GetScope(e)->ll->Get(fi.VisitExpr(e));
} }
...@@ -113,13 +115,13 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> { ...@@ -113,13 +115,13 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
Module mod_; Module mod_;
const DependencyGraph& dg_; const DependencyGraph& dg_;
std::unordered_map<DependencyGraph::Node*, Scope>* node_scope_; std::unordered_map<DependencyGraph::Node*, Scope>* node_scope_;
std::set<GlobalVar>* visited_; std::unordered_set<GlobalVar, NodeHash, NodeEqual>* visited_;
std::unordered_map<Expr, Expr, NodeHash, NodeEqual> memo; std::unordered_map<Expr, Expr, NodeHash, NodeEqual> memo;
Fill(Module mod, Fill(Module mod,
const DependencyGraph& dg, const DependencyGraph& dg,
std::unordered_map<DependencyGraph::Node*, Scope>* node_scope, std::unordered_map<DependencyGraph::Node*, Scope>* node_scope,
std::set<GlobalVar>* visited) : std::unordered_set<GlobalVar, NodeHash, NodeEqual>* visited) :
mod_(mod), mod_(mod),
dg_(dg), dg_(dg),
node_scope_(node_scope), node_scope_(node_scope),
...@@ -273,7 +275,9 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> { ...@@ -273,7 +275,9 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
} }
}; };
Expr ToANormalFormAux(const Expr& e, const Module& m, std::set<GlobalVar>* gv) { Expr ToANormalFormAux(const Expr& e,
const Module& m,
std::unordered_set<GlobalVar, NodeHash, NodeEqual>* gv) {
/* When you lift a lambda, what is inside is also being lift. /* When you lift a lambda, what is inside is also being lift.
* *
* So we must determine the scope of the lambda before determining the scope of it's body. * So we must determine the scope of the lambda before determining the scope of it's body.
...@@ -299,12 +303,14 @@ Expr ToANormalFormAux(const Expr& e, const Module& m, std::set<GlobalVar>* gv) { ...@@ -299,12 +303,14 @@ Expr ToANormalFormAux(const Expr& e, const Module& m, std::set<GlobalVar>* gv) {
return Fill::ToANormalForm(e, m, dg, &node_scope, gv); return Fill::ToANormalForm(e, m, dg, &node_scope, gv);
} }
Expr ToANormalForm(const Expr& e, const Module& m, std::set<GlobalVar>* gv) { Expr ToANormalForm(const Expr& e,
const Module& m,
std::unordered_set<GlobalVar, NodeHash, NodeEqual>* gv) {
return TransformF([&](const Expr& e) { return ToANormalFormAux(e, m, gv); }, e); return TransformF([&](const Expr& e) { return ToANormalFormAux(e, m, gv); }, e);
} }
Expr ToANormalForm(const Expr& e, const Module& m) { Expr ToANormalForm(const Expr& e, const Module& m) {
std::set<GlobalVar> gv; std::unordered_set<GlobalVar, NodeHash, NodeEqual> gv;
return ToANormalForm(e, m, &gv); return ToANormalForm(e, m, &gv);
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment