Unverified Commit 2a5656bf by Lianmin Zheng Committed by GitHub

[Relay] Alter Op Layout (#2150)

* [RELAY] Finish alter op pass

* [RELAY] AlterOpLayout Pass

* fix broadcast operators

* fix broadcast operators

* fix broadcast operators

* Support concatenate

* address comments

* address comments

* add comments

* rebase
parent 4bf1fd8c
Subproject commit e4a4c02764d37c9c3db0d64c4996651a3ef9513c Subproject commit a08e26e5a97f4ef4d566a42f6c78704b3f9c7b8a
...@@ -105,6 +105,7 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode<Conv2DTransposeAttrs> { ...@@ -105,6 +105,7 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode<Conv2DTransposeAttrs> {
int groups; int groups;
std::string data_layout; std::string data_layout;
std::string weight_layout; std::string weight_layout;
std::string out_layout;
DataType out_dtype; DataType out_dtype;
TVM_DECLARE_ATTRS(Conv2DTransposeAttrs, "relay.attrs.Conv2DTransposeAttrs") { TVM_DECLARE_ATTRS(Conv2DTransposeAttrs, "relay.attrs.Conv2DTransposeAttrs") {
...@@ -139,6 +140,10 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode<Conv2DTransposeAttrs> { ...@@ -139,6 +140,10 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode<Conv2DTransposeAttrs> {
.describe("Dimension ordering of data and weight. Can be 'OIHW', 'OIHW16o16i', etc." .describe("Dimension ordering of data and weight. Can be 'OIHW', 'OIHW16o16i', etc."
"'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
"dimensions respectively."); "dimensions respectively.");
TVM_ATTR_FIELD(out_layout).set_default("")
.describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Default to be same as input layout.");
TVM_ATTR_FIELD(out_dtype) TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>()) .set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting"); .describe("Output data type, set to explicit type under mixed precision setting");
......
...@@ -164,6 +164,19 @@ struct ClipAttrs : public tvm::AttrsNode<ClipAttrs> { ...@@ -164,6 +164,19 @@ struct ClipAttrs : public tvm::AttrsNode<ClipAttrs> {
} }
}; };
struct LayoutTransformAttrs : public tvm::AttrsNode<LayoutTransformAttrs> {
std::string src_layout;
std::string dst_layout;
TVM_DECLARE_ATTRS(LayoutTransformAttrs, "relay.attrs.LayoutTransformAttrs") {
TVM_ATTR_FIELD(src_layout)
.describe("The source layout of the tensor. (e.g. NCHW)");
TVM_ATTR_FIELD(dst_layout)
.describe("The destination layout of the tensor. (e.g. NCHW16c)");
}
};
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_ #endif // TVM_RELAY_ATTRS_TRANSFORM_H_
...@@ -459,7 +459,7 @@ inline const TTypeNode* ExprNode::type_as() const { ...@@ -459,7 +459,7 @@ inline const TTypeNode* ExprNode::type_as() const {
static_assert(std::is_base_of<TypeNode, TTypeNode>::value, static_assert(std::is_base_of<TypeNode, TTypeNode>::value,
"TType must be a special case of type"); "TType must be a special case of type");
CHECK(checked_type_.defined()) CHECK(checked_type_.defined())
<< "Type inference for this Expr has not completed"; << "Type inference for this Expr has not completed. Try to call infer_type pass.";
const TTypeNode* node = checked_type_.as<TTypeNode>(); const TTypeNode* node = checked_type_.as<TTypeNode>();
CHECK(node != nullptr) CHECK(node != nullptr)
<< "Expected type to be " << TTypeNode::_type_key << "Expected type to be " << TTypeNode::_type_key
......
...@@ -87,6 +87,21 @@ using FTVMSchedule = runtime::TypedPackedFunc< ...@@ -87,6 +87,21 @@ using FTVMSchedule = runtime::TypedPackedFunc<
const Target& target)>; const Target& target)>;
/*! /*!
* \brief Alternate the layout of operators or replace the
* operator with other expressions. This function will be invoked
* in AlterOpLayout pass.
* \param attrs The attribute of the original node.
* \param inputs The input symbols of the original node.
* \param tinfos An array of placeholders, use for getting the inferred shape
* and dtype of the inputs.
* \return new_expr The modified expression.
*/
using FTVMAlterOpLayout = runtime::TypedPackedFunc<
Expr(const Attrs& attrs,
const Array<Expr>& args,
const Array<Tensor>& tinfos)>;
/*!
* \brief Forward rewriting rule for a specific op. * \brief Forward rewriting rule for a specific op.
* *
* \param ref_call The reference old call type to be rewritten. * \param ref_call The reference old call type to be rewritten.
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <tvm/relay/module.h> #include <tvm/relay/module.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/op_attr_types.h>
#include <string> #include <string>
namespace tvm { namespace tvm {
...@@ -173,6 +174,21 @@ Expr ForwardRewrite(const Expr& expr, ...@@ -173,6 +174,21 @@ Expr ForwardRewrite(const Expr& expr,
std::function<NodeRef(const Call&)> fcontext = nullptr, std::function<NodeRef(const Call&)> fcontext = nullptr,
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr); std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
/*!
* \brief Apply rewrite rules to rewrite the expr in post DFS order.
* \param expr The expression.
* \param rewrite_func The rewrite func that will apply to all operators.
* \param fcontext Additional callback to provide context argument for each call node.
* \param fmulti_ref_trigger Transformation function to be called when
* an Expr consumed by multiple callers.
* \return The rewritten expression.
*/
Expr ForwardRewrite(const Expr& expr,
const FForwardRewrite& rewrite_func,
std::function<NodeRef(const Call&)> fcontext = nullptr,
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
/*! \brief A hashing structure in the style of std::hash. */ /*! \brief A hashing structure in the style of std::hash. */
struct StructuralHash { struct StructuralHash {
/*! \brief Hash a Relay type. /*! \brief Hash a Relay type.
......
...@@ -13,6 +13,7 @@ from . import container ...@@ -13,6 +13,7 @@ from . import container
from . import schedule from . import schedule
from . import module from . import module
from . import node from . import node
from . import attrs
from . import ir_builder from . import ir_builder
from . import target from . import target
from . import generic from . import generic
......
""" TVM Attribute module, which is mainly used for defining attributes of operators"""
from ._ffi.node import NodeBase, register_node as _register_tvm_node
from ._ffi.function import _init_api
from . import _api_internal
@_register_tvm_node
class Attrs(NodeBase):
"""Attribute node, which is mainly use for defining attributes of relay operators.
Used by function registered in python side, such as compute, schedule and alter_layout.
Attrs is passed as the first argument to these functions.
"""
def list_field_info(self):
""" Get fields information
Returns
-------
infos: list of AttrFieldInfo
List of field information
"""
return _api_internal._AttrsListFieldInfo(self)
def keys(self):
"""Get list of names in the attribute.
Returns
-------
keys : list of str
List of keys
"""
fields = self.list_field_info()
for field in fields:
yield field.name
def __getitem__(self, item):
return self.__getattr__(item)
_init_api("tvm.attrs")
...@@ -21,6 +21,20 @@ def register_relay_node(type_key=None): ...@@ -21,6 +21,20 @@ def register_relay_node(type_key=None):
return _register_tvm_node(type_key) return _register_tvm_node(type_key)
def register_relay_attr_node(type_key=None):
"""register relay attribute node
Parameters
----------
type_key : str or cls
The type key of the node
"""
if not isinstance(type_key, str):
return _register_tvm_node(
"relay.attrs." + type_key.__name__)(type_key)
return _register_tvm_node(type_key)
class RelayNode(NodeBase): class RelayNode(NodeBase):
"""Base class of all relay node.""" """Base class of all relay node."""
def astext(self, show_meta_data=True, annotate=None): def astext(self, show_meta_data=True, annotate=None):
......
...@@ -17,6 +17,7 @@ OPT_PASS_LEVEL = { ...@@ -17,6 +17,7 @@ OPT_PASS_LEVEL = {
"FoldConstant": 2, "FoldConstant": 2,
"CombineParallelConv2D": 3, "CombineParallelConv2D": 3,
"FoldScaleAxis": 3, "FoldScaleAxis": 3,
"AlterOpLayout": 3,
} }
class BuildConfig(object): class BuildConfig(object):
...@@ -157,6 +158,13 @@ def optimize(func, params=None): ...@@ -157,6 +158,13 @@ def optimize(func, params=None):
if cfg.pass_enabled("FoldConstant"): if cfg.pass_enabled("FoldConstant"):
func = ir_pass.fold_constant(func) func = ir_pass.fold_constant(func)
if cfg.pass_enabled("AlterOpLayout"):
func = ir_pass.infer_type(func)
func = ir_pass.canonicalize_ops(func)
func = ir_pass.infer_type(func)
func = ir_pass.alter_op_layout(func)
return func return func
......
...@@ -191,6 +191,23 @@ def simplify_inference(expr): ...@@ -191,6 +191,23 @@ def simplify_inference(expr):
return _ir_pass.simplify_inference(expr) return _ir_pass.simplify_inference(expr)
def canonicalize_ops(expr):
""" Canonicalize special operators to basic operators.
This can simplify latter analysis. (e.g. Expand bias_add to expand_dims and broadcast_add.)
Parameters
----------
e: tvm.relay.Expr
The input Expression
Returns
-------
result: tvm.relay.Expr
An expression without bias_add
"""
return _ir_pass.canonicalize_ops(expr)
def dead_code_elimination(expr): def dead_code_elimination(expr):
""" Remove expressions which does not effect the program result (dead code). """ Remove expressions which does not effect the program result (dead code).
...@@ -321,3 +338,22 @@ def combine_parallel_conv2d(expr): ...@@ -321,3 +338,22 @@ def combine_parallel_conv2d(expr):
Transformed expression Transformed expression
""" """
return _ir_pass.CombineParallelConv2D(expr) return _ir_pass.CombineParallelConv2D(expr)
def alter_op_layout(expr):
"""Alternate the layouts of operators or replace primitive operators with
other expressions.
This pass can be used for computing convolution in custom layouts or
other general weight pre-transformation.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
Returns
-------
transformed_expr : tvm.relay.Expr
Transformed expression with alternated layout.
"""
return _ir_pass.AlterOpLayout(expr)
#pylint: disable=wildcard-import, redefined-builtin #pylint: disable=wildcard-import, redefined-builtin
"""Relay core operators.""" """Relay core operators."""
# operator defs # operator defs
from .op import get, register, register_schedule, register_compute, Op from .op import get, register, register_schedule, register_compute, register_alter_op_layout, \
Op
# Operators # Operators
from .reduce import * from .reduce import *
...@@ -10,6 +11,7 @@ from .transform import * ...@@ -10,6 +11,7 @@ from .transform import *
from . import nn from . import nn
from . import image from . import image
from . import vision from . import vision
from . import op_attrs
# operator registry # operator registry
from . import _tensor from . import _tensor
......
...@@ -80,12 +80,3 @@ def clip_compute(attrs, inputs, output_type, target): ...@@ -80,12 +80,3 @@ def clip_compute(attrs, inputs, output_type, target):
return [topi.clip(inputs[0], attrs.a_min, attrs.a_max)] return [topi.clip(inputs[0], attrs.a_min, attrs.a_max)]
register_schedule("clip", schedule_elemwise) register_schedule("clip", schedule_elemwise)
register_pattern("clip", OpPattern.ELEMWISE)
# concatenate
@register_compute("concatenate")
def concatenate_compute(attrs, inputs, output_type, target):
return [topi.concatenate(inputs, axis=attrs.axis)]
register_schedule("concatenate", schedule_injective)
register_pattern("concatenate", OpPattern.INJECTIVE)
"""Backend compiler related feature registration""" """Backend compiler related feature registration"""
# pylint: disable=invalid-name # pylint: disable=invalid-name,unused-argument
from __future__ import absolute_import from __future__ import absolute_import
import topi
from . import op as _reg from . import op as _reg
from ._reduce import _schedule_reduce from ._reduce import _schedule_reduce
from .op import schedule_injective, OpPattern
schedule_injective = _reg.schedule_injective schedule_injective = _reg.schedule_injective
schedule_broadcast = _reg.schedule_injective schedule_broadcast = _reg.schedule_injective
...@@ -15,10 +17,22 @@ _reg.register_schedule("reshape", schedule_injective) ...@@ -15,10 +17,22 @@ _reg.register_schedule("reshape", schedule_injective)
_reg.register_schedule("reshape_like", schedule_injective) _reg.register_schedule("reshape_like", schedule_injective)
_reg.register_schedule("full", schedule_injective) _reg.register_schedule("full", schedule_injective)
_reg.register_schedule("full_like", schedule_injective) _reg.register_schedule("full_like", schedule_injective)
_reg.register_schedule("cast", schedule_broadcast) _reg.register_schedule("cast", schedule_injective)
_reg.register_schedule("strided_slice", schedule_injective) _reg.register_schedule("strided_slice", schedule_injective)
_reg.register_schedule("slice_like", schedule_injective) _reg.register_schedule("slice_like", schedule_injective)
_reg.register_schedule("split", schedule_injective) _reg.register_schedule("split", schedule_injective)
_reg.register_schedule("take", schedule_injective) _reg.register_schedule("take", schedule_injective)
_reg.register_schedule("transpose", schedule_injective) _reg.register_schedule("transpose", schedule_injective)
_reg.register_schedule("where", schedule_broadcast) _reg.register_schedule("where", schedule_broadcast)
# layout_transform
_reg.register_schedule("layout_transform", schedule_injective)
_reg.register_pattern("layout_transform", OpPattern.INJECTIVE)
# concatenate
@_reg.register_compute("concatenate")
def concatenate_compute(attrs, inputs, output_type, target):
return [topi.concatenate(inputs, axis=attrs.axis)]
_reg.register_schedule("concatenate", schedule_injective)
_reg.register_pattern("concatenate", OpPattern.INJECTIVE)
...@@ -107,7 +107,7 @@ def register_schedule(op_name, schedule=None, level=10): ...@@ -107,7 +107,7 @@ def register_schedule(op_name, schedule=None, level=10):
op_name : str op_name : str
The name of the op. The name of the op.
schedule : function schedule : function (attrs: Attrs, outs: List[Tensor], target: Target) -> sch: Schedule
The schedule function. The schedule function.
level : int level : int
...@@ -124,7 +124,8 @@ def register_compute(op_name, compute=None, level=10): ...@@ -124,7 +124,8 @@ def register_compute(op_name, compute=None, level=10):
op_name : str op_name : str
The name of the op. The name of the op.
compute : function compute : function (attrs: Attrs, inputs: List[Tensor], out_type: Type, target:Target)
-> List[Tensor]
The compute function. The compute function.
level : int level : int
...@@ -133,6 +134,23 @@ def register_compute(op_name, compute=None, level=10): ...@@ -133,6 +134,23 @@ def register_compute(op_name, compute=None, level=10):
return register(op_name, "FTVMCompute", compute, level) return register(op_name, "FTVMCompute", compute, level)
def register_alter_op_layout(op_name, alter_layout=None, level=10):
"""Register alter op layout function for an op
Parameters
----------
op_name : str
The name of the operator
alter_layout: function (attrs: Attrs, inputs: List[Expr]) -> new_expr: Expr
The function for changing the layout or replacing the operator
level : int
The priority level
"""
return register(op_name, "FTVMAlterOpLayout", alter_layout, level)
def register_pattern(op_name, pattern, level=10): def register_pattern(op_name, pattern, level=10):
"""Register operator pattern for an op. """Register operator pattern for an op.
......
"""The attributes node used for Relay operators"""
from ...attrs import Attrs
from ..base import register_relay_attr_node
@register_relay_attr_node
class Conv2DAttrs(Attrs):
"""Attribute of a Convolution Operator"""
pass
@register_relay_attr_node
class GlobalPool2DAttrs(Attrs):
"""Attribute of a Global 2D Pooling Operator"""
pass
...@@ -387,3 +387,25 @@ def slice_like(data, shape_like, axes=None): ...@@ -387,3 +387,25 @@ def slice_like(data, shape_like, axes=None):
The computed result. The computed result.
""" """
return _make.slice_like(data, shape_like, axes) return _make.slice_like(data, shape_like, axes)
def layout_transform(data, src_layout, dst_layout):
"""Transform the layout of a tensor
Parameters
----------
data : relay.Expr
The source tensor to be transformed
src_layout: str
The source layout. (e.g NCHW)
dst_layout: str
The destination layout. (e.g. NCHW16c)
Returns
-------
ret : relay.Expr
The transformed tensor.
"""
return _make.layout_transform(data, src_layout, dst_layout)
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
* \file attrs.cc * \file attrs.cc
*/ */
#include <tvm/attrs.h> #include <tvm/attrs.h>
#include <tvm/api_registry.h>
#include "attr_functor.h" #include "attr_functor.h"
namespace tvm { namespace tvm {
...@@ -321,4 +322,9 @@ bool DictAttrsNode::ContentEqual(const Node* other, AttrsEqual equal) const { ...@@ -321,4 +322,9 @@ bool DictAttrsNode::ContentEqual(const Node* other, AttrsEqual equal) const {
return equal(this->dict, static_cast<const DictAttrsNode*>(other)->dict); return equal(this->dict, static_cast<const DictAttrsNode*>(other)->dict);
} }
TVM_REGISTER_API("_AttrsListFieldInfo")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Attrs()->ListFieldInfo();
});
} // namespace tvm } // namespace tvm
...@@ -185,7 +185,7 @@ class Layout : public NodeRef { ...@@ -185,7 +185,7 @@ class Layout : public NodeRef {
CHECK_GT(block_size, 0); CHECK_GT(block_size, 0);
new_layout << block_size; new_layout << block_size;
} }
new_layout << layout_simplified[i]->value; new_layout << static_cast<char>(layout_simplified[i]->value);
} }
return Layout(new_layout.str()); return Layout(new_layout.str());
} }
...@@ -241,6 +241,16 @@ class Layout : public NodeRef { ...@@ -241,6 +241,16 @@ class Layout : public NodeRef {
return operator->()->layout_simplified.size(); return operator->()->layout_simplified.size();
} }
/*! \return number of super dimensions */
size_t ndim_super() const {
size_t ct = 0;
for (auto x : operator->()->layout_simplified) {
if (IsSuperdim(x))
ct++;
}
return ct;
}
/*! /*!
* \brief The description of the \p i-th dimension. * \brief The description of the \p i-th dimension.
* If it is a sub-dimension, the size will be returned as well, * If it is a sub-dimension, the size will be returned as well,
...@@ -327,6 +337,17 @@ class Layout : public NodeRef { ...@@ -327,6 +337,17 @@ class Layout : public NodeRef {
return operator->()->name == rhs->name; return operator->()->name == rhs->name;
} }
/*!
* \brief allow output string of layout to ostream
* \param os the output stream
* \param l the layout
* \return the ostream
*/
friend std::ostream& operator<<(std::ostream& os, const Layout& l) {
os << l.name();
return os;
}
using ContainerType = LayoutNode; using ContainerType = LayoutNode;
private: private:
......
...@@ -7,11 +7,13 @@ ...@@ -7,11 +7,13 @@
#include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/nn.h>
#include <vector> #include <vector>
#include "../../pass/alter_op_layout.h"
#include "../layout.h" #include "../layout.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
// relay.nn.conv2d
TVM_REGISTER_NODE_TYPE(Conv2DAttrs); TVM_REGISTER_NODE_TYPE(Conv2DAttrs);
bool Conv2DRel(const Array<Type>& types, bool Conv2DRel(const Array<Type>& types,
...@@ -101,6 +103,20 @@ bool Conv2DRel(const Array<Type>& types, ...@@ -101,6 +103,20 @@ bool Conv2DRel(const Array<Type>& types,
return true; return true;
} }
template<typename T>
Array<Array<Layout> > Conv2DInferCorrectLayout(
const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>> &old_in_shapes) {
const T* params = attrs.as<T>();
Layout out_layout(params->out_layout);
// We always make other operators to fit the layouts of convolution layers
// So this inference ignores all inputs
return Array<Array<Layout> >{{params->data_layout, params->weight_layout},
{out_layout.defined() ? out_layout : params->data_layout}};
}
// Positional relay function to create conv2d operator // Positional relay function to create conv2d operator
// used by frontend FFI. // used by frontend FFI.
...@@ -156,10 +172,11 @@ with the layer input to produce a tensor of outputs. ...@@ -156,10 +172,11 @@ with the layer input to produce a tensor of outputs.
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.") .add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(2) .set_support_level(2)
.add_type_rel("Conv2D", Conv2DRel); .add_type_rel("Conv2D", Conv2DRel)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", Conv2DInferCorrectLayout<Conv2DAttrs>);
// Conv2DTranspose // relay.nn.conv2d_transpose
TVM_REGISTER_NODE_TYPE(Conv2DTransposeAttrs); TVM_REGISTER_NODE_TYPE(Conv2DTransposeAttrs);
bool Conv2DTransposeRel(const Array<Type>& types, bool Conv2DTransposeRel(const Array<Type>& types,
...@@ -185,6 +202,12 @@ bool Conv2DTransposeRel(const Array<Type>& types, ...@@ -185,6 +202,12 @@ bool Conv2DTransposeRel(const Array<Type>& types,
<< "Conv only support kernel layouts that are convertible from OIHW." << "Conv only support kernel layouts that are convertible from OIHW."
<< " But got "<< kernel_layout; << " But got "<< kernel_layout;
Layout out_layout(param->out_layout);
if (!out_layout.defined()) out_layout = in_layout;
CHECK(out_layout.Convertible(kNCHW))
<< "Conv only support output layouts that are convertible from NCHW."
<< " But got " << out_layout;
IndexExpr channels, dilated_ksize_y, dilated_ksize_x; IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
auto dshape_nchw = ConvertLayout(data->shape, in_layout, kNCHW); auto dshape_nchw = ConvertLayout(data->shape, in_layout, kNCHW);
...@@ -241,7 +264,7 @@ bool Conv2DTransposeRel(const Array<Type>& types, ...@@ -241,7 +264,7 @@ bool Conv2DTransposeRel(const Array<Type>& types,
if (out_dtype.bits() == 0) { if (out_dtype.bits() == 0) {
out_dtype = data->dtype; out_dtype = data->dtype;
} }
oshape = ConvertLayout(oshape, kNCHW, in_layout); oshape = ConvertLayout(oshape, kNCHW, out_layout);
reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype)); reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
return true; return true;
} }
...@@ -307,6 +330,8 @@ v (batch_size, channels, out_height, out_width) if `layout` is `NCHW` ...@@ -307,6 +330,8 @@ v (batch_size, channels, out_height, out_width) if `layout` is `NCHW`
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.") .add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(2) .set_support_level(2)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
Conv2DInferCorrectLayout<Conv2DTransposeAttrs>)
.add_type_rel("Conv2DTranspose", Conv2DTransposeRel); .add_type_rel("Conv2DTranspose", Conv2DTransposeRel);
} // namespace relay } // namespace relay
......
...@@ -12,12 +12,14 @@ ...@@ -12,12 +12,14 @@
#include <topi/nn/flatten.h> #include <topi/nn/flatten.h>
#include <vector> #include <vector>
#include "../type_relations.h" #include "../type_relations.h"
#include "../../pass/alter_op_layout.h"
#include "../op_common.h" #include "../op_common.h"
#include "../layout.h" #include "../layout.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
// relay.nn.bias_add
TVM_REGISTER_NODE_TYPE(BiasAddAttrs); TVM_REGISTER_NODE_TYPE(BiasAddAttrs);
bool BiasAddRel(const Array<Type>& types, bool BiasAddRel(const Array<Type>& types,
...@@ -74,6 +76,7 @@ RELAY_REGISTER_OP("nn.bias_add") ...@@ -74,6 +76,7 @@ RELAY_REGISTER_OP("nn.bias_add")
.add_type_rel("BiasAdd", BiasAddRel); .add_type_rel("BiasAdd", BiasAddRel);
// relay.nn.dense
TVM_REGISTER_NODE_TYPE(DenseAttrs); TVM_REGISTER_NODE_TYPE(DenseAttrs);
...@@ -143,6 +146,8 @@ RELAY_REGISTER_OP("nn.dense") ...@@ -143,6 +146,8 @@ RELAY_REGISTER_OP("nn.dense")
.set_support_level(1) .set_support_level(1)
.add_type_rel("Dense", DenseRel); .add_type_rel("Dense", DenseRel);
// relay.leaky_relu
TVM_REGISTER_NODE_TYPE(LeakyReluAttrs);
// Positional relay function to create leaky relu operator used by frontend FFI. // Positional relay function to create leaky relu operator used by frontend FFI.
Expr MakeLeakyRelu(Expr data, Expr MakeLeakyRelu(Expr data,
...@@ -171,6 +176,7 @@ RELAY_REGISTER_OP("nn.leaky_relu") ...@@ -171,6 +176,7 @@ RELAY_REGISTER_OP("nn.leaky_relu")
.add_argument("data", "Tensor", "Input data.") .add_argument("data", "Tensor", "Input data.")
.set_support_level(3) .set_support_level(3)
.add_type_rel("Identity", IdentityRel) .add_type_rel("Identity", IdentityRel)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
"FTVMCompute", [](const Attrs& attrs, "FTVMCompute", [](const Attrs& attrs,
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
...@@ -181,6 +187,7 @@ RELAY_REGISTER_OP("nn.leaky_relu") ...@@ -181,6 +187,7 @@ RELAY_REGISTER_OP("nn.leaky_relu")
}); });
// relay.prelu
TVM_REGISTER_NODE_TYPE(PReluAttrs); TVM_REGISTER_NODE_TYPE(PReluAttrs);
bool PReluRel(const Array<Type>& types, bool PReluRel(const Array<Type>& types,
...@@ -235,6 +242,7 @@ where :math:`*` is an channelwise multiplication for each sample in the batch. ...@@ -235,6 +242,7 @@ where :math:`*` is an channelwise multiplication for each sample in the batch.
.add_argument("alpha", "Tensor", "Input channelwise alpha.") .add_argument("alpha", "Tensor", "Input channelwise alpha.")
.set_support_level(3) .set_support_level(3)
.add_type_rel("PRelu", PReluRel) .add_type_rel("PRelu", PReluRel)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
"FTVMCompute", [](const Attrs& attrs, "FTVMCompute", [](const Attrs& attrs,
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
...@@ -245,6 +253,9 @@ where :math:`*` is an channelwise multiplication for each sample in the batch. ...@@ -245,6 +253,9 @@ where :math:`*` is an channelwise multiplication for each sample in the batch.
}); });
// relay.softmax
TVM_REGISTER_NODE_TYPE(SoftmaxAttrs);
TVM_REGISTER_API("relay.op.nn._make.softmax") TVM_REGISTER_API("relay.op.nn._make.softmax")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body([](const TVMArgs& args, TVMRetValue* rv) {
auto make_func = [](Expr data, int axis) { auto make_func = [](Expr data, int axis) {
...@@ -282,6 +293,7 @@ RELAY_REGISTER_OP("nn.softmax") ...@@ -282,6 +293,7 @@ RELAY_REGISTER_OP("nn.softmax")
}); });
// relay.nn.log_softmax
TVM_REGISTER_API("relay.op.nn._make.log_softmax") TVM_REGISTER_API("relay.op.nn._make.log_softmax")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body([](const TVMArgs& args, TVMRetValue* rv) {
auto make_func = [](Expr data, int axis) { auto make_func = [](Expr data, int axis) {
...@@ -321,8 +333,7 @@ RELAY_REGISTER_OP("nn.log_softmax") ...@@ -321,8 +333,7 @@ RELAY_REGISTER_OP("nn.log_softmax")
}); });
// relay.nn.batch_flatten
// BatchFlatten
bool BatchFlattenRel(const Array<Type>& types, bool BatchFlattenRel(const Array<Type>& types,
int num_inputs, int num_inputs,
const Attrs& attrs, const Attrs& attrs,
...@@ -410,6 +421,7 @@ RELAY_REGISTER_OP("nn.relu") ...@@ -410,6 +421,7 @@ RELAY_REGISTER_OP("nn.relu")
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.set_support_level(1) .set_support_level(1)
.add_type_rel("Identity", IdentityRel) .add_type_rel("Identity", IdentityRel)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs, .set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs,
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
const Type& out_type, const Type& out_type,
...@@ -460,6 +472,7 @@ centered at that value (zero padding is added where necessary). ...@@ -460,6 +472,7 @@ centered at that value (zero padding is added where necessary).
.set_num_inputs(1) .set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2) .set_support_level(2)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.add_type_rel("Identity", IdentityRel); .add_type_rel("Identity", IdentityRel);
...@@ -495,6 +508,7 @@ Normalizes along dimension axis using an L2 norm ...@@ -495,6 +508,7 @@ Normalizes along dimension axis using an L2 norm
.set_num_inputs(1) .set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2) .set_support_level(2)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.add_type_rel("Identity", IdentityRel); .add_type_rel("Identity", IdentityRel);
// Dropout // Dropout
...@@ -538,6 +552,7 @@ The whole array is rescaled by ``1/(1-p)`` to keep the expected sum of the input ...@@ -538,6 +552,7 @@ The whole array is rescaled by ``1/(1-p)`` to keep the expected sum of the input
.set_num_inputs(1) .set_num_inputs(1)
.add_argument("data", "Tensor", "Input to which dropout will be applied.") .add_argument("data", "Tensor", "Input to which dropout will be applied.")
.set_support_level(1) .set_support_level(1)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.add_type_rel("Dropout", DropoutRel); .add_type_rel("Dropout", DropoutRel);
// batch_norm // batch_norm
......
/*! /*!
* Copyright (c) 2018 by Contributors * Copyright (c) 2018 by Contributors
* \file pad.cc * \file pad.cc
* \brief Implementation of operator pad * \brief Implementation of operator pad
*/ */
#include <tvm/ir_operator.h> #include <tvm/ir_operator.h>
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/nn.h>
#include <vector> #include <vector>
#include "../layout.h" #include "../layout.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
TVM_REGISTER_NODE_TYPE(PadAttrs); // relay.nn.pad
TVM_REGISTER_NODE_TYPE(PadAttrs);
bool PadRel(const Array<Type>& types,
int num_inputs, bool PadRel(const Array<Type>& types,
const Attrs& attrs, int num_inputs,
const TypeReporter& reporter) { const Attrs& attrs,
CHECK_EQ(types.size(), 2); const TypeReporter& reporter) {
const auto* data = types[0].as<TensorTypeNode>(); CHECK_EQ(types.size(), 2);
if (data == nullptr) return false; const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
const PadAttrs* param = attrs.as<PadAttrs>();
CHECK(param != nullptr); const PadAttrs* param = attrs.as<PadAttrs>();
CHECK(param != nullptr);
// check that pad widths match lengths
CHECK(data->shape.size() == param->pad_width.size()) // check that pad widths match lengths
<< "There should be as many pad width pairs as shape dimensions " CHECK(data->shape.size() == param->pad_width.size())
<< "but the shape has " << data->shape.size() << " dimensions " << "There should be as many pad width pairs as shape dimensions "
<< "and there are " << param->pad_width.size() << " pad width pairs."; << "but the shape has " << data->shape.size() << " dimensions "
<< "and there are " << param->pad_width.size() << " pad width pairs.";
// each pad width element should be a pair of positive integers
std::vector<IndexExpr> oshape; // each pad width element should be a pair of positive integers
for (size_t i = 0; i < param->pad_width.size(); i++) { std::vector<IndexExpr> oshape;
CHECK(param->pad_width[i].size() == 2) for (size_t i = 0; i < param->pad_width.size(); i++) {
<< "Each pad width element should be a pair but at index " << i CHECK(param->pad_width[i].size() == 2)
<< " there are " << param->pad_width[i].size() << " elements."; << "Each pad width element should be a pair but at index " << i
<< " there are " << param->pad_width[i].size() << " elements.";
auto width1 = as_const_int(param->pad_width[i][0]);
auto width2 = as_const_int(param->pad_width[i][1]); auto width1 = as_const_int(param->pad_width[i][0]);
CHECK(width1 != nullptr); auto width2 = as_const_int(param->pad_width[i][1]);
CHECK(width2 != nullptr); CHECK(width1 != nullptr);
CHECK(width2 != nullptr);
CHECK(*width1 >= 0)
<< "Param width elements should be positive but first pad width at " CHECK(*width1 >= 0)
<< "index " << i << " is " << *width1 << "."; << "Param width elements should be positive but first pad width at "
CHECK(*width2 >= 0) << "index " << i << " is " << *width1 << ".";
<< "Param width elements should be positive but first pad width at " CHECK(*width2 >= 0)
<< "index " << i << " is " << *width2 << "."; << "Param width elements should be positive but first pad width at "
<< "index " << i << " is " << *width2 << ".";
auto padding = make_const(data->shape[i].type(), *width1 + *width2);
oshape.push_back(data->shape[i] + padding); auto padding = make_const(data->shape[i].type(), *width1 + *width2);
} oshape.push_back(data->shape[i] + padding);
}
reporter->Assign(types[1], TensorTypeNode::make(Array<IndexExpr>(oshape),
data->dtype)); reporter->Assign(types[1], TensorTypeNode::make(Array<IndexExpr>(oshape),
return true; data->dtype));
} return true;
}
// Handler to create a call to the padding op used by front-end FFI
Expr MakePad(Expr data, Array<Array<IndexExpr> > pad_width, double pad_value) { // Handler to create a call to the padding op used by front-end FFI
auto attrs = make_node<PadAttrs>(); Expr MakePad(Expr data, Array<Array<IndexExpr> > pad_width, double pad_value) {
attrs->pad_value = pad_value; auto attrs = make_node<PadAttrs>();
attrs->pad_width = std::move(pad_width); attrs->pad_value = pad_value;
static const Op& op = Op::Get("nn.pad"); attrs->pad_width = std::move(pad_width);
return CallNode::make(op, {data}, Attrs(attrs), {}); static const Op& op = Op::Get("nn.pad");
} return CallNode::make(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op.nn._make.pad")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { TVM_REGISTER_API("relay.op.nn._make.pad")
runtime::detail::unpack_call<Expr, 3>(MakePad, args, rv); .set_body([](const TVMArgs& args, TVMRetValue* rv) {
}); runtime::detail::unpack_call<Expr, 3>(MakePad, args, rv);
});
RELAY_REGISTER_OP("nn.pad")
.describe(R"code(Pad for n-D tensor. RELAY_REGISTER_OP("nn.pad")
.describe(R"code(Pad for n-D tensor.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.PadAttrs") )code" TVM_ADD_FILELINE)
.set_num_inputs(1) .set_attrs_type_key("relay.attrs.PadAttrs")
.add_argument("data", "Tensor", "The input tensor.") .set_num_inputs(1)
.set_support_level(2) .add_argument("data", "Tensor", "The input tensor.")
.add_type_rel("Pad", PadRel); .set_support_level(2)
.add_type_rel("Pad", PadRel);
} // namespace relay
} // namespace tvm } // namespace relay
} // namespace tvm
...@@ -9,13 +9,39 @@ ...@@ -9,13 +9,39 @@
#include <topi/nn/pooling.h> #include <topi/nn/pooling.h>
#include <vector> #include <vector>
#include "../layout.h" #include "../layout.h"
#include "../../pass/alter_op_layout.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
// relay.nn.max_pool2d & relay.nn.avg_pool2d
TVM_REGISTER_NODE_TYPE(MaxPool2DAttrs); TVM_REGISTER_NODE_TYPE(MaxPool2DAttrs);
TVM_REGISTER_NODE_TYPE(AvgPool2DAttrs); TVM_REGISTER_NODE_TYPE(AvgPool2DAttrs);
template <typename T>
Array<Array<Layout> > Pool2DInferCorrectLayout(
const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>> &old_in_shapes) {
// NOTE: Discard "const" qualifier here.
T *params = const_cast<T*>(attrs.as<T>());
if (new_in_layouts.defined()) {
CHECK_EQ(new_in_layouts.size(), 1);
Layout raw_layout(params->layout);
Layout input = new_in_layouts[0];
if (input.Indexof('W') == raw_layout.Indexof('W') &&
input.Indexof('H') == raw_layout.Indexof('H') &&
!input.Contains('w') && !input.Contains('h')) {
params->layout = input.name(); // modify self to follow the input layout
}
}
return Array<Array<Layout> >{{params->layout}, {params->layout}};
}
template <typename AttrType> template <typename AttrType>
bool Pool2DRel(const Array<Type>& types, bool Pool2DRel(const Array<Type>& types,
int num_inputs, int num_inputs,
...@@ -163,6 +189,7 @@ RELAY_REGISTER_OP("nn.max_pool2d") ...@@ -163,6 +189,7 @@ RELAY_REGISTER_OP("nn.max_pool2d")
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2) .set_support_level(2)
.add_type_rel("MaxPool2D", Pool2DRel<MaxPool2DAttrs>) .add_type_rel("MaxPool2D", Pool2DRel<MaxPool2DAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", Pool2DInferCorrectLayout<MaxPool2DAttrs>)
.set_attr<FTVMCompute>("FTVMCompute", Pool2DCompute<MaxPool2DAttrs, topi::nn::kMaxPool>); .set_attr<FTVMCompute>("FTVMCompute", Pool2DCompute<MaxPool2DAttrs, topi::nn::kMaxPool>);
...@@ -219,9 +246,10 @@ Average pooling operation for one dimensional data. ...@@ -219,9 +246,10 @@ Average pooling operation for one dimensional data.
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2) .set_support_level(2)
.add_type_rel("AvgPool2D", Pool2DRel<AvgPool2DAttrs>) .add_type_rel("AvgPool2D", Pool2DRel<AvgPool2DAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", Pool2DInferCorrectLayout<AvgPool2DAttrs>)
.set_attr<FTVMCompute>("FTVMCompute", Pool2DCompute<AvgPool2DAttrs, topi::nn::kAvgPool>); .set_attr<FTVMCompute>("FTVMCompute", Pool2DCompute<AvgPool2DAttrs, topi::nn::kAvgPool>);
// Global Pool // relay.nn.global_pool_2d & relay.nn.max_pool_2d
TVM_REGISTER_NODE_TYPE(GlobalPool2DAttrs); TVM_REGISTER_NODE_TYPE(GlobalPool2DAttrs);
bool GlobalPool2DRel(const Array<Type>& types, bool GlobalPool2DRel(const Array<Type>& types,
...@@ -247,8 +275,9 @@ bool GlobalPool2DRel(const Array<Type>& types, ...@@ -247,8 +275,9 @@ bool GlobalPool2DRel(const Array<Type>& types,
const auto hidx = layout.Indexof('H'); const auto hidx = layout.Indexof('H');
const auto widx = layout.Indexof('W'); const auto widx = layout.Indexof('W');
std::vector<IndexExpr> oshape({dshape[0], dshape[1], dshape[2], dshape[3]}); Array<IndexExpr> oshape(dshape);
oshape[hidx] = oshape[widx] = 1; oshape.Set(hidx, 1);
oshape.Set(widx, 1);
// assign output type // assign output type
reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
...@@ -307,6 +336,8 @@ RELAY_REGISTER_OP("nn.global_avg_pool2d") ...@@ -307,6 +336,8 @@ RELAY_REGISTER_OP("nn.global_avg_pool2d")
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2) .set_support_level(2)
.add_type_rel("GlobalAvgPool2D", GlobalPool2DRel) .add_type_rel("GlobalAvgPool2D", GlobalPool2DRel)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
Pool2DInferCorrectLayout<GlobalPool2DAttrs>)
.set_attr<FTVMCompute>("FTVMCompute", GlobalPool2DCompute<topi::nn::kAvgPool>); .set_attr<FTVMCompute>("FTVMCompute", GlobalPool2DCompute<topi::nn::kAvgPool>);
// GlobalMaxPool // GlobalMaxPool
...@@ -338,6 +369,8 @@ RELAY_REGISTER_OP("nn.global_max_pool2d") ...@@ -338,6 +369,8 @@ RELAY_REGISTER_OP("nn.global_max_pool2d")
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2) .set_support_level(2)
.add_type_rel("GlobalMaxPool2D", GlobalPool2DRel) .add_type_rel("GlobalMaxPool2D", GlobalPool2DRel)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
Pool2DInferCorrectLayout<GlobalPool2DAttrs>)
.set_attr<FTVMCompute>("FTVMCompute", GlobalPool2DCompute<topi::nn::kMaxPool>); .set_attr<FTVMCompute>("FTVMCompute", GlobalPool2DCompute<topi::nn::kMaxPool>);
} // namespace relay } // namespace relay
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h> #include <tvm/relay/op_attr_types.h>
#include <vector> #include <vector>
#include "../pass/alter_op_layout.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -32,21 +33,24 @@ inline std::vector<T> AsVector(const Array<T> &array) { ...@@ -32,21 +33,24 @@ inline std::vector<T> AsVector(const Array<T> &array) {
* We make the decision to always only expose positional argument. * We make the decision to always only expose positional argument.
* We will do rewrapping in the frontend to support language * We will do rewrapping in the frontend to support language
* sugars such as keyword arguments and default value. * sugars such as keyword arguments and default value.
*
* \param Prefix the prefix of the registry, for example, "relay.op._make.".
*
* \param OpName the name of registry. * \param OpName the name of registry.
*/ */
#define RELAY_REGISTER_UNARY_OP(Prefix, OpName) \ #define RELAY_REGISTER_UNARY_OP(OpName) \
TVM_REGISTER_API(Prefix OpName) \ TVM_REGISTER_API("relay.op._make." OpName) \
.set_body_typed<Expr(Expr)>([](Expr data) { \ .set_body_typed<Expr(Expr)>([](Expr data) { \
static const Op& op = Op::Get(OpName); \ static const Op& op = Op::Get(OpName); \
return CallNode::make(op, {data}, Attrs(), {}); \ return CallNode::make(op, {data}, Attrs(), {}); \
}); \ }); \
RELAY_REGISTER_OP(OpName) \ RELAY_REGISTER_OP(OpName) \
.set_num_inputs(1) \ .set_num_inputs(1) \
.add_argument("data", "Tensor", "The input tensor.") \ .add_argument("data", "Tensor", "The input tensor.") \
.set_attr<TOpPattern>("TOpPattern", kElemWise) .add_type_rel("Identity", IdentityRel) \
.set_attr<TOpPattern>("TOpPattern", kElemWise) \
.set_attr<TOpIsStateful>("TOpIsStateful", false) \
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", \
ElemwiseArbitraryLayout) \
/*! Quick helper macro /*! Quick helper macro
* - Expose a positional make function to construct the node. * - Expose a positional make function to construct the node.
...@@ -56,12 +60,10 @@ inline std::vector<T> AsVector(const Array<T> &array) { ...@@ -56,12 +60,10 @@ inline std::vector<T> AsVector(const Array<T> &array) {
* We will do rewrapping in the frontend to support language * We will do rewrapping in the frontend to support language
* sugars such as keyword arguments and default value. * sugars such as keyword arguments and default value.
* *
* \param Prefix the prefix of the registry, for example, "relay.op._make.".
*
* \param OpName the name of registry. * \param OpName the name of registry.
*/ */
#define RELAY_REGISTER_BINARY_OP(Prefix, OpName) \ #define RELAY_REGISTER_BINARY_OP(OpName) \
TVM_REGISTER_API(Prefix OpName) \ TVM_REGISTER_API("relay.op._make." OpName) \
.set_body_typed<Expr(Expr, Expr)>([](Expr lhs, Expr rhs) { \ .set_body_typed<Expr(Expr, Expr)>([](Expr lhs, Expr rhs) { \
static const Op& op = Op::Get(OpName); \ static const Op& op = Op::Get(OpName); \
return CallNode::make(op, {lhs, rhs}, Attrs(), {}); \ return CallNode::make(op, {lhs, rhs}, Attrs(), {}); \
...@@ -72,7 +74,26 @@ inline std::vector<T> AsVector(const Array<T> &array) { ...@@ -72,7 +74,26 @@ inline std::vector<T> AsVector(const Array<T> &array) {
.add_argument("rhs", "Tensor", "The right hand side tensor.") \ .add_argument("rhs", "Tensor", "The right hand side tensor.") \
.add_type_rel("Broadcast", BroadcastRel) \ .add_type_rel("Broadcast", BroadcastRel) \
.set_attr<TOpPattern>("TOpPattern", kBroadcast) \ .set_attr<TOpPattern>("TOpPattern", kBroadcast) \
.set_attr<TOpIsStateful>("TOpIsStateful", false) .set_attr<TOpIsStateful>("TOpIsStateful", false) \
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", \
BinaryBroadcastLayout)
// Comparisons
#define RELAY_REGISTER_CMP_OP(OpName) \
TVM_REGISTER_API("relay.op._make." OpName) \
.set_body_typed<Expr(Expr, Expr)>([](Expr lhs, Expr rhs) { \
static const Op& op = Op::Get(OpName); \
return CallNode::make(op, {lhs, rhs}, Attrs(), {}); \
}); \
RELAY_REGISTER_OP(OpName) \
.set_num_inputs(2) \
.add_argument("lhs", "Tensor", "The left hand side tensor.") \
.add_argument("rhs", "Tensor", "The right hand side tensor.") \
.add_type_rel("BroadcastComp", BroadcastCompRel) \
.set_attr<TOpPattern>("TOpPattern", kBroadcast) \
.set_attr<TOpIsStateful>("TOpIsStateful", false) \
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", \
BinaryBroadcastLayout)
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
......
...@@ -23,71 +23,65 @@ namespace relay { ...@@ -23,71 +23,65 @@ namespace relay {
// Addition // Addition
RELAY_REGISTER_BINARY_OP("relay.op._make.", "add") RELAY_REGISTER_BINARY_OP("add")
.describe("Elementwise add with with broadcasting") .describe("Elementwise add with with broadcasting")
.set_support_level(1) .set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::add)); .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::add));
// Subtraction // Subtraction
RELAY_REGISTER_BINARY_OP("relay.op._make.", "subtract") RELAY_REGISTER_BINARY_OP("subtract")
.describe("Elementwise substract with broadcasting") .describe("Elementwise substract with broadcasting")
.set_support_level(1) .set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::subtract)); .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::subtract));
// Right shift // Right shift
RELAY_REGISTER_BINARY_OP("relay.op._make.", "right_shift") RELAY_REGISTER_BINARY_OP("right_shift")
.describe("Elementwise right shift with broadcasting") .describe("Elementwise right shift with broadcasting")
.set_support_level(4) .set_support_level(4)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::right_shift)); .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::right_shift));
RELAY_REGISTER_BINARY_OP("relay.op._make.", "left_shift")
RELAY_REGISTER_BINARY_OP("left_shift")
.describe("Elementwise left shift with broadcasting") .describe("Elementwise left shift with broadcasting")
.set_support_level(4) .set_support_level(4)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::left_shift)); .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::left_shift));
RELAY_REGISTER_BINARY_OP("relay.op._make.", "maximum")
RELAY_REGISTER_BINARY_OP("maximum")
.describe("Elementwise maximum of two tensors with broadcasting") .describe("Elementwise maximum of two tensors with broadcasting")
.set_support_level(4) .set_support_level(4)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::maximum)); .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::maximum));
RELAY_REGISTER_BINARY_OP("relay.op._make.", "minimum")
RELAY_REGISTER_BINARY_OP("minimum")
.describe("Elementwise minimum of two tensors with broadcasting") .describe("Elementwise minimum of two tensors with broadcasting")
.set_support_level(4) .set_support_level(4)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::minimum)); .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::minimum));
RELAY_REGISTER_BINARY_OP("relay.op._make.", "divide")
RELAY_REGISTER_BINARY_OP("divide")
.describe("Elementwise divide with broadcasting") .describe("Elementwise divide with broadcasting")
.set_support_level(1) .set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::divide)); .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::divide));
RELAY_REGISTER_BINARY_OP("relay.op._make.", "multiply")
RELAY_REGISTER_BINARY_OP("multiply")
.describe("Elementwise multiply with broadcasting") .describe("Elementwise multiply with broadcasting")
.set_support_level(1) .set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::multiply)); .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::multiply));
RELAY_REGISTER_BINARY_OP("relay.op._make.", "power")
RELAY_REGISTER_BINARY_OP("power")
.describe("Elementwise power with broadcasting") .describe("Elementwise power with broadcasting")
.set_support_level(4) .set_support_level(4)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::power)); .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::power));
RELAY_REGISTER_BINARY_OP("relay.op._make.", "mod")
RELAY_REGISTER_BINARY_OP("mod")
.describe("Elementwise mod with broadcasting") .describe("Elementwise mod with broadcasting")
.set_support_level(1) .set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::mod)); .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::mod));
// Comparisons
#define RELAY_REGISTER_CMP_OP(OpName) \
TVM_REGISTER_API("relay.op._make." OpName) \
.set_body_typed<Expr(Expr, Expr)>([](Expr lhs, Expr rhs) { \
static const Op& op = Op::Get(OpName); \
return CallNode::make(op, {lhs, rhs}, Attrs(), {}); \
}); \
RELAY_REGISTER_OP(OpName) \
.set_num_inputs(2) \
.add_argument("lhs", "Tensor", "The left hand side tensor.") \
.add_argument("rhs", "Tensor", "The right hand side tensor.") \
.add_type_rel("BroadcastComp", BroadcastCompRel) \
.set_attr<TOpPattern>("TOpPattern", kBroadcast)
RELAY_REGISTER_CMP_OP("equal") RELAY_REGISTER_CMP_OP("equal")
.describe("Elementwise equal compare with broadcasting") .describe("Elementwise equal compare with broadcasting")
......
...@@ -11,9 +11,12 @@ ...@@ -11,9 +11,12 @@
#include <topi/elemwise.h> #include <topi/elemwise.h>
#include <topi/broadcast.h> #include <topi/broadcast.h>
#include <topi/reduction.h> #include <topi/reduction.h>
#include <topi/nn.h>
#include <vector> #include <vector>
#include "../op_common.h" #include "../op_common.h"
#include "../../../arithmetic/compute_expr.h" #include "../../../arithmetic/compute_expr.h"
#include "../../pass/alter_op_layout.h"
#include "../layout.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -156,6 +159,7 @@ RELAY_REGISTER_OP("expand_dims") ...@@ -156,6 +159,7 @@ RELAY_REGISTER_OP("expand_dims")
.set_attr<FTVMCompute>("FTVMCompute", ExpandDimsCompute) .set_attr<FTVMCompute>("FTVMCompute", ExpandDimsCompute)
.set_attr<TOpPattern>("TOpPattern", kBroadcast); .set_attr<TOpPattern>("TOpPattern", kBroadcast);
// relay.concatenate
TVM_REGISTER_NODE_TYPE(ConcatenateAttrs); TVM_REGISTER_NODE_TYPE(ConcatenateAttrs);
bool ConcatenateRel(const Array<Type>& types, bool ConcatenateRel(const Array<Type>& types,
...@@ -201,6 +205,42 @@ bool ConcatenateRel(const Array<Type>& types, ...@@ -201,6 +205,42 @@ bool ConcatenateRel(const Array<Type>& types,
return true; return true;
} }
Array<Array<Layout>> ConcatenateLayout(
const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>> &old_in_shapes) {
const ConcatenateAttrs* param = attrs.as<ConcatenateAttrs>();
size_t axis = param->axis < 0 ? param->axis + old_in_shapes[0].size() :
static_cast<size_t>(param->axis);
Layout ret;
if (new_in_layouts.defined()) { // this function is called after some operators are alternated.
Layout::LayoutDim concate_dim = old_in_layouts[0][axis];
for (size_t i = 0; i < new_in_layouts.size(); ++i) {
if (new_in_layouts[i].ndim() > axis &&
new_in_layouts[i][axis] == concate_dim) {
ret = new_in_layouts[i];
break;
}
}
} else { // this function is called on the original correct relay ir
for (size_t i = 0; i < old_in_layouts.size(); ++i) {
if (old_in_layouts[i].defined()) {
ret = old_in_layouts[i];
break;
}
}
if (ret.ndim() <= axis || Layout::IsSubdim(ret[axis])) {
return Array<Array<Layout> > {{Layout::Undef()}, {Layout::Undef()}};
}
}
return Array<Array<Layout> > {Array<Layout>(old_in_layouts.size(), ret), {ret}};
}
Expr MakeConcatenate(Expr data, Expr MakeConcatenate(Expr data,
int axis) { int axis) {
auto attrs = make_node<ConcatenateAttrs>(); auto attrs = make_node<ConcatenateAttrs>();
...@@ -226,7 +266,8 @@ RELAY_REGISTER_OP("concatenate") ...@@ -226,7 +266,8 @@ RELAY_REGISTER_OP("concatenate")
.set_num_inputs(1) .set_num_inputs(1)
.add_argument("data", "Tensor", "The input list of tensors.") .add_argument("data", "Tensor", "The input list of tensors.")
.set_support_level(1) .set_support_level(1)
.add_type_rel("Concatenate", ConcatenateRel); .add_type_rel("Concatenate", ConcatenateRel)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConcatenateLayout);
/* relay.transpose */ /* relay.transpose */
TVM_REGISTER_NODE_TYPE(TransposeAttrs); TVM_REGISTER_NODE_TYPE(TransposeAttrs);
...@@ -323,7 +364,6 @@ RELAY_REGISTER_OP("transpose") ...@@ -323,7 +364,6 @@ RELAY_REGISTER_OP("transpose")
.set_attr<TOpPattern>("TOpPattern", kInjective); .set_attr<TOpPattern>("TOpPattern", kInjective);
/* relay.reshape */ /* relay.reshape */
TVM_REGISTER_NODE_TYPE(ReshapeAttrs); TVM_REGISTER_NODE_TYPE(ReshapeAttrs);
bool ReshapeRel(const Array<Type>& types, bool ReshapeRel(const Array<Type>& types,
...@@ -1252,7 +1292,7 @@ Examples:: ...@@ -1252,7 +1292,7 @@ Examples::
.set_attr<TOpPattern>("TOpPattern", kInjective); .set_attr<TOpPattern>("TOpPattern", kInjective);
// Split // relay.split
TVM_REGISTER_NODE_TYPE(SplitAttrs); TVM_REGISTER_NODE_TYPE(SplitAttrs);
bool SplitRel(const Array<Type>& types, bool SplitRel(const Array<Type>& types,
...@@ -1367,6 +1407,7 @@ the entries indicate where along axis the array is split. ...@@ -1367,6 +1407,7 @@ the entries indicate where along axis the array is split.
.set_attr<TOpPattern>("TOpPattern", kInjective); .set_attr<TOpPattern>("TOpPattern", kInjective);
// relay.slice_like
TVM_REGISTER_NODE_TYPE(SliceLikeAttrs); TVM_REGISTER_NODE_TYPE(SliceLikeAttrs);
/*! /*!
...@@ -1513,5 +1554,104 @@ RELAY_REGISTER_OP("slice_like") ...@@ -1513,5 +1554,104 @@ RELAY_REGISTER_OP("slice_like")
.set_attr<FTVMCompute>("FTVMCompute", SliceLikeCompute) .set_attr<FTVMCompute>("FTVMCompute", SliceLikeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective); .set_attr<TOpPattern>("TOpPattern", kInjective);
// relay.layout_transform
Array<Tensor> LayoutTransformCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
const LayoutTransformAttrs *param = attrs.as<LayoutTransformAttrs>();
CHECK(param != nullptr);
Layout src_layout(param->src_layout);
Layout dst_layout(param->dst_layout);
if (src_layout.Equals(dst_layout)) {
return Array<Tensor>{ inputs[0] };
}
CHECK(src_layout.defined() && dst_layout.defined())
<< "cannot convert from/to undefined layout";
CHECK(src_layout.Convertible(dst_layout))
<< "cannot convert from " << param->src_layout << " to " << param->dst_layout;
const auto& out_shape = ConvertLayout(inputs[0]->shape, src_layout, dst_layout);
return Array<Tensor> {
topi::layout_transform(inputs[0], out_shape, [&](const Array<tvm::Var>& dst_indices) {
std::vector<tvm::Expr> dst_to_src_indices;
for (size_t i = 0; i < src_layout.ndim(); ++i) {
Layout::LayoutDim src_axis = src_layout[i];
int dst_major_pos = dst_layout.Indexof(Layout::ToSuperdim(src_axis));
int dst_minor_pos = dst_layout.Indexof(Layout::ToSubdim(src_axis));
int32_t src_factor = static_cast<int32_t>(src_layout.Subsizeof(src_axis));
int32_t dst_factor = static_cast<int32_t>(dst_layout.Subsizeof(src_axis));
tvm::Expr src_index(dst_indices[dst_major_pos]);
if (dst_minor_pos >= 0) {
CHECK_GT(dst_factor, 0);
src_index = src_index * dst_factor + dst_indices[dst_minor_pos];
}
if (Layout::IsSuperdim(src_axis) && src_factor > 0) {
src_index = src_index / src_factor;
} else if (Layout::IsSubdim(src_axis) && src_factor > 0) {
src_index = src_index % src_factor;
}
dst_to_src_indices.push_back(src_index);
}
return Array<tvm::Expr>(dst_to_src_indices);
})
};
}
bool LayoutTransformRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
const auto* data = types[0].as<TensorTypeNode>();
CHECK(data != nullptr);
const LayoutTransformAttrs* params = attrs.as<LayoutTransformAttrs>();
Layout src_layout(params->src_layout);
Layout dst_layout(params->dst_layout);
CHECK(src_layout.defined() && dst_layout.defined())
<< "cannot convert from/to undefined layout";
CHECK(src_layout.Convertible(dst_layout))
<< "cannot convert from " << params->src_layout << " to " << params->dst_layout;
const auto& out_shape = ConvertLayout(data->shape, src_layout, dst_layout);
reporter->Assign(types[1], TensorTypeNode::make(out_shape, data->dtype));
return true;
}
Expr MakeLayoutTransform(Expr data,
std::string src_layout,
std::string dst_layout) {
auto attrs = make_node<LayoutTransformAttrs>();
attrs->src_layout = std::move(src_layout);
attrs->dst_layout = std::move(dst_layout);
static const Op& op = Op::Get("layout_transform");
return CallNode::make(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op._make.layout_transform")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 3>(MakeLayoutTransform, args, rv);
});
RELAY_REGISTER_OP("layout_transform")
.describe(R"code(Transform the input data layout.
For transforming from NCHW to N16cHWC, the `__layout_transform__` operator reshapes
the input array by output[n, c, h, w, C] = data[n, C*16+c, h, w]
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.LayoutTransformAttrs")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.add_type_rel("layout_transform", LayoutTransformRel)
.set_support_level(5)
.set_attr<FTVMCompute>("FTVMCompute", LayoutTransformCompute);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -22,7 +22,7 @@ namespace relay { ...@@ -22,7 +22,7 @@ namespace relay {
} \ } \
RELAY_REGISTER_UNARY_OP("relay.op._make.", "log") RELAY_REGISTER_UNARY_OP("log")
.describe(R"code(Returns the log input array, computed element-wise. .describe(R"code(Returns the log input array, computed element-wise.
.. math:: .. math::
...@@ -30,11 +30,10 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "log") ...@@ -30,11 +30,10 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "log")
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_support_level(1) .set_support_level(1)
.add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log)); .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log));
RELAY_REGISTER_UNARY_OP("relay.op._make.", "exp") RELAY_REGISTER_UNARY_OP("exp")
.describe(R"code(Returns the exp input array, computed element-wise. .describe(R"code(Returns the exp input array, computed element-wise.
.. math:: .. math::
...@@ -42,36 +41,30 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "exp") ...@@ -42,36 +41,30 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "exp")
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_support_level(1) .set_support_level(1)
.add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::exp)); .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::exp));
RELAY_REGISTER_UNARY_OP("sqrt")
RELAY_REGISTER_UNARY_OP("relay.op._make.", "sqrt") .describe(R"code(Returns the rsqrt input array, computed element-wise.
.describe(R"code(Returns the sqrt input array, computed element-wise.
.. math:: .. math::
sqrt(x) sqrt(x)
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_support_level(1) .set_support_level(1)
.add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sqrt)); .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sqrt));
RELAY_REGISTER_UNARY_OP("relay.op._make.", "zeros_like") RELAY_REGISTER_UNARY_OP("zeros_like")
.describe(R"code(Returns an array of zeros, with same type and shape as the input. .describe(R"code(Returns an array of zeros, with same type and shape as the input.
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_support_level(1) .set_support_level(4);
.add_type_rel("Identity", IdentityRel);
RELAY_REGISTER_UNARY_OP("ones_like")
RELAY_REGISTER_UNARY_OP("relay.op._make.", "ones_like")
.describe(R"code(Returns an array of ones, with same type and shape as the input. .describe(R"code(Returns an array of ones, with same type and shape as the input.
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_support_level(1) .set_support_level(4);
.add_type_rel("Identity", IdentityRel);
RELAY_REGISTER_UNARY_OP("relay.op._make.", "sigmoid") RELAY_REGISTER_UNARY_OP("sigmoid")
.describe(R"code(Returns the sigmoid input array, computed element-wise. .describe(R"code(Returns the sigmoid input array, computed element-wise.
.. math:: .. math::
...@@ -79,48 +72,47 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "sigmoid") ...@@ -79,48 +72,47 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "sigmoid")
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_support_level(1) .set_support_level(1)
.add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sigmoid)); .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sigmoid));
RELAY_REGISTER_UNARY_OP("relay.op._make.", "copy") RELAY_REGISTER_UNARY_OP("copy")
.describe(R"code(Copy a tensor. .describe(R"code(Copy a tensor.
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_support_level(3) .set_support_level(3)
.add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::identity)); .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::identity));
// relay.clip // relay.clip
TVM_REGISTER_NODE_TYPE(ClipAttrs); TVM_REGISTER_NODE_TYPE(ClipAttrs);
TVM_REGISTER_API("relay.op._make.clip") TVM_REGISTER_API("relay.op._make.clip")
.set_body_typed<Expr(Expr, double, double)>([](Expr a, double a_min, double a_max) { .set_body_typed<Expr(Expr, double, double)>([](Expr a, double a_min, double a_max) {
auto attrs = make_node<ClipAttrs>(); auto attrs = make_node<ClipAttrs>();
attrs->a_min = a_min; attrs->a_min = a_min;
attrs->a_max = a_max; attrs->a_max = a_max;
static const Op& op = Op::Get("clip"); static const Op& op = Op::Get("clip");
return CallNode::make(op, {a}, Attrs(attrs), {}); return CallNode::make(op, {a}, Attrs(attrs), {});
}); });
RELAY_REGISTER_OP("clip") RELAY_REGISTER_OP("clip")
.describe(R"code(Clip tensor values. .describe(R"code(Clip tensor values.
This function takes a tensor, a minimum value `a_min`, and a maximum value `a_max`, and returns a clipped tensor where all values below `a_min` are set to `a_min` and all values above `a_max` are set to `a_max`. `a_min` and `a_max` are cast to the tensor's dtype. This function takes a tensor, a minimum value `a_min`, and a maximum value `a_max`, and returns a clipped tensor where all values below `a_min` are set to `a_min` and all values above `a_max` are set to `a_max`. `a_min` and `a_max` are cast to the tensor's dtype.
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_num_inputs(1) .set_num_inputs(1)
.add_argument("tensor", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.set_support_level(3) .add_type_rel("Identity", IdentityRel)
.add_type_rel("Clip", IdentityRel); .set_attr<TOpPattern>("TOpPattern", kElemWise)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_support_level(3);
RELAY_REGISTER_UNARY_OP("relay.op._make.", "floor") RELAY_REGISTER_UNARY_OP("floor")
.describe(R"code(Returns the floor of input array, computed element-wise. .describe(R"code(Returns the floor of input array, computed element-wise.
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_support_level(3) .set_support_level(3)
.add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::floor)); .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::floor));
RELAY_REGISTER_UNARY_OP("relay.op._make.", "ceil") RELAY_REGISTER_UNARY_OP("ceil")
.describe(R"code(Returns the ceil of input array, computed element-wise. .describe(R"code(Returns the ceil of input array, computed element-wise.
.. math:: .. math::
...@@ -128,11 +120,10 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "ceil") ...@@ -128,11 +120,10 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "ceil")
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_support_level(3) .set_support_level(3)
.add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::ceil)); .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::ceil));
RELAY_REGISTER_UNARY_OP("relay.op._make.", "trunc") RELAY_REGISTER_UNARY_OP("trunc")
.describe(R"code(Returns the trunc of input array, computed element-wise. .describe(R"code(Returns the trunc of input array, computed element-wise.
.. math:: .. math::
...@@ -140,11 +131,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "trunc") ...@@ -140,11 +131,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "trunc")
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_support_level(3) .set_support_level(3)
.add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::trunc)); .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::trunc));
RELAY_REGISTER_UNARY_OP("round")
RELAY_REGISTER_UNARY_OP("relay.op._make.", "round")
.describe(R"code(Returns the round of input array, computed element-wise. .describe(R"code(Returns the round of input array, computed element-wise.
.. math:: .. math::
...@@ -152,11 +141,10 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "round") ...@@ -152,11 +141,10 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "round")
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_support_level(3) .set_support_level(3)
.add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::round)); .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::round));
RELAY_REGISTER_UNARY_OP("relay.op._make.", "abs") RELAY_REGISTER_UNARY_OP("abs")
.describe(R"code(Returns the abs of input array, computed element-wise. .describe(R"code(Returns the abs of input array, computed element-wise.
.. math:: .. math::
...@@ -164,11 +152,10 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "abs") ...@@ -164,11 +152,10 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "abs")
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_support_level(3) .set_support_level(3)
.add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::abs)); .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::abs));
RELAY_REGISTER_UNARY_OP("relay.op._make.", "tanh") RELAY_REGISTER_UNARY_OP("tanh")
.describe(R"code(Returns the tanh of input array, computed element-wise. .describe(R"code(Returns the tanh of input array, computed element-wise.
.. math:: .. math::
...@@ -176,11 +163,10 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "tanh") ...@@ -176,11 +163,10 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "tanh")
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_support_level(1) .set_support_level(1)
.add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::tanh)); .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::tanh));
RELAY_REGISTER_UNARY_OP("relay.op._make.", "negative") RELAY_REGISTER_UNARY_OP("negative")
.describe(R"code(Returns the numeric negative of input array, computed element-wise. .describe(R"code(Returns the numeric negative of input array, computed element-wise.
.. math:: .. math::
...@@ -188,7 +174,6 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "negative") ...@@ -188,7 +174,6 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "negative")
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_support_level(3) .set_support_level(3)
.add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::negative)); .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::negative));
} // namespace relay } // namespace relay
......
/*!
* Copyright (c) 2018 by Contributors
* \file alter_op_layout.h
* \brief Alternate the layouts of operators or replace primitive operators with
other expressions. This pass can be used for computing convolution in
custom layouts or other general weight pre-transformation.
*/
#ifndef TVM_RELAY_PASS_ALTER_OP_LAYOUT_H_
#define TVM_RELAY_PASS_ALTER_OP_LAYOUT_H_
#include <tvm/relay/expr.h>
#include "../op/layout.h"
namespace tvm {
namespace relay {
/*!
* \brief Infer & correct function of node layout. See \p Layout for layout convention
* \param attrs The attribute of the node.
* \param new_in_layouts The layouts of input arguments after alter_op_layout.
* This can be undefined, which means we call this function before alternating
* any operators.
* \param old_in_layouts The layouts of input arguments before alter_op_layout.
* \param old_in_shapes The shapes of old input arguments.
* \return infered_layout An array of two elements that are inferred input layouts and
* inferred output layouts.
*/
using FInferCorrectLayout = runtime::TypedPackedFunc<
Array<Array<Layout>>(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>> &old_in_shapes)>;
/*! \brief take arbitrary input layout and copy to output */
inline Array<Array<Layout> > ElemwiseArbitraryLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>> &old_in_shapes) {
Layout ret;
if (new_in_layouts.defined()) {
CHECK_GE(new_in_layouts.size(), 1);
ret = new_in_layouts[0];
} else {
for (size_t i = 0; i < old_in_layouts.size(); ++i) {
if (old_in_layouts[i].defined()) {
ret = old_in_layouts[i];
break;
}
}
}
return Array<Array<Layout> >{Array<Layout>(old_in_layouts.size(), ret), {ret}};
}
/*! \brief Infer layout for binary broadcast operators */
inline Array<Array<Layout> > BinaryBroadcastLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>> &old_in_shapes) {
Array<Layout> layouts;
if (new_in_layouts.defined()) {
layouts.assign(new_in_layouts.begin(), new_in_layouts.end());
} else {
layouts.assign(old_in_layouts.begin(), old_in_layouts.end());
}
if (!layouts[0].defined() && !layouts[1].defined()) {
// both undefined, infer fails
return Array<Array<Layout> > {{Layout::Undef()}, {Layout::Undef()}};
} else if (!layouts[0].defined() || !layouts[1].defined()) {
// only one is defined, use shape information to help infer
int defined_idx = layouts[0].defined() ? 0 : 1;
int undef_idx = 1 - defined_idx;
if (old_in_shapes[defined_idx].size() >= old_in_shapes[undef_idx].size()) {
layouts.Set(undef_idx,
layouts[defined_idx].Sublayout(
old_in_shapes[defined_idx].size() - old_in_shapes[undef_idx].size(),
old_in_shapes[undef_idx].size()));
return Array<Array<Layout> > {layouts, {layouts[defined_idx]}};
} else {
// only know the tensor with smaller dimensions,
// so we cannot infer the final broadcasted output.
// fails in this case.
return Array<Array<Layout> > {{Layout::Undef()}, {Layout::Undef()}};
}
} else {
// try to broadcast the tensors to the larger dimension
int large_idx = layouts[0].ndim_super() >= layouts[1].ndim_super() ? 0 : 1;
int small_idx = 1 - large_idx;
Layout ret = layouts[large_idx];
// extract common part
size_t i = layouts[large_idx].ndim();
for (; i != 0; --i) {
auto dim = layouts[large_idx][i-1];
if (!layouts[small_idx].Contains(Layout::ToSuperdim(dim))) {
break;
}
}
Layout common_part = layouts[large_idx].Sublayout(i, layouts[large_idx].ndim() - i);
if (!layouts[small_idx].Convertible(common_part)) { // fail
return Array<Array<Layout> > {{Layout::Undef()}, {Layout::Undef()}};
}
layouts.Set(small_idx, common_part);
return Array<Array<Layout> > {layouts, {ret}};
}
}
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_ALTER_OP_LAYOUT_H_
/*!
* Copyright (c) 2018 by Contributors
* \file canonicalize_ops.cc
* \brief Canonicalize special operators to basic operators.
This can simplify latter analysis. (e.g. Expand bias_add to expand_dims and broadcast_add.)
*/
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h>
#include "pattern_util.h"
namespace tvm {
namespace relay {
class BiasAddSimplifier : public ExprMutator {
public:
Expr VisitExpr_(const CallNode* n) {
static const Op& bias_add = Op::Get("nn.bias_add");
auto new_n = ExprMutator::VisitExpr_(n);
if (n->op.same_as(bias_add)) {
Call call = Downcast<Call>(new_n);
CHECK_EQ(call->args.size(), 2);
const BiasAddAttrs* param = call->attrs.as<BiasAddAttrs>();
auto ttype = call->args[0]->type_as<TensorTypeNode>();
size_t n_dim = ttype->shape.size();
Expr expanded_bias = ExpandBiasToMatchAxis(call->args[1], n_dim, {param->axis});
Expr ret = Add(call->args[0], expanded_bias);
ret->checked_type_ = n->checked_type_;
return ret;
}
return new_n;
}
};
Expr CanonicalizeOps(const Expr& e) {
return BiasAddSimplifier().Mutate(e);
}
TVM_REGISTER_API("relay._ir_pass.canonicalize_ops")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = CanonicalizeOps(args[0]);
});
} // namespace relay
} // namespace tvm
...@@ -29,11 +29,11 @@ using runtime::TypedPackedFunc; ...@@ -29,11 +29,11 @@ using runtime::TypedPackedFunc;
// FoldScaleAxis algorithm: // FoldScaleAxis algorithm:
// //
// The general idea is to transform Expr to tuple of // The general idea is to transform Expr to tuple of
// (value, axes, scale), where the final result satiesfies: // (value, axes, scale), where the final result satisfies:
// //
// result = value // result = value
// for i, k in enumerate(axes): // for i, k in enumerate(axes):
// k-ith dimension of result *= i-th dimension of scale // k-th dimension of result *= i-th dimension of scale
// //
// Then we can propagate this signal along and fold the scale if necessary. // Then we can propagate this signal along and fold the scale if necessary.
// However, it is possible that certain scale may never be consumed // However, it is possible that certain scale may never be consumed
......
...@@ -42,13 +42,20 @@ class TempRealizer : private ExprMutator { ...@@ -42,13 +42,20 @@ class TempRealizer : private ExprMutator {
class ForwardRewriter : private ExprMutator { class ForwardRewriter : private ExprMutator {
public: public:
ForwardRewriter(const OpMap<FForwardRewrite>& rewrite_map, ForwardRewriter(const OpMap<FForwardRewrite>* rewrite_map,
std::function<NodeRef(const Call&)> fcontext, std::function<NodeRef(const Call&)> fcontext,
std::function<Expr(const Expr&)> fmulti_ref_trigger) std::function<Expr(const Expr&)> fmulti_ref_trigger)
: rewrite_map_(rewrite_map), : rewrite_map_(rewrite_map),
fcontext_(fcontext), fcontext_(fcontext),
fmulti_ref_trigger_(fmulti_ref_trigger) { fmulti_ref_trigger_(fmulti_ref_trigger) {}
}
ForwardRewriter(const FForwardRewrite* rewrite_func,
std::function<NodeRef(const Call&)> fcontext,
std::function<Expr(const Expr&)> fmulti_ref_trigger)
: rewrite_func_(rewrite_func),
fcontext_(fcontext),
fmulti_ref_trigger_(fmulti_ref_trigger) {}
// Transform expression. // Transform expression.
Expr Rewrite(Expr expr) { Expr Rewrite(Expr expr) {
...@@ -60,8 +67,9 @@ class ForwardRewriter : private ExprMutator { ...@@ -60,8 +67,9 @@ class ForwardRewriter : private ExprMutator {
private: private:
// The rewrite rule. // The rewrite rule.
const OpMap<FForwardRewrite>& rewrite_map_; const OpMap<FForwardRewrite>* rewrite_map_{nullptr};
// The context. const FForwardRewrite* rewrite_func_{nullptr};
// The context.const
std::function<NodeRef(const Call&)> fcontext_{nullptr}; std::function<NodeRef(const Call&)> fcontext_{nullptr};
// The multiple reference trigger // The multiple reference trigger
std::function<Expr(const Expr&)> fmulti_ref_trigger_{nullptr}; std::function<Expr(const Expr&)> fmulti_ref_trigger_{nullptr};
...@@ -104,9 +112,31 @@ class ForwardRewriter : private ExprMutator { ...@@ -104,9 +112,31 @@ class ForwardRewriter : private ExprMutator {
} }
} }
Expr VisitExpr_(const TupleNode* op) final {
tvm::Array<Expr> fields;
bool all_fields_unchanged = true;
for (auto field : op->fields) {
auto new_field = this->GetTempExpr(field);
fields.push_back(new_field);
all_fields_unchanged &= new_field.same_as(field);
}
if (all_fields_unchanged) {
return GetRef<Expr>(op);
} else {
return TupleNode::make(fields);
}
}
Expr VisitExpr_(const CallNode* call_node) final { Expr VisitExpr_(const CallNode* call_node) final {
const Call& ref_call = GetRef<Call>(call_node); const Call& ref_call = GetRef<Call>(call_node);
PackedFunc frewrite = rewrite_map_.get(call_node->op, nullptr); PackedFunc frewrite;
if (rewrite_func_) {
frewrite = *rewrite_func_;
} else {
CHECK(rewrite_map_);
frewrite = rewrite_map_->get(call_node->op, nullptr);
}
auto new_op = this->Mutate(call_node->op); auto new_op = this->Mutate(call_node->op);
bool unchanged = call_node->op.same_as(new_op); bool unchanged = call_node->op.same_as(new_op);
...@@ -147,9 +177,16 @@ Expr ForwardRewrite(const Expr& expr, ...@@ -147,9 +177,16 @@ Expr ForwardRewrite(const Expr& expr,
std::function<NodeRef(const Call&)> fcontext, std::function<NodeRef(const Call&)> fcontext,
std::function<Expr(const Expr&)> fmulti_ref_trigger) { std::function<Expr(const Expr&)> fmulti_ref_trigger) {
auto rewrite_map = Op::GetAttr<FForwardRewrite>(rewrite_map_name); auto rewrite_map = Op::GetAttr<FForwardRewrite>(rewrite_map_name);
return ForwardRewriter(rewrite_map, return ForwardRewriter(&rewrite_map, fcontext, fmulti_ref_trigger).Rewrite(expr);
fcontext, }
fmulti_ref_trigger).Rewrite(expr);
Expr ForwardRewrite(const Expr& expr,
const FForwardRewrite& rewrite_func,
std::function<NodeRef(const Call&)> fcontext,
std::function<Expr(const Expr&)> fmulti_ref_trigger) {
return ForwardRewriter(&rewrite_func, fcontext, fmulti_ref_trigger).Rewrite(expr);
} }
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -73,7 +73,7 @@ inline bool MatchBroadcastToLeftAxes(const TensorTypeNode* tlhs, ...@@ -73,7 +73,7 @@ inline bool MatchBroadcastToLeftAxes(const TensorTypeNode* tlhs,
* the target Tensor on the specified axis via broadcasting rule. * the target Tensor on the specified axis via broadcasting rule.
* *
* \param bias The bias. * \param bias The bias.
* \param target_ndim target dimension. * \param target_ndim Target dimension.
* \param axes The axis on the output we want to match on. * \param axes The axis on the output we want to match on.
*/ */
inline Expr ExpandBiasToMatchAxis(Expr bias, inline Expr ExpandBiasToMatchAxis(Expr bias,
......
...@@ -448,6 +448,7 @@ inline tvm::Tensor group_conv2d_ngchw(const tvm::Tensor& I, ...@@ -448,6 +448,7 @@ inline tvm::Tensor group_conv2d_ngchw(const tvm::Tensor& I,
} }
using FLayoutIndicesTransform = std::function<Array<Expr>(const Array<Var>& indices)>; using FLayoutIndicesTransform = std::function<Array<Expr>(const Array<Var>& indices)>;
/*! /*!
* \brief Transform the layout according to the mapping function \p to_src_indices. * \brief Transform the layout according to the mapping function \p to_src_indices.
* \param src the source input. * \param src the source input.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment