source_module.cc 4.94 KB
Newer Older
1 2 3 4 5 6
/*!
 *  Copyright (c) 2017 by Contributors
 * \file source_module.cc
 * \brief Source code module, only for viewing
 */
#include <tvm/runtime/packed_func.h>
7
#include "codegen_source_base.h"
8 9
#include "../runtime/file_util.h"
#include "../runtime/meta_data.h"
10 11 12 13 14 15 16

namespace tvm {
namespace codegen {

using runtime::TVMArgs;
using runtime::TVMRetValue;
using runtime::PackedFunc;
17 18 19 20 21 22

using runtime::GetFileFormat;
using runtime::GetMetaFilePath;
using runtime::FunctionInfo;
using runtime::SaveBinaryToFile;

23
// Simulator function
24
class SourceModuleNode : public runtime::ModuleNode {
25 26 27 28 29 30 31
 public:
  SourceModuleNode(std::string code,
                   std::string fmt)
      : code_(code), fmt_(fmt) {}
  const char* type_key() const {
    return "source";
  }
32

33 34 35 36 37 38 39
  PackedFunc GetFunction(
      const std::string& name,
      const std::shared_ptr<ModuleNode>& sptr_to_self) final {
    LOG(FATAL) << "Source module cannot execute, to get executable module"
               << " build TVM with \'" << fmt_ << "\' runtime support";
    return PackedFunc();
  }
40

41 42 43 44
  std::string GetSource(const std::string& format) final {
    return code_;
  }

45
 protected:
46 47 48 49 50 51 52 53 54
  std::string code_;
  std::string fmt_;
};

runtime::Module SourceModuleCreate(std::string code, std::string fmt) {
  std::shared_ptr<SourceModuleNode> n =
      std::make_shared<SourceModuleNode>(code, fmt);
  return runtime::Module(n);
}
55

56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
// Simulator function
class CSourceModuleNode : public runtime::ModuleNode {
 public:
  CSourceModuleNode(std::string code,
                   std::string fmt)
      : code_(code), fmt_(fmt) {}
  const char* type_key() const {
    return "c";
  }

  PackedFunc GetFunction(
      const std::string& name,
      const std::shared_ptr<ModuleNode>& sptr_to_self) final {
    LOG(FATAL) << "C Source module cannot execute, to get executable module"
               << " build TVM with \'" << fmt_ << "\' runtime support";
    return PackedFunc();
  }

  std::string GetSource(const std::string& format) final {
    return code_;
  }

  void SaveToFile(const std::string& file_name,
                  const std::string& format) final {
    std::string fmt = GetFileFormat(file_name, format);
    std::string meta_file = GetMetaFilePath(file_name);
    if (fmt == "cc") {
      CHECK_NE(code_.length(), 0);
      SaveBinaryToFile(file_name, code_);
    } else {
      CHECK_EQ(fmt, fmt_)
          << "Can only save to format=" << fmt_;
    }
  }

 protected:
  std::string code_;
  std::string fmt_;
};

runtime::Module CSourceModuleCreate(std::string code, std::string fmt) {
  std::shared_ptr<CSourceModuleNode> n =
      std::make_shared<CSourceModuleNode>(code, fmt);
  return runtime::Module(n);
}

102
// supports limited save without cross compile
103
class DeviceSourceModuleNode final : public runtime::ModuleNode {
104
 public:
105
  DeviceSourceModuleNode(std::string data,
106 107
                         std::string fmt,
                         std::unordered_map<std::string, FunctionInfo> fmap,
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
                         std::string type_key,
                         std::function<std::string(const std::string&)> fget_source)
    : data_(data),
      fmt_(fmt),
      fmap_(fmap),
      type_key_(type_key),
      fget_source_(fget_source) {}

  PackedFunc GetFunction(
        const std::string& name,
        const std::shared_ptr<ModuleNode>& sptr_to_self) final {
    LOG(FATAL) << "Source module cannot execute, to get executable module"
               << " build TVM with \'" << fmt_ << "\' runtime support";
    return PackedFunc();
  }

  std::string GetSource(const std::string& format) final {
    if (fget_source_ != nullptr) {
      return fget_source_(format);
    } else {
      return data_;
    }
  }
131 132 133 134 135 136

  const char* type_key() const {
    return type_key_.c_str();
  }

  void SaveToFile(const std::string& file_name,
137
                  const std::string& format) final {
138 139 140 141 142
    std::string fmt = GetFileFormat(file_name, format);
    CHECK_EQ(fmt, fmt_)
        << "Can only save to format=" << fmt_;
    std::string meta_file = GetMetaFilePath(file_name);
    SaveMetaDataToFile(meta_file, fmap_);
143
    SaveBinaryToFile(file_name, data_);
144 145 146 147 148
  }

  void SaveToBinary(dmlc::Stream* stream) final {
    stream->Write(fmt_);
    stream->Write(fmap_);
149
    stream->Write(data_);
150 151 152
  }

 private:
153 154
  std::string data_;
  std::string fmt_;
155 156
  std::unordered_map<std::string, FunctionInfo> fmap_;
  std::string type_key_;
157
  std::function<std::string(const std::string&)> fget_source_;
158 159 160
};

runtime::Module DeviceSourceModuleCreate(
161
    std::string data,
162 163
    std::string fmt,
    std::unordered_map<std::string, FunctionInfo> fmap,
164 165
    std::string type_key,
    std::function<std::string(const std::string&)> fget_source) {
166
  std::shared_ptr<DeviceSourceModuleNode> n =
167
      std::make_shared<DeviceSourceModuleNode>(data, fmt, fmap, type_key, fget_source);
168 169 170
  return runtime::Module(n);
}

171 172 173 174
TVM_REGISTER_GLOBAL("module.source_module_create")
.set_body([](TVMArgs args, TVMRetValue* rv) {
    *rv = SourceModuleCreate(args[0], args[1]);
  });
175 176
}  // namespace codegen
}  // namespace tvm