Commit c829bd86 by Tianqi Chen

[PASS] PrintGraphIR, SimplifyBatchNormInference (#19)

parent 948f6898
......@@ -128,6 +128,33 @@ class Node {
static NodePtr Create();
};
/*!
* \brief Quick utilities make node.
* \param op_name The name of operator
* \param node_name The name of the node
* \param inputs The input entries
* \param attrs The attributes
* \return The created node entry.
*/
inline NodeEntry MakeNode(
const char* op_name,
std::string node_name,
std::vector<NodeEntry> inputs,
std::unordered_map<std::string, std::string> attrs =
std::unordered_map<std::string, std::string>()) {
NodePtr p = Node::Create();
p->attrs.op = nnvm::Op::Get(op_name);
p->attrs.name = std::move(node_name);
if (attrs.size() != 0) {
p->attrs.dict = attrs;
if (p->attrs.op->attr_parser) {
p->attrs.op->attr_parser(&(p->attrs));
}
}
p->inputs = std::move(inputs);
return NodeEntry{p, 0, 0};
}
// implementation of functions.
inline const Op* Node::op() const {
return this->attrs.op;
......
......@@ -83,6 +83,5 @@ def set_layout_inputs(g, layout):
g._set_json_attr("layout_inputs", list_shape, 'list_str')
return g
_move_out_module = tvm.get_global_func("nnvm.graph_attr._move_module")
_move_out_graph = tvm.get_global_func("nnvm.graph_attr._move_graph")
_move_out_module = tvm.get_global_func("nnvm.graph._move_module")
_move_out_graph = tvm.get_global_func("nnvm.graph._move_graph")
......@@ -7,6 +7,7 @@ Principle:
"""
from __future__ import absolute_import as _abs
import tvm
from . import graph_attr
......@@ -60,3 +61,26 @@ def infer_dtype(graph, **dtype):
output_dtype = [graph_attr.TCODE_TO_DTYPE[dtype[index.entry_id(x)]]
for x in index.output_entries]
return input_dtype, output_dtype
_deep_compare = tvm.get_global_func("nnvm.graph.DeepCompare")
def check_graph_equal(grapha, graphb):
"""Check if two graphs have equal structure.
Parameters
----------
grapha : Graph
The first graph
graphb : Graph
The second graph
Raises
------
ValueError
ValueError is raised with error message when graph not equal
"""
err = _deep_compare(grapha, graphb)
if err:
raise ValueError("Graph compare error: " + err)
......@@ -177,6 +177,10 @@ class Graph(object):
self._index = GraphIndex(self)
return self._index
def graphir(self):
"""Get text form of graph ir."""
return self.apply("PrintGraphIR").json_attr("graphir")
def apply(self, passes):
"""Apply passes to the graph
......
......@@ -10,7 +10,7 @@ from ..compiler import OpPattern
# relu
@reg.register_compute("relu")
def compute_relu(attrs, inputs):
def compute_relu(_, inputs):
"""Compute definition of relu"""
return topi.nn.relu(inputs[0])
......@@ -72,8 +72,7 @@ def schedule_conv2d(attrs, outs, target):
if target == "cuda":
if groups == 1:
return topi.cuda.schedule_conv2d_nchw(outs)
else:
return topi.cuda.schedule_depthwise_conv2d_nchw(outs)
return topi.cuda.schedule_depthwise_conv2d_nchw(outs)
# naive schedule
return tvm.create_schedule([x.op for x in outs])
......
/*!
* Copyright (c) 2017 by Contributors
* \file graph_deep_compare.cc
* \brief Deep compare two graph structure
*/
#include <nnvm/graph.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/compiler/packed_func_ext.h>
#include <tvm/runtime/packed_func.h>
#include "./node_attr.h"
namespace nnvm {
namespace compiler {
// deep compare the graph structure
// not considering the graph attributes
// return non-empty error message if the graph mismatch.
// the comparator won't match name of intermediate node.
std::string DeepCompare(Graph a, Graph b) {
const IndexedGraph& idxa = a.indexed_graph();
const IndexedGraph& idxb = b.indexed_graph();
std::ostringstream err;
if (idxa.num_nodes() != idxb.num_nodes()) {
err << "Number of nodes mismatch";
return err.str();
}
if (idxa.num_node_entries() != idxb.num_node_entries()) {
err << "Number of node entry mismatch";
return err.str();
}
if (idxa.outputs().size() != idxb.outputs().size()) {
err << "Number of outputs mismatch";
return err.str();
}
for (size_t i = 0; i < idxa.outputs().size(); ++i) {
if (idxa.outputs()[i].node_id != idxb.outputs()[i].node_id ||
idxa.outputs()[i].index != idxb.outputs()[i].index) {
err << "Output entry mismatch";
return err.str();
}
}
if (idxa.input_nodes().size() != idxb.input_nodes().size()) {
err << "Number of inputs mismatch";
return err.str();
}
for (uint32_t nid = 0; nid < idxa.num_nodes(); ++nid) {
const IndexedGraph::Node& anode = idxa[nid];
const IndexedGraph::Node& bnode = idxb[nid];
if (anode.source->op() != bnode.source->op()) {
err << "Node mismatch ";
return err.str();
}
AttrDict adict = GetAttrDict(anode.source->attrs);
AttrDict bdict = GetAttrDict(bnode.source->attrs);
auto fmatch = [&err, &anode](const AttrDict& adict, const AttrDict& bdict) {
for (const auto& kv : adict) {
auto it = bdict.find(kv.first);
if (it != bdict.end()) {
if (it->second != kv.second) {
err << "Node attr mismatch, op=" << anode.source->attrs.name
<< " attr_key=" << kv.first << " " << it->second
<< " v.s. " << kv.second;
return false;
}
} else {
err << "One attr_key=" << kv.first << " is missing in another "
<< "op=" << anode.source->attrs.name;
return false;
}
}
return true;
};
if (!fmatch(adict, bdict)) return err.str();
if (adict.size() != bdict.size()) {
CHECK(!fmatch(bdict, adict));
return err.str();
}
if (anode.inputs.size() != bnode.inputs.size()) {
err << "Node input mismatch, op=" << anode.source->attrs.name;
return err.str();
}
if (anode.control_deps.size() != bnode.control_deps.size()) {
err << "Node control_deps mistach, op=" << anode.source->attrs.name;
return err.str();
}
for (size_t i = 0; i < anode.inputs.size(); ++i) {
const IndexedGraph::NodeEntry& ae = anode.inputs[i];
const IndexedGraph::NodeEntry& be = bnode.inputs[i];
if (ae.node_id != be.node_id ||
ae.index != be.index ||
ae.version != be.version) {
err << "Node input mismatch on, op=" << anode.source->attrs.name;
return err.str();
}
}
for (size_t i = 0; i < anode.control_deps.size(); ++i) {
if (anode.control_deps[i] != bnode.control_deps[i]) {
err << "Node control_dep mismatch on, op=" << anode.source->attrs.name;
return err.str();
}
}
}
return "";
}
TVM_REGISTER_GLOBAL("nnvm.graph.DeepCompare")
.set_body([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue *rv) {
*rv = DeepCompare(args[0], args[1]);
});
} // namespace compiler
} // namespace nnvm
......@@ -13,7 +13,7 @@
#include <tvm/runtime/packed_func.h>
#include <tvm/operation.h>
#include <tvm/lowered_func.h>
#include "../../runtime/graph_executor.h"
#include "../runtime/graph_executor.h"
namespace nnvm {
namespace compiler {
......
/*!
* Copyright (c) 2017 by Contributors
* \file graph_transform.h
* \brief A mutator class that does local pattern matching and mutates a node.
*/
#ifndef NNVM_COMPILER_GRAPH_TRANSFORM_H_
#define NNVM_COMPILER_GRAPH_TRANSFORM_H_
#include <nnvm/graph.h>
#include <vector>
namespace nnvm {
namespace compiler {
/*!
* \brief Transform the graph to build a new Graph, in post DFS order.
*
* Automatically copies node when some of its children or control_deps changed.
* This function won't be called in Variable.
*
* \param graph The original graph
*
* \param ftransform Function of (int nid, const Node* node, std::vector<NodeEntry>* out) -> bool
*
* If empty vector is returned, it means original entries should be kept.
*
* \tparam FTransform The transformation function.
*/
template<typename FTransform>
Graph GraphTransform(Graph graph, FTransform ftransform) {
const IndexedGraph& idx = graph.indexed_graph();
// new nodes
std::vector<NodeEntry> new_entry_map(idx.num_node_entries());
std::vector<bool> updated(idx.num_node_entries(), false);
// setup inputs and placeholder.
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid];
if (inode.source->is_variable()) continue;
bool need_copy = false;
for (const IndexedGraph::NodeEntry& e : inode.inputs) {
if (updated[idx.entry_id(e)]) {
need_copy = true; break;
}
}
if (!need_copy) {
for (const uint32_t cid : inode.control_deps) {
const auto& cnode = idx[cid];
for (uint32_t i = 0 ; i < cnode.source->num_outputs(); ++i) {
if (updated[idx.entry_id(cid, i)]) {
need_copy = true;
}
}
if (need_copy) break;
}
}
if (!need_copy) {
std::vector<NodeEntry> ret;
if (ftransform(nid, inode.source, &ret)) {
CHECK_EQ(ret.size(), static_cast<size_t>(inode.source->num_outputs()));
for (uint32_t i = 0 ; i < inode.source->num_outputs(); ++i) {
updated[idx.entry_id(nid, i)] = true;
new_entry_map[idx.entry_id(nid, i)] = ret[i];
}
}
} else {
NodePtr node = Node::Create();
node->attrs = inode.source->attrs;
for (size_t i = 0; i < inode.inputs.size(); ++i) {
const IndexedGraph::NodeEntry& e = inode.inputs[i];
if (updated[idx.entry_id(e)]) {
node->inputs.push_back(new_entry_map[idx.entry_id(e)]);
} else {
node->inputs.push_back(inode.source->inputs[i]);
}
}
for (size_t i = 0; i < inode.control_deps.size(); ++i) {
const uint32_t cid = inode.control_deps[i];
const auto& cnode = idx[cid];
CHECK_NE(cnode.source->num_outputs(), 0U);
NodePtr selected_ptr;
for (uint32_t j = 0 ; j < cnode.source->num_outputs(); ++j) {
NodePtr cptr = updated[idx.entry_id(cid, j)] ?
new_entry_map[idx.entry_id(cid, j)].node : inode.source->control_deps[i];
if (selected_ptr == nullptr) {
selected_ptr = std::move(cptr);
} else {
CHECK(selected_ptr.get() == cptr.get())
<< "Control dependency node changed to more than one node";
}
}
node->control_deps.push_back(selected_ptr);
}
std::vector<NodeEntry> ret;
if (ftransform(nid, node.get(), &ret)) {
CHECK_EQ(ret.size(), static_cast<size_t>(inode.source->num_outputs()));
for (uint32_t i = 0 ; i < inode.source->num_outputs(); ++i) {
updated[idx.entry_id(nid, i)] = true;
new_entry_map[idx.entry_id(nid, i)] = ret[i];
}
} else {
for (uint32_t i = 0 ; i < inode.source->num_outputs(); ++i) {
updated[idx.entry_id(nid, i)] = true;
new_entry_map[idx.entry_id(nid, i)] = NodeEntry{node, i, 0};
}
}
}
}
Graph ret;
for (size_t i = 0; i < idx.outputs().size(); ++i) {
const IndexedGraph::NodeEntry& e = idx.outputs()[i];
if (updated[idx.entry_id(e)]) {
ret.outputs.push_back(new_entry_map[idx.entry_id(e)]);
} else {
ret.outputs.push_back(graph.outputs[i]);
}
}
return ret;
}
} // namespace compiler
} // namespace nnvm
#endif // NNVM_COMPILER_GRAPH_TRANSFORM_H_
/*!
* Copyright (c) 2017 by Contributors
* \file node_attr.h
* \brief utility to access node attributes
*/
#ifndef NNVM_COMPILER_NODE_ATTR_H_
#define NNVM_COMPILER_NODE_ATTR_H_
#include <nnvm/op.h>
#include <nnvm/compiler/op_attr_types.h>
#include <unordered_map>
#include <string>
namespace nnvm {
namespace compiler {
using AttrDict = std::unordered_map<std::string, std::string>;
/*!
* \brief Get canonicalized attr dict from node
* \param attrs The node attrs
* \return The attribute dict
*/
inline AttrDict GetAttrDict(const NodeAttrs& attrs) {
static auto& fgetdict = nnvm::Op::GetAttr<FGetAttrDict>("FGetAttrDict");
if (fgetdict.count(attrs.op)) {
return fgetdict[attrs.op](attrs);
} else {
return attrs.dict;
}
}
} // namespace compiler
} // namespace nnvm
#endif // NNVM_COMPILER_NODE_ATTR_H_
......@@ -8,6 +8,7 @@
#include <nnvm/op.h>
#include <nnvm/compiler/packed_func_ext.h>
#include <nnvm/compiler/op_attr_types.h>
#include "./node_attr.h"
namespace tvm {
namespace runtime {
......@@ -19,7 +20,6 @@ TVM_REGISTER_EXT_TYPE(nnvm::compiler::AttrDict);
} // namespace runtime
} // namespace tvm
namespace nnvm {
namespace compiler {
......@@ -58,17 +58,6 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._dict_keys")
});
// custom version of TVM compute
inline std::unordered_map<std::string, std::string>
GetAttrDict(const NodeAttrs& attrs) {
static auto& fgetdict = nnvm::Op::GetAttr<FGetAttrDict>("FGetAttrDict");
if (fgetdict.count(attrs.op)) {
return fgetdict[attrs.op](attrs);
} else {
return attrs.dict;
}
}
TVM_REGISTER_GLOBAL("nnvm._register_compute")
.set_body([](TVMArgs args, TVMRetValue *rv) {
// Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown
......@@ -105,14 +94,14 @@ TVM_REGISTER_GLOBAL("nnvm._register_pattern")
op.set_attr<TOpPattern>("TOpPattern", args[1].operator int(), args[2]);
});
TVM_REGISTER_GLOBAL("nnvm.graph_attr._move_module")
TVM_REGISTER_GLOBAL("nnvm.graph._move_module")
.set_body([](TVMArgs args, TVMRetValue *rv) {
const nnvm::Graph& g = args[0].AsExtension<Graph>();
*rv = const_cast<nnvm::Graph*>(&g)->
MoveCopyAttr<tvm::runtime::Module>(args[1]);
});
TVM_REGISTER_GLOBAL("nnvm.graph_attr._move_graph")
TVM_REGISTER_GLOBAL("nnvm.graph._move_graph")
.set_body([](TVMArgs args, TVMRetValue *rv) {
const nnvm::Graph& g = args[0].AsExtension<Graph>();
*rv = const_cast<nnvm::Graph*>(&g)->
......
/*!
* Copyright (c) 2017 by Contributors
* \file simplify_batch_norm.cc
* \author Ziheng Jiang
*/
#include <nnvm/graph.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/graph_attr_types.h>
#include <nnvm/pass.h>
#include <nnvm/compiler/op_attr_types.h>
#include <nnvm/top/nn.h>
#include "./graph_transform.h"
namespace nnvm {
namespace compiler {
std::vector<NodeEntry>
BatchNormToInferUnpack(const nnvm::NodeAttrs& attrs,
nnvm::NodeEntry data,
nnvm::NodeEntry gamma,
nnvm::NodeEntry beta,
nnvm::NodeEntry moving_mean,
nnvm::NodeEntry moving_var,
int data_dim) {
CHECK(attrs.op);
static const Op* bn_op = Op::Get("batch_norm");
CHECK(attrs.op == bn_op);
const auto& param = nnvm::get<top::BatchNormParam>(attrs.parsed);
std::string bn_name = attrs.name;
// transform batch_norm(data) to scale * data + shift
NodeEntry var_add_eps = MakeNode(
"__add_scalar__", bn_name + "_add_eps",
{moving_var}, {{"scalar", std::to_string(param.epsilon)}});
NodeEntry sqrt = MakeNode(
"sqrt", bn_name + "_sqrt", {var_add_eps});
NodeEntry scale = MakeNode(
"__rdiv_scalar__", bn_name + "_div",
{sqrt}, {{"scalar", "1"}});
if (param.scale) {
scale = MakeNode(
"elemwise_mul", bn_name + "_gamma_mul_div",
{scale, gamma});
}
NodeEntry neg_mean = MakeNode(
"negative", bn_name + "_neg_mean", {moving_mean});
NodeEntry shift = MakeNode(
"elemwise_mul", bn_name + "_neg_mean_mul_a",
{neg_mean, scale});
if (param.center) {
shift = MakeNode(
"elemwise_add", bn_name + "_add_beta", {shift, beta});
}
// reshape to nhwc
std::ostringstream oshape;
oshape << "(";
for (int i = 0; i < data_dim; ++i) {
if (i != 0) oshape << ", ";
if (i == param.axis) {
oshape << "-1";
} else {
oshape << "1";
}
}
oshape << ")";
scale = MakeNode("reshape", bn_name + "_sc_reshape",
{scale}, {{"shape", oshape.str()}});
shift = MakeNode("reshape", bn_name + "_sh_reshape",
{shift}, {{"shape", oshape.str()}});
NodeEntry out = MakeNode("broadcast_mul", bn_name + "_a_mul_data",
{data, scale});
out = MakeNode("broadcast_add", bn_name + "_out",
{out, shift});
// It is invalid to ref the other values of BN after infernece transform.
NodeEntry undef = MakeNode("__undef__", "undef", {});
return {out, undef, undef};
}
Graph SimplifyBatchNormInference(nnvm::Graph src) {
// Get attributes from the graph
const IndexedGraph& idx = src.indexed_graph();
const ShapeVector& shape_vec = src.GetAttr<ShapeVector>("shape");
auto transform = [&](uint32_t nid, const Node* n, std::vector<NodeEntry>* ret) {
if (n->is_variable()) return false;
static const Op* bn_op = Op::Get("batch_norm");
if (n->op() == bn_op) {
*ret = BatchNormToInferUnpack(
n->attrs,
n->inputs[0],
n->inputs[1],
n->inputs[2],
n->inputs[3],
n->inputs[4],
shape_vec[idx.entry_id(nid, 0)].ndim());
return true;
} else {
return false;
}
};
return GraphTransform(src, transform);
}
NNVM_REGISTER_PASS(SimplifyBatchNormInference)
.set_body(SimplifyBatchNormInference);
} // namespace compiler
} // namespace nnvm
/*!
* Copyright (c) 2017 by Contributors
* \file print_graph_ir.cc
* \brief Print the graph IR in LLVM style human readable format.
*/
#include <nnvm/graph.h>
#include <nnvm/pass.h>
#include <iostream>
namespace nnvm {
namespace pass {
// print the graph ir in readable format
void PrintGraphIR_(Graph src, std::ostream& os) { // NOLINT(*)
const IndexedGraph& idx = src.indexed_graph();
os << "Graph(";
if (idx.input_nodes().size() < 4) {
for (size_t i = 0; i < idx.input_nodes().size(); ++i) {
uint32_t nid = idx.input_nodes()[i];
if (i != 0) {
os << ", ";
}
os << '%' << idx[nid].source->attrs.name;
}
} else {
for (size_t i = 0; i < idx.input_nodes().size(); ++i) {
uint32_t nid = idx.input_nodes()[i];
if (i != 0) {
os << ",\n ";
}
os << '%' << idx[nid].source->attrs.name;
}
}
os << ") {\n";
auto print_entry = [&](const IndexedGraph::NodeEntry& e) {
if (idx[e.node_id].source->is_variable()) {
os << '%' << idx[e.node_id].source->attrs.name;
} else if (idx[e.node_id].source->num_outputs() == 1) {
os << '%' << e.node_id;
} else {
os << '%' << e.node_id << "." << e.index;
}
};
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid];
if (inode.source->is_variable()) continue;
os << " " << "%" << nid << " = "
<< inode.source->op()->name << "(";
bool first = true;
for (const IndexedGraph::NodeEntry& e : inode.inputs) {
if (first) {
first = false;
} else {
os << ", ";
}
print_entry(e);
}
for (const auto& kv : inode.source->attrs.dict) {
if (first) {
first = false;
} else {
os << ", ";
}
os << kv.first << "=\'" << kv.second << "\'";
}
os << ")";
if (inode.control_deps.size() != 0) {
os << ", control_deps=[";
for (size_t i = 0; i < inode.control_deps.size(); ++i) {
if (i != 0) os << ", ";
uint32_t cid = inode.control_deps[i];
if (idx[cid].source->is_variable()) {
os << '%' << idx[cid].source->attrs.name;
} else {
os << '%' << cid;
}
}
os << "]";
}
os << "\n";
}
os << " ret ";
{
bool first = true;
for (const IndexedGraph::NodeEntry& e : idx.outputs()) {
if (first) {
first = false;
} else {
os << ", ";
}
print_entry(e);
}
}
os << "\n}";
if (src.attrs.size() != 0) {
os << "\ngraph_attr_keys = [";
bool first = true;
for (const auto& kv : src.attrs) {
if (first) {
first = false;
} else {
os << ", ";
}
os << kv.first;
}
os << "]\n";
}
}
// save a graph to json
Graph PrintGraphIR(Graph src) {
std::ostringstream os;
PrintGraphIR_(src, os);
Graph ret;
ret.attrs["graphir"] = std::make_shared<any>(os.str());
return ret;
}
// register pass
NNVM_REGISTER_PASS(PrintGraphIR)
.describe("Return a empty Graph, save ir to ret.attrs[\"graphir\"]")
.set_body(PrintGraphIR);
} // namespace pass
} // namespace nnvm
......@@ -11,6 +11,7 @@
#include <nnvm/top/nn.h>
#include <string>
#include <vector>
#include <utility>
#include <algorithm>
namespace nnvm {
......
......@@ -12,6 +12,16 @@
namespace nnvm {
namespace top {
// undefined op
NNVM_REGISTER_ELEMWISE_UNARY_OP(__undef__)
.describe(R"code(undefined op.
Used to produce invalide node during optimization.
)code" NNVM_ADD_FILELINE)
.set_num_outputs(1)
.set_num_inputs(0);
// sigmoid
NNVM_REGISTER_ELEMWISE_UNARY_OP(sigmoid)
.describe(R"code(Computes sigmoid.
......@@ -52,6 +62,16 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(log)
)code" NNVM_ADD_FILELINE)
.set_support_level(1);
// sqrt
NNVM_REGISTER_ELEMWISE_UNARY_OP(sqrt)
.describe(R"code(Returns the sqrt input array, computed element-wise.
.. math::
\sqrt(x)
)code" NNVM_ADD_FILELINE)
.set_support_level(1);
// binary ops
NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_add)
......
"""Unittest cases for graph pass"""
import nnvm
import nnvm.compiler
from nnvm.compiler import graph_pass
from nnvm import symbol as sym
from nnvm.compiler import graph_pass, graph_attr
def test_infer_attr():
x = nnvm.symbol.Variable("x")
x = sym.Variable("x")
y = x * 2
g = nnvm.graph.create(y)
ishape, oshape = graph_pass.infer_shape(g, x=(10,20))
......@@ -13,6 +14,5 @@ def test_infer_attr():
itype, otype = graph_pass.infer_dtype(g, x="float32")
assert otype[0] == "float32"
if __name__ == "__main__":
test_infer_attr()
"""Unittest cases for simplify batch_norm"""
import nnvm
from nnvm import symbol as sym
from nnvm.compiler import graph_pass, graph_attr
def test_simplify_batchnorm():
def simple_bn(x, gamma, beta, moving_mean, moving_var,
axis=1, epsilon=1e-5, dim=2):
# expect = (x - moving_mean) / sym.sqrt(moving_var + eps) * gamma + beta
scale = sym.elemwise_mul(1 / sym.sqrt(moving_var + epsilon), gamma)
shift = sym.elemwise_add(
sym.elemwise_mul(sym.negative(moving_mean), scale), beta)
# for 2D
shape = tuple(1 if i != axis else -1 for i in range(dim))
scale = sym.reshape(scale, shape=shape)
shift = sym.reshape(shift, shape=shape)
return x * scale + shift
# Before simplify
def check(dim, axis, nstep):
eps = 0.01
x = sym.Variable("x") + 1
beta = sym.Variable("beta")
gamma = sym.Variable("gamma")
moving_var = sym.Variable("moving_var")
moving_mean = sym.Variable("moving_mean")
y1, y2 = x, x
for i in range(nstep):
y1 = sym.batch_norm(
y1 + 1, gamma, beta, moving_mean, moving_var, epsilon=eps, axis=axis)
y2 = simple_bn(y2 + 1, gamma, beta, moving_mean, moving_var,
epsilon=eps, axis=axis, dim=dim)
g = nnvm.graph.create(y1)
g2 = nnvm.graph.create(y2)
ishape = {"x": tuple(10 for i in range(dim))}
graph_attr.set_shape_inputs(g, ishape)
g1 = g.apply("InferShape").apply("SimplifyBatchNormInference")
# Some prints for debug
# print(g1.graphir())
# assert graph equals as expected
graph_pass.check_graph_equal(g1, g2)
check(2, 1, 1)
check(4, 0, 3)
if __name__ == "__main__":
test_simplify_batchnorm()
import nnvm.symbol as sym
import nnvm.graph as graph
def test_dense():
x = sym.Variable('x')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment