Commit 71abe36e by Haichen Shen Committed by Jared Roesch

[Relay][OP] Fix bias_add default axis (#2829)

* Fix bias add default axis

* update

* Fix canonicalize ops for bias_add
parent 1dab4dcc
......@@ -34,7 +34,7 @@ def _mx_fully_connected(inputs, attrs):
res = _op.nn.dense(inputs[0], inputs[1], units=units)
if use_bias:
assert len(inputs) == 3
res = _op.nn.bias_add(res, inputs[2])
res = _op.nn.bias_add(res, inputs[2], axis=-1)
return res
......@@ -413,7 +413,7 @@ def _mx_batch_dot(inputs, attrs):
raise tvm.error.OpAttributeInvalid(msg.format(transpose_a))
if transpose_b is False:
b = _op.transpose(b, axes=[0, 2, 1])
return _op.batch_matmul(a, b)
return _op.nn.batch_matmul(a, b)
def _mx_arange(inputs, attrs):
......
......@@ -248,7 +248,7 @@ def get_net(batch_size,
flatten = relay.nn.batch_flatten(pool)
fc1 = relay.nn.dense(flatten, relay.var("fc1_weight"), units=num_classes)
fc1 = relay.nn.bias_add(fc1, relay.var("fc2_bias"))
fc1 = relay.nn.bias_add(fc1, relay.var("fc2_bias"), axis=-1)
inception_v3 = relay.nn.softmax(data=fc1)
args = relay.ir_pass.free_vars(inception_v3)
return relay.Function(args, inception_v3)
......
......@@ -134,5 +134,5 @@ def dense_add_bias(data, weight=None, bias=None, units=None, **kwargs):
if not bias:
bias = relay.var(name + "_bias")
data = relay.nn.dense(data, weight, units, **kwargs)
data = relay.nn.bias_add(data, bias)
data = relay.nn.bias_add(data, bias, axis=-1)
return data
......@@ -50,13 +50,13 @@ def get_net(batch_size,
dtype=dtype)
data = relay.nn.batch_flatten(data)
fc1 = relay.nn.dense(data, relay.var("fc1_weight"), units=128)
fc1 = relay.nn.bias_add(fc1, relay.var("fc1_bias"))
fc1 = relay.nn.bias_add(fc1, relay.var("fc1_bias"), axis=-1)
act1 = relay.nn.relu(fc1)
fc2 = relay.nn.dense(act1, relay.var("fc2_weight"), units=64)
fc2 = relay.nn.bias_add(fc2, relay.var("fc2_bias"))
fc2 = relay.nn.bias_add(fc2, relay.var("fc2_bias"), axis=-1)
act2 = relay.nn.relu(fc2)
fc3 = relay.nn.dense(act2, relay.var("fc3_weight"), units=num_classes)
fc3 = relay.nn.bias_add(fc3, relay.var("fc3_bias"))
fc3 = relay.nn.bias_add(fc3, relay.var("fc3_bias"), axis=-1)
mlp = relay.nn.softmax(data=fc3)
args = relay.ir_pass.free_vars(mlp)
return relay.Function(args, mlp)
......
......@@ -24,7 +24,11 @@ class BiasAddSimplifier : public ExprMutator {
auto ttype = n->args[0]->type_as<TensorTypeNode>();
size_t n_dim = ttype->shape.size();
Expr expanded_bias = ExpandBiasToMatchAxis(call->args[1], n_dim, {param->axis});
int axis = param->axis;
if (axis < 0) {
axis += n_dim;
}
Expr expanded_bias = ExpandBiasToMatchAxis(call->args[1], n_dim, {axis});
Expr ret = Add(call->args[0], expanded_bias);
ret->checked_type_ = n->checked_type_;
return ret;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment