Commit b3f3ab55 by 雾雨魔理沙 Committed by Tianqi Chen

[Relay] Fix PE (#3482)

parent 287078c3
...@@ -55,7 +55,7 @@ struct Module; ...@@ -55,7 +55,7 @@ struct Module;
* The functional style allows users to construct custom * The functional style allows users to construct custom
* environments easily, for example each thread can store * environments easily, for example each thread can store
* a Module while auto-tuning. * a Module while auto-tuning.
* */ */
class ModuleNode : public RelayNode { class ModuleNode : public RelayNode {
public: public:
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
*/ */
/*! /*!
* Copyright (c) 2018 by Contributors * Copyright (c) 2019 by Contributors
* \file src/tvm/relay/expr_mutator.cc * \file src/tvm/relay/expr_mutator.cc
* \brief A wrapper around ExprFunctor which functionally updates the AST. * \brief A wrapper around ExprFunctor which functionally updates the AST.
* *
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
* the cost of using functional updates. * the cost of using functional updates.
*/ */
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include "type_functor.h" #include "type_functor.h"
namespace tvm { namespace tvm {
...@@ -353,7 +354,7 @@ TVM_REGISTER_API("relay._analysis.post_order_visit") ...@@ -353,7 +354,7 @@ TVM_REGISTER_API("relay._analysis.post_order_visit")
}); });
// Implement bind. // Implement bind.
class ExprBinder : public ExprMutator { class ExprBinder : public ExprMutator, PatternMutator {
public: public:
explicit ExprBinder(const tvm::Map<Var, Expr>& args_map) explicit ExprBinder(const tvm::Map<Var, Expr>& args_map)
: args_map_(args_map) { : args_map_(args_map) {
...@@ -383,13 +384,26 @@ class ExprBinder : public ExprMutator { ...@@ -383,13 +384,26 @@ class ExprBinder : public ExprMutator {
} }
} }
Pattern VisitPattern(const Pattern& p) final {
return PatternMutator::VisitPattern(p);
}
Clause VisitClause(const Clause& c) final {
Pattern pat = VisitPattern(c->lhs);
return ClauseNode::make(pat, VisitExpr(c->rhs));
}
Var VisitVar(const Var& v) final {
return Downcast<Var>(VisitExpr(v));
}
private: private:
const tvm::Map<Var, Expr>& args_map_; const tvm::Map<Var, Expr>& args_map_;
}; };
Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) { Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
if (const FunctionNode* func = expr.as<FunctionNode>()) { if (const FunctionNode* func = expr.as<FunctionNode>()) {
Expr new_body = ExprBinder(args_map).Mutate(func->body); Expr new_body = ExprBinder(args_map).VisitExpr(func->body);
Array<Var> new_params; Array<Var> new_params;
for (Var param : func->params) { for (Var param : func->params) {
if (!args_map.count(param)) { if (!args_map.count(param)) {
...@@ -406,7 +420,7 @@ Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) { ...@@ -406,7 +420,7 @@ Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
func->type_params, func->type_params,
func->attrs); func->attrs);
} else { } else {
return ExprBinder(args_map).Mutate(expr); return ExprBinder(args_map).VisitExpr(expr);
} }
} }
......
...@@ -92,6 +92,10 @@ void TypeVisitor::VisitType_(const TypeDataNode* op) { ...@@ -92,6 +92,10 @@ void TypeVisitor::VisitType_(const TypeDataNode* op) {
} }
} }
Type TypeMutator::VisitType(const Type& t) {
return t.defined() ? TypeFunctor<Type(const Type&)>::VisitType(t) : t;
}
// Type Mutator. // Type Mutator.
Array<Type> TypeMutator::MutateArray(Array<Type> arr) { Array<Type> TypeMutator::MutateArray(Array<Type> arr) {
// The array will do copy on write // The array will do copy on write
...@@ -221,7 +225,7 @@ class TypeBinder : public TypeMutator { ...@@ -221,7 +225,7 @@ class TypeBinder : public TypeMutator {
}; };
Type Bind(const Type& type, const tvm::Map<TypeVar, Type>& args_map) { Type Bind(const Type& type, const tvm::Map<TypeVar, Type>& args_map) {
return type.defined() ? TypeBinder(args_map).VisitType(type) : type; return TypeBinder(args_map).VisitType(type);
} }
} // namespace relay } // namespace relay
......
...@@ -139,6 +139,7 @@ class TypeVisitor : public TypeFunctor<void(const Type& n)> { ...@@ -139,6 +139,7 @@ class TypeVisitor : public TypeFunctor<void(const Type& n)> {
// Mutator that transform a type to another one. // Mutator that transform a type to another one.
class TypeMutator : public TypeFunctor<Type(const Type& n)> { class TypeMutator : public TypeFunctor<Type(const Type& n)> {
public: public:
Type VisitType(const Type& t) override;
Type VisitType_(const TypeVarNode* op) override; Type VisitType_(const TypeVarNode* op) override;
Type VisitType_(const TensorTypeNode* op) override; Type VisitType_(const TensorTypeNode* op) override;
Type VisitType_(const IncompleteTypeNode* op) override; Type VisitType_(const IncompleteTypeNode* op) override;
......
...@@ -48,7 +48,7 @@ class LetList { ...@@ -48,7 +48,7 @@ class LetList {
public: public:
~LetList() { ~LetList() {
if (lets_.size() > 0 && !used_) { if (lets_.size() > 0 && !used_) {
std::cout << "Warning: letlist not used" << std::endl; LOG(WARNING) << "letlist not used";
} }
} }
/*! /*!
......
...@@ -64,7 +64,7 @@ ...@@ -64,7 +64,7 @@
* 3: The generated code reuses bindings (although they are not shadowed), * 3: The generated code reuses bindings (although they are not shadowed),
* so we have to deduplicate them. * so we have to deduplicate them.
* *
* 4: In the generated code, multiple VarNode might have same Id. * 4: In the generated code, as it call TypeSubst, multiple VarNode might have same Id.
* While it is permitted, most pass use NodeHash for Var, * While it is permitted, most pass use NodeHash for Var,
* and having multiple VarNode for same Id break them. * and having multiple VarNode for same Id break them.
* Thus we remap them to a single Id for now. * Thus we remap them to a single Id for now.
...@@ -216,9 +216,9 @@ Static MkSRef() { ...@@ -216,9 +216,9 @@ Static MkSRef() {
} }
using Func = std::function<PStatic(const std::vector<PStatic>&, using Func = std::function<PStatic(const std::vector<PStatic>&,
const Attrs&, const Attrs&,
const Array<Type>&, const Array<Type>&,
LetList*)>; LetList*)>;
struct SFuncNode : StaticNode { struct SFuncNode : StaticNode {
Func func; Func func;
...@@ -256,6 +256,7 @@ class Environment { ...@@ -256,6 +256,7 @@ class Environment {
void Insert(const Var& v, const PStatic& ps) { void Insert(const Var& v, const PStatic& ps) {
CHECK(ps.defined()); CHECK(ps.defined());
CHECK_EQ(env_.back().locals.count(v), 0);
env_.back().locals[v] = ps; env_.back().locals[v] = ps;
} }
...@@ -287,12 +288,17 @@ class Environment { ...@@ -287,12 +288,17 @@ class Environment {
/*! /*!
* \brief As our store require rollback, we implement it as a frame. * \brief As our store require rollback, we implement it as a frame.
* every time we need to copy the store, a new frame is insert. *
* every time we roll back, a frame is popped. * Every time we need to copy the store, a new frame is insert.
* Every time we roll back, a frame is popped.
*/ */
struct StoreFrame { struct StoreFrame {
std::unordered_map<const SRefNode*, PStatic> store; std::unordered_map<const SRefNode*, PStatic> store;
/*! \brief on unknown effect, history_valid is set to true to signal above frame is outdated */ /*!
* \brief On unknown effect, history_valid is set to true to signal above frame is outdated.
*
* It only outdate the frame above it, but not the current frame.
*/
bool history_valid = true; bool history_valid = true;
explicit StoreFrame(const std::unordered_map<const SRefNode*, PStatic>& store) : store(store) { } explicit StoreFrame(const std::unordered_map<const SRefNode*, PStatic>& store) : store(store) { }
StoreFrame() = default; StoreFrame() = default;
...@@ -310,6 +316,7 @@ class Store { ...@@ -310,6 +316,7 @@ class Store {
} }
void Insert(const SRefNode* r, const PStatic& ps) { void Insert(const SRefNode* r, const PStatic& ps) {
CHECK(r);
store_.back().store[r] = ps; store_.back().store[r] = ps;
} }
...@@ -317,19 +324,21 @@ class Store { ...@@ -317,19 +324,21 @@ class Store {
PStatic Lookup(const SRefNode* r) { PStatic Lookup(const SRefNode* r) {
auto rit = store_.rbegin(); auto rit = store_.rbegin();
while (rit != store_.rend()) { while (rit != store_.rend()) {
if (!rit->history_valid) {
return PStatic();
}
if (rit->store.find(r) != rit->store.end()) { if (rit->store.find(r) != rit->store.end()) {
return rit->store.find(r)->second; return rit->store.find(r)->second;
} }
if (!rit->history_valid) {
return PStatic();
}
++rit; ++rit;
} }
return PStatic(); return PStatic();
} }
void Invalidate() { void Invalidate() {
store_.back().history_valid = false; StoreFrame sf;
sf.history_valid = false;
store_.push_back(sf);
} }
private: private:
...@@ -341,6 +350,10 @@ class Store { ...@@ -341,6 +350,10 @@ class Store {
store_->store_.push_back(StoreFrame()); store_->store_.push_back(StoreFrame());
} }
~StoreFrameContext() { ~StoreFrameContext() {
// push one history valid frame off.
while (!store_->store_.back().history_valid) {
store_->store_.pop_back();
}
store_->store_.pop_back(); store_->store_.pop_back();
} }
}; };
...@@ -442,13 +455,7 @@ Function AsFunc(const Expr& e) { ...@@ -442,13 +455,7 @@ Function AsFunc(const Expr& e) {
class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>, class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>,
public PatternFunctor<MatchStatus(const Pattern&, const PStatic&)> { public PatternFunctor<MatchStatus(const Pattern&, const PStatic&)> {
public: public:
PartialEvaluator(const tvm::Array<Var>& free_vars, PartialEvaluator(const Module& mod) : mod_(mod) { }
const Module& mod) :
mod_(mod) {
for (const Var& v : free_vars) {
env_.Insert(v, NoStatic(v));
}
}
PStatic VisitExpr(const Expr& e, LetList* ll) final { PStatic VisitExpr(const Expr& e, LetList* ll) final {
PStatic ret = ExprFunctor<PStatic(const Expr&, LetList*)>::VisitExpr(e, ll); PStatic ret = ExprFunctor<PStatic(const Expr&, LetList*)>::VisitExpr(e, ll);
...@@ -484,23 +491,23 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)> ...@@ -484,23 +491,23 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
return env_.Lookup(GetRef<Var>(op)); return env_.Lookup(GetRef<Var>(op));
} }
PStatic VisitExpr_(const GlobalVarNode* op, LetList* ll) final { PStatic VisitGlobalVar(const GlobalVar& gv) {
GlobalVar gv = GetRef<GlobalVar>(op); CHECK(mod_.defined());
if (gv_map_.count(gv) == 0) { if (gv_map_.count(gv) == 0) {
if (mod_.defined()) { Function func = mod_->Lookup(gv);
Function func = mod_->Lookup(gv); InitializeFuncId(func);
InitializeFuncId(func); Func f = VisitFuncStatic(func, gv);
Func f = VisitFuncStatic(func, gv); gv_map_.insert({gv, HasStatic(MkSFunc(f), gv)});
gv_map_.insert({gv, HasStatic(MkSFunc(f), gv)}); func = AsFunc(PostProcess(VisitFuncDynamic(func, f)));
func = AsFunc(PostProcess(VisitFuncDynamic(func, f))); mod_->Update(gv, func);
mod_->Update(gv, func);
} else {
gv_map_.insert({gv, NoStatic(gv)});
}
} }
return gv_map_.at(gv); return gv_map_.at(gv);
} }
PStatic VisitExpr_(const GlobalVarNode* op, LetList* ll) final {
return VisitGlobalVar(GetRef<GlobalVar>(op));
}
PStatic VisitExpr_(const LetNode* op, LetList* ll) final { PStatic VisitExpr_(const LetNode* op, LetList* ll) final {
env_.Insert(op->var, VisitExpr(op->value, ll)); env_.Insert(op->var, VisitExpr(op->value, ll));
return VisitExpr(op->body, ll); return VisitExpr(op->body, ll);
...@@ -629,7 +636,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)> ...@@ -629,7 +636,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
subst.Set(func->type_params[i], type_args[i]); subst.Set(func->type_params[i], type_args[i]);
} }
for (size_t i = type_args.size(); i < func->type_params.size(); ++i) { for (size_t i = type_args.size(); i < func->type_params.size(); ++i) {
subst.Set(func->type_params[i], Type()); subst.Set(func->type_params[i], IncompleteTypeNode::make(kType));
} }
std::vector<Time> args_time; std::vector<Time> args_time;
for (const auto& v : pv) { for (const auto& v : pv) {
...@@ -672,22 +679,22 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)> ...@@ -672,22 +679,22 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
}; };
} }
Expr VisitFuncDynamic(const Function& func, const Func& f) { Expr VisitFuncDynamic(const Function& func, const Func& f) {
return store_.Extend<Expr>([&]() { return store_.Extend<Expr>([&]() {
store_.Invalidate(); store_.Invalidate();
return FunctionNode::make(func->params, LetList::With([&](LetList* ll) { return FunctionNode::make(func->params,
std::vector<PStatic> pv; LetList::With([&](LetList* ll) {
for (const auto& v : func->params) { std::vector<PStatic> pv;
pv.push_back(NoStatic(v)); for (const auto& v : func->params) {
} pv.push_back(NoStatic(v));
tvm::Array<Type> type_args; }
for (const auto& tp : func->type_params) { tvm::Array<Type> type_args;
type_args.push_back(tp); for (const auto& tp : func->type_params) {
} type_args.push_back(tp);
return f(pv, Attrs(), type_args, ll)->dynamic; }
}), func->ret_type, func->type_params, func->attrs); return f(pv, Attrs(), type_args, ll)->dynamic;
}); }), func->ret_type, func->type_params, func->attrs);
});
} }
PStatic VisitFunc(const Function& func, LetList* ll) { PStatic VisitFunc(const Function& func, LetList* ll) {
...@@ -1012,17 +1019,14 @@ Expr PostProcess(const Expr& e) { ...@@ -1012,17 +1019,14 @@ Expr PostProcess(const Expr& e) {
Module PartialEval(const Module& m) { Module PartialEval(const Module& m) {
CHECK(m->entry_func.defined()); CHECK(m->entry_func.defined());
auto func = m->Lookup(m->entry_func); relay::partial_eval::PartialEvaluator pe(m);
Expr ret = std::vector<GlobalVar> gvs;
TransformF([&](const Expr& e) { for (const auto& p : m->functions) {
return LetList::With([&](LetList* ll) { gvs.push_back(p.first);
relay::partial_eval::PartialEvaluator pe(FreeVars(e), m); }
pe.InitializeFuncId(e); for (const auto& gv : gvs) {
return relay::partial_eval::PostProcess(pe.VisitExpr(e, ll)->dynamic); pe.VisitGlobalVar(gv);
}); }
}, func);
CHECK(ret->is_type<FunctionNode>());
m->Update(m->entry_func, Downcast<Function>(ret));
return m; return m;
} }
......
...@@ -172,6 +172,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>, ...@@ -172,6 +172,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
return it->second.checked_type; return it->second.checked_type;
} }
Type ret = this->VisitExpr(expr); Type ret = this->VisitExpr(expr);
CHECK(ret.defined());
KindCheck(ret, mod_); KindCheck(ret, mod_);
ResolvedTypeInfo& rti = type_map_[expr]; ResolvedTypeInfo& rti = type_map_[expr];
rti.checked_type = ret; rti.checked_type = ret;
......
...@@ -425,6 +425,16 @@ Expr TypeSubst(const Expr& expr, const tvm::Map<TypeVar, Type>& subst_map) { ...@@ -425,6 +425,16 @@ Expr TypeSubst(const Expr& expr, const tvm::Map<TypeVar, Type>& subst_map) {
Var VisitVar(const Var& v) final { Var VisitVar(const Var& v) final {
return Downcast<Var>(VisitExpr(v)); return Downcast<Var>(VisitExpr(v));
} }
Pattern VisitPattern(const Pattern& p) final {
return PatternMutator::VisitPattern(p);
}
Clause VisitClause(const Clause& c) final {
Pattern pat = VisitPattern(c->lhs);
return ClauseNode::make(pat, VisitExpr(c->rhs));
}
private: private:
const tvm::Map<TypeVar, Type>& subst_map_; const tvm::Map<TypeVar, Type>& subst_map_;
}; };
......
...@@ -307,10 +307,10 @@ def test_double(): ...@@ -307,10 +307,10 @@ def test_double():
if __name__ == '__main__': if __name__ == '__main__':
test_empty_ad() test_ref()
test_tuple() test_tuple()
test_empty_ad()
test_const_inline() test_const_inline()
test_ref()
test_ad() test_ad()
test_if_ref() test_if_ref()
test_function_invalidate() test_function_invalidate()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment