source_module.cc 5.67 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25
/*!
 *  Copyright (c) 2017 by Contributors
 * \file source_module.cc
 * \brief Source code module, only for viewing
 */
#include <tvm/runtime/packed_func.h>
26
#include "codegen_source_base.h"
27 28
#include "../runtime/file_util.h"
#include "../runtime/meta_data.h"
29 30 31 32 33 34 35

namespace tvm {
namespace codegen {

using runtime::TVMArgs;
using runtime::TVMRetValue;
using runtime::PackedFunc;
36 37 38 39 40 41

using runtime::GetFileFormat;
using runtime::GetMetaFilePath;
using runtime::FunctionInfo;
using runtime::SaveBinaryToFile;

42
// Simulator function
43
class SourceModuleNode : public runtime::ModuleNode {
44 45 46 47 48 49 50
 public:
  SourceModuleNode(std::string code,
                   std::string fmt)
      : code_(code), fmt_(fmt) {}
  const char* type_key() const {
    return "source";
  }
51

52 53 54 55 56 57 58
  PackedFunc GetFunction(
      const std::string& name,
      const std::shared_ptr<ModuleNode>& sptr_to_self) final {
    LOG(FATAL) << "Source module cannot execute, to get executable module"
               << " build TVM with \'" << fmt_ << "\' runtime support";
    return PackedFunc();
  }
59

60 61 62 63
  std::string GetSource(const std::string& format) final {
    return code_;
  }

64
 protected:
65 66 67 68 69 70 71 72 73
  std::string code_;
  std::string fmt_;
};

runtime::Module SourceModuleCreate(std::string code, std::string fmt) {
  std::shared_ptr<SourceModuleNode> n =
      std::make_shared<SourceModuleNode>(code, fmt);
  return runtime::Module(n);
}
74

75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
// Simulator function
class CSourceModuleNode : public runtime::ModuleNode {
 public:
  CSourceModuleNode(std::string code,
                   std::string fmt)
      : code_(code), fmt_(fmt) {}
  const char* type_key() const {
    return "c";
  }

  PackedFunc GetFunction(
      const std::string& name,
      const std::shared_ptr<ModuleNode>& sptr_to_self) final {
    LOG(FATAL) << "C Source module cannot execute, to get executable module"
               << " build TVM with \'" << fmt_ << "\' runtime support";
    return PackedFunc();
  }

  std::string GetSource(const std::string& format) final {
    return code_;
  }

  void SaveToFile(const std::string& file_name,
                  const std::string& format) final {
    std::string fmt = GetFileFormat(file_name, format);
    std::string meta_file = GetMetaFilePath(file_name);
    if (fmt == "cc") {
      CHECK_NE(code_.length(), 0);
      SaveBinaryToFile(file_name, code_);
    } else {
      CHECK_EQ(fmt, fmt_)
          << "Can only save to format=" << fmt_;
    }
  }

 protected:
  std::string code_;
  std::string fmt_;
};

runtime::Module CSourceModuleCreate(std::string code, std::string fmt) {
  std::shared_ptr<CSourceModuleNode> n =
      std::make_shared<CSourceModuleNode>(code, fmt);
  return runtime::Module(n);
}

121
// supports limited save without cross compile
122
class DeviceSourceModuleNode final : public runtime::ModuleNode {
123
 public:
124
  DeviceSourceModuleNode(std::string data,
125 126
                         std::string fmt,
                         std::unordered_map<std::string, FunctionInfo> fmap,
127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
                         std::string type_key,
                         std::function<std::string(const std::string&)> fget_source)
    : data_(data),
      fmt_(fmt),
      fmap_(fmap),
      type_key_(type_key),
      fget_source_(fget_source) {}

  PackedFunc GetFunction(
        const std::string& name,
        const std::shared_ptr<ModuleNode>& sptr_to_self) final {
    LOG(FATAL) << "Source module cannot execute, to get executable module"
               << " build TVM with \'" << fmt_ << "\' runtime support";
    return PackedFunc();
  }

  std::string GetSource(const std::string& format) final {
    if (fget_source_ != nullptr) {
      return fget_source_(format);
    } else {
      return data_;
    }
  }
150 151 152 153 154 155

  const char* type_key() const {
    return type_key_.c_str();
  }

  void SaveToFile(const std::string& file_name,
156
                  const std::string& format) final {
157 158 159 160 161
    std::string fmt = GetFileFormat(file_name, format);
    CHECK_EQ(fmt, fmt_)
        << "Can only save to format=" << fmt_;
    std::string meta_file = GetMetaFilePath(file_name);
    SaveMetaDataToFile(meta_file, fmap_);
162
    SaveBinaryToFile(file_name, data_);
163 164 165 166 167
  }

  void SaveToBinary(dmlc::Stream* stream) final {
    stream->Write(fmt_);
    stream->Write(fmap_);
168
    stream->Write(data_);
169 170 171
  }

 private:
172 173
  std::string data_;
  std::string fmt_;
174 175
  std::unordered_map<std::string, FunctionInfo> fmap_;
  std::string type_key_;
176
  std::function<std::string(const std::string&)> fget_source_;
177 178 179
};

runtime::Module DeviceSourceModuleCreate(
180
    std::string data,
181 182
    std::string fmt,
    std::unordered_map<std::string, FunctionInfo> fmap,
183 184
    std::string type_key,
    std::function<std::string(const std::string&)> fget_source) {
185
  std::shared_ptr<DeviceSourceModuleNode> n =
186
      std::make_shared<DeviceSourceModuleNode>(data, fmt, fmap, type_key, fget_source);
187 188 189
  return runtime::Module(n);
}

190
TVM_REGISTER_GLOBAL("module.source_module_create")
191
.set_body_typed(SourceModuleCreate);
192 193
}  // namespace codegen
}  // namespace tvm