/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file partial_eval.cc * * \brief Perform known computation in compile time. * * The partial evaluator try to do computation at compile time, * so it can generate code that do less work. * Additionally, it might open more chance for further optimization, * since the high level, structural part of the code (closure, reference, control flow) * might get partially evaluated away, and the subsequent optimization (for example, kernel fusion) * can reason across those structural code as it got removed. * In the extreme case, partial evaluation can even turn the whole program * into pure first order computation with no control flow. * In such a case, we can compile the whole computation onto SIMD Instruction/GPU/FPGA, * and get huge speedup. * * It works by making the following modifications to the standard relay interpreter: * * 0: The values become partially static value. * Since we cannot know the value of every term at compile time, * Term might get partially evaluated to 'Unknown Value'. * Every partially static value is, hence, * a static fragment that might not be there (partially static), * and a dynamic fragment that is semantically equivalent to the original term, * so the unknown part will be computed at runtime, using the dynamic fragment. * * 1: The interpreter holds a LetList, which preserves A Normal Form for the generated code. * More specifically, we require that all dynamic is an atom. * This avoids code duplication (which is both inefficient and incorrect), as atom has constant size * and allow us to not handle capture-avoidance substitution (as atom has no binder). * * 2: The map of References to partially static values is reified, as described below. * Instead of Reference having mutable field, Reference only has an unique identifier. * There will be a mutable mapping of id to partially static value, called the store. * This allow us to rollback the store: * when a path may or may not be executed (as in a conditional), we copy the store, * recurse with the copy, and reinstate the original when the call returns * so that the effects of the computation are not preserved. * We do this in if else, pattern matching, and in function, * as, when we see a function, we partially evaluate it with all the argument as dynamic, * to generate efficient dynamic for that function. * * 3: The generated code reuses bindings (although they are not shadowed), * so we have to deduplicate them. * * 4: In the generated code, as it call TypeSubst, multiple VarNode might have same Id. * While it is permitted, most pass use NodeHash for Var, * and having multiple VarNode for same Id break them. * Thus we remap them to a single Id for now. * * Also, It will also generate lots of dead code, * so it is a good idea to feed it through the dead code eliminator after partial evaluation. * * The partial evaluator makes several assumptions, so there is room for improvement: * * 0: Every time an unknown effect happened, we clear the whole store. * It is too conservative: if a local reference is created (and do not get passed outside), * An unknown global function call/global reference write can not modify it. * We can pair PE with escape analysis/alias analysis. * * 1: We assume all unknown code has effect. Doing effect analysis can make the store more precise. * * 2: When doing pattern matching, we can simplify the match even for dynamic case. * Right now it is all or nothing: either a complete match, or the original dynamic code. * Instead, we can get a match tree, pair it with the data and evaluate it to a normal form. * We then can reify the result. * * 3: Every time a function is called, its code will get expanded and partially evaluated. * We can do a binding time analysis to cache the result and avoid re-partial evaluation. * * These assumptions do not affect the correctness of the algorithm, however. */ #include <tvm/relay/analysis.h> #include <tvm/relay/transform.h> #include <tvm/relay/expr_functor.h> #include <tvm/relay/pattern_functor.h> #include <tvm/relay/interpreter.h> #include "../ir/type_functor.h" #include "pass_util.h" #include "let_list.h" namespace tvm { namespace relay { namespace partial_eval { using namespace runtime; /*! \brief Hash Var by it's id. * Different VarNode might has same vid, and they are considered to be the same var in such case. * Use VarHash to hash Var by id. */ struct VarHash { size_t operator()(const Var& v) const { return NodeHash()(v->vid); } }; /*! \brief Compare Var by it's id. * Different VarNode might has same vid, and they are considered to be the same var in such case. * Use VarEqual to compare Var by id. */ struct VarEqual { bool operator()(const Var& l, const Var& r) const { return l->vid.get() == r->vid.get(); } }; Expr PostProcess(const Expr&); /*! \brief A StaticNode contains some static data that the Partial Evaluator can use. */ class StaticNode : public RelayNode { public: static constexpr const char* _type_key = "relay.Static"; TVM_DECLARE_BASE_NODE_INFO(StaticNode, RelayNode); }; class Static : public NodeRef { public: Static() {} explicit Static(ObjectPtr<Object> n) : NodeRef(n) {} const StaticNode* operator->() const { return static_cast<const StaticNode*>(get()); } using ContainerType = StaticNode; }; using Time = size_t; struct PStaticNode : Node { static Time time() { static Time time_ = 0; Time ret = time_; time_++; return ret; } Static pstatic; // may be null Expr dynamic; Time created_time; PStaticNode(const Static& pstatic, const Expr& dynamic) : pstatic(pstatic), dynamic(dynamic), created_time(time()) { } explicit PStaticNode(const Expr& dynamic) : PStaticNode(Static(), dynamic) { } static constexpr const char* _type_key = "relay.PStatic"; TVM_DECLARE_NODE_TYPE_INFO(PStaticNode, Node); }; RELAY_DEFINE_NODE_REF(PStatic, PStaticNode, NodeRef); struct STupleNode : StaticNode { std::vector<PStatic> fields; explicit STupleNode(const std::vector<PStatic>& fields) : fields(fields) { } static constexpr const char* _type_key = "relay.STuple"; TVM_DECLARE_NODE_TYPE_INFO(STupleNode, StaticNode); }; RELAY_DEFINE_NODE_REF(STuple, STupleNode, Static); Static MkSTuple(const std::vector<PStatic>& fields) { return Static(make_node<STupleNode>(fields)); } struct STensorNode : StaticNode { runtime::NDArray data; explicit STensorNode(const NDArray& data) : data(data) { } static constexpr const char* _type_key = "relay.STensor"; TVM_DECLARE_NODE_TYPE_INFO(STensorNode, StaticNode); }; RELAY_DEFINE_NODE_REF(STensor, STensorNode, Static); Static MkSTensor(const NDArray& data) { return Static(make_node<STensorNode>(data)); } struct SConstructorNode : StaticNode { Constructor constructor; std::vector<PStatic> fields; SConstructorNode(const Constructor& constructor, const std::vector<PStatic>& fields) : constructor(constructor), fields(fields) { } static constexpr const char* _type_key = "relay.SConstructor"; TVM_DECLARE_NODE_TYPE_INFO(SConstructorNode, StaticNode); }; RELAY_DEFINE_NODE_REF(SConstructor, SConstructorNode, Static); Static MkSConstructor(const Constructor& constructor, const std::vector<PStatic>& fields) { return Static(make_node<SConstructorNode>(constructor, fields)); } struct SRefNode : StaticNode { static constexpr const char* _type_key = "relay.SRef"; // we will use the address as the guid for hashing TVM_DECLARE_NODE_TYPE_INFO(SRefNode, StaticNode); }; RELAY_DEFINE_NODE_REF(SRef, SRefNode, Static); Static MkSRef() { return Static(make_node<SRefNode>()); } using Func = std::function<PStatic(const PStatic&, const std::vector<PStatic>&, const Attrs&, const Array<Type>&, LetList*)>; struct SFuncNode : StaticNode { Func func; explicit SFuncNode(const Func& func) : func(func) { } static constexpr const char* _type_key = "relay.SFunc"; TVM_DECLARE_NODE_TYPE_INFO(SFuncNode, StaticNode); }; RELAY_DEFINE_NODE_REF(SFunc, SFuncNode, Static); Static MkSFunc(const Func& func) { return Static(make_node<SFuncNode>(func)); } class FuelNode; /*! \brief A meet-semilattice with finite descending chain. * It means that we can meet two element to get an element, * and for every element, there is only a finite amount of meet before getting back the same element. * * Every time we recurse, we do a meet and require that progress must be made. * This ensures we do not recurse infinitely in the Partial Evaluator. */ class Fuel : public NodeRef { public: Fuel() {} explicit Fuel(ObjectPtr<Object> n) : NodeRef(n) {} const FuelNode* operator->() const; using ContainerType = FuelNode; }; class FuelNode : public RelayNode { public: // Please implement one of the following function or there will be infinite loop. /*! \brief return the new Fuel, and whether progress is made. * * Note that progress is not symmetric - it only measure progress for (*this). * * Thus, if the generated is smaller then the argument of Meet, * and the generated is not smaller then (*this), * progress should be false. */ virtual std::tuple<Fuel, bool> Meet(const Fuel& f) const { bool progress = false; auto ret = Meet(f, &progress); return std::make_tuple(ret, progress); } /*! \brief return the new Fuel, and write (*progress | is progress made) to *progress. */ virtual Fuel Meet(const Fuel& f, bool* progress) const { CHECK(progress); auto ret = Meet(f); *progress |= std::get<1>(ret); return std::get<0>(ret); } static constexpr const char* _type_key = "relay.Fuel"; TVM_DECLARE_BASE_NODE_INFO(FuelNode, RelayNode); }; const FuelNode* Fuel::operator->() const { return static_cast<const FuelNode*>(get()); } Fuel MkFSeq(const std::vector<Fuel>& fuels); struct FSeqNode : FuelNode { std::vector<Fuel> fuels; Fuel Meet(const Fuel& f, bool* progress) const final { auto x = f.as<FSeqNode>(); CHECK(x); CHECK_EQ(fuels.size(), x->fuels.size()); std::vector<Fuel> new_fuels; for (size_t i = 0; i < fuels.size(); ++i) { new_fuels.push_back(fuels[i]->Meet(x->fuels[i], progress)); } return MkFSeq(new_fuels); } explicit FSeqNode(const std::vector<Fuel>& fuels) : fuels(fuels) { } static constexpr const char* _type_key = "relay.FSeq"; TVM_DECLARE_NODE_TYPE_INFO(FSeqNode, FuelNode); }; RELAY_DEFINE_NODE_REF(FSeq, FSeqNode, Fuel); Fuel MkFSeq(const std::vector<Fuel>& fuels) { return Fuel(make_node<FSeqNode>(fuels)); } Fuel MkFTime(Time time); struct FTimeNode : FuelNode { Time time; std::tuple<Fuel, bool> Meet(const Fuel& f) const final { auto x = f.as<FTimeNode>(); CHECK(x); Time new_time = std::min(time, x->time); return std::make_tuple(MkFTime(new_time), new_time < time); } explicit FTimeNode(Time time) : time(time) { } static constexpr const char* _type_key = "relay.FTime"; TVM_DECLARE_NODE_TYPE_INFO(FTimeNode, FuelNode); }; RELAY_DEFINE_NODE_REF(FTime, FTimeNode, Fuel); Fuel MkFTime(Time time) { return Fuel(make_node<FTimeNode>(time)); } Fuel MkFTValue(size_t tvalue); /*! \brief If the pstatic is hold a positive integer scalar, that number, else 0. */ struct FTValueNode : FuelNode { size_t tvalue; std::tuple<Fuel, bool> Meet(const Fuel& f) const final { auto x = f.as<FTValueNode>(); CHECK(x); size_t new_tvalue = std::min(tvalue, x->tvalue); return std::make_tuple(MkFTValue(new_tvalue), new_tvalue < tvalue); } explicit FTValueNode(size_t tvalue) : tvalue(tvalue) { } static constexpr const char* _type_key = "relay.FTValue"; TVM_DECLARE_NODE_TYPE_INFO(FTValueNode, FuelNode); }; RELAY_DEFINE_NODE_REF(FTValue, FTValueNode, Fuel); Fuel MkFTValue(size_t tvalue) { return Fuel(make_node<FTValueNode>(tvalue)); } /*! \brief Initially every element has Fuel of FTop. It is the largest element. * * Note that it is illegal to has FTop inside some other Fuel - * doing so break the finite descending chain property. */ struct FTopNode : FuelNode { std::tuple<Fuel, bool> Meet(const Fuel& f) const final { return std::make_tuple(f, !f.as<FTopNode>()); } static constexpr const char* _type_key = "relay.FTop"; TVM_DECLARE_NODE_TYPE_INFO(FTopNode, FuelNode); }; RELAY_DEFINE_NODE_REF(FTop, FTopNode, Fuel); Fuel MkFTop() { return Fuel(make_node<FTopNode>()); } /*! * \brief A stack frame in the Relay interpreter. * * Contains a mapping from relay::Var to relay::Value. */ struct Frame { /*! \brief The set of local variables and arguments for the frame. */ std::unordered_map<Var, PStatic, VarHash, VarEqual> locals; Frame() = default; }; class Environment { public: Environment() : env_({Frame()}) { } Environment(const Environment&) = delete; template<typename T> T Extend(const std::function<T()>& body) { FrameContext fc(this); return body(); } void Insert(const Var& v, const PStatic& ps) { CHECK(ps.defined()); CHECK_GT(env_.size(), 0); CHECK_EQ(env_.back().locals.count(v), 0); env_.back().locals[v] = ps; } PStatic Lookup(const Var& v) { auto rit = env_.rbegin(); while (rit != env_.rend()) { if (rit->locals.find(v) != rit->locals.end()) { return rit->locals.find(v)->second; } ++rit; } LOG(FATAL) << "Unknown Variable: " << v; throw; } private: std::list<Frame> env_; struct FrameContext { Environment* env_; explicit FrameContext(Environment* env) : env_(env) { env_->env_.push_back(Frame()); } ~FrameContext() { env_->env_.pop_back(); } }; }; /*! * \brief As our store require rollback, we implement it as a frame. * * Every time we need to copy the store, a new frame is insert. * Every time we roll back, a frame is popped. */ struct StoreFrame { std::unordered_map<const SRefNode*, PStatic> store; /*! * \brief On unknown effect, history_valid is set to true to signal above frame is outdated. * * It only outdate the frame above it, but not the current frame. */ bool history_valid = true; explicit StoreFrame(const std::unordered_map<const SRefNode*, PStatic>& store) : store(store) { } StoreFrame() = default; }; class Store { public: Store() : store_({StoreFrame()}) { } Store(const Store&) = delete; template<typename T> T Extend(const std::function<T()>& body) { StoreFrameContext sfc(this); return body(); } void Insert(const SRefNode* r, const PStatic& ps) { CHECK(r); store_.back().store[r] = ps; } // return null if not found PStatic Lookup(const SRefNode* r) { auto rit = store_.rbegin(); while (rit != store_.rend()) { if (rit->store.find(r) != rit->store.end()) { return rit->store.find(r)->second; } if (!rit->history_valid) { return PStatic(); } ++rit; } return PStatic(); } void Invalidate() { StoreFrame sf; sf.history_valid = false; store_.push_back(sf); } private: std::list<StoreFrame> store_; struct StoreFrameContext { Store* store_; explicit StoreFrameContext(Store* store) : store_(store) { store_->store_.push_back(StoreFrame()); } ~StoreFrameContext() { // push one history valid frame off. while (!store_->store_.back().history_valid) { store_->store_.pop_back(); } store_->store_.pop_back(); } }; }; PStatic HasStatic(const Static& stat, const Expr& dynamic) { CHECK(stat.defined()); return PStatic(make_node<PStaticNode>(stat, dynamic)); } PStatic NoStatic(const Expr& dynamic) { return PStatic(make_node<PStaticNode>(dynamic)); } enum struct MatchStatus { Match, NoMatch, Unknown }; bool StatefulOp(const Expr& e) { static auto op_stateful = Op::GetAttr<TOpIsStateful>("TOpIsStateful"); struct StatefulOpVisitor : ExprVisitor { bool stateful = false; void VisitExpr_(const OpNode* op) { stateful = stateful || op_stateful.get(GetRef<Op>(op), false); } }; StatefulOpVisitor sov; sov(e); return sov.stateful; } using FInterpreter = runtime::TypedPackedFunc<Value(Expr)>; DLContext CPUContext() { DLContext ctx; ctx.device_type = kDLCPU; ctx.device_id = 0; return ctx; } FInterpreter CPUInterpreter() { Target target = Target::Create("llvm"); // use a fresh build context // in case we are already in a build context. With<BuildConfig> fresh_build_ctx(BuildConfig::Create()); return CreateInterpreter(Module(nullptr), CPUContext(), target); } using FuncId = int; /*! * \brief Annotate a function with a FuncId. */ struct WithFuncIdAttrs : public tvm::AttrsNode<WithFuncIdAttrs> { FuncId fid; TVM_DECLARE_ATTRS(WithFuncIdAttrs, "relay.attrs.WithFuncIdAttrs") { TVM_ATTR_FIELD(fid) .describe("The FuncId that an function is annotated with.") .set_default(-1); } }; TVM_REGISTER_NODE_TYPE(WithFuncIdAttrs); Op WithFuncIdOp() { static const Op& op = Op::Get("annotation.with_funcid"); return op; } Expr MkWithFuncId(const Expr& expr, FuncId fid) { auto attrs = make_node<WithFuncIdAttrs>(); attrs->fid = fid; return CallNode::make(WithFuncIdOp(), {expr}, Attrs(attrs), {}); } RELAY_REGISTER_OP("annotation.with_funcid") .describe(R"code(Annotate a function with a funcid.)code" TVM_ADD_FILELINE) .set_num_inputs(1) .add_argument("func", "Function", "The input data."); Expr StripWithFuncId(const Expr& e); Function AsFunc(const Expr& e) { if (e.as<FunctionNode>()) { return Downcast<Function>(e); } else if (const CallNode* c = e.as<CallNode>()) { CHECK(c->op.same_as(WithFuncIdOp())); CHECK_EQ(c->args.size(), 1); return AsFunc(c->args[0]); } else { LOG(FATAL) << "Unknown case"; throw; } } class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>, public PatternFunctor<MatchStatus(const Pattern&, const PStatic&)> { public: PartialEvaluator(const Module& mod) : mod_(mod) { } PStatic VisitExpr(const Expr& e, LetList* ll) final { PStatic ret = ExprFunctor<PStatic(const Expr&, LetList*)>::VisitExpr(e, ll); CHECK(IsAtomic(ret->dynamic)) << ret->dynamic; return ret; } PStatic VisitExpr(const Expr& e, LetList* ll, const Var& name) { if (const CallNode* c = e.as<CallNode>()) { if (c->op.same_as(WithFuncIdOp())) { CHECK_EQ(c->args.size(), 1); return VisitExpr(c->args[0], ll, name); } } PStatic ret = e.as<FunctionNode>() ? VisitFunc(Downcast<Function>(e), ll, name) : VisitExpr(e, ll); CHECK(IsAtomic(ret->dynamic)) << ret->dynamic; return ret; } PStatic VisitExpr_(const ConstantNode* op, LetList* ll) final { return HasStatic(MkSTensor(op->data.CopyTo(context_)), ll->Push(GetRef<Expr>(op))); } PStatic VisitExpr_(const TupleNode* op, LetList* ll) final { std::vector<PStatic> value; tvm::Array<Expr> expr; for (const Expr& e : op->fields) { PStatic ps = VisitExpr(e, ll); value.push_back(ps); expr.push_back(ps->dynamic); } return HasStatic(MkSTuple(value), ll->Push(TupleNode::make(expr))); } PStatic VisitExpr_(const TupleGetItemNode* op, LetList* ll) final { PStatic ps = VisitExpr(op->tuple, ll); if (ps->pstatic.defined()) { return Downcast<STuple>(ps->pstatic)->fields[op->index]; } else { return NoStatic(ll->Push(TupleGetItemNode::make(ps->dynamic, op->index))); } } PStatic VisitExpr_(const VarNode* op, LetList* ll) final { return env_.Lookup(GetRef<Var>(op)); } PStatic VisitGlobalVar(const GlobalVar& gv) { CHECK(mod_.defined()); if (gv_map_.count(gv) == 0) { Function func = mod_->Lookup(gv); InitializeFuncId(func); Func f = VisitFuncStatic(func, gv); gv_map_.insert({gv, HasStatic(MkSFunc(f), gv)}); func = AsFunc(PostProcess(VisitFuncDynamic(func, f, gv))); mod_->Update(gv, func); } return gv_map_.at(gv); } PStatic VisitExpr_(const GlobalVarNode* op, LetList* ll) final { return VisitGlobalVar(GetRef<GlobalVar>(op)); } PStatic VisitExpr_(const LetNode* op, LetList* ll) final { env_.Insert(op->var, VisitExpr(op->value, ll, op->var)); return VisitExpr(op->body, ll); } PStatic VisitExpr_(const IfNode* op, LetList* ll) final { PStatic c = VisitExpr(op->cond, ll); if (c->pstatic.defined()) { NDArray cpu_array = Downcast<STensor>(c->pstatic)->data.CopyTo(CPUContext()); CHECK_EQ(DataType(cpu_array->dtype), DataType::Bool()); if (reinterpret_cast<uint8_t*>(cpu_array->data)[0]) { return VisitExpr(op->true_branch, ll); } else { return VisitExpr(op->false_branch, ll); } } else { Expr t = store_.Extend<Expr>([&]() { return LetList::With([&](LetList* ll) { return VisitExpr(op->true_branch, ll)->dynamic; }); }); Expr f = store_.Extend<Expr>([&]() { return LetList::With([&](LetList* ll) { return VisitExpr(op->false_branch, ll)->dynamic; }); }); store_.Invalidate(); return NoStatic(ll->Push(IfNode::make(c->dynamic, t, f))); } } PStatic VisitExpr_(const RefCreateNode* op, LetList* ll) final { PStatic ps = VisitExpr(op->value, ll); Static r = MkSRef(); store_.Insert(r.as<SRefNode>(), ps); return HasStatic(r, ll->Push(RefCreateNode::make(ps->dynamic))); } PStatic VisitExpr_(const RefWriteNode* op, LetList* ll) final { PStatic r = VisitExpr(op->ref, ll); PStatic v = VisitExpr(op->value, ll); if (r->pstatic.defined()) { store_.Insert(r->pstatic.as<SRefNode>(), v); } else { store_.Invalidate(); } return HasStatic(MkSTuple({}), ll->Push(RefWriteNode::make(r->dynamic, v->dynamic))); } PStatic VisitExpr_(const RefReadNode* op, LetList* ll) final { PStatic r = VisitExpr(op->ref, ll); if (r->pstatic.defined()) { PStatic ret = store_.Lookup(r->pstatic.as<SRefNode>()); if (ret) { return ret; } } return NoStatic(ll->Push(RefReadNode::make(r->dynamic))); } PStatic VisitExpr_(const CallNode* op, LetList* ll) final { if (op->op.same_as(WithFuncIdOp())) { CHECK_EQ(op->args.size(), 1); return VisitExpr(op->args[0], ll); } PStatic f = VisitExpr(op->op, ll); std::vector<PStatic> x; tvm::Array<Expr> x_dyn; for (const Expr& e : op->args) { PStatic ps = VisitExpr(e, ll); x.push_back(ps); x_dyn.push_back(ps->dynamic); } if (f->pstatic.defined()) { return Downcast<SFunc>(f->pstatic)->func(f, x, op->attrs, op->type_args, ll); } else { store_.Invalidate(); return NoStatic(ll->Push(CallNode::make(f->dynamic, x_dyn, op->attrs, op->type_args))); } } struct FuelFrame { PartialEvaluator* pe_; FuncId fid_; Fuel old_fuel; FuelFrame(PartialEvaluator* pe, FuncId fid, const Fuel& new_fuel) : pe_(pe), fid_(fid) { CHECK_GT(pe_->fuel_map_.count(fid_), 0); old_fuel = pe_->fuel_map_[fid_]; pe_->fuel_map_[fid_] = new_fuel; } ~FuelFrame() { pe_->fuel_map_[fid_] = old_fuel; } }; size_t GetFTValue(const PStatic& ps) { if (ps->pstatic.defined()) { if (auto* st = ps->pstatic.as<STensorNode>()) { if (st->data.Shape().empty()) { NDArray cpu_array = st->data.CopyTo(CPUContext()); DataType dtype = DataType(cpu_array->dtype); if (dtype == DataType::Int(32)) { return std::max<int32_t>(0, *static_cast<const int32_t*>(cpu_array->data)); } else if (dtype == DataType::Int(64)) { return std::max<int64_t>(0, *static_cast<const int64_t*>(cpu_array->data)); } } } } return 0; } Fuel GetFuel(const PStatic& ps) { std::vector<Fuel> fuels; fuels.push_back(MkFTime(ps->created_time)); fuels.push_back(MkFTValue(GetFTValue(ps))); return MkFSeq(fuels); } Func VisitFuncStatic(const Function& func, const Expr& var) { CHECK(IsAtomic(var)); if (func->IsPrimitive()) { return ConstEvaluateFunc(func); } std::vector<std::pair<Var, PStatic> > free_vars; for (const auto& v : FreeVars(func)) { if (v != var) { free_vars.push_back(std::pair<Var, PStatic>(v, env_.Lookup(v))); } } return [=](const PStatic& self, const std::vector<PStatic>& pv, const Attrs& attrs, const tvm::Array<Type>& type_args, LetList* ll) { return env_.Extend<PStatic>([&]() { CHECK_EQ(pv.size(), func->params.size()); CHECK_GT(func_map_.count(func), 0); FuncId fid = func_map_.at(func); if (fuel_map_.count(fid) == 0) { fuel_map_.insert({fid, MkFTop()}); } std::vector<Fuel> args_fuel; for (const auto& v : pv) { args_fuel.push_back(GetFuel(v)); } auto meet_res = fuel_map_[fid]->Meet(MkFSeq(args_fuel)); if (std::get<1>(meet_res)) { FuelFrame tf(this, fid, std::get<0>(meet_res)); Expr dedup_func = RegisterFuncId(DeDup(AnnotateFuncId(func))); Function func = AsFunc(dedup_func); if (var.as<VarNode>()) { env_.Insert(Downcast<Var>(var), self); } for (size_t i = 0; i < pv.size(); ++i) { env_.Insert(func->params[i], pv[i]); } for (const auto& p : free_vars) { env_.Insert(p.first, p.second); } tvm::Map<TypeVar, Type> subst; for (size_t i = 0; i < type_args.size(); ++i) { subst.Set(func->type_params[i], type_args[i]); } for (size_t i = type_args.size(); i < func->type_params.size(); ++i) { subst.Set(func->type_params[i], IncompleteTypeNode::make(kType)); } return VisitExpr(RegisterFuncId(TypeSubst(AnnotateFuncId(func->body), subst)), ll); } else { std::vector<Expr> dyn; for (const auto& v : pv) { dyn.push_back(v->dynamic); } return NoStatic(ll->Push(CallNode::make(var, dyn, attrs, type_args))); } }); }; } Expr VisitFuncDynamic(const Function& func, const Func& f, const Expr& self) { return store_.Extend<Expr>([&]() { store_.Invalidate(); return FunctionNode::make(func->params, LetList::With([&](LetList* ll) { std::vector<PStatic> pv; for (const auto& v : func->params) { pv.push_back(NoStatic(v)); } tvm::Array<Type> type_args; for (const auto& tp : func->type_params) { type_args.push_back(tp); } return f(HasStatic(MkSFunc(f), self), pv, Attrs(), type_args, ll)->dynamic; }), func->ret_type, func->type_params, func->attrs); }); } PStatic VisitFunc(const Function& func, LetList* ll, const Var& name = VarNode::make("x", Type())) { Func f = VisitFuncStatic(func, name); Function u_func = AsFunc(RegisterFuncId(DeDup(AnnotateFuncId(func)))); // TODO(@M.K.): we seems to reduce landin knot into letrec. // restore letrec support across whole relay. return HasStatic(MkSFunc(f), ll->Push(name, VisitFuncDynamic(u_func, f, name))); } PStatic VisitExpr_(const FunctionNode* op, LetList* ll) final { return VisitFunc(GetRef<Function>(op), ll); } struct ReflectError : dmlc::Error { ReflectError() : dmlc::Error("static value not found") { } }; Expr Reflect(const PStatic& st) { if (!st->pstatic.defined()) { throw ReflectError(); } else if (const STensorNode* op = st->pstatic.as<STensorNode>()) { return ConstantNode::make(op->data); } else if (const STupleNode* op = st->pstatic.as<STupleNode>()) { tvm::Array<Expr> fields; for (const PStatic& field : op->fields) { fields.push_back(Reflect(field)); } return TupleNode::make(fields); } else { LOG(FATAL) << "Unknown case: " << st->dynamic; throw; } } PStatic Reify(const Value& v, LetList* ll) const { if (const TensorValueNode* op = v.as<TensorValueNode>()) { return HasStatic(MkSTensor(op->data), ll->Push(ConstantNode::make(op->data))); } else if (const TupleValueNode* op = v.as<TupleValueNode>()) { std::vector<PStatic> fields; tvm::Array<Expr> fields_dyn; for (const Value& field : op->fields) { PStatic ps = Reify(field, ll); fields.push_back(ps); fields_dyn.push_back(ps->dynamic); } return HasStatic(MkSTuple(fields), ll->Push(TupleNode::make(fields_dyn))); } else { LOG(FATAL) << "Unknown case"; throw; } } // Constant evaluate a expression. PStatic ConstEvaluate(const Expr& expr, LetList* ll) { std::vector<transform::Pass> passes = {transform::FuseOps(0), transform::InferType()}; auto mod = ModuleNode::FromExpr(expr); auto seq = transform::Sequential(passes); mod = seq(mod); auto entry_func = mod->Lookup("main"); auto fused_infered = expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func; return Reify(executor_(fused_infered), ll); } Func ConstEvaluateFunc(const Expr& expr) { CHECK_EQ(FreeVars(expr).size(), 0); return [=](const PStatic& self, const std::vector<PStatic>& pv, const Attrs& attrs, const tvm::Array<Type>& type_args, LetList* ll) { tvm::Array<Expr> ns_args; for (const PStatic& ps : pv) { ns_args.push_back(ps->dynamic); } auto ns = [&]() { return NoStatic(ll->Push(CallNode::make(expr, ns_args, attrs, type_args))); }; if (StatefulOp(expr)) { return ns(); } try { tvm::Array<Expr> args; for (const PStatic& ps : pv) { args.push_back(Reflect(ps)); } return ConstEvaluate(CallNode::make(expr, args, attrs, type_args), ll); } catch (const ReflectError&) { return ns(); } }; } PStatic VisitExpr_(const OpNode* op, LetList* ll) final { return HasStatic(MkSFunc(ConstEvaluateFunc(GetRef<Expr>(op))), GetRef<Expr>(op)); } PStatic VisitExpr_(const ConstructorNode* op, LetList* ll) final { Constructor c = GetRef<Constructor>(op); Func f = [=](const PStatic& self, const std::vector<PStatic>& pv, const Attrs& attrs, const tvm::Array<Type>& type_args, LetList* ll) { tvm::Array<Expr> dyn; for (const PStatic& ps : pv) { dyn.push_back(ps->dynamic); } return HasStatic(MkSConstructor(c, pv), ll->Push(CallNode::make(c, dyn))); }; return HasStatic(MkSFunc(f), GetRef<Expr>(op)); } PStatic VisitExpr_(const MatchNode* op, LetList* ll) final { PStatic ps = VisitExpr(op->data, ll); return env_.Extend<PStatic>([&]() { for (const Clause& c : op->clauses) { switch (VisitPattern(c->lhs, ps)) { case MatchStatus::Match: return VisitExpr(c->rhs, ll); case MatchStatus::NoMatch: continue; case MatchStatus::Unknown: return [&]() { tvm::Array<Clause> clauses; for (const Clause& c : op->clauses) { Expr expr = store_.Extend<Expr>([&]() { return LetList::With([&](LetList* ll) { for (const Var& v : BoundVars(c->lhs)) { env_.Insert(v, NoStatic(v)); } return VisitExpr(c->rhs, ll)->dynamic; }); }); clauses.push_back(ClauseNode::make(c->lhs, expr)); } store_.Invalidate(); return NoStatic(ll->Push(MatchNode::make(ps->dynamic, clauses, op->complete))); }(); default: LOG(FATAL) << "Unknown MatchStatus"; throw; } } LOG(FATAL) << "No case Match"; throw; }); } MatchStatus VisitPattern_(const PatternWildcardNode* op, const PStatic& ps) final { return MatchStatus::Match; } MatchStatus VisitPattern_(const PatternVarNode* op, const PStatic& ps) final { env_.Insert(op->var, ps); return MatchStatus::Match; } MatchStatus VisitPattern_(const PatternConstructorNode* op, const PStatic& ps) final { if (ps->pstatic.defined()) { SConstructor scn = Downcast<SConstructor>(ps->pstatic); CHECK_NE(op->constructor->tag, -1); CHECK_NE(scn->constructor->tag, -1); if (op->constructor->tag == scn->constructor->tag) { CHECK_EQ(op->patterns.size(), scn->fields.size()); MatchStatus current_match_status = MatchStatus::Match; for (size_t i = 0; i < op->patterns.size(); ++i) { MatchStatus ms = VisitPattern(op->patterns[i], scn->fields[i]); switch (ms) { case MatchStatus::Match: continue; case MatchStatus::NoMatch: return MatchStatus::NoMatch; case MatchStatus::Unknown: current_match_status = MatchStatus::Unknown; } } return current_match_status; } return MatchStatus::NoMatch; } else { return MatchStatus::Unknown; } } MatchStatus VisitPattern_(const PatternTupleNode* op, const PStatic& ps) final { if (ps->pstatic.defined()) { STuple stn = Downcast<STuple>(ps->pstatic); CHECK_EQ(op->patterns.size(), stn->fields.size()); MatchStatus current_match_status = MatchStatus::Match; for (size_t i = 0; i < op->patterns.size(); ++i) { MatchStatus ms = VisitPattern(op->patterns[i], stn->fields[i]); switch (ms) { case MatchStatus::Match: continue; case MatchStatus::NoMatch: return MatchStatus::NoMatch; case MatchStatus::Unknown: current_match_status = MatchStatus::Unknown; } } return current_match_status; } else { return MatchStatus::Unknown; } } void InitializeFuncId(const Expr& e) { struct InitializeFuncIdVisitor : ExprVisitor, PatternVisitor { PartialEvaluator* pe; explicit InitializeFuncIdVisitor(PartialEvaluator* pe) : pe(pe) { } void VisitExpr_(const FunctionNode* op) final { Function f = GetRef<Function>(op); CHECK_EQ(pe->func_map_.count(f), 0); pe->func_map_.insert({f, pe->func_map_.size()}); VisitExpr(f->body); } void VisitPattern(const Pattern& p) final { PatternVisitor::VisitPattern(p); } }; InitializeFuncIdVisitor(this).VisitExpr(e); } Expr RegisterFuncId(const Expr& e) { struct RegisterFuncIdVisitor : ExprVisitor, PatternVisitor { PartialEvaluator* pe; explicit RegisterFuncIdVisitor(PartialEvaluator* pe) : pe(pe) { } void VisitExpr_(const CallNode* op) final { if (op->op.same_as(WithFuncIdOp())) { CHECK_EQ(op->args.size(), 1); CHECK(op->attrs.defined()); CHECK(op->attrs.as<WithFuncIdAttrs>()); Function f = AsFunc(op->args[0]); FuncId fid = op->attrs.as<WithFuncIdAttrs>()->fid; if (pe->func_map_.count(f) != 0) { CHECK_EQ(pe->func_map_.at(f), fid); } pe->func_map_.insert({f, fid}); } ExprVisitor::VisitExpr_(op); } void VisitExpr_(const FunctionNode* op) final { Function f = GetRef<Function>(op); CHECK_GT(pe->func_map_.count(f), 0); ExprVisitor::VisitExpr_(op); } void VisitPattern(const Pattern& p) final { PatternVisitor::VisitPattern(p); } }; RegisterFuncIdVisitor(this).VisitExpr(e); return e; } Expr AnnotateFuncId(const Expr& e) { struct AnnotateFuncIdMutator : ExprMutator, PatternMutator { PartialEvaluator* pe; explicit AnnotateFuncIdMutator(PartialEvaluator* pe) : pe(pe) { } Expr VisitExpr_(const FunctionNode* op) final { Function f = GetRef<Function>(op); CHECK_GT(pe->func_map_.count(f), 0); return MkWithFuncId(ExprMutator::VisitExpr_(op), pe->func_map_.at(f)); } Pattern VisitPattern(const Pattern& p) final { return PatternMutator::VisitPattern(p); } Var VisitVar(const Var& v) final { return v; } }; return AnnotateFuncIdMutator(this).VisitExpr(e); } private: Environment env_; Module mod_; std::unordered_map<GlobalVar, PStatic, NodeHash, NodeEqual> gv_map_; /*! Termination checking is done as follows: * We have finitely many FunctionIds. * Each FunctionId maps to a class of semantically equivalent function (ignoring type), * as both TypeSubst and DeDup create semantically equivalent function. * We partially map each FunctionId to a Fuel. * Every time we try to inline a Function, * we make sure it either does not have a Fuel, * or we meet the existing fuel with the fuel calculated from the argument. * If no progress is made, we do not inline. * In both case, we remap the mapping to the new Fuel * when we PE inside the Function body. * Termination is guaranteed because Fuel is finitely descending - there can only be so many meet. */ std::unordered_map<Function, FuncId, NodeHash, NodeEqual> func_map_; std::unordered_map<FuncId, Fuel> fuel_map_; Store store_; DLContext context_ = CPUContext(); FInterpreter executor_ = CPUInterpreter(); }; /*! \brief Remap multiple Var sharing the same Id into the same Var. */ Expr Remap(const Expr& e) { class RemapMutator : public ExprMutator, public PatternMutator { Expr VisitExpr_(const VarNode* op) final { Var v = GetRef<Var>(op); if (remap_.count(v) == 0) { remap_.insert({v, v}); } return remap_.at(v); } Var VisitVar(const Var& v) final { return Downcast<Var>(VisitExpr(v)); } private: std::unordered_map<Var, Var, VarHash, VarEqual> remap_; }; return RemapMutator().VisitExpr(e); } Expr StripWithFuncId(const Expr& e) { struct StripWithFuncIdMutator : ExprMutator, PatternMutator { Expr VisitExpr_(const CallNode* op) final { if (op->op.same_as(WithFuncIdOp())) { CHECK_EQ(op->args.size(), 1); return VisitExpr(op->args[0]); } else { return ExprMutator::VisitExpr_(op); } } Pattern VisitPattern(const Pattern& p) final { return PatternMutator::VisitPattern(p); } Var VisitVar(const Var& v) final { return v; } }; return StripWithFuncIdMutator().VisitExpr(e); } Expr PostProcess(const Expr& e) { return StripWithFuncId(DeDup(Remap(e))); } } // namespace partial_eval Module PartialEval(const Module& m) { relay::partial_eval::PartialEvaluator pe(m); std::vector<GlobalVar> gvs; for (const auto& p : m->functions) { gvs.push_back(p.first); } for (const auto& gv : gvs) { pe.VisitGlobalVar(gv); } return m; } namespace transform { Pass PartialEval() { runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func = [=](Module m, PassContext pc) { return PartialEval(m); }; return CreateModulePass(pass_func, 1, "PartialEvaluate", {}); } TVM_REGISTER_API("relay._transform.PartialEvaluate") .set_body_typed(PartialEval); } // namespace transform } // namespace relay } // namespace tvm