Unverified Commit 4369b7f6 by Tianqi Chen Committed by GitHub

[RELAY][PASS] General OpFusion. (#2090)

parent e470f8ea
...@@ -429,6 +429,16 @@ inline const TTypeNode* ExprNode::type_as() const { ...@@ -429,6 +429,16 @@ inline const TTypeNode* ExprNode::type_as() const {
return node; return node;
} }
/*!
* \brief Print node as text format.
* \param node The node to be printed.
* \param annotate An optional callback function for attaching
* additional comment block to an expr.
* \return The text representation.
*/
std::string RelayPrint(
const NodeRef& node,
runtime::TypedPackedFunc<std::string(Expr)> annotate = nullptr);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_EXPR_H_ #endif // TVM_RELAY_EXPR_H_
...@@ -161,6 +161,8 @@ class TypedPackedFunc<R(Args...)> { ...@@ -161,6 +161,8 @@ class TypedPackedFunc<R(Args...)> {
using TSelf = TypedPackedFunc<R(Args...)>; using TSelf = TypedPackedFunc<R(Args...)>;
/*! \brief default constructor */ /*! \brief default constructor */
TypedPackedFunc() {} TypedPackedFunc() {}
/*! \brief constructor from null */
TypedPackedFunc(std::nullptr_t null) {} // NOLINT(*)
/*! /*!
* \brief construct by wrap a PackedFunc * \brief construct by wrap a PackedFunc
* *
......
...@@ -22,15 +22,20 @@ def register_relay_node(type_key=None): ...@@ -22,15 +22,20 @@ def register_relay_node(type_key=None):
class RelayNode(NodeBase): class RelayNode(NodeBase):
def astext(self): """Base class of all relay node."""
def astext(self, annotate=None):
"""Get the text format of the expression. """Get the text format of the expression.
Returns Returns
------- -------
text : str text : str
The text format of the expression. The text format of the expression.
annotate: Optional[relay.Expr->str]
Optional annotate function to provide additional
information in the comment block.
""" """
return _expr._text_print(self) return _expr.RelayPrint(self, annotate)
@register_relay_node @register_relay_node
......
...@@ -173,11 +173,13 @@ def build(func, ...@@ -173,11 +173,13 @@ def build(func,
else: else:
tophub_context = autotvm.util.EmptyContext() tophub_context = autotvm.util.EmptyContext()
cfg = BuildConfig.current
with tophub_context: with tophub_context:
func = optimize(func) func = optimize(func)
# Fuse ops before running code gen # Fuse ops before running code gen
func = ir_pass.infer_type(func) func = ir_pass.infer_type(func)
func = ir_pass.fuse_ops(func) func = ir_pass.fuse_ops(func, cfg.opt_level)
# Graph code generation # Graph code generation
func = ir_pass.infer_type(func) func = ir_pass.infer_type(func)
graph_gen = _graph_gen.GraphRuntimeCodegen(mod=None, target=target) graph_gen = _graph_gen.GraphRuntimeCodegen(mod=None, target=target)
......
...@@ -6,7 +6,6 @@ from numbers import Number as _Number ...@@ -6,7 +6,6 @@ from numbers import Number as _Number
import numpy as _np import numpy as _np
from .base import RelayNode, register_relay_node from .base import RelayNode, register_relay_node
from . import _make from . import _make
from . import _expr
from . import ty as _ty from . import ty as _ty
from .._ffi import base as _base from .._ffi import base as _base
from .. import nd as _nd from .. import nd as _nd
...@@ -477,7 +476,7 @@ class TupleWrapper(object): ...@@ -477,7 +476,7 @@ class TupleWrapper(object):
text : str text : str
The text format of the tuple expression. The text format of the tuple expression.
""" """
return _expr._text_print(self.tuple_value) return self.tuple_value.astext()
def __getitem__(self, index): def __getitem__(self, index):
if index >= len(self): if index >= len(self):
......
...@@ -259,7 +259,7 @@ def structural_hash(value): ...@@ -259,7 +259,7 @@ def structural_hash(value):
raise TypeError(msg) raise TypeError(msg)
def fuse_ops(expr): def fuse_ops(expr, opt_level=1):
"""Fuse operators in expr together. """Fuse operators in expr together.
Parameters Parameters
...@@ -267,9 +267,12 @@ def fuse_ops(expr): ...@@ -267,9 +267,12 @@ def fuse_ops(expr):
expr : tvm.relay.Expr expr : tvm.relay.Expr
The input expression. The input expression.
opt_level : int
The level of fuse optimization.
Returns Returns
------- -------
transformed_expr : tvm.relay.Expr transformed_expr : tvm.relay.Expr
Transformed expression, containing fused result. Transformed expression, containing fused result.
""" """
return _ir_pass.FuseOps(expr) return _ir_pass.FuseOps(expr, opt_level)
...@@ -38,11 +38,29 @@ class Arena { ...@@ -38,11 +38,29 @@ class Arena {
/*! /*!
* \brief Allocate a space from Arena for type T * \brief Allocate a space from Arena for type T
* \param T the data type to be allocated * \param T the data type to be allocated
* \note The space of T is not initialized.
*/ */
template<typename T> template<typename T>
T* Alloc() { T* allocate_() {
return static_cast<T*>(Alloc(sizeof(T), alignof(T))); return static_cast<T*>(Alloc(sizeof(T), alignof(T)));
} }
/*!
* \brief Create a new instance of type T.
* \param args The constructor argument.
* \tparam T the type to be created.
* \tparam Args Arguments to the constructor.
*
* \return The allocated object.
* \note The type T must be simple type, or only contain
* memory allocated from the same arena.
* Otherwise the destructor needs to be called explicitly.
*/
template<typename T, typename... Args>
T* make(Args&&... args) {
T* ptr = allocate_<T>();
new (ptr) T(std::forward<Args>(args)...);
return ptr;
}
private: private:
// page size 16 KB // page size 16 KB
...@@ -87,6 +105,44 @@ class Arena { ...@@ -87,6 +105,44 @@ class Arena {
} }
}; };
/*!
* \brief Link list node
* \tparam T the content data type
*/
template<typename T>
struct LinkNode {
/*! \brief The content value */
T value;
/*! \brief pointer to the next location */
LinkNode<T>* next{nullptr};
};
/*!
* \brief LinkedList structure
* \tparam T the content data type
* \note This is a simple data structure that can be used together with the arena.
* \sa LinkNode
*/
template<typename T>
struct LinkedList {
/*! \brief Head pointer */
LinkNode<T>* head{nullptr};
/*! \brief Tail pointer */
LinkNode<T>* tail{nullptr};
/*!
* \brief Push a new node to the end of the linked list.
* \param node The node to be pushed.
*/
void Push(LinkNode<T>* node) {
node->next = nullptr;
if (this->tail != nullptr) {
this->tail->next = node;
this->tail = node;
} else {
head = tail = node;
}
}
};
} // namespace common } // namespace common
} // namespace tvm } // namespace tvm
#endif // TVM_COMMON_ARENA_H_ #endif // TVM_COMMON_ARENA_H_
...@@ -109,6 +109,29 @@ class ScheduleGetter : ...@@ -109,6 +109,29 @@ class ScheduleGetter :
return {}; return {};
} }
Array<Tensor> VisitExpr_(const ConstantNode* op) final {
CHECK(op->is_scalar());
void* data = op->data->data;
DataType dtype = TVMType2Type(op->data->dtype);
Tensor value = tvm::compute({}, [&](const Array<tvm::Var>&) {
if (dtype == Int(32)) {
return make_const(dtype, static_cast<const int32_t*>(data)[0]);
} else if (dtype == Int(64)) {
return make_const(dtype, static_cast<const int64_t*>(data)[0]);
} else if (dtype == Float(32)) {
return make_const(dtype, static_cast<const float*>(data)[0]);
} else if (dtype == Float(64)) {
return make_const(dtype, static_cast<const double*>(data)[0]);
} else if (dtype == Bool()) {
return make_const(dtype, static_cast<const uint8_t*>(data)[0]);
} else {
LOG(FATAL) << "not handled";
return tvm::Expr();
}
});
return {value};
}
Array<Tensor> VisitExpr_(const CallNode* call_node) final { Array<Tensor> VisitExpr_(const CallNode* call_node) final {
static auto fcompute = static auto fcompute =
Op::GetAttr<FTVMCompute>("FTVMCompute"); Op::GetAttr<FTVMCompute>("FTVMCompute");
......
...@@ -125,6 +125,8 @@ class TextPrinter : ...@@ -125,6 +125,8 @@ class TextPrinter :
public TypeFunctor<void (const Type&, std::ostream& os)>, // NOLINT(*) public TypeFunctor<void (const Type&, std::ostream& os)>, // NOLINT(*)
public AttrFunctor<void (const NodeRef&, std::ostream& os)> { // NOLINT(*) public AttrFunctor<void (const NodeRef&, std::ostream& os)> { // NOLINT(*)
public: public:
explicit TextPrinter(runtime::TypedPackedFunc<std::string(Expr)> annotate)
: annotate_(annotate) {}
/*! /*!
* \brief Print a node to string. * \brief Print a node to string.
* \param node. * \param node.
...@@ -279,11 +281,11 @@ class TextPrinter : ...@@ -279,11 +281,11 @@ class TextPrinter :
TextValue VisitExpr_(const CallNode* op) final { TextValue VisitExpr_(const CallNode* op) final {
// possibly through meta-data // possibly through meta-data
TextValue call_op = GetValue(op->op);
std::vector<TextValue> args; std::vector<TextValue> args;
for (Expr arg : op->args) { for (Expr arg : op->args) {
args.emplace_back(GetValue(arg)); args.emplace_back(GetValue(arg));
} }
TextValue call_op = GetValue(op->op);
TextValue id = this->AllocTempVar(); TextValue id = this->AllocTempVar();
this->PrintIndent(); this->PrintIndent();
...@@ -532,7 +534,9 @@ class TextPrinter : ...@@ -532,7 +534,9 @@ class TextPrinter :
*/ */
void PrintOptionalInfo(const Expr& expr) { void PrintOptionalInfo(const Expr& expr) {
// additional information in comment. // additional information in comment.
if (expr->checked_type_.defined()) { if (annotate_ != nullptr) {
stream_ << " # " << annotate_(expr);
} else if (expr->checked_type_.defined()) {
stream_ << " # ty="; stream_ << " # ty=";
this->PrintType(expr->checked_type(), stream_); this->PrintType(expr->checked_type(), stream_);
} }
...@@ -678,7 +682,10 @@ class TextPrinter : ...@@ -678,7 +682,10 @@ class TextPrinter :
name = "%" + name; name = "%" + name;
} }
TextValue val(GetUniqueName(name)); TextValue val(GetUniqueName(name));
CHECK(!memo_.count(var)) << "Duplicated variable " << var; // still print if ir is malformed, but show the error.
if (memo_.count(var)) {
memo_[var] = TextValue(val.name + "-malformed-ir");
}
memo_[var] = val; memo_[var] = val;
return val; return val;
} }
...@@ -686,6 +693,8 @@ class TextPrinter : ...@@ -686,6 +693,8 @@ class TextPrinter :
private: private:
class AttrPrinter; class AttrPrinter;
friend class AttrPrinter; friend class AttrPrinter;
/*! \brief additional comment function */
runtime::TypedPackedFunc<std::string(Expr)> annotate_;
/*! \brief meta data context */ /*! \brief meta data context */
TextMetaDataContext meta_; TextMetaDataContext meta_;
/*! \brief Check whether scope is still valid */ /*! \brief Check whether scope is still valid */
...@@ -776,12 +785,15 @@ void TextPrinter::PrintCallAttrs(const Expr& op, ...@@ -776,12 +785,15 @@ void TextPrinter::PrintCallAttrs(const Expr& op,
os << ", " << meta_.GetMetaNode(attrs); os << ", " << meta_.GetMetaNode(attrs);
} }
std::string RelayPrint(const NodeRef& node) { std::string RelayPrint(const NodeRef& node,
return TextPrinter().Print(node); runtime::TypedPackedFunc<std::string(Expr)> annotate) {
return TextPrinter(annotate).Print(node);
} }
TVM_REGISTER_API("relay._expr._text_print") TVM_REGISTER_API("relay._expr.RelayPrint")
.set_body_typed<std::string(const NodeRef&)>(RelayPrint); .set_body_typed<std::string(
const NodeRef&,
runtime::TypedPackedFunc<std::string(Expr)>)>(RelayPrint);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/nn.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include "pattern_util.h" #include "pattern_util.h"
#include "pass_util.h"
#include "../op/nn/layout.h" #include "../op/nn/layout.h"
namespace tvm { namespace tvm {
...@@ -580,23 +581,6 @@ using FBackwardTransform = TypedPackedFunc< ...@@ -580,23 +581,6 @@ using FBackwardTransform = TypedPackedFunc<
//---------------------------------------------- //----------------------------------------------
// Generic Visitors for FScaleAxisBackward // Generic Visitors for FScaleAxisBackward
//---------------------------------------------- //----------------------------------------------
/*!
* \brief Get reference counter of each internal ExprNode in body.
* \param body The body expression.
* \return The reference count mapping.
*/
std::unordered_map<const Node*, size_t>
GetExprRefCount(const Expr& body) {
class ExprRefCounter : private ExprVisitor {
public:
std::unordered_map<const Node*, size_t>
Get(const Expr& body) {
this->VisitExpr(body);
return std::move(this->visit_counter_);
}
};
return ExprRefCounter().Get(body);
}
class BackwardPrep : private ExprVisitor { class BackwardPrep : private ExprVisitor {
public: public:
......
/*!
* Copyright (c) 2018 by Contributors.
*
* \file tvm/relay/pass/pass_util.h
* \brief Utilities for writing
*/
#ifndef TVM_RELAY_PASS_PASS_UTIL_H_
#define TVM_RELAY_PASS_PASS_UTIL_H_
#include <tvm/relay/op.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/attrs/transform.h>
namespace tvm {
namespace relay {
/*!
* \brief Get reference counter of each internal ExprNode in body.
* \param body The body expression.
* \return The reference count mapping.
*/
std::unordered_map<const Node*, size_t>
GetExprRefCount(const Expr& body);
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_PASS_UTIL_H_
...@@ -442,6 +442,9 @@ class TypeInferencer::Resolver : public ExprMutator { ...@@ -442,6 +442,9 @@ class TypeInferencer::Resolver : public ExprMutator {
VarNode* new_var =( VarNode* new_var =(
std::is_base_of<VarNode, T>::value ? std::is_base_of<VarNode, T>::value ?
static_cast<VarNode*>(new_e.node_.get()) : nullptr); static_cast<VarNode*>(new_e.node_.get()) : nullptr);
FunctionNode* new_fn =(
std::is_base_of<FunctionNode, T>::value ?
static_cast<FunctionNode*>(new_e.node_.get()) : nullptr);
// check if we need update the new_e // check if we need update the new_e
bool need_update_type = !checked_type.same_as(new_e->checked_type_); bool need_update_type = !checked_type.same_as(new_e->checked_type_);
...@@ -454,7 +457,17 @@ class TypeInferencer::Resolver : public ExprMutator { ...@@ -454,7 +457,17 @@ class TypeInferencer::Resolver : public ExprMutator {
update_missing_type_annotation_ && update_missing_type_annotation_ &&
!new_var->type_annotation.defined()); !new_var->type_annotation.defined());
if (!need_update_type && !need_update_var && !need_update_call) return new_e; bool need_update_fn = (
std::is_base_of<FunctionNode, T>::value &&
update_missing_type_annotation_ &&
!new_fn->ret_type.defined());
if (!need_update_type &&
!need_update_var &&
!need_update_call &&
!need_update_fn) {
return new_e;
}
if (!new_e.node_.unique()) { if (!new_e.node_.unique()) {
// Copy on write optimization // Copy on write optimization
...@@ -467,6 +480,9 @@ class TypeInferencer::Resolver : public ExprMutator { ...@@ -467,6 +480,9 @@ class TypeInferencer::Resolver : public ExprMutator {
new_var = ( new_var = (
std::is_base_of<VarNode, T>::value ? std::is_base_of<VarNode, T>::value ?
static_cast<VarNode*>(new_e.node_.get()) : nullptr); static_cast<VarNode*>(new_e.node_.get()) : nullptr);
new_fn = (
std::is_base_of<FunctionNode, T>::value ?
static_cast<FunctionNode*>(new_e.node_.get()) : nullptr);
} }
// attach the information. // attach the information.
...@@ -483,6 +499,11 @@ class TypeInferencer::Resolver : public ExprMutator { ...@@ -483,6 +499,11 @@ class TypeInferencer::Resolver : public ExprMutator {
if (need_update_var) { if (need_update_var) {
new_var->type_annotation = checked_type; new_var->type_annotation = checked_type;
} }
if (need_update_fn) {
auto* fn_type = checked_type.as<FuncTypeNode>();
CHECK(fn_type != nullptr);
new_fn->ret_type = fn_type->ret_type;
}
return new_e; return new_e;
} }
......
...@@ -85,18 +85,18 @@ Type TypeSolver::Unify(const Type& dst, const Type& src) { ...@@ -85,18 +85,18 @@ Type TypeSolver::Unify(const Type& dst, const Type& src) {
void TypeSolver::AddConstraint(const TypeConstraint& constraint) { void TypeSolver::AddConstraint(const TypeConstraint& constraint) {
if (auto *op = constraint.as<TypeRelationNode>()) { if (auto *op = constraint.as<TypeRelationNode>()) {
// create a new relation node. // create a new relation node.
RelationNode* rnode = make<RelationNode>(); RelationNode* rnode = arena_.make<RelationNode>();
rnode->rel = GetRef<TypeRelation>(op); rnode->rel = GetRef<TypeRelation>(op);
rel_nodes_.push_back(rnode); rel_nodes_.push_back(rnode);
// populate the type information. // populate the type information.
for (size_t i = 0; i < op->args.size(); ++i) { for (size_t i = 0; i < op->args.size(); ++i) {
// insert link to the type list // insert link to the type list
LinkNode<TypeNode*>* tlink = make<LinkNode<TypeNode*> >(); LinkNode<TypeNode*>* tlink = arena_.make<LinkNode<TypeNode*> >();
TypeNode* tnode = GetTypeNode(op->args[i]); TypeNode* tnode = GetTypeNode(op->args[i]);
tlink->value = tnode; tlink->value = tnode;
rnode->type_list.Push(tlink); rnode->type_list.Push(tlink);
// insert type->relation node // insert type->relation node
LinkNode<RelationNode*>* rlink = make<LinkNode<RelationNode*> >(); LinkNode<RelationNode*>* rlink = arena_.make<LinkNode<RelationNode*> >();
rlink->value = rnode; rlink->value = rnode;
tnode->rel_list.Push(rlink); tnode->rel_list.Push(rlink);
} }
......
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
namespace tvm { namespace tvm {
namespace relay { namespace relay {
using common::LinkNode;
using common::LinkedList;
/*! /*!
* \brief Interface of type solver used in type inference. * \brief Interface of type solver used in type inference.
* *
...@@ -70,41 +72,6 @@ class TypeSolver { ...@@ -70,41 +72,6 @@ class TypeSolver {
// All the object in the structure is managed by a arena allocator // All the object in the structure is managed by a arena allocator
// which releases the memory upon distruction of the type solver. // which releases the memory upon distruction of the type solver.
/*! /*!
* \brief Link list node
* \tparam T the content data type
*/
template<typename T>
struct LinkNode {
/*! \brief The content value */
T value;
/*! \brief pointer to the next location */
LinkNode<T>* next{nullptr};
};
/*!
* \brief LinkedList structure
* \tparam T the content data type
*/
template<typename T>
struct LinkedList {
/*! \brief Head pointer */
LinkNode<T>* head{nullptr};
/*! \brief Tail pointer */
LinkNode<T>* tail{nullptr};
/*!
* \brief Push a new node to the end of the linked list.
* \param node The node to be pushed.
*/
void Push(LinkNode<T>* node) {
node->next = nullptr;
if (this->tail != nullptr) {
this->tail->next = node;
this->tail = node;
} else {
head = tail = node;
}
}
};
/*!
* \brief type node struct * \brief type node struct
* TypeNode implements a union-find data structure(via parent) * TypeNode implements a union-find data structure(via parent)
* that can unifies the same types to the name resolved_type. * that can unifies the same types to the name resolved_type.
...@@ -165,18 +132,6 @@ class TypeSolver { ...@@ -165,18 +132,6 @@ class TypeSolver {
/*! \brief Reporter that reports back to self */ /*! \brief Reporter that reports back to self */
TypeReporter reporter_; TypeReporter reporter_;
/*! /*!
* \brief Create function to create a new node ptr via arena
* \tparam The type parameter
* \return The node pointer.
*/
template<typename T>
T* make() {
T* ptr = arena_.Alloc<T>();
// call constructor
new (ptr) T();
return ptr;
}
/*!
* \brief GetTypeNode that is corresponds to t. * \brief GetTypeNode that is corresponds to t.
* if it do not exist, create a new one. * if it do not exist, create a new one.
* \return The type node. * \return The type node.
...@@ -186,7 +141,7 @@ class TypeSolver { ...@@ -186,7 +141,7 @@ class TypeSolver {
if (it != tmap_.end()) { if (it != tmap_.end()) {
return it->second->FindRoot(); return it->second->FindRoot();
} else { } else {
TypeNode* n = make<TypeNode>(); TypeNode* n = arena_.make<TypeNode>();
type_nodes_.push_back(n); type_nodes_.push_back(n);
n->resolved_type = t; n->resolved_type = t;
tmap_[t] = n; tmap_[t] = n;
......
...@@ -129,5 +129,23 @@ TVM_REGISTER_API("relay._ir_pass.free_type_vars") ...@@ -129,5 +129,23 @@ TVM_REGISTER_API("relay._ir_pass.free_type_vars")
} }
}); });
/*!
* \brief Get reference counter of each internal ExprNode in body.
* \param body The body expression.
* \return The reference count mapping.
*/
std::unordered_map<const Node*, size_t>
GetExprRefCount(const Expr& body) {
class ExprRefCounter : private ExprVisitor {
public:
std::unordered_map<const Node*, size_t>
Get(const Expr& body) {
this->VisitExpr(body);
return std::move(this->visit_counter_);
}
};
return ExprRefCounter().Get(body);
}
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -33,6 +33,7 @@ def test_env(): ...@@ -33,6 +33,7 @@ def test_env():
text = env.astext() text = env.astext()
assert "def @myf" in text assert "def @myf" in text
assert "%1 = add(%0, %0) # ty=float32" in text assert "%1 = add(%0, %0) # ty=float32" in text
show(env.astext(annotate=lambda x: str(x.checked_type.dtype)))
show(text) show(text)
......
...@@ -46,6 +46,8 @@ def test_fold_fwd_simple(): ...@@ -46,6 +46,8 @@ def test_fold_fwd_simple():
weight = relay.var("weight", type_dict["weight"]) weight = relay.var("weight", type_dict["weight"])
y1_folded = relay.ir_pass.forward_fold_scale_axis(y1) y1_folded = relay.ir_pass.forward_fold_scale_axis(y1)
y1_expected = expected(x, weight, in_bias, in_scale, channels) y1_expected = expected(x, weight, in_bias, in_scale, channels)
y1_folded = relay.ir_pass.infer_type(y1_folded)
y1_expected = relay.ir_pass.infer_type(y1_expected)
assert relay.ir_pass.alpha_equal(y1_folded, y1_expected) assert relay.ir_pass.alpha_equal(y1_folded, y1_expected)
check((2, 4, 10, 10), 2) check((2, 4, 10, 10), 2)
...@@ -113,6 +115,8 @@ def test_fold_fwd_dual_path(): ...@@ -113,6 +115,8 @@ def test_fold_fwd_dual_path():
type_dict = {x.name_hint:x.checked_type for x in y1.params} type_dict = {x.name_hint:x.checked_type for x in y1.params}
weight = relay.var("weight", type_dict["weight"]) weight = relay.var("weight", type_dict["weight"])
y1_expected = expected(x, weight, in_bias, in_scale, channels) y1_expected = expected(x, weight, in_bias, in_scale, channels)
y1_folded = relay.ir_pass.infer_type(y1_folded)
y1_expected = relay.ir_pass.infer_type(y1_expected)
assert relay.ir_pass.alpha_equal(y1_folded, y1_expected) assert relay.ir_pass.alpha_equal(y1_folded, y1_expected)
check((2, 4, 10, 3), 3) check((2, 4, 10, 3), 3)
...@@ -194,6 +198,8 @@ def test_fold_bwd_simple(): ...@@ -194,6 +198,8 @@ def test_fold_bwd_simple():
weight = relay.var("weight", type_dict["weight"]) weight = relay.var("weight", type_dict["weight"])
y1_folded = relay.ir_pass.backward_fold_scale_axis(y1) y1_folded = relay.ir_pass.backward_fold_scale_axis(y1)
y1_expected = expected(x, weight, out_bias, out_scale, channels) y1_expected = expected(x, weight, out_bias, out_scale, channels)
y1_folded = relay.ir_pass.infer_type(y1_folded)
y1_expected = relay.ir_pass.infer_type(y1_expected)
assert relay.ir_pass.alpha_equal(y1_folded, y1_expected) assert relay.ir_pass.alpha_equal(y1_folded, y1_expected)
check((2, 4, 10, 10), 8) check((2, 4, 10, 10), 8)
...@@ -255,6 +261,8 @@ def test_fold_bwd_dual_path(): ...@@ -255,6 +261,8 @@ def test_fold_bwd_dual_path():
weight = relay.var("weight", type_dict["weight"]) weight = relay.var("weight", type_dict["weight"])
y1_folded = relay.ir_pass.backward_fold_scale_axis(y1) y1_folded = relay.ir_pass.backward_fold_scale_axis(y1)
y1_expected = expected(x, weight, out_bias, out_scale, channels) y1_expected = expected(x, weight, out_bias, out_scale, channels)
y1_folded = relay.ir_pass.infer_type(y1_folded)
y1_expected = relay.ir_pass.infer_type(y1_expected)
assert relay.ir_pass.alpha_equal(y1_folded, y1_expected) assert relay.ir_pass.alpha_equal(y1_folded, y1_expected)
check((2, 4, 10, 10), 8) check((2, 4, 10, 10), 8)
......
...@@ -3,15 +3,103 @@ from tvm import relay ...@@ -3,15 +3,103 @@ from tvm import relay
def test_fuse_simple(): def test_fuse_simple():
"""Simple testcase.""" """Simple testcase."""
x = relay.var("x", shape=(10, 20)) def before():
y = relay.add(x, x) x = relay.var("x", shape=(10, 20))
z = relay.exp(y) y = relay.add(x, relay.const(1, "float32"))
z = relay.exp(y)
return relay.Function([x], z)
def expected():
x = relay.var("p", shape=(10, 20))
y = relay.add(x, relay.const(1, "float32"))
z = relay.exp(y)
f1 = relay.Function([x], z)
x = relay.var("x", shape=(10, 20))
y = relay.Call(f1, [x])
return relay.Function([x], y)
z = before()
z = relay.ir_pass.infer_type(z) z = relay.ir_pass.infer_type(z)
zz = relay.ir_pass.fuse_ops(z) zz = relay.ir_pass.fuse_ops(z, opt_level=2)
zz = relay.ir_pass.infer_type(zz)
zz = relay.ir_pass.fuse_ops(zz) zz = relay.ir_pass.fuse_ops(zz)
zz = relay.ir_pass.infer_type(zz) zz = relay.ir_pass.infer_type(zz)
zz.astext() after = relay.ir_pass.infer_type(expected())
assert relay.ir_pass.alpha_equal(zz, after)
def test_conv2d_fuse():
"""Test fusion case of conv2d"""
def before(dshape):
x = relay.var("x", shape=dshape)
y = relay.nn.conv2d(x, relay.var("w1"),
kernel_size=(3, 3),
padding=(1, 1),
channels=16)
# this is the next dominator.
y1 = relay.add(relay.const(1, "float32"), y)
y = relay.add(y, y1)
# second path
z2 = relay.nn.conv2d(y, relay.var("w2"),
kernel_size=(1, 1),
padding=(0,0),
channels=16)
z3 = relay.nn.conv2d(y, relay.var("w3"),
kernel_size=(3, 3),
padding=(1,1),
channels=16)
# add can only be fused to z1
z = relay.add(z2, z3)
return relay.Function(relay.ir_pass.free_vars(z), z)
def expected(dshape):
# segment 1
x = relay.var("p0", shape=dshape)
w = relay.var("p1")
y = relay.nn.conv2d(x, w,
kernel_size=(3, 3),
padding=(1, 1),
channels=16)
y1 = relay.add(relay.const(1, "float32"), y)
y = relay.add(y, y1)
f1 = relay.Function([x, w], y)
# segment 2
x = relay.var("p0", shape=dshape)
w = relay.var("p1")
z2 = relay.nn.conv2d(x, w,
kernel_size=(3, 3),
padding=(1,1),
channels=16)
f2 = relay.Function([x, w], z2)
# segment 3
x = relay.var("p0", shape=dshape)
w = relay.var("p1")
offset = relay.var("p2", shape=dshape)
z3 = relay.nn.conv2d(x, w,
kernel_size=(1, 1),
padding=(0, 0),
channels=16)
z3 = relay.add(z3, offset)
f3 = relay.Function([x, w, offset], z3)
# compose
x = relay.var("x", shape=dshape)
y = relay.Call(f1, [x, relay.var("w1")])
z2 = relay.Call(f2, [y, relay.var("w3")])
z3 = relay.Call(f3, [y, relay.var("w2"), z2])
z = z3
return relay.Function(relay.ir_pass.free_vars(z), z)
dshape = (1, 16, 64, 64)
z = before(dshape)
z = relay.ir_pass.infer_type(z)
zz = relay.ir_pass.fuse_ops(z, opt_level=2)
zz = relay.ir_pass.infer_type(zz)
after = relay.ir_pass.infer_type(expected(dshape))
assert relay.ir_pass.alpha_equal(zz, after)
if __name__ == "__main__": if __name__ == "__main__":
test_fuse_simple() test_fuse_simple()
test_conv2d_fuse()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment