Commit 813a3d52 by 雾雨魔理沙 Committed by Wuwei Lin

[Relay] Feature Detection (#3238)

* init

init

lint

rename

ci

fix

add

add some doc

save

add some test

add some test

lint

lint

lint

* fix build
parent 329378cf
...@@ -140,7 +140,7 @@ class Integer : public Expr { ...@@ -140,7 +140,7 @@ class Integer : public Expr {
*/ */
operator int64_t() const { operator int64_t() const {
CHECK(node_ != nullptr) CHECK(node_ != nullptr)
<< " Trying get reference a null Integer"; << " Trying to reference a null Integer";
return (*this)->value; return (*this)->value;
} }
/*! \brief type indicate the container type */ /*! \brief type indicate the container type */
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/relay/feature.h
* \brief Detect features used in Expr/Module.
*/
#ifndef TVM_RELAY_FEATURE_H_
#define TVM_RELAY_FEATURE_H_
#include <tvm/node/container.h>
#include <tvm/expr.h>
#include <bitset>
namespace tvm {
namespace relay {
/*! \brief Different kinds of relay feature a program might use. */
enum Feature : int {
fVar = 0,
fGlobalVar = 1,
fConstant = 2,
fTuple = 3,
fTupleGetItem = 4,
fFunction = 5,
fOp = 6,
fCall = 7,
fLet = 8,
fIf = 9,
fRefCreate = 10,
fRefRead = 11,
fRefWrite = 12,
fConstructor = 13,
fMatch = 14,
/*! \brief Whether any non-atom fragment of the program is shared, making the program a graph. */
fGraph = 15,
/*! \brief Whether there is local fixpoint in the program. */
fLetRec = 16
};
constexpr size_t feature_count = 17;
/*!
* \brief A finite set of Feature.
*/
class FeatureSet {
public:
FeatureSet(const FeatureSet&) = default;
/*! \brief A singleton set containing a single Feature. */
explicit FeatureSet(Feature ft) {
bs_.set(static_cast<size_t>(ft));
}
explicit FeatureSet(const tvm::Array<tvm::Integer>& ft) {
for (Integer i : ft) {
(*this) += Feature(static_cast<int>(i));
}
}
explicit operator Array<Integer>() const {
Array<Integer> ret;
for (size_t i = 0; i < feature_count; ++i) {
if (bs_[i]) {
ret.push_back(Integer(i));
}
}
return ret;
}
/*! \brief A set that contain all the Feature. */
static FeatureSet AllFeature() {
FeatureSet fs;
fs.bs_.flip();
return fs;
}
/*! \brief The empty set. Contain no Feature. */
static FeatureSet NoFeature() {
FeatureSet fs;
return fs;
}
template<typename T>
FeatureSet& operator+=(const T& rhs) {
bs_ |= FeatureSet(rhs).bs_;
return *this;
}
/*! \brief Set union. */
template<typename T>
FeatureSet operator+(const T& rhs) const {
FeatureSet fs(*this);
fs += rhs;
return fs;
}
template<typename T>
FeatureSet& operator-=(const T& rhs) {
bs_ &= ~(FeatureSet(rhs)).bs_;
return *this;
}
/*! \brief Set difference. */
template<typename T>
FeatureSet operator-(const T& rhs) const {
FeatureSet fs(*this);
fs -= rhs;
return fs;
}
/*!
* \brief Is this a subset of rhs?
*
* \param rhs another FeatureSet.
*
* \return true only if this is a subset of rhs.
*/
bool is_subset_of(const FeatureSet& rhs) const {
return ((*this) - rhs).bs_.none();
}
private:
std::bitset<feature_count> bs_;
FeatureSet() = default;
explicit FeatureSet(const std::bitset<feature_count>& bs) : bs_(bs) { }
};
class Expr;
/*!
* \brief Calculate the feature of the program.
*
* \param expr The expression.
*
* \return The FeatureSet.
*/
FeatureSet DetectFeature(const Expr& expr);
struct Module;
/*!
* \brief Calculate the feature of the program.
*
* \param mod The module.
*
* \return The FeatureSet.
*/
FeatureSet DetectFeature(const Module& mod);
/*!
* \brief Calculate the feature of the program.
*
* \param expr The expression.
* \param mod The module.
*
* \return The FeatureSet.
*/
inline FeatureSet DetectFeature(const Expr& expr, const Module& mod) {
return DetectFeature(expr) + DetectFeature(mod);
}
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_FEATURE_H_
...@@ -116,10 +116,10 @@ class TensorTypeNode : public BaseTensorTypeNode { ...@@ -116,10 +116,10 @@ class TensorTypeNode : public BaseTensorTypeNode {
RELAY_DEFINE_NODE_REF(TensorType, TensorTypeNode, Type); RELAY_DEFINE_NODE_REF(TensorType, TensorTypeNode, Type);
/*! \brief possible kinds of Type */ /*! \brief Possible kinds of Type. */
enum Kind : int { enum Kind : int {
/*! \brief template variable in shape expression */
kType = 0, kType = 0,
/*! \brief Template variable in shape expression. */
kShapeVar = 1, kShapeVar = 1,
kBaseType = 2, kBaseType = 2,
kShape = 3, kShape = 3,
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""The type nodes of the Relay language."""
from enum import IntEnum
class Feature(IntEnum):
""" The features a program might contain. """
fVar = 0
fGlobalVar = 1
fConstant = 2
fTuple = 3
fTupleGetItem = 4
fFunction = 5
fOp = 6
fCall = 7
fLet = 8
fIf = 9
fRefCreate = 10
fRefRead = 11
fRefWrite = 12
fConstructor = 13
fMatch = 14
""" Whether any non-atom fragment of the program is shared, making the program a graph. """
fGraph = 15
""" Whether there is local fixpoint in the program. """
fLetRec = 16
...@@ -25,6 +25,7 @@ from . import _make ...@@ -25,6 +25,7 @@ from . import _make
from .expr import Expr from .expr import Expr
from .ty import Type from .ty import Type
from .module import Module from .module import Module
from .feature import Feature
def post_order_visit(expr, fvisit): def post_order_visit(expr, fvisit):
...@@ -604,7 +605,6 @@ def gradient(expr, mod=None, mode='higher_order'): ...@@ -604,7 +605,6 @@ def gradient(expr, mod=None, mode='higher_order'):
raise Exception('unknown mode') raise Exception('unknown mode')
def get_total_mac_number(expr): def get_total_mac_number(expr):
""" """
Count the number of MACs (multiply-accumulate) of a model Count the number of MACs (multiply-accumulate) of a model
...@@ -641,6 +641,7 @@ def eliminate_common_subexpr(expr, fskip=None): ...@@ -641,6 +641,7 @@ def eliminate_common_subexpr(expr, fskip=None):
""" """
return _ir_pass.eliminate_common_subexpr(expr, fskip) return _ir_pass.eliminate_common_subexpr(expr, fskip)
def partial_evaluate(expr, mod=None): def partial_evaluate(expr, mod=None):
""" """
Evaluate the static fragment of the code. Evaluate the static fragment of the code.
...@@ -660,6 +661,7 @@ def partial_evaluate(expr, mod=None): ...@@ -660,6 +661,7 @@ def partial_evaluate(expr, mod=None):
""" """
return _ir_pass.partial_evaluate(expr, mod) return _ir_pass.partial_evaluate(expr, mod)
def unmatched_cases(match, mod=None): def unmatched_cases(match, mod=None):
""" """
Finds cases that the match expression does not catch, if any. Finds cases that the match expression does not catch, if any.
...@@ -677,3 +679,26 @@ def unmatched_cases(match, mod=None): ...@@ -677,3 +679,26 @@ def unmatched_cases(match, mod=None):
Patterns that the match expression does not catch. Patterns that the match expression does not catch.
""" """
return _ir_pass.unmatched_cases(match, mod) return _ir_pass.unmatched_cases(match, mod)
def detect_feature(a, b=None):
"""
Detect the feature used in a relay program.
Parameters
----------
a : Union[tvm.relay.Expr, tvm.relay.Module]
The input expression or module.
b : Optional[Union[tvm.relay.Expr, tvm.relay.Module]]
The input expression or module.
The two arguments cannot both be expression or module.
Returns
-------
features : Set[Feature]
Features used in the program.
"""
if isinstance(a, Module):
a, b = b, a
return set([Feature(int(x)) for x in _ir_pass.detect_feature(a, b)])
...@@ -23,8 +23,8 @@ from .op.tensor import add, subtract, equal ...@@ -23,8 +23,8 @@ from .op.tensor import add, subtract, equal
from .adt import Constructor, TypeData, Clause, Match from .adt import Constructor, TypeData, Clause, Match
from .adt import PatternConstructor, PatternVar, PatternWildcard from .adt import PatternConstructor, PatternVar, PatternWildcard
from .parser import fromtext from .parser import fromtext
__PRELUDE_PATH__ = os.path.dirname(os.path.realpath(__file__)) __PRELUDE_PATH__ = os.path.dirname(os.path.realpath(__file__))
from .module import Module
class Prelude: class Prelude:
"""Contains standard definitions.""" """Contains standard definitions."""
...@@ -486,7 +486,9 @@ class Prelude: ...@@ -486,7 +486,9 @@ class Prelude:
self.compose = self.mod.get_global_var("compose") self.compose = self.mod.get_global_var("compose")
def __init__(self, mod): def __init__(self, mod=None):
if mod is None:
mod = Module()
self.mod = mod self.mod = mod
self.load_prelude() self.load_prelude()
self.define_list_adt() self.define_list_adt()
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
*/ */
/*! /*!
* Copyright (c) 2018 by Contributors * Copyright (c) 2019 by Contributors
* \file alter_op_layout.cc * \file alter_op_layout.cc
* \brief Alternate the layouts of operators or replace primitive operators with * \brief Alternate the layouts of operators or replace primitive operators with
other expressions. This pass can be used for computing convolution in other expressions. This pass can be used for computing convolution in
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file feature.cc
* \brief Detect features used in Expr/Module
*/
#include <tvm/relay/feature.h>
#include <tvm/relay/pass.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/module.h>
#include "pass_util.h"
namespace tvm {
namespace relay {
FeatureSet DetectFeature(const Expr& expr) {
if (!expr.defined()) {
return FeatureSet::NoFeature();
}
struct FeatureDetector : ExprVisitor {
std::unordered_set<Expr, NodeHash, NodeEqual> visited_;
FeatureSet fs = FeatureSet::NoFeature();
void VisitExpr(const Expr& expr) final {
if (visited_.count(expr) == 0) {
ExprVisitor::VisitExpr(expr);
} else {
if (!IsAtomic(expr)) {
fs += fGraph;
}
}
}
#define DETECT_CONSTRUCT(CONSTRUCT_NAME, STMT) \
void VisitExpr_(const CONSTRUCT_NAME##Node* op) final { \
STMT \
fs += f##CONSTRUCT_NAME; \
ExprVisitor::VisitExpr_(op); \
}
#define DETECT_DEFAULT_CONSTRUCT(CONSTRUCT_NAME) DETECT_CONSTRUCT(CONSTRUCT_NAME, {})
DETECT_DEFAULT_CONSTRUCT(Var)
DETECT_DEFAULT_CONSTRUCT(GlobalVar)
DETECT_DEFAULT_CONSTRUCT(Constant)
DETECT_DEFAULT_CONSTRUCT(Tuple)
DETECT_DEFAULT_CONSTRUCT(TupleGetItem)
DETECT_DEFAULT_CONSTRUCT(Function)
DETECT_DEFAULT_CONSTRUCT(Op)
DETECT_DEFAULT_CONSTRUCT(Call)
DETECT_CONSTRUCT(Let, {
for (const Var& v : FreeVars(op->value)) {
if (op->var == v) {
fs += fLetRec;
}
}
})
DETECT_DEFAULT_CONSTRUCT(If)
DETECT_DEFAULT_CONSTRUCT(RefCreate)
DETECT_DEFAULT_CONSTRUCT(RefRead)
DETECT_DEFAULT_CONSTRUCT(RefWrite)
DETECT_DEFAULT_CONSTRUCT(Constructor)
DETECT_DEFAULT_CONSTRUCT(Match)
#undef DETECT_DEFAULT_CONSTRUCT
} fd;
fd(expr);
return fd.fs;
}
FeatureSet DetectFeature(const Module& mod) {
FeatureSet fs = FeatureSet::NoFeature();
if (mod.defined()) {
for (const auto& f : mod->functions) {
fs += DetectFeature(f.second);
}
}
return fs;
}
Array<Integer> PyDetectFeature(const Expr& expr, const Module& mod) {
FeatureSet fs = DetectFeature(expr) + DetectFeature(mod);
return static_cast<Array<Integer>>(fs);
}
TVM_REGISTER_API("relay._ir_pass.detect_feature")
.set_body_typed(PyDetectFeature);
} // namespace relay
} // namespace tvm
...@@ -39,7 +39,8 @@ ...@@ -39,7 +39,8 @@
namespace tvm { namespace tvm {
namespace relay { namespace relay {
/*! \brief LetList allow you to transform expression into variables, so you can copy them around. /*!
* \brief LetList allow you to transform expression into variables, so you can copy them around.
* one can insert into the LetList by calling Push, and wrap an expression with bindings with Get. * one can insert into the LetList by calling Push, and wrap an expression with bindings with Get.
* additionally, there is the 'With' function, which automatically call Get. * additionally, there is the 'With' function, which automatically call Get.
*/ */
......
...@@ -389,10 +389,6 @@ FInterpreter CPUInterpreter() { ...@@ -389,10 +389,6 @@ FInterpreter CPUInterpreter() {
return CreateInterpreter(Module(nullptr), CPUContext(), target); return CreateInterpreter(Module(nullptr), CPUContext(), target);
} }
bool IsAtomic(const Expr& e) {
return e.as<VarNode>() || e.as<OpNode>() || e.as<ConstructorNode>() || e.as<GlobalVarNode>();
}
using FuncId = int; using FuncId = int;
/*! /*!
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -97,6 +97,17 @@ inline Expr TransformF(const std::function<Expr(const Expr&)>& func, const Expr& ...@@ -97,6 +97,17 @@ inline Expr TransformF(const std::function<Expr(const Expr&)>& func, const Expr&
} }
} }
/*!
* \brief Decide whether the expression atomic or not?
* \param e the expression
* \return
* is it atomic?
* if so, the compute cost of the expression is bounded so it can be copy without graph mode.
*/
inline bool IsAtomic(const Expr& e) {
return e.as<VarNode>() || e.as<OpNode>() || e.as<ConstructorNode>() || e.as<GlobalVarNode>();
}
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_PASS_PASS_UTIL_H_ #endif // TVM_RELAY_PASS_PASS_UTIL_H_
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import relay
from tvm.relay.ir_pass import detect_feature, gradient
from tvm.relay.feature import Feature
from tvm.relay.prelude import Prelude
def test_prelude():
p = Prelude()
feats = detect_feature(p.mod)
assert feats == set([
Feature.fVar,
Feature.fGlobalVar,
Feature.fConstant,
Feature.fTuple,
Feature.fTupleGetItem,
Feature.fFunction,
Feature.fOp,
Feature.fCall,
Feature.fLet,
Feature.fIf,
Feature.fConstructor,
Feature.fMatch
])
def test_ad():
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
x = relay.var("x", t)
func = relay.Function([x], x + x)
back_func = relay.ir_pass.infer_type(gradient(func))
feats = detect_feature(back_func)
assert feats == set([
Feature.fVar,
Feature.fTuple,
Feature.fTupleGetItem,
Feature.fFunction,
Feature.fOp,
Feature.fCall,
Feature.fLet,
Feature.fRefCreate,
Feature.fRefRead,
Feature.fRefWrite
])
if __name__ == '__main__':
test_prelude()
test_ad()
...@@ -17,11 +17,12 @@ ...@@ -17,11 +17,12 @@
import numpy as np import numpy as np
import tvm import tvm
from tvm import relay from tvm import relay
from tvm.relay.ir_pass import to_a_normal_form, alpha_equal, infer_type from tvm.relay.ir_pass import to_a_normal_form, alpha_equal, infer_type, detect_feature
from tvm.relay import op, create_executor from tvm.relay import op, create_executor
from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue
from tvm.relay.prelude import Prelude from tvm.relay.prelude import Prelude
from tvm.relay.testing import add_nat_definitions, count from tvm.relay.testing import add_nat_definitions, count
from tvm.relay.feature import Feature
def check_eval(expr, expected_result, mod=None, rtol=1e-07): def check_eval(expr, expected_result, mod=None, rtol=1e-07):
...@@ -37,9 +38,9 @@ def test_explicit_bound(): ...@@ -37,9 +38,9 @@ def test_explicit_bound():
y = op.add(x, x) y = op.add(x, x)
z = op.add(y, y) z = op.add(y, y)
f = relay.Function([], op.add(z, z)) f = relay.Function([], op.add(z, z))
assert not "let" in f.astext() # assert the values are implicitly bounded assert not Feature.fLet in detect_feature(f)
anf = to_a_normal_form(f) anf = to_a_normal_form(f)
assert "let" in anf.astext() # assert the values are explicitly bounded assert Feature.fLet in detect_feature(anf)
check_eval(f(), 8.0) check_eval(f(), 8.0)
check_eval(anf(), 8.0) check_eval(anf(), 8.0)
...@@ -144,7 +145,7 @@ def test_nat_add(): ...@@ -144,7 +145,7 @@ def test_nat_add():
assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat()) assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat())
assert count(p, intrp.evaluate(add(s(z()), s(z())))) == 2 assert count(p, intrp.evaluate(add(s(z()), s(z())))) == 2
assert count(p, intrp.evaluate(to_a_normal_form(add(s(z()), s(z())), mod))) == 2 assert count(p, intrp.evaluate(to_a_normal_form(add(s(z()), s(z())), mod))) == 2
assert "let" in mod[add].astext() assert Feature.fLet in detect_feature(mod[add])
def test_let(): def test_let():
...@@ -173,7 +174,6 @@ if __name__ == '__main__': ...@@ -173,7 +174,6 @@ if __name__ == '__main__':
test_if() test_if()
test_recursion() test_recursion()
test_ref() test_ref()
test_add()
test_let() test_let()
test_nat_add() test_nat_add()
test_function() test_function()
...@@ -17,8 +17,9 @@ ...@@ -17,8 +17,9 @@
import numpy as np import numpy as np
import tvm import tvm
from tvm import relay from tvm import relay
from tvm.relay.ir_pass import to_graph_normal_form, to_a_normal_form, alpha_equal from tvm.relay.ir_pass import to_graph_normal_form, to_a_normal_form, alpha_equal, detect_feature
from tvm.relay import op, create_executor from tvm.relay import op, create_executor
from tvm.relay.feature import Feature
from tvm.relay.backend.interpreter import Value, TupleValue from tvm.relay.backend.interpreter import Value, TupleValue
...@@ -56,8 +57,8 @@ def test_round_trip(): ...@@ -56,8 +57,8 @@ def test_round_trip():
f = relay.Function([], relay.Let(x, relay.const(1), body)) f = relay.Function([], relay.Let(x, relay.const(1), body))
g = to_graph_normal_form(f) g = to_graph_normal_form(f)
h = to_a_normal_form(g) h = to_a_normal_form(g)
assert "let" in f.astext() assert Feature.fLet in detect_feature(f)
assert not "let" in g.astext() assert not Feature.fLet in detect_feature(g)
check_eval(f, [], 8.0) check_eval(f, [], 8.0)
check_eval(g, [], 8.0) check_eval(g, [], 8.0)
check_eval(h, [], 8.0) check_eval(h, [], 8.0)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment