Commit b131d836 by Bing Xu Committed by Jared Roesch

Relay C++ Build Module (#3082)

* [Relay] C++ Build module

* asdf
parent 472c3146
......@@ -344,6 +344,19 @@ TVM_DLL Array<LoweredFunc> lower(Schedule sch,
const std::string& name,
const std::unordered_map<Tensor, Buffer>& binds,
const BuildConfig& config);
/*!
* \brief Split host/device function and running necessary pass before build
* \param funcs The functions to be built.
* \param target The target device to build for.
* \param target_host The target for building host code. To use the default, pass Target()
* \param config The build configuration.
* \return The Array<Array<LoweredFunc>> with 2 elements. First is host function Array,
second is device function array
*/
TVM_DLL Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs,
const Target& target,
const Target& target_host,
const BuildConfig& config);
/*!
* \brief Build a device and host module for a specific target from an array of lowered functions.
......
......@@ -423,7 +423,7 @@ Array<LoweredFunc> lower(Schedule sch,
return Array<LoweredFunc>({ ir::MakeAPI(stmt, name, out_arg_list, 0, config->restricted_func) });
}
runtime::Module build(const Array<LoweredFunc>& funcs,
Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs,
const Target& target,
const Target& target_host,
const BuildConfig& config) {
......@@ -493,6 +493,17 @@ runtime::Module build(const Array<LoweredFunc>& funcs,
func = ir::CombineContextCall(func);
fhost.Set(i, func);
}
return {fhost, fdevice};
}
runtime::Module build(const Array<LoweredFunc>& funcs,
const Target& target,
const Target& target_host,
const BuildConfig& config) {
auto target_host_val = target_host.defined() ? target_host : DefaultTargetHost(target);
auto host_dev_funcs = split_dev_host_funcs(funcs, target, target_host, config);
auto& fhost = host_dev_funcs[0];
auto& fdevice = host_dev_funcs[1];
auto mhost = codegen::Build(fhost, target_host_val->str());
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file relay/backend/build_module.cc
* \brief Code generation for TVM's graph runtime.
*/
#include <tvm/build_module.h>
#include <tvm/relay/op.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h>
#include <vector>
#include <string>
#include <memory>
#include "utils.h"
namespace tvm {
namespace relay {
namespace backend {
/*!
* \brief Context name / index
* See: python/tvm/_ffi/runtime_ctypes.py
*/
struct ContextMap {
static const std::unordered_map<int, std::string> mask2str;
static const std::unordered_map<std::string, int> str2mask;
static std::string Mask2Str(int mask) {
CHECK_GT(mask2str.count(mask), 0) << "Unknown mask.";
return mask2str.at(mask);
}
static int Str2Mask(const std::string& str) {
CHECK_GT(str2mask.count(str), 0) << "Unknown context.";
return str2mask.at(str);
}
};
const std::unordered_map<int, std::string> ContextMap::mask2str = {
{1, "cpu"},
{2, "gpu"},
{4, "opencl"},
{5, "aocl"},
{6, "sdaccel"},
{7, "vulkan"},
{8, "metal"},
{9, "vpi"},
{10, "rocm"},
{11, "opengl"},
{12, "ext_dev"}
};
const std::unordered_map<std::string, int> ContextMap::str2mask = {
{"llvm", 1},
{"cpu", 1},
{"c", 1},
{"gpu", 2},
{"cuda", 2},
{"nvptx", 2},
{"cl", 4},
{"opencl", 4},
{"aocl", 5},
{"aocl_sw_emu", 5},
{"vulkan", 7},
{"metal", 8},
{"vpi", 9},
{"rocm", 10},
{"opengl", 11},
{"ext_dev", 12}
};
/*!
* \brief A data structure to map the names of specific optimizations to
* numeric optimization levels
*
*/
struct OptPassLevel {
static const std::unordered_map<std::string, int> _data;
/*!
* \brief Get level for an optimization pass
*
* \param key pass name
* \return int level
*/
int operator[](const std::string& key) const {
auto it = _data.find(key);
if (it == _data.end()) {
return -1;
}
return it->second;
}
};
const std::unordered_map<std::string, int> OptPassLevel::_data = {
{"SimplifyInference", 0},
{"OpFusion", 1},
{"FoldConstant", 2},
{"CombineParallelConv2D", 3},
{"FoldScaleAxis", 3},
{"AlterOpLayout", 3},
{"CanonicalizeOps", 3},
{"EliminateCommonSubexpr", 3}
};
/*!
* \brief Output of building module
*
*/
struct BuildOutput {
std::string graph_json;
runtime::Module mod;
std::unordered_map<std::string, tvm::runtime::NDArray> params;
};
/*!
* \brief Relay building config
*
*/
struct RelayBuildConfig {
int opt_level{2};
std::string fallback_device{"llvm"};
std::unordered_set<std::string> enabled_pass;
std::unordered_set<std::string> disabled_pass;
OptPassLevel OPT_PASS_LEVEL;
inline bool pass_enabled(const std::string& pass_name) const {
if (disabled_pass.count(pass_name)) {
return false;
}
if (enabled_pass.count(pass_name)) {
return true;
}
return opt_level >= OPT_PASS_LEVEL[pass_name];
}
};
/*!
* \brief GraphCodegen module wrapper
*
*/
struct GraphCodegen {
public:
GraphCodegen() {
auto pf = GetPackedFunc("relay.build_module._GraphRuntimeCodegen");
mod = (*pf)();
}
~GraphCodegen() {}
void Init(runtime::Module* m,
Map<HalideIR::Expr, HalideIR::Expr> targets) {
Array<HalideIR::Expr> tgts;
for (auto kv : targets) {
tgts.push_back(kv.first);
tgts.push_back(kv.second);
}
CallFunc("init", m, tgts);
}
void Codegen(const Function& func) {
CallFunc("codegen", func);
}
std::string GetJSON() {
return CallFunc<std::string>("get_graph_json", nullptr);
}
Map<std::string, Array<LoweredFunc> > GetLoweredFunc() {
return CallFunc<Map<std::string, Array<LoweredFunc> > >("get_lowered_funcs", nullptr);
}
std::unordered_map<std::string, tvm::runtime::NDArray> GetParams() {
std::unordered_map<std::string, tvm::runtime::NDArray> ret;
auto names = CallFunc<Array<HalideIR::Expr> >("list_params_name", nullptr);
for (auto expr : names) {
auto key = expr.as<ir::StringImm>()->value;
ret[key] = CallFunc<runtime::NDArray>("get_param_by_name", key);
}
return ret;
}
protected:
tvm::runtime::Module mod;
template<typename R, typename ...Args>
R CallFunc(const std::string &name, Args... args) {
auto pf = mod.GetFunction(name, false);
return pf(std::forward<Args>(args)...);
}
template<typename ...Args>
void CallFunc(const std::string &name, Args... args) {
auto pf = mod.GetFunction(name, false);
pf(std::forward<Args>(args)...);
return;
}
};
template<typename R, typename ...Args>
R CallPackedFunc(const std::string &name, Args... args) {
auto pf = GetPackedFunc(name);
return (*pf)(std::forward<Args>(args)...);
}
template<typename ...Args>
Function CallPackedFunc(const std::string &name, Args... args) {
auto pf = GetPackedFunc(name);
return (*pf)(std::forward<Args>(args)...);
}
/*!
* \brief Relay build module
*
*/
class RelayBuildModule : public runtime::ModuleNode {
public:
/*!
* \brief Get member function to front-end
* \param name The name of the function.
* \param sptr_to_self The pointer to the module node.
* \return The corresponding member function.
*/
PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
if (name == "get_graph_json") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->GetGraphJSON();
});
} else if (name == "get_module") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->GetModule();
});
} else if (name == "build") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.num_args, 3);
Array<HalideIR::Expr> tmp = args[1];
std::unordered_map<std::string, std::string> targets;
for (size_t i = 0; i < tmp.size(); i += 2) {
auto k = tmp[i].as<ir::StringImm>()->value;
auto v = tmp[i + 1].as<ir::StringImm>()->value;
targets[k] = v;
}
this->Build(args[0], targets, args[2]);
});
} else if (name == "list_params") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->ListParamNames();
});
} else if (name == "get_params") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->GetParams();
});
} else if (name == "set_opt_level") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.num_args, 1);
int level = args[0];
this->SetOptLevel(level);
});
} else if (name == "set_fallback_device") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
std::string dev = args[0];
this->SetFallBackDev(dev);
});
} else if (name == "add_pass") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
std::string pass_name = args[0];
this->AddPass(pass_name);
});
} else if (name == "disable_pass") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
std::string pass_name = args[0];
this->DisablePass(pass_name);
});
} else if (name == "set_params") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
Map<std::string, Constant> params = args[0];
for (const auto& kv : params) {
this->SetParam(kv.first, kv.second->data);
}
});
} else {
LOG(FATAL) << "Unknown packed function: " << name;
return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {});
}
}
/*!
* \brief Get the GraphJSON for runtime
*
* \return const std::string graph_json
*/
const std::string& GetGraphJSON() {
return ret_.graph_json;
}
/*!
* \brief Add extra pass into build cfg
*
* \param pass_name name of pass
*/
void AddPass(const std::string& pass_name) {
cfg_.enabled_pass.insert(pass_name);
}
/*!
* \brief Disable a specific pass in cfg
*
* \param pass_name name of pass
*/
void DisablePass(const std::string& pass_name) {
cfg_.disabled_pass.insert(pass_name);
}
/*!
* \brief Set the Fallback device
*
* \param device name
*/
void SetFallBackDev(const std::string& dev) {
cfg_.fallback_device = dev;
}
/*!
* \brief Get the Module object
*
* \return runtime::Module
*/
runtime::Module GetModule() {
return ret_.mod;
}
/*!
* \brief List all paramter names
*
* \return Array<StringImm> names of params
*/
Array<HalideIR::Expr> ListParamNames() {
Array<HalideIR::Expr> ret;
for (const auto& kv : params_) {
ret.push_back(ir::StringImm::make(kv.first));
}
return ret;
}
/*!
* \brief Get params dictionary
*
* \return Map<std::string, Constant> params dictionary
*/
Map<std::string, Constant> GetParams() {
Map<std::string, Constant> ret;
for (const auto& kv : ret_.params) {
ret.Set(kv.first, ConstantNode::make(kv.second));
}
return ret;
}
/*!
* \brief Set the parameters
*
* \param name name of parameter
* \param data_in input DLTensor
*/
void SetParam(const std::string& name, runtime::NDArray data_in) {
params_[name] = data_in;
}
/*!
* \brief Set the optimization level
*
* \param level
*/
void SetOptLevel(char level) {
cfg_.opt_level = level;
}
/*!
* \brief type key
*
* \return const char*
*/
const char* type_key() const final {
return "RelayBuildModule";
}
/*!
* \brief Build relay function for graph runtime
*
* \param func Relay Function
* \param target Target device
* \param target_host Host target device
*/
void Build(Function func,
const std::unordered_map<std::string, std::string>& targets,
const std::string& target_host) {
targets_ = targets;
target_host_ = target_host;
BuildRelay(func, cfg_, params_);
}
protected:
/*!
* \brief Bind params to function by using name
* \param func Relay function
* \param params params dict
* \return relay::Function
*/
relay::Function BindParamsByName(relay::Function func,
const std::unordered_map<std::string, runtime::NDArray>& params) {
std::unordered_map<std::string, relay::Var> name_dict;
std::unordered_set<relay::Var, NodeHash, NodeEqual> repeat_var;
for (auto arg : func->params) {
const auto &name = arg->name_hint();
if (name_dict.count(name)) {
repeat_var.insert(arg);
} else {
name_dict[name] = arg;
}
}
std::unordered_map<relay::Var, Expr, NodeHash, NodeEqual> bind_dict;
for (auto &kv : params) {
if (name_dict.count(kv.first) == 0) {
continue;
}
auto arg = name_dict.at(kv.first);
if (repeat_var.count(arg)) {
LOG(FATAL) << "Multiple args in the function have name " << kv.first;
}
auto e = CallPackedFunc<Expr>("relay._make.Constant", kv.second);
bind_dict[arg] = e;
}
return CallPackedFunc("relay._expr.Bind", func, tvm::Map<relay::Var, Expr>(bind_dict));
}
/*!
* \brief Optimize Relay function
*
* \param func Input function
* \param target target device
* \param cfg Relay build config
* \param params params dict
* \return relay::Function
*/
relay::Function Optimize(relay::Function func,
const std::unordered_map<std::string, std::string>& targets,
const RelayBuildConfig& cfg,
const std::unordered_map<std::string, runtime::NDArray>& params) {
if (params.size()) {
func = BindParamsByName(func, params);
}
if (cfg.pass_enabled("SimplifyInference")) {
func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
func = CallPackedFunc("relay._ir_pass.simplify_inference", func);
}
if (cfg.pass_enabled("EliminateCommonSubexpr")) {
auto fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
Expr expr = args[0];
if (expr.as<CallNode>()) {
auto call_node = expr.as<CallNode>();
auto op_node = call_node->op.as<OpNode>();
if (op_node->name == "cast") {
auto attrs = call_node->attrs.as<CastAttrs>();
if (attrs->dtype == HalideIR::Int(32)) {
*rv = true;
}
}
}
*rv = false;
});
func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
func = CallPackedFunc("relay._ir_pass.eliminate_common_subexpr", func, fskip);
}
if (cfg.pass_enabled("CombineParallelConv2D")) {
const int min_num_branches = 3;
func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
func = CallPackedFunc("relay._ir_pass.CombineParallelConv2D", func, min_num_branches);
}
if (cfg.pass_enabled("FoldConstant")) {
func = CallPackedFunc("relay._ir_pass.FoldConstant", func);
}
if (cfg.pass_enabled("FoldScaleAxis")) {
func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
func = CallPackedFunc("relay._ir_pass.backward_fold_scale_axis", func);
func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
func = CallPackedFunc("relay._ir_pass.forward_fold_scale_axis", func);
func = CallPackedFunc("relay._ir_pass.FoldConstant", func);
}
if (cfg.pass_enabled("CanonicalizeOps")) {
func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
func = CallPackedFunc("relay._ir_pass.canonicalize_ops", func);
}
if (cfg.pass_enabled("AlterOpLayout")) {
if (targets.size() == 1) {
func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
func = CallPackedFunc("relay._ir_pass.AlterOpLayout", func);
} else {
LOG(WARNING) << "AlterOpLayout pass is not enabled for heterogeneous"
<< " execution yet.";
}
}
if (cfg.pass_enabled("FoldConstant")) {
func = CallPackedFunc("relay._ir_pass.FoldConstant", func);
}
return func;
}
/*!
* \brief Update the target and fallback device required for heterogeneous
* compilation. CPU is used as the fallback device if it wasn't provided.
* Meanwhile, a CPU device type and "llvm" pair will be added to the target
* dictionary in this case.
*
* \param targets dictionary
* \param cfg
* \return Map<HalideIR::Expr, HalideIR::Expr>
*/
Map<HalideIR::Expr, HalideIR::Expr> UpdateHeterogeneousInputs(
const std::unordered_map<std::string, std::string>& targets,
const RelayBuildConfig& cfg) {
Map<HalideIR::Expr, HalideIR::Expr> device_target;
std::unordered_map<int64_t, std::string> tmp_map;
auto fallback_idx = ContextMap::Str2Mask(cfg.fallback_device);
for (const auto& kv : targets) {
tmp_map[ContextMap::Str2Mask(kv.first)] = kv.second;
}
if (tmp_map.count(fallback_idx) == 0) {
tmp_map[fallback_idx] = cfg.fallback_device;
}
for (const auto& kv : tmp_map) {
device_target.Set(
ir::IntImm::make(HalideIR::Int(64), kv.first),
ir::StringImm::make(kv.second));
}
return device_target;
}
/*!
* \brief Execute the device annotation passes to update the input program and
* target information.
*
* \param func
* \param cfg
* \param targets_map_ptr
* \return Function
*/
Function RunDeviceAnnotationPass(
Function func,
const RelayBuildConfig& cfg,
Map<HalideIR::Expr, HalideIR::Expr>* targets_map_ptr) {
auto fallback_idx = ContextMap::Str2Mask(cfg.fallback_device);
func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
func = CallPackedFunc("relay._ir_pass.RewriteDeviceAnnotation", func, fallback_idx);
auto device_map = CallPackedFunc<Map<Expr, Integer> >("relay._ir_pass.CollectDeviceInfo",
func,
nullptr);
if (device_map.size() == 0) {
auto annotation_map =
CallPackedFunc<Map<Expr, Integer> >("relay._ir_pass.CollectDeviceAnnotationOps",
func,
nullptr);
if (annotation_map.size() == 0) {
targets_map_ptr->Set(
ir::IntImm::make(HalideIR::Int(64), 0),
ir::StringImm::make(cfg.fallback_device));
} else {
int64_t dev_type = -1;
for (auto kv : annotation_map) {
dev_type = kv.second->value;
break;
}
for (auto kv : annotation_map) {
CHECK_EQ(kv.second->value, dev_type)
<< "Expressions in the function are "
<< "annotated with various device types,"
<< "but not device copy operators "
<< "found. Please check the "
<< "RewriteAnnotation pass.";
}
targets_map_ptr->Set(
ir::IntImm::make(HalideIR::Int(64), 0),
ir::StringImm::make(ContextMap::Mask2Str(dev_type)));
}
}
return func;
}
/*!
* \brief Build module given lowered functions for each target
*
* \param lowered_funcs target_str -> Array<LoweredFunc> map
* \param targets Targets map
* \param cfg Building configuration
*/
void BuildModule(const Map<std::string, Array<LoweredFunc> >& lowered_funcs,
const Map<HalideIR::Expr, HalideIR::Expr>& targets,
const BuildConfig& cfg) {
auto target_host = Target::create(cfg_.fallback_device);
for (const auto& kv : lowered_funcs) {
std::unordered_set<std::string> fname_set;
for (auto f : kv.second) {
if (fname_set.count(f->name)) {
LOG(FATAL) << "Duplicate function name "
<< f->name;
}
fname_set.insert(f->name);
}
}
std::unordered_map<std::string, Target> target_map;
for (const auto& kv : lowered_funcs) {
target_map[kv.first] = Target::create(kv.first);
}
Array<LoweredFunc> fhost_all;
std::vector<runtime::Module> device_module;
for (const auto& kv : lowered_funcs) {
auto target = target_map[kv.first];
auto host_dev_funcs = split_dev_host_funcs(kv.second, target, target_host, cfg);
for (auto f : host_dev_funcs[0]) {
fhost_all.push_back(f);
}
if (host_dev_funcs[1].size()) {
auto mdev = codegen::Build(host_dev_funcs[1], target->str());
device_module.push_back(mdev);
}
}
auto mhost = codegen::Build(fhost_all, target_host->str());
for (auto mdev : device_module) {
mhost.Import(mdev);
}
ret_.mod = mhost;
}
/*!
* \brief Build relay function to runtime module
*
* \param func Relay Function
* \param cfg Relay build config
* \param params parameters
*/
void BuildRelay(Function func,
const RelayBuildConfig& cfg,
const std::unordered_map<std::string, tvm::runtime::NDArray> &params) {
// convert
tvm_cfg_ = build_config();
Map<HalideIR::Expr, HalideIR::Expr> device_target;
if (targets_.size() > 1) {
device_target = UpdateHeterogeneousInputs(targets_, cfg);
} else {
for (auto &kv : targets_) {
device_target.Set(
ir::IntImm::make(HalideIR::Int(64), ContextMap::Str2Mask(kv.first)),
ir::StringImm::make(kv.second));
}
}
func = Optimize(func, targets_, cfg, params);
if (device_target.size() > 1) {
func = RunDeviceAnnotationPass(func, cfg, &device_target);
}
func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
func = CallPackedFunc("relay._ir_pass.FuseOps", func, cfg.opt_level);
func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
graph_codegen_ = std::unique_ptr<GraphCodegen>(new GraphCodegen());
graph_codegen_->Init(nullptr, device_target);
graph_codegen_->Codegen(func);
ret_.graph_json = graph_codegen_->GetJSON();
ret_.params = graph_codegen_->GetParams();
BuildModule(graph_codegen_->GetLoweredFunc(),
device_target,
tvm_cfg_);
}
protected:
std::unique_ptr<GraphCodegen> graph_codegen_;
/*! \brief target device */
std::unordered_map<std::string, std::string> targets_;
/*! \brief target host device */
std::string target_host_;
/*! \brief frontend optimization configure */
RelayBuildConfig cfg_;
/*! \brief parameters */
std::unordered_map<std::string, runtime::NDArray> params_;
/*! \brief building output */
BuildOutput ret_;
/*! \brief tvm building cfg */
BuildConfig tvm_cfg_;
};
runtime::Module RelayBuildCreate() {
std::shared_ptr<RelayBuildModule> exec = std::make_shared<RelayBuildModule>();
return runtime::Module(exec);
}
TVM_REGISTER_GLOBAL("relay.build_module._BuildModule").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = RelayBuildCreate();
});
} // namespace backend
} // namespace relay
} // namespace tvm
......@@ -371,7 +371,9 @@ class CompileEngineImpl : public CompileEngineNode {
cache_node->funcs = (*f)(
spair.first, all_args, cache_node->func_name, key->source_func);
} else {
LOG(FATAL) << "relay.backend.lower is not registred";
tvm::BuildConfig bcfg = tvm::build_config();
std::unordered_map<Tensor, Buffer> binds;
cache_node->funcs = tvm::lower(spair.first, all_args, cache_node->func_name, binds, bcfg);
}
value->cached_func = CachedFunc(cache_node);
return value;
......
......@@ -416,7 +416,12 @@ class GraphRuntimeCodegen
} else {
// heterogeneous execution.
const auto call_dev_key = std::to_string(call_dev_type);
const auto call_dev_name = runtime::DeviceName(call_dev_type);
std::string call_dev_name;
if (call_dev_type == 0) {
call_dev_name = "llvm";
} else {
call_dev_name = runtime::DeviceName(call_dev_type);
}
if (targets_.count(call_dev_name) == 0 && targets_.count(call_dev_key) == 0) {
LOG(FATAL) << "No target is provided for device "
<< call_dev_name;
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <gtest/gtest.h>
#include <tvm/tvm.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/type.h>
#include <tvm/relay/pass.h>
#include <topi/generic/injective.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
TVM_REGISTER_GLOBAL("test.sch")
.set_body([](tvm::TVMArgs args, tvm::TVMRetValue *rv) {
*rv = topi::generic::schedule_injective(args[0], args[1]);
});
TEST(Relay, BuildModule) {
using namespace tvm;
auto tensor_type = relay::TensorTypeNode::make({2, 3}, ::tvm::Float(32));
auto a = relay::VarNode::make("a", tensor_type);
auto b = relay::VarNode::make("b", tensor_type);
auto add_op = relay::Op::Get("add");
auto x = relay::CallNode::make(add_op, {a, b}, tvm::Attrs(), {});
auto c = relay::VarNode::make("c", tensor_type);
auto y = relay::CallNode::make(add_op, {x, c}, tvm::Attrs(), {});
auto func = relay::FunctionNode::make(relay::FreeVars(y), y, relay::Type(), {});
auto A = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
auto B = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
auto C = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
auto pA = (float*)A.ToDLPack()->dl_tensor.data;
auto pB = (float*)B.ToDLPack()->dl_tensor.data;
auto pC = (float*)C.ToDLPack()->dl_tensor.data;
for (int i = 0; i < 6; ++i) {
pA[i] = i;
pB[i] = i + 1;
pC[i] = i + 2;
}
// get schedule
auto reg = tvm::runtime::Registry::Get("relay.op._Register");
auto s_i = tvm::runtime::Registry::Get("test.sch");
if (!reg) {
LOG(FATAL) << "no _Register";
}
if (!s_i) {
LOG(FATAL) << "no _Register";
}
(*reg)("add", "FTVMSchedule", *s_i, 10);
// build
auto pfb = tvm::runtime::Registry::Get("relay.build_module._BuildModule");
tvm::runtime::Module build_mod = (*pfb)();
auto build_f = build_mod.GetFunction("build", false);
auto json_f = build_mod.GetFunction("get_graph_json", false);
auto mod_f = build_mod.GetFunction("get_module", false);
Array<HalideIR::Expr> target_pair;
target_pair.push_back(ir::StringImm::make("cpu"));
target_pair.push_back(ir::StringImm::make("llvm"));
build_f(func, target_pair, "llvm");
std::string json = json_f();
tvm::runtime::Module mod = mod_f();
// run
auto ctx = A->ctx;
auto pfr = tvm::runtime::Registry::Get("tvm.graph_runtime.create");
tvm::runtime::Module run_mod = (*pfr)(json, mod, (int)ctx.device_type, (int)ctx.device_id);
auto set_input_f = run_mod.GetFunction("set_input", false);
auto run_f = run_mod.GetFunction("run", false);
auto get_output_f = run_mod.GetFunction("get_output", false);
set_input_f("a", A);
set_input_f("b", B);
set_input_f("c", C);
run_f();
tvm::runtime::NDArray Y = get_output_f(0);
auto pY = (float*)Y.ToDLPack()->dl_tensor.data;
for (int i = 0; i < 6; ++i) {
CHECK_LT(fabs(pY[i] - (i + (i + 1) + (i + 2))), 1e-4);
}
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
}
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy as np
import tvm
from tvm import relay
from tvm._ffi.function import _init_api
_init_api("tvm.relay.build_module")
class BuildModule(object):
def __init__(self):
self.mod = relay.build_module._BuildModule()
self._get_graph_json = self.mod["get_graph_json"]
self._get_module = self.mod["get_module"]
self._build = self.mod["build"]
self._set_opt_level = self.mod["set_opt_level"]
self._set_params_func = self.mod["set_params"]
self._get_params_func = self.mod["get_params"]
def build(self, func, target, target_host, params):
tgts = []
for kv in target.items():
tgts.append(kv[0])
tgts.append(kv[1])
self._set_params(params)
self._build(func, tgts, target_host)
def get_json(self):
return self._get_graph_json()
def get_module(self):
return self._get_module()
def set_opt_level(self, level):
self._set_opt_level(level)
def _set_params(self, params):
inputs = {}
for name, param in params.items():
inputs[name] = relay.Constant(param)
self._set_params_func(inputs)
def get_params(self):
params = self._get_params_func()
ret = {}
for key, value in params.items():
ret[key] = value.data
return ret
def test_build():
m_bld = BuildModule()
tgt_name = "llvm"
tgt = "llvm"
ctx = tvm.cpu()
# func
a = relay.var("a", dtype="float32", shape=(16, 8))
b = relay.var("b", dtype="float32", shape=(8, 8))
c = relay.var("c", dtype="float32", shape=(16, 8))
x = relay.nn.dense(a, b)
y = relay.nn.relu(x)
z = y + c
func = relay.Function([a, b, c], z)
A = tvm.nd.array(np.random.uniform(-1, 1, (16, 8)).astype("float32"), ctx=ctx)
B = tvm.nd.array(np.random.uniform(-1, 1, (8, 8)).astype("float32"), ctx=ctx)
C = tvm.nd.array(np.random.uniform(-1, 1, (16, 8)).astype("float32"), ctx=ctx)
params = {
"b" : B,
"c" : C
}
# build
targets = {
tgt: tgt
}
m_bld.set_opt_level(3)
m_bld.build(func, targets, "llvm -mcpu=sse3", params=params)
g_json = m_bld.get_json()
mmod = m_bld.get_module()
params = m_bld.get_params()
# test
rt = tvm.contrib.graph_runtime.create(g_json, mmod, ctx)
rt.set_input("a", A)
rt.load_params(relay.save_param_dict(params))
rt.run()
out = rt.get_output(0)
np.testing.assert_allclose(out.asnumpy(),
np.maximum(np.dot(A.asnumpy(), B.asnumpy().T), 0) + C.asnumpy(), atol=1e-5, rtol=1e-5)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment