Unverified Commit 4369b7f6 by Tianqi Chen Committed by GitHub

[RELAY][PASS] General OpFusion. (#2090)

parent e470f8ea
...@@ -429,6 +429,16 @@ inline const TTypeNode* ExprNode::type_as() const { ...@@ -429,6 +429,16 @@ inline const TTypeNode* ExprNode::type_as() const {
return node; return node;
} }
/*!
* \brief Print node as text format.
* \param node The node to be printed.
* \param annotate An optional callback function for attaching
* additional comment block to an expr.
* \return The text representation.
*/
std::string RelayPrint(
const NodeRef& node,
runtime::TypedPackedFunc<std::string(Expr)> annotate = nullptr);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_EXPR_H_ #endif // TVM_RELAY_EXPR_H_
...@@ -161,6 +161,8 @@ class TypedPackedFunc<R(Args...)> { ...@@ -161,6 +161,8 @@ class TypedPackedFunc<R(Args...)> {
using TSelf = TypedPackedFunc<R(Args...)>; using TSelf = TypedPackedFunc<R(Args...)>;
/*! \brief default constructor */ /*! \brief default constructor */
TypedPackedFunc() {} TypedPackedFunc() {}
/*! \brief constructor from null */
TypedPackedFunc(std::nullptr_t null) {} // NOLINT(*)
/*! /*!
* \brief construct by wrap a PackedFunc * \brief construct by wrap a PackedFunc
* *
......
...@@ -22,15 +22,20 @@ def register_relay_node(type_key=None): ...@@ -22,15 +22,20 @@ def register_relay_node(type_key=None):
class RelayNode(NodeBase): class RelayNode(NodeBase):
def astext(self): """Base class of all relay node."""
def astext(self, annotate=None):
"""Get the text format of the expression. """Get the text format of the expression.
Returns Returns
------- -------
text : str text : str
The text format of the expression. The text format of the expression.
annotate: Optional[relay.Expr->str]
Optional annotate function to provide additional
information in the comment block.
""" """
return _expr._text_print(self) return _expr.RelayPrint(self, annotate)
@register_relay_node @register_relay_node
......
...@@ -173,11 +173,13 @@ def build(func, ...@@ -173,11 +173,13 @@ def build(func,
else: else:
tophub_context = autotvm.util.EmptyContext() tophub_context = autotvm.util.EmptyContext()
cfg = BuildConfig.current
with tophub_context: with tophub_context:
func = optimize(func) func = optimize(func)
# Fuse ops before running code gen # Fuse ops before running code gen
func = ir_pass.infer_type(func) func = ir_pass.infer_type(func)
func = ir_pass.fuse_ops(func) func = ir_pass.fuse_ops(func, cfg.opt_level)
# Graph code generation # Graph code generation
func = ir_pass.infer_type(func) func = ir_pass.infer_type(func)
graph_gen = _graph_gen.GraphRuntimeCodegen(mod=None, target=target) graph_gen = _graph_gen.GraphRuntimeCodegen(mod=None, target=target)
......
...@@ -6,7 +6,6 @@ from numbers import Number as _Number ...@@ -6,7 +6,6 @@ from numbers import Number as _Number
import numpy as _np import numpy as _np
from .base import RelayNode, register_relay_node from .base import RelayNode, register_relay_node
from . import _make from . import _make
from . import _expr
from . import ty as _ty from . import ty as _ty
from .._ffi import base as _base from .._ffi import base as _base
from .. import nd as _nd from .. import nd as _nd
...@@ -477,7 +476,7 @@ class TupleWrapper(object): ...@@ -477,7 +476,7 @@ class TupleWrapper(object):
text : str text : str
The text format of the tuple expression. The text format of the tuple expression.
""" """
return _expr._text_print(self.tuple_value) return self.tuple_value.astext()
def __getitem__(self, index): def __getitem__(self, index):
if index >= len(self): if index >= len(self):
......
...@@ -259,7 +259,7 @@ def structural_hash(value): ...@@ -259,7 +259,7 @@ def structural_hash(value):
raise TypeError(msg) raise TypeError(msg)
def fuse_ops(expr): def fuse_ops(expr, opt_level=1):
"""Fuse operators in expr together. """Fuse operators in expr together.
Parameters Parameters
...@@ -267,9 +267,12 @@ def fuse_ops(expr): ...@@ -267,9 +267,12 @@ def fuse_ops(expr):
expr : tvm.relay.Expr expr : tvm.relay.Expr
The input expression. The input expression.
opt_level : int
The level of fuse optimization.
Returns Returns
------- -------
transformed_expr : tvm.relay.Expr transformed_expr : tvm.relay.Expr
Transformed expression, containing fused result. Transformed expression, containing fused result.
""" """
return _ir_pass.FuseOps(expr) return _ir_pass.FuseOps(expr, opt_level)
...@@ -38,11 +38,29 @@ class Arena { ...@@ -38,11 +38,29 @@ class Arena {
/*! /*!
* \brief Allocate a space from Arena for type T * \brief Allocate a space from Arena for type T
* \param T the data type to be allocated * \param T the data type to be allocated
* \note The space of T is not initialized.
*/ */
template<typename T> template<typename T>
T* Alloc() { T* allocate_() {
return static_cast<T*>(Alloc(sizeof(T), alignof(T))); return static_cast<T*>(Alloc(sizeof(T), alignof(T)));
} }
/*!
* \brief Create a new instance of type T.
* \param args The constructor argument.
* \tparam T the type to be created.
* \tparam Args Arguments to the constructor.
*
* \return The allocated object.
* \note The type T must be simple type, or only contain
* memory allocated from the same arena.
* Otherwise the destructor needs to be called explicitly.
*/
template<typename T, typename... Args>
T* make(Args&&... args) {
T* ptr = allocate_<T>();
new (ptr) T(std::forward<Args>(args)...);
return ptr;
}
private: private:
// page size 16 KB // page size 16 KB
...@@ -87,6 +105,44 @@ class Arena { ...@@ -87,6 +105,44 @@ class Arena {
} }
}; };
/*!
* \brief Link list node
* \tparam T the content data type
*/
template<typename T>
struct LinkNode {
/*! \brief The content value */
T value;
/*! \brief pointer to the next location */
LinkNode<T>* next{nullptr};
};
/*!
* \brief LinkedList structure
* \tparam T the content data type
* \note This is a simple data structure that can be used together with the arena.
* \sa LinkNode
*/
template<typename T>
struct LinkedList {
/*! \brief Head pointer */
LinkNode<T>* head{nullptr};
/*! \brief Tail pointer */
LinkNode<T>* tail{nullptr};
/*!
* \brief Push a new node to the end of the linked list.
* \param node The node to be pushed.
*/
void Push(LinkNode<T>* node) {
node->next = nullptr;
if (this->tail != nullptr) {
this->tail->next = node;
this->tail = node;
} else {
head = tail = node;
}
}
};
} // namespace common } // namespace common
} // namespace tvm } // namespace tvm
#endif // TVM_COMMON_ARENA_H_ #endif // TVM_COMMON_ARENA_H_
...@@ -109,6 +109,29 @@ class ScheduleGetter : ...@@ -109,6 +109,29 @@ class ScheduleGetter :
return {}; return {};
} }
Array<Tensor> VisitExpr_(const ConstantNode* op) final {
CHECK(op->is_scalar());
void* data = op->data->data;
DataType dtype = TVMType2Type(op->data->dtype);
Tensor value = tvm::compute({}, [&](const Array<tvm::Var>&) {
if (dtype == Int(32)) {
return make_const(dtype, static_cast<const int32_t*>(data)[0]);
} else if (dtype == Int(64)) {
return make_const(dtype, static_cast<const int64_t*>(data)[0]);
} else if (dtype == Float(32)) {
return make_const(dtype, static_cast<const float*>(data)[0]);
} else if (dtype == Float(64)) {
return make_const(dtype, static_cast<const double*>(data)[0]);
} else if (dtype == Bool()) {
return make_const(dtype, static_cast<const uint8_t*>(data)[0]);
} else {
LOG(FATAL) << "not handled";
return tvm::Expr();
}
});
return {value};
}
Array<Tensor> VisitExpr_(const CallNode* call_node) final { Array<Tensor> VisitExpr_(const CallNode* call_node) final {
static auto fcompute = static auto fcompute =
Op::GetAttr<FTVMCompute>("FTVMCompute"); Op::GetAttr<FTVMCompute>("FTVMCompute");
......
...@@ -125,6 +125,8 @@ class TextPrinter : ...@@ -125,6 +125,8 @@ class TextPrinter :
public TypeFunctor<void (const Type&, std::ostream& os)>, // NOLINT(*) public TypeFunctor<void (const Type&, std::ostream& os)>, // NOLINT(*)
public AttrFunctor<void (const NodeRef&, std::ostream& os)> { // NOLINT(*) public AttrFunctor<void (const NodeRef&, std::ostream& os)> { // NOLINT(*)
public: public:
explicit TextPrinter(runtime::TypedPackedFunc<std::string(Expr)> annotate)
: annotate_(annotate) {}
/*! /*!
* \brief Print a node to string. * \brief Print a node to string.
* \param node. * \param node.
...@@ -279,11 +281,11 @@ class TextPrinter : ...@@ -279,11 +281,11 @@ class TextPrinter :
TextValue VisitExpr_(const CallNode* op) final { TextValue VisitExpr_(const CallNode* op) final {
// possibly through meta-data // possibly through meta-data
TextValue call_op = GetValue(op->op);
std::vector<TextValue> args; std::vector<TextValue> args;
for (Expr arg : op->args) { for (Expr arg : op->args) {
args.emplace_back(GetValue(arg)); args.emplace_back(GetValue(arg));
} }
TextValue call_op = GetValue(op->op);
TextValue id = this->AllocTempVar(); TextValue id = this->AllocTempVar();
this->PrintIndent(); this->PrintIndent();
...@@ -532,7 +534,9 @@ class TextPrinter : ...@@ -532,7 +534,9 @@ class TextPrinter :
*/ */
void PrintOptionalInfo(const Expr& expr) { void PrintOptionalInfo(const Expr& expr) {
// additional information in comment. // additional information in comment.
if (expr->checked_type_.defined()) { if (annotate_ != nullptr) {
stream_ << " # " << annotate_(expr);
} else if (expr->checked_type_.defined()) {
stream_ << " # ty="; stream_ << " # ty=";
this->PrintType(expr->checked_type(), stream_); this->PrintType(expr->checked_type(), stream_);
} }
...@@ -678,7 +682,10 @@ class TextPrinter : ...@@ -678,7 +682,10 @@ class TextPrinter :
name = "%" + name; name = "%" + name;
} }
TextValue val(GetUniqueName(name)); TextValue val(GetUniqueName(name));
CHECK(!memo_.count(var)) << "Duplicated variable " << var; // still print if ir is malformed, but show the error.
if (memo_.count(var)) {
memo_[var] = TextValue(val.name + "-malformed-ir");
}
memo_[var] = val; memo_[var] = val;
return val; return val;
} }
...@@ -686,6 +693,8 @@ class TextPrinter : ...@@ -686,6 +693,8 @@ class TextPrinter :
private: private:
class AttrPrinter; class AttrPrinter;
friend class AttrPrinter; friend class AttrPrinter;
/*! \brief additional comment function */
runtime::TypedPackedFunc<std::string(Expr)> annotate_;
/*! \brief meta data context */ /*! \brief meta data context */
TextMetaDataContext meta_; TextMetaDataContext meta_;
/*! \brief Check whether scope is still valid */ /*! \brief Check whether scope is still valid */
...@@ -776,12 +785,15 @@ void TextPrinter::PrintCallAttrs(const Expr& op, ...@@ -776,12 +785,15 @@ void TextPrinter::PrintCallAttrs(const Expr& op,
os << ", " << meta_.GetMetaNode(attrs); os << ", " << meta_.GetMetaNode(attrs);
} }
std::string RelayPrint(const NodeRef& node) { std::string RelayPrint(const NodeRef& node,
return TextPrinter().Print(node); runtime::TypedPackedFunc<std::string(Expr)> annotate) {
return TextPrinter(annotate).Print(node);
} }
TVM_REGISTER_API("relay._expr._text_print") TVM_REGISTER_API("relay._expr.RelayPrint")
.set_body_typed<std::string(const NodeRef&)>(RelayPrint); .set_body_typed<std::string(
const NodeRef&,
runtime::TypedPackedFunc<std::string(Expr)>)>(RelayPrint);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/nn.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include "pattern_util.h" #include "pattern_util.h"
#include "pass_util.h"
#include "../op/nn/layout.h" #include "../op/nn/layout.h"
namespace tvm { namespace tvm {
...@@ -580,23 +581,6 @@ using FBackwardTransform = TypedPackedFunc< ...@@ -580,23 +581,6 @@ using FBackwardTransform = TypedPackedFunc<
//---------------------------------------------- //----------------------------------------------
// Generic Visitors for FScaleAxisBackward // Generic Visitors for FScaleAxisBackward
//---------------------------------------------- //----------------------------------------------
/*!
* \brief Get reference counter of each internal ExprNode in body.
* \param body The body expression.
* \return The reference count mapping.
*/
std::unordered_map<const Node*, size_t>
GetExprRefCount(const Expr& body) {
class ExprRefCounter : private ExprVisitor {
public:
std::unordered_map<const Node*, size_t>
Get(const Expr& body) {
this->VisitExpr(body);
return std::move(this->visit_counter_);
}
};
return ExprRefCounter().Get(body);
}
class BackwardPrep : private ExprVisitor { class BackwardPrep : private ExprVisitor {
public: public:
......
...@@ -9,13 +9,686 @@ ...@@ -9,13 +9,686 @@
#include <tvm/ir_operator.h> #include <tvm/ir_operator.h>
#include <tvm/relay/pass.h> #include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
#include "../../common/arena.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
// Simple fuser that only makes each operator function as primitive. /*
class SimpleFuser : public ExprMutator { Note on Fusing algorithm:
The main challenge of genenral fusor is to handle possible diamond shape branches,
in the following graph, conv2d can be fused to elemwise add.
conv2d
/ | \
/ | \
op op op
\ | /
\ | /
elemwise add
|
However, at the point of conv2d we do not necessarily know that all its future path
will merge at the elemwise add. The new fusor algorithm applies post-dominator analysis.
The immediate post-dominator of a node defined by the closest node where all the future path goes into.
In the above case, the elemwise add is the post-dominator of conv2d. The general algorithm is as follows:
- Construct a DAG of dataflow graph for dominator analysis
- Construct a post-dominator tree which gives immediate post dominator of each node.
- Run fusion algorithm with the given post-dominator information.
Note that, because we run analysis on a DAG, we use a single pass post-dominator
tree construction algorithm via LCA, which is simpler than the full version that handles cycles.
The fusion algorithm traverses from each node and checks if it can be fused to its
immediate post dominator. It has to check the following things:
- CheckPath: check all the path between a node and its immediate post-dominator
satiesfies the fuse condition.
- Note that these intermediate node can already be fused with another nodes, the algorithm
will still run correctly.
- CommitFuse: mark all the nodes between source and post-dominator as the same group.
- We use an Union-Find data structure to manage the groups.
*/
using common::LinkNode;
using common::LinkedList;
/*!
* \brief Indexed data flow graph in forward direction.
* This is a temporary data structure used for operator fusion analysis.
*
* This data structure only captures the dataflow fragement and
* could ignore blocks like let by simply ordering each dataflow block
* and mark the output node as extern_ref;
*/
class IndexedForwardGraph {
public:
struct Node;
/*!
* The forward edge in the dataflow graph.
*/
struct Edge {
/*! \brief The corresponding node */
Node* node{nullptr};
/*! \brief The respective pattern of this op */
OpPatternKind pattern{kOpaque};
};
/*! \brief A node in the graph. */
struct Node {
/*! \brief weak reference to the corresponding edge. */
const tvm::Node* ref{nullptr};
/*! \brief The index of the node in topological order. */
size_t index{0};
/*! \brief Whether this node is referenced by external source */
bool extern_ref{false};
/*! \brief The general pattern in the node */
OpPatternKind pattern{kOpaque};
/*! \brief The outputs of the node. */
LinkedList<Edge> outputs;
};
/*! \brief The node map that maps node to graph */
std::unordered_map<const tvm::Node*, Node*> node_map;
/*! \brief All the nodes in post DFS order */
std::vector<Node*> post_dfs_order;
/*! \brief Dump the graph into string. */
void DebugDump() {
std::ostringstream os;
for (size_t i = 0; i < post_dfs_order.size(); ++i) {
Node* node = post_dfs_order[i];
os << "node[" << i << "], "
<< GetRef<NodeRef>(node->ref)
<< " outputs=[";
for (auto* link = node->outputs.head; link != nullptr; link = link->next) {
os << link->value.node->index << ", ";
}
os << "]\n";
}
LOG(INFO) << os.str();
}
/*!
* \brief create a indexed forward graph.
* \param arena The arena used for data allocation.
* \param body The body of the expression to create a graph.
*/
static IndexedForwardGraph Create(common::Arena* arena, const Expr& body);
private:
class Creator;
};
// Creator of post dominator tree of the dataflow
class IndexedForwardGraph::Creator : private ExprVisitor {
public:
explicit Creator(common::Arena* arena)
: arena_(arena) {}
IndexedForwardGraph Prepare(const Expr& body) {
this->Update(body, nullptr, kOpaque);
this->VisitExpr(body);
return std::move(graph_);
}
private:
/*! \brief allocator of all the internal node object */
common::Arena* arena_;
// The output.
IndexedForwardGraph graph_;
// attribute equal comparator
AttrsEqual attr_equal_;
// Update the message stored at the node.
void Update(const Expr& node,
IndexedForwardGraph::Node* parent,
OpPatternKind pattern) {
const tvm::Node* key = node.get();
IndexedForwardGraph::Node* current;
auto it = graph_.node_map.find(key);
if (it != graph_.node_map.end()) {
current = it->second;
} else {
current = arena_->make<IndexedForwardGraph::Node>();
graph_.node_map[key] = current;
}
if (parent != nullptr) {
auto* link = arena_->make<LinkNode<IndexedForwardGraph::Edge> >();
link->value.node = parent;
link->value.pattern = pattern;
current->outputs.Push(link);
} else {
current->extern_ref = true;
}
}
void AddNode(const tvm::Node* key) {
auto it = graph_.node_map.find(key);
CHECK(it != graph_.node_map.end())
<< "Cannot find node " << GetRef<NodeRef>(key);
IndexedForwardGraph::Node* node = it->second;
CHECK(node->ref == nullptr);
node->ref = key;
node->index = graph_.post_dfs_order.size();
graph_.post_dfs_order.push_back(node);
}
// Post order tree
void VisitExpr_(const FunctionNode* op) {
for (auto param : op->params) {
this->Update(param, nullptr, kOpaque);
}
this->Update(op->body, nullptr, kOpaque);
ExprVisitor::VisitExpr_(op);
}
void VisitExpr_(const ConstantNode* op) {
this->AddNode(op);
Node* node = graph_.node_map.at(op);
DataType dtype = TVMType2Type(op->data->dtype);
// This rule must be consistent with code generator.
bool is_simple_const = (
dtype == Int(32) ||
dtype == Int(64) ||
dtype == Float(32) ||
dtype == Float(64) ||
dtype == Bool());
if (op->is_scalar() && is_simple_const) {
node->pattern = kElemWise;
} else {
// for now, mark non-scalar constant
// as opaque, we will not choose to fuse it.
node->pattern = kOpaque;
}
}
void VisitExpr_(const CallNode* call) {
CHECK(graph_.node_map.count(call));
Node* node = graph_.node_map.at(call);
static auto fpattern =
Op::GetAttr<TOpPattern>("TOpPattern");
// setup pattern.
OpPatternKind op_pattern = kOpaque;
if (const OpNode* opnode = call->op.as<OpNode>()) {
op_pattern = static_cast<OpPatternKind>(fpattern[GetRef<Op>(opnode)]);
}
node->pattern = op_pattern;
const auto* rtype = call->checked_type().as<TensorTypeNode>();
// pass the message back to all the children it references.
for (size_t i = 0; i < call->args.size(); ++i) {
const auto* arg_type =
call->args[i]->checked_type().as<TensorTypeNode>();
// specifically check if result type
OpPatternKind edge_pattern = op_pattern;
if (edge_pattern == kBroadcast &&
arg_type != nullptr &&
rtype != nullptr &&
attr_equal_(rtype->shape, arg_type->shape)) {
edge_pattern = kElemWise;
}
this->Update(call->args[i], node, edge_pattern);
}
ExprVisitor::VisitExpr_(call);
this->AddNode(call);
}
void VisitExpr_(const TupleNode* op) {
for (const Expr& field : op->fields) {
this->Update(field, nullptr, kOpaque);
}
ExprVisitor::VisitExpr_(op);
this->AddNode(op);
}
void VisitExpr_(const TupleGetItemNode* op) {
CHECK(graph_.node_map.count(op));
Node* node = graph_.node_map.at(op);
this->Update(op->tuple, node, kOpaque);
ExprVisitor::VisitExpr_(op);
this->AddNode(op);
}
void VisitExpr_(const VarNode* op) {
this->AddNode(op);
}
void VisitExpr_(const LetNode* op) {
// do not fuse through let.
this->Update(op->var, nullptr, kOpaque);
this->Update(op->value, nullptr, kOpaque);
this->Update(op->body, nullptr, kOpaque);
ExprVisitor::VisitExpr_(op);
this->AddNode(op);
}
void VisitExpr_(const IfNode* op) {
// do not fuse through if.
this->Update(op->cond, nullptr, kOpaque);
this->Update(op->true_branch, nullptr, kOpaque);
this->Update(op->false_branch, nullptr, kOpaque);
ExprVisitor::VisitExpr_(op);
this->AddNode(op);
}
};
IndexedForwardGraph IndexedForwardGraph::Create(
common::Arena* arena, const Expr& body) {
return Creator(arena).Prepare(body);
}
/*!
* \brief Dominator tree that represent domination or
* post domination relation of the node.
*/
class DominatorTree {
public: public:
/*!
* \brief A node in the dominator tree.
*/
struct Node {
/*! \brief The node in the tree */
IndexedForwardGraph::Node* gnode{nullptr};
/*! \brief parent of the tree */
Node* parent{nullptr};
/*! \brief current depth*/
int depth{0};
/*! \brief aggregated pattern to parent */
OpPatternKind pattern{kOpaque};
};
// index -> node.
std::vector<Node*> nodes;
/*!
* \brief compute a post dominator relation for a given dataflow graph.
* \param arena The arena used for node allocation.
* \param graph The graph to be analyze.
* \return The dominator tree of the graph.
* \note This algorithm makes use of the fact that graph is DAG,
* and runs a single pass algorithm via LCA.
*/
static DominatorTree PostDom(common::Arena* arena,
const IndexedForwardGraph& graph);
private:
// Combine pattern together.
static OpPatternKind CombinePattern(
OpPatternKind lhs, OpPatternKind rhs) {
if (lhs > rhs) return lhs;
return rhs;
}
/*!
* \brief Find the least common acenstor of the two nodes.
* \param lhs The left node.
* \param rhs The right node.
* \param edge_pattern
* The combined edge pattern across all the parents.
* \return The least common acenstor of thw two.
*/
static Node* LeastCommonAcenstor(
Node* lhs,
Node* rhs,
OpPatternKind* edge_pattern) {
while (lhs != rhs) {
if (lhs == nullptr) return nullptr;
if (rhs == nullptr) return nullptr;
if (lhs->depth < rhs->depth) {
edge_pattern[0] = CombinePattern(
edge_pattern[0], rhs->pattern);
rhs = rhs->parent;
} else if (rhs->depth < lhs->depth) {
edge_pattern[0] = CombinePattern(
edge_pattern[0], lhs->pattern);
lhs = lhs->parent;
} else {
lhs = lhs->parent;
rhs = rhs->parent;
edge_pattern[0] = CombinePattern(
edge_pattern[0], lhs->pattern);
edge_pattern[0] = CombinePattern(
edge_pattern[0], rhs->pattern);
}
}
return lhs;
}
};
DominatorTree DominatorTree::PostDom(common::Arena* arena,
const IndexedForwardGraph& graph) {
DominatorTree tree;
tree.nodes.resize(graph.post_dfs_order.size(), nullptr);
// reverse topo order
for (size_t i = graph.post_dfs_order.size(); i != 0; --i) {
size_t index = i - 1;
Node* tnode = arena->make<Node>();
auto* gnode = graph.post_dfs_order[index];
tnode->gnode = gnode;
if (gnode->extern_ref) {
tnode->depth = 1;
tnode->parent = nullptr;
tnode->pattern = kOpaque;
} else {
// find the LCAs of all outputs.
OpPatternKind pattern = kElemWise;
Node* parent = nullptr;
for (auto link = gnode->outputs.head; link != nullptr; link= link->next) {
size_t oindex = link->value.node->index;
CHECK_LT(oindex, tree.nodes.size());
Node* onode = tree.nodes[oindex];
CHECK(onode != nullptr);
if (parent != nullptr) {
parent = LeastCommonAcenstor(parent, onode, &pattern);
} else {
parent = onode;
}
pattern = CombinePattern(pattern, link->value.pattern);
}
CHECK(parent != nullptr);
tnode->depth = parent->depth + 1;
tnode->parent = parent;
tnode->pattern = pattern;
}
tree.nodes[index] = tnode;
}
return tree;
}
/*!
* \brief A partition of the graph marked by union find data structure.
*/
class GraphPartitioner {
public:
explicit GraphPartitioner(common::Arena* arena, int opt_level)
: arena_(arena), opt_level_(opt_level) {}
/*!
* \brief Group as a union find data structure.
*/
struct Group {
/*! \brief The parent in the union find data structure. */
Group* parent{nullptr};
/*! \brief The pattern of the group */
OpPatternKind pattern;
/*! \brief reference to the root node. */
const tvm::Node* root_ref{nullptr};
/*!
* \brief Reference to the master node,
* this field is not nullptr only if pattern is kOutEWiseFusable.
*/
const tvm::Node* master_ref{nullptr};
/*!
* \brief Find the group root, perform path compression
* \return The root type node.
*/
Group* FindRoot() {
// fast path
if (this->parent == nullptr) return this;
// slow path with path compression.
Group* root = this;
while (root->parent != nullptr) {
root = root->parent;
}
for (Group* p = this; p != root;) {
Group* parent = p->parent;
p->parent = root;
p = parent;
}
return root;
}
};
/*!
* \brief Partition a graph.
* \return group assignments of each node.
*/
std::vector<Group*> Partition(const IndexedForwardGraph& graph);
private:
/*! \brief The internal arena for temporary space. */
common::Arena* arena_;
/*! \brief optimization level for fuse operation. */
int opt_level_;
/*! \brief The internal groups. */
std::vector<Group*> groups_;
/*! \brief internal field used for deduplication */
std::unordered_set<IndexedForwardGraph::Node*> visited_;
// Internal implelementation of CheckPath
template<typename F>
bool CheckPath_(IndexedForwardGraph::Node* src,
IndexedForwardGraph::Node* sink,
F fcond) {
if (visited_.count(src)) return true;
visited_.insert(src);
Group* gnode = groups_[src->index];
CHECK(gnode != nullptr);
gnode = gnode->FindRoot();
if (!fcond(gnode->pattern, src == sink)) return false;
if (src == sink) return true;
for (auto link = src->outputs.head; link != nullptr; link = link->next) {
if (!CheckPath_(link->value.node, sink, fcond)) return false;
}
return true;
}
/*!
* \brief Check all the node between src and sink satisfies fcond.
*
* src and sink are not checked.
*
* \param src The source node.
* \param sink The termination node.
* \param fcond The condition to be checked.
* \tparam F the condition function.
* \note sink must be a post-dominator of src.
*/
template<typename F>
bool CheckPath(IndexedForwardGraph::Node* src,
IndexedForwardGraph::Node* sink,
F fcond) {
CHECK(!src->extern_ref);
visited_.clear();
CHECK(src != sink);
for (auto link = src->outputs.head; link != nullptr; link = link->next) {
if (!CheckPath_(link->value.node, sink, fcond)) return false;
}
return true;
}
// Combine two patterns together.
static OpPatternKind CombinePattern(
OpPatternKind lhs, OpPatternKind rhs) {
if (lhs > kBroadcast && rhs > kBroadcast) {
LOG(FATAL) << "Cannot merge two complex group together";
}
if (lhs > rhs) return lhs;
return rhs;
}
/*!
* \brief Merge the child group to the parent.
* \param child The child group.
* \param parent The parent group.
*/
void MergeFromTo(Group* child, Group* parent) {
child = child->FindRoot();
parent = parent->FindRoot();
if (child == parent) return;
child->parent = parent;
// update master ref and pattern
if (child->master_ref != nullptr) {
CHECK(parent->master_ref == nullptr);
parent->master_ref = child->master_ref;
parent->pattern = CombinePattern(
child->pattern, parent->pattern);
}
}
// Internal implelementation of CommitFuse
void CommitFuse_(IndexedForwardGraph::Node* src,
IndexedForwardGraph::Node* sink,
Group* target) {
if (src == sink) return;
if (visited_.count(src)) return;
visited_.insert(src);
Group* gnode = groups_[src->index];
CHECK(gnode != nullptr);
// merge the current group to the parent if possible.
MergeFromTo(gnode, target);
for (auto link = src->outputs.head; link != nullptr; link = link->next) {
CommitFuse_(link->value.node, sink, target);;
}
}
/*!
* \brief Commit fusion operation.
* \param src The source node.
* \param sink The termination node.
* \tparam group the group to be committed.
* \note sink must be a post-dominator of src.
*/
void CommitFuse(IndexedForwardGraph::Node* src,
IndexedForwardGraph::Node* sink) {
Group* target = groups_[sink->index];
visited_.clear();
CHECK(src != sink);
CommitFuse_(src, sink, target);
}
// Initialize the groups.
void InitGroups(const IndexedForwardGraph& graph) {
groups_.resize(graph.post_dfs_order.size());
for (size_t nid = 0; nid < groups_.size(); ++nid) {
const auto* graph_node = graph.post_dfs_order[nid];
auto* group_node = arena_->make<Group>();
group_node->pattern = graph_node->pattern;
group_node->root_ref = graph_node->ref;
// set master ref if necessary.
if (group_node->pattern == kOutEWiseFusable) {
group_node->master_ref = graph_node->ref;
}
groups_[nid] = group_node;
}
}
// execute the fusion algorithm.
void RunFuse(const IndexedForwardGraph& graph,
const DominatorTree& post_dom_tree,
int phase) {
for (size_t nid = 0; nid < groups_.size(); ++nid) {
// the group of current node has been specified already.
auto* graph_node = graph.post_dfs_order[nid];
auto* dom_node = post_dom_tree.nodes[nid];
Group* group_node = groups_[nid];
CHECK(group_node != nullptr);
// no actions for opaque nodes
if (group_node->pattern == kOpaque) continue;
// no actions needed if the current node have no dominator
if (dom_node->parent == nullptr) continue;
CHECK(!graph_node->extern_ref);
// Skip if current node is already fused to the parent.
size_t dom_parent_gindex = dom_node->parent->gnode->index;
if (groups_[dom_parent_gindex] != nullptr &&
group_node->FindRoot() == groups_[dom_parent_gindex]->FindRoot()) {
continue;
}
// Try to fuse current node to its post-dominator.
if (group_node->pattern == kOutEWiseFusable) {
if (phase != 0) continue;
// Path for OutEWiseFusable: conv2d
// Check if the dominator relation is elemwise.
if (dom_node->parent != nullptr && dom_node->pattern == kElemWise) {
CHECK(dom_node->parent->gnode != nullptr);
// The fuse can be executed if all the intermediate ops are still broadcast.
auto fcond = [](OpPatternKind kind, bool is_sink) {
return kind <= kBroadcast;
};
if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
CommitFuse(graph_node, dom_node->parent->gnode);
}
}
} else if (group_node->pattern <= kBroadcast) {
// The fuse can be executed if all the intermediate ops are still broadcast.
auto fcond = [](OpPatternKind kind, bool is_sink) {
if (!is_sink) {
return kind <= kBroadcast;
} else {
return (kind <= kBroadcast ||
kind == kCommReduce ||
kind == kOutEWiseFusable);
}
};
if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
CommitFuse(graph_node, dom_node->parent->gnode);
}
} else if (group_node->pattern == kInjective) {
// defer injective fusion to second phase.
// so conv2d always finishes fusing.
if (phase != 1) continue;
// Check if all path are injective.
auto fcond = [](OpPatternKind kind, bool is_sink) {
return kind <= kInjective;
};
if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
CommitFuse(graph_node, dom_node->parent->gnode);
}
} else {
// do nothing.
CHECK(group_node->pattern == kCommReduce);
}
}
}
};
std::vector<GraphPartitioner::Group*>
GraphPartitioner::Partition(const IndexedForwardGraph& graph) {
this->InitGroups(graph);
if (opt_level_ == 0) return std::move(groups_);
// get post dominator tree
auto post_dom_tree = DominatorTree::PostDom(arena_, graph);
// run fusion algorithm.
for (int phase = 0; phase < 2; ++phase) {
this->RunFuse(graph, post_dom_tree, phase);
}
return std::move(groups_);
}
class FuseMutator : private ExprMutator {
public:
// Run the transform
Expr Transform(const Expr& body, int fuse_opt_level) {
// setup the group map.
auto graph = IndexedForwardGraph::Create(&arena_, body);
auto groups = GraphPartitioner(&arena_, fuse_opt_level).Partition(
graph);
for (size_t nid = 0; nid < graph.post_dfs_order.size(); ++nid) {
CHECK(graph.post_dfs_order[nid]->ref != nullptr);
gmap_[graph.post_dfs_order[nid]->ref] = groups[nid];
}
// The following line can be used for debug.
// this->DebugDumpGroup(body);
return this->Mutate(body);
}
private:
/*! \brief Temporary information from each group. */
struct GroupInfo {
public:
// The parameters of the function.
Array<Var> params;
// The arguments to call the functions.
Array<Expr> arguments;
// Get a new parameter or allocate an old one
Var GetOrAllocParam(const Expr& expr, const Type& type) {
// run linear scan as most fused groups contain only a few inputs.
for (size_t i = 0; i < arguments.size(); ++i) {
if (expr.same_as(arguments[i])) return params[i];
}
// create a new parameter.
std::ostringstream os;
os << "p" << params.size();
auto var = VarNode::make(os.str(), type);
params.push_back(var);
arguments.push_back(expr);
return var;
}
};
/*! \brief Internal arena. */
common::Arena arena_;
/*! \brief The group assignment map. */
std::unordered_map<const Node*, GraphPartitioner::Group*> gmap_;
/* \brief Internal group information map. */
std::unordered_map<GraphPartitioner::Group*, GroupInfo> ginfo_;
// Skip primitive function. // Skip primitive function.
Expr VisitExpr_(const FunctionNode* fn_node) { Expr VisitExpr_(const FunctionNode* fn_node) {
NodeRef res = FunctionGetAttr(GetRef<Function>(fn_node), "Primitive"); NodeRef res = FunctionGetAttr(GetRef<Function>(fn_node), "Primitive");
...@@ -26,48 +699,74 @@ class SimpleFuser : public ExprMutator { ...@@ -26,48 +699,74 @@ class SimpleFuser : public ExprMutator {
return ExprMutator::VisitExpr_(fn_node); return ExprMutator::VisitExpr_(fn_node);
} }
} }
// Transform calls.
Expr VisitExpr_(const CallNode* call) { Expr VisitExpr_(const CallNode* call) {
if (call->op.as<OpNode>()) { if (call->op.as<OpNode>()) {
// Placeholder fusion algorithm which abstracts // If it is a primitive op call
// single definitions into functions only. // then we must have a group assignment for it already.
Array<Var> params; CHECK(gmap_.count(call));
Array<Expr> inner_args; auto* ret_group = gmap_.at(call)->FindRoot();
Array<Expr> args; Array<Expr> new_args;
int param_number = 0;
for (auto arg : call->args) { for (auto arg : call->args) {
std::ostringstream os;
os << "p" << param_number++;
auto type = arg->checked_type(); auto type = arg->checked_type();
auto var = VarNode::make(os.str(), type); CHECK(gmap_.count(arg.get()))
params.push_back(var); << "cannot find group of " << arg;
inner_args.push_back(var); auto* arg_group = gmap_.at(arg.get())->FindRoot();
args.push_back(this->Mutate(arg)); Expr new_arg = this->Mutate(arg);
if (ret_group != arg_group) {
Var param = ginfo_[ret_group].GetOrAllocParam(new_arg, type);
new_args.push_back(param);
} else {
new_args.push_back(new_arg);
}
}
auto new_call = CallNode::make(
call->op, new_args, call->attrs, call->type_args);
if (ret_group->root_ref == call) {
// This is the root of the group
// create the new call node.
const GroupInfo& ginfo = ginfo_[ret_group];
auto func = FunctionNode::make(
ginfo.params, new_call, call->checked_type(), {});
func = FunctionSetAttr(func, "Primitive", tvm::Integer(1));
return CallNode::make(func, ginfo.arguments, Attrs());
} else {
// This is an intermediate node of a fused function
// simply return the new call.
return new_call;
} }
auto body = CallNode::make(call->op, inner_args, call->attrs);
auto func = FunctionNode::make(
params, body, call->checked_type(), {});
func = FunctionSetAttr(func, "Primitive", tvm::Integer(1));
return CallNode::make(func, args, Attrs());
} else { } else {
return ExprMutator::VisitExpr_(call); return ExprMutator::VisitExpr_(call);
} }
} }
// Debug function, dump the group assignment in text.
void DebugDumpGroup(const Expr& body) {
std::string text = RelayPrint(body, [this](const Expr& expr) -> std::string {
auto it = gmap_.find(expr.get());
if (it == gmap_.end()) return "";
std::ostringstream os;
auto *group = it->second->FindRoot();
os << "group=" << group;
return os.str();
});
LOG(INFO) << "Dump of group info:\n" << text;
}
}; };
Expr FuseOps(const Expr& expr) { Expr FuseOps(const Expr& expr, int fuse_opt_level) {
// First we convert all chains of fusable ops into // First we convert all chains of fusable ops into
// abstracted functions which we mark as primtive // abstracted functions which we mark as primtive
// then we convert these primtive functions into // then we convert these primtive functions into
// new operators. // new operators.
return SimpleFuser().Mutate(expr); return FuseMutator().Transform(expr, fuse_opt_level);
} }
TVM_REGISTER_API("relay._ir_pass.FuseOps") TVM_REGISTER_API("relay._ir_pass.FuseOps")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = FuseOps(args[0]); *ret = FuseOps(args[0], args[1]);
}); });
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
/*!
* Copyright (c) 2018 by Contributors.
*
* \file tvm/relay/pass/pass_util.h
* \brief Utilities for writing
*/
#ifndef TVM_RELAY_PASS_PASS_UTIL_H_
#define TVM_RELAY_PASS_PASS_UTIL_H_
#include <tvm/relay/op.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/attrs/transform.h>
namespace tvm {
namespace relay {
/*!
* \brief Get reference counter of each internal ExprNode in body.
* \param body The body expression.
* \return The reference count mapping.
*/
std::unordered_map<const Node*, size_t>
GetExprRefCount(const Expr& body);
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_PASS_UTIL_H_
...@@ -442,6 +442,9 @@ class TypeInferencer::Resolver : public ExprMutator { ...@@ -442,6 +442,9 @@ class TypeInferencer::Resolver : public ExprMutator {
VarNode* new_var =( VarNode* new_var =(
std::is_base_of<VarNode, T>::value ? std::is_base_of<VarNode, T>::value ?
static_cast<VarNode*>(new_e.node_.get()) : nullptr); static_cast<VarNode*>(new_e.node_.get()) : nullptr);
FunctionNode* new_fn =(
std::is_base_of<FunctionNode, T>::value ?
static_cast<FunctionNode*>(new_e.node_.get()) : nullptr);
// check if we need update the new_e // check if we need update the new_e
bool need_update_type = !checked_type.same_as(new_e->checked_type_); bool need_update_type = !checked_type.same_as(new_e->checked_type_);
...@@ -454,7 +457,17 @@ class TypeInferencer::Resolver : public ExprMutator { ...@@ -454,7 +457,17 @@ class TypeInferencer::Resolver : public ExprMutator {
update_missing_type_annotation_ && update_missing_type_annotation_ &&
!new_var->type_annotation.defined()); !new_var->type_annotation.defined());
if (!need_update_type && !need_update_var && !need_update_call) return new_e; bool need_update_fn = (
std::is_base_of<FunctionNode, T>::value &&
update_missing_type_annotation_ &&
!new_fn->ret_type.defined());
if (!need_update_type &&
!need_update_var &&
!need_update_call &&
!need_update_fn) {
return new_e;
}
if (!new_e.node_.unique()) { if (!new_e.node_.unique()) {
// Copy on write optimization // Copy on write optimization
...@@ -467,6 +480,9 @@ class TypeInferencer::Resolver : public ExprMutator { ...@@ -467,6 +480,9 @@ class TypeInferencer::Resolver : public ExprMutator {
new_var = ( new_var = (
std::is_base_of<VarNode, T>::value ? std::is_base_of<VarNode, T>::value ?
static_cast<VarNode*>(new_e.node_.get()) : nullptr); static_cast<VarNode*>(new_e.node_.get()) : nullptr);
new_fn = (
std::is_base_of<FunctionNode, T>::value ?
static_cast<FunctionNode*>(new_e.node_.get()) : nullptr);
} }
// attach the information. // attach the information.
...@@ -483,6 +499,11 @@ class TypeInferencer::Resolver : public ExprMutator { ...@@ -483,6 +499,11 @@ class TypeInferencer::Resolver : public ExprMutator {
if (need_update_var) { if (need_update_var) {
new_var->type_annotation = checked_type; new_var->type_annotation = checked_type;
} }
if (need_update_fn) {
auto* fn_type = checked_type.as<FuncTypeNode>();
CHECK(fn_type != nullptr);
new_fn->ret_type = fn_type->ret_type;
}
return new_e; return new_e;
} }
......
...@@ -85,18 +85,18 @@ Type TypeSolver::Unify(const Type& dst, const Type& src) { ...@@ -85,18 +85,18 @@ Type TypeSolver::Unify(const Type& dst, const Type& src) {
void TypeSolver::AddConstraint(const TypeConstraint& constraint) { void TypeSolver::AddConstraint(const TypeConstraint& constraint) {
if (auto *op = constraint.as<TypeRelationNode>()) { if (auto *op = constraint.as<TypeRelationNode>()) {
// create a new relation node. // create a new relation node.
RelationNode* rnode = make<RelationNode>(); RelationNode* rnode = arena_.make<RelationNode>();
rnode->rel = GetRef<TypeRelation>(op); rnode->rel = GetRef<TypeRelation>(op);
rel_nodes_.push_back(rnode); rel_nodes_.push_back(rnode);
// populate the type information. // populate the type information.
for (size_t i = 0; i < op->args.size(); ++i) { for (size_t i = 0; i < op->args.size(); ++i) {
// insert link to the type list // insert link to the type list
LinkNode<TypeNode*>* tlink = make<LinkNode<TypeNode*> >(); LinkNode<TypeNode*>* tlink = arena_.make<LinkNode<TypeNode*> >();
TypeNode* tnode = GetTypeNode(op->args[i]); TypeNode* tnode = GetTypeNode(op->args[i]);
tlink->value = tnode; tlink->value = tnode;
rnode->type_list.Push(tlink); rnode->type_list.Push(tlink);
// insert type->relation node // insert type->relation node
LinkNode<RelationNode*>* rlink = make<LinkNode<RelationNode*> >(); LinkNode<RelationNode*>* rlink = arena_.make<LinkNode<RelationNode*> >();
rlink->value = rnode; rlink->value = rnode;
tnode->rel_list.Push(rlink); tnode->rel_list.Push(rlink);
} }
......
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
namespace tvm { namespace tvm {
namespace relay { namespace relay {
using common::LinkNode;
using common::LinkedList;
/*! /*!
* \brief Interface of type solver used in type inference. * \brief Interface of type solver used in type inference.
* *
...@@ -70,41 +72,6 @@ class TypeSolver { ...@@ -70,41 +72,6 @@ class TypeSolver {
// All the object in the structure is managed by a arena allocator // All the object in the structure is managed by a arena allocator
// which releases the memory upon distruction of the type solver. // which releases the memory upon distruction of the type solver.
/*! /*!
* \brief Link list node
* \tparam T the content data type
*/
template<typename T>
struct LinkNode {
/*! \brief The content value */
T value;
/*! \brief pointer to the next location */
LinkNode<T>* next{nullptr};
};
/*!
* \brief LinkedList structure
* \tparam T the content data type
*/
template<typename T>
struct LinkedList {
/*! \brief Head pointer */
LinkNode<T>* head{nullptr};
/*! \brief Tail pointer */
LinkNode<T>* tail{nullptr};
/*!
* \brief Push a new node to the end of the linked list.
* \param node The node to be pushed.
*/
void Push(LinkNode<T>* node) {
node->next = nullptr;
if (this->tail != nullptr) {
this->tail->next = node;
this->tail = node;
} else {
head = tail = node;
}
}
};
/*!
* \brief type node struct * \brief type node struct
* TypeNode implements a union-find data structure(via parent) * TypeNode implements a union-find data structure(via parent)
* that can unifies the same types to the name resolved_type. * that can unifies the same types to the name resolved_type.
...@@ -165,18 +132,6 @@ class TypeSolver { ...@@ -165,18 +132,6 @@ class TypeSolver {
/*! \brief Reporter that reports back to self */ /*! \brief Reporter that reports back to self */
TypeReporter reporter_; TypeReporter reporter_;
/*! /*!
* \brief Create function to create a new node ptr via arena
* \tparam The type parameter
* \return The node pointer.
*/
template<typename T>
T* make() {
T* ptr = arena_.Alloc<T>();
// call constructor
new (ptr) T();
return ptr;
}
/*!
* \brief GetTypeNode that is corresponds to t. * \brief GetTypeNode that is corresponds to t.
* if it do not exist, create a new one. * if it do not exist, create a new one.
* \return The type node. * \return The type node.
...@@ -186,7 +141,7 @@ class TypeSolver { ...@@ -186,7 +141,7 @@ class TypeSolver {
if (it != tmap_.end()) { if (it != tmap_.end()) {
return it->second->FindRoot(); return it->second->FindRoot();
} else { } else {
TypeNode* n = make<TypeNode>(); TypeNode* n = arena_.make<TypeNode>();
type_nodes_.push_back(n); type_nodes_.push_back(n);
n->resolved_type = t; n->resolved_type = t;
tmap_[t] = n; tmap_[t] = n;
......
...@@ -129,5 +129,23 @@ TVM_REGISTER_API("relay._ir_pass.free_type_vars") ...@@ -129,5 +129,23 @@ TVM_REGISTER_API("relay._ir_pass.free_type_vars")
} }
}); });
/*!
* \brief Get reference counter of each internal ExprNode in body.
* \param body The body expression.
* \return The reference count mapping.
*/
std::unordered_map<const Node*, size_t>
GetExprRefCount(const Expr& body) {
class ExprRefCounter : private ExprVisitor {
public:
std::unordered_map<const Node*, size_t>
Get(const Expr& body) {
this->VisitExpr(body);
return std::move(this->visit_counter_);
}
};
return ExprRefCounter().Get(body);
}
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -33,6 +33,7 @@ def test_env(): ...@@ -33,6 +33,7 @@ def test_env():
text = env.astext() text = env.astext()
assert "def @myf" in text assert "def @myf" in text
assert "%1 = add(%0, %0) # ty=float32" in text assert "%1 = add(%0, %0) # ty=float32" in text
show(env.astext(annotate=lambda x: str(x.checked_type.dtype)))
show(text) show(text)
......
...@@ -46,6 +46,8 @@ def test_fold_fwd_simple(): ...@@ -46,6 +46,8 @@ def test_fold_fwd_simple():
weight = relay.var("weight", type_dict["weight"]) weight = relay.var("weight", type_dict["weight"])
y1_folded = relay.ir_pass.forward_fold_scale_axis(y1) y1_folded = relay.ir_pass.forward_fold_scale_axis(y1)
y1_expected = expected(x, weight, in_bias, in_scale, channels) y1_expected = expected(x, weight, in_bias, in_scale, channels)
y1_folded = relay.ir_pass.infer_type(y1_folded)
y1_expected = relay.ir_pass.infer_type(y1_expected)
assert relay.ir_pass.alpha_equal(y1_folded, y1_expected) assert relay.ir_pass.alpha_equal(y1_folded, y1_expected)
check((2, 4, 10, 10), 2) check((2, 4, 10, 10), 2)
...@@ -113,6 +115,8 @@ def test_fold_fwd_dual_path(): ...@@ -113,6 +115,8 @@ def test_fold_fwd_dual_path():
type_dict = {x.name_hint:x.checked_type for x in y1.params} type_dict = {x.name_hint:x.checked_type for x in y1.params}
weight = relay.var("weight", type_dict["weight"]) weight = relay.var("weight", type_dict["weight"])
y1_expected = expected(x, weight, in_bias, in_scale, channels) y1_expected = expected(x, weight, in_bias, in_scale, channels)
y1_folded = relay.ir_pass.infer_type(y1_folded)
y1_expected = relay.ir_pass.infer_type(y1_expected)
assert relay.ir_pass.alpha_equal(y1_folded, y1_expected) assert relay.ir_pass.alpha_equal(y1_folded, y1_expected)
check((2, 4, 10, 3), 3) check((2, 4, 10, 3), 3)
...@@ -194,6 +198,8 @@ def test_fold_bwd_simple(): ...@@ -194,6 +198,8 @@ def test_fold_bwd_simple():
weight = relay.var("weight", type_dict["weight"]) weight = relay.var("weight", type_dict["weight"])
y1_folded = relay.ir_pass.backward_fold_scale_axis(y1) y1_folded = relay.ir_pass.backward_fold_scale_axis(y1)
y1_expected = expected(x, weight, out_bias, out_scale, channels) y1_expected = expected(x, weight, out_bias, out_scale, channels)
y1_folded = relay.ir_pass.infer_type(y1_folded)
y1_expected = relay.ir_pass.infer_type(y1_expected)
assert relay.ir_pass.alpha_equal(y1_folded, y1_expected) assert relay.ir_pass.alpha_equal(y1_folded, y1_expected)
check((2, 4, 10, 10), 8) check((2, 4, 10, 10), 8)
...@@ -255,6 +261,8 @@ def test_fold_bwd_dual_path(): ...@@ -255,6 +261,8 @@ def test_fold_bwd_dual_path():
weight = relay.var("weight", type_dict["weight"]) weight = relay.var("weight", type_dict["weight"])
y1_folded = relay.ir_pass.backward_fold_scale_axis(y1) y1_folded = relay.ir_pass.backward_fold_scale_axis(y1)
y1_expected = expected(x, weight, out_bias, out_scale, channels) y1_expected = expected(x, weight, out_bias, out_scale, channels)
y1_folded = relay.ir_pass.infer_type(y1_folded)
y1_expected = relay.ir_pass.infer_type(y1_expected)
assert relay.ir_pass.alpha_equal(y1_folded, y1_expected) assert relay.ir_pass.alpha_equal(y1_folded, y1_expected)
check((2, 4, 10, 10), 8) check((2, 4, 10, 10), 8)
......
...@@ -3,15 +3,103 @@ from tvm import relay ...@@ -3,15 +3,103 @@ from tvm import relay
def test_fuse_simple(): def test_fuse_simple():
"""Simple testcase.""" """Simple testcase."""
x = relay.var("x", shape=(10, 20)) def before():
y = relay.add(x, x) x = relay.var("x", shape=(10, 20))
z = relay.exp(y) y = relay.add(x, relay.const(1, "float32"))
z = relay.exp(y)
return relay.Function([x], z)
def expected():
x = relay.var("p", shape=(10, 20))
y = relay.add(x, relay.const(1, "float32"))
z = relay.exp(y)
f1 = relay.Function([x], z)
x = relay.var("x", shape=(10, 20))
y = relay.Call(f1, [x])
return relay.Function([x], y)
z = before()
z = relay.ir_pass.infer_type(z) z = relay.ir_pass.infer_type(z)
zz = relay.ir_pass.fuse_ops(z) zz = relay.ir_pass.fuse_ops(z, opt_level=2)
zz = relay.ir_pass.infer_type(zz)
zz = relay.ir_pass.fuse_ops(zz) zz = relay.ir_pass.fuse_ops(zz)
zz = relay.ir_pass.infer_type(zz) zz = relay.ir_pass.infer_type(zz)
zz.astext() after = relay.ir_pass.infer_type(expected())
assert relay.ir_pass.alpha_equal(zz, after)
def test_conv2d_fuse():
"""Test fusion case of conv2d"""
def before(dshape):
x = relay.var("x", shape=dshape)
y = relay.nn.conv2d(x, relay.var("w1"),
kernel_size=(3, 3),
padding=(1, 1),
channels=16)
# this is the next dominator.
y1 = relay.add(relay.const(1, "float32"), y)
y = relay.add(y, y1)
# second path
z2 = relay.nn.conv2d(y, relay.var("w2"),
kernel_size=(1, 1),
padding=(0,0),
channels=16)
z3 = relay.nn.conv2d(y, relay.var("w3"),
kernel_size=(3, 3),
padding=(1,1),
channels=16)
# add can only be fused to z1
z = relay.add(z2, z3)
return relay.Function(relay.ir_pass.free_vars(z), z)
def expected(dshape):
# segment 1
x = relay.var("p0", shape=dshape)
w = relay.var("p1")
y = relay.nn.conv2d(x, w,
kernel_size=(3, 3),
padding=(1, 1),
channels=16)
y1 = relay.add(relay.const(1, "float32"), y)
y = relay.add(y, y1)
f1 = relay.Function([x, w], y)
# segment 2
x = relay.var("p0", shape=dshape)
w = relay.var("p1")
z2 = relay.nn.conv2d(x, w,
kernel_size=(3, 3),
padding=(1,1),
channels=16)
f2 = relay.Function([x, w], z2)
# segment 3
x = relay.var("p0", shape=dshape)
w = relay.var("p1")
offset = relay.var("p2", shape=dshape)
z3 = relay.nn.conv2d(x, w,
kernel_size=(1, 1),
padding=(0, 0),
channels=16)
z3 = relay.add(z3, offset)
f3 = relay.Function([x, w, offset], z3)
# compose
x = relay.var("x", shape=dshape)
y = relay.Call(f1, [x, relay.var("w1")])
z2 = relay.Call(f2, [y, relay.var("w3")])
z3 = relay.Call(f3, [y, relay.var("w2"), z2])
z = z3
return relay.Function(relay.ir_pass.free_vars(z), z)
dshape = (1, 16, 64, 64)
z = before(dshape)
z = relay.ir_pass.infer_type(z)
zz = relay.ir_pass.fuse_ops(z, opt_level=2)
zz = relay.ir_pass.infer_type(zz)
after = relay.ir_pass.infer_type(expected(dshape))
assert relay.ir_pass.alpha_equal(zz, after)
if __name__ == "__main__": if __name__ == "__main__":
test_fuse_simple() test_fuse_simple()
test_conv2d_fuse()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment