/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 *  Copyright (c) 2018 by Contributors
 * \file verify_gpu_code.cc
 * \brief Verify the correctness of a GPU IR.
 *        It will check the whether the amount of memory usage or the number of threads
 *        in a block exceeds the limit
 */

#include <tvm/api_registry.h>
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>

namespace tvm {
namespace ir {

class GPUCodeVerifier : public IRVisitor {
 public:
  bool Verify(tvm::Stmt stmt,
              int64_t max_local_memory_per_block,
              int64_t max_shared_memory_per_block,
              int64_t max_threads_per_block,
              int64_t max_thread_x,
              int64_t max_thread_y,
              int64_t max_thread_z) {
    max_local_memory_per_block_ = static_cast<size_t>(max_local_memory_per_block);
    max_shared_memory_per_block_ = static_cast<size_t>(max_shared_memory_per_block);
    max_threads_per_block_ = static_cast<size_t>(max_threads_per_block);
    max_thread_x_ = static_cast<size_t>(max_thread_x);
    max_thread_y_ = static_cast<size_t>(max_thread_y);
    max_thread_z_ = static_cast<size_t>(max_thread_z);

    Reset_();

    this->Visit(stmt);

    return valid_;
  }

  void Visit_(const ProducerConsumer *op) {
    if (nest_level_ == 0) {
      // enter a new kernel, reset statistics
      Reset_();
    }

    if (op->is_producer) {
      nest_level_++;
      IRVisitor::Visit_(op);
      nest_level_--;
    } else {
      IRVisitor::Visit_(op);
    }

    if (nest_level_ == 0) {
      // exit a kernel, check the validity
      valid_ &= thread_per_block_ <= max_threads_per_block_;

      valid_ &= local_memory_per_block_ <= max_local_memory_per_block_;
      valid_ &= shared_memory_per_block_ <= max_shared_memory_per_block_;
    }
  }

  void Visit_(const Allocate *op) {
    IRVisitor::Visit_(op);
    // visit an allocation of a buffer in shared memory, record its size
    if (visited_local_buffers_.count(op->buffer_var.get()) != 0) {
      size_t size = static_cast<size_t>(op->constant_allocation_size());
      local_memory_per_block_ += size * op->type.bytes() * op->type.lanes();
    } else if (visited_shared_buffers_.count(op->buffer_var.get()) != 0) {
      size_t size = static_cast<size_t>(op->constant_allocation_size());
      shared_memory_per_block_ += size * op->type.bytes() * op->type.lanes();
    }
  }

  void Visit_(const AttrStmt *op) {
    if (op->attr_key == attr::storage_scope) {
      std::string op_value = op->value.as<StringImm>()->value;
      if (op_value == "local") {
        visited_local_buffers_.insert(op->node.as<tvm::Variable>());
      } else if (op_value == "shared") {
        visited_shared_buffers_.insert(op->node.as<tvm::Variable>());
      }
    } else if (op->attr_key == attr::thread_extent) {
      VarExpr var = op->node.as<tvm::IterVarNode>()->var;
      const auto *extent = op->value.as<IntImm>();
      CHECK(extent);

      // record the number of threads in a block
      std::string name = var.get()->name_hint;
      if (name == "threadIdx.x" || name == "threadIdx.y" || name == "threadIdx.z") {
        size_t length = static_cast<size_t>(extent->value);
        if (!visited_threads_.count(name)) {
          visited_threads_.insert(name);
          thread_per_block_ *= length;

          if (name == "threadIdx.x") {
            valid_ &= length <= max_thread_x_;
            thread_x_extent_ = length;
          } else if (name == "threadIdx.y") {
            valid_ &= length <= max_thread_y_;
            thread_y_extent_ = length;
          } else if (name == "threadIdx.z") {
            valid_ &= length <= max_thread_z_;
            thread_z_extent_ = length;
          }
        } else {
          // the thread should be bound to axes with the same length
          if (name == "threadIdx.x") {
            valid_ &= length == thread_x_extent_;
          } else if (name == "threadIdx.y") {
            valid_ &= length == thread_y_extent_;
          } else if (name == "threadIdx.z") {
            valid_ &= length == thread_z_extent_;
          }
        }
      }
    }
    IRVisitor::Visit_(op);
  }

 private:
  int nest_level_{0};

  std::unordered_set<const tvm::Variable *> visited_local_buffers_;
  std::unordered_set<const tvm::Variable *> visited_shared_buffers_;
  std::unordered_set<std::string> visited_threads_;

  size_t thread_x_extent_, thread_y_extent_, thread_z_extent_;

  size_t local_memory_per_block_;
  size_t shared_memory_per_block_;
  size_t thread_per_block_;

  size_t max_local_memory_per_block_;
  size_t max_shared_memory_per_block_;
  size_t max_threads_per_block_;
  size_t max_thread_x_, max_thread_y_, max_thread_z_;

  bool valid_{true};

  void Reset_() {
    visited_local_buffers_.clear();
    visited_shared_buffers_.clear();
    local_memory_per_block_ = 0;
    shared_memory_per_block_ = 0;

    visited_threads_.clear();
    thread_per_block_ = 1;
  }
};

bool VerifyGPUCode(Stmt stmt,
                   Map<std::string, Expr> constraints) {
  GPUCodeVerifier verifier;

  int64_t max_local_memory_per_block = INT64_MAX;
  int64_t max_shared_memory_per_block = INT64_MAX;
  int64_t max_threads_per_block = INT64_MAX;
  int64_t max_thread_x = INT64_MAX;
  int64_t max_thread_y = INT64_MAX;
  int64_t max_thread_z = INT64_MAX;

  for (auto iter : constraints) {
    const IntImm* val = iter.second.as<IntImm>();
    if (iter.first == "max_local_memory_per_block")
      max_local_memory_per_block = val->value;
    else if (iter.first == "max_shared_memory_per_block")
      max_shared_memory_per_block = val->value;
    else if (iter.first == "max_threads_per_block")
      max_threads_per_block = val->value;
    else if (iter.first == "max_thread_x")
      max_thread_x = val->value;
    else if (iter.first == "max_thread_y")
      max_thread_y = val->value;
    else if (iter.first == "max_thread_z")
      max_thread_z = val->value;
    else
      LOG(FATAL) << "Invalid check item: " << iter.first;
  }

  return verifier.Verify(stmt,
                         max_local_memory_per_block,
                         max_shared_memory_per_block,
                         max_threads_per_block,
                         max_thread_x,
                         max_thread_y,
                         max_thread_z);
}

}  // namespace ir
}  // namespace tvm