Commit 813a3d52 by 雾雨魔理沙 Committed by Wuwei Lin

[Relay] Feature Detection (#3238)

* init

init

lint

rename

ci

fix

add

add some doc

save

add some test

add some test

lint

lint

lint

* fix build
parent 329378cf
......@@ -140,7 +140,7 @@ class Integer : public Expr {
*/
operator int64_t() const {
CHECK(node_ != nullptr)
<< " Trying get reference a null Integer";
<< " Trying to reference a null Integer";
return (*this)->value;
}
/*! \brief type indicate the container type */
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/relay/feature.h
* \brief Detect features used in Expr/Module.
*/
#ifndef TVM_RELAY_FEATURE_H_
#define TVM_RELAY_FEATURE_H_
#include <tvm/node/container.h>
#include <tvm/expr.h>
#include <bitset>
namespace tvm {
namespace relay {
/*! \brief Different kinds of relay feature a program might use. */
enum Feature : int {
fVar = 0,
fGlobalVar = 1,
fConstant = 2,
fTuple = 3,
fTupleGetItem = 4,
fFunction = 5,
fOp = 6,
fCall = 7,
fLet = 8,
fIf = 9,
fRefCreate = 10,
fRefRead = 11,
fRefWrite = 12,
fConstructor = 13,
fMatch = 14,
/*! \brief Whether any non-atom fragment of the program is shared, making the program a graph. */
fGraph = 15,
/*! \brief Whether there is local fixpoint in the program. */
fLetRec = 16
};
constexpr size_t feature_count = 17;
/*!
* \brief A finite set of Feature.
*/
class FeatureSet {
public:
FeatureSet(const FeatureSet&) = default;
/*! \brief A singleton set containing a single Feature. */
explicit FeatureSet(Feature ft) {
bs_.set(static_cast<size_t>(ft));
}
explicit FeatureSet(const tvm::Array<tvm::Integer>& ft) {
for (Integer i : ft) {
(*this) += Feature(static_cast<int>(i));
}
}
explicit operator Array<Integer>() const {
Array<Integer> ret;
for (size_t i = 0; i < feature_count; ++i) {
if (bs_[i]) {
ret.push_back(Integer(i));
}
}
return ret;
}
/*! \brief A set that contain all the Feature. */
static FeatureSet AllFeature() {
FeatureSet fs;
fs.bs_.flip();
return fs;
}
/*! \brief The empty set. Contain no Feature. */
static FeatureSet NoFeature() {
FeatureSet fs;
return fs;
}
template<typename T>
FeatureSet& operator+=(const T& rhs) {
bs_ |= FeatureSet(rhs).bs_;
return *this;
}
/*! \brief Set union. */
template<typename T>
FeatureSet operator+(const T& rhs) const {
FeatureSet fs(*this);
fs += rhs;
return fs;
}
template<typename T>
FeatureSet& operator-=(const T& rhs) {
bs_ &= ~(FeatureSet(rhs)).bs_;
return *this;
}
/*! \brief Set difference. */
template<typename T>
FeatureSet operator-(const T& rhs) const {
FeatureSet fs(*this);
fs -= rhs;
return fs;
}
/*!
* \brief Is this a subset of rhs?
*
* \param rhs another FeatureSet.
*
* \return true only if this is a subset of rhs.
*/
bool is_subset_of(const FeatureSet& rhs) const {
return ((*this) - rhs).bs_.none();
}
private:
std::bitset<feature_count> bs_;
FeatureSet() = default;
explicit FeatureSet(const std::bitset<feature_count>& bs) : bs_(bs) { }
};
class Expr;
/*!
* \brief Calculate the feature of the program.
*
* \param expr The expression.
*
* \return The FeatureSet.
*/
FeatureSet DetectFeature(const Expr& expr);
struct Module;
/*!
* \brief Calculate the feature of the program.
*
* \param mod The module.
*
* \return The FeatureSet.
*/
FeatureSet DetectFeature(const Module& mod);
/*!
* \brief Calculate the feature of the program.
*
* \param expr The expression.
* \param mod The module.
*
* \return The FeatureSet.
*/
inline FeatureSet DetectFeature(const Expr& expr, const Module& mod) {
return DetectFeature(expr) + DetectFeature(mod);
}
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_FEATURE_H_
......@@ -116,10 +116,10 @@ class TensorTypeNode : public BaseTensorTypeNode {
RELAY_DEFINE_NODE_REF(TensorType, TensorTypeNode, Type);
/*! \brief possible kinds of Type */
/*! \brief Possible kinds of Type. */
enum Kind : int {
/*! \brief template variable in shape expression */
kType = 0,
/*! \brief Template variable in shape expression. */
kShapeVar = 1,
kBaseType = 2,
kShape = 3,
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""The type nodes of the Relay language."""
from enum import IntEnum
class Feature(IntEnum):
""" The features a program might contain. """
fVar = 0
fGlobalVar = 1
fConstant = 2
fTuple = 3
fTupleGetItem = 4
fFunction = 5
fOp = 6
fCall = 7
fLet = 8
fIf = 9
fRefCreate = 10
fRefRead = 11
fRefWrite = 12
fConstructor = 13
fMatch = 14
""" Whether any non-atom fragment of the program is shared, making the program a graph. """
fGraph = 15
""" Whether there is local fixpoint in the program. """
fLetRec = 16
......@@ -25,6 +25,7 @@ from . import _make
from .expr import Expr
from .ty import Type
from .module import Module
from .feature import Feature
def post_order_visit(expr, fvisit):
......@@ -604,7 +605,6 @@ def gradient(expr, mod=None, mode='higher_order'):
raise Exception('unknown mode')
def get_total_mac_number(expr):
"""
Count the number of MACs (multiply-accumulate) of a model
......@@ -641,6 +641,7 @@ def eliminate_common_subexpr(expr, fskip=None):
"""
return _ir_pass.eliminate_common_subexpr(expr, fskip)
def partial_evaluate(expr, mod=None):
"""
Evaluate the static fragment of the code.
......@@ -660,6 +661,7 @@ def partial_evaluate(expr, mod=None):
"""
return _ir_pass.partial_evaluate(expr, mod)
def unmatched_cases(match, mod=None):
"""
Finds cases that the match expression does not catch, if any.
......@@ -677,3 +679,26 @@ def unmatched_cases(match, mod=None):
Patterns that the match expression does not catch.
"""
return _ir_pass.unmatched_cases(match, mod)
def detect_feature(a, b=None):
"""
Detect the feature used in a relay program.
Parameters
----------
a : Union[tvm.relay.Expr, tvm.relay.Module]
The input expression or module.
b : Optional[Union[tvm.relay.Expr, tvm.relay.Module]]
The input expression or module.
The two arguments cannot both be expression or module.
Returns
-------
features : Set[Feature]
Features used in the program.
"""
if isinstance(a, Module):
a, b = b, a
return set([Feature(int(x)) for x in _ir_pass.detect_feature(a, b)])
......@@ -23,8 +23,8 @@ from .op.tensor import add, subtract, equal
from .adt import Constructor, TypeData, Clause, Match
from .adt import PatternConstructor, PatternVar, PatternWildcard
from .parser import fromtext
__PRELUDE_PATH__ = os.path.dirname(os.path.realpath(__file__))
from .module import Module
class Prelude:
"""Contains standard definitions."""
......@@ -486,7 +486,9 @@ class Prelude:
self.compose = self.mod.get_global_var("compose")
def __init__(self, mod):
def __init__(self, mod=None):
if mod is None:
mod = Module()
self.mod = mod
self.load_prelude()
self.define_list_adt()
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -18,7 +18,7 @@
*/
/*!
* Copyright (c) 2018 by Contributors
* Copyright (c) 2019 by Contributors
* \file alter_op_layout.cc
* \brief Alternate the layouts of operators or replace primitive operators with
other expressions. This pass can be used for computing convolution in
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file feature.cc
* \brief Detect features used in Expr/Module
*/
#include <tvm/relay/feature.h>
#include <tvm/relay/pass.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/module.h>
#include "pass_util.h"
namespace tvm {
namespace relay {
FeatureSet DetectFeature(const Expr& expr) {
if (!expr.defined()) {
return FeatureSet::NoFeature();
}
struct FeatureDetector : ExprVisitor {
std::unordered_set<Expr, NodeHash, NodeEqual> visited_;
FeatureSet fs = FeatureSet::NoFeature();
void VisitExpr(const Expr& expr) final {
if (visited_.count(expr) == 0) {
ExprVisitor::VisitExpr(expr);
} else {
if (!IsAtomic(expr)) {
fs += fGraph;
}
}
}
#define DETECT_CONSTRUCT(CONSTRUCT_NAME, STMT) \
void VisitExpr_(const CONSTRUCT_NAME##Node* op) final { \
STMT \
fs += f##CONSTRUCT_NAME; \
ExprVisitor::VisitExpr_(op); \
}
#define DETECT_DEFAULT_CONSTRUCT(CONSTRUCT_NAME) DETECT_CONSTRUCT(CONSTRUCT_NAME, {})
DETECT_DEFAULT_CONSTRUCT(Var)
DETECT_DEFAULT_CONSTRUCT(GlobalVar)
DETECT_DEFAULT_CONSTRUCT(Constant)
DETECT_DEFAULT_CONSTRUCT(Tuple)
DETECT_DEFAULT_CONSTRUCT(TupleGetItem)
DETECT_DEFAULT_CONSTRUCT(Function)
DETECT_DEFAULT_CONSTRUCT(Op)
DETECT_DEFAULT_CONSTRUCT(Call)
DETECT_CONSTRUCT(Let, {
for (const Var& v : FreeVars(op->value)) {
if (op->var == v) {
fs += fLetRec;
}
}
})
DETECT_DEFAULT_CONSTRUCT(If)
DETECT_DEFAULT_CONSTRUCT(RefCreate)
DETECT_DEFAULT_CONSTRUCT(RefRead)
DETECT_DEFAULT_CONSTRUCT(RefWrite)
DETECT_DEFAULT_CONSTRUCT(Constructor)
DETECT_DEFAULT_CONSTRUCT(Match)
#undef DETECT_DEFAULT_CONSTRUCT
} fd;
fd(expr);
return fd.fs;
}
FeatureSet DetectFeature(const Module& mod) {
FeatureSet fs = FeatureSet::NoFeature();
if (mod.defined()) {
for (const auto& f : mod->functions) {
fs += DetectFeature(f.second);
}
}
return fs;
}
Array<Integer> PyDetectFeature(const Expr& expr, const Module& mod) {
FeatureSet fs = DetectFeature(expr) + DetectFeature(mod);
return static_cast<Array<Integer>>(fs);
}
TVM_REGISTER_API("relay._ir_pass.detect_feature")
.set_body_typed(PyDetectFeature);
} // namespace relay
} // namespace tvm
......@@ -39,7 +39,8 @@
namespace tvm {
namespace relay {
/*! \brief LetList allow you to transform expression into variables, so you can copy them around.
/*!
* \brief LetList allow you to transform expression into variables, so you can copy them around.
* one can insert into the LetList by calling Push, and wrap an expression with bindings with Get.
* additionally, there is the 'With' function, which automatically call Get.
*/
......
......@@ -389,10 +389,6 @@ FInterpreter CPUInterpreter() {
return CreateInterpreter(Module(nullptr), CPUContext(), target);
}
bool IsAtomic(const Expr& e) {
return e.as<VarNode>() || e.as<OpNode>() || e.as<ConstructorNode>() || e.as<GlobalVarNode>();
}
using FuncId = int;
/*!
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -97,6 +97,17 @@ inline Expr TransformF(const std::function<Expr(const Expr&)>& func, const Expr&
}
}
/*!
* \brief Decide whether the expression atomic or not?
* \param e the expression
* \return
* is it atomic?
* if so, the compute cost of the expression is bounded so it can be copy without graph mode.
*/
inline bool IsAtomic(const Expr& e) {
return e.as<VarNode>() || e.as<OpNode>() || e.as<ConstructorNode>() || e.as<GlobalVarNode>();
}
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_PASS_UTIL_H_
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import relay
from tvm.relay.ir_pass import detect_feature, gradient
from tvm.relay.feature import Feature
from tvm.relay.prelude import Prelude
def test_prelude():
p = Prelude()
feats = detect_feature(p.mod)
assert feats == set([
Feature.fVar,
Feature.fGlobalVar,
Feature.fConstant,
Feature.fTuple,
Feature.fTupleGetItem,
Feature.fFunction,
Feature.fOp,
Feature.fCall,
Feature.fLet,
Feature.fIf,
Feature.fConstructor,
Feature.fMatch
])
def test_ad():
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
x = relay.var("x", t)
func = relay.Function([x], x + x)
back_func = relay.ir_pass.infer_type(gradient(func))
feats = detect_feature(back_func)
assert feats == set([
Feature.fVar,
Feature.fTuple,
Feature.fTupleGetItem,
Feature.fFunction,
Feature.fOp,
Feature.fCall,
Feature.fLet,
Feature.fRefCreate,
Feature.fRefRead,
Feature.fRefWrite
])
if __name__ == '__main__':
test_prelude()
test_ad()
......@@ -17,11 +17,12 @@
import numpy as np
import tvm
from tvm import relay
from tvm.relay.ir_pass import to_a_normal_form, alpha_equal, infer_type
from tvm.relay.ir_pass import to_a_normal_form, alpha_equal, infer_type, detect_feature
from tvm.relay import op, create_executor
from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue
from tvm.relay.prelude import Prelude
from tvm.relay.testing import add_nat_definitions, count
from tvm.relay.feature import Feature
def check_eval(expr, expected_result, mod=None, rtol=1e-07):
......@@ -37,9 +38,9 @@ def test_explicit_bound():
y = op.add(x, x)
z = op.add(y, y)
f = relay.Function([], op.add(z, z))
assert not "let" in f.astext() # assert the values are implicitly bounded
assert not Feature.fLet in detect_feature(f)
anf = to_a_normal_form(f)
assert "let" in anf.astext() # assert the values are explicitly bounded
assert Feature.fLet in detect_feature(anf)
check_eval(f(), 8.0)
check_eval(anf(), 8.0)
......@@ -144,7 +145,7 @@ def test_nat_add():
assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat())
assert count(p, intrp.evaluate(add(s(z()), s(z())))) == 2
assert count(p, intrp.evaluate(to_a_normal_form(add(s(z()), s(z())), mod))) == 2
assert "let" in mod[add].astext()
assert Feature.fLet in detect_feature(mod[add])
def test_let():
......@@ -173,7 +174,6 @@ if __name__ == '__main__':
test_if()
test_recursion()
test_ref()
test_add()
test_let()
test_nat_add()
test_function()
......@@ -17,8 +17,9 @@
import numpy as np
import tvm
from tvm import relay
from tvm.relay.ir_pass import to_graph_normal_form, to_a_normal_form, alpha_equal
from tvm.relay.ir_pass import to_graph_normal_form, to_a_normal_form, alpha_equal, detect_feature
from tvm.relay import op, create_executor
from tvm.relay.feature import Feature
from tvm.relay.backend.interpreter import Value, TupleValue
......@@ -56,8 +57,8 @@ def test_round_trip():
f = relay.Function([], relay.Let(x, relay.const(1), body))
g = to_graph_normal_form(f)
h = to_a_normal_form(g)
assert "let" in f.astext()
assert not "let" in g.astext()
assert Feature.fLet in detect_feature(f)
assert not Feature.fLet in detect_feature(g)
check_eval(f, [], 8.0)
check_eval(g, [], 8.0)
check_eval(h, [], 8.0)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment