Unverified Commit ae89afe0 by Zhi Committed by GitHub

[Fix] Add ConstantNode to IsAtomic (#5457)

* add constantnode to atomic

* Add ToANormalForm to FoldConstant
parent 5d75992d
...@@ -203,6 +203,7 @@ class ConstantFolder : public ExprMutator { ...@@ -203,6 +203,7 @@ class ConstantFolder : public ExprMutator {
// Constant evaluate a expression. // Constant evaluate a expression.
Expr ConstEvaluate(Expr expr) { Expr ConstEvaluate(Expr expr) {
std::vector<transform::Pass> passes = {transform::FuseOps(0), std::vector<transform::Pass> passes = {transform::FuseOps(0),
transform::ToANormalForm(),
transform::InferType()}; transform::InferType()};
Function func; Function func;
if (expr.as<FunctionNode>()) { if (expr.as<FunctionNode>()) {
......
...@@ -32,6 +32,25 @@ def run_opt_pass(expr, opt_pass): ...@@ -32,6 +32,25 @@ def run_opt_pass(expr, opt_pass):
return entry if isinstance(expr, relay.Function) else entry.body return entry if isinstance(expr, relay.Function) else entry.body
def test_concatenate_const():
def before():
data = tvm.nd.array(np.array([1.0, 2.0, 3.0]))
const = relay.const(data)
concat = relay.op.concatenate([const, const], axis=0)
func = relay.Function([], concat)
return func
def expected():
data = tvm.nd.array(np.array([1.0, 2.0, 3.0, 1.0, 2.0, 3.0]))
const = relay.const(data)
func = relay.Function([], const)
return func
zz = run_opt_pass(before(), transform.FoldConstant())
zexpected = run_opt_pass(expected(), transform.InferType())
assert tvm.ir.structural_equal(zz, zexpected)
def test_fold_const(): def test_fold_const():
c_data = np.array([1, 2, 3]).astype("float32") c_data = np.array([1, 2, 3]).astype("float32")
t = relay.TensorType([1, 2, 3], "float32") t = relay.TensorType([1, 2, 3], "float32")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment