Commit bb87f044 by 雾雨魔理沙 Committed by Tianqi Chen

add document (#2714)

lint

lint

save

save

add more case

save

error

lint

lint

commit

do

lint

save

fix lint

wrap it back as func

lint

save

remove dead comment

fix style

fix lint

Update src/relay/pass/partial_eval.cc

Co-Authored-By: MarisaKirisame <lolisa@marisa.moe>

Update src/relay/pass/partial_eval.cc

Co-Authored-By: MarisaKirisame <lolisa@marisa.moe>

Update src/relay/pass/partial_eval.cc

Co-Authored-By: MarisaKirisame <lolisa@marisa.moe>

Update src/relay/pass/partial_eval.cc

Co-Authored-By: MarisaKirisame <lolisa@marisa.moe>

Update src/relay/pass/partial_eval.cc

Co-Authored-By: MarisaKirisame <lolisa@marisa.moe>

Update src/relay/pass/partial_eval.cc

Co-Authored-By: MarisaKirisame <lolisa@marisa.moe>

address review feedback

pe now handle freevar. as a result preserving function is now trivial.

test

add basic test, implement pretty printing for generic function

test

lint

fix segfault

save

save

do

test

fix another error

address comment

commit

save

address review feedback

add test for invalidate, fix error in lookup

rename cont to boduy

fix error and add regression test

fix error, add test case

Update src/relay/pass/partial_eval.cc

Co-Authored-By: MarisaKirisame <lolisa@marisa.moe>

fix lint

remove extra line

save

save
parent 28f354bf
......@@ -89,6 +89,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
* \return The result of the call
*/
virtual R VisitExpr(const Expr& n, Args... args) {
CHECK(n.defined());
static FType vtable = InitVTable();
return vtable(n, this, std::forward<Args>(args)...);
}
......
......@@ -64,7 +64,7 @@
#include <tvm/relay/module.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/type.h>
#include <tvm/relay/adt.h>
#include <string>
#include <vector>
......@@ -344,6 +344,17 @@ TVM_DLL bool WellFormed(const Expr& expr);
*/
TVM_DLL tvm::Array<Var> BoundVars(const Expr& expr);
/*! \brief Get all bound variables from pattern pat.
*
* Bound variables are all variables that got bound by the pat.
* They only have meaning inside that expr, and can only be used in it.
*
* \param pat the Pattern.
*
* \return List of bound vars, in the PostDFS order in the expression.
*/
TVM_DLL tvm::Array<Var> BoundVars(const Pattern& pat);
/*! \brief Get free type parameters from expression expr.
*
* Free variables are variables that are not bound by a
......@@ -431,12 +442,13 @@ TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Type& t, const Module& mod);
/*! \brief Remove expressions which does not effect the program result.
*
* It will remove let bindings which are not referenced, and branches that will
* not be entered.
* It will remove let bindings which are not referenced,
* and inline let bindings that are only used once.
*
* For example, this pass should turn `let a = 1 in 2` into `2`, as the value of
* the expression does not depend on a. Another example is `if (true) then 1
* else 2` will be optimized into 1.
* For example, this pass should turn `let a = 1 in 2` into `2`,
* as the value of the expression does not depend on a.
*
* As another example, `let a = 1 in a` will be optimized into 1.
*
* \param e the expression to optimize.
*
......@@ -558,6 +570,12 @@ TVM_DLL Expr ToANormalForm(const Expr& e, const Module& mod);
*/
TVM_DLL Expr ToGraphNormalForm(const Expr& e);
/*! \brief Aggressive constant propagation/constant folding/inlining.
* It will do as much computation in compile time as possible.
* It has two benefit: remove runtime overhead, and allow more optimization (typically fusion).
* As a side effect, code size will explode.
*/
Expr PartialEval(const Expr& e);
} // namespace relay
} // namespace tvm
......
......@@ -89,6 +89,7 @@ class PatternFunctor<R(const Pattern& n, Args...)> {
* \return The result of the call
*/
virtual R VisitPattern(const Pattern& n, Args... args) {
CHECK(n.defined());
static FType vtable = InitVTable();
return vtable(n, this, std::forward<Args>(args)...);
}
......
......@@ -956,3 +956,20 @@ def pass_debug_print(ast, show_meta_data=True, annotate=None, gnf=True):
A text representation of `ast`.
"""
return _ir_pass.pass_debug_print(ast, show_meta_data, annotate, gnf)
def partial_evaluate(expr):
"""
Evaluate the static fragment of the code.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
Returns
-------
expr : tvm.relay.Expr
The output expression.
"""
return _ir_pass.partial_evaluate(expr)
......@@ -556,7 +556,7 @@ class Interpreter :
CHECK_NE(cvn->constructor->tag, -1);
if (op->constructor->tag == cvn->constructor->tag) {
// todo(M.K.): should use ptr equality but it is broken
CHECK(op->patterns.size() == cvn->fields.size());
CHECK_EQ(op->patterns.size(), cvn->fields.size());
for (size_t i = 0; i < op->patterns.size(); ++i) {
if (!VisitPattern(op->patterns[i], cvn->fields[i])) {
return false;
......
......@@ -43,9 +43,6 @@ Expr ExprMutator::VisitExpr(const Expr& expr) {
}
Expr ExprMutator::VisitExpr_(const VarNode* op) {
// NOTE: var will only be mutated once
// Thanks to the memo and reused during rewriting if necessary.
// It is safe to assume that the
if (op->type_annotation.defined()) {
auto type = this->VisitType(op->type_annotation);
if (!op->type_annotation.same_as(type)) {
......
......@@ -245,6 +245,46 @@ class PrettyPrinter :
return Doc(unique_prefix);
}
Doc Print(Kind k) {
switch (k) {
case kType:
return Doc("Type");
case kShapeVar:
return Doc("Shape");
case kBaseType:
return Doc("BaseType");
case kConstraint:
return Doc("Constraint");
case kAdtHandle:
return Doc("AdtHandle");
case kTypeData:
return Doc("TypeData");
default:
LOG(ERROR) << "Unknown Kind";
throw;
}
}
/*!
* \brief Allocate name to a type variable.
* \param var The input type variable.
* \return The corresponding name.
*/
Doc AllocTypeVar(const TypeVar& var) {
std::string name = var->var->name_hint;
if (name.length() == 0 || !std::isalpha(name[0])) {
name = "t" + name;
}
Doc val = GetUniqueName("%" + name);
if (memo_type_.count(var)) {
val << "-malformed-ir";
}
memo_type_[var] = val;
if (var->kind != kType) {
val << ": " << Print(var->kind);
}
return val;
}
/*!
* \brief Allocate name to a variable.
* \param var The input variable.
......@@ -253,7 +293,7 @@ class PrettyPrinter :
Doc AllocVar(const Var& var) {
std::string name = var->name_hint();
// always make sure first name is alpha
if (name.length() != 0 && !std::isalpha(name[0])) {
if (name.length() == 0 || !std::isalpha(name[0])) {
name = "v" + name;
}
Doc val = GetUniqueName("%" + name);
......@@ -387,12 +427,18 @@ class PrettyPrinter :
}
Doc PrintFunc(const Doc& prefix, const Function& fn) {
// TODO(tqchen, M.K.) support generic function
// Possibly through meta data
CHECK_EQ(fn->type_params.size(), 0U)
<< "generic fn not yet supported";
Doc doc;
doc << prefix << "(";
doc << prefix;
if (fn->type_params.size() > 0) {
doc << "<";
std::vector<Doc> type_params;
for (const TypeVar& tv : fn->type_params) {
type_params.push_back(AllocTypeVar(tv));
}
doc << PrintVec(type_params);
doc << ">";
}
doc << "(";
std::vector<Doc> params;
for (Var param : fn->params) {
params.push_back(AllocVar(param));
......@@ -516,6 +562,10 @@ class PrettyPrinter :
return Print(GetRef<NodeRef>(node), true);
}
Doc VisitType_(const TypeVarNode* node) final {
return AllocTypeVar(GetRef<TypeVar>(node));
}
Doc VisitType_(const TensorTypeNode* node) final {
// scalar type
if (node->shape.size() == 0) {
......
......@@ -77,6 +77,7 @@ class TypeFunctor<R(const Type& n, Args...)> {
* \return The result of the call
*/
virtual R VisitType(const Type& n, Args... args) {
CHECK(n.defined());
static FType vtable = InitVTable();
return vtable(n, this, std::forward<Args>(args)...);
}
......
......@@ -35,90 +35,109 @@
namespace tvm {
namespace relay {
bool IsBoolLit(const Expr& e, bool b) {
if (const ConstantNode* c = e.as<ConstantNode>()) {
if (c->is_scalar()) {
auto dt = c->tensor_type()->dtype;
if (dt == Bool()) {
return *reinterpret_cast<const uint8_t*>(c->data->data) == b;
} else if (dt == UInt(8)) {
return *reinterpret_cast<const uint8_t*>(c->data->data) == b;
} else if (dt == UInt(16)) {
return *reinterpret_cast<const uint16_t*>(c->data->data) == b;
} else if (dt == UInt(32)) {
return *reinterpret_cast<const uint32_t*>(c->data->data) == b;
} else if (dt == UInt(64)) {
return *reinterpret_cast<const uint64_t*>(c->data->data) == b;
} else if (dt == Int(8)) {
return *reinterpret_cast<const int8_t*>(c->data->data) == b;
} else if (dt == Int(16)) {
return *reinterpret_cast<const int16_t*>(c->data->data) == b;
} else if (dt == Int(32)) {
return *reinterpret_cast<const int32_t*>(c->data->data) == b;
} else if (dt == Int(64)) {
return *reinterpret_cast<const int64_t*>(c->data->data) == b;
}
}
}
return false;
}
// calculate the dependency graph from expression
class CalcDep : private ExprMutator {
class CalcDep : private ExprVisitor {
public:
static Expr Eliminate(const Expr& e) {
CalcDep cd;
auto res = cd(e);
GenLet gl(cd.var_map_);
gl(res);
return gl.lets_.Get(res);
cd.Calculate(e);
Eliminator el(cd.expr_map_, cd.use_map_, cd.letrec_set_);
return el(e);
}
private:
using VarMap = std::unordered_map<Var, Expr, NodeHash, NodeEqual>;
VarMap var_map_;
Expr VisitExpr_(const IfNode* i) final {
auto cond = VisitExpr(i->cond);
if (IsBoolLit(cond, true)) {
return Eliminate(i->true_branch);
} else if (IsBoolLit(cond, false)) {
return Eliminate(i->false_branch);
} else {
return IfNode::make(cond, Eliminate(i->true_branch), Eliminate(i->false_branch));
template<typename X>
using VarMap = std::unordered_map<Var, X, NodeHash, NodeEqual>;
using VarSet = std::unordered_set<Var, NodeHash, NodeEqual>;
VarMap<Expr> expr_map_;
VarMap<size_t> use_map_;
VarSet letrec_set_;
bool count_ = true;
VarSet dead_worklist_;
VarSet current_letrec_;
void LetRec(const std::function<void()>& func, const Var& v) {
current_letrec_.insert(v);
func();
current_letrec_.erase(v);
}
void VisitExpr_(const LetNode* l) final {
if (count_) {
CHECK_EQ(expr_map_.count(l->var), 0);
CHECK_EQ(use_map_.count(l->var), 0);
expr_map_[l->var] = l->value;
use_map_[l->var] = 0;
dead_worklist_.insert(l->var);
LetRec([&]() { VisitExpr(l->value); }, l->var);
}
VisitExpr(l->body);
}
Expr VisitExpr_(const LetNode* l) final {
var_map_[l->var] = Eliminate(l->value);
return VisitExpr(l->body);
void VisitExpr(const Expr& e) final {
ExprFunctor<void(const Expr&)>::VisitExpr(e);
}
Expr VisitExpr_(const FunctionNode* f) final {
return FunctionNode::make(f->params,
Eliminate(f->body),
f->ret_type,
f->type_params);
void VisitExpr_(const VarNode* v) final {
Var var = GetRef<Var>(v);
if (expr_map_.count(var) == 0) {
return;
}
if (current_letrec_.count(var) == 0) {
if (count_) {
use_map_[var] += 1;
dead_worklist_.erase(var);
} else {
CHECK_GT(use_map_[var], 0) << var;
use_map_[var] -= 1;
if (use_map_[var] == 0) {
dead_worklist_.insert(var);
}
}
} else {
letrec_set_.insert(var);
}
}
void Calculate(const Expr& v) {
VisitExpr(v);
count_ = false;
while (!dead_worklist_.empty()) {
Var dead = *(dead_worklist_.begin());
dead_worklist_.erase(dead);
CHECK_EQ(use_map_[dead], 0);
if (expr_map_.count(dead) > 0) {
LetRec([&]() { VisitExpr(expr_map_[dead]); }, dead);
}
}
}
// generate the let list from dependency graph
class GenLet : private ExprVisitor {
class Eliminator : private ExprMutator {
private:
LetList lets_;
VarMap var_map_;
explicit GenLet(const VarMap& var_map) : var_map_(var_map) { }
VarMap<Expr> expr_map_;
VarMap<size_t> use_map_;
VarSet letrec_set_;
explicit Eliminator(const VarMap<Expr>& expr_map,
const VarMap<size_t>& use_map,
const VarSet& letrec_set) :
expr_map_(expr_map), use_map_(use_map), letrec_set_(letrec_set) { }
friend CalcDep;
void VisitExpr_(const VarNode* vnode) final {
Var v = GetRef<Var>(vnode);
auto it = var_map_.find(v);
if (it != var_map_.end()) {
Expr expr = it->second;
var_map_.erase(it);
// erase before visit to handle letrec
VisitExpr(expr);
// visit before push back so the dependency of dependency is before the dependency
lets_.Push(v, expr);
bool HasLet(const Var& v) {
return (use_map_[v] > 1 || (use_map_[v] != 0 && letrec_set_.count(v) != 0));
}
Expr VisitExpr_(const VarNode* op) final {
Var v = GetRef<Var>(op);
return (expr_map_.count(v) == 0 || HasLet(v)) ? v : VisitExpr(expr_map_[v]);
}
Expr VisitExpr_(const LetNode* op) final {
Var v = op->var;
if (HasLet(v)) {
return LetNode::make(v, VisitExpr(op->value), VisitExpr(op->body));
} else {
return VisitExpr(op->body);
}
}
};
......
......@@ -42,7 +42,6 @@ namespace relay {
std::unordered_map<const Node*, size_t>
GetExprRefCount(const Expr& body);
/*!
* \brief Check if expr is positive constant.
* \param expr The expression to be checked.
......@@ -50,7 +49,6 @@ GetExprRefCount(const Expr& body);
*/
bool IsAllPositiveConstant(const Expr& expr);
/*!
* \brief Substitute var with subst.
* \param type The type to be substituted.
......@@ -61,6 +59,15 @@ bool IsAllPositiveConstant(const Expr& expr);
Type TypeSubst(const Type& type, const TypeVar& tvar, const Type& subst);
/*!
* \brief Substitute var with subst.
* \param expr The expr to be substituted.
* \param tvar The type variable to be substituted.
* \param subst The target of substitution.
* \return The substituted result.
*/
Expr TypeSubst(const Expr& expr, const TypeVar& tvar, const Type& subst);
/*!
* \brief Substitute type vars in type.
* \param type The type to be substituted.
* \param subst_map The map of substitution.
......@@ -68,6 +75,28 @@ Type TypeSubst(const Type& type, const TypeVar& tvar, const Type& subst);
*/
Type TypeSubst(const Type& type, const tvm::Map<TypeVar, Type>& subst_map);
/*!
* \brief Substitute type vars in type.
* \param expr The expr to be substituted.
* \param subst_map The map of substitution.
* \return The substituted result.
*/
Expr TypeSubst(const Expr& expr, const tvm::Map<TypeVar, Type>& subst_map);
/*!
* \brief Make arbitrary transformation preserve the out most function.
* \param func The transformation.
* \param e The expression
* \return the transformed expression. If e is a function the return is also a function.
*/
inline Expr TransformF(const std::function<Expr(const Expr&)>& func, const Expr& e) {
if (const FunctionNode* f = e.as<FunctionNode>()) {
return FunctionNode::make(f->params, func(f->body), f->ret_type, f->type_params, f->attrs);
} else {
return func(e);
}
}
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_PASS_UTIL_H_
......@@ -28,6 +28,7 @@
#include <tvm/relay/expr_functor.h>
#include "let_list.h"
#include "../../common/arena.h"
#include "pass_util.h"
namespace tvm {
namespace relay {
......@@ -481,15 +482,7 @@ Expr ToANormalFormAux(const Expr& e, const Module& m, std::set<GlobalVar>* gv) {
}
Expr ToANormalForm(const Expr& e, const Module& m, std::set<GlobalVar>* gv) {
if (const auto* f = e.as<FunctionNode>()) {
return FunctionNode::make(f->params,
ToANormalFormAux(f->body, m, gv),
f->ret_type,
f->type_params,
f->attrs);
} else {
return ToANormalFormAux(e, m, gv);
}
return TransformF([&](const Expr& e) { return ToANormalFormAux(e, m, gv); }, e);
}
Expr ToANormalForm(const Expr& e, const Module& m) {
......
......@@ -27,6 +27,7 @@
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include "pass_util.h"
#include "../ir/type_functor.h"
namespace tvm {
......@@ -171,8 +172,7 @@ class VarVisitor : protected ExprVisitor, protected PatternVisitor {
return ret;
}
Array<Var> Bound(const Expr& expr) {
this->VisitExpr(expr);
Array<Var> Collect() {
Array<Var> ret;
for (const auto& v : bound_vars_.data) {
ret.push_back(v);
......@@ -180,6 +180,16 @@ class VarVisitor : protected ExprVisitor, protected PatternVisitor {
return ret;
}
Array<Var> Bound(const Expr& expr) {
this->VisitExpr(expr);
return Collect();
}
Array<Var> Bound(const Pattern& pat) {
this->VisitPattern(pat);
return Collect();
}
Array<Var> All(const Expr& expr) {
this->VisitExpr(expr);
Array<Var> ret;
......@@ -256,6 +266,10 @@ tvm::Array<Var> BoundVars(const Expr& expr) {
return VarVisitor().Bound(expr);
}
tvm::Array<Var> BoundVars(const Pattern& pat) {
return VarVisitor().Bound(pat);
}
tvm::Array<Var> AllVars(const Expr& expr) {
return VarVisitor().All(expr);
}
......@@ -267,7 +281,12 @@ TVM_REGISTER_API("relay._ir_pass.free_vars")
TVM_REGISTER_API("relay._ir_pass.bound_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = BoundVars(args[0]);
NodeRef x = args[0];
if (x.as_derived<ExprNode>()) {
*ret = BoundVars(Downcast<Expr>(x));
} else {
*ret = BoundVars(Downcast<Pattern>(x));
}
});
TVM_REGISTER_API("relay._ir_pass.all_vars")
......@@ -388,5 +407,33 @@ bool IsAllPositiveConstant(const Expr& expr) {
}
}
Type TypeSubst(const Type& type, const TypeVar& tvar, const Type& subst) {
return TypeSubst(type, tvm::Map<TypeVar, Type>({{tvar, subst}}));
}
Expr TypeSubst(const Expr& expr, const TypeVar& tvar, const Type& subst) {
return TypeSubst(expr, tvm::Map<TypeVar, Type>({{tvar, subst}}));
}
Type TypeSubst(const Type& type, const tvm::Map<TypeVar, Type>& subst_map) {
return Bind(type, subst_map);
}
Expr TypeSubst(const Expr& expr, const tvm::Map<TypeVar, Type>& subst_map) {
class TypeSubstMutator : public ExprMutator, public PatternMutator {
public:
explicit TypeSubstMutator(const tvm::Map<TypeVar, Type>& subst_map) : subst_map_(subst_map) { }
Type VisitType(const Type& t) final {
return TypeSubst(t, subst_map_);
}
Var VisitVar(const Var& v) final {
return Downcast<Var>(VisitExpr(v));
}
private:
const tvm::Map<TypeVar, Type>& subst_map_;
};
return TypeSubstMutator(subst_map).VisitExpr(expr);
}
} // namespace relay
} // namespace tvm
......@@ -48,8 +48,13 @@ def test_let():
def test_used_let():
orig = relay.Let(e.c, e.one, e.c + e.c)
assert alpha_equal(dead_code_elimination(orig), relay.Let(e.c, e.one, e.c + e.c))
def test_inline():
orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c))
assert alpha_equal(dead_code_elimination(orig), relay.Let(e.c, e.d, e.c))
assert alpha_equal(dead_code_elimination(orig), e.d)
def test_chain_unused_let():
......@@ -87,13 +92,6 @@ def test_op_let():
assert alpha_equal(dead_code_elimination(add(relay.Let(e.a, e.one, e.three), e.two)), add(e.three, e.two))
def test_if():
cond = relay.const(True)
orig = relay.If(cond, e.a, e.b)
y = dead_code_elimination(orig)
assert alpha_equal(y, e.a)
def test_tuple_get_item():
t = relay.Var('t')
g = relay.TupleGetItem(t, 0)
......@@ -102,9 +100,9 @@ def test_tuple_get_item():
if __name__ == "__main__":
test_if()
test_let()
test_used_let()
test_inline()
test_chain_unused_let()
test_recursion()
test_op_let()
......
......@@ -22,9 +22,11 @@ from tvm.relay.prelude import Prelude
import numpy as np
def rand(dtype='float32', *shape):
return tvm.nd.array(np.random.rand(*shape).astype(dtype))
def test_id():
shape = (10, 10)
dtype = 'float32'
......
import numpy as np
import tvm
from tvm import relay
from tvm.relay.ir_pass import partial_evaluate, dead_code_elimination
from tvm.relay.ir_pass import gradient, alpha_equal, infer_type
from tvm.relay import op, create_executor
from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue
from tvm.relay.prelude import Prelude
from tvm.relay import create_executor
def check_eval(expr, expected_result, mod=None, rtol=1e-07):
ctx = tvm.context("llvm", 0)
intrp = create_executor(mod=mod, ctx=ctx, target="llvm")
result = intrp.evaluate(expr)
np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol)
def dcpe(expr):
return dead_code_elimination(partial_evaluate(expr))
def test_tuple():
t = relay.TypeVar("t")
x = relay.Var("x", t)
body = relay.TupleGetItem(relay.Tuple([relay.const(4.0), x]), 1)
f = relay.Function([x], body, None, [t])
assert alpha_equal(dcpe(f), relay.Function([x], x, None, [t]))
def test_const_inline():
d = relay.Var("d")
double = relay.Function([d], d + d)
orig = double(relay.const(4.0))
assert alpha_equal(dcpe(double(relay.const(4.0))), relay.const(8.0))
def test_ref():
d = relay.Var("d")
r = relay.Var("r")
x = relay.Var("x")
body = relay.RefRead(r)
body = relay.Let(x, relay.RefWrite(r, relay.RefRead(r) * relay.RefRead(r)), body)
body = relay.Let(r, relay.RefCreate(d), body)
square = relay.Function([d], body)
assert alpha_equal(dcpe(square), relay.Function([d], d * d))
def test_ad():
shape = (10, 10)
dtype = "float32"
t = relay.TensorType(shape, dtype)
d = relay.Var("d", t)
f = relay.Function([d], d * d)
g = dcpe(gradient(f))
m = d * d
o = relay.op.ones_like(m)
grad = relay.op.zeros_like(d) + relay.op.collapse_sum_like(o * d, d) + relay.op.collapse_sum_like(o * d, d)
expected = relay.Function([d], relay.Tuple([m, relay.Tuple([grad])]))
assert alpha_equal(g, expected)
def test_if_ref():
shape = ()
dtype = "bool"
t = relay.TensorType(shape, dtype)
d = relay.Var("d", t)
r = relay.Var("r")
update = relay.Function([], relay.RefWrite(r, relay.RefRead(r) + relay.RefRead(r)))
u = relay.Var("u")
body = relay.If(d, u(), u())
eff = relay.Var("eff")
body = relay.Let(eff, body, relay.RefRead(r))
f = relay.Function([d], relay.Let(r, relay.RefCreate(relay.const(1)), relay.Let(u, update, body)))
f = infer_type(f)
pe_f = infer_type(partial_evaluate(f))
ex = create_executor()
f_res = ex.evaluate(f)(relay.const(True))
pe_f_res = ex.evaluate(pe_f)(relay.const(True))
np.testing.assert_allclose(f_res.asnumpy(), 2 * np.ones_like(f_res.asnumpy()))
np.testing.assert_allclose(pe_f_res.asnumpy(), 2 * np.ones_like(pe_f_res.asnumpy()))
def test_function_invalidate():
shape = ()
dtype = "bool"
t = relay.TensorType(shape, dtype)
d = relay.Var("d", t)
r = relay.Var("r")
fetch = relay.Function([], relay.RefRead(r))
fet = relay.Var("fetch")
fet_obscured = relay.Var("fetch_obscured")
u = relay.Var("u")
body = relay.If(d, fet_obscured(), fet_obscured())
body = relay.Let(u, relay.RefWrite(r, relay.const(1)), body)
body = relay.Let(fet_obscured, relay.If(d, fet, fet), body)
body = relay.Let(fet, fetch, body)
body = relay.Let(r, relay.RefCreate(relay.const(0)), body)
f = relay.Function([d], body)
f = infer_type(f)
pe_f = infer_type(partial_evaluate(f))
ex = create_executor()
f_res = ex.evaluate(f)(relay.const(True))
pe_f_res = ex.evaluate(pe_f)(relay.const(True))
np.testing.assert_allclose(f_res.asnumpy(), np.ones_like(f_res.asnumpy()))
np.testing.assert_allclose(pe_f_res.asnumpy(), np.ones_like(pe_f_res.asnumpy()))
def test_head_cons():
mod = relay.Module()
p = Prelude(mod)
def hd_impl():
a = relay.TypeVar("a")
x = relay.Var("x", p.l(a))
y = relay.Var("y")
z = relay.Var("z")
cons_case = relay.Clause(relay.PatternConstructor(p.cons,
[relay.PatternVar(y),
relay.PatternVar(z)]),
y)
return relay.Function([x], relay.Match(x, [cons_case]), a, [a])
t = relay.TypeVar("t")
x = relay.Var("x", t)
hd = relay.Var("hd")
body = relay.Let(hd, hd_impl(), hd(p.cons(x, p.nil())))
f = relay.Function([x], body, None, [t])
f = infer_type(f, mod=mod)
res = dcpe(f)
assert alpha_equal(res, relay.Function([x], x, t, [t]))
if __name__ == '__main__':
test_tuple()
test_const_inline()
test_ref()
test_ad()
test_if_ref()
test_function_invalidate()
test_head_cons()
......@@ -154,6 +154,7 @@ def test_add():
assert count(intrp.evaluate(to_a_normal_form(add(s(z()), s(z())), mod))) == 2
assert "let" in mod[add].astext()
def test_let():
x = relay.Var("x")
y = relay.Var("y")
......@@ -163,6 +164,17 @@ def test_let():
check_eval(body, 8)
check_eval(to_a_normal_form(body), 8)
def test_function():
x = relay.Var("x")
f = relay.Function([x], x + x)
d = relay.const(4.0, 'float32')
anf_f = to_a_normal_form(f)
assert isinstance(anf_f, relay.Function)
check_eval(f(d), 8)
check_eval(anf_f(d), 8)
if __name__ == '__main__':
test_explicit_bound()
test_order()
......@@ -171,3 +183,4 @@ if __name__ == '__main__':
test_ref()
test_add()
test_let()
test_function()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment