Commit 54fb55a1 by Zhao Wu Committed by Zhi

[CodeGen] Generate blob use LLVM directly (#4657)

parent f16f245c
...@@ -95,5 +95,8 @@ macro(find_llvm use_llvm) ...@@ -95,5 +95,8 @@ macro(find_llvm use_llvm)
message(STATUS "Found LLVM_INCLUDE_DIRS=" ${LLVM_INCLUDE_DIRS}) message(STATUS "Found LLVM_INCLUDE_DIRS=" ${LLVM_INCLUDE_DIRS})
message(STATUS "Found LLVM_DEFINITIONS=" ${LLVM_DEFINITIONS}) message(STATUS "Found LLVM_DEFINITIONS=" ${LLVM_DEFINITIONS})
message(STATUS "Found TVM_LLVM_VERSION=" ${TVM_LLVM_VERSION}) message(STATUS "Found TVM_LLVM_VERSION=" ${TVM_LLVM_VERSION})
if (${TVM_LLVM_VERSION} LESS 40)
message(FATAL_ERROR "TVM requires LLVM 4.0 or higher.")
endif()
endif() endif()
endmacro(find_llvm) endmacro(find_llvm)
...@@ -59,6 +59,21 @@ runtime::Module Build(const Array<LoweredFunc>& funcs, ...@@ -59,6 +59,21 @@ runtime::Module Build(const Array<LoweredFunc>& funcs,
* \return cstr The C string representation of the file. * \return cstr The C string representation of the file.
*/ */
std::string PackImportsToC(const runtime::Module& m, bool system_lib); std::string PackImportsToC(const runtime::Module& m, bool system_lib);
/*!
* \brief Pack imported device library to a LLVM module.
* Compile the LLVM module and link with the host library
* will allow the DSO loader to automatically discover and import
* the dependency from the shared library.
*
* \param m The host module with the imports.
* \param system_lib Whether expose as system library.
* \param target_triple LLVM target triple
* \return runtime::Module The generated LLVM module.
*/
runtime::Module PackImportsToLLVM(const runtime::Module& m,
bool system_lib,
const std::string& target_triple);
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
......
...@@ -51,10 +51,41 @@ def create_shared(output, ...@@ -51,10 +51,41 @@ def create_shared(output,
else: else:
raise ValueError("Unsupported platform") raise ValueError("Unsupported platform")
def get_target_by_dump_machine(compiler):
""" Functor of get_target_triple that can get the target triple using compiler.
Parameters
----------
compiler : Optional[str]
The compiler.
Returns
-------
out: Callable
A function that can get target triple according to dumpmachine option of compiler.
"""
def get_target_triple():
""" Get target triple according to dumpmachine option of compiler."""
if compiler:
cmd = [compiler, "-dumpmachine"]
proc = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
if proc.returncode != 0:
msg = "dumpmachine error:\n"
msg += py_str(out)
return None
return py_str(out)
else:
return None
return get_target_triple
# assign so as default output format # assign so as default output format
create_shared.output_format = "so" if sys.platform != "win32" else "dll" create_shared.output_format = "so" if sys.platform != "win32" else "dll"
create_shared.get_target_triple = get_target_by_dump_machine(
"g++" if sys.platform == "darwin" or sys.platform.startswith("linux") else None)
def build_create_shared_func(options=None, compile_cmd="g++"): def build_create_shared_func(options=None, compile_cmd="g++"):
"""Build create_shared function with particular default options and compile_cmd. """Build create_shared function with particular default options and compile_cmd.
...@@ -75,10 +106,11 @@ def build_create_shared_func(options=None, compile_cmd="g++"): ...@@ -75,10 +106,11 @@ def build_create_shared_func(options=None, compile_cmd="g++"):
def create_shared_wrapper(output, objects, options=options, compile_cmd=compile_cmd): def create_shared_wrapper(output, objects, options=options, compile_cmd=compile_cmd):
create_shared(output, objects, options, compile_cmd) create_shared(output, objects, options, compile_cmd)
create_shared_wrapper.output_format = create_shared.output_format create_shared_wrapper.output_format = create_shared.output_format
create_shared_wrapper.get_target_triple = get_target_by_dump_machine(compile_cmd)
return create_shared_wrapper return create_shared_wrapper
def cross_compiler(compile_func, base_options=None, output_format="so"): def cross_compiler(compile_func, base_options=None, output_format="so", get_target_triple=None):
"""Create a cross compiler function. """Create a cross compiler function.
Parameters Parameters
...@@ -92,6 +124,9 @@ def cross_compiler(compile_func, base_options=None, output_format="so"): ...@@ -92,6 +124,9 @@ def cross_compiler(compile_func, base_options=None, output_format="so"):
output_format : Optional[str] output_format : Optional[str]
Library output format. Library output format.
get_target_triple: Optional[Callable]
Function that can target triple according to dumpmachine option of compiler.
Returns Returns
------- -------
fcompile : Callable[[str, str, Optional[str]], None] fcompile : Callable[[str, str, Optional[str]], None]
...@@ -105,6 +140,7 @@ def cross_compiler(compile_func, base_options=None, output_format="so"): ...@@ -105,6 +140,7 @@ def cross_compiler(compile_func, base_options=None, output_format="so"):
all_options += options all_options += options
compile_func(outputs, objects, options=all_options) compile_func(outputs, objects, options=all_options)
_fcompile.output_format = output_format _fcompile.output_format = output_format
_fcompile.get_target_triple = get_target_triple
return _fcompile return _fcompile
......
...@@ -21,6 +21,7 @@ from __future__ import absolute_import as _abs ...@@ -21,6 +21,7 @@ from __future__ import absolute_import as _abs
import subprocess import subprocess
import os import os
from .._ffi.base import py_str from .._ffi.base import py_str
from .cc import get_target_by_dump_machine
def create_shared(output, def create_shared(output,
objects, objects,
...@@ -64,5 +65,8 @@ def create_shared(output, ...@@ -64,5 +65,8 @@ def create_shared(output,
msg += py_str(out) msg += py_str(out)
raise RuntimeError(msg) raise RuntimeError(msg)
# assign output format # assign output format
create_shared.output_format = "so" create_shared.output_format = "so"
create_shared.get_target_triple = get_target_by_dump_machine(
os.environ["TVM_NDK_CC"]) if "TVM_NDK_CC" in os.environ else None
...@@ -123,6 +123,7 @@ class Module(ModuleBase): ...@@ -123,6 +123,7 @@ class Module(ModuleBase):
files = [] files = []
is_system_lib = False is_system_lib = False
has_c_module = False has_c_module = False
llvm_target_triple = None
for index, module in enumerate(modules): for index, module in enumerate(modules):
if fcompile is not None and hasattr(fcompile, "object_format"): if fcompile is not None and hasattr(fcompile, "object_format"):
object_format = fcompile.object_format object_format = fcompile.object_format
...@@ -138,17 +139,28 @@ class Module(ModuleBase): ...@@ -138,17 +139,28 @@ class Module(ModuleBase):
files.append(path_obj) files.append(path_obj)
is_system_lib = (module.type_key == "llvm" and is_system_lib = (module.type_key == "llvm" and
module.get_function("__tvm_is_system_module")()) module.get_function("__tvm_is_system_module")())
llvm_target_triple = (module.type_key == "llvm" and
module.get_function("_get_target_triple")())
if not fcompile:
if file_name.endswith(".tar"):
fcompile = _tar.tar
else:
fcompile = _cc.create_shared
if llvm_target_triple is None and hasattr(fcompile, "get_target_triple"):
llvm_target_triple = fcompile.get_target_triple()
if self.imported_modules: if self.imported_modules:
if enabled("llvm") and llvm_target_triple:
path_obj = temp.relpath("devc.o")
m = _PackImportsToLLVM(self, is_system_lib, llvm_target_triple)
m.save(path_obj)
files.append(path_obj)
else:
path_cc = temp.relpath("devc.cc") path_cc = temp.relpath("devc.cc")
with open(path_cc, "w") as f: with open(path_cc, "w") as f:
f.write(_PackImportsToC(self, is_system_lib)) f.write(_PackImportsToC(self, is_system_lib))
files.append(path_cc) files.append(path_cc)
if not fcompile:
if file_name.endswith(".tar"):
fcompile = _tar.tar
else:
fcompile = _cc.create_shared
if has_c_module: if has_c_module:
options = [] options = []
......
...@@ -43,5 +43,8 @@ TVM_REGISTER_GLOBAL("codegen._Build") ...@@ -43,5 +43,8 @@ TVM_REGISTER_GLOBAL("codegen._Build")
TVM_REGISTER_GLOBAL("module._PackImportsToC") TVM_REGISTER_GLOBAL("module._PackImportsToC")
.set_body_typed(PackImportsToC); .set_body_typed(PackImportsToC);
TVM_REGISTER_GLOBAL("module._PackImportsToLLVM")
.set_body_typed(PackImportsToLLVM);
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/module.h> #include <tvm/runtime/module.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/build_module.h> #include <tvm/build_module.h>
#include <dmlc/memory_io.h> #include <dmlc/memory_io.h>
#include <sstream> #include <sstream>
...@@ -158,7 +159,8 @@ class ModuleSerializer { ...@@ -158,7 +159,8 @@ class ModuleSerializer {
std::vector<uint64_t> import_tree_child_indices_; std::vector<uint64_t> import_tree_child_indices_;
}; };
std::string PackImportsToC(const runtime::Module& mod, bool system_lib) { namespace {
std::string SerializeModule(const runtime::Module& mod) {
std::string bin; std::string bin;
dmlc::MemoryStringStream ms(&bin); dmlc::MemoryStringStream ms(&bin);
dmlc::Stream* stream = &ms; dmlc::Stream* stream = &ms;
...@@ -166,6 +168,13 @@ std::string PackImportsToC(const runtime::Module& mod, bool system_lib) { ...@@ -166,6 +168,13 @@ std::string PackImportsToC(const runtime::Module& mod, bool system_lib) {
ModuleSerializer module_serializer(mod); ModuleSerializer module_serializer(mod);
module_serializer.SerializeModule(stream); module_serializer.SerializeModule(stream);
return bin;
}
} // namespace
std::string PackImportsToC(const runtime::Module& mod, bool system_lib) {
std::string bin = SerializeModule(mod);
// translate to C program // translate to C program
std::ostringstream os; std::ostringstream os;
os << "#ifdef _WIN32\n" os << "#ifdef _WIN32\n"
...@@ -211,5 +220,29 @@ std::string PackImportsToC(const runtime::Module& mod, bool system_lib) { ...@@ -211,5 +220,29 @@ std::string PackImportsToC(const runtime::Module& mod, bool system_lib) {
<< "#endif\n"; << "#endif\n";
return os.str(); return os.str();
} }
runtime::Module PackImportsToLLVM(const runtime::Module& mod,
bool system_lib,
const std::string& target_triple) {
std::string bin = SerializeModule(mod);
uint64_t nbytes = bin.length();
std::string header;
for (size_t i = 0; i < sizeof(nbytes); ++i) {
header.push_back(((nbytes >> (i * 8)) & 0xffUL));
}
std::string blob = header + bin;
TVMByteArray blob_byte_array;
blob_byte_array.size = blob.length();
blob_byte_array.data = blob.data();
// Call codegen_blob to generate LLVM module
std::string codegen_f_name = "codegen.codegen_blob";
// the codegen function.
const PackedFunc* codegen_f = runtime::Registry::Get(codegen_f_name);
CHECK(codegen_f != nullptr) << "codegen.codegen_blob is not presented.";
return (*codegen_f)(blob_byte_array, system_lib, target_triple);
}
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file codegen_blob.cc
*/
#ifdef TVM_LLVM_VERSION
#include <tvm/runtime/module.h>
#include <cstring>
#include "codegen_blob.h"
namespace tvm {
namespace codegen {
std::pair<std::unique_ptr<llvm::Module>,
std::shared_ptr<llvm::LLVMContext>> CodeGenBlob(const std::string& data,
bool system_lib,
const std::string& target_triple) {
InitializeLLVM();
auto tm = GetLLVMTargetMachine(std::string("-target ") + target_triple);
auto triple = tm->getTargetTriple();
auto ctx = std::make_shared<llvm::LLVMContext>();
std::string module_name = "devc";
std::unique_ptr<llvm::Module> module(new llvm::Module(module_name, *ctx));
module->setTargetTriple(triple.str());
module->setDataLayout(tm->createDataLayout());
auto* blob_value = llvm::ConstantDataArray::getString(*ctx, data, false);
auto* tvm_dev_mblob = new llvm::GlobalVariable(*module, blob_value->getType(), true,
llvm::GlobalValue::ExternalLinkage, blob_value,
runtime::symbol::tvm_dev_mblob, nullptr,
llvm::GlobalVariable::NotThreadLocal, 0);
#if TVM_LLVM_VERSION >= 100
tvm_dev_mblob->setAlignment(llvm::Align(1));
#else
tvm_dev_mblob->setAlignment(1);
#endif
if (triple.isOSWindows()) {
tvm_dev_mblob->setDLLStorageClass(llvm::GlobalVariable::DLLExportStorageClass);
}
if (system_lib) {
// LLVM type helper
auto void_ty = llvm::Type::getVoidTy(*ctx);
auto int32_ty = llvm::Type::getInt32Ty(*ctx);
auto int8_ty = llvm::Type::getInt8Ty(*ctx);
auto int8_ptr_ty = int8_ty->getPointerTo(0);
llvm::Constant* constant_zero = llvm::Constant::getNullValue(int32_ty);
auto* tvm_dev_mblob_reg =
new llvm::GlobalVariable(*module, int32_ty,
false, llvm::GlobalValue::InternalLinkage,
constant_zero,
std::string(runtime::symbol::tvm_dev_mblob) + "_reg_");
auto tvm_dev_mblob_reg_alignment = module->getDataLayout().getABITypeAlignment(int32_ty);
#if TVM_LLVM_VERSION >= 100
tvm_dev_mblob_reg->setAlignment(llvm::Align(tvm_dev_mblob_reg_alignment));
#else
tvm_dev_mblob_reg->setAlignment(tvm_dev_mblob_reg_alignment);
#endif
auto* tvm_dev_mblob_string_ty =
llvm::ArrayType::get(int8_ty, std::strlen(runtime::symbol::tvm_dev_mblob) + 1);
auto* tvm_dev_mblob_string_value =
llvm::ConstantDataArray::getString(*ctx, runtime::symbol::tvm_dev_mblob, true);
auto* tvm_dev_mblob_string =
new llvm::GlobalVariable(*module, tvm_dev_mblob_string_ty,
true, llvm::GlobalValue::PrivateLinkage,
tvm_dev_mblob_string_value,
std::string(runtime::symbol::tvm_dev_mblob) + ".str");
#if TVM_LLVM_VERSION >= 100
tvm_dev_mblob_string->setAlignment(llvm::Align(1));
#else
tvm_dev_mblob_string->setAlignment(1);
#endif
// Global init function
llvm::Function* init_fn = llvm::Function::Create(llvm::FunctionType::get(void_ty, false),
llvm::GlobalValue::InternalLinkage,
llvm::Twine("_GLOBAL__sub_I_", module_name),
module.get());
// Create variable initialization function.
llvm::Function* var_init_fn = llvm::Function::Create(llvm::FunctionType::get(void_ty, false),
llvm::GlobalValue::InternalLinkage,
llvm::Twine("__cxx_global_var_init"),
module.get());
// Create TVMBackendRegisterSystemLibSymbol function
llvm::Function* tvm_backend_fn =
llvm::Function::Create(llvm::FunctionType::get(int32_ty, {int8_ptr_ty, int8_ptr_ty}, false),
llvm::GlobalValue::ExternalLinkage,
llvm::Twine("TVMBackendRegisterSystemLibSymbol"),
module.get());
// Set necessary fn sections
auto get_static_init_section_specifier = [&triple]() -> std::string {
if (triple.isOSLinux()) {
return ".text.startup";
} else if (triple.isOSDarwin()) {
return "__TEXT,__StaticInit,regular,pure_instructions";
} else {
return "";
}
};
auto static_init_section_specifier = get_static_init_section_specifier();
if (!static_init_section_specifier.empty()) {
init_fn->setSection(static_init_section_specifier);
var_init_fn->setSection(static_init_section_specifier);
}
// The priority is 65535 for all platforms as clang do.
llvm::appendToGlobalCtors(*module, init_fn, 65535);
// Define init_fn body
llvm::IRBuilder<> ir_builder(*ctx);
llvm::BasicBlock* init_fn_bb = llvm::BasicBlock::Create(*ctx, "entry", init_fn);
ir_builder.SetInsertPoint(init_fn_bb);
ir_builder.CreateCall(var_init_fn);
ir_builder.CreateRetVoid();
// Define var_init_fn body
llvm::BasicBlock* var_init_fn_bb = llvm::BasicBlock::Create(*ctx, "entry", var_init_fn);
ir_builder.SetInsertPoint(var_init_fn_bb);
llvm::Constant* indices[] = {constant_zero, constant_zero};
llvm::SmallVector<llvm::Value*, 2> args;
args.push_back(llvm::ConstantExpr::getGetElementPtr(tvm_dev_mblob_string_ty,
tvm_dev_mblob_string,
indices));
args.push_back(llvm::ConstantExpr::getGetElementPtr(blob_value->getType(),
tvm_dev_mblob,
indices));
auto* tvm_backend_fn_ret_value = ir_builder.CreateCall(tvm_backend_fn, args);
ir_builder.CreateStore(tvm_backend_fn_ret_value, tvm_dev_mblob_reg);
ir_builder.CreateRetVoid();
}
return std::make_pair(std::move(module), ctx);
}
} // namespace codegen
} // namespace tvm
#endif // TVM_LLVM_VERSION
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file codegen_blob.h
* \brief Code Generation of blob data
*/
#ifndef TVM_CODEGEN_LLVM_CODEGEN_BLOB_H_
#define TVM_CODEGEN_LLVM_CODEGEN_BLOB_H_
#ifdef TVM_LLVM_VERSION
#include <utility>
#include <memory>
#include <string>
#include "llvm_common.h"
namespace tvm {
namespace codegen {
/**
* \brief Code Generation of blob data
*
* \param data Blob data
* \param system_lib Whether expose as system library.
* \param target_triple LLVM target triple
*
* \return LLVM module and LLVM context
*/
std::pair<std::unique_ptr<llvm::Module>,
std::shared_ptr<llvm::LLVMContext>> CodeGenBlob(const std::string& data,
bool system_lib,
const std::string& target_triple);
} // namespace codegen
} // namespace tvm
#endif // LLVM_VERSION
#endif // TVM_CODEGEN_LLVM_CODEGEN_BLOB_H_
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <mutex> #include <mutex>
#include "llvm_common.h" #include "llvm_common.h"
#include "codegen_llvm.h" #include "codegen_llvm.h"
#include "codegen_blob.h"
#include "../../runtime/file_util.h" #include "../../runtime/file_util.h"
#include "../../runtime/library_module.h" #include "../../runtime/library_module.h"
...@@ -62,6 +63,11 @@ class LLVMModuleNode final : public runtime::ModuleNode { ...@@ -62,6 +63,11 @@ class LLVMModuleNode final : public runtime::ModuleNode {
return PackedFunc([flag](TVMArgs args, TVMRetValue *rv) { return PackedFunc([flag](TVMArgs args, TVMRetValue *rv) {
* rv = flag; * rv = flag;
}); });
} else if (name == "_get_target_triple") {
std::string target_triple = tm_->getTargetTriple().str();
return PackedFunc([target_triple](TVMArgs args, TVMRetValue *rv) {
* rv = target_triple;
});
} }
if (ee_ == nullptr) LazyInitJIT(); if (ee_ == nullptr) LazyInitJIT();
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
...@@ -218,15 +224,15 @@ class LLVMModuleNode final : public runtime::ModuleNode { ...@@ -218,15 +224,15 @@ class LLVMModuleNode final : public runtime::ModuleNode {
mptr_ = module_.get(); mptr_ = module_.get();
} }
void LoadIR(const std::string& file_name) { void Init(std::unique_ptr<llvm::Module> module,
std::shared_ptr<llvm::LLVMContext> ctx) {
InitializeLLVM(); InitializeLLVM();
ctx_ = std::make_shared<llvm::LLVMContext>(); ctx_ = ctx;
llvm::SMDiagnostic err; llvm::SMDiagnostic err;
module_ = llvm::parseIRFile(file_name, err, *ctx_); module_ = std::move(module);
if (module_.get() == nullptr) { if (module_ == nullptr) {
std::string msg = err.getMessage(); std::string msg = err.getMessage();
LOG(FATAL) << "Fail to load ir file " << file_name << "\n" LOG(FATAL) << "Fail to load module: " << msg;
<< "line " << err.getLineNo() << ":" << msg;
} }
std::string target_; std::string target_;
llvm::Metadata* mtarget = module_->getModuleFlag("tvm_target"); llvm::Metadata* mtarget = module_->getModuleFlag("tvm_target");
...@@ -243,6 +249,18 @@ class LLVMModuleNode final : public runtime::ModuleNode { ...@@ -243,6 +249,18 @@ class LLVMModuleNode final : public runtime::ModuleNode {
tm_ = GetLLVMTargetMachine(target_); tm_ = GetLLVMTargetMachine(target_);
} }
void LoadIR(const std::string& file_name) {
auto ctx = std::make_shared<llvm::LLVMContext>();
llvm::SMDiagnostic err;
auto module = llvm::parseIRFile(file_name, err, *ctx);
if (module == nullptr) {
std::string msg = err.getMessage();
LOG(FATAL) << "Fail to load ir file " << file_name << "\n"
<< "line " << err.getLineNo() << ":" << msg;
}
Init(std::move(module), ctx);
}
private: private:
void LazyInitJIT() { void LazyInitJIT() {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
...@@ -339,7 +357,7 @@ TVM_REGISTER_GLOBAL("codegen.llvm_lookup_intrinsic_id") ...@@ -339,7 +357,7 @@ TVM_REGISTER_GLOBAL("codegen.llvm_lookup_intrinsic_id")
TVM_REGISTER_GLOBAL("codegen.build_llvm") TVM_REGISTER_GLOBAL("codegen.build_llvm")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
auto n = make_object<LLVMModuleNode>(); auto n = make_object<LLVMModuleNode>();
n->Init(args[0], args[1]); n->Init(args[0].operator Array<LoweredFunc>(), args[1].operator std::string());
*rv = runtime::Module(n); *rv = runtime::Module(n);
}); });
...@@ -362,6 +380,16 @@ TVM_REGISTER_GLOBAL("codegen.llvm_target_enabled") ...@@ -362,6 +380,16 @@ TVM_REGISTER_GLOBAL("codegen.llvm_target_enabled")
InitializeLLVM(); InitializeLLVM();
*rv = (GetLLVMTargetMachine(args[0], true) != nullptr); *rv = (GetLLVMTargetMachine(args[0], true) != nullptr);
}); });
TVM_REGISTER_GLOBAL("codegen.codegen_blob")
.set_body([](TVMArgs args, TVMRetValue* rv) {
auto n = make_object<LLVMModuleNode>();
auto p = CodeGenBlob(args[0].operator std::string(),
args[1].operator bool(),
args[2].operator std::string());
n->Init(std::move(p.first), p.second);
*rv = runtime::Module(n);
});
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
#endif // TVM_LLVM_VERSION #endif // TVM_LLVM_VERSION
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy as np
from tvm import relay
from tvm.relay import testing
from tvm.contrib import graph_runtime
import tvm
import ctypes
def test_resnet18():
for device in ["llvm", "cuda"]:
if not tvm.module.enabled(device):
print("skip because %s is not enabled..." % device)
return
def verify(data):
mod, params = relay.testing.resnet.get_workload(num_layers=18)
with relay.build_config(opt_level=3):
graph, lib, graph_params = relay.build_module.build(mod, "llvm", params=params)
ctx = tvm.cpu()
module = graph_runtime.create(graph, lib, ctx)
module.set_input("data", data)
module.set_input(**graph_params)
module.run()
out = module.get_output(0).asnumpy()
return out
resnet18_mod, resnet18_params = relay.testing.resnet.get_workload(num_layers=18)
with relay.build_config(opt_level=3):
graph, resnet18_gpu_lib, graph_params = relay.build_module.build(resnet18_mod, "cuda", params=resnet18_params)
from tvm.contrib import util
temp = util.tempdir()
path_lib = temp.relpath("deploy_lib.so")
resnet18_gpu_lib.export_library(path_lib)
with open(temp.relpath("deploy_graph.json"), "w") as fo:
fo.write(graph)
with open(temp.relpath("deploy_param.params"), "wb") as fo:
fo.write(relay.save_param_dict(graph_params))
loaded_lib = tvm.module.load(path_lib)
loaded_json = open(temp.relpath("deploy_graph.json")).read()
loaded_params = bytearray(open(temp.relpath("deploy_param.params"), "rb").read())
data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32")
ctx = tvm.gpu()
module = graph_runtime.create(loaded_json, loaded_lib, ctx)
module.load_params(loaded_params)
module.set_input("data", data)
module.run()
out = module.get_output(0).asnumpy()
tvm.testing.assert_allclose(out, verify(data), atol=1e-5)
def test_system_lib():
ctx = tvm.gpu(0)
for device in ["llvm", "cuda"]:
if not tvm.module.enabled(device):
print("skip because %s is not enabled..." % device)
return
nn = 12
n = tvm.convert(nn)
A = tvm.placeholder((n,), name='A')
B = tvm.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
s = tvm.create_schedule(B.op)
bx, tx = s[B].split(B.op.axis[0], factor=4)
s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
s[B].bind(tx, tvm.thread_axis("threadIdx.x"))
from tvm.contrib import util
temp = util.tempdir()
fn_add = tvm.build(s, [A, B], target="cuda", target_host="llvm -system-lib", name="add")
path_obj = temp.relpath("add.o")
path_lib = temp.relpath("deploy_lib.so")
fn_add.save(path_obj)
fn_add.export_library(path_lib)
# Load dll, will trigger system library registration
dll = ctypes.CDLL(path_lib)
# Load the system wide library
m = tvm.module.system_lib()
a = tvm.nd.array(np.random.uniform(size=nn).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(nn, dtype=A.dtype), ctx)
m['add'](a, b)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
if __name__ == "__main__":
test_resnet18()
test_system_lib()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment