Commit bb87f044 by 雾雨魔理沙 Committed by Tianqi Chen

add document (#2714)

lint

lint

save

save

add more case

save

error

lint

lint

commit

do

lint

save

fix lint

wrap it back as func

lint

save

remove dead comment

fix style

fix lint

Update src/relay/pass/partial_eval.cc

Co-Authored-By: MarisaKirisame <lolisa@marisa.moe>

Update src/relay/pass/partial_eval.cc

Co-Authored-By: MarisaKirisame <lolisa@marisa.moe>

Update src/relay/pass/partial_eval.cc

Co-Authored-By: MarisaKirisame <lolisa@marisa.moe>

Update src/relay/pass/partial_eval.cc

Co-Authored-By: MarisaKirisame <lolisa@marisa.moe>

Update src/relay/pass/partial_eval.cc

Co-Authored-By: MarisaKirisame <lolisa@marisa.moe>

Update src/relay/pass/partial_eval.cc

Co-Authored-By: MarisaKirisame <lolisa@marisa.moe>

address review feedback

pe now handle freevar. as a result preserving function is now trivial.

test

add basic test, implement pretty printing for generic function

test

lint

fix segfault

save

save

do

test

fix another error

address comment

commit

save

address review feedback

add test for invalidate, fix error in lookup

rename cont to boduy

fix error and add regression test

fix error, add test case

Update src/relay/pass/partial_eval.cc

Co-Authored-By: MarisaKirisame <lolisa@marisa.moe>

fix lint

remove extra line

save

save
parent 28f354bf
......@@ -89,6 +89,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
* \return The result of the call
*/
virtual R VisitExpr(const Expr& n, Args... args) {
CHECK(n.defined());
static FType vtable = InitVTable();
return vtable(n, this, std::forward<Args>(args)...);
}
......
......@@ -64,7 +64,7 @@
#include <tvm/relay/module.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/type.h>
#include <tvm/relay/adt.h>
#include <string>
#include <vector>
......@@ -344,6 +344,17 @@ TVM_DLL bool WellFormed(const Expr& expr);
*/
TVM_DLL tvm::Array<Var> BoundVars(const Expr& expr);
/*! \brief Get all bound variables from pattern pat.
*
* Bound variables are all variables that got bound by the pat.
* They only have meaning inside that expr, and can only be used in it.
*
* \param pat the Pattern.
*
* \return List of bound vars, in the PostDFS order in the expression.
*/
TVM_DLL tvm::Array<Var> BoundVars(const Pattern& pat);
/*! \brief Get free type parameters from expression expr.
*
* Free variables are variables that are not bound by a
......@@ -431,12 +442,13 @@ TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Type& t, const Module& mod);
/*! \brief Remove expressions which does not effect the program result.
*
* It will remove let bindings which are not referenced, and branches that will
* not be entered.
* It will remove let bindings which are not referenced,
* and inline let bindings that are only used once.
*
* For example, this pass should turn `let a = 1 in 2` into `2`, as the value of
* the expression does not depend on a. Another example is `if (true) then 1
* else 2` will be optimized into 1.
* For example, this pass should turn `let a = 1 in 2` into `2`,
* as the value of the expression does not depend on a.
*
* As another example, `let a = 1 in a` will be optimized into 1.
*
* \param e the expression to optimize.
*
......@@ -558,6 +570,12 @@ TVM_DLL Expr ToANormalForm(const Expr& e, const Module& mod);
*/
TVM_DLL Expr ToGraphNormalForm(const Expr& e);
/*! \brief Aggressive constant propagation/constant folding/inlining.
* It will do as much computation in compile time as possible.
* It has two benefit: remove runtime overhead, and allow more optimization (typically fusion).
* As a side effect, code size will explode.
*/
Expr PartialEval(const Expr& e);
} // namespace relay
} // namespace tvm
......
......@@ -89,6 +89,7 @@ class PatternFunctor<R(const Pattern& n, Args...)> {
* \return The result of the call
*/
virtual R VisitPattern(const Pattern& n, Args... args) {
CHECK(n.defined());
static FType vtable = InitVTable();
return vtable(n, this, std::forward<Args>(args)...);
}
......
......@@ -956,3 +956,20 @@ def pass_debug_print(ast, show_meta_data=True, annotate=None, gnf=True):
A text representation of `ast`.
"""
return _ir_pass.pass_debug_print(ast, show_meta_data, annotate, gnf)
def partial_evaluate(expr):
"""
Evaluate the static fragment of the code.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
Returns
-------
expr : tvm.relay.Expr
The output expression.
"""
return _ir_pass.partial_evaluate(expr)
......@@ -556,7 +556,7 @@ class Interpreter :
CHECK_NE(cvn->constructor->tag, -1);
if (op->constructor->tag == cvn->constructor->tag) {
// todo(M.K.): should use ptr equality but it is broken
CHECK(op->patterns.size() == cvn->fields.size());
CHECK_EQ(op->patterns.size(), cvn->fields.size());
for (size_t i = 0; i < op->patterns.size(); ++i) {
if (!VisitPattern(op->patterns[i], cvn->fields[i])) {
return false;
......
......@@ -43,9 +43,6 @@ Expr ExprMutator::VisitExpr(const Expr& expr) {
}
Expr ExprMutator::VisitExpr_(const VarNode* op) {
// NOTE: var will only be mutated once
// Thanks to the memo and reused during rewriting if necessary.
// It is safe to assume that the
if (op->type_annotation.defined()) {
auto type = this->VisitType(op->type_annotation);
if (!op->type_annotation.same_as(type)) {
......
......@@ -245,6 +245,46 @@ class PrettyPrinter :
return Doc(unique_prefix);
}
Doc Print(Kind k) {
switch (k) {
case kType:
return Doc("Type");
case kShapeVar:
return Doc("Shape");
case kBaseType:
return Doc("BaseType");
case kConstraint:
return Doc("Constraint");
case kAdtHandle:
return Doc("AdtHandle");
case kTypeData:
return Doc("TypeData");
default:
LOG(ERROR) << "Unknown Kind";
throw;
}
}
/*!
* \brief Allocate name to a type variable.
* \param var The input type variable.
* \return The corresponding name.
*/
Doc AllocTypeVar(const TypeVar& var) {
std::string name = var->var->name_hint;
if (name.length() == 0 || !std::isalpha(name[0])) {
name = "t" + name;
}
Doc val = GetUniqueName("%" + name);
if (memo_type_.count(var)) {
val << "-malformed-ir";
}
memo_type_[var] = val;
if (var->kind != kType) {
val << ": " << Print(var->kind);
}
return val;
}
/*!
* \brief Allocate name to a variable.
* \param var The input variable.
......@@ -253,7 +293,7 @@ class PrettyPrinter :
Doc AllocVar(const Var& var) {
std::string name = var->name_hint();
// always make sure first name is alpha
if (name.length() != 0 && !std::isalpha(name[0])) {
if (name.length() == 0 || !std::isalpha(name[0])) {
name = "v" + name;
}
Doc val = GetUniqueName("%" + name);
......@@ -387,12 +427,18 @@ class PrettyPrinter :
}
Doc PrintFunc(const Doc& prefix, const Function& fn) {
// TODO(tqchen, M.K.) support generic function
// Possibly through meta data
CHECK_EQ(fn->type_params.size(), 0U)
<< "generic fn not yet supported";
Doc doc;
doc << prefix << "(";
doc << prefix;
if (fn->type_params.size() > 0) {
doc << "<";
std::vector<Doc> type_params;
for (const TypeVar& tv : fn->type_params) {
type_params.push_back(AllocTypeVar(tv));
}
doc << PrintVec(type_params);
doc << ">";
}
doc << "(";
std::vector<Doc> params;
for (Var param : fn->params) {
params.push_back(AllocVar(param));
......@@ -516,6 +562,10 @@ class PrettyPrinter :
return Print(GetRef<NodeRef>(node), true);
}
Doc VisitType_(const TypeVarNode* node) final {
return AllocTypeVar(GetRef<TypeVar>(node));
}
Doc VisitType_(const TensorTypeNode* node) final {
// scalar type
if (node->shape.size() == 0) {
......
......@@ -77,6 +77,7 @@ class TypeFunctor<R(const Type& n, Args...)> {
* \return The result of the call
*/
virtual R VisitType(const Type& n, Args... args) {
CHECK(n.defined());
static FType vtable = InitVTable();
return vtable(n, this, std::forward<Args>(args)...);
}
......
......@@ -35,90 +35,109 @@
namespace tvm {
namespace relay {
bool IsBoolLit(const Expr& e, bool b) {
if (const ConstantNode* c = e.as<ConstantNode>()) {
if (c->is_scalar()) {
auto dt = c->tensor_type()->dtype;
if (dt == Bool()) {
return *reinterpret_cast<const uint8_t*>(c->data->data) == b;
} else if (dt == UInt(8)) {
return *reinterpret_cast<const uint8_t*>(c->data->data) == b;
} else if (dt == UInt(16)) {
return *reinterpret_cast<const uint16_t*>(c->data->data) == b;
} else if (dt == UInt(32)) {
return *reinterpret_cast<const uint32_t*>(c->data->data) == b;
} else if (dt == UInt(64)) {
return *reinterpret_cast<const uint64_t*>(c->data->data) == b;
} else if (dt == Int(8)) {
return *reinterpret_cast<const int8_t*>(c->data->data) == b;
} else if (dt == Int(16)) {
return *reinterpret_cast<const int16_t*>(c->data->data) == b;
} else if (dt == Int(32)) {
return *reinterpret_cast<const int32_t*>(c->data->data) == b;
} else if (dt == Int(64)) {
return *reinterpret_cast<const int64_t*>(c->data->data) == b;
}
}
}
return false;
}
// calculate the dependency graph from expression
class CalcDep : private ExprMutator {
class CalcDep : private ExprVisitor {
public:
static Expr Eliminate(const Expr& e) {
CalcDep cd;
auto res = cd(e);
GenLet gl(cd.var_map_);
gl(res);
return gl.lets_.Get(res);
cd.Calculate(e);
Eliminator el(cd.expr_map_, cd.use_map_, cd.letrec_set_);
return el(e);
}
private:
using VarMap = std::unordered_map<Var, Expr, NodeHash, NodeEqual>;
VarMap var_map_;
Expr VisitExpr_(const IfNode* i) final {
auto cond = VisitExpr(i->cond);
if (IsBoolLit(cond, true)) {
return Eliminate(i->true_branch);
} else if (IsBoolLit(cond, false)) {
return Eliminate(i->false_branch);
} else {
return IfNode::make(cond, Eliminate(i->true_branch), Eliminate(i->false_branch));
template<typename X>
using VarMap = std::unordered_map<Var, X, NodeHash, NodeEqual>;
using VarSet = std::unordered_set<Var, NodeHash, NodeEqual>;
VarMap<Expr> expr_map_;
VarMap<size_t> use_map_;
VarSet letrec_set_;
bool count_ = true;
VarSet dead_worklist_;
VarSet current_letrec_;
void LetRec(const std::function<void()>& func, const Var& v) {
current_letrec_.insert(v);
func();
current_letrec_.erase(v);
}
void VisitExpr_(const LetNode* l) final {
if (count_) {
CHECK_EQ(expr_map_.count(l->var), 0);
CHECK_EQ(use_map_.count(l->var), 0);
expr_map_[l->var] = l->value;
use_map_[l->var] = 0;
dead_worklist_.insert(l->var);
LetRec([&]() { VisitExpr(l->value); }, l->var);
}
VisitExpr(l->body);
}
Expr VisitExpr_(const LetNode* l) final {
var_map_[l->var] = Eliminate(l->value);
return VisitExpr(l->body);
void VisitExpr(const Expr& e) final {
ExprFunctor<void(const Expr&)>::VisitExpr(e);
}
Expr VisitExpr_(const FunctionNode* f) final {
return FunctionNode::make(f->params,
Eliminate(f->body),
f->ret_type,
f->type_params);
void VisitExpr_(const VarNode* v) final {
Var var = GetRef<Var>(v);
if (expr_map_.count(var) == 0) {
return;
}
if (current_letrec_.count(var) == 0) {
if (count_) {
use_map_[var] += 1;
dead_worklist_.erase(var);
} else {
CHECK_GT(use_map_[var], 0) << var;
use_map_[var] -= 1;
if (use_map_[var] == 0) {
dead_worklist_.insert(var);
}
}
} else {
letrec_set_.insert(var);
}
}
void Calculate(const Expr& v) {
VisitExpr(v);
count_ = false;
while (!dead_worklist_.empty()) {
Var dead = *(dead_worklist_.begin());
dead_worklist_.erase(dead);
CHECK_EQ(use_map_[dead], 0);
if (expr_map_.count(dead) > 0) {
LetRec([&]() { VisitExpr(expr_map_[dead]); }, dead);
}
}
}
// generate the let list from dependency graph
class GenLet : private ExprVisitor {
class Eliminator : private ExprMutator {
private:
LetList lets_;
VarMap var_map_;
explicit GenLet(const VarMap& var_map) : var_map_(var_map) { }
VarMap<Expr> expr_map_;
VarMap<size_t> use_map_;
VarSet letrec_set_;
explicit Eliminator(const VarMap<Expr>& expr_map,
const VarMap<size_t>& use_map,
const VarSet& letrec_set) :
expr_map_(expr_map), use_map_(use_map), letrec_set_(letrec_set) { }
friend CalcDep;
void VisitExpr_(const VarNode* vnode) final {
Var v = GetRef<Var>(vnode);
auto it = var_map_.find(v);
if (it != var_map_.end()) {
Expr expr = it->second;
var_map_.erase(it);
// erase before visit to handle letrec
VisitExpr(expr);
// visit before push back so the dependency of dependency is before the dependency
lets_.Push(v, expr);
bool HasLet(const Var& v) {
return (use_map_[v] > 1 || (use_map_[v] != 0 && letrec_set_.count(v) != 0));
}
Expr VisitExpr_(const VarNode* op) final {
Var v = GetRef<Var>(op);
return (expr_map_.count(v) == 0 || HasLet(v)) ? v : VisitExpr(expr_map_[v]);
}
Expr VisitExpr_(const LetNode* op) final {
Var v = op->var;
if (HasLet(v)) {
return LetNode::make(v, VisitExpr(op->value), VisitExpr(op->body));
} else {
return VisitExpr(op->body);
}
}
};
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2018 by Contributors
*
* \file partial_eval.cc
*
* \brief Perform known computation in compile time.
*
* The partial evaluator try to do computation at compile time,
* so it can generate code that do less work.
* Additionally, it might open more chance for further optimization,
* since the high level, structural part of the code (closure, reference, control flow)
* might get partially evaluated away, and the subsequent optimization (for example, kernel fusion)
* can reason across those structural code as it got removed.
* In the extreme case, partial evaluation can even turn the whole program
* into pure first order computation with no control flow.
* In such a case, we can compile the whole computation onto SIMD Instruction/GPU/FPGA,
* and get huge speedup.
*
* It works by making the following modifications to the standard relay interpreter:
*
* 0: The values become partially static value.
* Since we cannot know the value of every term at compile time,
* Term might get partially evaluated to 'Unknown Value'.
* Every partially static value is, hence,
* a static fragment that might not be there (partially static),
* and a dynamic fragment that is semantically equivalent to the original term,
* so the unknown part will be computed at runtime, using the dynamic fragment.
*
* 1: The interpreter holds a LetList, which preserves A Normal Form for the generated code.
* More specifically, we require that all dynamic is an atom.
* This avoids code duplication (which is both inefficient and incorrect), as atom has constant size
* and allow us to not handle capture-avoidance substitution (as atom has no binder).
*
* 2: The map of References to partially static values is reified, as described below.
* Instead of Reference having mutable field, Reference only has an unique identifier.
* There will be a mutable mapping of id to partially static value, called the store.
* This allow us to rollback the store:
* when a path may or may not be executed (as in a conditional), we copy the store,
* recurse with the copy, and reinstate the original when the call returns
* so that the effects of the computation are not preserved.
* We do this in if else, pattern matching, and in function,
* as, when we see a function, we partially evaluate it with all the argument as dynamic,
* to generate efficient dynamic for that function.
*
* 3: The generated code reuses bindings (although they are not shadowed),
* so we have to deduplicate them.
*
* 4: In the generated code, multiple VarNode might have same Id.
* While it is permitted, most pass use NodeHash for Var,
* and having multiple VarNode for same Id break them.
* Thus we remap them to a single Id for now.
*
* Also, It will also generate lots of dead code,
* so it is a good idea to feed it through the dead code eliminator after partial evaluation.
*
* The partial evaluator makes several assumptions, so there is room for improvement:
*
* 0: The partial evaluator treats global variables as opaque.
* Doing PartialEval on a module level will solve this.
*
* 1: The partial evaluator assume all functions as terminating.
* We need to has a max_expand parameter that shrink on every compile time evaluation,
* to make sure PE does not infinite loop.
* Additionally, we might add a termination analysis pass that lift this requirement
* for function that analysis found terminating.
*
* 2: Every time an unknown effect happened, we clear the whole store.
* It is too conservative: if a local reference is created (and do not get passed outside),
* An unknown global function call/global reference write can not modify it.
* We can pair PE with escape analysis/alias analysis.
*
* 3: We assume all unknown code has effect. Doing effect analysis can make the store more precise.
*
* 4: When doing pattern matching, we can simplify the match even for dynamic case.
* Right now it is all or nothing: either a complete match, or the original dynamic code.
* Instead, we can get a match tree, pair it with the data and evaluate it to a normal form.
* We then can reify the result.
*
* 5: Every time a function is called, it's code will get expanded and partially evaluated.
* We can do a binding time analysis to cache the result and avoid re-partial evaluation.
*
* These assumptions do not affect the correctness of the algorithm, however.
*/
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include <tvm/relay/interpreter.h>
#include "pass_util.h"
#include "let_list.h"
namespace tvm {
namespace relay {
using namespace runtime;
/*! \brief Hash Var by it's id.
* Different VarNode might has same vid, and they are considered to be the same var in such case.
* Use VarHash to hash Var by id.
*/
struct VarHash {
size_t operator()(const Var& v) const {
return v->vid.hash();
}
};
/*! \brief Compare Var by it's id.
* Different VarNode might has same vid, and they are considered to be the same var in such case.
* Use VarEqual to compare Var by id.
*/
struct VarEqual {
bool operator()(const Var& l, const Var& r) const {
return l->vid.get() == r->vid.get();
}
};
/*! \brief The base container type of Relay values. */
class StaticNode : public RelayNode {
public:
static constexpr const char* _type_key = "relay.Value";
TVM_DECLARE_BASE_NODE_INFO(ValueNode, RelayNode);
};
class Static : public NodeRef {
public:
Static() {}
explicit Static(NodePtr<Node> n) : NodeRef(n) {}
const ValueNode* operator->() const {
return static_cast<const ValueNode*>(node_.get());
}
using ContainerType = StaticNode;
};
struct PStaticNode : Node {
Static pstatic; // may be null
Expr dynamic;
PStaticNode(const Static& pstatic, const Expr& dynamic) : pstatic(pstatic), dynamic(dynamic) { }
explicit PStaticNode(const Expr& dynamic) : PStaticNode(Static(), dynamic) { }
TVM_DECLARE_NODE_TYPE_INFO(PStaticNode, Node);
};
RELAY_DEFINE_NODE_REF(PStatic, PStaticNode, NodeRef);
struct STupleNode : StaticNode {
std::vector<PStatic> fields;
explicit STupleNode(const std::vector<PStatic>& fields) : fields(fields) { }
TVM_DECLARE_NODE_TYPE_INFO(STupleNode, StaticNode);
};
RELAY_DEFINE_NODE_REF(STuple, STupleNode, Value);
Static MkSTuple(const std::vector<PStatic>& fields) {
return Static(make_node<STupleNode>(fields));
}
struct STensorNode : StaticNode {
runtime::NDArray data;
explicit STensorNode(const NDArray& data) : data(data) { }
TVM_DECLARE_NODE_TYPE_INFO(STupleNode, StaticNode);
};
RELAY_DEFINE_NODE_REF(STensor, STensorNode, Value);
Static MkSTensor(const NDArray& data) {
return Static(make_node<STensorNode>(data));
}
struct SConstructorNode : StaticNode {
Constructor constructor;
std::vector<PStatic> fields;
SConstructorNode(const Constructor& constructor, const std::vector<PStatic>& fields) :
constructor(constructor), fields(fields) { }
TVM_DECLARE_NODE_TYPE_INFO(SConstructorNode, StaticNode);
};
RELAY_DEFINE_NODE_REF(SConstructor, SConstructorNode, Value);
Static MkSConstructor(const Constructor& constructor, const std::vector<PStatic>& fields) {
return Static(make_node<SConstructorNode>(constructor, fields));
}
struct SRefNode : StaticNode {
// we will use the address as the guid for hashing
TVM_DECLARE_NODE_TYPE_INFO(SRefNode, StaticNode);
};
RELAY_DEFINE_NODE_REF(SRef, SRefNode, Value);
Static MkSRef() {
return Static(make_node<SRefNode>());
}
using Func = std::function<PStatic(const std::vector<PStatic>&,
const Attrs&,
const Array<Type>&,
LetList*)>;
struct SFuncNode : StaticNode {
Func func;
explicit SFuncNode(const Func& func) : func(func) { }
TVM_DECLARE_NODE_TYPE_INFO(SFuncNode, StaticNode);
};
RELAY_DEFINE_NODE_REF(SFunc, SFuncNode, Value);
Static MkSFunc(const Func& func) {
return Static(make_node<SFuncNode>(func));
}
/*!
* \brief A stack frame in the Relay interpreter.
*
* Contains a mapping from relay::Var to relay::Value.
*/
struct Frame {
/*! \brief The set of local variables and arguments for the frame. */
std::unordered_map<Var, PStatic, VarHash, VarEqual> locals;
Frame() = default;
};
class Environment {
public:
Environment() : env_({Frame()}) { }
Environment(const Environment&) = delete;
template<typename T>
T Extend(const std::function<T()>& body) {
FrameContext fc(this);
return body();
}
void Insert(const Var& v, const PStatic& ps) {
CHECK(ps.defined());
env_.back().locals[v] = ps;
}
PStatic Lookup(const Var& v) {
auto rit = env_.rbegin();
while (rit != env_.rend()) {
if (rit->locals.find(v) != rit->locals.end()) {
return rit->locals.find(v)->second;
}
++rit;
}
LOG(FATAL) << "Unknown Variable: " << v;
throw;
}
private:
std::list<Frame> env_;
struct FrameContext {
Environment* env_;
explicit FrameContext(Environment* env) : env_(env) {
env_->env_.push_back(Frame());
}
~FrameContext() {
env_->env_.pop_back();
}
};
};
/*!
* \brief As our store require rollback, we implement it as a frame.
* every time we need to copy the store, a new frame is insert.
* every time we roll back, a frame is popped.
*/
struct StoreFrame {
std::unordered_map<const SRefNode*, PStatic> store;
/*! \brief on unknown effect, history_valid is set to true to signal above frame is outdated */
bool history_valid = true;
explicit StoreFrame(const std::unordered_map<const SRefNode*, PStatic>& store) : store(store) { }
StoreFrame() = default;
};
class Store {
public:
Store() : store_({StoreFrame()}) { }
Store(const Store&) = delete;
template<typename T>
T Extend(const std::function<T()>& body) {
StoreFrameContext sfc(this);
return body();
}
void Insert(const SRefNode* r, const PStatic& ps) {
store_.back().store[r] = ps;
}
// return null if not found
PStatic Lookup(const SRefNode* r) {
auto rit = store_.rbegin();
while (rit != store_.rend()) {
if (!rit->history_valid) {
return PStatic();
}
if (rit->store.find(r) != rit->store.end()) {
return rit->store.find(r)->second;
}
++rit;
}
return PStatic();
}
void Invalidate() {
store_.back().history_valid = false;
}
private:
std::list<StoreFrame> store_;
struct StoreFrameContext {
Store* store_;
explicit StoreFrameContext(Store* store) : store_(store) {
store_->store_.push_back(StoreFrame());
}
~StoreFrameContext() {
store_->store_.pop_back();
}
};
};
PStatic HasStatic(const Static& stat, const Expr& dynamic) {
return PStatic(make_node<PStaticNode>(stat, dynamic));
}
PStatic NoStatic(const Expr& dynamic) {
return PStatic(make_node<PStaticNode>(dynamic));
}
enum struct MatchStatus {
Match, NoMatch, Unknown
};
bool StatefulOp(const Expr& e) {
static auto op_stateful = Op::GetAttr<TOpIsStateful>("TOpIsStateful");
struct StatefulOpVisitor : ExprVisitor {
bool stateful = false;
void VisitExpr_(const OpNode* op) {
stateful = stateful || op_stateful.get(GetRef<Op>(op), false);
}
};
StatefulOpVisitor sov;
sov(e);
return sov.stateful;
}
using FInterpreter = runtime::TypedPackedFunc<Value(Expr)>;
DLContext CPUContext() {
DLContext ctx;
ctx.device_type = kDLCPU;
ctx.device_id = 0;
return ctx;
}
FInterpreter CPUInterpreter() {
Target target = Target::create("llvm");
// use a fresh build context
// in case we are already in a build context.
BuildConfigContext fresh_build_ctx(build_config());
return CreateInterpreter(Module(nullptr), CPUContext(), target);
}
class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>,
public PatternFunctor<MatchStatus(const Pattern&, const PStatic&)> {
public:
PartialEvaluator(const tvm::Array<Var>& free_vars) {
for (const Var& v : free_vars) {
env_.Insert(v, NoStatic(v));
}
}
PStatic VisitExpr_(const ConstantNode* op, LetList* ll) final {
return HasStatic(MkSTensor(op->data.CopyTo(context_)), ll->Push(GetRef<Expr>(op)));
}
PStatic VisitExpr_(const TupleNode* op, LetList* ll) final {
std::vector<PStatic> value;
tvm::Array<Expr> expr;
for (const Expr& e : op->fields) {
PStatic ps = VisitExpr(e, ll);
value.push_back(ps);
expr.push_back(ps->dynamic);
}
return HasStatic(MkSTuple(value), ll->Push(TupleNode::make(expr)));
}
PStatic VisitExpr_(const TupleGetItemNode* op, LetList* ll) final {
PStatic ps = VisitExpr(op->tuple, ll);
if (ps->pstatic.defined()) {
return Downcast<STuple>(ps->pstatic)->fields[op->index];
} else {
return NoStatic(ll->Push(TupleGetItemNode::make(ps->dynamic, op->index)));
}
}
PStatic VisitExpr_(const VarNode* op, LetList* ll) final {
return env_.Lookup(GetRef<Var>(op));
}
PStatic VisitExpr_(const GlobalVarNode* op, LetList* ll) final {
return NoStatic(GetRef<Expr>(op));
}
PStatic VisitExpr_(const LetNode* op, LetList* ll) final {
env_.Insert(op->var, VisitExpr(op->value, ll));
return VisitExpr(op->body, ll);
}
PStatic VisitExpr_(const IfNode* op, LetList* ll) final {
PStatic c = VisitExpr(op->cond, ll);
if (c->pstatic.defined()) {
NDArray cpu_array = Downcast<STensor>(c->pstatic)->data.CopyTo(CPUContext());
CHECK_EQ(TVMType2Type(cpu_array->dtype), Bool());
if (reinterpret_cast<uint8_t*>(cpu_array->data)[0]) {
return VisitExpr(op->true_branch, ll);
} else {
return VisitExpr(op->false_branch, ll);
}
} else {
Expr t = store_.Extend<Expr>([&]() {
return LetList::With([&](LetList* ll) {
return VisitExpr(op->true_branch, ll)->dynamic;
});
});
Expr f = store_.Extend<Expr>([&]() {
return LetList::With([&](LetList* ll) {
return VisitExpr(op->false_branch, ll)->dynamic;
});
});
store_.Invalidate();
return NoStatic(ll->Push(IfNode::make(c->dynamic, t, f)));
}
}
PStatic VisitExpr_(const RefCreateNode* op, LetList* ll) final {
PStatic ps = VisitExpr(op->value, ll);
Static r = MkSRef();
store_.Insert(r.as<SRefNode>(), ps);
return HasStatic(r, ll->Push(RefCreateNode::make(ps->dynamic)));
}
PStatic VisitExpr_(const RefWriteNode* op, LetList* ll) final {
PStatic r = VisitExpr(op->ref, ll);
PStatic v = VisitExpr(op->value, ll);
if (r->pstatic.defined()) {
store_.Insert(r->pstatic.as<SRefNode>(), v);
} else {
store_.Invalidate();
}
return HasStatic(MkSTuple({}), ll->Push(RefWriteNode::make(r->dynamic, v->dynamic)));
}
PStatic VisitExpr_(const RefReadNode* op, LetList* ll) final {
PStatic r = VisitExpr(op->ref, ll);
if (r->pstatic.defined()) {
PStatic ret = store_.Lookup(r->pstatic.as<SRefNode>());
if (ret) {
return ret;
}
}
return NoStatic(ll->Push(RefReadNode::make(r->dynamic)));
}
PStatic VisitExpr_(const CallNode* op, LetList* ll) final {
PStatic f = VisitExpr(op->op, ll);
std::vector<PStatic> x;
tvm::Array<Expr> x_dyn;
for (const Expr& e : op->args) {
PStatic ps = VisitExpr(e, ll);
x.push_back(ps);
x_dyn.push_back(ps->dynamic);
}
if (f->pstatic.defined()) {
return Downcast<SFunc>(f->pstatic)->func(x, op->attrs, op->type_args, ll);
} else {
store_.Invalidate();
return NoStatic(ll->Push(CallNode::make(f->dynamic, x_dyn, op->attrs, op->type_args)));
}
}
PStatic VisitExpr_(const FunctionNode* op, LetList* ll) final {
Function func = GetRef<Function>(op);
if (func->IsPrimitive()) {
return HasStatic(MkSFunc(ConstEvaluateFunc(func, ll)), func);
}
std::vector<std::pair<Var, PStatic> > free_vars;
for (const auto& v : FreeVars(GetRef<Expr>(op))) {
free_vars.push_back(std::pair<Var, PStatic>(v, env_.Lookup(v)));
}
Func f = [=](const std::vector<PStatic>& pv,
const Attrs& attrs,
const tvm::Array<Type>& type_args,
LetList* ll) {
return env_.Extend<PStatic>([&]() {
CHECK_EQ(pv.size(), func->params.size());
for (size_t i = 0; i < pv.size(); ++i) {
env_.Insert(func->params[i], pv[i]);
}
for (const auto& p : free_vars) {
env_.Insert(p.first, p.second);
}
tvm::Map<TypeVar, Type> subst;
for (size_t i = 0; i < type_args.size(); ++i) {
subst.Set(func->type_params[i], type_args[i]);
}
for (size_t i = type_args.size(); i < func->type_params.size(); ++i) {
subst.Set(func->type_params[i], Type());
}
return VisitExpr(TypeSubst(func->body, subst), ll);
});
};
Expr dyn = store_.Extend<Expr>([&]() {
store_.Invalidate();
return FunctionNode::make(func->params, LetList::With([&](LetList* ll) {
std::vector<PStatic> pv;
for (const auto& v : func->params) {
pv.push_back(NoStatic(v));
}
tvm::Array<Type> type_args;
for (const auto& tp : func->type_params) {
type_args.push_back(tp);
}
return f(pv, Attrs(), type_args, ll)->dynamic;
}), func->ret_type, func->type_params, func->attrs);
});
return HasStatic(MkSFunc(f), ll->Push(dyn));
}
Expr Reflect(const PStatic& st) {
if (const STensorNode* op = st->pstatic.as<STensorNode>()) {
return ConstantNode::make(op->data);
} else if (const STupleNode* op = st->pstatic.as<STupleNode>()) {
tvm::Array<Expr> fields;
for (const PStatic& field : op->fields) {
fields.push_back(Reflect(field));
}
return TupleNode::make(fields);
} else {
LOG(FATAL) << "Unknown case";
throw;
}
}
PStatic Reify(const Value& v, LetList* ll) const {
if (const TensorValueNode* op = v.as<TensorValueNode>()) {
return HasStatic(MkSTensor(op->data), ll->Push(ConstantNode::make(op->data)));
} else if (const TupleValueNode* op = v.as<TupleValueNode>()) {
std::vector<PStatic> fields;
tvm::Array<Expr> fields_dyn;
for (const Value& field : op->fields) {
PStatic ps = Reify(field, ll);
fields.push_back(ps);
fields_dyn.push_back(ps->dynamic);
}
return HasStatic(MkSTuple(fields), ll->Push(TupleNode::make(fields_dyn)));
} else {
LOG(FATAL) << "Unknown case";
throw;
}
}
// Constant evaluate a expression.
PStatic ConstEvaluate(const Expr& expr, LetList* ll) {
Expr infered = InferType(expr, Module(nullptr));
Expr fused = FuseOps(infered, 0);
Expr fused_infered = InferType(fused, Module(nullptr));
return Reify(executor_(fused_infered), ll);
}
Func ConstEvaluateFunc(const Expr& expr, LetList* ll) {
return [=](const std::vector<PStatic>& pv,
const Attrs& attrs,
const tvm::Array<Type>& type_args,
LetList* ll) {
tvm::Array<Expr> ns_args;
for (const PStatic& ps : pv) {
ns_args.push_back(ps->dynamic);
}
PStatic ns = NoStatic(CallNode::make(expr, ns_args, attrs, type_args));
if (StatefulOp(expr)) {
return ns;
}
tvm::Array<Expr> args;
for (const PStatic& ps : pv) {
if (ps->pstatic.defined()) {
args.push_back(Reflect(ps));
} else {
return ns;
}
}
return ConstEvaluate(CallNode::make(expr, args, attrs, type_args), ll);
};
}
PStatic VisitExpr_(const OpNode* op, LetList* ll) final {
return HasStatic(MkSFunc(ConstEvaluateFunc(GetRef<Expr>(op), ll)), GetRef<Expr>(op));
}
PStatic VisitExpr_(const ConstructorNode* op, LetList* ll) final {
Constructor c = GetRef<Constructor>(op);
Func f = [=](const std::vector<PStatic>& pv,
const Attrs& attrs,
const tvm::Array<Type>& type_args,
LetList* ll) {
tvm::Array<Expr> dyn;
for (const PStatic& ps : pv) {
dyn.push_back(ps->dynamic);
}
return HasStatic(MkSConstructor(c, pv), ll->Push(CallNode::make(c, dyn)));
};
return HasStatic(MkSFunc(f), GetRef<Expr>(op));
}
PStatic VisitExpr_(const MatchNode* op, LetList* ll) final {
PStatic ps = VisitExpr(op->data, ll);
return env_.Extend<PStatic>([&]() {
for (const Clause& c : op->clauses) {
switch (VisitPattern(c->lhs, ps)) {
case MatchStatus::Match:
return VisitExpr(c->rhs, ll);
case MatchStatus::NoMatch:
continue;
case MatchStatus::Unknown:
tvm::Array<Clause> clauses;
for (const Clause& c : op->clauses) {
Expr expr = store_.Extend<Expr>([&]() {
return LetList::With([&](LetList* ll) {
for (const Var& v : BoundVars(c->lhs)) {
env_.Insert(v, NoStatic(v));
}
return VisitExpr(c->rhs, ll)->dynamic;
});
});
clauses.push_back(ClauseNode::make(c->lhs, expr));
}
store_.Invalidate();
return NoStatic(ll->Push(MatchNode::make(ps->dynamic, clauses)));
}
}
LOG(FATAL) << "No case Match";
throw;
});
}
MatchStatus VisitPattern_(const PatternWildcardNode* op, const PStatic& ps) final {
return MatchStatus::Match;
}
MatchStatus VisitPattern_(const PatternVarNode* op, const PStatic& ps) final {
env_.Insert(op->var, ps);
return MatchStatus::Match;
}
MatchStatus VisitPattern_(const PatternConstructorNode* op, const PStatic& ps) final {
if (ps->pstatic.defined()) {
SConstructor scn = Downcast<SConstructor>(ps->pstatic);
CHECK_NE(op->constructor->tag, -1);
CHECK_NE(scn->constructor->tag, -1);
if (op->constructor->tag == scn->constructor->tag) {
// todo(M.K.): should use ptr equality but it is broken
CHECK_EQ(op->patterns.size(), scn->fields.size());
MatchStatus current_match_status = MatchStatus::Match;
for (size_t i = 0; i < op->patterns.size(); ++i) {
MatchStatus ms = VisitPattern(op->patterns[i], scn->fields[i]);
switch (ms) {
case MatchStatus::Match:
continue;
case MatchStatus::NoMatch:
return MatchStatus::NoMatch;
case MatchStatus::Unknown:
current_match_status = MatchStatus::Unknown;
}
}
return current_match_status;
}
return MatchStatus::NoMatch;
} else {
return MatchStatus::Unknown;
}
}
private:
Environment env_;
Store store_;
DLContext context_ = CPUContext();
FInterpreter executor_ = CPUInterpreter();
};
Var DeDupVar(const Var& v) {
return VarNode::make(v->name_hint(), v->type_annotation);
}
TypeVar DeDupTypeVar(const TypeVar& tv) {
return TypeVarNode::make(tv->var->name_hint, tv->kind);
}
/*! \brief Use a fresh Id for every Var to make the result well-formed. */
Expr DeDup(const Expr& e) {
class DeDupMutator : public ExprMutator, public PatternMutator {
public:
Var Fresh(const Var& v) {
Var ret = DeDupVar(v);
rename_[v] = ret;
return ret;
}
Expr VisitExpr(const Expr& e) final {
return ExprMutator::VisitExpr(e);
}
Expr VisitExpr_(const VarNode* op) final {
Var v = GetRef<Var>(op);
return rename_.count(v) != 0 ? rename_.at(v) : v;
}
Expr VisitExpr_(const LetNode* op) final {
return LetNode::make(Fresh(op->var), VisitExpr(op->value), VisitExpr(op->body));
}
Expr VisitExpr_(const FunctionNode* op) final {
tvm::Array<Var> params;
for (const Var& param : op->params) {
params.push_back(Fresh(param));
}
return FunctionNode::make(params,
VisitExpr(op->body),
op->ret_type,
op->type_params,
op->attrs);
}
Pattern VisitPattern(const Pattern& p) final {
return PatternMutator::VisitPattern(p);
}
Var VisitVar(const Var& v) final {
return Fresh(v);
}
private:
std::unordered_map<Var, Var, NodeHash, NodeEqual> rename_;
};
return DeDupMutator().VisitExpr(e);
}
/*! \brief Remap multiple Var sharing the same Id into the same Var. */
Expr Remap(const Expr& e) {
class RemapMutator : public ExprMutator, public PatternMutator {
Expr VisitExpr_(const VarNode* op) final {
Var v = GetRef<Var>(op);
if (remap_.count(v) == 0) {
remap_.insert({v, v});
}
return remap_.at(v);
}
Var VisitVar(const Var& v) final {
return Downcast<Var>(VisitExpr(v));
}
private:
std::unordered_map<Var, Var, VarHash, VarEqual> remap_;
};
return RemapMutator().VisitExpr(e);
}
Expr PartialEval(const Expr& e) {
return TransformF([&](const Expr& e) {
return LetList::With([&](LetList* ll) {
PartialEvaluator pe(FreeVars(e));
return Remap(DeDup(pe.VisitExpr(e, ll)->dynamic));
});
}, e);
}
TVM_REGISTER_API("relay._ir_pass.partial_evaluate")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = PartialEval(args[0]);
});
} // namespace relay
} // namespace tvm
......@@ -42,7 +42,6 @@ namespace relay {
std::unordered_map<const Node*, size_t>
GetExprRefCount(const Expr& body);
/*!
* \brief Check if expr is positive constant.
* \param expr The expression to be checked.
......@@ -50,7 +49,6 @@ GetExprRefCount(const Expr& body);
*/
bool IsAllPositiveConstant(const Expr& expr);
/*!
* \brief Substitute var with subst.
* \param type The type to be substituted.
......@@ -61,6 +59,15 @@ bool IsAllPositiveConstant(const Expr& expr);
Type TypeSubst(const Type& type, const TypeVar& tvar, const Type& subst);
/*!
* \brief Substitute var with subst.
* \param expr The expr to be substituted.
* \param tvar The type variable to be substituted.
* \param subst The target of substitution.
* \return The substituted result.
*/
Expr TypeSubst(const Expr& expr, const TypeVar& tvar, const Type& subst);
/*!
* \brief Substitute type vars in type.
* \param type The type to be substituted.
* \param subst_map The map of substitution.
......@@ -68,6 +75,28 @@ Type TypeSubst(const Type& type, const TypeVar& tvar, const Type& subst);
*/
Type TypeSubst(const Type& type, const tvm::Map<TypeVar, Type>& subst_map);
/*!
* \brief Substitute type vars in type.
* \param expr The expr to be substituted.
* \param subst_map The map of substitution.
* \return The substituted result.
*/
Expr TypeSubst(const Expr& expr, const tvm::Map<TypeVar, Type>& subst_map);
/*!
* \brief Make arbitrary transformation preserve the out most function.
* \param func The transformation.
* \param e The expression
* \return the transformed expression. If e is a function the return is also a function.
*/
inline Expr TransformF(const std::function<Expr(const Expr&)>& func, const Expr& e) {
if (const FunctionNode* f = e.as<FunctionNode>()) {
return FunctionNode::make(f->params, func(f->body), f->ret_type, f->type_params, f->attrs);
} else {
return func(e);
}
}
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_PASS_UTIL_H_
......@@ -28,6 +28,7 @@
#include <tvm/relay/expr_functor.h>
#include "let_list.h"
#include "../../common/arena.h"
#include "pass_util.h"
namespace tvm {
namespace relay {
......@@ -481,15 +482,7 @@ Expr ToANormalFormAux(const Expr& e, const Module& m, std::set<GlobalVar>* gv) {
}
Expr ToANormalForm(const Expr& e, const Module& m, std::set<GlobalVar>* gv) {
if (const auto* f = e.as<FunctionNode>()) {
return FunctionNode::make(f->params,
ToANormalFormAux(f->body, m, gv),
f->ret_type,
f->type_params,
f->attrs);
} else {
return ToANormalFormAux(e, m, gv);
}
return TransformF([&](const Expr& e) { return ToANormalFormAux(e, m, gv); }, e);
}
Expr ToANormalForm(const Expr& e, const Module& m) {
......
......@@ -27,6 +27,7 @@
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include "pass_util.h"
#include "../ir/type_functor.h"
namespace tvm {
......@@ -171,8 +172,7 @@ class VarVisitor : protected ExprVisitor, protected PatternVisitor {
return ret;
}
Array<Var> Bound(const Expr& expr) {
this->VisitExpr(expr);
Array<Var> Collect() {
Array<Var> ret;
for (const auto& v : bound_vars_.data) {
ret.push_back(v);
......@@ -180,6 +180,16 @@ class VarVisitor : protected ExprVisitor, protected PatternVisitor {
return ret;
}
Array<Var> Bound(const Expr& expr) {
this->VisitExpr(expr);
return Collect();
}
Array<Var> Bound(const Pattern& pat) {
this->VisitPattern(pat);
return Collect();
}
Array<Var> All(const Expr& expr) {
this->VisitExpr(expr);
Array<Var> ret;
......@@ -256,6 +266,10 @@ tvm::Array<Var> BoundVars(const Expr& expr) {
return VarVisitor().Bound(expr);
}
tvm::Array<Var> BoundVars(const Pattern& pat) {
return VarVisitor().Bound(pat);
}
tvm::Array<Var> AllVars(const Expr& expr) {
return VarVisitor().All(expr);
}
......@@ -267,7 +281,12 @@ TVM_REGISTER_API("relay._ir_pass.free_vars")
TVM_REGISTER_API("relay._ir_pass.bound_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = BoundVars(args[0]);
NodeRef x = args[0];
if (x.as_derived<ExprNode>()) {
*ret = BoundVars(Downcast<Expr>(x));
} else {
*ret = BoundVars(Downcast<Pattern>(x));
}
});
TVM_REGISTER_API("relay._ir_pass.all_vars")
......@@ -388,5 +407,33 @@ bool IsAllPositiveConstant(const Expr& expr) {
}
}
Type TypeSubst(const Type& type, const TypeVar& tvar, const Type& subst) {
return TypeSubst(type, tvm::Map<TypeVar, Type>({{tvar, subst}}));
}
Expr TypeSubst(const Expr& expr, const TypeVar& tvar, const Type& subst) {
return TypeSubst(expr, tvm::Map<TypeVar, Type>({{tvar, subst}}));
}
Type TypeSubst(const Type& type, const tvm::Map<TypeVar, Type>& subst_map) {
return Bind(type, subst_map);
}
Expr TypeSubst(const Expr& expr, const tvm::Map<TypeVar, Type>& subst_map) {
class TypeSubstMutator : public ExprMutator, public PatternMutator {
public:
explicit TypeSubstMutator(const tvm::Map<TypeVar, Type>& subst_map) : subst_map_(subst_map) { }
Type VisitType(const Type& t) final {
return TypeSubst(t, subst_map_);
}
Var VisitVar(const Var& v) final {
return Downcast<Var>(VisitExpr(v));
}
private:
const tvm::Map<TypeVar, Type>& subst_map_;
};
return TypeSubstMutator(subst_map).VisitExpr(expr);
}
} // namespace relay
} // namespace tvm
......@@ -48,8 +48,13 @@ def test_let():
def test_used_let():
orig = relay.Let(e.c, e.one, e.c + e.c)
assert alpha_equal(dead_code_elimination(orig), relay.Let(e.c, e.one, e.c + e.c))
def test_inline():
orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c))
assert alpha_equal(dead_code_elimination(orig), relay.Let(e.c, e.d, e.c))
assert alpha_equal(dead_code_elimination(orig), e.d)
def test_chain_unused_let():
......@@ -87,13 +92,6 @@ def test_op_let():
assert alpha_equal(dead_code_elimination(add(relay.Let(e.a, e.one, e.three), e.two)), add(e.three, e.two))
def test_if():
cond = relay.const(True)
orig = relay.If(cond, e.a, e.b)
y = dead_code_elimination(orig)
assert alpha_equal(y, e.a)
def test_tuple_get_item():
t = relay.Var('t')
g = relay.TupleGetItem(t, 0)
......@@ -102,9 +100,9 @@ def test_tuple_get_item():
if __name__ == "__main__":
test_if()
test_let()
test_used_let()
test_inline()
test_chain_unused_let()
test_recursion()
test_op_let()
......
......@@ -22,9 +22,11 @@ from tvm.relay.prelude import Prelude
import numpy as np
def rand(dtype='float32', *shape):
return tvm.nd.array(np.random.rand(*shape).astype(dtype))
def test_id():
shape = (10, 10)
dtype = 'float32'
......
import numpy as np
import tvm
from tvm import relay
from tvm.relay.ir_pass import partial_evaluate, dead_code_elimination
from tvm.relay.ir_pass import gradient, alpha_equal, infer_type
from tvm.relay import op, create_executor
from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue
from tvm.relay.prelude import Prelude
from tvm.relay import create_executor
def check_eval(expr, expected_result, mod=None, rtol=1e-07):
ctx = tvm.context("llvm", 0)
intrp = create_executor(mod=mod, ctx=ctx, target="llvm")
result = intrp.evaluate(expr)
np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol)
def dcpe(expr):
return dead_code_elimination(partial_evaluate(expr))
def test_tuple():
t = relay.TypeVar("t")
x = relay.Var("x", t)
body = relay.TupleGetItem(relay.Tuple([relay.const(4.0), x]), 1)
f = relay.Function([x], body, None, [t])
assert alpha_equal(dcpe(f), relay.Function([x], x, None, [t]))
def test_const_inline():
d = relay.Var("d")
double = relay.Function([d], d + d)
orig = double(relay.const(4.0))
assert alpha_equal(dcpe(double(relay.const(4.0))), relay.const(8.0))
def test_ref():
d = relay.Var("d")
r = relay.Var("r")
x = relay.Var("x")
body = relay.RefRead(r)
body = relay.Let(x, relay.RefWrite(r, relay.RefRead(r) * relay.RefRead(r)), body)
body = relay.Let(r, relay.RefCreate(d), body)
square = relay.Function([d], body)
assert alpha_equal(dcpe(square), relay.Function([d], d * d))
def test_ad():
shape = (10, 10)
dtype = "float32"
t = relay.TensorType(shape, dtype)
d = relay.Var("d", t)
f = relay.Function([d], d * d)
g = dcpe(gradient(f))
m = d * d
o = relay.op.ones_like(m)
grad = relay.op.zeros_like(d) + relay.op.collapse_sum_like(o * d, d) + relay.op.collapse_sum_like(o * d, d)
expected = relay.Function([d], relay.Tuple([m, relay.Tuple([grad])]))
assert alpha_equal(g, expected)
def test_if_ref():
shape = ()
dtype = "bool"
t = relay.TensorType(shape, dtype)
d = relay.Var("d", t)
r = relay.Var("r")
update = relay.Function([], relay.RefWrite(r, relay.RefRead(r) + relay.RefRead(r)))
u = relay.Var("u")
body = relay.If(d, u(), u())
eff = relay.Var("eff")
body = relay.Let(eff, body, relay.RefRead(r))
f = relay.Function([d], relay.Let(r, relay.RefCreate(relay.const(1)), relay.Let(u, update, body)))
f = infer_type(f)
pe_f = infer_type(partial_evaluate(f))
ex = create_executor()
f_res = ex.evaluate(f)(relay.const(True))
pe_f_res = ex.evaluate(pe_f)(relay.const(True))
np.testing.assert_allclose(f_res.asnumpy(), 2 * np.ones_like(f_res.asnumpy()))
np.testing.assert_allclose(pe_f_res.asnumpy(), 2 * np.ones_like(pe_f_res.asnumpy()))
def test_function_invalidate():
shape = ()
dtype = "bool"
t = relay.TensorType(shape, dtype)
d = relay.Var("d", t)
r = relay.Var("r")
fetch = relay.Function([], relay.RefRead(r))
fet = relay.Var("fetch")
fet_obscured = relay.Var("fetch_obscured")
u = relay.Var("u")
body = relay.If(d, fet_obscured(), fet_obscured())
body = relay.Let(u, relay.RefWrite(r, relay.const(1)), body)
body = relay.Let(fet_obscured, relay.If(d, fet, fet), body)
body = relay.Let(fet, fetch, body)
body = relay.Let(r, relay.RefCreate(relay.const(0)), body)
f = relay.Function([d], body)
f = infer_type(f)
pe_f = infer_type(partial_evaluate(f))
ex = create_executor()
f_res = ex.evaluate(f)(relay.const(True))
pe_f_res = ex.evaluate(pe_f)(relay.const(True))
np.testing.assert_allclose(f_res.asnumpy(), np.ones_like(f_res.asnumpy()))
np.testing.assert_allclose(pe_f_res.asnumpy(), np.ones_like(pe_f_res.asnumpy()))
def test_head_cons():
mod = relay.Module()
p = Prelude(mod)
def hd_impl():
a = relay.TypeVar("a")
x = relay.Var("x", p.l(a))
y = relay.Var("y")
z = relay.Var("z")
cons_case = relay.Clause(relay.PatternConstructor(p.cons,
[relay.PatternVar(y),
relay.PatternVar(z)]),
y)
return relay.Function([x], relay.Match(x, [cons_case]), a, [a])
t = relay.TypeVar("t")
x = relay.Var("x", t)
hd = relay.Var("hd")
body = relay.Let(hd, hd_impl(), hd(p.cons(x, p.nil())))
f = relay.Function([x], body, None, [t])
f = infer_type(f, mod=mod)
res = dcpe(f)
assert alpha_equal(res, relay.Function([x], x, t, [t]))
if __name__ == '__main__':
test_tuple()
test_const_inline()
test_ref()
test_ad()
test_if_ref()
test_function_invalidate()
test_head_cons()
......@@ -154,6 +154,7 @@ def test_add():
assert count(intrp.evaluate(to_a_normal_form(add(s(z()), s(z())), mod))) == 2
assert "let" in mod[add].astext()
def test_let():
x = relay.Var("x")
y = relay.Var("y")
......@@ -163,6 +164,17 @@ def test_let():
check_eval(body, 8)
check_eval(to_a_normal_form(body), 8)
def test_function():
x = relay.Var("x")
f = relay.Function([x], x + x)
d = relay.const(4.0, 'float32')
anf_f = to_a_normal_form(f)
assert isinstance(anf_f, relay.Function)
check_eval(f(d), 8)
check_eval(anf_f(d), 8)
if __name__ == '__main__':
test_explicit_bound()
test_order()
......@@ -171,3 +183,4 @@ if __name__ == '__main__':
test_ref()
test_add()
test_let()
test_function()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment