/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * * \file base64.h * \brief data stream support to input and output from/to base64 stream * base64 is easier to store and pass as text format in mapreduce */ #ifndef TVM_SUPPORT_BASE64_H_ #define TVM_SUPPORT_BASE64_H_ #include <dmlc/logging.h> #include <dmlc/logging.h> #include <cctype> #include <cstdio> #include <string> namespace tvm { namespace support { /*! \brief namespace of base64 decoding and encoding table */ namespace base64 { // decoding table const char DecodeTable[] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 62, // '+' 0, 0, 0, 63, // '/' 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9' 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // 'A'-'Z' 0, 0, 0, 0, 0, 0, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // 'a'-'z' }; // encoding table static const char EncodeTable[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; } // namespace base64 /*! * \brief Buffer reader from stream to avoid * virtual call overhead on each read. */ class StreamBufferReader { public: explicit StreamBufferReader(size_t buffer_size) { buffer_.resize(buffer_size); } /*! * \brief set input stream * \param stream The stream to be set */ void set_stream(dmlc::Stream *stream) { stream_ = stream; read_len_ = read_ptr_ = 1; } /*! * \return allows quick read using get char */ int GetChar() { while (true) { if (read_ptr_ < read_len_) { return static_cast<int>(buffer_[read_ptr_++]); } else { read_len_ = stream_->Read(&buffer_[0], buffer_.length()); if (read_len_ == 0) return EOF; read_ptr_ = 0; } } } /*! \return whether we are reaching the end of file */ bool AtEnd() const { return read_len_ == 0; } private: /*! \brief the underlying stream */ dmlc::Stream *stream_{nullptr}; /*! \brief buffer to hold data */ std::string buffer_; /*! \brief length of valid data in buffer */ size_t read_len_{1}; /*! \brief pointer in the buffer */ size_t read_ptr_{1}; }; /*! * \brief Input stream from base64 encoding */ class Base64InStream: public dmlc::Stream { public: explicit Base64InStream(dmlc::Stream *fs) : reader_(256) { reader_.set_stream(fs); } /*! * \brief initialize the stream position to beginning of next base64 stream * \note call this function before actually start read */ void InitPosition(void) { // get a character do { temp_ch_ = reader_.GetChar(); } while (isspace(temp_ch_)); } /*! \brief whether current position is end of a base64 stream */ bool IsEOF(void) const { return num_prev_ == 0 && (temp_ch_ == EOF || isspace(temp_ch_)); } // override read function. virtual size_t Read(void *ptr, size_t size) { using base64::DecodeTable; if (size == 0) return 0; // use tlen to record left size size_t tlen = size; unsigned char *cptr = static_cast<unsigned char*>(ptr); // if anything left, load from previous buffered result if (num_prev_ != 0) { if (num_prev_ == 2) { if (tlen >= 2) { *cptr++ = buf_prev[0]; *cptr++ = buf_prev[1]; tlen -= 2; num_prev_ = 0; } else { // assert tlen == 1 *cptr++ = buf_prev[0]; --tlen; buf_prev[0] = buf_prev[1]; num_prev_ = 1; } } else { // assert num_prev_ == 1 *cptr++ = buf_prev[0]; --tlen; num_prev_ = 0; } } if (tlen == 0) return size; int nvalue; // note: everything goes with 4 bytes in Base64 // so we process 4 bytes a unit while (tlen && temp_ch_ != EOF && !isspace(temp_ch_)) { // first byte nvalue = DecodeTable[temp_ch_] << 18; { // second byte temp_ch_ = reader_.GetChar(); CHECK(temp_ch_ != EOF && !isspace(temp_ch_)) << "invalid base64 format"; nvalue |= DecodeTable[temp_ch_] << 12; *cptr++ = (nvalue >> 16) & 0xFF; --tlen; } { // third byte temp_ch_ = reader_.GetChar(); CHECK(temp_ch_ != EOF && !isspace(temp_ch_)) << "invalid base64 format"; // handle termination if (temp_ch_ == '=') { temp_ch_ = reader_.GetChar(); CHECK(temp_ch_ == '=') << "invalid base64 format"; temp_ch_ = reader_.GetChar(); CHECK(temp_ch_ == EOF || isspace(temp_ch_)) << "invalid base64 format"; break; } nvalue |= DecodeTable[temp_ch_] << 6; if (tlen) { *cptr++ = (nvalue >> 8) & 0xFF; --tlen; } else { buf_prev[num_prev_++] = (nvalue >> 8) & 0xFF; } } { // fourth byte temp_ch_ = reader_.GetChar(); CHECK(temp_ch_ != EOF && !isspace(temp_ch_)) << "invalid base64 format"; if (temp_ch_ == '=') { temp_ch_ = reader_.GetChar(); CHECK(temp_ch_ == EOF || isspace(temp_ch_)) << "invalid base64 format"; break; } nvalue |= DecodeTable[temp_ch_]; if (tlen) { *cptr++ = nvalue & 0xFF; --tlen; } else { buf_prev[num_prev_ ++] = nvalue & 0xFF; } } // get next char temp_ch_ = reader_.GetChar(); } if (kStrictCheck) { CHECK_EQ(tlen, 0) << "Base64InStream: read incomplete"; } return size - tlen; } virtual void Write(const void *ptr, size_t size) { LOG(FATAL) << "Base64InStream do not support write"; } private: // internal reader StreamBufferReader reader_; int temp_ch_{0}; int num_prev_{0}; unsigned char buf_prev[2]; // whether we need to do strict check static const bool kStrictCheck = false; }; /*! * \brief Stream to write to base64 format. */ class Base64OutStream: public dmlc::Stream { public: explicit Base64OutStream(dmlc::Stream *fp) : fp_(fp) { } virtual void Write(const void *ptr, size_t size) { using base64::EncodeTable; size_t tlen = size; const unsigned char *cptr = static_cast<const unsigned char*>(ptr); while (tlen) { while (buf__top_ < 3 && tlen != 0) { buf_[++buf__top_] = *cptr++; --tlen; } if (buf__top_ == 3) { // flush 4 bytes out PutChar(EncodeTable[buf_[1] >> 2]); PutChar(EncodeTable[((buf_[1] << 4) | (buf_[2] >> 4)) & 0x3F]); PutChar(EncodeTable[((buf_[2] << 2) | (buf_[3] >> 6)) & 0x3F]); PutChar(EncodeTable[buf_[3] & 0x3F]); buf__top_ = 0; } } } virtual size_t Read(void *ptr, size_t size) { LOG(FATAL) << "Base64OutStream do not support read"; return 0; } /*! * \brief finish writing of all current base64 stream, do some post processing * \param endch character to put to end of stream, if it is EOF, then nothing will be appended. */ void Finish(int endch = EOF) { using base64::EncodeTable; if (buf__top_ == 1) { PutChar(EncodeTable[buf_[1] >> 2]); PutChar(EncodeTable[(buf_[1] << 4) & 0x3F]); PutChar('='); PutChar('='); } if (buf__top_ == 2) { PutChar(EncodeTable[buf_[1] >> 2]); PutChar(EncodeTable[((buf_[1] << 4) | (buf_[2] >> 4)) & 0x3F]); PutChar(EncodeTable[(buf_[2] << 2) & 0x3F]); PutChar('='); } buf__top_ = 0; if (endch != EOF) PutChar(endch); this->Flush(); } private: static constexpr size_t kBufferSize = 256; dmlc::Stream *fp_{nullptr}; int buf__top_{0}; unsigned char buf_[4]; std::string out_buf_; void PutChar(char ch) { out_buf_ += ch; if (out_buf_.length() >= kBufferSize) Flush(); } void Flush(void) { if (out_buf_.length() != 0) { fp_->Write(&out_buf_[0], out_buf_.length()); out_buf_.clear(); } } }; } // namespace support } // namespace tvm #endif // TVM_SUPPORT_BASE64_H_