Commit 5884cd01 by alex-weaver Committed by Tianqi Chen

Change TOPI ops to use C++ implementation where applicable (#357)

* Updated TVM version. Implemented fix for nnvm_compiler crash on exit on windows. Changed TOPI ops from using python to using C++ where applicable.

* Fix lint

* Fix lint

* Fix macro

* Fix reshape

* Update TVM to fix test fails
parent 3f6423aa
......@@ -22,6 +22,7 @@ include_directories(BEFORE "include")
include_directories("tvm/include")
include_directories("tvm/dlpack/include")
include_directories("tvm/HalideIR/src")
include_directories("tvm/topi/include")
set(NNVM_LINKER_LIBS "")
set(NNVM_COMPILER_LINKER_LIBS "")
......
......@@ -11,7 +11,7 @@ include $(config)
export LDFLAGS = -pthread -lm
export CFLAGS = -std=c++11 -Wall -O2 -Iinclude -fPIC
CFLAGS += -Itvm/include -Itvm/dlpack/include -Itvm/HalideIR/src
CFLAGS += -Itvm/include -Itvm/dlpack/include -Itvm/HalideIR/src -Itvm/topi/include
ifdef DMLC_CORE_PATH
CFLAGS += -I$(DMLC_CORE_PATH)/include
......
/*!
* Copyright (c) 2016 by Contributors
* \file util.h
* \brief Utility functions for nnvm compiler
*/
#ifndef NNVM_COMPILER_UTIL_H_
#define NNVM_COMPILER_UTIL_H_
#include <tvm/expr.h>
#include <nnvm/tuple.h>
namespace nnvm {
namespace compiler {
/*
* \brief Helper function to convert TShape to TVM array. Useful for
* passing data from NNVM param structures to TOPI ops.
*
* \param shape The shape to convert
*
* \return An Array of Expr, where each element is a constant int32
*/
inline tvm::Array<tvm::Expr> ShapeToArray(TShape shape) {
tvm::Array<tvm::Expr> result;
for (auto i : shape) {
result.push_back(tvm::make_const(tvm::Int(32), i));
}
return result;
}
} // namespace compiler
} // namespace nnvm
#endif // NNVM_COMPILER_UTIL_H_
......@@ -28,7 +28,7 @@ namespace nnvm {
* symbol is the final operation of a graph and thus including all the information
* required (the graph) to evaluate its output value.
*/
class Symbol {
class NNVM_DLL Symbol {
public:
/*! \brief option passed to ListAttr */
enum ListAttrOption {
......
......@@ -10,59 +10,26 @@ from . import registry as reg
from .registry import OpPattern
# relu
@reg.register_compute("relu")
def compute_relu(attrs, inputs, _):
"""Compute definition of relu"""
return topi.nn.relu(inputs[0])
reg.register_schedule("relu", _fschedule_broadcast)
reg.register_pattern("relu", OpPattern.ELEMWISE)
# leaky_relu
@reg.register_compute("leaky_relu")
def compute_leaky_relu(attrs, inputs, _):
"""Compute definition of relu"""
return topi.nn.leaky_relu(inputs[0], attrs.get_float("alpha"))
reg.register_schedule("leaky_relu", _fschedule_broadcast)
reg.register_pattern("leaky_relu", OpPattern.ELEMWISE)
# flatten
@reg.register_compute("flatten")
def compute_flatten(attrs, inputs, _):
"""Compute definition of flatten"""
return topi.nn.flatten(inputs[0])
reg.register_schedule("flatten", _fschedule_broadcast)
reg.register_pattern("flatten", OpPattern.INJECTIVE)
# pad
@reg.register_compute("pad")
def compute_pad(attrs, inputs, _):
"""Compute definition of pad"""
pad_width = attrs.get_int_pair_tuple('pad_width')
assert len(pad_width) == len(inputs[0].shape) and \
len(pad_width[0]) == 2, "illegal pad_width"
pad_before = [x[0] for x in pad_width]
pad_after = [x[1] for x in pad_width]
pad_value = attrs.get_int('pad_value')
return topi.nn.pad(inputs[0], pad_before, pad_after, pad_value)
reg.register_schedule("pad", _fschedule_broadcast)
reg.register_pattern("pad", OpPattern.INJECTIVE)
# softmax
@reg.register_compute("softmax")
def compute_softmax(attrs, inputs, _):
"""Compute definition of softmax"""
axis = attrs.get_int("axis")
assert axis == -1, "only support axis == -1 for now"
return topi.nn.softmax(inputs[0])
@reg.register_schedule("softmax")
def schedule_softmax(_, outs, target):
"""Schedule definition of softmax"""
......@@ -73,13 +40,6 @@ reg.register_pattern("softmax", OpPattern.OPAQUE)
# log softmax
@reg.register_compute("log_softmax")
def compute_log_softmax(attrs, inputs, _):
"""Compute definition of softmax"""
axis = attrs.get_int("axis")
assert axis == -1, "only support axis == -1 for now"
return topi.nn.log_softmax(inputs[0])
@reg.register_schedule("log_softmax")
def schedule_log_softmax(_, outs, target):
"""Schedule definition of softmax"""
......@@ -91,13 +51,6 @@ reg.register_pattern("log_softmax", OpPattern.OPAQUE)
# dense
@reg.register_compute("dense")
def compute_dense(attrs, inputs, _):
"""Compute definition of dense"""
if attrs.get_bool("use_bias"):
return topi.nn.dense(inputs[0], inputs[1], bias=inputs[2])
return topi.nn.dense(inputs[0], inputs[1])
@reg.register_schedule("dense")
def schedule_dense(_, outs, target):
"""Schedule definition of dense"""
......@@ -175,18 +128,6 @@ reg.register_pattern("conv2d_transpose", OpPattern.OUT_ELEMWISE_FUSABLE)
# max_pool2d
@reg.register_compute("max_pool2d")
def compute_max_pool2d(attrs, inputs, _):
"""Compute definition of max_pool2d"""
pool_size = attrs.get_int_tuple("pool_size")
strides = attrs.get_int_tuple("strides")
padding = attrs.get_int_tuple("padding")
layout = attrs["layout"]
ceil_mode = attrs.get_bool("ceil_mode")
assert layout == "NCHW", "only support nchw for now"
return topi.nn.pool(inputs[0], pool_size, strides, padding,
pool_type='max', ceil_mode=ceil_mode)
@reg.register_schedule("max_pool2d")
def schedule_max_pool2d(_, outs, target):
"""Schedule definition of max_pool2d"""
......@@ -197,18 +138,6 @@ reg.register_pattern("max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
# avg_pool2d
@reg.register_compute("avg_pool2d")
def compute_avg_pool2d(attrs, inputs, _):
"""Compute definition of avg_pool2d"""
pool_size = attrs.get_int_tuple("pool_size")
strides = attrs.get_int_tuple("strides")
padding = attrs.get_int_tuple("padding")
layout = attrs["layout"]
ceil_mode = attrs.get_bool("ceil_mode")
assert layout == "NCHW", "only support nchw for now"
return topi.nn.pool(inputs[0], pool_size, strides, padding,
pool_type='avg', ceil_mode=ceil_mode)
@reg.register_schedule("avg_pool2d")
def schedule_avg_pool2d(_, outs, target):
"""Schedule definition of avg_pool2d"""
......@@ -219,13 +148,6 @@ reg.register_pattern("avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
# global_max_pool2d
@reg.register_compute("global_max_pool2d")
def compute_global_max_pool2d(attrs, inputs, _):
"""Compute definition of global_max_pool2d"""
layout = attrs["layout"]
assert layout == "NCHW", "only support nchw for now"
return topi.nn.global_pool(inputs[0], pool_type='max')
@reg.register_schedule("global_max_pool2d")
def schedule_global_max_pool2d(_, outs, target):
"""Schedule definition of global_max_pool2d"""
......@@ -236,13 +158,6 @@ reg.register_pattern("global_max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
# global_avg_pool2d
@reg.register_compute("global_avg_pool2d")
def compute_global_avg_pool2d(attrs, inputs, _):
"""Compute definition of global_avg_pool2d"""
layout = attrs["layout"]
assert layout == "NCHW", "only support nchw for now"
return topi.nn.global_pool(inputs[0], pool_type='avg')
@reg.register_schedule("global_avg_pool2d")
def schedule_global_avg_pool2d(_, outs, target):
"""Schedule definition of global_avg_pool2d"""
......
......@@ -27,16 +27,13 @@ def _compute_reduce(f):
return _compute
# sum
reg.register_compute("sum", _compute_reduce(topi.sum))
reg.register_pattern("sum", OpPattern.COMM_REDUCE)
reg.register_schedule("sum", _fschedule_reduce)
# max
reg.register_compute("max", _compute_reduce(topi.max))
reg.register_pattern("max", OpPattern.COMM_REDUCE)
reg.register_schedule("max", _fschedule_reduce)
# min
reg.register_compute("min", _compute_reduce(topi.min))
reg.register_pattern("min", OpPattern.COMM_REDUCE)
reg.register_schedule("min", _fschedule_reduce)
......@@ -43,132 +43,97 @@ _fschedule_broadcast = _fschedule_injective
_fschedule_elemwise = _fschedule_injective
# copy
reg.register_compute("copy", _compute_unary(topi.identity))
reg.register_pattern("copy", OpPattern.ELEMWISE)
reg.register_schedule("copy", _fschedule_broadcast)
# exp
reg.register_compute("exp", _compute_unary(topi.exp))
reg.register_pattern("exp", OpPattern.ELEMWISE)
reg.register_schedule("exp", _fschedule_broadcast)
# sqrt
reg.register_compute("sqrt", _compute_unary(topi.sqrt))
reg.register_pattern("sqrt", OpPattern.ELEMWISE)
reg.register_schedule("sqrt", _fschedule_broadcast)
# log
reg.register_compute("log", _compute_unary(topi.log))
reg.register_pattern("log", OpPattern.ELEMWISE)
reg.register_schedule("log", _fschedule_broadcast)
# tanh
reg.register_compute("tanh", _compute_unary(topi.tanh))
reg.register_pattern("tanh", OpPattern.ELEMWISE)
reg.register_schedule("tanh", _fschedule_broadcast)
# negative
reg.register_compute("negative", _compute_unary(topi.negative))
reg.register_pattern("negative", OpPattern.ELEMWISE)
reg.register_schedule("negative", _fschedule_broadcast)
# sigmoid
reg.register_compute("sigmoid", _compute_unary(topi.sigmoid))
reg.register_pattern("sigmoid", OpPattern.ELEMWISE)
reg.register_schedule("sigmoid", _fschedule_broadcast)
# add_scalar
reg.register_compute("__add_scalar__",
_compute_binary_scalar(lambda x, y: x + y))
reg.register_pattern("__add_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__add_scalar__", _fschedule_broadcast)
# sub_calar
reg.register_compute("__sub_scalar__",
_compute_binary_scalar(lambda x, y: x - y))
reg.register_pattern("__sub_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__sub_scalar__", _fschedule_broadcast)
# rsub_scalar
reg.register_compute("__rsub_scalar__",
_compute_binary_scalar(lambda x, y: y - x))
reg.register_pattern("__rsub_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__rsub_scalar__", _fschedule_broadcast)
# mul_scalar
reg.register_compute("__mul_scalar__",
_compute_binary_scalar(lambda x, y: x * y))
reg.register_pattern("__mul_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__mul_scalar__", _fschedule_broadcast)
# div_scalar
reg.register_compute("__div_scalar__",
_compute_binary_scalar(lambda x, y: x / y))
reg.register_pattern("__div_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__div_scalar__", _fschedule_broadcast)
# rdiv_scalar
reg.register_compute("__rdiv_scalar__",
_compute_binary_scalar(lambda x, y: y / x))
reg.register_pattern("__rdiv_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__rdiv_scalar__", _fschedule_broadcast)
# pow_scalar
reg.register_compute("__pow_scalar__",
_compute_binary_scalar(tvm.power))
reg.register_pattern("__pow_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__pow_scalar__", _fschedule_broadcast)
# rpow_scalar
reg.register_compute("__rpow_scalar__",
_compute_binary_scalar(lambda x, y: tvm.power(y, x)))
reg.register_pattern("__rpow_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__rpow_scalar__", _fschedule_broadcast)
# elemwise_add
reg.register_compute("elemwise_add", _compute_binary(topi.broadcast_add))
reg.register_pattern("elemwise_add", OpPattern.BROADCAST)
reg.register_schedule("elemwise_add", _fschedule_broadcast)
# elemwise_sub
reg.register_compute("elemwise_sub", _compute_binary(topi.broadcast_sub))
reg.register_pattern("elemwise_sub", OpPattern.BROADCAST)
reg.register_schedule("elemwise_sub", _fschedule_broadcast)
# elemwise_mul
reg.register_compute("elemwise_mul", _compute_binary(topi.broadcast_mul))
reg.register_pattern("elemwise_mul", OpPattern.BROADCAST)
reg.register_schedule("elemwise_mul", _fschedule_broadcast)
# elemwise_div
reg.register_compute("elemwise_div", _compute_binary(topi.broadcast_div))
reg.register_pattern("elemwise_div", OpPattern.BROADCAST)
reg.register_schedule("elemwise_div", _fschedule_broadcast)
# broadcast_add
reg.register_compute("broadcast_add", _compute_binary(topi.broadcast_add))
reg.register_pattern("broadcast_add", OpPattern.BROADCAST)
reg.register_schedule("broadcast_add", _fschedule_broadcast)
# broadcast_sub
reg.register_compute("broadcast_sub", _compute_binary(topi.broadcast_sub))
reg.register_pattern("broadcast_sub", OpPattern.BROADCAST)
reg.register_schedule("broadcast_sub", _fschedule_broadcast)
# broadcast_mul
reg.register_compute("broadcast_mul", _compute_binary(topi.broadcast_mul))
reg.register_pattern("broadcast_mul", OpPattern.BROADCAST)
reg.register_schedule("broadcast_mul", _fschedule_broadcast)
# broadcast_div
reg.register_compute("broadcast_div", _compute_binary(topi.broadcast_div))
reg.register_pattern("broadcast_div", OpPattern.BROADCAST)
reg.register_schedule("broadcast_div", _fschedule_broadcast)
# broadcast_to
@reg.register_compute("broadcast_to")
def compute_broadcast_to(attrs, inputs, out_info):
"""Compute definition of softmax"""
return topi.broadcast_to(inputs[0], shape=out_info[0].shape)
reg.register_pattern("broadcast_to", OpPattern.BROADCAST)
reg.register_schedule("broadcast_to", _fschedule_broadcast)
......@@ -2,71 +2,30 @@
"""Tensor transformation ops"""
from __future__ import absolute_import
import topi
from .tensor import _fschedule_broadcast, _fschedule_injective
from . import registry as reg
from .registry import OpPattern
# expand_dims
@reg.register_compute("expand_dims")
def compute_expand_dims(attrs, inputs, out_info):
"""Compute definition of expand_dims"""
return topi.expand_dims(
inputs[0], attrs.get_int("axis"),
num_newaxis=attrs.get_int("num_newaxis"))
reg.register_pattern("expand_dims", OpPattern.BROADCAST)
reg.register_schedule("expand_dims", _fschedule_broadcast)
# transpose
@reg.register_compute("transpose")
def compute_transpose(attrs, inputs, out_info):
"""Compute definition of transpose"""
axes = attrs.get_int_tuple("axes")
axes = tuple(axes) if axes else None
return topi.transpose(inputs[0], axes)
reg.register_pattern("transpose", OpPattern.INJECTIVE)
reg.register_schedule("transpose", _fschedule_injective)
# reshape
@reg.register_compute("reshape")
def compute_reshape(attrs, inputs, out_info):
"""Compute definition of reshape"""
oshape = out_info[0].shape
return topi.reshape(inputs[0], oshape)
reg.register_pattern("reshape", OpPattern.INJECTIVE)
reg.register_schedule("reshape", _fschedule_injective)
# reshape
@reg.register_compute("squeeze")
def compute_squeeze(attrs, inputs, out_info):
"""Compute definition of reshape"""
axis = attrs.get_int_tuple("axis")
axis = tuple(axis) if axis else None
return topi.squeeze(inputs[0], axis)
# squeeze
reg.register_pattern("squeeze", OpPattern.INJECTIVE)
reg.register_schedule("squeeze", _fschedule_injective)
# concatenate
@reg.register_compute("concatenate")
def compute_concatenate(attrs, inputs, out_info):
"""Compute definition of concatenate"""
axis = attrs.get_int("axis")
return topi.concatenate([x for x in inputs], axis=axis)
reg.register_pattern("concatenate", OpPattern.INJECTIVE)
reg.register_schedule("concatenate", _fschedule_injective)
# split
@reg.register_compute("split")
def compute_split(attrs, inputs, out_info):
"""Compute definition of split"""
x = attrs["indices_or_sections"]
if x.startswith("(") or x.startswith("["):
indices_or_sections = attrs.get_int_tuple("indices_or_sections")
else:
indices_or_sections = attrs.get_int("indices_or_sections")
return topi.split(inputs[0], indices_or_sections, axis=attrs.get_int("axis"))
reg.register_pattern("split", OpPattern.INJECTIVE)
reg.register_schedule("split", _fschedule_injective)
......@@ -344,7 +344,7 @@ TVM_REGISTER_GLOBAL("nnvm.compiler.CacheItem2ScheduleArgs")
*rv = ret;
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<GraphFuncNode>([](const GraphFuncNode *op, IRPrinter *p) {
p->stream << "GraphFunc(name=" << op->func_name
<< ", addr=" << op << ")";
......
......@@ -80,7 +80,7 @@ GraphKey GraphKeyNode::make(Graph graph,
return GraphKey(n);
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<GraphKeyNode>([](const GraphKeyNode *op, IRPrinter *p) {
p->stream << "GraphKeyNode("<< op << ")";
});
......
......@@ -3,17 +3,27 @@
* \file nn.cc
* \brief Property def of nn operators.
*/
#include <tvm/expr.h>
#include <tvm/packed_func_ext.h>
#include <nnvm/op.h>
#include <nnvm/node.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/compiler/op_attr_types.h>
#include <nnvm/top/nn.h>
#include "./nn_common.h"
#include "../op_common.h"
#include "../elemwise_op_common.h"
#include "topi/nn/dense.h"
#include "topi/nn.h"
#include "topi/nn/softmax.h"
namespace nnvm {
namespace top {
using tvm::Tensor;
using tvm::Array;
using nnvm::compiler::FTVMCompute;
// dense
DMLC_REGISTER_PARAMETER(DenseParam);
......@@ -72,6 +82,21 @@ If ``use_bias`` is set to be false, then the ``bias`` term is ignored.
.set_attr<FListInputNames>("FListInputNames", UseBiasListInputNames<DenseParam>)
.set_attr<FInferShape>("FInferShape", DenseInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
Tensor bias_val;
Tensor* bias;
const DenseParam& param = nnvm::get<DenseParam>(attrs.parsed);
if (param.use_bias) {
bias_val = inputs[2];
bias = &bias_val;
} else {
bias = nullptr;
}
return Array<Tensor>{ topi::nn::dense(inputs[0], inputs[1], bias) };
})
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
......@@ -110,6 +135,12 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(relu)
max(input, 0)
)code" NNVM_ADD_FILELINE)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::relu(inputs[0], 0.0f) };
})
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
......@@ -258,6 +289,14 @@ NNVM_REGISTER_OP(softmax)
.set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_support_level(1)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
CHECK_EQ(param.axis, -1) << "Currently only axis=-1 is supported";
return Array<Tensor>{ topi::nn::softmax(inputs[0]) };
})
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
......@@ -306,6 +345,14 @@ NNVM_REGISTER_OP(log_softmax)
.set_num_outputs(1)
.set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
CHECK_EQ(param.axis, -1) << "Currently only axis=-1 is supported";
return Array<Tensor>{ topi::nn::log_softmax(inputs[0]) };
})
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
......@@ -357,6 +404,13 @@ NNVM_REGISTER_OP(leaky_relu)
.set_num_outputs(1)
.set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const LeakyReLUParam& param = nnvm::get<LeakyReLUParam>(attrs.parsed);
return Array<Tensor>{ topi::leaky_relu<float>(inputs[0], 0.0, param.alpha) };
})
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
......@@ -413,6 +467,25 @@ NNVM_REGISTER_OP(pad)
.set_num_inputs(1)
.set_attr<FInferShape>("FInferShape", PadInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const PadParam& param = nnvm::get<PadParam>(attrs.parsed);
auto pad_width = param.pad_width;
CHECK(pad_width.ndim() == inputs[0]->shape.size() &&
pad_width[0].ndim() == 2)
<< "Illegal pad_width";
Array<tvm::Expr> pad_before;
for (size_t i = 0; i < pad_width.ndim(); ++i) {
pad_before.push_back(tvm::make_const(tvm::Int(32), pad_width[i][0]));
}
Array<tvm::Expr> pad_after;
for (size_t i = 0; i < pad_width.ndim(); ++i) {
pad_after.push_back(tvm::make_const(tvm::Int(32), pad_width[i][1]));
}
return Array<Tensor>{ topi::pad(inputs[0], pad_before, pad_after, param.pad_value) };
})
.set_support_level(1);
} // namespace top
......
......@@ -6,13 +6,18 @@
#include <nnvm/op.h>
#include <nnvm/node.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/compiler/op_attr_types.h>
#include <nnvm/compiler/util.h>
#include <nnvm/top/nn.h>
#include "./nn_common.h"
#include "../op_common.h"
#include "../elemwise_op_common.h"
#include "topi/nn/pooling.h"
namespace nnvm {
namespace top {
using namespace tvm;
using namespace nnvm::compiler;
DMLC_REGISTER_PARAMETER(Pool2DParam);
......@@ -77,6 +82,20 @@ NNVM_REGISTER_OP(max_pool2d)
.set_num_inputs(1)
.set_attr<FInferShape>("FInferShape", Pool2DInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const Pool2DParam& param = nnvm::get<Pool2DParam>(attrs.parsed);
auto pool_size = ShapeToArray(param.pool_size);
auto strides = ShapeToArray(param.strides);
auto padding = ShapeToArray(param.padding);
auto ceil_mode = param.ceil_mode;
CHECK_EQ(param.layout, kNCHW)
<< "max_pool2d currently only supports NCHW layout";
return Array<Tensor>{
topi::nn::pool(inputs[0], pool_size, strides, padding, topi::nn::kMaxPool, ceil_mode) };
})
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
......@@ -124,6 +143,20 @@ NNVM_REGISTER_OP(avg_pool2d)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<Pool2DParam>)
.set_attr<FInferShape>("FInferShape", Pool2DInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const Pool2DParam& param = nnvm::get<Pool2DParam>(attrs.parsed);
auto pool_size = ShapeToArray(param.pool_size);
auto strides = ShapeToArray(param.strides);
auto padding = ShapeToArray(param.padding);
auto ceil_mode = param.ceil_mode;
CHECK_EQ(param.layout, kNCHW)
<< "avg_pool2d currently only supports NCHW layout";
return Array<Tensor>{
topi::nn::pool(inputs[0], pool_size, strides, padding, topi::nn::kAvgPool, ceil_mode) };
})
.set_num_outputs(1)
.set_num_inputs(1)
.set_support_level(2);
......@@ -162,6 +195,16 @@ NNVM_REGISTER_OP(global_max_pool2d)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<GlobalPool2DParam>)
.set_attr<FInferShape>("FInferShape", GlobalPool2DInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const GlobalPool2DParam& param = nnvm::get<GlobalPool2DParam>(attrs.parsed);
CHECK_EQ(param.layout, kNCHW)
<< "global_max_pool2d currently only supports NCHW layout";
return Array<Tensor>{
topi::nn::global_pool(inputs[0], topi::nn::kMaxPool) };
})
.set_num_outputs(1)
.set_num_inputs(1)
.set_support_level(2);
......@@ -182,6 +225,16 @@ NNVM_REGISTER_OP(global_avg_pool2d)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<GlobalPool2DParam>)
.set_attr<FInferShape>("FInferShape", GlobalPool2DInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const GlobalPool2DParam& param = nnvm::get<GlobalPool2DParam>(attrs.parsed);
CHECK_EQ(param.layout, kNCHW)
<< "global_avg_pool2d currently only supports NCHW layout";
return Array<Tensor>{
topi::nn::global_pool(inputs[0], topi::nn::kAvgPool) };
})
.set_num_outputs(1)
.set_num_inputs(1)
.set_support_level(2);
......
......@@ -3,15 +3,22 @@
* \file broadcast.cc
* \brief broadcast operator.
*/
#include <tvm/expr.h>
#include <tvm/packed_func_ext.h>
#include <nnvm/op.h>
#include <nnvm/node.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/compiler/op_attr_types.h>
#include <nnvm/compiler/util.h>
#include <nnvm/top/tensor.h>
#include "../op_common.h"
#include "../elemwise_op_common.h"
#include "topi/broadcast.h"
namespace nnvm {
namespace top {
using namespace tvm;
using namespace nnvm::compiler;
// broadcast_to
DMLC_REGISTER_PARAMETER(BroadcastToParam);
......@@ -67,6 +74,14 @@ So with `shape=(2,0)`, we will obtain the same result as in the above example.
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<BroadcastToParam>)
.set_attr<FInferShape>("FInferShape", BroadcastToInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const BroadcastToParam& param = nnvm::get<BroadcastToParam>(attrs.parsed);
auto shape = ShapeToArray(param.shape);
return Array<Tensor>{ topi::broadcast_to(inputs[0], shape) };
})
.set_num_inputs(1)
.set_num_outputs(1)
.set_support_level(4);
......@@ -122,6 +137,13 @@ inline bool BinaryBroadcastShape(const nnvm::NodeAttrs& attrs,
[](const NodeAttrs& attrs) { \
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}}; \
}) \
.set_attr<FTVMCompute>( \
"FTVMCompute", [](const NodeAttrs& attrs, \
const Array<Tensor>& inputs, \
const Array<Tensor>& out_info) { \
return Array<Tensor>{ \
topi::name(inputs[0], inputs[1]) }; \
}) \
.add_argument("lhs", "Tensor", "first input") \
.add_argument("rhs", "Tensor", "second input")
......
......@@ -6,13 +6,19 @@
#include <nnvm/op.h>
#include <nnvm/node.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/compiler/op_attr_types.h>
#include <nnvm/top/tensor.h>
#include <cmath>
#include "../op_common.h"
#include "../elemwise_op_common.h"
#include "topi/broadcast.h"
#include "topi/elemwise.h"
#include "topi/tags.h"
namespace nnvm {
namespace top {
using namespace tvm;
using namespace nnvm::compiler;
// undefined op
NNVM_REGISTER_ELEMWISE_UNARY_OP(__undef__)
.describe(R"code(undefined op.
......@@ -32,6 +38,12 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(sigmoid)
)code" NNVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::sigmoid(inputs[0]) };
})
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
......@@ -56,6 +68,12 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(tanh)
)code" NNVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::tanh(inputs[0]) };
})
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
......@@ -80,6 +98,12 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(exp)
)code" NNVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::exp(inputs[0]) };
})
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
......@@ -100,6 +124,12 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(log)
)code" NNVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::log(inputs[0]) };
})
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
......@@ -120,6 +150,12 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(sqrt)
)code" NNVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::sqrt(inputs[0]) };
})
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
......@@ -140,6 +176,12 @@ NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_add)
)code")
.set_support_level(1)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::broadcast_add(inputs[0], inputs[1]) };
})
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){
......@@ -154,6 +196,12 @@ NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_sub)
)code" NNVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::broadcast_sub(inputs[0], inputs[1]) };
})
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){
......@@ -171,6 +219,12 @@ NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_mul)
)code" NNVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::broadcast_mul(inputs[0], inputs[1]) };
})
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){
......@@ -190,6 +244,12 @@ NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_div)
)code" NNVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::broadcast_div(inputs[0], inputs[1]) };
})
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){
......@@ -216,6 +276,12 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(negative)
)code" NNVM_ADD_FILELINE)
.set_support_level(3)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::negative(inputs[0]) };
})
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){
......@@ -232,6 +298,12 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(copy)
)code" NNVM_ADD_FILELINE)
.set_support_level(3)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::identity(inputs[0]) };
})
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){
......@@ -315,12 +387,29 @@ DMLC_REGISTER_PARAMETER(ScalarParam);
.set_attr_parser(ParamParser<ScalarParam>) \
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ScalarParam>)
inline Tensor binary_scalar_op(const NodeAttrs& attrs,
const Tensor& x,
std::function<Expr(Expr, Expr)> f) {
const ScalarParam& param = nnvm::get<ScalarParam>(attrs.parsed);
auto scalar_val = static_cast<float>(param.scalar);
return compute(x->shape, [&](const Array<Var>& i) {
auto scalar_const = make_const(x->dtype, scalar_val);
return f(x(i), scalar_const);
}, "tensor", topi::kElementWise);
}
NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__add_scalar__)
.describe(R"code(Tensor add scalar
)code" NNVM_ADD_FILELINE)
.set_support_level(3)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ binary_scalar_op(attrs, inputs[0],
[](Expr x, Expr y) { return x + y; }) };
})
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){
......@@ -332,6 +421,13 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__sub_scalar__)
)code" NNVM_ADD_FILELINE)
.set_support_level(3)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ binary_scalar_op(attrs, inputs[0],
[](Expr x, Expr y) { return x - y; }) };
})
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){
......@@ -343,6 +439,13 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__rsub_scalar__)
)code" NNVM_ADD_FILELINE)
.set_support_level(3)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ binary_scalar_op(attrs, inputs[0],
[](Expr x, Expr y) { return y - x; }) };
})
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){
......@@ -356,6 +459,13 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__mul_scalar__)
)code" NNVM_ADD_FILELINE)
.set_support_level(3)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ binary_scalar_op(attrs, inputs[0],
[](Expr x, Expr y) { return x * y; }) };
})
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){
......@@ -372,6 +482,13 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__div_scalar__)
)code" NNVM_ADD_FILELINE)
.set_support_level(3)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ binary_scalar_op(attrs, inputs[0],
[](Expr x, Expr y) { return x / y; }) };
})
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){
......@@ -388,6 +505,13 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__rdiv_scalar__)
)code" NNVM_ADD_FILELINE)
.set_support_level(3)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ binary_scalar_op(attrs, inputs[0],
[](Expr x, Expr y) { return y / x; }) };
})
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){
......@@ -411,6 +535,13 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__pow_scalar__)
)code" NNVM_ADD_FILELINE)
.set_support_level(3)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ binary_scalar_op(attrs, inputs[0],
[](Expr x, Expr y) { return tvm::pow(x, y); }) };
})
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){
......@@ -434,6 +565,13 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__rpow_scalar__)
)code" NNVM_ADD_FILELINE)
.set_support_level(3)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ binary_scalar_op(attrs, inputs[0],
[](Expr x, Expr y) { return tvm::pow(y, x); }) };
})
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){
......
......@@ -6,12 +6,17 @@
#include <nnvm/op.h>
#include <nnvm/node.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/compiler/op_attr_types.h>
#include <nnvm/compiler/util.h>
#include <nnvm/top/tensor.h>
#include "../op_common.h"
#include "../elemwise_op_common.h"
#include "topi/reduction.h"
namespace nnvm {
namespace top {
using namespace tvm;
using namespace nnvm::compiler;
// reduce
DMLC_REGISTER_PARAMETER(ReduceParam);
......@@ -127,6 +132,15 @@ Example::
[ 12. 19. 27.]
)code" NNVM_ADD_FILELINE)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed);
auto axis = ShapeToArray(param.axis);
return Array<Tensor>{
topi::sum(inputs[0], axis, param.keepdims) };
})
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){
......@@ -145,6 +159,15 @@ NNVM_REGISTER_REDUCE_OP(max)
.describe(R"code(Computes the max of array elements over given axes.
)code" NNVM_ADD_FILELINE)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed);
auto axis = ShapeToArray(param.axis);
return Array<Tensor>{
topi::max(inputs[0], axis, param.keepdims) };
})
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){
......@@ -168,6 +191,15 @@ NNVM_REGISTER_REDUCE_OP(min)
.describe(R"code(Computes the min of array elements over given axes.
)code" NNVM_ADD_FILELINE)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed);
auto axis = ShapeToArray(param.axis);
return Array<Tensor>{
topi::min(inputs[0], axis, param.keepdims) };
})
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){
......
......@@ -6,13 +6,19 @@
#include <nnvm/op.h>
#include <nnvm/node.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/compiler/op_attr_types.h>
#include <nnvm/compiler/util.h>
#include <nnvm/top/tensor.h>
#include <cctype>
#include "../op_common.h"
#include "../elemwise_op_common.h"
#include "topi/nn/flatten.h"
#include "topi/transform.h"
namespace nnvm {
namespace top {
using namespace tvm;
using namespace nnvm::compiler;
// flatten
inline bool FlattenInferShape(const NodeAttrs& attrs,
......@@ -58,6 +64,12 @@ Example::
.set_attr<FInferShape>("FInferShape", FlattenInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.add_argument("data", "Tensor", "Input data.")
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::nn::flatten(inputs[0]) };
})
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){
......@@ -144,6 +156,13 @@ Example::
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ConcatenateParam>)
.set_attr<FInferShape>("FInferShape", ConcatenateInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const ConcatenateParam& param = nnvm::get<ConcatenateParam>(attrs.parsed);
return Array<Tensor>{ topi::concatenate(inputs, param.axis) };
})
.set_num_outputs(1)
.set_num_inputs(kVarg)
.set_support_level(1);
......@@ -190,6 +209,13 @@ will return a new array with shape ``(2,5,3,4)``.
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const ExpandDimsParam& param = nnvm::get<ExpandDimsParam>(attrs.parsed);
return Array<Tensor>{ topi::expand_dims(inputs[0], param.axis, param.num_newaxis) };
})
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){
......@@ -326,6 +352,22 @@ along which to split the array.
.set_attr<FInferType>("FInferType", ElemwiseType<1, -1>)
.set_num_inputs(1)
.set_num_outputs(SplitNumOutputs)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const SplitParam& param = nnvm::get<SplitParam>(attrs.parsed);
if (param.equal_split) {
return Array<Tensor>{
topi::split_sections(inputs[0], param.indices_or_sections[0], param.axis) };
} else {
Array<Expr> indices;
for (auto i : param.indices_or_sections) {
indices.push_back(tvm::make_const(tvm::Int(32), i));
}
return Array<Tensor>{ topi::split(inputs[0], indices, param.axis) };
}
})
.set_support_level(1);
// cast
......@@ -504,6 +546,12 @@ The significance of each is explained below:
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::reshape(inputs[0], out_info[0]->shape) };
})
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
......@@ -620,6 +668,14 @@ Examples::
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const SqueezeParam& param = nnvm::get<SqueezeParam>(attrs.parsed);
auto axis = ShapeToArray(param.axis);
return Array<Tensor>{ topi::squeeze(inputs[0], axis) };
})
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
......@@ -695,6 +751,14 @@ Examples::
.set_num_inputs(1)
.set_num_outputs(1)
.set_support_level(4)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const TransposeParam& param = nnvm::get<TransposeParam>(attrs.parsed);
auto axes = ShapeToArray(param.axes);
return Array<Tensor>{ topi::transpose(inputs[0], axes) };
})
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment