Commit bb87f044 by 雾雨魔理沙 Committed by Tianqi Chen

add document (#2714)

lint

lint

save

save

add more case

save

error

lint

lint

commit

do

lint

save

fix lint

wrap it back as func

lint

save

remove dead comment

fix style

fix lint

Update src/relay/pass/partial_eval.cc

Co-Authored-By: MarisaKirisame <lolisa@marisa.moe>

Update src/relay/pass/partial_eval.cc

Co-Authored-By: MarisaKirisame <lolisa@marisa.moe>

Update src/relay/pass/partial_eval.cc

Co-Authored-By: MarisaKirisame <lolisa@marisa.moe>

Update src/relay/pass/partial_eval.cc

Co-Authored-By: MarisaKirisame <lolisa@marisa.moe>

Update src/relay/pass/partial_eval.cc

Co-Authored-By: MarisaKirisame <lolisa@marisa.moe>

Update src/relay/pass/partial_eval.cc

Co-Authored-By: MarisaKirisame <lolisa@marisa.moe>

address review feedback

pe now handle freevar. as a result preserving function is now trivial.

test

add basic test, implement pretty printing for generic function

test

lint

fix segfault

save

save

do

test

fix another error

address comment

commit

save

address review feedback

add test for invalidate, fix error in lookup

rename cont to boduy

fix error and add regression test

fix error, add test case

Update src/relay/pass/partial_eval.cc

Co-Authored-By: MarisaKirisame <lolisa@marisa.moe>

fix lint

remove extra line

save

save
parent 28f354bf
...@@ -570,8 +570,8 @@ inline const TTypeNode* ExprNode::type_as() const { ...@@ -570,8 +570,8 @@ inline const TTypeNode* ExprNode::type_as() const {
* \return The text representation. * \return The text representation.
*/ */
std::string AsText(const NodeRef& node, std::string AsText(const NodeRef& node,
bool show_meta_data = true, bool show_meta_data = true,
runtime::TypedPackedFunc<std::string(Expr)> annotate = nullptr); runtime::TypedPackedFunc<std::string(Expr)> annotate = nullptr);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_EXPR_H_ #endif // TVM_RELAY_EXPR_H_
...@@ -89,6 +89,7 @@ class ExprFunctor<R(const Expr& n, Args...)> { ...@@ -89,6 +89,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
* \return The result of the call * \return The result of the call
*/ */
virtual R VisitExpr(const Expr& n, Args... args) { virtual R VisitExpr(const Expr& n, Args... args) {
CHECK(n.defined());
static FType vtable = InitVTable(); static FType vtable = InitVTable();
return vtable(n, this, std::forward<Args>(args)...); return vtable(n, this, std::forward<Args>(args)...);
} }
......
...@@ -64,7 +64,7 @@ ...@@ -64,7 +64,7 @@
#include <tvm/relay/module.h> #include <tvm/relay/module.h>
#include <tvm/relay/op_attr_types.h> #include <tvm/relay/op_attr_types.h>
#include <tvm/relay/type.h> #include <tvm/relay/type.h>
#include <tvm/relay/adt.h>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -344,6 +344,17 @@ TVM_DLL bool WellFormed(const Expr& expr); ...@@ -344,6 +344,17 @@ TVM_DLL bool WellFormed(const Expr& expr);
*/ */
TVM_DLL tvm::Array<Var> BoundVars(const Expr& expr); TVM_DLL tvm::Array<Var> BoundVars(const Expr& expr);
/*! \brief Get all bound variables from pattern pat.
*
* Bound variables are all variables that got bound by the pat.
* They only have meaning inside that expr, and can only be used in it.
*
* \param pat the Pattern.
*
* \return List of bound vars, in the PostDFS order in the expression.
*/
TVM_DLL tvm::Array<Var> BoundVars(const Pattern& pat);
/*! \brief Get free type parameters from expression expr. /*! \brief Get free type parameters from expression expr.
* *
* Free variables are variables that are not bound by a * Free variables are variables that are not bound by a
...@@ -431,12 +442,13 @@ TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Type& t, const Module& mod); ...@@ -431,12 +442,13 @@ TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Type& t, const Module& mod);
/*! \brief Remove expressions which does not effect the program result. /*! \brief Remove expressions which does not effect the program result.
* *
* It will remove let bindings which are not referenced, and branches that will * It will remove let bindings which are not referenced,
* not be entered. * and inline let bindings that are only used once.
* *
* For example, this pass should turn `let a = 1 in 2` into `2`, as the value of * For example, this pass should turn `let a = 1 in 2` into `2`,
* the expression does not depend on a. Another example is `if (true) then 1 * as the value of the expression does not depend on a.
* else 2` will be optimized into 1. *
* As another example, `let a = 1 in a` will be optimized into 1.
* *
* \param e the expression to optimize. * \param e the expression to optimize.
* *
...@@ -558,6 +570,12 @@ TVM_DLL Expr ToANormalForm(const Expr& e, const Module& mod); ...@@ -558,6 +570,12 @@ TVM_DLL Expr ToANormalForm(const Expr& e, const Module& mod);
*/ */
TVM_DLL Expr ToGraphNormalForm(const Expr& e); TVM_DLL Expr ToGraphNormalForm(const Expr& e);
/*! \brief Aggressive constant propagation/constant folding/inlining.
* It will do as much computation in compile time as possible.
* It has two benefit: remove runtime overhead, and allow more optimization (typically fusion).
* As a side effect, code size will explode.
*/
Expr PartialEval(const Expr& e);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
......
...@@ -89,6 +89,7 @@ class PatternFunctor<R(const Pattern& n, Args...)> { ...@@ -89,6 +89,7 @@ class PatternFunctor<R(const Pattern& n, Args...)> {
* \return The result of the call * \return The result of the call
*/ */
virtual R VisitPattern(const Pattern& n, Args... args) { virtual R VisitPattern(const Pattern& n, Args... args) {
CHECK(n.defined());
static FType vtable = InitVTable(); static FType vtable = InitVTable();
return vtable(n, this, std::forward<Args>(args)...); return vtable(n, this, std::forward<Args>(args)...);
} }
......
...@@ -956,3 +956,20 @@ def pass_debug_print(ast, show_meta_data=True, annotate=None, gnf=True): ...@@ -956,3 +956,20 @@ def pass_debug_print(ast, show_meta_data=True, annotate=None, gnf=True):
A text representation of `ast`. A text representation of `ast`.
""" """
return _ir_pass.pass_debug_print(ast, show_meta_data, annotate, gnf) return _ir_pass.pass_debug_print(ast, show_meta_data, annotate, gnf)
def partial_evaluate(expr):
"""
Evaluate the static fragment of the code.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
Returns
-------
expr : tvm.relay.Expr
The output expression.
"""
return _ir_pass.partial_evaluate(expr)
...@@ -556,7 +556,7 @@ class Interpreter : ...@@ -556,7 +556,7 @@ class Interpreter :
CHECK_NE(cvn->constructor->tag, -1); CHECK_NE(cvn->constructor->tag, -1);
if (op->constructor->tag == cvn->constructor->tag) { if (op->constructor->tag == cvn->constructor->tag) {
// todo(M.K.): should use ptr equality but it is broken // todo(M.K.): should use ptr equality but it is broken
CHECK(op->patterns.size() == cvn->fields.size()); CHECK_EQ(op->patterns.size(), cvn->fields.size());
for (size_t i = 0; i < op->patterns.size(); ++i) { for (size_t i = 0; i < op->patterns.size(); ++i) {
if (!VisitPattern(op->patterns[i], cvn->fields[i])) { if (!VisitPattern(op->patterns[i], cvn->fields[i])) {
return false; return false;
......
...@@ -43,9 +43,6 @@ Expr ExprMutator::VisitExpr(const Expr& expr) { ...@@ -43,9 +43,6 @@ Expr ExprMutator::VisitExpr(const Expr& expr) {
} }
Expr ExprMutator::VisitExpr_(const VarNode* op) { Expr ExprMutator::VisitExpr_(const VarNode* op) {
// NOTE: var will only be mutated once
// Thanks to the memo and reused during rewriting if necessary.
// It is safe to assume that the
if (op->type_annotation.defined()) { if (op->type_annotation.defined()) {
auto type = this->VisitType(op->type_annotation); auto type = this->VisitType(op->type_annotation);
if (!op->type_annotation.same_as(type)) { if (!op->type_annotation.same_as(type)) {
......
...@@ -245,15 +245,55 @@ class PrettyPrinter : ...@@ -245,15 +245,55 @@ class PrettyPrinter :
return Doc(unique_prefix); return Doc(unique_prefix);
} }
Doc Print(Kind k) {
switch (k) {
case kType:
return Doc("Type");
case kShapeVar:
return Doc("Shape");
case kBaseType:
return Doc("BaseType");
case kConstraint:
return Doc("Constraint");
case kAdtHandle:
return Doc("AdtHandle");
case kTypeData:
return Doc("TypeData");
default:
LOG(ERROR) << "Unknown Kind";
throw;
}
}
/*! /*!
* \brief Allocate name to a variable. * \brief Allocate name to a type variable.
* \param var The input variable. * \param var The input type variable.
* \return The corresponding name. * \return The corresponding name.
*/ */
Doc AllocTypeVar(const TypeVar& var) {
std::string name = var->var->name_hint;
if (name.length() == 0 || !std::isalpha(name[0])) {
name = "t" + name;
}
Doc val = GetUniqueName("%" + name);
if (memo_type_.count(var)) {
val << "-malformed-ir";
}
memo_type_[var] = val;
if (var->kind != kType) {
val << ": " << Print(var->kind);
}
return val;
}
/*!
* \brief Allocate name to a variable.
* \param var The input variable.
* \return The corresponding name.
*/
Doc AllocVar(const Var& var) { Doc AllocVar(const Var& var) {
std::string name = var->name_hint(); std::string name = var->name_hint();
// always make sure first name is alpha // always make sure first name is alpha
if (name.length() != 0 && !std::isalpha(name[0])) { if (name.length() == 0 || !std::isalpha(name[0])) {
name = "v" + name; name = "v" + name;
} }
Doc val = GetUniqueName("%" + name); Doc val = GetUniqueName("%" + name);
...@@ -387,12 +427,18 @@ class PrettyPrinter : ...@@ -387,12 +427,18 @@ class PrettyPrinter :
} }
Doc PrintFunc(const Doc& prefix, const Function& fn) { Doc PrintFunc(const Doc& prefix, const Function& fn) {
// TODO(tqchen, M.K.) support generic function
// Possibly through meta data
CHECK_EQ(fn->type_params.size(), 0U)
<< "generic fn not yet supported";
Doc doc; Doc doc;
doc << prefix << "("; doc << prefix;
if (fn->type_params.size() > 0) {
doc << "<";
std::vector<Doc> type_params;
for (const TypeVar& tv : fn->type_params) {
type_params.push_back(AllocTypeVar(tv));
}
doc << PrintVec(type_params);
doc << ">";
}
doc << "(";
std::vector<Doc> params; std::vector<Doc> params;
for (Var param : fn->params) { for (Var param : fn->params) {
params.push_back(AllocVar(param)); params.push_back(AllocVar(param));
...@@ -516,6 +562,10 @@ class PrettyPrinter : ...@@ -516,6 +562,10 @@ class PrettyPrinter :
return Print(GetRef<NodeRef>(node), true); return Print(GetRef<NodeRef>(node), true);
} }
Doc VisitType_(const TypeVarNode* node) final {
return AllocTypeVar(GetRef<TypeVar>(node));
}
Doc VisitType_(const TensorTypeNode* node) final { Doc VisitType_(const TensorTypeNode* node) final {
// scalar type // scalar type
if (node->shape.size() == 0) { if (node->shape.size() == 0) {
......
...@@ -77,6 +77,7 @@ class TypeFunctor<R(const Type& n, Args...)> { ...@@ -77,6 +77,7 @@ class TypeFunctor<R(const Type& n, Args...)> {
* \return The result of the call * \return The result of the call
*/ */
virtual R VisitType(const Type& n, Args... args) { virtual R VisitType(const Type& n, Args... args) {
CHECK(n.defined());
static FType vtable = InitVTable(); static FType vtable = InitVTable();
return vtable(n, this, std::forward<Args>(args)...); return vtable(n, this, std::forward<Args>(args)...);
} }
......
...@@ -35,90 +35,109 @@ ...@@ -35,90 +35,109 @@
namespace tvm { namespace tvm {
namespace relay { namespace relay {
bool IsBoolLit(const Expr& e, bool b) {
if (const ConstantNode* c = e.as<ConstantNode>()) {
if (c->is_scalar()) {
auto dt = c->tensor_type()->dtype;
if (dt == Bool()) {
return *reinterpret_cast<const uint8_t*>(c->data->data) == b;
} else if (dt == UInt(8)) {
return *reinterpret_cast<const uint8_t*>(c->data->data) == b;
} else if (dt == UInt(16)) {
return *reinterpret_cast<const uint16_t*>(c->data->data) == b;
} else if (dt == UInt(32)) {
return *reinterpret_cast<const uint32_t*>(c->data->data) == b;
} else if (dt == UInt(64)) {
return *reinterpret_cast<const uint64_t*>(c->data->data) == b;
} else if (dt == Int(8)) {
return *reinterpret_cast<const int8_t*>(c->data->data) == b;
} else if (dt == Int(16)) {
return *reinterpret_cast<const int16_t*>(c->data->data) == b;
} else if (dt == Int(32)) {
return *reinterpret_cast<const int32_t*>(c->data->data) == b;
} else if (dt == Int(64)) {
return *reinterpret_cast<const int64_t*>(c->data->data) == b;
}
}
}
return false;
}
// calculate the dependency graph from expression // calculate the dependency graph from expression
class CalcDep : private ExprMutator { class CalcDep : private ExprVisitor {
public: public:
static Expr Eliminate(const Expr& e) { static Expr Eliminate(const Expr& e) {
CalcDep cd; CalcDep cd;
auto res = cd(e); cd.Calculate(e);
GenLet gl(cd.var_map_); Eliminator el(cd.expr_map_, cd.use_map_, cd.letrec_set_);
gl(res); return el(e);
return gl.lets_.Get(res);
} }
private: private:
using VarMap = std::unordered_map<Var, Expr, NodeHash, NodeEqual>; template<typename X>
VarMap var_map_; using VarMap = std::unordered_map<Var, X, NodeHash, NodeEqual>;
using VarSet = std::unordered_set<Var, NodeHash, NodeEqual>;
Expr VisitExpr_(const IfNode* i) final { VarMap<Expr> expr_map_;
auto cond = VisitExpr(i->cond); VarMap<size_t> use_map_;
if (IsBoolLit(cond, true)) { VarSet letrec_set_;
return Eliminate(i->true_branch); bool count_ = true;
} else if (IsBoolLit(cond, false)) { VarSet dead_worklist_;
return Eliminate(i->false_branch); VarSet current_letrec_;
} else {
return IfNode::make(cond, Eliminate(i->true_branch), Eliminate(i->false_branch)); void LetRec(const std::function<void()>& func, const Var& v) {
current_letrec_.insert(v);
func();
current_letrec_.erase(v);
}
void VisitExpr_(const LetNode* l) final {
if (count_) {
CHECK_EQ(expr_map_.count(l->var), 0);
CHECK_EQ(use_map_.count(l->var), 0);
expr_map_[l->var] = l->value;
use_map_[l->var] = 0;
dead_worklist_.insert(l->var);
LetRec([&]() { VisitExpr(l->value); }, l->var);
} }
VisitExpr(l->body);
} }
Expr VisitExpr_(const LetNode* l) final { void VisitExpr(const Expr& e) final {
var_map_[l->var] = Eliminate(l->value); ExprFunctor<void(const Expr&)>::VisitExpr(e);
return VisitExpr(l->body);
} }
Expr VisitExpr_(const FunctionNode* f) final { void VisitExpr_(const VarNode* v) final {
return FunctionNode::make(f->params, Var var = GetRef<Var>(v);
Eliminate(f->body), if (expr_map_.count(var) == 0) {
f->ret_type, return;
f->type_params); }
if (current_letrec_.count(var) == 0) {
if (count_) {
use_map_[var] += 1;
dead_worklist_.erase(var);
} else {
CHECK_GT(use_map_[var], 0) << var;
use_map_[var] -= 1;
if (use_map_[var] == 0) {
dead_worklist_.insert(var);
}
}
} else {
letrec_set_.insert(var);
}
}
void Calculate(const Expr& v) {
VisitExpr(v);
count_ = false;
while (!dead_worklist_.empty()) {
Var dead = *(dead_worklist_.begin());
dead_worklist_.erase(dead);
CHECK_EQ(use_map_[dead], 0);
if (expr_map_.count(dead) > 0) {
LetRec([&]() { VisitExpr(expr_map_[dead]); }, dead);
}
}
} }
// generate the let list from dependency graph class Eliminator : private ExprMutator {
class GenLet : private ExprVisitor {
private: private:
LetList lets_; VarMap<Expr> expr_map_;
VarMap var_map_; VarMap<size_t> use_map_;
explicit GenLet(const VarMap& var_map) : var_map_(var_map) { } VarSet letrec_set_;
explicit Eliminator(const VarMap<Expr>& expr_map,
const VarMap<size_t>& use_map,
const VarSet& letrec_set) :
expr_map_(expr_map), use_map_(use_map), letrec_set_(letrec_set) { }
friend CalcDep; friend CalcDep;
void VisitExpr_(const VarNode* vnode) final { bool HasLet(const Var& v) {
Var v = GetRef<Var>(vnode); return (use_map_[v] > 1 || (use_map_[v] != 0 && letrec_set_.count(v) != 0));
auto it = var_map_.find(v); }
if (it != var_map_.end()) {
Expr expr = it->second; Expr VisitExpr_(const VarNode* op) final {
var_map_.erase(it); Var v = GetRef<Var>(op);
// erase before visit to handle letrec return (expr_map_.count(v) == 0 || HasLet(v)) ? v : VisitExpr(expr_map_[v]);
VisitExpr(expr); }
// visit before push back so the dependency of dependency is before the dependency
lets_.Push(v, expr); Expr VisitExpr_(const LetNode* op) final {
Var v = op->var;
if (HasLet(v)) {
return LetNode::make(v, VisitExpr(op->value), VisitExpr(op->body));
} else {
return VisitExpr(op->body);
} }
} }
}; };
......
...@@ -42,7 +42,6 @@ namespace relay { ...@@ -42,7 +42,6 @@ namespace relay {
std::unordered_map<const Node*, size_t> std::unordered_map<const Node*, size_t>
GetExprRefCount(const Expr& body); GetExprRefCount(const Expr& body);
/*! /*!
* \brief Check if expr is positive constant. * \brief Check if expr is positive constant.
* \param expr The expression to be checked. * \param expr The expression to be checked.
...@@ -50,7 +49,6 @@ GetExprRefCount(const Expr& body); ...@@ -50,7 +49,6 @@ GetExprRefCount(const Expr& body);
*/ */
bool IsAllPositiveConstant(const Expr& expr); bool IsAllPositiveConstant(const Expr& expr);
/*! /*!
* \brief Substitute var with subst. * \brief Substitute var with subst.
* \param type The type to be substituted. * \param type The type to be substituted.
...@@ -61,6 +59,15 @@ bool IsAllPositiveConstant(const Expr& expr); ...@@ -61,6 +59,15 @@ bool IsAllPositiveConstant(const Expr& expr);
Type TypeSubst(const Type& type, const TypeVar& tvar, const Type& subst); Type TypeSubst(const Type& type, const TypeVar& tvar, const Type& subst);
/*! /*!
* \brief Substitute var with subst.
* \param expr The expr to be substituted.
* \param tvar The type variable to be substituted.
* \param subst The target of substitution.
* \return The substituted result.
*/
Expr TypeSubst(const Expr& expr, const TypeVar& tvar, const Type& subst);
/*!
* \brief Substitute type vars in type. * \brief Substitute type vars in type.
* \param type The type to be substituted. * \param type The type to be substituted.
* \param subst_map The map of substitution. * \param subst_map The map of substitution.
...@@ -68,6 +75,28 @@ Type TypeSubst(const Type& type, const TypeVar& tvar, const Type& subst); ...@@ -68,6 +75,28 @@ Type TypeSubst(const Type& type, const TypeVar& tvar, const Type& subst);
*/ */
Type TypeSubst(const Type& type, const tvm::Map<TypeVar, Type>& subst_map); Type TypeSubst(const Type& type, const tvm::Map<TypeVar, Type>& subst_map);
/*!
* \brief Substitute type vars in type.
* \param expr The expr to be substituted.
* \param subst_map The map of substitution.
* \return The substituted result.
*/
Expr TypeSubst(const Expr& expr, const tvm::Map<TypeVar, Type>& subst_map);
/*!
* \brief Make arbitrary transformation preserve the out most function.
* \param func The transformation.
* \param e The expression
* \return the transformed expression. If e is a function the return is also a function.
*/
inline Expr TransformF(const std::function<Expr(const Expr&)>& func, const Expr& e) {
if (const FunctionNode* f = e.as<FunctionNode>()) {
return FunctionNode::make(f->params, func(f->body), f->ret_type, f->type_params, f->attrs);
} else {
return func(e);
}
}
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_PASS_PASS_UTIL_H_ #endif // TVM_RELAY_PASS_PASS_UTIL_H_
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include "let_list.h" #include "let_list.h"
#include "../../common/arena.h" #include "../../common/arena.h"
#include "pass_util.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -481,15 +482,7 @@ Expr ToANormalFormAux(const Expr& e, const Module& m, std::set<GlobalVar>* gv) { ...@@ -481,15 +482,7 @@ Expr ToANormalFormAux(const Expr& e, const Module& m, std::set<GlobalVar>* gv) {
} }
Expr ToANormalForm(const Expr& e, const Module& m, std::set<GlobalVar>* gv) { Expr ToANormalForm(const Expr& e, const Module& m, std::set<GlobalVar>* gv) {
if (const auto* f = e.as<FunctionNode>()) { return TransformF([&](const Expr& e) { return ToANormalFormAux(e, m, gv); }, e);
return FunctionNode::make(f->params,
ToANormalFormAux(f->body, m, gv),
f->ret_type,
f->type_params,
f->attrs);
} else {
return ToANormalFormAux(e, m, gv);
}
} }
Expr ToANormalForm(const Expr& e, const Module& m) { Expr ToANormalForm(const Expr& e, const Module& m) {
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <tvm/relay/pass.h> #include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h> #include <tvm/relay/pattern_functor.h>
#include "pass_util.h"
#include "../ir/type_functor.h" #include "../ir/type_functor.h"
namespace tvm { namespace tvm {
...@@ -171,8 +172,7 @@ class VarVisitor : protected ExprVisitor, protected PatternVisitor { ...@@ -171,8 +172,7 @@ class VarVisitor : protected ExprVisitor, protected PatternVisitor {
return ret; return ret;
} }
Array<Var> Bound(const Expr& expr) { Array<Var> Collect() {
this->VisitExpr(expr);
Array<Var> ret; Array<Var> ret;
for (const auto& v : bound_vars_.data) { for (const auto& v : bound_vars_.data) {
ret.push_back(v); ret.push_back(v);
...@@ -180,6 +180,16 @@ class VarVisitor : protected ExprVisitor, protected PatternVisitor { ...@@ -180,6 +180,16 @@ class VarVisitor : protected ExprVisitor, protected PatternVisitor {
return ret; return ret;
} }
Array<Var> Bound(const Expr& expr) {
this->VisitExpr(expr);
return Collect();
}
Array<Var> Bound(const Pattern& pat) {
this->VisitPattern(pat);
return Collect();
}
Array<Var> All(const Expr& expr) { Array<Var> All(const Expr& expr) {
this->VisitExpr(expr); this->VisitExpr(expr);
Array<Var> ret; Array<Var> ret;
...@@ -256,6 +266,10 @@ tvm::Array<Var> BoundVars(const Expr& expr) { ...@@ -256,6 +266,10 @@ tvm::Array<Var> BoundVars(const Expr& expr) {
return VarVisitor().Bound(expr); return VarVisitor().Bound(expr);
} }
tvm::Array<Var> BoundVars(const Pattern& pat) {
return VarVisitor().Bound(pat);
}
tvm::Array<Var> AllVars(const Expr& expr) { tvm::Array<Var> AllVars(const Expr& expr) {
return VarVisitor().All(expr); return VarVisitor().All(expr);
} }
...@@ -267,7 +281,12 @@ TVM_REGISTER_API("relay._ir_pass.free_vars") ...@@ -267,7 +281,12 @@ TVM_REGISTER_API("relay._ir_pass.free_vars")
TVM_REGISTER_API("relay._ir_pass.bound_vars") TVM_REGISTER_API("relay._ir_pass.bound_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = BoundVars(args[0]); NodeRef x = args[0];
if (x.as_derived<ExprNode>()) {
*ret = BoundVars(Downcast<Expr>(x));
} else {
*ret = BoundVars(Downcast<Pattern>(x));
}
}); });
TVM_REGISTER_API("relay._ir_pass.all_vars") TVM_REGISTER_API("relay._ir_pass.all_vars")
...@@ -388,5 +407,33 @@ bool IsAllPositiveConstant(const Expr& expr) { ...@@ -388,5 +407,33 @@ bool IsAllPositiveConstant(const Expr& expr) {
} }
} }
Type TypeSubst(const Type& type, const TypeVar& tvar, const Type& subst) {
return TypeSubst(type, tvm::Map<TypeVar, Type>({{tvar, subst}}));
}
Expr TypeSubst(const Expr& expr, const TypeVar& tvar, const Type& subst) {
return TypeSubst(expr, tvm::Map<TypeVar, Type>({{tvar, subst}}));
}
Type TypeSubst(const Type& type, const tvm::Map<TypeVar, Type>& subst_map) {
return Bind(type, subst_map);
}
Expr TypeSubst(const Expr& expr, const tvm::Map<TypeVar, Type>& subst_map) {
class TypeSubstMutator : public ExprMutator, public PatternMutator {
public:
explicit TypeSubstMutator(const tvm::Map<TypeVar, Type>& subst_map) : subst_map_(subst_map) { }
Type VisitType(const Type& t) final {
return TypeSubst(t, subst_map_);
}
Var VisitVar(const Var& v) final {
return Downcast<Var>(VisitExpr(v));
}
private:
const tvm::Map<TypeVar, Type>& subst_map_;
};
return TypeSubstMutator(subst_map).VisitExpr(expr);
}
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -48,8 +48,13 @@ def test_let(): ...@@ -48,8 +48,13 @@ def test_let():
def test_used_let(): def test_used_let():
orig = relay.Let(e.c, e.one, e.c + e.c)
assert alpha_equal(dead_code_elimination(orig), relay.Let(e.c, e.one, e.c + e.c))
def test_inline():
orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c)) orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c))
assert alpha_equal(dead_code_elimination(orig), relay.Let(e.c, e.d, e.c)) assert alpha_equal(dead_code_elimination(orig), e.d)
def test_chain_unused_let(): def test_chain_unused_let():
...@@ -87,13 +92,6 @@ def test_op_let(): ...@@ -87,13 +92,6 @@ def test_op_let():
assert alpha_equal(dead_code_elimination(add(relay.Let(e.a, e.one, e.three), e.two)), add(e.three, e.two)) assert alpha_equal(dead_code_elimination(add(relay.Let(e.a, e.one, e.three), e.two)), add(e.three, e.two))
def test_if():
cond = relay.const(True)
orig = relay.If(cond, e.a, e.b)
y = dead_code_elimination(orig)
assert alpha_equal(y, e.a)
def test_tuple_get_item(): def test_tuple_get_item():
t = relay.Var('t') t = relay.Var('t')
g = relay.TupleGetItem(t, 0) g = relay.TupleGetItem(t, 0)
...@@ -102,9 +100,9 @@ def test_tuple_get_item(): ...@@ -102,9 +100,9 @@ def test_tuple_get_item():
if __name__ == "__main__": if __name__ == "__main__":
test_if()
test_let() test_let()
test_used_let() test_used_let()
test_inline()
test_chain_unused_let() test_chain_unused_let()
test_recursion() test_recursion()
test_op_let() test_op_let()
......
...@@ -22,9 +22,11 @@ from tvm.relay.prelude import Prelude ...@@ -22,9 +22,11 @@ from tvm.relay.prelude import Prelude
import numpy as np import numpy as np
def rand(dtype='float32', *shape): def rand(dtype='float32', *shape):
return tvm.nd.array(np.random.rand(*shape).astype(dtype)) return tvm.nd.array(np.random.rand(*shape).astype(dtype))
def test_id(): def test_id():
shape = (10, 10) shape = (10, 10)
dtype = 'float32' dtype = 'float32'
......
import numpy as np
import tvm
from tvm import relay
from tvm.relay.ir_pass import partial_evaluate, dead_code_elimination
from tvm.relay.ir_pass import gradient, alpha_equal, infer_type
from tvm.relay import op, create_executor
from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue
from tvm.relay.prelude import Prelude
from tvm.relay import create_executor
def check_eval(expr, expected_result, mod=None, rtol=1e-07):
ctx = tvm.context("llvm", 0)
intrp = create_executor(mod=mod, ctx=ctx, target="llvm")
result = intrp.evaluate(expr)
np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol)
def dcpe(expr):
return dead_code_elimination(partial_evaluate(expr))
def test_tuple():
t = relay.TypeVar("t")
x = relay.Var("x", t)
body = relay.TupleGetItem(relay.Tuple([relay.const(4.0), x]), 1)
f = relay.Function([x], body, None, [t])
assert alpha_equal(dcpe(f), relay.Function([x], x, None, [t]))
def test_const_inline():
d = relay.Var("d")
double = relay.Function([d], d + d)
orig = double(relay.const(4.0))
assert alpha_equal(dcpe(double(relay.const(4.0))), relay.const(8.0))
def test_ref():
d = relay.Var("d")
r = relay.Var("r")
x = relay.Var("x")
body = relay.RefRead(r)
body = relay.Let(x, relay.RefWrite(r, relay.RefRead(r) * relay.RefRead(r)), body)
body = relay.Let(r, relay.RefCreate(d), body)
square = relay.Function([d], body)
assert alpha_equal(dcpe(square), relay.Function([d], d * d))
def test_ad():
shape = (10, 10)
dtype = "float32"
t = relay.TensorType(shape, dtype)
d = relay.Var("d", t)
f = relay.Function([d], d * d)
g = dcpe(gradient(f))
m = d * d
o = relay.op.ones_like(m)
grad = relay.op.zeros_like(d) + relay.op.collapse_sum_like(o * d, d) + relay.op.collapse_sum_like(o * d, d)
expected = relay.Function([d], relay.Tuple([m, relay.Tuple([grad])]))
assert alpha_equal(g, expected)
def test_if_ref():
shape = ()
dtype = "bool"
t = relay.TensorType(shape, dtype)
d = relay.Var("d", t)
r = relay.Var("r")
update = relay.Function([], relay.RefWrite(r, relay.RefRead(r) + relay.RefRead(r)))
u = relay.Var("u")
body = relay.If(d, u(), u())
eff = relay.Var("eff")
body = relay.Let(eff, body, relay.RefRead(r))
f = relay.Function([d], relay.Let(r, relay.RefCreate(relay.const(1)), relay.Let(u, update, body)))
f = infer_type(f)
pe_f = infer_type(partial_evaluate(f))
ex = create_executor()
f_res = ex.evaluate(f)(relay.const(True))
pe_f_res = ex.evaluate(pe_f)(relay.const(True))
np.testing.assert_allclose(f_res.asnumpy(), 2 * np.ones_like(f_res.asnumpy()))
np.testing.assert_allclose(pe_f_res.asnumpy(), 2 * np.ones_like(pe_f_res.asnumpy()))
def test_function_invalidate():
shape = ()
dtype = "bool"
t = relay.TensorType(shape, dtype)
d = relay.Var("d", t)
r = relay.Var("r")
fetch = relay.Function([], relay.RefRead(r))
fet = relay.Var("fetch")
fet_obscured = relay.Var("fetch_obscured")
u = relay.Var("u")
body = relay.If(d, fet_obscured(), fet_obscured())
body = relay.Let(u, relay.RefWrite(r, relay.const(1)), body)
body = relay.Let(fet_obscured, relay.If(d, fet, fet), body)
body = relay.Let(fet, fetch, body)
body = relay.Let(r, relay.RefCreate(relay.const(0)), body)
f = relay.Function([d], body)
f = infer_type(f)
pe_f = infer_type(partial_evaluate(f))
ex = create_executor()
f_res = ex.evaluate(f)(relay.const(True))
pe_f_res = ex.evaluate(pe_f)(relay.const(True))
np.testing.assert_allclose(f_res.asnumpy(), np.ones_like(f_res.asnumpy()))
np.testing.assert_allclose(pe_f_res.asnumpy(), np.ones_like(pe_f_res.asnumpy()))
def test_head_cons():
mod = relay.Module()
p = Prelude(mod)
def hd_impl():
a = relay.TypeVar("a")
x = relay.Var("x", p.l(a))
y = relay.Var("y")
z = relay.Var("z")
cons_case = relay.Clause(relay.PatternConstructor(p.cons,
[relay.PatternVar(y),
relay.PatternVar(z)]),
y)
return relay.Function([x], relay.Match(x, [cons_case]), a, [a])
t = relay.TypeVar("t")
x = relay.Var("x", t)
hd = relay.Var("hd")
body = relay.Let(hd, hd_impl(), hd(p.cons(x, p.nil())))
f = relay.Function([x], body, None, [t])
f = infer_type(f, mod=mod)
res = dcpe(f)
assert alpha_equal(res, relay.Function([x], x, t, [t]))
if __name__ == '__main__':
test_tuple()
test_const_inline()
test_ref()
test_ad()
test_if_ref()
test_function_invalidate()
test_head_cons()
...@@ -154,6 +154,7 @@ def test_add(): ...@@ -154,6 +154,7 @@ def test_add():
assert count(intrp.evaluate(to_a_normal_form(add(s(z()), s(z())), mod))) == 2 assert count(intrp.evaluate(to_a_normal_form(add(s(z()), s(z())), mod))) == 2
assert "let" in mod[add].astext() assert "let" in mod[add].astext()
def test_let(): def test_let():
x = relay.Var("x") x = relay.Var("x")
y = relay.Var("y") y = relay.Var("y")
...@@ -163,6 +164,17 @@ def test_let(): ...@@ -163,6 +164,17 @@ def test_let():
check_eval(body, 8) check_eval(body, 8)
check_eval(to_a_normal_form(body), 8) check_eval(to_a_normal_form(body), 8)
def test_function():
x = relay.Var("x")
f = relay.Function([x], x + x)
d = relay.const(4.0, 'float32')
anf_f = to_a_normal_form(f)
assert isinstance(anf_f, relay.Function)
check_eval(f(d), 8)
check_eval(anf_f(d), 8)
if __name__ == '__main__': if __name__ == '__main__':
test_explicit_bound() test_explicit_bound()
test_order() test_order()
...@@ -171,3 +183,4 @@ if __name__ == '__main__': ...@@ -171,3 +183,4 @@ if __name__ == '__main__':
test_ref() test_ref()
test_add() test_add()
test_let() test_let()
test_function()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment