Commit 71abe36e by Haichen Shen Committed by Jared Roesch

[Relay][OP] Fix bias_add default axis (#2829)

* Fix bias add default axis

* update

* Fix canonicalize ops for bias_add
parent 1dab4dcc
...@@ -34,7 +34,7 @@ def _mx_fully_connected(inputs, attrs): ...@@ -34,7 +34,7 @@ def _mx_fully_connected(inputs, attrs):
res = _op.nn.dense(inputs[0], inputs[1], units=units) res = _op.nn.dense(inputs[0], inputs[1], units=units)
if use_bias: if use_bias:
assert len(inputs) == 3 assert len(inputs) == 3
res = _op.nn.bias_add(res, inputs[2]) res = _op.nn.bias_add(res, inputs[2], axis=-1)
return res return res
...@@ -413,7 +413,7 @@ def _mx_batch_dot(inputs, attrs): ...@@ -413,7 +413,7 @@ def _mx_batch_dot(inputs, attrs):
raise tvm.error.OpAttributeInvalid(msg.format(transpose_a)) raise tvm.error.OpAttributeInvalid(msg.format(transpose_a))
if transpose_b is False: if transpose_b is False:
b = _op.transpose(b, axes=[0, 2, 1]) b = _op.transpose(b, axes=[0, 2, 1])
return _op.batch_matmul(a, b) return _op.nn.batch_matmul(a, b)
def _mx_arange(inputs, attrs): def _mx_arange(inputs, attrs):
......
...@@ -248,7 +248,7 @@ def get_net(batch_size, ...@@ -248,7 +248,7 @@ def get_net(batch_size,
flatten = relay.nn.batch_flatten(pool) flatten = relay.nn.batch_flatten(pool)
fc1 = relay.nn.dense(flatten, relay.var("fc1_weight"), units=num_classes) fc1 = relay.nn.dense(flatten, relay.var("fc1_weight"), units=num_classes)
fc1 = relay.nn.bias_add(fc1, relay.var("fc2_bias")) fc1 = relay.nn.bias_add(fc1, relay.var("fc2_bias"), axis=-1)
inception_v3 = relay.nn.softmax(data=fc1) inception_v3 = relay.nn.softmax(data=fc1)
args = relay.ir_pass.free_vars(inception_v3) args = relay.ir_pass.free_vars(inception_v3)
return relay.Function(args, inception_v3) return relay.Function(args, inception_v3)
......
...@@ -134,5 +134,5 @@ def dense_add_bias(data, weight=None, bias=None, units=None, **kwargs): ...@@ -134,5 +134,5 @@ def dense_add_bias(data, weight=None, bias=None, units=None, **kwargs):
if not bias: if not bias:
bias = relay.var(name + "_bias") bias = relay.var(name + "_bias")
data = relay.nn.dense(data, weight, units, **kwargs) data = relay.nn.dense(data, weight, units, **kwargs)
data = relay.nn.bias_add(data, bias) data = relay.nn.bias_add(data, bias, axis=-1)
return data return data
...@@ -50,13 +50,13 @@ def get_net(batch_size, ...@@ -50,13 +50,13 @@ def get_net(batch_size,
dtype=dtype) dtype=dtype)
data = relay.nn.batch_flatten(data) data = relay.nn.batch_flatten(data)
fc1 = relay.nn.dense(data, relay.var("fc1_weight"), units=128) fc1 = relay.nn.dense(data, relay.var("fc1_weight"), units=128)
fc1 = relay.nn.bias_add(fc1, relay.var("fc1_bias")) fc1 = relay.nn.bias_add(fc1, relay.var("fc1_bias"), axis=-1)
act1 = relay.nn.relu(fc1) act1 = relay.nn.relu(fc1)
fc2 = relay.nn.dense(act1, relay.var("fc2_weight"), units=64) fc2 = relay.nn.dense(act1, relay.var("fc2_weight"), units=64)
fc2 = relay.nn.bias_add(fc2, relay.var("fc2_bias")) fc2 = relay.nn.bias_add(fc2, relay.var("fc2_bias"), axis=-1)
act2 = relay.nn.relu(fc2) act2 = relay.nn.relu(fc2)
fc3 = relay.nn.dense(act2, relay.var("fc3_weight"), units=num_classes) fc3 = relay.nn.dense(act2, relay.var("fc3_weight"), units=num_classes)
fc3 = relay.nn.bias_add(fc3, relay.var("fc3_bias")) fc3 = relay.nn.bias_add(fc3, relay.var("fc3_bias"), axis=-1)
mlp = relay.nn.softmax(data=fc3) mlp = relay.nn.softmax(data=fc3)
args = relay.ir_pass.free_vars(mlp) args = relay.ir_pass.free_vars(mlp)
return relay.Function(args, mlp) return relay.Function(args, mlp)
......
...@@ -24,7 +24,11 @@ class BiasAddSimplifier : public ExprMutator { ...@@ -24,7 +24,11 @@ class BiasAddSimplifier : public ExprMutator {
auto ttype = n->args[0]->type_as<TensorTypeNode>(); auto ttype = n->args[0]->type_as<TensorTypeNode>();
size_t n_dim = ttype->shape.size(); size_t n_dim = ttype->shape.size();
Expr expanded_bias = ExpandBiasToMatchAxis(call->args[1], n_dim, {param->axis}); int axis = param->axis;
if (axis < 0) {
axis += n_dim;
}
Expr expanded_bias = ExpandBiasToMatchAxis(call->args[1], n_dim, {axis});
Expr ret = Add(call->args[0], expanded_bias); Expr ret = Add(call->args[0], expanded_bias);
ret->checked_type_ = n->checked_type_; ret->checked_type_ = n->checked_type_;
return ret; return ret;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment