/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file constant_folding.cc
 */
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/interpreter.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/container.h>
#include "pattern_util.h"

namespace tvm {
namespace relay {

using FInterpreter = runtime::TypedPackedFunc<ObjectRef(Expr)>;

class ConstantChecker : private ExprVisitor {
 public:
  // Check whether an expression is constant. The results are memoized.
  bool Check(const Expr& expr) {
    // The `ConstantNode` case is common enough that we check directly for the
    // case here, to avoid the time overhead of dispatching through the vtable
    // and the space overhead of memoizing always-true results.
    if (expr.as<ConstantNode>()) {
      return true;
    }
    const auto it = memo_.find(expr);
    if (it != memo_.end())
      return it->second;
    VisitExpr(expr);
    return memo_[expr];  // return memoized result or the default value false
  }

 private:
  std::unordered_map<Expr, bool, ObjectHash, ObjectEqual> memo_;

  void VisitExpr_(const TupleNode* n) final {
    bool result = true;
    for (const auto& field : n->fields) {
      if (!Check(field)) {
        result = false;
        break;
      }
    }
    memo_[GetRef<Tuple>(n)] = result;
  }
};

bool ConstantCheck(const Expr& e) {
  return ConstantChecker().Check(e);
}

TVM_REGISTER_GLOBAL("relay.analysis.check_constant")
.set_body_typed(ConstantCheck);

// TODO(tvm-team) consider combine dead-code with constant folder.
// or make a more powerful partial evaluator.
class ConstantFolder : public ExprMutator {
 public:
  explicit ConstantFolder(FInterpreter executor, IRModule module)
      : executor_(executor),
        module_(module),
        shape_of_op_(Op::Get("shape_of")),
        invoke_tvm_op_(Op::Get("memory.invoke_tvm_op")),
        shape_func_op_(Op::Get("memory.shape_func")),
        alloc_tensor_op_(Op::Get("memory.alloc_tensor")),
        alloc_storage_op_(Op::Get("memory.alloc_storage")),
        cast_op_(Op::Get("cast")) {}

  Expr VisitExpr_(const LetNode* op) final {
    Expr value = this->Mutate(op->value);
    if (value.as<ConstantNode>()) {
      memo_[op->var] = value;
      return this->Mutate(op->body);
    } else {
      Var var = Downcast<Var>(this->Mutate(op->var));
      Expr body = this->Mutate(op->body);
      if (var.same_as(op->var) &&
          value.same_as(op->value) &&
          body.same_as(op->body)) {
        return GetRef<Expr>(op);
      } else {
        return LetNode::make(var, value, body);
      }
    }
  }

  Expr VisitExpr_(const CallNode* call) final {
    static auto op_stateful = Op::GetAttr<TOpIsStateful>("TOpIsStateful");

    std::unordered_set<std::string> skip_list{"zeros_like", "ones_like", "full_like", "full"};

    auto origin_args = call->args;
    Expr res = ExprMutator::VisitExpr_(call);
    call = res.as<CallNode>();
    // We don't constant fold function with zero arguments.
    // This is a heuristic that is useful.
    // For example it is harmful to fold ones(shape=(4, 5)).
    if (call->args.size() == 0) return res;
    const OpNode* op = call->op.as<OpNode>();
    if (op == nullptr) return res;
    if (skip_list.count(op->name)) {
        return res;
    }
    // skip stateful ops.
    if (op_stateful.get(GetRef<Op>(op), false)) return res;
    // Try to evaluate shape_of op
    if (call->op == shape_of_op_) {
      return EvaluateShapeOf(res, origin_args, call->attrs);
    }

    // We should think about potentially constant evaluation over these ops too.
    if (call->op == invoke_tvm_op_ ||
        call->op == shape_func_op_ ||
        call->op == alloc_tensor_op_ ||
        call->op == alloc_storage_op_) {
      return GetRef<Call>(call);
    }

    bool all_const_args = true;
    for (Expr arg : call->args) {
      if (!checker_.Check(arg)) {
        all_const_args = false;
      }
    }
    if (all_const_args) {
      return ConstEvaluate(res);
    } else {
      return res;
    }
  }

  Expr VisitExpr_(const TupleGetItemNode* op) final {
    Expr res = ExprMutator::VisitExpr_(op);
    op = res.as<TupleGetItemNode>();
    if (const auto* tuple = op->tuple.as<TupleNode>()) {
      return tuple->fields[op->index];
    } else {
      return res;
    }
  }

 private:
  // Internal interepreter.
  FInterpreter executor_;
  // Internal constant checker
  ConstantChecker checker_;
  // Module
  IRModule module_;

  // Cache the following ops for equivalence checking in this pass.
  const Op& shape_of_op_;
  const Op& invoke_tvm_op_;
  const Op& shape_func_op_;
  const Op& alloc_tensor_op_;
  const Op& alloc_storage_op_;
  const Op& cast_op_;

  // Convert value to expression.
  Expr ObjectToExpr(const ObjectRef& value) {
    if (value->IsInstance<runtime::NDArray::ContainerType>()) {
      auto nd_array = Downcast<runtime::NDArray>(value);
      for (auto dim : nd_array.Shape()) {
        CHECK_GT(dim, 0)
          << "invalid dimension after constant eval";
      }
      return ConstantNode::make(nd_array);
    } else if (const auto* val = value.as<runtime::ADTObj>()) {
      runtime::ADT adt = GetRef<runtime::ADT>(val);
      Array<Expr> fields;
      for (size_t i = 0; i < adt.size(); ++i) {
        fields.push_back(ObjectToExpr(adt[i]));
      }
      return TupleNode::make(fields);
    } else {
      LOG(FATAL) << "Cannot handle " << value->GetTypeKey();
      return Expr();
    }
  }
  // Constant evaluate a expression.
  Expr ConstEvaluate(Expr expr) {
    std::vector<transform::Pass> passes = {transform::FuseOps(0),
                                           transform::InferType()};
    Function func;
    if (expr.as<FunctionNode>()) {
      func = Downcast<Function>(expr);
    } else {
      // TODO(@jroesch): fix this
      func = Function(FreeVars(expr), expr, Type(), FreeTypeVars(expr, module_), {});
    }
    auto mod = IRModule(
      {},
      module_->type_definitions,
      module_->Imports());
    auto global = GlobalVar("main");
    mod->Add(global, func);
    auto seq = transform::Sequential(passes);
    mod = seq(mod);
    auto entry_func = Downcast<Function>(mod->Lookup("main"));
    expr = expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
    return ObjectToExpr(executor_(expr));
  }

  // Evaluate a call to the shape_of operator for tensors with constant
  // shapes.
  Expr EvaluateShapeOf(Expr expr, Array<Expr> args, Attrs attrs) {
    Expr input = args[0];
    const auto* param = attrs.as<ShapeOfAttrs>();
    CHECK(param != nullptr);

    tvm::Array<IndexExpr> ishape;
    if (const ConstantNode* op = input.as<ConstantNode>()) {
      ishape = op->tensor_type()->shape;
    } else if (input->checked_type_.defined()) {
      ishape = input->checked_type().as<TensorTypeNode>()->shape;
    } else {
      return expr;
    }

    // Get the constant shape
    DLContext ctx;
    ctx.device_type = kDLCPU;
    ctx.device_id = 0;
    runtime::NDArray value;
    DLDataType cdtype = DataType::Int(32);
    if (ishape.size() == 0) {
      value = runtime::NDArray::Empty({}, cdtype, ctx);
    } else {
      CHECK_NE(ishape.size(), 0);
      std::vector<int64_t> cshape = { static_cast<int64_t>(ishape.size()) };
      value = runtime::NDArray::Empty(cshape, cdtype, ctx);
      int32_t* dims = static_cast<int32_t*>(value->data);
      using ::tvm::tir::IntImmNode;
      for (size_t i = 0; i < ishape.size(); ++i) {
        if (const IntImmNode* dim = ishape[i].as<IntImmNode>()) {
          dims[i] = dim->value;
        } else {
          return expr;
        }
      }
    }

    Constant shape = Downcast<Constant>(ObjectToExpr(value));

    if (shape->data.Shape().size() == 0 && GetScalarFromConstant<int32_t>(shape) == 0) {
      auto ndarray = runtime::NDArray::Empty({}, cdtype, ctx);
      shape = ConstantNode::make(ndarray);
    }

    // Cast the constant into correct dtype
    auto cast_attrs = make_object<CastAttrs>();
    cast_attrs->dtype = param->dtype;
    Expr ret = CallNode::make(cast_op_, { shape }, Attrs(cast_attrs), {});
    return ConstEvaluate(ret);
  }
};


Expr FoldConstant(const Expr& expr, const IRModule& mod) {
  DLContext ctx;
  ctx.device_type = kDLCPU;
  ctx.device_id = 0;
  Target target = Target::Create("llvm");
  // use a fresh build context
  // in case we are already in a build context.
  With<BuildConfig> fresh_build_ctx(BuildConfig::Create());

  return ConstantFolder(CreateInterpreter(mod, ctx, target), mod).Mutate(expr);
}

namespace transform {

Pass FoldConstant() {
  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
    [=](Function f, IRModule m, PassContext pc) {
      return Downcast<Function>(FoldConstant(f, m));
  };
  return CreateFunctionPass(pass_func, 2, "FoldConstant", {});
}

TVM_REGISTER_GLOBAL("relay._transform.FoldConstant")
.set_body_typed(FoldConstant);

}  // namespace transform

}  // namespace relay
}  // namespace tvm