Commit 2b3d2e21 by Tianqi Chen

[PASS] Improve GraphFuse to include five patterns (#26)

parent 2e9b6b99
NNVM Core Primitives NNVM Core Tensor Operators
==================== ==========================
**Level 1: Basic Ops** **Level 1: Basic Operators**
This level enables fully connected multi-layer perceptron.
.. autosummary:: .. autosummary::
:nosignatures: :nosignatures:
...@@ -12,12 +13,14 @@ NNVM Core Primitives ...@@ -12,12 +13,14 @@ NNVM Core Primitives
nnvm.symbol.sigmoid nnvm.symbol.sigmoid
nnvm.symbol.exp nnvm.symbol.exp
nnvm.symbol.log nnvm.symbol.log
nnvm.symbol.sqrt
nnvm.symbol.elemwise_add nnvm.symbol.elemwise_add
nnvm.symbol.elemwise_sub nnvm.symbol.elemwise_sub
nnvm.symbol.elemwise_mul nnvm.symbol.elemwise_mul
nnvm.symbol.elemwise_div nnvm.symbol.elemwise_div
nnvm.symbol.flatten nnvm.symbol.flatten
nnvm.symbol.concatenate nnvm.symbol.concatenate
nnvm.symbol.expand_dims
nnvm.symbol.split nnvm.symbol.split
nnvm.symbol.dropout nnvm.symbol.dropout
nnvm.symbol.batch_norm nnvm.symbol.batch_norm
...@@ -27,6 +30,8 @@ NNVM Core Primitives ...@@ -27,6 +30,8 @@ NNVM Core Primitives
**Level 2: Convolutions** **Level 2: Convolutions**
This level enables typical convnet models.
.. autosummary:: .. autosummary::
:nosignatures: :nosignatures:
...@@ -78,12 +83,14 @@ NNVM Core Primitives ...@@ -78,12 +83,14 @@ NNVM Core Primitives
.. autofunction:: nnvm.symbol.sigmoid .. autofunction:: nnvm.symbol.sigmoid
.. autofunction:: nnvm.symbol.exp .. autofunction:: nnvm.symbol.exp
.. autofunction:: nnvm.symbol.log .. autofunction:: nnvm.symbol.log
.. autofunction:: nnvm.symbol.sqrt
.. autofunction:: nnvm.symbol.elemwise_add .. autofunction:: nnvm.symbol.elemwise_add
.. autofunction:: nnvm.symbol.elemwise_sub .. autofunction:: nnvm.symbol.elemwise_sub
.. autofunction:: nnvm.symbol.elemwise_mul .. autofunction:: nnvm.symbol.elemwise_mul
.. autofunction:: nnvm.symbol.elemwise_div .. autofunction:: nnvm.symbol.elemwise_div
.. autofunction:: nnvm.symbol.flatten .. autofunction:: nnvm.symbol.flatten
.. autofunction:: nnvm.symbol.concatenate .. autofunction:: nnvm.symbol.concatenate
.. autofunction:: nnvm.symbol.expand_dims
.. autofunction:: nnvm.symbol.split .. autofunction:: nnvm.symbol.split
.. autofunction:: nnvm.symbol.dropout .. autofunction:: nnvm.symbol.dropout
.. autofunction:: nnvm.symbol.batch_norm .. autofunction:: nnvm.symbol.batch_norm
......
...@@ -25,16 +25,23 @@ using ::tvm::Tensor; ...@@ -25,16 +25,23 @@ using ::tvm::Tensor;
using ::tvm::Schedule; using ::tvm::Schedule;
/*! \brief operator pattern used in graph fusion */ /*! \brief operator pattern used in graph fusion */
enum OpPatternKind : int { enum OpPatternKind {
// Elementwise operation // Elementwise operation
kElemWise = 0, kElemWise = 0,
// Broadcast operation // Broadcasting operator, can always map output axis to the input in order.
// for example :code:`out[i, ax1, j, ax2] = input[i, j]`.
// Note that the axis need to be in order so transpose is not a bcast operator.
kBroadcast = 1, kBroadcast = 1,
// Complex operation, can fuse bcast in input/outputs // Injective operator, can always injectively map output axis to a single input axis.
// All injective operator can still be safely fused to injective and reduction.
kInjective = 2,
// Communicative reduction operator.
kCommReduce = 3,
// Complex operation, can still fuse elemwise operations into its output.
// but cannot chain another complex op // but cannot chain another complex op
kComplex = 2, kOutEWiseFusable = 4,
// Extern operation, cannot fuse anything. // Opaque operation, cannot fuse anything.
kExtern = 3 kOpaque = 8
}; };
/*! \brief the operator pattern */ /*! \brief the operator pattern */
......
...@@ -3,12 +3,24 @@ ...@@ -3,12 +3,24 @@
import tvm import tvm
class OpPattern(object): class OpPattern(object):
ELEM_WISE = 0 """Operator generic patterns
See Also
--------
top.tag : Contains explaination of the tag type.
"""
# Elementwise operator
ELEMWISE = 0
# Broadcast operator
BROADCAST = 1 BROADCAST = 1
# Complex means we can fuse elemwise to it # Injective mapping
COMPLEX = 2 INJECTIVE = 2
# Extern means the op is not fusable # Comunication
EXTERN = 3 COMM_REDUCE = 3
# Complex op, can still fuse ewise into it
OUT_ELEMWISE_FUSABLE = 4
# Not fusable opaque op
OPAQUE = 8
_register_compute = tvm.get_global_func("nnvm._register_compute") _register_compute = tvm.get_global_func("nnvm._register_compute")
_register_schedule = tvm.get_global_func("nnvm._register_schedule") _register_schedule = tvm.get_global_func("nnvm._register_schedule")
......
...@@ -16,8 +16,16 @@ def compute_relu(attrs, inputs, _): ...@@ -16,8 +16,16 @@ def compute_relu(attrs, inputs, _):
return topi.nn.relu(inputs[0]) return topi.nn.relu(inputs[0])
reg.register_schedule("relu", _fschedule_broadcast) reg.register_schedule("relu", _fschedule_broadcast)
reg.register_pattern("relu", OpPattern.ELEM_WISE) reg.register_pattern("relu", OpPattern.ELEMWISE)
# leaky_relu
@reg.register_compute("leaky_relu")
def compute_relu(attrs, inputs, _):
"""Compute definition of relu"""
return topi.nn.leaky_relu(inputs[0])
reg.register_schedule("leaky_relu", _fschedule_broadcast)
reg.register_pattern("leaky_relu", OpPattern.ELEMWISE)
# flatten # flatten
@reg.register_compute("flatten") @reg.register_compute("flatten")
...@@ -26,7 +34,7 @@ def compute_flatten(attrs, inputs, _): ...@@ -26,7 +34,7 @@ def compute_flatten(attrs, inputs, _):
return topi.nn.flatten(inputs[0]) return topi.nn.flatten(inputs[0])
reg.register_schedule("flatten", _fschedule_broadcast) reg.register_schedule("flatten", _fschedule_broadcast)
reg.register_pattern("flatten", OpPattern.COMPLEX) reg.register_pattern("flatten", OpPattern.INJECTIVE)
# softmax # softmax
...@@ -46,7 +54,7 @@ def schedule_softmax(_, outs, target): ...@@ -46,7 +54,7 @@ def schedule_softmax(_, outs, target):
return tvm.create_schedule([x.op for x in outs]) return tvm.create_schedule([x.op for x in outs])
# Mark softmax as extern as we do not fuse it in call cases # Mark softmax as extern as we do not fuse it in call cases
reg.register_pattern("softmax", OpPattern.EXTERN) reg.register_pattern("softmax", OpPattern.OPAQUE)
# dense # dense
...@@ -67,7 +75,7 @@ def schedule_dense(_, outs, target): ...@@ -67,7 +75,7 @@ def schedule_dense(_, outs, target):
return tvm.create_schedule([x.op for x in outs]) return tvm.create_schedule([x.op for x in outs])
# register extern for now, change me when fusion is enabled. # register extern for now, change me when fusion is enabled.
reg.register_pattern("dense", OpPattern.EXTERN) reg.register_pattern("dense", OpPattern.OPAQUE)
# conv # conv
...@@ -105,4 +113,4 @@ def schedule_conv2d(attrs, outs, target): ...@@ -105,4 +113,4 @@ def schedule_conv2d(attrs, outs, target):
# naive schedule # naive schedule
return tvm.create_schedule([x.op for x in outs]) return tvm.create_schedule([x.op for x in outs])
reg.register_pattern("conv2d", OpPattern.COMPLEX) reg.register_pattern("conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
...@@ -8,13 +8,15 @@ import topi.cuda ...@@ -8,13 +8,15 @@ import topi.cuda
from ..compiler import registry as reg from ..compiler import registry as reg
from ..compiler import OpPattern from ..compiler import OpPattern
def _schedule_broadcast(_, outs, target): def _schedule_injective(_, outs, target):
"""Generic schedule for binary bcast""" """Generic schedule for binary bcast"""
if target == "cuda": if target == "cuda":
return topi.cuda.schedule_elemwise(outs) return topi.cuda.schedule_injective(outs)
assert target.startswith("llvm") assert target.startswith("llvm")
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
x = outs[0]
tvm.schedule.AutoInlineInjective(s) tvm.schedule.AutoInlineInjective(s)
s[x].fuse(s[x].op.axis)
return s return s
def _compute_binary_scalar(f): def _compute_binary_scalar(f):
...@@ -42,89 +44,91 @@ def _compute_binary(f): ...@@ -42,89 +44,91 @@ def _compute_binary(f):
return _compute return _compute
_fschedule_broadcast = tvm.convert(_schedule_broadcast) _fschedule_injective = tvm.convert(_schedule_injective)
_fschedule_broadcast = _fschedule_injective
_fschedule_elemwise = _fschedule_injective
# copy # copy
reg.register_compute("copy", _compute_unary(topi.identity)) reg.register_compute("copy", _compute_unary(topi.identity))
reg.register_pattern("copy", OpPattern.ELEM_WISE) reg.register_pattern("copy", OpPattern.ELEMWISE)
reg.register_schedule("copy", _fschedule_broadcast) reg.register_schedule("copy", _fschedule_broadcast)
# exp # exp
reg.register_compute("exp", _compute_unary(topi.exp)) reg.register_compute("exp", _compute_unary(topi.exp))
reg.register_pattern("exp", OpPattern.ELEM_WISE) reg.register_pattern("exp", OpPattern.ELEMWISE)
reg.register_schedule("exp", _fschedule_broadcast) reg.register_schedule("exp", _fschedule_broadcast)
# sqrt # sqrt
reg.register_compute("sqrt", _compute_unary(topi.sqrt)) reg.register_compute("sqrt", _compute_unary(topi.sqrt))
reg.register_pattern("sqrt", OpPattern.ELEM_WISE) reg.register_pattern("sqrt", OpPattern.ELEMWISE)
reg.register_schedule("sqrt", _fschedule_broadcast) reg.register_schedule("sqrt", _fschedule_broadcast)
# log # log
reg.register_compute("log", _compute_unary(topi.log)) reg.register_compute("log", _compute_unary(topi.log))
reg.register_pattern("log", OpPattern.ELEM_WISE) reg.register_pattern("log", OpPattern.ELEMWISE)
reg.register_schedule("log", _fschedule_broadcast) reg.register_schedule("log", _fschedule_broadcast)
# tanh # tanh
reg.register_compute("tanh", _compute_unary(topi.tanh)) reg.register_compute("tanh", _compute_unary(topi.tanh))
reg.register_pattern("tanh", OpPattern.ELEM_WISE) reg.register_pattern("tanh", OpPattern.ELEMWISE)
reg.register_schedule("tanh", _fschedule_broadcast) reg.register_schedule("tanh", _fschedule_broadcast)
# negative # negative
reg.register_compute("negative", _compute_unary(topi.negative)) reg.register_compute("negative", _compute_unary(topi.negative))
reg.register_pattern("negative", OpPattern.ELEM_WISE) reg.register_pattern("negative", OpPattern.ELEMWISE)
reg.register_schedule("negative", _fschedule_broadcast) reg.register_schedule("negative", _fschedule_broadcast)
# sigmoid # sigmoid
reg.register_compute("sigmoid", _compute_unary(topi.sigmoid)) reg.register_compute("sigmoid", _compute_unary(topi.sigmoid))
reg.register_pattern("sigmoid", OpPattern.ELEM_WISE) reg.register_pattern("sigmoid", OpPattern.ELEMWISE)
reg.register_schedule("sigmoid", _fschedule_broadcast) reg.register_schedule("sigmoid", _fschedule_broadcast)
# add_scalar # add_scalar
reg.register_compute("__add_scalar__", reg.register_compute("__add_scalar__",
_compute_binary_scalar(lambda x, y: x + y)) _compute_binary_scalar(lambda x, y: x + y))
reg.register_pattern("__add_scalar__", OpPattern.ELEM_WISE) reg.register_pattern("__add_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__add_scalar__", _fschedule_broadcast) reg.register_schedule("__add_scalar__", _fschedule_broadcast)
# sub_calar # sub_calar
reg.register_compute("__sub_scalar__", reg.register_compute("__sub_scalar__",
_compute_binary_scalar(lambda x, y: x - y)) _compute_binary_scalar(lambda x, y: x - y))
reg.register_pattern("__sub_scalar__", OpPattern.ELEM_WISE) reg.register_pattern("__sub_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__sub_scalar__", _fschedule_broadcast) reg.register_schedule("__sub_scalar__", _fschedule_broadcast)
# rsub_scalar # rsub_scalar
reg.register_compute("__rsub_scalar__", reg.register_compute("__rsub_scalar__",
_compute_binary_scalar(lambda x, y: y - x)) _compute_binary_scalar(lambda x, y: y - x))
reg.register_pattern("__rsub_scalar__", OpPattern.ELEM_WISE) reg.register_pattern("__rsub_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__rsub_scalar__", _fschedule_broadcast) reg.register_schedule("__rsub_scalar__", _fschedule_broadcast)
# mul_scalar # mul_scalar
reg.register_compute("__mul_scalar__", reg.register_compute("__mul_scalar__",
_compute_binary_scalar(lambda x, y: x * y)) _compute_binary_scalar(lambda x, y: x * y))
reg.register_pattern("__mul_scalar__", OpPattern.ELEM_WISE) reg.register_pattern("__mul_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__mul_scalar__", _fschedule_broadcast) reg.register_schedule("__mul_scalar__", _fschedule_broadcast)
# div_scalar # div_scalar
reg.register_compute("__div_scalar__", reg.register_compute("__div_scalar__",
_compute_binary_scalar(lambda x, y: x / y)) _compute_binary_scalar(lambda x, y: x / y))
reg.register_pattern("__div_scalar__", OpPattern.ELEM_WISE) reg.register_pattern("__div_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__div_scalar__", _fschedule_broadcast) reg.register_schedule("__div_scalar__", _fschedule_broadcast)
# rdiv_scalar # rdiv_scalar
reg.register_compute("__rdiv_scalar__", reg.register_compute("__rdiv_scalar__",
_compute_binary_scalar(lambda x, y: y / x)) _compute_binary_scalar(lambda x, y: y / x))
reg.register_pattern("__rdiv_scalar__", OpPattern.ELEM_WISE) reg.register_pattern("__rdiv_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__rdiv_scalar__", _fschedule_broadcast) reg.register_schedule("__rdiv_scalar__", _fschedule_broadcast)
# pow_scalar # pow_scalar
reg.register_compute("__pow_scalar__", reg.register_compute("__pow_scalar__",
_compute_binary_scalar(tvm.power)) _compute_binary_scalar(tvm.power))
reg.register_pattern("__pow_scalar__", OpPattern.ELEM_WISE) reg.register_pattern("__pow_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__pow_scalar__", _fschedule_broadcast) reg.register_schedule("__pow_scalar__", _fschedule_broadcast)
# rpow_scalar # rpow_scalar
reg.register_compute("__rpow_scalar__", reg.register_compute("__rpow_scalar__",
_compute_binary_scalar(lambda x, y: tvm.power(y, x))) _compute_binary_scalar(lambda x, y: tvm.power(y, x)))
reg.register_pattern("__rpow_scalar__", OpPattern.ELEM_WISE) reg.register_pattern("__rpow_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__rpow_scalar__", _fschedule_broadcast) reg.register_schedule("__rpow_scalar__", _fschedule_broadcast)
# elemwise_add # elemwise_add
......
...@@ -37,5 +37,5 @@ def compute_reshape(attrs, inputs, out_info): ...@@ -37,5 +37,5 @@ def compute_reshape(attrs, inputs, out_info):
oshape = out_info[0].shape oshape = out_info[0].shape
x = inputs[0] x = inputs[0]
return tvm.compute(oshape, lambda *i: x(_flatten_index(i, oshape))) return tvm.compute(oshape, lambda *i: x(_flatten_index(i, oshape)))
reg.register_pattern("reshape", OpPattern.COMPLEX) reg.register_pattern("reshape", OpPattern.INJECTIVE)
reg.register_schedule("reshape", _fschedule_broadcast) reg.register_schedule("reshape", _fschedule_broadcast)
...@@ -71,7 +71,7 @@ nnvm::Graph GraphFusePartition(nnvm::Graph g) { ...@@ -71,7 +71,7 @@ nnvm::Graph GraphFusePartition(nnvm::Graph g) {
ref_count[e.node_id] += 2; ref_count[e.node_id] += 2;
} }
// Pattern for the subgraph // Pattern for the subgraph
std::vector<TOpPattern> pattern_vec(idx.num_nodes(), kExtern); std::vector<TOpPattern> pattern_vec(idx.num_nodes(), kOpaque);
// Whether node can be fused to parent. // Whether node can be fused to parent.
std::vector<FuseRule> fuse_vec(idx.num_nodes(), FuseRule::kUknown); std::vector<FuseRule> fuse_vec(idx.num_nodes(), FuseRule::kUknown);
// Master node id of fusion segment. // Master node id of fusion segment.
...@@ -84,19 +84,21 @@ nnvm::Graph GraphFusePartition(nnvm::Graph g) { ...@@ -84,19 +84,21 @@ nnvm::Graph GraphFusePartition(nnvm::Graph g) {
if (inode.source->is_variable()) { if (inode.source->is_variable()) {
fuse_vec[nid] = FuseRule::kRealize; continue; fuse_vec[nid] = FuseRule::kRealize; continue;
} }
TOpPattern pt = op_pattern.get(inode.source->op(), kExtern); TOpPattern pt = op_pattern.get(inode.source->op(), kOpaque);
if (pt <= kBroadcast) { if (pt <= kBroadcast) {
// Try to check if we can fuse to the master.
int chosen_master = -1; int chosen_master = -1;
bool ewise = inode.source->num_outputs() == 1; bool ewise = inode.source->num_outputs() == 1;
for (const auto& e : inode.inputs) { for (const auto& e : inode.inputs) {
if (fuse_vec[e.node_id] == FuseRule::kUknown) { if (fuse_vec[e.node_id] == FuseRule::kUknown) {
TOpPattern ipt = pattern_vec[e.node_id]; TOpPattern ipt = pattern_vec[e.node_id];
if (ipt != kElemWise) ewise = false; if (ipt != kElemWise) ewise = false;
if (ipt <= kBroadcast) { if (ipt <= kInjective) {
fuse_vec[e.node_id] = FuseRule::kFuseToMaster; fuse_vec[e.node_id] = FuseRule::kFuseToMaster;
} else if (ipt == kComplex && chosen_master == -1 && } else if (ipt == kOutEWiseFusable &&
shape_vec[idx.entry_id(nid, 0)] == shape_vec[idx.entry_id(e)]) { chosen_master == -1 &&
shape_vec[idx.entry_id(nid, 0)] == shape_vec[idx.entry_id(e)]) {
chosen_master = master_vec[e.node_id]; chosen_master = master_vec[e.node_id];
fuse_vec[e.node_id] = FuseRule::kFuseToMaster; fuse_vec[e.node_id] = FuseRule::kFuseToMaster;
} else { } else {
...@@ -111,11 +113,27 @@ nnvm::Graph GraphFusePartition(nnvm::Graph g) { ...@@ -111,11 +113,27 @@ nnvm::Graph GraphFusePartition(nnvm::Graph g) {
} }
master_vec[nid] = chosen_master; master_vec[nid] = chosen_master;
if (chosen_master != -1) { if (chosen_master != -1) {
pt = kComplex; pt = kOutEWiseFusable;
} else { } else {
pt = ewise ? kElemWise : kBroadcast; pt = ewise ? kElemWise : kBroadcast;
} }
} else if (pt == kInjective || pt == kCommReduce) {
// fuse to the comm reduce or injective
for (const auto& e : inode.inputs) {
if (fuse_vec[e.node_id] == FuseRule::kUknown) {
TOpPattern ipt = pattern_vec[e.node_id];
if (ipt <= kInjective) {
fuse_vec[e.node_id] = FuseRule::kFuseToMaster;
} else {
fuse_vec[e.node_id] = FuseRule::kRealize;
}
}
}
if (pt == kCommReduce) {
master_vec[nid] = nid;
}
} else { } else {
// realize
master_vec[nid] = nid; master_vec[nid] = nid;
for (const auto& e : inode.inputs) { for (const auto& e : inode.inputs) {
if (fuse_vec[e.node_id] == FuseRule::kUknown) { if (fuse_vec[e.node_id] == FuseRule::kUknown) {
...@@ -136,7 +154,6 @@ nnvm::Graph GraphFusePartition(nnvm::Graph g) { ...@@ -136,7 +154,6 @@ nnvm::Graph GraphFusePartition(nnvm::Graph g) {
} }
} }
// point to the group root id of each node // point to the group root id of each node
std::vector<int> group_vec(idx.num_nodes(), -1); std::vector<int> group_vec(idx.num_nodes(), -1);
for (uint32_t i = idx.num_nodes(); i != 0; --i) { for (uint32_t i = idx.num_nodes(); i != 0; --i) {
......
...@@ -52,7 +52,7 @@ nnvm::Graph LayoutTransform(nnvm::Graph src) { ...@@ -52,7 +52,7 @@ nnvm::Graph LayoutTransform(nnvm::Graph src) {
// use op pattern to decide whether an op is map // use op pattern to decide whether an op is map
auto is_map_op = [&](size_t nid) { auto is_map_op = [&](size_t nid) {
TOpPattern pt = op_pattern.get(idx[nid].source->op(), kExtern); TOpPattern pt = op_pattern.get(idx[nid].source->op(), kOpaque);
bool is_map = (pt <= kBroadcast); bool is_map = (pt <= kBroadcast);
if (pt == kBroadcast) { if (pt == kBroadcast) {
for (const auto& e : idx[nid].inputs) { for (const auto& e : idx[nid].inputs) {
......
import nnvm
import numpy as np
import tvm
import topi
from nnvm import symbol as sym
from nnvm.compiler import graph_util, graph_attr
from nnvm.testing.config import test_ctx_list
def test_ewise_injective():
x = sym.Variable("x")
y = x * 2
y = sym.flatten(y) + 1
dshape = (10, 2, 3)
shape_dict = {"x": dshape}
dtype = "float32"
target = "llvm"
for target, ctx in test_ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, shape_dict)
assert graph.index.num_nodes == 2
m = nnvm.runtime.create(graph, lib, ctx)
x_np = np.random.uniform(size=dshape).astype(dtype)
m.run(x=x_np)
out = m.get_output(0, tvm.nd.empty((10, 6)))
np.testing.assert_allclose(
out.asnumpy(), x_np.reshape(out.shape) * 2 + 1,
atol=1e-5, rtol=1e-5)
def test_conv_ewise_injective():
x = sym.Variable("x")
y = sym.conv2d(x, channels=32, kernel_size=(3, 3), groups=32,
name="y", padding=(1,1))
y = sym.flatten(y + 1) + 1
dtype = "float32"
dshape = (1, 32, 18, 18)
kshape = (32, 1, 3, 3)
oshape = (1, 32* 18 * 18)
shape_dict = {"x": dshape}
for target, ctx in test_ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, shape_dict)
m = nnvm.runtime.create(graph, lib, ctx)
# print(graph.ir(join_entry_attrs=["shape"]))
assert graph.index.num_nodes == 5
# set input
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
kernel = tvm.nd.array(np.random.uniform(size=kshape).astype(dtype))
bias = tvm.nd.array(np.random.uniform(size=kshape[0]).astype(dtype))
m.run(x=data, y_weight=kernel, y_bias=bias)
# get output
out = m.get_output(0, tvm.nd.empty(oshape, dtype))
c_np = topi.testing.depthwise_conv2d_python_nchw(
data.asnumpy(), kernel.asnumpy(), (1,1), 'SAME')
c_np = c_np + bias.asnumpy().reshape(kshape[0], 1, 1) + 1
c_np = c_np.reshape(c_np.shape[0], np.prod(c_np.shape[1:])) + 1
np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5)
if __name__ == "__main__":
test_ewise_injective()
test_conv_ewise_injective()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment