codegen.cc 9.81 KB
Newer Older
Zhi committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/type.h>
22
#include <tvm/runtime/ndarray.h>
Zhi committed
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
#include <tvm/runtime/module.h>
#include <tvm/runtime/object.h>

#include <fstream>
#include <sstream>

#include "codegen_c.h"

namespace tvm {
namespace relay {
namespace contrib {

/*!
 * \brief An example codegen that is only used for quick prototyping and testing
 * purpose. Only several binary options are covered. Users
 * may need to extend them to cover more operators.
 */
class CodegenC : public ExprVisitor, public CodegenCBase {
 public:
  explicit CodegenC(const std::string& id) { this->ext_func_id_ = id; }

44
  void VisitExpr_(const VarNode* node) final {
45
    ext_func_args_.push_back(GetRef<Var>(node));
Zhi committed
46
    out_.clear();
47 48 49
    Output output;
    output.name = node->name_hint();
    out_.push_back(output);
Zhi committed
50 51
  }

52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
  void VisitExpr_(const ConstantNode* cn) final {
    Constant constant = GetRef<Constant>(cn);
    if (visited_.count(constant)) {
      // Note this is for demostration purpose. ConstantNode doesn't necessarily
      // belong to calls. We need to revisit this when tuples come into play.
      out_.push_back(visited_[constant]);
      return;
    }

    std::ostringstream decl_stream;
    std::ostringstream buf_stream;

    out_.clear();
    Output output;
    output.name = "const_" + std::to_string(const_idx_++);
    out_.push_back(output);
    visited_[constant] = output;

    runtime::NDArray array = cn->data;
    const auto& shape = array.Shape();
    const DLTensor& dl_tensor = array.ToDLPack()->dl_tensor;

    // Get the number of elements.
    int64_t num_elems = 1;
    for (auto i : shape) num_elems *= i;

    const auto* type_node = cn->checked_type().as<TensorTypeNode>();
    CHECK(type_node);
    const auto& dtype = GetDtypeString(type_node);
    // Define a const buffer: float const_0[64] = {1.0, 2.0, ...};
    //
    // Technically, you may need: static float* const_0 = (float*)malloc(4 * 64)
    // to avoid possible stack overflow.
    buf_stream << dtype << " " << output.name << "[" << num_elems << "] = {";
    if (dtype == "float") {
      float* p_flt = static_cast<float*>(dl_tensor.data);
      for (int64_t i = 0; i < num_elems - 1; i++) buf_stream << p_flt[i] << ", ";
      if (num_elems) buf_stream << p_flt[num_elems - 1];
    } else if (dtype == "int") {
      int* p_flt = static_cast<int*>(dl_tensor.data);
      for (int64_t i = 0; i < num_elems - 1; i++) buf_stream << p_flt[i] << ", ";
      if (num_elems) buf_stream << p_flt[num_elems - 1];
    } else {
      LOG(FATAL) << "Only float and int are supported for now.";
    }
    buf_stream << "};";
    ext_func_body.insert(ext_func_body.begin(), buf_stream.str());
  }

Zhi committed
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
  void VisitExpr_(const CallNode* call) final {
    std::ostringstream macro_stream;
    std::ostringstream decl_stream;
    std::ostringstream buf_stream;

    std::string func_name = ext_func_id_ + "_" + std::to_string(func_idx++);

    // Make function declaration
    macro_stream << "CSOURCE_BINARY_OP_" << call->args.size() << "D(" << func_name << ", ";

    if (IsOp(call, "add")) {
      macro_stream << "+";
    } else if (IsOp(call, "subtract")) {
      macro_stream << "-";
    } else if (IsOp(call, "multiply")) {
      macro_stream << "*";
    } else {
      LOG(FATAL) << "Unrecognized op";
    }

    auto in_shape = GetShape(call->args[0]->checked_type());
    for (size_t i = 0; i < in_shape.size(); ++i) {
      macro_stream << ", " << in_shape[i];
    }
125 126 127 128 129 130

    const auto* type_node = call->checked_type().as<TensorTypeNode>();
    CHECK(type_node);
    const auto& dtype = GetDtypeString(type_node);
    macro_stream << ", " << dtype;

Zhi committed
131 132 133 134 135 136 137 138 139 140 141 142 143
    macro_stream << ");";
    func_decl_.push_back(macro_stream.str());

    // Make function call when visiting arguments
    bool first = true;
    decl_stream << func_name << "(";
    for (size_t i = 0; i < call->args.size(); ++i) {
      VisitExpr(call->args[i]);
      for (auto out : out_) {
        if (!first) {
          decl_stream << ", ";
        }
        first = false;
144
        decl_stream << out.name;
Zhi committed
145 146 147 148 149 150 151 152 153
      }
    }

    std::string out = "buf_" + std::to_string(buf_idx_++);
    auto out_shape = GetShape(call->checked_type());
    int out_size = 1;
    for (size_t i = 0; i < out_shape.size(); ++i) {
      out_size *= out_shape[i];
    }
154 155
    buf_stream << dtype << "* " << out <<
      " = (" << dtype << "*)std::malloc(4 * " << out_size << ");";
Zhi committed
156 157 158 159 160 161 162
    buf_decl_.push_back(buf_stream.str());

    decl_stream << ", " << out << ");";
    ext_func_body.push_back(decl_stream.str());

    // Update output buffer
    out_.clear();
163 164 165 166 167 168
    Output output;
    output.name = out;
    output.dtype = dtype;
    output.need_copy = true;
    output.size = out_size;
    out_.push_back(output);
Zhi committed
169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
  }

  /*!
   * \brief Emit the source code that invokes C compiler compatible wrappers.
   *
   * \return The emitted code.
   */
  std::string JIT() {
    // Write function macros
    for (auto decl : func_decl_) {
      code_stream_ << decl << "\n";
    }
    return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body, out_);
  }

 private:
  /*! \brief The function id that represents a C source function. */
  std::string ext_func_id_ = "";
  /*! \brief The index of a wrapped C function. */
  int func_idx = 0;
  /*! \brief The index of allocated buffers. */
  int buf_idx_ = 0;
191 192
  /*! \brief The index of global constants. */
  int const_idx_ = 0;
Zhi committed
193
  /*! \brief The arguments of a C compiler compatible function. */
194
  Array<Var> ext_func_args_;
Zhi committed
195 196 197 198 199 200 201
  /*! \brief The statements of a C compiler compatible function. */
  std::vector<std::string> ext_func_body;
  /*! \brief The declaration statements of a C compiler compatible function. */
  std::vector<std::string> func_decl_;
  /*! \brief The declaration statements of buffers. */
  std::vector<std::string> buf_decl_;
  /*! \brief The name and index pairs for output. */
202
  std::vector<Output> out_;
203 204
  /*! \brief The cached expressions. */
  std::unordered_map<Expr, Output, ObjectHash, ObjectEqual> visited_;
Zhi committed
205 206 207 208 209 210 211 212 213 214
};

class CSourceCodegen : public CSourceModuleCodegenBase {
 public:
  void GenCFunc(const Function& func) {
    CHECK(func.defined()) << "Input error: expect a Relay function.";

    // Record the external symbol for runtime lookup.
    auto sid = GetExtSymbol(func);

Zhi committed
215
    CodegenC builder(sid);
Zhi committed
216 217 218 219
    builder.VisitExpr(func->body);
    code_stream_ << builder.JIT();
  }

220
  runtime::Module CreateCSourceModule(const ObjectRef& ref) override {
Zhi committed
221 222 223
    // Create headers
    code_stream_ << "#include <cstring>\n";
    code_stream_ << "#include <tvm/runtime/c_runtime_api.h>\n";
224
    code_stream_ << "#include <tvm/runtime/packed_func.h>\n";
Zhi committed
225 226 227 228
    code_stream_ << "#include <dlpack/dlpack.h>\n";

    // Append some common macro for operator definition.
    const char* operator_macro = R"op_macro(
229 230 231 232 233
    #define CSOURCE_BINARY_OP_1D(p_ID_, p_OP_, p_DIM1_, p_DTYPE)       \
      extern "C" void p_ID_(p_DTYPE* a, p_DTYPE* b, p_DTYPE* out) {    \
        for (int64_t i = 0; i < p_DIM1_; ++i) {                        \
          out[i] = a[i] p_OP_ b[i];                                    \
        }                                                              \
Zhi committed
234
      }
235

236 237 238 239 240 241 242 243
    #define CSOURCE_BINARY_OP_2D(p_ID_, p_OP_, p_DIM1_, p_DIM2_, p_DTYPE)  \
      extern "C" void p_ID_(p_DTYPE* a, p_DTYPE* b, p_DTYPE* out) {        \
        for (int64_t i = 0; i < p_DIM1_; ++i) {                            \
          for (int64_t j = 0; j < p_DIM2_; ++j) {                          \
            int64_t k = i * p_DIM2_ + j;                                   \
            out[k] = a[k] p_OP_ b[k];                                      \
          }                                                                \
        }                                                                  \
Zhi committed
244 245 246 247 248 249 250
      }
    )op_macro";

    code_stream_ << operator_macro << "\n\n";

    if (ref->IsInstance<FunctionNode>()) {
      GenCFunc(Downcast<Function>(ref));
251 252
    } else if (ref->IsInstance<IRModuleNode>()) {
      IRModule mod = Downcast<IRModule>(ref);
Zhi committed
253 254 255 256 257 258 259 260 261
      for (const auto& it : mod->functions) {
        GenCFunc(Downcast<Function>(it.second));
      }
    } else {
      LOG(FATAL) << "The input ref is expected to be a Relay function or module"
                 << "\n";
    }

    // Create a CSourceModule
262
    const auto* pf = runtime::Registry::Get("runtime.CSourceModuleCreate");
Zhi committed
263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278
    CHECK(pf != nullptr) << "Cannot find csource module to create the external runtime module";
    return (*pf)(code_stream_.str(), "cc");
  }

 private:
  std::ostringstream code_stream_;
};

/*!
 * \brief The external compiler/codegen tool. It takes a Relay expression/module and
 * compile it into a runtime module.
 *
 * The external codegen tool should have been registered similiarly to LLVM,
 * CUDA, etc, under TVM, so the generated code could be packed in a runtime
 * module. This module simplifies code serialization and invocation.
 */
279
runtime::Module CCompiler(const ObjectRef& ref) {
Zhi committed
280 281 282 283
  CSourceCodegen csource;
  return csource.CreateCSourceModule(ref);
}

284
TVM_REGISTER_GLOBAL("relay.ext.ccompiler").set_body_typed(CCompiler);
Zhi committed
285 286 287 288

}  // namespace contrib
}  // namespace relay
}  // namespace tvm