/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 *  Copyright (c) 2018 by Contributors
 * \file random/mt_random_engine.cc
 * \brief mt19937 random engine
 */
#include <dmlc/logging.h>
#include <algorithm>
#include <ctime>
#include <random>

namespace tvm {
namespace contrib {

/*!
 * \brief An interface for generating [tensors of] random numbers.
 */
class RandomEngine {
 public:
   /*!
    * \brief Creates a RandomEngine using a default seed.
    */
  RandomEngine() {
    this->Seed(time(0));
  }

   /*!
    * \brief Creates a RandomEngine, suggesting the use of a provided seed.
    */
  explicit RandomEngine(unsigned seed) {
    this->Seed(seed);
  }

   /*!
    * \brief Seeds the underlying RNG, if possible.
    */
  inline void Seed(unsigned seed) {
    rnd_engine_.seed(seed);
    this->rseed_ = static_cast<unsigned>(seed);
  }

   /*!
    * \return the seed associated with the underlying RNG.
    */
  inline unsigned GetSeed() const {
    return rseed_;
  }

   /*!
    * \return a random integer sampled from the RNG.
    */
  inline unsigned GetRandInt() {
    return rnd_engine_();
  }

   /*!
    * \brief Fills a tensor with values drawn from Unif(low, high)
    */
  void SampleUniform(DLTensor* data, float low, float high) {
    CHECK_GT(high, low) << "high must be bigger than low";
    CHECK(data->strides == nullptr);

    DLDataType dtype = data->dtype;
    int64_t size = 1;
    for (int i = 0; i < data->ndim; ++i) {
      size *= data->shape[i];
    }

    CHECK(dtype.code == kDLFloat && dtype.bits == 32 && dtype.lanes == 1);

    if (data->ctx.device_type == kDLCPU) {
      std::uniform_real_distribution<float> uniform_dist(low, high);
      std::generate_n(static_cast<float*>(data->data), size, [&] () {
        return uniform_dist(rnd_engine_);
      });
    } else {
      LOG(FATAL) << "Do not support random.uniform on this device yet";
    }
  }

   /*!
    * \brief Fills a tensor with values drawn from Normal(loc, scale**2)
    */
  void SampleNormal(DLTensor* data, float loc, float scale) {
    CHECK_GT(scale, 0) << "standard deviation must be positive";
    CHECK(data->strides == nullptr);

    DLDataType dtype = data->dtype;
    int64_t size = 1;
    for (int i = 0; i < data->ndim; ++i) {
      size *= data->shape[i];
    }

    CHECK(dtype.code == kDLFloat && dtype.bits == 32 && dtype.lanes == 1);

    if (data->ctx.device_type == kDLCPU) {
      std::normal_distribution<float> normal_dist(loc, scale);
      std::generate_n(static_cast<float*>(data->data), size, [&] () {
        return normal_dist(rnd_engine_);
      });
    } else {
      LOG(FATAL) << "Do not support random.normal on this device yet";
    }
  }

 private:
  std::mt19937 rnd_engine_;
  unsigned rseed_;
};

}  // namespace contrib
}  // namespace tvm