/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 *  Copyright (c) 2019 by Contributors
 * \file tvm/relay/backend/vm/lambda_lift.cc
 * \brief Lift all local functions into global functions.
 */

#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/logging.h>
#include <tvm/relay/pass.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/vm.h>
#include <iostream>
#include <vector>

using namespace tvm::runtime;

namespace tvm {
namespace relay {
namespace vm {

static const char* kIsClosure = "IsClosure";

inline std::string GenerateName(const Function& func) {
  size_t hash = StructuralHash()(func);
  return std::string("lifted_name") + std::to_string(hash);
}

bool IsClosure(const Function& func) {
  NodeRef res = FunctionGetAttr(func, kIsClosure);
  const ir::IntImm* pval = res.as<ir::IntImm>();
  return pval && pval->value != 0;
}

Function MarkClosure(const Function& func) {
  return FunctionSetAttr(func, kIsClosure, tvm::Integer(1));
}

/* The goal of this class is to lift out any nested functions into top-level
 * functions.
 *
 * We will lift a function out into a global which takes the set of the free
 * vars and then return the new created function.
 */
struct LambdaLifter : ExprMutator {
  Module module_;
  explicit LambdaLifter(const Module& module) : module_(module) {}

  Expr VisitExpr_(const FunctionNode* func_node) final {
    auto func = GetRef<Function>(func_node);

    // We should not transform primitive functions.
    if (func->IsPrimitive()) {
      return std::move(func);
    }

    auto free_vars = FreeVars(func);
    auto free_type_vars = FreeTypeVars(func, module_);
    auto body = Downcast<Function>(ExprMutator::VisitExpr_(func_node));

    // When performing this optimization there are two cases.
    //
    // The first case in which we have no free variables
    // we can just lift the function into the global
    // environment without needing to allocate a closure.
    //
    //
    // The second case requires that we generate a special
    // function which makes a distinction between allocating
    // a closure, and then the code for the closure.
    //
    // We represent a closure allocation by lifting the
    // closure to a global function which takes its
    // captured arguments and then directly returns
    // the function representing the closure's code.
    //
    // When we generate code later on a call to the "outer"
    // function marked as a closure is used to emit allocation
    // code for the closure's environment.
    //
    // The "inner" function should be used to generate the
    // code for the closure.
    Function lifted_func;
    if (free_vars.size() == 0) {
      lifted_func = FunctionNode::make(body->params, body->body, body->ret_type, free_type_vars);
    } else {
      lifted_func =
          FunctionNode::make(free_vars, body, func->func_type_annotation(), free_type_vars);

      lifted_func = MarkClosure(lifted_func);
    }

    CHECK(lifted_func.defined());

    auto name = GenerateName(lifted_func);
    auto global = module_->GetGlobalVar(name);

    // Add the lifted function to the module.
    module_->Add(global, lifted_func);

    if (free_vars.size() == 0) {
      return std::move(global);
    } else {
      // If we need to allocate a closure,
      // we pass the variables in its environment here.
      Array<Expr> fvs;
      for (auto fv : free_vars) {
        fvs.push_back(fv);
      }
      return CallNode::make(global, fvs);
    }
  }

  Module Lift() {
    // There is an ordering bug here.
    auto glob_funcs = module_->functions;
    for (auto pair : glob_funcs) {
      auto func = pair.second;
      DLOG(INFO) << "Lifting " << AsText(func, false);
      func = FunctionNode::make(func->params,
                                VisitExpr(func->body),
                                func->ret_type,
                                func->type_params,
                                func->attrs);
      module_->Add(pair.first, func, true);
    }
    return module_;
  }
};

}  // namespace vm

namespace transform {

Pass LambdaLift() {
  runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func =
    [=](Module m, PassContext pc) {
    return relay::vm::LambdaLifter(m).Lift();
  };
  return CreateModulePass(pass_func, 1, "LambdaLift", {});
}

TVM_REGISTER_API("relay._transform.LambdaLift")
.set_body_typed(LambdaLift);

}  // namespace transform

}  // namespace relay
}  // namespace tvm