Commit 17e7e3d5 by alex-weaver Committed by Tianqi Chen

Port build_module.py to C++ (#667)

* Port build_module.py to C++

* Fix lint errors

* Fix more lint errors

* Fix more lint errors

* Fix more lint errors

* Fix build error

* Implemented style fixes

* Fix lint errors

* Added function to construct target from string
lower now returns array

* Fix lint error

* Implemented review changes - style & Target options -> std::vector

* Fixed lint, argument alignment and added unit test

* Changed test to target LLVM, fixed sign compare warnings

* Reverted unit test to CUDA, changed Jenkinsfile to enable GPU for C++ tests

* Slight change to Jenkinsfile

* Changed build_module test from CUDA to LLVM

* Added function var() to construct a Var instance.
Changed implementation of LLVMEnabled()

* Reverted Jenkinsfile
parent f140ce41
/*!
* Copyright (c) 2017 by Contributors
* \file build_module.h
* \brief Functions for compiling ops.
*/
#ifndef TVM_BUILD_MODULE_H_
#define TVM_BUILD_MODULE_H_
#include <string>
#include <vector>
#include "./tvm/runtime/packed_func.h"
#include "./tvm/schedule_pass.h"
#include "./tvm/lowered_func.h"
namespace tvm {
/*!
* \brief Container for target device information.
* Use target::llvm, target::cuda etc functions instead of constructing directly.
*/
struct Target {
/*! \brief The name of the target device */
std::string target_name;
/*! \brief The type of the target device */
DLDeviceType device_type;
/*! \brief The maximum threads that a schedule should use for this device */
int max_num_threads = 1;
/*! \brief The warp size that should be used by the LowerThreadAllreduce pass */
int thread_warp_size = 1;
/*! \brief Keys for this target */
std::unordered_set<std::string> keys;
/*! \brief Options for this target */
std::vector<std::string> options;
Target(const std::string& target_name,
DLDeviceType device_type,
int max_num_threads,
int thread_warp_size,
const std::unordered_set<std::string>& keys,
const std::vector<std::string>& options) :
target_name(target_name),
device_type(device_type),
max_num_threads(max_num_threads),
thread_warp_size(thread_warp_size),
keys(keys),
options(options) {
}
/*! \return the full device string to pass to codegen::Build */
EXPORT std::string str() const;
/*!
* \brief Create a Target given a string
* \param target_str the string to parse
*/
EXPORT static Target create(const std::string& target_str);
};
/*! \brief This namespace provides functions to construct Target instances */
namespace target {
/*! \return A target for LLVM */
EXPORT Target llvm();
/*! \return A target for CUDA */
EXPORT Target cuda();
/*! \return A target for ROCm */
EXPORT Target rocm();
/*! \return A target for Metal */
EXPORT Target metal();
/*! \return A target for rasp */
EXPORT Target rasp();
/*! \return A target for stackvm */
EXPORT Target stackvm();
} // namespace target
/*!
* \brief Container for build configuration options
*/
struct BuildConfig {
/*!
* \brief The data alignment to use when constructing buffers. If this is set to
* -1, then TVM's internal default will be used
*/
int data_alignment = -1;
/*!
* \brief The offset factor to use when constructing buffers. If this is set to
* 0, then the offset field is not used.
*/
int offset_factor = 0;
/*!
* \brief Splitting factor for loop splitting. If this is set to zero, no splitting will be
* done. Otherwise, a split will be done with this factor and the inner loop will be unrolled.
*/
int double_buffer_split_loop = 1;
/*! \brief Threshold of number of steps in the loop to be automatically unrolled */
int auto_unroll_max_step = 0;
/*! \brief The maximum nested level of loops that can be automatically unrolled */
int auto_unroll_max_depth = 8;
/*! \brief The maximum extent of loop that will be unrolled */
int auto_unroll_max_extent = 0;
/*!
* \brief Whether to explicitly unroll the loop. If set to false, the unroll hint will
* be passed to the CodeGen phase. Set to true if CodeGen supports unroll pragma.
*/
bool unroll_explicit = true;
/*! \brief Set to true if buffer arguments do not overlap. This enables more optimization. */
bool restricted_func = true;
/*! \brief Whether to detect global barrier */
bool detect_global_barrier = false;
BuildConfig() {
}
};
/*!
* \brief Build a LoweredFunc given a schedule, args and binds
* \param sch The schedule to lower.
* \param args The arguments to the function.
* \param name The name of the lowered function.
* \param binds Buffer assignments.
* \param config The build configuration.
* \return The lowered function.
*/
EXPORT Array<LoweredFunc> lower(Schedule sch,
const Array<Tensor>& args,
const std::string& name,
const std::unordered_map<Tensor, Buffer>& binds,
const BuildConfig& config);
/*!
* \brief Build a device and host module for a specific target from an array of lowered functions.
* \param funcs The functions to be built.
* \param target The target device to build for.
* \param target_host The target for building host code. If null, a suitable default will be used.
* \param config The build configuration.
* \return The built module.
*/
EXPORT runtime::Module build(const Array<LoweredFunc>& funcs,
const Target& target,
Target* target_host,
const BuildConfig& config);
} // namespace tvm
#endif // TVM_BUILD_MODULE_H_
......@@ -291,6 +291,13 @@ inline const char* IterVarType2String(IterVarType t) {
return "Unknown";
}
/*!
* \brief Construct a new Var expression
* \param name_hint The name hint for the expression
* \param t The type of the expression
*/
TVM_DLL Var var(const std::string& name_hint, Type t = Int(32));
/*
* \brief Template function to convert Map to unordered_map
* Sometimes useful for API gluing when internal uses unordered_map
......
......@@ -81,7 +81,7 @@ class Stage : public NodeRef {
* \param thread_ivar The thread axis to be binded.
* \return reference to self.
*/
Stage& bind(IterVar ivar, IterVar thread_ivar);
EXPORT Stage& bind(IterVar ivar, IterVar thread_ivar);
/*!
* \brief Set predicate under which store to the array can be performed.
* Use this when there are duplicated threads doing the same store and we only
......@@ -110,7 +110,7 @@ class Stage : public NodeRef {
* \param p_inner The result inner domain.
* \return reference to self.
*/
Stage& split(IterVar parent, Expr factor, IterVar* p_outer, IterVar* p_inner); // NOLINT(*)
EXPORT Stage& split(IterVar parent, Expr factor, IterVar* p_outer, IterVar* p_inner); // NOLINT(*)
/*!
* \brief Split the iteration with given number of parts.
*
......@@ -248,13 +248,13 @@ class Schedule : public NodeRef {
* \brief Get the stage corresponds to the op
* \param op The operation.
*/
Stage operator[](const Operation& op);
EXPORT Stage operator[](const Operation& op);
/*!
* \brief Short hand for getting the stage of tensor's operation.
* \param tensor The tensor
* \return The stage corresponding to the tensor's op
*/
Stage operator[](const Tensor& tensor) {
EXPORT Stage operator[](const Tensor& tensor) {
return this->operator[](tensor->op);
}
/*!
......@@ -493,7 +493,7 @@ class ScheduleNode : public Node {
* \param ops The ops to be scheduled.
* \return sch The created Schedule.
*/
static Schedule make(Array<Operation> ops);
EXPORT static Schedule make(Array<Operation> ops);
static constexpr const char* _type_key = "Schedule";
TVM_DECLARE_NODE_TYPE_INFO(ScheduleNode, Node);
......
/*!
* Copyright (c) 2017 by Contributors
* Compile executable modules.
* \file build_module.cc
*/
#include <tvm/build_module.h>
#include <tvm/operation.h>
#include <tvm/ir_pass.h>
#include <tvm/codegen.h>
namespace tvm {
std::string Target::str() const {
std::ostringstream result;
result << target_name;
for (const auto &x : options) {
result << " " << x;
}
return result.str();
}
Target TargetFromName(const std::string& name) {
if (name == "llvm") {
return target::llvm();
} else if (name == "cuda" || name == "nvptx") {
return target::cuda();
} else if (name == "rocm" || name == "opencl") {
/* For now, assume rocm schedule for opencl */
return target::rocm();
} else if (name == "metal") {
return target::metal();
} else if (name == "stackvm" || name == "ext_dev") {
return target::stackvm();
} else {
LOG(ERROR) << "Unknown target name " << name;
return target::stackvm();
}
}
bool StartsWith(const std::string& str, const std::string& pattern) {
return str.compare(0, pattern.length(), pattern) == 0;
}
std::string GetDeviceName(const std::string& target_str) {
std::istringstream ss(target_str);
std::string target_name;
ss >> target_name;
std::string item;
while (ss >> item) {
if (StartsWith(item, "-device=")) {
return item.substr(std::string("-device=").length());
}
}
return "";
}
Target Target::create(const std::string& target_str) {
if (target_str.length() == 0) {
LOG(ERROR) << "target_str must not be empty";
}
std::istringstream ss(target_str);
std::string target_name;
ss >> target_name;
auto device_name = GetDeviceName(target_str);
auto result = device_name == "rasp" ?
target::rasp() :
TargetFromName(target_name);
std::string item;
while (ss >> item) {
result.options.push_back(item);
}
return result;
}
namespace target {
Target llvm() {
std::unordered_set<std::string> keys({ "llvm", "cpu" });
std::vector<std::string> options;
return Target("llvm", kDLCPU, 512, 1, keys, options);
}
Target cuda() {
std::unordered_set<std::string> keys({ "cuda", "gpu" });
std::vector<std::string> options;
return Target("cuda", kDLGPU, 512, 32, keys, options);
}
Target rocm() {
std::unordered_set<std::string> keys({ "rocm", "gpu" });
std::vector<std::string> options;
return Target("rocm", kDLROCM, 256, 1, keys, options);
}
Target metal() {
std::unordered_set<std::string> keys({ "gpu" });
std::vector<std::string> options;
return Target("metal", kDLMetal, 256, 1, keys, options);
}
Target rasp() {
std::unordered_set<std::string> keys({ "llvm", "cpu" });
std::vector<std::string> options({
"-device=rasp",
"-mtriple=armv7l-none-linux-gnueabihf",
"-mcpu=cortex-a53",
"-mattr=+neon"
});
return Target("llvm", kDLCPU, 512, 1, keys, options);
}
Target stackvm() {
std::unordered_set<std::string> keys({ "stackvm", "cpu" });
std::vector<std::string> options;
return Target("stackvm", kDLCPU, 512, 1, keys, options);
}
} // namespace target
bool LLVMEnabled() {
const runtime::PackedFunc* pf = runtime::Registry::Get("codegen.build_llvm");
return pf != nullptr;
}
/*! \return The default host target for a given device target */
Target DefaultTargetHost(Target target) {
if (target.device_type == kDLCPU) {
return target;
} else {
if (LLVMEnabled()) {
return target::llvm();
} else {
return target::stackvm();
}
}
}
Buffer BufferWithOffsetAlignment(Array<Expr> shape,
Type dtype,
std::string name,
int data_alignment,
int offset_factor) {
auto data = Var(name, Handle());
Expr elem_offset;
if (offset_factor != 0) {
elem_offset = Var(name + "_elem_offset", shape[0].type());
} else {
elem_offset = Expr();
}
return BufferNode::make(data, dtype, shape, Array<Expr>(), elem_offset, name, "",
data_alignment, offset_factor);
}
void GetBinds(const Array<Tensor>& args,
const std::unordered_map<Tensor, Buffer>& binds,
Map<Tensor, Buffer>* out_binds,
Array<NodeRef>* out_arg_list,
const BuildConfig& config) {
*out_binds = binds;
for (const auto &x : args) {
if (out_binds->find(x) == out_binds->end()) {
auto buf = BufferWithOffsetAlignment(x->shape, x->dtype, x->op->name,
config.data_alignment, config.offset_factor);
out_binds->Set(x, buf);
out_arg_list->push_back(buf);
} else {
out_arg_list->push_back((*out_binds)[x]);
}
}
}
/*!
* \brief Build a Stmt given a schedule, args and binds. This function runs the IR passes.
* \param sch The schedule to build.
* \param args The arguments for the schedule.
* \param binds Buffer assignments.
* \param loop_partition True if the LoopPartition pass should be included.
* \param out_arg_list Returns the arguments for the Stmt.
* \param config The build configuration.
* \return The built Stmt.
*/
Stmt BuildStmt(Schedule sch,
const Array<Tensor>& args,
const std::unordered_map<Tensor, Buffer>& binds,
bool loop_partition,
Array<NodeRef> *out_arg_list,
const BuildConfig& config) {
Map<Tensor, Buffer> out_binds;
GetBinds(args, binds, &out_binds, out_arg_list, config);
sch = sch.normalize();
// Phase 0
auto bounds = schedule::InferBound(sch);
auto stmt = schedule::ScheduleOps(sch, bounds);
stmt = ir::InjectPrefetch(stmt);
// Phase 1
stmt = ir::StorageFlatten(stmt, out_binds, 64);
stmt = ir::CanonicalSimplify(stmt);
if (loop_partition) {
stmt = ir::LoopPartition(stmt);
}
stmt = ir::VectorizeLoop(stmt);
stmt = ir::InjectVirtualThread(stmt);
stmt = ir::InjectDoubleBuffer(stmt, config.double_buffer_split_loop);
stmt = ir::StorageRewrite(stmt);
stmt = ir::UnrollLoop(stmt, config.auto_unroll_max_step, config.auto_unroll_max_depth,
config.auto_unroll_max_extent, config.unroll_explicit);
// Phase 2
stmt = ir::Simplify(stmt);
stmt = ir::LowerStorageAccessInfo(stmt);
stmt = ir::RemoveNoOp(stmt);
stmt = ir::RewriteUnsafeSelect(stmt);
return stmt;
}
Array<LoweredFunc> lower(Schedule sch,
const Array<Tensor>& args,
const std::string& name,
const std::unordered_map<Tensor, Buffer>& binds,
const BuildConfig& config) {
Array<NodeRef> out_arg_list;
auto stmt = BuildStmt(sch, args, binds, true, &out_arg_list, config);
return Array<LoweredFunc>({ ir::MakeAPI(stmt, name, out_arg_list, 0, config.restricted_func) });
}
runtime::Module build(const Array<LoweredFunc>& funcs,
const Target& target,
Target* target_host,
const BuildConfig& config) {
std::unordered_set<std::string> all_names;
for (const auto &x : funcs) {
CHECK(all_names.count(x->name) == 0) << "Duplicate function name " << x->name;
all_names.insert(x->name);
}
Target target_host_val = target_host == nullptr ?
DefaultTargetHost(target) :
*target_host;
Array<LoweredFunc> fhost;
Array<LoweredFunc> fdevice;
for (const auto &x : funcs) {
if (x->func_type == kMixedFunc) {
auto func = x;
if (config.detect_global_barrier) {
func = ir::ThreadSync(func, "global");
}
func = ir::ThreadSync(func, "shared");
func = ir::LowerThreadAllreduce(func, target.thread_warp_size);
auto fsplits = ir::SplitHostDevice(func);
fhost.push_back(fsplits[0]);
for (auto f = fsplits.begin() + 1; f != fsplits.end(); ++f) {
fdevice.push_back(*f);
}
} else if (x->func_type == kHostFunc) {
fhost.push_back(x);
} else if (x->func_type == kDeviceFunc) {
fdevice.push_back(x);
} else {
LOG(FATAL) << "unknown function type " << x->func_type;
}
}
if (target.keys.count("gpu") > 0 && fdevice.size() == 0) {
LOG(WARNING) << "Specified target " + target.str() +
" but cannot find device code. Did you forget to bind?";
}
for (size_t i = 0; i < fhost.size(); ++i) {
auto func = fhost[i];
func = ir::BindDeviceType(func, target.device_type);
func = ir::LowerTVMBuiltin(func);
fhost.Set(i, func);
}
for (size_t i = 0; i < fdevice.size(); ++i) {
auto func = fdevice[i];
func = ir::LowerIntrin(func, target.target_name);
fdevice.Set(i, func);
}
for (size_t i = 0; i < fhost.size(); ++i) {
auto func = fhost[i];
func = ir::LowerIntrin(func, target_host_val.target_name);
func = ir::CombineContextCall(func);
fhost.Set(i, func);
}
auto mhost = codegen::Build(fhost, target_host_val.str());
if (fdevice.size() > 0) {
auto mdev = codegen::Build(fdevice, target.str());
mhost.Import(mdev);
}
return mhost;
}
} // namespace tvm
......@@ -47,6 +47,10 @@ std::ostream& operator<<(std::ostream& os, const NodeRef& n) { // NOLINT(*)
return os;
}
Var var(const std::string& name_hint, Type t) {
return Var(name_hint, t);
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<IterVarNode>([](const IterVarNode *op, IRPrinter *p) {
p->stream << "iter_var(";
......
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/tvm.h>
#include <tvm/operation.h>
#include <tvm/build_module.h>
TEST(BuildModule, Basic) {
using namespace tvm;
auto n = var("n");
Array<Expr> shape;
shape.push_back(n);
auto A = placeholder(shape, Float(32), "A");
auto B = placeholder(shape, Float(32), "B");
auto C = compute(A->shape, [&A, &B](Expr i) {
return A[i] + B[i];
}, "C");
auto s = create_schedule({ C->op });
auto cAxis = C->op.as<ComputeOpNode>()->axis;
IterVar bx, tx;
s[C].split(cAxis[0], 64, &bx, &tx);
auto args = Array<Tensor>({ A, B, C });
std::unordered_map<Tensor, Buffer> binds;
BuildConfig config;
auto target = target::llvm();
auto lowered = lower(s, args, "func", binds, config);
auto module = build(lowered, target, nullptr, config);
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment