Unverified Commit 415a270d by Tianqi Chen Committed by GitHub

[C++][API] Consistent RAII scoping API. (#3231)

parent b2f8b96a
......@@ -290,14 +290,14 @@ class CanonicalSimplifier {
};
/*!
* \brief A RAII constraint context.
* \brief Constraint context.
*
* \code
*
* Var("x");
* arith::Analyzer analyzer;
* {
* arith::ConstraintContext cctx(&analyzer, x % 3 == 0);
* With<arith::ConstraintContext> scope(&analyzer, x % 3 == 0);
* CHECK_EQ(analyzer.modular_set(x)->coeff, 3);
* }
* // constraint no longer in effect.
......@@ -306,19 +306,24 @@ class CanonicalSimplifier {
* \endcode
*/
class ConstraintContext {
public:
private:
// declare friend to enable with.
friend class With<ConstraintContext>;
/*!
* \brief Construct a constraint context.
* \param analyzer The analyzer.
* \param constraint The constraint to be applied.
*/
ConstraintContext(Analyzer* analyzer, const Expr& constraint) DMLC_THROW_EXCEPTION;
/*! \brief destructor */
~ConstraintContext() DMLC_THROW_EXCEPTION {
exit_();
}
private:
ConstraintContext(Analyzer* analyzer, Expr constraint)
: analyzer_(analyzer), constraint_(constraint) {}
// enter the scope.
void EnterWithScope();
// exit the scope.
void ExitWithScope();
/*! \brief The analyzer */
Analyzer* analyzer_;
/*! \brief The constraint */
Expr constraint_;
/*! \brief function to be called in recovery */
std::function<void()> exit_;
};
......
......@@ -102,6 +102,50 @@ using ::tvm::AttrVisitor;
};
/*!
* \brief RAII wrapper function to enter and exit a context object
* similar to python's with syntax.
*
* \code
* // context class
* class MyContext {
* private:
* friend class With<MyContext>;
MyContext(arguments);
* void EnterWithScope();
* void ExitWithScope();
* };
*
* {
* With<MyContext> scope(arguments);
* // effect take place.
* }
* \endcode
*
* \tparam ContextType Type of the context object.
*/
template<typename ContextType>
class With {
public:
/*!
* \brief constructor.
* Enter the scope of the context.
*/
template<typename ...Args>
explicit With(Args&& ...args)
: ctx_(std::forward<Args>(args)...) {
ctx_.EnterWithScope();
}
/*! \brief destructor, leaves the scope of the context. */
~With() DMLC_THROW_EXCEPTION {
ctx_.ExitWithScope();
}
private:
/*! \brief internal context type. */
ContextType ctx_;
};
/*!
* \brief save the node as well as all the node it depends on as json.
* This can be used to serialize any TVM object
*
......
......@@ -37,7 +37,7 @@ namespace tvm {
/*!
* \brief Container for target device information.
* Use target::llvm, target::cuda etc functions instead of constructing directly.
* Use target::llvm, target::cuda etc functions instead of constructing directly.
*/
class TargetNode : public Node {
public:
......@@ -89,65 +89,47 @@ class TargetNode : public Node {
mutable std::string str_repr_;
};
/*! \brief reference cpass to the target. */
class Target : public NodeRef {
public:
Target() {}
explicit Target(NodePtr<Node> n) : NodeRef(n) {}
/*!
* \brief Create a Target given a string
* \param target_str the string to parse
*/
TVM_DLL static Target create(const std::string& target_str);
/*!
* \brief Push a new target context onto the thread local stack. The Target on top of
* the stack is used to determine which specialization to use when invoking a GenericFunc.
* \param target The target to set as the current context.
*/
TVM_DLL static void EnterTargetScope(const tvm::Target& target);
/*!
* \brief Pop a target off the thread local context stack, restoring the previous target
* as the current context.
*/
TVM_DLL static void ExitTargetScope();
TVM_DLL static Target Create(const std::string& target_str);
/*!
* \brief Get the current target context from thread local storage.
* \param allow_not_defined If the context stack is empty and this is set to true, an
* undefined Target will be returned. Otherwise, an empty context stack will cause a
* runtime error.
* \return The target that is the current context. The target may not be defined if
* allow_not_defined is true.
*/
TVM_DLL static tvm::Target current_target(bool allow_not_defined = true);
* \brief Get the current target context from thread local storage.
* \param allow_not_defined If the context stack is empty and this is set to true, an
* undefined Target will be returned. Otherwise, an empty context stack will cause a
* runtime error.
* \return The target that is the current context. The target may not be defined if
* allow_not_defined is true.
*/
TVM_DLL static tvm::Target Current(bool allow_not_defined = true);
inline const TargetNode* operator->() const {
const TargetNode* operator->() const {
return static_cast<const TargetNode*>(node_.get());
}
using ContainerType = TargetNode;
};
/*!
* \brief RAII container to provide a scoped target context. Pushes a target onto the
* context stack when constructed, and pops it when destructed.
*/
struct TargetContext {
class Internal;
private:
// enable with syntax.
friend class Internal;
friend class With<Target>;
/*!
* \brief Enter a new target context. The given target becomes the new current context.
* When the TargetContext is destructed, the previous context is restored.
* \param target The target to set as the new current context.
* \brief Push a new target context onto the thread local stack.
* The Target on top of the stack is used to determine which
* specialization to use when invoking a GenericFunc.
*/
explicit TargetContext(const tvm::Target& target) {
Target::EnterTargetScope(target);
}
/*! \brief Destructor. Pops the context off the thread local stack. */
~TargetContext() {
Target::ExitTargetScope();
}
TVM_DLL void EnterWithScope();
/*!
* \brief Pop a target off the thread local context stack,
* restoring the previous target as the current context.
*/
TVM_DLL void ExitWithScope();
};
/*! \brief This namespace provides functions to construct Target instances */
......@@ -190,11 +172,9 @@ TVM_DLL Target stackvm(const std::vector<std::string>& options =
} // namespace target
class BuildConfig;
/*!
* \brief Container for build configuration options
*/
* \brief Container for build configuration options
*/
class BuildConfigNode : public Node {
public:
/*!
......@@ -271,70 +251,49 @@ class BuildConfigNode : public Node {
};
/*!
* \brief Container for build configuration options
*/
* \brief Build configuration for compilations.
*/
class BuildConfig : public ::tvm::NodeRef {
public:
BuildConfig() {}
explicit BuildConfig(NodePtr<::tvm::Node> n) : NodeRef(n) {}
const BuildConfigNode* operator->() const {
return static_cast<const BuildConfigNode*>(node_.get());
}
BuildConfigNode* operator->() {
return static_cast<BuildConfigNode*>(node_.get());
}
/*!
* \brief Push a new BuildConfig context onto the thread local stack.
* \param build_config The configuration to set as the current context.
* \brief Construct a BuildConfig containing a empty build config node.
* \return The new BuildConfig
*/
TVM_DLL static void EnterBuildConfigScope(const tvm::BuildConfig& build_config);
/*!
* \brief Pop a build config off the thread local context stack, restoring the previous
* configuration as the current context.
*/
TVM_DLL static void ExitBuildConfigScope();
TVM_DLL static BuildConfig Create();
/*!
* \brief Get the current BuildConfig context from thread local storage, or a default
* configuration if a BuildConfig scope has not been entered.
* \return The configuration that is the current context.
*/
TVM_DLL static tvm::BuildConfig Current();
TVM_DLL static BuildConfig Current();
using ContainerType = BuildConfigNode;
};
class Internal;
/*!
* \brief RAII container to provide a scoped BuildConfig context. Pushes a configuration onto the
* context stack when constructed, and pops it when destructed.
*/
struct BuildConfigContext {
private:
// Enable with syntax.
friend class With<BuildConfig>;
/*!
* \brief Enter a new BuildConfig context. The given BuildConfig becomes the new current
* context. When the BuildConfigContext is destructed, the previous context is restored.
* \param build_config The BuildConfig to set as the new current context.
* \brief Push a new BuildConfig context onto the thread local stack.
*/
explicit BuildConfigContext(const tvm::BuildConfig& build_config) {
BuildConfig::EnterBuildConfigScope(build_config);
}
TVM_DLL void EnterWithScope();
/*! \brief Destructor. Pops the context off the thread local stack. */
~BuildConfigContext() {
BuildConfig::ExitBuildConfigScope();
}
/*!
* \brief Pop a build config off the thread local context stack,
* restoring the previous configuration as the current context.
*/
TVM_DLL void ExitWithScope();
};
/*!
* \brief Construct a BuildConfig containing a new BuildConfigNode
* \return The new BuildConfig
*/
TVM_DLL BuildConfig build_config();
/*!
* \brief Build a LoweredFunc given a schedule, args and binds
* \param sch The schedule to lower.
* \param args The arguments to the function.
......
......@@ -187,7 +187,7 @@ class BuildConfig(NodeBase):
def __exit__(self, ptype, value, trace):
if self.dump_pass_ir:
BuildConfig._dump_ir.exit()
_api_internal._ExitBuildConfigScope()
_api_internal._ExitBuildConfigScope(self)
def __setattr__(self, name, value):
if name in BuildConfig._node_defaults:
......
......@@ -133,7 +133,7 @@ class Target(NodeBase):
return self
def __exit__(self, ptype, value, trace):
_api_internal._ExitTargetScope()
_api_internal._ExitTargetScope(self)
@register_node
......
......@@ -123,8 +123,8 @@ TVM_REGISTER_API("arith._CreateAnalyzer")
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
// can't use make_shared due to noexcept(false) decl in destructor,
// see https://stackoverflow.com/a/43907314
auto ctx =
std::shared_ptr<ConstraintContext>(new ConstraintContext(self.get(), args[0]));
auto ctx = std::shared_ptr<With<ConstraintContext> >(
new With<ConstraintContext>(self.get(), args[0]));
auto fexit = [ctx](TVMArgs, TVMRetValue*) mutable {
ctx.reset();
};
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -54,10 +54,12 @@ void Analyzer::Bind(const VarExpr& v, const Range& range) {
// skip rewrite simplify
}
ConstraintContext::ConstraintContext(Analyzer* analyzer, const Expr& constraint) {
void ConstraintContext::EnterWithScope() {
CHECK(exit_ == nullptr);
// entering the scope.
auto f0 = analyzer->const_int_bound.EnterConstraint(constraint);
auto f1 = analyzer->modular_set.EnterConstraint(constraint);
auto f0 = analyzer_->const_int_bound.EnterConstraint(constraint_);
auto f1 = analyzer_->modular_set.EnterConstraint(constraint_);
// recovery function.
exit_ = [f0, f1]() {
if (f1 != nullptr) f1();
......@@ -65,6 +67,11 @@ ConstraintContext::ConstraintContext(Analyzer* analyzer, const Expr& constraint)
};
}
void ConstraintContext::ExitWithScope() {
CHECK(exit_ != nullptr);
exit_();
}
bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) {
if (const auto* ptr = expr.as<ir::IntImm>()) {
return ptr->value > lower_bound;
......
......@@ -1200,11 +1200,11 @@ Mutate_(const Select* op, const Expr& self) {
Expr cond = Mutate(op->condition);
Expr true_value, false_value;
{
ConstraintContext constraint(parent_, cond);
With<ConstraintContext> constraint(parent_, cond);
true_value = Mutate(op->true_value);
}
{
ConstraintContext constraint(parent_, Mutate(Not::make(cond)));
With<ConstraintContext> constraint(parent_, Mutate(Not::make(cond)));
false_value = Mutate(op->false_value);
}
if (is_zero(cond)) {
......@@ -1237,11 +1237,11 @@ Mutate_(const Call* op, const Expr& self) {
Expr cond = Mutate(op->args[0]);
Expr true_value, false_value;
{
ConstraintContext constraint(parent_, cond);
With<ConstraintContext> constraint(parent_, cond);
true_value = Mutate(op->args[1]);
}
{
ConstraintContext constraint(parent_, Mutate(Not::make(cond)));
With<ConstraintContext> constraint(parent_, Mutate(Not::make(cond)));
false_value = Mutate(op->args[2]);
}
if (is_zero(cond)) {
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -48,11 +48,11 @@ class StmtSimplifier : public IRMutator {
Expr condition = this->Mutate(op->condition);
Stmt then_case, else_case;
{
ConstraintContext ctx(&analyzer_, condition);
With<ConstraintContext> ctx(&analyzer_, condition);
then_case = this->Mutate(op->then_case);
}
if (op->else_case.defined()) {
ConstraintContext ctx(&analyzer_, Mutate(Not::make(condition)));
With<ConstraintContext> ctx(&analyzer_, Mutate(Not::make(condition)));
else_case = this->Mutate(op->else_case);
}
if (is_one(condition)) return then_case;
......@@ -94,7 +94,7 @@ class StmtSimplifier : public IRMutator {
Stmt Mutate_(const AssertStmt* op, const Stmt& s) final {
Expr condition = this->Mutate(op->condition);
Expr message = this->Mutate(op->message);
ConstraintContext ctx(&analyzer_, condition);
With<ConstraintContext> ctx(&analyzer_, condition);
Stmt body = this->Mutate(op->body);
if (condition.same_as(op->condition) &&
......
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2017 by Contributors
* Compile executable modules.
* \file build_module.cc
*/
......@@ -148,8 +147,7 @@ TVM_REGISTER_API("_TargetCreate")
TVM_REGISTER_API("_TargetFromString")
.set_body([](TVMArgs args, TVMRetValue* ret) {
std::string target_str = args[0];
*ret = Target::create(target_str);
*ret = Target::Create(target_str);
});
std::vector<std::string> TargetNode::keys() const {
......@@ -207,7 +205,7 @@ std::string GetDeviceName(const std::string& target_str) {
return "";
}
Target Target::create(const std::string& target_str) {
Target Target::Create(const std::string& target_str) {
if (target_str.length() == 0) {
LOG(ERROR) << "target_str must not be empty";
}
......@@ -231,25 +229,24 @@ Target Target::create(const std::string& target_str) {
struct TVMTargetThreadLocalEntry {
/*! \brief The current target context */
std::stack<tvm::Target> context_stack;
TVMTargetThreadLocalEntry() {
}
};
/*! \brief Thread local store to hold the Target context stack. */
typedef dmlc::ThreadLocalStore<TVMTargetThreadLocalEntry> TVMTargetThreadLocalStore;
void Target::EnterTargetScope(const tvm::Target& target) {
void Target::EnterWithScope() {
TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get();
entry->context_stack.push(target);
entry->context_stack.push(*this);
}
void Target::ExitTargetScope() {
void Target::ExitWithScope() {
TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get();
CHECK(!entry->context_stack.empty());
CHECK(entry->context_stack.top().same_as(*this));
entry->context_stack.pop();
}
tvm::Target Target::current_target(bool allow_not_defined) {
tvm::Target Target::Current(bool allow_not_defined) {
TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get();
if (entry->context_stack.size() > 0) {
return entry->context_stack.top();
......@@ -574,7 +571,7 @@ runtime::Module build(const Map<std::string, Array<LoweredFunc>>& inputs,
const BuildConfig& config) {
Map<Target, Array<LoweredFunc>> updated_input;
for (const auto& it : inputs) {
auto target = Target::create(it.first);
auto target = Target::Create(it.first);
updated_input.Set(target, it.second);
}
return build(updated_input, target_host, config);
......@@ -589,33 +586,35 @@ runtime::Module build(const Array<LoweredFunc>& funcs,
return build(inputs, target_host, config);
}
BuildConfig build_config() {
BuildConfig BuildConfig::Create() {
return BuildConfig(make_node<BuildConfigNode>());
}
/*! \brief Entry to hold the BuildConfig context stack. */
struct TVMBuildConfigThreadLocalEntry {
/*! \brief The default build config if the stack is empty */
tvm::BuildConfig default_config;
BuildConfig default_config;
/*! \brief The current build config context */
std::stack<tvm::BuildConfig> context_stack;
std::stack<BuildConfig> context_stack;
TVMBuildConfigThreadLocalEntry() :
default_config(build_config()) {
default_config(BuildConfig::Create()) {
}
};
/*! \brief Thread local store to hold the BuildConfig context stack. */
typedef dmlc::ThreadLocalStore<TVMBuildConfigThreadLocalEntry> TVMBuildConfigThreadLocalStore;
void BuildConfig::EnterBuildConfigScope(const tvm::BuildConfig& build_config) {
void BuildConfig::EnterWithScope() {
TVMBuildConfigThreadLocalEntry *entry = TVMBuildConfigThreadLocalStore::Get();
entry->context_stack.push(build_config);
entry->context_stack.push(*this);
}
void BuildConfig::ExitBuildConfigScope() {
void BuildConfig::ExitWithScope() {
TVMBuildConfigThreadLocalEntry *entry = TVMBuildConfigThreadLocalStore::Get();
CHECK(!entry->context_stack.empty());
CHECK(entry->context_stack.top().same_as(*this));
entry->context_stack.pop();
}
......@@ -714,7 +713,7 @@ GenericFunc& GenericFunc::register_func(const std::vector<std::string>& tags,
void GenericFunc::CallPacked(TVMArgs args, TVMRetValue* ret) const {
auto node = static_cast<GenericFuncNode*>(node_.get());
auto target = Target::current_target(true);
auto target = Target::Current(true);
PackedFunc func;
if (target.defined()) {
......@@ -740,16 +739,21 @@ TVM_REGISTER_API("_GetCurrentBuildConfig")
*ret = BuildConfig::Current();
});
class BuildConfig::Internal {
public:
static void EnterScope(BuildConfig target) {
target.EnterWithScope();
}
static void ExitScope(BuildConfig target) {
target.ExitWithScope();
}
};
TVM_REGISTER_API("_EnterBuildConfigScope")
.set_body([](TVMArgs args, TVMRetValue* ret) {
BuildConfig target = args[0];
BuildConfig::EnterBuildConfigScope(target);
});
.set_body_typed(BuildConfig::Internal::EnterScope);
TVM_REGISTER_API("_ExitBuildConfigScope")
.set_body([](TVMArgs args, TVMRetValue* ret) {
BuildConfig::ExitBuildConfigScope();
});
.set_body_typed(BuildConfig::Internal::ExitScope);
TVM_REGISTER_API("_BuildConfigSetAddLowerPass")
.set_body([](TVMArgs args, TVMRetValue* ret) {
......@@ -836,18 +840,23 @@ TVM_REGISTER_API("_GenericFuncCallFunc")
TVM_REGISTER_API("_GetCurrentTarget")
.set_body([](TVMArgs args, TVMRetValue* ret) {
bool allow_not_defined = args[0];
*ret = Target::current_target(allow_not_defined);
*ret = Target::Current(allow_not_defined);
});
class Target::Internal {
public:
static void EnterScope(Target target) {
target.EnterWithScope();
}
static void ExitScope(Target target) {
target.ExitWithScope();
}
};
TVM_REGISTER_API("_EnterTargetScope")
.set_body([](TVMArgs args, TVMRetValue* ret) {
Target target = args[0];
Target::EnterTargetScope(target);
});
.set_body_typed(Target::Internal::EnterScope);
TVM_REGISTER_API("_ExitTargetScope")
.set_body([](TVMArgs args, TVMRetValue* ret) {
Target::ExitTargetScope();
});
.set_body_typed(Target::Internal::ExitScope);
} // namespace tvm
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -54,7 +54,7 @@ runtime::Module BuildAOCL(Array<LoweredFunc> funcs, std::string target_str,
std::string cmd = "aoc aocl.cl";
// AOCL supports fp64.
cmd += " -Dcl_khr_fp64";
Target target = Target::create(target_str);
Target target = Target::Create(target_str);
if (target->device_name != "") {
cmd += " -board=" + target->device_name;
}
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -155,7 +155,7 @@ runtime::Module BuildSDAccel(Array<LoweredFunc> funcs, std::string target_str) {
std::string xclbin;
if (const auto* f = Registry::Get("tvm_callback_sdaccel_compile")) {
Target target = Target::create(target_str);
Target target = Target::Create(target_str);
xclbin = (*f)(kernel_info, target->device_name).operator std::string();
} else {
LOG(FATAL) << "Cannot compile Vivado HLS code.";
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -1142,7 +1142,7 @@ void CodeGenLLVM::VisitStmt_(const AttrStmt* op) {
}
void CodeGenLLVM::VisitStmt_(const AssertStmt* op) {
arith::ConstraintContext cctx(analyzer_.get(), op->condition);
With<arith::ConstraintContext> cctx(analyzer_.get(), op->condition);
this->VisitStmt(op->body);
}
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -626,7 +626,7 @@ void CodeGenSPIRV::VisitStmt_(const AttrStmt* op) {
}
void CodeGenSPIRV::VisitStmt_(const AssertStmt* op) {
arith::ConstraintContext cctx(analyzer_.get(), op->condition);
With<arith::ConstraintContext> cctx(analyzer_.get(), op->condition);
this->VisitStmt(op->body);
}
......
......@@ -445,7 +445,7 @@ class RelayBuildModule : public runtime::ModuleNode {
if (targets.size() == 1) {
func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
for (const auto& kv : targets) {
TargetContext tctx(kv.second);
With<Target> tctx(kv.second);
func = CallPackedFunc("relay._ir_pass.AlterOpLayout", func);
}
} else {
......@@ -466,9 +466,9 @@ class RelayBuildModule : public runtime::ModuleNode {
*/
Target CreateDefaultTarget(int device_type) {
std::string name = runtime::DeviceName(device_type);
if (name == "cpu") return Target::create("llvm");
if (name == "gpu") return Target::create("cuda");
return Target::create(name);
if (name == "cpu") return Target::Create("llvm");
if (name == "gpu") return Target::Create("cuda");
return Target::Create(name);
}
/*!
* \brief Update the target and fallback device required for heterogeneous
......@@ -548,7 +548,7 @@ class RelayBuildModule : public runtime::ModuleNode {
const RelayBuildConfig& cfg,
const std::unordered_map<std::string, tvm::runtime::NDArray> &params) {
// convert
tvm_cfg_ = build_config();
tvm_cfg_ = BuildConfig::Create();
TargetsMap device_target;
if (targets_.size() > 1) {
device_target = UpdateHeterogeneousInputs(targets_, cfg);
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -344,7 +344,7 @@ class CompileEngineImpl : public CompileEngineNode {
cache_[key] = value;
}
// Enforce use the target.
TargetContext target_ctx(key->target);
With<Target> target_scope(key->target);
CHECK(!value->cached_func.defined());
auto spair = CreateSchedule(key->source_func, key->target);
......@@ -371,7 +371,7 @@ class CompileEngineImpl : public CompileEngineNode {
cache_node->funcs = (*f)(
spair.first, all_args, cache_node->func_name, key->source_func);
} else {
tvm::BuildConfig bcfg = tvm::build_config();
tvm::BuildConfig bcfg = BuildConfig::Create();
std::unordered_map<Tensor, Buffer> binds;
cache_node->funcs = tvm::lower(spair.first, all_args, cache_node->func_name, binds, bcfg);
}
......
......@@ -364,7 +364,7 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
// Next generate the invoke instruction.
CHECK(func->IsPrimitive());
auto target = Target::create("llvm");
auto target = Target::Create("llvm");
auto key = CCacheKeyNode::make(func, target);
auto cfunc = engine->Lower(key);
// TODO(jroesch): support lowered funcs for multiple targets
......@@ -502,7 +502,7 @@ void PopulatePackedFuncMap(const std::vector<LoweredFunc>& lowered_funcs,
runtime::Module mod;
if (lowered_funcs.size() > 0) {
// TODO(@jroesch): we need to read target from build config
Target target = Target::create("llvm");
Target target = Target::Create("llvm");
if (const auto* f = runtime::Registry::Get("relay.backend.build")) {
mod = (*f)(tvm::Array<LoweredFunc>(lowered_funcs.begin(), lowered_funcs.end()), target);
} else {
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -203,10 +203,10 @@ Expr FoldConstant(const Expr& expr) {
DLContext ctx;
ctx.device_type = kDLCPU;
ctx.device_id = 0;
Target target = Target::create("llvm");
Target target = Target::Create("llvm");
// use a fresh build context
// in case we are already in a build context.
BuildConfigContext fresh_build_ctx(build_config());
With<BuildConfig> fresh_build_ctx(BuildConfig::Create());
return ConstantFolder(CreateInterpreter(
Module(nullptr), ctx, target)).Mutate(expr);
......
......@@ -375,10 +375,10 @@ DLContext CPUContext() {
}
FInterpreter CPUInterpreter() {
Target target = Target::create("llvm");
Target target = Target::Create("llvm");
// use a fresh build context
// in case we are already in a build context.
BuildConfigContext fresh_build_ctx(build_config());
With<BuildConfig> fresh_build_ctx(BuildConfig::Create());
return CreateInterpreter(Module(nullptr), CPUContext(), target);
}
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -50,14 +50,14 @@ TEST(BuildModule, Basic) {
auto args = Array<Tensor>({ A, B, C });
std::unordered_map<Tensor, Buffer> binds;
auto config = build_config();
auto config = BuildConfig::Create();
auto target = target::llvm();
auto lowered = lower(s, args, "func", binds, config);
auto module = build(lowered, target, Target(), config);
auto mali_target = Target::create("opencl -model=Mali-T860MP4@800Mhz -device=mali");
CHECK_EQ(mali_target->str(), "opencl -model=Mali-T860MP4@800Mhz -device=mali");
auto mali_target = Target::Create("opencl -model=Mali-T860MP4@800Mhz -device=mali");
CHECK_EQ(mali_target->str(), "opencl -model=Mali-T860MP4@800Mhz -device=mali");
}
TEST(BuildModule, Heterogeneous) {
......@@ -105,7 +105,7 @@ TEST(BuildModule, Heterogeneous) {
auto s1 = topi::cuda::schedule_injective(target_cuda, {elemwise_add});
auto s2 = create_schedule({elemwise_sub->op});
auto config = build_config();
auto config = BuildConfig::Create();
auto args1 = Array<Tensor>({A, B, elemwise_add});
auto args2 = Array<Tensor>({copy, C, elemwise_sub});
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -75,7 +75,7 @@ TEST(Relay, BuildModule) {
auto json_f = build_mod.GetFunction("get_graph_json", false);
auto mod_f = build_mod.GetFunction("get_module", false);
Map<tvm::Integer, tvm::Target> targets;
Target llvm_tgt = Target::create("llvm");
Target llvm_tgt = Target::Create("llvm");
targets.Set(0, llvm_tgt);
build_f(func, targets, llvm_tgt);
std::string json = json_f();
......
......@@ -94,7 +94,7 @@ inline bool IsTensorType(TVMArgValue arg) {
TVM_REGISTER_GLOBAL("topi.TEST_create_target")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = tvm::Target::create(args[0]);
*rv = tvm::Target::Create(args[0]);
});
/* Ops from broadcast.h */
......@@ -640,7 +640,7 @@ using FTVMScheduleBuilder = std::function<
*/
inline PackedFunc WrapSchedule(FTVMScheduleBuilder builder) {
return PackedFunc([builder](TVMArgs args, TVMRetValue* ret) {
auto target = Target::current_target(false);
auto target = Target::Current(false);
Array<Tensor> outs;
NodeRef argNodeRef = args[0];
if (argNodeRef->type_index() == outs->type_index()) {
......@@ -712,7 +712,7 @@ using FTVMDenseOpBuilder = std::function<tvm::Tensor(const Target& target,
*/
inline PackedFunc WrapDenseOp(FTVMDenseOpBuilder builder) {
return PackedFunc([builder](TVMArgs args, TVMRetValue* ret) {
auto target = Target::current_target(false);
auto target = Target::Current(false);
Tensor data = args[0];
Tensor weight = args[1];
Tensor bias = args[2];
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment