Commit c8654e2a by 雾雨魔理沙 Committed by Thierry Moreau

[Relay] Partial Evaluator do concatenate, and has better termination checker for scalar. (#3703)

* save

lint some

lint

lint

add charrnn

save

save

save

remove debug

remove debug

remove space

refactor

save

rewrite dce

* reset files

* join -> meet

* lint

* address review comment

* wordsmith
parent ee74d00e
......@@ -14,7 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""The scope builder interface """
"""The scope builder interface."""
from __future__ import absolute_import
from . import expr as _expr
......
......@@ -419,8 +419,8 @@ class AlphaEqualHandler:
bool VisitExpr_(const LetNode* lhs, const Expr& other) final {
if (const LetNode* rhs = other.as<LetNode>()) {
if (!ExprEqual(lhs->value, rhs->value)) return false;
if (!MergeVarDecl(lhs->var, rhs->var)) return false;
if (!ExprEqual(lhs->value, rhs->value)) return false;
return ExprEqual(lhs->body, rhs->body);
} else {
return false;
......
......@@ -36,121 +36,94 @@
namespace tvm {
namespace relay {
template<typename X>
using VarMap = std::unordered_map<Var, X, NodeHash, NodeEqual>;
using VarSet = std::unordered_set<Var, NodeHash, NodeEqual>;
class CalcDep;
class FindDef : private ExprVisitor {
private:
VarMap<Expr> expr_map_;
void VisitExpr_(const LetNode* l) final {
CHECK_EQ(expr_map_.count(l->var), 0);
expr_map_[l->var] = l->value;
VisitExpr(l->value);
VisitExpr(l->body);
}
friend CalcDep;
};
class Eliminator : private ExprMutator {
private:
VarMap<Expr> expr_map_;
VarMap<size_t> use_map_;
bool inline_once_;
explicit Eliminator(const VarMap<Expr>& expr_map,
const VarMap<size_t>& use_map,
bool inline_once) :
expr_map_(expr_map), use_map_(use_map), inline_once_(inline_once) { }
friend CalcDep;
bool HasLet(const Var& v) {
switch (use_map_[v]) {
case 0:
return false;
case 1:
return !inline_once_;
default:
return true;
}
}
Expr VisitExpr_(const VarNode* op) final {
Var v = GetRef<Var>(op);
return (expr_map_.count(v) == 0 || HasLet(v)) ? v : VisitExpr(expr_map_[v]);
}
Expr VisitExpr_(const LetNode* op) final {
Var v = op->var;
if (HasLet(v)) {
return LetNode::make(v, VisitExpr(op->value), VisitExpr(op->body));
} else {
return VisitExpr(op->body);
}
}
};
// calculate the dependency graph from expression
class CalcDep : private ExprVisitor {
public:
static Expr Eliminate(const Expr& e, bool inline_once) {
CalcDep cd;
cd.Calculate(e);
Eliminator el(cd.expr_map_, cd.use_map_, cd.letrec_set_, inline_once);
FindDef fd;
fd(e);
CalcDep cd(fd.expr_map_);
cd(e);
Eliminator el(fd.expr_map_, cd.use_map_, inline_once);
return el(e);
}
private:
template<typename X>
using VarMap = std::unordered_map<Var, X, NodeHash, NodeEqual>;
using VarSet = std::unordered_set<Var, NodeHash, NodeEqual>;
explicit CalcDep(const VarMap<Expr>& expr_map) : expr_map_(expr_map) { }
VarMap<Expr> expr_map_;
VarMap<size_t> use_map_;
VarSet letrec_set_;
bool count_ = true;
VarSet dead_worklist_;
VarSet current_letrec_;
void LetRec(const std::function<void()>& func, const Var& v) {
current_letrec_.insert(v);
func();
current_letrec_.erase(v);
void VisitExpr(const Expr& e) final {
return ExprFunctor<void(const Expr& e)>::VisitExpr(e);
}
void VisitExpr_(const LetNode* l) final {
if (count_) {
CHECK_EQ(expr_map_.count(l->var), 0);
CHECK_EQ(use_map_.count(l->var), 0);
expr_map_[l->var] = l->value;
use_map_[l->var] = 0;
dead_worklist_.insert(l->var);
LetRec([&]() { VisitExpr(l->value); }, l->var);
}
VisitExpr(l->body);
}
void VisitExpr(const Expr& e) final {
ExprFunctor<void(const Expr&)>::VisitExpr(e);
}
void VisitExpr_(const VarNode* v) final {
Var var = GetRef<Var>(v);
if (expr_map_.count(var) == 0) {
return;
}
if (current_letrec_.count(var) == 0) {
if (count_) {
use_map_[var] += 1;
dead_worklist_.erase(var);
} else {
CHECK_GT(use_map_[var], 0) << var;
use_map_[var] -= 1;
if (use_map_[var] == 0) {
dead_worklist_.insert(var);
}
}
} else {
letrec_set_.insert(var);
++use_map_[var];
if (use_map_[var] == 1 && expr_map_.count(var) > 0) {
VisitExpr(expr_map_[var]);
}
}
void Calculate(const Expr& v) {
VisitExpr(v);
count_ = false;
while (!dead_worklist_.empty()) {
Var dead = *(dead_worklist_.begin());
dead_worklist_.erase(dead);
CHECK_EQ(use_map_[dead], 0);
if (expr_map_.count(dead) > 0) {
LetRec([&]() { VisitExpr(expr_map_[dead]); }, dead);
}
}
}
class Eliminator : private ExprMutator {
private:
VarMap<Expr> expr_map_;
VarMap<size_t> use_map_;
VarSet letrec_set_;
bool inline_once_;
explicit Eliminator(const VarMap<Expr>& expr_map,
const VarMap<size_t>& use_map,
const VarSet& letrec_set,
bool inline_once) :
expr_map_(expr_map), use_map_(use_map), letrec_set_(letrec_set), inline_once_(inline_once) { }
friend CalcDep;
bool HasLet(const Var& v) {
switch (use_map_[v]) {
case 0:
return false;
case 1:
return letrec_set_.count(v) > 0 || !inline_once_;
default:
return true;
}
}
Expr VisitExpr_(const VarNode* op) final {
Var v = GetRef<Var>(op);
return (expr_map_.count(v) == 0 || HasLet(v)) ? v : VisitExpr(expr_map_[v]);
}
Expr VisitExpr_(const LetNode* op) final {
Var v = op->var;
if (HasLet(v)) {
return LetNode::make(v, VisitExpr(op->value), VisitExpr(op->body));
} else {
return VisitExpr(op->body);
}
}
};
};
Expr DeadCodeElimination(const Expr& e, bool inline_once) {
......
......@@ -128,11 +128,11 @@ struct VarEqual {
Expr PostProcess(const Expr&);
/*! \brief The base container type of Relay values. */
/*! \brief A StaticNode contains some static data that the Partial Evaluator can use. */
class StaticNode : public RelayNode {
public:
static constexpr const char* _type_key = "relay.Static";
TVM_DECLARE_BASE_NODE_INFO(ValueNode, RelayNode);
TVM_DECLARE_BASE_NODE_INFO(StaticNode, RelayNode);
};
class Static : public NodeRef {
......@@ -174,7 +174,7 @@ struct STupleNode : StaticNode {
TVM_DECLARE_NODE_TYPE_INFO(STupleNode, StaticNode);
};
RELAY_DEFINE_NODE_REF(STuple, STupleNode, Value);
RELAY_DEFINE_NODE_REF(STuple, STupleNode, Static);
Static MkSTuple(const std::vector<PStatic>& fields) {
return Static(make_node<STupleNode>(fields));
......@@ -187,7 +187,7 @@ struct STensorNode : StaticNode {
TVM_DECLARE_NODE_TYPE_INFO(STensorNode, StaticNode);
};
RELAY_DEFINE_NODE_REF(STensor, STensorNode, Value);
RELAY_DEFINE_NODE_REF(STensor, STensorNode, Static);
Static MkSTensor(const NDArray& data) {
return Static(make_node<STensorNode>(data));
......@@ -202,7 +202,7 @@ struct SConstructorNode : StaticNode {
TVM_DECLARE_NODE_TYPE_INFO(SConstructorNode, StaticNode);
};
RELAY_DEFINE_NODE_REF(SConstructor, SConstructorNode, Value);
RELAY_DEFINE_NODE_REF(SConstructor, SConstructorNode, Static);
Static MkSConstructor(const Constructor& constructor, const std::vector<PStatic>& fields) {
return Static(make_node<SConstructorNode>(constructor, fields));
......@@ -214,13 +214,14 @@ struct SRefNode : StaticNode {
TVM_DECLARE_NODE_TYPE_INFO(SRefNode, StaticNode);
};
RELAY_DEFINE_NODE_REF(SRef, SRefNode, Value);
RELAY_DEFINE_NODE_REF(SRef, SRefNode, Static);
Static MkSRef() {
return Static(make_node<SRefNode>());
}
using Func = std::function<PStatic(const std::vector<PStatic>&,
using Func = std::function<PStatic(const PStatic&,
const std::vector<PStatic>&,
const Attrs&,
const Array<Type>&,
LetList*)>;
......@@ -232,12 +233,145 @@ struct SFuncNode : StaticNode {
TVM_DECLARE_NODE_TYPE_INFO(SFuncNode, StaticNode);
};
RELAY_DEFINE_NODE_REF(SFunc, SFuncNode, Value);
RELAY_DEFINE_NODE_REF(SFunc, SFuncNode, Static);
Static MkSFunc(const Func& func) {
return Static(make_node<SFuncNode>(func));
}
class FuelNode;
/*! \brief A meet-semilattice with finite descending chain.
* It means that we can meet two element to get an element,
* and for every element, there is only a finite amount of meet before getting back the same element.
*
* Every time we recurse, we do a meet and require that progress must be made.
* This ensures we do not recurse infinitely in the Partial Evaluator.
*/
class Fuel : public NodeRef {
public:
Fuel() {}
explicit Fuel(NodePtr<Node> n) : NodeRef(n) {}
const FuelNode* operator->() const;
using ContainerType = FuelNode;
};
class FuelNode : public RelayNode {
public:
// Please implement one of the following function or there will be infinite loop.
/*! \brief return the new Fuel, and whether progress is made.
*
* Note that progress is not symmetric - it only measure progress for (*this).
*
* Thus, if the generated is smaller then the argument of Meet,
* and the generated is not smaller then (*this),
* progress should be false.
*/
virtual std::tuple<Fuel, bool> Meet(const Fuel& f) const {
bool progress = false;
auto ret = Meet(f, &progress);
return std::make_tuple(ret, progress);
}
/*! \brief return the new Fuel, and write (*progress | is progress made) to *progress. */
virtual Fuel Meet(const Fuel& f, bool* progress) const {
CHECK(progress);
auto ret = Meet(f);
*progress |= std::get<1>(ret);
return std::get<0>(ret);
}
static constexpr const char* _type_key = "relay.Fuel";
TVM_DECLARE_BASE_NODE_INFO(FuelNode, RelayNode);
};
const FuelNode* Fuel::operator->() const {
return static_cast<const FuelNode*>(node_.get());
}
Fuel MkFSeq(const std::vector<Fuel>& fuels);
struct FSeqNode : FuelNode {
std::vector<Fuel> fuels;
Fuel Meet(const Fuel& f, bool* progress) const final {
auto x = f.as<FSeqNode>();
CHECK(x);
CHECK_EQ(fuels.size(), x->fuels.size());
std::vector<Fuel> new_fuels;
for (size_t i = 0; i < fuels.size(); ++i) {
new_fuels.push_back(fuels[i]->Meet(x->fuels[i], progress));
}
return MkFSeq(new_fuels);
}
explicit FSeqNode(const std::vector<Fuel>& fuels) : fuels(fuels) { }
static constexpr const char* _type_key = "relay.FSeq";
TVM_DECLARE_NODE_TYPE_INFO(FSeqNode, FuelNode);
};
RELAY_DEFINE_NODE_REF(FSeq, FSeqNode, Fuel);
Fuel MkFSeq(const std::vector<Fuel>& fuels) {
return Fuel(make_node<FSeqNode>(fuels));
}
Fuel MkFTime(Time time);
struct FTimeNode : FuelNode {
Time time;
std::tuple<Fuel, bool> Meet(const Fuel& f) const final {
auto x = f.as<FTimeNode>();
CHECK(x);
Time new_time = std::min(time, x->time);
return std::make_tuple(MkFTime(new_time), new_time < time);
}
explicit FTimeNode(Time time) : time(time) { }
static constexpr const char* _type_key = "relay.FTime";
TVM_DECLARE_NODE_TYPE_INFO(FTimeNode, FuelNode);
};
RELAY_DEFINE_NODE_REF(FTime, FTimeNode, Fuel);
Fuel MkFTime(Time time) {
return Fuel(make_node<FTimeNode>(time));
}
Fuel MkFTValue(size_t tvalue);
/*! \brief If the pstatic is hold a positive integer scalar, that number, else 0. */
struct FTValueNode : FuelNode {
size_t tvalue;
std::tuple<Fuel, bool> Meet(const Fuel& f) const final {
auto x = f.as<FTValueNode>();
CHECK(x);
size_t new_tvalue = std::min(tvalue, x->tvalue);
return std::make_tuple(MkFTValue(new_tvalue), new_tvalue < tvalue);
}
explicit FTValueNode(size_t tvalue) : tvalue(tvalue) { }
static constexpr const char* _type_key = "relay.FTValue";
TVM_DECLARE_NODE_TYPE_INFO(FTValueNode, FuelNode);
};
RELAY_DEFINE_NODE_REF(FTValue, FTValueNode, Fuel);
Fuel MkFTValue(size_t tvalue) {
return Fuel(make_node<FTValueNode>(tvalue));
}
/*! \brief Initially every element has Fuel of FTop. It is the largest element.
*
* Note that it is illegal to has FTop inside some other Fuel -
* doing so break the finite descending chain property.
*/
struct FTopNode : FuelNode {
std::tuple<Fuel, bool> Meet(const Fuel& f) const final {
return std::make_tuple(f, !f.as<FTopNode>());
}
static constexpr const char* _type_key = "relay.FTop";
TVM_DECLARE_NODE_TYPE_INFO(FTopNode, FuelNode);
};
RELAY_DEFINE_NODE_REF(FTop, FTopNode, Fuel);
Fuel MkFTop() {
return Fuel(make_node<FTopNode>());
}
/*!
* \brief A stack frame in the Relay interpreter.
*
......@@ -469,6 +603,20 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
return ret;
}
PStatic VisitExpr(const Expr& e, LetList* ll, const Var& name) {
if (auto* op = e.as<CallNode>()) {
if (op->op.same_as(WithFuncIdOp())) {
CHECK_EQ(op->args.size(), 1);
return VisitExpr(op->args[0], ll, name);
}
}
PStatic ret = e.as<FunctionNode>() ?
VisitFunc(Downcast<Function>(e), ll, name) :
VisitExpr(e, ll);
CHECK(IsAtomic(ret->dynamic)) << ret->dynamic;
return ret;
}
PStatic VisitExpr_(const ConstantNode* op, LetList* ll) final {
return HasStatic(MkSTensor(op->data.CopyTo(context_)), ll->Push(GetRef<Expr>(op)));
}
......@@ -504,7 +652,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
InitializeFuncId(func);
Func f = VisitFuncStatic(func, gv);
gv_map_.insert({gv, HasStatic(MkSFunc(f), gv)});
func = AsFunc(PostProcess(VisitFuncDynamic(func, f)));
func = AsFunc(PostProcess(VisitFuncDynamic(func, f, gv)));
mod_->Update(gv, func);
}
return gv_map_.at(gv);
......@@ -515,7 +663,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
}
PStatic VisitExpr_(const LetNode* op, LetList* ll) final {
env_.Insert(op->var, VisitExpr(op->value, ll));
env_.Insert(op->var, VisitExpr(op->value, ll, op->var));
return VisitExpr(op->body, ll);
}
......@@ -588,34 +736,53 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
x_dyn.push_back(ps->dynamic);
}
if (f->pstatic.defined()) {
return Downcast<SFunc>(f->pstatic)->func(x, op->attrs, op->type_args, ll);
return Downcast<SFunc>(f->pstatic)->func(f, x, op->attrs, op->type_args, ll);
} else {
store_.Invalidate();
return NoStatic(ll->Push(CallNode::make(f->dynamic, x_dyn, op->attrs, op->type_args)));
}
}
struct TimeFrame {
struct FuelFrame {
PartialEvaluator* pe_;
FuncId fid_;
std::vector<Time> old_time;
bool has_old_time;
TimeFrame(PartialEvaluator* pe,
Fuel old_fuel;
FuelFrame(PartialEvaluator* pe,
FuncId fid,
const std::vector<Time>& args_time) : pe_(pe), fid_(fid) {
has_old_time = pe_->time_map_.count(fid_) > 0;
old_time = pe_->time_map_[fid_];
pe_->time_map_[fid_] = args_time;
const Fuel& new_fuel) : pe_(pe), fid_(fid) {
CHECK_GT(pe_->fuel_map_.count(fid_), 0);
old_fuel = pe_->fuel_map_[fid_];
pe_->fuel_map_[fid_] = new_fuel;
}
~TimeFrame() {
if (has_old_time) {
pe_->time_map_[fid_] = old_time;
} else {
pe_->time_map_.erase(fid_);
}
~FuelFrame() {
pe_->fuel_map_[fid_] = old_fuel;
}
};
size_t GetFTValue(const PStatic& ps) {
if (ps->pstatic.defined()) {
if (auto* st = ps->pstatic.as<STensorNode>()) {
if (st->data.Shape().empty()) {
NDArray cpu_array = st->data.CopyTo(CPUContext());
DataType dtype = TVMType2Type(cpu_array->dtype);
if (dtype == Int(32)) {
return std::max<int32_t>(0, *static_cast<const int32_t*>(cpu_array->data));
} else if (dtype == Int(64)) {
return std::max<int64_t>(0, *static_cast<const int64_t*>(cpu_array->data));
}
}
}
}
return 0;
}
Fuel GetFuel(const PStatic& ps) {
std::vector<Fuel> fuels;
fuels.push_back(MkFTime(ps->created_time));
fuels.push_back(MkFTValue(GetFTValue(ps)));
return MkFSeq(fuels);
}
Func VisitFuncStatic(const Function& func, const Expr& var) {
CHECK(IsAtomic(var));
if (func->IsPrimitive()) {
......@@ -623,14 +790,20 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
}
std::vector<std::pair<Var, PStatic> > free_vars;
for (const auto& v : FreeVars(func)) {
free_vars.push_back(std::pair<Var, PStatic>(v, env_.Lookup(v)));
if (v != var) {
free_vars.push_back(std::pair<Var, PStatic>(v, env_.Lookup(v)));
}
}
return [=](const std::vector<PStatic>& pv,
return [=](const PStatic& self,
const std::vector<PStatic>& pv,
const Attrs& attrs,
const tvm::Array<Type>& type_args,
LetList* ll) {
return env_.Extend<PStatic>([&]() {
CHECK_EQ(pv.size(), func->params.size());
if (var.as<VarNode>()) {
env_.Insert(Downcast<Var>(var), self);
}
for (size_t i = 0; i < pv.size(); ++i) {
env_.Insert(func->params[i], pv[i]);
}
......@@ -644,48 +817,31 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
for (size_t i = type_args.size(); i < func->type_params.size(); ++i) {
subst.Set(func->type_params[i], IncompleteTypeNode::make(kType));
}
std::vector<Time> args_time;
std::vector<Fuel> args_fuel;
for (const auto& v : pv) {
args_time.push_back(v->created_time);
args_fuel.push_back(GetFuel(v));
}
CHECK_GT(func_map_.count(func), 0);
FuncId fid = func_map_.at(func);
auto recurse = [&]() {
TimeFrame tf(this, fid, args_time);
if (fuel_map_.count(fid) == 0) {
fuel_map_.insert({fid, MkFTop()});
}
auto meet_res = fuel_map_[fid]->Meet(MkFSeq(args_fuel));
if (std::get<1>(meet_res)) {
FuelFrame tf(this, fid, std::get<0>(meet_res));
return VisitExpr(RegisterFuncId(TypeSubst(AnnotateFuncId(func->body), subst)), ll);
};
if (time_map_.count(fid) == 0) {
return recurse();
} else {
/* We check to see that at least one argument decrease
* with respect to all previous invocation.
* The depth of the recursion is bounded by
* the sum of the time of all argument at the first call.
*/
bool can_recurse = false;
std::vector<Time>& min_time = time_map_.at(fid);
CHECK_EQ(args_time.size(), min_time.size());
for (size_t i = 0; i < args_time.size(); ++i) {
if (args_time[i] < min_time[i]) {
can_recurse = true;
}
args_time[i] = std::min(args_time[i], min_time[i]);
}
if (can_recurse) {
return recurse();
} else {
std::vector<Expr> dyn;
for (const auto& v : pv) {
dyn.push_back(v->dynamic);
}
return NoStatic(ll->Push(CallNode::make(var, dyn, attrs, type_args)));
std::vector<Expr> dyn;
for (const auto& v : pv) {
dyn.push_back(v->dynamic);
}
return NoStatic(ll->Push(CallNode::make(var, dyn, attrs, type_args)));
}
});
};
}
Expr VisitFuncDynamic(const Function& func, const Func& f) {
Expr VisitFuncDynamic(const Function& func, const Func& f, const Expr& self) {
return store_.Extend<Expr>([&]() {
store_.Invalidate();
return FunctionNode::make(func->params,
......@@ -698,19 +854,20 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
for (const auto& tp : func->type_params) {
type_args.push_back(tp);
}
return f(pv, Attrs(), type_args, ll)->dynamic;
return f(HasStatic(MkSFunc(f), self), pv, Attrs(), type_args, ll)->dynamic;
}), func->ret_type, func->type_params, func->attrs);
});
}
PStatic VisitFunc(const Function& func, LetList* ll) {
Var v = VarNode::make("x", Type());
Func f = VisitFuncStatic(func, v);
PStatic VisitFunc(const Function& func,
LetList* ll,
const Var& name = VarNode::make("x", Type())) {
Func f = VisitFuncStatic(func, name);
Function u_func = AsFunc(RegisterFuncId(DeDup(AnnotateFuncId(func))));
// TODO(@M.K.): we seems to reduce landin knot into letrec.
// restore letrec support across whole relay.
return HasStatic(MkSFunc(f),
ll->Push(v, VisitFuncDynamic(u_func, f)));
ll->Push(name, VisitFuncDynamic(u_func, f, name)));
}
PStatic VisitExpr_(const FunctionNode* op, LetList* ll) final {
......@@ -771,7 +928,8 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
Func ConstEvaluateFunc(const Expr& expr) {
CHECK_EQ(FreeVars(expr).size(), 0);
return [=](const std::vector<PStatic>& pv,
return [=](const PStatic& self,
const std::vector<PStatic>& pv,
const Attrs& attrs,
const tvm::Array<Type>& type_args,
LetList* ll) {
......@@ -804,7 +962,8 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
PStatic VisitExpr_(const ConstructorNode* op, LetList* ll) final {
Constructor c = GetRef<Constructor>(op);
Func f = [=](const std::vector<PStatic>& pv,
Func f = [=](const PStatic& self,
const std::vector<PStatic>& pv,
const Attrs& attrs,
const tvm::Array<Type>& type_args,
LetList* ll) {
......@@ -967,17 +1126,17 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
* We have finitely many FunctionIds.
* Each FunctionId maps to a class of semantically equivalent function (ignoring type),
* as both TypeSubst and DeDup create semantically equivalent function.
* We partially map each FunctionId to a std::vector<Time>,
* denoting the minimal TimeFrame of each argument of the function.
* We partially map each FunctionId to a Fuel.
* Every time we try to inline a Function,
* we make sure it either does not have a vector<Time>, which means this is the initial call,
* or some argument has a lesser time, which means some earlier argument is passed in.
* In any case, we remap the mapping to a minimal vector<Time> across all previous invocations
* we make sure it either does not have a Fuel,
* or we meet the existing fuel with the fuel calculated from the argument.
* If no progress is made, we do not inline.
* In both case, we remap the mapping to the new Fuel
* when we PE inside the Function body.
* Termination is guaranteed because the creation time of at least one argument will decrease every call.
* Termination is guaranteed because Fuel is finitely descending - there can only be so many meet.
*/
std::unordered_map<Function, FuncId, NodeHash, NodeEqual> func_map_;
std::unordered_map<FuncId, std::vector<Time> > time_map_;
std::unordered_map<FuncId, Fuel> fuel_map_;
Store store_;
DLContext context_ = CPUContext();
FInterpreter executor_ = CPUInterpreter();
......
......@@ -68,7 +68,7 @@ class GNF : public ExprMutator {
}
Expr VisitExpr_(const LetNode* ln) override {
var_map_.insert(std::pair<Var, Expr>(ln->var, VisitExpr(WrapRec(ln->var, ln->value))));
var_map_.insert(std::pair<Var, Expr>(ln->var, WrapRec(ln->var, VisitExpr(ln->value))));
return VisitExpr(ln->body);
}
};
......
......@@ -19,7 +19,7 @@ from nose.tools import nottest
import tvm
from tvm import relay
from tvm.relay import Function, transform
from tvm.relay.analysis import alpha_equal, graph_equal, free_vars
from tvm.relay.analysis import alpha_equal, graph_equal, free_vars, assert_alpha_equal
from tvm.relay.op import log, add, equal, subtract
......@@ -65,11 +65,10 @@ def test_used_let():
expected = relay.Let(e.c, e.one, e.c + e.c)
assert alpha_equal(Function([e.c], orig), Function([e.c], expected))
@nottest
def test_inline():
orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c))
orig = run_opt_pass(orig, transform.DeadCodeElimination())
assert alpha_equal(Function(free_vars(orig), orig), Function([e.d], e.d))
orig = run_opt_pass(orig, transform.DeadCodeElimination(True))
assert_alpha_equal(Function(free_vars(orig), orig), Function([e.d], e.d))
def test_chain_unused_let():
......@@ -78,6 +77,17 @@ def test_chain_unused_let():
assert alpha_equal(Function(free_vars(orig), orig), Function([e.e], e.e))
def use_f(func):
f = relay.Var("f")
n = relay.Var("n", e.int32)
data = relay.Var("data", e.float32)
funcbody = relay.If(equal(n, relay.const(0)),
data,
relay.Call(f, [subtract(n, relay.const(1)),
log(data)]))
value = relay.Function([n, data], funcbody, e.float32, [])
return relay.Let(f, value, func(f))
# make sure we dont infinite loop
def test_recursion():
"""
......@@ -91,21 +101,15 @@ def test_recursion():
}
f(2, 10000);
"""
f = relay.Var("f")
f1 = relay.Var("f1")
n = relay.Var("n", e.int32)
data = relay.Var("data", e.float32)
funcbody = relay.If(equal(n, relay.const(0)),
data,
relay.Call(f1, [subtract(n, relay.const(1)),
log(data)]))
value = relay.Function([n, data], funcbody, e.float32, [])
orig = relay.Let(f, value, relay.Call(f, [relay.const(2), relay.const(10000.0)]))
orig = use_f(lambda f: relay.Call(f, [relay.const(2), relay.const(10000.0)]))
dced = run_opt_pass(orig, transform.DeadCodeElimination())
orig = run_opt_pass(orig, transform.InferType())
assert graph_equal(dced, orig)
dced = run_opt_pass(relay.Let(f, value, e.three),
transform.DeadCodeElimination())
assert_alpha_equal(dced, orig)
def test_recursion_dead():
x = relay.Let(e.a, e.one, e.three)
dced_f = lambda f: x
dced = run_opt_pass(use_f(dced_f), transform.DeadCodeElimination())
assert alpha_equal(dced, e.three)
......@@ -133,5 +137,6 @@ if __name__ == "__main__":
test_inline()
test_chain_unused_let()
test_recursion()
test_recursion_dead()
test_op_let()
test_tuple_get_item()
......@@ -123,7 +123,7 @@ def test_ad():
body = relay.Let(x1, o, body)
expected = Function([d], relay.Let(x, m, body))
expected = run_opt_pass(expected, transform.InferType())
assert alpha_equal(g, expected)
assert_alpha_equal(g, expected)
def test_if_ref():
......@@ -311,7 +311,16 @@ def test_concat():
x = Var("x", t)
y = Var("x", t)
orig = run_infer_type(Function([x, y], op.concatenate([x, y], axis=0)))
assert_alpha_equal(orig, dcpe(orig))
assert_alpha_equal(dcpe(orig), orig)
def test_triangle():
t = relay.TensorType([], "int32")
x = Var("x", t)
f_var = Var("f")
f = Function([x], If(op.equal(x, const(0)), const(0), x + f_var(x - const(1))))
orig = run_infer_type(Let(f_var, f, f_var(const(10))))
assert_alpha_equal(dcpe(orig), const(55))
if __name__ == '__main__':
......@@ -332,3 +341,4 @@ if __name__ == '__main__':
test_global_match_nat_id()
test_match_nat_id()
test_concat()
test_triangle()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment