/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file ad.cc * \brief API for Automatic Differentiation for the Relay IR. */ #include <tvm/lowered_func.h> #include <tvm/operation.h> #include <tvm/relay/expr_functor.h> #include <tvm/relay/analysis.h> #include <tvm/relay/transform.h> #include "pattern_util.h" #include "pass_util.h" #include "let_list.h" #include "../ir/type_functor.h" namespace tvm { namespace relay { using namespace tvm::runtime; /*! What is automatic differentiation(AD) and why is it important? * By AD, we roughly mean, given a term which denotes some mathematical function, * derive a term which denotes the derivative of that mathematical function. * Such a method can be compile-time, which is a macro on completely known function. * Formally speaking, such requirement mean that the input function is a closed expression - * that is, it only refer to local variable that is it's parameter, or defined inside it. * Every top level definition satisfy this criteria. * AD can also be run-time, which mean it is merely a function term of AD : (Float[] -> Float[]) -> (Float[] -> Float[]). * In relay we currently only support compile-time AD, but it should be enough for a lot of use case. * * In deep learning, the most common way to train a deep neural network is by gradient descent or some of it's variant. * Such optimization method require us to input the gradient of neural network, which can be obtained easily using AD. * In fact, back propagation is essentially reverse-mode automatic differentiation, a kind of AD! */ /*! In relay, automatic differentiation(AD) is a macro, * that transform closed expr(expr without free variable/free type variable) of type * (x0, x1, x2, ...) -> Float[] to * (x0, x1, x2, ...) -> (Float[], (x0, x1, x2, ...)), * When x0, x1, x2... are Float of different shape. * the return value is a pair, with left hand side as the original value, and right hand side as gradient of the input. * WithGradientType will take the type of input, and produce the type of output. * There are multiple implementation of AD in relay, with different characteristic. * However, they all transform the input expr according to WithGradientType. */ Type WithGradientType(const Type&); /*! return an expression that represent differentiation of e (according to WithGradientType). * This version only work on first order code without control flow. */ Expr FirstOrderGradient(const Expr& e, const Module& mod); Type WithGradientType(const Type& t) { // TODO(M.K.): stricter checking auto ty = t.as<FuncTypeNode>(); CHECK(ty) << "input should be a function"; return FuncTypeNode::make(ty->arg_types, TupleTypeNode::make({ ty->ret_type, TupleTypeNode::make(ty->arg_types)}), {}, {}); } //! \brief if the expression is a GlobalVar, transform to it's expression. Expr DeGlobal(const Module& mod, const Expr& e) { if (const auto* x = e.as<GlobalVarNode>()) { return mod->Lookup(GetRef<GlobalVar>(x))->body; } else { return e; } } /*! \brief A fragment of the program being built by the automatic differentation * pass. */ struct ADValueNode { virtual ~ADValueNode() { } template <typename T> T& get() { auto ret = dynamic_cast<T*>(this); CHECK(ret) << "cannot downcast"; return *ret; } }; using ADValue = std::shared_ptr<ADValueNode>; /*! \brief AD over a program which generates a tensor output. */ struct ADTensor : ADValueNode { Expr forward; mutable Expr reverse; // must be a variable to avoid duplication ADTensor(LetList* ll, const Expr& forward) : forward(ll->Push(forward)), reverse(ll->Push(ZerosLike(this->forward))) { this->forward->checked_type_ = forward->checked_type(); } }; /*! \brief A staged representation of the program, we reflect * Relay functions into a function over fragments of AD. We * can compute away this function to obtain a reverse mode program. */ struct ADFunction : ADValueNode { std::function<ADValue(const Type&, const std::vector<ADValue>&, const Attrs&, const tvm::Array<Type>&)> func; explicit ADFunction(const std::function<ADValue(const Type&, const std::vector<ADValue>&, const Attrs&, const tvm::Array<Type>&)>& func) : func(func) { } }; struct FirstOrderReverseAD : ExprFunctor<ADValue(const Expr &)> { const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient"); std::vector<std::function<void(LetList* ll)>> backprop_actions; // we assume no closure so no need for lexical scoping std::unordered_map<Var, ADValue, NodeHash, NodeEqual> env; LetList* ll; FirstOrderReverseAD(LetList* ll) : ll(ll) { } ADValue VisitExpr_(const OpNode* op) final { Op op_ref = GetRef<Op>(op); CHECK(rev_map.count(op_ref)) << op->name << " does not have reverse mode defined"; return std::make_shared<ADFunction>([this, op_ref](const Type& orig_type, const std::vector<ADValue>& args, const Attrs& attrs, const tvm::Array<Type>& type_args) { std::vector<Expr> call_args; for (const ADValue& adval : args) { call_args.push_back(adval->get<ADTensor>().forward); } auto orig = CallNode::make(op_ref, call_args, attrs, type_args); orig->checked_type_ = orig_type; auto ret = std::make_shared<ADTensor>(ll, orig); backprop_actions.push_back([this, args, orig, ret, op_ref](LetList* ll) { tvm::Array<Expr> rev = rev_map[op_ref](orig, ret->reverse); CHECK(args.size() == rev.size()); for (size_t i = 0; i < args.size(); ++i) { args[i]->get<ADTensor>().reverse = ll->Push(Add(args[i]->get<ADTensor>().reverse, rev[i])); } }); return ret; }); } ADValue VisitExpr_(const ConstantNode* op) final { Expr e = GetRef<Expr>(op); return std::make_shared<ADTensor>(ll, e); } ADValue VisitExpr_(const CallNode* op) final { ADValue f = VisitExpr(op->op); std::vector<ADValue> args; for (const auto& arg : op->args) { args.push_back(VisitExpr(arg)); } return f->get<ADFunction>().func(op->checked_type(), args, op->attrs, op->type_args); } ADValue VisitExpr_(const FunctionNode* op) final { Function f = GetRef<Function>(op); // todo: assert no closure return std::make_shared<ADFunction>([this, f](const Type& orig_type, const std::vector<ADValue>& args, const Attrs& attrs, const tvm::Array<Type>& type_args) { CHECK_EQ(f->params.size(), args.size()); for (size_t i = 0; i < f->params.size(); ++i) { env[f->params[i]] = args[i]; } return VisitExpr(f->body); }); } ADValue VisitExpr_(const VarNode* op) final { Var v = GetRef<Var>(op); return env.at(v); } }; Type GradRetType(const Function& f) { // if type annotations are provided, we will construct a ret type; // otherwise, leave it to be inferred if (!f->ret_type.defined()) { return Type(); } std::vector<Type> vt; for (const auto& p : f->params) { if (!p->type_annotation.defined()) { return Type(); } vt.push_back(p->type_annotation); } return TupleTypeNode::make({f->ret_type, TupleTypeNode::make(vt)}); } Expr FirstOrderGradient(const Expr& re, const Module& mod) { // Currently we first remove any global functions for the first // order case. auto e = DeGlobal(mod, re); auto f = e.as<FunctionNode>(); CHECK(f) << "FOWithGradient expects its argument to be a function: " << f; CHECK(f->type_params.size() == 0) << "no polymorphism supported for now"; // We will then build a sequence of lets which implement reverse mode. Expr body = LetList::With([&](LetList* ll) { FirstOrderReverseAD reverse_ad(ll); ADValue rev = reverse_ad(e); std::vector<ADValue> args; for (const auto& p : f->params) { args.push_back(std::make_shared<ADTensor>(ll, p)); } auto c = rev->get<ADFunction>().func(f->checked_type(), args, Attrs(), {}); const auto& res = c->get<ADTensor>(); Expr grad = LetList::With([&](LetList* ll) { res.reverse = OnesLike(res.forward); for (auto it = reverse_ad.backprop_actions.rbegin(); it != reverse_ad.backprop_actions.rend(); ++it) { (*it)(ll); } std::vector<Expr> grad_res; for (const auto& a : args) { grad_res.push_back(a->get<ADTensor>().reverse); } return TupleNode::make(grad_res); }); return Pair(res.forward, grad); }); return FunctionNode::make(f->params, body, GradRetType(GetRef<Function>(f)), {}); } TVM_REGISTER_API("relay._transform.first_order_gradient") .set_body_typed(FirstOrderGradient); struct ReverseADType : TypeMutator { Type VisitType_(const TensorTypeNode* ttn) final { Type t = GetRef<Type>(ttn); return TupleTypeNode::make({t, RefTypeNode::make(t)}); } }; Type ReverseType(const Type& t) { return ReverseADType()(t); } /*! \brief Lift a function that transform Tensor to a function that also transform more type * by doing a structure preserving map. */ Expr LiftTensor(const std::function<Expr(const Expr& t)>& f, const std::function<Type(const Type&)>& tf, const Type& forward_type, const Expr& e, LetList* ll) { CHECK(IsAtomic(e)) << e; if (forward_type.as<TensorTypeNode>()) { auto ret = f(e); ret->checked_type_ = tf(forward_type); return ret; } else if (auto* tt = forward_type.as<TupleTypeNode>()) { tvm::Array<Expr> fields; tvm::Array<Type> types; for (size_t i = 0; i < tt->fields.size(); ++i) { auto field = LiftTensor(f, tf, tt->fields[i], ll->Push(GetField(e, i)), ll); fields.push_back(field); types.push_back(field->checked_type_); } auto ret = TupleNode::make(fields); ret->checked_type_ = TupleTypeNode::make(types); return std::move(ret); } else { LOG(FATAL) << "unsupported input/output type: " << tt; throw; } } /*! \brief Transfers the gradients from an Expr to a deep duplication of the Expr, * by stitching the references in the AD values. */ void TransferGrads(const Type& forward_type, const Expr& from, const Expr& to, LetList* ll) { CHECK(IsAtomic(from)) << from; CHECK(IsAtomic(to)) << to; if (forward_type.as<TensorTypeNode>()) { auto from_ref = TupleGetItemNode::make(from, 1); auto to_ref = TupleGetItemNode::make(to, 1); ll->Push(RefWriteNode::make(to_ref, RefReadNode::make(from_ref))); } else if (auto* tt = forward_type.as<TupleTypeNode>()) { for (size_t i = 0; i < tt->fields.size(); ++i) { TransferGrads(tt->fields[i], ll->Push(TupleGetItemNode::make(from, i)), ll->Push(TupleGetItemNode::make(to, i)), ll); } } else { LOG(FATAL) << "Unsupported input/output type: " << forward_type; throw; } } /*! \brief t -> ReverseType(t). Transform to Reverse Mode Value. */ Expr GetRev(const Type& forward_type, const Expr& e, LetList* ll) { auto rev = [&](const Expr& e) { return Pair(e, ll->Push(RefCreateNode::make(ZerosLike(e)))); }; auto rev_type = [&](const Type& forward_type) { return ReverseType(forward_type); }; return LiftTensor(rev, rev_type, forward_type, e, ll); } /*! \brief ReverseType(t) -> t. Get the original value. */ Expr GetValue(const Type& forward_type, const Expr& e, LetList* ll) { auto val = [&](const Expr& e) { return GetField(e, 0); }; auto val_type = [&](const Type& forward_type) { return forward_type; }; return LiftTensor(val, val_type, forward_type, e, ll); } /*! \brief ReverseType(t) -> t. Get the gradient. */ Expr GetGrad(const Type& forward_type, const Expr& e, LetList* ll) { auto grad = [&](const Expr& e) { return ll->Push(RefReadNode::make(GetField(e, 1))); }; auto grad_type = [&](const Type& forward_type) { return forward_type; }; return LiftTensor(grad, grad_type, forward_type, e, ll); } void UpdateGrad(const Type& t, const Expr& arg, const Expr& grad, LetList* ll) { if (t.as<TensorTypeNode>()) { ll->Push(RefWriteNode::make(GetField(arg, 1), Add(ll->Push(RefReadNode::make(GetField(arg, 1))), grad))); } else if (auto* tt = t.as<TupleTypeNode>()) { for (size_t i = 0; i < tt->fields.size(); ++i) { UpdateGrad(tt->fields[i], ll->Push(GetField(arg, i)), ll->Push(GetField(grad, i)), ll); } } else { LOG(FATAL) << "unsupported arg type of operator: " << t; throw; } } Expr BPEmpty() { Expr unitF = FunctionNode::make({}, TupleNode::make({}), TupleTypeNode::make({}), {}); return RefCreateNode::make(unitF); } struct ReverseAD : ExprMutator { using ADVarMap = std::unordered_map<Var, Var, NodeHash, NodeEqual>; Var bp; std::shared_ptr<ADVarMap> ad_vars; const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient"); explicit ReverseAD(const Var& bp, std::shared_ptr<ADVarMap> ad_vars) : bp(bp), ad_vars(ad_vars) { } Expr VisitExpr_(const OpNode* op) final { LOG(FATAL) << "op should only be inside call"; throw; } Expr VisitCheckpoint(const CallNode *call) { const OpNode* op_node = call->op.as<OpNode>(); CHECK(op_node) << "expected op in call"; Op op_ref = GetRef<Op>(op_node); CHECK(op_ref->name == "annotation.checkpoint") << "expected checkpoint annotation"; auto x = call->args[0]; return LetList::With([&](LetList* ll) { auto x_var = ll->Push(x); auto ret = ll->Push(GetRev(call->checked_type(), x_var, ll)); auto bpv = ll->Push(RefReadNode::make(bp)); Expr nbp = FunctionNode::make( {}, LetList::With([&](LetList* ll) { // we need a new ReverseAD visitor to avoid clobbering the bp local var auto dup_bp = ll->Push(BPEmpty()); ReverseAD dup_diff(dup_bp, ad_vars); auto dup_ad = ll->Push(dup_diff.VisitExpr(DeDup(x))); TransferGrads(call->checked_type(), ret, dup_ad, ll); ll->Push(CallNode::make(RefReadNode::make(dup_bp), {})); return CallNode::make(bpv, {}); }), TupleTypeNode::make({}), {}); ll->Push(RefWriteNode::make(bp, nbp)); return ret; }); } Expr VisitExpr_(const CallNode* call) final { if (const OpNode* op_node = call->op.as<OpNode>()) { Op op_ref = GetRef<Op>(op_node); if (op_ref->name == "annotation.checkpoint") { return VisitCheckpoint(call); } CHECK(rev_map.count(op_ref)) << op_node->name << " does not have reverse mode defined"; return LetList::With([&](LetList* ll) { std::vector<Var> args; for (const auto& arg : call->args) { args.push_back(ll->Push(VisitExpr(arg))); } std::vector<Expr> orig_args; for (size_t i = 0; i < args.size(); i++) { orig_args.push_back(GetValue(call->args[i]->checked_type(), args[i], ll)); } Expr orig = CallNode::make(call->op, orig_args, call->attrs, call->type_args); orig->checked_type_ = call->checked_type(); Var orig_var = ll->Push(orig); orig_var->checked_type_ = call->checked_type(); auto ret = ll->Push(GetRev(call->checked_type(), orig_var, ll)); auto bpv = ll->Push(RefReadNode::make(bp)); Expr nbp = FunctionNode::make( {}, LetList::With([&](LetList* ll) { tvm::Array<Expr> rev = rev_map[op_ref](orig, GetGrad(call->checked_type(), ret, ll)); CHECK(args.size() == rev.size()); for (size_t i = 0; i < args.size(); ++i) { UpdateGrad(call->args[i]->checked_type(), args[i], rev[i], ll); } return CallNode::make(bpv, {}); }), TupleTypeNode::make({}), {}); ll->Push(RefWriteNode::make(bp, nbp)); return ret; }); } return ExprMutator::VisitExpr_(call); } Expr VisitExpr_(const ConstantNode* op) final { Expr e = GetRef<Expr>(op); return Pair(e, RefCreateNode::make(ZerosLike(e))); } Expr VisitExpr_(const IfNode* op) final { return IfNode::make(TupleGetItemNode::make(VisitExpr(op->cond), 0), VisitExpr(op->true_branch), VisitExpr(op->false_branch)); } Expr VisitExpr_(const VarNode* var) final { // memoize Var -> ADVar so we don't end up with free Vars when checkpointing auto var_ref = GetRef<Var>(var); if (!ad_vars->count(var_ref)) { auto res = Downcast<Var>(ExprMutator::VisitExpr_(var)); (*ad_vars)[var_ref] = res; } return ad_vars->at(var_ref); } Type VisitType(const Type& t) final { return t.defined() ? ReverseType(t) : t; } }; bool MissingGrad(const Expr& e) { struct MGVisitor : ExprVisitor { const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient"); std::unordered_set<std::string> op_names; void VisitExpr_(const OpNode* op) final { Op op_ref = GetRef<Op>(op); if (op_ref->name != "annotation.checkpoint" && !rev_map.count(op_ref)) { op_names.insert(op_ref->name); } ExprVisitor::VisitExpr_(op); } }; MGVisitor mg; mg.VisitExpr(e); if (mg.op_names.size() > 0) { LOG(WARNING) << "found operators with missing gradients:"; for (const auto& op : mg.op_names) { LOG(WARNING) << " " << op; } return true; } return false; } Expr Gradient(const Expr& re, const Module& mod) { auto e = DeGlobal(mod, re); auto f = e.as<FunctionNode>(); CHECK(f) << "input need to be a function"; CHECK(f->type_params.size() == 0) << "no polymorphism supported for now"; for (const auto& p : f->params) { CHECK(p->checked_type().as<TensorTypeNode>()) << "input parameters need to be tensor"; } CHECK(!MissingGrad(e)) << "input has operators with missing gradients"; Expr body = LetList::With([&](LetList* ll) { Var bp = ll->Push(BPEmpty()); Expr rev = ReverseAD(bp, std::make_shared<ReverseAD::ADVarMap>())(e); std::vector<Expr> args; for (const auto& p : f->params) { args.push_back(ll->Push(Pair(p, RefCreateNode::make(ZerosLike(p))))); } auto c = ll->Push(CallNode::make(rev, args)); std::function<void(const Expr&, const Type&)> init_grad; init_grad = [&](const Expr& e, const Type& t) { if (t.as<TensorTypeNode>()) { ll->Push(RefWriteNode::make(GetField(e, 1), OnesLike(GetField(e, 0)))); } else if (auto tt = t.as<TupleTypeNode>()) { CHECK_GT(tt->fields.size(), 0); init_grad(ll->Push(GetField(e, 0)), tt->fields[0]); } else { LOG(FATAL) << "unhandled type " << t; throw; } }; init_grad(c, f->body->checked_type()); ll->Push(CallNode::make(RefReadNode::make(bp), {})); std::vector<Expr> ret; for (const auto& a : args) { ret.push_back(RefReadNode::make(GetField(a, 1))); } std::function<Expr(const Expr&, const Type&)> get_final_result; get_final_result = [&](const Expr& e, const Type& t) -> Expr { if (t.as<TensorTypeNode>()) { return GetField(e, 0); } else if (auto tt = t.as<TupleTypeNode>()) { tvm::Array<Expr> fields; for (size_t i = 0; i < tt->fields.size(); ++i) { fields.push_back(get_final_result(ll->Push(GetField(e, i)), tt->fields[i])); } return TupleNode::make(fields); } else { LOG(FATAL) << "unhandled type " << t; throw; } }; return Pair(get_final_result(c, f->body->checked_type()), TupleNode::make(ret)); }); return FunctionNode::make(f->params, body, GradRetType(GetRef<Function>(f)), {}); } TVM_REGISTER_API("relay._transform.gradient") .set_body_typed(Gradient); } // namespace relay } // namespace tvm