/*! * Copyright (c) 2018 by Contributors * \file constant_folding.cc */ #include <tvm/relay/pass.h> #include <tvm/relay/expr_functor.h> #include <tvm/relay/op_attr_types.h> #include <tvm/relay/interpreter.h> namespace tvm { namespace relay { using FInterpreter = runtime::TypedPackedFunc<Value(Expr)>; class ConstantChecker : private ExprVisitor { public: // Check whether an expression is constant. The results are memorized. bool Check(const Expr& expr) { if (expr.as<ConstantNode>()) { return true; } const auto it = memo_.find(expr); if (it != memo_.end()) return it->second; VisitExpr(expr); return memo_[expr]; // return memorized result or the default value false } private: std::unordered_map<Expr, bool, NodeHash, NodeEqual> memo_; void VisitExpr_(const TupleNode* n) final { bool result = true; for (const auto& field : n->fields) { if (!Check(field)) { result = false; break; } } memo_[GetRef<Tuple>(n)] = result; } }; // TODO(tvm-team) consider combine dead-code with constant folder. // or make a more powerful partial evaluator. class ConstantFolder : public ExprMutator { public: explicit ConstantFolder(FInterpreter executor) : executor_(executor) { } Expr VisitExpr_(const LetNode* op) final { Expr value = this->Mutate(op->value); if (value.as<ConstantNode>()) { memo_[op->var] = value; return this->Mutate(op->body); } else { Var var = Downcast<Var>(this->Mutate(op->var)); Expr body = this->Mutate(op->body); if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) { return GetRef<Expr>(op); } else { return LetNode::make(var, value, body); } } } Expr VisitExpr_(const CallNode* call) final { static auto op_stateful = Op::GetAttr<TOpIsStateful>("TOpIsStateful"); Expr res = ExprMutator::VisitExpr_(call); call = res.as<CallNode>(); // We don't constant fold function with zero arguments. // This is a heuristic that is useful. // For example it is harmful to fold ones(shape=(4, 5)). if (call->args.size() == 0) return res; const OpNode* op = call->op.as<OpNode>(); if (op == nullptr) return res; // skip stateful ops. if (op_stateful.get(GetRef<Op>(op), false)) return res; bool all_const_args = true; for (Expr arg : call->args) { if (!checker_.Check(arg)) { all_const_args = false; } } if (all_const_args) { return ConstEvaluate(res); } else { return res; } } Expr VisitExpr_(const TupleGetItemNode* op) final { Expr res = ExprMutator::VisitExpr_(op); op = res.as<TupleGetItemNode>(); if (const auto* tuple = op->tuple.as<TupleNode>()) { return tuple->fields[op->index]; } else { return res; } } private: // Internal interepreter. FInterpreter executor_; // Internal constant checker ConstantChecker checker_; // Convert value to expression. Expr ValueToExpr(Value value) { if (const auto* val = value.as<TensorValueNode>()) { return ConstantNode::make(val->data); } else if (const auto* val = value.as<TupleValueNode>()) { Array<Expr> fields; for (Value field : val->fields) { fields.push_back(ValueToExpr(field)); } return TupleNode::make(fields); } else { LOG(FATAL) << "Cannot handle " << value->type_key(); return Expr(); } } // Constant evaluate a expression. Expr ConstEvaluate(Expr expr) { expr = InferType(expr, Module(nullptr)); expr = FuseOps(expr, 0); expr = InferType(expr, Module(nullptr)); return ValueToExpr(executor_(expr)); } }; Expr FoldConstant(const Expr& expr) { DLContext ctx; ctx.device_type = kDLCPU; ctx.device_id = 0; Target target = Target::create("llvm"); // use a fresh build context // in case we are already in a build context. BuildConfigContext fresh_build_ctx(build_config()); return ConstantFolder(CreateInterpreter( Module(nullptr), ctx, target)).Mutate(expr); } TVM_REGISTER_API("relay._ir_pass.FoldConstant") .set_body([](TVMArgs args, TVMRetValue *ret) { *ret = FoldConstant(args[0]); }); } // namespace relay } // namespace tvm