Commit d15477cd by Wuwei Lin Committed by Tianqi Chen

[Relay][Pass] Fold constant tuple (#2201)

parent e37dbd4e
......@@ -13,6 +13,36 @@ namespace relay {
using FInterpreter = runtime::TypedPackedFunc<Value(Expr)>;
class ConstantChecker : private ExprVisitor {
public:
// Check whether an expression is constant. The results are memorized.
bool Check(const Expr& expr) {
if (expr.as<ConstantNode>()) {
return true;
}
const auto it = memo_.find(expr);
if (it != memo_.end())
return it->second;
VisitExpr(expr);
return memo_[expr]; // return memorized result or the default value false
}
private:
std::unordered_map<Expr, bool, NodeHash, NodeEqual> memo_;
void VisitExpr_(const TupleNode* n) final {
bool result = true;
for (const auto& field : n->fields) {
if (!Check(field)) {
result = false;
break;
}
}
memo_[GetRef<Tuple>(n)] = result;
}
};
// TODO(tvm-team) consider combine dead-code with constant folder.
// or make a more powerful partial evaluator.
class ConstantFolder : public ExprMutator {
......@@ -53,7 +83,7 @@ class ConstantFolder : public ExprMutator {
if (op_stateful.get(GetRef<Op>(op), false)) return res;
bool all_const_args = true;
for (Expr arg : call->args) {
if (arg.as<ConstantNode>() == nullptr) {
if (!checker_.Check(arg)) {
all_const_args = false;
}
}
......@@ -77,6 +107,9 @@ class ConstantFolder : public ExprMutator {
private:
// Internal interepreter.
FInterpreter executor_;
// Internal constant checker
ConstantChecker checker_;
// Convert value to expression.
Expr ValueToExpr(Value value) {
if (const auto* val = value.as<TensorValueNode>()) {
......
......@@ -76,7 +76,27 @@ def test_fold_tuple():
assert relay.ir_pass.graph_equal(zz, zexpected)
def test_fold_concat():
c_data = np.array([[1, 2, 3]]).astype("float32")
def before():
a = relay.const(c_data)
b = relay.const(c_data)
y = relay.concatenate((a, b), axis=0)
return relay.Function([], y)
def expected():
y_data = np.concatenate((c_data, c_data), axis=0)
y = relay.const(y_data)
return relay.Function([], y)
zz = relay.ir_pass.fold_constant(before())
zexpected = expected()
assert relay.ir_pass.graph_equal(zz, zexpected)
if __name__ == "__main__":
test_fold_const()
test_fold_let()
test_fold_tuple()
test_fold_concat()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment