Commit e470f8ea by Wuwei Lin Committed by Tianqi Chen

[RELAY] Fix type info after mutation in simplify inference (#2093)

parent ba3ddcd7
......@@ -15,7 +15,8 @@ Expr BatchNormToInferUnpack(const Attrs attrs,
Expr gamma,
Expr beta,
Expr moving_mean,
Expr moving_var) {
Expr moving_var,
Type tdata) {
const auto param = attrs.as<BatchNormAttrs>();
Expr epsilon = MakeConstantScalar(Float(32), static_cast<float>(param->epsilon));
Expr var_add_eps = Add(moving_var, epsilon);
......@@ -32,9 +33,11 @@ Expr BatchNormToInferUnpack(const Attrs attrs,
}
int axis = param->axis;
const auto* tdata = data->type_as<TensorTypeNode>();
scale = ExpandBiasToMatchAxis(scale, tdata->shape.size(), {axis});
shift = ExpandBiasToMatchAxis(shift, tdata->shape.size(), {axis});
auto ttype = tdata.as<TensorTypeNode>();
CHECK(ttype);
auto ndim = ttype->shape.size();
scale = ExpandBiasToMatchAxis(scale, ndim, {axis});
shift = ExpandBiasToMatchAxis(shift, ndim, {axis});
Expr out = Multiply(data, scale);
out = Add(out, shift);
......@@ -54,14 +57,26 @@ class InferenceSimplifier : public ExprMutator {
}
if (const auto* call = new_n->tuple.as<CallNode>()) {
if (call->op.same_as(batch_norm)) {
return BatchNormToInferUnpack(call->attrs,
call->args[0], call->args[1], call->args[2], call->args[3], call->args[4]);
return BatchNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2],
call->args[3], call->args[4], ty_map_.at(call->args[0]));
} else if (call->op.same_as(dropout)) {
return call->args[0];
}
}
return new_e;
}
Expr VisitExpr_(const CallNode* n) {
static const Op& batch_norm = Op::Get("nn.batch_norm");
auto new_n = ExprMutator::VisitExpr_(n);
if (n->op.same_as(batch_norm)) {
ty_map_[new_n.as<CallNode>()->args[0]] = n->args[0]->checked_type();
}
return new_n;
}
private:
std::unordered_map<Expr, Type, NodeHash, NodeEqual> ty_map_;
};
Expr SimplifyInference(const Expr& e) {
......
......@@ -30,12 +30,12 @@ def test_simplify_batchnorm():
y1, _, _ = rly.nn.batch_norm(y1 + rly.const(1, 'float32'),
gamma, beta, moving_mean, moving_var, epsilon=eps, axis=axis)
y1 = rly.nn.dropout(y1)
y1 = rly.ir_pass.infer_type(y1)
y1 = simplify_inference(y1)
y2 = simple_bn(y2 + rly.const(1, 'float32'),
gamma, beta, moving_mean, moving_var,
epsilon=eps, axis=axis, shape=ttype1.shape)
y1 = rly.ir_pass.infer_type(y1)
y1 = simplify_inference(y1)
assert rly.ir_pass.graph_equal(y1, y2)
check(2, 1, 1)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment