Commit 5884cd01 by alex-weaver Committed by Tianqi Chen

Change TOPI ops to use C++ implementation where applicable (#357)

* Updated TVM version. Implemented fix for nnvm_compiler crash on exit on windows. Changed TOPI ops from using python to using C++ where applicable.

* Fix lint

* Fix lint

* Fix macro

* Fix reshape

* Update TVM to fix test fails
parent 3f6423aa
...@@ -22,6 +22,7 @@ include_directories(BEFORE "include") ...@@ -22,6 +22,7 @@ include_directories(BEFORE "include")
include_directories("tvm/include") include_directories("tvm/include")
include_directories("tvm/dlpack/include") include_directories("tvm/dlpack/include")
include_directories("tvm/HalideIR/src") include_directories("tvm/HalideIR/src")
include_directories("tvm/topi/include")
set(NNVM_LINKER_LIBS "") set(NNVM_LINKER_LIBS "")
set(NNVM_COMPILER_LINKER_LIBS "") set(NNVM_COMPILER_LINKER_LIBS "")
......
...@@ -11,7 +11,7 @@ include $(config) ...@@ -11,7 +11,7 @@ include $(config)
export LDFLAGS = -pthread -lm export LDFLAGS = -pthread -lm
export CFLAGS = -std=c++11 -Wall -O2 -Iinclude -fPIC export CFLAGS = -std=c++11 -Wall -O2 -Iinclude -fPIC
CFLAGS += -Itvm/include -Itvm/dlpack/include -Itvm/HalideIR/src CFLAGS += -Itvm/include -Itvm/dlpack/include -Itvm/HalideIR/src -Itvm/topi/include
ifdef DMLC_CORE_PATH ifdef DMLC_CORE_PATH
CFLAGS += -I$(DMLC_CORE_PATH)/include CFLAGS += -I$(DMLC_CORE_PATH)/include
......
/*!
* Copyright (c) 2016 by Contributors
* \file util.h
* \brief Utility functions for nnvm compiler
*/
#ifndef NNVM_COMPILER_UTIL_H_
#define NNVM_COMPILER_UTIL_H_
#include <tvm/expr.h>
#include <nnvm/tuple.h>
namespace nnvm {
namespace compiler {
/*
* \brief Helper function to convert TShape to TVM array. Useful for
* passing data from NNVM param structures to TOPI ops.
*
* \param shape The shape to convert
*
* \return An Array of Expr, where each element is a constant int32
*/
inline tvm::Array<tvm::Expr> ShapeToArray(TShape shape) {
tvm::Array<tvm::Expr> result;
for (auto i : shape) {
result.push_back(tvm::make_const(tvm::Int(32), i));
}
return result;
}
} // namespace compiler
} // namespace nnvm
#endif // NNVM_COMPILER_UTIL_H_
...@@ -28,7 +28,7 @@ namespace nnvm { ...@@ -28,7 +28,7 @@ namespace nnvm {
* symbol is the final operation of a graph and thus including all the information * symbol is the final operation of a graph and thus including all the information
* required (the graph) to evaluate its output value. * required (the graph) to evaluate its output value.
*/ */
class Symbol { class NNVM_DLL Symbol {
public: public:
/*! \brief option passed to ListAttr */ /*! \brief option passed to ListAttr */
enum ListAttrOption { enum ListAttrOption {
......
...@@ -10,59 +10,26 @@ from . import registry as reg ...@@ -10,59 +10,26 @@ from . import registry as reg
from .registry import OpPattern from .registry import OpPattern
# relu # relu
@reg.register_compute("relu")
def compute_relu(attrs, inputs, _):
"""Compute definition of relu"""
return topi.nn.relu(inputs[0])
reg.register_schedule("relu", _fschedule_broadcast) reg.register_schedule("relu", _fschedule_broadcast)
reg.register_pattern("relu", OpPattern.ELEMWISE) reg.register_pattern("relu", OpPattern.ELEMWISE)
# leaky_relu # leaky_relu
@reg.register_compute("leaky_relu")
def compute_leaky_relu(attrs, inputs, _):
"""Compute definition of relu"""
return topi.nn.leaky_relu(inputs[0], attrs.get_float("alpha"))
reg.register_schedule("leaky_relu", _fschedule_broadcast) reg.register_schedule("leaky_relu", _fschedule_broadcast)
reg.register_pattern("leaky_relu", OpPattern.ELEMWISE) reg.register_pattern("leaky_relu", OpPattern.ELEMWISE)
# flatten # flatten
@reg.register_compute("flatten")
def compute_flatten(attrs, inputs, _):
"""Compute definition of flatten"""
return topi.nn.flatten(inputs[0])
reg.register_schedule("flatten", _fschedule_broadcast) reg.register_schedule("flatten", _fschedule_broadcast)
reg.register_pattern("flatten", OpPattern.INJECTIVE) reg.register_pattern("flatten", OpPattern.INJECTIVE)
# pad # pad
@reg.register_compute("pad")
def compute_pad(attrs, inputs, _):
"""Compute definition of pad"""
pad_width = attrs.get_int_pair_tuple('pad_width')
assert len(pad_width) == len(inputs[0].shape) and \
len(pad_width[0]) == 2, "illegal pad_width"
pad_before = [x[0] for x in pad_width]
pad_after = [x[1] for x in pad_width]
pad_value = attrs.get_int('pad_value')
return topi.nn.pad(inputs[0], pad_before, pad_after, pad_value)
reg.register_schedule("pad", _fschedule_broadcast) reg.register_schedule("pad", _fschedule_broadcast)
reg.register_pattern("pad", OpPattern.INJECTIVE) reg.register_pattern("pad", OpPattern.INJECTIVE)
# softmax # softmax
@reg.register_compute("softmax")
def compute_softmax(attrs, inputs, _):
"""Compute definition of softmax"""
axis = attrs.get_int("axis")
assert axis == -1, "only support axis == -1 for now"
return topi.nn.softmax(inputs[0])
@reg.register_schedule("softmax") @reg.register_schedule("softmax")
def schedule_softmax(_, outs, target): def schedule_softmax(_, outs, target):
"""Schedule definition of softmax""" """Schedule definition of softmax"""
...@@ -73,13 +40,6 @@ reg.register_pattern("softmax", OpPattern.OPAQUE) ...@@ -73,13 +40,6 @@ reg.register_pattern("softmax", OpPattern.OPAQUE)
# log softmax # log softmax
@reg.register_compute("log_softmax")
def compute_log_softmax(attrs, inputs, _):
"""Compute definition of softmax"""
axis = attrs.get_int("axis")
assert axis == -1, "only support axis == -1 for now"
return topi.nn.log_softmax(inputs[0])
@reg.register_schedule("log_softmax") @reg.register_schedule("log_softmax")
def schedule_log_softmax(_, outs, target): def schedule_log_softmax(_, outs, target):
"""Schedule definition of softmax""" """Schedule definition of softmax"""
...@@ -91,13 +51,6 @@ reg.register_pattern("log_softmax", OpPattern.OPAQUE) ...@@ -91,13 +51,6 @@ reg.register_pattern("log_softmax", OpPattern.OPAQUE)
# dense # dense
@reg.register_compute("dense")
def compute_dense(attrs, inputs, _):
"""Compute definition of dense"""
if attrs.get_bool("use_bias"):
return topi.nn.dense(inputs[0], inputs[1], bias=inputs[2])
return topi.nn.dense(inputs[0], inputs[1])
@reg.register_schedule("dense") @reg.register_schedule("dense")
def schedule_dense(_, outs, target): def schedule_dense(_, outs, target):
"""Schedule definition of dense""" """Schedule definition of dense"""
...@@ -175,18 +128,6 @@ reg.register_pattern("conv2d_transpose", OpPattern.OUT_ELEMWISE_FUSABLE) ...@@ -175,18 +128,6 @@ reg.register_pattern("conv2d_transpose", OpPattern.OUT_ELEMWISE_FUSABLE)
# max_pool2d # max_pool2d
@reg.register_compute("max_pool2d")
def compute_max_pool2d(attrs, inputs, _):
"""Compute definition of max_pool2d"""
pool_size = attrs.get_int_tuple("pool_size")
strides = attrs.get_int_tuple("strides")
padding = attrs.get_int_tuple("padding")
layout = attrs["layout"]
ceil_mode = attrs.get_bool("ceil_mode")
assert layout == "NCHW", "only support nchw for now"
return topi.nn.pool(inputs[0], pool_size, strides, padding,
pool_type='max', ceil_mode=ceil_mode)
@reg.register_schedule("max_pool2d") @reg.register_schedule("max_pool2d")
def schedule_max_pool2d(_, outs, target): def schedule_max_pool2d(_, outs, target):
"""Schedule definition of max_pool2d""" """Schedule definition of max_pool2d"""
...@@ -197,18 +138,6 @@ reg.register_pattern("max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) ...@@ -197,18 +138,6 @@ reg.register_pattern("max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
# avg_pool2d # avg_pool2d
@reg.register_compute("avg_pool2d")
def compute_avg_pool2d(attrs, inputs, _):
"""Compute definition of avg_pool2d"""
pool_size = attrs.get_int_tuple("pool_size")
strides = attrs.get_int_tuple("strides")
padding = attrs.get_int_tuple("padding")
layout = attrs["layout"]
ceil_mode = attrs.get_bool("ceil_mode")
assert layout == "NCHW", "only support nchw for now"
return topi.nn.pool(inputs[0], pool_size, strides, padding,
pool_type='avg', ceil_mode=ceil_mode)
@reg.register_schedule("avg_pool2d") @reg.register_schedule("avg_pool2d")
def schedule_avg_pool2d(_, outs, target): def schedule_avg_pool2d(_, outs, target):
"""Schedule definition of avg_pool2d""" """Schedule definition of avg_pool2d"""
...@@ -219,13 +148,6 @@ reg.register_pattern("avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) ...@@ -219,13 +148,6 @@ reg.register_pattern("avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
# global_max_pool2d # global_max_pool2d
@reg.register_compute("global_max_pool2d")
def compute_global_max_pool2d(attrs, inputs, _):
"""Compute definition of global_max_pool2d"""
layout = attrs["layout"]
assert layout == "NCHW", "only support nchw for now"
return topi.nn.global_pool(inputs[0], pool_type='max')
@reg.register_schedule("global_max_pool2d") @reg.register_schedule("global_max_pool2d")
def schedule_global_max_pool2d(_, outs, target): def schedule_global_max_pool2d(_, outs, target):
"""Schedule definition of global_max_pool2d""" """Schedule definition of global_max_pool2d"""
...@@ -236,13 +158,6 @@ reg.register_pattern("global_max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) ...@@ -236,13 +158,6 @@ reg.register_pattern("global_max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
# global_avg_pool2d # global_avg_pool2d
@reg.register_compute("global_avg_pool2d")
def compute_global_avg_pool2d(attrs, inputs, _):
"""Compute definition of global_avg_pool2d"""
layout = attrs["layout"]
assert layout == "NCHW", "only support nchw for now"
return topi.nn.global_pool(inputs[0], pool_type='avg')
@reg.register_schedule("global_avg_pool2d") @reg.register_schedule("global_avg_pool2d")
def schedule_global_avg_pool2d(_, outs, target): def schedule_global_avg_pool2d(_, outs, target):
"""Schedule definition of global_avg_pool2d""" """Schedule definition of global_avg_pool2d"""
......
...@@ -27,16 +27,13 @@ def _compute_reduce(f): ...@@ -27,16 +27,13 @@ def _compute_reduce(f):
return _compute return _compute
# sum # sum
reg.register_compute("sum", _compute_reduce(topi.sum))
reg.register_pattern("sum", OpPattern.COMM_REDUCE) reg.register_pattern("sum", OpPattern.COMM_REDUCE)
reg.register_schedule("sum", _fschedule_reduce) reg.register_schedule("sum", _fschedule_reduce)
# max # max
reg.register_compute("max", _compute_reduce(topi.max))
reg.register_pattern("max", OpPattern.COMM_REDUCE) reg.register_pattern("max", OpPattern.COMM_REDUCE)
reg.register_schedule("max", _fschedule_reduce) reg.register_schedule("max", _fschedule_reduce)
# min # min
reg.register_compute("min", _compute_reduce(topi.min))
reg.register_pattern("min", OpPattern.COMM_REDUCE) reg.register_pattern("min", OpPattern.COMM_REDUCE)
reg.register_schedule("min", _fschedule_reduce) reg.register_schedule("min", _fschedule_reduce)
...@@ -43,132 +43,97 @@ _fschedule_broadcast = _fschedule_injective ...@@ -43,132 +43,97 @@ _fschedule_broadcast = _fschedule_injective
_fschedule_elemwise = _fschedule_injective _fschedule_elemwise = _fschedule_injective
# copy # copy
reg.register_compute("copy", _compute_unary(topi.identity))
reg.register_pattern("copy", OpPattern.ELEMWISE) reg.register_pattern("copy", OpPattern.ELEMWISE)
reg.register_schedule("copy", _fschedule_broadcast) reg.register_schedule("copy", _fschedule_broadcast)
# exp # exp
reg.register_compute("exp", _compute_unary(topi.exp))
reg.register_pattern("exp", OpPattern.ELEMWISE) reg.register_pattern("exp", OpPattern.ELEMWISE)
reg.register_schedule("exp", _fschedule_broadcast) reg.register_schedule("exp", _fschedule_broadcast)
# sqrt # sqrt
reg.register_compute("sqrt", _compute_unary(topi.sqrt))
reg.register_pattern("sqrt", OpPattern.ELEMWISE) reg.register_pattern("sqrt", OpPattern.ELEMWISE)
reg.register_schedule("sqrt", _fschedule_broadcast) reg.register_schedule("sqrt", _fschedule_broadcast)
# log # log
reg.register_compute("log", _compute_unary(topi.log))
reg.register_pattern("log", OpPattern.ELEMWISE) reg.register_pattern("log", OpPattern.ELEMWISE)
reg.register_schedule("log", _fschedule_broadcast) reg.register_schedule("log", _fschedule_broadcast)
# tanh # tanh
reg.register_compute("tanh", _compute_unary(topi.tanh))
reg.register_pattern("tanh", OpPattern.ELEMWISE) reg.register_pattern("tanh", OpPattern.ELEMWISE)
reg.register_schedule("tanh", _fschedule_broadcast) reg.register_schedule("tanh", _fschedule_broadcast)
# negative # negative
reg.register_compute("negative", _compute_unary(topi.negative))
reg.register_pattern("negative", OpPattern.ELEMWISE) reg.register_pattern("negative", OpPattern.ELEMWISE)
reg.register_schedule("negative", _fschedule_broadcast) reg.register_schedule("negative", _fschedule_broadcast)
# sigmoid # sigmoid
reg.register_compute("sigmoid", _compute_unary(topi.sigmoid))
reg.register_pattern("sigmoid", OpPattern.ELEMWISE) reg.register_pattern("sigmoid", OpPattern.ELEMWISE)
reg.register_schedule("sigmoid", _fschedule_broadcast) reg.register_schedule("sigmoid", _fschedule_broadcast)
# add_scalar # add_scalar
reg.register_compute("__add_scalar__",
_compute_binary_scalar(lambda x, y: x + y))
reg.register_pattern("__add_scalar__", OpPattern.ELEMWISE) reg.register_pattern("__add_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__add_scalar__", _fschedule_broadcast) reg.register_schedule("__add_scalar__", _fschedule_broadcast)
# sub_calar # sub_calar
reg.register_compute("__sub_scalar__",
_compute_binary_scalar(lambda x, y: x - y))
reg.register_pattern("__sub_scalar__", OpPattern.ELEMWISE) reg.register_pattern("__sub_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__sub_scalar__", _fschedule_broadcast) reg.register_schedule("__sub_scalar__", _fschedule_broadcast)
# rsub_scalar # rsub_scalar
reg.register_compute("__rsub_scalar__",
_compute_binary_scalar(lambda x, y: y - x))
reg.register_pattern("__rsub_scalar__", OpPattern.ELEMWISE) reg.register_pattern("__rsub_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__rsub_scalar__", _fschedule_broadcast) reg.register_schedule("__rsub_scalar__", _fschedule_broadcast)
# mul_scalar # mul_scalar
reg.register_compute("__mul_scalar__",
_compute_binary_scalar(lambda x, y: x * y))
reg.register_pattern("__mul_scalar__", OpPattern.ELEMWISE) reg.register_pattern("__mul_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__mul_scalar__", _fschedule_broadcast) reg.register_schedule("__mul_scalar__", _fschedule_broadcast)
# div_scalar # div_scalar
reg.register_compute("__div_scalar__",
_compute_binary_scalar(lambda x, y: x / y))
reg.register_pattern("__div_scalar__", OpPattern.ELEMWISE) reg.register_pattern("__div_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__div_scalar__", _fschedule_broadcast) reg.register_schedule("__div_scalar__", _fschedule_broadcast)
# rdiv_scalar # rdiv_scalar
reg.register_compute("__rdiv_scalar__",
_compute_binary_scalar(lambda x, y: y / x))
reg.register_pattern("__rdiv_scalar__", OpPattern.ELEMWISE) reg.register_pattern("__rdiv_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__rdiv_scalar__", _fschedule_broadcast) reg.register_schedule("__rdiv_scalar__", _fschedule_broadcast)
# pow_scalar # pow_scalar
reg.register_compute("__pow_scalar__",
_compute_binary_scalar(tvm.power))
reg.register_pattern("__pow_scalar__", OpPattern.ELEMWISE) reg.register_pattern("__pow_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__pow_scalar__", _fschedule_broadcast) reg.register_schedule("__pow_scalar__", _fschedule_broadcast)
# rpow_scalar # rpow_scalar
reg.register_compute("__rpow_scalar__",
_compute_binary_scalar(lambda x, y: tvm.power(y, x)))
reg.register_pattern("__rpow_scalar__", OpPattern.ELEMWISE) reg.register_pattern("__rpow_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__rpow_scalar__", _fschedule_broadcast) reg.register_schedule("__rpow_scalar__", _fschedule_broadcast)
# elemwise_add # elemwise_add
reg.register_compute("elemwise_add", _compute_binary(topi.broadcast_add))
reg.register_pattern("elemwise_add", OpPattern.BROADCAST) reg.register_pattern("elemwise_add", OpPattern.BROADCAST)
reg.register_schedule("elemwise_add", _fschedule_broadcast) reg.register_schedule("elemwise_add", _fschedule_broadcast)
# elemwise_sub # elemwise_sub
reg.register_compute("elemwise_sub", _compute_binary(topi.broadcast_sub))
reg.register_pattern("elemwise_sub", OpPattern.BROADCAST) reg.register_pattern("elemwise_sub", OpPattern.BROADCAST)
reg.register_schedule("elemwise_sub", _fschedule_broadcast) reg.register_schedule("elemwise_sub", _fschedule_broadcast)
# elemwise_mul # elemwise_mul
reg.register_compute("elemwise_mul", _compute_binary(topi.broadcast_mul))
reg.register_pattern("elemwise_mul", OpPattern.BROADCAST) reg.register_pattern("elemwise_mul", OpPattern.BROADCAST)
reg.register_schedule("elemwise_mul", _fschedule_broadcast) reg.register_schedule("elemwise_mul", _fschedule_broadcast)
# elemwise_div # elemwise_div
reg.register_compute("elemwise_div", _compute_binary(topi.broadcast_div))
reg.register_pattern("elemwise_div", OpPattern.BROADCAST) reg.register_pattern("elemwise_div", OpPattern.BROADCAST)
reg.register_schedule("elemwise_div", _fschedule_broadcast) reg.register_schedule("elemwise_div", _fschedule_broadcast)
# broadcast_add # broadcast_add
reg.register_compute("broadcast_add", _compute_binary(topi.broadcast_add))
reg.register_pattern("broadcast_add", OpPattern.BROADCAST) reg.register_pattern("broadcast_add", OpPattern.BROADCAST)
reg.register_schedule("broadcast_add", _fschedule_broadcast) reg.register_schedule("broadcast_add", _fschedule_broadcast)
# broadcast_sub # broadcast_sub
reg.register_compute("broadcast_sub", _compute_binary(topi.broadcast_sub))
reg.register_pattern("broadcast_sub", OpPattern.BROADCAST) reg.register_pattern("broadcast_sub", OpPattern.BROADCAST)
reg.register_schedule("broadcast_sub", _fschedule_broadcast) reg.register_schedule("broadcast_sub", _fschedule_broadcast)
# broadcast_mul # broadcast_mul
reg.register_compute("broadcast_mul", _compute_binary(topi.broadcast_mul))
reg.register_pattern("broadcast_mul", OpPattern.BROADCAST) reg.register_pattern("broadcast_mul", OpPattern.BROADCAST)
reg.register_schedule("broadcast_mul", _fschedule_broadcast) reg.register_schedule("broadcast_mul", _fschedule_broadcast)
# broadcast_div # broadcast_div
reg.register_compute("broadcast_div", _compute_binary(topi.broadcast_div))
reg.register_pattern("broadcast_div", OpPattern.BROADCAST) reg.register_pattern("broadcast_div", OpPattern.BROADCAST)
reg.register_schedule("broadcast_div", _fschedule_broadcast) reg.register_schedule("broadcast_div", _fschedule_broadcast)
# broadcast_to # broadcast_to
@reg.register_compute("broadcast_to")
def compute_broadcast_to(attrs, inputs, out_info):
"""Compute definition of softmax"""
return topi.broadcast_to(inputs[0], shape=out_info[0].shape)
reg.register_pattern("broadcast_to", OpPattern.BROADCAST) reg.register_pattern("broadcast_to", OpPattern.BROADCAST)
reg.register_schedule("broadcast_to", _fschedule_broadcast) reg.register_schedule("broadcast_to", _fschedule_broadcast)
...@@ -2,71 +2,30 @@ ...@@ -2,71 +2,30 @@
"""Tensor transformation ops""" """Tensor transformation ops"""
from __future__ import absolute_import from __future__ import absolute_import
import topi
from .tensor import _fschedule_broadcast, _fschedule_injective from .tensor import _fschedule_broadcast, _fschedule_injective
from . import registry as reg from . import registry as reg
from .registry import OpPattern from .registry import OpPattern
# expand_dims # expand_dims
@reg.register_compute("expand_dims")
def compute_expand_dims(attrs, inputs, out_info):
"""Compute definition of expand_dims"""
return topi.expand_dims(
inputs[0], attrs.get_int("axis"),
num_newaxis=attrs.get_int("num_newaxis"))
reg.register_pattern("expand_dims", OpPattern.BROADCAST) reg.register_pattern("expand_dims", OpPattern.BROADCAST)
reg.register_schedule("expand_dims", _fschedule_broadcast) reg.register_schedule("expand_dims", _fschedule_broadcast)
# transpose # transpose
@reg.register_compute("transpose")
def compute_transpose(attrs, inputs, out_info):
"""Compute definition of transpose"""
axes = attrs.get_int_tuple("axes")
axes = tuple(axes) if axes else None
return topi.transpose(inputs[0], axes)
reg.register_pattern("transpose", OpPattern.INJECTIVE) reg.register_pattern("transpose", OpPattern.INJECTIVE)
reg.register_schedule("transpose", _fschedule_injective) reg.register_schedule("transpose", _fschedule_injective)
# reshape # reshape
@reg.register_compute("reshape")
def compute_reshape(attrs, inputs, out_info):
"""Compute definition of reshape"""
oshape = out_info[0].shape
return topi.reshape(inputs[0], oshape)
reg.register_pattern("reshape", OpPattern.INJECTIVE) reg.register_pattern("reshape", OpPattern.INJECTIVE)
reg.register_schedule("reshape", _fschedule_injective) reg.register_schedule("reshape", _fschedule_injective)
# reshape # squeeze
@reg.register_compute("squeeze")
def compute_squeeze(attrs, inputs, out_info):
"""Compute definition of reshape"""
axis = attrs.get_int_tuple("axis")
axis = tuple(axis) if axis else None
return topi.squeeze(inputs[0], axis)
reg.register_pattern("squeeze", OpPattern.INJECTIVE) reg.register_pattern("squeeze", OpPattern.INJECTIVE)
reg.register_schedule("squeeze", _fschedule_injective) reg.register_schedule("squeeze", _fschedule_injective)
# concatenate # concatenate
@reg.register_compute("concatenate")
def compute_concatenate(attrs, inputs, out_info):
"""Compute definition of concatenate"""
axis = attrs.get_int("axis")
return topi.concatenate([x for x in inputs], axis=axis)
reg.register_pattern("concatenate", OpPattern.INJECTIVE) reg.register_pattern("concatenate", OpPattern.INJECTIVE)
reg.register_schedule("concatenate", _fschedule_injective) reg.register_schedule("concatenate", _fschedule_injective)
# split # split
@reg.register_compute("split")
def compute_split(attrs, inputs, out_info):
"""Compute definition of split"""
x = attrs["indices_or_sections"]
if x.startswith("(") or x.startswith("["):
indices_or_sections = attrs.get_int_tuple("indices_or_sections")
else:
indices_or_sections = attrs.get_int("indices_or_sections")
return topi.split(inputs[0], indices_or_sections, axis=attrs.get_int("axis"))
reg.register_pattern("split", OpPattern.INJECTIVE) reg.register_pattern("split", OpPattern.INJECTIVE)
reg.register_schedule("split", _fschedule_injective) reg.register_schedule("split", _fschedule_injective)
...@@ -344,7 +344,7 @@ TVM_REGISTER_GLOBAL("nnvm.compiler.CacheItem2ScheduleArgs") ...@@ -344,7 +344,7 @@ TVM_REGISTER_GLOBAL("nnvm.compiler.CacheItem2ScheduleArgs")
*rv = ret; *rv = ret;
}); });
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<GraphFuncNode>([](const GraphFuncNode *op, IRPrinter *p) { .set_dispatch<GraphFuncNode>([](const GraphFuncNode *op, IRPrinter *p) {
p->stream << "GraphFunc(name=" << op->func_name p->stream << "GraphFunc(name=" << op->func_name
<< ", addr=" << op << ")"; << ", addr=" << op << ")";
......
...@@ -80,7 +80,7 @@ GraphKey GraphKeyNode::make(Graph graph, ...@@ -80,7 +80,7 @@ GraphKey GraphKeyNode::make(Graph graph,
return GraphKey(n); return GraphKey(n);
} }
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<GraphKeyNode>([](const GraphKeyNode *op, IRPrinter *p) { .set_dispatch<GraphKeyNode>([](const GraphKeyNode *op, IRPrinter *p) {
p->stream << "GraphKeyNode("<< op << ")"; p->stream << "GraphKeyNode("<< op << ")";
}); });
......
...@@ -3,17 +3,27 @@ ...@@ -3,17 +3,27 @@
* \file nn.cc * \file nn.cc
* \brief Property def of nn operators. * \brief Property def of nn operators.
*/ */
#include <tvm/expr.h>
#include <tvm/packed_func_ext.h>
#include <nnvm/op.h> #include <nnvm/op.h>
#include <nnvm/node.h> #include <nnvm/node.h>
#include <nnvm/op_attr_types.h> #include <nnvm/op_attr_types.h>
#include <nnvm/compiler/op_attr_types.h>
#include <nnvm/top/nn.h> #include <nnvm/top/nn.h>
#include "./nn_common.h" #include "./nn_common.h"
#include "../op_common.h" #include "../op_common.h"
#include "../elemwise_op_common.h" #include "../elemwise_op_common.h"
#include "topi/nn/dense.h"
#include "topi/nn.h"
#include "topi/nn/softmax.h"
namespace nnvm { namespace nnvm {
namespace top { namespace top {
using tvm::Tensor;
using tvm::Array;
using nnvm::compiler::FTVMCompute;
// dense // dense
DMLC_REGISTER_PARAMETER(DenseParam); DMLC_REGISTER_PARAMETER(DenseParam);
...@@ -72,6 +82,21 @@ If ``use_bias`` is set to be false, then the ``bias`` term is ignored. ...@@ -72,6 +82,21 @@ If ``use_bias`` is set to be false, then the ``bias`` term is ignored.
.set_attr<FListInputNames>("FListInputNames", UseBiasListInputNames<DenseParam>) .set_attr<FListInputNames>("FListInputNames", UseBiasListInputNames<DenseParam>)
.set_attr<FInferShape>("FInferShape", DenseInferShape) .set_attr<FInferShape>("FInferShape", DenseInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
Tensor bias_val;
Tensor* bias;
const DenseParam& param = nnvm::get<DenseParam>(attrs.parsed);
if (param.use_bias) {
bias_val = inputs[2];
bias = &bias_val;
} else {
bias = nullptr;
}
return Array<Tensor>{ topi::nn::dense(inputs[0], inputs[1], bias) };
})
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) { const std::vector<NodeEntry>& ograds) {
...@@ -110,6 +135,12 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(relu) ...@@ -110,6 +135,12 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(relu)
max(input, 0) max(input, 0)
)code" NNVM_ADD_FILELINE) )code" NNVM_ADD_FILELINE)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::relu(inputs[0], 0.0f) };
})
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) { const std::vector<NodeEntry>& ograds) {
...@@ -258,6 +289,14 @@ NNVM_REGISTER_OP(softmax) ...@@ -258,6 +289,14 @@ NNVM_REGISTER_OP(softmax)
.set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>) .set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_support_level(1) .set_support_level(1)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
CHECK_EQ(param.axis, -1) << "Currently only axis=-1 is supported";
return Array<Tensor>{ topi::nn::softmax(inputs[0]) };
})
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) { const std::vector<NodeEntry>& ograds) {
...@@ -306,6 +345,14 @@ NNVM_REGISTER_OP(log_softmax) ...@@ -306,6 +345,14 @@ NNVM_REGISTER_OP(log_softmax)
.set_num_outputs(1) .set_num_outputs(1)
.set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>) .set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
CHECK_EQ(param.axis, -1) << "Currently only axis=-1 is supported";
return Array<Tensor>{ topi::nn::log_softmax(inputs[0]) };
})
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) { const std::vector<NodeEntry>& ograds) {
...@@ -357,6 +404,13 @@ NNVM_REGISTER_OP(leaky_relu) ...@@ -357,6 +404,13 @@ NNVM_REGISTER_OP(leaky_relu)
.set_num_outputs(1) .set_num_outputs(1)
.set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>) .set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const LeakyReLUParam& param = nnvm::get<LeakyReLUParam>(attrs.parsed);
return Array<Tensor>{ topi::leaky_relu<float>(inputs[0], 0.0, param.alpha) };
})
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) { const std::vector<NodeEntry>& ograds) {
...@@ -413,6 +467,25 @@ NNVM_REGISTER_OP(pad) ...@@ -413,6 +467,25 @@ NNVM_REGISTER_OP(pad)
.set_num_inputs(1) .set_num_inputs(1)
.set_attr<FInferShape>("FInferShape", PadInferShape) .set_attr<FInferShape>("FInferShape", PadInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const PadParam& param = nnvm::get<PadParam>(attrs.parsed);
auto pad_width = param.pad_width;
CHECK(pad_width.ndim() == inputs[0]->shape.size() &&
pad_width[0].ndim() == 2)
<< "Illegal pad_width";
Array<tvm::Expr> pad_before;
for (size_t i = 0; i < pad_width.ndim(); ++i) {
pad_before.push_back(tvm::make_const(tvm::Int(32), pad_width[i][0]));
}
Array<tvm::Expr> pad_after;
for (size_t i = 0; i < pad_width.ndim(); ++i) {
pad_after.push_back(tvm::make_const(tvm::Int(32), pad_width[i][1]));
}
return Array<Tensor>{ topi::pad(inputs[0], pad_before, pad_after, param.pad_value) };
})
.set_support_level(1); .set_support_level(1);
} // namespace top } // namespace top
......
...@@ -6,13 +6,18 @@ ...@@ -6,13 +6,18 @@
#include <nnvm/op.h> #include <nnvm/op.h>
#include <nnvm/node.h> #include <nnvm/node.h>
#include <nnvm/op_attr_types.h> #include <nnvm/op_attr_types.h>
#include <nnvm/compiler/op_attr_types.h>
#include <nnvm/compiler/util.h>
#include <nnvm/top/nn.h> #include <nnvm/top/nn.h>
#include "./nn_common.h" #include "./nn_common.h"
#include "../op_common.h" #include "../op_common.h"
#include "../elemwise_op_common.h" #include "../elemwise_op_common.h"
#include "topi/nn/pooling.h"
namespace nnvm { namespace nnvm {
namespace top { namespace top {
using namespace tvm;
using namespace nnvm::compiler;
DMLC_REGISTER_PARAMETER(Pool2DParam); DMLC_REGISTER_PARAMETER(Pool2DParam);
...@@ -77,6 +82,20 @@ NNVM_REGISTER_OP(max_pool2d) ...@@ -77,6 +82,20 @@ NNVM_REGISTER_OP(max_pool2d)
.set_num_inputs(1) .set_num_inputs(1)
.set_attr<FInferShape>("FInferShape", Pool2DInferShape) .set_attr<FInferShape>("FInferShape", Pool2DInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const Pool2DParam& param = nnvm::get<Pool2DParam>(attrs.parsed);
auto pool_size = ShapeToArray(param.pool_size);
auto strides = ShapeToArray(param.strides);
auto padding = ShapeToArray(param.padding);
auto ceil_mode = param.ceil_mode;
CHECK_EQ(param.layout, kNCHW)
<< "max_pool2d currently only supports NCHW layout";
return Array<Tensor>{
topi::nn::pool(inputs[0], pool_size, strides, padding, topi::nn::kMaxPool, ceil_mode) };
})
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) { const std::vector<NodeEntry>& ograds) {
...@@ -124,6 +143,20 @@ NNVM_REGISTER_OP(avg_pool2d) ...@@ -124,6 +143,20 @@ NNVM_REGISTER_OP(avg_pool2d)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<Pool2DParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<Pool2DParam>)
.set_attr<FInferShape>("FInferShape", Pool2DInferShape) .set_attr<FInferShape>("FInferShape", Pool2DInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const Pool2DParam& param = nnvm::get<Pool2DParam>(attrs.parsed);
auto pool_size = ShapeToArray(param.pool_size);
auto strides = ShapeToArray(param.strides);
auto padding = ShapeToArray(param.padding);
auto ceil_mode = param.ceil_mode;
CHECK_EQ(param.layout, kNCHW)
<< "avg_pool2d currently only supports NCHW layout";
return Array<Tensor>{
topi::nn::pool(inputs[0], pool_size, strides, padding, topi::nn::kAvgPool, ceil_mode) };
})
.set_num_outputs(1) .set_num_outputs(1)
.set_num_inputs(1) .set_num_inputs(1)
.set_support_level(2); .set_support_level(2);
...@@ -162,6 +195,16 @@ NNVM_REGISTER_OP(global_max_pool2d) ...@@ -162,6 +195,16 @@ NNVM_REGISTER_OP(global_max_pool2d)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<GlobalPool2DParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<GlobalPool2DParam>)
.set_attr<FInferShape>("FInferShape", GlobalPool2DInferShape) .set_attr<FInferShape>("FInferShape", GlobalPool2DInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const GlobalPool2DParam& param = nnvm::get<GlobalPool2DParam>(attrs.parsed);
CHECK_EQ(param.layout, kNCHW)
<< "global_max_pool2d currently only supports NCHW layout";
return Array<Tensor>{
topi::nn::global_pool(inputs[0], topi::nn::kMaxPool) };
})
.set_num_outputs(1) .set_num_outputs(1)
.set_num_inputs(1) .set_num_inputs(1)
.set_support_level(2); .set_support_level(2);
...@@ -182,6 +225,16 @@ NNVM_REGISTER_OP(global_avg_pool2d) ...@@ -182,6 +225,16 @@ NNVM_REGISTER_OP(global_avg_pool2d)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<GlobalPool2DParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<GlobalPool2DParam>)
.set_attr<FInferShape>("FInferShape", GlobalPool2DInferShape) .set_attr<FInferShape>("FInferShape", GlobalPool2DInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const GlobalPool2DParam& param = nnvm::get<GlobalPool2DParam>(attrs.parsed);
CHECK_EQ(param.layout, kNCHW)
<< "global_avg_pool2d currently only supports NCHW layout";
return Array<Tensor>{
topi::nn::global_pool(inputs[0], topi::nn::kAvgPool) };
})
.set_num_outputs(1) .set_num_outputs(1)
.set_num_inputs(1) .set_num_inputs(1)
.set_support_level(2); .set_support_level(2);
......
...@@ -3,15 +3,22 @@ ...@@ -3,15 +3,22 @@
* \file broadcast.cc * \file broadcast.cc
* \brief broadcast operator. * \brief broadcast operator.
*/ */
#include <tvm/expr.h>
#include <tvm/packed_func_ext.h>
#include <nnvm/op.h> #include <nnvm/op.h>
#include <nnvm/node.h> #include <nnvm/node.h>
#include <nnvm/op_attr_types.h> #include <nnvm/op_attr_types.h>
#include <nnvm/compiler/op_attr_types.h>
#include <nnvm/compiler/util.h>
#include <nnvm/top/tensor.h> #include <nnvm/top/tensor.h>
#include "../op_common.h" #include "../op_common.h"
#include "../elemwise_op_common.h" #include "../elemwise_op_common.h"
#include "topi/broadcast.h"
namespace nnvm { namespace nnvm {
namespace top { namespace top {
using namespace tvm;
using namespace nnvm::compiler;
// broadcast_to // broadcast_to
DMLC_REGISTER_PARAMETER(BroadcastToParam); DMLC_REGISTER_PARAMETER(BroadcastToParam);
...@@ -67,6 +74,14 @@ So with `shape=(2,0)`, we will obtain the same result as in the above example. ...@@ -67,6 +74,14 @@ So with `shape=(2,0)`, we will obtain the same result as in the above example.
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<BroadcastToParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<BroadcastToParam>)
.set_attr<FInferShape>("FInferShape", BroadcastToInferShape) .set_attr<FInferShape>("FInferShape", BroadcastToInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const BroadcastToParam& param = nnvm::get<BroadcastToParam>(attrs.parsed);
auto shape = ShapeToArray(param.shape);
return Array<Tensor>{ topi::broadcast_to(inputs[0], shape) };
})
.set_num_inputs(1) .set_num_inputs(1)
.set_num_outputs(1) .set_num_outputs(1)
.set_support_level(4); .set_support_level(4);
...@@ -122,6 +137,13 @@ inline bool BinaryBroadcastShape(const nnvm::NodeAttrs& attrs, ...@@ -122,6 +137,13 @@ inline bool BinaryBroadcastShape(const nnvm::NodeAttrs& attrs,
[](const NodeAttrs& attrs) { \ [](const NodeAttrs& attrs) { \
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}}; \ return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}}; \
}) \ }) \
.set_attr<FTVMCompute>( \
"FTVMCompute", [](const NodeAttrs& attrs, \
const Array<Tensor>& inputs, \
const Array<Tensor>& out_info) { \
return Array<Tensor>{ \
topi::name(inputs[0], inputs[1]) }; \
}) \
.add_argument("lhs", "Tensor", "first input") \ .add_argument("lhs", "Tensor", "first input") \
.add_argument("rhs", "Tensor", "second input") .add_argument("rhs", "Tensor", "second input")
......
...@@ -6,13 +6,19 @@ ...@@ -6,13 +6,19 @@
#include <nnvm/op.h> #include <nnvm/op.h>
#include <nnvm/node.h> #include <nnvm/node.h>
#include <nnvm/op_attr_types.h> #include <nnvm/op_attr_types.h>
#include <nnvm/compiler/op_attr_types.h>
#include <nnvm/top/tensor.h> #include <nnvm/top/tensor.h>
#include <cmath> #include <cmath>
#include "../op_common.h" #include "../op_common.h"
#include "../elemwise_op_common.h" #include "../elemwise_op_common.h"
#include "topi/broadcast.h"
#include "topi/elemwise.h"
#include "topi/tags.h"
namespace nnvm { namespace nnvm {
namespace top { namespace top {
using namespace tvm;
using namespace nnvm::compiler;
// undefined op // undefined op
NNVM_REGISTER_ELEMWISE_UNARY_OP(__undef__) NNVM_REGISTER_ELEMWISE_UNARY_OP(__undef__)
.describe(R"code(undefined op. .describe(R"code(undefined op.
...@@ -32,6 +38,12 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(sigmoid) ...@@ -32,6 +38,12 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(sigmoid)
)code" NNVM_ADD_FILELINE) )code" NNVM_ADD_FILELINE)
.set_support_level(1) .set_support_level(1)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::sigmoid(inputs[0]) };
})
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) { const std::vector<NodeEntry>& ograds) {
...@@ -56,6 +68,12 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(tanh) ...@@ -56,6 +68,12 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(tanh)
)code" NNVM_ADD_FILELINE) )code" NNVM_ADD_FILELINE)
.set_support_level(1) .set_support_level(1)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::tanh(inputs[0]) };
})
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) { const std::vector<NodeEntry>& ograds) {
...@@ -80,6 +98,12 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(exp) ...@@ -80,6 +98,12 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(exp)
)code" NNVM_ADD_FILELINE) )code" NNVM_ADD_FILELINE)
.set_support_level(1) .set_support_level(1)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::exp(inputs[0]) };
})
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) { const std::vector<NodeEntry>& ograds) {
...@@ -100,6 +124,12 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(log) ...@@ -100,6 +124,12 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(log)
)code" NNVM_ADD_FILELINE) )code" NNVM_ADD_FILELINE)
.set_support_level(1) .set_support_level(1)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::log(inputs[0]) };
})
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) { const std::vector<NodeEntry>& ograds) {
...@@ -120,6 +150,12 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(sqrt) ...@@ -120,6 +150,12 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(sqrt)
)code" NNVM_ADD_FILELINE) )code" NNVM_ADD_FILELINE)
.set_support_level(1) .set_support_level(1)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::sqrt(inputs[0]) };
})
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) { const std::vector<NodeEntry>& ograds) {
...@@ -140,6 +176,12 @@ NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_add) ...@@ -140,6 +176,12 @@ NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_add)
)code") )code")
.set_support_level(1) .set_support_level(1)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::broadcast_add(inputs[0], inputs[1]) };
})
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){ const std::vector<NodeEntry>& ograds){
...@@ -154,6 +196,12 @@ NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_sub) ...@@ -154,6 +196,12 @@ NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_sub)
)code" NNVM_ADD_FILELINE) )code" NNVM_ADD_FILELINE)
.set_support_level(1) .set_support_level(1)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::broadcast_sub(inputs[0], inputs[1]) };
})
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){ const std::vector<NodeEntry>& ograds){
...@@ -171,6 +219,12 @@ NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_mul) ...@@ -171,6 +219,12 @@ NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_mul)
)code" NNVM_ADD_FILELINE) )code" NNVM_ADD_FILELINE)
.set_support_level(1) .set_support_level(1)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::broadcast_mul(inputs[0], inputs[1]) };
})
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){ const std::vector<NodeEntry>& ograds){
...@@ -190,6 +244,12 @@ NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_div) ...@@ -190,6 +244,12 @@ NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_div)
)code" NNVM_ADD_FILELINE) )code" NNVM_ADD_FILELINE)
.set_support_level(1) .set_support_level(1)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::broadcast_div(inputs[0], inputs[1]) };
})
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){ const std::vector<NodeEntry>& ograds){
...@@ -216,6 +276,12 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(negative) ...@@ -216,6 +276,12 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(negative)
)code" NNVM_ADD_FILELINE) )code" NNVM_ADD_FILELINE)
.set_support_level(3) .set_support_level(3)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::negative(inputs[0]) };
})
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){ const std::vector<NodeEntry>& ograds){
...@@ -232,6 +298,12 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(copy) ...@@ -232,6 +298,12 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(copy)
)code" NNVM_ADD_FILELINE) )code" NNVM_ADD_FILELINE)
.set_support_level(3) .set_support_level(3)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::identity(inputs[0]) };
})
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){ const std::vector<NodeEntry>& ograds){
...@@ -315,12 +387,29 @@ DMLC_REGISTER_PARAMETER(ScalarParam); ...@@ -315,12 +387,29 @@ DMLC_REGISTER_PARAMETER(ScalarParam);
.set_attr_parser(ParamParser<ScalarParam>) \ .set_attr_parser(ParamParser<ScalarParam>) \
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ScalarParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ScalarParam>)
inline Tensor binary_scalar_op(const NodeAttrs& attrs,
const Tensor& x,
std::function<Expr(Expr, Expr)> f) {
const ScalarParam& param = nnvm::get<ScalarParam>(attrs.parsed);
auto scalar_val = static_cast<float>(param.scalar);
return compute(x->shape, [&](const Array<Var>& i) {
auto scalar_const = make_const(x->dtype, scalar_val);
return f(x(i), scalar_const);
}, "tensor", topi::kElementWise);
}
NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__add_scalar__) NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__add_scalar__)
.describe(R"code(Tensor add scalar .describe(R"code(Tensor add scalar
)code" NNVM_ADD_FILELINE) )code" NNVM_ADD_FILELINE)
.set_support_level(3) .set_support_level(3)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ binary_scalar_op(attrs, inputs[0],
[](Expr x, Expr y) { return x + y; }) };
})
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){ const std::vector<NodeEntry>& ograds){
...@@ -332,6 +421,13 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__sub_scalar__) ...@@ -332,6 +421,13 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__sub_scalar__)
)code" NNVM_ADD_FILELINE) )code" NNVM_ADD_FILELINE)
.set_support_level(3) .set_support_level(3)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ binary_scalar_op(attrs, inputs[0],
[](Expr x, Expr y) { return x - y; }) };
})
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){ const std::vector<NodeEntry>& ograds){
...@@ -343,6 +439,13 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__rsub_scalar__) ...@@ -343,6 +439,13 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__rsub_scalar__)
)code" NNVM_ADD_FILELINE) )code" NNVM_ADD_FILELINE)
.set_support_level(3) .set_support_level(3)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ binary_scalar_op(attrs, inputs[0],
[](Expr x, Expr y) { return y - x; }) };
})
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){ const std::vector<NodeEntry>& ograds){
...@@ -356,6 +459,13 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__mul_scalar__) ...@@ -356,6 +459,13 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__mul_scalar__)
)code" NNVM_ADD_FILELINE) )code" NNVM_ADD_FILELINE)
.set_support_level(3) .set_support_level(3)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ binary_scalar_op(attrs, inputs[0],
[](Expr x, Expr y) { return x * y; }) };
})
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){ const std::vector<NodeEntry>& ograds){
...@@ -372,6 +482,13 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__div_scalar__) ...@@ -372,6 +482,13 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__div_scalar__)
)code" NNVM_ADD_FILELINE) )code" NNVM_ADD_FILELINE)
.set_support_level(3) .set_support_level(3)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ binary_scalar_op(attrs, inputs[0],
[](Expr x, Expr y) { return x / y; }) };
})
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){ const std::vector<NodeEntry>& ograds){
...@@ -388,6 +505,13 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__rdiv_scalar__) ...@@ -388,6 +505,13 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__rdiv_scalar__)
)code" NNVM_ADD_FILELINE) )code" NNVM_ADD_FILELINE)
.set_support_level(3) .set_support_level(3)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ binary_scalar_op(attrs, inputs[0],
[](Expr x, Expr y) { return y / x; }) };
})
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){ const std::vector<NodeEntry>& ograds){
...@@ -411,6 +535,13 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__pow_scalar__) ...@@ -411,6 +535,13 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__pow_scalar__)
)code" NNVM_ADD_FILELINE) )code" NNVM_ADD_FILELINE)
.set_support_level(3) .set_support_level(3)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ binary_scalar_op(attrs, inputs[0],
[](Expr x, Expr y) { return tvm::pow(x, y); }) };
})
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){ const std::vector<NodeEntry>& ograds){
...@@ -434,6 +565,13 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__rpow_scalar__) ...@@ -434,6 +565,13 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__rpow_scalar__)
)code" NNVM_ADD_FILELINE) )code" NNVM_ADD_FILELINE)
.set_support_level(3) .set_support_level(3)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ binary_scalar_op(attrs, inputs[0],
[](Expr x, Expr y) { return tvm::pow(y, x); }) };
})
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){ const std::vector<NodeEntry>& ograds){
......
...@@ -6,12 +6,17 @@ ...@@ -6,12 +6,17 @@
#include <nnvm/op.h> #include <nnvm/op.h>
#include <nnvm/node.h> #include <nnvm/node.h>
#include <nnvm/op_attr_types.h> #include <nnvm/op_attr_types.h>
#include <nnvm/compiler/op_attr_types.h>
#include <nnvm/compiler/util.h>
#include <nnvm/top/tensor.h> #include <nnvm/top/tensor.h>
#include "../op_common.h" #include "../op_common.h"
#include "../elemwise_op_common.h" #include "../elemwise_op_common.h"
#include "topi/reduction.h"
namespace nnvm { namespace nnvm {
namespace top { namespace top {
using namespace tvm;
using namespace nnvm::compiler;
// reduce // reduce
DMLC_REGISTER_PARAMETER(ReduceParam); DMLC_REGISTER_PARAMETER(ReduceParam);
...@@ -127,6 +132,15 @@ Example:: ...@@ -127,6 +132,15 @@ Example::
[ 12. 19. 27.] [ 12. 19. 27.]
)code" NNVM_ADD_FILELINE) )code" NNVM_ADD_FILELINE)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed);
auto axis = ShapeToArray(param.axis);
return Array<Tensor>{
topi::sum(inputs[0], axis, param.keepdims) };
})
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){ const std::vector<NodeEntry>& ograds){
...@@ -145,6 +159,15 @@ NNVM_REGISTER_REDUCE_OP(max) ...@@ -145,6 +159,15 @@ NNVM_REGISTER_REDUCE_OP(max)
.describe(R"code(Computes the max of array elements over given axes. .describe(R"code(Computes the max of array elements over given axes.
)code" NNVM_ADD_FILELINE) )code" NNVM_ADD_FILELINE)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed);
auto axis = ShapeToArray(param.axis);
return Array<Tensor>{
topi::max(inputs[0], axis, param.keepdims) };
})
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){ const std::vector<NodeEntry>& ograds){
...@@ -168,6 +191,15 @@ NNVM_REGISTER_REDUCE_OP(min) ...@@ -168,6 +191,15 @@ NNVM_REGISTER_REDUCE_OP(min)
.describe(R"code(Computes the min of array elements over given axes. .describe(R"code(Computes the min of array elements over given axes.
)code" NNVM_ADD_FILELINE) )code" NNVM_ADD_FILELINE)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed);
auto axis = ShapeToArray(param.axis);
return Array<Tensor>{
topi::min(inputs[0], axis, param.keepdims) };
})
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){ const std::vector<NodeEntry>& ograds){
......
...@@ -6,13 +6,19 @@ ...@@ -6,13 +6,19 @@
#include <nnvm/op.h> #include <nnvm/op.h>
#include <nnvm/node.h> #include <nnvm/node.h>
#include <nnvm/op_attr_types.h> #include <nnvm/op_attr_types.h>
#include <nnvm/compiler/op_attr_types.h>
#include <nnvm/compiler/util.h>
#include <nnvm/top/tensor.h> #include <nnvm/top/tensor.h>
#include <cctype> #include <cctype>
#include "../op_common.h" #include "../op_common.h"
#include "../elemwise_op_common.h" #include "../elemwise_op_common.h"
#include "topi/nn/flatten.h"
#include "topi/transform.h"
namespace nnvm { namespace nnvm {
namespace top { namespace top {
using namespace tvm;
using namespace nnvm::compiler;
// flatten // flatten
inline bool FlattenInferShape(const NodeAttrs& attrs, inline bool FlattenInferShape(const NodeAttrs& attrs,
...@@ -58,6 +64,12 @@ Example:: ...@@ -58,6 +64,12 @@ Example::
.set_attr<FInferShape>("FInferShape", FlattenInferShape) .set_attr<FInferShape>("FInferShape", FlattenInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.add_argument("data", "Tensor", "Input data.") .add_argument("data", "Tensor", "Input data.")
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::nn::flatten(inputs[0]) };
})
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){ const std::vector<NodeEntry>& ograds){
...@@ -144,6 +156,13 @@ Example:: ...@@ -144,6 +156,13 @@ Example::
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ConcatenateParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ConcatenateParam>)
.set_attr<FInferShape>("FInferShape", ConcatenateInferShape) .set_attr<FInferShape>("FInferShape", ConcatenateInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const ConcatenateParam& param = nnvm::get<ConcatenateParam>(attrs.parsed);
return Array<Tensor>{ topi::concatenate(inputs, param.axis) };
})
.set_num_outputs(1) .set_num_outputs(1)
.set_num_inputs(kVarg) .set_num_inputs(kVarg)
.set_support_level(1); .set_support_level(1);
...@@ -190,6 +209,13 @@ will return a new array with shape ``(2,5,3,4)``. ...@@ -190,6 +209,13 @@ will return a new array with shape ``(2,5,3,4)``.
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_num_inputs(1) .set_num_inputs(1)
.set_num_outputs(1) .set_num_outputs(1)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const ExpandDimsParam& param = nnvm::get<ExpandDimsParam>(attrs.parsed);
return Array<Tensor>{ topi::expand_dims(inputs[0], param.axis, param.num_newaxis) };
})
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){ const std::vector<NodeEntry>& ograds){
...@@ -326,6 +352,22 @@ along which to split the array. ...@@ -326,6 +352,22 @@ along which to split the array.
.set_attr<FInferType>("FInferType", ElemwiseType<1, -1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, -1>)
.set_num_inputs(1) .set_num_inputs(1)
.set_num_outputs(SplitNumOutputs) .set_num_outputs(SplitNumOutputs)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const SplitParam& param = nnvm::get<SplitParam>(attrs.parsed);
if (param.equal_split) {
return Array<Tensor>{
topi::split_sections(inputs[0], param.indices_or_sections[0], param.axis) };
} else {
Array<Expr> indices;
for (auto i : param.indices_or_sections) {
indices.push_back(tvm::make_const(tvm::Int(32), i));
}
return Array<Tensor>{ topi::split(inputs[0], indices, param.axis) };
}
})
.set_support_level(1); .set_support_level(1);
// cast // cast
...@@ -504,6 +546,12 @@ The significance of each is explained below: ...@@ -504,6 +546,12 @@ The significance of each is explained below:
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_num_inputs(1) .set_num_inputs(1)
.set_num_outputs(1) .set_num_outputs(1)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::reshape(inputs[0], out_info[0]->shape) };
})
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) { const std::vector<NodeEntry>& ograds) {
...@@ -620,6 +668,14 @@ Examples:: ...@@ -620,6 +668,14 @@ Examples::
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_num_inputs(1) .set_num_inputs(1)
.set_num_outputs(1) .set_num_outputs(1)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const SqueezeParam& param = nnvm::get<SqueezeParam>(attrs.parsed);
auto axis = ShapeToArray(param.axis);
return Array<Tensor>{ topi::squeeze(inputs[0], axis) };
})
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) { const std::vector<NodeEntry>& ograds) {
...@@ -695,6 +751,14 @@ Examples:: ...@@ -695,6 +751,14 @@ Examples::
.set_num_inputs(1) .set_num_inputs(1)
.set_num_outputs(1) .set_num_outputs(1)
.set_support_level(4) .set_support_level(4)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const TransposeParam& param = nnvm::get<TransposeParam>(attrs.parsed);
auto axes = ShapeToArray(param.axes);
return Array<Tensor>{ topi::transpose(inputs[0], axes) };
})
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) { const std::vector<NodeEntry>& ograds) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment