Unverified Commit 4369b7f6 by Tianqi Chen Committed by GitHub

[RELAY][PASS] General OpFusion. (#2090)

parent e470f8ea
......@@ -429,6 +429,16 @@ inline const TTypeNode* ExprNode::type_as() const {
return node;
}
/*!
* \brief Print node as text format.
* \param node The node to be printed.
* \param annotate An optional callback function for attaching
* additional comment block to an expr.
* \return The text representation.
*/
std::string RelayPrint(
const NodeRef& node,
runtime::TypedPackedFunc<std::string(Expr)> annotate = nullptr);
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_EXPR_H_
......@@ -161,6 +161,8 @@ class TypedPackedFunc<R(Args...)> {
using TSelf = TypedPackedFunc<R(Args...)>;
/*! \brief default constructor */
TypedPackedFunc() {}
/*! \brief constructor from null */
TypedPackedFunc(std::nullptr_t null) {} // NOLINT(*)
/*!
* \brief construct by wrap a PackedFunc
*
......
......@@ -22,15 +22,20 @@ def register_relay_node(type_key=None):
class RelayNode(NodeBase):
def astext(self):
"""Base class of all relay node."""
def astext(self, annotate=None):
"""Get the text format of the expression.
Returns
-------
text : str
The text format of the expression.
annotate: Optional[relay.Expr->str]
Optional annotate function to provide additional
information in the comment block.
"""
return _expr._text_print(self)
return _expr.RelayPrint(self, annotate)
@register_relay_node
......
......@@ -173,11 +173,13 @@ def build(func,
else:
tophub_context = autotvm.util.EmptyContext()
cfg = BuildConfig.current
with tophub_context:
func = optimize(func)
# Fuse ops before running code gen
func = ir_pass.infer_type(func)
func = ir_pass.fuse_ops(func)
func = ir_pass.fuse_ops(func, cfg.opt_level)
# Graph code generation
func = ir_pass.infer_type(func)
graph_gen = _graph_gen.GraphRuntimeCodegen(mod=None, target=target)
......
......@@ -6,7 +6,6 @@ from numbers import Number as _Number
import numpy as _np
from .base import RelayNode, register_relay_node
from . import _make
from . import _expr
from . import ty as _ty
from .._ffi import base as _base
from .. import nd as _nd
......@@ -477,7 +476,7 @@ class TupleWrapper(object):
text : str
The text format of the tuple expression.
"""
return _expr._text_print(self.tuple_value)
return self.tuple_value.astext()
def __getitem__(self, index):
if index >= len(self):
......
......@@ -259,7 +259,7 @@ def structural_hash(value):
raise TypeError(msg)
def fuse_ops(expr):
def fuse_ops(expr, opt_level=1):
"""Fuse operators in expr together.
Parameters
......@@ -267,9 +267,12 @@ def fuse_ops(expr):
expr : tvm.relay.Expr
The input expression.
opt_level : int
The level of fuse optimization.
Returns
-------
transformed_expr : tvm.relay.Expr
Transformed expression, containing fused result.
"""
return _ir_pass.FuseOps(expr)
return _ir_pass.FuseOps(expr, opt_level)
......@@ -38,11 +38,29 @@ class Arena {
/*!
* \brief Allocate a space from Arena for type T
* \param T the data type to be allocated
* \note The space of T is not initialized.
*/
template<typename T>
T* Alloc() {
T* allocate_() {
return static_cast<T*>(Alloc(sizeof(T), alignof(T)));
}
/*!
* \brief Create a new instance of type T.
* \param args The constructor argument.
* \tparam T the type to be created.
* \tparam Args Arguments to the constructor.
*
* \return The allocated object.
* \note The type T must be simple type, or only contain
* memory allocated from the same arena.
* Otherwise the destructor needs to be called explicitly.
*/
template<typename T, typename... Args>
T* make(Args&&... args) {
T* ptr = allocate_<T>();
new (ptr) T(std::forward<Args>(args)...);
return ptr;
}
private:
// page size 16 KB
......@@ -87,6 +105,44 @@ class Arena {
}
};
/*!
* \brief Link list node
* \tparam T the content data type
*/
template<typename T>
struct LinkNode {
/*! \brief The content value */
T value;
/*! \brief pointer to the next location */
LinkNode<T>* next{nullptr};
};
/*!
* \brief LinkedList structure
* \tparam T the content data type
* \note This is a simple data structure that can be used together with the arena.
* \sa LinkNode
*/
template<typename T>
struct LinkedList {
/*! \brief Head pointer */
LinkNode<T>* head{nullptr};
/*! \brief Tail pointer */
LinkNode<T>* tail{nullptr};
/*!
* \brief Push a new node to the end of the linked list.
* \param node The node to be pushed.
*/
void Push(LinkNode<T>* node) {
node->next = nullptr;
if (this->tail != nullptr) {
this->tail->next = node;
this->tail = node;
} else {
head = tail = node;
}
}
};
} // namespace common
} // namespace tvm
#endif // TVM_COMMON_ARENA_H_
......@@ -109,6 +109,29 @@ class ScheduleGetter :
return {};
}
Array<Tensor> VisitExpr_(const ConstantNode* op) final {
CHECK(op->is_scalar());
void* data = op->data->data;
DataType dtype = TVMType2Type(op->data->dtype);
Tensor value = tvm::compute({}, [&](const Array<tvm::Var>&) {
if (dtype == Int(32)) {
return make_const(dtype, static_cast<const int32_t*>(data)[0]);
} else if (dtype == Int(64)) {
return make_const(dtype, static_cast<const int64_t*>(data)[0]);
} else if (dtype == Float(32)) {
return make_const(dtype, static_cast<const float*>(data)[0]);
} else if (dtype == Float(64)) {
return make_const(dtype, static_cast<const double*>(data)[0]);
} else if (dtype == Bool()) {
return make_const(dtype, static_cast<const uint8_t*>(data)[0]);
} else {
LOG(FATAL) << "not handled";
return tvm::Expr();
}
});
return {value};
}
Array<Tensor> VisitExpr_(const CallNode* call_node) final {
static auto fcompute =
Op::GetAttr<FTVMCompute>("FTVMCompute");
......
......@@ -125,6 +125,8 @@ class TextPrinter :
public TypeFunctor<void (const Type&, std::ostream& os)>, // NOLINT(*)
public AttrFunctor<void (const NodeRef&, std::ostream& os)> { // NOLINT(*)
public:
explicit TextPrinter(runtime::TypedPackedFunc<std::string(Expr)> annotate)
: annotate_(annotate) {}
/*!
* \brief Print a node to string.
* \param node.
......@@ -279,11 +281,11 @@ class TextPrinter :
TextValue VisitExpr_(const CallNode* op) final {
// possibly through meta-data
TextValue call_op = GetValue(op->op);
std::vector<TextValue> args;
for (Expr arg : op->args) {
args.emplace_back(GetValue(arg));
}
TextValue call_op = GetValue(op->op);
TextValue id = this->AllocTempVar();
this->PrintIndent();
......@@ -532,7 +534,9 @@ class TextPrinter :
*/
void PrintOptionalInfo(const Expr& expr) {
// additional information in comment.
if (expr->checked_type_.defined()) {
if (annotate_ != nullptr) {
stream_ << " # " << annotate_(expr);
} else if (expr->checked_type_.defined()) {
stream_ << " # ty=";
this->PrintType(expr->checked_type(), stream_);
}
......@@ -678,7 +682,10 @@ class TextPrinter :
name = "%" + name;
}
TextValue val(GetUniqueName(name));
CHECK(!memo_.count(var)) << "Duplicated variable " << var;
// still print if ir is malformed, but show the error.
if (memo_.count(var)) {
memo_[var] = TextValue(val.name + "-malformed-ir");
}
memo_[var] = val;
return val;
}
......@@ -686,6 +693,8 @@ class TextPrinter :
private:
class AttrPrinter;
friend class AttrPrinter;
/*! \brief additional comment function */
runtime::TypedPackedFunc<std::string(Expr)> annotate_;
/*! \brief meta data context */
TextMetaDataContext meta_;
/*! \brief Check whether scope is still valid */
......@@ -776,12 +785,15 @@ void TextPrinter::PrintCallAttrs(const Expr& op,
os << ", " << meta_.GetMetaNode(attrs);
}
std::string RelayPrint(const NodeRef& node) {
return TextPrinter().Print(node);
std::string RelayPrint(const NodeRef& node,
runtime::TypedPackedFunc<std::string(Expr)> annotate) {
return TextPrinter(annotate).Print(node);
}
TVM_REGISTER_API("relay._expr._text_print")
.set_body_typed<std::string(const NodeRef&)>(RelayPrint);
TVM_REGISTER_API("relay._expr.RelayPrint")
.set_body_typed<std::string(
const NodeRef&,
runtime::TypedPackedFunc<std::string(Expr)>)>(RelayPrint);
} // namespace relay
} // namespace tvm
......@@ -10,6 +10,7 @@
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/expr_functor.h>
#include "pattern_util.h"
#include "pass_util.h"
#include "../op/nn/layout.h"
namespace tvm {
......@@ -580,23 +581,6 @@ using FBackwardTransform = TypedPackedFunc<
//----------------------------------------------
// Generic Visitors for FScaleAxisBackward
//----------------------------------------------
/*!
* \brief Get reference counter of each internal ExprNode in body.
* \param body The body expression.
* \return The reference count mapping.
*/
std::unordered_map<const Node*, size_t>
GetExprRefCount(const Expr& body) {
class ExprRefCounter : private ExprVisitor {
public:
std::unordered_map<const Node*, size_t>
Get(const Expr& body) {
this->VisitExpr(body);
return std::move(this->visit_counter_);
}
};
return ExprRefCounter().Get(body);
}
class BackwardPrep : private ExprVisitor {
public:
......
/*!
* Copyright (c) 2018 by Contributors.
*
* \file tvm/relay/pass/pass_util.h
* \brief Utilities for writing
*/
#ifndef TVM_RELAY_PASS_PASS_UTIL_H_
#define TVM_RELAY_PASS_PASS_UTIL_H_
#include <tvm/relay/op.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/attrs/transform.h>
namespace tvm {
namespace relay {
/*!
* \brief Get reference counter of each internal ExprNode in body.
* \param body The body expression.
* \return The reference count mapping.
*/
std::unordered_map<const Node*, size_t>
GetExprRefCount(const Expr& body);
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_PASS_UTIL_H_
......@@ -442,6 +442,9 @@ class TypeInferencer::Resolver : public ExprMutator {
VarNode* new_var =(
std::is_base_of<VarNode, T>::value ?
static_cast<VarNode*>(new_e.node_.get()) : nullptr);
FunctionNode* new_fn =(
std::is_base_of<FunctionNode, T>::value ?
static_cast<FunctionNode*>(new_e.node_.get()) : nullptr);
// check if we need update the new_e
bool need_update_type = !checked_type.same_as(new_e->checked_type_);
......@@ -454,7 +457,17 @@ class TypeInferencer::Resolver : public ExprMutator {
update_missing_type_annotation_ &&
!new_var->type_annotation.defined());
if (!need_update_type && !need_update_var && !need_update_call) return new_e;
bool need_update_fn = (
std::is_base_of<FunctionNode, T>::value &&
update_missing_type_annotation_ &&
!new_fn->ret_type.defined());
if (!need_update_type &&
!need_update_var &&
!need_update_call &&
!need_update_fn) {
return new_e;
}
if (!new_e.node_.unique()) {
// Copy on write optimization
......@@ -467,6 +480,9 @@ class TypeInferencer::Resolver : public ExprMutator {
new_var = (
std::is_base_of<VarNode, T>::value ?
static_cast<VarNode*>(new_e.node_.get()) : nullptr);
new_fn = (
std::is_base_of<FunctionNode, T>::value ?
static_cast<FunctionNode*>(new_e.node_.get()) : nullptr);
}
// attach the information.
......@@ -483,6 +499,11 @@ class TypeInferencer::Resolver : public ExprMutator {
if (need_update_var) {
new_var->type_annotation = checked_type;
}
if (need_update_fn) {
auto* fn_type = checked_type.as<FuncTypeNode>();
CHECK(fn_type != nullptr);
new_fn->ret_type = fn_type->ret_type;
}
return new_e;
}
......
......@@ -85,18 +85,18 @@ Type TypeSolver::Unify(const Type& dst, const Type& src) {
void TypeSolver::AddConstraint(const TypeConstraint& constraint) {
if (auto *op = constraint.as<TypeRelationNode>()) {
// create a new relation node.
RelationNode* rnode = make<RelationNode>();
RelationNode* rnode = arena_.make<RelationNode>();
rnode->rel = GetRef<TypeRelation>(op);
rel_nodes_.push_back(rnode);
// populate the type information.
for (size_t i = 0; i < op->args.size(); ++i) {
// insert link to the type list
LinkNode<TypeNode*>* tlink = make<LinkNode<TypeNode*> >();
LinkNode<TypeNode*>* tlink = arena_.make<LinkNode<TypeNode*> >();
TypeNode* tnode = GetTypeNode(op->args[i]);
tlink->value = tnode;
rnode->type_list.Push(tlink);
// insert type->relation node
LinkNode<RelationNode*>* rlink = make<LinkNode<RelationNode*> >();
LinkNode<RelationNode*>* rlink = arena_.make<LinkNode<RelationNode*> >();
rlink->value = rnode;
tnode->rel_list.Push(rlink);
}
......
......@@ -16,6 +16,8 @@
namespace tvm {
namespace relay {
using common::LinkNode;
using common::LinkedList;
/*!
* \brief Interface of type solver used in type inference.
*
......@@ -70,41 +72,6 @@ class TypeSolver {
// All the object in the structure is managed by a arena allocator
// which releases the memory upon distruction of the type solver.
/*!
* \brief Link list node
* \tparam T the content data type
*/
template<typename T>
struct LinkNode {
/*! \brief The content value */
T value;
/*! \brief pointer to the next location */
LinkNode<T>* next{nullptr};
};
/*!
* \brief LinkedList structure
* \tparam T the content data type
*/
template<typename T>
struct LinkedList {
/*! \brief Head pointer */
LinkNode<T>* head{nullptr};
/*! \brief Tail pointer */
LinkNode<T>* tail{nullptr};
/*!
* \brief Push a new node to the end of the linked list.
* \param node The node to be pushed.
*/
void Push(LinkNode<T>* node) {
node->next = nullptr;
if (this->tail != nullptr) {
this->tail->next = node;
this->tail = node;
} else {
head = tail = node;
}
}
};
/*!
* \brief type node struct
* TypeNode implements a union-find data structure(via parent)
* that can unifies the same types to the name resolved_type.
......@@ -165,18 +132,6 @@ class TypeSolver {
/*! \brief Reporter that reports back to self */
TypeReporter reporter_;
/*!
* \brief Create function to create a new node ptr via arena
* \tparam The type parameter
* \return The node pointer.
*/
template<typename T>
T* make() {
T* ptr = arena_.Alloc<T>();
// call constructor
new (ptr) T();
return ptr;
}
/*!
* \brief GetTypeNode that is corresponds to t.
* if it do not exist, create a new one.
* \return The type node.
......@@ -186,7 +141,7 @@ class TypeSolver {
if (it != tmap_.end()) {
return it->second->FindRoot();
} else {
TypeNode* n = make<TypeNode>();
TypeNode* n = arena_.make<TypeNode>();
type_nodes_.push_back(n);
n->resolved_type = t;
tmap_[t] = n;
......
......@@ -129,5 +129,23 @@ TVM_REGISTER_API("relay._ir_pass.free_type_vars")
}
});
/*!
* \brief Get reference counter of each internal ExprNode in body.
* \param body The body expression.
* \return The reference count mapping.
*/
std::unordered_map<const Node*, size_t>
GetExprRefCount(const Expr& body) {
class ExprRefCounter : private ExprVisitor {
public:
std::unordered_map<const Node*, size_t>
Get(const Expr& body) {
this->VisitExpr(body);
return std::move(this->visit_counter_);
}
};
return ExprRefCounter().Get(body);
}
} // namespace relay
} // namespace tvm
......@@ -33,6 +33,7 @@ def test_env():
text = env.astext()
assert "def @myf" in text
assert "%1 = add(%0, %0) # ty=float32" in text
show(env.astext(annotate=lambda x: str(x.checked_type.dtype)))
show(text)
......
......@@ -46,6 +46,8 @@ def test_fold_fwd_simple():
weight = relay.var("weight", type_dict["weight"])
y1_folded = relay.ir_pass.forward_fold_scale_axis(y1)
y1_expected = expected(x, weight, in_bias, in_scale, channels)
y1_folded = relay.ir_pass.infer_type(y1_folded)
y1_expected = relay.ir_pass.infer_type(y1_expected)
assert relay.ir_pass.alpha_equal(y1_folded, y1_expected)
check((2, 4, 10, 10), 2)
......@@ -113,6 +115,8 @@ def test_fold_fwd_dual_path():
type_dict = {x.name_hint:x.checked_type for x in y1.params}
weight = relay.var("weight", type_dict["weight"])
y1_expected = expected(x, weight, in_bias, in_scale, channels)
y1_folded = relay.ir_pass.infer_type(y1_folded)
y1_expected = relay.ir_pass.infer_type(y1_expected)
assert relay.ir_pass.alpha_equal(y1_folded, y1_expected)
check((2, 4, 10, 3), 3)
......@@ -194,6 +198,8 @@ def test_fold_bwd_simple():
weight = relay.var("weight", type_dict["weight"])
y1_folded = relay.ir_pass.backward_fold_scale_axis(y1)
y1_expected = expected(x, weight, out_bias, out_scale, channels)
y1_folded = relay.ir_pass.infer_type(y1_folded)
y1_expected = relay.ir_pass.infer_type(y1_expected)
assert relay.ir_pass.alpha_equal(y1_folded, y1_expected)
check((2, 4, 10, 10), 8)
......@@ -255,6 +261,8 @@ def test_fold_bwd_dual_path():
weight = relay.var("weight", type_dict["weight"])
y1_folded = relay.ir_pass.backward_fold_scale_axis(y1)
y1_expected = expected(x, weight, out_bias, out_scale, channels)
y1_folded = relay.ir_pass.infer_type(y1_folded)
y1_expected = relay.ir_pass.infer_type(y1_expected)
assert relay.ir_pass.alpha_equal(y1_folded, y1_expected)
check((2, 4, 10, 10), 8)
......
......@@ -3,15 +3,103 @@ from tvm import relay
def test_fuse_simple():
"""Simple testcase."""
x = relay.var("x", shape=(10, 20))
y = relay.add(x, x)
z = relay.exp(y)
def before():
x = relay.var("x", shape=(10, 20))
y = relay.add(x, relay.const(1, "float32"))
z = relay.exp(y)
return relay.Function([x], z)
def expected():
x = relay.var("p", shape=(10, 20))
y = relay.add(x, relay.const(1, "float32"))
z = relay.exp(y)
f1 = relay.Function([x], z)
x = relay.var("x", shape=(10, 20))
y = relay.Call(f1, [x])
return relay.Function([x], y)
z = before()
z = relay.ir_pass.infer_type(z)
zz = relay.ir_pass.fuse_ops(z)
zz = relay.ir_pass.fuse_ops(z, opt_level=2)
zz = relay.ir_pass.infer_type(zz)
zz = relay.ir_pass.fuse_ops(zz)
zz = relay.ir_pass.infer_type(zz)
zz.astext()
after = relay.ir_pass.infer_type(expected())
assert relay.ir_pass.alpha_equal(zz, after)
def test_conv2d_fuse():
"""Test fusion case of conv2d"""
def before(dshape):
x = relay.var("x", shape=dshape)
y = relay.nn.conv2d(x, relay.var("w1"),
kernel_size=(3, 3),
padding=(1, 1),
channels=16)
# this is the next dominator.
y1 = relay.add(relay.const(1, "float32"), y)
y = relay.add(y, y1)
# second path
z2 = relay.nn.conv2d(y, relay.var("w2"),
kernel_size=(1, 1),
padding=(0,0),
channels=16)
z3 = relay.nn.conv2d(y, relay.var("w3"),
kernel_size=(3, 3),
padding=(1,1),
channels=16)
# add can only be fused to z1
z = relay.add(z2, z3)
return relay.Function(relay.ir_pass.free_vars(z), z)
def expected(dshape):
# segment 1
x = relay.var("p0", shape=dshape)
w = relay.var("p1")
y = relay.nn.conv2d(x, w,
kernel_size=(3, 3),
padding=(1, 1),
channels=16)
y1 = relay.add(relay.const(1, "float32"), y)
y = relay.add(y, y1)
f1 = relay.Function([x, w], y)
# segment 2
x = relay.var("p0", shape=dshape)
w = relay.var("p1")
z2 = relay.nn.conv2d(x, w,
kernel_size=(3, 3),
padding=(1,1),
channels=16)
f2 = relay.Function([x, w], z2)
# segment 3
x = relay.var("p0", shape=dshape)
w = relay.var("p1")
offset = relay.var("p2", shape=dshape)
z3 = relay.nn.conv2d(x, w,
kernel_size=(1, 1),
padding=(0, 0),
channels=16)
z3 = relay.add(z3, offset)
f3 = relay.Function([x, w, offset], z3)
# compose
x = relay.var("x", shape=dshape)
y = relay.Call(f1, [x, relay.var("w1")])
z2 = relay.Call(f2, [y, relay.var("w3")])
z3 = relay.Call(f3, [y, relay.var("w2"), z2])
z = z3
return relay.Function(relay.ir_pass.free_vars(z), z)
dshape = (1, 16, 64, 64)
z = before(dshape)
z = relay.ir_pass.infer_type(z)
zz = relay.ir_pass.fuse_ops(z, opt_level=2)
zz = relay.ir_pass.infer_type(zz)
after = relay.ir_pass.infer_type(expected(dshape))
assert relay.ir_pass.alpha_equal(zz, after)
if __name__ == "__main__":
test_fuse_simple()
test_conv2d_fuse()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment