/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 *
 * \file base64.h
 * \brief data stream support to input and output from/to base64 stream
 *   base64 is easier to store and pass as text format in mapreduce
 */
#ifndef TVM_SUPPORT_BASE64_H_
#define TVM_SUPPORT_BASE64_H_

#include <dmlc/logging.h>
#include <dmlc/logging.h>
#include <cctype>
#include <cstdio>
#include <string>

namespace tvm {
namespace support {
/*! \brief namespace of base64 decoding and encoding table */
namespace base64 {
// decoding table
const char DecodeTable[] = {
  0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
  0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
  62,  // '+'
  0, 0, 0,
  63,  // '/'
  52, 53, 54, 55, 56, 57, 58, 59, 60, 61,  // '0'-'9'
  0, 0, 0, 0, 0, 0, 0,
  0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
  13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,  // 'A'-'Z'
  0, 0, 0, 0, 0, 0,
  26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,
  39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,  // 'a'-'z'
};
// encoding table
static const char EncodeTable[] =
    "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
}  // namespace base64

/*!
 * \brief Buffer reader from stream to avoid
 *  virtual call overhead on each read.
 */
class StreamBufferReader {
 public:
  explicit StreamBufferReader(size_t buffer_size) {
    buffer_.resize(buffer_size);
  }
  /*!
   * \brief set input stream
   * \param stream The stream to be set
   */
  void set_stream(dmlc::Stream *stream) {
    stream_ = stream;
    read_len_ = read_ptr_ = 1;
  }
  /*!
   * \return allows quick read using get char
   */
  int GetChar() {
    while (true) {
      if (read_ptr_ < read_len_) {
        return static_cast<int>(buffer_[read_ptr_++]);
      } else {
        read_len_ = stream_->Read(&buffer_[0], buffer_.length());
        if (read_len_ == 0) return EOF;
        read_ptr_ = 0;
      }
    }
  }
  /*! \return whether we are reaching the end of file */
  bool AtEnd() const {
    return read_len_ == 0;
  }

 private:
  /*! \brief the underlying stream */
  dmlc::Stream *stream_{nullptr};
  /*! \brief buffer to hold data */
  std::string buffer_;
  /*! \brief length of valid data in buffer */
  size_t read_len_{1};
  /*! \brief pointer in the buffer */
  size_t read_ptr_{1};
};

/*!
 * \brief Input stream from base64 encoding
 */
class Base64InStream: public dmlc::Stream {
 public:
  explicit Base64InStream(dmlc::Stream *fs) : reader_(256) {
    reader_.set_stream(fs);
  }
  /*!
   * \brief initialize the stream position to beginning of next base64 stream
   * \note call this function before actually start read
   */
  void InitPosition(void) {
    // get a character
    do {
      temp_ch_ = reader_.GetChar();
    } while (isspace(temp_ch_));
  }
  /*! \brief whether current position is end of a base64 stream */
  bool IsEOF(void) const {
    return num_prev_ == 0 && (temp_ch_ == EOF || isspace(temp_ch_));
  }
  // override read function.
  virtual size_t Read(void *ptr, size_t size) {
    using base64::DecodeTable;
    if (size == 0) return 0;
    // use tlen to record left size
    size_t tlen = size;
    unsigned char *cptr = static_cast<unsigned char*>(ptr);
    // if anything left, load from previous buffered result
    if (num_prev_ != 0) {
      if (num_prev_ == 2) {
        if (tlen >= 2) {
          *cptr++ = buf_prev[0];
          *cptr++ = buf_prev[1];
          tlen -= 2;
          num_prev_ = 0;
        } else {
          // assert tlen == 1
          *cptr++ = buf_prev[0]; --tlen;
          buf_prev[0] = buf_prev[1];
          num_prev_ = 1;
        }
      } else {
        // assert num_prev_ == 1
        *cptr++ = buf_prev[0]; --tlen; num_prev_ = 0;
      }
    }
    if (tlen == 0) return size;
    int nvalue;
    // note: everything goes with 4 bytes in Base64
    // so we process 4 bytes a unit
    while (tlen && temp_ch_ != EOF && !isspace(temp_ch_)) {
      // first byte
      nvalue = DecodeTable[temp_ch_] << 18;
      {
        // second byte
        temp_ch_ = reader_.GetChar();
        CHECK(temp_ch_ != EOF && !isspace(temp_ch_)) << "invalid base64 format";
        nvalue |= DecodeTable[temp_ch_] << 12;
        *cptr++ = (nvalue >> 16) & 0xFF; --tlen;
        }
      {
        // third byte
        temp_ch_ = reader_.GetChar();
        CHECK(temp_ch_ != EOF && !isspace(temp_ch_)) << "invalid base64 format";
        // handle termination
        if (temp_ch_ == '=') {
          temp_ch_ = reader_.GetChar();
          CHECK(temp_ch_ == '=') << "invalid base64 format";
          temp_ch_ = reader_.GetChar();
          CHECK(temp_ch_ == EOF || isspace(temp_ch_))
              << "invalid base64 format";
          break;
        }
        nvalue |= DecodeTable[temp_ch_] << 6;
        if (tlen) {
          *cptr++ = (nvalue >> 8) & 0xFF; --tlen;
        } else {
          buf_prev[num_prev_++] = (nvalue >> 8) & 0xFF;
        }
      }
      {
        // fourth byte
        temp_ch_ = reader_.GetChar();
        CHECK(temp_ch_ != EOF && !isspace(temp_ch_))
            << "invalid base64 format";
        if (temp_ch_ == '=') {
          temp_ch_ = reader_.GetChar();
          CHECK(temp_ch_ == EOF || isspace(temp_ch_))
              << "invalid base64 format";
          break;
        }
        nvalue |= DecodeTable[temp_ch_];
        if (tlen) {
          *cptr++ = nvalue & 0xFF; --tlen;
        } else {
          buf_prev[num_prev_ ++] = nvalue & 0xFF;
        }
      }
      // get next char
      temp_ch_ = reader_.GetChar();
    }
    if (kStrictCheck) {
      CHECK_EQ(tlen, 0) << "Base64InStream: read incomplete";
    }
    return size - tlen;
  }
  virtual void Write(const void *ptr, size_t size) {
    LOG(FATAL) << "Base64InStream do not support write";
  }

 private:
  // internal reader
  StreamBufferReader reader_;
  int temp_ch_{0};
  int num_prev_{0};
  unsigned char buf_prev[2];
  // whether we need to do strict check
  static const bool kStrictCheck = false;
};

/*!
 * \brief Stream to write to base64 format.
 */
class Base64OutStream: public dmlc::Stream {
 public:
  explicit Base64OutStream(dmlc::Stream *fp) : fp_(fp) {
  }
  virtual void Write(const void *ptr, size_t size) {
    using base64::EncodeTable;
    size_t tlen = size;
    const unsigned char *cptr = static_cast<const unsigned char*>(ptr);
    while (tlen) {
      while (buf__top_ < 3  && tlen != 0) {
        buf_[++buf__top_] = *cptr++; --tlen;
      }
      if (buf__top_ == 3) {
        // flush 4 bytes out
        PutChar(EncodeTable[buf_[1] >> 2]);
        PutChar(EncodeTable[((buf_[1] << 4) | (buf_[2] >> 4)) & 0x3F]);
        PutChar(EncodeTable[((buf_[2] << 2) | (buf_[3] >> 6)) & 0x3F]);
        PutChar(EncodeTable[buf_[3] & 0x3F]);
        buf__top_ = 0;
      }
    }
  }
  virtual size_t Read(void *ptr, size_t size) {
    LOG(FATAL) << "Base64OutStream do not support read";
    return 0;
  }
  /*!
   * \brief finish writing of all current base64 stream, do some post processing
   * \param endch character to put to end of stream, if it is EOF, then nothing will be appended.
   */
  void Finish(int endch = EOF) {
    using base64::EncodeTable;
    if (buf__top_ == 1) {
      PutChar(EncodeTable[buf_[1] >> 2]);
      PutChar(EncodeTable[(buf_[1] << 4) & 0x3F]);
      PutChar('=');
      PutChar('=');
    }
    if (buf__top_ == 2) {
      PutChar(EncodeTable[buf_[1] >> 2]);
      PutChar(EncodeTable[((buf_[1] << 4) | (buf_[2] >> 4)) & 0x3F]);
      PutChar(EncodeTable[(buf_[2] << 2) & 0x3F]);
      PutChar('=');
    }
    buf__top_ = 0;
    if (endch != EOF) PutChar(endch);
    this->Flush();
  }

 private:
  static constexpr size_t kBufferSize = 256;

  dmlc::Stream *fp_{nullptr};
  int buf__top_{0};
  unsigned char buf_[4];
  std::string out_buf_;


  void PutChar(char ch) {
    out_buf_ += ch;
    if (out_buf_.length() >= kBufferSize) Flush();
  }
  void Flush(void) {
    if (out_buf_.length() != 0) {
      fp_->Write(&out_buf_[0], out_buf_.length());
      out_buf_.clear();
    }
  }
};
}  // namespace support
}  // namespace tvm
#endif  // TVM_SUPPORT_BASE64_H_