Commit d1eb1229 by 雾雨魔理沙 Committed by Wuwei Lin

[Relay] Continuation Passing Style (#3456)

* save

add

me find type checker problem

save

save

lint

do

lint

reset ti

add some doc

add failed test case

add recursion for cps

add recursion for cps

fix pytest

lint

save

fix test error

lint

save

fix error

* fix rebase

* fix

* fix test

* lint

* lint

* restore rewriteannotationops

* do
parent 988ea2ac
......@@ -252,17 +252,6 @@ TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Expr& expr, const Module& mod);
TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Type& t, const Module& mod);
/*!
* \brief Rewrite the annotated program.
*
* \param expr The expression.
* \param fallback_device The fallback device which is the default device for
* operators without annotation.
*
* \return The updated program.
*/
TVM_DLL Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device);
/*!
* \brief Collect the device mapping information of each expression.
*
* \param expr The expression.
......
......@@ -405,6 +405,22 @@ TVM_DLL Pass RewriteAnnotatedOps(int fallback_device);
TVM_DLL Pass ToANormalForm();
/*!
* \brief Turn an expression into continuation passing style(CPS).
*
* CPS mean that every function will, instead of returning the result directly,
* be passed down an extra function (called the continuation) as argument,
* and pass the result to the continuation instead.
*
* Thus, every function call has to be passed an extra argument
* that represent the rest of the computation (Hence the name of continuation).
*
* Similarly, all other compute will be wrapped and call the continuation as well.
*
* \return the pass.
*/
TVM_DLL Pass ToCPS();
/*!
* \brief Remove let binding and directly share via pointer instead.
*
* It will remove all let binding,
......@@ -586,6 +602,57 @@ TVM_DLL Expr ForwardRewrite(const Expr& expr,
std::function<NodeRef(const Call&)> fcontext = nullptr,
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
/*!
* \brief Rewrite the annotated program.
*
* \param expr The expression.
* \param fallback_device The fallback device which is the default device for
* operators without annotation.
*
* \return The updated program.
*/
TVM_DLL Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device);
/*!
* \brief Turn an expression into continuation passing style(CPS).
*
* CPS mean that every function will, instead of returning the result directly,
* be passed down an extra function (called the continuation) as argument,
* and pass the result to the continuation instead.
*
* Thus, every function call has to be passed an extra argument
* that represent the rest of the computation (Hence the name of continuation).
*
* Similarly, all other compute will be wrapped and call the continuation as well.
*
* \param f the function.
* \param mod the module.
*
* \return the converted Function.
*/
TVM_DLL Function ToCPS(const Function& f, const Module& mod);
/*!
* \brief Remove the continuation argument of a CPS function.
*
* Note that this only transform the type back into un-CPS form
* when there is no higher order input/output.
*
* \param f the function.
*
* \return the converted Function.
*/
TVM_DLL Function UnCPS(const Function& f);
/*!
* \brief Deduplicate the bound variables and type variables in the expression.
*
* \param e the expression.
*
* \return the deduplicated expression.
*/
TVM_DLL Expr DeDup(const Expr& e);
} // namespace relay
} // namespace tvm
......
......@@ -17,6 +17,9 @@
"""Utilities for testing and benchmarks"""
from __future__ import absolute_import as _abs
import tvm.relay as relay
from tvm.relay import transform
from . import mlp
from . import resnet
from . import dqn
......@@ -32,3 +35,15 @@ from . import yolo_detection
from .config import ctx_list
from .init import create_workload
from .nat import add_nat_definitions, count, make_nat_value, make_nat_expr
def run_opt_pass(expr, opt_pass):
assert isinstance(opt_pass, transform.Pass)
mod = relay.Module.from_expr(expr)
mod = opt_pass(mod)
entry = mod[mod.entry_func]
return entry if isinstance(expr, relay.Function) else entry.body
def run_infer_type(expr):
return run_opt_pass(expr, transform.InferType())
......@@ -446,6 +446,20 @@ def ToANormalForm():
return _transform.ToANormalForm()
def ToCPS(expr, mod=None):
"""
Turn expression into continuation passing style(CPS).
Every intermediate compute will be passed to a continuation.
Returns
-------
result: tvm.relay.Pass
The registered pass that transforms an expression into CPS.
"""
return _ir_pass.to_cps(expr, mod)
def EtaExpand():
"""Add abstraction over a function
......@@ -495,14 +509,6 @@ def PartialEvaluate():
expression is provided. Otherwise, it will rely on the pass manager to
carry out transformation.
Parameters
----------
expr : Optional[tvm.relay.Expr]
The input expression.
mod : Optional[tvm.relay.Module]
The global module.
Returns
-------
ret: tvm.relay.Pass
......@@ -554,6 +560,48 @@ def gradient(expr, mod=None, mode='higher_order'):
raise Exception('unknown mode')
def to_cps(func, mod=None):
"""
Turn expression into CPS expression.
Every intermediate compute will be passed to a continuation.
Parameters
----------
func: tvm.relay.Function
The input function.
mod: Optional[tvm.relay.Module]
The global module.
Returns
-------
result: tvm.relay.Function
The output function.
"""
return _transform.to_cps(func, mod)
def un_cps(func):
"""
Turn an cps function into a Function without the continuation argument.
Note that this will not give the exact same interface as before cps:
If the input/output is higher order, they will still be in cps form.
Parameters
----------
func: tvm.relay.Function
The input function
Returns
-------
result: tvm.relay.Function
The output function
"""
return _transform.un_cps(func)
def _wrap_class_module_pass(pass_cls, pass_info):
"""Wrap a python class as function pass"""
class PyModulePass(ModulePass):
......
......@@ -18,7 +18,7 @@
*/
/*!
* Copyright (c) 2018 by Contributors
* Copyright (c) 2019 by Contributors
* \file src/tvm/ir/adt.cc
* \brief AST nodes for Relay algebraic data types (ADTs).
*/
......
......@@ -89,8 +89,9 @@ GlobalTypeVar ModuleNode::GetGlobalTypeVar(const std::string& name) const {
}
void ModuleNode::Add(const GlobalVar& var,
const Function& func,
const Function& f,
bool update) {
Function func = Downcast<Function>(DeDup(f));
// Type check the item before we add it to the module.
auto mod = GetRef<Module>(this);
Function checked_func = InferType(func, mod, var);
......
......@@ -645,11 +645,21 @@ class PrettyPrinter :
Doc VisitType_(const FuncTypeNode* node) final {
Doc doc;
doc << "fn ";
if (node->type_params.size() != 0) {
doc << "<";
std::vector<Doc> type_params;
for (Type type_param : node->type_params) {
type_params.push_back(Print(type_param));
}
doc << PrintVec(type_params);
doc << ">";
}
std::vector<Doc> arg_types;
for (Type arg_type : node->arg_types) {
arg_types.push_back(Print(arg_type));
}
return doc << "fn (" << PrintVec(arg_types) << ") -> " << Print(node->ret_type);
return doc << "(" << PrintVec(arg_types) << ") -> " << Print(node->ret_type);
}
Doc VisitType_(const RefTypeNode* node) final {
......
......@@ -221,7 +221,7 @@ class TypeBinder : public TypeMutator {
};
Type Bind(const Type& type, const tvm::Map<TypeVar, Type>& args_map) {
return TypeBinder(args_map).VisitType(type);
return type.defined() ? TypeBinder(args_map).VisitType(type) : type;
}
} // namespace relay
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
*
* \file de_duplicate.cc
* \brief Use a fresh Id for every Var to make the result well-formed.
*/
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/pattern_functor.h>
#include "../ir/type_functor.h"
namespace tvm {
namespace relay {
Expr DeDup(const Expr& e) {
class DeDupMutator : public TypeMutator,
public ExprMutator,
public PatternMutator {
public:
TypeVar Fresh(const TypeVar& tv) {
TypeVar ret = TypeVarNode::make(tv->var->name_hint, tv->kind);
type_rename_[tv] = ret;
return ret;
}
Var Fresh(const Var& v) {
Var ret = VarNode::make(v->name_hint(), VisitType(v->type_annotation));
rename_[v] = ret;
return ret;
}
Expr VisitExpr(const Expr& e) final {
return ExprMutator::VisitExpr(e);
}
Expr VisitExpr_(const VarNode* op) final {
Var v = GetRef<Var>(op);
return rename_.count(v) != 0 ? rename_.at(v) : v;
}
Expr VisitExpr_(const LetNode* op) final {
Var v = Fresh(op->var);
return LetNode::make(v, VisitExpr(op->value), VisitExpr(op->body));
}
Type VisitType(const Type& t) final {
return t.defined() ? TypeMutator::VisitType(t) : t;
}
Expr VisitExpr_(const FunctionNode* op) final {
tvm::Array<TypeVar> type_params;
for (const TypeVar& type_param : op->type_params) {
type_params.push_back(Fresh(type_param));
}
tvm::Array<Var> params;
for (const Var& param : op->params) {
params.push_back(Fresh(param));
}
return FunctionNode::make(params,
VisitExpr(op->body),
VisitType(op->ret_type),
type_params,
op->attrs);
}
Pattern VisitPattern(const Pattern& p) final {
return PatternMutator::VisitPattern(p);
}
Pattern VisitPattern_(const PatternVarNode* op) final {
return PatternVarNode::make(Fresh(op->var));
}
Clause VisitClause(const Clause& c) final {
Pattern pat = VisitPattern(c->lhs);
return ClauseNode::make(pat, VisitExpr(c->rhs));
}
Type VisitType_(const TypeVarNode* op) final {
TypeVar v = GetRef<TypeVar>(op);
return type_rename_.count(v) != 0 ? type_rename_.at(v) : v;
}
Var VisitVar(const Var& v) final {
return Fresh(v);
}
private:
std::unordered_map<Var, Var, NodeHash, NodeEqual> rename_;
std::unordered_map<TypeVar, TypeVar, NodeHash, NodeEqual> type_rename_;
};
Expr ret = DeDupMutator().VisitExpr(e);
CHECK_EQ(FreeVars(ret).size(), FreeVars(e).size());
return ret;
}
TVM_REGISTER_API("relay._transform.dedup")
.set_body_typed(DeDup);
} // namespace relay
} // namespace tvm
......@@ -20,7 +20,7 @@
/*!
* Copyright (c) 2019 by Contributors.
* \file tvm/relay/pass/dependency_graph.h
* \brief
* \brief create a dependency graph.
*/
#ifndef TVM_RELAY_PASS_DEPENDENCY_GRAPH_H_
#define TVM_RELAY_PASS_DEPENDENCY_GRAPH_H_
......
......@@ -18,7 +18,7 @@
*/
/*!
* Copyright (c) 2018 by Contributors
* Copyright (c) 2019 by Contributors
* \file let_list.h
* \brief LetList record let binding and insert let expression implicitly.
* using it, one can treat AST as value instead of expression,
......@@ -46,6 +46,11 @@ namespace relay {
*/
class LetList {
public:
~LetList() {
if (lets_.size() > 0 && !used_) {
std::cout << "Warning: letlist not used" << std::endl;
}
}
/*!
* \brief insert a binding.
*
......@@ -64,13 +69,13 @@ class LetList {
/*!
* \brief insert a binding.
*
* \param ty the type of the binding.
*
* \param expr the value of the binding.
*
* \param ty the type of the binding.
*
* \return a Var that hold the inserted expr.
*/
Var Push(Type ty, Expr expr) {
Var Push(Expr expr, Type ty) {
return Push(VarNode::make("x", ty), expr);
}
......@@ -82,7 +87,7 @@ class LetList {
* \return a Var that hold the inserted expr.
*/
Var Push(Expr expr) {
return Push(Type(), expr);
return Push(expr, Type());
}
/*!
......@@ -129,6 +134,12 @@ class LetList {
return ll.Get(f(&ll));
}
static Expr Let(const Expr& e, const std::function<Expr(const Var&)>& f) {
return With([&](LetList* ll) {
return f(ll->Push(e));
});
}
private:
std::vector<std::pair<Var, Expr> > lets_;
bool used_ = false;
......
......@@ -18,7 +18,7 @@
*/
/*!
* Copyright (c) 2018 by Contributors
* Copyright (c) 2019 by Contributors
*
* \file partial_eval.cc
*
......@@ -426,8 +426,6 @@ TVM_ADD_FILELINE)
Expr StripWithFuncId(const Expr& e);
Expr DeDup(const Expr& e);
Function AsFunc(const Expr& e) {
if (e.as<FunctionNode>()) {
return Downcast<Function>(e);
......@@ -963,86 +961,6 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
FInterpreter executor_ = CPUInterpreter();
};
/*! \brief Use a fresh Id for every Var to make the result well-formed. */
Expr DeDup(const Expr& e) {
class DeDupMutator : public TypeMutator,
public ExprMutator,
public PatternMutator {
public:
TypeVar Fresh(const TypeVar& tv) {
TypeVar ret = TypeVarNode::make(tv->var->name_hint, tv->kind);
type_rename_[tv] = ret;
return ret;
}
Var Fresh(const Var& v) {
Var ret = VarNode::make(v->name_hint(), VisitType(v->type_annotation));
rename_[v] = ret;
return ret;
}
Expr VisitExpr(const Expr& e) final {
return ExprMutator::VisitExpr(e);
}
Expr VisitExpr_(const VarNode* op) final {
Var v = GetRef<Var>(op);
return rename_.count(v) != 0 ? rename_.at(v) : v;
}
Expr VisitExpr_(const LetNode* op) final {
Var v = Fresh(op->var);
return LetNode::make(v, VisitExpr(op->value), VisitExpr(op->body));
}
Type VisitType(const Type& t) final {
return t.defined() ? TypeMutator::VisitType(t) : t;
}
Expr VisitExpr_(const FunctionNode* op) final {
tvm::Array<TypeVar> type_params;
for (const TypeVar& type_param : op->type_params) {
type_params.push_back(Fresh(type_param));
}
tvm::Array<Var> params;
for (const Var& param : op->params) {
params.push_back(Fresh(param));
}
return FunctionNode::make(params,
VisitExpr(op->body),
VisitType(op->ret_type),
type_params,
op->attrs);
}
Pattern VisitPattern(const Pattern& p) final {
return PatternMutator::VisitPattern(p);
}
Clause VisitClause(const Clause& c) final {
Pattern pat = VisitPattern(c->lhs);
return ClauseNode::make(pat, VisitExpr(c->rhs));
}
Type VisitType_(const TypeVarNode* op) final {
TypeVar v = GetRef<TypeVar>(op);
return type_rename_.count(v) != 0 ? type_rename_.at(v) : v;
}
Var VisitVar(const Var& v) final {
return Fresh(v);
}
private:
std::unordered_map<Var, Var, NodeHash, NodeEqual> rename_;
std::unordered_map<TypeVar, TypeVar, NodeHash, NodeEqual> type_rename_;
};
Expr ret = DeDupMutator().VisitExpr(e);
CHECK_EQ(FreeVars(ret).size(), FreeVars(e).size());
return ret;
}
/*! \brief Remap multiple Var sharing the same Id into the same Var. */
Expr Remap(const Expr& e) {
class RemapMutator : public ExprMutator, public PatternMutator {
......
......@@ -18,9 +18,9 @@
*/
/*!
* Copyright (c) 2018 by Contributors
* Copyright (c) 2019 by Contributors
*
* \file to_anf.cc
* \file to_a_normal_form.cc
*
* \brief Turn implicit sharing into observable sharing.
*/
......@@ -72,13 +72,16 @@ Scope LCA(Scope lhs, Scope rhs) {
std::unordered_map<DependencyGraph::Node*, Scope> CalcScope(const DependencyGraph& dg) {
std::unordered_map<DependencyGraph::Node*, Scope> expr_scope;
bool global_scope_used = false;
Scope global_scope = std::make_shared<ScopeNode>();
for (auto it = dg.post_dfs_order.rbegin(); it != dg.post_dfs_order.rend(); ++it) {
DependencyGraph::Node* n = *it;
auto iit = n->parents.head;
Scope s;
if (iit == nullptr) {
CHECK(!global_scope_used);
s = global_scope;
global_scope_used = true;
} else {
s = expr_scope.at(iit->value);
iit = iit->next;
......@@ -88,13 +91,10 @@ std::unordered_map<DependencyGraph::Node*, Scope> CalcScope(const DependencyGrap
}
expr_scope.insert({n, n->new_scope ? ChildScope(s) : s});
}
CHECK(global_scope_used);
return expr_scope;
}
bool IsPrimitiveFunction(const Expr& e) {
return e.as<FunctionNode>() && Downcast<Function>(e)->IsPrimitive();
}
/* Special care is needed to handle local recursion.
* Fill additionally take a (possibly null) Var argument,
* If it is not null, Fill is required to bind the transformed result to that var.
......@@ -137,22 +137,26 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
Expr VisitExpr(const Expr& e, const Var& v) final {
if (memo.count(e) == 0) {
memo.insert({e, ExprFunctor<Expr(const Expr&, const Var&)>::VisitExpr(e, v)});
} else if (v.defined()) {
GetScope(e)->ll->Push(v, memo.at(e));
}
return memo.at(e);
auto ret = memo.at(e);
CHECK(IsAtomic(ret));
return ret;
}
Expr VisitExpr(const Expr& e) {
return this->VisitExpr(e, Var());
}
Expr Atomic(const Expr& orig, const Expr& now, const Var& v) {
return v.defined() ? GetScope(orig)->ll->Push(v, now) : now;
Expr Atomic(const Expr& e, const Var& v) {
return v.defined() ? GetScope(e)->ll->Push(v, e) : e;
}
Expr Compound(const Expr& orig, const Expr& now, const Var& v) {
Var var = v.defined() ?
v :
VarNode::make(std::string("x"), IncompleteTypeNode::make(Kind::kType));
VarNode::make(std::string("x"), Type());
return GetScope(orig)->ll->Push(var, now);
}
......@@ -205,7 +209,7 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
Expr VisitExpr_(const FunctionNode* f, const Var& v) final {
Expr e = GetRef<Expr>(f);
Expr ret;
if (IsPrimitiveFunction(e)) {
if (f->IsPrimitive()) {
ret = e;
} else {
ret = FunctionNode::make(f->params,
......@@ -231,22 +235,22 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
Expr VisitExpr_(const VarNode* vn, const Var& v) final {
Expr e = GetRef<Expr>(vn);
return Atomic(e, e, v);
return Atomic(e, v);
}
Expr VisitExpr_(const GlobalVarNode* gvn, const Var& v) final {
GlobalVar gv = GetRef<GlobalVar>(gvn);
return Atomic(gv, gv, v);
return Atomic(gv, v);
}
Expr VisitExpr_(const OpNode* op, const Var& v) final {
Expr e = GetRef<Expr>(op);
return Atomic(e, e, v);
return Atomic(e, v);
}
Expr VisitExpr_(const ConstructorNode* c, const Var& v) final {
Expr e = GetRef<Expr>(c);
return Atomic(e, e, v);
return Atomic(e, v);
}
Expr VisitExpr_(const MatchNode* m, const Var& v) final {
......@@ -294,11 +298,15 @@ Module ToANormalForm(const Module& m) {
tvm::Map<GlobalVar, Function> updates;
auto funcs = m->functions;
for (const auto& it : funcs) {
CHECK_EQ(FreeVars(it.second).size(), 0);
Expr ret =
TransformF([&](const Expr& e) {
return ToANormalFormAux(e);
}, it.second);
CHECK_EQ(FreeVars(ret).size(), 0);
CHECK_EQ(FreeVars(ret).size(), 0)
<< AsText(ret)
<< "should not has free vars: "
<< FreeVars(ret);
updates.Set(it.first, Downcast<Function>(ret));
}
......
......@@ -368,10 +368,14 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
// Build a subsitituion map up from the function type and type arguments.
// Eventually allow the type vars to be passed in.
for (size_t i = 0; i < fn_ty->type_params.size(); i++) {
for (size_t i = 0; i < ty_args.size(); ++i) {
subst_map.Set(fn_ty->type_params[i], ty_args[i]);
}
for (size_t i = ty_args.size(); i < fn_ty->type_params.size(); ++i) {
subst_map.Set(fn_ty->type_params[i], IncompleteTypeNode::make(Kind::kType));
}
Type ret_type = fn_ty->ret_type;
// If the function type is incomplete, place a new IncompleteType
......@@ -437,13 +441,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
}
Array<Type> type_args = call->type_args;
if (type_args.size() == 0) {
for (size_t i = 0; i < fn_ty_node->type_params.size(); i++) {
type_args.push_back(IncompleteTypeNode::make(Kind::kType));
}
}
if (type_args.size() != fn_ty_node->type_params.size()) {
if (type_args.size() > fn_ty_node->type_params.size()) {
this->ReportFatalError(GetRef<Call>(call),
RELAY_ERROR("Incorrect number of type args in "
<< call->span << ": "
......
......@@ -17,14 +17,7 @@
import tvm
from tvm import relay
from tvm.relay import transform
def run_opt_pass(expr, opt_pass):
assert isinstance(opt_pass, transform.Pass)
mod = relay.Module.from_expr(expr)
mod = opt_pass(mod)
entry = mod[mod.entry_func]
return entry if isinstance(expr, relay.Function) else entry.body
from tvm.relay.testing import run_opt_pass
def test_fuse_simple():
......
......@@ -22,15 +22,7 @@ from tvm.relay.analysis import free_vars, free_type_vars
from tvm.relay import create_executor, transform
from tvm.relay.transform import gradient
from tvm.relay.prelude import Prelude
from tvm.relay.testing import add_nat_definitions, make_nat_expr
def run_infer_type(expr):
mod = relay.Module.from_expr(expr)
mod = relay.Module.from_expr(expr)
mod = transform.InferType()(mod)
entry = mod[mod.entry_func]
return entry if isinstance(expr, relay.Function) else entry.body
from tvm.relay.testing import add_nat_definitions, make_nat_expr, run_infer_type
def rand(dtype='float32', *shape):
......
......@@ -186,6 +186,19 @@ def test_function():
check_eval(anf_f(d), 8)
def test_gradient_if():
x = relay.var("a", shape=(1, 16))
y = relay.var("y", shape=(1, 16))
cond = relay.var("cond", shape=(), dtype='uint1')
net = relay.If(cond, x, x)
net = relay.add(x, net)
net = relay.Function([cond,x,y], net)
mod = relay.Module.from_expr(net)
mod = relay.transform.ToANormalForm()(mod)
mod[mod.entry_func] = relay.transform.gradient(mod[mod.entry_func], mode='higher_order')
mod = relay.transform.ToANormalForm()(mod)
if __name__ == '__main__':
test_explicit_bound()
test_order()
......@@ -195,3 +208,4 @@ if __name__ == '__main__':
test_let()
test_nat_add()
test_function()
test_gradient_if()
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy as np
import tvm
from tvm import relay
from tvm.relay.analysis import alpha_equal, detect_feature
from tvm.relay.transform import to_cps, un_cps
from tvm.relay.feature import Feature
from tvm.relay.prelude import Prelude
from tvm.relay.testing import add_nat_definitions, make_nat_expr, run_infer_type, run_opt_pass
from tvm.relay import create_executor
from tvm.relay import Function, transform
def rand(dtype='float32', *shape):
return tvm.nd.array(np.random.rand(*shape).astype(dtype))
# make sure cps work for recursion.
def test_recursion():
mod = relay.Module()
p = Prelude(mod)
add_nat_definitions(p)
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
x = relay.var("x", t)
double = relay.Function([x], x + x)
i = relay.var("i", t)
func = relay.Function([i], p.nat_iterate(double, make_nat_expr(p, 3))(i))
mod[mod.entry_func] = func
mod[mod.entry_func] = to_cps(mod[mod.entry_func], mod=mod)
mod[mod.entry_func] = un_cps(mod[mod.entry_func])
ex = create_executor(mod=mod)
i_nd = rand(dtype, *shape)
forward = ex.evaluate(mod.entry_func)(i_nd)
tvm.testing.assert_allclose(forward.asnumpy(), 8 * i_nd.asnumpy())
# This serve as an integration test.
# It test that, given a program with reference,
# cps and pe can completely eliminate the allocation of reference.
def test_cps_pe():
def destroy_ref(x):
x = run_infer_type(x)
x = to_cps(x)
x = run_infer_type(x)
y = un_cps(x)
y = run_infer_type(y)
x = run_opt_pass(x, transform.Sequential([transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True)]))
assert Feature.fRefCreate not in detect_feature(x)
unit = relay.Function([], relay.const(0., dtype='float32'))
f_ref = relay.Var("f_ref")
one = relay.const(1., dtype='float32')
two = relay.const(2., dtype='float32')
cond = relay.var(shape=(), dtype='uint1', name_hint='cond')
true_branch = relay.RefWrite(f_ref, relay.Function([], one))
false_branch = relay.RefWrite(f_ref, relay.Function([], two))
if_expr = relay.If(cond, true_branch, false_branch)
stmt = relay.Let(f_ref, relay.RefCreate(unit),
relay.Let(relay.Var("x"), if_expr,
relay.Call(relay.RefRead(f_ref), [])))
F = relay.Function([cond], stmt)
destroy_ref(F)
G = relay.Function([cond], relay.If(cond, one, two))
G = relay.transform.gradient(G)
destroy_ref(G)
x = relay.var("x", shape=(1, 16))
y = relay.var("y", shape=(1, 16))
z = relay.var("z", shape=(1, 16))
cond = relay.var("cond", shape=(), dtype='uint1')
H = relay.If(cond, x, y)
H = relay.add(H, z)
H = relay.Function([cond,x,y,z], H)
H = relay.transform.gradient(H)
destroy_ref(H)
if __name__ == '__main__':
test_recursion()
test_cps_pe()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment