Unverified Commit 415a270d by Tianqi Chen Committed by GitHub

[C++][API] Consistent RAII scoping API. (#3231)

parent b2f8b96a
...@@ -290,14 +290,14 @@ class CanonicalSimplifier { ...@@ -290,14 +290,14 @@ class CanonicalSimplifier {
}; };
/*! /*!
* \brief A RAII constraint context. * \brief Constraint context.
* *
* \code * \code
* *
* Var("x"); * Var("x");
* arith::Analyzer analyzer; * arith::Analyzer analyzer;
* { * {
* arith::ConstraintContext cctx(&analyzer, x % 3 == 0); * With<arith::ConstraintContext> scope(&analyzer, x % 3 == 0);
* CHECK_EQ(analyzer.modular_set(x)->coeff, 3); * CHECK_EQ(analyzer.modular_set(x)->coeff, 3);
* } * }
* // constraint no longer in effect. * // constraint no longer in effect.
...@@ -306,19 +306,24 @@ class CanonicalSimplifier { ...@@ -306,19 +306,24 @@ class CanonicalSimplifier {
* \endcode * \endcode
*/ */
class ConstraintContext { class ConstraintContext {
public: private:
// declare friend to enable with.
friend class With<ConstraintContext>;
/*! /*!
* \brief Construct a constraint context. * \brief Construct a constraint context.
* \param analyzer The analyzer. * \param analyzer The analyzer.
* \param constraint The constraint to be applied. * \param constraint The constraint to be applied.
*/ */
ConstraintContext(Analyzer* analyzer, const Expr& constraint) DMLC_THROW_EXCEPTION; ConstraintContext(Analyzer* analyzer, Expr constraint)
/*! \brief destructor */ : analyzer_(analyzer), constraint_(constraint) {}
~ConstraintContext() DMLC_THROW_EXCEPTION { // enter the scope.
exit_(); void EnterWithScope();
} // exit the scope.
void ExitWithScope();
private: /*! \brief The analyzer */
Analyzer* analyzer_;
/*! \brief The constraint */
Expr constraint_;
/*! \brief function to be called in recovery */ /*! \brief function to be called in recovery */
std::function<void()> exit_; std::function<void()> exit_;
}; };
......
...@@ -102,6 +102,50 @@ using ::tvm::AttrVisitor; ...@@ -102,6 +102,50 @@ using ::tvm::AttrVisitor;
}; };
/*! /*!
* \brief RAII wrapper function to enter and exit a context object
* similar to python's with syntax.
*
* \code
* // context class
* class MyContext {
* private:
* friend class With<MyContext>;
MyContext(arguments);
* void EnterWithScope();
* void ExitWithScope();
* };
*
* {
* With<MyContext> scope(arguments);
* // effect take place.
* }
* \endcode
*
* \tparam ContextType Type of the context object.
*/
template<typename ContextType>
class With {
public:
/*!
* \brief constructor.
* Enter the scope of the context.
*/
template<typename ...Args>
explicit With(Args&& ...args)
: ctx_(std::forward<Args>(args)...) {
ctx_.EnterWithScope();
}
/*! \brief destructor, leaves the scope of the context. */
~With() DMLC_THROW_EXCEPTION {
ctx_.ExitWithScope();
}
private:
/*! \brief internal context type. */
ContextType ctx_;
};
/*!
* \brief save the node as well as all the node it depends on as json. * \brief save the node as well as all the node it depends on as json.
* This can be used to serialize any TVM object * This can be used to serialize any TVM object
* *
......
...@@ -37,7 +37,7 @@ namespace tvm { ...@@ -37,7 +37,7 @@ namespace tvm {
/*! /*!
* \brief Container for target device information. * \brief Container for target device information.
* Use target::llvm, target::cuda etc functions instead of constructing directly. * Use target::llvm, target::cuda etc functions instead of constructing directly.
*/ */
class TargetNode : public Node { class TargetNode : public Node {
public: public:
...@@ -89,65 +89,47 @@ class TargetNode : public Node { ...@@ -89,65 +89,47 @@ class TargetNode : public Node {
mutable std::string str_repr_; mutable std::string str_repr_;
}; };
/*! \brief reference cpass to the target. */
class Target : public NodeRef { class Target : public NodeRef {
public: public:
Target() {} Target() {}
explicit Target(NodePtr<Node> n) : NodeRef(n) {} explicit Target(NodePtr<Node> n) : NodeRef(n) {}
/*! /*!
* \brief Create a Target given a string * \brief Create a Target given a string
* \param target_str the string to parse * \param target_str the string to parse
*/ */
TVM_DLL static Target create(const std::string& target_str); TVM_DLL static Target Create(const std::string& target_str);
/*!
* \brief Push a new target context onto the thread local stack. The Target on top of
* the stack is used to determine which specialization to use when invoking a GenericFunc.
* \param target The target to set as the current context.
*/
TVM_DLL static void EnterTargetScope(const tvm::Target& target);
/*!
* \brief Pop a target off the thread local context stack, restoring the previous target
* as the current context.
*/
TVM_DLL static void ExitTargetScope();
/*! /*!
* \brief Get the current target context from thread local storage. * \brief Get the current target context from thread local storage.
* \param allow_not_defined If the context stack is empty and this is set to true, an * \param allow_not_defined If the context stack is empty and this is set to true, an
* undefined Target will be returned. Otherwise, an empty context stack will cause a * undefined Target will be returned. Otherwise, an empty context stack will cause a
* runtime error. * runtime error.
* \return The target that is the current context. The target may not be defined if * \return The target that is the current context. The target may not be defined if
* allow_not_defined is true. * allow_not_defined is true.
*/ */
TVM_DLL static tvm::Target current_target(bool allow_not_defined = true); TVM_DLL static tvm::Target Current(bool allow_not_defined = true);
inline const TargetNode* operator->() const { const TargetNode* operator->() const {
return static_cast<const TargetNode*>(node_.get()); return static_cast<const TargetNode*>(node_.get());
} }
using ContainerType = TargetNode; using ContainerType = TargetNode;
}; class Internal;
private:
/*! // enable with syntax.
* \brief RAII container to provide a scoped target context. Pushes a target onto the friend class Internal;
* context stack when constructed, and pops it when destructed. friend class With<Target>;
*/
struct TargetContext {
/*! /*!
* \brief Enter a new target context. The given target becomes the new current context. * \brief Push a new target context onto the thread local stack.
* When the TargetContext is destructed, the previous context is restored. * The Target on top of the stack is used to determine which
* \param target The target to set as the new current context. * specialization to use when invoking a GenericFunc.
*/ */
explicit TargetContext(const tvm::Target& target) { TVM_DLL void EnterWithScope();
Target::EnterTargetScope(target); /*!
} * \brief Pop a target off the thread local context stack,
* restoring the previous target as the current context.
/*! \brief Destructor. Pops the context off the thread local stack. */ */
~TargetContext() { TVM_DLL void ExitWithScope();
Target::ExitTargetScope();
}
}; };
/*! \brief This namespace provides functions to construct Target instances */ /*! \brief This namespace provides functions to construct Target instances */
...@@ -190,11 +172,9 @@ TVM_DLL Target stackvm(const std::vector<std::string>& options = ...@@ -190,11 +172,9 @@ TVM_DLL Target stackvm(const std::vector<std::string>& options =
} // namespace target } // namespace target
class BuildConfig;
/*! /*!
* \brief Container for build configuration options * \brief Container for build configuration options
*/ */
class BuildConfigNode : public Node { class BuildConfigNode : public Node {
public: public:
/*! /*!
...@@ -271,70 +251,49 @@ class BuildConfigNode : public Node { ...@@ -271,70 +251,49 @@ class BuildConfigNode : public Node {
}; };
/*! /*!
* \brief Container for build configuration options * \brief Build configuration for compilations.
*/ */
class BuildConfig : public ::tvm::NodeRef { class BuildConfig : public ::tvm::NodeRef {
public: public:
BuildConfig() {} BuildConfig() {}
explicit BuildConfig(NodePtr<::tvm::Node> n) : NodeRef(n) {} explicit BuildConfig(NodePtr<::tvm::Node> n) : NodeRef(n) {}
const BuildConfigNode* operator->() const { const BuildConfigNode* operator->() const {
return static_cast<const BuildConfigNode*>(node_.get()); return static_cast<const BuildConfigNode*>(node_.get());
} }
BuildConfigNode* operator->() { BuildConfigNode* operator->() {
return static_cast<BuildConfigNode*>(node_.get()); return static_cast<BuildConfigNode*>(node_.get());
} }
/*! /*!
* \brief Push a new BuildConfig context onto the thread local stack. * \brief Construct a BuildConfig containing a empty build config node.
* \param build_config The configuration to set as the current context. * \return The new BuildConfig
*/ */
TVM_DLL static void EnterBuildConfigScope(const tvm::BuildConfig& build_config); TVM_DLL static BuildConfig Create();
/*!
* \brief Pop a build config off the thread local context stack, restoring the previous
* configuration as the current context.
*/
TVM_DLL static void ExitBuildConfigScope();
/*! /*!
* \brief Get the current BuildConfig context from thread local storage, or a default * \brief Get the current BuildConfig context from thread local storage, or a default
* configuration if a BuildConfig scope has not been entered. * configuration if a BuildConfig scope has not been entered.
* \return The configuration that is the current context. * \return The configuration that is the current context.
*/ */
TVM_DLL static tvm::BuildConfig Current(); TVM_DLL static BuildConfig Current();
using ContainerType = BuildConfigNode; using ContainerType = BuildConfigNode;
}; class Internal;
/*! private:
* \brief RAII container to provide a scoped BuildConfig context. Pushes a configuration onto the // Enable with syntax.
* context stack when constructed, and pops it when destructed. friend class With<BuildConfig>;
*/
struct BuildConfigContext {
/*! /*!
* \brief Enter a new BuildConfig context. The given BuildConfig becomes the new current * \brief Push a new BuildConfig context onto the thread local stack.
* context. When the BuildConfigContext is destructed, the previous context is restored.
* \param build_config The BuildConfig to set as the new current context.
*/ */
explicit BuildConfigContext(const tvm::BuildConfig& build_config) { TVM_DLL void EnterWithScope();
BuildConfig::EnterBuildConfigScope(build_config);
}
/*! \brief Destructor. Pops the context off the thread local stack. */ /*!
~BuildConfigContext() { * \brief Pop a build config off the thread local context stack,
BuildConfig::ExitBuildConfigScope(); * restoring the previous configuration as the current context.
} */
TVM_DLL void ExitWithScope();
}; };
/*! /*!
* \brief Construct a BuildConfig containing a new BuildConfigNode
* \return The new BuildConfig
*/
TVM_DLL BuildConfig build_config();
/*!
* \brief Build a LoweredFunc given a schedule, args and binds * \brief Build a LoweredFunc given a schedule, args and binds
* \param sch The schedule to lower. * \param sch The schedule to lower.
* \param args The arguments to the function. * \param args The arguments to the function.
......
...@@ -187,7 +187,7 @@ class BuildConfig(NodeBase): ...@@ -187,7 +187,7 @@ class BuildConfig(NodeBase):
def __exit__(self, ptype, value, trace): def __exit__(self, ptype, value, trace):
if self.dump_pass_ir: if self.dump_pass_ir:
BuildConfig._dump_ir.exit() BuildConfig._dump_ir.exit()
_api_internal._ExitBuildConfigScope() _api_internal._ExitBuildConfigScope(self)
def __setattr__(self, name, value): def __setattr__(self, name, value):
if name in BuildConfig._node_defaults: if name in BuildConfig._node_defaults:
......
...@@ -133,7 +133,7 @@ class Target(NodeBase): ...@@ -133,7 +133,7 @@ class Target(NodeBase):
return self return self
def __exit__(self, ptype, value, trace): def __exit__(self, ptype, value, trace):
_api_internal._ExitTargetScope() _api_internal._ExitTargetScope(self)
@register_node @register_node
......
...@@ -123,8 +123,8 @@ TVM_REGISTER_API("arith._CreateAnalyzer") ...@@ -123,8 +123,8 @@ TVM_REGISTER_API("arith._CreateAnalyzer")
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
// can't use make_shared due to noexcept(false) decl in destructor, // can't use make_shared due to noexcept(false) decl in destructor,
// see https://stackoverflow.com/a/43907314 // see https://stackoverflow.com/a/43907314
auto ctx = auto ctx = std::shared_ptr<With<ConstraintContext> >(
std::shared_ptr<ConstraintContext>(new ConstraintContext(self.get(), args[0])); new With<ConstraintContext>(self.get(), args[0]));
auto fexit = [ctx](TVMArgs, TVMRetValue*) mutable { auto fexit = [ctx](TVMArgs, TVMRetValue*) mutable {
ctx.reset(); ctx.reset();
}; };
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -54,10 +54,12 @@ void Analyzer::Bind(const VarExpr& v, const Range& range) { ...@@ -54,10 +54,12 @@ void Analyzer::Bind(const VarExpr& v, const Range& range) {
// skip rewrite simplify // skip rewrite simplify
} }
ConstraintContext::ConstraintContext(Analyzer* analyzer, const Expr& constraint) {
void ConstraintContext::EnterWithScope() {
CHECK(exit_ == nullptr);
// entering the scope. // entering the scope.
auto f0 = analyzer->const_int_bound.EnterConstraint(constraint); auto f0 = analyzer_->const_int_bound.EnterConstraint(constraint_);
auto f1 = analyzer->modular_set.EnterConstraint(constraint); auto f1 = analyzer_->modular_set.EnterConstraint(constraint_);
// recovery function. // recovery function.
exit_ = [f0, f1]() { exit_ = [f0, f1]() {
if (f1 != nullptr) f1(); if (f1 != nullptr) f1();
...@@ -65,6 +67,11 @@ ConstraintContext::ConstraintContext(Analyzer* analyzer, const Expr& constraint) ...@@ -65,6 +67,11 @@ ConstraintContext::ConstraintContext(Analyzer* analyzer, const Expr& constraint)
}; };
} }
void ConstraintContext::ExitWithScope() {
CHECK(exit_ != nullptr);
exit_();
}
bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) { bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) {
if (const auto* ptr = expr.as<ir::IntImm>()) { if (const auto* ptr = expr.as<ir::IntImm>()) {
return ptr->value > lower_bound; return ptr->value > lower_bound;
......
...@@ -1200,11 +1200,11 @@ Mutate_(const Select* op, const Expr& self) { ...@@ -1200,11 +1200,11 @@ Mutate_(const Select* op, const Expr& self) {
Expr cond = Mutate(op->condition); Expr cond = Mutate(op->condition);
Expr true_value, false_value; Expr true_value, false_value;
{ {
ConstraintContext constraint(parent_, cond); With<ConstraintContext> constraint(parent_, cond);
true_value = Mutate(op->true_value); true_value = Mutate(op->true_value);
} }
{ {
ConstraintContext constraint(parent_, Mutate(Not::make(cond))); With<ConstraintContext> constraint(parent_, Mutate(Not::make(cond)));
false_value = Mutate(op->false_value); false_value = Mutate(op->false_value);
} }
if (is_zero(cond)) { if (is_zero(cond)) {
...@@ -1237,11 +1237,11 @@ Mutate_(const Call* op, const Expr& self) { ...@@ -1237,11 +1237,11 @@ Mutate_(const Call* op, const Expr& self) {
Expr cond = Mutate(op->args[0]); Expr cond = Mutate(op->args[0]);
Expr true_value, false_value; Expr true_value, false_value;
{ {
ConstraintContext constraint(parent_, cond); With<ConstraintContext> constraint(parent_, cond);
true_value = Mutate(op->args[1]); true_value = Mutate(op->args[1]);
} }
{ {
ConstraintContext constraint(parent_, Mutate(Not::make(cond))); With<ConstraintContext> constraint(parent_, Mutate(Not::make(cond)));
false_value = Mutate(op->args[2]); false_value = Mutate(op->args[2]);
} }
if (is_zero(cond)) { if (is_zero(cond)) {
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -48,11 +48,11 @@ class StmtSimplifier : public IRMutator { ...@@ -48,11 +48,11 @@ class StmtSimplifier : public IRMutator {
Expr condition = this->Mutate(op->condition); Expr condition = this->Mutate(op->condition);
Stmt then_case, else_case; Stmt then_case, else_case;
{ {
ConstraintContext ctx(&analyzer_, condition); With<ConstraintContext> ctx(&analyzer_, condition);
then_case = this->Mutate(op->then_case); then_case = this->Mutate(op->then_case);
} }
if (op->else_case.defined()) { if (op->else_case.defined()) {
ConstraintContext ctx(&analyzer_, Mutate(Not::make(condition))); With<ConstraintContext> ctx(&analyzer_, Mutate(Not::make(condition)));
else_case = this->Mutate(op->else_case); else_case = this->Mutate(op->else_case);
} }
if (is_one(condition)) return then_case; if (is_one(condition)) return then_case;
...@@ -94,7 +94,7 @@ class StmtSimplifier : public IRMutator { ...@@ -94,7 +94,7 @@ class StmtSimplifier : public IRMutator {
Stmt Mutate_(const AssertStmt* op, const Stmt& s) final { Stmt Mutate_(const AssertStmt* op, const Stmt& s) final {
Expr condition = this->Mutate(op->condition); Expr condition = this->Mutate(op->condition);
Expr message = this->Mutate(op->message); Expr message = this->Mutate(op->message);
ConstraintContext ctx(&analyzer_, condition); With<ConstraintContext> ctx(&analyzer_, condition);
Stmt body = this->Mutate(op->body); Stmt body = this->Mutate(op->body);
if (condition.same_as(op->condition) && if (condition.same_as(op->condition) &&
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2017 by Contributors
* Compile executable modules. * Compile executable modules.
* \file build_module.cc * \file build_module.cc
*/ */
...@@ -148,8 +147,7 @@ TVM_REGISTER_API("_TargetCreate") ...@@ -148,8 +147,7 @@ TVM_REGISTER_API("_TargetCreate")
TVM_REGISTER_API("_TargetFromString") TVM_REGISTER_API("_TargetFromString")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
std::string target_str = args[0]; std::string target_str = args[0];
*ret = Target::Create(target_str);
*ret = Target::create(target_str);
}); });
std::vector<std::string> TargetNode::keys() const { std::vector<std::string> TargetNode::keys() const {
...@@ -207,7 +205,7 @@ std::string GetDeviceName(const std::string& target_str) { ...@@ -207,7 +205,7 @@ std::string GetDeviceName(const std::string& target_str) {
return ""; return "";
} }
Target Target::create(const std::string& target_str) { Target Target::Create(const std::string& target_str) {
if (target_str.length() == 0) { if (target_str.length() == 0) {
LOG(ERROR) << "target_str must not be empty"; LOG(ERROR) << "target_str must not be empty";
} }
...@@ -231,25 +229,24 @@ Target Target::create(const std::string& target_str) { ...@@ -231,25 +229,24 @@ Target Target::create(const std::string& target_str) {
struct TVMTargetThreadLocalEntry { struct TVMTargetThreadLocalEntry {
/*! \brief The current target context */ /*! \brief The current target context */
std::stack<tvm::Target> context_stack; std::stack<tvm::Target> context_stack;
TVMTargetThreadLocalEntry() {
}
}; };
/*! \brief Thread local store to hold the Target context stack. */ /*! \brief Thread local store to hold the Target context stack. */
typedef dmlc::ThreadLocalStore<TVMTargetThreadLocalEntry> TVMTargetThreadLocalStore; typedef dmlc::ThreadLocalStore<TVMTargetThreadLocalEntry> TVMTargetThreadLocalStore;
void Target::EnterTargetScope(const tvm::Target& target) { void Target::EnterWithScope() {
TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get(); TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get();
entry->context_stack.push(target); entry->context_stack.push(*this);
} }
void Target::ExitTargetScope() { void Target::ExitWithScope() {
TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get(); TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get();
CHECK(!entry->context_stack.empty());
CHECK(entry->context_stack.top().same_as(*this));
entry->context_stack.pop(); entry->context_stack.pop();
} }
tvm::Target Target::current_target(bool allow_not_defined) { tvm::Target Target::Current(bool allow_not_defined) {
TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get(); TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get();
if (entry->context_stack.size() > 0) { if (entry->context_stack.size() > 0) {
return entry->context_stack.top(); return entry->context_stack.top();
...@@ -574,7 +571,7 @@ runtime::Module build(const Map<std::string, Array<LoweredFunc>>& inputs, ...@@ -574,7 +571,7 @@ runtime::Module build(const Map<std::string, Array<LoweredFunc>>& inputs,
const BuildConfig& config) { const BuildConfig& config) {
Map<Target, Array<LoweredFunc>> updated_input; Map<Target, Array<LoweredFunc>> updated_input;
for (const auto& it : inputs) { for (const auto& it : inputs) {
auto target = Target::create(it.first); auto target = Target::Create(it.first);
updated_input.Set(target, it.second); updated_input.Set(target, it.second);
} }
return build(updated_input, target_host, config); return build(updated_input, target_host, config);
...@@ -589,33 +586,35 @@ runtime::Module build(const Array<LoweredFunc>& funcs, ...@@ -589,33 +586,35 @@ runtime::Module build(const Array<LoweredFunc>& funcs,
return build(inputs, target_host, config); return build(inputs, target_host, config);
} }
BuildConfig build_config() { BuildConfig BuildConfig::Create() {
return BuildConfig(make_node<BuildConfigNode>()); return BuildConfig(make_node<BuildConfigNode>());
} }
/*! \brief Entry to hold the BuildConfig context stack. */ /*! \brief Entry to hold the BuildConfig context stack. */
struct TVMBuildConfigThreadLocalEntry { struct TVMBuildConfigThreadLocalEntry {
/*! \brief The default build config if the stack is empty */ /*! \brief The default build config if the stack is empty */
tvm::BuildConfig default_config; BuildConfig default_config;
/*! \brief The current build config context */ /*! \brief The current build config context */
std::stack<tvm::BuildConfig> context_stack; std::stack<BuildConfig> context_stack;
TVMBuildConfigThreadLocalEntry() : TVMBuildConfigThreadLocalEntry() :
default_config(build_config()) { default_config(BuildConfig::Create()) {
} }
}; };
/*! \brief Thread local store to hold the BuildConfig context stack. */ /*! \brief Thread local store to hold the BuildConfig context stack. */
typedef dmlc::ThreadLocalStore<TVMBuildConfigThreadLocalEntry> TVMBuildConfigThreadLocalStore; typedef dmlc::ThreadLocalStore<TVMBuildConfigThreadLocalEntry> TVMBuildConfigThreadLocalStore;
void BuildConfig::EnterBuildConfigScope(const tvm::BuildConfig& build_config) { void BuildConfig::EnterWithScope() {
TVMBuildConfigThreadLocalEntry *entry = TVMBuildConfigThreadLocalStore::Get(); TVMBuildConfigThreadLocalEntry *entry = TVMBuildConfigThreadLocalStore::Get();
entry->context_stack.push(build_config); entry->context_stack.push(*this);
} }
void BuildConfig::ExitBuildConfigScope() { void BuildConfig::ExitWithScope() {
TVMBuildConfigThreadLocalEntry *entry = TVMBuildConfigThreadLocalStore::Get(); TVMBuildConfigThreadLocalEntry *entry = TVMBuildConfigThreadLocalStore::Get();
CHECK(!entry->context_stack.empty());
CHECK(entry->context_stack.top().same_as(*this));
entry->context_stack.pop(); entry->context_stack.pop();
} }
...@@ -714,7 +713,7 @@ GenericFunc& GenericFunc::register_func(const std::vector<std::string>& tags, ...@@ -714,7 +713,7 @@ GenericFunc& GenericFunc::register_func(const std::vector<std::string>& tags,
void GenericFunc::CallPacked(TVMArgs args, TVMRetValue* ret) const { void GenericFunc::CallPacked(TVMArgs args, TVMRetValue* ret) const {
auto node = static_cast<GenericFuncNode*>(node_.get()); auto node = static_cast<GenericFuncNode*>(node_.get());
auto target = Target::current_target(true); auto target = Target::Current(true);
PackedFunc func; PackedFunc func;
if (target.defined()) { if (target.defined()) {
...@@ -740,16 +739,21 @@ TVM_REGISTER_API("_GetCurrentBuildConfig") ...@@ -740,16 +739,21 @@ TVM_REGISTER_API("_GetCurrentBuildConfig")
*ret = BuildConfig::Current(); *ret = BuildConfig::Current();
}); });
class BuildConfig::Internal {
public:
static void EnterScope(BuildConfig target) {
target.EnterWithScope();
}
static void ExitScope(BuildConfig target) {
target.ExitWithScope();
}
};
TVM_REGISTER_API("_EnterBuildConfigScope") TVM_REGISTER_API("_EnterBuildConfigScope")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(BuildConfig::Internal::EnterScope);
BuildConfig target = args[0];
BuildConfig::EnterBuildConfigScope(target);
});
TVM_REGISTER_API("_ExitBuildConfigScope") TVM_REGISTER_API("_ExitBuildConfigScope")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(BuildConfig::Internal::ExitScope);
BuildConfig::ExitBuildConfigScope();
});
TVM_REGISTER_API("_BuildConfigSetAddLowerPass") TVM_REGISTER_API("_BuildConfigSetAddLowerPass")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
...@@ -836,18 +840,23 @@ TVM_REGISTER_API("_GenericFuncCallFunc") ...@@ -836,18 +840,23 @@ TVM_REGISTER_API("_GenericFuncCallFunc")
TVM_REGISTER_API("_GetCurrentTarget") TVM_REGISTER_API("_GetCurrentTarget")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
bool allow_not_defined = args[0]; bool allow_not_defined = args[0];
*ret = Target::current_target(allow_not_defined); *ret = Target::Current(allow_not_defined);
}); });
class Target::Internal {
public:
static void EnterScope(Target target) {
target.EnterWithScope();
}
static void ExitScope(Target target) {
target.ExitWithScope();
}
};
TVM_REGISTER_API("_EnterTargetScope") TVM_REGISTER_API("_EnterTargetScope")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(Target::Internal::EnterScope);
Target target = args[0];
Target::EnterTargetScope(target);
});
TVM_REGISTER_API("_ExitTargetScope") TVM_REGISTER_API("_ExitTargetScope")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(Target::Internal::ExitScope);
Target::ExitTargetScope();
});
} // namespace tvm } // namespace tvm
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -54,7 +54,7 @@ runtime::Module BuildAOCL(Array<LoweredFunc> funcs, std::string target_str, ...@@ -54,7 +54,7 @@ runtime::Module BuildAOCL(Array<LoweredFunc> funcs, std::string target_str,
std::string cmd = "aoc aocl.cl"; std::string cmd = "aoc aocl.cl";
// AOCL supports fp64. // AOCL supports fp64.
cmd += " -Dcl_khr_fp64"; cmd += " -Dcl_khr_fp64";
Target target = Target::create(target_str); Target target = Target::Create(target_str);
if (target->device_name != "") { if (target->device_name != "") {
cmd += " -board=" + target->device_name; cmd += " -board=" + target->device_name;
} }
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -155,7 +155,7 @@ runtime::Module BuildSDAccel(Array<LoweredFunc> funcs, std::string target_str) { ...@@ -155,7 +155,7 @@ runtime::Module BuildSDAccel(Array<LoweredFunc> funcs, std::string target_str) {
std::string xclbin; std::string xclbin;
if (const auto* f = Registry::Get("tvm_callback_sdaccel_compile")) { if (const auto* f = Registry::Get("tvm_callback_sdaccel_compile")) {
Target target = Target::create(target_str); Target target = Target::Create(target_str);
xclbin = (*f)(kernel_info, target->device_name).operator std::string(); xclbin = (*f)(kernel_info, target->device_name).operator std::string();
} else { } else {
LOG(FATAL) << "Cannot compile Vivado HLS code."; LOG(FATAL) << "Cannot compile Vivado HLS code.";
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -1142,7 +1142,7 @@ void CodeGenLLVM::VisitStmt_(const AttrStmt* op) { ...@@ -1142,7 +1142,7 @@ void CodeGenLLVM::VisitStmt_(const AttrStmt* op) {
} }
void CodeGenLLVM::VisitStmt_(const AssertStmt* op) { void CodeGenLLVM::VisitStmt_(const AssertStmt* op) {
arith::ConstraintContext cctx(analyzer_.get(), op->condition); With<arith::ConstraintContext> cctx(analyzer_.get(), op->condition);
this->VisitStmt(op->body); this->VisitStmt(op->body);
} }
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -626,7 +626,7 @@ void CodeGenSPIRV::VisitStmt_(const AttrStmt* op) { ...@@ -626,7 +626,7 @@ void CodeGenSPIRV::VisitStmt_(const AttrStmt* op) {
} }
void CodeGenSPIRV::VisitStmt_(const AssertStmt* op) { void CodeGenSPIRV::VisitStmt_(const AssertStmt* op) {
arith::ConstraintContext cctx(analyzer_.get(), op->condition); With<arith::ConstraintContext> cctx(analyzer_.get(), op->condition);
this->VisitStmt(op->body); this->VisitStmt(op->body);
} }
......
...@@ -445,7 +445,7 @@ class RelayBuildModule : public runtime::ModuleNode { ...@@ -445,7 +445,7 @@ class RelayBuildModule : public runtime::ModuleNode {
if (targets.size() == 1) { if (targets.size() == 1) {
func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
for (const auto& kv : targets) { for (const auto& kv : targets) {
TargetContext tctx(kv.second); With<Target> tctx(kv.second);
func = CallPackedFunc("relay._ir_pass.AlterOpLayout", func); func = CallPackedFunc("relay._ir_pass.AlterOpLayout", func);
} }
} else { } else {
...@@ -466,9 +466,9 @@ class RelayBuildModule : public runtime::ModuleNode { ...@@ -466,9 +466,9 @@ class RelayBuildModule : public runtime::ModuleNode {
*/ */
Target CreateDefaultTarget(int device_type) { Target CreateDefaultTarget(int device_type) {
std::string name = runtime::DeviceName(device_type); std::string name = runtime::DeviceName(device_type);
if (name == "cpu") return Target::create("llvm"); if (name == "cpu") return Target::Create("llvm");
if (name == "gpu") return Target::create("cuda"); if (name == "gpu") return Target::Create("cuda");
return Target::create(name); return Target::Create(name);
} }
/*! /*!
* \brief Update the target and fallback device required for heterogeneous * \brief Update the target and fallback device required for heterogeneous
...@@ -548,7 +548,7 @@ class RelayBuildModule : public runtime::ModuleNode { ...@@ -548,7 +548,7 @@ class RelayBuildModule : public runtime::ModuleNode {
const RelayBuildConfig& cfg, const RelayBuildConfig& cfg,
const std::unordered_map<std::string, tvm::runtime::NDArray> &params) { const std::unordered_map<std::string, tvm::runtime::NDArray> &params) {
// convert // convert
tvm_cfg_ = build_config(); tvm_cfg_ = BuildConfig::Create();
TargetsMap device_target; TargetsMap device_target;
if (targets_.size() > 1) { if (targets_.size() > 1) {
device_target = UpdateHeterogeneousInputs(targets_, cfg); device_target = UpdateHeterogeneousInputs(targets_, cfg);
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -344,7 +344,7 @@ class CompileEngineImpl : public CompileEngineNode { ...@@ -344,7 +344,7 @@ class CompileEngineImpl : public CompileEngineNode {
cache_[key] = value; cache_[key] = value;
} }
// Enforce use the target. // Enforce use the target.
TargetContext target_ctx(key->target); With<Target> target_scope(key->target);
CHECK(!value->cached_func.defined()); CHECK(!value->cached_func.defined());
auto spair = CreateSchedule(key->source_func, key->target); auto spair = CreateSchedule(key->source_func, key->target);
...@@ -371,7 +371,7 @@ class CompileEngineImpl : public CompileEngineNode { ...@@ -371,7 +371,7 @@ class CompileEngineImpl : public CompileEngineNode {
cache_node->funcs = (*f)( cache_node->funcs = (*f)(
spair.first, all_args, cache_node->func_name, key->source_func); spair.first, all_args, cache_node->func_name, key->source_func);
} else { } else {
tvm::BuildConfig bcfg = tvm::build_config(); tvm::BuildConfig bcfg = BuildConfig::Create();
std::unordered_map<Tensor, Buffer> binds; std::unordered_map<Tensor, Buffer> binds;
cache_node->funcs = tvm::lower(spair.first, all_args, cache_node->func_name, binds, bcfg); cache_node->funcs = tvm::lower(spair.first, all_args, cache_node->func_name, binds, bcfg);
} }
......
...@@ -364,7 +364,7 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -364,7 +364,7 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
// Next generate the invoke instruction. // Next generate the invoke instruction.
CHECK(func->IsPrimitive()); CHECK(func->IsPrimitive());
auto target = Target::create("llvm"); auto target = Target::Create("llvm");
auto key = CCacheKeyNode::make(func, target); auto key = CCacheKeyNode::make(func, target);
auto cfunc = engine->Lower(key); auto cfunc = engine->Lower(key);
// TODO(jroesch): support lowered funcs for multiple targets // TODO(jroesch): support lowered funcs for multiple targets
...@@ -502,7 +502,7 @@ void PopulatePackedFuncMap(const std::vector<LoweredFunc>& lowered_funcs, ...@@ -502,7 +502,7 @@ void PopulatePackedFuncMap(const std::vector<LoweredFunc>& lowered_funcs,
runtime::Module mod; runtime::Module mod;
if (lowered_funcs.size() > 0) { if (lowered_funcs.size() > 0) {
// TODO(@jroesch): we need to read target from build config // TODO(@jroesch): we need to read target from build config
Target target = Target::create("llvm"); Target target = Target::Create("llvm");
if (const auto* f = runtime::Registry::Get("relay.backend.build")) { if (const auto* f = runtime::Registry::Get("relay.backend.build")) {
mod = (*f)(tvm::Array<LoweredFunc>(lowered_funcs.begin(), lowered_funcs.end()), target); mod = (*f)(tvm::Array<LoweredFunc>(lowered_funcs.begin(), lowered_funcs.end()), target);
} else { } else {
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -203,10 +203,10 @@ Expr FoldConstant(const Expr& expr) { ...@@ -203,10 +203,10 @@ Expr FoldConstant(const Expr& expr) {
DLContext ctx; DLContext ctx;
ctx.device_type = kDLCPU; ctx.device_type = kDLCPU;
ctx.device_id = 0; ctx.device_id = 0;
Target target = Target::create("llvm"); Target target = Target::Create("llvm");
// use a fresh build context // use a fresh build context
// in case we are already in a build context. // in case we are already in a build context.
BuildConfigContext fresh_build_ctx(build_config()); With<BuildConfig> fresh_build_ctx(BuildConfig::Create());
return ConstantFolder(CreateInterpreter( return ConstantFolder(CreateInterpreter(
Module(nullptr), ctx, target)).Mutate(expr); Module(nullptr), ctx, target)).Mutate(expr);
......
...@@ -375,10 +375,10 @@ DLContext CPUContext() { ...@@ -375,10 +375,10 @@ DLContext CPUContext() {
} }
FInterpreter CPUInterpreter() { FInterpreter CPUInterpreter() {
Target target = Target::create("llvm"); Target target = Target::Create("llvm");
// use a fresh build context // use a fresh build context
// in case we are already in a build context. // in case we are already in a build context.
BuildConfigContext fresh_build_ctx(build_config()); With<BuildConfig> fresh_build_ctx(BuildConfig::Create());
return CreateInterpreter(Module(nullptr), CPUContext(), target); return CreateInterpreter(Module(nullptr), CPUContext(), target);
} }
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -50,14 +50,14 @@ TEST(BuildModule, Basic) { ...@@ -50,14 +50,14 @@ TEST(BuildModule, Basic) {
auto args = Array<Tensor>({ A, B, C }); auto args = Array<Tensor>({ A, B, C });
std::unordered_map<Tensor, Buffer> binds; std::unordered_map<Tensor, Buffer> binds;
auto config = build_config(); auto config = BuildConfig::Create();
auto target = target::llvm(); auto target = target::llvm();
auto lowered = lower(s, args, "func", binds, config); auto lowered = lower(s, args, "func", binds, config);
auto module = build(lowered, target, Target(), config); auto module = build(lowered, target, Target(), config);
auto mali_target = Target::create("opencl -model=Mali-T860MP4@800Mhz -device=mali"); auto mali_target = Target::Create("opencl -model=Mali-T860MP4@800Mhz -device=mali");
CHECK_EQ(mali_target->str(), "opencl -model=Mali-T860MP4@800Mhz -device=mali"); CHECK_EQ(mali_target->str(), "opencl -model=Mali-T860MP4@800Mhz -device=mali");
} }
TEST(BuildModule, Heterogeneous) { TEST(BuildModule, Heterogeneous) {
...@@ -105,7 +105,7 @@ TEST(BuildModule, Heterogeneous) { ...@@ -105,7 +105,7 @@ TEST(BuildModule, Heterogeneous) {
auto s1 = topi::cuda::schedule_injective(target_cuda, {elemwise_add}); auto s1 = topi::cuda::schedule_injective(target_cuda, {elemwise_add});
auto s2 = create_schedule({elemwise_sub->op}); auto s2 = create_schedule({elemwise_sub->op});
auto config = build_config(); auto config = BuildConfig::Create();
auto args1 = Array<Tensor>({A, B, elemwise_add}); auto args1 = Array<Tensor>({A, B, elemwise_add});
auto args2 = Array<Tensor>({copy, C, elemwise_sub}); auto args2 = Array<Tensor>({copy, C, elemwise_sub});
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -75,7 +75,7 @@ TEST(Relay, BuildModule) { ...@@ -75,7 +75,7 @@ TEST(Relay, BuildModule) {
auto json_f = build_mod.GetFunction("get_graph_json", false); auto json_f = build_mod.GetFunction("get_graph_json", false);
auto mod_f = build_mod.GetFunction("get_module", false); auto mod_f = build_mod.GetFunction("get_module", false);
Map<tvm::Integer, tvm::Target> targets; Map<tvm::Integer, tvm::Target> targets;
Target llvm_tgt = Target::create("llvm"); Target llvm_tgt = Target::Create("llvm");
targets.Set(0, llvm_tgt); targets.Set(0, llvm_tgt);
build_f(func, targets, llvm_tgt); build_f(func, targets, llvm_tgt);
std::string json = json_f(); std::string json = json_f();
......
...@@ -94,7 +94,7 @@ inline bool IsTensorType(TVMArgValue arg) { ...@@ -94,7 +94,7 @@ inline bool IsTensorType(TVMArgValue arg) {
TVM_REGISTER_GLOBAL("topi.TEST_create_target") TVM_REGISTER_GLOBAL("topi.TEST_create_target")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = tvm::Target::create(args[0]); *rv = tvm::Target::Create(args[0]);
}); });
/* Ops from broadcast.h */ /* Ops from broadcast.h */
...@@ -640,7 +640,7 @@ using FTVMScheduleBuilder = std::function< ...@@ -640,7 +640,7 @@ using FTVMScheduleBuilder = std::function<
*/ */
inline PackedFunc WrapSchedule(FTVMScheduleBuilder builder) { inline PackedFunc WrapSchedule(FTVMScheduleBuilder builder) {
return PackedFunc([builder](TVMArgs args, TVMRetValue* ret) { return PackedFunc([builder](TVMArgs args, TVMRetValue* ret) {
auto target = Target::current_target(false); auto target = Target::Current(false);
Array<Tensor> outs; Array<Tensor> outs;
NodeRef argNodeRef = args[0]; NodeRef argNodeRef = args[0];
if (argNodeRef->type_index() == outs->type_index()) { if (argNodeRef->type_index() == outs->type_index()) {
...@@ -712,7 +712,7 @@ using FTVMDenseOpBuilder = std::function<tvm::Tensor(const Target& target, ...@@ -712,7 +712,7 @@ using FTVMDenseOpBuilder = std::function<tvm::Tensor(const Target& target,
*/ */
inline PackedFunc WrapDenseOp(FTVMDenseOpBuilder builder) { inline PackedFunc WrapDenseOp(FTVMDenseOpBuilder builder) {
return PackedFunc([builder](TVMArgs args, TVMRetValue* ret) { return PackedFunc([builder](TVMArgs args, TVMRetValue* ret) {
auto target = Target::current_target(false); auto target = Target::Current(false);
Tensor data = args[0]; Tensor data = args[0];
Tensor weight = args[1]; Tensor weight = args[1];
Tensor bias = args[2]; Tensor bias = args[2];
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment