/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * Copyright (c) 2018 by Contributors
 *
 * \file util.cc
 *
 * \brief Utility functions for Relay.
 */
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include "pass_util.h"
#include "../ir/type_functor.h"

namespace tvm {
namespace relay {

template<typename T>
struct InsertionSet {
  std::unordered_set<T, NodeHash, NodeEqual> set;
  std::vector<T> data;
  void Insert(const T& t) {
    if (set.count(t) == 0) {
      set.insert(t);
      data.push_back(t);
    }
  }
};

class TypeVarTVisitor : public TypeVisitor {
 public:
  TypeVarTVisitor(
      InsertionSet<TypeVar>* type_vars,
      InsertionSet<TypeVar>* bound_type_vars)
    : type_vars_(type_vars), bound_type_vars_(bound_type_vars) { }

  void VisitType_(const TypeVarNode* tp) final {
    TypeVar var = GetRef<TypeVar>(tp);
    type_vars_->Insert(var);
  }

  void VisitType_(const FuncTypeNode* f) final {
    for (auto type_param : f->type_params) {
      type_vars_->Insert(type_param);
      bound_type_vars_->Insert(type_param);
    }
    TypeVisitor::VisitType_(f);
  }

 private:
  InsertionSet<TypeVar>* type_vars_;
  InsertionSet<TypeVar>* bound_type_vars_;
};

class TypeVarEVisitor : private ExprVisitor {
 public:
  explicit TypeVarEVisitor(const Module& mod) : mod_(mod) {}

  Array<TypeVar> CollectFree() {
    Array<TypeVar> ret;
    for (const auto& v : type_vars_.data) {
      if (bound_type_vars_.set.count(v) == 0) {
        ret.push_back(v);
      }
    }
    return ret;
  }

  Array<TypeVar> CollectBound() {
    Array<TypeVar> ret;
    for (const auto& v : bound_type_vars_.data) {
      ret.push_back(v);
    }
    return ret;
  }

  Array<TypeVar> CollectAll() {
    Array<TypeVar> ret;
    for (const auto& v : type_vars_.data) {
      ret.push_back(v);
    }
    return ret;
  }

  Array<TypeVar> Free(const Expr& expr) {
    VisitExpr(expr);
    return CollectFree();
  }

  Array<TypeVar> Free(const Type& type) {
    VisitType(type);
    return CollectFree();
  }

  Array<TypeVar> Bound(const Expr& expr) {
    VisitExpr(expr);
    return CollectBound();
  }

  Array<TypeVar> Bound(const Type& type) {
    VisitType(type);
    return CollectBound();
  }

  Array<TypeVar> All(const Expr& expr) {
    VisitExpr(expr);
    return CollectAll();
  }

  Array<TypeVar> All(const Type& type) {
    VisitType(type);
    return CollectAll();
  }

  void VisitExpr_(const FunctionNode* f) final {
    for (const auto& tp : f->type_params) {
      type_vars_.Insert(tp);
      bound_type_vars_.Insert(tp);
    }
    ExprVisitor::VisitExpr_(f);
  }

  void VisitExpr_(const ConstructorNode* cn) final {
    // for constructors, type vars will be bound in the module
    auto data = mod_->LookupDef(cn->belong_to);
    for (const auto& tv : data->type_vars) {
      type_vars_.Insert(tv);
      bound_type_vars_.Insert(tv);
    }
    ExprVisitor::VisitExpr_(cn);
  }

  void VisitType(const Type& t) final {
    TypeVarTVisitor(&type_vars_, &bound_type_vars_)
        .VisitType(t);
  }

 private:
  InsertionSet<TypeVar> type_vars_;
  InsertionSet<TypeVar> bound_type_vars_;
  const Module& mod_;
};

class VarVisitor : protected ExprVisitor, protected PatternVisitor {
 public:
  Array<Var> Free(const Expr& expr) {
    this->VisitExpr(expr);
    Array<Var> ret;
    for (const auto& v : vars_.data) {
      if (bound_vars_.set.count(v) == 0) {
        ret.push_back(v);
      }
    }
    return ret;
  }

  Array<Var> Collect() {
    Array<Var> ret;
    for (const auto& v : bound_vars_.data) {
      ret.push_back(v);
    }
    return ret;
  }

  Array<Var> Bound(const Expr& expr) {
    this->VisitExpr(expr);
    return Collect();
  }

  Array<Var> Bound(const Pattern& pat) {
    this->VisitPattern(pat);
    return Collect();
  }

  Array<Var> All(const Expr& expr) {
    this->VisitExpr(expr);
    Array<Var> ret;
    for (const auto& v : vars_.data) {
      ret.push_back(v);
    }
    return ret;
  }

  void MarkBounded(const Var& v) {
    bound_vars_.Insert(v);
    vars_.Insert(v);
  }

  void VisitExpr_(const VarNode* var) final {
    vars_.Insert(GetRef<Var>(var));
  }

  void VisitExpr_(const FunctionNode* op) final {
    for (const auto& param : op->params) {
      MarkBounded(param);
    }
    VisitExpr(op->body);
  }

  void VisitExpr_(const LetNode* op) final {
    MarkBounded(op->var);
    VisitExpr(op->value);
    VisitExpr(op->body);
  }

  void VisitPattern(const Pattern& p) final {
    PatternVisitor::VisitPattern(p);
  }

  void VisitPattern_(const PatternVarNode* op) final {
    MarkBounded(op->var);
  }

 private:
  InsertionSet<Var> vars_;
  InsertionSet<Var> bound_vars_;
};

tvm::Array<TypeVar> FreeTypeVars(const Expr& expr, const Module& mod) {
  return TypeVarEVisitor(mod).Free(expr);
}

tvm::Array<TypeVar> FreeTypeVars(const Type& type, const Module& mod) {
  return TypeVarEVisitor(mod).Free(type);
}

tvm::Array<TypeVar> BoundTypeVars(const Expr& expr, const Module& mod) {
  return TypeVarEVisitor(mod).Bound(expr);
}

tvm::Array<TypeVar> BoundTypeVars(const Type& type, const Module& mod) {
  return TypeVarEVisitor(mod).Bound(type);
}

tvm::Array<TypeVar> AllTypeVars(const Expr& expr, const Module& mod) {
  return TypeVarEVisitor(mod).All(expr);
}

tvm::Array<TypeVar> AllTypeVars(const Type& type, const Module& mod) {
  return TypeVarEVisitor(mod).All(type);
}

tvm::Array<Var> FreeVars(const Expr& expr) {
  return VarVisitor().Free(expr);
}

tvm::Array<Var> BoundVars(const Expr& expr) {
  return VarVisitor().Bound(expr);
}

tvm::Array<Var> BoundVars(const Pattern& pat) {
  return VarVisitor().Bound(pat);
}

tvm::Array<Var> AllVars(const Expr& expr) {
  return VarVisitor().All(expr);
}

TVM_REGISTER_API("relay._ir_pass.free_vars")
.set_body_typed(FreeVars);

TVM_REGISTER_API("relay._ir_pass.bound_vars")
  .set_body([](TVMArgs args, TVMRetValue* ret) {
      NodeRef x = args[0];
      if (x.as_derived<ExprNode>()) {
        *ret = BoundVars(Downcast<Expr>(x));
      } else {
        *ret = BoundVars(Downcast<Pattern>(x));
      }
    });

TVM_REGISTER_API("relay._ir_pass.all_vars")
.set_body_typed(AllVars);

TVM_REGISTER_API("relay._ir_pass.free_type_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) {
    NodeRef x = args[0];
    Module mod = args[1];
    if (x.as_derived<TypeNode>()) {
      *ret = FreeTypeVars(Downcast<Type>(x), mod);
    } else {
      *ret = FreeTypeVars(Downcast<Expr>(x), mod);
    }
  });

TVM_REGISTER_API("relay._ir_pass.bound_type_vars")
  .set_body([](TVMArgs args, TVMRetValue* ret) {
      NodeRef x = args[0];
      Module mod = args[1];
      if (x.as_derived<TypeNode>()) {
        *ret = BoundTypeVars(Downcast<Type>(x), mod);
      } else {
        *ret = BoundTypeVars(Downcast<Expr>(x), mod);
      }
    });

TVM_REGISTER_API("relay._ir_pass.all_type_vars")
  .set_body([](TVMArgs args, TVMRetValue* ret) {
      NodeRef x = args[0];
      Module mod = args[1];
      if (x.as_derived<TypeNode>()) {
        *ret = AllTypeVars(Downcast<Type>(x), mod);
      } else {
        *ret = AllTypeVars(Downcast<Expr>(x), mod);
      }
    });

/*!
 * \brief Get reference counter of each internal ExprNode in body.
 * \param body The body expression.
 * \return The reference count mapping.
 */
std::unordered_map<const Node*, size_t>
GetExprRefCount(const Expr& body) {
  class ExprRefCounter : private ExprVisitor {
   public:
    std::unordered_map<const Node*, size_t>
    Get(const Expr& body) {
      this->VisitExpr(body);
      return std::move(this->visit_counter_);
    }
  };
  return ExprRefCounter().Get(body);
}

template <typename T>
bool IsNDArrayAllGreaterEqual(const runtime::NDArray& tensor, T value) {
  CHECK_EQ(tensor->ctx.device_type, kDLCPU);
  CHECK(tensor->strides == nullptr);
  CHECK_EQ(tensor->byte_offset, 0);
  const T* data = static_cast<const T*>(tensor->data);
  int64_t num_elems = 1;
  for (int i = 0; i < tensor->ndim; ++i) {
    num_elems *= tensor->shape[i];
  }

  for (int64_t i = 0; i < num_elems; i++) {
    if (*data < value) {
      return false;
    }
    data++;
  }
  return true;
}

bool IsAllPositiveConstant(const Expr& expr) {
  // peel through a few common transform ops.
  static const auto& expand_dims = Op::Get("expand_dims");
  static const auto& reshape = Op::Get("reshape");
  static const auto& transpose = Op::Get("transpose");
  static const auto& squeeze = Op::Get("squeeze");

  if (const auto* constant = expr.as<ConstantNode>()) {
    const auto& tensor = constant->data;
    const auto& dtype = tensor->dtype;
    if (dtype.lanes != 1) {
      return false;
    } else if (dtype.code == kDLFloat && dtype.bits == 32) {
      return IsNDArrayAllGreaterEqual<float>(tensor, 0);
    } else if (dtype.code == kDLFloat && dtype.bits == 64) {
      return IsNDArrayAllGreaterEqual<double>(tensor, 0);
    } else if (dtype.code == kDLInt && dtype.bits == 8) {
      return IsNDArrayAllGreaterEqual<int8_t>(tensor, 0);
    } else if (dtype.code == kDLInt && dtype.bits == 32) {
      return IsNDArrayAllGreaterEqual<int32_t>(tensor, 0);
    } else if (dtype.code == kDLUInt && dtype.bits == 8) {
      return IsNDArrayAllGreaterEqual<uint8_t>(tensor, 0);
    } else if (dtype.code == kDLUInt && dtype.bits == 32) {
      return IsNDArrayAllGreaterEqual<uint32_t>(tensor, 0);
    } else {
      return false;
    }
  } else if (const auto* op = expr.as<CallNode>()) {
    // tail recursion.
    if (op->op.same_as(expand_dims) ||
        op->op.same_as(reshape) ||
        op->op.same_as(transpose) ||
        op->op.same_as(squeeze)) {
      return IsAllPositiveConstant(op->args[0]);
    } else {
      return false;
    }
  } else {
    return false;
  }
}

Type TypeSubst(const Type& type, const TypeVar& tvar, const Type& subst) {
  return TypeSubst(type, tvm::Map<TypeVar, Type>({{tvar, subst}}));
}

Expr TypeSubst(const Expr& expr, const TypeVar& tvar, const Type& subst) {
  return TypeSubst(expr, tvm::Map<TypeVar, Type>({{tvar, subst}}));
}

Type TypeSubst(const Type& type, const tvm::Map<TypeVar, Type>& subst_map) {
  return Bind(type, subst_map);
}

Expr TypeSubst(const Expr& expr, const tvm::Map<TypeVar, Type>& subst_map) {
  class TypeSubstMutator : public ExprMutator, public PatternMutator {
   public:
    explicit TypeSubstMutator(const tvm::Map<TypeVar, Type>& subst_map) : subst_map_(subst_map) { }
    Type VisitType(const Type& t) final {
      return TypeSubst(t, subst_map_);
    }
    Var VisitVar(const Var& v) final {
      return Downcast<Var>(VisitExpr(v));
    }
   private:
    const tvm::Map<TypeVar, Type>& subst_map_;
  };
  return TypeSubstMutator(subst_map).VisitExpr(expr);
}

}  // namespace relay
}  // namespace tvm